diff --git a/.github/workflows/cppcheck.yml b/.github/workflows/cppcheck.yml new file mode 100644 index 00000000..e4a8d9af --- /dev/null +++ b/.github/workflows/cppcheck.yml @@ -0,0 +1,39 @@ +name: Static Analysis + +on: + push: + branches: [main] + pull_request: + +jobs: + cppcheck: + name: Run Cppcheck + runs-on: ubuntu-latest + + steps: + - name: Checkout code + uses: actions/checkout@v4 + with: + submodules: true + lfs: true + + - name: Install cppcheck + run: | + sudo apt-get update + sudo apt-get install -y cppcheck + + - name: Run cppcheck + run: | + mkdir -p cppcheck-report + cppcheck --enable=all --inconclusive --quiet \ + --output-file=cppcheck-report/cppcheck.txt \ + $GITHUB_WORKSPACE/framework/src/ \ + -I $GITHUB_WORKSPACE/include/ \ + -I $GITHUB_WORKSPACE/framework/include/ + cat cppcheck-report/cppcheck.txt + + - name: Upload cppcheck report artifact + uses: actions/upload-artifact@v4 + with: + name: cppcheck-report + path: cppcheck-report/cppcheck.txt diff --git a/MODULE.bazel b/MODULE.bazel index 63282ba8..0fec2bc3 100644 --- a/MODULE.bazel +++ b/MODULE.bazel @@ -14,6 +14,17 @@ bazel_dep(name = "rules_python", version = "0.37.2") bazel_dep(name = "platforms", version = "0.0.10") bazel_dep(name = "googletest", version = "1.15.2") bazel_dep(name = "apple_support", version = "1.17.1", repo_name = "build_bazel_apple_support") +bazel_dep(name = "curl", version = "8.8.0") +bazel_dep(name = "nlohmann_json", version = "3.11.3") +bazel_dep(name = "hedron_compile_commands", dev_dependency = True) +bazel_dep(name = "flatbuffers", version = "24.3.25") + +# Hedron's Compile Commands Extractor for Bazel +git_override( + module_name = "hedron_compile_commands", + remote = "https://github.com/hedronvision/bazel-compile-commands-extractor.git", + commit = "4f28899228fb3ad0126897876f147ca15026151e", +) # Use archive_override to patch rules_foreign_cc to default to specific cmake version archive_override( diff --git a/MODULE.bazel.lock b/MODULE.bazel.lock index 52f43e79..f210eeb0 100644 --- a/MODULE.bazel.lock +++ b/MODULE.bazel.lock @@ -4,6 +4,7 @@ "https://bcr.bazel.build/bazel_registry.json": "8a28e4aff06ee60aed2a8c281907fb8bcbf3b753c91fb5a5c57da3215d5b3497", "https://bcr.bazel.build/modules/abseil-cpp/20210324.2/MODULE.bazel": "7cd0312e064fde87c8d1cd79ba06c876bd23630c83466e9500321be55c96ace2", "https://bcr.bazel.build/modules/abseil-cpp/20211102.0/MODULE.bazel": "70390338f7a5106231d20620712f7cccb659cd0e9d073d1991c038eb9fc57589", + "https://bcr.bazel.build/modules/abseil-cpp/20220623.1/MODULE.bazel": "73ae41b6818d423a11fd79d95aedef1258f304448193d4db4ff90e5e7a0f076c", "https://bcr.bazel.build/modules/abseil-cpp/20230125.1/MODULE.bazel": "89047429cb0207707b2dface14ba7f8df85273d484c2572755be4bab7ce9c3a0", "https://bcr.bazel.build/modules/abseil-cpp/20230802.0.bcr.1/MODULE.bazel": "1c8cec495288dccd14fdae6e3f95f772c1c91857047a098fad772034264cc8cb", "https://bcr.bazel.build/modules/abseil-cpp/20230802.0/MODULE.bazel": "d253ae36a8bd9ee3c5955384096ccb6baf16a1b1e93e858370da0a3b94f77c16", @@ -12,7 +13,21 @@ "https://bcr.bazel.build/modules/apple_support/1.15.1/MODULE.bazel": "a0556fefca0b1bb2de8567b8827518f94db6a6e7e7d632b4c48dc5f865bc7c85", "https://bcr.bazel.build/modules/apple_support/1.17.1/MODULE.bazel": "655c922ab1209978a94ef6ca7d9d43e940cd97d9c172fb55f94d91ac53f8610b", "https://bcr.bazel.build/modules/apple_support/1.17.1/source.json": "6b2b8c74d14e8d485528a938e44bdb72a5ba17632b9e14ef6e68a5ee96c8347f", + "https://bcr.bazel.build/modules/apple_support/1.3.1/MODULE.bazel": "6d04819e9f8775a6eabe3c232585454d5393c6c4600029d063566a4f2326a600", "https://bcr.bazel.build/modules/apple_support/1.5.0/MODULE.bazel": "50341a62efbc483e8a2a6aec30994a58749bd7b885e18dd96aa8c33031e558ef", + "https://bcr.bazel.build/modules/aspect_bazel_lib/1.29.2/MODULE.bazel": "3ca4ed580f4d7e7e47c0b4f2e4799e8c895a2e59e7fab922078cdd5fb631095b", + "https://bcr.bazel.build/modules/aspect_bazel_lib/1.31.2/MODULE.bazel": "7bee702b4862612f29333590f4b658a5832d433d6f8e4395f090e8f4e85d442f", + "https://bcr.bazel.build/modules/aspect_bazel_lib/1.39.0/MODULE.bazel": "4b9135560d1b9f9520b85739da72de105fc919346c83c874ebf0789794075340", + "https://bcr.bazel.build/modules/aspect_bazel_lib/1.40.0/MODULE.bazel": "eac4cf71482009e142804f72b2b102fb7e9812c326702d8b7206385b24f7805f", + "https://bcr.bazel.build/modules/aspect_bazel_lib/1.40.0/source.json": "035a1023f17bde54c2a695158e473ab23619c659099e69264dd9d52a9475b7b1", + "https://bcr.bazel.build/modules/aspect_rules_esbuild/0.15.0/MODULE.bazel": "35508f042286d2074b080df9d88b5faa2b97f98558b68101a39be8e5ab2837b1", + "https://bcr.bazel.build/modules/aspect_rules_esbuild/0.15.0/source.json": "be3ba638076fdb3f369b86b399990500f45c1c8251526a72d882d12d13d81ae5", + "https://bcr.bazel.build/modules/aspect_rules_js/1.29.2/MODULE.bazel": "1b8f06192f8372e33139e1f6aa97e1d56295eb5134b288d70a750d4d13a37736", + "https://bcr.bazel.build/modules/aspect_rules_js/1.34.1/MODULE.bazel": "d86fd8dcc3e09d17df383e994b3adf87ed645b065c604eb4560404603c46fa8d", + "https://bcr.bazel.build/modules/aspect_rules_js/1.34.1/source.json": "ecb32e41d9a1e1dc84525141bbfdc7155b8b71048b0a4e3f2740ed31febda874", + "https://bcr.bazel.build/modules/aspect_rules_ts/1.4.5/MODULE.bazel": "9e6520f1aa823e7f707968124e1bbe87598ec5495df3162d0749fa19a29973bb", + "https://bcr.bazel.build/modules/aspect_rules_ts/1.4.5/source.json": "40b03d827dd656b775318fe205a54481219d4729a67bf8293e6706a8a41ab2cd", + "https://bcr.bazel.build/modules/bazel_features/0.1.0/MODULE.bazel": "47011d645b0f949f42ee67f2e8775188a9cf4a0a1528aa2fa4952f2fd00906fd", "https://bcr.bazel.build/modules/bazel_features/1.1.1/MODULE.bazel": "27b8c79ef57efe08efccbd9dd6ef70d61b4798320b8d3c134fd571f78963dbcd", "https://bcr.bazel.build/modules/bazel_features/1.10.0/MODULE.bazel": "f75e8807570484a99be90abcd52b5e1f390362c258bcb73106f4544957a48101", "https://bcr.bazel.build/modules/bazel_features/1.11.0/MODULE.bazel": "f9382337dd5a474c3b7d334c2f83e50b6eaedc284253334cf823044a26de03e8", @@ -20,6 +35,7 @@ "https://bcr.bazel.build/modules/bazel_features/1.19.0/source.json": "d7bf14517c1b25b9d9c580b0f8795fceeae08a7590f507b76aace528e941375d", "https://bcr.bazel.build/modules/bazel_features/1.9.1/MODULE.bazel": "8f679097876a9b609ad1f60249c49d68bfab783dd9be012faf9d82547b14815a", "https://bcr.bazel.build/modules/bazel_skylib/1.0.3/MODULE.bazel": "bcb0fd896384802d1ad283b4e4eb4d718eebd8cb820b0a2c3a347fb971afd9d8", + "https://bcr.bazel.build/modules/bazel_skylib/1.1.1/MODULE.bazel": "1add3e7d93ff2e6998f9e118022c84d163917d912f5afafb3058e3d2f1545b5e", "https://bcr.bazel.build/modules/bazel_skylib/1.2.0/MODULE.bazel": "44fe84260e454ed94ad326352a698422dbe372b21a1ac9f3eab76eb531223686", "https://bcr.bazel.build/modules/bazel_skylib/1.2.1/MODULE.bazel": "f35baf9da0efe45fa3da1696ae906eea3d615ad41e2e3def4aeb4e8bc0ef9a7a", "https://bcr.bazel.build/modules/bazel_skylib/1.3.0/MODULE.bazel": "20228b92868bf5cfc41bda7afc8a8ba2a543201851de39d990ec957b513579c5", @@ -29,15 +45,33 @@ "https://bcr.bazel.build/modules/bazel_skylib/1.6.1/MODULE.bazel": "8fdee2dbaace6c252131c00e1de4b165dc65af02ea278476187765e1a617b917", "https://bcr.bazel.build/modules/bazel_skylib/1.7.1/MODULE.bazel": "3120d80c5861aa616222ec015332e5f8d3171e062e3e804a2a0253e1be26e59b", "https://bcr.bazel.build/modules/bazel_skylib/1.7.1/source.json": "f121b43eeefc7c29efbd51b83d08631e2347297c95aac9764a701f2a6a2bb953", + "https://bcr.bazel.build/modules/boringssl/0.0.0-20211025-d4f1ab9/MODULE.bazel": "6ee6353f8b1a701fe2178e1d925034294971350b6d3ac37e67e5a7d463267834", + "https://bcr.bazel.build/modules/boringssl/0.0.0-20211025-d4f1ab9/source.json": "323bafff99739f6aba35b69a84f0bc04ddb4540a46c1694355f60f073dff3001", "https://bcr.bazel.build/modules/buildozer/7.1.2/MODULE.bazel": "2e8dd40ede9c454042645fd8d8d0cd1527966aa5c919de86661e62953cd73d84", "https://bcr.bazel.build/modules/buildozer/7.1.2/source.json": "c9028a501d2db85793a6996205c8de120944f50a0d570438fcae0457a5f9d1f8", + "https://bcr.bazel.build/modules/c-ares/1.15.0/MODULE.bazel": "ba0a78360fdc83f02f437a9e7df0532ad1fbaa59b722f6e715c11effebaa0166", + "https://bcr.bazel.build/modules/c-ares/1.15.0/source.json": "5e3ed991616c5ec4cc09b0893b29a19232de4a1830eb78c567121bfea87453f7", + "https://bcr.bazel.build/modules/curl/8.8.0/MODULE.bazel": "7da3b3e79b0b4ee8f8c95d640bc6ad7b430ce66ef6e9c9d2bc29b3b5ef85f6fe", + "https://bcr.bazel.build/modules/curl/8.8.0/source.json": "d7d138b6878cf38891692fee0649ace35357fd549b425614d571786f054374d4", + "https://bcr.bazel.build/modules/flatbuffers/24.3.25/MODULE.bazel": "2794b084ee385ecd08a22fd90614b93851508ceb7a97e63da399886dedbc696c", + "https://bcr.bazel.build/modules/flatbuffers/24.3.25/source.json": "0cea4d62612a34154ffe0208a85f9f197edbb1f8f37a8855ec4aa722fea69276", + "https://bcr.bazel.build/modules/gazelle/0.26.0/MODULE.bazel": "6bf5f61b15648e7e35db25fb23cef6b4164fc71c3064ac42ecacafcb6d02abe6", + "https://bcr.bazel.build/modules/gazelle/0.32.0/MODULE.bazel": "b499f58a5d0d3537f3cf5b76d8ada18242f64ec474d8391247438bf04f58c7b8", + "https://bcr.bazel.build/modules/gazelle/0.32.0/source.json": "ef7e2d5194a004d902f5a745eb8f466c90b63a539e9d59311197b87e4d1caee7", "https://bcr.bazel.build/modules/google_benchmark/1.8.2/MODULE.bazel": "a70cf1bba851000ba93b58ae2f6d76490a9feb74192e57ab8e8ff13c34ec50cb", "https://bcr.bazel.build/modules/googletest/1.11.0/MODULE.bazel": "3a83f095183f66345ca86aa13c58b59f9f94a2f81999c093d4eeaa2d262d12f4", "https://bcr.bazel.build/modules/googletest/1.14.0.bcr.1/MODULE.bazel": "22c31a561553727960057361aa33bf20fb2e98584bc4fec007906e27053f80c6", "https://bcr.bazel.build/modules/googletest/1.14.0/MODULE.bazel": "cfbcbf3e6eac06ef9d85900f64424708cc08687d1b527f0ef65aa7517af8118f", "https://bcr.bazel.build/modules/googletest/1.15.2/MODULE.bazel": "6de1edc1d26cafb0ea1a6ab3f4d4192d91a312fd2d360b63adaa213cd00b2108", "https://bcr.bazel.build/modules/googletest/1.15.2/source.json": "dbdda654dcb3a0d7a8bc5d0ac5fc7e150b58c2a986025ae5bc634bb2cb61f470", + "https://bcr.bazel.build/modules/grpc/1.41.0/MODULE.bazel": "5bcbfc2b274dabea628f0649dc50c90cf36543b1cfc31624832538644ad1aae8", + "https://bcr.bazel.build/modules/grpc/1.48.1/MODULE.bazel": "3ca31ff176210449f280cb7765b59f3c6497abe10fa6f888de7b7bf00de53176", + "https://bcr.bazel.build/modules/grpc/1.48.1/source.json": "fb95df9c53c0a004f6681fa0e4a87d7b8c85c2182a73ada28c06339dbee78e42", "https://bcr.bazel.build/modules/libpfm/4.11.0/MODULE.bazel": "45061ff025b301940f1e30d2c16bea596c25b176c8b6b3087e92615adbd52902", + "https://bcr.bazel.build/modules/mbedtls/3.6.0/MODULE.bazel": "8e380e4698107c5f8766264d4df92e36766248447858db28187151d884995a09", + "https://bcr.bazel.build/modules/mbedtls/3.6.0/source.json": "1dbe7eb5258050afcc3806b9d43050f71c6f539ce0175535c670df606790b30c", + "https://bcr.bazel.build/modules/nlohmann_json/3.11.3/MODULE.bazel": "87023db2f55fc3a9949c7b08dc711fae4d4be339a80a99d04453c4bb3998eefc", + "https://bcr.bazel.build/modules/nlohmann_json/3.11.3/source.json": "296c63a90c6813e53b3812d24245711981fc7e563d98fe15625f55181494488a", "https://bcr.bazel.build/modules/platforms/0.0.10/MODULE.bazel": "8cb8efaf200bdeb2150d93e162c40f388529a25852b332cec879373771e48ed5", "https://bcr.bazel.build/modules/platforms/0.0.10/source.json": "f22828ff4cf021a6b577f1bf6341cb9dcd7965092a439f64fc1bb3b7a5ae4bd5", "https://bcr.bazel.build/modules/platforms/0.0.4/MODULE.bazel": "9b328e31ee156f53f3c416a64f8491f7eb731742655a47c9eec4703a71644aee", @@ -51,10 +85,12 @@ "https://bcr.bazel.build/modules/protobuf/24.4/MODULE.bazel": "7bc7ce5f2abf36b3b7b7c8218d3acdebb9426aeb35c2257c96445756f970eb12", "https://bcr.bazel.build/modules/protobuf/24.4/source.json": "ace4b8c65d4cfe64efe544f09fc5e5df77faf3a67fbb29c5341e0d755d9b15d6", "https://bcr.bazel.build/modules/protobuf/3.19.0/MODULE.bazel": "6b5fbb433f760a99a22b18b6850ed5784ef0e9928a72668b66e4d7ccd47db9b0", + "https://bcr.bazel.build/modules/protobuf/3.19.2/MODULE.bazel": "532ffe5f2186b69fdde039efe6df13ba726ff338c6bc82275ad433013fa10573", "https://bcr.bazel.build/modules/protobuf/3.19.6/MODULE.bazel": "9233edc5e1f2ee276a60de3eaa47ac4132302ef9643238f23128fea53ea12858", "https://bcr.bazel.build/modules/pybind11_bazel/2.11.1/MODULE.bazel": "88af1c246226d87e65be78ed49ecd1e6f5e98648558c14ce99176da041dc378e", "https://bcr.bazel.build/modules/pybind11_bazel/2.12.0/MODULE.bazel": "e6f4c20442eaa7c90d7190d8dc539d0ab422f95c65a57cc59562170c58ae3d34", "https://bcr.bazel.build/modules/pybind11_bazel/2.12.0/source.json": "6900fdc8a9e95866b8c0d4ad4aba4d4236317b5c1cd04c502df3f0d33afed680", + "https://bcr.bazel.build/modules/re2/2021-09-01/MODULE.bazel": "bcb6b96f3b071e6fe2d8bed9cc8ada137a105f9d2c5912e91d27528b3d123833", "https://bcr.bazel.build/modules/re2/2023-09-01/MODULE.bazel": "cb3d511531b16cfc78a225a9e2136007a48cf8a677e4264baeab57fe78a80206", "https://bcr.bazel.build/modules/re2/2024-07-02/MODULE.bazel": "0eadc4395959969297cbcf31a249ff457f2f1d456228c67719480205aa306daa", "https://bcr.bazel.build/modules/re2/2024-07-02/source.json": "547d0111a9d4f362db32196fef805abbf3676e8d6afbe44d395d87816c1130ca", @@ -65,7 +101,13 @@ "https://bcr.bazel.build/modules/rules_cc/0.0.9/MODULE.bazel": "836e76439f354b89afe6a911a7adf59a6b2518fafb174483ad78a2a2fde7b1c5", "https://bcr.bazel.build/modules/rules_cc/0.1.0/MODULE.bazel": "2fef03775b9ba995ec543868840041cc69e8bc705eb0cb6604a36eee18c87d8b", "https://bcr.bazel.build/modules/rules_cc/0.1.0/source.json": "8a4e832d75e073ab56c74dd77008cf7a81e107dec4544019eb1eefc1320d55be", + "https://bcr.bazel.build/modules/rules_go/0.33.0/MODULE.bazel": "a2b11b64cd24bf94f57454f53288a5dacfe6cb86453eee7761b7637728c1910c", + "https://bcr.bazel.build/modules/rules_go/0.34.0/MODULE.bazel": "20240361d6ff5cb752121af8c64aa41adc5a72ade59c90040606070e1690be09", + "https://bcr.bazel.build/modules/rules_go/0.41.0/MODULE.bazel": "55861d8e8bb0e62cbd2896f60ff303f62ffcb0eddb74ecb0e5c0cbe36fc292c8", + "https://bcr.bazel.build/modules/rules_go/0.41.0/source.json": "a46e5f523176e3bd60b1c9cfdcb6c878b9cd14c21fe1a563c4ba0e6d0e7c4dd8", "https://bcr.bazel.build/modules/rules_java/4.0.0/MODULE.bazel": "5a78a7ae82cd1a33cef56dc578c7d2a46ed0dca12643ee45edbb8417899e6f74", + "https://bcr.bazel.build/modules/rules_java/5.1.0/MODULE.bazel": "324b6478b0343a3ce7a9add8586ad75d24076d6d43d2f622990b9c1cfd8a1b15", + "https://bcr.bazel.build/modules/rules_java/5.3.5/MODULE.bazel": "a4ec4f2db570171e3e5eb753276ee4b389bae16b96207e9d3230895c99644b86", "https://bcr.bazel.build/modules/rules_java/6.3.0/MODULE.bazel": "a97c7678c19f236a956ad260d59c86e10a463badb7eb2eda787490f4c969b963", "https://bcr.bazel.build/modules/rules_java/7.1.0/MODULE.bazel": "30d9135a2b6561c761bd67bd4990da591e6bdc128790ce3e7afd6a3558b2fb64", "https://bcr.bazel.build/modules/rules_java/7.6.5/MODULE.bazel": "481164be5e02e4cab6e77a36927683263be56b7e36fef918b458d7a8a1ebadb1", @@ -77,6 +119,9 @@ "https://bcr.bazel.build/modules/rules_license/0.0.3/MODULE.bazel": "627e9ab0247f7d1e05736b59dbb1b6871373de5ad31c3011880b4133cafd4bd0", "https://bcr.bazel.build/modules/rules_license/0.0.7/MODULE.bazel": "088fbeb0b6a419005b89cf93fe62d9517c0a2b8bb56af3244af65ecfe37e7d5d", "https://bcr.bazel.build/modules/rules_license/0.0.7/source.json": "355cc5737a0f294e560d52b1b7a6492d4fff2caf0bef1a315df5a298fca2d34a", + "https://bcr.bazel.build/modules/rules_nodejs/5.8.2/MODULE.bazel": "6bc03c8f37f69401b888023bf511cb6ee4781433b0cb56236b2e55a21e3a026a", + "https://bcr.bazel.build/modules/rules_nodejs/5.8.3/MODULE.bazel": "9fac1897d2067a37693e47f48e11cdb386a455902313c85e9e46fe0aaaa2e4e1", + "https://bcr.bazel.build/modules/rules_nodejs/5.8.3/source.json": "adc580471187345e43dd874d951a84d2256455fbeaedca539174f1e4ab49f9a4", "https://bcr.bazel.build/modules/rules_pkg/0.7.0/MODULE.bazel": "df99f03fc7934a4737122518bb87e667e62d780b610910f0447665a7e2be62dc", "https://bcr.bazel.build/modules/rules_pkg/0.7.0/source.json": "c2557066e0c0342223ba592510ad3d812d4963b9024831f7f66fd0584dd8c66c", "https://bcr.bazel.build/modules/rules_proto/4.0.0/MODULE.bazel": "a7a7b6ce9bee418c1a760b3d84f83a299ad6952f9903c67f19e4edd964894e06", @@ -92,10 +137,16 @@ "https://bcr.bazel.build/modules/rules_python/0.37.2/MODULE.bazel": "b5ffde91410745750b6c13be1c5dc4555ef5bc50562af4a89fd77807fdde626a", "https://bcr.bazel.build/modules/rules_python/0.37.2/source.json": "af5c224d27ec98a612b4dcbdc481e02502cd5a4b49d87f0093200a10a35383e9", "https://bcr.bazel.build/modules/rules_python/0.4.0/MODULE.bazel": "9208ee05fd48bf09ac60ed269791cf17fb343db56c8226a720fbb1cdf467166c", + "https://bcr.bazel.build/modules/rules_swift/1.2.0/MODULE.bazel": "9559e7b880723a274845b92bc760bb2d4c9f9f562388155e357f05932b941789", + "https://bcr.bazel.build/modules/rules_swift/1.2.0/source.json": "c9344551abbd8544e128be8130277da6cd2f54a7a40182700b15c1fb8adb9f81", + "https://bcr.bazel.build/modules/stardoc/0.5.0/MODULE.bazel": "f9f1f46ba8d9c3362648eea571c6f9100680efc44913618811b58cc9c02cd678", "https://bcr.bazel.build/modules/stardoc/0.5.1/MODULE.bazel": "1a05d92974d0c122f5ccf09291442580317cdd859f07a8655f1db9a60374f9f8", "https://bcr.bazel.build/modules/stardoc/0.5.3/MODULE.bazel": "c7f6948dae6999bf0db32c1858ae345f112cacf98f174c7a8bb707e41b974f1c", + "https://bcr.bazel.build/modules/stardoc/0.5.4/MODULE.bazel": "6569966df04610b8520957cb8e97cf2e9faac2c0309657c537ab51c16c18a2a4", "https://bcr.bazel.build/modules/stardoc/0.6.2/MODULE.bazel": "7060193196395f5dd668eda046ccbeacebfd98efc77fed418dbe2b82ffaa39fd", "https://bcr.bazel.build/modules/stardoc/0.6.2/source.json": "d2ff8063b63b4a85e65fe595c4290f99717434fa9f95b4748a79a7d04dfed349", + "https://bcr.bazel.build/modules/upb/0.0.0-20211020-160625a/MODULE.bazel": "6cced416be2dc5b9c05efd5b997049ba795e5e4e6fafbe1624f4587767638928", + "https://bcr.bazel.build/modules/upb/0.0.0-20220602-e5f2601/MODULE.bazel": "84a1b5fc76719c2841759d150637cca2fdc19abccc680d6d02614def044379de", "https://bcr.bazel.build/modules/upb/0.0.0-20220923-a547704/MODULE.bazel": "7298990c00040a0e2f121f6c32544bab27d4452f80d9ce51349b1a28f3005c43", "https://bcr.bazel.build/modules/upb/0.0.0-20230516-61a97ef/MODULE.bazel": "c0df5e35ad55e264160417fd0875932ee3c9dda63d9fccace35ac62f45e1b6f9", "https://bcr.bazel.build/modules/upb/0.0.0-20230516-61a97ef/source.json": "b2150404947339e8b947c6b16baa39fa75657f4ddec5e37272c7b11c7ab533bc", @@ -134,6 +185,11490 @@ ] } }, + "@@aspect_bazel_lib~//lib:extensions.bzl%toolchains": { + "general": { + "bzlTransitiveDigest": "cyiMvevu77OGi9zTS4peB4cqYPpX4JK6nZW3vbnEjMI=", + "usagesDigest": "MqlTLnt+KowkMXXH3DzwrO4g8VlBc9gSrGIl9NEh+S4=", + "recordedFileInputs": {}, + "recordedDirentsInputs": {}, + "envVariables": {}, + "generatedRepoSpecs": { + "copy_directory_darwin_amd64": { + "bzlFile": "@@aspect_bazel_lib~//lib/private:copy_directory_toolchain.bzl", + "ruleClassName": "copy_directory_platform_repo", + "attributes": { + "platform": "darwin_amd64" + } + }, + "copy_directory_darwin_arm64": { + "bzlFile": "@@aspect_bazel_lib~//lib/private:copy_directory_toolchain.bzl", + "ruleClassName": "copy_directory_platform_repo", + "attributes": { + "platform": "darwin_arm64" + } + }, + "copy_directory_freebsd_amd64": { + "bzlFile": "@@aspect_bazel_lib~//lib/private:copy_directory_toolchain.bzl", + "ruleClassName": "copy_directory_platform_repo", + "attributes": { + "platform": "freebsd_amd64" + } + }, + "copy_directory_linux_amd64": { + "bzlFile": "@@aspect_bazel_lib~//lib/private:copy_directory_toolchain.bzl", + "ruleClassName": "copy_directory_platform_repo", + "attributes": { + "platform": "linux_amd64" + } + }, + "copy_directory_linux_arm64": { + "bzlFile": "@@aspect_bazel_lib~//lib/private:copy_directory_toolchain.bzl", + "ruleClassName": "copy_directory_platform_repo", + "attributes": { + "platform": "linux_arm64" + } + }, + "copy_directory_windows_amd64": { + "bzlFile": "@@aspect_bazel_lib~//lib/private:copy_directory_toolchain.bzl", + "ruleClassName": "copy_directory_platform_repo", + "attributes": { + "platform": "windows_amd64" + } + }, + "copy_directory_toolchains": { + "bzlFile": "@@aspect_bazel_lib~//lib/private:copy_directory_toolchain.bzl", + "ruleClassName": "copy_directory_toolchains_repo", + "attributes": { + "user_repository_name": "copy_directory" + } + }, + "copy_to_directory_darwin_amd64": { + "bzlFile": "@@aspect_bazel_lib~//lib/private:copy_to_directory_toolchain.bzl", + "ruleClassName": "copy_to_directory_platform_repo", + "attributes": { + "platform": "darwin_amd64" + } + }, + "copy_to_directory_darwin_arm64": { + "bzlFile": "@@aspect_bazel_lib~//lib/private:copy_to_directory_toolchain.bzl", + "ruleClassName": "copy_to_directory_platform_repo", + "attributes": { + "platform": "darwin_arm64" + } + }, + "copy_to_directory_freebsd_amd64": { + "bzlFile": "@@aspect_bazel_lib~//lib/private:copy_to_directory_toolchain.bzl", + "ruleClassName": "copy_to_directory_platform_repo", + "attributes": { + "platform": "freebsd_amd64" + } + }, + "copy_to_directory_linux_amd64": { + "bzlFile": "@@aspect_bazel_lib~//lib/private:copy_to_directory_toolchain.bzl", + "ruleClassName": "copy_to_directory_platform_repo", + "attributes": { + "platform": "linux_amd64" + } + }, + "copy_to_directory_linux_arm64": { + "bzlFile": "@@aspect_bazel_lib~//lib/private:copy_to_directory_toolchain.bzl", + "ruleClassName": "copy_to_directory_platform_repo", + "attributes": { + "platform": "linux_arm64" + } + }, + "copy_to_directory_windows_amd64": { + "bzlFile": "@@aspect_bazel_lib~//lib/private:copy_to_directory_toolchain.bzl", + "ruleClassName": "copy_to_directory_platform_repo", + "attributes": { + "platform": "windows_amd64" + } + }, + "copy_to_directory_toolchains": { + "bzlFile": "@@aspect_bazel_lib~//lib/private:copy_to_directory_toolchain.bzl", + "ruleClassName": "copy_to_directory_toolchains_repo", + "attributes": { + "user_repository_name": "copy_to_directory" + } + }, + "jq_darwin_amd64": { + "bzlFile": "@@aspect_bazel_lib~//lib/private:jq_toolchain.bzl", + "ruleClassName": "jq_platform_repo", + "attributes": { + "platform": "darwin_amd64", + "version": "1.6" + } + }, + "jq_darwin_arm64": { + "bzlFile": "@@aspect_bazel_lib~//lib/private:jq_toolchain.bzl", + "ruleClassName": "jq_platform_repo", + "attributes": { + "platform": "darwin_arm64", + "version": "1.6" + } + }, + "jq_linux_amd64": { + "bzlFile": "@@aspect_bazel_lib~//lib/private:jq_toolchain.bzl", + "ruleClassName": "jq_platform_repo", + "attributes": { + "platform": "linux_amd64", + "version": "1.6" + } + }, + "jq_windows_amd64": { + "bzlFile": "@@aspect_bazel_lib~//lib/private:jq_toolchain.bzl", + "ruleClassName": "jq_platform_repo", + "attributes": { + "platform": "windows_amd64", + "version": "1.6" + } + }, + "jq": { + "bzlFile": "@@aspect_bazel_lib~//lib/private:jq_toolchain.bzl", + "ruleClassName": "jq_host_alias_repo", + "attributes": {} + }, + "jq_toolchains": { + "bzlFile": "@@aspect_bazel_lib~//lib/private:jq_toolchain.bzl", + "ruleClassName": "jq_toolchains_repo", + "attributes": { + "user_repository_name": "jq" + } + }, + "yq_darwin_amd64": { + "bzlFile": "@@aspect_bazel_lib~//lib/private:yq_toolchain.bzl", + "ruleClassName": "yq_platform_repo", + "attributes": { + "platform": "darwin_amd64", + "version": "4.25.2" + } + }, + "yq_darwin_arm64": { + "bzlFile": "@@aspect_bazel_lib~//lib/private:yq_toolchain.bzl", + "ruleClassName": "yq_platform_repo", + "attributes": { + "platform": "darwin_arm64", + "version": "4.25.2" + } + }, + "yq_linux_amd64": { + "bzlFile": "@@aspect_bazel_lib~//lib/private:yq_toolchain.bzl", + "ruleClassName": "yq_platform_repo", + "attributes": { + "platform": "linux_amd64", + "version": "4.25.2" + } + }, + "yq_linux_arm64": { + "bzlFile": "@@aspect_bazel_lib~//lib/private:yq_toolchain.bzl", + "ruleClassName": "yq_platform_repo", + "attributes": { + "platform": "linux_arm64", + "version": "4.25.2" + } + }, + "yq_linux_s390x": { + "bzlFile": "@@aspect_bazel_lib~//lib/private:yq_toolchain.bzl", + "ruleClassName": "yq_platform_repo", + "attributes": { + "platform": "linux_s390x", + "version": "4.25.2" + } + }, + "yq_linux_ppc64le": { + "bzlFile": "@@aspect_bazel_lib~//lib/private:yq_toolchain.bzl", + "ruleClassName": "yq_platform_repo", + "attributes": { + "platform": "linux_ppc64le", + "version": "4.25.2" + } + }, + "yq_windows_amd64": { + "bzlFile": "@@aspect_bazel_lib~//lib/private:yq_toolchain.bzl", + "ruleClassName": "yq_platform_repo", + "attributes": { + "platform": "windows_amd64", + "version": "4.25.2" + } + }, + "yq": { + "bzlFile": "@@aspect_bazel_lib~//lib/private:yq_toolchain.bzl", + "ruleClassName": "yq_host_alias_repo", + "attributes": {} + }, + "yq_toolchains": { + "bzlFile": "@@aspect_bazel_lib~//lib/private:yq_toolchain.bzl", + "ruleClassName": "yq_toolchains_repo", + "attributes": { + "user_repository_name": "yq" + } + }, + "coreutils_darwin_amd64": { + "bzlFile": "@@aspect_bazel_lib~//lib/private:coreutils_toolchain.bzl", + "ruleClassName": "coreutils_platform_repo", + "attributes": { + "platform": "darwin_amd64", + "version": "0.0.16" + } + }, + "coreutils_darwin_arm64": { + "bzlFile": "@@aspect_bazel_lib~//lib/private:coreutils_toolchain.bzl", + "ruleClassName": "coreutils_platform_repo", + "attributes": { + "platform": "darwin_arm64", + "version": "0.0.16" + } + }, + "coreutils_linux_amd64": { + "bzlFile": "@@aspect_bazel_lib~//lib/private:coreutils_toolchain.bzl", + "ruleClassName": "coreutils_platform_repo", + "attributes": { + "platform": "linux_amd64", + "version": "0.0.16" + } + }, + "coreutils_linux_arm64": { + "bzlFile": "@@aspect_bazel_lib~//lib/private:coreutils_toolchain.bzl", + "ruleClassName": "coreutils_platform_repo", + "attributes": { + "platform": "linux_arm64", + "version": "0.0.16" + } + }, + "coreutils_windows_amd64": { + "bzlFile": "@@aspect_bazel_lib~//lib/private:coreutils_toolchain.bzl", + "ruleClassName": "coreutils_platform_repo", + "attributes": { + "platform": "windows_amd64", + "version": "0.0.16" + } + }, + "coreutils_toolchains": { + "bzlFile": "@@aspect_bazel_lib~//lib/private:coreutils_toolchain.bzl", + "ruleClassName": "coreutils_toolchains_repo", + "attributes": { + "user_repository_name": "coreutils" + } + }, + "expand_template_darwin_amd64": { + "bzlFile": "@@aspect_bazel_lib~//lib/private:expand_template_toolchain.bzl", + "ruleClassName": "expand_template_platform_repo", + "attributes": { + "platform": "darwin_amd64" + } + }, + "expand_template_darwin_arm64": { + "bzlFile": "@@aspect_bazel_lib~//lib/private:expand_template_toolchain.bzl", + "ruleClassName": "expand_template_platform_repo", + "attributes": { + "platform": "darwin_arm64" + } + }, + "expand_template_freebsd_amd64": { + "bzlFile": "@@aspect_bazel_lib~//lib/private:expand_template_toolchain.bzl", + "ruleClassName": "expand_template_platform_repo", + "attributes": { + "platform": "freebsd_amd64" + } + }, + "expand_template_linux_amd64": { + "bzlFile": "@@aspect_bazel_lib~//lib/private:expand_template_toolchain.bzl", + "ruleClassName": "expand_template_platform_repo", + "attributes": { + "platform": "linux_amd64" + } + }, + "expand_template_linux_arm64": { + "bzlFile": "@@aspect_bazel_lib~//lib/private:expand_template_toolchain.bzl", + "ruleClassName": "expand_template_platform_repo", + "attributes": { + "platform": "linux_arm64" + } + }, + "expand_template_windows_amd64": { + "bzlFile": "@@aspect_bazel_lib~//lib/private:expand_template_toolchain.bzl", + "ruleClassName": "expand_template_platform_repo", + "attributes": { + "platform": "windows_amd64" + } + }, + "expand_template_toolchains": { + "bzlFile": "@@aspect_bazel_lib~//lib/private:expand_template_toolchain.bzl", + "ruleClassName": "expand_template_toolchains_repo", + "attributes": { + "user_repository_name": "expand_template" + } + } + }, + "recordedRepoMappingEntries": [ + [ + "aspect_bazel_lib~", + "aspect_bazel_lib", + "aspect_bazel_lib~" + ], + [ + "aspect_bazel_lib~", + "bazel_skylib", + "bazel_skylib~" + ], + [ + "aspect_bazel_lib~", + "bazel_tools", + "bazel_tools" + ] + ] + } + }, + "@@aspect_rules_esbuild~//esbuild:extensions.bzl%esbuild": { + "general": { + "bzlTransitiveDigest": "V7gqrTgzsNF5oyT9MPb6/wXsP0JTzzlsNO813vDYL3o=", + "usagesDigest": "MXYZ9socGXzRzqMvkGiWBTmlQNiIg/0jV3eix/hm9IM=", + "recordedFileInputs": {}, + "recordedDirentsInputs": {}, + "envVariables": {}, + "generatedRepoSpecs": { + "esbuild_darwin-x64": { + "bzlFile": "@@aspect_rules_esbuild~//esbuild:repositories.bzl", + "ruleClassName": "esbuild_repositories", + "attributes": { + "esbuild_version": "0.16.7", + "platform": "darwin-x64" + } + }, + "esbuild_darwin-arm64": { + "bzlFile": "@@aspect_rules_esbuild~//esbuild:repositories.bzl", + "ruleClassName": "esbuild_repositories", + "attributes": { + "esbuild_version": "0.16.7", + "platform": "darwin-arm64" + } + }, + "esbuild_linux-x64": { + "bzlFile": "@@aspect_rules_esbuild~//esbuild:repositories.bzl", + "ruleClassName": "esbuild_repositories", + "attributes": { + "esbuild_version": "0.16.7", + "platform": "linux-x64" + } + }, + "esbuild_linux-arm64": { + "bzlFile": "@@aspect_rules_esbuild~//esbuild:repositories.bzl", + "ruleClassName": "esbuild_repositories", + "attributes": { + "esbuild_version": "0.16.7", + "platform": "linux-arm64" + } + }, + "esbuild_win32-x64": { + "bzlFile": "@@aspect_rules_esbuild~//esbuild:repositories.bzl", + "ruleClassName": "esbuild_repositories", + "attributes": { + "esbuild_version": "0.16.7", + "platform": "win32-x64" + } + }, + "esbuild_toolchains": { + "bzlFile": "@@aspect_rules_esbuild~//esbuild/private:toolchains_repo.bzl", + "ruleClassName": "toolchains_repo", + "attributes": { + "esbuild_version": "0.16.7", + "user_repository_name": "esbuild" + } + }, + "npm__esbuild_0.16.7": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_rule", + "attributes": { + "package": "esbuild", + "version": "0.16.7", + "root_package": "", + "link_workspace": "", + "link_packages": {}, + "integrity": "sha512-P6OBFYFSQOGzfApqCeYKqfKRRbCIRsdppTXFo4aAvtiW3o8TTyiIplBvHJI171saPAiy3WlawJHCveJVIOIx1A==", + "url": "", + "commit": "", + "patch_args": [ + "-p0" + ], + "patches": [], + "custom_postinstall": "", + "npm_auth": "", + "npm_auth_basic": "", + "npm_auth_username": "", + "npm_auth_password": "", + "lifecycle_hooks": [], + "extra_build_content": "", + "generate_bzl_library_targets": false + } + }, + "npm__esbuild_0.16.7__links": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_links", + "attributes": { + "package": "esbuild", + "version": "0.16.7", + "dev": false, + "root_package": "", + "link_packages": {}, + "deps": {}, + "transitive_closure": {}, + "lifecycle_build_target": false, + "lifecycle_hooks_env": [], + "lifecycle_hooks_execution_requirements": [ + "no-sandbox" + ], + "bins": {}, + "npm_translate_lock_repo": "", + "package_visibility": [ + "//visibility:public" + ] + } + } + }, + "recordedRepoMappingEntries": [ + [ + "aspect_bazel_lib~", + "bazel_skylib", + "bazel_skylib~" + ], + [ + "aspect_bazel_lib~", + "bazel_tools", + "bazel_tools" + ], + [ + "aspect_rules_esbuild~", + "aspect_rules_js", + "aspect_rules_js~" + ], + [ + "aspect_rules_esbuild~", + "bazel_skylib", + "bazel_skylib~" + ], + [ + "aspect_rules_js~", + "aspect_bazel_lib", + "aspect_bazel_lib~" + ], + [ + "aspect_rules_js~", + "bazel_features", + "bazel_features~" + ], + [ + "aspect_rules_js~", + "bazel_skylib", + "bazel_skylib~" + ], + [ + "aspect_rules_js~", + "bazel_tools", + "bazel_tools" + ], + [ + "bazel_features~", + "bazel_tools", + "bazel_tools" + ] + ] + } + }, + "@@aspect_rules_js~//npm:extensions.bzl%npm": { + "general": { + "bzlTransitiveDigest": "w6lgYO0RBRKMcZfpeusmQItcjQEtzi3yrl4pr7bCcxY=", + "usagesDigest": "X+Ly8Gee3CATehlJzM3ECpcqMkB382icTycKvOy8CIU=", + "recordedFileInputs": { + "@@flatbuffers~//pnpm-lock.yaml": "130fab1c4307b9bdac43bf88332f2311209a33aead012922d7a97e58a50b7de4", + "@@flatbuffers~//.npmrc": "d94d573d5aa644cdd09ff46d9b9c5e9b59185533420308c9a55ad5dc3176f22b" + }, + "recordedDirentsInputs": {}, + "envVariables": {}, + "generatedRepoSpecs": { + "npm": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_translate_lock.bzl", + "ruleClassName": "npm_translate_lock_rule", + "attributes": { + "pnpm_lock": "@@flatbuffers~//:pnpm-lock.yaml", + "update_pnpm_lock": false, + "npmrc": "@@flatbuffers~//:.npmrc", + "use_home_npmrc": false, + "patches": {}, + "patch_args": {}, + "custom_postinstalls": {}, + "package_visibility": {}, + "prod": false, + "public_hoist_packages": {}, + "dev": false, + "no_optional": false, + "lifecycle_hooks": { + "*": [ + "preinstall", + "install", + "postinstall" + ] + }, + "lifecycle_hooks_envs": {}, + "lifecycle_hooks_execution_requirements": { + "*": [ + "no-sandbox" + ] + }, + "bins": {}, + "verify_node_modules_ignored": "@@flatbuffers~//:.bazelignore", + "external_repository_action_cache": ".aspect/rules/external_repository_action_cache", + "link_workspace": "", + "root_package": ".", + "additional_file_contents": {}, + "repositories_bzl_filename": "repositories.bzl", + "defs_bzl_filename": "defs.bzl", + "generate_bzl_library_targets": false, + "data": [], + "preupdate": [], + "quiet": true, + "update_pnpm_lock_node_toolchain_prefix": "nodejs", + "npm_package_target_name": "{dirname}" + } + }, + "npm__at_aashutoshrathi_word-wrap__1.2.6": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_rule", + "attributes": { + "package": "@aashutoshrathi/word-wrap", + "version": "1.2.6", + "root_package": "", + "link_workspace": "flatbuffers~", + "link_packages": {}, + "integrity": "sha512-1Yjs2SvM8TflER/OD3cOjhWWOZb58A2t7wpE2S9XfBYTiIl+XFhQG2bjy4Pu1I+EAlCNUzRDYDdFwFYUKvXcIA==", + "url": "https://registry.npmjs.org/@aashutoshrathi/word-wrap/-/word-wrap-1.2.6.tgz", + "commit": "", + "patch_args": [], + "patches": [], + "custom_postinstall": "", + "npm_auth": "", + "npm_auth_basic": "", + "npm_auth_username": "", + "npm_auth_password": "", + "lifecycle_hooks": [], + "extra_build_content": "", + "generate_bzl_library_targets": false + } + }, + "npm__at_aashutoshrathi_word-wrap__1.2.6__links": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_links", + "attributes": { + "package": "@aashutoshrathi/word-wrap", + "version": "1.2.6", + "dev": true, + "root_package": "", + "link_packages": {}, + "deps": {}, + "transitive_closure": { + "@aashutoshrathi/word-wrap": [ + "1.2.6" + ] + }, + "lifecycle_build_target": false, + "lifecycle_hooks_env": [], + "lifecycle_hooks_execution_requirements": [], + "bins": {}, + "npm_translate_lock_repo": "npm", + "package_visibility": [ + "//visibility:public" + ] + } + }, + "npm__at_esbuild_android-arm64__0.19.8": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_rule", + "attributes": { + "package": "@esbuild/android-arm64", + "version": "0.19.8", + "root_package": "", + "link_workspace": "flatbuffers~", + "link_packages": {}, + "integrity": "sha512-B8JbS61bEunhfx8kasogFENgQfr/dIp+ggYXwTqdbMAgGDhRa3AaPpQMuQU0rNxDLECj6FhDzk1cF9WHMVwrtA==", + "url": "https://registry.npmjs.org/@esbuild/android-arm64/-/android-arm64-0.19.8.tgz", + "commit": "", + "patch_args": [], + "patches": [], + "custom_postinstall": "", + "npm_auth": "", + "npm_auth_basic": "", + "npm_auth_username": "", + "npm_auth_password": "", + "lifecycle_hooks": [ + "preinstall", + "install", + "postinstall" + ], + "extra_build_content": "", + "generate_bzl_library_targets": false + } + }, + "npm__at_esbuild_android-arm64__0.19.8__links": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_links", + "attributes": { + "package": "@esbuild/android-arm64", + "version": "0.19.8", + "dev": true, + "root_package": "", + "link_packages": {}, + "deps": {}, + "transitive_closure": { + "@esbuild/android-arm64": [ + "0.19.8" + ] + }, + "lifecycle_build_target": true, + "lifecycle_hooks_env": [], + "lifecycle_hooks_execution_requirements": [ + "no-sandbox" + ], + "bins": {}, + "npm_translate_lock_repo": "npm", + "package_visibility": [ + "//visibility:public" + ] + } + }, + "npm__at_esbuild_android-arm__0.19.8": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_rule", + "attributes": { + "package": "@esbuild/android-arm", + "version": "0.19.8", + "root_package": "", + "link_workspace": "flatbuffers~", + "link_packages": {}, + "integrity": "sha512-31E2lxlGM1KEfivQl8Yf5aYU/mflz9g06H6S15ITUFQueMFtFjESRMoDSkvMo8thYvLBax+VKTPlpnx+sPicOA==", + "url": "https://registry.npmjs.org/@esbuild/android-arm/-/android-arm-0.19.8.tgz", + "commit": "", + "patch_args": [], + "patches": [], + "custom_postinstall": "", + "npm_auth": "", + "npm_auth_basic": "", + "npm_auth_username": "", + "npm_auth_password": "", + "lifecycle_hooks": [ + "preinstall", + "install", + "postinstall" + ], + "extra_build_content": "", + "generate_bzl_library_targets": false + } + }, + "npm__at_esbuild_android-arm__0.19.8__links": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_links", + "attributes": { + "package": "@esbuild/android-arm", + "version": "0.19.8", + "dev": true, + "root_package": "", + "link_packages": {}, + "deps": {}, + "transitive_closure": { + "@esbuild/android-arm": [ + "0.19.8" + ] + }, + "lifecycle_build_target": true, + "lifecycle_hooks_env": [], + "lifecycle_hooks_execution_requirements": [ + "no-sandbox" + ], + "bins": {}, + "npm_translate_lock_repo": "npm", + "package_visibility": [ + "//visibility:public" + ] + } + }, + "npm__at_esbuild_android-x64__0.19.8": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_rule", + "attributes": { + "package": "@esbuild/android-x64", + "version": "0.19.8", + "root_package": "", + "link_workspace": "flatbuffers~", + "link_packages": {}, + "integrity": "sha512-rdqqYfRIn4jWOp+lzQttYMa2Xar3OK9Yt2fhOhzFXqg0rVWEfSclJvZq5fZslnz6ypHvVf3CT7qyf0A5pM682A==", + "url": "https://registry.npmjs.org/@esbuild/android-x64/-/android-x64-0.19.8.tgz", + "commit": "", + "patch_args": [], + "patches": [], + "custom_postinstall": "", + "npm_auth": "", + "npm_auth_basic": "", + "npm_auth_username": "", + "npm_auth_password": "", + "lifecycle_hooks": [ + "preinstall", + "install", + "postinstall" + ], + "extra_build_content": "", + "generate_bzl_library_targets": false + } + }, + "npm__at_esbuild_android-x64__0.19.8__links": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_links", + "attributes": { + "package": "@esbuild/android-x64", + "version": "0.19.8", + "dev": true, + "root_package": "", + "link_packages": {}, + "deps": {}, + "transitive_closure": { + "@esbuild/android-x64": [ + "0.19.8" + ] + }, + "lifecycle_build_target": true, + "lifecycle_hooks_env": [], + "lifecycle_hooks_execution_requirements": [ + "no-sandbox" + ], + "bins": {}, + "npm_translate_lock_repo": "npm", + "package_visibility": [ + "//visibility:public" + ] + } + }, + "npm__at_esbuild_darwin-arm64__0.19.8": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_rule", + "attributes": { + "package": "@esbuild/darwin-arm64", + "version": "0.19.8", + "root_package": "", + "link_workspace": "flatbuffers~", + "link_packages": {}, + "integrity": "sha512-RQw9DemMbIq35Bprbboyf8SmOr4UXsRVxJ97LgB55VKKeJOOdvsIPy0nFyF2l8U+h4PtBx/1kRf0BelOYCiQcw==", + "url": "https://registry.npmjs.org/@esbuild/darwin-arm64/-/darwin-arm64-0.19.8.tgz", + "commit": "", + "patch_args": [], + "patches": [], + "custom_postinstall": "", + "npm_auth": "", + "npm_auth_basic": "", + "npm_auth_username": "", + "npm_auth_password": "", + "lifecycle_hooks": [ + "preinstall", + "install", + "postinstall" + ], + "extra_build_content": "", + "generate_bzl_library_targets": false + } + }, + "npm__at_esbuild_darwin-arm64__0.19.8__links": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_links", + "attributes": { + "package": "@esbuild/darwin-arm64", + "version": "0.19.8", + "dev": true, + "root_package": "", + "link_packages": {}, + "deps": {}, + "transitive_closure": { + "@esbuild/darwin-arm64": [ + "0.19.8" + ] + }, + "lifecycle_build_target": true, + "lifecycle_hooks_env": [], + "lifecycle_hooks_execution_requirements": [ + "no-sandbox" + ], + "bins": {}, + "npm_translate_lock_repo": "npm", + "package_visibility": [ + "//visibility:public" + ] + } + }, + "npm__at_esbuild_darwin-x64__0.19.8": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_rule", + "attributes": { + "package": "@esbuild/darwin-x64", + "version": "0.19.8", + "root_package": "", + "link_workspace": "flatbuffers~", + "link_packages": {}, + "integrity": "sha512-3sur80OT9YdeZwIVgERAysAbwncom7b4bCI2XKLjMfPymTud7e/oY4y+ci1XVp5TfQp/bppn7xLw1n/oSQY3/Q==", + "url": "https://registry.npmjs.org/@esbuild/darwin-x64/-/darwin-x64-0.19.8.tgz", + "commit": "", + "patch_args": [], + "patches": [], + "custom_postinstall": "", + "npm_auth": "", + "npm_auth_basic": "", + "npm_auth_username": "", + "npm_auth_password": "", + "lifecycle_hooks": [ + "preinstall", + "install", + "postinstall" + ], + "extra_build_content": "", + "generate_bzl_library_targets": false + } + }, + "npm__at_esbuild_darwin-x64__0.19.8__links": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_links", + "attributes": { + "package": "@esbuild/darwin-x64", + "version": "0.19.8", + "dev": true, + "root_package": "", + "link_packages": {}, + "deps": {}, + "transitive_closure": { + "@esbuild/darwin-x64": [ + "0.19.8" + ] + }, + "lifecycle_build_target": true, + "lifecycle_hooks_env": [], + "lifecycle_hooks_execution_requirements": [ + "no-sandbox" + ], + "bins": {}, + "npm_translate_lock_repo": "npm", + "package_visibility": [ + "//visibility:public" + ] + } + }, + "npm__at_esbuild_freebsd-arm64__0.19.8": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_rule", + "attributes": { + "package": "@esbuild/freebsd-arm64", + "version": "0.19.8", + "root_package": "", + "link_workspace": "flatbuffers~", + "link_packages": {}, + "integrity": "sha512-WAnPJSDattvS/XtPCTj1tPoTxERjcTpH6HsMr6ujTT+X6rylVe8ggxk8pVxzf5U1wh5sPODpawNicF5ta/9Tmw==", + "url": "https://registry.npmjs.org/@esbuild/freebsd-arm64/-/freebsd-arm64-0.19.8.tgz", + "commit": "", + "patch_args": [], + "patches": [], + "custom_postinstall": "", + "npm_auth": "", + "npm_auth_basic": "", + "npm_auth_username": "", + "npm_auth_password": "", + "lifecycle_hooks": [ + "preinstall", + "install", + "postinstall" + ], + "extra_build_content": "", + "generate_bzl_library_targets": false + } + }, + "npm__at_esbuild_freebsd-arm64__0.19.8__links": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_links", + "attributes": { + "package": "@esbuild/freebsd-arm64", + "version": "0.19.8", + "dev": true, + "root_package": "", + "link_packages": {}, + "deps": {}, + "transitive_closure": { + "@esbuild/freebsd-arm64": [ + "0.19.8" + ] + }, + "lifecycle_build_target": true, + "lifecycle_hooks_env": [], + "lifecycle_hooks_execution_requirements": [ + "no-sandbox" + ], + "bins": {}, + "npm_translate_lock_repo": "npm", + "package_visibility": [ + "//visibility:public" + ] + } + }, + "npm__at_esbuild_freebsd-x64__0.19.8": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_rule", + "attributes": { + "package": "@esbuild/freebsd-x64", + "version": "0.19.8", + "root_package": "", + "link_workspace": "flatbuffers~", + "link_packages": {}, + "integrity": "sha512-ICvZyOplIjmmhjd6mxi+zxSdpPTKFfyPPQMQTK/w+8eNK6WV01AjIztJALDtwNNfFhfZLux0tZLC+U9nSyA5Zg==", + "url": "https://registry.npmjs.org/@esbuild/freebsd-x64/-/freebsd-x64-0.19.8.tgz", + "commit": "", + "patch_args": [], + "patches": [], + "custom_postinstall": "", + "npm_auth": "", + "npm_auth_basic": "", + "npm_auth_username": "", + "npm_auth_password": "", + "lifecycle_hooks": [ + "preinstall", + "install", + "postinstall" + ], + "extra_build_content": "", + "generate_bzl_library_targets": false + } + }, + "npm__at_esbuild_freebsd-x64__0.19.8__links": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_links", + "attributes": { + "package": "@esbuild/freebsd-x64", + "version": "0.19.8", + "dev": true, + "root_package": "", + "link_packages": {}, + "deps": {}, + "transitive_closure": { + "@esbuild/freebsd-x64": [ + "0.19.8" + ] + }, + "lifecycle_build_target": true, + "lifecycle_hooks_env": [], + "lifecycle_hooks_execution_requirements": [ + "no-sandbox" + ], + "bins": {}, + "npm_translate_lock_repo": "npm", + "package_visibility": [ + "//visibility:public" + ] + } + }, + "npm__at_esbuild_linux-arm64__0.19.8": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_rule", + "attributes": { + "package": "@esbuild/linux-arm64", + "version": "0.19.8", + "root_package": "", + "link_workspace": "flatbuffers~", + "link_packages": {}, + "integrity": "sha512-z1zMZivxDLHWnyGOctT9JP70h0beY54xDDDJt4VpTX+iwA77IFsE1vCXWmprajJGa+ZYSqkSbRQ4eyLCpCmiCQ==", + "url": "https://registry.npmjs.org/@esbuild/linux-arm64/-/linux-arm64-0.19.8.tgz", + "commit": "", + "patch_args": [], + "patches": [], + "custom_postinstall": "", + "npm_auth": "", + "npm_auth_basic": "", + "npm_auth_username": "", + "npm_auth_password": "", + "lifecycle_hooks": [ + "preinstall", + "install", + "postinstall" + ], + "extra_build_content": "", + "generate_bzl_library_targets": false + } + }, + "npm__at_esbuild_linux-arm64__0.19.8__links": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_links", + "attributes": { + "package": "@esbuild/linux-arm64", + "version": "0.19.8", + "dev": true, + "root_package": "", + "link_packages": {}, + "deps": {}, + "transitive_closure": { + "@esbuild/linux-arm64": [ + "0.19.8" + ] + }, + "lifecycle_build_target": true, + "lifecycle_hooks_env": [], + "lifecycle_hooks_execution_requirements": [ + "no-sandbox" + ], + "bins": {}, + "npm_translate_lock_repo": "npm", + "package_visibility": [ + "//visibility:public" + ] + } + }, + "npm__at_esbuild_linux-arm__0.19.8": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_rule", + "attributes": { + "package": "@esbuild/linux-arm", + "version": "0.19.8", + "root_package": "", + "link_workspace": "flatbuffers~", + "link_packages": {}, + "integrity": "sha512-H4vmI5PYqSvosPaTJuEppU9oz1dq2A7Mr2vyg5TF9Ga+3+MGgBdGzcyBP7qK9MrwFQZlvNyJrvz6GuCaj3OukQ==", + "url": "https://registry.npmjs.org/@esbuild/linux-arm/-/linux-arm-0.19.8.tgz", + "commit": "", + "patch_args": [], + "patches": [], + "custom_postinstall": "", + "npm_auth": "", + "npm_auth_basic": "", + "npm_auth_username": "", + "npm_auth_password": "", + "lifecycle_hooks": [ + "preinstall", + "install", + "postinstall" + ], + "extra_build_content": "", + "generate_bzl_library_targets": false + } + }, + "npm__at_esbuild_linux-arm__0.19.8__links": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_links", + "attributes": { + "package": "@esbuild/linux-arm", + "version": "0.19.8", + "dev": true, + "root_package": "", + "link_packages": {}, + "deps": {}, + "transitive_closure": { + "@esbuild/linux-arm": [ + "0.19.8" + ] + }, + "lifecycle_build_target": true, + "lifecycle_hooks_env": [], + "lifecycle_hooks_execution_requirements": [ + "no-sandbox" + ], + "bins": {}, + "npm_translate_lock_repo": "npm", + "package_visibility": [ + "//visibility:public" + ] + } + }, + "npm__at_esbuild_linux-ia32__0.19.8": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_rule", + "attributes": { + "package": "@esbuild/linux-ia32", + "version": "0.19.8", + "root_package": "", + "link_workspace": "flatbuffers~", + "link_packages": {}, + "integrity": "sha512-1a8suQiFJmZz1khm/rDglOc8lavtzEMRo0v6WhPgxkrjcU0LkHj+TwBrALwoz/OtMExvsqbbMI0ChyelKabSvQ==", + "url": "https://registry.npmjs.org/@esbuild/linux-ia32/-/linux-ia32-0.19.8.tgz", + "commit": "", + "patch_args": [], + "patches": [], + "custom_postinstall": "", + "npm_auth": "", + "npm_auth_basic": "", + "npm_auth_username": "", + "npm_auth_password": "", + "lifecycle_hooks": [ + "preinstall", + "install", + "postinstall" + ], + "extra_build_content": "", + "generate_bzl_library_targets": false + } + }, + "npm__at_esbuild_linux-ia32__0.19.8__links": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_links", + "attributes": { + "package": "@esbuild/linux-ia32", + "version": "0.19.8", + "dev": true, + "root_package": "", + "link_packages": {}, + "deps": {}, + "transitive_closure": { + "@esbuild/linux-ia32": [ + "0.19.8" + ] + }, + "lifecycle_build_target": true, + "lifecycle_hooks_env": [], + "lifecycle_hooks_execution_requirements": [ + "no-sandbox" + ], + "bins": {}, + "npm_translate_lock_repo": "npm", + "package_visibility": [ + "//visibility:public" + ] + } + }, + "npm__at_esbuild_linux-loong64__0.19.8": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_rule", + "attributes": { + "package": "@esbuild/linux-loong64", + "version": "0.19.8", + "root_package": "", + "link_workspace": "flatbuffers~", + "link_packages": {}, + "integrity": "sha512-fHZWS2JJxnXt1uYJsDv9+b60WCc2RlvVAy1F76qOLtXRO+H4mjt3Tr6MJ5l7Q78X8KgCFudnTuiQRBhULUyBKQ==", + "url": "https://registry.npmjs.org/@esbuild/linux-loong64/-/linux-loong64-0.19.8.tgz", + "commit": "", + "patch_args": [], + "patches": [], + "custom_postinstall": "", + "npm_auth": "", + "npm_auth_basic": "", + "npm_auth_username": "", + "npm_auth_password": "", + "lifecycle_hooks": [ + "preinstall", + "install", + "postinstall" + ], + "extra_build_content": "", + "generate_bzl_library_targets": false + } + }, + "npm__at_esbuild_linux-loong64__0.19.8__links": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_links", + "attributes": { + "package": "@esbuild/linux-loong64", + "version": "0.19.8", + "dev": true, + "root_package": "", + "link_packages": {}, + "deps": {}, + "transitive_closure": { + "@esbuild/linux-loong64": [ + "0.19.8" + ] + }, + "lifecycle_build_target": true, + "lifecycle_hooks_env": [], + "lifecycle_hooks_execution_requirements": [ + "no-sandbox" + ], + "bins": {}, + "npm_translate_lock_repo": "npm", + "package_visibility": [ + "//visibility:public" + ] + } + }, + "npm__at_esbuild_linux-mips64el__0.19.8": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_rule", + "attributes": { + "package": "@esbuild/linux-mips64el", + "version": "0.19.8", + "root_package": "", + "link_workspace": "flatbuffers~", + "link_packages": {}, + "integrity": "sha512-Wy/z0EL5qZYLX66dVnEg9riiwls5IYnziwuju2oUiuxVc+/edvqXa04qNtbrs0Ukatg5HEzqT94Zs7J207dN5Q==", + "url": "https://registry.npmjs.org/@esbuild/linux-mips64el/-/linux-mips64el-0.19.8.tgz", + "commit": "", + "patch_args": [], + "patches": [], + "custom_postinstall": "", + "npm_auth": "", + "npm_auth_basic": "", + "npm_auth_username": "", + "npm_auth_password": "", + "lifecycle_hooks": [ + "preinstall", + "install", + "postinstall" + ], + "extra_build_content": "", + "generate_bzl_library_targets": false + } + }, + "npm__at_esbuild_linux-mips64el__0.19.8__links": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_links", + "attributes": { + "package": "@esbuild/linux-mips64el", + "version": "0.19.8", + "dev": true, + "root_package": "", + "link_packages": {}, + "deps": {}, + "transitive_closure": { + "@esbuild/linux-mips64el": [ + "0.19.8" + ] + }, + "lifecycle_build_target": true, + "lifecycle_hooks_env": [], + "lifecycle_hooks_execution_requirements": [ + "no-sandbox" + ], + "bins": {}, + "npm_translate_lock_repo": "npm", + "package_visibility": [ + "//visibility:public" + ] + } + }, + "npm__at_esbuild_linux-ppc64__0.19.8": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_rule", + "attributes": { + "package": "@esbuild/linux-ppc64", + "version": "0.19.8", + "root_package": "", + "link_workspace": "flatbuffers~", + "link_packages": {}, + "integrity": "sha512-ETaW6245wK23YIEufhMQ3HSeHO7NgsLx8gygBVldRHKhOlD1oNeNy/P67mIh1zPn2Hr2HLieQrt6tWrVwuqrxg==", + "url": "https://registry.npmjs.org/@esbuild/linux-ppc64/-/linux-ppc64-0.19.8.tgz", + "commit": "", + "patch_args": [], + "patches": [], + "custom_postinstall": "", + "npm_auth": "", + "npm_auth_basic": "", + "npm_auth_username": "", + "npm_auth_password": "", + "lifecycle_hooks": [ + "preinstall", + "install", + "postinstall" + ], + "extra_build_content": "", + "generate_bzl_library_targets": false + } + }, + "npm__at_esbuild_linux-ppc64__0.19.8__links": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_links", + "attributes": { + "package": "@esbuild/linux-ppc64", + "version": "0.19.8", + "dev": true, + "root_package": "", + "link_packages": {}, + "deps": {}, + "transitive_closure": { + "@esbuild/linux-ppc64": [ + "0.19.8" + ] + }, + "lifecycle_build_target": true, + "lifecycle_hooks_env": [], + "lifecycle_hooks_execution_requirements": [ + "no-sandbox" + ], + "bins": {}, + "npm_translate_lock_repo": "npm", + "package_visibility": [ + "//visibility:public" + ] + } + }, + "npm__at_esbuild_linux-riscv64__0.19.8": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_rule", + "attributes": { + "package": "@esbuild/linux-riscv64", + "version": "0.19.8", + "root_package": "", + "link_workspace": "flatbuffers~", + "link_packages": {}, + "integrity": "sha512-T2DRQk55SgoleTP+DtPlMrxi/5r9AeFgkhkZ/B0ap99zmxtxdOixOMI570VjdRCs9pE4Wdkz7JYrsPvsl7eESg==", + "url": "https://registry.npmjs.org/@esbuild/linux-riscv64/-/linux-riscv64-0.19.8.tgz", + "commit": "", + "patch_args": [], + "patches": [], + "custom_postinstall": "", + "npm_auth": "", + "npm_auth_basic": "", + "npm_auth_username": "", + "npm_auth_password": "", + "lifecycle_hooks": [ + "preinstall", + "install", + "postinstall" + ], + "extra_build_content": "", + "generate_bzl_library_targets": false + } + }, + "npm__at_esbuild_linux-riscv64__0.19.8__links": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_links", + "attributes": { + "package": "@esbuild/linux-riscv64", + "version": "0.19.8", + "dev": true, + "root_package": "", + "link_packages": {}, + "deps": {}, + "transitive_closure": { + "@esbuild/linux-riscv64": [ + "0.19.8" + ] + }, + "lifecycle_build_target": true, + "lifecycle_hooks_env": [], + "lifecycle_hooks_execution_requirements": [ + "no-sandbox" + ], + "bins": {}, + "npm_translate_lock_repo": "npm", + "package_visibility": [ + "//visibility:public" + ] + } + }, + "npm__at_esbuild_linux-s390x__0.19.8": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_rule", + "attributes": { + "package": "@esbuild/linux-s390x", + "version": "0.19.8", + "root_package": "", + "link_workspace": "flatbuffers~", + "link_packages": {}, + "integrity": "sha512-NPxbdmmo3Bk7mbNeHmcCd7R7fptJaczPYBaELk6NcXxy7HLNyWwCyDJ/Xx+/YcNH7Im5dHdx9gZ5xIwyliQCbg==", + "url": "https://registry.npmjs.org/@esbuild/linux-s390x/-/linux-s390x-0.19.8.tgz", + "commit": "", + "patch_args": [], + "patches": [], + "custom_postinstall": "", + "npm_auth": "", + "npm_auth_basic": "", + "npm_auth_username": "", + "npm_auth_password": "", + "lifecycle_hooks": [ + "preinstall", + "install", + "postinstall" + ], + "extra_build_content": "", + "generate_bzl_library_targets": false + } + }, + "npm__at_esbuild_linux-s390x__0.19.8__links": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_links", + "attributes": { + "package": "@esbuild/linux-s390x", + "version": "0.19.8", + "dev": true, + "root_package": "", + "link_packages": {}, + "deps": {}, + "transitive_closure": { + "@esbuild/linux-s390x": [ + "0.19.8" + ] + }, + "lifecycle_build_target": true, + "lifecycle_hooks_env": [], + "lifecycle_hooks_execution_requirements": [ + "no-sandbox" + ], + "bins": {}, + "npm_translate_lock_repo": "npm", + "package_visibility": [ + "//visibility:public" + ] + } + }, + "npm__at_esbuild_linux-x64__0.19.8": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_rule", + "attributes": { + "package": "@esbuild/linux-x64", + "version": "0.19.8", + "root_package": "", + "link_workspace": "flatbuffers~", + "link_packages": {}, + "integrity": "sha512-lytMAVOM3b1gPypL2TRmZ5rnXl7+6IIk8uB3eLsV1JwcizuolblXRrc5ShPrO9ls/b+RTp+E6gbsuLWHWi2zGg==", + "url": "https://registry.npmjs.org/@esbuild/linux-x64/-/linux-x64-0.19.8.tgz", + "commit": "", + "patch_args": [], + "patches": [], + "custom_postinstall": "", + "npm_auth": "", + "npm_auth_basic": "", + "npm_auth_username": "", + "npm_auth_password": "", + "lifecycle_hooks": [ + "preinstall", + "install", + "postinstall" + ], + "extra_build_content": "", + "generate_bzl_library_targets": false + } + }, + "npm__at_esbuild_linux-x64__0.19.8__links": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_links", + "attributes": { + "package": "@esbuild/linux-x64", + "version": "0.19.8", + "dev": true, + "root_package": "", + "link_packages": {}, + "deps": {}, + "transitive_closure": { + "@esbuild/linux-x64": [ + "0.19.8" + ] + }, + "lifecycle_build_target": true, + "lifecycle_hooks_env": [], + "lifecycle_hooks_execution_requirements": [ + "no-sandbox" + ], + "bins": {}, + "npm_translate_lock_repo": "npm", + "package_visibility": [ + "//visibility:public" + ] + } + }, + "npm__at_esbuild_netbsd-x64__0.19.8": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_rule", + "attributes": { + "package": "@esbuild/netbsd-x64", + "version": "0.19.8", + "root_package": "", + "link_workspace": "flatbuffers~", + "link_packages": {}, + "integrity": "sha512-hvWVo2VsXz/8NVt1UhLzxwAfo5sioj92uo0bCfLibB0xlOmimU/DeAEsQILlBQvkhrGjamP0/el5HU76HAitGw==", + "url": "https://registry.npmjs.org/@esbuild/netbsd-x64/-/netbsd-x64-0.19.8.tgz", + "commit": "", + "patch_args": [], + "patches": [], + "custom_postinstall": "", + "npm_auth": "", + "npm_auth_basic": "", + "npm_auth_username": "", + "npm_auth_password": "", + "lifecycle_hooks": [ + "preinstall", + "install", + "postinstall" + ], + "extra_build_content": "", + "generate_bzl_library_targets": false + } + }, + "npm__at_esbuild_netbsd-x64__0.19.8__links": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_links", + "attributes": { + "package": "@esbuild/netbsd-x64", + "version": "0.19.8", + "dev": true, + "root_package": "", + "link_packages": {}, + "deps": {}, + "transitive_closure": { + "@esbuild/netbsd-x64": [ + "0.19.8" + ] + }, + "lifecycle_build_target": true, + "lifecycle_hooks_env": [], + "lifecycle_hooks_execution_requirements": [ + "no-sandbox" + ], + "bins": {}, + "npm_translate_lock_repo": "npm", + "package_visibility": [ + "//visibility:public" + ] + } + }, + "npm__at_esbuild_openbsd-x64__0.19.8": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_rule", + "attributes": { + "package": "@esbuild/openbsd-x64", + "version": "0.19.8", + "root_package": "", + "link_workspace": "flatbuffers~", + "link_packages": {}, + "integrity": "sha512-/7Y7u77rdvmGTxR83PgaSvSBJCC2L3Kb1M/+dmSIvRvQPXXCuC97QAwMugBNG0yGcbEGfFBH7ojPzAOxfGNkwQ==", + "url": "https://registry.npmjs.org/@esbuild/openbsd-x64/-/openbsd-x64-0.19.8.tgz", + "commit": "", + "patch_args": [], + "patches": [], + "custom_postinstall": "", + "npm_auth": "", + "npm_auth_basic": "", + "npm_auth_username": "", + "npm_auth_password": "", + "lifecycle_hooks": [ + "preinstall", + "install", + "postinstall" + ], + "extra_build_content": "", + "generate_bzl_library_targets": false + } + }, + "npm__at_esbuild_openbsd-x64__0.19.8__links": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_links", + "attributes": { + "package": "@esbuild/openbsd-x64", + "version": "0.19.8", + "dev": true, + "root_package": "", + "link_packages": {}, + "deps": {}, + "transitive_closure": { + "@esbuild/openbsd-x64": [ + "0.19.8" + ] + }, + "lifecycle_build_target": true, + "lifecycle_hooks_env": [], + "lifecycle_hooks_execution_requirements": [ + "no-sandbox" + ], + "bins": {}, + "npm_translate_lock_repo": "npm", + "package_visibility": [ + "//visibility:public" + ] + } + }, + "npm__at_esbuild_sunos-x64__0.19.8": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_rule", + "attributes": { + "package": "@esbuild/sunos-x64", + "version": "0.19.8", + "root_package": "", + "link_workspace": "flatbuffers~", + "link_packages": {}, + "integrity": "sha512-9Lc4s7Oi98GqFA4HzA/W2JHIYfnXbUYgekUP/Sm4BG9sfLjyv6GKKHKKVs83SMicBF2JwAX6A1PuOLMqpD001w==", + "url": "https://registry.npmjs.org/@esbuild/sunos-x64/-/sunos-x64-0.19.8.tgz", + "commit": "", + "patch_args": [], + "patches": [], + "custom_postinstall": "", + "npm_auth": "", + "npm_auth_basic": "", + "npm_auth_username": "", + "npm_auth_password": "", + "lifecycle_hooks": [ + "preinstall", + "install", + "postinstall" + ], + "extra_build_content": "", + "generate_bzl_library_targets": false + } + }, + "npm__at_esbuild_sunos-x64__0.19.8__links": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_links", + "attributes": { + "package": "@esbuild/sunos-x64", + "version": "0.19.8", + "dev": true, + "root_package": "", + "link_packages": {}, + "deps": {}, + "transitive_closure": { + "@esbuild/sunos-x64": [ + "0.19.8" + ] + }, + "lifecycle_build_target": true, + "lifecycle_hooks_env": [], + "lifecycle_hooks_execution_requirements": [ + "no-sandbox" + ], + "bins": {}, + "npm_translate_lock_repo": "npm", + "package_visibility": [ + "//visibility:public" + ] + } + }, + "npm__at_esbuild_win32-arm64__0.19.8": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_rule", + "attributes": { + "package": "@esbuild/win32-arm64", + "version": "0.19.8", + "root_package": "", + "link_workspace": "flatbuffers~", + "link_packages": {}, + "integrity": "sha512-rq6WzBGjSzihI9deW3fC2Gqiak68+b7qo5/3kmB6Gvbh/NYPA0sJhrnp7wgV4bNwjqM+R2AApXGxMO7ZoGhIJg==", + "url": "https://registry.npmjs.org/@esbuild/win32-arm64/-/win32-arm64-0.19.8.tgz", + "commit": "", + "patch_args": [], + "patches": [], + "custom_postinstall": "", + "npm_auth": "", + "npm_auth_basic": "", + "npm_auth_username": "", + "npm_auth_password": "", + "lifecycle_hooks": [ + "preinstall", + "install", + "postinstall" + ], + "extra_build_content": "", + "generate_bzl_library_targets": false + } + }, + "npm__at_esbuild_win32-arm64__0.19.8__links": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_links", + "attributes": { + "package": "@esbuild/win32-arm64", + "version": "0.19.8", + "dev": true, + "root_package": "", + "link_packages": {}, + "deps": {}, + "transitive_closure": { + "@esbuild/win32-arm64": [ + "0.19.8" + ] + }, + "lifecycle_build_target": true, + "lifecycle_hooks_env": [], + "lifecycle_hooks_execution_requirements": [ + "no-sandbox" + ], + "bins": {}, + "npm_translate_lock_repo": "npm", + "package_visibility": [ + "//visibility:public" + ] + } + }, + "npm__at_esbuild_win32-ia32__0.19.8": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_rule", + "attributes": { + "package": "@esbuild/win32-ia32", + "version": "0.19.8", + "root_package": "", + "link_workspace": "flatbuffers~", + "link_packages": {}, + "integrity": "sha512-AIAbverbg5jMvJznYiGhrd3sumfwWs8572mIJL5NQjJa06P8KfCPWZQ0NwZbPQnbQi9OWSZhFVSUWjjIrn4hSw==", + "url": "https://registry.npmjs.org/@esbuild/win32-ia32/-/win32-ia32-0.19.8.tgz", + "commit": "", + "patch_args": [], + "patches": [], + "custom_postinstall": "", + "npm_auth": "", + "npm_auth_basic": "", + "npm_auth_username": "", + "npm_auth_password": "", + "lifecycle_hooks": [ + "preinstall", + "install", + "postinstall" + ], + "extra_build_content": "", + "generate_bzl_library_targets": false + } + }, + "npm__at_esbuild_win32-ia32__0.19.8__links": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_links", + "attributes": { + "package": "@esbuild/win32-ia32", + "version": "0.19.8", + "dev": true, + "root_package": "", + "link_packages": {}, + "deps": {}, + "transitive_closure": { + "@esbuild/win32-ia32": [ + "0.19.8" + ] + }, + "lifecycle_build_target": true, + "lifecycle_hooks_env": [], + "lifecycle_hooks_execution_requirements": [ + "no-sandbox" + ], + "bins": {}, + "npm_translate_lock_repo": "npm", + "package_visibility": [ + "//visibility:public" + ] + } + }, + "npm__at_esbuild_win32-x64__0.19.8": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_rule", + "attributes": { + "package": "@esbuild/win32-x64", + "version": "0.19.8", + "root_package": "", + "link_workspace": "flatbuffers~", + "link_packages": {}, + "integrity": "sha512-bfZ0cQ1uZs2PqpulNL5j/3w+GDhP36k1K5c38QdQg+Swy51jFZWWeIkteNsufkQxp986wnqRRsb/bHbY1WQ7TA==", + "url": "https://registry.npmjs.org/@esbuild/win32-x64/-/win32-x64-0.19.8.tgz", + "commit": "", + "patch_args": [], + "patches": [], + "custom_postinstall": "", + "npm_auth": "", + "npm_auth_basic": "", + "npm_auth_username": "", + "npm_auth_password": "", + "lifecycle_hooks": [ + "preinstall", + "install", + "postinstall" + ], + "extra_build_content": "", + "generate_bzl_library_targets": false + } + }, + "npm__at_esbuild_win32-x64__0.19.8__links": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_links", + "attributes": { + "package": "@esbuild/win32-x64", + "version": "0.19.8", + "dev": true, + "root_package": "", + "link_packages": {}, + "deps": {}, + "transitive_closure": { + "@esbuild/win32-x64": [ + "0.19.8" + ] + }, + "lifecycle_build_target": true, + "lifecycle_hooks_env": [], + "lifecycle_hooks_execution_requirements": [ + "no-sandbox" + ], + "bins": {}, + "npm_translate_lock_repo": "npm", + "package_visibility": [ + "//visibility:public" + ] + } + }, + "npm__at_eslint-community_eslint-utils__4.4.0__eslint_8.55.0": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_rule", + "attributes": { + "package": "@eslint-community/eslint-utils", + "version": "4.4.0_eslint_8.55.0", + "root_package": "", + "link_workspace": "flatbuffers~", + "link_packages": {}, + "integrity": "sha512-1/sA4dwrzBAyeUoQ6oxahHKmrZvsnLCg4RfxW3ZFGGmQkSNQPFNLV9CUEFQP1x9EYXHTo5p6xdhZM1Ne9p/AfA==", + "url": "https://registry.npmjs.org/@eslint-community/eslint-utils/-/eslint-utils-4.4.0.tgz", + "commit": "", + "patch_args": [], + "patches": [], + "custom_postinstall": "", + "npm_auth": "", + "npm_auth_basic": "", + "npm_auth_username": "", + "npm_auth_password": "", + "lifecycle_hooks": [], + "extra_build_content": "", + "generate_bzl_library_targets": false + } + }, + "npm__at_eslint-community_eslint-utils__4.4.0__eslint_8.55.0__links": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_links", + "attributes": { + "package": "@eslint-community/eslint-utils", + "version": "4.4.0_eslint_8.55.0", + "dev": true, + "root_package": "", + "link_packages": {}, + "deps": { + "eslint": "8.55.0", + "eslint-visitor-keys": "3.4.3" + }, + "transitive_closure": { + "@eslint-community/eslint-utils": [ + "4.4.0_eslint_8.55.0" + ], + "eslint": [ + "8.55.0" + ], + "eslint-visitor-keys": [ + "3.4.3" + ], + "@eslint-community/regexpp": [ + "4.10.0" + ], + "@eslint/eslintrc": [ + "2.1.4" + ], + "@eslint/js": [ + "8.55.0" + ], + "@humanwhocodes/config-array": [ + "0.11.13" + ], + "@humanwhocodes/module-importer": [ + "1.0.1" + ], + "@nodelib/fs.walk": [ + "1.2.8" + ], + "@ungap/structured-clone": [ + "1.2.0" + ], + "ajv": [ + "6.12.6" + ], + "chalk": [ + "4.1.2" + ], + "cross-spawn": [ + "7.0.3" + ], + "debug": [ + "4.3.4" + ], + "doctrine": [ + "3.0.0" + ], + "escape-string-regexp": [ + "4.0.0" + ], + "eslint-scope": [ + "7.2.2" + ], + "espree": [ + "9.6.1" + ], + "esquery": [ + "1.5.0" + ], + "esutils": [ + "2.0.3" + ], + "fast-deep-equal": [ + "3.1.3" + ], + "file-entry-cache": [ + "6.0.1" + ], + "find-up": [ + "5.0.0" + ], + "glob-parent": [ + "6.0.2" + ], + "globals": [ + "13.23.0" + ], + "graphemer": [ + "1.4.0" + ], + "ignore": [ + "5.3.0" + ], + "imurmurhash": [ + "0.1.4" + ], + "is-glob": [ + "4.0.3" + ], + "is-path-inside": [ + "3.0.3" + ], + "js-yaml": [ + "4.1.0" + ], + "json-stable-stringify-without-jsonify": [ + "1.0.1" + ], + "levn": [ + "0.4.1" + ], + "lodash.merge": [ + "4.6.2" + ], + "minimatch": [ + "3.1.2" + ], + "natural-compare": [ + "1.4.0" + ], + "optionator": [ + "0.9.3" + ], + "strip-ansi": [ + "6.0.1" + ], + "text-table": [ + "0.2.0" + ], + "ansi-regex": [ + "5.0.1" + ], + "@aashutoshrathi/word-wrap": [ + "1.2.6" + ], + "deep-is": [ + "0.1.4" + ], + "fast-levenshtein": [ + "2.0.6" + ], + "prelude-ls": [ + "1.2.1" + ], + "type-check": [ + "0.4.0" + ], + "brace-expansion": [ + "1.1.11" + ], + "balanced-match": [ + "1.0.2" + ], + "concat-map": [ + "0.0.1" + ], + "argparse": [ + "2.0.1" + ], + "is-extglob": [ + "2.1.1" + ], + "type-fest": [ + "0.20.2" + ], + "locate-path": [ + "6.0.0" + ], + "path-exists": [ + "4.0.0" + ], + "p-locate": [ + "5.0.0" + ], + "p-limit": [ + "3.1.0" + ], + "yocto-queue": [ + "0.1.0" + ], + "flat-cache": [ + "3.2.0" + ], + "flatted": [ + "3.2.9" + ], + "keyv": [ + "4.5.4" + ], + "rimraf": [ + "3.0.2" + ], + "glob": [ + "7.2.3" + ], + "fs.realpath": [ + "1.0.0" + ], + "inflight": [ + "1.0.6" + ], + "inherits": [ + "2.0.4" + ], + "once": [ + "1.4.0" + ], + "path-is-absolute": [ + "1.0.1" + ], + "wrappy": [ + "1.0.2" + ], + "json-buffer": [ + "3.0.1" + ], + "estraverse": [ + "5.3.0" + ], + "acorn": [ + "8.11.2" + ], + "acorn-jsx": [ + "5.3.2_acorn_8.11.2" + ], + "esrecurse": [ + "4.3.0" + ], + "ms": [ + "2.1.2" + ], + "path-key": [ + "3.1.1" + ], + "shebang-command": [ + "2.0.0" + ], + "which": [ + "2.0.2" + ], + "isexe": [ + "2.0.0" + ], + "shebang-regex": [ + "3.0.0" + ], + "ansi-styles": [ + "4.3.0" + ], + "supports-color": [ + "7.2.0" + ], + "has-flag": [ + "4.0.0" + ], + "color-convert": [ + "2.0.1" + ], + "color-name": [ + "1.1.4" + ], + "fast-json-stable-stringify": [ + "2.1.0" + ], + "json-schema-traverse": [ + "0.4.1" + ], + "uri-js": [ + "4.4.1" + ], + "punycode": [ + "2.3.1" + ], + "@nodelib/fs.scandir": [ + "2.1.5" + ], + "fastq": [ + "1.15.0" + ], + "reusify": [ + "1.0.4" + ], + "@nodelib/fs.stat": [ + "2.0.5" + ], + "run-parallel": [ + "1.2.0" + ], + "queue-microtask": [ + "1.2.3" + ], + "@humanwhocodes/object-schema": [ + "2.0.1" + ], + "import-fresh": [ + "3.3.0" + ], + "strip-json-comments": [ + "3.1.1" + ], + "parent-module": [ + "1.0.1" + ], + "resolve-from": [ + "4.0.0" + ], + "callsites": [ + "3.1.0" + ] + }, + "lifecycle_build_target": false, + "lifecycle_hooks_env": [], + "lifecycle_hooks_execution_requirements": [], + "bins": {}, + "npm_translate_lock_repo": "npm", + "package_visibility": [ + "//visibility:public" + ] + } + }, + "npm__at_eslint-community_regexpp__4.10.0": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_rule", + "attributes": { + "package": "@eslint-community/regexpp", + "version": "4.10.0", + "root_package": "", + "link_workspace": "flatbuffers~", + "link_packages": {}, + "integrity": "sha512-Cu96Sd2By9mCNTx2iyKOmq10v22jUVQv0lQnlGNy16oE9589yE+QADPbrMGCkA51cKZSg3Pu/aTJVTGfL/qjUA==", + "url": "https://registry.npmjs.org/@eslint-community/regexpp/-/regexpp-4.10.0.tgz", + "commit": "", + "patch_args": [], + "patches": [], + "custom_postinstall": "", + "npm_auth": "", + "npm_auth_basic": "", + "npm_auth_username": "", + "npm_auth_password": "", + "lifecycle_hooks": [], + "extra_build_content": "", + "generate_bzl_library_targets": false + } + }, + "npm__at_eslint-community_regexpp__4.10.0__links": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_links", + "attributes": { + "package": "@eslint-community/regexpp", + "version": "4.10.0", + "dev": true, + "root_package": "", + "link_packages": {}, + "deps": {}, + "transitive_closure": { + "@eslint-community/regexpp": [ + "4.10.0" + ] + }, + "lifecycle_build_target": false, + "lifecycle_hooks_env": [], + "lifecycle_hooks_execution_requirements": [], + "bins": {}, + "npm_translate_lock_repo": "npm", + "package_visibility": [ + "//visibility:public" + ] + } + }, + "npm__at_eslint_eslintrc__2.1.4": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_rule", + "attributes": { + "package": "@eslint/eslintrc", + "version": "2.1.4", + "root_package": "", + "link_workspace": "flatbuffers~", + "link_packages": {}, + "integrity": "sha512-269Z39MS6wVJtsoUl10L60WdkhJVdPG24Q4eZTH3nnF6lpvSShEK3wQjDX9JRWAUPvPh7COouPpU9IrqaZFvtQ==", + "url": "https://registry.npmjs.org/@eslint/eslintrc/-/eslintrc-2.1.4.tgz", + "commit": "", + "patch_args": [], + "patches": [], + "custom_postinstall": "", + "npm_auth": "", + "npm_auth_basic": "", + "npm_auth_username": "", + "npm_auth_password": "", + "lifecycle_hooks": [], + "extra_build_content": "", + "generate_bzl_library_targets": false + } + }, + "npm__at_eslint_eslintrc__2.1.4__links": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_links", + "attributes": { + "package": "@eslint/eslintrc", + "version": "2.1.4", + "dev": true, + "root_package": "", + "link_packages": {}, + "deps": { + "ajv": "6.12.6", + "debug": "4.3.4", + "espree": "9.6.1", + "globals": "13.23.0", + "ignore": "5.3.0", + "import-fresh": "3.3.0", + "js-yaml": "4.1.0", + "minimatch": "3.1.2", + "strip-json-comments": "3.1.1" + }, + "transitive_closure": { + "@eslint/eslintrc": [ + "2.1.4" + ], + "ajv": [ + "6.12.6" + ], + "debug": [ + "4.3.4" + ], + "espree": [ + "9.6.1" + ], + "globals": [ + "13.23.0" + ], + "ignore": [ + "5.3.0" + ], + "import-fresh": [ + "3.3.0" + ], + "js-yaml": [ + "4.1.0" + ], + "minimatch": [ + "3.1.2" + ], + "strip-json-comments": [ + "3.1.1" + ], + "brace-expansion": [ + "1.1.11" + ], + "balanced-match": [ + "1.0.2" + ], + "concat-map": [ + "0.0.1" + ], + "argparse": [ + "2.0.1" + ], + "parent-module": [ + "1.0.1" + ], + "resolve-from": [ + "4.0.0" + ], + "callsites": [ + "3.1.0" + ], + "type-fest": [ + "0.20.2" + ], + "acorn": [ + "8.11.2" + ], + "acorn-jsx": [ + "5.3.2_acorn_8.11.2" + ], + "eslint-visitor-keys": [ + "3.4.3" + ], + "ms": [ + "2.1.2" + ], + "fast-deep-equal": [ + "3.1.3" + ], + "fast-json-stable-stringify": [ + "2.1.0" + ], + "json-schema-traverse": [ + "0.4.1" + ], + "uri-js": [ + "4.4.1" + ], + "punycode": [ + "2.3.1" + ] + }, + "lifecycle_build_target": false, + "lifecycle_hooks_env": [], + "lifecycle_hooks_execution_requirements": [], + "bins": {}, + "npm_translate_lock_repo": "npm", + "package_visibility": [ + "//visibility:public" + ] + } + }, + "npm__at_eslint_js__8.55.0": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_rule", + "attributes": { + "package": "@eslint/js", + "version": "8.55.0", + "root_package": "", + "link_workspace": "flatbuffers~", + "link_packages": {}, + "integrity": "sha512-qQfo2mxH5yVom1kacMtZZJFVdW+E70mqHMJvVg6WTLo+VBuQJ4TojZlfWBjK0ve5BdEeNAVxOsl/nvNMpJOaJA==", + "url": "https://registry.npmjs.org/@eslint/js/-/js-8.55.0.tgz", + "commit": "", + "patch_args": [], + "patches": [], + "custom_postinstall": "", + "npm_auth": "", + "npm_auth_basic": "", + "npm_auth_username": "", + "npm_auth_password": "", + "lifecycle_hooks": [], + "extra_build_content": "", + "generate_bzl_library_targets": false + } + }, + "npm__at_eslint_js__8.55.0__links": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_links", + "attributes": { + "package": "@eslint/js", + "version": "8.55.0", + "dev": true, + "root_package": "", + "link_packages": {}, + "deps": {}, + "transitive_closure": { + "@eslint/js": [ + "8.55.0" + ] + }, + "lifecycle_build_target": false, + "lifecycle_hooks_env": [], + "lifecycle_hooks_execution_requirements": [], + "bins": {}, + "npm_translate_lock_repo": "npm", + "package_visibility": [ + "//visibility:public" + ] + } + }, + "npm__at_humanwhocodes_config-array__0.11.13": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_rule", + "attributes": { + "package": "@humanwhocodes/config-array", + "version": "0.11.13", + "root_package": "", + "link_workspace": "flatbuffers~", + "link_packages": {}, + "integrity": "sha512-JSBDMiDKSzQVngfRjOdFXgFfklaXI4K9nLF49Auh21lmBWRLIK3+xTErTWD4KU54pb6coM6ESE7Awz/FNU3zgQ==", + "url": "https://registry.npmjs.org/@humanwhocodes/config-array/-/config-array-0.11.13.tgz", + "commit": "", + "patch_args": [], + "patches": [], + "custom_postinstall": "", + "npm_auth": "", + "npm_auth_basic": "", + "npm_auth_username": "", + "npm_auth_password": "", + "lifecycle_hooks": [], + "extra_build_content": "", + "generate_bzl_library_targets": false + } + }, + "npm__at_humanwhocodes_config-array__0.11.13__links": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_links", + "attributes": { + "package": "@humanwhocodes/config-array", + "version": "0.11.13", + "dev": true, + "root_package": "", + "link_packages": {}, + "deps": { + "@humanwhocodes/object-schema": "2.0.1", + "debug": "4.3.4", + "minimatch": "3.1.2" + }, + "transitive_closure": { + "@humanwhocodes/config-array": [ + "0.11.13" + ], + "@humanwhocodes/object-schema": [ + "2.0.1" + ], + "debug": [ + "4.3.4" + ], + "minimatch": [ + "3.1.2" + ], + "brace-expansion": [ + "1.1.11" + ], + "balanced-match": [ + "1.0.2" + ], + "concat-map": [ + "0.0.1" + ], + "ms": [ + "2.1.2" + ] + }, + "lifecycle_build_target": false, + "lifecycle_hooks_env": [], + "lifecycle_hooks_execution_requirements": [], + "bins": {}, + "npm_translate_lock_repo": "npm", + "package_visibility": [ + "//visibility:public" + ] + } + }, + "npm__at_humanwhocodes_module-importer__1.0.1": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_rule", + "attributes": { + "package": "@humanwhocodes/module-importer", + "version": "1.0.1", + "root_package": "", + "link_workspace": "flatbuffers~", + "link_packages": {}, + "integrity": "sha512-bxveV4V8v5Yb4ncFTT3rPSgZBOpCkjfK0y4oVVVJwIuDVBRMDXrPyXRL988i5ap9m9bnyEEjWfm5WkBmtffLfA==", + "url": "https://registry.npmjs.org/@humanwhocodes/module-importer/-/module-importer-1.0.1.tgz", + "commit": "", + "patch_args": [], + "patches": [], + "custom_postinstall": "", + "npm_auth": "", + "npm_auth_basic": "", + "npm_auth_username": "", + "npm_auth_password": "", + "lifecycle_hooks": [], + "extra_build_content": "", + "generate_bzl_library_targets": false + } + }, + "npm__at_humanwhocodes_module-importer__1.0.1__links": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_links", + "attributes": { + "package": "@humanwhocodes/module-importer", + "version": "1.0.1", + "dev": true, + "root_package": "", + "link_packages": {}, + "deps": {}, + "transitive_closure": { + "@humanwhocodes/module-importer": [ + "1.0.1" + ] + }, + "lifecycle_build_target": false, + "lifecycle_hooks_env": [], + "lifecycle_hooks_execution_requirements": [], + "bins": {}, + "npm_translate_lock_repo": "npm", + "package_visibility": [ + "//visibility:public" + ] + } + }, + "npm__at_humanwhocodes_object-schema__2.0.1": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_rule", + "attributes": { + "package": "@humanwhocodes/object-schema", + "version": "2.0.1", + "root_package": "", + "link_workspace": "flatbuffers~", + "link_packages": {}, + "integrity": "sha512-dvuCeX5fC9dXgJn9t+X5atfmgQAzUOWqS1254Gh0m6i8wKd10ebXkfNKiRK+1GWi/yTvvLDHpoxLr0xxxeslWw==", + "url": "https://registry.npmjs.org/@humanwhocodes/object-schema/-/object-schema-2.0.1.tgz", + "commit": "", + "patch_args": [], + "patches": [], + "custom_postinstall": "", + "npm_auth": "", + "npm_auth_basic": "", + "npm_auth_username": "", + "npm_auth_password": "", + "lifecycle_hooks": [], + "extra_build_content": "", + "generate_bzl_library_targets": false + } + }, + "npm__at_humanwhocodes_object-schema__2.0.1__links": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_links", + "attributes": { + "package": "@humanwhocodes/object-schema", + "version": "2.0.1", + "dev": true, + "root_package": "", + "link_packages": {}, + "deps": {}, + "transitive_closure": { + "@humanwhocodes/object-schema": [ + "2.0.1" + ] + }, + "lifecycle_build_target": false, + "lifecycle_hooks_env": [], + "lifecycle_hooks_execution_requirements": [], + "bins": {}, + "npm_translate_lock_repo": "npm", + "package_visibility": [ + "//visibility:public" + ] + } + }, + "npm__at_nodelib_fs.scandir__2.1.5": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_rule", + "attributes": { + "package": "@nodelib/fs.scandir", + "version": "2.1.5", + "root_package": "", + "link_workspace": "flatbuffers~", + "link_packages": {}, + "integrity": "sha512-vq24Bq3ym5HEQm2NKCr3yXDwjc7vTsEThRDnkp2DK9p1uqLR+DHurm/NOTo0KG7HYHU7eppKZj3MyqYuMBf62g==", + "url": "https://registry.npmjs.org/@nodelib/fs.scandir/-/fs.scandir-2.1.5.tgz", + "commit": "", + "patch_args": [], + "patches": [], + "custom_postinstall": "", + "npm_auth": "", + "npm_auth_basic": "", + "npm_auth_username": "", + "npm_auth_password": "", + "lifecycle_hooks": [], + "extra_build_content": "", + "generate_bzl_library_targets": false + } + }, + "npm__at_nodelib_fs.scandir__2.1.5__links": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_links", + "attributes": { + "package": "@nodelib/fs.scandir", + "version": "2.1.5", + "dev": true, + "root_package": "", + "link_packages": {}, + "deps": { + "@nodelib/fs.stat": "2.0.5", + "run-parallel": "1.2.0" + }, + "transitive_closure": { + "@nodelib/fs.scandir": [ + "2.1.5" + ], + "@nodelib/fs.stat": [ + "2.0.5" + ], + "run-parallel": [ + "1.2.0" + ], + "queue-microtask": [ + "1.2.3" + ] + }, + "lifecycle_build_target": false, + "lifecycle_hooks_env": [], + "lifecycle_hooks_execution_requirements": [], + "bins": {}, + "npm_translate_lock_repo": "npm", + "package_visibility": [ + "//visibility:public" + ] + } + }, + "npm__at_nodelib_fs.stat__2.0.5": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_rule", + "attributes": { + "package": "@nodelib/fs.stat", + "version": "2.0.5", + "root_package": "", + "link_workspace": "flatbuffers~", + "link_packages": {}, + "integrity": "sha512-RkhPPp2zrqDAQA/2jNhnztcPAlv64XdhIp7a7454A5ovI7Bukxgt7MX7udwAu3zg1DcpPU0rz3VV1SeaqvY4+A==", + "url": "https://registry.npmjs.org/@nodelib/fs.stat/-/fs.stat-2.0.5.tgz", + "commit": "", + "patch_args": [], + "patches": [], + "custom_postinstall": "", + "npm_auth": "", + "npm_auth_basic": "", + "npm_auth_username": "", + "npm_auth_password": "", + "lifecycle_hooks": [], + "extra_build_content": "", + "generate_bzl_library_targets": false + } + }, + "npm__at_nodelib_fs.stat__2.0.5__links": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_links", + "attributes": { + "package": "@nodelib/fs.stat", + "version": "2.0.5", + "dev": true, + "root_package": "", + "link_packages": {}, + "deps": {}, + "transitive_closure": { + "@nodelib/fs.stat": [ + "2.0.5" + ] + }, + "lifecycle_build_target": false, + "lifecycle_hooks_env": [], + "lifecycle_hooks_execution_requirements": [], + "bins": {}, + "npm_translate_lock_repo": "npm", + "package_visibility": [ + "//visibility:public" + ] + } + }, + "npm__at_nodelib_fs.walk__1.2.8": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_rule", + "attributes": { + "package": "@nodelib/fs.walk", + "version": "1.2.8", + "root_package": "", + "link_workspace": "flatbuffers~", + "link_packages": {}, + "integrity": "sha512-oGB+UxlgWcgQkgwo8GcEGwemoTFt3FIO9ababBmaGwXIoBKZ+GTy0pP185beGg7Llih/NSHSV2XAs1lnznocSg==", + "url": "https://registry.npmjs.org/@nodelib/fs.walk/-/fs.walk-1.2.8.tgz", + "commit": "", + "patch_args": [], + "patches": [], + "custom_postinstall": "", + "npm_auth": "", + "npm_auth_basic": "", + "npm_auth_username": "", + "npm_auth_password": "", + "lifecycle_hooks": [], + "extra_build_content": "", + "generate_bzl_library_targets": false + } + }, + "npm__at_nodelib_fs.walk__1.2.8__links": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_links", + "attributes": { + "package": "@nodelib/fs.walk", + "version": "1.2.8", + "dev": true, + "root_package": "", + "link_packages": {}, + "deps": { + "@nodelib/fs.scandir": "2.1.5", + "fastq": "1.15.0" + }, + "transitive_closure": { + "@nodelib/fs.walk": [ + "1.2.8" + ], + "@nodelib/fs.scandir": [ + "2.1.5" + ], + "fastq": [ + "1.15.0" + ], + "reusify": [ + "1.0.4" + ], + "@nodelib/fs.stat": [ + "2.0.5" + ], + "run-parallel": [ + "1.2.0" + ], + "queue-microtask": [ + "1.2.3" + ] + }, + "lifecycle_build_target": false, + "lifecycle_hooks_env": [], + "lifecycle_hooks_execution_requirements": [], + "bins": {}, + "npm_translate_lock_repo": "npm", + "package_visibility": [ + "//visibility:public" + ] + } + }, + "npm__at_types_json-schema__7.0.15": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_rule", + "attributes": { + "package": "@types/json-schema", + "version": "7.0.15", + "root_package": "", + "link_workspace": "flatbuffers~", + "link_packages": {}, + "integrity": "sha512-5+fP8P8MFNC+AyZCDxrB2pkZFPGzqQWUzpSeuuVLvm8VMcorNYavBqoFcxK8bQz4Qsbn4oUEEem4wDLfcysGHA==", + "url": "https://registry.npmjs.org/@types/json-schema/-/json-schema-7.0.15.tgz", + "commit": "", + "patch_args": [], + "patches": [], + "custom_postinstall": "", + "npm_auth": "", + "npm_auth_basic": "", + "npm_auth_username": "", + "npm_auth_password": "", + "lifecycle_hooks": [], + "extra_build_content": "", + "generate_bzl_library_targets": false + } + }, + "npm__at_types_json-schema__7.0.15__links": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_links", + "attributes": { + "package": "@types/json-schema", + "version": "7.0.15", + "dev": true, + "root_package": "", + "link_packages": {}, + "deps": {}, + "transitive_closure": { + "@types/json-schema": [ + "7.0.15" + ] + }, + "lifecycle_build_target": false, + "lifecycle_hooks_env": [], + "lifecycle_hooks_execution_requirements": [], + "bins": {}, + "npm_translate_lock_repo": "npm", + "package_visibility": [ + "//visibility:public" + ] + } + }, + "npm__at_types_node__20.10.4": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_rule", + "attributes": { + "package": "@types/node", + "version": "20.10.4", + "root_package": "", + "link_workspace": "flatbuffers~", + "link_packages": { + "": [ + "@types/node" + ] + }, + "integrity": "sha512-D08YG6rr8X90YB56tSIuBaddy/UXAA9RKJoFvrsnogAum/0pmjkgi4+2nx96A330FmioegBWmEYQ+syqCFaveg==", + "url": "https://registry.npmjs.org/@types/node/-/node-20.10.4.tgz", + "commit": "", + "patch_args": [], + "patches": [], + "custom_postinstall": "", + "npm_auth": "", + "npm_auth_basic": "", + "npm_auth_username": "", + "npm_auth_password": "", + "lifecycle_hooks": [], + "extra_build_content": "", + "generate_bzl_library_targets": false + } + }, + "npm__at_types_node__20.10.4__links": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_links", + "attributes": { + "package": "@types/node", + "version": "20.10.4", + "dev": true, + "root_package": "", + "link_packages": { + "": [ + "@types/node" + ] + }, + "deps": { + "undici-types": "5.26.5" + }, + "transitive_closure": { + "@types/node": [ + "20.10.4" + ], + "undici-types": [ + "5.26.5" + ] + }, + "lifecycle_build_target": false, + "lifecycle_hooks_env": [], + "lifecycle_hooks_execution_requirements": [], + "bins": {}, + "npm_translate_lock_repo": "npm", + "package_visibility": [ + "//visibility:public" + ] + } + }, + "npm__at_types_semver__7.5.6": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_rule", + "attributes": { + "package": "@types/semver", + "version": "7.5.6", + "root_package": "", + "link_workspace": "flatbuffers~", + "link_packages": {}, + "integrity": "sha512-dn1l8LaMea/IjDoHNd9J52uBbInB796CDffS6VdIxvqYCPSG0V0DzHp76GpaWnlhg88uYyPbXCDIowa86ybd5A==", + "url": "https://registry.npmjs.org/@types/semver/-/semver-7.5.6.tgz", + "commit": "", + "patch_args": [], + "patches": [], + "custom_postinstall": "", + "npm_auth": "", + "npm_auth_basic": "", + "npm_auth_username": "", + "npm_auth_password": "", + "lifecycle_hooks": [], + "extra_build_content": "", + "generate_bzl_library_targets": false + } + }, + "npm__at_types_semver__7.5.6__links": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_links", + "attributes": { + "package": "@types/semver", + "version": "7.5.6", + "dev": true, + "root_package": "", + "link_packages": {}, + "deps": {}, + "transitive_closure": { + "@types/semver": [ + "7.5.6" + ] + }, + "lifecycle_build_target": false, + "lifecycle_hooks_env": [], + "lifecycle_hooks_execution_requirements": [], + "bins": {}, + "npm_translate_lock_repo": "npm", + "package_visibility": [ + "//visibility:public" + ] + } + }, + "npm__at_typescript-eslint_eslint-plugin__6.13.2__-1224903089": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_rule", + "attributes": { + "package": "@typescript-eslint/eslint-plugin", + "version": "6.13.2_-1224903089", + "root_package": "", + "link_workspace": "flatbuffers~", + "link_packages": { + "": [ + "@typescript-eslint/eslint-plugin" + ] + }, + "integrity": "sha512-3+9OGAWHhk4O1LlcwLBONbdXsAhLjyCFogJY/cWy2lxdVJ2JrcTF2pTGMaLl2AE7U1l31n8Py4a8bx5DLf/0dQ==", + "url": "https://registry.npmjs.org/@typescript-eslint/eslint-plugin/-/eslint-plugin-6.13.2.tgz", + "commit": "", + "patch_args": [], + "patches": [], + "custom_postinstall": "", + "npm_auth": "", + "npm_auth_basic": "", + "npm_auth_username": "", + "npm_auth_password": "", + "lifecycle_hooks": [], + "extra_build_content": "", + "generate_bzl_library_targets": false + } + }, + "npm__at_typescript-eslint_eslint-plugin__6.13.2__-1224903089__links": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_links", + "attributes": { + "package": "@typescript-eslint/eslint-plugin", + "version": "6.13.2_-1224903089", + "dev": true, + "root_package": "", + "link_packages": { + "": [ + "@typescript-eslint/eslint-plugin" + ] + }, + "deps": { + "@eslint-community/regexpp": "4.10.0", + "@typescript-eslint/parser": "6.13.2_1796040679", + "@typescript-eslint/scope-manager": "6.13.2", + "@typescript-eslint/type-utils": "6.13.2_1796040679", + "@typescript-eslint/utils": "6.13.2_1796040679", + "@typescript-eslint/visitor-keys": "6.13.2", + "debug": "4.3.4", + "eslint": "8.55.0", + "graphemer": "1.4.0", + "ignore": "5.3.0", + "natural-compare": "1.4.0", + "semver": "7.5.4", + "ts-api-utils": "1.0.3_typescript_5.3.3", + "typescript": "5.3.3" + }, + "transitive_closure": { + "@typescript-eslint/eslint-plugin": [ + "6.13.2_-1224903089" + ], + "@eslint-community/regexpp": [ + "4.10.0" + ], + "@typescript-eslint/parser": [ + "6.13.2_1796040679" + ], + "@typescript-eslint/scope-manager": [ + "6.13.2" + ], + "@typescript-eslint/type-utils": [ + "6.13.2_1796040679" + ], + "@typescript-eslint/utils": [ + "6.13.2_1796040679" + ], + "@typescript-eslint/visitor-keys": [ + "6.13.2" + ], + "debug": [ + "4.3.4" + ], + "eslint": [ + "8.55.0" + ], + "graphemer": [ + "1.4.0" + ], + "ignore": [ + "5.3.0" + ], + "natural-compare": [ + "1.4.0" + ], + "semver": [ + "7.5.4" + ], + "ts-api-utils": [ + "1.0.3_typescript_5.3.3" + ], + "typescript": [ + "5.3.3" + ], + "lru-cache": [ + "6.0.0" + ], + "yallist": [ + "4.0.0" + ], + "@eslint-community/eslint-utils": [ + "4.4.0_eslint_8.55.0" + ], + "@eslint/eslintrc": [ + "2.1.4" + ], + "@eslint/js": [ + "8.55.0" + ], + "@humanwhocodes/config-array": [ + "0.11.13" + ], + "@humanwhocodes/module-importer": [ + "1.0.1" + ], + "@nodelib/fs.walk": [ + "1.2.8" + ], + "@ungap/structured-clone": [ + "1.2.0" + ], + "ajv": [ + "6.12.6" + ], + "chalk": [ + "4.1.2" + ], + "cross-spawn": [ + "7.0.3" + ], + "doctrine": [ + "3.0.0" + ], + "escape-string-regexp": [ + "4.0.0" + ], + "eslint-scope": [ + "7.2.2" + ], + "eslint-visitor-keys": [ + "3.4.3" + ], + "espree": [ + "9.6.1" + ], + "esquery": [ + "1.5.0" + ], + "esutils": [ + "2.0.3" + ], + "fast-deep-equal": [ + "3.1.3" + ], + "file-entry-cache": [ + "6.0.1" + ], + "find-up": [ + "5.0.0" + ], + "glob-parent": [ + "5.1.2", + "6.0.2" + ], + "globals": [ + "13.23.0" + ], + "imurmurhash": [ + "0.1.4" + ], + "is-glob": [ + "4.0.3" + ], + "is-path-inside": [ + "3.0.3" + ], + "js-yaml": [ + "4.1.0" + ], + "json-stable-stringify-without-jsonify": [ + "1.0.1" + ], + "levn": [ + "0.4.1" + ], + "lodash.merge": [ + "4.6.2" + ], + "minimatch": [ + "3.1.2" + ], + "optionator": [ + "0.9.3" + ], + "strip-ansi": [ + "6.0.1" + ], + "text-table": [ + "0.2.0" + ], + "ansi-regex": [ + "5.0.1" + ], + "@aashutoshrathi/word-wrap": [ + "1.2.6" + ], + "deep-is": [ + "0.1.4" + ], + "fast-levenshtein": [ + "2.0.6" + ], + "prelude-ls": [ + "1.2.1" + ], + "type-check": [ + "0.4.0" + ], + "brace-expansion": [ + "1.1.11" + ], + "balanced-match": [ + "1.0.2" + ], + "concat-map": [ + "0.0.1" + ], + "argparse": [ + "2.0.1" + ], + "is-extglob": [ + "2.1.1" + ], + "type-fest": [ + "0.20.2" + ], + "locate-path": [ + "6.0.0" + ], + "path-exists": [ + "4.0.0" + ], + "p-locate": [ + "5.0.0" + ], + "p-limit": [ + "3.1.0" + ], + "yocto-queue": [ + "0.1.0" + ], + "flat-cache": [ + "3.2.0" + ], + "flatted": [ + "3.2.9" + ], + "keyv": [ + "4.5.4" + ], + "rimraf": [ + "3.0.2" + ], + "glob": [ + "7.2.3" + ], + "fs.realpath": [ + "1.0.0" + ], + "inflight": [ + "1.0.6" + ], + "inherits": [ + "2.0.4" + ], + "once": [ + "1.4.0" + ], + "path-is-absolute": [ + "1.0.1" + ], + "wrappy": [ + "1.0.2" + ], + "json-buffer": [ + "3.0.1" + ], + "estraverse": [ + "5.3.0" + ], + "acorn": [ + "8.11.2" + ], + "acorn-jsx": [ + "5.3.2_acorn_8.11.2" + ], + "esrecurse": [ + "4.3.0" + ], + "path-key": [ + "3.1.1" + ], + "shebang-command": [ + "2.0.0" + ], + "which": [ + "2.0.2" + ], + "isexe": [ + "2.0.0" + ], + "shebang-regex": [ + "3.0.0" + ], + "ansi-styles": [ + "4.3.0" + ], + "supports-color": [ + "7.2.0" + ], + "has-flag": [ + "4.0.0" + ], + "color-convert": [ + "2.0.1" + ], + "color-name": [ + "1.1.4" + ], + "fast-json-stable-stringify": [ + "2.1.0" + ], + "json-schema-traverse": [ + "0.4.1" + ], + "uri-js": [ + "4.4.1" + ], + "punycode": [ + "2.3.1" + ], + "@nodelib/fs.scandir": [ + "2.1.5" + ], + "fastq": [ + "1.15.0" + ], + "reusify": [ + "1.0.4" + ], + "@nodelib/fs.stat": [ + "2.0.5" + ], + "run-parallel": [ + "1.2.0" + ], + "queue-microtask": [ + "1.2.3" + ], + "@humanwhocodes/object-schema": [ + "2.0.1" + ], + "import-fresh": [ + "3.3.0" + ], + "strip-json-comments": [ + "3.1.1" + ], + "parent-module": [ + "1.0.1" + ], + "resolve-from": [ + "4.0.0" + ], + "callsites": [ + "3.1.0" + ], + "ms": [ + "2.1.2" + ], + "@typescript-eslint/types": [ + "6.13.2" + ], + "@types/json-schema": [ + "7.0.15" + ], + "@types/semver": [ + "7.5.6" + ], + "@typescript-eslint/typescript-estree": [ + "6.13.2_typescript_5.3.3" + ], + "globby": [ + "11.1.0" + ], + "array-union": [ + "2.1.0" + ], + "dir-glob": [ + "3.0.1" + ], + "fast-glob": [ + "3.3.2" + ], + "merge2": [ + "1.4.1" + ], + "slash": [ + "3.0.0" + ], + "micromatch": [ + "4.0.5" + ], + "braces": [ + "3.0.2" + ], + "picomatch": [ + "2.3.1" + ], + "fill-range": [ + "7.0.1" + ], + "to-regex-range": [ + "5.0.1" + ], + "is-number": [ + "7.0.0" + ], + "path-type": [ + "4.0.0" + ] + }, + "lifecycle_build_target": false, + "lifecycle_hooks_env": [], + "lifecycle_hooks_execution_requirements": [], + "bins": {}, + "npm_translate_lock_repo": "npm", + "package_visibility": [ + "//visibility:public" + ] + } + }, + "npm__at_typescript-eslint_parser__6.13.2__1796040679": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_rule", + "attributes": { + "package": "@typescript-eslint/parser", + "version": "6.13.2_1796040679", + "root_package": "", + "link_workspace": "flatbuffers~", + "link_packages": { + "": [ + "@typescript-eslint/parser" + ] + }, + "integrity": "sha512-MUkcC+7Wt/QOGeVlM8aGGJZy1XV5YKjTpq9jK6r6/iLsGXhBVaGP5N0UYvFsu9BFlSpwY9kMretzdBH01rkRXg==", + "url": "https://registry.npmjs.org/@typescript-eslint/parser/-/parser-6.13.2.tgz", + "commit": "", + "patch_args": [], + "patches": [], + "custom_postinstall": "", + "npm_auth": "", + "npm_auth_basic": "", + "npm_auth_username": "", + "npm_auth_password": "", + "lifecycle_hooks": [], + "extra_build_content": "", + "generate_bzl_library_targets": false + } + }, + "npm__at_typescript-eslint_parser__6.13.2__1796040679__links": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_links", + "attributes": { + "package": "@typescript-eslint/parser", + "version": "6.13.2_1796040679", + "dev": true, + "root_package": "", + "link_packages": { + "": [ + "@typescript-eslint/parser" + ] + }, + "deps": { + "@typescript-eslint/scope-manager": "6.13.2", + "@typescript-eslint/types": "6.13.2", + "@typescript-eslint/typescript-estree": "6.13.2_typescript_5.3.3", + "@typescript-eslint/visitor-keys": "6.13.2", + "debug": "4.3.4", + "eslint": "8.55.0", + "typescript": "5.3.3" + }, + "transitive_closure": { + "@typescript-eslint/parser": [ + "6.13.2_1796040679" + ], + "@typescript-eslint/scope-manager": [ + "6.13.2" + ], + "@typescript-eslint/types": [ + "6.13.2" + ], + "@typescript-eslint/typescript-estree": [ + "6.13.2_typescript_5.3.3" + ], + "@typescript-eslint/visitor-keys": [ + "6.13.2" + ], + "debug": [ + "4.3.4" + ], + "eslint": [ + "8.55.0" + ], + "typescript": [ + "5.3.3" + ], + "@eslint-community/eslint-utils": [ + "4.4.0_eslint_8.55.0" + ], + "@eslint-community/regexpp": [ + "4.10.0" + ], + "@eslint/eslintrc": [ + "2.1.4" + ], + "@eslint/js": [ + "8.55.0" + ], + "@humanwhocodes/config-array": [ + "0.11.13" + ], + "@humanwhocodes/module-importer": [ + "1.0.1" + ], + "@nodelib/fs.walk": [ + "1.2.8" + ], + "@ungap/structured-clone": [ + "1.2.0" + ], + "ajv": [ + "6.12.6" + ], + "chalk": [ + "4.1.2" + ], + "cross-spawn": [ + "7.0.3" + ], + "doctrine": [ + "3.0.0" + ], + "escape-string-regexp": [ + "4.0.0" + ], + "eslint-scope": [ + "7.2.2" + ], + "eslint-visitor-keys": [ + "3.4.3" + ], + "espree": [ + "9.6.1" + ], + "esquery": [ + "1.5.0" + ], + "esutils": [ + "2.0.3" + ], + "fast-deep-equal": [ + "3.1.3" + ], + "file-entry-cache": [ + "6.0.1" + ], + "find-up": [ + "5.0.0" + ], + "glob-parent": [ + "5.1.2", + "6.0.2" + ], + "globals": [ + "13.23.0" + ], + "graphemer": [ + "1.4.0" + ], + "ignore": [ + "5.3.0" + ], + "imurmurhash": [ + "0.1.4" + ], + "is-glob": [ + "4.0.3" + ], + "is-path-inside": [ + "3.0.3" + ], + "js-yaml": [ + "4.1.0" + ], + "json-stable-stringify-without-jsonify": [ + "1.0.1" + ], + "levn": [ + "0.4.1" + ], + "lodash.merge": [ + "4.6.2" + ], + "minimatch": [ + "3.1.2" + ], + "natural-compare": [ + "1.4.0" + ], + "optionator": [ + "0.9.3" + ], + "strip-ansi": [ + "6.0.1" + ], + "text-table": [ + "0.2.0" + ], + "ansi-regex": [ + "5.0.1" + ], + "@aashutoshrathi/word-wrap": [ + "1.2.6" + ], + "deep-is": [ + "0.1.4" + ], + "fast-levenshtein": [ + "2.0.6" + ], + "prelude-ls": [ + "1.2.1" + ], + "type-check": [ + "0.4.0" + ], + "brace-expansion": [ + "1.1.11" + ], + "balanced-match": [ + "1.0.2" + ], + "concat-map": [ + "0.0.1" + ], + "argparse": [ + "2.0.1" + ], + "is-extglob": [ + "2.1.1" + ], + "type-fest": [ + "0.20.2" + ], + "locate-path": [ + "6.0.0" + ], + "path-exists": [ + "4.0.0" + ], + "p-locate": [ + "5.0.0" + ], + "p-limit": [ + "3.1.0" + ], + "yocto-queue": [ + "0.1.0" + ], + "flat-cache": [ + "3.2.0" + ], + "flatted": [ + "3.2.9" + ], + "keyv": [ + "4.5.4" + ], + "rimraf": [ + "3.0.2" + ], + "glob": [ + "7.2.3" + ], + "fs.realpath": [ + "1.0.0" + ], + "inflight": [ + "1.0.6" + ], + "inherits": [ + "2.0.4" + ], + "once": [ + "1.4.0" + ], + "path-is-absolute": [ + "1.0.1" + ], + "wrappy": [ + "1.0.2" + ], + "json-buffer": [ + "3.0.1" + ], + "estraverse": [ + "5.3.0" + ], + "acorn": [ + "8.11.2" + ], + "acorn-jsx": [ + "5.3.2_acorn_8.11.2" + ], + "esrecurse": [ + "4.3.0" + ], + "path-key": [ + "3.1.1" + ], + "shebang-command": [ + "2.0.0" + ], + "which": [ + "2.0.2" + ], + "isexe": [ + "2.0.0" + ], + "shebang-regex": [ + "3.0.0" + ], + "ansi-styles": [ + "4.3.0" + ], + "supports-color": [ + "7.2.0" + ], + "has-flag": [ + "4.0.0" + ], + "color-convert": [ + "2.0.1" + ], + "color-name": [ + "1.1.4" + ], + "fast-json-stable-stringify": [ + "2.1.0" + ], + "json-schema-traverse": [ + "0.4.1" + ], + "uri-js": [ + "4.4.1" + ], + "punycode": [ + "2.3.1" + ], + "@nodelib/fs.scandir": [ + "2.1.5" + ], + "fastq": [ + "1.15.0" + ], + "reusify": [ + "1.0.4" + ], + "@nodelib/fs.stat": [ + "2.0.5" + ], + "run-parallel": [ + "1.2.0" + ], + "queue-microtask": [ + "1.2.3" + ], + "@humanwhocodes/object-schema": [ + "2.0.1" + ], + "import-fresh": [ + "3.3.0" + ], + "strip-json-comments": [ + "3.1.1" + ], + "parent-module": [ + "1.0.1" + ], + "resolve-from": [ + "4.0.0" + ], + "callsites": [ + "3.1.0" + ], + "ms": [ + "2.1.2" + ], + "globby": [ + "11.1.0" + ], + "semver": [ + "7.5.4" + ], + "ts-api-utils": [ + "1.0.3_typescript_5.3.3" + ], + "lru-cache": [ + "6.0.0" + ], + "yallist": [ + "4.0.0" + ], + "array-union": [ + "2.1.0" + ], + "dir-glob": [ + "3.0.1" + ], + "fast-glob": [ + "3.3.2" + ], + "merge2": [ + "1.4.1" + ], + "slash": [ + "3.0.0" + ], + "micromatch": [ + "4.0.5" + ], + "braces": [ + "3.0.2" + ], + "picomatch": [ + "2.3.1" + ], + "fill-range": [ + "7.0.1" + ], + "to-regex-range": [ + "5.0.1" + ], + "is-number": [ + "7.0.0" + ], + "path-type": [ + "4.0.0" + ] + }, + "lifecycle_build_target": false, + "lifecycle_hooks_env": [], + "lifecycle_hooks_execution_requirements": [], + "bins": {}, + "npm_translate_lock_repo": "npm", + "package_visibility": [ + "//visibility:public" + ] + } + }, + "npm__at_typescript-eslint_scope-manager__6.13.2": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_rule", + "attributes": { + "package": "@typescript-eslint/scope-manager", + "version": "6.13.2", + "root_package": "", + "link_workspace": "flatbuffers~", + "link_packages": {}, + "integrity": "sha512-CXQA0xo7z6x13FeDYCgBkjWzNqzBn8RXaE3QVQVIUm74fWJLkJkaHmHdKStrxQllGh6Q4eUGyNpMe0b1hMkXFA==", + "url": "https://registry.npmjs.org/@typescript-eslint/scope-manager/-/scope-manager-6.13.2.tgz", + "commit": "", + "patch_args": [], + "patches": [], + "custom_postinstall": "", + "npm_auth": "", + "npm_auth_basic": "", + "npm_auth_username": "", + "npm_auth_password": "", + "lifecycle_hooks": [], + "extra_build_content": "", + "generate_bzl_library_targets": false + } + }, + "npm__at_typescript-eslint_scope-manager__6.13.2__links": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_links", + "attributes": { + "package": "@typescript-eslint/scope-manager", + "version": "6.13.2", + "dev": true, + "root_package": "", + "link_packages": {}, + "deps": { + "@typescript-eslint/types": "6.13.2", + "@typescript-eslint/visitor-keys": "6.13.2" + }, + "transitive_closure": { + "@typescript-eslint/scope-manager": [ + "6.13.2" + ], + "@typescript-eslint/types": [ + "6.13.2" + ], + "@typescript-eslint/visitor-keys": [ + "6.13.2" + ], + "eslint-visitor-keys": [ + "3.4.3" + ] + }, + "lifecycle_build_target": false, + "lifecycle_hooks_env": [], + "lifecycle_hooks_execution_requirements": [], + "bins": {}, + "npm_translate_lock_repo": "npm", + "package_visibility": [ + "//visibility:public" + ] + } + }, + "npm__at_typescript-eslint_type-utils__6.13.2__1796040679": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_rule", + "attributes": { + "package": "@typescript-eslint/type-utils", + "version": "6.13.2_1796040679", + "root_package": "", + "link_workspace": "flatbuffers~", + "link_packages": {}, + "integrity": "sha512-Qr6ssS1GFongzH2qfnWKkAQmMUyZSyOr0W54nZNU1MDfo+U4Mv3XveeLZzadc/yq8iYhQZHYT+eoXJqnACM1tw==", + "url": "https://registry.npmjs.org/@typescript-eslint/type-utils/-/type-utils-6.13.2.tgz", + "commit": "", + "patch_args": [], + "patches": [], + "custom_postinstall": "", + "npm_auth": "", + "npm_auth_basic": "", + "npm_auth_username": "", + "npm_auth_password": "", + "lifecycle_hooks": [], + "extra_build_content": "", + "generate_bzl_library_targets": false + } + }, + "npm__at_typescript-eslint_type-utils__6.13.2__1796040679__links": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_links", + "attributes": { + "package": "@typescript-eslint/type-utils", + "version": "6.13.2_1796040679", + "dev": true, + "root_package": "", + "link_packages": {}, + "deps": { + "@typescript-eslint/typescript-estree": "6.13.2_typescript_5.3.3", + "@typescript-eslint/utils": "6.13.2_1796040679", + "debug": "4.3.4", + "eslint": "8.55.0", + "ts-api-utils": "1.0.3_typescript_5.3.3", + "typescript": "5.3.3" + }, + "transitive_closure": { + "@typescript-eslint/type-utils": [ + "6.13.2_1796040679" + ], + "@typescript-eslint/typescript-estree": [ + "6.13.2_typescript_5.3.3" + ], + "@typescript-eslint/utils": [ + "6.13.2_1796040679" + ], + "debug": [ + "4.3.4" + ], + "eslint": [ + "8.55.0" + ], + "ts-api-utils": [ + "1.0.3_typescript_5.3.3" + ], + "typescript": [ + "5.3.3" + ], + "@eslint-community/eslint-utils": [ + "4.4.0_eslint_8.55.0" + ], + "@eslint-community/regexpp": [ + "4.10.0" + ], + "@eslint/eslintrc": [ + "2.1.4" + ], + "@eslint/js": [ + "8.55.0" + ], + "@humanwhocodes/config-array": [ + "0.11.13" + ], + "@humanwhocodes/module-importer": [ + "1.0.1" + ], + "@nodelib/fs.walk": [ + "1.2.8" + ], + "@ungap/structured-clone": [ + "1.2.0" + ], + "ajv": [ + "6.12.6" + ], + "chalk": [ + "4.1.2" + ], + "cross-spawn": [ + "7.0.3" + ], + "doctrine": [ + "3.0.0" + ], + "escape-string-regexp": [ + "4.0.0" + ], + "eslint-scope": [ + "7.2.2" + ], + "eslint-visitor-keys": [ + "3.4.3" + ], + "espree": [ + "9.6.1" + ], + "esquery": [ + "1.5.0" + ], + "esutils": [ + "2.0.3" + ], + "fast-deep-equal": [ + "3.1.3" + ], + "file-entry-cache": [ + "6.0.1" + ], + "find-up": [ + "5.0.0" + ], + "glob-parent": [ + "5.1.2", + "6.0.2" + ], + "globals": [ + "13.23.0" + ], + "graphemer": [ + "1.4.0" + ], + "ignore": [ + "5.3.0" + ], + "imurmurhash": [ + "0.1.4" + ], + "is-glob": [ + "4.0.3" + ], + "is-path-inside": [ + "3.0.3" + ], + "js-yaml": [ + "4.1.0" + ], + "json-stable-stringify-without-jsonify": [ + "1.0.1" + ], + "levn": [ + "0.4.1" + ], + "lodash.merge": [ + "4.6.2" + ], + "minimatch": [ + "3.1.2" + ], + "natural-compare": [ + "1.4.0" + ], + "optionator": [ + "0.9.3" + ], + "strip-ansi": [ + "6.0.1" + ], + "text-table": [ + "0.2.0" + ], + "ansi-regex": [ + "5.0.1" + ], + "@aashutoshrathi/word-wrap": [ + "1.2.6" + ], + "deep-is": [ + "0.1.4" + ], + "fast-levenshtein": [ + "2.0.6" + ], + "prelude-ls": [ + "1.2.1" + ], + "type-check": [ + "0.4.0" + ], + "brace-expansion": [ + "1.1.11" + ], + "balanced-match": [ + "1.0.2" + ], + "concat-map": [ + "0.0.1" + ], + "argparse": [ + "2.0.1" + ], + "is-extglob": [ + "2.1.1" + ], + "type-fest": [ + "0.20.2" + ], + "locate-path": [ + "6.0.0" + ], + "path-exists": [ + "4.0.0" + ], + "p-locate": [ + "5.0.0" + ], + "p-limit": [ + "3.1.0" + ], + "yocto-queue": [ + "0.1.0" + ], + "flat-cache": [ + "3.2.0" + ], + "flatted": [ + "3.2.9" + ], + "keyv": [ + "4.5.4" + ], + "rimraf": [ + "3.0.2" + ], + "glob": [ + "7.2.3" + ], + "fs.realpath": [ + "1.0.0" + ], + "inflight": [ + "1.0.6" + ], + "inherits": [ + "2.0.4" + ], + "once": [ + "1.4.0" + ], + "path-is-absolute": [ + "1.0.1" + ], + "wrappy": [ + "1.0.2" + ], + "json-buffer": [ + "3.0.1" + ], + "estraverse": [ + "5.3.0" + ], + "acorn": [ + "8.11.2" + ], + "acorn-jsx": [ + "5.3.2_acorn_8.11.2" + ], + "esrecurse": [ + "4.3.0" + ], + "path-key": [ + "3.1.1" + ], + "shebang-command": [ + "2.0.0" + ], + "which": [ + "2.0.2" + ], + "isexe": [ + "2.0.0" + ], + "shebang-regex": [ + "3.0.0" + ], + "ansi-styles": [ + "4.3.0" + ], + "supports-color": [ + "7.2.0" + ], + "has-flag": [ + "4.0.0" + ], + "color-convert": [ + "2.0.1" + ], + "color-name": [ + "1.1.4" + ], + "fast-json-stable-stringify": [ + "2.1.0" + ], + "json-schema-traverse": [ + "0.4.1" + ], + "uri-js": [ + "4.4.1" + ], + "punycode": [ + "2.3.1" + ], + "@nodelib/fs.scandir": [ + "2.1.5" + ], + "fastq": [ + "1.15.0" + ], + "reusify": [ + "1.0.4" + ], + "@nodelib/fs.stat": [ + "2.0.5" + ], + "run-parallel": [ + "1.2.0" + ], + "queue-microtask": [ + "1.2.3" + ], + "@humanwhocodes/object-schema": [ + "2.0.1" + ], + "import-fresh": [ + "3.3.0" + ], + "strip-json-comments": [ + "3.1.1" + ], + "parent-module": [ + "1.0.1" + ], + "resolve-from": [ + "4.0.0" + ], + "callsites": [ + "3.1.0" + ], + "ms": [ + "2.1.2" + ], + "@types/json-schema": [ + "7.0.15" + ], + "@types/semver": [ + "7.5.6" + ], + "@typescript-eslint/scope-manager": [ + "6.13.2" + ], + "@typescript-eslint/types": [ + "6.13.2" + ], + "semver": [ + "7.5.4" + ], + "lru-cache": [ + "6.0.0" + ], + "yallist": [ + "4.0.0" + ], + "@typescript-eslint/visitor-keys": [ + "6.13.2" + ], + "globby": [ + "11.1.0" + ], + "array-union": [ + "2.1.0" + ], + "dir-glob": [ + "3.0.1" + ], + "fast-glob": [ + "3.3.2" + ], + "merge2": [ + "1.4.1" + ], + "slash": [ + "3.0.0" + ], + "micromatch": [ + "4.0.5" + ], + "braces": [ + "3.0.2" + ], + "picomatch": [ + "2.3.1" + ], + "fill-range": [ + "7.0.1" + ], + "to-regex-range": [ + "5.0.1" + ], + "is-number": [ + "7.0.0" + ], + "path-type": [ + "4.0.0" + ] + }, + "lifecycle_build_target": false, + "lifecycle_hooks_env": [], + "lifecycle_hooks_execution_requirements": [], + "bins": {}, + "npm_translate_lock_repo": "npm", + "package_visibility": [ + "//visibility:public" + ] + } + }, + "npm__at_typescript-eslint_types__6.13.2": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_rule", + "attributes": { + "package": "@typescript-eslint/types", + "version": "6.13.2", + "root_package": "", + "link_workspace": "flatbuffers~", + "link_packages": {}, + "integrity": "sha512-7sxbQ+EMRubQc3wTfTsycgYpSujyVbI1xw+3UMRUcrhSy+pN09y/lWzeKDbvhoqcRbHdc+APLs/PWYi/cisLPg==", + "url": "https://registry.npmjs.org/@typescript-eslint/types/-/types-6.13.2.tgz", + "commit": "", + "patch_args": [], + "patches": [], + "custom_postinstall": "", + "npm_auth": "", + "npm_auth_basic": "", + "npm_auth_username": "", + "npm_auth_password": "", + "lifecycle_hooks": [], + "extra_build_content": "", + "generate_bzl_library_targets": false + } + }, + "npm__at_typescript-eslint_types__6.13.2__links": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_links", + "attributes": { + "package": "@typescript-eslint/types", + "version": "6.13.2", + "dev": true, + "root_package": "", + "link_packages": {}, + "deps": {}, + "transitive_closure": { + "@typescript-eslint/types": [ + "6.13.2" + ] + }, + "lifecycle_build_target": false, + "lifecycle_hooks_env": [], + "lifecycle_hooks_execution_requirements": [], + "bins": {}, + "npm_translate_lock_repo": "npm", + "package_visibility": [ + "//visibility:public" + ] + } + }, + "npm__at_typescript-eslint_typescript-estree__6.13.2__typescript_5.3.3": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_rule", + "attributes": { + "package": "@typescript-eslint/typescript-estree", + "version": "6.13.2_typescript_5.3.3", + "root_package": "", + "link_workspace": "flatbuffers~", + "link_packages": {}, + "integrity": "sha512-SuD8YLQv6WHnOEtKv8D6HZUzOub855cfPnPMKvdM/Bh1plv1f7Q/0iFUDLKKlxHcEstQnaUU4QZskgQq74t+3w==", + "url": "https://registry.npmjs.org/@typescript-eslint/typescript-estree/-/typescript-estree-6.13.2.tgz", + "commit": "", + "patch_args": [], + "patches": [], + "custom_postinstall": "", + "npm_auth": "", + "npm_auth_basic": "", + "npm_auth_username": "", + "npm_auth_password": "", + "lifecycle_hooks": [], + "extra_build_content": "", + "generate_bzl_library_targets": false + } + }, + "npm__at_typescript-eslint_typescript-estree__6.13.2__typescript_5.3.3__links": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_links", + "attributes": { + "package": "@typescript-eslint/typescript-estree", + "version": "6.13.2_typescript_5.3.3", + "dev": true, + "root_package": "", + "link_packages": {}, + "deps": { + "@typescript-eslint/types": "6.13.2", + "@typescript-eslint/visitor-keys": "6.13.2", + "debug": "4.3.4", + "globby": "11.1.0", + "is-glob": "4.0.3", + "semver": "7.5.4", + "ts-api-utils": "1.0.3_typescript_5.3.3", + "typescript": "5.3.3" + }, + "transitive_closure": { + "@typescript-eslint/typescript-estree": [ + "6.13.2_typescript_5.3.3" + ], + "@typescript-eslint/types": [ + "6.13.2" + ], + "@typescript-eslint/visitor-keys": [ + "6.13.2" + ], + "debug": [ + "4.3.4" + ], + "globby": [ + "11.1.0" + ], + "is-glob": [ + "4.0.3" + ], + "semver": [ + "7.5.4" + ], + "ts-api-utils": [ + "1.0.3_typescript_5.3.3" + ], + "typescript": [ + "5.3.3" + ], + "lru-cache": [ + "6.0.0" + ], + "yallist": [ + "4.0.0" + ], + "is-extglob": [ + "2.1.1" + ], + "array-union": [ + "2.1.0" + ], + "dir-glob": [ + "3.0.1" + ], + "fast-glob": [ + "3.3.2" + ], + "ignore": [ + "5.3.0" + ], + "merge2": [ + "1.4.1" + ], + "slash": [ + "3.0.0" + ], + "@nodelib/fs.stat": [ + "2.0.5" + ], + "@nodelib/fs.walk": [ + "1.2.8" + ], + "glob-parent": [ + "5.1.2" + ], + "micromatch": [ + "4.0.5" + ], + "braces": [ + "3.0.2" + ], + "picomatch": [ + "2.3.1" + ], + "fill-range": [ + "7.0.1" + ], + "to-regex-range": [ + "5.0.1" + ], + "is-number": [ + "7.0.0" + ], + "@nodelib/fs.scandir": [ + "2.1.5" + ], + "fastq": [ + "1.15.0" + ], + "reusify": [ + "1.0.4" + ], + "run-parallel": [ + "1.2.0" + ], + "queue-microtask": [ + "1.2.3" + ], + "path-type": [ + "4.0.0" + ], + "ms": [ + "2.1.2" + ], + "eslint-visitor-keys": [ + "3.4.3" + ] + }, + "lifecycle_build_target": false, + "lifecycle_hooks_env": [], + "lifecycle_hooks_execution_requirements": [], + "bins": {}, + "npm_translate_lock_repo": "npm", + "package_visibility": [ + "//visibility:public" + ] + } + }, + "npm__at_typescript-eslint_utils__6.13.2__1796040679": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_rule", + "attributes": { + "package": "@typescript-eslint/utils", + "version": "6.13.2_1796040679", + "root_package": "", + "link_workspace": "flatbuffers~", + "link_packages": {}, + "integrity": "sha512-b9Ptq4eAZUym4idijCRzl61oPCwwREcfDI8xGk751Vhzig5fFZR9CyzDz4Sp/nxSLBYxUPyh4QdIDqWykFhNmQ==", + "url": "https://registry.npmjs.org/@typescript-eslint/utils/-/utils-6.13.2.tgz", + "commit": "", + "patch_args": [], + "patches": [], + "custom_postinstall": "", + "npm_auth": "", + "npm_auth_basic": "", + "npm_auth_username": "", + "npm_auth_password": "", + "lifecycle_hooks": [], + "extra_build_content": "", + "generate_bzl_library_targets": false + } + }, + "npm__at_typescript-eslint_utils__6.13.2__1796040679__links": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_links", + "attributes": { + "package": "@typescript-eslint/utils", + "version": "6.13.2_1796040679", + "dev": true, + "root_package": "", + "link_packages": {}, + "deps": { + "@eslint-community/eslint-utils": "4.4.0_eslint_8.55.0", + "@types/json-schema": "7.0.15", + "@types/semver": "7.5.6", + "@typescript-eslint/scope-manager": "6.13.2", + "@typescript-eslint/types": "6.13.2", + "@typescript-eslint/typescript-estree": "6.13.2_typescript_5.3.3", + "eslint": "8.55.0", + "semver": "7.5.4" + }, + "transitive_closure": { + "@typescript-eslint/utils": [ + "6.13.2_1796040679" + ], + "@eslint-community/eslint-utils": [ + "4.4.0_eslint_8.55.0" + ], + "@types/json-schema": [ + "7.0.15" + ], + "@types/semver": [ + "7.5.6" + ], + "@typescript-eslint/scope-manager": [ + "6.13.2" + ], + "@typescript-eslint/types": [ + "6.13.2" + ], + "@typescript-eslint/typescript-estree": [ + "6.13.2_typescript_5.3.3" + ], + "eslint": [ + "8.55.0" + ], + "semver": [ + "7.5.4" + ], + "lru-cache": [ + "6.0.0" + ], + "yallist": [ + "4.0.0" + ], + "@eslint-community/regexpp": [ + "4.10.0" + ], + "@eslint/eslintrc": [ + "2.1.4" + ], + "@eslint/js": [ + "8.55.0" + ], + "@humanwhocodes/config-array": [ + "0.11.13" + ], + "@humanwhocodes/module-importer": [ + "1.0.1" + ], + "@nodelib/fs.walk": [ + "1.2.8" + ], + "@ungap/structured-clone": [ + "1.2.0" + ], + "ajv": [ + "6.12.6" + ], + "chalk": [ + "4.1.2" + ], + "cross-spawn": [ + "7.0.3" + ], + "debug": [ + "4.3.4" + ], + "doctrine": [ + "3.0.0" + ], + "escape-string-regexp": [ + "4.0.0" + ], + "eslint-scope": [ + "7.2.2" + ], + "eslint-visitor-keys": [ + "3.4.3" + ], + "espree": [ + "9.6.1" + ], + "esquery": [ + "1.5.0" + ], + "esutils": [ + "2.0.3" + ], + "fast-deep-equal": [ + "3.1.3" + ], + "file-entry-cache": [ + "6.0.1" + ], + "find-up": [ + "5.0.0" + ], + "glob-parent": [ + "5.1.2", + "6.0.2" + ], + "globals": [ + "13.23.0" + ], + "graphemer": [ + "1.4.0" + ], + "ignore": [ + "5.3.0" + ], + "imurmurhash": [ + "0.1.4" + ], + "is-glob": [ + "4.0.3" + ], + "is-path-inside": [ + "3.0.3" + ], + "js-yaml": [ + "4.1.0" + ], + "json-stable-stringify-without-jsonify": [ + "1.0.1" + ], + "levn": [ + "0.4.1" + ], + "lodash.merge": [ + "4.6.2" + ], + "minimatch": [ + "3.1.2" + ], + "natural-compare": [ + "1.4.0" + ], + "optionator": [ + "0.9.3" + ], + "strip-ansi": [ + "6.0.1" + ], + "text-table": [ + "0.2.0" + ], + "ansi-regex": [ + "5.0.1" + ], + "@aashutoshrathi/word-wrap": [ + "1.2.6" + ], + "deep-is": [ + "0.1.4" + ], + "fast-levenshtein": [ + "2.0.6" + ], + "prelude-ls": [ + "1.2.1" + ], + "type-check": [ + "0.4.0" + ], + "brace-expansion": [ + "1.1.11" + ], + "balanced-match": [ + "1.0.2" + ], + "concat-map": [ + "0.0.1" + ], + "argparse": [ + "2.0.1" + ], + "is-extglob": [ + "2.1.1" + ], + "type-fest": [ + "0.20.2" + ], + "locate-path": [ + "6.0.0" + ], + "path-exists": [ + "4.0.0" + ], + "p-locate": [ + "5.0.0" + ], + "p-limit": [ + "3.1.0" + ], + "yocto-queue": [ + "0.1.0" + ], + "flat-cache": [ + "3.2.0" + ], + "flatted": [ + "3.2.9" + ], + "keyv": [ + "4.5.4" + ], + "rimraf": [ + "3.0.2" + ], + "glob": [ + "7.2.3" + ], + "fs.realpath": [ + "1.0.0" + ], + "inflight": [ + "1.0.6" + ], + "inherits": [ + "2.0.4" + ], + "once": [ + "1.4.0" + ], + "path-is-absolute": [ + "1.0.1" + ], + "wrappy": [ + "1.0.2" + ], + "json-buffer": [ + "3.0.1" + ], + "estraverse": [ + "5.3.0" + ], + "acorn": [ + "8.11.2" + ], + "acorn-jsx": [ + "5.3.2_acorn_8.11.2" + ], + "esrecurse": [ + "4.3.0" + ], + "ms": [ + "2.1.2" + ], + "path-key": [ + "3.1.1" + ], + "shebang-command": [ + "2.0.0" + ], + "which": [ + "2.0.2" + ], + "isexe": [ + "2.0.0" + ], + "shebang-regex": [ + "3.0.0" + ], + "ansi-styles": [ + "4.3.0" + ], + "supports-color": [ + "7.2.0" + ], + "has-flag": [ + "4.0.0" + ], + "color-convert": [ + "2.0.1" + ], + "color-name": [ + "1.1.4" + ], + "fast-json-stable-stringify": [ + "2.1.0" + ], + "json-schema-traverse": [ + "0.4.1" + ], + "uri-js": [ + "4.4.1" + ], + "punycode": [ + "2.3.1" + ], + "@nodelib/fs.scandir": [ + "2.1.5" + ], + "fastq": [ + "1.15.0" + ], + "reusify": [ + "1.0.4" + ], + "@nodelib/fs.stat": [ + "2.0.5" + ], + "run-parallel": [ + "1.2.0" + ], + "queue-microtask": [ + "1.2.3" + ], + "@humanwhocodes/object-schema": [ + "2.0.1" + ], + "import-fresh": [ + "3.3.0" + ], + "strip-json-comments": [ + "3.1.1" + ], + "parent-module": [ + "1.0.1" + ], + "resolve-from": [ + "4.0.0" + ], + "callsites": [ + "3.1.0" + ], + "@typescript-eslint/visitor-keys": [ + "6.13.2" + ], + "globby": [ + "11.1.0" + ], + "ts-api-utils": [ + "1.0.3_typescript_5.3.3" + ], + "typescript": [ + "5.3.3" + ], + "array-union": [ + "2.1.0" + ], + "dir-glob": [ + "3.0.1" + ], + "fast-glob": [ + "3.3.2" + ], + "merge2": [ + "1.4.1" + ], + "slash": [ + "3.0.0" + ], + "micromatch": [ + "4.0.5" + ], + "braces": [ + "3.0.2" + ], + "picomatch": [ + "2.3.1" + ], + "fill-range": [ + "7.0.1" + ], + "to-regex-range": [ + "5.0.1" + ], + "is-number": [ + "7.0.0" + ], + "path-type": [ + "4.0.0" + ] + }, + "lifecycle_build_target": false, + "lifecycle_hooks_env": [], + "lifecycle_hooks_execution_requirements": [], + "bins": {}, + "npm_translate_lock_repo": "npm", + "package_visibility": [ + "//visibility:public" + ] + } + }, + "npm__at_typescript-eslint_visitor-keys__6.13.2": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_rule", + "attributes": { + "package": "@typescript-eslint/visitor-keys", + "version": "6.13.2", + "root_package": "", + "link_workspace": "flatbuffers~", + "link_packages": {}, + "integrity": "sha512-OGznFs0eAQXJsp+xSd6k/O1UbFi/K/L7WjqeRoFE7vadjAF9y0uppXhYNQNEqygjou782maGClOoZwPqF0Drlw==", + "url": "https://registry.npmjs.org/@typescript-eslint/visitor-keys/-/visitor-keys-6.13.2.tgz", + "commit": "", + "patch_args": [], + "patches": [], + "custom_postinstall": "", + "npm_auth": "", + "npm_auth_basic": "", + "npm_auth_username": "", + "npm_auth_password": "", + "lifecycle_hooks": [], + "extra_build_content": "", + "generate_bzl_library_targets": false + } + }, + "npm__at_typescript-eslint_visitor-keys__6.13.2__links": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_links", + "attributes": { + "package": "@typescript-eslint/visitor-keys", + "version": "6.13.2", + "dev": true, + "root_package": "", + "link_packages": {}, + "deps": { + "@typescript-eslint/types": "6.13.2", + "eslint-visitor-keys": "3.4.3" + }, + "transitive_closure": { + "@typescript-eslint/visitor-keys": [ + "6.13.2" + ], + "@typescript-eslint/types": [ + "6.13.2" + ], + "eslint-visitor-keys": [ + "3.4.3" + ] + }, + "lifecycle_build_target": false, + "lifecycle_hooks_env": [], + "lifecycle_hooks_execution_requirements": [], + "bins": {}, + "npm_translate_lock_repo": "npm", + "package_visibility": [ + "//visibility:public" + ] + } + }, + "npm__at_ungap_structured-clone__1.2.0": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_rule", + "attributes": { + "package": "@ungap/structured-clone", + "version": "1.2.0", + "root_package": "", + "link_workspace": "flatbuffers~", + "link_packages": {}, + "integrity": "sha512-zuVdFrMJiuCDQUMCzQaD6KL28MjnqqN8XnAqiEq9PNm/hCPTSGfrXCOfwj1ow4LFb/tNymJPwsNbVePc1xFqrQ==", + "url": "https://registry.npmjs.org/@ungap/structured-clone/-/structured-clone-1.2.0.tgz", + "commit": "", + "patch_args": [], + "patches": [], + "custom_postinstall": "", + "npm_auth": "", + "npm_auth_basic": "", + "npm_auth_username": "", + "npm_auth_password": "", + "lifecycle_hooks": [], + "extra_build_content": "", + "generate_bzl_library_targets": false + } + }, + "npm__at_ungap_structured-clone__1.2.0__links": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_links", + "attributes": { + "package": "@ungap/structured-clone", + "version": "1.2.0", + "dev": true, + "root_package": "", + "link_packages": {}, + "deps": {}, + "transitive_closure": { + "@ungap/structured-clone": [ + "1.2.0" + ] + }, + "lifecycle_build_target": false, + "lifecycle_hooks_env": [], + "lifecycle_hooks_execution_requirements": [], + "bins": {}, + "npm_translate_lock_repo": "npm", + "package_visibility": [ + "//visibility:public" + ] + } + }, + "npm__acorn-jsx__5.3.2__acorn_8.11.2": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_rule", + "attributes": { + "package": "acorn-jsx", + "version": "5.3.2_acorn_8.11.2", + "root_package": "", + "link_workspace": "flatbuffers~", + "link_packages": {}, + "integrity": "sha512-rq9s+JNhf0IChjtDXxllJ7g41oZk5SlXtp0LHwyA5cejwn7vKmKp4pPri6YEePv2PU65sAsegbXtIinmDFDXgQ==", + "url": "https://registry.npmjs.org/acorn-jsx/-/acorn-jsx-5.3.2.tgz", + "commit": "", + "patch_args": [], + "patches": [], + "custom_postinstall": "", + "npm_auth": "", + "npm_auth_basic": "", + "npm_auth_username": "", + "npm_auth_password": "", + "lifecycle_hooks": [], + "extra_build_content": "", + "generate_bzl_library_targets": false + } + }, + "npm__acorn-jsx__5.3.2__acorn_8.11.2__links": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_links", + "attributes": { + "package": "acorn-jsx", + "version": "5.3.2_acorn_8.11.2", + "dev": true, + "root_package": "", + "link_packages": {}, + "deps": { + "acorn": "8.11.2" + }, + "transitive_closure": { + "acorn-jsx": [ + "5.3.2_acorn_8.11.2" + ], + "acorn": [ + "8.11.2" + ] + }, + "lifecycle_build_target": false, + "lifecycle_hooks_env": [], + "lifecycle_hooks_execution_requirements": [], + "bins": {}, + "npm_translate_lock_repo": "npm", + "package_visibility": [ + "//visibility:public" + ] + } + }, + "npm__acorn__8.11.2": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_rule", + "attributes": { + "package": "acorn", + "version": "8.11.2", + "root_package": "", + "link_workspace": "flatbuffers~", + "link_packages": {}, + "integrity": "sha512-nc0Axzp/0FILLEVsm4fNwLCwMttvhEI263QtVPQcbpfZZ3ts0hLsZGOpE6czNlid7CJ9MlyH8reXkpsf3YUY4w==", + "url": "https://registry.npmjs.org/acorn/-/acorn-8.11.2.tgz", + "commit": "", + "patch_args": [], + "patches": [], + "custom_postinstall": "", + "npm_auth": "", + "npm_auth_basic": "", + "npm_auth_username": "", + "npm_auth_password": "", + "lifecycle_hooks": [], + "extra_build_content": "", + "generate_bzl_library_targets": false + } + }, + "npm__acorn__8.11.2__links": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_links", + "attributes": { + "package": "acorn", + "version": "8.11.2", + "dev": true, + "root_package": "", + "link_packages": {}, + "deps": {}, + "transitive_closure": { + "acorn": [ + "8.11.2" + ] + }, + "lifecycle_build_target": false, + "lifecycle_hooks_env": [], + "lifecycle_hooks_execution_requirements": [], + "bins": {}, + "npm_translate_lock_repo": "npm", + "package_visibility": [ + "//visibility:public" + ] + } + }, + "npm__ajv__6.12.6": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_rule", + "attributes": { + "package": "ajv", + "version": "6.12.6", + "root_package": "", + "link_workspace": "flatbuffers~", + "link_packages": {}, + "integrity": "sha512-j3fVLgvTo527anyYyJOGTYJbG+vnnQYvE0m5mmkc1TK+nxAppkCLMIL0aZ4dblVCNoGShhm+kzE4ZUykBoMg4g==", + "url": "https://registry.npmjs.org/ajv/-/ajv-6.12.6.tgz", + "commit": "", + "patch_args": [], + "patches": [], + "custom_postinstall": "", + "npm_auth": "", + "npm_auth_basic": "", + "npm_auth_username": "", + "npm_auth_password": "", + "lifecycle_hooks": [], + "extra_build_content": "", + "generate_bzl_library_targets": false + } + }, + "npm__ajv__6.12.6__links": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_links", + "attributes": { + "package": "ajv", + "version": "6.12.6", + "dev": true, + "root_package": "", + "link_packages": {}, + "deps": { + "fast-deep-equal": "3.1.3", + "fast-json-stable-stringify": "2.1.0", + "json-schema-traverse": "0.4.1", + "uri-js": "4.4.1" + }, + "transitive_closure": { + "ajv": [ + "6.12.6" + ], + "fast-deep-equal": [ + "3.1.3" + ], + "fast-json-stable-stringify": [ + "2.1.0" + ], + "json-schema-traverse": [ + "0.4.1" + ], + "uri-js": [ + "4.4.1" + ], + "punycode": [ + "2.3.1" + ] + }, + "lifecycle_build_target": false, + "lifecycle_hooks_env": [], + "lifecycle_hooks_execution_requirements": [], + "bins": {}, + "npm_translate_lock_repo": "npm", + "package_visibility": [ + "//visibility:public" + ] + } + }, + "npm__ansi-regex__5.0.1": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_rule", + "attributes": { + "package": "ansi-regex", + "version": "5.0.1", + "root_package": "", + "link_workspace": "flatbuffers~", + "link_packages": {}, + "integrity": "sha512-quJQXlTSUGL2LH9SUXo8VwsY4soanhgo6LNSm84E1LBcE8s3O0wpdiRzyR9z/ZZJMlMWv37qOOb9pdJlMUEKFQ==", + "url": "https://registry.npmjs.org/ansi-regex/-/ansi-regex-5.0.1.tgz", + "commit": "", + "patch_args": [], + "patches": [], + "custom_postinstall": "", + "npm_auth": "", + "npm_auth_basic": "", + "npm_auth_username": "", + "npm_auth_password": "", + "lifecycle_hooks": [], + "extra_build_content": "", + "generate_bzl_library_targets": false + } + }, + "npm__ansi-regex__5.0.1__links": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_links", + "attributes": { + "package": "ansi-regex", + "version": "5.0.1", + "dev": true, + "root_package": "", + "link_packages": {}, + "deps": {}, + "transitive_closure": { + "ansi-regex": [ + "5.0.1" + ] + }, + "lifecycle_build_target": false, + "lifecycle_hooks_env": [], + "lifecycle_hooks_execution_requirements": [], + "bins": {}, + "npm_translate_lock_repo": "npm", + "package_visibility": [ + "//visibility:public" + ] + } + }, + "npm__ansi-styles__4.3.0": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_rule", + "attributes": { + "package": "ansi-styles", + "version": "4.3.0", + "root_package": "", + "link_workspace": "flatbuffers~", + "link_packages": {}, + "integrity": "sha512-zbB9rCJAT1rbjiVDb2hqKFHNYLxgtk8NURxZ3IZwD3F6NtxbXZQCnnSi1Lkx+IDohdPlFp222wVALIheZJQSEg==", + "url": "https://registry.npmjs.org/ansi-styles/-/ansi-styles-4.3.0.tgz", + "commit": "", + "patch_args": [], + "patches": [], + "custom_postinstall": "", + "npm_auth": "", + "npm_auth_basic": "", + "npm_auth_username": "", + "npm_auth_password": "", + "lifecycle_hooks": [], + "extra_build_content": "", + "generate_bzl_library_targets": false + } + }, + "npm__ansi-styles__4.3.0__links": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_links", + "attributes": { + "package": "ansi-styles", + "version": "4.3.0", + "dev": true, + "root_package": "", + "link_packages": {}, + "deps": { + "color-convert": "2.0.1" + }, + "transitive_closure": { + "ansi-styles": [ + "4.3.0" + ], + "color-convert": [ + "2.0.1" + ], + "color-name": [ + "1.1.4" + ] + }, + "lifecycle_build_target": false, + "lifecycle_hooks_env": [], + "lifecycle_hooks_execution_requirements": [], + "bins": {}, + "npm_translate_lock_repo": "npm", + "package_visibility": [ + "//visibility:public" + ] + } + }, + "npm__argparse__2.0.1": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_rule", + "attributes": { + "package": "argparse", + "version": "2.0.1", + "root_package": "", + "link_workspace": "flatbuffers~", + "link_packages": {}, + "integrity": "sha512-8+9WqebbFzpX9OR+Wa6O29asIogeRMzcGtAINdpMHHyAg10f05aSFVBbcEqGf/PXw1EjAZ+q2/bEBg3DvurK3Q==", + "url": "https://registry.npmjs.org/argparse/-/argparse-2.0.1.tgz", + "commit": "", + "patch_args": [], + "patches": [], + "custom_postinstall": "", + "npm_auth": "", + "npm_auth_basic": "", + "npm_auth_username": "", + "npm_auth_password": "", + "lifecycle_hooks": [], + "extra_build_content": "", + "generate_bzl_library_targets": false + } + }, + "npm__argparse__2.0.1__links": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_links", + "attributes": { + "package": "argparse", + "version": "2.0.1", + "dev": true, + "root_package": "", + "link_packages": {}, + "deps": {}, + "transitive_closure": { + "argparse": [ + "2.0.1" + ] + }, + "lifecycle_build_target": false, + "lifecycle_hooks_env": [], + "lifecycle_hooks_execution_requirements": [], + "bins": {}, + "npm_translate_lock_repo": "npm", + "package_visibility": [ + "//visibility:public" + ] + } + }, + "npm__array-union__2.1.0": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_rule", + "attributes": { + "package": "array-union", + "version": "2.1.0", + "root_package": "", + "link_workspace": "flatbuffers~", + "link_packages": {}, + "integrity": "sha512-HGyxoOTYUyCM6stUe6EJgnd4EoewAI7zMdfqO+kGjnlZmBDz/cR5pf8r/cR4Wq60sL/p0IkcjUEEPwS3GFrIyw==", + "url": "https://registry.npmjs.org/array-union/-/array-union-2.1.0.tgz", + "commit": "", + "patch_args": [], + "patches": [], + "custom_postinstall": "", + "npm_auth": "", + "npm_auth_basic": "", + "npm_auth_username": "", + "npm_auth_password": "", + "lifecycle_hooks": [], + "extra_build_content": "", + "generate_bzl_library_targets": false + } + }, + "npm__array-union__2.1.0__links": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_links", + "attributes": { + "package": "array-union", + "version": "2.1.0", + "dev": true, + "root_package": "", + "link_packages": {}, + "deps": {}, + "transitive_closure": { + "array-union": [ + "2.1.0" + ] + }, + "lifecycle_build_target": false, + "lifecycle_hooks_env": [], + "lifecycle_hooks_execution_requirements": [], + "bins": {}, + "npm_translate_lock_repo": "npm", + "package_visibility": [ + "//visibility:public" + ] + } + }, + "npm__balanced-match__1.0.2": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_rule", + "attributes": { + "package": "balanced-match", + "version": "1.0.2", + "root_package": "", + "link_workspace": "flatbuffers~", + "link_packages": {}, + "integrity": "sha512-3oSeUO0TMV67hN1AmbXsK4yaqU7tjiHlbxRDZOpH0KW9+CeX4bRAaX0Anxt0tx2MrpRpWwQaPwIlISEJhYU5Pw==", + "url": "https://registry.npmjs.org/balanced-match/-/balanced-match-1.0.2.tgz", + "commit": "", + "patch_args": [], + "patches": [], + "custom_postinstall": "", + "npm_auth": "", + "npm_auth_basic": "", + "npm_auth_username": "", + "npm_auth_password": "", + "lifecycle_hooks": [], + "extra_build_content": "", + "generate_bzl_library_targets": false + } + }, + "npm__balanced-match__1.0.2__links": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_links", + "attributes": { + "package": "balanced-match", + "version": "1.0.2", + "dev": true, + "root_package": "", + "link_packages": {}, + "deps": {}, + "transitive_closure": { + "balanced-match": [ + "1.0.2" + ] + }, + "lifecycle_build_target": false, + "lifecycle_hooks_env": [], + "lifecycle_hooks_execution_requirements": [], + "bins": {}, + "npm_translate_lock_repo": "npm", + "package_visibility": [ + "//visibility:public" + ] + } + }, + "npm__brace-expansion__1.1.11": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_rule", + "attributes": { + "package": "brace-expansion", + "version": "1.1.11", + "root_package": "", + "link_workspace": "flatbuffers~", + "link_packages": {}, + "integrity": "sha512-iCuPHDFgrHX7H2vEI/5xpz07zSHB00TpugqhmYtVmMO6518mCuRMoOYFldEBl0g187ufozdaHgWKcYFb61qGiA==", + "url": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-1.1.11.tgz", + "commit": "", + "patch_args": [], + "patches": [], + "custom_postinstall": "", + "npm_auth": "", + "npm_auth_basic": "", + "npm_auth_username": "", + "npm_auth_password": "", + "lifecycle_hooks": [], + "extra_build_content": "", + "generate_bzl_library_targets": false + } + }, + "npm__brace-expansion__1.1.11__links": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_links", + "attributes": { + "package": "brace-expansion", + "version": "1.1.11", + "dev": true, + "root_package": "", + "link_packages": {}, + "deps": { + "balanced-match": "1.0.2", + "concat-map": "0.0.1" + }, + "transitive_closure": { + "brace-expansion": [ + "1.1.11" + ], + "balanced-match": [ + "1.0.2" + ], + "concat-map": [ + "0.0.1" + ] + }, + "lifecycle_build_target": false, + "lifecycle_hooks_env": [], + "lifecycle_hooks_execution_requirements": [], + "bins": {}, + "npm_translate_lock_repo": "npm", + "package_visibility": [ + "//visibility:public" + ] + } + }, + "npm__braces__3.0.2": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_rule", + "attributes": { + "package": "braces", + "version": "3.0.2", + "root_package": "", + "link_workspace": "flatbuffers~", + "link_packages": {}, + "integrity": "sha512-b8um+L1RzM3WDSzvhm6gIz1yfTbBt6YTlcEKAvsmqCZZFw46z626lVj9j1yEPW33H5H+lBQpZMP1k8l+78Ha0A==", + "url": "https://registry.npmjs.org/braces/-/braces-3.0.2.tgz", + "commit": "", + "patch_args": [], + "patches": [], + "custom_postinstall": "", + "npm_auth": "", + "npm_auth_basic": "", + "npm_auth_username": "", + "npm_auth_password": "", + "lifecycle_hooks": [], + "extra_build_content": "", + "generate_bzl_library_targets": false + } + }, + "npm__braces__3.0.2__links": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_links", + "attributes": { + "package": "braces", + "version": "3.0.2", + "dev": true, + "root_package": "", + "link_packages": {}, + "deps": { + "fill-range": "7.0.1" + }, + "transitive_closure": { + "braces": [ + "3.0.2" + ], + "fill-range": [ + "7.0.1" + ], + "to-regex-range": [ + "5.0.1" + ], + "is-number": [ + "7.0.0" + ] + }, + "lifecycle_build_target": false, + "lifecycle_hooks_env": [], + "lifecycle_hooks_execution_requirements": [], + "bins": {}, + "npm_translate_lock_repo": "npm", + "package_visibility": [ + "//visibility:public" + ] + } + }, + "npm__callsites__3.1.0": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_rule", + "attributes": { + "package": "callsites", + "version": "3.1.0", + "root_package": "", + "link_workspace": "flatbuffers~", + "link_packages": {}, + "integrity": "sha512-P8BjAsXvZS+VIDUI11hHCQEv74YT67YUi5JJFNWIqL235sBmjX4+qx9Muvls5ivyNENctx46xQLQ3aTuE7ssaQ==", + "url": "https://registry.npmjs.org/callsites/-/callsites-3.1.0.tgz", + "commit": "", + "patch_args": [], + "patches": [], + "custom_postinstall": "", + "npm_auth": "", + "npm_auth_basic": "", + "npm_auth_username": "", + "npm_auth_password": "", + "lifecycle_hooks": [], + "extra_build_content": "", + "generate_bzl_library_targets": false + } + }, + "npm__callsites__3.1.0__links": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_links", + "attributes": { + "package": "callsites", + "version": "3.1.0", + "dev": true, + "root_package": "", + "link_packages": {}, + "deps": {}, + "transitive_closure": { + "callsites": [ + "3.1.0" + ] + }, + "lifecycle_build_target": false, + "lifecycle_hooks_env": [], + "lifecycle_hooks_execution_requirements": [], + "bins": {}, + "npm_translate_lock_repo": "npm", + "package_visibility": [ + "//visibility:public" + ] + } + }, + "npm__chalk__4.1.2": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_rule", + "attributes": { + "package": "chalk", + "version": "4.1.2", + "root_package": "", + "link_workspace": "flatbuffers~", + "link_packages": {}, + "integrity": "sha512-oKnbhFyRIXpUuez8iBMmyEa4nbj4IOQyuhc/wy9kY7/WVPcwIO9VA668Pu8RkO7+0G76SLROeyw9CpQ061i4mA==", + "url": "https://registry.npmjs.org/chalk/-/chalk-4.1.2.tgz", + "commit": "", + "patch_args": [], + "patches": [], + "custom_postinstall": "", + "npm_auth": "", + "npm_auth_basic": "", + "npm_auth_username": "", + "npm_auth_password": "", + "lifecycle_hooks": [], + "extra_build_content": "", + "generate_bzl_library_targets": false + } + }, + "npm__chalk__4.1.2__links": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_links", + "attributes": { + "package": "chalk", + "version": "4.1.2", + "dev": true, + "root_package": "", + "link_packages": {}, + "deps": { + "ansi-styles": "4.3.0", + "supports-color": "7.2.0" + }, + "transitive_closure": { + "chalk": [ + "4.1.2" + ], + "ansi-styles": [ + "4.3.0" + ], + "supports-color": [ + "7.2.0" + ], + "has-flag": [ + "4.0.0" + ], + "color-convert": [ + "2.0.1" + ], + "color-name": [ + "1.1.4" + ] + }, + "lifecycle_build_target": false, + "lifecycle_hooks_env": [], + "lifecycle_hooks_execution_requirements": [], + "bins": {}, + "npm_translate_lock_repo": "npm", + "package_visibility": [ + "//visibility:public" + ] + } + }, + "npm__color-convert__2.0.1": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_rule", + "attributes": { + "package": "color-convert", + "version": "2.0.1", + "root_package": "", + "link_workspace": "flatbuffers~", + "link_packages": {}, + "integrity": "sha512-RRECPsj7iu/xb5oKYcsFHSppFNnsj/52OVTRKb4zP5onXwVF3zVmmToNcOfGC+CRDpfK/U584fMg38ZHCaElKQ==", + "url": "https://registry.npmjs.org/color-convert/-/color-convert-2.0.1.tgz", + "commit": "", + "patch_args": [], + "patches": [], + "custom_postinstall": "", + "npm_auth": "", + "npm_auth_basic": "", + "npm_auth_username": "", + "npm_auth_password": "", + "lifecycle_hooks": [], + "extra_build_content": "", + "generate_bzl_library_targets": false + } + }, + "npm__color-convert__2.0.1__links": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_links", + "attributes": { + "package": "color-convert", + "version": "2.0.1", + "dev": true, + "root_package": "", + "link_packages": {}, + "deps": { + "color-name": "1.1.4" + }, + "transitive_closure": { + "color-convert": [ + "2.0.1" + ], + "color-name": [ + "1.1.4" + ] + }, + "lifecycle_build_target": false, + "lifecycle_hooks_env": [], + "lifecycle_hooks_execution_requirements": [], + "bins": {}, + "npm_translate_lock_repo": "npm", + "package_visibility": [ + "//visibility:public" + ] + } + }, + "npm__color-name__1.1.4": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_rule", + "attributes": { + "package": "color-name", + "version": "1.1.4", + "root_package": "", + "link_workspace": "flatbuffers~", + "link_packages": {}, + "integrity": "sha512-dOy+3AuW3a2wNbZHIuMZpTcgjGuLU/uBL/ubcZF9OXbDo8ff4O8yVp5Bf0efS8uEoYo5q4Fx7dY9OgQGXgAsQA==", + "url": "https://registry.npmjs.org/color-name/-/color-name-1.1.4.tgz", + "commit": "", + "patch_args": [], + "patches": [], + "custom_postinstall": "", + "npm_auth": "", + "npm_auth_basic": "", + "npm_auth_username": "", + "npm_auth_password": "", + "lifecycle_hooks": [], + "extra_build_content": "", + "generate_bzl_library_targets": false + } + }, + "npm__color-name__1.1.4__links": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_links", + "attributes": { + "package": "color-name", + "version": "1.1.4", + "dev": true, + "root_package": "", + "link_packages": {}, + "deps": {}, + "transitive_closure": { + "color-name": [ + "1.1.4" + ] + }, + "lifecycle_build_target": false, + "lifecycle_hooks_env": [], + "lifecycle_hooks_execution_requirements": [], + "bins": {}, + "npm_translate_lock_repo": "npm", + "package_visibility": [ + "//visibility:public" + ] + } + }, + "npm__concat-map__0.0.1": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_rule", + "attributes": { + "package": "concat-map", + "version": "0.0.1", + "root_package": "", + "link_workspace": "flatbuffers~", + "link_packages": {}, + "integrity": "sha512-/Srv4dswyQNBfohGpz9o6Yb3Gz3SrUDqBH5rTuhGR7ahtlbYKnVxw2bCFMRljaA7EXHaXZ8wsHdodFvbkhKmqg==", + "url": "https://registry.npmjs.org/concat-map/-/concat-map-0.0.1.tgz", + "commit": "", + "patch_args": [], + "patches": [], + "custom_postinstall": "", + "npm_auth": "", + "npm_auth_basic": "", + "npm_auth_username": "", + "npm_auth_password": "", + "lifecycle_hooks": [], + "extra_build_content": "", + "generate_bzl_library_targets": false + } + }, + "npm__concat-map__0.0.1__links": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_links", + "attributes": { + "package": "concat-map", + "version": "0.0.1", + "dev": true, + "root_package": "", + "link_packages": {}, + "deps": {}, + "transitive_closure": { + "concat-map": [ + "0.0.1" + ] + }, + "lifecycle_build_target": false, + "lifecycle_hooks_env": [], + "lifecycle_hooks_execution_requirements": [], + "bins": {}, + "npm_translate_lock_repo": "npm", + "package_visibility": [ + "//visibility:public" + ] + } + }, + "npm__cross-spawn__7.0.3": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_rule", + "attributes": { + "package": "cross-spawn", + "version": "7.0.3", + "root_package": "", + "link_workspace": "flatbuffers~", + "link_packages": {}, + "integrity": "sha512-iRDPJKUPVEND7dHPO8rkbOnPpyDygcDFtWjpeWNCgy8WP2rXcxXL8TskReQl6OrB2G7+UJrags1q15Fudc7G6w==", + "url": "https://registry.npmjs.org/cross-spawn/-/cross-spawn-7.0.3.tgz", + "commit": "", + "patch_args": [], + "patches": [], + "custom_postinstall": "", + "npm_auth": "", + "npm_auth_basic": "", + "npm_auth_username": "", + "npm_auth_password": "", + "lifecycle_hooks": [], + "extra_build_content": "", + "generate_bzl_library_targets": false + } + }, + "npm__cross-spawn__7.0.3__links": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_links", + "attributes": { + "package": "cross-spawn", + "version": "7.0.3", + "dev": true, + "root_package": "", + "link_packages": {}, + "deps": { + "path-key": "3.1.1", + "shebang-command": "2.0.0", + "which": "2.0.2" + }, + "transitive_closure": { + "cross-spawn": [ + "7.0.3" + ], + "path-key": [ + "3.1.1" + ], + "shebang-command": [ + "2.0.0" + ], + "which": [ + "2.0.2" + ], + "isexe": [ + "2.0.0" + ], + "shebang-regex": [ + "3.0.0" + ] + }, + "lifecycle_build_target": false, + "lifecycle_hooks_env": [], + "lifecycle_hooks_execution_requirements": [], + "bins": {}, + "npm_translate_lock_repo": "npm", + "package_visibility": [ + "//visibility:public" + ] + } + }, + "npm__debug__4.3.4": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_rule", + "attributes": { + "package": "debug", + "version": "4.3.4", + "root_package": "", + "link_workspace": "flatbuffers~", + "link_packages": {}, + "integrity": "sha512-PRWFHuSU3eDtQJPvnNY7Jcket1j0t5OuOsFzPPzsekD52Zl8qUfFIPEiswXqIvHWGVHOgX+7G/vCNNhehwxfkQ==", + "url": "https://registry.npmjs.org/debug/-/debug-4.3.4.tgz", + "commit": "", + "patch_args": [], + "patches": [], + "custom_postinstall": "", + "npm_auth": "", + "npm_auth_basic": "", + "npm_auth_username": "", + "npm_auth_password": "", + "lifecycle_hooks": [], + "extra_build_content": "", + "generate_bzl_library_targets": false + } + }, + "npm__debug__4.3.4__links": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_links", + "attributes": { + "package": "debug", + "version": "4.3.4", + "dev": true, + "root_package": "", + "link_packages": {}, + "deps": { + "ms": "2.1.2" + }, + "transitive_closure": { + "debug": [ + "4.3.4" + ], + "ms": [ + "2.1.2" + ] + }, + "lifecycle_build_target": false, + "lifecycle_hooks_env": [], + "lifecycle_hooks_execution_requirements": [], + "bins": {}, + "npm_translate_lock_repo": "npm", + "package_visibility": [ + "//visibility:public" + ] + } + }, + "npm__deep-is__0.1.4": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_rule", + "attributes": { + "package": "deep-is", + "version": "0.1.4", + "root_package": "", + "link_workspace": "flatbuffers~", + "link_packages": {}, + "integrity": "sha512-oIPzksmTg4/MriiaYGO+okXDT7ztn/w3Eptv/+gSIdMdKsJo0u4CfYNFJPy+4SKMuCqGw2wxnA+URMg3t8a/bQ==", + "url": "https://registry.npmjs.org/deep-is/-/deep-is-0.1.4.tgz", + "commit": "", + "patch_args": [], + "patches": [], + "custom_postinstall": "", + "npm_auth": "", + "npm_auth_basic": "", + "npm_auth_username": "", + "npm_auth_password": "", + "lifecycle_hooks": [], + "extra_build_content": "", + "generate_bzl_library_targets": false + } + }, + "npm__deep-is__0.1.4__links": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_links", + "attributes": { + "package": "deep-is", + "version": "0.1.4", + "dev": true, + "root_package": "", + "link_packages": {}, + "deps": {}, + "transitive_closure": { + "deep-is": [ + "0.1.4" + ] + }, + "lifecycle_build_target": false, + "lifecycle_hooks_env": [], + "lifecycle_hooks_execution_requirements": [], + "bins": {}, + "npm_translate_lock_repo": "npm", + "package_visibility": [ + "//visibility:public" + ] + } + }, + "npm__dir-glob__3.0.1": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_rule", + "attributes": { + "package": "dir-glob", + "version": "3.0.1", + "root_package": "", + "link_workspace": "flatbuffers~", + "link_packages": {}, + "integrity": "sha512-WkrWp9GR4KXfKGYzOLmTuGVi1UWFfws377n9cc55/tb6DuqyF6pcQ5AbiHEshaDpY9v6oaSr2XCDidGmMwdzIA==", + "url": "https://registry.npmjs.org/dir-glob/-/dir-glob-3.0.1.tgz", + "commit": "", + "patch_args": [], + "patches": [], + "custom_postinstall": "", + "npm_auth": "", + "npm_auth_basic": "", + "npm_auth_username": "", + "npm_auth_password": "", + "lifecycle_hooks": [], + "extra_build_content": "", + "generate_bzl_library_targets": false + } + }, + "npm__dir-glob__3.0.1__links": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_links", + "attributes": { + "package": "dir-glob", + "version": "3.0.1", + "dev": true, + "root_package": "", + "link_packages": {}, + "deps": { + "path-type": "4.0.0" + }, + "transitive_closure": { + "dir-glob": [ + "3.0.1" + ], + "path-type": [ + "4.0.0" + ] + }, + "lifecycle_build_target": false, + "lifecycle_hooks_env": [], + "lifecycle_hooks_execution_requirements": [], + "bins": {}, + "npm_translate_lock_repo": "npm", + "package_visibility": [ + "//visibility:public" + ] + } + }, + "npm__doctrine__3.0.0": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_rule", + "attributes": { + "package": "doctrine", + "version": "3.0.0", + "root_package": "", + "link_workspace": "flatbuffers~", + "link_packages": {}, + "integrity": "sha512-yS+Q5i3hBf7GBkd4KG8a7eBNNWNGLTaEwwYWUijIYM7zrlYDM0BFXHjjPWlWZ1Rg7UaddZeIDmi9jF3HmqiQ2w==", + "url": "https://registry.npmjs.org/doctrine/-/doctrine-3.0.0.tgz", + "commit": "", + "patch_args": [], + "patches": [], + "custom_postinstall": "", + "npm_auth": "", + "npm_auth_basic": "", + "npm_auth_username": "", + "npm_auth_password": "", + "lifecycle_hooks": [], + "extra_build_content": "", + "generate_bzl_library_targets": false + } + }, + "npm__doctrine__3.0.0__links": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_links", + "attributes": { + "package": "doctrine", + "version": "3.0.0", + "dev": true, + "root_package": "", + "link_packages": {}, + "deps": { + "esutils": "2.0.3" + }, + "transitive_closure": { + "doctrine": [ + "3.0.0" + ], + "esutils": [ + "2.0.3" + ] + }, + "lifecycle_build_target": false, + "lifecycle_hooks_env": [], + "lifecycle_hooks_execution_requirements": [], + "bins": {}, + "npm_translate_lock_repo": "npm", + "package_visibility": [ + "//visibility:public" + ] + } + }, + "npm__esbuild__0.19.8": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_rule", + "attributes": { + "package": "esbuild", + "version": "0.19.8", + "root_package": "", + "link_workspace": "flatbuffers~", + "link_packages": { + "": [ + "esbuild" + ] + }, + "integrity": "sha512-l7iffQpT2OrZfH2rXIp7/FkmaeZM0vxbxN9KfiCwGYuZqzMg/JdvX26R31Zxn/Pxvsrg3Y9N6XTcnknqDyyv4w==", + "url": "https://registry.npmjs.org/esbuild/-/esbuild-0.19.8.tgz", + "commit": "", + "patch_args": [], + "patches": [], + "custom_postinstall": "", + "npm_auth": "", + "npm_auth_basic": "", + "npm_auth_username": "", + "npm_auth_password": "", + "lifecycle_hooks": [ + "preinstall", + "install", + "postinstall" + ], + "extra_build_content": "", + "generate_bzl_library_targets": false + } + }, + "npm__esbuild__0.19.8__links": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_links", + "attributes": { + "package": "esbuild", + "version": "0.19.8", + "dev": true, + "root_package": "", + "link_packages": { + "": [ + "esbuild" + ] + }, + "deps": { + "@esbuild/android-arm": "0.19.8", + "@esbuild/android-arm64": "0.19.8", + "@esbuild/android-x64": "0.19.8", + "@esbuild/darwin-arm64": "0.19.8", + "@esbuild/darwin-x64": "0.19.8", + "@esbuild/freebsd-arm64": "0.19.8", + "@esbuild/freebsd-x64": "0.19.8", + "@esbuild/linux-arm": "0.19.8", + "@esbuild/linux-arm64": "0.19.8", + "@esbuild/linux-ia32": "0.19.8", + "@esbuild/linux-loong64": "0.19.8", + "@esbuild/linux-mips64el": "0.19.8", + "@esbuild/linux-ppc64": "0.19.8", + "@esbuild/linux-riscv64": "0.19.8", + "@esbuild/linux-s390x": "0.19.8", + "@esbuild/linux-x64": "0.19.8", + "@esbuild/netbsd-x64": "0.19.8", + "@esbuild/openbsd-x64": "0.19.8", + "@esbuild/sunos-x64": "0.19.8", + "@esbuild/win32-arm64": "0.19.8", + "@esbuild/win32-ia32": "0.19.8", + "@esbuild/win32-x64": "0.19.8" + }, + "transitive_closure": { + "esbuild": [ + "0.19.8" + ], + "@esbuild/android-arm": [ + "0.19.8" + ], + "@esbuild/android-arm64": [ + "0.19.8" + ], + "@esbuild/android-x64": [ + "0.19.8" + ], + "@esbuild/darwin-arm64": [ + "0.19.8" + ], + "@esbuild/darwin-x64": [ + "0.19.8" + ], + "@esbuild/freebsd-arm64": [ + "0.19.8" + ], + "@esbuild/freebsd-x64": [ + "0.19.8" + ], + "@esbuild/linux-arm": [ + "0.19.8" + ], + "@esbuild/linux-arm64": [ + "0.19.8" + ], + "@esbuild/linux-ia32": [ + "0.19.8" + ], + "@esbuild/linux-loong64": [ + "0.19.8" + ], + "@esbuild/linux-mips64el": [ + "0.19.8" + ], + "@esbuild/linux-ppc64": [ + "0.19.8" + ], + "@esbuild/linux-riscv64": [ + "0.19.8" + ], + "@esbuild/linux-s390x": [ + "0.19.8" + ], + "@esbuild/linux-x64": [ + "0.19.8" + ], + "@esbuild/netbsd-x64": [ + "0.19.8" + ], + "@esbuild/openbsd-x64": [ + "0.19.8" + ], + "@esbuild/sunos-x64": [ + "0.19.8" + ], + "@esbuild/win32-arm64": [ + "0.19.8" + ], + "@esbuild/win32-ia32": [ + "0.19.8" + ], + "@esbuild/win32-x64": [ + "0.19.8" + ] + }, + "lifecycle_build_target": true, + "lifecycle_hooks_env": [], + "lifecycle_hooks_execution_requirements": [ + "no-sandbox" + ], + "bins": {}, + "npm_translate_lock_repo": "npm", + "package_visibility": [ + "//visibility:public" + ] + } + }, + "npm__escape-string-regexp__4.0.0": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_rule", + "attributes": { + "package": "escape-string-regexp", + "version": "4.0.0", + "root_package": "", + "link_workspace": "flatbuffers~", + "link_packages": {}, + "integrity": "sha512-TtpcNJ3XAzx3Gq8sWRzJaVajRs0uVxA2YAkdb1jm2YkPz4G6egUFAyA3n5vtEIZefPk5Wa4UXbKuS5fKkJWdgA==", + "url": "https://registry.npmjs.org/escape-string-regexp/-/escape-string-regexp-4.0.0.tgz", + "commit": "", + "patch_args": [], + "patches": [], + "custom_postinstall": "", + "npm_auth": "", + "npm_auth_basic": "", + "npm_auth_username": "", + "npm_auth_password": "", + "lifecycle_hooks": [], + "extra_build_content": "", + "generate_bzl_library_targets": false + } + }, + "npm__escape-string-regexp__4.0.0__links": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_links", + "attributes": { + "package": "escape-string-regexp", + "version": "4.0.0", + "dev": true, + "root_package": "", + "link_packages": {}, + "deps": {}, + "transitive_closure": { + "escape-string-regexp": [ + "4.0.0" + ] + }, + "lifecycle_build_target": false, + "lifecycle_hooks_env": [], + "lifecycle_hooks_execution_requirements": [], + "bins": {}, + "npm_translate_lock_repo": "npm", + "package_visibility": [ + "//visibility:public" + ] + } + }, + "npm__eslint-scope__7.2.2": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_rule", + "attributes": { + "package": "eslint-scope", + "version": "7.2.2", + "root_package": "", + "link_workspace": "flatbuffers~", + "link_packages": {}, + "integrity": "sha512-dOt21O7lTMhDM+X9mB4GX+DZrZtCUJPL/wlcTqxyrx5IvO0IYtILdtrQGQp+8n5S0gwSVmOf9NQrjMOgfQZlIg==", + "url": "https://registry.npmjs.org/eslint-scope/-/eslint-scope-7.2.2.tgz", + "commit": "", + "patch_args": [], + "patches": [], + "custom_postinstall": "", + "npm_auth": "", + "npm_auth_basic": "", + "npm_auth_username": "", + "npm_auth_password": "", + "lifecycle_hooks": [], + "extra_build_content": "", + "generate_bzl_library_targets": false + } + }, + "npm__eslint-scope__7.2.2__links": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_links", + "attributes": { + "package": "eslint-scope", + "version": "7.2.2", + "dev": true, + "root_package": "", + "link_packages": {}, + "deps": { + "esrecurse": "4.3.0", + "estraverse": "5.3.0" + }, + "transitive_closure": { + "eslint-scope": [ + "7.2.2" + ], + "esrecurse": [ + "4.3.0" + ], + "estraverse": [ + "5.3.0" + ] + }, + "lifecycle_build_target": false, + "lifecycle_hooks_env": [], + "lifecycle_hooks_execution_requirements": [], + "bins": {}, + "npm_translate_lock_repo": "npm", + "package_visibility": [ + "//visibility:public" + ] + } + }, + "npm__eslint-visitor-keys__3.4.3": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_rule", + "attributes": { + "package": "eslint-visitor-keys", + "version": "3.4.3", + "root_package": "", + "link_workspace": "flatbuffers~", + "link_packages": {}, + "integrity": "sha512-wpc+LXeiyiisxPlEkUzU6svyS1frIO3Mgxj1fdy7Pm8Ygzguax2N3Fa/D/ag1WqbOprdI+uY6wMUl8/a2G+iag==", + "url": "https://registry.npmjs.org/eslint-visitor-keys/-/eslint-visitor-keys-3.4.3.tgz", + "commit": "", + "patch_args": [], + "patches": [], + "custom_postinstall": "", + "npm_auth": "", + "npm_auth_basic": "", + "npm_auth_username": "", + "npm_auth_password": "", + "lifecycle_hooks": [], + "extra_build_content": "", + "generate_bzl_library_targets": false + } + }, + "npm__eslint-visitor-keys__3.4.3__links": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_links", + "attributes": { + "package": "eslint-visitor-keys", + "version": "3.4.3", + "dev": true, + "root_package": "", + "link_packages": {}, + "deps": {}, + "transitive_closure": { + "eslint-visitor-keys": [ + "3.4.3" + ] + }, + "lifecycle_build_target": false, + "lifecycle_hooks_env": [], + "lifecycle_hooks_execution_requirements": [], + "bins": {}, + "npm_translate_lock_repo": "npm", + "package_visibility": [ + "//visibility:public" + ] + } + }, + "npm__eslint__8.55.0": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_rule", + "attributes": { + "package": "eslint", + "version": "8.55.0", + "root_package": "", + "link_workspace": "flatbuffers~", + "link_packages": { + "": [ + "eslint" + ] + }, + "integrity": "sha512-iyUUAM0PCKj5QpwGfmCAG9XXbZCWsqP/eWAWrG/W0umvjuLRBECwSFdt+rCntju0xEH7teIABPwXpahftIaTdA==", + "url": "https://registry.npmjs.org/eslint/-/eslint-8.55.0.tgz", + "commit": "", + "patch_args": [], + "patches": [], + "custom_postinstall": "", + "npm_auth": "", + "npm_auth_basic": "", + "npm_auth_username": "", + "npm_auth_password": "", + "lifecycle_hooks": [], + "extra_build_content": "", + "generate_bzl_library_targets": false + } + }, + "npm__eslint__8.55.0__links": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_links", + "attributes": { + "package": "eslint", + "version": "8.55.0", + "dev": true, + "root_package": "", + "link_packages": { + "": [ + "eslint" + ] + }, + "deps": { + "@eslint-community/eslint-utils": "4.4.0_eslint_8.55.0", + "@eslint-community/regexpp": "4.10.0", + "@eslint/eslintrc": "2.1.4", + "@eslint/js": "8.55.0", + "@humanwhocodes/config-array": "0.11.13", + "@humanwhocodes/module-importer": "1.0.1", + "@nodelib/fs.walk": "1.2.8", + "@ungap/structured-clone": "1.2.0", + "ajv": "6.12.6", + "chalk": "4.1.2", + "cross-spawn": "7.0.3", + "debug": "4.3.4", + "doctrine": "3.0.0", + "escape-string-regexp": "4.0.0", + "eslint-scope": "7.2.2", + "eslint-visitor-keys": "3.4.3", + "espree": "9.6.1", + "esquery": "1.5.0", + "esutils": "2.0.3", + "fast-deep-equal": "3.1.3", + "file-entry-cache": "6.0.1", + "find-up": "5.0.0", + "glob-parent": "6.0.2", + "globals": "13.23.0", + "graphemer": "1.4.0", + "ignore": "5.3.0", + "imurmurhash": "0.1.4", + "is-glob": "4.0.3", + "is-path-inside": "3.0.3", + "js-yaml": "4.1.0", + "json-stable-stringify-without-jsonify": "1.0.1", + "levn": "0.4.1", + "lodash.merge": "4.6.2", + "minimatch": "3.1.2", + "natural-compare": "1.4.0", + "optionator": "0.9.3", + "strip-ansi": "6.0.1", + "text-table": "0.2.0" + }, + "transitive_closure": { + "eslint": [ + "8.55.0" + ], + "@eslint-community/eslint-utils": [ + "4.4.0_eslint_8.55.0" + ], + "@eslint-community/regexpp": [ + "4.10.0" + ], + "@eslint/eslintrc": [ + "2.1.4" + ], + "@eslint/js": [ + "8.55.0" + ], + "@humanwhocodes/config-array": [ + "0.11.13" + ], + "@humanwhocodes/module-importer": [ + "1.0.1" + ], + "@nodelib/fs.walk": [ + "1.2.8" + ], + "@ungap/structured-clone": [ + "1.2.0" + ], + "ajv": [ + "6.12.6" + ], + "chalk": [ + "4.1.2" + ], + "cross-spawn": [ + "7.0.3" + ], + "debug": [ + "4.3.4" + ], + "doctrine": [ + "3.0.0" + ], + "escape-string-regexp": [ + "4.0.0" + ], + "eslint-scope": [ + "7.2.2" + ], + "eslint-visitor-keys": [ + "3.4.3" + ], + "espree": [ + "9.6.1" + ], + "esquery": [ + "1.5.0" + ], + "esutils": [ + "2.0.3" + ], + "fast-deep-equal": [ + "3.1.3" + ], + "file-entry-cache": [ + "6.0.1" + ], + "find-up": [ + "5.0.0" + ], + "glob-parent": [ + "6.0.2" + ], + "globals": [ + "13.23.0" + ], + "graphemer": [ + "1.4.0" + ], + "ignore": [ + "5.3.0" + ], + "imurmurhash": [ + "0.1.4" + ], + "is-glob": [ + "4.0.3" + ], + "is-path-inside": [ + "3.0.3" + ], + "js-yaml": [ + "4.1.0" + ], + "json-stable-stringify-without-jsonify": [ + "1.0.1" + ], + "levn": [ + "0.4.1" + ], + "lodash.merge": [ + "4.6.2" + ], + "minimatch": [ + "3.1.2" + ], + "natural-compare": [ + "1.4.0" + ], + "optionator": [ + "0.9.3" + ], + "strip-ansi": [ + "6.0.1" + ], + "text-table": [ + "0.2.0" + ], + "ansi-regex": [ + "5.0.1" + ], + "@aashutoshrathi/word-wrap": [ + "1.2.6" + ], + "deep-is": [ + "0.1.4" + ], + "fast-levenshtein": [ + "2.0.6" + ], + "prelude-ls": [ + "1.2.1" + ], + "type-check": [ + "0.4.0" + ], + "brace-expansion": [ + "1.1.11" + ], + "balanced-match": [ + "1.0.2" + ], + "concat-map": [ + "0.0.1" + ], + "argparse": [ + "2.0.1" + ], + "is-extglob": [ + "2.1.1" + ], + "type-fest": [ + "0.20.2" + ], + "locate-path": [ + "6.0.0" + ], + "path-exists": [ + "4.0.0" + ], + "p-locate": [ + "5.0.0" + ], + "p-limit": [ + "3.1.0" + ], + "yocto-queue": [ + "0.1.0" + ], + "flat-cache": [ + "3.2.0" + ], + "flatted": [ + "3.2.9" + ], + "keyv": [ + "4.5.4" + ], + "rimraf": [ + "3.0.2" + ], + "glob": [ + "7.2.3" + ], + "fs.realpath": [ + "1.0.0" + ], + "inflight": [ + "1.0.6" + ], + "inherits": [ + "2.0.4" + ], + "once": [ + "1.4.0" + ], + "path-is-absolute": [ + "1.0.1" + ], + "wrappy": [ + "1.0.2" + ], + "json-buffer": [ + "3.0.1" + ], + "estraverse": [ + "5.3.0" + ], + "acorn": [ + "8.11.2" + ], + "acorn-jsx": [ + "5.3.2_acorn_8.11.2" + ], + "esrecurse": [ + "4.3.0" + ], + "ms": [ + "2.1.2" + ], + "path-key": [ + "3.1.1" + ], + "shebang-command": [ + "2.0.0" + ], + "which": [ + "2.0.2" + ], + "isexe": [ + "2.0.0" + ], + "shebang-regex": [ + "3.0.0" + ], + "ansi-styles": [ + "4.3.0" + ], + "supports-color": [ + "7.2.0" + ], + "has-flag": [ + "4.0.0" + ], + "color-convert": [ + "2.0.1" + ], + "color-name": [ + "1.1.4" + ], + "fast-json-stable-stringify": [ + "2.1.0" + ], + "json-schema-traverse": [ + "0.4.1" + ], + "uri-js": [ + "4.4.1" + ], + "punycode": [ + "2.3.1" + ], + "@nodelib/fs.scandir": [ + "2.1.5" + ], + "fastq": [ + "1.15.0" + ], + "reusify": [ + "1.0.4" + ], + "@nodelib/fs.stat": [ + "2.0.5" + ], + "run-parallel": [ + "1.2.0" + ], + "queue-microtask": [ + "1.2.3" + ], + "@humanwhocodes/object-schema": [ + "2.0.1" + ], + "import-fresh": [ + "3.3.0" + ], + "strip-json-comments": [ + "3.1.1" + ], + "parent-module": [ + "1.0.1" + ], + "resolve-from": [ + "4.0.0" + ], + "callsites": [ + "3.1.0" + ] + }, + "lifecycle_build_target": false, + "lifecycle_hooks_env": [], + "lifecycle_hooks_execution_requirements": [], + "bins": {}, + "npm_translate_lock_repo": "npm", + "package_visibility": [ + "//visibility:public" + ] + } + }, + "npm__espree__9.6.1": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_rule", + "attributes": { + "package": "espree", + "version": "9.6.1", + "root_package": "", + "link_workspace": "flatbuffers~", + "link_packages": {}, + "integrity": "sha512-oruZaFkjorTpF32kDSI5/75ViwGeZginGGy2NoOSg3Q9bnwlnmDm4HLnkl0RE3n+njDXR037aY1+x58Z/zFdwQ==", + "url": "https://registry.npmjs.org/espree/-/espree-9.6.1.tgz", + "commit": "", + "patch_args": [], + "patches": [], + "custom_postinstall": "", + "npm_auth": "", + "npm_auth_basic": "", + "npm_auth_username": "", + "npm_auth_password": "", + "lifecycle_hooks": [], + "extra_build_content": "", + "generate_bzl_library_targets": false + } + }, + "npm__espree__9.6.1__links": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_links", + "attributes": { + "package": "espree", + "version": "9.6.1", + "dev": true, + "root_package": "", + "link_packages": {}, + "deps": { + "acorn": "8.11.2", + "acorn-jsx": "5.3.2_acorn_8.11.2", + "eslint-visitor-keys": "3.4.3" + }, + "transitive_closure": { + "espree": [ + "9.6.1" + ], + "acorn": [ + "8.11.2" + ], + "acorn-jsx": [ + "5.3.2_acorn_8.11.2" + ], + "eslint-visitor-keys": [ + "3.4.3" + ] + }, + "lifecycle_build_target": false, + "lifecycle_hooks_env": [], + "lifecycle_hooks_execution_requirements": [], + "bins": {}, + "npm_translate_lock_repo": "npm", + "package_visibility": [ + "//visibility:public" + ] + } + }, + "npm__esquery__1.5.0": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_rule", + "attributes": { + "package": "esquery", + "version": "1.5.0", + "root_package": "", + "link_workspace": "flatbuffers~", + "link_packages": {}, + "integrity": "sha512-YQLXUplAwJgCydQ78IMJywZCceoqk1oH01OERdSAJc/7U2AylwjhSCLDEtqwg811idIS/9fIU5GjG73IgjKMVg==", + "url": "https://registry.npmjs.org/esquery/-/esquery-1.5.0.tgz", + "commit": "", + "patch_args": [], + "patches": [], + "custom_postinstall": "", + "npm_auth": "", + "npm_auth_basic": "", + "npm_auth_username": "", + "npm_auth_password": "", + "lifecycle_hooks": [], + "extra_build_content": "", + "generate_bzl_library_targets": false + } + }, + "npm__esquery__1.5.0__links": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_links", + "attributes": { + "package": "esquery", + "version": "1.5.0", + "dev": true, + "root_package": "", + "link_packages": {}, + "deps": { + "estraverse": "5.3.0" + }, + "transitive_closure": { + "esquery": [ + "1.5.0" + ], + "estraverse": [ + "5.3.0" + ] + }, + "lifecycle_build_target": false, + "lifecycle_hooks_env": [], + "lifecycle_hooks_execution_requirements": [], + "bins": {}, + "npm_translate_lock_repo": "npm", + "package_visibility": [ + "//visibility:public" + ] + } + }, + "npm__esrecurse__4.3.0": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_rule", + "attributes": { + "package": "esrecurse", + "version": "4.3.0", + "root_package": "", + "link_workspace": "flatbuffers~", + "link_packages": {}, + "integrity": "sha512-KmfKL3b6G+RXvP8N1vr3Tq1kL/oCFgn2NYXEtqP8/L3pKapUA4G8cFVaoF3SU323CD4XypR/ffioHmkti6/Tag==", + "url": "https://registry.npmjs.org/esrecurse/-/esrecurse-4.3.0.tgz", + "commit": "", + "patch_args": [], + "patches": [], + "custom_postinstall": "", + "npm_auth": "", + "npm_auth_basic": "", + "npm_auth_username": "", + "npm_auth_password": "", + "lifecycle_hooks": [], + "extra_build_content": "", + "generate_bzl_library_targets": false + } + }, + "npm__esrecurse__4.3.0__links": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_links", + "attributes": { + "package": "esrecurse", + "version": "4.3.0", + "dev": true, + "root_package": "", + "link_packages": {}, + "deps": { + "estraverse": "5.3.0" + }, + "transitive_closure": { + "esrecurse": [ + "4.3.0" + ], + "estraverse": [ + "5.3.0" + ] + }, + "lifecycle_build_target": false, + "lifecycle_hooks_env": [], + "lifecycle_hooks_execution_requirements": [], + "bins": {}, + "npm_translate_lock_repo": "npm", + "package_visibility": [ + "//visibility:public" + ] + } + }, + "npm__estraverse__5.3.0": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_rule", + "attributes": { + "package": "estraverse", + "version": "5.3.0", + "root_package": "", + "link_workspace": "flatbuffers~", + "link_packages": {}, + "integrity": "sha512-MMdARuVEQziNTeJD8DgMqmhwR11BRQ/cBP+pLtYdSTnf3MIO8fFeiINEbX36ZdNlfU/7A9f3gUw49B3oQsvwBA==", + "url": "https://registry.npmjs.org/estraverse/-/estraverse-5.3.0.tgz", + "commit": "", + "patch_args": [], + "patches": [], + "custom_postinstall": "", + "npm_auth": "", + "npm_auth_basic": "", + "npm_auth_username": "", + "npm_auth_password": "", + "lifecycle_hooks": [], + "extra_build_content": "", + "generate_bzl_library_targets": false + } + }, + "npm__estraverse__5.3.0__links": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_links", + "attributes": { + "package": "estraverse", + "version": "5.3.0", + "dev": true, + "root_package": "", + "link_packages": {}, + "deps": {}, + "transitive_closure": { + "estraverse": [ + "5.3.0" + ] + }, + "lifecycle_build_target": false, + "lifecycle_hooks_env": [], + "lifecycle_hooks_execution_requirements": [], + "bins": {}, + "npm_translate_lock_repo": "npm", + "package_visibility": [ + "//visibility:public" + ] + } + }, + "npm__esutils__2.0.3": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_rule", + "attributes": { + "package": "esutils", + "version": "2.0.3", + "root_package": "", + "link_workspace": "flatbuffers~", + "link_packages": {}, + "integrity": "sha512-kVscqXk4OCp68SZ0dkgEKVi6/8ij300KBWTJq32P/dYeWTSwK41WyTxalN1eRmA5Z9UU/LX9D7FWSmV9SAYx6g==", + "url": "https://registry.npmjs.org/esutils/-/esutils-2.0.3.tgz", + "commit": "", + "patch_args": [], + "patches": [], + "custom_postinstall": "", + "npm_auth": "", + "npm_auth_basic": "", + "npm_auth_username": "", + "npm_auth_password": "", + "lifecycle_hooks": [], + "extra_build_content": "", + "generate_bzl_library_targets": false + } + }, + "npm__esutils__2.0.3__links": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_links", + "attributes": { + "package": "esutils", + "version": "2.0.3", + "dev": true, + "root_package": "", + "link_packages": {}, + "deps": {}, + "transitive_closure": { + "esutils": [ + "2.0.3" + ] + }, + "lifecycle_build_target": false, + "lifecycle_hooks_env": [], + "lifecycle_hooks_execution_requirements": [], + "bins": {}, + "npm_translate_lock_repo": "npm", + "package_visibility": [ + "//visibility:public" + ] + } + }, + "npm__fast-deep-equal__3.1.3": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_rule", + "attributes": { + "package": "fast-deep-equal", + "version": "3.1.3", + "root_package": "", + "link_workspace": "flatbuffers~", + "link_packages": {}, + "integrity": "sha512-f3qQ9oQy9j2AhBe/H9VC91wLmKBCCU/gDOnKNAYG5hswO7BLKj09Hc5HYNz9cGI++xlpDCIgDaitVs03ATR84Q==", + "url": "https://registry.npmjs.org/fast-deep-equal/-/fast-deep-equal-3.1.3.tgz", + "commit": "", + "patch_args": [], + "patches": [], + "custom_postinstall": "", + "npm_auth": "", + "npm_auth_basic": "", + "npm_auth_username": "", + "npm_auth_password": "", + "lifecycle_hooks": [], + "extra_build_content": "", + "generate_bzl_library_targets": false + } + }, + "npm__fast-deep-equal__3.1.3__links": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_links", + "attributes": { + "package": "fast-deep-equal", + "version": "3.1.3", + "dev": true, + "root_package": "", + "link_packages": {}, + "deps": {}, + "transitive_closure": { + "fast-deep-equal": [ + "3.1.3" + ] + }, + "lifecycle_build_target": false, + "lifecycle_hooks_env": [], + "lifecycle_hooks_execution_requirements": [], + "bins": {}, + "npm_translate_lock_repo": "npm", + "package_visibility": [ + "//visibility:public" + ] + } + }, + "npm__fast-glob__3.3.2": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_rule", + "attributes": { + "package": "fast-glob", + "version": "3.3.2", + "root_package": "", + "link_workspace": "flatbuffers~", + "link_packages": {}, + "integrity": "sha512-oX2ruAFQwf/Orj8m737Y5adxDQO0LAB7/S5MnxCdTNDd4p6BsyIVsv9JQsATbTSq8KHRpLwIHbVlUNatxd+1Ow==", + "url": "https://registry.npmjs.org/fast-glob/-/fast-glob-3.3.2.tgz", + "commit": "", + "patch_args": [], + "patches": [], + "custom_postinstall": "", + "npm_auth": "", + "npm_auth_basic": "", + "npm_auth_username": "", + "npm_auth_password": "", + "lifecycle_hooks": [], + "extra_build_content": "", + "generate_bzl_library_targets": false + } + }, + "npm__fast-glob__3.3.2__links": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_links", + "attributes": { + "package": "fast-glob", + "version": "3.3.2", + "dev": true, + "root_package": "", + "link_packages": {}, + "deps": { + "@nodelib/fs.stat": "2.0.5", + "@nodelib/fs.walk": "1.2.8", + "glob-parent": "5.1.2", + "merge2": "1.4.1", + "micromatch": "4.0.5" + }, + "transitive_closure": { + "fast-glob": [ + "3.3.2" + ], + "@nodelib/fs.stat": [ + "2.0.5" + ], + "@nodelib/fs.walk": [ + "1.2.8" + ], + "glob-parent": [ + "5.1.2" + ], + "merge2": [ + "1.4.1" + ], + "micromatch": [ + "4.0.5" + ], + "braces": [ + "3.0.2" + ], + "picomatch": [ + "2.3.1" + ], + "fill-range": [ + "7.0.1" + ], + "to-regex-range": [ + "5.0.1" + ], + "is-number": [ + "7.0.0" + ], + "is-glob": [ + "4.0.3" + ], + "is-extglob": [ + "2.1.1" + ], + "@nodelib/fs.scandir": [ + "2.1.5" + ], + "fastq": [ + "1.15.0" + ], + "reusify": [ + "1.0.4" + ], + "run-parallel": [ + "1.2.0" + ], + "queue-microtask": [ + "1.2.3" + ] + }, + "lifecycle_build_target": false, + "lifecycle_hooks_env": [], + "lifecycle_hooks_execution_requirements": [], + "bins": {}, + "npm_translate_lock_repo": "npm", + "package_visibility": [ + "//visibility:public" + ] + } + }, + "npm__fast-json-stable-stringify__2.1.0": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_rule", + "attributes": { + "package": "fast-json-stable-stringify", + "version": "2.1.0", + "root_package": "", + "link_workspace": "flatbuffers~", + "link_packages": {}, + "integrity": "sha512-lhd/wF+Lk98HZoTCtlVraHtfh5XYijIjalXck7saUtuanSDyLMxnHhSXEDJqHxD7msR8D0uCmqlkwjCV8xvwHw==", + "url": "https://registry.npmjs.org/fast-json-stable-stringify/-/fast-json-stable-stringify-2.1.0.tgz", + "commit": "", + "patch_args": [], + "patches": [], + "custom_postinstall": "", + "npm_auth": "", + "npm_auth_basic": "", + "npm_auth_username": "", + "npm_auth_password": "", + "lifecycle_hooks": [], + "extra_build_content": "", + "generate_bzl_library_targets": false + } + }, + "npm__fast-json-stable-stringify__2.1.0__links": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_links", + "attributes": { + "package": "fast-json-stable-stringify", + "version": "2.1.0", + "dev": true, + "root_package": "", + "link_packages": {}, + "deps": {}, + "transitive_closure": { + "fast-json-stable-stringify": [ + "2.1.0" + ] + }, + "lifecycle_build_target": false, + "lifecycle_hooks_env": [], + "lifecycle_hooks_execution_requirements": [], + "bins": {}, + "npm_translate_lock_repo": "npm", + "package_visibility": [ + "//visibility:public" + ] + } + }, + "npm__fast-levenshtein__2.0.6": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_rule", + "attributes": { + "package": "fast-levenshtein", + "version": "2.0.6", + "root_package": "", + "link_workspace": "flatbuffers~", + "link_packages": {}, + "integrity": "sha512-DCXu6Ifhqcks7TZKY3Hxp3y6qphY5SJZmrWMDrKcERSOXWQdMhU9Ig/PYrzyw/ul9jOIyh0N4M0tbC5hodg8dw==", + "url": "https://registry.npmjs.org/fast-levenshtein/-/fast-levenshtein-2.0.6.tgz", + "commit": "", + "patch_args": [], + "patches": [], + "custom_postinstall": "", + "npm_auth": "", + "npm_auth_basic": "", + "npm_auth_username": "", + "npm_auth_password": "", + "lifecycle_hooks": [], + "extra_build_content": "", + "generate_bzl_library_targets": false + } + }, + "npm__fast-levenshtein__2.0.6__links": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_links", + "attributes": { + "package": "fast-levenshtein", + "version": "2.0.6", + "dev": true, + "root_package": "", + "link_packages": {}, + "deps": {}, + "transitive_closure": { + "fast-levenshtein": [ + "2.0.6" + ] + }, + "lifecycle_build_target": false, + "lifecycle_hooks_env": [], + "lifecycle_hooks_execution_requirements": [], + "bins": {}, + "npm_translate_lock_repo": "npm", + "package_visibility": [ + "//visibility:public" + ] + } + }, + "npm__fastq__1.15.0": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_rule", + "attributes": { + "package": "fastq", + "version": "1.15.0", + "root_package": "", + "link_workspace": "flatbuffers~", + "link_packages": {}, + "integrity": "sha512-wBrocU2LCXXa+lWBt8RoIRD89Fi8OdABODa/kEnyeyjS5aZO5/GNvI5sEINADqP/h8M29UHTHUb53sUu5Ihqdw==", + "url": "https://registry.npmjs.org/fastq/-/fastq-1.15.0.tgz", + "commit": "", + "patch_args": [], + "patches": [], + "custom_postinstall": "", + "npm_auth": "", + "npm_auth_basic": "", + "npm_auth_username": "", + "npm_auth_password": "", + "lifecycle_hooks": [], + "extra_build_content": "", + "generate_bzl_library_targets": false + } + }, + "npm__fastq__1.15.0__links": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_links", + "attributes": { + "package": "fastq", + "version": "1.15.0", + "dev": true, + "root_package": "", + "link_packages": {}, + "deps": { + "reusify": "1.0.4" + }, + "transitive_closure": { + "fastq": [ + "1.15.0" + ], + "reusify": [ + "1.0.4" + ] + }, + "lifecycle_build_target": false, + "lifecycle_hooks_env": [], + "lifecycle_hooks_execution_requirements": [], + "bins": {}, + "npm_translate_lock_repo": "npm", + "package_visibility": [ + "//visibility:public" + ] + } + }, + "npm__file-entry-cache__6.0.1": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_rule", + "attributes": { + "package": "file-entry-cache", + "version": "6.0.1", + "root_package": "", + "link_workspace": "flatbuffers~", + "link_packages": {}, + "integrity": "sha512-7Gps/XWymbLk2QLYK4NzpMOrYjMhdIxXuIvy2QBsLE6ljuodKvdkWs/cpyJJ3CVIVpH0Oi1Hvg1ovbMzLdFBBg==", + "url": "https://registry.npmjs.org/file-entry-cache/-/file-entry-cache-6.0.1.tgz", + "commit": "", + "patch_args": [], + "patches": [], + "custom_postinstall": "", + "npm_auth": "", + "npm_auth_basic": "", + "npm_auth_username": "", + "npm_auth_password": "", + "lifecycle_hooks": [], + "extra_build_content": "", + "generate_bzl_library_targets": false + } + }, + "npm__file-entry-cache__6.0.1__links": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_links", + "attributes": { + "package": "file-entry-cache", + "version": "6.0.1", + "dev": true, + "root_package": "", + "link_packages": {}, + "deps": { + "flat-cache": "3.2.0" + }, + "transitive_closure": { + "file-entry-cache": [ + "6.0.1" + ], + "flat-cache": [ + "3.2.0" + ], + "flatted": [ + "3.2.9" + ], + "keyv": [ + "4.5.4" + ], + "rimraf": [ + "3.0.2" + ], + "glob": [ + "7.2.3" + ], + "fs.realpath": [ + "1.0.0" + ], + "inflight": [ + "1.0.6" + ], + "inherits": [ + "2.0.4" + ], + "minimatch": [ + "3.1.2" + ], + "once": [ + "1.4.0" + ], + "path-is-absolute": [ + "1.0.1" + ], + "wrappy": [ + "1.0.2" + ], + "brace-expansion": [ + "1.1.11" + ], + "balanced-match": [ + "1.0.2" + ], + "concat-map": [ + "0.0.1" + ], + "json-buffer": [ + "3.0.1" + ] + }, + "lifecycle_build_target": false, + "lifecycle_hooks_env": [], + "lifecycle_hooks_execution_requirements": [], + "bins": {}, + "npm_translate_lock_repo": "npm", + "package_visibility": [ + "//visibility:public" + ] + } + }, + "npm__fill-range__7.0.1": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_rule", + "attributes": { + "package": "fill-range", + "version": "7.0.1", + "root_package": "", + "link_workspace": "flatbuffers~", + "link_packages": {}, + "integrity": "sha512-qOo9F+dMUmC2Lcb4BbVvnKJxTPjCm+RRpe4gDuGrzkL7mEVl/djYSu2OdQ2Pa302N4oqkSg9ir6jaLWJ2USVpQ==", + "url": "https://registry.npmjs.org/fill-range/-/fill-range-7.0.1.tgz", + "commit": "", + "patch_args": [], + "patches": [], + "custom_postinstall": "", + "npm_auth": "", + "npm_auth_basic": "", + "npm_auth_username": "", + "npm_auth_password": "", + "lifecycle_hooks": [], + "extra_build_content": "", + "generate_bzl_library_targets": false + } + }, + "npm__fill-range__7.0.1__links": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_links", + "attributes": { + "package": "fill-range", + "version": "7.0.1", + "dev": true, + "root_package": "", + "link_packages": {}, + "deps": { + "to-regex-range": "5.0.1" + }, + "transitive_closure": { + "fill-range": [ + "7.0.1" + ], + "to-regex-range": [ + "5.0.1" + ], + "is-number": [ + "7.0.0" + ] + }, + "lifecycle_build_target": false, + "lifecycle_hooks_env": [], + "lifecycle_hooks_execution_requirements": [], + "bins": {}, + "npm_translate_lock_repo": "npm", + "package_visibility": [ + "//visibility:public" + ] + } + }, + "npm__find-up__5.0.0": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_rule", + "attributes": { + "package": "find-up", + "version": "5.0.0", + "root_package": "", + "link_workspace": "flatbuffers~", + "link_packages": {}, + "integrity": "sha512-78/PXT1wlLLDgTzDs7sjq9hzz0vXD+zn+7wypEe4fXQxCmdmqfGsEPQxmiCSQI3ajFV91bVSsvNtrJRiW6nGng==", + "url": "https://registry.npmjs.org/find-up/-/find-up-5.0.0.tgz", + "commit": "", + "patch_args": [], + "patches": [], + "custom_postinstall": "", + "npm_auth": "", + "npm_auth_basic": "", + "npm_auth_username": "", + "npm_auth_password": "", + "lifecycle_hooks": [], + "extra_build_content": "", + "generate_bzl_library_targets": false + } + }, + "npm__find-up__5.0.0__links": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_links", + "attributes": { + "package": "find-up", + "version": "5.0.0", + "dev": true, + "root_package": "", + "link_packages": {}, + "deps": { + "locate-path": "6.0.0", + "path-exists": "4.0.0" + }, + "transitive_closure": { + "find-up": [ + "5.0.0" + ], + "locate-path": [ + "6.0.0" + ], + "path-exists": [ + "4.0.0" + ], + "p-locate": [ + "5.0.0" + ], + "p-limit": [ + "3.1.0" + ], + "yocto-queue": [ + "0.1.0" + ] + }, + "lifecycle_build_target": false, + "lifecycle_hooks_env": [], + "lifecycle_hooks_execution_requirements": [], + "bins": {}, + "npm_translate_lock_repo": "npm", + "package_visibility": [ + "//visibility:public" + ] + } + }, + "npm__flat-cache__3.2.0": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_rule", + "attributes": { + "package": "flat-cache", + "version": "3.2.0", + "root_package": "", + "link_workspace": "flatbuffers~", + "link_packages": {}, + "integrity": "sha512-CYcENa+FtcUKLmhhqyctpclsq7QF38pKjZHsGNiSQF5r4FtoKDWabFDl3hzaEQMvT1LHEysw5twgLvpYYb4vbw==", + "url": "https://registry.npmjs.org/flat-cache/-/flat-cache-3.2.0.tgz", + "commit": "", + "patch_args": [], + "patches": [], + "custom_postinstall": "", + "npm_auth": "", + "npm_auth_basic": "", + "npm_auth_username": "", + "npm_auth_password": "", + "lifecycle_hooks": [], + "extra_build_content": "", + "generate_bzl_library_targets": false + } + }, + "npm__flat-cache__3.2.0__links": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_links", + "attributes": { + "package": "flat-cache", + "version": "3.2.0", + "dev": true, + "root_package": "", + "link_packages": {}, + "deps": { + "flatted": "3.2.9", + "keyv": "4.5.4", + "rimraf": "3.0.2" + }, + "transitive_closure": { + "flat-cache": [ + "3.2.0" + ], + "flatted": [ + "3.2.9" + ], + "keyv": [ + "4.5.4" + ], + "rimraf": [ + "3.0.2" + ], + "glob": [ + "7.2.3" + ], + "fs.realpath": [ + "1.0.0" + ], + "inflight": [ + "1.0.6" + ], + "inherits": [ + "2.0.4" + ], + "minimatch": [ + "3.1.2" + ], + "once": [ + "1.4.0" + ], + "path-is-absolute": [ + "1.0.1" + ], + "wrappy": [ + "1.0.2" + ], + "brace-expansion": [ + "1.1.11" + ], + "balanced-match": [ + "1.0.2" + ], + "concat-map": [ + "0.0.1" + ], + "json-buffer": [ + "3.0.1" + ] + }, + "lifecycle_build_target": false, + "lifecycle_hooks_env": [], + "lifecycle_hooks_execution_requirements": [], + "bins": {}, + "npm_translate_lock_repo": "npm", + "package_visibility": [ + "//visibility:public" + ] + } + }, + "npm__flatted__3.2.9": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_rule", + "attributes": { + "package": "flatted", + "version": "3.2.9", + "root_package": "", + "link_workspace": "flatbuffers~", + "link_packages": {}, + "integrity": "sha512-36yxDn5H7OFZQla0/jFJmbIKTdZAQHngCedGxiMmpNfEZM0sdEeT+WczLQrjK6D7o2aiyLYDnkw0R3JK0Qv1RQ==", + "url": "https://registry.npmjs.org/flatted/-/flatted-3.2.9.tgz", + "commit": "", + "patch_args": [], + "patches": [], + "custom_postinstall": "", + "npm_auth": "", + "npm_auth_basic": "", + "npm_auth_username": "", + "npm_auth_password": "", + "lifecycle_hooks": [], + "extra_build_content": "", + "generate_bzl_library_targets": false + } + }, + "npm__flatted__3.2.9__links": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_links", + "attributes": { + "package": "flatted", + "version": "3.2.9", + "dev": true, + "root_package": "", + "link_packages": {}, + "deps": {}, + "transitive_closure": { + "flatted": [ + "3.2.9" + ] + }, + "lifecycle_build_target": false, + "lifecycle_hooks_env": [], + "lifecycle_hooks_execution_requirements": [], + "bins": {}, + "npm_translate_lock_repo": "npm", + "package_visibility": [ + "//visibility:public" + ] + } + }, + "npm__fs.realpath__1.0.0": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_rule", + "attributes": { + "package": "fs.realpath", + "version": "1.0.0", + "root_package": "", + "link_workspace": "flatbuffers~", + "link_packages": {}, + "integrity": "sha512-OO0pH2lK6a0hZnAdau5ItzHPI6pUlvI7jMVnxUQRtw4owF2wk8lOSabtGDCTP4Ggrg2MbGnWO9X8K1t4+fGMDw==", + "url": "https://registry.npmjs.org/fs.realpath/-/fs.realpath-1.0.0.tgz", + "commit": "", + "patch_args": [], + "patches": [], + "custom_postinstall": "", + "npm_auth": "", + "npm_auth_basic": "", + "npm_auth_username": "", + "npm_auth_password": "", + "lifecycle_hooks": [], + "extra_build_content": "", + "generate_bzl_library_targets": false + } + }, + "npm__fs.realpath__1.0.0__links": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_links", + "attributes": { + "package": "fs.realpath", + "version": "1.0.0", + "dev": true, + "root_package": "", + "link_packages": {}, + "deps": {}, + "transitive_closure": { + "fs.realpath": [ + "1.0.0" + ] + }, + "lifecycle_build_target": false, + "lifecycle_hooks_env": [], + "lifecycle_hooks_execution_requirements": [], + "bins": {}, + "npm_translate_lock_repo": "npm", + "package_visibility": [ + "//visibility:public" + ] + } + }, + "npm__glob-parent__5.1.2": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_rule", + "attributes": { + "package": "glob-parent", + "version": "5.1.2", + "root_package": "", + "link_workspace": "flatbuffers~", + "link_packages": {}, + "integrity": "sha512-AOIgSQCepiJYwP3ARnGx+5VnTu2HBYdzbGP45eLw1vr3zB3vZLeyed1sC9hnbcOc9/SrMyM5RPQrkGz4aS9Zow==", + "url": "https://registry.npmjs.org/glob-parent/-/glob-parent-5.1.2.tgz", + "commit": "", + "patch_args": [], + "patches": [], + "custom_postinstall": "", + "npm_auth": "", + "npm_auth_basic": "", + "npm_auth_username": "", + "npm_auth_password": "", + "lifecycle_hooks": [], + "extra_build_content": "", + "generate_bzl_library_targets": false + } + }, + "npm__glob-parent__5.1.2__links": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_links", + "attributes": { + "package": "glob-parent", + "version": "5.1.2", + "dev": true, + "root_package": "", + "link_packages": {}, + "deps": { + "is-glob": "4.0.3" + }, + "transitive_closure": { + "glob-parent": [ + "5.1.2" + ], + "is-glob": [ + "4.0.3" + ], + "is-extglob": [ + "2.1.1" + ] + }, + "lifecycle_build_target": false, + "lifecycle_hooks_env": [], + "lifecycle_hooks_execution_requirements": [], + "bins": {}, + "npm_translate_lock_repo": "npm", + "package_visibility": [ + "//visibility:public" + ] + } + }, + "npm__glob-parent__6.0.2": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_rule", + "attributes": { + "package": "glob-parent", + "version": "6.0.2", + "root_package": "", + "link_workspace": "flatbuffers~", + "link_packages": {}, + "integrity": "sha512-XxwI8EOhVQgWp6iDL+3b0r86f4d6AX6zSU55HfB4ydCEuXLXc5FcYeOu+nnGftS4TEju/11rt4KJPTMgbfmv4A==", + "url": "https://registry.npmjs.org/glob-parent/-/glob-parent-6.0.2.tgz", + "commit": "", + "patch_args": [], + "patches": [], + "custom_postinstall": "", + "npm_auth": "", + "npm_auth_basic": "", + "npm_auth_username": "", + "npm_auth_password": "", + "lifecycle_hooks": [], + "extra_build_content": "", + "generate_bzl_library_targets": false + } + }, + "npm__glob-parent__6.0.2__links": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_links", + "attributes": { + "package": "glob-parent", + "version": "6.0.2", + "dev": true, + "root_package": "", + "link_packages": {}, + "deps": { + "is-glob": "4.0.3" + }, + "transitive_closure": { + "glob-parent": [ + "6.0.2" + ], + "is-glob": [ + "4.0.3" + ], + "is-extglob": [ + "2.1.1" + ] + }, + "lifecycle_build_target": false, + "lifecycle_hooks_env": [], + "lifecycle_hooks_execution_requirements": [], + "bins": {}, + "npm_translate_lock_repo": "npm", + "package_visibility": [ + "//visibility:public" + ] + } + }, + "npm__glob__7.2.3": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_rule", + "attributes": { + "package": "glob", + "version": "7.2.3", + "root_package": "", + "link_workspace": "flatbuffers~", + "link_packages": {}, + "integrity": "sha512-nFR0zLpU2YCaRxwoCJvL6UvCH2JFyFVIvwTLsIf21AuHlMskA1hhTdk+LlYJtOlYt9v6dvszD2BGRqBL+iQK9Q==", + "url": "https://registry.npmjs.org/glob/-/glob-7.2.3.tgz", + "commit": "", + "patch_args": [], + "patches": [], + "custom_postinstall": "", + "npm_auth": "", + "npm_auth_basic": "", + "npm_auth_username": "", + "npm_auth_password": "", + "lifecycle_hooks": [], + "extra_build_content": "", + "generate_bzl_library_targets": false + } + }, + "npm__glob__7.2.3__links": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_links", + "attributes": { + "package": "glob", + "version": "7.2.3", + "dev": true, + "root_package": "", + "link_packages": {}, + "deps": { + "fs.realpath": "1.0.0", + "inflight": "1.0.6", + "inherits": "2.0.4", + "minimatch": "3.1.2", + "once": "1.4.0", + "path-is-absolute": "1.0.1" + }, + "transitive_closure": { + "glob": [ + "7.2.3" + ], + "fs.realpath": [ + "1.0.0" + ], + "inflight": [ + "1.0.6" + ], + "inherits": [ + "2.0.4" + ], + "minimatch": [ + "3.1.2" + ], + "once": [ + "1.4.0" + ], + "path-is-absolute": [ + "1.0.1" + ], + "wrappy": [ + "1.0.2" + ], + "brace-expansion": [ + "1.1.11" + ], + "balanced-match": [ + "1.0.2" + ], + "concat-map": [ + "0.0.1" + ] + }, + "lifecycle_build_target": false, + "lifecycle_hooks_env": [], + "lifecycle_hooks_execution_requirements": [], + "bins": {}, + "npm_translate_lock_repo": "npm", + "package_visibility": [ + "//visibility:public" + ] + } + }, + "npm__globals__13.23.0": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_rule", + "attributes": { + "package": "globals", + "version": "13.23.0", + "root_package": "", + "link_workspace": "flatbuffers~", + "link_packages": {}, + "integrity": "sha512-XAmF0RjlrjY23MA51q3HltdlGxUpXPvg0GioKiD9X6HD28iMjo2dKC8Vqwm7lne4GNr78+RHTfliktR6ZH09wA==", + "url": "https://registry.npmjs.org/globals/-/globals-13.23.0.tgz", + "commit": "", + "patch_args": [], + "patches": [], + "custom_postinstall": "", + "npm_auth": "", + "npm_auth_basic": "", + "npm_auth_username": "", + "npm_auth_password": "", + "lifecycle_hooks": [], + "extra_build_content": "", + "generate_bzl_library_targets": false + } + }, + "npm__globals__13.23.0__links": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_links", + "attributes": { + "package": "globals", + "version": "13.23.0", + "dev": true, + "root_package": "", + "link_packages": {}, + "deps": { + "type-fest": "0.20.2" + }, + "transitive_closure": { + "globals": [ + "13.23.0" + ], + "type-fest": [ + "0.20.2" + ] + }, + "lifecycle_build_target": false, + "lifecycle_hooks_env": [], + "lifecycle_hooks_execution_requirements": [], + "bins": {}, + "npm_translate_lock_repo": "npm", + "package_visibility": [ + "//visibility:public" + ] + } + }, + "npm__globby__11.1.0": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_rule", + "attributes": { + "package": "globby", + "version": "11.1.0", + "root_package": "", + "link_workspace": "flatbuffers~", + "link_packages": {}, + "integrity": "sha512-jhIXaOzy1sb8IyocaruWSn1TjmnBVs8Ayhcy83rmxNJ8q2uWKCAj3CnJY+KpGSXCueAPc0i05kVvVKtP1t9S3g==", + "url": "https://registry.npmjs.org/globby/-/globby-11.1.0.tgz", + "commit": "", + "patch_args": [], + "patches": [], + "custom_postinstall": "", + "npm_auth": "", + "npm_auth_basic": "", + "npm_auth_username": "", + "npm_auth_password": "", + "lifecycle_hooks": [], + "extra_build_content": "", + "generate_bzl_library_targets": false + } + }, + "npm__globby__11.1.0__links": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_links", + "attributes": { + "package": "globby", + "version": "11.1.0", + "dev": true, + "root_package": "", + "link_packages": {}, + "deps": { + "array-union": "2.1.0", + "dir-glob": "3.0.1", + "fast-glob": "3.3.2", + "ignore": "5.3.0", + "merge2": "1.4.1", + "slash": "3.0.0" + }, + "transitive_closure": { + "globby": [ + "11.1.0" + ], + "array-union": [ + "2.1.0" + ], + "dir-glob": [ + "3.0.1" + ], + "fast-glob": [ + "3.3.2" + ], + "ignore": [ + "5.3.0" + ], + "merge2": [ + "1.4.1" + ], + "slash": [ + "3.0.0" + ], + "@nodelib/fs.stat": [ + "2.0.5" + ], + "@nodelib/fs.walk": [ + "1.2.8" + ], + "glob-parent": [ + "5.1.2" + ], + "micromatch": [ + "4.0.5" + ], + "braces": [ + "3.0.2" + ], + "picomatch": [ + "2.3.1" + ], + "fill-range": [ + "7.0.1" + ], + "to-regex-range": [ + "5.0.1" + ], + "is-number": [ + "7.0.0" + ], + "is-glob": [ + "4.0.3" + ], + "is-extglob": [ + "2.1.1" + ], + "@nodelib/fs.scandir": [ + "2.1.5" + ], + "fastq": [ + "1.15.0" + ], + "reusify": [ + "1.0.4" + ], + "run-parallel": [ + "1.2.0" + ], + "queue-microtask": [ + "1.2.3" + ], + "path-type": [ + "4.0.0" + ] + }, + "lifecycle_build_target": false, + "lifecycle_hooks_env": [], + "lifecycle_hooks_execution_requirements": [], + "bins": {}, + "npm_translate_lock_repo": "npm", + "package_visibility": [ + "//visibility:public" + ] + } + }, + "npm__graphemer__1.4.0": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_rule", + "attributes": { + "package": "graphemer", + "version": "1.4.0", + "root_package": "", + "link_workspace": "flatbuffers~", + "link_packages": {}, + "integrity": "sha512-EtKwoO6kxCL9WO5xipiHTZlSzBm7WLT627TqC/uVRd0HKmq8NXyebnNYxDoBi7wt8eTWrUrKXCOVaFq9x1kgag==", + "url": "https://registry.npmjs.org/graphemer/-/graphemer-1.4.0.tgz", + "commit": "", + "patch_args": [], + "patches": [], + "custom_postinstall": "", + "npm_auth": "", + "npm_auth_basic": "", + "npm_auth_username": "", + "npm_auth_password": "", + "lifecycle_hooks": [], + "extra_build_content": "", + "generate_bzl_library_targets": false + } + }, + "npm__graphemer__1.4.0__links": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_links", + "attributes": { + "package": "graphemer", + "version": "1.4.0", + "dev": true, + "root_package": "", + "link_packages": {}, + "deps": {}, + "transitive_closure": { + "graphemer": [ + "1.4.0" + ] + }, + "lifecycle_build_target": false, + "lifecycle_hooks_env": [], + "lifecycle_hooks_execution_requirements": [], + "bins": {}, + "npm_translate_lock_repo": "npm", + "package_visibility": [ + "//visibility:public" + ] + } + }, + "npm__has-flag__4.0.0": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_rule", + "attributes": { + "package": "has-flag", + "version": "4.0.0", + "root_package": "", + "link_workspace": "flatbuffers~", + "link_packages": {}, + "integrity": "sha512-EykJT/Q1KjTWctppgIAgfSO0tKVuZUjhgMr17kqTumMl6Afv3EISleU7qZUzoXDFTAHTDC4NOoG/ZxU3EvlMPQ==", + "url": "https://registry.npmjs.org/has-flag/-/has-flag-4.0.0.tgz", + "commit": "", + "patch_args": [], + "patches": [], + "custom_postinstall": "", + "npm_auth": "", + "npm_auth_basic": "", + "npm_auth_username": "", + "npm_auth_password": "", + "lifecycle_hooks": [], + "extra_build_content": "", + "generate_bzl_library_targets": false + } + }, + "npm__has-flag__4.0.0__links": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_links", + "attributes": { + "package": "has-flag", + "version": "4.0.0", + "dev": true, + "root_package": "", + "link_packages": {}, + "deps": {}, + "transitive_closure": { + "has-flag": [ + "4.0.0" + ] + }, + "lifecycle_build_target": false, + "lifecycle_hooks_env": [], + "lifecycle_hooks_execution_requirements": [], + "bins": {}, + "npm_translate_lock_repo": "npm", + "package_visibility": [ + "//visibility:public" + ] + } + }, + "npm__ignore__5.3.0": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_rule", + "attributes": { + "package": "ignore", + "version": "5.3.0", + "root_package": "", + "link_workspace": "flatbuffers~", + "link_packages": {}, + "integrity": "sha512-g7dmpshy+gD7mh88OC9NwSGTKoc3kyLAZQRU1mt53Aw/vnvfXnbC+F/7F7QoYVKbV+KNvJx8wArewKy1vXMtlg==", + "url": "https://registry.npmjs.org/ignore/-/ignore-5.3.0.tgz", + "commit": "", + "patch_args": [], + "patches": [], + "custom_postinstall": "", + "npm_auth": "", + "npm_auth_basic": "", + "npm_auth_username": "", + "npm_auth_password": "", + "lifecycle_hooks": [], + "extra_build_content": "", + "generate_bzl_library_targets": false + } + }, + "npm__ignore__5.3.0__links": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_links", + "attributes": { + "package": "ignore", + "version": "5.3.0", + "dev": true, + "root_package": "", + "link_packages": {}, + "deps": {}, + "transitive_closure": { + "ignore": [ + "5.3.0" + ] + }, + "lifecycle_build_target": false, + "lifecycle_hooks_env": [], + "lifecycle_hooks_execution_requirements": [], + "bins": {}, + "npm_translate_lock_repo": "npm", + "package_visibility": [ + "//visibility:public" + ] + } + }, + "npm__import-fresh__3.3.0": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_rule", + "attributes": { + "package": "import-fresh", + "version": "3.3.0", + "root_package": "", + "link_workspace": "flatbuffers~", + "link_packages": {}, + "integrity": "sha512-veYYhQa+D1QBKznvhUHxb8faxlrwUnxseDAbAp457E0wLNio2bOSKnjYDhMj+YiAq61xrMGhQk9iXVk5FzgQMw==", + "url": "https://registry.npmjs.org/import-fresh/-/import-fresh-3.3.0.tgz", + "commit": "", + "patch_args": [], + "patches": [], + "custom_postinstall": "", + "npm_auth": "", + "npm_auth_basic": "", + "npm_auth_username": "", + "npm_auth_password": "", + "lifecycle_hooks": [], + "extra_build_content": "", + "generate_bzl_library_targets": false + } + }, + "npm__import-fresh__3.3.0__links": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_links", + "attributes": { + "package": "import-fresh", + "version": "3.3.0", + "dev": true, + "root_package": "", + "link_packages": {}, + "deps": { + "parent-module": "1.0.1", + "resolve-from": "4.0.0" + }, + "transitive_closure": { + "import-fresh": [ + "3.3.0" + ], + "parent-module": [ + "1.0.1" + ], + "resolve-from": [ + "4.0.0" + ], + "callsites": [ + "3.1.0" + ] + }, + "lifecycle_build_target": false, + "lifecycle_hooks_env": [], + "lifecycle_hooks_execution_requirements": [], + "bins": {}, + "npm_translate_lock_repo": "npm", + "package_visibility": [ + "//visibility:public" + ] + } + }, + "npm__imurmurhash__0.1.4": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_rule", + "attributes": { + "package": "imurmurhash", + "version": "0.1.4", + "root_package": "", + "link_workspace": "flatbuffers~", + "link_packages": {}, + "integrity": "sha512-JmXMZ6wuvDmLiHEml9ykzqO6lwFbof0GG4IkcGaENdCRDDmMVnny7s5HsIgHCbaq0w2MyPhDqkhTUgS2LU2PHA==", + "url": "https://registry.npmjs.org/imurmurhash/-/imurmurhash-0.1.4.tgz", + "commit": "", + "patch_args": [], + "patches": [], + "custom_postinstall": "", + "npm_auth": "", + "npm_auth_basic": "", + "npm_auth_username": "", + "npm_auth_password": "", + "lifecycle_hooks": [], + "extra_build_content": "", + "generate_bzl_library_targets": false + } + }, + "npm__imurmurhash__0.1.4__links": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_links", + "attributes": { + "package": "imurmurhash", + "version": "0.1.4", + "dev": true, + "root_package": "", + "link_packages": {}, + "deps": {}, + "transitive_closure": { + "imurmurhash": [ + "0.1.4" + ] + }, + "lifecycle_build_target": false, + "lifecycle_hooks_env": [], + "lifecycle_hooks_execution_requirements": [], + "bins": {}, + "npm_translate_lock_repo": "npm", + "package_visibility": [ + "//visibility:public" + ] + } + }, + "npm__inflight__1.0.6": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_rule", + "attributes": { + "package": "inflight", + "version": "1.0.6", + "root_package": "", + "link_workspace": "flatbuffers~", + "link_packages": {}, + "integrity": "sha512-k92I/b08q4wvFscXCLvqfsHCrjrF7yiXsQuIVvVE7N82W3+aqpzuUdBbfhWcy/FZR3/4IgflMgKLOsvPDrGCJA==", + "url": "https://registry.npmjs.org/inflight/-/inflight-1.0.6.tgz", + "commit": "", + "patch_args": [], + "patches": [], + "custom_postinstall": "", + "npm_auth": "", + "npm_auth_basic": "", + "npm_auth_username": "", + "npm_auth_password": "", + "lifecycle_hooks": [], + "extra_build_content": "", + "generate_bzl_library_targets": false + } + }, + "npm__inflight__1.0.6__links": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_links", + "attributes": { + "package": "inflight", + "version": "1.0.6", + "dev": true, + "root_package": "", + "link_packages": {}, + "deps": { + "once": "1.4.0", + "wrappy": "1.0.2" + }, + "transitive_closure": { + "inflight": [ + "1.0.6" + ], + "once": [ + "1.4.0" + ], + "wrappy": [ + "1.0.2" + ] + }, + "lifecycle_build_target": false, + "lifecycle_hooks_env": [], + "lifecycle_hooks_execution_requirements": [], + "bins": {}, + "npm_translate_lock_repo": "npm", + "package_visibility": [ + "//visibility:public" + ] + } + }, + "npm__inherits__2.0.4": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_rule", + "attributes": { + "package": "inherits", + "version": "2.0.4", + "root_package": "", + "link_workspace": "flatbuffers~", + "link_packages": {}, + "integrity": "sha512-k/vGaX4/Yla3WzyMCvTQOXYeIHvqOKtnqBduzTHpzpQZzAskKMhZ2K+EnBiSM9zGSoIFeMpXKxa4dYeZIQqewQ==", + "url": "https://registry.npmjs.org/inherits/-/inherits-2.0.4.tgz", + "commit": "", + "patch_args": [], + "patches": [], + "custom_postinstall": "", + "npm_auth": "", + "npm_auth_basic": "", + "npm_auth_username": "", + "npm_auth_password": "", + "lifecycle_hooks": [], + "extra_build_content": "", + "generate_bzl_library_targets": false + } + }, + "npm__inherits__2.0.4__links": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_links", + "attributes": { + "package": "inherits", + "version": "2.0.4", + "dev": true, + "root_package": "", + "link_packages": {}, + "deps": {}, + "transitive_closure": { + "inherits": [ + "2.0.4" + ] + }, + "lifecycle_build_target": false, + "lifecycle_hooks_env": [], + "lifecycle_hooks_execution_requirements": [], + "bins": {}, + "npm_translate_lock_repo": "npm", + "package_visibility": [ + "//visibility:public" + ] + } + }, + "npm__is-extglob__2.1.1": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_rule", + "attributes": { + "package": "is-extglob", + "version": "2.1.1", + "root_package": "", + "link_workspace": "flatbuffers~", + "link_packages": {}, + "integrity": "sha512-SbKbANkN603Vi4jEZv49LeVJMn4yGwsbzZworEoyEiutsN3nJYdbO36zfhGJ6QEDpOZIFkDtnq5JRxmvl3jsoQ==", + "url": "https://registry.npmjs.org/is-extglob/-/is-extglob-2.1.1.tgz", + "commit": "", + "patch_args": [], + "patches": [], + "custom_postinstall": "", + "npm_auth": "", + "npm_auth_basic": "", + "npm_auth_username": "", + "npm_auth_password": "", + "lifecycle_hooks": [], + "extra_build_content": "", + "generate_bzl_library_targets": false + } + }, + "npm__is-extglob__2.1.1__links": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_links", + "attributes": { + "package": "is-extglob", + "version": "2.1.1", + "dev": true, + "root_package": "", + "link_packages": {}, + "deps": {}, + "transitive_closure": { + "is-extglob": [ + "2.1.1" + ] + }, + "lifecycle_build_target": false, + "lifecycle_hooks_env": [], + "lifecycle_hooks_execution_requirements": [], + "bins": {}, + "npm_translate_lock_repo": "npm", + "package_visibility": [ + "//visibility:public" + ] + } + }, + "npm__is-glob__4.0.3": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_rule", + "attributes": { + "package": "is-glob", + "version": "4.0.3", + "root_package": "", + "link_workspace": "flatbuffers~", + "link_packages": {}, + "integrity": "sha512-xelSayHH36ZgE7ZWhli7pW34hNbNl8Ojv5KVmkJD4hBdD3th8Tfk9vYasLM+mXWOZhFkgZfxhLSnrwRr4elSSg==", + "url": "https://registry.npmjs.org/is-glob/-/is-glob-4.0.3.tgz", + "commit": "", + "patch_args": [], + "patches": [], + "custom_postinstall": "", + "npm_auth": "", + "npm_auth_basic": "", + "npm_auth_username": "", + "npm_auth_password": "", + "lifecycle_hooks": [], + "extra_build_content": "", + "generate_bzl_library_targets": false + } + }, + "npm__is-glob__4.0.3__links": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_links", + "attributes": { + "package": "is-glob", + "version": "4.0.3", + "dev": true, + "root_package": "", + "link_packages": {}, + "deps": { + "is-extglob": "2.1.1" + }, + "transitive_closure": { + "is-glob": [ + "4.0.3" + ], + "is-extglob": [ + "2.1.1" + ] + }, + "lifecycle_build_target": false, + "lifecycle_hooks_env": [], + "lifecycle_hooks_execution_requirements": [], + "bins": {}, + "npm_translate_lock_repo": "npm", + "package_visibility": [ + "//visibility:public" + ] + } + }, + "npm__is-number__7.0.0": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_rule", + "attributes": { + "package": "is-number", + "version": "7.0.0", + "root_package": "", + "link_workspace": "flatbuffers~", + "link_packages": {}, + "integrity": "sha512-41Cifkg6e8TylSpdtTpeLVMqvSBEVzTttHvERD741+pnZ8ANv0004MRL43QKPDlK9cGvNp6NZWZUBlbGXYxxng==", + "url": "https://registry.npmjs.org/is-number/-/is-number-7.0.0.tgz", + "commit": "", + "patch_args": [], + "patches": [], + "custom_postinstall": "", + "npm_auth": "", + "npm_auth_basic": "", + "npm_auth_username": "", + "npm_auth_password": "", + "lifecycle_hooks": [], + "extra_build_content": "", + "generate_bzl_library_targets": false + } + }, + "npm__is-number__7.0.0__links": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_links", + "attributes": { + "package": "is-number", + "version": "7.0.0", + "dev": true, + "root_package": "", + "link_packages": {}, + "deps": {}, + "transitive_closure": { + "is-number": [ + "7.0.0" + ] + }, + "lifecycle_build_target": false, + "lifecycle_hooks_env": [], + "lifecycle_hooks_execution_requirements": [], + "bins": {}, + "npm_translate_lock_repo": "npm", + "package_visibility": [ + "//visibility:public" + ] + } + }, + "npm__is-path-inside__3.0.3": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_rule", + "attributes": { + "package": "is-path-inside", + "version": "3.0.3", + "root_package": "", + "link_workspace": "flatbuffers~", + "link_packages": {}, + "integrity": "sha512-Fd4gABb+ycGAmKou8eMftCupSir5lRxqf4aD/vd0cD2qc4HL07OjCeuHMr8Ro4CoMaeCKDB0/ECBOVWjTwUvPQ==", + "url": "https://registry.npmjs.org/is-path-inside/-/is-path-inside-3.0.3.tgz", + "commit": "", + "patch_args": [], + "patches": [], + "custom_postinstall": "", + "npm_auth": "", + "npm_auth_basic": "", + "npm_auth_username": "", + "npm_auth_password": "", + "lifecycle_hooks": [], + "extra_build_content": "", + "generate_bzl_library_targets": false + } + }, + "npm__is-path-inside__3.0.3__links": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_links", + "attributes": { + "package": "is-path-inside", + "version": "3.0.3", + "dev": true, + "root_package": "", + "link_packages": {}, + "deps": {}, + "transitive_closure": { + "is-path-inside": [ + "3.0.3" + ] + }, + "lifecycle_build_target": false, + "lifecycle_hooks_env": [], + "lifecycle_hooks_execution_requirements": [], + "bins": {}, + "npm_translate_lock_repo": "npm", + "package_visibility": [ + "//visibility:public" + ] + } + }, + "npm__isexe__2.0.0": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_rule", + "attributes": { + "package": "isexe", + "version": "2.0.0", + "root_package": "", + "link_workspace": "flatbuffers~", + "link_packages": {}, + "integrity": "sha512-RHxMLp9lnKHGHRng9QFhRCMbYAcVpn69smSGcq3f36xjgVVWThj4qqLbTLlq7Ssj8B+fIQ1EuCEGI2lKsyQeIw==", + "url": "https://registry.npmjs.org/isexe/-/isexe-2.0.0.tgz", + "commit": "", + "patch_args": [], + "patches": [], + "custom_postinstall": "", + "npm_auth": "", + "npm_auth_basic": "", + "npm_auth_username": "", + "npm_auth_password": "", + "lifecycle_hooks": [], + "extra_build_content": "", + "generate_bzl_library_targets": false + } + }, + "npm__isexe__2.0.0__links": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_links", + "attributes": { + "package": "isexe", + "version": "2.0.0", + "dev": true, + "root_package": "", + "link_packages": {}, + "deps": {}, + "transitive_closure": { + "isexe": [ + "2.0.0" + ] + }, + "lifecycle_build_target": false, + "lifecycle_hooks_env": [], + "lifecycle_hooks_execution_requirements": [], + "bins": {}, + "npm_translate_lock_repo": "npm", + "package_visibility": [ + "//visibility:public" + ] + } + }, + "npm__js-yaml__4.1.0": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_rule", + "attributes": { + "package": "js-yaml", + "version": "4.1.0", + "root_package": "", + "link_workspace": "flatbuffers~", + "link_packages": {}, + "integrity": "sha512-wpxZs9NoxZaJESJGIZTyDEaYpl0FKSA+FB9aJiyemKhMwkxQg63h4T1KJgUGHpTqPDNRcmmYLugrRjJlBtWvRA==", + "url": "https://registry.npmjs.org/js-yaml/-/js-yaml-4.1.0.tgz", + "commit": "", + "patch_args": [], + "patches": [], + "custom_postinstall": "", + "npm_auth": "", + "npm_auth_basic": "", + "npm_auth_username": "", + "npm_auth_password": "", + "lifecycle_hooks": [], + "extra_build_content": "", + "generate_bzl_library_targets": false + } + }, + "npm__js-yaml__4.1.0__links": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_links", + "attributes": { + "package": "js-yaml", + "version": "4.1.0", + "dev": true, + "root_package": "", + "link_packages": {}, + "deps": { + "argparse": "2.0.1" + }, + "transitive_closure": { + "js-yaml": [ + "4.1.0" + ], + "argparse": [ + "2.0.1" + ] + }, + "lifecycle_build_target": false, + "lifecycle_hooks_env": [], + "lifecycle_hooks_execution_requirements": [], + "bins": {}, + "npm_translate_lock_repo": "npm", + "package_visibility": [ + "//visibility:public" + ] + } + }, + "npm__json-buffer__3.0.1": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_rule", + "attributes": { + "package": "json-buffer", + "version": "3.0.1", + "root_package": "", + "link_workspace": "flatbuffers~", + "link_packages": {}, + "integrity": "sha512-4bV5BfR2mqfQTJm+V5tPPdf+ZpuhiIvTuAB5g8kcrXOZpTT/QwwVRWBywX1ozr6lEuPdbHxwaJlm9G6mI2sfSQ==", + "url": "https://registry.npmjs.org/json-buffer/-/json-buffer-3.0.1.tgz", + "commit": "", + "patch_args": [], + "patches": [], + "custom_postinstall": "", + "npm_auth": "", + "npm_auth_basic": "", + "npm_auth_username": "", + "npm_auth_password": "", + "lifecycle_hooks": [], + "extra_build_content": "", + "generate_bzl_library_targets": false + } + }, + "npm__json-buffer__3.0.1__links": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_links", + "attributes": { + "package": "json-buffer", + "version": "3.0.1", + "dev": true, + "root_package": "", + "link_packages": {}, + "deps": {}, + "transitive_closure": { + "json-buffer": [ + "3.0.1" + ] + }, + "lifecycle_build_target": false, + "lifecycle_hooks_env": [], + "lifecycle_hooks_execution_requirements": [], + "bins": {}, + "npm_translate_lock_repo": "npm", + "package_visibility": [ + "//visibility:public" + ] + } + }, + "npm__json-schema-traverse__0.4.1": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_rule", + "attributes": { + "package": "json-schema-traverse", + "version": "0.4.1", + "root_package": "", + "link_workspace": "flatbuffers~", + "link_packages": {}, + "integrity": "sha512-xbbCH5dCYU5T8LcEhhuh7HJ88HXuW3qsI3Y0zOZFKfZEHcpWiHU/Jxzk629Brsab/mMiHQti9wMP+845RPe3Vg==", + "url": "https://registry.npmjs.org/json-schema-traverse/-/json-schema-traverse-0.4.1.tgz", + "commit": "", + "patch_args": [], + "patches": [], + "custom_postinstall": "", + "npm_auth": "", + "npm_auth_basic": "", + "npm_auth_username": "", + "npm_auth_password": "", + "lifecycle_hooks": [], + "extra_build_content": "", + "generate_bzl_library_targets": false + } + }, + "npm__json-schema-traverse__0.4.1__links": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_links", + "attributes": { + "package": "json-schema-traverse", + "version": "0.4.1", + "dev": true, + "root_package": "", + "link_packages": {}, + "deps": {}, + "transitive_closure": { + "json-schema-traverse": [ + "0.4.1" + ] + }, + "lifecycle_build_target": false, + "lifecycle_hooks_env": [], + "lifecycle_hooks_execution_requirements": [], + "bins": {}, + "npm_translate_lock_repo": "npm", + "package_visibility": [ + "//visibility:public" + ] + } + }, + "npm__json-stable-stringify-without-jsonify__1.0.1": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_rule", + "attributes": { + "package": "json-stable-stringify-without-jsonify", + "version": "1.0.1", + "root_package": "", + "link_workspace": "flatbuffers~", + "link_packages": {}, + "integrity": "sha512-Bdboy+l7tA3OGW6FjyFHWkP5LuByj1Tk33Ljyq0axyzdk9//JSi2u3fP1QSmd1KNwq6VOKYGlAu87CisVir6Pw==", + "url": "https://registry.npmjs.org/json-stable-stringify-without-jsonify/-/json-stable-stringify-without-jsonify-1.0.1.tgz", + "commit": "", + "patch_args": [], + "patches": [], + "custom_postinstall": "", + "npm_auth": "", + "npm_auth_basic": "", + "npm_auth_username": "", + "npm_auth_password": "", + "lifecycle_hooks": [], + "extra_build_content": "", + "generate_bzl_library_targets": false + } + }, + "npm__json-stable-stringify-without-jsonify__1.0.1__links": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_links", + "attributes": { + "package": "json-stable-stringify-without-jsonify", + "version": "1.0.1", + "dev": true, + "root_package": "", + "link_packages": {}, + "deps": {}, + "transitive_closure": { + "json-stable-stringify-without-jsonify": [ + "1.0.1" + ] + }, + "lifecycle_build_target": false, + "lifecycle_hooks_env": [], + "lifecycle_hooks_execution_requirements": [], + "bins": {}, + "npm_translate_lock_repo": "npm", + "package_visibility": [ + "//visibility:public" + ] + } + }, + "npm__keyv__4.5.4": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_rule", + "attributes": { + "package": "keyv", + "version": "4.5.4", + "root_package": "", + "link_workspace": "flatbuffers~", + "link_packages": {}, + "integrity": "sha512-oxVHkHR/EJf2CNXnWxRLW6mg7JyCCUcG0DtEGmL2ctUo1PNTin1PUil+r/+4r5MpVgC/fn1kjsx7mjSujKqIpw==", + "url": "https://registry.npmjs.org/keyv/-/keyv-4.5.4.tgz", + "commit": "", + "patch_args": [], + "patches": [], + "custom_postinstall": "", + "npm_auth": "", + "npm_auth_basic": "", + "npm_auth_username": "", + "npm_auth_password": "", + "lifecycle_hooks": [], + "extra_build_content": "", + "generate_bzl_library_targets": false + } + }, + "npm__keyv__4.5.4__links": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_links", + "attributes": { + "package": "keyv", + "version": "4.5.4", + "dev": true, + "root_package": "", + "link_packages": {}, + "deps": { + "json-buffer": "3.0.1" + }, + "transitive_closure": { + "keyv": [ + "4.5.4" + ], + "json-buffer": [ + "3.0.1" + ] + }, + "lifecycle_build_target": false, + "lifecycle_hooks_env": [], + "lifecycle_hooks_execution_requirements": [], + "bins": {}, + "npm_translate_lock_repo": "npm", + "package_visibility": [ + "//visibility:public" + ] + } + }, + "npm__levn__0.4.1": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_rule", + "attributes": { + "package": "levn", + "version": "0.4.1", + "root_package": "", + "link_workspace": "flatbuffers~", + "link_packages": {}, + "integrity": "sha512-+bT2uH4E5LGE7h/n3evcS/sQlJXCpIp6ym8OWJ5eV6+67Dsql/LaaT7qJBAt2rzfoa/5QBGBhxDix1dMt2kQKQ==", + "url": "https://registry.npmjs.org/levn/-/levn-0.4.1.tgz", + "commit": "", + "patch_args": [], + "patches": [], + "custom_postinstall": "", + "npm_auth": "", + "npm_auth_basic": "", + "npm_auth_username": "", + "npm_auth_password": "", + "lifecycle_hooks": [], + "extra_build_content": "", + "generate_bzl_library_targets": false + } + }, + "npm__levn__0.4.1__links": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_links", + "attributes": { + "package": "levn", + "version": "0.4.1", + "dev": true, + "root_package": "", + "link_packages": {}, + "deps": { + "prelude-ls": "1.2.1", + "type-check": "0.4.0" + }, + "transitive_closure": { + "levn": [ + "0.4.1" + ], + "prelude-ls": [ + "1.2.1" + ], + "type-check": [ + "0.4.0" + ] + }, + "lifecycle_build_target": false, + "lifecycle_hooks_env": [], + "lifecycle_hooks_execution_requirements": [], + "bins": {}, + "npm_translate_lock_repo": "npm", + "package_visibility": [ + "//visibility:public" + ] + } + }, + "npm__locate-path__6.0.0": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_rule", + "attributes": { + "package": "locate-path", + "version": "6.0.0", + "root_package": "", + "link_workspace": "flatbuffers~", + "link_packages": {}, + "integrity": "sha512-iPZK6eYjbxRu3uB4/WZ3EsEIMJFMqAoopl3R+zuq0UjcAm/MO6KCweDgPfP3elTztoKP3KtnVHxTn2NHBSDVUw==", + "url": "https://registry.npmjs.org/locate-path/-/locate-path-6.0.0.tgz", + "commit": "", + "patch_args": [], + "patches": [], + "custom_postinstall": "", + "npm_auth": "", + "npm_auth_basic": "", + "npm_auth_username": "", + "npm_auth_password": "", + "lifecycle_hooks": [], + "extra_build_content": "", + "generate_bzl_library_targets": false + } + }, + "npm__locate-path__6.0.0__links": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_links", + "attributes": { + "package": "locate-path", + "version": "6.0.0", + "dev": true, + "root_package": "", + "link_packages": {}, + "deps": { + "p-locate": "5.0.0" + }, + "transitive_closure": { + "locate-path": [ + "6.0.0" + ], + "p-locate": [ + "5.0.0" + ], + "p-limit": [ + "3.1.0" + ], + "yocto-queue": [ + "0.1.0" + ] + }, + "lifecycle_build_target": false, + "lifecycle_hooks_env": [], + "lifecycle_hooks_execution_requirements": [], + "bins": {}, + "npm_translate_lock_repo": "npm", + "package_visibility": [ + "//visibility:public" + ] + } + }, + "npm__lodash.merge__4.6.2": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_rule", + "attributes": { + "package": "lodash.merge", + "version": "4.6.2", + "root_package": "", + "link_workspace": "flatbuffers~", + "link_packages": {}, + "integrity": "sha512-0KpjqXRVvrYyCsX1swR/XTK0va6VQkQM6MNo7PqW77ByjAhoARA8EfrP1N4+KlKj8YS0ZUCtRT/YUuhyYDujIQ==", + "url": "https://registry.npmjs.org/lodash.merge/-/lodash.merge-4.6.2.tgz", + "commit": "", + "patch_args": [], + "patches": [], + "custom_postinstall": "", + "npm_auth": "", + "npm_auth_basic": "", + "npm_auth_username": "", + "npm_auth_password": "", + "lifecycle_hooks": [], + "extra_build_content": "", + "generate_bzl_library_targets": false + } + }, + "npm__lodash.merge__4.6.2__links": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_links", + "attributes": { + "package": "lodash.merge", + "version": "4.6.2", + "dev": true, + "root_package": "", + "link_packages": {}, + "deps": {}, + "transitive_closure": { + "lodash.merge": [ + "4.6.2" + ] + }, + "lifecycle_build_target": false, + "lifecycle_hooks_env": [], + "lifecycle_hooks_execution_requirements": [], + "bins": {}, + "npm_translate_lock_repo": "npm", + "package_visibility": [ + "//visibility:public" + ] + } + }, + "npm__lru-cache__6.0.0": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_rule", + "attributes": { + "package": "lru-cache", + "version": "6.0.0", + "root_package": "", + "link_workspace": "flatbuffers~", + "link_packages": {}, + "integrity": "sha512-Jo6dJ04CmSjuznwJSS3pUeWmd/H0ffTlkXXgwZi+eq1UCmqQwCh+eLsYOYCwY991i2Fah4h1BEMCx4qThGbsiA==", + "url": "https://registry.npmjs.org/lru-cache/-/lru-cache-6.0.0.tgz", + "commit": "", + "patch_args": [], + "patches": [], + "custom_postinstall": "", + "npm_auth": "", + "npm_auth_basic": "", + "npm_auth_username": "", + "npm_auth_password": "", + "lifecycle_hooks": [], + "extra_build_content": "", + "generate_bzl_library_targets": false + } + }, + "npm__lru-cache__6.0.0__links": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_links", + "attributes": { + "package": "lru-cache", + "version": "6.0.0", + "dev": true, + "root_package": "", + "link_packages": {}, + "deps": { + "yallist": "4.0.0" + }, + "transitive_closure": { + "lru-cache": [ + "6.0.0" + ], + "yallist": [ + "4.0.0" + ] + }, + "lifecycle_build_target": false, + "lifecycle_hooks_env": [], + "lifecycle_hooks_execution_requirements": [], + "bins": {}, + "npm_translate_lock_repo": "npm", + "package_visibility": [ + "//visibility:public" + ] + } + }, + "npm__merge2__1.4.1": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_rule", + "attributes": { + "package": "merge2", + "version": "1.4.1", + "root_package": "", + "link_workspace": "flatbuffers~", + "link_packages": {}, + "integrity": "sha512-8q7VEgMJW4J8tcfVPy8g09NcQwZdbwFEqhe/WZkoIzjn/3TGDwtOCYtXGxA3O8tPzpczCCDgv+P2P5y00ZJOOg==", + "url": "https://registry.npmjs.org/merge2/-/merge2-1.4.1.tgz", + "commit": "", + "patch_args": [], + "patches": [], + "custom_postinstall": "", + "npm_auth": "", + "npm_auth_basic": "", + "npm_auth_username": "", + "npm_auth_password": "", + "lifecycle_hooks": [], + "extra_build_content": "", + "generate_bzl_library_targets": false + } + }, + "npm__merge2__1.4.1__links": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_links", + "attributes": { + "package": "merge2", + "version": "1.4.1", + "dev": true, + "root_package": "", + "link_packages": {}, + "deps": {}, + "transitive_closure": { + "merge2": [ + "1.4.1" + ] + }, + "lifecycle_build_target": false, + "lifecycle_hooks_env": [], + "lifecycle_hooks_execution_requirements": [], + "bins": {}, + "npm_translate_lock_repo": "npm", + "package_visibility": [ + "//visibility:public" + ] + } + }, + "npm__micromatch__4.0.5": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_rule", + "attributes": { + "package": "micromatch", + "version": "4.0.5", + "root_package": "", + "link_workspace": "flatbuffers~", + "link_packages": {}, + "integrity": "sha512-DMy+ERcEW2q8Z2Po+WNXuw3c5YaUSFjAO5GsJqfEl7UjvtIuFKO6ZrKvcItdy98dwFI2N1tg3zNIdKaQT+aNdA==", + "url": "https://registry.npmjs.org/micromatch/-/micromatch-4.0.5.tgz", + "commit": "", + "patch_args": [], + "patches": [], + "custom_postinstall": "", + "npm_auth": "", + "npm_auth_basic": "", + "npm_auth_username": "", + "npm_auth_password": "", + "lifecycle_hooks": [], + "extra_build_content": "", + "generate_bzl_library_targets": false + } + }, + "npm__micromatch__4.0.5__links": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_links", + "attributes": { + "package": "micromatch", + "version": "4.0.5", + "dev": true, + "root_package": "", + "link_packages": {}, + "deps": { + "braces": "3.0.2", + "picomatch": "2.3.1" + }, + "transitive_closure": { + "micromatch": [ + "4.0.5" + ], + "braces": [ + "3.0.2" + ], + "picomatch": [ + "2.3.1" + ], + "fill-range": [ + "7.0.1" + ], + "to-regex-range": [ + "5.0.1" + ], + "is-number": [ + "7.0.0" + ] + }, + "lifecycle_build_target": false, + "lifecycle_hooks_env": [], + "lifecycle_hooks_execution_requirements": [], + "bins": {}, + "npm_translate_lock_repo": "npm", + "package_visibility": [ + "//visibility:public" + ] + } + }, + "npm__minimatch__3.1.2": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_rule", + "attributes": { + "package": "minimatch", + "version": "3.1.2", + "root_package": "", + "link_workspace": "flatbuffers~", + "link_packages": {}, + "integrity": "sha512-J7p63hRiAjw1NDEww1W7i37+ByIrOWO5XQQAzZ3VOcL0PNybwpfmV/N05zFAzwQ9USyEcX6t3UO+K5aqBQOIHw==", + "url": "https://registry.npmjs.org/minimatch/-/minimatch-3.1.2.tgz", + "commit": "", + "patch_args": [], + "patches": [], + "custom_postinstall": "", + "npm_auth": "", + "npm_auth_basic": "", + "npm_auth_username": "", + "npm_auth_password": "", + "lifecycle_hooks": [], + "extra_build_content": "", + "generate_bzl_library_targets": false + } + }, + "npm__minimatch__3.1.2__links": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_links", + "attributes": { + "package": "minimatch", + "version": "3.1.2", + "dev": true, + "root_package": "", + "link_packages": {}, + "deps": { + "brace-expansion": "1.1.11" + }, + "transitive_closure": { + "minimatch": [ + "3.1.2" + ], + "brace-expansion": [ + "1.1.11" + ], + "balanced-match": [ + "1.0.2" + ], + "concat-map": [ + "0.0.1" + ] + }, + "lifecycle_build_target": false, + "lifecycle_hooks_env": [], + "lifecycle_hooks_execution_requirements": [], + "bins": {}, + "npm_translate_lock_repo": "npm", + "package_visibility": [ + "//visibility:public" + ] + } + }, + "npm__ms__2.1.2": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_rule", + "attributes": { + "package": "ms", + "version": "2.1.2", + "root_package": "", + "link_workspace": "flatbuffers~", + "link_packages": {}, + "integrity": "sha512-sGkPx+VjMtmA6MX27oA4FBFELFCZZ4S4XqeGOXCv68tT+jb3vk/RyaKWP0PTKyWtmLSM0b+adUTEvbs1PEaH2w==", + "url": "https://registry.npmjs.org/ms/-/ms-2.1.2.tgz", + "commit": "", + "patch_args": [], + "patches": [], + "custom_postinstall": "", + "npm_auth": "", + "npm_auth_basic": "", + "npm_auth_username": "", + "npm_auth_password": "", + "lifecycle_hooks": [], + "extra_build_content": "", + "generate_bzl_library_targets": false + } + }, + "npm__ms__2.1.2__links": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_links", + "attributes": { + "package": "ms", + "version": "2.1.2", + "dev": true, + "root_package": "", + "link_packages": {}, + "deps": {}, + "transitive_closure": { + "ms": [ + "2.1.2" + ] + }, + "lifecycle_build_target": false, + "lifecycle_hooks_env": [], + "lifecycle_hooks_execution_requirements": [], + "bins": {}, + "npm_translate_lock_repo": "npm", + "package_visibility": [ + "//visibility:public" + ] + } + }, + "npm__natural-compare__1.4.0": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_rule", + "attributes": { + "package": "natural-compare", + "version": "1.4.0", + "root_package": "", + "link_workspace": "flatbuffers~", + "link_packages": {}, + "integrity": "sha512-OWND8ei3VtNC9h7V60qff3SVobHr996CTwgxubgyQYEpg290h9J0buyECNNJexkFm5sOajh5G116RYA1c8ZMSw==", + "url": "https://registry.npmjs.org/natural-compare/-/natural-compare-1.4.0.tgz", + "commit": "", + "patch_args": [], + "patches": [], + "custom_postinstall": "", + "npm_auth": "", + "npm_auth_basic": "", + "npm_auth_username": "", + "npm_auth_password": "", + "lifecycle_hooks": [], + "extra_build_content": "", + "generate_bzl_library_targets": false + } + }, + "npm__natural-compare__1.4.0__links": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_links", + "attributes": { + "package": "natural-compare", + "version": "1.4.0", + "dev": true, + "root_package": "", + "link_packages": {}, + "deps": {}, + "transitive_closure": { + "natural-compare": [ + "1.4.0" + ] + }, + "lifecycle_build_target": false, + "lifecycle_hooks_env": [], + "lifecycle_hooks_execution_requirements": [], + "bins": {}, + "npm_translate_lock_repo": "npm", + "package_visibility": [ + "//visibility:public" + ] + } + }, + "npm__once__1.4.0": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_rule", + "attributes": { + "package": "once", + "version": "1.4.0", + "root_package": "", + "link_workspace": "flatbuffers~", + "link_packages": {}, + "integrity": "sha512-lNaJgI+2Q5URQBkccEKHTQOPaXdUxnZZElQTZY0MFUAuaEqe1E+Nyvgdz/aIyNi6Z9MzO5dv1H8n58/GELp3+w==", + "url": "https://registry.npmjs.org/once/-/once-1.4.0.tgz", + "commit": "", + "patch_args": [], + "patches": [], + "custom_postinstall": "", + "npm_auth": "", + "npm_auth_basic": "", + "npm_auth_username": "", + "npm_auth_password": "", + "lifecycle_hooks": [], + "extra_build_content": "", + "generate_bzl_library_targets": false + } + }, + "npm__once__1.4.0__links": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_links", + "attributes": { + "package": "once", + "version": "1.4.0", + "dev": true, + "root_package": "", + "link_packages": {}, + "deps": { + "wrappy": "1.0.2" + }, + "transitive_closure": { + "once": [ + "1.4.0" + ], + "wrappy": [ + "1.0.2" + ] + }, + "lifecycle_build_target": false, + "lifecycle_hooks_env": [], + "lifecycle_hooks_execution_requirements": [], + "bins": {}, + "npm_translate_lock_repo": "npm", + "package_visibility": [ + "//visibility:public" + ] + } + }, + "npm__optionator__0.9.3": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_rule", + "attributes": { + "package": "optionator", + "version": "0.9.3", + "root_package": "", + "link_workspace": "flatbuffers~", + "link_packages": {}, + "integrity": "sha512-JjCoypp+jKn1ttEFExxhetCKeJt9zhAgAve5FXHixTvFDW/5aEktX9bufBKLRRMdU7bNtpLfcGu94B3cdEJgjg==", + "url": "https://registry.npmjs.org/optionator/-/optionator-0.9.3.tgz", + "commit": "", + "patch_args": [], + "patches": [], + "custom_postinstall": "", + "npm_auth": "", + "npm_auth_basic": "", + "npm_auth_username": "", + "npm_auth_password": "", + "lifecycle_hooks": [], + "extra_build_content": "", + "generate_bzl_library_targets": false + } + }, + "npm__optionator__0.9.3__links": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_links", + "attributes": { + "package": "optionator", + "version": "0.9.3", + "dev": true, + "root_package": "", + "link_packages": {}, + "deps": { + "@aashutoshrathi/word-wrap": "1.2.6", + "deep-is": "0.1.4", + "fast-levenshtein": "2.0.6", + "levn": "0.4.1", + "prelude-ls": "1.2.1", + "type-check": "0.4.0" + }, + "transitive_closure": { + "optionator": [ + "0.9.3" + ], + "@aashutoshrathi/word-wrap": [ + "1.2.6" + ], + "deep-is": [ + "0.1.4" + ], + "fast-levenshtein": [ + "2.0.6" + ], + "levn": [ + "0.4.1" + ], + "prelude-ls": [ + "1.2.1" + ], + "type-check": [ + "0.4.0" + ] + }, + "lifecycle_build_target": false, + "lifecycle_hooks_env": [], + "lifecycle_hooks_execution_requirements": [], + "bins": {}, + "npm_translate_lock_repo": "npm", + "package_visibility": [ + "//visibility:public" + ] + } + }, + "npm__p-limit__3.1.0": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_rule", + "attributes": { + "package": "p-limit", + "version": "3.1.0", + "root_package": "", + "link_workspace": "flatbuffers~", + "link_packages": {}, + "integrity": "sha512-TYOanM3wGwNGsZN2cVTYPArw454xnXj5qmWF1bEoAc4+cU/ol7GVh7odevjp1FNHduHc3KZMcFduxU5Xc6uJRQ==", + "url": "https://registry.npmjs.org/p-limit/-/p-limit-3.1.0.tgz", + "commit": "", + "patch_args": [], + "patches": [], + "custom_postinstall": "", + "npm_auth": "", + "npm_auth_basic": "", + "npm_auth_username": "", + "npm_auth_password": "", + "lifecycle_hooks": [], + "extra_build_content": "", + "generate_bzl_library_targets": false + } + }, + "npm__p-limit__3.1.0__links": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_links", + "attributes": { + "package": "p-limit", + "version": "3.1.0", + "dev": true, + "root_package": "", + "link_packages": {}, + "deps": { + "yocto-queue": "0.1.0" + }, + "transitive_closure": { + "p-limit": [ + "3.1.0" + ], + "yocto-queue": [ + "0.1.0" + ] + }, + "lifecycle_build_target": false, + "lifecycle_hooks_env": [], + "lifecycle_hooks_execution_requirements": [], + "bins": {}, + "npm_translate_lock_repo": "npm", + "package_visibility": [ + "//visibility:public" + ] + } + }, + "npm__p-locate__5.0.0": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_rule", + "attributes": { + "package": "p-locate", + "version": "5.0.0", + "root_package": "", + "link_workspace": "flatbuffers~", + "link_packages": {}, + "integrity": "sha512-LaNjtRWUBY++zB5nE/NwcaoMylSPk+S+ZHNB1TzdbMJMny6dynpAGt7X/tl/QYq3TIeE6nxHppbo2LGymrG5Pw==", + "url": "https://registry.npmjs.org/p-locate/-/p-locate-5.0.0.tgz", + "commit": "", + "patch_args": [], + "patches": [], + "custom_postinstall": "", + "npm_auth": "", + "npm_auth_basic": "", + "npm_auth_username": "", + "npm_auth_password": "", + "lifecycle_hooks": [], + "extra_build_content": "", + "generate_bzl_library_targets": false + } + }, + "npm__p-locate__5.0.0__links": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_links", + "attributes": { + "package": "p-locate", + "version": "5.0.0", + "dev": true, + "root_package": "", + "link_packages": {}, + "deps": { + "p-limit": "3.1.0" + }, + "transitive_closure": { + "p-locate": [ + "5.0.0" + ], + "p-limit": [ + "3.1.0" + ], + "yocto-queue": [ + "0.1.0" + ] + }, + "lifecycle_build_target": false, + "lifecycle_hooks_env": [], + "lifecycle_hooks_execution_requirements": [], + "bins": {}, + "npm_translate_lock_repo": "npm", + "package_visibility": [ + "//visibility:public" + ] + } + }, + "npm__parent-module__1.0.1": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_rule", + "attributes": { + "package": "parent-module", + "version": "1.0.1", + "root_package": "", + "link_workspace": "flatbuffers~", + "link_packages": {}, + "integrity": "sha512-GQ2EWRpQV8/o+Aw8YqtfZZPfNRWZYkbidE9k5rpl/hC3vtHHBfGm2Ifi6qWV+coDGkrUKZAxE3Lot5kcsRlh+g==", + "url": "https://registry.npmjs.org/parent-module/-/parent-module-1.0.1.tgz", + "commit": "", + "patch_args": [], + "patches": [], + "custom_postinstall": "", + "npm_auth": "", + "npm_auth_basic": "", + "npm_auth_username": "", + "npm_auth_password": "", + "lifecycle_hooks": [], + "extra_build_content": "", + "generate_bzl_library_targets": false + } + }, + "npm__parent-module__1.0.1__links": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_links", + "attributes": { + "package": "parent-module", + "version": "1.0.1", + "dev": true, + "root_package": "", + "link_packages": {}, + "deps": { + "callsites": "3.1.0" + }, + "transitive_closure": { + "parent-module": [ + "1.0.1" + ], + "callsites": [ + "3.1.0" + ] + }, + "lifecycle_build_target": false, + "lifecycle_hooks_env": [], + "lifecycle_hooks_execution_requirements": [], + "bins": {}, + "npm_translate_lock_repo": "npm", + "package_visibility": [ + "//visibility:public" + ] + } + }, + "npm__path-exists__4.0.0": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_rule", + "attributes": { + "package": "path-exists", + "version": "4.0.0", + "root_package": "", + "link_workspace": "flatbuffers~", + "link_packages": {}, + "integrity": "sha512-ak9Qy5Q7jYb2Wwcey5Fpvg2KoAc/ZIhLSLOSBmRmygPsGwkVVt0fZa0qrtMz+m6tJTAHfZQ8FnmB4MG4LWy7/w==", + "url": "https://registry.npmjs.org/path-exists/-/path-exists-4.0.0.tgz", + "commit": "", + "patch_args": [], + "patches": [], + "custom_postinstall": "", + "npm_auth": "", + "npm_auth_basic": "", + "npm_auth_username": "", + "npm_auth_password": "", + "lifecycle_hooks": [], + "extra_build_content": "", + "generate_bzl_library_targets": false + } + }, + "npm__path-exists__4.0.0__links": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_links", + "attributes": { + "package": "path-exists", + "version": "4.0.0", + "dev": true, + "root_package": "", + "link_packages": {}, + "deps": {}, + "transitive_closure": { + "path-exists": [ + "4.0.0" + ] + }, + "lifecycle_build_target": false, + "lifecycle_hooks_env": [], + "lifecycle_hooks_execution_requirements": [], + "bins": {}, + "npm_translate_lock_repo": "npm", + "package_visibility": [ + "//visibility:public" + ] + } + }, + "npm__path-is-absolute__1.0.1": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_rule", + "attributes": { + "package": "path-is-absolute", + "version": "1.0.1", + "root_package": "", + "link_workspace": "flatbuffers~", + "link_packages": {}, + "integrity": "sha512-AVbw3UJ2e9bq64vSaS9Am0fje1Pa8pbGqTTsmXfaIiMpnr5DlDhfJOuLj9Sf95ZPVDAUerDfEk88MPmPe7UCQg==", + "url": "https://registry.npmjs.org/path-is-absolute/-/path-is-absolute-1.0.1.tgz", + "commit": "", + "patch_args": [], + "patches": [], + "custom_postinstall": "", + "npm_auth": "", + "npm_auth_basic": "", + "npm_auth_username": "", + "npm_auth_password": "", + "lifecycle_hooks": [], + "extra_build_content": "", + "generate_bzl_library_targets": false + } + }, + "npm__path-is-absolute__1.0.1__links": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_links", + "attributes": { + "package": "path-is-absolute", + "version": "1.0.1", + "dev": true, + "root_package": "", + "link_packages": {}, + "deps": {}, + "transitive_closure": { + "path-is-absolute": [ + "1.0.1" + ] + }, + "lifecycle_build_target": false, + "lifecycle_hooks_env": [], + "lifecycle_hooks_execution_requirements": [], + "bins": {}, + "npm_translate_lock_repo": "npm", + "package_visibility": [ + "//visibility:public" + ] + } + }, + "npm__path-key__3.1.1": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_rule", + "attributes": { + "package": "path-key", + "version": "3.1.1", + "root_package": "", + "link_workspace": "flatbuffers~", + "link_packages": {}, + "integrity": "sha512-ojmeN0qd+y0jszEtoY48r0Peq5dwMEkIlCOu6Q5f41lfkswXuKtYrhgoTpLnyIcHm24Uhqx+5Tqm2InSwLhE6Q==", + "url": "https://registry.npmjs.org/path-key/-/path-key-3.1.1.tgz", + "commit": "", + "patch_args": [], + "patches": [], + "custom_postinstall": "", + "npm_auth": "", + "npm_auth_basic": "", + "npm_auth_username": "", + "npm_auth_password": "", + "lifecycle_hooks": [], + "extra_build_content": "", + "generate_bzl_library_targets": false + } + }, + "npm__path-key__3.1.1__links": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_links", + "attributes": { + "package": "path-key", + "version": "3.1.1", + "dev": true, + "root_package": "", + "link_packages": {}, + "deps": {}, + "transitive_closure": { + "path-key": [ + "3.1.1" + ] + }, + "lifecycle_build_target": false, + "lifecycle_hooks_env": [], + "lifecycle_hooks_execution_requirements": [], + "bins": {}, + "npm_translate_lock_repo": "npm", + "package_visibility": [ + "//visibility:public" + ] + } + }, + "npm__path-type__4.0.0": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_rule", + "attributes": { + "package": "path-type", + "version": "4.0.0", + "root_package": "", + "link_workspace": "flatbuffers~", + "link_packages": {}, + "integrity": "sha512-gDKb8aZMDeD/tZWs9P6+q0J9Mwkdl6xMV8TjnGP3qJVJ06bdMgkbBlLU8IdfOsIsFz2BW1rNVT3XuNEl8zPAvw==", + "url": "https://registry.npmjs.org/path-type/-/path-type-4.0.0.tgz", + "commit": "", + "patch_args": [], + "patches": [], + "custom_postinstall": "", + "npm_auth": "", + "npm_auth_basic": "", + "npm_auth_username": "", + "npm_auth_password": "", + "lifecycle_hooks": [], + "extra_build_content": "", + "generate_bzl_library_targets": false + } + }, + "npm__path-type__4.0.0__links": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_links", + "attributes": { + "package": "path-type", + "version": "4.0.0", + "dev": true, + "root_package": "", + "link_packages": {}, + "deps": {}, + "transitive_closure": { + "path-type": [ + "4.0.0" + ] + }, + "lifecycle_build_target": false, + "lifecycle_hooks_env": [], + "lifecycle_hooks_execution_requirements": [], + "bins": {}, + "npm_translate_lock_repo": "npm", + "package_visibility": [ + "//visibility:public" + ] + } + }, + "npm__picomatch__2.3.1": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_rule", + "attributes": { + "package": "picomatch", + "version": "2.3.1", + "root_package": "", + "link_workspace": "flatbuffers~", + "link_packages": {}, + "integrity": "sha512-JU3teHTNjmE2VCGFzuY8EXzCDVwEqB2a8fsIvwaStHhAWJEeVd1o1QD80CU6+ZdEXXSLbSsuLwJjkCBWqRQUVA==", + "url": "https://registry.npmjs.org/picomatch/-/picomatch-2.3.1.tgz", + "commit": "", + "patch_args": [], + "patches": [], + "custom_postinstall": "", + "npm_auth": "", + "npm_auth_basic": "", + "npm_auth_username": "", + "npm_auth_password": "", + "lifecycle_hooks": [], + "extra_build_content": "", + "generate_bzl_library_targets": false + } + }, + "npm__picomatch__2.3.1__links": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_links", + "attributes": { + "package": "picomatch", + "version": "2.3.1", + "dev": true, + "root_package": "", + "link_packages": {}, + "deps": {}, + "transitive_closure": { + "picomatch": [ + "2.3.1" + ] + }, + "lifecycle_build_target": false, + "lifecycle_hooks_env": [], + "lifecycle_hooks_execution_requirements": [], + "bins": {}, + "npm_translate_lock_repo": "npm", + "package_visibility": [ + "//visibility:public" + ] + } + }, + "npm__prelude-ls__1.2.1": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_rule", + "attributes": { + "package": "prelude-ls", + "version": "1.2.1", + "root_package": "", + "link_workspace": "flatbuffers~", + "link_packages": {}, + "integrity": "sha512-vkcDPrRZo1QZLbn5RLGPpg/WmIQ65qoWWhcGKf/b5eplkkarX0m9z8ppCat4mlOqUsWpyNuYgO3VRyrYHSzX5g==", + "url": "https://registry.npmjs.org/prelude-ls/-/prelude-ls-1.2.1.tgz", + "commit": "", + "patch_args": [], + "patches": [], + "custom_postinstall": "", + "npm_auth": "", + "npm_auth_basic": "", + "npm_auth_username": "", + "npm_auth_password": "", + "lifecycle_hooks": [], + "extra_build_content": "", + "generate_bzl_library_targets": false + } + }, + "npm__prelude-ls__1.2.1__links": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_links", + "attributes": { + "package": "prelude-ls", + "version": "1.2.1", + "dev": true, + "root_package": "", + "link_packages": {}, + "deps": {}, + "transitive_closure": { + "prelude-ls": [ + "1.2.1" + ] + }, + "lifecycle_build_target": false, + "lifecycle_hooks_env": [], + "lifecycle_hooks_execution_requirements": [], + "bins": {}, + "npm_translate_lock_repo": "npm", + "package_visibility": [ + "//visibility:public" + ] + } + }, + "npm__punycode__2.3.1": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_rule", + "attributes": { + "package": "punycode", + "version": "2.3.1", + "root_package": "", + "link_workspace": "flatbuffers~", + "link_packages": {}, + "integrity": "sha512-vYt7UD1U9Wg6138shLtLOvdAu+8DsC/ilFtEVHcH+wydcSpNE20AfSOduf6MkRFahL5FY7X1oU7nKVZFtfq8Fg==", + "url": "https://registry.npmjs.org/punycode/-/punycode-2.3.1.tgz", + "commit": "", + "patch_args": [], + "patches": [], + "custom_postinstall": "", + "npm_auth": "", + "npm_auth_basic": "", + "npm_auth_username": "", + "npm_auth_password": "", + "lifecycle_hooks": [], + "extra_build_content": "", + "generate_bzl_library_targets": false + } + }, + "npm__punycode__2.3.1__links": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_links", + "attributes": { + "package": "punycode", + "version": "2.3.1", + "dev": true, + "root_package": "", + "link_packages": {}, + "deps": {}, + "transitive_closure": { + "punycode": [ + "2.3.1" + ] + }, + "lifecycle_build_target": false, + "lifecycle_hooks_env": [], + "lifecycle_hooks_execution_requirements": [], + "bins": {}, + "npm_translate_lock_repo": "npm", + "package_visibility": [ + "//visibility:public" + ] + } + }, + "npm__queue-microtask__1.2.3": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_rule", + "attributes": { + "package": "queue-microtask", + "version": "1.2.3", + "root_package": "", + "link_workspace": "flatbuffers~", + "link_packages": {}, + "integrity": "sha512-NuaNSa6flKT5JaSYQzJok04JzTL1CA6aGhv5rfLW3PgqA+M2ChpZQnAC8h8i4ZFkBS8X5RqkDBHA7r4hej3K9A==", + "url": "https://registry.npmjs.org/queue-microtask/-/queue-microtask-1.2.3.tgz", + "commit": "", + "patch_args": [], + "patches": [], + "custom_postinstall": "", + "npm_auth": "", + "npm_auth_basic": "", + "npm_auth_username": "", + "npm_auth_password": "", + "lifecycle_hooks": [], + "extra_build_content": "", + "generate_bzl_library_targets": false + } + }, + "npm__queue-microtask__1.2.3__links": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_links", + "attributes": { + "package": "queue-microtask", + "version": "1.2.3", + "dev": true, + "root_package": "", + "link_packages": {}, + "deps": {}, + "transitive_closure": { + "queue-microtask": [ + "1.2.3" + ] + }, + "lifecycle_build_target": false, + "lifecycle_hooks_env": [], + "lifecycle_hooks_execution_requirements": [], + "bins": {}, + "npm_translate_lock_repo": "npm", + "package_visibility": [ + "//visibility:public" + ] + } + }, + "npm__resolve-from__4.0.0": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_rule", + "attributes": { + "package": "resolve-from", + "version": "4.0.0", + "root_package": "", + "link_workspace": "flatbuffers~", + "link_packages": {}, + "integrity": "sha512-pb/MYmXstAkysRFx8piNI1tGFNQIFA3vkE3Gq4EuA1dF6gHp/+vgZqsCGJapvy8N3Q+4o7FwvquPJcnZ7RYy4g==", + "url": "https://registry.npmjs.org/resolve-from/-/resolve-from-4.0.0.tgz", + "commit": "", + "patch_args": [], + "patches": [], + "custom_postinstall": "", + "npm_auth": "", + "npm_auth_basic": "", + "npm_auth_username": "", + "npm_auth_password": "", + "lifecycle_hooks": [], + "extra_build_content": "", + "generate_bzl_library_targets": false + } + }, + "npm__resolve-from__4.0.0__links": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_links", + "attributes": { + "package": "resolve-from", + "version": "4.0.0", + "dev": true, + "root_package": "", + "link_packages": {}, + "deps": {}, + "transitive_closure": { + "resolve-from": [ + "4.0.0" + ] + }, + "lifecycle_build_target": false, + "lifecycle_hooks_env": [], + "lifecycle_hooks_execution_requirements": [], + "bins": {}, + "npm_translate_lock_repo": "npm", + "package_visibility": [ + "//visibility:public" + ] + } + }, + "npm__reusify__1.0.4": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_rule", + "attributes": { + "package": "reusify", + "version": "1.0.4", + "root_package": "", + "link_workspace": "flatbuffers~", + "link_packages": {}, + "integrity": "sha512-U9nH88a3fc/ekCF1l0/UP1IosiuIjyTh7hBvXVMHYgVcfGvt897Xguj2UOLDeI5BG2m7/uwyaLVT6fbtCwTyzw==", + "url": "https://registry.npmjs.org/reusify/-/reusify-1.0.4.tgz", + "commit": "", + "patch_args": [], + "patches": [], + "custom_postinstall": "", + "npm_auth": "", + "npm_auth_basic": "", + "npm_auth_username": "", + "npm_auth_password": "", + "lifecycle_hooks": [], + "extra_build_content": "", + "generate_bzl_library_targets": false + } + }, + "npm__reusify__1.0.4__links": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_links", + "attributes": { + "package": "reusify", + "version": "1.0.4", + "dev": true, + "root_package": "", + "link_packages": {}, + "deps": {}, + "transitive_closure": { + "reusify": [ + "1.0.4" + ] + }, + "lifecycle_build_target": false, + "lifecycle_hooks_env": [], + "lifecycle_hooks_execution_requirements": [], + "bins": {}, + "npm_translate_lock_repo": "npm", + "package_visibility": [ + "//visibility:public" + ] + } + }, + "npm__rimraf__3.0.2": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_rule", + "attributes": { + "package": "rimraf", + "version": "3.0.2", + "root_package": "", + "link_workspace": "flatbuffers~", + "link_packages": {}, + "integrity": "sha512-JZkJMZkAGFFPP2YqXZXPbMlMBgsxzE8ILs4lMIX/2o0L9UBw9O/Y3o6wFw/i9YLapcUJWwqbi3kdxIPdC62TIA==", + "url": "https://registry.npmjs.org/rimraf/-/rimraf-3.0.2.tgz", + "commit": "", + "patch_args": [], + "patches": [], + "custom_postinstall": "", + "npm_auth": "", + "npm_auth_basic": "", + "npm_auth_username": "", + "npm_auth_password": "", + "lifecycle_hooks": [], + "extra_build_content": "", + "generate_bzl_library_targets": false + } + }, + "npm__rimraf__3.0.2__links": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_links", + "attributes": { + "package": "rimraf", + "version": "3.0.2", + "dev": true, + "root_package": "", + "link_packages": {}, + "deps": { + "glob": "7.2.3" + }, + "transitive_closure": { + "rimraf": [ + "3.0.2" + ], + "glob": [ + "7.2.3" + ], + "fs.realpath": [ + "1.0.0" + ], + "inflight": [ + "1.0.6" + ], + "inherits": [ + "2.0.4" + ], + "minimatch": [ + "3.1.2" + ], + "once": [ + "1.4.0" + ], + "path-is-absolute": [ + "1.0.1" + ], + "wrappy": [ + "1.0.2" + ], + "brace-expansion": [ + "1.1.11" + ], + "balanced-match": [ + "1.0.2" + ], + "concat-map": [ + "0.0.1" + ] + }, + "lifecycle_build_target": false, + "lifecycle_hooks_env": [], + "lifecycle_hooks_execution_requirements": [], + "bins": {}, + "npm_translate_lock_repo": "npm", + "package_visibility": [ + "//visibility:public" + ] + } + }, + "npm__run-parallel__1.2.0": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_rule", + "attributes": { + "package": "run-parallel", + "version": "1.2.0", + "root_package": "", + "link_workspace": "flatbuffers~", + "link_packages": {}, + "integrity": "sha512-5l4VyZR86LZ/lDxZTR6jqL8AFE2S0IFLMP26AbjsLVADxHdhB/c0GUsH+y39UfCi3dzz8OlQuPmnaJOMoDHQBA==", + "url": "https://registry.npmjs.org/run-parallel/-/run-parallel-1.2.0.tgz", + "commit": "", + "patch_args": [], + "patches": [], + "custom_postinstall": "", + "npm_auth": "", + "npm_auth_basic": "", + "npm_auth_username": "", + "npm_auth_password": "", + "lifecycle_hooks": [], + "extra_build_content": "", + "generate_bzl_library_targets": false + } + }, + "npm__run-parallel__1.2.0__links": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_links", + "attributes": { + "package": "run-parallel", + "version": "1.2.0", + "dev": true, + "root_package": "", + "link_packages": {}, + "deps": { + "queue-microtask": "1.2.3" + }, + "transitive_closure": { + "run-parallel": [ + "1.2.0" + ], + "queue-microtask": [ + "1.2.3" + ] + }, + "lifecycle_build_target": false, + "lifecycle_hooks_env": [], + "lifecycle_hooks_execution_requirements": [], + "bins": {}, + "npm_translate_lock_repo": "npm", + "package_visibility": [ + "//visibility:public" + ] + } + }, + "npm__semver__7.5.4": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_rule", + "attributes": { + "package": "semver", + "version": "7.5.4", + "root_package": "", + "link_workspace": "flatbuffers~", + "link_packages": {}, + "integrity": "sha512-1bCSESV6Pv+i21Hvpxp3Dx+pSD8lIPt8uVjRrxAUt/nbswYc+tK6Y2btiULjd4+fnq15PX+nqQDC7Oft7WkwcA==", + "url": "https://registry.npmjs.org/semver/-/semver-7.5.4.tgz", + "commit": "", + "patch_args": [], + "patches": [], + "custom_postinstall": "", + "npm_auth": "", + "npm_auth_basic": "", + "npm_auth_username": "", + "npm_auth_password": "", + "lifecycle_hooks": [], + "extra_build_content": "", + "generate_bzl_library_targets": false + } + }, + "npm__semver__7.5.4__links": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_links", + "attributes": { + "package": "semver", + "version": "7.5.4", + "dev": true, + "root_package": "", + "link_packages": {}, + "deps": { + "lru-cache": "6.0.0" + }, + "transitive_closure": { + "semver": [ + "7.5.4" + ], + "lru-cache": [ + "6.0.0" + ], + "yallist": [ + "4.0.0" + ] + }, + "lifecycle_build_target": false, + "lifecycle_hooks_env": [], + "lifecycle_hooks_execution_requirements": [], + "bins": {}, + "npm_translate_lock_repo": "npm", + "package_visibility": [ + "//visibility:public" + ] + } + }, + "npm__shebang-command__2.0.0": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_rule", + "attributes": { + "package": "shebang-command", + "version": "2.0.0", + "root_package": "", + "link_workspace": "flatbuffers~", + "link_packages": {}, + "integrity": "sha512-kHxr2zZpYtdmrN1qDjrrX/Z1rR1kG8Dx+gkpK1G4eXmvXswmcE1hTWBWYUzlraYw1/yZp6YuDY77YtvbN0dmDA==", + "url": "https://registry.npmjs.org/shebang-command/-/shebang-command-2.0.0.tgz", + "commit": "", + "patch_args": [], + "patches": [], + "custom_postinstall": "", + "npm_auth": "", + "npm_auth_basic": "", + "npm_auth_username": "", + "npm_auth_password": "", + "lifecycle_hooks": [], + "extra_build_content": "", + "generate_bzl_library_targets": false + } + }, + "npm__shebang-command__2.0.0__links": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_links", + "attributes": { + "package": "shebang-command", + "version": "2.0.0", + "dev": true, + "root_package": "", + "link_packages": {}, + "deps": { + "shebang-regex": "3.0.0" + }, + "transitive_closure": { + "shebang-command": [ + "2.0.0" + ], + "shebang-regex": [ + "3.0.0" + ] + }, + "lifecycle_build_target": false, + "lifecycle_hooks_env": [], + "lifecycle_hooks_execution_requirements": [], + "bins": {}, + "npm_translate_lock_repo": "npm", + "package_visibility": [ + "//visibility:public" + ] + } + }, + "npm__shebang-regex__3.0.0": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_rule", + "attributes": { + "package": "shebang-regex", + "version": "3.0.0", + "root_package": "", + "link_workspace": "flatbuffers~", + "link_packages": {}, + "integrity": "sha512-7++dFhtcx3353uBaq8DDR4NuxBetBzC7ZQOhmTQInHEd6bSrXdiEyzCvG07Z44UYdLShWUyXt5M/yhz8ekcb1A==", + "url": "https://registry.npmjs.org/shebang-regex/-/shebang-regex-3.0.0.tgz", + "commit": "", + "patch_args": [], + "patches": [], + "custom_postinstall": "", + "npm_auth": "", + "npm_auth_basic": "", + "npm_auth_username": "", + "npm_auth_password": "", + "lifecycle_hooks": [], + "extra_build_content": "", + "generate_bzl_library_targets": false + } + }, + "npm__shebang-regex__3.0.0__links": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_links", + "attributes": { + "package": "shebang-regex", + "version": "3.0.0", + "dev": true, + "root_package": "", + "link_packages": {}, + "deps": {}, + "transitive_closure": { + "shebang-regex": [ + "3.0.0" + ] + }, + "lifecycle_build_target": false, + "lifecycle_hooks_env": [], + "lifecycle_hooks_execution_requirements": [], + "bins": {}, + "npm_translate_lock_repo": "npm", + "package_visibility": [ + "//visibility:public" + ] + } + }, + "npm__slash__3.0.0": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_rule", + "attributes": { + "package": "slash", + "version": "3.0.0", + "root_package": "", + "link_workspace": "flatbuffers~", + "link_packages": {}, + "integrity": "sha512-g9Q1haeby36OSStwb4ntCGGGaKsaVSjQ68fBxoQcutl5fS1vuY18H3wSt3jFyFtrkx+Kz0V1G85A4MyAdDMi2Q==", + "url": "https://registry.npmjs.org/slash/-/slash-3.0.0.tgz", + "commit": "", + "patch_args": [], + "patches": [], + "custom_postinstall": "", + "npm_auth": "", + "npm_auth_basic": "", + "npm_auth_username": "", + "npm_auth_password": "", + "lifecycle_hooks": [], + "extra_build_content": "", + "generate_bzl_library_targets": false + } + }, + "npm__slash__3.0.0__links": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_links", + "attributes": { + "package": "slash", + "version": "3.0.0", + "dev": true, + "root_package": "", + "link_packages": {}, + "deps": {}, + "transitive_closure": { + "slash": [ + "3.0.0" + ] + }, + "lifecycle_build_target": false, + "lifecycle_hooks_env": [], + "lifecycle_hooks_execution_requirements": [], + "bins": {}, + "npm_translate_lock_repo": "npm", + "package_visibility": [ + "//visibility:public" + ] + } + }, + "npm__strip-ansi__6.0.1": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_rule", + "attributes": { + "package": "strip-ansi", + "version": "6.0.1", + "root_package": "", + "link_workspace": "flatbuffers~", + "link_packages": {}, + "integrity": "sha512-Y38VPSHcqkFrCpFnQ9vuSXmquuv5oXOKpGeT6aGrr3o3Gc9AlVa6JBfUSOCnbxGGZF+/0ooI7KrPuUSztUdU5A==", + "url": "https://registry.npmjs.org/strip-ansi/-/strip-ansi-6.0.1.tgz", + "commit": "", + "patch_args": [], + "patches": [], + "custom_postinstall": "", + "npm_auth": "", + "npm_auth_basic": "", + "npm_auth_username": "", + "npm_auth_password": "", + "lifecycle_hooks": [], + "extra_build_content": "", + "generate_bzl_library_targets": false + } + }, + "npm__strip-ansi__6.0.1__links": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_links", + "attributes": { + "package": "strip-ansi", + "version": "6.0.1", + "dev": true, + "root_package": "", + "link_packages": {}, + "deps": { + "ansi-regex": "5.0.1" + }, + "transitive_closure": { + "strip-ansi": [ + "6.0.1" + ], + "ansi-regex": [ + "5.0.1" + ] + }, + "lifecycle_build_target": false, + "lifecycle_hooks_env": [], + "lifecycle_hooks_execution_requirements": [], + "bins": {}, + "npm_translate_lock_repo": "npm", + "package_visibility": [ + "//visibility:public" + ] + } + }, + "npm__strip-json-comments__3.1.1": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_rule", + "attributes": { + "package": "strip-json-comments", + "version": "3.1.1", + "root_package": "", + "link_workspace": "flatbuffers~", + "link_packages": {}, + "integrity": "sha512-6fPc+R4ihwqP6N/aIv2f1gMH8lOVtWQHoqC4yK6oSDVVocumAsfCqjkXnqiYMhmMwS/mEHLp7Vehlt3ql6lEig==", + "url": "https://registry.npmjs.org/strip-json-comments/-/strip-json-comments-3.1.1.tgz", + "commit": "", + "patch_args": [], + "patches": [], + "custom_postinstall": "", + "npm_auth": "", + "npm_auth_basic": "", + "npm_auth_username": "", + "npm_auth_password": "", + "lifecycle_hooks": [], + "extra_build_content": "", + "generate_bzl_library_targets": false + } + }, + "npm__strip-json-comments__3.1.1__links": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_links", + "attributes": { + "package": "strip-json-comments", + "version": "3.1.1", + "dev": true, + "root_package": "", + "link_packages": {}, + "deps": {}, + "transitive_closure": { + "strip-json-comments": [ + "3.1.1" + ] + }, + "lifecycle_build_target": false, + "lifecycle_hooks_env": [], + "lifecycle_hooks_execution_requirements": [], + "bins": {}, + "npm_translate_lock_repo": "npm", + "package_visibility": [ + "//visibility:public" + ] + } + }, + "npm__supports-color__7.2.0": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_rule", + "attributes": { + "package": "supports-color", + "version": "7.2.0", + "root_package": "", + "link_workspace": "flatbuffers~", + "link_packages": {}, + "integrity": "sha512-qpCAvRl9stuOHveKsn7HncJRvv501qIacKzQlO/+Lwxc9+0q2wLyv4Dfvt80/DPn2pqOBsJdDiogXGR9+OvwRw==", + "url": "https://registry.npmjs.org/supports-color/-/supports-color-7.2.0.tgz", + "commit": "", + "patch_args": [], + "patches": [], + "custom_postinstall": "", + "npm_auth": "", + "npm_auth_basic": "", + "npm_auth_username": "", + "npm_auth_password": "", + "lifecycle_hooks": [], + "extra_build_content": "", + "generate_bzl_library_targets": false + } + }, + "npm__supports-color__7.2.0__links": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_links", + "attributes": { + "package": "supports-color", + "version": "7.2.0", + "dev": true, + "root_package": "", + "link_packages": {}, + "deps": { + "has-flag": "4.0.0" + }, + "transitive_closure": { + "supports-color": [ + "7.2.0" + ], + "has-flag": [ + "4.0.0" + ] + }, + "lifecycle_build_target": false, + "lifecycle_hooks_env": [], + "lifecycle_hooks_execution_requirements": [], + "bins": {}, + "npm_translate_lock_repo": "npm", + "package_visibility": [ + "//visibility:public" + ] + } + }, + "npm__text-table__0.2.0": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_rule", + "attributes": { + "package": "text-table", + "version": "0.2.0", + "root_package": "", + "link_workspace": "flatbuffers~", + "link_packages": {}, + "integrity": "sha512-N+8UisAXDGk8PFXP4HAzVR9nbfmVJ3zYLAWiTIoqC5v5isinhr+r5uaO8+7r3BMfuNIufIsA7RdpVgacC2cSpw==", + "url": "https://registry.npmjs.org/text-table/-/text-table-0.2.0.tgz", + "commit": "", + "patch_args": [], + "patches": [], + "custom_postinstall": "", + "npm_auth": "", + "npm_auth_basic": "", + "npm_auth_username": "", + "npm_auth_password": "", + "lifecycle_hooks": [], + "extra_build_content": "", + "generate_bzl_library_targets": false + } + }, + "npm__text-table__0.2.0__links": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_links", + "attributes": { + "package": "text-table", + "version": "0.2.0", + "dev": true, + "root_package": "", + "link_packages": {}, + "deps": {}, + "transitive_closure": { + "text-table": [ + "0.2.0" + ] + }, + "lifecycle_build_target": false, + "lifecycle_hooks_env": [], + "lifecycle_hooks_execution_requirements": [], + "bins": {}, + "npm_translate_lock_repo": "npm", + "package_visibility": [ + "//visibility:public" + ] + } + }, + "npm__to-regex-range__5.0.1": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_rule", + "attributes": { + "package": "to-regex-range", + "version": "5.0.1", + "root_package": "", + "link_workspace": "flatbuffers~", + "link_packages": {}, + "integrity": "sha512-65P7iz6X5yEr1cwcgvQxbbIw7Uk3gOy5dIdtZ4rDveLqhrdJP+Li/Hx6tyK0NEb+2GCyneCMJiGqrADCSNk8sQ==", + "url": "https://registry.npmjs.org/to-regex-range/-/to-regex-range-5.0.1.tgz", + "commit": "", + "patch_args": [], + "patches": [], + "custom_postinstall": "", + "npm_auth": "", + "npm_auth_basic": "", + "npm_auth_username": "", + "npm_auth_password": "", + "lifecycle_hooks": [], + "extra_build_content": "", + "generate_bzl_library_targets": false + } + }, + "npm__to-regex-range__5.0.1__links": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_links", + "attributes": { + "package": "to-regex-range", + "version": "5.0.1", + "dev": true, + "root_package": "", + "link_packages": {}, + "deps": { + "is-number": "7.0.0" + }, + "transitive_closure": { + "to-regex-range": [ + "5.0.1" + ], + "is-number": [ + "7.0.0" + ] + }, + "lifecycle_build_target": false, + "lifecycle_hooks_env": [], + "lifecycle_hooks_execution_requirements": [], + "bins": {}, + "npm_translate_lock_repo": "npm", + "package_visibility": [ + "//visibility:public" + ] + } + }, + "npm__ts-api-utils__1.0.3__typescript_5.3.3": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_rule", + "attributes": { + "package": "ts-api-utils", + "version": "1.0.3_typescript_5.3.3", + "root_package": "", + "link_workspace": "flatbuffers~", + "link_packages": {}, + "integrity": "sha512-wNMeqtMz5NtwpT/UZGY5alT+VoKdSsOOP/kqHFcUW1P/VRhH2wJ48+DN2WwUliNbQ976ETwDL0Ifd2VVvgonvg==", + "url": "https://registry.npmjs.org/ts-api-utils/-/ts-api-utils-1.0.3.tgz", + "commit": "", + "patch_args": [], + "patches": [], + "custom_postinstall": "", + "npm_auth": "", + "npm_auth_basic": "", + "npm_auth_username": "", + "npm_auth_password": "", + "lifecycle_hooks": [], + "extra_build_content": "", + "generate_bzl_library_targets": false + } + }, + "npm__ts-api-utils__1.0.3__typescript_5.3.3__links": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_links", + "attributes": { + "package": "ts-api-utils", + "version": "1.0.3_typescript_5.3.3", + "dev": true, + "root_package": "", + "link_packages": {}, + "deps": { + "typescript": "5.3.3" + }, + "transitive_closure": { + "ts-api-utils": [ + "1.0.3_typescript_5.3.3" + ], + "typescript": [ + "5.3.3" + ] + }, + "lifecycle_build_target": false, + "lifecycle_hooks_env": [], + "lifecycle_hooks_execution_requirements": [], + "bins": {}, + "npm_translate_lock_repo": "npm", + "package_visibility": [ + "//visibility:public" + ] + } + }, + "npm__type-check__0.4.0": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_rule", + "attributes": { + "package": "type-check", + "version": "0.4.0", + "root_package": "", + "link_workspace": "flatbuffers~", + "link_packages": {}, + "integrity": "sha512-XleUoc9uwGXqjWwXaUTZAmzMcFZ5858QA2vvx1Ur5xIcixXIP+8LnFDgRplU30us6teqdlskFfu+ae4K79Ooew==", + "url": "https://registry.npmjs.org/type-check/-/type-check-0.4.0.tgz", + "commit": "", + "patch_args": [], + "patches": [], + "custom_postinstall": "", + "npm_auth": "", + "npm_auth_basic": "", + "npm_auth_username": "", + "npm_auth_password": "", + "lifecycle_hooks": [], + "extra_build_content": "", + "generate_bzl_library_targets": false + } + }, + "npm__type-check__0.4.0__links": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_links", + "attributes": { + "package": "type-check", + "version": "0.4.0", + "dev": true, + "root_package": "", + "link_packages": {}, + "deps": { + "prelude-ls": "1.2.1" + }, + "transitive_closure": { + "type-check": [ + "0.4.0" + ], + "prelude-ls": [ + "1.2.1" + ] + }, + "lifecycle_build_target": false, + "lifecycle_hooks_env": [], + "lifecycle_hooks_execution_requirements": [], + "bins": {}, + "npm_translate_lock_repo": "npm", + "package_visibility": [ + "//visibility:public" + ] + } + }, + "npm__type-fest__0.20.2": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_rule", + "attributes": { + "package": "type-fest", + "version": "0.20.2", + "root_package": "", + "link_workspace": "flatbuffers~", + "link_packages": {}, + "integrity": "sha512-Ne+eE4r0/iWnpAxD852z3A+N0Bt5RN//NjJwRd2VFHEmrywxf5vsZlh4R6lixl6B+wz/8d+maTSAkN1FIkI3LQ==", + "url": "https://registry.npmjs.org/type-fest/-/type-fest-0.20.2.tgz", + "commit": "", + "patch_args": [], + "patches": [], + "custom_postinstall": "", + "npm_auth": "", + "npm_auth_basic": "", + "npm_auth_username": "", + "npm_auth_password": "", + "lifecycle_hooks": [], + "extra_build_content": "", + "generate_bzl_library_targets": false + } + }, + "npm__type-fest__0.20.2__links": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_links", + "attributes": { + "package": "type-fest", + "version": "0.20.2", + "dev": true, + "root_package": "", + "link_packages": {}, + "deps": {}, + "transitive_closure": { + "type-fest": [ + "0.20.2" + ] + }, + "lifecycle_build_target": false, + "lifecycle_hooks_env": [], + "lifecycle_hooks_execution_requirements": [], + "bins": {}, + "npm_translate_lock_repo": "npm", + "package_visibility": [ + "//visibility:public" + ] + } + }, + "npm__typescript__5.3.3": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_rule", + "attributes": { + "package": "typescript", + "version": "5.3.3", + "root_package": "", + "link_workspace": "flatbuffers~", + "link_packages": { + "": [ + "typescript" + ] + }, + "integrity": "sha512-pXWcraxM0uxAS+tN0AG/BF2TyqmHO014Z070UsJ+pFvYuRSq8KH8DmWpnbXe0pEPDHXZV3FcAbJkijJ5oNEnWw==", + "url": "https://registry.npmjs.org/typescript/-/typescript-5.3.3.tgz", + "commit": "", + "patch_args": [], + "patches": [], + "custom_postinstall": "", + "npm_auth": "", + "npm_auth_basic": "", + "npm_auth_username": "", + "npm_auth_password": "", + "lifecycle_hooks": [], + "extra_build_content": "", + "generate_bzl_library_targets": false + } + }, + "npm__typescript__5.3.3__links": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_links", + "attributes": { + "package": "typescript", + "version": "5.3.3", + "dev": true, + "root_package": "", + "link_packages": { + "": [ + "typescript" + ] + }, + "deps": {}, + "transitive_closure": { + "typescript": [ + "5.3.3" + ] + }, + "lifecycle_build_target": false, + "lifecycle_hooks_env": [], + "lifecycle_hooks_execution_requirements": [], + "bins": {}, + "npm_translate_lock_repo": "npm", + "package_visibility": [ + "//visibility:public" + ] + } + }, + "npm__undici-types__5.26.5": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_rule", + "attributes": { + "package": "undici-types", + "version": "5.26.5", + "root_package": "", + "link_workspace": "flatbuffers~", + "link_packages": {}, + "integrity": "sha512-JlCMO+ehdEIKqlFxk6IfVoAUVmgz7cU7zD/h9XZ0qzeosSHmUJVOzSQvvYSYWXkFXC+IfLKSIffhv0sVZup6pA==", + "url": "https://registry.npmjs.org/undici-types/-/undici-types-5.26.5.tgz", + "commit": "", + "patch_args": [], + "patches": [], + "custom_postinstall": "", + "npm_auth": "", + "npm_auth_basic": "", + "npm_auth_username": "", + "npm_auth_password": "", + "lifecycle_hooks": [], + "extra_build_content": "", + "generate_bzl_library_targets": false + } + }, + "npm__undici-types__5.26.5__links": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_links", + "attributes": { + "package": "undici-types", + "version": "5.26.5", + "dev": true, + "root_package": "", + "link_packages": {}, + "deps": {}, + "transitive_closure": { + "undici-types": [ + "5.26.5" + ] + }, + "lifecycle_build_target": false, + "lifecycle_hooks_env": [], + "lifecycle_hooks_execution_requirements": [], + "bins": {}, + "npm_translate_lock_repo": "npm", + "package_visibility": [ + "//visibility:public" + ] + } + }, + "npm__uri-js__4.4.1": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_rule", + "attributes": { + "package": "uri-js", + "version": "4.4.1", + "root_package": "", + "link_workspace": "flatbuffers~", + "link_packages": {}, + "integrity": "sha512-7rKUyy33Q1yc98pQ1DAmLtwX109F7TIfWlW1Ydo8Wl1ii1SeHieeh0HHfPeL2fMXK6z0s8ecKs9frCuLJvndBg==", + "url": "https://registry.npmjs.org/uri-js/-/uri-js-4.4.1.tgz", + "commit": "", + "patch_args": [], + "patches": [], + "custom_postinstall": "", + "npm_auth": "", + "npm_auth_basic": "", + "npm_auth_username": "", + "npm_auth_password": "", + "lifecycle_hooks": [], + "extra_build_content": "", + "generate_bzl_library_targets": false + } + }, + "npm__uri-js__4.4.1__links": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_links", + "attributes": { + "package": "uri-js", + "version": "4.4.1", + "dev": true, + "root_package": "", + "link_packages": {}, + "deps": { + "punycode": "2.3.1" + }, + "transitive_closure": { + "uri-js": [ + "4.4.1" + ], + "punycode": [ + "2.3.1" + ] + }, + "lifecycle_build_target": false, + "lifecycle_hooks_env": [], + "lifecycle_hooks_execution_requirements": [], + "bins": {}, + "npm_translate_lock_repo": "npm", + "package_visibility": [ + "//visibility:public" + ] + } + }, + "npm__which__2.0.2": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_rule", + "attributes": { + "package": "which", + "version": "2.0.2", + "root_package": "", + "link_workspace": "flatbuffers~", + "link_packages": {}, + "integrity": "sha512-BLI3Tl1TW3Pvl70l3yq3Y64i+awpwXqsGBYWkkqMtnbXgrMD+yj7rhW0kuEDxzJaYXGjEW5ogapKNMEKNMjibA==", + "url": "https://registry.npmjs.org/which/-/which-2.0.2.tgz", + "commit": "", + "patch_args": [], + "patches": [], + "custom_postinstall": "", + "npm_auth": "", + "npm_auth_basic": "", + "npm_auth_username": "", + "npm_auth_password": "", + "lifecycle_hooks": [], + "extra_build_content": "", + "generate_bzl_library_targets": false + } + }, + "npm__which__2.0.2__links": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_links", + "attributes": { + "package": "which", + "version": "2.0.2", + "dev": true, + "root_package": "", + "link_packages": {}, + "deps": { + "isexe": "2.0.0" + }, + "transitive_closure": { + "which": [ + "2.0.2" + ], + "isexe": [ + "2.0.0" + ] + }, + "lifecycle_build_target": false, + "lifecycle_hooks_env": [], + "lifecycle_hooks_execution_requirements": [], + "bins": {}, + "npm_translate_lock_repo": "npm", + "package_visibility": [ + "//visibility:public" + ] + } + }, + "npm__wrappy__1.0.2": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_rule", + "attributes": { + "package": "wrappy", + "version": "1.0.2", + "root_package": "", + "link_workspace": "flatbuffers~", + "link_packages": {}, + "integrity": "sha512-l4Sp/DRseor9wL6EvV2+TuQn63dMkPjZ/sp9XkghTEbV9KlPS1xUsZ3u7/IQO4wxtcFB4bgpQPRcR3QCvezPcQ==", + "url": "https://registry.npmjs.org/wrappy/-/wrappy-1.0.2.tgz", + "commit": "", + "patch_args": [], + "patches": [], + "custom_postinstall": "", + "npm_auth": "", + "npm_auth_basic": "", + "npm_auth_username": "", + "npm_auth_password": "", + "lifecycle_hooks": [], + "extra_build_content": "", + "generate_bzl_library_targets": false + } + }, + "npm__wrappy__1.0.2__links": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_links", + "attributes": { + "package": "wrappy", + "version": "1.0.2", + "dev": true, + "root_package": "", + "link_packages": {}, + "deps": {}, + "transitive_closure": { + "wrappy": [ + "1.0.2" + ] + }, + "lifecycle_build_target": false, + "lifecycle_hooks_env": [], + "lifecycle_hooks_execution_requirements": [], + "bins": {}, + "npm_translate_lock_repo": "npm", + "package_visibility": [ + "//visibility:public" + ] + } + }, + "npm__yallist__4.0.0": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_rule", + "attributes": { + "package": "yallist", + "version": "4.0.0", + "root_package": "", + "link_workspace": "flatbuffers~", + "link_packages": {}, + "integrity": "sha512-3wdGidZyq5PB084XLES5TpOSRA3wjXAlIWMhum2kRcv/41Sn2emQ0dycQW4uZXLejwKvg6EsvbdlVL+FYEct7A==", + "url": "https://registry.npmjs.org/yallist/-/yallist-4.0.0.tgz", + "commit": "", + "patch_args": [], + "patches": [], + "custom_postinstall": "", + "npm_auth": "", + "npm_auth_basic": "", + "npm_auth_username": "", + "npm_auth_password": "", + "lifecycle_hooks": [], + "extra_build_content": "", + "generate_bzl_library_targets": false + } + }, + "npm__yallist__4.0.0__links": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_links", + "attributes": { + "package": "yallist", + "version": "4.0.0", + "dev": true, + "root_package": "", + "link_packages": {}, + "deps": {}, + "transitive_closure": { + "yallist": [ + "4.0.0" + ] + }, + "lifecycle_build_target": false, + "lifecycle_hooks_env": [], + "lifecycle_hooks_execution_requirements": [], + "bins": {}, + "npm_translate_lock_repo": "npm", + "package_visibility": [ + "//visibility:public" + ] + } + }, + "npm__yocto-queue__0.1.0": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_rule", + "attributes": { + "package": "yocto-queue", + "version": "0.1.0", + "root_package": "", + "link_workspace": "flatbuffers~", + "link_packages": {}, + "integrity": "sha512-rVksvsnNCdJ/ohGc6xgPwyN8eheCxsiLM8mxuE/t/mOVqJewPuO1miLpTHQiRgTKCLexL4MeAFVagts7HmNZ2Q==", + "url": "https://registry.npmjs.org/yocto-queue/-/yocto-queue-0.1.0.tgz", + "commit": "", + "patch_args": [], + "patches": [], + "custom_postinstall": "", + "npm_auth": "", + "npm_auth_basic": "", + "npm_auth_username": "", + "npm_auth_password": "", + "lifecycle_hooks": [], + "extra_build_content": "", + "generate_bzl_library_targets": false + } + }, + "npm__yocto-queue__0.1.0__links": { + "bzlFile": "@@aspect_rules_js~//npm/private:npm_import.bzl", + "ruleClassName": "npm_import_links", + "attributes": { + "package": "yocto-queue", + "version": "0.1.0", + "dev": true, + "root_package": "", + "link_packages": {}, + "deps": {}, + "transitive_closure": { + "yocto-queue": [ + "0.1.0" + ] + }, + "lifecycle_build_target": false, + "lifecycle_hooks_env": [], + "lifecycle_hooks_execution_requirements": [], + "bins": {}, + "npm_translate_lock_repo": "npm", + "package_visibility": [ + "//visibility:public" + ] + } + } + }, + "recordedRepoMappingEntries": [ + [ + "aspect_bazel_lib~", + "bazel_skylib", + "bazel_skylib~" + ], + [ + "aspect_bazel_lib~", + "bazel_tools", + "bazel_tools" + ], + [ + "aspect_rules_js~", + "aspect_bazel_lib", + "aspect_bazel_lib~" + ], + [ + "aspect_rules_js~", + "bazel_features", + "bazel_features~" + ], + [ + "aspect_rules_js~", + "bazel_skylib", + "bazel_skylib~" + ], + [ + "aspect_rules_js~", + "bazel_tools", + "bazel_tools" + ], + [ + "bazel_features~", + "bazel_tools", + "bazel_tools" + ] + ] + } + }, "@@platforms//host:extension.bzl%host_platform": { "general": { "bzlTransitiveDigest": "xelQcPZH8+tmuOHVjL9vDxMnnQNMlwj0SlvgoqBkm4U=", @@ -151,35 +11686,6 @@ "recordedRepoMappingEntries": [] } }, - "@@protobuf~//:non_module_deps.bzl%non_module_deps": { - "general": { - "bzlTransitiveDigest": "n42CE1R95fa5ddK2PVwgWYAZfG476FzMuRvz0zo5gs8=", - "usagesDigest": "1JwsUDre7ljlZoaD2WfcvUlKnXUonmxIKAVBQ82j6Ig=", - "recordedFileInputs": {}, - "recordedDirentsInputs": {}, - "envVariables": {}, - "generatedRepoSpecs": { - "utf8_range": { - "bzlFile": "@@bazel_tools//tools/build_defs/repo:http.bzl", - "ruleClassName": "http_archive", - "attributes": { - "urls": [ - "https://github.com/protocolbuffers/utf8_range/archive/de0b4a8ff9b5d4c98108bdfe723291a33c52c54f.zip" - ], - "strip_prefix": "utf8_range-de0b4a8ff9b5d4c98108bdfe723291a33c52c54f", - "sha256": "5da960e5e5d92394c809629a03af3c7709d2d3d0ca731dacb3a9fb4bf28f7702" - } - } - }, - "recordedRepoMappingEntries": [ - [ - "protobuf~", - "bazel_tools", - "bazel_tools" - ] - ] - } - }, "@@pybind11_bazel~//:internal_configure.bzl%internal_configure_extension": { "general": { "bzlTransitiveDigest": "CyAKLVVonohnkTSqg9II/HA7M49sOlnMkgMHL3CmDuc=", @@ -597,6 +12103,68 @@ ] } }, + "@@rules_go~//go:extensions.bzl%go_sdk": { + "general": { + "bzlTransitiveDigest": "6OpUR/yglzmu6OR0l9BvoXNEmRETCk2i9/mg6yhIbMA=", + "usagesDigest": "d+jWsKUXmjXLstb8Ps8lKcqQSS92aURSbdbgcoFp7Ao=", + "recordedFileInputs": {}, + "recordedDirentsInputs": {}, + "envVariables": {}, + "generatedRepoSpecs": { + "go_default_sdk": { + "bzlFile": "@@rules_go~//go/private:sdk.bzl", + "ruleClassName": "go_download_sdk_rule", + "attributes": { + "goos": "", + "goarch": "", + "sdks": {}, + "urls": [ + "https://dl.google.com/go/{}" + ], + "version": "1.20.2" + } + }, + "go_host_compatible_sdk_label": { + "bzlFile": "@@rules_go~//go/private:extensions.bzl", + "ruleClassName": "host_compatible_toolchain", + "attributes": { + "toolchain": "@go_default_sdk//:ROOT" + } + }, + "go_toolchains": { + "bzlFile": "@@rules_go~//go/private:sdk.bzl", + "ruleClassName": "go_multiple_toolchains", + "attributes": { + "prefixes": [ + "_0000_go_default_sdk_" + ], + "geese": [ + "" + ], + "goarchs": [ + "" + ], + "sdk_repos": [ + "go_default_sdk" + ], + "sdk_types": [ + "remote" + ], + "sdk_versions": [ + "1.20.2" + ] + } + } + }, + "recordedRepoMappingEntries": [ + [ + "rules_go~", + "bazel_tools", + "bazel_tools" + ] + ] + } + }, "@@rules_jvm_external~//:extensions.bzl%maven": { "general": { "bzlTransitiveDigest": "ZZwUwwzxkACVpF3u5nup1ClQKp1WEF5TLy//fGjPiKU=", @@ -1804,32 +13372,93 @@ ] } }, - "@@rules_jvm_external~//:non-module-deps.bzl%non_module_deps": { + "@@rules_nodejs~//nodejs:extensions.bzl%node": { "general": { - "bzlTransitiveDigest": "ZOivBbbZUakRexeLO/N26oX4Bcph6HHnqNmfxt7yoCc=", - "usagesDigest": "53kHAQcKNmL0k7OtizNBnaTWq84lbKdGYv7383Wp/fc=", + "bzlTransitiveDigest": "KOk+Te5m8n3d0B9F5+lgyrzLbtEzqeqWset0MugBbOY=", + "usagesDigest": "Hpfezedx02zaAfjLSaFZ52QtTY1hd57RMXgRknmjzA0=", "recordedFileInputs": {}, "recordedDirentsInputs": {}, "envVariables": {}, "generatedRepoSpecs": { - "io_bazel_rules_kotlin": { - "bzlFile": "@@bazel_tools//tools/build_defs/repo:http.bzl", - "ruleClassName": "http_archive", + "nodejs_linux_amd64": { + "bzlFile": "@@rules_nodejs~//nodejs:repositories.bzl", + "ruleClassName": "node_repositories", "attributes": { - "sha256": "946747acdbeae799b085d12b240ec346f775ac65236dfcf18aa0cd7300f6de78", - "urls": [ - "https://github.com/bazelbuild/rules_kotlin/releases/download/v1.7.0-RC-2/rules_kotlin_release.tgz" - ] + "platform": "linux_amd64", + "node_version": "16.20.0" + } + }, + "nodejs_linux_arm64": { + "bzlFile": "@@rules_nodejs~//nodejs:repositories.bzl", + "ruleClassName": "node_repositories", + "attributes": { + "platform": "linux_arm64", + "node_version": "16.20.0" + } + }, + "nodejs_linux_s390x": { + "bzlFile": "@@rules_nodejs~//nodejs:repositories.bzl", + "ruleClassName": "node_repositories", + "attributes": { + "platform": "linux_s390x", + "node_version": "16.20.0" + } + }, + "nodejs_linux_ppc64le": { + "bzlFile": "@@rules_nodejs~//nodejs:repositories.bzl", + "ruleClassName": "node_repositories", + "attributes": { + "platform": "linux_ppc64le", + "node_version": "16.20.0" + } + }, + "nodejs_darwin_amd64": { + "bzlFile": "@@rules_nodejs~//nodejs:repositories.bzl", + "ruleClassName": "node_repositories", + "attributes": { + "platform": "darwin_amd64", + "node_version": "16.20.0" + } + }, + "nodejs_darwin_arm64": { + "bzlFile": "@@rules_nodejs~//nodejs:repositories.bzl", + "ruleClassName": "node_repositories", + "attributes": { + "platform": "darwin_arm64", + "node_version": "16.20.0" + } + }, + "nodejs_windows_amd64": { + "bzlFile": "@@rules_nodejs~//nodejs:repositories.bzl", + "ruleClassName": "node_repositories", + "attributes": { + "platform": "windows_amd64", + "node_version": "16.20.0" + } + }, + "nodejs": { + "bzlFile": "@@rules_nodejs~//nodejs/private:nodejs_repo_host_os_alias.bzl", + "ruleClassName": "nodejs_repo_host_os_alias", + "attributes": { + "user_node_repository_name": "nodejs" + } + }, + "nodejs_host": { + "bzlFile": "@@rules_nodejs~//nodejs/private:nodejs_repo_host_os_alias.bzl", + "ruleClassName": "nodejs_repo_host_os_alias", + "attributes": { + "user_node_repository_name": "nodejs" + } + }, + "nodejs_toolchains": { + "bzlFile": "@@rules_nodejs~//nodejs/private:toolchains_repo.bzl", + "ruleClassName": "toolchains_repo", + "attributes": { + "user_node_repository_name": "nodejs" } } }, - "recordedRepoMappingEntries": [ - [ - "rules_jvm_external~", - "bazel_tools", - "bazel_tools" - ] - ] + "recordedRepoMappingEntries": [] } }, "@@rules_python~//python/private/pypi:pip.bzl%pip_internal": { @@ -4116,35 +15745,6 @@ ] ] } - }, - "@@upb~//:non_module_deps.bzl%non_module_deps": { - "general": { - "bzlTransitiveDigest": "n42CE1R95fa5ddK2PVwgWYAZfG476FzMuRvz0zo5gs8=", - "usagesDigest": "jUN0s3TyKWQVNLdkIwSzKkk73kEAiVZpjP3qSq+wCWA=", - "recordedFileInputs": {}, - "recordedDirentsInputs": {}, - "envVariables": {}, - "generatedRepoSpecs": { - "utf8_range": { - "bzlFile": "@@bazel_tools//tools/build_defs/repo:http.bzl", - "ruleClassName": "http_archive", - "attributes": { - "urls": [ - "https://github.com/protocolbuffers/utf8_range/archive/de0b4a8ff9b5d4c98108bdfe723291a33c52c54f.zip" - ], - "strip_prefix": "utf8_range-de0b4a8ff9b5d4c98108bdfe723291a33c52c54f", - "sha256": "5da960e5e5d92394c809629a03af3c7709d2d3d0ca731dacb3a9fb4bf28f7702" - } - } - }, - "recordedRepoMappingEntries": [ - [ - "upb~", - "bazel_tools", - "bazel_tools" - ] - ] - } } } } diff --git a/framework/src/vx_context.cpp b/framework/src/vx_context.cpp index affb901d..084ad36c 100644 --- a/framework/src/vx_context.cpp +++ b/framework/src/vx_context.cpp @@ -34,6 +34,8 @@ vx_char targetModules[][VX_MAX_TARGET_NAME] = { #endif "openvx-c_model", "openvx-onnxRT", + "openvx-ai-server", + "openvx-liteRT", }; const vx_char extensions[] = diff --git a/include/VX/vx_corevx_ext.h b/include/VX/vx_corevx_ext.h index 2180b30e..7469c5be 100644 --- a/include/VX/vx_corevx_ext.h +++ b/include/VX/vx_corevx_ext.h @@ -1,6 +1,6 @@ /** * @file vx_corevx_ext.h - * @brief Extensions enabled for corevs + * @brief Extensions enabled for corevx * @version 0.1 * @date 2024-12-15 * @@ -13,6 +13,24 @@ #include #include +#ifdef __cplusplus +#include + +/*! \brief A character array (string) type. + * \note This is a C++ string type. It is not a C string. + * \ingroup group_basic_features + */ +using vx_string = std::string; +#endif /* __cplusplus */ + +/*! \brief The type enumeration lists additional types to extend the known types in OpenVX. + * \ingroup group_basic_features + */ +enum vx_type_ext_e +{ + VX_TYPE_STRING = 0x818, /*!< \brief A \ref vx_string. */ +}; + /*! \brief Define Edge AI Vendor ID * \ingroup group_basic_features */ @@ -30,6 +48,14 @@ enum vx_kernel_ext_e * \brief The ONNX Runtime CPU Inference kernel. */ VX_KERNEL_ORT_CPU_INF = VX_KERNEL_BASE(VX_ID_EDGE_AI, VX_LIBRARY_KHR_BASE) + 0x1, + /*! + * \brief The AI Model Server Chatbot kernel. + */ + VX_KERNEL_AIS_CHATBOT = VX_KERNEL_BASE(VX_ID_EDGE_AI, VX_LIBRARY_KHR_BASE) + 0x2, + /*! + * \brief The LiteRT CPU Inference kernel. + */ + VX_KERNEL_LITERT_CPU_INF = VX_KERNEL_BASE(VX_ID_EDGE_AI, VX_LIBRARY_KHR_BASE) + 0x3, }; /*! \brief addtitional tensor attributes. diff --git a/kernels/ai_server/BUILD b/kernels/ai_server/BUILD new file mode 100644 index 00000000..274f9687 --- /dev/null +++ b/kernels/ai_server/BUILD @@ -0,0 +1,20 @@ +cc_library( + name = "llm_kernels", + srcs = glob([ + "*.cpp", + ]), + hdrs = glob([ + "*.h", + "*.hpp", + ]), + includes = [ + ".", + "//framework/include" + ], + deps = [ + "//:corevx", + "@curl//:curl", + "@nlohmann_json//:json" + ], + visibility = ["//visibility:public"] +) \ No newline at end of file diff --git a/kernels/ai_server/chatbot.hpp b/kernels/ai_server/chatbot.hpp new file mode 100644 index 00000000..617c7c8b --- /dev/null +++ b/kernels/ai_server/chatbot.hpp @@ -0,0 +1,110 @@ +/** + * @file chatbot.hpp + * @brief Kernel for AI Model Server Chatbot + * @version 0.1 + * @date 2025-04-04 + * + * @copyright Copyright (c) 2025 + * + */ +#include +#include +#include +#include +#include + +#define DEFAULT_MODEL "gpt-4o-mini" +#define SERVER_URL "http://localhost:8000" +#define API_KEY "hardcoded-api-key" + +class RemoteModelClient +{ +private: + // Helper function for non-streaming response + static size_t WriteCallback(void *contents, size_t size, size_t nmemb, void *userp) + { + size_t totalSize = size * nmemb; + ((std::string *)userp)->append((char *)contents, totalSize); + return totalSize; + } + +public: + // kernel function (non-streaming) + vx_status AiServerQuery(const std::string &input_text, std::string &output_text, const std::string &api_path) + { + CURL *curl = curl_easy_init(); + if (!curl) + return VX_FAILURE; + + nlohmann::json request_json = { + {"model", DEFAULT_MODEL}, + {"messages", {{{"role", "user"}, {"content", input_text}}}}, + {"max_tokens", 100}, + {"stream", false}}; + + std::string request_payload = request_json.dump(); + std::string response_string; + std::string api_url = std::string(SERVER_URL) + api_path; + + struct curl_slist *headers = nullptr; + headers = curl_slist_append(headers, "Content-Type: application/json"); + headers = curl_slist_append(headers, ("Authorization: Bearer " + std::string(API_KEY)).c_str()); + + curl_easy_setopt(curl, CURLOPT_URL, api_url.c_str()); + curl_easy_setopt(curl, CURLOPT_POSTFIELDS, request_payload.c_str()); + curl_easy_setopt(curl, CURLOPT_HTTPHEADER, headers); + curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, WriteCallback); + curl_easy_setopt(curl, CURLOPT_WRITEDATA, &response_string); + + CURLcode res = curl_easy_perform(curl); + curl_slist_free_all(headers); + curl_easy_cleanup(curl); + + if (res != CURLE_OK) + return VX_FAILURE; + + auto json_response = nlohmann::json::parse(response_string); + output_text = json_response["choices"][0]["message"]["content"]; + + return VX_SUCCESS; + } + + // kernel function (streaming) + vx_status AiServerQueryStream(const std::string &input_text, std::string &output_text, const std::string &api_path) + { + CURL *curl = curl_easy_init(); + if (!curl) + return VX_FAILURE; + + nlohmann::json request_json = { + {"model", DEFAULT_MODEL}, + {"messages", {{{"role", "user"}, {"content", input_text}}}}, + {"max_tokens", 100}, + {"stream", true}}; + + std::string request_payload = request_json.dump(); + std::string response_chunk; + std::string api_url = std::string(SERVER_URL) + api_path; + + struct curl_slist *headers = nullptr; + headers = curl_slist_append(headers, "Content-Type: application/json"); + headers = curl_slist_append(headers, ("Authorization: Bearer " + std::string(API_KEY)).c_str()); + + curl_easy_setopt(curl, CURLOPT_URL, api_url.c_str()); + curl_easy_setopt(curl, CURLOPT_POSTFIELDS, request_payload.c_str()); + curl_easy_setopt(curl, CURLOPT_HTTPHEADER, headers); + curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, WriteCallback); + curl_easy_setopt(curl, CURLOPT_WRITEDATA, &response_chunk); + + CURLcode res = curl_easy_perform(curl); + curl_slist_free_all(headers); + curl_easy_cleanup(curl); + + if (res != CURLE_OK) + return VX_FAILURE; + + // Just return raw streamed response (newline-delimited JSON chunks) + output_text = response_chunk; + return VX_SUCCESS; + } +}; diff --git a/kernels/liteRT/BUILD b/kernels/liteRT/BUILD new file mode 100644 index 00000000..e9783d06 --- /dev/null +++ b/kernels/liteRT/BUILD @@ -0,0 +1,21 @@ + +cc_library( + name = "liteRT_kernels", + srcs = glob([ + "*.cpp", + ]), + hdrs = glob([ + "*.h", + "*.hpp", + ]), + includes = [ + ".", + "//framework/include", + ], + deps = [ + "//:corevx", + "//third_party:tflite", + "//third_party:tflite-hdrs", + ], + visibility = ["//visibility:public"] +) \ No newline at end of file diff --git a/kernels/liteRT/tflite.hpp b/kernels/liteRT/tflite.hpp new file mode 100644 index 00000000..c6c24e2b --- /dev/null +++ b/kernels/liteRT/tflite.hpp @@ -0,0 +1,255 @@ +/** + * @file tflite.hpp + * @brief + * @version 0.1 + * @date 2025-04-19 + * + * @copyright Copyright (c) 2025 + * + */ +#include +#include +#include + +#include "tensorflow/lite/core/interpreter_builder.h" +#include "tensorflow/lite/kernels/register.h" +#include "tensorflow/lite/interpreter.h" +#include "tensorflow/lite/model_builder.h" +#include "tensorflow/lite/optional_debug_tools.h" + +#define TFLITE_MINIMAL_CHECK(x) \ + if (!(x)) \ + { \ + fprintf(stderr, "Error at %s:%d\n", __FILE__, __LINE__); \ + return VX_FAILURE; \ + } + +/** + * @brief Class to run TFLite models + * + */ +class TFLiteRunner +{ +public: + /** + * @brief TFLiteRunner Constructor + */ + TFLiteRunner() : modelLoaded(false) {}; + + /** + * @brief Initialize the TFLite interpreter (load the model) + * @param filename Path to the ONNX model file + * @return VX_SUCCESS on success, VX_FAILURE otherwise + */ + vx_status init(std::string &filename) + { + TFLITE_MINIMAL_CHECK(false == filename.empty()) + + if (!modelLoaded) + { + // Load model + model = tflite::FlatBufferModel::BuildFromFile(filename.c_str()); + TFLITE_MINIMAL_CHECK(model != nullptr); + + // Build the interpreter with the InterpreterBuilder. + // Note: all Interpreters should be built with the InterpreterBuilder, + // which allocates memory for the Interpreter and does various set up + // tasks so that the Interpreter can read the provided model. + tflite::ops::builtin::BuiltinOpResolver resolver; + tflite::InterpreterBuilder builder(*model, resolver); + builder(&interpreter); + TFLITE_MINIMAL_CHECK(interpreter != nullptr); + + printf("=== Pre-invoke Interpreter State ===\n"); + tflite::PrintInterpreterState(interpreter.get()); + } + + return VX_SUCCESS; + } + + /** + * @brief Validate input/output parameters + * @param inputDims Input tensor dimensions + * @param outputDims Output tensor dimensions + * @return VX_SUCCESS on success, VX_FAILURE otherwise + */ + vx_status validate(std::vector> &inputDims, std::vector> &outputDims) + { + vx_status status = VX_SUCCESS; + + // Validate input dimensions + if (inputDims.size() != interpreter->inputs().size()) + { + fprintf(stderr, "Mismatch in number of input tensors: expected %zu, got %zu\n", + inputDims.size(), interpreter->inputs().size()); + return VX_FAILURE; + } + + for (std::size_t i = 0; i < interpreter->inputs().size(); ++i) + { + TfLiteTensor *input_tensor = interpreter->tensor(interpreter->inputs()[i]); + if (input_tensor == nullptr) + { + fprintf(stderr, "Input tensor at index %zu is null.\n", i); + return VX_FAILURE; + } + + // Get the shape of the input tensor + std::vector tensor_shape(input_tensor->dims->size); + for (int j = 0; j < input_tensor->dims->size; ++j) + { + tensor_shape[j] = input_tensor->dims->data[j]; + } + + // Compare with the expected shape + if (tensor_shape != inputDims[i]) + { + fprintf(stderr, "Mismatch in input tensor %zu shape: expected {", i); + for (size_t dim : inputDims[i]) + fprintf(stderr, "%zu,", dim); + fprintf(stderr, "} but got {"); + for (size_t dim : tensor_shape) + fprintf(stderr, "%zu,", dim); + fprintf(stderr, "}\n"); + return VX_FAILURE; + } + } + + // Validate output dimensions + if (outputDims.size() != interpreter->outputs().size()) + { + fprintf(stderr, "Mismatch in number of output tensors: expected %zu, got %zu\n", + outputDims.size(), interpreter->outputs().size()); + return VX_FAILURE; + } + + for (std::size_t i = 0; i < interpreter->outputs().size(); ++i) + { + TfLiteTensor *output_tensor = interpreter->tensor(interpreter->outputs()[i]); + if (output_tensor == nullptr) + { + fprintf(stderr, "Output tensor at index %zu is null.\n", i); + return VX_FAILURE; + } + + // Get the shape of the output tensor + std::vector tensor_shape(output_tensor->dims->size); + for (int j = 0; j < output_tensor->dims->size; ++j) + { + tensor_shape[j] = output_tensor->dims->data[j]; + } + + // Compare with the expected shape + if (tensor_shape != outputDims[i]) + { + fprintf(stderr, "Mismatch in output tensor %zu shape: expected {", i); + for (size_t dim : outputDims[i]) + fprintf(stderr, "%zu,", dim); + fprintf(stderr, "} but got {"); + for (size_t dim : tensor_shape) + fprintf(stderr, "%zu,", dim); + fprintf(stderr, "}\n"); + return VX_FAILURE; + } + } + + return status; + } + + /** + * @brief Allocate memory for input and output tensors + * @param inputTensors Input tensors + * @param outputTensors Output tensors + * @return VX_SUCCESS on success, VX_FAILURE otherwise + */ + vx_status allocate(std::vector> &inputTensors, std::vector> &outputTensors) + { + vx_status status = VX_SUCCESS; + + // Fill input buffers + // TODO(user): Insert code to fill input tensors. + // Note: The buffer of the input tensor with index `i` of type T can + // be accessed with `T* input = interpreter->typed_input_tensor(i);` + for (std::size_t i = 0; i < interpreter->inputs().size(); ++i) + { + status = bindMemory(interpreter->inputs()[i], inputTensors[i].first, inputTensors[i].second); + } + + // Read output buffers + // TODO(user): Insert getting data out code. + // Note: The buffer of the output tensor with index `i` of type T can + // be accessed with `T* output = interpreter->typed_output_tensor(i);` + for (std::size_t i = 0; i < interpreter->outputs().size(); ++i) + { + status |= bindMemory(interpreter->outputs()[i], outputTensors[i].first, outputTensors[i].second); + } + + // Allocate tensor buffers. + TFLITE_MINIMAL_CHECK(interpreter->AllocateTensors() == kTfLiteOk); + + return status; + } + + /** + * @brief Run the kernel (execute the model) + * @param inputTensors Input tensors + * @param outputTensosrs Output tensors + * @return VX_SUCCESS on success, VX_FAILURE otherwise + */ + vx_status run() + { + // Run inference + TFLITE_MINIMAL_CHECK(interpreter->Invoke() == kTfLiteOk); + printf("\n\n=== Post-invoke Interpreter State ===\n"); + tflite::PrintInterpreterState(interpreter.get()); + return VX_SUCCESS; + } + +private: + bool modelLoaded = false; + std::unique_ptr model; + // Pointer to the TFLite interpreter + std::unique_ptr interpreter; + + /** + * @brief Bind pre-allocated memory to a tensor + * @param tensor_index Index of the tensor to bind + * @param pre_allocated_memory Pointer to the pre-allocated memory + * @param size_in_bytes Size of the pre-allocated memory in bytes + * @return VX_SUCCESS on success, VX_FAILURE otherwise + */ + vx_status bindMemory(int tensor_index, void* pre_allocated_memory, size_t size_in_bytes) + { + vx_status status = VX_SUCCESS; + + // Get the tensor + TfLiteTensor* tensor = interpreter->tensor(tensor_index); + + // Check if the tensor exists + if (tensor == nullptr) + { + fprintf(stderr, "Tensor at index %d does not exist.\n", tensor_index); + status = VX_FAILURE; + } + + // Ensure the tensor type and size match your pre-allocated memory + if (VX_SUCCESS == status && + tensor->bytes != size_in_bytes) + { + fprintf(stderr, "Pre-allocated memory size (%ld) does not match tensor size (%ld).\n", + size_in_bytes, tensor->bytes); + status = VX_FAILURE; + } + + if (VX_SUCCESS == status) + { + // Bind the pre-allocated memory to the tensor + TFLITE_MINIMAL_CHECK(kTfLiteOk == interpreter->SetCustomAllocationForTensor( + tensor_index, + {pre_allocated_memory, size_in_bytes}, + kTfLiteCustomAllocationFlagsSkipAlignCheck)); + } + + return status; + } +}; diff --git a/targets/ai_server/BUILD b/targets/ai_server/BUILD new file mode 100644 index 00000000..6ee430a1 --- /dev/null +++ b/targets/ai_server/BUILD @@ -0,0 +1,32 @@ + +cc_library( + name = "ai-server", + srcs = glob([ + "*.cpp", + "*.h", + ]), + includes = [ + ".", + "//framework/include", + "//kernels/ai-server", + ], + deps = [ + "//:corevx", + "//kernels/ai_server:llm_kernels" + ], + visibility = ["//visibility:public"] +) + +cc_shared_library( + name = "openvx-ai-server", + deps = [ + ":ai-server", + ], + visibility = ["//visibility:public"] +) + +cc_import( + name = "imported_openvx_ai_server", + shared_library = ":openvx-ai-server", + visibility = ["//visibility:public"] +) diff --git a/targets/ai_server/vx_chatbot.cpp b/targets/ai_server/vx_chatbot.cpp new file mode 100644 index 00000000..6e3ab048 --- /dev/null +++ b/targets/ai_server/vx_chatbot.cpp @@ -0,0 +1,108 @@ +/** + * @file vx_chatbot.cpp + * @brief OpenVX Interface Into AI Model Server + * @version 0.1 + * @date 2025-01-20 + * + * @copyright Copyright (c) 2025 + * + */ +#include +#include +#include + +#include +#include +#include +#include + +#include "chatbot.hpp" +#include "vx_internal.h" + +// Create an instance of ORT runner +static const std::shared_ptr kernel = std::make_shared(); + +static std::unordered_map api_map = { + {"chat", "/v1/chat/completions"}, +}; + +class VxRemoteModelClient +{ +private: + static vx_status store_vx_string_to_array(vx_array arr, const vx_string &in) + { + vx_status status = vxTruncateArray(arr, 0); // clear existing contents + if (status != VX_SUCCESS) + return status; + + return vxAddArrayItems(arr, in.size(), in.data(), sizeof(char)); + } + + static vx_status load_vx_string_from_array(vx_array arr, vx_string &out) + { + vx_size size = 0; + vx_status status = vxQueryArray(arr, VX_ARRAY_ATTRIBUTE_NUMITEMS, &size, sizeof(size)); + if (status != VX_SUCCESS || size == 0) + return VX_FAILURE; + + out.resize(size); // allocate space directly in std::string + status = vxCopyArrayRange(arr, 0, size, sizeof(char), out.data(), VX_READ_ONLY, VX_MEMORY_TYPE_HOST); + return status; + } + +public: + static constexpr vx_param_description_t kernelParams[] = { + {VX_INPUT, VX_TYPE_ARRAY, VX_PARAMETER_STATE_REQUIRED}, // Parameter 0: Input text + {VX_OUTPUT, VX_TYPE_ARRAY, VX_PARAMETER_STATE_REQUIRED}, // Parameter 1: Output text + }; + + static vx_status VX_CALLBACK init(vx_node node, const vx_reference parameters[], vx_uint32 num) + { + (void)node; + (void)parameters; + (void)num; + return VX_SUCCESS; + } + + static vx_status VX_CALLBACK validate(vx_node node, const vx_reference parameters[], vx_uint32 num, vx_meta_format metas[]) + { + (void)node; + (void)parameters; + (void)num; + (void)metas; + return VX_SUCCESS; + } + + static vx_status VX_CALLBACK run(vx_node node, const vx_reference *parameters, vx_uint32 num) + { + (void)node; + (void)parameters; + (void)num; + vx_status status = VX_SUCCESS; + vx_string input_text, output_text; + + status = load_vx_string_from_array((vx_array)parameters[0], input_text); + status |= kernel->AiServerQuery( + input_text, // Input text + output_text, // Output text + api_map["chat"]); // API path + status |= store_vx_string_to_array((vx_array)parameters[1], output_text); + + return status; + } +}; + +/** + * @brief Ai Model Server Chatbot Kernel description structure + */ +vx_kernel_description_t chatbot_kernel = { + VX_KERNEL_AIS_CHATBOT, // Unique kernel ID + "remote.model.chat", // Kernel name + VxRemoteModelClient::run, // Kernel execution function + const_cast(VxRemoteModelClient::kernelParams), + dimof(VxRemoteModelClient::kernelParams), // Number of parameters + VxRemoteModelClient::validate, // Kernel validation function + nullptr, + nullptr, + VxRemoteModelClient::init, // Kernel initialization function + nullptr}; \ No newline at end of file diff --git a/targets/ai_server/vx_interface.cpp b/targets/ai_server/vx_interface.cpp new file mode 100644 index 00000000..c4b3e3a8 --- /dev/null +++ b/targets/ai_server/vx_interface.cpp @@ -0,0 +1,258 @@ +/** + * @file vx_interface.cpp + * @brief AI Model Server Target Interface + * @version 0.1 + * @date 2025-01-20 + * + * @copyright Copyright (c) 2025 + * + */ + +/*! + * \file + * \brief The AI Model Server Target Interface + */ + +#include + +#include "vx_internal.h" +#include "vx_interface.h" + +static const vx_char name[VX_MAX_TARGET_NAME] = "corevx.ai.server"; + +/*! \brief Declares the list of all supported base kernels. + * \ingroup group_implementation + * \note This is the list of all supported base kernels! It must at least + * match the OpenVX 1.0 Specification. + */ +static vx_kernel_description_t *target_kernels[] = + { + &chatbot_kernel}; + +/*! \brief Declares the number of base supported kernels. + * \ingroup group_implementation + */ +static vx_uint32 num_target_kernels = dimof(target_kernels); + +/******************************************************************************/ +/* EXPORTED FUNCTIONS */ +/******************************************************************************/ +extern "C" vx_status vxTargetInit(vx_target target) +{ + if (target) + { + strncpy(target->name, name, VX_MAX_TARGET_NAME); + target->priority = VX_TARGET_PRIORITY_ORT; + } + return target->initializeTarget(target_kernels, num_target_kernels); +} + +extern "C" vx_status vxTargetDeinit(vx_target target) +{ + return target->deinitializeTarget(); +} + +extern "C" vx_status vxTargetSupports(vx_target target, + vx_char targetName[VX_MAX_TARGET_NAME], + vx_char kernelName[VX_MAX_KERNEL_NAME], + vx_uint32 *pIndex) +{ + vx_status status = VX_ERROR_NOT_SUPPORTED; + if (strncmp(targetName, name, VX_MAX_TARGET_NAME) == 0) + { + vx_uint32 k = 0u; + for (k = 0u; k < VX_INT_MAX_KERNELS; k++) + { + vx_char targetKernelName[VX_MAX_KERNEL_NAME]; + vx_char *kernel; + vx_char def[8] = "default"; + + if (target->kernels[k]) + { + strncpy(targetKernelName, target->kernels[k]->name, VX_MAX_KERNEL_NAME); + kernel = strtok(targetKernelName, ":"); + if (kernel == nullptr) + { + kernel = def; + } + + if (strncmp(kernelName, kernel, VX_MAX_KERNEL_NAME) == 0) + { + status = VX_SUCCESS; + if (pIndex) + *pIndex = k; + break; + } + } + } + } + return status; +} + +extern "C" vx_action vxTargetProcess(vx_target target, vx_node nodes[], vx_size startIndex, vx_size numNodes) +{ + vx_action action = VX_ACTION_CONTINUE; + vx_status status = VX_SUCCESS; + vx_size n = 0; + (void)target; + + for (n = startIndex; (n < (startIndex + numNodes)) && (action == VX_ACTION_CONTINUE); n++) + { + vx_context context = vxGetContext((vx_reference)nodes[n]); + VX_PRINT(VX_ZONE_GRAPH, "Executing Kernel %s:%d in Nodes[%u] on target %s\n", + nodes[n]->kernel->name, + nodes[n]->kernel->enumeration, + n, + nodes[n]->context->targets[nodes[n]->affinity]->name); + + if (context->perf_enabled) + Osal::startCapture(&nodes[n]->perf); + + if (nodes[n]->is_replicated == vx_true_e) + { + vx_size num_replicas = 0; + vx_uint32 param; + vx_uint32 num_parameters = nodes[n]->kernel->signature.num_parameters; + vx_reference parameters[VX_INT_MAX_PARAMS] = {nullptr}; + + for (param = 0; param < num_parameters; ++param) + { + if (nodes[n]->replicated_flags[param] == vx_true_e) + { + vx_size numItems = 0; + if ((nodes[n]->parameters[param])->scope->type == VX_TYPE_PYRAMID) + { + vx_pyramid pyr = (vx_pyramid)(nodes[n]->parameters[param])->scope; + numItems = pyr->numLevels; + } + else if ((nodes[n]->parameters[param])->scope->type == VX_TYPE_OBJECT_ARRAY) + { + vx_object_array arr = (vx_object_array)(nodes[n]->parameters[param])->scope; + numItems = arr->num_items; + } + else + { + status = VX_ERROR_INVALID_PARAMETERS; + break; + } + + if (num_replicas == 0) + num_replicas = numItems; + else if (numItems != num_replicas) + { + status = VX_ERROR_INVALID_PARAMETERS; + break; + } + } + else + { + parameters[param] = nodes[n]->parameters[param]; + } + } + + if (status == VX_SUCCESS) + { + vx_size replica; + for (replica = 0; replica < num_replicas; ++replica) + { + for (param = 0; param < num_parameters; ++param) + { + if (nodes[n]->replicated_flags[param] == vx_true_e) + { + if ((nodes[n]->parameters[param])->scope->type == VX_TYPE_PYRAMID) + { + vx_pyramid pyr = (vx_pyramid)(nodes[n]->parameters[param])->scope; + parameters[param] = (vx_reference)pyr->levels[replica]; + } + else if ((nodes[n]->parameters[param])->scope->type == VX_TYPE_OBJECT_ARRAY) + { + vx_object_array arr = (vx_object_array)(nodes[n]->parameters[param])->scope; + parameters[param] = (vx_reference)arr->items[replica]; + } + } + } + + status = nodes[n]->kernel->function((vx_node)nodes[n], + parameters, + num_parameters); + } + } + } + else + { + status = nodes[n]->kernel->function((vx_node)nodes[n], + (vx_reference *)nodes[n]->parameters, + nodes[n]->kernel->signature.num_parameters); + } + + nodes[n]->executed = vx_true_e; + nodes[n]->status = status; + + if (context->perf_enabled) + Osal::stopCapture(&nodes[n]->perf); + + VX_PRINT(VX_ZONE_GRAPH, "kernel %s returned %d\n", nodes[n]->kernel->name, status); + + if (status == VX_SUCCESS) + { + /* call the callback if it is attached */ + if (nodes[n]->callback) + { + action = nodes[n]->callback((vx_node)nodes[n]); + VX_PRINT(VX_ZONE_GRAPH, "callback returned action %d\n", action); + } + } + else + { + action = VX_ACTION_ABANDON; + VX_PRINT(VX_ZONE_ERROR, "Abandoning Graph due to error (%d)!\n", status); + } + } + return action; +} + +extern "C" vx_status vxTargetVerify(vx_target target, vx_node node) +{ + vx_status status = VX_SUCCESS; + (void)target; + (void)node; + + return status; +} + +extern "C" vx_kernel vxTargetAddKernel(vx_target target, + vx_char name[VX_MAX_KERNEL_NAME], + vx_enum enumeration, + vx_kernel_f func_ptr, + vx_uint32 numParams, + vx_kernel_validate_f validate, + vx_kernel_input_validate_f input, + vx_kernel_output_validate_f output, + vx_kernel_initialize_f initialize, + vx_kernel_deinitialize_f deinitialize) +{ + VX_PRINT(VX_ZONE_INFO, "Entered %s\n", __func__); + vx_uint32 k = 0u; + vx_kernel kernel = nullptr; + Osal::semWait(&target->lock); + + for (k = 0; k < VX_INT_MAX_KERNELS; k++) + { + if (target->kernels[k] == nullptr || target->kernels[k]->enabled == vx_false_e) + { + target->kernels[k] = reinterpret_cast(Reference::createReference(target->context, VX_TYPE_KERNEL, VX_INTERNAL, target->context)); + target->kernels[k]->initializeKernel(enumeration, func_ptr, name, + nullptr, numParams, + validate, input, output, + initialize, deinitialize); + VX_PRINT(VX_ZONE_KERNEL, "Reserving %s Kernel[%u] for %s\n", target->name, k, target->kernels[k]->name); + target->num_kernels++; + kernel = target->kernels[k]; + break; + } + kernel = nullptr; + } + Osal::semPost(&target->lock); + + return kernel; +} diff --git a/targets/ai_server/vx_interface.h b/targets/ai_server/vx_interface.h new file mode 100644 index 00000000..7d35de14 --- /dev/null +++ b/targets/ai_server/vx_interface.h @@ -0,0 +1,17 @@ +/** + * @file vx_interface.h + * @brief AI Model Server Target Interface + * @version 0.1 + * @date 2025-01-20 + * + * @copyright Copyright (c) 2025 + * + */ +#ifndef OPENVX_INTERFACE_H +#define OPENVX_INTERFACE_H + +#include + +extern vx_kernel_description_t chatbot_kernel; + +#endif /* OPENVX_INTERFACE_H */ diff --git a/targets/liteRT/BUILD b/targets/liteRT/BUILD new file mode 100644 index 00000000..6bf9f710 --- /dev/null +++ b/targets/liteRT/BUILD @@ -0,0 +1,29 @@ + +cc_library( + name = "liteRT", + srcs = glob([ + "*.cpp", + "*.h", + ]), + includes = [ + ".", + "//framework/include", + ], + deps = [ + "//:corevx", + "//kernels/liteRT:liteRT_kernels", + ], + visibility = ["//visibility:public"] +) + +cc_shared_library( + name = "openvx-liteRT", + deps = [":liteRT"], + visibility = ["//visibility:public"] +) + +cc_import( + name = "imported_openvx_liteRT", + shared_library = ":openvx-liteRT", + visibility = ["//visibility:public"] +) \ No newline at end of file diff --git a/targets/liteRT/vx_interface.cpp b/targets/liteRT/vx_interface.cpp new file mode 100644 index 00000000..52335621 --- /dev/null +++ b/targets/liteRT/vx_interface.cpp @@ -0,0 +1,258 @@ +/** + * @file vx_interface.cpp + * @brief TFLITE Runtime Target Interface + * @version 0.1 + * @date 2025-01-20 + * + * @copyright Copyright (c) 2025 + * + */ + +/*! + * \file + * \brief The TFLITE-RT Target Interface + */ + +#include + +#include "vx_internal.h" +#include "vx_interface.h" + +static const vx_char name[VX_MAX_TARGET_NAME] = "corevx.tflite.rt"; + +/*! \brief Declares the list of all supported base kernels. + * \ingroup group_implementation + * \note This is the list of all supported base kernels! It must at least + * match the OpenVX 1.0 Specification. + */ +static vx_kernel_description_t *target_kernels[] = + { + &tflite_cpu_inf_kernel}; + +/*! \brief Declares the number of base supported kernels. + * \ingroup group_implementation + */ +static vx_uint32 num_target_kernels = dimof(target_kernels); + +/******************************************************************************/ +/* EXPORTED FUNCTIONS */ +/******************************************************************************/ +extern "C" vx_status vxTargetInit(vx_target target) +{ + if (target) + { + strncpy(target->name, name, VX_MAX_TARGET_NAME); + target->priority = VX_TARGET_PRIORITY_ORT; + } + return target->initializeTarget(target_kernels, num_target_kernels); +} + +extern "C" vx_status vxTargetDeinit(vx_target target) +{ + return target->deinitializeTarget(); +} + +extern "C" vx_status vxTargetSupports(vx_target target, + vx_char targetName[VX_MAX_TARGET_NAME], + vx_char kernelName[VX_MAX_KERNEL_NAME], + vx_uint32 *pIndex) +{ + vx_status status = VX_ERROR_NOT_SUPPORTED; + if (strncmp(targetName, name, VX_MAX_TARGET_NAME) == 0) + { + vx_uint32 k = 0u; + for (k = 0u; k < VX_INT_MAX_KERNELS; k++) + { + vx_char targetKernelName[VX_MAX_KERNEL_NAME]; + vx_char *kernel; + vx_char def[8] = "default"; + + if (target->kernels[k]) + { + strncpy(targetKernelName, target->kernels[k]->name, VX_MAX_KERNEL_NAME); + kernel = strtok(targetKernelName, ":"); + if (kernel == nullptr) + { + kernel = def; + } + + if (strncmp(kernelName, kernel, VX_MAX_KERNEL_NAME) == 0) + { + status = VX_SUCCESS; + if (pIndex) + *pIndex = k; + break; + } + } + } + } + return status; +} + +extern "C" vx_action vxTargetProcess(vx_target target, vx_node nodes[], vx_size startIndex, vx_size numNodes) +{ + vx_action action = VX_ACTION_CONTINUE; + vx_status status = VX_SUCCESS; + vx_size n = 0; + (void)target; + + for (n = startIndex; (n < (startIndex + numNodes)) && (action == VX_ACTION_CONTINUE); n++) + { + vx_context context = vxGetContext((vx_reference)nodes[n]); + VX_PRINT(VX_ZONE_GRAPH, "Executing Kernel %s:%d in Nodes[%u] on target %s\n", + nodes[n]->kernel->name, + nodes[n]->kernel->enumeration, + n, + nodes[n]->context->targets[nodes[n]->affinity]->name); + + if (context->perf_enabled) + Osal::startCapture(&nodes[n]->perf); + + if (nodes[n]->is_replicated == vx_true_e) + { + vx_size num_replicas = 0; + vx_uint32 param; + vx_uint32 num_parameters = nodes[n]->kernel->signature.num_parameters; + vx_reference parameters[VX_INT_MAX_PARAMS] = {nullptr}; + + for (param = 0; param < num_parameters; ++param) + { + if (nodes[n]->replicated_flags[param] == vx_true_e) + { + vx_size numItems = 0; + if ((nodes[n]->parameters[param])->scope->type == VX_TYPE_PYRAMID) + { + vx_pyramid pyr = (vx_pyramid)(nodes[n]->parameters[param])->scope; + numItems = pyr->numLevels; + } + else if ((nodes[n]->parameters[param])->scope->type == VX_TYPE_OBJECT_ARRAY) + { + vx_object_array arr = (vx_object_array)(nodes[n]->parameters[param])->scope; + numItems = arr->num_items; + } + else + { + status = VX_ERROR_INVALID_PARAMETERS; + break; + } + + if (num_replicas == 0) + num_replicas = numItems; + else if (numItems != num_replicas) + { + status = VX_ERROR_INVALID_PARAMETERS; + break; + } + } + else + { + parameters[param] = nodes[n]->parameters[param]; + } + } + + if (status == VX_SUCCESS) + { + vx_size replica; + for (replica = 0; replica < num_replicas; ++replica) + { + for (param = 0; param < num_parameters; ++param) + { + if (nodes[n]->replicated_flags[param] == vx_true_e) + { + if ((nodes[n]->parameters[param])->scope->type == VX_TYPE_PYRAMID) + { + vx_pyramid pyr = (vx_pyramid)(nodes[n]->parameters[param])->scope; + parameters[param] = (vx_reference)pyr->levels[replica]; + } + else if ((nodes[n]->parameters[param])->scope->type == VX_TYPE_OBJECT_ARRAY) + { + vx_object_array arr = (vx_object_array)(nodes[n]->parameters[param])->scope; + parameters[param] = (vx_reference)arr->items[replica]; + } + } + } + + status = nodes[n]->kernel->function((vx_node)nodes[n], + parameters, + num_parameters); + } + } + } + else + { + status = nodes[n]->kernel->function((vx_node)nodes[n], + (vx_reference *)nodes[n]->parameters, + nodes[n]->kernel->signature.num_parameters); + } + + nodes[n]->executed = vx_true_e; + nodes[n]->status = status; + + if (context->perf_enabled) + Osal::stopCapture(&nodes[n]->perf); + + VX_PRINT(VX_ZONE_GRAPH, "kernel %s returned %d\n", nodes[n]->kernel->name, status); + + if (status == VX_SUCCESS) + { + /* call the callback if it is attached */ + if (nodes[n]->callback) + { + action = nodes[n]->callback((vx_node)nodes[n]); + VX_PRINT(VX_ZONE_GRAPH, "callback returned action %d\n", action); + } + } + else + { + action = VX_ACTION_ABANDON; + VX_PRINT(VX_ZONE_ERROR, "Abandoning Graph due to error (%d)!\n", status); + } + } + return action; +} + +extern "C" vx_status vxTargetVerify(vx_target target, vx_node node) +{ + vx_status status = VX_SUCCESS; + (void)target; + (void)node; + + return status; +} + +extern "C" vx_kernel vxTargetAddKernel(vx_target target, + vx_char name[VX_MAX_KERNEL_NAME], + vx_enum enumeration, + vx_kernel_f func_ptr, + vx_uint32 numParams, + vx_kernel_validate_f validate, + vx_kernel_input_validate_f input, + vx_kernel_output_validate_f output, + vx_kernel_initialize_f initialize, + vx_kernel_deinitialize_f deinitialize) +{ + VX_PRINT(VX_ZONE_INFO, "Entered %s\n", __func__); + vx_uint32 k = 0u; + vx_kernel kernel = nullptr; + Osal::semWait(&target->lock); + + for (k = 0; k < VX_INT_MAX_KERNELS; k++) + { + if (target->kernels[k] == nullptr || target->kernels[k]->enabled == vx_false_e) + { + target->kernels[k] = reinterpret_cast(Reference::createReference(target->context, VX_TYPE_KERNEL, VX_INTERNAL, target->context)); + target->kernels[k]->initializeKernel(enumeration, func_ptr, name, + nullptr, numParams, + validate, input, output, + initialize, deinitialize); + VX_PRINT(VX_ZONE_KERNEL, "Reserving %s Kernel[%u] for %s\n", target->name, k, target->kernels[k]->name); + target->num_kernels++; + kernel = target->kernels[k]; + break; + } + kernel = nullptr; + } + Osal::semPost(&target->lock); + + return kernel; +} diff --git a/targets/liteRT/vx_interface.h b/targets/liteRT/vx_interface.h new file mode 100644 index 00000000..ec0e7d4b --- /dev/null +++ b/targets/liteRT/vx_interface.h @@ -0,0 +1,17 @@ +/** + * @file vx_interface.h + * @brief TFLITE Runtime Target Interface + * @version 0.1 + * @date 2025-01-20 + * + * @copyright Copyright (c) 2025 + * + */ +#ifndef OPENVX_INTERFACE_H +#define OPENVX_INTERFACE_H + +#include + +extern vx_kernel_description_t tflite_cpu_inf_kernel; + +#endif /* OPENVX_INTERFACE_H */ diff --git a/targets/liteRT/vx_litert_inf.cpp b/targets/liteRT/vx_litert_inf.cpp new file mode 100644 index 00000000..43168350 --- /dev/null +++ b/targets/liteRT/vx_litert_inf.cpp @@ -0,0 +1,280 @@ +/** + * @file vx_ort_inf.cpp + * @brief OpenVX Interface Into LiteRT + * @version 0.1 + * @date 2025-01-20 + * + * @copyright Copyright (c) 2025 + * + */ +#include + +#include +#include +#include +#include + +#include "tflite.hpp" +#include "vx_internal.h" + +// Create an instance of ORT runner +static const std::shared_ptr kernel = std::make_shared(); + +class VxLiteRTRunner +{ +public: + static constexpr vx_param_description_t kernelParams[] = { + {VX_INPUT, VX_TYPE_ARRAY, VX_PARAMETER_STATE_REQUIRED}, // Parameter 0: Model path + {VX_INPUT, VX_TYPE_OBJECT_ARRAY, VX_PARAMETER_STATE_REQUIRED}, // Parameter 1: Input tensors + {VX_OUTPUT, VX_TYPE_OBJECT_ARRAY, VX_PARAMETER_STATE_REQUIRED} // Parameter 2: Output tensors + }; + + // Initialization function + static vx_status VX_CALLBACK litertInitWrapper(vx_node node, const vx_reference parameters[], vx_uint32 num) + { + vx_status status = VX_SUCCESS; + std::string modelPath; + // Get the tensor pointers, total size of each, and cache them in a vector of pairs + std::vector> inputTensors; + std::vector> outputTensors; + // Get the tensor dimensions + std::vector> inputDims; + std::vector> outputDims; + + if (nullptr == node || + nullptr == parameters || + num != dimof(kernelParams)) + { + status = VX_FAILURE; + } + + if (VX_SUCCESS == status) + { + // Get the model path from the first parameter + vx_array array = (vx_array)parameters[0]; + status = readStringFromVxArray(array, modelPath); + + if (VX_SUCCESS == status) + { + VX_PRINT(VX_ZONE_INFO, "Reading from model path: %s\n", modelPath.c_str()); + // Initialize the kernel with the model path + status |= kernel->init(modelPath); + } + } + + if (VX_SUCCESS == status) + { + // Process input tensors + status = processTensors((vx_object_array)parameters[1], inputTensors); + // Process output tensors if input processing was successful + status |= processTensors((vx_object_array)parameters[2], outputTensors); + } + + if (VX_SUCCESS == status) + { + // Bind the input and output tensors + status = kernel->allocate(inputTensors, outputTensors); + } + + if (VX_SUCCESS == status) + { + // Get the input tensor dimensions from the tensors + status = processTensorDims(reinterpret_cast(parameters[1]), inputDims); + // Get the output tensor dimensions from the tensors + status = processTensorDims(reinterpret_cast(parameters[2]), outputDims); + } + + if (VX_SUCCESS == status) + { + // Call the validate member function + status = kernel->validate(inputDims, outputDims); + } + + return status; + } + + // Validation function + static vx_status VX_CALLBACK litertValidateWrapper(vx_node node, const vx_reference parameters[], vx_uint32 num, vx_meta_format metas[]) + { + vx_status status = VX_SUCCESS; + + if (nullptr == node || + nullptr == parameters || + num != dimof(kernelParams) || + nullptr == metas) + { + std::cerr << "Error: Invalid parameters during validation!" << std::endl; + status = VX_FAILURE; + } + + if (VX_SUCCESS == status) + { + // Retrieve the kernel instance from the node's local data + if (!kernel) + { + std::cerr << "Error: Kernel instance is null during validation!" << std::endl; + status = VX_FAILURE; + } + } + + if (VX_SUCCESS == status) + { + vx_object_array outputObjArr = reinterpret_cast(parameters[2]); + vx_size numItems = 0; + vx_enum itemType = VX_TYPE_TENSOR; + + status = vxQueryObjectArray(outputObjArr, VX_OBJECT_ARRAY_NUMITEMS, &numItems, sizeof(numItems)); + status |= vxSetMetaFormatAttribute(metas[2], VX_OBJECT_ARRAY_NUMITEMS, &numItems, sizeof(numItems)); + status |= vxSetMetaFormatAttribute(metas[2], VX_OBJECT_ARRAY_ITEMTYPE, &itemType, sizeof(vx_enum)); + } + + return status; + } + + // Execution function + static vx_status VX_CALLBACK litertRunWrapper(vx_node node, const vx_reference *parameters, vx_uint32 num) + { + vx_status status = VX_SUCCESS; + + if (nullptr == node || + nullptr == parameters || + num != dimof(kernelParams)) + { + status = VX_FAILURE; + } + + if (VX_SUCCESS == status) + { + // Retrieve the kernel instance from the node's local data + if (!kernel) + { + std::cerr << "Error: Kernel instance is null during execution!" << std::endl; + status = VX_FAILURE; + } + } + + if (VX_SUCCESS == status) + { + // Call the run member function + status = kernel->run(); + } + + return status; + } + +private: + /** + * @brief Helper function to read a string from a VX char array + * + * @param[in] array openvx char array to read from + * @param[out] str Output string containing the read data + * @return vx_status VX_SUCCESS on success, otherwise an error code + */ + static vx_status readStringFromVxArray(vx_array array, std::string &str) + { + vx_status status = VX_SUCCESS; + vx_size num_items = 0u, stride = 0u; + vx_map_id map_id = 0; + void *ptr = nullptr; + + status = vxQueryArray(array, VX_ARRAY_ATTRIBUTE_NUMITEMS, &num_items, sizeof(num_items)); + status |= vxMapArrayRange(array, 0, num_items, &map_id, &stride, &ptr, VX_READ_ONLY, VX_MEMORY_TYPE_HOST, VX_NOGAP_X); + + if (VX_SUCCESS == status) + { + str = std::string(static_cast(ptr)); + status |= vxUnmapArrayRange(array, map_id); + } + + return status; + } + + /** + * @brief Helper function to process tensor dimensions from an object array + * + * @param[in] objArr Object array containing tensors + * @param[out] dims Vector of vectors containing tensor dimensions + * @return vx_status VX_SUCCESS on success, otherwise an error code + */ + static vx_status processTensorDims(vx_object_array objArr, std::vector> &dims) + { + vx_status status = VX_SUCCESS; + vx_size numItems = 0, numDims = 0; + std::vector tensorDims; + + status = vxQueryObjectArray(objArr, VX_OBJECT_ARRAY_NUMITEMS, &numItems, sizeof(numItems)); + + for (vx_uint32 i = 0; i < numItems && status == VX_SUCCESS; ++i) + { + vx_tensor tensor = reinterpret_cast(vxGetObjectArrayItem(objArr, i)); + status |= vxQueryTensor(tensor, VX_TENSOR_NUMBER_OF_DIMS, &numDims, sizeof(numDims)); + tensorDims.resize(numDims); + status |= vxQueryTensor(tensor, VX_TENSOR_DIMS, tensorDims.data(), sizeof(vx_size) * tensorDims.size()); + + if (VX_SUCCESS != status) + { + std::cerr << "Error: Unable to query tensor in " << __func__ << " " << status << std::endl; + break; + } + dims.push_back(tensorDims); + } + return status; + } + + /** + * @brief Helper function to process tensors from an object array + * + * @param[in] objArr Object array containing tensors + * @param[out] tensors Vector of pairs containing tensor data and size + * @return vx_status VX_SUCCESS on success, otherwise an error code + */ + static vx_status processTensors(vx_object_array objArr, std::vector> &tensors) + { + vx_status status = VX_SUCCESS; + vx_size numItems = 0; + vxQueryObjectArray(objArr, VX_OBJECT_ARRAY_NUMITEMS, &numItems, sizeof(numItems)); + + for (vx_uint32 i = 0; i < numItems && status == VX_SUCCESS; ++i) + { + vx_tensor tensor = (vx_tensor)vxGetObjectArrayItem(objArr, i); + vx_size dims[VX_MAX_TENSOR_DIMENSIONS]; + vx_size stride[VX_MAX_TENSOR_DIMENSIONS]; + vx_size viewStart[VX_MAX_TENSOR_DIMENSIONS] = {0}; + void *ptr = nullptr; + vx_size numDims = 0, size = 0; + vx_map_id map_id = 0; + + status |= vxQueryTensor(tensor, VX_TENSOR_NUMBER_OF_DIMS, &numDims, sizeof(numDims)); + status |= vxQueryTensor(tensor, VX_TENSOR_DIMS, dims, sizeof(dims)); + status |= vxQueryTensor(tensor, VX_TENSOR_STRIDE, stride, sizeof(stride)); + status |= vxQueryTensor(tensor, VX_TENSOR_TOTAL_SIZE, &size, sizeof(size)); + status |= vxMapTensorPatch(tensor, numDims, viewStart, dims, &map_id, stride, &ptr, VX_READ_ONLY, VX_MEMORY_TYPE_HOST); + + if (VX_SUCCESS != status) + { + std::cerr << "Error: Unable to prep tensor in " << __func__ << ", status: " << status << std::endl; + break; + } + + tensors.emplace_back((float *)ptr, size); + status |= vxUnmapTensorPatch(tensor, map_id); + } + return status; + } +}; + +/** + * @brief LiteRT CPU Inference Kernel description structure + */ +vx_kernel_description_t tflite_cpu_inf_kernel = + { + VX_KERNEL_LITERT_CPU_INF, // Unique kernel ID + "tflite.cpu.runner", // Kernel name + VxLiteRTRunner::litertRunWrapper, // Kernel execution function + const_cast(VxLiteRTRunner::kernelParams), + dimof(VxLiteRTRunner::kernelParams), // Number of parameters + VxLiteRTRunner::litertValidateWrapper, // Kernel validation function + nullptr, + nullptr, + VxLiteRTRunner::litertInitWrapper, // Kernel initialization function + nullptr}; \ No newline at end of file diff --git a/tests/integration_test/BUILD b/tests/integration_test/BUILD index 539aec0d..a1ab40d6 100644 --- a/tests/integration_test/BUILD +++ b/tests/integration_test/BUILD @@ -25,4 +25,33 @@ # "//tests/raw:models", # ], # size = "small" -# ) \ No newline at end of file +# ) + +cc_test( + name = "test_tflite", + srcs = [ + "test_tflite.cpp" + ], + includes = [ + "include", + "framework/include" + ], + deps = [ + "//:corevx", + "@googletest//:gtest_main", + "//targets/c_model:imported_openvx_c_model", + "//targets/debug:imported_openvx_debug", + "//targets/extras:imported_openvx_extras", + "//targets/opencl:imported_openvx_opencl", + "//targets/liteRT:imported_openvx_liteRT", + ], + linkopts = select({ + "@platforms//os:linux": ["-Wl,-rpath,$ORIGIN"], + "@platforms//os:macos": ["-Wl,-rpath,@executable_path"], + "//conditions:default": [], + }), + data = [ + "//tests/raw:models", + ], + size = "small" +) \ No newline at end of file diff --git a/tests/integration_test/test_tflite.cpp b/tests/integration_test/test_tflite.cpp new file mode 100644 index 00000000..cfe70b9e --- /dev/null +++ b/tests/integration_test/test_tflite.cpp @@ -0,0 +1,133 @@ +/** + * @file test_tflite.cpp + * @brief Test TFLite Target + * @version 0.1 + * @date 2025-04-26 + * + * @copyright Copyright (c) 2025 + * + */ +#include +#include +#include +#include + +#include + +#include "vx_internal.h" + +class TFLiteIntegrationTest : public ::testing::Test +{ +protected: + vx_context context; + vx_graph graph; + std::string model_path = "./tests/raw/matmul_model.tflite"; + + void SetUp() override + { + // Initialize OpenVX context + context = vxCreateContext(); + ASSERT_EQ(vxGetStatus(context), VX_SUCCESS); + } + + void TearDown() override + { + vxReleaseGraph(&graph); + vxReleaseContext(&context); + } +}; + +TEST_F(TFLiteIntegrationTest, TfliteMatMul) +{ + const vx_size numDims = 2u; + vx_size inputADims[] = {3, 4}; + vx_size inputBDims[] = {4, 3}; + vx_size outputDims[] = {3, 3}; + + // Create input tensors + vx_tensor input_a = vxCreateTensor(context, numDims, inputADims, VX_TYPE_FLOAT32, 0); + vx_tensor input_b = vxCreateTensor(context, numDims, inputBDims, VX_TYPE_FLOAT32, 0); + vx_tensor output_c = vxCreateTensor(context, numDims, outputDims, VX_TYPE_FLOAT32, 0); + ASSERT_EQ(vxGetStatus(input_a), VX_SUCCESS); + ASSERT_EQ(vxGetStatus(input_b), VX_SUCCESS); + ASSERT_EQ(vxGetStatus(output_c), VX_SUCCESS); + + // Query tensor strides + vx_size inputAStride[numDims]; + vx_size inputBStride[numDims]; + vx_size outputStride[numDims]; + ASSERT_EQ(VX_SUCCESS, vxQueryTensor(input_a, VX_TENSOR_STRIDE, inputAStride, sizeof(inputAStride))); + ASSERT_EQ(VX_SUCCESS, vxQueryTensor(input_b, VX_TENSOR_STRIDE, inputBStride, sizeof(inputBStride))); + ASSERT_EQ(VX_SUCCESS, vxQueryTensor(output_c, VX_TENSOR_STRIDE, outputStride, sizeof(outputStride))); + + // Create object arrays for inputs and outputs + vx_object_array input_tensors = vxCreateObjectArrayWithType(context, VX_TYPE_TENSOR); + vx_object_array output_tensors = vxCreateObjectArrayWithType(context, VX_TYPE_TENSOR); + ASSERT_EQ(vxGetStatus(input_tensors), VX_SUCCESS); + ASSERT_EQ(vxGetStatus(output_tensors), VX_SUCCESS); + + // Set object array with items + ASSERT_EQ(VX_SUCCESS, vxSetObjectArrayItem(input_tensors, 0, (vx_reference)input_a)); + ASSERT_EQ(VX_SUCCESS, vxSetObjectArrayItem(input_tensors, 1, (vx_reference)input_b)); + ASSERT_EQ(VX_SUCCESS, vxSetObjectArrayItem(output_tensors, 0, (vx_reference)output_c)); + ASSERT_EQ(input_a, (vx_tensor)vxGetObjectArrayItem(input_tensors, 0)); + ASSERT_EQ(input_b, (vx_tensor)vxGetObjectArrayItem(input_tensors, 1)); + ASSERT_EQ(output_c, (vx_tensor)vxGetObjectArrayItem(output_tensors, 0)); + + // Create model path array + vx_array model_path_array = vxCreateArray(context, VX_TYPE_CHAR, model_path.length() + 1); + ASSERT_EQ(vxGetStatus(model_path_array), VX_SUCCESS); + ASSERT_EQ(VX_SUCCESS, vxAddArrayItems(model_path_array, model_path.length() + 1, model_path.c_str(), sizeof(char))); + + // Create graph + graph = vxCreateGraph(context); + ASSERT_EQ(vxGetStatus(graph), VX_SUCCESS); + + // Get tflite kernel + vx_kernel kernel = vxGetKernelByEnum(context, VX_KERNEL_LITERT_CPU_INF); + ASSERT_EQ(vxGetStatus(kernel), VX_SUCCESS); + + // Create node + vx_node node = vxCreateGenericNode(graph, kernel); + ASSERT_EQ(vxGetStatus(node), VX_SUCCESS); + + // Set node parameters + ASSERT_EQ(VX_SUCCESS, vxSetParameterByIndex(node, 0, (vx_reference)model_path_array)); + ASSERT_EQ(VX_SUCCESS, vxSetParameterByIndex(node, 1, (vx_reference)input_tensors)); + ASSERT_EQ(VX_SUCCESS, vxSetParameterByIndex(node, 2, (vx_reference)output_tensors)); + + // Verify graph + ASSERT_EQ(vxVerifyGraph(graph), VX_SUCCESS); + + // Fill input data + vx_float32 input_data_a[12] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}; + vx_float32 input_data_b[12] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}; + vx_size viewStart[VX_MAX_TENSOR_DIMENSIONS] = {0}; + + ASSERT_EQ(VX_SUCCESS, vxCopyTensorPatch((vx_tensor)vxGetObjectArrayItem(input_tensors, 0), numDims, viewStart, inputADims, inputAStride, input_data_a, VX_WRITE_ONLY, VX_MEMORY_TYPE_HOST)); + ASSERT_EQ(VX_SUCCESS, vxCopyTensorPatch((vx_tensor)vxGetObjectArrayItem(input_tensors, 1), numDims, viewStart, inputBDims, inputBStride, input_data_b, VX_WRITE_ONLY, VX_MEMORY_TYPE_HOST)); + + // Process graph + ASSERT_EQ(vxProcessGraph(graph), VX_SUCCESS); + + // Read output + float output_data[9]; + ASSERT_EQ(VX_SUCCESS, vxCopyTensorPatch((vx_tensor)vxGetObjectArrayItem(output_tensors, 0), numDims, viewStart, outputDims, outputStride, output_data, VX_READ_ONLY, VX_MEMORY_TYPE_HOST)); + + // Validate results + float expected[9] = {70, 80, 90, 158, 184, 210, 246, 288, 330}; + for (vx_uint8 i = 0; i < 9; i++) + { + EXPECT_NEAR(output_data[i], expected[i], 1e-5); + } + + // Cleanup + vxReleaseTensor(&input_a); + vxReleaseTensor(&input_b); + vxReleaseTensor(&output_c); + vxReleaseArray(&model_path_array); + vxReleaseObjectArray(&input_tensors); + vxReleaseObjectArray(&output_tensors); + vxReleaseKernel(&kernel); + vxReleaseNode(&node); +} \ No newline at end of file diff --git a/tests/raw/BUILD b/tests/raw/BUILD index 462eb091..07741d0b 100644 --- a/tests/raw/BUILD +++ b/tests/raw/BUILD @@ -15,6 +15,7 @@ filegroup( name = "models", srcs = glob([ "*.onnx", + "*.tflite", ]), visibility = ["//visibility:public"], ) \ No newline at end of file diff --git a/tests/raw/matmul_model.tflite b/tests/raw/matmul_model.tflite new file mode 100644 index 00000000..d920cbc9 Binary files /dev/null and b/tests/raw/matmul_model.tflite differ diff --git a/tests/raw/tf.py b/tests/raw/tf.py new file mode 100644 index 00000000..5220446a --- /dev/null +++ b/tests/raw/tf.py @@ -0,0 +1,26 @@ +""" +TensorFlow Lite Model Conversion Example +This script demonstrates how to convert a TensorFlow model to TensorFlow Lite format. +It includes a simple matrix multiplication model and shows how to save the converted model. +""" +import tensorflow as tf + +# Create a simple MatMul model +class MatMulModel(tf.Module): + @tf.function(input_signature=[ + tf.TensorSpec(shape=[3, 4], dtype=tf.float32), + tf.TensorSpec(shape=[4, 3], dtype=tf.float32) + ]) + def matmul(self, a, b): + return tf.matmul(a, b) + +# Instantiate the model +model = MatMulModel() + +# Convert to TFLite +converter = tf.lite.TFLiteConverter.from_concrete_functions([model.matmul.get_concrete_function()]) +tflite_model = converter.convert() + +# Save the TFLite model +with open("matmul_model.tflite", "wb") as f: + f.write(tflite_model) \ No newline at end of file diff --git a/third_party/BUILD b/third_party/BUILD index fe8628c3..5d45505f 100644 --- a/third_party/BUILD +++ b/third_party/BUILD @@ -1,3 +1,20 @@ """ BUILD file for build defs for third party deps -""" \ No newline at end of file +""" + +cc_import( + name = "tflite", + shared_library = select({ + "@platforms//os:linux": "tflite-hdrs/libtensorflowlite.so", + "@platforms//os:macos": "tflite-hdrs/libtensorflowlite.dylib", + }), + visibility = ["//visibility:public"], +) + +cc_library( + name = "tflite-hdrs", + hdrs = glob(["tflite-hdrs/**/*.h"]), + includes = ["tflite-hdrs"], + deps = ["@flatbuffers"], + visibility = ["//visibility:public"], +) \ No newline at end of file diff --git a/third_party/patch/onnx.patch b/third_party/patch/onnx.patch index ac25ff81..8a7deae3 100644 --- a/third_party/patch/onnx.patch +++ b/third_party/patch/onnx.patch @@ -22,4 +22,24 @@ index 9219f16be0..e1559bd3da 100644 +google_nsync;https://github.com/amikhail48/nsync/archive/refs/tags/1.29.3.zip;1cdfb3b740dadf9a6cc6d6b65976d31f9d9c2900 googletest;https://github.com/google/googletest/archive/refs/tags/v1.15.0.zip;9d2d0af8d77ac726ea55d44a8fa727ec98311349 #xnnpack 2024.09.04 - googlexnnpack;https://github.com/google/XNNPACK/archive/309b75c9e56e0a674bf78d59872ce131f814dfb6.zip;39FA5259EAEACE0547284B63D5CEDC4F05553F5A \ No newline at end of file + googlexnnpack;https://github.com/google/XNNPACK/archive/309b75c9e56e0a674bf78d59872ce131f814dfb6.zip;39FA5259EAEACE0547284B63D5CEDC4F05553F5A +diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen32.h b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen32.h +index af6f52090a..37ae94f1ae 100644 +--- onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen32.h ++++ onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen32.h +@@ -6,10 +6,11 @@ + #include "sqnbitgemm.h" + #include "sqnbitgemm_kernel_avx_common.h" + ++#pragma GCC diagnostic push ++#pragma GCC diagnostic ignored "-Warray-bounds" + + MLAS_FORCEINLINE void +@@ -1044,6 +1050,7 @@ MlasQ4Int8TileGemmKernelBlkLen32Avx2( + BiasPtr += BiasPtr != nullptr ? 1 : 0; + SumPtr += 1; + } + } // m + return CountM; + } ++#pragma GCC diagnostic pop diff --git a/third_party/tflite-hdrs/libtensorflowlite.dylib b/third_party/tflite-hdrs/libtensorflowlite.dylib new file mode 100755 index 00000000..60a3827e Binary files /dev/null and b/third_party/tflite-hdrs/libtensorflowlite.dylib differ diff --git a/third_party/tflite-hdrs/libtensorflowlite.so b/third_party/tflite-hdrs/libtensorflowlite.so new file mode 100755 index 00000000..03fb0b88 Binary files /dev/null and b/third_party/tflite-hdrs/libtensorflowlite.so differ diff --git a/third_party/tflite-hdrs/tensorflow/c/c_api.h b/third_party/tflite-hdrs/tensorflow/c/c_api.h new file mode 100644 index 00000000..9812b0a7 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/c/c_api.h @@ -0,0 +1,1667 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_C_API_H_ +#define TENSORFLOW_C_C_API_H_ + +#include +#include + +#include "tensorflow/c/c_api_macros.h" +#include "tensorflow/c/tf_attrtype.h" +#include "tensorflow/c/tf_buffer.h" +#include "tensorflow/c/tf_datatype.h" +#include "tensorflow/c/tf_status.h" +#include "tensorflow/c/tf_tensor.h" +#include "tensorflow/c/tf_tstring.h" + +// -------------------------------------------------------------------------- +// C API for TensorFlow. +// +// The API leans towards simplicity and uniformity instead of convenience +// since most usage will be by language specific wrappers. +// +// Conventions: +// * We use the prefix TF_ for everything in the API. +// * Objects are always passed around as pointers to opaque structs +// and these structs are allocated/deallocated via the API. +// * TF_Status holds error information. It is an object type +// and therefore is passed around as a pointer to an opaque +// struct as mentioned above. +// * Every call that has a TF_Status* argument clears it on success +// and fills it with error info on failure. +// * unsigned char is used for booleans (instead of the 'bool' type). +// In C++ bool is a keyword while in C99 bool is a macro defined +// in stdbool.h. It is possible for the two to be inconsistent. +// For example, neither the C99 nor the C++11 standard force a byte +// size on the bool type, so the macro defined in stdbool.h could +// be inconsistent with the bool keyword in C++. Thus, the use +// of stdbool.h is avoided and unsigned char is used instead. +// * size_t is used to represent byte sizes of objects that are +// materialized in the address space of the calling process. +// * int is used as an index into arrays. +// * Deletion functions are safe to call on nullptr. +// +// Questions left to address: +// * Might at some point need a way for callers to provide their own Env. +// * Maybe add TF_TensorShape that encapsulates dimension info. +// +// Design decisions made: +// * Backing store for tensor memory has an associated deallocation +// function. This deallocation function will point to client code +// for tensors populated by the client. So the client can do things +// like shadowing a numpy array. +// * We do not provide TF_OK since it is not strictly necessary and we +// are not optimizing for convenience. +// * We make assumption that one session has one graph. This should be +// fine since we have the ability to run sub-graphs. +// * We could allow NULL for some arguments (e.g., NULL options arg). +// However since convenience is not a primary goal, we don't do this. +// * Devices are not in this API. Instead, they are created/used internally +// and the API just provides high level controls over the number of +// devices of each type. + +#ifdef __cplusplus +extern "C" { +#endif + +// -------------------------------------------------------------------------- +// TF_Version returns a string describing version information of the +// TensorFlow library. TensorFlow uses semantic versioning. +TF_CAPI_EXPORT extern const char* TF_Version(void); + +// Parsing a serialized TensorProto into a TF_Tensor. +TF_CAPI_EXPORT extern void TF_TensorFromProto(const TF_Buffer* from, + TF_Tensor* to, TF_Status* status); + +// -------------------------------------------------------------------------- +// Used to return strings across the C API. The caller does not take ownership +// of the underlying data pointer and is not responsible for freeing it. +typedef struct TF_StringView { + const char* data; + size_t len; +} TF_StringView; + +// -------------------------------------------------------------------------- +// TF_SessionOptions holds options that can be passed during session creation. +typedef struct TF_SessionOptions TF_SessionOptions; + +// Return a new options object. +TF_CAPI_EXPORT extern TF_SessionOptions* TF_NewSessionOptions(void); + +// Set the target in TF_SessionOptions.options. +// target can be empty, a single entry, or a comma separated list of entries. +// Each entry is in one of the following formats : +// "local" +// ip:port +// host:port +TF_CAPI_EXPORT extern void TF_SetTarget(TF_SessionOptions* options, + const char* target); + +// Set the config in TF_SessionOptions.options. +// config should be a serialized tensorflow.ConfigProto proto. +// If config was not parsed successfully as a ConfigProto, record the +// error information in *status. +TF_CAPI_EXPORT extern void TF_SetConfig(TF_SessionOptions* options, + const void* proto, size_t proto_len, + TF_Status* status); + +// Destroy an options object. +TF_CAPI_EXPORT extern void TF_DeleteSessionOptions(TF_SessionOptions*); + +// TODO(jeff,sanjay): +// - export functions to set Config fields + +// -------------------------------------------------------------------------- +// The new graph construction API, still under development. + +// Represents a computation graph. Graphs may be shared between sessions. +// Graphs are thread-safe when used as directed below. +typedef struct TF_Graph TF_Graph; + +// Return a new graph object. +TF_CAPI_EXPORT extern TF_Graph* TF_NewGraph(void); + +// Destroy an options object. Graph will be deleted once no more +// TFSession's are referencing it. +TF_CAPI_EXPORT extern void TF_DeleteGraph(TF_Graph*); + +// Operation being built. The underlying graph must outlive this. +typedef struct TF_OperationDescription TF_OperationDescription; + +// Operation that has been added to the graph. Valid until the graph is +// deleted -- in particular adding a new operation to the graph does not +// invalidate old TF_Operation* pointers. +typedef struct TF_Operation TF_Operation; + +// Represents a specific input of an operation. +typedef struct TF_Input { + TF_Operation* oper; + int index; // The index of the input within oper. +} TF_Input; + +// Represents a specific output of an operation. +typedef struct TF_Output { + TF_Operation* oper; + int index; // The index of the output within oper. +} TF_Output; + +// TF_Function is a grouping of operations with defined inputs and outputs. +// Once created and added to graphs, functions can be invoked by creating an +// operation whose operation type matches the function name. +typedef struct TF_Function TF_Function; + +// Function definition options. TODO(iga): Define and implement +typedef struct TF_FunctionOptions TF_FunctionOptions; + +// Sets the shape of the Tensor referenced by `output` in `graph` to +// the shape described by `dims` and `num_dims`. +// +// If the number of dimensions is unknown, `num_dims` must be set to +// -1 and `dims` can be null. If a dimension is unknown, the +// corresponding entry in the `dims` array must be -1. +// +// This does not overwrite the existing shape associated with `output`, +// but merges the input shape with the existing shape. For example, +// setting a shape of [-1, 2] with an existing shape [2, -1] would set +// a final shape of [2, 2] based on shape merging semantics. +// +// Returns an error into `status` if: +// * `output` is not in `graph`. +// * An invalid shape is being set (e.g., the shape being set +// is incompatible with the existing shape). +TF_CAPI_EXPORT extern void TF_GraphSetTensorShape(TF_Graph* graph, + TF_Output output, + const int64_t* dims, + const int num_dims, + TF_Status* status); + +// Returns the number of dimensions of the Tensor referenced by `output` +// in `graph`. +// +// If the number of dimensions in the shape is unknown, returns -1. +// +// Returns an error into `status` if: +// * `output` is not in `graph`. +TF_CAPI_EXPORT extern int TF_GraphGetTensorNumDims(TF_Graph* graph, + TF_Output output, + TF_Status* status); + +// Returns the shape of the Tensor referenced by `output` in `graph` +// into `dims`. `dims` must be an array large enough to hold `num_dims` +// entries (e.g., the return value of TF_GraphGetTensorNumDims). +// +// If the number of dimensions in the shape is unknown or the shape is +// a scalar, `dims` will remain untouched. Otherwise, each element of +// `dims` will be set corresponding to the size of the dimension. An +// unknown dimension is represented by `-1`. +// +// Returns an error into `status` if: +// * `output` is not in `graph`. +// * `num_dims` does not match the actual number of dimensions. +TF_CAPI_EXPORT extern void TF_GraphGetTensorShape(TF_Graph* graph, + TF_Output output, + int64_t* dims, int num_dims, + TF_Status* status); + +// Creates a new operation - see `TF_NewOperation` for more details. +// +// The lock for `graph` must be held when calling this function. +// +// Unless implementing advanced behavior, like custom gradient functions, you +// most likely need to call `TF_NewOperation` instead. +TF_CAPI_EXPORT extern TF_OperationDescription* TF_NewOperationLocked( + TF_Graph* graph, const char* op_type, const char* oper_name); + +// Operation will only be added to *graph when TF_FinishOperation() is +// called (assuming TF_FinishOperation() does not return an error). +// *graph must not be deleted until after TF_FinishOperation() is +// called. +TF_CAPI_EXPORT extern TF_OperationDescription* TF_NewOperation( + TF_Graph* graph, const char* op_type, const char* oper_name); + +// Specify the device for `desc`. Defaults to empty, meaning unconstrained. +TF_CAPI_EXPORT extern void TF_SetDevice(TF_OperationDescription* desc, + const char* device); + +// The calls to TF_AddInput and TF_AddInputList must match (in number, +// order, and type) the op declaration. For example, the "Concat" op +// has registration: +// REGISTER_OP("Concat") +// .Input("concat_dim: int32") +// .Input("values: N * T") +// .Output("output: T") +// .Attr("N: int >= 2") +// .Attr("T: type"); +// that defines two inputs, "concat_dim" and "values" (in that order). +// You must use TF_AddInput() for the first input (since it takes a +// single tensor), and TF_AddInputList() for the second input (since +// it takes a list, even if you were to pass a list with a single +// tensor), as in: +// TF_OperationDescription* desc = TF_NewOperation(graph, "Concat", "c"); +// TF_Output concat_dim_input = {...}; +// TF_AddInput(desc, concat_dim_input); +// TF_Output values_inputs[5] = {{...}, ..., {...}}; +// TF_AddInputList(desc, values_inputs, 5); + +// For inputs that take a single tensor. +TF_CAPI_EXPORT extern void TF_AddInput(TF_OperationDescription* desc, + TF_Output input); + +// For inputs that take a list of tensors. +// inputs must point to TF_Output[num_inputs]. +TF_CAPI_EXPORT extern void TF_AddInputList(TF_OperationDescription* desc, + const TF_Output* inputs, + int num_inputs); + +// Call once per control input to `desc`. +TF_CAPI_EXPORT extern void TF_AddControlInput(TF_OperationDescription* desc, + TF_Operation* input); + +// Request that `desc` be co-located on the device where `op` +// is placed. +// +// Use of this is discouraged since the implementation of device placement is +// subject to change. Primarily intended for internal libraries +TF_CAPI_EXPORT extern void TF_ColocateWith(TF_OperationDescription* desc, + TF_Operation* op); + +// Call some TF_SetAttr*() function for every attr that is not +// inferred from an input and doesn't have a default value you wish to +// keep. + +// `value` must point to a string of length `length` bytes. +TF_CAPI_EXPORT extern void TF_SetAttrString(TF_OperationDescription* desc, + const char* attr_name, + const void* value, size_t length); +// `values` and `lengths` each must have lengths `num_values`. +// `values[i]` must point to a string of length `lengths[i]` bytes. +TF_CAPI_EXPORT extern void TF_SetAttrStringList(TF_OperationDescription* desc, + const char* attr_name, + const void* const* values, + const size_t* lengths, + int num_values); +TF_CAPI_EXPORT extern void TF_SetAttrInt(TF_OperationDescription* desc, + const char* attr_name, int64_t value); +TF_CAPI_EXPORT extern void TF_SetAttrIntList(TF_OperationDescription* desc, + const char* attr_name, + const int64_t* values, + int num_values); +TF_CAPI_EXPORT extern void TF_SetAttrFloat(TF_OperationDescription* desc, + const char* attr_name, float value); +TF_CAPI_EXPORT extern void TF_SetAttrFloatList(TF_OperationDescription* desc, + const char* attr_name, + const float* values, + int num_values); +TF_CAPI_EXPORT extern void TF_SetAttrBool(TF_OperationDescription* desc, + const char* attr_name, + unsigned char value); +TF_CAPI_EXPORT extern void TF_SetAttrBoolList(TF_OperationDescription* desc, + const char* attr_name, + const unsigned char* values, + int num_values); +TF_CAPI_EXPORT extern void TF_SetAttrType(TF_OperationDescription* desc, + const char* attr_name, + TF_DataType value); +TF_CAPI_EXPORT extern void TF_SetAttrTypeList(TF_OperationDescription* desc, + const char* attr_name, + const TF_DataType* values, + int num_values); +TF_CAPI_EXPORT extern void TF_SetAttrPlaceholder(TF_OperationDescription* desc, + const char* attr_name, + const char* placeholder); + +// Set a 'func' attribute to the specified name. +// `value` must point to a string of length `length` bytes. +TF_CAPI_EXPORT extern void TF_SetAttrFuncName(TF_OperationDescription* desc, + const char* attr_name, + const char* value, size_t length); + +// Set `num_dims` to -1 to represent "unknown rank". Otherwise, +// `dims` points to an array of length `num_dims`. `dims[i]` must be +// >= -1, with -1 meaning "unknown dimension". +TF_CAPI_EXPORT extern void TF_SetAttrShape(TF_OperationDescription* desc, + const char* attr_name, + const int64_t* dims, int num_dims); +// `dims` and `num_dims` must point to arrays of length `num_shapes`. +// Set `num_dims[i]` to -1 to represent "unknown rank". Otherwise, +// `dims[i]` points to an array of length `num_dims[i]`. `dims[i][j]` +// must be >= -1, with -1 meaning "unknown dimension". +TF_CAPI_EXPORT extern void TF_SetAttrShapeList(TF_OperationDescription* desc, + const char* attr_name, + const int64_t* const* dims, + const int* num_dims, + int num_shapes); +// `proto` must point to an array of `proto_len` bytes representing a +// binary-serialized TensorShapeProto. +TF_CAPI_EXPORT extern void TF_SetAttrTensorShapeProto( + TF_OperationDescription* desc, const char* attr_name, const void* proto, + size_t proto_len, TF_Status* status); +// `protos` and `proto_lens` must point to arrays of length `num_shapes`. +// `protos[i]` must point to an array of `proto_lens[i]` bytes +// representing a binary-serialized TensorShapeProto. +TF_CAPI_EXPORT extern void TF_SetAttrTensorShapeProtoList( + TF_OperationDescription* desc, const char* attr_name, + const void* const* protos, const size_t* proto_lens, int num_shapes, + TF_Status* status); + +TF_CAPI_EXPORT extern void TF_SetAttrTensor(TF_OperationDescription* desc, + const char* attr_name, + TF_Tensor* value, + TF_Status* status); +TF_CAPI_EXPORT extern void TF_SetAttrTensorList(TF_OperationDescription* desc, + const char* attr_name, + TF_Tensor* const* values, + int num_values, + TF_Status* status); + +// `proto` should point to a sequence of bytes of length `proto_len` +// representing a binary serialization of an AttrValue protocol +// buffer. +TF_CAPI_EXPORT extern void TF_SetAttrValueProto(TF_OperationDescription* desc, + const char* attr_name, + const void* proto, + size_t proto_len, + TF_Status* status); + +// Adds this operation to the graph - see `TF_FinishOperation` for more details. +// +// The lock for `graph` must be held when calling this function. +// +// Unless implementing advanced behavior, like custom gradient functions, you +// most likely need to call `TF_FinishOperation` instead. +TF_CAPI_EXPORT extern TF_Operation* TF_FinishOperationLocked( + TF_OperationDescription* desc, TF_Status* status); + +// If this function succeeds: +// * *status is set to an OK value, +// * a TF_Operation is added to the graph, +// * a non-null value pointing to the added operation is returned -- +// this value is valid until the underlying graph is deleted. +// Otherwise: +// * *status is set to a non-OK value, +// * the graph is not modified, +// * a null value is returned. +// In either case, it deletes `desc`. +TF_CAPI_EXPORT extern TF_Operation* TF_FinishOperation( + TF_OperationDescription* desc, TF_Status* status); + +// TF_Operation functions. Operations are immutable once created, so +// these are all query functions. + +TF_CAPI_EXPORT extern const char* TF_OperationName(TF_Operation* oper); +TF_CAPI_EXPORT extern const char* TF_OperationOpType(TF_Operation* oper); +TF_CAPI_EXPORT extern const char* TF_OperationDevice(TF_Operation* oper); + +TF_CAPI_EXPORT extern int TF_OperationNumOutputs(TF_Operation* oper); +TF_CAPI_EXPORT extern TF_DataType TF_OperationOutputType(TF_Output oper_out); +TF_CAPI_EXPORT extern int TF_OperationOutputListLength(TF_Operation* oper, + const char* arg_name, + TF_Status* status); + +TF_CAPI_EXPORT extern int TF_OperationNumInputs(TF_Operation* oper); +TF_CAPI_EXPORT extern TF_DataType TF_OperationInputType(TF_Input oper_in); +TF_CAPI_EXPORT extern int TF_OperationInputListLength(TF_Operation* oper, + const char* arg_name, + TF_Status* status); + +// In this code: +// TF_Output producer = TF_OperationInput(consumer); +// There is an edge from producer.oper's output (given by +// producer.index) to consumer.oper's input (given by consumer.index). +TF_CAPI_EXPORT extern TF_Output TF_OperationInput(TF_Input oper_in); + +// Get list of all inputs of a specific operation. `inputs` must point to +// an array of length at least `max_inputs` (ideally set to +// TF_OperationNumInputs(oper)). Beware that a concurrent +// modification of the graph can increase the number of inputs of +// an operation. +TF_CAPI_EXPORT extern void TF_OperationAllInputs(TF_Operation* oper, + TF_Output* inputs, + int max_inputs); + +// Get the number of current consumers of a specific output of an +// operation. Note that this number can change when new operations +// are added to the graph. +TF_CAPI_EXPORT extern int TF_OperationOutputNumConsumers(TF_Output oper_out); + +// Get list of all current consumers of a specific output of an +// operation. `consumers` must point to an array of length at least +// `max_consumers` (ideally set to +// TF_OperationOutputNumConsumers(oper_out)). Beware that a concurrent +// modification of the graph can increase the number of consumers of +// an operation. Returns the number of output consumers (should match +// TF_OperationOutputNumConsumers(oper_out)). +TF_CAPI_EXPORT extern int TF_OperationOutputConsumers(TF_Output oper_out, + TF_Input* consumers, + int max_consumers); + +// Get the number of control inputs to an operation. +TF_CAPI_EXPORT extern int TF_OperationNumControlInputs(TF_Operation* oper); + +// Get list of all control inputs to an operation. `control_inputs` must +// point to an array of length `max_control_inputs` (ideally set to +// TF_OperationNumControlInputs(oper)). Returns the number of control +// inputs (should match TF_OperationNumControlInputs(oper)). +TF_CAPI_EXPORT extern int TF_OperationGetControlInputs( + TF_Operation* oper, TF_Operation** control_inputs, int max_control_inputs); + +// Get the number of operations that have `*oper` as a control input. +// Note that this number can change when new operations are added to +// the graph. +TF_CAPI_EXPORT extern int TF_OperationNumControlOutputs(TF_Operation* oper); + +// Get the list of operations that have `*oper` as a control input. +// `control_outputs` must point to an array of length at least +// `max_control_outputs` (ideally set to +// TF_OperationNumControlOutputs(oper)). Beware that a concurrent +// modification of the graph can increase the number of control +// outputs. Returns the number of control outputs (should match +// TF_OperationNumControlOutputs(oper)). +TF_CAPI_EXPORT extern int TF_OperationGetControlOutputs( + TF_Operation* oper, TF_Operation** control_outputs, + int max_control_outputs); + +// TF_AttrMetadata describes the value of an attribute on an operation. +typedef struct TF_AttrMetadata { + // A boolean: 1 if the attribute value is a list, 0 otherwise. + unsigned char is_list; + + // Length of the list if is_list is true. Undefined otherwise. + int64_t list_size; + + // Type of elements of the list if is_list != 0. + // Type of the single value stored in the attribute if is_list == 0. + TF_AttrType type; + + // Total size the attribute value. + // The units of total_size depend on is_list and type. + // (1) If type == TF_ATTR_STRING and is_list == 0 + // then total_size is the byte size of the string + // valued attribute. + // (2) If type == TF_ATTR_STRING and is_list == 1 + // then total_size is the cumulative byte size + // of all the strings in the list. + // (3) If type == TF_ATTR_SHAPE and is_list == 0 + // then total_size is the number of dimensions + // of the shape valued attribute, or -1 + // if its rank is unknown. + // (4) If type == TF_ATTR_SHAPE and is_list == 1 + // then total_size is the cumulative number + // of dimensions of all shapes in the list. + // (5) Otherwise, total_size is undefined. + int64_t total_size; +} TF_AttrMetadata; + +// Returns metadata about the value of the attribute `attr_name` of `oper`. +TF_CAPI_EXPORT extern TF_AttrMetadata TF_OperationGetAttrMetadata( + TF_Operation* oper, const char* attr_name, TF_Status* status); + +// Fills in `value` with the value of the attribute `attr_name`. `value` must +// point to an array of length at least `max_length` (ideally set to +// TF_AttrMetadata.total_size from TF_OperationGetAttrMetadata(oper, +// attr_name)). +TF_CAPI_EXPORT extern void TF_OperationGetAttrString(TF_Operation* oper, + const char* attr_name, + void* value, + size_t max_length, + TF_Status* status); + +// Get the list of strings in the value of the attribute `attr_name`. Fills in +// `values` and `lengths`, each of which must point to an array of length at +// least `max_values`. +// +// The elements of values will point to addresses in `storage` which must be at +// least `storage_size` bytes in length. Ideally, max_values would be set to +// TF_AttrMetadata.list_size and `storage` would be at least +// TF_AttrMetadata.total_size, obtained from TF_OperationGetAttrMetadata(oper, +// attr_name). +// +// Fails if storage_size is too small to hold the requested number of strings. +TF_CAPI_EXPORT extern void TF_OperationGetAttrStringList( + TF_Operation* oper, const char* attr_name, void** values, size_t* lengths, + int max_values, void* storage, size_t storage_size, TF_Status* status); + +TF_CAPI_EXPORT extern void TF_OperationGetAttrInt(TF_Operation* oper, + const char* attr_name, + int64_t* value, + TF_Status* status); + +// Fills in `values` with the value of the attribute `attr_name` of `oper`. +// `values` must point to an array of length at least `max_values` (ideally set +// TF_AttrMetadata.list_size from TF_OperationGetAttrMetadata(oper, +// attr_name)). +TF_CAPI_EXPORT extern void TF_OperationGetAttrIntList(TF_Operation* oper, + const char* attr_name, + int64_t* values, + int max_values, + TF_Status* status); + +TF_CAPI_EXPORT extern void TF_OperationGetAttrFloat(TF_Operation* oper, + const char* attr_name, + float* value, + TF_Status* status); + +// Fills in `values` with the value of the attribute `attr_name` of `oper`. +// `values` must point to an array of length at least `max_values` (ideally set +// to TF_AttrMetadata.list_size from TF_OperationGetAttrMetadata(oper, +// attr_name)). +TF_CAPI_EXPORT extern void TF_OperationGetAttrFloatList(TF_Operation* oper, + const char* attr_name, + float* values, + int max_values, + TF_Status* status); + +TF_CAPI_EXPORT extern void TF_OperationGetAttrBool(TF_Operation* oper, + const char* attr_name, + unsigned char* value, + TF_Status* status); + +// Fills in `values` with the value of the attribute `attr_name` of `oper`. +// `values` must point to an array of length at least `max_values` (ideally set +// to TF_AttrMetadata.list_size from TF_OperationGetAttrMetadata(oper, +// attr_name)). +TF_CAPI_EXPORT extern void TF_OperationGetAttrBoolList(TF_Operation* oper, + const char* attr_name, + unsigned char* values, + int max_values, + TF_Status* status); + +TF_CAPI_EXPORT extern void TF_OperationGetAttrType(TF_Operation* oper, + const char* attr_name, + TF_DataType* value, + TF_Status* status); + +// Fills in `values` with the value of the attribute `attr_name` of `oper`. +// `values` must point to an array of length at least `max_values` (ideally set +// to TF_AttrMetadata.list_size from TF_OperationGetAttrMetadata(oper, +// attr_name)). +TF_CAPI_EXPORT extern void TF_OperationGetAttrTypeList(TF_Operation* oper, + const char* attr_name, + TF_DataType* values, + int max_values, + TF_Status* status); + +// Fills in `value` with the value of the attribute `attr_name` of `oper`. +// `values` must point to an array of length at least `num_dims` (ideally set to +// TF_Attr_Meta.size from TF_OperationGetAttrMetadata(oper, attr_name)). +TF_CAPI_EXPORT extern void TF_OperationGetAttrShape(TF_Operation* oper, + const char* attr_name, + int64_t* value, + int num_dims, + TF_Status* status); + +// Fills in `dims` with the list of shapes in the attribute `attr_name` of +// `oper` and `num_dims` with the corresponding number of dimensions. On return, +// for every i where `num_dims[i]` > 0, `dims[i]` will be an array of +// `num_dims[i]` elements. A value of -1 for `num_dims[i]` indicates that the +// i-th shape in the list is unknown. +// +// The elements of `dims` will point to addresses in `storage` which must be +// large enough to hold at least `storage_size` int64_ts. Ideally, `num_shapes` +// would be set to TF_AttrMetadata.list_size and `storage_size` would be set to +// TF_AttrMetadata.total_size from TF_OperationGetAttrMetadata(oper, +// attr_name). +// +// Fails if storage_size is insufficient to hold the requested shapes. +TF_CAPI_EXPORT extern void TF_OperationGetAttrShapeList( + TF_Operation* oper, const char* attr_name, int64_t** dims, int* num_dims, + int num_shapes, int64_t* storage, int storage_size, TF_Status* status); + +// Sets `value` to the binary-serialized TensorShapeProto of the value of +// `attr_name` attribute of `oper`. +TF_CAPI_EXPORT extern void TF_OperationGetAttrTensorShapeProto( + TF_Operation* oper, const char* attr_name, TF_Buffer* value, + TF_Status* status); + +// Fills in `values` with binary-serialized TensorShapeProto values of the +// attribute `attr_name` of `oper`. `values` must point to an array of length at +// least `num_values` (ideally set to TF_AttrMetadata.list_size from +// TF_OperationGetAttrMetadata(oper, attr_name)). +TF_CAPI_EXPORT extern void TF_OperationGetAttrTensorShapeProtoList( + TF_Operation* oper, const char* attr_name, TF_Buffer** values, + int max_values, TF_Status* status); + +// Gets the TF_Tensor valued attribute of `attr_name` of `oper`. +// +// Allocates a new TF_Tensor which the caller is expected to take +// ownership of (and can deallocate using TF_DeleteTensor). +TF_CAPI_EXPORT extern void TF_OperationGetAttrTensor(TF_Operation* oper, + const char* attr_name, + TF_Tensor** value, + TF_Status* status); + +// Fills in `values` with the TF_Tensor values of the attribute `attr_name` of +// `oper`. `values` must point to an array of TF_Tensor* of length at least +// `max_values` (ideally set to TF_AttrMetadata.list_size from +// TF_OperationGetAttrMetadata(oper, attr_name)). +// +// The caller takes ownership of all the non-null TF_Tensor* entries in `values` +// (which can be deleted using TF_DeleteTensor(values[i])). +TF_CAPI_EXPORT extern void TF_OperationGetAttrTensorList(TF_Operation* oper, + const char* attr_name, + TF_Tensor** values, + int max_values, + TF_Status* status); + +// Sets `output_attr_value` to the binary-serialized AttrValue proto +// representation of the value of the `attr_name` attr of `oper`. +TF_CAPI_EXPORT extern void TF_OperationGetAttrValueProto( + TF_Operation* oper, const char* attr_name, TF_Buffer* output_attr_value, + TF_Status* status); + +// Get the number of attributes the operation has. +TF_CAPI_EXPORT extern int TF_OperationGetNumAttrs(TF_Operation* oper); + +// Get the length of the name of the ith attribute, or -1 if there is not an +// ith attribute. +TF_CAPI_EXPORT extern int TF_OperationGetAttrNameLength(TF_Operation* oper, + int i); + +// Get the name of the ith attribute. output should have the size of +// TF_OperationGetAttrNameLength(oper, i). +TF_CAPI_EXPORT extern void TF_OperationGetAttrName(TF_Operation* oper, int i, + char* output, + TF_Status* status); + +// Returns the operation in the graph with `oper_name`. Returns nullptr if +// no operation found. +TF_CAPI_EXPORT extern TF_Operation* TF_GraphOperationByName( + TF_Graph* graph, const char* oper_name); + +// Iterate through the operations of a graph. To use: +// size_t pos = 0; +// TF_Operation* oper; +// while ((oper = TF_GraphNextOperation(graph, &pos)) != nullptr) { +// DoSomethingWithOperation(oper); +// } +TF_CAPI_EXPORT extern TF_Operation* TF_GraphNextOperation(TF_Graph* graph, + size_t* pos); + +// Write out a serialized representation of `graph` (as a GraphDef protocol +// message) to `output_graph_def` (allocated by TF_NewBuffer()). +// `output_graph_def`'s underlying buffer will be freed when TF_DeleteBuffer() +// is called. +// +// May fail on very large graphs in the future. +TF_CAPI_EXPORT extern void TF_GraphToGraphDef(TF_Graph* graph, + TF_Buffer* output_graph_def, + TF_Status* status); + +// Returns the serialized OpDef proto with name `op_name`, or a bad status if no +// such op exists. This can return OpDefs of functions copied into the graph. +TF_CAPI_EXPORT extern void TF_GraphGetOpDef(TF_Graph* graph, + const char* op_name, + TF_Buffer* output_op_def, + TF_Status* status); + +// Returns the serialized VersionDef proto for this graph. +TF_CAPI_EXPORT extern void TF_GraphVersions(TF_Graph* graph, + TF_Buffer* output_version_def, + TF_Status* status); + +// TF_ImportGraphDefOptions holds options that can be passed to +// TF_GraphImportGraphDef. +typedef struct TF_ImportGraphDefOptions TF_ImportGraphDefOptions; + +TF_CAPI_EXPORT extern TF_ImportGraphDefOptions* TF_NewImportGraphDefOptions( + void); +TF_CAPI_EXPORT extern void TF_DeleteImportGraphDefOptions( + TF_ImportGraphDefOptions* opts); + +// Set the prefix to be prepended to the names of nodes in `graph_def` that will +// be imported into `graph`. `prefix` is copied and has no lifetime +// requirements. +TF_CAPI_EXPORT extern void TF_ImportGraphDefOptionsSetPrefix( + TF_ImportGraphDefOptions* opts, const char* prefix); + +// Set the execution device for nodes in `graph_def`. +// Only applies to nodes where a device was not already explicitly specified. +// `device` is copied and has no lifetime requirements. +TF_CAPI_EXPORT extern void TF_ImportGraphDefOptionsSetDefaultDevice( + TF_ImportGraphDefOptions* opts, const char* device); + +// Set whether to uniquify imported operation names. If true, imported operation +// names will be modified if their name already exists in the graph. If false, +// conflicting names will be treated as an error. Note that this option has no +// effect if a prefix is set, since the prefix will guarantee all names are +// unique. Defaults to false. +TF_CAPI_EXPORT extern void TF_ImportGraphDefOptionsSetUniquifyNames( + TF_ImportGraphDefOptions* opts, unsigned char uniquify_names); + +// If true, the specified prefix will be modified if it already exists as an +// operation name or prefix in the graph. If false, a conflicting prefix will be +// treated as an error. This option has no effect if no prefix is specified. +TF_CAPI_EXPORT extern void TF_ImportGraphDefOptionsSetUniquifyPrefix( + TF_ImportGraphDefOptions* opts, unsigned char uniquify_prefix); + +// Set any imported nodes with input `src_name:src_index` to have that input +// replaced with `dst`. `src_name` refers to a node in the graph to be imported, +// `dst` references a node already existing in the graph being imported into. +// `src_name` is copied and has no lifetime requirements. +TF_CAPI_EXPORT extern void TF_ImportGraphDefOptionsAddInputMapping( + TF_ImportGraphDefOptions* opts, const char* src_name, int src_index, + TF_Output dst); + +// Set any imported nodes with control input `src_name` to have that input +// replaced with `dst`. `src_name` refers to a node in the graph to be imported, +// `dst` references an operation already existing in the graph being imported +// into. `src_name` is copied and has no lifetime requirements. +TF_CAPI_EXPORT extern void TF_ImportGraphDefOptionsRemapControlDependency( + TF_ImportGraphDefOptions* opts, const char* src_name, TF_Operation* dst); + +// Cause the imported graph to have a control dependency on `oper`. `oper` +// should exist in the graph being imported into. +TF_CAPI_EXPORT extern void TF_ImportGraphDefOptionsAddControlDependency( + TF_ImportGraphDefOptions* opts, TF_Operation* oper); + +// Add an output in `graph_def` to be returned via the `return_outputs` output +// parameter of TF_GraphImportGraphDef(). If the output is remapped via an input +// mapping, the corresponding existing tensor in `graph` will be returned. +// `oper_name` is copied and has no lifetime requirements. +TF_CAPI_EXPORT extern void TF_ImportGraphDefOptionsAddReturnOutput( + TF_ImportGraphDefOptions* opts, const char* oper_name, int index); + +// Returns the number of return outputs added via +// TF_ImportGraphDefOptionsAddReturnOutput(). +TF_CAPI_EXPORT extern int TF_ImportGraphDefOptionsNumReturnOutputs( + const TF_ImportGraphDefOptions* opts); + +// Add an operation in `graph_def` to be returned via the `return_opers` output +// parameter of TF_GraphImportGraphDef(). `oper_name` is copied and has no +// lifetime requirements. +TF_CAPI_EXPORT extern void TF_ImportGraphDefOptionsAddReturnOperation( + TF_ImportGraphDefOptions* opts, const char* oper_name); + +// Returns the number of return operations added via +// TF_ImportGraphDefOptionsAddReturnOperation(). +TF_CAPI_EXPORT extern int TF_ImportGraphDefOptionsNumReturnOperations( + const TF_ImportGraphDefOptions* opts); + +// TF_ImportGraphDefResults holds results that are generated by +// TF_GraphImportGraphDefWithResults(). +typedef struct TF_ImportGraphDefResults TF_ImportGraphDefResults; + +// Fetches the return outputs requested via +// TF_ImportGraphDefOptionsAddReturnOutput(). The number of fetched outputs is +// returned in `num_outputs`. The array of return outputs is returned in +// `outputs`. `*outputs` is owned by and has the lifetime of `results`. +TF_CAPI_EXPORT extern void TF_ImportGraphDefResultsReturnOutputs( + TF_ImportGraphDefResults* results, int* num_outputs, TF_Output** outputs); + +// Fetches the return operations requested via +// TF_ImportGraphDefOptionsAddReturnOperation(). The number of fetched +// operations is returned in `num_opers`. The array of return operations is +// returned in `opers`. `*opers` is owned by and has the lifetime of `results`. +TF_CAPI_EXPORT extern void TF_ImportGraphDefResultsReturnOperations( + TF_ImportGraphDefResults* results, int* num_opers, TF_Operation*** opers); + +// Fetches any input mappings requested via +// TF_ImportGraphDefOptionsAddInputMapping() that didn't appear in the GraphDef +// and weren't used as input to any node in the imported graph def. The number +// of fetched mappings is returned in `num_missing_unused_input_mappings`. The +// array of each mapping's source node name is returned in `src_names`, and the +// array of each mapping's source index is returned in `src_indexes`. +// +// `*src_names`, `*src_indexes`, and the memory backing each string in +// `src_names` are owned by and have the lifetime of `results`. +TF_CAPI_EXPORT extern void TF_ImportGraphDefResultsMissingUnusedInputMappings( + TF_ImportGraphDefResults* results, int* num_missing_unused_input_mappings, + const char*** src_names, int** src_indexes); + +// Deletes a results object returned by TF_GraphImportGraphDefWithResults(). +TF_CAPI_EXPORT extern void TF_DeleteImportGraphDefResults( + TF_ImportGraphDefResults* results); + +// Import the graph serialized in `graph_def` into `graph`. Returns nullptr and +// a bad status on error. Otherwise, returns a populated +// TF_ImportGraphDefResults instance. The returned instance must be deleted via +// TF_DeleteImportGraphDefResults(). +TF_CAPI_EXPORT extern TF_ImportGraphDefResults* +TF_GraphImportGraphDefWithResults(TF_Graph* graph, const TF_Buffer* graph_def, + const TF_ImportGraphDefOptions* options, + TF_Status* status); + +// Has the same behavior as TF_GraphImportGraphDefWithResults, but instead of +// taking in a serialized tensorflow::GraphDef, it takes in a *pointer* to the +// C++ *in memory representation* of the GraphDef, stored in `graph_def->data` +TF_CAPI_EXPORT extern TF_ImportGraphDefResults* +TF_GraphImportGraphDefWithResultsNoSerialization( + TF_Graph* graph, const TF_Buffer* graph_def, + const TF_ImportGraphDefOptions* options, TF_Status* status); + +// Import the graph serialized in `graph_def` into `graph`. +// Convenience function for when only return outputs are needed. +// +// `num_return_outputs` must be the number of return outputs added (i.e. the +// result of TF_ImportGraphDefOptionsNumReturnOutputs()). If +// `num_return_outputs` is non-zero, `return_outputs` must be of length +// `num_return_outputs`. Otherwise it can be null. +TF_CAPI_EXPORT extern void TF_GraphImportGraphDefWithReturnOutputs( + TF_Graph* graph, const TF_Buffer* graph_def, + const TF_ImportGraphDefOptions* options, TF_Output* return_outputs, + int num_return_outputs, TF_Status* status); + +// Import the graph serialized in `graph_def` into `graph`. +// Convenience function for when no results are needed. +TF_CAPI_EXPORT extern void TF_GraphImportGraphDef( + TF_Graph* graph, const TF_Buffer* graph_def, + const TF_ImportGraphDefOptions* options, TF_Status* status); + +// Adds a copy of function `func` and optionally its gradient function `grad` +// to `g`. Once `func`/`grad` is added to `g`, it can be called by creating +// an operation using the function's name. +// Any changes to `func`/`grad` (including deleting it) done after this method +// returns, won't affect the copy of `func`/`grad` in `g`. +// If `func` or `grad` are already in `g`, TF_GraphCopyFunction has no +// effect on them, but can establish the function->gradient relationship +// between them if `func` does not already have a gradient. If `func` already +// has a gradient different from `grad`, an error is returned. +// +// `func` must not be null. +// If `grad` is null and `func` is not in `g`, `func` is added without a +// gradient. +// If `grad` is null and `func` is in `g`, TF_GraphCopyFunction is a noop. +// `grad` must have appropriate signature as described in the doc of +// GradientDef in tensorflow/core/framework/function.proto. +// +// If successful, status is set to OK and `func` and `grad` are added to `g`. +// Otherwise, status is set to the encountered error and `g` is unmodified. +TF_CAPI_EXPORT extern void TF_GraphCopyFunction(TF_Graph* g, + const TF_Function* func, + const TF_Function* grad, + TF_Status* status); + +// Returns the number of TF_Functions registered in `g`. +TF_CAPI_EXPORT extern int TF_GraphNumFunctions(TF_Graph* g); + +// Fills in `funcs` with the TF_Function* registered in `g`. +// `funcs` must point to an array of TF_Function* of length at least +// `max_func`. In usual usage, max_func should be set to the result of +// TF_GraphNumFunctions(g). In this case, all the functions registered in +// `g` will be returned. Else, an unspecified subset. +// +// If successful, returns the number of TF_Function* successfully set in +// `funcs` and sets status to OK. The caller takes ownership of +// all the returned TF_Functions. They must be deleted with TF_DeleteFunction. +// On error, returns 0, sets status to the encountered error, and the contents +// of funcs will be undefined. +TF_CAPI_EXPORT extern int TF_GraphGetFunctions(TF_Graph* g, TF_Function** funcs, + int max_func, TF_Status* status); + +// Note: The following function may fail on very large protos in the future. + +TF_CAPI_EXPORT extern void TF_OperationToNodeDef(TF_Operation* oper, + TF_Buffer* output_node_def, + TF_Status* status); + +typedef struct TF_WhileParams { + // The number of inputs to the while loop, i.e. the number of loop variables. + // This is the size of cond_inputs, body_inputs, and body_outputs. + const int ninputs; + + // The while condition graph. The inputs are the current values of the loop + // variables. The output should be a scalar boolean. + TF_Graph* const cond_graph; + const TF_Output* const cond_inputs; + TF_Output cond_output; + + // The loop body graph. The inputs are the current values of the loop + // variables. The outputs are the updated values of the loop variables. + TF_Graph* const body_graph; + const TF_Output* const body_inputs; + TF_Output* const body_outputs; + + // Unique null-terminated name for this while loop. This is used as a prefix + // for created operations. + const char* name; +} TF_WhileParams; + +// Creates a TF_WhileParams for creating a while loop in `g`. `inputs` are +// outputs that already exist in `g` used as initial values for the loop +// variables. +// +// The returned TF_WhileParams will have all fields initialized except +// `cond_output`, `body_outputs`, and `name`. The `body_outputs` buffer will be +// allocated to size `ninputs`. The caller should build `cond_graph` and +// `body_graph` starting from the inputs, and store the final outputs in +// `cond_output` and `body_outputs`. +// +// If `status` is OK, the caller must call either TF_FinishWhile or +// TF_AbortWhile on the returned TF_WhileParams. If `status` isn't OK, the +// returned TF_WhileParams is not valid, and the caller should not call +// TF_FinishWhile() or TF_AbortWhile(). +// +// Missing functionality (TODO): +// - Gradients +// - Reference-type inputs +// - Directly referencing external tensors from the cond/body graphs (this is +// possible in the Python API) +TF_CAPI_EXPORT extern TF_WhileParams TF_NewWhile(TF_Graph* g, TF_Output* inputs, + int ninputs, + TF_Status* status); + +// Builds the while loop specified by `params` and returns the output tensors of +// the while loop in `outputs`. `outputs` should be allocated to size +// `params.ninputs`. +// +// `params` is no longer valid once this returns. +// +// Either this or TF_AbortWhile() must be called after a successful +// TF_NewWhile() call. +TF_CAPI_EXPORT extern void TF_FinishWhile(const TF_WhileParams* params, + TF_Status* status, + TF_Output* outputs); + +// Frees `params`s resources without building a while loop. `params` is no +// longer valid after this returns. Either this or TF_FinishWhile() must be +// called after a successful TF_NewWhile() call. +TF_CAPI_EXPORT extern void TF_AbortWhile(const TF_WhileParams* params); + +// Adds operations to compute the partial derivatives of sum of `y`s w.r.t `x`s, +// i.e., d(y_1 + y_2 + ...)/dx_1, d(y_1 + y_2 + ...)/dx_2... +// +// `dx` are used as initial gradients (which represent the symbolic partial +// derivatives of some loss function `L` w.r.t. `y`). +// `dx` must be nullptr or have size `ny`. +// If `dx` is nullptr, the implementation will use dx of `OnesLike` for all +// shapes in `y`. +// The partial derivatives are returned in `dy`. `dy` should be allocated to +// size `nx`. +// +// Gradient nodes are automatically named under the "gradients/" prefix. To +// guarantee name uniqueness, subsequent calls to the same graph will +// append an incremental tag to the prefix: "gradients_1/", "gradients_2/", ... +// See TF_AddGradientsWithPrefix, which provides a means to specify a custom +// name prefix for operations added to a graph to compute the gradients. +// +// WARNING: This function does not yet support all the gradients that python +// supports. See +// https://www.tensorflow.org/code/tensorflow/cc/gradients/README.md +// for instructions on how to add C++ more gradients. +TF_CAPI_EXPORT void TF_AddGradients(TF_Graph* g, TF_Output* y, int ny, + TF_Output* x, int nx, TF_Output* dx, + TF_Status* status, TF_Output* dy); + +// Adds operations to compute the partial derivatives of sum of `y`s w.r.t `x`s, +// i.e., d(y_1 + y_2 + ...)/dx_1, d(y_1 + y_2 + ...)/dx_2... +// This is a variant of TF_AddGradients that allows to caller to pass a custom +// name prefix to the operations added to a graph to compute the gradients. +// +// `dx` are used as initial gradients (which represent the symbolic partial +// derivatives of some loss function `L` w.r.t. `y`). +// `dx` must be nullptr or have size `ny`. +// If `dx` is nullptr, the implementation will use dx of `OnesLike` for all +// shapes in `y`. +// The partial derivatives are returned in `dy`. `dy` should be allocated to +// size `nx`. +// `prefix` names the scope into which all gradients operations are being added. +// `prefix` must be unique within the provided graph otherwise this operation +// will fail. If `prefix` is nullptr, the default prefixing behaviour takes +// place, see TF_AddGradients for more details. +// +// WARNING: This function does not yet support all the gradients that python +// supports. See +// https://www.tensorflow.org/code/tensorflow/cc/gradients/README.md +// for instructions on how to add C++ more gradients. +TF_CAPI_EXPORT void TF_AddGradientsWithPrefix(TF_Graph* g, const char* prefix, + TF_Output* y, int ny, + TF_Output* x, int nx, + TF_Output* dx, TF_Status* status, + TF_Output* dy); + +// Create a TF_Function from a TF_Graph +// +// Params: +// fn_body - the graph whose operations (or subset of whose operations) will be +// converted to TF_Function. +// fn_name - the name of the new TF_Function. Should match the operation +// name (OpDef.name) regexp [A-Z][A-Za-z0-9_.\\-/]*. +// If `append_hash_to_fn_name` is false, `fn_name` must be distinct +// from other function and operation names (at least those +// registered in graphs where this function will be used). +// append_hash_to_fn_name - Must be 0 or 1. If set to 1, the actual name +// of the function will be `fn_name` appended with +// '_'. +// If set to 0, the function's name will be `fn_name`. +// num_opers - `num_opers` contains the number of elements in the `opers` array +// or a special value of -1 meaning that no array is given. +// The distinction between an empty array of operations and no +// array of operations is necessary to distinguish the case of +// creating a function with no body (e.g. identity or permutation) +// and the case of creating a function whose body contains all +// the nodes in the graph (except for the automatic skipping, see +// below). +// opers - Array of operations to become the body of the function or null. +// - If no array is given (`num_opers` = -1), all the +// operations in `fn_body` will become part of the function +// except operations referenced in `inputs`. These operations +// must have a single output (these operations are typically +// placeholders created for the sole purpose of representing +// an input. We can relax this constraint if there are +// compelling use cases). +// - If an array is given (`num_opers` >= 0), all operations +// in it will become part of the function. In particular, no +// automatic skipping of dummy input operations is performed. +// ninputs - number of elements in `inputs` array +// inputs - array of TF_Outputs that specify the inputs to the function. +// If `ninputs` is zero (the function takes no inputs), `inputs` +// can be null. The names used for function inputs are normalized +// names of the operations (usually placeholders) pointed to by +// `inputs`. These operation names should start with a letter. +// Normalization will convert all letters to lowercase and +// non-alphanumeric characters to '_' to make resulting names match +// the "[a-z][a-z0-9_]*" pattern for operation argument names. +// `inputs` cannot contain the same tensor twice. +// noutputs - number of elements in `outputs` array +// outputs - array of TF_Outputs that specify the outputs of the function. +// If `noutputs` is zero (the function returns no outputs), `outputs` +// can be null. `outputs` can contain the same tensor more than once. +// output_names - The names of the function's outputs. `output_names` array +// must either have the same length as `outputs` +// (i.e. `noutputs`) or be null. In the former case, +// the names should match the regular expression for ArgDef +// names - "[a-z][a-z0-9_]*". In the latter case, +// names for outputs will be generated automatically. +// opts - various options for the function, e.g. XLA's inlining control. +// description - optional human-readable description of this function. +// status - Set to OK on success and an appropriate error on failure. +// +// Note that when the same TF_Output is listed as both an input and an output, +// the corresponding function's output will equal to this input, +// instead of the original node's output. +// +// Callers must also satisfy the following constraints: +// - `inputs` cannot refer to TF_Outputs within a control flow context. For +// example, one cannot use the output of "switch" node as input. +// - `inputs` and `outputs` cannot have reference types. Reference types are +// not exposed through C API and are being replaced with Resources. We support +// reference types inside function's body to support legacy code. Do not +// use them in new code. +// - Every node in the function's body must have all of its inputs (including +// control inputs). In other words, for every node in the body, each input +// must be either listed in `inputs` or must come from another node in +// the body. In particular, it is an error to have a control edge going from +// a node outside of the body into a node in the body. This applies to control +// edges going from nodes referenced in `inputs` to nodes in the body when +// the former nodes are not in the body (automatically skipped or not +// included in explicitly specified body). +// +// Returns: +// On success, a newly created TF_Function instance. It must be deleted by +// calling TF_DeleteFunction. +// +// On failure, null. +TF_CAPI_EXPORT extern TF_Function* TF_GraphToFunction( + const TF_Graph* fn_body, const char* fn_name, + unsigned char append_hash_to_fn_name, int num_opers, + const TF_Operation* const* opers, int ninputs, const TF_Output* inputs, + int noutputs, const TF_Output* outputs, const char* const* output_names, + const TF_FunctionOptions* opts, const char* description, TF_Status* status); + +// Similar to TF_GraphToFunction but allows specifying control outputs of the +// function. +// +// The arguments of TF_GraphToFunction have the same meaning, but the new +// arguments are as follows: +// +// ncontrol_outputs: Number of control outputs of the function. +// control_outputs: vector of TF_Operation objects to be marked as control +// outputs of the function. Operations marked as control outputs are +// guaranteed to execute. +// control_output_names: Optional. If not nullptr, vector of strings, one +// per control output, with their names to be added to the function's +// OpDef. +TF_CAPI_EXPORT extern TF_Function* TF_GraphToFunctionWithControlOutputs( + const TF_Graph* fn_body, const char* fn_name, + unsigned char append_hash_to_fn_name, int num_opers, + const TF_Operation* const* opers, int ninputs, const TF_Output* inputs, + int noutputs, const TF_Output* outputs, const char* const* output_names, + int ncontrol_outputs, const TF_Operation* const* control_outputs, + const char* const* control_output_names, const TF_FunctionOptions* opts, + const char* description, TF_Status* status); + +// Returns the name of the graph function. +// The return value points to memory that is only usable until the next +// mutation to *func. +TF_CAPI_EXPORT extern const char* TF_FunctionName(TF_Function* func); + +// Write out a serialized representation of `func` (as a FunctionDef protocol +// message) to `output_func_def` (allocated by TF_NewBuffer()). +// `output_func_def`'s underlying buffer will be freed when TF_DeleteBuffer() +// is called. +// +// May fail on very large graphs in the future. +TF_CAPI_EXPORT extern void TF_FunctionToFunctionDef(TF_Function* func, + TF_Buffer* output_func_def, + TF_Status* status); + +// Construct and return the function whose FunctionDef representation is +// serialized in `proto`. `proto_len` must equal the number of bytes +// pointed to by `proto`. +// Returns: +// On success, a newly created TF_Function instance. It must be deleted by +// calling TF_DeleteFunction. +// +// On failure, null. +TF_CAPI_EXPORT extern TF_Function* TF_FunctionImportFunctionDef( + const void* proto, size_t proto_len, TF_Status* status); + +// Sets function attribute named `attr_name` to value stored in `proto`. +// If this attribute is already set to another value, it is overridden. +// `proto` should point to a sequence of bytes of length `proto_len` +// representing a binary serialization of an AttrValue protocol +// buffer. +TF_CAPI_EXPORT extern void TF_FunctionSetAttrValueProto(TF_Function* func, + const char* attr_name, + const void* proto, + size_t proto_len, + TF_Status* status); + +// Sets `output_attr_value` to the binary-serialized AttrValue proto +// representation of the value of the `attr_name` attr of `func`. +// If `attr_name` attribute is not present, status is set to an error. +TF_CAPI_EXPORT extern void TF_FunctionGetAttrValueProto( + TF_Function* func, const char* attr_name, TF_Buffer* output_attr_value, + TF_Status* status); + +// Frees the memory used by the `func` struct. +// TF_DeleteFunction is a noop if `func` is null. +// Deleting a function does not remove it from any graphs it was copied to. +TF_CAPI_EXPORT extern void TF_DeleteFunction(TF_Function* func); + +// Attempts to evaluate `output`. This will only be possible if `output` doesn't +// depend on any graph inputs (this function is safe to call if this isn't the +// case though). +// +// If the evaluation is successful, this function returns true and `output`s +// value is returned in `result`. Otherwise returns false. An error status is +// returned if something is wrong with the graph or input. Note that this may +// return false even if no error status is set. +TF_CAPI_EXPORT extern unsigned char TF_TryEvaluateConstant(TF_Graph* graph, + TF_Output output, + TF_Tensor** result, + TF_Status* status); + +// TODO(josh11b): Register OpDef, available to all operations added +// to this graph. + +// -------------------------------------------------------------------------- +// API for driving Graph execution. + +typedef struct TF_Session TF_Session; + +// Return a new execution session with the associated graph, or NULL on +// error. Does not take ownership of any input parameters. +// +// *`graph` must be a valid graph (not deleted or nullptr). `graph` will be +// kept alive for the lifetime of the returned TF_Session. New nodes can still +// be added to `graph` after this call. +TF_CAPI_EXPORT extern TF_Session* TF_NewSession(TF_Graph* graph, + const TF_SessionOptions* opts, + TF_Status* status); + +// This function creates a new TF_Session (which is created on success) using +// `session_options`, and then initializes state (restoring tensors and other +// assets) using `run_options`. +// +// Any NULL and non-NULL value combinations for (`run_options, `meta_graph_def`) +// are valid. +// +// - `export_dir` must be set to the path of the exported SavedModel. +// - `tags` must include the set of tags used to identify one MetaGraphDef in +// the SavedModel. +// - `graph` must be a graph newly allocated with TF_NewGraph(). +// +// If successful, populates `graph` with the contents of the Graph and +// `meta_graph_def` with the MetaGraphDef of the loaded model. +TF_CAPI_EXPORT extern TF_Session* TF_LoadSessionFromSavedModel( + const TF_SessionOptions* session_options, const TF_Buffer* run_options, + const char* export_dir, const char* const* tags, int tags_len, + TF_Graph* graph, TF_Buffer* meta_graph_def, TF_Status* status); + +// Close a session. +// +// Contacts any other processes associated with the session, if applicable. +// May not be called after TF_DeleteSession(). +TF_CAPI_EXPORT extern void TF_CloseSession(TF_Session*, TF_Status* status); + +// Destroy a session object. +// +// Even if error information is recorded in *status, this call discards all +// local resources associated with the session. The session may not be used +// during or after this call (and the session drops its reference to the +// corresponding graph). +TF_CAPI_EXPORT extern void TF_DeleteSession(TF_Session*, TF_Status* status); + +// Run the graph associated with the session starting with the supplied inputs +// (inputs[0,ninputs-1] with corresponding values in input_values[0,ninputs-1]). +// +// Any NULL and non-NULL value combinations for (`run_options`, +// `run_metadata`) are valid. +// +// - `run_options` may be NULL, in which case it will be ignored; or +// non-NULL, in which case it must point to a `TF_Buffer` containing the +// serialized representation of a `RunOptions` protocol buffer. +// - `run_metadata` may be NULL, in which case it will be ignored; or +// non-NULL, in which case it must point to an empty, freshly allocated +// `TF_Buffer` that may be updated to contain the serialized representation +// of a `RunMetadata` protocol buffer. +// +// The caller retains ownership of `input_values` (which can be deleted using +// TF_DeleteTensor). The caller also retains ownership of `run_options` and/or +// `run_metadata` (when not NULL) and should manually call TF_DeleteBuffer on +// them. +// +// On success, the tensors corresponding to outputs[0,noutputs-1] are placed in +// output_values[]. Ownership of the elements of output_values[] is transferred +// to the caller, which must eventually call TF_DeleteTensor on them. +// +// On failure, output_values[] contains NULLs. +TF_CAPI_EXPORT extern void TF_SessionRun( + TF_Session* session, + // RunOptions + const TF_Buffer* run_options, + // Input tensors + const TF_Output* inputs, TF_Tensor* const* input_values, int ninputs, + // Output tensors + const TF_Output* outputs, TF_Tensor** output_values, int noutputs, + // Target operations + const TF_Operation* const* target_opers, int ntargets, + // RunMetadata + TF_Buffer* run_metadata, + // Output status + TF_Status*); + +// Set up the graph with the intended feeds (inputs) and fetches (outputs) for a +// sequence of partial run calls. +// +// On success, returns a handle that is used for subsequent PRun calls. The +// handle should be deleted with TF_DeletePRunHandle when it is no longer +// needed. +// +// On failure, out_status contains a tensorflow::Status with an error +// message. *handle is set to nullptr. +TF_CAPI_EXPORT extern void TF_SessionPRunSetup( + TF_Session*, + // Input names + const TF_Output* inputs, int ninputs, + // Output names + const TF_Output* outputs, int noutputs, + // Target operations + const TF_Operation* const* target_opers, int ntargets, + // Output handle + const char** handle, + // Output status + TF_Status*); + +// Continue to run the graph with additional feeds and fetches. The +// execution state is uniquely identified by the handle. +TF_CAPI_EXPORT extern void TF_SessionPRun( + TF_Session*, const char* handle, + // Input tensors + const TF_Output* inputs, TF_Tensor* const* input_values, int ninputs, + // Output tensors + const TF_Output* outputs, TF_Tensor** output_values, int noutputs, + // Target operations + const TF_Operation* const* target_opers, int ntargets, + // Output status + TF_Status*); + +// Deletes a handle allocated by TF_SessionPRunSetup. +// Once called, no more calls to TF_SessionPRun should be made. +TF_CAPI_EXPORT extern void TF_DeletePRunHandle(const char* handle); + +// -------------------------------------------------------------------------- +// The deprecated session API. Please switch to the above instead of +// TF_ExtendGraph(). This deprecated API can be removed at any time without +// notice. + +typedef struct TF_DeprecatedSession TF_DeprecatedSession; + +TF_CAPI_EXPORT extern TF_DeprecatedSession* TF_NewDeprecatedSession( + const TF_SessionOptions*, TF_Status* status); +TF_CAPI_EXPORT extern void TF_CloseDeprecatedSession(TF_DeprecatedSession*, + TF_Status* status); +TF_CAPI_EXPORT extern void TF_DeleteDeprecatedSession(TF_DeprecatedSession*, + TF_Status* status); +TF_CAPI_EXPORT extern void TF_Reset(const TF_SessionOptions* opt, + const char** containers, int ncontainers, + TF_Status* status); +// Treat the bytes proto[0,proto_len-1] as a serialized GraphDef and +// add the nodes in that GraphDef to the graph for the session. +// +// Prefer use of TF_Session and TF_GraphImportGraphDef over this. +TF_CAPI_EXPORT extern void TF_ExtendGraph(TF_DeprecatedSession*, + const void* proto, size_t proto_len, + TF_Status*); + +// See TF_SessionRun() above. +TF_CAPI_EXPORT extern void TF_Run(TF_DeprecatedSession*, + const TF_Buffer* run_options, + const char** input_names, TF_Tensor** inputs, + int ninputs, const char** output_names, + TF_Tensor** outputs, int noutputs, + const char** target_oper_names, int ntargets, + TF_Buffer* run_metadata, TF_Status*); + +// See TF_SessionPRunSetup() above. +TF_CAPI_EXPORT extern void TF_PRunSetup(TF_DeprecatedSession*, + const char** input_names, int ninputs, + const char** output_names, int noutputs, + const char** target_oper_names, + int ntargets, const char** handle, + TF_Status*); + +// See TF_SessionPRun above. +TF_CAPI_EXPORT extern void TF_PRun(TF_DeprecatedSession*, const char* handle, + const char** input_names, TF_Tensor** inputs, + int ninputs, const char** output_names, + TF_Tensor** outputs, int noutputs, + const char** target_oper_names, int ntargets, + TF_Status*); + +typedef struct TF_DeviceList TF_DeviceList; + +// Lists all devices in a TF_Session. +// +// Caller takes ownership of the returned TF_DeviceList* which must eventually +// be freed with a call to TF_DeleteDeviceList. +TF_CAPI_EXPORT extern TF_DeviceList* TF_SessionListDevices(TF_Session* session, + TF_Status* status); + +// Lists all devices in a TF_Session. +// +// Caller takes ownership of the returned TF_DeviceList* which must eventually +// be freed with a call to TF_DeleteDeviceList. +TF_CAPI_EXPORT extern TF_DeviceList* TF_DeprecatedSessionListDevices( + TF_DeprecatedSession* session, TF_Status* status); + +// Deallocates the device list. +TF_CAPI_EXPORT extern void TF_DeleteDeviceList(TF_DeviceList* list); + +// Counts the number of elements in the device list. +TF_CAPI_EXPORT extern int TF_DeviceListCount(const TF_DeviceList* list); + +// Retrieves the full name of the device (e.g. /job:worker/replica:0/...) +// The return value will be a pointer to a null terminated string. The caller +// must not modify or delete the string. It will be deallocated upon a call to +// TF_DeleteDeviceList. +// +// If index is out of bounds, an error code will be set in the status object, +// and a null pointer will be returned. +TF_CAPI_EXPORT extern const char* TF_DeviceListName(const TF_DeviceList* list, + int index, + TF_Status* status); + +// Retrieves the type of the device at the given index. +// +// The caller must not modify or delete the string. It will be deallocated upon +// a call to TF_DeleteDeviceList. +// +// If index is out of bounds, an error code will be set in the status object, +// and a null pointer will be returned. +TF_CAPI_EXPORT extern const char* TF_DeviceListType(const TF_DeviceList* list, + int index, + TF_Status* status); + +// Retrieve the amount of memory associated with a given device. +// +// If index is out of bounds, an error code will be set in the status object, +// and -1 will be returned. +TF_CAPI_EXPORT extern int64_t TF_DeviceListMemoryBytes( + const TF_DeviceList* list, int index, TF_Status* status); + +// Retrieve the incarnation number of a given device. +// +// If index is out of bounds, an error code will be set in the status object, +// and 0 will be returned. +TF_CAPI_EXPORT extern uint64_t TF_DeviceListIncarnation( + const TF_DeviceList* list, int index, TF_Status* status); + +// -------------------------------------------------------------------------- +// Load plugins containing custom ops and kernels + +// TF_Library holds information about dynamically loaded TensorFlow plugins. +typedef struct TF_Library TF_Library; + +// Load the library specified by library_filename and register the ops and +// kernels present in that library. +// +// Pass "library_filename" to a platform-specific mechanism for dynamically +// loading a library. The rules for determining the exact location of the +// library are platform-specific and are not documented here. +// +// On success, place OK in status and return the newly created library handle. +// The caller owns the library handle. +// +// On failure, place an error status in status and return NULL. +TF_CAPI_EXPORT extern TF_Library* TF_LoadLibrary(const char* library_filename, + TF_Status* status); + +// Get the OpList of OpDefs defined in the library pointed by lib_handle. +// +// Returns a TF_Buffer. The memory pointed to by the result is owned by +// lib_handle. The data in the buffer will be the serialized OpList proto for +// ops defined in the library. +TF_CAPI_EXPORT extern TF_Buffer TF_GetOpList(TF_Library* lib_handle); + +// Frees the memory associated with the library handle. +// Does NOT unload the library. +TF_CAPI_EXPORT extern void TF_DeleteLibraryHandle(TF_Library* lib_handle); + +// Get the OpList of all OpDefs defined in this address space. +// Returns a TF_Buffer, ownership of which is transferred to the caller +// (and can be freed using TF_DeleteBuffer). +// +// The data in the buffer will be the serialized OpList proto for ops registered +// in this address space. +TF_CAPI_EXPORT extern TF_Buffer* TF_GetAllOpList(void); + +// TF_ApiDefMap encapsulates a collection of API definitions for an operation. +// +// This object maps the name of a TensorFlow operation to a description of the +// API to generate for it, as defined by the ApiDef protocol buffer ( +// https://www.tensorflow.org/code/tensorflow/core/framework/api_def.proto) +// +// The ApiDef messages are typically used to generate convenience wrapper +// functions for TensorFlow operations in various language bindings. +typedef struct TF_ApiDefMap TF_ApiDefMap; + +// Creates a new TF_ApiDefMap instance. +// +// Params: +// op_list_buffer - TF_Buffer instance containing serialized OpList +// protocol buffer. (See +// https://www.tensorflow.org/code/tensorflow/core/framework/op_def.proto +// for the OpList proto definition). +// status - Set to OK on success and an appropriate error on failure. +TF_CAPI_EXPORT extern TF_ApiDefMap* TF_NewApiDefMap(TF_Buffer* op_list_buffer, + TF_Status* status); + +// Deallocates a TF_ApiDefMap. +TF_CAPI_EXPORT extern void TF_DeleteApiDefMap(TF_ApiDefMap* apimap); + +// Add ApiDefs to the map. +// +// `text` corresponds to a text representation of an ApiDefs protocol message. +// (https://www.tensorflow.org/code/tensorflow/core/framework/api_def.proto). +// +// The provided ApiDefs will be merged with existing ones in the map, with +// precedence given to the newly added version in case of conflicts with +// previous calls to TF_ApiDefMapPut. +TF_CAPI_EXPORT extern void TF_ApiDefMapPut(TF_ApiDefMap* api_def_map, + const char* text, size_t text_len, + TF_Status* status); + +// Returns a serialized ApiDef protocol buffer for the TensorFlow operation +// named `name`. +TF_CAPI_EXPORT extern TF_Buffer* TF_ApiDefMapGet(TF_ApiDefMap* api_def_map, + const char* name, + size_t name_len, + TF_Status* status); + +// -------------------------------------------------------------------------- +// Kernel definition information. + +// Returns a serialized KernelList protocol buffer containing KernelDefs for all +// registered kernels. +TF_CAPI_EXPORT extern TF_Buffer* TF_GetAllRegisteredKernels(TF_Status* status); + +// Returns a serialized KernelList protocol buffer containing KernelDefs for all +// kernels registered for the operation named `name`. +TF_CAPI_EXPORT extern TF_Buffer* TF_GetRegisteredKernelsForOp( + const char* name, TF_Status* status); + +// Update edge, switch input/ output in a node +TF_CAPI_EXPORT extern void TF_UpdateEdge(TF_Graph* graph, TF_Output new_src, + TF_Input dst, TF_Status* status); + +// -------------------------------------------------------------------------- +// In-process TensorFlow server functionality, for use in distributed training. +// A Server instance encapsulates a set of devices and a Session target that +// can participate in distributed training. A server belongs to a cluster +// (specified by a ClusterSpec), and corresponds to a particular task in a +// named job. The server can communicate with any other server in the same +// cluster. + +// In-process TensorFlow server. +typedef struct TF_Server TF_Server; + +// Creates a new in-process TensorFlow server configured using a serialized +// ServerDef protocol buffer provided via `proto` and `proto_len`. +// +// The server will not serve any requests until TF_ServerStart is invoked. +// The server will stop serving requests once TF_ServerStop or +// TF_DeleteServer is invoked. +TF_CAPI_EXPORT extern TF_Server* TF_NewServer(const void* proto, + size_t proto_len, + TF_Status* status); + +// Starts an in-process TensorFlow server. +TF_CAPI_EXPORT extern void TF_ServerStart(TF_Server* server, TF_Status* status); + +// Stops an in-process TensorFlow server. +TF_CAPI_EXPORT extern void TF_ServerStop(TF_Server* server, TF_Status* status); + +// Blocks until the server has been successfully stopped (via TF_ServerStop or +// TF_ServerClose). +TF_CAPI_EXPORT extern void TF_ServerJoin(TF_Server* server, TF_Status* status); + +// Returns the target string that can be provided to TF_SetTarget() to connect +// a TF_Session to `server`. +// +// The returned string is valid only until TF_DeleteServer is invoked. +TF_CAPI_EXPORT extern const char* TF_ServerTarget(TF_Server* server); + +// Destroy an in-process TensorFlow server, frees memory. If server is running +// it will be stopped and joined. +TF_CAPI_EXPORT extern void TF_DeleteServer(TF_Server* server); + +// Register a listener method that processes printed messages. +// +// If any listeners are registered, the print operator will call all listeners +// with the printed messages and immediately return without writing to the +// logs. +TF_CAPI_EXPORT extern void TF_RegisterLogListener( + void (*listener)(const char*)); + +// Register a FileSystem plugin from filename `plugin_filename`. +// +// On success, place OK in status. +// On failure, place an error status in status. +TF_CAPI_EXPORT extern void TF_RegisterFilesystemPlugin( + const char* plugin_filename, TF_Status* status); + +// Apis that are corresponding to python c api. -------------------- + +// Add control input to `op`. +TF_CAPI_EXPORT extern void TF_AddOperationControlInput(TF_Graph* graph, + TF_Operation* op, + TF_Operation* input); + +// Changes an attr value in the node_def Protocol Buffer and sets a status upon +// completion. +TF_CAPI_EXPORT extern void TF_SetAttr(TF_Graph* graph, TF_Operation* op, + const char* attr_name, + TF_Buffer* attr_value_proto, + TF_Status* status); + +// Clears the attr in the node_def Protocol Buffer and sets a status upon +// completion. +TF_CAPI_EXPORT extern void TF_ClearAttr(TF_Graph* graph, TF_Operation* op, + const char* attr_name, + TF_Status* status); + +// Sets the experimental_type` field in the node_def Protocol Buffer. +TF_CAPI_EXPORT extern void TF_SetFullType(TF_Graph* graph, TF_Operation* op, + const TF_Buffer* full_type_proto); + +// Set the requested device for `graph`. +TF_CAPI_EXPORT extern void TF_SetRequestedDevice(TF_Graph* graph, + TF_Operation* op, + const char* device); + +// Remove all the control inputs from `op` in `graph`. +TF_CAPI_EXPORT extern void TF_RemoveAllControlInputs(TF_Graph* graph, + TF_Operation* op); + +// Set if `graph` requires shape inference functions. +TF_CAPI_EXPORT extern void TF_SetRequireShapeInferenceFns(TF_Graph* graph, + bool require); + +// Extends `session` with any new operations added to its associated graph. +// Usually this happens automatically in TF_SessionRun. After this is called, +// TF_SessionRun will no longer extend the session on every call. +// +// We expose this here to allow fine-grained synchronization in multi-threaded +// workloads, which is required since the Python implementation depends on the +// above mutation methods. This allows us to prevent modifications to nodes in +// the graph after the session has been made aware of them. +TF_CAPI_EXPORT extern void TF_ExtendSession(TF_Session* session, + TF_Status* status); + +// Returns the serialized CppShapeInferenceResult::HandleData proto for +// `output` if its a resource or variant tensor, or otherwise returns the empty +// string. +TF_CAPI_EXPORT extern TF_Buffer* TF_GetHandleShapeAndType(TF_Graph* graph, + TF_Output output); + +// Sets `output` based on `proto`, which should be a serialized +// CppShapeInferenceResult::HandleData proto. `output` should be a resource +// or variant tensor. +// NOTE(skyewm): `proto` is passed a void*/size_t pair instead of a std::string +// because I couldn't get SWIG to work otherwise. +TF_CAPI_EXPORT extern void TF_SetHandleShapeAndType(TF_Graph* graph, + TF_Output output, + const void* proto, + size_t proto_len, + TF_Status* status); + +// This method is used to add a new input edge to 'dst', which must be a While +// op. The While op's "T" attribute must have already been updated to include +// the new edge. This is used to construct tf.while_loop gradients. +TF_CAPI_EXPORT extern void TF_AddWhileInputHack(TF_Graph* graph, + TF_Output new_src, + TF_Operation* dst, + TF_Status* status); + +// ---------------------------------------------------------------- + +#ifdef __cplusplus +} /* end extern "C" */ +#endif + +#endif // TENSORFLOW_C_C_API_H_ diff --git a/third_party/tflite-hdrs/tensorflow/c/c_api_experimental.h b/third_party/tflite-hdrs/tensorflow/c/c_api_experimental.h new file mode 100644 index 00000000..abae68cf --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/c/c_api_experimental.h @@ -0,0 +1,324 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_C_API_EXPERIMENTAL_H_ +#define TENSORFLOW_C_C_API_EXPERIMENTAL_H_ + +#include +#include + +#include "tensorflow/c/c_api.h" +#include "tensorflow/c/c_api_macros.h" +#include "tensorflow/c/eager/c_api.h" + +// -------------------------------------------------------------------------- +// Experimental C API for TensorFlow. +// +// The API here is subject to changes in the future. +// -------------------------------------------------------------------------- + +#ifdef __cplusplus +extern "C" { +#endif + +// When `enable` is true, set +// tensorflow.ConfigProto.OptimizerOptions.global_jit_level to ON_1, and also +// set XLA flag values to prepare for XLA compilation. Otherwise set +// global_jit_level to OFF. +// +// This and the next API are syntax sugar over TF_SetConfig(), and is used by +// clients that cannot read/write the tensorflow.ConfigProto proto. +// TODO: Migrate to TF_CreateConfig() below. +TF_CAPI_EXPORT extern void TF_EnableXLACompilation(TF_SessionOptions* options, + unsigned char enable); + +// Set XLA's internal BuildXlaOpsPassFlags.tf_xla_enable_lazy_compilation to the +// value of 'enabled'. Also returns the original value of that flag. +// +// Use in tests to allow XLA to fallback to TF classic. This has global effect. +TF_CAPI_EXPORT unsigned char TF_SetXlaEnableLazyCompilation( + unsigned char enable); +TF_CAPI_EXPORT unsigned char TF_SetTfXlaCpuGlobalJit(unsigned char enable); + +// Sets XLA's auto jit mode according to the specified string, which is parsed +// as if passed in XLA_FLAGS. This has global effect. +TF_CAPI_EXPORT void TF_SetXlaAutoJitMode(const char* mode); + +// Returns whether the single GPU or general XLA auto jit optimizations are +// enabled through MarkForCompilationPassFlags. +TF_CAPI_EXPORT unsigned char TF_GetXlaAutoJitEnabled(); + +// Sets XLA's minimum cluster size. This has global effect. +TF_CAPI_EXPORT void TF_SetXlaMinClusterSize(int size); + +// Gets/Sets TF/XLA flag for whether(true) or not(false) to disable constant +// folding. This is for testing to ensure that XLA is being tested rather than +// Tensorflow's CPU implementation through constant folding. +TF_CAPI_EXPORT unsigned char TF_GetXlaConstantFoldingDisabled(); +TF_CAPI_EXPORT void TF_SetXlaConstantFoldingDisabled( + unsigned char should_enable); + +// Create a serialized tensorflow.ConfigProto proto, where: +// +// a) ConfigProto.optimizer_options.global_jit_level is set to ON_1 if +// `enable_xla_compilation` is non-zero, and OFF otherwise. +// b) ConfigProto.gpu_options.allow_growth is set to `gpu_memory_allow_growth`. +// c) ConfigProto.device_count is set to `num_cpu_devices`. +TF_CAPI_EXPORT extern TF_Buffer* TF_CreateConfig( + unsigned char enable_xla_compilation, unsigned char gpu_memory_allow_growth, + unsigned int num_cpu_devices); + +// Create a serialized tensorflow.RunOptions proto, where RunOptions.trace_level +// is set to FULL_TRACE if `enable_full_trace` is non-zero, and NO_TRACE +// otherwise. +TF_CAPI_EXPORT extern TF_Buffer* TF_CreateRunOptions( + unsigned char enable_full_trace); + +// Returns the graph content in a human-readable format, with length set in +// `len`. The format is subject to change in the future. +// The returned string is heap-allocated, and caller should call free() on it. +TF_CAPI_EXPORT extern const char* TF_GraphDebugString(TF_Graph* graph, + size_t* len); + +// Returns the function content in a human-readable format, with length set in +// `len`. The format is subject to change in the future. +// The returned string is heap-allocated, and caller should call free() on it. +// +// Do not return const char*, because some foreign language binding +// (e.g. swift) cannot then call free() on the returned pointer. +TF_CAPI_EXPORT extern char* TF_FunctionDebugString(TF_Function* func, + size_t* len); + +// On success, dequeues a tensor from a TF-managed FifoQueue given by +// `tensor_id`, associated with `session`. There must be a graph node named +// "fifo_queue_dequeue_", to be executed by this API call. + +// Caller must call TF_DeleteTensor() over the returned tensor. If the queue is +// empty, this call is blocked. +// +// Tensors are enqueued via the corresponding TF enqueue op. +// TODO(hongm): Add support for `timeout_ms`. +TF_CAPI_EXPORT extern TF_Tensor* TF_DequeueNamedTensor(TF_Session* session, + int tensor_id, + TF_Status* status); + +// On success, enqueues `tensor` into a TF-managed FifoQueue given by +// `tensor_id`, associated with `session`. There must be a graph node named +// "fifo_queue_enqueue_", to be executed by this API call. It reads +// from a placeholder node "arg_tensor_enqueue_". +// +// `tensor` is still owned by the caller. This call will be blocked if the queue +// has reached its capacity, and will be unblocked when the queued tensors again +// drop below the capacity due to dequeuing. +// +// Tensors are dequeued via the corresponding TF dequeue op. +// TODO(hongm): Add support for `timeout_ms`. +TF_CAPI_EXPORT extern void TF_EnqueueNamedTensor(TF_Session* session, + int tensor_id, + TF_Tensor* tensor, + TF_Status* status); +// Create a serialized tensorflow.ServerDef proto. +TF_Buffer* TFE_GetServerDef(const char* text_proto, TF_Status* status); + +TF_CAPI_EXPORT extern void TF_MakeInternalErrorStatus(TF_Status* status, + const char* errMsg); + +// TF_NewCheckpointReader() return the CheckpointReader that can be use to +// investigate or load the variable from the checkpoint file +typedef struct TF_CheckpointReader TF_CheckpointReader; +TF_CAPI_EXPORT extern TF_CheckpointReader* TF_NewCheckpointReader( + const char* filename, TF_Status* status); +TF_CAPI_EXPORT extern void TF_DeleteCheckpointReader( + TF_CheckpointReader* reader); +TF_CAPI_EXPORT extern int TF_CheckpointReaderHasTensor( + TF_CheckpointReader* reader, const char* name); +// Get the variable name at the given index +TF_CAPI_EXPORT extern const char* TF_CheckpointReaderGetVariable( + TF_CheckpointReader* reader, int index); +// Get the number of variable in the checkpoint +TF_CAPI_EXPORT extern int TF_CheckpointReaderSize(TF_CheckpointReader* reader); +// Get the DataType of a variable +TF_CAPI_EXPORT extern TF_DataType TF_CheckpointReaderGetVariableDataType( + TF_CheckpointReader* reader, const char* name); +// Read the shape of a variable and write to `dims` +TF_CAPI_EXPORT extern void TF_CheckpointReaderGetVariableShape( + TF_CheckpointReader* reader, const char* name, int64_t* dims, int num_dims, + TF_Status* status); +// Get the number of dimension of a variable +TF_CAPI_EXPORT extern int TF_CheckpointReaderGetVariableNumDims( + TF_CheckpointReader* reader, const char* name); +// Load the weight of a variable +TF_CAPI_EXPORT extern TF_Tensor* TF_CheckpointReaderGetTensor( + TF_CheckpointReader* reader, const char* name, TF_Status* status); + +// TF_NewAttrBuilder() returns an object that you can set attributes on as +// though it were an op. This allows querying properties of that op for +// type-checking purposes like if the op will run on a particular device type. +typedef struct TF_AttrBuilder TF_AttrBuilder; +TF_CAPI_EXPORT extern TF_AttrBuilder* TF_NewAttrBuilder(const char* op_name); +TF_CAPI_EXPORT extern void TF_DeleteAttrBuilder(TF_AttrBuilder* builder); +TF_CAPI_EXPORT extern void TF_AttrBuilderSetType(TF_AttrBuilder* builder, + const char* attr_name, + TF_DataType value); +TF_CAPI_EXPORT extern void TF_AttrBuilderSetTypeList(TF_AttrBuilder* builder, + const char* attr_name, + const TF_DataType* values, + int num_values); + +// Checks the tensorflow::NodeDef built via the methods above to see if it can +// run on device_type. +TF_CAPI_EXPORT extern void TF_AttrBuilderCheckCanRunOnDevice( + TF_AttrBuilder* builder, const char* device_type, TF_Status* status); + +// For argument number input_index, fetch the corresponding number_attr that +// needs to be updated with the argument length of the input list. +// Returns nullptr if there is any problem like op_name is not found, or the +// argument does not support this attribute type. +TF_CAPI_EXPORT extern const char* TF_GetNumberAttrForOpListInput( + const char* op_name, int input_index, TF_Status* status); + +// Returns 1 if the op is stateful, 0 otherwise. The return value is undefined +// if the status is not ok. +TF_CAPI_EXPORT extern int TF_OpIsStateful(const char* op_type, + TF_Status* status); + +// Platform specific initialization routine. Very few platforms actually require +// this to be called. +TF_CAPI_EXPORT void TF_InitMain(const char* usage, int* argc, char*** argv); + +// Platform-specific implementation to return an unused port. (This should used +// in tests only.) +TF_CAPI_EXPORT int TF_PickUnusedPortOrDie(void); + +// Fast path method that makes constructing a single scalar tensor require less +// overhead and copies. +TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_NewTensorHandleFromScalar( + TF_DataType data_type, void* data, size_t len, TF_Status* status); + +// Specify the server_def that enables collective ops. +// This is different to the above function in that it doesn't create remote +// contexts, and remotely executing ops is not possible. It just enables +// communication for collective ops. +TF_CAPI_EXPORT extern void TFE_EnableCollectiveOps(TFE_Context* ctx, + const void* proto, + size_t proto_len, + TF_Status* status); + +// Aborts all ongoing collectives with the specified status. After abortion, +// subsequent collectives will error with this status immediately. To reset the +// collectives, create a new EagerContext. +// +// This is intended to be used when a peer failure is detected. +TF_CAPI_EXPORT extern void TFE_AbortCollectiveOps(TFE_Context* ctx, + TF_Status* status); + +// Checks the health of collective ops peers. Explicit health check is needed in +// multi worker collective ops to detect failures in the cluster. If a peer is +// down, collective ops may hang. +TF_CAPI_EXPORT extern void TFE_CollectiveOpsCheckPeerHealth( + TFE_Context* ctx, const char* task, int64_t timeout_in_ms, + TF_Status* status); + +// Information about the shape of a Tensor and its type. +struct TF_ShapeAndType { + // Number of dimensions. -1 indicates unknown rank. + int num_dims; + // Array of dimensions. -1 indicates unknown dim. + int64_t* dims; + // The data type. May be 0 to denote unknown type. + TF_DataType dtype; +}; + +typedef struct TF_ShapeAndType TF_ShapeAndType; + +// A list of TF_ShapeAndType elements.. +struct TF_ShapeAndTypeList { + int num_items; + TF_ShapeAndType* items; +}; +typedef struct TF_ShapeAndTypeList TF_ShapeAndTypeList; + +// API for manipulating TF_ShapeAndTypeList objects. +// +TF_CAPI_EXPORT extern TF_ShapeAndTypeList* TF_NewShapeAndTypeList( + int num_shapes); +TF_CAPI_EXPORT extern void TF_ShapeAndTypeListSetShape( + TF_ShapeAndTypeList* shape_list, int index, const int64_t* dims, + int num_dims); +TF_CAPI_EXPORT extern void TF_ShapeAndTypeListSetUnknownShape( + TF_ShapeAndTypeList* shape_list, int index); +TF_CAPI_EXPORT extern void TF_ShapeAndTypeListSetDtype( + TF_ShapeAndTypeList* shape_list, int index, TF_DataType dtype); +TF_CAPI_EXPORT extern void TF_DeleteShapeAndTypeList( + TF_ShapeAndTypeList* shape_list); +TF_CAPI_EXPORT extern void TF_DeleteShapeAndTypeListArray( + TF_ShapeAndTypeList** shape_list_array, int num_items); + +// Infer shapes for the given `op`. The arguments mimic the arguments of the +// `shape_inference::InferenceContext` constructor. Note the following: +// - The inputs of the `op` are not used for shape inference. So, it is +// OK to not have the inputs properly set in `op`. See `input_tensors` +// if you want shape inference to consider the input tensors of the +// op for shape inference. +// - The types need not be set in `input_shapes` as it is not used. +// - The number of `input_tensors` should be the same as the number of items +// in `input_shapes`. +// +// The results are returned in `output_shapes` and +// `output_resource_shapes_and_types`. The caller is responsible for freeing the +// memory in these buffers by calling `TF_DeleteShapeAndTypeList`. +TF_CAPI_EXPORT extern void TFE_InferShapes( + TFE_Op* op, TF_ShapeAndTypeList* input_shapes, TF_Tensor** input_tensors, + TF_ShapeAndTypeList* input_tensor_as_shapes, + TF_ShapeAndTypeList** input_resource_shapes_and_types, + TF_ShapeAndTypeList** output_shapes, + TF_ShapeAndTypeList*** output_resource_shapes_and_types, TF_Status* status); + +TF_CAPI_EXPORT extern void +TF_ImportGraphDefOptionsSetValidateColocationConstraints( + TF_ImportGraphDefOptions* opts, unsigned char enable); + +// Load the library specified by library_filename and register the pluggable +// device and related kernels present in that library. This function is not +// supported on embedded on mobile and embedded platforms and will fail if +// called. +// +// Pass "library_filename" to a platform-specific mechanism for dynamically +// loading a library. The rules for determining the exact location of the +// library are platform-specific and are not documented here. +// +// On success, returns the newly created library handle and places OK in status. +// The caller owns the library handle. +// +// On failure, returns nullptr and places an error status in status. +TF_CAPI_EXPORT extern TF_Library* TF_LoadPluggableDeviceLibrary( + const char* library_filename, TF_Status* status); + +// Frees the memory associated with the library handle. +// Does NOT unload the library. +TF_CAPI_EXPORT extern void TF_DeletePluggableDeviceLibraryHandle( + TF_Library* lib_handle); + +// Removes `func_name` from `g`. If `func_name` is not in `g`, an error will be +// returned. +TF_CAPI_EXPORT extern void TF_GraphRemoveFunction(TF_Graph* g, + const char* func_name, + TF_Status* status); + +#ifdef __cplusplus +} /* end extern "C" */ +#endif + +#endif // TENSORFLOW_C_C_API_EXPERIMENTAL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/c/c_api_internal.h b/third_party/tflite-hdrs/tensorflow/c/c_api_internal.h new file mode 100644 index 00000000..15d279b6 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/c/c_api_internal.h @@ -0,0 +1,227 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_C_API_INTERNAL_H_ +#define TENSORFLOW_C_C_API_INTERNAL_H_ + +#include +#include +#include +#include +#include + +#include "tensorflow/c/c_api.h" + +// clang-format off +// Required for IS_MOBILE_PLATFORM +#include "tensorflow/core/platform/platform.h" +// clang-format on + +#include "tensorflow/c/tf_status_internal.h" +#include "tensorflow/c/tf_tensor_internal.h" +#if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD) +#include "tensorflow/core/framework/op_gen_lib.h" +#endif // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD) +#include "tensorflow/core/common_runtime/graph_constructor.h" +#include "tensorflow/core/common_runtime/shape_refiner.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/graph/node_builder.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/public/session.h" + +namespace tensorflow { +class Device; +class DeviceMgr; +class ServerInterface; +} // namespace tensorflow + +// Internal structures used by the C API. These are likely to change and should +// not be depended on. + +struct TF_SessionOptions { + tensorflow::SessionOptions options; +}; + +struct TF_DeprecatedSession { + tensorflow::Session* session; +}; + +struct TF_Library { + void* lib_handle; + TF_Buffer op_list; +}; + +struct TF_Graph { + TF_Graph(); + + mutable tensorflow::mutex mu; + tensorflow::Graph graph TF_GUARDED_BY(mu); + + // Runs shape inference. + tensorflow::ShapeRefiner refiner TF_GUARDED_BY(mu); + + // Maps from name of an operation to the Node* in 'graph'. + std::unordered_map name_map + TF_GUARDED_BY(mu); + + // The keys of this map are all the active sessions using this graph. Each + // value records whether the graph has been mutated since the corresponding + // session has been run (this is detected in RecordMutation function). If the + // string is empty, no mutation has occurred. Otherwise the string is a + // description of the mutation suitable for returning to the user. + // + // Sessions are added to this map in TF_NewSession, and removed in + // TF_DeleteSession. + // TF_Graph may only / must be deleted when + // sessions.size() == 0 && delete_requested == true + // + // TODO(b/74949947): mutations currently trigger a warning instead of a bad + // status, this should be reverted when possible. + tensorflow::gtl::FlatMap sessions + TF_GUARDED_BY(mu); + bool delete_requested TF_GUARDED_BY(mu); // set true by TF_DeleteGraph + + // Used to link graphs contained in TF_WhileParams to the parent graph that + // will eventually contain the full while loop. + TF_Graph* parent; + TF_Output* parent_inputs; +}; + +struct TF_OperationDescription { + TF_OperationDescription(TF_Graph* g, const char* op_type, + const char* node_name) + : node_builder(node_name, op_type, g->graph.op_registry()), graph(g) {} + + tensorflow::NodeBuilder node_builder; + TF_Graph* graph; + std::set colocation_constraints; +}; + +struct TF_Operation { + tensorflow::Node node; + + private: + ~TF_Operation() = default; +}; + +struct TF_Session { + TF_Session(tensorflow::Session* s, TF_Graph* g); + + tensorflow::Session* session; + TF_Graph* const graph; + + tensorflow::mutex mu TF_ACQUIRED_AFTER(TF_Graph::mu); + int last_num_graph_nodes; + + // If true, TF_SessionRun and similar methods will call + // ExtendSessionGraphHelper before running the graph (this is the default + // public behavior). Can be set to false if the caller needs to call + // ExtendSessionGraphHelper manually. + std::atomic extend_before_run; +}; + +struct TF_ImportGraphDefOptions { + tensorflow::ImportGraphDefOptions opts; + + // Backing memory for TensorId fields in opts. + // TODO(skyewm): it'd be better if ImportGraphDefOptions owned this. + std::vector tensor_id_data; +}; + +struct TF_ImportGraphDefResults { + std::vector return_tensors; + std::vector return_nodes; + std::vector missing_unused_key_names; + std::vector missing_unused_key_indexes; + + // Backing memory for missing_unused_key_names values. + std::vector missing_unused_key_names_data; +}; + +struct TF_DeviceList { + std::vector response; +}; + +struct TF_Function { + tensorflow::FunctionRecord* record; +}; + +struct TF_ApiDefMap { + explicit TF_ApiDefMap(const tensorflow::OpList& op_list) + : +#if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD) + api_def_map(op_list), +#endif // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD) + update_docs_called(false) { + } + +#if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD) + tensorflow::ApiDefMap api_def_map TF_GUARDED_BY(lock); +#endif // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD) + bool update_docs_called TF_GUARDED_BY(lock); + tensorflow::mutex lock; +}; + +#if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD) +struct TF_Server { + TF_Server(std::unique_ptr server); + + const tensorflow::string target; + std::unique_ptr server; +}; +#endif // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD) + +namespace tensorflow { + +// Set the shapes and types of the output's handle. +// +// The lengths of the arrays pointed to by `shapes`, `ranks`, and `types` must +// all be equal to `num_shapes_and_types`. If `ranks[i] != -1`, (i.e., if the +// rank is known), then it must be equal to the length of `shapes[i]`; if +// `ranks[i] == 1`, then `shapes[i]` may be nullptr. +// +// TODO(akshayka): Implement a corresponding getter method. +void TF_GraphSetOutputHandleShapesAndTypes(TF_Graph* graph, TF_Output output, + int num_shapes_and_types, + const int64_t** shapes, + const int* ranks, + const TF_DataType* types, + TF_Status* status); + +void RecordMutation(TF_Graph* graph, const TF_Operation& op, + const char* mutation_type) + TF_EXCLUSIVE_LOCKS_REQUIRED(graph->mu); + +bool ExtendSessionGraphHelper(TF_Session* session, TF_Status* status) + TF_LOCKS_EXCLUDED(session->graph->mu, session->mu); + +std::string getTF_OutputDebugString(TF_Output node); + +// Set whether to propagate assigned device information when constructing a new +// Graph from a GraphDef. By default assigned device information is not copied +// and is re-computed by the runtime. +inline void TF_ImportGraphDefOptionsSetPropagateDeviceSpec( + TF_ImportGraphDefOptions* opts, unsigned char propagate_device_spec) { + opts->opts.propagate_device_spec = propagate_device_spec; +} + +} // end namespace tensorflow + +#endif // TENSORFLOW_C_C_API_INTERNAL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/c/c_api_macros.h b/third_party/tflite-hdrs/tensorflow/c/c_api_macros.h new file mode 100644 index 00000000..d73546ae --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/c/c_api_macros.h @@ -0,0 +1,51 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_C_API_MACROS_H_ +#define TENSORFLOW_C_C_API_MACROS_H_ + +#ifdef SWIG +#define TF_CAPI_EXPORT +#else +#if defined(_WIN32) +#ifdef TF_COMPILE_LIBRARY +#define TF_CAPI_EXPORT __declspec(dllexport) +#else +#define TF_CAPI_EXPORT __declspec(dllimport) +#endif // TF_COMPILE_LIBRARY +#else +#ifdef TF_CAPI_WEAK +#define TF_CAPI_EXPORT \ + __attribute__((visibility("default"))) __attribute((weak)) +#else +#define TF_CAPI_EXPORT __attribute__((visibility("default"))) +#endif // TF_CAPI_WEAK +#endif // _WIN32 +#endif // SWIG + +// TF_Bool is the C API typedef for unsigned char, while TF_BOOL is +// the datatype for boolean tensors. +#ifndef TF_Bool +#define TF_Bool unsigned char +#endif // TF_Bool + +// Macro used to calculate struct size for maintaining ABI stability across +// different struct implementations. +#ifndef TF_OFFSET_OF_END +#define TF_OFFSET_OF_END(TYPE, MEMBER) \ + (offsetof(TYPE, MEMBER) + sizeof(((TYPE *)0)->MEMBER)) +#endif // TF_OFFSET_OF_END + +#endif // TENSORFLOW_C_C_API_MACROS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/c/c_api_macros_internal.h b/third_party/tflite-hdrs/tensorflow/c/c_api_macros_internal.h new file mode 100644 index 00000000..b2bc61d7 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/c/c_api_macros_internal.h @@ -0,0 +1,48 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_C_API_MACROS_INTERNAL_H_ +#define TENSORFLOW_C_C_API_MACROS_INTERNAL_H_ + +#ifdef __cplusplus +#include "tensorflow/core/platform/status.h" + +// Macro to verify that the field `struct_size` of STRUCT_OBJ is initialized. +// `struct_size` is used for struct member compatibility check between core TF +// and plug-ins with the same C API minor version. More info here: +// https://github.com/tensorflow/community/blob/master/rfcs/20200612-stream-executor-c-api/C_API_versioning_strategy.md +#define TF_VALIDATE_STRUCT_SIZE(STRUCT_NAME, STRUCT_OBJ, SIZE_VALUE_NAME) \ + do { \ + if (STRUCT_OBJ.struct_size == 0) { \ + return tensorflow::Status(absl::StatusCode::kFailedPrecondition, \ + "Expected initialized `" #STRUCT_NAME \ + "` structure with `struct_size` field " \ + "set to " #SIZE_VALUE_NAME \ + ". Found `struct_size` = 0."); \ + } \ + } while (0) + +// Macro to verify that the field NAME of STRUCT_OBJ is not null. +#define TF_VALIDATE_NOT_NULL(STRUCT_NAME, STRUCT_OBJ, NAME) \ + do { \ + if (STRUCT_OBJ.NAME == 0) { \ + return tensorflow::Status(absl::StatusCode::kFailedPrecondition, \ + "'" #NAME "' field in " #STRUCT_NAME \ + " must be set."); \ + } \ + } while (0) + +#endif // __cplusplus +#endif // TENSORFLOW_C_C_API_MACROS_INTERNAL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/c/c_op_requires.h b/third_party/tflite-hdrs/tensorflow/c/c_op_requires.h new file mode 100644 index 00000000..1a515bb1 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/c/c_op_requires.h @@ -0,0 +1,51 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_C_OP_REQUIRES_H_ +#define TENSORFLOW_C_C_OP_REQUIRES_H_ + +#include "tensorflow/core/platform/macros.h" + +namespace tensorflow { + +// Convenience macros for asserting and handling exceptional conditions, for +// C structs, including `TF_OpKernelContext`, `TF_Status`, etc. This is analogus +// to the macros in tensorflow/core/framework/op_requires.h. This is provided +// for plugin OpKernel developer's convenience. + +#define C_OPKERNELCONTEXT_REQUIRES_OK(CTX, C_STATUS, __VA_ARGS__) \ + do { \ + ::tensorflow::Status _s(__VA_ARGS__); \ + if (!TF_PREDICT_TRUE(_s.ok())) { \ + ::tensorflow::Set_TF_Status_from_Status(C_STATUS, _s); \ + TF_OpKernelContext_Failure(CTX, C_STATUS); \ + TF_DeleteStatus(C_STATUS); \ + return; \ + } \ + } while (0) + +#define TF_CLEANUP_AND_RETURN_IF_ERROR(C_STATUS, BUFFER, __VA_ARGS__) \ + do { \ + ::tensorflow::Status _s(__VA_ARGS__); \ + if (TF_PREDICT_TRUE(!_s.ok())) { \ + TF_DeleteStatus(C_STATUS); \ + TF_DeleteBuffer(BUFFER); \ + return _s; \ + } \ + } while (0) + +} // namespace tensorflow + +#endif // TENSORFLOW_C_C_OP_REQUIRES_H_ diff --git a/third_party/tflite-hdrs/tensorflow/c/c_test_util.h b/third_party/tflite-hdrs/tensorflow/c/c_test_util.h new file mode 100644 index 00000000..7eeb1ee5 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/c/c_test_util.h @@ -0,0 +1,165 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_C_TEST_UTIL_H_ +#define TENSORFLOW_C_C_TEST_UTIL_H_ + +#include "tensorflow/c/c_api.h" + +#include +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/function.pb.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/platform/test.h" + +using ::tensorflow::string; + +typedef std::unique_ptr + unique_tensor_ptr; + +TF_Tensor* BoolTensor(int32_t v); + +// Create a tensor with values of type TF_INT8 provided by `values`. +TF_Tensor* Int8Tensor(const int64_t* dims, int num_dims, const char* values); + +// Create a tensor with values of type TF_INT32 provided by `values`. +TF_Tensor* Int32Tensor(const int64_t* dims, int num_dims, + const int32_t* values); + +// Create 1 dimensional tensor with values from `values` +TF_Tensor* Int32Tensor(const std::vector& values); + +TF_Tensor* Int32Tensor(int32_t v); + +TF_Tensor* DoubleTensor(double v); + +TF_Tensor* FloatTensor(float v); + +TF_Operation* Placeholder(TF_Graph* graph, TF_Status* s, + const char* name = "feed", + TF_DataType dtype = TF_INT32, + const std::vector& dims = {}); + +TF_Operation* Const(TF_Tensor* t, TF_Graph* graph, TF_Status* s, + const char* name = "const"); + +TF_Operation* ScalarConst(bool v, TF_Graph* graph, TF_Status* s, + const char* name = "scalar"); + +TF_Operation* ScalarConst(int32_t v, TF_Graph* graph, TF_Status* s, + const char* name = "scalar"); + +TF_Operation* ScalarConst(double v, TF_Graph* graph, TF_Status* s, + const char* name = "scalar"); + +TF_Operation* ScalarConst(float v, TF_Graph* graph, TF_Status* s, + const char* name = "scalar"); + +TF_Operation* Add(TF_Operation* l, TF_Operation* r, TF_Graph* graph, + TF_Status* s, const char* name = "add"); + +TF_Operation* AddNoCheck(TF_Operation* l, TF_Operation* r, TF_Graph* graph, + TF_Status* s, const char* name = "add"); + +TF_Operation* AddWithCtrlDependency(TF_Operation* l, TF_Operation* r, + TF_Graph* graph, TF_Operation* ctrl_op, + TF_Status* s, const char* name = "add"); + +TF_Operation* Add(TF_Output l, TF_Output r, TF_Graph* graph, TF_Status* s, + const char* name = "add"); + +TF_Operation* Min(TF_Operation* l, TF_Operation* r, TF_Graph* graph, + TF_Status* s, const char* name = "min"); + +TF_Operation* Mul(TF_Operation* l, TF_Operation* r, TF_Graph* graph, + TF_Status* s, const char* name = "mul"); + +// If `op_device` is non-empty, set the created op on that device. +TF_Operation* MinWithDevice(TF_Operation* l, TF_Operation* r, TF_Graph* graph, + const string& op_device, TF_Status* s, + const char* name = "min"); + +TF_Operation* Neg(TF_Operation* n, TF_Graph* graph, TF_Status* s, + const char* name = "neg"); + +TF_Operation* LessThan(TF_Output l, TF_Output r, TF_Graph* graph, TF_Status* s); + +TF_Operation* RandomUniform(TF_Operation* shape, TF_DataType dtype, + TF_Graph* graph, TF_Status* s); + +// Split `input` along the first dimension into 3 tensors +TF_Operation* Split3(TF_Operation* input, TF_Graph* graph, TF_Status* s, + const char* name = "split3"); + +bool IsPlaceholder(const tensorflow::NodeDef& node_def); + +bool IsScalarConst(const tensorflow::NodeDef& node_def, int v); + +bool IsAddN(const tensorflow::NodeDef& node_def, int n); + +bool IsNeg(const tensorflow::NodeDef& node_def, const string& input); + +bool GetGraphDef(TF_Graph* graph, tensorflow::GraphDef* graph_def); + +bool GetNodeDef(TF_Operation* oper, tensorflow::NodeDef* node_def); + +bool GetFunctionDef(TF_Function* func, tensorflow::FunctionDef* func_def); + +bool GetAttrValue(TF_Operation* oper, const char* attr_name, + tensorflow::AttrValue* attr_value, TF_Status* s); + +// Returns a sorted vector of std::pair from +// graph_def.library().gradient() +std::vector> GetGradDefs( + const tensorflow::GraphDef& graph_def); + +// Returns a sorted vector of names contained in `grad_def` +std::vector GetFuncNames(const tensorflow::GraphDef& graph_def); + +class CSession { + public: + CSession(TF_Graph* graph, TF_Status* s, bool use_XLA = false); + explicit CSession(TF_Session* session); + + ~CSession(); + + void SetInputs(std::vector> inputs); + void SetOutputs(std::initializer_list outputs); + void SetOutputs(const std::vector& outputs); + void SetTargets(std::initializer_list targets); + + void Run(TF_Status* s); + + void CloseAndDelete(TF_Status* s); + + TF_Tensor* output_tensor(int i) { return output_values_[i]; } + + TF_Session* mutable_session() { return session_; } + + private: + void DeleteInputValues(); + void ResetOutputValues(); + + TF_Session* session_; + std::vector inputs_; + std::vector input_values_; + std::vector outputs_; + std::vector output_values_; + std::vector targets_; +}; + +#endif // TENSORFLOW_C_C_TEST_UTIL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/c/checkpoint_reader.h b/third_party/tflite-hdrs/tensorflow/c/checkpoint_reader.h new file mode 100644 index 00000000..75008ffa --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/c/checkpoint_reader.h @@ -0,0 +1,83 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_CHECKPOINT_READER_H_ +#define TENSORFLOW_C_CHECKPOINT_READER_H_ + +#include +#include + +#include "tensorflow/c/tf_status_helper.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/tensor_bundle/tensor_bundle.h" +#include "tensorflow/core/util/tensor_slice_reader.h" + +namespace tensorflow { +namespace checkpoint { + +class TensorSliceReader; + +// A wrapper around BundleReader (for V2 checkpoints) and +// checkpoint::TensorSliceReader (for V1), that is more easily SWIG wrapped for +// other languages. +// +// The class currently only interacts with single-slice (i.e., non-partitioned) +// variables. +class CheckpointReader { + public: + CheckpointReader(const string& filename, TF_Status* status); + + bool HasTensor(const string& name) const; + const string DebugString() const; + + // Returns a map from variable names to their shapes. Slices of a partitioned + // tensor are combined into a single entry. + const TensorSliceReader::VarToShapeMap& GetVariableToShapeMap() const; + + // Returns a map from variable names to their data types. Slices of a + // partitioned tensor are combined into a single entry. + const TensorSliceReader::VarToDataTypeMap& GetVariableToDataTypeMap() const; + + // Attempts to look up the tensor named "name" and stores the found result in + // "out_tensor". + void GetTensor(const string& name, + std::unique_ptr* out_tensor, + TF_Status* out_status) const; + + private: + // Uses "v2_reader_" to build "var name -> shape" and "var name -> data type" + // maps; both owned by caller. + // REQUIRES: "v2_reader_ != nullptr && v2_reader_.status().ok()". + std::pair, + std::unique_ptr > + BuildV2VarMaps(); + + // Invariant: exactly one of "reader_" and "v2_reader_" is non-null. + std::unique_ptr reader_; + std::unique_ptr v2_reader_; + + std::unique_ptr var_to_shape_map_; + std::unique_ptr var_to_data_type_map_; + + CheckpointReader(const CheckpointReader&) = delete; + void operator=(const CheckpointReader&) = delete; +}; + +} // namespace checkpoint +} // namespace tensorflow + +#endif // TENSORFLOW_C_CHECKPOINT_READER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/c/conversion_macros.h b/third_party/tflite-hdrs/tensorflow/c/conversion_macros.h new file mode 100644 index 00000000..d1f99b7b --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/c/conversion_macros.h @@ -0,0 +1,33 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_CONVERSION_MACROS_H_ +#define TENSORFLOW_C_CONVERSION_MACROS_H_ + +#define DEFINE_CONVERSION_FUNCTIONS(cpp_impl, wrapper) \ + inline cpp_impl *unwrap(wrapper *w) { \ + return reinterpret_cast(w); \ + } \ + \ + inline const cpp_impl *unwrap(const wrapper *w) { \ + return reinterpret_cast(w); \ + } \ + \ + inline wrapper *wrap(cpp_impl *i) { return reinterpret_cast(i); } \ + inline const wrapper *wrap(const cpp_impl *i) { \ + return reinterpret_cast(i); \ + } + +#endif // TENSORFLOW_C_CONVERSION_MACROS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/c/eager/abstract_context.h b/third_party/tflite-hdrs/tensorflow/c/eager/abstract_context.h new file mode 100644 index 00000000..4bf6ff9b --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/c/eager/abstract_context.h @@ -0,0 +1,82 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_C_EAGER_ABSTRACT_CONTEXT_H_ +#define TENSORFLOW_C_EAGER_ABSTRACT_CONTEXT_H_ + +#include + +#include "tensorflow/c/eager/abstract_function.h" +#include "tensorflow/c/eager/abstract_operation.h" + +namespace tensorflow { + +// Abstract interface to a context. +// +// This serves as a factory for creating `AbstractOperation`s and for +// registering traced functions. +// Operations creation within a context can only be executed in that context +// (for now at least). +// Implementations of the context may contain some state e.g. an execution +// environment, a traced representation etc. +class AbstractContext { + protected: + enum AbstractContextKind { kGraph, kMlir, kEager, kTfrt, kTape, kOpHandler }; + explicit AbstractContext(AbstractContextKind kind) : kind_(kind) {} + virtual ~AbstractContext() {} + + public: + AbstractContextKind getKind() const { return kind_; } + + // Release any underlying resources, including the interface object. + // + // WARNING: The destructor of this class is marked as protected to disallow + // clients from directly destroying this object since it may manage its own + // lifetime through ref counting. Thus clients MUST call Release() in order to + // destroy an instance of this class. + virtual void Release() = 0; + + // Creates an operation builder and ties it to this context. + // The returned object can be used for setting operation's attributes, + // adding inputs and finally executing (immediately or lazily as in tracing) + // it in this context. + virtual AbstractOperation* CreateOperation() = 0; + + // Registers a function with this context, after this the function is + // available to be called/referenced by its name in this context. + virtual absl::Status RegisterFunction(AbstractFunction*) = 0; + // Remove a function. 'func' argument is the name of a previously added + // FunctionDef. The name is in fdef.signature.name. + virtual absl::Status RemoveFunction(const string& func) = 0; + + private: + const AbstractContextKind kind_; +}; + +namespace internal { +struct AbstractContextDeleter { + void operator()(AbstractContext* p) const { + if (p != nullptr) { + p->Release(); + } + } +}; +} // namespace internal + +using AbstractContextPtr = + std::unique_ptr; + +} // namespace tensorflow + +#endif // TENSORFLOW_C_EAGER_ABSTRACT_CONTEXT_H_ diff --git a/third_party/tflite-hdrs/tensorflow/c/eager/abstract_function.h b/third_party/tflite-hdrs/tensorflow/c/eager/abstract_function.h new file mode 100644 index 00000000..7bc8f8bd --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/c/eager/abstract_function.h @@ -0,0 +1,56 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_C_EAGER_ABSTRACT_FUNCTION_H_ +#define TENSORFLOW_C_EAGER_ABSTRACT_FUNCTION_H_ + +#include "absl/status/statusor.h" +#include "tensorflow/core/framework/function.pb.h" +#include "tensorflow/core/platform/intrusive_ptr.h" +#include "tensorflow/core/platform/refcount.h" +#include "tensorflow/core/platform/status.h" + +namespace tensorflow { + +class FunctionRecord; + +// A traced function: this hides the complexity of converting the serialized +// representation between various supported formats e.g. FunctionDef and Mlir +// function. +class AbstractFunction : public core::RefCounted { + protected: + enum AbstractFunctionKind { kGraph, kMlir }; + explicit AbstractFunction(AbstractFunctionKind kind) : kind_(kind) {} + + public: + // Returns which subclass is this instance of. + AbstractFunctionKind getKind() const { return kind_; } + + // Returns the AbstractFunction as a FunctionDef. + virtual absl::Status GetFunctionDef(const FunctionDef**) = 0; + + // Returns a shared reference to the wrapped function. + virtual absl::StatusOr> + GetFunctionRecord() = 0; + + private: + const AbstractFunctionKind kind_; +}; + +using AbstractFunctionPtr = + tensorflow::core::IntrusivePtr; + +} // namespace tensorflow + +#endif // TENSORFLOW_C_EAGER_ABSTRACT_FUNCTION_H_ diff --git a/third_party/tflite-hdrs/tensorflow/c/eager/abstract_op_attrs.h b/third_party/tflite-hdrs/tensorflow/c/eager/abstract_op_attrs.h new file mode 100644 index 00000000..e799552a --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/c/eager/abstract_op_attrs.h @@ -0,0 +1,54 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_C_EAGER_ABSTRACT_OP_ATTRS_H_ +#define TENSORFLOW_C_EAGER_ABSTRACT_OP_ATTRS_H_ + +#include "absl/container/inlined_vector.h" +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/platform/status.h" + +namespace tensorflow { + +// Attributes of an op. +class AbstractOpAttrs { + protected: + enum AbstractOpAttrsKind { kEager, kTfrt }; + explicit AbstractOpAttrs(AbstractOpAttrsKind kind) : kind_(kind) {} + + public: + // Returns which subclass is this instance of. + AbstractOpAttrsKind getKind() const { return kind_; } + virtual ~AbstractOpAttrs() = default; + + // Returns the AbstractFunction as a FunctionDef. + virtual void GetNameAttrList( + tensorflow::NameAttrList* name_and_attrs) const = 0; + + virtual bool GetInt(absl::string_view, int64_t* result) const = 0; + virtual bool GetFloat(absl::string_view attr_name, float* result) const = 0; + virtual bool GetBool(absl::string_view attr_name, bool* result) const = 0; + virtual bool GetType(absl::string_view attr_name, DataType* result) const = 0; + virtual absl::Status GetTypeList( + absl::string_view attr_name, + absl::InlinedVector* type_list) const = 0; + + private: + const AbstractOpAttrsKind kind_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_C_EAGER_ABSTRACT_OP_ATTRS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/c/eager/abstract_operation.h b/third_party/tflite-hdrs/tensorflow/c/eager/abstract_operation.h new file mode 100644 index 00000000..95142210 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/c/eager/abstract_operation.h @@ -0,0 +1,172 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_C_EAGER_ABSTRACT_OPERATION_H_ +#define TENSORFLOW_C_EAGER_ABSTRACT_OPERATION_H_ + +#include + +#include "absl/types/span.h" +#include "tensorflow/c/eager/abstract_tensor_handle.h" +#include "tensorflow/c/tensor_interface.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/platform/status.h" + +namespace tensorflow { + +// Abstract interface to an operation. +// This interface allows building and executing an operation in either +// tracing or immediate execution mode. +class AbstractOperation { + protected: + enum AbstractOperationKind { + kGraph, + kMlir, + kEager, + kTfrt, + kTape, + kOpHandler + }; + explicit AbstractOperation(AbstractOperationKind kind) : kind_(kind) {} + virtual ~AbstractOperation() {} + + public: + AbstractOperationKind getKind() const { return kind_; } + + // Release any underlying resources, including the interface object. + // + // WARNING: The destructor of this class is marked as protected to disallow + // clients from directly destroying this object since it may manage it's own + // lifetime through ref counting. Thus this must be allocated on the heap and + // clients MUST call Release() in order to destroy an instance of this class. + virtual void Release() = 0; + + virtual absl::Status Reset(const char* op, const char* raw_device_name) = 0; + + virtual const string& Name() const = 0; + + // Returns the operation's device name. + // + // The value returned may be different from the one set by SetDeviceName, but + // it will be compatible with it: the name will be updated by device placement + // logic to refer to the specific device chosen. + // + // Example: If one calls `op->SetDeviceName("/device:GPU")`, the value + // returned by DeviceName should be "/device:GPU:*" until a particular GPU is + // chosen for the operation by the device placement logic in the + // executor. After that, the value returned by DeviceName will be a full + // device name such as "/job:localhost/replica:0/task:0/device:GPU:1". + virtual const string& DeviceName() const = 0; + + // Sets the operation device name. + // + // The given `name` must be parseable by DeviceNameUtils::ParseFullName, and + // the result will be used as a constraint for device placement. See the + // documentation for DeviceName for more details. + // + // The value will override the previous value - that is, no "merging" of + // existing and given constraints will be performed. + virtual absl::Status SetDeviceName(const char* name) = 0; + + virtual absl::Status AddInput(AbstractTensorHandle* input) = 0; + virtual absl::Status AddInputList( + absl::Span inputs) = 0; + virtual absl::Status Execute(absl::Span retvals, + int* num_retvals) = 0; + + virtual absl::Status SetAttrString(const char* attr_name, const char* data, + size_t length) = 0; + virtual absl::Status SetAttrInt(const char* attr_name, int64_t value) = 0; + virtual absl::Status SetAttrFloat(const char* attr_name, float value) = 0; + virtual absl::Status SetAttrBool(const char* attr_name, bool value) = 0; + virtual absl::Status SetAttrType(const char* attr_name, DataType value) = 0; + virtual absl::Status SetAttrShape(const char* attr_name, const int64_t* dims, + const int num_dims) = 0; + virtual absl::Status SetAttrShape(const char* attr_name, + const PartialTensorShape shape); + virtual absl::Status SetAttrFunction(const char* attr_name, + const AbstractOperation* value) = 0; + virtual absl::Status SetAttrFunctionName(const char* attr_name, + const char* value, + size_t length) = 0; + virtual absl::Status SetAttrTensor(const char* attr_name, + AbstractTensorInterface* tensor) = 0; + virtual absl::Status SetAttrStringList(const char* attr_name, + const void* const* values, + const size_t* lengths, + int num_values) = 0; + virtual absl::Status SetAttrStringList(const char* attr_name, + absl::Span values); + virtual absl::Status SetAttrFloatList(const char* attr_name, + const float* values, + int num_values) = 0; + virtual absl::Status SetAttrIntList(const char* attr_name, + const int64_t* values, + int num_values) = 0; + virtual absl::Status SetAttrTypeList(const char* attr_name, + const DataType* values, + int num_values) = 0; + virtual absl::Status SetAttrBoolList(const char* attr_name, + const unsigned char* values, + int num_values) = 0; + virtual absl::Status SetAttrShapeList(const char* attr_name, + const int64_t** dims, + const int* num_dims, + int num_values) = 0; + virtual absl::Status SetAttrFunctionList( + const char* attr_name, absl::Span values) = 0; + + private: + const AbstractOperationKind kind_; +}; + +// TODO(b/193656009): Defining these in a cc file causes linker errors with +// fastbuild. +inline absl::Status AbstractOperation::SetAttrShape( + const char* attr_name, const PartialTensorShape shape) { + return SetAttrShape(attr_name, shape.dim_sizes().data(), shape.dims()); +} + +inline absl::Status AbstractOperation::SetAttrStringList( + const char* attr_name, absl::Span values) { + std::vector raw_strs; + std::vector lengths; + raw_strs.reserve(values.size()); + lengths.reserve(values.size()); + for (const auto& s : values) { + raw_strs.emplace_back(s.data()); + lengths.emplace_back(s.size()); + } + return SetAttrStringList(attr_name, + reinterpret_cast(raw_strs.data()), + lengths.data(), values.size()); +} + +namespace internal { +struct AbstractOperationDeleter { + void operator()(AbstractOperation* p) const { + if (p != nullptr) { + p->Release(); + } + } +}; +} // namespace internal + +using AbstractOperationPtr = + std::unique_ptr; + +} // namespace tensorflow + +#endif // TENSORFLOW_C_EAGER_ABSTRACT_OPERATION_H_ diff --git a/third_party/tflite-hdrs/tensorflow/c/eager/abstract_tensor_handle.h b/third_party/tflite-hdrs/tensorflow/c/eager/abstract_tensor_handle.h new file mode 100644 index 00000000..4a40b1c9 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/c/eager/abstract_tensor_handle.h @@ -0,0 +1,83 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_C_EAGER_ABSTRACT_TENSOR_HANDLE_H_ +#define TENSORFLOW_C_EAGER_ABSTRACT_TENSOR_HANDLE_H_ + +#include + +#include "tensorflow/core/framework/full_type.pb.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/platform/refcount.h" +#include "tensorflow/core/platform/status.h" + +namespace tensorflow { + +// Abstract interface to a Tensor handle in either tracing or immediate +// execution mode. +class AbstractTensorHandle : public core::RefCounted { + protected: + enum AbstractTensorHandleKind { kGraph, kMlir, kEager, kTfrt, kCustomDevice }; + explicit AbstractTensorHandle(AbstractTensorHandleKind kind) : kind_(kind) {} + ~AbstractTensorHandle() override {} + + public: + // Returns tensor dtype. + virtual tensorflow::DataType DataType() const = 0; + + // Returns the status of the tensor handle. If it is a tfrt::TensorHandle, + // the tensor handle can be an error and return non-OK status. + virtual absl::Status TensorHandleStatus() const; + + // Returns tensor shape. If tensor has unknown rank, shape remains untouched. + virtual absl::Status Shape(tensorflow::PartialTensorShape* shape) const = 0; + + // Returns tensor (full) type. + // While there is no immediate plan to deprecate dtype and shape in favor + // of only using full type type information, this is a future possibility. + // + // Note that map_dtype_to_child_of_tensor() from core/framework/types.h + // can be used to set a FullTypeDef based on dtype in a derived class if + // appropriate. + virtual tensorflow::FullTypeDef FullType() const = 0; + + // The default debug string includes a shape, dtype and FullType. + // Implementations are free to override it with something more informative. + virtual std::string DebugString() const; + + AbstractTensorHandleKind getKind() const { return kind_; } + + private: + const AbstractTensorHandleKind kind_; +}; + +namespace internal { +struct AbstractTensorHandleDeleter { + void operator()(AbstractTensorHandle* p) const { + if (p != nullptr) { + p->Unref(); + } + } +}; +} // namespace internal + +// TODO(b/185908092): Make AbstractTensorHandlePtr an IntrusivePtr. +using AbstractTensorHandlePtr = + std::unique_ptr; + +} // namespace tensorflow + +#endif // TENSORFLOW_C_EAGER_ABSTRACT_TENSOR_HANDLE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/c/eager/c_api.h b/third_party/tflite-hdrs/tensorflow/c/eager/c_api.h new file mode 100644 index 00000000..7f458ac5 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/c/eager/c_api.h @@ -0,0 +1,448 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_EAGER_C_API_H_ +#define TENSORFLOW_C_EAGER_C_API_H_ + +// C API extensions to experiment with eager execution of kernels. +// WARNING: Unlike tensorflow/c/c_api.h, the API here is not guaranteed to be +// stable and can change without notice. + +#include "tensorflow/c/c_api.h" +#include "tensorflow/c/c_api_macros.h" + +#ifdef __cplusplus +extern "C" { +#endif + +typedef struct TFE_ContextOptions TFE_ContextOptions; + +// Return a new options object. +TF_CAPI_EXPORT extern TFE_ContextOptions* TFE_NewContextOptions(void); + +// Set the config in TF_ContextOptions.options. +// config should be a serialized tensorflow.ConfigProto proto. +// If config was not parsed successfully as a ConfigProto, record the +// error information in *status. +TF_CAPI_EXPORT extern void TFE_ContextOptionsSetConfig( + TFE_ContextOptions* options, const void* proto, size_t proto_len, + TF_Status* status); + +// Controls how to act when we try to run an operation on a given device but +// some input tensors are not on that device. +// LINT.IfChange +// Note: Keep in sync with internal copy of enum in eager/context.h. +typedef enum TFE_ContextDevicePlacementPolicy { + // Running operations with input tensors on the wrong device will fail. + TFE_DEVICE_PLACEMENT_EXPLICIT = 0, + // Copy the tensor to the right device but log a warning. + TFE_DEVICE_PLACEMENT_WARN = 1, + // Silently copy the tensor, which has a performance cost since the operation + // will be blocked till the copy completes. This is the default placement + // policy. + TFE_DEVICE_PLACEMENT_SILENT = 2, + // Placement policy which silently copies int32 tensors but not other dtypes. + TFE_DEVICE_PLACEMENT_SILENT_FOR_INT32 = 3, +} TFE_ContextDevicePlacementPolicy; +// LINT.ThenChange(//tensorflow/c/eager/immediate_execution_context.h) + +// Sets the default execution mode (sync/async). Note that this can be +// overridden per thread using TFE_ContextSetExecutorForThread. +TF_CAPI_EXPORT extern void TFE_ContextOptionsSetAsync(TFE_ContextOptions*, + unsigned char enable); + +TF_CAPI_EXPORT extern void TFE_ContextOptionsSetDevicePlacementPolicy( + TFE_ContextOptions*, TFE_ContextDevicePlacementPolicy); + +// Destroy an options object. +TF_CAPI_EXPORT extern void TFE_DeleteContextOptions(TFE_ContextOptions*); + +// "Context" under which operations/functions are executed. It encapsulates +// things like the available devices, resource manager etc. +// TFE_Context must outlive all tensor handles created using it. In other +// words, TFE_DeleteContext() must be called after all tensor handles have +// been deleted (with TFE_DeleteTensorHandle). +// +// TODO(ashankar): Merge with TF_Session? +typedef struct TFE_Context TFE_Context; + +TF_CAPI_EXPORT extern TFE_Context* TFE_NewContext( + const TFE_ContextOptions* opts, TF_Status* status); +TF_CAPI_EXPORT extern void TFE_DeleteContext(TFE_Context* ctx); +TF_CAPI_EXPORT extern TF_DeviceList* TFE_ContextListDevices(TFE_Context* ctx, + TF_Status* status); + +// Clears the internal caches in the TFE context. Useful when reseeding random +// ops. +TF_CAPI_EXPORT extern void TFE_ContextClearCaches(TFE_Context* ctx); + +// Sets a thread-local device placement policy. After this call, other calls to +// TFE_Execute in the same thread will use the device policy specified here +// instead of the device policy used to construct the context. This has no +// effect on the device policy used by other program threads. +TF_CAPI_EXPORT extern void TFE_ContextSetThreadLocalDevicePlacementPolicy( + TFE_Context* ctx, TFE_ContextDevicePlacementPolicy policy); + +// Returns the device placement policy to be used by this context in the current +// thread. +TF_CAPI_EXPORT extern TFE_ContextDevicePlacementPolicy +TFE_ContextGetDevicePlacementPolicy(TFE_Context* ctx); + +// A tensorflow.ServerDef specifies remote workers (in addition to the current +// workers name). Operations created in this context can then be executed on +// any of these remote workers by setting an appropriate device. +// +// If the following is set, all servers identified by the +// ServerDef must be up when the context is created. +TF_CAPI_EXPORT extern void TFE_ContextSetServerDef(TFE_Context* ctx, + int keep_alive_secs, + const void* proto, + size_t proto_len, + TF_Status* status); + +// A handle to a tensor on a device. +// +// Like a TF_Tensor, a TFE_TensorHandle refers to a tensor with a value, shape, +// type etc. Unlike a TF_Tensor, a TFE_TensorHandle may refer to such tensors +// placed in the memory of different devices or remote address spaces. +typedef struct TFE_TensorHandle TFE_TensorHandle; + +TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_NewTensorHandle(const TF_Tensor* t, + TF_Status* status); +// Indicates that the caller will not be using `h` any more. +TF_CAPI_EXPORT extern void TFE_DeleteTensorHandle(TFE_TensorHandle* h); +TF_CAPI_EXPORT extern TF_DataType TFE_TensorHandleDataType(TFE_TensorHandle* h); +// This function will block till the operation that produces `h` has completed. +TF_CAPI_EXPORT extern int TFE_TensorHandleNumDims(TFE_TensorHandle* h, + TF_Status* status); +TF_CAPI_EXPORT extern int64_t TFE_TensorHandleNumElements(TFE_TensorHandle* h, + TF_Status* status); +// This function will block till the operation that produces `h` has completed. +TF_CAPI_EXPORT extern int64_t TFE_TensorHandleDim(TFE_TensorHandle* h, + int dim_index, + TF_Status* status); + +// Returns the device of the operation that produced `h`. If `h` was produced by +// a copy, returns the destination device of the copy. Note that the returned +// device name is not always the device holding the tensor handle's memory. If +// you want the latter, use TFE_TensorHandleBackingDeviceName. This function +// will block till the operation that produces `h` has completed. +TF_CAPI_EXPORT extern const char* TFE_TensorHandleDeviceName( + TFE_TensorHandle* h, TF_Status* status); + +// Returns the name of the device in whose memory `h` resides. +// +// This function will block till the operation that produces `h` has completed. +TF_CAPI_EXPORT extern const char* TFE_TensorHandleBackingDeviceName( + TFE_TensorHandle* h, TF_Status* status); + +// Return a pointer to a new TFE_TensorHandle that shares the underlying tensor +// with `h`. On success, `status` is set to OK. On failure, `status` reflects +// the error and a nullptr is returned. +TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_TensorHandleCopySharingTensor( + TFE_TensorHandle* h, TF_Status* status); + +// This function will block till the operation that produces `h` has +// completed. The memory returned might alias the internal memory used by +// TensorFlow. Hence, callers should not mutate this memory (for example by +// modifying the memory region pointed to by TF_TensorData() on the returned +// TF_Tensor). +TF_CAPI_EXPORT extern TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h, + TF_Status* status); + +// Create a new TFE_TensorHandle with the same contents as 'h' but placed +// in the memory of the device name 'device_name'. +// If source and destination are the same device, then this creates a new handle +// that shares the underlying buffer. Otherwise, it currently requires at least +// one of the source or destination devices to be CPU (i.e., for the source or +// destination tensor to be placed in host memory). +// If async execution is enabled, the copy may be enqueued and the call will +// return "non-ready" handle. Else, this function returns after the copy has +// been done. +TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_TensorHandleCopyToDevice( + TFE_TensorHandle* h, TFE_Context* ctx, const char* device_name, + TF_Status* status); + +// Debugging/Profiling information for TFE_TensorHandle +// +// TFE_TensorDebugInfo contains information useful for debugging and +// profiling tensors. +typedef struct TFE_TensorDebugInfo TFE_TensorDebugInfo; + +// Retrieves TFE_TensorDebugInfo for `handle`. +// If TFE_TensorHandleTensorDebugInfo succeeds, `status` is set to OK and caller +// is responsible for deleting returned TFE_TensorDebugInfo. +// If TFE_TensorHandleTensorDebugInfo fails, `status` is set to appropriate +// error and nullptr is returned. This function can block till the operation +// that produces `handle` has completed. +TF_CAPI_EXPORT extern TFE_TensorDebugInfo* TFE_TensorHandleTensorDebugInfo( + TFE_TensorHandle* h, TF_Status* status); + +// Deletes `debug_info`. +TF_CAPI_EXPORT extern void TFE_DeleteTensorDebugInfo( + TFE_TensorDebugInfo* debug_info); + +// Returns the number of dimensions used to represent the tensor on its device. +// The number of dimensions used to represent the tensor on device can be +// different from the number returned by TFE_TensorHandleNumDims. +// The return value was current at the time of TFE_TensorDebugInfo creation. +TF_CAPI_EXPORT extern int TFE_TensorDebugInfoOnDeviceNumDims( + TFE_TensorDebugInfo* debug_info); + +// Returns the number of elements in dimension `dim_index`. +// Tensor representation on device can be transposed from its representation +// on host. The data contained in dimension `dim_index` on device +// can correspond to the data contained in another dimension in on-host +// representation. The dimensions are indexed using the standard TensorFlow +// major-to-minor order (slowest varying dimension first), +// not the XLA's minor-to-major order. +// On-device dimensions can be padded. TFE_TensorDebugInfoOnDeviceDim returns +// the number of elements in a dimension after padding. +// The return value was current at the time of TFE_TensorDebugInfo creation. +TF_CAPI_EXPORT extern int64_t TFE_TensorDebugInfoOnDeviceDim( + TFE_TensorDebugInfo* debug_info, int dim_index); + +// Description of the TensorFlow op to execute. +// +// Assumes that the provided 'ctx' outlives the returned TFE_Op, i.e., +// TFE_DeleteOp() is called before TFE_DeleteContext(). +// +// Very similar to TF_OperationDescription with some differences: +// (1) TF_Output or TFE_TensorHandle* as arguments to TF_AddInput, +// TF_AddInputList +// (2) TF_ColocateWith, TF_AddControlInput etc. do not make sense. +// (3) Implementation detail: Avoid use of NodeBuilder/NodeDefBuilder since +// the additional sanity checks there seem unnecessary; +typedef struct TFE_Op TFE_Op; + +TF_CAPI_EXPORT extern TFE_Op* TFE_NewOp(TFE_Context* ctx, + const char* op_or_function_name, + TF_Status* status); +TF_CAPI_EXPORT extern void TFE_DeleteOp(TFE_Op* op); + +// Returns the op or function name `op` will execute. +// +// The returned string remains valid throughout the lifetime of 'op'. +TF_CAPI_EXPORT extern const char* TFE_OpGetName(const TFE_Op* op, + TF_Status* status); +TF_CAPI_EXPORT extern TFE_Context* TFE_OpGetContext(const TFE_Op* op, + TF_Status* status); + +TF_CAPI_EXPORT extern void TFE_OpSetDevice(TFE_Op* op, const char* device_name, + TF_Status* status); +// The returned string remains valid throughout the lifetime of 'op'. +TF_CAPI_EXPORT extern const char* TFE_OpGetDevice(const TFE_Op* op, + TF_Status* status); + +TF_CAPI_EXPORT extern void TFE_OpAddInput(TFE_Op* op, TFE_TensorHandle* input, + TF_Status* status); + +TF_CAPI_EXPORT extern void TFE_OpAddInputList(TFE_Op* op, + TFE_TensorHandle** inputs, + int num_inputs, + TF_Status* status); + +// Fetches the current number of inputs attached to `op`. +// +// Does not use the operation's definition to determine how many inputs should +// be attached. It is intended for use with TFE_OpGetFlatInput to inspect an +// already-finalized operation. +// +// Note that TFE_OpGetFlatInputCount and TFE_OpGetFlatInput operate on a flat +// sequence of inputs, unlike TFE_OpGetInputLength (for getting the length of a +// particular named input list, which may only be part of the op's inputs). +TF_CAPI_EXPORT extern int TFE_OpGetFlatInputCount(const TFE_Op* op, + TF_Status* status); +// Returns a borrowed reference to one of `op`'s inputs. Use +// `TFE_TensorHandleCopySharingTensor` to make a new reference. +TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_OpGetFlatInput(const TFE_Op* op, + int index, + TF_Status* status); + +TF_CAPI_EXPORT extern TF_AttrType TFE_OpGetAttrType(TFE_Op* op, + const char* attr_name, + unsigned char* is_list, + TF_Status* status); +// Get an attribute type given an op name; a fusion of TFE_NewOp and +// TFE_OpGetAttrType for use from Python without the overhead of the individual +// calls and memory management of TFE_Op. +TF_CAPI_EXPORT extern TF_AttrType TFE_OpNameGetAttrType( + TFE_Context* ctx, const char* op_or_function_name, const char* attr_name, + unsigned char* is_list, TF_Status* status); + +TF_CAPI_EXPORT extern void TFE_OpSetAttrString(TFE_Op* op, + const char* attr_name, + const void* value, + size_t length); +TF_CAPI_EXPORT extern void TFE_OpSetAttrInt(TFE_Op* op, const char* attr_name, + int64_t value); +TF_CAPI_EXPORT extern void TFE_OpSetAttrFloat(TFE_Op* op, const char* attr_name, + float value); +TF_CAPI_EXPORT extern void TFE_OpSetAttrBool(TFE_Op* op, const char* attr_name, + unsigned char value); +TF_CAPI_EXPORT extern void TFE_OpSetAttrType(TFE_Op* op, const char* attr_name, + TF_DataType value); +// If the number of dimensions is unknown, `num_dims` must be set to +// -1 and `dims` can be null. If a dimension is unknown, the +// corresponding entry in the `dims` array must be -1. +TF_CAPI_EXPORT extern void TFE_OpSetAttrShape(TFE_Op* op, const char* attr_name, + const int64_t* dims, + const int num_dims, + TF_Status* out_status); + +// Sets the attribute attr_name to be a function specified by 'function'. +// +// TODO(ashankar,iga): Add this functionality to the C API for graph +// construction. Perhaps we want an AttrValueMap equivalent in the C API? +TF_CAPI_EXPORT extern void TFE_OpSetAttrFunction(TFE_Op* op, + const char* attr_name, + const TFE_Op* value); + +TF_CAPI_EXPORT void TFE_OpSetAttrFunctionName(TFE_Op* op, const char* attr_name, + const char* data, size_t length); + +TF_CAPI_EXPORT extern void TFE_OpSetAttrTensor(TFE_Op* op, + const char* attr_name, + TF_Tensor* tensor, + TF_Status* status); + +TF_CAPI_EXPORT extern void TFE_OpSetAttrStringList(TFE_Op* op, + const char* attr_name, + const void* const* values, + const size_t* lengths, + int num_values); +TF_CAPI_EXPORT extern void TFE_OpSetAttrIntList(TFE_Op* op, + const char* attr_name, + const int64_t* values, + int num_values); +TF_CAPI_EXPORT extern void TFE_OpSetAttrFloatList(TFE_Op* op, + const char* attr_name, + const float* values, + int num_values); +TF_CAPI_EXPORT extern void TFE_OpSetAttrBoolList(TFE_Op* op, + const char* attr_name, + const unsigned char* values, + int num_values); +TF_CAPI_EXPORT extern void TFE_OpSetAttrTypeList(TFE_Op* op, + const char* attr_name, + const TF_DataType* values, + int num_values); +TF_CAPI_EXPORT extern void TFE_OpSetAttrShapeList( + TFE_Op* op, const char* attr_name, const int64_t** dims, + const int* num_dims, int num_values, TF_Status* out_status); +TF_CAPI_EXPORT extern void TFE_OpSetAttrFunctionList(TFE_Op* op, + const char* attr_name, + const TFE_Op** value, + int num_values); + +// Returns the length (number of tensors) of the input argument `input_name` +// found in the provided `op`. +TF_CAPI_EXPORT extern int TFE_OpGetInputLength(TFE_Op* op, + const char* input_name, + TF_Status* status); + +// Returns the length (number of tensors) of the output argument `output_name` +// found in the provided `op`. +TF_CAPI_EXPORT extern int TFE_OpGetOutputLength(TFE_Op* op, + const char* output_name, + TF_Status* status); + +// Execute the operation defined by 'op' and return handles to computed +// tensors in `retvals`. +// +// 'retvals' must point to a pre-allocated array of TFE_TensorHandle* and +// '*num_retvals' should be set to the size of this array. It is an error if +// the size of 'retvals' is less than the number of outputs. This call sets +// *num_retvals to the number of outputs. +// +// If async execution is enabled, the call may simply enqueue the execution +// and return "non-ready" handles in `retvals`. Note that any handles contained +// in 'op' should not be mutated till the kernel execution actually finishes. +// +// For sync execution, if any of the inputs to `op` are not ready, this call +// will block till they become ready and then return when the kernel execution +// is done. +// TODO(agarwal): change num_retvals to int from int*. +TF_CAPI_EXPORT extern void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, + int* num_retvals, TF_Status* status); + +// Add a function (serialized FunctionDef protocol buffer) to ctx so +// that it can be invoked using TFE_Execute. +TF_CAPI_EXPORT extern void TFE_ContextAddFunctionDef( + TFE_Context* ctx, const char* serialized_function_def, size_t size, + TF_Status* status); + +// Adds a function (created from TF_GraphToFunction or +// TF_FunctionImportFunctionDef) to the context, allowing it to be executed with +// TFE_Execute by creating an op with the same name as the function. +TF_CAPI_EXPORT extern void TFE_ContextAddFunction(TFE_Context* ctx, + TF_Function* function, + TF_Status* status); + +// Removes a function from the context. Once removed, you can no longer +// TFE_Execute it or TFE_Execute any TFE_Op which has it as an attribute or any +// other function which calls it as an attribute. +TF_CAPI_EXPORT extern void TFE_ContextRemoveFunction(TFE_Context* ctx, + const char* name, + TF_Status* status); + +// Checks whether a function is registered under `name`. +TF_CAPI_EXPORT unsigned char TFE_ContextHasFunction(TFE_Context* ctx, + const char* name); + +// Enables tracing of RunMetadata on the ops executed from this context. +TF_CAPI_EXPORT extern void TFE_ContextEnableRunMetadata(TFE_Context* ctx); + +// Disables tracing of RunMetadata on the ops executed from this context. +TF_CAPI_EXPORT extern void TFE_ContextDisableRunMetadata(TFE_Context* ctx); + +// Populates the passed-in buffer with a serialized RunMetadata protocol buffer +// containing any run metadata information accumulated so far and clears this +// information. +// If async mode is enabled, this call blocks till all currently pending ops are +// done. +TF_CAPI_EXPORT extern void TFE_ContextExportRunMetadata(TFE_Context* ctx, + TF_Buffer* buf, + TF_Status* status); + +// Some TF ops need a step container to be set to limit the lifetime of some +// resources (mostly TensorArray and Stack, used in while loop gradients in +// graph mode). Calling this on a context tells it to start a step. +TF_CAPI_EXPORT extern void TFE_ContextStartStep(TFE_Context* ctx); + +// Ends a step. When there is no active step (that is, every started step has +// been ended) step containers will be cleared. Note: it is not safe to call +// TFE_ContextEndStep while ops that rely on the step container may be running. +TF_CAPI_EXPORT extern void TFE_ContextEndStep(TFE_Context* ctx); + +#ifdef __cplusplus +} /* end extern "C" */ +#endif + +#ifdef __cplusplus +// A workaround to ease conversion to and from numpy objects and +// TFE_TensorHandle's. +// +// TODO(ashankar): Figure out an alternative scheme that precludes the need for +// these API-boundary breaking methods. +namespace tensorflow { +class Tensor; +} // namespace tensorflow + +TFE_TensorHandle* TFE_NewTensorHandle(const tensorflow::Tensor& t, + TF_Status* status); +#endif + +#endif // TENSORFLOW_C_EAGER_C_API_H_ diff --git a/third_party/tflite-hdrs/tensorflow/c/eager/c_api_experimental.h b/third_party/tflite-hdrs/tensorflow/c/eager/c_api_experimental.h new file mode 100644 index 00000000..ab50b470 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/c/eager/c_api_experimental.h @@ -0,0 +1,797 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_C_EAGER_C_API_EXPERIMENTAL_H_ +#define TENSORFLOW_C_EAGER_C_API_EXPERIMENTAL_H_ + +#include "tensorflow/c/c_api.h" +#include "tensorflow/c/c_api_macros.h" +#include "tensorflow/c/eager/c_api.h" + +#ifdef __cplusplus +extern "C" { +#endif + +// Resets `op_to_reset` with `op_or_function_name` and `raw_device_name`. This +// is for performance optimization by reusing an exiting unused op rather than +// creating a new op every time. If `raw_device_name` is `NULL` or empty, it +// does not set the device name. If it's not `NULL`, then it attempts to parse +// and set the device name. It's effectively `TFE_OpSetDevice`, but it is faster +// than separately calling it because if the existing op has the same +// `raw_device_name`, it skips parsing and just leave as it is. +TF_CAPI_EXPORT extern void TFE_OpReset(TFE_Op* op_to_reset, + const char* op_or_function_name, + const char* raw_device_name, + TF_Status* status); + +// Enables only graph collection in RunMetadata on the functions executed from +// this context. +TF_CAPI_EXPORT extern void TFE_ContextEnableGraphCollection(TFE_Context* ctx); + +// Disables only graph collection in RunMetadata on the functions executed from +// this context. +TF_CAPI_EXPORT extern void TFE_ContextDisableGraphCollection(TFE_Context* ctx); + +// TODO(fishx): Move these monitoring APIs into a separate file. +// ----------------------------------------------------------------------------- +// Monitoring Counter APIs. +// These APIs de-templated monitoring Counter for swig. + +typedef struct TFE_MonitoringCounterCell TFE_MonitoringCounterCell; + +// Atomically increments the value of the cell. The value must be non-negative. +TF_CAPI_EXPORT extern void TFE_MonitoringCounterCellIncrementBy( + TFE_MonitoringCounterCell* cell, int64_t value); + +// Retrieves the current value of the cell. +TF_CAPI_EXPORT extern int64_t TFE_MonitoringCounterCellValue( + TFE_MonitoringCounterCell* cell); + +// APIs for Counter without label. +typedef struct TFE_MonitoringCounter0 TFE_MonitoringCounter0; +// Returns a new Counter metric object. The caller should manage lifetime of +// the object. Using duplicate metric name will crash the program with fatal +// error. +TF_CAPI_EXPORT extern TFE_MonitoringCounter0* TFE_MonitoringNewCounter0( + const char* name, TF_Status* status, const char* description); +// Deletes the Counter object. +TF_CAPI_EXPORT extern void TFE_MonitoringDeleteCounter0( + TFE_MonitoringCounter0* counter); +// Retrieves the cell from the Counter object. The Counter object will manage +// lifetime of the cell. +TF_CAPI_EXPORT extern TFE_MonitoringCounterCell* TFE_MonitoringGetCellCounter0( + TFE_MonitoringCounter0* counter); + +// APIs for Counter with 1 label. +typedef struct TFE_MonitoringCounter1 TFE_MonitoringCounter1; +TF_CAPI_EXPORT extern TFE_MonitoringCounter1* TFE_MonitoringNewCounter1( + const char* name, TF_Status* status, const char* description, + const char* label1); +TF_CAPI_EXPORT extern void TFE_MonitoringDeleteCounter1( + TFE_MonitoringCounter1* counter); +TF_CAPI_EXPORT extern TFE_MonitoringCounterCell* TFE_MonitoringGetCellCounter1( + TFE_MonitoringCounter1* counter, const char* label1); + +// APIs for Counter with 2 labels. +typedef struct TFE_MonitoringCounter2 TFE_MonitoringCounter2; +TF_CAPI_EXPORT extern TFE_MonitoringCounter2* TFE_MonitoringNewCounter2( + const char* name, TF_Status* status, const char* description, + const char* label1, const char* label2); +TF_CAPI_EXPORT extern void TFE_MonitoringDeleteCounter2( + TFE_MonitoringCounter2* counter); +TF_CAPI_EXPORT extern TFE_MonitoringCounterCell* TFE_MonitoringGetCellCounter2( + TFE_MonitoringCounter2* counter, const char* label1, const char* label2); + +// ----------------------------------------------------------------------------- +// Monitoring Gauge APIs. +// These APIs de-templated monitoring Gauge for swig. + +typedef struct TFE_MonitoringIntGaugeCell TFE_MonitoringIntGaugeCell; + +// Atomically set the value of the cell. +TF_CAPI_EXPORT extern void TFE_MonitoringIntGaugeCellSet( + TFE_MonitoringIntGaugeCell* cell, int64_t value); + +// Retrieves the current value of the cell. +TF_CAPI_EXPORT extern int64_t TFE_MonitoringIntGaugeCellValue( + TFE_MonitoringIntGaugeCell* cell); + +// APIs for Int Gauge without label. +typedef struct TFE_MonitoringIntGauge0 TFE_MonitoringIntGauge0; +TF_CAPI_EXPORT extern TFE_MonitoringIntGauge0* TFE_MonitoringNewIntGauge0( + const char* name, TF_Status* out_status, const char* description); +TF_CAPI_EXPORT extern void TFE_MonitoringDeleteIntGauge0( + TFE_MonitoringIntGauge0* gauge); +TF_CAPI_EXPORT extern TFE_MonitoringIntGaugeCell* +TFE_MonitoringGetCellIntGauge0(TFE_MonitoringIntGauge0* gauge); + +// APIs for Int Gauge with 1 label. +typedef struct TFE_MonitoringIntGauge1 TFE_MonitoringIntGauge1; +TF_CAPI_EXPORT extern TFE_MonitoringIntGauge1* TFE_MonitoringNewIntGauge1( + const char* name, TF_Status* out_status, const char* description, + const char* label1); +TF_CAPI_EXPORT extern void TFE_MonitoringDeleteIntGauge1( + TFE_MonitoringIntGauge1* gauge); +TF_CAPI_EXPORT extern TFE_MonitoringIntGaugeCell* +TFE_MonitoringGetCellIntGauge1(TFE_MonitoringIntGauge1* gauge, + const char* label1); + +// APIs for Int Gauge with 2 label. +typedef struct TFE_MonitoringIntGauge2 TFE_MonitoringIntGauge2; +TF_CAPI_EXPORT extern TFE_MonitoringIntGauge2* TFE_MonitoringNewIntGauge2( + const char* name, TF_Status* out_status, const char* description, + const char* label1, const char* label2); +TF_CAPI_EXPORT extern void TFE_MonitoringDeleteIntGauge2( + TFE_MonitoringIntGauge2* gauge); +TF_CAPI_EXPORT extern TFE_MonitoringIntGaugeCell* +TFE_MonitoringGetCellIntGauge2(TFE_MonitoringIntGauge2* gauge, + const char* label1, const char* label2); + +typedef struct TFE_MonitoringStringGaugeCell TFE_MonitoringStringGaugeCell; +TF_CAPI_EXPORT extern void TFE_MonitoringStringGaugeCellSet( + TFE_MonitoringStringGaugeCell* cell, const char* value); +// Retrieves the string value and saves it in the buffer. +TF_CAPI_EXPORT extern const void TFE_MonitoringStringGaugeCellValue( + TFE_MonitoringStringGaugeCell* cell, TF_Buffer* buf); + +// APIs for String Gauge without label. +typedef struct TFE_MonitoringStringGauge0 TFE_MonitoringStringGauge0; +TF_CAPI_EXPORT extern TFE_MonitoringStringGauge0* TFE_MonitoringNewStringGauge0( + const char* name, TF_Status* out_status, const char* description); +TF_CAPI_EXPORT extern void TFE_MonitoringDeleteStringGauge0( + TFE_MonitoringStringGauge0* gauge); +TF_CAPI_EXPORT extern TFE_MonitoringStringGaugeCell* +TFE_MonitoringGetCellStringGauge0(TFE_MonitoringStringGauge0* gauge); + +// APIs for String Gauge with 1 label. +typedef struct TFE_MonitoringStringGauge1 TFE_MonitoringStringGauge1; +TF_CAPI_EXPORT extern TFE_MonitoringStringGauge1* TFE_MonitoringNewStringGauge1( + const char* name, TF_Status* out_status, const char* description, + const char* label1); +TF_CAPI_EXPORT extern void TFE_MonitoringDeleteStringGauge1( + TFE_MonitoringStringGauge1* gauge); +TF_CAPI_EXPORT extern TFE_MonitoringStringGaugeCell* +TFE_MonitoringGetCellStringGauge1(TFE_MonitoringStringGauge1* gauge, + const char* label1); + +// APIs for String Gauge with 2 label. +typedef struct TFE_MonitoringStringGauge2 TFE_MonitoringStringGauge2; +TF_CAPI_EXPORT extern TFE_MonitoringStringGauge2* TFE_MonitoringNewStringGauge2( + const char* name, TF_Status* out_status, const char* description, + const char* label1, const char* label2); +TF_CAPI_EXPORT extern void TFE_MonitoringDeleteStringGauge2( + TFE_MonitoringStringGauge2* gauge); +TF_CAPI_EXPORT extern TFE_MonitoringStringGaugeCell* +TFE_MonitoringGetCellStringGauge2(TFE_MonitoringStringGauge2* gauge, + const char* label1, const char* label2); + +// APIs for String Gauge with 3 labels. +typedef struct TFE_MonitoringStringGauge3 TFE_MonitoringStringGauge3; +TF_CAPI_EXPORT extern TFE_MonitoringStringGauge3* TFE_MonitoringNewStringGauge3( + const char* name, TF_Status* out_status, const char* description, + const char* label1, const char* label2, const char* label3); +TF_CAPI_EXPORT extern void TFE_MonitoringDeleteStringGauge3( + TFE_MonitoringStringGauge3* gauge); +TF_CAPI_EXPORT extern TFE_MonitoringStringGaugeCell* +TFE_MonitoringGetCellStringGauge3(TFE_MonitoringStringGauge3* gauge, + const char* label1, const char* label2, + const char* label3); + +// APIs for String Gauge with 4 labels. +typedef struct TFE_MonitoringStringGauge4 TFE_MonitoringStringGauge4; +TF_CAPI_EXPORT extern TFE_MonitoringStringGauge4* TFE_MonitoringNewStringGauge4( + const char* name, TF_Status* out_status, const char* description, + const char* label1, const char* label2, const char* label3, + const char* label4); +TF_CAPI_EXPORT extern void TFE_MonitoringDeleteStringGauge4( + TFE_MonitoringStringGauge4* gauge); +TF_CAPI_EXPORT extern TFE_MonitoringStringGaugeCell* +TFE_MonitoringGetCellStringGauge4(TFE_MonitoringStringGauge4* gauge, + const char* label1, const char* label2, + const char* label3, const char* label4); + +typedef struct TFE_MonitoringBoolGaugeCell TFE_MonitoringBoolGaugeCell; +TF_CAPI_EXPORT extern void TFE_MonitoringBoolGaugeCellSet( + TFE_MonitoringBoolGaugeCell* cell, bool value); +TF_CAPI_EXPORT extern bool TFE_MonitoringBoolGaugeCellValue( + TFE_MonitoringBoolGaugeCell* cell); + +// APIs for Bool Gauge without label. +typedef struct TFE_MonitoringBoolGauge0 TFE_MonitoringBoolGauge0; +TF_CAPI_EXPORT extern TFE_MonitoringBoolGauge0* TFE_MonitoringNewBoolGauge0( + const char* name, TF_Status* out_status, const char* description); +TF_CAPI_EXPORT extern void TFE_MonitoringDeleteBoolGauge0( + TFE_MonitoringBoolGauge0* gauge); +TF_CAPI_EXPORT extern TFE_MonitoringBoolGaugeCell* +TFE_MonitoringGetCellBoolGauge0(TFE_MonitoringBoolGauge0* gauge); + +// APIs for Bool Gauge with 1 label. +typedef struct TFE_MonitoringBoolGauge1 TFE_MonitoringBoolGauge1; +TF_CAPI_EXPORT extern TFE_MonitoringBoolGauge1* TFE_MonitoringNewBoolGauge1( + const char* name, TF_Status* out_status, const char* description, + const char* label1); +TF_CAPI_EXPORT extern void TFE_MonitoringDeleteBoolGauge1( + TFE_MonitoringBoolGauge1* gauge); +TF_CAPI_EXPORT extern TFE_MonitoringBoolGaugeCell* +TFE_MonitoringGetCellBoolGauge1(TFE_MonitoringBoolGauge1* gauge, + const char* label1); + +// APIs for Bool Gauge with 2 label. +typedef struct TFE_MonitoringBoolGauge2 TFE_MonitoringBoolGauge2; +TF_CAPI_EXPORT extern TFE_MonitoringBoolGauge2* TFE_MonitoringNewBoolGauge2( + const char* name, TF_Status* out_status, const char* description, + const char* label1, const char* label2); +TF_CAPI_EXPORT extern void TFE_MonitoringDeleteBoolGauge2( + TFE_MonitoringBoolGauge2* gauge); +TF_CAPI_EXPORT extern TFE_MonitoringBoolGaugeCell* +TFE_MonitoringGetCellBoolGauge2(TFE_MonitoringBoolGauge2* gauge, + const char* label1, const char* label2); + +// ----------------------------------------------------------------------------- +// Monitoring Sampler APIs. +// These APIs de-templated monitoring Sampler for swig. + +typedef struct TFE_MonitoringSamplerCell TFE_MonitoringSamplerCell; + +// Atomically add the value of the cell. +TF_CAPI_EXPORT extern void TFE_MonitoringSamplerCellAdd( + TFE_MonitoringSamplerCell* cell, double value); + +// Retrieves the current value of the cell. The return value is a HistogramProto +// saved in the buffer. +TF_CAPI_EXPORT extern void TFE_MonitoringSamplerCellValue( + TFE_MonitoringSamplerCell* cell, TF_Buffer* buf); + +// APIs for sampler buckets +typedef struct TFE_MonitoringBuckets TFE_MonitoringBuckets; +TF_CAPI_EXPORT extern TFE_MonitoringBuckets* +TFE_MonitoringNewExponentialBuckets(double scale, double growth_factor, + int bucket_count); +TF_CAPI_EXPORT extern void TFE_MonitoringDeleteBuckets( + TFE_MonitoringBuckets* buckets); + +// APIs for Sampler without label. +typedef struct TFE_MonitoringSampler0 TFE_MonitoringSampler0; +TF_CAPI_EXPORT extern TFE_MonitoringSampler0* TFE_MonitoringNewSampler0( + const char* name, TFE_MonitoringBuckets* buckets, TF_Status* out_status, + const char* description); +TF_CAPI_EXPORT extern void TFE_MonitoringDeleteSampler0( + TFE_MonitoringSampler0* sampler); +TF_CAPI_EXPORT extern TFE_MonitoringSamplerCell* TFE_MonitoringGetCellSampler0( + TFE_MonitoringSampler0* sampler); + +// APIs for Sampler with 1 label. +typedef struct TFE_MonitoringSampler1 TFE_MonitoringSampler1; +TF_CAPI_EXPORT extern TFE_MonitoringSampler1* TFE_MonitoringNewSampler1( + const char* name, TFE_MonitoringBuckets* buckets, TF_Status* out_status, + const char* description, const char* label1); +TF_CAPI_EXPORT extern void TFE_MonitoringDeleteSampler1( + TFE_MonitoringSampler1* sampler); +TF_CAPI_EXPORT extern TFE_MonitoringSamplerCell* TFE_MonitoringGetCellSampler1( + TFE_MonitoringSampler1* sampler, const char* label1); + +// APIs for Sampler with 2 label. +typedef struct TFE_MonitoringSampler2 TFE_MonitoringSampler2; +TF_CAPI_EXPORT extern TFE_MonitoringSampler2* TFE_MonitoringNewSampler2( + const char* name, TFE_MonitoringBuckets* buckets, TF_Status* out_status, + const char* description, const char* label1, const char* label2); +TF_CAPI_EXPORT extern void TFE_MonitoringDeleteSampler2( + TFE_MonitoringSampler2* sampler); +TF_CAPI_EXPORT extern TFE_MonitoringSamplerCell* TFE_MonitoringGetCellSampler2( + TFE_MonitoringSampler2* sampler, const char* label1, const char* label2); + +// Sets whether to use TFRT +TF_CAPI_EXPORT extern void TFE_ContextOptionsSetTfrt(TFE_ContextOptions*, + bool use_tfrt); + +// Returns the context_id from the EagerContext which is used by the +// EagerService to maintain consistency between client and worker. The +// context_id is initialized with a dummy value and is later set when the worker +// is initialized (either locally or remotely). The context_id can change during +// the process lifetime although this should cause the worker to be +// reinitialized (e.g. cleared caches) as well. +TF_CAPI_EXPORT extern uint64_t TFE_GetContextId(TFE_Context* ctx); + +// ----------------------------------------------------------------------------- +// Cancellation APIs. + +typedef struct TFE_CancellationManager TFE_CancellationManager; +typedef int64_t TFE_CancellationToken; +typedef struct TFE_CancelCallback { + void (*callback)(void* context); + void* context; +} TFE_CancelCallback; +TF_CAPI_EXPORT extern TFE_CancellationManager* TFE_NewCancellationManager(); +TF_CAPI_EXPORT extern bool TFE_CancellationManagerIsCancelled( + TFE_CancellationManager*); +TF_CAPI_EXPORT extern bool TFE_CancellationManagerIsCancelling( + TFE_CancellationManager*); +TF_CAPI_EXPORT extern void TFE_CancellationManagerStartCancel( + TFE_CancellationManager*); +TF_CAPI_EXPORT extern TFE_CancellationToken TFE_CancellationManagerGetToken( + TFE_CancellationManager*); +TF_CAPI_EXPORT extern bool TFE_CancellationManagerRegisterCallback( + TFE_CancellationManager*, TFE_CancellationToken token, + const TFE_CancelCallback* c_callback, const char* callback_name); +TF_CAPI_EXPORT extern bool TFE_CancellationManagerDeregisterCallback( + TFE_CancellationManager*, TFE_CancellationToken token); +TF_CAPI_EXPORT extern bool TFE_CancellationManagerTryDeregisterCallback( + TFE_CancellationManager*, TFE_CancellationToken token); +TF_CAPI_EXPORT extern void TFE_DeleteCancellationManager( + TFE_CancellationManager*); + +// Associates the given `cancellation_manager` with `op`, so that invoking +// `TFE_CancellationManagerStartCancel(cancellation_manager)` will cancel the +// execution of `op`. +typedef struct TFE_CancellationManager TFE_CancellationManager; +TF_CAPI_EXPORT extern void TFE_OpSetCancellationManager( + TFE_Op* op, TFE_CancellationManager* cancellation_manager, + TF_Status* status); + +// ----------------------------------------------------------------------------- +// Eager Executor APIs. +typedef struct TFE_Executor TFE_Executor; + +// Creates a new eager Executor. Nodes in one executor are guaranteed to be +// executed in sequence. Assigning nodes to different executors allows executing +// nodes in parallel. +// in_flight_nodes_limit: when is_async is true, this value controls the +// maximum number of in flight async nodes. Enqueuing of additional async ops +// after the limit is reached blocks until some inflight nodes finishes. +// The effect is bounding the memory held by inflight TensorHandles that are +// referenced by the inflight nodes. +// A recommended value has not been established. +// A value of 0 removes the limit, which is the behavior of TensorFlow 2.11. +// When is_async is false, the value is ignored. +TF_CAPI_EXPORT extern TFE_Executor* TFE_NewExecutor( + bool is_async, bool enable_streaming_enqueue, int in_flight_nodes_limit); + +// Deletes the eager Executor without waiting for enqueued nodes. Please call +// TFE_ExecutorWaitForAllPendingNodes before calling this API if you want to +// make sure all nodes are finished. +TF_CAPI_EXPORT extern void TFE_DeleteExecutor(TFE_Executor*); + +// Returns true if the executor is in async mode. +TF_CAPI_EXPORT extern bool TFE_ExecutorIsAsync(TFE_Executor*); + +// Causes the calling thread to block till all ops dispatched in this executor +// have been executed. Note that "execution" here refers to kernel execution / +// scheduling of copies, etc. Similar to sync execution, it doesn't guarantee +// that lower level device queues (like GPU streams) have been flushed. +// +// This call may not block for execution of ops enqueued concurrently with this +// call. +TF_CAPI_EXPORT extern void TFE_ExecutorWaitForAllPendingNodes( + TFE_Executor*, TF_Status* status); + +// When an error happens, any pending operations are discarded, and newly issued +// ops return an error. This call clears the error state and re-enables +// execution of newly issued ops. +// +// Note that outputs of discarded ops remain in a corrupt state and should not +// be used for future calls. +// TODO(agarwal): mark the affected handles and raise errors if they are used. +TF_CAPI_EXPORT extern void TFE_ExecutorClearError(TFE_Executor*); + +// Sets a custom Executor for the current thread. All nodes created by this +// thread will be added to this Executor. It will override the current executor. +TF_CAPI_EXPORT extern void TFE_ContextSetExecutorForThread(TFE_Context*, + TFE_Executor*); + +// Returns the Executor for the current thread. +TF_CAPI_EXPORT extern TFE_Executor* TFE_ContextGetExecutorForThread( + TFE_Context*); + +// ----------------------------------------------------------------------------- +// Dynamic cluster API. + +// Update an existing context with a new set of servers defined in a ServerDef +// proto. Servers can be added to and removed from the list of remote workers +// in the context. A New set of servers identified by the ServerDef must be up +// when the context is updated. +// +// This API is for experimental usage and may be subject to change. +TF_CAPI_EXPORT extern void TFE_ContextUpdateServerDef(TFE_Context* ctx, + int keep_alive_secs, + const void* proto, + size_t proto_len, + TF_Status* status); + +// This API is for experimental usage and may be subject to change. +TF_CAPI_EXPORT extern void TFE_ContextUpdateServerDefWithTimeout( + TFE_Context* ctx, int keep_alive_secs, const void* proto, size_t proto_len, + int64_t init_timeout_in_ms, TF_Status* status); + +// This API is for experimental usage and may be subject to change. +TF_CAPI_EXPORT extern void TFE_ContextSetServerDefWithTimeout( + TFE_Context* ctx, int keep_alive_secs, const void* proto, size_t proto_len, + int64_t init_timeout_in_ms, TF_Status* status, + bool clear_existing_contexts); + +// Set server def with retries and timeout. This is helpful for fault-tolerant +// initial connection in high-preemption environments, such as +// ParameterServerStrategy training. +// This API is for experimental usage and may be subject to change. +TF_CAPI_EXPORT extern void TFE_ContextSetServerDefWithTimeoutAndRetries( + TFE_Context* ctx, int keep_alive_secs, const void* proto, size_t proto_len, + int64_t init_timeout_in_ms, int retries, TF_Status* status, + bool clear_existing_contexts); + +// Checks whether a remote worker is alive or not. This will return true even if +// the context doesn't exist on the remote worker. +TF_CAPI_EXPORT extern bool TFE_ContextCheckAlive(TFE_Context* ctx, + const char* worker_name, + TF_Status* status); + +// Sync pending nodes in local executors (including the context default executor +// and thread executors) and streaming requests to remote executors, and get the +// combined status. +TF_CAPI_EXPORT extern void TFE_ContextAsyncWait(TFE_Context* ctx, + TF_Status* status); + +// This function will block till the operation that produces `h` has +// completed. This is only valid on local TFE_TensorHandles. The pointer +// returned will be on the device in which the TFE_TensorHandle resides (so e.g. +// for a GPU tensor this will return a pointer to GPU memory). The pointer is +// only guaranteed to be valid until TFE_DeleteTensorHandle is called on this +// TensorHandle. Only supports POD data types. +TF_CAPI_EXPORT extern void* TFE_TensorHandleDevicePointer(TFE_TensorHandle*, + TF_Status*); + +// This function will block till the operation that produces `h` has +// completed. This is only valid on local TFE_TensorHandles. Returns the size in +// bytes of the memory pointed to by the device pointer returned above. +TF_CAPI_EXPORT extern size_t TFE_TensorHandleDeviceMemorySize(TFE_TensorHandle*, + TF_Status*); + +// Creates a new TensorHandle from memory residing in the physical device +// device_name. Takes ownership of the memory, and will call deleter to release +// it after TF no longer needs it or in case of error. +// +// Custom devices must use TFE_NewCustomDeviceTensorHandle instead. +TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_NewTensorHandleFromDeviceMemory( + TFE_Context* ctx, const char* device_name, TF_DataType, const int64_t* dims, + int num_dims, void* data, size_t len, + void (*deallocator)(void* data, size_t len, void* arg), + void* deallocator_arg, TF_Status* status); + +// Retrieves the address space (i.e. job, replia, task) of the local host and +// saves it in the buffer. +TF_CAPI_EXPORT extern void TFE_HostAddressSpace(TFE_Context* ctx, + TF_Buffer* buf); + +// APIs for generically dealing with op attributes (e.g. when forwarding them +// through custom device implementations). +// +// TODO(allenl): Currently these are black boxes, but we should have some way to +// inspect values. This would let people e.g. copy over most attributes and then +// modify some based on their values. + +// A reference to an op's name -> attribute mapping +typedef struct TFE_OpAttrs TFE_OpAttrs; + +// Fetch a reference to `op`'s attributes. The returned reference is only valid +// while `op` is alive. +TF_CAPI_EXPORT extern const TFE_OpAttrs* TFE_OpGetAttrs(const TFE_Op* op); +// Add attributes in `attrs` to `op`. +// +// Does not overwrite or update existing attributes, but adds new ones. +TF_CAPI_EXPORT extern void TFE_OpAddAttrs(TFE_Op* op, const TFE_OpAttrs* attrs); + +// Serialize `attrs` as a tensorflow::NameAttrList protocol buffer (into `buf`), +// containing the op name and a map of its attributes. +TF_CAPI_EXPORT extern void TFE_OpAttrsSerialize(const TFE_OpAttrs* attrs, + TF_Buffer* buf, + TF_Status* status); + +// Set an op's attribute from a serialized AttrValue protocol buffer. +// +// Analogous to TF_SetAttrValueProto for building graph operations. +TF_CAPI_EXPORT extern void TFE_OpSetAttrValueProto(const TFE_Op* op, + const char* attr_name, + const void* proto, + size_t proto_len, + TF_Status* status); + +// TODO(b/166642410): It would be nice, for custom devices and for other users, +// to have a non-string representation of devices (TF_Device) extracted from +// tensors/ops/etc. and usable in APIs like OpSetDevice/ResetOp/etc. + +#define TFE_CUSTOM_DEVICE_VERSION 4 + +// Struct to be filled in. Functions are required except where indicated. +typedef struct TFE_CustomDevice { + int version = TFE_CUSTOM_DEVICE_VERSION; + // Method to copy a tensor to the custom device. + TFE_TensorHandle* (*copy_tensor_to_device)(TFE_Context* context, + TFE_TensorHandle* tensor, + TF_Status* status, + void* device_info); + + // Method to copy a tensor from the custom device to a target device. + TFE_TensorHandle* (*copy_tensor_from_device)(TFE_Context* context, + TFE_TensorHandle* tensor, + const char* target_device_name, + TF_Status* status, + void* device_info); + + // Method to execute an operation. + // + // Arguments provide enough information to reconstruct the original `TFE_Op`, + // or construct a transformed version, by inspecting the passed `op`. + // + // TFE_OpGetDevice(op) records the original placement of the operation. It may + // be an empty string if no device was explicitly requested, but will + // otherwise be the name of this custom device. Ops are placed onto a custom + // device if any of their inputs are on that custom device, but custom devices + // are free to set a bad status in order to require explicit placement. + void (*execute)(const TFE_Op* op, int* num_outputs, + TFE_TensorHandle** outputs, TF_Status* s, void* device_info); + + // Method to delete a device. + void (*delete_device)(void* device_info); + + // Implements TFE_CreatePackedTensorHandle when one of `handles` is on this + // custom device. + // + // Many devices will want to simply return an "unimplemented" status + // here. This is the default behavior if `pack` is null when passed to + // TFE_RegisterCustomDevice. + TFE_TensorHandle* (*pack)(TFE_Context* context, TFE_TensorHandle** handles, + int num_handles, TF_Status* s, + void* device_info) = nullptr; + + // Pins the op to `device` based on inputs to `op`. Returns true + // signifying to pin to the current custom device. Returns false + // to pin to the physical device. + // + // This function is guaranteed to be called only when all of the custom-device + // inputs are on this device. + bool (*shall_pin_to_this_device)(const TFE_Op* op, TF_Status* s) = nullptr; +} TFE_CustomDevice; + +// Registers a custom device for use with eager execution. +// +// Eager operations may be placed on this device, e.g. `with +// tf.device("CUSTOM"):` from Python if `device_name` for this call is +// "/job:localhost/replica:0/task:0/device:CUSTOM:0". +// +// The custom device defines copy operations for moving TensorHandles on and +// off, and an execution operation for named operations. Often execution will +// simply wrap op execution on one or more physical devices. +// +// device_info is an opaque caller-defined type stored with the custom device +// which is passed to the functions referenced in the TFE_CustomDevice struct +// `device` (execute, delete_device, etc.). It can for example contain the +// names of wrapped devices. +// +// There are currently no graph semantics implemented for registered custom +// devices, so executing tf.functions which contain operations placed on the +// custom devices will fail. +// +// `device_name` must not name an existing physical or custom device. It must +// follow the format: +// +// /job:/replica:/task:/device:: +// +// If the device is successfully registered, `status` is set to TF_OK. Otherwise +// the device is not usable. In case of a bad status, `device.delete_device` is +// still called on `device_info` (i.e. the caller does not retain ownership). +// +// This API is highly experimental, and in particular is expected to change when +// it starts supporting operations with attributes and when tf.function support +// is added. +TF_CAPI_EXPORT extern void TFE_RegisterCustomDevice(TFE_Context* ctx, + TFE_CustomDevice device, + const char* device_name, + void* device_info, + TF_Status* status); + +// Returns whether `device_name` maps to a registered custom device. +TF_CAPI_EXPORT extern bool TFE_IsCustomDevice(TFE_Context* ctx, + const char* device_name); + +// Struct to be filled in to define a custom device tensor handle. Fields are +// required except where indicated. +typedef struct TFE_CustomDeviceTensorHandleMethods { + int version = TFE_CUSTOM_DEVICE_VERSION; + + // Computes the rank of the tensor handle. + // + // Shapes are specified via callbacks because retrieving the shape of a tensor + // is a blocking operation for async eager; custom devices should avoid + // retrieving shapes of tensors they wrap until the custom device tensor's + // shape is explicitly requested where possible. + int (*num_dims)(void* data, TF_Status* status); + + // Computes the axis length at `dim_index`. + int64_t (*dim)(void* data, int dim_index, TF_Status* status); + + void (*deallocator)(void* data); + + // Summarizes the value of this tensor. The caller takes ownership of the + // returned buffer. If `status` is not TF_OK, instead returns a null pointer. + // + // Does not include the shape and dtype of the tensor (which is generally + // appended later), but should include any information specific to this custom + // device which would be useful for debugging. + // + // Optional. If null, defaults to resolving the TFE_TensorHandle into a + // TF_Tensor and summarizing that. + TF_Buffer* (*summarize)(void* data, TF_Status* status) = nullptr; +} TFE_CustomDeviceTensorHandle; + +// Creates a new TensorHandle from memory residing in a custom device. Takes +// ownership of the memory pointed to by `tensor_handle_data`, and calls +// `methods.deallocator` to release it after TF no longer needs it or in case of +// an error. +// +// This call is similar to `TFE_NewTensorHandleFromDeviceMemory`, but supports +// custom devices instead of physical devices and does not require blocking +// waiting for exact shapes. +TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_NewCustomDeviceTensorHandle( + TFE_Context*, const char* device_name, TF_DataType, void* data, + TFE_CustomDeviceTensorHandle methods, TF_Status* status); + +TF_CAPI_EXPORT extern void TFE_ContextGetFunctionDef(TFE_Context* ctx, + const char* function_name, + TF_Buffer* buf, + TF_Status* status); + +// Get GraphDebugInfo containing stack traces mapping to node names +TF_CAPI_EXPORT extern void TFE_ContextGetGraphDebugInfo( + TFE_Context* ctx, const char* function_name, TF_Buffer* buf, + TF_Status* status); + +// Extracts a TF_Function from the context. +// Must call TF_DeleteFunction on the returned value. +TF_CAPI_EXPORT extern TF_Function* TFE_ContextGetFunction(TFE_Context* ctx, + const char* name, + TF_Status* status); + +// Allocate and return a new Tensor on the host. +// +// The caller must set the Tensor values by writing them to the pointer returned +// by TF_TensorData with length TF_TensorByteSize. +TF_CAPI_EXPORT extern TF_Tensor* TFE_AllocateHostTensor(TFE_Context* ctx, + TF_DataType dtype, + const int64_t* dims, + int num_dims, + TF_Status* status); + +// Given a Tensor, wrap it with a TensorHandle +// +// Similar to TFE_NewTensorHandle, but includes a pointer to the TFE_Context. +// The context should be identical to that of the Tensor. +TF_CAPI_EXPORT TFE_TensorHandle* TFE_NewTensorHandleFromTensor( + TFE_Context* ctx, TF_Tensor* t, TF_Status* status); + +// Create a packed TensorHandle with the given list of TensorHandles. +// If `handles` are on the same device, assign the same device to the packed +// handle; if `handles` are on different deivces, assign a CompositeDevice to +// it. +TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_CreatePackedTensorHandle( + TFE_Context* ctx, TFE_TensorHandle** handles, int* num_handles, + TF_Status* status); + +// Configure soft device placement policy for the eager executor. Note this +// policy is applied to any subsequent op executions. +TF_CAPI_EXPORT void TFE_ContextSetSoftDevicePlacement(TFE_Context* ctx, + unsigned char enable, + TF_Status* status); + +// Configure device placement policy logging for the eager executor. Note this +// policy is applied to any subsequent op executions. +TF_CAPI_EXPORT void TFE_ContextSetLogDevicePlacement(TFE_Context* ctx, + unsigned char enable, + TF_Status* status); + +// Enables running eager ops as function. +TF_CAPI_EXPORT void TFE_ContextSetRunEagerOpAsFunction(TFE_Context* ctx, + unsigned char enable, + TF_Status* status); + +// Enables rewrite jit_compile functions. +TF_CAPI_EXPORT void TFE_ContextSetJitCompileRewrite(TFE_Context* ctx, + unsigned char enable, + TF_Status* status); + +// Returns the device type of the operation that produced `h`. +TF_CAPI_EXPORT extern const char* TFE_TensorHandleDeviceType( + TFE_TensorHandle* h, TF_Status* status); + +// Returns the device ID of the operation that produced `h`. +TF_CAPI_EXPORT extern int TFE_TensorHandleDeviceID(TFE_TensorHandle* h, + TF_Status* status); + +// Returns the status for the tensor handle. In TFRT, a tensor handle can carry +// error info if error happens. If so, the status will be set with the error +// info. If not, status will be set as OK. +TF_CAPI_EXPORT extern void TFE_TensorHandleGetStatus(TFE_TensorHandle* h, + TF_Status* status); + +// Get a comma-separated list of op names executed in graph functions dispatched +// to `ctx`. This feature is currently only enabled for TFRT debug builds, for +// performance and simplicity reasons. +TF_CAPI_EXPORT extern void TFE_GetExecutedOpNames(TFE_Context* ctx, + TF_Buffer* buf, + TF_Status* status); + +// Set logical devices to the context's device manager. +// If logical devices are already configured at context initialization +// through TFE_ContextOptions, this method should not be called. +TF_CAPI_EXPORT extern void TFE_SetLogicalCpuDevices(TFE_Context* ctx, + int num_cpus, + const char* prefix, + TF_Status* status); + +// Set configuration key and value using coordination service. +// If coordination service is enabled, the key-value will be stored on the +// leader and become accessible to all workers in the cluster. +// Currently, a config key can only be set with one value, and subsequently +// setting the same key will lead to errors. +// +// Note that the key-values are only expected to be used for cluster +// configuration data, and should not be used for storing a large amount of data +// or being accessed very frequently. +TF_CAPI_EXPORT extern void TFE_InsertConfigKeyValue(TFE_Context* ctx, + const char* key, + const char* value, + TF_Status* status); + +// Get configuration key and value using coordination service. +// The config key must be set before getting its value. Getting value of +// non-existing config keys will result in errors. +// If `timeout_in_ms=0`, this call will block until the key-value is set or the +// worker shuts down. +TF_CAPI_EXPORT extern void TFE_GetConfigKeyValue(TFE_Context* ctx, + const char* key, + int64_t timeout_in_ms, + TF_Buffer* value_buf, + TF_Status* status); + +// Delete configuration key-value. If `key` is a directory, recursively clean up +// all key-values under the path specified by `key`. +TF_CAPI_EXPORT extern void TFE_DeleteConfigKeyValue(TFE_Context* ctx, + const char* key, + TF_Status* status); + +// Report error (specified by error_code and error_message) to other tasks in +// the cluster. +TF_CAPI_EXPORT extern void TFE_ReportErrorToCluster(TFE_Context* ctx, + int error_code, + const char* error_message, + TF_Status* status); + +// Get task states from the Coordination Service. +TF_CAPI_EXPORT extern void TFE_GetTaskStates(TFE_Context* ctx, + const TF_Buffer& tasks, + void* states, TF_Status* status); + +TF_CAPI_EXPORT extern void TFE_WaitAtBarrier(TFE_Context* ctx, + const char* barrier_id, + int64_t barrier_timeout_in_ms, + TF_Status* status); + +TF_CAPI_EXPORT extern void TFE_InitializeLocalOnlyContext(TFE_Context* ctx, + int keep_alive_secs, + const void* proto, + size_t proto_len, + TF_Status* status); + +#ifdef __cplusplus +} /* end extern "C" */ +#endif + +#endif // TENSORFLOW_C_EAGER_C_API_EXPERIMENTAL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/c/eager/c_api_experimental_reader.h b/third_party/tflite-hdrs/tensorflow/c/eager/c_api_experimental_reader.h new file mode 100644 index 00000000..71c2e465 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/c/eager/c_api_experimental_reader.h @@ -0,0 +1,60 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License");; +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_EAGER_C_API_EXPERIMENTAL_READER_H_ +#define TENSORFLOW_C_EAGER_C_API_EXPERIMENTAL_READER_H_ + +#include "tensorflow/c/eager/c_api.h" + +#ifdef __cplusplus +extern "C" { +#endif + +// Test only exports of the monitoring Cell Reader API which allows tests to +// read current values from streamz counters defined in other modules. +// +// The code under test will have created streamz counters like this: +// auto* streamz = tensorflow::monitoring::Counter<1>::New("name", +// "description", "label"); +// and then incremented that counter for various values of label: +// streamz->GetCell("label-value")->IncrementBy(1); +// +// The test code can then read and test the value of that counter: +// +// auto* reader = TFE_MonitoringNewCounterReader("name"); +// test(); +// int64_t value = TFE_MonitoringReadCounter1(reader, "label-value"); + +// Opaque handle to a reader. +typedef struct TFE_MonitoringCounterReader TFE_MonitoringCounterReader; + +// Returns a handle to be used for reading values from streamz counter. The +// counter can have been created with any number of labels. +TF_CAPI_EXPORT extern TFE_MonitoringCounterReader* +TFE_MonitoringNewCounterReader(const char* name); + +// Reads the value of a counter that was created with 0 labels. +TF_CAPI_EXPORT extern int64_t TFE_MonitoringReadCounter0( + TFE_MonitoringCounterReader*); + +// Reads the value of specific cell of a counter that was created with 1 label. +TF_CAPI_EXPORT extern int64_t TFE_MonitoringReadCounter1( + TFE_MonitoringCounterReader*, const char* label_value); + +#ifdef __cplusplus +} /* end extern "C" */ +#endif + +#endif // TENSORFLOW_C_EAGER_C_API_EXPERIMENTAL_READER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/c/eager/c_api_internal.h b/third_party/tflite-hdrs/tensorflow/c/eager/c_api_internal.h new file mode 100644 index 00000000..eff96826 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/c/eager/c_api_internal.h @@ -0,0 +1,43 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_C_EAGER_C_API_INTERNAL_H_ +#define TENSORFLOW_C_EAGER_C_API_INTERNAL_H_ + +#include "tensorflow/c/c_api_internal.h" +#include "tensorflow/c/eager/c_api.h" +#include "tensorflow/c/eager/c_api_experimental.h" +#include "tensorflow/c/eager/tfe_cancellation_manager_internal.h" // IWYU pragma: export +#include "tensorflow/c/eager/tfe_executor_internal.h" // IWYU pragma: export +#include "tensorflow/c/eager/tfe_monitoring_internal.h" // IWYU pragma: export +#include "tensorflow/c/eager/tfe_op_attrs_internal.h" // IWYU pragma: export +#include "tensorflow/c/eager/tfe_tensor_debug_info_internal.h" // IWYU pragma: export + +// TODO(b/154564140): Move this to its own header. This requires splitting +// c_api_experimental.h +struct TFE_ContextOptions { + TF_SessionOptions session_options; + // true if async execution is enabled. + bool async = false; + TFE_ContextDevicePlacementPolicy device_placement_policy{ + TFE_DEVICE_PLACEMENT_SILENT}; + // If true, use TFRT backend + bool use_tfrt = false; + // Whether to run elementary eager ops wrapped in a call op. + bool run_eager_op_as_function = false; + // Whether to rewrite jit_compile functions. + bool jit_compile_rewrite = false; +}; + +#endif // TENSORFLOW_C_EAGER_C_API_INTERNAL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/c/eager/c_api_remote_test_util.h b/third_party/tflite-hdrs/tensorflow/c/eager/c_api_remote_test_util.h new file mode 100644 index 00000000..6d9edb65 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/c/eager/c_api_remote_test_util.h @@ -0,0 +1,27 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_C_EAGER_C_API_REMOTE_TEST_UTIL_H_ +#define TENSORFLOW_C_EAGER_C_API_REMOTE_TEST_UTIL_H_ + +// Run a function containing a MatMul op and check its output. +// If heavy_load_on_streaming_rpc is true, send some rpc requests before the one +// which creates a remote input, to simulate a scenario that the remote input +// is not ready when we start running an op or a function. +void TestRemoteExecuteSilentCopies(bool async, bool remote, bool func, + bool heavy_load_on_streaming_rpc, + bool remote_func_outputs = false, + bool has_packed_input = false); + +#endif // TENSORFLOW_C_EAGER_C_API_REMOTE_TEST_UTIL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/c/eager/c_api_test_util.h b/third_party/tflite-hdrs/tensorflow/c/eager/c_api_test_util.h new file mode 100644 index 00000000..ff5b0736 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/c/eager/c_api_test_util.h @@ -0,0 +1,174 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_C_EAGER_C_API_TEST_UTIL_H_ +#define TENSORFLOW_C_EAGER_C_API_TEST_UTIL_H_ + +#include + +#include "tensorflow/c/eager/c_api.h" +#include "tensorflow/c/eager/c_api_experimental.h" +#include "tensorflow/c/tf_datatype.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/tstring.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/protobuf/tensorflow_server.pb.h" + +// Return a tensor handle containing a float scalar +TFE_TensorHandle* TestScalarTensorHandle(TFE_Context* ctx, float value); + +// Return a tensor handle containing a int scalar +TFE_TensorHandle* TestScalarTensorHandle(TFE_Context* ctx, int value); + +// Return a tensor handle containing a bool scalar +TFE_TensorHandle* TestScalarTensorHandle(TFE_Context* ctx, bool value); + +// Return a tensor handle containing a tstring scalar +TFE_TensorHandle* TestScalarTensorHandle(TFE_Context* ctx, + const tensorflow::tstring& value); + +// Return a tensor handle containing a 2x2 matrix of doubles +TFE_TensorHandle* DoubleTestMatrixTensorHandle(TFE_Context* ctx); + +// Return a tensor handle containing a 2x2 matrix of floats +TFE_TensorHandle* TestMatrixTensorHandle(TFE_Context* ctx); + +// Return a tensor handle containing 2D matrix containing given data and +// dimensions +TFE_TensorHandle* TestMatrixTensorHandleWithInput(TFE_Context* ctx, + float data[], int64_t dims[], + int num_dims); + +// Get a Matrix TensorHandle with given float values and dimensions +TFE_TensorHandle* TestTensorHandleWithDimsFloat(TFE_Context* ctx, float data[], + int64_t dims[], int num_dims); + +// Get a Matrix TensorHandle with given int values and dimensions +TFE_TensorHandle* TestTensorHandleWithDimsInt(TFE_Context* ctx, int data[], + int64_t dims[], int num_dims); + +// Return a tensor handle with given type, values and dimensions. +template +TFE_TensorHandle* TestTensorHandleWithDims(TFE_Context* ctx, const T* data, + const int64_t* dims, int num_dims) { + TF_Status* status = TF_NewStatus(); + TF_Tensor* t = TFE_AllocateHostTensor(ctx, datatype, dims, num_dims, status); + memcpy(TF_TensorData(t), data, TF_TensorByteSize(t)); + TFE_TensorHandle* th = TFE_NewTensorHandleFromTensor(ctx, t, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TF_DeleteTensor(t); + TF_DeleteStatus(status); + return th; +} + +// Return a scalar tensor handle with given values. +template +TFE_TensorHandle* TestScalarTensorHandle(TFE_Context* ctx, const T value) { + T data[] = {value}; + return TestTensorHandleWithDims(ctx, data, nullptr, 0); +} + +// Return a tensor handle containing a 100x100 matrix of floats +TFE_TensorHandle* TestMatrixTensorHandle100x100(TFE_Context* ctx); + +// Return a tensor handle containing a 3x2 matrix of doubles +TFE_TensorHandle* DoubleTestMatrixTensorHandle3X2(TFE_Context* ctx); + +// Return a tensor handle containing a 3x2 matrix of floats +TFE_TensorHandle* TestMatrixTensorHandle3X2(TFE_Context* ctx); + +// Return a variable handle referring to a variable with the given initial value +// on the given device. +TFE_TensorHandle* TestVariable(TFE_Context* ctx, float value, + const tensorflow::string& device_name = ""); + +// Return an add op multiplying `a` by `b`. +TFE_Op* AddOp(TFE_Context* ctx, TFE_TensorHandle* a, TFE_TensorHandle* b); + +// Return a matmul op multiplying `a` by `b`. +TFE_Op* MatMulOp(TFE_Context* ctx, TFE_TensorHandle* a, TFE_TensorHandle* b); + +// Return an identity op. +TFE_Op* IdentityOp(TFE_Context* ctx, TFE_TensorHandle* a); + +// Return a shape op fetching the shape of `a`. +TFE_Op* ShapeOp(TFE_Context* ctx, TFE_TensorHandle* a); + +// Return an allreduce op adding up input tensor `in` from `group_size` workers. +TFE_Op* AllReduceOp(TFE_Context* ctx, TFE_TensorHandle* in, int group_size); + +// Return a SendOp op `op_name` with send input tensor `in` and attributes +// `send_device`, `recv_device`, and `send_device_incarnation` set. +TFE_Op* SendOp(TFE_Context* ctx, TFE_TensorHandle* in, + const std::string& op_name, const std::string& send_device, + const std::string& recv_device, + tensorflow::uint64 send_device_incarnation); + +// Return a RecvOp op `op_name` with the attributes `send_device`, +// `recv_device`, and `send_device_incarnation` set. +TFE_Op* RecvOp(TFE_Context* ctx, const std::string& op_name, + const std::string& send_device, const std::string& recv_device, + tensorflow::uint64 send_device_incarnation); + +// Return a 1-D INT32 tensor containing a single value 1. +TFE_TensorHandle* TestAxisTensorHandle(TFE_Context* ctx); + +// Return an op taking minimum of `input` long `axis` dimension. +TFE_Op* MinOp(TFE_Context* ctx, TFE_TensorHandle* input, + TFE_TensorHandle* axis); + +// If there is a device of type `device_type`, returns true +// and sets 'device_name' accordingly. +// `device_type` must be either "GPU" or "TPU". +bool GetDeviceName(TFE_Context* ctx, tensorflow::string* device_name, + const char* device_type); + +// Create a ServerDef with the given `job_name` and add `num_tasks` tasks in it. +tensorflow::ServerDef GetServerDef(const tensorflow::string& job_name, + int num_tasks); + +// Create a ServerDef with job name "localhost" and add `num_tasks` tasks in it. +tensorflow::ServerDef GetServerDef(int num_tasks); + +// Create a multi-client ServerDef with the given `job_name`, add `num_tasks` +// tasks and `num_virtual_gpus` virtual GPUs in it. +tensorflow::ServerDef GetMultiClientServerDef(const std::string& job_name, + int num_tasks, + int num_virtual_gpus = 0); + +// Create a variable handle with name `variable_name` on a device with name +// `device_name`. +TFE_TensorHandle* CreateVarHandle(TFE_Context* ctx, + const tensorflow::string& device_name, + const tensorflow::string& variable_name); + +// Create a variable with value `value` and name `variable_name` on a device +// with name `device_name`. +TFE_TensorHandle* CreateVariable(TFE_Context* ctx, float value, + const tensorflow::string& device_name, + const tensorflow::string& variable_name); + +TFE_Context* CreateContext(const std::string& serialized_server_def, + bool isolate_session_state, + int64_t init_timeout_in_ms); + +tensorflow::ServerDef ReplaceTaskInServerDef( + const tensorflow::ServerDef& server_def, int task_index); + +void ReplaceTaskInServerDef(tensorflow::ServerDef* server_def, int task_index, + const std::string& host, int port); + +std::vector ListDeviceNames(TFE_Context* ctx); + +#endif // TENSORFLOW_C_EAGER_C_API_TEST_UTIL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/c/eager/c_api_unified_experimental.h b/third_party/tflite-hdrs/tensorflow/c/eager/c_api_unified_experimental.h new file mode 100644 index 00000000..41228f07 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/c/eager/c_api_unified_experimental.h @@ -0,0 +1,153 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_C_EAGER_C_API_UNIFIED_EXPERIMENTAL_H_ +#define TENSORFLOW_C_EAGER_C_API_UNIFIED_EXPERIMENTAL_H_ + +#include "tensorflow/c/eager/c_api.h" +#include "tensorflow/c/tf_datatype.h" +#include "tensorflow/c/tf_status.h" + +#ifdef __cplusplus +extern "C" { +#endif + +// ============================================================================= +// Unified Execution APIs for Eager and tracing backends. +// ============================================================================= + +// ----------------------------------------------------------------------------- +// Core APIs +// ----------------------------------------------------------------------------- + +// A TF_ExecutionContext stores knowledge about how to execute an operation. +// E.g. it could know whether we're in eager mode or graph mode, keeps track +// of gradient tapes, etc. +typedef struct TF_ExecutionContext TF_ExecutionContext; + +// A TF_AbstractTensor is an input to an operation. E.g. it could be a union +// type of eager and graph tensors. It is also the result of executing an +// operation. +typedef struct TF_AbstractTensor TF_AbstractTensor; + +// A TF_AbstractOp is the metadata we need to execute an operation. E.g. this +// could contain the op type and other attributes. +typedef struct TF_AbstractOp TF_AbstractOp; + +// Stores a function representation that can be used for execution or for +// setting functional attributes of other composite ops e.g. control flow. +typedef struct TF_AbstractFunction TF_AbstractFunction; + +// This allows the client to swap the implementation of the tracing engine. +// Any future call to TF_CreateFunction will use the implementation defined +// here. +void TF_SetTracingImplementation(const char* name, TF_Status*); + +// Creates a new TensorFlow function. A Function is an execution context, and as +// such it can trace operations through TF_ExecuteOperation. After completing +// tracing, a function can be obtained by TF_FinalizeFunction. +TF_ExecutionContext* TF_CreateFunction(const char* fn_name, TF_Status* status); + +// Creates a context for eager execution of operations. +TF_ExecutionContext* TF_NewEagerExecutionContext(TFE_ContextOptions*, + TF_Status* s); +void TF_DeleteExecutionContext(TF_ExecutionContext*); + +// Represents a (partially-defined) shape. +typedef struct TF_Shape { + int num_dims; // Must be >= -1; -1 represents unknown rank. + int64_t* dim_sizes; +} TF_Shape; + +// Add a new parameter to a TensorFlow Function. +TF_AbstractTensor* TF_AddFunctionParameter(TF_ExecutionContext* func, + TF_DataType dtype, TF_Shape shape, + TF_Status* s); + +// Create an operation suitable to use with the provided context. The operation +// requires its type (e.g. "AddV2") to be set independently. +TF_AbstractOp* TF_NewAbstractOp(TF_ExecutionContext* ctx); +void TF_DeleteAbstractOp(TF_AbstractOp*); + +// TODO(srbs): Add APIs for specifying attrs etc. +// `op_type` must outlive `op`. +void TF_AbstractOpSetOpType(TF_AbstractOp* op, const char* const op_type, + TF_Status* s); +// `op_name` must outlive `op`. +void TF_AbstractOpSetOpName(TF_AbstractOp* op, const char* const op_name, + TF_Status* s); +// `attr_name` must outlive `op`. +void TF_AbstractOpSetAttrType(TF_AbstractOp* op, const char* const attr_name, + TF_DataType value, TF_Status* s); + +void TF_DeleteAbstractTensor(TF_AbstractTensor*); + +// TF_OutputList holds the list of TF_AbstractTensor that results from executing +// an operation, or provided to create a function. +// When executing an operation in an eager context, the expected number of +// outputs must be set beforehand with `TF_OutputListSetNumOutputs`. +typedef struct TF_OutputList TF_OutputList; +TF_OutputList* TF_NewOutputList(); +void TF_DeleteOutputList(TF_OutputList* o); +// Prepare tracing to the expected number of output for an operation. +void TF_OutputListSetNumOutputs(TF_OutputList* o, int num_outputs, TF_Status*); +// Return the number of outputs in the list. +int TF_OutputListNumOutputs(TF_OutputList* o); +// Return the `i`th output in the list. +TF_AbstractTensor* TF_OutputListGet(TF_OutputList* o, int i); +// Append a tensor at the end of the output list, growing its size by one. +void TF_OutputListPushBack(TF_OutputList* o, TF_AbstractTensor* tensor, + TF_Status*); + +// TF_ExecuteOperation will, if in eager mode, execute, if in graph mode, maybe +// capture some inputs and then add a node in the graph. The output tensors are +// returned through the provided TF_OutputList. +// Any active tape will observe the effects of this execution. +void TF_ExecuteOperation(TF_AbstractOp* op, int num_inputs, + TF_AbstractTensor* const* inputs, TF_OutputList* o, + TF_Status* s); + +// Creates a new TF_AbstractFunction from the current tracing states in the +// context. The provided `ctx` is consumed by this API call and deleted. +// The returned TF_AbstractFunction must be deleted by the client, +// TODO(aminim): clarify the contract on the state of the context after this +// call. +TF_AbstractFunction* TF_FinalizeFunction(TF_ExecutionContext* ctx, + TF_OutputList*, TF_Status*); + +void TF_DeleteAbstractFunction(TF_AbstractFunction*); + +// Register the function with the given context. This is particularly useful for +// making a function available to an eager context. +void TF_ExecutionContextRegisterFunction(TF_ExecutionContext*, + TF_AbstractFunction*, TF_Status*); + +// ----------------------------------------------------------------------------- +// APIs specific to Eager modes +// ----------------------------------------------------------------------------- + +// Temporary APIs till we figure out how to create scalar valued Eager +// tensors and how to get value out of eager abstract tensors. +TF_AbstractTensor* TF_CreateAbstractTensorFromEagerTensor(TFE_TensorHandle* t, + TF_Status* s); +TFE_TensorHandle* TF_AbstractTensorGetEagerTensor(TF_AbstractTensor* at, + TF_Status* s); +TFE_Context* TF_ExecutionContextGetTFEContext(TF_ExecutionContext*, + TF_Status* s); + +#ifdef __cplusplus +} /* end extern "C" */ +#endif + +#endif // TENSORFLOW_C_EAGER_C_API_UNIFIED_EXPERIMENTAL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/c/eager/c_api_unified_experimental_internal.h b/third_party/tflite-hdrs/tensorflow/c/eager/c_api_unified_experimental_internal.h new file mode 100644 index 00000000..872b9081 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/c/eager/c_api_unified_experimental_internal.h @@ -0,0 +1,138 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_EAGER_C_API_UNIFIED_EXPERIMENTAL_INTERNAL_H_ +#define TENSORFLOW_C_EAGER_C_API_UNIFIED_EXPERIMENTAL_INTERNAL_H_ + +#include + +#include "tensorflow/c/c_api.h" +#include "tensorflow/c/conversion_macros.h" +#include "tensorflow/c/eager/abstract_context.h" +#include "tensorflow/c/eager/abstract_operation.h" +#include "tensorflow/c/eager/abstract_tensor_handle.h" +#include "tensorflow/c/eager/c_api_unified_experimental.h" +#include "tensorflow/c/tf_datatype.h" +#include "tensorflow/c/tf_status.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/platform/casts.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +// Represents the results of the execution of an operation. +struct OutputList { + std::vector outputs; + int expected_num_outputs = -1; +}; + +namespace tracing { + +// ============================================================================= +// Implementation detail for the unified execution APIs for Eager and tracing +// backends (graph/MLIR). +// +// This defines a set of abstract classes that are intended to provide the +// functionality of the opaque C types exposed in the public APIs defined in the +// `c_api_unified_experimental.h` header. +// ============================================================================= + +// Represents either a MlirTensor or a GraphTensor. +// This base class does not expose any public methods other than to distinguish +// which subclass it actually is. The user is responsible to use the right +// type of AbstractTensor in their context (do not pass an MlirTensor to a +// GraphContext and vice-versa). +class TracingTensorHandle : public AbstractTensorHandle { + protected: + explicit TracingTensorHandle(AbstractTensorHandleKind kind) + : AbstractTensorHandle(kind) {} + + public: + // For LLVM style RTTI. + static bool classof(const AbstractTensorHandle* ptr) { + return ptr->getKind() == kGraph || ptr->getKind() == kMlir; + } +}; + +// An abstract operation describes an operation by its type, name, and +// attributes. It can be "executed" by the context with some input tensors. +// It is allowed to reusing the same abstract operation for multiple execution +// on a given context, with the same or different input tensors. +class TracingOperation : public AbstractOperation { + protected: + explicit TracingOperation(AbstractOperationKind kind) + : AbstractOperation(kind) {} + + public: + // Sets the name of the operation: this is an optional identifier that is + // not intended to carry semantics and preserved/propagated without + // guarantees. + virtual absl::Status SetOpName(const char* op_name) = 0; + + // For LLVM style RTTI. + static bool classof(const AbstractOperation* ptr) { + return ptr->getKind() == kGraph || ptr->getKind() == kMlir; + } +}; + +namespace internal { +struct TracingOperationDeleter { + void operator()(TracingOperation* p) const { + if (p != nullptr) { + p->Release(); + } + } +}; +} // namespace internal + +using TracingOperationPtr = + std::unique_ptr; + +// This holds the context for the execution: dispatching operations either to an +// MLIR implementation or to a graph implementation. +class TracingContext : public AbstractContext { + protected: + explicit TracingContext(AbstractContextKind kind) : AbstractContext(kind) {} + + public: + // Add a function parameter and return the corresponding tensor. + virtual absl::Status AddParameter(DataType dtype, + const PartialTensorShape& shape, + TracingTensorHandle**) = 0; + + // Finalize this context and make a function out of it. The context is in a + // invalid state after this call and must be destroyed. + virtual absl::Status Finalize(OutputList* outputs, AbstractFunction**) = 0; + + // For LLVM style RTTI. + static bool classof(const AbstractContext* ptr) { + return ptr->getKind() == kGraph || ptr->getKind() == kMlir; + } +}; + +typedef TracingContext* (*FactoryFunction)(const char* fn_name, TF_Status*); +absl::Status SetDefaultTracingEngine(const char* name); +void RegisterTracingEngineFactory(const ::tensorflow::string& name, + FactoryFunction factory); +} // namespace tracing + +DEFINE_CONVERSION_FUNCTIONS(AbstractContext, TF_ExecutionContext) +DEFINE_CONVERSION_FUNCTIONS(AbstractTensorHandle, TF_AbstractTensor) +DEFINE_CONVERSION_FUNCTIONS(AbstractFunction, TF_AbstractFunction) +DEFINE_CONVERSION_FUNCTIONS(AbstractOperation, TF_AbstractOp) +DEFINE_CONVERSION_FUNCTIONS(OutputList, TF_OutputList) +} // namespace tensorflow + +#endif // TENSORFLOW_C_EAGER_C_API_UNIFIED_EXPERIMENTAL_INTERNAL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/c/eager/custom_device_testutil.h b/third_party/tflite-hdrs/tensorflow/c/eager/custom_device_testutil.h new file mode 100644 index 00000000..a7c60080 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/c/eager/custom_device_testutil.h @@ -0,0 +1,36 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_EAGER_CUSTOM_DEVICE_TESTUTIL_H_ +#define TENSORFLOW_C_EAGER_CUSTOM_DEVICE_TESTUTIL_H_ + +// A simple logging device to test custom device registration. +#include + +#include "tensorflow/c/c_api.h" +#include "tensorflow/c/eager/c_api.h" +#include "tensorflow/c/eager/c_api_experimental.h" +#include "tensorflow/c/tf_status.h" + +void RegisterLoggingDevice(TFE_Context* context, const char* name, + bool strict_scope_placement, bool* arrived_flag, + bool* executed_flag, TF_Status* status); +void AllocateLoggingDevice(const char* name, bool* arrived_flag, + bool* executed_flag, TFE_CustomDevice** device, + void** device_info); +TFE_TensorHandle* UnpackTensorHandle(TFE_TensorHandle* logged_tensor_handle, + TF_Status* status); + +#endif // TENSORFLOW_C_EAGER_CUSTOM_DEVICE_TESTUTIL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/c/eager/dlpack.h b/third_party/tflite-hdrs/tensorflow/c/eager/dlpack.h new file mode 100644 index 00000000..8c85dee6 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/c/eager/dlpack.h @@ -0,0 +1,40 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_EAGER_DLPACK_H_ +#define TENSORFLOW_C_EAGER_DLPACK_H_ + +#include "tensorflow/c/eager/c_api.h" + +namespace tensorflow { + +// PyCapsule name for DLPack Tensor +const char* const kDlTensorCapsuleName = "dltensor"; + +// Converts eager tensor handle to DLPack (DLManagedTensor*), and return the +// void* for further PyCapsule construction. +TF_CAPI_EXPORT extern void* TFE_HandleToDLPack(TFE_TensorHandle* h, + TF_Status* status); + +// Converts DLPack (DLManagedTensor*) to eager tensor handle. +TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_HandleFromDLPack(void* dlm, + TF_Status* status, + TFE_Context* ctx); + +// Calls the destructor of DLManagedTensor, used in the destructor of PyCapsule. +TF_CAPI_EXPORT extern void TFE_CallDLManagedTensorDeleter(void* dlm_ptr); +} // namespace tensorflow + +#endif // TENSORFLOW_C_EAGER_DLPACK_H_ diff --git a/third_party/tflite-hdrs/tensorflow/c/eager/gradient_checker.h b/third_party/tflite-hdrs/tensorflow/c/eager/gradient_checker.h new file mode 100644 index 00000000..d64ad448 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/c/eager/gradient_checker.h @@ -0,0 +1,46 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_C_EAGER_GRADIENT_CHECKER_H_ +#define TENSORFLOW_C_EAGER_GRADIENT_CHECKER_H_ + +#include + +#include "absl/types/span.h" +#include "tensorflow/c/eager/abstract_tensor_handle.h" +#include "tensorflow/c/eager/unified_api_testutil.h" + +namespace tensorflow { +namespace gradients { + +/* Returns numerical grad inside `dtheta_approx` given `forward` model and + * parameter specified by `input_index`. + * + * I.e. if y = and w = inputs[input_index], + * this will calculate dy/dw numerically. + * + * `use_function` indicates whether to use graph mode(true) or eager(false). + * + * `numerical_grad` is the pointer to the AbstractTensorHandle* which will + * hold the numerical gradient data at the end of the function. + */ +absl::Status CalcNumericalGrad(AbstractContext* ctx, Model forward, + absl::Span inputs, + int input_index, bool use_function, + AbstractTensorHandle** numerical_grad); + +} // namespace gradients +} // namespace tensorflow + +#endif // TENSORFLOW_C_EAGER_GRADIENT_CHECKER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/c/eager/gradients.h b/third_party/tflite-hdrs/tensorflow/c/eager/gradients.h new file mode 100644 index 00000000..88c1df24 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/c/eager/gradients.h @@ -0,0 +1,178 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_EAGER_GRADIENTS_H_ +#define TENSORFLOW_C_EAGER_GRADIENTS_H_ + +#include "absl/container/flat_hash_map.h" +#include "tensorflow/c/eager/abstract_context.h" +#include "tensorflow/c/eager/abstract_tensor_handle.h" +#include "tensorflow/c/eager/tape.h" +#include "tensorflow/core/common_runtime/eager/attr_builder.h" + +namespace tensorflow { +namespace gradients { + +// =============== Experimental C++ API for computing gradients =============== + +// Sample gradient function: +// +// class AddGradientFunction : public GradientFunction { +// public: +// Status Compute(Context* ctx, +// absl::Span grad_inputs, +// absl::Span grad_outputs) override { +// grad_outputs[0] = grad_inputs[0]; +// grad_outputs[1] = grad_inputs[0]; +// grad_outputs[0]->Ref(); +// grad_outputs[1]->Ref(); +// return OkStatus(); +// } +// ~AddGradientFunction() override {} +// }; +// +// GradientFunction* AddRegisterer(const ForwardOperation& op) { +// // More complex gradient functions can use inputs/attrs etc. from the +// // forward `op`. +// return new AddGradientFunction; +// } +// +// Status RegisterGradients(GradientRegistry* registry) { +// return registry->Register("Add", AddRegisterer); +// } +class GradientFunction { + public: + virtual absl::Status Compute( + AbstractContext* ctx, + absl::Span grad_outputs, + absl::Span grad_inputs) = 0; + virtual ~GradientFunction() {} +}; + +// Metadata from the forward operation that is made available to the +// gradient registerer to instantiate a GradientFunction. +struct ForwardOperation { + public: + string op_name; + std::vector inputs; + std::vector outputs; + std::vector skip_input_indices; + AttrBuilder attrs; +}; + +using GradientFunctionFactory = + std::function; + +// Map from op name to a `GradientFunctionFactory`. +class GradientRegistry { + public: + absl::Status Register(const string& op, + GradientFunctionFactory gradient_function_factory); + absl::Status Lookup( + const ForwardOperation& op, + std::unique_ptr* gradient_function) const; + + private: + absl::flat_hash_map registry_; +}; + +// TODO(srbs): Figure out if we can avoid declaring this in the public header. +// Wrapper for a tensor output of an operation executing under a tape. +// +// `GetID` returns a unique id for the wrapped tensor which is used to maintain +// a map (`tensorflow::eager::TensorTape`) from the wrapped tensor to the id of +// the op that produced it (or -1 if this tensor was watched using +// `GradientTape::Watch`.) The op_id is simply a unique index assigned to each +// op executed under the tape. A separate map (`tensorflow::eager::OpTape`) +// maintains the map from `op_id` to a `OpTapeEntry` which stores the `op_type`, +// inputs and outputs and the gradient function These data structures combined +// allow us to trace the data dependencies between operations and hence compute +// gradients. +// +// `ZerosLike` is not expected to be called and returns a nullptr. The creation +// of default zeros grads is handled by the `DefaultGradientFunction` registered +// for each op. +// TODO(srbs): We need to define `ZerosLike` here to keep the compiler happy. +// Figure out a way to avoid this. +// TODO(srbs): Should ZerosLike check-fail instead of returning nullptr? +class TapeTensor { + public: + explicit TapeTensor(AbstractTensorHandle* handle); + TapeTensor(const TapeTensor& other); + ~TapeTensor(); + + int64_t GetID() const; + tensorflow::DataType GetDType() const; + + AbstractTensorHandle* ZerosLike() const; + + AbstractTensorHandle* GetHandle() const; + + private: + AbstractTensorHandle* handle_; +}; + +// A tracing/immediate-execution agnostic tape. +// +// Gradient functions defined for this tape must support handling null incoming +// gradients. +class Tape : protected eager::GradientTape { + public: + using GradientTape::GradientTape; + // Returns whether the tape is persistent, i.e., whether the tape will hold + // onto its internal state after a call to `ComputeGradient`. + using GradientTape::IsPersistent; + + // Adds this tensor to the list of watched tensors. + // + // This is a no-op if the tensor is already being watched either from an + // earlier call to `GradientTape::Watch` or being an output of an op with + // watched inputs. + void Watch(const AbstractTensorHandle*); + // Records an operation with given inputs and outputs + // on the tape and marks all its outputs as watched if at + // least one input of the op is watched and has a trainable dtype. + // op_name is optional and is used for debugging only. + void RecordOperation(absl::Span inputs, + absl::Span outputs, + GradientFunction* gradient_function, + const string& op_name = ""); + // Returns whether any tensor in a list of tensors is being watched and has + // a trainable dtype. + bool ShouldRecord( + absl::Span tensors) const; + // Unwatches this tensor on the tape. Mainly used for cleanup when deleting + // eager tensors. + void DeleteTrace(const AbstractTensorHandle*); + + // Consumes the internal state of the tape (so cannot be called more than + // once unless the tape is persistent) and produces the gradient of the target + // tensors with respect to the source tensors. The output gradients are used + // if not empty and not null. The result is populated with one tensor per + // target element. + absl::Status ComputeGradient( + AbstractContext* ctx, absl::Span targets, + absl::Span sources, + absl::Span output_gradients, + absl::Span result); +}; + +} // namespace gradients +} // namespace tensorflow + +#endif // TENSORFLOW_C_EAGER_GRADIENTS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/c/eager/gradients_internal.h b/third_party/tflite-hdrs/tensorflow/c/eager/gradients_internal.h new file mode 100644 index 00000000..93c2d36b --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/c/eager/gradients_internal.h @@ -0,0 +1,93 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_EAGER_GRADIENTS_INTERNAL_H_ +#define TENSORFLOW_C_EAGER_GRADIENTS_INTERNAL_H_ + +#include "tensorflow/c/eager/gradients.h" + +namespace tensorflow { +namespace gradients { +namespace internal { + +// Helper functions which delegate to `AbstractOperation`, update +// the state of the ForwardOperation and call the tape as appropriate. +// These APIs are mainly to facilitate testing and are subject to change. + +// Records the op name in the `ForwardOperation`. +absl::Status Reset(AbstractOperation*, const char* op, + const char* raw_device_name, ForwardOperation*); + +// Records the inputs in the `ForwardOperation`. +absl::Status AddInput(AbstractOperation*, AbstractTensorHandle*, + ForwardOperation*); +absl::Status AddInputList(AbstractOperation*, + absl::Span inputs, + ForwardOperation*); + +// Sets the attrs in the `ForwardOperation`. +absl::Status SetAttrString(AbstractOperation*, const char* attr_name, + const char* data, size_t length, ForwardOperation*); +absl::Status SetAttrInt(AbstractOperation*, const char* attr_name, + int64_t value, ForwardOperation*); +absl::Status SetAttrFloat(AbstractOperation*, const char* attr_name, + float value, ForwardOperation*); +absl::Status SetAttrBool(AbstractOperation*, const char* attr_name, bool value, + ForwardOperation*); +absl::Status SetAttrType(AbstractOperation*, const char* attr_name, + DataType value, ForwardOperation*); +absl::Status SetAttrShape(AbstractOperation*, const char* attr_name, + const int64_t* dims, const int num_dims, + ForwardOperation*); +absl::Status SetAttrFunction(AbstractOperation*, const char* attr_name, + const AbstractOperation* value, ForwardOperation*); +absl::Status SetAttrFunctionName(AbstractOperation*, const char* attr_name, + const char* value, size_t length, + ForwardOperation*); +absl::Status SetAttrTensor(AbstractOperation*, const char* attr_name, + AbstractTensorInterface* tensor, ForwardOperation*); +absl::Status SetAttrStringList(AbstractOperation*, const char* attr_name, + const void* const* values, const size_t* lengths, + int num_values, ForwardOperation*); +absl::Status SetAttrFloatList(AbstractOperation*, const char* attr_name, + const float* values, int num_values, + ForwardOperation*); +absl::Status SetAttrIntList(AbstractOperation*, const char* attr_name, + const int64_t* values, int num_values, + ForwardOperation*); +absl::Status SetAttrTypeList(AbstractOperation*, const char* attr_name, + const DataType* values, int num_values, + ForwardOperation*); +absl::Status SetAttrBoolList(AbstractOperation*, const char* attr_name, + const unsigned char* values, int num_values, + ForwardOperation*); +absl::Status SetAttrShapeList(AbstractOperation*, const char* attr_name, + const int64_t** dims, const int* num_dims, + int num_values, ForwardOperation*); +absl::Status SetAttrFunctionList(AbstractOperation*, const char* attr_name, + absl::Span values, + ForwardOperation*); + +// Make the call to `Tape::RecordOperation`. +absl::Status Execute(AbstractOperation*, AbstractContext*, + absl::Span retvals, + int* num_retvals, ForwardOperation*, Tape*, + const GradientRegistry&); + +} // namespace internal +} // namespace gradients +} // namespace tensorflow + +#endif // TENSORFLOW_C_EAGER_GRADIENTS_INTERNAL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/c/eager/graph_function.h b/third_party/tflite-hdrs/tensorflow/c/eager/graph_function.h new file mode 100644 index 00000000..b15d1b4b --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/c/eager/graph_function.h @@ -0,0 +1,53 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_C_EAGER_GRAPH_FUNCTION_H_ +#define TENSORFLOW_C_EAGER_GRAPH_FUNCTION_H_ + +#include "tensorflow/c/eager/abstract_function.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/platform/refcount.h" +namespace tensorflow { +namespace tracing { +namespace graph { +using tensorflow::AbstractFunction; +// Thin wrapper around a FunctionDef. +class GraphFunction : public AbstractFunction { + public: + explicit GraphFunction(FunctionDef fdef); + ~GraphFunction() override; + + // GraphFunction maybe stay alive for the duration of the returned + // FunctionDef. + absl::Status GetFunctionDef(const FunctionDef** fdef) override; + + // Returns a shared reference to the wrapped function. + absl::StatusOr> GetFunctionRecord() + override { + return func_record_.GetNewRef(); + } + + // For LLVM style RTTI. + static bool classof(const AbstractFunction* ptr) { + return ptr->getKind() == kGraph; + } + + private: + core::RefCountPtr func_record_; +}; +} // namespace graph +} // namespace tracing +} // namespace tensorflow + +#endif // TENSORFLOW_C_EAGER_GRAPH_FUNCTION_H_ diff --git a/third_party/tflite-hdrs/tensorflow/c/eager/immediate_execution_context.h b/third_party/tflite-hdrs/tensorflow/c/eager/immediate_execution_context.h new file mode 100644 index 00000000..216fcfe9 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/c/eager/immediate_execution_context.h @@ -0,0 +1,294 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_C_EAGER_IMMEDIATE_EXECUTION_CONTEXT_H_ +#define TENSORFLOW_C_EAGER_IMMEDIATE_EXECUTION_CONTEXT_H_ + +#include +#include +#include +#include + +#include "absl/types/optional.h" +#include "absl/types/span.h" +#include "tensorflow/c/eager/abstract_context.h" +#include "tensorflow/c/eager/immediate_execution_distributed_manager.h" +#include "tensorflow/c/eager/immediate_execution_operation.h" +#include "tensorflow/c/eager/immediate_execution_tensor_handle.h" +#include "tensorflow/c/tensor_interface.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/function.pb.h" +#include "tensorflow/core/framework/numeric_types.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/platform/platform.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/tstring.h" +#include "tensorflow/core/protobuf/config.pb.h" +#include "tensorflow/core/util/device_name_utils.h" + +namespace tensorflow { +class EagerExecutor; +class EagerContext; +class CustomDevice; +class CustomDeviceOpHandler; +class Device; + +// LINT.IfChange +// Note: Keep in sync with exported copy of enum in eager/c_api.h. +enum ContextDevicePlacementPolicy { + // Running operations with input tensors on the wrong device will fail. + DEVICE_PLACEMENT_EXPLICIT = 0, + // Copy the tensor to the right device but log a warning. + DEVICE_PLACEMENT_WARN = 1, + // Silently copy the tensor, which has a performance cost since the operation + // will be blocked till the copy completes. This is the default policy. + DEVICE_PLACEMENT_SILENT = 2, + // Placement policy which silently copies int32 tensors but not other dtypes. + DEVICE_PLACEMENT_SILENT_FOR_INT32 = 3, +}; +// LINT.ThenChange(//tensorflow/c/eager/c_api.h) + +// Abstract interface to a context. +// +// A context is responsible for creating key objects such as Tensors, +// TensorHandles & Operations. +class ImmediateExecutionContext : public AbstractContext { + public: + // Optimized scalar creation functions + virtual AbstractTensorInterface* CreateInt64Scalar(int64_t value) = 0; + virtual AbstractTensorInterface* CreateUint64Scalar(uint64 value) = 0; + virtual AbstractTensorInterface* CreateInt32Scalar(int32_t value) = 0; + virtual AbstractTensorInterface* CreateFloatScalar(float value) = 0; + virtual AbstractTensorInterface* CreateDoubleScalar(double value) = 0; + virtual AbstractTensorInterface* CreateHalfScalar(Eigen::half value) = 0; + virtual AbstractTensorInterface* CreateStringScalar(tstring value) = 0; + virtual AbstractTensorInterface* CreateComplex128Scalar(complex128 value) = 0; + virtual AbstractTensorInterface* CreateBoolScalar(bool value) = 0; + + // Tensor creation functions + virtual AbstractTensorInterface* CreateTensor( + DataType dtype, absl::Span dim_sizes) = 0; + + typedef void (*MemoryReleaser)(void* data, size_t len, void* arg); + + // Create a tensor instance from the given data buffer and description. + // `memory_releaser` will be called on destruction, and it's responsible for + // cleaning up the underlying buffer. + virtual AbstractTensorInterface* CreateTensor( + DataType dtype, const int64_t* dims, int num_dims, void* data, size_t len, + MemoryReleaser memory_releaser, void* memory_releaser_arg) = 0; + + // Create a handle to wrap and manage a Tensor + virtual ImmediateExecutionTensorHandle* CreateLocalHandle( + AbstractTensorInterface* t) = 0; + // Copy the handle to another device. + virtual ImmediateExecutionTensorHandle* CopyTensorHandleToDevice( + ImmediateExecutionTensorHandle* handle, const char* device_name, + absl::Status* status) = 0; + + // Create an operation to perform op execution + ImmediateExecutionOperation* CreateOperation() override = 0; + + // Returns whether the runtime is backed by TFRT or the legacy TF Eager + // Runtime. This is necessary to decouple runtime-dependent + // code that is layered on top of the runtime. + virtual bool UsesTFRT() = 0; + + // List attributes of available devices + virtual void ListDevices(std::vector* devices) = 0; + + // Add `devices` into context's device manager. Context's device manager + // will take ownership and maintain devices' lifetime. + virtual absl::Status AddDevices( + std::vector> devices) = 0; + + // Block until all pending nodes are finished. + virtual absl::Status AsyncWait() = 0; + + // Add a function (serialized FunctionDef protocol buffer) so that it can + // be executed as an op. Return error if the function with the same name + // already exists. + virtual absl::Status AddFunctionDef(const FunctionDef& fdef) = 0; + + // Notifies about the function removal. + virtual absl::Status AddRemoveFunctionNotifier( + const string& func, std::function notifier) = 0; + + // Same as `AddFunctionDef`, but additionally saves the `stack_traces` under + // the key of the function definition name (to be retrieved during function + // instantiation). + virtual absl::Status AddFunctionDefWithStackTraces( + const FunctionDef& fdef, const StackTracesMap& stack_traces) = 0; + + // Find and return a added function by its name. + virtual const FunctionDef* FindFunctionDef(const string& name) const = 0; + + // Find and return a function record added by its name. + virtual core::RefCountPtr FindRecord( + const string& name) const = 0; + + // Return the ParsedName of Host CPU device. + virtual const DeviceNameUtils::ParsedName& HostCPUParsedName() const = 0; + virtual const string& HostCPUName() const = 0; + + // Configure soft device placement policy. + virtual void SetAllowSoftPlacement(bool enable) = 0; + + // Configure device placement policy logging. + virtual void SetLogDevicePlacement(bool enable) = 0; + + // Enables running eager ops as functions. + virtual void SetRunEagerOpAsFunction(bool enable) = 0; + + // Enables rewriting jit_compile functions. + virtual void SetJitCompileRewrite(bool enable) = 0; + + // Sets the device placement policy for the current thread. + virtual void SetThreadLocalDevicePlacementPolicy( + ContextDevicePlacementPolicy policy) = 0; + // Returns the device placement policy for the current thread. + virtual ContextDevicePlacementPolicy GetDevicePlacementPolicy() const = 0; + + // Configure graph collection in RunMetadata. + virtual void SetShouldStoreGraphs(bool value) = 0; + + // Return the collected RunMetadata. This method will transfer the ownership + // to the caller. + virtual std::unique_ptr ExportRunMetadata() = 0; + + // For LLVM style RTTI. + static bool classof(const AbstractContext* ptr) { + return ptr->getKind() == kEager || ptr->getKind() == kTfrt; + } + + //===--------------------------------------------------------------------===// + // Experimental Custom Device. + //===--------------------------------------------------------------------===// + virtual CustomDeviceOpHandler& GetCustomDeviceOpHandler() = 0; + + // Returns whether `device_name` is registered as a custom device. + virtual bool IsCustomDevice(const string& device_name) = 0; + + // Register a custom device. It will return error is the device name is + // already registered. + // TODO(tfrt-devs): Remove this method. Let caller register it directly into + // CustomDeviceOpHandler. + virtual absl::Status RegisterCustomDevice( + const string& name, std::unique_ptr device) = 0; + + // Return FunctionLibraryDefinition. Transformations need to use it to use it + // to invoke MLIR compiler passes. + virtual FunctionLibraryDefinition* FuncLibDef() = 0; + + // Resets the global rendezvous used for functions. + virtual void ResetGlobalRendezvousForFunction() = 0; + + //===--------------------------------------------------------------------===// + // Following are features in current TF Eager Runtime. + // TODO(tfrt-devs): Figure out a way to deprecate following features after + // migrated to TFRT. + //===--------------------------------------------------------------------===// + // Clear pending nodes in thread executors and kernel caches. + virtual void ClearCachesAndThreadExecutors() = 0; + + // Initialize the step resource container for a training step. This is used + // in current TF runtime. For tfrt, it is used by fallback op handler. + virtual void StartStep() = 0; + // Destroy the step resource container for a training step. + virtual void EndStep() = 0; + + // Return the Eager Executor for current thread. Please note that Eager + // Executor is only used in current TF but not in TFRT. + virtual EagerExecutor& Executor() = 0; + // Update the Eager Executor for current thread. + virtual void SetExecutorForThread(EagerExecutor* executor) = 0; + + // Return a list of local tensorflow::Device*. + // TODO(tfrt-devs): We shouldn't expose legacy device in this API. + virtual std::vector ListLocalTfDevices() = 0; + + // Return a list of all tensorflow::Device*. + virtual std::vector ListAllTfDevices() = 0; + + //===--------------------------------------------------------------------===// + // Following are helper functions to assist integrating TFRT with current + // TF eager runtime. + // TODO(b/172877902): These helper functions are currently used to support + // PyFuncOp on TFRT, and might be useful for ops that directly use low + // level TF APIs. Remove/replace the following functions when TFRT native + // ops are implemented. + //===--------------------------------------------------------------------===// + // Create an abstract tensor handle from tensorflow::Tensor. + virtual ImmediateExecutionTensorHandle* CreateLocalHandleFromTFTensor( + tensorflow::Tensor& t, const char* d_name) = 0; + + // Convert a TFRT TensorHandle to tensorflow::TensorHandle. + virtual ImmediateExecutionTensorHandle* TFTensorHandleFromInterface( + ImmediateExecutionTensorHandle* handle) = 0; + + virtual std::vector GetLoggedOpsTestonly() { return {}; } + + // Get a list of the names of functions that have been registered. + virtual std::vector ListFunctionNames() = 0; + + struct CacheStats { + int64_t kernel_cache_size; + int64_t device_cache_size; + std::map func_kernel_cache_entries; + int64_t local_rendezvous_cache_active_size; + }; + virtual CacheStats GetCacheStats() = 0; + + //===--------------------------------------------------------------------===// + // Distributed runtime related functions. + //===--------------------------------------------------------------------===// +#if !defined(IS_MOBILE_PLATFORM) + // Set up a multi-client distributed execution environment. Must be called on + // all tasks in the cluster. + // This call internally coordinates with other tasks to initialize the eager + // context and TF server for multi-client execution. + virtual absl::Status EnableCollectiveOps(const ServerDef& server_def) = 0; + + // Set a distributed manager that helps set up, update, and check liveness + // of member tasks in the cluster. + virtual void SetDistributedManager( + std::unique_ptr distributed) = 0; + + virtual ImmediateExecutionDistributedManager* GetDistributedManager() = 0; +#endif // !IS_MOBILE_PLATFORM + + protected: + explicit ImmediateExecutionContext(AbstractContextKind kind) + : AbstractContext(kind) {} + ~ImmediateExecutionContext() override {} +}; + +namespace internal { +struct ImmediateExecutionContextDeleter { + void operator()(ImmediateExecutionContext* p) const { + if (p != nullptr) { + p->Release(); + } + } +}; +} // namespace internal + +using ImmediateContextPtr = + std::unique_ptr; + +} // namespace tensorflow + +#endif // TENSORFLOW_C_EAGER_IMMEDIATE_EXECUTION_CONTEXT_H_ diff --git a/third_party/tflite-hdrs/tensorflow/c/eager/immediate_execution_distributed_manager.h b/third_party/tflite-hdrs/tensorflow/c/eager/immediate_execution_distributed_manager.h new file mode 100644 index 00000000..f4f4f093 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/c/eager/immediate_execution_distributed_manager.h @@ -0,0 +1,70 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_EAGER_IMMEDIATE_EXECUTION_DISTRIBUTED_MANAGER_H_ +#define TENSORFLOW_C_EAGER_IMMEDIATE_EXECUTION_DISTRIBUTED_MANAGER_H_ + +#include +#include + +#include "tensorflow/core/platform/status.h" + +namespace tsl { +class CoordinationServiceAgent; +} + +namespace tensorflow { +class ImmediateExecutionContext; +class ServerDef; +class WorkerEnv; +class WorkerCacheInterface; + +class ImmediateExecutionDistributedManager { + public: + virtual ~ImmediateExecutionDistributedManager() {} + + // Set up distributed execution environment on local and remote tasks. + // When `reset_context` is true, initialize new cluster context state based + // on cluster configurations provided in `server_def`; otherwise, update + // existing context state with the provided `server_def`. Contexts created + // on remote tasks will be considered stale and garbage collected after + // `keep_alive_secs` of inactivity. + virtual absl::Status SetOrUpdateServerDef( + const ServerDef& server_def, bool reset_context, int keep_alive_secs, + int64_t init_timeout_in_ms, int retries, + bool clear_existing_contexts = false) = 0; + + // Initializes context for the local worker and no contexts will be created + // for remote workers. Currently this only works for resetting context. + // TODO(b/289445025): Consider removing this when we find a proper fix. + virtual absl::Status InitializeLocalOnlyContext(const ServerDef& server_def, + int keep_alive_secs) = 0; + + // Set up a multi-client distributed execution environment. Must be called + // on all tasks in the cluster. This call internally coordinates with other + // tasks to initialize the eager context and TF server for multi-client + // execution. + virtual absl::Status EnableCollectiveOps(const ServerDef& server_def) = 0; + + // Check if the remote task is alive. + virtual absl::Status CheckRemoteAlive(const std::string& remote_task_name, + bool* is_alive) = 0; + + // Get pointer to the coordination service agent instance. + virtual tsl::CoordinationServiceAgent* GetCoordinationServiceAgent() = 0; +}; +} // namespace tensorflow + +#endif // TENSORFLOW_C_EAGER_IMMEDIATE_EXECUTION_DISTRIBUTED_MANAGER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/c/eager/immediate_execution_operation.h b/third_party/tflite-hdrs/tensorflow/c/eager/immediate_execution_operation.h new file mode 100644 index 00000000..fb76af9d --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/c/eager/immediate_execution_operation.h @@ -0,0 +1,104 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_C_EAGER_IMMEDIATE_EXECUTION_OPERATION_H_ +#define TENSORFLOW_C_EAGER_IMMEDIATE_EXECUTION_OPERATION_H_ + +#include + +#include "absl/types/optional.h" +#include "absl/types/span.h" +#include "tensorflow/c/eager/abstract_operation.h" +#include "tensorflow/c/eager/immediate_execution_tensor_handle.h" +#include "tensorflow/c/tensor_interface.h" +#include "tensorflow/core/framework/cancellation.h" +#include "tensorflow/core/framework/device_attributes.pb.h" +#include "tensorflow/core/framework/op_def.pb.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/platform/casts.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/util/managed_stack_trace.h" + +struct TFE_Op; + +namespace tensorflow { + +class ImmediateExecutionContext; +class AbstractOpAttrs; + +// Abstract interface to an operation. +class ImmediateExecutionOperation : public AbstractOperation { + public: + virtual void Clear() = 0; + + // Returns the inputs of this op. + virtual absl::Span GetInputs() + const = 0; + virtual absl::Status SetInput(size_t index, + ImmediateExecutionTensorHandle* input) = 0; + + virtual ImmediateExecutionContext* GetContext() const = 0; + + // Following two methods are used to support custom device. + // Return true if the inputs contain custom device tensor handle. It means + // that the argument need to be handled by a custom device. + virtual bool HasCustomDeviceInput() const = 0; + + virtual const tensorflow::OpDef* OpDef() const = 0; + + virtual absl::Status InputLength(const char* input_name, int* length) = 0; + virtual absl::Status OutputLength(const char* output_name, int* length) = 0; + + // Set stack trace to be used for potential async error reporting. + virtual void SetStackTrace(ManagedStackTrace stack_trace) = 0; + + virtual const tensorflow::AbstractOpAttrs* GetOpAttrs() const = 0; + virtual void AddAttrs(const AbstractOpAttrs* op_attrs) = 0; + + virtual void SetCancellationManager( + CancellationManager* cancellation_manager) = 0; + + // Returns the stack trace set by `SetStackTrace` if exists. + virtual absl::optional GetStackTrace() = 0; + + virtual void SetStepId(int64_t step_id) = 0; + + // For LLVM style RTTI. + static bool classof(const AbstractOperation* ptr) { + return ptr->getKind() == kEager || ptr->getKind() == kTfrt; + } + + protected: + explicit ImmediateExecutionOperation(AbstractOperationKind kind) + : AbstractOperation(kind) {} + ~ImmediateExecutionOperation() override {} +}; + +namespace internal { +struct ImmediateExecutionOperationDeleter { + void operator()(ImmediateExecutionOperation* p) const { + if (p != nullptr) { + p->Release(); + } + } +}; +} // namespace internal + +using ImmediateOpPtr = + std::unique_ptr; + +} // namespace tensorflow + +#endif // TENSORFLOW_C_EAGER_IMMEDIATE_EXECUTION_OPERATION_H_ diff --git a/third_party/tflite-hdrs/tensorflow/c/eager/immediate_execution_tensor_handle.h b/third_party/tflite-hdrs/tensorflow/c/eager/immediate_execution_tensor_handle.h new file mode 100644 index 00000000..61fc0fe8 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/c/eager/immediate_execution_tensor_handle.h @@ -0,0 +1,105 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_C_EAGER_IMMEDIATE_EXECUTION_TENSOR_HANDLE_H_ +#define TENSORFLOW_C_EAGER_IMMEDIATE_EXECUTION_TENSOR_HANDLE_H_ + +#include "tensorflow/c/eager/abstract_tensor_handle.h" +#include "tensorflow/c/tensor_interface.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/platform/status.h" + +namespace tensorflow { + +// Abstract interface to a TensorHandle. +// +// A TensorHandle is management class around a Tensor which may track additional +// metadata and synchronization. +// +// This allows us to hide concrete implementations of TensorHandle from header +// files. The interface lists the common functionality that must be provided by +// any concrete implementation. However, in cases where the true concrete class +// is needed a static_cast can be applied. +class ImmediateExecutionTensorHandle : public AbstractTensorHandle { + public: + // Returns number of dimensions. + virtual absl::Status NumDims(int* num_dims) const = 0; + // Returns number of elements across all dimensions. + virtual absl::Status NumElements(int64_t* num_elements) const = 0; + // Returns size of specified dimension + // + // -1 indicates an unknown axis length; this is unreachable for most standard + // ImmediateExecutionTensorHandles, but comes up for example when computing + // the shape of a parallel tensor with component shapes differing across + // devices. + virtual absl::Status Dim(int dim_index, int64_t* dim) const = 0; + + // Returns the device which created the handle. + virtual const char* DeviceName(absl::Status* status) const = 0; + // Returns the device where the tensor was placed. + virtual const char* BackingDeviceName(absl::Status* status) const = 0; + // Returns the device type which created the handle. + virtual const char* DeviceType(absl::Status* status) const = 0; + // Returns the device ID which created the handle. + virtual int DeviceId(absl::Status* status) const = 0; + // Returns a tensor for the handle. If tensor is remote, it will be copied. + virtual AbstractTensorInterface* Resolve(absl::Status* status) = 0; + + std::string DebugString() const override; + + // Returns a Boolean hint indicating whether callers should prefer + // `SummarizeValue` to resolving this handle and formatting the tensor. + // + // For example some tensor handles may represent distributed values, in which + // case placement information is lost when resolving the handle. + // + // If false, a caller might implement pretty-printing by resolving and + // iterating over the resulting tensor. This may still be viable if resolving + // the handle loses information, but `SummarizeValue` would be more precise. + virtual bool PreferCustomSummarizer() const { return false; } + + // Returns a string which summarizes the value of this TensorHandle, for + // debugging. Does not include a shape or dtype. + // + // Included in the default implementation of DebugString. + virtual absl::Status SummarizeValue(std::string& summary) const; + + // For LLVM style RTTI. + static bool classof(const AbstractTensorHandle* ptr) { + return ptr->getKind() == kEager || ptr->getKind() == kTfrt; + } + + protected: + explicit ImmediateExecutionTensorHandle(AbstractTensorHandleKind kind) + : AbstractTensorHandle(kind) {} + ~ImmediateExecutionTensorHandle() override {} +}; + +namespace internal { +struct ImmediateExecutionTensorHandleDeleter { + void operator()(ImmediateExecutionTensorHandle* p) const { + if (p != nullptr) { + p->Unref(); + } + } +}; +} // namespace internal + +using ImmediateTensorHandlePtr = + std::unique_ptr; + +} // namespace tensorflow + +#endif // TENSORFLOW_C_EAGER_IMMEDIATE_EXECUTION_TENSOR_HANDLE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/c/eager/parallel_device/parallel_device.h b/third_party/tflite-hdrs/tensorflow/c/eager/parallel_device/parallel_device.h new file mode 100644 index 00000000..b8e571b8 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/c/eager/parallel_device/parallel_device.h @@ -0,0 +1,65 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_EAGER_PARALLEL_DEVICE_PARALLEL_DEVICE_H_ +#define TENSORFLOW_C_EAGER_PARALLEL_DEVICE_PARALLEL_DEVICE_H_ + +#include "tensorflow/c/c_api.h" +#include "tensorflow/c/eager/c_api.h" +#include "tensorflow/c/eager/c_api_experimental.h" + +namespace tensorflow { +namespace parallel_device { + +// Allocate a parallel device named `device_name` which forwards operations to +// `underlying_devices`, maintaining "parallel tensors" with components placed +// on each underlying device. +// +// For example if `device_name` is +// "/job:localhost/replica:0/task:0/device:CUSTOM:0" +// and `underlying_devices` is +// {"/job:localhost/replica:0/task:0/device:GPU:0", +// "/job:localhost/replica:0/task:0/device:GPU:1"} +// Then executing an operation on CUSTOM:0 will execute it on GPU:0 and GPU:1. +// +// Implicit copies onto `device_name` are allowed, replicating the value once +// per device in `underlying_devices`. Implicit copies off of the device throw +// an error. +// +// All component tensors must have the same dtype. Currently they must also have +// the same shape, although this requirement may be relaxed in the future. +// +// `device_name` must not name an existing physical or custom device (see +// the documentation for TFE_RegisterCustomDevice for more information). +// +// Tensors may be copied on or off the device explicitly using +// TPUReplicatedInput and TPUReplicatedOutput respectively. For example, with +// two component devices, running `x = TPUReplicatedInput(inputs=[a, b])` on the +// parallel device creates a parallel tensor `x` with `a` on the first of +// `underlying_devices` and `b` on the second. Running `a_unpacked, b_unpacked = +// TPUReplicatedOutput(input=x, num_replicas=2)` un-packs the parallel tensor +// into its components. +// +// The filled `device` struct and the allocated `device_info` struct may be +// passed to TFE_RegisterCustomDevice. The `device_name` arguments must match. +void AllocateParallelDevice(const char* device_name, + const char* const* underlying_devices, + int num_underlying_devices, + TFE_CustomDevice* device, void** device_info); + +} // namespace parallel_device +} // namespace tensorflow + +#endif // TENSORFLOW_C_EAGER_PARALLEL_DEVICE_PARALLEL_DEVICE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/c/eager/parallel_device/parallel_device_lib.h b/third_party/tflite-hdrs/tensorflow/c/eager/parallel_device/parallel_device_lib.h new file mode 100644 index 00000000..03845d15 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/c/eager/parallel_device/parallel_device_lib.h @@ -0,0 +1,299 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_EAGER_PARALLEL_DEVICE_PARALLEL_DEVICE_LIB_H_ +#define TENSORFLOW_C_EAGER_PARALLEL_DEVICE_PARALLEL_DEVICE_LIB_H_ + +#include +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/types/optional.h" +#include "absl/types/span.h" +#include "absl/types/variant.h" +#include "tensorflow/c/c_api.h" +#include "tensorflow/c/eager/c_api.h" +#include "tensorflow/c/eager/c_api_experimental.h" +#include "tensorflow/c/eager/tfe_op_internal.h" +#include "tensorflow/c/safe_ptr.h" +#include "tensorflow/c/tf_datatype.h" +#include "tensorflow/c/tf_status.h" +#include "tensorflow/c/tf_tensor.h" +#include "tensorflow/core/framework/cancellation.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/platform/status.h" + +namespace tensorflow { +namespace parallel_device { + +using TensorHandlePtr = tensorflow::Safe_TFE_TensorHandlePtr; + +class ParallelTensor; +class DeviceThread; + +// Forwards operations to `devices`, maintaining ParallelTensor with components +// placed on each underlying device. +class ParallelDevice { + public: + // Eager async execution is only supported when remote eager is not in use + // (b/157523095). + explicit ParallelDevice(const std::vector& devices, + bool is_async = false, int in_flight_nodes_limit = 0); + + ~ParallelDevice(); + + // Helper to copy a tensor handle from another device once for each component + // of the ParallelDevice. + // + // Sets a bad status and returns a nullptr if `tensor` is already on the + // ParallelDevice, or if the individual copies fail. + std::unique_ptr CopyToParallelDevice(TFE_Context* context, + TFE_TensorHandle* tensor, + TF_Status* status) const; + + // Construct a parallel tensor consisting of the scalar values from `values`. + template + std::unique_ptr ScalarsFromSequence( + absl::Span values, TFE_Context* context, + TF_Status* status) const; + + // A parallel tensor with scalar integers numbering component devices. + std::unique_ptr DeviceIDs(TFE_Context* context, + TF_Status* status) const; + + // The number of devices operations run on. + size_t num_underlying_devices() const { return underlying_devices_.size(); } + + // The devices operations run on. + const std::vector& underlying_devices() const { + return underlying_devices_; + } + + // Takes a description of a single operation being executed on the + // ParallelDevice, and in turn runs one operation per component device with + // its corresponding inputs from the input ParallelTensors. Wraps the + // resulting per-device and per-output TFE_TensorHandles into one + // ParallelTensor per output of the original operation. + // + // Attributes are forwarded to executed operations unmodified. + // + // The returned optional has a value if and only if `status` evaluates to + // TF_OK. Bad statuses are forwarded from underlying `TFE_Execute` calls, or + // if sanity checks on dtypes/metadata fail. + absl::optional>> Execute( + TFE_Context* context, const std::vector& inputs, + const char* operation_name, const TFE_OpAttrs* attributes, + int expected_max_outputs, TF_Status* status) const; + + // A non-blocking version of `Execute`. After each call, `Join` must be called + // before `StartExecute` is called again. Using `StartExecute` with `Join` + // allows the caller to schedule computation on multiple ParallelDevices + // without sequencing those operations (first call `StartExecute` on each + // parallel device, then call `Join` on each; even if some of the `Join`s + // return a bad status the caller must run all of the `Join`s or any future + // `StartExecute`s will deadlock). + // + // If `is_async=false` (constructor argument), `cancellation_manager` must + // live until `Join` finishes. If `is_async=true` it must live until `Join` is + // followed by `TFE_ContextAsyncWait` to clear pending operations. It will be + // used to cancel all other operations if any fails. + // + // Set step_id to configure the step id used for rendezvous creation. step id + // of value -1 is reserved for global rendezvous and should not be set here. + void StartExecute(TFE_Context* context, + const std::vector& inputs, + const char* operation_name, const TFE_OpAttrs* attributes, + int expected_max_outputs, + CancellationManager& cancellation_manager, + std::optional step_id = std::nullopt) const; + + void StartExecute(TFE_Context* context, + const std::vector>& inputs, + const char* operation_name, const TFE_OpAttrs* attributes, + int expected_max_outputs, + CancellationManager& cancellation_manager, + std::optional step_id = std::nullopt) const; + + // Blocks until the previous `StartExecute` has run `TFE_Execute` on each + // device. If is_async=false (constructor argument) this means the ops have + // run and have results. If is_async=true it means that all of the + // device-specific executors have scheduled the op. + // + // Accepts inferred shapes for outputs (`expected_output_shapes`), which if + // fully defined will avoid querying the shapes of the underlying + // TensorHandles when ParallelTensor::Shape is called. This allows async + // computation to continue without blocking. + // + // The return status and value is the same as `Execute`. + absl::optional>> Join( + const std::vector& expected_output_shapes, + TF_Status* status) const; + + void AsyncWait(TFE_Context* context, TF_Status* status) const; + + // Device strings for component devices that only include a + // worker/task/replica if any of those differ across components. Useful for + // printing debug messages. + std::vector SummarizeDeviceNames() const; + + private: + // A sequence of device names, indicating which devices replicated operations + // are forwarded to. + const std::vector underlying_devices_; + // A sequence of thread wrappers, one per device, for executing operations in + // parallel. + // + // Conceptually this is a thread pool with one thread per device. It requires + // less synchronization than a thread pool would for this task, since Execute + // acquires each thread in order (and so only one Execute will schedule + // blocking collective operations at a time), and avoids some dynamic + // allocation/scheduling. + // + // TODO(allenl): Keep a map from outer thread to list of inner threads rather + // than a single list of threads so aliased nested parallel devices don't + // re-use a thread. + std::vector> device_threads_; + // A cancellation manager to use if the caller does not provide one. When ops + // are executed asynchronously this must outlive the queued op, so it can't be + // function-local to Execute. + mutable std::unique_ptr default_cancellation_manager_; +}; + +// Contains a tuple of tensors, one on each of the `underlying_devices_` of the +// ParallelDevice. +class ParallelTensor { + public: + // Construct a ParallelTensor from TensorHandles placed on the component + // devices of a ParallelDevice. If called, ParallelTensor::Shape inspects + // `components` to determine a shape. + static std::unique_ptr FromTensorHandles( + const ParallelDevice& parallel_device, + std::vector components, TF_Status* status); + // Uses the provided shape without additional checks, which avoids blocking + // when ParallelTensor::Shape is called. + static std::unique_ptr FromTensorHandles( + const ParallelDevice& parallel_device, + std::vector components, absl::Span shape, + TF_Status* status); + + size_t num_tensors() const { return tensors_.size(); } + TFE_TensorHandle* tensor(size_t index) const { return tensors_[index].get(); } + + // If the `shape` argument to `FromTensorHandles` is specified, returns that. + // + // Otherwise if all of the tensors have the same shape, returns that via the + // `shape` output argument. This blocks waiting for async tensors, may return + // a delayed bad status encountered during async execution, and will return a + // bad status unless all tensors have the same shape. + absl::Status Shape(const std::vector** shape) const; + TF_DataType dtype() const { return dtype_; } + + // Sets its output argument to a summary of the values of this tensor on every + // component device. + absl::Status SummarizeValue(std::string& summary); + + std::vector release_tensors() { return std::move(tensors_); } + + std::vector tensors() const { + std::vector result; + result.reserve(tensors_.size()); + for (const TensorHandlePtr& tensor : tensors_) { + result.emplace_back(tensor.get()); + } + return result; + } + + private: + ParallelTensor(const ParallelDevice& device, + std::vector tensors, + absl::Span shape, const TF_DataType dtype) + : device_(device), + tensors_(std::move(tensors)), + shape_(std::vector(shape.begin(), shape.end())), + dtype_(dtype) {} + ParallelTensor(const ParallelDevice& device, + std::vector tensors, const TF_DataType dtype) + : device_(device), + tensors_(std::move(tensors)), + shape_(absl::nullopt), + dtype_(dtype) {} + + const ParallelDevice& device_; + std::vector tensors_; + // Parallel tensors are immutable but compute their shape lazily unless it is + // provided on construction. The optional has a value if the lazy computation + // has been completed or the shape was provided on construction. + mutable absl::optional> shape_; + const TF_DataType dtype_; +}; + +template +std::unique_ptr ParallelDevice::ScalarsFromSequence( + absl::Span values, TFE_Context* context, + TF_Status* status) const { + std::vector components; + components.reserve(underlying_devices_.size()); + + if (values.size() != num_underlying_devices()) { + TF_SetStatus( + status, TF_INVALID_ARGUMENT, + "Number of values did not match number of underlying devices."); + return nullptr; + } + TF_DataType datatype_enum( + static_cast(DataTypeToEnum().value)); + for (int device_index = 0; device_index < num_underlying_devices(); + ++device_index) { + auto device_value = absl::make_unique(); + *device_value = values[device_index]; + std::unique_ptr tensor( + TF_NewTensor( + datatype_enum, /*dims=*/nullptr, /*num_dims=*/0, + device_value.release(), sizeof(DataType), + [](void* data, size_t, void* arg) { + delete reinterpret_cast(data); + }, + nullptr), + TF_DeleteTensor); + // TODO(allenl): Here and when executing regular operations, we could hold + // on to one TFE_Op per device and just call TFE_ResetOp to avoid parsing + // device names repeatedly. + std::unique_ptr const_op( + TFE_NewOp(context, "Const", status), TFE_DeleteOp); + if (TF_GetCode(status) != TF_OK) return nullptr; + TFE_OpSetDevice(const_op.get(), underlying_devices_[device_index].c_str(), + status); + if (TF_GetCode(status) != TF_OK) return nullptr; + TFE_OpSetAttrTensor(const_op.get(), "value", tensor.get(), status); + if (TF_GetCode(status) != TF_OK) return nullptr; + TFE_OpSetAttrType(const_op.get(), "dtype", datatype_enum); + TFE_TensorHandle* device_handle; + int num_outputs = 1; + TFE_Execute(const_op.get(), &device_handle, &num_outputs, status); + if (TF_GetCode(status) != TF_OK) return nullptr; + components.emplace_back(device_handle); + } + return ParallelTensor::FromTensorHandles(*this, std::move(components), + status); +} + +} // namespace parallel_device +} // namespace tensorflow + +#endif // TENSORFLOW_C_EAGER_PARALLEL_DEVICE_PARALLEL_DEVICE_LIB_H_ diff --git a/third_party/tflite-hdrs/tensorflow/c/eager/parallel_device/parallel_device_testlib.h b/third_party/tflite-hdrs/tensorflow/c/eager/parallel_device/parallel_device_testlib.h new file mode 100644 index 00000000..d55a23bd --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/c/eager/parallel_device/parallel_device_testlib.h @@ -0,0 +1,172 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_EAGER_PARALLEL_DEVICE_PARALLEL_DEVICE_TESTLIB_H_ +#define TENSORFLOW_C_EAGER_PARALLEL_DEVICE_PARALLEL_DEVICE_TESTLIB_H_ + +#include + +#include "tensorflow/c/c_api.h" +#include "tensorflow/c/c_api_experimental.h" +#include "tensorflow/c/eager/c_api.h" +#include "tensorflow/c/eager/c_api_experimental.h" +#include "tensorflow/c/eager/parallel_device/parallel_device.h" +#include "tensorflow/c/eager/parallel_device/parallel_device_lib.h" +#include "tensorflow/c/tf_datatype.h" +#include "tensorflow/c/tf_status.h" +#include "tensorflow/c/tf_tensor.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace parallel_device { + +// A helper for performing common operations on variables. A much more +// restricted stand-in for tf.Variable in Python. +class Variable { + public: + // Construct a Variable from a resource-dtype TFE_TensorHandle and an + // indication of the dtype of the variable's value. + // + // Note that creating this resource-dtype handle can fail, so `Create` is a + // separate static method which returns a status. + Variable(TFE_TensorHandle* handle, TF_DataType type) + : handle_(handle), type_(type) {} + + // Helper for constructing a resource handle and wrapping it in a `Variable` + // object. + static Variable* Create(TFE_Context* context, TF_DataType type, + const int64_t* dims, const int num_dims, + const char* device, TF_Status* status); + // Dereferences the backing buffer for the variable. Note that since this can + // fail (it runs operations), it must be called explicitly and the resulting + // `status` checked. + void Destroy(TFE_Context* context, TF_Status* status); + + // Reads from the variable. + TensorHandlePtr Read(TFE_Context* context, TF_Status* status); + // Assigns a new value to the variable. + void Assign(TFE_Context* context, TFE_TensorHandle* value, TF_Status* status); + // Adds `value` to the existing value of the variable. + void AssignAdd(TFE_Context* context, TFE_TensorHandle* value, + TF_Status* status); + + private: + // Helper for running any single-argument assignment ops (Assign, AssignAdd, + // AssignSub, ...). + void GeneralAssignment(const char* op_name, TFE_Context* context, + TFE_TensorHandle* value, TF_Status* status); + + // The a handle for the resource-dtype tensor pointing to the variable's + // buffer. + TFE_TensorHandle* handle_; + // The dtype of the variable's buffer (input dtype for assignments, output + // dtype of read operations). + TF_DataType type_; +}; + +// Creates a TFE_TensorHandle with value `v`. +TensorHandlePtr FloatTensorHandle(float v, TF_Status* status); + +// Creates a rank-one TFE_TensorHandle with value `v`. +TensorHandlePtr VectorFloatTensorHandle(const std::vector& v, + TF_Status* status); + +// Helper to un-pack `num_replicas` TFE_TensorHandles from one parallel handle. +template +void ExtractPerDeviceValues( + TFE_Context* context, TFE_TensorHandle* input, + std::array* components, TF_Status* status); + +// Helper to pack `num_replicas` TFE_TensorHandles into one parallel handle. +template +TensorHandlePtr CreatePerDeviceValues( + TFE_Context* context, + const std::array& components, + const char* device, TF_Status* status); + +TensorHandlePtr Multiply(TFE_Context* context, TFE_TensorHandle* first, + TFE_TensorHandle* second, TF_Status* status); + +// Assert that `handle` is equal to `expected_value`. +template +void ExpectScalarEq(TFE_TensorHandle* handle, value_type expected_value); + +template +void RegisterParallelDevice( + TFE_Context* context, const char* device_name, + const std::array& underlying_devices, + TF_Status* status); + +// Create and modify a variable placed on a parallel device which composes +// `first_device` and `second_device`. +void BasicTestsForTwoDevices(TFE_Context* context, const char* first_device, + const char* second_device); + +// Implementations of templated functions ****************************** + +template +TensorHandlePtr CreatePerDeviceValues( + TFE_Context* context, + const std::array& components, + const char* device, TF_Status* status) { + std::unique_ptr op( + TFE_NewOp(context, "TPUReplicatedInput", status), TFE_DeleteOp); + if (TF_GetCode(status) != TF_OK) return nullptr; + TFE_OpSetAttrInt(op.get(), "N", num_replicas); + for (int i = 0; i < num_replicas; ++i) { + TFE_OpAddInput(op.get(), components[i], status); + if (TF_GetCode(status) != TF_OK) return nullptr; + } + TFE_OpSetDevice(op.get(), device, status); + if (TF_GetCode(status) != TF_OK) return nullptr; + + TFE_TensorHandle* result_handle; + int num_retvals = 1; + TFE_Execute(op.get(), &result_handle, &num_retvals, status); + if (TF_GetCode(status) != TF_OK) return nullptr; + return TensorHandlePtr(result_handle); +} + +template +void ExpectScalarEq(TFE_TensorHandle* handle, value_type expected_value) { + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + std::unique_ptr actual_value( + TFE_TensorHandleResolve(handle, status.get()), TF_DeleteTensor); + ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get()); + ASSERT_EQ(TF_TensorType(actual_value.get()), + static_cast(DataTypeToEnum().value)); + EXPECT_EQ(expected_value, + *static_cast(TF_TensorData(actual_value.get()))); +} + +template +void RegisterParallelDevice( + TFE_Context* context, const char* device_name, + const std::array& underlying_devices, + TF_Status* status) { + TFE_CustomDevice device; + void* device_info; + tensorflow::parallel_device::AllocateParallelDevice( + device_name, underlying_devices.data(), underlying_devices.size(), + &device, &device_info); + TFE_RegisterCustomDevice(context, device, device_name, device_info, status); +} + +} // namespace parallel_device +} // namespace tensorflow + +#endif // TENSORFLOW_C_EAGER_PARALLEL_DEVICE_PARALLEL_DEVICE_TESTLIB_H_ diff --git a/third_party/tflite-hdrs/tensorflow/c/eager/tape.h b/third_party/tflite-hdrs/tensorflow/c/eager/tape.h new file mode 100644 index 00000000..7ed8025b --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/c/eager/tape.h @@ -0,0 +1,1168 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_C_EAGER_TAPE_H_ +#define TENSORFLOW_C_EAGER_TAPE_H_ + +// Language-agnostic gradient tape. Does not perform backpropagation, just +// maintains the data structures required to do so. + +#include +#include +#include +#include + +#include "tensorflow/core/config/flag_defs.h" +#include "tensorflow/core/config/flags.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/lib/gtl/cleanup.h" +#include "tensorflow/core/lib/gtl/flatmap.h" +#include "tensorflow/core/lib/gtl/flatset.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +namespace eager { + +// Represents an entry in the tape. +template +struct OpTapeEntry { + string op_type; + std::vector output_tensor_info; + std::vector input_tensor_id; + + // TODO(apassos) consider narrowing down this interface. + BackwardFunction* backward_function; + + // Should be called before deleting the backward function. TODO(apassos) use + // unique_ptrs to ensure this happens. + std::function backward_function_deleter; +}; + +// Map from tensor_id to internally-defined operation-id of the operation which +// produced this tensor. A value of -1 means that the tensor was directly +// watched and not the result of any operation in the tape. +using TensorTape = std::unordered_map; + +// Map from operation-id to tape entry. +template +using OpTape = + std::unordered_map>; + +// Operations the tape needs to perform on tensors to do backpropagation. Named +// "vspace" because a subset of these are related to a vector space, such as +// adding gradients, getting zeroes, etc. Currently cannot be implemented +// without using tensorflow python code, hence left unspecified here. +// +// Gradient is the type returned by gradient functions. In Python TF it's either +// Tensor or IndexedSlices or None, which here we map to nullptr. Gradients need +// to allow their size to be computed and they need to be passable to a backward +// function and deleted (as the backprop code creates lots of gradients the user +// is not interested in). +// +// BackwardFunction needs to be a closure which stores intermediate activations +// from the forward computation and calls a vector-jacobian product function +// (also known as adjoint function) to compute, given downstream gradients, +// upstream gradients. +// +// TODO(apassos) provide concrete template instantiations for TFE_TensorHandle +// specialization, which is blocked by quite a few things needing to loop back +// into python now. +template +class VSpace { + public: + virtual ~VSpace() {} + + // Returns the number of elements in the gradient tensor. + virtual int64_t NumElements(Gradient* tensor) const = 0; + + // Consumes references to the tensors in the gradient_tensors list and returns + // a tensor with the result. + virtual Gradient* AggregateGradients( + gtl::ArraySlice gradient_tensors) const = 0; + + // Calls the passed-in backward function. + // + // `unneeded_gradients` contains sorted list of input indices for which a + // gradient is not required. + virtual absl::Status CallBackwardFunction( + const string& op_type, BackwardFunction* backward_function, + const std::vector& unneeded_gradients, + gtl::ArraySlice output_gradients, + absl::Span result) const = 0; + + // Builds a tensor filled with ones with the same shape and dtype as `t`. + virtual absl::Status BuildOnesLike(const TapeTensor& t, + Gradient** result) const = 0; + + // Looks up the ID of a Gradient. + virtual int64_t TensorId(Gradient* tensor) const = 0; + + // Converts a Gradient to a TapeTensor. + virtual TapeTensor TapeTensorFromGradient(Gradient* gradient) const = 0; + + // Marks the following gradient as a result so it's not consumed by backward + // functions. + virtual void MarkAsResult(Gradient* gradient) const = 0; + + // Deletes the input tensor. + virtual void DeleteGradient(Gradient* gradient) const = 0; +}; + +// Traces the execution of operations, doing eager garbage collection, and +// exporting a full trace so other code can do backpropagation. Not thread-safe. +template +class GradientTape { + public: + // If `persistent` is true, GradientTape will not eagerly delete backward + // functions (and hence the tensors they keep alive). Instead, everything + // is deleted in ~GradientTape. Persistent GradientTapes are useful when + // users want to compute multiple gradients over the same tape. + explicit GradientTape(bool persistent) : persistent_(persistent) {} + ~GradientTape() { + for (const auto& pair : op_tape_) { + pair.second.backward_function_deleter(pair.second.backward_function); + } + } + + // Returns whether any tensor in a list of tensors is being watched and has + // a trainable dtype. + bool ShouldRecord(absl::Span tensor_ids, + absl::Span dtypes) const; + + // Adds this tensor to the list of watched tensors. + // + // This is a no-op if the tensor is already being watched either from an + // earlier call to `GradientTape::Watch` or being an output of an op with + // watched inputs. + void Watch(int64_t tensor_id); + + // Records an operation with inputs `input_tensor_id` and outputs + // `output_tensors` on the tape and marks all its outputs as watched if at + // least one input of the op is watched and has trainable dtype. + // + // op_type is used to decide which of the incoming gradients can be left as + // nullptr instead of building zeros when build_default_zeros_grads == true. + void RecordOperation( + const string& op_type, const std::vector& output_tensors, + absl::Span input_tensor_id, + absl::Span input_dtypes, + const std::function& backward_function_getter, + const std::function& backward_function_deleter); + + void DeleteTrace(int64_t tensor_id); + + // Consumes the internal state of the tape (so cannot be called more than + // once) and produces the gradient of the target tensors with respect to the + // source tensors. The output gradients are used if not empty and not + // null. The result is populated with one tensor per target element. + // When running backward functions, builds zeros-like tensors for + // incoming grads which are nullptrs, unless `build_default_zeros_grads` + // is set to false. + absl::Status ComputeGradient( + const VSpace& vspace, + const absl::Span target_tensor_ids, + const absl::Span source_tensor_ids, + const std::unordered_map& sources_that_are_targets, + gtl::ArraySlice output_gradients, absl::Span result, + bool build_default_zeros_grads = true); + + // Whether the tape is persistent. See ctor for detailed description. + bool IsPersistent() const { return persistent_; } + + private: + TensorTape tensor_tape_; + OpTape op_tape_; + int64_t next_op_id_{0}; + + // Map from tensor id to number of remaining usages (i.e. how many entries in + // the tape refer to it); to aid in tape garbage collection. + std::unordered_map tensor_usage_; + + // If false, all activations are deleted in the first call to ComputeGradient. + // Else, only when this is destructed. + bool persistent_; +}; + +// Describes a callback for special-cased and more efficient jvp computation. +// +// Could just be a simple typedef in ForwardAccumulator, but MSVC chokes on +// that. +template +class ForwardFunction + : public std::function&, + std::vector*, bool)> { + public: + template + explicit ForwardFunction(lambda_type lambda) + : std::function&, + std::vector*, bool)>(lambda) {} +}; + +// Computes Jacobian-vector products using forward-mode automatic +// differentiation. +// +// While GradientTape's RecordOperation is trivial, ForwardAccumulator's +// Accumulate runs the gradient computation immediately. +// +// Keeps references to Tensors watched via Watch and computed in Accumulate +// corresponding to output_tensors, and releases these references in its +// destructor. However, waiting until the destructor runs loses the memory +// efficiency of forward-mode autodiff. Instead, language bindings should call +// DeleteGradient as soon as a Tensor which was `Watch`ed or was an output +// Tensor passed to Accumulate goes out of scope. +// +// Not thread-safe. +template +class ForwardAccumulator { + public: + // Does not take ownership of `vspace`, which must outlive the + // ForwardAccumulator. + explicit ForwardAccumulator( + const VSpace& vspace, + bool use_batch) + : vspace_(vspace), use_batch_(use_batch) { + call_state_.emplace(nullptr, false); + } + + virtual ~ForwardAccumulator() { + for (auto accumulated : accumulated_gradients_) { + vspace_.DeleteGradient(accumulated.second); + } + } + + // Tell the forward accumulator to watch tensor_id, with a Tensor tangent + // vector `tangent` of matching shape and dtype. Tangents are the "vector" in + // "Jacobian-vector product"; `Watch`ing a new Tensor and immediately calling + // FetchJVP for it would return `tangent`. + void Watch(int64_t tensor_id, Gradient* tangent); + + // Removes the gradient associated with tensor_id. Should be called when the + // Tensor associated with `tensor_id` is deleted. + void DeleteGradient(int64_t tensor_id); + + // Runs forward autodiff. Should be called whenever a new operation is + // available and the accumulator is active. + // + // Like GradientTape::RecordOperation, this method takes the operation type + // `op_type` (e.g. "Add"), the operation's inputs (`input_tensors`, + // `input_tensor_id`, and `input_dtypes`; the latter two are somewhat + // redundant but taken as arguments to avoid repeatedly fetching these values + // between calls to ShouldRecord and Accumulator), and its outputs + // (`output_tensors`). + // + // If provided, a non-null `forward_function` will be used instead of the + // backward function (`backward_function_getter` / + // `backward_function_deleter`) to compute jvps for this operation. If + // `forward_function` is null, a GradientTape is used on the backward function + // to compute the jvp, which will waste computation when executing eagerly. + // + // Unlike GradientTape::RecordOperation, Accumulate runs gradient computation + // immediately. It stores the results, which feed into Accumulate for future + // operations and may be fetched by calling FetchJVP. ForwardAccumulator + // maintains a reference to these JVPs: if an `output_tensors` Tensor is + // deleted, `DeleteGradient` should be called as soon as possible to free the + // (now inaccessible) corresponding JVPs, but ForwardAccumulator's destructor + // will release remaining references. + // + // This method is not thread-safe (and in general ForwardAccumulator is not + // thread-safe). + absl::Status Accumulate( + const string& op_type, const std::vector& input_tensors, + const std::vector& output_tensors, + absl::Span input_tensor_id, + absl::Span input_dtypes, + const ForwardFunction* forward_function, + const std::function& backward_function_getter, + const std::function& backward_function_deleter); + + // Returns true if `Accumulate` is active somewhere above on the stack and + // there isn't an intervening PushState. This is useful for ordering + // ForwardAccumulators, where more deeply nested accumulators should not see + // computations from less deeply nested accumulators. + bool BusyAccumulating() const { return call_state_.top().accumulating; } + + // Fetches the current Jacobian-vector product associated with `tensor_id`, or + // a nullptr if none is available. + // + // Returns a borrowed reference, i.e. does not run VSpace::MarkAsResult on its + // return value. The caller should increment the reference count before + // deleting the ForwardAccumulator or calling DeleteGradient if keeping a + // persistent reference to a non-null result. + Gradient* FetchJVP(int64_t tensor_id); + + // Indicates whether the forward accumulator should run on an operation with + // the specified inputs and dtypes. + bool ShouldRecord(absl::Span tensor_ids, + absl::Span dtypes); + + // Temporarily push or pop transient state for this accumulator. + // + // Allows an accumulator which is currently processing an operation to + // temporarily reset its state. Without pushing and popping, accumulators + // ignore operations executed as a direct result of their own jvp + // computations. + void PushState() { call_state_.emplace(nullptr, false); } + void PopState() { call_state_.pop(); } + + private: + // Helper for Accumulate: uses a GradientTape to compute forward gradients + // from a backward gradient function. Fills `out_grads` corresponding to + // `output_tensors`. `out_grads` must not be null. + // + // Executes the backward function in order to trace its gradient, which will + // waste computation if executing eagerly (when graph building the unneeded + // computation is pruned). Temporarily sets `backward_tape` so that + // Accumulate will forward op executions to the tape while the backward + // function is running; this effectively adds the backward tape to the active + // set (but does not require complicated callbacks to the language bindings). + absl::Status ForwardpropFromTape( + const string& op_type, const std::vector& output_tensors, + const std::function& backward_function_getter, + const std::function& backward_function_deleter, + const std::vector& in_grads, absl::Span out_grads); + + // Maps from tensor IDs to corresponding JVPs. + std::unordered_map accumulated_gradients_; + // Not owned; provides operations on Tensors which are currently only + // available in language bindings (e.g. Python). + const VSpace& vspace_; + + // Decides if tangents are vectorized or not + bool use_batch_; + + struct AccumulatorCallState { + AccumulatorCallState( + GradientTape* backward_tape, + bool accumulating) + : backward_tape(backward_tape), accumulating(accumulating) {} + // Set temporarily while in the Accumulate method; if backward_tape is not + // nullptr then we forward op executions to it so Accumulate can compute a + // backward pass on its backward function. + // + // Not owned by the ForwardAccumulator. The method which sets + // `backward_tape` keeps ownership. + GradientTape* backward_tape; + // While the Accumulate method is running (accumulating is True), any op + // executions not forwarded to backward_tape should be ignored. + bool accumulating; + }; + // A deque-backed stack, whose element references are not invalidated by + // pushes and pops at the back. + std::stack call_state_; +}; + +// Template instantiations here + +inline bool IsDtypeTrainable(DataType dtype) { + switch (dtype) { + case DT_HALF: + case DT_BFLOAT16: + case DT_FLOAT: + case DT_DOUBLE: + case DT_COMPLEX64: + case DT_COMPLEX128: + case DT_RESOURCE: + case DT_VARIANT: + return true; + case DT_QINT8: + case DT_QINT16: + case DT_QINT32: + case DT_QUINT8: + case DT_QUINT16: + return tensorflow::flags::Global() + .enable_quantized_dtypes_training.value(); + default: + return false; + } +} + +template +bool GradientTape::ShouldRecord( + absl::Span tensor_ids, + absl::Span dtypes) const { + CHECK_EQ(tensor_ids.size(), dtypes.size()); + for (int i = 0; i < tensor_ids.size(); ++i) { + if (tensor_tape_.find(tensor_ids[i]) != tensor_tape_.end()) { + if (IsDtypeTrainable(dtypes[i])) { + return true; + } + } + } + return false; +} + +template +void GradientTape::Watch( + int64_t tensor_id) { + tensor_tape_.emplace(tensor_id, -1); +} + +template +void GradientTape::RecordOperation( + const string& op_type, const std::vector& output_tensors, + absl::Span input_tensor_id, + absl::Span input_dtypes, + const std::function& backward_function_getter, + const std::function& backward_function_deleter) { + if (!ShouldRecord(input_tensor_id, input_dtypes)) { + return; + } + std::vector ids; + ids.reserve(input_tensor_id.size()); + for (int64_t i : input_tensor_id) { + tensor_usage_[i]++; + ids.push_back(i); + } + const int64_t op_id = next_op_id_++; + std::vector tensors; + tensors.reserve(output_tensors.size()); + for (const TapeTensor& o : output_tensors) { + // Note: the tensor can have already been watched and hence be in the tape, + // so we cannot check that we're inserting it here. + tensor_tape_[o.GetID()] = op_id; + tensor_usage_[o.GetID()] = 1; + tensors.push_back(o); + } + op_tape_[op_id] = OpTapeEntry{ + op_type, std::move(tensors), std::move(ids), backward_function_getter(), + backward_function_deleter}; +} + +template +void GradientTape::DeleteTrace( + int64_t tensor_id) { + auto it = tensor_usage_.find(tensor_id); + if (it == tensor_usage_.end()) { + return; + } + it->second--; + if (it->second != 0) { + return; + } + tensor_usage_.erase(it); + auto tensor_op_it = tensor_tape_.find(tensor_id); + if (tensor_op_it == tensor_tape_.end()) { + return; + } + const int64_t op_id = tensor_op_it->second; + if (op_id == -1) { + // Do not delete watched tensors. + return; + } + tensor_tape_.erase(tensor_op_it); + auto op_it = op_tape_.find(op_id); + CHECK(op_it != op_tape_.end()); + for (const auto& output : op_it->second.output_tensor_info) { + if (tensor_usage_.find(output.GetID()) != tensor_usage_.end()) { + // Found a usage for an output, so cannot delete the op. + return; + } + } + for (int64_t id : op_it->second.input_tensor_id) { + DeleteTrace(id); + } + op_it->second.backward_function_deleter(op_it->second.backward_function); + op_tape_.erase(op_it); +} + +// Terminology: +// +// - op: a possibly composite operation, which has an entry in the tape +// - target: dy in dx/dy +// - source: dx in dx/dy +// - tensor: one of the many inputs or outputs of an operation +// +// Below here we do the gradient algorithm. It works as follows: +// +// First we filter the tape to just the subset of operations we want to +// differentiate. In the process of doing so we count how many times each Tensor +// is used as an input to an op (so we know when we're done computing gradients +// for that Tensor). We also count, for each tape entry, how many of its output +// Tensors need gradients to be computed (Tensors which are not used do not need +// any gradients to be computed). +// +// Finally, we start a backprop stack with a set of tape entries for which we +// have all gradients available. This set usually is a subset of the set of +// targets (not all since targets which have outputs in the tape will not have +// gradients available initially). +// +// Then we repeatedly pop an entry from the stack, run its backprop, and update +// the gradients of its inputs. Once we have computed all gradients for a single +// input we can mark this input as done, and this can trigger adding an entry to +// the stack if all outputs of that entry are now done. +// +// When the stack is empty we have gradients for all tensors we're interested +// in. + +namespace { + +template +struct BackpropInitialState { + OpTape op_tape; + + // Map from tensor ID to how many references still exist for this tensor in + // the tape. + std::unordered_map tensor_usage_counts; + + // Maps from op ID to how many output tensors of this op still need to have + // their gradients computed. + std::unordered_map op_missing_tensor; +}; + +// If `persistent_tape` is true, op_tape is not changed and none of the +// backwards functions are deleted. +// If `persistent_tape` is false, op_tape is cleared and backwards functions +// not needed for gradient computation are deleted. Backwards functions that +// are needed, are copied and returned in BackpropInitialState. +template +BackpropInitialState PrepareBackprop( + absl::Span target, const TensorTape& tensor_tape, + OpTape* op_tape, + const std::unordered_set& sources_set, bool persistent_tape) { + std::vector tensor_stack; + tensor_stack.reserve(target.size()); + for (auto t : target) { + tensor_stack.push_back(t); + } + BackpropInitialState result; + while (!tensor_stack.empty()) { + int64_t tensor_id = tensor_stack.back(); + tensor_stack.pop_back(); + auto op_id_it = tensor_tape.find(tensor_id); + if (op_id_it == tensor_tape.end()) { + continue; + } + int64_t op_id = op_id_it->second; + auto op_it = op_tape->find(op_id); + auto result_op_it = result.op_tape.find(op_id); + if (op_id == -1 || op_it == op_tape->end() || + result_op_it != result.op_tape.end()) { + continue; + } + CHECK(result.op_tape.emplace(op_id, op_it->second).second); + for (auto it : op_it->second.input_tensor_id) { + auto count_it = result.tensor_usage_counts.find(it); + if (count_it != result.tensor_usage_counts.end()) { + count_it->second++; + } else { + result.tensor_usage_counts[it] = 1; + if (tensor_tape.find(it) != tensor_tape.end()) { + tensor_stack.push_back(it); + } + } + } + if (!persistent_tape) { + op_tape->erase(op_it); + } + } + for (auto& pair : result.tensor_usage_counts) { + auto it = tensor_tape.find(pair.first); + if (it != tensor_tape.end() && it->second != -1) { + result.op_missing_tensor[it->second] += 1; + } + } + if (!persistent_tape) { + // Call destructors for all unneeded gradient functions and + // clear the op_tape. We can clear the tape because ownership of + // backward functions that will be used for gradient computation + // has been transferred to `result`. + for (const auto& op_pair : *op_tape) { + op_pair.second.backward_function_deleter( + op_pair.second.backward_function); + } + op_tape->clear(); + } + return result; +} + +template +std::vector InitialStack( + const OpTape& op_tape, + const std::unordered_map& op_missing_tensor) { + std::vector result; + for (auto& op_entry : op_tape) { + if (op_missing_tensor.find(op_entry.first) == op_missing_tensor.end()) { + result.push_back(op_entry.first); + } + } + return result; +} + +template +absl::Status InitialGradients( + const VSpace& vspace, + absl::Span target_tensor_ids, + const std::unordered_map& sources_that_are_targets, + gtl::ArraySlice output_gradients, const TensorTape& tensor_tape, + const OpTape& op_tape, + std::unordered_map>* result) { + for (int i = 0, end = target_tensor_ids.size(); i < end; ++i) { + const int64_t id = target_tensor_ids[i]; + if (output_gradients.empty() || output_gradients[i] == nullptr) { + auto tensor_it = tensor_tape.find(id); + if (tensor_it != tensor_tape.end() && tensor_it->second != -1) { + auto op_it = op_tape.find(tensor_it->second); + if (op_it == op_tape.end()) { + return errors::Internal( + "Internal state of the gradient tape is invalid: " + "failed to find operation producing a tensor"); + } + bool found = false; + for (int j = 0; j < op_it->second.output_tensor_info.size(); ++j) { + if (op_it->second.output_tensor_info[j].GetID() == id) { + found = true; + Gradient* ones_like = nullptr; + TF_RETURN_IF_ERROR(vspace.BuildOnesLike( + op_it->second.output_tensor_info[j], &ones_like)); + (*result)[id].push_back(ones_like); + break; + } + } + if (!found) { + return errors::Internal( + "Internal state of the gradient tape is invalid: " + "none of operations outputs match expected tensor"); + } + } else { + // This target tensor was not generated by any operation recorded on + // the tape, so no gradient needs to be computed from it unless this + // target is also a source. + auto source_tensor = sources_that_are_targets.find(id); + if (source_tensor != sources_that_are_targets.end()) { + Gradient* ones_like = nullptr; + TF_RETURN_IF_ERROR( + vspace.BuildOnesLike(source_tensor->second, &ones_like)); + (*result)[id].push_back(ones_like); + } + } + } else { + (*result)[id].push_back(output_gradients[i]); + } + } + return absl::OkStatus(); +} + +// TODO(agarwal): use an automatic mechanism for handling None arguments to +// gradient functions. +// +// Some gradient functions can accept None arguments for gradients. The +// following maps the operation name to the indices at which the corresponding +// gradient function can accept None values. e.g. FusedBatchNorm outputs 5 +// values and hence receives 5 gradient values during backprop. However the +// gradient function uses only the first of those values and ignores the rest. +// The entry, "FusedBatchNorm": [1, 2, 3, 4], indicates that only the gradient +// corresponding to index 0 is used, and the gradient values at indices 1-4 are +// ignored (and hence can be None). The backprop algorithm can then leverage +// this by not constructing zeros to pass for those indices. +std::unordered_map>* +FunctionsAcceptingNoneForIndicesMap() { + static auto* const m = + new std::unordered_map>({ + {"SoftmaxCrossEntropyWithLogits", {1}}, + {"SparseSoftmaxCrossEntropyWithLogits", {1}}, + {"FusedBatchNorm", {1, 2, 3, 4}}, + }); + return m; +} + +} // namespace + +// If over kMinAggregateCount gradients are accumulated and the total +// memory consumption is over kMinAggregateBytes, do an early aggregation +// so as to release the gradient tensor to save memory. +constexpr int kMinAggregateCount = 4; +constexpr int kMinAggregateBytes = 128 * 1024 * 1024; + +template +absl::Status +GradientTape::ComputeGradient( + const VSpace& vspace, + const absl::Span target_tensor_ids, + const absl::Span source_tensor_ids, + const std::unordered_map& sources_that_are_targets, + gtl::ArraySlice output_gradients, absl::Span result, + bool build_default_zeros_grads) { + std::unordered_set sources_set(source_tensor_ids.begin(), + source_tensor_ids.end()); + BackpropInitialState state = PrepareBackprop( + target_tensor_ids, tensor_tape_, &op_tape_, sources_set, persistent_); + std::vector op_stack = + InitialStack(state.op_tape, state.op_missing_tensor); + std::unordered_map> gradients; + absl::Status s = InitialGradients(vspace, target_tensor_ids, + sources_that_are_targets, output_gradients, + tensor_tape_, state.op_tape, &gradients); + auto cleanup = gtl::MakeCleanup([this, &state]() { + if (!persistent_) { + // Release all backprop functions + for (const auto& pair : state.op_tape) { + pair.second.backward_function_deleter(pair.second.backward_function); + } + } + }); + if (!s.ok()) { + return s; + } + + std::unordered_map gradients_size; + // TODO(apassos) multiple threads could be dequeuing from op_stack at the same + // time, for better CPU backprop performance. + VLOG(1) << "Initial stack:"; + if (VLOG_IS_ON(1)) { + for (auto t : op_stack) { + VLOG(1) << " " << t; + } + } + while (!op_stack.empty()) { + const int64_t op = op_stack.back(); + VLOG(1) << "Popped " << op; + op_stack.pop_back(); + auto op_it = state.op_tape.find(op); + if (op_it == state.op_tape.end()) { + // It is possible for ops to end up on the stack if they are unrelated to + // the target; we should just skip them. + continue; + } + auto trace = std::move(op_it->second); + state.op_tape.erase(op_it); + std::vector out_gradients; + out_gradients.reserve(trace.output_tensor_info.size()); + std::vector unneeded_gradients; + for (int i = 0, end = trace.input_tensor_id.size(); i < end; i++) { + const auto& in_tensor_id = trace.input_tensor_id[i]; + if (tensor_tape_.find(in_tensor_id) == tensor_tape_.end() && + sources_set.find(in_tensor_id) == sources_set.end()) { + unneeded_gradients.push_back(i); + } + } + + bool any_gradient_nonzero = false; + std::vector zero_indices; + for (int i = 0, end = trace.output_tensor_info.size(); i < end; ++i) { + const int64_t id = trace.output_tensor_info[i].GetID(); + auto grad_it = gradients.find(id); + if (grad_it == gradients.end()) { + out_gradients.push_back(nullptr); + if (build_default_zeros_grads) { + auto func_name_it = + FunctionsAcceptingNoneForIndicesMap()->find(trace.op_type); + if (func_name_it == FunctionsAcceptingNoneForIndicesMap()->end() || + func_name_it->second.find(i) == func_name_it->second.end()) { + zero_indices.push_back(i); + } + } + } else { + any_gradient_nonzero = true; + Gradient* new_gradients = nullptr; + if (grad_it->second.size() == 1) { + new_gradients = grad_it->second.at(0); + } else { + new_gradients = vspace.AggregateGradients(grad_it->second); + } + if (sources_set.find(grad_it->first) == sources_set.end()) { + gradients.erase(grad_it); + } else { + grad_it->second.clear(); + grad_it->second.push_back(new_gradients); + vspace.MarkAsResult(new_gradients); + } + out_gradients.push_back(new_gradients); + } + } + VLOG(1) << "Calling gradient function for '" << trace.op_type << "'"; + std::vector in_gradients(trace.input_tensor_id.size()); + DCHECK(build_default_zeros_grads || zero_indices.empty()); + if (any_gradient_nonzero) { + for (const auto i : zero_indices) { + out_gradients[i] = trace.output_tensor_info[i].ZerosLike(); + } + absl::Status s; + s = vspace.CallBackwardFunction(trace.op_type, trace.backward_function, + unneeded_gradients, out_gradients, + absl::MakeSpan(in_gradients)); + if (!persistent_) { + trace.backward_function_deleter(trace.backward_function); + } + if (!s.ok()) { + return s; + } + } else { + if (!persistent_) { + trace.backward_function_deleter(trace.backward_function); + } + for (Gradient* grad : out_gradients) { + if (grad != nullptr) { + vspace.DeleteGradient(grad); + } + } + } + for (int i = 0, end = in_gradients.size(); i < end; ++i) { + const int64_t id = trace.input_tensor_id[i]; + if (in_gradients[i] != nullptr) { + auto& unaggregated_grads = gradients[id]; + unaggregated_grads.push_back(in_gradients[i]); + if (unaggregated_grads.size() > kMinAggregateCount) { + auto size_it = gradients_size.find(id); + int64_t size; + if (size_it == gradients_size.end()) { + size = vspace.NumElements(unaggregated_grads[0]); + gradients_size.emplace(id, size); + } else { + size = size_it->second; + } + if (unaggregated_grads.size() * size * 4 > kMinAggregateBytes) { + Gradient* grad = vspace.AggregateGradients(unaggregated_grads); + unaggregated_grads.clear(); + unaggregated_grads.push_back(grad); + } + } + } + auto usage_count_it = state.tensor_usage_counts.find(id); + if (usage_count_it == state.tensor_usage_counts.end()) { + VLOG(1) << "Tensor " << id << " not used"; + continue; + } + usage_count_it->second--; + if (usage_count_it->second > 0) { + VLOG(1) << "Tensor " << id << " usage count " << usage_count_it->second; + continue; + } + auto tape_it = tensor_tape_.find(id); + if (tape_it == tensor_tape_.end()) { + VLOG(1) << "Tensor " << id + << " has no associated op. Deleting gradient"; + auto grad_it = gradients.find(id); + if (grad_it != gradients.end()) { + for (auto g : grad_it->second) { + vspace.DeleteGradient(g); + } + gradients.erase(grad_it); + } + continue; + } + const int64_t op_id = tape_it->second; + if (op_id == -1) { + VLOG(1) << "Tensor " << id << " is source"; + continue; + } + auto missing_it = state.op_missing_tensor.find(op_id); + if (missing_it != state.op_missing_tensor.end()) { + missing_it->second--; + VLOG(1) << "Op " << op_id << " missing " << missing_it->second + << " output gradients"; + if (missing_it->second == 0) { + op_stack.insert(op_stack.begin(), op_id); + } + } + } + } + if (!state.op_tape.empty()) { + return tensorflow::errors::Internal("Invalid tape state."); + } + if (result.size() != source_tensor_ids.size()) { + return errors::Internal("Expected result Span to be of size ", + source_tensor_ids.size(), " found ", result.size(), + " in call to Tape::ComputeGradient."); + } + std::unordered_set used_gradient_ids(source_tensor_ids.size()); + for (int i = 0; i < source_tensor_ids.size(); i++) { + int64_t tensor_id = source_tensor_ids[i]; + auto grad_it = gradients.find(tensor_id); + if (grad_it == gradients.end()) { + result[i] = nullptr; + } else { + if (grad_it->second.size() > 1) { + Gradient* grad = vspace.AggregateGradients(grad_it->second); + grad_it->second.clear(); + grad_it->second.push_back(grad); + } + result[i] = grad_it->second[0]; + used_gradient_ids.insert(tensor_id); + } + } + VLOG(1) << "Final gradients size: " + << gradients.size() - used_gradient_ids.size(); + for (const auto& grad_pair : gradients) { + if (used_gradient_ids.find(grad_pair.first) == used_gradient_ids.end()) { + for (const auto& g : grad_pair.second) { + vspace.DeleteGradient(g); + } + } + } + return absl::OkStatus(); +} + +template +bool ForwardAccumulator::ShouldRecord( + absl::Span tensor_ids, + absl::Span dtypes) { + if (call_state_.top().backward_tape != nullptr) { + // If we're forwarding Accumulate calls to backward_tape's RecordOperation, + // we should also delegate ShouldRecord. + return call_state_.top().backward_tape->ShouldRecord(tensor_ids, dtypes); + } + if (call_state_.top().accumulating) { + return false; + } + for (int i = 0; i < tensor_ids.size(); ++i) { + if (accumulated_gradients_.find(tensor_ids[i]) != + accumulated_gradients_.end()) { + if (IsDtypeTrainable(dtypes[i])) { + return true; + } + } + } + return false; +} + +template +absl::Status +ForwardAccumulator::ForwardpropFromTape( + const string& op_type, const std::vector& output_tensors, + const std::function& backward_function_getter, + const std::function& backward_function_deleter, + const std::vector& in_grads, absl::Span out_grads) { + /* This function is approximately equivalent to this Python code: + + forwardprop_aids = tf.ones_like(output_tensors) + with tf.GradientTape() as g: + g.watch(forwardprop_aids) + grad = backward_function(forwardprop_aids) + forward_grads = g.gradient(grad, forwardprop_aids, output_gradients=in_grads) + accumulated_gradients_[ID(output_tensors)] = forward_grads + */ + std::unique_ptr> tape( + new GradientTape(false)); + AccumulatorCallState& call_state = call_state_.top(); + call_state.backward_tape = tape.get(); + auto pop_backward_tape = + gtl::MakeCleanup([&call_state] { call_state.backward_tape = nullptr; }); + std::vector forwardprop_aids; + std::vector sources; + std::unordered_set sources_set; + sources.reserve(output_tensors.size()); + for (const TapeTensor& output_tensor : output_tensors) { + // Ownership of `aid` transferred to CallBackwardFunction below. + Gradient* aid; + if (output_tensor.GetDType() == tensorflow::DT_VARIANT) { + // Note: Needs to be zeros rather than ones since there's currently no + // ones_like for variants. + aid = output_tensor.ZerosLike(); + } else { + // TODO(allenl): Figure out why using zeros_like everywhere causes issues + // for some gradient functions and if there's another way to work around + // it (e.g. conds instead of ifs). The value shouldn't really matter. + TF_RETURN_IF_ERROR(vspace_.BuildOnesLike(output_tensor, &aid)); + } + if (TF_PREDICT_FALSE(aid == nullptr)) { + return tensorflow::errors::Internal( + "Failed to create ones tensor for tensor ", output_tensor.GetID(), + " with dtype ", output_tensor.GetDType()); + } + forwardprop_aids.push_back(aid); + int64_t aid_id = vspace_.TensorId(aid); + sources.push_back(aid_id); + sources_set.insert(aid_id); + tape->Watch(aid_id); + } + std::vector grad(in_grads.size()); + auto delete_grad = gtl::MakeCleanup([&grad, this] { + for (Gradient* tensor : grad) { + this->vspace_.DeleteGradient(tensor); + } + }); + { + std::vector unneeded_gradients; + std::unique_ptr> + backward_function(backward_function_getter(), + backward_function_deleter); + TF_RETURN_IF_ERROR(vspace_.CallBackwardFunction( + op_type, backward_function.get(), unneeded_gradients, forwardprop_aids, + absl::MakeSpan(grad))); + } + + // Stop the tape from recording + pop_backward_tape.release()(); + + std::vector targets; + std::vector used_in_grads; + // We may end up with slightly fewer elements than we reserve, but grad.size() + // should be a reasonably tight upper bound. + targets.reserve(grad.size()); + used_in_grads.reserve(grad.size()); + std::unordered_map sources_that_are_targets; + for (int grad_index = 0, end = grad.size(); grad_index < end; ++grad_index) { + Gradient* grad_tensor = grad[grad_index]; + if (grad_tensor != nullptr) { + int64_t tensor_id = vspace_.TensorId(grad_tensor); + targets.push_back(tensor_id); + if (sources_set.find(tensor_id) != sources_set.end()) { + sources_that_are_targets.emplace( + tensor_id, vspace_.TapeTensorFromGradient(grad_tensor)); + } + Gradient* in_grad = in_grads[grad_index]; + if (in_grad != nullptr) { + // ComputeGradient steals a reference + vspace_.MarkAsResult(in_grad); + } + used_in_grads.push_back(in_grad); + } + } + + return tape->ComputeGradient(vspace_, targets, sources, + sources_that_are_targets, used_in_grads, + out_grads); +} + +template +absl::Status +ForwardAccumulator::Accumulate( + const string& op_type, const std::vector& input_tensors, + const std::vector& output_tensors, + absl::Span input_tensor_id, + absl::Span input_dtypes, + const ForwardFunction* forward_function, + const std::function& backward_function_getter, + const std::function& backward_function_deleter) { + if (call_state_.top().backward_tape != nullptr) { + // If backward_tape is not null, then this call to Accumulate is the result + // of a still-active call to Accumulate which is running operations. We + // forward these operations to backward_tape so the outer Accumulate call + // can do its work. + // + // Rather than re-entering and delegating Accumulate like this, we could + // instead allow ForwardAccumulator some control over the current tape set + // (so it can deactivate itself and activate its GradientTape). Currently + // that is managed by the language binding and would require relatively + // messy callbacks. + call_state_.top().backward_tape->RecordOperation( + op_type, output_tensors, input_tensor_id, input_dtypes, + backward_function_getter, backward_function_deleter); + return absl::OkStatus(); + } + if (!ShouldRecord(input_tensor_id, input_dtypes)) { + return absl::OkStatus(); + } + + // We may need to allocate zero inputs for trainable dtypes we don't have JVPs + // for. Make sure they get cleaned up. + std::vector new_zeros; + auto delete_new_zeros = gtl::MakeCleanup([&new_zeros, this] { + for (Gradient* tensor : new_zeros) { + this->vspace_.DeleteGradient(tensor); + } + }); + std::vector in_grads; + in_grads.reserve(input_tensors.size()); + for (int target_index = 0; target_index < input_tensors.size(); + ++target_index) { + const auto current_grad = + accumulated_gradients_.find(input_tensors[target_index].GetID()); + if (current_grad == accumulated_gradients_.end()) { + if (IsDtypeTrainable(input_tensors[target_index].GetDType())) { + // ForwardAccumulator defaults to zeros for unwatched Tensors, unlike + // GradientTape which uses ones. + Gradient* zero = input_tensors[target_index].ZerosLike(); + new_zeros.push_back(zero); + in_grads.push_back(zero); + } else { + in_grads.push_back(nullptr); + } + } else { + in_grads.push_back(current_grad->second); + } + } + + // Avoid infinite recursion. Whichever forward function we run, it'll end up + // executing ops, and we don't want to watch those with this accumulator. + call_state_.emplace(nullptr, true); + auto pop_call_state = gtl::MakeCleanup([this] { this->call_state_.pop(); }); + + std::vector forward_grads; + if (forward_function == nullptr) { + // We have no special-cased forward gradient. Fall back to running the + // backward function under a gradient tape. + forward_grads.resize(output_tensors.size()); + TF_RETURN_IF_ERROR(ForwardpropFromTape( + op_type, output_tensors, backward_function_getter, + backward_function_deleter, in_grads, absl::MakeSpan(forward_grads))); + } else { + TF_RETURN_IF_ERROR( + (*forward_function)(in_grads, &forward_grads, use_batch_)); + } + for (int i = 0; i < forward_grads.size(); ++i) { + if (forward_grads[i] != nullptr) { + int64_t tensor_id = output_tensors[i].GetID(); + auto existing = accumulated_gradients_.find(tensor_id); + if (existing != accumulated_gradients_.end()) { + // This is a somewhat odd case to be in, since it means we have two + // operations which supposedly both created the same Tensor. It comes up + // in recompute_grad, where the gradients have the same value. However, + // only the original gradient is connected to everything else, so we + // should still use that. + vspace_.DeleteGradient(forward_grads[i]); + } else { + accumulated_gradients_[output_tensors[i].GetID()] = forward_grads[i]; + } + } + } + return absl::OkStatus(); +} + +template +void ForwardAccumulator::Watch( + int64_t tensor_id, Gradient* tangent) { + typename std::unordered_map::iterator existing = + accumulated_gradients_.find(tensor_id); + vspace_.MarkAsResult(tangent); + if (existing == accumulated_gradients_.end()) { + accumulated_gradients_.emplace(tensor_id, tangent); + } else { + std::array to_aggregate; + to_aggregate[0] = tangent; + to_aggregate[1] = existing->second; + // AggregateGradients steals a reference to each of its arguments. We + // MarkAsResult on `tangent` above so we don't steal a reference to it. + existing->second = vspace_.AggregateGradients(to_aggregate); + } +} + +template +void ForwardAccumulator::DeleteGradient( + int64_t tensor_id) { + auto existing = accumulated_gradients_.find(tensor_id); + if (existing != accumulated_gradients_.end()) { + vspace_.DeleteGradient(existing->second); + accumulated_gradients_.erase(existing); + } +} + +template +Gradient* ForwardAccumulator::FetchJVP( + int64_t tensor_id) { + auto lookup = accumulated_gradients_.find(tensor_id); + if (lookup == accumulated_gradients_.end()) { + return nullptr; + } else { + return lookup->second; + } +} + +} // namespace eager +} // namespace tensorflow + +#endif // TENSORFLOW_C_EAGER_TAPE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/c/eager/tfe_cancellation_manager_internal.h b/third_party/tflite-hdrs/tensorflow/c/eager/tfe_cancellation_manager_internal.h new file mode 100644 index 00000000..6fdecd78 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/c/eager/tfe_cancellation_manager_internal.h @@ -0,0 +1,31 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_C_EAGER_TFE_CANCELLATION_MANAGER_INTERNAL_H_ +#define TENSORFLOW_C_EAGER_TFE_CANCELLATION_MANAGER_INTERNAL_H_ + +#include "tensorflow/c/conversion_macros.h" +#include "tensorflow/core/framework/cancellation.h" + +struct TFE_CancellationManager; +typedef struct TFE_CancellationManager TFE_CancellationManager; + +namespace tensorflow { +DEFINE_CONVERSION_FUNCTIONS(tensorflow::CancellationManager, + TFE_CancellationManager); +DEFINE_CONVERSION_FUNCTIONS(tensorflow::CancellationManager*, + TFE_CancellationManager*); +} // namespace tensorflow + +#endif // TENSORFLOW_C_EAGER_TFE_CANCELLATION_MANAGER_INTERNAL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/c/eager/tfe_context_internal.h b/third_party/tflite-hdrs/tensorflow/c/eager/tfe_context_internal.h new file mode 100644 index 00000000..1f203531 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/c/eager/tfe_context_internal.h @@ -0,0 +1,35 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_C_EAGER_TFE_CONTEXT_INTERNAL_H_ +#define TENSORFLOW_C_EAGER_TFE_CONTEXT_INTERNAL_H_ + +#include "tensorflow/c/conversion_macros.h" +#include "tensorflow/c/eager/immediate_execution_context.h" + +// Wraps a pointer to a context implementation. +// +// WARNING: Since the underlying object could be ref-counted a user of this +// interface cannot destruct the underlying context object. Instead, call +// TFE_DeleteContext who calls Release() on the context pointer and deletes +// the TFE_Context structure. +typedef struct TFE_Context TFE_Context; + +namespace tensorflow { + +DEFINE_CONVERSION_FUNCTIONS(tensorflow::ImmediateExecutionContext, TFE_Context); + +} // namespace tensorflow + +#endif // TENSORFLOW_C_EAGER_TFE_CONTEXT_INTERNAL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/c/eager/tfe_executor_internal.h b/third_party/tflite-hdrs/tensorflow/c/eager/tfe_executor_internal.h new file mode 100644 index 00000000..7f55532a --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/c/eager/tfe_executor_internal.h @@ -0,0 +1,39 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_C_EAGER_TFE_EXECUTOR_INTERNAL_H_ +#define TENSORFLOW_C_EAGER_TFE_EXECUTOR_INTERNAL_H_ + +#include + +#include "tensorflow/core/common_runtime/eager/eager_executor.h" + +struct TFE_Executor { + explicit TFE_Executor(bool async, bool enable_streaming_enqueue, + int in_flight_nodes_limit) + : owned_executor(new tensorflow::EagerExecutor( + async, enable_streaming_enqueue, in_flight_nodes_limit)) {} + + explicit TFE_Executor(tensorflow::EagerExecutor* executor) + : owned_executor(nullptr), unowned_executor(executor) {} + + tensorflow::EagerExecutor* executor() { + return owned_executor == nullptr ? unowned_executor : owned_executor.get(); + } + + std::unique_ptr owned_executor; + tensorflow::EagerExecutor* unowned_executor; +}; + +#endif // TENSORFLOW_C_EAGER_TFE_EXECUTOR_INTERNAL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/c/eager/tfe_monitoring_internal.h b/third_party/tflite-hdrs/tensorflow/c/eager/tfe_monitoring_internal.h new file mode 100644 index 00000000..e33eaa23 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/c/eager/tfe_monitoring_internal.h @@ -0,0 +1,152 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_C_EAGER_TFE_MONITORING_INTERNAL_H_ +#define TENSORFLOW_C_EAGER_TFE_MONITORING_INTERNAL_H_ + +#include +#include +#include + +#include "absl/memory/memory.h" +#include "tensorflow/core/lib/monitoring/counter.h" +#include "tensorflow/core/lib/monitoring/gauge.h" +#include "tensorflow/core/lib/monitoring/sampler.h" +#include "tensorflow/core/platform/types.h" + +struct TFE_MonitoringCounterCell { + tensorflow::monitoring::CounterCell cell; +}; + +template +struct TFE_MonitoringCounter { + template + TFE_MonitoringCounter(const char* name, const char* description, + LabelDesc&&... label) { + counter = absl::WrapUnique(tensorflow::monitoring::Counter::New( + name, description, label...)); + } + + std::unique_ptr> counter; +}; + +struct TFE_MonitoringCounter0 : TFE_MonitoringCounter<0> { + using TFE_MonitoringCounter::TFE_MonitoringCounter; +}; +struct TFE_MonitoringCounter1 : TFE_MonitoringCounter<1> { + using TFE_MonitoringCounter::TFE_MonitoringCounter; +}; +struct TFE_MonitoringCounter2 : TFE_MonitoringCounter<2> { + using TFE_MonitoringCounter::TFE_MonitoringCounter; +}; + +struct TFE_MonitoringIntGaugeCell { + tensorflow::monitoring::GaugeCell cell; +}; +struct TFE_MonitoringStringGaugeCell { + tensorflow::monitoring::GaugeCell cell; +}; +struct TFE_MonitoringBoolGaugeCell { + tensorflow::monitoring::GaugeCell cell; +}; + +template +struct TFE_MonitoringGauge { + template + TFE_MonitoringGauge(const char* name, const char* description, + LabelDesc&&... label) { + gauge = absl::WrapUnique( + tensorflow::monitoring::Gauge::New( + name, description, label...)); + } + + std::unique_ptr> gauge; +}; + +struct TFE_MonitoringIntGauge0 : TFE_MonitoringGauge { + using TFE_MonitoringGauge::TFE_MonitoringGauge; +}; +struct TFE_MonitoringIntGauge1 : TFE_MonitoringGauge { + using TFE_MonitoringGauge::TFE_MonitoringGauge; +}; +struct TFE_MonitoringIntGauge2 : TFE_MonitoringGauge { + using TFE_MonitoringGauge::TFE_MonitoringGauge; +}; + +struct TFE_MonitoringStringGauge0 : TFE_MonitoringGauge { + using TFE_MonitoringGauge::TFE_MonitoringGauge; +}; +struct TFE_MonitoringStringGauge1 : TFE_MonitoringGauge { + using TFE_MonitoringGauge::TFE_MonitoringGauge; +}; +struct TFE_MonitoringStringGauge2 : TFE_MonitoringGauge { + using TFE_MonitoringGauge::TFE_MonitoringGauge; +}; +struct TFE_MonitoringStringGauge3 : TFE_MonitoringGauge { + using TFE_MonitoringGauge::TFE_MonitoringGauge; +}; +struct TFE_MonitoringStringGauge4 : TFE_MonitoringGauge { + using TFE_MonitoringGauge::TFE_MonitoringGauge; +}; + +struct TFE_MonitoringBoolGauge0 : TFE_MonitoringGauge { + using TFE_MonitoringGauge::TFE_MonitoringGauge; +}; +struct TFE_MonitoringBoolGauge1 : TFE_MonitoringGauge { + using TFE_MonitoringGauge::TFE_MonitoringGauge; +}; +struct TFE_MonitoringBoolGauge2 : TFE_MonitoringGauge { + using TFE_MonitoringGauge::TFE_MonitoringGauge; +}; + +struct TFE_MonitoringBuckets { + explicit TFE_MonitoringBuckets( + std::function(void)> + fn) { + create_buckets = fn; + } + + std::function(void)> + create_buckets; +}; + +struct TFE_MonitoringSamplerCell { + tensorflow::monitoring::SamplerCell cell; +}; + +template +struct TFE_MonitoringSampler { + template + TFE_MonitoringSampler( + const char* name, + std::unique_ptr buckets, + const char* description, LabelDesc&&... label) { + sampler = absl::WrapUnique(tensorflow::monitoring::Sampler::New( + {name, description, label...}, std::move(buckets))); + } + + std::unique_ptr> sampler; +}; + +struct TFE_MonitoringSampler0 : TFE_MonitoringSampler<0> { + using TFE_MonitoringSampler::TFE_MonitoringSampler; +}; +struct TFE_MonitoringSampler1 : TFE_MonitoringSampler<1> { + using TFE_MonitoringSampler::TFE_MonitoringSampler; +}; +struct TFE_MonitoringSampler2 : TFE_MonitoringSampler<2> { + using TFE_MonitoringSampler::TFE_MonitoringSampler; +}; + +#endif // TENSORFLOW_C_EAGER_TFE_MONITORING_INTERNAL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/c/eager/tfe_monitoring_reader_internal.h b/third_party/tflite-hdrs/tensorflow/c/eager/tfe_monitoring_reader_internal.h new file mode 100644 index 00000000..3c63e672 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/c/eager/tfe_monitoring_reader_internal.h @@ -0,0 +1,34 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_EAGER_TFE_MONITORING_READER_INTERNAL_H_ +#define TENSORFLOW_C_EAGER_TFE_MONITORING_READER_INTERNAL_H_ + +#include + +#include "tensorflow/core/lib/monitoring/cell_reader.h" + +struct TFE_MonitoringCounterReader { + explicit TFE_MonitoringCounterReader(const char* name) { + counter = std::make_unique< + ::tensorflow::monitoring::testing::CellReader>(name); + } + template + int64_t Read(const LabelType&... labels); + std::unique_ptr<::tensorflow::monitoring::testing::CellReader> + counter; +}; + +#endif // TENSORFLOW_C_EAGER_TFE_MONITORING_READER_INTERNAL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/c/eager/tfe_op_attrs_internal.h b/third_party/tflite-hdrs/tensorflow/c/eager/tfe_op_attrs_internal.h new file mode 100644 index 00000000..24e3692a --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/c/eager/tfe_op_attrs_internal.h @@ -0,0 +1,39 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_C_EAGER_TFE_OP_ATTRS_INTERNAL_H_ +#define TENSORFLOW_C_EAGER_TFE_OP_ATTRS_INTERNAL_H_ + +#include "tensorflow/c/conversion_macros.h" +#include "tensorflow/c/eager/abstract_op_attrs.h" +#include "tensorflow/c/tf_status.h" +#include "tensorflow/core/framework/attr_value.pb.h" + +// An equivalent of a tensorflow::NameAttrList protocol buffer, but used in ways +// that sometimes do not require serialization. +typedef struct TFE_OpAttrs TFE_OpAttrs; + +typedef struct TFE_Context TFE_Context; +typedef struct TFE_Op TFE_Op; + +namespace tensorflow { +DEFINE_CONVERSION_FUNCTIONS(tensorflow::AbstractOpAttrs, TFE_OpAttrs); + +// Set an AttrValue on the op. Doesn't handle the list types. +void SetOpAttrValueScalar(TFE_Context* ctx, TFE_Op* op, + const tensorflow::AttrValue& default_value, + const char* attr_name, TF_Status* status); +} // namespace tensorflow + +#endif // TENSORFLOW_C_EAGER_TFE_OP_ATTRS_INTERNAL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/c/eager/tfe_op_internal.h b/third_party/tflite-hdrs/tensorflow/c/eager/tfe_op_internal.h new file mode 100644 index 00000000..3fe94d35 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/c/eager/tfe_op_internal.h @@ -0,0 +1,36 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_C_EAGER_TFE_OP_INTERNAL_H_ +#define TENSORFLOW_C_EAGER_TFE_OP_INTERNAL_H_ + +#include "tensorflow/c/conversion_macros.h" +#include "tensorflow/c/eager/immediate_execution_operation.h" + +// Wraps a pointer to an operation implementation. +// +// WARNING: Since the underlying object could be ref-counted a user of this +// interface cannot destruct the underlying operation object. Instead, call +// TFE_DeleteOp who calls Release() on the operation pointer and deletes +// the TFE_Op structure. +typedef struct TFE_Op TFE_Op; + +namespace tensorflow { + +DEFINE_CONVERSION_FUNCTIONS(tensorflow::ImmediateExecutionOperation, TFE_Op); +DEFINE_CONVERSION_FUNCTIONS(tensorflow::ImmediateExecutionOperation*, TFE_Op*); + +} // namespace tensorflow + +#endif // TENSORFLOW_C_EAGER_TFE_OP_INTERNAL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/c/eager/tfe_tensor_debug_info_internal.h b/third_party/tflite-hdrs/tensorflow/c/eager/tfe_tensor_debug_info_internal.h new file mode 100644 index 00000000..0c570660 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/c/eager/tfe_tensor_debug_info_internal.h @@ -0,0 +1,30 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_C_EAGER_TFE_TENSOR_DEBUG_INFO_INTERNAL_H_ +#define TENSORFLOW_C_EAGER_TFE_TENSOR_DEBUG_INFO_INTERNAL_H_ + +#include + +#include "tensorflow/core/platform/types.h" + +struct TFE_TensorDebugInfo { + explicit TFE_TensorDebugInfo(const std::vector& dims) + : dev_dims(dims) {} + + // Fully-padded, minor-to-major. + std::vector dev_dims; +}; + +#endif // TENSORFLOW_C_EAGER_TFE_TENSOR_DEBUG_INFO_INTERNAL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/c/eager/tfe_tensorhandle_internal.h b/third_party/tflite-hdrs/tensorflow/c/eager/tfe_tensorhandle_internal.h new file mode 100644 index 00000000..308e8c24 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/c/eager/tfe_tensorhandle_internal.h @@ -0,0 +1,38 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_C_EAGER_TFE_TENSORHANDLE_INTERNAL_H_ +#define TENSORFLOW_C_EAGER_TFE_TENSORHANDLE_INTERNAL_H_ + +#include "tensorflow/c/conversion_macros.h" +#include "tensorflow/c/eager/immediate_execution_tensor_handle.h" + +// Wraps a pointer to a tensor handle implementation. +// +// WARNING: Since the underlying object could be ref-counted a user of this +// interface cannot destruct the underlying handle object. Instead, call +// TFE_DeleteTensorHandle who calls Release() on the handle pointer and deletes +// the TFE_TensorHandle structure. +typedef struct TFE_TensorHandle TFE_TensorHandle; + +namespace tensorflow { + +DEFINE_CONVERSION_FUNCTIONS(tensorflow::ImmediateExecutionTensorHandle, + TFE_TensorHandle); +DEFINE_CONVERSION_FUNCTIONS(tensorflow::ImmediateExecutionTensorHandle*, + TFE_TensorHandle*); + +} // namespace tensorflow + +#endif // TENSORFLOW_C_EAGER_TFE_TENSORHANDLE_INTERNAL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/c/eager/tracing_utils.h b/third_party/tflite-hdrs/tensorflow/c/eager/tracing_utils.h new file mode 100644 index 00000000..1c336322 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/c/eager/tracing_utils.h @@ -0,0 +1,26 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_C_EAGER_TRACING_UTILS_H_ +#define TENSORFLOW_C_EAGER_TRACING_UTILS_H_ + +#include "tensorflow/c/eager/abstract_operation.h" + +namespace tensorflow { +namespace tracing { +absl::Status MaybeSetOpName(AbstractOperation*, const char* op_name); +} // namespace tracing +} // namespace tensorflow + +#endif // TENSORFLOW_C_EAGER_TRACING_UTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/c/eager/unified_api_testutil.h b/third_party/tflite-hdrs/tensorflow/c/eager/unified_api_testutil.h new file mode 100644 index 00000000..2df18c13 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/c/eager/unified_api_testutil.h @@ -0,0 +1,95 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_C_EAGER_UNIFIED_API_TESTUTIL_H_ +#define TENSORFLOW_C_EAGER_UNIFIED_API_TESTUTIL_H_ + +#include "tensorflow/c/eager/abstract_context.h" +#include "tensorflow/c/eager/abstract_tensor_handle.h" +#include "tensorflow/c/eager/c_api_test_util.h" +#include "tensorflow/c/eager/c_api_unified_experimental.h" +#include "tensorflow/c/eager/c_api_unified_experimental_internal.h" +#include "tensorflow/c/tf_status_helper.h" +#include "tensorflow/c/tf_tensor.h" +#include "tensorflow/core/platform/status.h" + +namespace tensorflow { + +// Builds and returns a `TracingContext` using the default tracing impl. +AbstractContext* BuildFunction(const char* fn_name); + +// Creates parameters (placeholders) in the tracing `ctx` using the shape and +// dtype of `inputs`. +absl::Status CreateParamsForInputs( + AbstractContext* ctx, absl::Span inputs, + std::vector* params); + +// A callable that takes tensor inputs and returns zero or more tensor outputs. +using Model = std::function, + absl::Span)>; + +// Runs `model` maybe wrapped in a function call op. This can be thought as +// being equivalent to the following python code. +// +// if use_function: +// outputs = tf.function(model)(inputs) +// else: +// outputs = model(inputs) +absl::Status RunModel(Model model, AbstractContext* ctx, + absl::Span inputs, + absl::Span outputs, + bool use_function); + +absl::Status BuildImmediateExecutionContext(bool use_tfrt, + AbstractContext** ctx); + +// Return a tensor handle with given type, values and dimensions. +template +absl::Status TestTensorHandleWithDims(AbstractContext* ctx, const T* data, + const int64_t* dims, int num_dims, + AbstractTensorHandle** tensor) { + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + TFE_Context* eager_ctx = + TF_ExecutionContextGetTFEContext(wrap(ctx), status.get()); + TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get())); + TFE_TensorHandle* input_eager = + TestTensorHandleWithDims(eager_ctx, data, dims, num_dims); + *tensor = + unwrap(TF_CreateAbstractTensorFromEagerTensor(input_eager, status.get())); + return absl::OkStatus(); +} + +// Return a scalar tensor handle with given value. +template +absl::Status TestScalarTensorHandle(AbstractContext* ctx, const T value, + AbstractTensorHandle** tensor) { + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + TFE_Context* eager_ctx = + TF_ExecutionContextGetTFEContext(wrap(ctx), status.get()); + TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get())); + TFE_TensorHandle* input_eager = + TestScalarTensorHandle(eager_ctx, value); + *tensor = + unwrap(TF_CreateAbstractTensorFromEagerTensor(input_eager, status.get())); + return absl::OkStatus(); +} + +// Places data from `t` into *result_tensor. +absl::Status GetValue(AbstractTensorHandle* t, TF_Tensor** result_tensor); +} // namespace tensorflow + +#endif // TENSORFLOW_C_EAGER_UNIFIED_API_TESTUTIL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/c/env.h b/third_party/tflite-hdrs/tensorflow/c/env.h new file mode 100644 index 00000000..ac6a9e32 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/c/env.h @@ -0,0 +1,212 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_ENV_H_ +#define TENSORFLOW_C_ENV_H_ + +#include +#include +#include + +#include "tensorflow/c/c_api_macros.h" +#include "tensorflow/c/tf_file_statistics.h" +#include "tensorflow/c/tf_status.h" + +// -------------------------------------------------------------------------- +// C API for tensorflow::Env. + +#ifdef __cplusplus +extern "C" { +#endif + +typedef struct TF_WritableFileHandle TF_WritableFileHandle; +typedef struct TF_StringStream TF_StringStream; +typedef struct TF_Thread TF_Thread; + +typedef struct TF_ThreadOptions { + // Thread stack size to use (in bytes), zero implies that the system default + // will be used. + size_t stack_size; + + // Guard area size to use near thread stacks to use (in bytes), zero implies + // that the system default will be used. + size_t guard_size; + + // The NUMA node to use, -1 implies that there should be no NUMA affinity for + // this thread. + int numa_node; +} TF_ThreadOptions; + +// Creates the specified directory. Typical status code are: +// * TF_OK - successfully created the directory +// * TF_ALREADY_EXISTS - directory already exists +// * TF_PERMISSION_DENIED - dirname is not writable +TF_CAPI_EXPORT extern void TF_CreateDir(const char* dirname, TF_Status* status); + +// Deletes the specified directory. Typical status codes are: +// * TF_OK - successfully deleted the directory +// * TF_FAILED_PRECONDITION - the directory is not empty +TF_CAPI_EXPORT extern void TF_DeleteDir(const char* dirname, TF_Status* status); + +// Deletes the specified directory and all subdirectories and files underneath +// it. This is accomplished by traversing the directory tree rooted at dirname +// and deleting entries as they are encountered. +// +// If dirname itself is not readable or does not exist, *undeleted_dir_count is +// set to 1, *undeleted_file_count is set to 0 and an appropriate status (e.g. +// TF_NOT_FOUND) is returned. +// +// If dirname and all its descendants were successfully deleted, TF_OK is +// returned and both error counters are set to zero. +// +// Otherwise, while traversing the tree, undeleted_file_count and +// undeleted_dir_count are updated if an entry of the corresponding type could +// not be deleted. The returned error status represents the reason that any one +// of these entries could not be deleted. +// +// Typical status codes: +// * TF_OK - dirname exists and we were able to delete everything underneath +// * TF_NOT_FOUND - dirname doesn't exist +// * TF_PERMISSION_DENIED - dirname or some descendant is not writable +// * TF_UNIMPLEMENTED - some underlying functions (like Delete) are not +// implemented +TF_CAPI_EXPORT extern void TF_DeleteRecursively(const char* dirname, + int64_t* undeleted_file_count, + int64_t* undeleted_dir_count, + TF_Status* status); + +// Obtains statistics for the given path. If status is TF_OK, *stats is +// updated, otherwise it is not touched. +TF_CAPI_EXPORT extern void TF_FileStat(const char* filename, + TF_FileStatistics* stats, + TF_Status* status); + +// Creates or truncates the given filename and returns a handle to be used for +// appending data to the file. If status is TF_OK, *handle is updated and the +// caller is responsible for freeing it (see TF_CloseWritableFile). +TF_CAPI_EXPORT extern void TF_NewWritableFile(const char* filename, + TF_WritableFileHandle** handle, + TF_Status* status); + +// Closes the given handle and frees its memory. If there was a problem closing +// the file, it is indicated by status. Memory is freed in any case. +TF_CAPI_EXPORT extern void TF_CloseWritableFile(TF_WritableFileHandle* handle, + TF_Status* status); + +// Syncs content of the handle to the filesystem. Blocks waiting for the +// filesystem to indicate that the content has been persisted. +TF_CAPI_EXPORT extern void TF_SyncWritableFile(TF_WritableFileHandle* handle, + TF_Status* status); + +// Flush local buffers to the filesystem. If the process terminates after a +// successful flush, the contents may still be persisted, since the underlying +// filesystem may eventually flush the contents. If the OS or machine crashes +// after a successful flush, the contents may or may not be persisted, depending +// on the implementation. +TF_CAPI_EXPORT extern void TF_FlushWritableFile(TF_WritableFileHandle* handle, + TF_Status* status); + +// Appends the given bytes to the file. Any failure to do so is indicated in +// status. +TF_CAPI_EXPORT extern void TF_AppendWritableFile(TF_WritableFileHandle* handle, + const char* data, + size_t length, + TF_Status* status); + +// Deletes the named file and indicates whether successful in *status. +TF_CAPI_EXPORT extern void TF_DeleteFile(const char* filename, + TF_Status* status); + +// Retrieves the next item from the given TF_StringStream and places a pointer +// to it in *result. If no more items are in the list, *result is set to NULL +// and false is returned. +// +// Ownership of the items retrieved with this function remains with the library. +// Item points are invalidated after a call to TF_StringStreamDone. +TF_CAPI_EXPORT extern bool TF_StringStreamNext(TF_StringStream* list, + const char** result); + +// Frees the resources associated with given string list. All pointers returned +// by TF_StringStreamNext are invalid after this call. +TF_CAPI_EXPORT extern void TF_StringStreamDone(TF_StringStream* list); + +// Retrieves the list of children of the given directory. You can iterate +// through the list with TF_StringStreamNext. The caller is responsible for +// freeing the list (see TF_StringStreamDone). +TF_CAPI_EXPORT extern TF_StringStream* TF_GetChildren(const char* filename, + TF_Status* status); + +// Retrieves a list of directory names on the local machine that may be used for +// temporary storage. You can iterate through the list with TF_StringStreamNext. +// The caller is responsible for freeing the list (see TF_StringStreamDone). +TF_CAPI_EXPORT extern TF_StringStream* TF_GetLocalTempDirectories(void); + +// Creates a temporary file name with an extension. +// The caller is responsible for freeing the returned pointer. +TF_CAPI_EXPORT extern char* TF_GetTempFileName(const char* extension); + +// Returns the number of nanoseconds since the Unix epoch. +TF_CAPI_EXPORT extern uint64_t TF_NowNanos(void); + +// Returns the number of microseconds since the Unix epoch. +TF_CAPI_EXPORT extern uint64_t TF_NowMicros(void); + +// Returns the number of seconds since the Unix epoch. +TF_CAPI_EXPORT extern uint64_t TF_NowSeconds(void); + +// Populates a TF_ThreadOptions struct with system-default values. +TF_CAPI_EXPORT extern void TF_DefaultThreadOptions(TF_ThreadOptions* options); + +// Returns a new thread that is running work_func and is identified +// (for debugging/performance-analysis) by thread_name. +// +// The given param (which may be null) is passed to work_func when the thread +// starts. In this way, data may be passed from the thread back to the caller. +// +// Caller takes ownership of the result and must call TF_JoinThread on it +// eventually. +TF_CAPI_EXPORT extern TF_Thread* TF_StartThread(const TF_ThreadOptions* options, + const char* thread_name, + void (*work_func)(void*), + void* param); + +// Waits for the given thread to finish execution, then deletes it. +TF_CAPI_EXPORT extern void TF_JoinThread(TF_Thread* thread); + +// \brief Load a dynamic library. +// +// Pass "library_filename" to a platform-specific mechanism for dynamically +// loading a library. The rules for determining the exact location of the +// library are platform-specific and are not documented here. +// +// On success, place OK in status and return the newly created library handle. +// Otherwise returns nullptr and set error status. +TF_CAPI_EXPORT extern void* TF_LoadSharedLibrary(const char* library_filename, + TF_Status* status); + +// \brief Get a pointer to a symbol from a dynamic library. +// +// "handle" should be a pointer returned from a previous call to +// TF_LoadLibraryFromEnv. On success, place OK in status and return a pointer to +// the located symbol. Otherwise returns nullptr and set error status. +TF_CAPI_EXPORT extern void* TF_GetSymbolFromLibrary(void* handle, + const char* symbol_name, + TF_Status* status); + +#ifdef __cplusplus +} +#endif + +#endif // TENSORFLOW_C_ENV_H_ diff --git a/third_party/tflite-hdrs/tensorflow/c/experimental/filesystem/filesystem_interface.h b/third_party/tflite-hdrs/tensorflow/c/experimental/filesystem/filesystem_interface.h new file mode 100644 index 00000000..13fd7632 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/c/experimental/filesystem/filesystem_interface.h @@ -0,0 +1,1125 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_FILESYSTEM_INTERFACE_H_ +#define TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_FILESYSTEM_INTERFACE_H_ + +#include +#include + +#include "tensorflow/c/tf_file_statistics.h" +#include "tensorflow/c/tf_status.h" + +/// This is the interop header between core TensorFlow and modular filesystem +/// plugins (see initial RFC https://github.com/tensorflow/community/pull/101). +/// +/// Both core TensorFlow and every plugin will use this header. The associated +/// `.cc` file is only used by core TensorFlow to implement checking needed for +/// plugin registration and ensuring API and ABI compatibility. Plugin authors +/// don't need to read the `.cc` file but they should consult every section of +/// this file to ensure a compliant plugin can be built and that the plugin can +/// be used without recompilation in the widest range of TensorFlow versions. +/// +/// The header is divided into sections, as follows: +/// 1. Opaque plugin private data structures and wrappers for type safety; +/// 2. Function tables for plugin functionality; +/// 3. Versioning metadata; +/// 4. Plugin registration API and the DSO entry point. + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +/// SECTION 1. Opaque data structures to hold plugin specific data +/// ---------------------------------------------------------------------------- +/// +/// The following data structures incorporate a `void*` that is opaque to +/// TensorFlow but can be used by each filesystem plugin to represent internal +/// data. +/// +/// We prefer to have these structures instead of passing `void*` into +/// method signatures to have some type of type safety: for example, operations +/// that are only valid on random access files have a `TF_RandomAccessFile` +/// argument. +/// +/// Lifetime: The wrapper data structures are owned by core TensorFlow. The data +/// pointed to by the `void*` members is always owned by the plugin. The plugin +/// will provide functions to call to allocate and deallocate this data (see +/// next sections) and core TensorFlow ensures to call these at the proper time. +/// +/// Plugins will never receive a `TF_*` pointer that is `nullptr`. Core +/// TensorFlow will never touch the `void*` wrapped by these structures, except +/// to initialize it as `nullptr`. + +typedef struct TF_RandomAccessFile { + void* plugin_file; +} TF_RandomAccessFile; + +typedef struct TF_WritableFile { + void* plugin_file; +} TF_WritableFile; + +typedef struct TF_ReadOnlyMemoryRegion { + void* plugin_memory_region; +} TF_ReadOnlyMemoryRegion; + +typedef struct TF_Filesystem { + void* plugin_filesystem; +} TF_Filesystem; + +typedef struct TF_TransactionToken { + void* token; + TF_Filesystem* owner; +} TF_TransactionToken; + +// The named union is needed here (as opposed to +// inside the `TF_Filesystem_Option_Value` struct) +// as MSVC does not recognize `typeof`. +typedef union TF_Filesystem_Option_Value_Union { + int64_t int_val; + double real_val; + struct { + char* buf; + int buf_length; + } buffer_val; +} TF_Filesystem_Option_Value_Union; + +typedef struct TF_Filesystem_Option_Value { + int type_tag; // type of values in the values union + int num_values; // number of values + TF_Filesystem_Option_Value_Union* + values; // owned (plugins must make a copy if storing this) +} TF_Filesystem_Option_Value; + +typedef enum TF_Filesystem_Option_Type { + TF_Filesystem_Option_Type_Int = 0, + TF_Filesystem_Option_Type_Real, + TF_Filesystem_Option_Type_Buffer, + TF_Filesystem_Num_Option_Types, // must always be the last item +} TF_Filesystem_Option_Type; + +typedef struct TF_Filesystem_Option { + char* name; // null terminated, owned + char* description; // null terminated, owned + int per_file; // bool actually, but bool is not a C type + TF_Filesystem_Option_Value* value; // owned +} TF_Filesystem_Option; + +/// SECTION 2. Function tables for functionality provided by plugins +/// ---------------------------------------------------------------------------- +/// +/// The following data structures represent the function tables for operations +/// that plugins provide (some are mandatory, some are optional, with or without +/// a default implementation). +/// +/// Each plugin implements the operations that are supported and TensorFlow will +/// properly handle the cases when an operation is not supported (i.e., return +/// the corresponding `Status` value). +/// +/// REQUIRED OPERATIONS: All required operations are marked as such, including +/// operations which are conditionally required. If the presence of an operation +/// `foo` requires operation `bar` to be present, this is specified in `foo`. If +/// the entire set of operations in a table is not provided, use `nullptr` for +/// the struct pointer (e.g., when a file type is not supported). +/// +/// DEFAULT IMPLEMENTATIONS: Some operations have default implementations that +/// TensorFlow uses in case the plugin doesn't supply its own version. An +/// operation `foo` might have a default implementation which uses `bar` and +/// `foobar`. If the plugin supplies `bar` and `foobar`, TensorFlow can use the +/// default implementation of `foo`. +/// +/// During plugin loading, plugins will call the registration function provided +/// by this interface, supplying values for each of these structures. Core +/// TensorFlow checks that the plugin supplies all mandatory operations and +/// then copies these tables to a different memory location, marking the new +/// operation tables as read-only. Once a plugin is loaded, none of these +/// operation pointers may change. +/// +/// There are 4 function tables: one for each of the 3 file objects in +/// TensorFlow (i.e., `RandomAccessFile`, `WritableFile`, +/// `ReadOnlyMemoryRegion`) and one for all the operations a `Filesystem` +/// implements. Each of them is in a 1-to-1 correspondence with the wrapper +/// structures from the first section: these tables only contain function +/// pointers that operate on the corresponding data. Thus, the first argument of +/// each of these functions is a pointer to the paired struct and this argument +/// can be used to track state in between calls (from an object oriented point +/// of view, this can be viewed as a "vtable" for a "class" -- that is the +/// corresponding struct above --; the first argument is in place of `this`). +/// +/// Except where noted otherwise, all pointer arguments are owned by core +/// TensorFlow and are guaranteed to not be `nullptr`. +/// +/// All path-like arguments are null terminated `char*` strings. Plugins can +/// assume that before any function using path arguments is invoked, the path is +/// made canonical by calling the function provided by `translate_name` or a +/// default implementation of that (supplied by core TensorFlow). +/// +/// The only time the pointer to the `TF_*` structures from section 1 is not +/// marked `const` in these functions is when these function are either +/// allocating or deallocating the plugin specific data. That is, in the 4 +/// `cleanup` functions (one for each data structure), the `init` function for +/// `TF_Filesystem` and the `new_*` methods of `TF_FilesystemOps` to initialize +/// the 3 types of files. In all other cases, there is no need to modify the +/// address of the opaque data pointer, hence the wrapper pointer is marked +/// `const`. +/// +/// For consistency, the arguments on all these functions follow the same +/// pattern: first we have the opaque pointer argument ("this" above), then the +/// input arguments, then the in-out arguments (if any) and we finish the +/// argument list with the out arguments. We only use the return type for an out +/// parameter if that is a plain C type, as this ensures ABI compatibility +/// (returning structures has issues in case compiler options affect +/// optimizations such as RVO). If a status needs to be returned from these +/// methods, the last argument is always a `TF_Status *` (or an array of such +/// pointers) owned by core TensorFlow and guaranteed to not be `nullptr`. +/// +/// To ensure ABI and API compatibility, we have out-of-bounds data that is used +/// by both core TensorFlow and the plugin at load time. We don't include this +/// data in the structures here to prevent cases when padding/packing enabled by +/// different compiler options breaks compatibility. For more details about how +/// this is used, please consult next sections. Here we just wrap these tables +/// in lint warnings so that changes here cause changes to the versioning data +/// as well. Here is a short summary of what changes are allowed: +/// * adding a new method at the end of a table is allowed at any time; +/// * any other change to these tables is only allowed on a major TensorFlow +/// version change (e.g., from 2.x to 3.0). This is provided as an escape +/// hatch to allow cleaning up these tables. Since any of these changes +/// break ABI compatibility and cause all plugins to be recompiled, these +/// type of changes should be extremely rare. +/// +/// Next section will detail this as well as some corner cases that are out of +/// scope for now. + +// LINT.IfChange +typedef struct TF_RandomAccessFileOps { + /// Releases resources associated with `*file`. + /// + /// Requires that `*file` is not used in any concurrent or subsequent + /// operations. + /// + /// This operation must be provided. See "REQUIRED OPERATIONS" above. + void (*cleanup)(TF_RandomAccessFile* file); + + /// Reads up to `n` bytes from `*file` starting at `offset`. + /// + /// The output is in `buffer`, core TensorFlow owns the buffer and guarantees + /// that at least `n` bytes are available. + /// + /// Returns number of bytes read or -1 in case of error. Because of this + /// constraint and the fact that `ssize_t` is not defined in `stdint.h`/C++ + /// standard, the return type is `int64_t`. + /// + /// This is thread safe. + /// + /// Note: the `buffer` argument is NOT a null terminated string! + /// + /// Plugins: + /// * Must set `status` to `TF_OK` if exactly `n` bytes have been read. + /// * Must set `status` to `TF_OUT_OF_RANGE` if fewer than `n` bytes have + /// been read due to EOF. + /// * Must return -1 for any other error and must set `status` to any + /// other value to provide more information about the error. + int64_t (*read)(const TF_RandomAccessFile* file, uint64_t offset, size_t n, + char* buffer, TF_Status* status); +} TF_RandomAccessFileOps; +// LINT.ThenChange(:random_access_file_ops_version) + +// LINT.IfChange +typedef struct TF_WritableFileOps { + /// Releases resources associated with `*file`. + /// + /// Requires that `*file` is not used in any concurrent or subsequent + /// operations. + /// + /// This operation must be provided. See "REQUIRED OPERATIONS" above. + void (*cleanup)(TF_WritableFile* file); + + /// Appends `buffer` of size `n` to `*file`. + /// + /// Core TensorFlow owns `buffer` and guarantees at least `n` bytes of storage + /// that can be used to write data. + /// + /// Note: the `buffer` argument is NOT a null terminated string! + /// + /// Plugins: + /// * Must set `status` to `TF_OK` if exactly `n` bytes have been written. + /// * Must set `status` to `TF_RESOURCE_EXHAUSTED` if fewer than `n` bytes + /// have been written, potentially due to quota/disk space. + /// * Might use any other error value for `status` to signal other errors. + void (*append)(const TF_WritableFile* file, const char* buffer, size_t n, + TF_Status* status); + + /// Returns the current write position in `*file`. + /// + /// Plugins should ensure that the implementation is idempotent, 2 identical + /// calls result in the same answer. + /// + /// Plugins: + /// * Must set `status` to `TF_OK` and return current position if no error. + /// * Must set `status` to any other value and return -1 in case of error. + int64_t (*tell)(const TF_WritableFile* file, TF_Status* status); + + /// Flushes `*file` and syncs contents to filesystem. + /// + /// This call might not block, and when it returns the contents might not have + /// been fully persisted. + /// + /// DEFAULT IMPLEMENTATION: No op. + void (*flush)(const TF_WritableFile* file, TF_Status* status); + + /// Syncs contents of `*file` with the filesystem. + /// + /// This call should block until filesystem confirms that all buffers have + /// been flushed and persisted. + /// + /// DEFAULT IMPLEMENTATION: No op. + void (*sync)(const TF_WritableFile* file, TF_Status* status); + + /// Closes `*file`. + /// + /// Flushes all buffers and deallocates all resources. + /// + /// Calling `close` must not result in calling `cleanup`. + /// + /// Core TensorFlow will never call `close` twice. + void (*close)(const TF_WritableFile* file, TF_Status* status); +} TF_WritableFileOps; +// LINT.ThenChange(:writable_file_ops_version) + +// LINT.IfChange +typedef struct TF_ReadOnlyMemoryRegionOps { + /// Releases resources associated with `*region`. + /// + /// Requires that `*region` is not used in any concurrent or subsequent + /// operations. + /// + /// This operation must be provided. See "REQUIRED OPERATIONS" above. + void (*cleanup)(TF_ReadOnlyMemoryRegion* region); + + /// Returns a pointer to the memory region. + /// + /// This operation must be provided. See "REQUIRED OPERATIONS" above. + const void* (*data)(const TF_ReadOnlyMemoryRegion* region); + + /// Returns the length of the memory region in bytes. + /// + /// This operation must be provided. See "REQUIRED OPERATIONS" above. + uint64_t (*length)(const TF_ReadOnlyMemoryRegion* region); +} TF_ReadOnlyMemoryRegionOps; +// LINT.ThenChange(:read_only_memory_region_ops_version) + +// LINT.IfChange +typedef struct TF_FilesystemOps { + /// Acquires all resources used by the filesystem. + /// + /// This operation must be provided. See "REQUIRED OPERATIONS" above. + void (*init)(TF_Filesystem* filesystem, TF_Status* status); + + /// Releases all resources used by the filesystem + /// + /// NOTE: TensorFlow does not unload DSOs. Thus, the only way a filesystem + /// won't be registered anymore is if this function gets called by core + /// TensorFlow and the `TF_Filesystem*` object is destroyed. However, due to + /// registration being done in a static instance of `Env`, the destructor of + /// `FileSystem` is never called (see + /// https://github.com/tensorflow/tensorflow/issues/27535). In turn, this + /// function will never be called. There are plans to refactor registration + /// and fix this. + /// + /// TODO(b/139060984): After all filesystems are converted, revisit note. + /// + /// This operation must be provided. See "REQUIRED OPERATIONS" above. + void (*cleanup)(TF_Filesystem* filesystem); + + /// Creates a new random access read-only file from given `path`. + /// + /// After this call `file` may be concurrently accessed by multiple threads. + /// + /// Plugins: + /// * Must set `status` to `TF_OK` if `file` was updated. + /// * Must set `status` to `TF_NOT_FOUND` if `path` doesn't point to an + /// existing file or one of the parent entries in `path` doesn't exist. + /// * Must set `status` to `TF_FAILED_PRECONDITION` if `path` points to a + /// directory or if it is invalid (e.g., malformed, or has a parent entry + /// which is a file). + /// * Might use any other error value for `status` to signal other errors. + /// + /// REQUIREMENTS: If plugins implement this, they must also provide a filled + /// `TF_RandomAccessFileOps` table. See "REQUIRED OPERATIONS" above. + void (*new_random_access_file)(const TF_Filesystem* filesystem, + const char* path, TF_RandomAccessFile* file, + TF_Status* status); + + /// Creates an object to write to a file with the specified `path`. + /// + /// If the file already exists, it is deleted and recreated. The `file` object + /// must only be accessed by one thread at a time. + /// + /// Plugins: + /// * Must set `status` to `TF_OK` if `file` was updated. + /// * Must set `status` to `TF_NOT_FOUND` if one of the parents entries in + /// `path` doesn't exist. + /// * Must set `status` to `TF_FAILED_PRECONDITION` if `path` points to a + /// directory or if it is invalid. + /// * Might use any other error value for `status` to signal other errors. + /// + /// REQUIREMENTS: If plugins implement this, they must also provide a filled + /// `TF_WritableFileOps` table. See "REQUIRED OPERATIONS" above. + void (*new_writable_file)(const TF_Filesystem* filesystem, const char* path, + TF_WritableFile* file, TF_Status* status); + + /// Creates an object to append to a file with the specified `path`. + /// + /// If the file doesn't exists, it is first created with empty contents. + /// The `file` object must only be accessed by one thread at a time. + /// + /// Plugins: + /// * Must set `status` to `TF_OK` if `file` was updated. + /// * Must set `status` to `TF_NOT_FOUND` if one of the parents entries in + /// `path` doesn't exist. + /// * Must set `status` to `TF_FAILED_PRECONDITION` if `path` points to a + /// directory or if it is invalid. + /// * Might use any other error value for `status` to signal other errors. + /// + /// REQUIREMENTS: If plugins implement this, they must also provide a filled + /// `TF_WritableFileOps` table. See "REQUIRED OPERATIONS" above. + void (*new_appendable_file)(const TF_Filesystem* filesystem, const char* path, + TF_WritableFile* file, TF_Status* status); + + /// Creates a read-only region of memory from contents of `path`. + /// + /// After this call `region` may be concurrently accessed by multiple threads. + /// + /// Plugins: + /// * Must set `status` to `TF_OK` if `region` was updated. + /// * Must set `status` to `TF_NOT_FOUND` if `path` doesn't point to an + /// existing file or one of the parent entries in `path` doesn't exist. + /// * Must set `status` to `TF_FAILED_PRECONDITION` if `path` points to a + /// directory or if it is invalid. + /// * Must set `status` to `TF_INVALID_ARGUMENT` if `path` points to an + /// empty file. + /// * Might use any other error value for `status` to signal other errors. + /// + /// REQUIREMENTS: If plugins implement this, they must also provide a filled + /// `TF_ReadOnlyMemoryRegionOps` table. See "REQUIRED OPERATIONS" above. + void (*new_read_only_memory_region_from_file)(const TF_Filesystem* filesystem, + const char* path, + TF_ReadOnlyMemoryRegion* region, + TF_Status* status); + + /// Creates the directory specified by `path`, assuming parent exists. + /// + /// Plugins: + /// * Must set `status` to `TF_OK` if directory was created. + /// * Must set `status` to `TF_NOT_FOUND` if one of the parents entries in + /// `path` doesn't exist. + /// * Must set `status` to `TF_FAILED_PRECONDITION` if `path` is invalid. + /// * Must set `status` to `TF_ALREADY_EXISTS` if `path` already exists. + /// * Might use any other error value for `status` to signal other errors. + void (*create_dir)(const TF_Filesystem* filesystem, const char* path, + TF_Status* status); + + /// Creates the directory specified by `path` and all needed ancestors. + /// + /// Plugins: + /// * Must set `status` to `TF_OK` if directory was created. + /// * Must set `status` to `TF_FAILED_PRECONDITION` if `path` is invalid or + /// if it exists but is not a directory. + /// * Might use any other error value for `status` to signal other errors. + /// + /// NOTE: The requirements specify that `TF_ALREADY_EXISTS` is not returned if + /// directory exists. Similarly, `TF_NOT_FOUND` is not be returned, as the + /// missing directory entry and all its descendants will be created by the + /// plugin. + /// + /// DEFAULT IMPLEMENTATION: Creates directories one by one. Needs + /// `path_exists`, `is_directory`, and `create_dir`. + void (*recursively_create_dir)(const TF_Filesystem* filesystem, + const char* path, TF_Status* status); + + /// Deletes the file specified by `path`. + /// + /// Plugins: + /// * Must set `status` to `TF_OK` if file was deleted. + /// * Must set `status` to `TF_NOT_FOUND` if `path` doesn't exist. + /// * Must set `status` to `TF_FAILED_PRECONDITION` if `path` points to a + /// directory or if it is invalid. + /// * Might use any other error value for `status` to signal other errors. + void (*delete_file)(const TF_Filesystem* filesystem, const char* path, + TF_Status* status); + + /// Deletes the empty directory specified by `path`. + /// + /// Plugins: + /// * Must set `status` to `TF_OK` if directory was deleted. + /// * Must set `status` to `TF_NOT_FOUND` if `path` doesn't exist. + /// * Must set `status` to `TF_FAILED_PRECONDITION` if `path` does not point + /// to a directory, if `path` is invalid, or if directory is not empty. + /// * Might use any other error value for `status` to signal other errors. + void (*delete_dir)(const TF_Filesystem* filesystem, const char* path, + TF_Status* status); + + /// Deletes the directory specified by `path` and all its contents. + /// + /// This is accomplished by traversing directory tree rooted at `path` and + /// deleting entries as they are encountered, from leaves to root. Each plugin + /// is free to choose a different approach which obtains similar results. + /// + /// On successful deletion, `status` must be `TF_OK` and `*undeleted_files` + /// and `*undeleted_dirs` must be 0. On unsuccessful deletion, `status` must + /// be set to the reason why one entry couldn't be removed and the proper + /// count must be updated. If the deletion is unsuccessful because the + /// traversal couldn't start, `*undeleted_files` must be set to 0 and + /// `*undeleted_dirs` must be set to 1. + /// + /// TODO(b/139060984): After all filesystems are converted, consider + /// invariant about `*undeleted_files` and `*undeleted_dirs`. + /// + /// Plugins: + /// * Must set `status` to `TF_OK` if directory was deleted. + /// * Must set `status` to `TF_NOT_FOUND` if `path` doesn't exist. + /// * Must set `status` to `TF_FAILED_PRECONDITION` if `path` is invalid. + /// * Might use any other error value for `status` to signal other errors. + /// + /// DEFAULT IMPLEMENTATION: Does a BFS traversal of tree rooted at `path`, + /// deleting entries as needed. Needs `path_exists`, `get_children`, + /// `is_directory`, `delete_file`, and `delete_dir`. + void (*delete_recursively)(const TF_Filesystem* filesystem, const char* path, + uint64_t* undeleted_files, + uint64_t* undeleted_dirs, TF_Status* status); + + /// Renames the file given by `src` to that in `dst`. + /// + /// Replaces `dst` if it exists. In case of error, both `src` and `dst` keep + /// the same state as before the call. + /// + /// Plugins: + /// * Must set `status` to `TF_OK` if rename was completed. + /// * Must set `status` to `TF_NOT_FOUND` if one of the parents entries in + /// either `src` or `dst` doesn't exist or if the specified `src` path + /// doesn't exist. + /// * Must set `status` to `TF_FAILED_PRECONDITION` if either `src` or + /// `dst` is a directory or if either of them is invalid. + /// * Might use any other error value for `status` to signal other errors. + /// + /// DEFAULT IMPLEMENTATION: Copies file and deletes original. Needs + /// `copy_file`. and `delete_file`. + void (*rename_file)(const TF_Filesystem* filesystem, const char* src, + const char* dst, TF_Status* status); + + /// Copies the file given by `src` to that in `dst`. + /// + /// Similar to `rename_file`, but both `src` and `dst` exist after this call + /// with the same contents. In case of error, both `src` and `dst` keep the + /// same state as before the call. + /// + /// If `dst` is a directory, creates a file with the same name as the source + /// inside the target directory. + /// + /// Plugins: + /// * Must set `status` to `TF_OK` if rename was completed. + /// * Must set `status` to `TF_NOT_FOUND` if one of the parents entries in + /// either `src` or `dst` doesn't exist or if the specified `src` path + /// doesn't exist. + /// * Must set `status` to `TF_FAILED_PRECONDITION` if either `src` or + /// `dst` is a directory or if either of them is invalid. + /// * Might use any other error value for `status` to signal other errors. + /// + /// DEFAULT IMPLEMENTATION: Reads from `src` and writes to `dst`. Needs + /// `new_random_access_file` and `new_writable_file`. + void (*copy_file)(const TF_Filesystem* filesystem, const char* src, + const char* dst, TF_Status* status); + + /// Checks if `path` exists. + /// + /// Note that this doesn't differentiate between files and directories. + /// + /// Plugins: + /// * Must set `status` to `TF_OK` if `path` exists. + /// * Must set `status` to `TF_NOT_FOUND` if `path` doesn't point to a + /// filesystem entry. + /// * Must set `status` to `TF_FAILED_PRECONDITION` if `path` is invalid. + /// * Might use any other error value for `status` to signal other errors. + void (*path_exists)(const TF_Filesystem* filesystem, const char* path, + TF_Status* status); + + /// Checks if all values in `paths` exist in the filesystem. + /// + /// Returns `true` if and only if calling `path_exists` on each entry in + /// `paths` would set `status` to `TF_OK`. + /// + /// Caller guarantees that: + /// * `paths` has exactly `num_files` entries. + /// * `statuses` is either null or an array of `num_files` non-null elements + /// of type `TF_Status*`. + /// + /// If `statuses` is not null, plugins must fill each element with detailed + /// status for each file, as if calling `path_exists` on each one. Core + /// TensorFlow initializes the `statuses` array and plugins must use + /// `TF_SetStatus` to set each element instead of directly assigning. + /// + /// DEFAULT IMPLEMENTATION: Checks existence of every file. Needs + /// `path_exists`. + bool (*paths_exist)(const TF_Filesystem* filesystem, char** paths, + int num_files, TF_Status** statuses); + + /// Obtains statistics for the given `path`. + /// + /// Updates `stats` only if `status` is set to `TF_OK`. + /// + /// Plugins: + /// * Must set `status` to `TF_OK` if `path` exists. + /// * Must set `status` to `TF_NOT_FOUND` if `path` doesn't point to a + /// filesystem entry. + /// * Must set `status` to `TF_FAILED_PRECONDITION` if `path` is invalid. + /// * Might use any other error value for `status` to signal other errors. + void (*stat)(const TF_Filesystem* filesystem, const char* path, + TF_FileStatistics* stats, TF_Status* status); + + /// Checks whether the given `path` is a directory or not. + /// + /// If `status` is not `TF_OK`, returns `false`, otherwise returns the same + /// as the `is_directory` member of a `TF_FileStatistics` that would be used + /// on the equivalent call of `stat`. + /// + /// Plugins: + /// * Must set `status` to `TF_OK` if `path` exists. + /// * Must set `status` to `TF_NOT_FOUND` if `path` doesn't point to a + /// filesystem entry. + /// * Must set `status` to `TF_FAILED_PRECONDITION` if `path` is invalid. + /// * Might use any other error value for `status` to signal other errors. + /// + /// DEFAULT IMPLEMENTATION: Gets statistics about `path`. Needs `stat`. + bool (*is_directory)(const TF_Filesystem* filesystem, const char* path, + TF_Status* status); + + /// Returns the size of the file given by `path`. + /// + /// If `status` is not `TF_OK`, return value is undefined. Otherwise, returns + /// the same as `length` member of a `TF_FileStatistics` that would be used on + /// the equivalent call of `stat`. + /// + /// Plugins: + /// * Must set `status` to `TF_OK` if `path` exists. + /// * Must set `status` to `TF_NOT_FOUND` if `path` doesn't point to a + /// filesystem entry. + /// * Must set `status` to `TF_FAILED_PRECONDITION` if `path` is invalid or + /// points to a directory. + /// * Might use any other error value for `status` to signal other errors. + /// + /// DEFAULT IMPLEMENTATION: Gets statistics about `path`. Needs `stat`. + int64_t (*get_file_size)(const TF_Filesystem* filesystem, const char* path, + TF_Status* status); + + /// Translates `uri` to a filename for the filesystem + /// + /// A filesystem is registered for a specific scheme and all of the methods + /// should work with URIs. Hence, each filesystem needs to be able to + /// translate from an URI to a path on the filesystem. For example, this + /// function could translate `fs:///path/to/a/file` into `/path/to/a/file`, if + /// implemented by a filesystem registered to handle the `fs://` scheme. + /// + /// A new `char*` buffer must be allocated by this method. Core TensorFlow + /// manages the lifetime of the buffer after the call. Thus, all callers of + /// this method must take ownership of the returned pointer. + /// + /// The implementation should clean up paths, including but not limited to, + /// removing duplicate `/`s, and resolving `..` and `.`. + /// + /// Plugins must not return `nullptr`. Returning empty strings is allowed. + /// + /// The allocation and freeing of memory must happen via the functions sent to + /// core TensorFlow upon registration (see the `TF_FilesystemPluginInfo` + /// structure in Section 4). + /// + /// This function will be called by core TensorFlow to clean up all path + /// arguments for all other methods in the filesystem API. + /// + /// DEFAULT IMPLEMENTATION: Uses `io::CleanPath` and `io::ParseURI`. + char* (*translate_name)(const TF_Filesystem* filesystem, const char* uri); + + /// Finds all entries in the directory given by `path`. + /// + /// The returned entries are paths relative to `path`. + /// + /// Plugins must allocate `entries` to hold all names that need to be returned + /// and return the size of `entries`. Caller takes ownership of `entries` + /// after the call. + /// + /// In case of error, plugins must set `status` to a value different than + /// `TF_OK`, free memory allocated for `entries` and return -1. + /// + /// The allocation and freeing of memory must happen via the functions sent to + /// core TensorFlow upon registration (see the `TF_FilesystemPluginInfo` + /// structure in Section 4). + /// + /// Plugins: + /// * Must set `status` to `TF_OK` if all children were returned. + /// * Must set `status` to `TF_NOT_FOUND` if `path` doesn't point to a + /// filesystem entry or if one of the parents entries in `path` doesn't + /// exist. + /// * Must set `status` to `TF_FAILED_PRECONDITION` if one of the parent + /// entries in `path` is not a directory, or if `path` is a file. + /// * Might use any other error value for `status` to signal other errors. + int (*get_children)(const TF_Filesystem* filesystem, const char* path, + char*** entries, TF_Status* status); + + /// Finds all entries matching the regular expression given by `glob`. + /// + /// Pattern must match the entire entry name, not just a substring. + /// + /// pattern: { term } + /// term: + /// '*': matches any sequence of non-'/' characters + /// '?': matches a single non-'/' character + /// '[' [ '^' ] { match-list } ']': + /// matches any single character (not) on the list + /// c: matches character c (c != '*', '?', '\\', '[') + /// '\\' c: matches character c + /// character-range: + /// c: matches character c (c != '\\', '-', ']') + /// '\\' c: matches character c + /// lo '-' hi: matches character c for lo <= c <= hi + /// + /// Implementations must allocate `entries` to hold all names that need to be + /// returned and return the size of `entries`. Caller takes ownership of + /// `entries` after the call. + /// + /// In case of error, the implementations must set `status` to a value + /// different than `TF_OK`, free any memory that might have been allocated for + /// `entries` and return -1. + /// + /// The allocation and freeing of memory must happen via the functions sent to + /// core TensorFlow upon registration (see the `TF_FilesystemPluginInfo` + /// structure in Section 4). + /// + /// Plugins: + /// * Must set `status` to `TF_OK` if all matches were returned. + /// * Might use any other error value for `status` to signal other errors. + /// + /// DEFAULT IMPLEMENTATION: Scans the directory tree (in parallel if possible) + /// and fills `*entries`. Needs `get_children` and `is_directory`. + int (*get_matching_paths)(const TF_Filesystem* filesystem, const char* glob, + char*** entries, TF_Status* status); + + /// Flushes any filesystem cache currently in memory + /// + /// DEFAULT IMPLEMENTATION: No op. + void (*flush_caches)(const TF_Filesystem* filesystem); + + /// Starts a new transaction. + /// + /// An opaque transaction token is returned in `token`. Ownership of the token + /// is in filesystem. Token will be freed in `end_transaction` call and any + /// access to token after that is invalid. + /// + /// In case of error, plugins must set `status` to a value different than + /// `TF_OK`, free memory allocated for `token` and return -1. + /// + /// The allocation and freeing of memory must happen via the functions sent to + /// core TensorFlow upon registration (see the `TF_FilesystemPluginInfo` + /// structure in Section 4). + /// + /// Plugins: + /// * Must set `status` to `TF_OK` if transaction successfuly started. + /// * Must set `status` to `TF_FAILED_PRECONDITION` if multiple transactions + /// are not supported + /// * Might use any other error value for `status` to signal other errors. + int (*start_transaction)(const TF_Filesystem* filesystem, + TF_TransactionToken** token, TF_Status* status); + + /// Ends transaction and free the `token`. Any access to token after + /// that will be invalid. + /// + /// In case of error, plugins must set `status` to a value different than + /// `TF_OK`, free memory allocated for `token` and return -1. + /// + /// The allocation and freeing of memory must happen via the functions sent to + /// core TensorFlow upon registration (see the `TF_FilesystemPluginInfo` + /// structure in Section 4). + /// + /// Plugins: + /// * Must set `status` to `TF_OK` if transaction successfuly finalized. + /// * Must set `status` to `TF_NOT_FOUND` if token is invalid/not found + /// * Might use any other error value for `status` to signal other errors. + int (*end_transaction)(const TF_Filesystem* filesystem, + TF_TransactionToken* token, TF_Status* status); + + /// Adds file/directory in the `path` to transaction in `token`. It is a valid + /// operation to add a path that doesn't exist yet to a transaction. + /// + /// In case of error, plugins must set `status` to a value different than + /// `TF_OK`, free memory allocated for `token` and return -1. + /// + /// The allocation and freeing of memory must happen via the functions sent to + /// core TensorFlow upon registration (see the `TF_FilesystemPluginInfo` + /// structure in Section 4). + /// + /// Plugins: + /// * Must set `status` to `TF_OK` if path added to transaction successful. + /// * Must set `status` to `TF_NOT_FOUND` if `token` is invalid. + /// * Must set `status` to `TF_FAILED_PRECONDITION` if file/directory is in + /// another transaction and multiple transactions are not supported + /// * Might use any other error value for `status` to signal other errors. + int (*add_to_transaction)(const TF_Filesystem* filesystem, const char* path, + TF_TransactionToken* token, TF_Status* status); + + /// Returns transaction token for file/directory in the `path`. Note that path + /// may not exist yet but still might be part of a transaction. + /// + /// Transaction token is returned in `token`. Ownership of the token is in + /// filesystem. Token will be freed in `end_transaction` call and any access + /// to token after that is invalid. + /// + /// In case of error, plugins must set `status` to a value different than + /// `TF_OK`, free memory allocated for `token` and return -1. + /// + /// The allocation and freeing of memory must happen via the functions sent to + /// core TensorFlow upon registration (see the `TF_FilesystemPluginInfo` + /// structure in Section 4). + /// + /// Plugins: + /// * Must set `status` to `TF_OK` if a transaction for path is found + /// * Must set `status` to `TF_NOT_FOUND` if `path` is not part of any + /// transaction + /// * Must set `status` to `TF_FAILED_PRECONDITION` if `path` is + /// not in this filesystem. + /// * Might use any other error value for `status` to signal other errors. + int (*get_transaction_for_path)(const TF_Filesystem* filesystem, + const char* path, TF_TransactionToken** token, + TF_Status* status); + + /// Returns transaction token for `path` if it is part of a transaction else + /// starts a new transaction and adds `path` to that transaction + /// + /// Transaction token is returned in `token`. Ownership of the token is in + /// filesystem. Token will be freed in `end_transaction` call and any access + /// to token after that is invalid. + /// + /// In case of error, plugins must set `status` to a value different than + /// `TF_OK`, free memory allocated for `token` and return -1. + /// + /// The allocation and freeing of memory must happen via the functions sent to + /// core TensorFlow upon registration (see the `TF_FilesystemPluginInfo` + /// structure in Section 4). + /// + /// Plugins: + /// * Must set `status` to `TF_OK` if transaction found or successfuly + /// started. + /// * Must set `status` to `TF_NOT_FOUND` if `path` doesn't point to this + /// filesystem + /// * Must set `status` to `TF_FAILED_PRECONDITION` if file/directory is + /// not in any transaction and multiple transactions are not supported. + /// * Might use any other error value for `status` to signal other errors. + int (*get_or_start_transaction_for_path)(const TF_Filesystem* filesystem, + const char* path, + TF_TransactionToken** token, + TF_Status* status); + + /// Decodes transaction token in `token` to human readable format for + /// debugging. + /// + /// A new `char*` buffer must be allocated by this method. Core TensorFlow + /// manages the lifetime of the buffer after the call. Thus, all callers of + /// this method must take ownership of the returned pointer. + /// + /// Plugins must not return `nullptr`. Returning empty strings is allowed. + /// + /// The allocation and freeing of memory must happen via the functions sent to + /// core TensorFlow upon registration (see the `TF_FilesystemPluginInfo` + /// structure in Section 4). + /// + /// DEFAULT IMPLEMENTATION: Dump token and owner address. + char* (*decode_transaction_token)(const TF_Filesystem* filesystem, + const TF_TransactionToken* token); + + /// Returns pointer to an array of available configuration options and their + /// current/default values in `options` and number of options in array in + /// `num_options`. Ownership of the array is transferred to caller and the + /// caller is responsible of freeing the buffers using respective file systems + /// allocation API. + /// + /// Plugins: + /// * Must set `status` to `TF_OK` if `options` and `num_options` set. + /// If there is no configurable option, `num_options` should be 0. + /// * Might use any other error value for `status` to signal other errors. + /// + /// DEFAULT IMPLEMENTATION: return 0 options and `TF_OK`. + void (*get_filesystem_configuration)(const TF_Filesystem* filesystem, + TF_Filesystem_Option** options, + int* num_options, TF_Status* status); + + /// Updates filesystem configuration with options passed in `options`. It can + /// contain full set of options supported by the filesystem or just a subset + /// of them. Ownership of options and buffers therein belongs to the caller + /// and any buffers need to be allocated through filesystem allocation API. + /// Filesystems may choose to ignore configuration errors but should at least + /// display a warning or error message to warn the users. + /// + /// Plugins: + /// * Must set `status` to `TF_OK` if options are updated. + /// * Might use any other error value for `status` to signal other errors. + /// + /// DEFAULT IMPLEMENTATION: return `TF_NOT_FOUND`. + void (*set_filesystem_configuration)(const TF_Filesystem* filesystem, + const TF_Filesystem_Option* options, + int num_options, TF_Status* status); + + /// Returns the value of the filesystem option given in `key` in `option`. + /// Valid values of the `key` are returned by + /// `get_file_system_configuration_keys` call. Ownership of the + /// `option` is transferred to caller. Buffers therein should be allocated and + /// freed by the relevant filesystems allocation API. + /// + /// Plugins: + /// * Must set `status` to `TF_OK` if `option` is set + /// * Must set `status` to `TF_NOT_FOUND` if the key is invalid + /// * Might use any other error value for `status` to signal other errors. + /// + /// DEFAULT IMPLEMENTATION: return `TF_NOT_FOUND`. + void (*get_filesystem_configuration_option)(const TF_Filesystem* filesystem, + const char* key, + TF_Filesystem_Option** option, + TF_Status* status); + + /// Sets the value of the filesystem option given in `key` to value in + /// `option`. Valid values of the `key` are returned by + /// `get_file_system_configuration_keys` call. Ownership of the `option` and + /// the `key` belogs to the caller. Buffers therein should be allocated and + /// freed by the filesystems allocation API. + /// + /// Plugins: + /// * Must set `status` to `TF_OK` if `option` is set/updated + /// * Must set `status` to `TF_NOT_FOUND` if the key is invalid + /// * Might use any other error value for `status` to signal other errors. + /// + /// DEFAULT IMPLEMENTATION: return `TF_NOT_FOUND`. + void (*set_filesystem_configuration_option)( + const TF_Filesystem* filesystem, const TF_Filesystem_Option* option, + TF_Status* status); + + /// Returns a list of valid configuration keys in `keys` array and number of + /// keys in `num_keys`. Ownership of the buffers in `keys` are transferred to + /// caller and needs to be freed using relevant filesystem allocation API. + /// + /// Plugins: + /// * Must set `status` to `TF_OK` on success. If there are no configurable + /// keys, `num_keys` should be set to 0 + /// * Might use any other error value for `status` to signal other errors. + /// + /// DEFAULT IMPLEMENTATION: return `TF_OK` and `num_keys`=0. + void (*get_filesystem_configuration_keys)(const TF_Filesystem* filesystem, + char** keys, int* num_keys, + TF_Status* status); +} TF_FilesystemOps; +// LINT.ThenChange(:filesystem_ops_version) + +/// SECTION 3. ABI and API compatibility +/// ---------------------------------------------------------------------------- +/// +/// In this section we define constants and macros to record versioning +/// information for each of the structures in section 2: ABI and API versions +/// and the number of functions in each of the function tables (which is +/// automatically determined, so ignored for the rest of this comment). +/// +/// Since filesystem plugins are outside of TensorFlow's code tree, they are not +/// tied with TensorFlow releases and should have their own versioning metadata +/// in addition with the data discussed in this section. Each plugin author can +/// use a custom scheme, but it should only relate to changes in plugin code. +/// This section only touches metadata related to the versioning of this +/// interface that is shared by all possible plugins. +/// +/// The API number increases whenever we break API compatibility while still +/// maintaining ABI compatibility. This happens only in the following cases: +/// 1. A new method is added _at the end_ of the function table. +/// 2. Preconditions or postconditions for one operation in these function +/// table change. Note that only core TensorFlow is able to impose these +/// invariants (i.e., guarantee the preconditions before calling the operation +/// and check the postconditions after the operation returns). If plugins need +/// additional invariants, they should be checked on the plugin side and the +/// `status` out variable should be updated accordingly (e.g., to include +/// plugin version information that relates to the condition change). +/// +/// All other changes to the data structures (e.g., method removal, method +/// reordering, argument reordering, adding or removing arguments, changing the +/// type or the constness of a parameter, etc.) results in an ABI breakage. +/// Thus, we should not do any of these types of changes, except, potentially, +/// when we are releasing a new major version of TensorFlow. This is an escape +/// hatch, to be used rarely, preferably only to cleanup these structures. +/// Whenever we do these changes, the ABI number must be increased. +/// +/// Next section will detail how this metadata is used at plugin registration to +/// only load compatible plugins and discard all others. + +// LINT.IfChange(random_access_file_ops_version) +constexpr int TF_RANDOM_ACCESS_FILE_OPS_API = 0; +constexpr int TF_RANDOM_ACCESS_FILE_OPS_ABI = 0; +constexpr size_t TF_RANDOM_ACCESS_FILE_OPS_SIZE = + sizeof(TF_RandomAccessFileOps); +// LINT.ThenChange() + +// LINT.IfChange(writable_file_ops_version) +constexpr int TF_WRITABLE_FILE_OPS_API = 0; +constexpr int TF_WRITABLE_FILE_OPS_ABI = 0; +constexpr size_t TF_WRITABLE_FILE_OPS_SIZE = sizeof(TF_WritableFileOps); +// LINT.ThenChange() + +// LINT.IfChange(read_only_memory_region_ops_version) +constexpr int TF_READ_ONLY_MEMORY_REGION_OPS_API = 0; +constexpr int TF_READ_ONLY_MEMORY_REGION_OPS_ABI = 0; +constexpr size_t TF_READ_ONLY_MEMORY_REGION_OPS_SIZE = + sizeof(TF_ReadOnlyMemoryRegionOps); +// LINT.ThenChange() + +// LINT.IfChange(filesystem_ops_version) +constexpr int TF_FILESYSTEM_OPS_API = 0; +constexpr int TF_FILESYSTEM_OPS_ABI = 0; +constexpr size_t TF_FILESYSTEM_OPS_SIZE = sizeof(TF_FilesystemOps); +// LINT.ThenChange() + +/// SECTION 4. Plugin registration and initialization +/// ---------------------------------------------------------------------------- +/// +/// In this section we define the API used by core TensorFlow to initialize a +/// filesystem provided by a plugin. That is, we define the following: +/// * `TF_InitPlugin` function: must be present in the plugin shared object as +/// it will be called by core TensorFlow when the filesystem plugin is +/// loaded; +/// * `TF_FilesystemPluginOps` struct: used to transfer information between +/// plugins and core TensorFlow about the operations provided and metadata; +/// * `TF_FilesystemPluginInfo` struct: similar to the above structure, but +/// collects information about all the file schemes that the plugin provides +/// support for, as well as about the plugin's memory handling routines; +/// * `TF_SetFilesystemVersionMetadata` function: must be called by plugins in +/// their `TF_InitPlugin` to record the versioning information the plugins +/// are compiled against. +/// +/// The `TF_InitPlugin` function is used by plugins to set up the data +/// structures that implement this interface, as presented in Section 2. In +/// order to not have plugin shared objects call back symbols defined in core +/// TensorFlow, `TF_InitPlugin` has a `TF_FilesystemPluginInfo` argument which +/// the plugin must fill (using the `TF_SetFilesystemVersionMetadata` for the +/// metadata and setting up all the supported operations and the URI schemes +/// that are supported). + +/// This structure incorporates the operations defined in Section 2 and the +/// metadata defined in section 3, allowing plugins to define different ops +/// for different URI schemes. +/// +/// Every URI scheme is of the form "fs" for URIs of form "fs:///path/to/file". +/// For local filesystems (i.e., when the URI is "/path/to/file"), the scheme +/// must be "". The scheme must never be `nullptr`. +/// +/// Every plugin fills this in `TF_InitPlugin`, using the alocator passed as +/// argument to allocate memory. After `TF_InitPlugin` finishes, core +/// TensorFlow uses the information present in this to initialize filesystems +/// for the URI schemes that the plugin requests. +/// +/// All pointers defined in this structure point to memory allocated by the DSO +/// using an allocator provided by core TensorFlow when calling `TF_InitPlugin`. +/// +/// IMPORTANT: To maintain binary compatibility, the layout of this structure +/// must not change! In the unlikely case that a new type of file needs to be +/// supported, add the new ops and metadata at the end of the structure. +typedef struct TF_FilesystemPluginOps { + char* scheme; + int filesystem_ops_abi; + int filesystem_ops_api; + size_t filesystem_ops_size; + TF_FilesystemOps* filesystem_ops; + int random_access_file_ops_abi; + int random_access_file_ops_api; + size_t random_access_file_ops_size; + TF_RandomAccessFileOps* random_access_file_ops; + int writable_file_ops_abi; + int writable_file_ops_api; + size_t writable_file_ops_size; + TF_WritableFileOps* writable_file_ops; + int read_only_memory_region_ops_abi; + int read_only_memory_region_ops_api; + size_t read_only_memory_region_ops_size; + TF_ReadOnlyMemoryRegionOps* read_only_memory_region_ops; +} TF_FilesystemPluginOps; + +/// This structure gathers together all the operations provided by the plugin. +/// +/// Plugins must provide exactly `num_schemes` elements in the `ops` array. +/// +/// Since memory that is allocated by the DSO gets transferred to core +/// TensorFlow, we need to provide a way for the allocation and deallocation to +/// match. This is why this structure also defines `plugin_memory_allocate` and +/// `plugin_memory_free` members. +/// +/// All memory allocated by the plugin that will be owned by core TensorFlow +/// must be allocated using the allocator in this structure. Core TensorFlow +/// will use the deallocator to free this memory once it no longer needs it. +/// +/// IMPORTANT: To maintain binary compatibility, the layout of this structure +/// must not change! In the unlikely case that new global operations must be +/// provided, add them at the end of the structure. +typedef struct TF_FilesystemPluginInfo { + size_t num_schemes; + TF_FilesystemPluginOps* ops; + void* (*plugin_memory_allocate)(size_t size); + void (*plugin_memory_free)(void* ptr); +} TF_FilesystemPluginInfo; + +/// Convenience function for setting the versioning metadata. +/// +/// The argument is guaranteed to not be `nullptr`. +/// +/// We want this to be defined in the plugin's memory space and we guarantee +/// that core TensorFlow will never call this. +static inline void TF_SetFilesystemVersionMetadata( + TF_FilesystemPluginOps* ops) { + ops->filesystem_ops_abi = TF_FILESYSTEM_OPS_ABI; + ops->filesystem_ops_api = TF_FILESYSTEM_OPS_API; + ops->filesystem_ops_size = TF_FILESYSTEM_OPS_SIZE; + ops->random_access_file_ops_abi = TF_RANDOM_ACCESS_FILE_OPS_ABI; + ops->random_access_file_ops_api = TF_RANDOM_ACCESS_FILE_OPS_API; + ops->random_access_file_ops_size = TF_RANDOM_ACCESS_FILE_OPS_SIZE; + ops->writable_file_ops_abi = TF_WRITABLE_FILE_OPS_ABI; + ops->writable_file_ops_api = TF_WRITABLE_FILE_OPS_API; + ops->writable_file_ops_size = TF_WRITABLE_FILE_OPS_SIZE; + ops->read_only_memory_region_ops_abi = TF_READ_ONLY_MEMORY_REGION_OPS_ABI; + ops->read_only_memory_region_ops_api = TF_READ_ONLY_MEMORY_REGION_OPS_API; + ops->read_only_memory_region_ops_size = TF_READ_ONLY_MEMORY_REGION_OPS_SIZE; +} + +/// Initializes a TensorFlow plugin. +/// +/// Must be implemented by the plugin DSO. It is called by TensorFlow runtime. +/// +/// Filesystem plugins can be loaded on demand by users via +/// `Env::LoadLibrary` or during TensorFlow's startup if they are on certain +/// paths (although this has a security risk if two plugins register for the +/// same filesystem and the malicious one loads before the legimitate one - +/// but we consider this to be something that users should care about and +/// manage themselves). In both of these cases, core TensorFlow looks for +/// the `TF_InitPlugin` symbol and calls this function. +/// +/// For every filesystem URI scheme that this plugin supports, the plugin must +/// add one `TF_FilesystemPluginInfo` entry in `plugin_info->ops` and call +/// `TF_SetFilesystemVersionMetadata` for that entry. +/// +/// Plugins must also initialize `plugin_info->plugin_memory_allocate` and +/// `plugin_info->plugin_memory_free` to ensure memory allocated by plugin is +/// freed in a compatible way. +TF_CAPI_EXPORT extern void TF_InitPlugin(TF_FilesystemPluginInfo* plugin_info); + +#ifdef __cplusplus +} // end extern "C" +#endif // __cplusplus + +#endif // TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_FILESYSTEM_INTERFACE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/c/experimental/filesystem/modular_filesystem.h b/third_party/tflite-hdrs/tensorflow/c/experimental/filesystem/modular_filesystem.h new file mode 100644 index 00000000..a19ee27d --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/c/experimental/filesystem/modular_filesystem.h @@ -0,0 +1,210 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_MODULAR_FILESYSTEM_H_ +#define TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_MODULAR_FILESYSTEM_H_ + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "tensorflow/c/experimental/filesystem/filesystem_interface.h" +#include "tensorflow/core/platform/file_statistics.h" +#include "tensorflow/core/platform/file_system.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/stringpiece.h" +#include "tensorflow/core/platform/types.h" +#include "tsl/platform/file_system.h" + +/// This file builds classes needed to hold a filesystem implementation in the +/// modular world. Once all TensorFlow filesystems are converted to use the +/// plugin based approach, this file will replace the one in core/platform and +/// the names will lose the `Modular` part. Until that point, the `Modular*` +/// classes here are experimental and subject to breaking changes. +/// For documentation on these methods, consult `core/platform/filesystem.h`. + +namespace tensorflow { + +// TODO(b/143949615): After all filesystems are converted, this file will be +// moved to core/platform, and this class can become a singleton and replace the +// need for `Env::Default()`. At that time, we might decide to remove the need +// for `Env::Default()` altogether, but that's a different project, not in +// scope for now. I'm just mentioning this here as that transition will mean +// removal of the registration part from `Env` and adding it here instead: we +// will need tables to hold for each scheme the function tables that implement +// the needed functionality instead of the current `FileSystemRegistry` code in +// `core/platform/env.cc`. +class ModularFileSystem final : public FileSystem { + public: + ModularFileSystem( + std::unique_ptr filesystem, + std::unique_ptr filesystem_ops, + std::unique_ptr random_access_file_ops, + std::unique_ptr writable_file_ops, + std::unique_ptr + read_only_memory_region_ops, + std::function plugin_memory_allocate, + std::function plugin_memory_free) + : filesystem_(std::move(filesystem)), + ops_(std::move(filesystem_ops)), + random_access_file_ops_(std::move(random_access_file_ops)), + writable_file_ops_(std::move(writable_file_ops)), + read_only_memory_region_ops_(std::move(read_only_memory_region_ops)), + plugin_memory_allocate_(std::move(plugin_memory_allocate)), + plugin_memory_free_(std::move(plugin_memory_free)) {} + + ~ModularFileSystem() override { ops_->cleanup(filesystem_.get()); } + + TF_USE_FILESYSTEM_METHODS_WITH_NO_TRANSACTION_SUPPORT; + + absl::Status NewRandomAccessFile( + const std::string& fname, TransactionToken* token, + std::unique_ptr* result) override; + absl::Status NewWritableFile(const std::string& fname, + TransactionToken* token, + std::unique_ptr* result) override; + absl::Status NewAppendableFile( + const std::string& fname, TransactionToken* token, + std::unique_ptr* result) override; + absl::Status NewReadOnlyMemoryRegionFromFile( + const std::string& fname, TransactionToken* token, + std::unique_ptr* result) override; + absl::Status FileExists(const std::string& fname, + TransactionToken* token) override; + bool FilesExist(const std::vector& files, + TransactionToken* token, + std::vector* status) override; + absl::Status GetChildren(const std::string& dir, TransactionToken* token, + std::vector* result) override; + absl::Status GetMatchingPaths(const std::string& pattern, + TransactionToken* token, + std::vector* results) override; + absl::Status DeleteFile(const std::string& fname, + TransactionToken* token) override; + absl::Status DeleteRecursively(const std::string& dirname, + TransactionToken* token, + int64_t* undeleted_files, + int64_t* undeleted_dirs) override; + absl::Status DeleteDir(const std::string& dirname, + TransactionToken* token) override; + absl::Status RecursivelyCreateDir(const std::string& dirname, + TransactionToken* token) override; + absl::Status CreateDir(const std::string& dirname, + TransactionToken* token) override; + absl::Status Stat(const std::string& fname, TransactionToken* token, + FileStatistics* stat) override; + absl::Status IsDirectory(const std::string& fname, + TransactionToken* token) override; + absl::Status GetFileSize(const std::string& fname, TransactionToken* token, + uint64* file_size) override; + absl::Status RenameFile(const std::string& src, const std::string& target, + TransactionToken* token) override; + absl::Status CopyFile(const std::string& src, const std::string& target, + TransactionToken* token) override; + std::string TranslateName(const std::string& name) const override; + void FlushCaches(TransactionToken* token) override; + absl::Status SetOption(const std::string& name, + const std::vector& values) override; + absl::Status SetOption(const std::string& name, + const std::vector& values) override; + absl::Status SetOption(const std::string& name, + const std::vector& values) override; + + private: + std::unique_ptr filesystem_; + std::unique_ptr ops_; + std::unique_ptr random_access_file_ops_; + std::unique_ptr writable_file_ops_; + std::unique_ptr + read_only_memory_region_ops_; + std::function plugin_memory_allocate_; + std::function plugin_memory_free_; + ModularFileSystem(const ModularFileSystem&) = delete; + void operator=(const ModularFileSystem&) = delete; +}; + +class ModularRandomAccessFile final : public RandomAccessFile { + public: + ModularRandomAccessFile(const std::string& filename, + std::unique_ptr file, + const TF_RandomAccessFileOps* ops) + : filename_(filename), file_(std::move(file)), ops_(ops) {} + + ~ModularRandomAccessFile() override { ops_->cleanup(file_.get()); } + + absl::Status Read(uint64 offset, size_t n, StringPiece* result, + char* scratch) const override; + absl::Status Name(StringPiece* result) const override; + + private: + std::string filename_; + std::unique_ptr file_; + const TF_RandomAccessFileOps* ops_; // not owned + ModularRandomAccessFile(const ModularRandomAccessFile&) = delete; + void operator=(const ModularRandomAccessFile&) = delete; +}; + +class ModularWritableFile final : public WritableFile { + public: + ModularWritableFile(const std::string& filename, + std::unique_ptr file, + const TF_WritableFileOps* ops) + : filename_(filename), file_(std::move(file)), ops_(ops) {} + + ~ModularWritableFile() override { ops_->cleanup(file_.get()); } + + absl::Status Append(StringPiece data) override; + absl::Status Close() override; + absl::Status Flush() override; + absl::Status Sync() override; + absl::Status Name(StringPiece* result) const override; + absl::Status Tell(int64_t* position) override; + + private: + std::string filename_; + std::unique_ptr file_; + const TF_WritableFileOps* ops_; // not owned + ModularWritableFile(const ModularWritableFile&) = delete; + void operator=(const ModularWritableFile&) = delete; +}; + +class ModularReadOnlyMemoryRegion final : public ReadOnlyMemoryRegion { + public: + ModularReadOnlyMemoryRegion(std::unique_ptr region, + const TF_ReadOnlyMemoryRegionOps* ops) + : region_(std::move(region)), ops_(ops) {} + + ~ModularReadOnlyMemoryRegion() override { ops_->cleanup(region_.get()); }; + + const void* data() override { return ops_->data(region_.get()); } + uint64 length() override { return ops_->length(region_.get()); } + + private: + std::unique_ptr region_; + const TF_ReadOnlyMemoryRegionOps* ops_; // not owned + ModularReadOnlyMemoryRegion(const ModularReadOnlyMemoryRegion&) = delete; + void operator=(const ModularReadOnlyMemoryRegion&) = delete; +}; + +// Registers a filesystem plugin so that core TensorFlow can use it. +absl::Status RegisterFilesystemPlugin(const std::string& dso_path); + +} // namespace tensorflow + +#endif // TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_MODULAR_FILESYSTEM_H_ diff --git a/third_party/tflite-hdrs/tensorflow/c/experimental/filesystem/modular_filesystem_registration.h b/third_party/tflite-hdrs/tensorflow/c/experimental/filesystem/modular_filesystem_registration.h new file mode 100644 index 00000000..e119857f --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/c/experimental/filesystem/modular_filesystem_registration.h @@ -0,0 +1,34 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_MODULAR_FILESYSTEM_REGISTRATION_H_ +#define TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_MODULAR_FILESYSTEM_REGISTRATION_H_ + +#include "absl/status/status.h" +#include "tensorflow/c/experimental/filesystem/filesystem_interface.h" +#include "tensorflow/core/platform/status.h" + +namespace tensorflow { +namespace filesystem_registration { + +// Implementation for filesystem registration +// +// Don't call this directly. Instead call `RegisterFilesystemPlugin`. +// Exposed only for static registration of local filesystems. +absl::Status RegisterFilesystemPluginImpl(const TF_FilesystemPluginInfo* info); + +} // namespace filesystem_registration +} // namespace tensorflow + +#endif // TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_MODULAR_FILESYSTEM_REGISTRATION_H_ diff --git a/third_party/tflite-hdrs/tensorflow/c/experimental/filesystem/plugins/gcs/cleanup.h b/third_party/tflite-hdrs/tensorflow/c/experimental/filesystem/plugins/gcs/cleanup.h new file mode 100644 index 00000000..cc7a7451 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/c/experimental/filesystem/plugins/gcs/cleanup.h @@ -0,0 +1,109 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// MakeCleanup(f) returns an RAII cleanup object that calls 'f' in its +// destructor. The easiest way to use MakeCleanup is with a lambda argument, +// capturing the return value in an 'auto' local variable. Most users will not +// need more sophisticated syntax than that. +// +// Example: +// void func() { +// FILE* fp = fopen("data.txt", "r"); +// if (fp == nullptr) return; +// auto fp_cleaner = gtl::MakeCleanup([fp] { fclose(fp); }); +// // No matter what, fclose(fp) will happen. +// DataObject d; +// while (ReadDataObject(fp, &d)) { +// if (d.IsBad()) { +// LOG(ERROR) << "Bad Data"; +// return; +// } +// PushGoodData(d); +// } +// } +// +// You can use Cleanup directly, instead of using MakeCleanup and auto, +// but there's rarely a reason to do that. +// +// You can call 'release()' on a Cleanup object to cancel the cleanup. + +#ifndef TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_GCS_CLEANUP_H_ +#define TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_GCS_CLEANUP_H_ + +#include +#include + +namespace tf_gcs_filesystem { + +// A move-only RAII object that calls a stored cleanup functor when +// destroyed. Cleanup is the return type of gtl::MakeCleanup(F). +template +class Cleanup { + public: + Cleanup() : released_(true), f_() {} + + template + explicit Cleanup(G&& f) // NOLINT + : f_(std::forward(f)) {} // NOLINT(build/c++11) + + Cleanup(Cleanup&& src) // NOLINT + : released_(src.is_released()), f_(src.release()) {} + + // Implicitly move-constructible from any compatible Cleanup. + // The source will be released as if src.release() were called. + // A moved-from Cleanup can be safely destroyed or reassigned. + template + Cleanup(Cleanup&& src) // NOLINT + : released_(src.is_released()), f_(src.release()) {} + + // Assignment to a Cleanup object behaves like destroying it + // and making a new one in its place, analogous to unique_ptr + // semantics. + Cleanup& operator=(Cleanup&& src) { // NOLINT + if (!released_) f_(); + released_ = src.released_; + f_ = src.release(); + return *this; + } + + ~Cleanup() { + if (!released_) f_(); + } + + // Releases the cleanup function instead of running it. + // Hint: use c.release()() to run early. + F release() { + released_ = true; + return std::move(f_); + } + + bool is_released() const { return released_; } + + private: + static_assert(!std::is_reference::value, "F must not be a reference"); + + bool released_ = false; + F f_; +}; + +template ::type> +Cleanup MakeCleanup(F&& f) { + return Cleanup(std::forward(f)); +} + +} // namespace tf_gcs_filesystem + +#endif // TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_GCS_CLEANUP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/c/experimental/filesystem/plugins/gcs/expiring_lru_cache.h b/third_party/tflite-hdrs/tensorflow/c/experimental/filesystem/plugins/gcs/expiring_lru_cache.h new file mode 100644 index 00000000..c0347faa --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/c/experimental/filesystem/plugins/gcs/expiring_lru_cache.h @@ -0,0 +1,191 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_GCS_EXPIRING_LRU_CACHE_H_ +#define TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_GCS_EXPIRING_LRU_CACHE_H_ + +#include +#include +#include +#include +#include + +#include "absl/base/thread_annotations.h" +#include "absl/synchronization/mutex.h" +#include "tensorflow/c/env.h" +#include "tensorflow/c/tf_status.h" + +namespace tf_gcs_filesystem { + +/// \brief An LRU cache of string keys and arbitrary values, with configurable +/// max item age (in seconds) and max entries. +/// +/// This class is thread safe. +template +class ExpiringLRUCache { + public: + /// A `max_age` of 0 means that nothing is cached. A `max_entries` of 0 means + /// that there is no limit on the number of entries in the cache (however, if + /// `max_age` is also 0, the cache will not be populated). + ExpiringLRUCache(uint64_t max_age, size_t max_entries, + std::function timer_seconds = TF_NowSeconds) + : max_age_(max_age), + max_entries_(max_entries), + timer_seconds_(timer_seconds) {} + + /// Insert `value` with key `key`. This will replace any previous entry with + /// the same key. + void Insert(const std::string& key, const T& value) { + if (max_age_ == 0) { + return; + } + absl::MutexLock lock(&mu_); + InsertLocked(key, value); + } + + // Delete the entry with key `key`. Return true if the entry was found for + // `key`, false if the entry was not found. In both cases, there is no entry + // with key `key` existed after the call. + bool Delete(const std::string& key) { + absl::MutexLock lock(&mu_); + return DeleteLocked(key); + } + + /// Look up the entry with key `key` and copy it to `value` if found. Returns + /// true if an entry was found for `key`, and its timestamp is not more than + /// max_age_ seconds in the past. + bool Lookup(const std::string& key, T* value) { + if (max_age_ == 0) { + return false; + } + absl::MutexLock lock(&mu_); + return LookupLocked(key, value); + } + + typedef std::function ComputeFunc; + + /// Look up the entry with key `key` and copy it to `value` if found. If not + /// found, call `compute_func`. If `compute_func` set `status` to `TF_OK`, + /// store a copy of the output parameter in the cache, and another copy in + /// `value`. + void LookupOrCompute(const std::string& key, T* value, + const ComputeFunc& compute_func, TF_Status* status) { + if (max_age_ == 0) { + return compute_func(key, value, status); + } + + // Note: we hold onto mu_ for the rest of this function. In practice, this + // is okay, as stat requests are typically fast, and concurrent requests are + // often for the same file. Future work can split this up into one lock per + // key if this proves to be a significant performance bottleneck. + absl::MutexLock lock(&mu_); + if (LookupLocked(key, value)) { + return TF_SetStatus(status, TF_OK, ""); + } + compute_func(key, value, status); + if (TF_GetCode(status) == TF_OK) { + InsertLocked(key, *value); + } + } + + /// Clear the cache. + void Clear() { + absl::MutexLock lock(&mu_); + cache_.clear(); + lru_list_.clear(); + } + + /// Accessors for cache parameters. + uint64_t max_age() const { return max_age_; } + size_t max_entries() const { return max_entries_; } + + private: + struct Entry { + /// The timestamp (seconds) at which the entry was added to the cache. + uint64_t timestamp; + + /// The entry's value. + T value; + + /// A list iterator pointing to the entry's position in the LRU list. + std::list::iterator lru_iterator; + }; + + bool LookupLocked(const std::string& key, T* value) + ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_) { + auto it = cache_.find(key); + if (it == cache_.end()) { + return false; + } + lru_list_.erase(it->second.lru_iterator); + if (timer_seconds_() - it->second.timestamp > max_age_) { + cache_.erase(it); + return false; + } + *value = it->second.value; + lru_list_.push_front(it->first); + it->second.lru_iterator = lru_list_.begin(); + return true; + } + + void InsertLocked(const std::string& key, const T& value) + ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_) { + lru_list_.push_front(key); + Entry entry{timer_seconds_(), value, lru_list_.begin()}; + auto insert = cache_.insert(std::make_pair(key, entry)); + if (!insert.second) { + lru_list_.erase(insert.first->second.lru_iterator); + insert.first->second = entry; + } else if (max_entries_ > 0 && cache_.size() > max_entries_) { + cache_.erase(lru_list_.back()); + lru_list_.pop_back(); + } + } + + bool DeleteLocked(const std::string& key) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_) { + auto it = cache_.find(key); + if (it == cache_.end()) { + return false; + } + lru_list_.erase(it->second.lru_iterator); + cache_.erase(it); + return true; + } + + /// The maximum age of entries in the cache, in seconds. A value of 0 means + /// that no entry is ever placed in the cache. + const uint64_t max_age_; + + /// The maximum number of entries in the cache. A value of 0 means there is no + /// limit on entry count. + const size_t max_entries_; + + /// The callback to read timestamps. + std::function timer_seconds_; + + /// Guards access to the cache and the LRU list. + absl::Mutex mu_; + + /// The cache (a map from string key to Entry). + std::map cache_ ABSL_GUARDED_BY(mu_); + + /// The LRU list of entries. The front of the list identifies the most + /// recently accessed entry. + std::list lru_list_ ABSL_GUARDED_BY(mu_); +}; + +} // namespace tf_gcs_filesystem + +#endif // TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_GCS_EXPIRING_LRU_CACHE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/c/experimental/filesystem/plugins/gcs/gcs_filesystem.h b/third_party/tflite-hdrs/tensorflow/c/experimental/filesystem/plugins/gcs/gcs_filesystem.h new file mode 100644 index 00000000..c7781f52 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/c/experimental/filesystem/plugins/gcs/gcs_filesystem.h @@ -0,0 +1,117 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ==============================================================================*/ +#ifndef TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_GCS_GCS_FILESYSTEM_H_ +#define TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_GCS_GCS_FILESYSTEM_H_ + +#include +#include +#include +#include + +#include "absl/base/thread_annotations.h" +#include "absl/synchronization/mutex.h" +#include "google/cloud/storage/client.h" +#include "tensorflow/c/experimental/filesystem/filesystem_interface.h" +#include "tensorflow/c/experimental/filesystem/plugins/gcs/expiring_lru_cache.h" +#include "tensorflow/c/experimental/filesystem/plugins/gcs/ram_file_block_cache.h" +#include "tensorflow/c/tf_status.h" + +void ParseGCSPath(const std::string& fname, bool object_empty_ok, + std::string* bucket, std::string* object, TF_Status* status); + +namespace tf_random_access_file { +void Cleanup(TF_RandomAccessFile* file); +int64_t Read(const TF_RandomAccessFile* file, uint64_t offset, size_t n, + char* buffer, TF_Status* status); +} // namespace tf_random_access_file + +namespace tf_writable_file { +void Cleanup(TF_WritableFile* file); +void Append(const TF_WritableFile* file, const char* buffer, size_t n, + TF_Status* status); +int64_t Tell(const TF_WritableFile* file, TF_Status* status); +void Flush(const TF_WritableFile* file, TF_Status* status); +void Sync(const TF_WritableFile* file, TF_Status* status); +void Close(const TF_WritableFile* file, TF_Status* status); +} // namespace tf_writable_file + +namespace tf_read_only_memory_region { +void Cleanup(TF_ReadOnlyMemoryRegion* region); +const void* Data(const TF_ReadOnlyMemoryRegion* region); +uint64_t Length(const TF_ReadOnlyMemoryRegion* region); +} // namespace tf_read_only_memory_region + +namespace tf_gcs_filesystem { +typedef struct GcsFileStat { + TF_FileStatistics base; + int64_t generation_number; +} GcsFileStat; + +typedef struct GCSFile { + google::cloud::storage::Client gcs_client; // owned + bool compose; + absl::Mutex block_cache_lock; + std::shared_ptr file_block_cache + ABSL_GUARDED_BY(block_cache_lock); + uint64_t block_size; // Reads smaller than block_size will trigger a read + // of block_size. + std::unique_ptr> stat_cache; + GCSFile(google::cloud::storage::Client&& gcs_client); + // This constructor is used for testing purpose only. + GCSFile(google::cloud::storage::Client&& gcs_client, bool compose, + uint64_t block_size, size_t max_bytes, uint64_t max_staleness, + uint64_t stat_cache_max_age, size_t stat_cache_max_entries); +} GCSFile; + +// This function is used to initialize a filesystem without the need of setting +// manually environement variables. +void InitTest(TF_Filesystem* filesystem, bool compose, uint64_t block_size, + size_t max_bytes, uint64_t max_staleness, + uint64_t stat_cache_max_age, size_t stat_cache_max_entries, + TF_Status* status); + +void Init(TF_Filesystem* filesystem, TF_Status* status); +void Cleanup(TF_Filesystem* filesystem); +void NewRandomAccessFile(const TF_Filesystem* filesystem, const char* path, + TF_RandomAccessFile* file, TF_Status* status); +void NewWritableFile(const TF_Filesystem* filesystem, const char* path, + TF_WritableFile* file, TF_Status* status); +void NewAppendableFile(const TF_Filesystem* filesystem, const char* path, + TF_WritableFile* file, TF_Status* status); +void NewReadOnlyMemoryRegionFromFile(const TF_Filesystem* filesystem, + const char* path, + TF_ReadOnlyMemoryRegion* region, + TF_Status* status); +int64_t GetFileSize(const TF_Filesystem* filesystem, const char* path, + TF_Status* status); +void PathExists(const TF_Filesystem* filesystem, const char* path, + TF_Status* status); +void CreateDir(const TF_Filesystem* filesystem, const char* path, + TF_Status* status); +int GetChildren(const TF_Filesystem* filesystem, const char* path, + char*** entries, TF_Status* status); +void DeleteFile(const TF_Filesystem* filesystem, const char* path, + TF_Status* status); +void Stat(const TF_Filesystem* filesystem, const char* path, + TF_FileStatistics* stats, TF_Status* status); +void DeleteDir(const TF_Filesystem* filesystem, const char* path, + TF_Status* status); +void CopyFile(const TF_Filesystem* filesystem, const char* src, const char* dst, + TF_Status* status); +void RenameFile(const TF_Filesystem* filesystem, const char* src, + const char* dst, TF_Status* status); +} // namespace tf_gcs_filesystem + +#endif // TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_GCS_GCS_FILESYSTEM_H_ diff --git a/third_party/tflite-hdrs/tensorflow/c/experimental/filesystem/plugins/gcs/gcs_helper.h b/third_party/tflite-hdrs/tensorflow/c/experimental/filesystem/plugins/gcs/gcs_helper.h new file mode 100644 index 00000000..dfe182e2 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/c/experimental/filesystem/plugins/gcs/gcs_helper.h @@ -0,0 +1,35 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_GCS_GCS_HELPER_H_ +#define TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_GCS_GCS_HELPER_H_ + +#include +#include +#include + +class TempFile : public std::fstream { + public: + // We should specify openmode each time we call TempFile. + TempFile(const std::string& temp_file_name, std::ios::openmode mode); + TempFile(TempFile&& rhs); + ~TempFile() override; + const std::string getName() const; + bool truncate(); + + private: + const std::string name_; +}; + +#endif // TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_GCS_GCS_HELPER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/c/experimental/filesystem/plugins/gcs/ram_file_block_cache.h b/third_party/tflite-hdrs/tensorflow/c/experimental/filesystem/plugins/gcs/ram_file_block_cache.h new file mode 100644 index 00000000..7e674722 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/c/experimental/filesystem/plugins/gcs/ram_file_block_cache.h @@ -0,0 +1,269 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_GCS_RAM_FILE_BLOCK_CACHE_H_ +#define TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_GCS_RAM_FILE_BLOCK_CACHE_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/base/thread_annotations.h" +#include "absl/synchronization/mutex.h" +#include "absl/synchronization/notification.h" +#include "tensorflow/c/env.h" +#include "tensorflow/c/logging.h" +#include "tensorflow/c/tf_status.h" + +namespace tf_gcs_filesystem { + +/// \brief An LRU block cache of file contents, keyed by {filename, offset}. +/// +/// This class should be shared by read-only random access files on a remote +/// filesystem (e.g. GCS). +class RamFileBlockCache { + public: + /// The callback executed when a block is not found in the cache, and needs to + /// be fetched from the backing filesystem. This callback is provided when the + /// cache is constructed. It returns total bytes read ( -1 in case of errors + /// ). The `status` should be `TF_OK` as long as the read from the remote + /// filesystem succeeded (similar to the semantics of the read(2) system + /// call). + typedef std::function + BlockFetcher; + + RamFileBlockCache(size_t block_size, size_t max_bytes, uint64_t max_staleness, + BlockFetcher block_fetcher, + std::function timer_seconds = TF_NowSeconds) + : block_size_(block_size), + max_bytes_(max_bytes), + max_staleness_(max_staleness), + block_fetcher_(block_fetcher), + timer_seconds_(timer_seconds), + pruning_thread_(nullptr, + [](TF_Thread* thread) { TF_JoinThread(thread); }) { + if (max_staleness_ > 0) { + TF_ThreadOptions thread_options; + TF_DefaultThreadOptions(&thread_options); + pruning_thread_.reset( + TF_StartThread(&thread_options, "TF_prune_FBC", PruneThread, this)); + } + TF_VLog(1, "GCS file block cache is %s.\n", + (IsCacheEnabled() ? "enabled" : "disabled")); + } + + ~RamFileBlockCache() { + if (pruning_thread_) { + stop_pruning_thread_.Notify(); + // Destroying pruning_thread_ will block until Prune() receives the above + // notification and returns. + pruning_thread_.reset(); + } + } + + /// Read `n` bytes from `filename` starting at `offset` into `buffer`. It + /// returns total bytes read ( -1 in case of errors ). This method will set + /// `status` to: + /// + /// 1) The error from the remote filesystem, if the read from the remote + /// filesystem failed. + /// 2) `TF_FAILED_PRECONDITION` if the read from the remote filesystem + /// succeeded, + /// but the read returned a partial block, and the LRU cache contained a + /// block at a higher offset (indicating that the partial block should have + /// been a full block). + /// 3) `TF_OUT_OF_RANGE` if the read from the remote filesystem succeeded, but + /// the file contents do not extend past `offset` and thus nothing was + /// placed in `out`. + /// 4) `TF_OK` otherwise (i.e. the read succeeded, and at least one byte was + /// placed + /// in `buffer`). + /// + /// Caller is responsible for allocating memory for `buffer`. + /// `buffer` will be left unchanged in case of errors. + int64_t Read(const std::string& filename, size_t offset, size_t n, + char* buffer, TF_Status* status); + + // Validate the given file signature with the existing file signature in the + // cache. Returns true if the signature doesn't change or the file doesn't + // exist before. If the signature changes, update the existing signature with + // the new one and remove the file from cache. + bool ValidateAndUpdateFileSignature(const std::string& filename, + int64_t file_signature) + ABSL_LOCKS_EXCLUDED(mu_); + + /// Remove all cached blocks for `filename`. + void RemoveFile(const std::string& filename) ABSL_LOCKS_EXCLUDED(mu_); + + /// Remove all cached data. + void Flush() ABSL_LOCKS_EXCLUDED(mu_); + + /// Accessors for cache parameters. + size_t block_size() const { return block_size_; } + size_t max_bytes() const { return max_bytes_; } + uint64_t max_staleness() const { return max_staleness_; } + + /// The current size (in bytes) of the cache. + size_t CacheSize() const ABSL_LOCKS_EXCLUDED(mu_); + + // Returns true if the cache is enabled. If false, the BlockFetcher callback + // is always executed during Read. + bool IsCacheEnabled() const { return block_size_ > 0 && max_bytes_ > 0; } + + // We can not pass a lambda with capture as a function pointer to + // `TF_StartThread`, so we have to wrap `Prune` inside a static function. + static void PruneThread(void* param) { + auto ram_file_block_cache = static_cast(param); + ram_file_block_cache->Prune(); + } + + private: + /// The size of the blocks stored in the LRU cache, as well as the size of the + /// reads from the underlying filesystem. + const size_t block_size_; + /// The maximum number of bytes (sum of block sizes) allowed in the LRU cache. + const size_t max_bytes_; + /// The maximum staleness of any block in the LRU cache, in seconds. + const uint64_t max_staleness_; + /// The callback to read a block from the underlying filesystem. + const BlockFetcher block_fetcher_; + /// The callback to read timestamps. + const std::function timer_seconds_; + + /// \brief The key type for the file block cache. + /// + /// The file block cache key is a {filename, offset} pair. + typedef std::pair Key; + + /// \brief The state of a block. + /// + /// A block begins in the CREATED stage. The first thread will attempt to read + /// the block from the filesystem, transitioning the state of the block to + /// FETCHING. After completing, if the read was successful the state should + /// be FINISHED. Otherwise the state should be ERROR. A subsequent read can + /// re-fetch the block if the state is ERROR. + enum class FetchState { + CREATED, + FETCHING, + FINISHED, + ERROR, + }; + + /// \brief A block of a file. + /// + /// A file block consists of the block data, the block's current position in + /// the LRU cache, the timestamp (seconds since epoch) at which the block + /// was cached, a coordination lock, and state & condition variables. + /// + /// Thread safety: + /// The iterator and timestamp fields should only be accessed while holding + /// the block-cache-wide mu_ instance variable. The state variable should only + /// be accessed while holding the Block's mu lock. The data vector should only + /// be accessed after state == FINISHED, and it should never be modified. + /// + /// In order to prevent deadlocks, never grab the block-cache-wide mu_ lock + /// AFTER grabbing any block's mu lock. It is safe to grab mu without locking + /// mu_. + struct Block { + /// The block data. + std::vector data; + /// A list iterator pointing to the block's position in the LRU list. + std::list::iterator lru_iterator; + /// A list iterator pointing to the block's position in the LRA list. + std::list::iterator lra_iterator; + /// The timestamp (seconds since epoch) at which the block was cached. + uint64_t timestamp; + /// Mutex to guard state variable + absl::Mutex mu; + /// The state of the block. + FetchState state ABSL_GUARDED_BY(mu) = FetchState::CREATED; + /// Wait on cond_var if state is FETCHING. + absl::CondVar cond_var; + }; + + /// \brief The block map type for the file block cache. + /// + /// The block map is an ordered map from Key to Block. + typedef std::map> BlockMap; + + /// Prune the cache by removing files with expired blocks. + void Prune() ABSL_LOCKS_EXCLUDED(mu_); + + bool BlockNotStale(const std::shared_ptr& block) + ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); + + /// Look up a Key in the block cache. + std::shared_ptr Lookup(const Key& key) ABSL_LOCKS_EXCLUDED(mu_); + + void MaybeFetch(const Key& key, const std::shared_ptr& block, + TF_Status* status) ABSL_LOCKS_EXCLUDED(mu_); + + /// Trim the block cache to make room for another entry. + void Trim() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); + + /// Update the LRU iterator for the block at `key`. + void UpdateLRU(const Key& key, const std::shared_ptr& block, + TF_Status* status) ABSL_LOCKS_EXCLUDED(mu_); + + /// Remove all blocks of a file, with mu_ already held. + void RemoveFile_Locked(const std::string& filename) + ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); + + /// Remove the block `entry` from the block map and LRU list, and update the + /// cache size accordingly. + void RemoveBlock(BlockMap::iterator entry) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); + + /// The cache pruning thread that removes files with expired blocks. + std::unique_ptr> pruning_thread_; + + /// Notification for stopping the cache pruning thread. + absl::Notification stop_pruning_thread_; + + /// Guards access to the block map, LRU list, and cached byte count. + mutable absl::Mutex mu_; + + /// The block map (map from Key to Block). + BlockMap block_map_ ABSL_GUARDED_BY(mu_); + + /// The LRU list of block keys. The front of the list identifies the most + /// recently accessed block. + std::list lru_list_ ABSL_GUARDED_BY(mu_); + + /// The LRA (least recently added) list of block keys. The front of the list + /// identifies the most recently added block. + /// + /// Note: blocks are added to lra_list_ only after they have successfully been + /// fetched from the underlying block store. + std::list lra_list_ ABSL_GUARDED_BY(mu_); + + /// The combined number of bytes in all of the cached blocks. + size_t cache_size_ ABSL_GUARDED_BY(mu_) = 0; + + // A filename->file_signature map. + std::map file_signature_map_ ABSL_GUARDED_BY(mu_); +}; + +} // namespace tf_gcs_filesystem + +#endif // TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_GCS_RAM_FILE_BLOCK_CACHE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/c/experimental/filesystem/plugins/posix/copy_file.h b/third_party/tflite-hdrs/tensorflow/c/experimental/filesystem/plugins/posix/copy_file.h new file mode 100644 index 00000000..d7c2f970 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/c/experimental/filesystem/plugins/posix/copy_file.h @@ -0,0 +1,32 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_POSIX_COPY_FILE_H_ +#define TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_POSIX_COPY_FILE_H_ + +#include + +namespace tf_posix_filesystem { + +// Transfers up to `size` bytes from `dst_fd` to `src_fd`. +// +// This method uses `sendfile` if available (i.e., linux 2.6.33 or later) or an +// intermediate buffer if not. +// +// Returns number of bytes transferred or -1 on failure. +int CopyFileContents(int dst_fd, int src_fd, off_t size); + +} // namespace tf_posix_filesystem + +#endif // TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_POSIX_COPY_FILE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/c/experimental/filesystem/plugins/posix/posix_filesystem.h b/third_party/tflite-hdrs/tensorflow/c/experimental/filesystem/plugins/posix/posix_filesystem.h new file mode 100644 index 00000000..0a444ef8 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/c/experimental/filesystem/plugins/posix/posix_filesystem.h @@ -0,0 +1,31 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_POSIX_POSIX_FILESYSTEM_H_ +#define TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_POSIX_POSIX_FILESYSTEM_H_ + +#include "tensorflow/c/experimental/filesystem/filesystem_interface.h" + +// Initialize the POSIX filesystem. +// +// In general, the `TF_InitPlugin` symbol doesn't need to be exposed in a header +// file, since the plugin registration will look for the symbol in the DSO file +// that provides the filesystem functionality. However, the POSIX filesystem +// needs to be statically registered in some tests and utilities for building +// the API files at the time of creating the pip package. Hence, we need to +// expose this function so that this filesystem can be statically registered +// when needed. +void TF_InitPlugin(TF_FilesystemPluginInfo* info); + +#endif // TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_POSIX_POSIX_FILESYSTEM_H_ diff --git a/third_party/tflite-hdrs/tensorflow/c/experimental/filesystem/plugins/posix/posix_filesystem_helper.h b/third_party/tflite-hdrs/tensorflow/c/experimental/filesystem/plugins/posix/posix_filesystem_helper.h new file mode 100644 index 00000000..612366ba --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/c/experimental/filesystem/plugins/posix/posix_filesystem_helper.h @@ -0,0 +1,37 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_POSIX_POSIX_FILESYSTEM_HELPER_H_ +#define TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_POSIX_POSIX_FILESYSTEM_HELPER_H_ + +#include +#include + +namespace tf_posix_filesystem { + +// Copies up to `size` of `src` to `dst`, creating destination if needed. +// +// Callers should pass size of `src` in `size` and the permissions of `src` in +// `mode`. The later is only used if `dst` needs to be created. +int TransferFileContents(const char* src, const char* dst, mode_t mode, + off_t size); + +// Returns true only if `entry` points to an entry other than `.` or `..`. +// +// This is a filter for `scandir`. +int RemoveSpecialDirectoryEntries(const struct dirent* entry); + +} // namespace tf_posix_filesystem + +#endif // TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_POSIX_POSIX_FILESYSTEM_HELPER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/c/experimental/gradients/array_grad.h b/third_party/tflite-hdrs/tensorflow/c/experimental/gradients/array_grad.h new file mode 100644 index 00000000..3dcf98b0 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/c/experimental/gradients/array_grad.h @@ -0,0 +1,26 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_ARRAY_GRAD_H_ +#define TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_ARRAY_GRAD_H_ + +#include "tensorflow/c/eager/gradients.h" + +namespace tensorflow { +namespace gradients { +GradientFunction* IdentityNRegisterer(const ForwardOperation& op); +} // namespace gradients +} // namespace tensorflow + +#endif // TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_ARRAY_GRAD_H_ diff --git a/third_party/tflite-hdrs/tensorflow/c/experimental/gradients/grad_test_helper.h b/third_party/tflite-hdrs/tensorflow/c/experimental/gradients/grad_test_helper.h new file mode 100644 index 00000000..84761f96 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/c/experimental/gradients/grad_test_helper.h @@ -0,0 +1,39 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_GRAD_TEST_HELPER_H_ +#define TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_GRAD_TEST_HELPER_H_ + +#include "tensorflow/c/eager/gradients.h" +#include "tensorflow/c/eager/unified_api_testutil.h" + +namespace tensorflow { +namespace gradients { +namespace internal { + +void CompareNumericalAndAutodiffGradients( + Model model, Model grad_model, AbstractContext* ctx, + absl::Span inputs, bool use_function, + double abs_error = 1e-2); + +void CheckTensorValue(AbstractTensorHandle* t, absl::Span manuals, + absl::Span dims, double abs_error = 1e-2); + +Model BuildGradModel(Model forward, GradientRegistry registry); + +} // namespace internal +} // namespace gradients +} // namespace tensorflow + +#endif // TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_GRAD_TEST_HELPER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/c/experimental/gradients/math_grad.h b/third_party/tflite-hdrs/tensorflow/c/experimental/gradients/math_grad.h new file mode 100644 index 00000000..e26ee899 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/c/experimental/gradients/math_grad.h @@ -0,0 +1,36 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_MATH_GRAD_H_ +#define TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_MATH_GRAD_H_ + +#include "tensorflow/c/eager/gradients.h" + +namespace tensorflow { +namespace gradients { + +GradientFunction* AddRegisterer(const ForwardOperation& op); +GradientFunction* ExpRegisterer(const ForwardOperation& op); +GradientFunction* MatMulRegisterer(const ForwardOperation& op); +GradientFunction* SqrtRegisterer(const ForwardOperation& op); +GradientFunction* NegRegisterer(const ForwardOperation& op); +GradientFunction* SubRegisterer(const ForwardOperation& op); +GradientFunction* MulRegisterer(const ForwardOperation& op); +GradientFunction* Log1pRegisterer(const ForwardOperation& op); +GradientFunction* DivNoNanRegisterer(const ForwardOperation& op); + +} // namespace gradients +} // namespace tensorflow + +#endif // TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_MATH_GRAD_H_ diff --git a/third_party/tflite-hdrs/tensorflow/c/experimental/gradients/nn_grad.h b/third_party/tflite-hdrs/tensorflow/c/experimental/gradients/nn_grad.h new file mode 100644 index 00000000..2a635f54 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/c/experimental/gradients/nn_grad.h @@ -0,0 +1,29 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_NN_GRAD_H_ +#define TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_NN_GRAD_H_ + +#include "tensorflow/c/eager/gradients.h" + +namespace tensorflow { +namespace gradients { +GradientFunction* ReluRegisterer(const ForwardOperation& op); +GradientFunction* SparseSoftmaxCrossEntropyWithLogitsRegisterer( + const ForwardOperation& op); +GradientFunction* BiasAddRegisterer(const ForwardOperation& op); +} // namespace gradients +} // namespace tensorflow + +#endif // TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_NN_GRAD_H_ diff --git a/third_party/tflite-hdrs/tensorflow/c/experimental/gradients/not_differentiable.h b/third_party/tflite-hdrs/tensorflow/c/experimental/gradients/not_differentiable.h new file mode 100644 index 00000000..7167340a --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/c/experimental/gradients/not_differentiable.h @@ -0,0 +1,35 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_NOT_DIFFERENTIABLE_H_ +#define TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_NOT_DIFFERENTIABLE_H_ + +#include "tensorflow/c/eager/abstract_context.h" +#include "tensorflow/c/eager/gradients.h" + +namespace tensorflow { +namespace gradients { +// Ignores `grad_outputs` and sets all entries in grad_inputs to nullptr. +class NotDifferentiableGradientFunction : public GradientFunction { + absl::Status Compute(AbstractContext* ctx, + absl::Span grad_outputs, + absl::Span grad_inputs) override; +}; +// Shorthand for registry->Register(op, new NotDifferentiableGradientFunction) +absl::Status RegisterNotDifferentiable(GradientRegistry* registry, + const string& op); +} // namespace gradients +} // namespace tensorflow + +#endif // TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_NOT_DIFFERENTIABLE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/c/experimental/gradients/tape/tape_context.h b/third_party/tflite-hdrs/tensorflow/c/experimental/gradients/tape/tape_context.h new file mode 100644 index 00000000..f92c35f2 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/c/experimental/gradients/tape/tape_context.h @@ -0,0 +1,49 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_TAPE_TAPE_CONTEXT_H_ +#define TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_TAPE_TAPE_CONTEXT_H_ + +#include "absl/status/status.h" +#include "tensorflow/c/eager/abstract_context.h" +#include "tensorflow/c/eager/abstract_function.h" +#include "tensorflow/c/eager/gradients.h" +#include "tensorflow/c/experimental/gradients/tape/tape_operation.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +namespace gradients { +class TapeContext : public AbstractContext { + public: + explicit TapeContext(AbstractContext*, Tape*, const GradientRegistry&); + void Release() override; + TapeOperation* CreateOperation() override; + absl::Status RegisterFunction(AbstractFunction*) override; + absl::Status RemoveFunction(const string& func) override; + // For LLVM style RTTI. + static bool classof(const AbstractContext* ptr) { + return ptr->getKind() == kTape; + } + ~TapeContext() override; + + private: + AbstractContext* parent_ctx_; // Not owned. + Tape* tape_; + const GradientRegistry& registry_; +}; +} // namespace gradients +} // namespace tensorflow + +#endif // TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_TAPE_TAPE_CONTEXT_H_ diff --git a/third_party/tflite-hdrs/tensorflow/c/experimental/gradients/tape/tape_operation.h b/third_party/tflite-hdrs/tensorflow/c/experimental/gradients/tape/tape_operation.h new file mode 100644 index 00000000..8f447440 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/c/experimental/gradients/tape/tape_operation.h @@ -0,0 +1,94 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_TAPE_TAPE_OPERATION_H_ +#define TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_TAPE_TAPE_OPERATION_H_ + +#include +#include + +#include "absl/status/status.h" +#include "absl/types/span.h" +#include "tensorflow/c/eager/abstract_operation.h" +#include "tensorflow/c/eager/abstract_tensor_handle.h" +#include "tensorflow/c/eager/gradients.h" +#include "tensorflow/c/tensor_interface.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +namespace gradients { +class TapeOperation : public AbstractOperation { + public: + explicit TapeOperation(AbstractOperation*, Tape*, const GradientRegistry&); + void Release() override; + absl::Status Reset(const char* op, const char* raw_device_name) override; + const string& Name() const override; + const string& DeviceName() const override; + absl::Status SetDeviceName(const char* name) override; + absl::Status AddInput(AbstractTensorHandle* input) override; + absl::Status AddInputList( + absl::Span inputs) override; + absl::Status Execute(absl::Span retvals, + int* num_retvals) override; + absl::Status SetAttrString(const char* attr_name, const char* data, + size_t length) override; + absl::Status SetAttrInt(const char* attr_name, int64_t value) override; + absl::Status SetAttrFloat(const char* attr_name, float value) override; + absl::Status SetAttrBool(const char* attr_name, bool value) override; + absl::Status SetAttrType(const char* attr_name, DataType value) override; + absl::Status SetAttrShape(const char* attr_name, const int64_t* dims, + const int num_dims) override; + absl::Status SetAttrFunction(const char* attr_name, + const AbstractOperation* value) override; + absl::Status SetAttrFunctionName(const char* attr_name, const char* value, + size_t length) override; + absl::Status SetAttrTensor(const char* attr_name, + AbstractTensorInterface* tensor) override; + absl::Status SetAttrStringList(const char* attr_name, + const void* const* values, + const size_t* lengths, + int num_values) override; + absl::Status SetAttrFloatList(const char* attr_name, const float* values, + int num_values) override; + absl::Status SetAttrIntList(const char* attr_name, const int64_t* values, + int num_values) override; + absl::Status SetAttrTypeList(const char* attr_name, const DataType* values, + int num_values) override; + absl::Status SetAttrBoolList(const char* attr_name, + const unsigned char* values, + int num_values) override; + absl::Status SetAttrShapeList(const char* attr_name, const int64_t** dims, + const int* num_dims, int num_values) override; + absl::Status SetAttrFunctionList( + const char* attr_name, + absl::Span values) override; + AbstractOperation* GetBackingOperation(); + // For LLVM style RTTI. + static bool classof(const AbstractOperation* ptr) { + return ptr->getKind() == kTape; + } + ~TapeOperation() override; + + private: + AbstractOperation* parent_op_; + ForwardOperation forward_op_; + Tape* tape_; + const GradientRegistry& registry_; +}; + +} // namespace gradients +} // namespace tensorflow +#endif // TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_TAPE_TAPE_OPERATION_H_ diff --git a/third_party/tflite-hdrs/tensorflow/c/experimental/grappler/grappler.h b/third_party/tflite-hdrs/tensorflow/c/experimental/grappler/grappler.h new file mode 100644 index 00000000..0a293c66 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/c/experimental/grappler/grappler.h @@ -0,0 +1,294 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_C_EXPERIMENTAL_GRAPPLER_GRAPPLER_H_ +#define TENSORFLOW_C_EXPERIMENTAL_GRAPPLER_GRAPPLER_H_ + +#include +#include + +#include "tensorflow/c/c_api.h" +#include "tensorflow/c/c_api_macros.h" +#include "tensorflow/c/tf_buffer.h" +#include "tensorflow/c/tf_status.h" + +// -------------------------------------------------------------------------- +// C API for Graph. The API is under active development and eventually +// should allow registering a plugin graph optimizer with TensorFlow. +// +// Conventions: +// * Struct prefix indicates whether struct fields should be filled by the +// plugin or core implementation: +// * Struct that should be filled by the plugin: `TP_OptimizerConfigs`, +// `TP_Optimizer`, `TP_OptimizerRegistrationParams` +// * Struct that should be filled by the proper: `TF_GrapplerItem`, +// `TF_GraphProperties`, `TF_FunctionLibraryDefinition` +// * We use `struct_size` for version checking. It should be set both by +// core and the plugin. +// * For example, `TF_InitGraph` function receives +// `TP_OptimizerRegistrationParams*` as input with `struct_size` +// populated by core. The plugin is responsible for setting +// `struct_size` as well, along with all other fields. +// * Refer to "TensorFlow Versioning Strategy" section at +// https://github.com/tensorflow/community/pull/257/files. +// * Note that the API is still under active development and doesn't have +// versioning guarantees yet. +// * `void* ext` is a free-form field that can be populated by +// a plugin in `TP_*` structs or potential future extension points . +// +// Example usage: +// +// /* Sample TensorFlow code below, exact implementation might differ. */ +// // Version checking uses `struct_size`. It should be set both by core +// // and the plugin. +// TP_OptimizerRegistrationParams params{ +// TP_OPTIMIZER_REGISTRATION_PARAMS_STRUCT_SIZE}; +// TP_Optimizer optimizer{TP_OPTIMIZER_STRUCT_SIZE}; +// TP_OptimizerConfigs configs{TP_OPTIMIZER_CONFIGS_STRUCT_SIZE}; +// params.optimizer = &optimizer; +// params.configs = &configs; +// +// /* Plugin code below */ +// void TF_InitGraph(TP_OptimizerRegistrationParams* params, +// TF_Status* status) { +// params->struct_size = TP_OPTIMIZER_REGISTRATION_PARAMS_STRUCT_SIZE; +// params->device_type = "MY_DEVICE"; +// +// // Disable certain optimizer. +// params->optimizer_configs->struct_size = +// TP_OPTIMIZER_CONFIGS_STRUCT_SIZE; params->optimizer_configs->remapping = +// TF_TriState_Off; +// +// // Set functions to create a new optimizer. +// params->optimizer->struct_size = TP_OPTIMIZER_STRUCT_SIZE; +// params->optimizer->create_func = (My_optimizer::create_func); +// } + +#define GO_MAJOR 0 +#define GO_MINOR 0 +#define GO_PATCH 1 + +#ifdef __cplusplus +extern "C" { +#endif + +// TF_TriState is the C API typedef for tri-state. +typedef enum TF_TriState { + TF_TriState_Default = 0, + TF_TriState_Off, + TF_TriState_On, +} TF_TriState; + +// TF_GrapplerItem represents a combination of a graph, one of more fetch nodes, +// and potentially a set of nodes to feed. +typedef struct TF_GrapplerItem TF_GrapplerItem; + +// Flags indicating whether existing optimizers should be turned off. +// It's optional for plugin to set functions to return true/false. If not +// set, proper uses configuration set by user. +typedef struct TP_OptimizerConfigs { + size_t struct_size; + void* ext; // reserved for future use + TF_TriState disable_model_pruning; + TF_TriState implementation_selector; + TF_TriState function_optimization; + TF_TriState common_subgraph_elimination; + TF_TriState arithmetic_optimization; + TF_TriState debug_stripper; + TF_TriState constant_folding; + TF_TriState shape_optimization; + TF_TriState auto_mixed_precision; + TF_TriState auto_mixed_precision_onednn_bfloat16; + TF_TriState auto_mixed_precision_mkl; + TF_TriState pin_to_host_optimization; + TF_TriState layout_optimizer; + TF_TriState remapping; + TF_TriState loop_optimization; + TF_TriState dependency_optimization; + TF_TriState auto_parallel; + TF_TriState memory_optimization; + TF_TriState scoped_allocator_optimization; +} TP_OptimizerConfigs; + +#define TP_OPTIMIZER_CONFIGS_STRUCT_SIZE \ + TF_OFFSET_OF_END(TP_OptimizerConfigs, scoped_allocator_optimization) + +// Struct for Optimizer. Plugin authors must provide an optimize function. +// Creation and deletion functions are optional. +typedef struct TP_Optimizer { + size_t struct_size; + void* ext; // reserved for future use + + // [Optional] + // Create function for optimizer. + void* (*create_func)(); + + // Optimizer function for optimizer. The first param is an optimizer created + // by create_func. The second param is input graph. The third param is + // GrapplerItem. The fourth param is output graph. + void (*optimize_func)(void*, const TF_Buffer*, const TF_GrapplerItem*, + TF_Buffer*, TF_Status*); + + // [Optional] + // Destroy function for optimizer. If Create function is provided, destroy + // function is must. + void (*destroy_func)(void*); +} TP_Optimizer; + +#define TP_OPTIMIZER_STRUCT_SIZE TF_OFFSET_OF_END(TP_Optimizer, destroy_func) + +typedef struct TP_OptimizerRegistrationParams { + size_t struct_size; + void* ext; // reserved for future use + + // Graph C API version. + int32_t major_version; + int32_t minor_version; + int32_t patch_version; + + // Backend device type supported by the optimizer. + const char* device_type; + TP_OptimizerConfigs* optimizer_configs; // output, set by plugin + TP_Optimizer* optimizer; // output, set by plugin +} TP_OptimizerRegistrationParams; + +#define TP_OPTIMIZER_REGISTRATION_PARAMS_STRUCT_SIZE \ + TF_OFFSET_OF_END(TP_OptimizerRegistrationParams, optimizer) + +// TF_InitGraph is used to do graph optimizer registration. +// Plugin should implement TF_InitGraph to register graph optimizers. +void TF_InitGraph(TP_OptimizerRegistrationParams* params, TF_Status* status); + +// Get a set of node names that must be preserved. They can not be transformed +// or removed during the graph transformation. This includes feed and fetch +// nodes, keep_ops, init_ops. Fills in `num_values` and `storage_size`, they +// will be used in `TF_GetNodesToPreserveList`. +TF_CAPI_EXPORT extern void TF_GetNodesToPreserveListSize( + const TF_GrapplerItem* item, int* num_values, size_t* storage_size, + TF_Status* status); + +// Get a set of node names that must be preserved. They can not be transformed +// or removed during the graph transformation. This includes feed and fetch +// nodes, keep_ops, init_ops. Fills in `values` and `lengths`, each of which +// must point to an array of length at least `num_values`. +// +// The elements of values will point to addresses in `storage` which must be at +// least `storage_size` bytes in length. `num_values` and `storage` can be +// obtained from TF_GetNodesToPreserveSize +// +// Fails if storage_size is too small to hold the requested number of strings. +TF_CAPI_EXPORT extern void TF_GetNodesToPreserveList( + const TF_GrapplerItem* item, char** values, size_t* lengths, int num_values, + void* storage, size_t storage_size, TF_Status* status); + +// Get a set of node names for fetch nodes. Fills in `values` and `lengths`, +// they will be used in `TF_GetFetchNodesList` +TF_CAPI_EXPORT extern void TF_GetFetchNodesListSize(const TF_GrapplerItem* item, + int* num_values, + size_t* storage_size, + TF_Status* status); + +// Get a set of node names for fetch nodes. Fills in `values` and `lengths`, +// each of which must point to an array of length at least `num_values`. +// +// The elements of values will point to addresses in `storage` which must be at +// least `storage_size` bytes in length. `num_values` and `storage` can be +// obtained from TF_GetFetchNodesSize +// +// Fails if storage_size is too small to hold the requested number of strings. +TF_CAPI_EXPORT extern void TF_GetFetchNodesList(const TF_GrapplerItem* item, + char** values, size_t* lengths, + int num_values, void* storage, + size_t storage_size, + TF_Status* status); + +// Infer OpInfo::TensorProperties for graph nodes inputs/outputs. +// +// Typical use case, is to infer tensor properties from a graph, before doing +// optimization pass. Nodes modified during optimization pass have to be +// invalidated, to prevent further incorrect optimizations based on wrong shape +// and data type properties. +typedef struct TF_GraphProperties TF_GraphProperties; + +// Create GraphProperties. The item must outlive the properties. +TF_CAPI_EXPORT extern TF_GraphProperties* TF_NewGraphProperties( + const TF_GrapplerItem* item); + +// Delete GraphProperties. +TF_CAPI_EXPORT extern void TF_DeleteGraphProperties( + TF_GraphProperties* graph_properties); + +// Infer tensor shapes through abstract interpretation. +// If assume_valid_feeds is true, it can help infer shapes in the fanout of fed +// nodes. This may cause incorrectness in graph analyses, but is useful for +// simulation or scheduling. +// If aggressive_shape_inference is true, nodes are executed on the host to +// identify output values when possible and does other aggressive strategies. +// This may cause incorrectness in graph analyses, but is useful for simulation +// or scheduling. +// If include_input_tensor_values is true, the values of constant +// tensors will included in the input properties. +// If include_output_tensor_values is true, the values of constant tensors will +// be included in the output properties. +TF_CAPI_EXPORT extern void TF_InferStatically( + TF_GraphProperties* graph_properties, TF_Bool assume_valid_feeds, + TF_Bool aggressive_shape_inference, TF_Bool include_input_tensor_values, + TF_Bool include_output_tensor_values, TF_Status* s); + +// Get the size of input OpInfo::TensorProperties given node name. +TF_CAPI_EXPORT extern void TF_GetInputPropertiesListSize( + TF_GraphProperties* graph_properties, const char* name, int* num_values, + TF_Status* status); + +// Get the size of output OpInfo::TensorProperties given node name. +TF_CAPI_EXPORT extern void TF_GetOutputPropertiesListSize( + TF_GraphProperties* graph_properties, const char* name, int* num_values, + TF_Status* status); + +// Get a list of input OpInfo::TensorProperties given node name. +// Return the serialized list `properties`. +TF_CAPI_EXPORT extern void TF_GetInputPropertiesList( + TF_GraphProperties* graph_properties, const char* name, + TF_Buffer** properties, int num_values, TF_Status* status); + +// Get a list of output OpInfo::TensorProperties given node name. +// Return the serialized list `properties`. +TF_CAPI_EXPORT extern void TF_GetOutputPropertiesList( + TF_GraphProperties* graph_properties, const char* name, + TF_Buffer** properties, int num_values, TF_Status* status); + +// Helper to maintain a map between function names in a given +// FunctionDefLibrary and function definitions. +// Typical use case, is to look up an OpDef by type name. +typedef struct TF_FunctionLibraryDefinition TF_FunctionLibraryDefinition; + +// Create NewFunctionLibraryDefinition. +TF_CAPI_EXPORT extern TF_FunctionLibraryDefinition* +TF_NewFunctionLibraryDefinition(const TF_Buffer* graph_buf, TF_Status* status); + +// Delete NewFunctionLibraryDefinition. +TF_CAPI_EXPORT extern void TF_DeleteFunctionLibraryDefinition( + TF_FunctionLibraryDefinition* fn_lib); + +// Shorthand for calling LookUp to get the OpDef from FunctionLibraryDefinition +// given op name. The returned OpDef is represented by TF_Buffer. +TF_CAPI_EXPORT extern void TF_LookUpOpDef(TF_FunctionLibraryDefinition* fn_lib, + const char* name, TF_Buffer* buf, + TF_Status* s); + +#ifdef __cplusplus +} // extern "C" +#endif + +#endif // TENSORFLOW_C_EXPERIMENTAL_GRAPPLER_GRAPPLER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/c/experimental/grappler/grappler_internal.h b/third_party/tflite-hdrs/tensorflow/c/experimental/grappler/grappler_internal.h new file mode 100644 index 00000000..799d3bef --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/c/experimental/grappler/grappler_internal.h @@ -0,0 +1,104 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// Classes and utilities that work with Graph C API for internal use. +// This includes functions used for optimizer registration and interfaces needed +// for testing. + +#ifndef TENSORFLOW_C_EXPERIMENTAL_GRAPPLER_GRAPPLER_INTERNAL_H_ +#define TENSORFLOW_C_EXPERIMENTAL_GRAPPLER_GRAPPLER_INTERNAL_H_ + +#include +#include +#include +#include +#include + +#include "tensorflow/c/c_api.h" +#include "tensorflow/c/experimental/grappler/grappler.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer.h" +#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/protobuf/rewriter_config.pb.h" + +namespace tensorflow { +namespace grappler { + +// Plugin initialization function that a device plugin +// must define. +typedef void (*TFInitGraphPluginFn)(TP_OptimizerRegistrationParams* const, + TF_Status* const); + +// Registers Graph optimizers. +Status InitGraphPlugin(void* dso_handle); + +// Allow registering a graph optimizer using a function (used for +// testing). +Status InitGraphPlugin(TFInitGraphPluginFn init_fn); + +struct GrapplerItem; +class Cluster; + +struct TFStatusDeleter { + void operator()(TF_Status* s) const { TF_DeleteStatus(s); } +}; +using OwnedTFStatus = std::unique_ptr; + +struct TFBufferDeleter { + void operator()(TF_Buffer* buf) const { TF_DeleteBuffer(buf); } +}; +using OwnedTFBuffer = std::unique_ptr; + +class CGraphOptimizer : public CustomGraphOptimizer { + public: + explicit CGraphOptimizer(TP_Optimizer optimizer, const char* device_type) + : optimizer_(optimizer), device_type_(device_type) { + if (optimizer.create_func != nullptr) { + c_optimizer_ = (*optimizer_.create_func)(); + } else { + c_optimizer_ = nullptr; + } + } + std::string name() const override { return "PluggableGraphOptimizer"; } + bool UsesFunctionLibrary() const override { return false; } + Status Init( + const tensorflow::RewriterConfig_CustomGraphOptimizer* config) override { + return OkStatus(); + } + Status Optimize(Cluster* cluster, const GrapplerItem& item, + GraphDef* optimized_graph_def) override; + + ~CGraphOptimizer() override { + if (optimizer_.destroy_func != nullptr) { + (*optimizer_.destroy_func)(c_optimizer_); + } + } + + private: + TP_Optimizer optimizer_; + std::string device_type_; + void* c_optimizer_; +}; + +// Registration function to register a CGraphOptimizer along with plugin configs +// and device type. +void CGraphOptimizerRegister( + const PluginGraphOptimizerRegistry::Creator& creator, + const TP_OptimizerConfigs tp_configs, const char* device_type); + +} // namespace grappler +} // namespace tensorflow + +#endif // TENSORFLOW_C_EXPERIMENTAL_GRAPPLER_GRAPPLER_INTERNAL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/c/experimental/next_pluggable_device/c_api.h b/third_party/tflite-hdrs/tensorflow/c/experimental/next_pluggable_device/c_api.h new file mode 100644 index 00000000..036d33dc --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/c/experimental/next_pluggable_device/c_api.h @@ -0,0 +1,156 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_EXPERIMENTAL_NEXT_PLUGGABLE_DEVICE_C_API_H_ +#define TENSORFLOW_C_EXPERIMENTAL_NEXT_PLUGGABLE_DEVICE_C_API_H_ + +#include + +#include "tensorflow/c/c_api_macros.h" +#include "tensorflow/c/kernels.h" +#include "tensorflow/c/kernels_experimental.h" +#include "tensorflow/c/tf_buffer.h" +#include "tensorflow/c/tf_status.h" +#include "xla/pjrt/c/pjrt_c_api.h" + +// -------------------------------------------------------------------------- +// C API for device. The API is under active development and eventually +// should allow registering a plugin device with TensorFlow. + +#ifdef __cplusplus +extern "C" { +#endif + +// TF_Device is a C wrapper to the C++ TF Device class. This is to be passed +// through TF_OpKernelContext, and is opaque to plugin. +typedef struct TF_Device TF_Device; + +typedef struct TF_VariableInfo TF_VariableInfo; + +// Returns a `TF_Device` pointer, which actually points to a C++ `Device`. +// Currently we only allow `NextPluggableDevice` to be casted as `TF_Device`, +// but in theory every this is a C API for every kind of device. +TF_CAPI_EXPORT extern TF_Device* TF_GetDevice(TF_OpKernelContext* ctx); + +// -------------------------- Resource --------------------------------------- +// Create a `tensorflow::PluginResource` to the ResourceMgr provided by the +// `ctx`. The `tensorflow::PluginResource` wraps a resource by plugin (as a +// opaque pointer, since TensorFlow cannot parse it). `delete_func` is needed +// for ResourceMgr to clean up the resource. `status` will be set. +TF_CAPI_EXPORT extern void TF_CreatePluginResource( + TF_OpKernelContext* ctx, const char* container_name, + const char* plugin_resource_name, void* plugin_resource, + void (*delete_func)(void*), TF_Status* status); + +// If the ResourceMgr provided by the `ctx` has a resource +// `plugin_resource_name`, returns it in `*result_plugin_resource`. Otherwise, +// invokes create_func to create the resource. `delete_func` is needed for +// ResourceMgr to clean up the resource. `status` will be set. If `status` is +// not OK, `*result_plugin_resource` will be set as nullptr. +// +// Caller does not take ownership of the `plugin_resource`. +TF_CAPI_EXPORT extern void TF_LookupOrCreatePluginResource( + TF_OpKernelContext* ctx, const char* container_name, + const char* plugin_resource_name, void** result_plugin_resource, + void* (*create_func)(void*), void* create_func_args, + void (*delete_func)(void*), TF_Status* status); + +// ------------------------- VariableInfo ------------------------------------ +TF_CAPI_EXPORT extern TF_VariableInfo* TF_CreateVariableInfoFromContext( + TF_OpKernelContext* ctx, int index, TF_Status* status); + +TF_CAPI_EXPORT extern void TF_LockVariableInfos(TF_VariableInfo** vars, + int num_vars, + TF_Status* status); + +TF_CAPI_EXPORT extern void TF_AllocateTempForVariableInfo( + TF_OpKernelContext* ctx, TF_VariableInfo* var_info, TF_Status* status); + +TF_CAPI_EXPORT extern TF_Tensor* TF_GetTensorFromVariableInfo( + TF_VariableInfo* var_info, TF_Status* status); + +TF_CAPI_EXPORT extern void TF_DeleteVariableInfo(TF_VariableInfo* var_info); + +// --------------------- Coordination service -------------------------------- +// Returns a not owning pointer to the coordination service agent, which is +// opaque to plugin. Plugin OpKernels need to use the accompanying C APIs to +// access coordination service functionalities. +TF_CAPI_EXPORT extern TF_CoordinationServiceAgent* +TF_GetCoordinationServiceAgent(TF_OpKernelContext* ctx); + +// Returns true if the coordination service agent has been initialized. +TF_CAPI_EXPORT extern bool TF_CoordinationServiceIsInitialized( + TF_CoordinationServiceAgent* agent); + +TF_CAPI_EXPORT extern void TF_CoordinationServiceInsertKeyValue( + const char* key, int64_t key_size, const char* value, int64_t value_size, + TF_CoordinationServiceAgent* agent, TF_Status* status); + +// Obtains key-value from coordination service agent. The returned `TF_Buffer` +// is a newly allocated buffer to hold the string key-value, and caller is +// responsible for managing the lifetime. If error, `status` will be set and a +// nullptr will be returned. +TF_CAPI_EXPORT extern TF_Buffer* TF_CoordinationServiceGetKeyValue( + const char* key, int64_t key_size, TF_CoordinationServiceAgent* agent, + TF_Status* status); + +TF_CAPI_EXPORT extern TF_Buffer* TF_CoordinationServiceGetKeyValueWithTimeout( + const char* key, int64_t key_size, int64_t timeout_seconds, + TF_CoordinationServiceAgent* agent, TF_Status* status); + +TF_CAPI_EXPORT extern TF_Buffer* TF_CoordinationServiceTryGetKeyValue( + const char* key, int64_t key_size, TF_CoordinationServiceAgent* agent, + TF_Status* status); + +TF_CAPI_EXPORT extern void TF_CoordinationServiceDeleteKeyValue( + const char* key, int64_t key_size, TF_CoordinationServiceAgent* agent, + TF_Status* status); + +// ---------------------------- PJRT ----------------------------------------- +// Passes the pointer to a vector of PJRT_NamedValue and number of options to +// set options for creating a PJRT client. Passes nullptr for create_options and +// 0 for num_options if no options need to be set. You can use +// ConvertToPjRtNamedValueList in +// tensorflow/compiler/xla/pjrt/c/pjrt_c_api_helpers.h to generate the options. +TF_CAPI_EXPORT extern void TF_CreateAndSetPjRtCApiClient( + const char* device_type, TF_Status* status, PJRT_NamedValue* create_options, + int num_options); + +// Resets the PjRt client for a device. After this, `TF_GetPjRtCClient` will +// returns an error for that device. +TF_CAPI_EXPORT extern void TF_ResetPjRtCClient(const char* device_type, + TF_Status* status); + +// Gets the `PJRT_Client*` stored in TF global ResourceManager. +TF_CAPI_EXPORT extern PJRT_Client* TF_GetPjRtCClient(const char* device_type, + TF_Status* status); + +// Gets the `PJRT_Buffer*` stored in the tensor. The status will contain error +// if the tensor does not have a `PjRtCApiBuffer`. +TF_CAPI_EXPORT extern PJRT_Buffer* TF_GetPjRtCBuffer(TF_Tensor* c_tensor, + TF_Status* status); + +// Creates a `PjRtCApiBuffer` with the `PJRT_Buffer*` passed in and set to the +// tensor. +TF_CAPI_EXPORT extern void TF_CreatePjRtBuffer(TF_Tensor* c_tensor, + PJRT_Buffer* c_buffer, + const char* device_type, + TF_Status* status); + +#ifdef __cplusplus +} // extern "C" +#endif + +#endif // TENSORFLOW_C_EXPERIMENTAL_NEXT_PLUGGABLE_DEVICE_C_API_H_ diff --git a/third_party/tflite-hdrs/tensorflow/c/experimental/next_pluggable_device/tensor_pjrt_buffer_util.h b/third_party/tflite-hdrs/tensorflow/c/experimental/next_pluggable_device/tensor_pjrt_buffer_util.h new file mode 100644 index 00000000..c2378b68 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/c/experimental/next_pluggable_device/tensor_pjrt_buffer_util.h @@ -0,0 +1,39 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_C_EXPERIMENTAL_NEXT_PLUGGABLE_DEVICE_TENSOR_PJRT_BUFFER_UTIL_H_ +#define TENSORFLOW_C_EXPERIMENTAL_NEXT_PLUGGABLE_DEVICE_TENSOR_PJRT_BUFFER_UTIL_H_ + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "xla/pjrt/c/pjrt_c_api.h" +#include "xla/pjrt/pjrt_c_api_client.h" +#include "tensorflow/core/framework/tensor.h" + +namespace tensorflow { + +absl::StatusOr GetPjRtCBufferFromTensor(const Tensor* tensor); + +absl::Status SetPjRtCBufferToTensor(PJRT_Buffer* c_buffer, + xla::PjRtCApiClient* c_api_client, + Tensor* tensor); + +absl::StatusOr GetPjRtCApiClient( + const DeviceType& device_type); + +absl::Status ResetPjRtClient(const DeviceType& device_type); + +} // namespace tensorflow + +#endif // TENSORFLOW_C_EXPERIMENTAL_NEXT_PLUGGABLE_DEVICE_TENSOR_PJRT_BUFFER_UTIL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/c/experimental/ops/array_ops.h b/third_party/tflite-hdrs/tensorflow/c/experimental/ops/array_ops.h new file mode 100644 index 00000000..0af99e9f --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/c/experimental/ops/array_ops.h @@ -0,0 +1,70 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This file is MACHINE GENERATED! Do not edit. + +#ifndef TENSORFLOW_C_EXPERIMENTAL_OPS_ARRAY_OPS_H_ +#define TENSORFLOW_C_EXPERIMENTAL_OPS_ARRAY_OPS_H_ + +#include "absl/status/status.h" +#include "absl/types/span.h" +#include "tensorflow/c/eager/abstract_context.h" +#include "tensorflow/c/eager/abstract_tensor_handle.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/platform/status.h" + +namespace tensorflow { +namespace ops { + +// Return a tensor with the same shape and contents as the input tensor or +// value. +absl::Status Identity(AbstractContext* ctx, AbstractTensorHandle* const input, + AbstractTensorHandle** output, const char* name = nullptr, + const char* raw_device_name = nullptr); + +// Returns a list of tensors with the same shapes and contents as the input +absl::Status IdentityN(AbstractContext* ctx, + absl::Span input, + absl::Span output, + const char* name = nullptr, + const char* raw_device_name = nullptr); + +// Returns a tensor of zeros with the same shape and type as x. +absl::Status ZerosLike(AbstractContext* ctx, AbstractTensorHandle* const x, + AbstractTensorHandle** y, const char* name = nullptr, + const char* raw_device_name = nullptr); + +// Returns the shape of a tensor. +absl::Status Shape(AbstractContext* ctx, AbstractTensorHandle* const input, + AbstractTensorHandle** output, DataType out_type = DT_INT32, + const char* name = nullptr, + const char* raw_device_name = nullptr); + +// Inserts a dimension of 1 into a tensor's shape. +absl::Status ExpandDims(AbstractContext* ctx, AbstractTensorHandle* const input, + AbstractTensorHandle* const dim, + AbstractTensorHandle** output, + const char* name = nullptr, + const char* raw_device_name = nullptr); + +// Returns a tensor of ones with the same shape and type as x. +absl::Status OnesLike(AbstractContext* ctx, AbstractTensorHandle* const x, + AbstractTensorHandle** y, const char* name = nullptr, + const char* raw_device_name = nullptr); + +} // namespace ops +} // namespace tensorflow + +#endif // TENSORFLOW_C_EXPERIMENTAL_OPS_ARRAY_OPS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/c/experimental/ops/gen/common/case_format.h b/third_party/tflite-hdrs/tensorflow/c/experimental/ops/gen/common/case_format.h new file mode 100644 index 00000000..f8255f6a --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/c/experimental/ops/gen/common/case_format.h @@ -0,0 +1,46 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_C_EXPERIMENTAL_OPS_GEN_COMMON_CASE_FORMAT_H_ +#define TENSORFLOW_C_EXPERIMENTAL_OPS_GEN_COMMON_CASE_FORMAT_H_ + +#include "tensorflow/core/platform/str_util.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +namespace generator { + +// Conversion routines between upper/lower camel/snake case formats, e.g.: +// "lowerCamelCase" +// "lower_snake_case" +// "UpperCamelCase" +// "UPPER_SNAKE_CASE" +// +// The input format is automatically detected. +// The delimiter must be specified if it is other than an underscore ('_') +// for conversion either *to* or *from* snake case. +// +// Leading and trailing delimiters are supported, e.g.: +// "__OneTwo__" (in camel case) <==> "__ONE_TWO__" (in snake case) +// +// Note: performance not yet tested. +string toLowerCamel(const string &s, const char delimiter = '_'); +string toLowerSnake(const string &s, const char delimiter = '_'); +string toUpperCamel(const string &s, const char delimiter = '_'); +string toUpperSnake(const string &s, const char delimiter = '_'); + +} // namespace generator +} // namespace tensorflow + +#endif // TENSORFLOW_C_EXPERIMENTAL_OPS_GEN_COMMON_CASE_FORMAT_H_ diff --git a/third_party/tflite-hdrs/tensorflow/c/experimental/ops/gen/common/controller.h b/third_party/tflite-hdrs/tensorflow/c/experimental/ops/gen/common/controller.h new file mode 100644 index 00000000..e152efeb --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/c/experimental/ops/gen/common/controller.h @@ -0,0 +1,57 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_C_EXPERIMENTAL_OPS_GEN_COMMON_CONTROLLER_H_ +#define TENSORFLOW_C_EXPERIMENTAL_OPS_GEN_COMMON_CONTROLLER_H_ + +#include + +#include "tensorflow/c/experimental/ops/gen/common/path_config.h" +#include "tensorflow/c/experimental/ops/gen/common/source_code.h" +#include "tensorflow/c/experimental/ops/gen/model/op_spec.h" +#include "tensorflow/core/framework/op_def.pb.h" +#include "tensorflow/core/framework/op_gen_lib.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +namespace generator { + +class Controller { + public: + explicit Controller(PathConfig path_config, Env* env = Env::Default()); + virtual ~Controller(); + const void WriteFile(const string& file_path, const SourceCode& code) const; + const std::vector& GetModelOps() const; + + private: + void InitializeOpApi(); + void BuildModel(); + + // Data model: Ops to generate + std::vector operators_; + + // Configuration + Env* env_; + PathConfig path_config_; + + // Initialized TensorFlow Op/API definitions + OpList op_list_; + ApiDefMap* api_def_map_; +}; + +} // namespace generator +} // namespace tensorflow + +#endif // TENSORFLOW_C_EXPERIMENTAL_OPS_GEN_COMMON_CONTROLLER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/c/experimental/ops/gen/common/path_config.h b/third_party/tflite-hdrs/tensorflow/c/experimental/ops/gen/common/path_config.h new file mode 100644 index 00000000..ce29063b --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/c/experimental/ops/gen/common/path_config.h @@ -0,0 +1,42 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_C_EXPERIMENTAL_OPS_GEN_COMMON_PATH_CONFIG_H_ +#define TENSORFLOW_C_EXPERIMENTAL_OPS_GEN_COMMON_PATH_CONFIG_H_ + +#include + +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +namespace generator { + +struct PathConfig { + string output_path; + std::vector op_names; + std::vector api_dirs; + string tf_prefix_dir; + string tf_root_dir; + string tf_output_dir; + + explicit PathConfig() = default; + explicit PathConfig(const string &output_dir, const string &source_dir, + const string &api_dir_list, + const std::vector op_names); +}; + +} // namespace generator +} // namespace tensorflow + +#endif // TENSORFLOW_C_EXPERIMENTAL_OPS_GEN_COMMON_PATH_CONFIG_H_ diff --git a/third_party/tflite-hdrs/tensorflow/c/experimental/ops/gen/common/source_code.h b/third_party/tflite-hdrs/tensorflow/c/experimental/ops/gen/common/source_code.h new file mode 100644 index 00000000..df1aa90a --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/c/experimental/ops/gen/common/source_code.h @@ -0,0 +1,54 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_C_EXPERIMENTAL_OPS_GEN_COMMON_SOURCE_CODE_H_ +#define TENSORFLOW_C_EXPERIMENTAL_OPS_GEN_COMMON_SOURCE_CODE_H_ + +#include + +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +namespace generator { + +class SourceCode { + public: + string Render() const; + void SetSpacesPerIndent(int spaces_per_indent) { + spaces_per_indent_ = spaces_per_indent; + } + + void AddLineWithIndent(const string &line); + void AddLineWithoutIndent(const string &line); + void AddBlankLine(); + void IncreaseIndent(); + void DecreaseIndent(); + + private: + struct Line { + int indent; + string text; + }; + + void ValidateAndAddLine(int indent_level, const string &raw_line); + + int spaces_per_indent_ = 2; + int current_indent_ = 0; + std::vector lines_; +}; + +} // namespace generator +} // namespace tensorflow + +#endif // TENSORFLOW_C_EXPERIMENTAL_OPS_GEN_COMMON_SOURCE_CODE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/c/experimental/ops/gen/common/view_util.h b/third_party/tflite-hdrs/tensorflow/c/experimental/ops/gen/common/view_util.h new file mode 100644 index 00000000..7ab437a9 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/c/experimental/ops/gen/common/view_util.h @@ -0,0 +1,33 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_C_EXPERIMENTAL_OPS_GEN_COMMON_VIEW_UTIL_H_ +#define TENSORFLOW_C_EXPERIMENTAL_OPS_GEN_COMMON_VIEW_UTIL_H_ + +#include + +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +namespace generator { + +string Call(const string &function, std::vector arguments); +string Call(const string &object, const string &method, + std::vector arguments, const char *oper = "->"); +string Quoted(const string &s); + +} // namespace generator +} // namespace tensorflow + +#endif // TENSORFLOW_C_EXPERIMENTAL_OPS_GEN_COMMON_VIEW_UTIL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/c/experimental/ops/gen/cpp/cpp_generator.h b/third_party/tflite-hdrs/tensorflow/c/experimental/ops/gen/cpp/cpp_generator.h new file mode 100644 index 00000000..0a7b08cd --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/c/experimental/ops/gen/cpp/cpp_generator.h @@ -0,0 +1,49 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_C_EXPERIMENTAL_OPS_GEN_CPP_CPP_GENERATOR_H_ +#define TENSORFLOW_C_EXPERIMENTAL_OPS_GEN_CPP_CPP_GENERATOR_H_ + +#include "tensorflow/c/experimental/ops/gen/common/controller.h" +#include "tensorflow/c/experimental/ops/gen/common/path_config.h" +#include "tensorflow/c/experimental/ops/gen/common/source_code.h" +#include "tensorflow/c/experimental/ops/gen/cpp/renderers/cpp_config.h" +#include "tensorflow/c/experimental/ops/gen/cpp/renderers/renderer_context.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +namespace generator { + +class CppGenerator { + public: + explicit CppGenerator(cpp::CppConfig cpp_config, PathConfig path_config); + SourceCode HeaderFileContents() const; + SourceCode SourceFileContents() const; + string HeaderFileName() const; + string SourceFileName() const; + void WriteHeaderFile() const; + void WriteSourceFile() const; + + private: + SourceCode GenerateOneFile(cpp::RendererContext::Mode mode) const; + + Controller controller_; + cpp::CppConfig cpp_config_; + PathConfig path_config_; +}; + +} // namespace generator +} // namespace tensorflow + +#endif // TENSORFLOW_C_EXPERIMENTAL_OPS_GEN_CPP_CPP_GENERATOR_H_ diff --git a/third_party/tflite-hdrs/tensorflow/c/experimental/ops/gen/cpp/renderers/cpp_config.h b/third_party/tflite-hdrs/tensorflow/c/experimental/ops/gen/cpp/renderers/cpp_config.h new file mode 100644 index 00000000..fa7571d9 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/c/experimental/ops/gen/cpp/renderers/cpp_config.h @@ -0,0 +1,40 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_C_EXPERIMENTAL_OPS_GEN_CPP_RENDERERS_CPP_CONFIG_H_ +#define TENSORFLOW_C_EXPERIMENTAL_OPS_GEN_CPP_RENDERERS_CPP_CONFIG_H_ + +#include + +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +namespace generator { +namespace cpp { + +struct CppConfig { + string category; + string unit; + std::vector namespaces; + + explicit CppConfig() = default; + explicit CppConfig(const string &category, + const string &name_space = "tensorflow::ops"); +}; + +} // namespace cpp +} // namespace generator +} // namespace tensorflow + +#endif // TENSORFLOW_C_EXPERIMENTAL_OPS_GEN_CPP_RENDERERS_CPP_CONFIG_H_ diff --git a/third_party/tflite-hdrs/tensorflow/c/experimental/ops/gen/cpp/renderers/cpp_file_renderer.h b/third_party/tflite-hdrs/tensorflow/c/experimental/ops/gen/cpp/renderers/cpp_file_renderer.h new file mode 100644 index 00000000..4bfc3f92 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/c/experimental/ops/gen/cpp/renderers/cpp_file_renderer.h @@ -0,0 +1,48 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_C_EXPERIMENTAL_OPS_GEN_CPP_RENDERERS_CPP_FILE_RENDERER_H_ +#define TENSORFLOW_C_EXPERIMENTAL_OPS_GEN_CPP_RENDERERS_CPP_FILE_RENDERER_H_ + +#include + +#include "tensorflow/c/experimental/ops/gen/cpp/renderers/guard_renderer.h" +#include "tensorflow/c/experimental/ops/gen/cpp/renderers/include_renderer.h" +#include "tensorflow/c/experimental/ops/gen/cpp/renderers/namespace_renderer.h" +#include "tensorflow/c/experimental/ops/gen/cpp/renderers/renderer.h" +#include "tensorflow/c/experimental/ops/gen/cpp/renderers/renderer_context.h" +#include "tensorflow/c/experimental/ops/gen/cpp/views/op_view.h" + +namespace tensorflow { +namespace generator { +namespace cpp { + +class CppFileRenderer : public Renderer { + public: + explicit CppFileRenderer(RendererContext context, + const std::vector &ops); + void Render(); + + private: + GuardRenderer guard_; + NamespaceRenderer name_space_; + IncludeRenderer includes_; + std::vector ops_; +}; + +} // namespace cpp +} // namespace generator +} // namespace tensorflow + +#endif // TENSORFLOW_C_EXPERIMENTAL_OPS_GEN_CPP_RENDERERS_CPP_FILE_RENDERER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/c/experimental/ops/gen/cpp/renderers/guard_renderer.h b/third_party/tflite-hdrs/tensorflow/c/experimental/ops/gen/cpp/renderers/guard_renderer.h new file mode 100644 index 00000000..a45fe89a --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/c/experimental/ops/gen/cpp/renderers/guard_renderer.h @@ -0,0 +1,41 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_C_EXPERIMENTAL_OPS_GEN_CPP_RENDERERS_GUARD_RENDERER_H_ +#define TENSORFLOW_C_EXPERIMENTAL_OPS_GEN_CPP_RENDERERS_GUARD_RENDERER_H_ + +#include "tensorflow/c/experimental/ops/gen/cpp/renderers/renderer.h" +#include "tensorflow/c/experimental/ops/gen/cpp/renderers/renderer_context.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +namespace generator { +namespace cpp { + +class GuardRenderer : public Renderer { + public: + explicit GuardRenderer(RendererContext context); + + void Open(); + void Close(); + + private: + string guard_; +}; + +} // namespace cpp +} // namespace generator +} // namespace tensorflow + +#endif // TENSORFLOW_C_EXPERIMENTAL_OPS_GEN_CPP_RENDERERS_GUARD_RENDERER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/c/experimental/ops/gen/cpp/renderers/include_renderer.h b/third_party/tflite-hdrs/tensorflow/c/experimental/ops/gen/cpp/renderers/include_renderer.h new file mode 100644 index 00000000..e43715a6 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/c/experimental/ops/gen/cpp/renderers/include_renderer.h @@ -0,0 +1,42 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_C_EXPERIMENTAL_OPS_GEN_CPP_RENDERERS_INCLUDE_RENDERER_H_ +#define TENSORFLOW_C_EXPERIMENTAL_OPS_GEN_CPP_RENDERERS_INCLUDE_RENDERER_H_ + +#include "tensorflow/c/experimental/ops/gen/cpp/renderers/renderer.h" +#include "tensorflow/c/experimental/ops/gen/cpp/renderers/renderer_context.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +namespace generator { +namespace cpp { + +class IncludeRenderer : public Renderer { + public: + explicit IncludeRenderer(RendererContext context); + + string SelfHeaderPath() const; + void SelfHeader(); + void Headers(); + + private: + void Include(const string &tf_file_path); +}; + +} // namespace cpp +} // namespace generator +} // namespace tensorflow + +#endif // TENSORFLOW_C_EXPERIMENTAL_OPS_GEN_CPP_RENDERERS_INCLUDE_RENDERER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/c/experimental/ops/gen/cpp/renderers/namespace_renderer.h b/third_party/tflite-hdrs/tensorflow/c/experimental/ops/gen/cpp/renderers/namespace_renderer.h new file mode 100644 index 00000000..fd8ccf95 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/c/experimental/ops/gen/cpp/renderers/namespace_renderer.h @@ -0,0 +1,40 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_C_EXPERIMENTAL_OPS_GEN_CPP_RENDERERS_NAMESPACE_RENDERER_H_ +#define TENSORFLOW_C_EXPERIMENTAL_OPS_GEN_CPP_RENDERERS_NAMESPACE_RENDERER_H_ + +#include + +#include "tensorflow/c/experimental/ops/gen/cpp/renderers/renderer.h" +#include "tensorflow/c/experimental/ops/gen/cpp/renderers/renderer_context.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +namespace generator { +namespace cpp { + +class NamespaceRenderer : public Renderer { + public: + explicit NamespaceRenderer(RendererContext context); + + void Open(); + void Close(); +}; + +} // namespace cpp +} // namespace generator +} // namespace tensorflow + +#endif // TENSORFLOW_C_EXPERIMENTAL_OPS_GEN_CPP_RENDERERS_NAMESPACE_RENDERER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/c/experimental/ops/gen/cpp/renderers/op_comment_renderer.h b/third_party/tflite-hdrs/tensorflow/c/experimental/ops/gen/cpp/renderers/op_comment_renderer.h new file mode 100644 index 00000000..9131cc94 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/c/experimental/ops/gen/cpp/renderers/op_comment_renderer.h @@ -0,0 +1,40 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_C_EXPERIMENTAL_OPS_GEN_CPP_RENDERERS_OP_COMMENT_RENDERER_H_ +#define TENSORFLOW_C_EXPERIMENTAL_OPS_GEN_CPP_RENDERERS_OP_COMMENT_RENDERER_H_ + +#include "tensorflow/c/experimental/ops/gen/cpp/renderers/renderer.h" +#include "tensorflow/c/experimental/ops/gen/cpp/renderers/renderer_context.h" +#include "tensorflow/c/experimental/ops/gen/cpp/views/op_view.h" + +namespace tensorflow { +namespace generator { +namespace cpp { + +class OpCommentRenderer : public Renderer { + public: + explicit OpCommentRenderer(RendererContext context, OpView op); + + void Render(); + + private: + OpView op_; +}; + +} // namespace cpp +} // namespace generator +} // namespace tensorflow + +#endif // TENSORFLOW_C_EXPERIMENTAL_OPS_GEN_CPP_RENDERERS_OP_COMMENT_RENDERER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/c/experimental/ops/gen/cpp/renderers/op_implementation_renderer.h b/third_party/tflite-hdrs/tensorflow/c/experimental/ops/gen/cpp/renderers/op_implementation_renderer.h new file mode 100644 index 00000000..98c3b0d7 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/c/experimental/ops/gen/cpp/renderers/op_implementation_renderer.h @@ -0,0 +1,45 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_C_EXPERIMENTAL_OPS_GEN_CPP_RENDERERS_OP_IMPLEMENTATION_RENDERER_H_ +#define TENSORFLOW_C_EXPERIMENTAL_OPS_GEN_CPP_RENDERERS_OP_IMPLEMENTATION_RENDERER_H_ + +#include "tensorflow/c/experimental/ops/gen/cpp/renderers/renderer.h" +#include "tensorflow/c/experimental/ops/gen/cpp/renderers/renderer_context.h" +#include "tensorflow/c/experimental/ops/gen/cpp/views/op_view.h" + +namespace tensorflow { +namespace generator { +namespace cpp { + +class OpImplementationRenderer : public Renderer { + public: + explicit OpImplementationRenderer(RendererContext context, OpView op); + void Render(); + + private: + void RenderInitialization(); + void RenderExecutionListOp(); + void RenderExecutionMultipleOutputs(); + void RenderExecutionZeroOutputs(); + void RenderExecutionSingleOutput(); + + OpView op_; +}; + +} // namespace cpp +} // namespace generator +} // namespace tensorflow + +#endif // TENSORFLOW_C_EXPERIMENTAL_OPS_GEN_CPP_RENDERERS_OP_IMPLEMENTATION_RENDERER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/c/experimental/ops/gen/cpp/renderers/op_renderer.h b/third_party/tflite-hdrs/tensorflow/c/experimental/ops/gen/cpp/renderers/op_renderer.h new file mode 100644 index 00000000..3360e14e --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/c/experimental/ops/gen/cpp/renderers/op_renderer.h @@ -0,0 +1,44 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_C_EXPERIMENTAL_OPS_GEN_CPP_RENDERERS_OP_RENDERER_H_ +#define TENSORFLOW_C_EXPERIMENTAL_OPS_GEN_CPP_RENDERERS_OP_RENDERER_H_ + +#include "tensorflow/c/experimental/ops/gen/cpp/renderers/op_comment_renderer.h" +#include "tensorflow/c/experimental/ops/gen/cpp/renderers/renderer.h" +#include "tensorflow/c/experimental/ops/gen/cpp/renderers/renderer_context.h" +#include "tensorflow/c/experimental/ops/gen/cpp/views/op_view.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +namespace generator { +namespace cpp { + +class OpRenderer : public Renderer { + public: + explicit OpRenderer(RendererContext context, OpView op); + void Render(); + + private: + OpView op_; + OpCommentRenderer comment_; + + string Signature() const; +}; + +} // namespace cpp +} // namespace generator +} // namespace tensorflow + +#endif // TENSORFLOW_C_EXPERIMENTAL_OPS_GEN_CPP_RENDERERS_OP_RENDERER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/c/experimental/ops/gen/cpp/renderers/renderer.h b/third_party/tflite-hdrs/tensorflow/c/experimental/ops/gen/cpp/renderers/renderer.h new file mode 100644 index 00000000..b6168b19 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/c/experimental/ops/gen/cpp/renderers/renderer.h @@ -0,0 +1,100 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_C_EXPERIMENTAL_OPS_GEN_CPP_RENDERERS_RENDERER_H_ +#define TENSORFLOW_C_EXPERIMENTAL_OPS_GEN_CPP_RENDERERS_RENDERER_H_ + +#include "absl/strings/string_view.h" +#include "absl/strings/substitute.h" +#include "tensorflow/c/experimental/ops/gen/cpp/renderers/renderer_context.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +namespace generator { +namespace cpp { + +class Renderer { + public: + explicit Renderer(RendererContext context); + + protected: + // Append a blank line. + Renderer &BlankLine(); + + // Append a line of source code, left-justified (not indented). + // Use for preprocessors directives ("#include"), namespaces, etc. + Renderer &CodeLine(const string &text); + template + Renderer CodeLine(absl::string_view text, const Args &...args) { + return CodeLine(absl::Substitute(text, args...)); + } + + // Append a multiline string of source code, left-justified (not indented). + // Note: Trims leading/trailing whitespace including newlines, making this + // method convenient for multiline raw strings. + // Newlines ('\n') are allowed/expected. + Renderer &CodeLines(const string &text); + template + Renderer CodeLines(absl::string_view text, const Args &...args) { + return CodeLines(absl::Substitute(text, args...)); + } + + // Indent and append a C++ statement. + // Note: do *not* include a trailing semicolon in the statement text. + Renderer &Statement(const string &text); + template + Renderer Statement(absl::string_view text, const Args &...args) { + return Statement(absl::Substitute(text, args...)); + } + + // Indent and append a call to a TF method returning a Status to check. + // Note: do *not* include a trailing semicolon in the statement text. + Renderer &TFStatement(const string &text); + template + Renderer TFStatement(absl::string_view text, const Args &...args) { + return TFStatement(absl::Substitute(text, args...)); + } + + // Indent and append a C++ single-line style comment (using '//'). + Renderer &CommentLine(const string &text = ""); + template + Renderer CommentLine(absl::string_view text, const Args &...args) { + return CommentLine(absl::Substitute(text, args...)); + } + + // Append a line of code which starts a new block: trailing with '{') and + // indenting. + Renderer &BlockOpen(const string &text); + template + Renderer BlockOpen(absl::string_view text, const Args &...args) { + return BlockOpen(absl::Substitute(text, args...)); + } + + // Append a line of code ending a block: unindenting and adding '}'. + // Note: optional trailing text is often a comment, e.g. '// namespace xyz'. + Renderer &BlockClose(const string &text = ""); + template + Renderer BlockClose(absl::string_view text, const Args &...args) { + return BlockClose(absl::Substitute(text, args...)); + } + + protected: + RendererContext context_; +}; + +} // namespace cpp +} // namespace generator +} // namespace tensorflow + +#endif // TENSORFLOW_C_EXPERIMENTAL_OPS_GEN_CPP_RENDERERS_RENDERER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/c/experimental/ops/gen/cpp/renderers/renderer_context.h b/third_party/tflite-hdrs/tensorflow/c/experimental/ops/gen/cpp/renderers/renderer_context.h new file mode 100644 index 00000000..c0eb03e3 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/c/experimental/ops/gen/cpp/renderers/renderer_context.h @@ -0,0 +1,39 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_C_EXPERIMENTAL_OPS_GEN_CPP_RENDERERS_RENDERER_CONTEXT_H_ +#define TENSORFLOW_C_EXPERIMENTAL_OPS_GEN_CPP_RENDERERS_RENDERER_CONTEXT_H_ + +#include "tensorflow/c/experimental/ops/gen/common/path_config.h" +#include "tensorflow/c/experimental/ops/gen/common/source_code.h" +#include "tensorflow/c/experimental/ops/gen/cpp/renderers/cpp_config.h" + +namespace tensorflow { +namespace generator { +namespace cpp { + +struct RendererContext { + enum Mode { kHeader = 0, kSource }; + + Mode mode; + SourceCode &code; + CppConfig cpp_config; + PathConfig path_config; +}; + +} // namespace cpp +} // namespace generator +} // namespace tensorflow + +#endif // TENSORFLOW_C_EXPERIMENTAL_OPS_GEN_CPP_RENDERERS_RENDERER_CONTEXT_H_ diff --git a/third_party/tflite-hdrs/tensorflow/c/experimental/ops/gen/cpp/views/arg_type_view.h b/third_party/tflite-hdrs/tensorflow/c/experimental/ops/gen/cpp/views/arg_type_view.h new file mode 100644 index 00000000..d071f62c --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/c/experimental/ops/gen/cpp/views/arg_type_view.h @@ -0,0 +1,39 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_C_EXPERIMENTAL_OPS_GEN_CPP_VIEWS_ARG_TYPE_VIEW_H_ +#define TENSORFLOW_C_EXPERIMENTAL_OPS_GEN_CPP_VIEWS_ARG_TYPE_VIEW_H_ + +#include "tensorflow/c/experimental/ops/gen/model/arg_type.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +namespace generator { +namespace cpp { + +class ArgTypeView { + public: + explicit ArgTypeView(ArgType arg_type); + + string TypeName() const; + + private: + ArgType arg_type_; +}; + +} // namespace cpp +} // namespace generator +} // namespace tensorflow + +#endif // TENSORFLOW_C_EXPERIMENTAL_OPS_GEN_CPP_VIEWS_ARG_TYPE_VIEW_H_ diff --git a/third_party/tflite-hdrs/tensorflow/c/experimental/ops/gen/cpp/views/arg_view.h b/third_party/tflite-hdrs/tensorflow/c/experimental/ops/gen/cpp/views/arg_view.h new file mode 100644 index 00000000..49085d3a --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/c/experimental/ops/gen/cpp/views/arg_view.h @@ -0,0 +1,47 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_C_EXPERIMENTAL_OPS_GEN_CPP_VIEWS_ARG_VIEW_H_ +#define TENSORFLOW_C_EXPERIMENTAL_OPS_GEN_CPP_VIEWS_ARG_VIEW_H_ + +#include + +#include "tensorflow/c/experimental/ops/gen/cpp/views/arg_type_view.h" +#include "tensorflow/c/experimental/ops/gen/model/arg_spec.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +namespace generator { +namespace cpp { + +class ArgView { + public: + explicit ArgView(ArgSpec arg); + + string VariableName() const; + string SetterMethod() const; + std::vector SetterArgs() const; + int Position() const; + + bool IsList() const; + + private: + ArgSpec arg_; +}; + +} // namespace cpp +} // namespace generator +} // namespace tensorflow + +#endif // TENSORFLOW_C_EXPERIMENTAL_OPS_GEN_CPP_VIEWS_ARG_VIEW_H_ diff --git a/third_party/tflite-hdrs/tensorflow/c/experimental/ops/gen/cpp/views/attr_view.h b/third_party/tflite-hdrs/tensorflow/c/experimental/ops/gen/cpp/views/attr_view.h new file mode 100644 index 00000000..70149aa6 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/c/experimental/ops/gen/cpp/views/attr_view.h @@ -0,0 +1,50 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_C_EXPERIMENTAL_OPS_GEN_CPP_VIEWS_ATTR_VIEW_H_ +#define TENSORFLOW_C_EXPERIMENTAL_OPS_GEN_CPP_VIEWS_ATTR_VIEW_H_ + +#include + +#include "tensorflow/c/experimental/ops/gen/model/attr_spec.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +namespace generator { +namespace cpp { + +class AttrView { + public: + explicit AttrView(AttrSpec attr) : attr_(attr) {} + + string VariableName() const; + string VariableType() const; + string AttrNameString() const; + string VariableStrLen() const; + string VariableSpanData() const; + string VariableSpanLen() const; + string DefaultValue() const; + string InputArg(bool with_default_value) const; + string SetterMethod() const; + std::vector SetterArgs() const; + + private: + AttrSpec attr_; +}; + +} // namespace cpp +} // namespace generator +} // namespace tensorflow + +#endif // TENSORFLOW_C_EXPERIMENTAL_OPS_GEN_CPP_VIEWS_ATTR_VIEW_H_ diff --git a/third_party/tflite-hdrs/tensorflow/c/experimental/ops/gen/cpp/views/op_argument_view.h b/third_party/tflite-hdrs/tensorflow/c/experimental/ops/gen/cpp/views/op_argument_view.h new file mode 100644 index 00000000..ff3e2b51 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/c/experimental/ops/gen/cpp/views/op_argument_view.h @@ -0,0 +1,46 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_C_EXPERIMENTAL_OPS_GEN_CPP_VIEWS_OP_ARGUMENT_VIEW_H_ +#define TENSORFLOW_C_EXPERIMENTAL_OPS_GEN_CPP_VIEWS_OP_ARGUMENT_VIEW_H_ + +#include "tensorflow/c/experimental/ops/gen/model/arg_spec.h" +#include "tensorflow/c/experimental/ops/gen/model/attr_spec.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +namespace generator { +namespace cpp { + +class OpArgumentView { + public: + explicit OpArgumentView(ArgSpec arg); + explicit OpArgumentView(AttrSpec attr); + explicit OpArgumentView(string type, string var, string def = ""); + + string Declaration() const; + string Initializer() const; + bool HasDefaultValue() const; + + private: + string type_name_; + string variable_name_; + string default_value_; +}; + +} // namespace cpp +} // namespace generator +} // namespace tensorflow + +#endif // TENSORFLOW_C_EXPERIMENTAL_OPS_GEN_CPP_VIEWS_OP_ARGUMENT_VIEW_H_ diff --git a/third_party/tflite-hdrs/tensorflow/c/experimental/ops/gen/cpp/views/op_view.h b/third_party/tflite-hdrs/tensorflow/c/experimental/ops/gen/cpp/views/op_view.h new file mode 100644 index 00000000..35b8858b --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/c/experimental/ops/gen/cpp/views/op_view.h @@ -0,0 +1,63 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_C_EXPERIMENTAL_OPS_GEN_CPP_VIEWS_OP_VIEW_H_ +#define TENSORFLOW_C_EXPERIMENTAL_OPS_GEN_CPP_VIEWS_OP_VIEW_H_ + +#include + +#include "tensorflow/c/experimental/ops/gen/cpp/views/arg_view.h" +#include "tensorflow/c/experimental/ops/gen/cpp/views/attr_view.h" +#include "tensorflow/c/experimental/ops/gen/cpp/views/op_argument_view.h" +#include "tensorflow/c/experimental/ops/gen/model/op_spec.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +namespace generator { +namespace cpp { + +class OpView { + public: + explicit OpView(OpSpec op); + + const std::vector &Inputs() const; + const std::vector &Outputs() const; + const std::vector &Attributes() const; + const std::vector &AllArguments() const; + + int NumInputs() const; + int NumOutputs() const; + ArgView OnlyInput() const; + ArgView OnlyOutput() const; + + string FunctionName() const; + string VariableName() const; + string OpNameString() const; + string Summary() const; + std::vector Description() const; + bool IsListOp() const; + + private: + OpSpec op_; + std::vector input_args_; + std::vector output_args_; + std::vector argument_attrs_; + std::vector all_arguments_; +}; + +} // namespace cpp +} // namespace generator +} // namespace tensorflow + +#endif // TENSORFLOW_C_EXPERIMENTAL_OPS_GEN_CPP_VIEWS_OP_VIEW_H_ diff --git a/third_party/tflite-hdrs/tensorflow/c/experimental/ops/gen/model/arg_spec.h b/third_party/tflite-hdrs/tensorflow/c/experimental/ops/gen/model/arg_spec.h new file mode 100644 index 00000000..d18c0d62 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/c/experimental/ops/gen/model/arg_spec.h @@ -0,0 +1,53 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_C_EXPERIMENTAL_OPS_GEN_MODEL_ARG_SPEC_H_ +#define TENSORFLOW_C_EXPERIMENTAL_OPS_GEN_MODEL_ARG_SPEC_H_ + +#include "tensorflow/c/experimental/ops/gen/model/arg_type.h" +#include "tensorflow/core/framework/op_def.pb.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +namespace generator { + +// An input or output argument to an Op. +// +// Essentially, this represents an OpDef::ArgDef and its context within the Op. +class ArgSpec { + public: + ArgSpec() = default; + ArgSpec(const ArgSpec& other) = default; + static ArgSpec CreateInput(const OpDef::ArgDef& arg_def, int position); + static ArgSpec CreateOutput(const OpDef::ArgDef& arg_def, int position); + + const string& name() const { return name_; } + const string& description() const { return description_; } + const ArgType arg_type() const { return arg_type_; } + const int position() const { return position_; } + + private: + explicit ArgSpec(const OpDef::ArgDef& arg_def, ArgType arg_type, + int position); + + string name_; + string description_; + ArgType arg_type_; + int position_; +}; + +} // namespace generator +} // namespace tensorflow + +#endif // TENSORFLOW_C_EXPERIMENTAL_OPS_GEN_MODEL_ARG_SPEC_H_ diff --git a/third_party/tflite-hdrs/tensorflow/c/experimental/ops/gen/model/arg_type.h b/third_party/tflite-hdrs/tensorflow/c/experimental/ops/gen/model/arg_type.h new file mode 100644 index 00000000..df3b9e94 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/c/experimental/ops/gen/model/arg_type.h @@ -0,0 +1,55 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_C_EXPERIMENTAL_OPS_GEN_MODEL_ARG_TYPE_H_ +#define TENSORFLOW_C_EXPERIMENTAL_OPS_GEN_MODEL_ARG_TYPE_H_ + +#include "tensorflow/core/framework/op_def.pb.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +namespace generator { + +// Type information of an Op argument (ArgSpec).. +// +// This represents the type information with OpDef::ArgDef and any type-related +// context. +class ArgType { + public: + ArgType() = default; + ArgType(const ArgType& other) = default; + static ArgType CreateInput(const OpDef::ArgDef& arg_def); + static ArgType CreateInputRef(const OpDef::ArgDef& arg_def); + static ArgType CreateOutput(const OpDef::ArgDef& arg_def); + + const tensorflow::DataType data_type() const { return data_type_; } + const string type_attr_name() const { return type_attr_name_; } + const bool is_read_only() const { return kind_ == kInput; } + const bool is_list() const { return is_list_; } + + private: + enum Kind { kInput = 0, kInputRef, kOutput }; + + explicit ArgType(const OpDef::ArgDef& arg_def, Kind kind); + + Kind kind_; + tensorflow::DataType data_type_; + string type_attr_name_; + bool is_list_; +}; + +} // namespace generator +} // namespace tensorflow + +#endif // TENSORFLOW_C_EXPERIMENTAL_OPS_GEN_MODEL_ARG_TYPE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/c/experimental/ops/gen/model/attr_spec.h b/third_party/tflite-hdrs/tensorflow/c/experimental/ops/gen/model/attr_spec.h new file mode 100644 index 00000000..8c9488bf --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/c/experimental/ops/gen/model/attr_spec.h @@ -0,0 +1,55 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_C_EXPERIMENTAL_OPS_GEN_MODEL_ATTR_SPEC_H_ +#define TENSORFLOW_C_EXPERIMENTAL_OPS_GEN_MODEL_ATTR_SPEC_H_ + +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/op_def.pb.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +namespace generator { + +// An attribute for an Op, such as an input/output type or for passing options. +// +// Essentially, this represents an OpDef::AttrDef and its context within the Op. +class AttrSpec { + public: + AttrSpec() = default; + AttrSpec(const AttrSpec& other) = default; + static AttrSpec Create(const OpDef::AttrDef& attr_def); + + const string& name() const { return name_; } + const string& description() const { return description_; } + const string& full_type() const { return full_type_; } + const string& base_type() const { return base_type_; } + const AttrValue& default_value() const { return default_value_; } + const bool is_list() const { return is_list_; } + + private: + explicit AttrSpec(const OpDef::AttrDef& attr_def); + + string name_; + string description_; + string full_type_; + string base_type_; + AttrValue default_value_; + bool is_list_; +}; + +} // namespace generator +} // namespace tensorflow + +#endif // TENSORFLOW_C_EXPERIMENTAL_OPS_GEN_MODEL_ATTR_SPEC_H_ diff --git a/third_party/tflite-hdrs/tensorflow/c/experimental/ops/gen/model/op_spec.h b/third_party/tflite-hdrs/tensorflow/c/experimental/ops/gen/model/op_spec.h new file mode 100644 index 00000000..986ece00 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/c/experimental/ops/gen/model/op_spec.h @@ -0,0 +1,60 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_C_EXPERIMENTAL_OPS_GEN_MODEL_OP_SPEC_H_ +#define TENSORFLOW_C_EXPERIMENTAL_OPS_GEN_MODEL_OP_SPEC_H_ + +#include +#include + +#include "tensorflow/c/experimental/ops/gen/model/arg_spec.h" +#include "tensorflow/c/experimental/ops/gen/model/attr_spec.h" +#include "tensorflow/core/framework/api_def.pb.h" +#include "tensorflow/core/framework/op_def.pb.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +namespace generator { + +// An Op. +// +// Essentially, this represents an OpDef and any necessary context (e.g ApiDef). +class OpSpec { + public: + static OpSpec Create(const OpDef& op_def, const ApiDef& api_def); + + const string& name() const { return name_; } + const string& summary() const { return summary_; } + const string& description() const { return description_; } + const std::vector& Inputs() const { return input_args_; } + const std::vector& Outputs() const { return output_args_; } + const std::vector& Attributes() const { return argument_attrs_; } + + private: + explicit OpSpec(const OpDef& op_def, const ApiDef& api_def); + + private: + string name_; + string summary_; + string description_; + std::vector input_args_; + std::vector output_args_; + std::vector argument_attrs_; + std::map type_attrs_; +}; + +} // namespace generator +} // namespace tensorflow + +#endif // TENSORFLOW_C_EXPERIMENTAL_OPS_GEN_MODEL_OP_SPEC_H_ diff --git a/third_party/tflite-hdrs/tensorflow/c/experimental/ops/io_ops.h b/third_party/tflite-hdrs/tensorflow/c/experimental/ops/io_ops.h new file mode 100644 index 00000000..939c8536 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/c/experimental/ops/io_ops.h @@ -0,0 +1,50 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This file is MACHINE GENERATED! Do not edit. + +#ifndef TENSORFLOW_C_EXPERIMENTAL_OPS_IO_OPS_H_ +#define TENSORFLOW_C_EXPERIMENTAL_OPS_IO_OPS_H_ + +#include "absl/status/status.h" +#include "absl/types/span.h" +#include "tensorflow/c/eager/abstract_context.h" +#include "tensorflow/c/eager/abstract_tensor_handle.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/platform/status.h" + +namespace tensorflow { +namespace ops { + +// Restores tensors from a V2 checkpoint. +absl::Status RestoreV2(AbstractContext* ctx, AbstractTensorHandle* const prefix, + AbstractTensorHandle* const tensor_names, + AbstractTensorHandle* const shape_and_slices, + absl::Span tensors, + absl::Span dtypes, const char* name = nullptr, + const char* raw_device_name = nullptr); + +// Saves tensors in V2 checkpoint format. +absl::Status SaveV2(AbstractContext* ctx, AbstractTensorHandle* const prefix, + AbstractTensorHandle* const tensor_names, + AbstractTensorHandle* const shape_and_slices, + absl::Span tensors, + const char* name = nullptr, + const char* raw_device_name = nullptr); + +} // namespace ops +} // namespace tensorflow + +#endif // TENSORFLOW_C_EXPERIMENTAL_OPS_IO_OPS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/c/experimental/ops/math_ops.h b/third_party/tflite-hdrs/tensorflow/c/experimental/ops/math_ops.h new file mode 100644 index 00000000..c33c89fd --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/c/experimental/ops/math_ops.h @@ -0,0 +1,107 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This file is MACHINE GENERATED! Do not edit. + +#ifndef TENSORFLOW_C_EXPERIMENTAL_OPS_MATH_OPS_H_ +#define TENSORFLOW_C_EXPERIMENTAL_OPS_MATH_OPS_H_ + +#include "absl/status/status.h" +#include "tensorflow/c/eager/abstract_context.h" +#include "tensorflow/c/eager/abstract_tensor_handle.h" +#include "tensorflow/core/platform/status.h" + +namespace tensorflow { +namespace ops { + +// Returns x * y element-wise. +absl::Status Mul(AbstractContext* ctx, AbstractTensorHandle* const x, + AbstractTensorHandle* const y, AbstractTensorHandle** z, + const char* name = nullptr, + const char* raw_device_name = nullptr); + +// Returns the complex conjugate of a complex number. +absl::Status Conj(AbstractContext* ctx, AbstractTensorHandle* const input, + AbstractTensorHandle** output, const char* name = nullptr, + const char* raw_device_name = nullptr); + +// Returns x + y element-wise. +absl::Status AddV2(AbstractContext* ctx, AbstractTensorHandle* const x, + AbstractTensorHandle* const y, AbstractTensorHandle** z, + const char* name = nullptr, + const char* raw_device_name = nullptr); + +// Multiply the matrix "a" by the matrix "b". +absl::Status MatMul(AbstractContext* ctx, AbstractTensorHandle* const a, + AbstractTensorHandle* const b, + AbstractTensorHandle** product, bool transpose_a = false, + bool transpose_b = false, const char* name = nullptr, + const char* raw_device_name = nullptr); + +// Computes numerical negative value element-wise. +absl::Status Neg(AbstractContext* ctx, AbstractTensorHandle* const x, + AbstractTensorHandle** y, const char* name = nullptr, + const char* raw_device_name = nullptr); + +// Computes the sum of elements across dimensions of a tensor. +absl::Status Sum(AbstractContext* ctx, AbstractTensorHandle* const input, + AbstractTensorHandle* const reduction_indices, + AbstractTensorHandle** output, bool keep_dims = false, + const char* name = nullptr, + const char* raw_device_name = nullptr); + +// Returns x - y element-wise. +absl::Status Sub(AbstractContext* ctx, AbstractTensorHandle* const x, + AbstractTensorHandle* const y, AbstractTensorHandle** z, + const char* name = nullptr, + const char* raw_device_name = nullptr); + +// Returns x / y element-wise. +absl::Status Div(AbstractContext* ctx, AbstractTensorHandle* const x, + AbstractTensorHandle* const y, AbstractTensorHandle** z, + const char* name = nullptr, + const char* raw_device_name = nullptr); + +// Returns 0 if the denominator is zero. +absl::Status DivNoNan(AbstractContext* ctx, AbstractTensorHandle* const x, + AbstractTensorHandle* const y, AbstractTensorHandle** z, + const char* name = nullptr, + const char* raw_device_name = nullptr); + +// Computes exponential of x element-wise. \\(y = e^x\\). +absl::Status Exp(AbstractContext* ctx, AbstractTensorHandle* const x, + AbstractTensorHandle** y, const char* name = nullptr, + const char* raw_device_name = nullptr); + +// Computes square root of x element-wise. +absl::Status Sqrt(AbstractContext* ctx, AbstractTensorHandle* const x, + AbstractTensorHandle** y, const char* name = nullptr, + const char* raw_device_name = nullptr); + +// Computes the gradient for the sqrt of `x` wrt its input. +absl::Status SqrtGrad(AbstractContext* ctx, AbstractTensorHandle* const y, + AbstractTensorHandle* const dy, AbstractTensorHandle** z, + const char* name = nullptr, + const char* raw_device_name = nullptr); + +// Computes natural logarithm of (1 + x) element-wise. +absl::Status Log1p(AbstractContext* ctx, AbstractTensorHandle* const x, + AbstractTensorHandle** y, const char* name = nullptr, + const char* raw_device_name = nullptr); + +} // namespace ops +} // namespace tensorflow + +#endif // TENSORFLOW_C_EXPERIMENTAL_OPS_MATH_OPS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/c/experimental/ops/nn_ops.h b/third_party/tflite-hdrs/tensorflow/c/experimental/ops/nn_ops.h new file mode 100644 index 00000000..0006267f --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/c/experimental/ops/nn_ops.h @@ -0,0 +1,69 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This file is MACHINE GENERATED! Do not edit. + +#ifndef TENSORFLOW_C_EXPERIMENTAL_OPS_NN_OPS_H_ +#define TENSORFLOW_C_EXPERIMENTAL_OPS_NN_OPS_H_ + +#include "absl/status/status.h" +#include "tensorflow/c/eager/abstract_context.h" +#include "tensorflow/c/eager/abstract_tensor_handle.h" +#include "tensorflow/core/platform/status.h" + +namespace tensorflow { +namespace ops { + +// Computes softmax cross entropy cost and gradients to backpropagate. +absl::Status SparseSoftmaxCrossEntropyWithLogits( + AbstractContext* ctx, AbstractTensorHandle* const features, + AbstractTensorHandle* const labels, AbstractTensorHandle** loss, + AbstractTensorHandle** backprop, const char* name = nullptr, + const char* raw_device_name = nullptr); + +// Computes rectified linear gradients for a Relu operation. +absl::Status ReluGrad(AbstractContext* ctx, + AbstractTensorHandle* const gradients, + AbstractTensorHandle* const features, + AbstractTensorHandle** backprops, + const char* name = nullptr, + const char* raw_device_name = nullptr); + +// Computes rectified linear: `max(features, 0)`. +absl::Status Relu(AbstractContext* ctx, AbstractTensorHandle* const features, + AbstractTensorHandle** activations, + const char* name = nullptr, + const char* raw_device_name = nullptr); + +// Adds `bias` to `value`. +absl::Status BiasAdd(AbstractContext* ctx, AbstractTensorHandle* const value, + AbstractTensorHandle* const bias, + AbstractTensorHandle** output, + const char* data_format = "NHWC", + const char* name = nullptr, + const char* raw_device_name = nullptr); + +// The backward operation for "BiasAdd" on the "bias" tensor. +absl::Status BiasAddGrad(AbstractContext* ctx, + AbstractTensorHandle* const out_backprop, + AbstractTensorHandle** output, + const char* data_format = "NHWC", + const char* name = nullptr, + const char* raw_device_name = nullptr); + +} // namespace ops +} // namespace tensorflow + +#endif // TENSORFLOW_C_EXPERIMENTAL_OPS_NN_OPS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/c/experimental/ops/resource_variable_ops.h b/third_party/tflite-hdrs/tensorflow/c/experimental/ops/resource_variable_ops.h new file mode 100644 index 00000000..02b42bf4 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/c/experimental/ops/resource_variable_ops.h @@ -0,0 +1,67 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This file is MACHINE GENERATED! Do not edit. + +#ifndef TENSORFLOW_C_EXPERIMENTAL_OPS_RESOURCE_VARIABLE_OPS_H_ +#define TENSORFLOW_C_EXPERIMENTAL_OPS_RESOURCE_VARIABLE_OPS_H_ + +#include "absl/status/status.h" +#include "absl/types/span.h" +#include "tensorflow/c/eager/abstract_context.h" +#include "tensorflow/c/eager/abstract_tensor_handle.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +namespace ops { + +// Creates a handle to a Variable resource. +absl::Status VarHandleOp(AbstractContext* ctx, AbstractTensorHandle** resource, + DataType dtype, const PartialTensorShape shape, + const char* container = "", + const char* shared_name = "", + absl::Span allowed_devices = {}, + const char* name = nullptr, + const char* raw_device_name = nullptr); + +// Reads the value of a variable. +absl::Status ReadVariableOp(AbstractContext* ctx, + AbstractTensorHandle* const resource, + AbstractTensorHandle** value, DataType dtype, + const char* name = nullptr, + const char* raw_device_name = nullptr); + +// Assigns a new value to a variable. +absl::Status AssignVariableOp(AbstractContext* ctx, + AbstractTensorHandle* const resource, + AbstractTensorHandle* const value, + bool validate_shape = false, + const char* name = nullptr, + const char* raw_device_name = nullptr); + +// Deletes the resource specified by the handle. +absl::Status DestroyResourceOp(AbstractContext* ctx, + AbstractTensorHandle* const resource, + bool ignore_lookup_error = true, + const char* name = nullptr, + const char* raw_device_name = nullptr); + +} // namespace ops +} // namespace tensorflow + +#endif // TENSORFLOW_C_EXPERIMENTAL_OPS_RESOURCE_VARIABLE_OPS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/c/experimental/pluggable_profiler/pluggable_profiler.h b/third_party/tflite-hdrs/tensorflow/c/experimental/pluggable_profiler/pluggable_profiler.h new file mode 100644 index 00000000..bf8dbc49 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/c/experimental/pluggable_profiler/pluggable_profiler.h @@ -0,0 +1,178 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_C_EXPERIMENTAL_PLUGGABLE_PROFILER_PLUGGABLE_PROFILER_H_ +#define TENSORFLOW_C_EXPERIMENTAL_PLUGGABLE_PROFILER_PLUGGABLE_PROFILER_H_ +#include +#include + +#include "tensorflow/c/c_api_macros.h" +#include "tensorflow/c/tf_status.h" + +// C API for Pluggable Profiler. The API is under active development and +// eventually should allow registering a profiler with TensorFlow. +// +// Conventions: +// * Struct prefix indicates whether struct fields should be filled by the +// plug-in or core TensorFlow implementation: +// * TF_: Set/filled by core, unless marked otherwise. +// * TP_: Set/filled by plug-in, unless marked otherwise. +// * This prefix rule only applies to structures. Enumerations and methods +// are all prefixed with TP_. +// * Structs begin with two fields: +// * size_t struct_size: Stores the unpadded size of the struct. +// * void* ext: A reserved field that may be populated by a plugin in TP_* +// structs or potential future extension points in TF_ structs. Must be set +// to zero by default if it unused. +// * We use struct_size for version checking by both core and plug-in. +// * It is exempt from the TF/TP rule above and must be set both by core and +// plug-in. +// * It can be checked programmatically to determine which struct fields are +// available in the structure. +// * When a member is added to a struct, the struct size definition must be +// updated to use the new last member of the struct. +// +// Example usage: +// /* Sample TensorFlow code below, exact implementation might differ. */ +// // Version checking uses `struct_size`. It is exempt from the `TF/TP` rule +// // above and should be set both by core and the plugin." +// +// /* Plugin code below */ +// void profiler_start(const TP_Profiler* profiler, TF_Status* status) { +// /* Enable profiler */ +// ... +// } +// +// void profiler_stop(const TP_Profiler* profiler, TF_Status* status) { +// /* Disable Profiler */ +// ... +// } +// +// void profiler_collect_data_xspace(const TP_Profiler* profiler, uint8_t* +// buffer, size_t* size_in_bytes, TF_Status* status) { +// /* Plugin generates Xspace based on collected profiler data. */ +// Xspace xspace = get_my_xspace(); +// size_t buffer_size_in_bytes = *size_in_bytes; +// *size_in_bytes = xspace.ByteSizeLong(); /* get the size of Xspace */ +// if (buffer == nullptr) { +// /* TensorFlow will first get the size of Xspace, then allocate the big +// enough buffer and pass it to the plugin for retrieving Xspace. */ +// return; +// } +// bool success = xspace.SerializeToArray(buffer, buffer_size_in_bytes); +// } +// +// void TF_InitProfiler(TF_ProfilerRegistrationParams* params, TF_Status* +// status) { +// *params = { TF_PROFILER_REGISTRATION_PARAMS_STRUCT_SIZE }; +// params->profiler->struct_size = TP_PROFILER_STRUCT_SIZE; +// params->profiler_fns->struct_size = TP_PROFILER_FNS_STRUCT_SIZE; +// +// params->profiler->type = "MyDeviceType"; +// +// params->profiler_fns->start = profiler_start; +// params->profiler_fns->stop = profiler_stop; +// params->profiler_fns->collect_data_xspace = profiler_collect_data_xspace; +// params->destroy_profiler = profiler_destroy_profiler; +// params->destroy_profiler_fns = profiler_destroy_profiler_fns; +// } + +#define TP_MAJOR 0 +#define TP_MINOR 0 +#define TP_PATCH 1 + +#ifdef __cplusplus +extern "C" { +#endif + +// -------------------------------------------------------------------------- +// TP_Profiler holds a pointer to device type filed by the plug-in. +typedef struct TP_Profiler { + size_t struct_size; + void* ext; // free-form data set by plugin. + const char* device_type; + + // The struct size must be updated when adding new members. +#define TP_PROFILER_STRUCT_SIZE TF_OFFSET_OF_END(TP_Profiler, device_type) +} TP_Profiler; + +// -------------------------------------------------------------------------- +// TP_ProfilerFns holds the profiler interface function pointers filled by the +// plug-in. +typedef struct TP_ProfilerFns { + size_t struct_size; + + void* ext; // reserved for future use. + // Starts profiling. + void (*start)(const TP_Profiler* profiler, TF_Status* status); + // Stops profiling. + void (*stop)(const TP_Profiler* profiler, TF_Status* status); + + // Saves collected profile data into XSpace and serializes it to the buffer. + // - If `buffer` is null, returns the required buffer size in `size_in_bytes`. + // - If `buffer` is not null and `size_in_bytes` is the required buffer size, + // `buffer` is populated with profile data in serialized XSpace format. + // + // Only the first call with a non-null `buffer` following successful calls to + // start and stop might return data. Subsequent calls might return empty data + // unless start and stop are successfully called again. + void (*collect_data_xspace)(const TP_Profiler* profiler, uint8_t* buffer, + size_t* size_in_bytes, TF_Status* status); + + // The struct size must be updated when adding new members. +#define TP_PROFILER_FNS_STRUCT_SIZE \ + TF_OFFSET_OF_END(TP_ProfilerFns, collect_data_xspace) +} TP_ProfilerFns; + +// TF_ProfilerRegistrationParams holds the pointers to TP_Profiler and +// TP_ProfilerFns, the memory of TP_Profiler and TP_ProfilerFns is owned by Core +// TensorFlow and populated by the plug-in. +typedef struct TF_ProfilerRegistrationParams { + size_t struct_size; + void* ext; // reserved for future use + + // TensorFlow Profiler C API version. + int32_t major_version; + int32_t minor_version; + int32_t patch_version; + + // [in/out] Memory owned by core but attributes within are populated by the + // plugin. + TP_Profiler* profiler; + // [in/out] Memory owned by core but attributes within are populated by the + // plugin. + TP_ProfilerFns* profiler_fns; + // [out] Pointer to plugin's `TP_Profiler` clean up function. + // Cleans up fields inside `TP_Profiler` that were allocated + // by the plugin. `profiler` itself must not be deleted by the plugin. + void (*destroy_profiler)(TP_Profiler* profiler); + // [out] Pointer to plugin's `TP_ProfilerFns` clean up function. + // Cleans up fields inside `TP_ProfilerFns` that were allocated + // by the plugin. `profiler_fns` itself must not be deleted by the plugin. + void (*destroy_profiler_fns)(TP_ProfilerFns* profiler_fns); + + // The struct size must be updated when adding new members. +#define TF_PROFILER_REGISTRATION_PARAMS_STRUCT_SIZE \ + TF_OFFSET_OF_END(TF_ProfilerRegistrationParams, destroy_profiler_fns) +} TF_ProfilerRegistrationParams; + +// TF_InitProfiler to do profiler registration. +// Plug-in should implement TF_InitProfiler to register the profiler. +void TF_InitProfiler(TF_ProfilerRegistrationParams* params, TF_Status* status); + +#ifdef __cplusplus +} // extern "C" +#endif + +#endif // TENSORFLOW_C_EXPERIMENTAL_PLUGGABLE_PROFILER_PLUGGABLE_PROFILER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/c/experimental/pluggable_profiler/pluggable_profiler_internal.h b/third_party/tflite-hdrs/tensorflow/c/experimental/pluggable_profiler/pluggable_profiler_internal.h new file mode 100644 index 00000000..55af07ad --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/c/experimental/pluggable_profiler/pluggable_profiler_internal.h @@ -0,0 +1,38 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_C_EXPERIMENTAL_PLUGGABLE_PROFILER_PLUGGABLE_PROFILER_INTERNAL_H_ +#define TENSORFLOW_C_EXPERIMENTAL_PLUGGABLE_PROFILER_PLUGGABLE_PROFILER_INTERNAL_H_ +#include "tensorflow/c/experimental/pluggable_profiler/pluggable_profiler.h" +#include "tensorflow/c/tf_status_helper.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/profiler/lib/profiler_interface.h" +#include "tensorflow/core/profiler/protobuf/xplane.pb.h" +#include "tsl/profiler/protobuf/profiler_options.pb.h" + +namespace tensorflow { +namespace profiler { + +// Plugin initialization function that a device plugin must define. Returns +// a TF_Status output specifying whether the initialization is successful. +using TFInitProfilerFn = void (*)(TF_ProfilerRegistrationParams* const, + TF_Status* const); + +// Registers plugin's profiler to TensorFlow's profiler registry. +absl::Status InitPluginProfiler(TFInitProfilerFn init_fn); + +} // namespace profiler +} // namespace tensorflow + +#endif // TENSORFLOW_C_EXPERIMENTAL_PLUGGABLE_PROFILER_PLUGGABLE_PROFILER_INTERNAL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/c/experimental/saved_model/core/concrete_function.h b/third_party/tflite-hdrs/tensorflow/c/experimental/saved_model/core/concrete_function.h new file mode 100644 index 00000000..0fc60557 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/c/experimental/saved_model/core/concrete_function.h @@ -0,0 +1,57 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_CONCRETE_FUNCTION_H_ +#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_CONCRETE_FUNCTION_H_ + +#include +#include + +#include "absl/types/span.h" +#include "tensorflow/c/eager/immediate_execution_operation.h" +#include "tensorflow/c/eager/immediate_execution_tensor_handle.h" +#include "tensorflow/c/experimental/saved_model/core/function_metadata.h" + +namespace tensorflow { + +// ConcreteFunctions correspond to an instance of a tf.function with a known set +// of inputs (either through get_concrete_function) or an input_signature. +// ConcreteFunction attempts to preserve the user-facing semantics of the +// tf.function python API and can take a limited set of types as arguments +// (to be modeled in tensorflow::Value), not just Tensors. +// +// SavedModelAPI's ConcreteFunctions' lifetimes are bound to the SavedModel they +// are loaded from, since they retain pointers to the TensorHandles owned by the +// SavedModel, and the FunctionDef of the SavedModel. +// +// Note(bmzhao): This class is only TEMPORARILY virtual, as a way to unblock +// TFRT integration with TF Serving. Do not add more virtual implementations of +// this class. Eventually we want to remove this virtual base class indirection +// and have only a single implementation. +class ConcreteFunction { + public: + virtual ~ConcreteFunction() = default; + + // This method returns the "Call" Op used to execute the function. + virtual absl::Status MakeCallOp( + absl::Span inputs, + ImmediateOpPtr* out) const = 0; + + virtual const FunctionMetadata& GetFunctionMetadata() const = 0; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_CONCRETE_FUNCTION_H_ diff --git a/third_party/tflite-hdrs/tensorflow/c/experimental/saved_model/core/function_metadata.h b/third_party/tflite-hdrs/tensorflow/c/experimental/saved_model/core/function_metadata.h new file mode 100644 index 00000000..8499288f --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/c/experimental/saved_model/core/function_metadata.h @@ -0,0 +1,27 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_FUNCTION_METADATA_H_ +#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_FUNCTION_METADATA_H_ + +namespace tensorflow { + +class FunctionMetadata { + // TODO(bmzhao): Fill in with fields as necessary +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_FUNCTION_METADATA_H_ diff --git a/third_party/tflite-hdrs/tensorflow/c/experimental/saved_model/core/ops/restore_ops.h b/third_party/tflite-hdrs/tensorflow/c/experimental/saved_model/core/ops/restore_ops.h new file mode 100644 index 00000000..5a0ec2bc --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/c/experimental/saved_model/core/ops/restore_ops.h @@ -0,0 +1,42 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_OPS_RESTORE_OPS_H_ +#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_OPS_RESTORE_OPS_H_ + +#include + +#include "absl/status/status.h" +#include "tensorflow/c/eager/immediate_execution_context.h" +#include "tensorflow/c/eager/immediate_execution_tensor_handle.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/platform/status.h" + +namespace tensorflow { +namespace internal { + +// TODO(bmzhao): Add a function to restore multiple tensors in one call. + +// Restores a single non-partioned tensorhandle of dtype `dtype`, using +// checkpoint at `prefix`, with a value stored in `checkpoint_key`. +absl::Status SingleRestore(ImmediateExecutionContext* ctx, + const std::string& prefix, + const std::string& checkpoint_key, DataType dtype, + ImmediateTensorHandlePtr* out); + +} // namespace internal +} // namespace tensorflow + +#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_OPS_RESTORE_OPS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/c/experimental/saved_model/core/ops/variable_ops.h b/third_party/tflite-hdrs/tensorflow/c/experimental/saved_model/core/ops/variable_ops.h new file mode 100644 index 00000000..ee01935b --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/c/experimental/saved_model/core/ops/variable_ops.h @@ -0,0 +1,63 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_OPS_VARIABLE_OPS_H_ +#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_OPS_VARIABLE_OPS_H_ + +#include "absl/status/status.h" +#include "tensorflow/c/eager/immediate_execution_context.h" +#include "tensorflow/c/eager/immediate_execution_tensor_handle.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/platform/status.h" + +namespace tensorflow { +namespace internal { + +// Executes a VarHandleOp using `ctx`, and fills `handle` with the DT_RESOURCE +// TensorHandle associated with the variable. This is equivalent to creating an +// unitialized TF2 tf.Variable. +// https://github.com/tensorflow/tensorflow/blob/516608035f85cec8b126712b0ff8407220206b22/tensorflow/python/ops/resource_variable_ops.py#L1867-L1872 +absl::Status CreateUninitializedResourceVariable( + ImmediateExecutionContext* ctx, DataType dtype, TensorShape shape, + const char* raw_device_name, ImmediateTensorHandlePtr* handle); + +// Executes an AssignVariableOp using `ctx`, assigning the variable associated +// with `variable_handle` with `value`. `dtype` must be the datatype of the +// underlying variable for `variable_handle`. Note that it is illegal to assign +// a variable to a Tensor with a different dtype than what the variable was +// created with. +absl::Status AssignVariable(ImmediateExecutionContext* ctx, + ImmediateExecutionTensorHandle* variable_handle, + DataType dtype, + ImmediateExecutionTensorHandle* value); + +// Executes a ReadVariableOp using `ctx`. This reads the underlying variable +// value of `variable_handle` and copies the value to `output`. `dtype` must be +// the dtype of the variable associated with `variable_handle`. +absl::Status ReadVariable(ImmediateExecutionContext* ctx, + ImmediateExecutionTensorHandle* variable_handle, + DataType dtype, ImmediateTensorHandlePtr* output); + +// Executes DestroyResourceOp on `handle`, using `ctx`. This is equivalent to +// the cleanup that occurs in a tf.Variable's EagerResourceDeleter: +// https://github.com/tensorflow/tensorflow/blob/516608035f85cec8b126712b0ff8407220206b22/tensorflow/python/ops/resource_variable_ops.py#L289-L290 +absl::Status DestroyResource(ImmediateExecutionContext* ctx, + ImmediateExecutionTensorHandle* handle); + +} // namespace internal +} // namespace tensorflow + +#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_OPS_VARIABLE_OPS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/c/experimental/saved_model/core/revived_types/asset.h b/third_party/tflite-hdrs/tensorflow/c/experimental/saved_model/core/revived_types/asset.h new file mode 100644 index 00000000..4f4bff86 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/c/experimental/saved_model/core/revived_types/asset.h @@ -0,0 +1,53 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_ASSET_H_ +#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_ASSET_H_ + +#include +#include + +#include "absl/status/status.h" +#include "tensorflow/c/eager/immediate_execution_context.h" +#include "tensorflow/c/eager/immediate_execution_tensor_handle.h" +#include "tensorflow/c/experimental/saved_model/core/revived_types/tensorhandle_convertible.h" +#include "tensorflow/c/tensor_interface.h" +#include "tensorflow/core/framework/tensor.pb.h" +#include "tensorflow/core/platform/status.h" + +namespace tensorflow { + +class Asset : public TensorHandleConvertible { + public: + static absl::Status Create(ImmediateExecutionContext* ctx, + const std::string& saved_model_dir, + const std::string& asset_filename, + std::unique_ptr* output); + + // Asset is movable, but not copyable. + Asset(Asset&& other) = default; + Asset& operator=(Asset&& other) = default; + + ~Asset() override = default; + + private: + explicit Asset(ImmediateTensorHandlePtr handle); + Asset(const Asset&) = delete; + Asset& operator=(const Asset&) = delete; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_ASSET_H_ diff --git a/third_party/tflite-hdrs/tensorflow/c/experimental/saved_model/core/revived_types/constant.h b/third_party/tflite-hdrs/tensorflow/c/experimental/saved_model/core/revived_types/constant.h new file mode 100644 index 00000000..0d89cf37 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/c/experimental/saved_model/core/revived_types/constant.h @@ -0,0 +1,57 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_CONSTANT_H_ +#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_CONSTANT_H_ + +#include + +#include "absl/status/status.h" +#include "tensorflow/c/eager/immediate_execution_context.h" +#include "tensorflow/c/eager/immediate_execution_tensor_handle.h" +#include "tensorflow/c/experimental/saved_model/core/revived_types/tensorhandle_convertible.h" +#include "tensorflow/c/tensor_interface.h" +#include "tensorflow/core/framework/tensor.pb.h" +#include "tensorflow/core/platform/status.h" + +namespace tensorflow { + +// This class corresponds to python's tf.constant, which is effectively a +// TensorHandle explicitly initialized to some value. +// For now this doesn't do much beyond wrap Context's CreateLocalHandle method, +// and offer a subclass of TensorHandleConvertible. Note that similar to +// the python's eager mode logic, we bypass calling the "Const" op: +// https://github.com/tensorflow/tensorflow/blob/1c064ab76064c58e54261b805027474885a1534d/tensorflow/python/framework/constant_op.py#L301 +class Constant : public TensorHandleConvertible { + public: + static absl::Status Create(ImmediateExecutionContext* ctx, + AbstractTensorInterface* tensor, + std::unique_ptr* output); + + // RevivedConstant is movable, but not copyable. + Constant(Constant&& other) = default; + Constant& operator=(Constant&& other) = default; + + ~Constant() override = default; + + private: + explicit Constant(ImmediateTensorHandlePtr handle); + Constant(const Constant&) = delete; + Constant& operator=(const Constant&) = delete; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_CONSTANT_H_ diff --git a/third_party/tflite-hdrs/tensorflow/c/experimental/saved_model/core/revived_types/flat_tensor_function.h b/third_party/tflite-hdrs/tensorflow/c/experimental/saved_model/core/revived_types/flat_tensor_function.h new file mode 100644 index 00000000..810a42ec --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/c/experimental/saved_model/core/revived_types/flat_tensor_function.h @@ -0,0 +1,90 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_FLAT_TENSOR_FUNCTION_H_ +#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_FLAT_TENSOR_FUNCTION_H_ + +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/types/span.h" +#include "tensorflow/c/eager/abstract_tensor_handle.h" +#include "tensorflow/c/eager/immediate_execution_context.h" +#include "tensorflow/c/eager/immediate_execution_operation.h" +#include "tensorflow/c/eager/immediate_execution_tensor_handle.h" +#include "tensorflow/core/framework/function.pb.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/protobuf/saved_object_graph.pb.h" + +namespace tensorflow { + +// FlatTensorFunction models a TF2 eager runtime view of a callable function, +// taking + returning flat lists of tensors, including any captures. +// Effectively, it is a thin wrapper around a FunctionDef owned by the +// EagerContext, and any TensorHandle captures associated with the function. The +// MakeCallOp method handles the logic of marshaling captures after the user +// provided inputs automatically. +// Note(bmzhao): This class is mainly intended to house low-level reusable +// function logic between SignatureDefFunction and ConcreteFunction, which +// present higher level interfaces. This type does *not* hold any "function +// metadata". +class FlatTensorFunction { + public: + // Factory for creating a FlatTensorFunction. + // + // Params: + // function_def - The function_def associated with the created + // FlatTensorFunction. FlatTensorFunction will register this + // function_def with `ctx` on creation, and de-register it on + // destruction. function_def must be non-null, but + // otherwise has no lifetime requirements. + // captures - The captured TensorHandles associated with this + // FlatTensorFunction. FlatTensorFunction will participate in + // ownership of the handles (it explicitly increments the refcount + // of each handle, and will decrement them on destruction). + // ctx - A handle to the Tensorflow runtime. This MUST be non-null and + // outlive TFConcreteFunction. + // out - The output FlatTensorFunction. + static absl::Status Create( + const FunctionDef* function_def, + std::vector captures, + ImmediateExecutionContext* ctx, std::unique_ptr* out); + + // This method creates a "Call" Op used to execute the function. + absl::Status MakeCallOp(absl::Span inputs, + ImmediateOpPtr* out) const; + + ~FlatTensorFunction(); + + private: + FlatTensorFunction(const std::string& name, + std::vector captures, + ImmediateExecutionContext* ctx); + + FlatTensorFunction(const FlatTensorFunction&) = delete; + FlatTensorFunction& operator=(const FlatTensorFunction&) = delete; + + // Name of the FunctionDef corresponding to this TFConcreteFunction + std::string name_; + std::vector captures_; + ImmediateExecutionContext* ctx_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_FLAT_TENSOR_FUNCTION_H_ diff --git a/third_party/tflite-hdrs/tensorflow/c/experimental/saved_model/core/revived_types/partially_revived_objects.h b/third_party/tflite-hdrs/tensorflow/c/experimental/saved_model/core/revived_types/partially_revived_objects.h new file mode 100644 index 00000000..07a4e185 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/c/experimental/saved_model/core/revived_types/partially_revived_objects.h @@ -0,0 +1,65 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_PARTIALLY_REVIVED_OBJECTS_H_ +#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_PARTIALLY_REVIVED_OBJECTS_H_ + +#include +#include + +#include "absl/status/status.h" +#include "tensorflow/c/eager/immediate_execution_context.h" +#include "tensorflow/c/experimental/saved_model/core/revived_types/asset.h" +#include "tensorflow/c/experimental/saved_model/core/revived_types/constant.h" +#include "tensorflow/c/experimental/saved_model/core/revived_types/restored_resource_revival_state.h" +#include "tensorflow/c/experimental/saved_model/core/revived_types/revived_objects.h" +#include "tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function_revival_state.h" +#include "tensorflow/c/experimental/saved_model/core/revived_types/tf_signature_def_function_revival_state.h" +#include "tensorflow/c/experimental/saved_model/core/revived_types/variable.h" +#include "tensorflow/core/lib/gtl/flatmap.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/protobuf/saved_object_graph.pb.h" + +namespace tensorflow { + +// Container for objects during the revival step in SavedModel's loading. +// Notably, resources and functions can be in a state where they reference +// other resources/functions that have not been constructed yet. We collect +// *all* objects in a partially valid state here, then properly initialize +// resources and functions. Implementation-wise, PartiallyRevivedObjects +// contains maps keyed by the node number of the SavedObjectGraph, and map to an +// object of the corresponding type. So, if node 2 in the object graph is a +// variable, PartiallyRevivedObjects.variables[2] exists, and corresponds to a +// tensorflow::Variable object. The only exception to this is the +// "signatures_map", which is keyed by the "signature" key +// (https://github.com/tensorflow/tensorflow/blob/372918decee7f558b3c194b04f77c20dcc679a31/tensorflow/core/protobuf/meta_graph.proto#L89), +// and maps to the SignatureDefFunction node in the SavedObjectGraph. +struct PartiallyRevivedObjects { + gtl::FlatMap> variables; + gtl::FlatMap> assets; + gtl::FlatMap> constants; + gtl::FlatMap concrete_functions; + gtl::FlatMap signature_def_functions; + gtl::FlatMap restored_resources; + gtl::FlatMap signatures_map; + + absl::Status Build(ImmediateExecutionContext* ctx, + const SavedObjectGraph& obj_graph, + RevivedObjects* revived); +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_PARTIALLY_REVIVED_OBJECTS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/c/experimental/saved_model/core/revived_types/restored_resource.h b/third_party/tflite-hdrs/tensorflow/c/experimental/saved_model/core/revived_types/restored_resource.h new file mode 100644 index 00000000..691a591c --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/c/experimental/saved_model/core/revived_types/restored_resource.h @@ -0,0 +1,89 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_RESTORED_RESOURCE_H_ +#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_RESTORED_RESOURCE_H_ + +#include +#include + +#include "absl/status/status.h" +#include "absl/types/optional.h" +#include "tensorflow/c/eager/immediate_execution_tensor_handle.h" +#include "tensorflow/c/experimental/saved_model/core/revived_types/tensorhandle_convertible.h" +#include "tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function.h" +#include "tensorflow/core/platform/status.h" + +namespace tensorflow { + +// RestoredResource represents a TF2 "Resource" object loaded from a savedmodel, +// analogous to the Python _RestoredResource object: +// https://github.com/tensorflow/tensorflow/blob/fda326e542ca67534e8411edb180e8760a4828b7/tensorflow/python/saved_model/load.py#L481 +// TF2 resource objects typically extend TrackableResource: +// https://github.com/tensorflow/tensorflow/blob/fda326e542ca67534e8411edb180e8760a4828b7/tensorflow/python/training/tracking/tracking.py#L285 +// and are expected to implement "_create_resource", "_initialize", and +// "_destroy_resource" functions: +// https://github.com/tensorflow/tensorflow/blob/139ba9c5284799beafdd1d7f895127cf00e7c48f/tensorflow/python/training/tracking/tracking.py#L262-L281 +class RestoredResource : TensorHandleConvertible { + public: + // Note(bmzhao): RestoredResource stores non-owning pointers to its associated + // functions because SavedModel internally owns all functions and objects in + // the RevivedObjects struct (which owns all functions). One alternative would + // be to have RevivedObjects store shared_ptr instead, and + // change RestoredResource's constructor take shared_ptr. + // To keep things simple, I've stuck to raw pointers for now. + // + // Params: + // device - The device string associated with the SavedResource + // https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/protobuf/saved_object_graph.proto#L182 + // Conceptually, this is the same device used in CapturableResource: + // https://github.com/tensorflow/tensorflow/blob/568e2bef00f24af1159a0846abf67c099ca78a21/tensorflow/python/training/tracking/tracking.py#L222-L225 + // Implementation-wise, it is device used when invoking the + // create_resource function to produce the resource_handle + // associated with the object: + // https://github.com/tensorflow/tensorflow/blob/568e2bef00f24af1159a0846abf67c099ca78a21/tensorflow/python/training/tracking/tracking.py#L246-L247 + // create_resource - Non owning pointer to the create_resource function + // associated with this object. Must be NON-NULL. + // initialize - Non owning pointer to the initialize function associated with + // this object. Must be NON-NULL. + // destroy_resource - Non owning pointer to the destroy_resource function + // associated with this object. Ideally this should be + // NON-NULL, but in order to support models saved prior to + // https://github.com/tensorflow/tensorflow/commit/3c806101f57768e479f8646e7518bbdff1632ca3 + // we allow null here. This will, however, leak resources. + RestoredResource(const std::string& device, + TFConcreteFunction* create_resource, + TFConcreteFunction* initialize, + TFConcreteFunction* destroy_resource, + ImmediateTensorHandlePtr resource_handle); + + absl::Status Initialize() const; + + // RestoredResource is movable, but not copyable. + RestoredResource(RestoredResource&& other) = default; + RestoredResource& operator=(RestoredResource&& other) = default; + + ~RestoredResource() override; + + private: + std::string device_; + TFConcreteFunction* create_resource_; + TFConcreteFunction* initialize_; + TFConcreteFunction* destroy_resource_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_RESTORED_RESOURCE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/c/experimental/saved_model/core/revived_types/restored_resource_revival_state.h b/third_party/tflite-hdrs/tensorflow/c/experimental/saved_model/core/revived_types/restored_resource_revival_state.h new file mode 100644 index 00000000..48d00308 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/c/experimental/saved_model/core/revived_types/restored_resource_revival_state.h @@ -0,0 +1,38 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_RESTORED_RESOURCE_REVIVAL_STATE_H_ +#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_RESTORED_RESOURCE_REVIVAL_STATE_H_ + +#include + +#include "tensorflow/c/eager/immediate_execution_tensor_handle.h" +#include "tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function_revival_state.h" + +namespace tensorflow { + +// All "Resources" should have these 3 saved functions: +// https://github.com/tensorflow/tensorflow/blob/86dc281333d7d277ddc1882f2bca4b17e7ec40e5/tensorflow/python/training/tracking/tracking.py#L277-L281 +struct RestoredResourceRevivalState { + std::string device; + TFConcreteFunctionRevivalState* create_resource = nullptr; + TFConcreteFunctionRevivalState* initialize = nullptr; + TFConcreteFunctionRevivalState* destroy_resource = nullptr; + ImmediateTensorHandlePtr resource_handle = nullptr; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_RESTORED_RESOURCE_REVIVAL_STATE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/c/experimental/saved_model/core/revived_types/revived_objects.h b/third_party/tflite-hdrs/tensorflow/c/experimental/saved_model/core/revived_types/revived_objects.h new file mode 100644 index 00000000..0f09c743 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/c/experimental/saved_model/core/revived_types/revived_objects.h @@ -0,0 +1,92 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_REVIVED_OBJECTS_H_ +#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_REVIVED_OBJECTS_H_ + +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "tensorflow/c/experimental/saved_model/core/revived_types/asset.h" +#include "tensorflow/c/experimental/saved_model/core/revived_types/constant.h" +#include "tensorflow/c/experimental/saved_model/core/revived_types/restored_resource.h" +#include "tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function.h" +#include "tensorflow/c/experimental/saved_model/core/revived_types/tf_signature_def_function.h" +#include "tensorflow/c/experimental/saved_model/core/revived_types/variable.h" +#include "tensorflow/core/lib/gtl/flatmap.h" + +namespace tensorflow { + +// A container for revived saved model objects. +// +// Most of the objects will be revived from nodes in the object graph, and for +// those objects this container provides a map from node id to the revived +// objects. +// +// For objects that have to be revived but are not part of the object graph, +// this container provides a place where the objects can be stored so they are +// available to the runtime. +template +class RevivedObjectContainer { + public: + // Insert an object that is not related to a node id. This usually means the + // object was not referenced by the object_graph, but is needed by other + // objects. + void Insert(std::unique_ptr object) { + objects_.push_back(std::move(object)); + } + + // Insert an object that is tied to the given object graph node id. + void Insert(std::unique_ptr object, int node_id) { + objects_by_id_[node_id] = object.get(); + Insert(std::move(object)); + } + + // Find an object by the object graph node id. + // Returns nullptr if there is no such object. + T* Find(int node_id) { + auto it = objects_by_id_.find(node_id); + return it == objects_by_id_.end() ? nullptr : it->second; + } + + private: + std::vector> objects_; + absl::flat_hash_map objects_by_id_; +}; + +// RevivedObjects is mainly used as a container for all the "state" owned by +// SavedModel. It stores all non-"user object" nodes from a SavedModel +// (https://github.com/tensorflow/tensorflow/blob/568e2bef00f24af1159a0846abf67c099ca78a21/tensorflow/core/protobuf/saved_object_graph.proto#L57-L62) +// in a "fully constructed" state. It is effectively a strongly typed map, where +// each member is a map from the node id in the SavedObjectGraph's nodes +// (https://github.com/tensorflow/tensorflow/blob/568e2bef00f24af1159a0846abf67c099ca78a21/tensorflow/core/protobuf/saved_object_graph.proto#L25-L29) +// to the revived object of the corresponding type. +struct RevivedObjects { + // Order of declaration is important here: we want the RestoredResources to be + // freed after TFConcreteFunctions, for example. + gtl::FlatMap> variables; + gtl::FlatMap> assets; + gtl::FlatMap> constants; + gtl::FlatMap> + signature_def_functions; + RevivedObjectContainer concrete_functions; + gtl::FlatMap restored_resources; + gtl::FlatMap signatures_map; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_REVIVED_OBJECTS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/c/experimental/saved_model/core/revived_types/tensorhandle_convertible.h b/third_party/tflite-hdrs/tensorflow/c/experimental/saved_model/core/revived_types/tensorhandle_convertible.h new file mode 100644 index 00000000..4c2c874e --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/c/experimental/saved_model/core/revived_types/tensorhandle_convertible.h @@ -0,0 +1,49 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_TENSORHANDLE_CONVERTIBLE_H_ +#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_TENSORHANDLE_CONVERTIBLE_H_ + +#include "tensorflow/c/eager/immediate_execution_tensor_handle.h" + +namespace tensorflow { + +// A common interface for objects that can be converted to a TensorHandle. +// Examples of objects that implement this include Variables, Constants, Assets, +// etc. This is used to convert captured objects into a ConcreteFunction's +// captured TensorHandles: +// https://github.com/tensorflow/tensorflow/blob/676a68963ea4b64fe479b9cede06aa8f5b290ab8/tensorflow/python/saved_model/load.py#L229-L240 +class TensorHandleConvertible { + public: + explicit TensorHandleConvertible(ImmediateTensorHandlePtr handle) + : handle_(std::move(handle)) {} + + ImmediateExecutionTensorHandle* handle() { return handle_.get(); } + + // TensorHandleConvertible is movable, but not copyable. + TensorHandleConvertible(TensorHandleConvertible&& other) = default; + TensorHandleConvertible& operator=(TensorHandleConvertible&& other) = default; + + virtual ~TensorHandleConvertible() = default; + + protected: + TensorHandleConvertible(const TensorHandleConvertible&) = delete; + TensorHandleConvertible& operator=(const TensorHandleConvertible&) = delete; + ImmediateTensorHandlePtr handle_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_TENSORHANDLE_CONVERTIBLE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function.h b/third_party/tflite-hdrs/tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function.h new file mode 100644 index 00000000..669d77b5 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function.h @@ -0,0 +1,85 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_TF_CONCRETE_FUNCTION_H_ +#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_TF_CONCRETE_FUNCTION_H_ + +#include +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/types/span.h" +#include "tensorflow/c/eager/abstract_tensor_handle.h" +#include "tensorflow/c/eager/immediate_execution_context.h" +#include "tensorflow/c/eager/immediate_execution_operation.h" +#include "tensorflow/c/eager/immediate_execution_tensor_handle.h" +#include "tensorflow/c/experimental/saved_model/core/concrete_function.h" +#include "tensorflow/c/experimental/saved_model/core/function_metadata.h" +#include "tensorflow/c/experimental/saved_model/core/revived_types/flat_tensor_function.h" +#include "tensorflow/core/framework/function.pb.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/protobuf/saved_object_graph.pb.h" + +namespace tensorflow { + +// TF Eager Runtime-based implementation of a "ConcreteFunction" loaded from a +// saved model. +class TFConcreteFunction : public ConcreteFunction { + public: + // Factory function for creating a TFConcreteFunction. + // + // Params: + // function_def - The function_def associated with the created + // TFConcreteFunction. TFConcreteFunction will register this + // function_def with `ctx` on creation, and de-register it on + // destruction. function_def must be non-null, but + // otherwise has no lifetime requirements. + // captures - The captured TensorHandles associated with this + // TFConcreteFunction. + // metadata - The FunctionMetadata associated with this TFConcreteFunction. + // ctx - A handle to the Tensorflow runtime. This MUST be non-null and + // outlive TFConcreteFunction. + // out - The output TFConcreteFunction. + static absl::Status Create( + const FunctionDef* function_def, + std::vector captures, + FunctionMetadata metadata, ImmediateExecutionContext* ctx, + std::unique_ptr* out); + + // This method returns the "Call" Op used to execute the function. + absl::Status MakeCallOp(absl::Span inputs, + ImmediateOpPtr* out) const override; + + const FunctionMetadata& GetFunctionMetadata() const override; + + ~TFConcreteFunction() override = default; + + private: + TFConcreteFunction(std::unique_ptr func, + FunctionMetadata metadata); + + TFConcreteFunction(const TFConcreteFunction&) = delete; + TFConcreteFunction& operator=(const TFConcreteFunction&) = delete; + + std::unique_ptr func_; + FunctionMetadata metadata_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_TF_CONCRETE_FUNCTION_H_ diff --git a/third_party/tflite-hdrs/tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function_revival_state.h b/third_party/tflite-hdrs/tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function_revival_state.h new file mode 100644 index 00000000..3dd7a6ee --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function_revival_state.h @@ -0,0 +1,61 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_TF_CONCRETE_FUNCTION_REVIVAL_STATE_H_ +#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_TF_CONCRETE_FUNCTION_REVIVAL_STATE_H_ + +#include +#include + +#include "absl/types/optional.h" +#include "tensorflow/c/eager/immediate_execution_context.h" +#include "tensorflow/c/eager/immediate_execution_tensor_handle.h" +#include "tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function.h" +#include "tensorflow/core/framework/function.pb.h" +#include "tensorflow/core/protobuf/saved_object_graph.pb.h" + +namespace tensorflow { + +// TFConcreteFunctionRevivalState wraps the state needed for building a +// TF_ConcreteFunction. This is mainly used in PartiallyRevivedObjects, which +// wraps partially constructed Function and Resource objects. +struct TFConcreteFunctionRevivalState { + // Index of the node in the SavedObjectGraph it was loaded from. + int node_id; + + // Pointer to the original functiondef. fdef_ is guaranteed to be + // non-null. + const FunctionDef* fdef; + + // TensorHandle captures for this funtion + std::vector captures; + + // SavedConcreteFunction contains much of the metadata of the expected "types" + // of the inputs and outputs of a function. + // Note(bmzhao): saved_concrete_func_ is guaranteed to be non-null. + const SavedConcreteFunction* saved_concrete_func; + + // This field is only present on TF2 ConcreteFunctions, and is useful for + // determining the original argument *names* of the function, (since the + // "canonicalized_input_signature" may append extra uniquifying integers). + // However, SavedBareConcreteFunctions do not have a FunctionSpec. + // Note(bmzhao): if function_spec_.has_value(), *function_spec_ is guaranteed + // to be non-null. + absl::optional function_spec; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_TF_CONCRETE_FUNCTION_REVIVAL_STATE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/c/experimental/saved_model/core/revived_types/tf_signature_def_function.h b/third_party/tflite-hdrs/tensorflow/c/experimental/saved_model/core/revived_types/tf_signature_def_function.h new file mode 100644 index 00000000..c9b98189 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/c/experimental/saved_model/core/revived_types/tf_signature_def_function.h @@ -0,0 +1,89 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_TF_SIGNATURE_DEF_FUNCTION_H_ +#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_TF_SIGNATURE_DEF_FUNCTION_H_ + +#include +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/types/span.h" +#include "tensorflow/c/eager/abstract_tensor_handle.h" +#include "tensorflow/c/eager/immediate_execution_context.h" +#include "tensorflow/c/eager/immediate_execution_operation.h" +#include "tensorflow/c/eager/immediate_execution_tensor_handle.h" +#include "tensorflow/c/experimental/saved_model/core/revived_types/flat_tensor_function.h" +#include "tensorflow/c/experimental/saved_model/core/signature_def_function.h" +#include "tensorflow/c/experimental/saved_model/core/signature_def_function_metadata.h" +#include "tensorflow/core/framework/function.pb.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/protobuf/saved_object_graph.pb.h" + +namespace tensorflow { + +// This is the TF eager runtime implementation of SignatureDefFunction (separate +// from the TFRT implementation). The user-facing API of SignatureDefFunctions +// and their semantic differences from ConcreteFunction are described here: +// https://github.com/tensorflow/tensorflow/blob/e2db60c9d9598ebae0b7741587ce6f5d473584d9/tensorflow/cc/saved_model/experimental/public/signature_def_function.h#L30-L59 +// Additional implementation notes are available here: +// https://github.com/tensorflow/tensorflow/blob/e2db60c9d9598ebae0b7741587ce6f5d473584d9/tensorflow/c/experimental/saved_model/core/signature_def_function.h#L31-L48 +class TFSignatureDefFunction : public SignatureDefFunction { + public: + // Factory function for creating a TFSignatureDefFunction. + // + // Params: + // function_def - The function_def associated with the created + // TFSignatureDefFunction. TFSignatureDefFunction will + // register this function_def with `ctx` on creation, and + // de-register it on destruction. function_def must be + // non-null, but otherwise has no lifetime requirements. + // captures - The captured TensorHandles associated with this + // TFConcreteFunction. + // metadata - FunctionMetadata associated with this TFSignatureDefFunction. + // ctx - A handle to the Tensorflow runtime. This MUST be non-null and + // outlive TFSignatureDefFunction. + // out - The output TFSignatureDefFunction. + static absl::Status Create( + const FunctionDef* function_def, + std::vector captures, + SignatureDefFunctionMetadata metadata, ImmediateExecutionContext* ctx, + std::unique_ptr* out); + + // This method creates a "Call" Op used to execute the function. + absl::Status MakeCallOp(absl::Span inputs, + ImmediateOpPtr* out) const override; + + const SignatureDefFunctionMetadata& GetFunctionMetadata() const override; + + ~TFSignatureDefFunction() override = default; + + private: + TFSignatureDefFunction(std::unique_ptr func, + SignatureDefFunctionMetadata metadata); + + TFSignatureDefFunction(const TFSignatureDefFunction&) = delete; + TFSignatureDefFunction& operator=(const TFSignatureDefFunction&) = delete; + + std::unique_ptr func_; + SignatureDefFunctionMetadata metadata_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_TF_SIGNATURE_DEF_FUNCTION_H_ diff --git a/third_party/tflite-hdrs/tensorflow/c/experimental/saved_model/core/revived_types/tf_signature_def_function_revival_state.h b/third_party/tflite-hdrs/tensorflow/c/experimental/saved_model/core/revived_types/tf_signature_def_function_revival_state.h new file mode 100644 index 00000000..ac1b20e4 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/c/experimental/saved_model/core/revived_types/tf_signature_def_function_revival_state.h @@ -0,0 +1,55 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_TF_SIGNATURE_DEF_FUNCTION_REVIVAL_STATE_H_ +#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_TF_SIGNATURE_DEF_FUNCTION_REVIVAL_STATE_H_ + +#include +#include + +#include "absl/types/optional.h" +#include "tensorflow/c/eager/immediate_execution_tensor_handle.h" +#include "tensorflow/c/experimental/saved_model/core/revived_types/tf_signature_def_function.h" +#include "tensorflow/core/framework/function.pb.h" +#include "tensorflow/core/protobuf/saved_object_graph.pb.h" + +namespace tensorflow { + +// FunctionBuilder wraps the state needed for building a SignatureDefFunction. +// This is mainly used in PartiallyRevivedObjects, which wraps partially +// constructed Function and Resource objects. +struct TFSignatureDefFunctionRevivalState { + // Index of the node in the SavedObjectGraph it was loaded from. + int node_id = 0; + + // Pointer to the original functiondef. fdef_ is guaranteed to be + // non-null. + const FunctionDef* fdef = nullptr; + + // SavedConcreteFunction contains much of the metadata of the expected "types" + // of the inputs and outputs of a function. + // Note(bmzhao): saved_concrete_func_ is guaranteed to be non-null. + const SavedConcreteFunction* saved_concrete_func = nullptr; + + // The name of the SignatureDef key. + std::string signature_key; + + // TensorHandle captures for this funtion + std::vector captures; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_TF_SIGNATURE_DEF_FUNCTION_REVIVAL_STATE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/c/experimental/saved_model/core/revived_types/variable.h b/third_party/tflite-hdrs/tensorflow/c/experimental/saved_model/core/revived_types/variable.h new file mode 100644 index 00000000..5a9ad51a --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/c/experimental/saved_model/core/revived_types/variable.h @@ -0,0 +1,80 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_VARIABLE_H_ +#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_VARIABLE_H_ + +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/types/optional.h" +#include "tensorflow/c/eager/immediate_execution_context.h" +#include "tensorflow/c/eager/immediate_execution_tensor_handle.h" +#include "tensorflow/c/experimental/saved_model/core/revived_types/tensorhandle_convertible.h" +#include "tensorflow/core/common_runtime/eager/context.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/protobuf/saved_object_graph.pb.h" + +namespace tensorflow { + +class Variable : public TensorHandleConvertible { + public: + // Creates an uninitialized resource variable. Note that a caller must + // call "assign" to associate a value with the variable. + static absl::Status CreateUninitialized( + ImmediateExecutionContext* ctx, DataType dtype, TensorShape shape, + absl::optional name, const char* raw_device_name, + const std::vector& component_devices, + std::unique_ptr* output); + + // The dtype of the underlying variable. + DataType dtype(); + + // The shape of the underlying variable. + TensorShape shape(); + + // Updates the variable's contents with `handle`. + absl::Status Assign(ImmediateExecutionTensorHandle* handle); + + // Reads the value of the variable, and stores it in `out` + absl::Status ReadValue(ImmediateTensorHandlePtr* out); + + // Variable is movable, but not copyable. + Variable(Variable&& other) = default; + Variable& operator=(Variable&& other) = default; + + ~Variable() override; + + private: + Variable(ImmediateExecutionContext* ctx, DataType dtype, TensorShape shape, + absl::optional name, ImmediateTensorHandlePtr handle); + Variable(const Variable& variable) = delete; + Variable& operator=(const Variable&) = delete; + + std::string name_; + DataType dtype_; + TensorShape shape_; + + // ctx_ must outlive Variable. + ImmediateExecutionContext* ctx_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_TYPES_VARIABLE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/c/experimental/saved_model/core/saved_model_api.h b/third_party/tflite-hdrs/tensorflow/c/experimental/saved_model/core/saved_model_api.h new file mode 100644 index 00000000..1fd56822 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/c/experimental/saved_model/core/saved_model_api.h @@ -0,0 +1,66 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_SAVED_MODEL_API_H_ +#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_SAVED_MODEL_API_H_ + +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "tensorflow/c/experimental/saved_model/core/concrete_function.h" +#include "tensorflow/c/experimental/saved_model/core/signature_def_function.h" +#include "tensorflow/cc/saved_model/bundle_v2.h" +#include "tensorflow/core/platform/status.h" + +namespace tensorflow { + +// Note(bmzhao): This class is only TEMPORARILY virtual, as a way to unblock +// TFRT integration with TF Serving. Do not add more virtual implementations of +// this class. Eventually we want to remove this virtual base class indirection +// and have only a single implementation. +class SavedModelAPI { + public: + // Retrieve a function from the TF2 SavedModel, using the "path" to a function + // in a TF2 savedmodel. + // + // Note: `function` is a double pointer, so that implementations are + // able to return a pointer to an internal member. + virtual absl::Status GetFunction(const std::string& function_path, + ConcreteFunction** function) = 0; + + // Retrieve a list of child functions from a SavedModel given a starting node. + // 0 is the root node. + virtual absl::Status GetFunctions( + int node_id, + absl::flat_hash_map* functions) = 0; + + // Retrieve a SignatureDefFunction from a SavedModel, using the key of the + // SignatureDef map: + // https://github.com/tensorflow/tensorflow/blob/69b08900b1e991d84bce31f3b404f5ed768f339f/tensorflow/core/protobuf/meta_graph.proto#L89 + virtual absl::Status GetSignatureDefFunction( + const std::string& signature_def_key, + SignatureDefFunction** function) = 0; + + virtual SavedModelV2Bundle* GetBundle() = 0; + + virtual ~SavedModelAPI() = default; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_SAVED_MODEL_API_H_ diff --git a/third_party/tflite-hdrs/tensorflow/c/experimental/saved_model/core/saved_model_utils.h b/third_party/tflite-hdrs/tensorflow/c/experimental/saved_model/core/saved_model_utils.h new file mode 100644 index 00000000..9a6108db --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/c/experimental/saved_model/core/saved_model_utils.h @@ -0,0 +1,120 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_SAVED_MODEL_UTILS_H_ +#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_SAVED_MODEL_UTILS_H_ + +// Some internal utility functions for the SavedModelAPI, factored out into a +// separately unit-testable header. + +#include +#include + +#include "absl/types/optional.h" +#include "absl/types/span.h" +#include "tensorflow/c/eager/immediate_execution_context.h" +#include "tensorflow/c/experimental/saved_model/core/revived_types/asset.h" +#include "tensorflow/c/experimental/saved_model/core/revived_types/constant.h" +#include "tensorflow/c/experimental/saved_model/core/revived_types/partially_revived_objects.h" +#include "tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function.h" +#include "tensorflow/c/experimental/saved_model/core/revived_types/variable.h" +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/tensor.pb.h" +#include "tensorflow/core/lib/gtl/flatmap.h" +#include "tensorflow/core/lib/hash/hash.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/stringpiece.h" +#include "tensorflow/core/protobuf/meta_graph.pb.h" +#include "tensorflow/core/protobuf/saved_object_graph.pb.h" +#include "tensorflow/core/protobuf/struct.pb.h" + +namespace tensorflow { +namespace internal { + +// Load a TensorProto into a tensorflow::Constant. This is similar to the +// constant loading logic in python: +// https://github.com/tensorflow/tensorflow/blob/516608035f85cec8b126712b0ff8407220206b22/tensorflow/python/saved_model/load.py#L437 +absl::Status TensorProtoToConstant(ImmediateExecutionContext* ctx, + const TensorProto& proto, + std::unique_ptr* output); + +// Creates a tensorflow::Variable from a SavedVariable. This is similar to the +// logic in: +// https://github.com/tensorflow/tensorflow/blob/516608035f85cec8b126712b0ff8407220206b22/tensorflow/python/saved_model/load.py#L407 +// Note that the caller **must assign a value** to the loaded variable. +absl::Status LoadSavedVariable(ImmediateExecutionContext* ctx, + const SavedVariable& variable, + std::unique_ptr* output); + +absl::Status LoadSavedAsset(ImmediateExecutionContext* ctx, + const SavedAsset& asset, + const std::string& saved_model_dir, + absl::Span assets, + std::unique_ptr* output); + +// Creates a TFConcreteFunction from a SavedConcreteFunction. +absl::Status LoadTFConcreteFunction( + const SavedConcreteFunction& saved_concrete_function, + const FunctionDef* function_def, + const std::unordered_map>& + captured_objects, + ImmediateExecutionContext* ctx, std::unique_ptr* out); + +// Flattens `signature` into a vector of TensorSpecProto pointers back into +// `signature`. `signature` must outlive flattened_specs. `signature` must also +// be the input or output signature of a SavedConcreteFunction (i.e. "nested +// structures of tensorspecs"). +absl::Status FlattenSignature( + const StructuredValue& signature, + std::vector* flattened_specs); + +// Find the node id in `object_graph` at location `path`. `path` must be +// a dot-delimited string of object names relative to the root object. If no +// object is found, returns absl::nullopt. +absl::optional FindNodeAtPath(absl::string_view path, + const SavedObjectGraph& object_graph); + +// Maps each node in `graphdef` to its corresponding Attribute Map. +// Callers must ensure that `graphdef` outlives the returned map. +gtl::FlatMap +NodeToAttrMap(const tensorflow::GraphDef& graphdef); + +// Maps the name of each FunctionDef in `library` to its corresponding +// FunctionDef. Callers must ensure `library` outlives the returned map. +gtl::FlatMap +FunctionNameToFunctionDefMap(const FunctionDefLibrary& library); + +// Finds the "signatures" object in the object graph, and fills a mapping of +// each signature's name to the corresponding function's node in the object +// graph. +absl::Status GetSignaturesMap(const SavedObjectGraph& saved_objects, + gtl::FlatMap* signatures_map); + +// Validates the `saved_function`. +absl::Status ValidateSingleConcreteFunction( + const SavedFunction& saved_function); + +// Walks through the SavedObjectGraph in metagraph, and restores all nodes +// (except "UserDefinedObjects") with their corresponding type in +// "PartiallyRevivedObjects". +absl::Status PartiallyReviveSavedModelObjects( + const MetaGraphDef& metagraph, ImmediateExecutionContext* context, + const std::string& directory, PartiallyRevivedObjects* objects); + +} // namespace internal +} // namespace tensorflow + +#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_SAVED_MODEL_UTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/c/experimental/saved_model/core/signature_def_function.h b/third_party/tflite-hdrs/tensorflow/c/experimental/saved_model/core/signature_def_function.h new file mode 100644 index 00000000..71e6a432 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/c/experimental/saved_model/core/signature_def_function.h @@ -0,0 +1,63 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_SIGNATURE_DEF_FUNCTION_H_ +#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_SIGNATURE_DEF_FUNCTION_H_ + +#include +#include + +#include "absl/types/span.h" +#include "tensorflow/c/eager/immediate_execution_operation.h" +#include "tensorflow/c/eager/immediate_execution_tensor_handle.h" +#include "tensorflow/c/experimental/saved_model/core/signature_def_function_metadata.h" + +namespace tensorflow { + +// See tensorflow/cc/experimental/saved_model/public/signature_def_function.h +// for SignatureDefFunction's intended user-facing semantics. +// This class is the "implementation" C++ part of the C++/C/C++ sandwich for +// a SignatureDefFunction. +// Note(bmzhao): Implementation-wise, SignatureDefFunctions are always saved as +// a "BareConcreteFunction", w/o a FunctionSpec, rather than a SavedFunction: +// https://github.com/tensorflow/tensorflow/blob/9bcefa44cd335c1db4a703a13da09f29ae1bbdb2/tensorflow/core/protobuf/saved_object_graph.proto#L60 +// Additionally they are guaranteed to be children of the .signatures attribute +// of the root object, where the child object "name" is the signature_def key: +// https://github.com/tensorflow/tensorflow/blob/9bcefa44cd335c1db4a703a13da09f29ae1bbdb2/tensorflow/python/saved_model/signature_serialization.py#L181-L230 +// One of the critical requirements of SignatureDef functions is that their +// inputs and outputs are "named". For example, a `.signatures` function: +// a. Requires users to pass: kwargs of all inputs: +// https://github.com/tensorflow/tensorflow/blob/26c4ee0c833e74f94d0102d8b005c41a28b44445/tensorflow/python/saved_model/signature_serialization.py#L119-L126 +// b. Returns a dictionary of named outputs. +// https://github.com/tensorflow/tensorflow/blob/26c4ee0c833e74f94d0102d8b005c41a28b44445/tensorflow/python/saved_model/signature_serialization.py#L153-L161 +// Since SignatureDefFunctions do not have FunctionSpecs, but guarantee the +// dictionary of inputs/outputs, we can parse these dictionaries' keys to obtain +// the input/output names of the SignatureDef: +// https://github.com/tensorflow/tensorflow/blob/9bcefa44cd335c1db4a703a13da09f29ae1bbdb2/tensorflow/core/protobuf/meta_graph.proto#L318-L321 +class SignatureDefFunction { + public: + virtual ~SignatureDefFunction() = default; + + // Creates a "Call" Op used to execute the function. + virtual absl::Status MakeCallOp( + absl::Span inputs, + ImmediateOpPtr* out) const = 0; + + virtual const SignatureDefFunctionMetadata& GetFunctionMetadata() const = 0; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_SIGNATURE_DEF_FUNCTION_H_ diff --git a/third_party/tflite-hdrs/tensorflow/c/experimental/saved_model/core/signature_def_function_metadata.h b/third_party/tflite-hdrs/tensorflow/c/experimental/saved_model/core/signature_def_function_metadata.h new file mode 100644 index 00000000..e9cc0b11 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/c/experimental/saved_model/core/signature_def_function_metadata.h @@ -0,0 +1,59 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_SIGNATURE_DEF_FUNCTION_METADATA_H_ +#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_SIGNATURE_DEF_FUNCTION_METADATA_H_ + +#include +#include + +#include "tensorflow/c/experimental/saved_model/core/tensor_spec.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/protobuf/struct.pb.h" + +namespace tensorflow { + +// SignatureDefParam represents a named Tensor input or output to a +// SignatureDefFunction. +class SignatureDefParam { + public: + SignatureDefParam(std::string name, TensorSpec spec); + + const std::string& name() const; + + const TensorSpec& spec() const; + + private: + std::string name_; + TensorSpec spec_; +}; + +class SignatureDefFunctionMetadata { + public: + SignatureDefFunctionMetadata() = default; + SignatureDefFunctionMetadata(std::vector arguments, + std::vector returns); + + const std::vector& arguments() const; + const std::vector& returns() const; + + private: + std::vector arguments_; + std::vector returns_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_SIGNATURE_DEF_FUNCTION_METADATA_H_ diff --git a/third_party/tflite-hdrs/tensorflow/c/experimental/saved_model/core/tensor_spec.h b/third_party/tflite-hdrs/tensorflow/c/experimental/saved_model/core/tensor_spec.h new file mode 100644 index 00000000..dcdff890 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/c/experimental/saved_model/core/tensor_spec.h @@ -0,0 +1,51 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_TENSOR_SPEC_H_ +#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_TENSOR_SPEC_H_ + +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/protobuf/struct.pb.h" + +namespace tensorflow { + +// Note(bmzhao): TensorSpec deliberately does not store the "name" from a +// TensorSpecProto. From edloper@, "Names should really be associated with +// parameters, not the tensors inside those parameters. This would be +// inconsistent with the corresponding Python class, but I don't think that's +// necessarily a problem. If it turns out later that we really need a name +// attribute here, we can always add it back in; but let's see how far we can +// get without it." +class TensorSpec { + public: + // Constructs a scalar, DT_FLOAT TensorSpec + TensorSpec(); + + TensorSpec(PartialTensorShape shape, DataType dtype); + + explicit TensorSpec(const TensorSpecProto& proto); + + const PartialTensorShape& shape() const; + DataType dtype() const; + + private: + PartialTensorShape shape_; + DataType dtype_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_TENSOR_SPEC_H_ diff --git a/third_party/tflite-hdrs/tensorflow/c/experimental/saved_model/core/test_utils.h b/third_party/tflite-hdrs/tensorflow/c/experimental/saved_model/core/test_utils.h new file mode 100644 index 00000000..f3e6548d --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/c/experimental/saved_model/core/test_utils.h @@ -0,0 +1,79 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_TEST_UTILS_H_ +#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_TEST_UTILS_H_ + +#include +#include + +#include "absl/types/span.h" +#include "tensorflow/c/eager/immediate_execution_context.h" +#include "tensorflow/c/eager/immediate_execution_tensor_handle.h" +#include "tensorflow/core/common_runtime/device_mgr.h" +#include "tensorflow/core/common_runtime/eager/context.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +namespace testing { + +// Creates a DeviceMgr suitable for local tests. +std::unique_ptr CreateTestingDeviceMgr(); + +// Creates an EagerContext suitable for local tests. Does not take ownership +// of `device_mgr`. +EagerContextPtr CreateTestingEagerContext(DeviceMgr* device_mgr); + +// Converts a tensorflow::DatatypeSet to std::vector. +// This is useful for tests using GTest's ::testing::ValuesIn, since +// DataTypeSet doesn't fullfill all the constraints of an STL-like iterable. +std::vector DataTypeSetToVector(DataTypeSet set); + +// Returns a vector of shapes intended to be "interesting" test cases. +// Currently, this returns scalar, 1D vector, 2D matrix, and a 4D tensor shapes +std::vector> InterestingShapes(); + +// Returns a TensorHandle of `dtype` and `shape`, filled with `value`. +// `dtype` must be an integer dtype, float, or double. +// If a TensorHandle cannot be created successfully, this function will +// CHECK fail. This should only be used for testing purposes. +ImmediateTensorHandlePtr CreateTensorHandle(ImmediateExecutionContext* ctx, + DataType dtype, + absl::Span shape, + int8_t value); + +// Fills a numeric tensor's buffer with `value`. +// dtype must be any integer dtype, float or double. +void FillNumericTensorBuffer(DataType dtype, size_t num_elements, void* buffer, + int8_t value); + +// Checks the underlying data is equal for the buffers for two numeric tensors. +// Note: The caller must ensure to check that the dtypes and sizes of the +// underlying buffers are the same before calling this. +// dtype must be any integer dtype, float, or double. +void CheckBufferDataIsEqual(DataType dtype, int64_t num_elements, void* a, + void* b); + +// Converts a TensorHandle to a Tensor, and dies if unsuccessful. This should +// only be used for testing purposes. +AbstractTensorPtr TensorHandleToTensor(ImmediateExecutionTensorHandle* handle); + +} // namespace testing +} // namespace tensorflow + +#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_TEST_UTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/c/experimental/saved_model/core/tf_concrete_function_test_protos.h b/third_party/tflite-hdrs/tensorflow/c/experimental/saved_model/core/tf_concrete_function_test_protos.h new file mode 100644 index 00000000..e3fbfefe --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/c/experimental/saved_model/core/tf_concrete_function_test_protos.h @@ -0,0 +1,50 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_TF_CONCRETE_FUNCTION_TEST_PROTOS_H_ +#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_TF_CONCRETE_FUNCTION_TEST_PROTOS_H_ + +#include "tensorflow/core/protobuf/struct.pb.h" + +namespace tensorflow { +namespace testing { + +// Returns a StructuredValue corresponding to the serialized InputSignature of a +// tf.function with 0 inputs +StructuredValue ZeroArgInputSignature(); + +// Returns a StructuredValue corresponding to the serialized InputSignature of a +// tf.function with 1 input +StructuredValue SingleArgInputSignature(); + +// Returns a StructuredValue corresponding to the serialized InputSignature of a +// tf.function with 3 inputs +StructuredValue ThreeArgInputSignature(); + +// Returns a StructuredValue corresponding to the serialized OutputSignature of +// a tf.function with no return values +StructuredValue ZeroReturnOutputSignature(); + +// Returns a StructuredValue corresponding to the serialized OutputSignature of +// a tf.function with a single tensor output +StructuredValue SingleReturnOutputSignature(); + +// Returns a StructuredValue corresponding to the serialized OutputSignature of +// a tf.function with three tensor outputs +StructuredValue ThreeReturnOutputSignature(); + +} // namespace testing +} // namespace tensorflow +#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_TF_CONCRETE_FUNCTION_TEST_PROTOS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/c/experimental/saved_model/core/tf_saved_model_api.h b/third_party/tflite-hdrs/tensorflow/c/experimental/saved_model/core/tf_saved_model_api.h new file mode 100644 index 00000000..17c71258 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/c/experimental/saved_model/core/tf_saved_model_api.h @@ -0,0 +1,93 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_TF_SAVED_MODEL_API_H_ +#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_TF_SAVED_MODEL_API_H_ + +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/types/optional.h" +#include "tensorflow/c/eager/immediate_execution_context.h" +#include "tensorflow/c/experimental/saved_model/core/concrete_function.h" +#include "tensorflow/c/experimental/saved_model/core/revived_types/revived_objects.h" +#include "tensorflow/c/experimental/saved_model/core/revived_types/tensorhandle_convertible.h" +#include "tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function.h" +#include "tensorflow/c/experimental/saved_model/core/revived_types/variable.h" +#include "tensorflow/c/experimental/saved_model/core/saved_model_api.h" +#include "tensorflow/c/experimental/saved_model/core/signature_def_function.h" +#include "tensorflow/cc/saved_model/bundle_v2.h" +#include "tensorflow/core/platform/status.h" + +namespace tensorflow { + +// An implementation of the SavedModelAPI using the TF Eager runtime. See +// https://github.com/tensorflow/community/blob/master/rfcs/20200218-tf-c-saved-model.md +// Conceptually, there are many differences between a tf.function and +// a FunctionDef is executed by the C API. +// 1. A tf.function is polymorphic, meaning it can correspond to multiple +// ConcreteFunctions (of differing shapes, python arguments, etc). A +// FunctionDef corresponds to a single ConcreteFunction. +// 2. A tf.function can take arbitrary python inputs, whereas the FunctionDef +// only accepts tensors. +// 3. A tf.function is a closure that can contain captured inputs, whereas +// FunctionDefs loaded from SavedModels are "functional" (all inputs are +// explicitly passed as arguments). +// The SavedModelAPI only supports loading tf.functions annotated with input +// signatures so that we ensure that there is a 1:1 mapping between tf.function +// -> FunctionDef, and have a guarantee that all inputs are tensors. +// (https://github.com/tensorflow/tensorflow/blob/2b96f3662bd776e277f86997659e61046b56c315/tensorflow/python/eager/def_function.py#L1167-L1171), +class TFSavedModelAPI : public SavedModelAPI { + public: + absl::Status GetFunction(const std::string& function_path, + ConcreteFunction** function) override; + + absl::Status GetFunctions( + int node_id, + absl::flat_hash_map* functions) override; + + absl::Status GetSignatureDefFunction( + const std::string& signature_def_key, + SignatureDefFunction** function) override; + + static absl::Status Load( + const std::string& directory, + const absl::optional>& tags, + ImmediateExecutionContext* context, + std::unique_ptr* out); + + ~TFSavedModelAPI() override = default; + + absl::Status GetVariable(const std::string& variable_path, + Variable** variable); + + SavedModelV2Bundle* GetBundle() override; + + private: + TFSavedModelAPI(const std::string& directory, SavedModelV2Bundle bundle, + RevivedObjects revived_objects); + + std::string directory_; + SavedModelV2Bundle bundle_; + RevivedObjects revived_objects_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_TF_SAVED_MODEL_API_H_ diff --git a/third_party/tflite-hdrs/tensorflow/c/experimental/saved_model/internal/concrete_function_list_type.h b/third_party/tflite-hdrs/tensorflow/c/experimental/saved_model/internal/concrete_function_list_type.h new file mode 100644 index 00000000..66e0a8f9 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/c/experimental/saved_model/internal/concrete_function_list_type.h @@ -0,0 +1,30 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_CONCRETE_FUNCTION_LIST_TYPE_H_ +#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_CONCRETE_FUNCTION_LIST_TYPE_H_ + +#include + +#include "tensorflow/c/experimental/saved_model/core/concrete_function.h" + +// Internal structures used by the SavedModel C API. These are likely to change +// and should not be depended on. + +struct TF_ConcreteFunctionList { + std::vector list; +}; + +#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_CONCRETE_FUNCTION_LIST_TYPE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/c/experimental/saved_model/internal/concrete_function_type.h b/third_party/tflite-hdrs/tensorflow/c/experimental/saved_model/internal/concrete_function_type.h new file mode 100644 index 00000000..bc36b0c6 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/c/experimental/saved_model/internal/concrete_function_type.h @@ -0,0 +1,36 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_CONCRETE_FUNCTION_TYPE_H_ +#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_CONCRETE_FUNCTION_TYPE_H_ + +#include "tensorflow/c/conversion_macros.h" +#include "tensorflow/c/experimental/saved_model/core/concrete_function.h" + +// Internal structures used by the SavedModel C API. These are likely to change +// and should not be depended on. + +// It doesn't make sense to wrap tensorflow::ConcreteFunction* in a separate +// struct, since the lifetime of the struct and the raw pointer it wraps would +// be different. Therefore TF_ConcreteFunction* = tensorflow::ConcreteFunction*. +typedef struct TF_ConcreteFunction TF_ConcreteFunction; + +namespace tensorflow { + +DEFINE_CONVERSION_FUNCTIONS(tensorflow::ConcreteFunction, TF_ConcreteFunction) + +} // namespace tensorflow + +#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_CONCRETE_FUNCTION_TYPE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/c/experimental/saved_model/internal/function_metadata_type.h b/third_party/tflite-hdrs/tensorflow/c/experimental/saved_model/internal/function_metadata_type.h new file mode 100644 index 00000000..40f05f91 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/c/experimental/saved_model/internal/function_metadata_type.h @@ -0,0 +1,30 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_FUNCTION_METADATA_TYPE_H_ +#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_FUNCTION_METADATA_TYPE_H_ + +#include "tensorflow/c/conversion_macros.h" +#include "tensorflow/c/experimental/saved_model/core/function_metadata.h" + +typedef struct TF_FunctionMetadata TF_FunctionMetadata; + +namespace tensorflow { + +DEFINE_CONVERSION_FUNCTIONS(tensorflow::FunctionMetadata, TF_FunctionMetadata) + +} // namespace tensorflow + +#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_FUNCTION_METADATA_TYPE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/c/experimental/saved_model/internal/saved_model_api_type.h b/third_party/tflite-hdrs/tensorflow/c/experimental/saved_model/internal/saved_model_api_type.h new file mode 100644 index 00000000..380c3703 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/c/experimental/saved_model/internal/saved_model_api_type.h @@ -0,0 +1,35 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_SAVED_MODEL_API_TYPE_H_ +#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_SAVED_MODEL_API_TYPE_H_ + +#include + +#include "tensorflow/c/conversion_macros.h" +#include "tensorflow/c/experimental/saved_model/core/saved_model_api.h" + +// Internal structures used by the SavedModel C API. These are likely to change +// and should not be depended on. + +typedef struct TF_SavedModel TF_SavedModel; + +namespace tensorflow { + +DEFINE_CONVERSION_FUNCTIONS(tensorflow::SavedModelAPI, TF_SavedModel) + +} // namespace tensorflow + +#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_SAVED_MODEL_API_TYPE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/c/experimental/saved_model/internal/signature_def_function_metadata_type.h b/third_party/tflite-hdrs/tensorflow/c/experimental/saved_model/internal/signature_def_function_metadata_type.h new file mode 100644 index 00000000..fa6d0f65 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/c/experimental/saved_model/internal/signature_def_function_metadata_type.h @@ -0,0 +1,31 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_SIGNATURE_DEF_FUNCTION_METADATA_TYPE_H_ +#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_SIGNATURE_DEF_FUNCTION_METADATA_TYPE_H_ + +#include "tensorflow/c/conversion_macros.h" +#include "tensorflow/c/experimental/saved_model/core/signature_def_function_metadata.h" + +typedef struct TF_SignatureDefFunctionMetadata TF_SignatureDefFunctionMetadata; + +namespace tensorflow { + +DEFINE_CONVERSION_FUNCTIONS(tensorflow::SignatureDefFunctionMetadata, + TF_SignatureDefFunctionMetadata) + +} // namespace tensorflow + +#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_SIGNATURE_DEF_FUNCTION_METADATA_TYPE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/c/experimental/saved_model/internal/signature_def_function_type.h b/third_party/tflite-hdrs/tensorflow/c/experimental/saved_model/internal/signature_def_function_type.h new file mode 100644 index 00000000..ca44dc43 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/c/experimental/saved_model/internal/signature_def_function_type.h @@ -0,0 +1,31 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_SIGNATURE_DEF_FUNCTION_TYPE_H_ +#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_SIGNATURE_DEF_FUNCTION_TYPE_H_ + +#include "tensorflow/c/conversion_macros.h" +#include "tensorflow/c/experimental/saved_model/core/signature_def_function.h" + +typedef struct TF_SignatureDefFunction TF_SignatureDefFunction; + +namespace tensorflow { + +DEFINE_CONVERSION_FUNCTIONS(tensorflow::SignatureDefFunction, + TF_SignatureDefFunction) + +} // namespace tensorflow + +#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_SIGNATURE_DEF_FUNCTION_TYPE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/c/experimental/saved_model/internal/signature_def_param_list_type.h b/third_party/tflite-hdrs/tensorflow/c/experimental/saved_model/internal/signature_def_param_list_type.h new file mode 100644 index 00000000..6f535110 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/c/experimental/saved_model/internal/signature_def_param_list_type.h @@ -0,0 +1,33 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_SIGNATURE_DEF_PARAM_LIST_TYPE_H_ +#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_SIGNATURE_DEF_PARAM_LIST_TYPE_H_ + +#include + +#include "tensorflow/c/conversion_macros.h" +#include "tensorflow/c/experimental/saved_model/core/signature_def_function_metadata.h" + +typedef struct TF_SignatureDefParamList TF_SignatureDefParamList; + +namespace tensorflow { + +DEFINE_CONVERSION_FUNCTIONS(std::vector, + TF_SignatureDefParamList) + +} // namespace tensorflow + +#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_SIGNATURE_DEF_PARAM_LIST_TYPE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/c/experimental/saved_model/internal/signature_def_param_type.h b/third_party/tflite-hdrs/tensorflow/c/experimental/saved_model/internal/signature_def_param_type.h new file mode 100644 index 00000000..fd634bcd --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/c/experimental/saved_model/internal/signature_def_param_type.h @@ -0,0 +1,30 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_SIGNATURE_DEF_PARAM_TYPE_H_ +#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_SIGNATURE_DEF_PARAM_TYPE_H_ + +#include "tensorflow/c/conversion_macros.h" +#include "tensorflow/c/experimental/saved_model/core/signature_def_function_metadata.h" + +typedef struct TF_SignatureDefParam TF_SignatureDefParam; + +namespace tensorflow { + +DEFINE_CONVERSION_FUNCTIONS(tensorflow::SignatureDefParam, TF_SignatureDefParam) + +} // namespace tensorflow + +#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_SIGNATURE_DEF_PARAM_TYPE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/c/experimental/saved_model/internal/tensor_spec_type.h b/third_party/tflite-hdrs/tensorflow/c/experimental/saved_model/internal/tensor_spec_type.h new file mode 100644 index 00000000..7284c8a8 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/c/experimental/saved_model/internal/tensor_spec_type.h @@ -0,0 +1,30 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_TENSOR_SPEC_TYPE_H_ +#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_TENSOR_SPEC_TYPE_H_ + +#include "tensorflow/c/conversion_macros.h" +#include "tensorflow/c/experimental/saved_model/core/tensor_spec.h" + +typedef struct TF_TensorSpec TF_TensorSpec; + +namespace tensorflow { + +DEFINE_CONVERSION_FUNCTIONS(tensorflow::TensorSpec, TF_TensorSpec) + +} // namespace tensorflow + +#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_TENSOR_SPEC_TYPE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/c/experimental/saved_model/public/c_saved_model_api.h b/third_party/tflite-hdrs/tensorflow/c/experimental/saved_model/public/c_saved_model_api.h new file mode 100644 index 00000000..68f1ece2 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/c/experimental/saved_model/public/c_saved_model_api.h @@ -0,0 +1,31 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_C_SAVED_MODEL_API_H_ +#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_C_SAVED_MODEL_API_H_ + +// IWYU pragma: begin_exports +#include "tensorflow/c/experimental/saved_model/public/concrete_function.h" +#include "tensorflow/c/experimental/saved_model/public/concrete_function_list.h" +#include "tensorflow/c/experimental/saved_model/public/function_metadata.h" +#include "tensorflow/c/experimental/saved_model/public/saved_model_api.h" +#include "tensorflow/c/experimental/saved_model/public/signature_def_function.h" +#include "tensorflow/c/experimental/saved_model/public/signature_def_function_metadata.h" +#include "tensorflow/c/experimental/saved_model/public/signature_def_param.h" +#include "tensorflow/c/experimental/saved_model/public/signature_def_param_list.h" +#include "tensorflow/c/experimental/saved_model/public/tensor_spec.h" +// IWYU pragma: end_exports + +#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_C_SAVED_MODEL_API_H_ diff --git a/third_party/tflite-hdrs/tensorflow/c/experimental/saved_model/public/concrete_function.h b/third_party/tflite-hdrs/tensorflow/c/experimental/saved_model/public/concrete_function.h new file mode 100644 index 00000000..ff8a2459 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/c/experimental/saved_model/public/concrete_function.h @@ -0,0 +1,58 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_CONCRETE_FUNCTION_H_ +#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_CONCRETE_FUNCTION_H_ + +#include "tensorflow/c/c_api_macros.h" +#include "tensorflow/c/eager/c_api.h" +#include "tensorflow/c/experimental/saved_model/public/function_metadata.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +// An opaque type that corresponds to a Function loaded from a SavedModel. +// TODO(bmzhao): Work together w/srbs@ to make sure this composes w/the +// C++ Unified Eager/Graph API's AbstractFunction +typedef struct TF_ConcreteFunction TF_ConcreteFunction; + +// Returns FunctionMetadata associated with `func`. Metadata's lifetime is +// bound to `func`, which is bound to the TF_SavedModel it was loaded from. +TF_CAPI_EXPORT extern TF_FunctionMetadata* TF_ConcreteFunctionGetMetadata( + TF_ConcreteFunction* func); + +// Returns a TFE_Op suitable for executing this function. Caller must provide +// all function inputs in `inputs`, and must not add any additional inputs on +// the returned op. (i.e. don't call TFE_OpAddInput or TFE_OpAddInputList). +// The caller is responsible for deleting the returned TFE_Op. If op +// construction fails, `status` will be non-OK and the returned pointer will be +// null. +// TODO(bmzhao): Remove this function in a subsequent change; Design + implement +// a Function Execution interface for ConcreteFunction that accepts a tagged +// union of types (tensorflow::Value). This effectively requires moving much of +// the implementation of function.py/def_function.py to C++, and exposing a +// high-level API here. A strawman for what this interface could look like: +// TF_Value* TF_ExecuteFunction(TFE_Context*, TF_ConcreteFunction*, TF_Value* +// inputs, int num_inputs, TF_Status* status); +TF_CAPI_EXPORT extern TFE_Op* TF_ConcreteFunctionMakeCallOp( + TF_ConcreteFunction* func, TFE_TensorHandle** inputs, int num_inputs, + TF_Status* status); + +#ifdef __cplusplus +} // end extern "C" +#endif // __cplusplus + +#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_CONCRETE_FUNCTION_H_ diff --git a/third_party/tflite-hdrs/tensorflow/c/experimental/saved_model/public/concrete_function_list.h b/third_party/tflite-hdrs/tensorflow/c/experimental/saved_model/public/concrete_function_list.h new file mode 100644 index 00000000..e3554675 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/c/experimental/saved_model/public/concrete_function_list.h @@ -0,0 +1,47 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_CONCRETE_FUNCTION_LIST_H_ +#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_CONCRETE_FUNCTION_LIST_H_ + +#include + +#include "tensorflow/c/c_api_macros.h" +#include "tensorflow/c/experimental/saved_model/public/concrete_function.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +// An opaque type that is acts like a list of TF_ConcreteFunction pointers. +typedef struct TF_ConcreteFunctionList TF_ConcreteFunctionList; + +// Returns the size of `list`. +TF_CAPI_EXPORT extern size_t TF_ConcreteFunctionListSize( + TF_ConcreteFunctionList* list); + +// Returns the `i`th TF_ConcreteFunction in the list. +TF_CAPI_EXPORT extern TF_ConcreteFunction* TF_ConcreteFunctionListGet( + TF_ConcreteFunctionList* list, int i); + +// Deletes `list`. +TF_CAPI_EXPORT extern void TF_DeleteConcreteFunctionList( + TF_ConcreteFunctionList* list); + +#ifdef __cplusplus +} // end extern "C" +#endif // __cplusplus + +#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_CONCRETE_FUNCTION_LIST_H_ diff --git a/third_party/tflite-hdrs/tensorflow/c/experimental/saved_model/public/function_metadata.h b/third_party/tflite-hdrs/tensorflow/c/experimental/saved_model/public/function_metadata.h new file mode 100644 index 00000000..83ca3c73 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/c/experimental/saved_model/public/function_metadata.h @@ -0,0 +1,35 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_FUNCTION_METADATA_H_ +#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_FUNCTION_METADATA_H_ + +#include "tensorflow/c/c_api_macros.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +// An opaque type used to store any metadata associated with a function. +typedef struct TF_FunctionMetadata TF_FunctionMetadata; + +// TODO(bmzhao): Add getters for fields as we determine what metadata +// we want to expose. + +#ifdef __cplusplus +} // end extern "C" +#endif // __cplusplus + +#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_FUNCTION_METADATA_H_ diff --git a/third_party/tflite-hdrs/tensorflow/c/experimental/saved_model/public/saved_model_api.h b/third_party/tflite-hdrs/tensorflow/c/experimental/saved_model/public/saved_model_api.h new file mode 100644 index 00000000..cef7fe86 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/c/experimental/saved_model/public/saved_model_api.h @@ -0,0 +1,107 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_SAVED_MODEL_API_H_ +#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_SAVED_MODEL_API_H_ + +#include "tensorflow/c/c_api_macros.h" +#include "tensorflow/c/experimental/saved_model/public/concrete_function.h" +#include "tensorflow/c/experimental/saved_model/public/concrete_function_list.h" +#include "tensorflow/c/experimental/saved_model/public/signature_def_function.h" +#include "tensorflow/c/tf_status.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +// An opaque type representing a Tensorflow "SavedModel" +// (https://www.tensorflow.org/guide/saved_model) that we always pass by pointer +// to achieve ABI stability. +typedef struct TF_SavedModel TF_SavedModel; + +// Load a SavedModel from `dirname`. We expect the SavedModel to contain a +// single Metagraph (as for those exported from TF2's `tf.saved_model.save`). +// +// Params: +// dirname - A directory filepath that the SavedModel is at. +// ctx - A TFE_Context containing optional load/TF runtime options. +// `ctx` must outlive the returned TF_SavedModel pointer. +// status - Set to OK on success and an appropriate error on failure. +// Returns: +// If status is not OK, returns nullptr. Otherwise, returns a newly created +// TF_SavedModel instance. It must be deleted by calling TF_DeleteSavedModel. +TF_CAPI_EXPORT extern TF_SavedModel* TF_LoadSavedModel(const char* dirname, + TFE_Context* ctx, + TF_Status* status); + +// Load a SavedModel from `dirname`. +// +// Params: +// dirname - A directory filepath that the SavedModel is at. +// ctx - A TFE_Context containing optional load/TF runtime options. +// `ctx` must outlive the returned TF_SavedModel pointer. +// tags - char* array of SavedModel tags. We will load the metagraph matching +// the tags. +// tags_len - number of elements in the `tags` array. +// status - Set to OK on success and an appropriate error on failure. +// Returns: +// If status is not OK, returns nullptr. Otherwise, returns a newly created +// TF_SavedModel instance. It must be deleted by calling TF_DeleteSavedModel. +TF_CAPI_EXPORT extern TF_SavedModel* TF_LoadSavedModelWithTags( + const char* dirname, TFE_Context* ctx, const char* const* tags, + int tags_len, TF_Status* status); + +// Deletes a TF_SavedModel, and frees any resources owned by it. +TF_CAPI_EXPORT extern void TF_DeleteSavedModel(TF_SavedModel* model); + +// Retrieve a function from the TF2 SavedModel via function path. +// +// Params: +// model - The TF2 SavedModel to load a function from. +// function_path - A string containing the path from the root saved python +// object to a tf.function method. +// TODO(bmzhao): Add a detailed example of this with a +// python tf.module before moving this out of experimental. +// status - Set to OK on success and an appropriate error on failure. +// Returns: +// If status is not OK, returns nullptr. Otherwise, returns a +// TF_ConcreteFunction instance. The lifetime of this instance is +// "conceptually" bound to `model`. Once `model` is deleted, all +// `TF_ConcreteFunctions` retrieved from it are invalid, and have been deleted. +TF_CAPI_EXPORT extern TF_ConcreteFunction* TF_GetSavedModelConcreteFunction( + TF_SavedModel* model, const char* function_path, TF_Status* status); + +// Retrieve a function from the TF SavedModel via a SignatureDef key. +// +// Params: +// model - The SavedModel to load a function from. +// signature_def_key - The string key of the SignatureDef map of a SavedModel: +// https://github.com/tensorflow/tensorflow/blob/69b08900b1e991d84bce31f3b404f5ed768f339f/tensorflow/core/protobuf/meta_graph.proto#L89 +// status - Set to OK on success and an appropriate error on failure. +// Returns: +// If status is not OK, returns nullptr. Otherwise, returns a +// TF_SignatureDefFunction instance. Once `model` is deleted, all +// `TF_SignatureDefFunctions` retrieved from it are invalid, and have been +// deleted. +TF_CAPI_EXPORT extern TF_SignatureDefFunction* +TF_GetSavedModelSignatureDefFunction(TF_SavedModel* model, + const char* signature_def_key, + TF_Status* status); + +#ifdef __cplusplus +} // end extern "C" +#endif // __cplusplus + +#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_SAVED_MODEL_API_H_ diff --git a/third_party/tflite-hdrs/tensorflow/c/experimental/saved_model/public/signature_def_function.h b/third_party/tflite-hdrs/tensorflow/c/experimental/saved_model/public/signature_def_function.h new file mode 100644 index 00000000..16471fdc --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/c/experimental/saved_model/public/signature_def_function.h @@ -0,0 +1,50 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_SIGNATURE_DEF_FUNCTION_H_ +#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_SIGNATURE_DEF_FUNCTION_H_ + +#include "tensorflow/c/c_api_macros.h" +#include "tensorflow/c/eager/c_api.h" +#include "tensorflow/c/experimental/saved_model/public/signature_def_function_metadata.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +// An opaque type that corresponds to a SignatureDefFunction loaded from a +// SavedModel. +typedef struct TF_SignatureDefFunction TF_SignatureDefFunction; + +// Returns FunctionMetadata associated with `func`. Metadata's lifetime is +// bound to `func`, which is bound to the TF_SavedModel it was loaded from. +TF_CAPI_EXPORT extern TF_SignatureDefFunctionMetadata* +TF_SignatureDefFunctionGetMetadata(TF_SignatureDefFunction* func); + +// Returns a TFE_Op suitable for executing this function. Caller must provide +// all function inputs in `inputs`, and must not add any additional inputs on +// the returned op. (i.e. don't call TFE_OpAddInput or TFE_OpAddInputList). +// The caller is responsible for deleting the returned TFE_Op. If op +// construction fails, `status` will be non-OK and the returned pointer will be +// null. +TF_CAPI_EXPORT extern TFE_Op* TF_SignatureDefFunctionMakeCallOp( + TF_SignatureDefFunction* func, TFE_TensorHandle** inputs, int num_inputs, + TF_Status* status); + +#ifdef __cplusplus +} // end extern "C" +#endif // __cplusplus + +#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_SIGNATURE_DEF_FUNCTION_H_ diff --git a/third_party/tflite-hdrs/tensorflow/c/experimental/saved_model/public/signature_def_function_metadata.h b/third_party/tflite-hdrs/tensorflow/c/experimental/saved_model/public/signature_def_function_metadata.h new file mode 100644 index 00000000..b7a7f67e --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/c/experimental/saved_model/public/signature_def_function_metadata.h @@ -0,0 +1,46 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_SIGNATURE_DEF_FUNCTION_METADATA_H_ +#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_SIGNATURE_DEF_FUNCTION_METADATA_H_ + +#include "tensorflow/c/c_api_macros.h" +#include "tensorflow/c/experimental/saved_model/public/signature_def_param_list.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +// An opaque type that corresponds to a SignatureDefFunction loaded from a +// SavedModel. +typedef struct TF_SignatureDefFunctionMetadata TF_SignatureDefFunctionMetadata; + +// Retrieves the arguments of the SignatureDefFunction. The caller is not +// responsible for freeing the returned pointer. +TF_CAPI_EXPORT extern const TF_SignatureDefParamList* +TF_SignatureDefFunctionMetadataArgs( + const TF_SignatureDefFunctionMetadata* list); + +// Retrieves the returns of the SignatureDefFunction. The caller is not +// responsible for freeing the returned pointer. +TF_CAPI_EXPORT extern const TF_SignatureDefParamList* +TF_SignatureDefFunctionMetadataReturns( + const TF_SignatureDefFunctionMetadata* list); + +#ifdef __cplusplus +} // end extern "C" +#endif // __cplusplus + +#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_SIGNATURE_DEF_FUNCTION_METADATA_H_ diff --git a/third_party/tflite-hdrs/tensorflow/c/experimental/saved_model/public/signature_def_param.h b/third_party/tflite-hdrs/tensorflow/c/experimental/saved_model/public/signature_def_param.h new file mode 100644 index 00000000..82993d7f --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/c/experimental/saved_model/public/signature_def_param.h @@ -0,0 +1,44 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_SIGNATURE_DEF_PARAM_H_ +#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_SIGNATURE_DEF_PARAM_H_ + +#include "tensorflow/c/c_api_macros.h" +#include "tensorflow/c/experimental/saved_model/public/tensor_spec.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +// An opaque type that containing metadata of an input/output of a +// TF_SignatureDefFunction loaded from a SavedModel. +typedef struct TF_SignatureDefParam TF_SignatureDefParam; + +// Returns the name of the given parameter. The caller is not responsible for +// freeing the returned char*. +TF_CAPI_EXPORT extern const char* TF_SignatureDefParamName( + const TF_SignatureDefParam* param); + +// Returns the TensorSpec associated with the given parameter. The caller is +// not reponsible for freeing the returned TF_TensorSpec*. +TF_CAPI_EXPORT extern const TF_TensorSpec* TF_SignatureDefParamTensorSpec( + const TF_SignatureDefParam* param); + +#ifdef __cplusplus +} // end extern "C" +#endif // __cplusplus + +#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_SIGNATURE_DEF_PARAM_H_ diff --git a/third_party/tflite-hdrs/tensorflow/c/experimental/saved_model/public/signature_def_param_list.h b/third_party/tflite-hdrs/tensorflow/c/experimental/saved_model/public/signature_def_param_list.h new file mode 100644 index 00000000..0cb3a0d6 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/c/experimental/saved_model/public/signature_def_param_list.h @@ -0,0 +1,44 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_SIGNATURE_DEF_PARAM_LIST_H_ +#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_SIGNATURE_DEF_PARAM_LIST_H_ + +#include + +#include "tensorflow/c/c_api_macros.h" +#include "tensorflow/c/experimental/saved_model/public/signature_def_param.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +// An opaque type that containing metadata of an input/output of a +// ConcreteFunction loaded from a SavedModel. +typedef struct TF_SignatureDefParamList TF_SignatureDefParamList; + +// Returns the size of `list`. +TF_CAPI_EXPORT extern size_t TF_SignatureDefParamListSize( + const TF_SignatureDefParamList* list); + +// Returns the `i`th TF_SignatureDefParam in the list. +TF_CAPI_EXPORT extern const TF_SignatureDefParam* TF_SignatureDefParamListGet( + const TF_SignatureDefParamList* list, int i); + +#ifdef __cplusplus +} // end extern "C" +#endif // __cplusplus + +#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_SIGNATURE_DEF_PARAM_LIST_H_ diff --git a/third_party/tflite-hdrs/tensorflow/c/experimental/saved_model/public/tensor_spec.h b/third_party/tflite-hdrs/tensorflow/c/experimental/saved_model/public/tensor_spec.h new file mode 100644 index 00000000..82972ef7 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/c/experimental/saved_model/public/tensor_spec.h @@ -0,0 +1,46 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_TENSOR_SPEC_H_ +#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_TENSOR_SPEC_H_ + +#include + +#include "tensorflow/c/c_api_macros.h" +#include "tensorflow/c/tf_datatype.h" +#include "tensorflow/c/tf_shape.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +// An opaque type corresponding to TensorSpec +typedef struct TF_TensorSpec TF_TensorSpec; + +// Returns the dtype associated with the TensorSpec. +TF_CAPI_EXPORT extern TF_DataType TF_TensorSpecDataType( + const TF_TensorSpec* spec); + +// Returns the shape associated with the TensorSpec. The returned Shape is not +// owned by the caller. Caller must not call TF_DeleteShape on the returned +// shape. +TF_CAPI_EXPORT extern const TF_Shape* TF_TensorSpecShape( + const TF_TensorSpec* spec); + +#ifdef __cplusplus +} // end extern "C" +#endif // __cplusplus + +#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_TENSOR_SPEC_H_ diff --git a/third_party/tflite-hdrs/tensorflow/c/experimental/stream_executor/stream_executor.h b/third_party/tflite-hdrs/tensorflow/c/experimental/stream_executor/stream_executor.h new file mode 100644 index 00000000..eebbae6c --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/c/experimental/stream_executor/stream_executor.h @@ -0,0 +1,536 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_C_EXPERIMENTAL_STREAM_EXECUTOR_STREAM_EXECUTOR_H_ +#define TENSORFLOW_C_EXPERIMENTAL_STREAM_EXECUTOR_STREAM_EXECUTOR_H_ +#include +#include + +#include "tensorflow/c/c_api_macros.h" +#include "tensorflow/c/tf_status.h" + +// -------------------------------------------------------------------------- +// C API for StreamExecutor. The API is under active development and eventually +// should allow registering a pluggable device with TensorFlow. +// +// Conventions: +// * Struct prefix indicates whether struct fields should be filled by the +// plugin or core implementation: +// * SE_ : set/filled by core unless explicitly marked otherwise. +// * SP_ : set/filled by plugin unless explicitly marked otherwise. +// * We use `struct_size` for version checking. It is exempt from the `SE/SP` +// rule above and should be set both by core and the plugin. +// * For example, `create_device` function receives `SP_Device*` as input +// with `struct_size` populated by core. The plugin is responsible for +// setting `struct_size` as well, along with all other fields. +// * Refer to "TensorFlow Versioning Strategy" section at +// https://github.com/tensorflow/community/pull/257/files. +// * Note that the API is still under active development and doesn't have +// versioning guarantees yet. +// * `void* ext` is a free-form field that can be populated by +// a plugin in `SP_*` structs or potential future extension points in `SE_` +// structs. +// +// Example usage: +// +// /* Sample TensorFlow code below, exact implementation might differ. */ +// // Version checking uses `struct_size`. It is exempt from the `SE/SP` rule +// // above and should be set both by core and the plugin." +// SP_Device device { SP_DEVICE_STRUCT_SIZE }; +// SE_CreateDeviceParams params { SE_CREATE_DEVICE_PARAMS_STRUCT_SIZE } ; +// params.device = &device; +// +// /* Plugin code below */ +// constexpr char DEVICE_NAME[] = "MY_DEVICE"; +// constexpr char DEVICE_TYPE[] = "GPU"; +// +// void create_device(const SP_Platform* platform, +// SE_CreateDeviceParams* params, TF_Status* status) { +// // Custom actions based on TensorFlow's view of SP_Device. +// OnTFDeviceView(params->device->struct_size); +// params->device = { SP_DEVICE_STRUCT_SIZE }; +// params->device->device_handle = get_my_device_handle(device->ordinal); +// params->device->ordinal = params->ordinal; +// ... +// } +// +// void destroy_device(const SP_Platform* platform, SP_Device* device) { +// delete_my_device_handle(device->device_handle); +// } +// +// void SE_InitPlugin( +// SE_PlatformRegistrationParams* params, +// TF_Status* status) { +// params->platform = { SP_PLATFORM_STRUCT_SIZE }; +// // Values such as `name` and `type` must outlive SE_InitPlugin call. +// params->platform->name = DEVICE_NAME; +// params->platform->type = DEVICE_TYPE; +// params->platform_fns->get_device_count = get_device_count; +// params->platform_fns->create_device = create_device; +// params->platform_fns->destroy_device = destroy_device; +// ... +// } + +#define SE_MAJOR 0 +#define SE_MINOR 0 +#define SE_PATCH 1 + +#ifdef __cplusplus +extern "C" { +#endif + +typedef struct SP_Stream_st* SP_Stream; +typedef struct SP_Event_st* SP_Event; +typedef struct SP_Timer_st* SP_Timer; +// Takes `callback_arg` passed to `host_callback` as the first argument. +typedef void (*SE_StatusCallbackFn)(void* const, TF_Status* const); + +typedef struct SP_TimerFns { + size_t struct_size; + void* ext; // reserved for future use + uint64_t (*nanoseconds)(SP_Timer timer); +} SP_TimerFns; + +#define SP_TIMER_FNS_STRUCT_SIZE TF_OFFSET_OF_END(SP_TimerFns, nanoseconds) + +typedef struct SP_AllocatorStats { + size_t struct_size; + int64_t num_allocs; + int64_t bytes_in_use; + int64_t peak_bytes_in_use; + int64_t largest_alloc_size; + + int8_t has_bytes_limit; + int64_t bytes_limit; + + int64_t bytes_reserved; + int64_t peak_bytes_reserved; + + int8_t has_bytes_reservable_limit; + int64_t bytes_reservable_limit; + + int64_t largest_free_block_bytes; +} SP_AllocatorStats; + +#define SP_ALLOCATORSTATS_STRUCT_SIZE \ + TF_OFFSET_OF_END(SP_AllocatorStats, largest_free_block_bytes) + +// Potential states for an SP_Event. If `poll_for_status` returns anything aside +// from kPending or kComplete, an error has occurred; kUnknown is a bad state. +typedef enum SE_EventStatus { + SE_EVENT_UNKNOWN, + SE_EVENT_ERROR, + SE_EVENT_PENDING, + SE_EVENT_COMPLETE, +} SE_EventStatus; + +// Memory allocation information. +// This matches DeviceMemoryBase defined here: +// https://cs.opensource.google/tensorflow/tensorflow/+/refs/tags/v2.3.0:tensorflow/compiler/xla/stream_executor/device_memory.h;l=57 +typedef struct SP_DeviceMemoryBase { + size_t struct_size; + void* ext; // Reserved for future use + // Platform-dependent value representing allocated memory. + // Note that the pointer does not have to be to the virtual address itself. + void* opaque; + uint64_t size; // Size in bytes of this allocation. + uint64_t payload; // Value for plugin's use +} SP_DeviceMemoryBase; + +#define SP_DEVICE_MEMORY_BASE_STRUCT_SIZE \ + TF_OFFSET_OF_END(SP_DeviceMemoryBase, payload) + +typedef struct SP_Device { + size_t struct_size; + void* ext; // free-form data set by plugin + int32_t ordinal; // device index + + // Device vendor can store handle to their device representation + // here. + void* device_handle; + + // [Optional] + // Device hardware name. Used for printing. + // Must be null-terminated. + const char* hardware_name; + + // [Optional] + // Device vendor name. Used for printing. + // Must be null-terminated. + const char* device_vendor; + + // [Optional] + // Returns the PCI bus identifier for this device, of the form + // [domain]:[bus]:[device].[function] + // where domain number is usually 0000. + // Example: 0000:00:02.1 + // For more information see: + // https://en.wikipedia.org/wiki/PCI_configuration_space + // https://www.oreilly.com/library/view/linux-device-drivers/0596005903/ch12.html + // Used for printing. Must be null-terminated. + const char* pci_bus_id; +} SP_Device; + +#define SP_DEVICE_STRUCT_SIZE TF_OFFSET_OF_END(SP_Device, pci_bus_id) + +typedef struct SE_CreateDeviceParams { + size_t struct_size; + void* ext; // reserved for future use + int32_t ordinal; // device index + + SP_Device* device; // Input/output, struct_size set by TF for plugin to read. + // Subsequently plugin fills the entire struct. +} SE_CreateDeviceParams; + +#define SE_CREATE_DEVICE_PARAMS_STRUCT_SIZE \ + TF_OFFSET_OF_END(SE_CreateDeviceParams, device) + +typedef struct SP_DeviceFns { + size_t struct_size; + void* ext; // reserved for future use + + // [Optional] + // Returns the NUMA node associated with this device, for use in + // determining socket locality. If the NUMA node could not be determined, -1 + // is returned. + // Negative values are treated as "unset". + int32_t (*get_numa_node)(const SP_Device* device); + + // [Optional] + // Device's memory bandwidth in bytes/sec. (This is for reads/writes to/from + // the device's own memory, not for transfers between the host and device.) + // Negative values are treated as "unset". + int64_t (*get_memory_bandwidth)(const SP_Device* device); + + // [Optional] + // Estimate of average number of floating point operations per second for + // this device * 10e-9. + // Negative values are treated as "unset". + double (*get_gflops)(const SP_Device* device); +} SP_DeviceFns; + +#define SP_DEVICE_FNS_STRUCT_SIZE TF_OFFSET_OF_END(SP_DeviceFns, get_gflops) + +typedef struct SE_CreateDeviceFnsParams { + size_t struct_size; + void* ext; // reserved for future use + + SP_DeviceFns* device_fns; // output, to be filled by plugin +} SE_CreateDeviceFnsParams; + +#define SE_CREATE_DEVICE_FNS_PARAMS_STRUCT_SIZE \ + TF_OFFSET_OF_END(SE_CreateDeviceFnsParams, device_fns) + +typedef struct SP_StreamExecutor { + size_t struct_size; + void* ext; // reserved for future use + + /*** ALLOCATION CALLBACKS ***/ + // Synchronously allocates `size` bytes on the underlying platform and returns + // `SP_DeviceMemoryBase` representing that allocation. In the case of failure, + // nullptr is returned. + // `memory_space` is reserved for a potential future usage and should be set + // to 0. + void (*allocate)(const SP_Device* device, uint64_t size, int64_t memory_space, + SP_DeviceMemoryBase* mem); + + // Deallocate the device memory previously allocated via this interface. + // Deallocation of a nullptr-representative value is permitted. + void (*deallocate)(const SP_Device* device, SP_DeviceMemoryBase* memory); + + // Allocates a region of host memory and registers it with the platform API. + // Memory allocated in this manner is required for use in asynchronous memcpy + // operations, such as `memcpy_dtoh`. + void* (*host_memory_allocate)(const SP_Device* device, uint64_t size); + + // Deallocates a region of host memory allocated by `host_memory_allocate`. + void (*host_memory_deallocate)(const SP_Device* device, void* mem); + + // Allocates unified memory space of the given size, if supported. Unified + // memory support should be added by setting `supports_unified_memory` field + // in `SP_Platform`. + void* (*unified_memory_allocate)(const SP_Device* device, uint64_t bytes); + + // Deallocates unified memory space previously allocated with + // `unified_memory_allocate`. Unified + // memory support should be added by setting `supports_unified_memory` field + // in `SP_Platform`. + void (*unified_memory_deallocate)(const SP_Device* device, void* location); + + // Fills SP_AllocatorStats with allocator statistics, if it is available. + // If it is not available, return false. + TF_Bool (*get_allocator_stats)(const SP_Device* device, + SP_AllocatorStats* stats); + // Fills the underlying device memory usage information, if it is + // available. If it is not available (false is returned), free/total need not + // be initialized. + TF_Bool (*device_memory_usage)(const SP_Device* device, int64_t* free, + int64_t* total); + + /*** STREAM CALLBACKS ***/ + // Creates SP_Stream. This call should also allocate stream + // resources on the underlying platform and initializes its + // internals. + void (*create_stream)(const SP_Device* device, SP_Stream* stream, + TF_Status* status); + + // Destroys SP_Stream and deallocates any underlying resources. + void (*destroy_stream)(const SP_Device* device, SP_Stream stream); + + // Causes `dependent` to not begin execution until `other` has finished its + // last-enqueued work. + void (*create_stream_dependency)(const SP_Device* device, SP_Stream dependent, + SP_Stream other, TF_Status* status); + + // Without blocking the device, retrieve the current stream status. + void (*get_stream_status)(const SP_Device* device, SP_Stream stream, + TF_Status* status); + + /*** EVENT CALLBACKS ***/ + // Create SP_Event. Performs platform-specific allocation and initialization + // of an event. + void (*create_event)(const SP_Device* device, SP_Event* event, + TF_Status* status); + + // Destroy SE_Event and perform any platform-specific deallocation and + // cleanup of an event. + void (*destroy_event)(const SP_Device* device, SP_Event event); + + // Requests the current status of the event from the underlying platform. + SE_EventStatus (*get_event_status)(const SP_Device* device, SP_Event event); + // Inserts the specified event at the end of the specified stream. + void (*record_event)(const SP_Device* device, SP_Stream stream, + SP_Event event, TF_Status* status); + + // Wait for the specified event at the end of the specified stream. + void (*wait_for_event)(const SP_Device* const device, SP_Stream stream, + SP_Event event, TF_Status* const status); + + /*** TIMER CALLBACKS ***/ + // Creates SP_Timer. Allocates timer resources on the underlying platform + // and initializes its internals, setting `timer` output variable. Sets + // values in `timer_fns` struct. + void (*create_timer)(const SP_Device* device, SP_Timer* timer, + TF_Status* status); + + // Destroy timer and deallocates timer resources on the underlying platform. + void (*destroy_timer)(const SP_Device* device, SP_Timer timer); + + // Records a start event for an interval timer. + void (*start_timer)(const SP_Device* device, SP_Stream stream, SP_Timer timer, + TF_Status* status); + + // Records a stop event for an interval timer. + void (*stop_timer)(const SP_Device* device, SP_Stream stream, SP_Timer timer, + TF_Status* status); + + /*** MEMCPY CALLBACKS ***/ + // Enqueues a memcpy operation onto stream, with a host destination location + // `host_dst` and a device memory source, with target size `size`. + void (*memcpy_dtoh)(const SP_Device* device, SP_Stream stream, void* host_dst, + const SP_DeviceMemoryBase* device_src, uint64_t size, + TF_Status* status); + + // Enqueues a memcpy operation onto stream, with a device destination + // location and a host memory source, with target size `size`. + void (*memcpy_htod)(const SP_Device* device, SP_Stream stream, + SP_DeviceMemoryBase* device_dst, const void* host_src, + uint64_t size, TF_Status* status); + + // Enqueues a memcpy operation onto stream, with a device destination + // location and a device memory source, with target size `size`. + void (*memcpy_dtod)(const SP_Device* device, SP_Stream stream, + SP_DeviceMemoryBase* device_dst, + const SP_DeviceMemoryBase* device_src, uint64_t size, + TF_Status* status); + + // Blocks the caller while a data segment of the given size is + // copied from the device source to the host destination. + void (*sync_memcpy_dtoh)(const SP_Device* device, void* host_dst, + const SP_DeviceMemoryBase* device_src, uint64_t size, + TF_Status* status); + + // Blocks the caller while a data segment of the given size is + // copied from the host source to the device destination. + void (*sync_memcpy_htod)(const SP_Device* device, + SP_DeviceMemoryBase* device_dst, + const void* host_src, uint64_t size, + TF_Status* status); + + // Blocks the caller while a data segment of the given size is copied from the + // device source to the device destination. + void (*sync_memcpy_dtod)(const SP_Device* device, + SP_DeviceMemoryBase* device_dst, + const SP_DeviceMemoryBase* device_src, uint64_t size, + TF_Status* status); + + // Causes the host code to synchronously wait for the event to complete. + void (*block_host_for_event)(const SP_Device* device, SP_Event event, + TF_Status* status); + + // [Optional] + // Causes the host code to synchronously wait for operations entrained onto + // stream to complete. Effectively a join on the asynchronous device + // operations enqueued on the stream before this program point. + // If not set, then corresponding functionality will be implemented + // by registering an event on the `stream` and waiting for it using + // `block_host_for_event`. + void (*block_host_until_done)(const SP_Device* device, SP_Stream stream, + TF_Status* status); + + // Synchronizes all activity occurring in the StreamExecutor's context (most + // likely a whole device). + void (*synchronize_all_activity)(const SP_Device* device, TF_Status* status); + + // Zero out `size` bytes starting at the location. + void (*mem_zero)(const SP_Device* device, SP_Stream stream, + SP_DeviceMemoryBase* location, uint64_t size, + TF_Status* status); + + // Set the 8-bit patterns starting at the location with `size` bytes. + void (*memset)(const SP_Device* device, SP_Stream stream, + SP_DeviceMemoryBase* location, uint8_t pattern, uint64_t size, + TF_Status* status); + + // Set the 32-bit patterns starting at the location with `size` bytes. + void (*memset32)(const SP_Device* device, SP_Stream stream, + SP_DeviceMemoryBase* location, uint32_t pattern, + uint64_t size, TF_Status* status); + + // Enqueues on a stream a user-specified function to be run on the host. + // `callback_arg` should be passed as the first argument to `callback_fn`. + TF_Bool (*host_callback)(const SP_Device* device, SP_Stream stream, + SE_StatusCallbackFn callback_fn, void* callback_arg); +} SP_StreamExecutor; + +#define SP_STREAMEXECUTOR_STRUCT_SIZE \ + TF_OFFSET_OF_END(SP_StreamExecutor, host_callback) + +typedef struct SE_CreateStreamExecutorParams { + size_t struct_size; + void* ext; // reserved for future use + + SP_StreamExecutor* stream_executor; // output, to be filled by plugin +} SE_CreateStreamExecutorParams; + +#define SE_CREATE_STREAM_EXECUTOR_PARAMS_STRUCT_SIZE \ + TF_OFFSET_OF_END(SE_CreateStreamExecutorParams, stream_executor) + +typedef struct SP_Platform { + size_t struct_size; + + void* ext; // free-form data set by plugin + + // Platform name (also referred to as subtype), for example MY_DEVICE. + // The name must start with a capital letter and consist of + // capital letters and underscores. + // Must be null-terminated. + const char* name; + + // Device type name, for example GPU. Must be null-terminated. + // The name must start with a capital letter and consist of + // capital letters and underscores. + const char* type; + + // Whether this platform supports unified memory. + // Unified memory is a single memory address space accessible from any device. + TF_Bool supports_unified_memory; + + // Whether to wrap allocator for this device with an allocator that uses BFC + // (best-fit with coalescing) strategy. + TF_Bool use_bfc_allocator; + + // Whether to force the memory allocations to grow over time instead of + // allocating it all at once. When this is set to true, the value of + // allow_growth is ignored. + TF_Bool force_memory_growth; +} SP_Platform; + +#define SP_PLATFORM_STRUCT_SIZE \ + TF_OFFSET_OF_END(SP_Platform, force_memory_growth) + +typedef struct SP_PlatformFns { + size_t struct_size; + + void* ext; // reserved for future use + + // Callbacks for getting device count + void (*get_device_count)(const SP_Platform* platform, int* device_count, + TF_Status* status); + // Callbacks for creating/destroying SP_Device. + void (*create_device)(const SP_Platform* platform, + SE_CreateDeviceParams* params, TF_Status* status); + + // Clean up fields inside SP_Device that were allocated + // by the plugin. `device` itself should not be deleted here. + void (*destroy_device)(const SP_Platform* platform, SP_Device* device); + + // Callbacks for creating/destroying SP_DeviceFns. + void (*create_device_fns)(const SP_Platform* platform, + SE_CreateDeviceFnsParams* params, + TF_Status* status); + + // Clean up fields inside SP_DeviceFns that were allocated + // by the plugin. `device_fns` itself should not be deleted here. + void (*destroy_device_fns)(const SP_Platform* platform, + SP_DeviceFns* device_fns); + + // Callbacks for creating/destroying SP_StreamExecutor. + void (*create_stream_executor)(const SP_Platform* platform, + SE_CreateStreamExecutorParams* params, + TF_Status* status); + // Clean up fields inside SP_StreamExecutor that were allocated + // by the plugin. `stream_executor` itself should not be deleted here. + void (*destroy_stream_executor)(const SP_Platform* platform, + SP_StreamExecutor* stream_executor); + + // Callbacks for creating/destroying SP_TimerFns. + void (*create_timer_fns)(const SP_Platform* platform, SP_TimerFns* timer, + TF_Status* status); + + void (*destroy_timer_fns)(const SP_Platform* platform, + SP_TimerFns* timer_fns); +} SP_PlatformFns; + +#define SP_PLATFORM_FNS_STRUCT_SIZE \ + TF_OFFSET_OF_END(SP_PlatformFns, destroy_timer_fns) + +typedef struct SE_PlatformRegistrationParams { + size_t struct_size; + void* ext; // reserved for future use + + // StreamExecutor C API version. + int32_t major_version; + int32_t minor_version; + int32_t patch_version; + + SP_Platform* platform; // output, set by plugin + SP_PlatformFns* platform_fns; // output, set by plugin + // Clean up fields inside SP_Platform that were allocated + // by the plugin. `platform` itself should not be deleted here. + void (*destroy_platform)(SP_Platform* platform); // out, set by plugin + void (*destroy_platform_fns)( + SP_PlatformFns* platform_fns); // out, set by plugin +} SE_PlatformRegistrationParams; + +#define SE_PLATFORM_REGISTRATION_PARAMS_STRUCT_SIZE \ + TF_OFFSET_OF_END(SE_PlatformRegistrationParams, destroy_platform_fns) + +void SE_InitPlugin(SE_PlatformRegistrationParams* params, TF_Status* status); + +#ifdef __cplusplus +} // extern "C" +#endif + +#endif // TENSORFLOW_C_EXPERIMENTAL_STREAM_EXECUTOR_STREAM_EXECUTOR_H_ diff --git a/third_party/tflite-hdrs/tensorflow/c/experimental/stream_executor/stream_executor_internal.h b/third_party/tflite-hdrs/tensorflow/c/experimental/stream_executor/stream_executor_internal.h new file mode 100644 index 00000000..b8217ea3 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/c/experimental/stream_executor/stream_executor_internal.h @@ -0,0 +1,336 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// Classes and utilities that work with StreamExecutor C API for internal use. +// This includes functions used for device registration and interfaces needed +// for testing. +#ifndef TENSORFLOW_C_EXPERIMENTAL_STREAM_EXECUTOR_STREAM_EXECUTOR_INTERNAL_H_ +#define TENSORFLOW_C_EXPERIMENTAL_STREAM_EXECUTOR_STREAM_EXECUTOR_INTERNAL_H_ + +#include +#include +#include +#include + +#include "absl/functional/any_invocable.h" +#include "absl/status/status.h" +#include "tensorflow/c/experimental/stream_executor/stream_executor.h" +#include "tensorflow/c/tf_status.h" +#include "tensorflow/c/tf_status_helper.h" +#include "xla/stream_executor/device_memory.h" +#include "xla/stream_executor/event.h" +#include "xla/stream_executor/executor_cache.h" +#include "xla/stream_executor/platform.h" +#include "xla/stream_executor/stream.h" +#include "xla/stream_executor/stream_common.h" +#include "xla/stream_executor/stream_executor.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" + +namespace stream_executor { + +// Plugin initialization function that a device plugin +// must define. +typedef void (*SEInitPluginFn)(SE_PlatformRegistrationParams* const, + TF_Status* const); + +// Registers StreamExecutor platform. `device_type` and `platform_name` are +// output parameters. +absl::Status InitStreamExecutorPlugin(void* dso_handle, + std::string* device_type, + std::string* platform_name); + +// Allow registering a StreamExecutor plugin using a function (used for +// testing). +absl::Status InitStreamExecutorPlugin(SEInitPluginFn init_fn, + std::string* device_type, + std::string* platform_name); + +// Converts DeviceMemoryBase to a C struct. +inline SP_DeviceMemoryBase DeviceMemoryBaseToC(const DeviceMemoryBase* mem) { + SP_DeviceMemoryBase device_memory_base{SP_DEVICE_MEMORY_BASE_STRUCT_SIZE}; + // `opaque` field inside SP_DeviceMemoryBase is not const. + // Therefore, we need to cast away the constness before setting it. + device_memory_base.opaque = const_cast(mem->opaque()); + device_memory_base.size = mem->size(); + device_memory_base.payload = mem->payload(); + return device_memory_base; +} + +// This file implements core stream executor base classes in terms of +// the C API defined in stream_executor.h. A class "CSomething" represents a +// "Something" that can be manipulated via calls in the C interface. +class CPlatform : public Platform { + public: + explicit CPlatform(SP_Platform platform, + void (*destroy_platform)(SP_Platform*), + SP_PlatformFns platform_fns, + void (*destroy_platform_fns)(SP_PlatformFns*), + SP_DeviceFns device_fns, SP_StreamExecutor stream_executor, + SP_TimerFns timer_fns); + ~CPlatform() override; + + Id id() const override { return const_cast(&plugin_id_value_); } + const std::string& Name() const override { return name_; } + int VisibleDeviceCount() const override { + int visible_device_count = 0; + tensorflow::TF_StatusPtr c_status(TF_NewStatus()); + platform_fns_.get_device_count(&platform_, &visible_device_count, + c_status.get()); + if (TF_GetCode(c_status.get()) != TF_OK) { + LOG(ERROR) << TF_Message(c_status.get()); + return 0; + } + return visible_device_count; + } + bool UseBfcAllocator() const { return platform_.use_bfc_allocator; } + bool ForceMemoryGrowth() const { return platform_.force_memory_growth; } + absl::StatusOr> DescriptionForDevice( + int ordinal) const override; + absl::StatusOr ExecutorForDevice(int ordinal) override; + absl::StatusOr FindExisting(int ordinal) override; + + private: + // Returns a device constructed with the ordinal without + // looking in or storing to the Platform's executor cache. + // Ownership IS transferred to the caller. + absl::StatusOr> GetUncachedExecutor( + int ordinal); + + SP_Platform platform_; + void (*destroy_platform_)(SP_Platform*); + SP_PlatformFns platform_fns_; + void (*destroy_platform_fns_)(SP_PlatformFns*); + SP_DeviceFns device_fns_; + SP_StreamExecutor stream_executor_; + SP_TimerFns timer_fns_; + const std::string name_; + int plugin_id_value_; + stream_executor::ExecutorCache executor_cache_; +}; + +class CEvent : public Event { + public: + CEvent(SP_Device* device, SP_StreamExecutor* stream_executor) + : device_(device), + stream_executor_(stream_executor), + event_handle_(nullptr) {} + ~CEvent() override { Destroy(); } + + Event::Status PollForStatus() override { + SE_EventStatus event_status = + stream_executor_->get_event_status(device_, event_handle_); + + switch (event_status) { + case SE_EVENT_ERROR: + return Event::Status::kError; + case SE_EVENT_PENDING: + return Event::Status::kPending; + case SE_EVENT_COMPLETE: + return Event::Status::kComplete; + default: + return Event::Status::kUnknown; + } + } + + absl::Status Create() { + tensorflow::TF_StatusPtr c_status(TF_NewStatus()); + stream_executor_->create_event(device_, &event_handle_, c_status.get()); + return tensorflow::StatusFromTF_Status(c_status.get()); + } + + absl::Status Record(SP_Stream stream_handle) { + tensorflow::TF_StatusPtr c_status(TF_NewStatus()); + stream_executor_->record_event(device_, stream_handle, event_handle_, + c_status.get()); + return tensorflow::StatusFromTF_Status(c_status.get()); + } + + void Destroy() { + if (event_handle_ != nullptr) { + stream_executor_->destroy_event(device_, event_handle_); + event_handle_ = nullptr; + } + } + + SP_Event Handle() { return event_handle_; } + + private: + SP_Device* device_; + SP_StreamExecutor* stream_executor_; + SP_Event event_handle_; +}; + +class CStream : public StreamCommon { + public: + CStream(SP_Device* device, SP_StreamExecutor* stream_executor, + StreamExecutor* executor) + : StreamCommon(executor), + device_(device), + stream_executor_(stream_executor), + stream_handle_(nullptr) {} + ~CStream() override { + BlockHostUntilDone().IgnoreError(); + parent()->DeallocateStream(this); + Destroy(); + } + + absl::Status Create() { + tensorflow::TF_StatusPtr c_status(TF_NewStatus()); + stream_executor_->create_stream(device_, &stream_handle_, c_status.get()); + return tensorflow::StatusFromTF_Status(c_status.get()); + } + + void Destroy() { + if (stream_handle_ != nullptr) { + stream_executor_->destroy_stream(device_, stream_handle_); + stream_handle_ = nullptr; + } + } + absl::Status RefreshStatus() override { + tensorflow::TF_StatusPtr c_status(TF_NewStatus()); + stream_executor_->get_stream_status(device_, stream_handle_, + c_status.get()); + absl::Status status = tensorflow::StatusFromTF_Status(c_status.get()); + CheckStatus(status); + return status; + } + + absl::Status RecordEvent(Event* event) override { + return static_cast(event)->Record(stream_handle_); + } + + absl::Status BlockHostUntilDone() override { + tensorflow::TF_StatusPtr c_status(TF_NewStatus()); + SP_Stream stream_handle = Handle(); + + // If `block_host_until_done` is set, use it. + if (stream_executor_->block_host_until_done != nullptr) { + stream_executor_->block_host_until_done(device_, stream_handle, + c_status.get()); + return tensorflow::StatusFromTF_Status(c_status.get()); + } + // Create and record an event and then wait for it. + SP_Event event_handle; + stream_executor_->create_event(device_, &event_handle, c_status.get()); + TF_RETURN_IF_ERROR(tensorflow::StatusFromTF_Status(c_status.get())); + stream_executor_->record_event(device_, stream_handle, event_handle, + c_status.get()); + absl::Status s = tensorflow::StatusFromTF_Status(c_status.get()); + if (!s.ok()) { + stream_executor_->destroy_event(device_, event_handle); + return s; + } + stream_executor_->block_host_for_event(device_, event_handle, + c_status.get()); + stream_executor_->destroy_event(device_, event_handle); + return tensorflow::StatusFromTF_Status(c_status.get()); + } + + absl::Status WaitFor(Stream* other) override { + tensorflow::TF_StatusPtr c_status(TF_NewStatus()); + SP_Stream other_handle = static_cast(other)->Handle(); + stream_executor_->create_stream_dependency(device_, stream_handle_, + other_handle, c_status.get()); + return tensorflow::StatusFromTF_Status(c_status.get()); + } + absl::Status WaitFor(Event* event) override { + SP_Event event_handle = static_cast(event)->Handle(); + tensorflow::TF_StatusPtr c_status(TF_NewStatus()); + stream_executor_->wait_for_event(device_, stream_handle_, event_handle, + c_status.get()); + return tensorflow::StatusFromTF_Status(c_status.get()); + } + absl::Status MemZero(DeviceMemoryBase* location, uint64_t size) override { + tensorflow::TF_StatusPtr c_status(TF_NewStatus()); + SP_DeviceMemoryBase device_mem = DeviceMemoryBaseToC(location); + stream_executor_->mem_zero(device_, stream_handle_, &device_mem, size, + c_status.get()); + return tensorflow::StatusFromTF_Status(c_status.get()); + } + absl::Status Memset32(DeviceMemoryBase* location, uint32_t pattern, + uint64_t size) override { + tensorflow::TF_StatusPtr c_status(TF_NewStatus()); + SP_DeviceMemoryBase device_mem = DeviceMemoryBaseToC(location); + stream_executor_->memset32(device_, stream_handle_, &device_mem, pattern, + size, c_status.get()); + return tensorflow::StatusFromTF_Status(c_status.get()); + } + absl::Status Memcpy(DeviceMemoryBase* gpu_dst, const void* host_src, + uint64_t size) override { + tensorflow::TF_StatusPtr c_status(TF_NewStatus()); + SP_DeviceMemoryBase device_mem_dst = DeviceMemoryBaseToC(gpu_dst); + stream_executor_->memcpy_htod(device_, stream_handle_, &device_mem_dst, + host_src, size, c_status.get()); + if (TF_GetCode(c_status.get()) != TF_OK) { + LOG(ERROR) << TF_Message(c_status.get()); + } + return tensorflow::StatusFromTF_Status(c_status.get()); + } + absl::Status Memcpy(DeviceMemoryBase* gpu_dst, + const DeviceMemoryBase& gpu_src, uint64_t size) override { + tensorflow::TF_StatusPtr c_status(TF_NewStatus()); + SP_DeviceMemoryBase device_mem_dst = DeviceMemoryBaseToC(gpu_dst); + SP_DeviceMemoryBase device_mem_src = DeviceMemoryBaseToC(&gpu_src); + stream_executor_->memcpy_dtod(device_, stream_handle_, &device_mem_dst, + &device_mem_src, size, c_status.get()); + if (TF_GetCode(c_status.get()) != TF_OK) { + LOG(ERROR) << TF_Message(c_status.get()); + } + return tensorflow::StatusFromTF_Status(c_status.get()); + } + absl::Status Memcpy(void* host_dst, const DeviceMemoryBase& gpu_src, + uint64_t size) override { + tensorflow::TF_StatusPtr c_status(TF_NewStatus()); + SP_DeviceMemoryBase device_mem_src = DeviceMemoryBaseToC(&gpu_src); + stream_executor_->memcpy_dtoh(device_, stream_handle_, host_dst, + &device_mem_src, size, c_status.get()); + if (TF_GetCode(c_status.get()) != TF_OK) { + LOG(ERROR) << TF_Message(c_status.get()); + } + return tensorflow::StatusFromTF_Status(c_status.get()); + } + // Wrapper that allows passing std::function across C API. + struct HostCallbackContext { + absl::AnyInvocable callback; + }; + + // This wrapper allows calling `HostCallbackContext::callback` across C API. + // This function matches `SE_StatusCallbackFn` signature and will be passed as + // `callback_fn` to `host_callback` in `SP_StreamExecutor`. + static void HostCallbackTrampoline(void* ctx, TF_Status* status) { + HostCallbackContext* host_ctx = static_cast(ctx); + absl::Status s = std::move(host_ctx->callback)(); + tsl::Set_TF_Status_from_Status(status, s); + delete host_ctx; + } + absl::Status DoHostCallbackWithStatus( + absl::AnyInvocable callback) override { + HostCallbackContext* ctx = new HostCallbackContext{std::move(callback)}; + if (stream_executor_->host_callback(device_, stream_handle_, + &HostCallbackTrampoline, ctx)) { + return absl::OkStatus(); + } + return absl::InternalError("Failed to host callback."); + } + SP_Stream Handle() { return stream_handle_; } + + private: + SP_Device* device_; + SP_StreamExecutor* stream_executor_; + SP_Stream stream_handle_; +}; + +} // namespace stream_executor +#endif // TENSORFLOW_C_EXPERIMENTAL_STREAM_EXECUTOR_STREAM_EXECUTOR_INTERNAL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/c/experimental/stream_executor/stream_executor_test_util.h b/third_party/tflite-hdrs/tensorflow/c/experimental/stream_executor/stream_executor_test_util.h new file mode 100644 index 00000000..0bebf6f4 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/c/experimental/stream_executor/stream_executor_test_util.h @@ -0,0 +1,56 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_C_EXPERIMENTAL_STREAM_EXECUTOR_STREAM_EXECUTOR_TEST_UTIL_H_ +#define TENSORFLOW_C_EXPERIMENTAL_STREAM_EXECUTOR_STREAM_EXECUTOR_TEST_UTIL_H_ + +#include "tensorflow/c/experimental/stream_executor/stream_executor.h" + +struct SP_Stream_st { + explicit SP_Stream_st(int id) : stream_id(id) {} + int stream_id; +}; + +struct SP_Event_st { + explicit SP_Event_st(int id) : event_id(id) {} + int event_id; +}; + +struct SP_Timer_st { + explicit SP_Timer_st(int id) : timer_id(id) {} + int timer_id; +}; + +namespace stream_executor { +namespace test_util { + +constexpr int kDeviceCount = 2; +constexpr char kDeviceName[] = "MY_DEVICE"; +constexpr char kDeviceType[] = "GPU"; + +void PopulateDefaultStreamExecutor(SP_StreamExecutor* se); +void PopulateDefaultDeviceFns(SP_DeviceFns* device_fns); +void PopulateDefaultTimerFns(SP_TimerFns* timer_fns); +void PopulateDefaultPlatform(SP_Platform* platform, + SP_PlatformFns* platform_fns); +void PopulateDefaultPlatformRegistrationParams( + SE_PlatformRegistrationParams* const params); + +void DestroyPlatform(SP_Platform* platform); +void DestroyPlatformFns(SP_PlatformFns* platform_fns); + +} // namespace test_util +} // namespace stream_executor + +#endif // TENSORFLOW_C_EXPERIMENTAL_STREAM_EXECUTOR_STREAM_EXECUTOR_TEST_UTIL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/c/kernels.h b/third_party/tflite-hdrs/tensorflow/c/kernels.h new file mode 100644 index 00000000..fd7f99cd --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/c/kernels.h @@ -0,0 +1,543 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_KERNELS_H_ +#define TENSORFLOW_C_KERNELS_H_ + +#include +#include + +#include "tensorflow/c/c_api.h" +#include "tensorflow/c/c_api_macros.h" +#include "tensorflow/c/experimental/stream_executor/stream_executor.h" +#include "tensorflow/c/tf_buffer.h" +#include "tensorflow/c/tf_datatype.h" +#include "tensorflow/c/tf_status.h" +#include "tensorflow/c/tf_tensor.h" + +// Required for IS_MOBILE_PLATFORM definition +#include "tsl/platform/platform.h" // IWYU pragma: keep + +#if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD) +#include "tensorflow/core/common_runtime/next_pluggable_device/c/tf_rendezvous_c_api.h" +#endif + +#ifdef __cplusplus +extern "C" { +#endif + +typedef struct TF_Tensor TF_Tensor; + +// -------------------------------------------------------------------------- +// C API for TensorFlow Kernels. +// +// This API allows developers to register custom kernel implementations for +// TensorFlow. +// +// See c_api.h header comments for a discussion about API conventions. +// +// Users wishing to extend TensorFlow with new kernels will call +// `TF_NewKernelBuilder`. The resulting kernel builder can be registered with +// `TF_RegisterKernelBuilder`, which will allow TF to construct user-provided +// kernels when necessary. + +typedef struct TF_KernelBuilder TF_KernelBuilder; +typedef struct TF_OpKernelConstruction TF_OpKernelConstruction; +typedef struct TF_OpKernelContext TF_OpKernelContext; +typedef struct TF_AsyncOpKernelDoneCallback TF_AsyncOpKernelDoneCallback; + +// Run callback function for async kernel. +TF_CAPI_EXPORT extern void TF_RunAsyncOpKernelDoneCallback( + TF_AsyncOpKernelDoneCallback*); + +// TF_InitKernel to do op/kernel registration. +// Plugin should implement TF_InitKernel to register kernels. This function +// should register all kernels in a plugin. +void TF_InitKernel(); + +// Allocates a new kernel builder and returns a pointer to it. +// +// If non-null, TensorFlow will call create_func when it needs to instantiate +// the kernel. The pointer returned by create_func will be passed to +// compute_func and delete_func, thereby functioning as a "this" pointer for +// referring to kernel instances. +// +// The TF_OpKernelConstruction pointer passed to create_func is owned by +// TensorFlow and will be deleted once create_func returns. It must not be used +// after this. +// +// When TensorFlow needs to perform a computation with this kernel, it will +// call compute_func. This function will receive the pointer returned by +// create_func (or null if no create_func was provided), along with the inputs +// to the computation. +// +// The TF_OpKernelContext pointer received by compute_func is owned by +// TensorFlow and will be deleted once compute_func returns. It must not be used +// after this. +// +// Finally, when TensorFlow no longer needs the kernel, it will call +// delete_func if one is provided. This function will receive the pointer +// returned in `create_func` or nullptr if no `create_func` was provided. +// +// The caller should pass the result of this function to +// TF_RegisterKernelBuilder, which will take ownership of the pointer. If, for +// some reason, the kernel builder will not be registered, the caller should +// delete it with TF_DeleteKernelBuilder. +TF_CAPI_EXPORT extern TF_KernelBuilder* TF_NewKernelBuilder( + const char* op_name, const char* device_name, + void* (*create_func)(TF_OpKernelConstruction*), + void (*compute_func)(void*, TF_OpKernelContext*), + void (*delete_func)(void*)); + +// Allocates a new kernel builder and returns a pointer to it. +// +// It is similar as TF_NewKernelBuilder, except compute_async_func. +// It creates an AsyncOpKernel, and performs async computation through +// compute_async_func. +TF_CAPI_EXPORT extern TF_KernelBuilder* TF_NewAsyncKernelBuilder( + const char* op_name, const char* device_name, + void* (*create_func)(TF_OpKernelConstruction*), + void (*compute_async_func)(void*, TF_OpKernelContext*, + TF_AsyncOpKernelDoneCallback* done), + void (*delete_func)(void*)); + +// Specifies that this kernel's attribute only supports the given type. +TF_CAPI_EXPORT extern void TF_KernelBuilder_TypeConstraint( + TF_KernelBuilder* kernel_builder, const char* attr_name, + const TF_DataType type, TF_Status* status); + +// Specify that this kernel requires/provides an input/output arg +// in host memory (instead of the default, device memory). +TF_CAPI_EXPORT extern void TF_KernelBuilder_HostMemory( + TF_KernelBuilder* kernel_builder, const char* arg_name); + +// Specify a priority number for this kernel. +TF_CAPI_EXPORT extern void TF_KernelBuilder_Priority( + TF_KernelBuilder* kernel_builder, int32_t priority_number); + +// Specify a label for this kernel. +TF_CAPI_EXPORT extern void TF_KernelBuilder_Label( + TF_KernelBuilder* kernel_builder, const char* label); + +// Register the given kernel builder with the TensorFlow runtime. If +// registration fails, the given status will be populated. +// +// This call takes ownership of the `builder` pointer. +TF_CAPI_EXPORT extern void TF_RegisterKernelBuilder(const char* kernel_name, + TF_KernelBuilder* builder, + TF_Status* status); + +// Register the given kernel builder with the TensorFlow runtime. If +// registration fails, the given status will be populated. +// +// This method is the same as TF_RegisterKernelBuilder except it takes in a +// serialized KernelDef, and uses it for registration, instead of building a new +// one. Users can choose to not provide a serialized KernelDef and in that case +// it's identical to TF_RegisterKernelBuilder. +TF_CAPI_EXPORT extern void TF_RegisterKernelBuilderWithKernelDef( + const char* serialized_kernel_def, const char* name, + TF_KernelBuilder* builder, TF_Status* status); + +// Deletes the given TF_KernelBuilder. This should be called only if the kernel +// builder is not registered with TensorFlow via TF_RegisterKernelBuilder. +TF_CAPI_EXPORT extern void TF_DeleteKernelBuilder(TF_KernelBuilder* builder); + +// -------------------------------------------------------------------------- +// OpKernelContext routines + +// TF_GetStream returns the SP_Stream available in ctx. +// This function returns a stream only for devices registered using the +// StreamExecutor C API +// (tensorflow/c/experimental/stream_executor/stream_executor.h). It will return +// nullptr and set error status in all other cases. +// Experimental: this function doesn't have compatibility guarantees and subject +// to change at any time. +TF_CAPI_EXPORT extern SP_Stream TF_GetStream(TF_OpKernelContext* ctx, + TF_Status* status); + +// TF_NumInputs returns the number of inputs available in ctx. +TF_CAPI_EXPORT extern int TF_NumInputs(TF_OpKernelContext* ctx); + +// TF_NumOutputs returns the number of outputs to be placed in *ctx by the +// kernel. +TF_CAPI_EXPORT extern int TF_NumOutputs(TF_OpKernelContext* ctx); + +// Retrieves the ith input from ctx. If TF_GetCode(status) is TF_OK, *tensor is +// populated and its ownership is passed to the caller. In any other case, +// *tensor is not modified. +// +// If i < 0 or i >= TF_NumInputs(ctx), *status is set to TF_OUT_OF_RANGE. +TF_CAPI_EXPORT extern void TF_GetInput(TF_OpKernelContext* ctx, int i, + TF_Tensor** tensor, TF_Status* status); + +typedef struct { + size_t struct_size; + void* priv; // Not used, for possible extension. + int start; // output + int stop; // output + TF_Status* status; // output +} TF_InputRange_Args; +const size_t TF_InputRange_Args_STRUCT_SIZE = + TF_OFFSET_OF_END(TF_InputRange_Args, status); + +// Retrieves the start and stop indices, given the input name. Equivalent to +// OpKernel::InputRange(). `args` will contain the result indices and status. +TF_CAPI_EXPORT extern void TF_InputRange(TF_OpKernelContext* ctx, + const char* name, + TF_InputRange_Args* args); + +// Returns the data type of the index-th input. If index < 0 or index >= +// TF_NumInputs(ctx), the program aborts. +TF_CAPI_EXPORT extern TF_DataType TF_InputDatatype(TF_OpKernelContext* ctx, + int index); + +// Sets the ith output of ctx to tensor. If TF_GetCode(status) is anything but +// TF_OK, ctx is left unmodified. +// +// If i < 0 or i >= TF_NumOutputs(ctx), *status is set to TF_OUT_OF_RANGE. +TF_CAPI_EXPORT extern void TF_SetOutput(TF_OpKernelContext* ctx, int i, + const TF_Tensor* tensor, + TF_Status* status); + +// Retrieves the ith output from ctx. If TF_GetCode(status) is TF_OK, *tensor is +// populated and its ownership is passed to the caller. In any other case, +// *tensor is not modified. +// +// If i < 0 or i >= TF_NumOutputs(ctx), *status is set to TF_OUT_OF_RANGE. +TF_CAPI_EXPORT extern TF_Tensor* TF_GetMutableOutput(TF_OpKernelContext* ctx, + int i, TF_Status* status); + +// Retrieves a serialized FunctionDefLibrary. Status will be set. +TF_CAPI_EXPORT extern void TF_GetSerializedFunctionDefLibrary( + TF_OpKernelContext* ctx, TF_Buffer* serialized_function_def_library, + TF_Status* status); + +// Retrieves a serialized ConfigProto. Status will be set. +TF_CAPI_EXPORT extern void TF_GetSerializedConfigProto( + TF_OpKernelContext* ctx, TF_Buffer* serialized_config_proto, + TF_Status* status); + +// Retrieves a serialized ResourceHandleProto. Status will be set. +TF_CAPI_EXPORT extern void TF_GetSerializedResourceHandleProto( + TF_OpKernelContext* ctx, int i, TF_Buffer* serialized_resource_handle_proto, + TF_Status* status); + +// Notifies the given OpKernelConstruction that kernel construction has failed. +TF_CAPI_EXPORT extern void TF_OpKernelConstruction_Failure( + TF_OpKernelConstruction* ctx, TF_Status* status); + +// Notifies the given OpKernelContext that the kernel's compute function has +// failed. +TF_CAPI_EXPORT extern void TF_OpKernelContext_Failure(TF_OpKernelContext* ctx, + TF_Status* status); + +// Returns the expected output data type of the ith output. If i < 0 or +// i >= TF_NumOutputs(ctx), the program aborts. +TF_CAPI_EXPORT extern TF_DataType TF_ExpectedOutputDataType( + TF_OpKernelContext* ctx, int i); + +// Returns true if the ith input is allocated in host memory. If i < 0 or i >= +// TF_NumInputs(ctx), the program aborts. +TF_CAPI_EXPORT extern bool TF_IsHostMemoryInput(TF_OpKernelContext* ctx, int i, + TF_Status* status); + +// Returns true if the ith output is allocated in host memory. If i < 0 or i >= +// TF_NumOutputs(ctx), the program aborts. +TF_CAPI_EXPORT extern bool TF_IsHostMemoryOutput(TF_OpKernelContext* ctx, int i, + TF_Status* status); + +// Returns the step ID of the given context. +TF_CAPI_EXPORT extern int64_t TF_StepId(TF_OpKernelContext* ctx); + +// Returns the serialized NodeDef protocol buffer for the kernel +TF_CAPI_EXPORT extern TF_Buffer* TF_OpKernelConstruction_GetNodeDef( + TF_OpKernelConstruction* ctx, TF_Status* status); + +// Returns the frame ID of the given context. +TF_CAPI_EXPORT extern uint64_t TF_GetFrameId(TF_OpKernelContext* ctx); + +// Returns the Iter ID of the given context. +TF_CAPI_EXPORT extern int64_t TF_GetIterId(TF_OpKernelContext* ctx); + +// Returns the Step ID of the given context. +TF_CAPI_EXPORT extern int64_t TF_GetStepId(TF_OpKernelContext* ctx); + +// Returns the Device ID of the device that the context possesses. Returns the +// PlatformDeviceId if a mapping between between TfDeviceId and PlatformDeviceId +// is set; otherwise returns the id in the device name. Please refer to +// tensorflow/compiler/xla/tsl/framework/device_id.h for more details. +// For mobile or slim build, returns the id in the device name. +TF_CAPI_EXPORT extern int TF_GetDeviceId(TF_OpKernelContext* ctx); + +// Returns the Device Name of the device that the context possesses. +// +// The returned TF_StringView's underlying string is owned by the OpKernel and +// has the same lifetime as the OpKernel. +TF_CAPI_EXPORT TF_StringView TF_GetDeviceName(TF_OpKernelContext* ctx); + +#if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD) +// Returns the rendezvous in the context. Not supported on mobile. +TF_CAPI_EXPORT TF_RendezvousThunk TF_GetRendezvous(TF_OpKernelContext* ctx); +#endif + +// Returns the graph def version of the given context. +TF_CAPI_EXPORT extern int TF_GetGraphDefVersion(TF_OpKernelContext* ctx); + +// Returns the name of the OpKernel. +// +// The returned TF_StringView's underlying string is owned by the OpKernel and +// has the same lifetime as the OpKernel. +TF_CAPI_EXPORT extern TF_StringView TF_GetOpKernelName(TF_OpKernelContext* ctx); + +// Returns the default container of the resource manager in OpKernelContext. +// +// The returned TF_StringView's underlying string is owned by the OpKernel and +// has the same lifetime as the OpKernel. +TF_CAPI_EXPORT extern TF_StringView TF_GetResourceMgrDefaultContainerName( + TF_OpKernelContext* ctx); + +// Returns the name of the requested input at `index` from the OpKernel. +// +// The returned TF_StringView's underlying string is owned by the OpKernel and +// has the same lifetime as the OpKernel. +TF_CAPI_EXPORT extern TF_StringView TF_GetOpKernelRequestedInput( + TF_OpKernelContext* ctx, size_t index); + +// Get the list_size and total_size of the attribute `attr_name` of `oper`. +// list_size - the length of the list. +// total_size - total size of the list. +// (1) If attr_type == TF_ATTR_STRING +// then total_size is the cumulative byte size +// of all the strings in the list. +// (3) If attr_type == TF_ATTR_SHAPE +// then total_size is the number of dimensions +// of the shape valued attribute, or -1 +// if its rank is unknown. +// (4) If attr_type == TF_ATTR_SHAPE +// then total_size is the cumulative number +// of dimensions of all shapes in the list. +// (5) Otherwise, total_size is undefined. +TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrSize( + TF_OpKernelConstruction* ctx, const char* attr_name, int32_t* list_size, + int32_t* total_size, TF_Status* status); + +// Interprets the named kernel construction attribute as a TF_DataType and +// places it into *val. *status is set to TF_OK. +// +// If the attribute could not be found or could not be interpreted as +// TF_DataType, *status is populated with an error. +TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrType( + TF_OpKernelConstruction* ctx, const char* attr_name, TF_DataType* val, + TF_Status* status); + +// Interprets the named kernel construction attribute as int32_t and +// places it into *val. *status is set to TF_OK. +// +// If the attribute could not be found or could not be interpreted as +// int32, *status is populated with an error. +TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrInt32( + TF_OpKernelConstruction* ctx, const char* attr_name, int32_t* val, + TF_Status* status); + +// Interprets the named kernel construction attribute as int64_t and +// places it into *val. *status is set to TF_OK. +// +// If the attribute could not be found or could not be interpreted as +// int64, *status is populated with an error. +TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrInt64( + TF_OpKernelConstruction* ctx, const char* attr_name, int64_t* val, + TF_Status* status); + +// Interprets the named kernel construction attribute as float and +// places it into *val. *status is set to TF_OK. +// +// If the attribute could not be found or could not be interpreted as +// float, *status is populated with an error. +TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrFloat( + TF_OpKernelConstruction* ctx, const char* attr_name, float* val, + TF_Status* status); + +// Interprets the named kernel construction attribute as bool and +// places it into *val. *status is set to TF_OK. +// +// If the attribute could not be found or could not be interpreted as +// bool, *status is populated with an error. +TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrBool( + TF_OpKernelConstruction* ctx, const char* attr_name, TF_Bool* val, + TF_Status* status); + +// Interprets the named kernel construction attribute as string and +// places it into *val. `val` must +// point to an array of length at least `max_length` (ideally set to +// total_size from TF_OpKernelConstruction_GetAttrSize(ctx, +// attr_name, list_size, total_size)). *status is set to TF_OK. +// +// If the attribute could not be found or could not be interpreted as +// string, *status is populated with an error. +TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrString( + TF_OpKernelConstruction* ctx, const char* attr_name, char* val, + size_t max_length, TF_Status* status); + +// Interprets the named kernel construction attribute as tensor and places it +// into *val. Allocates a new TF_Tensor which the caller is expected to take +// ownership of (and can deallocate using TF_DeleteTensor). *status is set to +// TF_OK. +// +// If the attribute could not be found or could not be interpreted as +// tensor, *status is populated with an error. +TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrTensor( + TF_OpKernelConstruction* ctx, const char* attr_name, TF_Tensor** val, + TF_Status* status); + +// Interprets the named kernel construction attribute as a TF_DataType array and +// places it into *vals. *status is set to TF_OK. +// `vals` must point to an array of length at least `max_values` (ideally set +// to list_size from +// TF_OpKernelConstruction_GetAttrSize(ctx, attr_name, list_size, +// total_size)). +TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrTypeList( + TF_OpKernelConstruction* ctx, const char* attr_name, TF_DataType* vals, + int max_vals, TF_Status* status); + +// Interprets the named kernel construction attribute as int32_t array and +// places it into *vals. *status is set to TF_OK. +// `vals` must point to an array of length at least `max_values` (ideally set +// to list_size from +// TF_OpKernelConstruction_GetAttrSize(ctx, attr_name, list_size, +// total_size)). +TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrInt32List( + TF_OpKernelConstruction* ctx, const char* attr_name, int32_t* vals, + int max_vals, TF_Status* status); + +// Interprets the named kernel construction attribute as int64_t array and +// places it into *vals. *status is set to TF_OK. +// `vals` must point to an array of length at least `max_values` (ideally set +// to list_size from +// TF_OpKernelConstruction_GetAttrSize(ctx, attr_name, list_size, +// total_size)). +TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrInt64List( + TF_OpKernelConstruction* ctx, const char* attr_name, int64_t* vals, + int max_vals, TF_Status* status); + +// Interprets the named kernel construction attribute as float array and +// places it into *vals. *status is set to TF_OK. +// `vals` must point to an array of length at least `max_values` (ideally set +// to list_size from +// TF_OpKernelConstruction_GetAttrSize(ctx, attr_name, list_size, +// total_size)). +TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrFloatList( + TF_OpKernelConstruction* ctx, const char* attr_name, float* vals, + int max_vals, TF_Status* status); + +// Interprets the named kernel construction attribute as bool array and +// places it into *vals. *status is set to TF_OK. +// `vals` must point to an array of length at least `max_values` (ideally set +// to list_size from +// TF_OpKernelConstruction_GetAttrSize(ctx, attr_name, list_size, +// total_size)). +TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrBoolList( + TF_OpKernelConstruction* ctx, const char* attr_name, TF_Bool* vals, + int max_vals, TF_Status* status); + +// Interprets the named kernel construction attribute as string array and fills +// in `vals` and `lengths`, each of which must point to an array of length at +// least `max_values`. *status is set to TF_OK. The elements of values will +// point to addresses in `storage` which must be at least `storage_size` bytes +// in length. Ideally, max_values would be set to list_size and `storage` would +// be at least total_size, obtained from +// TF_OpKernelConstruction_GetAttrSize(ctx, attr_name, list_size, +// total_size). +TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrStringList( + TF_OpKernelConstruction* ctx, const char* attr_name, char** values, + size_t* lengths, int max_values, void* storage, size_t storage_size, + TF_Status* status); + +// Interprets the named kernel construction attribute as tensor array and places +// it into *vals. *status is set to TF_OK. +// `vals` must point to an array of length at least `max_values` +// (ideally set to list_size from TF_OpKernelConstruction_GetAttrSize(ctx, +// attr_name, list_size, total_size)). +// +// The caller takes ownership of all the non-null TF_Tensor* entries in `vals` +// (which can be deleted using TF_DeleteTensor(vals[i])). +TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrTensorList( + TF_OpKernelConstruction* ctx, const char* attr_name, TF_Tensor** vals, + int max_values, TF_Status* status); + +// Interprets the named kernel construction attribute as a +// tensorflow::NameAttrList and returns the serialized proto as TF_Buffer. +// `status` will be set. The caller takes ownership of the returned TF_Buffer +// (if not null) and is responsible for managing its lifetime. +TF_CAPI_EXPORT extern TF_Buffer* TF_OpKernelConstruction_GetAttrFunction( + TF_OpKernelConstruction* ctx, const char* attr_name, TF_Status* status); + +// Return true if the kernel construction has the attr_name +TF_CAPI_EXPORT extern bool TF_OpKernelConstruction_HasAttr( + TF_OpKernelConstruction* ctx, const char* attr_name, TF_Status* status); + +// Returns the unique operation name for this OpKernel. +TF_CAPI_EXPORT extern TF_StringView TF_OpKernelConstruction_GetName( + TF_OpKernelConstruction* ctx); + +// Allocates Tensor for output at given index. Caller takes ownership of +// returned TF_Tensor and should deallocate it using TF_DeleteTensor(tensor). +// +// This function should be used to allocate outputs inside kernel +// compute function. +TF_CAPI_EXPORT TF_Tensor* TF_AllocateOutput(TF_OpKernelContext* context, + int index, TF_DataType dtype, + const int64_t* dims, int num_dims, + size_t len, TF_Status* status); + +// Tries to forward one of the inputs given in input_indices to +// output[output_index]. If none of the given inputs can be forwarded, calls +// allocate_output() to allocate a new output buffer. The index of the +// forwarded input will be assign to output argument forwarded_input (if it's +// not nullptr). If no inputs are forwarded, forwarded_input will be assigned +// -1. +TF_CAPI_EXPORT TF_Tensor* TF_ForwardInputOrAllocateOutput( + TF_OpKernelContext* context, const int* candidate_input_indices, + int num_candidate_input_indices, int output_index, + const int64_t* output_dims, int output_num_dims, int* forwarded_input, + TF_Status* status); + +// Allocates a temporary Tensor of the specified type and shape. The +// Tensor must not be used after kernel construction is +// complete. +// +// num_dims must equal the size of array dims +TF_CAPI_EXPORT extern TF_Tensor* TF_AllocateTemp( + TF_OpKernelContext* context, TF_DataType dtype, const int64_t* dims, + int num_dims, TF_AllocatorAttributes* alloc_attrs, TF_Status* status); + +// Used by OpKernel implementations to track actively running deferred ops. +// +// A deferred op is one whose Compute method returns (or whose ComputeAsync +// method invokes the callback) when work is scheduled onto a device. At that +// point, we don't know when the work will actually complete (or if it has +// already completed) on the device. These functions allow the executor to +// track the status of deferred ops and act accordingly. +// +// Deferred OpKernel implementations must use these methods to get two +// functions. It then must call these two functions in pairs, before and after +// device execution, respectively. +TF_CAPI_EXPORT extern void TF_IncNumDeferredOps(TF_OpKernelContext* context); +TF_CAPI_EXPORT extern void TF_DecNumDeferredOps(TF_OpKernelContext* context); + +#ifdef __cplusplus +} /* end extern "C" */ +#endif + +#endif // TENSORFLOW_C_KERNELS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/c/kernels/tensor_shape_utils.h b/third_party/tflite-hdrs/tensorflow/c/kernels/tensor_shape_utils.h new file mode 100644 index 00000000..27167b39 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/c/kernels/tensor_shape_utils.h @@ -0,0 +1,37 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This file contains shape utilities to be used by kernels and is not part of +// the C API. As such, it is subject to change at any time. + +#ifndef TENSORFLOW_C_KERNELS_TENSOR_SHAPE_UTILS_H_ +#define TENSORFLOW_C_KERNELS_TENSOR_SHAPE_UTILS_H_ + +#include + +#include "tensorflow/c/tf_tensor.h" + +namespace tensorflow { + +// The following are utils for the shape of a TF_Tensor type. +// These functions may later be subsumed by the methods for a +// TF_TensorShape type. + +// Returns a string representation of the TF_Tensor shape. +std::string ShapeDebugString(TF_Tensor* tensor); + +} // namespace tensorflow + +#endif // TENSORFLOW_C_KERNELS_TENSOR_SHAPE_UTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/c/kernels_experimental.h b/third_party/tflite-hdrs/tensorflow/c/kernels_experimental.h new file mode 100644 index 00000000..2f93e6b2 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/c/kernels_experimental.h @@ -0,0 +1,192 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_KERNELS_EXPERIMENTAL_H_ +#define TENSORFLOW_C_KERNELS_EXPERIMENTAL_H_ + +#include "tensorflow/c/c_api_macros.h" +#include "tensorflow/c/kernels.h" + +// -------------------------------------------------------------------------- +// Experimental kernel C API for TensorFlow. +// +// The API here is subject to changes in the future. +// -------------------------------------------------------------------------- + +#ifdef __cplusplus +extern "C" { +#endif + +typedef struct TF_VariableInputLockHolder TF_VariableInputLockHolder; + +// Expose higher level Assignment operation for Pluggable vendors to implement +// in the plugin for Training. The API takes in the context with indices for +// the input and value tensors. It also accepts the copy callback provided by +// pluggable vendor to do the copying of the tensors. The caller takes ownership +// of the `source` and `dest` tensors and is responsible for freeing them with +// TF_DeleteTensor. This function will return an error when the following +// conditions are met: +// 1. `validate_shape` is set to `true` +// 2. The variable is initialized +// 3. The shape of the value tensor doesn't match the shape of the variable +// tensor. +TF_CAPI_EXPORT extern void TF_AssignVariable( + TF_OpKernelContext* ctx, int input_index, int value_index, + bool validate_shape, + void (*copyFunc)(TF_OpKernelContext* ctx, TF_Tensor* source, + TF_Tensor* dest), + TF_Status* status); + +// Expose higher level Assignment operation for Pluggable vendors to implement +// in the plugin for Training on ref variables. The API takes in the context +// with indices for the input and value tensors. It also accepts the copy +// callback provided by pluggable vendor to do the copying of the tensors. The +// caller takes ownership of the `source` and `dest` tensors and is responsible +// for freeing them with TF_DeleteTensor. +TF_CAPI_EXPORT extern void TF_AssignRefVariable( + TF_OpKernelContext* ctx, int input_ref_index, int output_ref_index, + int value_index, bool use_locking, bool validate_shape, + void (*copyFunc)(TF_OpKernelContext* ctx, TF_Tensor* source, + TF_Tensor* dest), + TF_Status* status); + +// Expose higher level AssignUpdate operation for Pluggable vendors to implement +// in the plugin for Training. The API takes in the context with indices for the +// input and value tensors. It also accepts the copy callback provided by +// pluggable vendor to do the copying of the tensors and the update callback to +// apply the arithmetic operation. The caller takes ownership of the `source`, +// `dest`, `tensor` and `value` tensors and is responsible for freeing them with +// TF_DeleteTensor. +TF_CAPI_EXPORT extern void TF_AssignUpdateVariable( + TF_OpKernelContext* ctx, int input_index, int value_index, int Op, + int isVariantType, + void (*copyFunc)(TF_OpKernelContext* ctx, TF_Tensor* source, + TF_Tensor* dest), + void (*updateFunc)(TF_OpKernelContext* ctx, TF_Tensor* tensor, + TF_Tensor* value, int Op), + TF_Status* status); + +// Expose higher level temporary variable operator for Pluggable vendors to +// implement in the plugin for managing temporary variables. The API takes in +// the context with indices for the input and value tensors. It also accepts the +// allocator provided by pluggable vendor to do the allocate_temp of the +// tensors. The caller takes ownership of temporary variables and is responsible +// for freeing them with TF_DestroyTemporaryVariable. This function will return +// an error when the following conditions are met: +// 1. Cannot allocate a new temporary variable +// 2. Calling plugin allocator failed +TF_CAPI_EXPORT extern void TF_TemporaryVariable( + TF_OpKernelContext* ctx, TF_DataType dtype, const int64_t* dims, + int num_dims, TF_StringView* var_name, + void (*plugin_allocator)(TF_OpKernelContext*, TF_Tensor*, TF_DataType, + const int64_t*, int, TF_Status*), + TF_Status* tf_status); + +// Expose higher level temporary variable operator for Pluggable vendors to +// implement in the plugin for destroying temporary variables. The API takes in +// the context with indices for the input and variable name. This function will +// return an error when either of the following conditions is met: +// 1. `input data type` is not ref type +// 2. Cannot find temporary variable by name in arguments +TF_CAPI_EXPORT extern void TF_DestroyTemporaryVariable(TF_OpKernelContext* ctx, + const int index, + TF_StringView* var_name, + TF_Status* tf_status); + +// This is a helper function which acquires mutexes in-order to provide +// thread-safe way of performing weights update during the optimizer op. It +// returns an opaque LockHolder handle back to plugin. This handle is passed to +// the Release API for releasing the locks when the weight update is done. The +// caller takes ownership of the `source` and `dest` tensors and is responsible +// for freeing them with TF_DeleteTensor. +TF_CAPI_EXPORT extern void TF_MaybeLockVariableInputMutexesInOrder( + TF_OpKernelContext* ctx, bool do_lock, bool sparse, const int* const inputs, + size_t len, + void (*copyFunc)(TF_OpKernelContext* ctx, TF_Tensor* source, + TF_Tensor* dest), + TF_VariableInputLockHolder** lockHolder, TF_Status* status); + +// This interface returns `out` tensor which is updated corresponding to the +// variable passed with input index. The caller takes ownership of the `source` +// and `dest` tensors and is responsible for freeing them with TF_DeleteTensor. +TF_CAPI_EXPORT extern void TF_GetInputTensorFromVariable( + TF_OpKernelContext* ctx, int input, bool lock_held, bool isVariantType, + bool sparse, + void (*copyFunc)(TF_OpKernelContext* ctx, TF_Tensor* source, + TF_Tensor* dest), + TF_Tensor** out, TF_Status* status); + +// This interface forwards the reference from input to the output tensors +// corresponding to the indices provided with `input_index` and `output_index` +TF_CAPI_EXPORT extern void TF_OpKernelContext_ForwardRefInputToRefOutput( + TF_OpKernelContext* ctx, int32_t input_index, int32_t output_index); + +// The API releases the opaque lock handle returned with +// `TF_MaybeLockVariableInputMutexesInOrder` API +TF_CAPI_EXPORT extern void TF_ReleaseVariableInputLockHolder( + TF_VariableInputLockHolder* lockHolder); + +// Allows plugin to get TF_Tensor when passed its input_name +TF_CAPI_EXPORT extern void TF_GetInputByName(TF_OpKernelContext* ctx, + const char* inputName, + TF_Tensor** tensor, + TF_Status* status); + +// Interprets the named kernel construction attribute as a shape attribute and +// fills in `vals` with the size of each dimension. `vals` must point to an +// array of length at least `max_values` (ideally set to total_size from +// TF_OpKernelConstruction_GetAttrSize(ctx, attr_name, &list_size, +// &total_size)). +TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrTensorShape( + TF_OpKernelConstruction* ctx, const char* attr_name, int64_t* dims, + size_t num_dims, TF_Status* status); + +TF_CAPI_EXPORT extern bool TF_IsRefInput(TF_OpKernelContext* ctx, int i, + TF_Status* status); + +#ifndef IS_MOBILE_PLATFORM +// Expose higher level AddN operation for Pluggable vendors to implement +// in the plugin for Variant data types. The API takes in the context and a +// callback provided by pluggable vendor to do a Binary Add operation on the +// tensors unwrapped from the Variant tensors. The caller takes ownership of the +// `a`, `b` and `out` tensors and is responsible for freeing them with +// TF_DeleteTensor. +TF_CAPI_EXPORT extern void TF_AddNVariant( + TF_OpKernelContext* ctx, + void (*binary_add_func)(TF_OpKernelContext* ctx, TF_Tensor* a, TF_Tensor* b, + TF_Tensor* out), + TF_Status* status); + +// Expose higher level ZerosLike operation for Pluggable vendors to implement +// in the plugin for Variant data types. The API takes in the context and a +// callback provided by pluggable vendor to do a ZerosLike operation on the +// tensors unwrapped from the Variant tensors. The caller takes ownership of the +// `input` and `out` tensors and is responsible for freeing them with +// TF_DeleteTensor. +TF_CAPI_EXPORT extern void TF_ZerosLikeVariant( + TF_OpKernelContext* ctx, + void (*zeros_like_func)(TF_OpKernelContext* ctx, TF_Tensor* input, + TF_Tensor* out), + TF_Status* status); + +typedef struct TF_CoordinationServiceAgent TF_CoordinationServiceAgent; + +#endif // IS_MOBILE_PLATFORM + +#ifdef __cplusplus +} /* end extern "C" */ +#endif + +#endif // TENSORFLOW_C_KERNELS_EXPERIMENTAL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/c/logging.h b/third_party/tflite-hdrs/tensorflow/c/logging.h new file mode 100644 index 00000000..9583777b --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/c/logging.h @@ -0,0 +1,42 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_C_LOGGING_H_ +#define TENSORFLOW_C_LOGGING_H_ + +#include "tensorflow/c/c_api_macros.h" + +// -------------------------------------------------------------------------- +// C API for tensorflow::Logging. + +#ifdef __cplusplus +extern "C" { +#endif + +typedef enum TF_LogLevel { + TF_INFO = 0, + TF_WARNING = 1, + TF_ERROR = 2, + TF_FATAL = 3, +} TF_LogLevel; + +TF_CAPI_EXPORT extern void TF_Log(TF_LogLevel level, const char* fmt, ...); +TF_CAPI_EXPORT extern void TF_VLog(int level, const char* fmt, ...); +TF_CAPI_EXPORT extern void TF_DVLog(int level, const char* fmt, ...); + +#ifdef __cplusplus +} +#endif + +#endif // TENSORFLOW_C_LOGGING_H_ diff --git a/third_party/tflite-hdrs/tensorflow/c/ops.h b/third_party/tflite-hdrs/tensorflow/c/ops.h new file mode 100644 index 00000000..5d3a1e89 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/c/ops.h @@ -0,0 +1,364 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Routines for registering new ops and for implementing op shape inference +// functions. +// +// This API is alpha software and is subject to change. +// +// REGISTRATION +// ------------ +// +// In order to register a new op, create a new TF_OpDefinitionBuilder: +// +// TF_OpDefinitionBuilder* builder = TF_NewOpDefinitionBuilder("OpName"); +// +// Inputs, outputs and attributes can be added to the builder with the +// corresponding functions, e.g. +// +// TF_OpDefinitionBuilderAddInput(builder, "input1: int32"); +// TF_OpDefinitionBuilderAddOutput(builder, "output1: int64"); +// TF_OpDefinitionBuilderAddAttr(builder, "attr: int32"); +// +// The builder may then be registered with TensorFlow using the +// TF_RegisterOpDefinition function. E.g. +// +// TF_Status* status = TF_NewStatus(); +// TF_RegisterOpDefinition(builder, &status); +// if (TF_GetCode(status) != TF_OK) { +// // handle error +// } +// +// SHAPE INFERENCE +// --------------- +// +// You can provide a shape inference function that TensorFlow will call when it +// wants to understand the shape of outputs that the op will produce. Use the +// TF_OpDefinitionBuilderSetShapeInferenceFunction function to register a shape +// inference function pointer with TensorFlow. The following is an example of a +// very simple shape inference function: +// +// void identity_shape_fn(TF_ShapeInferenceContext* ctx, TF_Status* status) { +// TF_ShapeHandle* input = TF_NewShapeHandle(); +// TF_ShapeInferenceContextGetInput(ctx, 0, input, status); +// if (TF_GetCode(status) == TF_OK) { +// TF_ShapeInferenceContextSetOutput(ctx, 0, input, status); +// } +// TF_DeleteShapeHandle(input); +// } +// +// The following code registers the inference function with TensorFlow: +// +// TF_OpDefinitionBuilderSetShapeInferenceFunction(builder, &identity_shape_fn); +// +// For more details about shape inference, see the documentation for +// TF_OpDefinitionBuilderSetShapeInferenceFunction. + +#ifndef TENSORFLOW_C_OPS_H_ +#define TENSORFLOW_C_OPS_H_ + +#include +#include +#include + +#include "tensorflow/c/c_api_macros.h" +#include "tensorflow/c/tf_datatype.h" +#include "tensorflow/c/tf_status.h" + +#ifdef __cplusplus +extern "C" { +#endif + +struct TF_DimensionHandle; +struct TF_OpDefinitionBuilder; +struct TF_ShapeHandle; +struct TF_ShapeInferenceContext; + +// Returns a newly allocated op definition builder for the given op name. The +// returned builder may be customized with the `TF_OpDefinitionBuilder...` +// functions and then registered with TensorFlow with TF_RegisterOpDefinition. +// +// The returned pointer is either freed by a call to TF_RegisterOpDefinition, or +// can be manually deleted by TF_DeleteOpDefinitionBuilder if it is never +// registered. +TF_CAPI_EXPORT extern TF_OpDefinitionBuilder* TF_NewOpDefinitionBuilder( + const char* op_name); + +// Registers the given op builder with TensorFlow. Indicates success or +// otherwise in the given status. +// +// `builder` is freed whether the op was successfully registered or not. You +// must call either this function or TF_DeleteOpDefinitionBuilder to free the +// builder, but never both. +TF_CAPI_EXPORT extern void TF_RegisterOpDefinition( + TF_OpDefinitionBuilder* builder, TF_Status* status); + +// Frees the given op definition builder. You must call either this function or +// TF_RegisterOpDefinition to free the builder, but never both. +TF_CAPI_EXPORT extern void TF_DeleteOpDefinitionBuilder( + TF_OpDefinitionBuilder* builder); + +//---------------------------------------------------- +// Attribute functions. + +// Adds an attr to the given TF_OpDefinitionBuilder. The spec has +// format ":" or ":=" +// where matches regexp [a-zA-Z][a-zA-Z0-9_]*. +// By convention, names containing only capital letters are reserved for +// attributes whose values can be inferred by the operator implementation if not +// supplied by the user. If the attribute name contains characters other than +// capital letters, the operator expects the user to provide the attribute value +// at operation runtime. +// +// can be: +// "string", "int", "float", "bool", "type", "shape", or "tensor" +// "numbertype", "realnumbertype", "quantizedtype" +// (meaning "type" with a restriction on valid values) +// "{int32,int64}" or {realnumbertype,quantizedtype,string}" +// (meaning "type" with a restriction containing unions of value types) +// "{\"foo\", \"bar\n baz\"}", or "{'foo', 'bar\n baz'}" +// (meaning "string" with a restriction on valid values) +// "list(string)", ..., "list(tensor)", "list(numbertype)", ... +// (meaning lists of the above types) +// "int >= 2" (meaning "int" with a restriction on valid values) +// "list(string) >= 2", "list(int) >= 2" +// (meaning "list(string)" / "list(int)" with length at least 2) +// , if included, should use the Proto text format +// of . For lists use [a, b, c] format. +// +// Note that any attr specifying the length of an input or output will +// get a default minimum of 1 unless the >= # syntax is used. +TF_CAPI_EXPORT extern void TF_OpDefinitionBuilderAddAttr( + TF_OpDefinitionBuilder* builder, const char* attr_spec); + +// Adds an input to this TF_OpDefinitionBuilder. +// The spec has form ":" or ":Ref()" +// where matches regexp [a-z][a-z0-9_]* and can be: +// * For a single tensor: +// * For a sequence of tensors with the same type: * +// * For a sequence of tensors with different types: +// Where: +// is either one of "float", "int32", "string", ... +// or the name of an attr (see TF_OpDefinitionBuilderAddAttr) +// with type "type". +// is the name of an attr with type "int". +// is the name of an attr with type "list(type)". +TF_CAPI_EXPORT extern void TF_OpDefinitionBuilderAddInput( + TF_OpDefinitionBuilder* builder, const char* input_spec); + +// Adds an output to this TF_OpDefinitionBuilder. +// The spec has form ":" or ":Ref()" +// where matches regexp [a-z][a-z0-9_]* and can be: +// * For a single tensor: +// * For a sequence of tensors with the same type: * +// * For a sequence of tensors with different types: +// Where: +// is either one of "float", "int32", "string", ... +// or the name of an attr (see TF_OpDefinitionBuilderAddAttr) +// with type "type". +// is the name of an attr with type "int". +// is the name of an attr with type "list(type)". +TF_CAPI_EXPORT extern void TF_OpDefinitionBuilderAddOutput( + TF_OpDefinitionBuilder* builder, const char* output_spec); + +// Sets the commutative property for the op built by the given builder. +TF_CAPI_EXPORT extern void TF_OpDefinitionBuilderSetIsCommutative( + TF_OpDefinitionBuilder* builder, bool is_commutative); + +// Sets the is_aggregate property of the builder to the given value. +// +// If is_aggregate is true, then the operation produced by this builder accepts +// N >= 2 inputs and produces 1 output all of the same type. Should be +// associative and commutative, and produce output with the same shape as the +// input. The optimizer may replace an aggregate op taking input from multiple +// devices with a tree of aggregate ops that aggregate locally within each +// device (and possibly within groups of nearby devices) before communicating. +TF_CAPI_EXPORT extern void TF_OpDefinitionBuilderSetIsAggregate( + TF_OpDefinitionBuilder* builder, bool is_aggregate); + +// Sets the is_stateful property of the builder to the given value. +// +// The op built by this builder is stateful if its behavior depends on some +// state beyond its input tensors (e.g. variable reading op) or if it has a +// side-effect (e.g. printing or asserting ops). Equivalently, stateless ops +// must always produce the same output for the same input and have no +// side-effects. +// +// By default Ops may be moved between devices. Stateful ops should either not +// be moved, or should only be moved if that state can also be moved (e.g. via +// some sort of save / restore). Stateful ops are guaranteed to never be +// optimized away by Common Subexpression Elimination (CSE). +TF_CAPI_EXPORT extern void TF_OpDefinitionBuilderSetIsStateful( + TF_OpDefinitionBuilder* builder, bool is_stateful); + +// Sets the allows_uninitialized_input property of the operation built by this +// builder. +// +// By default, all inputs to an Op must be initialized Tensors. Ops that may +// initialize tensors for the first time should set this field to true, to allow +// the Op to take an uninitialized Tensor as input. +TF_CAPI_EXPORT extern void TF_OpDefinitionBuilderSetAllowsUninitializedInput( + TF_OpDefinitionBuilder* builder, bool allows_uninitialized_input); + +// Adds a deprecation warning for the given op. This indicates to the user that +// `version` is the first TensorFlow GraphDef version for which the operation is +// deprecated. `explanation` should contain the reason for the deprecation and +// what to use instead. +// +// This function is only an indicator that the operation may disappear in a +// version of TensorFlow after `version`. It does not affect op registration. +TF_CAPI_EXPORT extern void TF_OpDefinitionBuilderDeprecated( + TF_OpDefinitionBuilder* builder, int version, const char* explanation); + +// Sets the shape inference function for the op. +TF_CAPI_EXPORT extern void TF_OpDefinitionBuilderSetShapeInferenceFunction( + TF_OpDefinitionBuilder* builder, + void (*shape_inference_func)(TF_ShapeInferenceContext* ctx, + TF_Status* status)); + +//---------------------------------------------------- +// Functions for TF_ShapeInferenceContext. +// +// Functions for implementing shape inference functions. TensorFlow uses these +// functions to determine the shape of tensors produced by an operation without +// having to actually run the operation. If an operation chooses to provide a +// shape inference function, it will be invoked by TensorFlow as needed. +// +// When invoked by TensorFlow, the shape inference function is provided with a +// TF_ShapeInferenceContext pointer. The function's implementation will use the +// accessor and mutator functions with names beginning with +// TF_ShapeInferenceContext to examine the input state and determine the output +// shape. + +// Returns the number of inputs in the given shape inference context. +TF_CAPI_EXPORT extern int64_t TF_ShapeInferenceContextNumInputs( + TF_ShapeInferenceContext* ctx); + +// Returns a newly allocated shape handle. The shapes represented by these +// handles may be queried or mutated with the corresponding +// TF_ShapeInferenceContext... functions. +TF_CAPI_EXPORT extern TF_ShapeHandle* TF_NewShapeHandle(); + +// Places the ith input of the given shape inference context into the given +// shape handle, or returns a status other than TF_OK indicating why the input +// could not be retrieved +// (for example, if i < 0 || i >= TF_ShapeInferenceContextNumInputs(ctx)). +TF_CAPI_EXPORT extern void TF_ShapeInferenceContextGetInput( + TF_ShapeInferenceContext* ctx, int i, TF_ShapeHandle* handle, + TF_Status* status); + +// Places the given shape handle into the `i`th output position of the given +// context. Internally, the shape handle is copied; the caller may subsequently +// delete `handle`. +TF_CAPI_EXPORT +extern void TF_ShapeInferenceContextSetOutput(TF_ShapeInferenceContext* ctx, + int i, TF_ShapeHandle* handle, + TF_Status* status); + +// Returns a newly-allocated scalar shape handle. The returned handle should +// be freed with TF_DeleteShapeHandle. +TF_CAPI_EXPORT extern TF_ShapeHandle* TF_ShapeInferenceContextScalar( + TF_ShapeInferenceContext* ctx); + +// Returns a newly-allocate shape handle representing a vector of the given +// size. The returned handle should be freed with TF_DeleteShapeHandle. +TF_CAPI_EXPORT extern TF_ShapeHandle* TF_ShapeInferenceContextVectorFromSize( + TF_ShapeInferenceContext* ctx, size_t size); + +// Returns a newly allocated dimension handle. It must be freed with +// TF_DeleteDimensionHandle. +TF_CAPI_EXPORT extern TF_DimensionHandle* TF_NewDimensionHandle(); + +// Interprets the named shape inference context attribute as a TF_DataType and +// places it into *val. *status is set to TF_OK. +// +// If the attribute could not be found or could not be interpreted as +// TF_DataType, *status is populated with an error. +TF_CAPI_EXPORT extern void TF_ShapeInferenceContext_GetAttrType( + TF_ShapeInferenceContext* ctx, const char* attr_name, TF_DataType* val, + TF_Status* status); + +// Returns the rank of the shape represented by the given handle. +TF_CAPI_EXPORT extern int64_t TF_ShapeInferenceContextRank( + TF_ShapeInferenceContext* ctx, TF_ShapeHandle* handle); + +// Returns 1 if `handle` has a known rank, 0 otherwise. +TF_CAPI_EXPORT extern int TF_ShapeInferenceContextRankKnown( + TF_ShapeInferenceContext* ctx, TF_ShapeHandle* handle); + +// If has rank , or its rank is unknown, return OK and return the +// shape with asserted rank in <*result>. Otherwise an error is placed into +// `status`. +TF_CAPI_EXPORT extern void TF_ShapeInferenceContextWithRank( + TF_ShapeInferenceContext* ctx, TF_ShapeHandle* handle, int64_t rank, + TF_ShapeHandle* result, TF_Status* status); + +// If has rank at least , or its rank is unknown, return OK and +// return the shape with asserted rank in <*result>. Otherwise an error is +// placed into `status`. +TF_CAPI_EXPORT extern void TF_ShapeInferenceContextWithRankAtLeast( + TF_ShapeInferenceContext* ctx, TF_ShapeHandle* handle, int64_t rank, + TF_ShapeHandle* result, TF_Status* status); + +// If has rank at most , or its rank is unknown, return OK and +// return the shape with asserted rank in <*result>. Otherwise an error is +// placed into `status`. +TF_CAPI_EXPORT extern void TF_ShapeInferenceContextWithRankAtMost( + TF_ShapeInferenceContext* ctx, TF_ShapeHandle* handle, int64_t rank, + TF_ShapeHandle* result, TF_Status* status); + +// Places a handle to the ith dimension of the given shape into *result. +TF_CAPI_EXPORT extern void TF_ShapeInferenceContextDim( + TF_ShapeInferenceContext* ctx, TF_ShapeHandle* shape_handle, int64_t i, + TF_DimensionHandle* result); + +// Returns in <*result> a sub-shape of , with dimensions +// [start:end]. and can be negative, to index from the end of the +// shape. and are set to the rank of if > rank of +// . +TF_CAPI_EXPORT extern void TF_ShapeInferenceContextSubshape( + TF_ShapeInferenceContext* ctx, TF_ShapeHandle* shape_handle, int64_t start, + int64_t end, TF_ShapeHandle* result, TF_Status* status); + +// Places an unknown shape in all outputs for the given inference context. Used +// for shape inference functions with ops whose output shapes are unknown. +TF_CAPI_EXPORT extern void TF_ShapeInferenceContextSetUnknownShape( + TF_ShapeInferenceContext* ctx, TF_Status* status); + +// Returns whether the given handle represents a known dimension. +TF_CAPI_EXPORT extern int TF_DimensionHandleValueKnown( + TF_DimensionHandle* dim_handle); + +// Returns the value of the given dimension. +TF_CAPI_EXPORT extern int64_t TF_DimensionHandleValue( + TF_DimensionHandle* dim_handle); + +// Returns in <*result> the result of appending the dimensions of to +// those of . +TF_CAPI_EXPORT extern void TF_ShapeInferenceContextConcatenateShapes( + TF_ShapeInferenceContext* ctx, TF_ShapeHandle* first, + TF_ShapeHandle* second, TF_ShapeHandle* result, TF_Status* status); + +// Frees the given shape handle. +TF_CAPI_EXPORT extern void TF_DeleteShapeHandle(TF_ShapeHandle* handle); + +// Frees the given dimension handle. +TF_CAPI_EXPORT extern void TF_DeleteDimensionHandle(TF_DimensionHandle* handle); + +#ifdef __cplusplus +} /* end extern "C" */ +#endif + +#endif // TENSORFLOW_C_OPS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/c/python_api.h b/third_party/tflite-hdrs/tensorflow/c/python_api.h new file mode 100644 index 00000000..043b7668 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/c/python_api.h @@ -0,0 +1,82 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_PYTHON_API_H_ +#define TENSORFLOW_C_PYTHON_API_H_ + +#include + +#include "tensorflow/c/c_api.h" +#include "tensorflow/core/framework/full_type.pb.h" + +// These functions can be removed without notice. They exist to facilitate some +// refactoring of graph construction code in the Python API. + +namespace tensorflow { + +void AddControlInput(TF_Graph* graph, TF_Operation* op, TF_Operation* input); + +// Changes an attr value in the node_def Protocol Buffer and sets a status upon +// completion. +void SetAttr(TF_Graph* graph, TF_Operation* op, const char* attr_name, + TF_Buffer* attr_value_proto, TF_Status* status); + +// Clears the attr in the node_def Protocol Buffer and sets a status upon +// completion. +void ClearAttr(TF_Graph* graph, TF_Operation* op, const char* attr_name, + TF_Status* status); + +// Sets the experimental_type` field in the node_def Protocol Buffer. +void SetFullType(TF_Graph* graph, TF_Operation* op, + const TF_Buffer* full_type_proto); + +void SetRequestedDevice(TF_Graph* graph, TF_Operation* op, const char* device); + +// Updates 'dst' to consume 'new_src'. +void UpdateEdge(TF_Graph* graph, TF_Output new_src, TF_Input dst, + TF_Status* status); + +// Extends `session` with any new operations added to its associated graph. +// Usually this happens automatically in TF_SessionRun. After this is called, +// TF_SessionRun will no longer extend the session on every call. +// +// We expose this here to allow fine-grained synchronization in multi-threaded +// workloads, which is required since the Python implementation depends on the +// above mutation methods. This allows us to prevent modifications to nodes in +// the graph after the session has been made aware of them. +void ExtendSession(TF_Session* session, TF_Status* status); + +// Returns the serialized CppShapeInferenceResult::HandleData proto for +// `output` if its a resource or variant tensor, or otherwise returns the empty +// string. +std::string GetHandleShapeAndType(TF_Graph* graph, TF_Output output); + +// Sets `output` based on `proto`, which should be a serialized +// CppShapeInferenceResult::HandleData proto. `output` should be a resource +// or variant tensor. +// NOTE(skyewm): `proto` is passed a void*/size_t pair instead of a std::string +// because I couldn't get SWIG to work otherwise. +void SetHandleShapeAndType(TF_Graph* graph, TF_Output output, const void* proto, + size_t proto_len, TF_Status* status); + +// This method is used to add a new input edge to 'dst', which must be a While +// op. The While op's "T" attribute must have already been updated to include +// the new edge. This is used to construct tf.while_loop gradients. +void AddWhileInputHack(TF_Graph* graph, TF_Output new_src, TF_Operation* dst, + TF_Status* status); + +} // namespace tensorflow + +#endif // TENSORFLOW_C_PYTHON_API_H_ diff --git a/third_party/tflite-hdrs/tensorflow/c/safe_ptr.h b/third_party/tflite-hdrs/tensorflow/c/safe_ptr.h new file mode 100644 index 00000000..8d8b8141 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/c/safe_ptr.h @@ -0,0 +1,68 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_SAFE_PTR_H_ +#define TENSORFLOW_C_SAFE_PTR_H_ + +#include + +#include "tensorflow/c/c_api.h" +#include "tensorflow/c/eager/c_api.h" + +namespace tensorflow { +namespace detail { + +struct TFTensorDeleter { + void operator()(TF_Tensor* p) const { TF_DeleteTensor(p); } +}; + +struct TFETensorHandleDeleter { + void operator()(TFE_TensorHandle* p) const { TFE_DeleteTensorHandle(p); } +}; + +struct TFStatusDeleter { + void operator()(TF_Status* p) const { TF_DeleteStatus(p); } +}; + +struct TFBufferDeleter { + void operator()(TF_Buffer* p) const { TF_DeleteBuffer(p); } +}; + +} // namespace detail + +// Safe containers for an owned TF_Tensor. On destruction, the tensor will be +// deleted by TF_DeleteTensor. +using Safe_TF_TensorPtr = std::unique_ptr; +Safe_TF_TensorPtr make_safe(TF_Tensor* tensor); + +// Safe containers for an owned TFE_TensorHandle. On destruction, the handle +// will be deleted by TFE_DeleteTensorHandle. +using Safe_TFE_TensorHandlePtr = + std::unique_ptr; +Safe_TFE_TensorHandlePtr make_safe(TFE_TensorHandle* handle); + +// Safe containers for an owned TF_Status. On destruction, the handle +// will be deleted by TF_DeleteStatus. +using Safe_TF_StatusPtr = std::unique_ptr; +Safe_TF_StatusPtr make_safe(TF_Status* status); + +// Safe containers for an owned TF_Buffer. On destruction, the handle +// will be deleted by TF_DeleteBuffer. +using Safe_TF_BufferPtr = std::unique_ptr; +Safe_TF_BufferPtr make_safe(TF_Buffer* buffer); + +} // namespace tensorflow + +#endif // TENSORFLOW_C_SAFE_PTR_H_ diff --git a/third_party/tflite-hdrs/tensorflow/c/tensor_interface.h b/third_party/tflite-hdrs/tensorflow/c/tensor_interface.h new file mode 100644 index 00000000..0b352f56 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/c/tensor_interface.h @@ -0,0 +1,75 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_TENSOR_INTERFACE_H_ +#define TENSORFLOW_C_TENSOR_INTERFACE_H_ + +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/platform/status.h" + +namespace tensorflow { + +// Abstract interface to a Tensor. +// +// This allows us to hide concrete implementations of Tensor from header +// files. The interface lists the common functionality that must be provided by +// any concrete implementation. However, in cases where the true concrete class +// is needed a static_cast can be applied. +class AbstractTensorInterface { + public: + // Release any underlying resources, including the interface object. + virtual void Release() = 0; + + // Returns tensor dtype. + virtual DataType Type() const = 0; + // Returns number of dimensions. + virtual int NumDims() const = 0; + // Returns size of specified dimension + virtual int64_t Dim(int dim_index) const = 0; + // Returns number of elements across all dimensions. + virtual int64_t NumElements() const = 0; + // Return size in bytes of the Tensor + virtual size_t ByteSize() const = 0; + // Returns a pointer to tensor data + virtual void* Data() const = 0; + + // Returns if the tensor is aligned + virtual bool IsAligned() const = 0; + // Returns if their is sole ownership of this Tensor and thus it can be moved. + virtual bool CanMove() const = 0; + + virtual std::string SummarizeValue() const = 0; + + protected: + virtual ~AbstractTensorInterface() {} +}; + +namespace internal { +struct AbstractTensorInterfaceDeleter { + void operator()(AbstractTensorInterface* p) const { + if (p != nullptr) { + p->Release(); + } + } +}; +} // namespace internal + +using AbstractTensorPtr = + std::unique_ptr; + +} // namespace tensorflow + +#endif // TENSORFLOW_C_TENSOR_INTERFACE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/c/tf_attrtype.h b/third_party/tflite-hdrs/tensorflow/c/tf_attrtype.h new file mode 100644 index 00000000..0c1545db --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/c/tf_attrtype.h @@ -0,0 +1,39 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_C_TF_ATTRTYPE_H_ +#define TENSORFLOW_C_TF_ATTRTYPE_H_ + +#ifdef __cplusplus +extern "C" { +#endif + +// TF_AttrType describes the type of the value of an attribute on an operation. +typedef enum TF_AttrType { + TF_ATTR_STRING = 0, + TF_ATTR_INT = 1, + TF_ATTR_FLOAT = 2, + TF_ATTR_BOOL = 3, + TF_ATTR_TYPE = 4, + TF_ATTR_SHAPE = 5, + TF_ATTR_TENSOR = 6, + TF_ATTR_PLACEHOLDER = 7, + TF_ATTR_FUNC = 8, +} TF_AttrType; + +#ifdef __cplusplus +} /* end extern "C" */ +#endif + +#endif // TENSORFLOW_C_TF_ATTRTYPE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/c/tf_buffer.h b/third_party/tflite-hdrs/tensorflow/c/tf_buffer.h new file mode 100644 index 00000000..71a9aef8 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/c/tf_buffer.h @@ -0,0 +1,57 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_TF_BUFFER_H_ +#define TENSORFLOW_C_TF_BUFFER_H_ + +#include + +#include "tensorflow/c/c_api_macros.h" + +#ifdef __cplusplus +extern "C" { +#endif + +// -------------------------------------------------------------------------- +// TF_Buffer holds a pointer to a block of data and its associated length. +// Typically, the data consists of a serialized protocol buffer, but other data +// may also be held in a buffer. +// +// By default, TF_Buffer itself does not do any memory management of the +// pointed-to block. If need be, users of this struct should specify how to +// deallocate the block by setting the `data_deallocator` function pointer. +typedef struct TF_Buffer { + const void* data; + size_t length; + void (*data_deallocator)(void* data, size_t length); +} TF_Buffer; + +// Makes a copy of the input and sets an appropriate deallocator. Useful for +// passing in read-only, input protobufs. +TF_CAPI_EXPORT extern TF_Buffer* TF_NewBufferFromString(const void* proto, + size_t proto_len); + +// Useful for passing *out* a protobuf. +TF_CAPI_EXPORT extern TF_Buffer* TF_NewBuffer(void); + +TF_CAPI_EXPORT extern void TF_DeleteBuffer(TF_Buffer*); + +TF_CAPI_EXPORT extern TF_Buffer TF_GetBuffer(TF_Buffer* buffer); + +#ifdef __cplusplus +} /* end extern "C" */ +#endif + +#endif // TENSORFLOW_C_TF_BUFFER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/c/tf_buffer_internal.h b/third_party/tflite-hdrs/tensorflow/c/tf_buffer_internal.h new file mode 100644 index 00000000..85436f42 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/c/tf_buffer_internal.h @@ -0,0 +1,45 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_TF_BUFFER_INTERNAL_H_ +#define TENSORFLOW_C_TF_BUFFER_INTERNAL_H_ + +#include + +#include "tensorflow/c/tf_buffer.h" +#include "tensorflow/core/platform/protobuf.h" // IWYU pragma: keep +#include "tensorflow/core/platform/status.h" + +namespace tensorflow { + +absl::Status MessageToBuffer(const tensorflow::protobuf::MessageLite& in, + TF_Buffer* out); + +absl::Status BufferToMessage(const TF_Buffer* in, + tensorflow::protobuf::MessageLite* out); + +namespace internal { + +struct TF_BufferDeleter { + void operator()(TF_Buffer* buf) const { TF_DeleteBuffer(buf); } +}; + +} // namespace internal + +using TF_BufferPtr = std::unique_ptr; + +} // namespace tensorflow + +#endif // TENSORFLOW_C_TF_BUFFER_INTERNAL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/c/tf_datatype.h b/third_party/tflite-hdrs/tensorflow/c/tf_datatype.h new file mode 100644 index 00000000..9a9eaadc --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/c/tf_datatype.h @@ -0,0 +1,75 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_TF_DATATYPE_H_ +#define TENSORFLOW_C_TF_DATATYPE_H_ + +#include + +#include "tensorflow/c/c_api_macros.h" + +#ifdef __cplusplus +extern "C" { +#endif + +// -------------------------------------------------------------------------- +// TF_DataType holds the type for a scalar value. E.g., one slot in a tensor. +// The enum values here are identical to corresponding values in types.proto. +typedef enum TF_DataType { + TF_FLOAT = 1, + TF_DOUBLE = 2, + TF_INT32 = 3, // Int32 tensors are always in 'host' memory. + TF_UINT8 = 4, + TF_INT16 = 5, + TF_INT8 = 6, + TF_STRING = 7, + TF_COMPLEX64 = 8, // Single-precision complex + TF_COMPLEX = 8, // Old identifier kept for API backwards compatibility + TF_INT64 = 9, + TF_BOOL = 10, + TF_QINT8 = 11, // Quantized int8 + TF_QUINT8 = 12, // Quantized uint8 + TF_QINT32 = 13, // Quantized int32 + TF_BFLOAT16 = 14, // Float32 truncated to 16 bits. + TF_QINT16 = 15, // Quantized int16 + TF_QUINT16 = 16, // Quantized uint16 + TF_UINT16 = 17, + TF_COMPLEX128 = 18, // Double-precision complex + TF_HALF = 19, + TF_RESOURCE = 20, + TF_VARIANT = 21, + TF_UINT32 = 22, + TF_UINT64 = 23, + TF_FLOAT8_E5M2 = 24, // 5 exponent bits, 2 mantissa bits. + TF_FLOAT8_E4M3FN = 25, // 4 exponent bits, 3 mantissa bits, finite-only, with + // 2 NaNs (0bS1111111). + // TODO - b/299182407: Leaving room for remaining float8 types. + // TF_FLOAT8_E4M3FNUZ = 26, + // TF_FLOAT8_E4M3B11FNUZ = 27, + // TF_FLOAT8_E5M2FNUZ = 28, + TF_INT4 = 29, + TF_UINT4 = 30, +} TF_DataType; + +// TF_DataTypeSize returns the sizeof() for the underlying type corresponding +// to the given TF_DataType enum value. Returns 0 for variable length types +// (eg. TF_STRING) or on failure. +TF_CAPI_EXPORT extern size_t TF_DataTypeSize(TF_DataType dt); + +#ifdef __cplusplus +} /* end extern "C" */ +#endif + +#endif // TENSORFLOW_C_TF_DATATYPE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/c/tf_file_statistics.h b/third_party/tflite-hdrs/tensorflow/c/tf_file_statistics.h new file mode 100644 index 00000000..117d9501 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/c/tf_file_statistics.h @@ -0,0 +1,34 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_TF_FILE_STATISTICS_H_ +#define TENSORFLOW_C_TF_FILE_STATISTICS_H_ + +#include + +typedef struct TF_FileStatistics { + // The length of the file in bytes. + int64_t length; + // The last modified time in nanoseconds. + int64_t mtime_nsec; + // Whether the name refers to a directory. + bool is_directory; +} TF_FileStatistics; + +// TODO(b/139060984): `tensorflow::FileStatistics` from +// `core/platform/file_statistics.h` is a duplicate of this so maybe try to +// remove duplication later? + +#endif // TENSORFLOW_C_TF_FILE_STATISTICS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/c/tf_shape.h b/third_party/tflite-hdrs/tensorflow/c/tf_shape.h new file mode 100644 index 00000000..f218d05e --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/c/tf_shape.h @@ -0,0 +1,50 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "tensorflow/c/c_api_macros.h" + +#ifndef TENSORFLOW_C_TF_SHAPE_H_ +#define TENSORFLOW_C_TF_SHAPE_H_ + +#ifdef __cplusplus +extern "C" { +#endif + +// An opaque type corresponding to a shape in tensorflow. In the future, +// we may expose the ABI of TF_Shape for performance reasons. +typedef struct TF_Shape TF_Shape; + +// Return a new, unknown rank shape object. The caller is responsible for +// calling TF_DeleteShape to deallocate and destroy the returned shape. +TF_CAPI_EXPORT extern TF_Shape* TF_NewShape(); + +// Returns the rank of `shape`. If `shape` has unknown rank, returns -1. +TF_CAPI_EXPORT extern int TF_ShapeDims(const TF_Shape* shape); + +// Returns the `d`th dimension of `shape`. If `shape` has unknown rank, +// invoking this function is undefined behavior. Returns -1 if dimension is +// unknown. +TF_CAPI_EXPORT extern int64_t TF_ShapeDimSize(const TF_Shape* shape, int d); + +// Deletes `shape`. +TF_CAPI_EXPORT extern void TF_DeleteShape(TF_Shape* shape); + +#ifdef __cplusplus +} /* end extern "C" */ +#endif + +#endif // TENSORFLOW_C_TF_SHAPE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/c/tf_shape_internal.h b/third_party/tflite-hdrs/tensorflow/c/tf_shape_internal.h new file mode 100644 index 00000000..fe977264 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/c/tf_shape_internal.h @@ -0,0 +1,30 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_TF_SHAPE_INTERNAL_H_ +#define TENSORFLOW_C_TF_SHAPE_INTERNAL_H_ + +#include "tensorflow/c/conversion_macros.h" +#include "tensorflow/core/framework/tensor_shape.h" + +typedef struct TF_Shape TF_Shape; + +namespace tensorflow { + +DEFINE_CONVERSION_FUNCTIONS(tensorflow::PartialTensorShape, TF_Shape); + +} + +#endif // TENSORFLOW_C_TF_SHAPE_INTERNAL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/c/tf_status.h b/third_party/tflite-hdrs/tensorflow/c/tf_status.h new file mode 100644 index 00000000..8979e42c --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/c/tf_status.h @@ -0,0 +1,98 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_TF_STATUS_H_ +#define TENSORFLOW_C_TF_STATUS_H_ + +#include "tensorflow/c/c_api_macros.h" +#include "xla/tsl/c/tsl_status.h" + +#ifdef __cplusplus +extern "C" { +#endif + +typedef struct TSL_Status TF_Status; + +// -------------------------------------------------------------------------- +// TF_Code holds an error code. The enum values here are identical to +// corresponding values in error_codes.proto. +typedef TSL_Code TF_Code; +// LINT.IfChange +#define TF_OK TSL_OK +#define TF_CANCELLED TSL_CANCELLED +#define TF_UNKNOWN TSL_UNKNOWN +#define TF_INVALID_ARGUMENT TSL_INVALID_ARGUMENT +#define TF_DEADLINE_EXCEEDED TSL_DEADLINE_EXCEEDED +#define TF_NOT_FOUND TSL_NOT_FOUND +#define TF_ALREADY_EXISTS TSL_ALREADY_EXISTS +#define TF_PERMISSION_DENIED TSL_PERMISSION_DENIED +#define TF_UNAUTHENTICATED TSL_UNAUTHENTICATED +#define TF_RESOURCE_EXHAUSTED TSL_RESOURCE_EXHAUSTED +#define TF_FAILED_PRECONDITION TSL_FAILED_PRECONDITION +#define TF_ABORTED TSL_ABORTED +#define TF_OUT_OF_RANGE TSL_OUT_OF_RANGE +#define TF_UNIMPLEMENTED TSL_UNIMPLEMENTED +#define TF_INTERNAL TSL_INTERNAL +#define TF_UNAVAILABLE TSL_UNAVAILABLE +#define TF_DATA_LOSS TSL_DATA_LOSS +// LINT.ThenChange(//tensorflow/python/py_exception_registry_wrapper.cc) + +// -------------------------------------------------------------------------- + +// Return a new status object. +TF_CAPI_EXPORT extern TF_Status* TF_NewStatus(void); + +// Delete a previously created status object. +TF_CAPI_EXPORT extern void TF_DeleteStatus(TF_Status*); + +// Record in *s. Any previous information is lost. +// A common use is to clear a status: TF_SetStatus(s, TF_OK, ""); +TF_CAPI_EXPORT extern void TF_SetStatus(TF_Status* s, TF_Code code, + const char* msg); + +// Record as a payload in *s. The previous payload having the +// same key (if any) is overwritten. Payload will not be added if the Status +// is OK. +TF_CAPI_EXPORT void TF_SetPayload(TF_Status* s, const char* key, + const char* value); + +// Iterates over the stored payloads and calls the `visitor(key, value)` +// callable for each one. `key` and `value` is only usable during the callback. +// `capture` will be passed to the callback without modification. +#define TF_PayloadVisitor TSL_PayloadVisitor +TF_CAPI_EXPORT extern void TF_ForEachPayload(const TF_Status* s, + TF_PayloadVisitor visitor, + void* capture); + +// Convert from an I/O error code (e.g., errno) to a TF_Status value. +// Any previous information is lost. Prefer to use this instead of TF_SetStatus +// when the error comes from I/O operations. +TF_CAPI_EXPORT extern void TF_SetStatusFromIOError(TF_Status* s, int error_code, + const char* context); + +// Return the code record in *s. +TF_CAPI_EXPORT extern TF_Code TF_GetCode(const TF_Status* s); + +// Return a pointer to the (null-terminated) error message in *s. The +// return value points to memory that is only usable until the next +// mutation to *s. Always returns an empty string if TF_GetCode(s) is +// TF_OK. +TF_CAPI_EXPORT extern const char* TF_Message(const TF_Status* s); + +#ifdef __cplusplus +} /* end extern "C" */ +#endif + +#endif // TENSORFLOW_C_TF_STATUS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/c/tf_status_helper.h b/third_party/tflite-hdrs/tensorflow/c/tf_status_helper.h new file mode 100644 index 00000000..ce833c39 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/c/tf_status_helper.h @@ -0,0 +1,74 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_TF_STATUS_HELPER_H_ +#define TENSORFLOW_C_TF_STATUS_HELPER_H_ + +#include +#include + +#include "tensorflow/c/tf_status.h" +#include "tsl/platform/status.h" + +namespace tsl { +// Set the attribute of "tf_status" from the attributes of "status". +void Set_TF_Status_from_Status(TF_Status* tf_status, + const absl::Status& status); + +// Returns a "status" from "tf_status". +absl::Status StatusFromTF_Status(const TF_Status* tf_status); +} // namespace tsl + +namespace tensorflow { +using tsl::Set_TF_Status_from_Status; // NOLINT +using tsl::StatusFromTF_Status; // NOLINT + +namespace internal { +struct TF_StatusDeleter { + void operator()(TF_Status* tf_status) const { TF_DeleteStatus(tf_status); } +}; +} // namespace internal + +using TF_StatusPtr = std::unique_ptr; + +} // namespace tensorflow + +#define TF_STATUS_ASSIGN_OR_RETURN(lhs, rexpr, c_status) \ + _TF_STATUS_ASSIGN_OR_RETURN_IMPL( \ + _TF_STATUS_CONCAT(_status_or_value, __COUNTER__), lhs, rexpr, c_status); + +#define _TF_STATUS_ASSIGN_OR_RETURN_IMPL(statusor, lhs, rexpr, c_status) \ + auto statusor = (rexpr); \ + if (!statusor.ok()) { \ + tensorflow::Set_TF_Status_from_Status(c_status, statusor.status()); \ + return; \ + } \ + lhs = std::move(*statusor) + +#define TF_STATUS_RETURN_IF_ERROR(rexpr, c_status) \ + _TF_STATUS_RETURN_IF_ERROR_IMPL(_TF_STATUS_CONCAT(_status, __COUNTER__), \ + rexpr, c_status); + +#define _TF_STATUS_RETURN_IF_ERROR_IMPL(status, rexpr, c_status) \ + auto status = (rexpr); \ + if (!status.ok()) { \ + tensorflow::Set_TF_Status_from_Status(c_status, status); \ + return; \ + } + +#define _TF_STATUS_CONCAT(x, y) _TF_STATUS_CONCAT_IMPL(x, y) +#define _TF_STATUS_CONCAT_IMPL(x, y) x##y + +#endif // TENSORFLOW_C_TF_STATUS_HELPER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/c/tf_status_internal.h b/third_party/tflite-hdrs/tensorflow/c/tf_status_internal.h new file mode 100644 index 00000000..4aa273fc --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/c/tf_status_internal.h @@ -0,0 +1,23 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_TF_STATUS_INTERNAL_H_ +#define TENSORFLOW_C_TF_STATUS_INTERNAL_H_ + +#include "xla/tsl/c/tsl_status_internal.h" + +typedef struct TSL_Status TF_Status; + +#endif // TENSORFLOW_C_TF_STATUS_INTERNAL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/c/tf_tensor.h b/third_party/tflite-hdrs/tensorflow/c/tf_tensor.h new file mode 100644 index 00000000..b2855d28 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/c/tf_tensor.h @@ -0,0 +1,161 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_TF_TENSOR_H_ +#define TENSORFLOW_C_TF_TENSOR_H_ + +#include +#include + +#include "tensorflow/c/c_api_macros.h" +#include "tensorflow/c/tf_datatype.h" +#include "tensorflow/c/tf_status.h" + +#ifdef __cplusplus +extern "C" { +#endif + +// Allocator Attributes used for tensor allocation. +typedef struct TF_AllocatorAttributes { + size_t struct_size; + // Set boolean to 1 for CPU allocation, else 0. + TF_Bool on_host; +} TF_AllocatorAttributes; + +#define TF_ALLOCATOR_ATTRIBUTES_STRUCT_SIZE \ + TF_OFFSET_OF_END(TF_AllocatorAttributes, on_host) + +// -------------------------------------------------------------------------- +// TF_Tensor holds a multi-dimensional array of elements of a single data type. +// For all types other than TF_STRING, the data buffer stores elements +// in row major order. E.g. if data is treated as a vector of TF_DataType: +// +// element 0: index (0, ..., 0) +// element 1: index (0, ..., 1) +// ... +// +// The format for TF_STRING tensors is: +// start_offset: array[uint64] +// data: byte[...] +// +// The string length (as a varint, start_offset[i + 1] - start_offset[i]), +// followed by the contents of the string is encoded at data[start_offset[i]]. +// TF_StringEncode and TF_StringDecode facilitate this encoding. + +typedef struct TF_Tensor TF_Tensor; + +// Return a new tensor that holds the bytes data[0,len-1]. +// +// The data will be deallocated by a subsequent call to TF_DeleteTensor via: +// (*deallocator)(data, len, deallocator_arg) +// Clients must provide a custom deallocator function so they can pass in +// memory managed by something like numpy. +// +// May return NULL (and invoke the deallocator) if the provided data buffer +// (data, len) is inconsistent with a tensor of the given TF_DataType +// and the shape specified by (dima, num_dims). +TF_CAPI_EXPORT extern TF_Tensor* TF_NewTensor( + TF_DataType, const int64_t* dims, int num_dims, void* data, size_t len, + void (*deallocator)(void* data, size_t len, void* arg), + void* deallocator_arg); + +// Returns the alignment, in bytes, required for allocating aligned tensors. +// +// This can be used in combination with TF_NewTensor to manually manage +// memory while ensuring the resulting tensors satisfy TensorFlow's +// memory alignment preferences. +TF_CAPI_EXPORT extern size_t TF_TensorDefaultAlignment(); + +// Allocate and return a new Tensor. +// +// This function is an alternative to TF_NewTensor and should be used when +// memory is allocated to pass the Tensor to the C API. The allocated memory +// satisfies TensorFlow's memory alignment preferences and should be preferred +// over calling malloc and free. +// +// The caller must set the Tensor values by writing them to the pointer returned +// by TF_TensorData with length TF_TensorByteSize. +TF_CAPI_EXPORT extern TF_Tensor* TF_AllocateTensor(TF_DataType, + const int64_t* dims, + int num_dims, size_t len); + +// Deletes `tensor` and returns a new TF_Tensor with the same content if +// possible. Returns nullptr and leaves `tensor` untouched if not. +TF_CAPI_EXPORT extern TF_Tensor* TF_TensorMaybeMove(TF_Tensor* tensor); + +// Destroy a tensor. +TF_CAPI_EXPORT extern void TF_DeleteTensor(TF_Tensor*); + +// Return the type of a tensor element. +TF_CAPI_EXPORT extern TF_DataType TF_TensorType(const TF_Tensor*); + +// Set a new shape for the Tensor. +TF_CAPI_EXPORT extern void TF_SetShape(TF_Tensor* tensor, const int64_t* dims, + int num_dims); + +// Return the number of dimensions that the tensor has. +TF_CAPI_EXPORT extern int TF_NumDims(const TF_Tensor*); + +// Return the length of the tensor in the "dim_index" dimension. +// REQUIRES: 0 <= dim_index < TF_NumDims(tensor) +TF_CAPI_EXPORT extern int64_t TF_Dim(const TF_Tensor* tensor, int dim_index); + +// Return the size of the underlying data in bytes. +TF_CAPI_EXPORT extern size_t TF_TensorByteSize(const TF_Tensor*); + +// Return a pointer to the underlying data buffer. +TF_CAPI_EXPORT extern void* TF_TensorData(const TF_Tensor*); + +// Returns the number of elements in the tensor. +TF_CAPI_EXPORT extern int64_t TF_TensorElementCount(const TF_Tensor* tensor); + +// Copy the internal data representation of `from` to `to`. `new_dims` and +// `num_new_dims` specify the new shape of the `to` tensor, `type` specifies its +// data type. On success, *status is set to TF_OK and the two tensors share the +// same data buffer. +// +// This call requires that the `from` tensor and the given type and shape (dims +// and num_dims) are "compatible" (i.e. they occupy the same number of bytes). +// Specifically, given from_type_size = TF_DataTypeSize(TF_TensorType(from)): +// +// ShapeElementCount(dims, num_dims) * TF_DataTypeSize(type) +// +// must equal +// +// TF_TensorElementCount(from) * from_type_size +// +// where TF_ShapeElementCount would be the number of elements in a tensor with +// the given shape. +// +// In addition, this function requires: +// * TF_DataTypeSize(TF_TensorType(from)) != 0 +// * TF_DataTypeSize(type) != 0 +// +// If any of the requirements are not met, *status is set to +// TF_INVALID_ARGUMENT. +TF_CAPI_EXPORT extern void TF_TensorBitcastFrom(const TF_Tensor* from, + TF_DataType type, TF_Tensor* to, + const int64_t* new_dims, + int num_new_dims, + TF_Status* status); + +// Returns bool iff this tensor is aligned. +TF_CAPI_EXPORT extern bool TF_TensorIsAligned(const TF_Tensor*); + +#ifdef __cplusplus +} /* end extern "C" */ +#endif + +#endif // TENSORFLOW_C_TF_TENSOR_H_ diff --git a/third_party/tflite-hdrs/tensorflow/c/tf_tensor_helper.h b/third_party/tflite-hdrs/tensorflow/c/tf_tensor_helper.h new file mode 100644 index 00000000..b77d5a78 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/c/tf_tensor_helper.h @@ -0,0 +1,47 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_TF_TENSOR_HELPER_H_ +#define TENSORFLOW_C_TF_TENSOR_HELPER_H_ + +#include + +#include "tensorflow/c/tf_tensor.h" +#include "tensorflow/core/platform/status.h" + +namespace tensorflow { + +class Tensor; + +absl::Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst); + +TF_Tensor* TF_TensorFromTensor(const Tensor& src, absl::Status* status); + +TF_Tensor* TF_TensorFromTensorShallow(const Tensor& src, absl::Status* status); + +namespace internal { + +struct TFTensorDeleter { + void operator()(TF_Tensor* tf_tensor) const { TF_DeleteTensor(tf_tensor); } +}; + +} // namespace internal + +// Struct that wraps TF_Tensor to delete once out of scope. +using TF_TensorPtr = std::unique_ptr; + +} // namespace tensorflow + +#endif // TENSORFLOW_C_TF_TENSOR_HELPER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/c/tf_tensor_internal.h b/third_party/tflite-hdrs/tensorflow/c/tf_tensor_internal.h new file mode 100644 index 00000000..61bceee5 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/c/tf_tensor_internal.h @@ -0,0 +1,136 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_TF_TENSOR_INTERNAL_H_ +#define TENSORFLOW_C_TF_TENSOR_INTERNAL_H_ + +#include +#include +#include +#include + +#include "tensorflow/c/tensor_interface.h" +#include "tensorflow/c/tf_datatype.h" +#include "tensorflow/c/tf_tensor.h" +#include "tensorflow/c/tf_tensor_helper.h" // IWYU pragma: export +#include "tensorflow/core/framework/allocation_description.pb.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/platform/casts.h" +#include "tensorflow/core/platform/status.h" + +// Internal structures used by the C API. These are likely to change and should +// not be depended on. + +// This struct forms part of the C API's public interface. It must strictly be +// passed to or returned from C functions *by pointer*. Otherwise, changes to +// its internal structure will break the C API's binary interface. +typedef struct TF_Tensor { + tensorflow::AbstractTensorInterface* tensor; +} TF_Tensor; + +class TF_ManagedBuffer : public tensorflow::TensorBuffer { + public: + TF_ManagedBuffer(void* data, size_t len, + void (*deallocator)(void* data, size_t len, void* arg), + void* deallocator_arg, bool owns_memory) + : TensorBuffer(data), + len_(len), + deallocator_(deallocator), + deallocator_arg_(deallocator_arg), + owns_memory_(owns_memory) {} + + ~TF_ManagedBuffer() override { + (*deallocator_)(data(), len_, deallocator_arg_); + } + + size_t size() const override { return len_; } + TensorBuffer* root_buffer() override { return this; } + void FillAllocationDescription( + tensorflow::AllocationDescription* proto) const override { + int64_t rb = size(); + proto->set_requested_bytes(rb); + proto->set_allocator_name(tensorflow::cpu_allocator()->Name()); + } + + bool OwnsMemory() const override { return owns_memory_; } + + private: + const size_t len_; + void (*const deallocator_)(void* data, size_t len, void* arg); + void* const deallocator_arg_; + bool owns_memory_; +}; + +namespace tensorflow { + +class TensorCApi { + public: + static TensorBuffer* Buffer(const Tensor& tensor) { return tensor.buf_; } + static Tensor MakeTensor(TF_DataType type, const TensorShape& shape, + TensorBuffer* buf) { + return Tensor(static_cast(type), shape, buf); + } +}; + +// Allocates tensor data buffer using specified allocator. +// `operation` is a name for this operation. +void* allocate_tensor(const char* operation, size_t len, Allocator* allocator); + +// Deallocates tensor data buffer. +// Defaults to deallocating using CPU allocator. You can pass pointer to +// a different Allocator as `arg`. +void deallocate_buffer(void* data, size_t len, void* arg); + +class TensorInterface : public AbstractTensorInterface { + public: + TensorInterface() {} + explicit TensorInterface(tensorflow::Tensor t) : tensor_(std::move(t)) {} + ~TensorInterface() override {} + + void Release() override; + + DataType Type() const override; + int NumDims() const override; + int64_t Dim(int dim_index) const override; + int64_t NumElements() const override; + size_t ByteSize() const override; + void* Data() const override; + bool IsAligned() const override; + bool CanMove() const override; + std::string SummarizeValue() const override; + + void SetShape(const int64_t* dims, int num_dims); + absl::Status ToTensor(tensorflow::Tensor* dst) const; + absl::Status BitcastFrom(const TensorInterface& from, DataType type, + const int64_t* new_dims, int num_new_dims); + absl::Status FromProto(const tensorflow::TensorProto& from); + + tensorflow::Tensor& Tensor() { return tensor_; } + + private: + tensorflow::Tensor tensor_; +}; + +inline Tensor& TensorFromInterface(AbstractTensorInterface* tensor) { + return down_cast(tensor)->Tensor(); +} + +AbstractTensorInterface* TensorInterfaceFromTensor(const Tensor& src, + absl::Status* status); + +} // namespace tensorflow + +#endif // TENSORFLOW_C_TF_TENSOR_INTERNAL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/c/tf_tstring.h b/third_party/tflite-hdrs/tensorflow/c/tf_tstring.h new file mode 100644 index 00000000..876fd5f3 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/c/tf_tstring.h @@ -0,0 +1,49 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_C_TF_TSTRING_H_ +#define TENSORFLOW_C_TF_TSTRING_H_ + +#include "tensorflow/c/c_api_macros.h" +#include "tensorflow/c/tf_tensor.h" +#include "tensorflow/core/platform/ctstring.h" + +#ifdef __cplusplus +extern "C" { +#endif + +TF_CAPI_EXPORT extern void TF_StringInit(TF_TString *t); + +TF_CAPI_EXPORT extern void TF_StringCopy(TF_TString *dst, const char *src, + size_t size); + +TF_CAPI_EXPORT extern void TF_StringAssignView(TF_TString *dst, const char *src, + size_t size); + +TF_CAPI_EXPORT extern const char *TF_StringGetDataPointer( + const TF_TString *tstr); + +TF_CAPI_EXPORT extern TF_TString_Type TF_StringGetType(const TF_TString *str); + +TF_CAPI_EXPORT extern size_t TF_StringGetSize(const TF_TString *tstr); + +TF_CAPI_EXPORT extern size_t TF_StringGetCapacity(const TF_TString *str); + +TF_CAPI_EXPORT extern void TF_StringDealloc(TF_TString *tstr); + +#ifdef __cplusplus +} /* end extern "C" */ +#endif + +#endif // TENSORFLOW_C_TF_TSTRING_H_ diff --git a/third_party/tflite-hdrs/tensorflow/cc/client/client_session.h b/third_party/tflite-hdrs/tensorflow/cc/client/client_session.h new file mode 100644 index 00000000..9dc790d0 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/cc/client/client_session.h @@ -0,0 +1,164 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CC_CLIENT_CLIENT_SESSION_H_ +#define TENSORFLOW_CC_CLIENT_CLIENT_SESSION_H_ + +#include +#include +#include +#include + +#include "tensorflow/cc/framework/ops.h" +#include "tensorflow/cc/framework/scope.h" +#include "tensorflow/core/public/session_options.h" + +namespace tsl { +namespace thread { +struct ThreadPoolOptions; +} +} // namespace tsl + +namespace tensorflow { + +namespace thread { +using tsl::thread::ThreadPoolOptions; +} + +/// @addtogroup core +/// @{ + +/// A `ClientSession` object lets the caller drive the evaluation of the +/// TensorFlow graph constructed with the C++ API. +/// +/// Example: +/// +/// Scope root = Scope::NewRootScope(); +/// auto a = Placeholder(root, DT_INT32); +/// auto c = Add(root, a, {41}); +/// +/// ClientSession session(root); +/// std::vector outputs; +/// +/// Status s = session.Run({ {a, {1}} }, {c}, &outputs); +/// if (!s.ok()) { ... } +class ClientSession { + public: + /// A data type to represent feeds to a Run call. + /// + /// This is a map of `Output` objects returned by op-constructors to the value + /// to feed them with. See `Input::Initializer` for details on what can be + /// used as feed values. + typedef std::unordered_map FeedType; + + /// Create a new session to evaluate the graph contained in `scope` by + /// connecting to the TensorFlow runtime specified by `target`. + ClientSession(const Scope& scope, const string& target); + + /// Same as above, but use the empty string ("") as the target specification. + explicit ClientSession(const Scope& scope); + + /// Create a new session, configuring it with `session_options`. + ClientSession(const Scope& scope, const SessionOptions& session_options); + + ~ClientSession(); + + /// Evaluate the tensors in `fetch_outputs`. The values are returned as + /// `Tensor` objects in `outputs`. The number and order of `outputs` will + /// match `fetch_outputs`. + absl::Status Run(const std::vector& fetch_outputs, + std::vector* outputs) const; + + /// Same as above, but use the mapping in `inputs` as feeds. + absl::Status Run(const FeedType& inputs, + const std::vector& fetch_outputs, + std::vector* outputs) const; + + /// Same as above. Additionally runs the operations ins `run_outputs`. + absl::Status Run(const FeedType& inputs, + const std::vector& fetch_outputs, + const std::vector& run_outputs, + std::vector* outputs) const; + + /// Use `run_options` to turn on performance profiling. `run_metadata`, if not + /// null, is filled in with the profiling results. + absl::Status Run(const RunOptions& run_options, const FeedType& inputs, + const std::vector& fetch_outputs, + const std::vector& run_outputs, + std::vector* outputs, + RunMetadata* run_metadata) const; + + /// Same as above. Additionally allows user to provide custom threadpool + /// implementation via ThreadPoolOptions. + absl::Status Run(const RunOptions& run_options, const FeedType& inputs, + const std::vector& fetch_outputs, + const std::vector& run_outputs, + std::vector* outputs, RunMetadata* run_metadata, + const thread::ThreadPoolOptions& threadpool_options) const; + + /// \brief A handle to a subgraph, created with + /// `ClientSession::MakeCallable()`. + typedef int64_t CallableHandle; + + /// \brief Creates a `handle` for invoking the subgraph defined by + /// `callable_options`. + /// NOTE: This API is still experimental and may change. + absl::Status MakeCallable(const CallableOptions& callable_options, + CallableHandle* out_handle); + + /// \brief Invokes the subgraph named by `handle` with the given options and + /// input tensors. + /// + /// The order of tensors in `feed_tensors` must match the order of names in + /// `CallableOptions::feed()` and the order of tensors in `fetch_tensors` will + /// match the order of names in `CallableOptions::fetch()` when this subgraph + /// was created. + /// NOTE: This API is still experimental and may change. + absl::Status RunCallable(CallableHandle handle, + const std::vector& feed_tensors, + std::vector* fetch_tensors, + RunMetadata* run_metadata); + + /// \brief Invokes the subgraph named by `handle` with the given options and + /// input tensors. + /// + /// The order of tensors in `feed_tensors` must match the order of names in + /// `CallableOptions::feed()` and the order of tensors in `fetch_tensors` will + /// match the order of names in `CallableOptions::fetch()` when this subgraph + /// was created. + /// NOTE: This API is still experimental and may change. + absl::Status RunCallable(CallableHandle handle, + const std::vector& feed_tensors, + std::vector* fetch_tensors, + RunMetadata* run_metadata, + const thread::ThreadPoolOptions& options); + + /// \brief Releases resources associated with the given `handle` in this + /// session. + /// NOTE: This API is still experimental and may change. + absl::Status ReleaseCallable(CallableHandle handle); + + private: + class Impl; + std::unique_ptr impl_; + Impl* impl() { return impl_.get(); } + const Impl* impl() const { return impl_.get(); } +}; + +/// @} + +} // end namespace tensorflow + +#endif // TENSORFLOW_CC_CLIENT_CLIENT_SESSION_H_ diff --git a/third_party/tflite-hdrs/tensorflow/cc/experimental/base/public/runtime.h b/third_party/tflite-hdrs/tensorflow/cc/experimental/base/public/runtime.h new file mode 100644 index 00000000..711a38c2 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/cc/experimental/base/public/runtime.h @@ -0,0 +1,71 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_RUNTIME_H_ +#define TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_RUNTIME_H_ + +#include + +#include "tensorflow/c/eager/c_api_experimental.h" + +namespace tensorflow { +namespace experimental { +namespace cc { + +// Runtime represents an opaque instance of a Tensorflow runtime, with its own +// resources, threadpools, etc. Clients are expected to construct a Runtime +// object through tensorflow::cc::RuntimeBuilder::Build, after setting any +// relevant configuration options. Many Tensorflow functions take a reference to +// the runtime as an argument (eg: tensorflow::cc::SavedModelAPI::Load), and +// may have different implementations depending on the runtime. For many of +// these Runtime-attached objects (such as tensorflow::cc::TensorHandle), the +// Runtime must outlive these objects. +class Runtime { + public: + // Runtime is movable, but not copyable. + Runtime(Runtime&&) = default; + Runtime& operator=(Runtime&&) = default; + + private: + friend class RuntimeBuilder; + friend class SavedModelAPI; + friend class TensorHandle; + + // Wraps a TFE_Context. Takes ownership of ctx. + explicit Runtime(TFE_Context* ctx) : ctx_(ctx) {} + + // Deletes the currently wrapped TFE_Context, swaps it with ctx, + // and takes ownership of ctx. + void Reset(TFE_Context* ctx) { ctx_.reset(ctx); } + + // Returns the TFE_Context that this object wraps. This object + // retains ownership of the pointer. + TFE_Context* GetTFEContext() const { return ctx_.get(); } + + // Runtime is not copyable + Runtime(const Runtime&) = delete; + Runtime& operator=(const Runtime&) = delete; + + struct TFEContextDeleter { + void operator()(TFE_Context* p) const { TFE_DeleteContext(p); } + }; + std::unique_ptr ctx_; +}; + +} // namespace cc +} // namespace experimental +} // namespace tensorflow + +#endif // TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_RUNTIME_H_ diff --git a/third_party/tflite-hdrs/tensorflow/cc/experimental/base/public/runtime_builder.h b/third_party/tflite-hdrs/tensorflow/cc/experimental/base/public/runtime_builder.h new file mode 100644 index 00000000..737e06cb --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/cc/experimental/base/public/runtime_builder.h @@ -0,0 +1,86 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_RUNTIME_BUILDER_H_ +#define TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_RUNTIME_BUILDER_H_ + +#include + +#include "tensorflow/c/eager/c_api.h" +#include "tensorflow/c/eager/c_api_experimental.h" +#include "tensorflow/cc/experimental/base/public/runtime.h" +#include "tensorflow/cc/experimental/base/public/status.h" + +namespace tensorflow { +namespace experimental { +namespace cc { + +// RuntimeBuilder is a builder used to construct a tensorflow::cc::Runtime. +// Use this to set configuration options, like threadpool size, etc. +class RuntimeBuilder { + public: + RuntimeBuilder() : options_(TFE_NewContextOptions()) {} + + // If `use_tfrt` is true, we will use the new Tensorflow Runtime + // (https://blog.tensorflow.org/2020/04/tfrt-new-tensorflow-runtime.html) as + // our runtime implementation. + RuntimeBuilder& SetUseTFRT(bool use_tfrt); + + // Build a Tensorflow Runtime. + // + // Params: + // status - Set to OK on success and an appropriate error on failure. + // Returns: + // If status is not OK, returns nullptr. Otherwise, returns a + // unique_ptr. + std::unique_ptr Build(Status* status); + + // RuntimeBuilder is movable, but not copyable. + RuntimeBuilder(RuntimeBuilder&&) = default; + RuntimeBuilder& operator=(RuntimeBuilder&&) = default; + + private: + // RuntimeBuilder is not copyable + RuntimeBuilder(const RuntimeBuilder&) = delete; + RuntimeBuilder& operator=(const RuntimeBuilder&) = delete; + + struct TFEContextOptionsDeleter { + void operator()(TFE_ContextOptions* p) const { + TFE_DeleteContextOptions(p); + } + }; + std::unique_ptr options_; +}; + +inline RuntimeBuilder& RuntimeBuilder::SetUseTFRT(bool use_tfrt) { + TFE_ContextOptionsSetTfrt(options_.get(), use_tfrt); + return *this; +} + +inline std::unique_ptr RuntimeBuilder::Build(Status* status) { + TFE_Context* result = TFE_NewContext(options_.get(), status->GetTFStatus()); + if (!status->ok()) { + return nullptr; + } + // We can't use std::make_unique here because of its interaction with a + // private constructor: https://abseil.io/tips/134 + return std::unique_ptr(new Runtime(result)); +} + +} // namespace cc +} // namespace experimental +} // namespace tensorflow + +#endif // TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_RUNTIME_BUILDER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/cc/experimental/base/public/status.h b/third_party/tflite-hdrs/tensorflow/cc/experimental/base/public/status.h new file mode 100644 index 00000000..98c8cf6c --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/cc/experimental/base/public/status.h @@ -0,0 +1,96 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_STATUS_H_ +#define TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_STATUS_H_ + +#include +#include + +#include "tensorflow/c/tf_status.h" + +namespace tensorflow { +namespace experimental { +namespace cc { + +// Status is a wrapper around an error code and an optional error message. +// The set of error codes are defined here: +// https://github.com/tensorflow/tensorflow/blob/08931c1e3e9eb2e26230502d678408e66730826c/tensorflow/c/tf_status.h#L39-L60 +// Many Tensorflow APIs return a Status, or take a Status as an out parameter. +// Clients should check for status.ok() after calling these APIs, and either +// handle or propagate the error appropriately. +// TODO(bmzhao): Add a detailed code example before moving out of experimental. +class Status { + public: + // Create a success status + Status() : status_(TF_NewStatus()) {} + + // Return the status code + TF_Code code() const; + + // Returns the error message in Status. + std::string message() const; + + // Returns the error message in Status. + bool ok() const; + + // Record in Status. Any previous information is lost. + // A common use is to clear a status: SetStatus(TF_OK, ""); + void SetStatus(TF_Code code, const std::string& msg); + + // Status is movable, but not copyable. + Status(Status&&) = default; + Status& operator=(Status&&) = default; + + private: + friend class RuntimeBuilder; + friend class Runtime; + friend class SavedModelAPI; + friend class TensorHandle; + + // Wraps a TF_Status*, and takes ownership of it. + explicit Status(TF_Status* status) : status_(status) {} + + // Status is not copyable + Status(const Status&) = delete; + Status& operator=(const Status&) = delete; + + // Returns the TF_Status that this object wraps. This object + // retains ownership of the pointer. + TF_Status* GetTFStatus() const { return status_.get(); } + + struct TFStatusDeleter { + void operator()(TF_Status* p) const { TF_DeleteStatus(p); } + }; + std::unique_ptr status_; +}; + +inline TF_Code Status::code() const { return TF_GetCode(status_.get()); } + +inline std::string Status::message() const { + return std::string(TF_Message(status_.get())); +} + +inline bool Status::ok() const { return code() == TF_OK; } + +inline void Status::SetStatus(TF_Code code, const std::string& msg) { + TF_SetStatus(status_.get(), code, msg.c_str()); +} + +} // namespace cc +} // namespace experimental +} // namespace tensorflow + +#endif // TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_STATUS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/cc/experimental/base/public/tensor.h b/third_party/tflite-hdrs/tensorflow/cc/experimental/base/public/tensor.h new file mode 100644 index 00000000..7aab1cce --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/cc/experimental/base/public/tensor.h @@ -0,0 +1,175 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_TENSOR_H_ +#define TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_TENSOR_H_ + +#include +#include + +#include +#include +#include + +#include "tensorflow/c/tf_datatype.h" +#include "tensorflow/c/tf_tensor.h" +#include "tensorflow/cc/experimental/base/public/status.h" + +namespace tensorflow { +namespace experimental { +namespace cc { + +// Tensor represents an n-dimensional array of values. +class Tensor { + public: + using DeleterCallback = std::function; + + // Constructs a Tensor from user provided buffer. + // + // Params: + // dtype - The dtype of the tensor's data. + // shape - A shape vector, where each element corresponds to the size of + // the tensor's corresponding dimension. + // data - Pointer to a buffer of memory to construct a Tensor out of. + // len - The length (in bytes) of `data` + // deleter - A std::function to be called when the Tensor no longer needs the + // memory in `data`. This can be used to free `data`, or + // perhaps decrement a refcount associated with `data`, etc. + // status - Set to OK on success and an error on failure. + // Returns: + // If an error occurred, status->ok() will be false, and the returned + // Tensor must not be used. + // TODO(bmzhao): Add Runtime as an argument to this function so we can swap to + // a TFRT backed tensor. + // TODO(bmzhao): Add benchmarks on overhead for this function; we can + // consider using int64_t* + length rather than vector. + static Tensor FromBuffer(TF_DataType dtype, const std::vector& shape, + void* data, size_t len, DeleterCallback deleter, + Status* status); + + // TODO(bmzhao): In the case we construct a tensor from non-owned memory, + // we should offer a way to deep copy the tensor into a new tensor, which + // owns the underlying memory. This could be a .deepcopy()/clone() method. + + // TODO(bmzhao): In the future, we want to relax the non-copyability + // constraint. To do so, we can add a C API function that acts like + // CopyFrom: + // https://github.com/tensorflow/tensorflow/blob/08931c1e3e9eb2e26230502d678408e66730826c/tensorflow/core/framework/tensor.h#L301-L311 + + // Tensor is movable, but not copyable + Tensor(Tensor&&) = default; + Tensor& operator=(Tensor&&) = default; + + // Returns the number of dimensions in the tensor. Can be -1, which represents + // unknown rank. + int dims() const; + + // Returns the number of elements in dimension `d`. + // REQUIRES: `0 <= d < dims()` + int64_t dim_size(int d) const; + + // Returns a pointer to the underlying data buffer. + void* data() const; + + // Returns the data type of the tensor. + TF_DataType dtype() const; + + // Returns the number of elements in the tensor. For a tensor with a partially + // defined shape, -1 means not fully defined. + int64_t num_elements() const; + + // Returns the size of the underlying data in bytes. + size_t num_bytes() const; + + private: + friend class TensorHandle; + friend class Runtime; + + // Wraps a TF_Tensor. Takes ownership of handle. + explicit Tensor(TF_Tensor* tensor) : tensor_(tensor) {} + + // Tensor is not copyable + Tensor(const Tensor&) = delete; + Tensor& operator=(const Tensor&) = delete; + + // Returns the underlying TF_Tensor that this object wraps. + // This object retains ownership of the pointer. + TF_Tensor* GetTFTensor() const { return tensor_.get(); } + + struct DeleterStruct { + std::function deleter; + }; + + static void DeleterFunction(void* memory, size_t len, void* deleter_struct) { + DeleterStruct* deleter = reinterpret_cast(deleter_struct); + deleter->deleter(memory, len); + delete deleter; + } + + struct TFTensorDeleter { + void operator()(TF_Tensor* p) const { TF_DeleteTensor(p); } + }; + std::unique_ptr tensor_; +}; + +inline void* Tensor::data() const { return TF_TensorData(tensor_.get()); } + +inline int Tensor::dims() const { return TF_NumDims(tensor_.get()); } + +inline int64_t Tensor::dim_size(int d) const { + return TF_Dim(tensor_.get(), d); +} + +inline TF_DataType Tensor::dtype() const { + return TF_TensorType(tensor_.get()); +} + +inline int64_t Tensor::num_elements() const { + return TF_TensorElementCount(tensor_.get()); +} + +inline size_t Tensor::num_bytes() const { + return TF_TensorByteSize(tensor_.get()); +} + +inline Tensor Tensor::FromBuffer(TF_DataType dtype, + const std::vector& shape, void* data, + size_t len, DeleterCallback deleter, + Status* status) { + // Credit to apassos@ for this technique: + // Despite the fact that our API takes a std::function deleter, we are able + // to maintain ABI stability because: + // 1. Only a function pointer is sent across the C API (&DeleterFunction) + // 2. DeleterFunction is defined in the same build artifact that constructed + // the std::function (so there isn't confusion about std::function ABI). + // Note that 2. is satisfied by the fact that this is a header-only API, where + // the function implementations are inline. + + DeleterStruct* deleter_struct = new DeleterStruct{deleter}; + TF_Tensor* tensor = TF_NewTensor(dtype, shape.data(), shape.size(), data, len, + &DeleterFunction, deleter_struct); + if (tensor == nullptr) { + status->SetStatus(TF_INVALID_ARGUMENT, + "Failed to create tensor for input buffer"); + return Tensor(nullptr); + } + return Tensor(tensor); +} + +} // namespace cc +} // namespace experimental +} // namespace tensorflow + +#endif // TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_TENSOR_H_ diff --git a/third_party/tflite-hdrs/tensorflow/cc/experimental/base/public/tensorhandle.h b/third_party/tflite-hdrs/tensorflow/cc/experimental/base/public/tensorhandle.h new file mode 100644 index 00000000..99453ee7 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/cc/experimental/base/public/tensorhandle.h @@ -0,0 +1,98 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_TENSORHANDLE_H_ +#define TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_TENSORHANDLE_H_ + +#include +#include + +#include "tensorflow/c/eager/c_api.h" +#include "tensorflow/c/eager/c_api_experimental.h" +#include "tensorflow/cc/experimental/base/public/runtime.h" +#include "tensorflow/cc/experimental/base/public/status.h" +#include "tensorflow/cc/experimental/base/public/tensor.h" + +namespace tensorflow { +namespace experimental { +namespace cc { + +// An opaque representation of a tensor computed/managed by the Tensorflow +// runtime (tensorflow:cc::Runtime). Unlike a tensor, a Tensorhandle may refer +// to tensors placed in memory of different devices or remote address spaces. +// Note that tensorflow::cc::Runtime MUST outlive all TensorHandles created +// from it. +class TensorHandle { + public: + // Unwraps a Tensor from the given TensorHandle. If an error occurred, + // status->ok() will be false, and the returned Tensor must not be used. + Tensor Resolve(Status* status); + + // Constructs a TensorHandle from a Tensor. If an error occurred, + // status->ok() will be false, and the returned TensorHandle must not be used. + static TensorHandle FromTensor(const Tensor& tensor, const Runtime& runtime, + Status* status); + + // TensorHandle is movable, and not copyable + TensorHandle(TensorHandle&&) = default; + TensorHandle& operator=(TensorHandle&&) = default; + + private: + // Wraps a TFE_TensorHandle. Takes ownership of handle. + explicit TensorHandle(TFE_TensorHandle* handle) : handle_(handle) {} + + // TensorHandle is not copyable + TensorHandle(const TensorHandle&) = delete; + TensorHandle& operator=(const TensorHandle&) = delete; + + // Returns the underlying TFE_TensorHandle that this object wraps. + // This object retains ownership of the pointer. + TFE_TensorHandle* GetTFETensorHandle() const { return handle_.get(); } + + // Deletes the currently wrapped TFE_TensorHandle, and swaps it with handle, + // and takes ownership of handle. + void Reset(TFE_TensorHandle* handle) { handle_.reset(handle); } + + struct TFETensorHandleDeleter { + void operator()(TFE_TensorHandle* p) const { TFE_DeleteTensorHandle(p); } + }; + std::unique_ptr handle_; +}; + +inline Tensor TensorHandle::Resolve(Status* status) { + TF_Tensor* tensor = + TFE_TensorHandleResolve(handle_.get(), status->GetTFStatus()); + if (!status->ok()) { + return Tensor(nullptr); + } + return Tensor(tensor); +} + +inline TensorHandle TensorHandle::FromTensor(const Tensor& tensor, + const Runtime& runtime, + Status* status) { + TFE_TensorHandle* tensor_handle = TFE_NewTensorHandleFromTensor( + runtime.GetTFEContext(), tensor.GetTFTensor(), status->GetTFStatus()); + if (!status->ok()) { + return TensorHandle(nullptr); + } + return TensorHandle(tensor_handle); +} + +} // namespace cc +} // namespace experimental +} // namespace tensorflow + +#endif // TENSORFLOW_CC_EXPERIMENTAL_BASE_PUBLIC_TENSORHANDLE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/cc/experimental/base/tests/tensor_types_test_util.h b/third_party/tflite-hdrs/tensorflow/cc/experimental/base/tests/tensor_types_test_util.h new file mode 100644 index 00000000..1e649d5d --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/cc/experimental/base/tests/tensor_types_test_util.h @@ -0,0 +1,76 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CC_EXPERIMENTAL_BASE_TESTS_TENSOR_TYPES_TEST_UTIL_H_ +#define TENSORFLOW_CC_EXPERIMENTAL_BASE_TESTS_TENSOR_TYPES_TEST_UTIL_H_ + +#include + +#include "tensorflow/c/tf_datatype.h" + +namespace tensorflow { + +// Each of the following struct types have two members: a kDType that +// corresponds to a TF_Datatype enum value, and a typedef "type" +// of its corresponding C++ type. These types allow us to write Dtype-agnostic +// tests via GoogleTest's TypedTests: +// https://github.com/google/googletest/blob/e589a337170554c48bc658cc857cf15080c9eacc/googletest/docs/advanced.md#typed-tests +struct FloatType { + using type = float; + static constexpr TF_DataType kDType = TF_FLOAT; +}; + +struct DoubleType { + using type = double; + static constexpr TF_DataType kDType = TF_DOUBLE; +}; + +struct Int32Type { + using type = int32_t; + static constexpr TF_DataType kDType = TF_INT32; +}; + +struct UINT8Type { + using type = uint8_t; + static constexpr TF_DataType kDType = TF_UINT8; +}; + +struct INT8Type { + using type = int8_t; + static constexpr TF_DataType kDType = TF_INT8; +}; + +struct INT64Type { + using type = int64_t; + static constexpr TF_DataType kDType = TF_INT64; +}; + +struct UINT16Type { + using type = uint16_t; + static constexpr TF_DataType kDType = TF_UINT16; +}; + +struct UINT32Type { + using type = uint32_t; + static constexpr TF_DataType kDType = TF_UINT32; +}; + +struct UINT64Type { + using type = uint64_t; + static constexpr TF_DataType kDType = TF_UINT64; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CC_EXPERIMENTAL_BASE_TESTS_TENSOR_TYPES_TEST_UTIL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/cc/experimental/libexport/load.h b/third_party/tflite-hdrs/tensorflow/cc/experimental/libexport/load.h new file mode 100644 index 00000000..6775f73b --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/cc/experimental/libexport/load.h @@ -0,0 +1,108 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CC_EXPERIMENTAL_LIBEXPORT_LOAD_H_ +#define TENSORFLOW_CC_EXPERIMENTAL_LIBEXPORT_LOAD_H_ + +#include + +#include "absl/container/flat_hash_map.h" +#include "tensorflow/core/framework/function.pb.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/platform/statusor.h" +#include "tensorflow/core/protobuf/saved_model.pb.h" +#include "tensorflow/core/protobuf/saved_object_graph.pb.h" +#include "tensorflow/core/protobuf/trackable_object_graph.pb.h" +#include "tensorflow/core/util/tensor_bundle/tensor_bundle.h" + +namespace tensorflow { +namespace libexport { + +// A low-level representation of a SavedModel. +// +// This class should only ever be a thin wrapper around disk (or other storage) +// access for a SavedModel. Higher level functionality should be layered on top +// by other functions and classes. +// +// In the future, this class can also provide a mechanism for automatic version +// migration. This will allow the calling code to always work against the most +// recent version of SavedModel. +class TFPackage { + public: + // Load a SavedModel, parsing the associated protobuf for later access. + static absl::StatusOr Load(const std::string& path); + + // Reads and returns a checkpoint key associated with a variable. + // + // The variable is identified by the index in the object graph node list. + // + // RestoreV2 is the operation that will ultimately be responsible for reading + // and restoring the variable(s)' values. Variable values are indexed in the + // checkpoint files by "checkpoint keys". These keys along with dtype and + // shape / slice information allow RestoreV2 to look up a variable's value in + // the SavedModel and restore it into a tensor. + absl::StatusOr GetVariableCheckpointKey(int index); + + // Retrieves the object graph from the SavedModel. + // + // For now, we're returning the object graph directly (i.e. the parsed proto) + // rather than adding abstraction on top. We may later find we would like an + // intermediate abstraction layer to make traversal easier, but for now the + // extra complexity doesn't seem justified. Regardless of what we choose, + // that logic should live outside this class; this class should continue to + // have the clearly-defined, singular responsibility of reading and parsing + // the low-level, serialized format. + const SavedObjectGraph& GetObjectGraph(); + + // Retrieves a specific GraphDef node by name. + // + // GraphDef nodes are stored as a repeating list of nodes. At module load + // time, a module may have constants that need to be restored. To restore + // these constants, they are looked up in the GraphDef's nodes by their name. + // Since we may need to load many constants, we create a hash map of these + // names to their corresponding nodes at load time in order to look them up + // in constant time. + absl::StatusOr GetGraphDefNode(std::string name); + + // Returns a list of function defs in the SavedModel. + const protobuf::RepeatedPtrField& GetFunctionDefs(); + + // Returns a BundleReader for reading variable values. + // + // This TFPackage retains ownership of the underlying reader. + tensorflow::BundleReader* GetVariableReader() { + return variable_reader_.get(); + } + + // Returns whether or not we found a valid checkpoint when loading the + // package. + bool HasCheckpoint() { return has_checkpoint_; } + + // Returns the path to the variables file. + const std::string GetVariablesFilepath() const { return variables_filepath_; } + + private: + SavedModel saved_model_proto_; + TrackableObjectGraph trackable_object_graph_; + std::unique_ptr variable_reader_; + std::string variables_filepath_; + bool has_checkpoint_; + absl::flat_hash_map graph_def_nodes_by_name_; +}; + +} // namespace libexport +} // namespace tensorflow + +#endif // TENSORFLOW_CC_EXPERIMENTAL_LIBEXPORT_LOAD_H_ diff --git a/third_party/tflite-hdrs/tensorflow/cc/experimental/libexport/save.h b/third_party/tflite-hdrs/tensorflow/cc/experimental/libexport/save.h new file mode 100644 index 00000000..382f4645 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/cc/experimental/libexport/save.h @@ -0,0 +1,33 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CC_EXPERIMENTAL_LIBEXPORT_SAVE_H_ +#define TENSORFLOW_CC_EXPERIMENTAL_LIBEXPORT_SAVE_H_ + +#include + +#include "tensorflow/core/platform/status.h" + +namespace tensorflow { +namespace libexport { + +// Writes a saved model to disk. +// +// Writes a saved model to the given `export_dir`. +TF_EXPORT Status Save(const std::string& export_dir); + +} // namespace libexport +} // namespace tensorflow + +#endif // TENSORFLOW_CC_EXPERIMENTAL_EXPORT_EXPORT_H_ diff --git a/third_party/tflite-hdrs/tensorflow/cc/framework/cc_op_gen.h b/third_party/tflite-hdrs/tensorflow/cc/framework/cc_op_gen.h new file mode 100644 index 00000000..7b348365 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/cc/framework/cc_op_gen.h @@ -0,0 +1,34 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CC_FRAMEWORK_CC_OP_GEN_H_ +#define TENSORFLOW_CC_FRAMEWORK_CC_OP_GEN_H_ + +#include + +#include "tensorflow/core/framework/op_def.pb.h" +#include "tensorflow/core/framework/op_gen_lib.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +namespace cc_op { +/// Result is written to files dot_h and dot_cc. +void WriteCCOps(const OpList& ops, const ApiDefMap& api_def_map, + const string& dot_h_fname, const string& dot_cc_fname); + +} // namespace cc_op +} // namespace tensorflow + +#endif // TENSORFLOW_CC_FRAMEWORK_CC_OP_GEN_H_ diff --git a/third_party/tflite-hdrs/tensorflow/cc/framework/cc_op_gen_util.h b/third_party/tflite-hdrs/tensorflow/cc/framework/cc_op_gen_util.h new file mode 100644 index 00000000..4e3272c7 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/cc/framework/cc_op_gen_util.h @@ -0,0 +1,148 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CC_FRAMEWORK_CC_OP_GEN_UTIL_H_ +#define TENSORFLOW_CC_FRAMEWORK_CC_OP_GEN_UTIL_H_ + +#include +#include +#include +#include + +#include "tensorflow/core/framework/api_def.pb.h" +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/attr_value_util.h" +#include "tensorflow/core/framework/op_def_util.h" +#include "tensorflow/core/framework/op_gen_lib.h" +#include "tensorflow/core/framework/tensor.pb.h" +#include "tensorflow/core/framework/tensor_shape.pb.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +namespace cc_op { + +absl::StatusOr LoadOpsAndApiDefs( + OpList& ops, bool include_internal, + const std::vector& api_def_dirs); + +// Converts: +// bazel-out/.../(bin|genfiles)/(external/YYY/)?XX +// to: XX. +string GetPath(absl::string_view dot_h_fname); + +// Converts: some/path/to/file.xx +// to: file +// (note that suffix is removed) +string GetFilename(absl::string_view path); + +// Converts: +// cc/ops/gen_foo_ops.h +// to: +// CC_OPS_GEN_FOO_OPS_H_ +string ToGuard(absl::string_view path); + +// Converts: some_name_xyz +// to: Some Name Xyz +string ToTitle(absl::string_view name); + +// Change: Into: +// ABC /// ABC +// /// +// DEF /// DEF +string MakeComment(absl::string_view text, absl::string_view indent); + +string PrintString(absl::string_view str); + +string PrintTensorShape(const TensorShapeProto& shape_proto); + +template +string PrintArray(int64_t num_elts, const T* array) { + string ret; + for (int64_t i = 0; i < num_elts; ++i) { + if (i > 0) strings::StrAppend(&ret, ", "); + strings::StrAppend(&ret, array[i]); + } + return ret; +} + +string PrintTensor(const TensorProto& tensor_proto); + +string PrintTensorProto(const TensorProto& proto); + +string PrintAttrValue(absl::string_view, const AttrValue& attr_value); + +bool IsEmptyList(const AttrValue::ListValue& list); + +string ToCamelCase(absl::string_view str); + +string SeparateNamespaces(absl::string_view str); + +// Returns a pair. The string is the C++ type name to be used for +// attr_type when defining an object of that type. The bool is a flag to +// indicate whether to treat the type as const when accepting the C++ type as an +// argument to a function. +std::pair AttrTypeName(absl::string_view attr_type); + +absl::string_view ListElementTypeName(absl::string_view attr_type); + +bool IsCPPKeyword(absl::string_view name); + +string AvoidCPPKeywords(absl::string_view name); + +void InferArgAttributes(const OpDef::ArgDef& arg, + std::unordered_map* inferred_attrs); + +void InferOpAttributes( + const OpDef& op_def, + std::unordered_map* inferred_input_attrs); + +bool ArgIsList(const OpDef::ArgDef& arg); + +bool HasOptionalAttrs( + const ApiDef& api_def, + const std::unordered_map& inferred_input_attrs); + +struct OpInfo { + // graph_op_def: The OpDef used by the runtime, has the names that + // must be used when calling NodeBuilder. + // interface_op_def: The OpDef used in the interface in the generated + // code, with possibly overridden names and defaults. + OpInfo(const OpDef& graph_op_def, const ApiDef& api_def, + const std::vector& aliases); + OpInfo(const OpDef& graph_op_def, const ApiDef& api_def); + string GetOpAttrStruct() const; + string GetConstructorDecl(absl::string_view op_name_prefix, + bool include_attr) const; + + string op_name; + std::vector arg_types; + std::vector arg_names; + std::vector output_types; + std::vector output_names; + std::vector is_list_output; + bool has_optional_attrs; + string comment; + + const OpDef& graph_op_def; + const ApiDef& api_def; + const std::vector& aliases; + // Map from type attribute to corresponding original argument name. + std::unordered_map inferred_input_attrs; +}; + +} // namespace cc_op +} // namespace tensorflow + +#endif // TENSORFLOW_CC_FRAMEWORK_CC_OP_GEN_UTIL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/cc/framework/fuzzing/cc_op_fuzz_gen.h b/third_party/tflite-hdrs/tensorflow/cc/framework/fuzzing/cc_op_fuzz_gen.h new file mode 100644 index 00000000..c11c9635 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/cc/framework/fuzzing/cc_op_fuzz_gen.h @@ -0,0 +1,36 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CC_FRAMEWORK_FUZZING_CC_OP_FUZZ_GEN_H_ +#define TENSORFLOW_CC_FRAMEWORK_FUZZING_CC_OP_FUZZ_GEN_H_ + +#include "tensorflow/cc/framework/cc_op_gen_util.h" +#include "tensorflow/core/framework/op_def.pb.h" +#include "tensorflow/core/framework/op_gen_lib.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +namespace cc_op { + +// String with single fuzzer file content. +string WriteSingleFuzzer(const OpInfo& op_info, bool is_fuzzable); + +// Do we have all we need to create a fuzzer +bool OpFuzzingIsOk(const OpInfo& op_info); + +} // namespace cc_op +} // namespace tensorflow + +#endif // TENSORFLOW_CC_FRAMEWORK_FUZZING_CC_OP_FUZZ_GEN_H_ diff --git a/third_party/tflite-hdrs/tensorflow/cc/framework/grad_op_registry.h b/third_party/tflite-hdrs/tensorflow/cc/framework/grad_op_registry.h new file mode 100644 index 00000000..b0847844 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/cc/framework/grad_op_registry.h @@ -0,0 +1,77 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CC_FRAMEWORK_GRAD_OP_REGISTRY_H_ +#define TENSORFLOW_CC_FRAMEWORK_GRAD_OP_REGISTRY_H_ + +#include +#include +#include + +#include "tensorflow/cc/framework/ops.h" +#include "tensorflow/cc/framework/scope.h" + +namespace tensorflow { +namespace ops { + +/// GradFunc is the signature for all gradient functions in GradOpRegistry. +/// Implementations should add operations to compute the gradient outputs of +/// 'op' (returned in 'grad_outputs') using 'scope' and 'grad_inputs'. +typedef absl::Status (*GradFunc)(const Scope& scope, const Operation& op, + const std::vector& grad_inputs, + std::vector* grad_outputs); + +/// GradOpRegistry maintains a static registry of gradient functions. +/// Gradient functions are indexed in the registry by the forward op name (i.e. +/// "MatMul" -> MatMulGrad func). +class GradOpRegistry { + public: + /// Registers 'func' as the gradient function for 'op'. + /// Returns true if registration was successful, check fails otherwise. + bool Register(const string& op, GradFunc func); + + /// Sets 'func' to the gradient function for 'op' and returns Status OK if + /// the gradient function for 'op' exists in the registry. + /// Note that 'func' can be null for ops that have registered no-gradient with + /// the registry. + /// Returns error status otherwise. + absl::Status Lookup(const string& op, GradFunc* func) const; + + /// Returns a pointer to the global gradient function registry. + static GradOpRegistry* Global(); + + private: + std::unordered_map registry_; +}; + +} // namespace ops + +// Macros used to define gradient functions for ops. +#define REGISTER_GRADIENT_OP(name, fn) \ + REGISTER_GRADIENT_OP_UNIQ_HELPER(__COUNTER__, name, fn) + +#define REGISTER_NO_GRADIENT_OP(name) \ + REGISTER_GRADIENT_OP_UNIQ_HELPER(__COUNTER__, name, nullptr) + +#define REGISTER_GRADIENT_OP_UNIQ_HELPER(ctr, name, fn) \ + REGISTER_GRADIENT_OP_UNIQ(ctr, name, fn) + +#define REGISTER_GRADIENT_OP_UNIQ(ctr, name, fn) \ + static bool unused_ret_val_##ctr = \ + ::tensorflow::ops::GradOpRegistry::Global()->Register(name, fn) + +} // namespace tensorflow + +#endif // TENSORFLOW_CC_FRAMEWORK_GRAD_OP_REGISTRY_H_ diff --git a/third_party/tflite-hdrs/tensorflow/cc/framework/gradient_checker.h b/third_party/tflite-hdrs/tensorflow/cc/framework/gradient_checker.h new file mode 100644 index 00000000..20b6545f --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/cc/framework/gradient_checker.h @@ -0,0 +1,65 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CC_FRAMEWORK_GRADIENT_CHECKER_H_ +#define TENSORFLOW_CC_FRAMEWORK_GRADIENT_CHECKER_H_ + +#include + +#include "tensorflow/cc/framework/ops.h" +#include "tensorflow/cc/framework/scope.h" +#include "tensorflow/core/framework/tensor.h" + +namespace tensorflow { + +/// Returns in 'max_error' the maximum element-wise error for dy/dx between the +/// computed and numeric Jacobian matrices where 'xs' and 'ys' are tensors. +/// X_T and Y_T are the c++ types for the x and y tensors, and JAC_T is a +/// real-valued type to store the Jacobian derivatives dy/dx. +/// This function adds operations to the graph associated with 'scope'. +/// +/// Examples: +/// if y = Square(x), where x (and so y) are DT_FLOAT, +/// should be +/// +/// if y = Square(x), where x (and so y) are DT_DOUBLE, +/// should be +/// +/// if y = Square(x), where x (and so y) are DT_COMPLEX64, +/// should be +/// Note that JAC_T is always real-valued, and should be an appropriate +/// precision to host the partial derivatives for dy/dx +/// +/// if y = ComplexAbs(x) where x is DT_COMPLEX64 (so y is DT_FLOAT) +/// should be +/// +/// if y = Complex(x, x) where x is DT_FLOAT (so y is DT_COMPLEX64) +/// should be +template +absl::Status ComputeGradientError(const Scope& scope, const OutputList& xs, + const std::vector& x_shapes, + const OutputList& ys, + const std::vector& y_shapes, + JAC_T* max_error); + +/// Overload of ComputeGradientError which takes an initial value for 'x'. +template +absl::Status ComputeGradientError(const Scope& scope, const Output& x, + const Tensor& x_init_value, const Output& y, + const TensorShape& y_shape, JAC_T* max_error); + +} // namespace tensorflow + +#endif // TENSORFLOW_CC_FRAMEWORK_GRADIENT_CHECKER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/cc/framework/gradients.h b/third_party/tflite-hdrs/tensorflow/cc/framework/gradients.h new file mode 100644 index 00000000..c79269fd --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/cc/framework/gradients.h @@ -0,0 +1,54 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CC_FRAMEWORK_GRADIENTS_H_ +#define TENSORFLOW_CC_FRAMEWORK_GRADIENTS_H_ + +#include + +#include "tensorflow/cc/framework/ops.h" +#include "tensorflow/cc/framework/scope.h" + +namespace tensorflow { + +/// NOTE: This API is a work in progress and will likely be changing frequently. +/// +/// Given initial gradients 'grad_inputs' (which represent the symbolic partial +/// derivatives of some loss function 'L' w.r.t 'outputs'), adds gradient nodes +/// to the graph associated with 'scope', which compute (and return in +/// 'grad_outputs') the symbolic partial derivatives of 'L' w.r.t 'inputs'. +absl::Status AddSymbolicGradients(const Scope& scope, + const std::vector& outputs, + const std::vector& inputs, + const std::vector& grad_inputs, + std::vector* grad_outputs); + +// Same as above, but uses 'OnesLike' for all shapes in +// 'outputs' as grad_inputs. +absl::Status AddSymbolicGradients(const Scope& scope, + const std::vector& outputs, + const std::vector& inputs, + std::vector* grad_outputs); + +/// Returns a sentinel Output that represents 'no gradient' (i.e. no gradient +/// flows along some graph edge during backpropagation). +/// Can be returned in 'grad_outputs' by an invocation of 'AddSymbolicGradients' +/// (note that gradient flow through an Output can be stopped through the use of +/// the StopGradient node). +Output NoGradient(); + +} // namespace tensorflow + +#endif // TENSORFLOW_CC_FRAMEWORK_GRADIENTS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/cc/framework/ops.h b/third_party/tflite-hdrs/tensorflow/cc/framework/ops.h new file mode 100644 index 00000000..e856e311 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/cc/framework/ops.h @@ -0,0 +1,304 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CC_FRAMEWORK_OPS_H_ +#define TENSORFLOW_CC_FRAMEWORK_OPS_H_ + +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor.pb.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/lib/hash/hash.h" +#include "tensorflow/core/lib/strings/strcat.h" + +namespace tensorflow { + +/// @defgroup core Core Tensorflow API + +class Output; + +/// @addtogroup core +/// @{ + +/// Represents a node in the computation graph. +class Operation { + public: + Operation() : node_(nullptr) {} + explicit Operation(Node* n); + + int32 num_inputs() const { return node_->num_inputs(); } + DataType input_type(int32_t o) const { return node_->input_type(o); } + Output input(int32_t i) const; + + int32 num_outputs() const { return node_->num_outputs(); } + DataType output_type(int32_t o) const { return node_->output_type(o); } + Output output(int32_t i) const; + + Node* node() const { return node_; } + + uint64 hash(int32_t index) const; + + bool operator==(const Operation& other) const { return node_ == other.node_; } + + private: + typedef std::vector> Inputs; + static Inputs GetInputs(Node* node); + + Inputs inputs_; + Node* node_; +}; + +/// Represents a tensor value produced by an Operation. +class Output { + public: + Output() = default; + explicit Output(Node* n) : op_(n) {} + Output(Node* n, int32_t index) : op_(n), index_(index) {} + Output(const Operation& op, int32_t index) : op_(op), index_(index) {} + + Operation op() const { return op_; } + Node* node() const { return op().node(); } + int32 index() const { return index_; } + DataType type() const { return op_.output_type(index_); } + std::string name() const { + return strings::StrCat(node()->name(), ":", index()); + } + bool operator==(const Output& other) const { + return op_ == other.op_ && index_ == other.index_; + } + + uint64 hash() const { return op_.hash(index_); } + + private: + Operation op_ = Operation(nullptr); + int32 index_ = 0; +}; + +/// Hash class that can be used for e.g. storing Outputs in an unordered_map +struct OutputHash { + std::size_t operator()(const Output& output) const { + return Hash64Combine(std::hash()(output.node()), + std::hash()(output.index())); + } +}; + +/// Represents a tensor value that can be used as an operand to an Operation. +class Input { + public: + /// Initializer enables constructing an Input object from various kinds of C++ + /// constants such as simple primitive constants and nested initializer lists + /// representing a multi-dimensional array. Initializer constructors are all + /// templates, so the aforementioned kinds of C++ constants can be used to + /// construct an Initializer. Initializer stores the value it got constructed + /// with in a Tensor object. + struct Initializer { + /// Construct from a scalar value of an arithmetic type or a type that can + /// be converted to a string (eg. a string literal). + template ::value || + std::is_convertible::value>::type> + Initializer(const T& v) { // NOLINT(runtime/explicit) + typedef typename RealType::type RealT; + Tensor t(DataTypeToEnum::v(), TensorShape()); + t.flat()(0) = RealT(v); + tensor = t; + } + + Initializer(const Tensor& t) : tensor(t) {} // NOLINT(runtime/explicit) + + /// Construct from a scalar value and an explicit shape + template ::value || + std::is_convertible::value>::type> + Initializer(const T& v, const TensorShape& shape) { + typedef typename RealType::type RealT; + Tensor t(DataTypeToEnum::v(), shape); + for (int64_t i = 0; i < t.NumElements(); ++i) { + t.flat()(i) = RealT(v); + } + tensor = t; + } + + /// Construct from a initializer list of scalars (a one-dimensional tensor). + template ::value || + std::is_convertible::value>::type> + Initializer( + const std::initializer_list& v) { // NOLINT(runtime/explicit) + typedef typename RealType::type RealT; + Tensor t(DataTypeToEnum::v(), + TensorShape{static_cast(v.size())}); + std::copy_n(v.begin(), v.size(), t.flat().data()); + tensor = t; + } + + /// Construct from a initializer list of scalars and an explicit shape. + template ::value || + std::is_convertible::value>::type> + Initializer(const std::initializer_list& v, const TensorShape& shape) { + typedef typename RealType::type RealT; + Tensor t(DataTypeToEnum::v(), shape); + if (t.NumElements() != static_cast(v.size())) { + status = absl::InvalidArgumentError(absl::StrCat( + "Cannot construct a tensor with ", t.NumElements(), + " from an initializer list with ", v.size(), " elements")); + return; + } + std::copy_n(v.begin(), v.size(), t.flat().data()); + tensor = t; + } + + /// Construct a multi-dimensional tensor from a nested initializer + /// list. Note that C++ syntax allows nesting of arbitrarily typed + /// initializer lists, so such invalid initializers cannot be disallowed at + /// compile time. This function performs checks to make sure that the nested + /// initializer list is indeed a valid multi-dimensional tensor. + Initializer(const std::initializer_list& v); + + // START_SKIP_DOXYGEN + template ::value> + struct RealType { + typedef tstring type; + }; + + template + struct RealType { + typedef T type; + }; + // END_SKIP_DOXYGEN + + TensorProto AsTensorProto() { + TensorProto tensor_proto; + if (tensor.NumElements() > 1) { + tensor.AsProtoTensorContent(&tensor_proto); + } else { + tensor.AsProtoField(&tensor_proto); + } + return tensor_proto; + } + + absl::Status status; + Tensor tensor; + }; + + /// All of Input's constructors are implicit. Input can be implicitly + /// constructed from the following objects : + /// * Output: This is so that the output of an Operation can be directly used + /// as the input to a op wrapper, which takes Inputs. + /// * A scalar, or a multi-dimensional tensor specified as a recursive + /// initializer list. This enables directly passing constants as + /// inputs to op wrappers. + /// * A Tensor object. + Input(const Output& o) : output_(o) {} // NOLINT(runtime/explicit) + + template ::value || + std::is_convertible::value>::type> + Input(const T& v) // NOLINT(runtime/explicit) + : Input(Initializer(v)) {} + + Input(const Initializer& init) // NOLINT(runtime/explicit) + : status_(init.status), + tensor_(init.tensor) {} + + Input(const Tensor& t) // NOLINT(runtime/explicit) + : status_(absl::OkStatus()), tensor_(t) {} + + Input(const std::initializer_list& + init) { // NOLINT(runtime/explicit) + for (const auto& i : init) { + if (!i.status.ok()) { + status_ = i.status; + return; + } + } + tensor_ = Initializer(init).tensor; + } + + /// Constructor specifying a node name, index and datatype. This should only + /// be used for specifying a backward edge, needed by control flow. + Input(const std::string& name, int32_t i, DataType dt) + : node_name_(name), index_(i), data_type_(dt) {} + + Node* node() const { return output_.node(); } + std::string node_name() const { return node_name_; } + int32 index() const { return node_name_.empty() ? output_.index() : index_; } + DataType data_type() const { return data_type_; } + absl::Status status() const { return status_; } + const Tensor& tensor() const { return tensor_; } + + private: + absl::Status status_; + Output output_ = Output(Operation(nullptr), 0); + Tensor tensor_; + const std::string node_name_ = ""; + int32 index_ = 0; + DataType data_type_ = DT_INVALID; +}; + +/// A type for representing the output of ops that produce more than one output, +/// or a list of tensors. +typedef std::vector OutputList; + +/// A type for representing the input to ops that require a list of tensors. +class InputList { + public: + /// Implicitly convert a list of outputs to a list of inputs. This is useful + /// to write code such as ops::Concat(ops::Split(x, 4)). + InputList(const OutputList& out) { // NOLINT(runtime/explicit) + for (auto const& x : out) { + inputs_.push_back(x); + } + } + + InputList( + const std::initializer_list& inputs) // NOLINT(runtime/explicit) + : inputs_(inputs.begin(), inputs.end()) {} + + InputList(const absl::Span& inputs) // NOLINT(runtime/explicit) + : inputs_(inputs.begin(), inputs.end()) {} + + InputList( + const std::initializer_list& out) { // NOLINT(runtime/explicit) + for (auto const& x : out) { + inputs_.push_back(x); + } + } + + typename std::vector::iterator begin() { return inputs_.begin(); } + typename std::vector::iterator end() { return inputs_.end(); } + typename std::vector::const_iterator begin() const { + return inputs_.begin(); + } + typename std::vector::const_iterator end() const { + return inputs_.end(); + } + + private: + std::vector inputs_; +}; + +/// @} + +} // namespace tensorflow + +#endif // TENSORFLOW_CC_FRAMEWORK_OPS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/cc/framework/scope.h b/third_party/tflite-hdrs/tensorflow/cc/framework/scope.h new file mode 100644 index 00000000..9b8896e4 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/cc/framework/scope.h @@ -0,0 +1,270 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CC_FRAMEWORK_SCOPE_H_ +#define TENSORFLOW_CC_FRAMEWORK_SCOPE_H_ + +#include +#include +#include +#include +#include + +#include "absl/strings/str_cat.h" +#include "tensorflow/cc/framework/ops.h" +#include "tensorflow/core/common_runtime/graph_constructor.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/gtl/array_slice.h" + +namespace tensorflow { + +class Graph; +class GraphDef; +class NodeBuilder; +struct CompositeOpScopes; + +/// @addtogroup core +/// @{ + +/// A `Scope` object represents a set of related TensorFlow ops that have the +/// same properties such as a common name prefix. +/// +/// A Scope object is a container for TensorFlow Op properties. Op constructors +/// get a Scope object as a mandatory first argument and the constructed op +/// acquires the properties in the object. +/// +/// A simple example: +/// +/// using namespace ops; +/// Scope root = Scope::NewRootScope(); +/// auto c1 = Const(root, { {1, 1} }); +/// auto m = MatMul(root, c1, { {41}, {1} }); +/// GraphDef gdef; +/// Status s = root.ToGraphDef(&gdef); +/// if (!s.ok()) { ... } +/// +/// Scope hierarchy: +/// +/// The Scope class provides various With<> functions that create a new scope. +/// The new scope typically has one property changed while other properties are +/// inherited from the parent scope. +/// NewSubScope(name) method appends `name` to the prefix of names for ops +/// created within the scope, and WithOpName() changes the suffix which +/// otherwise defaults to the type of the op. +/// +/// Name examples: +/// +/// Scope root = Scope::NewRootScope(); +/// Scope linear = root.NewSubScope("linear"); +/// // W will be named "linear/W" +/// auto W = Variable(linear.WithOpName("W"), +/// {2, 2}, DT_FLOAT); +/// // b will be named "linear/b_3" +/// int idx = 3; +/// auto b = Variable(linear.WithOpName("b_", idx), +/// {2}, DT_FLOAT); +/// auto x = Const(linear, {...}); // name: "linear/Const" +/// auto m = MatMul(linear, x, W); // name: "linear/MatMul" +/// auto r = BiasAdd(linear, m, b); // name: "linear/BiasAdd" +/// +/// Scope lifetime: +/// +/// A new scope is created by calling Scope::NewRootScope. This creates some +/// resources that are shared by all the child scopes that inherit from this +/// scope, directly or transitively. For instance, a new scope creates a new +/// Graph object to which operations are added when the new scope or its +/// children are used by an Op constructor. The new scope also has a Status +/// object which will be used to indicate errors by Op-constructor functions +/// called on any child scope. The Op-constructor functions have to check the +/// scope's status by calling the ok() method before proceeding to construct the +/// op. +/// +/// Thread safety: +/// +/// A `Scope` object is NOT thread-safe. Threads cannot concurrently call +/// op-constructor functions on the same `Scope` object. +class Scope { + public: + Scope(const Scope& other); + ~Scope(); + Scope& operator=(const Scope& other); + + // The following functions are for users making graphs. They return brand new + // scopes, or scopes derived from an existing scope object. + + /// Return a new scope. + /// This creates a new graph and all operations constructed in this graph + /// should use the returned object as the "root" scope. + static Scope NewRootScope(); + + /// Return a new scope. Ops created with this scope will have + /// `name/child_scope_name` as the prefix. The actual name will be unique + /// in the current scope. All other properties are inherited from the current + /// scope. If `child_scope_name` is empty, the `/` is elided. + Scope NewSubScope(const string& child_scope_name) const; + + /// Return a new scope. All ops created within the returned scope will have + /// names of the form `name/StrCat(fragments...)[_suffix]` + template + Scope WithOpName(Ty... fragments) const { + return WithOpNameImpl(absl::StrCat(fragments...)); + } + + /// Return a new scope. All ops created within the returned scope will have as + /// control dependencies the union of operations in the control_deps vector + /// and the control dependencies of the current scope. + Scope WithControlDependencies(absl::Span control_deps) const; + /// Same as above, but convenient to add control dependency on the operation + /// producing the control_dep output. + Scope WithControlDependencies(const Output& control_dep) const; + + /// Return a new scope. All ops created within the returned scope will have no + /// control dependencies on other operations. + Scope WithNoControlDependencies() const; + + /// Return a new scope. All ops created within the returned scope will have + /// the device field set to 'device'. + Scope WithDevice(const string& device) const; + + /// Returns a new scope. All ops created within the returned scope will have + /// their assigned device set to `assigned_device`. + Scope WithAssignedDevice(const string& assigned_device) const; + + /// Returns a new scope. All ops created within the returned scope will have + /// their _XlaCluster attribute set to `xla_cluster`. + Scope WithXlaCluster(const string& xla_cluster) const; + + /// Return a new scope. All ops created within the returned scope will be + /// co-located on the device where op is placed. + /// NOTE: This function is intended to be use internal libraries only for + /// controlling placement of ops on to devices. Public use is not encouraged + /// because the implementation of device placement is subject to change. + Scope ColocateWith(const Operation& op) const; + /// Convenience function for above. + Scope ColocateWith(const Output& out) const { return ColocateWith(out.op()); } + /// Clear all colocation constraints. + Scope ClearColocation() const; + + /// Return a new scope. The op-constructor functions taking the returned scope + /// as the scope argument will exit as soon as an error is detected, instead + /// of setting the status on the scope. + Scope ExitOnError() const; + + /// Return a new scope. All ops created with the new scope will have + /// kernel_label as the value for their '_kernel' attribute; + Scope WithKernelLabel(const string& kernel_label) const; + + // The following functions are for scope object consumers. + + /// Return a unique name, using default_name if an op name has not been + /// specified. + string GetUniqueNameForOp(const string& default_name) const; + + /// Update the status on this scope. + /// Note: The status object is shared between all children of this scope. + /// If the resulting status is not OkStatus() and exit_on_error_ is set on + /// this scope, this function exits by calling LOG(FATAL). + void UpdateStatus(const absl::Status& s) const; + + // START_SKIP_DOXYGEN + + /// Update the builder with properties accumulated in this scope. Does not set + /// status(). + // TODO(skyewm): NodeBuilder is not part of public API + void UpdateBuilder(NodeBuilder* builder) const; + // END_SKIP_DOXYGEN + + CompositeOpScopes GetCompositeOpScopes(const string& composite_op_name) const; + + bool ok() const; + + // TODO(skyewm): Graph is not part of public API + Graph* graph() const; + + // TODO(skyewm): Graph is not part of public API + std::shared_ptr graph_as_shared_ptr() const; + + absl::Status status() const; + + /// If status() is ok, convert the Graph object stored in this scope + /// to a GraphDef proto and return an ok Status. Otherwise, return the error + /// status as is without performing GraphDef conversion. If + /// `include_debug_info` is true, populate the `debug_info` field of the + /// GraphDef from stack traces in this Graph. + absl::Status ToGraphDef(GraphDef* gdef, + bool include_debug_info = false) const; + + // START_SKIP_DOXYGEN + + /// If status() is OkStatus(), construct a Graph object using `opts` as the + /// GraphConstructorOptions, and return Status::OK if graph construction was + /// successful. Otherwise, return the error status. + // TODO(josh11b, keveman): Make this faster; right now it converts + // Graph->GraphDef->Graph. This cleans up the graph (e.g. adds + // edges from the source and to the sink node, resolves back edges + // by name), and makes sure the resulting graph is valid. + absl::Status ToGraph( + Graph* g, GraphConstructorOptions opts = GraphConstructorOptions{}) const; + + // Calls AddNode() using this scope's ShapeRefiner. This exists in the public + // API to prevent custom op wrappers from needing access to shape_refiner.h or + // scope_internal.h. + // TODO(skyewm): remove this from public API + absl::Status DoShapeInference(Node* node) const; + + // Creates a new root scope that causes all DoShapeInference() calls to return + // OkStatus() (on the returned scope and any subscopes). Used for testing. + // TODO(skyewm): fix tests that still require this and eventually remove, or + // at least remove from public API + static Scope DisabledShapeInferenceScope(); + // END_SKIP_DOXYGEN + + const std::vector& control_deps() const; + + // START_SKIP_DOXYGEN + class Impl; + Impl* impl() { return impl_.get(); } + const Impl* impl() const { return impl_.get(); } + // END_SKIP_DOXYGEN + + private: + Scope WithOpNameImpl(const string& op_name) const; + + friend class InternalScope; + std::unique_ptr impl_; + explicit Scope(Impl*); +}; + +/// A helper struct to hold the scopes that would be used by a function +/// constructing a composite op. +struct CompositeOpScopes { + /// Scope to be used for creating the local ops (primitive or other composite + /// ops). + Scope child; + /// Scope to be used for creating the last op. + Scope last; +}; + +// Creates a node of the given operation, with the given inputs, and assigns the +// result to output. This does not support the ability to add additional +// attributes. +absl::Status CreateOutputWithScope(string op_name, + absl::Span inputs, + const Scope& scope, Output* output); +/// @} + +} // namespace tensorflow + +#endif // TENSORFLOW_CC_FRAMEWORK_SCOPE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/cc/framework/scope_internal.h b/third_party/tflite-hdrs/tensorflow/cc/framework/scope_internal.h new file mode 100644 index 00000000..0cf6af68 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/cc/framework/scope_internal.h @@ -0,0 +1,134 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CC_FRAMEWORK_SCOPE_INTERNAL_H_ +#define TENSORFLOW_CC_FRAMEWORK_SCOPE_INTERNAL_H_ + +#include +#include +#include +#include +#include + +#include "tensorflow/cc/framework/scope.h" + +namespace tensorflow { + +class ShapeRefiner; + +// NewInternalScope returns a new scope which doesn't take ownership of +// graph, status, name_map, and refiner. +// This is intended to enable the C API (which are used by other language +// bindings) to create a Scope and access C++ functionality (i.e. gradients). +// +// Shape inference is disabled if `refiner` is nullptr. +Scope NewInternalScope(Graph* graph, absl::Status* status, + ShapeRefiner* refiner); + +class Scope::Impl { + public: + // A NameMap is used to keep track of suffixes for names used in a scope. A + // name that has not been used so far in a scope will get no suffix. Later + // uses of the same name will get suffixes _1, _2, _3, etc. Multiple scopes + // can share the same NameMap. For instance, a new scope created using + // WithControlDependencies() would share the same NameMap with the parent. + typedef std::unordered_map NameMap; + + Impl(const std::shared_ptr& graph, + const std::shared_ptr& status, + const std::shared_ptr& name_map, + const std::shared_ptr& refiner); + + const string& name() const { return name_; } + const std::vector& control_deps() const { return control_deps_; } + + private: + friend class Scope; + + // Tag types to choose the constructor to dispatch. + struct Tags { + enum class ScopeName; + enum class OpName; + enum class ControlDeps; + enum class Device; + enum class SingleUseScope; + enum class ExitOnError; + enum class KernelLabel; + enum class Colocate; + enum class AssignedDevice; + enum class XlaCluster; + }; + + Impl(Graph* graph, absl::Status* status, NameMap* name_map, + ShapeRefiner* refiner, bool disable_shape_inference); + Impl(const Scope& other, Tags::ScopeName, const string& name, + bool copy_names); + Impl(const Scope& other, Tags::OpName, const string& name, + const string& op_name); + Impl(const Scope& other, Tags::ControlDeps, + std::vector control_deps, bool clear_control_deps); + Impl(const Scope& other, Tags::Device, const string& device); + Impl(const Scope& other, Tags::SingleUseScope, const string& op_name); + Impl(const Scope& other, Tags::ExitOnError); + Impl(const Scope& other, Tags::KernelLabel, const string& kernel_label); + Impl(const Scope& other, Tags::Colocate, const Operation& colocate_with_op, + bool clear_colocations); + Impl(const Scope& other, Tags::AssignedDevice, const string& assigned_device); + Impl(const Scope& other, Tags::XlaCluster, const string& xla_cluster); + + std::unordered_set GetColocationConstraints( + const Operation& colocate_with_op) const; + + // Helper functions to get a unique names. + string GetUniqueName(const string& prefix, bool check_single_use) const; + string GetNameForOp(const string& default_name) const; + + bool single_use_scope() const { return scope_used_ != nullptr; } + + // The graph, status, and name maps are shared by all child scopes + // created from a single 'root' scope. A root scope is created by calling the + // Scope::NewRootScope function, which creates a new graph, a new status and + // the name maps. + std::shared_ptr graph_ = nullptr; + std::shared_ptr status_ = nullptr; + std::shared_ptr name_map_ = nullptr; + std::shared_ptr refiner_ = nullptr; + + // If scope_used_ is not nullptr, op_name_ should be empty and + // GetUniqueNameForOp can only be called once on this scope. More calls to + // GetUniqueNameForOp will cause an error status to be set on this scope. + std::shared_ptr scope_used_ = nullptr; + + const std::vector control_deps_; + + // The fully-qualified name of this scope (i.e. includes any parent scope + // names). + const string name_ = ""; + const string op_name_ = ""; + const bool exit_on_error_ = false; + const string kernel_label_ = ""; + const string device_ = ""; + const string assigned_device_ = ""; + const string xla_cluster_ = ""; + const std::unordered_set colocation_constraints_; + + // If true, Scope::DoShapeInference() always returns Status:OK(). + // TODO(skyewm): remove this when possible + const bool disable_shape_inference_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CC_FRAMEWORK_SCOPE_INTERNAL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/cc/framework/testutil.h b/third_party/tflite-hdrs/tensorflow/cc/framework/testutil.h new file mode 100644 index 00000000..2464b491 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/cc/framework/testutil.h @@ -0,0 +1,49 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CC_FRAMEWORK_TESTUTIL_H_ +#define TENSORFLOW_CC_FRAMEWORK_TESTUTIL_H_ + +#include + +#include "tensorflow/cc/framework/ops.h" +#include "tensorflow/cc/framework/scope.h" + +namespace tensorflow { +namespace test { + +/// Computes the outputs listed in 'tensors', returns the tensors in 'out'. +void GetTensors(const Scope& scope, OutputList tensors, + std::vector* out); + +// Computes the outputs listed in 'tensors', returns the tensors in 'out'. +// assign_vars are extra outputs that should be run +// e.g. to assign values to variables. +void GetTensors(const Scope& scope, const std::vector& assign_vars, + const OutputList& tensors, std::vector* out); + +/// Computes the output 'tensor', returning the resulting tensor in 'out'. +void GetTensor(const Scope& scope, Output tensor, Tensor* out); + +// Computes the output 'tensor', returning the resulting tensor in 'out'. +// assign_vars are extra outputs that should be run +// e.g. to assign values to variables. +void GetTensor(const Scope& scope, const std::vector& assign_vars, + Output tensor, Tensor* out); + +} // namespace test +} // namespace tensorflow + +#endif // TENSORFLOW_CC_FRAMEWORK_TESTUTIL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/cc/framework/while_gradients.h b/third_party/tflite-hdrs/tensorflow/cc/framework/while_gradients.h new file mode 100644 index 00000000..1f31de15 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/cc/framework/while_gradients.h @@ -0,0 +1,42 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CC_FRAMEWORK_WHILE_GRADIENTS_H_ +#define TENSORFLOW_CC_FRAMEWORK_WHILE_GRADIENTS_H_ + +#include + +#include "tensorflow/cc/framework/ops.h" +#include "tensorflow/cc/framework/scope.h" +#include "tensorflow/core/graph/while_context.h" + +// Utility functions for constructing while loop gradients + +namespace tensorflow { + +// Adds the gradient computation for the while loop associated with +// `while_ctx`. `grad_inputs` are the partial derivatives w.r.t. the loop +// outputs, i.e. the exit nodes. The partial derivatives w.r.t. the loop +// inputs, i.e. the input loop vars, are returned in `grad_outputs`. +// `grad_inputs` and `grad_outputs` are both in loop-variable order, as defined +// by the original inputs to BuildWhileLoop(). +// TODO(skyewm): maybe comment on NoGradient once it's supported +absl::Status AddWhileLoopGradient(WhileContext* while_ctx, const Scope& scope, + const std::vector& grad_inputs, + std::vector* grad_outputs); + +} // namespace tensorflow + +#endif // TENSORFLOW_CC_FRAMEWORK_WHILE_GRADIENTS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/cc/gradients/grad_helper.h b/third_party/tflite-hdrs/tensorflow/cc/gradients/grad_helper.h new file mode 100644 index 00000000..2a50d648 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/cc/gradients/grad_helper.h @@ -0,0 +1,36 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CC_GRADIENTS_GRAD_HELPER_H_ +#define TENSORFLOW_CC_GRADIENTS_GRAD_HELPER_H_ + +#include "tensorflow/cc/framework/ops.h" +#include "tensorflow/cc/framework/scope.h" + +namespace tensorflow { + +// Helper function for reduction ops. +// +// input_shape: 1-D Tensor, the shape of the Tensor being reduced. +// axes: 1-D Tensor, the reduction axes. +// Note that the reduction indices are in the range +// -rank(input_shape), rank(input_shape) +// returns a 1-D Tensor, the output shape as if keep_dims were set to True. +Output ReducedShapeHelper(const Scope& scope, const Output& input_shape, + const Output& reduction_axes); + +} // namespace tensorflow + +#endif // TENSORFLOW_CC_GRADIENTS_GRAD_HELPER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/cc/gradients/grad_testutil.h b/third_party/tflite-hdrs/tensorflow/cc/gradients/grad_testutil.h new file mode 100644 index 00000000..acde3075 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/cc/gradients/grad_testutil.h @@ -0,0 +1,37 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CC_GRADIENTS_GRAD_TESTUTIL_H_ +#define TENSORFLOW_CC_GRADIENTS_GRAD_TESTUTIL_H_ + +#include + +#include "tensorflow/cc/framework/ops.h" +#include "tensorflow/cc/framework/scope.h" + +namespace tensorflow { +namespace test { + +/// Calls the gradient function registered for 'op', adding gradient operations +/// to the graph associated with 'scope'. Gradient outputs for each 'op' input +/// are returned in 'grad_outputs'. +absl::Status CallGradFunction(const Scope& scope, const Operation& op, + const std::vector& grad_inputs, + std::vector* grad_outputs); + +} // namespace test +} // namespace tensorflow + +#endif // TENSORFLOW_CC_GRADIENTS_GRAD_TESTUTIL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/cc/ops/const_op.h b/third_party/tflite-hdrs/tensorflow/cc/ops/const_op.h new file mode 100644 index 00000000..9c888701 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/cc/ops/const_op.h @@ -0,0 +1,87 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CC_OPS_CONST_OP_H_ +#define TENSORFLOW_CC_OPS_CONST_OP_H_ + +#include + +#include "tensorflow/cc/framework/ops.h" +#include "tensorflow/cc/framework/scope.h" +#include "tensorflow/core/graph/node_builder.h" + +namespace tensorflow { +namespace ops { + +/// @defgroup const_op Const Op +/// @{ + +Output Const(const Scope& scope, const Input::Initializer& val); + +Output ConstFromProto(const Scope& scope, const TensorProto& proto); + +NodeBuilder::NodeOut AsNodeOut(const Scope& scope, const Input& inp); + +template +Output Const(const Scope& scope, const Input::Initializer& val) { + auto orig_const_output = Const(scope, val); + if (!scope.ok()) return Output(); + + typedef typename Input::Initializer::RealType::type DstT; + + if (val.tensor.dtype() == DataTypeToEnum::v()) { + return orig_const_output; + } + if (val.tensor.NumElements() == 0) { + Tensor t(DataTypeToEnum::v(), val.tensor.shape()); + return Const(scope, Input::Initializer(t)); + } + + // TODO(keveman): Refactor Cast op's kernel implementation such that the code + // can be directly called here instead of adding the Cast op to the graph. + auto orig_const = AsNodeOut(scope, orig_const_output); + const auto cast_op_name = scope.GetUniqueNameForOp("Cast"); + + auto cast_builder = NodeBuilder(cast_op_name, "Cast") + .Input(orig_const) + .Attr("DstT", DataTypeToEnum::v()); + scope.UpdateBuilder(&cast_builder); + Node* ret; + scope.UpdateStatus(cast_builder.Finalize(scope.graph(), &ret)); + if (!scope.ok()) return Output(); + scope.UpdateStatus(scope.DoShapeInference(ret)); + return Output(ret, 0); +} + +template +Output Const(const Scope& scope, const T& v, const TensorShape shape) { + return Const(scope, Input::Initializer(v, shape)); +} + +template +Output Const(const Scope& scope, const std::initializer_list& v, + const TensorShape shape) { + return Const(scope, Input::Initializer(v, shape)); +} + +std::vector AsNodeOutList(const Scope& scope, + const InputList& inp); + +/// }@ + +} // namespace ops +} // namespace tensorflow + +#endif // TENSORFLOW_CC_OPS_CONST_OP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/cc/ops/standard_ops.h b/third_party/tflite-hdrs/tensorflow/cc/ops/standard_ops.h new file mode 100644 index 00000000..98f53010 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/cc/ops/standard_ops.h @@ -0,0 +1,40 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CC_OPS_STANDARD_OPS_H_ +#define TENSORFLOW_CC_OPS_STANDARD_OPS_H_ + +#include "tensorflow/cc/ops/array_ops.h" +#include "tensorflow/cc/ops/candidate_sampling_ops.h" +#include "tensorflow/cc/ops/const_op.h" +#include "tensorflow/cc/ops/control_flow_ops.h" +#include "tensorflow/cc/ops/data_flow_ops.h" +#include "tensorflow/cc/ops/image_ops.h" +#include "tensorflow/cc/ops/io_ops.h" +#include "tensorflow/cc/ops/linalg_ops.h" +#include "tensorflow/cc/ops/logging_ops.h" +#include "tensorflow/cc/ops/lookup_ops.h" +#include "tensorflow/cc/ops/math_ops.h" +#include "tensorflow/cc/ops/nn_ops.h" +#include "tensorflow/cc/ops/no_op.h" +#include "tensorflow/cc/ops/parsing_ops.h" +#include "tensorflow/cc/ops/random_ops.h" +#include "tensorflow/cc/ops/sparse_ops.h" +#include "tensorflow/cc/ops/state_ops.h" +#include "tensorflow/cc/ops/string_ops.h" +#include "tensorflow/cc/ops/training_ops.h" +#include "tensorflow/cc/ops/user_ops.h" + +#endif // TENSORFLOW_CC_OPS_STANDARD_OPS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/cc/ops/while_loop.h b/third_party/tflite-hdrs/tensorflow/cc/ops/while_loop.h new file mode 100644 index 00000000..5a1a45da --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/cc/ops/while_loop.h @@ -0,0 +1,80 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CC_OPS_WHILE_LOOP_H_ +#define TENSORFLOW_CC_OPS_WHILE_LOOP_H_ + +#include +#include + +#include "tensorflow/cc/framework/ops.h" +#include "tensorflow/cc/framework/scope.h" + +namespace tensorflow { +namespace ops { + +// Function that takes cond graph inputs and returns cond graph boolean output. +// 'output' need not be set if an error is returned. +typedef std::function& inputs, Output* output)> + CondGraphBuilderFn; + +// Function that takes body graph inputs and returns body graph outputs. +// 'outputs' need not be populated if an error is returned. +typedef std::function& inputs, + std::vector* outputs)> + BodyGraphBuilderFn; + +// Constructs a while loop. +// +// Arguments: +// * scope: used to construct the while loop. +// * inputs: the initial values of the loop variables. Must be non-empty. +// * cond: a function that builds the condition graph of the loop. Takes the +// current loop variables as inputs and returns a scalar boolean Output +// indicating whether the loop should continue. +// * body: a function that builds the body graph of the loop. Takes the current +// loop variables as inputs and returns the updated loop variables. +// * frame_name: the frame name to use for this while loop. This should be a +// unique name. This will be used as a prefix for created operations. +// * outputs: output param that returns final loop variable outputs in non-error +// case. Must be non-null and empty. +// * create_while_ctx: if true, a WhileContext is created and populated for this +// loop. See core/graph/while_context.h for more details on +// WhileContexts. This is set to false for loops used as part of gradient +// computations, since they're part of the gradient for a loop in the +// forward-pass. +// TODO(skyewm): revisit this. Should we create WhileContexts for all loops, +// even if we don't need them? +// * cond_output: if non-null, the output of the predicate is returned. This +// will always be a LoopCond node. +// +// Returns an error if the while loop could not be fully constructed. +// +// TODO(skyewm): clean up partially-constructed loop in error case +// TODO(skyewm): create public interface to this method +absl::Status BuildWhileLoop(const Scope& scope, + const std::vector& inputs, + const CondGraphBuilderFn& cond, + const BodyGraphBuilderFn& body, + const string& frame_name, OutputList* outputs, + bool create_while_ctx = true, + Output* cond_output = nullptr); + +} // namespace ops +} // namespace tensorflow + +#endif // TENSORFLOW_CC_OPS_WHILE_LOOP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/cc/saved_model/bundle_v2.h b/third_party/tflite-hdrs/tensorflow/cc/saved_model/bundle_v2.h new file mode 100644 index 00000000..ec85d14f --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/cc/saved_model/bundle_v2.h @@ -0,0 +1,90 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Helpers for loading the persistent representation of a SavedModelV2. +// Please note that this is depended on by code that does not make use of +// the full runtime and its dependencies should be restricted. + +#ifndef TENSORFLOW_CC_SAVED_MODEL_BUNDLE_V2_H_ +#define TENSORFLOW_CC_SAVED_MODEL_BUNDLE_V2_H_ + +#include +#include +#include + +#include "absl/container/flat_hash_set.h" +#include "tensorflow/core/framework/graph_debug_info.pb.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/protobuf/meta_graph.pb.h" +#include "tensorflow/core/protobuf/saved_object_graph.pb.h" +#include "tensorflow/core/protobuf/trackable_object_graph.pb.h" +#include "tensorflow/core/util/tensor_bundle/tensor_bundle.h" + +namespace tensorflow { + +/// Represents a version 2 SavedModel that is loaded from storage (but not yet +/// loaded into an executable in-memory representation). +class SavedModelV2Bundle { + public: + using RestoreObjectsCallback = std::function; + + /// Loads persistent representations for a SavedModelV2 from the specified + /// export directory. + static absl::Status Load(const std::string& export_dir, + SavedModelV2Bundle* bundle); + + /// MetaGraphDef from the loaded SavedModel. + MetaGraphDef& meta_graph_def() { return meta_graph_def_; } + + /// SavedObjectGraph from the MetaGraphDef. + const SavedObjectGraph& saved_object_graph() { + return meta_graph_def().object_graph_def(); + } + + /// TrackableObjectGraph loaded from the variable_reader() checkpoint. + TrackableObjectGraph& trackable_object_graph() { + return trackable_object_graph_; + } + + /// BundleReader for accessing the variables bundle. + BundleReader* variable_reader() { return variable_reader_.get(); } + + /// The GraphDebugInfo (or nullptr if none). + GraphDebugInfo* debug_info() { return debug_info_.get(); } + + /// Restores objects, invoking the callback with the node id in the + /// saved_object_graph() and the corresponding TrackableObject from the + /// trackable_object_graph(). The callback may use the variable_reader() but + /// must not modify the underlying saved_object_graph(). + absl::Status VisitObjectsToRestore(RestoreObjectsCallback callback); + + private: + absl::Status RecurseObjectsToRestore( + const SavedObject* saved_object, int saved_object_node_id, + const TrackableObjectGraph::TrackableObject* trackable_object, + std::string object_name, + absl::flat_hash_set* seen_trackable_node_ids, + RestoreObjectsCallback callback); + + MetaGraphDef meta_graph_def_; + TrackableObjectGraph trackable_object_graph_; + std::unique_ptr variable_reader_; + std::unique_ptr debug_info_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CC_SAVED_MODEL_BUNDLE_V2_H_ diff --git a/third_party/tflite-hdrs/tensorflow/cc/saved_model/constants.h b/third_party/tflite-hdrs/tensorflow/cc/saved_model/constants.h new file mode 100644 index 00000000..e8a267e3 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/cc/saved_model/constants.h @@ -0,0 +1,82 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CC_SAVED_MODEL_CONSTANTS_H_ +#define TENSORFLOW_CC_SAVED_MODEL_CONSTANTS_H_ + +namespace tensorflow { + +// SavedModel assets directory. +inline constexpr char kSavedModelAssetsDirectory[] = "assets"; + +// SavedModel assets.extra directory. +inline constexpr char kSavedModelAssetsExtraDirectory[] = "assets.extra"; + +// SavedModel assets key for graph collection-def. +inline constexpr char kSavedModelAssetsKey[] = "saved_model_assets"; + +/// SavedModel legacy init op collection key. Used in v1 SavedModels. +inline constexpr char kSavedModelLegacyInitOpKey[] = "legacy_init_op"; + +/// SavedModel main op collection key. Used in v1 SavedModels. +inline constexpr char kSavedModelMainOpKey[] = "saved_model_main_op"; + +// CollectionDef key for the SavedModel train op. +// Not exported while export_all_saved_models is experimental. +inline constexpr char kSavedModelTrainOpKey[] = "saved_model_train_op"; + +// Schema version for SavedModel. +inline constexpr int kSavedModelSchemaVersion = 1; + +// SavedModel proto filename prefix. +inline constexpr char kSavedModelFilenamePrefix[] = "saved_model"; +// SavedModel proto filename. +inline constexpr char kSavedModelFilenamePb[] = "saved_model.pb"; + +// SavedModel chunked proto filename. +inline constexpr char kSavedModelFilenameCpb[] = "saved_model.cpb"; + +// SavedModel text format proto filename. +inline constexpr char kSavedModelFilenamePbTxt[] = "saved_model.pbtxt"; + +// Subdirectory where debugging related files are written. +inline constexpr char kSavedModelDebugDirectory[] = "debug"; + +// File name for GraphDebugInfo protocol buffer which corresponds to the +// SavedModel. +inline constexpr char kSavedModelDebugInfoFilenamePb[] = + "saved_model_debug_info.pb"; + +// Directory in which to save the SavedModel variables. +inline constexpr char kSavedModelVariablesDirectory[] = "variables"; + +// SavedModel variables filename. +inline constexpr char kSavedModelVariablesFilename[] = "variables"; + +// SavedModel SignatureDef keys for the initialization and train ops. Used in +// V2 SavedModels. +inline constexpr char kSavedModelInitOpSignatureKey[] = "__saved_model_init_op"; +inline constexpr char kSavedModelTrainOpSignatureKey[] = + "__saved_model_train_op"; + +// Key in the TensorBundle for the object graph proto. +inline constexpr char kObjectGraphProtoKey[] = "_CHECKPOINTABLE_OBJECT_GRAPH"; + +// Filename for the FingerprintDef protocol buffer. +inline constexpr char kFingerprintFilenamePb[] = "fingerprint.pb"; + +} // namespace tensorflow + +#endif // TENSORFLOW_CC_SAVED_MODEL_CONSTANTS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/cc/saved_model/experimental/public/concrete_function.h b/third_party/tflite-hdrs/tensorflow/cc/saved_model/experimental/public/concrete_function.h new file mode 100644 index 00000000..1adaf70b --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/cc/saved_model/experimental/public/concrete_function.h @@ -0,0 +1,61 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_CONCRETE_FUNCTION_H_ +#define TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_CONCRETE_FUNCTION_H_ + +#include + +#include "tensorflow/c/eager/c_api.h" +#include "tensorflow/c/experimental/saved_model/public/concrete_function.h" +#include "tensorflow/cc/experimental/base/public/status.h" +#include "tensorflow/cc/saved_model/experimental/public/function_metadata.h" + +namespace tensorflow { +namespace experimental { +namespace cc { + +// ConcreteFunction is an executable "function" loaded from a SavedModelAPI. +class ConcreteFunction final { + public: + // TODO(bmzhao): Adding ConcreteFunction::Run in subsequent CL, since + // it depends on tensorflow::cc::Tensor and tensorflow::cc::TensorHandle + + // Returns FunctionMetadata associated with this ConcreteFunction. + const FunctionMetadata* GetFunctionMetadata(); + + private: + friend class SavedModelAPI; + friend class ConcreteFunctionList; + + // TODO(bmzhao): Consider adding a macro for wrapping/unwrapping + // when moving out of experimental. + static ConcreteFunction* wrap(TF_ConcreteFunction* p) { + return reinterpret_cast(p); + } + static TF_ConcreteFunction* unwrap(ConcreteFunction* p) { + return reinterpret_cast(p); + } +}; + +inline const FunctionMetadata* ConcreteFunction::GetFunctionMetadata() { + return FunctionMetadata::wrap(TF_ConcreteFunctionGetMetadata(unwrap(this))); +} + +} // namespace cc +} // namespace experimental +} // namespace tensorflow + +#endif // TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_CONCRETE_FUNCTION_H_ diff --git a/third_party/tflite-hdrs/tensorflow/cc/saved_model/experimental/public/concrete_function_list.h b/third_party/tflite-hdrs/tensorflow/cc/saved_model/experimental/public/concrete_function_list.h new file mode 100644 index 00000000..88cb779e --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/cc/saved_model/experimental/public/concrete_function_list.h @@ -0,0 +1,63 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_CONCRETE_FUNCTION_LIST_H_ +#define TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_CONCRETE_FUNCTION_LIST_H_ + +#include + +#include "tensorflow/c/experimental/saved_model/public/concrete_function_list.h" +#include "tensorflow/cc/saved_model/experimental/public/concrete_function.h" + +namespace tensorflow { +namespace experimental { +namespace cc { + +// ConcreteFunctionList helps convert an opaque pointer to an array of +// ConcreteFunction pointers to a std::vector. +class ConcreteFunctionList { + public: + // Converts this object to a std::vector + std::vector ToVector(); + + private: + friend class SavedModelAPI; + // Wraps a TF_ConcreteFunctionList. Takes ownership of list. + explicit ConcreteFunctionList(TF_ConcreteFunctionList* list) : list_(list) {} + + struct TFConcreteFunctionListDeleter { + void operator()(TF_ConcreteFunctionList* p) const { + TF_DeleteConcreteFunctionList(p); + } + }; + std::unique_ptr list_; +}; + +inline std::vector ConcreteFunctionList::ToVector() { + int size = TF_ConcreteFunctionListSize(list_.get()); + std::vector result; + result.reserve(size); + for (int i = 0; i < size; ++i) { + result.push_back( + ConcreteFunction::wrap(TF_ConcreteFunctionListGet(list_.get(), i))); + } + return result; +} + +} // namespace cc +} // namespace experimental +} // namespace tensorflow + +#endif // TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_CONCRETE_FUNCTION_LIST_H_ diff --git a/third_party/tflite-hdrs/tensorflow/cc/saved_model/experimental/public/function_metadata.h b/third_party/tflite-hdrs/tensorflow/cc/saved_model/experimental/public/function_metadata.h new file mode 100644 index 00000000..11e1a860 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/cc/saved_model/experimental/public/function_metadata.h @@ -0,0 +1,47 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_FUNCTION_METADATA_H_ +#define TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_FUNCTION_METADATA_H_ + +#include + +#include "tensorflow/c/experimental/saved_model/public/function_metadata.h" + +namespace tensorflow { +namespace experimental { +namespace cc { + +// FunctionMetadata stores additional function information, including +// optional signaturedef feeds/fetches (for TF1-based ConcreteFunctions), +// a valid function path (for TF2-based ConcreteFunctions), and +// the types + number of inputs and outputs. +class FunctionMetadata final { + // TODO(bmzhao): Add getters here as necessary. + private: + friend class ConcreteFunction; + static FunctionMetadata* wrap(TF_FunctionMetadata* p) { + return reinterpret_cast(p); + } + static TF_FunctionMetadata* unwrap(FunctionMetadata* p) { + return reinterpret_cast(p); + } +}; + +} // namespace cc +} // namespace experimental +} // namespace tensorflow + +#endif // TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_FUNCTION_METADATA_H_ diff --git a/third_party/tflite-hdrs/tensorflow/cc/saved_model/experimental/public/saved_model_api.h b/third_party/tflite-hdrs/tensorflow/cc/saved_model/experimental/public/saved_model_api.h new file mode 100644 index 00000000..9d30a4a2 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/cc/saved_model/experimental/public/saved_model_api.h @@ -0,0 +1,155 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_SAVED_MODEL_API_H_ +#define TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_SAVED_MODEL_API_H_ + +#include +#include +#include +#include + +#include "tensorflow/c/experimental/saved_model/public/saved_model_api.h" +#include "tensorflow/cc/experimental/base/public/runtime.h" +#include "tensorflow/cc/experimental/base/public/status.h" +#include "tensorflow/cc/saved_model/experimental/public/concrete_function.h" +#include "tensorflow/cc/saved_model/experimental/public/concrete_function_list.h" +#include "tensorflow/cc/saved_model/experimental/public/signature_def_function.h" + +namespace tensorflow { +namespace experimental { +namespace cc { + +// SavedModelAPI offers a way to load Tensorflow Saved Models +// (https://www.tensorflow.org/guide/saved_model) and execute saved +// tf.functions or legacy SignatureDefs in a TF2-idiomatic fashion. +// See RFC 207 +// (https://github.com/tensorflow/community/blob/master/rfcs/20200218-tf-c-saved-model.md) +// TODO(bmzhao): Add an e2e example here, once ConcreteFunction::Run is added. +class SavedModelAPI { + public: + // Load a SavedModel from `dirname`. + // + // Params: + // saved_model_path - A directory filepath that the SavedModel is at. + // runtime - A runtime used to load SavedModelAPI. `runtime` must outlive the + // returned TF_SavedModel pointer. + // tags - Optional set of tags. If tags = nullptr, we expect the SavedModel + // to contain a single Metagraph (as for those exported from TF2's + // `tf.saved_model.save`). If tags != nullptr, we load the metagraph + // matching the tags: + // https://github.com/tensorflow/tensorflow/blob/428cdeda09aef81e958eeb274b83d27ad635b57b/tensorflow/core/protobuf/meta_graph.proto#L50-L56 + // status - Set to OK on success and an appropriate error on failure. + // Returns: + // If status is not OK, returns nullptr. + static std::unique_ptr Load( + const std::string& saved_model_path, const Runtime& runtime, + Status* status, const std::unordered_set* tags = nullptr); + + // Retrieve a function from the TF2 SavedModel via function path. + // + // Params: + // function_path - A string containing the path from the root saved python + // object to a tf.function method. + // status - Set to OK on success and an appropriate error on failure. + // Returns: + // If status is not OK, returns nullptr. Otherwise, returns a + // tensorflow::cc::ConcreteFunction pointer. The lifetime of this pointer + // is bound to SavedModelAPI it was loaded from. + ConcreteFunction* GetConcreteFunction(const std::string& function_path, + Status* status); + + // Retrieve a function from the TF SavedModel via a SignatureDef key. + // + // Params: + // signature_def_key - String key of SignatureDef map of a SavedModel: + // https://github.com/tensorflow/tensorflow/blob/69b08900b1e991d84bce31f3b404f5ed768f339f/tensorflow/core/protobuf/meta_graph.proto#L89 + // status - Set to OK on success and an appropriate error on failure. + // Returns: + // If status is not OK, returns nullptr. Otherwise, returns a + // tensorflow::cc::ConcreteFunction pointer. The lifetime of this pointer + // is bound to SavedModelAPI it was loaded from. + SignatureDefFunction* GetSignatureDefFunction( + const std::string& function_path, Status* status); + + // SavedModelAPI is movable, but not copyable. + SavedModelAPI(SavedModelAPI&&) = default; + SavedModelAPI& operator=(SavedModelAPI&&) = default; + + private: + SavedModelAPI(const SavedModelAPI&) = delete; + SavedModelAPI& operator=(const SavedModelAPI&) = delete; + + explicit SavedModelAPI(TF_SavedModel* model) : saved_model_(model) {} + struct TFSavedModelDeleter { + void operator()(TF_SavedModel* p) const { TF_DeleteSavedModel(p); } + }; + std::unique_ptr saved_model_; +}; + +inline std::unique_ptr SavedModelAPI::Load( + const std::string& saved_model_path, const Runtime& runtime, Status* status, + const std::unordered_set* tags) { + TF_SavedModel* saved_model = nullptr; + + if (tags == nullptr) { + saved_model = + TF_LoadSavedModel(saved_model_path.c_str(), runtime.GetTFEContext(), + status->GetTFStatus()); + } else { + std::vector tags_vector; + tags_vector.reserve(tags->size()); + for (const std::string& tag : *tags) { + tags_vector.push_back(tag.c_str()); + } + saved_model = TF_LoadSavedModelWithTags( + saved_model_path.c_str(), runtime.GetTFEContext(), tags_vector.data(), + tags_vector.size(), status->GetTFStatus()); + } + + if (!status->ok()) { + return nullptr; + } + + // We can't use std::make_unique here because of its interaction with a + // private constructor: https://abseil.io/tips/134 + return std::unique_ptr(new SavedModelAPI(saved_model)); +} + +inline ConcreteFunction* SavedModelAPI::GetConcreteFunction( + const std::string& function_path, Status* status) { + TF_ConcreteFunction* function = TF_GetSavedModelConcreteFunction( + saved_model_.get(), function_path.c_str(), status->GetTFStatus()); + if (!status->ok()) { + return nullptr; + } + return ConcreteFunction::wrap(function); +} + +inline SignatureDefFunction* SavedModelAPI::GetSignatureDefFunction( + const std::string& function_path, Status* status) { + TF_SignatureDefFunction* function = TF_GetSavedModelSignatureDefFunction( + saved_model_.get(), function_path.c_str(), status->GetTFStatus()); + if (!status->ok()) { + return nullptr; + } + return SignatureDefFunction::wrap(function); +} + +} // namespace cc +} // namespace experimental +} // namespace tensorflow + +#endif // TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_SAVED_MODEL_API_H_ diff --git a/third_party/tflite-hdrs/tensorflow/cc/saved_model/experimental/public/signature_def_function.h b/third_party/tflite-hdrs/tensorflow/cc/saved_model/experimental/public/signature_def_function.h new file mode 100644 index 00000000..bc72d208 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/cc/saved_model/experimental/public/signature_def_function.h @@ -0,0 +1,89 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_SIGNATURE_DEF_FUNCTION_H_ +#define TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_SIGNATURE_DEF_FUNCTION_H_ + +#include + +#include "tensorflow/c/eager/c_api.h" +#include "tensorflow/c/experimental/saved_model/public/signature_def_function.h" +#include "tensorflow/cc/experimental/base/public/status.h" +#include "tensorflow/cc/saved_model/experimental/public/signature_def_function_metadata.h" + +namespace tensorflow { +namespace experimental { +namespace cc { + +// SignatureDefFunctions are functions that correspond to either: +// "signatures" saved from a TF2 SavedModel APIs: +// https://github.com/tensorflow/tensorflow/blob/8ce0600f58ed84a8c84a7bbdb014d1f09e44f4c8/tensorflow/python/saved_model/save.py#L830-L854 +// Or the "SignatureDefMap" saved from TF1 SavedModel APIs: +// https://github.com/tensorflow/tensorflow/blob/8ce0600f58ed84a8c84a7bbdb014d1f09e44f4c8/tensorflow/python/saved_model/load_v1_in_v2_test.py#L170-L174 +// In both cases, a SignatureDef is serialized as a SignatureDef protobuf: +// https://github.com/tensorflow/tensorflow/blob/8ce0600f58ed84a8c84a7bbdb014d1f09e44f4c8/tensorflow/core/protobuf/meta_graph.proto#L260-L330 +// and represents a computation defined by a TF subgraph. +// These Signatures were primarily designed to be interoperable with the legacy +// TF 1 Session-based C++ SavedModelBundle loading APIs: +// https://github.com/tensorflow/tensorflow/blob/26c4ee0c833e74f94d0102d8b005c41a28b44445/tensorflow/cc/saved_model/loader.h#L96-L108 +// SignatureDefFunctions have different semantics from regular TF2 +// ConcreteFunctions, and are mainly intended provide a serving-friendly +// transition point from the TF1 Session API. +// First, SignatureDefFunctions have different calling conventions. +// SignatureDefFunctions' inputs and outputs are constrained to **flattened +// lists of TensorHandles only**. They do not support more exotic input/output +// types (like optionals, generators, etc). Additionally, this flattening means +// they will not preserve the exact interface of the original tf.function they +// were traced from, as things like composite tensors decay into their +// internal dense tensor representation. +// Second, all inputs and outputs are "named", and these names are load bearing +// (eg: they are part of the interface of tensorflow_serving): +// https://github.com/tensorflow/serving/blob/e0d247b2e4050713194b8fad0be24a0636df7209/tensorflow_serving/apis/predict.proto#L21 +// https://github.com/tensorflow/serving/blob/e0d247b2e4050713194b8fad0be24a0636df7209/tensorflow_serving/apis/predict.proto#L39 +// The name of each input/output is stored in the corresponding tf::Argument in +// SignatureDefFunctionMetadata::arguments(). Users must ensure the order of +// TensorHandles passed to the function matches with the order of named +// arguments. Similarly the name of the outputs is stored in +// SignatureDefFunctionMetadata::returns(). +class SignatureDefFunction final { + public: + // Returns FunctionMetadata associated with this ConcreteFunction. + const SignatureDefFunctionMetadata* GetFunctionMetadata(); + + private: + friend class SavedModelAPI; + friend class ConcreteFunctionList; + + // TODO(bmzhao): Consider adding a macro for wrapping/unwrapping + // when moving out of experimental. + static SignatureDefFunction* wrap(TF_SignatureDefFunction* p) { + return reinterpret_cast(p); + } + static TF_SignatureDefFunction* unwrap(SignatureDefFunction* p) { + return reinterpret_cast(p); + } +}; + +inline const SignatureDefFunctionMetadata* +SignatureDefFunction::GetFunctionMetadata() { + return SignatureDefFunctionMetadata::wrap( + TF_SignatureDefFunctionGetMetadata(unwrap(this))); +} + +} // namespace cc +} // namespace experimental +} // namespace tensorflow + +#endif // TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_SIGNATURE_DEF_FUNCTION_H_ diff --git a/third_party/tflite-hdrs/tensorflow/cc/saved_model/experimental/public/signature_def_function_metadata.h b/third_party/tflite-hdrs/tensorflow/cc/saved_model/experimental/public/signature_def_function_metadata.h new file mode 100644 index 00000000..6cb01bf1 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/cc/saved_model/experimental/public/signature_def_function_metadata.h @@ -0,0 +1,47 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_SIGNATURE_DEF_FUNCTION_METADATA_H_ +#define TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_SIGNATURE_DEF_FUNCTION_METADATA_H_ + +#include + +#include "tensorflow/c/experimental/saved_model/public/signature_def_function_metadata.h" + +namespace tensorflow { +namespace experimental { +namespace cc { + +// SignatureDefFunctionMetadata stores additional information on each input +// and output's names, dtypes, and shape. +class SignatureDefFunctionMetadata final { + // TODO(bmzhao): Add getters here as necessary. + private: + friend class SignatureDefFunction; + static SignatureDefFunctionMetadata* wrap( + TF_SignatureDefFunctionMetadata* p) { + return reinterpret_cast(p); + } + static TF_SignatureDefFunctionMetadata* unwrap( + SignatureDefFunctionMetadata* p) { + return reinterpret_cast(p); + } +}; + +} // namespace cc +} // namespace experimental +} // namespace tensorflow + +#endif // TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_SIGNATURE_DEF_FUNCTION_METADATA_H_ diff --git a/third_party/tflite-hdrs/tensorflow/cc/saved_model/fingerprinting.h b/third_party/tflite-hdrs/tensorflow/cc/saved_model/fingerprinting.h new file mode 100644 index 00000000..2b232481 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/cc/saved_model/fingerprinting.h @@ -0,0 +1,47 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CC_SAVED_MODEL_FINGERPRINTING_H_ +#define TENSORFLOW_CC_SAVED_MODEL_FINGERPRINTING_H_ + +#include + +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "tensorflow/core/protobuf/fingerprint.pb.h" + +namespace tensorflow::saved_model::fingerprinting { + +// Creates a FingerprintDef proto from a SavedModel (regular or chunked) and the +// checkpoint meta file (.index) in `export_dir`. +absl::StatusOr CreateFingerprintDef( + absl::string_view export_dir); + +// Loads the `fingerprint.pb` from `export_dir`, returns an error if there is +// none. +absl::StatusOr ReadSavedModelFingerprint( + absl::string_view export_dir); + +// Canonical fingerprinting ID for a SavedModel. +std::string Singleprint(uint64_t graph_def_program_hash, + uint64_t signature_def_hash, + uint64_t saved_object_graph_hash, + uint64_t checkpoint_hash); +std::string Singleprint(const FingerprintDef& fingerprint); +absl::StatusOr Singleprint(absl::string_view export_dir); + +} // namespace tensorflow::saved_model::fingerprinting + +#endif // TENSORFLOW_CC_SAVED_MODEL_FINGERPRINTING_H_ diff --git a/third_party/tflite-hdrs/tensorflow/cc/saved_model/fingerprinting_utils.h b/third_party/tflite-hdrs/tensorflow/cc/saved_model/fingerprinting_utils.h new file mode 100644 index 00000000..306abec8 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/cc/saved_model/fingerprinting_utils.h @@ -0,0 +1,137 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CC_SAVED_MODEL_FINGERPRINTING_UTILS_H_ +#define TENSORFLOW_CC_SAVED_MODEL_FINGERPRINTING_UTILS_H_ + +#include +#include +#include + +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "riegeli/bytes/fd_reader.h" // from @riegeli +#include "riegeli/records/record_reader.h" // from @riegeli +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/platform/protobuf.h" // IWYU pragma: keep +#include "tensorflow/core/protobuf/fingerprint.pb.h" +#include "tensorflow/core/protobuf/meta_graph.pb.h" +#include "tensorflow/core/protobuf/saved_model.pb.h" +#include "tensorflow/core/protobuf/saved_object_graph.pb.h" +#include "tensorflow/tools/proto_splitter/chunk.pb.h" + +namespace tensorflow::saved_model::fingerprinting { + +namespace fingerprinting_utils_internal { + +using ::tensorflow::protobuf::Map; +using ::tensorflow::protobuf::Message; +using ::tensorflow::protobuf::RepeatedPtrField; + +// Number of sequential FieldIndex matches of `a` in `b`. (Length of initial +// subsequence.) +// Example: `a = {4, 2}`, `b = {4, 2, 1, 3}`, `fieldTagMatches(a, b) == 2` +absl::StatusOr fieldTagMatches( + const RepeatedPtrField<::tensorflow::proto_splitter::FieldIndex>& a, + const RepeatedPtrField<::tensorflow::proto_splitter::FieldIndex>& b); + +// Pull out the relevant data within `chunked_message`. A `chunked_field` is +// relevant if its `field_tags` are an initial subsequence any of the +// `target_fields` in the provided `target_fields_list`. +absl::StatusOr<::tensorflow::proto_splitter::ChunkedMessage> +PruneChunkedMessage( + const ::tensorflow::proto_splitter::ChunkedMessage& chunked_message, + riegeli::RecordReader>& reader, + std::vector<::tensorflow::proto_splitter::ChunkInfo> chunks_info, + std::vector> + target_fields_list); + +// Deterministically serializes the proto `message`. +std::string SerializeProto(const Message& message); + +// Uses metadata contained in `chunked_message` to hash fields within the +// data accessed by the `reader` using `chunks_info`. +absl::StatusOr HashFields( + const ::tensorflow::proto_splitter::ChunkedMessage& chunked_message, + riegeli::RecordReader>& reader, + const std::vector<::tensorflow::proto_splitter::ChunkInfo>& chunks_info, + const RepeatedPtrField<::tensorflow::proto_splitter::FieldIndex>& + field_tags, + Message* merged_message); + +// Gets the field tags for `graph_def`.::tensorflow +inline RepeatedPtrField<::tensorflow::proto_splitter::FieldIndex> +GraphDefFieldTags(); + +// Gets the field tags for `signature_def`. +inline RepeatedPtrField<::tensorflow::proto_splitter::FieldIndex> +SignatureDefFieldTags(); + +// Gets the field tags for `saved_object_graph`. +inline RepeatedPtrField<::tensorflow::proto_splitter::FieldIndex> +SavedObjectGraphFieldTags(); + +// Returns a `SavedModel` containing only fields (up to those) specified by +// `GraphDefFieldTags()`, `SignatureDefFieldTags()`, and +// `SavedObjectGraphFieldTags()`. +absl::StatusOr PrunedSavedModel( + absl::string_view export_dir, + riegeli::RecordReader>& reader, + const std::vector<::tensorflow::proto_splitter::ChunkInfo>& chunks_info, + ::tensorflow::proto_splitter::ChunkMetadata& chunk_metadata); + +// Hashes the contents of `message` specified by `field_tags`. +absl::StatusOr HashMessage( + Message* message, + const ::tensorflow::proto_splitter::ChunkedMessage& chunked_message, + riegeli::RecordReader>& reader, + const std::vector<::tensorflow::proto_splitter::ChunkInfo>& chunks_info, + const RepeatedPtrField<::tensorflow::proto_splitter::FieldIndex>& + field_tags); + +// Hashes the contents of `graph_def`. +absl::StatusOr HashGraphDef( + tensorflow::GraphDef* graph_def, + const ::tensorflow::proto_splitter::ChunkedMessage& chunked_message, + riegeli::RecordReader>& reader, + const std::vector<::tensorflow::proto_splitter::ChunkInfo>& chunks_info); + +// Hashes the contents of `signature_def`. +absl::StatusOr HashSignatureDef( + const Map& signature_def_map, + const ::tensorflow::proto_splitter::ChunkedMessage& chunked_message, + riegeli::RecordReader>& reader, + const std::vector<::tensorflow::proto_splitter::ChunkInfo>& chunks_info); + +// Hashes the contents of `saved_object_graph`. +absl::StatusOr HashSavedObjectGraph( + tensorflow::SavedObjectGraph* saved_object_graph, + const ::tensorflow::proto_splitter::ChunkedMessage& chunked_message, + riegeli::RecordReader>& reader, + const std::vector<::tensorflow::proto_splitter::ChunkInfo>& chunks_info); + +} // namespace fingerprinting_utils_internal + +// Returns the hash of the checkpoint .index file, 0 if there is none. +uint64_t HashCheckpointIndexFile(absl::string_view model_dir); + +// Creates a FingerprintDef proto from a chunked SavedModel and the checkpoint +// meta file (.index) in `export_dir`. +absl::StatusOr CreateFingerprintDefCpb( + absl::string_view export_dir, std::string cpb_file); + +} // namespace tensorflow::saved_model::fingerprinting + +#endif // TENSORFLOW_CC_SAVED_MODEL_FINGERPRINTING_UTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/cc/saved_model/image_format/internal_api.h b/third_party/tflite-hdrs/tensorflow/cc/saved_model/image_format/internal_api.h new file mode 100644 index 00000000..5c9b13d0 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/cc/saved_model/image_format/internal_api.h @@ -0,0 +1,65 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CC_SAVED_MODEL_IMAGE_FORMAT_INTERNAL_API_H_ +#define TENSORFLOW_CC_SAVED_MODEL_IMAGE_FORMAT_INTERNAL_API_H_ + +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/cord.h" +#include "tensorflow/core/protobuf/saved_model.pb.h" + +#define IS_OSS false + +namespace tensorflow { +namespace image_format { + +// Reads the SavedModel proto from {file_prefix}{.pb|.cpb}. +// Returns a failure status when the SavedModel file does not exist. +absl::Status ReadSavedModel(const std::string& file_prefix, + SavedModel* saved_model_proto); + +// Writes the SavedModel proto to a file or to string. If the proto is < the +// protobuf maximum size, then it will be serialized as a `.pb` proto binary. +// When larger than the maximum size, the SavedModel proto is destructively +// separated into chunks and written to +// `.cpb` (chunked proto). +// +// Write SavedModel to {file_prefix}{.pb|.cpb}. +absl::Status WriteSavedModel(SavedModel* saved_model_proto, + const std::string& file_prefix); +// Writes the SavedModel proto to std::string +// The bool field record whether it's saved as a chunked protobuf (true) or +// regular protobuf (false) +absl::StatusOr> WriteSavedModelToString( + SavedModel* saved_model_proto); +#if !IS_OSS +absl::StatusOr> WriteSavedModelToCord( + SavedModel* saved_model_proto); +#endif + +// See above. The `debug_max_size` argument can be used to the maximum size to +// less than 2GB for testing purposes. +absl::Status WriteSavedModel(SavedModel* saved_model_proto, + const std::string& file_prefix, + int debug_max_size); + +} // namespace image_format +} // namespace tensorflow + +#endif // TENSORFLOW_CC_SAVED_MODEL_IMAGE_FORMAT_INTERNAL_API_H_ diff --git a/third_party/tflite-hdrs/tensorflow/cc/saved_model/loader.h b/third_party/tflite-hdrs/tensorflow/cc/saved_model/loader.h new file mode 100644 index 00000000..f549645e --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/cc/saved_model/loader.h @@ -0,0 +1,150 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +/// SavedModel loading functions and SavedModelBundle struct. + +#ifndef TENSORFLOW_CC_SAVED_MODEL_LOADER_H_ +#define TENSORFLOW_CC_SAVED_MODEL_LOADER_H_ + +#include +#include + +#include "tensorflow/core/framework/graph_debug_info.pb.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/protobuf/meta_graph.pb.h" +#include "tensorflow/core/public/session.h" + +namespace tensorflow { + +/// Represents a SavedModel that is loaded from storage. +class SavedModelBundleInterface { + public: + virtual ~SavedModelBundleInterface(); + + /// Returns the TensorFlow Session that can be used to interact with the + /// SavedModel. + virtual Session* GetSession() const = 0; + + /// Returns a map from signature name to SignatureDef for all signatures in + /// in the SavedModel. + virtual const protobuf::Map& GetSignatures() const = 0; +}; + +/// SavedModel representation once the SavedModel is loaded from storage. +/// +/// NOTE: Prefer to use SavedModelBundleLite in new code, as it consumes less +/// RAM. +struct SavedModelBundle : public SavedModelBundleInterface { + /// A TensorFlow Session does not Close itself on destruction. To avoid + /// resource leaks, we explicitly call Close on Sessions that we create. + ~SavedModelBundle() override { + if (session) { + session->Close().IgnoreError(); + } + } + + SavedModelBundle() = default; + + Session* GetSession() const override { return session.get(); } + const protobuf::Map& GetSignatures() const override { + return meta_graph_def.signature_def(); + } + + std::unique_ptr session; + MetaGraphDef meta_graph_def; + std::unique_ptr debug_info; +}; + +// A version of SavedModelBundle that avoids storing a potentially large +// MetaGraphDef. Prefer to use SavedModelBundleLite in new code. +class SavedModelBundleLite : public SavedModelBundleInterface { + public: + SavedModelBundleLite() = default; + SavedModelBundleLite(SavedModelBundleLite&& other) = default; + SavedModelBundleLite& operator=(SavedModelBundleLite&& other) = default; + + SavedModelBundleLite(std::unique_ptr session, + protobuf::Map signatures) + : session_(std::move(session)), signatures_(std::move(signatures)) {} + + /// A TensorFlow Session does not Close itself on destruction. To avoid + /// resource leaks, we explicitly call Close on Sessions that we create. + ~SavedModelBundleLite() override { + if (session_) { + session_->Close().IgnoreError(); + } + } + + Session* GetSession() const override { return session_.get(); } + const protobuf::Map& GetSignatures() const override { + return signatures_; + } + + private: + std::unique_ptr session_; + protobuf::Map signatures_; +}; + +// Restore variable and resources in the SavedModel export dir for the +// indicated metagraph. +// The recommended way to load a saved model is to call LoadSavedModel, +// which provides an already initialized Metagraph, Session, and DebugInfo. +absl::Status RestoreSession(const RunOptions& run_options, + const MetaGraphDef& meta_graph, + const string& export_dir, + std::unique_ptr* session); + +// Initialize a session which wraps this metagraph. +// The recommended way to load a saved model is to call LoadSavedModel, +// which provides an already initialized Metagraph, Session, and DebugInfo. +absl::Status LoadMetagraphIntoSession(const SessionOptions& session_options, + const MetaGraphDef& meta_graph, + std::unique_ptr* session); + +/// Loads a SavedModel from the specified export directory. The MetaGraphDef +/// to be loaded is identified by the supplied tags, corresponding exactly to +/// the set of tags used at SavedModel build time. Stores a SavedModel bundle in +/// *bundle with a session and the requested MetaGraphDef, if found. +/// +/// NOTE: Prefer the overload that takes a SavedModelBundleLite* in new code. +absl::Status LoadSavedModel(const SessionOptions& session_options, + const RunOptions& run_options, + const string& export_dir, + const std::unordered_set& tags, + SavedModelBundle* bundle); + +/// Loads a SavedModel from the specified export directory. The MetaGraphDef +/// to be loaded is identified by the supplied tags, corresponding exactly to +/// the set of tags used at SavedModel build time. Stores a SavedModel bundle +/// in *bundle with a session created from the requested MetaGraphDef if found. +/// +/// This overload creates a SavedModelBundleLite, which consumes less RAM than +/// an equivalent SavedModelBundle. +absl::Status LoadSavedModel(const SessionOptions& session_options, + const RunOptions& run_options, + const string& export_dir, + const std::unordered_set& tags, + SavedModelBundleLite* bundle); + +/// Checks whether the provided directory could contain a SavedModel. Note that +/// the method does not load any data by itself. If the method returns `false`, +/// the export directory definitely does not contain a SavedModel. If the method +/// returns `true`, the export directory may contain a SavedModel but provides +/// no guarantee that it can be loaded. +bool MaybeSavedModelDirectory(const std::string& export_dir); + +} // namespace tensorflow + +#endif // TENSORFLOW_CC_SAVED_MODEL_LOADER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/cc/saved_model/loader_util.h b/third_party/tflite-hdrs/tensorflow/cc/saved_model/loader_util.h new file mode 100644 index 00000000..9ce3500c --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/cc/saved_model/loader_util.h @@ -0,0 +1,40 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CC_SAVED_MODEL_LOADER_UTIL_H_ +#define TENSORFLOW_CC_SAVED_MODEL_LOADER_UTIL_H_ + +#include + +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/protobuf/meta_graph.pb.h" + +namespace tensorflow { +namespace internal { + +// A SavedModel may store the name of the initialization op to run in the +// in the SignatureDef (v2) or a collection (v1). If an init_op collection +// exists, then the collection must contain exactly one op. +absl::Status GetInitOp(const string& export_dir, + const MetaGraphDef& meta_graph_def, + string* init_op_name); + +absl::Status GetAssetFileDefs(const MetaGraphDef& meta_graph_def, + std::vector* asset_file_defs); + +} // namespace internal +} // namespace tensorflow + +#endif // TENSORFLOW_CC_SAVED_MODEL_LOADER_UTIL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/cc/saved_model/metrics.h b/third_party/tflite-hdrs/tensorflow/cc/saved_model/metrics.h new file mode 100644 index 00000000..fa587d60 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/cc/saved_model/metrics.h @@ -0,0 +1,147 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// APIs for accessing SavedModel and checkpoint metric objects. +// +// In order to collect the data from these metrics, please add the metrics to +// the provided monitoring platform. Unless configured with a user-specified +// monitoring platform, the data is not collected in OSS. + +#ifndef TENSORFLOW_CC_SAVED_MODEL_METRICS_H_ +#define TENSORFLOW_CC_SAVED_MODEL_METRICS_H_ + +#include +#include +#include + +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "tensorflow/core/lib/monitoring/counter.h" +#include "tensorflow/core/lib/monitoring/gauge.h" +#include "tensorflow/core/lib/monitoring/sampler.h" +#include "tensorflow/core/protobuf/fingerprint.pb.h" + +namespace tensorflow { +namespace metrics { + +const char kFingerprintFound[] = "FOUND"; +const char kFingerprintNotFound[] = "NOT_FOUND"; +const char kFingerprintError[] = "ERROR"; + +// Returns "/tensorflow/core/saved_model/write/count" cell. This metric +// has 1 field "write_version", which is equal to the +// `tensorflow::libexport::GetWriteVersion` of the protobuf and should be +// incremented when a SavedModel has been successfully written. +monitoring::CounterCell& SavedModelWriteCount(absl::string_view write_version); + +// Returns "/tensorflow/core/saved_model/read/count" cell. This metric +// has 1 field "write_version", which is equal to the +// `tensorflow::libexport::GetWriteVersion` of the protobuf, and should be +// incremented when a SavedModel has been successfully read. +monitoring::CounterCell& SavedModelReadCount(absl::string_view write_version); + +// Returns "/tensorflow/core/saved_model/write/api" cell. This metric has 1 +// field "api_label" which corresponds to a SavedModel write API. The cell for +// `foo` should be incremented when the write API `foo` is called. +monitoring::CounterCell& SavedModelWriteApi(absl::string_view api_label); + +// Returns "/tensorflow/core/saved_model/read/api" cell. This metric has 1 +// field "api_label" which corresponds to a SavedModel read API. The cell for +// `foo` should be incremented when the read API `foo` is called. +monitoring::CounterCell& SavedModelReadApi(absl::string_view api_label); + +// Returns "/tensorflow/core/saved_model/write/fingerprint" cell, which contains +// the saved_model_checksum of the SM's fingerprint when it is exported. +monitoring::GaugeCell& SavedModelWriteFingerprint(); + +// Returns "/tensorflow/core/saved_model/write/path" cell, which contains +// the saved_model_path of the SM when it is exported. +monitoring::GaugeCell& SavedModelWritePath(); + +// Returns "/tensorflow/core/saved_model/write/path_and_fingerprint" cell, which +// contains the path (saved_model_path) and fingerprint (concatenation of +// graph_def_program_hash, signature_def_hash, saved_object_graph_hash, +// and checkpoint_hash) of the SavedModel when it is exported. +monitoring::GaugeCell& SavedModelWritePathAndSingleprint(); + +// Returns "/tensorflow/core/saved_model/read/fingerprint" cell, wich contains +// the saved_model_checksum of the SM's fingerprint when it is imported. +monitoring::GaugeCell& SavedModelReadFingerprint(); + +// Returns "/tensorflow/core/saved_model/read/path" cell, wich contains +// the saved_model_path of the SM when it is imported. +monitoring::GaugeCell& SavedModelReadPath(); + +// Returns "/tensorflow/core/saved_model/read/path_and_fingerprint" cell, which +// contains the path (saved_model_path) and singleprint (concatenation of +// graph_def_program_hash, signature_def_hash, saved_object_graph_hash, +// and checkpoint_hash) of the SavedModel when it is imported. +monitoring::GaugeCell& SavedModelReadPathAndSingleprint(); + +// Returns the fingerprint as a Json string. +std::string MakeFingerprintJson(FingerprintDef fingerprint_def); + +// Returns canonical string concatenation of path and singleprint. +absl::StatusOr MakeSavedModelPathAndSingleprint( + std::string path, std::string singleprint); + +// Returns path and singleprint as a pair, parsed canonically from the string +// metric. +absl::StatusOr> +ParseSavedModelPathAndSingleprint(std::string path_and_singleprint); + +// Returns string status indicating whether or not the fingerprint.pb file was +// found when loading the SavedModel. +monitoring::GaugeCell& SavedModelFoundFingerprintOnLoad(); + +// Returns "/tensorflow/core/checkpoint/read/read_durations" cell belonging to +// field `api_label`. +monitoring::SamplerCell& CheckpointReadDuration(absl::string_view api_label); + +// Returns "/tensorflow/core/checkpoint/write/write_durations" cell belonging to +// field `api_label`. +monitoring::SamplerCell& CheckpointWriteDuration(absl::string_view api_label); + +// Returns "/tensorflow/core/checkpoint/write/async_write_durations" cell +// belonging to field `api_label`. +monitoring::SamplerCell& AsyncCheckpointWriteDuration( + absl::string_view api_label); + +// Returns "/tensorflow/core/checkpoint/write/training_time_saved" cell +// belonging to field `api_label`. +monitoring::CounterCell& TrainingTimeSaved(absl::string_view api_label); + +// Returns "/tensorflow/core/checkpoint/write/checkpoint_size" cell +// belonging to field (`api_label`, `filesize`). +monitoring::CounterCell& CheckpointSize(absl::string_view api_label, + int64_t filesize); + +// Returns "/tensorflow/core/checkpoint/sharding/callback_duration" cell which +// describes how long it took to execute the checkpoint sharding callback in +// microseconds. +monitoring::CounterCell& ShardingCallbackDuration(); + +// Returns "/tensorflow/core/checkpoint/sharding/num_checkpoint_shards_written" +// cell which describes how many checkpoint shard files were written during +// saving. +monitoring::CounterCell& NumCheckpointShardsWritten(); + +// Returns "/tensorflow/core/checkpoint/sharding/callback_description" cell +// which describes the callback used to shard the checkpoint during saving. +monitoring::GaugeCell& ShardingCallbackDescription(); + +} // namespace metrics +} // namespace tensorflow + +#endif // TENSORFLOW_CC_SAVED_MODEL_METRICS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/cc/saved_model/reader.h b/third_party/tflite-hdrs/tensorflow/cc/saved_model/reader.h new file mode 100644 index 00000000..b5e81f9e --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/cc/saved_model/reader.h @@ -0,0 +1,59 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +/// Functions to read the SavedModel proto, or parts of it. + +#ifndef TENSORFLOW_CC_SAVED_MODEL_READER_H_ +#define TENSORFLOW_CC_SAVED_MODEL_READER_H_ + +#include +#include + +#include "absl/status/statusor.h" +#include "tensorflow/core/framework/graph_debug_info.pb.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/protobuf/meta_graph.pb.h" +#include "tensorflow/core/protobuf/saved_model.pb.h" + +namespace tensorflow { +absl::Status ReadSavedModel(absl::string_view export_dir, + SavedModel* saved_model_proto); + +// Finds and returns the MetaGraphDef (within the provided SavedModel) that +// matches the given set of tags. The lifetime of the returned MetaGraphDef is +// the same as the lifetime of `saved_model_proto`. +// +// FindMetaGraphDef returns a failure status when no MetaGraphDef matches the +// provided tags. +absl::StatusOr FindMetaGraphDef( + const std::unordered_set& tags, SavedModel* saved_model_proto); + +// Reads the SavedModel proto from saved_model.pb(txt) in the given directory, +// finds the MetaGraphDef that matches the given set of tags and writes it to +// the `meta_graph_def` parameter. Returns a failure status when the SavedModel +// file does not exist or no MetaGraphDef matches the tags. +absl::Status ReadMetaGraphDefFromSavedModel( + absl::string_view export_dir, const std::unordered_set& tags, + MetaGraphDef* meta_graph_def); + +// Store debug info from the SavedModel export dir. +absl::Status ReadSavedModelDebugInfoIfPresent( + absl::string_view export_dir, + std::unique_ptr* debug_info_proto); + +} // namespace tensorflow + +#endif // TENSORFLOW_CC_SAVED_MODEL_READER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/cc/saved_model/signature_constants.h b/third_party/tflite-hdrs/tensorflow/cc/saved_model/signature_constants.h new file mode 100644 index 00000000..7d8c07f5 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/cc/saved_model/signature_constants.h @@ -0,0 +1,69 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CC_SAVED_MODEL_SIGNATURE_CONSTANTS_H_ +#define TENSORFLOW_CC_SAVED_MODEL_SIGNATURE_CONSTANTS_H_ + +namespace tensorflow { + +/// Key in the signature def map for `default` serving signatures. The default +/// signature is used in inference requests where a specific signature was not +/// specified. +static constexpr char kDefaultServingSignatureDefKey[] = "serving_default"; + +//////////////////////////////////////////////////////////////////////////////// +/// Classification API constants. + +/// Classification inputs. +static constexpr char kClassifyInputs[] = "inputs"; + +/// Classification method name used in a SignatureDef. +static constexpr char kClassifyMethodName[] = "tensorflow/serving/classify"; + +/// Classification classes output. +static constexpr char kClassifyOutputClasses[] = "classes"; + +/// Classification scores output. +static constexpr char kClassifyOutputScores[] = "scores"; + +//////////////////////////////////////////////////////////////////////////////// +/// Predict API constants. + +/// Predict inputs. +static constexpr char kPredictInputs[] = "inputs"; + +/// Predict method name used in a SignatureDef. +static constexpr char kPredictMethodName[] = "tensorflow/serving/predict"; + +/// Predict outputs. +static constexpr char kPredictOutputs[] = "outputs"; + +//////////////////////////////////////////////////////////////////////////////// +/// Regression API constants. + +/// Regression inputs. +static constexpr char kRegressInputs[] = "inputs"; + +/// Regression method name used in a SignatureDef. +static constexpr char kRegressMethodName[] = "tensorflow/serving/regress"; + +/// Regression outputs. +static constexpr char kRegressOutputs[] = "outputs"; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace tensorflow + +#endif // TENSORFLOW_CC_SAVED_MODEL_SIGNATURE_CONSTANTS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/cc/saved_model/tag_constants.h b/third_party/tflite-hdrs/tensorflow/cc/saved_model/tag_constants.h new file mode 100644 index 00000000..68a090e0 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/cc/saved_model/tag_constants.h @@ -0,0 +1,35 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CC_SAVED_MODEL_TAG_CONSTANTS_H_ +#define TENSORFLOW_CC_SAVED_MODEL_TAG_CONSTANTS_H_ + +namespace tensorflow { + +/// Tag for the `gpu` graph. +constexpr char kSavedModelTagGpu[] = "gpu"; + +/// Tag for the `tpu` graph. +constexpr char kSavedModelTagTpu[] = "tpu"; + +/// Tag for the `serving` graph. +constexpr char kSavedModelTagServe[] = "serve"; + +/// Tag for the `training` graph. +constexpr char kSavedModelTagTrain[] = "train"; + +} // namespace tensorflow + +#endif // TENSORFLOW_CC_SAVED_MODEL_TAG_CONSTANTS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/cc/saved_model/test_utils.h b/third_party/tflite-hdrs/tensorflow/cc/saved_model/test_utils.h new file mode 100644 index 00000000..3e131951 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/cc/saved_model/test_utils.h @@ -0,0 +1,53 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CC_SAVED_MODEL_TEST_UTILS_H_ +#define TENSORFLOW_CC_SAVED_MODEL_TEST_UTILS_H_ + +#include +#include + +#include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow::saved_model { + +// TODO(b/229726259) Switch to OSS version after it's available. +// Simple implementation of a proto matcher comparing string representations. +// Only works as ShapeProto's textual representation is deterministic. +class ProtoStringMatcher { + public: + explicit ProtoStringMatcher(const tensorflow::protobuf::Message& expected) + : expected_(expected.DebugString()) {} + + template + bool MatchAndExplain(const Message& p, + ::testing::MatchResultListener*) const { + return p.DebugString() == expected_; + } + + void DescribeTo(::std::ostream* os) const { *os << expected_; } + void DescribeNegationTo(::std::ostream* os) const { + *os << "not equal to expected message: " << expected_; + } + + private: + const std::string expected_; +}; + +inline ::testing::PolymorphicMatcher EqualsProto( + const tensorflow::protobuf::Message& x) { + return ::testing::MakePolymorphicMatcher(ProtoStringMatcher(x)); +} + +} // namespace tensorflow::saved_model + +#endif // TENSORFLOW_CC_SAVED_MODEL_TEST_UTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/cc/saved_model/util.h b/third_party/tflite-hdrs/tensorflow/cc/saved_model/util.h new file mode 100644 index 00000000..2489f837 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/cc/saved_model/util.h @@ -0,0 +1,56 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CC_SAVED_MODEL_UTIL_H_ +#define TENSORFLOW_CC_SAVED_MODEL_UTIL_H_ + +#include +#include +#include +#include + +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor.pb.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/protobuf/meta_graph.pb.h" +#include "tensorflow/core/protobuf/saved_model.pb.h" + +namespace tensorflow { +namespace saved_model { + +// Utility functions for SavedModel reading and writing. + +// Returns "WriteVersion" ("1" or "2") of the SavedModel protobuf. If the +// protobuf has exactly one MetaGraphDef, which contains a SavedObjectGraph, it +// is version 2. Else, the protobuf is version 1. +// +// NOTE: The "WriteVersion" does *not* equal the major version of TF. +std::string GetWriteVersion(const SavedModel& saved_model); + +// Get view of string keys of a map. +std::set GetMapKeys( + const ::google::protobuf::Map& map); + +// Get the default input value from signature if it's missing in the request +// inputs. If `is_alias` is set to true, the keys of the `request_inputs` are +// alias names rather than the feed names in the graph. +absl::Status GetInputValues( + const SignatureDef& signature, + const ::google::protobuf::Map& request_inputs, + std::vector>& inputs); + +} // namespace saved_model +} // namespace tensorflow + +#endif // TENSORFLOW_CC_SAVED_MODEL_UTIL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/cc/tools/freeze_saved_model.h b/third_party/tflite-hdrs/tensorflow/cc/tools/freeze_saved_model.h new file mode 100644 index 00000000..8a35bafe --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/cc/tools/freeze_saved_model.h @@ -0,0 +1,44 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CC_TOOLS_FREEZE_SAVED_MODEL_H_ +#define TENSORFLOW_CC_TOOLS_FREEZE_SAVED_MODEL_H_ + +#include + +#include "tensorflow/cc/saved_model/loader.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +// Returns a frozen GraphDef, input tensors, and output tensors from the loaded +// SavedModelBundle. +// `inputs` and `outputs` consist of the union of all inputs and outputs in the +// SignatureDefs in the SavedModelBundle. +// FreezeSavedModel sets `frozen_graph_def` to a GraphDef of all nodes needed by +// `outputs`. All variables in the supplied SavedModelBundle are converted to +// constants, set to the value of the variables, by running the restored Session +// in the SavedModelBundle. +// WARNING: Only the variable checkpoints will be reflected in the frozen +// graph_def. All saved_model assets will be ignored. +absl::Status FreezeSavedModel(const SavedModelBundle& saved_model_bundle, + GraphDef* frozen_graph_def, + std::unordered_set* inputs, + std::unordered_set* outputs); + +} // namespace tensorflow + +#endif // TENSORFLOW_CC_TOOLS_FREEZE_SAVED_MODEL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/cc/training/coordinator.h b/third_party/tflite-hdrs/tensorflow/cc/training/coordinator.h new file mode 100644 index 00000000..2a52d743 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/cc/training/coordinator.h @@ -0,0 +1,136 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CC_TRAINING_COORDINATOR_H_ +#define TENSORFLOW_CC_TRAINING_COORDINATOR_H_ + +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "xla/tsl/protobuf/error_codes.pb.h" +#include "tensorflow/core/framework/cost_graph.pb.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/protobuf/config.pb.h" +#include "tensorflow/core/protobuf/error_codes.pb.h" +#include "tsl/platform/thread_annotations.h" + +namespace tensorflow { + +/// The abstract interface for runners which must implement the Join and the +/// IsRunning function. +class RunnerInterface { + public: + virtual ~RunnerInterface() {} + virtual absl::Status Join() = 0; + virtual absl::Status ExportCostGraph(CostGraphDef* cost_graph) const { + return absl::Status(absl::StatusCode::kInvalidArgument, + "No cost model to export."); + } + /// Returns true iff the runner is running, i.e. if it is trying to populate + /// its queue. + virtual bool IsRunning() const = 0; +}; + +/// Coordinator class manages the termination of a collection of QueueRunners. +/// Without a coordinator, QueueRunners have to be joined in a specific order; +/// otherwise the QueueRunner::Join() could sometimes hang. The +/// Coordinator::RequestStop() plays the key role which notifies all running +/// threads under a coordinator to stop. This function could be called by any +/// thread or any client. +/// Usage, in the client: +/// Coordinator coord; +/// std::unique_ptr qr(&coord, ...); +/// qr.Start(session); +/// coord.RegisterRunner(std::move(qr)); +/// /// do some work +/// TF_CHECK_OK(coord.Join()); +/// In each thread of QueueRunner, the coordinator needs to be used as: +/// void Run() { +/// while (!coord->ShouldStop()) { +/// /// do some work +/// if (error) { +/// coord->RequestStop(); +/// coord->ReportStatus(error_status); +/// } +/// } +/// } +class Coordinator { + public: + Coordinator(); + + /// Constructor with a list of error codes which would not be taken as errors + /// in status reporting. + Coordinator(const std::vector& clean_stop_errors); + + /// In the destructor, RequestStop() and Join() would be called. + ~Coordinator(); + + /// Registers a runner, i.e. a unit of running threads which is usually a + /// QueueRunner. It takes the ownership of runner to avoid lifecycle-related + /// problems. Note, the coordinator would not start these threads; they are + /// supposed to be in running state when they are registered here. + absl::Status RegisterRunner(std::unique_ptr runner); + + /// Returns true iff all the registered runners have been stopped. + bool AllRunnersStopped(); + + /// Requests all running threads to stop. + absl::Status RequestStop(); + + /// Returns true if its RequestStop() has been called. + bool ShouldStop(); + + /// Joins all threads, returns OK or the first reported and unexpected status. + absl::Status Join(); + + /// Reports status to the coordinator. This is usually called by threads. + void ReportStatus(const absl::Status& status); + + /// Returns the latest status. + absl::Status GetStatus(); + + /// Returns immediately if the coordinator is stopped or blocks until + /// RequestStop() is called. + void WaitForStop(); + + // Returns the cost graph from stored run metadata in registered runners. + absl::Status ExportCostGraph(CostGraphDef* cost_graph) const; + + private: + std::unordered_set clean_stop_errors_; + condition_variable wait_for_stop_; + + mutex mu_; + bool should_stop_ TF_GUARDED_BY(mu_); + + mutex status_lock_; + absl::Status status_ TF_GUARDED_BY(status_lock_); + + mutable mutex runners_lock_; + std::vector> runners_ + TF_GUARDED_BY(runners_lock_); + + Coordinator(const Coordinator&) = delete; + void operator=(const Coordinator&) = delete; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CC_TRAINING_COORDINATOR_H_ diff --git a/third_party/tflite-hdrs/tensorflow/cc/training/queue_runner.h b/third_party/tflite-hdrs/tensorflow/cc/training/queue_runner.h new file mode 100644 index 00000000..3122ff31 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/cc/training/queue_runner.h @@ -0,0 +1,144 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CC_TRAINING_QUEUE_RUNNER_H_ +#define TENSORFLOW_CC_TRAINING_QUEUE_RUNNER_H_ + +#include +#include +#include +#include + +#include "tensorflow/cc/training/coordinator.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/core/threadpool.h" +#include "tensorflow/core/platform/blocking_counter.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/protobuf/config.pb.h" +#include "tensorflow/core/protobuf/error_codes.pb.h" +#include "tensorflow/core/protobuf/queue_runner.pb.h" +#include "tensorflow/core/public/session.h" +#include "tsl/platform/thread_annotations.h" + +namespace tensorflow { + +/// QueueRunner class imitates the behavior of the python version of QueueRunner +/// which creates a thread for each enqueue op, runs close op on completion. +class QueueRunner : public RunnerInterface { + public: + /// Creates a new QueueRunner from proto. + // TODO(yuefengz): we may want to initialize from queues and ops in the + // future. + static absl::Status New(const QueueRunnerDef& queue_runner_def, + std::unique_ptr* result); + + /// Creates a new QueueRunner with a coordinator, see coordinator.h for usage. + static absl::Status New(const QueueRunnerDef& queue_runner_def, + Coordinator* coord, + std::unique_ptr* result); + + /// Adds a callback that the queue runner will call when it detects an error. + void AddErrorCallback(const std::function& cb); + + /// Delete the previously registered callbacks. + void ClearErrorCallbacks(); + + /// The destructor would join all the threads. + ~QueueRunner(); + + /// Starts the queue runner with the given session. + absl::Status Start(Session* sess); + + /// Starts the queue runner with the given session and sets the run arguments + /// for sess->Run. It also collects and stores the cost model. + absl::Status StartAndCollectCostGraph( + Session* sess, const RunOptions& run_options = RunOptions()); + + /// Starts the queue runner with the given session, and wait for up to the + /// specified time (in milliseconds) for the queues to start to fill up. + absl::Status Start(Session* sess, int wait_for_ms); + absl::Status StartAndCollectCostGraph( + Session* session, int wait_for_ms, + const RunOptions& run_options = RunOptions()); + + /// Requests to stop and runs the cancel op. It would be called in a separate + /// thread when coordinator is set. If there is no coordinator it should be + /// called before calling Join. + void Stop(Session* sess); + + /// Joins all the threads. Returns okay if all threads run successfully; + /// otherwise returns the first captured failure status. + absl::Status Join() final; + + /// Returns the latest status. + absl::Status GetStatus(); + + // Returns the stored cost model. + absl::Status ExportCostGraph(CostGraphDef* cost_graph) const override; + + private: + QueueRunner() : coord_(nullptr), stopped_(false), cg_mu_(nullptr) {} + + // Initializes the instance with the QueueRunnerDef proto. + absl::Status Init(const QueueRunnerDef& queue_runner_def); + + // The Run function for each thread. + void Run(Session* sess, const string& enqueue_op); + + // Updates the internal status; it only keeps OK or the first unexpected error + // status. + void UpdateStatus(const absl::Status& status); + + bool IsQueueClosed(absl::Status status) const { + return queue_closed_exception_types_.count( + static_cast(status.code())) > 0; + } + + bool IsRunning() const override { return !stopped_; } + + void SetRunArgumentsAndCostGraph(const RunOptions& run_options); + + absl::Status RealRun(Session* sess, const string& op, bool update_costs); + + string queue_name_; + std::vector enqueue_op_names_; + string close_op_name_; + string cancel_op_name_; + // code::Code casted to int to avoid a hash function. + std::unordered_set queue_closed_exception_types_; + + std::unique_ptr thread_pool_; + mutex mu_; + int runs_ = 0; + absl::Status status_ TF_GUARDED_BY(mu_); + absl::Status enqueue_status_ TF_GUARDED_BY(mu_); + std::unique_ptr counter_; + + Coordinator* coord_; + + std::atomic stopped_; + + mutex cb_mu_; + std::vector> callbacks_; + + mutable std::unique_ptr cg_mu_; + std::unique_ptr cost_graph_ TF_GUARDED_BY(cg_mu_); + RunOptions run_options_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CC_TRAINING_QUEUE_RUNNER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/aot/aot_only_var_handle_op.h b/third_party/tflite-hdrs/tensorflow/compiler/aot/aot_only_var_handle_op.h new file mode 100644 index 00000000..43a8196e --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/aot/aot_only_var_handle_op.h @@ -0,0 +1,27 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_AOT_AOT_ONLY_VAR_HANDLE_OP_H_ +#define TENSORFLOW_COMPILER_AOT_AOT_ONLY_VAR_HANDLE_OP_H_ + +namespace tensorflow { +namespace tfcompile { + +static constexpr const char* const kXlaAotOnlyVarHandleOp = + "_XlaAotOnlyVarHandleOp"; + +} // namespace tfcompile +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_AOT_AOT_ONLY_VAR_HANDLE_OP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/aot/benchmark.h b/third_party/tflite-hdrs/tensorflow/compiler/aot/benchmark.h new file mode 100644 index 00000000..526c76c2 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/aot/benchmark.h @@ -0,0 +1,70 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Contains benchmark functions used with the code-generated benchmarks that can +// be used to test a model on android. See also code generation rules in +// tfcompile.bzl. +// +// This is separate from the built-in micro-benchmarks, because we want to: +// 1. show a binary with minimal dependencies, to show a close-to-lower-bound +// binary size. +// 2. compile on Android. +#ifndef TENSORFLOW_COMPILER_AOT_BENCHMARK_H_ +#define TENSORFLOW_COMPILER_AOT_BENCHMARK_H_ + +#include +#include +#include + +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +namespace tfcompile { +namespace benchmark { + +// Options specifies options for benchmarks of functions generated by tfcompile. +struct Options { + // kDefaultMicros specifies the default time to run the benchmark, and is used + // if neither max_iters nor max_micros is set. + static constexpr int64_t kDefaultMicros = 3000000; + + int64_t max_iters = 0; // Maximum iterations to run, ignored if <= 0. + int64_t max_micros = 0; // Maximum microseconds to run, ignored if <= 0. +}; + +// Stats holds statistics collected during benchmarking. +struct Stats { + std::vector per_iter_us; // Per-iteration deltas in us. + int64_t total_us; // Total time in us. + + Stats() : total_us(0) { per_iter_us.reserve(5000); } +}; + +// DumpStatsToStdout printfs to stdout stats in a multi-line human-friendly +// form. +void DumpStatsToStdout(const Stats& stats); + +// BenchmarkFn is the signature of the function generated by tfcompile. +typedef std::function BenchmarkFn; + +// Benchmark runs a benchmark of the function `fn`, collecting stats in `stats`. +// Use `options` to configure benchmarking options. +void Benchmark(const Options& options, const BenchmarkFn& fn, Stats* stats); + +} // namespace benchmark +} // namespace tfcompile +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_AOT_BENCHMARK_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/aot/codegen.h b/third_party/tflite-hdrs/tensorflow/compiler/aot/codegen.h new file mode 100644 index 00000000..993196b1 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/aot/codegen.h @@ -0,0 +1,109 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_AOT_CODEGEN_H_ +#define TENSORFLOW_COMPILER_AOT_CODEGEN_H_ + +#include +#include + +#include "absl/strings/string_view.h" +#include "tensorflow/compiler/aot/compile.h" +#include "tensorflow/compiler/tf2xla/tf2xla.pb.h" + +namespace tensorflow { +namespace tfcompile { + +// CodegenOpts specifies code generation options for the generated header file +// and the generated metadata object file. +struct CodegenOpts { + // The name of the generated C++ class, wrapping the generated function. + string class_name; + + // Target triple for the architecture we're targeting. + string target_triple; + + // Namespaces specifies a list of C++ namespaces to add to the generated + // header. If empty, all symbols will be in the global namespace. + std::vector namespaces; + + // If true, generate name-to-index data for Lookup{Arg,Result}Index methods. + bool gen_name_to_index = false; + + // If true, generate program shape data for the ProgramShape method. + bool gen_program_shape = false; + + // If true, emit a serialized HloProfilePrinterData protobuf that can be used + // to pretty print HLO profile counters. + bool gen_hlo_profile_printer_data = false; + + // If true, sets this executable as an XLA Runtime one. + bool use_xla_runtime = false; +}; + +// Describes a generated metadata object file. +struct MetadataResult { + // These are top level "extern C" declarations that are expected to be visible + // wherever program_shape_access_shim is emitted. + std::vector header_variable_decls; + + // program_shape_access_shim is a C++ expression that constructs the + // xla::ProgramShapeProto instance for the CompileResult passed to + // GenerateMetadata. + string program_shape_access_shim; + + // hlo_profile_printer_data_access_shim is a C++ expression that constructs + // the xla::HloProfilePrinterData instance for the CompileResult passed to + // GenerateMetadata. If the xla::HloProfilePrinterData is null then this is a + // C++ expression that evaluates to nullptr at runtime. + string hlo_profile_printer_data_access_shim; + + // The contents of the object (".o") file. + string object_file_data; +}; + +// Generates a metadata object file according to `opts` and `compile_result`. +// The generated object file is returned via `metadata_result`. +absl::Status GenerateMetadata(const CodegenOpts& opts, + const CompileResult& compile_result, + MetadataResult* metadata_result); + +// GenerateHeader uses the meta-information from compile_result to generate a +// C++ header giving access to the function in the generated object file. The +// header includes API usage documentation. +// +// metadata_result is an instance of MetadataResult obtained by a previous +// invocation to GenerateMetadata. +absl::Status GenerateHeader(const CodegenOpts& opts, + const tf2xla::Config& config, + const CompileResult& compile_result, + const MetadataResult& metadata_result, + string* header); + +// ParseCppClass parses `cpp_class` into its `class_name` and `namespaces` +// components. The syntax is [[::],...]. This +// mirrors the C++ syntax for referring to a class, where multiple namespaces +// may precede the class name, separated by double-colons. +absl::Status ParseCppClass(const string& cpp_class, string* class_name, + std::vector* namespaces); + +// ValidateCppIdent returns OK iff ident is a valid C++ identifier. The msg is +// appended to error messages. +absl::Status ValidateCppIdent(absl::string_view ident, absl::string_view msg); + +} // namespace tfcompile +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_AOT_CODEGEN_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/aot/compile.h b/third_party/tflite-hdrs/tensorflow/compiler/aot/compile.h new file mode 100644 index 00000000..9d3ff78a --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/aot/compile.h @@ -0,0 +1,56 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_AOT_COMPILE_H_ +#define TENSORFLOW_COMPILER_AOT_COMPILE_H_ + +#include +#include + +#include "tensorflow/compiler/aot/flags.h" +#include "tensorflow/compiler/tf2xla/tf2xla.pb.h" +#include "xla/service/cpu/cpu_compiler.h" +#include "xla/xla_data.pb.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/platform/status.h" + +namespace tensorflow { +namespace tfcompile { + +// CompileResult describes the output of CompileGraph, where the object file +// data and meta-information is available in aot. +struct CompileResult { + // Contains object file and meta-info. + std::unique_ptr aot; + xla::ProgramShapeProto program_shape; // Static shape of args and results. + string entry_point; // Name of generated function. + int pointer_size = 0; // Size of a pointer in bytes. +}; + +// CompileGraph compiles the graph_def into an object file containing a function +// that performs the graph operations. +// +// The XLA compilation options are specified in the flags. +absl::Status CompileGraph(GraphDef graph_def, const tf2xla::Config& config, + const MainFlags& flags, + CompileResult* compile_result); + +// The full compilation method, for reuse in a library setting. +absl::Status Main(const MainFlags& flags); + +} // namespace tfcompile +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_AOT_COMPILE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/aot/embedded_protocol_buffers.h b/third_party/tflite-hdrs/tensorflow/compiler/aot/embedded_protocol_buffers.h new file mode 100644 index 00000000..0af4d4a3 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/aot/embedded_protocol_buffers.h @@ -0,0 +1,92 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This file defines utilities to help "embed" protocol buffers into object +// (".o") files. These C++ binaries and shared objects can link in these .o to +// get access to said protocol buffers at runtime. + +#ifndef TENSORFLOW_COMPILER_AOT_EMBEDDED_PROTOCOL_BUFFERS_H_ +#define TENSORFLOW_COMPILER_AOT_EMBEDDED_PROTOCOL_BUFFERS_H_ + +#include "absl/status/statusor.h" +#include "absl/types/span.h" +#include "tensorflow/core/platform/protobuf.h" + +namespace tensorflow { +namespace tfcompile { +using absl::StatusOr; + +// Represents a set of protocol buffers embedded into an object file and +// describes how to access them at runtime. +struct EmbeddedProtocolBuffers { + // Each instance CPPShim describes how to generate C++ code to instantiate a + // protobuf instance from the corresponding static data emitted into the + // object file. + struct CPPShim { + // `expression` is a C++ expression that creates an instance of said + // protocol buffer when executed. + string expression; + + // `variable_decl` is an "extern C" array declaration that is used in + // `expression`. It must be visible wherever `expression` is emitted. + string variable_decl; + }; + + // Each cpp_shim corresponds to one embedded protocol buffer. + std::vector cpp_shims; + + // The contents of the object (".o") file the protocol buffers are embbed in. + // This needs to be linked in to any program that wants to execute any of the + // expressions in `cpp_shims`. + string object_file_data; +}; + +// Describes a protocol buffer to embed into an object file. +struct ProtobufToEmbed { + // `symbol_prefix` is prefix that is guaranteed to be unique across the binary + // or DSO the generated object file will be linked into. + string symbol_prefix; + + // `qualified_cpp_protobuf_name` is a qualified ("qualified" as in C++ + // namespace qualified) protocol buffer name. This is only used in + // CPPShim::expression so relatively qualified names are fine as long as + // they're valid wherever CPPShim::expression is emitted. + string qualified_cpp_protobuf_name; + + // `message` is the protocol buffer to be embedded. It is allowed to be + // nullptr, in which case the generated C++ shim expression is just `nullptr`, + // and the generated object file does not define any symbols. + const ::tensorflow::protobuf::MessageLite* message; +}; + +// Embeds a sequence of protocol buffers into an object file. +// +// `target_triple` is the target triple for the target architecture for the +// generated object file. +// +// `protobufs_to_embed` describes the protocol buffers to embed into the +// resulting object file. The C++ shim for protobufs_to_embed[i] is +// cpp_shims[i] in the returned EmbeddedProtocolBuffers instance. The contents +// of all the protocol buffers are embedded into a single .o file whose content +// is stored in the object_file_data field in the returned +// EmbeddedProtocolBuffers instance. +absl::StatusOr CreateEmbeddedProtocolBuffers( + absl::string_view target_triple, + absl::Span protobufs_to_embed); + +} // namespace tfcompile +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_AOT_EMBEDDED_PROTOCOL_BUFFERS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/aot/flags.h b/third_party/tflite-hdrs/tensorflow/compiler/aot/flags.h new file mode 100644 index 00000000..7b02f172 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/aot/flags.h @@ -0,0 +1,62 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_AOT_FLAGS_H_ +#define TENSORFLOW_COMPILER_AOT_FLAGS_H_ + +#include +#include + +#include "tensorflow/core/util/command_line_flags.h" + +namespace tensorflow { +namespace tfcompile { + +// Flags for the tfcompile binary. See *.cc file for descriptions. + +struct MainFlags { + string graph; + string debug_info; + string debug_info_path_begin_marker; + string config; + bool dump_fetch_nodes = false; + string target_triple; + string target_cpu; + string target_features; + string entry_point; + string cpp_class; + string out_function_object; + string out_metadata_object; + string out_header; + string out_session_module; + string mlir_components; + bool experimental_quantize = false; + + // Sanitizer pass options + bool sanitize_dataflow = false; + string sanitize_abilists_dataflow; + + // C++ codegen options + bool gen_name_to_index = false; + bool gen_program_shape = false; +}; + +// Appends to flag_list a tensorflow::Flag for each field in MainFlags. +void AppendMainFlags(std::vector* flag_list, MainFlags* flags); + +} // namespace tfcompile +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_AOT_FLAGS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/aot/quantize.h b/third_party/tflite-hdrs/tensorflow/compiler/aot/quantize.h new file mode 100644 index 00000000..62f03808 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/aot/quantize.h @@ -0,0 +1,41 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_AOT_QUANTIZE_H_ +#define TENSORFLOW_COMPILER_AOT_QUANTIZE_H_ + +#include +#include +#include + +#include "tensorflow/compiler/tf2xla/tf2xla.pb.h" +#include "xla/hlo/builder/xla_computation.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/status.h" + +namespace tensorflow { +namespace tfcompile { + +using QuantizeXlaFn = std::function; + +// Set the static quantization function to the `fn` if it hasn't been set. +// Return false if the static function has been set. +bool RegisterQuantizeFn(const QuantizeXlaFn& fn); + +} // namespace tfcompile +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_AOT_QUANTIZE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/jit/build_xla_ops_pass.h b/third_party/tflite-hdrs/tensorflow/compiler/jit/build_xla_ops_pass.h new file mode 100644 index 00000000..c1219d7c --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/jit/build_xla_ops_pass.h @@ -0,0 +1,45 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_JIT_BUILD_XLA_OPS_PASS_H_ +#define TENSORFLOW_COMPILER_JIT_BUILD_XLA_OPS_PASS_H_ + +#include "absl/types/optional.h" +#include "tensorflow/core/common_runtime/optimization_registry.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { + +// Replaces TF function calls marked with `_XlaCompiledKernel` with _XlaCompile +// and _XlaRun nodes (which compile and launch, respectively, the corresponding +// HLO module). +class BuildXlaOpsPass : public GraphOptimizationPass { + public: + // If enable_lazy_compilation is not nullopt then *enable_lazy_compilation + // overrides --tf_xla_enable_lazy_compilation flag in deciding whether lazy + // compilation is enabled. + explicit BuildXlaOpsPass( + std::optional enable_lazy_compilation = std::nullopt) + : enable_lazy_compilation_(enable_lazy_compilation) {} + + absl::Status Run(const GraphOptimizationPassOptions& options) override; + + private: + std::optional enable_lazy_compilation_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_JIT_BUILD_XLA_OPS_PASS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/jit/clone_constants_for_better_clustering.h b/third_party/tflite-hdrs/tensorflow/compiler/jit/clone_constants_for_better_clustering.h new file mode 100644 index 00000000..ebe51008 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/jit/clone_constants_for_better_clustering.h @@ -0,0 +1,62 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_JIT_CLONE_CONSTANTS_FOR_BETTER_CLUSTERING_H_ +#define TENSORFLOW_COMPILER_JIT_CLONE_CONSTANTS_FOR_BETTER_CLUSTERING_H_ + +#include "absl/container/flat_hash_set.h" +#include "tensorflow/core/common_runtime/optimization_registry.h" + +namespace tensorflow { +// Clones small host constants in the graph to make it easier to form larger +// clusters. +// +// This helps us in two ways: +// +// - It reduces dependencies between clusters. Let's say a constant C is used +// by nodes X and Y. If X and Y are put in different clusters (for whatever +// reason) Y's cluster now has to wait for all the operations in X's cluster +// to finish before it starts running. +// +// - It lets us create bigger clusters in multi-GPU benchmarks. Consider the +// following graph: +// +// digraph { +// Const -> GPU_1 +// Const -> GPU_0_Y +// GPU_0_X -> GPU_0_Y +// } +// +// We'd cluster Const and GPU_1 together (and place it on GPU_1), and this +// will block us from clustering GPU_0_X and GPU_0_Y together since that +// would increase the amount of work on GPU 0 waiting on work on GPU 1. +// However, cloning Const into two copies, one for GPU_0_Y and one for GPU_1 +// will let us create one cluster containing {Const/copy_0, GPU_1} and +// another containing {Const/copy_1, GPU_0_X, GPU_0_Y}. +// +// We only clone small host constants now to avoid increasing memory consumption +// too much. Moreover, in practice the constants we have to duplicate are +// things like the `perm` input to `Transpose` and the `size` input to `Slice` +// which tend to be small anyway. + +class CloneConstantsForBetterClusteringPass : public GraphOptimizationPass { + public: + CloneConstantsForBetterClusteringPass() = default; + + absl::Status Run(const GraphOptimizationPassOptions& options) override; +}; +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_JIT_CLONE_CONSTANTS_FOR_BETTER_CLUSTERING_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/jit/cluster_scoping_pass.h b/third_party/tflite-hdrs/tensorflow/compiler/jit/cluster_scoping_pass.h new file mode 100644 index 00000000..0b0c2ccf --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/jit/cluster_scoping_pass.h @@ -0,0 +1,38 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_JIT_CLUSTER_SCOPING_PASS_H_ +#define TENSORFLOW_COMPILER_JIT_CLUSTER_SCOPING_PASS_H_ + +#include "tensorflow/core/common_runtime/optimization_registry.h" + +namespace tensorflow { + +// This pass adds scopes to nodes in the _XlaInternalScope attribute to guide +// the later clustering passes. A major reason to do this is to prevent the +// clustering from losing critical parallelism in the Tensorflow graph, which +// can incur great performance degradation. +// +// This pass must be run before MarkForCompilationPass, as it stores the +// scoping information that MarkForCompilationPass will need to respect for +// clustering decision. +class ClusterScopingPass : public GraphOptimizationPass { + public: + absl::Status Run(const GraphOptimizationPassOptions& options) override; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_JIT_CLUSTER_SCOPING_PASS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/jit/compilability_check_util.h b/third_party/tflite-hdrs/tensorflow/compiler/jit/compilability_check_util.h new file mode 100644 index 00000000..0d86c22d --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/jit/compilability_check_util.h @@ -0,0 +1,340 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_JIT_COMPILABILITY_CHECK_UTIL_H_ +#define TENSORFLOW_COMPILER_JIT_COMPILABILITY_CHECK_UTIL_H_ + +#include + +#include "absl/algorithm/container.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "tensorflow/compiler/jit/defs.h" +#include "tensorflow/compiler/jit/device_util.h" +#include "tensorflow/compiler/jit/flags.h" +#include "tensorflow/compiler/jit/resource_operation_safety_analysis.h" +#include "tensorflow/compiler/tf2xla/const_analysis.h" +#include "tensorflow/compiler/tf2xla/resource_operation_table.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "xla/service/graphcycles/graphcycles.h" +#include "xla/union_find.h" +#include "xla/util.h" +#include "tensorflow/core/common_runtime/function.h" +#include "tensorflow/core/common_runtime/graph_constructor.h" +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/bounds_check.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/graph_def_util.h" +#include "tensorflow/core/framework/memory_types.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/graph/algorithm.h" +#include "tensorflow/core/graph/control_flow.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/lib/gtl/cleanup.h" +#include "tensorflow/core/lib/strings/stringprintf.h" +#include "tensorflow/core/public/version.h" +#include "tensorflow/core/util/dump_graph.h" + +namespace tensorflow { +// Checks whether a TF node can be compiled or not. "Recursive" as in for call +// and functional while nodes it recursively checks whether the callee functions +// can be compiled. +class RecursiveCompilabilityChecker { + public: + // Contains node name and function name. If the node is not inside a function + // body, function name is an empty string. + struct StackFrame { + std::string name; + std::string function_name; + std::shared_ptr stack_trace; + }; + + // Contains information about uncompilable node inside a function body. + struct UncompilableNodeInfo { + std::string name; + // A list representing a stacktrace from the highest level node in + // increasing call depth to immediate node that fails the + // compilability checker. + std::vector stack_trace; + std::string uncompilable_reason; + }; + + // Aggregates information about what kinds of ops are allowed. + struct OperationFilter { // TODO(lzr): Add AllowEverything() helper. + // Whether resource variable ops are allowed are allowed in callees. We do + // not allow resource variable ops in called functions (either as direct TF + // calls or as higher order control flow ops) because we do not yet model + // their memory effects in jit/resource_operation_safety_analysis. + bool allow_resource_ops_in_called_functions = false; + + // Whether Stack operations are allowed. We avoid auto-clustering Stack + // operations in general because we do not support snapshotting them. + // + // TODO(b/112837194): This restriction can be lifted with some work. + bool allow_stack_ops = false; + + // Whether TensorArray operations are allowed. We avoid auto-clustering + // TensorArray operations in general because we do not support snapshotting + // them. + // + // TODO(b/112837194): This restriction can be lifted with some work. + bool allow_tensor_array_ops = false; + + // Whether stateful RNG ops are allowed. XLA's RNG does not have the same + // seeding behavior as TensorFlow's RNG (b/34749654). So we avoid + // auto-clustering stateful RNG ops. + bool allow_stateful_rng_ops = false; + + // TODO(b/118970344): Whether ControlTrigger ops are allowed. It is unsound + // to cluster ControlTrigger because of how we use deadness analysis. + bool allow_control_trigger = false; + + // Whether it is okay to "cluster" Assert and CheckNumerics by simply + // removing them (they're not removed during clustering, but their + // XlaOpKernel is a no-op kernel). We avoid auto-clustering these ops so + // that the user is not surprised when XLA is implicitly enabled. If the + // user explicitly specifies to use XLA, it is fine to resort to a dummy + // implementation. Currently Assert and CheckNumerics ops have dummy XLA + // implementations. + bool allow_eliding_assert_and_checknumerics_ops = false; + + // Whether ops that produce or consume DT_VARIANT values are allowed. We + // don't auto-cluster these ops because we don't yet support live-in or + // live-out DT_VARIANT values. + bool allow_ops_producing_or_consuming_variant = false; + + // Whether ops known to be slow on XLA-GPU should be considered compilable. + bool allow_slow_ops = false; + + // Whether ops known to have numerical accuracy issues should be considered + // compilable.. + bool allow_inaccurate_ops = false; + + // Require the function to be always compilable, regardless whether some + // control flow branches might be dead for a given input. + bool require_always_compilable = false; + + // Whether string constants are compilable. + bool allow_string_consts = true; + + // Whether to allow the compilation of CollectiveReduceV2Op. + bool allow_collective_reduce_v2 = true; + + // Whether to allow the compilation of WhereOp. + bool allow_where_op = true; + + // Whether to allow the compilation of UniqueOp. Compilation of the UniqueOp + // generates output with bounded dynamic shape that may cause failures with + // auto clustering. + // TODO(b/209813421): Enable tf.unique during + // autoclustering once all failures are rfixed. + bool allow_unique_op = true; + + // Whether ops that are marked as outside compiled are always considered + // compilable. + // TODO(b/191502757): Make this behavior true by default and remove this + // option once inference converter supports outside compilation. + bool allow_outside_compiled = false; + }; + + RecursiveCompilabilityChecker(OperationFilter op_filter, + DeviceType jit_device_type) + : op_filter_(std::move(op_filter)), + jit_device_type_(std::move(jit_device_type)) {} + + using UncompilableNodesMap = + std::map>>; + + // Returns a map where the key is the function identifier(short debug + // string) of the function encapsulating the uncompilable nodes, and the + // value is a pair of NameAttrList of the function and a vector of + // uncompilable node info. When uncompilable node is not inside any + // function call nodes, then key is a ShortDebugString() of an empty + // NameAttrList. + // + // Also, when `node` is inside a function body, users can set + // `node_stack_trace` to provide an additional context for `node`'s + // placement within the outer most graph. + UncompilableNodesMap FindUncompilableNodes( + const Node& node, FunctionLibraryRuntime* lib_runtime, + const std::vector* node_stack_trace = nullptr) const; + + // Returns true if `node` can be compiled by XLA. + bool IsCompilableNode(const Node& node, + FunctionLibraryRuntime* lib_runtime) const { + std::vector stack_trace; + stack_trace.emplace_back(StackFrameView{node.name(), ""}); + return IsCompilableNode(node, lib_runtime, &stack_trace); + } + + // Returns true if XLA supports this Op, but we don't want to cluster it (ie: + // due to performance or correctness concerns). + bool OpIsInaccurate(const Node& node) const; + bool OpIsSlow(const Node& node) const; + + private: + struct StackFrameView { + absl::string_view name; + absl::string_view function_name; + std::shared_ptr stack_trace; + }; + + bool IsCompilableNode( + const Node& node, FunctionLibraryRuntime* lib_runtime, + std::vector* stack_trace, + NameAttrList* encapsulating_function = nullptr, + UncompilableNodesMap* uncompilable_nodes = nullptr) const; + bool IsCompilableCall( + const NodeDef& call_def, FunctionLibraryRuntime* lib_runtime, + std::vector* stack_trace, + NameAttrList* encapsulating_function = nullptr, + UncompilableNodesMap* uncompilable_nodes = nullptr) const; + bool IsCompilableIf(const Node& if_node, FunctionLibraryRuntime* lib_runtime, + std::vector* stack_trace, + NameAttrList* encapsulating_function, + UncompilableNodesMap* uncompilable_nodes) const; + bool IsCompilableWhile(const Node& while_node, + FunctionLibraryRuntime* lib_runtime, + std::vector* stack_trace, + NameAttrList* encapsulating_function, + UncompilableNodesMap* uncompilable_nodes) const; + + // Tests whether 'case_node' is compilable. Every operator in all branches + // must be compilable. + bool IsCompilableCase(const Node& case_node, + FunctionLibraryRuntime* lib_runtime, + std::vector* stack_trace, + NameAttrList* encapsulating_function, + UncompilableNodesMap* uncompilable_nodes) const; + + // Returns compilability of node def retrieved from `node`'s attribute with + // name `attr_name`. + bool ExtractNodeDefAndCheckCompilability( + const Node& node, const std::string& attr_name, + const std::string& call_name, NameAttrList* encapsulating_function, + FunctionLibraryRuntime* lib_runtime, + std::vector* stack_trace, + UncompilableNodesMap* uncompilable_nodes) const; + + bool IsStackOp(const Node& node) const { + const XlaResourceOpInfo* op_info = + GetResourceOpInfoForOp(node.type_string()); + return op_info && op_info->resource_kind() == XlaResourceKind::kStack; + } + + bool IsTensorArrayOp(const Node& node) const { + const XlaResourceOpInfo* op_info = + GetResourceOpInfoForOp(node.type_string()); + return op_info && op_info->resource_kind() == XlaResourceKind::kTensorArray; + } + + bool IsAssertOrCheckNumerics(absl::string_view op_name) const { + return op_name == "Assert" || op_name == "CheckNumerics"; + } + + bool IsStatefulRandomOp(absl::string_view op_name) const { + return op_name == "RandomUniform" || op_name == "RandomShuffle" || + op_name == "RandomUniformInt" || op_name == "RandomStandardNormal" || + op_name == "TruncatedNormal" || op_name == "Multinomial"; + } + + bool OpProducesOrConsumesVariant(const Node& node) const { + auto is_variant = [](DataType dtype) { return dtype == DT_VARIANT; }; + return absl::c_any_of(node.input_types(), is_variant) || + absl::c_any_of(node.output_types(), is_variant); + } + + bool HasXLAKernel(const Node& node, + string* uncompilable_reason = nullptr) const; + + static void MaybeMarkUncompilableNode( + const absl::string_view reason, + const std::vector& stack_trace, + NameAttrList* encapsulating_function, + UncompilableNodesMap* uncompilable_nodes_map); + + // Make sure we don't recurse infinitely on recursive functions. + const size_t kMaxRecursionDepth = 50; + + const OperationFilter op_filter_; + const DeviceType jit_device_type_; +}; + +RecursiveCompilabilityChecker::OperationFilter CreateOperationFilter( + const XlaOpRegistry::DeviceRegistration& registration); + +// Given a FunctionLibraryRuntime and a `function`, returns this function's body +// in `fbody` as well as the indices of its constant and resource arguments. +// `fbody` is owned by `flr`. +// `constant_arg_indices` and `resource_arg_indices` should be empty vector. +// They are sorted in ascending order on this function's return. +absl::Status GetBodyAndConstantsAndResources( + FunctionLibraryRuntime* flr, const NameAttrList& function, + const FunctionBody** fbody, std::vector* constant_arg_indices, + std::vector* resource_arg_indices); + +// Given a NodeDef `node_def` returns true iff `node_def` has kXlaCompileAttr +// set. +bool CanCreateXlaKernel(const NodeDef& node_def); + +// Returns memory types for the input. +// `constant_arg_indices` and `resource_arg_indices` are sorted arrays of +// indices corresponding to constant and resource arguments respectively. +// +// One might wonder, about the case where a compile-time constant argument +// (which must be in host memory) is also used as an input into an op, +// e.g. `Add`, that expects its inputs in device memory. Here is how it +// works now. +// First, what do we mean by "op expects an input in XYZ memory"? +// There are two types of "ops" here: the tf2xla kernel and the HLO +// computation it builds. The tf2xla kernel needs to retrieve the actual +// numeric value of the compile-time constant tensors, so it really expects +// them to be on in host memory. However, for other inputs, it refers to them +// using xla::ComputationDataHandle, which is just a symbolic handle that +// xla::ComputationBuilder assigns. How does this handle gets assigned for +// constant arguments? Even constant arguments get an _Arg node in the graph +// instantiated for Function compilation. The tf2xla kernel for constant _Arg +// nodes takes the constant value, converts it to XlaLiteral, and feeds it +// to xla::ComputationBuilder.ConstantLiteral, which returns the handle. This +// constant XlaLiteral is included in the HLO graph, and subsequently, in +// the actual executable, which is copied to the device before being +// executed. Thus, when this executable runs, the constant is available in +// device memory. +tensorflow::MemoryTypeVector GetInputMemoryTypes( + const tensorflow::FunctionBody* fbody, + absl::Span constant_arg_indices, + absl::Span resource_arg_indices); + +// Returns output memory types. +// +// XlaLaunch kernel keeps all outputs (including constants, which it copies), +// in device memory except for resources. +tensorflow::MemoryTypeVector GetOutputMemoryTypes( + const tensorflow::FunctionBody* fbody); + +// Check whether graph can trigger XLA compilation. +bool CanTriggerXlaCompilation(const GraphDef& graph); + +// Returns true iff the node can trigger XLA compilation. +bool NodeCanTriggerXlaCompilation(const NodeDef& node); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_JIT_COMPILABILITY_CHECK_UTIL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/jit/deadness_analysis.h b/third_party/tflite-hdrs/tensorflow/compiler/jit/deadness_analysis.h new file mode 100644 index 00000000..80fa9a20 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/jit/deadness_analysis.h @@ -0,0 +1,99 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_JIT_DEADNESS_ANALYSIS_H_ +#define TENSORFLOW_COMPILER_JIT_DEADNESS_ANALYSIS_H_ + +#include "tensorflow/core/graph/graph.h" + +namespace tensorflow { + +// This analyzes a TensorFlow graph to identify nodes which may have partially +// dead inputs (i.e. these nodes may have some dead inputs and some alive +// inputs). +// +// For example, the ADD node in the following graph +// +// V0 PRED0 V1 PRED1 +// | | | | +// v v v v +// SWITCH SWITCH +// | | +// +---+ + ---+ +// | | +// v v +// ADD +// +// can have its inputs independently dead or alive based on the runtime values +// of PRED0 and PRED1. +// +// It is tempting to call this a liveness analysis but I avoided that because +// "liveness" already has other connotations. +class DeadnessAnalysis { + public: + // An opaque representation of a predicate. DeadnessPredicate + // instances that compare equal via operator== represent predicates + // that always evaluate to the same value. + struct DeadnessPredicate { + public: + DeadnessPredicate(const DeadnessPredicate&) = default; + DeadnessPredicate(DeadnessPredicate&&) = default; + + DeadnessPredicate& operator=(const DeadnessPredicate&) = default; + DeadnessPredicate& operator=(DeadnessPredicate&&) = default; + + bool operator==(const DeadnessPredicate& other) const { + return other.pred_ == pred_; + } + + bool operator!=(const DeadnessPredicate& other) const { + return other.pred_ != pred_; + } + + private: + explicit DeadnessPredicate(void* pred) : pred_(pred) {} + + // This is really a Predicate*, but we don't want to expose that + // implementation detail to our clients. `pred_` has pointer equality so we + // can just compare the pointer in operator== and operator!=. + void* pred_; + + friend class DeadnessAnalysis; + }; + + virtual absl::StatusOr GetPredicateFor(Node* n, + int oidx) const = 0; + + // Prints out the internal state of this instance. For debugging purposes + // only. + virtual void Print() const = 0; + virtual ~DeadnessAnalysis(); + + string DebugString(DeadnessPredicate predicate) const; + + // Run the deadness analysis over `graph` and returns an error or a populated + // instance of DeadnessAnalysis in `result`. + static absl::Status Run(const Graph& graph, + std::unique_ptr* result); + + protected: + static DeadnessPredicate MakeDeadnessPredicate(void* pred) { + return DeadnessPredicate(pred); + } +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_JIT_DEADNESS_ANALYSIS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/jit/deadness_analysis_internal.h b/third_party/tflite-hdrs/tensorflow/compiler/jit/deadness_analysis_internal.h new file mode 100644 index 00000000..0dc18d3e --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/jit/deadness_analysis_internal.h @@ -0,0 +1,35 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_JIT_DEADNESS_ANALYSIS_INTERNAL_H_ +#define TENSORFLOW_COMPILER_JIT_DEADNESS_ANALYSIS_INTERNAL_H_ + +#include "absl/container/flat_hash_map.h" +#include "tensorflow/core/graph/tensor_id.h" + +namespace tensorflow { +namespace deadness_analysis_internal { + +// Returns a map describing the predicate each Tensor was mapped to. For +// testing purposes only. +using PredicateMapTy = absl::flat_hash_map; +absl::Status ComputePredicates(const Graph& graph, + PredicateMapTy* out_predicate_map, + bool enable_optimistic = true); + +} // namespace deadness_analysis_internal +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_JIT_DEADNESS_ANALYSIS_INTERNAL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/jit/defs.h b/third_party/tflite-hdrs/tensorflow/compiler/jit/defs.h new file mode 100644 index 00000000..58bd4bdd --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/jit/defs.h @@ -0,0 +1,49 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Provides definitions needed for use of the TensorFlow XLA +// device. + +#ifndef TENSORFLOW_COMPILER_JIT_DEFS_H_ +#define TENSORFLOW_COMPILER_JIT_DEFS_H_ + +namespace tensorflow { + +// Name of attribute used to tag operators for compilation with XLA + +// Implies must-compile semantics: either it will be compiled +// with XLA, or an error will be thrown. +extern const char* const kXlaMustCompileAttr; // "_XlaMustCompile" + +// Implies auto-clustering: tagged nodes will be clustered and compiled with XLA +// on a best-effort basis. +extern const char* const kXlaCompileAttr; // "_XlaCompile" + +// Implies auto-clustering within the given scope. +extern const char* const kXlaScopeAttr; // "_XlaScope" +extern const char* const kXlaInternalScopeAttr; // "_XlaInternalScope" + +// The id of the compiled cluster. +extern const char* const kXlaClusterIdAttr; // "_xla_compile_id" + +[[deprecated("XLA:CPU/GPU devices are deprecated")]] void +RequestXlaDevicesCreation(); + +[[deprecated("XLA:CPU/GPU devices are deprecated")]] bool +XlaDevicesCreationRequired(); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_JIT_DEFS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/jit/device_compilation_cache.h b/third_party/tflite-hdrs/tensorflow/compiler/jit/device_compilation_cache.h new file mode 100644 index 00000000..ad871349 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/jit/device_compilation_cache.h @@ -0,0 +1,258 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_JIT_DEVICE_COMPILATION_CACHE_H_ +#define TENSORFLOW_COMPILER_JIT_DEVICE_COMPILATION_CACHE_H_ + +#include +#include +#include +#include +#include +#include + +#include "absl/strings/str_cat.h" +#include "tensorflow/compiler/jit/device_compilation_cluster_signature.h" +#include "tensorflow/compiler/jit/xla_compile_util.h" +#include "tensorflow/compiler/tf2xla/xla_compiler.h" +#include "xla/client/local_client.h" +#include "xla/pjrt/pjrt_client.h" +#include "tensorflow/core/platform/mutex.h" + +namespace tensorflow { +namespace device_compilation_cache_internal { +template +int64_t ExecutableSize(const ExecutableType* executable) { + return 0; +} + +template <> +inline int64_t ExecutableSize( + const xla::LocalExecutable* executable) { + if (executable != nullptr && executable->executable() != nullptr) { + return executable->executable()->SizeOfGeneratedCodeInBytes(); + } + + return 0; +} + +template <> +inline int64_t ExecutableSize( + const xla::PjRtLoadedExecutable* executable) { + if (executable != nullptr) { + return executable->SizeOfGeneratedCodeInBytes(); + } + + return 0; +} +} // namespace device_compilation_cache_internal + +// Cache to store compiled HLO, executables and related metadata keyed by +// `DeviceCompilationClusterSignature`. The cache owns the stored +// CompilationResults and Executables. +// Currently no cache eviction policy is implemented and the cache grows without +// bound. +template +class DeviceCompilationCache { + public: + DeviceCompilationCache() = default; + ~DeviceCompilationCache() = default; + + using Key = DeviceCompilationClusterSignature; + struct Value { + DeviceCompileState compile_state = DeviceCompileState::kUncompiled; + absl::Status compilation_status; + int64_t request_count = 0; + const XlaCompiler::CompilationResult* compilation_result = nullptr; + ExecutableType* executable = nullptr; + }; + + // Returns std::nullopt if value for the supplied key is not found. If a value + // is found, `request_count` is incremented before returning the value. + std::optional Lookup(const Key& key) const; + + // Inserts an empty value if value is not found and returns it. If a value is + // found, `request_count` is incremented before returning the value. + Value LookupOrCreate(const Key& key); + + // Caches `compile_state`, `compilation_status`, `compilation_result` and + // `executable` and associates them with the provided `key`. Takes ownership + // of `compilation_result` and `executable`. Does not increment the + // corresponding `request_count`. Only arguments that are not std::nullopt are + // updated in the cache. + void Store(const Key& key, std::optional compile_state, + std::optional compilation_status, + std::optional> + compilation_result, + std::optional> executable); + + std::string DebugString() const; + + private: + // The value associated with a cache entry. + struct Entry { + mutable mutex mu; + + // The current compilation state for this entry. + DeviceCompileState compile_state TF_GUARDED_BY(mu) = + DeviceCompileState::kUncompiled; + + // The number of times a compilation with this signature has been requested. + int64_t request_count TF_GUARDED_BY(mu) = 0; + + // Did compilation succeed? + absl::Status compilation_status TF_GUARDED_BY(mu); + + // Output of the XlaCompiler. + std::unique_ptr compilation_result + TF_GUARDED_BY(mu); + + // The XLA executable compiled from . May be null if no + // executable has been built. + std::unique_ptr executable TF_GUARDED_BY(mu); + + std::string DebugString() const { + mutex_lock lock(mu); + + int64_t executable_size = + device_compilation_cache_internal::ExecutableSize( + executable.get()); + + int64_t hlo_module_size = 0; + if (compilation_result != nullptr && + compilation_result->computation != nullptr) { + hlo_module_size = + compilation_result->computation->proto().ByteSizeLong(); + } + + return absl::StrCat( + "{compile_state: ", compile_state, ", request_count: ", request_count, + ", compilation_status: ", compilation_status.ToString(), + ", compilation_result?: ", compilation_result != nullptr, + ", hlo_module_size: ", hlo_module_size, " bytes", + ", executable?: ", executable != nullptr, + ", executable_size: ", executable_size, " bytes}"); + } + }; + + mutable mutex compile_cache_mu_; + absl::flat_hash_map, Key::Hash> cache_ + TF_GUARDED_BY(compile_cache_mu_); + + DeviceCompilationCache(const DeviceCompilationCache&) = delete; + void operator=(const DeviceCompilationCache&) = delete; +}; + +template +std::optional::Value> +DeviceCompilationCache::Lookup(const Key& key) const { + // The outer lock protects the existence of the cache entry. It does not + // protect the contents of the cache entry. + Entry* entry; + { + mutex_lock lock(compile_cache_mu_); + // Find cache entry. + auto it = cache_.find(key); + if (it == cache_.cend()) { + return std::nullopt; + } + + entry = it->second.get(); + } + + mutex_lock lock(entry->mu); + Value value = {/*compile_state=*/entry->compile_state, + /*compilation_status=*/entry->compilation_status, + /*request_count=*/++entry->request_count, + /*compilation_result=*/entry->compilation_result.get(), + /*executable=*/entry->executable.get()}; + return value; +} + +template +typename DeviceCompilationCache::Value +DeviceCompilationCache::LookupOrCreate(const Key& key) { + // The outer lock protects the existence of the cache entry. It does not + // protect the contents of the cache entry. + Entry* entry; + { + mutex_lock lock(compile_cache_mu_); + // Emplace empty cache entry if not found. + auto it = cache_.emplace(key, std::make_unique()).first; + entry = it->second.get(); + } + + mutex_lock lock(entry->mu); + Value value = {/*compile_state=*/entry->compile_state, + /*compilation_status=*/entry->compilation_status, + /*request_count=*/++entry->request_count, + /*compilation_result=*/entry->compilation_result.get(), + /*executable=*/entry->executable.get()}; + return value; +} + +template +void DeviceCompilationCache::Store( + const Key& key, std::optional compile_state, + std::optional compilation_status, + std::optional> + compilation_result, + std::optional> executable) { + Entry* entry; + { + mutex_lock lock(compile_cache_mu_); + // Emplace empty cache entry if not found. + auto it = cache_.emplace(key, std::make_unique()).first; + entry = it->second.get(); + } + + { + mutex_lock lock(entry->mu); + if (compile_state.has_value()) { + entry->compile_state = *compile_state; + } + if (compilation_status.has_value()) { + entry->compilation_status = *compilation_status; + } + if (compilation_result.has_value()) { + entry->compilation_result = std::move(*compilation_result); + } + if (executable.has_value()) { + entry->executable = std::move(*executable); + } + } + + VLOG(4) << "Added/updated cache entry: key=" << key.HumanString() + << ", entry=" << entry->DebugString(); +} + +template +std::string DeviceCompilationCache::DebugString() const { + std::string s = "DeviceCompilationCache {\n"; + { + mutex_lock lock(compile_cache_mu_); + for (const auto& [key, entry] : cache_) { + absl::StrAppend(&s, key.HumanString(), " : ", entry->DebugString(), + ",\n"); + } + } + absl::StrAppend(&s, "}"); + + return s; +} + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_JIT_DEVICE_COMPILATION_CACHE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/jit/device_compilation_cluster_signature.h b/third_party/tflite-hdrs/tensorflow/compiler/jit/device_compilation_cluster_signature.h new file mode 100644 index 00000000..4acea2a0 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/jit/device_compilation_cluster_signature.h @@ -0,0 +1,56 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_JIT_DEVICE_COMPILATION_CLUSTER_SIGNATURE_H_ +#define TENSORFLOW_COMPILER_JIT_DEVICE_COMPILATION_CLUSTER_SIGNATURE_H_ + +#include +#include + +#include "tensorflow/compiler/tf2xla/xla_compiler.h" + +namespace tensorflow { + +// Describes the types, shapes and any compile-time constant arguments +// to a kernel. Key that uniquely identifies a compilation output. +struct DeviceCompilationClusterSignature { + // Name of the cluster, built from the function name and it's attributes. + string name; + + // List of args (either as a TensorTypeAndShape or as a Tensor value) + // for compile-time constant arguments to the compilation, ordered by + // argument number. Tensors must be in host memory. + using TensorTypeAndShape = + std::pair>; + absl::InlinedVector, 8> args; + + bool operator==(const DeviceCompilationClusterSignature& other) const; + + struct Hash { + uint64 operator()(const DeviceCompilationClusterSignature& signature) const; + }; + + // Returns a human-readable description of the signature. + string HumanString() const; + + // Builds the signature for a compilation. + static absl::StatusOr Build( + const NameAttrList& function, + absl::Span args); +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_JIT_DEVICE_COMPILATION_CLUSTER_SIGNATURE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/jit/device_compilation_profiler.h b/third_party/tflite-hdrs/tensorflow/compiler/jit/device_compilation_profiler.h new file mode 100644 index 00000000..9f1d9521 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/jit/device_compilation_profiler.h @@ -0,0 +1,101 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_JIT_DEVICE_COMPILATION_PROFILER_H_ +#define TENSORFLOW_COMPILER_JIT_DEVICE_COMPILATION_PROFILER_H_ + +#include +#include + +#include "tensorflow/compiler/jit/xla_compile_util.h" +#include "tensorflow/core/framework/attr_value.pb.h" + +namespace tensorflow { + +// Tracks statistics for device compilation and uses these to determine whether +// the given cluster should be compiled or not. +class DeviceCompilationProfiler : public ResourceBase { + public: + DeviceCompilationProfiler() = default; + ~DeviceCompilationProfiler() override; + + struct ClusterCompileStats { + // Number of times the cluster has been (re-)compiled. + int64_t compile_count = 0; + + // The number of times this cluster has been executed. + int64_t execution_count = 0; + + // Cumulative time spent compiling the cluster. + int64_t cumulative_compile_time_us = 0; + + // True if we have decided that this cluster is too dynamic (i.e. its shapes + // change too frequently) to profitably JIT compile. Once a cluster is + // tagged megamorphic, it stays megamorphic forever. + bool is_megamorphic = false; + + std::string DebugString() const { + return absl::StrCat( + "DeviceCompilationProfiler::ClusterCompileStats {compile_count=", + compile_count, ", execution_count=", execution_count, + ", cumulative_compile_time_us=", cumulative_compile_time_us, + ", is_megamorphic=", is_megamorphic, "}"); + } + }; + + // Returns the compilation statistics for the given cluster. + absl::StatusOr GetCompileStats( + const NameAttrList& function) const; + + // Determines whether the cluster should be compiled. Creates and inserts an + // entry into stats (also calls `RegisterExecution`) for `function` if it + // doesn't already exist. + virtual bool ShouldCompileCluster(const NameAttrList& function, + DeviceCompileMode compile_mode, + int64_t current_request_count); + + // Registers a cluster execution. Increments the execution count for the given + // cluster and also determines whether the cluster has gone megamorphic (and + // sets the megamorphic bit accordingly). + void RegisterExecution(const NameAttrList& function); + + // Registers a cluster compilation. Increments the compilation count and + // accumulates the compile time for the given cluster. Also broadcasts an + // XlaJitCompilationActivity. + virtual absl::Status RegisterCompilation(const NameAttrList& function, + int64_t compile_time_us, + bool used_persistent_cache); + + void IncrementOngoingAsyncCompilations(); + void DecrementOngoingAsyncCompilations(); + int64_t GetNumOngoingAsyncCompilations() const; + std::string DebugString() const override; + + private: + mutable mutex mu_; + + // Maps cluster names to compilation statistics for said cluster. + absl::flat_hash_map cluster_compile_stats_ + TF_GUARDED_BY(mu_); + + int64_t num_ongoing_compilations_ TF_GUARDED_BY(mu_) = 0; + + DeviceCompilationProfiler(const DeviceCompilationProfiler&) = delete; + void operator=(const DeviceCompilationProfiler&) = delete; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_JIT_DEVICE_COMPILATION_PROFILER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/jit/device_compiler.h b/third_party/tflite-hdrs/tensorflow/compiler/jit/device_compiler.h new file mode 100644 index 00000000..1baa7085 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/jit/device_compiler.h @@ -0,0 +1,504 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_JIT_DEVICE_COMPILER_H_ +#define TENSORFLOW_COMPILER_JIT_DEVICE_COMPILER_H_ + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/base/call_once.h" +#include "absl/container/flat_hash_map.h" +#include "absl/types/optional.h" +#include "absl/types/span.h" +#include "tensorflow/compiler/jit/device_compilation_cache.h" +#include "tensorflow/compiler/jit/device_compilation_cluster_signature.h" +#include "tensorflow/compiler/jit/device_compilation_profiler.h" +#include "tensorflow/compiler/jit/device_compiler_client.h" +#include "tensorflow/compiler/jit/device_executable_persistor.h" +#include "tensorflow/compiler/jit/flags.h" +#include "tensorflow/compiler/jit/tf_graph_to_hlo_compiler.h" +#include "tensorflow/compiler/jit/xla_compile_util.h" +#include "tensorflow/compiler/tf2xla/xla_compiler.h" +#include "xla/client/local_client.h" +#include "tensorflow/core/framework/metrics.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/lib/core/threadpool.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/thread_annotations.h" + +namespace tensorflow { + +// Compiles/lowers a given Tensorflow graph/function/cluster into a compiled XLA +// compilation (HLO) using the XlaCompiler and compiles the resulting +// XlaCompilationResult into an `ExecutableType` (eg. xla::LocalExecutable) by +// calling `ClientType` (eg. xla::LocalClient). +// +// Caches the compiled XlaCompilationResult and Executable using a +// DeviceCompilationCache. Compilation is done only when there's a cache miss. +// +// Uses the DeviceExecutablePersistor class for persistence and tries to load a +// serialized executable from disk upon a request for compilation. If the +// appropriate executable isn't found on disk, compiles the given Tensorflow +// function/graph/cluster into an XlaCompilationResult (HLO) and +// `ExecutableType` and tries saving/persisting the compiled HLO and executable +// to disk. +// +// Since XLA computations must have static shapes, DeviceCompiler generates a +// new XLA computation for each new set of input shapes. +// TODO(b/255826209): De-templatize once we've moved to Device API completely. +template +class DeviceCompiler : public ResourceBase { + public: + DeviceCompiler( + std::unique_ptr> + persistor, + std::unique_ptr> + compiler_client); + ~DeviceCompiler() override; + + enum class CompileScope { + kOp, + kFunction, + }; + + // Compiles a function into a XlaCompiler::CompilationResult that can be used + // to execute an XLA Computation. Compilation results are cached. Compilation + // is skipped if there is a cache hit. `function` is the name of a Tensorflow + // function to compile. `args` is a description of the arguments to the + // computation. + // + // `compile_mode` controls the behavior of the compilation cache on a cache + // miss. If `compile_mode` is `kLazy` then, based on some profitability + // heuristics, the compilation cache may decide not to compile the cluster at + // this time. In this case it returns null into both `out_compilation_result` + // and `out_executable`. If `compile_mode` is `kStrict` then the compilation + // cache always attempts the compilation on a cache miss. If compilation mode + // is 'kAsync' compilation of the cluster happens in the background while the + // fallback path executes. + // + // The result of compilation is written to `*out_compilation_result`, which + // must be non-null. If `out_executable` is non-null, also builds an + // `ExecutableType` and sets `out_executable` to point to it. The + // resulting executable pointer may be null if the computation has no + // non-constant outputs. + absl::Status CompileIfNeeded( + const XlaCompiler::Options& options, const NameAttrList& function, + const std::vector& args, + const XlaCompiler::CompileOptions& compile_options, + DeviceCompileMode compile_mode, DeviceCompilationProfiler* profiler, + const XlaCompiler::CompilationResult** out_compilation_result, + ExecutableType** out_executable); + + // As above, but for a single op. + absl::Status CompileSingleOpIfNeeded( + const XlaCompiler::Options& options, + const std::vector& args, + const XlaCompiler::CompileOptions& compile_options, OpKernelContext* ctx, + DeviceCompilationProfiler* profiler, + const XlaCompiler::CompilationResult** out_compilation_result, + ExecutableType** out_executable); + + ClientType* client() const { return compiler_client_->client(); } + const DeviceType& device_type() const { return persistor_->device_type(); } + DeviceCompilationCache* cache() { return cache_.get(); } + DeviceExecutablePersistor* persistor() { + return persistor_.get(); + } + DeviceCompilerClient* compiler_client() { + return compiler_client_.get(); + } + + string DebugString() const override; + + private: + // Common implementation of Compile and CompileSingleOp. The `OpKernelContext` + // parameter is always null for the former. + absl::Status CompileImpl( + const XlaCompiler::CompileOptions& compile_options, + const XlaCompiler::Options& options, const NameAttrList& function, + const std::vector& args, CompileScope scope, + DeviceCompileMode compile_mode, OpKernelContext* ctx, + DeviceCompilationProfiler* profiler, + const XlaCompiler::CompilationResult** out_compilation_result, + ExecutableType** out_executable); + + StatusOr::Value> + CompileStrict( + const DeviceCompilationClusterSignature& sig, + const XlaCompiler::CompileOptions& compile_options, + const XlaCompiler::Options& options, + const std::vector& args, + const NameAttrList& function, + typename DeviceCompilationCache::Value cache_value, + CompileScope scope, OpKernelContext* ctx, + DeviceCompilationProfiler* profiler, mutex* mu) + TF_EXCLUSIVE_LOCKS_REQUIRED(*mu); + + absl::Status CompileAsynchronous( + const DeviceCompilationClusterSignature& sig, + const XlaCompiler::CompileOptions& compile_options, + const XlaCompiler::Options& options, + const std::vector& args, + const NameAttrList& function, CompileScope scope, OpKernelContext* ctx, + DeviceCompilationProfiler* profiler); + + std::unique_ptr> + persistor_; + std::unique_ptr> + compiler_client_; + std::unique_ptr> cache_; + + // Pool of threads for asynchronous compilations. + std::unique_ptr async_compiler_threads_; + + mutex cluster_mutexes_mu_; + absl::flat_hash_map, + DeviceCompilationClusterSignature::Hash> + cluster_mutexes_ TF_GUARDED_BY(cluster_mutexes_mu_); + + DeviceCompiler(const DeviceCompiler&) = delete; + void operator=(const DeviceCompiler&) = delete; +}; + +namespace device_compiler_internal { +// Print something that users can search for to definitively ascertain that XLA +// was used for their TF model. +// Prints only once to avoid spamming LOG(INFO). +inline void LogOnceXlaCompiledFirstCluster() { + static absl::once_flag log_once; + absl::call_once(log_once, [] { + LOG(INFO) << "Compiled cluster using XLA! This line is logged at most " + "once for the lifetime of the process."; + }); +} + +template +inline absl::Status EligibleToPersist(DeviceCompileState compile_state, + const ExecutableType* executable) { + if (compile_state != DeviceCompileState::kCompiled) { + return errors::FailedPrecondition( + "Cache entry to serialize is not compiled."); + } + if (executable == nullptr) { + return errors::FailedPrecondition( + "LocalExecutable not found for cache entry to serialize."); + } + return absl::OkStatus(); +} +} // namespace device_compiler_internal + +template +DeviceCompiler::DeviceCompiler( + std::unique_ptr> + persistor, + std::unique_ptr> + compiler_client) + : persistor_(std::move(persistor)), + compiler_client_(std::move(compiler_client)) { + cache_ = std::make_unique>(); + async_compiler_threads_ = std::make_unique( + tensorflow::Env::Default(), "async_compiler_threads", + kNumAsyncDeviceCompilerThreads); +} + +template +DeviceCompiler::~DeviceCompiler() { + // Since programs are owned by the cache, ensure any use of our programs have + // completed by waiting for all stream executors to complete. + compiler_client_->WaitForProgramsToFinish(); + // Wait for all outstanding compilations to finish. + // Resetting the pointer explicitly in the top level destructor. + // Without this, the pointer would be reset when the AsyncCompilationState + // is destructed, which is dependent on the order of the members in the + // DeviceCompiler class, which is error prone if the order changes. + async_compiler_threads_.reset(); + // TODO(b/110813685): Think about the program ownership model. Programs are + // currently owned by the compilation cache which means we must wait for + // program completion in the destructor. There are multiple compilation caches + // around, which complicates things a little. Perhaps having programs be + // shared_ptrs (an invasive change) would make the model easier to reason + // about? +} + +template +string DeviceCompiler::DebugString() const { + return "DeviceCompiler"; +} + +template +absl::Status DeviceCompiler::CompileIfNeeded( + const XlaCompiler::Options& options, const NameAttrList& function, + const std::vector& args, + const XlaCompiler::CompileOptions& compile_options, + DeviceCompileMode compile_mode, DeviceCompilationProfiler* profiler, + const XlaCompiler::CompilationResult** out_compilation_result, + ExecutableType** out_executable) { + return CompileImpl(compile_options, options, function, args, + CompileScope::kFunction, compile_mode, /*ctx=*/nullptr, + profiler, out_compilation_result, out_executable); +} + +template +absl::Status +DeviceCompiler::CompileSingleOpIfNeeded( + const XlaCompiler::Options& options, + const std::vector& args, + const XlaCompiler::CompileOptions& compile_options, OpKernelContext* ctx, + DeviceCompilationProfiler* profiler, + const XlaCompiler::CompilationResult** out_compilation_result, + ExecutableType** out_executable) { + const NodeDef& def = ctx->op_kernel().def(); + NameAttrList name; + name.set_name(def.op()); + *name.mutable_attr() = def.attr(); + // Remove the "_class" attribute from the attribute set used to create the + // compilation cache key. This attribute is information for the colocator + // and causes false uniqueness between nodes. + name.mutable_attr()->erase("_class"); + return CompileImpl(compile_options, options, name, args, CompileScope::kOp, + DeviceCompileMode::kStrict, ctx, profiler, + out_compilation_result, out_executable); +} + +template +StatusOr::Value> +DeviceCompiler::CompileStrict( + const DeviceCompilationClusterSignature& sig, + const XlaCompiler::CompileOptions& compile_options, + const XlaCompiler::Options& options, + const std::vector& args, + const NameAttrList& function, + typename DeviceCompilationCache::Value cache_value, + CompileScope scope, OpKernelContext* ctx, + DeviceCompilationProfiler* profiler, mutex* mu) { + tensorflow::Env* env = tensorflow::Env::Default(); + const uint64 compile_start_us = env->NowMicros(); + + TfGraphToHloCompiler compiler(options); + cache_value.compile_state = DeviceCompileState::kCompiled; + + std::unique_ptr out_executable; + auto out_compilation_result = + std::make_unique(); + + if (scope == CompileScope::kOp) { + cache_value.compilation_status = compiler.CompileSingleOp( + compile_options, ctx, args, out_compilation_result.get()); + } else { + CHECK(scope == CompileScope::kFunction); // Crash OK + cache_value.compilation_status = compiler.Compile( + compile_options, function, args, out_compilation_result.get()); + } + TF_RETURN_IF_ERROR(cache_value.compilation_status); + TF_RET_CHECK(cache_value.executable == nullptr); + TF_RET_CHECK(out_compilation_result->computation != nullptr); + + auto loaded_executable = persistor_->TryToLoadExecutable( + DeviceCompilationClusterSignature::Hash()(sig), sig.HumanString(), + options, *out_compilation_result, compiler_client_.get()); + + if (loaded_executable.has_value()) { + cache_value.compilation_status = loaded_executable->status(); + if (loaded_executable->ok()) { + out_executable = *std::move(*loaded_executable); + metrics::UpdatePersistentCacheLoadCount(); + } + } else { + auto built_executable = + compiler_client_->BuildExecutable(options, *out_compilation_result); + TF_RETURN_IF_ERROR(built_executable.status()); + out_executable = *std::move(built_executable); + + TF_RETURN_IF_ERROR( + device_compiler_internal::EligibleToPersist( + cache_value.compile_state, out_executable.get())); + TF_RETURN_IF_ERROR(persistor_->TryToPersistExecutable( + DeviceCompilationClusterSignature::Hash()(sig), sig.HumanString(), + options, *out_compilation_result, *out_executable, + compiler_client_.get())); + } + + cache_value.compilation_result = out_compilation_result.get(); + cache_value.executable = out_executable.get(); + cache_->Store(sig, cache_value.compile_state, cache_value.compilation_status, + std::move(out_compilation_result), std::move(out_executable)); + + const uint64 compile_end_us = env->NowMicros(); + const uint64 compile_time_us = compile_end_us - compile_start_us; + + device_compiler_internal::LogOnceXlaCompiledFirstCluster(); + TF_RETURN_IF_ERROR(profiler->RegisterCompilation( + function, compile_time_us, loaded_executable.has_value())); + return cache_value; +} + +template +absl::Status DeviceCompiler::CompileAsynchronous( + const DeviceCompilationClusterSignature& signature, + const XlaCompiler::CompileOptions& compile_options, + const XlaCompiler::Options& options, + const std::vector& args, + const NameAttrList& function, CompileScope scope, OpKernelContext* ctx, + DeviceCompilationProfiler* profiler) { + // Explicitly capture all required data by value for async compilation. + // Update compilation state in cache. + cache_->Store(signature, DeviceCompileState::kCompiling, std::nullopt, + std::nullopt, std::nullopt); + profiler->IncrementOngoingAsyncCompilations(); + // Don't move the above code into the thread function as it synchronously + // updates the async compilation state! + + // When the ThreadPool for the compilation cache is destroyed, it waits for + // compilations to have finished. This means that both 'entry' and 'this' will + // be alive for the duration of the compilation. + // !!Pay attention when additional variables must be captured by this lambda!! + // All values are captured by value. Make sure that all pointer values (like + // entry) do not get freed until the lambda has finished. + const std::string& function_name = function.name(); + async_compiler_threads_->Schedule([=] { + VLOG(2) << "Starting asynchronous compilation of cluster " << function_name + << '.'; + // We don't need to lock mu, but do it anyway to satisfy thread safety + // analysis. + mutex mu; + mutex_lock lock(mu); + auto cache_value = typename DeviceCompilationCache::Value(); + auto s = CompileStrict(signature, compile_options, options, args, function, + cache_value, scope, ctx, profiler, &mu); + VLOG(2) << "Finished asynchronous compililation of cluster " + << function_name << '.'; + profiler->DecrementOngoingAsyncCompilations(); + // Update compilation status in cache. + if (!s.ok()) { + cache_->Store(signature, std::nullopt, s.status(), std::nullopt, + std::nullopt); + } + }); + return absl::OkStatus(); +} + +template +absl::Status DeviceCompiler::CompileImpl( + const XlaCompiler::CompileOptions& compile_options, + const XlaCompiler::Options& options, const NameAttrList& function, + const std::vector& args, CompileScope scope, + DeviceCompileMode compile_mode, OpKernelContext* ctx, + DeviceCompilationProfiler* profiler, + const XlaCompiler::CompilationResult** out_compilation_result, + ExecutableType** out_executable) { + DCHECK_NE(out_executable, nullptr); + VLOG(2) << "DeviceCompiler::Compile " << DebugString(); + + if (VLOG_IS_ON(2)) { + VLOG(2) << "num_inputs=" << args.size(); + for (int i = 0, end = args.size(); i < end; i++) { + VLOG(3) << i << ": " << args[i].HumanString(); + } + } + TF_ASSIGN_OR_RETURN(auto signature, + DeviceCompilationClusterSignature::Build(function, args)); + + // The outer lock protects the existence of the mutex in the map. + mutex* cluster_mutex; + { + mutex_lock lock(cluster_mutexes_mu_); + auto it = + cluster_mutexes_.emplace(signature, std::make_unique()).first; + cluster_mutex = it->second.get(); + } + + profiler->RegisterExecution(function); + + string human_signature; + if (VLOG_IS_ON(2)) { + human_signature = VLOG_IS_ON(3) ? signature.HumanString() : function.name(); + VLOG(2) << "DeviceCompilationClusterSignature: " << human_signature; + } + + // Acquire the cache entry lock and compile, if necessary. + // TODO(phawkins): this locking will need to be restructured when we implement + // cache eviction. + mutex_lock cluster_compile_lock(*cluster_mutex); + auto cache_value = cache_->LookupOrCreate(signature); + + int64_t current_request_count = cache_value.request_count; + VLOG(2) << "Compilation cache entry hit: " + << static_cast(cache_value.compile_state) + << " signature: " << human_signature << " with request count " + << current_request_count; + + DeviceCompileState state = cache_value.compile_state; + *out_compilation_result = nullptr; + *out_executable = nullptr; + + // Check if the requested entry is uncompiled and return an error if + // compilation is disabled. This will raise an error for kLazy even if we have + // not yet hit the compilation threshold and no compilation happens this + // round. This is to avoid non-determanism of when compilation is disallowed, + // for example by changing the threshold. + if (state == DeviceCompileState::kUncompiled && FailOnXlaCompilation()) { + VLOG(1) << "XLA compilation disabled: " << function.name() << "\n" + << absl::StrJoin( + args, "\n", + [](std::string* out, const XlaCompiler::Argument& arg) { + absl::StrAppend(out, " arg: ", arg.HumanString()); + }); + + return errors::Internal("XLA compilation disabled"); + } + + if (state == DeviceCompileState::kUncompiled) { + XLA_SCOPED_LOGGING_TIMER("Compilation of XLA executable"); + if (!profiler->ShouldCompileCluster(function, compile_mode, + current_request_count)) { + VLOG(2) << "Not compiling for signature: " << human_signature; + return absl::OkStatus(); + } else if (compile_mode == DeviceCompileMode::kAsync) { + VLOG(2) << "Queueing asynchronous compilation for signature: " + << human_signature; + TF_RETURN_IF_ERROR(CompileAsynchronous(signature, compile_options, + options, args, function, scope, + ctx, profiler)); + return absl::OkStatus(); + } else { + VLOG(2) << "Instantly compiling for signature: " << human_signature; + TF_ASSIGN_OR_RETURN( + cache_value, + CompileStrict(signature, compile_options, options, args, function, + cache_value, scope, ctx, profiler, cluster_mutex)); + } + } else if (state == DeviceCompileState::kCompiling) { + VLOG(2) << "Ongoing asynchronous compilation for signature: " + << human_signature; + return absl::OkStatus(); + } else if (state == DeviceCompileState::kCompiled) { + VLOG(2) << "Already Compiled for signature: " << human_signature; + } + + TF_RETURN_IF_ERROR(cache_value.compilation_status); + *out_compilation_result = cache_value.compilation_result; + *out_executable = cache_value.executable; + return absl::OkStatus(); +} + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_JIT_DEVICE_COMPILER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/jit/device_compiler_client.h b/third_party/tflite-hdrs/tensorflow/compiler/jit/device_compiler_client.h new file mode 100644 index 00000000..358cb923 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/jit/device_compiler_client.h @@ -0,0 +1,76 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_JIT_DEVICE_COMPILER_CLIENT_H_ +#define TENSORFLOW_COMPILER_JIT_DEVICE_COMPILER_CLIENT_H_ + +#include +#include +#include + +#include "tensorflow/compiler/tf2xla/xla_compiler.h" +#include "xla/client/executable_build_options.h" + +namespace tensorflow { + +template +class DeviceCompilerClient { + public: + DeviceCompilerClient() = default; + virtual ~DeviceCompilerClient() = default; + + // Compiles `result` (HLO) to an `ExecutableType` using `ClientType` and + // returns it. + virtual StatusOr> BuildExecutable( + const XlaCompiler::Options& options, + const XlaCompiler::CompilationResult& result) = 0; + + // Serializes an available `executable` to string using `ClientType` and + // returns it. + virtual absl::StatusOr SerializeExecutable( + const ExecutableType& executable) = 0; + + // Compiles `result` (HLO) to a serializable executable (eg. + // xla::AotCompilationResult) using `ClientType`, serializes it to string and + // returns it. + virtual absl::StatusOr BuildSerializedExecutable( + const XlaCompiler::Options& options, + const XlaCompiler::CompilationResult& result) = 0; + + // Loads `serialized_executable` into an `ExecutableType` using `ClientType`. + virtual StatusOr> LoadExecutable( + const XlaCompiler::Options& options, + const XlaCompiler::CompilationResult& result, + const std::string& serialized_executable) = 0; + + // Waits for the underlying `ClientType` backend's programs to finish + // executing before returning. + virtual void WaitForProgramsToFinish() = 0; + + virtual ClientType* client() const = 0; + + private: + DeviceCompilerClient(const DeviceCompilerClient&) = delete; + void operator=(const DeviceCompilerClient&) = delete; +}; + +// Generates the ExecutableBuildOptions for compilation from HLO to +// executable. +xla::ExecutableBuildOptions GetExecutableBuildOptions( + const XlaCompiler::Options& options, + const XlaCompiler::CompilationResult& result, int default_device_ordinal); +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_JIT_DEVICE_COMPILER_CLIENT_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/jit/device_executable_persistor.h b/third_party/tflite-hdrs/tensorflow/compiler/jit/device_executable_persistor.h new file mode 100644 index 00000000..458441c8 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/jit/device_executable_persistor.h @@ -0,0 +1,404 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_JIT_DEVICE_EXECUTABLE_PERSISTOR_H_ +#define TENSORFLOW_COMPILER_JIT_DEVICE_EXECUTABLE_PERSISTOR_H_ + +#include +#include +#include + +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "tensorflow/compiler/jit/device_compiler_client.h" +#include "tensorflow/compiler/jit/xla_compilation_cache.pb.h" +#include "tensorflow/compiler/jit/xla_device_compiler_client.h" +#include "tensorflow/compiler/tf2xla/xla_compiler.h" +#include "xla/pjrt/pjrt_client.h" +#include "xla/service/hlo.pb.h" +#include "xla/util.h" +#include "tensorflow/core/framework/device.h" +#include "tensorflow/core/lib/strings/proto_serialization.h" +#include "tensorflow/core/platform/path.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/statusor.h" +#include "tensorflow/core/platform/types.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" + +namespace tensorflow { + +// Returns the persisted compilation cache file name for the given key. +std::string XlaSerializedCacheKeyToFileName(const XlaSerializedCacheKey& key); + +// Offers a way to persist and/or load compiled `ExecutableType`s along with the +// corresponding HLO (`CompilationResult`) to/from `persistent_cache_directory` +// (if one was provided during construction) on disk using `ClientType`. +template +class DeviceExecutablePersistor { + public: + // Configuration for setting up persistence (directory, filename prefix, etc). + struct Config { + Config() = default; + explicit Config(absl::string_view persistent_cache_directory, + bool disable_strict_signature_checks, + absl::string_view persistence_prefix, + bool persistent_cache_directory_read_only) + : persistent_cache_directory(persistent_cache_directory), + disable_strict_signature_checks(disable_strict_signature_checks), + persistence_prefix(persistence_prefix), + persistent_cache_directory_read_only( + persistent_cache_directory_read_only) {} + + explicit Config(absl::string_view persistent_cache_directory, + bool disable_strict_signature_checks, + absl::string_view persistence_prefix) + : persistent_cache_directory(persistent_cache_directory), + disable_strict_signature_checks(disable_strict_signature_checks), + persistence_prefix(persistence_prefix) {} + + // If non-empty, JIT-compiled executables are saved to and loaded from the + // specified file system directory path. + std::string persistent_cache_directory; + + // Disable strict signature checks for entries loaded into the cache from + // external sources. + bool disable_strict_signature_checks = false; + + // The cache persistence prefix to use if serializing/deserialzing entries. + std::string persistence_prefix; + + // Cache is read-only if set to true. + bool persistent_cache_directory_read_only = false; + }; + + DeviceExecutablePersistor(const Config& config, + const DeviceType& device_type); + virtual ~DeviceExecutablePersistor() = default; + + // Returns std::nullopt if persistence is not enabled (i.e. + // `persistent_cache_directory_` is empty) or if the serialized entry is not + // found on disk. Otherwise, loads and returns the serialized executable + // (or returns a status). + // TODO(b/255826209): Take in Signature instead of hash and string once cache + // is refactored. + std::optional>> TryToLoadExecutable( + uint64 signature_hash, const std::string& signature_str, + const XlaCompiler::Options& options, + const XlaCompiler::CompilationResult& compilation_result, + DeviceCompilerClient* client) const; + + // Tries to serialize an already built `executable` and persist it on disk. If + // unable to do so, tries to build a serialized executable using the AOT + // pipeline and persists that to disk. + // TODO(b/255826209): Take in Signature instead hash and string once cache + // is refactored. + virtual absl::Status TryToPersistExecutable( + uint64 signature_hash, const std::string& signature_str, + const XlaCompiler::Options& options, + const XlaCompiler::CompilationResult& compilation_result, + const ExecutableType& executable, + DeviceCompilerClient* client) const; + + const DeviceType& device_type() const { return device_type_; } + const std::string& persistence_prefix() const { return persistence_prefix_; } + const std::string& persistent_cache_directory() const { + return persistent_cache_directory_; + } + + private: + // Returns a cache key proto that identifies an entry in the compilation + // cache. + XlaSerializedCacheKey BuildSerializedCacheKey( + uint64 signature_hash, const xla::HloModuleProto& hlo_module) const; + + XlaSerializedCacheKey BuildSerializedCacheKey( + uint64 signature_hash, const xla::HloModuleProto& hlo_module, + bool compiled_using_pjrt) const; + + // Serializes the signature and its corresponding entry to a proto message. + absl::StatusOr SerializeEntry( + uint64 signature_hash, const XlaCompiler::Options& options, + const XlaCompiler::CompilationResult& compilation_result, + const ExecutableType& executable, + DeviceCompilerClient* compiler_client) const; + + // Saves the cache entry in the file directory supplied during the + // construction of this class. Overwrites existing entries. + absl::Status SaveSerializedEntry(const XlaSerializedCacheEntry& entry) const; + + // Tries to read a cache entry given a `key` by searching the file directory + // supplied during the construction of this class. Returns std::nullopt if no + // cache entry is found. + absl::StatusOr> + TryToReadSerializedEntry(const XlaSerializedCacheKey& key) const; + + // Checks if the loaded `entry` matches the expected `key` and `hlo_module`. + absl::Status VerifyLoadedCacheEntry( + const XlaSerializedCacheKey& key, const xla::HloModuleProto& hlo_module, + const XlaSerializedCacheEntry& entry) const; + + std::string GetFilePath(const XlaSerializedCacheKey& key) const; + + const DeviceType device_type_; + const bool disable_strict_signature_checks_; + const std::string persistence_prefix_; + + // If non-empty, JIT-compiled executables are saved to and loaded from the + // specified file system directory path. + const std::string persistent_cache_directory_; + + // Cache is read-only if set to true. + const bool persistent_cache_directory_read_only_; + + DeviceExecutablePersistor(const DeviceExecutablePersistor&) = delete; + void operator=(const DeviceExecutablePersistor&) = delete; +}; + +template +DeviceExecutablePersistor:: + DeviceExecutablePersistor(const Config& config, + const DeviceType& device_type) + : device_type_(device_type), + disable_strict_signature_checks_(config.disable_strict_signature_checks), + persistence_prefix_(config.persistence_prefix), + persistent_cache_directory_(config.persistent_cache_directory), + persistent_cache_directory_read_only_( + config.persistent_cache_directory_read_only) {} + +template +std::string DeviceExecutablePersistor::GetFilePath( + const XlaSerializedCacheKey& key) const { + const std::string file_name = XlaSerializedCacheKeyToFileName(key); + return io::JoinPath(persistent_cache_directory_, file_name); +} + +template +XlaSerializedCacheKey +DeviceExecutablePersistor::BuildSerializedCacheKey( + uint64 signature_hash, const xla::HloModuleProto& hlo_module, + bool compiled_using_pjrt) const { + XlaSerializedCacheKey key; + key.set_signature_fingerprint(signature_hash); + key.set_cluster_fingerprint(DeterministicProtoHash64(hlo_module)); + key.set_device_type(device_type().type_string()); + key.set_prefix(persistence_prefix()); + key.set_compiled_using_pjrt(compiled_using_pjrt); + return key; +} + +template +XlaSerializedCacheKey +DeviceExecutablePersistor::BuildSerializedCacheKey( + uint64 signature_hash, const xla::HloModuleProto& hlo_module) const { + return BuildSerializedCacheKey(signature_hash, hlo_module, false); +} + +// This template specialization sets compiled_using_prjt to true in the cache +// key when the template arguments are PjRtLoadedExecutable and PjRtClient. +template <> +inline XlaSerializedCacheKey +DeviceExecutablePersistor:: + BuildSerializedCacheKey(uint64 signature_hash, + const xla::HloModuleProto& hlo_module) const { + return BuildSerializedCacheKey(signature_hash, hlo_module, true); +} + +template +absl::StatusOr> +DeviceExecutablePersistor::TryToReadSerializedEntry( + const XlaSerializedCacheKey& key) const { + Env* env = Env::Default(); + const std::string file_path = GetFilePath(key); + if (!env->FileExists(file_path).ok()) { + return absl::StatusOr>(std::nullopt); + } + + XlaSerializedCacheEntry entry; + TF_RETURN_IF_ERROR(ReadTextOrBinaryProto(env, file_path, &entry)); + return std::optional(entry); +} + +template +absl::Status +DeviceExecutablePersistor::VerifyLoadedCacheEntry( + const XlaSerializedCacheKey& key, const xla::HloModuleProto& hlo_module, + const XlaSerializedCacheEntry& entry) const { + XLA_SCOPED_LOGGING_TIMER(absl::StrCat("Verifying loaded cache entry: ", + hlo_module.entry_computation_name())); + + if (!AreSerializedProtosEqual(key, entry.key())) { + VLOG(2) << "Serialized cache key does not match:\n" + << "got:\n" + << entry.key().DebugString() << "\nexpected:\n" + << key.DebugString() << "\n"; + return errors::InvalidArgument("Serialized cache key does not match."); + } + + // Perform a stricter (slower) check of the snapshot to verify that they + // match exactly. + if (!disable_strict_signature_checks_) { + if (!AreSerializedProtosEqual(hlo_module, entry.hlo_module())) { + VLOG(2) << "HLOs do not match:\n" + << "got:\n" + << hlo_module.DebugString() << "\nexpected:\n" + << entry.hlo_module().DebugString() << "\n"; + return errors::InvalidArgument("Serialized HLO does not match."); + } + } + + if (entry.executable().empty()) { + return errors::InvalidArgument("No binary found in serialized entry."); + } + return absl::OkStatus(); +} + +template +absl::Status +DeviceExecutablePersistor::SaveSerializedEntry( + const XlaSerializedCacheEntry& entry) const { + Env* env = Env::Default(); + TF_RETURN_IF_ERROR(env->RecursivelyCreateDir(persistent_cache_directory_)); + + // The cache on the filesystem can be read while we're writing out the proto. + // To prevent reads of partially-written files, we write the proto to a temp + // file, then move it into place once we're done writing. And we warn the + // user if these moves are not known to be atomic. + bool has_atomic_move = false; + env->HasAtomicMove(persistent_cache_directory_, &has_atomic_move) + .IgnoreError(); + if (!has_atomic_move) { + LOG_EVERY_POW_2(WARNING) + << "Filesystem for XLA persistent cache at " + << persistent_cache_directory_ + << " does not support atomic moves. Therefore the persistent cache is " + "racy if you have multiple XLA compilations occurring " + "simultaneously! You have been warned. :)"; + } + + // Write to temp location, then when that completes, atomically move into the + // final location. + std::string temp_path = + io::JoinPath(persistent_cache_directory_, + XlaSerializedCacheKeyToFileName(entry.key())); + if (!env->CreateUniqueFileName(&temp_path, ".tmp")) { + return absl::UnavailableError(absl::StrCat( + "Could not create a unique file inside ", persistent_cache_directory_)); + } + TF_RETURN_IF_ERROR(WriteBinaryProto(env, temp_path, entry)); + return env->RenameFile(temp_path, GetFilePath(entry.key())); +} + +template +absl::StatusOr +DeviceExecutablePersistor::SerializeEntry( + uint64 signature_hash, const XlaCompiler::Options& options, + const XlaCompiler::CompilationResult& compilation_result, + const ExecutableType& executable, + DeviceCompilerClient* compiler_client) const { + XlaSerializedCacheEntry serialized_entry; + const xla::HloModuleProto& hlo_module = + compilation_result.computation->proto(); + *serialized_entry.mutable_key() = + BuildSerializedCacheKey(signature_hash, hlo_module); + *serialized_entry.mutable_hlo_module() = hlo_module; + + // XLA compiler supports exporting executables as an AOT compilation result + // to avoid running potentially expensive compilation pipeline twice. + // Check if XLA compiler can export available executable. + if (auto serialized_executable = + compiler_client->SerializeExecutable(executable); + serialized_executable.ok()) { + serialized_entry.set_executable(std::move(*serialized_executable)); + return serialized_entry; + } else if (serialized_executable.status().code() == error::UNIMPLEMENTED) { + VLOG(1) << "Executable export is not implemented"; + } else { + return serialized_executable.status(); + } + + TF_ASSIGN_OR_RETURN( + auto serialized_executable, + compiler_client->BuildSerializedExecutable(options, compilation_result)); + serialized_entry.set_executable(std::move(serialized_executable)); + return serialized_entry; +} + +template +std::optional>> +DeviceExecutablePersistor::TryToLoadExecutable( + uint64 signature_hash, const std::string& signature_str, + const XlaCompiler::Options& options, + const XlaCompiler::CompilationResult& compilation_result, + DeviceCompilerClient* compiler_client) const { + if (persistent_cache_directory_.empty()) { + return std::nullopt; + } + + const xla::HloModuleProto& hlo_module = + compilation_result.computation->proto(); + + XlaSerializedCacheKey cache_key = + BuildSerializedCacheKey(signature_hash, hlo_module); + + std::optional serialized_entry; + { + XLA_SCOPED_LOGGING_TIMER( + absl::StrCat("Try loading serialized cache entry:", signature_str)); + TF_ASSIGN_OR_RETURN(serialized_entry, TryToReadSerializedEntry(cache_key)); + } + + if (!serialized_entry.has_value()) { + return std::nullopt; + } + + TF_RETURN_IF_ERROR( + VerifyLoadedCacheEntry(cache_key, hlo_module, *serialized_entry)); + + VLOG(1) << "Loading cached entry for: " << signature_str; + return compiler_client->LoadExecutable(options, compilation_result, + serialized_entry->executable()); +} + +template +absl::Status +DeviceExecutablePersistor::TryToPersistExecutable( + uint64 signature_hash, const std::string& signature_str, + const XlaCompiler::Options& options, + const XlaCompiler::CompilationResult& compilation_result, + const ExecutableType& executable, + DeviceCompilerClient* client) const { + if (persistent_cache_directory_.empty() || + persistent_cache_directory_read_only_) { + VLOG(1) << "Not persisting executable. No `persistent_cache_directory` " + "provided or cache is read-only."; + return absl::OkStatus(); + } + + XLA_SCOPED_LOGGING_TIMER( + absl::StrCat("Serializing and saving cache entry: ", signature_str)); + TF_ASSIGN_OR_RETURN(XlaSerializedCacheEntry serialized_entry, + SerializeEntry(signature_hash, options, + compilation_result, executable, client)); + TF_RETURN_IF_ERROR(SaveSerializedEntry(std::move(serialized_entry))); + VLOG(2) << "XlaSerializedCacheEntry saved for signature: [" << signature_str + << "] with signature hash: " << signature_hash; + return absl::OkStatus(); +} + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_JIT_DEVICE_EXECUTABLE_PERSISTOR_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/jit/device_util.h b/third_party/tflite-hdrs/tensorflow/compiler/jit/device_util.h new file mode 100644 index 00000000..745f8730 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/jit/device_util.h @@ -0,0 +1,203 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_JIT_DEVICE_UTIL_H_ +#define TENSORFLOW_COMPILER_JIT_DEVICE_UTIL_H_ + +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/numeric/bits.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "xla/status_macros.h" +#include "tensorflow/core/framework/types.h" + +namespace tensorflow { +namespace jit { +class DeviceInfoCache; +class DeviceSet; + +// Instances of DeviceId represent TensorFlow devices as integers. +// +// This helps avoid having to manipulate device names as strings when +// auto-clustering. +class DeviceId { + public: + DeviceId(DeviceId&&) = default; + DeviceId(const DeviceId&) = default; + DeviceId& operator=(const DeviceId&) = default; + + bool operator==(const DeviceId& other) const { return id() == other.id(); } + bool operator!=(const DeviceId& other) const { return !(*this == other); } + + private: + int id_; + + explicit DeviceId(int id) : id_(id) {} + + int id() const { return id_; } + + friend class DeviceInfoCache; + friend class DeviceSet; +}; + +// A set of DeviceIds, represented as a bitmap. +class DeviceSet { + public: + void Insert(DeviceId device_id); + void UnionWith(const DeviceSet& other); + bool IsEmpty() const; + + // Calls `func` on each DeviceId in the set. Stops iterating early if `func` + // return false. + // + // TODO(sanjoy): Change this to take a typed std::function if that's + // performance neutral. + template + void ForEach(FnTy func) const { + // This is really a poor man's iterator, we should consider writing a proper + // iterator if this ends up being used widely. + for (int word_index = 0, end = storage_.size(); word_index < end; + word_index++) { + uint64 word = storage_[word_index]; + while (word != 0) { + uint64 only_lowest_bit_set = word & -word; + // The number of trailing zeros in a non-zero word is the index of the + // least significant 1. + int bit_index = absl::countr_zero(word); + if (!func(DeviceId(word_index * kWordSize + bit_index))) { + return; + } + word ^= only_lowest_bit_set; + } + } + } + + private: + absl::InlinedVector storage_; + + const int kWordSize = 64; +}; + +// Caches some miscellaneous information about TF devices. Thread compatible. +class DeviceInfoCache { + public: + bool IsGpu(DeviceId device) const { return is_gpu_[device.id()]; } + bool IsCpu(DeviceId device) const { return is_cpu_[device.id()]; } + + absl::string_view GetNameFor(DeviceId device) const { + return names_[device.id()]; + } + + absl::StatusOr GetIdFor(absl::string_view name); + + using DeviceRegistration = const XlaOpRegistry::DeviceRegistration; + + DeviceRegistration* GetCompilationDevice(DeviceId device) const { + return id_to_compilation_device_[device.id()]; + } + + absl::StatusOr GetCompilationDevice( + absl::string_view name) { + TF_ASSIGN_OR_RETURN(DeviceId device_id, GetIdFor(name)); + return GetCompilationDevice(device_id); + } + + const DeviceType& GetDeviceTypeFor(DeviceId device) const { + return *id_to_device_type_[device.id()]; + } + + using DeviceTypeConstRef = std::reference_wrapper; + + absl::StatusOr GetDeviceTypeFor( + absl::string_view device_name) { + TF_ASSIGN_OR_RETURN(DeviceId device_id, GetIdFor(device_name)); + return std::cref(*id_to_device_type_[device_id.id()]); + } + + string DebugString(const DeviceSet& device_set) const; + + private: + absl::flat_hash_map name_to_id_; + + // These fields are populated for a device in GetIdFor, *before* we give out a + // DeviceId. + std::vector + id_to_compilation_device_; + std::vector> id_to_device_type_; + std::vector names_; + std::vector is_cpu_; + std::vector is_gpu_; +}; + +} // namespace jit + +// Returns the DeviceType corresponding to 'device'. +absl::Status DeviceNameToDeviceType(const string& device, + DeviceType* device_type); + +// Picks the device for which XLA should compile a cluster that contains +// operations placed in devices in `devices`. For instance a cluster that +// contains operations solely placed on the CPU will be compiled into a CPU +// executable by XLA, whereas a cluster that contains operations placed on the +// CPU and also operations placed on the GPU will be compiled into a GPU +// executable. +// +// Returns a non-OK Status if no unambiguous choice of device exists. +// +// We choose the device using the following rules: +// +// - It is an error for `device_names` to contain more than one device of the +// same type. +// - GPU is preferred over CPU. +// - If `allow_mixing_unknown_and_cpu` is true then unknown devices are +// preferred over CPU. +// - XLA devices count as "unrecognized devices". +// +// This set of rules above implicitly assume that XLA:GPU can compile all +// operations in the cluster that XLA:CPU can compile, and if +// `allow_mixing_unknown_and_cpu` then the unrecognized device can also compile +// all operations in the cluster that XLA:CPU can compile. +// +// We provide the `allow_mixing_unknown_and_cpu` knob so that we can do both of +// the following things: +// +// - Let MarkForCompilationPass not inject CPU-placed operations into clusters +// that will run on unknown devices (because the unknown XLA backend may not +// support every operation supported by CPU). +// - Let BuildXlaOpsPass successfully infer a compilation device for a cluster +// that contains nodes placed on both the CPU and on unknown devices. In this +// case it is the responsibility of the optimization pass that injected the +// CPU nodes into the cluster to ensure that these nodes can be compiled by +// the unknown XLA backend. +absl::StatusOr PickDeviceForXla( + const jit::DeviceInfoCache& device_info_cache, + const jit::DeviceSet& devices, bool allow_mixing_unknown_and_cpu); + +// This is like `PickDeviceForXla` except that it returns nullopt (instead of a +// non-OK Status) if no unambiguous choice of device exists. +// +// We return a failing Status for errors unrelated to the device choice +// algorithm itself. +absl::StatusOr> MaybePickDeviceForXla( + const jit::DeviceInfoCache& device_info_cache, + const jit::DeviceSet& devices, bool allow_mixing_unknown_and_cpu); +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_JIT_DEVICE_UTIL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/jit/encapsulate_subgraphs_pass.h b/third_party/tflite-hdrs/tensorflow/compiler/jit/encapsulate_subgraphs_pass.h new file mode 100644 index 00000000..0c7729f6 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/jit/encapsulate_subgraphs_pass.h @@ -0,0 +1,108 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// An optimization pass that groups nodes marked with a common +// kXlaClusterAttr into functions, and replaces the original nodes by +// calls. The calls are annotated with kXlaCompiledKernelAttr. + +#ifndef TENSORFLOW_COMPILER_JIT_ENCAPSULATE_SUBGRAPHS_PASS_H_ +#define TENSORFLOW_COMPILER_JIT_ENCAPSULATE_SUBGRAPHS_PASS_H_ + +#include "tensorflow/core/common_runtime/optimization_registry.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { + +// EncapsulateSubgraphs pass takes all the nodes with the same cluster ID +// (derived from kXlaClusterAttr=ID (kXlaClusterAttr) attribute), puts them into +// a TF function, and replaces the subgraph in the main graph with a call to +// that TF function annotated with kXlaCompiledKernelAttr (_XlaCompiledKernel). +class EncapsulateSubgraphsPass : public GraphOptimizationPass { + public: + absl::Status Run(const GraphOptimizationPassOptions& options) override; +}; + +// A rewriting function to apply to each subgraph during encapsulation. +// 'arg_source_tensors' are the tensors corresponding to the arguments in the +// original source graph (*not* 'graph'). +// +// 'graph' is the subgraph. The rewriting may renumber the inputs and outputs; +// 'input_permutation' is a mapping from old argument numbers to new argument +// numbers, whereas 'output_permutation' is the same for outputs. Both +// 'input_permutation' and 'output_permutation' are initialized to the identity +// permutation. 'nodedef' is the NodeDef for the call to the function under +// construction, provided to allow additional attributes to be set. +// The rewrite may also change the NodeDef's operator name, and that +// name will be used as the name of the generated function. +typedef std::function& arg_source_tensors, + std::unique_ptr* graph, std::vector* input_permutation, + std::vector* output_permutation, NodeDef* node_def)> + RewriteSubgraphFn; + +// Transformation that finds subgraphs whose nodes are marked with +// 'group_attribute', splits those subgraphs into functions, and replaces +// the originals with function calls. +// +// 'group_attribute' must be a string valued-attribute that names the new +// functions to introduce. +// +// If 'rewrite_subgraph_fn' is set, it is applied to each subgraph before +// function conversion. +// +// If 'reuse_existing_functions' is set, use an existing function with the +// same name, if any. +// +// TODO(phawkins): currently, some information in control edges +// is not preserved. Suppose you have A and B in the main +// graph, C and D in a subgraph. B and C have control deps from A, D has control +// dep from B. Originally D must run after C, post-transformation this +// dependency is lost. +absl::Status EncapsulateSubgraphsInFunctions( + string group_attribute, const Graph& graph_in, + const RewriteSubgraphFn& rewrite_subgraph_fn, bool reuse_existing_functions, + std::unique_ptr* graph_out, FunctionLibraryDefinition* library); + +// The attribute that marks function calls produced by the encapsulate +// subgraphs pass and that should in turn be compiled via XlaLaunch operators. +extern const char* const kXlaCompiledKernelAttr; + +// Does `node` have the kXlaCompiledKernelAttr attribute? +bool IsXlaCompiledKernel(const Node& node); + +// Functions produced by the EncapsulateSubgraphs pass have their arguments in +// the order: +// 1) compile-time constant arguments, in host memory, +// 2) other arguments, in device memory. +// 3) resource variable arguments, in host memory. Note that only the resource +// Tensor itself is in host memory; the underlying value may be in device +// memory. +// The functions are annotated with the following attributes that describe how +// many constant and resource arguments there are: + +// Name of the attribute containing the number of constant arguments. +extern const char* const kXlaNumConstantArgsAttr; + +// Name of the attribute containing the number of resource variable arguments. +extern const char* const kXlaNumResourceArgsAttr; + +// Name of the attribute defining whether the cluster has reference variables. +extern const char* const kXlaHasReferenceVarsAttr; + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_JIT_ENCAPSULATE_SUBGRAPHS_PASS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/jit/encapsulate_util.h b/third_party/tflite-hdrs/tensorflow/compiler/jit/encapsulate_util.h new file mode 100644 index 00000000..7c99763c --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/jit/encapsulate_util.h @@ -0,0 +1,155 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This file contains some utility functions for encapsulating XLA computation +// in host graph and encapsulating outside compilation in XLA computation. + +#ifndef TENSORFLOW_COMPILER_JIT_ENCAPSULATE_UTIL_H_ +#define TENSORFLOW_COMPILER_JIT_ENCAPSULATE_UTIL_H_ + +#include "absl/container/flat_hash_map.h" +#include "tensorflow/core/graph/graph.h" + +namespace tensorflow { + +// Attribute marking output tensor shapes inferred by XLA. Attribute value is +// a list of PartialTensorShape objects. +extern const char kXlaInferredShapesAttrName[]; + +// Infers output shapes for all nodes in graph `g`. The output shapes will be +// stored in node attribute `kXlaInferredShapesAttrName`. +// +// We have to perform shape inference before encapsulation because after +// encapsulation, some nodes will be encapsulated into function call, and shape +// inference does not handle function call at the moment. +absl::Status PerformStaticShapeInferenceBeforeEncapsulation(Graph* g); + +// Attribute indicating that some ops in this node's XLA computation has control +// dependency on this node. Attribute value will always be "true". +extern const char kXlaConnectedToXlaComputationAttrName[]; + +// Attribute indicating that this node has control dependency on some ops in +// this node's XLA computation. Attribute value will always be "true". +extern const char kXlaConnectedFromXlaComputationAttrName[]; + +// Attribute indicating that this is an Placeholder node added to act as a +// temporary input node for an outside compilation node. Attribute value will be +// string (original input node name). +extern const char kOutsideCompilationOriginalNodeAttrName[]; + +// Attribute indicating that this is an Placeholder node added to act as a +// temporary input node for an outside compilation node. Attribute value will be +// int (src_output for original edge). +extern const char kOutsideCompilationSrcOutputAttrName[]; + +// Attribute indicating that this node has control dependencies on some other +// nodes within the same XLA cluster. Attribute value will be a list of string +// (node names). +extern const char kXlaControlDependenciesWithinXlaClusterAttrName[]; + +// Attribute indicating that this node is an outside compilation node which is +// lifted out of If/While/function node. Attribute value will always be boolean +// value "true". +extern const char kXlaIsLiftedArgAttrName[]; + +// Attribute indicating that this node is a Placeholder node for an _Arg node +// lifted out of If/While/function node. Attribute value will be a string, which +// is the outside compilation cluster name sending the lifted arg node to host. +extern const char kXlaLiftedArgOutsideCompilationAttrName[]; + +// Attribute indicating that this is an IdentityN node receiving inputs for a +// outside compilation Placeholder node (the original outside compilation node +// is moved out of TPU computation, and we left a Placeholder node there). +// Attribute value will be a string, which is the outside compilation cluster +// name for the outside compilation Placeholder node. +extern const char kXlaOutsideCompilationInputsAttrName[]; + +// Attribute indicating that this is a Placeholder node for an _Arg node used in +// outside compilation. We should not move this node out of XLA computation. +// Attribute value will always be boolean value "true". +extern const char kXlaIsPlaceholderForArg[]; + +// Information for XLA computation. +struct XlaClusterInfo { + // Add an explicitly-defined default constructor for this class. + // + // The compiler may delete the default constructor here because + // host_compute_core is a const member whose type (std::map) doesn't + // necessarily have a user provided constructor -- while libc++ and + // libstdc++ 4.8 provide a user defined default constructor, libstdc++ at + // least >= 7.3 does not. See also c++11 [class.ctor] p5. + // + // TODO(klimek): In c++17 we'll be able to initialize host_compute_core + // without losing aggregate initialization, which allows us to get rid of + // the constructor definitions again. + XlaClusterInfo() {} + XlaClusterInfo(const string& cluster_name, + const NameAttrList& func_name_attrs, Node* node, + const std::map& host_compute_core) + : cluster_name(cluster_name), + func_name_attrs(func_name_attrs), + node(node), + host_compute_core(host_compute_core) {} + // XLA cluster name. It might be different from `func_name`. + const string cluster_name; + // Name and attributes of XLA computation function. + const NameAttrList func_name_attrs; + // The XLA computation node in the graph. + Node* node; + // A mapping from outside compilation cluster name to its device assignment. + const std::map host_compute_core; +}; + +// Finds dependencies between outside compilation clusters, including both data +// dependencies and control dependencies. cluster_deps maps the name name of an +// outside compilation cluster to a set of names of outside compilation clusters +// that it depends on. +absl::StatusOr< + std::unique_ptr>>> +OutsideCompilationClusterDependencies( + const Graph* g, const string& outside_compilation_attr_name); + +// Preprocesses edges within the same XLA cluster. It will perform the following +// operations in order: +// +// 0. Remove edges from source node to outside compilation nodes, and edges +// from outside compilation nodes to sink node. +// 1a. For edges between different outside compilation clusters, remove the edge +// and add attr "kXlaControlDependenciesWithinXlaClusterAttrName = src node +// name" to dst node. +// 1b. For control edges between outside compilation and its XLA computation, +// add attr "kXlaConnected{From, To}XlaComputationAttrName = true" to the +// outside compilation node. +// 2. For data edges between different outside compilations, remove the edge +// and create a Placeholder node as dst node's input. +absl::Status PreprocessEdgesBetweenOutsideCompilations( + Graph* g, const string& outside_compilation_attr_name); + +// Postprocesses edges within the same XLA cluster. This function reverts what +// `PreprocessEdgesBetweenOutsideCompilations` did. It will perform the +// following operations in order: +// +// 1. Remove Placeholder nodes between different outside compilations (created +// in `PreprocessEdgesBetweenOutsideCompilations` step 2). +// 2a. Reconnect control edges between different outside compilations (marked by +// `PreprocessEdgesBetweenOutsideCompilations` step 1a). +// Notice that control edges marked by +// `PreprocessEdgesBetweenOutsideCompilations` step 1b are not handled here. +// They are handled in `RewriteOutsideCompilationSubgraphFn`. +absl::Status PostprocessEdgesBetweenOutsideCompilations( + Graph* g, const string& outside_compilation_attr_name); +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_JIT_ENCAPSULATE_UTIL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/jit/encapsulate_xla_computations_pass.h b/third_party/tflite-hdrs/tensorflow/compiler/jit/encapsulate_xla_computations_pass.h new file mode 100644 index 00000000..6301e963 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/jit/encapsulate_xla_computations_pass.h @@ -0,0 +1,85 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ==============================================================================*/ +// Rewrites computations generated by the xla.compile() Python code into +// XlaLaunch nodes. +// +// xla.compile() does two main things: +// a) marks operators that make up an XLA computation with the attribute +// _xla_compile_id=XYZ, where XYZ is a unique key. +// b) adds XlaClusterOutput nodes to represent outputs of the computation. +// These nodes are not marked with the _xla_compile_id attribute. + +#ifndef TENSORFLOW_COMPILER_JIT_ENCAPSULATE_XLA_COMPUTATIONS_PASS_H_ +#define TENSORFLOW_COMPILER_JIT_ENCAPSULATE_XLA_COMPUTATIONS_PASS_H_ + +#include +#include + +#include "tensorflow/core/common_runtime/optimization_registry.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/statusor.h" + +namespace tensorflow { + +// Encapsulates nodes marked with the _xla_compile_id attribute into +// XlaLaunch operators. +class EncapsulateXlaComputationsPass : public GraphOptimizationPass { + public: + absl::Status Run(const GraphOptimizationPassOptions& options) override; + + // The following methods are public only for unit tests. + + // This pass has two stages: + // a) first, we call EncapsulateSubgraphsPass to encapsulate all nodes + // marked with the same _xla_compile_id attribute into functions. These + // functions contain the computations to be passed to XlaLaunch. During + // encapsulation, we sort the arguments into the order expected by + // XlaLaunch. + static absl::Status Encapsulate(std::unique_ptr* graph, + FunctionLibraryDefinition* flib_def); + + // b) we rewrite the function calls generated in phase (a) into XlaLaunch + // operators. We also convert the XlaClusterOutput output nodes of the + // function call into the outputs of the XlaLaunch operator. + static absl::Status BuildXlaLaunchOps(Graph* graph); + + struct XlaFunctionInfo { + int variable_start_index = -1; + std::string function_name; + }; + + // We need to introduce this version to adapt to the output of gpu inference + // converter. The single argument overload version calls this function. + // + // When add_edges_to_output_of_downstream_nodes is true, the output edges of + // the xla_launch_node's immediate downstream nodes would be attached to the + // generated xla node. For example, if the original graph is + // StatefulPartitionedCall{_xla_compile_id=1} -> XlaClusterOutput -> NodeA + // The output graph of this function would look like the following when + // add_edges_to_output_of_downstream_nodes is true: + // XlaLaunch -> NodeA + static absl::Status BuildXlaLaunchOps( + Graph* graph, + const std::function(const Node&)>& + is_xla_launch_node, + const std::function(const Node&)>& + get_xla_function_info, + bool add_edges_to_output_of_downstream_nodes); +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_JIT_ENCAPSULATE_XLA_COMPUTATIONS_PASS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/jit/extract_outside_compilation_pass.h b/third_party/tflite-hdrs/tensorflow/compiler/jit/extract_outside_compilation_pass.h new file mode 100644 index 00000000..7631ccd0 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/jit/extract_outside_compilation_pass.h @@ -0,0 +1,112 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_JIT_EXTRACT_OUTSIDE_COMPILATION_PASS_H_ +#define TENSORFLOW_COMPILER_JIT_EXTRACT_OUTSIDE_COMPILATION_PASS_H_ + +#include "absl/types/optional.h" +#include "tensorflow/compiler/jit/encapsulate_util.h" +#include "xla/status_macros.h" +#include "tensorflow/core/graph/graph.h" + +namespace tensorflow { + +// Rewrite function for outside compilation subgraphs. It will perform the +// following steps: +// +// 1. Add a XLA computation key placeholder node (it will be used as input for +// XlaRecvAtHost and XlaSendFromHost); +// 2. Replace all _Arg nodes with one single XlaRecvAtHost node; +// 3. Replace all _Retval nodes with one single XlaSendFromHost node; +// 4. Mark all nodes except key placeholder with attr `xla_cluster_attr_name` +// and `outside_compilation_attr_name`; +// 5. For nodes marked with attr kXlaConnectedToXlaComputationAttrName, add a +// control edge from the node to XlaSendFromHost; for nodes marked with attr +// kXlaConnectedFromXlaComputationAttrName, add a control edge from +// XlaRecvAtHost node to the node; +// 6. Try pruning XlaRecvAtHost/XlaSendFromHost/key placeholder node. +// 7. Add necessary attributes to `node_def`, so we can replace it with a +// XlaHostCompute node later. If all input shapes for XlaSendFromHost are +// known, "shapes" attr will be set to the list of input shapes; otherwise +// "shape_inference_graph" attr will be set to shape inference function name. +class RewriteOutsideCompilationSubgraphFn { + public: + RewriteOutsideCompilationSubgraphFn( + const string& xla_cluster_attr_name, + const string& outside_compilation_attr_name, + const string& xla_cluster_name, const string& new_function_name) + : xla_cluster_attr_name_(xla_cluster_attr_name), + outside_compilation_attr_name_(outside_compilation_attr_name), + xla_cluster_name_(xla_cluster_name), + new_function_name_(new_function_name) {} + + absl::Status operator()(const std::vector&, + std::unique_ptr* graph, + std::vector* input_permutation, + std::vector* output_permutation, + NodeDef* node_def); + + private: + string xla_cluster_attr_name_; + string outside_compilation_attr_name_; + string xla_cluster_name_; + string new_function_name_; +}; + +// For an XLA computation function, replace all outside compilations with +// XlaHostCompute nodes. Each outside compilation subgraph will be rewritten by +// `RewriteOutsideCompilationSubgraphFn`, and they will be merged into one +// single host side graph (`host_graph`). +// +// xla_cluster_attr_name and outside_compilation_attr_name: attr name for XLA +// computation and outside compilation. Required for +// `RewriteOutsideCompilationSubgraphFn`. +// xla_cluster_name: XLA cluster name for this XLA computation. We need it +// because XLA cluster name might be different from `func_name`. +// func_name_attrs: they will be used to instantiate the XLA computation func. +// new_func_name: new function name for rewritten XLA computation func. +// host_compute_core: mapping from outside compilation cluster name to XLA +// device assignment. +// fld: FunctionLibraryDefinition object. +// host_graph: Graph object to store host side graph for all outside +// compilations within this XLA computation func. If there is no outside +// compilation, it will be empty. +// shape_inference_graphs: a list of outside compilation shape inference +// function names. These functions need to be rewritten later. +// has_outside_compilation: a bool indicating whether this function has any +// outside compilation nodes. +absl::Status ExtractOutsideCompilationForFunction( + const string& xla_cluster_attr_name, + const string& outside_compilation_attr_name, const string& xla_cluster_name, + const NameAttrList& func_name_attrs, const string& new_func_name, + const string& host_graph_func_name, + const std::map& host_compute_core, FunctionLibraryRuntime* flr, + FunctionLibraryDefinition* fld, std::vector* shape_inference_graphs, + bool* has_outside_compilation); + +// Rewrites XLA computation in `clusters` to replace outside compilation nodes +// with XlaHostCompute, and moves those outside compilations into `g`. If shapes +// of outside compilation outputs cannot be determined now, we will store shape +// inference graph into `fld`. +absl::Status ExtractOutsideCompilation( + const string& xla_cluster_attr_name, + const string& outside_compilation_attr_name, + const std::unordered_map& clusters, Graph* g, + FunctionLibraryRuntime* flr, FunctionLibraryDefinition* fld, + bool* modified); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_JIT_EXTRACT_OUTSIDE_COMPILATION_PASS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/jit/flags.h b/third_party/tflite-hdrs/tensorflow/compiler/jit/flags.h new file mode 100644 index 00000000..9dbd6106 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/jit/flags.h @@ -0,0 +1,360 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_JIT_FLAGS_H_ +#define TENSORFLOW_COMPILER_JIT_FLAGS_H_ + +#include +#include +#include +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/types/optional.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/protobuf/config.pb.h" +#include "tensorflow/core/util/command_line_flags.h" + +namespace tensorflow { + +struct XlaAutoJitFlag { + // Control compilation of operators into XLA computations on CPU and GPU + // devices. 0 = use ConfigProto setting; -1 = off; 1 = on for things very + // likely to be improved; 2 = on for everything. + // + // If all non-CPU ops in the graph being optimized are placed on a single GPU + // and there is at least one node placed on that GPU then + // `optimization_level_single_gpu` applies. Otherwise + // `optimization_level_general` applies. + // + // Experimental. + int32 optimization_level_single_gpu; + int32 optimization_level_general; +}; + +// Sets the xla_auto_jit_flag based on the given flag string. Supported syntax +// is: +// : sets general and single_gpu setting to the provided number. +// single-gpu(): sets the single_gpu setting to the provided number. +bool SetXlaAutoJitFlagFromFlagString(const string& value); + +// Flags associated with the XLA bridge's mark_for_compilation_pass module. +struct MarkForCompilationPassFlags { + XlaAutoJitFlag xla_auto_jit_flag; + + // Minimum number of operators in an XLA compilation. Ignored for operators + // placed on an XLA device or operators explicitly marked for compilation. + int32 tf_xla_min_cluster_size; + + // Maximum number of operators in an XLA compilation. + int32 tf_xla_max_cluster_size; + + // If non-empty, limit XLA clustering to the following TF operations. + string tf_xla_ops_to_cluster; + + // If non-empty, remove following operations from XLA clustering excludelist. + string tf_xla_cluster_exclude_ops; + + // Dump graphs during XLA compilation. + bool tf_xla_clustering_debug; + + // Enables global JIT compilation for CPU via SessionOptions. + bool tf_xla_cpu_global_jit; + + // "Compiler fuel" for clustering. Only this many ops will be marked as + // eligible for clustering. + int64_t tf_xla_clustering_fuel; + + // If tf_xla_disable_deadness_safety_checks_for_debugging is set to true then + // we do not do deadness related safety checks. This is unsound in general, + // but can be used as a debugging aid. + bool tf_xla_disable_deadness_safety_checks_for_debugging; + + // If tf_xla_disable_resource_variable_safety_checks_for_debugging is set to + // true then we do not do safety checks to preserve TensorFlow's resource + // variable concurrency semantics. This is unsound in general, but can be + // used as a debugging aid. + bool tf_xla_disable_resource_variable_safety_checks_for_debugging; + + // If true names of clustered operations will be computed deterministically + // so that they remain stable from run to run of auto clusteing. + bool tf_xla_deterministic_cluster_names; + + // If non-empty, JIT-compiled executables are saved to and loaded from the + // specified file system directory path. + std::string tf_xla_persistent_cache_directory; + + // If non-empty, the persistent cache will only be used for the specified + // devices (comma separated). Each device type should be able to be converted + // to `DeviceType`. + std::string tf_xla_persistent_cache_device_types; + + bool tf_xla_persistent_cache_read_only; + + // If true, entries loaded into the XLA compile cache will not have their + // signatures checked strictly. This should generally not be disabled except + // for debugging. Defaults to false. + bool tf_xla_disable_strict_signature_checks; + + // Specifies the persistance cache prefix. Default is "xla_compile_cache" + string tf_xla_persistent_cache_prefix; +}; + +// Flags associated with XLA Sparse Core. +struct XlaSparseCoreFlags { + // Max level of division to split input data into minibatches. + int tf_xla_sparse_core_minibatch_max_division_level; + + // Disable table stacking for all the tables passed to the SparseCore + // mid level API. + bool tf_xla_sparse_core_disable_table_stacking; + + // If non-zero, limits the size of the activations for a given table to + // be below these many bytes. + int64_t tf_xla_sparse_core_stacking_mem_limit_bytes; + + // If non-zero, limits the size of any table shard to be below these + // many bytes. + int64_t tf_xla_sparse_core_stacking_table_shard_limit_bytes; +}; + +// Flags associated with the XLA bridge's xla_device module. +struct XlaDeviceFlags { + // Switch the CPU device into "on-demand" mode, where instead of + // auto-clustering ops are compiled one by one just-in-time. + // Enabling this mode by a legacy flag is a temporary mechanism. When this + // feature is battle-tested, we will switch this to be a session option. + bool tf_xla_compile_on_demand; + + // Enables "XLA" devices if this flag is set. + bool tf_xla_enable_xla_devices; +}; + +// Flags common to the _Xla* ops and their kernels. +struct XlaOpsCommonFlags { + // If true, _XlaCompile always refuses to compile the cluster, which means the + // XLA clusters always run in the TF executor. Defaults to false. + bool tf_xla_always_defer_compilation; + // If true, _XlaCompile compiles the cluster asynchronously with respect to + // the main execution. The fallback path is taken while compilation happens. + bool tf_xla_async_compilation; + + class PjRtForSingleDeviceCompilationRollout { + public: + // Allow using Device API (PjRt) for `device_type` in the XlaLaunch op. + // Please note that `enabled_for_xla_launch_` needs to be true in addition + // to the `device_type` being allowed in order to use the Device API for + // single device compilation and execution in the XlaLaunch op. + void AllowForDeviceInXlaLaunch(const DeviceType& device_type) { + xla_launch_allowed_devices_.insert(device_type.type_string()); + } + + bool IsEnabledInXlaLaunchForDevice(const DeviceType& device_type) const { + if (!enabled_for_gpu_ && device_type.type_string() == "GPU") return false; + return enabled_for_all_ || + (enabled_for_xla_launch_ && + xla_launch_allowed_devices_.contains(device_type.type_string())); + } + + // Allow using Device API (PjRt) for `device_type` in the XlaCompileOnDemand + // op. Please note that `enabled_for_compile_on_demand_` needs to be true in + // addition to the `device_type` being allowed in order to use the Device + // API for single device compilation and execution in the XlaCompileOnDemand + // op. + void AllowForDeviceInXlaCompileOnDemand(const DeviceType& device_type) { + xla_compile_on_demand_allowed_devices_.insert(device_type.type_string()); + } + + bool IsEnabledInXlaCompileOnDemandForDevice( + const DeviceType& device_type) const { + if (!enabled_for_gpu_ && device_type.type_string() == "GPU") return false; + return enabled_for_all_ || + (enabled_for_compile_on_demand_ && + xla_compile_on_demand_allowed_devices_.contains( + device_type.type_string())); + } + + // Allow using Device API (PjRt) for `device_type` in the XlaCompile and + // XlaRun ops. Please note that `enabled_for_compile_and_run_` needs to be + // true in addition to the `device_type` being allowed in order to use the + // Device API for single device compilation and execution in the XlaCompile + // and XlaRun ops. + void AllowForDeviceInXlaCompileAndRun(const DeviceType& device_type) { + xla_compile_and_run_allowed_devices_.insert(device_type.type_string()); + } + + bool IsEnabledInXlaCompileAndRunForDevice( + const DeviceType& device_type) const { + if (!enabled_for_gpu_ && device_type.type_string() == "GPU") return false; + return enabled_for_all_ || (enabled_for_compile_and_run_ && + xla_compile_and_run_allowed_devices_.contains( + device_type.type_string())); + } + + bool IsEnabledForGpu() const { return enabled_for_gpu_; } + + // If true, uses Device API (PjRt) for single device compilation and + // execution of functions marked for JIT compilation i.e. jit_compile=True. + // Defaults to false. + bool enabled_for_xla_launch_; + + // If true, uses Device API (PjRt) for compiling and executing ops one by + // one in "on-demand" mode. Defaults to false. + bool enabled_for_compile_on_demand_; + + // If true, uses Device API (PjRt) for compilation and execution when + // auto-clustering is enabled. Defaults to false. + bool enabled_for_compile_and_run_; + + // If true, uses Device API (PjRt) for compilation and execution everywhere + // i.e. for functions marked for JIT compilation, for ops in "on-demand" + // mode and auto-clustering. Defaults to false. + // + // Note that this flag can be overridden by device flag like + // `enabled_for_gpu_` below. + bool enabled_for_all_; + + // If true, enable Device API (PjRt) for TF GPU device. This is a helper + // flag so that individual tests can turn on PjRt for GPU specifically. + // Once the rollout to GPU is complete, this flag can be deprecated. + bool enabled_for_gpu_; + + private: + // Devices for which using Device API (PjRt) is allowed in the XlaLaunch op. + // This can only be modified programmatically. + absl::flat_hash_set xla_launch_allowed_devices_; + // Devices for which using Device API (PjRt) is allowed in the + // XlaCompileOnDemand op. This can only be modified programmatically. + absl::flat_hash_set xla_compile_on_demand_allowed_devices_; + // Devices for which using Device API (PjRt) is allowed in the + // XlaCompile and XlaRun ops. This can only be modified programmatically. + absl::flat_hash_set xla_compile_and_run_allowed_devices_; + } tf_xla_use_device_api; +}; + +// Flags for the XlaCallModule kernel. +struct XlaCallModuleFlags { + // Used by XlaCallModuleOp to specify safety checks to disable. + absl::flat_hash_set disabled_checks; +}; + +// Flags for the build_xla_ops pass. +struct BuildXlaOpsPassFlags { + // Enables lazy compilation for TF/XLA (only when auto-clustering) if true. + // Defaults to true. + bool tf_xla_enable_lazy_compilation; + + // If true then insert Print nodes to print out values produced by XLA + // clusters. Useful for debugging. + bool tf_xla_print_cluster_outputs; + + // If true, insert CheckNumerics nodes for every floating point typed input to + // an XLA cluster. + bool tf_xla_check_cluster_input_numerics; + + // If true, insert CheckNumerics nodes for every floating point typed output + // from an XLA cluster. + bool tf_xla_check_cluster_output_numerics; + + // Disables all constant folding. The primary use for this is for testing to + // guarantee that tests are run on XLA and not on TF's CPU implementation. + bool tf_xla_disable_constant_folding; + + // Disables full embedding pipelining when true. Instead, strict SparseCore + // TensorCore sequencing will be used. + bool tf_xla_disable_full_embedding_pipelining; + + // Force the WhileOps in embedding_pipelining and embedding_sequencing to use + // this many parallel_iterations + int tf_xla_embedding_parallel_iterations; +}; + +// Flags for common MLIR configurations. +struct MlirCommonFlags { + ConfigProto::Experimental::MlirBridgeRollout tf_mlir_enable_mlir_bridge; + + bool tf_mlir_enable_merge_control_flow_pass; + bool tf_mlir_enable_convert_control_to_data_outputs_pass; + bool tf_mlir_enable_composite_tpuexecute_side_effects; + bool tf_mlir_enable_strict_clusters; + bool tf_mlir_enable_tpu_variable_runtime_reformatting_pass; + // TODO(pineapplejuice233): Revisit this flag once the performance impact is verified + // with different local CPU devices settings. + bool tf_mlir_enable_multiple_local_cpu_devices; +}; + +// Flags for the JitRt pipeline -- see tf_jitrt_pipeline.h for details. +struct JitRtFlags { + bool always_specialize; + bool cost_driven_async_parallel_for; + + // Enables tracking of the "live" JitRt queries to, on a crash, identify the + // "query of death". See TfJitRtQueryOfDeathLogger. + bool log_query_of_death; + + // Enable vectorization, which requires tiling and peeling on different ops. + bool vectorize; + + // Enables crash reproducer for JitRt MLIR pass manager. + bool enable_crash_reproducer; +}; + +// Return a pointer to the DumpGraphFlags struct; +// repeated calls return the same pointer. +// This should be called only after Flags::Parse() has returned. + +// Getters for flags structs defined above. The first call to any of these +// parses TF_XLA_FLAGS for all of them. Those functions which return a pointer +// always return the same pointer. +MarkForCompilationPassFlags* GetMarkForCompilationPassFlags(); +BuildXlaOpsPassFlags* GetBuildXlaOpsPassFlags(); +XlaSparseCoreFlags* GetXlaSparseCoreFlags(); +XlaDeviceFlags* GetXlaDeviceFlags(); +XlaOpsCommonFlags* GetXlaOpsCommonFlags(); +XlaCallModuleFlags* GetXlaCallModuleFlags(); + +MlirCommonFlags* GetMlirCommonFlags(); + +void ResetJitCompilerFlags(); + +const JitRtFlags& GetJitRtFlags(); + +// Returns the effective MLIR bridge rollout state based on the flags and the +// optional configuration. +ConfigProto::Experimental::MlirBridgeRollout GetMlirBridgeRolloutState( + std::optional config_proto); + +// Appends the flag definitions associated with +// MarkForCompilationPassFlags/DumpGraphFlags to `flag_list`. +// +// Has the side-effect of parsing TF_XLA_FLAGS if that hasn't happened yet. +void AppendMarkForCompilationPassFlags( + std::vector* flag_list); + +// Disables XLA compilation, forces it to return an error message instead. Can +// be used by a server to ensure that JIT compilation is opt-in. +void DisableXlaCompilation(); + +// Enables XLA compilation. Can be used with `DisableXlaCompilation` to +// enable/disable JIT compilation at different stages. +void EnableXlaCompilation(); + +// Returns `false` unless `DisableXlaCompilation` was called. +bool FailOnXlaCompilation(); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_JIT_FLAGS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/jit/force_xla_constants_on_host_pass.h b/third_party/tflite-hdrs/tensorflow/compiler/jit/force_xla_constants_on_host_pass.h new file mode 100644 index 00000000..ae7cf149 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/jit/force_xla_constants_on_host_pass.h @@ -0,0 +1,36 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_JIT_FORCE_XLA_CONSTANTS_ON_HOST_PASS_H_ +#define TENSORFLOW_COMPILER_JIT_FORCE_XLA_CONSTANTS_ON_HOST_PASS_H_ + +#include "absl/container/flat_hash_set.h" +#include "tensorflow/compiler/jit/compilability_check_util.h" +#include "tensorflow/core/common_runtime/optimization_registry.h" + +namespace tensorflow { + +// An optimization pass which marks the constants which have to be resolved for +// XLA compilation with `_input_hostmem`. +class ForceXlaConstantsOnHostPass : public GraphOptimizationPass { + public: + ForceXlaConstantsOnHostPass() = default; + + absl::Status Run(const GraphOptimizationPassOptions& options) override; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_JIT_FORCE_XLA_CONSTANTS_ON_HOST_PASS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/jit/get_compiler_ir.h b/third_party/tflite-hdrs/tensorflow/compiler/jit/get_compiler_ir.h new file mode 100644 index 00000000..a4352d11 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/jit/get_compiler_ir.h @@ -0,0 +1,82 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_JIT_GET_COMPILER_IR_H_ +#define TENSORFLOW_COMPILER_JIT_GET_COMPILER_IR_H_ + +#include +#include + +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "tensorflow/compiler/tf2xla/xla_compiler.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/platform/statusor.h" + +namespace tensorflow { + +class ProcessFunctionLibraryRuntime; +class Device; +class Tensor; +class TensorHandle; +class EagerContext; + +enum class IrExportStage { + STABLEHLO, + STABLEHLO_SERIALIZED, + HLO, + HLO_NO_METADATA, + HLO_SERIALIZED, + OPTIMIZED_HLO, + OPTIMIZED_HLO_SERIALIZED, + OPTIMIZED_HLO_PROTO_SERIALIZED, + OPTIMIZED_HLO_DOT +}; + +struct ArgShapeAndDType { + TensorShape shape; + DataType dtype; +}; + +enum class CompilerArgSource { + TENSOR_SPEC, + CONCRETE_INPUT, +}; + +// Returns the IR format of the selected stage for a given function `func_name` +// using library runtime `runtime` on a device `dev` with given +// `inputs_arg_shape_and_dtype` and `input_handles`. +absl::StatusOr GetCompilerIr( + IrExportStage stage, ProcessFunctionLibraryRuntime* pflr, + absl::string_view func_name, Device* dev, EagerContext* context, + absl::Span input_arg_shape_and_dtype, + absl::Span input_handles, + CompilerArgSource compiler_arg_source); + +// Returns the IR format of the selected stage for a given function `func_name` +// using library runtime `runtime` on a platform `platform_name` with given +// `inputs_arg_shape_and_dtype` and `input_handles`. +absl::StatusOr GetCompilerIr( + IrExportStage stage, ProcessFunctionLibraryRuntime* pflr, + absl::string_view func_name, absl::string_view platform_name, + EagerContext* context, + absl::Span input_arg_shape_and_dtype, + absl::Span input_handles, + CompilerArgSource compiler_arg_source); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_JIT_GET_COMPILER_IR_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass.h b/third_party/tflite-hdrs/tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass.h new file mode 100644 index 00000000..23f54afe --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass.h @@ -0,0 +1,59 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_JIT_INCREASE_DYNAMISM_FOR_AUTO_JIT_PASS_H_ +#define TENSORFLOW_COMPILER_JIT_INCREASE_DYNAMISM_FOR_AUTO_JIT_PASS_H_ + +#include "absl/status/status.h" +#include "tensorflow/core/common_runtime/optimization_registry.h" +#include "tensorflow/core/platform/status.h" + +namespace tensorflow { + +// Increases the amount of "dynamism" representable by XLA clusters by rewriting +// the TensorFlow graph. This pass does the following rewrites: +// +// Slice +// ----- +// +// Slice(op, begin, size ) => +// Slice(op, begin, actual_size(op.shape(), size, begin)); +// _XlaCompileTimeConstantInputs={2} +// +// where +// +// actual_size(op_shape, size, begin)[i] = +// size[i] == -1 ? (op_shape[i] - size[i]) +// : size[i] +// +// This pass, combined with jit/partially_decluster_pass, reduces the number of +// unnecessary cluster recompilations in some common cases. After the rewrite +// shown above jit/partially_decluster_pass extracts the actual_size(...) +// computation to outside the XLA cluster, causing the cluster to be versioned +// only on the actual size of the XlaDynamicSlice. This avoids recompilation +// due to superficial changes that don't affect tensor shapes. +// +// Future Work TODO(b/111210515) +// ----------------------------- +// +// In the future we will also translate StridedSlice and Pad a similar way. +class IncreaseDynamismForAutoJitPass : public GraphOptimizationPass { + public: + absl::Status Run(const GraphOptimizationPassOptions& options) override; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_JIT_INCREASE_DYNAMISM_FOR_AUTO_JIT_PASS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/jit/kernels/xla_ops.h b/third_party/tflite-hdrs/tensorflow/compiler/jit/kernels/xla_ops.h new file mode 100644 index 00000000..911b5cae --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/jit/kernels/xla_ops.h @@ -0,0 +1,140 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_JIT_KERNELS_XLA_OPS_H_ +#define TENSORFLOW_COMPILER_JIT_KERNELS_XLA_OPS_H_ + +#include + +#include "tensorflow/compiler/jit/device_compiler.h" +#include "tensorflow/compiler/jit/xla_device.h" +#include "tensorflow/compiler/jit/xla_launch_util.h" +#include "tensorflow/compiler/jit/xla_platform_info.h" +#include "xla/stream_executor/integrations/tf_allocator_adapter.h" +#include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/util/stream_executor_util.h" + +namespace tensorflow { + + +// XlaLocalLaunchBase is almost the same as XlaLocalLaunchOp. +// The only difference is that it does not require arguments to follow +// the "constants, then regular args, then resources" order. +// It takes vectors of constant and resource arguments explicitly. +// It does not have corresponding OpDef because it is never present +// in the GraphDef. +// Currently, it is used by eager runtime. FunctionLibraryRuntime creates +// this kernel when asked to create a kernel for an XLA-compiled function. +// +// `has_ref_vars`: whether the input computation can have reference variables. +// TODO(cheshire): instead derive this information from the input graph. +class XlaLocalLaunchBase : public AsyncOpKernel { + public: + XlaLocalLaunchBase(OpKernelConstruction* ctx, + const std::vector& constants, + const std::vector& resources, + const NameAttrList& function, bool has_ref_vars); + XlaLocalLaunchBase(const XlaLocalLaunchBase&) = delete; + XlaLocalLaunchBase& operator=(const XlaLocalLaunchBase&) = delete; + ~XlaLocalLaunchBase() override = default; + + void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override; + + protected: + // Indexes of compile-time constant inputs + const std::vector constants_; + // Indexes of resource inputs + const std::vector resources_; + + const NameAttrList function_; + const XlaPlatformInfo platform_info_; + + bool has_ref_vars_; +}; + +// XlaLocalLaunchOp is used to replace a region of the TensorFlow graph +// which will be compiled and executed using XLA. The XlaLocalLaunchOp is +// responsible for handling interactions with the TensorFlow executor. +// Once all inputs are present, and their shapes are known, the op can +// use a 'DeviceCompiler' to compile and execute code which is specific +// to the shapes of input Tensors. +// XlaLocalLaunchOp uses xla::LocalClient::Compile() and +// xla::LocalExecutable::Run(), and passes arguments into/out of XLA in device +// memory. +class XlaLocalLaunchOp : public XlaLocalLaunchBase { + public: + explicit XlaLocalLaunchOp(OpKernelConstruction* ctx); + ~XlaLocalLaunchOp() override; + + private: + XlaLocalLaunchOp(const XlaLocalLaunchOp&) = delete; + void operator=(const XlaLocalLaunchOp&) = delete; +}; + +class XlaCompileOp : public OpKernel { + public: + explicit XlaCompileOp(OpKernelConstruction* ctx); + + void Compute(OpKernelContext* ctx) override; + + private: + // Indexes of compile-time constant inputs + const std::vector constants_; + // Indexes of resource inputs + const std::vector resources_; + + const NameAttrList function_; + + XlaPlatformInfo platform_info_; + + const bool must_compile_; + + // Whether the graph has TF reference variables. + const bool has_ref_vars_; + + // cannot_compile_cluster_ is set to true if XLA returns an Unimplemented + // error when compiling the cluster this _XlaCompile is supposed to compile. + // If `cannot_compile_cluster_` is true then we avoid compiling this cluster + // on any future calls to _XlaCompile. + bool cannot_compile_cluster_ TF_GUARDED_BY(cannot_compile_cluster_mu_) = + false; + + mutex cannot_compile_cluster_mu_; +}; + +class XlaRunOp : public OpKernel { + public: + explicit XlaRunOp(OpKernelConstruction* ctx); + + void Compute(OpKernelContext* ctx) override; + + private: + const XlaPlatformInfo platform_info_; +}; + +class XlaMergeOp : public OpKernel { + public: + explicit XlaMergeOp(OpKernelConstruction* ctx); + + void Compute(OpKernelContext* ctx) override; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_JIT_KERNELS_XLA_OPS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/jit/mark_for_compilation_pass.h b/third_party/tflite-hdrs/tensorflow/compiler/jit/mark_for_compilation_pass.h new file mode 100644 index 00000000..558912f2 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/jit/mark_for_compilation_pass.h @@ -0,0 +1,63 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// An optimization passes that marks nodes that are to be compiled with +// attribute kXlaClusterAttr. Nodes with the same cluster ID will be compiled +// together. + +#ifndef TENSORFLOW_COMPILER_JIT_MARK_FOR_COMPILATION_PASS_H_ +#define TENSORFLOW_COMPILER_JIT_MARK_FOR_COMPILATION_PASS_H_ + +#include "absl/container/flat_hash_set.h" +#include "tensorflow/compiler/jit/compilability_check_util.h" +#include "tensorflow/core/common_runtime/optimization_registry.h" + +namespace tensorflow { + +// The attribute that marks nodes to be grouped into functions by the +// encapsulate subgraphs pass. +extern const char* const kXlaClusterAttr; + +// Marks a subset of nodes in the graph which are to be clustered +// with an attribute _XlaCluster= so they are picked up by the +// EncapsulateSubgraphsPass. +class MarkForCompilationPass : public GraphOptimizationPass { + public: + MarkForCompilationPass() = default; + + absl::Status Run(const GraphOptimizationPassOptions& options) override; + + private: + absl::Status RunForTest(const GraphOptimizationPassOptions& options, + bool disable_deadness_analysis, + bool deterministic_cluster_names); + + friend class MarkForCompilationPassTestHelper; +}; + +absl::flat_hash_map>* GetAllowlistTable(); + +namespace testing { +// DO NOT USE IN PRODUCTION. +// +// Resets some internal state to let us write reliable unit tests. +void ResetClusterSequenceNumber(); + +// Return a list of operation that we choose not to put into the allowlist. +absl::flat_hash_set GetKnownXLAAllowlistOp(); +} // namespace testing +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_JIT_MARK_FOR_COMPILATION_PASS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.h b/third_party/tflite-hdrs/tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.h new file mode 100644 index 00000000..84d24898 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.h @@ -0,0 +1,85 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_JIT_MARK_FOR_COMPILATION_PASS_TEST_HELPER_H_ +#define TENSORFLOW_COMPILER_JIT_MARK_FOR_COMPILATION_PASS_TEST_HELPER_H_ + +#include +#include +#include + +#include "tensorflow/compiler/jit/mark_for_compilation_pass.h" + +namespace tensorflow { +class MarkForCompilationPassTestHelper { + public: + struct Options { + bool enable_global_jit; + bool disable_deadness_analysis; + bool enable_cluster_scoping; + bool deterministic_cluster_names; + std::string session_name; // ConfigProto.Experimental.SessionMetadata.name + + Options() + : enable_global_jit(true), + disable_deadness_analysis(true), + enable_cluster_scoping(true), + deterministic_cluster_names(false) {} + + Options WithNoGlobalJit() { + Options copy = *this; + copy.enable_global_jit = false; + return copy; + } + + Options WithDeadnessAnalysis() { + Options copy = *this; + copy.disable_deadness_analysis = false; + return copy; + } + + Options WithNoClusterScoping() { + Options copy = *this; + copy.enable_cluster_scoping = false; + return copy; + } + + Options WithDeterministicClusterNames() { + Options copy = *this; + copy.deterministic_cluster_names = true; + return copy; + } + + Options WithSessionName(std::string name) { + Options copy = *this; + copy.session_name = std::move(name); + return copy; + } + }; + + // Runs the MarkForCompilation pass on `graph` after assigning all nodes in + // `graph` to the CPU device. To make testing easier, ignores device + // registration and _XlaCompile attributes. + static absl::Status MarkForCompilation(std::unique_ptr* graph, + FunctionLibraryDefinition* flib_def, + Options options = Options()); + + // Like `MarkForCompilation` but creates `flib_def` from the op registry. + static absl::Status MarkForCompilation(std::unique_ptr* graph, + Options options = Options()); +}; +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_JIT_MARK_FOR_COMPILATION_PASS_TEST_HELPER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/jit/node_matchers.h b/third_party/tflite-hdrs/tensorflow/compiler/jit/node_matchers.h new file mode 100644 index 00000000..a0208680 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/jit/node_matchers.h @@ -0,0 +1,251 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Provides a set of matchers for tensorflow nodes. +// +// Example usage: +// +// tensorflow::Node* node = ...; +// EXPECT_THAT(node, NodeWith(Name("name"), Op("op"), +// Inputs(Out(3, NodeWith(Name("input")))))) +// +// Matchable node properties (the expressions that go inside NodeWith(...)) +// are: +// +// - Name(string): matches the node name exactly. We will probably need to +// have this take a string matcher soon in the future. +// +// - Op(string): matches the op exactly. +// +// - AssignedDevice(string): matches the assigned device exactly. +// +// - Inputs(): matches the list of non-control inputs to the node +// exactly (i.e. does not match a suffix or a prefix) where each element +// matches an output of a node (see Out(idx, node) below). +// +// - CtrlDeps(): matches the list of control dependences on the +// node exactly but in any order. +// +// - ConstantValue(tensorflow::Input::Initializer init): matches a Const node +// with the constant value `init`. Implies Op("Const"). +// +// - Attr(name, value): Matches a single attribute with name `name` and value +// `value`. Right now only boolean values are supported. +// +// Overlapping node properties may not be repeated in a single NodeWith(...) +// matcher. E.g. NodeWith(Op("Foo"), Op("Bar")) will CHECK-fail. Since +// ConstantValue implies Op("Const"), a single NodeWith matcher can't have both +// ConstantValue(...) and Op(...). Multiple Attr() values can be combined as +// long as the attribute names are different. +// +// Out(idx, node) matches the `idx`'th output of a node that matches `node`. + +#ifndef TENSORFLOW_COMPILER_JIT_NODE_MATCHERS_H_ +#define TENSORFLOW_COMPILER_JIT_NODE_MATCHERS_H_ + +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "absl/types/span.h" +#include "tensorflow/cc/framework/ops.h" +#include "xla/test.h" +#include "tensorflow/core/graph/graph.h" + +namespace tensorflow { +namespace testing { +namespace matchers { + +namespace impl { + +using OutEdge = std::pair; + +// ----------------------------------------------------------------------------- +// Implementation details. + +// Properties that we match on for a particular Node. If a particular property +// is nullopt then any value for it is allowed. +class NodeMatcherProperties { + public: + using NodeSeqMatcher = std::vector<::testing::Matcher>; + using InputSeqMatcher = std::vector<::testing::Matcher>; + using AttrKeyValuePair = std::pair>; + + const std::optional& name() const { return name_; } + const std::optional& op() const { return op_; } + const std::optional& assigned_device() const { + return assigned_device_; + } + const std::optional& constant_value() const { + return constant_value_; + } + const std::optional& inputs() const { + return input_matchers_; + } + const std::optional& control_deps() const { + return control_deps_; + } + const std::optional& attr() const { return attr_; } + + void set_name(string name) { + DCHECK(IsEmpty()); + name_ = std::move(name); + } + + void set_op(string op) { + DCHECK(IsEmpty()); + op_ = std::move(op); + } + + void set_assigned_device(string assigned_device) { + DCHECK(IsEmpty()); + assigned_device_ = std::move(assigned_device); + } + + void set_constant_value(Tensor constant_value) { + DCHECK(IsEmpty()); + constant_value_ = std::move(constant_value); + op_ = "Const"; + } + + void set_inputs(InputSeqMatcher inputs) { + DCHECK(IsEmpty()); + input_matchers_ = std::move(inputs); + } + + void set_control_deps(NodeSeqMatcher control_deps) { + DCHECK(IsEmpty()); + control_deps_ = std::move(control_deps); + } + + void set_attr(AttrKeyValuePair attr) { + DCHECK(IsEmpty()); + attr_ = std::move(attr); + } + + bool IsEmpty() const { + return !name().has_value() && !op().has_value() && !inputs().has_value() && + !control_deps().has_value() && !attr().has_value(); + } + + private: + std::optional name_; + std::optional op_; + std::optional assigned_device_; + std::optional constant_value_; + std::optional input_matchers_; + std::optional control_deps_; + std::optional attr_; +}; + +::testing::Matcher NodeWith( + absl::Span props); + +impl::NodeMatcherProperties Inputs( + absl::Span> inputs); + +impl::NodeMatcherProperties CtrlDeps( + absl::Span> control_deps); + +impl::NodeMatcherProperties Attr(std::pair attrs); +impl::NodeMatcherProperties Attr(string name); + +std::pair AttrLiteralHelper( + const std::pair& bool_attr); + +std::pair AttrLiteralHelper( + const std::pair>& int_list_attr); + +std::pair AttrLiteralHelper( + const std::pair>& string_list_attr); +} // namespace impl + +// ----------------------------------------------------------------------------- +// Public interface. + +// Matches a node with name `name`. +impl::NodeMatcherProperties Name(string name); + +// Matches a node with op `op`. +impl::NodeMatcherProperties Op(string op); + +// Matches a node with assigned device `assigned_device`. +impl::NodeMatcherProperties AssignedDevice(string assigned_device); + +// Matches a node with a boolean typed attribute named `name` and with value +// `value`. +template +impl::NodeMatcherProperties Attr(const string& name, ValueTy value) { + return impl::Attr({impl::AttrLiteralHelper({name, value})}); +} + +inline impl::NodeMatcherProperties Attr(const string& name) { + return impl::Attr(name); +} + +// Matches a node with inputs `inputs`. +// +// `inputs` are ordered; `inputs`[i] must match input i. +template +impl::NodeMatcherProperties Inputs(Ts... inputs) { + return impl::Inputs({inputs...}); +} + +// Matches the `idx`'th output of a node that matches `node`. +::testing::Matcher Out(int oidx, + ::testing::Matcher node); + +// Matches the first output of a node that matches `node`. +inline ::testing::Matcher Out( + ::testing::Matcher node) { + return Out(0, node); +} + +// Matches a node with control dependences `control_deps`. +// +// `control_deps` are unordered and will match the control deps of a node in any +// order. +template +impl::NodeMatcherProperties CtrlDeps(Ts... control_deps) { + return impl::CtrlDeps({control_deps...}); +} + +// Matches a constant node with value `val`. +impl::NodeMatcherProperties ConstantValue( + const ::tensorflow::Input::Initializer& val); + +// The main gmock matcher. See file comment for example usage. +template +::testing::Matcher NodeWith(Ts... args) { + std::array array = {args...}; + return impl::NodeWith(array); +} + +::testing::Matcher Const( + const ::tensorflow::Input::Initializer& val); +} // namespace matchers + +// If `g` has a node named `name` returns it, otherwise returns null. +Node* FindNodeByName(Graph* g, absl::string_view name); +} // namespace testing + +void PrintTo(const Node* n, ::std::ostream* os); +void PrintTo(Node* n, ::std::ostream* os); +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_JIT_NODE_MATCHERS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/jit/partially_decluster_pass.h b/third_party/tflite-hdrs/tensorflow/compiler/jit/partially_decluster_pass.h new file mode 100644 index 00000000..18b0091c --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/jit/partially_decluster_pass.h @@ -0,0 +1,35 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_JIT_PARTIALLY_DECLUSTER_PASS_H_ +#define TENSORFLOW_COMPILER_JIT_PARTIALLY_DECLUSTER_PASS_H_ + +#include "tensorflow/core/common_runtime/optimization_registry.h" + +namespace tensorflow { + +// Clones or moves nodes from within a cluster to outside the cluster if +// profitable. There are two reasons why we do this: +// +// - Reducing device-to-host copies. +// - Reducing the number of XLA recompilations. +class PartiallyDeclusterPass : public GraphOptimizationPass { + public: + absl::Status Run(const GraphOptimizationPassOptions& options) override; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_JIT_PARTIALLY_DECLUSTER_PASS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/jit/pjrt_base_device.h b/third_party/tflite-hdrs/tensorflow/compiler/jit/pjrt_base_device.h new file mode 100644 index 00000000..b2135745 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/jit/pjrt_base_device.h @@ -0,0 +1,112 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_JIT_PJRT_BASE_DEVICE_H_ +#define TENSORFLOW_COMPILER_JIT_PJRT_BASE_DEVICE_H_ + +#include +#include +#include + +#include "tensorflow/compiler/tf2xla/layout_util.h" +#include "tensorflow/core/common_runtime/local_device.h" +#include "tensorflow/core/framework/device_base.h" + +namespace tensorflow { + +// tensorflow::PjRtBaseDevice replaces the deprecated tensorflow::XlaDevice. +// This accelerator agnostic device is mainly used to store metadata. +class PjRtBaseDevice : public LocalDevice { + public: + // Stores metadata about the PjRtBaseDevice. + class Metadata { + public: + Metadata(const DeviceType& jit_device_type, + std::vector + shape_determination_fns) + : jit_device_type_(jit_device_type), + shape_determination_fns_(std::move(shape_determination_fns)) {} + + // The index of the device on this host. + int device_ordinal() const; + + const DeviceType& jit_device_type() const { return jit_device_type_; } + const XlaShapeLayoutHelpers::ShapeDeterminationFns& + default_shape_determination_fns() const { + return shape_determination_fns_.at(0); + } + + const XlaShapeLayoutHelpers::ShapeDeterminationFns& + shape_determination_fns_at(int i) const { + return shape_determination_fns_[i]; + } + + private: + const DeviceType jit_device_type_; + std::vector + shape_determination_fns_; + + Metadata(const Metadata&) = delete; + void operator=(const Metadata&) = delete; + }; + + struct Options { + // The device name's prefix (e.g., "/task:7") + std::string device_name_prefix; + + // The name of the device (e.g., "TPU") + std::string device_name; + + // The index of the device. + int device_ordinal = -1; + + // The name of the compilation device, also referred to as jit_device_type. + // (e.g., "XLA_CPU_JIT"); + std::string compilation_device_name; + + // A vector of ShapeDeterminationFn (i.e., a bundle of LayoutSelectionFn, + // ShapeRepresentationFn). Each bundle describes how the on-host shapes of + // a) argument and return value, for entry computations b) variables, for + // all computations, should be represented in XLA. Parameters/return values + // will be shaped according to the function pair, and reshaped back to/from + // their declared shapes for computations. Must be non-empty. + std::vector + shape_determination_fns; + + Options(std::string device_name_prefix, std::string device_name, + int device_ordinal, std::string compilation_device_name, + std::vector + shape_determination_fns) + : device_name_prefix(device_name_prefix), + device_name(device_name), + device_ordinal(device_ordinal), + compilation_device_name(compilation_device_name), + shape_determination_fns(shape_determination_fns) {} + }; + + // Creates a new PJRT base device. + PjRtBaseDevice(const SessionOptions& session_options, const Options& options); + + static absl::StatusOr GetMetadataFromDevice( + DeviceBase* device); + + private: + // The metadata of this PjRtBaseDevice. + const Metadata metadata_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_JIT_PJRT_BASE_DEVICE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/jit/pjrt_compile_util.h b/third_party/tflite-hdrs/tensorflow/compiler/jit/pjrt_compile_util.h new file mode 100644 index 00000000..11645651 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/jit/pjrt_compile_util.h @@ -0,0 +1,60 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_JIT_PJRT_COMPILE_UTIL_H_ +#define TENSORFLOW_COMPILER_JIT_PJRT_COMPILE_UTIL_H_ + +#include "tensorflow/compiler/jit/xla_compile_util.h" +#include "tensorflow/compiler/jit/xla_platform_info.h" +#include "tensorflow/compiler/tf2xla/xla_compiler.h" +#include "xla/pjrt/pjrt_client.h" +#include "tensorflow/core/framework/device_base.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/resource_mgr.h" +#include "tensorflow/core/platform/status.h" + +namespace tensorflow { + +// Compiles a `function` to PjRtLoadedExecutable `executable` with `ctx`. +// The compilation result is output in `compilation_result`. The PJRT client +// used for compilation is output in `client`. The PJRT executable is output in +// `executable`. +absl::Status CompileToPjRtLoadedExecutable( + const OpKernelContext& ctx, const XlaPlatformInfo& platform_info, + const NameAttrList& function, + const std::vector& args, + DeviceCompileMode compile_mode, bool has_ref_vars, + bool may_alias_resource_update, + const XlaCompiler::CompilationResult** compilation_result, + xla::PjRtClient** client, xla::PjRtLoadedExecutable** executable); + +// Similar to the above function but it does not take a OpKernelContext. +// Instead, it takes the following arguments that are obtained from +// OpKernelContext in the above function. +// - `device`: the device used to compile the function. +// - `rm`: the resource manager for DeviceCompiler to store JIT-compiled XLA +// computation. +// - `flr`: the FunctionLibraryRuntime for the `function`. +absl::Status CompileToPjRtLoadedExecutable( + const DeviceBase* device, const XlaPlatformInfo& platform_info, + const NameAttrList& function, + const std::vector& args, + DeviceCompileMode compile_mode, bool has_ref_vars, + bool may_alias_resource_update, FunctionLibraryRuntime* flr, + ResourceMgr* rm, const XlaCompiler::CompilationResult** compilation_result, + xla::PjRtClient** client, xla::PjRtLoadedExecutable** executable); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_JIT_PJRT_COMPILE_UTIL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/jit/pjrt_device_compiler_client.h b/third_party/tflite-hdrs/tensorflow/compiler/jit/pjrt_device_compiler_client.h new file mode 100644 index 00000000..8c590b57 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/jit/pjrt_device_compiler_client.h @@ -0,0 +1,85 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_JIT_PJRT_DEVICE_COMPILER_CLIENT_H_ +#define TENSORFLOW_COMPILER_JIT_PJRT_DEVICE_COMPILER_CLIENT_H_ + +#include +#include +#include + +#include "tensorflow/compiler/jit/device_compiler_client.h" +#include "xla/pjrt/pjrt_client.h" + +namespace tensorflow { + +// Calls into PjRtClient to provide functionality for building, serializing and +// loading PjRtLoadedExecutables. +class PjRtDeviceCompilerClient + : public DeviceCompilerClient { + public: + explicit PjRtDeviceCompilerClient(xla::PjRtClient* client) + : client_(client) {} + + absl::StatusOr> BuildExecutable( + const XlaCompiler::Options& options, + const XlaCompiler::CompilationResult& result) override; + + // Returns a platform-specific serialization of `executable`. The + // serialization is not guaranteed to be stable over time. `executable` must + // have been produced by this client. + absl::StatusOr SerializeExecutable( + const xla::PjRtLoadedExecutable& executable) override; + + // PjRt doesn't support AOT compilation yet. Builds a PjRtLoadedExecutable and + // serializes it to string. + absl::StatusOr BuildSerializedExecutable( + const XlaCompiler::Options& options, + const XlaCompiler::CompilationResult& result) override; + + // Deserializes a serialized executable as produced by + // PjRtExecutable::SerializeExecutable(). `serialized_executable` must have + // been produced by a compiler of the same platform and version as this one. + // + // PjRt doesn't support AOT compilation yet. Loading a serialized executable + // is currently only implemented for TfrtTpuPjrtClient and hence, this + // function doesn't use PjRtClient::LoadSerializedExecutable() and uses + // PjRtClient::DeserializeExecutable() instead. + absl::StatusOr> LoadExecutable( + const XlaCompiler::Options& options, + const XlaCompiler::CompilationResult& result, + const std::string& serialized_executable) override; + + // No-op. PJRT uses futures and waiting for programs to finish isn't + // necessary. + void WaitForProgramsToFinish() override; + + xla::PjRtClient* client() const override { return client_; } + + private: + xla::PjRtClient* const client_; + + PjRtDeviceCompilerClient(const PjRtDeviceCompilerClient&) = delete; + void operator=(const PjRtDeviceCompilerClient&) = delete; +}; + +// Generates CompileOptions for PJRT compilation. +xla::CompileOptions GetPjRtCompileOptions( + const XlaCompiler::Options& options, + const XlaCompiler::CompilationResult& result); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_JIT_PJRT_DEVICE_COMPILER_CLIENT_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/jit/pjrt_device_context.h b/third_party/tflite-hdrs/tensorflow/compiler/jit/pjrt_device_context.h new file mode 100644 index 00000000..7637d396 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/jit/pjrt_device_context.h @@ -0,0 +1,64 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_JIT_PJRT_DEVICE_CONTEXT_H_ +#define TENSORFLOW_COMPILER_JIT_PJRT_DEVICE_CONTEXT_H_ + +#include + +#include "tensorflow/compiler/tf2xla/layout_util.h" +#include "tensorflow/core/framework/device_base.h" +#include "tensorflow/core/platform/status.h" + +namespace tensorflow { + +// Helper class for managing data transfers between host and accelerator +// devices using PjRt. +class PjRtDeviceContext : public DeviceContext { + public: + explicit PjRtDeviceContext( + XlaShapeLayoutHelpers::ShapeDeterminationFns shape_determination_fns, + bool use_pjrt_tensor_buffer = false) + : shape_determination_fns_(std::move(shape_determination_fns)), + use_pjrt_tensor_buffer_(use_pjrt_tensor_buffer) {} + + void CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device, + Tensor* device_tensor, StatusCallback done, + bool sync_dst_compute) const override; + void CopyDeviceTensorToCPU(const Tensor* device_tensor, + absl::string_view tensor_name, Device* device, + Tensor* cpu_tensor, StatusCallback done) override; + void CopyTensorInSameDevice(const Tensor* input_tensor, Device* device, + Tensor* output_tensor, + StatusCallback done) const override; + + bool use_pjrt_tensor_buffer() const { return use_pjrt_tensor_buffer_; } + + private: + XlaShapeLayoutHelpers::ShapeDeterminationFns shape_determination_fns_; + // Note: we currently assume the PjRtBuffer is a PjRtStreamExecutorBuffer. + bool use_pjrt_tensor_buffer_; +}; + +void PjRtDeviceToDeviceCopy(DeviceContext* send_dev_context, + DeviceContext* recv_dev_context, Device* src, + Device* dst, AllocatorAttributes src_alloc_attr, + AllocatorAttributes dst_alloc_attr, + const Tensor* input, Tensor* output, + int dev_to_dev_stream_index, StatusCallback done); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_JIT_PJRT_DEVICE_CONTEXT_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/jit/pjrt_tensor_buffer.h b/third_party/tflite-hdrs/tensorflow/compiler/jit/pjrt_tensor_buffer.h new file mode 100644 index 00000000..0dd496c9 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/jit/pjrt_tensor_buffer.h @@ -0,0 +1,57 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_JIT_PJRT_TENSOR_BUFFER_H_ +#define TENSORFLOW_COMPILER_JIT_PJRT_TENSOR_BUFFER_H_ + +#include +#include + +#include "xla/pjrt/pjrt_client.h" +#include "tensorflow/core/framework/allocation_description.pb.h" +#include "tensorflow/core/framework/tensor.h" + +namespace tensorflow { + +// PjRtTensorBuffer is derived from TensorBuffer, which holds a device memory +// pointer so that legacy TF kernel can access it directly. PjRtTensorBuffer +// also owns a PjRtBuffer for XLA kernel's usage. +class PjRtTensorBuffer : public TensorBuffer { + public: + PjRtTensorBuffer(const void* ptr, size_t expected_size, + std::unique_ptr pjrt_buffer) + : TensorBuffer(const_cast(ptr)), + expected_size_(expected_size), + pjrt_buffer_(std::move(pjrt_buffer)) {} + + size_t size() const override { return expected_size_; } + + TensorBuffer* root_buffer() override { return this; } + + xla::PjRtBuffer* pjrt_buffer() const { return pjrt_buffer_.get(); } + + // TODO(b/288965065): Implement this. + void FillAllocationDescription(AllocationDescription* proto) const override { + proto->set_requested_bytes(static_cast(expected_size_)); + } + + private: + size_t expected_size_; + std::unique_ptr pjrt_buffer_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_JIT_PJRT_TENSOR_BUFFER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/jit/pjrt_tensor_buffer_util.h b/third_party/tflite-hdrs/tensorflow/compiler/jit/pjrt_tensor_buffer_util.h new file mode 100644 index 00000000..f73834b5 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/jit/pjrt_tensor_buffer_util.h @@ -0,0 +1,56 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_JIT_PJRT_TENSOR_BUFFER_UTIL_H_ +#define TENSORFLOW_COMPILER_JIT_PJRT_TENSOR_BUFFER_UTIL_H_ + +#include + +#include "absl/status/statusor.h" +#include "tensorflow/compiler/jit/pjrt_tensor_buffer.h" +#include "xla/pjrt/pjrt_client.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" + +namespace tensorflow { + +// Takes the device memory pointer from the PjRtBuffer and create a Tensor that +// contains a PjRtTensorBuffer. The PjRtTensorBuffer holds the pointer to the +// device memory. It also owns the PjRtBuffer. +// +// TODO(b/289001822): Create a unit test to cover this function. +absl::StatusOr MakeTensorFromPjRtBuffer( + DataType dtype, const TensorShape& shape, + std::unique_ptr pjrt_buffer); + +// For TensorFlow internal use only. +class PjRtTensorBufferUtil { + public: + // Takes the device memory pointer from the PjRtBuffer and create a + // PjRtTensorBuffer. The PjRtTensorBuffer holds the pointer to the device + // memory. It also owns the PjRtBuffer. If output_tensor does not use + // PjRtTensorBuffer and the opaque device memory is the same, update the + // output_tensor->buf_ so that the same device memory will not be double-free. + // Otherwise a new Tensor will be created with the PjRtTensorBuffer. + // + // TODO(b/289001822): Create a unit test to cover this function. + static absl::Status UpdateOrMakeTensorWithPjRtBuffer( + DataType dtype, const TensorShape& shape, + std::unique_ptr pjrt_buffer, Tensor* output_tensor); +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_JIT_PJRT_TENSOR_BUFFER_UTIL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/jit/report_clustering_info_pass.h b/third_party/tflite-hdrs/tensorflow/compiler/jit/report_clustering_info_pass.h new file mode 100644 index 00000000..2ac67bf1 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/jit/report_clustering_info_pass.h @@ -0,0 +1,32 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_JIT_REPORT_CLUSTERING_INFO_PASS_H_ +#define TENSORFLOW_COMPILER_JIT_REPORT_CLUSTERING_INFO_PASS_H_ + +#include "tensorflow/core/common_runtime/optimization_registry.h" + +namespace tensorflow { + +// This is not really an optimization pass. It does not change the graph in any +// way; instead it computes a summary of the XLA clusters in the graph and +// broadcasts it via xla_activity_listener. +class ReportClusteringInfoPass : public GraphOptimizationPass { + public: + absl::Status Run(const GraphOptimizationPassOptions& options) override; +}; +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_JIT_REPORT_CLUSTERING_INFO_PASS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/jit/resource_operation_safety_analysis.h b/third_party/tflite-hdrs/tensorflow/compiler/jit/resource_operation_safety_analysis.h new file mode 100644 index 00000000..eea18fb1 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/jit/resource_operation_safety_analysis.h @@ -0,0 +1,69 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_JIT_RESOURCE_OPERATION_SAFETY_ANALYSIS_H_ +#define TENSORFLOW_COMPILER_JIT_RESOURCE_OPERATION_SAFETY_ANALYSIS_H_ + +#include "xla/service/graphcycles/graphcycles.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/graph/graph.h" + +namespace tensorflow { +// An XLA cluster hoists all resource reads to be beginning of the cluster +// execution and all the resource writes to the end. This means it cannot +// enforce arbitrary ordering dependencies (via control or data edges) between +// resource operations. Since all resource reads happen before all resource +// writes, edges constraining resource writes to happen before resource reads +// are problematic. This analysis returns the set of pairs of resource +// operations that cannot be put in the same cluster because XLA cannot respect +// the dependencies between them in the TensorFlow program. +// +// The restrictions are not transitive: it is fine to put A and C in the same +// cluster even if the returned set contains (A,B) and (B,C). +// +// In other words, if these pairs are seen as edges in an undirected graph of +// the nodes in `g` then auto-clustering is at least as constrained as the graph +// coloring problem on this graph. +// +// +// For instance if we auto-cluster all operations in this TensorFlow graph: +// +// AssignVariablepOp0 -> AssignVariableOp1 +// | +// v +// ReadVariableOp0 -> ReadVariableOp1 +// +// we will lose the AssignVariablepOp1 -> ReadVariableOp0. The ReadVariableOp0 +// -> ReadVariableOp1 and AssignVariableOp0 -> AssignVariableOp1 edges will be +// respected by XlaLaunchOp though because all reads happen before all writes +// with that limited clustering.. +// +// +// NB! The result computed by this analysis assumes that we don't auto-cluster +// back-edges (i.e. the edges from NextIteration to Merge). +// +// NB! The result computed by this analysis assumes that we don't auto-cluster +// functional control flow nodes containing resource operations. +// +// If `resource_ops_to_ignore` is set then nodes for which it returns true are +// ignored (we pretend these nodes are not resource operations). +absl::Status ComputeIncompatibleResourceOperationPairs( + const Graph& g, const FunctionLibraryDefinition* flib_def, + const std::function& + resource_ops_to_ignore, + std::vector>* result); +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_JIT_RESOURCE_OPERATION_SAFETY_ANALYSIS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/jit/shape_inference.h b/third_party/tflite-hdrs/tensorflow/compiler/jit/shape_inference.h new file mode 100644 index 00000000..467ecb83 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/jit/shape_inference.h @@ -0,0 +1,58 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_JIT_SHAPE_INFERENCE_H_ +#define TENSORFLOW_COMPILER_JIT_SHAPE_INFERENCE_H_ + +#include +#include + +#include "absl/status/statusor.h" +#include "tensorflow/core/common_runtime/optimization_registry.h" +#include "tensorflow/core/common_runtime/shape_refiner.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { + +struct InferredShape { + // Shape of the argument tensor. + PartialTensorShape shape; + + // If the argument is a resource variable, the type and shape of the + // variable's value. + DataType handle_type = DT_INVALID; + PartialTensorShape handle_shape; +}; +typedef std::unordered_map> GraphShapeInfo; + +// Infer shapes for all Tensors in a graph, and save them in a map. The vector +// for a Node contains the information about each of its outputs. +// TODO(phawkins): this code does not infer accurate shapes for cyclic graphs. +// `arg_shapes`: user given map from the `index` to shapes of this +// node, where `index` is the `index` attribute of `_Arg` op or `_index` +// attribute of `Placeholder` op. +absl::Status InferShapes(Graph* graph, + const std::map& arg_shapes, + const tensorflow::FunctionLibraryDefinition* fnlib_def, + GraphShapeInfo* shape_info); + +// Merges two InferredShapes. Return an error if the two shapes cannot be +// merged. +absl::StatusOr MergeInferredShapes(const InferredShape& a, + const InferredShape& b); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_JIT_SHAPE_INFERENCE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/jit/shape_inference_helpers.h b/third_party/tflite-hdrs/tensorflow/compiler/jit/shape_inference_helpers.h new file mode 100644 index 00000000..d4c81954 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/jit/shape_inference_helpers.h @@ -0,0 +1,65 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_JIT_SHAPE_INFERENCE_HELPERS_H_ +#define TENSORFLOW_COMPILER_JIT_SHAPE_INFERENCE_HELPERS_H_ + +#include + +#include "tensorflow/core/graph/graph.h" + +namespace tensorflow { + +// Helper class to temporarily remove, then replace, the back edges in a +// graph. Simple algorithms for shape inference don't work with cycles, and this +// class can be used to remove cycles before running inference and replace them +// after. Correct usage requires exactly one call to Remove(), followed by any +// number of calls to RemovedEdges() and at most one call to Replace(). The call +// to Replace() is optional if the graph will be discarded without being +// executed, e.g., if it is being used purely for a shape inference pass. +class BackEdgeHelper { + public: + struct BackEdge { + const Edge* edge; + Node* src; + int src_output; + Node* dst; + int dst_input; + }; + + BackEdgeHelper() = default; + // Disallows copy and assign. + BackEdgeHelper(const BackEdgeHelper& other) = delete; + BackEdgeHelper& operator=(const BackEdgeHelper& other) = delete; + + // Temporarily removes all the back edges in graph. + absl::Status Remove(Graph* graph); + + // Gets the list of removed edges. + const std::vector& RemovedEdges() const; + + // Replaces the back edges removed by a prior call to Remove. + absl::Status Replace(); + + private: + Graph* graph_ = nullptr; // not owned + std::vector back_edges_; + // Set once Replace has been called. + bool replaced_ = false; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_JIT_SHAPE_INFERENCE_HELPERS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/jit/test_util.h b/third_party/tflite-hdrs/tensorflow/compiler/jit/test_util.h new file mode 100644 index 00000000..ec694662 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/jit/test_util.h @@ -0,0 +1,89 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Helper functions for tests. + +#ifndef TENSORFLOW_COMPILER_JIT_TEST_UTIL_H_ +#define TENSORFLOW_COMPILER_JIT_TEST_UTIL_H_ + +#include +#include +#include +#include +#include +#include + +#include "tensorflow/compiler/jit/shape_inference.h" +#include "tensorflow/core/common_runtime/device_mgr.h" +#include "tensorflow/core/common_runtime/optimization_registry.h" +#include "tensorflow/core/common_runtime/process_function_library_runtime.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/partial_tensor_shape.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/public/session_options.h" + +namespace tensorflow { + +// Tests that the shapes in 'shape_info' for the nodes in `graph` match +// `expected_shapes`. Returns an error if there are nodes in `expected_shapes` +// that do not have shape information. Ignores nodes in `graph` that do not have +// `expected_shapes` entries. +absl::Status ShapeAnnotationsMatch( + const Graph& graph, const GraphShapeInfo& shape_info, + std::map> expected_shapes); + +// A helper object to create GraphOptimizationPassOptions. +struct GraphOptimizationPassWrapper { + explicit GraphOptimizationPassWrapper() + : library(OpRegistry::Global(), FunctionDefLibrary()) { + session_options.env = Env::Default(); + } + + // Create GraphOptimizationPassOptions with a graph passed in constructor and + // sensible options. + GraphOptimizationPassOptions CreateGraphOptimizationPassOptions( + std::unique_ptr* graph) { + GraphOptimizationPassOptions options; + options.session_options = &session_options; + options.flib_def = &library; + options.graph = graph; + return options; + } + + FunctionLibraryDefinition library; + SessionOptions session_options; +}; + +// Helps set up devices for unit tests. +class DeviceSetup { + public: + void AddDevicesAndSetUp( + const std::vector& device_names, + const std::optional& fdef = std::nullopt); + Device* GetDevice(const string& device_name); + FunctionLibraryRuntime* flr() { return flr_; } + + private: + FunctionLibraryRuntime* flr_; + std::unique_ptr device_mgr_; + std::unique_ptr lib_def_; + std::unique_ptr pflr_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_JIT_TEST_UTIL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/jit/tests/auto_clustering_test_helper.h b/third_party/tflite-hdrs/tensorflow/compiler/jit/tests/auto_clustering_test_helper.h new file mode 100644 index 00000000..4750803c --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/jit/tests/auto_clustering_test_helper.h @@ -0,0 +1,71 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_JIT_TESTS_AUTO_CLUSTERING_TEST_HELPER_H_ +#define TENSORFLOW_COMPILER_JIT_TESTS_AUTO_CLUSTERING_TEST_HELPER_H_ + +#include "absl/status/statusor.h" +#include "tensorflow/core/common_runtime/graph_constructor.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/test_benchmark.h" + +namespace tensorflow { +// Helper to write integration tests and benchmarks for the auto-clustering pass +// pipeline. These tests run auto-clustering on a graphdef and compare a +// summary of the auto-clustering decisions with a "golden" summary. +// +// To create a new test from an TF workload first run the workload with the +// following environment variables set: +// +// TF_DUMP_GRAPH_PREFIX= +// TF_XLA_FLAGS="--tf_xla_clustering_debug" +// +// If auto-clustering is enabled this should produce files named +// before_mark_for_compilation_.pbtxt in the temporary directory. As the +// file name suggests, these are graphdefs that have been dumped right before +// the mark_for_compilation pass. There should be one +// before_mark_for_compilation_.pbtxt for every TF graph that was +// auto-clustered, out of which usually only one is the "main" graph that's +// running training/inference. +// +// Copy the pbtxt for that "main" graph to tensorflow/compiler/jit/tests/ +// (i.e. this directory) and create a corresponding empty .golden_summary file. +// Add the .pbtxt and .golden_summary files to the "data" section of the cc_test +// rule for :auto_clustering_test and then see the comment on update_golden on +// how to auto-generate the .golden_summary file. + +class AutoClusteringTest : public ::testing::Test { + protected: + absl::Status RunAutoClusteringTestWithPbtxt( + absl::string_view pbtxt_file_path, + absl::string_view golden_summary_file_path); + absl::Status RunAutoClusteringTestWithGzippedPbtxt( + absl::string_view gzipped_pbtxt_file_path, + absl::string_view golden_summary_file_path); + + private: + absl::Status RunAutoClusteringTestImpl( + GraphDef graphdef, absl::string_view golden_summary_file_path); +}; + +#if defined(PLATFORM_GOOGLE) +// Reads the GraphDef stored in graph_def_path (which must be a pbtxt file) and +// benchmarks MarkForCompilationPass on this graphdef. +absl::Status BenchmarkMarkForCompilation(absl::string_view graph_def_path, + benchmark::State& state); +#endif // PLATFORM_GOOGLE + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_JIT_TESTS_AUTO_CLUSTERING_TEST_HELPER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/jit/tests/device_compiler_test_helper.h b/third_party/tflite-hdrs/tensorflow/compiler/jit/tests/device_compiler_test_helper.h new file mode 100644 index 00000000..58e0a034 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/jit/tests/device_compiler_test_helper.h @@ -0,0 +1,104 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_JIT_TESTS_DEVICE_COMPILER_TEST_HELPER_H_ +#define TENSORFLOW_COMPILER_JIT_TESTS_DEVICE_COMPILER_TEST_HELPER_H_ + +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/strings/string_view.h" +#include "tensorflow/compiler/jit/xla_activity_listener.h" +#include "tensorflow/core/graph/graph_def_builder.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { + +// A listener to inspect the use of XLA's persistent compilation cache entries. +class JitCompilationListener : public XlaActivityListener { + public: + absl::Status Listen( + const XlaAutoClusteringActivity& auto_clustering_activity) override { + return absl::OkStatus(); + } + + absl::Status Listen( + const XlaJitCompilationActivity& jit_compilation_activity) override { + activity_history_.push_back(jit_compilation_activity); + return absl::OkStatus(); + } + + absl::Status Listen( + const XlaOptimizationRemark& optimization_remark) override { + return absl::OkStatus(); + } + + ~JitCompilationListener() override = default; + + absl::Status VerifyPersistentCacheUseListenerHistory( + bool expect_persistent_cache_use) { + for (const auto& activity : activity_history_) { + if (activity.used_persistent_cache() != expect_persistent_cache_use) { + return absl::FailedPreconditionError("Unexpected listener history."); + } + } + return absl::OkStatus(); + } + + std::vector GetListenerHistory() { + return activity_history_; + } + + void ClearListenerHistory() { activity_history_.clear(); } + + private: + std::vector activity_history_; +}; + +// Fixture for testing XLA compilation cache serialization. +class DeviceCompilerSerializeTest : public ::testing::Test { + protected: + DeviceCompilerSerializeTest() { + auto listener = std::make_unique(); + listener_ = listener.get(); + RegisterXlaActivityListener(std::move(listener)); + } + + JitCompilationListener* listener() const { return listener_; } + + // Returns a test graph that will split into two XLA clusters (due to a node + // with _XlaCompile = false). + GraphDef GetTestGraph(const PartialTensorShape& input_shape); + + // Runs the graph using specified batch size both with and without XLA JIT + // compilation. Returns an error if the results between the two do not match. + absl::Status ExecuteWithBatch(const GraphDef& graph, int batch); + + // Adds the suffix "_altered" to the HLO module names of all of the persistent + // XLA compilation cache entries found at the specified directory. If none are + // found, returns NOT_FOUND error. + absl::Status AlterPersistentCacheEntryHloModuleNames( + absl::string_view persistent_cache_dir_path, + absl::string_view file_prefix = "xla_compile_cache"); + + private: + JitCompilationListener* listener_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_JIT_TESTS_DEVICE_COMPILER_TEST_HELPER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/jit/tf_graph_to_hlo_compiler.h b/third_party/tflite-hdrs/tensorflow/compiler/jit/tf_graph_to_hlo_compiler.h new file mode 100644 index 00000000..adc2a74e --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/jit/tf_graph_to_hlo_compiler.h @@ -0,0 +1,60 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_JIT_TF_GRAPH_TO_HLO_COMPILER_H_ +#define TENSORFLOW_COMPILER_JIT_TF_GRAPH_TO_HLO_COMPILER_H_ + +#include +#include + +#include "tensorflow/compiler/jit/tf_to_hlo_compiler.h" +#include "tensorflow/compiler/tf2xla/xla_compiler.h" +#include "tensorflow/compiler/tf2xla/xla_helpers.h" + +namespace tensorflow { + +class TfGraphToHloCompiler : public TfToHloCompiler { + public: + TfGraphToHloCompiler() = delete; + + explicit TfGraphToHloCompiler(const XlaCompiler::Options& options) + : xla_compiler_(options) {} + + // Compiles a Tensorflow `function` into an HloModuleProto stored in the + // XlaCompilationResult pointed to by `result` by calling + // XlaCompiler::CompileFunction. + absl::Status Compile(const XlaCompiler::CompileOptions& options, + const NameAttrList& function, + absl::Span args, + XlaCompilationResult* result) override; + + // Compiles a Tensorflow single op into an HloModuleProto stored in the + // XlaCompilationResult pointed to by `result` by calling + // XlaCompiler::CompileSingleOp. + absl::Status CompileSingleOp(const XlaCompiler::CompileOptions& options, + const OpKernelContext* ctx, + absl::Span args, + XlaCompilationResult* result) override; + + private: + XlaCompiler xla_compiler_; + + TfGraphToHloCompiler(const TfGraphToHloCompiler&) = delete; + void operator=(const TfGraphToHloCompiler&) = delete; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_JIT_TF_GRAPH_TO_HLO_COMPILER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/jit/tf_to_hlo_compiler.h b/third_party/tflite-hdrs/tensorflow/compiler/jit/tf_to_hlo_compiler.h new file mode 100644 index 00000000..f9937a65 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/jit/tf_to_hlo_compiler.h @@ -0,0 +1,52 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_JIT_TF_TO_HLO_COMPILER_H_ +#define TENSORFLOW_COMPILER_JIT_TF_TO_HLO_COMPILER_H_ + +#include +#include + +#include "tensorflow/compiler/tf2xla/xla_compiler.h" +#include "tensorflow/core/framework/op_kernel.h" + +namespace tensorflow { + +class TfToHloCompiler { + public: + TfToHloCompiler() = default; + virtual ~TfToHloCompiler() = default; + + // Compiles a Tensorflow `function` to an HloModuleProto stored in the + // XlaCompilationResult pointed to by `result`. + virtual absl::Status Compile(const XlaCompiler::CompileOptions& options, + const NameAttrList& function, + absl::Span args, + XlaCompilationResult* result) = 0; + + // Compiles a Tensorflow single op to an HloModuleProto stored in the + // XlaCompilationResult pointed to by `result`. + virtual absl::Status CompileSingleOp( + const XlaCompiler::CompileOptions& options, const OpKernelContext* ctx, + absl::Span args, XlaCompilationResult* result) = 0; + + private: + TfToHloCompiler(const TfToHloCompiler&) = delete; + void operator=(const TfToHloCompiler&) = delete; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_JIT_TF_TO_HLO_COMPILER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/jit/variable_info.h b/third_party/tflite-hdrs/tensorflow/compiler/jit/variable_info.h new file mode 100644 index 00000000..9294c5e4 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/jit/variable_info.h @@ -0,0 +1,95 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_JIT_VARIABLE_INFO_H_ +#define TENSORFLOW_COMPILER_JIT_VARIABLE_INFO_H_ + +#include +#include +#include +#include + +#include "tensorflow/core/framework/resource_var.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/thread_annotations.h" + +namespace tensorflow { + +// Information about the state of a variable passed as input to the _XlaCompile +// and _XlaRun operators. Unlocks the resource variable and decrements its +// refcount on destruction. +class VariableInfo { + public: + explicit VariableInfo(int index, absl::string_view name, Var* var, + const std::optional& + definition_stack_trace = std::nullopt); + VariableInfo(VariableInfo&& other); + + VariableInfo& operator=(VariableInfo&& other); + + VariableInfo(const VariableInfo&) = delete; + VariableInfo& operator=(const VariableInfo&) = delete; + + // The index of the DT_RESOURCE input to the _XlaCompile/_XlaRun operator. + // Note that the indices can be different between _XlaCompile and _XlaRun. + int index() const { return index_; } + + // A pointer to the resource variable. May be null if this VariableInfo is + // "empty", i.e. it does not track a resource variable. + Var* var() const { return var_; } + + // Returns the variable name. + absl::string_view name() const { return name_; } + + // Returns true if the resource variable lock was successfully acquired by + // this thread. + bool lock_held() const { return lock_held_; } + void set_lock_held() { lock_held_ = true; } + + // Returns true if the resource variable reader lock was successfully acquired + // by this thread. + bool shared_lock_held() const { return shared_lock_held_; } + void set_shared_lock_held() { shared_lock_held_ = true; } + + bool read_only() const { return read_only_; } + void set_read_only() { read_only_ = true; } + + const std::optional& definition_stack_trace() const { + return definition_stack_trace_; + } + + ~VariableInfo(); + + private: + int index_; + std::string name_; + Var* var_; + std::optional definition_stack_trace_; + + // We can't use a optional here because it confuses the compiler's + // thread safety analysis. Instead we use a boolean flag and release the lock + // in the VariableInfo destructor. + bool lock_held_ = false; + bool shared_lock_held_ = false; + + // Whether this variable is going to be mutated. Left false if the caller + // doesn't provide this information. + bool read_only_ = false; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_JIT_VARIABLE_INFO_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/jit/variable_info_util.h b/third_party/tflite-hdrs/tensorflow/compiler/jit/variable_info_util.h new file mode 100644 index 00000000..ac825d14 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/jit/variable_info_util.h @@ -0,0 +1,93 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_JIT_VARIABLE_INFO_UTIL_H_ +#define TENSORFLOW_COMPILER_JIT_VARIABLE_INFO_UTIL_H_ + +#include +#include +#include +#include + +#include "tensorflow/compiler/jit/variable_info.h" +#include "tensorflow/core/framework/resource_mgr.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/thread_annotations.h" + +namespace tensorflow { + +// Snapshot of resource variables for a TF kernel invocation, mapping from +// parameter number to values at execution time. If the resource variable is not +// initialized, the value will not be present. +using ResourceVarsSnapshot = absl::flat_hash_map>; + +// Takes a snapshot of the values of resource variable arguments, whose indices +// are specified in `variable_indices` argument. We snapshot tensors that back +// resource variables since concurrent updates may modify the shape, and it is +// important that the shapes used for compilation match the true shapes of the +// buffers. +// +// We snapshot the entire set of resource variables as one atomic operation. +// This models Read->* dependencies between resource variable operations. See +// jit/resource_operation_safety_analysis for details. +absl::Status SnapshotResourceVariables( + OpKernelContext* ctx, absl::Span variable_indices, + absl::Span variable_infos, + ResourceVarsSnapshot* result); + +// Acquires the mutexes for all the variables in `variables` using a +// deadlock-safe protocol (acquire the mutexes in increasing-address order). +// +// `variables` is allowed to contain instances that don't track a resource +// variable (i.e. variables[i].var() can be null for some i). +// +// If the variable is read_only(), only acquires reader locks. +absl::Status LockVariables(absl::Span variables) + TF_EXCLUSIVE_LOCK_FUNCTION(); +absl::Status LockVariables(absl::Span variables) + TF_EXCLUSIVE_LOCK_FUNCTION(); + +// Returns a vector of VariableInfo instances for the resource variable inputs, +// given that *all* inputs are in `inputs`. The input indices for the resource +// variable inputs are in `variable_indices`. +// +// When using the VariableInfos generated by this version, all variables would +// be writer-locked. +absl::Status GetVariableInfosFromInputs(ResourceMgr* rm, DeviceBase* dev, + absl::Span inputs, + absl::Span variable_indices, + std::vector* result); + +// variables_updated is a set containing the indices of the variables that are +// going to be mutated. If variables_updated is empty, then in LockVariables all +// variables would only be reader-locked. If variables_updated is null, then we +// consider this information unknown and will acquire writer-lock for all +// variables. +absl::Status GetVariableInfosFromInputs(ResourceMgr* rm, DeviceBase* dev, + absl::Span inputs, + absl::Span variable_indices, + const std::set* variables_updated, + std::vector* result); + +std::vector GetResourceVariableIndicesFromContext(OpKernelContext* ctx); + +absl::Status CreateVariableInfoLookup( + absl::Span variable_args, + absl::flat_hash_map& variable_info_lookup); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_JIT_VARIABLE_INFO_UTIL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/jit/xla_activity_listener.h b/third_party/tflite-hdrs/tensorflow/compiler/jit/xla_activity_listener.h new file mode 100644 index 00000000..d8be8309 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/jit/xla_activity_listener.h @@ -0,0 +1,77 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_JIT_XLA_ACTIVITY_LISTENER_H_ +#define TENSORFLOW_COMPILER_JIT_XLA_ACTIVITY_LISTENER_H_ + +#include + +#include "tensorflow/compiler/jit/xla_activity.pb.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { +// Broadcast `auto_clustering_activity` to all the registered listeners. +absl::Status BroadcastXlaActivity( + XlaAutoClusteringActivity auto_clustering_activity); + +// Broadcast `jit_compilation_activity` to all the registered listeners. +absl::Status BroadcastXlaActivity( + XlaJitCompilationActivity jit_compilation_activity); + +// Broadcast `jit_compilation_activity` to all the registered listeners. +absl::Status BroadcastOptimizationRemark( + XlaOptimizationRemark optimization_remark); + +// LINT.IfChange +// Called after TensorFlow realizes possible lost performance. The parameters in +// this should match all of the values in the XlaOptimizationRemark proto. +absl::Status BroadcastOptimizationRemark( + XlaOptimizationRemark::Warning optimization_warning, + string debug_information); + +// LINT.ThenChange(//tensorflow/compiler/jit/xla_activity.proto) + +// Various components of the system can subclass XlaActivityListener to +// notifications on auto-clustering and JIT compilation events. +// +// Subclasses of XlaActivityListener must be thread safe. +class XlaActivityListener { + public: + // Called after TensorFlow auto-clusters a graph. + virtual absl::Status Listen( + const XlaAutoClusteringActivity& auto_clustering_activity) = 0; + + // Called after TensorFlow JIT compiles an XLA cluster. + virtual absl::Status Listen( + const XlaJitCompilationActivity& jit_compilation_activity) = 0; + + // Called after TensorFlow realizes possible lost performance. + virtual absl::Status Listen( + const XlaOptimizationRemark& optimization_remark) = 0; + + // Called at program exit in best-effort manner to give listeners a chance to + // flush their state. + // + // Default implementation is a no-op. + virtual void Flush(); + + virtual ~XlaActivityListener(); +}; + +// Registers an `XlaActivityListener`, which will be invoked on all subsequent +// `BroadcastXlaActivity` calls. +void RegisterXlaActivityListener(std::unique_ptr listener); +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_JIT_XLA_ACTIVITY_LISTENER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/jit/xla_cluster_util.h b/third_party/tflite-hdrs/tensorflow/compiler/jit/xla_cluster_util.h new file mode 100644 index 00000000..6fe0b485 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/jit/xla_cluster_util.h @@ -0,0 +1,113 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Contains utilities for clustering compilable graph nodes via XLA. + +#ifndef TENSORFLOW_COMPILER_JIT_XLA_CLUSTER_UTIL_H_ +#define TENSORFLOW_COMPILER_JIT_XLA_CLUSTER_UTIL_H_ + +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" +#include "absl/types/optional.h" +#include "tensorflow/compiler/jit/xla_activity.pb.h" +#include "xla/service/graphcycles/graphcycles.h" +#include "tensorflow/core/common_runtime/optimization_registry.h" +#include "tensorflow/core/graph/algorithm.h" + +namespace tensorflow { + +// The attribute that marks nodes to be grouped into functions by the +// encapsulate subgraphs pass. +extern const char* const kXlaClusterAttr; + +// The attribute that marks certain inputs to a Node as required to be a +// constant at compile time. If this attribute is present then the +// CompileTimeConstantInput information in the corresponding XlaOpKernel is +// ignored. +// +// The value for this attribute, if present, has to be a list of strings naming +// the inputs to the node that must be constant. +extern const char* const kXlaCompileTimeConstantInputsAttr; + +using OrderedNodeSet = std::set; + +// Returns true if `node` has a ref tensor input that it forwards to its output. +bool HasForwardedRefInput(const Node& node); + +// Creates a graph representation to enable cycle detection when clustering. +// This representation handles loops in graph by disconnecting each loop from +// the enclosing graph. +// +// Returns true for success and false for valid graphs that we can't handle yet +// (b/127521408). +absl::StatusOr CreateCycleDetectionGraph(const Graph* graph, + xla::GraphCycles* cycles); + +// Returns the XLA cluster in which `node` is placed if it is in an XLA cluster, +// otherwise returns nullopt. +std::optional GetXlaClusterForNode(const Node& node); + +// Removes `node_def` its XLA cluster (by clearing its _XlaCluster attribute). +void RemoveFromXlaCluster(NodeDef* node_def); + +// Removes `node` its XLA cluster (by clearing its _XlaCluster attribute). +void RemoveFromXlaCluster(Node* node); + +// Returns true if `node` has a DT_RESOURCE typed input or output. +bool HasResourceInputOrOutput(const Node& node); + +// Determines the global jit level based on GraphOptimizationPassOptions, +// --tf_xla_auto_jit and whether the graph is a single GPU graph. +OptimizerOptions::GlobalJitLevel GetGlobalJitLevelForGraph( + const GraphOptimizationPassOptions& options); + +// Returns true if `g` is a single-GPU graph. A single-GPU graph uses exactly +// one GPU (and any number of CPUs). +bool IsSingleGpuGraph(const Graph& g); + +// Returns true if it is possible (but not guaranteed) that `n` calls a +// function. +bool MayCallFunction(const Node& n, const FunctionLibraryDefinition* flib_def); + +// Returns true if `node` an operator that consumes only the shape of its input, +// not the data itself. +bool IsShapeConsumerOp(const Node& node); + +// Computes a clustering summary for `graph`. See documentation on +// `XlaAutoClusteringSummary` for details. +XlaAutoClusteringSummary GetXlaAutoClusteringSummary(const Graph& graph); + +// Returns the set of nodes that have a path to or from nodes that may have ref +// variables as input or output. +// +// We assume each node has a trivial path to itself so the returned set includes +// all of the nodes that have ref variables as input or output. +absl::StatusOr> GetNodesRelatedToRefVariables( + const Graph& graph, FunctionLibraryRuntime* lib_runtime); + +// Deterministically serialized the graph to a byte string. +absl::StatusOr SerializeGraphDeterministic(const Graph& graph); + +// Computes a fingerprint of the given `graph`. The fingerprint can use used to +// check if two graphs are likely the same but should not be relied on +// determining if the graphs are identical. +absl::StatusOr FingerprintGraph(const Graph& graph); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_JIT_XLA_CLUSTER_UTIL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/jit/xla_compile_on_demand_op.h b/third_party/tflite-hdrs/tensorflow/compiler/jit/xla_compile_on_demand_op.h new file mode 100644 index 00000000..dfe9ddaa --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/jit/xla_compile_on_demand_op.h @@ -0,0 +1,77 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// The XlaCompileOnDemandOp is an OpKernel that, when its Compute method is +// called, will generate an xla::Computation and run it asynchronously. + +#ifndef TENSORFLOW_COMPILER_JIT_XLA_COMPILE_ON_DEMAND_OP_H_ +#define TENSORFLOW_COMPILER_JIT_XLA_COMPILE_ON_DEMAND_OP_H_ + +#include + +#include "tensorflow/compiler/jit/device_compilation_profiler.h" +#include "tensorflow/compiler/jit/variable_info.h" +#include "tensorflow/compiler/jit/variable_info_util.h" +#include "tensorflow/compiler/jit/xla_launch_util.h" +#include "tensorflow/compiler/jit/xla_platform_info.h" +#include "tensorflow/compiler/tf2xla/xla_compiler.h" +#include "xla/client/local_client.h" +#include "xla/pjrt/pjrt_client.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { + +// An OpKernel that compiles an op to an XLA computation and runs it. Unlike +// XlaLaunch this doesn't rely on any rewrites of the graphdef - it will run a +// vanilla TensorFlow op as long as the bridge supports it. +class XlaCompileOnDemandOp : public OpKernel { + public: + explicit XlaCompileOnDemandOp(OpKernelConstruction* ctx) + : OpKernel(ctx), + platform_info_(XlaPlatformInfoFromDevice(ctx->device())) {} + void Compute(OpKernelContext* ctx) override; + + private: + absl::Status Compile(const std::vector& args, + OpKernelContext* ctx, + DeviceCompiler** + xla_device_compiler, + DeviceCompilationProfiler** profiler, + const XlaCompiler::CompilationResult** result, + xla::LocalExecutable** executable); + + absl::Status Compile(const std::vector& args, + OpKernelContext* ctx, + DeviceCompiler** pjrt_device_compiler, + DeviceCompilationProfiler** profiler, + const XlaCompiler::CompilationResult** result, + xla::PjRtLoadedExecutable** executable); + + absl::Status Run(const ResourceVarsSnapshot& variable_args, + const XlaCompiler::CompilationResult* result, + const DeviceCompiler* + xla_device_compiler, + xla::LocalExecutable* executable, OpKernelContext* ctx); + + const XlaPlatformInfo platform_info_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_JIT_XLA_COMPILE_ON_DEMAND_OP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/jit/xla_compile_util.h b/third_party/tflite-hdrs/tensorflow/compiler/jit/xla_compile_util.h new file mode 100644 index 00000000..d722ba8e --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/jit/xla_compile_util.h @@ -0,0 +1,67 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_JIT_XLA_COMPILE_UTIL_H_ +#define TENSORFLOW_COMPILER_JIT_XLA_COMPILE_UTIL_H_ + +#include +#include + +#include "tensorflow/compiler/tf2xla/xla_argument.h" +#include "tensorflow/core/graph/graph.h" + +namespace tensorflow { +// The number of compiler threads to use for asynchronous device compilation. +inline constexpr int64_t kNumAsyncDeviceCompilerThreads = 10; + +enum class DeviceCompileMode { + kLazy, + kStrict, + kAsync, +}; + +enum class DeviceCompileState { + kUncompiled, + kCompiling, + kCompiled, +}; + +// Creates a single-node graph using the specified `node_def` as the only op +// apart from the arg and retval nodes corresponding to `args` and +// `result_types` respectively. +absl::StatusOr> CreateSingleOpGraph( + const NodeDef& node_def, absl::Span args, + absl::Span result_types); + +// Checks if single device compilation and execution with PJRT is enabled for +// `device_type` in either the XlaLaunch op or the XlaCompileOnDemand op. +bool UsePjRtForSingleDeviceCompilation(const DeviceType& device_type); + +// Gets the resource name of the PjRt DeviceCompiler for `device_type`. +std::string GetPjRtDeviceCompilerResourceName(const DeviceType& device_type); + +// Gets the resource name of the DeviceCompilationProfiler for `device_type` +// when PjRt is used for compilation and execution. +std::string GetPjRtDeviceCompilationProfilerResourceName( + const DeviceType& device_type); + +// Gets the ResourceMgr where the DeviceCompiler is/should be stored for the +// given `device_type`. +absl::StatusOr GetResourceMgrForDeviceCompiler( + const OpKernelContext& ctx, const DeviceType& device_type); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_JIT_XLA_COMPILE_UTIL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/jit/xla_compiler_options_util.h b/third_party/tflite-hdrs/tensorflow/compiler/jit/xla_compiler_options_util.h new file mode 100644 index 00000000..23cb5f86 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/jit/xla_compiler_options_util.h @@ -0,0 +1,64 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_JIT_XLA_COMPILER_OPTIONS_UTIL_H_ +#define TENSORFLOW_COMPILER_JIT_XLA_COMPILER_OPTIONS_UTIL_H_ + +#include "tensorflow/compiler/jit/device_compiler.h" +#include "tensorflow/compiler/jit/xla_platform_info.h" +#include "tensorflow/compiler/tf2xla/xla_compiler.h" +#include "xla/client/local_client.h" +#include "xla/pjrt/pjrt_client.h" + +namespace tensorflow { + +// Returns created options for the XLA compiler. +XlaCompiler::Options GenerateCompilerOptions( + const DeviceCompiler& + xla_device_compiler, + const FunctionLibraryRuntime& function_library, DeviceBase* device, + se::Stream* stream, const XlaPlatformInfo& platform_info, + bool has_ref_vars); + +// Returns created options for XLA compiler when TFRT-TPU is used. +XlaCompiler::Options GenerateCompilerOptionsForTfrtTpu( + const DeviceCompiler& + xla_device_compiler, + const FunctionLibraryRuntime& function_library); + +// Returns created options for XLA compiler when PjRt (Device API) is used for +// compilation and execution. +XlaCompiler::Options GenerateCompilerOptionsForPjRt( + const FunctionLibraryRuntime& function_library, + const DeviceBase* device_base, const XlaPlatformInfo& platform_info, + const DeviceCompiler* + pjrt_device_compiler); + +// Returns created options for XLA compiler when PjRt (Device API) is used for +// compilation and execution. +XlaCompiler::Options GenerateCompilerOptionsForPjRt( + const FunctionLibraryDefinition* function_library_def, + int graph_def_version, const DeviceBase* device_base, + const XlaPlatformInfo& platform_info, + const DeviceCompiler* + pjrt_device_compiler); + +// Returns created CompileOptions for XLA compiler. +XlaCompiler::CompileOptions GenerateCompileOptions( + bool has_ref_vars, bool may_alias_resource_update); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_JIT_XLA_COMPILER_OPTIONS_UTIL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/jit/xla_device.h b/third_party/tflite-hdrs/tensorflow/compiler/jit/xla_device.h new file mode 100644 index 00000000..877d208d --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/jit/xla_device.h @@ -0,0 +1,321 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// The XlaDevice executes a TensorFlow graph using the XLA linear algebra +// runtime. +// +// Operators assigned to an XlaDevice are compiled into XLA computations. +// Tensors on an XlaDevice are thin wrappers around XLA ScopedShapedBuffers. +// +// XlaDevice is instantiated separately for each XLA backend (e.g., CPU or GPU), +// under different names (e.g., XLA_CPU or XLA_GPU). + +#ifndef TENSORFLOW_COMPILER_JIT_XLA_DEVICE_H_ +#define TENSORFLOW_COMPILER_JIT_XLA_DEVICE_H_ +#include + +#include "absl/types/optional.h" +#include "tensorflow/compiler/jit/xla_tensor.h" +#include "tensorflow/compiler/tf2xla/layout_util.h" +#include "tensorflow/compiler/tf2xla/xla_compiler.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "xla/client/local_client.h" +#include "tensorflow/core/common_runtime/device_factory.h" +#include "tensorflow/core/common_runtime/local_device.h" +#include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/framework/device_base.h" +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/resource_mgr.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" +#include "tensorflow/core/tfrt/common/async_value_tensor.h" + +namespace tensorflow { + +class XlaDevice : public LocalDevice { + public: + // Given a tensor, sets `xla::Shape*` the shape of tensor's representation + // on device, fully padded. On error, the contents of `xla::Shape*` + // are undefined. + typedef std::function PaddedShapeFn; + + // Wrapper class to store metadata about the XlaDevice, where it can be + // retrieved e.g., when lazily creating the XlaCompilationCache device. + class Metadata { + public: + Metadata(int device_ordinal, se::Platform* platform, + const DeviceType& device_type, + std::vector + shape_determination_fns, + PaddedShapeFn padded_shape_fn, bool use_multiple_streams); + + // The index of the device on this host. + int device_ordinal() const; + + se::Platform* platform() const; + xla::LocalClient* client() const; + const DeviceType& jit_device_type() const; + const XlaShapeLayoutHelpers::ShapeDeterminationFns& + default_shape_determination_fns() const { + return shape_determination_fns_.at(0); + } + const PaddedShapeFn& padded_shape_fn() const { return padded_shape_fn_; } + + bool UseMultipleStreams() const { return use_multiple_streams_; } + + private: + const int device_ordinal_; + const DeviceType device_type_; + se::Platform* platform_; // Not owned. + std::vector + shape_determination_fns_; + PaddedShapeFn padded_shape_fn_; + const bool use_multiple_streams_; + + Metadata(const Metadata&) = delete; + void operator=(const Metadata&) = delete; + }; + + // Sets `*metadata` to the XlaDevice Metadata in the XLA device used by `ctx`. + static absl::Status GetMetadata(OpKernelContext* ctx, + const Metadata** metadata); + + // Sets `*metadata` to the XlaDevice Metadata in the XLA device used by `ctx`. + static absl::Status GetMetadata(OpKernelConstruction* ctx, + const Metadata** metadata); + + // Sets `*metadata` to the XlaDevice Metadata in the XLA device used by + // `device`. + static absl::Status GetMetadataFromDevice( + DeviceBase* device, const XlaDevice::Metadata** metadata); + + struct Options { + // The StreamExecutor platform. Not owned. Must be non-null. + se::Platform* platform = nullptr; + + // The device name's prefix (e.g., "/task:7") + string device_name_prefix; + + // The name of the XLA device (e.g., "XLA_CPU") + string device_name; + + // The number of the device. + int device_ordinal = -1; + + // The name of the compilation device (e.g., "XLA_CPU_JIT"); + string compilation_device_name; + + // If 'use_multiple_streams' is true, we create separate streams for + // compute, host-to-device, and device-to-host communication. + bool use_multiple_streams = false; + + // If true, the XLA devices with the same device ordinal will share the same + // compute stream. Otherwise each XLA device will having their own compute + // streams. + bool use_global_compute_stream = false; + + // A vector of ShapeDeterminationFn (i.e., a bundle of LayoutSelectionFn, + // ShapeRepresentationFn). Each bundle describes how the on-host shapes of + // a) argument and return value, for entry computations b) variables, for + // all computations, should be represented in XLA. Parameters/return values + // will be shaped according to the function pair, and reshaped back to/from + // their declared shapes for computations. Must be non-empty. + std::vector + shape_determination_fns; + + // If padded_shape_fn is empty, a default implementation that returns + // the logical on-device shape without padding is used. + PaddedShapeFn padded_shape_fn; + + // Set of devices to use. This controls which of the devices on the given + // platform will have resources allocated. For GPUs this will be + // filled from visible_gpu_devices list from session configuration. + std::optional> allowed_devices; + }; + + // Creates a new XLA Device. + XlaDevice(const SessionOptions& session_options, const Options& options); + ~XlaDevice() override; + + Allocator* GetAllocator(AllocatorAttributes attr) override + TF_LOCKS_EXCLUDED(mu_); + void Compute(OpKernel* op_kernel, OpKernelContext* context) override; + void ComputeAsync(AsyncOpKernel* op_kernel, OpKernelContext* context, + AsyncOpKernel::DoneCallback done) override; + absl::Status Sync() override; + + absl::Status TryGetDeviceContext(DeviceContext** out_context) override + TF_LOCKS_EXCLUDED(mu_); + + absl::Status MakeTensorFromProto(const TensorProto& tensor_proto, + const AllocatorAttributes alloc_attrs, + Tensor* tensor) override + TF_LOCKS_EXCLUDED(mu_); + + absl::Status MakeTensorFromProto(DeviceContext* device_context, + const TensorProto& tensor_proto, + const AllocatorAttributes alloc_attrs, + Tensor* tensor); + + const Metadata& metadata() { return xla_metadata_; } + + // Ensures the DeviceContext associated with this XlaDevice is created and + // valid (i.e. all streams are ok). If any state is not valid, a new + // DeviceContext will be created. + // + // TODO(b/111859745): The Eager context needs to call this method to recover + // from failures. + absl::Status EnsureDeviceContextOk() TF_LOCKS_EXCLUDED(mu_); + + // Two convenient methods to get the underlying device context. + // Get the default device context, created by the first + // shape_representation_fn. + absl::StatusOr GetDeviceContextDefault(); + // Get the device context given the index. + absl::StatusOr GetDeviceContextWithIndex(int index); + + // Instructs this XlaDevice to set a AcceleratorDeviceInfo, which holds extra + // information for GPU and TPU devices. + absl::Status UseAcceleratorDeviceInfo() TF_LOCKS_EXCLUDED(mu_); + + // Instructs this XlaDevice to return 'sync_on_completion' for + // AllowsSyncOnCompletion(). + void SetAllowsSyncOnCompletion(bool sync_on_completion) + TF_LOCKS_EXCLUDED(mu_); + bool AllowsSyncOnCompletion() const override TF_LOCKS_EXCLUDED(mu_); + + // Installs an error handling callback when RefreshStatus sees !status.ok(). + void SetHandleDeviceErrorCallback(std::function callback); + + absl::Status RefreshStatus() override TF_LOCKS_EXCLUDED(mu_); + + private: + absl::StatusOr GetOrCreateClient() const; + Allocator* GetAllocatorLocked(AllocatorAttributes attr) + TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); + absl::Status EnsureStreamOkLocked(xla::Backend* backend, const string& name, + std::shared_ptr* stream, + bool* stream_was_changed) + TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); + + // Return a vector of device context, ordered by the sequence in the given + // shape_representation_fns. + absl::StatusOr> GetDeviceContextLocked() + TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); + + // Handles error when RefreshStatus sees !status.ok(). + absl::Status HandleDeviceError(); + + mutable mutex mu_; + // The metadata of this XlaDevice. + const Metadata xla_metadata_; + // Which hardware device in the client's platform this XlaDevice controls. + const int device_ordinal_; + // The name/type of this XlaDevice. eg. "XLA_GPU". + const DeviceType device_name_; + // The name of the device that is used to compile Ops for this XlaDevice. + const DeviceType jit_device_name_; + // The platform for this device. + se::Platform* const platform_; // Not owned. + // Intra-op threads to spawn (from SessionOptions). + const int intra_op_parallelism_threads_; + // Memory allocator associated with this device. + Allocator* xla_allocator_ TF_GUARDED_BY(mu_) = nullptr; // Not owned. + std::unique_ptr pjrt_allocator_ TF_GUARDED_BY(mu_); + + // Stream associated with this device. Operations enqueued on this + // stream are executed on the device. Operations include data + // copying back and forth between CPU and the device, and + // computations enqueued by XLA. + std::shared_ptr stream_ TF_GUARDED_BY(mu_); + // If false, only stream_ is valid and all computation and transfers use + // stream_. If true, computation is performed by stream_ and transfers are + // performed by host_to_device/device_to_device stream or borrowing a stream + // for each device to host transfer. + const bool use_multiple_streams_; + // If use_multiple_streams_, host to device transfers are performed using this + // stream. + std::shared_ptr host_to_device_stream_ TF_GUARDED_BY(mu_); + // If use_multiple_streams_, transfers between different devices are performed + // using these streams. + std::vector> device_to_device_streams_ + TF_GUARDED_BY(mu_); + + // See comments in options. + std::vector + shape_determination_fns_; + + // A list of the device context accessed by all users of the XlaDevice, set by + // calls to EnsureDeviceContextOk. The number of device conetexts is based on + // the number of shape representation functions in XlaDevice::Options. If + // accelerator_device_info_ is non-null, this pointer is also filled in to + // that struct. DeviceContext is a ref-counted object. + std::vector device_contexts_ TF_GUARDED_BY(mu_); + + // Holds extra information for GPU and TPU devices, e.g. the device context. + bool use_accelerator_device_info_ TF_GUARDED_BY(mu_) = false; + std::unique_ptr accelerator_device_info_ + TF_GUARDED_BY(mu_); + + // Thread pool used for running closures + std::unique_ptr thread_pool_; + + // True if the device allows XlaDevice::Sync to be called on completion + // regardless of status. + bool sync_on_completion_ TF_GUARDED_BY(mu_) = true; + + // A callback that will be invoked when RefreshStatus sees a status error. + std::function device_error_callback_ TF_GUARDED_BY(mu_); + + // Set of devices to use. This controls which of the devices on the given + // platform will have resources allocated. For GPUs this will be + // filled from visible_gpu_devices list from session configuration. + std::optional> allowed_devices_; + + const bool use_global_compute_stream_; + + // A static vector with device_ordinal as its index, describing the global + // compute streams used in each XLA device. It is only used if + // `use_global_compute_stream` in `XlaDevice::Options` is set to true. + static mutex global_mu_; + static std::vector>* global_compute_streams_ + TF_GUARDED_BY(global_mu_); +}; + +// Builds OpKernel registrations on 'device' for the JIT operators +// registered on 'jit_device'. Returns ownership of a XlaDeviceOpRegistrations +// object that encapsulates the kernel registrations. +struct XlaDeviceOpRegistrations { + std::vector> + op_kernel_registrars; +}; + +XlaDeviceOpRegistrations* RegisterXlaDeviceKernels( + const char* device, const char* jit_device, + OpKernel* (*factory)(OpKernelConstruction*), + absl::string_view kernel_class_name); + +XlaDeviceOpRegistrations* RegisterXlaDeviceKernels(const char* device, + const char* jit_device); + +absl::Status DefaultPaddedShapeFn(const Tensor& tensor, xla::Shape* shape); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_JIT_XLA_DEVICE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/jit/xla_device_compiler_client.h b/third_party/tflite-hdrs/tensorflow/compiler/jit/xla_device_compiler_client.h new file mode 100644 index 00000000..3967897c --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/jit/xla_device_compiler_client.h @@ -0,0 +1,69 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_JIT_XLA_DEVICE_COMPILER_CLIENT_H_ +#define TENSORFLOW_COMPILER_JIT_XLA_DEVICE_COMPILER_CLIENT_H_ + +#include +#include +#include + +#include "tensorflow/compiler/jit/device_compiler_client.h" +#include "xla/client/local_client.h" + +namespace tensorflow { + +class XlaDeviceCompilerClient + : public DeviceCompilerClient { + public: + explicit XlaDeviceCompilerClient(xla::LocalClient* client) + : client_(client) {} + + absl::StatusOr> BuildExecutable( + const XlaCompiler::Options& options, + const XlaCompiler::CompilationResult& result) override; + + // Returns a serialized AOT result obtained by exporting the available + // `executable` using the XlaCompiler. + absl::StatusOr SerializeExecutable( + const xla::LocalExecutable& executable) override; + + // Returns a serialized AOT result obtained by compiling `result` into an AOT + // result. + absl::StatusOr BuildSerializedExecutable( + const XlaCompiler::Options& options, + const XlaCompiler::CompilationResult& result) override; + + // Loads a serialized AOT result (`serialized_executable`) into an + // xla::LocalExecutable and returns it. + absl::StatusOr> LoadExecutable( + const XlaCompiler::Options& options, + const XlaCompiler::CompilationResult& result, + const std::string& serialized_executable) override; + + void WaitForProgramsToFinish() override; + + xla::LocalClient* client() const override { return client_; } + + private: + xla::LocalClient* const client_; + + XlaDeviceCompilerClient(const XlaDeviceCompilerClient&) = delete; + void operator=(const XlaDeviceCompilerClient&) = delete; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_JIT_XLA_DEVICE_COMPILER_CLIENT_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/jit/xla_device_context.h b/third_party/tflite-hdrs/tensorflow/compiler/jit/xla_device_context.h new file mode 100644 index 00000000..4e8a769e --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/jit/xla_device_context.h @@ -0,0 +1,128 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_JIT_XLA_DEVICE_CONTEXT_H_ +#define TENSORFLOW_COMPILER_JIT_XLA_DEVICE_CONTEXT_H_ + +#include + +#include "absl/synchronization/mutex.h" +#include "tensorflow/compiler/jit/xla_tensor.h" +#include "tensorflow/compiler/tf2xla/layout_util.h" +#include "tensorflow/compiler/tf2xla/xla_compiler.h" +#include "xla/client/global_data.h" +#include "xla/client/local_client.h" +#include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/framework/device_base.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { + +// The allocator used for Tensors assigned to the XLA device. The allocator +// ignores the alignment and size of the request and always returns a new, +// empty, XlaTensor. +class XlaDeviceAllocator : public Allocator { + public: + XlaDeviceAllocator(se::StreamExecutor* stream_executor); + ~XlaDeviceAllocator() override; + + string Name() override; + + void* AllocateRaw(size_t alignment, size_t num_bytes) override; + void DeallocateRaw(void* ptr) override; + std::optional GetStats() override; + bool ClearStats() override; + + private: + // The stream executor of the device. + se::StreamExecutor* stream_executor_; +}; + +// Helper class for managing data transfers between host and XLA devices. +class XlaDeviceContext : public DeviceContext { + public: + explicit XlaDeviceContext( + std::shared_ptr compute_stream, + std::shared_ptr host_to_device_stream, + std::shared_ptr device_to_host_stream, + std::vector> device_to_device_streams, + xla::LocalClient* client, + XlaShapeLayoutHelpers::ShapeDeterminationFns shape_determination_fns, + thread::ThreadPool* thread_pool); + + void CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device, + Tensor* device_tensor, StatusCallback done, + bool sync_dst_compute) const override; + void CopyDeviceTensorToCPU(const Tensor* device_tensor, + absl::string_view tensor_name, Device* device, + Tensor* cpu_tensor, StatusCallback done) override; + void CopyTensorInSameDevice(const Tensor* input_tensor, Device* device, + Tensor* output_tensor, + StatusCallback done) const override; + + xla::LocalClient* client() const { return client_; } + se::Stream* stream() const override { return stream_.get(); } + se::Stream* host_to_device_stream() const { + return host_to_device_stream_.get(); + } + se::Stream* device_to_device_stream(int index) const { + return device_to_device_streams_.at(index).get(); + } + xla::TransferManager* transfer_manager() const { return transfer_manager_; } + const XlaShapeLayoutHelpers::ShapeDeterminationFns& shape_determination_fns() + const { + return shape_determination_fns_; + } + + // Returns a device-to-device stream, in round-robin fashion. + se::Stream* GetDeviceToDeviceStream(); + + absl::Status ThenExecute(Device* device, stream_executor::Stream* stream, + std::function func) override; + + private: + bool UseMultipleStreams() const { return stream_ != host_to_device_stream_; } + + // The main compute stream of the device, used to synchronize the transfer + // streams if they are set. + std::shared_ptr stream_; + // The stream to use for transferring data from host to device. Can be + // idential to stream_, but must not be nullptr. + std::shared_ptr host_to_device_stream_; + // The stream to use for transferring data from device to host. Can be + // idential to stream_. If nullptr, borrow a stream from backend for each + // transfer request to support out-of-order requests. + std::shared_ptr device_to_host_stream_; + // Streams to use for transferring data directly between different devices, + // e.g., over NVLINK. + std::vector> device_to_device_streams_; + + // For the underlying memory allocator and XLA's TransferManager. + xla::LocalClient* client_; + // Transfer manager, for marshalling data to and from the device. + xla::TransferManager* transfer_manager_; + + XlaShapeLayoutHelpers::ShapeDeterminationFns shape_determination_fns_; + + // Thread pool used for running closures + thread::ThreadPool* thread_pool_; + + absl::Mutex mu_; + int next_stream_ TF_GUARDED_BY(mu_) = 0; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_JIT_XLA_DEVICE_CONTEXT_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/jit/xla_device_ops.h b/third_party/tflite-hdrs/tensorflow/compiler/jit/xla_device_ops.h new file mode 100644 index 00000000..fdb28446 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/jit/xla_device_ops.h @@ -0,0 +1,256 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Common kernel registrations for XLA devices. + +#ifndef TENSORFLOW_COMPILER_JIT_XLA_DEVICE_OPS_H_ +#define TENSORFLOW_COMPILER_JIT_XLA_DEVICE_OPS_H_ + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/resource_mgr.h" +#include "tensorflow/core/kernels/constant_op.h" +#include "tensorflow/core/kernels/data/finalize_dataset_op.h" +#include "tensorflow/core/kernels/data/generator_dataset_op.h" +#include "tensorflow/core/kernels/data/iterator_ops.h" +#include "tensorflow/core/kernels/data/optional_ops.h" +#include "tensorflow/core/kernels/data/options_dataset_op.h" +#include "tensorflow/core/kernels/data/prefetch_dataset_op.h" +#include "tensorflow/core/kernels/fifo_queue.h" +#include "tensorflow/core/kernels/function_ops.h" +#include "tensorflow/core/kernels/identity_op.h" +#include "tensorflow/core/kernels/resource_variable_ops.h" +#include "tensorflow/core/kernels/shape_ops.h" +#include "tensorflow/core/kernels/variable_ops.h" + +namespace tensorflow { + +// Dummy OpKernel, used for kernels assigned to an XLA device that should be +// compiled. Should never be called at runtime since such ops should be +// rewritten to a XlaLaunch op. If it is called, it means the placer placed an +// operator on an XLA device but the compiler did not compile it. +class XlaDeviceDummyOp : public OpKernel { + public: + explicit XlaDeviceDummyOp(OpKernelConstruction* ctx); + void Compute(OpKernelContext* ctx) override; +}; + +class XlaAssignVariableOp : public OpKernel { + public: + explicit XlaAssignVariableOp(OpKernelConstruction* c); + void Compute(OpKernelContext* context) override; + + private: + DataType dtype_; +}; + +#define REGISTER_XLA_LAUNCH_KERNEL(DEVICE, KERNEL, TYPES) \ + REGISTER_KERNEL_BUILDER(Name("XlaLaunch") \ + .Device(DEVICE) \ + .HostMemory("constants") \ + .HostMemory("resources"), \ + KERNEL); + +#define REGISTER_XLA_COMPILE_KERNEL(DEVICE, KERNEL, TYPES) \ + REGISTER_KERNEL_BUILDER(Name("_XlaCompile") \ + .Device(DEVICE) \ + .HostMemory("constants") \ + .HostMemory("key") \ + .HostMemory("compilation_successful") \ + .HostMemory("resources"), \ + KERNEL); + +#define REGISTER_XLA_RUN_KERNEL(DEVICE, KERNEL, TYPES) \ + REGISTER_KERNEL_BUILDER(Name("_XlaRun").Device(DEVICE), KERNEL); + +#define REGISTER_XLA_DEVICE_KERNELS(DEVICE, TYPES) \ + REGISTER_KERNEL_BUILDER( \ + Name("Const").Device(DEVICE).TypeConstraint("dtype", TYPES), \ + ConstantOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("Identity").Device(DEVICE).TypeConstraint("T", TYPES), IdentityOp); \ + \ + REGISTER_KERNEL_BUILDER( \ + Name("VarHandleOp").Device(DEVICE).HostMemory("resource"), VarHandleOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("_VarHandlesOp").Device(DEVICE).HostMemory("resources"), \ + ResourceHandlesOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("ReadVariableOp").Device(DEVICE).HostMemory("resource"), \ + ReadVariableOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("_ReadVariablesOp").Device(DEVICE).HostMemory("resources"), \ + ReadVariablesOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("DestroyResourceOp").Device(DEVICE).HostMemory("resource"), \ + DestroyResourceOp); \ + REGISTER_KERNEL_BUILDER(Name("Shape") \ + .Device(DEVICE) \ + .HostMemory("output") \ + .TypeConstraint("out_type") \ + .TypeConstraint("T", TYPES), \ + ShapeOp); \ + REGISTER_KERNEL_BUILDER(Name("Shape") \ + .Device(DEVICE) \ + .HostMemory("output") \ + .TypeConstraint("out_type") \ + .TypeConstraint("T", TYPES), \ + ShapeOp); \ + REGISTER_KERNEL_BUILDER(Name("ShapeN") \ + .Device(DEVICE) \ + .HostMemory("output") \ + .TypeConstraint("out_type") \ + .TypeConstraint("T", TYPES), \ + ShapeNOp); \ + REGISTER_KERNEL_BUILDER(Name("ShapeN") \ + .Device(DEVICE) \ + .HostMemory("output") \ + .TypeConstraint("out_type") \ + .TypeConstraint("T", TYPES), \ + ShapeNOp); \ + REGISTER_KERNEL_BUILDER(Name("VariableShape") \ + .Device(DEVICE) \ + .TypeConstraint("out_type") \ + .HostMemory("output") \ + .HostMemory("input"), \ + VariableShapeOp); \ + REGISTER_KERNEL_BUILDER(Name("VariableShape") \ + .Device(DEVICE) \ + .TypeConstraint("out_type") \ + .HostMemory("output") \ + .HostMemory("input"), \ + VariableShapeOp); \ + REGISTER_KERNEL_BUILDER(Name("Size") \ + .Device(DEVICE) \ + .HostMemory("output") \ + .TypeConstraint("out_type") \ + .TypeConstraint("T", TYPES), \ + SizeOp); \ + REGISTER_KERNEL_BUILDER(Name("Size") \ + .Device(DEVICE) \ + .HostMemory("output") \ + .TypeConstraint("out_type") \ + .TypeConstraint("T", TYPES), \ + SizeOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("Rank").Device(DEVICE).HostMemory("output").TypeConstraint("T", \ + TYPES), \ + RankOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("AssignVariableOp").Device(DEVICE).HostMemory("resource"), \ + XlaAssignVariableOp); \ + \ + REGISTER_KERNEL_BUILDER( \ + Name("FIFOQueueV2").Device(DEVICE).HostMemory("handle"), FIFOQueueOp); \ + \ + REGISTER_KERNEL_BUILDER( \ + Name(kArgOp).Device(DEVICE).TypeConstraint("T", TYPES), ArgOp); \ + REGISTER_KERNEL_BUILDER(Name(kArgOp) \ + .Device(DEVICE) \ + .HostMemory("output") \ + .TypeConstraint("T"), \ + ArgOp); \ + REGISTER_KERNEL_BUILDER( \ + Name(kArgOp).Device(DEVICE).TypeConstraint("T"), ArgOp); \ + \ + REGISTER_KERNEL_BUILDER( \ + Name(kRetOp).Device(DEVICE).TypeConstraint("T", TYPES), RetvalOp); \ + REGISTER_KERNEL_BUILDER(Name(kRetOp) \ + .Device(DEVICE) \ + .TypeConstraint("T") \ + .HostMemory("input"), \ + RetvalOp); \ + REGISTER_KERNEL_BUILDER( \ + Name(kDeviceRetOp).Device(DEVICE).TypeConstraint("T"), RetvalOp); \ + \ + REGISTER_KERNEL_BUILDER( \ + Name("RemoteCall").Device(DEVICE).HostMemory("target"), RemoteCallOp); \ + \ + REGISTER_KERNEL_BUILDER( \ + Name("GeneratorDataset").Device(DEVICE).HostMemory("handle"), \ + data::GeneratorDatasetOp); \ + REGISTER_KERNEL_BUILDER(Name("PrefetchDataset") \ + .Device(DEVICE) \ + .HostMemory("buffer_size") \ + .HostMemory("input_dataset") \ + .HostMemory("handle"), \ + data::PrefetchDatasetOp); \ + REGISTER_KERNEL_BUILDER(Name("OptionsDataset") \ + .Device(DEVICE) \ + .HostMemory("input_dataset") \ + .HostMemory("handle"), \ + data::OptionsDatasetOp); \ + REGISTER_KERNEL_BUILDER(Name("FinalizeDataset") \ + .Device(DEVICE) \ + .HostMemory("input_dataset") \ + .HostMemory("handle"), \ + data::FinalizeDatasetOp); \ + \ + REGISTER_KERNEL_BUILDER(Name("IteratorV2").Device(DEVICE), \ + data::IteratorHandleOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("MakeIterator").Device(DEVICE).HostMemory("dataset"), \ + data::MakeIteratorOp); \ + REGISTER_KERNEL_BUILDER(Name("AnonymousIterator").Device(DEVICE), \ + data::AnonymousIteratorHandleOp); \ + REGISTER_KERNEL_BUILDER(Name("AnonymousIteratorV2").Device(DEVICE), \ + data::AnonymousIteratorHandleOp); \ + REGISTER_KERNEL_BUILDER(Name("AnonymousIteratorV3").Device(DEVICE), \ + data::AnonymousIteratorHandleOp); \ + REGISTER_KERNEL_BUILDER(Name("DeleteIterator").Device(DEVICE), \ + data::DeleteIteratorOp); \ + REGISTER_KERNEL_BUILDER(Name("IteratorGetNext").Device(DEVICE), \ + data::IteratorGetNextOp); \ + REGISTER_KERNEL_BUILDER(Name("IteratorGetNextAsOptional").Device(DEVICE), \ + data::IteratorGetNextAsOptionalOp); \ + REGISTER_KERNEL_BUILDER(Name("IteratorGetNextSync").Device(DEVICE), \ + data::IteratorGetNextOp); \ + REGISTER_KERNEL_BUILDER(Name("IteratorToStringHandle") \ + .Device(DEVICE) \ + .HostMemory("string_handle"), \ + data::IteratorToStringHandleOp); \ + REGISTER_KERNEL_BUILDER(Name("IteratorFromStringHandleV2") \ + .Device(DEVICE) \ + .HostMemory("string_handle"), \ + data::IteratorFromStringHandleOp); \ + REGISTER_KERNEL_BUILDER(Name("OptionalNone").Device(DEVICE), \ + data::OptionalNoneOp); \ + REGISTER_KERNEL_BUILDER(Name("OptionalFromValue").Device(DEVICE), \ + data::OptionalFromValueOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("OptionalHasValue").Device(DEVICE).HostMemory("has_value"), \ + data::OptionalHasValueOp); \ + REGISTER_KERNEL_BUILDER(Name("OptionalGetValue").Device(DEVICE), \ + data::OptionalGetValueOp); \ + REGISTER_KERNEL_BUILDER(Name(FunctionLibraryDefinition::kArgOp) \ + .Device(DEVICE) \ + .HostMemory("output") \ + .TypeConstraint("T"), \ + ArgOp); \ + REGISTER_KERNEL_BUILDER(Name(FunctionLibraryDefinition::kRetOp) \ + .Device(DEVICE) \ + .TypeConstraint("T") \ + .HostMemory("input"), \ + RetvalOp); + +// TODO(b/118881356): currently we do not register the QueueEnqueueMany, +// QueueDequeueMany, or QueueDequeueUpTo kernels because they attempt to read +// and write the tensors they access in order to concatenate them into a batch. +// We would need either to call out to an XLA computation to perform the +// concatenation, or we would need to refactor those kernels so the splitting +// or merging is done in a separate operator that can be compiled. + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_JIT_XLA_DEVICE_OPS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/jit/xla_host_recv_device_context.h b/third_party/tflite-hdrs/tensorflow/compiler/jit/xla_host_recv_device_context.h new file mode 100644 index 00000000..d6dfc6f1 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/jit/xla_host_recv_device_context.h @@ -0,0 +1,93 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_JIT_XLA_HOST_RECV_DEVICE_CONTEXT_H_ +#define TENSORFLOW_COMPILER_JIT_XLA_HOST_RECV_DEVICE_CONTEXT_H_ + +#include "xla/shape.h" +#include "xla/stream_executor/device_memory.h" +#include "xla/stream_executor/stream.h" +#include "tensorflow/core/framework/device_base.h" +#include "tfrt/concurrency/async_value_ref.h" // from @tf_runtime + +namespace tensorflow { + +// XlaHostRecvDeviceContext is a DeviceContext that is intended to be +// used to transfer from device->host using Rendezvous. It transfers the +// content of `device_memory_base` with `shape` using `stream`. Only +// `CopyDeviceTensorToCPU` method is implemented. The `done_event` is marked as +// Concrete once transfer is completed. +// +// Example usage: +// +// Device device; +// stream_executor::Stream stream(executor); +// Tensor device_tensor(device_allocator, DT_FLOAT, TensorShape({2, 2})); +// se::DeviceMemoryBase gpu_dst{device_tensor.data(), 4 * sizeof(float)}; +// xla::Shape shape(xla::F32, {2, 2}, {}, {}) +// tsl::AsyncValueRef> done_event = +// tsl::MakeConstructedAsyncValueRef>(stream.parent()); +// done_event->Init(); +// Tensor dest_cpu_tensor; +// +// XlaHostRecvDeviceContext device_context(&stream, gpu_dst, +// shape, done_event); +// device_context.CopyDeviceTensorToCPUSync( +// &device_tensor, "", &device, &dest_cpu_tensor); + +class XlaHostRecvDeviceContext : public DeviceContext { + public: + XlaHostRecvDeviceContext( + se::Stream* stream, const se::DeviceMemoryBase& device_memory_base, + const xla::Shape& shape, + tsl::AsyncValueRef>& done_event) + : stream_(stream), + device_memory_base_(device_memory_base), + shape_(shape), + done_event_(done_event) {} + + void CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device, + Tensor* device_tensor, StatusCallback done, + bool sync_dst_compute) const override { + done(errors::Internal("host->device copy not implemented.")); + } + + // Copies `device_memory_base_` with `shape_` into `cpu_tensor`. + // `device_tensor` is unused. + void CopyDeviceTensorToCPU(const Tensor* device_tensor, + absl::string_view tensor_name, Device* device, + Tensor* cpu_tensor, StatusCallback done) override; + + void CopyTensorInSameDevice(const Tensor* input_tensor, Device* device, + Tensor* output_tensor, + StatusCallback done) const override { + done(errors::Internal("device->device copy not implemented.")); + } + + private: + se::Stream* stream_; // Not owned. + // This is copied rather than a reference or pointer since its lifetime + // is not guaranteed to outlast the original object. Object slicing is + // not an issue here since only DeviceMemoryBase methods/members are used. + const se::DeviceMemoryBase device_memory_base_; + const xla::Shape shape_; + tsl::AsyncValueRef> done_event_; + + XlaHostRecvDeviceContext(const XlaHostRecvDeviceContext&) = delete; + void operator=(const XlaHostRecvDeviceContext&) = delete; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_JIT_XLA_HOST_RECV_DEVICE_CONTEXT_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/jit/xla_host_send_device_context.h b/third_party/tflite-hdrs/tensorflow/compiler/jit/xla_host_send_device_context.h new file mode 100644 index 00000000..52ca6125 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/jit/xla_host_send_device_context.h @@ -0,0 +1,90 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_JIT_XLA_HOST_SEND_DEVICE_CONTEXT_H_ +#define TENSORFLOW_COMPILER_JIT_XLA_HOST_SEND_DEVICE_CONTEXT_H_ + +#include "xla/shape.h" +#include "xla/stream_executor/device_memory.h" +#include "xla/stream_executor/stream.h" +#include "tensorflow/core/framework/device_base.h" +#include "tfrt/concurrency/async_value_ref.h" // from @tf_runtime + +namespace tensorflow { + +// XlaHostSendDeviceContext is a DeviceContext that is intended to be +// used to transfer from host->device using Rendezvous. It transfers the +// content of `device_memory_base` with `shape` using `stream`. Only +// `CopyCPUTensorToDevice` method is implemented. The `done_event` is marked as +// Concrete once transfer is completed. +// +// Example usage: +// +// Device device; +// stream_executor::Stream stream(executor); +// Tensor cpu_tensor(host_allocator, DT_FLOAT, TensorShape({2, 2})); +// Tensor device_tensor(device_allocator, DT_FLOAT, TensorShape({2, 2})); +// se::DeviceMemoryBase gpu_dst{device_tensor.data(), 4 * sizeof(float)}; +// xla::Shape shape(xla::F32, {2, 2}, {}, {}) +// tsl::AsyncValueRef> done_event = +// tsl::MakeConstructedAsyncValueRef>(stream.parent()); +// done_event->Init(); +// +// XlaHostSendDeviceContext device_context(&stream, &gpu_dst, +// shape, done_event); +// device_context.CopyCPUTensorToDeviceSync( +// &cpu_tensor, &device, &device_tensor); + +class XlaHostSendDeviceContext : public DeviceContext { + public: + XlaHostSendDeviceContext( + se::Stream* stream, se::DeviceMemoryBase* device_memory_base, + const xla::Shape& shape, + tsl::AsyncValueRef>& done_event) + : stream_(stream), + device_memory_base_(device_memory_base), + shape_(shape), + done_event_(done_event) {} + + // Copies 'cpu_tensor' to `device_memory_base_` with `shape_`. + // `device_tensor` is unused. + void CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device, + Tensor* device_tensor, StatusCallback done, + bool sync_dst_compute) const override; + + void CopyDeviceTensorToCPU(const Tensor* device_tensor, + absl::string_view tensor_name, Device* device, + Tensor* cpu_tensor, StatusCallback done) override { + done(errors::Internal("host->device copy not implemented.")); + } + + void CopyTensorInSameDevice(const Tensor* input_tensor, Device* device, + Tensor* output_tensor, + StatusCallback done) const override { + done(errors::Internal("device->device copy not implemented.")); + } + + private: + se::Stream* stream_; // Not owned. + se::DeviceMemoryBase* device_memory_base_; // Not owned. + const xla::Shape shape_; + tsl::AsyncValueRef> done_event_; + + XlaHostSendDeviceContext(const XlaHostSendDeviceContext&) = delete; + void operator=(const XlaHostSendDeviceContext&) = delete; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_JIT_XLA_HOST_SEND_DEVICE_CONTEXT_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/jit/xla_kernel_creator.h b/third_party/tflite-hdrs/tensorflow/compiler/jit/xla_kernel_creator.h new file mode 100644 index 00000000..67c843bd --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/jit/xla_kernel_creator.h @@ -0,0 +1,49 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_JIT_XLA_KERNEL_CREATOR_H_ +#define TENSORFLOW_COMPILER_JIT_XLA_KERNEL_CREATOR_H_ + +#include + +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/node_properties.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { + +class FunctionLibraryRuntime; +class OpKernel; + +class XlaKernelCreator : public CustomKernelCreator { + public: + // Given a NodeDef 'node_def' and the function library runtime 'flr', returns + // true if 'node_def' is a call to a compilable function defined in 'flr', + // with the kXlaCompileAttr set. + bool CanCreateKernel( + const FunctionLibraryRuntime& flr, + const std::shared_ptr& props) const override; + + // Given a supported NodeDef, returns a XlaLaunchOp that computes the node. + absl::Status CreateKernel(FunctionLibraryRuntime* flr, + const std::shared_ptr& props, + std::unique_ptr* kernel) const override; +}; + +bool RegisterLaunchOpCreator(); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_JIT_XLA_KERNEL_CREATOR_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/jit/xla_launch_util.h b/third_party/tflite-hdrs/tensorflow/compiler/jit/xla_launch_util.h new file mode 100644 index 00000000..5e5128d5 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/jit/xla_launch_util.h @@ -0,0 +1,267 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Contains utilities for launching compiled XLA kernels for a KernelContext. + +#ifndef TENSORFLOW_COMPILER_JIT_XLA_LAUNCH_UTIL_H_ +#define TENSORFLOW_COMPILER_JIT_XLA_LAUNCH_UTIL_H_ + +#include +#include +#include +#include + +#include "tensorflow/compiler/jit/variable_info.h" +#include "tensorflow/compiler/jit/xla_tensor.h" +#include "tensorflow/compiler/tf2xla/xla_compiler.h" +#include "xla/client/local_client.h" +#include "xla/pjrt/pjrt_client.h" +#include "xla/service/shaped_buffer.h" +#include "xla/stream_executor/device_memory_allocator.h" +#include "tensorflow/core/framework/allocation_description.pb.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/thread_annotations.h" + +namespace tensorflow { + +// Creates a list of updated resource variables. +absl::StatusOr> GatherVariableInfo( + OpKernelContext* ctx, + const XlaCompiler::CompilationResult& compilation_result, + int missing_ctx_input_prefix); + +// Returns pointers to inputs stored in `ctx`. +std::vector InputsFromContext(OpKernelContext* ctx); + +absl::StatusOr> GetConstantInputIndicesFromContext( + OpKernelContext* ctx); + +absl::Status SetOutputForConstant( + OpKernelContext* ctx, bool requires_copy_to_device, + const XlaCompiler::CompilationResult* compilation_result, int output_num); + +// Converts input tensors and variables which are parameters of the +// XlaComputation into PjRtBuffers to be fed as input to the +// PjRtLoadedExecutable. +// +// Assumes that the first `num_missing_prefix_ctx_inputs` inputs to the +// compilation_result are missing in `inputs` and adjusts indexing into `inputs` +// accordingly. +// `input_mapping` is a vector that maps from the parameters of the +// XlaComputation to their original argument positions. This can be sourced from +// `XlaCompiler::CompilationResult::input_mapping`. +// `variable_snapshots` is a map of {index of the input to the +// compilation_result -> underlying Tensor the variable is/was pointing to (i.e. +// the value of the variable at the time of lowering/compilation)}. +// +// The obtained PjRtBuffers are populated to `args` vector. +// `non_donatable_input_indices` will also be set, which contains the indices of +// the input that should not be donated to output. +// +// There can be three types of input: 1. Tensor with PjRtTensorBuffer; 2. +// Tensor with AsyncValueTensor; 3. Tensor with raw device mem pointer. +// For case 3, we need to create a PjRtBuffer from the raw device mem pointer, +// and we need to ensure the PjRtBuffer persists till XLA computation is +// complete. Therefore we put the newly created PjRtBuffer into `owned_args`. +// Caller is responsible to ensure `owned_args` lives till the end of XLA +// computation. +absl::Status PreparePjRtExecutableArguments( + int num_missing_prefix_ctx_inputs, const std::vector& input_mapping, + const std::vector& inputs, + const absl::flat_hash_map& variable_snapshots, + xla::PjRtClient* pjrt_client, xla::PjRtDevice* pjrt_device, + bool use_pjrt_tensor_buffer, std::vector* args, + std::vector>* owned_args, + absl::flat_hash_set* non_donatable_input_indices); + +// Populates the OpKernelContext outputs with the outputs of the +// PjRtLoadedExecutable. Requires the `compilation_result` used to build the +// PjRtLoadedExecutable. +// This function only looks at variables that were updated, so `variables` can +// either be all the variables or only the ones that were updated. +// Assumes that the first `num_missing_prefix_ctx_inputs` inputs to the +// compilation_result are missing in `inputs` and adjusts indexing into `inputs` +// accordingly. +absl::Status PopulateCtxOutputsFromPjRtExecutableOutputs( + int num_missing_prefix_ctx_inputs, const std::vector& inputs, + const std::vector& variables, + const XlaCompiler::CompilationResult& compilation_result, + bool use_pjrt_tensor_buffer, + std::vector>& executable_outputs, + OpKernelContext* ctx); + +// Returns the options used for executing a PjRtLoadedExecutable. +xla::ExecuteOptions GetPjRtExecuteOptions( + const DeviceType& device_type, + absl::flat_hash_set non_donatable_input_indices); + +// Returns the device ordinal from the parsed name of the device. +int GetDeviceOrdinal(const DeviceBase* device); + +// Returns the device type from the OpKernelContext. +DeviceType GetDeviceType(OpKernelContext* ctx); + +// Runs `executable` and populates the outputs in `ctx`. `inputs` and +// `variables` are the input arguments to the computation, usually read from the +// OpKernelContext, `ctx`. Requires the device-appropriate `pjrt_client` and the +// `compilation_result` used to build the `executable`. +absl::Status RunPjRtExecutable( + const std::vector& inputs, + const std::vector& variables, + const XlaCompiler::CompilationResult& compilation_result, + xla::PjRtClient* pjrt_client, xla::PjRtLoadedExecutable* executable, + OpKernelContext* ctx); + +// Same as the above function but takes in `updated_variables` and +// `variable_snapshots` which is a map of {index of the input to the +// compilation_result -> underlying Tensor the variable is/was pointing to +// (i.e. the value of the variable at the time of lowering/compilation)}. +// Assumes that the first `num_missing_prefix_ctx_inputs` inputs to the +// compilation_result are missing in `inputs` and adjusts indexing into `inputs` +// accordingly. +absl::Status RunPjRtExecutable( + int num_missing_prefix_ctx_inputs, const std::vector& inputs, + const absl::flat_hash_map& variable_snapshots, + const std::vector& updated_variables, + const XlaCompiler::CompilationResult& compilation_result, + xla::PjRtClient* pjrt_client, xla::PjRtLoadedExecutable* executable, + OpKernelContext* ctx); + +// Similar to the above function but it does not take an OpKernelContext, and +// it returns the output in PjRtBuffers, instead of populating results into +// OpKernelContext. +absl::StatusOr>> RunPjRtExecutable( + int num_missing_prefix_ctx_inputs, const std::vector& inputs, + const absl::flat_hash_map& variable_snapshots, + const std::vector& updated_variables, + const DeviceType& device_type, bool use_pjrt_tensor_buffer, + const XlaCompiler::CompilationResult& compilation_result, + xla::PjRtDevice* device, xla::PjRtClient* pjrt_client, + xla::PjRtLoadedExecutable* executable); + +// Helper class to perform the marshalling of TensorFlow inputs and outputs to +// ShapedBuffers suitable for passing to an XLA computation. +class XlaComputationLaunchContext { + public: + // Create a new launch context. 'allocate_xla_tensors' is true if allocated + // output tensors and variables are always XlaTensors. If false they are + // assumed to be "normal" device pointers. + // If 'use_multiple_streams' is true, tensors may be defined and used on + // multiple streams and so se::Events must be defined and waited for. If + // 'use_multiple_streams' is true, 'allocate_xla_tensors' must also be true + // because we track inter-stream dependencies through events inside XlaTensor + // objects. + XlaComputationLaunchContext(xla::LocalClient* client, + se::DeviceMemoryAllocator* xla_allocator, + int device_ordinal, bool allocate_xla_tensors, + bool use_multiple_streams); + + // Builds a XlaCompiler::Argument vector from the arguments to an XlaLaunch + // op. + // Precondition: variables in `variable_args` are locked. + static absl::StatusOr> + BuildXlaCompilerArguments(absl::Span must_be_constant_idxs, + absl::Span inputs, + absl::Span variable_args, + Device* device); + + // Add all inputs within `ctx` as XLA arguments (returned by arguments()). + // `variables` is a map from TensorFlow argument number to resource variable. + // + // Assumes that the first `missing_ctx_input_prefix` inputs to the kernel are + // missing and adjusts input indices accordingly. All elements in kernel's + // input_mapping must be greater than or equal to `missing_ctx_input_prefix` + // (in other words, no inputs actually required by the kernel can be missing). + absl::StatusOr> PopulateInputs( + OpKernelContext* ctx, + const XlaCompiler::CompilationResult* compilation_result, + const std::map& resource_vars, + int missing_ctx_input_prefix, + const xla::HloInputOutputAliasConfig& input_output_alias); + + // Given the XLA output in `output`, populate all outputs of `ctx`. Also + // writes out the resource variable updates. + // + // Updates to all resource variables are written in a single atomic operation. + // This models *->Write dependencies between resource variable operations. + // See jit/resource_operation_safety_analysis for details. + // + // + // Assumes that the first `missing_ctx_input_prefix` inputs to the + // compilation_result are missing and adjusts input indices accordingly. + absl::Status PopulateOutputs( + OpKernelContext* ctx, + const XlaCompiler::CompilationResult* compilation_result, + xla::ScopedShapedBuffer output, int missing_ctx_input_prefix, + absl::Span variable_infos, + const xla::HloInputOutputAliasConfig& input_output_alias, + const std::map& resource_vars); + + private: + xla::LocalClient* client_; + se::DeviceMemoryAllocator* xla_allocator_; + bool allocate_xla_tensors_; + bool use_multiple_streams_; + int device_ordinal_; +}; + +// A simple TensorBuffer implementation that allows us to create Tensors that +// take ownership of pre-allocated memory. +class XlaTensorBuffer : public TensorBuffer { + public: + XlaTensorBuffer(const void* ptr, size_t expected_size, size_t actual_size, + Allocator* allocator) + : TensorBuffer(const_cast(ptr)), + expected_size_(expected_size), + actual_size_(actual_size), + allocator_(allocator) {} + + ~XlaTensorBuffer() override { + if (data()) { + allocator_->DeallocateRaw(data()); + } + } + + size_t size() const override { return expected_size_; } + + TensorBuffer* root_buffer() override { return this; } + + void FillAllocationDescription(AllocationDescription* proto) const override { + proto->set_requested_bytes(static_cast(expected_size_)); + proto->set_allocator_name(allocator_->Name()); + proto->set_ptr(reinterpret_cast(data())); + if (allocator_->TracksAllocationSizes()) { + auto ab = static_cast(allocator_->AllocatedSize(data())); + proto->set_allocated_bytes(ab); + int64_t id = allocator_->AllocationId(data()); + if (id > 0) { + proto->set_allocation_id(id); + } + if (RefCountIsOne()) { + proto->set_has_single_reference(true); + } + } + } + + private: + size_t expected_size_; + size_t actual_size_; + Allocator* allocator_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_JIT_XLA_LAUNCH_UTIL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/jit/xla_platform_info.h b/third_party/tflite-hdrs/tensorflow/compiler/jit/xla_platform_info.h new file mode 100644 index 00000000..7c5099f0 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/jit/xla_platform_info.h @@ -0,0 +1,172 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_JIT_XLA_PLATFORM_INFO_H_ +#define TENSORFLOW_COMPILER_JIT_XLA_PLATFORM_INFO_H_ + +#include +#include +#include + +#include "tensorflow/compiler/jit/device_compiler.h" +#include "tensorflow/compiler/jit/pjrt_base_device.h" +#include "tensorflow/compiler/jit/xla_device.h" +#include "xla/stream_executor/integrations/tf_allocator_adapter.h" +#include "tensorflow/core/framework/op_kernel.h" + +namespace tensorflow { + +// Holds some information about the platform on which an +// XlaLaunch/_XlaCompile/_XlaRun op must run on. Provides a common layer of +// abstraction for normal, XLA devices and devices inheriting from +// PjRtBaseDevice. +class XlaPlatformInfo { + public: + XlaPlatformInfo() : device_type_("") {} + XlaPlatformInfo(XlaPlatformInfo&&) = default; + explicit XlaPlatformInfo( + const DeviceType device_type, se::Platform::Id platform_id, + const XlaDevice::Metadata* xla_device_metadata, + const PjRtBaseDevice::Metadata* pjrt_device_metadata, + std::shared_ptr device_allocator) + : device_type_(device_type), + platform_id_(platform_id), + xla_device_metadata_(xla_device_metadata), + pjrt_device_metadata_(pjrt_device_metadata), + device_allocator_(device_allocator) {} + + XlaPlatformInfo& operator=(XlaPlatformInfo&& other) = default; + + bool UseMultipleStreams() const { + return xla_device_metadata_ && xla_device_metadata_->UseMultipleStreams(); + } + + // Non-null only when run on an XLA device. + std::shared_ptr custom_allocator() const { + return device_allocator_; + } + + DeviceType device_type() const { return device_type_; } + + // This is equal to xla_device_metadata()->platform()->id() if + // xla_device_metadata() is not nullptr. + se::Platform::Id platform_id() const { return platform_id_; } + + // This may be null if the op this XlaPlatformInfo is for was not placed on an + // XLA device. + const XlaDevice::Metadata* xla_device_metadata() const { + return xla_device_metadata_; + } + bool is_on_xla_device() const { return xla_device_metadata() != nullptr; } + + const PjRtBaseDevice::Metadata* pjrt_device_metadata() const { + return pjrt_device_metadata_; + } + + private: + DeviceType device_type_; + se::Platform::Id platform_id_; + + // xla_device_metadata_ lives in the tensorflow::DeviceBase in which the + // XlaLaunch/_XlaCompile/_XlaRun op is placed and thus does not die before the + // XlaLaunch/_XlaCompile/_XlaRun OpKernel. + const XlaDevice::Metadata* xla_device_metadata_; + + // pjrt_device_metadata_ lives in tensorflow::PjRtBaseDevice in which the + // XlaLaunch/XlaCompileOnDemand op is placed and thus does not die before the + // op kernel. + const PjRtBaseDevice::Metadata* pjrt_device_metadata_; + + // If the op associated with this XlaPlatformInfo is placed on an XLA device + // then device_allocator_ is the xla::Backend's memory allocator. If the op + // is placed on a regular CPU or GPU device then device_allocator_ is null. + // The allocator is of unknown provenance; keep it in a shared pointer to + // set an artificial refcount of one. + std::shared_ptr device_allocator_; + + XlaPlatformInfo(const XlaPlatformInfo&) = delete; + void operator=(const XlaPlatformInfo&) = delete; +}; + +// Returns a set containing the device ids contained in visible_device_list or +// nullopt if it is empty. It returns error in case of malformed configuration +// string. +absl::StatusOr>> ParseVisibleDeviceList( + absl::string_view visible_device_list); + +// Returns the device type for building a DeviceCompiler from the given platform +// type. +absl::StatusOr GetCompilationDeviceType( + const DeviceType& platform_device_type); + +// Builds a DeviceCompiler that uses xla::LocalClient using `platform_info` and +// `compilation_device_type` (in non-TPU case) and sets *xla_device_compiler to +// point to it. Uses flags from `MarkForCompilationPassFlags` for configuring +// the persistor used in the DeviceCompiler. The platform ID from +// `platform_info` must not be null in CPU case. +absl::Status BuildXlaDeviceCompiler( + DeviceBase* dev, FunctionLibraryRuntime* flr, + const XlaPlatformInfo& platform_info, DeviceType compilation_device_type, + DeviceCompiler** + xla_device_compiler); + +// Fetches a DeviceCompiler from the tfrt_global resource manager (or creates +// one there if not found) that uses xla::PjRtClient using an appropriate +// PjRtClient for `platform_info.device_type()` and sets *pjrt_device_compiler +// to point to it. Also fetches/creates a DeviceCompilationProfiler from/in the +// tfrt_global resource manager for `platform_info.device_type()` and sets +// *profiler to point to it. Uses flags from `MarkForCompilationPassFlags` for +// configuring the persistor used in the DeviceCompiler. Please note that +// non-XLA devices aren't supported yet. This is because: +// 1. PjRtClient doesn't support data transfer for non-XLA devices yet +// 2. Fetching the PjRtClient for non-XLA devices is also not supported yet +absl::Status GetOrCreatePjRtDeviceCompilerAndProfiler( + const OpKernelContext& ctx, const XlaPlatformInfo& platform_info, + FunctionLibraryRuntime* flr, + DeviceCompiler** + pjrt_device_compiler, + DeviceCompilationProfiler** profiler); + +// Same as the above function but takes the resource manager `rm` instead of an +// OpKernelContext. +absl::Status GetOrCreatePjRtDeviceCompilerAndProfiler( + const XlaPlatformInfo& platform_info, ResourceMgr* rm, + FunctionLibraryRuntime* flr, + DeviceCompiler** + pjrt_device_compiler, + DeviceCompilationProfiler** profiler); + +// Returns information about the platform from kernel context. +XlaPlatformInfo XlaPlatformInfoFromDevice(DeviceBase* device); + +// Obtains persistent cache directory for executables that target a given device +// based off xla flags. If you shouldn't use persistent caching, returns "". +std::string GetPersistentCacheDirectory( + const DeviceType& compilation_device_type); + +// Returns allocator from platform info if non-null, or populate and return a +// pointer to the allocator adapter with allocator from context. +// +// This is necessary because for XLA devices the underlying TF allocator returns +// dummy tensors. +// +// `stream` parameter is nullable when running on host. +std::shared_ptr GetAllocator( + DeviceBase* device, se::Stream* stream, + const XlaPlatformInfo& platform_info); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_JIT_XLA_PLATFORM_INFO_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/jit/xla_tensor.h b/third_party/tflite-hdrs/tensorflow/compiler/jit/xla_tensor.h new file mode 100644 index 00000000..91e06ddf --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/jit/xla_tensor.h @@ -0,0 +1,116 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_JIT_XLA_TENSOR_H_ +#define TENSORFLOW_COMPILER_JIT_XLA_TENSOR_H_ + +#include + +#include "absl/memory/memory.h" +#include "xla/client/local_client.h" +#include "xla/service/shaped_buffer.h" +#include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/framework/device_base.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/mutex.h" + +namespace tensorflow { + +// The implementation of a Tensor for an XlaDevice. All device tensors are +// actually one of these. +// +// To distinguish between "normal" device tensors and XlaTensors, the raw +// pointer data stored in the TensorBuffer is a tagged pointer. +class XlaTensor { + public: + // Downcast from a Tensor to an XlaTensor. Return nullptr if the downcast + // fails. + static XlaTensor* FromTensor(const Tensor* tensor); + + // Create a DeviceMemoryBase from a Tensor. The Tensor can be an XlaTensor, in + // which case the returned value is shaped_buffer()->root_buffer(), or a + // normal Tensor in which case the returned value is + // {tensor.tensor_data().data(), tensor.tensor_data().size}. + static se::DeviceMemoryBase DeviceMemoryFromTensor(const Tensor& tensor); + + // Assign the internal ShapedBuffer to new memory for the given dtype and + // shape. If a ShapedBuffer exists already (has_shaped_buffer() == true), it + // is replaced and the managed memory deallocated. + absl::Status AllocateShapedBuffer(DataType dtype, + const xla::Shape& on_device_shape, + xla::LocalClient* client, + int device_ordinal); + + // Some Tensors can have complex on-device shapes, including tuple shapes. To + // manage the memory for these tensors a ShapedBuffer may be required. + + // Return true if this XlaTensor contains a ShapedBuffer. + bool has_shaped_buffer() const { return shaped_buffer_.has_value(); } + // Return the contained ShapedBuffer. + // REQUIRES: has_shaped_buffer() + const xla::ShapedBuffer& shaped_buffer() const { + CHECK(has_shaped_buffer()); + return *shaped_buffer_; + } + xla::ShapedBuffer& shaped_buffer() { + CHECK(has_shaped_buffer()); + return *shaped_buffer_; + } + // Mutates the XlaTensor to set the ShapedBuffer. + void set_shaped_buffer(xla::ScopedShapedBuffer shaped_buffer) { + shaped_buffer_ = std::move(shaped_buffer); + } + + // Adds synchronization events to 'stream' that wait for this tensor to be + // defined on 'stream'. Does nothing if the tensor is already defined on that + // stream. + void WaitForDefinitionEventOnStream(se::Stream* stream); + + // (Re)sets the definition event of the tensor to 'event', and promises that + // the tensor has already been defined on stream. Removes any previous + // definition event or any previous promises about the tensor being defined on + // streams. + // It is legal to reset the definition event of a tensor when overwriting the + // tensor's value (at which point, it is effectively a new tensor once again.) + void ResetDefinitionEvent(std::shared_ptr event, + se::Stream* stream); + + // Refresh the status of streams_defined_on_. Return the first not-OK stream's + // status or OK. + absl::Status RefreshStatusOfStreams(); + + // Convert from a raw pointer to an XlaTensor, removing the pointer tag. + static XlaTensor* FromOpaquePointer(void* ptr); + // Convert to a raw pointer from an XlaTensor, adding the pointer tag. + static void* ToOpaquePointer(XlaTensor* tensor); + + private: + // The optional contained ShapedBuffer. + std::optional shaped_buffer_; + // An optional host tensor value. + std::optional host_tensor_; + // An optional event that is triggered when the tensor's content has been + // defined. If this event is nullptr, it is assumed that the tensor's content + // is always defined. + std::shared_ptr definition_event_; + // A list of all streams for which the tensor's content is defined for any + // newly enqueued command. + absl::InlinedVector streams_defined_on_ TF_GUARDED_BY(mu_); + mutex mu_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_JIT_XLA_TENSOR_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/jit/xla_tpu_device.h b/third_party/tflite-hdrs/tensorflow/compiler/jit/xla_tpu_device.h new file mode 100644 index 00000000..bb31c65b --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/jit/xla_tpu_device.h @@ -0,0 +1,36 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_JIT_XLA_TPU_DEVICE_H_ +#define TENSORFLOW_COMPILER_JIT_XLA_TPU_DEVICE_H_ + +#include "tensorflow/core/common_runtime/device.h" +#include "tensorflow/core/common_runtime/device_factory.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/public/session_options.h" + +namespace tensorflow { + +void RegisterTpuDeviceToDeviceCopy(); + +void RegisterTpuNodeDevice( + bool tpu_autoclustering, bool tpu_xla_device_failure_closes_chips, + bool tpu_use_substreams_for_cross_tpu_device_transfers); + +void RegisterTpuSystemDevice(); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_JIT_XLA_TPU_DEVICE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/init_mlir.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/init_mlir.h new file mode 100644 index 00000000..290ef361 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/init_mlir.h @@ -0,0 +1,33 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_INIT_MLIR_H_ +#define TENSORFLOW_COMPILER_MLIR_INIT_MLIR_H_ + +namespace tensorflow { + +// Initializer to perform TF's InitMain initialization. +// InitMain also performs flag parsing and '--' is used to separate flags passed +// to it: Flags before the first '--' are parsed by InitMain and argc and argv +// progressed to the flags post. If there is no separator, then no flags are +// parsed by InitMain and argc/argv left unadjusted. +class InitMlir { + public: + InitMlir(int *argc, char ***argv); +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_INIT_MLIR_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/allocation.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/allocation.h new file mode 100644 index 00000000..a82d8c04 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/allocation.h @@ -0,0 +1,158 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +/// \file +/// +/// Memory management for TF Lite. +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_ALLOCATION_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_ALLOCATION_H_ + +#include + +#include +#include +#include + +#include "tensorflow/compiler/mlir/lite/core/api/error_reporter.h" + +namespace tflite { + +/// A memory allocation handle. This could be a mmap or shared memory. +class Allocation { + public: + using Ptr = std::unique_ptr; + + virtual ~Allocation() {} + + enum class Type { + kMMap, + kFileCopy, + kMemory, + }; + + /// Base pointer of this allocation + virtual const void* base() const = 0; + /// Size in bytes of the allocation + virtual size_t bytes() const = 0; + /// Whether the allocation is valid + virtual bool valid() const = 0; + /// Return the type of the Allocation. + Type type() const { return type_; } + + protected: + Allocation(ErrorReporter* error_reporter, Type type) + : error_reporter_(error_reporter), type_(type) {} + ErrorReporter* error_reporter_; + + private: + const Type type_; +}; + +/// Note that not all platforms support MMAP-based allocation. +/// Use `IsSupported()` to check. +class MMAPAllocation : public Allocation { + public: + /// Loads and maps the provided file to a memory region. + MMAPAllocation(const char* filename, ErrorReporter* error_reporter); + + /// Maps the provided file descriptor to a memory region. + /// Note: The provided file descriptor will be dup'ed for usage; the caller + /// retains ownership of the provided descriptor and should close accordingly. + MMAPAllocation(int fd, ErrorReporter* error_reporter); + + /// Maps the provided file descriptor, with the given offset and length (both + /// in bytes), to a memory region. + /// Note: The provided file descriptor will be dup'ed for usage; the caller + /// retains ownership of the provided descriptor and should close accordingly. + MMAPAllocation(int fd, size_t offset, size_t length, + ErrorReporter* error_reporter); + + ~MMAPAllocation() override; + const void* base() const override; + size_t bytes() const override; + bool valid() const override; + + int fd() const { return mmap_fd_; } + + // The start address of the mmapped buffer. + // This will be base() rounded down to the nearest page boundary. + const void* mmapped_buffer() const { return mmapped_buffer_; } + + // The size of the mmapped buffer. + size_t mmapped_buffer_size() const { return bytes() + offset_in_buffer_; } + + // Offset of mmapped_buffer() in the file referenced by the file descriptor. + size_t mmapped_buffer_offset_in_file() const { + return offset_of_buffer_in_file_; + } + + static bool IsSupported(); + + protected: + // Data required for mmap. + int mmap_fd_ = -1; // mmap file descriptor + const void* mmapped_buffer_; + size_t buffer_size_bytes_ = 0; + // Used when the address to mmap is not page-aligned. + size_t offset_in_buffer_ = 0; + size_t offset_of_buffer_in_file_ = 0; + + private: + // Assumes ownership of the provided `owned_fd` instance. + MMAPAllocation(ErrorReporter* error_reporter, int owned_fd); + + // Assumes ownership of the provided `owned_fd` instance, and uses the given + // offset and length (both in bytes) for memory mapping. + MMAPAllocation(ErrorReporter* error_reporter, int owned_fd, size_t offset, + size_t length); +}; + +class FileCopyAllocation : public Allocation { + public: + /// Loads the provided file into a heap memory region. + FileCopyAllocation(const char* filename, ErrorReporter* error_reporter); + ~FileCopyAllocation() override; + const void* base() const override; + size_t bytes() const override; + bool valid() const override; + + private: + std::unique_ptr copied_buffer_; + size_t buffer_size_bytes_ = 0; +}; + +class MemoryAllocation : public Allocation { + public: + /// Provides a (read-only) view of the provided buffer region as an + /// allocation. + /// Note: The caller retains ownership of `ptr`, and must ensure it remains + /// valid for the lifetime of the class instance. + MemoryAllocation(const void* ptr, size_t num_bytes, + ErrorReporter* error_reporter); + ~MemoryAllocation() override; + const void* base() const override; + size_t bytes() const override; + bool valid() const override; + + private: + const void* buffer_; +#if defined(__x86_64__) && defined(UNDEFINED_BEHAVIOR_SANITIZER) + void* aligned_ptr_ = nullptr; +#endif + size_t buffer_size_bytes_ = 0; +}; + +} // namespace tflite + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_ALLOCATION_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/common/tfl_pass_config.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/common/tfl_pass_config.h new file mode 100644 index 00000000..db9715e9 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/common/tfl_pass_config.h @@ -0,0 +1,148 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_COMMON_TFL_PASS_CONFIG_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_COMMON_TFL_PASS_CONFIG_H_ + +#include +#include + +#include "absl/strings/str_join.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/Support/raw_ostream.h" +#include "tensorflow/compiler/mlir/lite/converter_flags.pb.h" +#include "tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_config.h" + +namespace mlir { +namespace TFL { + +// A config that controls which passes get run as part TFLite converter. +struct PassConfig { + explicit PassConfig(quant::QuantizationSpecs specs) + : quant_specs(std::move(specs)) {} + + // If `emit_builtin_tflite_ops` is true, TF Lite legalization passes will be + // added, which produces TF Lite ops. + bool emit_builtin_tflite_ops = true; + // If `lower_tensor_list_ops` is true, tensorlist ops will be lowered to basic + // TF ops before legalization to TF Lite dialect. + bool lower_tensor_list_ops = false; + // The allowlist of functions that would be preserved after trimming. + llvm::ArrayRef trim_functions_allowlist; + // All information about quantization. + quant::QuantizationSpecs quant_specs; + // If `form_clusters` is true , clusters are formed by grouping consecutive + // ops of the same device, under a `tf_device.launch` op. + bool form_clusters = false; + // If `unfold_batch_matmul` is true, the tf.BatchMatMul is unfolded to a set + // of tfl.fully_connected ops. + bool unfold_batch_matmul = true; + // Whether to outline WhileOp at the end of the pipeline. + bool outline_tf_while = false; + // Whether to do shape inference. + bool shape_inference = true; + // Whether to do TFLite runtime verification. + bool runtime_verification = true; + // Whether to enable TFLite variables or not, this will allow + // mutable variables and produce ReadVariable/AssignVariable ops in TFLite. + bool enable_tflite_variables = false; + // Whether to unfold large splat constant tensors and replace them with + // fill operation. + bool unfold_large_splat_constant = false; + // Whether to run the `GuaranteeAllFuncsOneUsePass` to ensure each function + // has a single use. + bool guarantee_all_funcs_one_use = false; + // Whether to enable the hlo/stablehlo to tf conversion. This also supports + // the case where a saved model contains both TF module and serialized + // StableHLO module. + bool enable_hlo_to_tf_conversion = false; + // Whether to disable the direct hlo/stablehlo to Tensorflow Lite conversion. + // + // This prevents from directly converting from HLO to TFLite without going + // through TF for some of the ops. Some conversions are only supported through + // this path. + bool disable_hlo_to_tfl_conversion = false; + // Whether to enable to use DynamicUpdateSlice op. + bool enable_dynamic_update_slice = false; + // Whether to preserve AssertOp during legalization. + bool preserve_assert_op = false; + // Whether to enable TF->stablehlo passes. + bool enable_stablehlo_conversion = false; + // Whether to convert `tf.TensorList*` to `tfl.custom_op` if they can all + // be supported. + bool legalize_custom_tensor_list_ops = false; + // Whether to convert some tensor types to a lower precision if all values + // within that tensor are within the range of the lower precision. This could + // have side effects e.g. reduced flatbuffer size. Only certain type + // conversions are supported. + bool reduce_type_precision = false; + // Whether to consider this model a quantized model with quantize/dequantize + // ops and to convert kernels to quantized kernels wherever appropriate. + quant::QDQConversionMode qdq_conversion_mode = + quant::QDQConversionMode::kQDQNone; + + // When set to true, StableHLO Quantizer is run. The full configuration for + // the quantizer is at `ConverterFlags::quantization_config`. + bool enable_stablehlo_quantizer = false; + + // Enables the attempt to directly lower composites into tflite ops. + bool enable_composite_direct_lowering = true; + + // Specifies the framework of the original model. + tflite::ConverterFlags::ModelOriginFramework model_origin_framework = + tflite::ConverterFlags::UNSET; + + // When set to true, convert +Inf/-Inf to MIN/MAX float value and output of + // convert only contains finite values. + bool canonicalizing_inf_as_min_max_float = true; +}; + +inline llvm::raw_ostream& operator<<(llvm::raw_ostream& os, + const PassConfig& pass_config) { + return os << "emit_builtin_tflite_ops: " + << pass_config.emit_builtin_tflite_ops + << "\nlower_tensor_list_ops: " << pass_config.lower_tensor_list_ops + << "\ntrim_functions_allowlist: " + << absl::StrJoin(pass_config.trim_functions_allowlist.vec(), ",") + << "\nform_clusters: " << pass_config.form_clusters + << "\nunfold_batch_matmul: " << pass_config.unfold_batch_matmul + << "\noutline_tf_while: " << pass_config.outline_tf_while + << "\nshape_inference: " << pass_config.shape_inference + << "\nruntime_verification: " << pass_config.runtime_verification + << "\nenable_tflite_variables: " + << pass_config.enable_tflite_variables + << "\nunfold_large_splat_constant: " + << pass_config.unfold_large_splat_constant + << "\nguarantee_all_funcs_one_use: " + << pass_config.guarantee_all_funcs_one_use + << "\nenable_hlo_to_tf_conversion: " + << pass_config.enable_hlo_to_tf_conversion + << "\nenable_stablehlo_conversion: " + << pass_config.enable_stablehlo_conversion + << "\nlegalize_custom_tensor_list_ops: " + << pass_config.legalize_custom_tensor_list_ops + << "\nreduce_type_precision: " << pass_config.reduce_type_precision + << "\nconvert_qdq_format: " + << GetQDQQuantModeString(pass_config.qdq_conversion_mode) + << "\nmodel_origin_framework: " + << tflite::ConverterFlags::ModelOriginFramework_Name( + pass_config.model_origin_framework) + << "\n"; +} + +} // namespace TFL +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_COMMON_TFL_PASS_CONFIG_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/core/absl_error_model_builder.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/core/absl_error_model_builder.h new file mode 100644 index 00000000..c3d76e2b --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/core/absl_error_model_builder.h @@ -0,0 +1,47 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_CORE_ABSL_ERROR_MODEL_BUILDER_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_CORE_ABSL_ERROR_MODEL_BUILDER_H_ + +#include + +#include "tensorflow/compiler/mlir/lite/core/api/error_reporter.h" +#include "tensorflow/compiler/mlir/lite/core/model_builder_base.h" + +namespace mlir::TFL { + +// An error reporter that uses absl logging. +class AbslErrorReporter : public tflite::ErrorReporter { + int Report(const char* format, va_list args) override; +}; + +tflite::ErrorReporter* GetAbslErrorReporter(); + +class FlatBufferModelAbslError + : public tflite::impl::FlatBufferModelBase { + public: + // Use stderr_reporter as the default error reporter. + static tflite::ErrorReporter* GetDefaultErrorReporter() { + return GetAbslErrorReporter(); + } + + // Inherit all constructors from FlatBufferModelBase since inherited factory + // methods refer to them. + using FlatBufferModelBase::FlatBufferModelBase; +}; + +} // namespace mlir::TFL + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_CORE_ABSL_ERROR_MODEL_BUILDER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/core/api/error_reporter.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/core/api/error_reporter.h new file mode 100644 index 00000000..79c9fc93 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/core/api/error_reporter.h @@ -0,0 +1,72 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_CORE_API_ERROR_REPORTER_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_CORE_API_ERROR_REPORTER_H_ + +#include + +namespace tflite { + +/// A functor that reports error to supporting system. Invoked similar to +/// printf. +/// +/// Usage: +/// ErrorReporter foo; +/// foo.Report("test %d", 5); +/// or +/// va_list args; +/// foo.Report("test %d", args); // where args is va_list +/// +/// Subclass ErrorReporter to provide another reporting destination. +/// For example, if you have a GUI program, you might redirect to a buffer +/// that drives a GUI error log box. +class ErrorReporter { + public: + virtual ~ErrorReporter() = default; + /// Converts `args` to character equivalents according to `format` string, + /// constructs the error string and report it. + /// Returns number of characters written or zero on success, and negative + /// number on error. + virtual int Report(const char* format, va_list args) = 0; + + /// Converts arguments to character equivalents according to `format` string, + /// constructs the error string and report it. + /// Returns number of characters written or zero on success, and negative + /// number on error. + int Report(const char* format, ...); + + /// Equivalent to `Report` above. The additional `void*` parameter is unused. + /// This method is for compatibility with macros that takes `TfLiteContext`, + /// like TF_LITE_ENSURE and related macros. + int ReportError(void*, const char* format, ...); +}; + +} // namespace tflite + +// You should not make bare calls to the error reporter, instead use the +// TF_LITE_REPORT_ERROR macro, since this allows message strings to be +// stripped when the binary size has to be optimized. If you are looking to +// reduce binary size, define TF_LITE_STRIP_ERROR_STRINGS when compiling and +// every call will be stubbed out, taking no memory. +#ifndef TF_LITE_STRIP_ERROR_STRINGS +#define TF_LITE_REPORT_ERROR(reporter, ...) \ + do { \ + static_cast<::tflite::ErrorReporter*>(reporter)->Report(__VA_ARGS__); \ + } while (false) +#else // TF_LITE_STRIP_ERROR_STRINGS +#define TF_LITE_REPORT_ERROR(reporter, ...) +#endif // TF_LITE_STRIP_ERROR_STRINGS + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_CORE_API_ERROR_REPORTER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/core/api/flatbuffer_conversions.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/core/api/flatbuffer_conversions.h new file mode 100644 index 00000000..ed452c90 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/core/api/flatbuffer_conversions.h @@ -0,0 +1,428 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_CORE_API_FLATBUFFER_CONVERSIONS_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_CORE_API_FLATBUFFER_CONVERSIONS_H_ + +#include +#include + +#include "absl/status/status.h" +#include "tensorflow/compiler/mlir/lite/core/c/tflite_types.h" +#include "tensorflow/compiler/mlir/lite/schema/schema_generated.h" + +// The namespace tflite_file is for the data structures that define the .tflite +// file format, and code that is tightly coupled with those data structures. +// The .tflite file format is the serialized flatbuffer representation of +// computations on tensors that TF Lite uses for distribution of compiled ML +// models. +namespace tflite_file { + +// This namespace contains functions that transform code and data structures +// that are defined in the flatbuffer serialization format into +// in-memory values that are used by the runtime API, interpreter and compiler. +namespace flatbuffer_conversions { + +using tflite::Operator; + +// Interface class for builtin data allocations. +class BuiltinDataAllocator { + public: + virtual void* Allocate(size_t size, size_t alignment_hint) = 0; + virtual void Deallocate(void* data) = 0; + + // Allocate a structure, but make sure it is a POD structure that doesn't + // require constructors to run. The reason we do this, is that Interpreter's C + // extension part will take ownership so destructors will not be run during + // deallocation. + template + T* AllocatePOD() { + // TODO(b/154346074): Change this to is_trivially_destructible when all + // platform targets support that properly. + static_assert(std::is_pod::value, "Builtin data structure must be POD."); + void* allocated_memory = this->Allocate(sizeof(T), alignof(T)); + return new (allocated_memory) T(); + } + + virtual ~BuiltinDataAllocator() = default; +}; + +// Parse the appropriate data out of the op. +// +// This handles builtin data explicitly as there are flatbuffer schemas. +// If it returns kTfLiteOk, it passes the data out with `builtin_data`. The +// calling function has to pass in an allocator object, and this allocator +// will be called to reserve space for the output data. If the calling +// function's allocator reserves memory on the heap, then it's the calling +// function's responsibility to free it. +// If it returns kTfLiteError, `builtin_data` will be `nullptr`. +absl::Status ParseOpData(const tflite::Operator* op, + tflite::BuiltinOperator op_type, + BuiltinDataAllocator* allocator, void** builtin_data); + +// Converts the tensor data type used in the flat buffer to the representation +// used by the runtime. +absl::Status ConvertTensorType(tflite::TensorType tensor_type, + TfLiteType* type); + +absl::Status ParseAbs(const Operator* op, BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseAdd(const Operator* op, BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseAddN(const Operator* op, BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseArgMax(const Operator* op, BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseArgMin(const Operator* op, BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseAssignVariable(const Operator* op, + BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseBatchMatMul(const Operator* op, + BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseBatchToSpaceNd(const Operator* op, + BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseBroadcastArgs(const Operator* op, + BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseBroadcastTo(const Operator* op, + BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseCallOnce(const Operator* op, BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseCeil(const Operator* op, BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseCast(const Operator* op, BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseConcatenation(const Operator* op, + BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseConv2D(const Operator* op, BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseCos(const Operator* op, BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseCumsum(const Operator* op, BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseDepthToSpace(const Operator* op, + BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseDepthwiseConv2D(const Operator* op, + BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseDequantize(const Operator* op, + BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseDiv(const Operator* op, BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseElu(const Operator* op, BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseEmbeddingLookup(const Operator* op, + BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseEqual(const Operator* op, BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseExp(const Operator* op, BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseExpandDims(const Operator* op, + BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseFill(const Operator* op, BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseFloor(const Operator* op, BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseFloorDiv(const Operator* op, BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseFloorMod(const Operator* op, BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseFullyConnected(const Operator* op, + BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseGather(const Operator* op, BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseGatherNd(const Operator* op, BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseGreater(const Operator* op, BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseGreaterEqual(const Operator* op, + BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseHardSwish(const Operator* op, BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseIf(const Operator* op, BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseL2Normalization(const Operator* op, + BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseLeakyRelu(const Operator* op, BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseLess(const Operator* op, BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseLessEqual(const Operator* op, BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseLog(const Operator* op, BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseLogicalAnd(const Operator* op, + BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseLogicalNot(const Operator* op, + BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseLogicalOr(const Operator* op, BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseLogistic(const Operator* op, BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseLogSoftmax(const Operator* op, + BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseLSTM(const Operator* op, BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseMaximum(const Operator* op, BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseMinimum(const Operator* op, BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseMirrorPad(const Operator* op, BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseMul(const Operator* op, BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseNeg(const Operator* op, BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseNotEqual(const Operator* op, BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParsePack(const Operator* op, BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParsePad(const Operator* op, BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParsePadV2(const Operator* op, BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParsePool(const Operator* op, BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParsePow(const Operator* op, BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParsePrelu(const Operator* op, BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseQuantize(const Operator* op, BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseReadVariable(const Operator* op, + BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseReducer(const Operator* op, BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseRelu(const Operator* op, BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseRelu6(const Operator* op, BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseReshape(const Operator* op, BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseResizeBilinear(const Operator* op, + BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseResizeNearestNeighbor(const Operator* op, + BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseRound(const Operator* op, BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseRsqrt(const Operator* op, BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseSelectV2(const Operator* op, BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseShape(const Operator* op, BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseSin(const Operator* op, BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseSlice(const Operator* op, BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseSoftmax(const Operator* op, BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseSpaceToBatchNd(const Operator* op, + BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseSpaceToDepth(const Operator* op, + BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseSplit(const Operator* op, BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseSplitV(const Operator* op, BuiltinDataAllocator* allocator, + void** builtin_data); +absl::Status ParseSqueeze(const Operator* op, BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseSqrt(const Operator* op, BuiltinDataAllocator* allocator, + void** builtin_data); +absl::Status ParseSquare(const Operator* op, BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseSquaredDifference(const Operator* op, + BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseStridedSlice(const Operator* op, + BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseSub(const Operator* op, BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseSvdf(const Operator* op, BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseTanh(const Operator* op, BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseTranspose(const Operator* op, BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseTransposeConv(const Operator* op, + BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseUnpack(const Operator* op, BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseUnidirectionalSequenceLSTM(const Operator* op, + BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseVarHandle(const Operator* op, BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseWhile(const Operator* op, BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseZerosLike(const Operator* op, BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseBitwiseXor(const Operator* op, + BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseRightShift(const Operator* op, + BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseStablehloScatter(const Operator* op, + BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseStablehloRngBitGenerator(const Operator* op, + BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseStablehloGather(const Operator* op, + BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseStablehloReduceWindow(const Operator* op, + BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseStablehloPad(const Operator* op, + BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseStablehloComposite(const Operator* op, + BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseStablehloShiftLeft(const Operator* op, + BuiltinDataAllocator* allocator, + void** builtin_data); + +absl::Status ParseStablehloCase(const Operator* op, + BuiltinDataAllocator* allocator, + void** builtin_data); + +} // namespace flatbuffer_conversions +} // namespace tflite_file + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_CORE_API_FLATBUFFER_CONVERSIONS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/core/api/verifier.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/core/api/verifier.h new file mode 100644 index 00000000..2e24347d --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/core/api/verifier.h @@ -0,0 +1,39 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +/// \file +/// +/// Abstract interface for verifying a model. +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_CORE_API_VERIFIER_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_CORE_API_VERIFIER_H_ + +#include "tensorflow/compiler/mlir/lite/core/api/error_reporter.h" + +namespace tflite { + +/// Abstract interface that verifies whether a given model is legit. +/// It facilitates the use-case to verify and build a model without loading it +/// twice. +/// (See also "tensorflow/lite/tools/verifier.h".) +class TfLiteVerifier { + public: + /// Returns true if the model is legit. + virtual bool Verify(const char* data, int length, + ErrorReporter* reporter) = 0; + virtual ~TfLiteVerifier() {} +}; + +} // namespace tflite + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_CORE_API_VERIFIER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/core/c/builtin_op_data.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/core/c/builtin_op_data.h new file mode 100644 index 00000000..1327162f --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/core/c/builtin_op_data.h @@ -0,0 +1,670 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +/// WARNING: Users of TensorFlow Lite should not include this file directly, +/// but should instead include +/// "third_party/tensorflow/lite/c/builtin_op_data.h". +/// Only the TensorFlow Lite implementation itself should include this +/// file directly. +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_CORE_C_BUILTIN_OP_DATA_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_CORE_C_BUILTIN_OP_DATA_H_ + +#include // IWYU pragma: keep +#include +#include + +#include "tensorflow/compiler/mlir/lite/core/c/tflite_types.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +// TfLiteReshapeParams can't have dynamic data so we fix the maximum possible +// number of dimensions. +#define TFLITE_RESHAPE_PARAMS_MAX_DIMENSION_COUNT 8 +#define TFLITE_STABLEHLO_SCATTER_PARAMS_MAX_DIMENSION_COUNT 8 +#define TFLITE_STABLEHLO_GATHER_PARAMS_MAX_DIMENSION_COUNT 8 +#define TFLITE_STABLEHLO_REDUCE_WINDOW_PARAMS_MAX_DIMENSION_COUNT 8 +#define TFLITE_STABLEHLO_PAD_PARAMS_MAX_DIMENSION_COUNT 8 +#define TFLITE_STABLEHLO_CASE_PARAMS_MAX_BRANCHES_COUNT 20 + +// TODO(aselle): Consider using "if this then that" for testing. + +// Useful placeholder to put in otherwise empty structs to avoid size warnings. +typedef struct { + char dummy; +} EmptyStructPlaceholder; + +// IMPORTANT: All new members of structs must be added at the end to ensure +// backwards compatibility. + +// Possible padding types (for convolutions) +typedef enum { + kTfLitePaddingUnknown = 0, + kTfLitePaddingSame, + kTfLitePaddingValid, +} TfLitePadding; + +typedef enum { + kTfLiteMirrorPaddingUnknown = 0, + kTfLiteMirrorPaddingReflect, + kTfLiteMirrorPaddingSymmetric, +} TfLiteMirrorPaddingMode; + +// TODO(b/130259536): We should move this out of builtin_op_data. +typedef struct { + int width; + int height; + int width_offset; + int height_offset; +} TfLitePaddingValues; + +typedef struct { + TfLiteMirrorPaddingMode mode; +} TfLiteMirrorPaddingParams; + +// Possible fused activation functions. +typedef enum { + kTfLiteActNone = 0, + kTfLiteActRelu, + kTfLiteActReluN1To1, // min(max(-1, x), 1) + kTfLiteActRelu6, // min(max(0, x), 6) + kTfLiteActTanh, + kTfLiteActSignBit, + kTfLiteActSigmoid, +} TfLiteFusedActivation; + +typedef struct { + // Parameters for CONV_2D version 1. + TfLitePadding padding; + int stride_width; + int stride_height; + TfLiteFusedActivation activation; + + // Parameters for CONV_2D version 2. + // Note: Version 2 supports dilation values not equal to 1. + int dilation_width_factor; + int dilation_height_factor; + + // Parameters for CONV_2D version 7 or above. + // Used to determine the default value for the quantized bias. + TfLiteType quantized_bias_type; +} TfLiteConvParams; + +typedef struct { + TfLitePadding padding; + int stride_width; + int stride_height; + int stride_depth; + int dilation_width_factor; + int dilation_height_factor; + int dilation_depth_factor; + TfLiteFusedActivation activation; +} TfLiteConv3DParams; + +typedef TfLiteConv3DParams TfLiteConv3DTransposeParams; + +typedef struct { + TfLitePadding padding; + int stride_width; + int stride_height; + int filter_width; + int filter_height; + TfLiteFusedActivation activation; + struct { + TfLitePaddingValues padding; + } computed; +} TfLitePoolParams; + +typedef struct { + // Parameters for DepthwiseConv version 1 or above. + TfLitePadding padding; + int stride_width; + int stride_height; + // `depth_multiplier` is redundant. It's used by CPU kernels in + // TensorFlow 2.0 or below, but ignored in versions above. + // + // The information can be deduced from the shape of input and the shape of + // weights. Since the TFLiteConverter toolchain doesn't support partially + // specified shapes, relying on `depth_multiplier` stops us from supporting + // graphs with dynamic shape tensors. + // + // Note: Some of the delegates (e.g. NNAPI, GPU) are still relying on this + // field. + int depth_multiplier; + TfLiteFusedActivation activation; + // Parameters for DepthwiseConv version 2 or above. + int dilation_width_factor; + int dilation_height_factor; +} TfLiteDepthwiseConvParams; + +typedef struct { + int rank; + TfLiteFusedActivation activation; + + // Parameter for SVDF version 4. + bool asymmetric_quantize_inputs; +} TfLiteSVDFParams; + +typedef struct { + TfLiteFusedActivation activation; + + // Parameter for RNN version 3. + bool asymmetric_quantize_inputs; +} TfLiteRNNParams; + +typedef struct { + bool time_major; + TfLiteFusedActivation activation; + + // Parameter for Sequence RNN version 3. + bool asymmetric_quantize_inputs; +} TfLiteSequenceRNNParams; + +typedef struct { + bool time_major; + TfLiteFusedActivation activation; + bool merge_outputs; + + // Parameter for Bidirectional RNN version 3. + bool asymmetric_quantize_inputs; +} TfLiteBidirectionalSequenceRNNParams; + +typedef enum { + kTfLiteFullyConnectedWeightsFormatDefault = 0, + kTfLiteFullyConnectedWeightsFormatShuffled4x16Int8 = 1, +} TfLiteFullyConnectedWeightsFormat; + +typedef struct { + // Parameters for FullyConnected version 1 or above. + TfLiteFusedActivation activation; + + // Parameters for FullyConnected version 2 or above. + TfLiteFullyConnectedWeightsFormat weights_format; + + // Parameters for FullyConnected version 5 or above. + // If set to true, then the number of dimensions in the input and the output + // tensors are the same. Furthermore, all but the last dimension of the input + // and output shapes will be equal. + bool keep_num_dims; + + // Parameters for FullyConnected version 7 or above. + // If set to true and the weights are quantized, then non constant inputs + // are quantized at evaluation time with asymmetric quantization. + bool asymmetric_quantize_inputs; + + // Parameters for FullyConnected version 10 or above. + // Used to determine the default value for the quantized bias. + TfLiteType quantized_bias_type; +} TfLiteFullyConnectedParams; + +typedef enum { + kTfLiteLshProjectionUnknown = 0, + kTfLiteLshProjectionSparse = 1, + kTfLiteLshProjectionDense = 2, +} TfLiteLSHProjectionType; + +typedef struct { + TfLiteLSHProjectionType type; +} TfLiteLSHProjectionParams; + +typedef struct { + float beta; +} TfLiteSoftmaxParams; + +typedef struct { + int axis; + TfLiteFusedActivation activation; +} TfLiteConcatenationParams; + +typedef struct { + TfLiteFusedActivation activation; + // Parameter added for the version 4. + bool pot_scale_int16; +} TfLiteAddParams; + +typedef struct { + EmptyStructPlaceholder placeholder; +} TfLiteSpaceToBatchNDParams; + +typedef struct { + EmptyStructPlaceholder placeholder; +} TfLiteBatchToSpaceNDParams; + +typedef struct { + bool adj_x; + bool adj_y; + // Parameters for BatchMatMul version 4 or above. + // If set to true and the weights are quantized, then non constant inputs + // are quantized at evaluation time with asymmetric quantization. + bool asymmetric_quantize_inputs; +} TfLiteBatchMatMulParams; + +typedef struct { + TfLiteFusedActivation activation; +} TfLiteMulParams; + +typedef struct { + TfLiteFusedActivation activation; + // Parameter added for the version 5. + bool pot_scale_int16; +} TfLiteSubParams; + +typedef struct { + TfLiteFusedActivation activation; +} TfLiteDivParams; + +typedef struct { + TfLiteFusedActivation activation; +} TfLiteL2NormParams; + +typedef struct { + int radius; + float bias; + float alpha; + float beta; +} TfLiteLocalResponseNormParams; + +typedef enum { + kTfLiteLSTMFullKernel = 0, + kTfLiteLSTMBasicKernel +} TfLiteLSTMKernelType; + +typedef struct { + // Parameters for LSTM version 1. + TfLiteFusedActivation activation; + float cell_clip; + float proj_clip; + + // Parameters for LSTM version 2. + // kTfLiteLSTMBasicKernel is only supported in version 2 or above. + TfLiteLSTMKernelType kernel_type; + + // Parameters for LSTM version 4. + bool asymmetric_quantize_inputs; +} TfLiteLSTMParams; + +typedef struct { + // Parameters needed for the underlying LSTM. + TfLiteFusedActivation activation; + float cell_clip; + float proj_clip; + + // If set to true then the first dimension is time, otherwise batch. + bool time_major; + + // Parameter for unidirectional sequence RNN version 3. + bool asymmetric_quantize_inputs; + + // Parameter for unidirectional sequence RNN version 4. + bool diagonal_recurrent_tensors; +} TfLiteUnidirectionalSequenceLSTMParams; + +typedef struct { + // Parameters supported by version 1: + // Parameters inherited for the LSTM kernel. + TfLiteFusedActivation activation; + float cell_clip; + float proj_clip; + + // If true, store the outputs of both directions in the first output. + bool merge_outputs; + + // Parameters supported by version 2: + // If set to true then the first dimension is time, otherwise batch. + bool time_major; + + // Parameters supported by version 3: + // If set to true, then hybrid ops use asymmetric quantization for inputs. + bool asymmetric_quantize_inputs; +} TfLiteBidirectionalSequenceLSTMParams; + +typedef struct { + bool align_corners; + // half_pixel_centers assumes pixels are of half the actual dimensions, and + // yields more accurate resizes. Corresponds to the same argument for the + // original TensorFlow op in TF2.0. + bool half_pixel_centers; +} TfLiteResizeBilinearParams; + +typedef struct { + bool align_corners; + bool half_pixel_centers; +} TfLiteResizeNearestNeighborParams; + +typedef struct { + EmptyStructPlaceholder placeholder; +} TfLitePadParams; + +typedef struct { + EmptyStructPlaceholder placeholder; +} TfLitePadV2Params; + +typedef struct { + // These fields are only used in old models for backward compatibility. + // In the current implementation, we use the 2nd input of the op as the shape, + // and these fields are unused. + int32_t shape[TFLITE_RESHAPE_PARAMS_MAX_DIMENSION_COUNT]; + int num_dimensions; +} TfLiteReshapeParams; + +typedef struct { + int ngram_size; + int max_skip_size; + bool include_all_ngrams; +} TfLiteSkipGramParams; + +typedef struct { + int block_size; +} TfLiteSpaceToDepthParams; + +typedef struct { + int block_size; +} TfLiteDepthToSpaceParams; + +typedef struct { + TfLiteType in_data_type; + TfLiteType out_data_type; +} TfLiteCastParams; + +typedef enum { + kTfLiteCombinerTypeSum = 0, + kTfLiteCombinerTypeMean = 1, + kTfLiteCombinerTypeSqrtn = 2, +} TfLiteCombinerType; + +typedef struct { + TfLiteCombinerType combiner; +} TfLiteEmbeddingLookupSparseParams; + +typedef struct { + int axis; + int batch_dims; +} TfLiteGatherParams; + +typedef struct { + EmptyStructPlaceholder placeholder; +} TfLiteTransposeParams; + +typedef struct { + bool keep_dims; +} TfLiteReducerParams; + +typedef struct { + int num_splits; +} TfLiteSplitParams; + +typedef struct { + int num_splits; +} TfLiteSplitVParams; + +typedef struct { + // TODO(ahentz): We can't have dynamic data in this struct, at least not yet. + // For now we will fix the maximum possible number of dimensions. + int32_t squeeze_dims[8]; + int num_squeeze_dims; +} TfLiteSqueezeParams; + +typedef struct { + int begin_mask; + int end_mask; + int ellipsis_mask; + int new_axis_mask; + int shrink_axis_mask; + + // Parameters supported by version 8: + // If true, then the end tensor is an offset of the begin tensor. + bool offset; +} TfLiteStridedSliceParams; + +typedef struct { + TfLiteType output_type; +} TfLiteArgMaxParams; + +typedef struct { + TfLiteType output_type; +} TfLiteArgMinParams; + +typedef struct { + // Parameters supported by version 1: + TfLitePadding padding; + int stride_width; + int stride_height; + + // Parameters supported by version 4: + TfLiteFusedActivation activation; + + // Parameters for TransposeConv version 5 or above. + // Used to determine the default value for the quantized bias. + TfLiteType quantized_bias_type; +} TfLiteTransposeConvParams; + +typedef struct { + bool validate_indices; +} TfLiteSparseToDenseParams; + +typedef struct { + TfLiteType out_type; +} TfLiteShapeParams; + +typedef struct { + EmptyStructPlaceholder placeholder; +} TfLiteRankParams; + +typedef struct { + // Parameters supported by version 1: + float min; + float max; + int num_bits; + + // Parameters supported by version 2: + bool narrow_range; +} TfLiteFakeQuantParams; + +typedef struct { + int values_count; + int axis; +} TfLitePackParams; + +typedef struct { + int axis; +} TfLiteOneHotParams; + +typedef struct { + int num; + int axis; +} TfLiteUnpackParams; + +typedef struct { + float alpha; +} TfLiteLeakyReluParams; + +typedef struct { + TfLiteType index_out_type; +} TfLiteUniqueParams; + +typedef struct { + int seq_dim; + int batch_dim; +} TfLiteReverseSequenceParams; + +typedef struct { + EmptyStructPlaceholder placeholder; +} TfLiteMatrixDiagParams; + +typedef struct { + EmptyStructPlaceholder placeholder; +} TfLiteMatrixSetDiagParams; + +typedef struct { + int then_subgraph_index; + int else_subgraph_index; +} TfLiteIfParams; + +typedef struct { + int cond_subgraph_index; + int body_subgraph_index; +} TfLiteWhileParams; + +typedef struct { + bool exclusive; + bool reverse; +} TfLiteCumsumParams; + +typedef struct { + int init_subgraph_index; +} TfLiteCallOnceParams; + +typedef struct { + int table_id; + TfLiteType key_dtype; + TfLiteType value_dtype; +} TfLiteHashtableParams; + +typedef struct { + const char* container; + const char* shared_name; +} TfLiteVarHandleParams; + +typedef struct { + int seed; + int seed2; +} TfLiteRandomParams; + +typedef struct { + int num_boundaries; + // This points to the memory stored in the model (flatbuffer), + // and is not owned. + const float* boundaries; +} TfLiteBucketizeParams; + +typedef struct { + bool approximate; +} TfLiteGeluParams; + +typedef struct { + int64_t dimension; +} TfLiteStablehloConcatenateParams; + +typedef struct { + // See the stablehlo spec for the explanation of the attributes: + // https://github.com/openxla/stablehlo/blob/main/docs/spec.md#scatter + bool indices_are_sorted; + int64_t + update_window_dims[TFLITE_STABLEHLO_SCATTER_PARAMS_MAX_DIMENSION_COUNT]; + int num_update_window_dims; + int64_t + inserted_window_dims[TFLITE_STABLEHLO_SCATTER_PARAMS_MAX_DIMENSION_COUNT]; + int num_inserted_window_dims; + int64_t scatter_dims_to_operand_dims + [TFLITE_STABLEHLO_SCATTER_PARAMS_MAX_DIMENSION_COUNT]; + int num_scatter_dims_to_operand_dims; + int64_t index_vector_dim; + bool unique_indices; + int update_computation_subgraph_index; +} TfLiteStablehloScatterParams; + +typedef enum { + kTfLiteRngAlgorithmUnknown = 0, + // An algorithm auto-selected by the system according to device type. + kTfLiteRngAlgorithmDefault, + // The Philox algorithm, as described in paper + // ['Parallel Random Numbers: As Easy as 1, 2, 3'] + // (https://www.thesalmons.org/john/random123/papers/random123sc11.pdf) + kTfLiteRngAlgorithmPhilox, + // The ThreeFry algorithm, as described in paper + // ['Parallel Random Numbers: As Easy as 1, 2, 3'] + // (https://www.thesalmons.org/john/random123/papers/random123sc11.pdf) + kTfLiteRngAlgorithmThreefry, +} TfLiteRngAlgorithm; + +typedef struct { + TfLiteRngAlgorithm algorithm; +} TfLiteStablehloRngBitGeneratorParams; + +typedef struct { + // See the stablehlo spec for the explanation of the attributes: + // https://github.com/openxla/stablehlo/blob/main/docs/spec.md#gather + int64_t offset_dims[TFLITE_STABLEHLO_GATHER_PARAMS_MAX_DIMENSION_COUNT]; + int num_offset_dims; + int64_t + collapsed_slice_dims[TFLITE_STABLEHLO_GATHER_PARAMS_MAX_DIMENSION_COUNT]; + int num_collapsed_slice_dims; + int64_t start_index_map[TFLITE_STABLEHLO_GATHER_PARAMS_MAX_DIMENSION_COUNT]; + int num_start_index_map; + int64_t index_vector_dim; + int64_t slice_sizes[TFLITE_STABLEHLO_GATHER_PARAMS_MAX_DIMENSION_COUNT]; + int num_slice_sizes; + bool indices_are_sorted; +} TfLiteStablehloGatherParams; + +typedef struct { + // See the stablehlo spec for the explanation of the attributes: + // https://github.com/openxla/stablehlo/blob/main/docs/spec.md#reduce_window + int64_t window_dimensions + [TFLITE_STABLEHLO_REDUCE_WINDOW_PARAMS_MAX_DIMENSION_COUNT]; + int64_t + window_strides[TFLITE_STABLEHLO_REDUCE_WINDOW_PARAMS_MAX_DIMENSION_COUNT]; + int64_t + base_dilations[TFLITE_STABLEHLO_REDUCE_WINDOW_PARAMS_MAX_DIMENSION_COUNT]; + int64_t window_dilations + [TFLITE_STABLEHLO_REDUCE_WINDOW_PARAMS_MAX_DIMENSION_COUNT]; + int64_t + padding[2 * TFLITE_STABLEHLO_REDUCE_WINDOW_PARAMS_MAX_DIMENSION_COUNT]; + int body_subgraph_index; +} TfLiteStablehloReduceWindowParams; + +enum TfLiteReduceWindowFunction { + TfLiteReduceWindowFunctionUnsupported, + TfLiteReduceWindowFunctionAdd, + TfLiteReduceWindowFunctionMul, + TfLiteReduceWindowFunctionMin, + TfLiteReduceWindowFunctionMax, + TfLiteReduceWindowFunctionAll, + TfLiteReduceWindowFunctionAny +}; + +typedef struct { + enum TfLiteReduceWindowFunction reduce_function; +} TfLiteReduceWindowParams; + +typedef struct { + // See the stablehlo spec for the explanation of the attributes: + // https://github.com/openxla/stablehlo/blob/main/docs/spec.md#pad + int64_t edge_padding_low[TFLITE_STABLEHLO_PAD_PARAMS_MAX_DIMENSION_COUNT]; + int64_t edge_padding_high[TFLITE_STABLEHLO_PAD_PARAMS_MAX_DIMENSION_COUNT]; + int64_t interior_padding[TFLITE_STABLEHLO_PAD_PARAMS_MAX_DIMENSION_COUNT]; +} TfLiteStablehloPadParams; + +typedef struct { + const char* name; + int32_t subgraph_index; + int32_t version; + const uint8_t* attributes; + size_t attributes_size; +} TfLiteStablehloCompositeParams; + +typedef struct { + // See the stablehlo spec for the explanation of the attributes: + // https://github.com/openxla/stablehlo/blob/main/docs/spec.md#case + int32_t + branch_subgraph_indices[TFLITE_STABLEHLO_CASE_PARAMS_MAX_BRANCHES_COUNT]; + uint32_t num_branches; +} TfLiteStablehloCaseParams; + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_CORE_C_BUILTIN_OP_DATA_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/core/c/tflite_types.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/core/c/tflite_types.h new file mode 100644 index 00000000..068facb1 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/core/c/tflite_types.h @@ -0,0 +1,90 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// This file hosts data structures that are needed both for LiteRT and +// Compiler. + +// WARNING: Users of TensorFlow Lite should not include this file directly, but +// should instead include "third_party/tensorflow/lite/c/c_api_types.h". +// Only the TensorFlow Lite implementation itself should include this file +// directly. + +// clang-format off +// NOLINTBEGIN(whitespace/line_length) +/// \note Users of TensorFlow Lite should use +/// \code +/// #include "tensorflow/lite/c/c_api_types.h" +/// \endcode +/// to access the APIs documented on this page. +// NOLINTEND(whitespace/line_length) +// clang-format on + +// IWYU pragma: private, include "third_party/tensorflow/lite/c/c_api_types.h" + +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_CORE_C_TFLITE_TYPES_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_CORE_C_TFLITE_TYPES_H_ + +#include + +#ifdef __cplusplus +extern "C" { +#endif + +/// Types supported by tensor +// LINT.IfChange +typedef enum { + kTfLiteNoType = 0, + kTfLiteFloat32 = 1, + kTfLiteInt32 = 2, + kTfLiteUInt8 = 3, + kTfLiteInt64 = 4, + kTfLiteString = 5, + kTfLiteBool = 6, + kTfLiteInt16 = 7, + kTfLiteComplex64 = 8, + kTfLiteInt8 = 9, + kTfLiteFloat16 = 10, + kTfLiteFloat64 = 11, + kTfLiteComplex128 = 12, + kTfLiteUInt64 = 13, + kTfLiteResource = 14, + kTfLiteVariant = 15, + kTfLiteUInt32 = 16, + kTfLiteUInt16 = 17, + kTfLiteInt4 = 18, + kTfLiteBFloat16 = 19, +} TfLiteType; +// LINT.ThenChange(//tensorflow/lite/profiling/proto/model_runtime_info.proto:EdgeDataType) + +/// Legacy. Will be deprecated in favor of `TfLiteAffineQuantization`. +/// If per-layer quantization is specified this field will still be populated in +/// addition to `TfLiteAffineQuantization`. +/// Parameters for asymmetric quantization. Quantized values can be converted +/// back to float using: `real_value = scale * (quantized_value - zero_point)` +typedef struct TfLiteQuantizationParams { + float scale; + int32_t zero_point; +} TfLiteQuantizationParams; + +/// Storage format of each dimension in a sparse tensor. +typedef enum TfLiteDimensionType { + kTfLiteDimDense = 0, + kTfLiteDimSparseCSR, +} TfLiteDimensionType; + +#ifdef __cplusplus +} // extern C +#endif + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_CORE_C_TFLITE_TYPES_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/core/macros.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/core/macros.h new file mode 100644 index 00000000..c18984d3 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/core/macros.h @@ -0,0 +1,50 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// This provides utility macros and functions that are inherently platform +// specific or shared across runtime & converter. +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_CORE_MACROS_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_CORE_MACROS_H_ + +#ifndef TF_LITE_STATIC_MEMORY +// maximum size of a valid flatbuffer +inline constexpr unsigned int flatbuffer_size_max = 2147483648; +// If none zero then the buffer is stored outside of the flatbuffers, string +inline constexpr char tflite_metadata_buffer_location[] = "buffer_location"; +// field for minimum runtime version, string +inline constexpr char tflite_metadata_min_runtime_version[] = + "min_runtime_version"; +// the stablehlo op version is supported by the tflite runtime +inline constexpr char tflite_supported_stablehlo_version[] = "1.0.0"; +#endif + +// LINT.IfChange(TFLITE_NOINLINE) + +#ifdef _WIN32 +#define TFLITE_NOINLINE __declspec(noinline) +#else +#ifdef __has_attribute +#if __has_attribute(noinline) +#define TFLITE_NOINLINE __attribute__((noinline)) +#else +#define TFLITE_NOINLINE +#endif // __has_attribute(noinline) +#else +#define TFLITE_NOINLINE +#endif // __has_attribute +#endif // _WIN32 + +// LINT.ThenChange(//tensorflow/lite/core/macros.h) + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_CORE_MACROS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/core/model_builder_base.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/core/model_builder_base.h new file mode 100644 index 00000000..e7892cc0 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/core/model_builder_base.h @@ -0,0 +1,614 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +/// \file +/// +/// Deserialization infrastructure for tflite. Provides functionality +/// to go from a serialized tflite model in flatbuffer format to an +/// in-memory representation of the model. +/// +/// WARNING: Users of TensorFlow Lite should not include this file directly, +/// but should instead include "third_party/tensorflow/lite/model_builder.h". +/// Only the TensorFlow Lite implementation itself should include this +/// file directly. +// IWYU pragma: private, include "third_party/tensorflow/lite/model_builder.h" + +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_CORE_MODEL_BUILDER_BASE_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_CORE_MODEL_BUILDER_BASE_H_ + +#include + +#include +#include +#include +#include +#include +#include + +#include "flatbuffers/base.h" // from @flatbuffers +#include "flatbuffers/buffer.h" // from @flatbuffers +#include "flatbuffers/vector.h" // from @flatbuffers +#include "flatbuffers/verifier.h" // from @flatbuffers +#include "tensorflow/compiler/mlir/lite/allocation.h" +#include "tensorflow/compiler/mlir/lite/core/api/error_reporter.h" +#include "tensorflow/compiler/mlir/lite/core/api/verifier.h" +#include "tensorflow/compiler/mlir/lite/core/macros.h" +#include "tensorflow/compiler/mlir/lite/schema/schema_generated.h" + +namespace tflite { + +std::unique_ptr GetAllocationFromFile( + const char* filename, ErrorReporter* error_reporter); + +std::unique_ptr GetAllocationFromFile( + int fd, ErrorReporter* error_reporter); + +namespace impl { + +/// An RAII object that represents a read-only tflite model, copied from disk, +/// or mmapped. This uses flatbuffers as the serialization format. +/// +/// NOTE: The current API requires that a FlatBufferModelBase instance be kept +/// alive by the client as long as it is in use by any dependent Interpreter +/// instances. As the FlatBufferModelBase instance is effectively immutable +/// after creation, the client may safely use a single model with multiple +/// dependent Interpreter instances, even across multiple threads (though note +/// that each Interpreter instance is *not* thread-safe). +/// +///

+/// using namespace tflite;
+/// StderrReporter error_reporter;
+/// auto model = FlatBufferModelBase::BuildFromFile("interesting_model.tflite",
+///                                             &error_reporter);
+/// MyOpResolver resolver;  // You need to subclass OpResolver to provide
+///                         // implementations.
+/// InterpreterBuilder builder(*model, resolver);
+/// std::unique_ptr interpreter;
+/// if(builder(&interpreter) == kTfLiteOk) {
+///   .. run model inference with interpreter
+/// }
+/// 
+/// +/// OpResolver must be defined to provide your kernel implementations to the +/// interpreter. This is environment specific and may consist of just the +/// builtin ops, or some custom operators you defined to extend tflite. +template +class FlatBufferModelBase { + public: + /// Builds a model based on a file. + /// Caller retains ownership of `error_reporter` and must ensure its lifetime + /// is longer than the FlatBufferModelBase instance. + /// Returns a nullptr in case of failure. + static std::unique_ptr BuildFromFile( + const char* filename, + ErrorReporter* error_reporter = T::GetDefaultErrorReporter()) { + error_reporter = ValidateErrorReporter(error_reporter); + std::unique_ptr model = BuildFromAllocation( + GetAllocationFromFile(filename, error_reporter), error_reporter); +#if FLATBUFFERS_LITTLEENDIAN == 1 + return model; +#else + return ByteConvertModel(std::move(model), error_reporter); +#endif + } + + /// Verifies whether the content of the file is legit, then builds a model + /// based on the file. + /// The extra_verifier argument is an additional optional verifier for the + /// file contents. By default, we always check with tflite::VerifyModelBuffer. + /// If extra_verifier is supplied, the file contents is also checked against + /// the extra_verifier after the check against tflite::VerifyModelBuilder. + /// Caller retains ownership of `error_reporter` and must ensure its lifetime + /// is longer than the FlatBufferModelBase instance. + /// Returns a nullptr in case of failure. + static std::unique_ptr VerifyAndBuildFromFile( + const char* filename, TfLiteVerifier* extra_verifier = nullptr, + ErrorReporter* error_reporter = T::GetDefaultErrorReporter()) { + error_reporter = ValidateErrorReporter(error_reporter); + std::unique_ptr model = VerifyAndBuildFromAllocation( + GetAllocationFromFile(filename, error_reporter), extra_verifier, + error_reporter); +#if FLATBUFFERS_LITTLEENDIAN == 1 + return model; +#else + return ByteConvertModel(std::move(model), error_reporter); +#endif + } + + /// Builds a model based on a file descriptor. + /// Caller retains ownership of `error_reporter` and must ensure its lifetime + /// is longer than the FlatBufferModelBase instance. Caller retains ownership + /// of `fd` and must ensure it is closed after BuildFromFile returns. Returns + /// a nullptr in case of failure. + static std::unique_ptr BuildFromFileDescriptor( + int fd, ErrorReporter* error_reporter = T::GetDefaultErrorReporter()) { + error_reporter = ValidateErrorReporter(error_reporter); + std::unique_ptr model = BuildFromAllocation( + GetAllocationFromFile(fd, error_reporter), error_reporter); +#if FLATBUFFERS_LITTLEENDIAN == 1 + return model; +#else + return ByteConvertModel(std::move(model), error_reporter); +#endif + } + + /// Verifies whether the content of the file descriptor is legit, then builds + /// a model based on the file. + /// The extra_verifier argument is an additional optional verifier for the + /// file contents. By default, we always check with tflite::VerifyModelBuffer. + /// If extra_verifier is supplied, the file contents is also checked against + /// the extra_verifier after the check against tflite::VerifyModelBuilder. + /// Caller retains ownership of `error_reporter` and must ensure its lifetime + /// is longer than the FlatBufferModelBase instance. + /// Returns a nullptr in case of failure. + static std::unique_ptr VerifyAndBuildFromFileDescriptor( + int fd, TfLiteVerifier* extra_verifier = nullptr, + ErrorReporter* error_reporter = T::GetDefaultErrorReporter()) { + error_reporter = ValidateErrorReporter(error_reporter); + std::unique_ptr> model = + VerifyAndBuildFromAllocation(GetAllocationFromFile(fd, error_reporter), + extra_verifier, error_reporter); +#if FLATBUFFERS_LITTLEENDIAN == 1 + return model; +#else + return ByteConvertModel(std::move(model), error_reporter); +#endif + } + + /// Builds a model based on a pre-loaded flatbuffer. + /// Caller retains ownership of the buffer and should keep it alive until + /// the returned object is destroyed. Caller also retains ownership of + /// `error_reporter` and must ensure its lifetime is longer than the + /// FlatBufferModelBase instance. + /// Returns a nullptr in case of failure. + /// NOTE: this does NOT validate the buffer so it should NOT be called on + /// invalid/untrusted input. Use VerifyAndBuildFromBuffer in that case + static std::unique_ptr BuildFromBuffer( + const char* caller_owned_buffer, size_t buffer_size, + ErrorReporter* error_reporter = T::GetDefaultErrorReporter()) { + error_reporter = ValidateErrorReporter(error_reporter); + std::unique_ptr allocation( + new MemoryAllocation(caller_owned_buffer, buffer_size, error_reporter)); + return BuildFromAllocation(std::move(allocation), error_reporter); + } + + /// Verifies whether the content of the buffer is legit, then builds a model + /// based on the pre-loaded flatbuffer. + /// The extra_verifier argument is an additional optional verifier for the + /// buffer. By default, we always check with tflite::VerifyModelBuffer. If + /// extra_verifier is supplied, the buffer is checked against the + /// extra_verifier after the check against tflite::VerifyModelBuilder. The + /// caller retains ownership of the buffer and should keep it alive until the + /// returned object is destroyed. Caller retains ownership of `error_reporter` + /// and must ensure its lifetime is longer than the FlatBufferModelBase + /// instance. Returns a nullptr in case of failure. + static std::unique_ptr VerifyAndBuildFromBuffer( + const char* caller_owned_buffer, size_t buffer_size, + TfLiteVerifier* extra_verifier = nullptr, + ErrorReporter* error_reporter = T::GetDefaultErrorReporter()) { + error_reporter = ValidateErrorReporter(error_reporter); + std::unique_ptr allocation( + new MemoryAllocation(caller_owned_buffer, buffer_size, error_reporter)); + return VerifyAndBuildFromAllocation(std::move(allocation), extra_verifier, + error_reporter); + } + +#if FLATBUFFERS_LITTLEENDIAN == 0 + + void ByteSwapSerializedModel(std::string* serialized_model, + bool from_big_endian) { + const uint8_t* buffer = + reinterpret_cast(serialized_model->c_str()); + const tflite::Model* input_model = tflite::GetModel(buffer); + ByteSwapTFLiteModel(input_model, from_big_endian); + } + + void ByteSwapBuffer(int8_t tensor_type, size_t buffer_size, uint8_t* buffer, + bool from_big_endian) { + switch (tensor_type) { + case tflite::TensorType_STRING: { + auto bp = reinterpret_cast(buffer); + int num_of_strings = + from_big_endian ? bp[0] : flatbuffers::EndianSwap(bp[0]); + for (int i = 0; i < num_of_strings + 2; i++) + bp[i] = flatbuffers::EndianSwap(bp[i]); + break; + } + // 16-bit types + case tflite::TensorType_FLOAT16: + case tflite::TensorType_INT16: + case tflite::TensorType_UINT16: { + auto bp = reinterpret_cast(buffer); + for (int i = 0; i < buffer_size / 2; i++) + bp[i] = flatbuffers::EndianSwap(bp[i]); + break; + } + // 32-bit types + case tflite::TensorType_FLOAT32: + case tflite::TensorType_INT32: + case tflite::TensorType_UINT32: + case tflite::TensorType_COMPLEX64: { + auto bp = reinterpret_cast(buffer); + for (int i = 0; i < buffer_size / 4; i++) + bp[i] = flatbuffers::EndianSwap(bp[i]); + break; + } + // 64-bit types + case tflite::TensorType_INT64: + case tflite::TensorType_FLOAT64: + case tflite::TensorType_UINT64: + case tflite::TensorType_COMPLEX128: { + auto bp = reinterpret_cast(buffer); + for (int i = 0; i < buffer_size / 8; i++) + bp[i] = flatbuffers::EndianSwap(bp[i]); + break; + } + default: + break; + } + } + + void ByteSwapTFLiteModel(const tflite::Model* tfl_model, + bool from_big_endian) { + std::vector buffer_swapped(tfl_model->buffers()->size(), false); + for (size_t subgraph_idx = 0; subgraph_idx < tfl_model->subgraphs()->size(); + subgraph_idx++) { + const tflite::SubGraph* subgraph = + tfl_model->subgraphs()->Get(subgraph_idx); + for (size_t ts_idx = 0; ts_idx < subgraph->tensors()->size(); ts_idx++) { + const tflite::Tensor* tensor = subgraph->tensors()->Get(ts_idx); + if (tensor->buffer() > 0 && + tensor->buffer() < tfl_model->buffers()->size() && + !buffer_swapped[tensor->buffer()]) { + const tflite::Buffer* buffer_ = + (*tfl_model->buffers())[tensor->buffer()]; + if (!buffer_ || !buffer_->data()) continue; + auto* buffer = buffer_->data(); + uint8_t* buff_ = const_cast(buffer->data()); + ByteSwapBuffer(tensor->type(), buffer->size(), buff_, + from_big_endian); + buffer_swapped[tensor->buffer()] = true; + } + } + } + } + + std::unique_ptr ByteConvertModel(std::unique_ptr model, + ErrorReporter* error_reporter, + bool from_big_endian) { + if (model == nullptr) return model; + auto tfl_model = model->GetModel(); + if (tfl_model->subgraphs()->size() == 0) return model; + if (tfl_model->subgraphs()->Get(0)->tensors()->size() == 0) return model; + if (tfl_model->buffers()->size() < 2) return model; + return ByteSwapFlatBufferModelBase(std::move(model), error_reporter, + from_big_endian); + } + + std::unique_ptr ByteSwapFlatBufferModelBase(std::unique_ptr model, + ErrorReporter* error_reporter, + bool from_big_endian) { + FlatBufferModelBase* modelp = model.release(); + auto tflite_model = modelp->GetModel(); + auto copied_model = std::make_unique(); + tflite_model->UnPackTo(copied_model.get(), nullptr); + ByteSwapTFLiteModelT(copied_model.get(), from_big_endian); + std::unique_ptr builder( + new flatbuffers::FlatBufferBuilder()); + auto packed_model = tflite::Model::Pack(*builder, copied_model.get()); + tflite::FinishModelBuffer(*builder, packed_model); + flatbuffers::FlatBufferBuilder* builder_ = builder.release(); + return BuildFromBuffer( + reinterpret_cast(builder_->GetBufferPointer()), + builder_->GetSize(), error_reporter); + } + + void ByteSwapTFLiteModelT(tflite::ModelT* tfl_modelt, bool from_big_endian) { + size_t bytes_per_elem = 0; + std::vector buffer_swapped(tfl_modelt->buffers.size(), false); + for (size_t subgraph_idx = 0; subgraph_idx < tfl_modelt->subgraphs.size(); + subgraph_idx++) { + tflite::SubGraphT* subgraph = + tfl_modelt->subgraphs.at(subgraph_idx).get(); + for (size_t ts_idx = 0; ts_idx < subgraph->tensors.size(); ts_idx++) { + tflite::TensorT* tensor = subgraph->tensors[ts_idx].get(); + if (tensor->buffer > 0 && tensor->buffer < tfl_modelt->buffers.size() && + !buffer_swapped[tensor->buffer]) { + const auto* buffer = + &(tfl_modelt->buffers[tensor->buffer].get()->data); + if (buffer && buffer->data()) { + uint8_t* buff_ = const_cast(buffer->data()); + ByteSwapBuffer(tensor->type, buffer->size(), buff_, + from_big_endian); + buffer_swapped[tensor->buffer] = true; + } + } + } + } + } + +#endif + + /// Builds a model directly from an allocation. + /// Ownership of the allocation is passed to the model, but the caller + /// retains ownership of `error_reporter` and must ensure its lifetime is + /// longer than the FlatBufferModelBase instance. + /// Returns a nullptr in case of failure (e.g., the allocation is invalid). + static std::unique_ptr BuildFromAllocation( + std::unique_ptr allocation, + ErrorReporter* error_reporter = T::GetDefaultErrorReporter()) { + std::unique_ptr model( + new T(std::move(allocation), ValidateErrorReporter(error_reporter))); + if (!model->initialized()) { + model.reset(); + } else { + model->ValidateModelBuffers(error_reporter); + } + return model; + } + + /// Verifies whether the content of the allocation is legit, then builds a + /// model based on the provided allocation. + /// The extra_verifier argument is an additional optional verifier for the + /// buffer. By default, we always check with tflite::VerifyModelBuffer. If + /// extra_verifier is supplied, the buffer is checked against the + /// extra_verifier after the check against tflite::VerifyModelBuilder. + /// Ownership of the allocation is passed to the model, but the caller + /// retains ownership of `error_reporter` and must ensure its lifetime is + /// longer than the FlatBufferModelBase instance. + /// Returns a nullptr in case of failure. + static std::unique_ptr VerifyAndBuildFromAllocation( + std::unique_ptr allocation, + TfLiteVerifier* extra_verifier = nullptr, + ErrorReporter* error_reporter = T::GetDefaultErrorReporter()) { + error_reporter = ValidateErrorReporter(error_reporter); + if (!allocation || !allocation->valid()) { + TF_LITE_REPORT_ERROR(error_reporter, + "The model allocation is null/empty"); + return nullptr; + } + + { + // Flatbuffers can only be smaller than 2GB. The file format appends some + // data after the actual flabuffer. We truncate the allocation size to 2GB + // so that the verifier doesn't early exit on us. + size_t allocation_size = + std::min(allocation->bytes(), + static_cast(FLATBUFFERS_MAX_BUFFER_SIZE - 1)); + flatbuffers::Verifier base_verifier( + reinterpret_cast(allocation->base()), allocation_size, + flatbuffers::Verifier::Options()); + if (!VerifyModelBuffer(base_verifier)) { + TF_LITE_REPORT_ERROR(error_reporter, + "The model is not a valid Flatbuffer buffer"); + return nullptr; + } + + if (extra_verifier && + !extra_verifier->Verify(static_cast(allocation->base()), + allocation_size, error_reporter)) { + // The verifier will have already logged an appropriate error message. + return nullptr; + } + } + + return BuildFromAllocation(std::move(allocation), error_reporter); + } + + /// Builds a model directly from a flatbuffer pointer + /// Caller retains ownership of the buffer and should keep it alive until the + /// returned object is destroyed. Caller retains ownership of `error_reporter` + /// and must ensure its lifetime is longer than the FlatBufferModelBase + /// instance. Returns a nullptr in case of failure. + static std::unique_ptr BuildFromModel( + const tflite::Model* caller_owned_model_spec, + ErrorReporter* error_reporter = T::GetDefaultErrorReporter()) { + error_reporter = ValidateErrorReporter(error_reporter); + + if (CheckBufferOutsideModel(caller_owned_model_spec)) { + TF_LITE_REPORT_ERROR(error_reporter, + "The model contains weights not accessible from " + "tflite::Model *, please use other api"); + return nullptr; + } + + std::unique_ptr model(new T(caller_owned_model_spec, error_reporter)); + if (!model->initialized()) { + model.reset(); + } else { + model->ValidateModelBuffers(error_reporter); + } + return model; + } + + // Releases memory or unmaps mmaped memory. + ~FlatBufferModelBase() = default; + + // Copying or assignment is disallowed to simplify ownership semantics. + FlatBufferModelBase(const FlatBufferModelBase&) = delete; + FlatBufferModelBase& operator=(const FlatBufferModelBase&) = delete; + + bool initialized() const { return model_ != nullptr; } + const tflite::Model* operator->() const { return model_; } + const tflite::Model* GetModel() const { return model_; } + ErrorReporter* error_reporter() const { return error_reporter_; } + const Allocation* allocation() const { return allocation_.get(); } + + // Returns the minimum runtime version from the flatbuffer. This runtime + // version encodes the minimum required interpreter version to run the + // flatbuffer model. If the minimum version can't be determined, an empty + // string will be returned. + // Note that the returned minimum version is a lower-bound but not a strict + // lower-bound; ops in the graph may not have an associated runtime version, + // in which case the actual required runtime might be greater than the + // reported minimum. + std::string GetMinimumRuntime() const { + if (!model_ || !model_->metadata()) return ""; + + for (int i = 0; i < model_->metadata()->size(); ++i) { + auto metadata = model_->metadata()->Get(i); + if (metadata->name()->str() == tflite_metadata_min_runtime_version) { + auto buf = metadata->buffer(); + auto* buffer = (*model_->buffers())[buf]; + auto* array = buffer->data(); + // Get the real length of the runtime string, since there might be + // trailing + // '\0's in the buffer. + for (int len = 0; len < array->size(); ++len) { + if (array->data()[len] == '\0') { + return std::string(reinterpret_cast(array->data()), + len); + } + } + // If there is no '\0' in the buffer, this indicates that the flatbuffer + // is malformed. + TF_LITE_REPORT_ERROR( + error_reporter_, + "Min_runtime_version in model metadata is malformed"); + break; + } + } + return ""; + } + + // Return model metadata as a mapping of name & buffer strings. + // See Metadata table in TFLite schema. + std::map ReadAllMetadata() const { + return ReadAllMetadata(model_); + } + + // // Return model metadata as a mapping of name & buffer strings. + // // See Metadata table in TFLite schema. + static std::map ReadAllMetadata( + const ::tflite::Model* model) { + std::map keys_values; + if (!model || !model->metadata() || !model->buffers()) return keys_values; + + for (size_t i = 0; i < model->metadata()->size(); ++i) { + auto metadata = model->metadata()->Get(i); + auto buf = metadata->buffer(); + if (buf >= model->buffers()->size()) continue; + const tflite::Buffer* buffer = (*model->buffers())[buf]; + if (!buffer || !buffer->data()) continue; + const flatbuffers::Vector* array = buffer->data(); + if (!array) continue; + std::string val = std::string( + reinterpret_cast(array->data()), array->size()); + // Skip if key or value of metadata is empty. + if (!metadata->name() || val.empty()) continue; + keys_values[metadata->name()->str()] = val; + } + return keys_values; + } + + // Validates if the FlatBufferModelBase's buffer is well-formed. Specifically, + // it checks if the 0th entry of the model buffers is an empty buffer + // (sentinel). This is a convention so that tensors without a buffer can + // provide 0 as their buffer. NOTE: The function doesn't explicitly fail for + // backward compatibility reasons; it just provides a warning in case of + // failures. + void ValidateModelBuffers(ErrorReporter* error_reporter) { + auto buffers = model_->buffers(); + if (buffers && buffers->size() > 0) { + auto first_buffer = buffers->Get(0); + if (first_buffer && first_buffer->size() != 0) { + // Note the 0th entry of this array must be an empty buffer (sentinel). + // This is a convention so that tensors without a buffer can provide 0 + // as their buffer. + TF_LITE_REPORT_ERROR( + error_reporter, + "The 0th entry of the model buffer must be an empty buffer."); + } + } + } + + /// Returns true if the model identifier is correct (otherwise false and + /// reports an error). + bool CheckModelIdentifier() const { + if (allocation_->bytes() < 7) { + TF_LITE_REPORT_ERROR( + error_reporter_, + "Model provided must have at least 7 bytes to hold identifier.\n"); + return false; + } + if (!tflite::ModelBufferHasIdentifier(allocation_->base())) { + const char* ident = flatbuffers::GetBufferIdentifier(allocation_->base()); + // Suppress unused variable warning. + (void)ident; + TF_LITE_REPORT_ERROR( + error_reporter_, + "Model provided has model identifier '%c%c%c%c', should be '%s'\n", + ident[0], ident[1], ident[2], ident[3], tflite::ModelIdentifier()); + return false; + } + return true; + } + + /// Check If the buffer is stored as part of the Flatbuffer or outside + /// Return false if the buffers are part of the Flatbuffer + static bool CheckBufferOutsideModel(const tflite::Model* model) { + if (!model || !model->metadata()) return false; + + for (int i = 0; i < model->metadata()->size(); ++i) { + auto metadata = model->metadata()->Get(i); + if (metadata->name()->str() == tflite_metadata_buffer_location) { + return true; + } + } + return false; + } + + protected: + /// Loads a model from a given allocation. FlatBufferModelBase will take over + /// the ownership of `allocation`, and delete it in destructor. The ownership + /// of `error_reporter`remains with the caller and must have lifetime at least + /// as much as FlatBufferModelBase. This is to allow multiple models to use + /// the same ErrorReporter instance. + explicit FlatBufferModelBase( + std::unique_ptr allocation, + ErrorReporter* error_reporter = T::GetDefaultErrorReporter()) + : error_reporter_(ValidateErrorReporter(error_reporter)), + allocation_(std::move(allocation)) { + if (!allocation_ || !allocation_->valid() || !CheckModelIdentifier()) { + return; + } + + model_ = ::tflite::GetModel(allocation_->base()); + } + + /// Loads a model from Model flatbuffer. The `model` has to remain alive and + /// unchanged until the end of this flatbuffer model's lifetime. + FlatBufferModelBase(const Model* model, ErrorReporter* error_reporter) + : model_(model), error_reporter_(ValidateErrorReporter(error_reporter)) {} + + static ErrorReporter* ValidateErrorReporter(ErrorReporter* error_reporter) { + return error_reporter ? error_reporter : T::GetDefaultErrorReporter(); + } + + /// Flatbuffer traverser pointer. (Model* is a pointer that is within the + /// allocated memory of the data allocated by allocation's internals. + const tflite::Model* model_ = nullptr; + /// The error reporter to use for model errors and subsequent errors when + /// the interpreter is created + ErrorReporter* error_reporter_; + /// The allocator used for holding memory of the model. Note that this will + /// be null if the client provides a tflite::Model directly. + std::unique_ptr allocation_; +}; + +} // namespace impl + +} // namespace tflite + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_CORE_MODEL_BUILDER_BASE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/debug/debug.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/debug/debug.h new file mode 100644 index 00000000..3d7f6fe5 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/debug/debug.h @@ -0,0 +1,34 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_DEBUG_DEBUG_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_DEBUG_DEBUG_H_ + +#include "llvm/Support/raw_ostream.h" +#include "mlir/Pass/PassManager.h" // from @llvm-project +#include "tensorflow/compiler/mlir/lite/debug/debug_options.pb.h" + +namespace tensorflow { + +// Initializes the pass manager with default options that make debugging easier. +// The `out` method parameter is exposed for testing purposes and not intended +// to be specified by client code. +void InitPassManager(mlir::PassManager& pm, + const converter::DebugOptions& options, + llvm::raw_ostream& out = llvm::outs()); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_DEBUG_DEBUG_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/delegates/flex/allowlisted_flex_ops.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/delegates/flex/allowlisted_flex_ops.h new file mode 100644 index 00000000..fc75816d --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/delegates/flex/allowlisted_flex_ops.h @@ -0,0 +1,45 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_DELEGATES_FLEX_ALLOWLISTED_FLEX_OPS_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_DELEGATES_FLEX_ALLOWLISTED_FLEX_OPS_H_ + +#include +#include + +namespace tflite { +namespace flex { + +// Whether the given op has been statically allowlisted for flex export. +// +// This static allowlist is formed by the intersection of ops supported by +// TensorFlowMobile on both iOS and Android. As the converter is likely running +// on a host that has the full suite of TensorFlow ops available, we use this +// static allowlist to ensure compatibility when deploying to a mobile device. +// TODO(b/118389105): Automate generation of the allowlisted flex ops. +bool IsAllowlistedFlexOp(const std::string& tensorflow_op_name); + +// Return the list of allowlisted flex ops. +const std::set& GetFlexAllowlist(); + +// Return the list of TF.Text flex ops. +const std::set& GetTFTextFlexAllowlist(); + +// Return the list of SentencePiece flex ops. +const std::set& GetSentencePieceFlexAllowlist(); + +} // namespace flex +} // namespace tflite + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_DELEGATES_FLEX_ALLOWLISTED_FLEX_OPS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/delegates/flex/allowlisted_flex_ops_internal.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/delegates/flex/allowlisted_flex_ops_internal.h new file mode 100644 index 00000000..420516c0 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/delegates/flex/allowlisted_flex_ops_internal.h @@ -0,0 +1,29 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_DELEGATES_FLEX_ALLOWLISTED_FLEX_OPS_INTERNAL_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_DELEGATES_FLEX_ALLOWLISTED_FLEX_OPS_INTERNAL_H_ + +#include + +namespace tflite { +namespace flex { + +// Return true if op_name is a tf.text op need to be supported by flex delegate. +bool IsAllowedTFTextOpForFlex(const std::string& op_name); + +} // namespace flex +} // namespace tflite + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_DELEGATES_FLEX_ALLOWLISTED_FLEX_OPS_INTERNAL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/experimental/common/outline_operations.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/experimental/common/outline_operations.h new file mode 100644 index 00000000..358392ed --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/experimental/common/outline_operations.h @@ -0,0 +1,133 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_COMMON_OUTLINE_OPERATIONS_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_COMMON_OUTLINE_OPERATIONS_H_ + +#include +#include +#include + +#include "llvm/ADT/SetVector.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/raw_os_ostream.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Block.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/OpDefinition.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/UseDefLists.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/IR/ValueRange.h" // from @llvm-project +#include "mlir/IR/Visitors.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassRegistry.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" +#include "tensorflow/compiler/mlir/lite/utils/utils.h" + +namespace mlir { +namespace TFL { +namespace common { + +// Returns true if the `op` is a constant-like op or produces none type. +bool IsConstantOrNone(Operation* op); + +// Computes the list of Value(s) referenced by Subgraph Operations that are +// not defined within the Subgraph. Any such Value(s) +// are validly in-scope for the initial Operation. They must be either +// defined above the subgraph or appear as an argument to the containing func. +// These Value(s) are taken to be the arguments of the new raised func. +// An operand dependency is a Value referenced anywhere in an Op +// that is defined above the Op. All SSA Values are assigned/defined in a +// BlockArg or as a result of an Operation. +llvm::SmallVector AccumulateOperandsDefinedAbove( + const llvm::SetVector& partition_ops); + +// Similar to `AccumulateOperandsDefinedAbove()`, computes the Value(s) that are +// defined within a Subgraph and referenced in a descendant Operation. These +// Values(s) are to be returned by the new raised function. +llvm::SmallVector AccumulateResultsDefinedWithin( + const llvm::SetVector& partition_ops); + +// Represents a view of a set of mlir Operations that form a subgraph of the +// entire Module's DAG. `Subgraph` can be thought of as segment of sequential +// Operations within a func definition. Additional facts: +// 1. Subgraphs are restricted to a single Block. They do not span +// branching instructions. Thus the subgraph is a simple 1-degree path. +// 2. All Operations in a subgraph belong to the same block in a +// funtion body. +// 3. Function bodies are assumed to have only one block in some places. +class Subgraph { + // Set vector preserves insertion order, must insert Ops in topological order. + public: + const llvm::SetVector partition_ops_; + + // Subgraphs are given a unique incremented integer id based on when + // they were encountered in this pass. + const int subgraph_id_; + + const llvm::StringRef dialect_namespace_; + + Subgraph(const llvm::SetVector partition_ops, int num_subgraphs) + : partition_ops_(partition_ops), + subgraph_id_(num_subgraphs), + func_arguments_(AccumulateOperandsDefinedAbove(partition_ops)), + func_outputs_(AccumulateResultsDefinedWithin(partition_ops)) {} + + const llvm::SmallVector& FuncArguments() const { + // `Value`s in MLIR library are implemented as having "value semantics" + // see "llvm/llvm-project/mlir/include/mlir/IR/Value.h" so copying is fine. + return func_arguments_; + } + const llvm::SmallVector& FuncOutputs() const { return func_outputs_; } + + private: + // Compute once at construction and save as field. + const llvm::SmallVector func_arguments_; + const llvm::SmallVector func_outputs_; +}; + +// Helper data structure for output parameters to `ExtractSubgraphToFunc`. +// `ExtractSubgraphToFunc` adds exactly two "new" `Operations`, a FuncOp and +// a CallOp. Pass these back to the caller for setting more specific attributes +// after graph mutation has taken place. +struct OpsAdded { + mlir::func::FuncOp func_op; + mlir::func::CallOp call_op; +}; + +// Given a `Subgraph` containing a sequence of adjacent `Operations` from +// the `module`, raise these `Operations` (and any ops contained nested within) +// to the body of a new seperate root level function. Replace in their current +// location with a `CallOp` which invokes said `FuncOp`. The inputs to +// this new functions are taken to be the `Values` that appear as operands +// to ops in the subgraph, which are not self-contained within the subgraph. +// The outputs of this function are taken to be the results of ops in the +// subgraph which are referenced as operands outside of the subgraph. +// Also refer to documention of `AccumulateOperandsDefinedAbove` & +// `AccumulateResultsDefinedWithin`. +void ExtractSubgraphToFunc(const Subgraph& subgraph, OpBuilder& builder, + ModuleOp& module, OpsAdded& ops_added); + +} // namespace common +} // namespace TFL +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_COMMON_OUTLINE_OPERATIONS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/experimental/remat/metadata_util.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/experimental/remat/metadata_util.h new file mode 100644 index 00000000..6036c468 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/experimental/remat/metadata_util.h @@ -0,0 +1,62 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +/// \file +/// +/// Functions for serializiation/deserialization of control dependency +/// information to/from model metadata. +/// + +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_REMAT_METADATA_UTIL_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_REMAT_METADATA_UTIL_H_ + +#include +#include +#include +#include + +#include "tensorflow/compiler/mlir/lite/utils/control_edges.h" + +namespace tflite { + +/// Control dependencies for the model is the collection of control dependencies +/// for its subgraphs. +using ModelControlDependencies = std::vector; + +/// Serializes `in` into the returned string. The result is parseable with +/// ParseModelControlDependencies. +std::string SerializeModelControlDependencies( + const ModelControlDependencies& in); + +/// Deserializes `*out` from a character buffer of size `size` at `data`. +/// Returns true iff successful. `*out` needn't be empty before invocation. +/// When returning false, `*out`'s state is undefined. +bool ParseModelControlDependencies(const char* data, size_t size, + ModelControlDependencies* out); + +/// The key under which to store the serialized control dependencies in the +/// model's metadata. +constexpr char kModelControlDependenciesMetadataKey[] = + "model_control_dependencies"; + +/// To allow future changes to the format, serialized control dependency data +/// will contain a version; this constant is the version that will be used for +/// serialization. For deserialization, past versions should remain parseable. +constexpr uint32_t kModelControlDependenciesMetadataVersion = 1; + +inline constexpr char kModelUseStablehloTensorKey[] = "keep_stablehlo_constant"; + +} // namespace tflite + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_REMAT_METADATA_UTIL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/experimental/remat/rematerializer.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/experimental/remat/rematerializer.h new file mode 100644 index 00000000..02f66046 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/experimental/remat/rematerializer.h @@ -0,0 +1,262 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_REMAT_REMATERIALIZER_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_REMAT_REMATERIALIZER_H_ + +// This file declares the Rematerializer class, which is used by an MLIR-based +// set of transformations for TFLite IR that lower memory usage by redoing +// operations with small inputs and large outputs instead of keeping the result +// in memory. This class allows us to compactly and efficiently represent the +// (idealized) memory profile of a TFLite graph and simulate the effect of +// re-inserting operations on that memory profile. + +#include +#include +#include +#include +#include + +namespace mlir { +namespace TFL { + +// A class that +// (1) Encodes in concise form the memory requirements of a computational graph +// (2) Allows for the fast simulation of changes to the peak memory requirement +// under rematerialization of intermediate results in the graph +// (3) Implements a greedy algorithm for finding rematerializations of +// intermediate results in that graph to lower peak memory requirements. +class Rematerializer { + public: + Rematerializer() = default; + virtual ~Rematerializer() = default; + + // The type used for memory sizes (in bytes) and differences thereof. + using SizeT = int64_t; + + // The memory profile: The i-th element gives the amount of memory + // that is needed when performing the i-th operation. This is the + // sum of the sizes of + // + // (1) input tensors of that operation, + // (2) output tensors of that operation, + // (3) output tensors of preceding operations that are input tensors + // of subsequent operations. + using MemProfile = std::vector; + + // Used for specifying memory consumption at a certain operation in the + // computational graph. + struct MemSpec { + int op_index; // The index of the operation + SizeT size; // The amount of memory needed in order to execute this + // operation, i.e., the sum of input and output sizes and the + // sizes of outputs of previous operations that are needed as + // inputs of subsequent operations. + explicit MemSpec(int op_index = 0, SizeT size = 0) + : op_index(op_index), size(size) {} + }; + + static bool BySize(const MemSpec& a, const MemSpec& b) { + return std::tie(a.size, a.op_index) < std::tie(b.size, b.op_index); + } + + static bool ByOpIndex(const MemSpec& a, const MemSpec& b) { + return std::tie(a.op_index, a.size) < std::tie(b.op_index, b.size); + } + + // Specifies an elementary rematerialization operation: The operations in + // operations [`begin`, `end`) will be rescheduled before operation `insert`. + // A valid `RematSpec` requires begin <= end <= insert <= number of + // operations. Note that (1) `end` is exclusive -- begin == end signifies a + // trivial RematSpec (no operation will be rescheduled), (2) the + // zero-initialized RematSpec {} is trivial and always valid. + struct RematSpec { + int begin; + int end; + int insert; + }; + + // Gives the peak memory location and size after inserting operations + // according to `remat` (but doesn't actually insert them.) Ties are broken + // towards later locations. `remat` must be valid (see above). + MemSpec GetPeakMemory(const RematSpec& remat = {}) const; + + // Gives memory profile after inserting operations according to `remat` (but + // doesn't actually insert them). `remat` must be valid (see above). + MemProfile GetMemProfile(const RematSpec& remat = {}) const; + + // Runs the greedy incremental block algorithm: Finds a sequence of + // rematerializations of block size up to max_block_length, each reducing peak + // memory by at least min_savings. If max_cost >= 0, at most max_cost + // operations will be re-inserted. For each rematerialization found, + // ApplyRemat is invoked (which can be used to apply the rematerialization to + // the higher- level representation, e.g., MLIR, flatbuffer, ...) + void RunGreedyAlgorithm(int max_cost, int max_block_length, + SizeT min_savings); + + virtual void ApplyRemat(const RematSpec& remat) {} + + protected: + // Rematerializes the outputs of the operations [`remat.begin`, `remat.end`) + // before operation remat.insert by copying that operation range before + // remat.insert and updating tensor references so that any operation that can + // will make use of the rematerialized outputs rather than the original ones. + // `remat` must be valid (see above). + void Remat(const RematSpec& remat); + + // The protected methods below are to be used by derived classes to create the + // low-level (this class) representation from a high-level one. + + // Creates a new tensor-like object that takes `size` bytes. Returns a + // contiguous increasing index for each new object, starting at 0. + int AddTensor(SizeT size); + + // Creates an operation. If `is_stateful`, the operation (and any block of + // operations containing it) will never be considered for rematerialization. + // Returns a contiguous increasing index for each new object, starting at 0. + int AddOperation(bool is_stateful); + + // The operator with index `ioperation` will be assumed to produce and/or + // consume the tensor with index `itensor`. NoOp if that's already the case. + // The arguments must be valid indices (i.e., obtained with + // `AddOperation`/`AddTensor`). + void AddUse(int ioperation, int itensor); + + // Undoes an AddUse(ioperation, itensor). NoOp if there was no prior `AddUse`. + // The arguments must be valid indices (i.e., obtained with + // `AddOperation`/`AddTensor`). + void DelUse(int ioperation, int itensor); + + private: + // Find the best remat operation that saves at least `min_savings` bytes for a + // block of operators with a length is between [`begin_len`, `end_len`). + // 'Best' means with the highest savings, ties are broken towards shorter + // blocks. + std::tuple FindBestRemat(SizeT min_savings, int begin_len, + int end_len) const; + + // Optimization: Estimate (from above) the remat savings of instruction block + // [begin, end) after operation `peak_location` + SizeT MaxSavings(int begin, int end, int peak_loc) const; + + // If I want to remat ops [begin, end) after the op at operation `peak_loc`, + // find the latest point at which to reinsert them (the op before which to + // insert.) + int FindBestRematPoint(int begin, int end, int peak_loc) const; + + // The memory objects. + struct Tensor { + SizeT size; // The size of the object (in bytes.) + std::vector operations; // The operations it is used in. This vector + // is kept sorted + unique. + + // The operation that makes the first use of this tensor. + int first_use() const { return *operations.begin(); } + + // The operation that makes the last use of this tensor. + int last_use() const { return *operations.rbegin(); } + }; + + // The operators. + struct Operation { + bool is_stateful = false; // Results of an Operation can be rematerialized + // only if `!is_stateful`. This probably should + // be replaced with a more-fine grained + // approach--for example, the results of a "read + // resource variable" operation can be + // rematerialized as long as this doesn't happen + // after the corresponding "write resource + // variable" operation. + + std::vector tensors; // The tensors that are used (input or output) by + // this operation. They needn't correspond to + // tensors in the TF graph -- we may add fake + // tensors to model memory consumed in addition + // to input and output tensors. This vector is + // kept sorted + unique. + + SizeT alloc = 0; // The number of bytes that need to be allocated before + // this operation. + SizeT dealloc = 0; // The number of bytes that can be deallocated after + // this operation. + }; + + // Given the current state of `operations_` and `tensors_`, return a vector of + // corrections that transform the current memory profile into the one that we + // would get after applying `remat`. + // + // The memory profile of a sequence of operations is the partial sum of the + // sizes of the allocations that are necessary before an operation and the + // negative sizes of the deallocations that are possible after the previous + // operation. + // + // If we modify the operation sequence by cloning an operation range, that + // memory profile will change--cloning makes it necessary to extend the + // lifetime of some tensors, while other tensors can be deallocated early and + // rematerialized later. + // + // This method represents these changes in compact form: It returns a vector + // of (position of operation, delta) pairs in lexicographic order; one + // obtains the memory profile after `remat` by adding the deltas from any + // entries (i, delta) to the i-th entry of the partial sum. + // + // This allows us to efficiently compute the change to the peak of a memory + // profile due to cloning an operation range without having to actually clone + // that range and without having to build a profile vector. + // + // The returned vector has at most 2 entries for each tensor referenced in + // [remat.begin, remat.end). There may be multiple entries for a single + // operation position; operation positions refer to the sequence *after* + // cloning [`remat.begin`, `remat.end`) before `remat.insert`. + std::vector GetDeltas(const RematSpec& remat) const; + + // Helper template: Iterates through all `MemSpec`s (i.e., operation + // index/memory usage pairs) for the current graph in operation order and + // calls `mapper` on them. This is an optimization -- by instantiating with an + // appropriate `Mapper`, it allows us to e.g. compute the peak memory without + // having to instantiate an actual memory profile vector. + template + void MapMem(const Mapper& mapper, const RematSpec& remat) const { + const auto deltas = GetDeltas(remat); + const auto len = (remat.end - remat.begin); + auto idelta = deltas.begin(); + + for (MemSpec m; m.op_index < operations_.size() + len; ++m.op_index) { + // Are we in the cloned portion of the new operation sequence? + // Then all alloc/dealloc information must come from deltas. + const bool patch = + (m.op_index >= remat.insert) && (m.op_index < remat.insert + len); + // Are we past the insertion portion of the new operation sequence? + // Then we need to convert indices back to the original sequence. + const int shift = (m.op_index >= remat.insert + len) ? len : 0; + m.size += patch ? 0 : operations_[m.op_index - shift].alloc; + // deltas is sorted by location; apply any corrections to the current + // operator. + for (; idelta != deltas.end() && idelta->op_index == m.op_index; + ++idelta) { + m.size += idelta->size; + } + mapper(m); + m.size -= patch ? 0 : operations_[m.op_index - shift].dealloc; + } + } + + std::vector operations_; + std::vector tensors_; +}; + +} // namespace TFL +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_REMAT_REMATERIALIZER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/experimental/tac/common/cost.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/experimental/tac/common/cost.h new file mode 100644 index 00000000..2d79e8d3 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/experimental/tac/common/cost.h @@ -0,0 +1,50 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_TAC_COMMON_COST_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_TAC_COMMON_COST_H_ + +#include + +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project + +namespace mlir { +namespace TFL { +namespace tac { + +// Cost attribute string on the TFL dialect. +constexpr char kCost[] = "tac.cost"; + +inline void UpdateCost(Operation* op, float cost, OpBuilder* builder) { + op->setAttr(kCost, builder->getF32FloatAttr(cost)); +} + +// Get the cost annotated with kCost. +inline bool GetCostOnOp(Operation* op, float* cost) { + auto cost_type = op->getAttrOfType(kCost); + if (cost_type == nullptr) { + return false; + } + + *cost = cost_type.getValueAsDouble(); + return true; +} + +} // namespace tac +} // namespace TFL +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_TAC_COMMON_COST_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/experimental/tac/common/subgraph.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/experimental/tac/common/subgraph.h new file mode 100644 index 00000000..ed61f74c --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/experimental/tac/common/subgraph.h @@ -0,0 +1,52 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_TAC_COMMON_SUBGRAPH_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_TAC_COMMON_SUBGRAPH_H_ + +#include +#include + +#include "llvm/ADT/StringRef.h" +#include "mlir/IR/Operation.h" // from @llvm-project + +namespace mlir { +namespace TFL { +namespace tac { + +// Interface name here is the "hook" between the CallOp and FuncOps. +// Take the following example: +// +// call @func_1_CPU {tac.interface_name = "func_1"} +// +// "func_1" is the interface name where "func_1_cpu" is the real implementation +// we can have multiple FuncOps like "func_1_cpu" and "func_1_gpu" and they +// both implement "func_1". +// +// The attribute on the FuncOp means what it actually implements while the +// attribute on the CallOp means what it actually looks for. +constexpr char kInterfaceNameAttr[] = "tac.interface_name"; + +inline std::optional GetInterFaceName(Operation* op) { + auto name_attr = op->getAttrOfType(kInterfaceNameAttr); + if (!name_attr) return std::nullopt; + return name_attr.getValue().str(); +} + +} // namespace tac +} // namespace TFL +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_TAC_COMMON_SUBGRAPH_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/experimental/tac/common/targets.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/experimental/tac/common/targets.h new file mode 100644 index 00000000..2f299287 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/experimental/tac/common/targets.h @@ -0,0 +1,150 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_TAC_COMMON_TARGETS_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_TAC_COMMON_TARGETS_H_ + +#include +#include +#include +#include +#include +#include + +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/StringRef.h" +#include "mlir/IR/Operation.h" // from @llvm-project + +namespace mlir { +namespace TFL { +namespace tac { + +// Device attribute string on the TFL dialect. +constexpr char kDevice[] = "tac.device"; + +// Inference type. +constexpr char kInferenceType[] = "tac.inference_type"; + +// Inference type. +constexpr char kSkipTargetAnnotation[] = "tac.skip_target_annotation"; + +// TODO(renjieliu): Add more inference types. +enum InferenceType { + UNKNOWN = 0, + FLOAT = 1, + QUANTIZED_INT8 = 2, + QUANTIZED_UINT8 = 3, + HYBRID = 4 +}; + +inline InferenceType GetInferenceTypeEnum(llvm::StringRef inference_type_str) { + if (inference_type_str == "FLOAT") { + return FLOAT; + } else if (inference_type_str == "QUANTIZED_INT8") { + return QUANTIZED_INT8; + } else if (inference_type_str == "QUANTIZED_UINT8") { + return QUANTIZED_UINT8; + } else if (inference_type_str == "HYBRID") { + return HYBRID; + } else { + return UNKNOWN; + } +} + +inline std::string GetInferenceString(InferenceType inference_type) { + if (inference_type == FLOAT) { + return "FLOAT"; + } else if (inference_type == QUANTIZED_INT8) { + return "QUANTIZED_INT8"; + } else if (inference_type == QUANTIZED_UINT8) { + return "QUANTIZED_UINT8"; + } else if (inference_type == HYBRID) { + return "HYBRID"; + } else { + return "UNKNOWN"; + } +} + +// Returns canonical representation for hardware name (All uppercase). +// TODO(b/177376459): Remove this in favor of the string defined by hardwares +// MyHardware::kId. +inline std::string GetCanonicalHardwareName(const std::string& hardware_name) { + std::string name = hardware_name; + std::transform( + name.begin(), name.end(), name.begin(), + [](unsigned char c) -> unsigned char { return std::toupper(c); }); + return name; +} + +// Get the target annotation form the op. +inline std::optional GetTargetAnnotation(Operation* op) { + auto device = op->getAttrOfType(kDevice); + if (device == nullptr || device.getValue().empty()) return std::nullopt; + + return GetCanonicalHardwareName(device.getValue().str()); +} + +// Get inference type attribute from the operation if available. +inline std::optional GetInferenceTypeAnnotation(Operation* op) { + auto inference_type = op->getAttrOfType(kInferenceType); + if (inference_type == nullptr) return std::nullopt; + + llvm::StringRef device_name_str = inference_type.getValue(); + return GetInferenceTypeEnum(device_name_str); +} + +// InferenceDeviceType is a combination of the hardware with inference type. +struct InferenceDeviceType { + std::string hardware; + InferenceType inference_type; + + bool operator==(const InferenceDeviceType& other) const { + return (hardware == other.hardware) && + (inference_type == other.inference_type); + } + + bool operator!=(const InferenceDeviceType& other) const { + return !(*this == other); + } + + struct inference_device_type_hash { + size_t operator()(const InferenceDeviceType& p) const { + auto hash1 = std::hash{}(p.hardware); + auto hash2 = std::hash{}(p.inference_type); + return hash1 ^ hash2; + } + }; +}; + +// Get InferenceDeviceType attribute from the operation if available. +inline std::optional GetInferenceDeviceTypeForOp( + Operation* op) { + auto hardware = GetTargetAnnotation(op); + if (!hardware.has_value()) return std::nullopt; + + auto inference_type = GetInferenceTypeAnnotation(op); + if (!inference_type.has_value()) return std::nullopt; + + InferenceDeviceType inference_device_type; + inference_device_type.hardware = hardware.value(); + inference_device_type.inference_type = inference_type.value(); + return inference_device_type; +} + +} // namespace tac +} // namespace TFL +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_TAC_COMMON_TARGETS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/experimental/tac/common/utils.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/experimental/tac/common/utils.h new file mode 100644 index 00000000..741ee5d5 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/experimental/tac/common/utils.h @@ -0,0 +1,94 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_TAC_COMMON_UTILS_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_TAC_COMMON_UTILS_H_ + +#include "llvm/Support/Casting.h" +#include "mlir/Bytecode/BytecodeOpInterface.h" // from @llvm-project +#include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Dialect/Quant/IR/QuantTypes.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/OpDefinition.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/Interfaces/CallInterfaces.h" // from @llvm-project +#include "mlir/Interfaces/CastInterfaces.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "tensorflow/compiler/mlir/lite/experimental/tac/common/targets.h" +#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" +#include "tensorflow/compiler/mlir/lite/utils/utils.h" + +namespace mlir { +namespace TFL { +namespace tac { + +// Returns true if 'op' is non const op. Returns false otherwise or if +// 'op' is null. +inline bool IsNonConstOp(Operation* op) { + if (!op) return false; + if (llvm::isa(op)) return false; + if (op->hasTrait()) return false; + if (llvm::isa(op)) return false; + return true; +} + +// Returns true if 'op' is a terminator op, otherwise false. +bool IsTerminatorOp(Operation* op); + +// Returns true if 'op' is not TFL Quant / Dequant op. Returns False otherwise +// or if 'op' is null. +bool NotTFLQuantDequantizeOp(Operation* op); + +// Returns true if it is a shaped type of f32 elements. +inline bool IsF32ShapedType(Type t) { + if (auto shaped_type = mlir::dyn_cast_or_null(t)) { + return shaped_type.getElementType().isF32(); + } + return false; +} + +// Return true when the given element_type is QI8. +inline bool IsQI8Type(Type t) { + auto quantized_type = quant::QuantizedType::getQuantizedElementType(t); + return quantized_type != nullptr && + quantized_type.getStorageTypeIntegralWidth() == 8 && + quantized_type.isSigned(); +} + +// Return true when the given element_type is QUI8. +inline bool IsQUI8Type(Type t) { + auto quantized_type = quant::QuantizedType::getQuantizedElementType(t); + return quantized_type != nullptr && + quantized_type.getStorageTypeIntegralWidth() == 8 && + !quantized_type.isSigned(); +} + +// Return true when the given element_type is QI32. +inline bool IsQI32Type(Type t) { + auto quantized_type = quant::QuantizedType::getQuantizedElementType(t); + return quantized_type != nullptr && + quantized_type.getStorageTypeIntegralWidth() == 32 && + quantized_type.isSigned(); +} + +// Try to guess the inference type of the op. +InferenceType GetInferenceType(Operation* op); + +} // namespace tac +} // namespace TFL +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_TAC_COMMON_UTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/experimental/tac/examples/example_hardware.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/experimental/tac/examples/example_hardware.h new file mode 100644 index 00000000..84264907 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/experimental/tac/examples/example_hardware.h @@ -0,0 +1,45 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_TAC_EXAMPLES_EXAMPLE_HARDWARE_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_TAC_EXAMPLES_EXAMPLE_HARDWARE_H_ + +#include "tensorflow/compiler/mlir/lite/experimental/tac/hardwares/simple_hardware.h" + +namespace mlir { +namespace TFL { +namespace tac { + +class ExampleHardware : public SimpleHardware { + public: + static constexpr char kId[] = "ExampleHardware"; + + mlir::RewritePatternSet GetTransformations( + MLIRContext* context) const override; + + mlir::TypeID GetTypeId() const override { + return mlir::TypeID::get(); + } + + bool IsNotSupportedOp(mlir::Operation* op) const override { return false; } + + float AdvantageOverCPU() const override { return 5.0; } +}; + +} // namespace tac +} // namespace TFL +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_TAC_EXAMPLES_EXAMPLE_HARDWARE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/experimental/tac/execution_metadata_exporter.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/experimental/tac/execution_metadata_exporter.h new file mode 100644 index 00000000..4a5f5f11 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/experimental/tac/execution_metadata_exporter.h @@ -0,0 +1,31 @@ +// Copyright 2020 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_TAC_EXECUTION_METADATA_EXPORTER_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_TAC_EXECUTION_METADATA_EXPORTER_H_ + +#include +#include + +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project + +namespace tflite { + +// Returns serialized string for the generated flatbuffer. +std::optional ExportRuntimeMetadata(mlir::ModuleOp module); + +} // namespace tflite + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_TAC_EXECUTION_METADATA_EXPORTER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/experimental/tac/hardwares/gpu_hardware.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/experimental/tac/hardwares/gpu_hardware.h new file mode 100644 index 00000000..149c2076 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/experimental/tac/hardwares/gpu_hardware.h @@ -0,0 +1,49 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_TAC_HARDWARES_GPU_HARDWARE_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_TAC_HARDWARES_GPU_HARDWARE_H_ + +#include + +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/Support/TypeID.h" // from @llvm-project +#include "tensorflow/compiler/mlir/lite/experimental/tac/hardwares/target_hardware.h" +#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" + +namespace mlir { +namespace TFL { +namespace tac { +// Gpu hardware class which handles GPU capabilities in TFLite. +// This is used by TAC to get op supported/ op cost estimates on GPU. +class GpuHardware : public TargetHardware { + public: + static constexpr char kId[] = "GPU"; + mlir::RewritePatternSet GetTransformations( + MLIRContext* context) const override; + + mlir::TypeID GetTypeId() const override { + return mlir::TypeID::get(); + } + + double GetHardwareSwitchingCost(const TargetHardware* from, + size_t buffer_size) const override; +}; +} // namespace tac +} // namespace TFL +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_TAC_HARDWARES_GPU_HARDWARE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/experimental/tac/hardwares/nnapi_hardware.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/experimental/tac/hardwares/nnapi_hardware.h new file mode 100644 index 00000000..51c1c117 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/experimental/tac/hardwares/nnapi_hardware.h @@ -0,0 +1,50 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +/* NNAPI Hardware Implementation */ +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_TAC_HARDWARES_NNAPI_HARDWARE_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_TAC_HARDWARES_NNAPI_HARDWARE_H_ + +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/Support/TypeID.h" // from @llvm-project +#include "tensorflow/compiler/mlir/lite/experimental/tac/hardwares/simple_hardware.h" + +namespace mlir { +namespace TFL { +namespace tac { + +class NNAPIHardware : public SimpleHardware { + public: + static constexpr char kId[] = "NNAPI"; + + mlir::RewritePatternSet GetTransformations( + MLIRContext* context) const override; + + mlir::TypeID GetTypeId() const override { + return mlir::TypeID::get(); + } + + bool IsNotSupportedOp(mlir::Operation* op) const override { return false; } + + float AdvantageOverCPU() const override { return 5.0; } +}; + +} // namespace tac +} // namespace TFL +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_TAC_HARDWARES_NNAPI_HARDWARE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/experimental/tac/hardwares/simple_hardware.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/experimental/tac/hardwares/simple_hardware.h new file mode 100644 index 00000000..ca371544 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/experimental/tac/hardwares/simple_hardware.h @@ -0,0 +1,67 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_TAC_HARDWARES_SIMPLE_HARDWARE_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_TAC_HARDWARES_SIMPLE_HARDWARE_H_ + +#include + +#include "mlir/IR/Operation.h" // from @llvm-project +#include "tensorflow/compiler/mlir/lite/experimental/tac/hardwares/target_hardware.h" + +namespace mlir { +namespace TFL { +namespace tac { + +// A simple hardware is an interface makes you add a target backend easily if +// you don't want too much customization. +// +// It allows you to easily specify the ops capabilities (by +// specifying the denylist), the rest ops will be considered supported. Also you +// can also specify the advantage over CPU. +// +// If you need more customization, e.g., if you have your own hardware dialect, +// please consider use TargetHardware directly. +class SimpleHardware : public TargetHardware { + public: + // This is essentially a denylist. + // TODO(renjieliu): Consider whether we want an allowlist for custom op as + // well. + virtual bool IsNotSupportedOp(mlir::Operation* op) const = 0; + + // The larger the value is, the more preferrable over CPU. + // If the value > 1, means the hardware has advantage over CPU. + // If the value < 1, means CPU is more preferred. + // If we specify 10.0, meaning the hardware is 10x faster than CPU. + // The value should be > 0. + // TODO(renjieliu): Consider add an interface for more detailed customization, + // for example, users should be able to specify some ops are preferred and + // some are less preferred. + virtual float AdvantageOverCPU() const = 0; + + private: + bool IsOpSupported(mlir::Operation* op) const override; + + double GetHardwareSwitchingCost(const TargetHardware* from, + size_t buffer_size) const override; + + double GetOpCost(mlir::Operation* op) const override; +}; + +} // namespace tac +} // namespace TFL +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_TAC_HARDWARES_SIMPLE_HARDWARE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/experimental/tac/hardwares/target_hardware.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/experimental/tac/hardwares/target_hardware.h new file mode 100644 index 00000000..136bb5ec --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/experimental/tac/hardwares/target_hardware.h @@ -0,0 +1,195 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_TAC_HARDWARES_TARGET_HARDWARE_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_TAC_HARDWARES_TARGET_HARDWARE_H_ + +#include +#include +#include +#include + +#include "llvm/ADT/ArrayRef.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/DialectRegistry.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/Support/TypeID.h" // from @llvm-project + +namespace mlir { +namespace TFL { +namespace tac { + +// Default fixed values for ops. +constexpr static float kDefaultFixedValuedCost = 1000000.0; + +// This is just fake data. +constexpr static float kCrossHardwareTransferPerByteCost = 5.0f; + +// This is just fake data. +constexpr static float kCrossHardwareTransferFixedCost = 10.f; + +// Interface for an Operation capabilities which should be tied to +// a specific hardware. +// Users should implement the interface and use TargetHardwareOpRegistration +// for registering the operation. +class TargetHardwareOperation { + public: + virtual ~TargetHardwareOperation() = default; + + virtual double GetOpCost(mlir::Operation* op) const = 0; + + virtual bool IsOpSupported(mlir::Operation* op) const = 0; +}; + +// Abstract base class for a hardware. +// To introduce new hardware +// users should implement the interface and use TargetHardwareRegistration +// for registering the hardware. +// Subclasses must implement the pure virtual function interface and +// define static member variable that retrieves string identifying the Target +// Hardware. Example, +// class MyType : public TargetHardware { +// public: +// static constexpr char kId[] = "MyHardware"; +// }; +class TargetHardware { + public: + virtual ~TargetHardware() = default; + + // Initializes all TargetHardwareOperation registered for this hardware. + // Users overriding this function, should call the base class method to + // initialize the ops. + virtual bool Init(); + + // Returns the cost of running 'op' on this Hardware. + virtual double GetOpCost(mlir::Operation* op) const; + + // Returns the cost of running the whole function on this hardware. + // By default this is the sum of the cost of individual cost for each op. + virtual double GetFuncCost(func::FuncOp* func) const; + + // Returns true if 'op' can run on this Hardware. + virtual bool IsOpSupported(mlir::Operation* op) const; + + // Switching cost between from hardware and this hardware. + // If both the hardwares are the same, the transfer cost is basically 0. + virtual double GetHardwareSwitchingCost(const TargetHardware* from, + size_t buffer_size) const = 0; + + // Returns a list of all patterns to apply for this hardware. + virtual mlir::RewritePatternSet GetTransformations( + MLIRContext* context) const = 0; + + // Returns TypeId for the provided hardware. + // Usually should be something like mlir::TypeID::get() + virtual mlir::TypeID GetTypeId() const = 0; + + virtual void GetDependentDialects(mlir::DialectRegistry& registry) const {} + + protected: + // All registered hardware ops. + std::vector> hardware_ops_; +}; + +// Returns pointer to the Hardware identified by 'hardware_name'. +// If not found nullptr is returned. +// DEPRECATED: Do not use, prefer GetTargetHardwareFactory instead. +const TargetHardware* GetTargetHardware(const std::string& hardware_name); + +// Returns the factory method for the requested hardware if present. +std::function()> GetTargetHardwareFactory( + const std::string& hardware_name); + +namespace internal { + +void RegisterTargetHardwareFactory( + const std::string& unique_name, const std::string& description, + mlir::TypeID type_id, + std::function()> target_hardware_factory); + +// Registers the provided target hardware factory. +template +void RegisterTargetHardwareFactory( + const std::string& description, + std::function()> target_hardware_factory) { + RegisterTargetHardwareFactory(T::kId, description, mlir::TypeID::get(), + target_hardware_factory); +} + +// DEPRECATED: Do not use, prefer RegisterTargetHardwareOpFactory intstead. +void RegisterTargetHardwareOp( + mlir::TypeID hardware_type, mlir::TypeID op_type, + std::function()> + target_hardware_op_factory); + +void RegisterTargetHardwareOpFactory( + mlir::TypeID hardware_type, mlir::TypeID op_type, + std::function()> + target_hardware_op_factory); +} // namespace internal + +// Register target hardware. +template +struct TargetHardwareRegistration { + TargetHardwareRegistration(const std::string& description, + std::function()> + target_hardware_factory) { + internal::RegisterTargetHardwareFactory(description, + target_hardware_factory); + } +}; + +// Register Op capabilities for specific hardware. +template +struct TargetHardwareOpRegistration { + explicit TargetHardwareOpRegistration( + std::function()> + target_hardware_op_factory) { + // TODO(b/177376459): remove this. + internal::RegisterTargetHardwareOp(mlir::TypeID::get(), + mlir::TypeID::get(), + target_hardware_op_factory); + internal::RegisterTargetHardwareOpFactory(mlir::TypeID::get(), + mlir::TypeID::get(), + target_hardware_op_factory); + } +}; + +//======== util functions ========== + +// Process user specified device specs, will always add CPU if it's not there. +// specified_device_specs: ',' separated, like "GPU,DSP,CPU". +// device_specs: processed device specs enum. +bool ProcessTargetDevices(llvm::ArrayRef specified_device_specs, + std::vector* device_specs); + +// Check whether two hardwares are the same. +inline bool IsSameHardware(const TargetHardware* lhs, + const TargetHardware* rhs) { + return lhs->GetTypeId() == rhs->GetTypeId(); +} + +// Returns the ID identifying 'hardware'. This should match the ID defined +// in the hardware field ID. +// For example, if MyHardware is passed the value returned should match +// MyHardware::kId. +std::string GetHardwareName(const TargetHardware* hardware); + +} // namespace tac +} // namespace TFL +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_TAC_HARDWARES_TARGET_HARDWARE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/experimental/tac/py_wrapper/tac_wrapper.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/experimental/tac/py_wrapper/tac_wrapper.h new file mode 100644 index 00000000..5776891f --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/experimental/tac/py_wrapper/tac_wrapper.h @@ -0,0 +1,42 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_TAC_PY_WRAPPER_TAC_WRAPPER_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_TAC_PY_WRAPPER_TAC_WRAPPER_H_ + +#include +#include +#include +#include + +// Place `` before to avoid build failures in macOS. +#include + +// The empty line above is on purpose as otherwise clang-format will +// automatically move before . +#include + +namespace tflite { + +// Run target-aware-conversion for the given tflite model with the given device +// specs. +// Warning: The API is experimental and subject to change. +bool run_tac(const std::string& model_file_path, + const std::vector& device_specs, + const std::string& model_output_path); + +} // namespace tflite + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_TAC_PY_WRAPPER_TAC_WRAPPER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/experimental/tac/tac_importer_exporter.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/experimental/tac/tac_importer_exporter.h new file mode 100644 index 00000000..a40a3b94 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/experimental/tac/tac_importer_exporter.h @@ -0,0 +1,52 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_TAC_TAC_IMPORTER_EXPORTER_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_TAC_TAC_IMPORTER_EXPORTER_H_ + +#include "absl/status/statusor.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project + +namespace mlir { +namespace TFL { +namespace tac { + +// Interface for Importing program to TAC (Target Aware Conversion) Module. +// This class is an interface for importing program in TAC. +// See TacModule in how to register it with the module and use it. +class TacImporter { + public: + virtual ~TacImporter() = default; + + // Imports and returns the Module for the imported program. + virtual absl::StatusOr> Import() = 0; +}; + +// Interface for exporting a module. +// Users should implement the interface for exporting the result from TAC +// in their preferred way. +// See TacModule in how to register it with the module and use it. +class TacExporter { + public: + virtual ~TacExporter() = default; + + // Imports and returns the Module for the imported program. + virtual absl::Status Export(mlir::ModuleOp module) = 0; +}; +} // namespace tac +} // namespace TFL +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_TAC_TAC_IMPORTER_EXPORTER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/experimental/tac/tac_module.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/experimental/tac/tac_module.h new file mode 100644 index 00000000..7733a9bd --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/experimental/tac/tac_module.h @@ -0,0 +1,122 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_TAC_TAC_MODULE_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_TAC_TAC_MODULE_H_ + +#include +#include +#include +#include + +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/Pass/PassManager.h" // from @llvm-project +#include "tensorflow/compiler/mlir/lite/experimental/tac/hardwares/target_hardware.h" +#include "tensorflow/compiler/mlir/lite/experimental/tac/tac_importer_exporter.h" + +namespace mlir { +namespace TFL { +namespace tac { + +// Main class for using Target Aware Conversion (TAC). +// To run TAC: +// 1) users should create object form this class, with desired options +// (TacModule::Options). +// 2) Use SetImporter/SetExporter to the desired importer +// and exporter. +// 3) Call Run() +// +// The module fetches all TargetHardware backends registered in the binary +// and only create TargetHardware requested in Options. +// +// This class is not thread safe. +class TacModule { + public: + // TAC options. Contains knobs to configure TAC as needed. + struct Options { + // List of names for the requested Target hardware. + std::vector hardware_backends; + // Debug mode. + // This will output different alternative subgraphs in mlir format for debug + // purpose. + bool debug_mode = false; + // Whether to enable inliner passes or not. + bool enable_inliner = false; + // Whether to legalize ops to TFLite ops before exporting. + bool legalize_to_tflite_ops = false; + }; + + virtual ~TacModule() = default; + + explicit TacModule(const Options& options) : options_(options) {} + + void SetImporter(std::unique_ptr importer) { + importer_ = std::move(importer); + } + + void SetExporter(std::unique_ptr exporter) { + exporter_ = std::move(exporter); + } + + // Returns pointer to the TargetHardware that is identified by 'hardware_name' + // Returns NULL If no hardware with this name found. + const tac::TargetHardware* GetTargetHardware( + const std::string& hardware_name) const; + + // Runs the TAC workflow, configured as in the options provided during + // construction. + // SetImporter/SetExporter should be called prior to invoking `Run`. + // Returns Status of the Run. + virtual absl::Status Run(); + + // Returns all available hardware backends registered in this module + // instance. + const std::vector& GetAvailableHardwares() const { + return const_backends_; + } + + // Registers all dialects in 'registry' with the module. + // This to allow clients to register extra dialects required. + void RegisterExtraDialects(mlir::DialectRegistry& registry); + + protected: + // Adds TAC passes to the 'pass_manager'. + virtual void AddTACPass(mlir::OpPassManager* pass_manager, + llvm::ArrayRef device_specs); + + private: + // Runs all TAC passes on the provided module. + absl::Status RunTacPasses(mlir::ModuleOp* module, bool debug_mode = false); + + // Create instances of all registered hardwares. + std::vector> InstantiateBackends(); + + std::unique_ptr importer_; + std::unique_ptr exporter_; + // Owned list of all target hardware backends. + std::vector> backends_; + // Holder for const pointers for the data in 'backends_' + std::vector const_backends_; + // Extra dialects requested by the user. + mlir::DialectRegistry registry_; + + const Options options_; +}; + +} // namespace tac +} // namespace TFL +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_TAC_TAC_MODULE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/experimental/tac/tflite_import_export.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/experimental/tac/tflite_import_export.h new file mode 100644 index 00000000..ed59787f --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/experimental/tac/tflite_import_export.h @@ -0,0 +1,73 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_TAC_TFLITE_IMPORT_EXPORT_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_TAC_TFLITE_IMPORT_EXPORT_H_ + +#include +#include +#include + +#include "absl/status/status.h" +#include "llvm/Support/SourceMgr.h" +#include "tensorflow/compiler/mlir/lite/experimental/tac/tac_importer_exporter.h" + +namespace mlir { +namespace TFL { +namespace tac { +// TAC Importer for TFLite. +// This import to MLIR from tflite file or MLIR +class TfLiteImporter : public mlir::TFL::tac::TacImporter { + public: + // Options for configuring the importer. + struct Options { + std::string file_name; + // Whether the input file is an MLIR not tflite file. + bool input_mlir = false; + }; + + explicit TfLiteImporter(const Options& options) : options_(options) {} + + absl::StatusOr> Import() override; + + private: + Options options_; + mlir::MLIRContext context_; + llvm::SourceMgr source_mgr_; + std::unique_ptr source_mgr_handler_; +}; + +// TAC Exporter. It exports the provided Module to a tflite file. +class TfLiteExporter : public mlir::TFL::tac::TacExporter { + public: + // Exporter configuration options. + struct Options { + bool export_runtime_metadata = false; + bool output_mlir = false; + std::string output_file_name; + std::vector target_hardware_backends; + }; + + explicit TfLiteExporter(const Options& options) : options_(options) {} + + absl::Status Export(mlir::ModuleOp module) override; + + private: + Options options_; +}; +} // namespace tac +} // namespace TFL +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_TAC_TFLITE_IMPORT_EXPORT_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/experimental/tac/transforms/cost_model.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/experimental/tac/transforms/cost_model.h new file mode 100644 index 00000000..86445235 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/experimental/tac/transforms/cost_model.h @@ -0,0 +1,64 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_TAC_TRANSFORMS_COST_MODEL_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_TAC_TRANSFORMS_COST_MODEL_H_ + +#include + +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "tensorflow/compiler/mlir/lite/experimental/tac/common/targets.h" + +namespace mlir { +namespace TFL { +namespace tac { + +// TODO(renjieliu): We need to come up with a better strategy to do cost +// estimatation. Maybe build a big lookup table for all the ops. + +// TODO(renjieliu): We need to consider what's the default value if we cannot +// analyze the cost. + +// ================== Interface ======================== + +// Get the estimated cost for the op under the given hardware spec senario. +float GetCostForOp(Operation* op, const std::string& hardware); + +// Get the estimated cost for the whole function under the given hardware. +float GetCostForFunc(func::FuncOp* func, const std::string& hardware); + +// Get the transfer cost given from & to hardware info. +// We will only calculate for the "necessary" tensor transferred. +// from_graph & to_graph are used to compute the "necessary" tensors. +// from_graph +// / \ \ +// out1 out2 out3 +// \ / +// to_graph +// So only out2 & out3 are counted. +float GetTransferCost(const std::string& from_hardware_str, + const std::string& to_hardware_str, + func::CallOp from_graph, func::CallOp to_graph); + +// Get the cross quantization/dequantization boundary cost. +float GetQuantDequantCost(InferenceType from_inference_type, + InferenceType to_inference_type, + func::CallOp from_graph, func::CallOp to_graph); + +} // namespace tac +} // namespace TFL +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_TAC_TRANSFORMS_COST_MODEL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/experimental/tac/transforms/device_transform.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/experimental/tac/transforms/device_transform.h new file mode 100644 index 00000000..e6d77838 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/experimental/tac/transforms/device_transform.h @@ -0,0 +1,49 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_TAC_TRANSFORMS_DEVICE_TRANSFORM_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_TAC_TRANSFORMS_DEVICE_TRANSFORM_H_ + +#include + +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "tensorflow/compiler/mlir/lite/experimental/tac/common/targets.h" +#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" + +namespace mlir { +namespace TFL { +namespace tac { + +// Returns true if 'op' is supported to run on 'hardware'. +bool IsSupported(Operation* op, const std::string& hardware); + +// Return proper rewriter patterns for different hardwares. +RewritePatternSet GetHardwareRewritePatterns(MLIRContext* context, + const std::string& hardware); + +// Convert quantized ops to float, this will essentially insert dequantize & +// quantize pair around the op. +void ConvertQuantizedOpToFloat(func::FuncOp func, OpBuilder* builder); + +// This will optimize the quantized ops -> float graph. +void OptimizeQuantizedOpToFloat(func::FuncOp func, MLIRContext* context); + +} // namespace tac +} // namespace TFL +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_TAC_TRANSFORMS_DEVICE_TRANSFORM_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/experimental/tac/transforms/device_transform_gpu.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/experimental/tac/transforms/device_transform_gpu.h new file mode 100644 index 00000000..9de0e3c2 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/experimental/tac/transforms/device_transform_gpu.h @@ -0,0 +1,34 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_TAC_TRANSFORMS_DEVICE_TRANSFORM_GPU_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_TAC_TRANSFORMS_DEVICE_TRANSFORM_GPU_H_ + +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project + +namespace mlir { +namespace TFL { +namespace tac { + +// nit: Returns all the gpu suitable transformation patterns. +RewritePatternSet GetHardwareRewritePatternsGPU(MLIRContext* context); + +} // namespace tac +} // namespace TFL +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_TAC_TRANSFORMS_DEVICE_TRANSFORM_GPU_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/experimental/tac/transforms/device_transform_patterns.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/experimental/tac/transforms/device_transform_patterns.h new file mode 100644 index 00000000..3866d576 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/experimental/tac/transforms/device_transform_patterns.h @@ -0,0 +1,115 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_TAC_TRANSFORMS_DEVICE_TRANSFORM_PATTERNS_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_TAC_TRANSFORMS_DEVICE_TRANSFORM_PATTERNS_H_ + +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" + +namespace mlir { +namespace TFL { +namespace tac { + +// TODO(renjieliu): add more patterns. + +// This basically: +// Pack => (Concat -> Reshape) +struct LowerPackIntoConcatReshape : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TFL::PackOp pack_op, + PatternRewriter& rewriter) const override; +}; + +struct SquaredDifference : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TFL::SquaredDifferenceOp squared_diff_op, + PatternRewriter& rewriter) const override; +}; + +// Unroll split into a bunch of slice ops. +struct UnrollSplit : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TFL::SplitOp split_op, + PatternRewriter& rewriter) const override; +}; + +// Unroll splitv into a bunch of slice ops. +struct UnrollSplitV : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TFL::SplitVOp splitv_op, + PatternRewriter& rewriter) const override; +}; + +// Ensure bias for conv2d op. +struct EnsureBiasForConv2d : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TFL::Conv2DOp conv_op, + PatternRewriter& rewriter) const override; +}; + +// Pad slice to 4d. +struct PadSlice : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TFL::SliceOp slice_op, + PatternRewriter& rewriter) const override; +}; + +// Fully connected to conv2d. +struct FullyConnectedToConv : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TFL::FullyConnectedOp fc_op, + PatternRewriter& rewriter) const override; +}; + +// Pad concat to 4d. +struct PadConcat : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TFL::ConcatenationOp concat_op, + PatternRewriter& rewriter) const override; +}; + +// Convert reduce mean 4d to avg pool. +struct ReduceMeanToAvgPool : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TFL::MeanOp mean_op, + PatternRewriter& rewriter) const override; +}; + +// Insert Requant ops for reduce_mean. +struct InsertRequantForReduceMean : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TFL::MeanOp mean_op, + PatternRewriter& rewriter) const override; +}; + +} // namespace tac +} // namespace TFL +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_TAC_TRANSFORMS_DEVICE_TRANSFORM_PATTERNS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/experimental/tac/transforms/passes.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/experimental/tac/transforms/passes.h new file mode 100644 index 00000000..a16b0f77 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/experimental/tac/transforms/passes.h @@ -0,0 +1,77 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_TAC_TRANSFORMS_PASSES_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_TAC_TRANSFORMS_PASSES_H_ + +#include +#include + +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "tensorflow/compiler/mlir/lite/experimental/tac/tac_filter.pb.h" + +namespace mlir { +namespace TFL { +namespace tac { +class TacModule; + +// Create an instance of the TargetAnnotationPass. +// TODO(b/177376459): Remove in favor of the one below. +std::unique_ptr> CreateTargetAnnotationPass( + llvm::ArrayRef device_specs); + +// Create and instance of TargetAnnotationPass. +std::unique_ptr> CreateTargetAnnotationPass( + const TacModule* module); + +// Create an instance of the RaiseTargetSubgraphsPass. If `skip_raise_cpu_ops`, +// we skip clustering for CPU ops for better clustering of ops running on other +// ML accelerators. When `ignore_inference_type` is set to true, the inference +// types are set to "NOT_CARE" for better clustering. +std::unique_ptr> CreateRaiseTargetSubgraphsPass( + bool skip_raise_cpu_ops = false, bool ignore_inference_type = false); + +// Create an instance of the AlternativeSubgraphPass. +std::unique_ptr> CreateAlternativeSubgraphPass( + llvm::ArrayRef device_specs); + +// Create an instance of ComputeCostPass. +std::unique_ptr> CreateComputeCostPass(); + +// Create an instance of PickSubgraphsPass. +std::unique_ptr> CreatePickSubgraphsPass(); + +// Create an instance of DeviceTransformGPUPass. +std::unique_ptr> CreateDeviceTransformGPUPass(); + +// Create an instance of GetOpCostPass. +std::unique_ptr> CreateGetOpCostPass(); + +// Create an instance of FoldConstantsToSubgraphPass. +std::unique_ptr> CreateFoldConstantsToSubgraphPass( + bool fold_all_constants); + +// Create an instance of TacFilterPass. +std::unique_ptr> CreateTacFilterPass( + ::third_party::tensorflow::compiler::mlir::lite::experimental::tac:: + TacFilters* tac_filters); + +} // namespace tac +} // namespace TFL +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_TAC_TRANSFORMS_PASSES_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/experimental/tac/transforms/tac_pass.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/experimental/tac/transforms/tac_pass.h new file mode 100644 index 00000000..6e61dbe9 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/experimental/tac/transforms/tac_pass.h @@ -0,0 +1,90 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_TAC_TRANSFORMS_TAC_PASS_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_TAC_TRANSFORMS_TAC_PASS_H_ + +#include +#include + +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "tensorflow/compiler/mlir/lite/experimental/tac/hardwares/target_hardware.h" +#include "tensorflow/compiler/mlir/lite/experimental/tac/tac_module.h" + +namespace mlir { +namespace TFL { +namespace tac { +// An OperationPass<> with access to the TAC module instance that the +// pass is running part of. +// See OperationPass<> comments for all details/restrictions of OperationPass. +// +// When adding new Pass to TAC, users should use this class as the base class +// as it provides access to the TAC module. +template +class TacPass : public OperationPass { + public: + using OperationPass::OperationPass; + explicit TacPass(const TacModule* module) + : OperationPass::OperationPass(mlir::TypeID::get()), + module_(module) {} + + ~TacPass() override = default; + + const TargetHardware* GetTargetHardware( + const std::string& hardware_name) const { + return module_ != nullptr + ? module_->GetTargetHardware(hardware_name) + : mlir::TFL::tac::GetTargetHardware(hardware_name); + } + + protected: + const TacModule* module_ = nullptr; // Not owned. +}; + +// A FunctionPass but with access to TAC module. +// See FunctionPass comments for all details/restrictions of FunctionPass. +// +// When adding new Pass to TAC, users should use this class as the base class +// as it provides access to the TAC module. +template +class TacFunctionPass : public TacPass { + public: + using TacPass::TacPass; + + ~TacFunctionPass() override = default; + + mlir::func::FuncOp getFunction() { return getOperation(); } + + virtual void runOnFunction() = 0; + + void runOnOperation() final { + if (!getFunction().isExternal()) runOnFunction(); + } + + protected: + // Returns the derived pass name. + StringRef getName() const override { return llvm::getTypeName(); } + + // A clone method to create a copy of this pass. + std::unique_ptr clonePass() const override { + return std::make_unique(*static_cast(this)); + } +}; + +} // namespace tac +} // namespace TFL +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_TAC_TRANSFORMS_TAC_PASS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/experimental/tac/utils/utils.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/experimental/tac/utils/utils.h new file mode 100644 index 00000000..049cc186 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/experimental/tac/utils/utils.h @@ -0,0 +1,49 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_TAC_UTILS_UTILS_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_TAC_UTILS_UTILS_H_ + +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/OwningOpRef.h" // from @llvm-project +#include "mlir/Parser/Parser.h" // from @llvm-project + +namespace mlir { +namespace TFL { +namespace tac { + +// Import the file as mlir module, the input maybe flatbuffer or mlir file. +absl::StatusOr> ImportFlatbufferOrMlir( + const std::string& input_filename, bool input_mlir, + bool experimental_prune_unreachable_nodes_unconditionally, + llvm::SourceMgr* source_mgr, mlir::MLIRContext* context); + +// Export the module to file, can be either mlir or flatbuffer. +absl::Status ExportFlatbufferOrMlir( + const std::string& output_filename, bool output_mlir, mlir::ModuleOp module, + bool enable_select_tf_ops, + std::optional custom_option_alignment = std::nullopt); + +} // namespace tac +} // namespace TFL +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_TAC_UTILS_UTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/flatbuffer_export.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/flatbuffer_export.h new file mode 100644 index 00000000..27cf2852 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/flatbuffer_export.h @@ -0,0 +1,62 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_FLATBUFFER_EXPORT_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_FLATBUFFER_EXPORT_H_ + +#include +#include +#include +#include +#include + +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "tensorflow/compiler/mlir/lite/converter_flags.pb.h" +#include "tensorflow/compiler/mlir/op_or_arg_name_mapper.h" + +namespace tflite { +// Options for exporting to Flatbuffer. +struct FlatbufferExportOptions { + // ConverterFlags proto. The following fields are migrated. + // bool emit_builtin_tflite_ops -> !converter_flags.force_select_tf_ops() + // bool emit_select_tf_ops -> converter_flags.enable_select_tf_ops() + // bool emit_custom_ops -> converter_flags.allow_custom_ops() + // bool allow_all_select_tf_ops -> converter_flags.allow_all_select_tf_ops() + // std::set<> select_user_tf_ops -> converter_flags.select_user_tf_ops() + tflite::ConverterFlags converter_flags; + // When exporting from SavedModel, this will have the requested tags. + std::unordered_set saved_model_tags; + // Metadata key/value pairs to write to the flatbuffer. + std::map metadata; + // OpOrArgNameMapper to convert location of the op to name in flatbuffer. + // If not set, a default mapper will be used. + tensorflow::OpOrArgNameMapper* op_or_arg_name_mapper = nullptr; + // User-specified value of flatbuffer alignment requirement for custom + // options. If specified, the value should be multiplier of 16 (default + // alignment for TFL flatbuffer). + std::optional custom_option_alignment = std::nullopt; +}; + +// Translates the given MLIR `module` into a FlatBuffer and stores the +// serialized flatbuffer into the string. +// Returns true on successful exporting, false otherwise. +bool MlirToFlatBufferTranslateFunction(mlir::ModuleOp module, + const FlatbufferExportOptions& options, + std::string* serialized_flatbuffer, + bool serialize_stablehlo_ops = false); +} // namespace tflite + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_FLATBUFFER_EXPORT_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/flatbuffer_export_flags.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/flatbuffer_export_flags.h new file mode 100644 index 00000000..ba97ede7 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/flatbuffer_export_flags.h @@ -0,0 +1,34 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_FLATBUFFER_EXPORT_FLAGS_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_FLATBUFFER_EXPORT_FLAGS_H_ + +#include + +// These flags are used to control the emission or not of different kinds of ops +// during the flatbuffer translation. +extern bool emit_builtin_tflite_ops; +extern bool emit_select_tf_ops; +extern bool emit_custom_ops; +// The flag to control whether to lower tensorlist ops into TF ops. +extern bool lower_tensor_list_ops; +// The flag to control whether debug info gets stripped on export. +extern bool strip_debug_info; +// The flag to control whether to store constant & custom buffers inside +// flatbuffer +extern bool use_buffer_offset; + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_FLATBUFFER_EXPORT_FLAGS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/flatbuffer_import.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/flatbuffer_import.h new file mode 100644 index 00000000..f0f70114 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/flatbuffer_import.h @@ -0,0 +1,49 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_FLATBUFFER_IMPORT_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_FLATBUFFER_IMPORT_H_ + +#include +#include + +#include "absl/strings/string_view.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project + +namespace tflite { +// Converts a TFLite flatbuffer stored in `buffer` to a MLIR module +// The buffer must live for the duration of the function call, +// The caller receives ownership of the module. +// `base_loc` is used for error reporting and debug info. +// If ordered_output_arrays is not empty, then the imported mlir function will +// only return nodes in ordered_output_arrays in the same order. Returns nullptr +// on failure, and more specific errors will be emitted via the context. +// If `use_external_constant` is true, it will create `tfl.external_const` +// instead of `tfl.const`. +// If `experimental_prune_unreachable_nodes_unconditionally` is true, nodes that +// are not ancestors of the output nodes will be pruned. +mlir::OwningOpRef FlatBufferToMlir( + absl::string_view buffer, mlir::MLIRContext* context, + mlir::Location base_loc, bool use_external_constant = false, + const std::vector& ordered_input_arrays = {}, + const std::vector& ordered_output_arrays = {}, + bool experimental_prune_unreachable_nodes_unconditionally = false, + bool disable_vhlo_to_stablehlo = false); +} // namespace tflite + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_FLATBUFFER_IMPORT_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/flatbuffer_operator.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/flatbuffer_operator.h new file mode 100644 index 00000000..f0afe15f --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/flatbuffer_operator.h @@ -0,0 +1,309 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_FLATBUFFER_OPERATOR_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_FLATBUFFER_OPERATOR_H_ + +#include + +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "flatbuffers/flatbuffers.h" // from @flatbuffers +#include "llvm/ADT/SmallVector.h" +#include "llvm/Analysis/AssumeBundleQueries.h" +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "stablehlo/dialect/StablehloOps.h" // from @stablehlo +#include "stablehlo/dialect/VhloOps.h" // from @stablehlo +#include "stablehlo/dialect/VhloTypes.h" // from @stablehlo +#include "tensorflow/compiler/mlir/lite/schema/mutable/schema_generated.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/statusor.h" + +namespace mlir { + +// duplicated from +// https://github.com/openxla/stablehlo/blob/e5ad51715a11721c78b6748ab5de7945df24b1b8/stablehlo/transforms/StablehloLegalizeToVhlo.cpp#L756 +// so we can create correct vhlo types +class StablehloVhloTypeConverter : public mlir::vhlo::VhloTypeConverter { + public: + StablehloVhloTypeConverter() : mlir::vhlo::VhloTypeConverter() { + addConversion([](mlir::Type type) -> mlir::Type { + if (type.getDialect().getNamespace() == + mlir::vhlo::VhloDialect::getDialectNamespace()) { + return type; + } + return {}; + }); + addConversion([](mlir::stablehlo::TokenType token) -> mlir::Type { + return mlir::vhlo::TokenV1Type::get(token.getContext()); + }); + addBuiltinToVhloConversions(); + } + + mlir::Attribute convertEncoding(mlir::Attribute attr) const final { + // Must be VHLO encoding, or convertible to VHLO encoding. + if (attr.getDialect().getNamespace() == + mlir::vhlo::VhloDialect::getDialectNamespace()) + return attr; + + if (auto stablehloAttr = + mlir::dyn_cast_or_null(attr)) { + return mlir::vhlo::TypeExtensionsV1Attr::get(stablehloAttr.getContext(), + stablehloAttr.getBounds()); + } + + // Was not VHLO encoding, or convertible. + return {}; + } +}; + +// from +// https://github.com/openxla/stablehlo/blob/e5ad51715a11721c78b6748ab5de7945df24b1b8/stablehlo/transforms/VhloLegalizeToStablehlo.cpp#L45C70-L45C70 +class VhloToStablehloTypeConverter : public vhlo::VhloTypeConverter { + public: + VhloToStablehloTypeConverter() : vhlo::VhloTypeConverter() { + addConversion([](Type type) -> Type { return type; }); + addConversion([](vhlo::TokenV1Type token) -> Type { + return stablehlo::TokenType::get(token.getContext()); + }); + addVhloToBuiltinConversions(); + } + + Attribute convertEncoding(Attribute attr) const final { + if (auto vhloAttr = + mlir::dyn_cast_or_null(attr)) { + return stablehlo::TypeExtensionsAttr::get(vhloAttr.getContext(), + vhloAttr.getBounds()); + } + // All encodings supported in StableHLO. + return attr; + } +}; + +// Returns true if the op_code belongs to a stablehlo operation. +bool IsStablehloOp(const tflite::OperatorCodeT &op_code); + +// Returns the MLIR op name for the flatbuffer operator corresponding to +// `op_code`. +std::string GetMlirOpNameFromOpCode(const ::tflite::OperatorCodeT &op_code); + +// Returns the builtin op code for the given MLIR operation on success; emits +// error and returns std::nullopt on failure. +std::optional GetBuiltinOpCode(Operation *mlir_op); + +// Packs the given MLIR operation into a TFLite FlatBuffer operator object. +// Returns the FlatBuffer offset for the operator on success; emits error and +// returns std::nullopt on failure. +std::optional> CreateFlatBufferOperator( + Operation *mlir_op, uint32_t opcode_index, + const std::vector &operands, const std::vector &results, + const std::vector &intermediates, + flatbuffers::FlatBufferBuilder *fbb, + std::optional debug_metadata_index = -1); + +// Populates the array of mlir::NamedAttributes corresponding to the given +// tflite::FlatbufferOptionsUnion. +// We use an out parameter per LLVM convention +void BuiltinOptionsToAttributes( + tflite::BuiltinOptionsUnion op_union, mlir::Builder builder, + // NOLINTNEXTLINE + llvm::SmallVectorImpl &attributes); + +// While the last several tensors could be optional tensors for an tfl op, the +// number of input operands could vary. This function gets the min/max number of +// operands from tflite op name. +llvm::MinMax OperandNumbersMinMax(llvm::StringRef op_name); + +// Populates the `custom_code` and `custom_options` to attributes. +// `custom_code` is used to identify CustomOp. +// `custom_options` are opaque attribute used to store infomations for this +// custom op. +absl::Status CustomOptionsToAttributes( + const std::string &custom_code, const std::vector &custom_options, + mlir::Builder builder, + // NOLINTNEXTLINE + Location loc, llvm::SmallVectorImpl *attributes); + +// TODO(zichuanwei@): Populate Builtin_options_2 manual for now, should automate +// these in the future +void BuiltinOptions2ToAttributes( + tflite::BuiltinOptions2Union op_union, mlir::Builder builder, + llvm::SmallVectorImpl &attributes); + +// Function calls with a non-specialized type will result to a linker error. +template +inline std::vector GetVector(DenseElementsAttr elements); + +// TODO(zichuanwei@): for each type, we need to make sure the element type +// matches the expected type otherwise an error should be thrown, but for now +// we're just returning empty vector +template <> +inline std::vector GetVector(DenseElementsAttr elements) { + auto type = elements.getType(); + auto elemType = type.getElementType(); + if (elemType.isSignlessInteger(1)) { + auto vec = llvm::to_vector( + llvm::map_range(elements.getValues(), + [&](bool value) -> uint8_t { return value ? 1 : 0; })); + return std::vector(vec.begin(), vec.end()); + } + + return std::vector(); +} + +template <> +inline std::vector GetVector(DenseElementsAttr elements) { + auto type = elements.getType(); + auto elemType = type.getElementType(); + if (elemType.isSignlessInteger(8)) { + auto vec = llvm::to_vector(llvm::map_range( + elements.getValues(), + [&](APInt value) -> int8_t { return value.getSExtValue(); })); + return std::vector(vec.begin(), vec.end()); + } + + return std::vector(); +} + +template <> +inline std::vector GetVector(DenseElementsAttr elements) { + auto type = elements.getType(); + auto elemType = type.getElementType(); + if (elemType.isSignlessInteger(16)) { + auto vec = llvm::to_vector(llvm::map_range( + elements.getValues(), + [&](APInt value) -> int16_t { return value.getSExtValue(); })); + return std::vector(vec.begin(), vec.end()); + } + + return std::vector(); +} + +template <> +inline std::vector GetVector(DenseElementsAttr elements) { + auto type = elements.getType(); + auto elemType = type.getElementType(); + if (elemType.isSignlessInteger(32)) { + auto vec = llvm::to_vector(llvm::map_range( + elements.getValues(), + [&](APInt value) -> int32_t { return value.getSExtValue(); })); + return std::vector(vec.begin(), vec.end()); + } + + return std::vector(); +} + +template <> +inline std::vector GetVector(DenseElementsAttr elements) { + auto type = elements.getType(); + auto elemType = type.getElementType(); + if (elemType.isSignlessInteger(64)) { + auto vec = llvm::to_vector(llvm::map_range( + elements.getValues(), + [&](APInt value) -> int64_t { return value.getSExtValue(); })); + return std::vector(vec.begin(), vec.end()); + } + + return std::vector(); +} + +template <> +inline std::vector GetVector(DenseElementsAttr elements) { + auto type = elements.getType(); + auto elemType = type.getElementType(); + if (elemType.isSignlessInteger(64)) { + auto vec = llvm::to_vector(llvm::map_range( + elements.getValues(), + [&](APInt value) -> uint64_t { return value.getSExtValue(); })); + return std::vector(vec.begin(), vec.end()); + } + + return std::vector(); +} + +template <> +inline std::vector GetVector(DenseElementsAttr elements) { + auto type = elements.getType(); + auto elemType = type.getElementType(); + if (elemType.isF32()) { + auto vec = llvm::to_vector(llvm::map_range( + elements.getValues(), + [&](APFloat value) -> float { return value.convertToFloat(); })); + return std::vector(vec.begin(), vec.end()); + } + + return std::vector(); +} + +template <> +inline std::vector GetVector(DenseElementsAttr elements) { + auto type = elements.getType(); + auto elemType = type.getElementType(); + if (elemType.isF64()) { + auto vec = llvm::to_vector(llvm::map_range( + elements.getValues(), + [&](APFloat value) -> double { return value.convertToFloat(); })); + return std::vector(vec.begin(), vec.end()); + } + + return std::vector(); +} + +// Handles the case when the DenseElementsAttr doesn't exist, and when it +// doesn't returns a vector of length `default_size` all with the same value +// `default_value`. +template +static inline std::vector GetOptionalVector( + std::optional elements, int64_t default_size = 0, + int64_t default_value = 0) { + if (elements.has_value()) { + return GetVector(elements.value()); + } + return std::vector(default_size, default_value); +} + +// Handles the case when the ArrayRef doesn't exist, and when it +// doesn't returns a vector of length `default_size` all with the same value +// `default_value`. +template +static inline std::vector GetOptionalVector( + std::optional> values, int64_t default_size = 0, + int64_t default_value = 0) { + if (values.has_value()) { + return std::vector(values->begin(), values->end()); + } + return std::vector(default_size, default_value); +} + +template +static inline std::vector GetVector( + vhlo::TensorV1Attr elements, + mlir::vhlo::VhloTypeConverter &vhlo_type_converter) { + return GetOptionalVector(mlir::DenseIntElementsAttr::getFromRawBuffer( + mlir::cast( + vhlo_type_converter.convertType(elements.getType())), + elements.getData())); +} + +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_FLATBUFFER_OPERATOR_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/flatbuffer_translate.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/flatbuffer_translate.h new file mode 100644 index 00000000..f344fc28 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/flatbuffer_translate.h @@ -0,0 +1,44 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_FLATBUFFER_TRANSLATE_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_FLATBUFFER_TRANSLATE_H_ + +#include + +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "tensorflow/compiler/mlir/op_or_arg_name_mapper.h" + +namespace tflite { + +// Translates the given MLIR `module` into a FlatBuffer and stores the +// serialized flatbuffer into the string. This uses OpOrArgLocNameMapper to +// convert location of the op to name in flatbuffer. Returns true if translation +// fails, otherwise returns false. +bool MlirToFlatBufferTranslateFunction(mlir::ModuleOp module, + std::string* serialized_flatbuffer, + bool emit_builtin_tflite_ops, + bool emit_select_tf_ops, + bool emit_custom_ops); + +// Same as the above but with a custom op name mapper. +bool MlirToFlatBufferTranslateFunction( + mlir::ModuleOp module, std::string* serialized_flatbuffer, + bool emit_builtin_tflite_ops, bool emit_select_tf_ops, bool emit_custom_ops, + tensorflow::OpOrArgNameMapper* op_or_arg_name_mapper); +} // namespace tflite + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_FLATBUFFER_TRANSLATE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/flatbuffer_translate_flags.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/flatbuffer_translate_flags.h new file mode 100644 index 00000000..6c8f80d4 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/flatbuffer_translate_flags.h @@ -0,0 +1,31 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_FLATBUFFER_TRANSLATE_FLAGS_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_FLATBUFFER_TRANSLATE_FLAGS_H_ + +#include + +// These flags are used to control the emission or not of different kinds of ops +// during the flatbuffer translation. +extern bool emit_builtin_tflite_ops; +extern bool emit_select_tf_ops; +extern bool emit_custom_ops; +// The flag to control whether to lower tensorlist ops into TF ops. +extern bool lower_tensor_list_ops; +// The flag to control whether debug info gets stripped on export. +extern bool strip_debug_info; + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_FLATBUFFER_TRANSLATE_FLAGS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/ir/tfl_ops.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/ir/tfl_ops.h new file mode 100644 index 00000000..5946ce0f --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/ir/tfl_ops.h @@ -0,0 +1,66 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This file defines the operations used in the MLIR TensorFlow Lite dialect. + +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_IR_TFL_OPS_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_IR_TFL_OPS_H_ + +#include "mlir/Dialect/Traits.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/Dialect.h" // from @llvm-project +#include "mlir/IR/DialectImplementation.h" // from @llvm-project +#include "mlir/IR/OpImplementation.h" // from @llvm-project +#include "mlir/IR/TypeSupport.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/Interfaces/DerivedAttributeOpInterface.h" // from @llvm-project +#include "mlir/Interfaces/InferTypeOpInterface.h" // from @llvm-project +#include "mlir/Interfaces/LoopLikeInterface.h" // from @llvm-project +#include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/TypeID.h" // from @llvm-project +#include "tensorflow/compiler/mlir/lite/ir/tfl_ops_dialect.h.inc" +#include "tensorflow/compiler/mlir/lite/ir/tfl_ops_enums.h.inc" +#include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.h" +#include "tensorflow/compiler/mlir/lite/schema/schema_generated.h" +#include "tensorflow/compiler/mlir/lite/utils/utils.h" +#include "tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_utils.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_traits.h" +#define GET_ATTRDEF_CLASSES +#include "tensorflow/compiler/mlir/lite/ir/tfl_ops_attrdefs.h.inc" + +namespace mlir { +namespace TFL { + +typedef TFLDialect TensorFlowLiteDialect; + +// The Control type is a token-like value that models control dependencies +class ControlType : public Type::TypeBase { + public: + using Base::Base; + static constexpr StringLiteral name = "tfl.control"; +}; + +#include "tensorflow/compiler/mlir/lite/ir/tfl_ops_interface.h.inc" + +} // end namespace TFL +} // end namespace mlir + +#define GET_OP_CLASSES +#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h.inc" + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_IR_TFL_OPS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/kernels/internal/common.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/kernels/internal/common.h new file mode 100644 index 00000000..fd9fdf81 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/kernels/internal/common.h @@ -0,0 +1,236 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_KERNELS_INTERNAL_COMMON_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_KERNELS_INTERNAL_COMMON_H_ + +#include +#include +#include + +#ifndef ALLOW_SLOW_GENERIC_DEPTHWISECONV_FALLBACK +#ifdef GEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK +#define ALLOW_SLOW_GENERIC_DEPTHWISECONV_FALLBACK +#endif +#endif + +#include +#include + +#include "fixedpoint/fixedpoint.h" +#include "tensorflow/compiler/mlir/lite/core/macros.h" +#include "tensorflow/compiler/mlir/lite/kernels/internal/compatibility_macros.h" +#include "tensorflow/compiler/mlir/lite/kernels/internal/optimized/neon_check.h" + +// LINT.IfChange + +namespace tflite_migration { + +constexpr int kReverseShift = -1; + +TFLITE_NOINLINE int32_t MultiplyByQuantizedMultiplier( + int32_t x, int32_t quantized_multiplier, int shift); + +TFLITE_NOINLINE int32_t MultiplyByQuantizedMultiplier( + int64_t x, int32_t quantized_multiplier, int shift); + +// Single-rounding MultiplyByQuantizedMultiplier +#if TFLITE_SINGLE_ROUNDING +inline int32_t MultiplyByQuantizedMultiplierSmallerThanOneExp( + int32_t x, int32_t quantized_multiplier, int shift) { + TFLITE_DCHECK_LE(shift, 0); + return MultiplyByQuantizedMultiplier(x, quantized_multiplier, shift); +} + +inline int32_t MultiplyByQuantizedMultiplierGreaterThanOne( + int32_t x, int32_t quantized_multiplier, int shift) { + TFLITE_DCHECK_GE(shift, 0); + return MultiplyByQuantizedMultiplier(x, quantized_multiplier, shift); +} + +#ifdef USE_NEON +inline int32x4x4_t MultiplyByQuantizedMultiplier4Rows( + int32x4x4_t input_val, int32_t quantized_multiplier, int shift) { + TFLITE_DCHECK(quantized_multiplier >= 0); + + const int right_shift = std::min(-1, shift); + const int left_shift = shift - right_shift; + + const int32x4_t multiplier_dup = vdupq_n_s32(quantized_multiplier); + const int32x4_t left_shift_dup = vdupq_n_s32(left_shift); + const int32x4_t right_shift_dup = vdupq_n_s32(right_shift); + + int32x4x4_t result; + result.val[0] = vrshlq_s32( + vqdmulhq_s32(vshlq_s32(input_val.val[0], left_shift_dup), multiplier_dup), + right_shift_dup); + + result.val[1] = vrshlq_s32( + vqdmulhq_s32(vshlq_s32(input_val.val[1], left_shift_dup), multiplier_dup), + right_shift_dup); + + result.val[2] = vrshlq_s32( + vqdmulhq_s32(vshlq_s32(input_val.val[2], left_shift_dup), multiplier_dup), + right_shift_dup); + + result.val[3] = vrshlq_s32( + vqdmulhq_s32(vshlq_s32(input_val.val[3], left_shift_dup), multiplier_dup), + right_shift_dup); + + return result; +} +#endif // USE_NEON +// Double-rounding MultiplyByQuantizedMultiplier +#else +inline int32_t MultiplyByQuantizedMultiplierSmallerThanOneExp( + int32_t x, int32_t quantized_multiplier, int left_shift) { + using gemmlowp::RoundingDivideByPOT; + using gemmlowp::SaturatingRoundingDoublingHighMul; + return RoundingDivideByPOT( + SaturatingRoundingDoublingHighMul(x, quantized_multiplier), -left_shift); +} + +inline int32_t MultiplyByQuantizedMultiplierGreaterThanOne( + int32_t x, int32_t quantized_multiplier, int left_shift) { + using gemmlowp::SaturatingRoundingDoublingHighMul; + return SaturatingRoundingDoublingHighMul(x * (1 << left_shift), + quantized_multiplier); +} + +#ifdef USE_NEON +// Round uses ARM's rounding shift right. +inline int32x4x4_t MultiplyByQuantizedMultiplier4Rows( + int32x4x4_t input_val, int32_t quantized_multiplier, int shift) { + const int left_shift = std::max(shift, 0); + const int right_shift = std::min(shift, 0); + int32x4x4_t result; + + int32x4_t multiplier_dup = vdupq_n_s32(quantized_multiplier); + int32x4_t left_shift_dup = vdupq_n_s32(left_shift); + int32x4_t right_shift_dup = vdupq_n_s32(right_shift); + + result.val[0] = + vrshlq_s32(vqrdmulhq_s32(vshlq_s32(input_val.val[0], left_shift_dup), + multiplier_dup), + right_shift_dup); + + result.val[1] = + vrshlq_s32(vqrdmulhq_s32(vshlq_s32(input_val.val[1], left_shift_dup), + multiplier_dup), + right_shift_dup); + + result.val[2] = + vrshlq_s32(vqrdmulhq_s32(vshlq_s32(input_val.val[2], left_shift_dup), + multiplier_dup), + right_shift_dup); + + result.val[3] = + vrshlq_s32(vqrdmulhq_s32(vshlq_s32(input_val.val[3], left_shift_dup), + multiplier_dup), + right_shift_dup); + + return result; +} +#endif // USE_NEON +#endif // TFLITE_SINGLE_ROUNDING + +template +int CountLeadingZeros(T integer_input) { + static_assert(std::is_unsigned::value, + "Only unsigned integer types handled."); + if (integer_input == 0) { + return std::numeric_limits::digits; + } +#if defined(__GNUC__) + if (std::is_same::value) { + return __builtin_clz(integer_input); + } else if (std::is_same::value) { + return __builtin_clzll(integer_input); + } +#endif + const T one_in_leading_positive = static_cast(1) + << (std::numeric_limits::digits - 1); + int leading_zeros = 0; + while (integer_input < one_in_leading_positive) { + integer_input <<= 1; + ++leading_zeros; + } + return leading_zeros; +} + +inline void GetInvSqrtQuantizedMultiplierExp(int32_t input, int reverse_shift, + int32_t* output_inv_sqrt, + int* output_shift) { + TFLITE_DCHECK_GE(input, 0); + if (input <= 1) { + // Handle the input value 1 separately to avoid overflow in that case + // in the general computation below (b/143972021). Also handle 0 as if it + // were a 1. 0 is an invalid input here (divide by zero) and 1 is a valid + // but rare/unrealistic input value. We can expect both to occur in some + // incompletely trained models, but probably not in fully trained models. + *output_inv_sqrt = std::numeric_limits::max(); + *output_shift = 0; + return; + } + TFLITE_DCHECK_GT(input, 1); + *output_shift = 11; + while (input >= (1 << 29)) { + input /= 4; + ++*output_shift; + } + const unsigned max_left_shift_bits = + CountLeadingZeros(static_cast(input)) - 1; + const unsigned max_left_shift_bit_pairs = max_left_shift_bits / 2; + const unsigned left_shift_bit_pairs = max_left_shift_bit_pairs - 1; + *output_shift -= left_shift_bit_pairs; + input <<= 2 * left_shift_bit_pairs; + TFLITE_DCHECK_GE(input, (1 << 27)); + TFLITE_DCHECK_LT(input, (1 << 29)); + using gemmlowp::FixedPoint; + using gemmlowp::Rescale; + using gemmlowp::SaturatingRoundingMultiplyByPOT; + // Using 3 integer bits gives us enough room for the internal arithmetic in + // this Newton-Raphson iteration. + using F3 = FixedPoint; + using F0 = FixedPoint; + const F3 fixedpoint_input = F3::FromRaw(input >> 1); + const F3 fixedpoint_half_input = + SaturatingRoundingMultiplyByPOT<-1>(fixedpoint_input); + const F3 fixedpoint_half_three = + GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F3, (1 << 28) + (1 << 27), 1.5); + // Newton-Raphson iteration + // Naive unoptimized starting guess: x = 1 + F3 x = F3::One(); + // Naive unoptimized number of iterations: 5 + for (int i = 0; i < 5; i++) { + const F3 x3 = Rescale<3>(x * x * x); + x = Rescale<3>(fixedpoint_half_three * x - fixedpoint_half_input * x3); + } + const F0 fixedpoint_half_sqrt_2 = + GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F0, 1518500250, std::sqrt(2.) / 2.); + x = x * fixedpoint_half_sqrt_2; + *output_inv_sqrt = x.raw(); + if (*output_shift < 0) { + *output_inv_sqrt <<= -*output_shift; + *output_shift = 0; + } + // Convert right shift (right is positive) to left shift. + *output_shift *= reverse_shift; +} + +} // namespace tflite_migration + +// LINT.ThenChange(//tensorflow/lite/kernels/internal/common.h) + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_KERNELS_INTERNAL_COMMON_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/kernels/internal/compatibility_macros.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/kernels/internal/compatibility_macros.h new file mode 100644 index 00000000..3233fe00 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/kernels/internal/compatibility_macros.h @@ -0,0 +1,63 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_KERNELS_INTERNAL_COMPATIBILITY_MACROS_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_KERNELS_INTERNAL_COMPATIBILITY_MACROS_H_ + +#ifndef TFLITE_ABORT +#define TFLITE_ABORT abort() +#endif + +#ifndef TFLITE_ASSERT_FALSE +#if defined(NDEBUG) +#define TFLITE_ASSERT_FALSE (static_cast(0)) +#else +#define TFLITE_ASSERT_FALSE TFLITE_ABORT +#endif +#endif + +// LINT.IfChange + +#ifndef TFLITE_DCHECK +#define TFLITE_DCHECK(condition) (condition) ? (void)0 : TFLITE_ASSERT_FALSE +#endif + +#ifndef TFLITE_DCHECK_EQ +#define TFLITE_DCHECK_EQ(x, y) ((x) == (y)) ? (void)0 : TFLITE_ASSERT_FALSE +#endif + +#ifndef TFLITE_DCHECK_NE +#define TFLITE_DCHECK_NE(x, y) ((x) != (y)) ? (void)0 : TFLITE_ASSERT_FALSE +#endif + +#ifndef TFLITE_DCHECK_GE +#define TFLITE_DCHECK_GE(x, y) ((x) >= (y)) ? (void)0 : TFLITE_ASSERT_FALSE +#endif + +#ifndef TFLITE_DCHECK_GT +#define TFLITE_DCHECK_GT(x, y) ((x) > (y)) ? (void)0 : TFLITE_ASSERT_FALSE +#endif + +#ifndef TFLITE_DCHECK_LE +#define TFLITE_DCHECK_LE(x, y) ((x) <= (y)) ? (void)0 : TFLITE_ASSERT_FALSE +#endif + +#ifndef TFLITE_DCHECK_LT +#define TFLITE_DCHECK_LT(x, y) ((x) < (y)) ? (void)0 : TFLITE_ASSERT_FALSE +#endif + +// LINT.ThenChange(//tensorflow/lite/kernels/internal/compatibility.h) + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_KERNELS_INTERNAL_COMPATIBILITY_MACROS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/kernels/internal/cppmath.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/kernels/internal/cppmath.h new file mode 100644 index 00000000..49b66e10 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/kernels/internal/cppmath.h @@ -0,0 +1,43 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_KERNELS_INTERNAL_CPPMATH_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_KERNELS_INTERNAL_CPPMATH_H_ + +#include + +// LINT.IfChange + +namespace tflite_migration { + +#if defined(TF_LITE_USE_GLOBAL_CMATH_FUNCTIONS) || \ + (defined(__ANDROID__) && !defined(__NDK_MAJOR__)) || defined(__ZEPHYR__) +#define TF_LITE_GLOBAL_STD_PREFIX +#else +#define TF_LITE_GLOBAL_STD_PREFIX std +#endif + +#define DECLARE_STD_GLOBAL_SWITCH1(tf_name, std_name) \ + template \ + inline T tf_name(const T x) { \ + return TF_LITE_GLOBAL_STD_PREFIX::std_name(x); \ + } + +DECLARE_STD_GLOBAL_SWITCH1(TfLiteRound, round) + +} // namespace tflite_migration + +// LINT.ThenChange(//tensorflow/lite/kernels/internal/cppmath.h) + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_KERNELS_INTERNAL_CPPMATH_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/kernels/internal/optimized/neon_check.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/kernels/internal/optimized/neon_check.h new file mode 100644 index 00000000..ec3908d7 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/kernels/internal/optimized/neon_check.h @@ -0,0 +1,32 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_KERNELS_INTERNAL_OPTIMIZED_NEON_CHECK_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_KERNELS_INTERNAL_OPTIMIZED_NEON_CHECK_H_ + +// LINT.IfChange + +#if defined(__ARM_NEON__) || defined(__ARM_NEON) +#define USE_NEON +#include // IWYU pragma: export +#endif + +#if defined __GNUC__ && defined __SSE4_1__ && !defined TF_LITE_DISABLE_X86_NEON +#define USE_NEON +#include "NEON_2_SSE.h" // IWYU pragma: export +#endif + +// LINT.ThenChange(//tensorflow/lite/kernels/internal/optimized/neon_check.h) + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_KERNELS_INTERNAL_OPTIMIZED_NEON_CHECK_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/kernels/internal/quantization_util.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/kernels/internal/quantization_util.h new file mode 100644 index 00000000..b38391c3 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/kernels/internal/quantization_util.h @@ -0,0 +1,166 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_KERNELS_INTERNAL_QUANTIZATION_UTIL_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_KERNELS_INTERNAL_QUANTIZATION_UTIL_H_ + +#include +#include +#include + +namespace tflite_migration { + +// LINT.IfChange + +// Decompose a double multiplier into a Q0.31 int32 representation of its +// significand, and shift representation of its exponent. +// +// Restricted to the case where the multiplier > 1. +void QuantizeMultiplierGreaterThanOne(double double_multiplier, + int32_t* quantized_multiplier, + int* left_shift); + +// Decompose a double multiplier into a Q0.31 int32 representation of its +// significand, and shift representation of its exponent. +// +// Handles an arbitrary positive multiplier. The 'shift' output-value is +// basically the 'floating-point exponent' of the multiplier: +// Negative for a right-shift (when the multiplier is <1), positive for a +// left-shift (when the multiplier is >1) +void QuantizeMultiplier(double double_multiplier, int32_t* quantized_multiplier, + int* shift); + +// Splits a double input value into a returned fraction, and a shift value from +// the exponent, using only bitwise and integer operations to support +// microcontrollers and other environments without floating-point support. +// +// This is designed to be a replacement for how std::frexp() is used within the +// QuantizeMultiplier() function, and so has a different signature than the +// standard version, returning a 64-bit integer rather than a double. This +// result has a maximum value of 1<<31, with the fraction expressed as a +// proportion of that maximum. +// +// std::frexp() returns NaNs and infinities unmodified, but since we're +// returning integers that can't represent those values, instead we return +// a shift of std::numeric_limits::max() for all bad numbers, with an int64 +// result of 0 for NaNs, std:numeric_limits::max() for +INFINITY, and +// std::numeric_limits::min() for -INFINITY. Denormalized inputs will +// result in return values that end up truncating some bits at the end, +// reflecting the loss of precision inherent in denormalization. +int64_t IntegerFrExp(double input, int* shift); + +// Converts an integer fraction in the format produced by IntegerFrExp (where +// 0x40000000 is 1.0) and an exponent shift (between -1022 and +1022) into an +// IEEE binary64 double format result. The implementation uses only integer and +// bitwise operators, so no floating point hardware support or emulation is +// needed. This is here so quantized operations can run non-time-critical +// preparation calculations on microcontrollers and other platforms without +// float support. +double DoubleFromFractionAndShift(int64_t fraction, int shift); + +// Performs a multiplication of two numbers in double format, using only integer +// and bitwise instructions. This is aimed at supporting housekeeping functions +// for quantized operations on microcontrollers without floating-point hardware. +double IntegerDoubleMultiply(double a, double b); + +// Returns -1 if a is less than b, 0 if a and b are equal, and +1 if a is +// greater than b. It is implemented using only integer and logical instructions +// so that it can be easily run on microcontrollers for quantized operations. +int IntegerDoubleCompare(double a, double b); + +// This first creates a multiplier in a double equivalent of +// Q(input_integer_bits).(31-input_integer_bits) representation, with extra +// precision in the double's fractional bits. It then splits the result into +// significand and exponent. +void PreprocessSoftmaxScaling(double beta, double input_scale, + int input_integer_bits, + int32_t* quantized_multiplier, int* left_shift); +// Like PreprocessSoftmaxScaling, but inverse scaling factors also calculated. + +// Calculate the largest input that will result in a within-bounds intermediate +// result within MultiplyByQuantizedMultiplierGreaterThanOne. In other words, +// it must not overflow before we reduce the value by multiplication by the +// input multiplier. The negative radius is used as the minimum difference in +// Softmax. +int CalculateInputRadius(int input_integer_bits, int input_left_shift, + int total_signed_bits = 31); + +// Converts a floating-point number to an integer. For all inputs x where +// static_cast(x) is legal according to the C++ standard, the result +// is identical to that cast (i.e. the result is x with its fractional part +// truncated whenever that is representable as IntOut). +// +// static_cast would cause undefined behavior for the following cases, which +// have well-defined behavior for this function: +// +// 1. If x is NaN, the result is zero. +// +// 2. If the truncated form of x is above the representable range of IntOut, +// the result is std::numeric_limits::max(). +// +// 3. If the truncated form of x is below the representable range of IntOut, +// the result is std::numeric_limits::min(). +// +// Note that cases #2 and #3 cover infinities as well as finite numbers. +// +// The range of FloatIn must include the range of IntOut, otherwise +// the results are undefined. +// TODO(sfeuz): Replace by absl::SafeCast once available. +template +IntOut SafeCast(FloatIn x) { + static_assert(!std::numeric_limits::is_integer, + "FloatIn is integer"); + static_assert(std::numeric_limits::is_integer, + "IntOut is not integer"); + static_assert(std::numeric_limits::radix == 2, "IntOut is base 2"); + + // Special case NaN, for which the logic below doesn't work. + if (std::isnan(x)) { + return 0; + } + + // Negative values all clip to zero for unsigned results. + if (!std::numeric_limits::is_signed && x < 0) { + return 0; + } + + // Handle infinities. + if (std::isinf(x)) { + return x < 0 ? std::numeric_limits::min() + : std::numeric_limits::max(); + } + + // Set exp such that x == f * 2^exp for some f with |f| in [0.5, 1.0), + // unless x is zero in which case exp == 0. Note that this implies that the + // magnitude of x is strictly less than 2^exp. + int exp = 0; + std::frexp(x, &exp); + + // Let N be the number of non-sign bits in the representation of IntOut. If + // the magnitude of x is strictly less than 2^N, the truncated version of x + // is representable as IntOut. The only representable integer for which this + // is not the case is kMin for signed types (i.e. -2^N), but that is covered + // by the fall-through below. + if (exp <= std::numeric_limits::digits) { + return x; + } + + // Handle numbers with magnitude >= 2^N. + return x < 0 ? std::numeric_limits::min() + : std::numeric_limits::max(); +} +// LINT.ThenChange(//tensorflow/lite/kernels/internal/quantization_util.h) +} // namespace tflite_migration + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_KERNELS_INTERNAL_QUANTIZATION_UTIL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/kernels/internal/runtime_shape.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/kernels/internal/runtime_shape.h new file mode 100644 index 00000000..3a602ba9 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/kernels/internal/runtime_shape.h @@ -0,0 +1,263 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_KERNELS_INTERNAL_RUNTIME_SHAPE_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_KERNELS_INTERNAL_RUNTIME_SHAPE_H_ + +// This file is the MLIR copy of runtime_shape as part of the effort to +// decouple TFLite from MLIR. +// LINT.IfChange + +#include +#include +#include +#include +#include +#include + +#include "tensorflow/compiler/mlir/lite/kernels/internal/compatibility_macros.h" + +namespace mlir { + +template +struct Dims { + int sizes[N]; + int strides[N]; +}; + +class RuntimeShape { + public: + // Shapes with dimensions up to 6 are stored directly in the structure, while + // larger shapes are separately allocated. + static constexpr int kMaxSmallSize = 6; + + RuntimeShape& operator=(RuntimeShape const&) = delete; + + RuntimeShape() : size_(0) {} + + explicit RuntimeShape(int dimensions_count) : size_(dimensions_count) { + if (dimensions_count > kMaxSmallSize) { + dims_pointer_ = new int32_t[dimensions_count]; + } + } + + RuntimeShape(int shape_size, int32_t value) : size_(0) { + Resize(shape_size); + for (int i = 0; i < shape_size; ++i) { + SetDim(i, value); + } + } + + RuntimeShape(int dimensions_count, const int32_t* dims_data) : size_(0) { + ReplaceWith(dimensions_count, dims_data); + } + + RuntimeShape(const std::initializer_list init_list) : size_(0) { + BuildFrom(init_list); + } + + // Avoid using this constructor. We should be able to delete it when C++17 + // rolls out. + RuntimeShape(RuntimeShape const& other) : size_(other.DimensionsCount()) { + if (size_ > kMaxSmallSize) { + dims_pointer_ = new int32_t[size_]; + } + std::memcpy(DimsData(), other.DimsData(), sizeof(int32_t) * size_); + } + + bool operator==(const RuntimeShape& comp) const { + return this->size_ == comp.size_ && + std::memcmp(DimsData(), comp.DimsData(), size_ * sizeof(int32_t)) == + 0; + } + + ~RuntimeShape(); + + inline int32_t DimensionsCount() const { return size_; } + + int32_t Dims(int i) const; + + inline void SetDim(int i, int32_t val) { + TFLITE_DCHECK_GE(i, 0); + TFLITE_DCHECK_LT(i, size_); + if (size_ > kMaxSmallSize) { + dims_pointer_[i] = val; + } else { + dims_[i] = val; + } + } + + inline int32_t* DimsData() { + return size_ > kMaxSmallSize ? dims_pointer_ : dims_; + } + inline const int32_t* DimsData() const { + return size_ > kMaxSmallSize ? dims_pointer_ : dims_; + } + // The caller must ensure that the shape is no bigger than 5-D. + inline const int32_t* DimsDataUpTo5D() const { return dims_; } + + inline void Resize(int dimensions_count) { + const int32_t old_size = size_; + size_ = dimensions_count; + + if (old_size <= kMaxSmallSize) { + if (dimensions_count <= kMaxSmallSize) { + return; + } else { // Small to big. + int32_t* new_big_data = new int32_t[dimensions_count]; + memcpy(new_big_data, dims_, sizeof(int32_t) * old_size); + dims_pointer_ = new_big_data; + } + } else { + if (dimensions_count > kMaxSmallSize && dimensions_count <= old_size) { + return; + } + std::unique_ptr old_data(dims_pointer_); + if (dimensions_count <= old_size) { // Big to small. + memcpy(dims_, old_data.get(), sizeof(int32_t) * dimensions_count); + } else { // Big to bigger. + dims_pointer_ = new int32_t[dimensions_count]; + memcpy(dims_pointer_, old_data.get(), sizeof(int32_t) * old_size); + } + } + } + + void ReplaceWith(int dimensions_count, const int32_t* dims_data); + + template + inline void BuildFrom(const T& src_iterable) { + const int dimensions_count = + std::distance(src_iterable.begin(), src_iterable.end()); + Resize(dimensions_count); + int32_t* data = DimsData(); + for (auto it : src_iterable) { + *data = it; + ++data; + } + } + + // This will probably be factored out. Old code made substantial use of 4-D + // shapes, and so this function is used to extend smaller shapes. Note that + // (a) as Dims<4>-dependent code is eliminated, the reliance on this should be + // reduced, and (b) some kernels are stricly 4-D, but then the shapes of their + // inputs should already be 4-D, so this function should not be needed. + inline static RuntimeShape ExtendedShape(int new_shape_size, + const RuntimeShape& shape) { + return RuntimeShape(new_shape_size, shape, 1); + } + + inline void BuildFrom(const std::initializer_list init_list) { + BuildFrom>(init_list); + } + + // Returns the total count of elements, that is the size when flattened into a + // vector. + int FlatSize() const; + + bool operator!=(const RuntimeShape& comp) const { return !((*this) == comp); } + + private: + // For use only by ExtendedShape(), written to guarantee (return-value) copy + // elision in C++17. + // This creates a shape padded to the desired size with the specified value. + RuntimeShape(int new_shape_size, const RuntimeShape& shape, int pad_value) + : size_(0) { + // If the following check fails, it is likely because a 4D-only kernel is + // being used with an array of larger dimension count. + TFLITE_DCHECK_GE(new_shape_size, shape.DimensionsCount()); + Resize(new_shape_size); + const int size_increase = new_shape_size - shape.DimensionsCount(); + for (int i = 0; i < size_increase; ++i) { + SetDim(i, pad_value); + } + std::memcpy(DimsData() + size_increase, shape.DimsData(), + sizeof(int32_t) * shape.DimensionsCount()); + } + + int32_t size_; + union { + int32_t dims_[kMaxSmallSize]; + int32_t* dims_pointer_; + }; +}; + +// Converts inference-style shape to legacy tflite::Dims<4>. +inline mlir::Dims<4> ToRuntimeDims(const mlir::RuntimeShape& array_shape) { + mlir::Dims<4> result; + const int dimensions_count = array_shape.DimensionsCount(); + TFLITE_DCHECK_LE(dimensions_count, 4); + int cum_prod = 1; + for (int i = 0; i < 4; i++) { + const int new_dim = + (i < dimensions_count) ? array_shape.Dims(dimensions_count - 1 - i) : 1; + result.sizes[i] = new_dim; + result.strides[i] = cum_prod; + cum_prod *= new_dim; + } + return result; +} + +// TODO(b/80418076): Move to legacy ops file, update invocations. +inline RuntimeShape DimsToShape(const mlir::Dims<4>& dims) { + return RuntimeShape( + {dims.sizes[3], dims.sizes[2], dims.sizes[1], dims.sizes[0]}); +} + +// Since tensors with '0' in their shape are valid in TF, these offset functions +// allow that as long as the corresponding index is also 0. It is upto the +// calling ops to ensure that they perform verification checks on tensor shapes +// if they don't support a particular behavior. + +inline int Offset(const RuntimeShape& shape, int i0, int i1, int i2, int i3) { + TFLITE_DCHECK_EQ(shape.DimensionsCount(), 4); + const int* dims_data = reinterpret_cast(shape.DimsDataUpTo5D()); + TFLITE_DCHECK((dims_data[0] == 0 && i0 == 0) || + (i0 >= 0 && i0 < dims_data[0])); + TFLITE_DCHECK((dims_data[1] == 0 && i1 == 0) || + (i1 >= 0 && i1 < dims_data[1])); + TFLITE_DCHECK((dims_data[2] == 0 && i2 == 0) || + (i2 >= 0 && i2 < dims_data[2])); + TFLITE_DCHECK((dims_data[3] == 0 && i3 == 0) || + (i3 >= 0 && i3 < dims_data[3])); + return ((i0 * dims_data[1] + i1) * dims_data[2] + i2) * dims_data[3] + i3; +} + +inline int Offset(const RuntimeShape& shape, int i0, int i1, int i2, int i3, + int i4) { + TFLITE_DCHECK_EQ(shape.DimensionsCount(), 5); + const int* dims_data = reinterpret_cast(shape.DimsDataUpTo5D()); + TFLITE_DCHECK((dims_data[0] == 0 && i0 == 0) || + (i0 >= 0 && i0 < dims_data[0])); + TFLITE_DCHECK((dims_data[1] == 0 && i1 == 0) || + (i1 >= 0 && i1 < dims_data[1])); + TFLITE_DCHECK((dims_data[2] == 0 && i2 == 0) || + (i2 >= 0 && i2 < dims_data[2])); + TFLITE_DCHECK((dims_data[3] == 0 && i3 == 0) || + (i3 >= 0 && i3 < dims_data[3])); + TFLITE_DCHECK((dims_data[4] == 0 && i4 == 0) || + (i4 >= 0 && i4 < dims_data[4])); + return (((i0 * dims_data[1] + i1) * dims_data[2] + i2) * dims_data[3] + i3) * + dims_data[4] + + i4; +} + +inline int Offset(const RuntimeShape& shape, int* index) { + return Offset(shape, index[0], index[1], index[2], index[3]); +} + +} // namespace mlir + +// LINT.ThenChange(//tensorflow/lite/kernels/internal/runtime_shape.h) + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_KERNELS_INTERNAL_RUNTIME_SHAPE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/kernels/internal/utils/sparsity_format_converter.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/kernels/internal/utils/sparsity_format_converter.h new file mode 100644 index 00000000..56ba7181 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/kernels/internal/utils/sparsity_format_converter.h @@ -0,0 +1,102 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_KERNELS_INTERNAL_UTILS_SPARSITY_FORMAT_CONVERTER_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_KERNELS_INTERNAL_UTILS_SPARSITY_FORMAT_CONVERTER_H_ + +#include + +#include "Eigen/Core" // from @eigen_archive +#include "tensorflow/compiler/mlir/lite/core/c/tflite_types.h" + +namespace tflite_migration { +namespace internal { +namespace sparsity { + +// LINT.IfChange + +// A converter that keeps an internal representation of sparse tensor parameters +// and converts tensors between dense and sparse formats. +template +class FormatConverter { + public: + /* + * Creates a dense to sparse converter. + * @param shape Shape of the dense tensor. + * @param traversal_order In what order to traverse all dimensions, + * including block dimensions. + * @param format Whether each dimension in the dense tensor is + * dense or sparse (not in the traversal order). + * @param block_size Size of each block dimension. + * @param block_map Map from block dimension to original tensor + * dimension. + */ + FormatConverter(const std::vector& shape, + const std::vector& traversal_order, + const std::vector& format, + const std::vector& block_size = {}, + const std::vector& block_map = {}); + + const std::vector& GetData() { return data_; } + + const std::vector>& GetDimMetadata() { + return dim_metadata_; + } + + // Method for dense to sparse conversion. Need to call GetData() method to get + // the compressed data. + + void DenseToSparse(const T* src_data); + + // Check if val is equal to zero. + bool IsZero(const T val); + + // Shape of the conceptual dense tensor. + std::vector dense_shape_; + // Shape of the dense tensor with inner blocks reduced. For example, a (4, 4) + // tensor with (2, 2) block has blocked_shape (2, 2). + std::vector blocked_shape_; + // Total number of elements in the dense tensor. + size_t dense_size_; + // Has n(original dimension)+k(block_dimension) elements. + std::vector traversal_order_; + // Format of each dimension in the traversal order. + std::vector format_; + // Size of each block dimension, in the same order as block map. + std::vector block_size_; + // Map from block dimension to the original tensor dimension. + std::vector block_map_; + // Metadata of each dimension in the traversal order. + // Each dimension needs two vectors. For dense dimensions, the first vector + // stores the size of that dimension, and the second vector is empty. For + // sparse dimensions, the first vector stores the segments and the second one + // stores the indices. + std::vector> dim_metadata_; + // Actual buffer holding data after conversion. Could be sparse buffer or + // dense buffer. + std::vector data_; +}; + +extern template class FormatConverter; +extern template class FormatConverter; +extern template class FormatConverter; +extern template class FormatConverter; + +// LINT.ThenChange(//tensorflow/lite/kernels/internal/utils/sparsity_format_converter.h) + +} // namespace sparsity +} // namespace internal +} // namespace tflite_migration + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_KERNELS_INTERNAL_UTILS_SPARSITY_FORMAT_CONVERTER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/kernels/padding.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/kernels/padding.h new file mode 100644 index 00000000..b0dd6daf --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/kernels/padding.h @@ -0,0 +1,59 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_KERNELS_PADDING_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_KERNELS_PADDING_H_ + +// LINT.IfChange +#include "tensorflow/compiler/mlir/lite/core/c/builtin_op_data.h" + +namespace tflite_migration { + +// Matching GetWindowedOutputSize in TensorFlow. +inline int ComputeOutSize(TfLitePadding padding, int image_size, + int filter_size, int stride, int dilation_rate = 1) { + int effective_filter_size = (filter_size - 1) * dilation_rate + 1; + + // TODO(b/186448822): This uses 0 since the function has no other way to + // report error case + if (stride == 0) return 0; + + switch (padding) { + case kTfLitePaddingSame: + return (image_size + stride - 1) / stride; + case kTfLitePaddingValid: + return (image_size + stride - effective_filter_size) / stride; + default: + return 0; + } +} + +// It's not guaranteed that padding is symmetric. It's important to keep +// offset for algorithms need all paddings. +inline int ComputePaddingWithOffset(int stride, int dilation_rate, int in_size, + int filter_size, int out_size, + int* offset) { + int effective_filter_size = (filter_size - 1) * dilation_rate + 1; + int total_padding = + ((out_size - 1) * stride + effective_filter_size - in_size); + total_padding = total_padding > 0 ? total_padding : 0; + *offset = total_padding % 2; + return total_padding / 2; +} + +} // namespace tflite_migration + +// LINT.ThenChange(//tensorflow/lite/kernels/padding.h) + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_KERNELS_PADDING_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/metrics/error_collector.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/metrics/error_collector.h new file mode 100644 index 00000000..f21b0c47 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/metrics/error_collector.h @@ -0,0 +1,58 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_METRICS_ERROR_COLLECTOR_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_METRICS_ERROR_COLLECTOR_H_ + +#include +#include +#include + +#include "tensorflow/compiler/mlir/lite/metrics/converter_error_data.pb.h" +#include "tensorflow/compiler/mlir/lite/metrics/types_util.h" + +namespace mlir { +namespace TFL { + +// A singleton to store errors collected by the instrumentation. +class ErrorCollector { + using ConverterErrorData = tflite::metrics::ConverterErrorData; + using ConverterErrorDataSet = + std::unordered_set; + + public: + const ConverterErrorDataSet &CollectedErrors() { return collected_errors_; } + + void ReportError(const ConverterErrorData &error) { + collected_errors_.insert(error); + } + + // Clear the set of collected errors. + void Clear() { collected_errors_.clear(); } + + // Returns the global instance of ErrorCollector. + static ErrorCollector* GetErrorCollector(); + + private: + ErrorCollector() {} + + ConverterErrorDataSet collected_errors_; + + static ErrorCollector* error_collector_instance_; +}; + +} // namespace TFL +} // namespace mlir +#endif // TENSORFLOW_COMPILER_MLIR_LITE_METRICS_ERROR_COLLECTOR_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/metrics/error_collector_inst.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/metrics/error_collector_inst.h new file mode 100644 index 00000000..e3ac59a2 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/metrics/error_collector_inst.h @@ -0,0 +1,78 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_METRICS_ERROR_COLLECTOR_INST_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_METRICS_ERROR_COLLECTOR_INST_H_ + +#include +#include +#include +#include + +#include "mlir/IR/Diagnostics.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/Pass/PassInstrumentation.h" // from @llvm-project +#include "tensorflow/compiler/mlir/lite/metrics/converter_error_data.pb.h" +#include "tensorflow/compiler/mlir/lite/metrics/error_collector.h" +#include "tensorflow/compiler/mlir/lite/metrics/types_util.h" + +namespace mlir { +namespace TFL { + +// Collects errors when running the pass manager. +class ErrorCollectorInstrumentation : public PassInstrumentation { + using ConverterErrorData = tflite::metrics::ConverterErrorData; + using ErrorCode = ConverterErrorData::ErrorCode; + + public: + explicit ErrorCollectorInstrumentation(MLIRContext *context); + + private: + // Instrumentation hooks. These hooks don't need to be thread-safe. The pass + // manager runs each pass for the entire module, then it walks through + // each op in the module and runs the pass on them, may be in async mode. + void runBeforePass(Pass *pass, Operation *module) override; + void runAfterPass(Pass *pass, Operation *module) override; + void runAfterPassFailed(Pass *pass, Operation *module) override; + + // The handler to capture error messages. + std::unique_ptr handler_; + // A map from location to op name. + std::unordered_map loc_to_name_; + // Stores the error message for errors without op name and error code. + std::string common_error_message_; + // Name of the running pass. + std::string pass_name_; + // Pointer to the global ErrorCollector instance. + ErrorCollector *error_collector_; +}; + +// Prefix when adding error code as a note in Diagnostic. +constexpr char kErrorCodePrefix[] = "Error code: "; + +// Adds error code to a newly created InFlightDiagnostic. +inline InFlightDiagnostic AttachErrorCode(InFlightDiagnostic &&diag, + int error_code) { + using tflite::metrics::ConverterErrorData; + diag.attachNote() << kErrorCodePrefix + << ConverterErrorData::ErrorCode_Name(error_code); + return std::move(diag); +} + +} // namespace TFL +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_METRICS_ERROR_COLLECTOR_INST_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/metrics/types_util.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/metrics/types_util.h new file mode 100644 index 00000000..7fe31a38 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/metrics/types_util.h @@ -0,0 +1,71 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_METRICS_TYPES_UTIL_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_METRICS_TYPES_UTIL_H_ + +#include +#include +#include + +#include "mlir/IR/Location.h" // from @llvm-project +#include "tensorflow/compiler/mlir/lite/metrics/converter_error_data.pb.h" + +namespace mlir { +namespace TFL { + +// The hash function for mlir::Location. +struct LocationHash { + std::size_t operator()(const Location& v) const noexcept { + return hash_value(v); + } +}; + +// The hash function for ConverterErrorData. +struct ConverterErrorDataHash { + std::size_t operator()( + const tflite::metrics::ConverterErrorData& v) const noexcept { + std::size_t hash_result = std::hash{}(v.error_message()); + if (v.has_subcomponent()) { + hash_result ^= std::hash{}(v.subcomponent()) << 1; + } + if (v.has_error_code()) { + hash_result ^= std::hash{}(v.error_code()) << 2; + } + if (v.has_operator_() && v.operator_().has_name()) { + hash_result ^= std::hash{}(v.operator_().name()) << 3; + } + return hash_result; + } +}; + +// The comparison function for ConverterErrorData. +struct ConverterErrorDataComparison { + std::size_t operator()( + const tflite::metrics::ConverterErrorData& a, + const tflite::metrics::ConverterErrorData& b) const noexcept { + return ConverterErrorDataHash()(a) == ConverterErrorDataHash()(b); + } +}; + +// Helper function to create a new ConverterErrorData. +tflite::metrics::ConverterErrorData NewConverterErrorData( + const std ::string& pass_name, const std::string& error_message, + tflite::metrics::ConverterErrorData::ErrorCode error_code, + const std::string& op_name, const Location& location); + +} // namespace TFL +} // namespace mlir +#endif // TENSORFLOW_COMPILER_MLIR_LITE_METRICS_TYPES_UTIL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/offset_buffer.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/offset_buffer.h new file mode 100644 index 00000000..79e9d3f9 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/offset_buffer.h @@ -0,0 +1,31 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_OFFSET_BUFFER_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_OFFSET_BUFFER_H_ + +#include + +namespace tflite { + +// Check if the model is using custom_option_offset to store custom op +// buffers. When this field is not explicitly set by the user, then FlatBuffer +// will omit the field and interpret this as 0, to ensure this field is +// populated. The flatbuffer exporter will always set it to 1, and it's also not +// a valid buffer offset value. So it's only valid when it's > 1. +inline bool IsValidBufferOffset(const int64_t offset) { return offset > 1; } + +} // namespace tflite + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_OFFSET_BUFFER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/python/converter_python_api.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/python/converter_python_api.h new file mode 100644 index 00000000..cfcba696 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/python/converter_python_api.h @@ -0,0 +1,74 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_PYTHON_CONVERTER_PYTHON_API_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_PYTHON_CONVERTER_PYTHON_API_H_ + +#include + +#include +#include + +#include "tensorflow/compiler/mlir/quantization/tensorflow/python/py_function_lib.h" + +namespace tflite { + +// Convert a model represented in `input_contents`. `model_flags_proto` +// describes model parameters. `flags_proto` describes conversion +// parameters (see relevant .protos for more information). Returns a string +// representing the contents of the converted model. When extended_return +// flag is set to true returns a dictionary that contains string representation +// of the converted model and some statistics like arithmetic ops count. +// `debug_info_str` contains the `GraphDebugInfo` proto. +PyObject* Convert(PyObject* model_flags_proto_txt_raw, + PyObject* converter_flags_proto_txt_raw, + PyObject* input_contents_txt_raw, + bool extended_return = false, + PyObject* debug_info_txt_raw = nullptr, + const tensorflow::quantization::PyFunctionLibrary* + quantization_py_function_library = nullptr); + +// Quantize the model with calibration data. Throw errors if `fully_quantize` +// is specified by the calibration data are not sufficient to quantize the +// model. +PyObject* MlirQuantizeModel(PyObject* data, bool disable_per_channel, + bool fully_quantize, int inference_type, + int input_data_type, int output_data_type, + bool enable_numeric_verify = false, + bool enable_whole_model_verify = false, + PyObject* op_denylist = nullptr, + PyObject* node_denylist = nullptr, + bool enable_variable_quantization = false, + bool disable_per_channel_for_dense_layers = false, + PyObject* debug_options_proto_txt_raw = nullptr); + +// Sparsifies model to encode sparse tensors with proper format. Throws error if +// sparsification fails. +PyObject* MlirSparsifyModel(PyObject* data); + +// Registers the given custom opdefs to TensorFlow global op registry. +PyObject* RegisterCustomOpdefs(PyObject* list); + +// Returns the collected TFLite conversion errors. +std::vector RetrieveCollectedErrors(); + +// Returns MLIR string dump of the given Flatbuffer model. +std::string FlatBufferFileToMlir(const std::string& model, + bool input_is_filepath); + +// All the exported functions should be listed in +// tensorflow/tools/def_file_filter/symbols_pybind.txt for the Windows build. +} // namespace tflite + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_PYTHON_CONVERTER_PYTHON_API_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/python/flatbuffer_to_mlir.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/python/flatbuffer_to_mlir.h new file mode 100644 index 00000000..3164265f --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/python/flatbuffer_to_mlir.h @@ -0,0 +1,30 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_PYTHON_FLATBUFFER_TO_MLIR_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_PYTHON_FLATBUFFER_TO_MLIR_H_ + +#include + +namespace tensorflow { + +// Translates the given FlatBuffer filename or buffer into MLIR and returns +// translated MLIR as string. +std::string FlatBufferFileToMlir(const std::string& model_file_or_buffer, + bool input_is_filepath); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_PYTHON_FLATBUFFER_TO_MLIR_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/python/graphdef_to_tfl_flatbuffer.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/python/graphdef_to_tfl_flatbuffer.h new file mode 100644 index 00000000..a1a73863 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/python/graphdef_to_tfl_flatbuffer.h @@ -0,0 +1,38 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_PYTHON_GRAPHDEF_TO_TFL_FLATBUFFER_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_PYTHON_GRAPHDEF_TO_TFL_FLATBUFFER_H_ + +#include + +#include "absl/status/status.h" +#include "tensorflow/compiler/mlir/lite/converter_flags.pb.h" +#include "tensorflow/compiler/mlir/lite/model_flags.pb.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/graph_debug_info.pb.h" + +namespace tensorflow { + +// Converts the given GraphDef to a TF Lite FlatBuffer string according to the +// given model flags, converter flags and debug information. Returns error +// status if it fails to convert the input. +absl::Status ConvertGraphDefToTFLiteFlatBuffer( + const tflite::ModelFlags& model_flags, + tflite::ConverterFlags& converter_flags, const GraphDebugInfo& debug_info, + const GraphDef& input, std::string* result); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_PYTHON_GRAPHDEF_TO_TFL_FLATBUFFER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/python/interpreter_wrapper/python_error_reporter.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/python/interpreter_wrapper/python_error_reporter.h new file mode 100644 index 00000000..f98a3522 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/python/interpreter_wrapper/python_error_reporter.h @@ -0,0 +1,50 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_PYTHON_INTERPRETER_WRAPPER_PYTHON_ERROR_REPORTER_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_PYTHON_INTERPRETER_WRAPPER_PYTHON_ERROR_REPORTER_H_ + +#include + +#include +#include +#include + +#include "tensorflow/compiler/mlir/lite/stateful_error_reporter.h" + +namespace tflite_migration { +namespace interpreter_wrapper { + +class PythonErrorReporter : public tflite_migration::StatefulErrorReporter { + public: + PythonErrorReporter() = default; + + // Report an error message + int Report(const char* format, va_list args) override; + + // Sets a Python runtime exception with the last error and + // clears the error message buffer. + PyObject* exception(); + + // Gets the last error message and clears the buffer. + std::string message() override; + + private: + std::stringstream buffer_; +}; + +} // namespace interpreter_wrapper +} // namespace tflite_migration +#endif // TENSORFLOW_COMPILER_MLIR_LITE_PYTHON_INTERPRETER_WRAPPER_PYTHON_ERROR_REPORTER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/python/interpreter_wrapper/python_utils.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/python/interpreter_wrapper/python_utils.h new file mode 100644 index 00000000..8afc03ee --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/python/interpreter_wrapper/python_utils.h @@ -0,0 +1,34 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_PYTHON_INTERPRETER_WRAPPER_PYTHON_UTILS_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_PYTHON_INTERPRETER_WRAPPER_PYTHON_UTILS_H_ + +#include + +#include + +namespace mlirlite { +namespace python_utils { + +struct PyDecrefDeleter { + void operator()(PyObject* p) const { Py_DECREF(p); } +}; + +int ConvertFromPyString(PyObject* obj, char** data, Py_ssize_t* length); +PyObject* ConvertToPyString(const char* data, size_t length); + +} // namespace python_utils +} // namespace mlirlite +#endif // TENSORFLOW_COMPILER_MLIR_LITE_PYTHON_INTERPRETER_WRAPPER_PYTHON_UTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/python/jax_to_tfl_flatbuffer.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/python/jax_to_tfl_flatbuffer.h new file mode 100644 index 00000000..9008560f --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/python/jax_to_tfl_flatbuffer.h @@ -0,0 +1,37 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_PYTHON_JAX_TO_TFL_FLATBUFFER_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_PYTHON_JAX_TO_TFL_FLATBUFFER_H_ + +#include + +#include "tensorflow/compiler/mlir/lite/converter_flags.pb.h" +#include "tensorflow/compiler/mlir/lite/model_flags.pb.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +// Converts the given Jax model to a TF Lite FlatBuffer +// string according to the given model flags, converter flags and tags. Returns +// error status if it fails to convert the input. +absl::Status ConvertJaxToTFLiteFlatBuffer( + const std::string& input, const tflite::ModelFlags& model_flags, + tflite::ConverterFlags& converter_flags, string* result); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_PYTHON_JAX_TO_TFL_FLATBUFFER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.h new file mode 100644 index 00000000..92801047 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.h @@ -0,0 +1,39 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_PYTHON_SAVED_MODEL_TO_TFL_FLATBUFFER_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_PYTHON_SAVED_MODEL_TO_TFL_FLATBUFFER_H_ + +#include "tensorflow/compiler/mlir/lite/converter_flags.pb.h" +#include "tensorflow/compiler/mlir/lite/model_flags.pb.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/python/py_function_lib.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/graph_debug_info.pb.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +// Converts the given saved_model(either v1 or v2) to a TF Lite FlatBuffer +// string according to the given model flags, converter flags and tags. Returns +// error status if it fails to convert the input. +absl::Status ConvertSavedModelToTFLiteFlatBuffer( + const tflite::ModelFlags& model_flags, + tflite::ConverterFlags& converter_flags, string* result, + const quantization::PyFunctionLibrary* quantization_py_function_lib); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_PYTHON_SAVED_MODEL_TO_TFL_FLATBUFFER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.h new file mode 100644 index 00000000..de1e33f0 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.h @@ -0,0 +1,71 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_PYTHON_TF_TFL_FLATBUFFER_HELPERS_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_PYTHON_TF_TFL_FLATBUFFER_HELPERS_H_ + +#include +#include +#include +#include + +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/OwningOpRef.h" // from @llvm-project +#include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h" +#include "tensorflow/compiler/mlir/lite/converter_flags.pb.h" +#include "tensorflow/compiler/mlir/lite/model_flags.pb.h" +#include "tensorflow/compiler/mlir/lite/transforms/passes.h" +#include "tensorflow/compiler/mlir/lite/types.pb.h" +#include "tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_config.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/python/py_function_lib.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +namespace internal { + +// Register all custom ops including user specified custom ops. +absl::Status RegisterAllCustomOps( + const tflite::ConverterFlags& converter_flags); + +// Populate quantization specs (or not) given user specified ranges for each +// input arrays. +absl::Status PopulateQuantizationSpecs( + const tflite::ModelFlags& model_flags, + tflite::ConverterFlags& converter_flags, + mlir::quant::QuantizationSpecs* quant_specs, + std::vector* node_names, std::vector* node_dtypes, + std::vector>>* node_shapes, + std::vector>* node_mins, + std::vector>* node_maxs); + +// Convert imported MLIR file to TfLite flatbuffer. +// This will also run relevant passes as well. +absl::Status ConvertMLIRToTFLiteFlatBuffer( + const tflite::ModelFlags& model_flags, + tflite::ConverterFlags& converter_flags, + std::unique_ptr&& context, + mlir::OwningOpRef module, + const mlir::TFL::PassConfig& pass_config, + const std::unordered_set& saved_model_tags, string* result, + const quantization::PyFunctionLibrary* quantization_py_function_lib); + +// Give a warning for any unused flags that have been specified. +void WarningUnusedFlags(const tflite::ModelFlags& model_flags, + const tflite::ConverterFlags& converter_flags); +} // namespace internal +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_PYTHON_TF_TFL_FLATBUFFER_HELPERS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/quantization/device_target.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/quantization/device_target.h new file mode 100644 index 00000000..01072c50 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/quantization/device_target.h @@ -0,0 +1,196 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_DEVICE_TARGET_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_DEVICE_TARGET_H_ + +#include +#include +#include + +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/Hashing.h" +#include "llvm/ADT/MapVector.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringMap.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/raw_ostream.h" +#include "mlir/Dialect/Quant/IR/QuantTypes.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.h" +#include "tensorflow/compiler/mlir/lite/quantization/numerical_utils.h" + +namespace mlir { +namespace quant { + +class QuantizeContext; + +using AdjacentOperations = llvm::SmallVectorImpl; +using QuantizedMultipliers = llvm::SmallVector; +using QuantizedRanges = llvm::SmallVector; +using ScaleFn = std::function; + +using ScaleDecomposeFn = + std::function; + +static const QuantizedMultiplier kUnitQuantizedMultiplier{1, 0}; + +enum class ScaleConstraintType { + OutputInputSameScale, + OutputInputFreeScale, + CustomScale, +}; + +// Each kernel signature has its own specification for scales. +struct KernelSpec { + // Scale constraint + ScaleConstraintType type; + + // Custom function to derive the scales. Only available when the scale + // constraint is `CustomScale`. + ScaleFn scale_fn; +}; + +class KernelSpecs { + public: + using Signature = llvm::SmallVector; + + // Returns the kernel specification for the kernel signature. + std::optional Find(const Signature& signature) const { + auto spec_it = all_signatures_.find(signature); + if (spec_it != all_signatures_.end()) { + return spec_it->second; + } else { + return std::nullopt; + } + } + + ScaleDecomposeFn GetDecomposeFn() const { return decompose_fn_; } + + // Adds the kernel signature with the kernel specification. + LogicalResult Add(const Signature& signature, const KernelSpec& spec) { + if (all_signatures_.insert({signature, spec}).second) return success(); + return failure(); + } + + KernelSpecs& WithSignature(const KernelSpecs::Signature& signature, + const ScaleFn& fn) { + (void)Add(signature, {ScaleConstraintType::CustomScale, fn}); + return *this; + } + + KernelSpecs& WithImpl(const ScaleDecomposeFn& dfn) { + decompose_fn_ = dfn; + return *this; + } + + private: + // The signature is pattern match based. + struct SignatureInfo : public llvm::DenseMapInfo { + static inline Signature getEmptyKey() { return {}; } + static inline Signature getTombstoneKey() { return {nullptr}; } + static unsigned getHashValue(Signature val) { + return llvm::hash_combine_range(val.begin(), val.end()); + } + static bool isEqual(Signature LHS, Signature RHS) { + if (RHS == getEmptyKey()) return LHS == getEmptyKey(); + if (RHS == getTombstoneKey()) return LHS == getTombstoneKey(); + if (LHS.size() != RHS.size()) return false; + for (auto arg : llvm::zip(LHS, RHS)) { + if (std::get<0>(arg) != std::get<1>(arg)) return false; + } + return true; + } + }; + + // Maps the signature to the kernel spec. Note that the matching is + // pattern match based. + llvm::DenseMap all_signatures_; + + // A method to compute the effective multipliers. This is independent on the + // bits of the ports, thus all the signature shares the same here. + ScaleDecomposeFn decompose_fn_; +}; + +class DeviceTarget { + public: + explicit DeviceTarget(MLIRContext* ctx); + + // Retrieves the kernel spec for the quant region op. + std::optional GetKernelSpec( + llvm::StringRef kernel, const KernelSpecs::Signature& signature) const; + + // Retrieves the scale decomposition function for the quant region op. + ScaleDecomposeFn GetDecomposeFn(quantfork::QuantizeRegionOp op) const; + + // converts specification to signature: + // - UniformedQuantizedType -> AnyQuantizedType + // - AnyQuantizedType (int) -> AnyQuantizedType + // - Float -> {} + static void AppendToSignature(Type spec, KernelSpecs::Signature* signature); + + protected: + // Adds the kernel spec with the custom scale function for the kernel. + LogicalResult RegisterKernel(llvm::StringRef kernel, + const KernelSpecs::Signature& signature, + const ScaleFn& fn, const ScaleDecomposeFn& dfn); + + // Adds the kernel spec with the scale constraint type for the kernel. + LogicalResult RegisterKernel(llvm::StringRef kernel, + const KernelSpecs::Signature& signature, + ScaleConstraintType constraint); + + // Adds the kernel with the name. Retrun an existing one if it has been + // added before. + KernelSpecs& RegisterKernel(llvm::StringRef kernel) { return specs_[kernel]; } + + // For "mulmat->add" type of kernels, convert the scales of all the ports to + // multipliers. + static LogicalResult DecomposeMultiplyAccumulateScale( + Operation* op, QuantizedMultipliers* input_multipliers, + QuantizedMultipliers* output_multipliers, QuantizedRanges* output_ranges); + + // For "reshape" type of kernels. + static LogicalResult DecomposeSameScale( + Operation* op, QuantizedMultipliers* input_multipliers, + QuantizedMultipliers* output_multipliers, QuantizedRanges* output_ranges); + + // A set of parameters are required to build the signatures. + FloatType f32_; + IntegerType i8_, i32_; + int64_t i8_min_, i8_max_, i32_min_, i32_max_; + quant::AnyQuantizedType any_, qi8_, qi8n_, qi32_; + + private: + // Maps the kernel names to all the available kernels. + llvm::StringMap specs_; + + // Points to the global MLIRContext. + MLIRContext* ctx_; +}; + +} // namespace quant +} // namespace mlir +#endif // TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_DEVICE_TARGET_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/quantization/ir/Passes.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/quantization/ir/Passes.h new file mode 100644 index 00000000..06f4697f --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/quantization/ir/Passes.h @@ -0,0 +1,57 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// +// This file defines all of the passes owned by the quantization dialect. As +// things mature, it is expected that passes specific to certain frontend or +// backend dialects will move to those dialects directly. For now, they are +// incubated here. +// +//===----------------------------------------------------------------------===// + +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_IR_PASSES_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_IR_PASSES_H_ + +#include "mlir/Pass/Pass.h" // from @llvm-project + +namespace mlir { +namespace func { +class FuncOp; +} // namespace func + +namespace quantfork { + +/// Creates a pass that converts quantization simulation operations (i.e. +/// FakeQuant and those like it) to casts into/out of supported QuantizedTypes. +std::unique_ptr> createConvertSimulatedQuantPass(); + +/// Creates a pass that converts constants followed by a qbarrier to a +/// constant whose value is quantized. This is typically one of the last +/// passes done when lowering to express actual quantized arithmetic in a +/// low level representation. Because it modifies the constant, it is +/// destructive and cannot be undone. +std::unique_ptr> createConvertConstPass(); + +//===----------------------------------------------------------------------===// +// Registration +//===----------------------------------------------------------------------===// + +/// Generate the code for registering passes. +#define GEN_PASS_REGISTRATION +#include "tensorflow/compiler/mlir/lite/quantization/ir/Passes.h.inc" + +} // namespace quantfork +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_IR_PASSES_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.h new file mode 100644 index 00000000..bee081a1 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.h @@ -0,0 +1,34 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_IR_QUANTOPS_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_IR_QUANTOPS_H_ + +#include "llvm/Support/MathExtras.h" +#include "mlir/Bytecode/BytecodeOpInterface.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/Dialect.h" // from @llvm-project +#include "mlir/IR/OpDefinition.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/Interfaces/InferTypeOpInterface.h" // from @llvm-project +#include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project +#include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOpsDialect.h.inc" +#define GET_OP_CLASSES + +#include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.h.inc" + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_IR_QUANTOPS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/quantization/ir/QuantizeUtils.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/quantization/ir/QuantizeUtils.h new file mode 100644 index 00000000..bfc6afb8 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/quantization/ir/QuantizeUtils.h @@ -0,0 +1,71 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_IR_QUANTIZEUTILS_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_IR_QUANTIZEUTILS_H_ + +namespace mlir { +class Attribute; +class Type; + +namespace quant { +class QuantizedType; +class UniformQuantizedType; +} // namespace quant +namespace quantfork { +class UniformQuantizedValueConverter; + +/// Converts an attribute from a type based on +/// quantizedElementType.getExpressedType() to one based on +/// quantizedElementType.getStorageType(), where quantizedElementType is as from +/// QuantizedType::getQuantizedElementType(). +/// Returns nullptr if the conversion is not supported. On success, stores the +/// converted type in outConvertedType. +/// +/// Examples: +/// 1. realValue is a primitive value attribute: +/// (realValue: FloatAttr, quantizedElementType: UniformQuantizedType[i8:f32]) +/// -> (IntegerAttr, outConvertedType: i8) +/// 2. realValue is an elements attribute: +/// (realValue: DenseElementsAttr[tensor<2x2xf32>], +/// quantizedElementType: UniformQuantizedType[i8:f32]) +/// -> (DenseElementsAttr[tensor<2x2xi8>], outConvertedType: tensor<2x2xi8>) +Attribute quantizeAttr(Attribute realValue, + quant::QuantizedType quantizedElementType, + Type &outConvertedType); + +/// Converts an attribute from a type based on +/// quantizedElementType.getExpressedType() to one based on +/// quantizedElementType.getStorageType(), where quantizedElementType is as from +/// QuantizedType::getQuantizedElementType() and casted to an +/// UniformQuantizedType. Returns nullptr if the conversion is not supported. On +/// success, stores the converted type in outConvertedType. +/// +/// Examples: +/// 1. realValue is a primitive value attribute: +/// (realValue: FloatAttr, quantizedElementType: UniformQuantizedType[i8:f32]) +/// -> (IntegerAttr, outConvertedType: i8) +/// 2. realValue is an elements attribute: +/// (realValue: DenseElementsAttr[tensor<2x2xf32>], +/// quantizedElementType: UniformQuantizedType[i8:f32]) +/// -> (DenseElementsAttr[tensor<2x2xi8>], outConvertedType: tensor<2x2xi8>) +Attribute quantizeAttrUniform(Attribute realValue, + quant::UniformQuantizedType quantizedElementType, + const UniformQuantizedValueConverter &converter, + Type &outConvertedType); +} // namespace quantfork +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_IR_QUANTIZEUTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.h new file mode 100644 index 00000000..9257f533 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.h @@ -0,0 +1,66 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_LITE_QUANTIZE_MODEL_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_LITE_QUANTIZE_MODEL_H_ + +#include +#include +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/status/status.h" +#include "absl/strings/string_view.h" +#include "tensorflow/compiler/mlir/lite/debug/debug_options.pb.h" +#include "tensorflow/compiler/mlir/lite/schema/schema_generated.h" + +namespace mlir { +namespace lite { + +// Quantizes the input model represented as `model_buffer` and writes the result +// to the `output_buffer`. Both `model_buffer` and `output_buffer` should be a +// valid FlatBuffer format for Model supported by TFLite. +// +// The `input_type`, `output_type` and `inference_type` can be float32 / qint8 / +// int8 / int16. +// +// Returns a partially quantized model if `fully_quantize` is false. Returns a +// non-OK status if the quantization fails. +// +// When `verify_numeric` is true, the model will have it's original float ops +// and NumericVerify ops to compare output values from the quantized and float +// ops. +// +// When `legacy_float_scale` is true, the quantizer will use float scale instead +// of double, and call TOCO's quantization routines to maintain bit-exactness of +// the values with the TOCO quantizer. +absl::Status QuantizeModel( + absl::string_view model_buffer, const tflite::TensorType &input_type, + const tflite::TensorType &output_type, + const tflite::TensorType &inference_type, + const std::unordered_set &operator_names, + bool disable_per_channel, bool fully_quantize, std::string &output_buffer, + bool verify_numeric = false, bool whole_model_verify = false, + bool legacy_float_scale = true, + const absl::flat_hash_set &denylisted_ops = {}, + const absl::flat_hash_set &denylisted_nodes = {}, + bool enable_variable_quantization = false, + bool disable_per_channel_for_dense_layers = false, + const std::optional + &debug_options = std::nullopt); + +} // namespace lite +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_LITE_QUANTIZE_MODEL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/quantization/lite/quantize_weights.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/quantization/lite/quantize_weights.h new file mode 100644 index 00000000..65d044e4 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/quantization/lite/quantize_weights.h @@ -0,0 +1,91 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_LITE_QUANTIZE_WEIGHTS_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_LITE_QUANTIZE_WEIGHTS_H_ + +#include +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/status/status.h" +#include "flatbuffers/flatbuffer_builder.h" // from @flatbuffers +#include "tensorflow/compiler/mlir/lite/schema/schema_generated.h" + +namespace mlir { +namespace lite { + +// Supported resulting types from quantization process. +enum class BufferType { QUANTIZED_INT8, QUANTIZED_FLOAT16 }; + +// Stores information about how to quantize a user-specified custom operation. +// CustomOpInfo contains info of its corresponding CustomOp registered in the +// CustomOpMap. 'quantizable_input_indices' is used to determine which indices +// of the CustomOp are quantizable. 'is_weight_only' is used specify whether the +// custom op is quantized only for storage and dequantized at runtime. +// 'no_side_effect' is used to determine whether the op can be pruned if +// considered as trivially dead. +struct CustomOpInfo { + std::vector quantizable_input_indices; + bool is_weight_only = false; + bool no_side_effect = true; +}; + +using BuiltinOperatorSet = absl::flat_hash_set; +// Map from custom op code to custom op quantization information. +using CustomOpMap = std::unordered_map; + +// Applies dynamic range quantization for the given model wehre the input_model +// type is flatbuffer but is converted to MLIR during quantization process and +// then converted back to flatbuffer for return. Note that this is part of +// reaching feature parity with the old quantizer for dynamic range +// quantization, specifically for +// third_party/tensorflow/lite/tools/optimize/quantize_weights.h. +// TODO(b/202468183): Selective quantization + quant debugger support for +// dynamic range quantization for verify_numeric and whole_model_verify flags. +absl::Status QuantizeWeights( + flatbuffers::FlatBufferBuilder* builder, const tflite::Model* input_model, + const tflite::TensorType& inference_type, + const absl::flat_hash_set& denylisted_ops, + const CustomOpMap& custom_op_map, + int64_t minimum_elements_for_weights = 1024, + bool disable_per_channel = false, bool weight_only_quantization = false, + bool legacy_float_scale = false); + +// Overloading methods to support old quantizer versions API +absl::Status QuantizeWeights(flatbuffers::FlatBufferBuilder* builder, + const tflite::Model* input_model, + int64_t weights_min_num_elements, + bool use_hybrid_evaluation = true); + +absl::Status QuantizeWeights(flatbuffers::FlatBufferBuilder* builder, + const tflite::Model* input_model, + BufferType quant_type = BufferType::QUANTIZED_INT8, + bool use_updated_hybrid_scheme = true); + +absl::Status QuantizeWeights(flatbuffers::FlatBufferBuilder* builder, + const tflite::Model* input_model, + int64_t weights_min_num_elements, + const CustomOpMap& custom_op_map, + bool use_updated_hybrid_scheme = true, + const BuiltinOperatorSet& op_denylist = {}); + +} // namespace lite +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_LITE_QUANTIZE_WEIGHTS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/quantization/lite/test_util.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/quantization/lite/test_util.h new file mode 100644 index 00000000..8953a384 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/quantization/lite/test_util.h @@ -0,0 +1,145 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_LITE_TEST_UTIL_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_LITE_TEST_UTIL_H_ + +#include + +#include "tensorflow/compiler/mlir/lite/core/api/error_reporter.h" + +namespace mlir { +namespace lite { +namespace internal { +// Test model with a single convolution. +// Floating point weights of the model are all integers and lie in +// range[-127, 127]. The weights have been put in such a way that each +// channel has at least one weight as -127 and one weight as 127. +// The activations are all in range: [-128, 127] +// This means all bias computations should result in 1.0 scale. +extern const char* kConvModelWithMinus128Plus127Weights; + +// Test model with single convolution where all weights are integers between +// [0, 10] weights are randomly distributed. It is not guaranteed that min max +// for weights are going to appear in each channel. +// Activations have min = 0, max = 10. +extern const char* kConvModelWith0Plus10Weights; + +// Test model where no bias is in the conv. +extern const char* kConvModelWithNoBias; + +// A floating point model with a single softmax. The input tensor has min +// and max in range [-5, 5], not necessarily -5 or +5. +extern const char* kSingleSoftmaxModelMinMinus5MaxPlus5; + +// A floating point model with a single average pool. The input tensor has min +// and max in range [-5, 5], not necessarily -5 or +5. +extern const char* kSingleAvgPoolModelMinMinus5MaxPlus5; + +// Test model with a weights variable that is shared between a convolution layer +// and an add operation. +extern const char* kModelWithSharedWeights; + +// Test model with Add followed by a reshape. Model has 2 inputs for add. +extern const char* kMultiInputAddWithReshape; + +// Test gather operation with quantized input. +extern const char* kQuantizedWithGather; + +// Test model with a tf.constant input to tf.add. Model has 2 inputs one +// constant and other placeholder. +extern const char* kConstInputAddModel; + +// A float test model with concat that has [0, 5] and [0, 10] for inputs and [0, +// 10] as output. +extern const char* kFloatConcatMax5Max10Max10; + +// Test model with broadcast_to op. +extern const char* kModelWithBroadcastToOp; + +// Test model with a custom op. +extern const char* kModelWithCustomOp; + +// Test model with a argmax op. +extern const char* kModelWithArgMaxOp; + +// Test model with a fully connected op. +extern const char* kModelWithFCOp; + +// Test model with a gather_nd op. +extern const char* kModelWithGatherNDOp; + +// Test model with a Where op. +extern const char* kModelWithWhereOp; + +// Test model with mixed quantizable and un-quantizable ops. +// reshape->custom->custom->squeeze. +extern const char* kModelMixed; + +// Test model with mixed quantizable and +// and un-quantizable ops for +// activations in 16-bit. +extern const char* kModelMixed16x8; + +// Test model with split op. +extern const char* kModelSplit; + +// Test model with pack op. +extern const char* kModelPack; + +// Test model with LSTM op that has layer norm, has projection, without +// peephole, without cifg. +extern const char* kLstmCalibrated; +extern const char* kLstmQuantized; + +// Test model with LSTM op that has peephole, without layer norm, without +// projection, without cifg. +extern const char* kLstmCalibrated2; +extern const char* kLstmQuantized2; + +extern const char* kUnidirectionalSequenceLstmCalibrated; +extern const char* kUnidirectionalSequenceLstmQuantized; + +// Test model with a minimum op. +extern const char* kModelWithMinimumOp; + +// Test model with a maximum op. +extern const char* kModelWithMaximumOp; + +// Test model with a transpose op. +extern const char* kModelWithTranspose; + +// Test model with SVDF op. +extern const char* kSvdfCalibrated; +extern const char* kSvdfQuantized; + +// Test model with an unpack op. +extern const char* kModelWithUnpack; + +// Test QAT model with fc op. +extern const char* kQatModelWithFc; + +// Test calibrated model with resource variables. +extern const char* kModelWithResourceVarsCalibrated; + +// An error reporter that fails on testing. +class FailOnErrorReporter : public tflite::ErrorReporter { + public: + int Report(const char* format, va_list args) override; +}; +} // namespace internal +} // namespace lite +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_LITE_TEST_UTIL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/quantization/lite/tfl_to_std.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/quantization/lite/tfl_to_std.h new file mode 100644 index 00000000..94742d11 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/quantization/lite/tfl_to_std.h @@ -0,0 +1,64 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_LITE_TFL_TO_STD_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_LITE_TFL_TO_STD_H_ + +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project + +namespace mlir { +namespace TFL { + +// Converts all the tfl.quantize/tfl.dequantize ops to the ops in the mlir.quant +// dialect ones in the function. +void ConvertTFLQuantOpsToMlirQuantOps(func::FuncOp func); + +// Converts all the mlir.quant dialect ops to the tfl.quantize/tfl.dequantize +// ops in the function. +void ConvertMlirQuantOpsToTFLQuantOps(func::FuncOp func); + +// A helper class to convert target function to another representation using +// `ConvertForward` function during construction and convert target function +// back to the original representation using `ConvertBackward` function during +// deconstruction. +template +class ScopedOpsConverter { + public: + explicit ScopedOpsConverter(func::FuncOp func) : func_(func) { + ConvertForward(func_); + } + + ScopedOpsConverter(const ScopedOpsConverter&) = delete; + ScopedOpsConverter operator=(const ScopedOpsConverter&) = delete; + ScopedOpsConverter(const ScopedOpsConverter&&) = delete; + ScopedOpsConverter operator=(const ScopedOpsConverter&&) = delete; + + ~ScopedOpsConverter() { ConvertBackward(func_); } + + private: + func::FuncOp func_; +}; + +using ScopedTFLQuantOpsToMlirQuantOpsConverter = + ScopedOpsConverter; +using ScopedMlirQuantOpsToTFLQuantOpsConverter = + ScopedOpsConverter; +} // namespace TFL +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_LITE_TFL_TO_STD_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/quantization/lite/toco_legacy/model_utils.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/quantization/lite/toco_legacy/model_utils.h new file mode 100644 index 00000000..5841a4c7 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/quantization/lite/toco_legacy/model_utils.h @@ -0,0 +1,60 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// This file is the MLIR copy of part of +// third_party/tensorflow/lite/tools/optimize/model_utils.h as part of the +// effort to decouple TFLite from MLIR. + +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_LITE_TOCO_LEGACY_MODEL_UTILS_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_LITE_TOCO_LEGACY_MODEL_UTILS_H_ + +#include +#include +#include +#include + +#include "tensorflow/compiler/mlir/lite/schema/schema_generated.h" + +namespace mlir { +namespace lite { +namespace toco_legacy { + +using std::string; +using tflite::ModelT; +using tflite::OperatorT; +using tflite::TensorT; +using tflite::TensorType; + +// LINT.IfChange(MakeDequantizeOperator) +// Creates a Dequantize OperatorT object. +void MakeDequantizeOperator(ModelT* model, std::unique_ptr* op, + int32_t input, int32_t output); +// LINT.ThenChange(//tensorflow/lite/tools/optimize/model_utils.h:MakeDequantizeOperator) + +// LINT.IfChange(MakeTensor) +// Create a new TensorT object without quantization parameters. +void MakeTensor(const string& name, const std::vector& shape, + const std::vector& shape_signature, + const TensorType& type, std::unique_ptr* tensor); +// LINT.ThenChange(//tensorflow/lite/tools/optimize/model_utils.h:MakeTensor) + +// LINT.IfChange(HasMinMax) +bool HasMinMax(const TensorT* tensor); +// LINT.ThenChange(//tensorflow/lite/tools/optimize/model_utils.h:HasMinMax) + +} // namespace toco_legacy +} // namespace lite +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_LITE_TOCO_LEGACY_MODEL_UTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/quantization/lite/toco_legacy/portable_tensor_utils.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/quantization/lite/toco_legacy/portable_tensor_utils.h new file mode 100644 index 00000000..7bc80a1b --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/quantization/lite/toco_legacy/portable_tensor_utils.h @@ -0,0 +1,42 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// This file is the MLIR copy of part of +// third_party/tensorflow/lite/kernels/internal/reference/portable_tensor_utils.h +// as part of the effort to decouple TFLite from MLIR. + +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_LITE_TOCO_LEGACY_PORTABLE_TENSOR_UTILS_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_LITE_TOCO_LEGACY_PORTABLE_TENSOR_UTILS_H_ + +#include + +namespace mlir { +namespace lite { +namespace toco_legacy { + +// LINT.IfChange(portable_symmetric_quantize_floats) +void PortableSymmetricQuantizeFloats(const float* values, const int size, + int8_t* quantized_values, float* min_value, + float* max_value, float* scaling_factor); + +void PortableSymmetricQuantizeFloats(const float* values, const int size, + int8_t* quantized_values, float min_value, + float max_value, float* scaling_factor); +// LINT.ThenChange(//tensorflow/lite/kernels/internal/reference/portable_tensor_utils.h:portable_symmetric_quantize_floats) + +} // namespace toco_legacy +} // namespace lite +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_LITE_TOCO_LEGACY_PORTABLE_TENSOR_UTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/quantization/lite/toco_legacy/quantization_utils.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/quantization/lite/toco_legacy/quantization_utils.h new file mode 100644 index 00000000..bd68ed1c --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/quantization/lite/toco_legacy/quantization_utils.h @@ -0,0 +1,110 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// This file is the MLIR copy of part of +// third_party/tensorflow/lite/tools/optimize/quantization_utils.h as part of +// the effort to decouple TFLite from MLIR. + +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_LITE_TOCO_LEGACY_QUANTIZATION_UTILS_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_LITE_TOCO_LEGACY_QUANTIZATION_UTILS_H_ + +#include +#include +#include + +#include "absl/status/status.h" +#include "tensorflow/compiler/mlir/lite/schema/schema_generated.h" + +namespace mlir { +namespace lite { +namespace toco_legacy { + +using tflite::ModelT; +using tflite::QuantizationParametersT; +using tflite::TensorT; +using tflite::TensorType; + +// LINT.IfChange(num_elements) +// Returns the number of elements in the given tensor. +absl::Status NumElements(const TensorT& tensor, uint64_t* num_elements); +// LINT.ThenChange(//tensorflow/lite/tools/optimize/quantization_utils.h:num_elements) + +// LINT.IfChange(fill_per_channel_min_max) +// Populates the max and min values for per channel quantization. +absl::Status FillPerChannelMinMax(const float* const input, + const std::vector& dimension, + int32_t channel_dim_index, + QuantizationParametersT* quantization_params); +// LINT.ThenChange(//tensorflow/lite/tools/optimize/quantization_utils.h:fill_per_channel_min_max) + +// LINT.IfChange(symmetric_per_channel_quantization) +// Per-channel quantize a tensor at the given index and returns both scales and +// quantized values. +// Parameters: +// - tensor is the tensor to be quantized, needed to access associated +// quantization parameters +// - input is the float input data to be quantized. +// - channel_dim_index is the channel index within "dimension". +// dimension[channel_dim_index] gives the number of channels. +// - output_scale is the output scale, the size of which equals the number of +// channels. +// - output_value is the output data, the size of which equals the number of +// inputs. +absl::Status SymmetricPerChannelQuantization(TensorT* tensor, + const float* const input, + int32_t channel_dim_index, + std::vector* output_scales, + std::vector* output_value); +// LINT.ThenChange(//tensorflow/lite/tools/optimize/quantization_utils.h:symmetric_per_channel_quantization) + +// LINT.IfChange(symmetric_per_channel_quantize_values) +// Quantize the values given an array of scales. +void SymmetricPerChannelQuantizeValues(const float* const input, + const std::vector& scales_inv, + const std::vector& dimension, + int32_t channel_dim_index, + std::vector* output_value); +// LINT.ThenChange(//tensorflow/lite/tools/optimize/quantization_utils.h:symmetric_per_channel_quantize_values) + +// LINT.IfChange(symmetric_quantize_tensor) +// Quantizes tensor using symmetric quantization with the min and max elements +// of the tensor. +absl::Status SymmetricQuantizeTensor(ModelT* model, TensorT* tensor); +// LINT.ThenChange(//tensorflow/lite/tools/optimize/quantization_utils.h:symmetric_quantize_tensor) + +// LINT.IfChange(symmetric_quantize_tensor_per_channel) +// Quantizes tensor with per channel. +absl::Status SymmetricQuantizeTensorPerChannel(ModelT* model, TensorT* tensor, + int32_t channel_dim_index); +// LINT.ThenChange(//tensorflow/lite/tools/optimize/quantization_utils.h:symmetric_quantize_tensor_per_channel) + +// LINT.IfChange(quantize_tensor_float16) +// Quantizes tensor to float16. +absl::Status QuantizeTensorFloat16(ModelT* model, TensorT* tensor); +// LINT.ThenChange(//tensorflow/lite/tools/optimize/quantization_utils.h:quantize_tensor_float16) + +// LINT.IfChange(add_quantization_params) +absl::Status AddQuantizationParams(const std::vector& scales, + const std::vector& zero_point, + int quantized_dimension, + const uint8_t* buffer_data, + size_t buffer_size, TensorType output_type, + ModelT* model, TensorT* tensor); +// LINT.ThenChange(//tensorflow/lite/tools/optimize/quantization_utils.h:add_quantization_params) + +} // namespace toco_legacy +} // namespace lite +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_LITE_TOCO_LEGACY_QUANTIZATION_UTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/quantization/lite/toco_legacy/quantize_weights.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/quantization/lite/toco_legacy/quantize_weights.h new file mode 100644 index 00000000..039c18d8 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/quantization/lite/toco_legacy/quantize_weights.h @@ -0,0 +1,109 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_LITE_TOCO_LEGACY_QUANTIZE_WEIGHTS_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_LITE_TOCO_LEGACY_QUANTIZE_WEIGHTS_H_ + +#include +#include +#include +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/status/status.h" +#include "flatbuffers/flatbuffer_builder.h" // from @flatbuffers +#include "tensorflow/compiler/mlir/lite/schema/schema_generated.h" + +namespace mlir { +namespace lite { +namespace toco_legacy { + +using ::tflite::BuiltinOperator; +using ::tflite::Model; + +// Supported resulting types from quantization process. +enum class BufferType { QUANTIZED_INT8, QUANTIZED_FLOAT16 }; +enum class QuantizerType { OLD_QUANTIZER, MLIR_QUANTIZER }; + +// Stores information about how to quantize a user-specified custom operation. +struct CustomOpInfo { + std::vector quantizable_input_indices; + bool is_hybrid; +}; + +// Map from custom op code to custom op quantization information. +using CustomOpMap = std::unordered_map; + +// This macro is for internal use for conversions requiring previous behavior. +#ifdef TFLITE_USE_PREVIOUS_HYBRID_SCHEME +// Use asymmetric quantized activations and per-channel quantized weights. +constexpr bool kUseUpdatedHybridSchemeDefault = false; +#else +// Use symmetric quantized activations and per-channel quantized weights. +constexpr bool kUseUpdatedHybridSchemeDefault = true; +#endif + +// Quantizes input_model and populates the provided builder with the new model. +// By default only weights tensors weight more than 1024 elements will be +// quantized. +// +// A tflite::Model can be obtained from the builder with: +// const uint8_t* buffer = builder->GetBufferPointer(); +// tflite::Model* model = GetModel(buffer); +absl::Status QuantizeWeights( + flatbuffers::FlatBufferBuilder* builder, const Model* input_model, + BufferType quant_type = BufferType::QUANTIZED_INT8, + bool use_updated_hybrid_scheme = kUseUpdatedHybridSchemeDefault, + QuantizerType quantizer_type = QuantizerType::OLD_QUANTIZER); + +// Same as above, but only weights with greater than or equal +// weights_min_num_elements elements will be quantized. +absl::Status QuantizeWeights( + flatbuffers::FlatBufferBuilder* builder, const Model* input_model, + uint64_t weights_min_num_elements, + QuantizerType quantizer_type = QuantizerType::OLD_QUANTIZER); + +// Same as above, but with entry point of quantizing custom ops. +absl::Status QuantizeWeights( + flatbuffers::FlatBufferBuilder* builder, const Model* input_model, + uint64_t weights_min_num_elements, const CustomOpMap& custom_op_map, + QuantizerType quantizer_type = QuantizerType::OLD_QUANTIZER); + +// Same as above, but if use updated_hybrid_scheme is false, +// use previous quantization scheme. Optional op_denylist argument +// disables hybrid evaluation for provided BuiltinOperators. +absl::Status QuantizeWeights( + flatbuffers::FlatBufferBuilder* builder, const Model* input_model, + uint64_t weights_min_num_elements, const CustomOpMap& custom_op_map, + bool use_updated_hybrid_scheme, + const absl::flat_hash_set& op_denylist = {}, + QuantizerType quantizer_type = QuantizerType::OLD_QUANTIZER); + +namespace internal { +// If use_hybrid_evaluation is false, will disable using hybrid eval for +// operations that support it. +// +// We use this internal QuantizeWeights call to test models with hybrid +// evaluation disabled. +absl::Status QuantizeWeights( + flatbuffers::FlatBufferBuilder* builder, const Model* input_model, + uint64_t weights_min_num_elements, bool use_hybrid_evaluation, + QuantizerType quantizer_type = QuantizerType::OLD_QUANTIZER); +} // namespace internal + +} // namespace toco_legacy +} // namespace lite +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_LITE_TOCO_LEGACY_QUANTIZE_WEIGHTS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/quantization/numerical_utils.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/quantization/numerical_utils.h new file mode 100644 index 00000000..d938cd2c --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/quantization/numerical_utils.h @@ -0,0 +1,46 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_NUMERICAL_UTILS_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_NUMERICAL_UTILS_H_ + +#include +#include +#include + +#include "absl/types/optional.h" + +namespace mlir { +namespace quant { + +using QuantizedMultiplier = std::pair; +using QuantizedRange = std::pair; + +// Decompose double precision multiplier to integer multiplier and exponent. +// double_multiplier = int_multiplier * 2 ^ (-31 + exponent) +// int_multiplier will be range of (2^31, 2^30]. +QuantizedMultiplier QuantizeMultiplier(double double_multiplier); + +// Calculate the effective quantized value range for the scale, zero point. The +// range is the minimum range defined by [rmin, rmax] and [qmin, qmax]. +QuantizedRange CalculateQuantizedRange(double scale, int32_t zero_point, + std::optional rmin, + std::optional rmax, int32_t qmin, + int32_t qmax); + +} // namespace quant +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_NUMERICAL_UTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/quantization/quantization_context.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/quantization/quantization_context.h new file mode 100644 index 00000000..a1f40f86 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/quantization/quantization_context.h @@ -0,0 +1,245 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_QUANTIZATION_CONTEXT_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_QUANTIZATION_CONTEXT_H_ + +#include +#include +#include + +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/SmallVector.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Dialect/Quant/IR/QuantTypes.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "tensorflow/compiler/mlir/lite/quantization/device_target.h" +#include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.h" +#include "tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_utils.h" + +namespace mlir { +namespace quant { + +static bool EmptyParams(QuantParams p) { return p == quant::QuantizedType(); } + +// The state for each op result during the quantization parameters propagation. +struct QuantState { + // Quantization parameters propagated to an op result. + QuantParams params; + // A flag indicates this state (the params) shouldn't be changed after it is + // initialized. This flag will be set to true if the quantization parameters + // are from the quantization-aware training. + const bool immutable; + + bool IsEmpty() { return EmptyParams(params); } +}; + +// The state for rescaling the propagated quantization parameters. This can be +// on the input side to satisfy the constraint of previous operation, or on the +// output side to satisfy the constraint of the next operation. +struct RequantizeState { + // Sometimes, we have to "requantize" the quantization result to satisfy all + // the constraints. The "requantize" can happen either on the input or output + // of the quantization result. + enum RequantizePosition { + NO_REQUANTIZE, + ON_INPUT, + ON_OUTPUT + } pos = NO_REQUANTIZE; + + // Quantization parameters will be used to add the requantize ops. + QuantParams params; +}; + +// This class manages all the intermediate quantization states. +class QuantizeContext { + public: + QuantizeContext(func::FuncOp func, const DeviceTarget &spec); + + // Returns all the quant region ops. + std::vector GetAllOps(); + + // For each quant region op, propagates its quantization parameters according + // to the kernel specification and also returns the adjacent quant region ops + // which get the new quantization parameters propagated. + LogicalResult Handle(quantfork::QuantizeRegionOp op, + llvm::SmallVectorImpl *new_items, + bool *changed); + + // Updates the port quantization specifications of all the quant region ops + // with the propagation results. + LogicalResult Finalize(); + + // Dumps the states stores in the state manager. + void DumpStates(quantfork::QuantizeRegionOp current_op = {}); + + // Update the quantization parameter for certain result of the op. By this + // method, the quantization parameter is propagated to all the users of the + // result as well. + bool SetResultParams(Operation *op, int index, QuantParams params) { + return states_manager_.SetResultParams(op, index, params); + } + + // Update the quantization parameter for certain operand of the op. By this + // method, the quantization parameter is propagated to the defining op of + // operand as well. + bool SetOperandParams(Operation *op, int index, QuantParams params) { + return states_manager_.SetOperandParams(op, index, params); + } + + // Return the quantization parameter of certain result of the op. + QuantParams GetResultParams(Operation *op, int index) { + return states_manager_.GetResultParams(op, index); + } + + // Return the quantization parameter of certain operand of the op. + QuantParams GetOperandParams(Operation *op, int index) { + return states_manager_.GetOperandParams(op, index); + } + + // Return the signature of the op. + KernelSpecs::Signature GetSignature(quantfork::QuantizeRegionOp op); + + // A heuristic to get quantization parameters satisfies the same scale + // constraints: + // - If there are immutable states, + // - use the single input, or, + // - use the single output, or, + // - use the first one in the collection, + // - use the single input if it is ready, or, + // - use the single output if it is ready, or, + // - use the first ready one in the collection. + QuantParams GetQuantParamsForSameScaleConstraint(Operation *op); + + // Propagate `params` to all the quantizable port of the `op`. The adjacent + // ops, which have the parameters propagated to, are collected by `new_items`, + // so they can be added to the working queue. `changed` is set to true if + // there are any new elements being added to `new_items`. + LogicalResult PropagateQuantParams(Operation *op, const QuantParams params, + AdjacentOperations *new_items, + bool *changed); + + private: + class StatesManager { + public: + // Sets the quantization parameters of the constant result according to its + // content. + // + // Always returns true. + bool SetConstantResultParams(Operation *op); + + // Sets the quantization parameters of the result to a fixed value. If any + // quantization parameters have been propagated, a `requantize` will happen + // on the input of propagated quantization. + // + // Returns true, if the users of the result needs to be added to the + // worklist. + bool SetResultParams(Operation *op, int index, QuantParams params); + + // Sets the quantization parameters of the operand to a fixed value. If any + // quantization parameters have been propagated, a `requantize` will happen + // on the output of propagated quantization. + // + // Returns true, if the defining op of the operand needs to be added to the + // worklist. + bool SetOperandParams(Operation *op, int index, QuantParams params); + + // Returns the quantization parameters of the index-th result of the op. + QuantParams GetResultParams(Operation *op, int index) { + return states_[result_states_[{op, index}]].params; + } + + // Returns the quantization parameters of the index-th operand of the op. + QuantParams GetOperandParams(Operation *op, int index) { + return states_[operand_states_[{op, index}]].params; + } + + private: + friend class QuantizeContext; + + // Uses the type of `val` to set the initial state of the index-th result if + // `as_result` is true or index-th operand if `as_result` is false. The + // state is immutable if the type is a quantized type. Returns the index of + // this new state in the state vector. + int InitializeState(quantfork::QuantizeRegionOp op, int index, + bool as_result); + + // Sets the state of the index-th operand of the op. If this operand is + // cached, uses the cached result without creating new entry in the state + // vector. Otherwise, allocate a new entry in the state vector. + void InitializeOperandState(quantfork::QuantizeRegionOp op, int index, + llvm::DenseMap *cache); + + // Sets the state of the index-th result of the op. If this result is + // cached, uses the cached result without creating new entry in the state + // vector. Otherwise, allocate a new entry in the state vector. + void InitializeResultState(quantfork::QuantizeRegionOp op, int index, + llvm::DenseMap *cache); + + // Returns the state of the index-th operand of the op. + QuantState &GetOperandQuantState(Operation *op, int index) { + return states_[operand_states_[{op, index}]]; + } + + // Returns the state of the index-th result of the op. + QuantState &GetResultQuantState(Operation *op, int index) { + return states_[result_states_[{op, index}]]; + } + + // Returns the state of the index-th operand of the op. + RequantizeState &GetOperandRequantizeState(Operation *op, int index) { + return rescale_states_[operand_states_[{op, index}]]; + } + + // Returns the state of the index-th result of the op. + RequantizeState &GetResultRequantizeState(Operation *op, int index) { + return rescale_states_[result_states_[{op, index}]]; + } + + private: + // This is used to identify an operand or result of an op. The second + // element of this pair is the index of the operand or result. + using OpValue = std::pair; + + // The vector contains all the quantization parameters propagated from the + // defining operations of the value, or from the quantization aware + // training. + std::vector states_; + + // The map contains all the quantization parameters which are required to + // satisfy the same operands and results constraint. The keys of this map + // are the values from `operand_states_` and `result_state_`. + std::unordered_map rescale_states_; + + // Maps of indexes to the propagation state vector from the ops operands, + // results and arguments. + llvm::DenseMap operand_states_; + llvm::DenseMap result_states_; + }; + + func::FuncOp func_; + + DeviceTarget target_spec_; + + StatesManager states_manager_; +}; + +} // namespace quant +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_QUANTIZATION_CONTEXT_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/quantization/quantization_passes.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/quantization/quantization_passes.h new file mode 100644 index 00000000..5c119a65 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/quantization/quantization_passes.h @@ -0,0 +1,46 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_QUANTIZATION_PASSES_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_QUANTIZATION_PASSES_H_ + +#include +#include + +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassManager.h" // from @llvm-project + +namespace mlir { +namespace quant { + +using OperationToName = std::function; + +// Creates an instance pass to import quantization stats to the operations in +// the function. A custom method to get the name from the op is used because +// different dialect ops might have different ways to assign the name. +std::unique_ptr> CreateImportQuantStatsPass( + OperationToName op_to_name, const std::string& stats_str); + +// Creates an instance pass to import quantization stats to the operations in +// the function. A custom method to get the name from the op is used because +// different dialect ops might have different ways to assign the name. +std::unique_ptr> +CreateImportQuantStatsPassForTFControlDialect(const std::string& stats_str); + +} // namespace quant +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_QUANTIZATION_PASSES_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/quantization/stablehlo/quantization.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/quantization/stablehlo/quantization.h new file mode 100644 index 00000000..c55d59ca --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/quantization/stablehlo/quantization.h @@ -0,0 +1,60 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Adaptor functions for StableHLO Quantizer. +// Provides simpler interfaces when integrating StableHLO Quantizer into TFLite +// Converter. + +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_STABLEHLO_QUANTIZATION_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_STABLEHLO_QUANTIZATION_H_ + +#include +#include + +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "tensorflow/cc/saved_model/loader.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/python/py_function_lib.h" + +namespace tensorflow { + +// Runs quantization on `module_op`. `saved_model_bundle` is required to +// retrieve information about the original model (e.g. signature def mapping) +// because quantization requires exporting the intermediate `ModuleOp` back to +// SavedModel for calibration. Similarly, `saved_model_dir` is required to +// access the assets of the original model. `saved_model_tags` uniquely +// identifies the `MetaGraphDef`. `quantization_config` determines the behavior +// of StableHLO Quantizer. `quantization_py_function_lib` contains python +// implementations of certain APIs that are required for calibration. +// `module_op` is the input graph to be quantized and it should contain +// StableHLO ops. +// +// Returns a quantized `ModuleOp` in StableHLO, potentially wrapped inside a +// XlaCallModuleOp. Returns a non-OK status if quantization fails, or any of +// `saved_model_bundle` or `quantization_py_function_lib` is a nullptr. +absl::StatusOr RunQuantization( + const SavedModelBundle* saved_model_bundle, + absl::string_view saved_model_dir, + const std::unordered_set& saved_model_tags, + const stablehlo::quantization::QuantizationConfig& quantization_config, + const tensorflow::quantization::PyFunctionLibrary* + quantization_py_function_lib, + mlir::ModuleOp module_op); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_STABLEHLO_QUANTIZATION_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/quantization/tensorflow/passes.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/quantization/tensorflow/passes.h new file mode 100644 index 00000000..a552cc65 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/quantization/tensorflow/passes.h @@ -0,0 +1,38 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_TENSORFLOW_PASSES_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_TENSORFLOW_PASSES_H_ + +#include +#include + +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project + +namespace mlir { +namespace TF { + +// Legalize the tf ops to the quant ops, so the quantization passes can work. +std::unique_ptr> CreateLegalizeTFToQuantPass(); + +// Fallbacks ops that are not supported by TF Quantization to TFLite Flex ops. +std::unique_ptr> CreateFallbackToFlexOpsPass( + const std::string &mode = "DEFAULT"); + +} // namespace TF +} // namespace mlir +#endif // TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_TENSORFLOW_PASSES_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/schema/conversion_metadata_generated.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/schema/conversion_metadata_generated.h new file mode 100755 index 00000000..12af129c --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/schema/conversion_metadata_generated.h @@ -0,0 +1,672 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// automatically generated by the FlatBuffers compiler, do not modify + + +#ifndef FLATBUFFERS_GENERATED_CONVERSIONMETADATA_TFLITE_H_ +#define FLATBUFFERS_GENERATED_CONVERSIONMETADATA_TFLITE_H_ + +#include "flatbuffers/flatbuffers.h" + +// Ensure the included flatbuffers.h is the same version as when this file was +// generated, otherwise it may not be compatible. +static_assert(FLATBUFFERS_VERSION_MAJOR == 24 && + FLATBUFFERS_VERSION_MINOR == 3 && + FLATBUFFERS_VERSION_REVISION == 25, + "Non-compatible flatbuffers version included"); + +namespace tflite { + +struct Environment; +struct EnvironmentBuilder; +struct EnvironmentT; + +struct SparsityBlockSize; +struct SparsityBlockSizeBuilder; +struct SparsityBlockSizeT; + +struct ConversionOptions; +struct ConversionOptionsBuilder; +struct ConversionOptionsT; + +struct ConversionMetadata; +struct ConversionMetadataBuilder; +struct ConversionMetadataT; + +enum ModelType : int32_t { + ModelType_NONE = 0, + ModelType_TF_SAVED_MODEL = 1, + ModelType_KERAS_MODEL = 2, + ModelType_TF_CONCRETE_FUNCTIONS = 3, + ModelType_TF_GRAPH_DEF = 4, + ModelType_TF_SESSION = 5, + ModelType_JAX = 6, + ModelType_MIN = ModelType_NONE, + ModelType_MAX = ModelType_JAX +}; + +inline const ModelType (&EnumValuesModelType())[7] { + static const ModelType values[] = { + ModelType_NONE, + ModelType_TF_SAVED_MODEL, + ModelType_KERAS_MODEL, + ModelType_TF_CONCRETE_FUNCTIONS, + ModelType_TF_GRAPH_DEF, + ModelType_TF_SESSION, + ModelType_JAX + }; + return values; +} + +inline const char * const *EnumNamesModelType() { + static const char * const names[8] = { + "NONE", + "TF_SAVED_MODEL", + "KERAS_MODEL", + "TF_CONCRETE_FUNCTIONS", + "TF_GRAPH_DEF", + "TF_SESSION", + "JAX", + nullptr + }; + return names; +} + +inline const char *EnumNameModelType(ModelType e) { + if (::flatbuffers::IsOutRange(e, ModelType_NONE, ModelType_JAX)) return ""; + const size_t index = static_cast(e); + return EnumNamesModelType()[index]; +} + +enum ModelOptimizationMode : int32_t { + ModelOptimizationMode_PTQ_FLOAT16 = 1001, + ModelOptimizationMode_PTQ_DYNAMIC_RANGE = 1002, + ModelOptimizationMode_PTQ_FULL_INTEGER = 1003, + ModelOptimizationMode_PTQ_INT16 = 1004, + ModelOptimizationMode_QUANTIZATION_AWARE_TRAINING = 2000, + ModelOptimizationMode_RANDOM_SPARSITY = 3001, + ModelOptimizationMode_BLOCK_SPARSITY = 3002, + ModelOptimizationMode_STRUCTURED_SPARSITY = 3003, + ModelOptimizationMode_MIN = ModelOptimizationMode_PTQ_FLOAT16, + ModelOptimizationMode_MAX = ModelOptimizationMode_STRUCTURED_SPARSITY +}; + +inline const ModelOptimizationMode (&EnumValuesModelOptimizationMode())[8] { + static const ModelOptimizationMode values[] = { + ModelOptimizationMode_PTQ_FLOAT16, + ModelOptimizationMode_PTQ_DYNAMIC_RANGE, + ModelOptimizationMode_PTQ_FULL_INTEGER, + ModelOptimizationMode_PTQ_INT16, + ModelOptimizationMode_QUANTIZATION_AWARE_TRAINING, + ModelOptimizationMode_RANDOM_SPARSITY, + ModelOptimizationMode_BLOCK_SPARSITY, + ModelOptimizationMode_STRUCTURED_SPARSITY + }; + return values; +} + +inline const char *EnumNameModelOptimizationMode(ModelOptimizationMode e) { + switch (e) { + case ModelOptimizationMode_PTQ_FLOAT16: return "PTQ_FLOAT16"; + case ModelOptimizationMode_PTQ_DYNAMIC_RANGE: return "PTQ_DYNAMIC_RANGE"; + case ModelOptimizationMode_PTQ_FULL_INTEGER: return "PTQ_FULL_INTEGER"; + case ModelOptimizationMode_PTQ_INT16: return "PTQ_INT16"; + case ModelOptimizationMode_QUANTIZATION_AWARE_TRAINING: return "QUANTIZATION_AWARE_TRAINING"; + case ModelOptimizationMode_RANDOM_SPARSITY: return "RANDOM_SPARSITY"; + case ModelOptimizationMode_BLOCK_SPARSITY: return "BLOCK_SPARSITY"; + case ModelOptimizationMode_STRUCTURED_SPARSITY: return "STRUCTURED_SPARSITY"; + default: return ""; + } +} + +struct EnvironmentT : public ::flatbuffers::NativeTable { + typedef Environment TableType; + std::string tensorflow_version{}; + uint32_t api_version = 0; + tflite::ModelType model_type = tflite::ModelType_NONE; +}; + +struct Environment FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef EnvironmentT NativeTableType; + typedef EnvironmentBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_TENSORFLOW_VERSION = 4, + VT_API_VERSION = 6, + VT_MODEL_TYPE = 8 + }; + const ::flatbuffers::String *tensorflow_version() const { + return GetPointer(VT_TENSORFLOW_VERSION); + } + uint32_t api_version() const { + return GetField(VT_API_VERSION, 0); + } + tflite::ModelType model_type() const { + return static_cast(GetField(VT_MODEL_TYPE, 0)); + } + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyOffset(verifier, VT_TENSORFLOW_VERSION) && + verifier.VerifyString(tensorflow_version()) && + VerifyField(verifier, VT_API_VERSION, 4) && + VerifyField(verifier, VT_MODEL_TYPE, 4) && + verifier.EndTable(); + } + EnvironmentT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(EnvironmentT *_o, const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + static ::flatbuffers::Offset Pack(::flatbuffers::FlatBufferBuilder &_fbb, const EnvironmentT* _o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct EnvironmentBuilder { + typedef Environment Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_tensorflow_version(::flatbuffers::Offset<::flatbuffers::String> tensorflow_version) { + fbb_.AddOffset(Environment::VT_TENSORFLOW_VERSION, tensorflow_version); + } + void add_api_version(uint32_t api_version) { + fbb_.AddElement(Environment::VT_API_VERSION, api_version, 0); + } + void add_model_type(tflite::ModelType model_type) { + fbb_.AddElement(Environment::VT_MODEL_TYPE, static_cast(model_type), 0); + } + explicit EnvironmentBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateEnvironment( + ::flatbuffers::FlatBufferBuilder &_fbb, + ::flatbuffers::Offset<::flatbuffers::String> tensorflow_version = 0, + uint32_t api_version = 0, + tflite::ModelType model_type = tflite::ModelType_NONE) { + EnvironmentBuilder builder_(_fbb); + builder_.add_model_type(model_type); + builder_.add_api_version(api_version); + builder_.add_tensorflow_version(tensorflow_version); + return builder_.Finish(); +} + +inline ::flatbuffers::Offset CreateEnvironmentDirect( + ::flatbuffers::FlatBufferBuilder &_fbb, + const char *tensorflow_version = nullptr, + uint32_t api_version = 0, + tflite::ModelType model_type = tflite::ModelType_NONE) { + auto tensorflow_version__ = tensorflow_version ? _fbb.CreateString(tensorflow_version) : 0; + return tflite::CreateEnvironment( + _fbb, + tensorflow_version__, + api_version, + model_type); +} + +::flatbuffers::Offset CreateEnvironment(::flatbuffers::FlatBufferBuilder &_fbb, const EnvironmentT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct SparsityBlockSizeT : public ::flatbuffers::NativeTable { + typedef SparsityBlockSize TableType; + std::vector values{}; +}; + +struct SparsityBlockSize FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef SparsityBlockSizeT NativeTableType; + typedef SparsityBlockSizeBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_VALUES = 4 + }; + const ::flatbuffers::Vector *values() const { + return GetPointer *>(VT_VALUES); + } + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyOffset(verifier, VT_VALUES) && + verifier.VerifyVector(values()) && + verifier.EndTable(); + } + SparsityBlockSizeT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(SparsityBlockSizeT *_o, const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + static ::flatbuffers::Offset Pack(::flatbuffers::FlatBufferBuilder &_fbb, const SparsityBlockSizeT* _o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct SparsityBlockSizeBuilder { + typedef SparsityBlockSize Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_values(::flatbuffers::Offset<::flatbuffers::Vector> values) { + fbb_.AddOffset(SparsityBlockSize::VT_VALUES, values); + } + explicit SparsityBlockSizeBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateSparsityBlockSize( + ::flatbuffers::FlatBufferBuilder &_fbb, + ::flatbuffers::Offset<::flatbuffers::Vector> values = 0) { + SparsityBlockSizeBuilder builder_(_fbb); + builder_.add_values(values); + return builder_.Finish(); +} + +inline ::flatbuffers::Offset CreateSparsityBlockSizeDirect( + ::flatbuffers::FlatBufferBuilder &_fbb, + const std::vector *values = nullptr) { + auto values__ = values ? _fbb.CreateVector(*values) : 0; + return tflite::CreateSparsityBlockSize( + _fbb, + values__); +} + +::flatbuffers::Offset CreateSparsityBlockSize(::flatbuffers::FlatBufferBuilder &_fbb, const SparsityBlockSizeT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct ConversionOptionsT : public ::flatbuffers::NativeTable { + typedef ConversionOptions TableType; + std::vector model_optimization_modes{}; + bool allow_custom_ops = false; + bool enable_select_tf_ops = false; + bool force_select_tf_ops = false; + std::vector> sparsity_block_sizes{}; + ConversionOptionsT() = default; + ConversionOptionsT(const ConversionOptionsT &o); + ConversionOptionsT(ConversionOptionsT&&) FLATBUFFERS_NOEXCEPT = default; + ConversionOptionsT &operator=(ConversionOptionsT o) FLATBUFFERS_NOEXCEPT; +}; + +struct ConversionOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef ConversionOptionsT NativeTableType; + typedef ConversionOptionsBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_MODEL_OPTIMIZATION_MODES = 4, + VT_ALLOW_CUSTOM_OPS = 6, + VT_ENABLE_SELECT_TF_OPS = 8, + VT_FORCE_SELECT_TF_OPS = 10, + VT_SPARSITY_BLOCK_SIZES = 12 + }; + const ::flatbuffers::Vector *model_optimization_modes() const { + return GetPointer *>(VT_MODEL_OPTIMIZATION_MODES); + } + bool allow_custom_ops() const { + return GetField(VT_ALLOW_CUSTOM_OPS, 0) != 0; + } + bool enable_select_tf_ops() const { + return GetField(VT_ENABLE_SELECT_TF_OPS, 0) != 0; + } + bool force_select_tf_ops() const { + return GetField(VT_FORCE_SELECT_TF_OPS, 0) != 0; + } + const ::flatbuffers::Vector<::flatbuffers::Offset> *sparsity_block_sizes() const { + return GetPointer> *>(VT_SPARSITY_BLOCK_SIZES); + } + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyOffset(verifier, VT_MODEL_OPTIMIZATION_MODES) && + verifier.VerifyVector(model_optimization_modes()) && + VerifyField(verifier, VT_ALLOW_CUSTOM_OPS, 1) && + VerifyField(verifier, VT_ENABLE_SELECT_TF_OPS, 1) && + VerifyField(verifier, VT_FORCE_SELECT_TF_OPS, 1) && + VerifyOffset(verifier, VT_SPARSITY_BLOCK_SIZES) && + verifier.VerifyVector(sparsity_block_sizes()) && + verifier.VerifyVectorOfTables(sparsity_block_sizes()) && + verifier.EndTable(); + } + ConversionOptionsT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(ConversionOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + static ::flatbuffers::Offset Pack(::flatbuffers::FlatBufferBuilder &_fbb, const ConversionOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct ConversionOptionsBuilder { + typedef ConversionOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_model_optimization_modes(::flatbuffers::Offset<::flatbuffers::Vector> model_optimization_modes) { + fbb_.AddOffset(ConversionOptions::VT_MODEL_OPTIMIZATION_MODES, model_optimization_modes); + } + void add_allow_custom_ops(bool allow_custom_ops) { + fbb_.AddElement(ConversionOptions::VT_ALLOW_CUSTOM_OPS, static_cast(allow_custom_ops), 0); + } + void add_enable_select_tf_ops(bool enable_select_tf_ops) { + fbb_.AddElement(ConversionOptions::VT_ENABLE_SELECT_TF_OPS, static_cast(enable_select_tf_ops), 0); + } + void add_force_select_tf_ops(bool force_select_tf_ops) { + fbb_.AddElement(ConversionOptions::VT_FORCE_SELECT_TF_OPS, static_cast(force_select_tf_ops), 0); + } + void add_sparsity_block_sizes(::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset>> sparsity_block_sizes) { + fbb_.AddOffset(ConversionOptions::VT_SPARSITY_BLOCK_SIZES, sparsity_block_sizes); + } + explicit ConversionOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateConversionOptions( + ::flatbuffers::FlatBufferBuilder &_fbb, + ::flatbuffers::Offset<::flatbuffers::Vector> model_optimization_modes = 0, + bool allow_custom_ops = false, + bool enable_select_tf_ops = false, + bool force_select_tf_ops = false, + ::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset>> sparsity_block_sizes = 0) { + ConversionOptionsBuilder builder_(_fbb); + builder_.add_sparsity_block_sizes(sparsity_block_sizes); + builder_.add_model_optimization_modes(model_optimization_modes); + builder_.add_force_select_tf_ops(force_select_tf_ops); + builder_.add_enable_select_tf_ops(enable_select_tf_ops); + builder_.add_allow_custom_ops(allow_custom_ops); + return builder_.Finish(); +} + +inline ::flatbuffers::Offset CreateConversionOptionsDirect( + ::flatbuffers::FlatBufferBuilder &_fbb, + const std::vector *model_optimization_modes = nullptr, + bool allow_custom_ops = false, + bool enable_select_tf_ops = false, + bool force_select_tf_ops = false, + const std::vector<::flatbuffers::Offset> *sparsity_block_sizes = nullptr) { + auto model_optimization_modes__ = model_optimization_modes ? _fbb.CreateVector(*model_optimization_modes) : 0; + auto sparsity_block_sizes__ = sparsity_block_sizes ? _fbb.CreateVector<::flatbuffers::Offset>(*sparsity_block_sizes) : 0; + return tflite::CreateConversionOptions( + _fbb, + model_optimization_modes__, + allow_custom_ops, + enable_select_tf_ops, + force_select_tf_ops, + sparsity_block_sizes__); +} + +::flatbuffers::Offset CreateConversionOptions(::flatbuffers::FlatBufferBuilder &_fbb, const ConversionOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct ConversionMetadataT : public ::flatbuffers::NativeTable { + typedef ConversionMetadata TableType; + std::unique_ptr environment{}; + std::unique_ptr options{}; + ConversionMetadataT() = default; + ConversionMetadataT(const ConversionMetadataT &o); + ConversionMetadataT(ConversionMetadataT&&) FLATBUFFERS_NOEXCEPT = default; + ConversionMetadataT &operator=(ConversionMetadataT o) FLATBUFFERS_NOEXCEPT; +}; + +struct ConversionMetadata FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef ConversionMetadataT NativeTableType; + typedef ConversionMetadataBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_ENVIRONMENT = 4, + VT_OPTIONS = 6 + }; + const tflite::Environment *environment() const { + return GetPointer(VT_ENVIRONMENT); + } + const tflite::ConversionOptions *options() const { + return GetPointer(VT_OPTIONS); + } + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyOffset(verifier, VT_ENVIRONMENT) && + verifier.VerifyTable(environment()) && + VerifyOffset(verifier, VT_OPTIONS) && + verifier.VerifyTable(options()) && + verifier.EndTable(); + } + ConversionMetadataT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(ConversionMetadataT *_o, const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + static ::flatbuffers::Offset Pack(::flatbuffers::FlatBufferBuilder &_fbb, const ConversionMetadataT* _o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct ConversionMetadataBuilder { + typedef ConversionMetadata Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_environment(::flatbuffers::Offset environment) { + fbb_.AddOffset(ConversionMetadata::VT_ENVIRONMENT, environment); + } + void add_options(::flatbuffers::Offset options) { + fbb_.AddOffset(ConversionMetadata::VT_OPTIONS, options); + } + explicit ConversionMetadataBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateConversionMetadata( + ::flatbuffers::FlatBufferBuilder &_fbb, + ::flatbuffers::Offset environment = 0, + ::flatbuffers::Offset options = 0) { + ConversionMetadataBuilder builder_(_fbb); + builder_.add_options(options); + builder_.add_environment(environment); + return builder_.Finish(); +} + +::flatbuffers::Offset CreateConversionMetadata(::flatbuffers::FlatBufferBuilder &_fbb, const ConversionMetadataT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); + +inline EnvironmentT *Environment::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { + auto _o = std::unique_ptr(new EnvironmentT()); + UnPackTo(_o.get(), _resolver); + return _o.release(); +} + +inline void Environment::UnPackTo(EnvironmentT *_o, const ::flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = tensorflow_version(); if (_e) _o->tensorflow_version = _e->str(); } + { auto _e = api_version(); _o->api_version = _e; } + { auto _e = model_type(); _o->model_type = _e; } +} + +inline ::flatbuffers::Offset Environment::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const EnvironmentT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { + return CreateEnvironment(_fbb, _o, _rehasher); +} + +inline ::flatbuffers::Offset CreateEnvironment(::flatbuffers::FlatBufferBuilder &_fbb, const EnvironmentT *_o, const ::flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { ::flatbuffers::FlatBufferBuilder *__fbb; const EnvironmentT* __o; const ::flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _tensorflow_version = _o->tensorflow_version.empty() ? 0 : _fbb.CreateString(_o->tensorflow_version); + auto _api_version = _o->api_version; + auto _model_type = _o->model_type; + return tflite::CreateEnvironment( + _fbb, + _tensorflow_version, + _api_version, + _model_type); +} + +inline SparsityBlockSizeT *SparsityBlockSize::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { + auto _o = std::unique_ptr(new SparsityBlockSizeT()); + UnPackTo(_o.get(), _resolver); + return _o.release(); +} + +inline void SparsityBlockSize::UnPackTo(SparsityBlockSizeT *_o, const ::flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = values(); if (_e) { _o->values.resize(_e->size()); for (::flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->values[_i] = _e->Get(_i); } } else { _o->values.resize(0); } } +} + +inline ::flatbuffers::Offset SparsityBlockSize::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const SparsityBlockSizeT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { + return CreateSparsityBlockSize(_fbb, _o, _rehasher); +} + +inline ::flatbuffers::Offset CreateSparsityBlockSize(::flatbuffers::FlatBufferBuilder &_fbb, const SparsityBlockSizeT *_o, const ::flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { ::flatbuffers::FlatBufferBuilder *__fbb; const SparsityBlockSizeT* __o; const ::flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _values = _o->values.size() ? _fbb.CreateVector(_o->values) : 0; + return tflite::CreateSparsityBlockSize( + _fbb, + _values); +} + +inline ConversionOptionsT::ConversionOptionsT(const ConversionOptionsT &o) + : model_optimization_modes(o.model_optimization_modes), + allow_custom_ops(o.allow_custom_ops), + enable_select_tf_ops(o.enable_select_tf_ops), + force_select_tf_ops(o.force_select_tf_ops) { + sparsity_block_sizes.reserve(o.sparsity_block_sizes.size()); + for (const auto &sparsity_block_sizes_ : o.sparsity_block_sizes) { sparsity_block_sizes.emplace_back((sparsity_block_sizes_) ? new tflite::SparsityBlockSizeT(*sparsity_block_sizes_) : nullptr); } +} + +inline ConversionOptionsT &ConversionOptionsT::operator=(ConversionOptionsT o) FLATBUFFERS_NOEXCEPT { + std::swap(model_optimization_modes, o.model_optimization_modes); + std::swap(allow_custom_ops, o.allow_custom_ops); + std::swap(enable_select_tf_ops, o.enable_select_tf_ops); + std::swap(force_select_tf_ops, o.force_select_tf_ops); + std::swap(sparsity_block_sizes, o.sparsity_block_sizes); + return *this; +} + +inline ConversionOptionsT *ConversionOptions::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { + auto _o = std::unique_ptr(new ConversionOptionsT()); + UnPackTo(_o.get(), _resolver); + return _o.release(); +} + +inline void ConversionOptions::UnPackTo(ConversionOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = model_optimization_modes(); if (_e) { _o->model_optimization_modes.resize(_e->size()); for (::flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->model_optimization_modes[_i] = static_cast(_e->Get(_i)); } } else { _o->model_optimization_modes.resize(0); } } + { auto _e = allow_custom_ops(); _o->allow_custom_ops = _e; } + { auto _e = enable_select_tf_ops(); _o->enable_select_tf_ops = _e; } + { auto _e = force_select_tf_ops(); _o->force_select_tf_ops = _e; } + { auto _e = sparsity_block_sizes(); if (_e) { _o->sparsity_block_sizes.resize(_e->size()); for (::flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { if(_o->sparsity_block_sizes[_i]) { _e->Get(_i)->UnPackTo(_o->sparsity_block_sizes[_i].get(), _resolver); } else { _o->sparsity_block_sizes[_i] = std::unique_ptr(_e->Get(_i)->UnPack(_resolver)); }; } } else { _o->sparsity_block_sizes.resize(0); } } +} + +inline ::flatbuffers::Offset ConversionOptions::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const ConversionOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { + return CreateConversionOptions(_fbb, _o, _rehasher); +} + +inline ::flatbuffers::Offset CreateConversionOptions(::flatbuffers::FlatBufferBuilder &_fbb, const ConversionOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { ::flatbuffers::FlatBufferBuilder *__fbb; const ConversionOptionsT* __o; const ::flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _model_optimization_modes = _o->model_optimization_modes.size() ? _fbb.CreateVectorScalarCast(::flatbuffers::data(_o->model_optimization_modes), _o->model_optimization_modes.size()) : 0; + auto _allow_custom_ops = _o->allow_custom_ops; + auto _enable_select_tf_ops = _o->enable_select_tf_ops; + auto _force_select_tf_ops = _o->force_select_tf_ops; + auto _sparsity_block_sizes = _o->sparsity_block_sizes.size() ? _fbb.CreateVector<::flatbuffers::Offset> (_o->sparsity_block_sizes.size(), [](size_t i, _VectorArgs *__va) { return CreateSparsityBlockSize(*__va->__fbb, __va->__o->sparsity_block_sizes[i].get(), __va->__rehasher); }, &_va ) : 0; + return tflite::CreateConversionOptions( + _fbb, + _model_optimization_modes, + _allow_custom_ops, + _enable_select_tf_ops, + _force_select_tf_ops, + _sparsity_block_sizes); +} + +inline ConversionMetadataT::ConversionMetadataT(const ConversionMetadataT &o) + : environment((o.environment) ? new tflite::EnvironmentT(*o.environment) : nullptr), + options((o.options) ? new tflite::ConversionOptionsT(*o.options) : nullptr) { +} + +inline ConversionMetadataT &ConversionMetadataT::operator=(ConversionMetadataT o) FLATBUFFERS_NOEXCEPT { + std::swap(environment, o.environment); + std::swap(options, o.options); + return *this; +} + +inline ConversionMetadataT *ConversionMetadata::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { + auto _o = std::unique_ptr(new ConversionMetadataT()); + UnPackTo(_o.get(), _resolver); + return _o.release(); +} + +inline void ConversionMetadata::UnPackTo(ConversionMetadataT *_o, const ::flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = environment(); if (_e) { if(_o->environment) { _e->UnPackTo(_o->environment.get(), _resolver); } else { _o->environment = std::unique_ptr(_e->UnPack(_resolver)); } } else if (_o->environment) { _o->environment.reset(); } } + { auto _e = options(); if (_e) { if(_o->options) { _e->UnPackTo(_o->options.get(), _resolver); } else { _o->options = std::unique_ptr(_e->UnPack(_resolver)); } } else if (_o->options) { _o->options.reset(); } } +} + +inline ::flatbuffers::Offset ConversionMetadata::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const ConversionMetadataT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { + return CreateConversionMetadata(_fbb, _o, _rehasher); +} + +inline ::flatbuffers::Offset CreateConversionMetadata(::flatbuffers::FlatBufferBuilder &_fbb, const ConversionMetadataT *_o, const ::flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { ::flatbuffers::FlatBufferBuilder *__fbb; const ConversionMetadataT* __o; const ::flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _environment = _o->environment ? CreateEnvironment(_fbb, _o->environment.get(), _rehasher) : 0; + auto _options = _o->options ? CreateConversionOptions(_fbb, _o->options.get(), _rehasher) : 0; + return tflite::CreateConversionMetadata( + _fbb, + _environment, + _options); +} + +inline const tflite::ConversionMetadata *GetConversionMetadata(const void *buf) { + return ::flatbuffers::GetRoot(buf); +} + +inline const tflite::ConversionMetadata *GetSizePrefixedConversionMetadata(const void *buf) { + return ::flatbuffers::GetSizePrefixedRoot(buf); +} + +inline bool VerifyConversionMetadataBuffer( + ::flatbuffers::Verifier &verifier) { + return verifier.VerifyBuffer(nullptr); +} + +inline bool VerifySizePrefixedConversionMetadataBuffer( + ::flatbuffers::Verifier &verifier) { + return verifier.VerifySizePrefixedBuffer(nullptr); +} + +inline void FinishConversionMetadataBuffer( + ::flatbuffers::FlatBufferBuilder &fbb, + ::flatbuffers::Offset root) { + fbb.Finish(root); +} + +inline void FinishSizePrefixedConversionMetadataBuffer( + ::flatbuffers::FlatBufferBuilder &fbb, + ::flatbuffers::Offset root) { + fbb.FinishSizePrefixed(root); +} + +inline std::unique_ptr UnPackConversionMetadata( + const void *buf, + const ::flatbuffers::resolver_function_t *res = nullptr) { + return std::unique_ptr(GetConversionMetadata(buf)->UnPack(res)); +} + +inline std::unique_ptr UnPackSizePrefixedConversionMetadata( + const void *buf, + const ::flatbuffers::resolver_function_t *res = nullptr) { + return std::unique_ptr(GetSizePrefixedConversionMetadata(buf)->UnPack(res)); +} + +} // namespace tflite + +#endif // FLATBUFFERS_GENERATED_CONVERSIONMETADATA_TFLITE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/schema/schema_conversion_utils.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/schema/schema_conversion_utils.h new file mode 100644 index 00000000..ebf9219f --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/schema/schema_conversion_utils.h @@ -0,0 +1,40 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_SCHEMA_SCHEMA_CONVERSION_UTILS_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_SCHEMA_SCHEMA_CONVERSION_UTILS_H_ + +#include "flatbuffers/flatbuffers.h" +#include "tensorflow/compiler/mlir/lite/schema/schema_generated.h" + +namespace tflite { + +int8_t ConvertBuiltinCodeToDeprecatedBuiltinCode(BuiltinOperator builtin_code); + +// The following methods are for backward compatibility for the early version +// three, which does not have an extended builtin code. +flatbuffers::Offset CreateOperatorCode( + flatbuffers::FlatBufferBuilder &_fbb, + BuiltinOperator builtin_code = BuiltinOperator_ADD, + flatbuffers::Offset custom_code = 0, + int32_t version = 1); + +flatbuffers::Offset CreateOperatorCodeDirect( + flatbuffers::FlatBufferBuilder &_fbb, + BuiltinOperator builtin_code = BuiltinOperator_ADD, + const char *custom_code = nullptr, int32_t version = 1); + +} // namespace tflite + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_SCHEMA_SCHEMA_CONVERSION_UTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/schema/schema_generated.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/schema/schema_generated.h new file mode 100755 index 00000000..6262406f --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/schema/schema_generated.h @@ -0,0 +1,25321 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// automatically generated by the FlatBuffers compiler, do not modify + + +#ifndef FLATBUFFERS_GENERATED_SCHEMA_TFLITE_H_ +#define FLATBUFFERS_GENERATED_SCHEMA_TFLITE_H_ + +#include "flatbuffers/flatbuffers.h" + +// Ensure the included flatbuffers.h is the same version as when this file was +// generated, otherwise it may not be compatible. +static_assert(FLATBUFFERS_VERSION_MAJOR == 24 && + FLATBUFFERS_VERSION_MINOR == 3 && + FLATBUFFERS_VERSION_REVISION == 25, + "Non-compatible flatbuffers version included"); + +namespace tflite { + +struct CustomQuantization; +struct CustomQuantizationBuilder; +struct CustomQuantizationT; + +struct BlockwiseQuantization; +struct BlockwiseQuantizationBuilder; +struct BlockwiseQuantizationT; + +struct QuantizationParameters; +struct QuantizationParametersBuilder; +struct QuantizationParametersT; + +struct Int32Vector; +struct Int32VectorBuilder; +struct Int32VectorT; + +struct Uint16Vector; +struct Uint16VectorBuilder; +struct Uint16VectorT; + +struct Uint8Vector; +struct Uint8VectorBuilder; +struct Uint8VectorT; + +struct DimensionMetadata; +struct DimensionMetadataBuilder; +struct DimensionMetadataT; + +struct SparsityParameters; +struct SparsityParametersBuilder; +struct SparsityParametersT; + +struct VariantSubType; +struct VariantSubTypeBuilder; +struct VariantSubTypeT; + +struct Tensor; +struct TensorBuilder; +struct TensorT; + +struct StablehloGatherOptions; +struct StablehloGatherOptionsBuilder; +struct StablehloGatherOptionsT; + +struct StablehloTransposeOptions; +struct StablehloTransposeOptionsBuilder; +struct StablehloTransposeOptionsT; + +struct StablehloDotGeneralOptions; +struct StablehloDotGeneralOptionsBuilder; +struct StablehloDotGeneralOptionsT; + +struct StablehloReduceWindowOptions; +struct StablehloReduceWindowOptionsBuilder; +struct StablehloReduceWindowOptionsT; + +struct StablehloWhileOptions; +struct StablehloWhileOptionsBuilder; +struct StablehloWhileOptionsT; + +struct StablehloSortOptions; +struct StablehloSortOptionsBuilder; +struct StablehloSortOptionsT; + +struct StablehloConcatenateOptions; +struct StablehloConcatenateOptionsBuilder; +struct StablehloConcatenateOptionsT; + +struct StablehloBroadcastInDimOptions; +struct StablehloBroadcastInDimOptionsBuilder; +struct StablehloBroadcastInDimOptionsT; + +struct StablehloCompareOptions; +struct StablehloCompareOptionsBuilder; +struct StablehloCompareOptionsT; + +struct StablehloDynamicSliceOptions; +struct StablehloDynamicSliceOptionsBuilder; +struct StablehloDynamicSliceOptionsT; + +struct StablehloPadOptions; +struct StablehloPadOptionsBuilder; +struct StablehloPadOptionsT; + +struct StablehloIotaOptions; +struct StablehloIotaOptionsBuilder; +struct StablehloIotaOptionsT; + +struct StablehloCustomCallOptions; +struct StablehloCustomCallOptionsBuilder; +struct StablehloCustomCallOptionsT; + +struct StablehloReduceOptions; +struct StablehloReduceOptionsBuilder; +struct StablehloReduceOptionsT; + +struct StablehloSliceOptions; +struct StablehloSliceOptionsBuilder; +struct StablehloSliceOptionsT; + +struct StablehloConvolutionOptions; +struct StablehloConvolutionOptionsBuilder; +struct StablehloConvolutionOptionsT; + +struct StablehloScatterOptions; +struct StablehloScatterOptionsBuilder; +struct StablehloScatterOptionsT; + +struct StablehloCaseOptions; +struct StablehloCaseOptionsBuilder; +struct StablehloCaseOptionsT; + +struct StablehloRngBitGeneratorOptions; +struct StablehloRngBitGeneratorOptionsBuilder; +struct StablehloRngBitGeneratorOptionsT; + +struct Conv2DOptions; +struct Conv2DOptionsBuilder; +struct Conv2DOptionsT; + +struct Conv3DOptions; +struct Conv3DOptionsBuilder; +struct Conv3DOptionsT; + +struct Pool2DOptions; +struct Pool2DOptionsBuilder; +struct Pool2DOptionsT; + +struct DepthwiseConv2DOptions; +struct DepthwiseConv2DOptionsBuilder; +struct DepthwiseConv2DOptionsT; + +struct ConcatEmbeddingsOptions; +struct ConcatEmbeddingsOptionsBuilder; +struct ConcatEmbeddingsOptionsT; + +struct LSHProjectionOptions; +struct LSHProjectionOptionsBuilder; +struct LSHProjectionOptionsT; + +struct SVDFOptions; +struct SVDFOptionsBuilder; +struct SVDFOptionsT; + +struct RNNOptions; +struct RNNOptionsBuilder; +struct RNNOptionsT; + +struct SequenceRNNOptions; +struct SequenceRNNOptionsBuilder; +struct SequenceRNNOptionsT; + +struct BidirectionalSequenceRNNOptions; +struct BidirectionalSequenceRNNOptionsBuilder; +struct BidirectionalSequenceRNNOptionsT; + +struct FullyConnectedOptions; +struct FullyConnectedOptionsBuilder; +struct FullyConnectedOptionsT; + +struct SoftmaxOptions; +struct SoftmaxOptionsBuilder; +struct SoftmaxOptionsT; + +struct ConcatenationOptions; +struct ConcatenationOptionsBuilder; +struct ConcatenationOptionsT; + +struct AddOptions; +struct AddOptionsBuilder; +struct AddOptionsT; + +struct MulOptions; +struct MulOptionsBuilder; +struct MulOptionsT; + +struct L2NormOptions; +struct L2NormOptionsBuilder; +struct L2NormOptionsT; + +struct LocalResponseNormalizationOptions; +struct LocalResponseNormalizationOptionsBuilder; +struct LocalResponseNormalizationOptionsT; + +struct LSTMOptions; +struct LSTMOptionsBuilder; +struct LSTMOptionsT; + +struct UnidirectionalSequenceLSTMOptions; +struct UnidirectionalSequenceLSTMOptionsBuilder; +struct UnidirectionalSequenceLSTMOptionsT; + +struct BidirectionalSequenceLSTMOptions; +struct BidirectionalSequenceLSTMOptionsBuilder; +struct BidirectionalSequenceLSTMOptionsT; + +struct ResizeBilinearOptions; +struct ResizeBilinearOptionsBuilder; +struct ResizeBilinearOptionsT; + +struct ResizeNearestNeighborOptions; +struct ResizeNearestNeighborOptionsBuilder; +struct ResizeNearestNeighborOptionsT; + +struct CallOptions; +struct CallOptionsBuilder; +struct CallOptionsT; + +struct PadOptions; +struct PadOptionsBuilder; +struct PadOptionsT; + +struct PadV2Options; +struct PadV2OptionsBuilder; +struct PadV2OptionsT; + +struct ReshapeOptions; +struct ReshapeOptionsBuilder; +struct ReshapeOptionsT; + +struct SpaceToBatchNDOptions; +struct SpaceToBatchNDOptionsBuilder; +struct SpaceToBatchNDOptionsT; + +struct BatchToSpaceNDOptions; +struct BatchToSpaceNDOptionsBuilder; +struct BatchToSpaceNDOptionsT; + +struct SkipGramOptions; +struct SkipGramOptionsBuilder; +struct SkipGramOptionsT; + +struct SpaceToDepthOptions; +struct SpaceToDepthOptionsBuilder; +struct SpaceToDepthOptionsT; + +struct DepthToSpaceOptions; +struct DepthToSpaceOptionsBuilder; +struct DepthToSpaceOptionsT; + +struct SubOptions; +struct SubOptionsBuilder; +struct SubOptionsT; + +struct DivOptions; +struct DivOptionsBuilder; +struct DivOptionsT; + +struct TopKV2Options; +struct TopKV2OptionsBuilder; +struct TopKV2OptionsT; + +struct EmbeddingLookupSparseOptions; +struct EmbeddingLookupSparseOptionsBuilder; +struct EmbeddingLookupSparseOptionsT; + +struct GatherOptions; +struct GatherOptionsBuilder; +struct GatherOptionsT; + +struct TransposeOptions; +struct TransposeOptionsBuilder; +struct TransposeOptionsT; + +struct ExpOptions; +struct ExpOptionsBuilder; +struct ExpOptionsT; + +struct CosOptions; +struct CosOptionsBuilder; +struct CosOptionsT; + +struct ReducerOptions; +struct ReducerOptionsBuilder; +struct ReducerOptionsT; + +struct SqueezeOptions; +struct SqueezeOptionsBuilder; +struct SqueezeOptionsT; + +struct SplitOptions; +struct SplitOptionsBuilder; +struct SplitOptionsT; + +struct SplitVOptions; +struct SplitVOptionsBuilder; +struct SplitVOptionsT; + +struct StridedSliceOptions; +struct StridedSliceOptionsBuilder; +struct StridedSliceOptionsT; + +struct LogSoftmaxOptions; +struct LogSoftmaxOptionsBuilder; +struct LogSoftmaxOptionsT; + +struct CastOptions; +struct CastOptionsBuilder; +struct CastOptionsT; + +struct DequantizeOptions; +struct DequantizeOptionsBuilder; +struct DequantizeOptionsT; + +struct MaximumMinimumOptions; +struct MaximumMinimumOptionsBuilder; +struct MaximumMinimumOptionsT; + +struct TileOptions; +struct TileOptionsBuilder; +struct TileOptionsT; + +struct ArgMaxOptions; +struct ArgMaxOptionsBuilder; +struct ArgMaxOptionsT; + +struct ArgMinOptions; +struct ArgMinOptionsBuilder; +struct ArgMinOptionsT; + +struct GreaterOptions; +struct GreaterOptionsBuilder; +struct GreaterOptionsT; + +struct GreaterEqualOptions; +struct GreaterEqualOptionsBuilder; +struct GreaterEqualOptionsT; + +struct LessOptions; +struct LessOptionsBuilder; +struct LessOptionsT; + +struct LessEqualOptions; +struct LessEqualOptionsBuilder; +struct LessEqualOptionsT; + +struct NegOptions; +struct NegOptionsBuilder; +struct NegOptionsT; + +struct SelectOptions; +struct SelectOptionsBuilder; +struct SelectOptionsT; + +struct SliceOptions; +struct SliceOptionsBuilder; +struct SliceOptionsT; + +struct TransposeConvOptions; +struct TransposeConvOptionsBuilder; +struct TransposeConvOptionsT; + +struct ExpandDimsOptions; +struct ExpandDimsOptionsBuilder; +struct ExpandDimsOptionsT; + +struct SparseToDenseOptions; +struct SparseToDenseOptionsBuilder; +struct SparseToDenseOptionsT; + +struct EqualOptions; +struct EqualOptionsBuilder; +struct EqualOptionsT; + +struct NotEqualOptions; +struct NotEqualOptionsBuilder; +struct NotEqualOptionsT; + +struct ShapeOptions; +struct ShapeOptionsBuilder; +struct ShapeOptionsT; + +struct RankOptions; +struct RankOptionsBuilder; +struct RankOptionsT; + +struct PowOptions; +struct PowOptionsBuilder; +struct PowOptionsT; + +struct FakeQuantOptions; +struct FakeQuantOptionsBuilder; +struct FakeQuantOptionsT; + +struct PackOptions; +struct PackOptionsBuilder; +struct PackOptionsT; + +struct LogicalOrOptions; +struct LogicalOrOptionsBuilder; +struct LogicalOrOptionsT; + +struct OneHotOptions; +struct OneHotOptionsBuilder; +struct OneHotOptionsT; + +struct AbsOptions; +struct AbsOptionsBuilder; +struct AbsOptionsT; + +struct HardSwishOptions; +struct HardSwishOptionsBuilder; +struct HardSwishOptionsT; + +struct LogicalAndOptions; +struct LogicalAndOptionsBuilder; +struct LogicalAndOptionsT; + +struct LogicalNotOptions; +struct LogicalNotOptionsBuilder; +struct LogicalNotOptionsT; + +struct UnpackOptions; +struct UnpackOptionsBuilder; +struct UnpackOptionsT; + +struct FloorDivOptions; +struct FloorDivOptionsBuilder; +struct FloorDivOptionsT; + +struct SquareOptions; +struct SquareOptionsBuilder; +struct SquareOptionsT; + +struct ZerosLikeOptions; +struct ZerosLikeOptionsBuilder; +struct ZerosLikeOptionsT; + +struct FillOptions; +struct FillOptionsBuilder; +struct FillOptionsT; + +struct FloorModOptions; +struct FloorModOptionsBuilder; +struct FloorModOptionsT; + +struct RangeOptions; +struct RangeOptionsBuilder; +struct RangeOptionsT; + +struct LeakyReluOptions; +struct LeakyReluOptionsBuilder; +struct LeakyReluOptionsT; + +struct SquaredDifferenceOptions; +struct SquaredDifferenceOptionsBuilder; +struct SquaredDifferenceOptionsT; + +struct MirrorPadOptions; +struct MirrorPadOptionsBuilder; +struct MirrorPadOptionsT; + +struct UniqueOptions; +struct UniqueOptionsBuilder; +struct UniqueOptionsT; + +struct ReverseV2Options; +struct ReverseV2OptionsBuilder; +struct ReverseV2OptionsT; + +struct AddNOptions; +struct AddNOptionsBuilder; +struct AddNOptionsT; + +struct GatherNdOptions; +struct GatherNdOptionsBuilder; +struct GatherNdOptionsT; + +struct WhereOptions; +struct WhereOptionsBuilder; +struct WhereOptionsT; + +struct ReverseSequenceOptions; +struct ReverseSequenceOptionsBuilder; +struct ReverseSequenceOptionsT; + +struct MatrixDiagOptions; +struct MatrixDiagOptionsBuilder; +struct MatrixDiagOptionsT; + +struct QuantizeOptions; +struct QuantizeOptionsBuilder; +struct QuantizeOptionsT; + +struct MatrixSetDiagOptions; +struct MatrixSetDiagOptionsBuilder; +struct MatrixSetDiagOptionsT; + +struct IfOptions; +struct IfOptionsBuilder; +struct IfOptionsT; + +struct CallOnceOptions; +struct CallOnceOptionsBuilder; +struct CallOnceOptionsT; + +struct WhileOptions; +struct WhileOptionsBuilder; +struct WhileOptionsT; + +struct NonMaxSuppressionV4Options; +struct NonMaxSuppressionV4OptionsBuilder; +struct NonMaxSuppressionV4OptionsT; + +struct NonMaxSuppressionV5Options; +struct NonMaxSuppressionV5OptionsBuilder; +struct NonMaxSuppressionV5OptionsT; + +struct ScatterNdOptions; +struct ScatterNdOptionsBuilder; +struct ScatterNdOptionsT; + +struct SelectV2Options; +struct SelectV2OptionsBuilder; +struct SelectV2OptionsT; + +struct DensifyOptions; +struct DensifyOptionsBuilder; +struct DensifyOptionsT; + +struct SegmentSumOptions; +struct SegmentSumOptionsBuilder; +struct SegmentSumOptionsT; + +struct BatchMatMulOptions; +struct BatchMatMulOptionsBuilder; +struct BatchMatMulOptionsT; + +struct CumsumOptions; +struct CumsumOptionsBuilder; +struct CumsumOptionsT; + +struct BroadcastToOptions; +struct BroadcastToOptionsBuilder; +struct BroadcastToOptionsT; + +struct Rfft2dOptions; +struct Rfft2dOptionsBuilder; +struct Rfft2dOptionsT; + +struct HashtableOptions; +struct HashtableOptionsBuilder; +struct HashtableOptionsT; + +struct HashtableFindOptions; +struct HashtableFindOptionsBuilder; +struct HashtableFindOptionsT; + +struct HashtableImportOptions; +struct HashtableImportOptionsBuilder; +struct HashtableImportOptionsT; + +struct HashtableSizeOptions; +struct HashtableSizeOptionsBuilder; +struct HashtableSizeOptionsT; + +struct VarHandleOptions; +struct VarHandleOptionsBuilder; +struct VarHandleOptionsT; + +struct ReadVariableOptions; +struct ReadVariableOptionsBuilder; +struct ReadVariableOptionsT; + +struct AssignVariableOptions; +struct AssignVariableOptionsBuilder; +struct AssignVariableOptionsT; + +struct RandomOptions; +struct RandomOptionsBuilder; +struct RandomOptionsT; + +struct BucketizeOptions; +struct BucketizeOptionsBuilder; +struct BucketizeOptionsT; + +struct GeluOptions; +struct GeluOptionsBuilder; +struct GeluOptionsT; + +struct DynamicUpdateSliceOptions; +struct DynamicUpdateSliceOptionsBuilder; +struct DynamicUpdateSliceOptionsT; + +struct UnsortedSegmentProdOptions; +struct UnsortedSegmentProdOptionsBuilder; +struct UnsortedSegmentProdOptionsT; + +struct UnsortedSegmentMaxOptions; +struct UnsortedSegmentMaxOptionsBuilder; +struct UnsortedSegmentMaxOptionsT; + +struct UnsortedSegmentSumOptions; +struct UnsortedSegmentSumOptionsBuilder; +struct UnsortedSegmentSumOptionsT; + +struct ATan2Options; +struct ATan2OptionsBuilder; +struct ATan2OptionsT; + +struct UnsortedSegmentMinOptions; +struct UnsortedSegmentMinOptionsBuilder; +struct UnsortedSegmentMinOptionsT; + +struct SignOptions; +struct SignOptionsBuilder; +struct SignOptionsT; + +struct BitcastOptions; +struct BitcastOptionsBuilder; +struct BitcastOptionsT; + +struct BitwiseXorOptions; +struct BitwiseXorOptionsBuilder; +struct BitwiseXorOptionsT; + +struct RightShiftOptions; +struct RightShiftOptionsBuilder; +struct RightShiftOptionsT; + +struct DilateOptions; +struct DilateOptionsBuilder; +struct DilateOptionsT; + +struct ReduceWindowOptions; +struct ReduceWindowOptionsBuilder; +struct ReduceWindowOptionsT; + +struct OperatorCode; +struct OperatorCodeBuilder; +struct OperatorCodeT; + +struct StableHLOCompositeOptions; +struct StableHLOCompositeOptionsBuilder; +struct StableHLOCompositeOptionsT; + +struct StablehloShiftLeftOptions; +struct StablehloShiftLeftOptionsBuilder; +struct StablehloShiftLeftOptionsT; + +struct Operator; +struct OperatorBuilder; +struct OperatorT; + +struct SubGraph; +struct SubGraphBuilder; +struct SubGraphT; + +struct Buffer; +struct BufferBuilder; +struct BufferT; + +struct Metadata; +struct MetadataBuilder; +struct MetadataT; + +struct TensorMap; +struct TensorMapBuilder; +struct TensorMapT; + +struct SignatureDef; +struct SignatureDefBuilder; +struct SignatureDefT; + +struct Model; +struct ModelBuilder; +struct ModelT; + +enum TensorType : int8_t { + TensorType_FLOAT32 = 0, + TensorType_FLOAT16 = 1, + TensorType_INT32 = 2, + TensorType_UINT8 = 3, + TensorType_INT64 = 4, + TensorType_STRING = 5, + TensorType_BOOL = 6, + TensorType_INT16 = 7, + TensorType_COMPLEX64 = 8, + TensorType_INT8 = 9, + TensorType_FLOAT64 = 10, + TensorType_COMPLEX128 = 11, + TensorType_UINT64 = 12, + TensorType_RESOURCE = 13, + TensorType_VARIANT = 14, + TensorType_UINT32 = 15, + TensorType_UINT16 = 16, + TensorType_INT4 = 17, + TensorType_BFLOAT16 = 18, + TensorType_MIN = TensorType_FLOAT32, + TensorType_MAX = TensorType_BFLOAT16 +}; + +inline const TensorType (&EnumValuesTensorType())[19] { + static const TensorType values[] = { + TensorType_FLOAT32, + TensorType_FLOAT16, + TensorType_INT32, + TensorType_UINT8, + TensorType_INT64, + TensorType_STRING, + TensorType_BOOL, + TensorType_INT16, + TensorType_COMPLEX64, + TensorType_INT8, + TensorType_FLOAT64, + TensorType_COMPLEX128, + TensorType_UINT64, + TensorType_RESOURCE, + TensorType_VARIANT, + TensorType_UINT32, + TensorType_UINT16, + TensorType_INT4, + TensorType_BFLOAT16 + }; + return values; +} + +inline const char * const *EnumNamesTensorType() { + static const char * const names[20] = { + "FLOAT32", + "FLOAT16", + "INT32", + "UINT8", + "INT64", + "STRING", + "BOOL", + "INT16", + "COMPLEX64", + "INT8", + "FLOAT64", + "COMPLEX128", + "UINT64", + "RESOURCE", + "VARIANT", + "UINT32", + "UINT16", + "INT4", + "BFLOAT16", + nullptr + }; + return names; +} + +inline const char *EnumNameTensorType(TensorType e) { + if (::flatbuffers::IsOutRange(e, TensorType_FLOAT32, TensorType_BFLOAT16)) return ""; + const size_t index = static_cast(e); + return EnumNamesTensorType()[index]; +} + +enum QuantizationDetails : uint8_t { + QuantizationDetails_NONE = 0, + QuantizationDetails_CustomQuantization = 1, + QuantizationDetails_BlockwiseQuantization = 2, + QuantizationDetails_MIN = QuantizationDetails_NONE, + QuantizationDetails_MAX = QuantizationDetails_BlockwiseQuantization +}; + +inline const QuantizationDetails (&EnumValuesQuantizationDetails())[3] { + static const QuantizationDetails values[] = { + QuantizationDetails_NONE, + QuantizationDetails_CustomQuantization, + QuantizationDetails_BlockwiseQuantization + }; + return values; +} + +inline const char * const *EnumNamesQuantizationDetails() { + static const char * const names[4] = { + "NONE", + "CustomQuantization", + "BlockwiseQuantization", + nullptr + }; + return names; +} + +inline const char *EnumNameQuantizationDetails(QuantizationDetails e) { + if (::flatbuffers::IsOutRange(e, QuantizationDetails_NONE, QuantizationDetails_BlockwiseQuantization)) return ""; + const size_t index = static_cast(e); + return EnumNamesQuantizationDetails()[index]; +} + +template struct QuantizationDetailsTraits { + static const QuantizationDetails enum_value = QuantizationDetails_NONE; +}; + +template<> struct QuantizationDetailsTraits { + static const QuantizationDetails enum_value = QuantizationDetails_CustomQuantization; +}; + +template<> struct QuantizationDetailsTraits { + static const QuantizationDetails enum_value = QuantizationDetails_BlockwiseQuantization; +}; + +template struct QuantizationDetailsUnionTraits { + static const QuantizationDetails enum_value = QuantizationDetails_NONE; +}; + +template<> struct QuantizationDetailsUnionTraits { + static const QuantizationDetails enum_value = QuantizationDetails_CustomQuantization; +}; + +template<> struct QuantizationDetailsUnionTraits { + static const QuantizationDetails enum_value = QuantizationDetails_BlockwiseQuantization; +}; + +struct QuantizationDetailsUnion { + QuantizationDetails type; + void *value; + + QuantizationDetailsUnion() : type(QuantizationDetails_NONE), value(nullptr) {} + QuantizationDetailsUnion(QuantizationDetailsUnion&& u) FLATBUFFERS_NOEXCEPT : + type(QuantizationDetails_NONE), value(nullptr) + { std::swap(type, u.type); std::swap(value, u.value); } + QuantizationDetailsUnion(const QuantizationDetailsUnion &); + QuantizationDetailsUnion &operator=(const QuantizationDetailsUnion &u) + { QuantizationDetailsUnion t(u); std::swap(type, t.type); std::swap(value, t.value); return *this; } + QuantizationDetailsUnion &operator=(QuantizationDetailsUnion &&u) FLATBUFFERS_NOEXCEPT + { std::swap(type, u.type); std::swap(value, u.value); return *this; } + ~QuantizationDetailsUnion() { Reset(); } + + void Reset(); + + template + void Set(T&& val) { + typedef typename std::remove_reference::type RT; + Reset(); + type = QuantizationDetailsUnionTraits::enum_value; + if (type != QuantizationDetails_NONE) { + value = new RT(std::forward(val)); + } + } + + static void *UnPack(const void *obj, QuantizationDetails type, const ::flatbuffers::resolver_function_t *resolver); + ::flatbuffers::Offset Pack(::flatbuffers::FlatBufferBuilder &_fbb, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr) const; + + tflite::CustomQuantizationT *AsCustomQuantization() { + return type == QuantizationDetails_CustomQuantization ? + reinterpret_cast(value) : nullptr; + } + const tflite::CustomQuantizationT *AsCustomQuantization() const { + return type == QuantizationDetails_CustomQuantization ? + reinterpret_cast(value) : nullptr; + } + tflite::BlockwiseQuantizationT *AsBlockwiseQuantization() { + return type == QuantizationDetails_BlockwiseQuantization ? + reinterpret_cast(value) : nullptr; + } + const tflite::BlockwiseQuantizationT *AsBlockwiseQuantization() const { + return type == QuantizationDetails_BlockwiseQuantization ? + reinterpret_cast(value) : nullptr; + } +}; + +bool VerifyQuantizationDetails(::flatbuffers::Verifier &verifier, const void *obj, QuantizationDetails type); +bool VerifyQuantizationDetailsVector(::flatbuffers::Verifier &verifier, const ::flatbuffers::Vector<::flatbuffers::Offset> *values, const ::flatbuffers::Vector *types); + +enum DimensionType : int8_t { + DimensionType_DENSE = 0, + DimensionType_SPARSE_CSR = 1, + DimensionType_MIN = DimensionType_DENSE, + DimensionType_MAX = DimensionType_SPARSE_CSR +}; + +inline const DimensionType (&EnumValuesDimensionType())[2] { + static const DimensionType values[] = { + DimensionType_DENSE, + DimensionType_SPARSE_CSR + }; + return values; +} + +inline const char * const *EnumNamesDimensionType() { + static const char * const names[3] = { + "DENSE", + "SPARSE_CSR", + nullptr + }; + return names; +} + +inline const char *EnumNameDimensionType(DimensionType e) { + if (::flatbuffers::IsOutRange(e, DimensionType_DENSE, DimensionType_SPARSE_CSR)) return ""; + const size_t index = static_cast(e); + return EnumNamesDimensionType()[index]; +} + +enum SparseIndexVector : uint8_t { + SparseIndexVector_NONE = 0, + SparseIndexVector_Int32Vector = 1, + SparseIndexVector_Uint16Vector = 2, + SparseIndexVector_Uint8Vector = 3, + SparseIndexVector_MIN = SparseIndexVector_NONE, + SparseIndexVector_MAX = SparseIndexVector_Uint8Vector +}; + +inline const SparseIndexVector (&EnumValuesSparseIndexVector())[4] { + static const SparseIndexVector values[] = { + SparseIndexVector_NONE, + SparseIndexVector_Int32Vector, + SparseIndexVector_Uint16Vector, + SparseIndexVector_Uint8Vector + }; + return values; +} + +inline const char * const *EnumNamesSparseIndexVector() { + static const char * const names[5] = { + "NONE", + "Int32Vector", + "Uint16Vector", + "Uint8Vector", + nullptr + }; + return names; +} + +inline const char *EnumNameSparseIndexVector(SparseIndexVector e) { + if (::flatbuffers::IsOutRange(e, SparseIndexVector_NONE, SparseIndexVector_Uint8Vector)) return ""; + const size_t index = static_cast(e); + return EnumNamesSparseIndexVector()[index]; +} + +template struct SparseIndexVectorTraits { + static const SparseIndexVector enum_value = SparseIndexVector_NONE; +}; + +template<> struct SparseIndexVectorTraits { + static const SparseIndexVector enum_value = SparseIndexVector_Int32Vector; +}; + +template<> struct SparseIndexVectorTraits { + static const SparseIndexVector enum_value = SparseIndexVector_Uint16Vector; +}; + +template<> struct SparseIndexVectorTraits { + static const SparseIndexVector enum_value = SparseIndexVector_Uint8Vector; +}; + +template struct SparseIndexVectorUnionTraits { + static const SparseIndexVector enum_value = SparseIndexVector_NONE; +}; + +template<> struct SparseIndexVectorUnionTraits { + static const SparseIndexVector enum_value = SparseIndexVector_Int32Vector; +}; + +template<> struct SparseIndexVectorUnionTraits { + static const SparseIndexVector enum_value = SparseIndexVector_Uint16Vector; +}; + +template<> struct SparseIndexVectorUnionTraits { + static const SparseIndexVector enum_value = SparseIndexVector_Uint8Vector; +}; + +struct SparseIndexVectorUnion { + SparseIndexVector type; + void *value; + + SparseIndexVectorUnion() : type(SparseIndexVector_NONE), value(nullptr) {} + SparseIndexVectorUnion(SparseIndexVectorUnion&& u) FLATBUFFERS_NOEXCEPT : + type(SparseIndexVector_NONE), value(nullptr) + { std::swap(type, u.type); std::swap(value, u.value); } + SparseIndexVectorUnion(const SparseIndexVectorUnion &); + SparseIndexVectorUnion &operator=(const SparseIndexVectorUnion &u) + { SparseIndexVectorUnion t(u); std::swap(type, t.type); std::swap(value, t.value); return *this; } + SparseIndexVectorUnion &operator=(SparseIndexVectorUnion &&u) FLATBUFFERS_NOEXCEPT + { std::swap(type, u.type); std::swap(value, u.value); return *this; } + ~SparseIndexVectorUnion() { Reset(); } + + void Reset(); + + template + void Set(T&& val) { + typedef typename std::remove_reference::type RT; + Reset(); + type = SparseIndexVectorUnionTraits::enum_value; + if (type != SparseIndexVector_NONE) { + value = new RT(std::forward(val)); + } + } + + static void *UnPack(const void *obj, SparseIndexVector type, const ::flatbuffers::resolver_function_t *resolver); + ::flatbuffers::Offset Pack(::flatbuffers::FlatBufferBuilder &_fbb, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr) const; + + tflite::Int32VectorT *AsInt32Vector() { + return type == SparseIndexVector_Int32Vector ? + reinterpret_cast(value) : nullptr; + } + const tflite::Int32VectorT *AsInt32Vector() const { + return type == SparseIndexVector_Int32Vector ? + reinterpret_cast(value) : nullptr; + } + tflite::Uint16VectorT *AsUint16Vector() { + return type == SparseIndexVector_Uint16Vector ? + reinterpret_cast(value) : nullptr; + } + const tflite::Uint16VectorT *AsUint16Vector() const { + return type == SparseIndexVector_Uint16Vector ? + reinterpret_cast(value) : nullptr; + } + tflite::Uint8VectorT *AsUint8Vector() { + return type == SparseIndexVector_Uint8Vector ? + reinterpret_cast(value) : nullptr; + } + const tflite::Uint8VectorT *AsUint8Vector() const { + return type == SparseIndexVector_Uint8Vector ? + reinterpret_cast(value) : nullptr; + } +}; + +bool VerifySparseIndexVector(::flatbuffers::Verifier &verifier, const void *obj, SparseIndexVector type); +bool VerifySparseIndexVectorVector(::flatbuffers::Verifier &verifier, const ::flatbuffers::Vector<::flatbuffers::Offset> *values, const ::flatbuffers::Vector *types); + +enum BuiltinOperator : int32_t { + BuiltinOperator_ADD = 0, + BuiltinOperator_AVERAGE_POOL_2D = 1, + BuiltinOperator_CONCATENATION = 2, + BuiltinOperator_CONV_2D = 3, + BuiltinOperator_DEPTHWISE_CONV_2D = 4, + BuiltinOperator_DEPTH_TO_SPACE = 5, + BuiltinOperator_DEQUANTIZE = 6, + BuiltinOperator_EMBEDDING_LOOKUP = 7, + BuiltinOperator_FLOOR = 8, + BuiltinOperator_FULLY_CONNECTED = 9, + BuiltinOperator_HASHTABLE_LOOKUP = 10, + BuiltinOperator_L2_NORMALIZATION = 11, + BuiltinOperator_L2_POOL_2D = 12, + BuiltinOperator_LOCAL_RESPONSE_NORMALIZATION = 13, + BuiltinOperator_LOGISTIC = 14, + BuiltinOperator_LSH_PROJECTION = 15, + BuiltinOperator_LSTM = 16, + BuiltinOperator_MAX_POOL_2D = 17, + BuiltinOperator_MUL = 18, + BuiltinOperator_RELU = 19, + BuiltinOperator_RELU_N1_TO_1 = 20, + BuiltinOperator_RELU6 = 21, + BuiltinOperator_RESHAPE = 22, + BuiltinOperator_RESIZE_BILINEAR = 23, + BuiltinOperator_RNN = 24, + BuiltinOperator_SOFTMAX = 25, + BuiltinOperator_SPACE_TO_DEPTH = 26, + BuiltinOperator_SVDF = 27, + BuiltinOperator_TANH = 28, + BuiltinOperator_CONCAT_EMBEDDINGS = 29, + BuiltinOperator_SKIP_GRAM = 30, + BuiltinOperator_CALL = 31, + BuiltinOperator_CUSTOM = 32, + BuiltinOperator_EMBEDDING_LOOKUP_SPARSE = 33, + BuiltinOperator_PAD = 34, + BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN = 35, + BuiltinOperator_GATHER = 36, + BuiltinOperator_BATCH_TO_SPACE_ND = 37, + BuiltinOperator_SPACE_TO_BATCH_ND = 38, + BuiltinOperator_TRANSPOSE = 39, + BuiltinOperator_MEAN = 40, + BuiltinOperator_SUB = 41, + BuiltinOperator_DIV = 42, + BuiltinOperator_SQUEEZE = 43, + BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM = 44, + BuiltinOperator_STRIDED_SLICE = 45, + BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN = 46, + BuiltinOperator_EXP = 47, + BuiltinOperator_TOPK_V2 = 48, + BuiltinOperator_SPLIT = 49, + BuiltinOperator_LOG_SOFTMAX = 50, + BuiltinOperator_DELEGATE = 51, + BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM = 52, + BuiltinOperator_CAST = 53, + BuiltinOperator_PRELU = 54, + BuiltinOperator_MAXIMUM = 55, + BuiltinOperator_ARG_MAX = 56, + BuiltinOperator_MINIMUM = 57, + BuiltinOperator_LESS = 58, + BuiltinOperator_NEG = 59, + BuiltinOperator_PADV2 = 60, + BuiltinOperator_GREATER = 61, + BuiltinOperator_GREATER_EQUAL = 62, + BuiltinOperator_LESS_EQUAL = 63, + BuiltinOperator_SELECT = 64, + BuiltinOperator_SLICE = 65, + BuiltinOperator_SIN = 66, + BuiltinOperator_TRANSPOSE_CONV = 67, + BuiltinOperator_SPARSE_TO_DENSE = 68, + BuiltinOperator_TILE = 69, + BuiltinOperator_EXPAND_DIMS = 70, + BuiltinOperator_EQUAL = 71, + BuiltinOperator_NOT_EQUAL = 72, + BuiltinOperator_LOG = 73, + BuiltinOperator_SUM = 74, + BuiltinOperator_SQRT = 75, + BuiltinOperator_RSQRT = 76, + BuiltinOperator_SHAPE = 77, + BuiltinOperator_POW = 78, + BuiltinOperator_ARG_MIN = 79, + BuiltinOperator_FAKE_QUANT = 80, + BuiltinOperator_REDUCE_PROD = 81, + BuiltinOperator_REDUCE_MAX = 82, + BuiltinOperator_PACK = 83, + BuiltinOperator_LOGICAL_OR = 84, + BuiltinOperator_ONE_HOT = 85, + BuiltinOperator_LOGICAL_AND = 86, + BuiltinOperator_LOGICAL_NOT = 87, + BuiltinOperator_UNPACK = 88, + BuiltinOperator_REDUCE_MIN = 89, + BuiltinOperator_FLOOR_DIV = 90, + BuiltinOperator_REDUCE_ANY = 91, + BuiltinOperator_SQUARE = 92, + BuiltinOperator_ZEROS_LIKE = 93, + BuiltinOperator_FILL = 94, + BuiltinOperator_FLOOR_MOD = 95, + BuiltinOperator_RANGE = 96, + BuiltinOperator_RESIZE_NEAREST_NEIGHBOR = 97, + BuiltinOperator_LEAKY_RELU = 98, + BuiltinOperator_SQUARED_DIFFERENCE = 99, + BuiltinOperator_MIRROR_PAD = 100, + BuiltinOperator_ABS = 101, + BuiltinOperator_SPLIT_V = 102, + BuiltinOperator_UNIQUE = 103, + BuiltinOperator_CEIL = 104, + BuiltinOperator_REVERSE_V2 = 105, + BuiltinOperator_ADD_N = 106, + BuiltinOperator_GATHER_ND = 107, + BuiltinOperator_COS = 108, + BuiltinOperator_WHERE = 109, + BuiltinOperator_RANK = 110, + BuiltinOperator_ELU = 111, + BuiltinOperator_REVERSE_SEQUENCE = 112, + BuiltinOperator_MATRIX_DIAG = 113, + BuiltinOperator_QUANTIZE = 114, + BuiltinOperator_MATRIX_SET_DIAG = 115, + BuiltinOperator_ROUND = 116, + BuiltinOperator_HARD_SWISH = 117, + BuiltinOperator_IF = 118, + BuiltinOperator_WHILE = 119, + BuiltinOperator_NON_MAX_SUPPRESSION_V4 = 120, + BuiltinOperator_NON_MAX_SUPPRESSION_V5 = 121, + BuiltinOperator_SCATTER_ND = 122, + BuiltinOperator_SELECT_V2 = 123, + BuiltinOperator_DENSIFY = 124, + BuiltinOperator_SEGMENT_SUM = 125, + BuiltinOperator_BATCH_MATMUL = 126, + BuiltinOperator_PLACEHOLDER_FOR_GREATER_OP_CODES = 127, + BuiltinOperator_CUMSUM = 128, + BuiltinOperator_CALL_ONCE = 129, + BuiltinOperator_BROADCAST_TO = 130, + BuiltinOperator_RFFT2D = 131, + BuiltinOperator_CONV_3D = 132, + BuiltinOperator_IMAG = 133, + BuiltinOperator_REAL = 134, + BuiltinOperator_COMPLEX_ABS = 135, + BuiltinOperator_HASHTABLE = 136, + BuiltinOperator_HASHTABLE_FIND = 137, + BuiltinOperator_HASHTABLE_IMPORT = 138, + BuiltinOperator_HASHTABLE_SIZE = 139, + BuiltinOperator_REDUCE_ALL = 140, + BuiltinOperator_CONV_3D_TRANSPOSE = 141, + BuiltinOperator_VAR_HANDLE = 142, + BuiltinOperator_READ_VARIABLE = 143, + BuiltinOperator_ASSIGN_VARIABLE = 144, + BuiltinOperator_BROADCAST_ARGS = 145, + BuiltinOperator_RANDOM_STANDARD_NORMAL = 146, + BuiltinOperator_BUCKETIZE = 147, + BuiltinOperator_RANDOM_UNIFORM = 148, + BuiltinOperator_MULTINOMIAL = 149, + BuiltinOperator_GELU = 150, + BuiltinOperator_DYNAMIC_UPDATE_SLICE = 151, + BuiltinOperator_RELU_0_TO_1 = 152, + BuiltinOperator_UNSORTED_SEGMENT_PROD = 153, + BuiltinOperator_UNSORTED_SEGMENT_MAX = 154, + BuiltinOperator_UNSORTED_SEGMENT_SUM = 155, + BuiltinOperator_ATAN2 = 156, + BuiltinOperator_UNSORTED_SEGMENT_MIN = 157, + BuiltinOperator_SIGN = 158, + BuiltinOperator_BITCAST = 159, + BuiltinOperator_BITWISE_XOR = 160, + BuiltinOperator_RIGHT_SHIFT = 161, + BuiltinOperator_STABLEHLO_LOGISTIC = 162, + BuiltinOperator_STABLEHLO_ADD = 163, + BuiltinOperator_STABLEHLO_DIVIDE = 164, + BuiltinOperator_STABLEHLO_MULTIPLY = 165, + BuiltinOperator_STABLEHLO_MAXIMUM = 166, + BuiltinOperator_STABLEHLO_RESHAPE = 167, + BuiltinOperator_STABLEHLO_CLAMP = 168, + BuiltinOperator_STABLEHLO_CONCATENATE = 169, + BuiltinOperator_STABLEHLO_BROADCAST_IN_DIM = 170, + BuiltinOperator_STABLEHLO_CONVOLUTION = 171, + BuiltinOperator_STABLEHLO_SLICE = 172, + BuiltinOperator_STABLEHLO_CUSTOM_CALL = 173, + BuiltinOperator_STABLEHLO_REDUCE = 174, + BuiltinOperator_STABLEHLO_ABS = 175, + BuiltinOperator_STABLEHLO_AND = 176, + BuiltinOperator_STABLEHLO_COSINE = 177, + BuiltinOperator_STABLEHLO_EXPONENTIAL = 178, + BuiltinOperator_STABLEHLO_FLOOR = 179, + BuiltinOperator_STABLEHLO_LOG = 180, + BuiltinOperator_STABLEHLO_MINIMUM = 181, + BuiltinOperator_STABLEHLO_NEGATE = 182, + BuiltinOperator_STABLEHLO_OR = 183, + BuiltinOperator_STABLEHLO_POWER = 184, + BuiltinOperator_STABLEHLO_REMAINDER = 185, + BuiltinOperator_STABLEHLO_RSQRT = 186, + BuiltinOperator_STABLEHLO_SELECT = 187, + BuiltinOperator_STABLEHLO_SUBTRACT = 188, + BuiltinOperator_STABLEHLO_TANH = 189, + BuiltinOperator_STABLEHLO_SCATTER = 190, + BuiltinOperator_STABLEHLO_COMPARE = 191, + BuiltinOperator_STABLEHLO_CONVERT = 192, + BuiltinOperator_STABLEHLO_DYNAMIC_SLICE = 193, + BuiltinOperator_STABLEHLO_DYNAMIC_UPDATE_SLICE = 194, + BuiltinOperator_STABLEHLO_PAD = 195, + BuiltinOperator_STABLEHLO_IOTA = 196, + BuiltinOperator_STABLEHLO_DOT_GENERAL = 197, + BuiltinOperator_STABLEHLO_REDUCE_WINDOW = 198, + BuiltinOperator_STABLEHLO_SORT = 199, + BuiltinOperator_STABLEHLO_WHILE = 200, + BuiltinOperator_STABLEHLO_GATHER = 201, + BuiltinOperator_STABLEHLO_TRANSPOSE = 202, + BuiltinOperator_DILATE = 203, + BuiltinOperator_STABLEHLO_RNG_BIT_GENERATOR = 204, + BuiltinOperator_REDUCE_WINDOW = 205, + BuiltinOperator_STABLEHLO_COMPOSITE = 206, + BuiltinOperator_STABLEHLO_SHIFT_LEFT = 207, + BuiltinOperator_STABLEHLO_CBRT = 208, + BuiltinOperator_STABLEHLO_CASE = 209, + BuiltinOperator_MIN = BuiltinOperator_ADD, + BuiltinOperator_MAX = BuiltinOperator_STABLEHLO_CASE +}; + +inline const BuiltinOperator (&EnumValuesBuiltinOperator())[210] { + static const BuiltinOperator values[] = { + BuiltinOperator_ADD, + BuiltinOperator_AVERAGE_POOL_2D, + BuiltinOperator_CONCATENATION, + BuiltinOperator_CONV_2D, + BuiltinOperator_DEPTHWISE_CONV_2D, + BuiltinOperator_DEPTH_TO_SPACE, + BuiltinOperator_DEQUANTIZE, + BuiltinOperator_EMBEDDING_LOOKUP, + BuiltinOperator_FLOOR, + BuiltinOperator_FULLY_CONNECTED, + BuiltinOperator_HASHTABLE_LOOKUP, + BuiltinOperator_L2_NORMALIZATION, + BuiltinOperator_L2_POOL_2D, + BuiltinOperator_LOCAL_RESPONSE_NORMALIZATION, + BuiltinOperator_LOGISTIC, + BuiltinOperator_LSH_PROJECTION, + BuiltinOperator_LSTM, + BuiltinOperator_MAX_POOL_2D, + BuiltinOperator_MUL, + BuiltinOperator_RELU, + BuiltinOperator_RELU_N1_TO_1, + BuiltinOperator_RELU6, + BuiltinOperator_RESHAPE, + BuiltinOperator_RESIZE_BILINEAR, + BuiltinOperator_RNN, + BuiltinOperator_SOFTMAX, + BuiltinOperator_SPACE_TO_DEPTH, + BuiltinOperator_SVDF, + BuiltinOperator_TANH, + BuiltinOperator_CONCAT_EMBEDDINGS, + BuiltinOperator_SKIP_GRAM, + BuiltinOperator_CALL, + BuiltinOperator_CUSTOM, + BuiltinOperator_EMBEDDING_LOOKUP_SPARSE, + BuiltinOperator_PAD, + BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN, + BuiltinOperator_GATHER, + BuiltinOperator_BATCH_TO_SPACE_ND, + BuiltinOperator_SPACE_TO_BATCH_ND, + BuiltinOperator_TRANSPOSE, + BuiltinOperator_MEAN, + BuiltinOperator_SUB, + BuiltinOperator_DIV, + BuiltinOperator_SQUEEZE, + BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM, + BuiltinOperator_STRIDED_SLICE, + BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN, + BuiltinOperator_EXP, + BuiltinOperator_TOPK_V2, + BuiltinOperator_SPLIT, + BuiltinOperator_LOG_SOFTMAX, + BuiltinOperator_DELEGATE, + BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM, + BuiltinOperator_CAST, + BuiltinOperator_PRELU, + BuiltinOperator_MAXIMUM, + BuiltinOperator_ARG_MAX, + BuiltinOperator_MINIMUM, + BuiltinOperator_LESS, + BuiltinOperator_NEG, + BuiltinOperator_PADV2, + BuiltinOperator_GREATER, + BuiltinOperator_GREATER_EQUAL, + BuiltinOperator_LESS_EQUAL, + BuiltinOperator_SELECT, + BuiltinOperator_SLICE, + BuiltinOperator_SIN, + BuiltinOperator_TRANSPOSE_CONV, + BuiltinOperator_SPARSE_TO_DENSE, + BuiltinOperator_TILE, + BuiltinOperator_EXPAND_DIMS, + BuiltinOperator_EQUAL, + BuiltinOperator_NOT_EQUAL, + BuiltinOperator_LOG, + BuiltinOperator_SUM, + BuiltinOperator_SQRT, + BuiltinOperator_RSQRT, + BuiltinOperator_SHAPE, + BuiltinOperator_POW, + BuiltinOperator_ARG_MIN, + BuiltinOperator_FAKE_QUANT, + BuiltinOperator_REDUCE_PROD, + BuiltinOperator_REDUCE_MAX, + BuiltinOperator_PACK, + BuiltinOperator_LOGICAL_OR, + BuiltinOperator_ONE_HOT, + BuiltinOperator_LOGICAL_AND, + BuiltinOperator_LOGICAL_NOT, + BuiltinOperator_UNPACK, + BuiltinOperator_REDUCE_MIN, + BuiltinOperator_FLOOR_DIV, + BuiltinOperator_REDUCE_ANY, + BuiltinOperator_SQUARE, + BuiltinOperator_ZEROS_LIKE, + BuiltinOperator_FILL, + BuiltinOperator_FLOOR_MOD, + BuiltinOperator_RANGE, + BuiltinOperator_RESIZE_NEAREST_NEIGHBOR, + BuiltinOperator_LEAKY_RELU, + BuiltinOperator_SQUARED_DIFFERENCE, + BuiltinOperator_MIRROR_PAD, + BuiltinOperator_ABS, + BuiltinOperator_SPLIT_V, + BuiltinOperator_UNIQUE, + BuiltinOperator_CEIL, + BuiltinOperator_REVERSE_V2, + BuiltinOperator_ADD_N, + BuiltinOperator_GATHER_ND, + BuiltinOperator_COS, + BuiltinOperator_WHERE, + BuiltinOperator_RANK, + BuiltinOperator_ELU, + BuiltinOperator_REVERSE_SEQUENCE, + BuiltinOperator_MATRIX_DIAG, + BuiltinOperator_QUANTIZE, + BuiltinOperator_MATRIX_SET_DIAG, + BuiltinOperator_ROUND, + BuiltinOperator_HARD_SWISH, + BuiltinOperator_IF, + BuiltinOperator_WHILE, + BuiltinOperator_NON_MAX_SUPPRESSION_V4, + BuiltinOperator_NON_MAX_SUPPRESSION_V5, + BuiltinOperator_SCATTER_ND, + BuiltinOperator_SELECT_V2, + BuiltinOperator_DENSIFY, + BuiltinOperator_SEGMENT_SUM, + BuiltinOperator_BATCH_MATMUL, + BuiltinOperator_PLACEHOLDER_FOR_GREATER_OP_CODES, + BuiltinOperator_CUMSUM, + BuiltinOperator_CALL_ONCE, + BuiltinOperator_BROADCAST_TO, + BuiltinOperator_RFFT2D, + BuiltinOperator_CONV_3D, + BuiltinOperator_IMAG, + BuiltinOperator_REAL, + BuiltinOperator_COMPLEX_ABS, + BuiltinOperator_HASHTABLE, + BuiltinOperator_HASHTABLE_FIND, + BuiltinOperator_HASHTABLE_IMPORT, + BuiltinOperator_HASHTABLE_SIZE, + BuiltinOperator_REDUCE_ALL, + BuiltinOperator_CONV_3D_TRANSPOSE, + BuiltinOperator_VAR_HANDLE, + BuiltinOperator_READ_VARIABLE, + BuiltinOperator_ASSIGN_VARIABLE, + BuiltinOperator_BROADCAST_ARGS, + BuiltinOperator_RANDOM_STANDARD_NORMAL, + BuiltinOperator_BUCKETIZE, + BuiltinOperator_RANDOM_UNIFORM, + BuiltinOperator_MULTINOMIAL, + BuiltinOperator_GELU, + BuiltinOperator_DYNAMIC_UPDATE_SLICE, + BuiltinOperator_RELU_0_TO_1, + BuiltinOperator_UNSORTED_SEGMENT_PROD, + BuiltinOperator_UNSORTED_SEGMENT_MAX, + BuiltinOperator_UNSORTED_SEGMENT_SUM, + BuiltinOperator_ATAN2, + BuiltinOperator_UNSORTED_SEGMENT_MIN, + BuiltinOperator_SIGN, + BuiltinOperator_BITCAST, + BuiltinOperator_BITWISE_XOR, + BuiltinOperator_RIGHT_SHIFT, + BuiltinOperator_STABLEHLO_LOGISTIC, + BuiltinOperator_STABLEHLO_ADD, + BuiltinOperator_STABLEHLO_DIVIDE, + BuiltinOperator_STABLEHLO_MULTIPLY, + BuiltinOperator_STABLEHLO_MAXIMUM, + BuiltinOperator_STABLEHLO_RESHAPE, + BuiltinOperator_STABLEHLO_CLAMP, + BuiltinOperator_STABLEHLO_CONCATENATE, + BuiltinOperator_STABLEHLO_BROADCAST_IN_DIM, + BuiltinOperator_STABLEHLO_CONVOLUTION, + BuiltinOperator_STABLEHLO_SLICE, + BuiltinOperator_STABLEHLO_CUSTOM_CALL, + BuiltinOperator_STABLEHLO_REDUCE, + BuiltinOperator_STABLEHLO_ABS, + BuiltinOperator_STABLEHLO_AND, + BuiltinOperator_STABLEHLO_COSINE, + BuiltinOperator_STABLEHLO_EXPONENTIAL, + BuiltinOperator_STABLEHLO_FLOOR, + BuiltinOperator_STABLEHLO_LOG, + BuiltinOperator_STABLEHLO_MINIMUM, + BuiltinOperator_STABLEHLO_NEGATE, + BuiltinOperator_STABLEHLO_OR, + BuiltinOperator_STABLEHLO_POWER, + BuiltinOperator_STABLEHLO_REMAINDER, + BuiltinOperator_STABLEHLO_RSQRT, + BuiltinOperator_STABLEHLO_SELECT, + BuiltinOperator_STABLEHLO_SUBTRACT, + BuiltinOperator_STABLEHLO_TANH, + BuiltinOperator_STABLEHLO_SCATTER, + BuiltinOperator_STABLEHLO_COMPARE, + BuiltinOperator_STABLEHLO_CONVERT, + BuiltinOperator_STABLEHLO_DYNAMIC_SLICE, + BuiltinOperator_STABLEHLO_DYNAMIC_UPDATE_SLICE, + BuiltinOperator_STABLEHLO_PAD, + BuiltinOperator_STABLEHLO_IOTA, + BuiltinOperator_STABLEHLO_DOT_GENERAL, + BuiltinOperator_STABLEHLO_REDUCE_WINDOW, + BuiltinOperator_STABLEHLO_SORT, + BuiltinOperator_STABLEHLO_WHILE, + BuiltinOperator_STABLEHLO_GATHER, + BuiltinOperator_STABLEHLO_TRANSPOSE, + BuiltinOperator_DILATE, + BuiltinOperator_STABLEHLO_RNG_BIT_GENERATOR, + BuiltinOperator_REDUCE_WINDOW, + BuiltinOperator_STABLEHLO_COMPOSITE, + BuiltinOperator_STABLEHLO_SHIFT_LEFT, + BuiltinOperator_STABLEHLO_CBRT, + BuiltinOperator_STABLEHLO_CASE + }; + return values; +} + +inline const char * const *EnumNamesBuiltinOperator() { + static const char * const names[211] = { + "ADD", + "AVERAGE_POOL_2D", + "CONCATENATION", + "CONV_2D", + "DEPTHWISE_CONV_2D", + "DEPTH_TO_SPACE", + "DEQUANTIZE", + "EMBEDDING_LOOKUP", + "FLOOR", + "FULLY_CONNECTED", + "HASHTABLE_LOOKUP", + "L2_NORMALIZATION", + "L2_POOL_2D", + "LOCAL_RESPONSE_NORMALIZATION", + "LOGISTIC", + "LSH_PROJECTION", + "LSTM", + "MAX_POOL_2D", + "MUL", + "RELU", + "RELU_N1_TO_1", + "RELU6", + "RESHAPE", + "RESIZE_BILINEAR", + "RNN", + "SOFTMAX", + "SPACE_TO_DEPTH", + "SVDF", + "TANH", + "CONCAT_EMBEDDINGS", + "SKIP_GRAM", + "CALL", + "CUSTOM", + "EMBEDDING_LOOKUP_SPARSE", + "PAD", + "UNIDIRECTIONAL_SEQUENCE_RNN", + "GATHER", + "BATCH_TO_SPACE_ND", + "SPACE_TO_BATCH_ND", + "TRANSPOSE", + "MEAN", + "SUB", + "DIV", + "SQUEEZE", + "UNIDIRECTIONAL_SEQUENCE_LSTM", + "STRIDED_SLICE", + "BIDIRECTIONAL_SEQUENCE_RNN", + "EXP", + "TOPK_V2", + "SPLIT", + "LOG_SOFTMAX", + "DELEGATE", + "BIDIRECTIONAL_SEQUENCE_LSTM", + "CAST", + "PRELU", + "MAXIMUM", + "ARG_MAX", + "MINIMUM", + "LESS", + "NEG", + "PADV2", + "GREATER", + "GREATER_EQUAL", + "LESS_EQUAL", + "SELECT", + "SLICE", + "SIN", + "TRANSPOSE_CONV", + "SPARSE_TO_DENSE", + "TILE", + "EXPAND_DIMS", + "EQUAL", + "NOT_EQUAL", + "LOG", + "SUM", + "SQRT", + "RSQRT", + "SHAPE", + "POW", + "ARG_MIN", + "FAKE_QUANT", + "REDUCE_PROD", + "REDUCE_MAX", + "PACK", + "LOGICAL_OR", + "ONE_HOT", + "LOGICAL_AND", + "LOGICAL_NOT", + "UNPACK", + "REDUCE_MIN", + "FLOOR_DIV", + "REDUCE_ANY", + "SQUARE", + "ZEROS_LIKE", + "FILL", + "FLOOR_MOD", + "RANGE", + "RESIZE_NEAREST_NEIGHBOR", + "LEAKY_RELU", + "SQUARED_DIFFERENCE", + "MIRROR_PAD", + "ABS", + "SPLIT_V", + "UNIQUE", + "CEIL", + "REVERSE_V2", + "ADD_N", + "GATHER_ND", + "COS", + "WHERE", + "RANK", + "ELU", + "REVERSE_SEQUENCE", + "MATRIX_DIAG", + "QUANTIZE", + "MATRIX_SET_DIAG", + "ROUND", + "HARD_SWISH", + "IF", + "WHILE", + "NON_MAX_SUPPRESSION_V4", + "NON_MAX_SUPPRESSION_V5", + "SCATTER_ND", + "SELECT_V2", + "DENSIFY", + "SEGMENT_SUM", + "BATCH_MATMUL", + "PLACEHOLDER_FOR_GREATER_OP_CODES", + "CUMSUM", + "CALL_ONCE", + "BROADCAST_TO", + "RFFT2D", + "CONV_3D", + "IMAG", + "REAL", + "COMPLEX_ABS", + "HASHTABLE", + "HASHTABLE_FIND", + "HASHTABLE_IMPORT", + "HASHTABLE_SIZE", + "REDUCE_ALL", + "CONV_3D_TRANSPOSE", + "VAR_HANDLE", + "READ_VARIABLE", + "ASSIGN_VARIABLE", + "BROADCAST_ARGS", + "RANDOM_STANDARD_NORMAL", + "BUCKETIZE", + "RANDOM_UNIFORM", + "MULTINOMIAL", + "GELU", + "DYNAMIC_UPDATE_SLICE", + "RELU_0_TO_1", + "UNSORTED_SEGMENT_PROD", + "UNSORTED_SEGMENT_MAX", + "UNSORTED_SEGMENT_SUM", + "ATAN2", + "UNSORTED_SEGMENT_MIN", + "SIGN", + "BITCAST", + "BITWISE_XOR", + "RIGHT_SHIFT", + "STABLEHLO_LOGISTIC", + "STABLEHLO_ADD", + "STABLEHLO_DIVIDE", + "STABLEHLO_MULTIPLY", + "STABLEHLO_MAXIMUM", + "STABLEHLO_RESHAPE", + "STABLEHLO_CLAMP", + "STABLEHLO_CONCATENATE", + "STABLEHLO_BROADCAST_IN_DIM", + "STABLEHLO_CONVOLUTION", + "STABLEHLO_SLICE", + "STABLEHLO_CUSTOM_CALL", + "STABLEHLO_REDUCE", + "STABLEHLO_ABS", + "STABLEHLO_AND", + "STABLEHLO_COSINE", + "STABLEHLO_EXPONENTIAL", + "STABLEHLO_FLOOR", + "STABLEHLO_LOG", + "STABLEHLO_MINIMUM", + "STABLEHLO_NEGATE", + "STABLEHLO_OR", + "STABLEHLO_POWER", + "STABLEHLO_REMAINDER", + "STABLEHLO_RSQRT", + "STABLEHLO_SELECT", + "STABLEHLO_SUBTRACT", + "STABLEHLO_TANH", + "STABLEHLO_SCATTER", + "STABLEHLO_COMPARE", + "STABLEHLO_CONVERT", + "STABLEHLO_DYNAMIC_SLICE", + "STABLEHLO_DYNAMIC_UPDATE_SLICE", + "STABLEHLO_PAD", + "STABLEHLO_IOTA", + "STABLEHLO_DOT_GENERAL", + "STABLEHLO_REDUCE_WINDOW", + "STABLEHLO_SORT", + "STABLEHLO_WHILE", + "STABLEHLO_GATHER", + "STABLEHLO_TRANSPOSE", + "DILATE", + "STABLEHLO_RNG_BIT_GENERATOR", + "REDUCE_WINDOW", + "STABLEHLO_COMPOSITE", + "STABLEHLO_SHIFT_LEFT", + "STABLEHLO_CBRT", + "STABLEHLO_CASE", + nullptr + }; + return names; +} + +inline const char *EnumNameBuiltinOperator(BuiltinOperator e) { + if (::flatbuffers::IsOutRange(e, BuiltinOperator_ADD, BuiltinOperator_STABLEHLO_CASE)) return ""; + const size_t index = static_cast(e); + return EnumNamesBuiltinOperator()[index]; +} + +enum BuiltinOptions : uint8_t { + BuiltinOptions_NONE = 0, + BuiltinOptions_Conv2DOptions = 1, + BuiltinOptions_DepthwiseConv2DOptions = 2, + BuiltinOptions_ConcatEmbeddingsOptions = 3, + BuiltinOptions_LSHProjectionOptions = 4, + BuiltinOptions_Pool2DOptions = 5, + BuiltinOptions_SVDFOptions = 6, + BuiltinOptions_RNNOptions = 7, + BuiltinOptions_FullyConnectedOptions = 8, + BuiltinOptions_SoftmaxOptions = 9, + BuiltinOptions_ConcatenationOptions = 10, + BuiltinOptions_AddOptions = 11, + BuiltinOptions_L2NormOptions = 12, + BuiltinOptions_LocalResponseNormalizationOptions = 13, + BuiltinOptions_LSTMOptions = 14, + BuiltinOptions_ResizeBilinearOptions = 15, + BuiltinOptions_CallOptions = 16, + BuiltinOptions_ReshapeOptions = 17, + BuiltinOptions_SkipGramOptions = 18, + BuiltinOptions_SpaceToDepthOptions = 19, + BuiltinOptions_EmbeddingLookupSparseOptions = 20, + BuiltinOptions_MulOptions = 21, + BuiltinOptions_PadOptions = 22, + BuiltinOptions_GatherOptions = 23, + BuiltinOptions_BatchToSpaceNDOptions = 24, + BuiltinOptions_SpaceToBatchNDOptions = 25, + BuiltinOptions_TransposeOptions = 26, + BuiltinOptions_ReducerOptions = 27, + BuiltinOptions_SubOptions = 28, + BuiltinOptions_DivOptions = 29, + BuiltinOptions_SqueezeOptions = 30, + BuiltinOptions_SequenceRNNOptions = 31, + BuiltinOptions_StridedSliceOptions = 32, + BuiltinOptions_ExpOptions = 33, + BuiltinOptions_TopKV2Options = 34, + BuiltinOptions_SplitOptions = 35, + BuiltinOptions_LogSoftmaxOptions = 36, + BuiltinOptions_CastOptions = 37, + BuiltinOptions_DequantizeOptions = 38, + BuiltinOptions_MaximumMinimumOptions = 39, + BuiltinOptions_ArgMaxOptions = 40, + BuiltinOptions_LessOptions = 41, + BuiltinOptions_NegOptions = 42, + BuiltinOptions_PadV2Options = 43, + BuiltinOptions_GreaterOptions = 44, + BuiltinOptions_GreaterEqualOptions = 45, + BuiltinOptions_LessEqualOptions = 46, + BuiltinOptions_SelectOptions = 47, + BuiltinOptions_SliceOptions = 48, + BuiltinOptions_TransposeConvOptions = 49, + BuiltinOptions_SparseToDenseOptions = 50, + BuiltinOptions_TileOptions = 51, + BuiltinOptions_ExpandDimsOptions = 52, + BuiltinOptions_EqualOptions = 53, + BuiltinOptions_NotEqualOptions = 54, + BuiltinOptions_ShapeOptions = 55, + BuiltinOptions_PowOptions = 56, + BuiltinOptions_ArgMinOptions = 57, + BuiltinOptions_FakeQuantOptions = 58, + BuiltinOptions_PackOptions = 59, + BuiltinOptions_LogicalOrOptions = 60, + BuiltinOptions_OneHotOptions = 61, + BuiltinOptions_LogicalAndOptions = 62, + BuiltinOptions_LogicalNotOptions = 63, + BuiltinOptions_UnpackOptions = 64, + BuiltinOptions_FloorDivOptions = 65, + BuiltinOptions_SquareOptions = 66, + BuiltinOptions_ZerosLikeOptions = 67, + BuiltinOptions_FillOptions = 68, + BuiltinOptions_BidirectionalSequenceLSTMOptions = 69, + BuiltinOptions_BidirectionalSequenceRNNOptions = 70, + BuiltinOptions_UnidirectionalSequenceLSTMOptions = 71, + BuiltinOptions_FloorModOptions = 72, + BuiltinOptions_RangeOptions = 73, + BuiltinOptions_ResizeNearestNeighborOptions = 74, + BuiltinOptions_LeakyReluOptions = 75, + BuiltinOptions_SquaredDifferenceOptions = 76, + BuiltinOptions_MirrorPadOptions = 77, + BuiltinOptions_AbsOptions = 78, + BuiltinOptions_SplitVOptions = 79, + BuiltinOptions_UniqueOptions = 80, + BuiltinOptions_ReverseV2Options = 81, + BuiltinOptions_AddNOptions = 82, + BuiltinOptions_GatherNdOptions = 83, + BuiltinOptions_CosOptions = 84, + BuiltinOptions_WhereOptions = 85, + BuiltinOptions_RankOptions = 86, + BuiltinOptions_ReverseSequenceOptions = 87, + BuiltinOptions_MatrixDiagOptions = 88, + BuiltinOptions_QuantizeOptions = 89, + BuiltinOptions_MatrixSetDiagOptions = 90, + BuiltinOptions_HardSwishOptions = 91, + BuiltinOptions_IfOptions = 92, + BuiltinOptions_WhileOptions = 93, + BuiltinOptions_DepthToSpaceOptions = 94, + BuiltinOptions_NonMaxSuppressionV4Options = 95, + BuiltinOptions_NonMaxSuppressionV5Options = 96, + BuiltinOptions_ScatterNdOptions = 97, + BuiltinOptions_SelectV2Options = 98, + BuiltinOptions_DensifyOptions = 99, + BuiltinOptions_SegmentSumOptions = 100, + BuiltinOptions_BatchMatMulOptions = 101, + BuiltinOptions_CumsumOptions = 102, + BuiltinOptions_CallOnceOptions = 103, + BuiltinOptions_BroadcastToOptions = 104, + BuiltinOptions_Rfft2dOptions = 105, + BuiltinOptions_Conv3DOptions = 106, + BuiltinOptions_HashtableOptions = 107, + BuiltinOptions_HashtableFindOptions = 108, + BuiltinOptions_HashtableImportOptions = 109, + BuiltinOptions_HashtableSizeOptions = 110, + BuiltinOptions_VarHandleOptions = 111, + BuiltinOptions_ReadVariableOptions = 112, + BuiltinOptions_AssignVariableOptions = 113, + BuiltinOptions_RandomOptions = 114, + BuiltinOptions_BucketizeOptions = 115, + BuiltinOptions_GeluOptions = 116, + BuiltinOptions_DynamicUpdateSliceOptions = 117, + BuiltinOptions_UnsortedSegmentProdOptions = 118, + BuiltinOptions_UnsortedSegmentMaxOptions = 119, + BuiltinOptions_UnsortedSegmentMinOptions = 120, + BuiltinOptions_UnsortedSegmentSumOptions = 121, + BuiltinOptions_ATan2Options = 122, + BuiltinOptions_SignOptions = 123, + BuiltinOptions_BitcastOptions = 124, + BuiltinOptions_BitwiseXorOptions = 125, + BuiltinOptions_RightShiftOptions = 126, + BuiltinOptions_MIN = BuiltinOptions_NONE, + BuiltinOptions_MAX = BuiltinOptions_RightShiftOptions +}; + +inline const BuiltinOptions (&EnumValuesBuiltinOptions())[127] { + static const BuiltinOptions values[] = { + BuiltinOptions_NONE, + BuiltinOptions_Conv2DOptions, + BuiltinOptions_DepthwiseConv2DOptions, + BuiltinOptions_ConcatEmbeddingsOptions, + BuiltinOptions_LSHProjectionOptions, + BuiltinOptions_Pool2DOptions, + BuiltinOptions_SVDFOptions, + BuiltinOptions_RNNOptions, + BuiltinOptions_FullyConnectedOptions, + BuiltinOptions_SoftmaxOptions, + BuiltinOptions_ConcatenationOptions, + BuiltinOptions_AddOptions, + BuiltinOptions_L2NormOptions, + BuiltinOptions_LocalResponseNormalizationOptions, + BuiltinOptions_LSTMOptions, + BuiltinOptions_ResizeBilinearOptions, + BuiltinOptions_CallOptions, + BuiltinOptions_ReshapeOptions, + BuiltinOptions_SkipGramOptions, + BuiltinOptions_SpaceToDepthOptions, + BuiltinOptions_EmbeddingLookupSparseOptions, + BuiltinOptions_MulOptions, + BuiltinOptions_PadOptions, + BuiltinOptions_GatherOptions, + BuiltinOptions_BatchToSpaceNDOptions, + BuiltinOptions_SpaceToBatchNDOptions, + BuiltinOptions_TransposeOptions, + BuiltinOptions_ReducerOptions, + BuiltinOptions_SubOptions, + BuiltinOptions_DivOptions, + BuiltinOptions_SqueezeOptions, + BuiltinOptions_SequenceRNNOptions, + BuiltinOptions_StridedSliceOptions, + BuiltinOptions_ExpOptions, + BuiltinOptions_TopKV2Options, + BuiltinOptions_SplitOptions, + BuiltinOptions_LogSoftmaxOptions, + BuiltinOptions_CastOptions, + BuiltinOptions_DequantizeOptions, + BuiltinOptions_MaximumMinimumOptions, + BuiltinOptions_ArgMaxOptions, + BuiltinOptions_LessOptions, + BuiltinOptions_NegOptions, + BuiltinOptions_PadV2Options, + BuiltinOptions_GreaterOptions, + BuiltinOptions_GreaterEqualOptions, + BuiltinOptions_LessEqualOptions, + BuiltinOptions_SelectOptions, + BuiltinOptions_SliceOptions, + BuiltinOptions_TransposeConvOptions, + BuiltinOptions_SparseToDenseOptions, + BuiltinOptions_TileOptions, + BuiltinOptions_ExpandDimsOptions, + BuiltinOptions_EqualOptions, + BuiltinOptions_NotEqualOptions, + BuiltinOptions_ShapeOptions, + BuiltinOptions_PowOptions, + BuiltinOptions_ArgMinOptions, + BuiltinOptions_FakeQuantOptions, + BuiltinOptions_PackOptions, + BuiltinOptions_LogicalOrOptions, + BuiltinOptions_OneHotOptions, + BuiltinOptions_LogicalAndOptions, + BuiltinOptions_LogicalNotOptions, + BuiltinOptions_UnpackOptions, + BuiltinOptions_FloorDivOptions, + BuiltinOptions_SquareOptions, + BuiltinOptions_ZerosLikeOptions, + BuiltinOptions_FillOptions, + BuiltinOptions_BidirectionalSequenceLSTMOptions, + BuiltinOptions_BidirectionalSequenceRNNOptions, + BuiltinOptions_UnidirectionalSequenceLSTMOptions, + BuiltinOptions_FloorModOptions, + BuiltinOptions_RangeOptions, + BuiltinOptions_ResizeNearestNeighborOptions, + BuiltinOptions_LeakyReluOptions, + BuiltinOptions_SquaredDifferenceOptions, + BuiltinOptions_MirrorPadOptions, + BuiltinOptions_AbsOptions, + BuiltinOptions_SplitVOptions, + BuiltinOptions_UniqueOptions, + BuiltinOptions_ReverseV2Options, + BuiltinOptions_AddNOptions, + BuiltinOptions_GatherNdOptions, + BuiltinOptions_CosOptions, + BuiltinOptions_WhereOptions, + BuiltinOptions_RankOptions, + BuiltinOptions_ReverseSequenceOptions, + BuiltinOptions_MatrixDiagOptions, + BuiltinOptions_QuantizeOptions, + BuiltinOptions_MatrixSetDiagOptions, + BuiltinOptions_HardSwishOptions, + BuiltinOptions_IfOptions, + BuiltinOptions_WhileOptions, + BuiltinOptions_DepthToSpaceOptions, + BuiltinOptions_NonMaxSuppressionV4Options, + BuiltinOptions_NonMaxSuppressionV5Options, + BuiltinOptions_ScatterNdOptions, + BuiltinOptions_SelectV2Options, + BuiltinOptions_DensifyOptions, + BuiltinOptions_SegmentSumOptions, + BuiltinOptions_BatchMatMulOptions, + BuiltinOptions_CumsumOptions, + BuiltinOptions_CallOnceOptions, + BuiltinOptions_BroadcastToOptions, + BuiltinOptions_Rfft2dOptions, + BuiltinOptions_Conv3DOptions, + BuiltinOptions_HashtableOptions, + BuiltinOptions_HashtableFindOptions, + BuiltinOptions_HashtableImportOptions, + BuiltinOptions_HashtableSizeOptions, + BuiltinOptions_VarHandleOptions, + BuiltinOptions_ReadVariableOptions, + BuiltinOptions_AssignVariableOptions, + BuiltinOptions_RandomOptions, + BuiltinOptions_BucketizeOptions, + BuiltinOptions_GeluOptions, + BuiltinOptions_DynamicUpdateSliceOptions, + BuiltinOptions_UnsortedSegmentProdOptions, + BuiltinOptions_UnsortedSegmentMaxOptions, + BuiltinOptions_UnsortedSegmentMinOptions, + BuiltinOptions_UnsortedSegmentSumOptions, + BuiltinOptions_ATan2Options, + BuiltinOptions_SignOptions, + BuiltinOptions_BitcastOptions, + BuiltinOptions_BitwiseXorOptions, + BuiltinOptions_RightShiftOptions + }; + return values; +} + +inline const char * const *EnumNamesBuiltinOptions() { + static const char * const names[128] = { + "NONE", + "Conv2DOptions", + "DepthwiseConv2DOptions", + "ConcatEmbeddingsOptions", + "LSHProjectionOptions", + "Pool2DOptions", + "SVDFOptions", + "RNNOptions", + "FullyConnectedOptions", + "SoftmaxOptions", + "ConcatenationOptions", + "AddOptions", + "L2NormOptions", + "LocalResponseNormalizationOptions", + "LSTMOptions", + "ResizeBilinearOptions", + "CallOptions", + "ReshapeOptions", + "SkipGramOptions", + "SpaceToDepthOptions", + "EmbeddingLookupSparseOptions", + "MulOptions", + "PadOptions", + "GatherOptions", + "BatchToSpaceNDOptions", + "SpaceToBatchNDOptions", + "TransposeOptions", + "ReducerOptions", + "SubOptions", + "DivOptions", + "SqueezeOptions", + "SequenceRNNOptions", + "StridedSliceOptions", + "ExpOptions", + "TopKV2Options", + "SplitOptions", + "LogSoftmaxOptions", + "CastOptions", + "DequantizeOptions", + "MaximumMinimumOptions", + "ArgMaxOptions", + "LessOptions", + "NegOptions", + "PadV2Options", + "GreaterOptions", + "GreaterEqualOptions", + "LessEqualOptions", + "SelectOptions", + "SliceOptions", + "TransposeConvOptions", + "SparseToDenseOptions", + "TileOptions", + "ExpandDimsOptions", + "EqualOptions", + "NotEqualOptions", + "ShapeOptions", + "PowOptions", + "ArgMinOptions", + "FakeQuantOptions", + "PackOptions", + "LogicalOrOptions", + "OneHotOptions", + "LogicalAndOptions", + "LogicalNotOptions", + "UnpackOptions", + "FloorDivOptions", + "SquareOptions", + "ZerosLikeOptions", + "FillOptions", + "BidirectionalSequenceLSTMOptions", + "BidirectionalSequenceRNNOptions", + "UnidirectionalSequenceLSTMOptions", + "FloorModOptions", + "RangeOptions", + "ResizeNearestNeighborOptions", + "LeakyReluOptions", + "SquaredDifferenceOptions", + "MirrorPadOptions", + "AbsOptions", + "SplitVOptions", + "UniqueOptions", + "ReverseV2Options", + "AddNOptions", + "GatherNdOptions", + "CosOptions", + "WhereOptions", + "RankOptions", + "ReverseSequenceOptions", + "MatrixDiagOptions", + "QuantizeOptions", + "MatrixSetDiagOptions", + "HardSwishOptions", + "IfOptions", + "WhileOptions", + "DepthToSpaceOptions", + "NonMaxSuppressionV4Options", + "NonMaxSuppressionV5Options", + "ScatterNdOptions", + "SelectV2Options", + "DensifyOptions", + "SegmentSumOptions", + "BatchMatMulOptions", + "CumsumOptions", + "CallOnceOptions", + "BroadcastToOptions", + "Rfft2dOptions", + "Conv3DOptions", + "HashtableOptions", + "HashtableFindOptions", + "HashtableImportOptions", + "HashtableSizeOptions", + "VarHandleOptions", + "ReadVariableOptions", + "AssignVariableOptions", + "RandomOptions", + "BucketizeOptions", + "GeluOptions", + "DynamicUpdateSliceOptions", + "UnsortedSegmentProdOptions", + "UnsortedSegmentMaxOptions", + "UnsortedSegmentMinOptions", + "UnsortedSegmentSumOptions", + "ATan2Options", + "SignOptions", + "BitcastOptions", + "BitwiseXorOptions", + "RightShiftOptions", + nullptr + }; + return names; +} + +inline const char *EnumNameBuiltinOptions(BuiltinOptions e) { + if (::flatbuffers::IsOutRange(e, BuiltinOptions_NONE, BuiltinOptions_RightShiftOptions)) return ""; + const size_t index = static_cast(e); + return EnumNamesBuiltinOptions()[index]; +} + +template struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_NONE; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_Conv2DOptions; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_DepthwiseConv2DOptions; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_ConcatEmbeddingsOptions; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_LSHProjectionOptions; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_Pool2DOptions; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_SVDFOptions; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_RNNOptions; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_FullyConnectedOptions; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_SoftmaxOptions; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_ConcatenationOptions; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_AddOptions; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_L2NormOptions; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_LocalResponseNormalizationOptions; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_LSTMOptions; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_ResizeBilinearOptions; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_CallOptions; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_ReshapeOptions; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_SkipGramOptions; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_SpaceToDepthOptions; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_EmbeddingLookupSparseOptions; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_MulOptions; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_PadOptions; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_GatherOptions; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_BatchToSpaceNDOptions; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_SpaceToBatchNDOptions; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_TransposeOptions; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_ReducerOptions; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_SubOptions; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_DivOptions; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_SqueezeOptions; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_SequenceRNNOptions; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_StridedSliceOptions; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_ExpOptions; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_TopKV2Options; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_SplitOptions; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_LogSoftmaxOptions; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_CastOptions; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_DequantizeOptions; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_MaximumMinimumOptions; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_ArgMaxOptions; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_LessOptions; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_NegOptions; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_PadV2Options; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_GreaterOptions; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_GreaterEqualOptions; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_LessEqualOptions; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_SelectOptions; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_SliceOptions; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_TransposeConvOptions; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_SparseToDenseOptions; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_TileOptions; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_ExpandDimsOptions; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_EqualOptions; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_NotEqualOptions; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_ShapeOptions; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_PowOptions; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_ArgMinOptions; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_FakeQuantOptions; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_PackOptions; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_LogicalOrOptions; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_OneHotOptions; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_LogicalAndOptions; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_LogicalNotOptions; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_UnpackOptions; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_FloorDivOptions; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_SquareOptions; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_ZerosLikeOptions; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_FillOptions; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_BidirectionalSequenceLSTMOptions; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_BidirectionalSequenceRNNOptions; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_UnidirectionalSequenceLSTMOptions; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_FloorModOptions; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_RangeOptions; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_ResizeNearestNeighborOptions; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_LeakyReluOptions; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_SquaredDifferenceOptions; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_MirrorPadOptions; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_AbsOptions; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_SplitVOptions; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_UniqueOptions; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_ReverseV2Options; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_AddNOptions; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_GatherNdOptions; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_CosOptions; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_WhereOptions; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_RankOptions; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_ReverseSequenceOptions; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_MatrixDiagOptions; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_QuantizeOptions; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_MatrixSetDiagOptions; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_HardSwishOptions; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_IfOptions; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_WhileOptions; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_DepthToSpaceOptions; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_NonMaxSuppressionV4Options; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_NonMaxSuppressionV5Options; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_ScatterNdOptions; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_SelectV2Options; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_DensifyOptions; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_SegmentSumOptions; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_BatchMatMulOptions; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_CumsumOptions; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_CallOnceOptions; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_BroadcastToOptions; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_Rfft2dOptions; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_Conv3DOptions; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_HashtableOptions; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_HashtableFindOptions; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_HashtableImportOptions; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_HashtableSizeOptions; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_VarHandleOptions; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_ReadVariableOptions; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_AssignVariableOptions; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_RandomOptions; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_BucketizeOptions; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_GeluOptions; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_DynamicUpdateSliceOptions; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_UnsortedSegmentProdOptions; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_UnsortedSegmentMaxOptions; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_UnsortedSegmentMinOptions; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_UnsortedSegmentSumOptions; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_ATan2Options; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_SignOptions; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_BitcastOptions; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_BitwiseXorOptions; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_RightShiftOptions; +}; + +template struct BuiltinOptionsUnionTraits { + static const BuiltinOptions enum_value = BuiltinOptions_NONE; +}; + +template<> struct BuiltinOptionsUnionTraits { + static const BuiltinOptions enum_value = BuiltinOptions_Conv2DOptions; +}; + +template<> struct BuiltinOptionsUnionTraits { + static const BuiltinOptions enum_value = BuiltinOptions_DepthwiseConv2DOptions; +}; + +template<> struct BuiltinOptionsUnionTraits { + static const BuiltinOptions enum_value = BuiltinOptions_ConcatEmbeddingsOptions; +}; + +template<> struct BuiltinOptionsUnionTraits { + static const BuiltinOptions enum_value = BuiltinOptions_LSHProjectionOptions; +}; + +template<> struct BuiltinOptionsUnionTraits { + static const BuiltinOptions enum_value = BuiltinOptions_Pool2DOptions; +}; + +template<> struct BuiltinOptionsUnionTraits { + static const BuiltinOptions enum_value = BuiltinOptions_SVDFOptions; +}; + +template<> struct BuiltinOptionsUnionTraits { + static const BuiltinOptions enum_value = BuiltinOptions_RNNOptions; +}; + +template<> struct BuiltinOptionsUnionTraits { + static const BuiltinOptions enum_value = BuiltinOptions_FullyConnectedOptions; +}; + +template<> struct BuiltinOptionsUnionTraits { + static const BuiltinOptions enum_value = BuiltinOptions_SoftmaxOptions; +}; + +template<> struct BuiltinOptionsUnionTraits { + static const BuiltinOptions enum_value = BuiltinOptions_ConcatenationOptions; +}; + +template<> struct BuiltinOptionsUnionTraits { + static const BuiltinOptions enum_value = BuiltinOptions_AddOptions; +}; + +template<> struct BuiltinOptionsUnionTraits { + static const BuiltinOptions enum_value = BuiltinOptions_L2NormOptions; +}; + +template<> struct BuiltinOptionsUnionTraits { + static const BuiltinOptions enum_value = BuiltinOptions_LocalResponseNormalizationOptions; +}; + +template<> struct BuiltinOptionsUnionTraits { + static const BuiltinOptions enum_value = BuiltinOptions_LSTMOptions; +}; + +template<> struct BuiltinOptionsUnionTraits { + static const BuiltinOptions enum_value = BuiltinOptions_ResizeBilinearOptions; +}; + +template<> struct BuiltinOptionsUnionTraits { + static const BuiltinOptions enum_value = BuiltinOptions_CallOptions; +}; + +template<> struct BuiltinOptionsUnionTraits { + static const BuiltinOptions enum_value = BuiltinOptions_ReshapeOptions; +}; + +template<> struct BuiltinOptionsUnionTraits { + static const BuiltinOptions enum_value = BuiltinOptions_SkipGramOptions; +}; + +template<> struct BuiltinOptionsUnionTraits { + static const BuiltinOptions enum_value = BuiltinOptions_SpaceToDepthOptions; +}; + +template<> struct BuiltinOptionsUnionTraits { + static const BuiltinOptions enum_value = BuiltinOptions_EmbeddingLookupSparseOptions; +}; + +template<> struct BuiltinOptionsUnionTraits { + static const BuiltinOptions enum_value = BuiltinOptions_MulOptions; +}; + +template<> struct BuiltinOptionsUnionTraits { + static const BuiltinOptions enum_value = BuiltinOptions_PadOptions; +}; + +template<> struct BuiltinOptionsUnionTraits { + static const BuiltinOptions enum_value = BuiltinOptions_GatherOptions; +}; + +template<> struct BuiltinOptionsUnionTraits { + static const BuiltinOptions enum_value = BuiltinOptions_BatchToSpaceNDOptions; +}; + +template<> struct BuiltinOptionsUnionTraits { + static const BuiltinOptions enum_value = BuiltinOptions_SpaceToBatchNDOptions; +}; + +template<> struct BuiltinOptionsUnionTraits { + static const BuiltinOptions enum_value = BuiltinOptions_TransposeOptions; +}; + +template<> struct BuiltinOptionsUnionTraits { + static const BuiltinOptions enum_value = BuiltinOptions_ReducerOptions; +}; + +template<> struct BuiltinOptionsUnionTraits { + static const BuiltinOptions enum_value = BuiltinOptions_SubOptions; +}; + +template<> struct BuiltinOptionsUnionTraits { + static const BuiltinOptions enum_value = BuiltinOptions_DivOptions; +}; + +template<> struct BuiltinOptionsUnionTraits { + static const BuiltinOptions enum_value = BuiltinOptions_SqueezeOptions; +}; + +template<> struct BuiltinOptionsUnionTraits { + static const BuiltinOptions enum_value = BuiltinOptions_SequenceRNNOptions; +}; + +template<> struct BuiltinOptionsUnionTraits { + static const BuiltinOptions enum_value = BuiltinOptions_StridedSliceOptions; +}; + +template<> struct BuiltinOptionsUnionTraits { + static const BuiltinOptions enum_value = BuiltinOptions_ExpOptions; +}; + +template<> struct BuiltinOptionsUnionTraits { + static const BuiltinOptions enum_value = BuiltinOptions_TopKV2Options; +}; + +template<> struct BuiltinOptionsUnionTraits { + static const BuiltinOptions enum_value = BuiltinOptions_SplitOptions; +}; + +template<> struct BuiltinOptionsUnionTraits { + static const BuiltinOptions enum_value = BuiltinOptions_LogSoftmaxOptions; +}; + +template<> struct BuiltinOptionsUnionTraits { + static const BuiltinOptions enum_value = BuiltinOptions_CastOptions; +}; + +template<> struct BuiltinOptionsUnionTraits { + static const BuiltinOptions enum_value = BuiltinOptions_DequantizeOptions; +}; + +template<> struct BuiltinOptionsUnionTraits { + static const BuiltinOptions enum_value = BuiltinOptions_MaximumMinimumOptions; +}; + +template<> struct BuiltinOptionsUnionTraits { + static const BuiltinOptions enum_value = BuiltinOptions_ArgMaxOptions; +}; + +template<> struct BuiltinOptionsUnionTraits { + static const BuiltinOptions enum_value = BuiltinOptions_LessOptions; +}; + +template<> struct BuiltinOptionsUnionTraits { + static const BuiltinOptions enum_value = BuiltinOptions_NegOptions; +}; + +template<> struct BuiltinOptionsUnionTraits { + static const BuiltinOptions enum_value = BuiltinOptions_PadV2Options; +}; + +template<> struct BuiltinOptionsUnionTraits { + static const BuiltinOptions enum_value = BuiltinOptions_GreaterOptions; +}; + +template<> struct BuiltinOptionsUnionTraits { + static const BuiltinOptions enum_value = BuiltinOptions_GreaterEqualOptions; +}; + +template<> struct BuiltinOptionsUnionTraits { + static const BuiltinOptions enum_value = BuiltinOptions_LessEqualOptions; +}; + +template<> struct BuiltinOptionsUnionTraits { + static const BuiltinOptions enum_value = BuiltinOptions_SelectOptions; +}; + +template<> struct BuiltinOptionsUnionTraits { + static const BuiltinOptions enum_value = BuiltinOptions_SliceOptions; +}; + +template<> struct BuiltinOptionsUnionTraits { + static const BuiltinOptions enum_value = BuiltinOptions_TransposeConvOptions; +}; + +template<> struct BuiltinOptionsUnionTraits { + static const BuiltinOptions enum_value = BuiltinOptions_SparseToDenseOptions; +}; + +template<> struct BuiltinOptionsUnionTraits { + static const BuiltinOptions enum_value = BuiltinOptions_TileOptions; +}; + +template<> struct BuiltinOptionsUnionTraits { + static const BuiltinOptions enum_value = BuiltinOptions_ExpandDimsOptions; +}; + +template<> struct BuiltinOptionsUnionTraits { + static const BuiltinOptions enum_value = BuiltinOptions_EqualOptions; +}; + +template<> struct BuiltinOptionsUnionTraits { + static const BuiltinOptions enum_value = BuiltinOptions_NotEqualOptions; +}; + +template<> struct BuiltinOptionsUnionTraits { + static const BuiltinOptions enum_value = BuiltinOptions_ShapeOptions; +}; + +template<> struct BuiltinOptionsUnionTraits { + static const BuiltinOptions enum_value = BuiltinOptions_PowOptions; +}; + +template<> struct BuiltinOptionsUnionTraits { + static const BuiltinOptions enum_value = BuiltinOptions_ArgMinOptions; +}; + +template<> struct BuiltinOptionsUnionTraits { + static const BuiltinOptions enum_value = BuiltinOptions_FakeQuantOptions; +}; + +template<> struct BuiltinOptionsUnionTraits { + static const BuiltinOptions enum_value = BuiltinOptions_PackOptions; +}; + +template<> struct BuiltinOptionsUnionTraits { + static const BuiltinOptions enum_value = BuiltinOptions_LogicalOrOptions; +}; + +template<> struct BuiltinOptionsUnionTraits { + static const BuiltinOptions enum_value = BuiltinOptions_OneHotOptions; +}; + +template<> struct BuiltinOptionsUnionTraits { + static const BuiltinOptions enum_value = BuiltinOptions_LogicalAndOptions; +}; + +template<> struct BuiltinOptionsUnionTraits { + static const BuiltinOptions enum_value = BuiltinOptions_LogicalNotOptions; +}; + +template<> struct BuiltinOptionsUnionTraits { + static const BuiltinOptions enum_value = BuiltinOptions_UnpackOptions; +}; + +template<> struct BuiltinOptionsUnionTraits { + static const BuiltinOptions enum_value = BuiltinOptions_FloorDivOptions; +}; + +template<> struct BuiltinOptionsUnionTraits { + static const BuiltinOptions enum_value = BuiltinOptions_SquareOptions; +}; + +template<> struct BuiltinOptionsUnionTraits { + static const BuiltinOptions enum_value = BuiltinOptions_ZerosLikeOptions; +}; + +template<> struct BuiltinOptionsUnionTraits { + static const BuiltinOptions enum_value = BuiltinOptions_FillOptions; +}; + +template<> struct BuiltinOptionsUnionTraits { + static const BuiltinOptions enum_value = BuiltinOptions_BidirectionalSequenceLSTMOptions; +}; + +template<> struct BuiltinOptionsUnionTraits { + static const BuiltinOptions enum_value = BuiltinOptions_BidirectionalSequenceRNNOptions; +}; + +template<> struct BuiltinOptionsUnionTraits { + static const BuiltinOptions enum_value = BuiltinOptions_UnidirectionalSequenceLSTMOptions; +}; + +template<> struct BuiltinOptionsUnionTraits { + static const BuiltinOptions enum_value = BuiltinOptions_FloorModOptions; +}; + +template<> struct BuiltinOptionsUnionTraits { + static const BuiltinOptions enum_value = BuiltinOptions_RangeOptions; +}; + +template<> struct BuiltinOptionsUnionTraits { + static const BuiltinOptions enum_value = BuiltinOptions_ResizeNearestNeighborOptions; +}; + +template<> struct BuiltinOptionsUnionTraits { + static const BuiltinOptions enum_value = BuiltinOptions_LeakyReluOptions; +}; + +template<> struct BuiltinOptionsUnionTraits { + static const BuiltinOptions enum_value = BuiltinOptions_SquaredDifferenceOptions; +}; + +template<> struct BuiltinOptionsUnionTraits { + static const BuiltinOptions enum_value = BuiltinOptions_MirrorPadOptions; +}; + +template<> struct BuiltinOptionsUnionTraits { + static const BuiltinOptions enum_value = BuiltinOptions_AbsOptions; +}; + +template<> struct BuiltinOptionsUnionTraits { + static const BuiltinOptions enum_value = BuiltinOptions_SplitVOptions; +}; + +template<> struct BuiltinOptionsUnionTraits { + static const BuiltinOptions enum_value = BuiltinOptions_UniqueOptions; +}; + +template<> struct BuiltinOptionsUnionTraits { + static const BuiltinOptions enum_value = BuiltinOptions_ReverseV2Options; +}; + +template<> struct BuiltinOptionsUnionTraits { + static const BuiltinOptions enum_value = BuiltinOptions_AddNOptions; +}; + +template<> struct BuiltinOptionsUnionTraits { + static const BuiltinOptions enum_value = BuiltinOptions_GatherNdOptions; +}; + +template<> struct BuiltinOptionsUnionTraits { + static const BuiltinOptions enum_value = BuiltinOptions_CosOptions; +}; + +template<> struct BuiltinOptionsUnionTraits { + static const BuiltinOptions enum_value = BuiltinOptions_WhereOptions; +}; + +template<> struct BuiltinOptionsUnionTraits { + static const BuiltinOptions enum_value = BuiltinOptions_RankOptions; +}; + +template<> struct BuiltinOptionsUnionTraits { + static const BuiltinOptions enum_value = BuiltinOptions_ReverseSequenceOptions; +}; + +template<> struct BuiltinOptionsUnionTraits { + static const BuiltinOptions enum_value = BuiltinOptions_MatrixDiagOptions; +}; + +template<> struct BuiltinOptionsUnionTraits { + static const BuiltinOptions enum_value = BuiltinOptions_QuantizeOptions; +}; + +template<> struct BuiltinOptionsUnionTraits { + static const BuiltinOptions enum_value = BuiltinOptions_MatrixSetDiagOptions; +}; + +template<> struct BuiltinOptionsUnionTraits { + static const BuiltinOptions enum_value = BuiltinOptions_HardSwishOptions; +}; + +template<> struct BuiltinOptionsUnionTraits { + static const BuiltinOptions enum_value = BuiltinOptions_IfOptions; +}; + +template<> struct BuiltinOptionsUnionTraits { + static const BuiltinOptions enum_value = BuiltinOptions_WhileOptions; +}; + +template<> struct BuiltinOptionsUnionTraits { + static const BuiltinOptions enum_value = BuiltinOptions_DepthToSpaceOptions; +}; + +template<> struct BuiltinOptionsUnionTraits { + static const BuiltinOptions enum_value = BuiltinOptions_NonMaxSuppressionV4Options; +}; + +template<> struct BuiltinOptionsUnionTraits { + static const BuiltinOptions enum_value = BuiltinOptions_NonMaxSuppressionV5Options; +}; + +template<> struct BuiltinOptionsUnionTraits { + static const BuiltinOptions enum_value = BuiltinOptions_ScatterNdOptions; +}; + +template<> struct BuiltinOptionsUnionTraits { + static const BuiltinOptions enum_value = BuiltinOptions_SelectV2Options; +}; + +template<> struct BuiltinOptionsUnionTraits { + static const BuiltinOptions enum_value = BuiltinOptions_DensifyOptions; +}; + +template<> struct BuiltinOptionsUnionTraits { + static const BuiltinOptions enum_value = BuiltinOptions_SegmentSumOptions; +}; + +template<> struct BuiltinOptionsUnionTraits { + static const BuiltinOptions enum_value = BuiltinOptions_BatchMatMulOptions; +}; + +template<> struct BuiltinOptionsUnionTraits { + static const BuiltinOptions enum_value = BuiltinOptions_CumsumOptions; +}; + +template<> struct BuiltinOptionsUnionTraits { + static const BuiltinOptions enum_value = BuiltinOptions_CallOnceOptions; +}; + +template<> struct BuiltinOptionsUnionTraits { + static const BuiltinOptions enum_value = BuiltinOptions_BroadcastToOptions; +}; + +template<> struct BuiltinOptionsUnionTraits { + static const BuiltinOptions enum_value = BuiltinOptions_Rfft2dOptions; +}; + +template<> struct BuiltinOptionsUnionTraits { + static const BuiltinOptions enum_value = BuiltinOptions_Conv3DOptions; +}; + +template<> struct BuiltinOptionsUnionTraits { + static const BuiltinOptions enum_value = BuiltinOptions_HashtableOptions; +}; + +template<> struct BuiltinOptionsUnionTraits { + static const BuiltinOptions enum_value = BuiltinOptions_HashtableFindOptions; +}; + +template<> struct BuiltinOptionsUnionTraits { + static const BuiltinOptions enum_value = BuiltinOptions_HashtableImportOptions; +}; + +template<> struct BuiltinOptionsUnionTraits { + static const BuiltinOptions enum_value = BuiltinOptions_HashtableSizeOptions; +}; + +template<> struct BuiltinOptionsUnionTraits { + static const BuiltinOptions enum_value = BuiltinOptions_VarHandleOptions; +}; + +template<> struct BuiltinOptionsUnionTraits { + static const BuiltinOptions enum_value = BuiltinOptions_ReadVariableOptions; +}; + +template<> struct BuiltinOptionsUnionTraits { + static const BuiltinOptions enum_value = BuiltinOptions_AssignVariableOptions; +}; + +template<> struct BuiltinOptionsUnionTraits { + static const BuiltinOptions enum_value = BuiltinOptions_RandomOptions; +}; + +template<> struct BuiltinOptionsUnionTraits { + static const BuiltinOptions enum_value = BuiltinOptions_BucketizeOptions; +}; + +template<> struct BuiltinOptionsUnionTraits { + static const BuiltinOptions enum_value = BuiltinOptions_GeluOptions; +}; + +template<> struct BuiltinOptionsUnionTraits { + static const BuiltinOptions enum_value = BuiltinOptions_DynamicUpdateSliceOptions; +}; + +template<> struct BuiltinOptionsUnionTraits { + static const BuiltinOptions enum_value = BuiltinOptions_UnsortedSegmentProdOptions; +}; + +template<> struct BuiltinOptionsUnionTraits { + static const BuiltinOptions enum_value = BuiltinOptions_UnsortedSegmentMaxOptions; +}; + +template<> struct BuiltinOptionsUnionTraits { + static const BuiltinOptions enum_value = BuiltinOptions_UnsortedSegmentMinOptions; +}; + +template<> struct BuiltinOptionsUnionTraits { + static const BuiltinOptions enum_value = BuiltinOptions_UnsortedSegmentSumOptions; +}; + +template<> struct BuiltinOptionsUnionTraits { + static const BuiltinOptions enum_value = BuiltinOptions_ATan2Options; +}; + +template<> struct BuiltinOptionsUnionTraits { + static const BuiltinOptions enum_value = BuiltinOptions_SignOptions; +}; + +template<> struct BuiltinOptionsUnionTraits { + static const BuiltinOptions enum_value = BuiltinOptions_BitcastOptions; +}; + +template<> struct BuiltinOptionsUnionTraits { + static const BuiltinOptions enum_value = BuiltinOptions_BitwiseXorOptions; +}; + +template<> struct BuiltinOptionsUnionTraits { + static const BuiltinOptions enum_value = BuiltinOptions_RightShiftOptions; +}; + +struct BuiltinOptionsUnion { + BuiltinOptions type; + void *value; + + BuiltinOptionsUnion() : type(BuiltinOptions_NONE), value(nullptr) {} + BuiltinOptionsUnion(BuiltinOptionsUnion&& u) FLATBUFFERS_NOEXCEPT : + type(BuiltinOptions_NONE), value(nullptr) + { std::swap(type, u.type); std::swap(value, u.value); } + BuiltinOptionsUnion(const BuiltinOptionsUnion &); + BuiltinOptionsUnion &operator=(const BuiltinOptionsUnion &u) + { BuiltinOptionsUnion t(u); std::swap(type, t.type); std::swap(value, t.value); return *this; } + BuiltinOptionsUnion &operator=(BuiltinOptionsUnion &&u) FLATBUFFERS_NOEXCEPT + { std::swap(type, u.type); std::swap(value, u.value); return *this; } + ~BuiltinOptionsUnion() { Reset(); } + + void Reset(); + + template + void Set(T&& val) { + typedef typename std::remove_reference::type RT; + Reset(); + type = BuiltinOptionsUnionTraits::enum_value; + if (type != BuiltinOptions_NONE) { + value = new RT(std::forward(val)); + } + } + + static void *UnPack(const void *obj, BuiltinOptions type, const ::flatbuffers::resolver_function_t *resolver); + ::flatbuffers::Offset Pack(::flatbuffers::FlatBufferBuilder &_fbb, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr) const; + + tflite::Conv2DOptionsT *AsConv2DOptions() { + return type == BuiltinOptions_Conv2DOptions ? + reinterpret_cast(value) : nullptr; + } + const tflite::Conv2DOptionsT *AsConv2DOptions() const { + return type == BuiltinOptions_Conv2DOptions ? + reinterpret_cast(value) : nullptr; + } + tflite::DepthwiseConv2DOptionsT *AsDepthwiseConv2DOptions() { + return type == BuiltinOptions_DepthwiseConv2DOptions ? + reinterpret_cast(value) : nullptr; + } + const tflite::DepthwiseConv2DOptionsT *AsDepthwiseConv2DOptions() const { + return type == BuiltinOptions_DepthwiseConv2DOptions ? + reinterpret_cast(value) : nullptr; + } + tflite::ConcatEmbeddingsOptionsT *AsConcatEmbeddingsOptions() { + return type == BuiltinOptions_ConcatEmbeddingsOptions ? + reinterpret_cast(value) : nullptr; + } + const tflite::ConcatEmbeddingsOptionsT *AsConcatEmbeddingsOptions() const { + return type == BuiltinOptions_ConcatEmbeddingsOptions ? + reinterpret_cast(value) : nullptr; + } + tflite::LSHProjectionOptionsT *AsLSHProjectionOptions() { + return type == BuiltinOptions_LSHProjectionOptions ? + reinterpret_cast(value) : nullptr; + } + const tflite::LSHProjectionOptionsT *AsLSHProjectionOptions() const { + return type == BuiltinOptions_LSHProjectionOptions ? + reinterpret_cast(value) : nullptr; + } + tflite::Pool2DOptionsT *AsPool2DOptions() { + return type == BuiltinOptions_Pool2DOptions ? + reinterpret_cast(value) : nullptr; + } + const tflite::Pool2DOptionsT *AsPool2DOptions() const { + return type == BuiltinOptions_Pool2DOptions ? + reinterpret_cast(value) : nullptr; + } + tflite::SVDFOptionsT *AsSVDFOptions() { + return type == BuiltinOptions_SVDFOptions ? + reinterpret_cast(value) : nullptr; + } + const tflite::SVDFOptionsT *AsSVDFOptions() const { + return type == BuiltinOptions_SVDFOptions ? + reinterpret_cast(value) : nullptr; + } + tflite::RNNOptionsT *AsRNNOptions() { + return type == BuiltinOptions_RNNOptions ? + reinterpret_cast(value) : nullptr; + } + const tflite::RNNOptionsT *AsRNNOptions() const { + return type == BuiltinOptions_RNNOptions ? + reinterpret_cast(value) : nullptr; + } + tflite::FullyConnectedOptionsT *AsFullyConnectedOptions() { + return type == BuiltinOptions_FullyConnectedOptions ? + reinterpret_cast(value) : nullptr; + } + const tflite::FullyConnectedOptionsT *AsFullyConnectedOptions() const { + return type == BuiltinOptions_FullyConnectedOptions ? + reinterpret_cast(value) : nullptr; + } + tflite::SoftmaxOptionsT *AsSoftmaxOptions() { + return type == BuiltinOptions_SoftmaxOptions ? + reinterpret_cast(value) : nullptr; + } + const tflite::SoftmaxOptionsT *AsSoftmaxOptions() const { + return type == BuiltinOptions_SoftmaxOptions ? + reinterpret_cast(value) : nullptr; + } + tflite::ConcatenationOptionsT *AsConcatenationOptions() { + return type == BuiltinOptions_ConcatenationOptions ? + reinterpret_cast(value) : nullptr; + } + const tflite::ConcatenationOptionsT *AsConcatenationOptions() const { + return type == BuiltinOptions_ConcatenationOptions ? + reinterpret_cast(value) : nullptr; + } + tflite::AddOptionsT *AsAddOptions() { + return type == BuiltinOptions_AddOptions ? + reinterpret_cast(value) : nullptr; + } + const tflite::AddOptionsT *AsAddOptions() const { + return type == BuiltinOptions_AddOptions ? + reinterpret_cast(value) : nullptr; + } + tflite::L2NormOptionsT *AsL2NormOptions() { + return type == BuiltinOptions_L2NormOptions ? + reinterpret_cast(value) : nullptr; + } + const tflite::L2NormOptionsT *AsL2NormOptions() const { + return type == BuiltinOptions_L2NormOptions ? + reinterpret_cast(value) : nullptr; + } + tflite::LocalResponseNormalizationOptionsT *AsLocalResponseNormalizationOptions() { + return type == BuiltinOptions_LocalResponseNormalizationOptions ? + reinterpret_cast(value) : nullptr; + } + const tflite::LocalResponseNormalizationOptionsT *AsLocalResponseNormalizationOptions() const { + return type == BuiltinOptions_LocalResponseNormalizationOptions ? + reinterpret_cast(value) : nullptr; + } + tflite::LSTMOptionsT *AsLSTMOptions() { + return type == BuiltinOptions_LSTMOptions ? + reinterpret_cast(value) : nullptr; + } + const tflite::LSTMOptionsT *AsLSTMOptions() const { + return type == BuiltinOptions_LSTMOptions ? + reinterpret_cast(value) : nullptr; + } + tflite::ResizeBilinearOptionsT *AsResizeBilinearOptions() { + return type == BuiltinOptions_ResizeBilinearOptions ? + reinterpret_cast(value) : nullptr; + } + const tflite::ResizeBilinearOptionsT *AsResizeBilinearOptions() const { + return type == BuiltinOptions_ResizeBilinearOptions ? + reinterpret_cast(value) : nullptr; + } + tflite::CallOptionsT *AsCallOptions() { + return type == BuiltinOptions_CallOptions ? + reinterpret_cast(value) : nullptr; + } + const tflite::CallOptionsT *AsCallOptions() const { + return type == BuiltinOptions_CallOptions ? + reinterpret_cast(value) : nullptr; + } + tflite::ReshapeOptionsT *AsReshapeOptions() { + return type == BuiltinOptions_ReshapeOptions ? + reinterpret_cast(value) : nullptr; + } + const tflite::ReshapeOptionsT *AsReshapeOptions() const { + return type == BuiltinOptions_ReshapeOptions ? + reinterpret_cast(value) : nullptr; + } + tflite::SkipGramOptionsT *AsSkipGramOptions() { + return type == BuiltinOptions_SkipGramOptions ? + reinterpret_cast(value) : nullptr; + } + const tflite::SkipGramOptionsT *AsSkipGramOptions() const { + return type == BuiltinOptions_SkipGramOptions ? + reinterpret_cast(value) : nullptr; + } + tflite::SpaceToDepthOptionsT *AsSpaceToDepthOptions() { + return type == BuiltinOptions_SpaceToDepthOptions ? + reinterpret_cast(value) : nullptr; + } + const tflite::SpaceToDepthOptionsT *AsSpaceToDepthOptions() const { + return type == BuiltinOptions_SpaceToDepthOptions ? + reinterpret_cast(value) : nullptr; + } + tflite::EmbeddingLookupSparseOptionsT *AsEmbeddingLookupSparseOptions() { + return type == BuiltinOptions_EmbeddingLookupSparseOptions ? + reinterpret_cast(value) : nullptr; + } + const tflite::EmbeddingLookupSparseOptionsT *AsEmbeddingLookupSparseOptions() const { + return type == BuiltinOptions_EmbeddingLookupSparseOptions ? + reinterpret_cast(value) : nullptr; + } + tflite::MulOptionsT *AsMulOptions() { + return type == BuiltinOptions_MulOptions ? + reinterpret_cast(value) : nullptr; + } + const tflite::MulOptionsT *AsMulOptions() const { + return type == BuiltinOptions_MulOptions ? + reinterpret_cast(value) : nullptr; + } + tflite::PadOptionsT *AsPadOptions() { + return type == BuiltinOptions_PadOptions ? + reinterpret_cast(value) : nullptr; + } + const tflite::PadOptionsT *AsPadOptions() const { + return type == BuiltinOptions_PadOptions ? + reinterpret_cast(value) : nullptr; + } + tflite::GatherOptionsT *AsGatherOptions() { + return type == BuiltinOptions_GatherOptions ? + reinterpret_cast(value) : nullptr; + } + const tflite::GatherOptionsT *AsGatherOptions() const { + return type == BuiltinOptions_GatherOptions ? + reinterpret_cast(value) : nullptr; + } + tflite::BatchToSpaceNDOptionsT *AsBatchToSpaceNDOptions() { + return type == BuiltinOptions_BatchToSpaceNDOptions ? + reinterpret_cast(value) : nullptr; + } + const tflite::BatchToSpaceNDOptionsT *AsBatchToSpaceNDOptions() const { + return type == BuiltinOptions_BatchToSpaceNDOptions ? + reinterpret_cast(value) : nullptr; + } + tflite::SpaceToBatchNDOptionsT *AsSpaceToBatchNDOptions() { + return type == BuiltinOptions_SpaceToBatchNDOptions ? + reinterpret_cast(value) : nullptr; + } + const tflite::SpaceToBatchNDOptionsT *AsSpaceToBatchNDOptions() const { + return type == BuiltinOptions_SpaceToBatchNDOptions ? + reinterpret_cast(value) : nullptr; + } + tflite::TransposeOptionsT *AsTransposeOptions() { + return type == BuiltinOptions_TransposeOptions ? + reinterpret_cast(value) : nullptr; + } + const tflite::TransposeOptionsT *AsTransposeOptions() const { + return type == BuiltinOptions_TransposeOptions ? + reinterpret_cast(value) : nullptr; + } + tflite::ReducerOptionsT *AsReducerOptions() { + return type == BuiltinOptions_ReducerOptions ? + reinterpret_cast(value) : nullptr; + } + const tflite::ReducerOptionsT *AsReducerOptions() const { + return type == BuiltinOptions_ReducerOptions ? + reinterpret_cast(value) : nullptr; + } + tflite::SubOptionsT *AsSubOptions() { + return type == BuiltinOptions_SubOptions ? + reinterpret_cast(value) : nullptr; + } + const tflite::SubOptionsT *AsSubOptions() const { + return type == BuiltinOptions_SubOptions ? + reinterpret_cast(value) : nullptr; + } + tflite::DivOptionsT *AsDivOptions() { + return type == BuiltinOptions_DivOptions ? + reinterpret_cast(value) : nullptr; + } + const tflite::DivOptionsT *AsDivOptions() const { + return type == BuiltinOptions_DivOptions ? + reinterpret_cast(value) : nullptr; + } + tflite::SqueezeOptionsT *AsSqueezeOptions() { + return type == BuiltinOptions_SqueezeOptions ? + reinterpret_cast(value) : nullptr; + } + const tflite::SqueezeOptionsT *AsSqueezeOptions() const { + return type == BuiltinOptions_SqueezeOptions ? + reinterpret_cast(value) : nullptr; + } + tflite::SequenceRNNOptionsT *AsSequenceRNNOptions() { + return type == BuiltinOptions_SequenceRNNOptions ? + reinterpret_cast(value) : nullptr; + } + const tflite::SequenceRNNOptionsT *AsSequenceRNNOptions() const { + return type == BuiltinOptions_SequenceRNNOptions ? + reinterpret_cast(value) : nullptr; + } + tflite::StridedSliceOptionsT *AsStridedSliceOptions() { + return type == BuiltinOptions_StridedSliceOptions ? + reinterpret_cast(value) : nullptr; + } + const tflite::StridedSliceOptionsT *AsStridedSliceOptions() const { + return type == BuiltinOptions_StridedSliceOptions ? + reinterpret_cast(value) : nullptr; + } + tflite::ExpOptionsT *AsExpOptions() { + return type == BuiltinOptions_ExpOptions ? + reinterpret_cast(value) : nullptr; + } + const tflite::ExpOptionsT *AsExpOptions() const { + return type == BuiltinOptions_ExpOptions ? + reinterpret_cast(value) : nullptr; + } + tflite::TopKV2OptionsT *AsTopKV2Options() { + return type == BuiltinOptions_TopKV2Options ? + reinterpret_cast(value) : nullptr; + } + const tflite::TopKV2OptionsT *AsTopKV2Options() const { + return type == BuiltinOptions_TopKV2Options ? + reinterpret_cast(value) : nullptr; + } + tflite::SplitOptionsT *AsSplitOptions() { + return type == BuiltinOptions_SplitOptions ? + reinterpret_cast(value) : nullptr; + } + const tflite::SplitOptionsT *AsSplitOptions() const { + return type == BuiltinOptions_SplitOptions ? + reinterpret_cast(value) : nullptr; + } + tflite::LogSoftmaxOptionsT *AsLogSoftmaxOptions() { + return type == BuiltinOptions_LogSoftmaxOptions ? + reinterpret_cast(value) : nullptr; + } + const tflite::LogSoftmaxOptionsT *AsLogSoftmaxOptions() const { + return type == BuiltinOptions_LogSoftmaxOptions ? + reinterpret_cast(value) : nullptr; + } + tflite::CastOptionsT *AsCastOptions() { + return type == BuiltinOptions_CastOptions ? + reinterpret_cast(value) : nullptr; + } + const tflite::CastOptionsT *AsCastOptions() const { + return type == BuiltinOptions_CastOptions ? + reinterpret_cast(value) : nullptr; + } + tflite::DequantizeOptionsT *AsDequantizeOptions() { + return type == BuiltinOptions_DequantizeOptions ? + reinterpret_cast(value) : nullptr; + } + const tflite::DequantizeOptionsT *AsDequantizeOptions() const { + return type == BuiltinOptions_DequantizeOptions ? + reinterpret_cast(value) : nullptr; + } + tflite::MaximumMinimumOptionsT *AsMaximumMinimumOptions() { + return type == BuiltinOptions_MaximumMinimumOptions ? + reinterpret_cast(value) : nullptr; + } + const tflite::MaximumMinimumOptionsT *AsMaximumMinimumOptions() const { + return type == BuiltinOptions_MaximumMinimumOptions ? + reinterpret_cast(value) : nullptr; + } + tflite::ArgMaxOptionsT *AsArgMaxOptions() { + return type == BuiltinOptions_ArgMaxOptions ? + reinterpret_cast(value) : nullptr; + } + const tflite::ArgMaxOptionsT *AsArgMaxOptions() const { + return type == BuiltinOptions_ArgMaxOptions ? + reinterpret_cast(value) : nullptr; + } + tflite::LessOptionsT *AsLessOptions() { + return type == BuiltinOptions_LessOptions ? + reinterpret_cast(value) : nullptr; + } + const tflite::LessOptionsT *AsLessOptions() const { + return type == BuiltinOptions_LessOptions ? + reinterpret_cast(value) : nullptr; + } + tflite::NegOptionsT *AsNegOptions() { + return type == BuiltinOptions_NegOptions ? + reinterpret_cast(value) : nullptr; + } + const tflite::NegOptionsT *AsNegOptions() const { + return type == BuiltinOptions_NegOptions ? + reinterpret_cast(value) : nullptr; + } + tflite::PadV2OptionsT *AsPadV2Options() { + return type == BuiltinOptions_PadV2Options ? + reinterpret_cast(value) : nullptr; + } + const tflite::PadV2OptionsT *AsPadV2Options() const { + return type == BuiltinOptions_PadV2Options ? + reinterpret_cast(value) : nullptr; + } + tflite::GreaterOptionsT *AsGreaterOptions() { + return type == BuiltinOptions_GreaterOptions ? + reinterpret_cast(value) : nullptr; + } + const tflite::GreaterOptionsT *AsGreaterOptions() const { + return type == BuiltinOptions_GreaterOptions ? + reinterpret_cast(value) : nullptr; + } + tflite::GreaterEqualOptionsT *AsGreaterEqualOptions() { + return type == BuiltinOptions_GreaterEqualOptions ? + reinterpret_cast(value) : nullptr; + } + const tflite::GreaterEqualOptionsT *AsGreaterEqualOptions() const { + return type == BuiltinOptions_GreaterEqualOptions ? + reinterpret_cast(value) : nullptr; + } + tflite::LessEqualOptionsT *AsLessEqualOptions() { + return type == BuiltinOptions_LessEqualOptions ? + reinterpret_cast(value) : nullptr; + } + const tflite::LessEqualOptionsT *AsLessEqualOptions() const { + return type == BuiltinOptions_LessEqualOptions ? + reinterpret_cast(value) : nullptr; + } + tflite::SelectOptionsT *AsSelectOptions() { + return type == BuiltinOptions_SelectOptions ? + reinterpret_cast(value) : nullptr; + } + const tflite::SelectOptionsT *AsSelectOptions() const { + return type == BuiltinOptions_SelectOptions ? + reinterpret_cast(value) : nullptr; + } + tflite::SliceOptionsT *AsSliceOptions() { + return type == BuiltinOptions_SliceOptions ? + reinterpret_cast(value) : nullptr; + } + const tflite::SliceOptionsT *AsSliceOptions() const { + return type == BuiltinOptions_SliceOptions ? + reinterpret_cast(value) : nullptr; + } + tflite::TransposeConvOptionsT *AsTransposeConvOptions() { + return type == BuiltinOptions_TransposeConvOptions ? + reinterpret_cast(value) : nullptr; + } + const tflite::TransposeConvOptionsT *AsTransposeConvOptions() const { + return type == BuiltinOptions_TransposeConvOptions ? + reinterpret_cast(value) : nullptr; + } + tflite::SparseToDenseOptionsT *AsSparseToDenseOptions() { + return type == BuiltinOptions_SparseToDenseOptions ? + reinterpret_cast(value) : nullptr; + } + const tflite::SparseToDenseOptionsT *AsSparseToDenseOptions() const { + return type == BuiltinOptions_SparseToDenseOptions ? + reinterpret_cast(value) : nullptr; + } + tflite::TileOptionsT *AsTileOptions() { + return type == BuiltinOptions_TileOptions ? + reinterpret_cast(value) : nullptr; + } + const tflite::TileOptionsT *AsTileOptions() const { + return type == BuiltinOptions_TileOptions ? + reinterpret_cast(value) : nullptr; + } + tflite::ExpandDimsOptionsT *AsExpandDimsOptions() { + return type == BuiltinOptions_ExpandDimsOptions ? + reinterpret_cast(value) : nullptr; + } + const tflite::ExpandDimsOptionsT *AsExpandDimsOptions() const { + return type == BuiltinOptions_ExpandDimsOptions ? + reinterpret_cast(value) : nullptr; + } + tflite::EqualOptionsT *AsEqualOptions() { + return type == BuiltinOptions_EqualOptions ? + reinterpret_cast(value) : nullptr; + } + const tflite::EqualOptionsT *AsEqualOptions() const { + return type == BuiltinOptions_EqualOptions ? + reinterpret_cast(value) : nullptr; + } + tflite::NotEqualOptionsT *AsNotEqualOptions() { + return type == BuiltinOptions_NotEqualOptions ? + reinterpret_cast(value) : nullptr; + } + const tflite::NotEqualOptionsT *AsNotEqualOptions() const { + return type == BuiltinOptions_NotEqualOptions ? + reinterpret_cast(value) : nullptr; + } + tflite::ShapeOptionsT *AsShapeOptions() { + return type == BuiltinOptions_ShapeOptions ? + reinterpret_cast(value) : nullptr; + } + const tflite::ShapeOptionsT *AsShapeOptions() const { + return type == BuiltinOptions_ShapeOptions ? + reinterpret_cast(value) : nullptr; + } + tflite::PowOptionsT *AsPowOptions() { + return type == BuiltinOptions_PowOptions ? + reinterpret_cast(value) : nullptr; + } + const tflite::PowOptionsT *AsPowOptions() const { + return type == BuiltinOptions_PowOptions ? + reinterpret_cast(value) : nullptr; + } + tflite::ArgMinOptionsT *AsArgMinOptions() { + return type == BuiltinOptions_ArgMinOptions ? + reinterpret_cast(value) : nullptr; + } + const tflite::ArgMinOptionsT *AsArgMinOptions() const { + return type == BuiltinOptions_ArgMinOptions ? + reinterpret_cast(value) : nullptr; + } + tflite::FakeQuantOptionsT *AsFakeQuantOptions() { + return type == BuiltinOptions_FakeQuantOptions ? + reinterpret_cast(value) : nullptr; + } + const tflite::FakeQuantOptionsT *AsFakeQuantOptions() const { + return type == BuiltinOptions_FakeQuantOptions ? + reinterpret_cast(value) : nullptr; + } + tflite::PackOptionsT *AsPackOptions() { + return type == BuiltinOptions_PackOptions ? + reinterpret_cast(value) : nullptr; + } + const tflite::PackOptionsT *AsPackOptions() const { + return type == BuiltinOptions_PackOptions ? + reinterpret_cast(value) : nullptr; + } + tflite::LogicalOrOptionsT *AsLogicalOrOptions() { + return type == BuiltinOptions_LogicalOrOptions ? + reinterpret_cast(value) : nullptr; + } + const tflite::LogicalOrOptionsT *AsLogicalOrOptions() const { + return type == BuiltinOptions_LogicalOrOptions ? + reinterpret_cast(value) : nullptr; + } + tflite::OneHotOptionsT *AsOneHotOptions() { + return type == BuiltinOptions_OneHotOptions ? + reinterpret_cast(value) : nullptr; + } + const tflite::OneHotOptionsT *AsOneHotOptions() const { + return type == BuiltinOptions_OneHotOptions ? + reinterpret_cast(value) : nullptr; + } + tflite::LogicalAndOptionsT *AsLogicalAndOptions() { + return type == BuiltinOptions_LogicalAndOptions ? + reinterpret_cast(value) : nullptr; + } + const tflite::LogicalAndOptionsT *AsLogicalAndOptions() const { + return type == BuiltinOptions_LogicalAndOptions ? + reinterpret_cast(value) : nullptr; + } + tflite::LogicalNotOptionsT *AsLogicalNotOptions() { + return type == BuiltinOptions_LogicalNotOptions ? + reinterpret_cast(value) : nullptr; + } + const tflite::LogicalNotOptionsT *AsLogicalNotOptions() const { + return type == BuiltinOptions_LogicalNotOptions ? + reinterpret_cast(value) : nullptr; + } + tflite::UnpackOptionsT *AsUnpackOptions() { + return type == BuiltinOptions_UnpackOptions ? + reinterpret_cast(value) : nullptr; + } + const tflite::UnpackOptionsT *AsUnpackOptions() const { + return type == BuiltinOptions_UnpackOptions ? + reinterpret_cast(value) : nullptr; + } + tflite::FloorDivOptionsT *AsFloorDivOptions() { + return type == BuiltinOptions_FloorDivOptions ? + reinterpret_cast(value) : nullptr; + } + const tflite::FloorDivOptionsT *AsFloorDivOptions() const { + return type == BuiltinOptions_FloorDivOptions ? + reinterpret_cast(value) : nullptr; + } + tflite::SquareOptionsT *AsSquareOptions() { + return type == BuiltinOptions_SquareOptions ? + reinterpret_cast(value) : nullptr; + } + const tflite::SquareOptionsT *AsSquareOptions() const { + return type == BuiltinOptions_SquareOptions ? + reinterpret_cast(value) : nullptr; + } + tflite::ZerosLikeOptionsT *AsZerosLikeOptions() { + return type == BuiltinOptions_ZerosLikeOptions ? + reinterpret_cast(value) : nullptr; + } + const tflite::ZerosLikeOptionsT *AsZerosLikeOptions() const { + return type == BuiltinOptions_ZerosLikeOptions ? + reinterpret_cast(value) : nullptr; + } + tflite::FillOptionsT *AsFillOptions() { + return type == BuiltinOptions_FillOptions ? + reinterpret_cast(value) : nullptr; + } + const tflite::FillOptionsT *AsFillOptions() const { + return type == BuiltinOptions_FillOptions ? + reinterpret_cast(value) : nullptr; + } + tflite::BidirectionalSequenceLSTMOptionsT *AsBidirectionalSequenceLSTMOptions() { + return type == BuiltinOptions_BidirectionalSequenceLSTMOptions ? + reinterpret_cast(value) : nullptr; + } + const tflite::BidirectionalSequenceLSTMOptionsT *AsBidirectionalSequenceLSTMOptions() const { + return type == BuiltinOptions_BidirectionalSequenceLSTMOptions ? + reinterpret_cast(value) : nullptr; + } + tflite::BidirectionalSequenceRNNOptionsT *AsBidirectionalSequenceRNNOptions() { + return type == BuiltinOptions_BidirectionalSequenceRNNOptions ? + reinterpret_cast(value) : nullptr; + } + const tflite::BidirectionalSequenceRNNOptionsT *AsBidirectionalSequenceRNNOptions() const { + return type == BuiltinOptions_BidirectionalSequenceRNNOptions ? + reinterpret_cast(value) : nullptr; + } + tflite::UnidirectionalSequenceLSTMOptionsT *AsUnidirectionalSequenceLSTMOptions() { + return type == BuiltinOptions_UnidirectionalSequenceLSTMOptions ? + reinterpret_cast(value) : nullptr; + } + const tflite::UnidirectionalSequenceLSTMOptionsT *AsUnidirectionalSequenceLSTMOptions() const { + return type == BuiltinOptions_UnidirectionalSequenceLSTMOptions ? + reinterpret_cast(value) : nullptr; + } + tflite::FloorModOptionsT *AsFloorModOptions() { + return type == BuiltinOptions_FloorModOptions ? + reinterpret_cast(value) : nullptr; + } + const tflite::FloorModOptionsT *AsFloorModOptions() const { + return type == BuiltinOptions_FloorModOptions ? + reinterpret_cast(value) : nullptr; + } + tflite::RangeOptionsT *AsRangeOptions() { + return type == BuiltinOptions_RangeOptions ? + reinterpret_cast(value) : nullptr; + } + const tflite::RangeOptionsT *AsRangeOptions() const { + return type == BuiltinOptions_RangeOptions ? + reinterpret_cast(value) : nullptr; + } + tflite::ResizeNearestNeighborOptionsT *AsResizeNearestNeighborOptions() { + return type == BuiltinOptions_ResizeNearestNeighborOptions ? + reinterpret_cast(value) : nullptr; + } + const tflite::ResizeNearestNeighborOptionsT *AsResizeNearestNeighborOptions() const { + return type == BuiltinOptions_ResizeNearestNeighborOptions ? + reinterpret_cast(value) : nullptr; + } + tflite::LeakyReluOptionsT *AsLeakyReluOptions() { + return type == BuiltinOptions_LeakyReluOptions ? + reinterpret_cast(value) : nullptr; + } + const tflite::LeakyReluOptionsT *AsLeakyReluOptions() const { + return type == BuiltinOptions_LeakyReluOptions ? + reinterpret_cast(value) : nullptr; + } + tflite::SquaredDifferenceOptionsT *AsSquaredDifferenceOptions() { + return type == BuiltinOptions_SquaredDifferenceOptions ? + reinterpret_cast(value) : nullptr; + } + const tflite::SquaredDifferenceOptionsT *AsSquaredDifferenceOptions() const { + return type == BuiltinOptions_SquaredDifferenceOptions ? + reinterpret_cast(value) : nullptr; + } + tflite::MirrorPadOptionsT *AsMirrorPadOptions() { + return type == BuiltinOptions_MirrorPadOptions ? + reinterpret_cast(value) : nullptr; + } + const tflite::MirrorPadOptionsT *AsMirrorPadOptions() const { + return type == BuiltinOptions_MirrorPadOptions ? + reinterpret_cast(value) : nullptr; + } + tflite::AbsOptionsT *AsAbsOptions() { + return type == BuiltinOptions_AbsOptions ? + reinterpret_cast(value) : nullptr; + } + const tflite::AbsOptionsT *AsAbsOptions() const { + return type == BuiltinOptions_AbsOptions ? + reinterpret_cast(value) : nullptr; + } + tflite::SplitVOptionsT *AsSplitVOptions() { + return type == BuiltinOptions_SplitVOptions ? + reinterpret_cast(value) : nullptr; + } + const tflite::SplitVOptionsT *AsSplitVOptions() const { + return type == BuiltinOptions_SplitVOptions ? + reinterpret_cast(value) : nullptr; + } + tflite::UniqueOptionsT *AsUniqueOptions() { + return type == BuiltinOptions_UniqueOptions ? + reinterpret_cast(value) : nullptr; + } + const tflite::UniqueOptionsT *AsUniqueOptions() const { + return type == BuiltinOptions_UniqueOptions ? + reinterpret_cast(value) : nullptr; + } + tflite::ReverseV2OptionsT *AsReverseV2Options() { + return type == BuiltinOptions_ReverseV2Options ? + reinterpret_cast(value) : nullptr; + } + const tflite::ReverseV2OptionsT *AsReverseV2Options() const { + return type == BuiltinOptions_ReverseV2Options ? + reinterpret_cast(value) : nullptr; + } + tflite::AddNOptionsT *AsAddNOptions() { + return type == BuiltinOptions_AddNOptions ? + reinterpret_cast(value) : nullptr; + } + const tflite::AddNOptionsT *AsAddNOptions() const { + return type == BuiltinOptions_AddNOptions ? + reinterpret_cast(value) : nullptr; + } + tflite::GatherNdOptionsT *AsGatherNdOptions() { + return type == BuiltinOptions_GatherNdOptions ? + reinterpret_cast(value) : nullptr; + } + const tflite::GatherNdOptionsT *AsGatherNdOptions() const { + return type == BuiltinOptions_GatherNdOptions ? + reinterpret_cast(value) : nullptr; + } + tflite::CosOptionsT *AsCosOptions() { + return type == BuiltinOptions_CosOptions ? + reinterpret_cast(value) : nullptr; + } + const tflite::CosOptionsT *AsCosOptions() const { + return type == BuiltinOptions_CosOptions ? + reinterpret_cast(value) : nullptr; + } + tflite::WhereOptionsT *AsWhereOptions() { + return type == BuiltinOptions_WhereOptions ? + reinterpret_cast(value) : nullptr; + } + const tflite::WhereOptionsT *AsWhereOptions() const { + return type == BuiltinOptions_WhereOptions ? + reinterpret_cast(value) : nullptr; + } + tflite::RankOptionsT *AsRankOptions() { + return type == BuiltinOptions_RankOptions ? + reinterpret_cast(value) : nullptr; + } + const tflite::RankOptionsT *AsRankOptions() const { + return type == BuiltinOptions_RankOptions ? + reinterpret_cast(value) : nullptr; + } + tflite::ReverseSequenceOptionsT *AsReverseSequenceOptions() { + return type == BuiltinOptions_ReverseSequenceOptions ? + reinterpret_cast(value) : nullptr; + } + const tflite::ReverseSequenceOptionsT *AsReverseSequenceOptions() const { + return type == BuiltinOptions_ReverseSequenceOptions ? + reinterpret_cast(value) : nullptr; + } + tflite::MatrixDiagOptionsT *AsMatrixDiagOptions() { + return type == BuiltinOptions_MatrixDiagOptions ? + reinterpret_cast(value) : nullptr; + } + const tflite::MatrixDiagOptionsT *AsMatrixDiagOptions() const { + return type == BuiltinOptions_MatrixDiagOptions ? + reinterpret_cast(value) : nullptr; + } + tflite::QuantizeOptionsT *AsQuantizeOptions() { + return type == BuiltinOptions_QuantizeOptions ? + reinterpret_cast(value) : nullptr; + } + const tflite::QuantizeOptionsT *AsQuantizeOptions() const { + return type == BuiltinOptions_QuantizeOptions ? + reinterpret_cast(value) : nullptr; + } + tflite::MatrixSetDiagOptionsT *AsMatrixSetDiagOptions() { + return type == BuiltinOptions_MatrixSetDiagOptions ? + reinterpret_cast(value) : nullptr; + } + const tflite::MatrixSetDiagOptionsT *AsMatrixSetDiagOptions() const { + return type == BuiltinOptions_MatrixSetDiagOptions ? + reinterpret_cast(value) : nullptr; + } + tflite::HardSwishOptionsT *AsHardSwishOptions() { + return type == BuiltinOptions_HardSwishOptions ? + reinterpret_cast(value) : nullptr; + } + const tflite::HardSwishOptionsT *AsHardSwishOptions() const { + return type == BuiltinOptions_HardSwishOptions ? + reinterpret_cast(value) : nullptr; + } + tflite::IfOptionsT *AsIfOptions() { + return type == BuiltinOptions_IfOptions ? + reinterpret_cast(value) : nullptr; + } + const tflite::IfOptionsT *AsIfOptions() const { + return type == BuiltinOptions_IfOptions ? + reinterpret_cast(value) : nullptr; + } + tflite::WhileOptionsT *AsWhileOptions() { + return type == BuiltinOptions_WhileOptions ? + reinterpret_cast(value) : nullptr; + } + const tflite::WhileOptionsT *AsWhileOptions() const { + return type == BuiltinOptions_WhileOptions ? + reinterpret_cast(value) : nullptr; + } + tflite::DepthToSpaceOptionsT *AsDepthToSpaceOptions() { + return type == BuiltinOptions_DepthToSpaceOptions ? + reinterpret_cast(value) : nullptr; + } + const tflite::DepthToSpaceOptionsT *AsDepthToSpaceOptions() const { + return type == BuiltinOptions_DepthToSpaceOptions ? + reinterpret_cast(value) : nullptr; + } + tflite::NonMaxSuppressionV4OptionsT *AsNonMaxSuppressionV4Options() { + return type == BuiltinOptions_NonMaxSuppressionV4Options ? + reinterpret_cast(value) : nullptr; + } + const tflite::NonMaxSuppressionV4OptionsT *AsNonMaxSuppressionV4Options() const { + return type == BuiltinOptions_NonMaxSuppressionV4Options ? + reinterpret_cast(value) : nullptr; + } + tflite::NonMaxSuppressionV5OptionsT *AsNonMaxSuppressionV5Options() { + return type == BuiltinOptions_NonMaxSuppressionV5Options ? + reinterpret_cast(value) : nullptr; + } + const tflite::NonMaxSuppressionV5OptionsT *AsNonMaxSuppressionV5Options() const { + return type == BuiltinOptions_NonMaxSuppressionV5Options ? + reinterpret_cast(value) : nullptr; + } + tflite::ScatterNdOptionsT *AsScatterNdOptions() { + return type == BuiltinOptions_ScatterNdOptions ? + reinterpret_cast(value) : nullptr; + } + const tflite::ScatterNdOptionsT *AsScatterNdOptions() const { + return type == BuiltinOptions_ScatterNdOptions ? + reinterpret_cast(value) : nullptr; + } + tflite::SelectV2OptionsT *AsSelectV2Options() { + return type == BuiltinOptions_SelectV2Options ? + reinterpret_cast(value) : nullptr; + } + const tflite::SelectV2OptionsT *AsSelectV2Options() const { + return type == BuiltinOptions_SelectV2Options ? + reinterpret_cast(value) : nullptr; + } + tflite::DensifyOptionsT *AsDensifyOptions() { + return type == BuiltinOptions_DensifyOptions ? + reinterpret_cast(value) : nullptr; + } + const tflite::DensifyOptionsT *AsDensifyOptions() const { + return type == BuiltinOptions_DensifyOptions ? + reinterpret_cast(value) : nullptr; + } + tflite::SegmentSumOptionsT *AsSegmentSumOptions() { + return type == BuiltinOptions_SegmentSumOptions ? + reinterpret_cast(value) : nullptr; + } + const tflite::SegmentSumOptionsT *AsSegmentSumOptions() const { + return type == BuiltinOptions_SegmentSumOptions ? + reinterpret_cast(value) : nullptr; + } + tflite::BatchMatMulOptionsT *AsBatchMatMulOptions() { + return type == BuiltinOptions_BatchMatMulOptions ? + reinterpret_cast(value) : nullptr; + } + const tflite::BatchMatMulOptionsT *AsBatchMatMulOptions() const { + return type == BuiltinOptions_BatchMatMulOptions ? + reinterpret_cast(value) : nullptr; + } + tflite::CumsumOptionsT *AsCumsumOptions() { + return type == BuiltinOptions_CumsumOptions ? + reinterpret_cast(value) : nullptr; + } + const tflite::CumsumOptionsT *AsCumsumOptions() const { + return type == BuiltinOptions_CumsumOptions ? + reinterpret_cast(value) : nullptr; + } + tflite::CallOnceOptionsT *AsCallOnceOptions() { + return type == BuiltinOptions_CallOnceOptions ? + reinterpret_cast(value) : nullptr; + } + const tflite::CallOnceOptionsT *AsCallOnceOptions() const { + return type == BuiltinOptions_CallOnceOptions ? + reinterpret_cast(value) : nullptr; + } + tflite::BroadcastToOptionsT *AsBroadcastToOptions() { + return type == BuiltinOptions_BroadcastToOptions ? + reinterpret_cast(value) : nullptr; + } + const tflite::BroadcastToOptionsT *AsBroadcastToOptions() const { + return type == BuiltinOptions_BroadcastToOptions ? + reinterpret_cast(value) : nullptr; + } + tflite::Rfft2dOptionsT *AsRfft2dOptions() { + return type == BuiltinOptions_Rfft2dOptions ? + reinterpret_cast(value) : nullptr; + } + const tflite::Rfft2dOptionsT *AsRfft2dOptions() const { + return type == BuiltinOptions_Rfft2dOptions ? + reinterpret_cast(value) : nullptr; + } + tflite::Conv3DOptionsT *AsConv3DOptions() { + return type == BuiltinOptions_Conv3DOptions ? + reinterpret_cast(value) : nullptr; + } + const tflite::Conv3DOptionsT *AsConv3DOptions() const { + return type == BuiltinOptions_Conv3DOptions ? + reinterpret_cast(value) : nullptr; + } + tflite::HashtableOptionsT *AsHashtableOptions() { + return type == BuiltinOptions_HashtableOptions ? + reinterpret_cast(value) : nullptr; + } + const tflite::HashtableOptionsT *AsHashtableOptions() const { + return type == BuiltinOptions_HashtableOptions ? + reinterpret_cast(value) : nullptr; + } + tflite::HashtableFindOptionsT *AsHashtableFindOptions() { + return type == BuiltinOptions_HashtableFindOptions ? + reinterpret_cast(value) : nullptr; + } + const tflite::HashtableFindOptionsT *AsHashtableFindOptions() const { + return type == BuiltinOptions_HashtableFindOptions ? + reinterpret_cast(value) : nullptr; + } + tflite::HashtableImportOptionsT *AsHashtableImportOptions() { + return type == BuiltinOptions_HashtableImportOptions ? + reinterpret_cast(value) : nullptr; + } + const tflite::HashtableImportOptionsT *AsHashtableImportOptions() const { + return type == BuiltinOptions_HashtableImportOptions ? + reinterpret_cast(value) : nullptr; + } + tflite::HashtableSizeOptionsT *AsHashtableSizeOptions() { + return type == BuiltinOptions_HashtableSizeOptions ? + reinterpret_cast(value) : nullptr; + } + const tflite::HashtableSizeOptionsT *AsHashtableSizeOptions() const { + return type == BuiltinOptions_HashtableSizeOptions ? + reinterpret_cast(value) : nullptr; + } + tflite::VarHandleOptionsT *AsVarHandleOptions() { + return type == BuiltinOptions_VarHandleOptions ? + reinterpret_cast(value) : nullptr; + } + const tflite::VarHandleOptionsT *AsVarHandleOptions() const { + return type == BuiltinOptions_VarHandleOptions ? + reinterpret_cast(value) : nullptr; + } + tflite::ReadVariableOptionsT *AsReadVariableOptions() { + return type == BuiltinOptions_ReadVariableOptions ? + reinterpret_cast(value) : nullptr; + } + const tflite::ReadVariableOptionsT *AsReadVariableOptions() const { + return type == BuiltinOptions_ReadVariableOptions ? + reinterpret_cast(value) : nullptr; + } + tflite::AssignVariableOptionsT *AsAssignVariableOptions() { + return type == BuiltinOptions_AssignVariableOptions ? + reinterpret_cast(value) : nullptr; + } + const tflite::AssignVariableOptionsT *AsAssignVariableOptions() const { + return type == BuiltinOptions_AssignVariableOptions ? + reinterpret_cast(value) : nullptr; + } + tflite::RandomOptionsT *AsRandomOptions() { + return type == BuiltinOptions_RandomOptions ? + reinterpret_cast(value) : nullptr; + } + const tflite::RandomOptionsT *AsRandomOptions() const { + return type == BuiltinOptions_RandomOptions ? + reinterpret_cast(value) : nullptr; + } + tflite::BucketizeOptionsT *AsBucketizeOptions() { + return type == BuiltinOptions_BucketizeOptions ? + reinterpret_cast(value) : nullptr; + } + const tflite::BucketizeOptionsT *AsBucketizeOptions() const { + return type == BuiltinOptions_BucketizeOptions ? + reinterpret_cast(value) : nullptr; + } + tflite::GeluOptionsT *AsGeluOptions() { + return type == BuiltinOptions_GeluOptions ? + reinterpret_cast(value) : nullptr; + } + const tflite::GeluOptionsT *AsGeluOptions() const { + return type == BuiltinOptions_GeluOptions ? + reinterpret_cast(value) : nullptr; + } + tflite::DynamicUpdateSliceOptionsT *AsDynamicUpdateSliceOptions() { + return type == BuiltinOptions_DynamicUpdateSliceOptions ? + reinterpret_cast(value) : nullptr; + } + const tflite::DynamicUpdateSliceOptionsT *AsDynamicUpdateSliceOptions() const { + return type == BuiltinOptions_DynamicUpdateSliceOptions ? + reinterpret_cast(value) : nullptr; + } + tflite::UnsortedSegmentProdOptionsT *AsUnsortedSegmentProdOptions() { + return type == BuiltinOptions_UnsortedSegmentProdOptions ? + reinterpret_cast(value) : nullptr; + } + const tflite::UnsortedSegmentProdOptionsT *AsUnsortedSegmentProdOptions() const { + return type == BuiltinOptions_UnsortedSegmentProdOptions ? + reinterpret_cast(value) : nullptr; + } + tflite::UnsortedSegmentMaxOptionsT *AsUnsortedSegmentMaxOptions() { + return type == BuiltinOptions_UnsortedSegmentMaxOptions ? + reinterpret_cast(value) : nullptr; + } + const tflite::UnsortedSegmentMaxOptionsT *AsUnsortedSegmentMaxOptions() const { + return type == BuiltinOptions_UnsortedSegmentMaxOptions ? + reinterpret_cast(value) : nullptr; + } + tflite::UnsortedSegmentMinOptionsT *AsUnsortedSegmentMinOptions() { + return type == BuiltinOptions_UnsortedSegmentMinOptions ? + reinterpret_cast(value) : nullptr; + } + const tflite::UnsortedSegmentMinOptionsT *AsUnsortedSegmentMinOptions() const { + return type == BuiltinOptions_UnsortedSegmentMinOptions ? + reinterpret_cast(value) : nullptr; + } + tflite::UnsortedSegmentSumOptionsT *AsUnsortedSegmentSumOptions() { + return type == BuiltinOptions_UnsortedSegmentSumOptions ? + reinterpret_cast(value) : nullptr; + } + const tflite::UnsortedSegmentSumOptionsT *AsUnsortedSegmentSumOptions() const { + return type == BuiltinOptions_UnsortedSegmentSumOptions ? + reinterpret_cast(value) : nullptr; + } + tflite::ATan2OptionsT *AsATan2Options() { + return type == BuiltinOptions_ATan2Options ? + reinterpret_cast(value) : nullptr; + } + const tflite::ATan2OptionsT *AsATan2Options() const { + return type == BuiltinOptions_ATan2Options ? + reinterpret_cast(value) : nullptr; + } + tflite::SignOptionsT *AsSignOptions() { + return type == BuiltinOptions_SignOptions ? + reinterpret_cast(value) : nullptr; + } + const tflite::SignOptionsT *AsSignOptions() const { + return type == BuiltinOptions_SignOptions ? + reinterpret_cast(value) : nullptr; + } + tflite::BitcastOptionsT *AsBitcastOptions() { + return type == BuiltinOptions_BitcastOptions ? + reinterpret_cast(value) : nullptr; + } + const tflite::BitcastOptionsT *AsBitcastOptions() const { + return type == BuiltinOptions_BitcastOptions ? + reinterpret_cast(value) : nullptr; + } + tflite::BitwiseXorOptionsT *AsBitwiseXorOptions() { + return type == BuiltinOptions_BitwiseXorOptions ? + reinterpret_cast(value) : nullptr; + } + const tflite::BitwiseXorOptionsT *AsBitwiseXorOptions() const { + return type == BuiltinOptions_BitwiseXorOptions ? + reinterpret_cast(value) : nullptr; + } + tflite::RightShiftOptionsT *AsRightShiftOptions() { + return type == BuiltinOptions_RightShiftOptions ? + reinterpret_cast(value) : nullptr; + } + const tflite::RightShiftOptionsT *AsRightShiftOptions() const { + return type == BuiltinOptions_RightShiftOptions ? + reinterpret_cast(value) : nullptr; + } +}; + +bool VerifyBuiltinOptions(::flatbuffers::Verifier &verifier, const void *obj, BuiltinOptions type); +bool VerifyBuiltinOptionsVector(::flatbuffers::Verifier &verifier, const ::flatbuffers::Vector<::flatbuffers::Offset> *values, const ::flatbuffers::Vector *types); + +enum BuiltinOptions2 : uint8_t { + BuiltinOptions2_NONE = 0, + BuiltinOptions2_StablehloConcatenateOptions = 1, + BuiltinOptions2_StablehloBroadcastInDimOptions = 2, + BuiltinOptions2_StablehloSliceOptions = 3, + BuiltinOptions2_StablehloConvolutionOptions = 4, + BuiltinOptions2_StablehloCustomCallOptions = 5, + BuiltinOptions2_StablehloReduceOptions = 6, + BuiltinOptions2_StablehloScatterOptions = 7, + BuiltinOptions2_StablehloCompareOptions = 8, + BuiltinOptions2_StablehloDynamicSliceOptions = 9, + BuiltinOptions2_StablehloPadOptions = 10, + BuiltinOptions2_StablehloIotaOptions = 11, + BuiltinOptions2_StablehloDotGeneralOptions = 12, + BuiltinOptions2_StablehloReduceWindowOptions = 13, + BuiltinOptions2_StablehloSortOptions = 14, + BuiltinOptions2_StablehloWhileOptions = 15, + BuiltinOptions2_StablehloGatherOptions = 16, + BuiltinOptions2_StablehloTransposeOptions = 17, + BuiltinOptions2_DilateOptions = 18, + BuiltinOptions2_StablehloRngBitGeneratorOptions = 19, + BuiltinOptions2_ReduceWindowOptions = 20, + BuiltinOptions2_StableHLOCompositeOptions = 21, + BuiltinOptions2_StablehloShiftLeftOptions = 22, + BuiltinOptions2_StablehloCaseOptions = 23, + BuiltinOptions2_MIN = BuiltinOptions2_NONE, + BuiltinOptions2_MAX = BuiltinOptions2_StablehloCaseOptions +}; + +inline const BuiltinOptions2 (&EnumValuesBuiltinOptions2())[24] { + static const BuiltinOptions2 values[] = { + BuiltinOptions2_NONE, + BuiltinOptions2_StablehloConcatenateOptions, + BuiltinOptions2_StablehloBroadcastInDimOptions, + BuiltinOptions2_StablehloSliceOptions, + BuiltinOptions2_StablehloConvolutionOptions, + BuiltinOptions2_StablehloCustomCallOptions, + BuiltinOptions2_StablehloReduceOptions, + BuiltinOptions2_StablehloScatterOptions, + BuiltinOptions2_StablehloCompareOptions, + BuiltinOptions2_StablehloDynamicSliceOptions, + BuiltinOptions2_StablehloPadOptions, + BuiltinOptions2_StablehloIotaOptions, + BuiltinOptions2_StablehloDotGeneralOptions, + BuiltinOptions2_StablehloReduceWindowOptions, + BuiltinOptions2_StablehloSortOptions, + BuiltinOptions2_StablehloWhileOptions, + BuiltinOptions2_StablehloGatherOptions, + BuiltinOptions2_StablehloTransposeOptions, + BuiltinOptions2_DilateOptions, + BuiltinOptions2_StablehloRngBitGeneratorOptions, + BuiltinOptions2_ReduceWindowOptions, + BuiltinOptions2_StableHLOCompositeOptions, + BuiltinOptions2_StablehloShiftLeftOptions, + BuiltinOptions2_StablehloCaseOptions + }; + return values; +} + +inline const char * const *EnumNamesBuiltinOptions2() { + static const char * const names[25] = { + "NONE", + "StablehloConcatenateOptions", + "StablehloBroadcastInDimOptions", + "StablehloSliceOptions", + "StablehloConvolutionOptions", + "StablehloCustomCallOptions", + "StablehloReduceOptions", + "StablehloScatterOptions", + "StablehloCompareOptions", + "StablehloDynamicSliceOptions", + "StablehloPadOptions", + "StablehloIotaOptions", + "StablehloDotGeneralOptions", + "StablehloReduceWindowOptions", + "StablehloSortOptions", + "StablehloWhileOptions", + "StablehloGatherOptions", + "StablehloTransposeOptions", + "DilateOptions", + "StablehloRngBitGeneratorOptions", + "ReduceWindowOptions", + "StableHLOCompositeOptions", + "StablehloShiftLeftOptions", + "StablehloCaseOptions", + nullptr + }; + return names; +} + +inline const char *EnumNameBuiltinOptions2(BuiltinOptions2 e) { + if (::flatbuffers::IsOutRange(e, BuiltinOptions2_NONE, BuiltinOptions2_StablehloCaseOptions)) return ""; + const size_t index = static_cast(e); + return EnumNamesBuiltinOptions2()[index]; +} + +template struct BuiltinOptions2Traits { + static const BuiltinOptions2 enum_value = BuiltinOptions2_NONE; +}; + +template<> struct BuiltinOptions2Traits { + static const BuiltinOptions2 enum_value = BuiltinOptions2_StablehloConcatenateOptions; +}; + +template<> struct BuiltinOptions2Traits { + static const BuiltinOptions2 enum_value = BuiltinOptions2_StablehloBroadcastInDimOptions; +}; + +template<> struct BuiltinOptions2Traits { + static const BuiltinOptions2 enum_value = BuiltinOptions2_StablehloSliceOptions; +}; + +template<> struct BuiltinOptions2Traits { + static const BuiltinOptions2 enum_value = BuiltinOptions2_StablehloConvolutionOptions; +}; + +template<> struct BuiltinOptions2Traits { + static const BuiltinOptions2 enum_value = BuiltinOptions2_StablehloCustomCallOptions; +}; + +template<> struct BuiltinOptions2Traits { + static const BuiltinOptions2 enum_value = BuiltinOptions2_StablehloReduceOptions; +}; + +template<> struct BuiltinOptions2Traits { + static const BuiltinOptions2 enum_value = BuiltinOptions2_StablehloScatterOptions; +}; + +template<> struct BuiltinOptions2Traits { + static const BuiltinOptions2 enum_value = BuiltinOptions2_StablehloCompareOptions; +}; + +template<> struct BuiltinOptions2Traits { + static const BuiltinOptions2 enum_value = BuiltinOptions2_StablehloDynamicSliceOptions; +}; + +template<> struct BuiltinOptions2Traits { + static const BuiltinOptions2 enum_value = BuiltinOptions2_StablehloPadOptions; +}; + +template<> struct BuiltinOptions2Traits { + static const BuiltinOptions2 enum_value = BuiltinOptions2_StablehloIotaOptions; +}; + +template<> struct BuiltinOptions2Traits { + static const BuiltinOptions2 enum_value = BuiltinOptions2_StablehloDotGeneralOptions; +}; + +template<> struct BuiltinOptions2Traits { + static const BuiltinOptions2 enum_value = BuiltinOptions2_StablehloReduceWindowOptions; +}; + +template<> struct BuiltinOptions2Traits { + static const BuiltinOptions2 enum_value = BuiltinOptions2_StablehloSortOptions; +}; + +template<> struct BuiltinOptions2Traits { + static const BuiltinOptions2 enum_value = BuiltinOptions2_StablehloWhileOptions; +}; + +template<> struct BuiltinOptions2Traits { + static const BuiltinOptions2 enum_value = BuiltinOptions2_StablehloGatherOptions; +}; + +template<> struct BuiltinOptions2Traits { + static const BuiltinOptions2 enum_value = BuiltinOptions2_StablehloTransposeOptions; +}; + +template<> struct BuiltinOptions2Traits { + static const BuiltinOptions2 enum_value = BuiltinOptions2_DilateOptions; +}; + +template<> struct BuiltinOptions2Traits { + static const BuiltinOptions2 enum_value = BuiltinOptions2_StablehloRngBitGeneratorOptions; +}; + +template<> struct BuiltinOptions2Traits { + static const BuiltinOptions2 enum_value = BuiltinOptions2_ReduceWindowOptions; +}; + +template<> struct BuiltinOptions2Traits { + static const BuiltinOptions2 enum_value = BuiltinOptions2_StableHLOCompositeOptions; +}; + +template<> struct BuiltinOptions2Traits { + static const BuiltinOptions2 enum_value = BuiltinOptions2_StablehloShiftLeftOptions; +}; + +template<> struct BuiltinOptions2Traits { + static const BuiltinOptions2 enum_value = BuiltinOptions2_StablehloCaseOptions; +}; + +template struct BuiltinOptions2UnionTraits { + static const BuiltinOptions2 enum_value = BuiltinOptions2_NONE; +}; + +template<> struct BuiltinOptions2UnionTraits { + static const BuiltinOptions2 enum_value = BuiltinOptions2_StablehloConcatenateOptions; +}; + +template<> struct BuiltinOptions2UnionTraits { + static const BuiltinOptions2 enum_value = BuiltinOptions2_StablehloBroadcastInDimOptions; +}; + +template<> struct BuiltinOptions2UnionTraits { + static const BuiltinOptions2 enum_value = BuiltinOptions2_StablehloSliceOptions; +}; + +template<> struct BuiltinOptions2UnionTraits { + static const BuiltinOptions2 enum_value = BuiltinOptions2_StablehloConvolutionOptions; +}; + +template<> struct BuiltinOptions2UnionTraits { + static const BuiltinOptions2 enum_value = BuiltinOptions2_StablehloCustomCallOptions; +}; + +template<> struct BuiltinOptions2UnionTraits { + static const BuiltinOptions2 enum_value = BuiltinOptions2_StablehloReduceOptions; +}; + +template<> struct BuiltinOptions2UnionTraits { + static const BuiltinOptions2 enum_value = BuiltinOptions2_StablehloScatterOptions; +}; + +template<> struct BuiltinOptions2UnionTraits { + static const BuiltinOptions2 enum_value = BuiltinOptions2_StablehloCompareOptions; +}; + +template<> struct BuiltinOptions2UnionTraits { + static const BuiltinOptions2 enum_value = BuiltinOptions2_StablehloDynamicSliceOptions; +}; + +template<> struct BuiltinOptions2UnionTraits { + static const BuiltinOptions2 enum_value = BuiltinOptions2_StablehloPadOptions; +}; + +template<> struct BuiltinOptions2UnionTraits { + static const BuiltinOptions2 enum_value = BuiltinOptions2_StablehloIotaOptions; +}; + +template<> struct BuiltinOptions2UnionTraits { + static const BuiltinOptions2 enum_value = BuiltinOptions2_StablehloDotGeneralOptions; +}; + +template<> struct BuiltinOptions2UnionTraits { + static const BuiltinOptions2 enum_value = BuiltinOptions2_StablehloReduceWindowOptions; +}; + +template<> struct BuiltinOptions2UnionTraits { + static const BuiltinOptions2 enum_value = BuiltinOptions2_StablehloSortOptions; +}; + +template<> struct BuiltinOptions2UnionTraits { + static const BuiltinOptions2 enum_value = BuiltinOptions2_StablehloWhileOptions; +}; + +template<> struct BuiltinOptions2UnionTraits { + static const BuiltinOptions2 enum_value = BuiltinOptions2_StablehloGatherOptions; +}; + +template<> struct BuiltinOptions2UnionTraits { + static const BuiltinOptions2 enum_value = BuiltinOptions2_StablehloTransposeOptions; +}; + +template<> struct BuiltinOptions2UnionTraits { + static const BuiltinOptions2 enum_value = BuiltinOptions2_DilateOptions; +}; + +template<> struct BuiltinOptions2UnionTraits { + static const BuiltinOptions2 enum_value = BuiltinOptions2_StablehloRngBitGeneratorOptions; +}; + +template<> struct BuiltinOptions2UnionTraits { + static const BuiltinOptions2 enum_value = BuiltinOptions2_ReduceWindowOptions; +}; + +template<> struct BuiltinOptions2UnionTraits { + static const BuiltinOptions2 enum_value = BuiltinOptions2_StableHLOCompositeOptions; +}; + +template<> struct BuiltinOptions2UnionTraits { + static const BuiltinOptions2 enum_value = BuiltinOptions2_StablehloShiftLeftOptions; +}; + +template<> struct BuiltinOptions2UnionTraits { + static const BuiltinOptions2 enum_value = BuiltinOptions2_StablehloCaseOptions; +}; + +struct BuiltinOptions2Union { + BuiltinOptions2 type; + void *value; + + BuiltinOptions2Union() : type(BuiltinOptions2_NONE), value(nullptr) {} + BuiltinOptions2Union(BuiltinOptions2Union&& u) FLATBUFFERS_NOEXCEPT : + type(BuiltinOptions2_NONE), value(nullptr) + { std::swap(type, u.type); std::swap(value, u.value); } + BuiltinOptions2Union(const BuiltinOptions2Union &); + BuiltinOptions2Union &operator=(const BuiltinOptions2Union &u) + { BuiltinOptions2Union t(u); std::swap(type, t.type); std::swap(value, t.value); return *this; } + BuiltinOptions2Union &operator=(BuiltinOptions2Union &&u) FLATBUFFERS_NOEXCEPT + { std::swap(type, u.type); std::swap(value, u.value); return *this; } + ~BuiltinOptions2Union() { Reset(); } + + void Reset(); + + template + void Set(T&& val) { + typedef typename std::remove_reference::type RT; + Reset(); + type = BuiltinOptions2UnionTraits::enum_value; + if (type != BuiltinOptions2_NONE) { + value = new RT(std::forward(val)); + } + } + + static void *UnPack(const void *obj, BuiltinOptions2 type, const ::flatbuffers::resolver_function_t *resolver); + ::flatbuffers::Offset Pack(::flatbuffers::FlatBufferBuilder &_fbb, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr) const; + + tflite::StablehloConcatenateOptionsT *AsStablehloConcatenateOptions() { + return type == BuiltinOptions2_StablehloConcatenateOptions ? + reinterpret_cast(value) : nullptr; + } + const tflite::StablehloConcatenateOptionsT *AsStablehloConcatenateOptions() const { + return type == BuiltinOptions2_StablehloConcatenateOptions ? + reinterpret_cast(value) : nullptr; + } + tflite::StablehloBroadcastInDimOptionsT *AsStablehloBroadcastInDimOptions() { + return type == BuiltinOptions2_StablehloBroadcastInDimOptions ? + reinterpret_cast(value) : nullptr; + } + const tflite::StablehloBroadcastInDimOptionsT *AsStablehloBroadcastInDimOptions() const { + return type == BuiltinOptions2_StablehloBroadcastInDimOptions ? + reinterpret_cast(value) : nullptr; + } + tflite::StablehloSliceOptionsT *AsStablehloSliceOptions() { + return type == BuiltinOptions2_StablehloSliceOptions ? + reinterpret_cast(value) : nullptr; + } + const tflite::StablehloSliceOptionsT *AsStablehloSliceOptions() const { + return type == BuiltinOptions2_StablehloSliceOptions ? + reinterpret_cast(value) : nullptr; + } + tflite::StablehloConvolutionOptionsT *AsStablehloConvolutionOptions() { + return type == BuiltinOptions2_StablehloConvolutionOptions ? + reinterpret_cast(value) : nullptr; + } + const tflite::StablehloConvolutionOptionsT *AsStablehloConvolutionOptions() const { + return type == BuiltinOptions2_StablehloConvolutionOptions ? + reinterpret_cast(value) : nullptr; + } + tflite::StablehloCustomCallOptionsT *AsStablehloCustomCallOptions() { + return type == BuiltinOptions2_StablehloCustomCallOptions ? + reinterpret_cast(value) : nullptr; + } + const tflite::StablehloCustomCallOptionsT *AsStablehloCustomCallOptions() const { + return type == BuiltinOptions2_StablehloCustomCallOptions ? + reinterpret_cast(value) : nullptr; + } + tflite::StablehloReduceOptionsT *AsStablehloReduceOptions() { + return type == BuiltinOptions2_StablehloReduceOptions ? + reinterpret_cast(value) : nullptr; + } + const tflite::StablehloReduceOptionsT *AsStablehloReduceOptions() const { + return type == BuiltinOptions2_StablehloReduceOptions ? + reinterpret_cast(value) : nullptr; + } + tflite::StablehloScatterOptionsT *AsStablehloScatterOptions() { + return type == BuiltinOptions2_StablehloScatterOptions ? + reinterpret_cast(value) : nullptr; + } + const tflite::StablehloScatterOptionsT *AsStablehloScatterOptions() const { + return type == BuiltinOptions2_StablehloScatterOptions ? + reinterpret_cast(value) : nullptr; + } + tflite::StablehloCompareOptionsT *AsStablehloCompareOptions() { + return type == BuiltinOptions2_StablehloCompareOptions ? + reinterpret_cast(value) : nullptr; + } + const tflite::StablehloCompareOptionsT *AsStablehloCompareOptions() const { + return type == BuiltinOptions2_StablehloCompareOptions ? + reinterpret_cast(value) : nullptr; + } + tflite::StablehloDynamicSliceOptionsT *AsStablehloDynamicSliceOptions() { + return type == BuiltinOptions2_StablehloDynamicSliceOptions ? + reinterpret_cast(value) : nullptr; + } + const tflite::StablehloDynamicSliceOptionsT *AsStablehloDynamicSliceOptions() const { + return type == BuiltinOptions2_StablehloDynamicSliceOptions ? + reinterpret_cast(value) : nullptr; + } + tflite::StablehloPadOptionsT *AsStablehloPadOptions() { + return type == BuiltinOptions2_StablehloPadOptions ? + reinterpret_cast(value) : nullptr; + } + const tflite::StablehloPadOptionsT *AsStablehloPadOptions() const { + return type == BuiltinOptions2_StablehloPadOptions ? + reinterpret_cast(value) : nullptr; + } + tflite::StablehloIotaOptionsT *AsStablehloIotaOptions() { + return type == BuiltinOptions2_StablehloIotaOptions ? + reinterpret_cast(value) : nullptr; + } + const tflite::StablehloIotaOptionsT *AsStablehloIotaOptions() const { + return type == BuiltinOptions2_StablehloIotaOptions ? + reinterpret_cast(value) : nullptr; + } + tflite::StablehloDotGeneralOptionsT *AsStablehloDotGeneralOptions() { + return type == BuiltinOptions2_StablehloDotGeneralOptions ? + reinterpret_cast(value) : nullptr; + } + const tflite::StablehloDotGeneralOptionsT *AsStablehloDotGeneralOptions() const { + return type == BuiltinOptions2_StablehloDotGeneralOptions ? + reinterpret_cast(value) : nullptr; + } + tflite::StablehloReduceWindowOptionsT *AsStablehloReduceWindowOptions() { + return type == BuiltinOptions2_StablehloReduceWindowOptions ? + reinterpret_cast(value) : nullptr; + } + const tflite::StablehloReduceWindowOptionsT *AsStablehloReduceWindowOptions() const { + return type == BuiltinOptions2_StablehloReduceWindowOptions ? + reinterpret_cast(value) : nullptr; + } + tflite::StablehloSortOptionsT *AsStablehloSortOptions() { + return type == BuiltinOptions2_StablehloSortOptions ? + reinterpret_cast(value) : nullptr; + } + const tflite::StablehloSortOptionsT *AsStablehloSortOptions() const { + return type == BuiltinOptions2_StablehloSortOptions ? + reinterpret_cast(value) : nullptr; + } + tflite::StablehloWhileOptionsT *AsStablehloWhileOptions() { + return type == BuiltinOptions2_StablehloWhileOptions ? + reinterpret_cast(value) : nullptr; + } + const tflite::StablehloWhileOptionsT *AsStablehloWhileOptions() const { + return type == BuiltinOptions2_StablehloWhileOptions ? + reinterpret_cast(value) : nullptr; + } + tflite::StablehloGatherOptionsT *AsStablehloGatherOptions() { + return type == BuiltinOptions2_StablehloGatherOptions ? + reinterpret_cast(value) : nullptr; + } + const tflite::StablehloGatherOptionsT *AsStablehloGatherOptions() const { + return type == BuiltinOptions2_StablehloGatherOptions ? + reinterpret_cast(value) : nullptr; + } + tflite::StablehloTransposeOptionsT *AsStablehloTransposeOptions() { + return type == BuiltinOptions2_StablehloTransposeOptions ? + reinterpret_cast(value) : nullptr; + } + const tflite::StablehloTransposeOptionsT *AsStablehloTransposeOptions() const { + return type == BuiltinOptions2_StablehloTransposeOptions ? + reinterpret_cast(value) : nullptr; + } + tflite::DilateOptionsT *AsDilateOptions() { + return type == BuiltinOptions2_DilateOptions ? + reinterpret_cast(value) : nullptr; + } + const tflite::DilateOptionsT *AsDilateOptions() const { + return type == BuiltinOptions2_DilateOptions ? + reinterpret_cast(value) : nullptr; + } + tflite::StablehloRngBitGeneratorOptionsT *AsStablehloRngBitGeneratorOptions() { + return type == BuiltinOptions2_StablehloRngBitGeneratorOptions ? + reinterpret_cast(value) : nullptr; + } + const tflite::StablehloRngBitGeneratorOptionsT *AsStablehloRngBitGeneratorOptions() const { + return type == BuiltinOptions2_StablehloRngBitGeneratorOptions ? + reinterpret_cast(value) : nullptr; + } + tflite::ReduceWindowOptionsT *AsReduceWindowOptions() { + return type == BuiltinOptions2_ReduceWindowOptions ? + reinterpret_cast(value) : nullptr; + } + const tflite::ReduceWindowOptionsT *AsReduceWindowOptions() const { + return type == BuiltinOptions2_ReduceWindowOptions ? + reinterpret_cast(value) : nullptr; + } + tflite::StableHLOCompositeOptionsT *AsStableHLOCompositeOptions() { + return type == BuiltinOptions2_StableHLOCompositeOptions ? + reinterpret_cast(value) : nullptr; + } + const tflite::StableHLOCompositeOptionsT *AsStableHLOCompositeOptions() const { + return type == BuiltinOptions2_StableHLOCompositeOptions ? + reinterpret_cast(value) : nullptr; + } + tflite::StablehloShiftLeftOptionsT *AsStablehloShiftLeftOptions() { + return type == BuiltinOptions2_StablehloShiftLeftOptions ? + reinterpret_cast(value) : nullptr; + } + const tflite::StablehloShiftLeftOptionsT *AsStablehloShiftLeftOptions() const { + return type == BuiltinOptions2_StablehloShiftLeftOptions ? + reinterpret_cast(value) : nullptr; + } + tflite::StablehloCaseOptionsT *AsStablehloCaseOptions() { + return type == BuiltinOptions2_StablehloCaseOptions ? + reinterpret_cast(value) : nullptr; + } + const tflite::StablehloCaseOptionsT *AsStablehloCaseOptions() const { + return type == BuiltinOptions2_StablehloCaseOptions ? + reinterpret_cast(value) : nullptr; + } +}; + +bool VerifyBuiltinOptions2(::flatbuffers::Verifier &verifier, const void *obj, BuiltinOptions2 type); +bool VerifyBuiltinOptions2Vector(::flatbuffers::Verifier &verifier, const ::flatbuffers::Vector<::flatbuffers::Offset> *values, const ::flatbuffers::Vector *types); + +enum StablehloPrecisionConfig : uint32_t { + StablehloPrecisionConfig_DEFAULT = 0, + StablehloPrecisionConfig_HIGH = 1, + StablehloPrecisionConfig_HIGHEST = 2, + StablehloPrecisionConfig_MIN = StablehloPrecisionConfig_DEFAULT, + StablehloPrecisionConfig_MAX = StablehloPrecisionConfig_HIGHEST +}; + +inline const StablehloPrecisionConfig (&EnumValuesStablehloPrecisionConfig())[3] { + static const StablehloPrecisionConfig values[] = { + StablehloPrecisionConfig_DEFAULT, + StablehloPrecisionConfig_HIGH, + StablehloPrecisionConfig_HIGHEST + }; + return values; +} + +inline const char * const *EnumNamesStablehloPrecisionConfig() { + static const char * const names[4] = { + "DEFAULT", + "HIGH", + "HIGHEST", + nullptr + }; + return names; +} + +inline const char *EnumNameStablehloPrecisionConfig(StablehloPrecisionConfig e) { + if (::flatbuffers::IsOutRange(e, StablehloPrecisionConfig_DEFAULT, StablehloPrecisionConfig_HIGHEST)) return ""; + const size_t index = static_cast(e); + return EnumNamesStablehloPrecisionConfig()[index]; +} + +enum StablehloComparisonDirection : uint32_t { + StablehloComparisonDirection_STABLEHLO_COMPARISON_DIRECTION_EQ = 0, + StablehloComparisonDirection_STABLEHLO_COMPARISON_DIRECTION_NE = 1, + StablehloComparisonDirection_STABLEHLO_COMPARISON_DIRECTION_GE = 2, + StablehloComparisonDirection_STABLEHLO_COMPARISON_DIRECTION_GT = 3, + StablehloComparisonDirection_STABLEHLO_COMPARISON_DIRECTION_LE = 4, + StablehloComparisonDirection_STABLEHLO_COMPARISON_DIRECTION_LT = 5, + StablehloComparisonDirection_MIN = StablehloComparisonDirection_STABLEHLO_COMPARISON_DIRECTION_EQ, + StablehloComparisonDirection_MAX = StablehloComparisonDirection_STABLEHLO_COMPARISON_DIRECTION_LT +}; + +inline const StablehloComparisonDirection (&EnumValuesStablehloComparisonDirection())[6] { + static const StablehloComparisonDirection values[] = { + StablehloComparisonDirection_STABLEHLO_COMPARISON_DIRECTION_EQ, + StablehloComparisonDirection_STABLEHLO_COMPARISON_DIRECTION_NE, + StablehloComparisonDirection_STABLEHLO_COMPARISON_DIRECTION_GE, + StablehloComparisonDirection_STABLEHLO_COMPARISON_DIRECTION_GT, + StablehloComparisonDirection_STABLEHLO_COMPARISON_DIRECTION_LE, + StablehloComparisonDirection_STABLEHLO_COMPARISON_DIRECTION_LT + }; + return values; +} + +inline const char * const *EnumNamesStablehloComparisonDirection() { + static const char * const names[7] = { + "STABLEHLO_COMPARISON_DIRECTION_EQ", + "STABLEHLO_COMPARISON_DIRECTION_NE", + "STABLEHLO_COMPARISON_DIRECTION_GE", + "STABLEHLO_COMPARISON_DIRECTION_GT", + "STABLEHLO_COMPARISON_DIRECTION_LE", + "STABLEHLO_COMPARISON_DIRECTION_LT", + nullptr + }; + return names; +} + +inline const char *EnumNameStablehloComparisonDirection(StablehloComparisonDirection e) { + if (::flatbuffers::IsOutRange(e, StablehloComparisonDirection_STABLEHLO_COMPARISON_DIRECTION_EQ, StablehloComparisonDirection_STABLEHLO_COMPARISON_DIRECTION_LT)) return ""; + const size_t index = static_cast(e); + return EnumNamesStablehloComparisonDirection()[index]; +} + +enum StablehloComparisonType : uint32_t { + StablehloComparisonType_STABLEHLO_COMPARISON_TYPE_NOTYPE = 0, + StablehloComparisonType_STABLEHLO_COMPARISON_TYPE_FLOAT = 1, + StablehloComparisonType_STABLEHLO_COMPARISON_TYPE_FLOAT_TOTAL_ORDER = 2, + StablehloComparisonType_STABLEHLO_COMPARISON_TYPE_SIGNED = 3, + StablehloComparisonType_STABLEHLO_COMPARISON_TYPE_UNSIGNED = 4, + StablehloComparisonType_MIN = StablehloComparisonType_STABLEHLO_COMPARISON_TYPE_NOTYPE, + StablehloComparisonType_MAX = StablehloComparisonType_STABLEHLO_COMPARISON_TYPE_UNSIGNED +}; + +inline const StablehloComparisonType (&EnumValuesStablehloComparisonType())[5] { + static const StablehloComparisonType values[] = { + StablehloComparisonType_STABLEHLO_COMPARISON_TYPE_NOTYPE, + StablehloComparisonType_STABLEHLO_COMPARISON_TYPE_FLOAT, + StablehloComparisonType_STABLEHLO_COMPARISON_TYPE_FLOAT_TOTAL_ORDER, + StablehloComparisonType_STABLEHLO_COMPARISON_TYPE_SIGNED, + StablehloComparisonType_STABLEHLO_COMPARISON_TYPE_UNSIGNED + }; + return values; +} + +inline const char * const *EnumNamesStablehloComparisonType() { + static const char * const names[6] = { + "STABLEHLO_COMPARISON_TYPE_NOTYPE", + "STABLEHLO_COMPARISON_TYPE_FLOAT", + "STABLEHLO_COMPARISON_TYPE_FLOAT_TOTAL_ORDER", + "STABLEHLO_COMPARISON_TYPE_SIGNED", + "STABLEHLO_COMPARISON_TYPE_UNSIGNED", + nullptr + }; + return names; +} + +inline const char *EnumNameStablehloComparisonType(StablehloComparisonType e) { + if (::flatbuffers::IsOutRange(e, StablehloComparisonType_STABLEHLO_COMPARISON_TYPE_NOTYPE, StablehloComparisonType_STABLEHLO_COMPARISON_TYPE_UNSIGNED)) return ""; + const size_t index = static_cast(e); + return EnumNamesStablehloComparisonType()[index]; +} + +enum RngAlgorithm : int8_t { + RngAlgorithm_DEFAULT = 0, + RngAlgorithm_PHILOX = 1, + RngAlgorithm_THREEFRY = 2, + RngAlgorithm_MIN = RngAlgorithm_DEFAULT, + RngAlgorithm_MAX = RngAlgorithm_THREEFRY +}; + +inline const RngAlgorithm (&EnumValuesRngAlgorithm())[3] { + static const RngAlgorithm values[] = { + RngAlgorithm_DEFAULT, + RngAlgorithm_PHILOX, + RngAlgorithm_THREEFRY + }; + return values; +} + +inline const char * const *EnumNamesRngAlgorithm() { + static const char * const names[4] = { + "DEFAULT", + "PHILOX", + "THREEFRY", + nullptr + }; + return names; +} + +inline const char *EnumNameRngAlgorithm(RngAlgorithm e) { + if (::flatbuffers::IsOutRange(e, RngAlgorithm_DEFAULT, RngAlgorithm_THREEFRY)) return ""; + const size_t index = static_cast(e); + return EnumNamesRngAlgorithm()[index]; +} + +enum Padding : int8_t { + Padding_SAME = 0, + Padding_VALID = 1, + Padding_MIN = Padding_SAME, + Padding_MAX = Padding_VALID +}; + +inline const Padding (&EnumValuesPadding())[2] { + static const Padding values[] = { + Padding_SAME, + Padding_VALID + }; + return values; +} + +inline const char * const *EnumNamesPadding() { + static const char * const names[3] = { + "SAME", + "VALID", + nullptr + }; + return names; +} + +inline const char *EnumNamePadding(Padding e) { + if (::flatbuffers::IsOutRange(e, Padding_SAME, Padding_VALID)) return ""; + const size_t index = static_cast(e); + return EnumNamesPadding()[index]; +} + +enum ActivationFunctionType : int8_t { + ActivationFunctionType_NONE = 0, + ActivationFunctionType_RELU = 1, + ActivationFunctionType_RELU_N1_TO_1 = 2, + ActivationFunctionType_RELU6 = 3, + ActivationFunctionType_TANH = 4, + ActivationFunctionType_SIGN_BIT = 5, + ActivationFunctionType_MIN = ActivationFunctionType_NONE, + ActivationFunctionType_MAX = ActivationFunctionType_SIGN_BIT +}; + +inline const ActivationFunctionType (&EnumValuesActivationFunctionType())[6] { + static const ActivationFunctionType values[] = { + ActivationFunctionType_NONE, + ActivationFunctionType_RELU, + ActivationFunctionType_RELU_N1_TO_1, + ActivationFunctionType_RELU6, + ActivationFunctionType_TANH, + ActivationFunctionType_SIGN_BIT + }; + return values; +} + +inline const char * const *EnumNamesActivationFunctionType() { + static const char * const names[7] = { + "NONE", + "RELU", + "RELU_N1_TO_1", + "RELU6", + "TANH", + "SIGN_BIT", + nullptr + }; + return names; +} + +inline const char *EnumNameActivationFunctionType(ActivationFunctionType e) { + if (::flatbuffers::IsOutRange(e, ActivationFunctionType_NONE, ActivationFunctionType_SIGN_BIT)) return ""; + const size_t index = static_cast(e); + return EnumNamesActivationFunctionType()[index]; +} + +enum LSHProjectionType : int8_t { + LSHProjectionType_UNKNOWN = 0, + LSHProjectionType_SPARSE = 1, + LSHProjectionType_DENSE = 2, + LSHProjectionType_MIN = LSHProjectionType_UNKNOWN, + LSHProjectionType_MAX = LSHProjectionType_DENSE +}; + +inline const LSHProjectionType (&EnumValuesLSHProjectionType())[3] { + static const LSHProjectionType values[] = { + LSHProjectionType_UNKNOWN, + LSHProjectionType_SPARSE, + LSHProjectionType_DENSE + }; + return values; +} + +inline const char * const *EnumNamesLSHProjectionType() { + static const char * const names[4] = { + "UNKNOWN", + "SPARSE", + "DENSE", + nullptr + }; + return names; +} + +inline const char *EnumNameLSHProjectionType(LSHProjectionType e) { + if (::flatbuffers::IsOutRange(e, LSHProjectionType_UNKNOWN, LSHProjectionType_DENSE)) return ""; + const size_t index = static_cast(e); + return EnumNamesLSHProjectionType()[index]; +} + +enum FullyConnectedOptionsWeightsFormat : int8_t { + FullyConnectedOptionsWeightsFormat_DEFAULT = 0, + FullyConnectedOptionsWeightsFormat_SHUFFLED4x16INT8 = 1, + FullyConnectedOptionsWeightsFormat_MIN = FullyConnectedOptionsWeightsFormat_DEFAULT, + FullyConnectedOptionsWeightsFormat_MAX = FullyConnectedOptionsWeightsFormat_SHUFFLED4x16INT8 +}; + +inline const FullyConnectedOptionsWeightsFormat (&EnumValuesFullyConnectedOptionsWeightsFormat())[2] { + static const FullyConnectedOptionsWeightsFormat values[] = { + FullyConnectedOptionsWeightsFormat_DEFAULT, + FullyConnectedOptionsWeightsFormat_SHUFFLED4x16INT8 + }; + return values; +} + +inline const char * const *EnumNamesFullyConnectedOptionsWeightsFormat() { + static const char * const names[3] = { + "DEFAULT", + "SHUFFLED4x16INT8", + nullptr + }; + return names; +} + +inline const char *EnumNameFullyConnectedOptionsWeightsFormat(FullyConnectedOptionsWeightsFormat e) { + if (::flatbuffers::IsOutRange(e, FullyConnectedOptionsWeightsFormat_DEFAULT, FullyConnectedOptionsWeightsFormat_SHUFFLED4x16INT8)) return ""; + const size_t index = static_cast(e); + return EnumNamesFullyConnectedOptionsWeightsFormat()[index]; +} + +enum LSTMKernelType : int8_t { + LSTMKernelType_FULL = 0, + LSTMKernelType_BASIC = 1, + LSTMKernelType_MIN = LSTMKernelType_FULL, + LSTMKernelType_MAX = LSTMKernelType_BASIC +}; + +inline const LSTMKernelType (&EnumValuesLSTMKernelType())[2] { + static const LSTMKernelType values[] = { + LSTMKernelType_FULL, + LSTMKernelType_BASIC + }; + return values; +} + +inline const char * const *EnumNamesLSTMKernelType() { + static const char * const names[3] = { + "FULL", + "BASIC", + nullptr + }; + return names; +} + +inline const char *EnumNameLSTMKernelType(LSTMKernelType e) { + if (::flatbuffers::IsOutRange(e, LSTMKernelType_FULL, LSTMKernelType_BASIC)) return ""; + const size_t index = static_cast(e); + return EnumNamesLSTMKernelType()[index]; +} + +enum CombinerType : int8_t { + CombinerType_SUM = 0, + CombinerType_MEAN = 1, + CombinerType_SQRTN = 2, + CombinerType_MIN = CombinerType_SUM, + CombinerType_MAX = CombinerType_SQRTN +}; + +inline const CombinerType (&EnumValuesCombinerType())[3] { + static const CombinerType values[] = { + CombinerType_SUM, + CombinerType_MEAN, + CombinerType_SQRTN + }; + return values; +} + +inline const char * const *EnumNamesCombinerType() { + static const char * const names[4] = { + "SUM", + "MEAN", + "SQRTN", + nullptr + }; + return names; +} + +inline const char *EnumNameCombinerType(CombinerType e) { + if (::flatbuffers::IsOutRange(e, CombinerType_SUM, CombinerType_SQRTN)) return ""; + const size_t index = static_cast(e); + return EnumNamesCombinerType()[index]; +} + +enum MirrorPadMode : int8_t { + MirrorPadMode_REFLECT = 0, + MirrorPadMode_SYMMETRIC = 1, + MirrorPadMode_MIN = MirrorPadMode_REFLECT, + MirrorPadMode_MAX = MirrorPadMode_SYMMETRIC +}; + +inline const MirrorPadMode (&EnumValuesMirrorPadMode())[2] { + static const MirrorPadMode values[] = { + MirrorPadMode_REFLECT, + MirrorPadMode_SYMMETRIC + }; + return values; +} + +inline const char * const *EnumNamesMirrorPadMode() { + static const char * const names[3] = { + "REFLECT", + "SYMMETRIC", + nullptr + }; + return names; +} + +inline const char *EnumNameMirrorPadMode(MirrorPadMode e) { + if (::flatbuffers::IsOutRange(e, MirrorPadMode_REFLECT, MirrorPadMode_SYMMETRIC)) return ""; + const size_t index = static_cast(e); + return EnumNamesMirrorPadMode()[index]; +} + +enum ReduceWindowFunction : int32_t { + ReduceWindowFunction_UNSUPPORTED = 0, + ReduceWindowFunction_ADD = 1, + ReduceWindowFunction_MUL = 2, + ReduceWindowFunction_MINIMUM = 3, + ReduceWindowFunction_MAXIMUM = 4, + ReduceWindowFunction_ALL = 5, + ReduceWindowFunction_ANY = 6, + ReduceWindowFunction_MIN = ReduceWindowFunction_UNSUPPORTED, + ReduceWindowFunction_MAX = ReduceWindowFunction_ANY +}; + +inline const ReduceWindowFunction (&EnumValuesReduceWindowFunction())[7] { + static const ReduceWindowFunction values[] = { + ReduceWindowFunction_UNSUPPORTED, + ReduceWindowFunction_ADD, + ReduceWindowFunction_MUL, + ReduceWindowFunction_MINIMUM, + ReduceWindowFunction_MAXIMUM, + ReduceWindowFunction_ALL, + ReduceWindowFunction_ANY + }; + return values; +} + +inline const char * const *EnumNamesReduceWindowFunction() { + static const char * const names[8] = { + "UNSUPPORTED", + "ADD", + "MUL", + "MINIMUM", + "MAXIMUM", + "ALL", + "ANY", + nullptr + }; + return names; +} + +inline const char *EnumNameReduceWindowFunction(ReduceWindowFunction e) { + if (::flatbuffers::IsOutRange(e, ReduceWindowFunction_UNSUPPORTED, ReduceWindowFunction_ANY)) return ""; + const size_t index = static_cast(e); + return EnumNamesReduceWindowFunction()[index]; +} + +enum CustomOptionsFormat : int8_t { + CustomOptionsFormat_FLEXBUFFERS = 0, + CustomOptionsFormat_MIN = CustomOptionsFormat_FLEXBUFFERS, + CustomOptionsFormat_MAX = CustomOptionsFormat_FLEXBUFFERS +}; + +inline const CustomOptionsFormat (&EnumValuesCustomOptionsFormat())[1] { + static const CustomOptionsFormat values[] = { + CustomOptionsFormat_FLEXBUFFERS + }; + return values; +} + +inline const char * const *EnumNamesCustomOptionsFormat() { + static const char * const names[2] = { + "FLEXBUFFERS", + nullptr + }; + return names; +} + +inline const char *EnumNameCustomOptionsFormat(CustomOptionsFormat e) { + if (::flatbuffers::IsOutRange(e, CustomOptionsFormat_FLEXBUFFERS, CustomOptionsFormat_FLEXBUFFERS)) return ""; + const size_t index = static_cast(e); + return EnumNamesCustomOptionsFormat()[index]; +} + +struct CustomQuantizationT : public ::flatbuffers::NativeTable { + typedef CustomQuantization TableType; + std::vector custom{}; +}; + +struct CustomQuantization FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef CustomQuantizationT NativeTableType; + typedef CustomQuantizationBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_CUSTOM = 4 + }; + const ::flatbuffers::Vector *custom() const { + return GetPointer *>(VT_CUSTOM); + } + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyOffset(verifier, VT_CUSTOM) && + verifier.VerifyVector(custom()) && + verifier.EndTable(); + } + CustomQuantizationT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(CustomQuantizationT *_o, const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + static ::flatbuffers::Offset Pack(::flatbuffers::FlatBufferBuilder &_fbb, const CustomQuantizationT* _o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct CustomQuantizationBuilder { + typedef CustomQuantization Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_custom(::flatbuffers::Offset<::flatbuffers::Vector> custom) { + fbb_.AddOffset(CustomQuantization::VT_CUSTOM, custom); + } + explicit CustomQuantizationBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateCustomQuantization( + ::flatbuffers::FlatBufferBuilder &_fbb, + ::flatbuffers::Offset<::flatbuffers::Vector> custom = 0) { + CustomQuantizationBuilder builder_(_fbb); + builder_.add_custom(custom); + return builder_.Finish(); +} + +inline ::flatbuffers::Offset CreateCustomQuantizationDirect( + ::flatbuffers::FlatBufferBuilder &_fbb, + const std::vector *custom = nullptr) { + if (custom) { _fbb.ForceVectorAlignment(custom->size(), sizeof(uint8_t), 16); } + auto custom__ = custom ? _fbb.CreateVector(*custom) : 0; + return tflite::CreateCustomQuantization( + _fbb, + custom__); +} + +::flatbuffers::Offset CreateCustomQuantization(::flatbuffers::FlatBufferBuilder &_fbb, const CustomQuantizationT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct BlockwiseQuantizationT : public ::flatbuffers::NativeTable { + typedef BlockwiseQuantization TableType; + int32_t scales = 0; + int32_t zero_points = 0; + int32_t block_size = 0; +}; + +struct BlockwiseQuantization FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef BlockwiseQuantizationT NativeTableType; + typedef BlockwiseQuantizationBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_SCALES = 4, + VT_ZERO_POINTS = 6, + VT_BLOCK_SIZE = 8 + }; + int32_t scales() const { + return GetField(VT_SCALES, 0); + } + int32_t zero_points() const { + return GetField(VT_ZERO_POINTS, 0); + } + int32_t block_size() const { + return GetField(VT_BLOCK_SIZE, 0); + } + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_SCALES, 4) && + VerifyField(verifier, VT_ZERO_POINTS, 4) && + VerifyField(verifier, VT_BLOCK_SIZE, 4) && + verifier.EndTable(); + } + BlockwiseQuantizationT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(BlockwiseQuantizationT *_o, const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + static ::flatbuffers::Offset Pack(::flatbuffers::FlatBufferBuilder &_fbb, const BlockwiseQuantizationT* _o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct BlockwiseQuantizationBuilder { + typedef BlockwiseQuantization Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_scales(int32_t scales) { + fbb_.AddElement(BlockwiseQuantization::VT_SCALES, scales, 0); + } + void add_zero_points(int32_t zero_points) { + fbb_.AddElement(BlockwiseQuantization::VT_ZERO_POINTS, zero_points, 0); + } + void add_block_size(int32_t block_size) { + fbb_.AddElement(BlockwiseQuantization::VT_BLOCK_SIZE, block_size, 0); + } + explicit BlockwiseQuantizationBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateBlockwiseQuantization( + ::flatbuffers::FlatBufferBuilder &_fbb, + int32_t scales = 0, + int32_t zero_points = 0, + int32_t block_size = 0) { + BlockwiseQuantizationBuilder builder_(_fbb); + builder_.add_block_size(block_size); + builder_.add_zero_points(zero_points); + builder_.add_scales(scales); + return builder_.Finish(); +} + +::flatbuffers::Offset CreateBlockwiseQuantization(::flatbuffers::FlatBufferBuilder &_fbb, const BlockwiseQuantizationT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct QuantizationParametersT : public ::flatbuffers::NativeTable { + typedef QuantizationParameters TableType; + std::vector min{}; + std::vector max{}; + std::vector scale{}; + std::vector zero_point{}; + tflite::QuantizationDetailsUnion details{}; + int32_t quantized_dimension = 0; +}; + +struct QuantizationParameters FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef QuantizationParametersT NativeTableType; + typedef QuantizationParametersBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_MIN = 4, + VT_MAX = 6, + VT_SCALE = 8, + VT_ZERO_POINT = 10, + VT_DETAILS_TYPE = 12, + VT_DETAILS = 14, + VT_QUANTIZED_DIMENSION = 16 + }; + const ::flatbuffers::Vector *min() const { + return GetPointer *>(VT_MIN); + } + const ::flatbuffers::Vector *max() const { + return GetPointer *>(VT_MAX); + } + const ::flatbuffers::Vector *scale() const { + return GetPointer *>(VT_SCALE); + } + const ::flatbuffers::Vector *zero_point() const { + return GetPointer *>(VT_ZERO_POINT); + } + tflite::QuantizationDetails details_type() const { + return static_cast(GetField(VT_DETAILS_TYPE, 0)); + } + const void *details() const { + return GetPointer(VT_DETAILS); + } + template const T *details_as() const; + const tflite::CustomQuantization *details_as_CustomQuantization() const { + return details_type() == tflite::QuantizationDetails_CustomQuantization ? static_cast(details()) : nullptr; + } + const tflite::BlockwiseQuantization *details_as_BlockwiseQuantization() const { + return details_type() == tflite::QuantizationDetails_BlockwiseQuantization ? static_cast(details()) : nullptr; + } + int32_t quantized_dimension() const { + return GetField(VT_QUANTIZED_DIMENSION, 0); + } + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyOffset(verifier, VT_MIN) && + verifier.VerifyVector(min()) && + VerifyOffset(verifier, VT_MAX) && + verifier.VerifyVector(max()) && + VerifyOffset(verifier, VT_SCALE) && + verifier.VerifyVector(scale()) && + VerifyOffset(verifier, VT_ZERO_POINT) && + verifier.VerifyVector(zero_point()) && + VerifyField(verifier, VT_DETAILS_TYPE, 1) && + VerifyOffset(verifier, VT_DETAILS) && + VerifyQuantizationDetails(verifier, details(), details_type()) && + VerifyField(verifier, VT_QUANTIZED_DIMENSION, 4) && + verifier.EndTable(); + } + QuantizationParametersT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(QuantizationParametersT *_o, const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + static ::flatbuffers::Offset Pack(::flatbuffers::FlatBufferBuilder &_fbb, const QuantizationParametersT* _o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +template<> inline const tflite::CustomQuantization *QuantizationParameters::details_as() const { + return details_as_CustomQuantization(); +} + +template<> inline const tflite::BlockwiseQuantization *QuantizationParameters::details_as() const { + return details_as_BlockwiseQuantization(); +} + +struct QuantizationParametersBuilder { + typedef QuantizationParameters Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_min(::flatbuffers::Offset<::flatbuffers::Vector> min) { + fbb_.AddOffset(QuantizationParameters::VT_MIN, min); + } + void add_max(::flatbuffers::Offset<::flatbuffers::Vector> max) { + fbb_.AddOffset(QuantizationParameters::VT_MAX, max); + } + void add_scale(::flatbuffers::Offset<::flatbuffers::Vector> scale) { + fbb_.AddOffset(QuantizationParameters::VT_SCALE, scale); + } + void add_zero_point(::flatbuffers::Offset<::flatbuffers::Vector> zero_point) { + fbb_.AddOffset(QuantizationParameters::VT_ZERO_POINT, zero_point); + } + void add_details_type(tflite::QuantizationDetails details_type) { + fbb_.AddElement(QuantizationParameters::VT_DETAILS_TYPE, static_cast(details_type), 0); + } + void add_details(::flatbuffers::Offset details) { + fbb_.AddOffset(QuantizationParameters::VT_DETAILS, details); + } + void add_quantized_dimension(int32_t quantized_dimension) { + fbb_.AddElement(QuantizationParameters::VT_QUANTIZED_DIMENSION, quantized_dimension, 0); + } + explicit QuantizationParametersBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateQuantizationParameters( + ::flatbuffers::FlatBufferBuilder &_fbb, + ::flatbuffers::Offset<::flatbuffers::Vector> min = 0, + ::flatbuffers::Offset<::flatbuffers::Vector> max = 0, + ::flatbuffers::Offset<::flatbuffers::Vector> scale = 0, + ::flatbuffers::Offset<::flatbuffers::Vector> zero_point = 0, + tflite::QuantizationDetails details_type = tflite::QuantizationDetails_NONE, + ::flatbuffers::Offset details = 0, + int32_t quantized_dimension = 0) { + QuantizationParametersBuilder builder_(_fbb); + builder_.add_quantized_dimension(quantized_dimension); + builder_.add_details(details); + builder_.add_zero_point(zero_point); + builder_.add_scale(scale); + builder_.add_max(max); + builder_.add_min(min); + builder_.add_details_type(details_type); + return builder_.Finish(); +} + +inline ::flatbuffers::Offset CreateQuantizationParametersDirect( + ::flatbuffers::FlatBufferBuilder &_fbb, + const std::vector *min = nullptr, + const std::vector *max = nullptr, + const std::vector *scale = nullptr, + const std::vector *zero_point = nullptr, + tflite::QuantizationDetails details_type = tflite::QuantizationDetails_NONE, + ::flatbuffers::Offset details = 0, + int32_t quantized_dimension = 0) { + auto min__ = min ? _fbb.CreateVector(*min) : 0; + auto max__ = max ? _fbb.CreateVector(*max) : 0; + auto scale__ = scale ? _fbb.CreateVector(*scale) : 0; + auto zero_point__ = zero_point ? _fbb.CreateVector(*zero_point) : 0; + return tflite::CreateQuantizationParameters( + _fbb, + min__, + max__, + scale__, + zero_point__, + details_type, + details, + quantized_dimension); +} + +::flatbuffers::Offset CreateQuantizationParameters(::flatbuffers::FlatBufferBuilder &_fbb, const QuantizationParametersT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct Int32VectorT : public ::flatbuffers::NativeTable { + typedef Int32Vector TableType; + std::vector values{}; +}; + +struct Int32Vector FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef Int32VectorT NativeTableType; + typedef Int32VectorBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_VALUES = 4 + }; + const ::flatbuffers::Vector *values() const { + return GetPointer *>(VT_VALUES); + } + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyOffset(verifier, VT_VALUES) && + verifier.VerifyVector(values()) && + verifier.EndTable(); + } + Int32VectorT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(Int32VectorT *_o, const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + static ::flatbuffers::Offset Pack(::flatbuffers::FlatBufferBuilder &_fbb, const Int32VectorT* _o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct Int32VectorBuilder { + typedef Int32Vector Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_values(::flatbuffers::Offset<::flatbuffers::Vector> values) { + fbb_.AddOffset(Int32Vector::VT_VALUES, values); + } + explicit Int32VectorBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateInt32Vector( + ::flatbuffers::FlatBufferBuilder &_fbb, + ::flatbuffers::Offset<::flatbuffers::Vector> values = 0) { + Int32VectorBuilder builder_(_fbb); + builder_.add_values(values); + return builder_.Finish(); +} + +inline ::flatbuffers::Offset CreateInt32VectorDirect( + ::flatbuffers::FlatBufferBuilder &_fbb, + const std::vector *values = nullptr) { + auto values__ = values ? _fbb.CreateVector(*values) : 0; + return tflite::CreateInt32Vector( + _fbb, + values__); +} + +::flatbuffers::Offset CreateInt32Vector(::flatbuffers::FlatBufferBuilder &_fbb, const Int32VectorT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct Uint16VectorT : public ::flatbuffers::NativeTable { + typedef Uint16Vector TableType; + std::vector values{}; +}; + +struct Uint16Vector FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef Uint16VectorT NativeTableType; + typedef Uint16VectorBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_VALUES = 4 + }; + const ::flatbuffers::Vector *values() const { + return GetPointer *>(VT_VALUES); + } + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyOffset(verifier, VT_VALUES) && + verifier.VerifyVector(values()) && + verifier.EndTable(); + } + Uint16VectorT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(Uint16VectorT *_o, const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + static ::flatbuffers::Offset Pack(::flatbuffers::FlatBufferBuilder &_fbb, const Uint16VectorT* _o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct Uint16VectorBuilder { + typedef Uint16Vector Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_values(::flatbuffers::Offset<::flatbuffers::Vector> values) { + fbb_.AddOffset(Uint16Vector::VT_VALUES, values); + } + explicit Uint16VectorBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateUint16Vector( + ::flatbuffers::FlatBufferBuilder &_fbb, + ::flatbuffers::Offset<::flatbuffers::Vector> values = 0) { + Uint16VectorBuilder builder_(_fbb); + builder_.add_values(values); + return builder_.Finish(); +} + +inline ::flatbuffers::Offset CreateUint16VectorDirect( + ::flatbuffers::FlatBufferBuilder &_fbb, + const std::vector *values = nullptr) { + if (values) { _fbb.ForceVectorAlignment(values->size(), sizeof(uint16_t), 4); } + auto values__ = values ? _fbb.CreateVector(*values) : 0; + return tflite::CreateUint16Vector( + _fbb, + values__); +} + +::flatbuffers::Offset CreateUint16Vector(::flatbuffers::FlatBufferBuilder &_fbb, const Uint16VectorT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct Uint8VectorT : public ::flatbuffers::NativeTable { + typedef Uint8Vector TableType; + std::vector values{}; +}; + +struct Uint8Vector FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef Uint8VectorT NativeTableType; + typedef Uint8VectorBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_VALUES = 4 + }; + const ::flatbuffers::Vector *values() const { + return GetPointer *>(VT_VALUES); + } + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyOffset(verifier, VT_VALUES) && + verifier.VerifyVector(values()) && + verifier.EndTable(); + } + Uint8VectorT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(Uint8VectorT *_o, const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + static ::flatbuffers::Offset Pack(::flatbuffers::FlatBufferBuilder &_fbb, const Uint8VectorT* _o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct Uint8VectorBuilder { + typedef Uint8Vector Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_values(::flatbuffers::Offset<::flatbuffers::Vector> values) { + fbb_.AddOffset(Uint8Vector::VT_VALUES, values); + } + explicit Uint8VectorBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateUint8Vector( + ::flatbuffers::FlatBufferBuilder &_fbb, + ::flatbuffers::Offset<::flatbuffers::Vector> values = 0) { + Uint8VectorBuilder builder_(_fbb); + builder_.add_values(values); + return builder_.Finish(); +} + +inline ::flatbuffers::Offset CreateUint8VectorDirect( + ::flatbuffers::FlatBufferBuilder &_fbb, + const std::vector *values = nullptr) { + if (values) { _fbb.ForceVectorAlignment(values->size(), sizeof(uint8_t), 4); } + auto values__ = values ? _fbb.CreateVector(*values) : 0; + return tflite::CreateUint8Vector( + _fbb, + values__); +} + +::flatbuffers::Offset CreateUint8Vector(::flatbuffers::FlatBufferBuilder &_fbb, const Uint8VectorT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct DimensionMetadataT : public ::flatbuffers::NativeTable { + typedef DimensionMetadata TableType; + tflite::DimensionType format = tflite::DimensionType_DENSE; + int32_t dense_size = 0; + tflite::SparseIndexVectorUnion array_segments{}; + tflite::SparseIndexVectorUnion array_indices{}; +}; + +struct DimensionMetadata FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef DimensionMetadataT NativeTableType; + typedef DimensionMetadataBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_FORMAT = 4, + VT_DENSE_SIZE = 6, + VT_ARRAY_SEGMENTS_TYPE = 8, + VT_ARRAY_SEGMENTS = 10, + VT_ARRAY_INDICES_TYPE = 12, + VT_ARRAY_INDICES = 14 + }; + tflite::DimensionType format() const { + return static_cast(GetField(VT_FORMAT, 0)); + } + int32_t dense_size() const { + return GetField(VT_DENSE_SIZE, 0); + } + tflite::SparseIndexVector array_segments_type() const { + return static_cast(GetField(VT_ARRAY_SEGMENTS_TYPE, 0)); + } + const void *array_segments() const { + return GetPointer(VT_ARRAY_SEGMENTS); + } + template const T *array_segments_as() const; + const tflite::Int32Vector *array_segments_as_Int32Vector() const { + return array_segments_type() == tflite::SparseIndexVector_Int32Vector ? static_cast(array_segments()) : nullptr; + } + const tflite::Uint16Vector *array_segments_as_Uint16Vector() const { + return array_segments_type() == tflite::SparseIndexVector_Uint16Vector ? static_cast(array_segments()) : nullptr; + } + const tflite::Uint8Vector *array_segments_as_Uint8Vector() const { + return array_segments_type() == tflite::SparseIndexVector_Uint8Vector ? static_cast(array_segments()) : nullptr; + } + tflite::SparseIndexVector array_indices_type() const { + return static_cast(GetField(VT_ARRAY_INDICES_TYPE, 0)); + } + const void *array_indices() const { + return GetPointer(VT_ARRAY_INDICES); + } + template const T *array_indices_as() const; + const tflite::Int32Vector *array_indices_as_Int32Vector() const { + return array_indices_type() == tflite::SparseIndexVector_Int32Vector ? static_cast(array_indices()) : nullptr; + } + const tflite::Uint16Vector *array_indices_as_Uint16Vector() const { + return array_indices_type() == tflite::SparseIndexVector_Uint16Vector ? static_cast(array_indices()) : nullptr; + } + const tflite::Uint8Vector *array_indices_as_Uint8Vector() const { + return array_indices_type() == tflite::SparseIndexVector_Uint8Vector ? static_cast(array_indices()) : nullptr; + } + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_FORMAT, 1) && + VerifyField(verifier, VT_DENSE_SIZE, 4) && + VerifyField(verifier, VT_ARRAY_SEGMENTS_TYPE, 1) && + VerifyOffset(verifier, VT_ARRAY_SEGMENTS) && + VerifySparseIndexVector(verifier, array_segments(), array_segments_type()) && + VerifyField(verifier, VT_ARRAY_INDICES_TYPE, 1) && + VerifyOffset(verifier, VT_ARRAY_INDICES) && + VerifySparseIndexVector(verifier, array_indices(), array_indices_type()) && + verifier.EndTable(); + } + DimensionMetadataT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(DimensionMetadataT *_o, const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + static ::flatbuffers::Offset Pack(::flatbuffers::FlatBufferBuilder &_fbb, const DimensionMetadataT* _o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +template<> inline const tflite::Int32Vector *DimensionMetadata::array_segments_as() const { + return array_segments_as_Int32Vector(); +} + +template<> inline const tflite::Uint16Vector *DimensionMetadata::array_segments_as() const { + return array_segments_as_Uint16Vector(); +} + +template<> inline const tflite::Uint8Vector *DimensionMetadata::array_segments_as() const { + return array_segments_as_Uint8Vector(); +} + +template<> inline const tflite::Int32Vector *DimensionMetadata::array_indices_as() const { + return array_indices_as_Int32Vector(); +} + +template<> inline const tflite::Uint16Vector *DimensionMetadata::array_indices_as() const { + return array_indices_as_Uint16Vector(); +} + +template<> inline const tflite::Uint8Vector *DimensionMetadata::array_indices_as() const { + return array_indices_as_Uint8Vector(); +} + +struct DimensionMetadataBuilder { + typedef DimensionMetadata Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_format(tflite::DimensionType format) { + fbb_.AddElement(DimensionMetadata::VT_FORMAT, static_cast(format), 0); + } + void add_dense_size(int32_t dense_size) { + fbb_.AddElement(DimensionMetadata::VT_DENSE_SIZE, dense_size, 0); + } + void add_array_segments_type(tflite::SparseIndexVector array_segments_type) { + fbb_.AddElement(DimensionMetadata::VT_ARRAY_SEGMENTS_TYPE, static_cast(array_segments_type), 0); + } + void add_array_segments(::flatbuffers::Offset array_segments) { + fbb_.AddOffset(DimensionMetadata::VT_ARRAY_SEGMENTS, array_segments); + } + void add_array_indices_type(tflite::SparseIndexVector array_indices_type) { + fbb_.AddElement(DimensionMetadata::VT_ARRAY_INDICES_TYPE, static_cast(array_indices_type), 0); + } + void add_array_indices(::flatbuffers::Offset array_indices) { + fbb_.AddOffset(DimensionMetadata::VT_ARRAY_INDICES, array_indices); + } + explicit DimensionMetadataBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateDimensionMetadata( + ::flatbuffers::FlatBufferBuilder &_fbb, + tflite::DimensionType format = tflite::DimensionType_DENSE, + int32_t dense_size = 0, + tflite::SparseIndexVector array_segments_type = tflite::SparseIndexVector_NONE, + ::flatbuffers::Offset array_segments = 0, + tflite::SparseIndexVector array_indices_type = tflite::SparseIndexVector_NONE, + ::flatbuffers::Offset array_indices = 0) { + DimensionMetadataBuilder builder_(_fbb); + builder_.add_array_indices(array_indices); + builder_.add_array_segments(array_segments); + builder_.add_dense_size(dense_size); + builder_.add_array_indices_type(array_indices_type); + builder_.add_array_segments_type(array_segments_type); + builder_.add_format(format); + return builder_.Finish(); +} + +::flatbuffers::Offset CreateDimensionMetadata(::flatbuffers::FlatBufferBuilder &_fbb, const DimensionMetadataT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct SparsityParametersT : public ::flatbuffers::NativeTable { + typedef SparsityParameters TableType; + std::vector traversal_order{}; + std::vector block_map{}; + std::vector> dim_metadata{}; + SparsityParametersT() = default; + SparsityParametersT(const SparsityParametersT &o); + SparsityParametersT(SparsityParametersT&&) FLATBUFFERS_NOEXCEPT = default; + SparsityParametersT &operator=(SparsityParametersT o) FLATBUFFERS_NOEXCEPT; +}; + +struct SparsityParameters FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef SparsityParametersT NativeTableType; + typedef SparsityParametersBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_TRAVERSAL_ORDER = 4, + VT_BLOCK_MAP = 6, + VT_DIM_METADATA = 8 + }; + const ::flatbuffers::Vector *traversal_order() const { + return GetPointer *>(VT_TRAVERSAL_ORDER); + } + const ::flatbuffers::Vector *block_map() const { + return GetPointer *>(VT_BLOCK_MAP); + } + const ::flatbuffers::Vector<::flatbuffers::Offset> *dim_metadata() const { + return GetPointer> *>(VT_DIM_METADATA); + } + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyOffset(verifier, VT_TRAVERSAL_ORDER) && + verifier.VerifyVector(traversal_order()) && + VerifyOffset(verifier, VT_BLOCK_MAP) && + verifier.VerifyVector(block_map()) && + VerifyOffset(verifier, VT_DIM_METADATA) && + verifier.VerifyVector(dim_metadata()) && + verifier.VerifyVectorOfTables(dim_metadata()) && + verifier.EndTable(); + } + SparsityParametersT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(SparsityParametersT *_o, const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + static ::flatbuffers::Offset Pack(::flatbuffers::FlatBufferBuilder &_fbb, const SparsityParametersT* _o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct SparsityParametersBuilder { + typedef SparsityParameters Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_traversal_order(::flatbuffers::Offset<::flatbuffers::Vector> traversal_order) { + fbb_.AddOffset(SparsityParameters::VT_TRAVERSAL_ORDER, traversal_order); + } + void add_block_map(::flatbuffers::Offset<::flatbuffers::Vector> block_map) { + fbb_.AddOffset(SparsityParameters::VT_BLOCK_MAP, block_map); + } + void add_dim_metadata(::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset>> dim_metadata) { + fbb_.AddOffset(SparsityParameters::VT_DIM_METADATA, dim_metadata); + } + explicit SparsityParametersBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateSparsityParameters( + ::flatbuffers::FlatBufferBuilder &_fbb, + ::flatbuffers::Offset<::flatbuffers::Vector> traversal_order = 0, + ::flatbuffers::Offset<::flatbuffers::Vector> block_map = 0, + ::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset>> dim_metadata = 0) { + SparsityParametersBuilder builder_(_fbb); + builder_.add_dim_metadata(dim_metadata); + builder_.add_block_map(block_map); + builder_.add_traversal_order(traversal_order); + return builder_.Finish(); +} + +inline ::flatbuffers::Offset CreateSparsityParametersDirect( + ::flatbuffers::FlatBufferBuilder &_fbb, + const std::vector *traversal_order = nullptr, + const std::vector *block_map = nullptr, + const std::vector<::flatbuffers::Offset> *dim_metadata = nullptr) { + auto traversal_order__ = traversal_order ? _fbb.CreateVector(*traversal_order) : 0; + auto block_map__ = block_map ? _fbb.CreateVector(*block_map) : 0; + auto dim_metadata__ = dim_metadata ? _fbb.CreateVector<::flatbuffers::Offset>(*dim_metadata) : 0; + return tflite::CreateSparsityParameters( + _fbb, + traversal_order__, + block_map__, + dim_metadata__); +} + +::flatbuffers::Offset CreateSparsityParameters(::flatbuffers::FlatBufferBuilder &_fbb, const SparsityParametersT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct VariantSubTypeT : public ::flatbuffers::NativeTable { + typedef VariantSubType TableType; + std::vector shape{}; + tflite::TensorType type = tflite::TensorType_FLOAT32; + bool has_rank = false; +}; + +struct VariantSubType FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef VariantSubTypeT NativeTableType; + typedef VariantSubTypeBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_SHAPE = 4, + VT_TYPE = 6, + VT_HAS_RANK = 8 + }; + const ::flatbuffers::Vector *shape() const { + return GetPointer *>(VT_SHAPE); + } + tflite::TensorType type() const { + return static_cast(GetField(VT_TYPE, 0)); + } + bool has_rank() const { + return GetField(VT_HAS_RANK, 0) != 0; + } + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyOffset(verifier, VT_SHAPE) && + verifier.VerifyVector(shape()) && + VerifyField(verifier, VT_TYPE, 1) && + VerifyField(verifier, VT_HAS_RANK, 1) && + verifier.EndTable(); + } + VariantSubTypeT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(VariantSubTypeT *_o, const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + static ::flatbuffers::Offset Pack(::flatbuffers::FlatBufferBuilder &_fbb, const VariantSubTypeT* _o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct VariantSubTypeBuilder { + typedef VariantSubType Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_shape(::flatbuffers::Offset<::flatbuffers::Vector> shape) { + fbb_.AddOffset(VariantSubType::VT_SHAPE, shape); + } + void add_type(tflite::TensorType type) { + fbb_.AddElement(VariantSubType::VT_TYPE, static_cast(type), 0); + } + void add_has_rank(bool has_rank) { + fbb_.AddElement(VariantSubType::VT_HAS_RANK, static_cast(has_rank), 0); + } + explicit VariantSubTypeBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateVariantSubType( + ::flatbuffers::FlatBufferBuilder &_fbb, + ::flatbuffers::Offset<::flatbuffers::Vector> shape = 0, + tflite::TensorType type = tflite::TensorType_FLOAT32, + bool has_rank = false) { + VariantSubTypeBuilder builder_(_fbb); + builder_.add_shape(shape); + builder_.add_has_rank(has_rank); + builder_.add_type(type); + return builder_.Finish(); +} + +inline ::flatbuffers::Offset CreateVariantSubTypeDirect( + ::flatbuffers::FlatBufferBuilder &_fbb, + const std::vector *shape = nullptr, + tflite::TensorType type = tflite::TensorType_FLOAT32, + bool has_rank = false) { + auto shape__ = shape ? _fbb.CreateVector(*shape) : 0; + return tflite::CreateVariantSubType( + _fbb, + shape__, + type, + has_rank); +} + +::flatbuffers::Offset CreateVariantSubType(::flatbuffers::FlatBufferBuilder &_fbb, const VariantSubTypeT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct TensorT : public ::flatbuffers::NativeTable { + typedef Tensor TableType; + std::vector shape{}; + tflite::TensorType type = tflite::TensorType_FLOAT32; + uint32_t buffer = 0; + std::string name{}; + std::unique_ptr quantization{}; + bool is_variable = false; + std::unique_ptr sparsity{}; + std::vector shape_signature{}; + bool has_rank = false; + std::vector> variant_tensors{}; + TensorT() = default; + TensorT(const TensorT &o); + TensorT(TensorT&&) FLATBUFFERS_NOEXCEPT = default; + TensorT &operator=(TensorT o) FLATBUFFERS_NOEXCEPT; +}; + +struct Tensor FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef TensorT NativeTableType; + typedef TensorBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_SHAPE = 4, + VT_TYPE = 6, + VT_BUFFER = 8, + VT_NAME = 10, + VT_QUANTIZATION = 12, + VT_IS_VARIABLE = 14, + VT_SPARSITY = 16, + VT_SHAPE_SIGNATURE = 18, + VT_HAS_RANK = 20, + VT_VARIANT_TENSORS = 22 + }; + const ::flatbuffers::Vector *shape() const { + return GetPointer *>(VT_SHAPE); + } + tflite::TensorType type() const { + return static_cast(GetField(VT_TYPE, 0)); + } + uint32_t buffer() const { + return GetField(VT_BUFFER, 0); + } + const ::flatbuffers::String *name() const { + return GetPointer(VT_NAME); + } + const tflite::QuantizationParameters *quantization() const { + return GetPointer(VT_QUANTIZATION); + } + bool is_variable() const { + return GetField(VT_IS_VARIABLE, 0) != 0; + } + const tflite::SparsityParameters *sparsity() const { + return GetPointer(VT_SPARSITY); + } + const ::flatbuffers::Vector *shape_signature() const { + return GetPointer *>(VT_SHAPE_SIGNATURE); + } + bool has_rank() const { + return GetField(VT_HAS_RANK, 0) != 0; + } + const ::flatbuffers::Vector<::flatbuffers::Offset> *variant_tensors() const { + return GetPointer> *>(VT_VARIANT_TENSORS); + } + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyOffset(verifier, VT_SHAPE) && + verifier.VerifyVector(shape()) && + VerifyField(verifier, VT_TYPE, 1) && + VerifyField(verifier, VT_BUFFER, 4) && + VerifyOffset(verifier, VT_NAME) && + verifier.VerifyString(name()) && + VerifyOffset(verifier, VT_QUANTIZATION) && + verifier.VerifyTable(quantization()) && + VerifyField(verifier, VT_IS_VARIABLE, 1) && + VerifyOffset(verifier, VT_SPARSITY) && + verifier.VerifyTable(sparsity()) && + VerifyOffset(verifier, VT_SHAPE_SIGNATURE) && + verifier.VerifyVector(shape_signature()) && + VerifyField(verifier, VT_HAS_RANK, 1) && + VerifyOffset(verifier, VT_VARIANT_TENSORS) && + verifier.VerifyVector(variant_tensors()) && + verifier.VerifyVectorOfTables(variant_tensors()) && + verifier.EndTable(); + } + TensorT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(TensorT *_o, const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + static ::flatbuffers::Offset Pack(::flatbuffers::FlatBufferBuilder &_fbb, const TensorT* _o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct TensorBuilder { + typedef Tensor Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_shape(::flatbuffers::Offset<::flatbuffers::Vector> shape) { + fbb_.AddOffset(Tensor::VT_SHAPE, shape); + } + void add_type(tflite::TensorType type) { + fbb_.AddElement(Tensor::VT_TYPE, static_cast(type), 0); + } + void add_buffer(uint32_t buffer) { + fbb_.AddElement(Tensor::VT_BUFFER, buffer, 0); + } + void add_name(::flatbuffers::Offset<::flatbuffers::String> name) { + fbb_.AddOffset(Tensor::VT_NAME, name); + } + void add_quantization(::flatbuffers::Offset quantization) { + fbb_.AddOffset(Tensor::VT_QUANTIZATION, quantization); + } + void add_is_variable(bool is_variable) { + fbb_.AddElement(Tensor::VT_IS_VARIABLE, static_cast(is_variable), 0); + } + void add_sparsity(::flatbuffers::Offset sparsity) { + fbb_.AddOffset(Tensor::VT_SPARSITY, sparsity); + } + void add_shape_signature(::flatbuffers::Offset<::flatbuffers::Vector> shape_signature) { + fbb_.AddOffset(Tensor::VT_SHAPE_SIGNATURE, shape_signature); + } + void add_has_rank(bool has_rank) { + fbb_.AddElement(Tensor::VT_HAS_RANK, static_cast(has_rank), 0); + } + void add_variant_tensors(::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset>> variant_tensors) { + fbb_.AddOffset(Tensor::VT_VARIANT_TENSORS, variant_tensors); + } + explicit TensorBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateTensor( + ::flatbuffers::FlatBufferBuilder &_fbb, + ::flatbuffers::Offset<::flatbuffers::Vector> shape = 0, + tflite::TensorType type = tflite::TensorType_FLOAT32, + uint32_t buffer = 0, + ::flatbuffers::Offset<::flatbuffers::String> name = 0, + ::flatbuffers::Offset quantization = 0, + bool is_variable = false, + ::flatbuffers::Offset sparsity = 0, + ::flatbuffers::Offset<::flatbuffers::Vector> shape_signature = 0, + bool has_rank = false, + ::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset>> variant_tensors = 0) { + TensorBuilder builder_(_fbb); + builder_.add_variant_tensors(variant_tensors); + builder_.add_shape_signature(shape_signature); + builder_.add_sparsity(sparsity); + builder_.add_quantization(quantization); + builder_.add_name(name); + builder_.add_buffer(buffer); + builder_.add_shape(shape); + builder_.add_has_rank(has_rank); + builder_.add_is_variable(is_variable); + builder_.add_type(type); + return builder_.Finish(); +} + +inline ::flatbuffers::Offset CreateTensorDirect( + ::flatbuffers::FlatBufferBuilder &_fbb, + const std::vector *shape = nullptr, + tflite::TensorType type = tflite::TensorType_FLOAT32, + uint32_t buffer = 0, + const char *name = nullptr, + ::flatbuffers::Offset quantization = 0, + bool is_variable = false, + ::flatbuffers::Offset sparsity = 0, + const std::vector *shape_signature = nullptr, + bool has_rank = false, + const std::vector<::flatbuffers::Offset> *variant_tensors = nullptr) { + auto shape__ = shape ? _fbb.CreateVector(*shape) : 0; + auto name__ = name ? _fbb.CreateString(name) : 0; + auto shape_signature__ = shape_signature ? _fbb.CreateVector(*shape_signature) : 0; + auto variant_tensors__ = variant_tensors ? _fbb.CreateVector<::flatbuffers::Offset>(*variant_tensors) : 0; + return tflite::CreateTensor( + _fbb, + shape__, + type, + buffer, + name__, + quantization, + is_variable, + sparsity, + shape_signature__, + has_rank, + variant_tensors__); +} + +::flatbuffers::Offset CreateTensor(::flatbuffers::FlatBufferBuilder &_fbb, const TensorT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct StablehloGatherOptionsT : public ::flatbuffers::NativeTable { + typedef StablehloGatherOptions TableType; + std::vector offset_dims{}; + std::vector collapsed_slice_dims{}; + std::vector start_index_map{}; + int64_t index_vector_dim = 0; + std::vector slice_sizes{}; + bool indices_are_sorted = false; +}; + +struct StablehloGatherOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef StablehloGatherOptionsT NativeTableType; + typedef StablehloGatherOptionsBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_OFFSET_DIMS = 4, + VT_COLLAPSED_SLICE_DIMS = 6, + VT_START_INDEX_MAP = 8, + VT_INDEX_VECTOR_DIM = 10, + VT_SLICE_SIZES = 12, + VT_INDICES_ARE_SORTED = 14 + }; + const ::flatbuffers::Vector *offset_dims() const { + return GetPointer *>(VT_OFFSET_DIMS); + } + const ::flatbuffers::Vector *collapsed_slice_dims() const { + return GetPointer *>(VT_COLLAPSED_SLICE_DIMS); + } + const ::flatbuffers::Vector *start_index_map() const { + return GetPointer *>(VT_START_INDEX_MAP); + } + int64_t index_vector_dim() const { + return GetField(VT_INDEX_VECTOR_DIM, 0); + } + const ::flatbuffers::Vector *slice_sizes() const { + return GetPointer *>(VT_SLICE_SIZES); + } + bool indices_are_sorted() const { + return GetField(VT_INDICES_ARE_SORTED, 0) != 0; + } + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyOffset(verifier, VT_OFFSET_DIMS) && + verifier.VerifyVector(offset_dims()) && + VerifyOffset(verifier, VT_COLLAPSED_SLICE_DIMS) && + verifier.VerifyVector(collapsed_slice_dims()) && + VerifyOffset(verifier, VT_START_INDEX_MAP) && + verifier.VerifyVector(start_index_map()) && + VerifyField(verifier, VT_INDEX_VECTOR_DIM, 8) && + VerifyOffset(verifier, VT_SLICE_SIZES) && + verifier.VerifyVector(slice_sizes()) && + VerifyField(verifier, VT_INDICES_ARE_SORTED, 1) && + verifier.EndTable(); + } + StablehloGatherOptionsT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(StablehloGatherOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + static ::flatbuffers::Offset Pack(::flatbuffers::FlatBufferBuilder &_fbb, const StablehloGatherOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct StablehloGatherOptionsBuilder { + typedef StablehloGatherOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_offset_dims(::flatbuffers::Offset<::flatbuffers::Vector> offset_dims) { + fbb_.AddOffset(StablehloGatherOptions::VT_OFFSET_DIMS, offset_dims); + } + void add_collapsed_slice_dims(::flatbuffers::Offset<::flatbuffers::Vector> collapsed_slice_dims) { + fbb_.AddOffset(StablehloGatherOptions::VT_COLLAPSED_SLICE_DIMS, collapsed_slice_dims); + } + void add_start_index_map(::flatbuffers::Offset<::flatbuffers::Vector> start_index_map) { + fbb_.AddOffset(StablehloGatherOptions::VT_START_INDEX_MAP, start_index_map); + } + void add_index_vector_dim(int64_t index_vector_dim) { + fbb_.AddElement(StablehloGatherOptions::VT_INDEX_VECTOR_DIM, index_vector_dim, 0); + } + void add_slice_sizes(::flatbuffers::Offset<::flatbuffers::Vector> slice_sizes) { + fbb_.AddOffset(StablehloGatherOptions::VT_SLICE_SIZES, slice_sizes); + } + void add_indices_are_sorted(bool indices_are_sorted) { + fbb_.AddElement(StablehloGatherOptions::VT_INDICES_ARE_SORTED, static_cast(indices_are_sorted), 0); + } + explicit StablehloGatherOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateStablehloGatherOptions( + ::flatbuffers::FlatBufferBuilder &_fbb, + ::flatbuffers::Offset<::flatbuffers::Vector> offset_dims = 0, + ::flatbuffers::Offset<::flatbuffers::Vector> collapsed_slice_dims = 0, + ::flatbuffers::Offset<::flatbuffers::Vector> start_index_map = 0, + int64_t index_vector_dim = 0, + ::flatbuffers::Offset<::flatbuffers::Vector> slice_sizes = 0, + bool indices_are_sorted = false) { + StablehloGatherOptionsBuilder builder_(_fbb); + builder_.add_index_vector_dim(index_vector_dim); + builder_.add_slice_sizes(slice_sizes); + builder_.add_start_index_map(start_index_map); + builder_.add_collapsed_slice_dims(collapsed_slice_dims); + builder_.add_offset_dims(offset_dims); + builder_.add_indices_are_sorted(indices_are_sorted); + return builder_.Finish(); +} + +inline ::flatbuffers::Offset CreateStablehloGatherOptionsDirect( + ::flatbuffers::FlatBufferBuilder &_fbb, + const std::vector *offset_dims = nullptr, + const std::vector *collapsed_slice_dims = nullptr, + const std::vector *start_index_map = nullptr, + int64_t index_vector_dim = 0, + const std::vector *slice_sizes = nullptr, + bool indices_are_sorted = false) { + auto offset_dims__ = offset_dims ? _fbb.CreateVector(*offset_dims) : 0; + auto collapsed_slice_dims__ = collapsed_slice_dims ? _fbb.CreateVector(*collapsed_slice_dims) : 0; + auto start_index_map__ = start_index_map ? _fbb.CreateVector(*start_index_map) : 0; + auto slice_sizes__ = slice_sizes ? _fbb.CreateVector(*slice_sizes) : 0; + return tflite::CreateStablehloGatherOptions( + _fbb, + offset_dims__, + collapsed_slice_dims__, + start_index_map__, + index_vector_dim, + slice_sizes__, + indices_are_sorted); +} + +::flatbuffers::Offset CreateStablehloGatherOptions(::flatbuffers::FlatBufferBuilder &_fbb, const StablehloGatherOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct StablehloTransposeOptionsT : public ::flatbuffers::NativeTable { + typedef StablehloTransposeOptions TableType; + std::vector permutation{}; +}; + +struct StablehloTransposeOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef StablehloTransposeOptionsT NativeTableType; + typedef StablehloTransposeOptionsBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_PERMUTATION = 4 + }; + const ::flatbuffers::Vector *permutation() const { + return GetPointer *>(VT_PERMUTATION); + } + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyOffset(verifier, VT_PERMUTATION) && + verifier.VerifyVector(permutation()) && + verifier.EndTable(); + } + StablehloTransposeOptionsT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(StablehloTransposeOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + static ::flatbuffers::Offset Pack(::flatbuffers::FlatBufferBuilder &_fbb, const StablehloTransposeOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct StablehloTransposeOptionsBuilder { + typedef StablehloTransposeOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_permutation(::flatbuffers::Offset<::flatbuffers::Vector> permutation) { + fbb_.AddOffset(StablehloTransposeOptions::VT_PERMUTATION, permutation); + } + explicit StablehloTransposeOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateStablehloTransposeOptions( + ::flatbuffers::FlatBufferBuilder &_fbb, + ::flatbuffers::Offset<::flatbuffers::Vector> permutation = 0) { + StablehloTransposeOptionsBuilder builder_(_fbb); + builder_.add_permutation(permutation); + return builder_.Finish(); +} + +inline ::flatbuffers::Offset CreateStablehloTransposeOptionsDirect( + ::flatbuffers::FlatBufferBuilder &_fbb, + const std::vector *permutation = nullptr) { + auto permutation__ = permutation ? _fbb.CreateVector(*permutation) : 0; + return tflite::CreateStablehloTransposeOptions( + _fbb, + permutation__); +} + +::flatbuffers::Offset CreateStablehloTransposeOptions(::flatbuffers::FlatBufferBuilder &_fbb, const StablehloTransposeOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct StablehloDotGeneralOptionsT : public ::flatbuffers::NativeTable { + typedef StablehloDotGeneralOptions TableType; + std::vector lhs_batching_dimensions{}; + std::vector rhs_batching_dimensions{}; + std::vector lhs_contracting_dimensions{}; + std::vector rhs_contracting_dimensions{}; + std::vector precision_config{}; +}; + +struct StablehloDotGeneralOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef StablehloDotGeneralOptionsT NativeTableType; + typedef StablehloDotGeneralOptionsBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_LHS_BATCHING_DIMENSIONS = 4, + VT_RHS_BATCHING_DIMENSIONS = 6, + VT_LHS_CONTRACTING_DIMENSIONS = 8, + VT_RHS_CONTRACTING_DIMENSIONS = 10, + VT_PRECISION_CONFIG = 12 + }; + const ::flatbuffers::Vector *lhs_batching_dimensions() const { + return GetPointer *>(VT_LHS_BATCHING_DIMENSIONS); + } + const ::flatbuffers::Vector *rhs_batching_dimensions() const { + return GetPointer *>(VT_RHS_BATCHING_DIMENSIONS); + } + const ::flatbuffers::Vector *lhs_contracting_dimensions() const { + return GetPointer *>(VT_LHS_CONTRACTING_DIMENSIONS); + } + const ::flatbuffers::Vector *rhs_contracting_dimensions() const { + return GetPointer *>(VT_RHS_CONTRACTING_DIMENSIONS); + } + const ::flatbuffers::Vector *precision_config() const { + return GetPointer *>(VT_PRECISION_CONFIG); + } + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyOffset(verifier, VT_LHS_BATCHING_DIMENSIONS) && + verifier.VerifyVector(lhs_batching_dimensions()) && + VerifyOffset(verifier, VT_RHS_BATCHING_DIMENSIONS) && + verifier.VerifyVector(rhs_batching_dimensions()) && + VerifyOffset(verifier, VT_LHS_CONTRACTING_DIMENSIONS) && + verifier.VerifyVector(lhs_contracting_dimensions()) && + VerifyOffset(verifier, VT_RHS_CONTRACTING_DIMENSIONS) && + verifier.VerifyVector(rhs_contracting_dimensions()) && + VerifyOffset(verifier, VT_PRECISION_CONFIG) && + verifier.VerifyVector(precision_config()) && + verifier.EndTable(); + } + StablehloDotGeneralOptionsT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(StablehloDotGeneralOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + static ::flatbuffers::Offset Pack(::flatbuffers::FlatBufferBuilder &_fbb, const StablehloDotGeneralOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct StablehloDotGeneralOptionsBuilder { + typedef StablehloDotGeneralOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_lhs_batching_dimensions(::flatbuffers::Offset<::flatbuffers::Vector> lhs_batching_dimensions) { + fbb_.AddOffset(StablehloDotGeneralOptions::VT_LHS_BATCHING_DIMENSIONS, lhs_batching_dimensions); + } + void add_rhs_batching_dimensions(::flatbuffers::Offset<::flatbuffers::Vector> rhs_batching_dimensions) { + fbb_.AddOffset(StablehloDotGeneralOptions::VT_RHS_BATCHING_DIMENSIONS, rhs_batching_dimensions); + } + void add_lhs_contracting_dimensions(::flatbuffers::Offset<::flatbuffers::Vector> lhs_contracting_dimensions) { + fbb_.AddOffset(StablehloDotGeneralOptions::VT_LHS_CONTRACTING_DIMENSIONS, lhs_contracting_dimensions); + } + void add_rhs_contracting_dimensions(::flatbuffers::Offset<::flatbuffers::Vector> rhs_contracting_dimensions) { + fbb_.AddOffset(StablehloDotGeneralOptions::VT_RHS_CONTRACTING_DIMENSIONS, rhs_contracting_dimensions); + } + void add_precision_config(::flatbuffers::Offset<::flatbuffers::Vector> precision_config) { + fbb_.AddOffset(StablehloDotGeneralOptions::VT_PRECISION_CONFIG, precision_config); + } + explicit StablehloDotGeneralOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateStablehloDotGeneralOptions( + ::flatbuffers::FlatBufferBuilder &_fbb, + ::flatbuffers::Offset<::flatbuffers::Vector> lhs_batching_dimensions = 0, + ::flatbuffers::Offset<::flatbuffers::Vector> rhs_batching_dimensions = 0, + ::flatbuffers::Offset<::flatbuffers::Vector> lhs_contracting_dimensions = 0, + ::flatbuffers::Offset<::flatbuffers::Vector> rhs_contracting_dimensions = 0, + ::flatbuffers::Offset<::flatbuffers::Vector> precision_config = 0) { + StablehloDotGeneralOptionsBuilder builder_(_fbb); + builder_.add_precision_config(precision_config); + builder_.add_rhs_contracting_dimensions(rhs_contracting_dimensions); + builder_.add_lhs_contracting_dimensions(lhs_contracting_dimensions); + builder_.add_rhs_batching_dimensions(rhs_batching_dimensions); + builder_.add_lhs_batching_dimensions(lhs_batching_dimensions); + return builder_.Finish(); +} + +inline ::flatbuffers::Offset CreateStablehloDotGeneralOptionsDirect( + ::flatbuffers::FlatBufferBuilder &_fbb, + const std::vector *lhs_batching_dimensions = nullptr, + const std::vector *rhs_batching_dimensions = nullptr, + const std::vector *lhs_contracting_dimensions = nullptr, + const std::vector *rhs_contracting_dimensions = nullptr, + const std::vector *precision_config = nullptr) { + auto lhs_batching_dimensions__ = lhs_batching_dimensions ? _fbb.CreateVector(*lhs_batching_dimensions) : 0; + auto rhs_batching_dimensions__ = rhs_batching_dimensions ? _fbb.CreateVector(*rhs_batching_dimensions) : 0; + auto lhs_contracting_dimensions__ = lhs_contracting_dimensions ? _fbb.CreateVector(*lhs_contracting_dimensions) : 0; + auto rhs_contracting_dimensions__ = rhs_contracting_dimensions ? _fbb.CreateVector(*rhs_contracting_dimensions) : 0; + auto precision_config__ = precision_config ? _fbb.CreateVector(*precision_config) : 0; + return tflite::CreateStablehloDotGeneralOptions( + _fbb, + lhs_batching_dimensions__, + rhs_batching_dimensions__, + lhs_contracting_dimensions__, + rhs_contracting_dimensions__, + precision_config__); +} + +::flatbuffers::Offset CreateStablehloDotGeneralOptions(::flatbuffers::FlatBufferBuilder &_fbb, const StablehloDotGeneralOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct StablehloReduceWindowOptionsT : public ::flatbuffers::NativeTable { + typedef StablehloReduceWindowOptions TableType; + std::vector window_dimensions{}; + std::vector window_strides{}; + std::vector base_dilations{}; + std::vector window_dilations{}; + std::vector padding{}; + int32_t body_subgraph_index = 0; +}; + +struct StablehloReduceWindowOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef StablehloReduceWindowOptionsT NativeTableType; + typedef StablehloReduceWindowOptionsBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_WINDOW_DIMENSIONS = 4, + VT_WINDOW_STRIDES = 6, + VT_BASE_DILATIONS = 8, + VT_WINDOW_DILATIONS = 10, + VT_PADDING = 12, + VT_BODY_SUBGRAPH_INDEX = 14 + }; + const ::flatbuffers::Vector *window_dimensions() const { + return GetPointer *>(VT_WINDOW_DIMENSIONS); + } + const ::flatbuffers::Vector *window_strides() const { + return GetPointer *>(VT_WINDOW_STRIDES); + } + const ::flatbuffers::Vector *base_dilations() const { + return GetPointer *>(VT_BASE_DILATIONS); + } + const ::flatbuffers::Vector *window_dilations() const { + return GetPointer *>(VT_WINDOW_DILATIONS); + } + const ::flatbuffers::Vector *padding() const { + return GetPointer *>(VT_PADDING); + } + int32_t body_subgraph_index() const { + return GetField(VT_BODY_SUBGRAPH_INDEX, 0); + } + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyOffset(verifier, VT_WINDOW_DIMENSIONS) && + verifier.VerifyVector(window_dimensions()) && + VerifyOffset(verifier, VT_WINDOW_STRIDES) && + verifier.VerifyVector(window_strides()) && + VerifyOffset(verifier, VT_BASE_DILATIONS) && + verifier.VerifyVector(base_dilations()) && + VerifyOffset(verifier, VT_WINDOW_DILATIONS) && + verifier.VerifyVector(window_dilations()) && + VerifyOffset(verifier, VT_PADDING) && + verifier.VerifyVector(padding()) && + VerifyField(verifier, VT_BODY_SUBGRAPH_INDEX, 4) && + verifier.EndTable(); + } + StablehloReduceWindowOptionsT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(StablehloReduceWindowOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + static ::flatbuffers::Offset Pack(::flatbuffers::FlatBufferBuilder &_fbb, const StablehloReduceWindowOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct StablehloReduceWindowOptionsBuilder { + typedef StablehloReduceWindowOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_window_dimensions(::flatbuffers::Offset<::flatbuffers::Vector> window_dimensions) { + fbb_.AddOffset(StablehloReduceWindowOptions::VT_WINDOW_DIMENSIONS, window_dimensions); + } + void add_window_strides(::flatbuffers::Offset<::flatbuffers::Vector> window_strides) { + fbb_.AddOffset(StablehloReduceWindowOptions::VT_WINDOW_STRIDES, window_strides); + } + void add_base_dilations(::flatbuffers::Offset<::flatbuffers::Vector> base_dilations) { + fbb_.AddOffset(StablehloReduceWindowOptions::VT_BASE_DILATIONS, base_dilations); + } + void add_window_dilations(::flatbuffers::Offset<::flatbuffers::Vector> window_dilations) { + fbb_.AddOffset(StablehloReduceWindowOptions::VT_WINDOW_DILATIONS, window_dilations); + } + void add_padding(::flatbuffers::Offset<::flatbuffers::Vector> padding) { + fbb_.AddOffset(StablehloReduceWindowOptions::VT_PADDING, padding); + } + void add_body_subgraph_index(int32_t body_subgraph_index) { + fbb_.AddElement(StablehloReduceWindowOptions::VT_BODY_SUBGRAPH_INDEX, body_subgraph_index, 0); + } + explicit StablehloReduceWindowOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateStablehloReduceWindowOptions( + ::flatbuffers::FlatBufferBuilder &_fbb, + ::flatbuffers::Offset<::flatbuffers::Vector> window_dimensions = 0, + ::flatbuffers::Offset<::flatbuffers::Vector> window_strides = 0, + ::flatbuffers::Offset<::flatbuffers::Vector> base_dilations = 0, + ::flatbuffers::Offset<::flatbuffers::Vector> window_dilations = 0, + ::flatbuffers::Offset<::flatbuffers::Vector> padding = 0, + int32_t body_subgraph_index = 0) { + StablehloReduceWindowOptionsBuilder builder_(_fbb); + builder_.add_body_subgraph_index(body_subgraph_index); + builder_.add_padding(padding); + builder_.add_window_dilations(window_dilations); + builder_.add_base_dilations(base_dilations); + builder_.add_window_strides(window_strides); + builder_.add_window_dimensions(window_dimensions); + return builder_.Finish(); +} + +inline ::flatbuffers::Offset CreateStablehloReduceWindowOptionsDirect( + ::flatbuffers::FlatBufferBuilder &_fbb, + const std::vector *window_dimensions = nullptr, + const std::vector *window_strides = nullptr, + const std::vector *base_dilations = nullptr, + const std::vector *window_dilations = nullptr, + const std::vector *padding = nullptr, + int32_t body_subgraph_index = 0) { + auto window_dimensions__ = window_dimensions ? _fbb.CreateVector(*window_dimensions) : 0; + auto window_strides__ = window_strides ? _fbb.CreateVector(*window_strides) : 0; + auto base_dilations__ = base_dilations ? _fbb.CreateVector(*base_dilations) : 0; + auto window_dilations__ = window_dilations ? _fbb.CreateVector(*window_dilations) : 0; + auto padding__ = padding ? _fbb.CreateVector(*padding) : 0; + return tflite::CreateStablehloReduceWindowOptions( + _fbb, + window_dimensions__, + window_strides__, + base_dilations__, + window_dilations__, + padding__, + body_subgraph_index); +} + +::flatbuffers::Offset CreateStablehloReduceWindowOptions(::flatbuffers::FlatBufferBuilder &_fbb, const StablehloReduceWindowOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct StablehloWhileOptionsT : public ::flatbuffers::NativeTable { + typedef StablehloWhileOptions TableType; + int32_t cond_subgraph_index = 0; + int32_t body_subgraph_index = 0; +}; + +struct StablehloWhileOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef StablehloWhileOptionsT NativeTableType; + typedef StablehloWhileOptionsBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_COND_SUBGRAPH_INDEX = 4, + VT_BODY_SUBGRAPH_INDEX = 6 + }; + int32_t cond_subgraph_index() const { + return GetField(VT_COND_SUBGRAPH_INDEX, 0); + } + int32_t body_subgraph_index() const { + return GetField(VT_BODY_SUBGRAPH_INDEX, 0); + } + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_COND_SUBGRAPH_INDEX, 4) && + VerifyField(verifier, VT_BODY_SUBGRAPH_INDEX, 4) && + verifier.EndTable(); + } + StablehloWhileOptionsT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(StablehloWhileOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + static ::flatbuffers::Offset Pack(::flatbuffers::FlatBufferBuilder &_fbb, const StablehloWhileOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct StablehloWhileOptionsBuilder { + typedef StablehloWhileOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_cond_subgraph_index(int32_t cond_subgraph_index) { + fbb_.AddElement(StablehloWhileOptions::VT_COND_SUBGRAPH_INDEX, cond_subgraph_index, 0); + } + void add_body_subgraph_index(int32_t body_subgraph_index) { + fbb_.AddElement(StablehloWhileOptions::VT_BODY_SUBGRAPH_INDEX, body_subgraph_index, 0); + } + explicit StablehloWhileOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateStablehloWhileOptions( + ::flatbuffers::FlatBufferBuilder &_fbb, + int32_t cond_subgraph_index = 0, + int32_t body_subgraph_index = 0) { + StablehloWhileOptionsBuilder builder_(_fbb); + builder_.add_body_subgraph_index(body_subgraph_index); + builder_.add_cond_subgraph_index(cond_subgraph_index); + return builder_.Finish(); +} + +::flatbuffers::Offset CreateStablehloWhileOptions(::flatbuffers::FlatBufferBuilder &_fbb, const StablehloWhileOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct StablehloSortOptionsT : public ::flatbuffers::NativeTable { + typedef StablehloSortOptions TableType; + int64_t dimension = 0; + bool is_stable = false; + int32_t comparator_subgraph_index = 0; +}; + +struct StablehloSortOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef StablehloSortOptionsT NativeTableType; + typedef StablehloSortOptionsBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_DIMENSION = 4, + VT_IS_STABLE = 6, + VT_COMPARATOR_SUBGRAPH_INDEX = 8 + }; + int64_t dimension() const { + return GetField(VT_DIMENSION, 0); + } + bool is_stable() const { + return GetField(VT_IS_STABLE, 0) != 0; + } + int32_t comparator_subgraph_index() const { + return GetField(VT_COMPARATOR_SUBGRAPH_INDEX, 0); + } + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_DIMENSION, 8) && + VerifyField(verifier, VT_IS_STABLE, 1) && + VerifyField(verifier, VT_COMPARATOR_SUBGRAPH_INDEX, 4) && + verifier.EndTable(); + } + StablehloSortOptionsT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(StablehloSortOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + static ::flatbuffers::Offset Pack(::flatbuffers::FlatBufferBuilder &_fbb, const StablehloSortOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct StablehloSortOptionsBuilder { + typedef StablehloSortOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_dimension(int64_t dimension) { + fbb_.AddElement(StablehloSortOptions::VT_DIMENSION, dimension, 0); + } + void add_is_stable(bool is_stable) { + fbb_.AddElement(StablehloSortOptions::VT_IS_STABLE, static_cast(is_stable), 0); + } + void add_comparator_subgraph_index(int32_t comparator_subgraph_index) { + fbb_.AddElement(StablehloSortOptions::VT_COMPARATOR_SUBGRAPH_INDEX, comparator_subgraph_index, 0); + } + explicit StablehloSortOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateStablehloSortOptions( + ::flatbuffers::FlatBufferBuilder &_fbb, + int64_t dimension = 0, + bool is_stable = false, + int32_t comparator_subgraph_index = 0) { + StablehloSortOptionsBuilder builder_(_fbb); + builder_.add_dimension(dimension); + builder_.add_comparator_subgraph_index(comparator_subgraph_index); + builder_.add_is_stable(is_stable); + return builder_.Finish(); +} + +::flatbuffers::Offset CreateStablehloSortOptions(::flatbuffers::FlatBufferBuilder &_fbb, const StablehloSortOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct StablehloConcatenateOptionsT : public ::flatbuffers::NativeTable { + typedef StablehloConcatenateOptions TableType; + int64_t dimension = 0; +}; + +struct StablehloConcatenateOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef StablehloConcatenateOptionsT NativeTableType; + typedef StablehloConcatenateOptionsBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_DIMENSION = 4 + }; + int64_t dimension() const { + return GetField(VT_DIMENSION, 0); + } + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_DIMENSION, 8) && + verifier.EndTable(); + } + StablehloConcatenateOptionsT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(StablehloConcatenateOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + static ::flatbuffers::Offset Pack(::flatbuffers::FlatBufferBuilder &_fbb, const StablehloConcatenateOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct StablehloConcatenateOptionsBuilder { + typedef StablehloConcatenateOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_dimension(int64_t dimension) { + fbb_.AddElement(StablehloConcatenateOptions::VT_DIMENSION, dimension, 0); + } + explicit StablehloConcatenateOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateStablehloConcatenateOptions( + ::flatbuffers::FlatBufferBuilder &_fbb, + int64_t dimension = 0) { + StablehloConcatenateOptionsBuilder builder_(_fbb); + builder_.add_dimension(dimension); + return builder_.Finish(); +} + +::flatbuffers::Offset CreateStablehloConcatenateOptions(::flatbuffers::FlatBufferBuilder &_fbb, const StablehloConcatenateOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct StablehloBroadcastInDimOptionsT : public ::flatbuffers::NativeTable { + typedef StablehloBroadcastInDimOptions TableType; + std::vector broadcast_dimensions{}; +}; + +struct StablehloBroadcastInDimOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef StablehloBroadcastInDimOptionsT NativeTableType; + typedef StablehloBroadcastInDimOptionsBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_BROADCAST_DIMENSIONS = 4 + }; + const ::flatbuffers::Vector *broadcast_dimensions() const { + return GetPointer *>(VT_BROADCAST_DIMENSIONS); + } + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyOffset(verifier, VT_BROADCAST_DIMENSIONS) && + verifier.VerifyVector(broadcast_dimensions()) && + verifier.EndTable(); + } + StablehloBroadcastInDimOptionsT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(StablehloBroadcastInDimOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + static ::flatbuffers::Offset Pack(::flatbuffers::FlatBufferBuilder &_fbb, const StablehloBroadcastInDimOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct StablehloBroadcastInDimOptionsBuilder { + typedef StablehloBroadcastInDimOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_broadcast_dimensions(::flatbuffers::Offset<::flatbuffers::Vector> broadcast_dimensions) { + fbb_.AddOffset(StablehloBroadcastInDimOptions::VT_BROADCAST_DIMENSIONS, broadcast_dimensions); + } + explicit StablehloBroadcastInDimOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateStablehloBroadcastInDimOptions( + ::flatbuffers::FlatBufferBuilder &_fbb, + ::flatbuffers::Offset<::flatbuffers::Vector> broadcast_dimensions = 0) { + StablehloBroadcastInDimOptionsBuilder builder_(_fbb); + builder_.add_broadcast_dimensions(broadcast_dimensions); + return builder_.Finish(); +} + +inline ::flatbuffers::Offset CreateStablehloBroadcastInDimOptionsDirect( + ::flatbuffers::FlatBufferBuilder &_fbb, + const std::vector *broadcast_dimensions = nullptr) { + auto broadcast_dimensions__ = broadcast_dimensions ? _fbb.CreateVector(*broadcast_dimensions) : 0; + return tflite::CreateStablehloBroadcastInDimOptions( + _fbb, + broadcast_dimensions__); +} + +::flatbuffers::Offset CreateStablehloBroadcastInDimOptions(::flatbuffers::FlatBufferBuilder &_fbb, const StablehloBroadcastInDimOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct StablehloCompareOptionsT : public ::flatbuffers::NativeTable { + typedef StablehloCompareOptions TableType; + tflite::StablehloComparisonDirection comparison_direction = tflite::StablehloComparisonDirection_STABLEHLO_COMPARISON_DIRECTION_EQ; + tflite::StablehloComparisonType compare_type = tflite::StablehloComparisonType_STABLEHLO_COMPARISON_TYPE_NOTYPE; +}; + +struct StablehloCompareOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef StablehloCompareOptionsT NativeTableType; + typedef StablehloCompareOptionsBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_COMPARISON_DIRECTION = 4, + VT_COMPARE_TYPE = 6 + }; + tflite::StablehloComparisonDirection comparison_direction() const { + return static_cast(GetField(VT_COMPARISON_DIRECTION, 0)); + } + tflite::StablehloComparisonType compare_type() const { + return static_cast(GetField(VT_COMPARE_TYPE, 0)); + } + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_COMPARISON_DIRECTION, 4) && + VerifyField(verifier, VT_COMPARE_TYPE, 4) && + verifier.EndTable(); + } + StablehloCompareOptionsT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(StablehloCompareOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + static ::flatbuffers::Offset Pack(::flatbuffers::FlatBufferBuilder &_fbb, const StablehloCompareOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct StablehloCompareOptionsBuilder { + typedef StablehloCompareOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_comparison_direction(tflite::StablehloComparisonDirection comparison_direction) { + fbb_.AddElement(StablehloCompareOptions::VT_COMPARISON_DIRECTION, static_cast(comparison_direction), 0); + } + void add_compare_type(tflite::StablehloComparisonType compare_type) { + fbb_.AddElement(StablehloCompareOptions::VT_COMPARE_TYPE, static_cast(compare_type), 0); + } + explicit StablehloCompareOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateStablehloCompareOptions( + ::flatbuffers::FlatBufferBuilder &_fbb, + tflite::StablehloComparisonDirection comparison_direction = tflite::StablehloComparisonDirection_STABLEHLO_COMPARISON_DIRECTION_EQ, + tflite::StablehloComparisonType compare_type = tflite::StablehloComparisonType_STABLEHLO_COMPARISON_TYPE_NOTYPE) { + StablehloCompareOptionsBuilder builder_(_fbb); + builder_.add_compare_type(compare_type); + builder_.add_comparison_direction(comparison_direction); + return builder_.Finish(); +} + +::flatbuffers::Offset CreateStablehloCompareOptions(::flatbuffers::FlatBufferBuilder &_fbb, const StablehloCompareOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct StablehloDynamicSliceOptionsT : public ::flatbuffers::NativeTable { + typedef StablehloDynamicSliceOptions TableType; + std::vector slice_sizes{}; +}; + +struct StablehloDynamicSliceOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef StablehloDynamicSliceOptionsT NativeTableType; + typedef StablehloDynamicSliceOptionsBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_SLICE_SIZES = 4 + }; + const ::flatbuffers::Vector *slice_sizes() const { + return GetPointer *>(VT_SLICE_SIZES); + } + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyOffset(verifier, VT_SLICE_SIZES) && + verifier.VerifyVector(slice_sizes()) && + verifier.EndTable(); + } + StablehloDynamicSliceOptionsT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(StablehloDynamicSliceOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + static ::flatbuffers::Offset Pack(::flatbuffers::FlatBufferBuilder &_fbb, const StablehloDynamicSliceOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct StablehloDynamicSliceOptionsBuilder { + typedef StablehloDynamicSliceOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_slice_sizes(::flatbuffers::Offset<::flatbuffers::Vector> slice_sizes) { + fbb_.AddOffset(StablehloDynamicSliceOptions::VT_SLICE_SIZES, slice_sizes); + } + explicit StablehloDynamicSliceOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateStablehloDynamicSliceOptions( + ::flatbuffers::FlatBufferBuilder &_fbb, + ::flatbuffers::Offset<::flatbuffers::Vector> slice_sizes = 0) { + StablehloDynamicSliceOptionsBuilder builder_(_fbb); + builder_.add_slice_sizes(slice_sizes); + return builder_.Finish(); +} + +inline ::flatbuffers::Offset CreateStablehloDynamicSliceOptionsDirect( + ::flatbuffers::FlatBufferBuilder &_fbb, + const std::vector *slice_sizes = nullptr) { + auto slice_sizes__ = slice_sizes ? _fbb.CreateVector(*slice_sizes) : 0; + return tflite::CreateStablehloDynamicSliceOptions( + _fbb, + slice_sizes__); +} + +::flatbuffers::Offset CreateStablehloDynamicSliceOptions(::flatbuffers::FlatBufferBuilder &_fbb, const StablehloDynamicSliceOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct StablehloPadOptionsT : public ::flatbuffers::NativeTable { + typedef StablehloPadOptions TableType; + std::vector edge_padding_low{}; + std::vector edge_padding_high{}; + std::vector interior_padding{}; +}; + +struct StablehloPadOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef StablehloPadOptionsT NativeTableType; + typedef StablehloPadOptionsBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_EDGE_PADDING_LOW = 4, + VT_EDGE_PADDING_HIGH = 6, + VT_INTERIOR_PADDING = 8 + }; + const ::flatbuffers::Vector *edge_padding_low() const { + return GetPointer *>(VT_EDGE_PADDING_LOW); + } + const ::flatbuffers::Vector *edge_padding_high() const { + return GetPointer *>(VT_EDGE_PADDING_HIGH); + } + const ::flatbuffers::Vector *interior_padding() const { + return GetPointer *>(VT_INTERIOR_PADDING); + } + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyOffset(verifier, VT_EDGE_PADDING_LOW) && + verifier.VerifyVector(edge_padding_low()) && + VerifyOffset(verifier, VT_EDGE_PADDING_HIGH) && + verifier.VerifyVector(edge_padding_high()) && + VerifyOffset(verifier, VT_INTERIOR_PADDING) && + verifier.VerifyVector(interior_padding()) && + verifier.EndTable(); + } + StablehloPadOptionsT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(StablehloPadOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + static ::flatbuffers::Offset Pack(::flatbuffers::FlatBufferBuilder &_fbb, const StablehloPadOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct StablehloPadOptionsBuilder { + typedef StablehloPadOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_edge_padding_low(::flatbuffers::Offset<::flatbuffers::Vector> edge_padding_low) { + fbb_.AddOffset(StablehloPadOptions::VT_EDGE_PADDING_LOW, edge_padding_low); + } + void add_edge_padding_high(::flatbuffers::Offset<::flatbuffers::Vector> edge_padding_high) { + fbb_.AddOffset(StablehloPadOptions::VT_EDGE_PADDING_HIGH, edge_padding_high); + } + void add_interior_padding(::flatbuffers::Offset<::flatbuffers::Vector> interior_padding) { + fbb_.AddOffset(StablehloPadOptions::VT_INTERIOR_PADDING, interior_padding); + } + explicit StablehloPadOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateStablehloPadOptions( + ::flatbuffers::FlatBufferBuilder &_fbb, + ::flatbuffers::Offset<::flatbuffers::Vector> edge_padding_low = 0, + ::flatbuffers::Offset<::flatbuffers::Vector> edge_padding_high = 0, + ::flatbuffers::Offset<::flatbuffers::Vector> interior_padding = 0) { + StablehloPadOptionsBuilder builder_(_fbb); + builder_.add_interior_padding(interior_padding); + builder_.add_edge_padding_high(edge_padding_high); + builder_.add_edge_padding_low(edge_padding_low); + return builder_.Finish(); +} + +inline ::flatbuffers::Offset CreateStablehloPadOptionsDirect( + ::flatbuffers::FlatBufferBuilder &_fbb, + const std::vector *edge_padding_low = nullptr, + const std::vector *edge_padding_high = nullptr, + const std::vector *interior_padding = nullptr) { + auto edge_padding_low__ = edge_padding_low ? _fbb.CreateVector(*edge_padding_low) : 0; + auto edge_padding_high__ = edge_padding_high ? _fbb.CreateVector(*edge_padding_high) : 0; + auto interior_padding__ = interior_padding ? _fbb.CreateVector(*interior_padding) : 0; + return tflite::CreateStablehloPadOptions( + _fbb, + edge_padding_low__, + edge_padding_high__, + interior_padding__); +} + +::flatbuffers::Offset CreateStablehloPadOptions(::flatbuffers::FlatBufferBuilder &_fbb, const StablehloPadOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct StablehloIotaOptionsT : public ::flatbuffers::NativeTable { + typedef StablehloIotaOptions TableType; + int64_t iota_dimension = 0; +}; + +struct StablehloIotaOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef StablehloIotaOptionsT NativeTableType; + typedef StablehloIotaOptionsBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_IOTA_DIMENSION = 4 + }; + int64_t iota_dimension() const { + return GetField(VT_IOTA_DIMENSION, 0); + } + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_IOTA_DIMENSION, 8) && + verifier.EndTable(); + } + StablehloIotaOptionsT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(StablehloIotaOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + static ::flatbuffers::Offset Pack(::flatbuffers::FlatBufferBuilder &_fbb, const StablehloIotaOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct StablehloIotaOptionsBuilder { + typedef StablehloIotaOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_iota_dimension(int64_t iota_dimension) { + fbb_.AddElement(StablehloIotaOptions::VT_IOTA_DIMENSION, iota_dimension, 0); + } + explicit StablehloIotaOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateStablehloIotaOptions( + ::flatbuffers::FlatBufferBuilder &_fbb, + int64_t iota_dimension = 0) { + StablehloIotaOptionsBuilder builder_(_fbb); + builder_.add_iota_dimension(iota_dimension); + return builder_.Finish(); +} + +::flatbuffers::Offset CreateStablehloIotaOptions(::flatbuffers::FlatBufferBuilder &_fbb, const StablehloIotaOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct StablehloCustomCallOptionsT : public ::flatbuffers::NativeTable { + typedef StablehloCustomCallOptions TableType; + std::string call_target_name{}; + bool has_side_effect = false; + std::string backend_config{}; + int32_t api_version = 0; + std::vector called_computations{}; + std::vector custom_attributes{}; +}; + +struct StablehloCustomCallOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef StablehloCustomCallOptionsT NativeTableType; + typedef StablehloCustomCallOptionsBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_CALL_TARGET_NAME = 4, + VT_HAS_SIDE_EFFECT = 6, + VT_BACKEND_CONFIG = 8, + VT_API_VERSION = 10, + VT_CALLED_COMPUTATIONS = 12, + VT_CUSTOM_ATTRIBUTES = 14 + }; + const ::flatbuffers::String *call_target_name() const { + return GetPointer(VT_CALL_TARGET_NAME); + } + bool has_side_effect() const { + return GetField(VT_HAS_SIDE_EFFECT, 0) != 0; + } + const ::flatbuffers::String *backend_config() const { + return GetPointer(VT_BACKEND_CONFIG); + } + int32_t api_version() const { + return GetField(VT_API_VERSION, 0); + } + const ::flatbuffers::Vector *called_computations() const { + return GetPointer *>(VT_CALLED_COMPUTATIONS); + } + const ::flatbuffers::Vector *custom_attributes() const { + return GetPointer *>(VT_CUSTOM_ATTRIBUTES); + } + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyOffset(verifier, VT_CALL_TARGET_NAME) && + verifier.VerifyString(call_target_name()) && + VerifyField(verifier, VT_HAS_SIDE_EFFECT, 1) && + VerifyOffset(verifier, VT_BACKEND_CONFIG) && + verifier.VerifyString(backend_config()) && + VerifyField(verifier, VT_API_VERSION, 4) && + VerifyOffset(verifier, VT_CALLED_COMPUTATIONS) && + verifier.VerifyVector(called_computations()) && + VerifyOffset(verifier, VT_CUSTOM_ATTRIBUTES) && + verifier.VerifyVector(custom_attributes()) && + verifier.EndTable(); + } + StablehloCustomCallOptionsT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(StablehloCustomCallOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + static ::flatbuffers::Offset Pack(::flatbuffers::FlatBufferBuilder &_fbb, const StablehloCustomCallOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct StablehloCustomCallOptionsBuilder { + typedef StablehloCustomCallOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_call_target_name(::flatbuffers::Offset<::flatbuffers::String> call_target_name) { + fbb_.AddOffset(StablehloCustomCallOptions::VT_CALL_TARGET_NAME, call_target_name); + } + void add_has_side_effect(bool has_side_effect) { + fbb_.AddElement(StablehloCustomCallOptions::VT_HAS_SIDE_EFFECT, static_cast(has_side_effect), 0); + } + void add_backend_config(::flatbuffers::Offset<::flatbuffers::String> backend_config) { + fbb_.AddOffset(StablehloCustomCallOptions::VT_BACKEND_CONFIG, backend_config); + } + void add_api_version(int32_t api_version) { + fbb_.AddElement(StablehloCustomCallOptions::VT_API_VERSION, api_version, 0); + } + void add_called_computations(::flatbuffers::Offset<::flatbuffers::Vector> called_computations) { + fbb_.AddOffset(StablehloCustomCallOptions::VT_CALLED_COMPUTATIONS, called_computations); + } + void add_custom_attributes(::flatbuffers::Offset<::flatbuffers::Vector> custom_attributes) { + fbb_.AddOffset(StablehloCustomCallOptions::VT_CUSTOM_ATTRIBUTES, custom_attributes); + } + explicit StablehloCustomCallOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateStablehloCustomCallOptions( + ::flatbuffers::FlatBufferBuilder &_fbb, + ::flatbuffers::Offset<::flatbuffers::String> call_target_name = 0, + bool has_side_effect = false, + ::flatbuffers::Offset<::flatbuffers::String> backend_config = 0, + int32_t api_version = 0, + ::flatbuffers::Offset<::flatbuffers::Vector> called_computations = 0, + ::flatbuffers::Offset<::flatbuffers::Vector> custom_attributes = 0) { + StablehloCustomCallOptionsBuilder builder_(_fbb); + builder_.add_custom_attributes(custom_attributes); + builder_.add_called_computations(called_computations); + builder_.add_api_version(api_version); + builder_.add_backend_config(backend_config); + builder_.add_call_target_name(call_target_name); + builder_.add_has_side_effect(has_side_effect); + return builder_.Finish(); +} + +inline ::flatbuffers::Offset CreateStablehloCustomCallOptionsDirect( + ::flatbuffers::FlatBufferBuilder &_fbb, + const char *call_target_name = nullptr, + bool has_side_effect = false, + const char *backend_config = nullptr, + int32_t api_version = 0, + const std::vector *called_computations = nullptr, + const std::vector *custom_attributes = nullptr) { + auto call_target_name__ = call_target_name ? _fbb.CreateString(call_target_name) : 0; + auto backend_config__ = backend_config ? _fbb.CreateString(backend_config) : 0; + auto called_computations__ = called_computations ? _fbb.CreateVector(*called_computations) : 0; + auto custom_attributes__ = custom_attributes ? _fbb.CreateVector(*custom_attributes) : 0; + return tflite::CreateStablehloCustomCallOptions( + _fbb, + call_target_name__, + has_side_effect, + backend_config__, + api_version, + called_computations__, + custom_attributes__); +} + +::flatbuffers::Offset CreateStablehloCustomCallOptions(::flatbuffers::FlatBufferBuilder &_fbb, const StablehloCustomCallOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct StablehloReduceOptionsT : public ::flatbuffers::NativeTable { + typedef StablehloReduceOptions TableType; + std::vector dimensions{}; + int32_t body_subgraph_index = 0; +}; + +struct StablehloReduceOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef StablehloReduceOptionsT NativeTableType; + typedef StablehloReduceOptionsBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_DIMENSIONS = 4, + VT_BODY_SUBGRAPH_INDEX = 6 + }; + const ::flatbuffers::Vector *dimensions() const { + return GetPointer *>(VT_DIMENSIONS); + } + int32_t body_subgraph_index() const { + return GetField(VT_BODY_SUBGRAPH_INDEX, 0); + } + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyOffset(verifier, VT_DIMENSIONS) && + verifier.VerifyVector(dimensions()) && + VerifyField(verifier, VT_BODY_SUBGRAPH_INDEX, 4) && + verifier.EndTable(); + } + StablehloReduceOptionsT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(StablehloReduceOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + static ::flatbuffers::Offset Pack(::flatbuffers::FlatBufferBuilder &_fbb, const StablehloReduceOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct StablehloReduceOptionsBuilder { + typedef StablehloReduceOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_dimensions(::flatbuffers::Offset<::flatbuffers::Vector> dimensions) { + fbb_.AddOffset(StablehloReduceOptions::VT_DIMENSIONS, dimensions); + } + void add_body_subgraph_index(int32_t body_subgraph_index) { + fbb_.AddElement(StablehloReduceOptions::VT_BODY_SUBGRAPH_INDEX, body_subgraph_index, 0); + } + explicit StablehloReduceOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateStablehloReduceOptions( + ::flatbuffers::FlatBufferBuilder &_fbb, + ::flatbuffers::Offset<::flatbuffers::Vector> dimensions = 0, + int32_t body_subgraph_index = 0) { + StablehloReduceOptionsBuilder builder_(_fbb); + builder_.add_body_subgraph_index(body_subgraph_index); + builder_.add_dimensions(dimensions); + return builder_.Finish(); +} + +inline ::flatbuffers::Offset CreateStablehloReduceOptionsDirect( + ::flatbuffers::FlatBufferBuilder &_fbb, + const std::vector *dimensions = nullptr, + int32_t body_subgraph_index = 0) { + auto dimensions__ = dimensions ? _fbb.CreateVector(*dimensions) : 0; + return tflite::CreateStablehloReduceOptions( + _fbb, + dimensions__, + body_subgraph_index); +} + +::flatbuffers::Offset CreateStablehloReduceOptions(::flatbuffers::FlatBufferBuilder &_fbb, const StablehloReduceOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct StablehloSliceOptionsT : public ::flatbuffers::NativeTable { + typedef StablehloSliceOptions TableType; + std::vector start_indices{}; + std::vector limit_indices{}; + std::vector strides{}; +}; + +struct StablehloSliceOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef StablehloSliceOptionsT NativeTableType; + typedef StablehloSliceOptionsBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_START_INDICES = 4, + VT_LIMIT_INDICES = 6, + VT_STRIDES = 8 + }; + const ::flatbuffers::Vector *start_indices() const { + return GetPointer *>(VT_START_INDICES); + } + const ::flatbuffers::Vector *limit_indices() const { + return GetPointer *>(VT_LIMIT_INDICES); + } + const ::flatbuffers::Vector *strides() const { + return GetPointer *>(VT_STRIDES); + } + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyOffset(verifier, VT_START_INDICES) && + verifier.VerifyVector(start_indices()) && + VerifyOffset(verifier, VT_LIMIT_INDICES) && + verifier.VerifyVector(limit_indices()) && + VerifyOffset(verifier, VT_STRIDES) && + verifier.VerifyVector(strides()) && + verifier.EndTable(); + } + StablehloSliceOptionsT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(StablehloSliceOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + static ::flatbuffers::Offset Pack(::flatbuffers::FlatBufferBuilder &_fbb, const StablehloSliceOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct StablehloSliceOptionsBuilder { + typedef StablehloSliceOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_start_indices(::flatbuffers::Offset<::flatbuffers::Vector> start_indices) { + fbb_.AddOffset(StablehloSliceOptions::VT_START_INDICES, start_indices); + } + void add_limit_indices(::flatbuffers::Offset<::flatbuffers::Vector> limit_indices) { + fbb_.AddOffset(StablehloSliceOptions::VT_LIMIT_INDICES, limit_indices); + } + void add_strides(::flatbuffers::Offset<::flatbuffers::Vector> strides) { + fbb_.AddOffset(StablehloSliceOptions::VT_STRIDES, strides); + } + explicit StablehloSliceOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateStablehloSliceOptions( + ::flatbuffers::FlatBufferBuilder &_fbb, + ::flatbuffers::Offset<::flatbuffers::Vector> start_indices = 0, + ::flatbuffers::Offset<::flatbuffers::Vector> limit_indices = 0, + ::flatbuffers::Offset<::flatbuffers::Vector> strides = 0) { + StablehloSliceOptionsBuilder builder_(_fbb); + builder_.add_strides(strides); + builder_.add_limit_indices(limit_indices); + builder_.add_start_indices(start_indices); + return builder_.Finish(); +} + +inline ::flatbuffers::Offset CreateStablehloSliceOptionsDirect( + ::flatbuffers::FlatBufferBuilder &_fbb, + const std::vector *start_indices = nullptr, + const std::vector *limit_indices = nullptr, + const std::vector *strides = nullptr) { + auto start_indices__ = start_indices ? _fbb.CreateVector(*start_indices) : 0; + auto limit_indices__ = limit_indices ? _fbb.CreateVector(*limit_indices) : 0; + auto strides__ = strides ? _fbb.CreateVector(*strides) : 0; + return tflite::CreateStablehloSliceOptions( + _fbb, + start_indices__, + limit_indices__, + strides__); +} + +::flatbuffers::Offset CreateStablehloSliceOptions(::flatbuffers::FlatBufferBuilder &_fbb, const StablehloSliceOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct StablehloConvolutionOptionsT : public ::flatbuffers::NativeTable { + typedef StablehloConvolutionOptions TableType; + std::vector window_strides{}; + std::vector padding{}; + std::vector lhs_dilation{}; + std::vector rhs_dilation{}; + std::vector window_reversal{}; + int64_t input_batch_dimension = 0; + int64_t input_feature_dimension = 0; + std::vector input_spatial_dimensions{}; + int64_t kernel_input_feature_dimension = 0; + int64_t kernel_output_feature_dimension = 0; + std::vector kernel_spatial_dimensions{}; + int64_t output_batch_dimension = 0; + int64_t output_feature_dimension = 0; + std::vector output_spatial_dimensions{}; + int64_t feature_group_count = 0; + int64_t batch_group_count = 0; + std::vector precision_config{}; +}; + +struct StablehloConvolutionOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef StablehloConvolutionOptionsT NativeTableType; + typedef StablehloConvolutionOptionsBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_WINDOW_STRIDES = 4, + VT_PADDING = 6, + VT_LHS_DILATION = 8, + VT_RHS_DILATION = 10, + VT_WINDOW_REVERSAL = 12, + VT_INPUT_BATCH_DIMENSION = 14, + VT_INPUT_FEATURE_DIMENSION = 16, + VT_INPUT_SPATIAL_DIMENSIONS = 18, + VT_KERNEL_INPUT_FEATURE_DIMENSION = 20, + VT_KERNEL_OUTPUT_FEATURE_DIMENSION = 22, + VT_KERNEL_SPATIAL_DIMENSIONS = 24, + VT_OUTPUT_BATCH_DIMENSION = 26, + VT_OUTPUT_FEATURE_DIMENSION = 28, + VT_OUTPUT_SPATIAL_DIMENSIONS = 30, + VT_FEATURE_GROUP_COUNT = 32, + VT_BATCH_GROUP_COUNT = 34, + VT_PRECISION_CONFIG = 36 + }; + const ::flatbuffers::Vector *window_strides() const { + return GetPointer *>(VT_WINDOW_STRIDES); + } + const ::flatbuffers::Vector *padding() const { + return GetPointer *>(VT_PADDING); + } + const ::flatbuffers::Vector *lhs_dilation() const { + return GetPointer *>(VT_LHS_DILATION); + } + const ::flatbuffers::Vector *rhs_dilation() const { + return GetPointer *>(VT_RHS_DILATION); + } + const ::flatbuffers::Vector *window_reversal() const { + return GetPointer *>(VT_WINDOW_REVERSAL); + } + int64_t input_batch_dimension() const { + return GetField(VT_INPUT_BATCH_DIMENSION, 0); + } + int64_t input_feature_dimension() const { + return GetField(VT_INPUT_FEATURE_DIMENSION, 0); + } + const ::flatbuffers::Vector *input_spatial_dimensions() const { + return GetPointer *>(VT_INPUT_SPATIAL_DIMENSIONS); + } + int64_t kernel_input_feature_dimension() const { + return GetField(VT_KERNEL_INPUT_FEATURE_DIMENSION, 0); + } + int64_t kernel_output_feature_dimension() const { + return GetField(VT_KERNEL_OUTPUT_FEATURE_DIMENSION, 0); + } + const ::flatbuffers::Vector *kernel_spatial_dimensions() const { + return GetPointer *>(VT_KERNEL_SPATIAL_DIMENSIONS); + } + int64_t output_batch_dimension() const { + return GetField(VT_OUTPUT_BATCH_DIMENSION, 0); + } + int64_t output_feature_dimension() const { + return GetField(VT_OUTPUT_FEATURE_DIMENSION, 0); + } + const ::flatbuffers::Vector *output_spatial_dimensions() const { + return GetPointer *>(VT_OUTPUT_SPATIAL_DIMENSIONS); + } + int64_t feature_group_count() const { + return GetField(VT_FEATURE_GROUP_COUNT, 0); + } + int64_t batch_group_count() const { + return GetField(VT_BATCH_GROUP_COUNT, 0); + } + const ::flatbuffers::Vector *precision_config() const { + return GetPointer *>(VT_PRECISION_CONFIG); + } + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyOffset(verifier, VT_WINDOW_STRIDES) && + verifier.VerifyVector(window_strides()) && + VerifyOffset(verifier, VT_PADDING) && + verifier.VerifyVector(padding()) && + VerifyOffset(verifier, VT_LHS_DILATION) && + verifier.VerifyVector(lhs_dilation()) && + VerifyOffset(verifier, VT_RHS_DILATION) && + verifier.VerifyVector(rhs_dilation()) && + VerifyOffset(verifier, VT_WINDOW_REVERSAL) && + verifier.VerifyVector(window_reversal()) && + VerifyField(verifier, VT_INPUT_BATCH_DIMENSION, 8) && + VerifyField(verifier, VT_INPUT_FEATURE_DIMENSION, 8) && + VerifyOffset(verifier, VT_INPUT_SPATIAL_DIMENSIONS) && + verifier.VerifyVector(input_spatial_dimensions()) && + VerifyField(verifier, VT_KERNEL_INPUT_FEATURE_DIMENSION, 8) && + VerifyField(verifier, VT_KERNEL_OUTPUT_FEATURE_DIMENSION, 8) && + VerifyOffset(verifier, VT_KERNEL_SPATIAL_DIMENSIONS) && + verifier.VerifyVector(kernel_spatial_dimensions()) && + VerifyField(verifier, VT_OUTPUT_BATCH_DIMENSION, 8) && + VerifyField(verifier, VT_OUTPUT_FEATURE_DIMENSION, 8) && + VerifyOffset(verifier, VT_OUTPUT_SPATIAL_DIMENSIONS) && + verifier.VerifyVector(output_spatial_dimensions()) && + VerifyField(verifier, VT_FEATURE_GROUP_COUNT, 8) && + VerifyField(verifier, VT_BATCH_GROUP_COUNT, 8) && + VerifyOffset(verifier, VT_PRECISION_CONFIG) && + verifier.VerifyVector(precision_config()) && + verifier.EndTable(); + } + StablehloConvolutionOptionsT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(StablehloConvolutionOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + static ::flatbuffers::Offset Pack(::flatbuffers::FlatBufferBuilder &_fbb, const StablehloConvolutionOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct StablehloConvolutionOptionsBuilder { + typedef StablehloConvolutionOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_window_strides(::flatbuffers::Offset<::flatbuffers::Vector> window_strides) { + fbb_.AddOffset(StablehloConvolutionOptions::VT_WINDOW_STRIDES, window_strides); + } + void add_padding(::flatbuffers::Offset<::flatbuffers::Vector> padding) { + fbb_.AddOffset(StablehloConvolutionOptions::VT_PADDING, padding); + } + void add_lhs_dilation(::flatbuffers::Offset<::flatbuffers::Vector> lhs_dilation) { + fbb_.AddOffset(StablehloConvolutionOptions::VT_LHS_DILATION, lhs_dilation); + } + void add_rhs_dilation(::flatbuffers::Offset<::flatbuffers::Vector> rhs_dilation) { + fbb_.AddOffset(StablehloConvolutionOptions::VT_RHS_DILATION, rhs_dilation); + } + void add_window_reversal(::flatbuffers::Offset<::flatbuffers::Vector> window_reversal) { + fbb_.AddOffset(StablehloConvolutionOptions::VT_WINDOW_REVERSAL, window_reversal); + } + void add_input_batch_dimension(int64_t input_batch_dimension) { + fbb_.AddElement(StablehloConvolutionOptions::VT_INPUT_BATCH_DIMENSION, input_batch_dimension, 0); + } + void add_input_feature_dimension(int64_t input_feature_dimension) { + fbb_.AddElement(StablehloConvolutionOptions::VT_INPUT_FEATURE_DIMENSION, input_feature_dimension, 0); + } + void add_input_spatial_dimensions(::flatbuffers::Offset<::flatbuffers::Vector> input_spatial_dimensions) { + fbb_.AddOffset(StablehloConvolutionOptions::VT_INPUT_SPATIAL_DIMENSIONS, input_spatial_dimensions); + } + void add_kernel_input_feature_dimension(int64_t kernel_input_feature_dimension) { + fbb_.AddElement(StablehloConvolutionOptions::VT_KERNEL_INPUT_FEATURE_DIMENSION, kernel_input_feature_dimension, 0); + } + void add_kernel_output_feature_dimension(int64_t kernel_output_feature_dimension) { + fbb_.AddElement(StablehloConvolutionOptions::VT_KERNEL_OUTPUT_FEATURE_DIMENSION, kernel_output_feature_dimension, 0); + } + void add_kernel_spatial_dimensions(::flatbuffers::Offset<::flatbuffers::Vector> kernel_spatial_dimensions) { + fbb_.AddOffset(StablehloConvolutionOptions::VT_KERNEL_SPATIAL_DIMENSIONS, kernel_spatial_dimensions); + } + void add_output_batch_dimension(int64_t output_batch_dimension) { + fbb_.AddElement(StablehloConvolutionOptions::VT_OUTPUT_BATCH_DIMENSION, output_batch_dimension, 0); + } + void add_output_feature_dimension(int64_t output_feature_dimension) { + fbb_.AddElement(StablehloConvolutionOptions::VT_OUTPUT_FEATURE_DIMENSION, output_feature_dimension, 0); + } + void add_output_spatial_dimensions(::flatbuffers::Offset<::flatbuffers::Vector> output_spatial_dimensions) { + fbb_.AddOffset(StablehloConvolutionOptions::VT_OUTPUT_SPATIAL_DIMENSIONS, output_spatial_dimensions); + } + void add_feature_group_count(int64_t feature_group_count) { + fbb_.AddElement(StablehloConvolutionOptions::VT_FEATURE_GROUP_COUNT, feature_group_count, 0); + } + void add_batch_group_count(int64_t batch_group_count) { + fbb_.AddElement(StablehloConvolutionOptions::VT_BATCH_GROUP_COUNT, batch_group_count, 0); + } + void add_precision_config(::flatbuffers::Offset<::flatbuffers::Vector> precision_config) { + fbb_.AddOffset(StablehloConvolutionOptions::VT_PRECISION_CONFIG, precision_config); + } + explicit StablehloConvolutionOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateStablehloConvolutionOptions( + ::flatbuffers::FlatBufferBuilder &_fbb, + ::flatbuffers::Offset<::flatbuffers::Vector> window_strides = 0, + ::flatbuffers::Offset<::flatbuffers::Vector> padding = 0, + ::flatbuffers::Offset<::flatbuffers::Vector> lhs_dilation = 0, + ::flatbuffers::Offset<::flatbuffers::Vector> rhs_dilation = 0, + ::flatbuffers::Offset<::flatbuffers::Vector> window_reversal = 0, + int64_t input_batch_dimension = 0, + int64_t input_feature_dimension = 0, + ::flatbuffers::Offset<::flatbuffers::Vector> input_spatial_dimensions = 0, + int64_t kernel_input_feature_dimension = 0, + int64_t kernel_output_feature_dimension = 0, + ::flatbuffers::Offset<::flatbuffers::Vector> kernel_spatial_dimensions = 0, + int64_t output_batch_dimension = 0, + int64_t output_feature_dimension = 0, + ::flatbuffers::Offset<::flatbuffers::Vector> output_spatial_dimensions = 0, + int64_t feature_group_count = 0, + int64_t batch_group_count = 0, + ::flatbuffers::Offset<::flatbuffers::Vector> precision_config = 0) { + StablehloConvolutionOptionsBuilder builder_(_fbb); + builder_.add_batch_group_count(batch_group_count); + builder_.add_feature_group_count(feature_group_count); + builder_.add_output_feature_dimension(output_feature_dimension); + builder_.add_output_batch_dimension(output_batch_dimension); + builder_.add_kernel_output_feature_dimension(kernel_output_feature_dimension); + builder_.add_kernel_input_feature_dimension(kernel_input_feature_dimension); + builder_.add_input_feature_dimension(input_feature_dimension); + builder_.add_input_batch_dimension(input_batch_dimension); + builder_.add_precision_config(precision_config); + builder_.add_output_spatial_dimensions(output_spatial_dimensions); + builder_.add_kernel_spatial_dimensions(kernel_spatial_dimensions); + builder_.add_input_spatial_dimensions(input_spatial_dimensions); + builder_.add_window_reversal(window_reversal); + builder_.add_rhs_dilation(rhs_dilation); + builder_.add_lhs_dilation(lhs_dilation); + builder_.add_padding(padding); + builder_.add_window_strides(window_strides); + return builder_.Finish(); +} + +inline ::flatbuffers::Offset CreateStablehloConvolutionOptionsDirect( + ::flatbuffers::FlatBufferBuilder &_fbb, + const std::vector *window_strides = nullptr, + const std::vector *padding = nullptr, + const std::vector *lhs_dilation = nullptr, + const std::vector *rhs_dilation = nullptr, + const std::vector *window_reversal = nullptr, + int64_t input_batch_dimension = 0, + int64_t input_feature_dimension = 0, + const std::vector *input_spatial_dimensions = nullptr, + int64_t kernel_input_feature_dimension = 0, + int64_t kernel_output_feature_dimension = 0, + const std::vector *kernel_spatial_dimensions = nullptr, + int64_t output_batch_dimension = 0, + int64_t output_feature_dimension = 0, + const std::vector *output_spatial_dimensions = nullptr, + int64_t feature_group_count = 0, + int64_t batch_group_count = 0, + const std::vector *precision_config = nullptr) { + auto window_strides__ = window_strides ? _fbb.CreateVector(*window_strides) : 0; + auto padding__ = padding ? _fbb.CreateVector(*padding) : 0; + auto lhs_dilation__ = lhs_dilation ? _fbb.CreateVector(*lhs_dilation) : 0; + auto rhs_dilation__ = rhs_dilation ? _fbb.CreateVector(*rhs_dilation) : 0; + auto window_reversal__ = window_reversal ? _fbb.CreateVector(*window_reversal) : 0; + auto input_spatial_dimensions__ = input_spatial_dimensions ? _fbb.CreateVector(*input_spatial_dimensions) : 0; + auto kernel_spatial_dimensions__ = kernel_spatial_dimensions ? _fbb.CreateVector(*kernel_spatial_dimensions) : 0; + auto output_spatial_dimensions__ = output_spatial_dimensions ? _fbb.CreateVector(*output_spatial_dimensions) : 0; + auto precision_config__ = precision_config ? _fbb.CreateVector(*precision_config) : 0; + return tflite::CreateStablehloConvolutionOptions( + _fbb, + window_strides__, + padding__, + lhs_dilation__, + rhs_dilation__, + window_reversal__, + input_batch_dimension, + input_feature_dimension, + input_spatial_dimensions__, + kernel_input_feature_dimension, + kernel_output_feature_dimension, + kernel_spatial_dimensions__, + output_batch_dimension, + output_feature_dimension, + output_spatial_dimensions__, + feature_group_count, + batch_group_count, + precision_config__); +} + +::flatbuffers::Offset CreateStablehloConvolutionOptions(::flatbuffers::FlatBufferBuilder &_fbb, const StablehloConvolutionOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct StablehloScatterOptionsT : public ::flatbuffers::NativeTable { + typedef StablehloScatterOptions TableType; + bool indices_are_sorted = false; + std::vector update_window_dims{}; + std::vector inserted_window_dims{}; + std::vector scatter_dims_to_operand_dims{}; + int64_t index_vector_dim = 0; + bool unique_indices = false; + int32_t update_computation_subgraph_index = 0; +}; + +struct StablehloScatterOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef StablehloScatterOptionsT NativeTableType; + typedef StablehloScatterOptionsBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_INDICES_ARE_SORTED = 4, + VT_UPDATE_WINDOW_DIMS = 6, + VT_INSERTED_WINDOW_DIMS = 8, + VT_SCATTER_DIMS_TO_OPERAND_DIMS = 10, + VT_INDEX_VECTOR_DIM = 12, + VT_UNIQUE_INDICES = 14, + VT_UPDATE_COMPUTATION_SUBGRAPH_INDEX = 16 + }; + bool indices_are_sorted() const { + return GetField(VT_INDICES_ARE_SORTED, 0) != 0; + } + const ::flatbuffers::Vector *update_window_dims() const { + return GetPointer *>(VT_UPDATE_WINDOW_DIMS); + } + const ::flatbuffers::Vector *inserted_window_dims() const { + return GetPointer *>(VT_INSERTED_WINDOW_DIMS); + } + const ::flatbuffers::Vector *scatter_dims_to_operand_dims() const { + return GetPointer *>(VT_SCATTER_DIMS_TO_OPERAND_DIMS); + } + int64_t index_vector_dim() const { + return GetField(VT_INDEX_VECTOR_DIM, 0); + } + bool unique_indices() const { + return GetField(VT_UNIQUE_INDICES, 0) != 0; + } + int32_t update_computation_subgraph_index() const { + return GetField(VT_UPDATE_COMPUTATION_SUBGRAPH_INDEX, 0); + } + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_INDICES_ARE_SORTED, 1) && + VerifyOffset(verifier, VT_UPDATE_WINDOW_DIMS) && + verifier.VerifyVector(update_window_dims()) && + VerifyOffset(verifier, VT_INSERTED_WINDOW_DIMS) && + verifier.VerifyVector(inserted_window_dims()) && + VerifyOffset(verifier, VT_SCATTER_DIMS_TO_OPERAND_DIMS) && + verifier.VerifyVector(scatter_dims_to_operand_dims()) && + VerifyField(verifier, VT_INDEX_VECTOR_DIM, 8) && + VerifyField(verifier, VT_UNIQUE_INDICES, 1) && + VerifyField(verifier, VT_UPDATE_COMPUTATION_SUBGRAPH_INDEX, 4) && + verifier.EndTable(); + } + StablehloScatterOptionsT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(StablehloScatterOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + static ::flatbuffers::Offset Pack(::flatbuffers::FlatBufferBuilder &_fbb, const StablehloScatterOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct StablehloScatterOptionsBuilder { + typedef StablehloScatterOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_indices_are_sorted(bool indices_are_sorted) { + fbb_.AddElement(StablehloScatterOptions::VT_INDICES_ARE_SORTED, static_cast(indices_are_sorted), 0); + } + void add_update_window_dims(::flatbuffers::Offset<::flatbuffers::Vector> update_window_dims) { + fbb_.AddOffset(StablehloScatterOptions::VT_UPDATE_WINDOW_DIMS, update_window_dims); + } + void add_inserted_window_dims(::flatbuffers::Offset<::flatbuffers::Vector> inserted_window_dims) { + fbb_.AddOffset(StablehloScatterOptions::VT_INSERTED_WINDOW_DIMS, inserted_window_dims); + } + void add_scatter_dims_to_operand_dims(::flatbuffers::Offset<::flatbuffers::Vector> scatter_dims_to_operand_dims) { + fbb_.AddOffset(StablehloScatterOptions::VT_SCATTER_DIMS_TO_OPERAND_DIMS, scatter_dims_to_operand_dims); + } + void add_index_vector_dim(int64_t index_vector_dim) { + fbb_.AddElement(StablehloScatterOptions::VT_INDEX_VECTOR_DIM, index_vector_dim, 0); + } + void add_unique_indices(bool unique_indices) { + fbb_.AddElement(StablehloScatterOptions::VT_UNIQUE_INDICES, static_cast(unique_indices), 0); + } + void add_update_computation_subgraph_index(int32_t update_computation_subgraph_index) { + fbb_.AddElement(StablehloScatterOptions::VT_UPDATE_COMPUTATION_SUBGRAPH_INDEX, update_computation_subgraph_index, 0); + } + explicit StablehloScatterOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateStablehloScatterOptions( + ::flatbuffers::FlatBufferBuilder &_fbb, + bool indices_are_sorted = false, + ::flatbuffers::Offset<::flatbuffers::Vector> update_window_dims = 0, + ::flatbuffers::Offset<::flatbuffers::Vector> inserted_window_dims = 0, + ::flatbuffers::Offset<::flatbuffers::Vector> scatter_dims_to_operand_dims = 0, + int64_t index_vector_dim = 0, + bool unique_indices = false, + int32_t update_computation_subgraph_index = 0) { + StablehloScatterOptionsBuilder builder_(_fbb); + builder_.add_index_vector_dim(index_vector_dim); + builder_.add_update_computation_subgraph_index(update_computation_subgraph_index); + builder_.add_scatter_dims_to_operand_dims(scatter_dims_to_operand_dims); + builder_.add_inserted_window_dims(inserted_window_dims); + builder_.add_update_window_dims(update_window_dims); + builder_.add_unique_indices(unique_indices); + builder_.add_indices_are_sorted(indices_are_sorted); + return builder_.Finish(); +} + +inline ::flatbuffers::Offset CreateStablehloScatterOptionsDirect( + ::flatbuffers::FlatBufferBuilder &_fbb, + bool indices_are_sorted = false, + const std::vector *update_window_dims = nullptr, + const std::vector *inserted_window_dims = nullptr, + const std::vector *scatter_dims_to_operand_dims = nullptr, + int64_t index_vector_dim = 0, + bool unique_indices = false, + int32_t update_computation_subgraph_index = 0) { + auto update_window_dims__ = update_window_dims ? _fbb.CreateVector(*update_window_dims) : 0; + auto inserted_window_dims__ = inserted_window_dims ? _fbb.CreateVector(*inserted_window_dims) : 0; + auto scatter_dims_to_operand_dims__ = scatter_dims_to_operand_dims ? _fbb.CreateVector(*scatter_dims_to_operand_dims) : 0; + return tflite::CreateStablehloScatterOptions( + _fbb, + indices_are_sorted, + update_window_dims__, + inserted_window_dims__, + scatter_dims_to_operand_dims__, + index_vector_dim, + unique_indices, + update_computation_subgraph_index); +} + +::flatbuffers::Offset CreateStablehloScatterOptions(::flatbuffers::FlatBufferBuilder &_fbb, const StablehloScatterOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct StablehloCaseOptionsT : public ::flatbuffers::NativeTable { + typedef StablehloCaseOptions TableType; + std::vector branch_subgraph_indices{}; +}; + +struct StablehloCaseOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef StablehloCaseOptionsT NativeTableType; + typedef StablehloCaseOptionsBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_BRANCH_SUBGRAPH_INDICES = 4 + }; + const ::flatbuffers::Vector *branch_subgraph_indices() const { + return GetPointer *>(VT_BRANCH_SUBGRAPH_INDICES); + } + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyOffset(verifier, VT_BRANCH_SUBGRAPH_INDICES) && + verifier.VerifyVector(branch_subgraph_indices()) && + verifier.EndTable(); + } + StablehloCaseOptionsT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(StablehloCaseOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + static ::flatbuffers::Offset Pack(::flatbuffers::FlatBufferBuilder &_fbb, const StablehloCaseOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct StablehloCaseOptionsBuilder { + typedef StablehloCaseOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_branch_subgraph_indices(::flatbuffers::Offset<::flatbuffers::Vector> branch_subgraph_indices) { + fbb_.AddOffset(StablehloCaseOptions::VT_BRANCH_SUBGRAPH_INDICES, branch_subgraph_indices); + } + explicit StablehloCaseOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateStablehloCaseOptions( + ::flatbuffers::FlatBufferBuilder &_fbb, + ::flatbuffers::Offset<::flatbuffers::Vector> branch_subgraph_indices = 0) { + StablehloCaseOptionsBuilder builder_(_fbb); + builder_.add_branch_subgraph_indices(branch_subgraph_indices); + return builder_.Finish(); +} + +inline ::flatbuffers::Offset CreateStablehloCaseOptionsDirect( + ::flatbuffers::FlatBufferBuilder &_fbb, + const std::vector *branch_subgraph_indices = nullptr) { + auto branch_subgraph_indices__ = branch_subgraph_indices ? _fbb.CreateVector(*branch_subgraph_indices) : 0; + return tflite::CreateStablehloCaseOptions( + _fbb, + branch_subgraph_indices__); +} + +::flatbuffers::Offset CreateStablehloCaseOptions(::flatbuffers::FlatBufferBuilder &_fbb, const StablehloCaseOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct StablehloRngBitGeneratorOptionsT : public ::flatbuffers::NativeTable { + typedef StablehloRngBitGeneratorOptions TableType; + tflite::RngAlgorithm algorithm = tflite::RngAlgorithm_DEFAULT; +}; + +struct StablehloRngBitGeneratorOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef StablehloRngBitGeneratorOptionsT NativeTableType; + typedef StablehloRngBitGeneratorOptionsBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_ALGORITHM = 4 + }; + tflite::RngAlgorithm algorithm() const { + return static_cast(GetField(VT_ALGORITHM, 0)); + } + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_ALGORITHM, 1) && + verifier.EndTable(); + } + StablehloRngBitGeneratorOptionsT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(StablehloRngBitGeneratorOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + static ::flatbuffers::Offset Pack(::flatbuffers::FlatBufferBuilder &_fbb, const StablehloRngBitGeneratorOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct StablehloRngBitGeneratorOptionsBuilder { + typedef StablehloRngBitGeneratorOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_algorithm(tflite::RngAlgorithm algorithm) { + fbb_.AddElement(StablehloRngBitGeneratorOptions::VT_ALGORITHM, static_cast(algorithm), 0); + } + explicit StablehloRngBitGeneratorOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateStablehloRngBitGeneratorOptions( + ::flatbuffers::FlatBufferBuilder &_fbb, + tflite::RngAlgorithm algorithm = tflite::RngAlgorithm_DEFAULT) { + StablehloRngBitGeneratorOptionsBuilder builder_(_fbb); + builder_.add_algorithm(algorithm); + return builder_.Finish(); +} + +::flatbuffers::Offset CreateStablehloRngBitGeneratorOptions(::flatbuffers::FlatBufferBuilder &_fbb, const StablehloRngBitGeneratorOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct Conv2DOptionsT : public ::flatbuffers::NativeTable { + typedef Conv2DOptions TableType; + tflite::Padding padding = tflite::Padding_SAME; + int32_t stride_w = 0; + int32_t stride_h = 0; + tflite::ActivationFunctionType fused_activation_function = tflite::ActivationFunctionType_NONE; + int32_t dilation_w_factor = 1; + int32_t dilation_h_factor = 1; + tflite::TensorType quantized_bias_type = tflite::TensorType_FLOAT32; +}; + +struct Conv2DOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef Conv2DOptionsT NativeTableType; + typedef Conv2DOptionsBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_PADDING = 4, + VT_STRIDE_W = 6, + VT_STRIDE_H = 8, + VT_FUSED_ACTIVATION_FUNCTION = 10, + VT_DILATION_W_FACTOR = 12, + VT_DILATION_H_FACTOR = 14, + VT_QUANTIZED_BIAS_TYPE = 16 + }; + tflite::Padding padding() const { + return static_cast(GetField(VT_PADDING, 0)); + } + int32_t stride_w() const { + return GetField(VT_STRIDE_W, 0); + } + int32_t stride_h() const { + return GetField(VT_STRIDE_H, 0); + } + tflite::ActivationFunctionType fused_activation_function() const { + return static_cast(GetField(VT_FUSED_ACTIVATION_FUNCTION, 0)); + } + int32_t dilation_w_factor() const { + return GetField(VT_DILATION_W_FACTOR, 1); + } + int32_t dilation_h_factor() const { + return GetField(VT_DILATION_H_FACTOR, 1); + } + tflite::TensorType quantized_bias_type() const { + return static_cast(GetField(VT_QUANTIZED_BIAS_TYPE, 0)); + } + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_PADDING, 1) && + VerifyField(verifier, VT_STRIDE_W, 4) && + VerifyField(verifier, VT_STRIDE_H, 4) && + VerifyField(verifier, VT_FUSED_ACTIVATION_FUNCTION, 1) && + VerifyField(verifier, VT_DILATION_W_FACTOR, 4) && + VerifyField(verifier, VT_DILATION_H_FACTOR, 4) && + VerifyField(verifier, VT_QUANTIZED_BIAS_TYPE, 1) && + verifier.EndTable(); + } + Conv2DOptionsT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(Conv2DOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + static ::flatbuffers::Offset Pack(::flatbuffers::FlatBufferBuilder &_fbb, const Conv2DOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct Conv2DOptionsBuilder { + typedef Conv2DOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_padding(tflite::Padding padding) { + fbb_.AddElement(Conv2DOptions::VT_PADDING, static_cast(padding), 0); + } + void add_stride_w(int32_t stride_w) { + fbb_.AddElement(Conv2DOptions::VT_STRIDE_W, stride_w, 0); + } + void add_stride_h(int32_t stride_h) { + fbb_.AddElement(Conv2DOptions::VT_STRIDE_H, stride_h, 0); + } + void add_fused_activation_function(tflite::ActivationFunctionType fused_activation_function) { + fbb_.AddElement(Conv2DOptions::VT_FUSED_ACTIVATION_FUNCTION, static_cast(fused_activation_function), 0); + } + void add_dilation_w_factor(int32_t dilation_w_factor) { + fbb_.AddElement(Conv2DOptions::VT_DILATION_W_FACTOR, dilation_w_factor, 1); + } + void add_dilation_h_factor(int32_t dilation_h_factor) { + fbb_.AddElement(Conv2DOptions::VT_DILATION_H_FACTOR, dilation_h_factor, 1); + } + void add_quantized_bias_type(tflite::TensorType quantized_bias_type) { + fbb_.AddElement(Conv2DOptions::VT_QUANTIZED_BIAS_TYPE, static_cast(quantized_bias_type), 0); + } + explicit Conv2DOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateConv2DOptions( + ::flatbuffers::FlatBufferBuilder &_fbb, + tflite::Padding padding = tflite::Padding_SAME, + int32_t stride_w = 0, + int32_t stride_h = 0, + tflite::ActivationFunctionType fused_activation_function = tflite::ActivationFunctionType_NONE, + int32_t dilation_w_factor = 1, + int32_t dilation_h_factor = 1, + tflite::TensorType quantized_bias_type = tflite::TensorType_FLOAT32) { + Conv2DOptionsBuilder builder_(_fbb); + builder_.add_dilation_h_factor(dilation_h_factor); + builder_.add_dilation_w_factor(dilation_w_factor); + builder_.add_stride_h(stride_h); + builder_.add_stride_w(stride_w); + builder_.add_quantized_bias_type(quantized_bias_type); + builder_.add_fused_activation_function(fused_activation_function); + builder_.add_padding(padding); + return builder_.Finish(); +} + +::flatbuffers::Offset CreateConv2DOptions(::flatbuffers::FlatBufferBuilder &_fbb, const Conv2DOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct Conv3DOptionsT : public ::flatbuffers::NativeTable { + typedef Conv3DOptions TableType; + tflite::Padding padding = tflite::Padding_SAME; + int32_t stride_d = 0; + int32_t stride_w = 0; + int32_t stride_h = 0; + tflite::ActivationFunctionType fused_activation_function = tflite::ActivationFunctionType_NONE; + int32_t dilation_d_factor = 1; + int32_t dilation_w_factor = 1; + int32_t dilation_h_factor = 1; +}; + +struct Conv3DOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef Conv3DOptionsT NativeTableType; + typedef Conv3DOptionsBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_PADDING = 4, + VT_STRIDE_D = 6, + VT_STRIDE_W = 8, + VT_STRIDE_H = 10, + VT_FUSED_ACTIVATION_FUNCTION = 12, + VT_DILATION_D_FACTOR = 14, + VT_DILATION_W_FACTOR = 16, + VT_DILATION_H_FACTOR = 18 + }; + tflite::Padding padding() const { + return static_cast(GetField(VT_PADDING, 0)); + } + int32_t stride_d() const { + return GetField(VT_STRIDE_D, 0); + } + int32_t stride_w() const { + return GetField(VT_STRIDE_W, 0); + } + int32_t stride_h() const { + return GetField(VT_STRIDE_H, 0); + } + tflite::ActivationFunctionType fused_activation_function() const { + return static_cast(GetField(VT_FUSED_ACTIVATION_FUNCTION, 0)); + } + int32_t dilation_d_factor() const { + return GetField(VT_DILATION_D_FACTOR, 1); + } + int32_t dilation_w_factor() const { + return GetField(VT_DILATION_W_FACTOR, 1); + } + int32_t dilation_h_factor() const { + return GetField(VT_DILATION_H_FACTOR, 1); + } + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_PADDING, 1) && + VerifyField(verifier, VT_STRIDE_D, 4) && + VerifyField(verifier, VT_STRIDE_W, 4) && + VerifyField(verifier, VT_STRIDE_H, 4) && + VerifyField(verifier, VT_FUSED_ACTIVATION_FUNCTION, 1) && + VerifyField(verifier, VT_DILATION_D_FACTOR, 4) && + VerifyField(verifier, VT_DILATION_W_FACTOR, 4) && + VerifyField(verifier, VT_DILATION_H_FACTOR, 4) && + verifier.EndTable(); + } + Conv3DOptionsT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(Conv3DOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + static ::flatbuffers::Offset Pack(::flatbuffers::FlatBufferBuilder &_fbb, const Conv3DOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct Conv3DOptionsBuilder { + typedef Conv3DOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_padding(tflite::Padding padding) { + fbb_.AddElement(Conv3DOptions::VT_PADDING, static_cast(padding), 0); + } + void add_stride_d(int32_t stride_d) { + fbb_.AddElement(Conv3DOptions::VT_STRIDE_D, stride_d, 0); + } + void add_stride_w(int32_t stride_w) { + fbb_.AddElement(Conv3DOptions::VT_STRIDE_W, stride_w, 0); + } + void add_stride_h(int32_t stride_h) { + fbb_.AddElement(Conv3DOptions::VT_STRIDE_H, stride_h, 0); + } + void add_fused_activation_function(tflite::ActivationFunctionType fused_activation_function) { + fbb_.AddElement(Conv3DOptions::VT_FUSED_ACTIVATION_FUNCTION, static_cast(fused_activation_function), 0); + } + void add_dilation_d_factor(int32_t dilation_d_factor) { + fbb_.AddElement(Conv3DOptions::VT_DILATION_D_FACTOR, dilation_d_factor, 1); + } + void add_dilation_w_factor(int32_t dilation_w_factor) { + fbb_.AddElement(Conv3DOptions::VT_DILATION_W_FACTOR, dilation_w_factor, 1); + } + void add_dilation_h_factor(int32_t dilation_h_factor) { + fbb_.AddElement(Conv3DOptions::VT_DILATION_H_FACTOR, dilation_h_factor, 1); + } + explicit Conv3DOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateConv3DOptions( + ::flatbuffers::FlatBufferBuilder &_fbb, + tflite::Padding padding = tflite::Padding_SAME, + int32_t stride_d = 0, + int32_t stride_w = 0, + int32_t stride_h = 0, + tflite::ActivationFunctionType fused_activation_function = tflite::ActivationFunctionType_NONE, + int32_t dilation_d_factor = 1, + int32_t dilation_w_factor = 1, + int32_t dilation_h_factor = 1) { + Conv3DOptionsBuilder builder_(_fbb); + builder_.add_dilation_h_factor(dilation_h_factor); + builder_.add_dilation_w_factor(dilation_w_factor); + builder_.add_dilation_d_factor(dilation_d_factor); + builder_.add_stride_h(stride_h); + builder_.add_stride_w(stride_w); + builder_.add_stride_d(stride_d); + builder_.add_fused_activation_function(fused_activation_function); + builder_.add_padding(padding); + return builder_.Finish(); +} + +::flatbuffers::Offset CreateConv3DOptions(::flatbuffers::FlatBufferBuilder &_fbb, const Conv3DOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct Pool2DOptionsT : public ::flatbuffers::NativeTable { + typedef Pool2DOptions TableType; + tflite::Padding padding = tflite::Padding_SAME; + int32_t stride_w = 0; + int32_t stride_h = 0; + int32_t filter_width = 0; + int32_t filter_height = 0; + tflite::ActivationFunctionType fused_activation_function = tflite::ActivationFunctionType_NONE; +}; + +struct Pool2DOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef Pool2DOptionsT NativeTableType; + typedef Pool2DOptionsBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_PADDING = 4, + VT_STRIDE_W = 6, + VT_STRIDE_H = 8, + VT_FILTER_WIDTH = 10, + VT_FILTER_HEIGHT = 12, + VT_FUSED_ACTIVATION_FUNCTION = 14 + }; + tflite::Padding padding() const { + return static_cast(GetField(VT_PADDING, 0)); + } + int32_t stride_w() const { + return GetField(VT_STRIDE_W, 0); + } + int32_t stride_h() const { + return GetField(VT_STRIDE_H, 0); + } + int32_t filter_width() const { + return GetField(VT_FILTER_WIDTH, 0); + } + int32_t filter_height() const { + return GetField(VT_FILTER_HEIGHT, 0); + } + tflite::ActivationFunctionType fused_activation_function() const { + return static_cast(GetField(VT_FUSED_ACTIVATION_FUNCTION, 0)); + } + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_PADDING, 1) && + VerifyField(verifier, VT_STRIDE_W, 4) && + VerifyField(verifier, VT_STRIDE_H, 4) && + VerifyField(verifier, VT_FILTER_WIDTH, 4) && + VerifyField(verifier, VT_FILTER_HEIGHT, 4) && + VerifyField(verifier, VT_FUSED_ACTIVATION_FUNCTION, 1) && + verifier.EndTable(); + } + Pool2DOptionsT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(Pool2DOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + static ::flatbuffers::Offset Pack(::flatbuffers::FlatBufferBuilder &_fbb, const Pool2DOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct Pool2DOptionsBuilder { + typedef Pool2DOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_padding(tflite::Padding padding) { + fbb_.AddElement(Pool2DOptions::VT_PADDING, static_cast(padding), 0); + } + void add_stride_w(int32_t stride_w) { + fbb_.AddElement(Pool2DOptions::VT_STRIDE_W, stride_w, 0); + } + void add_stride_h(int32_t stride_h) { + fbb_.AddElement(Pool2DOptions::VT_STRIDE_H, stride_h, 0); + } + void add_filter_width(int32_t filter_width) { + fbb_.AddElement(Pool2DOptions::VT_FILTER_WIDTH, filter_width, 0); + } + void add_filter_height(int32_t filter_height) { + fbb_.AddElement(Pool2DOptions::VT_FILTER_HEIGHT, filter_height, 0); + } + void add_fused_activation_function(tflite::ActivationFunctionType fused_activation_function) { + fbb_.AddElement(Pool2DOptions::VT_FUSED_ACTIVATION_FUNCTION, static_cast(fused_activation_function), 0); + } + explicit Pool2DOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreatePool2DOptions( + ::flatbuffers::FlatBufferBuilder &_fbb, + tflite::Padding padding = tflite::Padding_SAME, + int32_t stride_w = 0, + int32_t stride_h = 0, + int32_t filter_width = 0, + int32_t filter_height = 0, + tflite::ActivationFunctionType fused_activation_function = tflite::ActivationFunctionType_NONE) { + Pool2DOptionsBuilder builder_(_fbb); + builder_.add_filter_height(filter_height); + builder_.add_filter_width(filter_width); + builder_.add_stride_h(stride_h); + builder_.add_stride_w(stride_w); + builder_.add_fused_activation_function(fused_activation_function); + builder_.add_padding(padding); + return builder_.Finish(); +} + +::flatbuffers::Offset CreatePool2DOptions(::flatbuffers::FlatBufferBuilder &_fbb, const Pool2DOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct DepthwiseConv2DOptionsT : public ::flatbuffers::NativeTable { + typedef DepthwiseConv2DOptions TableType; + tflite::Padding padding = tflite::Padding_SAME; + int32_t stride_w = 0; + int32_t stride_h = 0; + int32_t depth_multiplier = 0; + tflite::ActivationFunctionType fused_activation_function = tflite::ActivationFunctionType_NONE; + int32_t dilation_w_factor = 1; + int32_t dilation_h_factor = 1; +}; + +struct DepthwiseConv2DOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef DepthwiseConv2DOptionsT NativeTableType; + typedef DepthwiseConv2DOptionsBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_PADDING = 4, + VT_STRIDE_W = 6, + VT_STRIDE_H = 8, + VT_DEPTH_MULTIPLIER = 10, + VT_FUSED_ACTIVATION_FUNCTION = 12, + VT_DILATION_W_FACTOR = 14, + VT_DILATION_H_FACTOR = 16 + }; + tflite::Padding padding() const { + return static_cast(GetField(VT_PADDING, 0)); + } + int32_t stride_w() const { + return GetField(VT_STRIDE_W, 0); + } + int32_t stride_h() const { + return GetField(VT_STRIDE_H, 0); + } + int32_t depth_multiplier() const { + return GetField(VT_DEPTH_MULTIPLIER, 0); + } + tflite::ActivationFunctionType fused_activation_function() const { + return static_cast(GetField(VT_FUSED_ACTIVATION_FUNCTION, 0)); + } + int32_t dilation_w_factor() const { + return GetField(VT_DILATION_W_FACTOR, 1); + } + int32_t dilation_h_factor() const { + return GetField(VT_DILATION_H_FACTOR, 1); + } + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_PADDING, 1) && + VerifyField(verifier, VT_STRIDE_W, 4) && + VerifyField(verifier, VT_STRIDE_H, 4) && + VerifyField(verifier, VT_DEPTH_MULTIPLIER, 4) && + VerifyField(verifier, VT_FUSED_ACTIVATION_FUNCTION, 1) && + VerifyField(verifier, VT_DILATION_W_FACTOR, 4) && + VerifyField(verifier, VT_DILATION_H_FACTOR, 4) && + verifier.EndTable(); + } + DepthwiseConv2DOptionsT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(DepthwiseConv2DOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + static ::flatbuffers::Offset Pack(::flatbuffers::FlatBufferBuilder &_fbb, const DepthwiseConv2DOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct DepthwiseConv2DOptionsBuilder { + typedef DepthwiseConv2DOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_padding(tflite::Padding padding) { + fbb_.AddElement(DepthwiseConv2DOptions::VT_PADDING, static_cast(padding), 0); + } + void add_stride_w(int32_t stride_w) { + fbb_.AddElement(DepthwiseConv2DOptions::VT_STRIDE_W, stride_w, 0); + } + void add_stride_h(int32_t stride_h) { + fbb_.AddElement(DepthwiseConv2DOptions::VT_STRIDE_H, stride_h, 0); + } + void add_depth_multiplier(int32_t depth_multiplier) { + fbb_.AddElement(DepthwiseConv2DOptions::VT_DEPTH_MULTIPLIER, depth_multiplier, 0); + } + void add_fused_activation_function(tflite::ActivationFunctionType fused_activation_function) { + fbb_.AddElement(DepthwiseConv2DOptions::VT_FUSED_ACTIVATION_FUNCTION, static_cast(fused_activation_function), 0); + } + void add_dilation_w_factor(int32_t dilation_w_factor) { + fbb_.AddElement(DepthwiseConv2DOptions::VT_DILATION_W_FACTOR, dilation_w_factor, 1); + } + void add_dilation_h_factor(int32_t dilation_h_factor) { + fbb_.AddElement(DepthwiseConv2DOptions::VT_DILATION_H_FACTOR, dilation_h_factor, 1); + } + explicit DepthwiseConv2DOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateDepthwiseConv2DOptions( + ::flatbuffers::FlatBufferBuilder &_fbb, + tflite::Padding padding = tflite::Padding_SAME, + int32_t stride_w = 0, + int32_t stride_h = 0, + int32_t depth_multiplier = 0, + tflite::ActivationFunctionType fused_activation_function = tflite::ActivationFunctionType_NONE, + int32_t dilation_w_factor = 1, + int32_t dilation_h_factor = 1) { + DepthwiseConv2DOptionsBuilder builder_(_fbb); + builder_.add_dilation_h_factor(dilation_h_factor); + builder_.add_dilation_w_factor(dilation_w_factor); + builder_.add_depth_multiplier(depth_multiplier); + builder_.add_stride_h(stride_h); + builder_.add_stride_w(stride_w); + builder_.add_fused_activation_function(fused_activation_function); + builder_.add_padding(padding); + return builder_.Finish(); +} + +::flatbuffers::Offset CreateDepthwiseConv2DOptions(::flatbuffers::FlatBufferBuilder &_fbb, const DepthwiseConv2DOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct ConcatEmbeddingsOptionsT : public ::flatbuffers::NativeTable { + typedef ConcatEmbeddingsOptions TableType; + int32_t num_channels = 0; + std::vector num_columns_per_channel{}; + std::vector embedding_dim_per_channel{}; +}; + +struct ConcatEmbeddingsOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef ConcatEmbeddingsOptionsT NativeTableType; + typedef ConcatEmbeddingsOptionsBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_NUM_CHANNELS = 4, + VT_NUM_COLUMNS_PER_CHANNEL = 6, + VT_EMBEDDING_DIM_PER_CHANNEL = 8 + }; + int32_t num_channels() const { + return GetField(VT_NUM_CHANNELS, 0); + } + const ::flatbuffers::Vector *num_columns_per_channel() const { + return GetPointer *>(VT_NUM_COLUMNS_PER_CHANNEL); + } + const ::flatbuffers::Vector *embedding_dim_per_channel() const { + return GetPointer *>(VT_EMBEDDING_DIM_PER_CHANNEL); + } + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_NUM_CHANNELS, 4) && + VerifyOffset(verifier, VT_NUM_COLUMNS_PER_CHANNEL) && + verifier.VerifyVector(num_columns_per_channel()) && + VerifyOffset(verifier, VT_EMBEDDING_DIM_PER_CHANNEL) && + verifier.VerifyVector(embedding_dim_per_channel()) && + verifier.EndTable(); + } + ConcatEmbeddingsOptionsT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(ConcatEmbeddingsOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + static ::flatbuffers::Offset Pack(::flatbuffers::FlatBufferBuilder &_fbb, const ConcatEmbeddingsOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct ConcatEmbeddingsOptionsBuilder { + typedef ConcatEmbeddingsOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_num_channels(int32_t num_channels) { + fbb_.AddElement(ConcatEmbeddingsOptions::VT_NUM_CHANNELS, num_channels, 0); + } + void add_num_columns_per_channel(::flatbuffers::Offset<::flatbuffers::Vector> num_columns_per_channel) { + fbb_.AddOffset(ConcatEmbeddingsOptions::VT_NUM_COLUMNS_PER_CHANNEL, num_columns_per_channel); + } + void add_embedding_dim_per_channel(::flatbuffers::Offset<::flatbuffers::Vector> embedding_dim_per_channel) { + fbb_.AddOffset(ConcatEmbeddingsOptions::VT_EMBEDDING_DIM_PER_CHANNEL, embedding_dim_per_channel); + } + explicit ConcatEmbeddingsOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateConcatEmbeddingsOptions( + ::flatbuffers::FlatBufferBuilder &_fbb, + int32_t num_channels = 0, + ::flatbuffers::Offset<::flatbuffers::Vector> num_columns_per_channel = 0, + ::flatbuffers::Offset<::flatbuffers::Vector> embedding_dim_per_channel = 0) { + ConcatEmbeddingsOptionsBuilder builder_(_fbb); + builder_.add_embedding_dim_per_channel(embedding_dim_per_channel); + builder_.add_num_columns_per_channel(num_columns_per_channel); + builder_.add_num_channels(num_channels); + return builder_.Finish(); +} + +inline ::flatbuffers::Offset CreateConcatEmbeddingsOptionsDirect( + ::flatbuffers::FlatBufferBuilder &_fbb, + int32_t num_channels = 0, + const std::vector *num_columns_per_channel = nullptr, + const std::vector *embedding_dim_per_channel = nullptr) { + auto num_columns_per_channel__ = num_columns_per_channel ? _fbb.CreateVector(*num_columns_per_channel) : 0; + auto embedding_dim_per_channel__ = embedding_dim_per_channel ? _fbb.CreateVector(*embedding_dim_per_channel) : 0; + return tflite::CreateConcatEmbeddingsOptions( + _fbb, + num_channels, + num_columns_per_channel__, + embedding_dim_per_channel__); +} + +::flatbuffers::Offset CreateConcatEmbeddingsOptions(::flatbuffers::FlatBufferBuilder &_fbb, const ConcatEmbeddingsOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct LSHProjectionOptionsT : public ::flatbuffers::NativeTable { + typedef LSHProjectionOptions TableType; + tflite::LSHProjectionType type = tflite::LSHProjectionType_UNKNOWN; +}; + +struct LSHProjectionOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef LSHProjectionOptionsT NativeTableType; + typedef LSHProjectionOptionsBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_TYPE = 4 + }; + tflite::LSHProjectionType type() const { + return static_cast(GetField(VT_TYPE, 0)); + } + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_TYPE, 1) && + verifier.EndTable(); + } + LSHProjectionOptionsT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(LSHProjectionOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + static ::flatbuffers::Offset Pack(::flatbuffers::FlatBufferBuilder &_fbb, const LSHProjectionOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct LSHProjectionOptionsBuilder { + typedef LSHProjectionOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_type(tflite::LSHProjectionType type) { + fbb_.AddElement(LSHProjectionOptions::VT_TYPE, static_cast(type), 0); + } + explicit LSHProjectionOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateLSHProjectionOptions( + ::flatbuffers::FlatBufferBuilder &_fbb, + tflite::LSHProjectionType type = tflite::LSHProjectionType_UNKNOWN) { + LSHProjectionOptionsBuilder builder_(_fbb); + builder_.add_type(type); + return builder_.Finish(); +} + +::flatbuffers::Offset CreateLSHProjectionOptions(::flatbuffers::FlatBufferBuilder &_fbb, const LSHProjectionOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct SVDFOptionsT : public ::flatbuffers::NativeTable { + typedef SVDFOptions TableType; + int32_t rank = 0; + tflite::ActivationFunctionType fused_activation_function = tflite::ActivationFunctionType_NONE; + bool asymmetric_quantize_inputs = false; +}; + +struct SVDFOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef SVDFOptionsT NativeTableType; + typedef SVDFOptionsBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_RANK = 4, + VT_FUSED_ACTIVATION_FUNCTION = 6, + VT_ASYMMETRIC_QUANTIZE_INPUTS = 8 + }; + int32_t rank() const { + return GetField(VT_RANK, 0); + } + tflite::ActivationFunctionType fused_activation_function() const { + return static_cast(GetField(VT_FUSED_ACTIVATION_FUNCTION, 0)); + } + bool asymmetric_quantize_inputs() const { + return GetField(VT_ASYMMETRIC_QUANTIZE_INPUTS, 0) != 0; + } + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_RANK, 4) && + VerifyField(verifier, VT_FUSED_ACTIVATION_FUNCTION, 1) && + VerifyField(verifier, VT_ASYMMETRIC_QUANTIZE_INPUTS, 1) && + verifier.EndTable(); + } + SVDFOptionsT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(SVDFOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + static ::flatbuffers::Offset Pack(::flatbuffers::FlatBufferBuilder &_fbb, const SVDFOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct SVDFOptionsBuilder { + typedef SVDFOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_rank(int32_t rank) { + fbb_.AddElement(SVDFOptions::VT_RANK, rank, 0); + } + void add_fused_activation_function(tflite::ActivationFunctionType fused_activation_function) { + fbb_.AddElement(SVDFOptions::VT_FUSED_ACTIVATION_FUNCTION, static_cast(fused_activation_function), 0); + } + void add_asymmetric_quantize_inputs(bool asymmetric_quantize_inputs) { + fbb_.AddElement(SVDFOptions::VT_ASYMMETRIC_QUANTIZE_INPUTS, static_cast(asymmetric_quantize_inputs), 0); + } + explicit SVDFOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateSVDFOptions( + ::flatbuffers::FlatBufferBuilder &_fbb, + int32_t rank = 0, + tflite::ActivationFunctionType fused_activation_function = tflite::ActivationFunctionType_NONE, + bool asymmetric_quantize_inputs = false) { + SVDFOptionsBuilder builder_(_fbb); + builder_.add_rank(rank); + builder_.add_asymmetric_quantize_inputs(asymmetric_quantize_inputs); + builder_.add_fused_activation_function(fused_activation_function); + return builder_.Finish(); +} + +::flatbuffers::Offset CreateSVDFOptions(::flatbuffers::FlatBufferBuilder &_fbb, const SVDFOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct RNNOptionsT : public ::flatbuffers::NativeTable { + typedef RNNOptions TableType; + tflite::ActivationFunctionType fused_activation_function = tflite::ActivationFunctionType_NONE; + bool asymmetric_quantize_inputs = false; +}; + +struct RNNOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef RNNOptionsT NativeTableType; + typedef RNNOptionsBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_FUSED_ACTIVATION_FUNCTION = 4, + VT_ASYMMETRIC_QUANTIZE_INPUTS = 6 + }; + tflite::ActivationFunctionType fused_activation_function() const { + return static_cast(GetField(VT_FUSED_ACTIVATION_FUNCTION, 0)); + } + bool asymmetric_quantize_inputs() const { + return GetField(VT_ASYMMETRIC_QUANTIZE_INPUTS, 0) != 0; + } + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_FUSED_ACTIVATION_FUNCTION, 1) && + VerifyField(verifier, VT_ASYMMETRIC_QUANTIZE_INPUTS, 1) && + verifier.EndTable(); + } + RNNOptionsT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(RNNOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + static ::flatbuffers::Offset Pack(::flatbuffers::FlatBufferBuilder &_fbb, const RNNOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct RNNOptionsBuilder { + typedef RNNOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_fused_activation_function(tflite::ActivationFunctionType fused_activation_function) { + fbb_.AddElement(RNNOptions::VT_FUSED_ACTIVATION_FUNCTION, static_cast(fused_activation_function), 0); + } + void add_asymmetric_quantize_inputs(bool asymmetric_quantize_inputs) { + fbb_.AddElement(RNNOptions::VT_ASYMMETRIC_QUANTIZE_INPUTS, static_cast(asymmetric_quantize_inputs), 0); + } + explicit RNNOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateRNNOptions( + ::flatbuffers::FlatBufferBuilder &_fbb, + tflite::ActivationFunctionType fused_activation_function = tflite::ActivationFunctionType_NONE, + bool asymmetric_quantize_inputs = false) { + RNNOptionsBuilder builder_(_fbb); + builder_.add_asymmetric_quantize_inputs(asymmetric_quantize_inputs); + builder_.add_fused_activation_function(fused_activation_function); + return builder_.Finish(); +} + +::flatbuffers::Offset CreateRNNOptions(::flatbuffers::FlatBufferBuilder &_fbb, const RNNOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct SequenceRNNOptionsT : public ::flatbuffers::NativeTable { + typedef SequenceRNNOptions TableType; + bool time_major = false; + tflite::ActivationFunctionType fused_activation_function = tflite::ActivationFunctionType_NONE; + bool asymmetric_quantize_inputs = false; +}; + +struct SequenceRNNOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef SequenceRNNOptionsT NativeTableType; + typedef SequenceRNNOptionsBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_TIME_MAJOR = 4, + VT_FUSED_ACTIVATION_FUNCTION = 6, + VT_ASYMMETRIC_QUANTIZE_INPUTS = 8 + }; + bool time_major() const { + return GetField(VT_TIME_MAJOR, 0) != 0; + } + tflite::ActivationFunctionType fused_activation_function() const { + return static_cast(GetField(VT_FUSED_ACTIVATION_FUNCTION, 0)); + } + bool asymmetric_quantize_inputs() const { + return GetField(VT_ASYMMETRIC_QUANTIZE_INPUTS, 0) != 0; + } + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_TIME_MAJOR, 1) && + VerifyField(verifier, VT_FUSED_ACTIVATION_FUNCTION, 1) && + VerifyField(verifier, VT_ASYMMETRIC_QUANTIZE_INPUTS, 1) && + verifier.EndTable(); + } + SequenceRNNOptionsT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(SequenceRNNOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + static ::flatbuffers::Offset Pack(::flatbuffers::FlatBufferBuilder &_fbb, const SequenceRNNOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct SequenceRNNOptionsBuilder { + typedef SequenceRNNOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_time_major(bool time_major) { + fbb_.AddElement(SequenceRNNOptions::VT_TIME_MAJOR, static_cast(time_major), 0); + } + void add_fused_activation_function(tflite::ActivationFunctionType fused_activation_function) { + fbb_.AddElement(SequenceRNNOptions::VT_FUSED_ACTIVATION_FUNCTION, static_cast(fused_activation_function), 0); + } + void add_asymmetric_quantize_inputs(bool asymmetric_quantize_inputs) { + fbb_.AddElement(SequenceRNNOptions::VT_ASYMMETRIC_QUANTIZE_INPUTS, static_cast(asymmetric_quantize_inputs), 0); + } + explicit SequenceRNNOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateSequenceRNNOptions( + ::flatbuffers::FlatBufferBuilder &_fbb, + bool time_major = false, + tflite::ActivationFunctionType fused_activation_function = tflite::ActivationFunctionType_NONE, + bool asymmetric_quantize_inputs = false) { + SequenceRNNOptionsBuilder builder_(_fbb); + builder_.add_asymmetric_quantize_inputs(asymmetric_quantize_inputs); + builder_.add_fused_activation_function(fused_activation_function); + builder_.add_time_major(time_major); + return builder_.Finish(); +} + +::flatbuffers::Offset CreateSequenceRNNOptions(::flatbuffers::FlatBufferBuilder &_fbb, const SequenceRNNOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct BidirectionalSequenceRNNOptionsT : public ::flatbuffers::NativeTable { + typedef BidirectionalSequenceRNNOptions TableType; + bool time_major = false; + tflite::ActivationFunctionType fused_activation_function = tflite::ActivationFunctionType_NONE; + bool merge_outputs = false; + bool asymmetric_quantize_inputs = false; +}; + +struct BidirectionalSequenceRNNOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef BidirectionalSequenceRNNOptionsT NativeTableType; + typedef BidirectionalSequenceRNNOptionsBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_TIME_MAJOR = 4, + VT_FUSED_ACTIVATION_FUNCTION = 6, + VT_MERGE_OUTPUTS = 8, + VT_ASYMMETRIC_QUANTIZE_INPUTS = 10 + }; + bool time_major() const { + return GetField(VT_TIME_MAJOR, 0) != 0; + } + tflite::ActivationFunctionType fused_activation_function() const { + return static_cast(GetField(VT_FUSED_ACTIVATION_FUNCTION, 0)); + } + bool merge_outputs() const { + return GetField(VT_MERGE_OUTPUTS, 0) != 0; + } + bool asymmetric_quantize_inputs() const { + return GetField(VT_ASYMMETRIC_QUANTIZE_INPUTS, 0) != 0; + } + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_TIME_MAJOR, 1) && + VerifyField(verifier, VT_FUSED_ACTIVATION_FUNCTION, 1) && + VerifyField(verifier, VT_MERGE_OUTPUTS, 1) && + VerifyField(verifier, VT_ASYMMETRIC_QUANTIZE_INPUTS, 1) && + verifier.EndTable(); + } + BidirectionalSequenceRNNOptionsT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(BidirectionalSequenceRNNOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + static ::flatbuffers::Offset Pack(::flatbuffers::FlatBufferBuilder &_fbb, const BidirectionalSequenceRNNOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct BidirectionalSequenceRNNOptionsBuilder { + typedef BidirectionalSequenceRNNOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_time_major(bool time_major) { + fbb_.AddElement(BidirectionalSequenceRNNOptions::VT_TIME_MAJOR, static_cast(time_major), 0); + } + void add_fused_activation_function(tflite::ActivationFunctionType fused_activation_function) { + fbb_.AddElement(BidirectionalSequenceRNNOptions::VT_FUSED_ACTIVATION_FUNCTION, static_cast(fused_activation_function), 0); + } + void add_merge_outputs(bool merge_outputs) { + fbb_.AddElement(BidirectionalSequenceRNNOptions::VT_MERGE_OUTPUTS, static_cast(merge_outputs), 0); + } + void add_asymmetric_quantize_inputs(bool asymmetric_quantize_inputs) { + fbb_.AddElement(BidirectionalSequenceRNNOptions::VT_ASYMMETRIC_QUANTIZE_INPUTS, static_cast(asymmetric_quantize_inputs), 0); + } + explicit BidirectionalSequenceRNNOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateBidirectionalSequenceRNNOptions( + ::flatbuffers::FlatBufferBuilder &_fbb, + bool time_major = false, + tflite::ActivationFunctionType fused_activation_function = tflite::ActivationFunctionType_NONE, + bool merge_outputs = false, + bool asymmetric_quantize_inputs = false) { + BidirectionalSequenceRNNOptionsBuilder builder_(_fbb); + builder_.add_asymmetric_quantize_inputs(asymmetric_quantize_inputs); + builder_.add_merge_outputs(merge_outputs); + builder_.add_fused_activation_function(fused_activation_function); + builder_.add_time_major(time_major); + return builder_.Finish(); +} + +::flatbuffers::Offset CreateBidirectionalSequenceRNNOptions(::flatbuffers::FlatBufferBuilder &_fbb, const BidirectionalSequenceRNNOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct FullyConnectedOptionsT : public ::flatbuffers::NativeTable { + typedef FullyConnectedOptions TableType; + tflite::ActivationFunctionType fused_activation_function = tflite::ActivationFunctionType_NONE; + tflite::FullyConnectedOptionsWeightsFormat weights_format = tflite::FullyConnectedOptionsWeightsFormat_DEFAULT; + bool keep_num_dims = false; + bool asymmetric_quantize_inputs = false; + tflite::TensorType quantized_bias_type = tflite::TensorType_FLOAT32; +}; + +struct FullyConnectedOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef FullyConnectedOptionsT NativeTableType; + typedef FullyConnectedOptionsBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_FUSED_ACTIVATION_FUNCTION = 4, + VT_WEIGHTS_FORMAT = 6, + VT_KEEP_NUM_DIMS = 8, + VT_ASYMMETRIC_QUANTIZE_INPUTS = 10, + VT_QUANTIZED_BIAS_TYPE = 12 + }; + tflite::ActivationFunctionType fused_activation_function() const { + return static_cast(GetField(VT_FUSED_ACTIVATION_FUNCTION, 0)); + } + tflite::FullyConnectedOptionsWeightsFormat weights_format() const { + return static_cast(GetField(VT_WEIGHTS_FORMAT, 0)); + } + bool keep_num_dims() const { + return GetField(VT_KEEP_NUM_DIMS, 0) != 0; + } + bool asymmetric_quantize_inputs() const { + return GetField(VT_ASYMMETRIC_QUANTIZE_INPUTS, 0) != 0; + } + tflite::TensorType quantized_bias_type() const { + return static_cast(GetField(VT_QUANTIZED_BIAS_TYPE, 0)); + } + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_FUSED_ACTIVATION_FUNCTION, 1) && + VerifyField(verifier, VT_WEIGHTS_FORMAT, 1) && + VerifyField(verifier, VT_KEEP_NUM_DIMS, 1) && + VerifyField(verifier, VT_ASYMMETRIC_QUANTIZE_INPUTS, 1) && + VerifyField(verifier, VT_QUANTIZED_BIAS_TYPE, 1) && + verifier.EndTable(); + } + FullyConnectedOptionsT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(FullyConnectedOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + static ::flatbuffers::Offset Pack(::flatbuffers::FlatBufferBuilder &_fbb, const FullyConnectedOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct FullyConnectedOptionsBuilder { + typedef FullyConnectedOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_fused_activation_function(tflite::ActivationFunctionType fused_activation_function) { + fbb_.AddElement(FullyConnectedOptions::VT_FUSED_ACTIVATION_FUNCTION, static_cast(fused_activation_function), 0); + } + void add_weights_format(tflite::FullyConnectedOptionsWeightsFormat weights_format) { + fbb_.AddElement(FullyConnectedOptions::VT_WEIGHTS_FORMAT, static_cast(weights_format), 0); + } + void add_keep_num_dims(bool keep_num_dims) { + fbb_.AddElement(FullyConnectedOptions::VT_KEEP_NUM_DIMS, static_cast(keep_num_dims), 0); + } + void add_asymmetric_quantize_inputs(bool asymmetric_quantize_inputs) { + fbb_.AddElement(FullyConnectedOptions::VT_ASYMMETRIC_QUANTIZE_INPUTS, static_cast(asymmetric_quantize_inputs), 0); + } + void add_quantized_bias_type(tflite::TensorType quantized_bias_type) { + fbb_.AddElement(FullyConnectedOptions::VT_QUANTIZED_BIAS_TYPE, static_cast(quantized_bias_type), 0); + } + explicit FullyConnectedOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateFullyConnectedOptions( + ::flatbuffers::FlatBufferBuilder &_fbb, + tflite::ActivationFunctionType fused_activation_function = tflite::ActivationFunctionType_NONE, + tflite::FullyConnectedOptionsWeightsFormat weights_format = tflite::FullyConnectedOptionsWeightsFormat_DEFAULT, + bool keep_num_dims = false, + bool asymmetric_quantize_inputs = false, + tflite::TensorType quantized_bias_type = tflite::TensorType_FLOAT32) { + FullyConnectedOptionsBuilder builder_(_fbb); + builder_.add_quantized_bias_type(quantized_bias_type); + builder_.add_asymmetric_quantize_inputs(asymmetric_quantize_inputs); + builder_.add_keep_num_dims(keep_num_dims); + builder_.add_weights_format(weights_format); + builder_.add_fused_activation_function(fused_activation_function); + return builder_.Finish(); +} + +::flatbuffers::Offset CreateFullyConnectedOptions(::flatbuffers::FlatBufferBuilder &_fbb, const FullyConnectedOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct SoftmaxOptionsT : public ::flatbuffers::NativeTable { + typedef SoftmaxOptions TableType; + float beta = 0.0f; +}; + +struct SoftmaxOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef SoftmaxOptionsT NativeTableType; + typedef SoftmaxOptionsBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_BETA = 4 + }; + float beta() const { + return GetField(VT_BETA, 0.0f); + } + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_BETA, 4) && + verifier.EndTable(); + } + SoftmaxOptionsT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(SoftmaxOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + static ::flatbuffers::Offset Pack(::flatbuffers::FlatBufferBuilder &_fbb, const SoftmaxOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct SoftmaxOptionsBuilder { + typedef SoftmaxOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_beta(float beta) { + fbb_.AddElement(SoftmaxOptions::VT_BETA, beta, 0.0f); + } + explicit SoftmaxOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateSoftmaxOptions( + ::flatbuffers::FlatBufferBuilder &_fbb, + float beta = 0.0f) { + SoftmaxOptionsBuilder builder_(_fbb); + builder_.add_beta(beta); + return builder_.Finish(); +} + +::flatbuffers::Offset CreateSoftmaxOptions(::flatbuffers::FlatBufferBuilder &_fbb, const SoftmaxOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct ConcatenationOptionsT : public ::flatbuffers::NativeTable { + typedef ConcatenationOptions TableType; + int32_t axis = 0; + tflite::ActivationFunctionType fused_activation_function = tflite::ActivationFunctionType_NONE; +}; + +struct ConcatenationOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef ConcatenationOptionsT NativeTableType; + typedef ConcatenationOptionsBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_AXIS = 4, + VT_FUSED_ACTIVATION_FUNCTION = 6 + }; + int32_t axis() const { + return GetField(VT_AXIS, 0); + } + tflite::ActivationFunctionType fused_activation_function() const { + return static_cast(GetField(VT_FUSED_ACTIVATION_FUNCTION, 0)); + } + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_AXIS, 4) && + VerifyField(verifier, VT_FUSED_ACTIVATION_FUNCTION, 1) && + verifier.EndTable(); + } + ConcatenationOptionsT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(ConcatenationOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + static ::flatbuffers::Offset Pack(::flatbuffers::FlatBufferBuilder &_fbb, const ConcatenationOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct ConcatenationOptionsBuilder { + typedef ConcatenationOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_axis(int32_t axis) { + fbb_.AddElement(ConcatenationOptions::VT_AXIS, axis, 0); + } + void add_fused_activation_function(tflite::ActivationFunctionType fused_activation_function) { + fbb_.AddElement(ConcatenationOptions::VT_FUSED_ACTIVATION_FUNCTION, static_cast(fused_activation_function), 0); + } + explicit ConcatenationOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateConcatenationOptions( + ::flatbuffers::FlatBufferBuilder &_fbb, + int32_t axis = 0, + tflite::ActivationFunctionType fused_activation_function = tflite::ActivationFunctionType_NONE) { + ConcatenationOptionsBuilder builder_(_fbb); + builder_.add_axis(axis); + builder_.add_fused_activation_function(fused_activation_function); + return builder_.Finish(); +} + +::flatbuffers::Offset CreateConcatenationOptions(::flatbuffers::FlatBufferBuilder &_fbb, const ConcatenationOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct AddOptionsT : public ::flatbuffers::NativeTable { + typedef AddOptions TableType; + tflite::ActivationFunctionType fused_activation_function = tflite::ActivationFunctionType_NONE; + bool pot_scale_int16 = true; +}; + +struct AddOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef AddOptionsT NativeTableType; + typedef AddOptionsBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_FUSED_ACTIVATION_FUNCTION = 4, + VT_POT_SCALE_INT16 = 6 + }; + tflite::ActivationFunctionType fused_activation_function() const { + return static_cast(GetField(VT_FUSED_ACTIVATION_FUNCTION, 0)); + } + bool pot_scale_int16() const { + return GetField(VT_POT_SCALE_INT16, 1) != 0; + } + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_FUSED_ACTIVATION_FUNCTION, 1) && + VerifyField(verifier, VT_POT_SCALE_INT16, 1) && + verifier.EndTable(); + } + AddOptionsT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(AddOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + static ::flatbuffers::Offset Pack(::flatbuffers::FlatBufferBuilder &_fbb, const AddOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct AddOptionsBuilder { + typedef AddOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_fused_activation_function(tflite::ActivationFunctionType fused_activation_function) { + fbb_.AddElement(AddOptions::VT_FUSED_ACTIVATION_FUNCTION, static_cast(fused_activation_function), 0); + } + void add_pot_scale_int16(bool pot_scale_int16) { + fbb_.AddElement(AddOptions::VT_POT_SCALE_INT16, static_cast(pot_scale_int16), 1); + } + explicit AddOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateAddOptions( + ::flatbuffers::FlatBufferBuilder &_fbb, + tflite::ActivationFunctionType fused_activation_function = tflite::ActivationFunctionType_NONE, + bool pot_scale_int16 = true) { + AddOptionsBuilder builder_(_fbb); + builder_.add_pot_scale_int16(pot_scale_int16); + builder_.add_fused_activation_function(fused_activation_function); + return builder_.Finish(); +} + +::flatbuffers::Offset CreateAddOptions(::flatbuffers::FlatBufferBuilder &_fbb, const AddOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct MulOptionsT : public ::flatbuffers::NativeTable { + typedef MulOptions TableType; + tflite::ActivationFunctionType fused_activation_function = tflite::ActivationFunctionType_NONE; +}; + +struct MulOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef MulOptionsT NativeTableType; + typedef MulOptionsBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_FUSED_ACTIVATION_FUNCTION = 4 + }; + tflite::ActivationFunctionType fused_activation_function() const { + return static_cast(GetField(VT_FUSED_ACTIVATION_FUNCTION, 0)); + } + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_FUSED_ACTIVATION_FUNCTION, 1) && + verifier.EndTable(); + } + MulOptionsT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(MulOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + static ::flatbuffers::Offset Pack(::flatbuffers::FlatBufferBuilder &_fbb, const MulOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct MulOptionsBuilder { + typedef MulOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_fused_activation_function(tflite::ActivationFunctionType fused_activation_function) { + fbb_.AddElement(MulOptions::VT_FUSED_ACTIVATION_FUNCTION, static_cast(fused_activation_function), 0); + } + explicit MulOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateMulOptions( + ::flatbuffers::FlatBufferBuilder &_fbb, + tflite::ActivationFunctionType fused_activation_function = tflite::ActivationFunctionType_NONE) { + MulOptionsBuilder builder_(_fbb); + builder_.add_fused_activation_function(fused_activation_function); + return builder_.Finish(); +} + +::flatbuffers::Offset CreateMulOptions(::flatbuffers::FlatBufferBuilder &_fbb, const MulOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct L2NormOptionsT : public ::flatbuffers::NativeTable { + typedef L2NormOptions TableType; + tflite::ActivationFunctionType fused_activation_function = tflite::ActivationFunctionType_NONE; +}; + +struct L2NormOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef L2NormOptionsT NativeTableType; + typedef L2NormOptionsBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_FUSED_ACTIVATION_FUNCTION = 4 + }; + tflite::ActivationFunctionType fused_activation_function() const { + return static_cast(GetField(VT_FUSED_ACTIVATION_FUNCTION, 0)); + } + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_FUSED_ACTIVATION_FUNCTION, 1) && + verifier.EndTable(); + } + L2NormOptionsT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(L2NormOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + static ::flatbuffers::Offset Pack(::flatbuffers::FlatBufferBuilder &_fbb, const L2NormOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct L2NormOptionsBuilder { + typedef L2NormOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_fused_activation_function(tflite::ActivationFunctionType fused_activation_function) { + fbb_.AddElement(L2NormOptions::VT_FUSED_ACTIVATION_FUNCTION, static_cast(fused_activation_function), 0); + } + explicit L2NormOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateL2NormOptions( + ::flatbuffers::FlatBufferBuilder &_fbb, + tflite::ActivationFunctionType fused_activation_function = tflite::ActivationFunctionType_NONE) { + L2NormOptionsBuilder builder_(_fbb); + builder_.add_fused_activation_function(fused_activation_function); + return builder_.Finish(); +} + +::flatbuffers::Offset CreateL2NormOptions(::flatbuffers::FlatBufferBuilder &_fbb, const L2NormOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct LocalResponseNormalizationOptionsT : public ::flatbuffers::NativeTable { + typedef LocalResponseNormalizationOptions TableType; + int32_t radius = 0; + float bias = 0.0f; + float alpha = 0.0f; + float beta = 0.0f; +}; + +struct LocalResponseNormalizationOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef LocalResponseNormalizationOptionsT NativeTableType; + typedef LocalResponseNormalizationOptionsBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_RADIUS = 4, + VT_BIAS = 6, + VT_ALPHA = 8, + VT_BETA = 10 + }; + int32_t radius() const { + return GetField(VT_RADIUS, 0); + } + float bias() const { + return GetField(VT_BIAS, 0.0f); + } + float alpha() const { + return GetField(VT_ALPHA, 0.0f); + } + float beta() const { + return GetField(VT_BETA, 0.0f); + } + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_RADIUS, 4) && + VerifyField(verifier, VT_BIAS, 4) && + VerifyField(verifier, VT_ALPHA, 4) && + VerifyField(verifier, VT_BETA, 4) && + verifier.EndTable(); + } + LocalResponseNormalizationOptionsT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(LocalResponseNormalizationOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + static ::flatbuffers::Offset Pack(::flatbuffers::FlatBufferBuilder &_fbb, const LocalResponseNormalizationOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct LocalResponseNormalizationOptionsBuilder { + typedef LocalResponseNormalizationOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_radius(int32_t radius) { + fbb_.AddElement(LocalResponseNormalizationOptions::VT_RADIUS, radius, 0); + } + void add_bias(float bias) { + fbb_.AddElement(LocalResponseNormalizationOptions::VT_BIAS, bias, 0.0f); + } + void add_alpha(float alpha) { + fbb_.AddElement(LocalResponseNormalizationOptions::VT_ALPHA, alpha, 0.0f); + } + void add_beta(float beta) { + fbb_.AddElement(LocalResponseNormalizationOptions::VT_BETA, beta, 0.0f); + } + explicit LocalResponseNormalizationOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateLocalResponseNormalizationOptions( + ::flatbuffers::FlatBufferBuilder &_fbb, + int32_t radius = 0, + float bias = 0.0f, + float alpha = 0.0f, + float beta = 0.0f) { + LocalResponseNormalizationOptionsBuilder builder_(_fbb); + builder_.add_beta(beta); + builder_.add_alpha(alpha); + builder_.add_bias(bias); + builder_.add_radius(radius); + return builder_.Finish(); +} + +::flatbuffers::Offset CreateLocalResponseNormalizationOptions(::flatbuffers::FlatBufferBuilder &_fbb, const LocalResponseNormalizationOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct LSTMOptionsT : public ::flatbuffers::NativeTable { + typedef LSTMOptions TableType; + tflite::ActivationFunctionType fused_activation_function = tflite::ActivationFunctionType_NONE; + float cell_clip = 0.0f; + float proj_clip = 0.0f; + tflite::LSTMKernelType kernel_type = tflite::LSTMKernelType_FULL; + bool asymmetric_quantize_inputs = false; +}; + +struct LSTMOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef LSTMOptionsT NativeTableType; + typedef LSTMOptionsBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_FUSED_ACTIVATION_FUNCTION = 4, + VT_CELL_CLIP = 6, + VT_PROJ_CLIP = 8, + VT_KERNEL_TYPE = 10, + VT_ASYMMETRIC_QUANTIZE_INPUTS = 12 + }; + tflite::ActivationFunctionType fused_activation_function() const { + return static_cast(GetField(VT_FUSED_ACTIVATION_FUNCTION, 0)); + } + float cell_clip() const { + return GetField(VT_CELL_CLIP, 0.0f); + } + float proj_clip() const { + return GetField(VT_PROJ_CLIP, 0.0f); + } + tflite::LSTMKernelType kernel_type() const { + return static_cast(GetField(VT_KERNEL_TYPE, 0)); + } + bool asymmetric_quantize_inputs() const { + return GetField(VT_ASYMMETRIC_QUANTIZE_INPUTS, 0) != 0; + } + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_FUSED_ACTIVATION_FUNCTION, 1) && + VerifyField(verifier, VT_CELL_CLIP, 4) && + VerifyField(verifier, VT_PROJ_CLIP, 4) && + VerifyField(verifier, VT_KERNEL_TYPE, 1) && + VerifyField(verifier, VT_ASYMMETRIC_QUANTIZE_INPUTS, 1) && + verifier.EndTable(); + } + LSTMOptionsT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(LSTMOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + static ::flatbuffers::Offset Pack(::flatbuffers::FlatBufferBuilder &_fbb, const LSTMOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct LSTMOptionsBuilder { + typedef LSTMOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_fused_activation_function(tflite::ActivationFunctionType fused_activation_function) { + fbb_.AddElement(LSTMOptions::VT_FUSED_ACTIVATION_FUNCTION, static_cast(fused_activation_function), 0); + } + void add_cell_clip(float cell_clip) { + fbb_.AddElement(LSTMOptions::VT_CELL_CLIP, cell_clip, 0.0f); + } + void add_proj_clip(float proj_clip) { + fbb_.AddElement(LSTMOptions::VT_PROJ_CLIP, proj_clip, 0.0f); + } + void add_kernel_type(tflite::LSTMKernelType kernel_type) { + fbb_.AddElement(LSTMOptions::VT_KERNEL_TYPE, static_cast(kernel_type), 0); + } + void add_asymmetric_quantize_inputs(bool asymmetric_quantize_inputs) { + fbb_.AddElement(LSTMOptions::VT_ASYMMETRIC_QUANTIZE_INPUTS, static_cast(asymmetric_quantize_inputs), 0); + } + explicit LSTMOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateLSTMOptions( + ::flatbuffers::FlatBufferBuilder &_fbb, + tflite::ActivationFunctionType fused_activation_function = tflite::ActivationFunctionType_NONE, + float cell_clip = 0.0f, + float proj_clip = 0.0f, + tflite::LSTMKernelType kernel_type = tflite::LSTMKernelType_FULL, + bool asymmetric_quantize_inputs = false) { + LSTMOptionsBuilder builder_(_fbb); + builder_.add_proj_clip(proj_clip); + builder_.add_cell_clip(cell_clip); + builder_.add_asymmetric_quantize_inputs(asymmetric_quantize_inputs); + builder_.add_kernel_type(kernel_type); + builder_.add_fused_activation_function(fused_activation_function); + return builder_.Finish(); +} + +::flatbuffers::Offset CreateLSTMOptions(::flatbuffers::FlatBufferBuilder &_fbb, const LSTMOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct UnidirectionalSequenceLSTMOptionsT : public ::flatbuffers::NativeTable { + typedef UnidirectionalSequenceLSTMOptions TableType; + tflite::ActivationFunctionType fused_activation_function = tflite::ActivationFunctionType_NONE; + float cell_clip = 0.0f; + float proj_clip = 0.0f; + bool time_major = false; + bool asymmetric_quantize_inputs = false; + bool diagonal_recurrent_tensors = false; +}; + +struct UnidirectionalSequenceLSTMOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef UnidirectionalSequenceLSTMOptionsT NativeTableType; + typedef UnidirectionalSequenceLSTMOptionsBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_FUSED_ACTIVATION_FUNCTION = 4, + VT_CELL_CLIP = 6, + VT_PROJ_CLIP = 8, + VT_TIME_MAJOR = 10, + VT_ASYMMETRIC_QUANTIZE_INPUTS = 12, + VT_DIAGONAL_RECURRENT_TENSORS = 14 + }; + tflite::ActivationFunctionType fused_activation_function() const { + return static_cast(GetField(VT_FUSED_ACTIVATION_FUNCTION, 0)); + } + float cell_clip() const { + return GetField(VT_CELL_CLIP, 0.0f); + } + float proj_clip() const { + return GetField(VT_PROJ_CLIP, 0.0f); + } + bool time_major() const { + return GetField(VT_TIME_MAJOR, 0) != 0; + } + bool asymmetric_quantize_inputs() const { + return GetField(VT_ASYMMETRIC_QUANTIZE_INPUTS, 0) != 0; + } + bool diagonal_recurrent_tensors() const { + return GetField(VT_DIAGONAL_RECURRENT_TENSORS, 0) != 0; + } + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_FUSED_ACTIVATION_FUNCTION, 1) && + VerifyField(verifier, VT_CELL_CLIP, 4) && + VerifyField(verifier, VT_PROJ_CLIP, 4) && + VerifyField(verifier, VT_TIME_MAJOR, 1) && + VerifyField(verifier, VT_ASYMMETRIC_QUANTIZE_INPUTS, 1) && + VerifyField(verifier, VT_DIAGONAL_RECURRENT_TENSORS, 1) && + verifier.EndTable(); + } + UnidirectionalSequenceLSTMOptionsT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(UnidirectionalSequenceLSTMOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + static ::flatbuffers::Offset Pack(::flatbuffers::FlatBufferBuilder &_fbb, const UnidirectionalSequenceLSTMOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct UnidirectionalSequenceLSTMOptionsBuilder { + typedef UnidirectionalSequenceLSTMOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_fused_activation_function(tflite::ActivationFunctionType fused_activation_function) { + fbb_.AddElement(UnidirectionalSequenceLSTMOptions::VT_FUSED_ACTIVATION_FUNCTION, static_cast(fused_activation_function), 0); + } + void add_cell_clip(float cell_clip) { + fbb_.AddElement(UnidirectionalSequenceLSTMOptions::VT_CELL_CLIP, cell_clip, 0.0f); + } + void add_proj_clip(float proj_clip) { + fbb_.AddElement(UnidirectionalSequenceLSTMOptions::VT_PROJ_CLIP, proj_clip, 0.0f); + } + void add_time_major(bool time_major) { + fbb_.AddElement(UnidirectionalSequenceLSTMOptions::VT_TIME_MAJOR, static_cast(time_major), 0); + } + void add_asymmetric_quantize_inputs(bool asymmetric_quantize_inputs) { + fbb_.AddElement(UnidirectionalSequenceLSTMOptions::VT_ASYMMETRIC_QUANTIZE_INPUTS, static_cast(asymmetric_quantize_inputs), 0); + } + void add_diagonal_recurrent_tensors(bool diagonal_recurrent_tensors) { + fbb_.AddElement(UnidirectionalSequenceLSTMOptions::VT_DIAGONAL_RECURRENT_TENSORS, static_cast(diagonal_recurrent_tensors), 0); + } + explicit UnidirectionalSequenceLSTMOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateUnidirectionalSequenceLSTMOptions( + ::flatbuffers::FlatBufferBuilder &_fbb, + tflite::ActivationFunctionType fused_activation_function = tflite::ActivationFunctionType_NONE, + float cell_clip = 0.0f, + float proj_clip = 0.0f, + bool time_major = false, + bool asymmetric_quantize_inputs = false, + bool diagonal_recurrent_tensors = false) { + UnidirectionalSequenceLSTMOptionsBuilder builder_(_fbb); + builder_.add_proj_clip(proj_clip); + builder_.add_cell_clip(cell_clip); + builder_.add_diagonal_recurrent_tensors(diagonal_recurrent_tensors); + builder_.add_asymmetric_quantize_inputs(asymmetric_quantize_inputs); + builder_.add_time_major(time_major); + builder_.add_fused_activation_function(fused_activation_function); + return builder_.Finish(); +} + +::flatbuffers::Offset CreateUnidirectionalSequenceLSTMOptions(::flatbuffers::FlatBufferBuilder &_fbb, const UnidirectionalSequenceLSTMOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct BidirectionalSequenceLSTMOptionsT : public ::flatbuffers::NativeTable { + typedef BidirectionalSequenceLSTMOptions TableType; + tflite::ActivationFunctionType fused_activation_function = tflite::ActivationFunctionType_NONE; + float cell_clip = 0.0f; + float proj_clip = 0.0f; + bool merge_outputs = false; + bool time_major = true; + bool asymmetric_quantize_inputs = false; +}; + +struct BidirectionalSequenceLSTMOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef BidirectionalSequenceLSTMOptionsT NativeTableType; + typedef BidirectionalSequenceLSTMOptionsBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_FUSED_ACTIVATION_FUNCTION = 4, + VT_CELL_CLIP = 6, + VT_PROJ_CLIP = 8, + VT_MERGE_OUTPUTS = 10, + VT_TIME_MAJOR = 12, + VT_ASYMMETRIC_QUANTIZE_INPUTS = 14 + }; + tflite::ActivationFunctionType fused_activation_function() const { + return static_cast(GetField(VT_FUSED_ACTIVATION_FUNCTION, 0)); + } + float cell_clip() const { + return GetField(VT_CELL_CLIP, 0.0f); + } + float proj_clip() const { + return GetField(VT_PROJ_CLIP, 0.0f); + } + bool merge_outputs() const { + return GetField(VT_MERGE_OUTPUTS, 0) != 0; + } + bool time_major() const { + return GetField(VT_TIME_MAJOR, 1) != 0; + } + bool asymmetric_quantize_inputs() const { + return GetField(VT_ASYMMETRIC_QUANTIZE_INPUTS, 0) != 0; + } + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_FUSED_ACTIVATION_FUNCTION, 1) && + VerifyField(verifier, VT_CELL_CLIP, 4) && + VerifyField(verifier, VT_PROJ_CLIP, 4) && + VerifyField(verifier, VT_MERGE_OUTPUTS, 1) && + VerifyField(verifier, VT_TIME_MAJOR, 1) && + VerifyField(verifier, VT_ASYMMETRIC_QUANTIZE_INPUTS, 1) && + verifier.EndTable(); + } + BidirectionalSequenceLSTMOptionsT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(BidirectionalSequenceLSTMOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + static ::flatbuffers::Offset Pack(::flatbuffers::FlatBufferBuilder &_fbb, const BidirectionalSequenceLSTMOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct BidirectionalSequenceLSTMOptionsBuilder { + typedef BidirectionalSequenceLSTMOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_fused_activation_function(tflite::ActivationFunctionType fused_activation_function) { + fbb_.AddElement(BidirectionalSequenceLSTMOptions::VT_FUSED_ACTIVATION_FUNCTION, static_cast(fused_activation_function), 0); + } + void add_cell_clip(float cell_clip) { + fbb_.AddElement(BidirectionalSequenceLSTMOptions::VT_CELL_CLIP, cell_clip, 0.0f); + } + void add_proj_clip(float proj_clip) { + fbb_.AddElement(BidirectionalSequenceLSTMOptions::VT_PROJ_CLIP, proj_clip, 0.0f); + } + void add_merge_outputs(bool merge_outputs) { + fbb_.AddElement(BidirectionalSequenceLSTMOptions::VT_MERGE_OUTPUTS, static_cast(merge_outputs), 0); + } + void add_time_major(bool time_major) { + fbb_.AddElement(BidirectionalSequenceLSTMOptions::VT_TIME_MAJOR, static_cast(time_major), 1); + } + void add_asymmetric_quantize_inputs(bool asymmetric_quantize_inputs) { + fbb_.AddElement(BidirectionalSequenceLSTMOptions::VT_ASYMMETRIC_QUANTIZE_INPUTS, static_cast(asymmetric_quantize_inputs), 0); + } + explicit BidirectionalSequenceLSTMOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateBidirectionalSequenceLSTMOptions( + ::flatbuffers::FlatBufferBuilder &_fbb, + tflite::ActivationFunctionType fused_activation_function = tflite::ActivationFunctionType_NONE, + float cell_clip = 0.0f, + float proj_clip = 0.0f, + bool merge_outputs = false, + bool time_major = true, + bool asymmetric_quantize_inputs = false) { + BidirectionalSequenceLSTMOptionsBuilder builder_(_fbb); + builder_.add_proj_clip(proj_clip); + builder_.add_cell_clip(cell_clip); + builder_.add_asymmetric_quantize_inputs(asymmetric_quantize_inputs); + builder_.add_time_major(time_major); + builder_.add_merge_outputs(merge_outputs); + builder_.add_fused_activation_function(fused_activation_function); + return builder_.Finish(); +} + +::flatbuffers::Offset CreateBidirectionalSequenceLSTMOptions(::flatbuffers::FlatBufferBuilder &_fbb, const BidirectionalSequenceLSTMOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct ResizeBilinearOptionsT : public ::flatbuffers::NativeTable { + typedef ResizeBilinearOptions TableType; + bool align_corners = false; + bool half_pixel_centers = false; +}; + +struct ResizeBilinearOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef ResizeBilinearOptionsT NativeTableType; + typedef ResizeBilinearOptionsBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_ALIGN_CORNERS = 8, + VT_HALF_PIXEL_CENTERS = 10 + }; + bool align_corners() const { + return GetField(VT_ALIGN_CORNERS, 0) != 0; + } + bool half_pixel_centers() const { + return GetField(VT_HALF_PIXEL_CENTERS, 0) != 0; + } + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_ALIGN_CORNERS, 1) && + VerifyField(verifier, VT_HALF_PIXEL_CENTERS, 1) && + verifier.EndTable(); + } + ResizeBilinearOptionsT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(ResizeBilinearOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + static ::flatbuffers::Offset Pack(::flatbuffers::FlatBufferBuilder &_fbb, const ResizeBilinearOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct ResizeBilinearOptionsBuilder { + typedef ResizeBilinearOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_align_corners(bool align_corners) { + fbb_.AddElement(ResizeBilinearOptions::VT_ALIGN_CORNERS, static_cast(align_corners), 0); + } + void add_half_pixel_centers(bool half_pixel_centers) { + fbb_.AddElement(ResizeBilinearOptions::VT_HALF_PIXEL_CENTERS, static_cast(half_pixel_centers), 0); + } + explicit ResizeBilinearOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateResizeBilinearOptions( + ::flatbuffers::FlatBufferBuilder &_fbb, + bool align_corners = false, + bool half_pixel_centers = false) { + ResizeBilinearOptionsBuilder builder_(_fbb); + builder_.add_half_pixel_centers(half_pixel_centers); + builder_.add_align_corners(align_corners); + return builder_.Finish(); +} + +::flatbuffers::Offset CreateResizeBilinearOptions(::flatbuffers::FlatBufferBuilder &_fbb, const ResizeBilinearOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct ResizeNearestNeighborOptionsT : public ::flatbuffers::NativeTable { + typedef ResizeNearestNeighborOptions TableType; + bool align_corners = false; + bool half_pixel_centers = false; +}; + +struct ResizeNearestNeighborOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef ResizeNearestNeighborOptionsT NativeTableType; + typedef ResizeNearestNeighborOptionsBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_ALIGN_CORNERS = 4, + VT_HALF_PIXEL_CENTERS = 6 + }; + bool align_corners() const { + return GetField(VT_ALIGN_CORNERS, 0) != 0; + } + bool half_pixel_centers() const { + return GetField(VT_HALF_PIXEL_CENTERS, 0) != 0; + } + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_ALIGN_CORNERS, 1) && + VerifyField(verifier, VT_HALF_PIXEL_CENTERS, 1) && + verifier.EndTable(); + } + ResizeNearestNeighborOptionsT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(ResizeNearestNeighborOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + static ::flatbuffers::Offset Pack(::flatbuffers::FlatBufferBuilder &_fbb, const ResizeNearestNeighborOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct ResizeNearestNeighborOptionsBuilder { + typedef ResizeNearestNeighborOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_align_corners(bool align_corners) { + fbb_.AddElement(ResizeNearestNeighborOptions::VT_ALIGN_CORNERS, static_cast(align_corners), 0); + } + void add_half_pixel_centers(bool half_pixel_centers) { + fbb_.AddElement(ResizeNearestNeighborOptions::VT_HALF_PIXEL_CENTERS, static_cast(half_pixel_centers), 0); + } + explicit ResizeNearestNeighborOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateResizeNearestNeighborOptions( + ::flatbuffers::FlatBufferBuilder &_fbb, + bool align_corners = false, + bool half_pixel_centers = false) { + ResizeNearestNeighborOptionsBuilder builder_(_fbb); + builder_.add_half_pixel_centers(half_pixel_centers); + builder_.add_align_corners(align_corners); + return builder_.Finish(); +} + +::flatbuffers::Offset CreateResizeNearestNeighborOptions(::flatbuffers::FlatBufferBuilder &_fbb, const ResizeNearestNeighborOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct CallOptionsT : public ::flatbuffers::NativeTable { + typedef CallOptions TableType; + uint32_t subgraph = 0; +}; + +struct CallOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef CallOptionsT NativeTableType; + typedef CallOptionsBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_SUBGRAPH = 4 + }; + uint32_t subgraph() const { + return GetField(VT_SUBGRAPH, 0); + } + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_SUBGRAPH, 4) && + verifier.EndTable(); + } + CallOptionsT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(CallOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + static ::flatbuffers::Offset Pack(::flatbuffers::FlatBufferBuilder &_fbb, const CallOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct CallOptionsBuilder { + typedef CallOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_subgraph(uint32_t subgraph) { + fbb_.AddElement(CallOptions::VT_SUBGRAPH, subgraph, 0); + } + explicit CallOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateCallOptions( + ::flatbuffers::FlatBufferBuilder &_fbb, + uint32_t subgraph = 0) { + CallOptionsBuilder builder_(_fbb); + builder_.add_subgraph(subgraph); + return builder_.Finish(); +} + +::flatbuffers::Offset CreateCallOptions(::flatbuffers::FlatBufferBuilder &_fbb, const CallOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct PadOptionsT : public ::flatbuffers::NativeTable { + typedef PadOptions TableType; +}; + +struct PadOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef PadOptionsT NativeTableType; + typedef PadOptionsBuilder Builder; + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + verifier.EndTable(); + } + PadOptionsT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(PadOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + static ::flatbuffers::Offset Pack(::flatbuffers::FlatBufferBuilder &_fbb, const PadOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct PadOptionsBuilder { + typedef PadOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + explicit PadOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreatePadOptions( + ::flatbuffers::FlatBufferBuilder &_fbb) { + PadOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +::flatbuffers::Offset CreatePadOptions(::flatbuffers::FlatBufferBuilder &_fbb, const PadOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct PadV2OptionsT : public ::flatbuffers::NativeTable { + typedef PadV2Options TableType; +}; + +struct PadV2Options FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef PadV2OptionsT NativeTableType; + typedef PadV2OptionsBuilder Builder; + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + verifier.EndTable(); + } + PadV2OptionsT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(PadV2OptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + static ::flatbuffers::Offset Pack(::flatbuffers::FlatBufferBuilder &_fbb, const PadV2OptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct PadV2OptionsBuilder { + typedef PadV2Options Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + explicit PadV2OptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreatePadV2Options( + ::flatbuffers::FlatBufferBuilder &_fbb) { + PadV2OptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +::flatbuffers::Offset CreatePadV2Options(::flatbuffers::FlatBufferBuilder &_fbb, const PadV2OptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct ReshapeOptionsT : public ::flatbuffers::NativeTable { + typedef ReshapeOptions TableType; + std::vector new_shape{}; +}; + +struct ReshapeOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef ReshapeOptionsT NativeTableType; + typedef ReshapeOptionsBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_NEW_SHAPE = 4 + }; + const ::flatbuffers::Vector *new_shape() const { + return GetPointer *>(VT_NEW_SHAPE); + } + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyOffset(verifier, VT_NEW_SHAPE) && + verifier.VerifyVector(new_shape()) && + verifier.EndTable(); + } + ReshapeOptionsT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(ReshapeOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + static ::flatbuffers::Offset Pack(::flatbuffers::FlatBufferBuilder &_fbb, const ReshapeOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct ReshapeOptionsBuilder { + typedef ReshapeOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_new_shape(::flatbuffers::Offset<::flatbuffers::Vector> new_shape) { + fbb_.AddOffset(ReshapeOptions::VT_NEW_SHAPE, new_shape); + } + explicit ReshapeOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateReshapeOptions( + ::flatbuffers::FlatBufferBuilder &_fbb, + ::flatbuffers::Offset<::flatbuffers::Vector> new_shape = 0) { + ReshapeOptionsBuilder builder_(_fbb); + builder_.add_new_shape(new_shape); + return builder_.Finish(); +} + +inline ::flatbuffers::Offset CreateReshapeOptionsDirect( + ::flatbuffers::FlatBufferBuilder &_fbb, + const std::vector *new_shape = nullptr) { + auto new_shape__ = new_shape ? _fbb.CreateVector(*new_shape) : 0; + return tflite::CreateReshapeOptions( + _fbb, + new_shape__); +} + +::flatbuffers::Offset CreateReshapeOptions(::flatbuffers::FlatBufferBuilder &_fbb, const ReshapeOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct SpaceToBatchNDOptionsT : public ::flatbuffers::NativeTable { + typedef SpaceToBatchNDOptions TableType; +}; + +struct SpaceToBatchNDOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef SpaceToBatchNDOptionsT NativeTableType; + typedef SpaceToBatchNDOptionsBuilder Builder; + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + verifier.EndTable(); + } + SpaceToBatchNDOptionsT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(SpaceToBatchNDOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + static ::flatbuffers::Offset Pack(::flatbuffers::FlatBufferBuilder &_fbb, const SpaceToBatchNDOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct SpaceToBatchNDOptionsBuilder { + typedef SpaceToBatchNDOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + explicit SpaceToBatchNDOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateSpaceToBatchNDOptions( + ::flatbuffers::FlatBufferBuilder &_fbb) { + SpaceToBatchNDOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +::flatbuffers::Offset CreateSpaceToBatchNDOptions(::flatbuffers::FlatBufferBuilder &_fbb, const SpaceToBatchNDOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct BatchToSpaceNDOptionsT : public ::flatbuffers::NativeTable { + typedef BatchToSpaceNDOptions TableType; +}; + +struct BatchToSpaceNDOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef BatchToSpaceNDOptionsT NativeTableType; + typedef BatchToSpaceNDOptionsBuilder Builder; + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + verifier.EndTable(); + } + BatchToSpaceNDOptionsT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(BatchToSpaceNDOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + static ::flatbuffers::Offset Pack(::flatbuffers::FlatBufferBuilder &_fbb, const BatchToSpaceNDOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct BatchToSpaceNDOptionsBuilder { + typedef BatchToSpaceNDOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + explicit BatchToSpaceNDOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateBatchToSpaceNDOptions( + ::flatbuffers::FlatBufferBuilder &_fbb) { + BatchToSpaceNDOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +::flatbuffers::Offset CreateBatchToSpaceNDOptions(::flatbuffers::FlatBufferBuilder &_fbb, const BatchToSpaceNDOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct SkipGramOptionsT : public ::flatbuffers::NativeTable { + typedef SkipGramOptions TableType; + int32_t ngram_size = 0; + int32_t max_skip_size = 0; + bool include_all_ngrams = false; +}; + +struct SkipGramOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef SkipGramOptionsT NativeTableType; + typedef SkipGramOptionsBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_NGRAM_SIZE = 4, + VT_MAX_SKIP_SIZE = 6, + VT_INCLUDE_ALL_NGRAMS = 8 + }; + int32_t ngram_size() const { + return GetField(VT_NGRAM_SIZE, 0); + } + int32_t max_skip_size() const { + return GetField(VT_MAX_SKIP_SIZE, 0); + } + bool include_all_ngrams() const { + return GetField(VT_INCLUDE_ALL_NGRAMS, 0) != 0; + } + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_NGRAM_SIZE, 4) && + VerifyField(verifier, VT_MAX_SKIP_SIZE, 4) && + VerifyField(verifier, VT_INCLUDE_ALL_NGRAMS, 1) && + verifier.EndTable(); + } + SkipGramOptionsT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(SkipGramOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + static ::flatbuffers::Offset Pack(::flatbuffers::FlatBufferBuilder &_fbb, const SkipGramOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct SkipGramOptionsBuilder { + typedef SkipGramOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_ngram_size(int32_t ngram_size) { + fbb_.AddElement(SkipGramOptions::VT_NGRAM_SIZE, ngram_size, 0); + } + void add_max_skip_size(int32_t max_skip_size) { + fbb_.AddElement(SkipGramOptions::VT_MAX_SKIP_SIZE, max_skip_size, 0); + } + void add_include_all_ngrams(bool include_all_ngrams) { + fbb_.AddElement(SkipGramOptions::VT_INCLUDE_ALL_NGRAMS, static_cast(include_all_ngrams), 0); + } + explicit SkipGramOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateSkipGramOptions( + ::flatbuffers::FlatBufferBuilder &_fbb, + int32_t ngram_size = 0, + int32_t max_skip_size = 0, + bool include_all_ngrams = false) { + SkipGramOptionsBuilder builder_(_fbb); + builder_.add_max_skip_size(max_skip_size); + builder_.add_ngram_size(ngram_size); + builder_.add_include_all_ngrams(include_all_ngrams); + return builder_.Finish(); +} + +::flatbuffers::Offset CreateSkipGramOptions(::flatbuffers::FlatBufferBuilder &_fbb, const SkipGramOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct SpaceToDepthOptionsT : public ::flatbuffers::NativeTable { + typedef SpaceToDepthOptions TableType; + int32_t block_size = 0; +}; + +struct SpaceToDepthOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef SpaceToDepthOptionsT NativeTableType; + typedef SpaceToDepthOptionsBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_BLOCK_SIZE = 4 + }; + int32_t block_size() const { + return GetField(VT_BLOCK_SIZE, 0); + } + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_BLOCK_SIZE, 4) && + verifier.EndTable(); + } + SpaceToDepthOptionsT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(SpaceToDepthOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + static ::flatbuffers::Offset Pack(::flatbuffers::FlatBufferBuilder &_fbb, const SpaceToDepthOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct SpaceToDepthOptionsBuilder { + typedef SpaceToDepthOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_block_size(int32_t block_size) { + fbb_.AddElement(SpaceToDepthOptions::VT_BLOCK_SIZE, block_size, 0); + } + explicit SpaceToDepthOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateSpaceToDepthOptions( + ::flatbuffers::FlatBufferBuilder &_fbb, + int32_t block_size = 0) { + SpaceToDepthOptionsBuilder builder_(_fbb); + builder_.add_block_size(block_size); + return builder_.Finish(); +} + +::flatbuffers::Offset CreateSpaceToDepthOptions(::flatbuffers::FlatBufferBuilder &_fbb, const SpaceToDepthOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct DepthToSpaceOptionsT : public ::flatbuffers::NativeTable { + typedef DepthToSpaceOptions TableType; + int32_t block_size = 0; +}; + +struct DepthToSpaceOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef DepthToSpaceOptionsT NativeTableType; + typedef DepthToSpaceOptionsBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_BLOCK_SIZE = 4 + }; + int32_t block_size() const { + return GetField(VT_BLOCK_SIZE, 0); + } + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_BLOCK_SIZE, 4) && + verifier.EndTable(); + } + DepthToSpaceOptionsT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(DepthToSpaceOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + static ::flatbuffers::Offset Pack(::flatbuffers::FlatBufferBuilder &_fbb, const DepthToSpaceOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct DepthToSpaceOptionsBuilder { + typedef DepthToSpaceOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_block_size(int32_t block_size) { + fbb_.AddElement(DepthToSpaceOptions::VT_BLOCK_SIZE, block_size, 0); + } + explicit DepthToSpaceOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateDepthToSpaceOptions( + ::flatbuffers::FlatBufferBuilder &_fbb, + int32_t block_size = 0) { + DepthToSpaceOptionsBuilder builder_(_fbb); + builder_.add_block_size(block_size); + return builder_.Finish(); +} + +::flatbuffers::Offset CreateDepthToSpaceOptions(::flatbuffers::FlatBufferBuilder &_fbb, const DepthToSpaceOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct SubOptionsT : public ::flatbuffers::NativeTable { + typedef SubOptions TableType; + tflite::ActivationFunctionType fused_activation_function = tflite::ActivationFunctionType_NONE; + bool pot_scale_int16 = true; +}; + +struct SubOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef SubOptionsT NativeTableType; + typedef SubOptionsBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_FUSED_ACTIVATION_FUNCTION = 4, + VT_POT_SCALE_INT16 = 6 + }; + tflite::ActivationFunctionType fused_activation_function() const { + return static_cast(GetField(VT_FUSED_ACTIVATION_FUNCTION, 0)); + } + bool pot_scale_int16() const { + return GetField(VT_POT_SCALE_INT16, 1) != 0; + } + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_FUSED_ACTIVATION_FUNCTION, 1) && + VerifyField(verifier, VT_POT_SCALE_INT16, 1) && + verifier.EndTable(); + } + SubOptionsT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(SubOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + static ::flatbuffers::Offset Pack(::flatbuffers::FlatBufferBuilder &_fbb, const SubOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct SubOptionsBuilder { + typedef SubOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_fused_activation_function(tflite::ActivationFunctionType fused_activation_function) { + fbb_.AddElement(SubOptions::VT_FUSED_ACTIVATION_FUNCTION, static_cast(fused_activation_function), 0); + } + void add_pot_scale_int16(bool pot_scale_int16) { + fbb_.AddElement(SubOptions::VT_POT_SCALE_INT16, static_cast(pot_scale_int16), 1); + } + explicit SubOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateSubOptions( + ::flatbuffers::FlatBufferBuilder &_fbb, + tflite::ActivationFunctionType fused_activation_function = tflite::ActivationFunctionType_NONE, + bool pot_scale_int16 = true) { + SubOptionsBuilder builder_(_fbb); + builder_.add_pot_scale_int16(pot_scale_int16); + builder_.add_fused_activation_function(fused_activation_function); + return builder_.Finish(); +} + +::flatbuffers::Offset CreateSubOptions(::flatbuffers::FlatBufferBuilder &_fbb, const SubOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct DivOptionsT : public ::flatbuffers::NativeTable { + typedef DivOptions TableType; + tflite::ActivationFunctionType fused_activation_function = tflite::ActivationFunctionType_NONE; +}; + +struct DivOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef DivOptionsT NativeTableType; + typedef DivOptionsBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_FUSED_ACTIVATION_FUNCTION = 4 + }; + tflite::ActivationFunctionType fused_activation_function() const { + return static_cast(GetField(VT_FUSED_ACTIVATION_FUNCTION, 0)); + } + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_FUSED_ACTIVATION_FUNCTION, 1) && + verifier.EndTable(); + } + DivOptionsT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(DivOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + static ::flatbuffers::Offset Pack(::flatbuffers::FlatBufferBuilder &_fbb, const DivOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct DivOptionsBuilder { + typedef DivOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_fused_activation_function(tflite::ActivationFunctionType fused_activation_function) { + fbb_.AddElement(DivOptions::VT_FUSED_ACTIVATION_FUNCTION, static_cast(fused_activation_function), 0); + } + explicit DivOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateDivOptions( + ::flatbuffers::FlatBufferBuilder &_fbb, + tflite::ActivationFunctionType fused_activation_function = tflite::ActivationFunctionType_NONE) { + DivOptionsBuilder builder_(_fbb); + builder_.add_fused_activation_function(fused_activation_function); + return builder_.Finish(); +} + +::flatbuffers::Offset CreateDivOptions(::flatbuffers::FlatBufferBuilder &_fbb, const DivOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct TopKV2OptionsT : public ::flatbuffers::NativeTable { + typedef TopKV2Options TableType; +}; + +struct TopKV2Options FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef TopKV2OptionsT NativeTableType; + typedef TopKV2OptionsBuilder Builder; + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + verifier.EndTable(); + } + TopKV2OptionsT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(TopKV2OptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + static ::flatbuffers::Offset Pack(::flatbuffers::FlatBufferBuilder &_fbb, const TopKV2OptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct TopKV2OptionsBuilder { + typedef TopKV2Options Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + explicit TopKV2OptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateTopKV2Options( + ::flatbuffers::FlatBufferBuilder &_fbb) { + TopKV2OptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +::flatbuffers::Offset CreateTopKV2Options(::flatbuffers::FlatBufferBuilder &_fbb, const TopKV2OptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct EmbeddingLookupSparseOptionsT : public ::flatbuffers::NativeTable { + typedef EmbeddingLookupSparseOptions TableType; + tflite::CombinerType combiner = tflite::CombinerType_SUM; +}; + +struct EmbeddingLookupSparseOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef EmbeddingLookupSparseOptionsT NativeTableType; + typedef EmbeddingLookupSparseOptionsBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_COMBINER = 4 + }; + tflite::CombinerType combiner() const { + return static_cast(GetField(VT_COMBINER, 0)); + } + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_COMBINER, 1) && + verifier.EndTable(); + } + EmbeddingLookupSparseOptionsT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(EmbeddingLookupSparseOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + static ::flatbuffers::Offset Pack(::flatbuffers::FlatBufferBuilder &_fbb, const EmbeddingLookupSparseOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct EmbeddingLookupSparseOptionsBuilder { + typedef EmbeddingLookupSparseOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_combiner(tflite::CombinerType combiner) { + fbb_.AddElement(EmbeddingLookupSparseOptions::VT_COMBINER, static_cast(combiner), 0); + } + explicit EmbeddingLookupSparseOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateEmbeddingLookupSparseOptions( + ::flatbuffers::FlatBufferBuilder &_fbb, + tflite::CombinerType combiner = tflite::CombinerType_SUM) { + EmbeddingLookupSparseOptionsBuilder builder_(_fbb); + builder_.add_combiner(combiner); + return builder_.Finish(); +} + +::flatbuffers::Offset CreateEmbeddingLookupSparseOptions(::flatbuffers::FlatBufferBuilder &_fbb, const EmbeddingLookupSparseOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct GatherOptionsT : public ::flatbuffers::NativeTable { + typedef GatherOptions TableType; + int32_t axis = 0; + int32_t batch_dims = 0; +}; + +struct GatherOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef GatherOptionsT NativeTableType; + typedef GatherOptionsBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_AXIS = 4, + VT_BATCH_DIMS = 6 + }; + int32_t axis() const { + return GetField(VT_AXIS, 0); + } + int32_t batch_dims() const { + return GetField(VT_BATCH_DIMS, 0); + } + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_AXIS, 4) && + VerifyField(verifier, VT_BATCH_DIMS, 4) && + verifier.EndTable(); + } + GatherOptionsT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(GatherOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + static ::flatbuffers::Offset Pack(::flatbuffers::FlatBufferBuilder &_fbb, const GatherOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct GatherOptionsBuilder { + typedef GatherOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_axis(int32_t axis) { + fbb_.AddElement(GatherOptions::VT_AXIS, axis, 0); + } + void add_batch_dims(int32_t batch_dims) { + fbb_.AddElement(GatherOptions::VT_BATCH_DIMS, batch_dims, 0); + } + explicit GatherOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateGatherOptions( + ::flatbuffers::FlatBufferBuilder &_fbb, + int32_t axis = 0, + int32_t batch_dims = 0) { + GatherOptionsBuilder builder_(_fbb); + builder_.add_batch_dims(batch_dims); + builder_.add_axis(axis); + return builder_.Finish(); +} + +::flatbuffers::Offset CreateGatherOptions(::flatbuffers::FlatBufferBuilder &_fbb, const GatherOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct TransposeOptionsT : public ::flatbuffers::NativeTable { + typedef TransposeOptions TableType; +}; + +struct TransposeOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef TransposeOptionsT NativeTableType; + typedef TransposeOptionsBuilder Builder; + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + verifier.EndTable(); + } + TransposeOptionsT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(TransposeOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + static ::flatbuffers::Offset Pack(::flatbuffers::FlatBufferBuilder &_fbb, const TransposeOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct TransposeOptionsBuilder { + typedef TransposeOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + explicit TransposeOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateTransposeOptions( + ::flatbuffers::FlatBufferBuilder &_fbb) { + TransposeOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +::flatbuffers::Offset CreateTransposeOptions(::flatbuffers::FlatBufferBuilder &_fbb, const TransposeOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct ExpOptionsT : public ::flatbuffers::NativeTable { + typedef ExpOptions TableType; +}; + +struct ExpOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef ExpOptionsT NativeTableType; + typedef ExpOptionsBuilder Builder; + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + verifier.EndTable(); + } + ExpOptionsT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(ExpOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + static ::flatbuffers::Offset Pack(::flatbuffers::FlatBufferBuilder &_fbb, const ExpOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct ExpOptionsBuilder { + typedef ExpOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + explicit ExpOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateExpOptions( + ::flatbuffers::FlatBufferBuilder &_fbb) { + ExpOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +::flatbuffers::Offset CreateExpOptions(::flatbuffers::FlatBufferBuilder &_fbb, const ExpOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct CosOptionsT : public ::flatbuffers::NativeTable { + typedef CosOptions TableType; +}; + +struct CosOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef CosOptionsT NativeTableType; + typedef CosOptionsBuilder Builder; + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + verifier.EndTable(); + } + CosOptionsT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(CosOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + static ::flatbuffers::Offset Pack(::flatbuffers::FlatBufferBuilder &_fbb, const CosOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct CosOptionsBuilder { + typedef CosOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + explicit CosOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateCosOptions( + ::flatbuffers::FlatBufferBuilder &_fbb) { + CosOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +::flatbuffers::Offset CreateCosOptions(::flatbuffers::FlatBufferBuilder &_fbb, const CosOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct ReducerOptionsT : public ::flatbuffers::NativeTable { + typedef ReducerOptions TableType; + bool keep_dims = false; +}; + +struct ReducerOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef ReducerOptionsT NativeTableType; + typedef ReducerOptionsBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_KEEP_DIMS = 4 + }; + bool keep_dims() const { + return GetField(VT_KEEP_DIMS, 0) != 0; + } + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_KEEP_DIMS, 1) && + verifier.EndTable(); + } + ReducerOptionsT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(ReducerOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + static ::flatbuffers::Offset Pack(::flatbuffers::FlatBufferBuilder &_fbb, const ReducerOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct ReducerOptionsBuilder { + typedef ReducerOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_keep_dims(bool keep_dims) { + fbb_.AddElement(ReducerOptions::VT_KEEP_DIMS, static_cast(keep_dims), 0); + } + explicit ReducerOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateReducerOptions( + ::flatbuffers::FlatBufferBuilder &_fbb, + bool keep_dims = false) { + ReducerOptionsBuilder builder_(_fbb); + builder_.add_keep_dims(keep_dims); + return builder_.Finish(); +} + +::flatbuffers::Offset CreateReducerOptions(::flatbuffers::FlatBufferBuilder &_fbb, const ReducerOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct SqueezeOptionsT : public ::flatbuffers::NativeTable { + typedef SqueezeOptions TableType; + std::vector squeeze_dims{}; +}; + +struct SqueezeOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef SqueezeOptionsT NativeTableType; + typedef SqueezeOptionsBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_SQUEEZE_DIMS = 4 + }; + const ::flatbuffers::Vector *squeeze_dims() const { + return GetPointer *>(VT_SQUEEZE_DIMS); + } + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyOffset(verifier, VT_SQUEEZE_DIMS) && + verifier.VerifyVector(squeeze_dims()) && + verifier.EndTable(); + } + SqueezeOptionsT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(SqueezeOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + static ::flatbuffers::Offset Pack(::flatbuffers::FlatBufferBuilder &_fbb, const SqueezeOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct SqueezeOptionsBuilder { + typedef SqueezeOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_squeeze_dims(::flatbuffers::Offset<::flatbuffers::Vector> squeeze_dims) { + fbb_.AddOffset(SqueezeOptions::VT_SQUEEZE_DIMS, squeeze_dims); + } + explicit SqueezeOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateSqueezeOptions( + ::flatbuffers::FlatBufferBuilder &_fbb, + ::flatbuffers::Offset<::flatbuffers::Vector> squeeze_dims = 0) { + SqueezeOptionsBuilder builder_(_fbb); + builder_.add_squeeze_dims(squeeze_dims); + return builder_.Finish(); +} + +inline ::flatbuffers::Offset CreateSqueezeOptionsDirect( + ::flatbuffers::FlatBufferBuilder &_fbb, + const std::vector *squeeze_dims = nullptr) { + auto squeeze_dims__ = squeeze_dims ? _fbb.CreateVector(*squeeze_dims) : 0; + return tflite::CreateSqueezeOptions( + _fbb, + squeeze_dims__); +} + +::flatbuffers::Offset CreateSqueezeOptions(::flatbuffers::FlatBufferBuilder &_fbb, const SqueezeOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct SplitOptionsT : public ::flatbuffers::NativeTable { + typedef SplitOptions TableType; + int32_t num_splits = 0; +}; + +struct SplitOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef SplitOptionsT NativeTableType; + typedef SplitOptionsBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_NUM_SPLITS = 4 + }; + int32_t num_splits() const { + return GetField(VT_NUM_SPLITS, 0); + } + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_NUM_SPLITS, 4) && + verifier.EndTable(); + } + SplitOptionsT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(SplitOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + static ::flatbuffers::Offset Pack(::flatbuffers::FlatBufferBuilder &_fbb, const SplitOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct SplitOptionsBuilder { + typedef SplitOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_num_splits(int32_t num_splits) { + fbb_.AddElement(SplitOptions::VT_NUM_SPLITS, num_splits, 0); + } + explicit SplitOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateSplitOptions( + ::flatbuffers::FlatBufferBuilder &_fbb, + int32_t num_splits = 0) { + SplitOptionsBuilder builder_(_fbb); + builder_.add_num_splits(num_splits); + return builder_.Finish(); +} + +::flatbuffers::Offset CreateSplitOptions(::flatbuffers::FlatBufferBuilder &_fbb, const SplitOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct SplitVOptionsT : public ::flatbuffers::NativeTable { + typedef SplitVOptions TableType; + int32_t num_splits = 0; +}; + +struct SplitVOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef SplitVOptionsT NativeTableType; + typedef SplitVOptionsBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_NUM_SPLITS = 4 + }; + int32_t num_splits() const { + return GetField(VT_NUM_SPLITS, 0); + } + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_NUM_SPLITS, 4) && + verifier.EndTable(); + } + SplitVOptionsT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(SplitVOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + static ::flatbuffers::Offset Pack(::flatbuffers::FlatBufferBuilder &_fbb, const SplitVOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct SplitVOptionsBuilder { + typedef SplitVOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_num_splits(int32_t num_splits) { + fbb_.AddElement(SplitVOptions::VT_NUM_SPLITS, num_splits, 0); + } + explicit SplitVOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateSplitVOptions( + ::flatbuffers::FlatBufferBuilder &_fbb, + int32_t num_splits = 0) { + SplitVOptionsBuilder builder_(_fbb); + builder_.add_num_splits(num_splits); + return builder_.Finish(); +} + +::flatbuffers::Offset CreateSplitVOptions(::flatbuffers::FlatBufferBuilder &_fbb, const SplitVOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct StridedSliceOptionsT : public ::flatbuffers::NativeTable { + typedef StridedSliceOptions TableType; + int32_t begin_mask = 0; + int32_t end_mask = 0; + int32_t ellipsis_mask = 0; + int32_t new_axis_mask = 0; + int32_t shrink_axis_mask = 0; + bool offset = false; +}; + +struct StridedSliceOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef StridedSliceOptionsT NativeTableType; + typedef StridedSliceOptionsBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_BEGIN_MASK = 4, + VT_END_MASK = 6, + VT_ELLIPSIS_MASK = 8, + VT_NEW_AXIS_MASK = 10, + VT_SHRINK_AXIS_MASK = 12, + VT_OFFSET = 14 + }; + int32_t begin_mask() const { + return GetField(VT_BEGIN_MASK, 0); + } + int32_t end_mask() const { + return GetField(VT_END_MASK, 0); + } + int32_t ellipsis_mask() const { + return GetField(VT_ELLIPSIS_MASK, 0); + } + int32_t new_axis_mask() const { + return GetField(VT_NEW_AXIS_MASK, 0); + } + int32_t shrink_axis_mask() const { + return GetField(VT_SHRINK_AXIS_MASK, 0); + } + bool offset() const { + return GetField(VT_OFFSET, 0) != 0; + } + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_BEGIN_MASK, 4) && + VerifyField(verifier, VT_END_MASK, 4) && + VerifyField(verifier, VT_ELLIPSIS_MASK, 4) && + VerifyField(verifier, VT_NEW_AXIS_MASK, 4) && + VerifyField(verifier, VT_SHRINK_AXIS_MASK, 4) && + VerifyField(verifier, VT_OFFSET, 1) && + verifier.EndTable(); + } + StridedSliceOptionsT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(StridedSliceOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + static ::flatbuffers::Offset Pack(::flatbuffers::FlatBufferBuilder &_fbb, const StridedSliceOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct StridedSliceOptionsBuilder { + typedef StridedSliceOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_begin_mask(int32_t begin_mask) { + fbb_.AddElement(StridedSliceOptions::VT_BEGIN_MASK, begin_mask, 0); + } + void add_end_mask(int32_t end_mask) { + fbb_.AddElement(StridedSliceOptions::VT_END_MASK, end_mask, 0); + } + void add_ellipsis_mask(int32_t ellipsis_mask) { + fbb_.AddElement(StridedSliceOptions::VT_ELLIPSIS_MASK, ellipsis_mask, 0); + } + void add_new_axis_mask(int32_t new_axis_mask) { + fbb_.AddElement(StridedSliceOptions::VT_NEW_AXIS_MASK, new_axis_mask, 0); + } + void add_shrink_axis_mask(int32_t shrink_axis_mask) { + fbb_.AddElement(StridedSliceOptions::VT_SHRINK_AXIS_MASK, shrink_axis_mask, 0); + } + void add_offset(bool offset) { + fbb_.AddElement(StridedSliceOptions::VT_OFFSET, static_cast(offset), 0); + } + explicit StridedSliceOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateStridedSliceOptions( + ::flatbuffers::FlatBufferBuilder &_fbb, + int32_t begin_mask = 0, + int32_t end_mask = 0, + int32_t ellipsis_mask = 0, + int32_t new_axis_mask = 0, + int32_t shrink_axis_mask = 0, + bool offset = false) { + StridedSliceOptionsBuilder builder_(_fbb); + builder_.add_shrink_axis_mask(shrink_axis_mask); + builder_.add_new_axis_mask(new_axis_mask); + builder_.add_ellipsis_mask(ellipsis_mask); + builder_.add_end_mask(end_mask); + builder_.add_begin_mask(begin_mask); + builder_.add_offset(offset); + return builder_.Finish(); +} + +::flatbuffers::Offset CreateStridedSliceOptions(::flatbuffers::FlatBufferBuilder &_fbb, const StridedSliceOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct LogSoftmaxOptionsT : public ::flatbuffers::NativeTable { + typedef LogSoftmaxOptions TableType; +}; + +struct LogSoftmaxOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef LogSoftmaxOptionsT NativeTableType; + typedef LogSoftmaxOptionsBuilder Builder; + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + verifier.EndTable(); + } + LogSoftmaxOptionsT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(LogSoftmaxOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + static ::flatbuffers::Offset Pack(::flatbuffers::FlatBufferBuilder &_fbb, const LogSoftmaxOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct LogSoftmaxOptionsBuilder { + typedef LogSoftmaxOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + explicit LogSoftmaxOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateLogSoftmaxOptions( + ::flatbuffers::FlatBufferBuilder &_fbb) { + LogSoftmaxOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +::flatbuffers::Offset CreateLogSoftmaxOptions(::flatbuffers::FlatBufferBuilder &_fbb, const LogSoftmaxOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct CastOptionsT : public ::flatbuffers::NativeTable { + typedef CastOptions TableType; + tflite::TensorType in_data_type = tflite::TensorType_FLOAT32; + tflite::TensorType out_data_type = tflite::TensorType_FLOAT32; +}; + +struct CastOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef CastOptionsT NativeTableType; + typedef CastOptionsBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_IN_DATA_TYPE = 4, + VT_OUT_DATA_TYPE = 6 + }; + tflite::TensorType in_data_type() const { + return static_cast(GetField(VT_IN_DATA_TYPE, 0)); + } + tflite::TensorType out_data_type() const { + return static_cast(GetField(VT_OUT_DATA_TYPE, 0)); + } + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_IN_DATA_TYPE, 1) && + VerifyField(verifier, VT_OUT_DATA_TYPE, 1) && + verifier.EndTable(); + } + CastOptionsT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(CastOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + static ::flatbuffers::Offset Pack(::flatbuffers::FlatBufferBuilder &_fbb, const CastOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct CastOptionsBuilder { + typedef CastOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_in_data_type(tflite::TensorType in_data_type) { + fbb_.AddElement(CastOptions::VT_IN_DATA_TYPE, static_cast(in_data_type), 0); + } + void add_out_data_type(tflite::TensorType out_data_type) { + fbb_.AddElement(CastOptions::VT_OUT_DATA_TYPE, static_cast(out_data_type), 0); + } + explicit CastOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateCastOptions( + ::flatbuffers::FlatBufferBuilder &_fbb, + tflite::TensorType in_data_type = tflite::TensorType_FLOAT32, + tflite::TensorType out_data_type = tflite::TensorType_FLOAT32) { + CastOptionsBuilder builder_(_fbb); + builder_.add_out_data_type(out_data_type); + builder_.add_in_data_type(in_data_type); + return builder_.Finish(); +} + +::flatbuffers::Offset CreateCastOptions(::flatbuffers::FlatBufferBuilder &_fbb, const CastOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct DequantizeOptionsT : public ::flatbuffers::NativeTable { + typedef DequantizeOptions TableType; +}; + +struct DequantizeOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef DequantizeOptionsT NativeTableType; + typedef DequantizeOptionsBuilder Builder; + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + verifier.EndTable(); + } + DequantizeOptionsT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(DequantizeOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + static ::flatbuffers::Offset Pack(::flatbuffers::FlatBufferBuilder &_fbb, const DequantizeOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct DequantizeOptionsBuilder { + typedef DequantizeOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + explicit DequantizeOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateDequantizeOptions( + ::flatbuffers::FlatBufferBuilder &_fbb) { + DequantizeOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +::flatbuffers::Offset CreateDequantizeOptions(::flatbuffers::FlatBufferBuilder &_fbb, const DequantizeOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct MaximumMinimumOptionsT : public ::flatbuffers::NativeTable { + typedef MaximumMinimumOptions TableType; +}; + +struct MaximumMinimumOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef MaximumMinimumOptionsT NativeTableType; + typedef MaximumMinimumOptionsBuilder Builder; + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + verifier.EndTable(); + } + MaximumMinimumOptionsT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(MaximumMinimumOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + static ::flatbuffers::Offset Pack(::flatbuffers::FlatBufferBuilder &_fbb, const MaximumMinimumOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct MaximumMinimumOptionsBuilder { + typedef MaximumMinimumOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + explicit MaximumMinimumOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateMaximumMinimumOptions( + ::flatbuffers::FlatBufferBuilder &_fbb) { + MaximumMinimumOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +::flatbuffers::Offset CreateMaximumMinimumOptions(::flatbuffers::FlatBufferBuilder &_fbb, const MaximumMinimumOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct TileOptionsT : public ::flatbuffers::NativeTable { + typedef TileOptions TableType; +}; + +struct TileOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef TileOptionsT NativeTableType; + typedef TileOptionsBuilder Builder; + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + verifier.EndTable(); + } + TileOptionsT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(TileOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + static ::flatbuffers::Offset Pack(::flatbuffers::FlatBufferBuilder &_fbb, const TileOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct TileOptionsBuilder { + typedef TileOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + explicit TileOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateTileOptions( + ::flatbuffers::FlatBufferBuilder &_fbb) { + TileOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +::flatbuffers::Offset CreateTileOptions(::flatbuffers::FlatBufferBuilder &_fbb, const TileOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct ArgMaxOptionsT : public ::flatbuffers::NativeTable { + typedef ArgMaxOptions TableType; + tflite::TensorType output_type = tflite::TensorType_FLOAT32; +}; + +struct ArgMaxOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef ArgMaxOptionsT NativeTableType; + typedef ArgMaxOptionsBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_OUTPUT_TYPE = 4 + }; + tflite::TensorType output_type() const { + return static_cast(GetField(VT_OUTPUT_TYPE, 0)); + } + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_OUTPUT_TYPE, 1) && + verifier.EndTable(); + } + ArgMaxOptionsT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(ArgMaxOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + static ::flatbuffers::Offset Pack(::flatbuffers::FlatBufferBuilder &_fbb, const ArgMaxOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct ArgMaxOptionsBuilder { + typedef ArgMaxOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_output_type(tflite::TensorType output_type) { + fbb_.AddElement(ArgMaxOptions::VT_OUTPUT_TYPE, static_cast(output_type), 0); + } + explicit ArgMaxOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateArgMaxOptions( + ::flatbuffers::FlatBufferBuilder &_fbb, + tflite::TensorType output_type = tflite::TensorType_FLOAT32) { + ArgMaxOptionsBuilder builder_(_fbb); + builder_.add_output_type(output_type); + return builder_.Finish(); +} + +::flatbuffers::Offset CreateArgMaxOptions(::flatbuffers::FlatBufferBuilder &_fbb, const ArgMaxOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct ArgMinOptionsT : public ::flatbuffers::NativeTable { + typedef ArgMinOptions TableType; + tflite::TensorType output_type = tflite::TensorType_FLOAT32; +}; + +struct ArgMinOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef ArgMinOptionsT NativeTableType; + typedef ArgMinOptionsBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_OUTPUT_TYPE = 4 + }; + tflite::TensorType output_type() const { + return static_cast(GetField(VT_OUTPUT_TYPE, 0)); + } + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_OUTPUT_TYPE, 1) && + verifier.EndTable(); + } + ArgMinOptionsT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(ArgMinOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + static ::flatbuffers::Offset Pack(::flatbuffers::FlatBufferBuilder &_fbb, const ArgMinOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct ArgMinOptionsBuilder { + typedef ArgMinOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_output_type(tflite::TensorType output_type) { + fbb_.AddElement(ArgMinOptions::VT_OUTPUT_TYPE, static_cast(output_type), 0); + } + explicit ArgMinOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateArgMinOptions( + ::flatbuffers::FlatBufferBuilder &_fbb, + tflite::TensorType output_type = tflite::TensorType_FLOAT32) { + ArgMinOptionsBuilder builder_(_fbb); + builder_.add_output_type(output_type); + return builder_.Finish(); +} + +::flatbuffers::Offset CreateArgMinOptions(::flatbuffers::FlatBufferBuilder &_fbb, const ArgMinOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct GreaterOptionsT : public ::flatbuffers::NativeTable { + typedef GreaterOptions TableType; +}; + +struct GreaterOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef GreaterOptionsT NativeTableType; + typedef GreaterOptionsBuilder Builder; + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + verifier.EndTable(); + } + GreaterOptionsT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(GreaterOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + static ::flatbuffers::Offset Pack(::flatbuffers::FlatBufferBuilder &_fbb, const GreaterOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct GreaterOptionsBuilder { + typedef GreaterOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + explicit GreaterOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateGreaterOptions( + ::flatbuffers::FlatBufferBuilder &_fbb) { + GreaterOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +::flatbuffers::Offset CreateGreaterOptions(::flatbuffers::FlatBufferBuilder &_fbb, const GreaterOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct GreaterEqualOptionsT : public ::flatbuffers::NativeTable { + typedef GreaterEqualOptions TableType; +}; + +struct GreaterEqualOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef GreaterEqualOptionsT NativeTableType; + typedef GreaterEqualOptionsBuilder Builder; + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + verifier.EndTable(); + } + GreaterEqualOptionsT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(GreaterEqualOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + static ::flatbuffers::Offset Pack(::flatbuffers::FlatBufferBuilder &_fbb, const GreaterEqualOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct GreaterEqualOptionsBuilder { + typedef GreaterEqualOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + explicit GreaterEqualOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateGreaterEqualOptions( + ::flatbuffers::FlatBufferBuilder &_fbb) { + GreaterEqualOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +::flatbuffers::Offset CreateGreaterEqualOptions(::flatbuffers::FlatBufferBuilder &_fbb, const GreaterEqualOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct LessOptionsT : public ::flatbuffers::NativeTable { + typedef LessOptions TableType; +}; + +struct LessOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef LessOptionsT NativeTableType; + typedef LessOptionsBuilder Builder; + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + verifier.EndTable(); + } + LessOptionsT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(LessOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + static ::flatbuffers::Offset Pack(::flatbuffers::FlatBufferBuilder &_fbb, const LessOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct LessOptionsBuilder { + typedef LessOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + explicit LessOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateLessOptions( + ::flatbuffers::FlatBufferBuilder &_fbb) { + LessOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +::flatbuffers::Offset CreateLessOptions(::flatbuffers::FlatBufferBuilder &_fbb, const LessOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct LessEqualOptionsT : public ::flatbuffers::NativeTable { + typedef LessEqualOptions TableType; +}; + +struct LessEqualOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef LessEqualOptionsT NativeTableType; + typedef LessEqualOptionsBuilder Builder; + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + verifier.EndTable(); + } + LessEqualOptionsT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(LessEqualOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + static ::flatbuffers::Offset Pack(::flatbuffers::FlatBufferBuilder &_fbb, const LessEqualOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct LessEqualOptionsBuilder { + typedef LessEqualOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + explicit LessEqualOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateLessEqualOptions( + ::flatbuffers::FlatBufferBuilder &_fbb) { + LessEqualOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +::flatbuffers::Offset CreateLessEqualOptions(::flatbuffers::FlatBufferBuilder &_fbb, const LessEqualOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct NegOptionsT : public ::flatbuffers::NativeTable { + typedef NegOptions TableType; +}; + +struct NegOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef NegOptionsT NativeTableType; + typedef NegOptionsBuilder Builder; + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + verifier.EndTable(); + } + NegOptionsT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(NegOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + static ::flatbuffers::Offset Pack(::flatbuffers::FlatBufferBuilder &_fbb, const NegOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct NegOptionsBuilder { + typedef NegOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + explicit NegOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateNegOptions( + ::flatbuffers::FlatBufferBuilder &_fbb) { + NegOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +::flatbuffers::Offset CreateNegOptions(::flatbuffers::FlatBufferBuilder &_fbb, const NegOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct SelectOptionsT : public ::flatbuffers::NativeTable { + typedef SelectOptions TableType; +}; + +struct SelectOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef SelectOptionsT NativeTableType; + typedef SelectOptionsBuilder Builder; + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + verifier.EndTable(); + } + SelectOptionsT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(SelectOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + static ::flatbuffers::Offset Pack(::flatbuffers::FlatBufferBuilder &_fbb, const SelectOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct SelectOptionsBuilder { + typedef SelectOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + explicit SelectOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateSelectOptions( + ::flatbuffers::FlatBufferBuilder &_fbb) { + SelectOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +::flatbuffers::Offset CreateSelectOptions(::flatbuffers::FlatBufferBuilder &_fbb, const SelectOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct SliceOptionsT : public ::flatbuffers::NativeTable { + typedef SliceOptions TableType; +}; + +struct SliceOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef SliceOptionsT NativeTableType; + typedef SliceOptionsBuilder Builder; + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + verifier.EndTable(); + } + SliceOptionsT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(SliceOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + static ::flatbuffers::Offset Pack(::flatbuffers::FlatBufferBuilder &_fbb, const SliceOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct SliceOptionsBuilder { + typedef SliceOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + explicit SliceOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateSliceOptions( + ::flatbuffers::FlatBufferBuilder &_fbb) { + SliceOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +::flatbuffers::Offset CreateSliceOptions(::flatbuffers::FlatBufferBuilder &_fbb, const SliceOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct TransposeConvOptionsT : public ::flatbuffers::NativeTable { + typedef TransposeConvOptions TableType; + tflite::Padding padding = tflite::Padding_SAME; + int32_t stride_w = 0; + int32_t stride_h = 0; + tflite::ActivationFunctionType fused_activation_function = tflite::ActivationFunctionType_NONE; + tflite::TensorType quantized_bias_type = tflite::TensorType_FLOAT32; +}; + +struct TransposeConvOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef TransposeConvOptionsT NativeTableType; + typedef TransposeConvOptionsBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_PADDING = 4, + VT_STRIDE_W = 6, + VT_STRIDE_H = 8, + VT_FUSED_ACTIVATION_FUNCTION = 10, + VT_QUANTIZED_BIAS_TYPE = 12 + }; + tflite::Padding padding() const { + return static_cast(GetField(VT_PADDING, 0)); + } + int32_t stride_w() const { + return GetField(VT_STRIDE_W, 0); + } + int32_t stride_h() const { + return GetField(VT_STRIDE_H, 0); + } + tflite::ActivationFunctionType fused_activation_function() const { + return static_cast(GetField(VT_FUSED_ACTIVATION_FUNCTION, 0)); + } + tflite::TensorType quantized_bias_type() const { + return static_cast(GetField(VT_QUANTIZED_BIAS_TYPE, 0)); + } + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_PADDING, 1) && + VerifyField(verifier, VT_STRIDE_W, 4) && + VerifyField(verifier, VT_STRIDE_H, 4) && + VerifyField(verifier, VT_FUSED_ACTIVATION_FUNCTION, 1) && + VerifyField(verifier, VT_QUANTIZED_BIAS_TYPE, 1) && + verifier.EndTable(); + } + TransposeConvOptionsT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(TransposeConvOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + static ::flatbuffers::Offset Pack(::flatbuffers::FlatBufferBuilder &_fbb, const TransposeConvOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct TransposeConvOptionsBuilder { + typedef TransposeConvOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_padding(tflite::Padding padding) { + fbb_.AddElement(TransposeConvOptions::VT_PADDING, static_cast(padding), 0); + } + void add_stride_w(int32_t stride_w) { + fbb_.AddElement(TransposeConvOptions::VT_STRIDE_W, stride_w, 0); + } + void add_stride_h(int32_t stride_h) { + fbb_.AddElement(TransposeConvOptions::VT_STRIDE_H, stride_h, 0); + } + void add_fused_activation_function(tflite::ActivationFunctionType fused_activation_function) { + fbb_.AddElement(TransposeConvOptions::VT_FUSED_ACTIVATION_FUNCTION, static_cast(fused_activation_function), 0); + } + void add_quantized_bias_type(tflite::TensorType quantized_bias_type) { + fbb_.AddElement(TransposeConvOptions::VT_QUANTIZED_BIAS_TYPE, static_cast(quantized_bias_type), 0); + } + explicit TransposeConvOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateTransposeConvOptions( + ::flatbuffers::FlatBufferBuilder &_fbb, + tflite::Padding padding = tflite::Padding_SAME, + int32_t stride_w = 0, + int32_t stride_h = 0, + tflite::ActivationFunctionType fused_activation_function = tflite::ActivationFunctionType_NONE, + tflite::TensorType quantized_bias_type = tflite::TensorType_FLOAT32) { + TransposeConvOptionsBuilder builder_(_fbb); + builder_.add_stride_h(stride_h); + builder_.add_stride_w(stride_w); + builder_.add_quantized_bias_type(quantized_bias_type); + builder_.add_fused_activation_function(fused_activation_function); + builder_.add_padding(padding); + return builder_.Finish(); +} + +::flatbuffers::Offset CreateTransposeConvOptions(::flatbuffers::FlatBufferBuilder &_fbb, const TransposeConvOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct ExpandDimsOptionsT : public ::flatbuffers::NativeTable { + typedef ExpandDimsOptions TableType; +}; + +struct ExpandDimsOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef ExpandDimsOptionsT NativeTableType; + typedef ExpandDimsOptionsBuilder Builder; + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + verifier.EndTable(); + } + ExpandDimsOptionsT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(ExpandDimsOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + static ::flatbuffers::Offset Pack(::flatbuffers::FlatBufferBuilder &_fbb, const ExpandDimsOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct ExpandDimsOptionsBuilder { + typedef ExpandDimsOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + explicit ExpandDimsOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateExpandDimsOptions( + ::flatbuffers::FlatBufferBuilder &_fbb) { + ExpandDimsOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +::flatbuffers::Offset CreateExpandDimsOptions(::flatbuffers::FlatBufferBuilder &_fbb, const ExpandDimsOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct SparseToDenseOptionsT : public ::flatbuffers::NativeTable { + typedef SparseToDenseOptions TableType; + bool validate_indices = false; +}; + +struct SparseToDenseOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef SparseToDenseOptionsT NativeTableType; + typedef SparseToDenseOptionsBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_VALIDATE_INDICES = 4 + }; + bool validate_indices() const { + return GetField(VT_VALIDATE_INDICES, 0) != 0; + } + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_VALIDATE_INDICES, 1) && + verifier.EndTable(); + } + SparseToDenseOptionsT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(SparseToDenseOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + static ::flatbuffers::Offset Pack(::flatbuffers::FlatBufferBuilder &_fbb, const SparseToDenseOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct SparseToDenseOptionsBuilder { + typedef SparseToDenseOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_validate_indices(bool validate_indices) { + fbb_.AddElement(SparseToDenseOptions::VT_VALIDATE_INDICES, static_cast(validate_indices), 0); + } + explicit SparseToDenseOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateSparseToDenseOptions( + ::flatbuffers::FlatBufferBuilder &_fbb, + bool validate_indices = false) { + SparseToDenseOptionsBuilder builder_(_fbb); + builder_.add_validate_indices(validate_indices); + return builder_.Finish(); +} + +::flatbuffers::Offset CreateSparseToDenseOptions(::flatbuffers::FlatBufferBuilder &_fbb, const SparseToDenseOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct EqualOptionsT : public ::flatbuffers::NativeTable { + typedef EqualOptions TableType; +}; + +struct EqualOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef EqualOptionsT NativeTableType; + typedef EqualOptionsBuilder Builder; + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + verifier.EndTable(); + } + EqualOptionsT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(EqualOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + static ::flatbuffers::Offset Pack(::flatbuffers::FlatBufferBuilder &_fbb, const EqualOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct EqualOptionsBuilder { + typedef EqualOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + explicit EqualOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateEqualOptions( + ::flatbuffers::FlatBufferBuilder &_fbb) { + EqualOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +::flatbuffers::Offset CreateEqualOptions(::flatbuffers::FlatBufferBuilder &_fbb, const EqualOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct NotEqualOptionsT : public ::flatbuffers::NativeTable { + typedef NotEqualOptions TableType; +}; + +struct NotEqualOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef NotEqualOptionsT NativeTableType; + typedef NotEqualOptionsBuilder Builder; + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + verifier.EndTable(); + } + NotEqualOptionsT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(NotEqualOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + static ::flatbuffers::Offset Pack(::flatbuffers::FlatBufferBuilder &_fbb, const NotEqualOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct NotEqualOptionsBuilder { + typedef NotEqualOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + explicit NotEqualOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateNotEqualOptions( + ::flatbuffers::FlatBufferBuilder &_fbb) { + NotEqualOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +::flatbuffers::Offset CreateNotEqualOptions(::flatbuffers::FlatBufferBuilder &_fbb, const NotEqualOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct ShapeOptionsT : public ::flatbuffers::NativeTable { + typedef ShapeOptions TableType; + tflite::TensorType out_type = tflite::TensorType_FLOAT32; +}; + +struct ShapeOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef ShapeOptionsT NativeTableType; + typedef ShapeOptionsBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_OUT_TYPE = 4 + }; + tflite::TensorType out_type() const { + return static_cast(GetField(VT_OUT_TYPE, 0)); + } + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_OUT_TYPE, 1) && + verifier.EndTable(); + } + ShapeOptionsT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(ShapeOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + static ::flatbuffers::Offset Pack(::flatbuffers::FlatBufferBuilder &_fbb, const ShapeOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct ShapeOptionsBuilder { + typedef ShapeOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_out_type(tflite::TensorType out_type) { + fbb_.AddElement(ShapeOptions::VT_OUT_TYPE, static_cast(out_type), 0); + } + explicit ShapeOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateShapeOptions( + ::flatbuffers::FlatBufferBuilder &_fbb, + tflite::TensorType out_type = tflite::TensorType_FLOAT32) { + ShapeOptionsBuilder builder_(_fbb); + builder_.add_out_type(out_type); + return builder_.Finish(); +} + +::flatbuffers::Offset CreateShapeOptions(::flatbuffers::FlatBufferBuilder &_fbb, const ShapeOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct RankOptionsT : public ::flatbuffers::NativeTable { + typedef RankOptions TableType; +}; + +struct RankOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef RankOptionsT NativeTableType; + typedef RankOptionsBuilder Builder; + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + verifier.EndTable(); + } + RankOptionsT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(RankOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + static ::flatbuffers::Offset Pack(::flatbuffers::FlatBufferBuilder &_fbb, const RankOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct RankOptionsBuilder { + typedef RankOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + explicit RankOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateRankOptions( + ::flatbuffers::FlatBufferBuilder &_fbb) { + RankOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +::flatbuffers::Offset CreateRankOptions(::flatbuffers::FlatBufferBuilder &_fbb, const RankOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct PowOptionsT : public ::flatbuffers::NativeTable { + typedef PowOptions TableType; +}; + +struct PowOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef PowOptionsT NativeTableType; + typedef PowOptionsBuilder Builder; + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + verifier.EndTable(); + } + PowOptionsT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(PowOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + static ::flatbuffers::Offset Pack(::flatbuffers::FlatBufferBuilder &_fbb, const PowOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct PowOptionsBuilder { + typedef PowOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + explicit PowOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreatePowOptions( + ::flatbuffers::FlatBufferBuilder &_fbb) { + PowOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +::flatbuffers::Offset CreatePowOptions(::flatbuffers::FlatBufferBuilder &_fbb, const PowOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct FakeQuantOptionsT : public ::flatbuffers::NativeTable { + typedef FakeQuantOptions TableType; + float min = 0.0f; + float max = 0.0f; + int32_t num_bits = 0; + bool narrow_range = false; +}; + +struct FakeQuantOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef FakeQuantOptionsT NativeTableType; + typedef FakeQuantOptionsBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_MIN = 4, + VT_MAX = 6, + VT_NUM_BITS = 8, + VT_NARROW_RANGE = 10 + }; + float min() const { + return GetField(VT_MIN, 0.0f); + } + float max() const { + return GetField(VT_MAX, 0.0f); + } + int32_t num_bits() const { + return GetField(VT_NUM_BITS, 0); + } + bool narrow_range() const { + return GetField(VT_NARROW_RANGE, 0) != 0; + } + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_MIN, 4) && + VerifyField(verifier, VT_MAX, 4) && + VerifyField(verifier, VT_NUM_BITS, 4) && + VerifyField(verifier, VT_NARROW_RANGE, 1) && + verifier.EndTable(); + } + FakeQuantOptionsT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(FakeQuantOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + static ::flatbuffers::Offset Pack(::flatbuffers::FlatBufferBuilder &_fbb, const FakeQuantOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct FakeQuantOptionsBuilder { + typedef FakeQuantOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_min(float min) { + fbb_.AddElement(FakeQuantOptions::VT_MIN, min, 0.0f); + } + void add_max(float max) { + fbb_.AddElement(FakeQuantOptions::VT_MAX, max, 0.0f); + } + void add_num_bits(int32_t num_bits) { + fbb_.AddElement(FakeQuantOptions::VT_NUM_BITS, num_bits, 0); + } + void add_narrow_range(bool narrow_range) { + fbb_.AddElement(FakeQuantOptions::VT_NARROW_RANGE, static_cast(narrow_range), 0); + } + explicit FakeQuantOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateFakeQuantOptions( + ::flatbuffers::FlatBufferBuilder &_fbb, + float min = 0.0f, + float max = 0.0f, + int32_t num_bits = 0, + bool narrow_range = false) { + FakeQuantOptionsBuilder builder_(_fbb); + builder_.add_num_bits(num_bits); + builder_.add_max(max); + builder_.add_min(min); + builder_.add_narrow_range(narrow_range); + return builder_.Finish(); +} + +::flatbuffers::Offset CreateFakeQuantOptions(::flatbuffers::FlatBufferBuilder &_fbb, const FakeQuantOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct PackOptionsT : public ::flatbuffers::NativeTable { + typedef PackOptions TableType; + int32_t values_count = 0; + int32_t axis = 0; +}; + +struct PackOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef PackOptionsT NativeTableType; + typedef PackOptionsBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_VALUES_COUNT = 4, + VT_AXIS = 6 + }; + int32_t values_count() const { + return GetField(VT_VALUES_COUNT, 0); + } + int32_t axis() const { + return GetField(VT_AXIS, 0); + } + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_VALUES_COUNT, 4) && + VerifyField(verifier, VT_AXIS, 4) && + verifier.EndTable(); + } + PackOptionsT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(PackOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + static ::flatbuffers::Offset Pack(::flatbuffers::FlatBufferBuilder &_fbb, const PackOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct PackOptionsBuilder { + typedef PackOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_values_count(int32_t values_count) { + fbb_.AddElement(PackOptions::VT_VALUES_COUNT, values_count, 0); + } + void add_axis(int32_t axis) { + fbb_.AddElement(PackOptions::VT_AXIS, axis, 0); + } + explicit PackOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreatePackOptions( + ::flatbuffers::FlatBufferBuilder &_fbb, + int32_t values_count = 0, + int32_t axis = 0) { + PackOptionsBuilder builder_(_fbb); + builder_.add_axis(axis); + builder_.add_values_count(values_count); + return builder_.Finish(); +} + +::flatbuffers::Offset CreatePackOptions(::flatbuffers::FlatBufferBuilder &_fbb, const PackOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct LogicalOrOptionsT : public ::flatbuffers::NativeTable { + typedef LogicalOrOptions TableType; +}; + +struct LogicalOrOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef LogicalOrOptionsT NativeTableType; + typedef LogicalOrOptionsBuilder Builder; + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + verifier.EndTable(); + } + LogicalOrOptionsT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(LogicalOrOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + static ::flatbuffers::Offset Pack(::flatbuffers::FlatBufferBuilder &_fbb, const LogicalOrOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct LogicalOrOptionsBuilder { + typedef LogicalOrOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + explicit LogicalOrOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateLogicalOrOptions( + ::flatbuffers::FlatBufferBuilder &_fbb) { + LogicalOrOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +::flatbuffers::Offset CreateLogicalOrOptions(::flatbuffers::FlatBufferBuilder &_fbb, const LogicalOrOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct OneHotOptionsT : public ::flatbuffers::NativeTable { + typedef OneHotOptions TableType; + int32_t axis = 0; +}; + +struct OneHotOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef OneHotOptionsT NativeTableType; + typedef OneHotOptionsBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_AXIS = 4 + }; + int32_t axis() const { + return GetField(VT_AXIS, 0); + } + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_AXIS, 4) && + verifier.EndTable(); + } + OneHotOptionsT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(OneHotOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + static ::flatbuffers::Offset Pack(::flatbuffers::FlatBufferBuilder &_fbb, const OneHotOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct OneHotOptionsBuilder { + typedef OneHotOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_axis(int32_t axis) { + fbb_.AddElement(OneHotOptions::VT_AXIS, axis, 0); + } + explicit OneHotOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateOneHotOptions( + ::flatbuffers::FlatBufferBuilder &_fbb, + int32_t axis = 0) { + OneHotOptionsBuilder builder_(_fbb); + builder_.add_axis(axis); + return builder_.Finish(); +} + +::flatbuffers::Offset CreateOneHotOptions(::flatbuffers::FlatBufferBuilder &_fbb, const OneHotOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct AbsOptionsT : public ::flatbuffers::NativeTable { + typedef AbsOptions TableType; +}; + +struct AbsOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef AbsOptionsT NativeTableType; + typedef AbsOptionsBuilder Builder; + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + verifier.EndTable(); + } + AbsOptionsT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(AbsOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + static ::flatbuffers::Offset Pack(::flatbuffers::FlatBufferBuilder &_fbb, const AbsOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct AbsOptionsBuilder { + typedef AbsOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + explicit AbsOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateAbsOptions( + ::flatbuffers::FlatBufferBuilder &_fbb) { + AbsOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +::flatbuffers::Offset CreateAbsOptions(::flatbuffers::FlatBufferBuilder &_fbb, const AbsOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct HardSwishOptionsT : public ::flatbuffers::NativeTable { + typedef HardSwishOptions TableType; +}; + +struct HardSwishOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef HardSwishOptionsT NativeTableType; + typedef HardSwishOptionsBuilder Builder; + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + verifier.EndTable(); + } + HardSwishOptionsT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(HardSwishOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + static ::flatbuffers::Offset Pack(::flatbuffers::FlatBufferBuilder &_fbb, const HardSwishOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct HardSwishOptionsBuilder { + typedef HardSwishOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + explicit HardSwishOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateHardSwishOptions( + ::flatbuffers::FlatBufferBuilder &_fbb) { + HardSwishOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +::flatbuffers::Offset CreateHardSwishOptions(::flatbuffers::FlatBufferBuilder &_fbb, const HardSwishOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct LogicalAndOptionsT : public ::flatbuffers::NativeTable { + typedef LogicalAndOptions TableType; +}; + +struct LogicalAndOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef LogicalAndOptionsT NativeTableType; + typedef LogicalAndOptionsBuilder Builder; + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + verifier.EndTable(); + } + LogicalAndOptionsT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(LogicalAndOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + static ::flatbuffers::Offset Pack(::flatbuffers::FlatBufferBuilder &_fbb, const LogicalAndOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct LogicalAndOptionsBuilder { + typedef LogicalAndOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + explicit LogicalAndOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateLogicalAndOptions( + ::flatbuffers::FlatBufferBuilder &_fbb) { + LogicalAndOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +::flatbuffers::Offset CreateLogicalAndOptions(::flatbuffers::FlatBufferBuilder &_fbb, const LogicalAndOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct LogicalNotOptionsT : public ::flatbuffers::NativeTable { + typedef LogicalNotOptions TableType; +}; + +struct LogicalNotOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef LogicalNotOptionsT NativeTableType; + typedef LogicalNotOptionsBuilder Builder; + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + verifier.EndTable(); + } + LogicalNotOptionsT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(LogicalNotOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + static ::flatbuffers::Offset Pack(::flatbuffers::FlatBufferBuilder &_fbb, const LogicalNotOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct LogicalNotOptionsBuilder { + typedef LogicalNotOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + explicit LogicalNotOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateLogicalNotOptions( + ::flatbuffers::FlatBufferBuilder &_fbb) { + LogicalNotOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +::flatbuffers::Offset CreateLogicalNotOptions(::flatbuffers::FlatBufferBuilder &_fbb, const LogicalNotOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct UnpackOptionsT : public ::flatbuffers::NativeTable { + typedef UnpackOptions TableType; + int32_t num = 0; + int32_t axis = 0; +}; + +struct UnpackOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef UnpackOptionsT NativeTableType; + typedef UnpackOptionsBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_NUM = 4, + VT_AXIS = 6 + }; + int32_t num() const { + return GetField(VT_NUM, 0); + } + int32_t axis() const { + return GetField(VT_AXIS, 0); + } + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_NUM, 4) && + VerifyField(verifier, VT_AXIS, 4) && + verifier.EndTable(); + } + UnpackOptionsT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(UnpackOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + static ::flatbuffers::Offset Pack(::flatbuffers::FlatBufferBuilder &_fbb, const UnpackOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct UnpackOptionsBuilder { + typedef UnpackOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_num(int32_t num) { + fbb_.AddElement(UnpackOptions::VT_NUM, num, 0); + } + void add_axis(int32_t axis) { + fbb_.AddElement(UnpackOptions::VT_AXIS, axis, 0); + } + explicit UnpackOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateUnpackOptions( + ::flatbuffers::FlatBufferBuilder &_fbb, + int32_t num = 0, + int32_t axis = 0) { + UnpackOptionsBuilder builder_(_fbb); + builder_.add_axis(axis); + builder_.add_num(num); + return builder_.Finish(); +} + +::flatbuffers::Offset CreateUnpackOptions(::flatbuffers::FlatBufferBuilder &_fbb, const UnpackOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct FloorDivOptionsT : public ::flatbuffers::NativeTable { + typedef FloorDivOptions TableType; +}; + +struct FloorDivOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef FloorDivOptionsT NativeTableType; + typedef FloorDivOptionsBuilder Builder; + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + verifier.EndTable(); + } + FloorDivOptionsT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(FloorDivOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + static ::flatbuffers::Offset Pack(::flatbuffers::FlatBufferBuilder &_fbb, const FloorDivOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct FloorDivOptionsBuilder { + typedef FloorDivOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + explicit FloorDivOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateFloorDivOptions( + ::flatbuffers::FlatBufferBuilder &_fbb) { + FloorDivOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +::flatbuffers::Offset CreateFloorDivOptions(::flatbuffers::FlatBufferBuilder &_fbb, const FloorDivOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct SquareOptionsT : public ::flatbuffers::NativeTable { + typedef SquareOptions TableType; +}; + +struct SquareOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef SquareOptionsT NativeTableType; + typedef SquareOptionsBuilder Builder; + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + verifier.EndTable(); + } + SquareOptionsT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(SquareOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + static ::flatbuffers::Offset Pack(::flatbuffers::FlatBufferBuilder &_fbb, const SquareOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct SquareOptionsBuilder { + typedef SquareOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + explicit SquareOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateSquareOptions( + ::flatbuffers::FlatBufferBuilder &_fbb) { + SquareOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +::flatbuffers::Offset CreateSquareOptions(::flatbuffers::FlatBufferBuilder &_fbb, const SquareOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct ZerosLikeOptionsT : public ::flatbuffers::NativeTable { + typedef ZerosLikeOptions TableType; +}; + +struct ZerosLikeOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef ZerosLikeOptionsT NativeTableType; + typedef ZerosLikeOptionsBuilder Builder; + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + verifier.EndTable(); + } + ZerosLikeOptionsT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(ZerosLikeOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + static ::flatbuffers::Offset Pack(::flatbuffers::FlatBufferBuilder &_fbb, const ZerosLikeOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct ZerosLikeOptionsBuilder { + typedef ZerosLikeOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + explicit ZerosLikeOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateZerosLikeOptions( + ::flatbuffers::FlatBufferBuilder &_fbb) { + ZerosLikeOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +::flatbuffers::Offset CreateZerosLikeOptions(::flatbuffers::FlatBufferBuilder &_fbb, const ZerosLikeOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct FillOptionsT : public ::flatbuffers::NativeTable { + typedef FillOptions TableType; +}; + +struct FillOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef FillOptionsT NativeTableType; + typedef FillOptionsBuilder Builder; + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + verifier.EndTable(); + } + FillOptionsT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(FillOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + static ::flatbuffers::Offset Pack(::flatbuffers::FlatBufferBuilder &_fbb, const FillOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct FillOptionsBuilder { + typedef FillOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + explicit FillOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateFillOptions( + ::flatbuffers::FlatBufferBuilder &_fbb) { + FillOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +::flatbuffers::Offset CreateFillOptions(::flatbuffers::FlatBufferBuilder &_fbb, const FillOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct FloorModOptionsT : public ::flatbuffers::NativeTable { + typedef FloorModOptions TableType; +}; + +struct FloorModOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef FloorModOptionsT NativeTableType; + typedef FloorModOptionsBuilder Builder; + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + verifier.EndTable(); + } + FloorModOptionsT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(FloorModOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + static ::flatbuffers::Offset Pack(::flatbuffers::FlatBufferBuilder &_fbb, const FloorModOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct FloorModOptionsBuilder { + typedef FloorModOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + explicit FloorModOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateFloorModOptions( + ::flatbuffers::FlatBufferBuilder &_fbb) { + FloorModOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +::flatbuffers::Offset CreateFloorModOptions(::flatbuffers::FlatBufferBuilder &_fbb, const FloorModOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct RangeOptionsT : public ::flatbuffers::NativeTable { + typedef RangeOptions TableType; +}; + +struct RangeOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef RangeOptionsT NativeTableType; + typedef RangeOptionsBuilder Builder; + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + verifier.EndTable(); + } + RangeOptionsT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(RangeOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + static ::flatbuffers::Offset Pack(::flatbuffers::FlatBufferBuilder &_fbb, const RangeOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct RangeOptionsBuilder { + typedef RangeOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + explicit RangeOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateRangeOptions( + ::flatbuffers::FlatBufferBuilder &_fbb) { + RangeOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +::flatbuffers::Offset CreateRangeOptions(::flatbuffers::FlatBufferBuilder &_fbb, const RangeOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct LeakyReluOptionsT : public ::flatbuffers::NativeTable { + typedef LeakyReluOptions TableType; + float alpha = 0.0f; +}; + +struct LeakyReluOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef LeakyReluOptionsT NativeTableType; + typedef LeakyReluOptionsBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_ALPHA = 4 + }; + float alpha() const { + return GetField(VT_ALPHA, 0.0f); + } + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_ALPHA, 4) && + verifier.EndTable(); + } + LeakyReluOptionsT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(LeakyReluOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + static ::flatbuffers::Offset Pack(::flatbuffers::FlatBufferBuilder &_fbb, const LeakyReluOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct LeakyReluOptionsBuilder { + typedef LeakyReluOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_alpha(float alpha) { + fbb_.AddElement(LeakyReluOptions::VT_ALPHA, alpha, 0.0f); + } + explicit LeakyReluOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateLeakyReluOptions( + ::flatbuffers::FlatBufferBuilder &_fbb, + float alpha = 0.0f) { + LeakyReluOptionsBuilder builder_(_fbb); + builder_.add_alpha(alpha); + return builder_.Finish(); +} + +::flatbuffers::Offset CreateLeakyReluOptions(::flatbuffers::FlatBufferBuilder &_fbb, const LeakyReluOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct SquaredDifferenceOptionsT : public ::flatbuffers::NativeTable { + typedef SquaredDifferenceOptions TableType; +}; + +struct SquaredDifferenceOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef SquaredDifferenceOptionsT NativeTableType; + typedef SquaredDifferenceOptionsBuilder Builder; + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + verifier.EndTable(); + } + SquaredDifferenceOptionsT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(SquaredDifferenceOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + static ::flatbuffers::Offset Pack(::flatbuffers::FlatBufferBuilder &_fbb, const SquaredDifferenceOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct SquaredDifferenceOptionsBuilder { + typedef SquaredDifferenceOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + explicit SquaredDifferenceOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateSquaredDifferenceOptions( + ::flatbuffers::FlatBufferBuilder &_fbb) { + SquaredDifferenceOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +::flatbuffers::Offset CreateSquaredDifferenceOptions(::flatbuffers::FlatBufferBuilder &_fbb, const SquaredDifferenceOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct MirrorPadOptionsT : public ::flatbuffers::NativeTable { + typedef MirrorPadOptions TableType; + tflite::MirrorPadMode mode = tflite::MirrorPadMode_REFLECT; +}; + +struct MirrorPadOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef MirrorPadOptionsT NativeTableType; + typedef MirrorPadOptionsBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_MODE = 4 + }; + tflite::MirrorPadMode mode() const { + return static_cast(GetField(VT_MODE, 0)); + } + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_MODE, 1) && + verifier.EndTable(); + } + MirrorPadOptionsT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(MirrorPadOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + static ::flatbuffers::Offset Pack(::flatbuffers::FlatBufferBuilder &_fbb, const MirrorPadOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct MirrorPadOptionsBuilder { + typedef MirrorPadOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_mode(tflite::MirrorPadMode mode) { + fbb_.AddElement(MirrorPadOptions::VT_MODE, static_cast(mode), 0); + } + explicit MirrorPadOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateMirrorPadOptions( + ::flatbuffers::FlatBufferBuilder &_fbb, + tflite::MirrorPadMode mode = tflite::MirrorPadMode_REFLECT) { + MirrorPadOptionsBuilder builder_(_fbb); + builder_.add_mode(mode); + return builder_.Finish(); +} + +::flatbuffers::Offset CreateMirrorPadOptions(::flatbuffers::FlatBufferBuilder &_fbb, const MirrorPadOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct UniqueOptionsT : public ::flatbuffers::NativeTable { + typedef UniqueOptions TableType; + tflite::TensorType idx_out_type = tflite::TensorType_INT32; +}; + +struct UniqueOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef UniqueOptionsT NativeTableType; + typedef UniqueOptionsBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_IDX_OUT_TYPE = 4 + }; + tflite::TensorType idx_out_type() const { + return static_cast(GetField(VT_IDX_OUT_TYPE, 2)); + } + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_IDX_OUT_TYPE, 1) && + verifier.EndTable(); + } + UniqueOptionsT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(UniqueOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + static ::flatbuffers::Offset Pack(::flatbuffers::FlatBufferBuilder &_fbb, const UniqueOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct UniqueOptionsBuilder { + typedef UniqueOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_idx_out_type(tflite::TensorType idx_out_type) { + fbb_.AddElement(UniqueOptions::VT_IDX_OUT_TYPE, static_cast(idx_out_type), 2); + } + explicit UniqueOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateUniqueOptions( + ::flatbuffers::FlatBufferBuilder &_fbb, + tflite::TensorType idx_out_type = tflite::TensorType_INT32) { + UniqueOptionsBuilder builder_(_fbb); + builder_.add_idx_out_type(idx_out_type); + return builder_.Finish(); +} + +::flatbuffers::Offset CreateUniqueOptions(::flatbuffers::FlatBufferBuilder &_fbb, const UniqueOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct ReverseV2OptionsT : public ::flatbuffers::NativeTable { + typedef ReverseV2Options TableType; +}; + +struct ReverseV2Options FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef ReverseV2OptionsT NativeTableType; + typedef ReverseV2OptionsBuilder Builder; + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + verifier.EndTable(); + } + ReverseV2OptionsT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(ReverseV2OptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + static ::flatbuffers::Offset Pack(::flatbuffers::FlatBufferBuilder &_fbb, const ReverseV2OptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct ReverseV2OptionsBuilder { + typedef ReverseV2Options Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + explicit ReverseV2OptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateReverseV2Options( + ::flatbuffers::FlatBufferBuilder &_fbb) { + ReverseV2OptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +::flatbuffers::Offset CreateReverseV2Options(::flatbuffers::FlatBufferBuilder &_fbb, const ReverseV2OptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct AddNOptionsT : public ::flatbuffers::NativeTable { + typedef AddNOptions TableType; +}; + +struct AddNOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef AddNOptionsT NativeTableType; + typedef AddNOptionsBuilder Builder; + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + verifier.EndTable(); + } + AddNOptionsT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(AddNOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + static ::flatbuffers::Offset Pack(::flatbuffers::FlatBufferBuilder &_fbb, const AddNOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct AddNOptionsBuilder { + typedef AddNOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + explicit AddNOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateAddNOptions( + ::flatbuffers::FlatBufferBuilder &_fbb) { + AddNOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +::flatbuffers::Offset CreateAddNOptions(::flatbuffers::FlatBufferBuilder &_fbb, const AddNOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct GatherNdOptionsT : public ::flatbuffers::NativeTable { + typedef GatherNdOptions TableType; +}; + +struct GatherNdOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef GatherNdOptionsT NativeTableType; + typedef GatherNdOptionsBuilder Builder; + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + verifier.EndTable(); + } + GatherNdOptionsT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(GatherNdOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + static ::flatbuffers::Offset Pack(::flatbuffers::FlatBufferBuilder &_fbb, const GatherNdOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct GatherNdOptionsBuilder { + typedef GatherNdOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + explicit GatherNdOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateGatherNdOptions( + ::flatbuffers::FlatBufferBuilder &_fbb) { + GatherNdOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +::flatbuffers::Offset CreateGatherNdOptions(::flatbuffers::FlatBufferBuilder &_fbb, const GatherNdOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct WhereOptionsT : public ::flatbuffers::NativeTable { + typedef WhereOptions TableType; +}; + +struct WhereOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef WhereOptionsT NativeTableType; + typedef WhereOptionsBuilder Builder; + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + verifier.EndTable(); + } + WhereOptionsT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(WhereOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + static ::flatbuffers::Offset Pack(::flatbuffers::FlatBufferBuilder &_fbb, const WhereOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct WhereOptionsBuilder { + typedef WhereOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + explicit WhereOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateWhereOptions( + ::flatbuffers::FlatBufferBuilder &_fbb) { + WhereOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +::flatbuffers::Offset CreateWhereOptions(::flatbuffers::FlatBufferBuilder &_fbb, const WhereOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct ReverseSequenceOptionsT : public ::flatbuffers::NativeTable { + typedef ReverseSequenceOptions TableType; + int32_t seq_dim = 0; + int32_t batch_dim = 0; +}; + +struct ReverseSequenceOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef ReverseSequenceOptionsT NativeTableType; + typedef ReverseSequenceOptionsBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_SEQ_DIM = 4, + VT_BATCH_DIM = 6 + }; + int32_t seq_dim() const { + return GetField(VT_SEQ_DIM, 0); + } + int32_t batch_dim() const { + return GetField(VT_BATCH_DIM, 0); + } + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_SEQ_DIM, 4) && + VerifyField(verifier, VT_BATCH_DIM, 4) && + verifier.EndTable(); + } + ReverseSequenceOptionsT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(ReverseSequenceOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + static ::flatbuffers::Offset Pack(::flatbuffers::FlatBufferBuilder &_fbb, const ReverseSequenceOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct ReverseSequenceOptionsBuilder { + typedef ReverseSequenceOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_seq_dim(int32_t seq_dim) { + fbb_.AddElement(ReverseSequenceOptions::VT_SEQ_DIM, seq_dim, 0); + } + void add_batch_dim(int32_t batch_dim) { + fbb_.AddElement(ReverseSequenceOptions::VT_BATCH_DIM, batch_dim, 0); + } + explicit ReverseSequenceOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateReverseSequenceOptions( + ::flatbuffers::FlatBufferBuilder &_fbb, + int32_t seq_dim = 0, + int32_t batch_dim = 0) { + ReverseSequenceOptionsBuilder builder_(_fbb); + builder_.add_batch_dim(batch_dim); + builder_.add_seq_dim(seq_dim); + return builder_.Finish(); +} + +::flatbuffers::Offset CreateReverseSequenceOptions(::flatbuffers::FlatBufferBuilder &_fbb, const ReverseSequenceOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct MatrixDiagOptionsT : public ::flatbuffers::NativeTable { + typedef MatrixDiagOptions TableType; +}; + +struct MatrixDiagOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef MatrixDiagOptionsT NativeTableType; + typedef MatrixDiagOptionsBuilder Builder; + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + verifier.EndTable(); + } + MatrixDiagOptionsT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(MatrixDiagOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + static ::flatbuffers::Offset Pack(::flatbuffers::FlatBufferBuilder &_fbb, const MatrixDiagOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct MatrixDiagOptionsBuilder { + typedef MatrixDiagOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + explicit MatrixDiagOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateMatrixDiagOptions( + ::flatbuffers::FlatBufferBuilder &_fbb) { + MatrixDiagOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +::flatbuffers::Offset CreateMatrixDiagOptions(::flatbuffers::FlatBufferBuilder &_fbb, const MatrixDiagOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct QuantizeOptionsT : public ::flatbuffers::NativeTable { + typedef QuantizeOptions TableType; +}; + +struct QuantizeOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef QuantizeOptionsT NativeTableType; + typedef QuantizeOptionsBuilder Builder; + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + verifier.EndTable(); + } + QuantizeOptionsT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(QuantizeOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + static ::flatbuffers::Offset Pack(::flatbuffers::FlatBufferBuilder &_fbb, const QuantizeOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct QuantizeOptionsBuilder { + typedef QuantizeOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + explicit QuantizeOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateQuantizeOptions( + ::flatbuffers::FlatBufferBuilder &_fbb) { + QuantizeOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +::flatbuffers::Offset CreateQuantizeOptions(::flatbuffers::FlatBufferBuilder &_fbb, const QuantizeOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct MatrixSetDiagOptionsT : public ::flatbuffers::NativeTable { + typedef MatrixSetDiagOptions TableType; +}; + +struct MatrixSetDiagOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef MatrixSetDiagOptionsT NativeTableType; + typedef MatrixSetDiagOptionsBuilder Builder; + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + verifier.EndTable(); + } + MatrixSetDiagOptionsT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(MatrixSetDiagOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + static ::flatbuffers::Offset Pack(::flatbuffers::FlatBufferBuilder &_fbb, const MatrixSetDiagOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct MatrixSetDiagOptionsBuilder { + typedef MatrixSetDiagOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + explicit MatrixSetDiagOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateMatrixSetDiagOptions( + ::flatbuffers::FlatBufferBuilder &_fbb) { + MatrixSetDiagOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +::flatbuffers::Offset CreateMatrixSetDiagOptions(::flatbuffers::FlatBufferBuilder &_fbb, const MatrixSetDiagOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct IfOptionsT : public ::flatbuffers::NativeTable { + typedef IfOptions TableType; + int32_t then_subgraph_index = 0; + int32_t else_subgraph_index = 0; +}; + +struct IfOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef IfOptionsT NativeTableType; + typedef IfOptionsBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_THEN_SUBGRAPH_INDEX = 4, + VT_ELSE_SUBGRAPH_INDEX = 6 + }; + int32_t then_subgraph_index() const { + return GetField(VT_THEN_SUBGRAPH_INDEX, 0); + } + int32_t else_subgraph_index() const { + return GetField(VT_ELSE_SUBGRAPH_INDEX, 0); + } + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_THEN_SUBGRAPH_INDEX, 4) && + VerifyField(verifier, VT_ELSE_SUBGRAPH_INDEX, 4) && + verifier.EndTable(); + } + IfOptionsT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(IfOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + static ::flatbuffers::Offset Pack(::flatbuffers::FlatBufferBuilder &_fbb, const IfOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct IfOptionsBuilder { + typedef IfOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_then_subgraph_index(int32_t then_subgraph_index) { + fbb_.AddElement(IfOptions::VT_THEN_SUBGRAPH_INDEX, then_subgraph_index, 0); + } + void add_else_subgraph_index(int32_t else_subgraph_index) { + fbb_.AddElement(IfOptions::VT_ELSE_SUBGRAPH_INDEX, else_subgraph_index, 0); + } + explicit IfOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateIfOptions( + ::flatbuffers::FlatBufferBuilder &_fbb, + int32_t then_subgraph_index = 0, + int32_t else_subgraph_index = 0) { + IfOptionsBuilder builder_(_fbb); + builder_.add_else_subgraph_index(else_subgraph_index); + builder_.add_then_subgraph_index(then_subgraph_index); + return builder_.Finish(); +} + +::flatbuffers::Offset CreateIfOptions(::flatbuffers::FlatBufferBuilder &_fbb, const IfOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct CallOnceOptionsT : public ::flatbuffers::NativeTable { + typedef CallOnceOptions TableType; + int32_t init_subgraph_index = 0; +}; + +struct CallOnceOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef CallOnceOptionsT NativeTableType; + typedef CallOnceOptionsBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_INIT_SUBGRAPH_INDEX = 4 + }; + int32_t init_subgraph_index() const { + return GetField(VT_INIT_SUBGRAPH_INDEX, 0); + } + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_INIT_SUBGRAPH_INDEX, 4) && + verifier.EndTable(); + } + CallOnceOptionsT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(CallOnceOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + static ::flatbuffers::Offset Pack(::flatbuffers::FlatBufferBuilder &_fbb, const CallOnceOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct CallOnceOptionsBuilder { + typedef CallOnceOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_init_subgraph_index(int32_t init_subgraph_index) { + fbb_.AddElement(CallOnceOptions::VT_INIT_SUBGRAPH_INDEX, init_subgraph_index, 0); + } + explicit CallOnceOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateCallOnceOptions( + ::flatbuffers::FlatBufferBuilder &_fbb, + int32_t init_subgraph_index = 0) { + CallOnceOptionsBuilder builder_(_fbb); + builder_.add_init_subgraph_index(init_subgraph_index); + return builder_.Finish(); +} + +::flatbuffers::Offset CreateCallOnceOptions(::flatbuffers::FlatBufferBuilder &_fbb, const CallOnceOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct WhileOptionsT : public ::flatbuffers::NativeTable { + typedef WhileOptions TableType; + int32_t cond_subgraph_index = 0; + int32_t body_subgraph_index = 0; +}; + +struct WhileOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef WhileOptionsT NativeTableType; + typedef WhileOptionsBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_COND_SUBGRAPH_INDEX = 4, + VT_BODY_SUBGRAPH_INDEX = 6 + }; + int32_t cond_subgraph_index() const { + return GetField(VT_COND_SUBGRAPH_INDEX, 0); + } + int32_t body_subgraph_index() const { + return GetField(VT_BODY_SUBGRAPH_INDEX, 0); + } + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_COND_SUBGRAPH_INDEX, 4) && + VerifyField(verifier, VT_BODY_SUBGRAPH_INDEX, 4) && + verifier.EndTable(); + } + WhileOptionsT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(WhileOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + static ::flatbuffers::Offset Pack(::flatbuffers::FlatBufferBuilder &_fbb, const WhileOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct WhileOptionsBuilder { + typedef WhileOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_cond_subgraph_index(int32_t cond_subgraph_index) { + fbb_.AddElement(WhileOptions::VT_COND_SUBGRAPH_INDEX, cond_subgraph_index, 0); + } + void add_body_subgraph_index(int32_t body_subgraph_index) { + fbb_.AddElement(WhileOptions::VT_BODY_SUBGRAPH_INDEX, body_subgraph_index, 0); + } + explicit WhileOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateWhileOptions( + ::flatbuffers::FlatBufferBuilder &_fbb, + int32_t cond_subgraph_index = 0, + int32_t body_subgraph_index = 0) { + WhileOptionsBuilder builder_(_fbb); + builder_.add_body_subgraph_index(body_subgraph_index); + builder_.add_cond_subgraph_index(cond_subgraph_index); + return builder_.Finish(); +} + +::flatbuffers::Offset CreateWhileOptions(::flatbuffers::FlatBufferBuilder &_fbb, const WhileOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct NonMaxSuppressionV4OptionsT : public ::flatbuffers::NativeTable { + typedef NonMaxSuppressionV4Options TableType; +}; + +struct NonMaxSuppressionV4Options FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef NonMaxSuppressionV4OptionsT NativeTableType; + typedef NonMaxSuppressionV4OptionsBuilder Builder; + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + verifier.EndTable(); + } + NonMaxSuppressionV4OptionsT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(NonMaxSuppressionV4OptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + static ::flatbuffers::Offset Pack(::flatbuffers::FlatBufferBuilder &_fbb, const NonMaxSuppressionV4OptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct NonMaxSuppressionV4OptionsBuilder { + typedef NonMaxSuppressionV4Options Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + explicit NonMaxSuppressionV4OptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateNonMaxSuppressionV4Options( + ::flatbuffers::FlatBufferBuilder &_fbb) { + NonMaxSuppressionV4OptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +::flatbuffers::Offset CreateNonMaxSuppressionV4Options(::flatbuffers::FlatBufferBuilder &_fbb, const NonMaxSuppressionV4OptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct NonMaxSuppressionV5OptionsT : public ::flatbuffers::NativeTable { + typedef NonMaxSuppressionV5Options TableType; +}; + +struct NonMaxSuppressionV5Options FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef NonMaxSuppressionV5OptionsT NativeTableType; + typedef NonMaxSuppressionV5OptionsBuilder Builder; + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + verifier.EndTable(); + } + NonMaxSuppressionV5OptionsT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(NonMaxSuppressionV5OptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + static ::flatbuffers::Offset Pack(::flatbuffers::FlatBufferBuilder &_fbb, const NonMaxSuppressionV5OptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct NonMaxSuppressionV5OptionsBuilder { + typedef NonMaxSuppressionV5Options Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + explicit NonMaxSuppressionV5OptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateNonMaxSuppressionV5Options( + ::flatbuffers::FlatBufferBuilder &_fbb) { + NonMaxSuppressionV5OptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +::flatbuffers::Offset CreateNonMaxSuppressionV5Options(::flatbuffers::FlatBufferBuilder &_fbb, const NonMaxSuppressionV5OptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct ScatterNdOptionsT : public ::flatbuffers::NativeTable { + typedef ScatterNdOptions TableType; +}; + +struct ScatterNdOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef ScatterNdOptionsT NativeTableType; + typedef ScatterNdOptionsBuilder Builder; + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + verifier.EndTable(); + } + ScatterNdOptionsT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(ScatterNdOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + static ::flatbuffers::Offset Pack(::flatbuffers::FlatBufferBuilder &_fbb, const ScatterNdOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct ScatterNdOptionsBuilder { + typedef ScatterNdOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + explicit ScatterNdOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateScatterNdOptions( + ::flatbuffers::FlatBufferBuilder &_fbb) { + ScatterNdOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +::flatbuffers::Offset CreateScatterNdOptions(::flatbuffers::FlatBufferBuilder &_fbb, const ScatterNdOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct SelectV2OptionsT : public ::flatbuffers::NativeTable { + typedef SelectV2Options TableType; +}; + +struct SelectV2Options FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef SelectV2OptionsT NativeTableType; + typedef SelectV2OptionsBuilder Builder; + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + verifier.EndTable(); + } + SelectV2OptionsT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(SelectV2OptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + static ::flatbuffers::Offset Pack(::flatbuffers::FlatBufferBuilder &_fbb, const SelectV2OptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct SelectV2OptionsBuilder { + typedef SelectV2Options Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + explicit SelectV2OptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateSelectV2Options( + ::flatbuffers::FlatBufferBuilder &_fbb) { + SelectV2OptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +::flatbuffers::Offset CreateSelectV2Options(::flatbuffers::FlatBufferBuilder &_fbb, const SelectV2OptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct DensifyOptionsT : public ::flatbuffers::NativeTable { + typedef DensifyOptions TableType; +}; + +struct DensifyOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef DensifyOptionsT NativeTableType; + typedef DensifyOptionsBuilder Builder; + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + verifier.EndTable(); + } + DensifyOptionsT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(DensifyOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + static ::flatbuffers::Offset Pack(::flatbuffers::FlatBufferBuilder &_fbb, const DensifyOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct DensifyOptionsBuilder { + typedef DensifyOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + explicit DensifyOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateDensifyOptions( + ::flatbuffers::FlatBufferBuilder &_fbb) { + DensifyOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +::flatbuffers::Offset CreateDensifyOptions(::flatbuffers::FlatBufferBuilder &_fbb, const DensifyOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct SegmentSumOptionsT : public ::flatbuffers::NativeTable { + typedef SegmentSumOptions TableType; +}; + +struct SegmentSumOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef SegmentSumOptionsT NativeTableType; + typedef SegmentSumOptionsBuilder Builder; + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + verifier.EndTable(); + } + SegmentSumOptionsT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(SegmentSumOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + static ::flatbuffers::Offset Pack(::flatbuffers::FlatBufferBuilder &_fbb, const SegmentSumOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct SegmentSumOptionsBuilder { + typedef SegmentSumOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + explicit SegmentSumOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateSegmentSumOptions( + ::flatbuffers::FlatBufferBuilder &_fbb) { + SegmentSumOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +::flatbuffers::Offset CreateSegmentSumOptions(::flatbuffers::FlatBufferBuilder &_fbb, const SegmentSumOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct BatchMatMulOptionsT : public ::flatbuffers::NativeTable { + typedef BatchMatMulOptions TableType; + bool adj_x = false; + bool adj_y = false; + bool asymmetric_quantize_inputs = false; +}; + +struct BatchMatMulOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef BatchMatMulOptionsT NativeTableType; + typedef BatchMatMulOptionsBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_ADJ_X = 4, + VT_ADJ_Y = 6, + VT_ASYMMETRIC_QUANTIZE_INPUTS = 8 + }; + bool adj_x() const { + return GetField(VT_ADJ_X, 0) != 0; + } + bool adj_y() const { + return GetField(VT_ADJ_Y, 0) != 0; + } + bool asymmetric_quantize_inputs() const { + return GetField(VT_ASYMMETRIC_QUANTIZE_INPUTS, 0) != 0; + } + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_ADJ_X, 1) && + VerifyField(verifier, VT_ADJ_Y, 1) && + VerifyField(verifier, VT_ASYMMETRIC_QUANTIZE_INPUTS, 1) && + verifier.EndTable(); + } + BatchMatMulOptionsT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(BatchMatMulOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + static ::flatbuffers::Offset Pack(::flatbuffers::FlatBufferBuilder &_fbb, const BatchMatMulOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct BatchMatMulOptionsBuilder { + typedef BatchMatMulOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_adj_x(bool adj_x) { + fbb_.AddElement(BatchMatMulOptions::VT_ADJ_X, static_cast(adj_x), 0); + } + void add_adj_y(bool adj_y) { + fbb_.AddElement(BatchMatMulOptions::VT_ADJ_Y, static_cast(adj_y), 0); + } + void add_asymmetric_quantize_inputs(bool asymmetric_quantize_inputs) { + fbb_.AddElement(BatchMatMulOptions::VT_ASYMMETRIC_QUANTIZE_INPUTS, static_cast(asymmetric_quantize_inputs), 0); + } + explicit BatchMatMulOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateBatchMatMulOptions( + ::flatbuffers::FlatBufferBuilder &_fbb, + bool adj_x = false, + bool adj_y = false, + bool asymmetric_quantize_inputs = false) { + BatchMatMulOptionsBuilder builder_(_fbb); + builder_.add_asymmetric_quantize_inputs(asymmetric_quantize_inputs); + builder_.add_adj_y(adj_y); + builder_.add_adj_x(adj_x); + return builder_.Finish(); +} + +::flatbuffers::Offset CreateBatchMatMulOptions(::flatbuffers::FlatBufferBuilder &_fbb, const BatchMatMulOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct CumsumOptionsT : public ::flatbuffers::NativeTable { + typedef CumsumOptions TableType; + bool exclusive = false; + bool reverse = false; +}; + +struct CumsumOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef CumsumOptionsT NativeTableType; + typedef CumsumOptionsBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_EXCLUSIVE = 4, + VT_REVERSE = 6 + }; + bool exclusive() const { + return GetField(VT_EXCLUSIVE, 0) != 0; + } + bool reverse() const { + return GetField(VT_REVERSE, 0) != 0; + } + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_EXCLUSIVE, 1) && + VerifyField(verifier, VT_REVERSE, 1) && + verifier.EndTable(); + } + CumsumOptionsT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(CumsumOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + static ::flatbuffers::Offset Pack(::flatbuffers::FlatBufferBuilder &_fbb, const CumsumOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct CumsumOptionsBuilder { + typedef CumsumOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_exclusive(bool exclusive) { + fbb_.AddElement(CumsumOptions::VT_EXCLUSIVE, static_cast(exclusive), 0); + } + void add_reverse(bool reverse) { + fbb_.AddElement(CumsumOptions::VT_REVERSE, static_cast(reverse), 0); + } + explicit CumsumOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateCumsumOptions( + ::flatbuffers::FlatBufferBuilder &_fbb, + bool exclusive = false, + bool reverse = false) { + CumsumOptionsBuilder builder_(_fbb); + builder_.add_reverse(reverse); + builder_.add_exclusive(exclusive); + return builder_.Finish(); +} + +::flatbuffers::Offset CreateCumsumOptions(::flatbuffers::FlatBufferBuilder &_fbb, const CumsumOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct BroadcastToOptionsT : public ::flatbuffers::NativeTable { + typedef BroadcastToOptions TableType; +}; + +struct BroadcastToOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef BroadcastToOptionsT NativeTableType; + typedef BroadcastToOptionsBuilder Builder; + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + verifier.EndTable(); + } + BroadcastToOptionsT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(BroadcastToOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + static ::flatbuffers::Offset Pack(::flatbuffers::FlatBufferBuilder &_fbb, const BroadcastToOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct BroadcastToOptionsBuilder { + typedef BroadcastToOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + explicit BroadcastToOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateBroadcastToOptions( + ::flatbuffers::FlatBufferBuilder &_fbb) { + BroadcastToOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +::flatbuffers::Offset CreateBroadcastToOptions(::flatbuffers::FlatBufferBuilder &_fbb, const BroadcastToOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct Rfft2dOptionsT : public ::flatbuffers::NativeTable { + typedef Rfft2dOptions TableType; +}; + +struct Rfft2dOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef Rfft2dOptionsT NativeTableType; + typedef Rfft2dOptionsBuilder Builder; + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + verifier.EndTable(); + } + Rfft2dOptionsT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(Rfft2dOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + static ::flatbuffers::Offset Pack(::flatbuffers::FlatBufferBuilder &_fbb, const Rfft2dOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct Rfft2dOptionsBuilder { + typedef Rfft2dOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + explicit Rfft2dOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateRfft2dOptions( + ::flatbuffers::FlatBufferBuilder &_fbb) { + Rfft2dOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +::flatbuffers::Offset CreateRfft2dOptions(::flatbuffers::FlatBufferBuilder &_fbb, const Rfft2dOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct HashtableOptionsT : public ::flatbuffers::NativeTable { + typedef HashtableOptions TableType; + int32_t table_id = 0; + tflite::TensorType key_dtype = tflite::TensorType_FLOAT32; + tflite::TensorType value_dtype = tflite::TensorType_FLOAT32; +}; + +struct HashtableOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef HashtableOptionsT NativeTableType; + typedef HashtableOptionsBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_TABLE_ID = 4, + VT_KEY_DTYPE = 6, + VT_VALUE_DTYPE = 8 + }; + int32_t table_id() const { + return GetField(VT_TABLE_ID, 0); + } + tflite::TensorType key_dtype() const { + return static_cast(GetField(VT_KEY_DTYPE, 0)); + } + tflite::TensorType value_dtype() const { + return static_cast(GetField(VT_VALUE_DTYPE, 0)); + } + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_TABLE_ID, 4) && + VerifyField(verifier, VT_KEY_DTYPE, 1) && + VerifyField(verifier, VT_VALUE_DTYPE, 1) && + verifier.EndTable(); + } + HashtableOptionsT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(HashtableOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + static ::flatbuffers::Offset Pack(::flatbuffers::FlatBufferBuilder &_fbb, const HashtableOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct HashtableOptionsBuilder { + typedef HashtableOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_table_id(int32_t table_id) { + fbb_.AddElement(HashtableOptions::VT_TABLE_ID, table_id, 0); + } + void add_key_dtype(tflite::TensorType key_dtype) { + fbb_.AddElement(HashtableOptions::VT_KEY_DTYPE, static_cast(key_dtype), 0); + } + void add_value_dtype(tflite::TensorType value_dtype) { + fbb_.AddElement(HashtableOptions::VT_VALUE_DTYPE, static_cast(value_dtype), 0); + } + explicit HashtableOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateHashtableOptions( + ::flatbuffers::FlatBufferBuilder &_fbb, + int32_t table_id = 0, + tflite::TensorType key_dtype = tflite::TensorType_FLOAT32, + tflite::TensorType value_dtype = tflite::TensorType_FLOAT32) { + HashtableOptionsBuilder builder_(_fbb); + builder_.add_table_id(table_id); + builder_.add_value_dtype(value_dtype); + builder_.add_key_dtype(key_dtype); + return builder_.Finish(); +} + +::flatbuffers::Offset CreateHashtableOptions(::flatbuffers::FlatBufferBuilder &_fbb, const HashtableOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct HashtableFindOptionsT : public ::flatbuffers::NativeTable { + typedef HashtableFindOptions TableType; +}; + +struct HashtableFindOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef HashtableFindOptionsT NativeTableType; + typedef HashtableFindOptionsBuilder Builder; + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + verifier.EndTable(); + } + HashtableFindOptionsT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(HashtableFindOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + static ::flatbuffers::Offset Pack(::flatbuffers::FlatBufferBuilder &_fbb, const HashtableFindOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct HashtableFindOptionsBuilder { + typedef HashtableFindOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + explicit HashtableFindOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateHashtableFindOptions( + ::flatbuffers::FlatBufferBuilder &_fbb) { + HashtableFindOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +::flatbuffers::Offset CreateHashtableFindOptions(::flatbuffers::FlatBufferBuilder &_fbb, const HashtableFindOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct HashtableImportOptionsT : public ::flatbuffers::NativeTable { + typedef HashtableImportOptions TableType; +}; + +struct HashtableImportOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef HashtableImportOptionsT NativeTableType; + typedef HashtableImportOptionsBuilder Builder; + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + verifier.EndTable(); + } + HashtableImportOptionsT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(HashtableImportOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + static ::flatbuffers::Offset Pack(::flatbuffers::FlatBufferBuilder &_fbb, const HashtableImportOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct HashtableImportOptionsBuilder { + typedef HashtableImportOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + explicit HashtableImportOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateHashtableImportOptions( + ::flatbuffers::FlatBufferBuilder &_fbb) { + HashtableImportOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +::flatbuffers::Offset CreateHashtableImportOptions(::flatbuffers::FlatBufferBuilder &_fbb, const HashtableImportOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct HashtableSizeOptionsT : public ::flatbuffers::NativeTable { + typedef HashtableSizeOptions TableType; +}; + +struct HashtableSizeOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef HashtableSizeOptionsT NativeTableType; + typedef HashtableSizeOptionsBuilder Builder; + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + verifier.EndTable(); + } + HashtableSizeOptionsT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(HashtableSizeOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + static ::flatbuffers::Offset Pack(::flatbuffers::FlatBufferBuilder &_fbb, const HashtableSizeOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct HashtableSizeOptionsBuilder { + typedef HashtableSizeOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + explicit HashtableSizeOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateHashtableSizeOptions( + ::flatbuffers::FlatBufferBuilder &_fbb) { + HashtableSizeOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +::flatbuffers::Offset CreateHashtableSizeOptions(::flatbuffers::FlatBufferBuilder &_fbb, const HashtableSizeOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct VarHandleOptionsT : public ::flatbuffers::NativeTable { + typedef VarHandleOptions TableType; + std::string container{}; + std::string shared_name{}; +}; + +struct VarHandleOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef VarHandleOptionsT NativeTableType; + typedef VarHandleOptionsBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_CONTAINER = 4, + VT_SHARED_NAME = 6 + }; + const ::flatbuffers::String *container() const { + return GetPointer(VT_CONTAINER); + } + const ::flatbuffers::String *shared_name() const { + return GetPointer(VT_SHARED_NAME); + } + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyOffset(verifier, VT_CONTAINER) && + verifier.VerifyString(container()) && + VerifyOffset(verifier, VT_SHARED_NAME) && + verifier.VerifyString(shared_name()) && + verifier.EndTable(); + } + VarHandleOptionsT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(VarHandleOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + static ::flatbuffers::Offset Pack(::flatbuffers::FlatBufferBuilder &_fbb, const VarHandleOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct VarHandleOptionsBuilder { + typedef VarHandleOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_container(::flatbuffers::Offset<::flatbuffers::String> container) { + fbb_.AddOffset(VarHandleOptions::VT_CONTAINER, container); + } + void add_shared_name(::flatbuffers::Offset<::flatbuffers::String> shared_name) { + fbb_.AddOffset(VarHandleOptions::VT_SHARED_NAME, shared_name); + } + explicit VarHandleOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateVarHandleOptions( + ::flatbuffers::FlatBufferBuilder &_fbb, + ::flatbuffers::Offset<::flatbuffers::String> container = 0, + ::flatbuffers::Offset<::flatbuffers::String> shared_name = 0) { + VarHandleOptionsBuilder builder_(_fbb); + builder_.add_shared_name(shared_name); + builder_.add_container(container); + return builder_.Finish(); +} + +inline ::flatbuffers::Offset CreateVarHandleOptionsDirect( + ::flatbuffers::FlatBufferBuilder &_fbb, + const char *container = nullptr, + const char *shared_name = nullptr) { + auto container__ = container ? _fbb.CreateString(container) : 0; + auto shared_name__ = shared_name ? _fbb.CreateString(shared_name) : 0; + return tflite::CreateVarHandleOptions( + _fbb, + container__, + shared_name__); +} + +::flatbuffers::Offset CreateVarHandleOptions(::flatbuffers::FlatBufferBuilder &_fbb, const VarHandleOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct ReadVariableOptionsT : public ::flatbuffers::NativeTable { + typedef ReadVariableOptions TableType; +}; + +struct ReadVariableOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef ReadVariableOptionsT NativeTableType; + typedef ReadVariableOptionsBuilder Builder; + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + verifier.EndTable(); + } + ReadVariableOptionsT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(ReadVariableOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + static ::flatbuffers::Offset Pack(::flatbuffers::FlatBufferBuilder &_fbb, const ReadVariableOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct ReadVariableOptionsBuilder { + typedef ReadVariableOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + explicit ReadVariableOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateReadVariableOptions( + ::flatbuffers::FlatBufferBuilder &_fbb) { + ReadVariableOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +::flatbuffers::Offset CreateReadVariableOptions(::flatbuffers::FlatBufferBuilder &_fbb, const ReadVariableOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct AssignVariableOptionsT : public ::flatbuffers::NativeTable { + typedef AssignVariableOptions TableType; +}; + +struct AssignVariableOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef AssignVariableOptionsT NativeTableType; + typedef AssignVariableOptionsBuilder Builder; + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + verifier.EndTable(); + } + AssignVariableOptionsT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(AssignVariableOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + static ::flatbuffers::Offset Pack(::flatbuffers::FlatBufferBuilder &_fbb, const AssignVariableOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct AssignVariableOptionsBuilder { + typedef AssignVariableOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + explicit AssignVariableOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateAssignVariableOptions( + ::flatbuffers::FlatBufferBuilder &_fbb) { + AssignVariableOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +::flatbuffers::Offset CreateAssignVariableOptions(::flatbuffers::FlatBufferBuilder &_fbb, const AssignVariableOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct RandomOptionsT : public ::flatbuffers::NativeTable { + typedef RandomOptions TableType; + int64_t seed = 0; + int64_t seed2 = 0; +}; + +struct RandomOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef RandomOptionsT NativeTableType; + typedef RandomOptionsBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_SEED = 4, + VT_SEED2 = 6 + }; + int64_t seed() const { + return GetField(VT_SEED, 0); + } + int64_t seed2() const { + return GetField(VT_SEED2, 0); + } + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_SEED, 8) && + VerifyField(verifier, VT_SEED2, 8) && + verifier.EndTable(); + } + RandomOptionsT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(RandomOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + static ::flatbuffers::Offset Pack(::flatbuffers::FlatBufferBuilder &_fbb, const RandomOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct RandomOptionsBuilder { + typedef RandomOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_seed(int64_t seed) { + fbb_.AddElement(RandomOptions::VT_SEED, seed, 0); + } + void add_seed2(int64_t seed2) { + fbb_.AddElement(RandomOptions::VT_SEED2, seed2, 0); + } + explicit RandomOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateRandomOptions( + ::flatbuffers::FlatBufferBuilder &_fbb, + int64_t seed = 0, + int64_t seed2 = 0) { + RandomOptionsBuilder builder_(_fbb); + builder_.add_seed2(seed2); + builder_.add_seed(seed); + return builder_.Finish(); +} + +::flatbuffers::Offset CreateRandomOptions(::flatbuffers::FlatBufferBuilder &_fbb, const RandomOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct BucketizeOptionsT : public ::flatbuffers::NativeTable { + typedef BucketizeOptions TableType; + std::vector boundaries{}; +}; + +struct BucketizeOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef BucketizeOptionsT NativeTableType; + typedef BucketizeOptionsBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_BOUNDARIES = 4 + }; + const ::flatbuffers::Vector *boundaries() const { + return GetPointer *>(VT_BOUNDARIES); + } + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyOffset(verifier, VT_BOUNDARIES) && + verifier.VerifyVector(boundaries()) && + verifier.EndTable(); + } + BucketizeOptionsT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(BucketizeOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + static ::flatbuffers::Offset Pack(::flatbuffers::FlatBufferBuilder &_fbb, const BucketizeOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct BucketizeOptionsBuilder { + typedef BucketizeOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_boundaries(::flatbuffers::Offset<::flatbuffers::Vector> boundaries) { + fbb_.AddOffset(BucketizeOptions::VT_BOUNDARIES, boundaries); + } + explicit BucketizeOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateBucketizeOptions( + ::flatbuffers::FlatBufferBuilder &_fbb, + ::flatbuffers::Offset<::flatbuffers::Vector> boundaries = 0) { + BucketizeOptionsBuilder builder_(_fbb); + builder_.add_boundaries(boundaries); + return builder_.Finish(); +} + +inline ::flatbuffers::Offset CreateBucketizeOptionsDirect( + ::flatbuffers::FlatBufferBuilder &_fbb, + const std::vector *boundaries = nullptr) { + auto boundaries__ = boundaries ? _fbb.CreateVector(*boundaries) : 0; + return tflite::CreateBucketizeOptions( + _fbb, + boundaries__); +} + +::flatbuffers::Offset CreateBucketizeOptions(::flatbuffers::FlatBufferBuilder &_fbb, const BucketizeOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct GeluOptionsT : public ::flatbuffers::NativeTable { + typedef GeluOptions TableType; + bool approximate = false; +}; + +struct GeluOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef GeluOptionsT NativeTableType; + typedef GeluOptionsBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_APPROXIMATE = 4 + }; + bool approximate() const { + return GetField(VT_APPROXIMATE, 0) != 0; + } + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_APPROXIMATE, 1) && + verifier.EndTable(); + } + GeluOptionsT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(GeluOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + static ::flatbuffers::Offset Pack(::flatbuffers::FlatBufferBuilder &_fbb, const GeluOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct GeluOptionsBuilder { + typedef GeluOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_approximate(bool approximate) { + fbb_.AddElement(GeluOptions::VT_APPROXIMATE, static_cast(approximate), 0); + } + explicit GeluOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateGeluOptions( + ::flatbuffers::FlatBufferBuilder &_fbb, + bool approximate = false) { + GeluOptionsBuilder builder_(_fbb); + builder_.add_approximate(approximate); + return builder_.Finish(); +} + +::flatbuffers::Offset CreateGeluOptions(::flatbuffers::FlatBufferBuilder &_fbb, const GeluOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct DynamicUpdateSliceOptionsT : public ::flatbuffers::NativeTable { + typedef DynamicUpdateSliceOptions TableType; +}; + +struct DynamicUpdateSliceOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef DynamicUpdateSliceOptionsT NativeTableType; + typedef DynamicUpdateSliceOptionsBuilder Builder; + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + verifier.EndTable(); + } + DynamicUpdateSliceOptionsT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(DynamicUpdateSliceOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + static ::flatbuffers::Offset Pack(::flatbuffers::FlatBufferBuilder &_fbb, const DynamicUpdateSliceOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct DynamicUpdateSliceOptionsBuilder { + typedef DynamicUpdateSliceOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + explicit DynamicUpdateSliceOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateDynamicUpdateSliceOptions( + ::flatbuffers::FlatBufferBuilder &_fbb) { + DynamicUpdateSliceOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +::flatbuffers::Offset CreateDynamicUpdateSliceOptions(::flatbuffers::FlatBufferBuilder &_fbb, const DynamicUpdateSliceOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct UnsortedSegmentProdOptionsT : public ::flatbuffers::NativeTable { + typedef UnsortedSegmentProdOptions TableType; +}; + +struct UnsortedSegmentProdOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef UnsortedSegmentProdOptionsT NativeTableType; + typedef UnsortedSegmentProdOptionsBuilder Builder; + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + verifier.EndTable(); + } + UnsortedSegmentProdOptionsT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(UnsortedSegmentProdOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + static ::flatbuffers::Offset Pack(::flatbuffers::FlatBufferBuilder &_fbb, const UnsortedSegmentProdOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct UnsortedSegmentProdOptionsBuilder { + typedef UnsortedSegmentProdOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + explicit UnsortedSegmentProdOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateUnsortedSegmentProdOptions( + ::flatbuffers::FlatBufferBuilder &_fbb) { + UnsortedSegmentProdOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +::flatbuffers::Offset CreateUnsortedSegmentProdOptions(::flatbuffers::FlatBufferBuilder &_fbb, const UnsortedSegmentProdOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct UnsortedSegmentMaxOptionsT : public ::flatbuffers::NativeTable { + typedef UnsortedSegmentMaxOptions TableType; +}; + +struct UnsortedSegmentMaxOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef UnsortedSegmentMaxOptionsT NativeTableType; + typedef UnsortedSegmentMaxOptionsBuilder Builder; + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + verifier.EndTable(); + } + UnsortedSegmentMaxOptionsT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(UnsortedSegmentMaxOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + static ::flatbuffers::Offset Pack(::flatbuffers::FlatBufferBuilder &_fbb, const UnsortedSegmentMaxOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct UnsortedSegmentMaxOptionsBuilder { + typedef UnsortedSegmentMaxOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + explicit UnsortedSegmentMaxOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateUnsortedSegmentMaxOptions( + ::flatbuffers::FlatBufferBuilder &_fbb) { + UnsortedSegmentMaxOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +::flatbuffers::Offset CreateUnsortedSegmentMaxOptions(::flatbuffers::FlatBufferBuilder &_fbb, const UnsortedSegmentMaxOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct UnsortedSegmentSumOptionsT : public ::flatbuffers::NativeTable { + typedef UnsortedSegmentSumOptions TableType; +}; + +struct UnsortedSegmentSumOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef UnsortedSegmentSumOptionsT NativeTableType; + typedef UnsortedSegmentSumOptionsBuilder Builder; + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + verifier.EndTable(); + } + UnsortedSegmentSumOptionsT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(UnsortedSegmentSumOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + static ::flatbuffers::Offset Pack(::flatbuffers::FlatBufferBuilder &_fbb, const UnsortedSegmentSumOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct UnsortedSegmentSumOptionsBuilder { + typedef UnsortedSegmentSumOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + explicit UnsortedSegmentSumOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateUnsortedSegmentSumOptions( + ::flatbuffers::FlatBufferBuilder &_fbb) { + UnsortedSegmentSumOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +::flatbuffers::Offset CreateUnsortedSegmentSumOptions(::flatbuffers::FlatBufferBuilder &_fbb, const UnsortedSegmentSumOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct ATan2OptionsT : public ::flatbuffers::NativeTable { + typedef ATan2Options TableType; +}; + +struct ATan2Options FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef ATan2OptionsT NativeTableType; + typedef ATan2OptionsBuilder Builder; + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + verifier.EndTable(); + } + ATan2OptionsT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(ATan2OptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + static ::flatbuffers::Offset Pack(::flatbuffers::FlatBufferBuilder &_fbb, const ATan2OptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct ATan2OptionsBuilder { + typedef ATan2Options Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + explicit ATan2OptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateATan2Options( + ::flatbuffers::FlatBufferBuilder &_fbb) { + ATan2OptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +::flatbuffers::Offset CreateATan2Options(::flatbuffers::FlatBufferBuilder &_fbb, const ATan2OptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct UnsortedSegmentMinOptionsT : public ::flatbuffers::NativeTable { + typedef UnsortedSegmentMinOptions TableType; +}; + +struct UnsortedSegmentMinOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef UnsortedSegmentMinOptionsT NativeTableType; + typedef UnsortedSegmentMinOptionsBuilder Builder; + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + verifier.EndTable(); + } + UnsortedSegmentMinOptionsT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(UnsortedSegmentMinOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + static ::flatbuffers::Offset Pack(::flatbuffers::FlatBufferBuilder &_fbb, const UnsortedSegmentMinOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct UnsortedSegmentMinOptionsBuilder { + typedef UnsortedSegmentMinOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + explicit UnsortedSegmentMinOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateUnsortedSegmentMinOptions( + ::flatbuffers::FlatBufferBuilder &_fbb) { + UnsortedSegmentMinOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +::flatbuffers::Offset CreateUnsortedSegmentMinOptions(::flatbuffers::FlatBufferBuilder &_fbb, const UnsortedSegmentMinOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct SignOptionsT : public ::flatbuffers::NativeTable { + typedef SignOptions TableType; +}; + +struct SignOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef SignOptionsT NativeTableType; + typedef SignOptionsBuilder Builder; + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + verifier.EndTable(); + } + SignOptionsT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(SignOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + static ::flatbuffers::Offset Pack(::flatbuffers::FlatBufferBuilder &_fbb, const SignOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct SignOptionsBuilder { + typedef SignOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + explicit SignOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateSignOptions( + ::flatbuffers::FlatBufferBuilder &_fbb) { + SignOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +::flatbuffers::Offset CreateSignOptions(::flatbuffers::FlatBufferBuilder &_fbb, const SignOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct BitcastOptionsT : public ::flatbuffers::NativeTable { + typedef BitcastOptions TableType; +}; + +struct BitcastOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef BitcastOptionsT NativeTableType; + typedef BitcastOptionsBuilder Builder; + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + verifier.EndTable(); + } + BitcastOptionsT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(BitcastOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + static ::flatbuffers::Offset Pack(::flatbuffers::FlatBufferBuilder &_fbb, const BitcastOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct BitcastOptionsBuilder { + typedef BitcastOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + explicit BitcastOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateBitcastOptions( + ::flatbuffers::FlatBufferBuilder &_fbb) { + BitcastOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +::flatbuffers::Offset CreateBitcastOptions(::flatbuffers::FlatBufferBuilder &_fbb, const BitcastOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct BitwiseXorOptionsT : public ::flatbuffers::NativeTable { + typedef BitwiseXorOptions TableType; +}; + +struct BitwiseXorOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef BitwiseXorOptionsT NativeTableType; + typedef BitwiseXorOptionsBuilder Builder; + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + verifier.EndTable(); + } + BitwiseXorOptionsT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(BitwiseXorOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + static ::flatbuffers::Offset Pack(::flatbuffers::FlatBufferBuilder &_fbb, const BitwiseXorOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct BitwiseXorOptionsBuilder { + typedef BitwiseXorOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + explicit BitwiseXorOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateBitwiseXorOptions( + ::flatbuffers::FlatBufferBuilder &_fbb) { + BitwiseXorOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +::flatbuffers::Offset CreateBitwiseXorOptions(::flatbuffers::FlatBufferBuilder &_fbb, const BitwiseXorOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct RightShiftOptionsT : public ::flatbuffers::NativeTable { + typedef RightShiftOptions TableType; +}; + +struct RightShiftOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef RightShiftOptionsT NativeTableType; + typedef RightShiftOptionsBuilder Builder; + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + verifier.EndTable(); + } + RightShiftOptionsT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(RightShiftOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + static ::flatbuffers::Offset Pack(::flatbuffers::FlatBufferBuilder &_fbb, const RightShiftOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct RightShiftOptionsBuilder { + typedef RightShiftOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + explicit RightShiftOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateRightShiftOptions( + ::flatbuffers::FlatBufferBuilder &_fbb) { + RightShiftOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +::flatbuffers::Offset CreateRightShiftOptions(::flatbuffers::FlatBufferBuilder &_fbb, const RightShiftOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct DilateOptionsT : public ::flatbuffers::NativeTable { + typedef DilateOptions TableType; +}; + +struct DilateOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef DilateOptionsT NativeTableType; + typedef DilateOptionsBuilder Builder; + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + verifier.EndTable(); + } + DilateOptionsT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(DilateOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + static ::flatbuffers::Offset Pack(::flatbuffers::FlatBufferBuilder &_fbb, const DilateOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct DilateOptionsBuilder { + typedef DilateOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + explicit DilateOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateDilateOptions( + ::flatbuffers::FlatBufferBuilder &_fbb) { + DilateOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +::flatbuffers::Offset CreateDilateOptions(::flatbuffers::FlatBufferBuilder &_fbb, const DilateOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct ReduceWindowOptionsT : public ::flatbuffers::NativeTable { + typedef ReduceWindowOptions TableType; + tflite::ReduceWindowFunction reduce_function = tflite::ReduceWindowFunction_UNSUPPORTED; +}; + +struct ReduceWindowOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef ReduceWindowOptionsT NativeTableType; + typedef ReduceWindowOptionsBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_REDUCE_FUNCTION = 4 + }; + tflite::ReduceWindowFunction reduce_function() const { + return static_cast(GetField(VT_REDUCE_FUNCTION, 0)); + } + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_REDUCE_FUNCTION, 4) && + verifier.EndTable(); + } + ReduceWindowOptionsT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(ReduceWindowOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + static ::flatbuffers::Offset Pack(::flatbuffers::FlatBufferBuilder &_fbb, const ReduceWindowOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct ReduceWindowOptionsBuilder { + typedef ReduceWindowOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_reduce_function(tflite::ReduceWindowFunction reduce_function) { + fbb_.AddElement(ReduceWindowOptions::VT_REDUCE_FUNCTION, static_cast(reduce_function), 0); + } + explicit ReduceWindowOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateReduceWindowOptions( + ::flatbuffers::FlatBufferBuilder &_fbb, + tflite::ReduceWindowFunction reduce_function = tflite::ReduceWindowFunction_UNSUPPORTED) { + ReduceWindowOptionsBuilder builder_(_fbb); + builder_.add_reduce_function(reduce_function); + return builder_.Finish(); +} + +::flatbuffers::Offset CreateReduceWindowOptions(::flatbuffers::FlatBufferBuilder &_fbb, const ReduceWindowOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct OperatorCodeT : public ::flatbuffers::NativeTable { + typedef OperatorCode TableType; + int8_t deprecated_builtin_code = 0; + std::string custom_code{}; + int32_t version = 1; + tflite::BuiltinOperator builtin_code = tflite::BuiltinOperator_ADD; +}; + +struct OperatorCode FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef OperatorCodeT NativeTableType; + typedef OperatorCodeBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_DEPRECATED_BUILTIN_CODE = 4, + VT_CUSTOM_CODE = 6, + VT_VERSION = 8, + VT_BUILTIN_CODE = 10 + }; + int8_t deprecated_builtin_code() const { + return GetField(VT_DEPRECATED_BUILTIN_CODE, 0); + } + const ::flatbuffers::String *custom_code() const { + return GetPointer(VT_CUSTOM_CODE); + } + int32_t version() const { + return GetField(VT_VERSION, 1); + } + tflite::BuiltinOperator builtin_code() const { + return static_cast(GetField(VT_BUILTIN_CODE, 0)); + } + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_DEPRECATED_BUILTIN_CODE, 1) && + VerifyOffset(verifier, VT_CUSTOM_CODE) && + verifier.VerifyString(custom_code()) && + VerifyField(verifier, VT_VERSION, 4) && + VerifyField(verifier, VT_BUILTIN_CODE, 4) && + verifier.EndTable(); + } + OperatorCodeT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(OperatorCodeT *_o, const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + static ::flatbuffers::Offset Pack(::flatbuffers::FlatBufferBuilder &_fbb, const OperatorCodeT* _o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct OperatorCodeBuilder { + typedef OperatorCode Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_deprecated_builtin_code(int8_t deprecated_builtin_code) { + fbb_.AddElement(OperatorCode::VT_DEPRECATED_BUILTIN_CODE, deprecated_builtin_code, 0); + } + void add_custom_code(::flatbuffers::Offset<::flatbuffers::String> custom_code) { + fbb_.AddOffset(OperatorCode::VT_CUSTOM_CODE, custom_code); + } + void add_version(int32_t version) { + fbb_.AddElement(OperatorCode::VT_VERSION, version, 1); + } + void add_builtin_code(tflite::BuiltinOperator builtin_code) { + fbb_.AddElement(OperatorCode::VT_BUILTIN_CODE, static_cast(builtin_code), 0); + } + explicit OperatorCodeBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateOperatorCode( + ::flatbuffers::FlatBufferBuilder &_fbb, + int8_t deprecated_builtin_code = 0, + ::flatbuffers::Offset<::flatbuffers::String> custom_code = 0, + int32_t version = 1, + tflite::BuiltinOperator builtin_code = tflite::BuiltinOperator_ADD) { + OperatorCodeBuilder builder_(_fbb); + builder_.add_builtin_code(builtin_code); + builder_.add_version(version); + builder_.add_custom_code(custom_code); + builder_.add_deprecated_builtin_code(deprecated_builtin_code); + return builder_.Finish(); +} + +inline ::flatbuffers::Offset CreateOperatorCodeDirect( + ::flatbuffers::FlatBufferBuilder &_fbb, + int8_t deprecated_builtin_code = 0, + const char *custom_code = nullptr, + int32_t version = 1, + tflite::BuiltinOperator builtin_code = tflite::BuiltinOperator_ADD) { + auto custom_code__ = custom_code ? _fbb.CreateString(custom_code) : 0; + return tflite::CreateOperatorCode( + _fbb, + deprecated_builtin_code, + custom_code__, + version, + builtin_code); +} + +::flatbuffers::Offset CreateOperatorCode(::flatbuffers::FlatBufferBuilder &_fbb, const OperatorCodeT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct StableHLOCompositeOptionsT : public ::flatbuffers::NativeTable { + typedef StableHLOCompositeOptions TableType; + std::string name{}; + int32_t decomposition_subgraph_index = 0; + std::vector composite_attributes{}; + tflite::CustomOptionsFormat composite_attributes_format = tflite::CustomOptionsFormat_FLEXBUFFERS; + int32_t version = 0; +}; + +struct StableHLOCompositeOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef StableHLOCompositeOptionsT NativeTableType; + typedef StableHLOCompositeOptionsBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_NAME = 4, + VT_DECOMPOSITION_SUBGRAPH_INDEX = 6, + VT_COMPOSITE_ATTRIBUTES = 8, + VT_COMPOSITE_ATTRIBUTES_FORMAT = 10, + VT_VERSION = 12 + }; + const ::flatbuffers::String *name() const { + return GetPointer(VT_NAME); + } + int32_t decomposition_subgraph_index() const { + return GetField(VT_DECOMPOSITION_SUBGRAPH_INDEX, 0); + } + const ::flatbuffers::Vector *composite_attributes() const { + return GetPointer *>(VT_COMPOSITE_ATTRIBUTES); + } + tflite::CustomOptionsFormat composite_attributes_format() const { + return static_cast(GetField(VT_COMPOSITE_ATTRIBUTES_FORMAT, 0)); + } + int32_t version() const { + return GetField(VT_VERSION, 0); + } + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyOffset(verifier, VT_NAME) && + verifier.VerifyString(name()) && + VerifyField(verifier, VT_DECOMPOSITION_SUBGRAPH_INDEX, 4) && + VerifyOffset(verifier, VT_COMPOSITE_ATTRIBUTES) && + verifier.VerifyVector(composite_attributes()) && + VerifyField(verifier, VT_COMPOSITE_ATTRIBUTES_FORMAT, 1) && + VerifyField(verifier, VT_VERSION, 4) && + verifier.EndTable(); + } + StableHLOCompositeOptionsT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(StableHLOCompositeOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + static ::flatbuffers::Offset Pack(::flatbuffers::FlatBufferBuilder &_fbb, const StableHLOCompositeOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct StableHLOCompositeOptionsBuilder { + typedef StableHLOCompositeOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_name(::flatbuffers::Offset<::flatbuffers::String> name) { + fbb_.AddOffset(StableHLOCompositeOptions::VT_NAME, name); + } + void add_decomposition_subgraph_index(int32_t decomposition_subgraph_index) { + fbb_.AddElement(StableHLOCompositeOptions::VT_DECOMPOSITION_SUBGRAPH_INDEX, decomposition_subgraph_index, 0); + } + void add_composite_attributes(::flatbuffers::Offset<::flatbuffers::Vector> composite_attributes) { + fbb_.AddOffset(StableHLOCompositeOptions::VT_COMPOSITE_ATTRIBUTES, composite_attributes); + } + void add_composite_attributes_format(tflite::CustomOptionsFormat composite_attributes_format) { + fbb_.AddElement(StableHLOCompositeOptions::VT_COMPOSITE_ATTRIBUTES_FORMAT, static_cast(composite_attributes_format), 0); + } + void add_version(int32_t version) { + fbb_.AddElement(StableHLOCompositeOptions::VT_VERSION, version, 0); + } + explicit StableHLOCompositeOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateStableHLOCompositeOptions( + ::flatbuffers::FlatBufferBuilder &_fbb, + ::flatbuffers::Offset<::flatbuffers::String> name = 0, + int32_t decomposition_subgraph_index = 0, + ::flatbuffers::Offset<::flatbuffers::Vector> composite_attributes = 0, + tflite::CustomOptionsFormat composite_attributes_format = tflite::CustomOptionsFormat_FLEXBUFFERS, + int32_t version = 0) { + StableHLOCompositeOptionsBuilder builder_(_fbb); + builder_.add_version(version); + builder_.add_composite_attributes(composite_attributes); + builder_.add_decomposition_subgraph_index(decomposition_subgraph_index); + builder_.add_name(name); + builder_.add_composite_attributes_format(composite_attributes_format); + return builder_.Finish(); +} + +inline ::flatbuffers::Offset CreateStableHLOCompositeOptionsDirect( + ::flatbuffers::FlatBufferBuilder &_fbb, + const char *name = nullptr, + int32_t decomposition_subgraph_index = 0, + const std::vector *composite_attributes = nullptr, + tflite::CustomOptionsFormat composite_attributes_format = tflite::CustomOptionsFormat_FLEXBUFFERS, + int32_t version = 0) { + auto name__ = name ? _fbb.CreateString(name) : 0; + auto composite_attributes__ = composite_attributes ? _fbb.CreateVector(*composite_attributes) : 0; + return tflite::CreateStableHLOCompositeOptions( + _fbb, + name__, + decomposition_subgraph_index, + composite_attributes__, + composite_attributes_format, + version); +} + +::flatbuffers::Offset CreateStableHLOCompositeOptions(::flatbuffers::FlatBufferBuilder &_fbb, const StableHLOCompositeOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct StablehloShiftLeftOptionsT : public ::flatbuffers::NativeTable { + typedef StablehloShiftLeftOptions TableType; +}; + +struct StablehloShiftLeftOptions FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef StablehloShiftLeftOptionsT NativeTableType; + typedef StablehloShiftLeftOptionsBuilder Builder; + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + verifier.EndTable(); + } + StablehloShiftLeftOptionsT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(StablehloShiftLeftOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + static ::flatbuffers::Offset Pack(::flatbuffers::FlatBufferBuilder &_fbb, const StablehloShiftLeftOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct StablehloShiftLeftOptionsBuilder { + typedef StablehloShiftLeftOptions Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + explicit StablehloShiftLeftOptionsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateStablehloShiftLeftOptions( + ::flatbuffers::FlatBufferBuilder &_fbb) { + StablehloShiftLeftOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +::flatbuffers::Offset CreateStablehloShiftLeftOptions(::flatbuffers::FlatBufferBuilder &_fbb, const StablehloShiftLeftOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct OperatorT : public ::flatbuffers::NativeTable { + typedef Operator TableType; + uint32_t opcode_index = 0; + std::vector inputs{}; + std::vector outputs{}; + tflite::BuiltinOptionsUnion builtin_options{}; + std::vector custom_options{}; + tflite::CustomOptionsFormat custom_options_format = tflite::CustomOptionsFormat_FLEXBUFFERS; + std::vector mutating_variable_inputs{}; + std::vector intermediates{}; + uint64_t large_custom_options_offset = 0; + uint64_t large_custom_options_size = 0; + tflite::BuiltinOptions2Union builtin_options_2{}; + int32_t debug_metadata_index = -1; +}; + +struct Operator FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef OperatorT NativeTableType; + typedef OperatorBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_OPCODE_INDEX = 4, + VT_INPUTS = 6, + VT_OUTPUTS = 8, + VT_BUILTIN_OPTIONS_TYPE = 10, + VT_BUILTIN_OPTIONS = 12, + VT_CUSTOM_OPTIONS = 14, + VT_CUSTOM_OPTIONS_FORMAT = 16, + VT_MUTATING_VARIABLE_INPUTS = 18, + VT_INTERMEDIATES = 20, + VT_LARGE_CUSTOM_OPTIONS_OFFSET = 22, + VT_LARGE_CUSTOM_OPTIONS_SIZE = 24, + VT_BUILTIN_OPTIONS_2_TYPE = 26, + VT_BUILTIN_OPTIONS_2 = 28, + VT_DEBUG_METADATA_INDEX = 30 + }; + uint32_t opcode_index() const { + return GetField(VT_OPCODE_INDEX, 0); + } + const ::flatbuffers::Vector *inputs() const { + return GetPointer *>(VT_INPUTS); + } + const ::flatbuffers::Vector *outputs() const { + return GetPointer *>(VT_OUTPUTS); + } + tflite::BuiltinOptions builtin_options_type() const { + return static_cast(GetField(VT_BUILTIN_OPTIONS_TYPE, 0)); + } + const void *builtin_options() const { + return GetPointer(VT_BUILTIN_OPTIONS); + } + template const T *builtin_options_as() const; + const tflite::Conv2DOptions *builtin_options_as_Conv2DOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_Conv2DOptions ? static_cast(builtin_options()) : nullptr; + } + const tflite::DepthwiseConv2DOptions *builtin_options_as_DepthwiseConv2DOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_DepthwiseConv2DOptions ? static_cast(builtin_options()) : nullptr; + } + const tflite::ConcatEmbeddingsOptions *builtin_options_as_ConcatEmbeddingsOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_ConcatEmbeddingsOptions ? static_cast(builtin_options()) : nullptr; + } + const tflite::LSHProjectionOptions *builtin_options_as_LSHProjectionOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_LSHProjectionOptions ? static_cast(builtin_options()) : nullptr; + } + const tflite::Pool2DOptions *builtin_options_as_Pool2DOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_Pool2DOptions ? static_cast(builtin_options()) : nullptr; + } + const tflite::SVDFOptions *builtin_options_as_SVDFOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_SVDFOptions ? static_cast(builtin_options()) : nullptr; + } + const tflite::RNNOptions *builtin_options_as_RNNOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_RNNOptions ? static_cast(builtin_options()) : nullptr; + } + const tflite::FullyConnectedOptions *builtin_options_as_FullyConnectedOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_FullyConnectedOptions ? static_cast(builtin_options()) : nullptr; + } + const tflite::SoftmaxOptions *builtin_options_as_SoftmaxOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_SoftmaxOptions ? static_cast(builtin_options()) : nullptr; + } + const tflite::ConcatenationOptions *builtin_options_as_ConcatenationOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_ConcatenationOptions ? static_cast(builtin_options()) : nullptr; + } + const tflite::AddOptions *builtin_options_as_AddOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_AddOptions ? static_cast(builtin_options()) : nullptr; + } + const tflite::L2NormOptions *builtin_options_as_L2NormOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_L2NormOptions ? static_cast(builtin_options()) : nullptr; + } + const tflite::LocalResponseNormalizationOptions *builtin_options_as_LocalResponseNormalizationOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_LocalResponseNormalizationOptions ? static_cast(builtin_options()) : nullptr; + } + const tflite::LSTMOptions *builtin_options_as_LSTMOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_LSTMOptions ? static_cast(builtin_options()) : nullptr; + } + const tflite::ResizeBilinearOptions *builtin_options_as_ResizeBilinearOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_ResizeBilinearOptions ? static_cast(builtin_options()) : nullptr; + } + const tflite::CallOptions *builtin_options_as_CallOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_CallOptions ? static_cast(builtin_options()) : nullptr; + } + const tflite::ReshapeOptions *builtin_options_as_ReshapeOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_ReshapeOptions ? static_cast(builtin_options()) : nullptr; + } + const tflite::SkipGramOptions *builtin_options_as_SkipGramOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_SkipGramOptions ? static_cast(builtin_options()) : nullptr; + } + const tflite::SpaceToDepthOptions *builtin_options_as_SpaceToDepthOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_SpaceToDepthOptions ? static_cast(builtin_options()) : nullptr; + } + const tflite::EmbeddingLookupSparseOptions *builtin_options_as_EmbeddingLookupSparseOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_EmbeddingLookupSparseOptions ? static_cast(builtin_options()) : nullptr; + } + const tflite::MulOptions *builtin_options_as_MulOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_MulOptions ? static_cast(builtin_options()) : nullptr; + } + const tflite::PadOptions *builtin_options_as_PadOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_PadOptions ? static_cast(builtin_options()) : nullptr; + } + const tflite::GatherOptions *builtin_options_as_GatherOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_GatherOptions ? static_cast(builtin_options()) : nullptr; + } + const tflite::BatchToSpaceNDOptions *builtin_options_as_BatchToSpaceNDOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_BatchToSpaceNDOptions ? static_cast(builtin_options()) : nullptr; + } + const tflite::SpaceToBatchNDOptions *builtin_options_as_SpaceToBatchNDOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_SpaceToBatchNDOptions ? static_cast(builtin_options()) : nullptr; + } + const tflite::TransposeOptions *builtin_options_as_TransposeOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_TransposeOptions ? static_cast(builtin_options()) : nullptr; + } + const tflite::ReducerOptions *builtin_options_as_ReducerOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_ReducerOptions ? static_cast(builtin_options()) : nullptr; + } + const tflite::SubOptions *builtin_options_as_SubOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_SubOptions ? static_cast(builtin_options()) : nullptr; + } + const tflite::DivOptions *builtin_options_as_DivOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_DivOptions ? static_cast(builtin_options()) : nullptr; + } + const tflite::SqueezeOptions *builtin_options_as_SqueezeOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_SqueezeOptions ? static_cast(builtin_options()) : nullptr; + } + const tflite::SequenceRNNOptions *builtin_options_as_SequenceRNNOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_SequenceRNNOptions ? static_cast(builtin_options()) : nullptr; + } + const tflite::StridedSliceOptions *builtin_options_as_StridedSliceOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_StridedSliceOptions ? static_cast(builtin_options()) : nullptr; + } + const tflite::ExpOptions *builtin_options_as_ExpOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_ExpOptions ? static_cast(builtin_options()) : nullptr; + } + const tflite::TopKV2Options *builtin_options_as_TopKV2Options() const { + return builtin_options_type() == tflite::BuiltinOptions_TopKV2Options ? static_cast(builtin_options()) : nullptr; + } + const tflite::SplitOptions *builtin_options_as_SplitOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_SplitOptions ? static_cast(builtin_options()) : nullptr; + } + const tflite::LogSoftmaxOptions *builtin_options_as_LogSoftmaxOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_LogSoftmaxOptions ? static_cast(builtin_options()) : nullptr; + } + const tflite::CastOptions *builtin_options_as_CastOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_CastOptions ? static_cast(builtin_options()) : nullptr; + } + const tflite::DequantizeOptions *builtin_options_as_DequantizeOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_DequantizeOptions ? static_cast(builtin_options()) : nullptr; + } + const tflite::MaximumMinimumOptions *builtin_options_as_MaximumMinimumOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_MaximumMinimumOptions ? static_cast(builtin_options()) : nullptr; + } + const tflite::ArgMaxOptions *builtin_options_as_ArgMaxOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_ArgMaxOptions ? static_cast(builtin_options()) : nullptr; + } + const tflite::LessOptions *builtin_options_as_LessOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_LessOptions ? static_cast(builtin_options()) : nullptr; + } + const tflite::NegOptions *builtin_options_as_NegOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_NegOptions ? static_cast(builtin_options()) : nullptr; + } + const tflite::PadV2Options *builtin_options_as_PadV2Options() const { + return builtin_options_type() == tflite::BuiltinOptions_PadV2Options ? static_cast(builtin_options()) : nullptr; + } + const tflite::GreaterOptions *builtin_options_as_GreaterOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_GreaterOptions ? static_cast(builtin_options()) : nullptr; + } + const tflite::GreaterEqualOptions *builtin_options_as_GreaterEqualOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_GreaterEqualOptions ? static_cast(builtin_options()) : nullptr; + } + const tflite::LessEqualOptions *builtin_options_as_LessEqualOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_LessEqualOptions ? static_cast(builtin_options()) : nullptr; + } + const tflite::SelectOptions *builtin_options_as_SelectOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_SelectOptions ? static_cast(builtin_options()) : nullptr; + } + const tflite::SliceOptions *builtin_options_as_SliceOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_SliceOptions ? static_cast(builtin_options()) : nullptr; + } + const tflite::TransposeConvOptions *builtin_options_as_TransposeConvOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_TransposeConvOptions ? static_cast(builtin_options()) : nullptr; + } + const tflite::SparseToDenseOptions *builtin_options_as_SparseToDenseOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_SparseToDenseOptions ? static_cast(builtin_options()) : nullptr; + } + const tflite::TileOptions *builtin_options_as_TileOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_TileOptions ? static_cast(builtin_options()) : nullptr; + } + const tflite::ExpandDimsOptions *builtin_options_as_ExpandDimsOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_ExpandDimsOptions ? static_cast(builtin_options()) : nullptr; + } + const tflite::EqualOptions *builtin_options_as_EqualOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_EqualOptions ? static_cast(builtin_options()) : nullptr; + } + const tflite::NotEqualOptions *builtin_options_as_NotEqualOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_NotEqualOptions ? static_cast(builtin_options()) : nullptr; + } + const tflite::ShapeOptions *builtin_options_as_ShapeOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_ShapeOptions ? static_cast(builtin_options()) : nullptr; + } + const tflite::PowOptions *builtin_options_as_PowOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_PowOptions ? static_cast(builtin_options()) : nullptr; + } + const tflite::ArgMinOptions *builtin_options_as_ArgMinOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_ArgMinOptions ? static_cast(builtin_options()) : nullptr; + } + const tflite::FakeQuantOptions *builtin_options_as_FakeQuantOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_FakeQuantOptions ? static_cast(builtin_options()) : nullptr; + } + const tflite::PackOptions *builtin_options_as_PackOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_PackOptions ? static_cast(builtin_options()) : nullptr; + } + const tflite::LogicalOrOptions *builtin_options_as_LogicalOrOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_LogicalOrOptions ? static_cast(builtin_options()) : nullptr; + } + const tflite::OneHotOptions *builtin_options_as_OneHotOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_OneHotOptions ? static_cast(builtin_options()) : nullptr; + } + const tflite::LogicalAndOptions *builtin_options_as_LogicalAndOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_LogicalAndOptions ? static_cast(builtin_options()) : nullptr; + } + const tflite::LogicalNotOptions *builtin_options_as_LogicalNotOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_LogicalNotOptions ? static_cast(builtin_options()) : nullptr; + } + const tflite::UnpackOptions *builtin_options_as_UnpackOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_UnpackOptions ? static_cast(builtin_options()) : nullptr; + } + const tflite::FloorDivOptions *builtin_options_as_FloorDivOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_FloorDivOptions ? static_cast(builtin_options()) : nullptr; + } + const tflite::SquareOptions *builtin_options_as_SquareOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_SquareOptions ? static_cast(builtin_options()) : nullptr; + } + const tflite::ZerosLikeOptions *builtin_options_as_ZerosLikeOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_ZerosLikeOptions ? static_cast(builtin_options()) : nullptr; + } + const tflite::FillOptions *builtin_options_as_FillOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_FillOptions ? static_cast(builtin_options()) : nullptr; + } + const tflite::BidirectionalSequenceLSTMOptions *builtin_options_as_BidirectionalSequenceLSTMOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_BidirectionalSequenceLSTMOptions ? static_cast(builtin_options()) : nullptr; + } + const tflite::BidirectionalSequenceRNNOptions *builtin_options_as_BidirectionalSequenceRNNOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_BidirectionalSequenceRNNOptions ? static_cast(builtin_options()) : nullptr; + } + const tflite::UnidirectionalSequenceLSTMOptions *builtin_options_as_UnidirectionalSequenceLSTMOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_UnidirectionalSequenceLSTMOptions ? static_cast(builtin_options()) : nullptr; + } + const tflite::FloorModOptions *builtin_options_as_FloorModOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_FloorModOptions ? static_cast(builtin_options()) : nullptr; + } + const tflite::RangeOptions *builtin_options_as_RangeOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_RangeOptions ? static_cast(builtin_options()) : nullptr; + } + const tflite::ResizeNearestNeighborOptions *builtin_options_as_ResizeNearestNeighborOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_ResizeNearestNeighborOptions ? static_cast(builtin_options()) : nullptr; + } + const tflite::LeakyReluOptions *builtin_options_as_LeakyReluOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_LeakyReluOptions ? static_cast(builtin_options()) : nullptr; + } + const tflite::SquaredDifferenceOptions *builtin_options_as_SquaredDifferenceOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_SquaredDifferenceOptions ? static_cast(builtin_options()) : nullptr; + } + const tflite::MirrorPadOptions *builtin_options_as_MirrorPadOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_MirrorPadOptions ? static_cast(builtin_options()) : nullptr; + } + const tflite::AbsOptions *builtin_options_as_AbsOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_AbsOptions ? static_cast(builtin_options()) : nullptr; + } + const tflite::SplitVOptions *builtin_options_as_SplitVOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_SplitVOptions ? static_cast(builtin_options()) : nullptr; + } + const tflite::UniqueOptions *builtin_options_as_UniqueOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_UniqueOptions ? static_cast(builtin_options()) : nullptr; + } + const tflite::ReverseV2Options *builtin_options_as_ReverseV2Options() const { + return builtin_options_type() == tflite::BuiltinOptions_ReverseV2Options ? static_cast(builtin_options()) : nullptr; + } + const tflite::AddNOptions *builtin_options_as_AddNOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_AddNOptions ? static_cast(builtin_options()) : nullptr; + } + const tflite::GatherNdOptions *builtin_options_as_GatherNdOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_GatherNdOptions ? static_cast(builtin_options()) : nullptr; + } + const tflite::CosOptions *builtin_options_as_CosOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_CosOptions ? static_cast(builtin_options()) : nullptr; + } + const tflite::WhereOptions *builtin_options_as_WhereOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_WhereOptions ? static_cast(builtin_options()) : nullptr; + } + const tflite::RankOptions *builtin_options_as_RankOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_RankOptions ? static_cast(builtin_options()) : nullptr; + } + const tflite::ReverseSequenceOptions *builtin_options_as_ReverseSequenceOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_ReverseSequenceOptions ? static_cast(builtin_options()) : nullptr; + } + const tflite::MatrixDiagOptions *builtin_options_as_MatrixDiagOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_MatrixDiagOptions ? static_cast(builtin_options()) : nullptr; + } + const tflite::QuantizeOptions *builtin_options_as_QuantizeOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_QuantizeOptions ? static_cast(builtin_options()) : nullptr; + } + const tflite::MatrixSetDiagOptions *builtin_options_as_MatrixSetDiagOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_MatrixSetDiagOptions ? static_cast(builtin_options()) : nullptr; + } + const tflite::HardSwishOptions *builtin_options_as_HardSwishOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_HardSwishOptions ? static_cast(builtin_options()) : nullptr; + } + const tflite::IfOptions *builtin_options_as_IfOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_IfOptions ? static_cast(builtin_options()) : nullptr; + } + const tflite::WhileOptions *builtin_options_as_WhileOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_WhileOptions ? static_cast(builtin_options()) : nullptr; + } + const tflite::DepthToSpaceOptions *builtin_options_as_DepthToSpaceOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_DepthToSpaceOptions ? static_cast(builtin_options()) : nullptr; + } + const tflite::NonMaxSuppressionV4Options *builtin_options_as_NonMaxSuppressionV4Options() const { + return builtin_options_type() == tflite::BuiltinOptions_NonMaxSuppressionV4Options ? static_cast(builtin_options()) : nullptr; + } + const tflite::NonMaxSuppressionV5Options *builtin_options_as_NonMaxSuppressionV5Options() const { + return builtin_options_type() == tflite::BuiltinOptions_NonMaxSuppressionV5Options ? static_cast(builtin_options()) : nullptr; + } + const tflite::ScatterNdOptions *builtin_options_as_ScatterNdOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_ScatterNdOptions ? static_cast(builtin_options()) : nullptr; + } + const tflite::SelectV2Options *builtin_options_as_SelectV2Options() const { + return builtin_options_type() == tflite::BuiltinOptions_SelectV2Options ? static_cast(builtin_options()) : nullptr; + } + const tflite::DensifyOptions *builtin_options_as_DensifyOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_DensifyOptions ? static_cast(builtin_options()) : nullptr; + } + const tflite::SegmentSumOptions *builtin_options_as_SegmentSumOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_SegmentSumOptions ? static_cast(builtin_options()) : nullptr; + } + const tflite::BatchMatMulOptions *builtin_options_as_BatchMatMulOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_BatchMatMulOptions ? static_cast(builtin_options()) : nullptr; + } + const tflite::CumsumOptions *builtin_options_as_CumsumOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_CumsumOptions ? static_cast(builtin_options()) : nullptr; + } + const tflite::CallOnceOptions *builtin_options_as_CallOnceOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_CallOnceOptions ? static_cast(builtin_options()) : nullptr; + } + const tflite::BroadcastToOptions *builtin_options_as_BroadcastToOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_BroadcastToOptions ? static_cast(builtin_options()) : nullptr; + } + const tflite::Rfft2dOptions *builtin_options_as_Rfft2dOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_Rfft2dOptions ? static_cast(builtin_options()) : nullptr; + } + const tflite::Conv3DOptions *builtin_options_as_Conv3DOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_Conv3DOptions ? static_cast(builtin_options()) : nullptr; + } + const tflite::HashtableOptions *builtin_options_as_HashtableOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_HashtableOptions ? static_cast(builtin_options()) : nullptr; + } + const tflite::HashtableFindOptions *builtin_options_as_HashtableFindOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_HashtableFindOptions ? static_cast(builtin_options()) : nullptr; + } + const tflite::HashtableImportOptions *builtin_options_as_HashtableImportOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_HashtableImportOptions ? static_cast(builtin_options()) : nullptr; + } + const tflite::HashtableSizeOptions *builtin_options_as_HashtableSizeOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_HashtableSizeOptions ? static_cast(builtin_options()) : nullptr; + } + const tflite::VarHandleOptions *builtin_options_as_VarHandleOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_VarHandleOptions ? static_cast(builtin_options()) : nullptr; + } + const tflite::ReadVariableOptions *builtin_options_as_ReadVariableOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_ReadVariableOptions ? static_cast(builtin_options()) : nullptr; + } + const tflite::AssignVariableOptions *builtin_options_as_AssignVariableOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_AssignVariableOptions ? static_cast(builtin_options()) : nullptr; + } + const tflite::RandomOptions *builtin_options_as_RandomOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_RandomOptions ? static_cast(builtin_options()) : nullptr; + } + const tflite::BucketizeOptions *builtin_options_as_BucketizeOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_BucketizeOptions ? static_cast(builtin_options()) : nullptr; + } + const tflite::GeluOptions *builtin_options_as_GeluOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_GeluOptions ? static_cast(builtin_options()) : nullptr; + } + const tflite::DynamicUpdateSliceOptions *builtin_options_as_DynamicUpdateSliceOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_DynamicUpdateSliceOptions ? static_cast(builtin_options()) : nullptr; + } + const tflite::UnsortedSegmentProdOptions *builtin_options_as_UnsortedSegmentProdOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_UnsortedSegmentProdOptions ? static_cast(builtin_options()) : nullptr; + } + const tflite::UnsortedSegmentMaxOptions *builtin_options_as_UnsortedSegmentMaxOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_UnsortedSegmentMaxOptions ? static_cast(builtin_options()) : nullptr; + } + const tflite::UnsortedSegmentMinOptions *builtin_options_as_UnsortedSegmentMinOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_UnsortedSegmentMinOptions ? static_cast(builtin_options()) : nullptr; + } + const tflite::UnsortedSegmentSumOptions *builtin_options_as_UnsortedSegmentSumOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_UnsortedSegmentSumOptions ? static_cast(builtin_options()) : nullptr; + } + const tflite::ATan2Options *builtin_options_as_ATan2Options() const { + return builtin_options_type() == tflite::BuiltinOptions_ATan2Options ? static_cast(builtin_options()) : nullptr; + } + const tflite::SignOptions *builtin_options_as_SignOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_SignOptions ? static_cast(builtin_options()) : nullptr; + } + const tflite::BitcastOptions *builtin_options_as_BitcastOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_BitcastOptions ? static_cast(builtin_options()) : nullptr; + } + const tflite::BitwiseXorOptions *builtin_options_as_BitwiseXorOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_BitwiseXorOptions ? static_cast(builtin_options()) : nullptr; + } + const tflite::RightShiftOptions *builtin_options_as_RightShiftOptions() const { + return builtin_options_type() == tflite::BuiltinOptions_RightShiftOptions ? static_cast(builtin_options()) : nullptr; + } + const ::flatbuffers::Vector *custom_options() const { + return GetPointer *>(VT_CUSTOM_OPTIONS); + } + tflite::CustomOptionsFormat custom_options_format() const { + return static_cast(GetField(VT_CUSTOM_OPTIONS_FORMAT, 0)); + } + const ::flatbuffers::Vector *mutating_variable_inputs() const { + return GetPointer *>(VT_MUTATING_VARIABLE_INPUTS); + } + const ::flatbuffers::Vector *intermediates() const { + return GetPointer *>(VT_INTERMEDIATES); + } + uint64_t large_custom_options_offset() const { + return GetField(VT_LARGE_CUSTOM_OPTIONS_OFFSET, 0); + } + uint64_t large_custom_options_size() const { + return GetField(VT_LARGE_CUSTOM_OPTIONS_SIZE, 0); + } + tflite::BuiltinOptions2 builtin_options_2_type() const { + return static_cast(GetField(VT_BUILTIN_OPTIONS_2_TYPE, 0)); + } + const void *builtin_options_2() const { + return GetPointer(VT_BUILTIN_OPTIONS_2); + } + template const T *builtin_options_2_as() const; + const tflite::StablehloConcatenateOptions *builtin_options_2_as_StablehloConcatenateOptions() const { + return builtin_options_2_type() == tflite::BuiltinOptions2_StablehloConcatenateOptions ? static_cast(builtin_options_2()) : nullptr; + } + const tflite::StablehloBroadcastInDimOptions *builtin_options_2_as_StablehloBroadcastInDimOptions() const { + return builtin_options_2_type() == tflite::BuiltinOptions2_StablehloBroadcastInDimOptions ? static_cast(builtin_options_2()) : nullptr; + } + const tflite::StablehloSliceOptions *builtin_options_2_as_StablehloSliceOptions() const { + return builtin_options_2_type() == tflite::BuiltinOptions2_StablehloSliceOptions ? static_cast(builtin_options_2()) : nullptr; + } + const tflite::StablehloConvolutionOptions *builtin_options_2_as_StablehloConvolutionOptions() const { + return builtin_options_2_type() == tflite::BuiltinOptions2_StablehloConvolutionOptions ? static_cast(builtin_options_2()) : nullptr; + } + const tflite::StablehloCustomCallOptions *builtin_options_2_as_StablehloCustomCallOptions() const { + return builtin_options_2_type() == tflite::BuiltinOptions2_StablehloCustomCallOptions ? static_cast(builtin_options_2()) : nullptr; + } + const tflite::StablehloReduceOptions *builtin_options_2_as_StablehloReduceOptions() const { + return builtin_options_2_type() == tflite::BuiltinOptions2_StablehloReduceOptions ? static_cast(builtin_options_2()) : nullptr; + } + const tflite::StablehloScatterOptions *builtin_options_2_as_StablehloScatterOptions() const { + return builtin_options_2_type() == tflite::BuiltinOptions2_StablehloScatterOptions ? static_cast(builtin_options_2()) : nullptr; + } + const tflite::StablehloCompareOptions *builtin_options_2_as_StablehloCompareOptions() const { + return builtin_options_2_type() == tflite::BuiltinOptions2_StablehloCompareOptions ? static_cast(builtin_options_2()) : nullptr; + } + const tflite::StablehloDynamicSliceOptions *builtin_options_2_as_StablehloDynamicSliceOptions() const { + return builtin_options_2_type() == tflite::BuiltinOptions2_StablehloDynamicSliceOptions ? static_cast(builtin_options_2()) : nullptr; + } + const tflite::StablehloPadOptions *builtin_options_2_as_StablehloPadOptions() const { + return builtin_options_2_type() == tflite::BuiltinOptions2_StablehloPadOptions ? static_cast(builtin_options_2()) : nullptr; + } + const tflite::StablehloIotaOptions *builtin_options_2_as_StablehloIotaOptions() const { + return builtin_options_2_type() == tflite::BuiltinOptions2_StablehloIotaOptions ? static_cast(builtin_options_2()) : nullptr; + } + const tflite::StablehloDotGeneralOptions *builtin_options_2_as_StablehloDotGeneralOptions() const { + return builtin_options_2_type() == tflite::BuiltinOptions2_StablehloDotGeneralOptions ? static_cast(builtin_options_2()) : nullptr; + } + const tflite::StablehloReduceWindowOptions *builtin_options_2_as_StablehloReduceWindowOptions() const { + return builtin_options_2_type() == tflite::BuiltinOptions2_StablehloReduceWindowOptions ? static_cast(builtin_options_2()) : nullptr; + } + const tflite::StablehloSortOptions *builtin_options_2_as_StablehloSortOptions() const { + return builtin_options_2_type() == tflite::BuiltinOptions2_StablehloSortOptions ? static_cast(builtin_options_2()) : nullptr; + } + const tflite::StablehloWhileOptions *builtin_options_2_as_StablehloWhileOptions() const { + return builtin_options_2_type() == tflite::BuiltinOptions2_StablehloWhileOptions ? static_cast(builtin_options_2()) : nullptr; + } + const tflite::StablehloGatherOptions *builtin_options_2_as_StablehloGatherOptions() const { + return builtin_options_2_type() == tflite::BuiltinOptions2_StablehloGatherOptions ? static_cast(builtin_options_2()) : nullptr; + } + const tflite::StablehloTransposeOptions *builtin_options_2_as_StablehloTransposeOptions() const { + return builtin_options_2_type() == tflite::BuiltinOptions2_StablehloTransposeOptions ? static_cast(builtin_options_2()) : nullptr; + } + const tflite::DilateOptions *builtin_options_2_as_DilateOptions() const { + return builtin_options_2_type() == tflite::BuiltinOptions2_DilateOptions ? static_cast(builtin_options_2()) : nullptr; + } + const tflite::StablehloRngBitGeneratorOptions *builtin_options_2_as_StablehloRngBitGeneratorOptions() const { + return builtin_options_2_type() == tflite::BuiltinOptions2_StablehloRngBitGeneratorOptions ? static_cast(builtin_options_2()) : nullptr; + } + const tflite::ReduceWindowOptions *builtin_options_2_as_ReduceWindowOptions() const { + return builtin_options_2_type() == tflite::BuiltinOptions2_ReduceWindowOptions ? static_cast(builtin_options_2()) : nullptr; + } + const tflite::StableHLOCompositeOptions *builtin_options_2_as_StableHLOCompositeOptions() const { + return builtin_options_2_type() == tflite::BuiltinOptions2_StableHLOCompositeOptions ? static_cast(builtin_options_2()) : nullptr; + } + const tflite::StablehloShiftLeftOptions *builtin_options_2_as_StablehloShiftLeftOptions() const { + return builtin_options_2_type() == tflite::BuiltinOptions2_StablehloShiftLeftOptions ? static_cast(builtin_options_2()) : nullptr; + } + const tflite::StablehloCaseOptions *builtin_options_2_as_StablehloCaseOptions() const { + return builtin_options_2_type() == tflite::BuiltinOptions2_StablehloCaseOptions ? static_cast(builtin_options_2()) : nullptr; + } + int32_t debug_metadata_index() const { + return GetField(VT_DEBUG_METADATA_INDEX, -1); + } + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_OPCODE_INDEX, 4) && + VerifyOffset(verifier, VT_INPUTS) && + verifier.VerifyVector(inputs()) && + VerifyOffset(verifier, VT_OUTPUTS) && + verifier.VerifyVector(outputs()) && + VerifyField(verifier, VT_BUILTIN_OPTIONS_TYPE, 1) && + VerifyOffset(verifier, VT_BUILTIN_OPTIONS) && + VerifyBuiltinOptions(verifier, builtin_options(), builtin_options_type()) && + VerifyOffset(verifier, VT_CUSTOM_OPTIONS) && + verifier.VerifyVector(custom_options()) && + VerifyField(verifier, VT_CUSTOM_OPTIONS_FORMAT, 1) && + VerifyOffset(verifier, VT_MUTATING_VARIABLE_INPUTS) && + verifier.VerifyVector(mutating_variable_inputs()) && + VerifyOffset(verifier, VT_INTERMEDIATES) && + verifier.VerifyVector(intermediates()) && + VerifyField(verifier, VT_LARGE_CUSTOM_OPTIONS_OFFSET, 8) && + VerifyField(verifier, VT_LARGE_CUSTOM_OPTIONS_SIZE, 8) && + VerifyField(verifier, VT_BUILTIN_OPTIONS_2_TYPE, 1) && + VerifyOffset(verifier, VT_BUILTIN_OPTIONS_2) && + VerifyBuiltinOptions2(verifier, builtin_options_2(), builtin_options_2_type()) && + VerifyField(verifier, VT_DEBUG_METADATA_INDEX, 4) && + verifier.EndTable(); + } + OperatorT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(OperatorT *_o, const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + static ::flatbuffers::Offset Pack(::flatbuffers::FlatBufferBuilder &_fbb, const OperatorT* _o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +template<> inline const tflite::Conv2DOptions *Operator::builtin_options_as() const { + return builtin_options_as_Conv2DOptions(); +} + +template<> inline const tflite::DepthwiseConv2DOptions *Operator::builtin_options_as() const { + return builtin_options_as_DepthwiseConv2DOptions(); +} + +template<> inline const tflite::ConcatEmbeddingsOptions *Operator::builtin_options_as() const { + return builtin_options_as_ConcatEmbeddingsOptions(); +} + +template<> inline const tflite::LSHProjectionOptions *Operator::builtin_options_as() const { + return builtin_options_as_LSHProjectionOptions(); +} + +template<> inline const tflite::Pool2DOptions *Operator::builtin_options_as() const { + return builtin_options_as_Pool2DOptions(); +} + +template<> inline const tflite::SVDFOptions *Operator::builtin_options_as() const { + return builtin_options_as_SVDFOptions(); +} + +template<> inline const tflite::RNNOptions *Operator::builtin_options_as() const { + return builtin_options_as_RNNOptions(); +} + +template<> inline const tflite::FullyConnectedOptions *Operator::builtin_options_as() const { + return builtin_options_as_FullyConnectedOptions(); +} + +template<> inline const tflite::SoftmaxOptions *Operator::builtin_options_as() const { + return builtin_options_as_SoftmaxOptions(); +} + +template<> inline const tflite::ConcatenationOptions *Operator::builtin_options_as() const { + return builtin_options_as_ConcatenationOptions(); +} + +template<> inline const tflite::AddOptions *Operator::builtin_options_as() const { + return builtin_options_as_AddOptions(); +} + +template<> inline const tflite::L2NormOptions *Operator::builtin_options_as() const { + return builtin_options_as_L2NormOptions(); +} + +template<> inline const tflite::LocalResponseNormalizationOptions *Operator::builtin_options_as() const { + return builtin_options_as_LocalResponseNormalizationOptions(); +} + +template<> inline const tflite::LSTMOptions *Operator::builtin_options_as() const { + return builtin_options_as_LSTMOptions(); +} + +template<> inline const tflite::ResizeBilinearOptions *Operator::builtin_options_as() const { + return builtin_options_as_ResizeBilinearOptions(); +} + +template<> inline const tflite::CallOptions *Operator::builtin_options_as() const { + return builtin_options_as_CallOptions(); +} + +template<> inline const tflite::ReshapeOptions *Operator::builtin_options_as() const { + return builtin_options_as_ReshapeOptions(); +} + +template<> inline const tflite::SkipGramOptions *Operator::builtin_options_as() const { + return builtin_options_as_SkipGramOptions(); +} + +template<> inline const tflite::SpaceToDepthOptions *Operator::builtin_options_as() const { + return builtin_options_as_SpaceToDepthOptions(); +} + +template<> inline const tflite::EmbeddingLookupSparseOptions *Operator::builtin_options_as() const { + return builtin_options_as_EmbeddingLookupSparseOptions(); +} + +template<> inline const tflite::MulOptions *Operator::builtin_options_as() const { + return builtin_options_as_MulOptions(); +} + +template<> inline const tflite::PadOptions *Operator::builtin_options_as() const { + return builtin_options_as_PadOptions(); +} + +template<> inline const tflite::GatherOptions *Operator::builtin_options_as() const { + return builtin_options_as_GatherOptions(); +} + +template<> inline const tflite::BatchToSpaceNDOptions *Operator::builtin_options_as() const { + return builtin_options_as_BatchToSpaceNDOptions(); +} + +template<> inline const tflite::SpaceToBatchNDOptions *Operator::builtin_options_as() const { + return builtin_options_as_SpaceToBatchNDOptions(); +} + +template<> inline const tflite::TransposeOptions *Operator::builtin_options_as() const { + return builtin_options_as_TransposeOptions(); +} + +template<> inline const tflite::ReducerOptions *Operator::builtin_options_as() const { + return builtin_options_as_ReducerOptions(); +} + +template<> inline const tflite::SubOptions *Operator::builtin_options_as() const { + return builtin_options_as_SubOptions(); +} + +template<> inline const tflite::DivOptions *Operator::builtin_options_as() const { + return builtin_options_as_DivOptions(); +} + +template<> inline const tflite::SqueezeOptions *Operator::builtin_options_as() const { + return builtin_options_as_SqueezeOptions(); +} + +template<> inline const tflite::SequenceRNNOptions *Operator::builtin_options_as() const { + return builtin_options_as_SequenceRNNOptions(); +} + +template<> inline const tflite::StridedSliceOptions *Operator::builtin_options_as() const { + return builtin_options_as_StridedSliceOptions(); +} + +template<> inline const tflite::ExpOptions *Operator::builtin_options_as() const { + return builtin_options_as_ExpOptions(); +} + +template<> inline const tflite::TopKV2Options *Operator::builtin_options_as() const { + return builtin_options_as_TopKV2Options(); +} + +template<> inline const tflite::SplitOptions *Operator::builtin_options_as() const { + return builtin_options_as_SplitOptions(); +} + +template<> inline const tflite::LogSoftmaxOptions *Operator::builtin_options_as() const { + return builtin_options_as_LogSoftmaxOptions(); +} + +template<> inline const tflite::CastOptions *Operator::builtin_options_as() const { + return builtin_options_as_CastOptions(); +} + +template<> inline const tflite::DequantizeOptions *Operator::builtin_options_as() const { + return builtin_options_as_DequantizeOptions(); +} + +template<> inline const tflite::MaximumMinimumOptions *Operator::builtin_options_as() const { + return builtin_options_as_MaximumMinimumOptions(); +} + +template<> inline const tflite::ArgMaxOptions *Operator::builtin_options_as() const { + return builtin_options_as_ArgMaxOptions(); +} + +template<> inline const tflite::LessOptions *Operator::builtin_options_as() const { + return builtin_options_as_LessOptions(); +} + +template<> inline const tflite::NegOptions *Operator::builtin_options_as() const { + return builtin_options_as_NegOptions(); +} + +template<> inline const tflite::PadV2Options *Operator::builtin_options_as() const { + return builtin_options_as_PadV2Options(); +} + +template<> inline const tflite::GreaterOptions *Operator::builtin_options_as() const { + return builtin_options_as_GreaterOptions(); +} + +template<> inline const tflite::GreaterEqualOptions *Operator::builtin_options_as() const { + return builtin_options_as_GreaterEqualOptions(); +} + +template<> inline const tflite::LessEqualOptions *Operator::builtin_options_as() const { + return builtin_options_as_LessEqualOptions(); +} + +template<> inline const tflite::SelectOptions *Operator::builtin_options_as() const { + return builtin_options_as_SelectOptions(); +} + +template<> inline const tflite::SliceOptions *Operator::builtin_options_as() const { + return builtin_options_as_SliceOptions(); +} + +template<> inline const tflite::TransposeConvOptions *Operator::builtin_options_as() const { + return builtin_options_as_TransposeConvOptions(); +} + +template<> inline const tflite::SparseToDenseOptions *Operator::builtin_options_as() const { + return builtin_options_as_SparseToDenseOptions(); +} + +template<> inline const tflite::TileOptions *Operator::builtin_options_as() const { + return builtin_options_as_TileOptions(); +} + +template<> inline const tflite::ExpandDimsOptions *Operator::builtin_options_as() const { + return builtin_options_as_ExpandDimsOptions(); +} + +template<> inline const tflite::EqualOptions *Operator::builtin_options_as() const { + return builtin_options_as_EqualOptions(); +} + +template<> inline const tflite::NotEqualOptions *Operator::builtin_options_as() const { + return builtin_options_as_NotEqualOptions(); +} + +template<> inline const tflite::ShapeOptions *Operator::builtin_options_as() const { + return builtin_options_as_ShapeOptions(); +} + +template<> inline const tflite::PowOptions *Operator::builtin_options_as() const { + return builtin_options_as_PowOptions(); +} + +template<> inline const tflite::ArgMinOptions *Operator::builtin_options_as() const { + return builtin_options_as_ArgMinOptions(); +} + +template<> inline const tflite::FakeQuantOptions *Operator::builtin_options_as() const { + return builtin_options_as_FakeQuantOptions(); +} + +template<> inline const tflite::PackOptions *Operator::builtin_options_as() const { + return builtin_options_as_PackOptions(); +} + +template<> inline const tflite::LogicalOrOptions *Operator::builtin_options_as() const { + return builtin_options_as_LogicalOrOptions(); +} + +template<> inline const tflite::OneHotOptions *Operator::builtin_options_as() const { + return builtin_options_as_OneHotOptions(); +} + +template<> inline const tflite::LogicalAndOptions *Operator::builtin_options_as() const { + return builtin_options_as_LogicalAndOptions(); +} + +template<> inline const tflite::LogicalNotOptions *Operator::builtin_options_as() const { + return builtin_options_as_LogicalNotOptions(); +} + +template<> inline const tflite::UnpackOptions *Operator::builtin_options_as() const { + return builtin_options_as_UnpackOptions(); +} + +template<> inline const tflite::FloorDivOptions *Operator::builtin_options_as() const { + return builtin_options_as_FloorDivOptions(); +} + +template<> inline const tflite::SquareOptions *Operator::builtin_options_as() const { + return builtin_options_as_SquareOptions(); +} + +template<> inline const tflite::ZerosLikeOptions *Operator::builtin_options_as() const { + return builtin_options_as_ZerosLikeOptions(); +} + +template<> inline const tflite::FillOptions *Operator::builtin_options_as() const { + return builtin_options_as_FillOptions(); +} + +template<> inline const tflite::BidirectionalSequenceLSTMOptions *Operator::builtin_options_as() const { + return builtin_options_as_BidirectionalSequenceLSTMOptions(); +} + +template<> inline const tflite::BidirectionalSequenceRNNOptions *Operator::builtin_options_as() const { + return builtin_options_as_BidirectionalSequenceRNNOptions(); +} + +template<> inline const tflite::UnidirectionalSequenceLSTMOptions *Operator::builtin_options_as() const { + return builtin_options_as_UnidirectionalSequenceLSTMOptions(); +} + +template<> inline const tflite::FloorModOptions *Operator::builtin_options_as() const { + return builtin_options_as_FloorModOptions(); +} + +template<> inline const tflite::RangeOptions *Operator::builtin_options_as() const { + return builtin_options_as_RangeOptions(); +} + +template<> inline const tflite::ResizeNearestNeighborOptions *Operator::builtin_options_as() const { + return builtin_options_as_ResizeNearestNeighborOptions(); +} + +template<> inline const tflite::LeakyReluOptions *Operator::builtin_options_as() const { + return builtin_options_as_LeakyReluOptions(); +} + +template<> inline const tflite::SquaredDifferenceOptions *Operator::builtin_options_as() const { + return builtin_options_as_SquaredDifferenceOptions(); +} + +template<> inline const tflite::MirrorPadOptions *Operator::builtin_options_as() const { + return builtin_options_as_MirrorPadOptions(); +} + +template<> inline const tflite::AbsOptions *Operator::builtin_options_as() const { + return builtin_options_as_AbsOptions(); +} + +template<> inline const tflite::SplitVOptions *Operator::builtin_options_as() const { + return builtin_options_as_SplitVOptions(); +} + +template<> inline const tflite::UniqueOptions *Operator::builtin_options_as() const { + return builtin_options_as_UniqueOptions(); +} + +template<> inline const tflite::ReverseV2Options *Operator::builtin_options_as() const { + return builtin_options_as_ReverseV2Options(); +} + +template<> inline const tflite::AddNOptions *Operator::builtin_options_as() const { + return builtin_options_as_AddNOptions(); +} + +template<> inline const tflite::GatherNdOptions *Operator::builtin_options_as() const { + return builtin_options_as_GatherNdOptions(); +} + +template<> inline const tflite::CosOptions *Operator::builtin_options_as() const { + return builtin_options_as_CosOptions(); +} + +template<> inline const tflite::WhereOptions *Operator::builtin_options_as() const { + return builtin_options_as_WhereOptions(); +} + +template<> inline const tflite::RankOptions *Operator::builtin_options_as() const { + return builtin_options_as_RankOptions(); +} + +template<> inline const tflite::ReverseSequenceOptions *Operator::builtin_options_as() const { + return builtin_options_as_ReverseSequenceOptions(); +} + +template<> inline const tflite::MatrixDiagOptions *Operator::builtin_options_as() const { + return builtin_options_as_MatrixDiagOptions(); +} + +template<> inline const tflite::QuantizeOptions *Operator::builtin_options_as() const { + return builtin_options_as_QuantizeOptions(); +} + +template<> inline const tflite::MatrixSetDiagOptions *Operator::builtin_options_as() const { + return builtin_options_as_MatrixSetDiagOptions(); +} + +template<> inline const tflite::HardSwishOptions *Operator::builtin_options_as() const { + return builtin_options_as_HardSwishOptions(); +} + +template<> inline const tflite::IfOptions *Operator::builtin_options_as() const { + return builtin_options_as_IfOptions(); +} + +template<> inline const tflite::WhileOptions *Operator::builtin_options_as() const { + return builtin_options_as_WhileOptions(); +} + +template<> inline const tflite::DepthToSpaceOptions *Operator::builtin_options_as() const { + return builtin_options_as_DepthToSpaceOptions(); +} + +template<> inline const tflite::NonMaxSuppressionV4Options *Operator::builtin_options_as() const { + return builtin_options_as_NonMaxSuppressionV4Options(); +} + +template<> inline const tflite::NonMaxSuppressionV5Options *Operator::builtin_options_as() const { + return builtin_options_as_NonMaxSuppressionV5Options(); +} + +template<> inline const tflite::ScatterNdOptions *Operator::builtin_options_as() const { + return builtin_options_as_ScatterNdOptions(); +} + +template<> inline const tflite::SelectV2Options *Operator::builtin_options_as() const { + return builtin_options_as_SelectV2Options(); +} + +template<> inline const tflite::DensifyOptions *Operator::builtin_options_as() const { + return builtin_options_as_DensifyOptions(); +} + +template<> inline const tflite::SegmentSumOptions *Operator::builtin_options_as() const { + return builtin_options_as_SegmentSumOptions(); +} + +template<> inline const tflite::BatchMatMulOptions *Operator::builtin_options_as() const { + return builtin_options_as_BatchMatMulOptions(); +} + +template<> inline const tflite::CumsumOptions *Operator::builtin_options_as() const { + return builtin_options_as_CumsumOptions(); +} + +template<> inline const tflite::CallOnceOptions *Operator::builtin_options_as() const { + return builtin_options_as_CallOnceOptions(); +} + +template<> inline const tflite::BroadcastToOptions *Operator::builtin_options_as() const { + return builtin_options_as_BroadcastToOptions(); +} + +template<> inline const tflite::Rfft2dOptions *Operator::builtin_options_as() const { + return builtin_options_as_Rfft2dOptions(); +} + +template<> inline const tflite::Conv3DOptions *Operator::builtin_options_as() const { + return builtin_options_as_Conv3DOptions(); +} + +template<> inline const tflite::HashtableOptions *Operator::builtin_options_as() const { + return builtin_options_as_HashtableOptions(); +} + +template<> inline const tflite::HashtableFindOptions *Operator::builtin_options_as() const { + return builtin_options_as_HashtableFindOptions(); +} + +template<> inline const tflite::HashtableImportOptions *Operator::builtin_options_as() const { + return builtin_options_as_HashtableImportOptions(); +} + +template<> inline const tflite::HashtableSizeOptions *Operator::builtin_options_as() const { + return builtin_options_as_HashtableSizeOptions(); +} + +template<> inline const tflite::VarHandleOptions *Operator::builtin_options_as() const { + return builtin_options_as_VarHandleOptions(); +} + +template<> inline const tflite::ReadVariableOptions *Operator::builtin_options_as() const { + return builtin_options_as_ReadVariableOptions(); +} + +template<> inline const tflite::AssignVariableOptions *Operator::builtin_options_as() const { + return builtin_options_as_AssignVariableOptions(); +} + +template<> inline const tflite::RandomOptions *Operator::builtin_options_as() const { + return builtin_options_as_RandomOptions(); +} + +template<> inline const tflite::BucketizeOptions *Operator::builtin_options_as() const { + return builtin_options_as_BucketizeOptions(); +} + +template<> inline const tflite::GeluOptions *Operator::builtin_options_as() const { + return builtin_options_as_GeluOptions(); +} + +template<> inline const tflite::DynamicUpdateSliceOptions *Operator::builtin_options_as() const { + return builtin_options_as_DynamicUpdateSliceOptions(); +} + +template<> inline const tflite::UnsortedSegmentProdOptions *Operator::builtin_options_as() const { + return builtin_options_as_UnsortedSegmentProdOptions(); +} + +template<> inline const tflite::UnsortedSegmentMaxOptions *Operator::builtin_options_as() const { + return builtin_options_as_UnsortedSegmentMaxOptions(); +} + +template<> inline const tflite::UnsortedSegmentMinOptions *Operator::builtin_options_as() const { + return builtin_options_as_UnsortedSegmentMinOptions(); +} + +template<> inline const tflite::UnsortedSegmentSumOptions *Operator::builtin_options_as() const { + return builtin_options_as_UnsortedSegmentSumOptions(); +} + +template<> inline const tflite::ATan2Options *Operator::builtin_options_as() const { + return builtin_options_as_ATan2Options(); +} + +template<> inline const tflite::SignOptions *Operator::builtin_options_as() const { + return builtin_options_as_SignOptions(); +} + +template<> inline const tflite::BitcastOptions *Operator::builtin_options_as() const { + return builtin_options_as_BitcastOptions(); +} + +template<> inline const tflite::BitwiseXorOptions *Operator::builtin_options_as() const { + return builtin_options_as_BitwiseXorOptions(); +} + +template<> inline const tflite::RightShiftOptions *Operator::builtin_options_as() const { + return builtin_options_as_RightShiftOptions(); +} + +template<> inline const tflite::StablehloConcatenateOptions *Operator::builtin_options_2_as() const { + return builtin_options_2_as_StablehloConcatenateOptions(); +} + +template<> inline const tflite::StablehloBroadcastInDimOptions *Operator::builtin_options_2_as() const { + return builtin_options_2_as_StablehloBroadcastInDimOptions(); +} + +template<> inline const tflite::StablehloSliceOptions *Operator::builtin_options_2_as() const { + return builtin_options_2_as_StablehloSliceOptions(); +} + +template<> inline const tflite::StablehloConvolutionOptions *Operator::builtin_options_2_as() const { + return builtin_options_2_as_StablehloConvolutionOptions(); +} + +template<> inline const tflite::StablehloCustomCallOptions *Operator::builtin_options_2_as() const { + return builtin_options_2_as_StablehloCustomCallOptions(); +} + +template<> inline const tflite::StablehloReduceOptions *Operator::builtin_options_2_as() const { + return builtin_options_2_as_StablehloReduceOptions(); +} + +template<> inline const tflite::StablehloScatterOptions *Operator::builtin_options_2_as() const { + return builtin_options_2_as_StablehloScatterOptions(); +} + +template<> inline const tflite::StablehloCompareOptions *Operator::builtin_options_2_as() const { + return builtin_options_2_as_StablehloCompareOptions(); +} + +template<> inline const tflite::StablehloDynamicSliceOptions *Operator::builtin_options_2_as() const { + return builtin_options_2_as_StablehloDynamicSliceOptions(); +} + +template<> inline const tflite::StablehloPadOptions *Operator::builtin_options_2_as() const { + return builtin_options_2_as_StablehloPadOptions(); +} + +template<> inline const tflite::StablehloIotaOptions *Operator::builtin_options_2_as() const { + return builtin_options_2_as_StablehloIotaOptions(); +} + +template<> inline const tflite::StablehloDotGeneralOptions *Operator::builtin_options_2_as() const { + return builtin_options_2_as_StablehloDotGeneralOptions(); +} + +template<> inline const tflite::StablehloReduceWindowOptions *Operator::builtin_options_2_as() const { + return builtin_options_2_as_StablehloReduceWindowOptions(); +} + +template<> inline const tflite::StablehloSortOptions *Operator::builtin_options_2_as() const { + return builtin_options_2_as_StablehloSortOptions(); +} + +template<> inline const tflite::StablehloWhileOptions *Operator::builtin_options_2_as() const { + return builtin_options_2_as_StablehloWhileOptions(); +} + +template<> inline const tflite::StablehloGatherOptions *Operator::builtin_options_2_as() const { + return builtin_options_2_as_StablehloGatherOptions(); +} + +template<> inline const tflite::StablehloTransposeOptions *Operator::builtin_options_2_as() const { + return builtin_options_2_as_StablehloTransposeOptions(); +} + +template<> inline const tflite::DilateOptions *Operator::builtin_options_2_as() const { + return builtin_options_2_as_DilateOptions(); +} + +template<> inline const tflite::StablehloRngBitGeneratorOptions *Operator::builtin_options_2_as() const { + return builtin_options_2_as_StablehloRngBitGeneratorOptions(); +} + +template<> inline const tflite::ReduceWindowOptions *Operator::builtin_options_2_as() const { + return builtin_options_2_as_ReduceWindowOptions(); +} + +template<> inline const tflite::StableHLOCompositeOptions *Operator::builtin_options_2_as() const { + return builtin_options_2_as_StableHLOCompositeOptions(); +} + +template<> inline const tflite::StablehloShiftLeftOptions *Operator::builtin_options_2_as() const { + return builtin_options_2_as_StablehloShiftLeftOptions(); +} + +template<> inline const tflite::StablehloCaseOptions *Operator::builtin_options_2_as() const { + return builtin_options_2_as_StablehloCaseOptions(); +} + +struct OperatorBuilder { + typedef Operator Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_opcode_index(uint32_t opcode_index) { + fbb_.AddElement(Operator::VT_OPCODE_INDEX, opcode_index, 0); + } + void add_inputs(::flatbuffers::Offset<::flatbuffers::Vector> inputs) { + fbb_.AddOffset(Operator::VT_INPUTS, inputs); + } + void add_outputs(::flatbuffers::Offset<::flatbuffers::Vector> outputs) { + fbb_.AddOffset(Operator::VT_OUTPUTS, outputs); + } + void add_builtin_options_type(tflite::BuiltinOptions builtin_options_type) { + fbb_.AddElement(Operator::VT_BUILTIN_OPTIONS_TYPE, static_cast(builtin_options_type), 0); + } + void add_builtin_options(::flatbuffers::Offset builtin_options) { + fbb_.AddOffset(Operator::VT_BUILTIN_OPTIONS, builtin_options); + } + void add_custom_options(::flatbuffers::Offset<::flatbuffers::Vector> custom_options) { + fbb_.AddOffset(Operator::VT_CUSTOM_OPTIONS, custom_options); + } + void add_custom_options_format(tflite::CustomOptionsFormat custom_options_format) { + fbb_.AddElement(Operator::VT_CUSTOM_OPTIONS_FORMAT, static_cast(custom_options_format), 0); + } + void add_mutating_variable_inputs(::flatbuffers::Offset<::flatbuffers::Vector> mutating_variable_inputs) { + fbb_.AddOffset(Operator::VT_MUTATING_VARIABLE_INPUTS, mutating_variable_inputs); + } + void add_intermediates(::flatbuffers::Offset<::flatbuffers::Vector> intermediates) { + fbb_.AddOffset(Operator::VT_INTERMEDIATES, intermediates); + } + void add_large_custom_options_offset(uint64_t large_custom_options_offset) { + fbb_.AddElement(Operator::VT_LARGE_CUSTOM_OPTIONS_OFFSET, large_custom_options_offset, 0); + } + void add_large_custom_options_size(uint64_t large_custom_options_size) { + fbb_.AddElement(Operator::VT_LARGE_CUSTOM_OPTIONS_SIZE, large_custom_options_size, 0); + } + void add_builtin_options_2_type(tflite::BuiltinOptions2 builtin_options_2_type) { + fbb_.AddElement(Operator::VT_BUILTIN_OPTIONS_2_TYPE, static_cast(builtin_options_2_type), 0); + } + void add_builtin_options_2(::flatbuffers::Offset builtin_options_2) { + fbb_.AddOffset(Operator::VT_BUILTIN_OPTIONS_2, builtin_options_2); + } + void add_debug_metadata_index(int32_t debug_metadata_index) { + fbb_.AddElement(Operator::VT_DEBUG_METADATA_INDEX, debug_metadata_index, -1); + } + explicit OperatorBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateOperator( + ::flatbuffers::FlatBufferBuilder &_fbb, + uint32_t opcode_index = 0, + ::flatbuffers::Offset<::flatbuffers::Vector> inputs = 0, + ::flatbuffers::Offset<::flatbuffers::Vector> outputs = 0, + tflite::BuiltinOptions builtin_options_type = tflite::BuiltinOptions_NONE, + ::flatbuffers::Offset builtin_options = 0, + ::flatbuffers::Offset<::flatbuffers::Vector> custom_options = 0, + tflite::CustomOptionsFormat custom_options_format = tflite::CustomOptionsFormat_FLEXBUFFERS, + ::flatbuffers::Offset<::flatbuffers::Vector> mutating_variable_inputs = 0, + ::flatbuffers::Offset<::flatbuffers::Vector> intermediates = 0, + uint64_t large_custom_options_offset = 0, + uint64_t large_custom_options_size = 0, + tflite::BuiltinOptions2 builtin_options_2_type = tflite::BuiltinOptions2_NONE, + ::flatbuffers::Offset builtin_options_2 = 0, + int32_t debug_metadata_index = -1) { + OperatorBuilder builder_(_fbb); + builder_.add_large_custom_options_size(large_custom_options_size); + builder_.add_large_custom_options_offset(large_custom_options_offset); + builder_.add_debug_metadata_index(debug_metadata_index); + builder_.add_builtin_options_2(builtin_options_2); + builder_.add_intermediates(intermediates); + builder_.add_mutating_variable_inputs(mutating_variable_inputs); + builder_.add_custom_options(custom_options); + builder_.add_builtin_options(builtin_options); + builder_.add_outputs(outputs); + builder_.add_inputs(inputs); + builder_.add_opcode_index(opcode_index); + builder_.add_builtin_options_2_type(builtin_options_2_type); + builder_.add_custom_options_format(custom_options_format); + builder_.add_builtin_options_type(builtin_options_type); + return builder_.Finish(); +} + +inline ::flatbuffers::Offset CreateOperatorDirect( + ::flatbuffers::FlatBufferBuilder &_fbb, + uint32_t opcode_index = 0, + const std::vector *inputs = nullptr, + const std::vector *outputs = nullptr, + tflite::BuiltinOptions builtin_options_type = tflite::BuiltinOptions_NONE, + ::flatbuffers::Offset builtin_options = 0, + const std::vector *custom_options = nullptr, + tflite::CustomOptionsFormat custom_options_format = tflite::CustomOptionsFormat_FLEXBUFFERS, + const std::vector *mutating_variable_inputs = nullptr, + const std::vector *intermediates = nullptr, + uint64_t large_custom_options_offset = 0, + uint64_t large_custom_options_size = 0, + tflite::BuiltinOptions2 builtin_options_2_type = tflite::BuiltinOptions2_NONE, + ::flatbuffers::Offset builtin_options_2 = 0, + int32_t debug_metadata_index = -1) { + auto inputs__ = inputs ? _fbb.CreateVector(*inputs) : 0; + auto outputs__ = outputs ? _fbb.CreateVector(*outputs) : 0; + auto custom_options__ = custom_options ? _fbb.CreateVector(*custom_options) : 0; + auto mutating_variable_inputs__ = mutating_variable_inputs ? _fbb.CreateVector(*mutating_variable_inputs) : 0; + auto intermediates__ = intermediates ? _fbb.CreateVector(*intermediates) : 0; + return tflite::CreateOperator( + _fbb, + opcode_index, + inputs__, + outputs__, + builtin_options_type, + builtin_options, + custom_options__, + custom_options_format, + mutating_variable_inputs__, + intermediates__, + large_custom_options_offset, + large_custom_options_size, + builtin_options_2_type, + builtin_options_2, + debug_metadata_index); +} + +::flatbuffers::Offset CreateOperator(::flatbuffers::FlatBufferBuilder &_fbb, const OperatorT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct SubGraphT : public ::flatbuffers::NativeTable { + typedef SubGraph TableType; + std::vector> tensors{}; + std::vector inputs{}; + std::vector outputs{}; + std::vector> operators{}; + std::string name{}; + int32_t debug_metadata_index = -1; + SubGraphT() = default; + SubGraphT(const SubGraphT &o); + SubGraphT(SubGraphT&&) FLATBUFFERS_NOEXCEPT = default; + SubGraphT &operator=(SubGraphT o) FLATBUFFERS_NOEXCEPT; +}; + +struct SubGraph FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef SubGraphT NativeTableType; + typedef SubGraphBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_TENSORS = 4, + VT_INPUTS = 6, + VT_OUTPUTS = 8, + VT_OPERATORS = 10, + VT_NAME = 12, + VT_DEBUG_METADATA_INDEX = 14 + }; + const ::flatbuffers::Vector<::flatbuffers::Offset> *tensors() const { + return GetPointer> *>(VT_TENSORS); + } + const ::flatbuffers::Vector *inputs() const { + return GetPointer *>(VT_INPUTS); + } + const ::flatbuffers::Vector *outputs() const { + return GetPointer *>(VT_OUTPUTS); + } + const ::flatbuffers::Vector<::flatbuffers::Offset> *operators() const { + return GetPointer> *>(VT_OPERATORS); + } + const ::flatbuffers::String *name() const { + return GetPointer(VT_NAME); + } + int32_t debug_metadata_index() const { + return GetField(VT_DEBUG_METADATA_INDEX, -1); + } + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyOffset(verifier, VT_TENSORS) && + verifier.VerifyVector(tensors()) && + verifier.VerifyVectorOfTables(tensors()) && + VerifyOffset(verifier, VT_INPUTS) && + verifier.VerifyVector(inputs()) && + VerifyOffset(verifier, VT_OUTPUTS) && + verifier.VerifyVector(outputs()) && + VerifyOffset(verifier, VT_OPERATORS) && + verifier.VerifyVector(operators()) && + verifier.VerifyVectorOfTables(operators()) && + VerifyOffset(verifier, VT_NAME) && + verifier.VerifyString(name()) && + VerifyField(verifier, VT_DEBUG_METADATA_INDEX, 4) && + verifier.EndTable(); + } + SubGraphT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(SubGraphT *_o, const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + static ::flatbuffers::Offset Pack(::flatbuffers::FlatBufferBuilder &_fbb, const SubGraphT* _o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct SubGraphBuilder { + typedef SubGraph Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_tensors(::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset>> tensors) { + fbb_.AddOffset(SubGraph::VT_TENSORS, tensors); + } + void add_inputs(::flatbuffers::Offset<::flatbuffers::Vector> inputs) { + fbb_.AddOffset(SubGraph::VT_INPUTS, inputs); + } + void add_outputs(::flatbuffers::Offset<::flatbuffers::Vector> outputs) { + fbb_.AddOffset(SubGraph::VT_OUTPUTS, outputs); + } + void add_operators(::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset>> operators) { + fbb_.AddOffset(SubGraph::VT_OPERATORS, operators); + } + void add_name(::flatbuffers::Offset<::flatbuffers::String> name) { + fbb_.AddOffset(SubGraph::VT_NAME, name); + } + void add_debug_metadata_index(int32_t debug_metadata_index) { + fbb_.AddElement(SubGraph::VT_DEBUG_METADATA_INDEX, debug_metadata_index, -1); + } + explicit SubGraphBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateSubGraph( + ::flatbuffers::FlatBufferBuilder &_fbb, + ::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset>> tensors = 0, + ::flatbuffers::Offset<::flatbuffers::Vector> inputs = 0, + ::flatbuffers::Offset<::flatbuffers::Vector> outputs = 0, + ::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset>> operators = 0, + ::flatbuffers::Offset<::flatbuffers::String> name = 0, + int32_t debug_metadata_index = -1) { + SubGraphBuilder builder_(_fbb); + builder_.add_debug_metadata_index(debug_metadata_index); + builder_.add_name(name); + builder_.add_operators(operators); + builder_.add_outputs(outputs); + builder_.add_inputs(inputs); + builder_.add_tensors(tensors); + return builder_.Finish(); +} + +inline ::flatbuffers::Offset CreateSubGraphDirect( + ::flatbuffers::FlatBufferBuilder &_fbb, + const std::vector<::flatbuffers::Offset> *tensors = nullptr, + const std::vector *inputs = nullptr, + const std::vector *outputs = nullptr, + const std::vector<::flatbuffers::Offset> *operators = nullptr, + const char *name = nullptr, + int32_t debug_metadata_index = -1) { + auto tensors__ = tensors ? _fbb.CreateVector<::flatbuffers::Offset>(*tensors) : 0; + auto inputs__ = inputs ? _fbb.CreateVector(*inputs) : 0; + auto outputs__ = outputs ? _fbb.CreateVector(*outputs) : 0; + auto operators__ = operators ? _fbb.CreateVector<::flatbuffers::Offset>(*operators) : 0; + auto name__ = name ? _fbb.CreateString(name) : 0; + return tflite::CreateSubGraph( + _fbb, + tensors__, + inputs__, + outputs__, + operators__, + name__, + debug_metadata_index); +} + +::flatbuffers::Offset CreateSubGraph(::flatbuffers::FlatBufferBuilder &_fbb, const SubGraphT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct BufferT : public ::flatbuffers::NativeTable { + typedef Buffer TableType; + std::vector data{}; + uint64_t offset = 0; + uint64_t size = 0; +}; + +struct Buffer FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef BufferT NativeTableType; + typedef BufferBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_DATA = 4, + VT_OFFSET = 6, + VT_SIZE = 8 + }; + const ::flatbuffers::Vector *data() const { + return GetPointer *>(VT_DATA); + } + uint64_t offset() const { + return GetField(VT_OFFSET, 0); + } + uint64_t size() const { + return GetField(VT_SIZE, 0); + } + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyOffset(verifier, VT_DATA) && + verifier.VerifyVector(data()) && + VerifyField(verifier, VT_OFFSET, 8) && + VerifyField(verifier, VT_SIZE, 8) && + verifier.EndTable(); + } + BufferT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(BufferT *_o, const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + static ::flatbuffers::Offset Pack(::flatbuffers::FlatBufferBuilder &_fbb, const BufferT* _o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct BufferBuilder { + typedef Buffer Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_data(::flatbuffers::Offset<::flatbuffers::Vector> data) { + fbb_.AddOffset(Buffer::VT_DATA, data); + } + void add_offset(uint64_t offset) { + fbb_.AddElement(Buffer::VT_OFFSET, offset, 0); + } + void add_size(uint64_t size) { + fbb_.AddElement(Buffer::VT_SIZE, size, 0); + } + explicit BufferBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateBuffer( + ::flatbuffers::FlatBufferBuilder &_fbb, + ::flatbuffers::Offset<::flatbuffers::Vector> data = 0, + uint64_t offset = 0, + uint64_t size = 0) { + BufferBuilder builder_(_fbb); + builder_.add_size(size); + builder_.add_offset(offset); + builder_.add_data(data); + return builder_.Finish(); +} + +inline ::flatbuffers::Offset CreateBufferDirect( + ::flatbuffers::FlatBufferBuilder &_fbb, + const std::vector *data = nullptr, + uint64_t offset = 0, + uint64_t size = 0) { + if (data) { _fbb.ForceVectorAlignment(data->size(), sizeof(uint8_t), 16); } + auto data__ = data ? _fbb.CreateVector(*data) : 0; + return tflite::CreateBuffer( + _fbb, + data__, + offset, + size); +} + +::flatbuffers::Offset CreateBuffer(::flatbuffers::FlatBufferBuilder &_fbb, const BufferT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct MetadataT : public ::flatbuffers::NativeTable { + typedef Metadata TableType; + std::string name{}; + uint32_t buffer = 0; +}; + +struct Metadata FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef MetadataT NativeTableType; + typedef MetadataBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_NAME = 4, + VT_BUFFER = 6 + }; + const ::flatbuffers::String *name() const { + return GetPointer(VT_NAME); + } + uint32_t buffer() const { + return GetField(VT_BUFFER, 0); + } + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyOffset(verifier, VT_NAME) && + verifier.VerifyString(name()) && + VerifyField(verifier, VT_BUFFER, 4) && + verifier.EndTable(); + } + MetadataT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(MetadataT *_o, const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + static ::flatbuffers::Offset Pack(::flatbuffers::FlatBufferBuilder &_fbb, const MetadataT* _o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct MetadataBuilder { + typedef Metadata Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_name(::flatbuffers::Offset<::flatbuffers::String> name) { + fbb_.AddOffset(Metadata::VT_NAME, name); + } + void add_buffer(uint32_t buffer) { + fbb_.AddElement(Metadata::VT_BUFFER, buffer, 0); + } + explicit MetadataBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateMetadata( + ::flatbuffers::FlatBufferBuilder &_fbb, + ::flatbuffers::Offset<::flatbuffers::String> name = 0, + uint32_t buffer = 0) { + MetadataBuilder builder_(_fbb); + builder_.add_buffer(buffer); + builder_.add_name(name); + return builder_.Finish(); +} + +inline ::flatbuffers::Offset CreateMetadataDirect( + ::flatbuffers::FlatBufferBuilder &_fbb, + const char *name = nullptr, + uint32_t buffer = 0) { + auto name__ = name ? _fbb.CreateString(name) : 0; + return tflite::CreateMetadata( + _fbb, + name__, + buffer); +} + +::flatbuffers::Offset CreateMetadata(::flatbuffers::FlatBufferBuilder &_fbb, const MetadataT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct TensorMapT : public ::flatbuffers::NativeTable { + typedef TensorMap TableType; + std::string name{}; + uint32_t tensor_index = 0; +}; + +struct TensorMap FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef TensorMapT NativeTableType; + typedef TensorMapBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_NAME = 4, + VT_TENSOR_INDEX = 6 + }; + const ::flatbuffers::String *name() const { + return GetPointer(VT_NAME); + } + uint32_t tensor_index() const { + return GetField(VT_TENSOR_INDEX, 0); + } + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyOffset(verifier, VT_NAME) && + verifier.VerifyString(name()) && + VerifyField(verifier, VT_TENSOR_INDEX, 4) && + verifier.EndTable(); + } + TensorMapT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(TensorMapT *_o, const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + static ::flatbuffers::Offset Pack(::flatbuffers::FlatBufferBuilder &_fbb, const TensorMapT* _o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct TensorMapBuilder { + typedef TensorMap Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_name(::flatbuffers::Offset<::flatbuffers::String> name) { + fbb_.AddOffset(TensorMap::VT_NAME, name); + } + void add_tensor_index(uint32_t tensor_index) { + fbb_.AddElement(TensorMap::VT_TENSOR_INDEX, tensor_index, 0); + } + explicit TensorMapBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateTensorMap( + ::flatbuffers::FlatBufferBuilder &_fbb, + ::flatbuffers::Offset<::flatbuffers::String> name = 0, + uint32_t tensor_index = 0) { + TensorMapBuilder builder_(_fbb); + builder_.add_tensor_index(tensor_index); + builder_.add_name(name); + return builder_.Finish(); +} + +inline ::flatbuffers::Offset CreateTensorMapDirect( + ::flatbuffers::FlatBufferBuilder &_fbb, + const char *name = nullptr, + uint32_t tensor_index = 0) { + auto name__ = name ? _fbb.CreateString(name) : 0; + return tflite::CreateTensorMap( + _fbb, + name__, + tensor_index); +} + +::flatbuffers::Offset CreateTensorMap(::flatbuffers::FlatBufferBuilder &_fbb, const TensorMapT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct SignatureDefT : public ::flatbuffers::NativeTable { + typedef SignatureDef TableType; + std::vector> inputs{}; + std::vector> outputs{}; + std::string signature_key{}; + uint32_t subgraph_index = 0; + SignatureDefT() = default; + SignatureDefT(const SignatureDefT &o); + SignatureDefT(SignatureDefT&&) FLATBUFFERS_NOEXCEPT = default; + SignatureDefT &operator=(SignatureDefT o) FLATBUFFERS_NOEXCEPT; +}; + +struct SignatureDef FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef SignatureDefT NativeTableType; + typedef SignatureDefBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_INPUTS = 4, + VT_OUTPUTS = 6, + VT_SIGNATURE_KEY = 8, + VT_SUBGRAPH_INDEX = 12 + }; + const ::flatbuffers::Vector<::flatbuffers::Offset> *inputs() const { + return GetPointer> *>(VT_INPUTS); + } + const ::flatbuffers::Vector<::flatbuffers::Offset> *outputs() const { + return GetPointer> *>(VT_OUTPUTS); + } + const ::flatbuffers::String *signature_key() const { + return GetPointer(VT_SIGNATURE_KEY); + } + uint32_t subgraph_index() const { + return GetField(VT_SUBGRAPH_INDEX, 0); + } + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyOffset(verifier, VT_INPUTS) && + verifier.VerifyVector(inputs()) && + verifier.VerifyVectorOfTables(inputs()) && + VerifyOffset(verifier, VT_OUTPUTS) && + verifier.VerifyVector(outputs()) && + verifier.VerifyVectorOfTables(outputs()) && + VerifyOffset(verifier, VT_SIGNATURE_KEY) && + verifier.VerifyString(signature_key()) && + VerifyField(verifier, VT_SUBGRAPH_INDEX, 4) && + verifier.EndTable(); + } + SignatureDefT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(SignatureDefT *_o, const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + static ::flatbuffers::Offset Pack(::flatbuffers::FlatBufferBuilder &_fbb, const SignatureDefT* _o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct SignatureDefBuilder { + typedef SignatureDef Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_inputs(::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset>> inputs) { + fbb_.AddOffset(SignatureDef::VT_INPUTS, inputs); + } + void add_outputs(::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset>> outputs) { + fbb_.AddOffset(SignatureDef::VT_OUTPUTS, outputs); + } + void add_signature_key(::flatbuffers::Offset<::flatbuffers::String> signature_key) { + fbb_.AddOffset(SignatureDef::VT_SIGNATURE_KEY, signature_key); + } + void add_subgraph_index(uint32_t subgraph_index) { + fbb_.AddElement(SignatureDef::VT_SUBGRAPH_INDEX, subgraph_index, 0); + } + explicit SignatureDefBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateSignatureDef( + ::flatbuffers::FlatBufferBuilder &_fbb, + ::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset>> inputs = 0, + ::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset>> outputs = 0, + ::flatbuffers::Offset<::flatbuffers::String> signature_key = 0, + uint32_t subgraph_index = 0) { + SignatureDefBuilder builder_(_fbb); + builder_.add_subgraph_index(subgraph_index); + builder_.add_signature_key(signature_key); + builder_.add_outputs(outputs); + builder_.add_inputs(inputs); + return builder_.Finish(); +} + +inline ::flatbuffers::Offset CreateSignatureDefDirect( + ::flatbuffers::FlatBufferBuilder &_fbb, + const std::vector<::flatbuffers::Offset> *inputs = nullptr, + const std::vector<::flatbuffers::Offset> *outputs = nullptr, + const char *signature_key = nullptr, + uint32_t subgraph_index = 0) { + auto inputs__ = inputs ? _fbb.CreateVector<::flatbuffers::Offset>(*inputs) : 0; + auto outputs__ = outputs ? _fbb.CreateVector<::flatbuffers::Offset>(*outputs) : 0; + auto signature_key__ = signature_key ? _fbb.CreateString(signature_key) : 0; + return tflite::CreateSignatureDef( + _fbb, + inputs__, + outputs__, + signature_key__, + subgraph_index); +} + +::flatbuffers::Offset CreateSignatureDef(::flatbuffers::FlatBufferBuilder &_fbb, const SignatureDefT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct ModelT : public ::flatbuffers::NativeTable { + typedef Model TableType; + uint32_t version = 0; + std::vector> operator_codes{}; + std::vector> subgraphs{}; + std::string description{}; + std::vector> buffers{}; + std::vector metadata_buffer{}; + std::vector> metadata{}; + std::vector> signature_defs{}; + ModelT() = default; + ModelT(const ModelT &o); + ModelT(ModelT&&) FLATBUFFERS_NOEXCEPT = default; + ModelT &operator=(ModelT o) FLATBUFFERS_NOEXCEPT; +}; + +struct Model FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { + typedef ModelT NativeTableType; + typedef ModelBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_VERSION = 4, + VT_OPERATOR_CODES = 6, + VT_SUBGRAPHS = 8, + VT_DESCRIPTION = 10, + VT_BUFFERS = 12, + VT_METADATA_BUFFER = 14, + VT_METADATA = 16, + VT_SIGNATURE_DEFS = 18 + }; + uint32_t version() const { + return GetField(VT_VERSION, 0); + } + const ::flatbuffers::Vector<::flatbuffers::Offset> *operator_codes() const { + return GetPointer> *>(VT_OPERATOR_CODES); + } + const ::flatbuffers::Vector<::flatbuffers::Offset> *subgraphs() const { + return GetPointer> *>(VT_SUBGRAPHS); + } + const ::flatbuffers::String *description() const { + return GetPointer(VT_DESCRIPTION); + } + const ::flatbuffers::Vector<::flatbuffers::Offset> *buffers() const { + return GetPointer> *>(VT_BUFFERS); + } + const ::flatbuffers::Vector *metadata_buffer() const { + return GetPointer *>(VT_METADATA_BUFFER); + } + const ::flatbuffers::Vector<::flatbuffers::Offset> *metadata() const { + return GetPointer> *>(VT_METADATA); + } + const ::flatbuffers::Vector<::flatbuffers::Offset> *signature_defs() const { + return GetPointer> *>(VT_SIGNATURE_DEFS); + } + bool Verify(::flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_VERSION, 4) && + VerifyOffset(verifier, VT_OPERATOR_CODES) && + verifier.VerifyVector(operator_codes()) && + verifier.VerifyVectorOfTables(operator_codes()) && + VerifyOffset(verifier, VT_SUBGRAPHS) && + verifier.VerifyVector(subgraphs()) && + verifier.VerifyVectorOfTables(subgraphs()) && + VerifyOffset(verifier, VT_DESCRIPTION) && + verifier.VerifyString(description()) && + VerifyOffset(verifier, VT_BUFFERS) && + verifier.VerifyVector(buffers()) && + verifier.VerifyVectorOfTables(buffers()) && + VerifyOffset(verifier, VT_METADATA_BUFFER) && + verifier.VerifyVector(metadata_buffer()) && + VerifyOffset(verifier, VT_METADATA) && + verifier.VerifyVector(metadata()) && + verifier.VerifyVectorOfTables(metadata()) && + VerifyOffset(verifier, VT_SIGNATURE_DEFS) && + verifier.VerifyVector(signature_defs()) && + verifier.VerifyVectorOfTables(signature_defs()) && + verifier.EndTable(); + } + ModelT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(ModelT *_o, const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; + static ::flatbuffers::Offset Pack(::flatbuffers::FlatBufferBuilder &_fbb, const ModelT* _o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct ModelBuilder { + typedef Model Table; + ::flatbuffers::FlatBufferBuilder &fbb_; + ::flatbuffers::uoffset_t start_; + void add_version(uint32_t version) { + fbb_.AddElement(Model::VT_VERSION, version, 0); + } + void add_operator_codes(::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset>> operator_codes) { + fbb_.AddOffset(Model::VT_OPERATOR_CODES, operator_codes); + } + void add_subgraphs(::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset>> subgraphs) { + fbb_.AddOffset(Model::VT_SUBGRAPHS, subgraphs); + } + void add_description(::flatbuffers::Offset<::flatbuffers::String> description) { + fbb_.AddOffset(Model::VT_DESCRIPTION, description); + } + void add_buffers(::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset>> buffers) { + fbb_.AddOffset(Model::VT_BUFFERS, buffers); + } + void add_metadata_buffer(::flatbuffers::Offset<::flatbuffers::Vector> metadata_buffer) { + fbb_.AddOffset(Model::VT_METADATA_BUFFER, metadata_buffer); + } + void add_metadata(::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset>> metadata) { + fbb_.AddOffset(Model::VT_METADATA, metadata); + } + void add_signature_defs(::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset>> signature_defs) { + fbb_.AddOffset(Model::VT_SIGNATURE_DEFS, signature_defs); + } + explicit ModelBuilder(::flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ::flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = ::flatbuffers::Offset(end); + return o; + } +}; + +inline ::flatbuffers::Offset CreateModel( + ::flatbuffers::FlatBufferBuilder &_fbb, + uint32_t version = 0, + ::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset>> operator_codes = 0, + ::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset>> subgraphs = 0, + ::flatbuffers::Offset<::flatbuffers::String> description = 0, + ::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset>> buffers = 0, + ::flatbuffers::Offset<::flatbuffers::Vector> metadata_buffer = 0, + ::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset>> metadata = 0, + ::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset>> signature_defs = 0) { + ModelBuilder builder_(_fbb); + builder_.add_signature_defs(signature_defs); + builder_.add_metadata(metadata); + builder_.add_metadata_buffer(metadata_buffer); + builder_.add_buffers(buffers); + builder_.add_description(description); + builder_.add_subgraphs(subgraphs); + builder_.add_operator_codes(operator_codes); + builder_.add_version(version); + return builder_.Finish(); +} + +inline ::flatbuffers::Offset CreateModelDirect( + ::flatbuffers::FlatBufferBuilder &_fbb, + uint32_t version = 0, + const std::vector<::flatbuffers::Offset> *operator_codes = nullptr, + const std::vector<::flatbuffers::Offset> *subgraphs = nullptr, + const char *description = nullptr, + const std::vector<::flatbuffers::Offset> *buffers = nullptr, + const std::vector *metadata_buffer = nullptr, + const std::vector<::flatbuffers::Offset> *metadata = nullptr, + const std::vector<::flatbuffers::Offset> *signature_defs = nullptr) { + auto operator_codes__ = operator_codes ? _fbb.CreateVector<::flatbuffers::Offset>(*operator_codes) : 0; + auto subgraphs__ = subgraphs ? _fbb.CreateVector<::flatbuffers::Offset>(*subgraphs) : 0; + auto description__ = description ? _fbb.CreateString(description) : 0; + auto buffers__ = buffers ? _fbb.CreateVector<::flatbuffers::Offset>(*buffers) : 0; + auto metadata_buffer__ = metadata_buffer ? _fbb.CreateVector(*metadata_buffer) : 0; + auto metadata__ = metadata ? _fbb.CreateVector<::flatbuffers::Offset>(*metadata) : 0; + auto signature_defs__ = signature_defs ? _fbb.CreateVector<::flatbuffers::Offset>(*signature_defs) : 0; + return tflite::CreateModel( + _fbb, + version, + operator_codes__, + subgraphs__, + description__, + buffers__, + metadata_buffer__, + metadata__, + signature_defs__); +} + +::flatbuffers::Offset CreateModel(::flatbuffers::FlatBufferBuilder &_fbb, const ModelT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); + +inline CustomQuantizationT *CustomQuantization::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { + auto _o = std::unique_ptr(new CustomQuantizationT()); + UnPackTo(_o.get(), _resolver); + return _o.release(); +} + +inline void CustomQuantization::UnPackTo(CustomQuantizationT *_o, const ::flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = custom(); if (_e) { _o->custom.resize(_e->size()); std::copy(_e->begin(), _e->end(), _o->custom.begin()); } } +} + +inline ::flatbuffers::Offset CustomQuantization::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const CustomQuantizationT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { + return CreateCustomQuantization(_fbb, _o, _rehasher); +} + +inline ::flatbuffers::Offset CreateCustomQuantization(::flatbuffers::FlatBufferBuilder &_fbb, const CustomQuantizationT *_o, const ::flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { ::flatbuffers::FlatBufferBuilder *__fbb; const CustomQuantizationT* __o; const ::flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + _fbb.ForceVectorAlignment(_o->custom.size(), sizeof(uint8_t), 16); + auto _custom = _o->custom.size() ? _fbb.CreateVector(_o->custom) : 0; + return tflite::CreateCustomQuantization( + _fbb, + _custom); +} + +inline BlockwiseQuantizationT *BlockwiseQuantization::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { + auto _o = std::unique_ptr(new BlockwiseQuantizationT()); + UnPackTo(_o.get(), _resolver); + return _o.release(); +} + +inline void BlockwiseQuantization::UnPackTo(BlockwiseQuantizationT *_o, const ::flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = scales(); _o->scales = _e; } + { auto _e = zero_points(); _o->zero_points = _e; } + { auto _e = block_size(); _o->block_size = _e; } +} + +inline ::flatbuffers::Offset BlockwiseQuantization::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const BlockwiseQuantizationT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { + return CreateBlockwiseQuantization(_fbb, _o, _rehasher); +} + +inline ::flatbuffers::Offset CreateBlockwiseQuantization(::flatbuffers::FlatBufferBuilder &_fbb, const BlockwiseQuantizationT *_o, const ::flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { ::flatbuffers::FlatBufferBuilder *__fbb; const BlockwiseQuantizationT* __o; const ::flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _scales = _o->scales; + auto _zero_points = _o->zero_points; + auto _block_size = _o->block_size; + return tflite::CreateBlockwiseQuantization( + _fbb, + _scales, + _zero_points, + _block_size); +} + +inline QuantizationParametersT *QuantizationParameters::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { + auto _o = std::unique_ptr(new QuantizationParametersT()); + UnPackTo(_o.get(), _resolver); + return _o.release(); +} + +inline void QuantizationParameters::UnPackTo(QuantizationParametersT *_o, const ::flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = min(); if (_e) { _o->min.resize(_e->size()); for (::flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->min[_i] = _e->Get(_i); } } else { _o->min.resize(0); } } + { auto _e = max(); if (_e) { _o->max.resize(_e->size()); for (::flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->max[_i] = _e->Get(_i); } } else { _o->max.resize(0); } } + { auto _e = scale(); if (_e) { _o->scale.resize(_e->size()); for (::flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->scale[_i] = _e->Get(_i); } } else { _o->scale.resize(0); } } + { auto _e = zero_point(); if (_e) { _o->zero_point.resize(_e->size()); for (::flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->zero_point[_i] = _e->Get(_i); } } else { _o->zero_point.resize(0); } } + { auto _e = details_type(); _o->details.type = _e; } + { auto _e = details(); if (_e) _o->details.value = tflite::QuantizationDetailsUnion::UnPack(_e, details_type(), _resolver); } + { auto _e = quantized_dimension(); _o->quantized_dimension = _e; } +} + +inline ::flatbuffers::Offset QuantizationParameters::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const QuantizationParametersT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { + return CreateQuantizationParameters(_fbb, _o, _rehasher); +} + +inline ::flatbuffers::Offset CreateQuantizationParameters(::flatbuffers::FlatBufferBuilder &_fbb, const QuantizationParametersT *_o, const ::flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { ::flatbuffers::FlatBufferBuilder *__fbb; const QuantizationParametersT* __o; const ::flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _min = _o->min.size() ? _fbb.CreateVector(_o->min) : 0; + auto _max = _o->max.size() ? _fbb.CreateVector(_o->max) : 0; + auto _scale = _o->scale.size() ? _fbb.CreateVector(_o->scale) : 0; + auto _zero_point = _o->zero_point.size() ? _fbb.CreateVector(_o->zero_point) : 0; + auto _details_type = _o->details.type; + auto _details = _o->details.Pack(_fbb); + auto _quantized_dimension = _o->quantized_dimension; + return tflite::CreateQuantizationParameters( + _fbb, + _min, + _max, + _scale, + _zero_point, + _details_type, + _details, + _quantized_dimension); +} + +inline Int32VectorT *Int32Vector::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { + auto _o = std::unique_ptr(new Int32VectorT()); + UnPackTo(_o.get(), _resolver); + return _o.release(); +} + +inline void Int32Vector::UnPackTo(Int32VectorT *_o, const ::flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = values(); if (_e) { _o->values.resize(_e->size()); for (::flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->values[_i] = _e->Get(_i); } } else { _o->values.resize(0); } } +} + +inline ::flatbuffers::Offset Int32Vector::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const Int32VectorT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { + return CreateInt32Vector(_fbb, _o, _rehasher); +} + +inline ::flatbuffers::Offset CreateInt32Vector(::flatbuffers::FlatBufferBuilder &_fbb, const Int32VectorT *_o, const ::flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { ::flatbuffers::FlatBufferBuilder *__fbb; const Int32VectorT* __o; const ::flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _values = _o->values.size() ? _fbb.CreateVector(_o->values) : 0; + return tflite::CreateInt32Vector( + _fbb, + _values); +} + +inline Uint16VectorT *Uint16Vector::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { + auto _o = std::unique_ptr(new Uint16VectorT()); + UnPackTo(_o.get(), _resolver); + return _o.release(); +} + +inline void Uint16Vector::UnPackTo(Uint16VectorT *_o, const ::flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = values(); if (_e) { _o->values.resize(_e->size()); for (::flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->values[_i] = _e->Get(_i); } } else { _o->values.resize(0); } } +} + +inline ::flatbuffers::Offset Uint16Vector::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const Uint16VectorT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { + return CreateUint16Vector(_fbb, _o, _rehasher); +} + +inline ::flatbuffers::Offset CreateUint16Vector(::flatbuffers::FlatBufferBuilder &_fbb, const Uint16VectorT *_o, const ::flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { ::flatbuffers::FlatBufferBuilder *__fbb; const Uint16VectorT* __o; const ::flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + _fbb.ForceVectorAlignment(_o->values.size(), sizeof(uint16_t), 4); + auto _values = _o->values.size() ? _fbb.CreateVector(_o->values) : 0; + return tflite::CreateUint16Vector( + _fbb, + _values); +} + +inline Uint8VectorT *Uint8Vector::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { + auto _o = std::unique_ptr(new Uint8VectorT()); + UnPackTo(_o.get(), _resolver); + return _o.release(); +} + +inline void Uint8Vector::UnPackTo(Uint8VectorT *_o, const ::flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = values(); if (_e) { _o->values.resize(_e->size()); std::copy(_e->begin(), _e->end(), _o->values.begin()); } } +} + +inline ::flatbuffers::Offset Uint8Vector::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const Uint8VectorT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { + return CreateUint8Vector(_fbb, _o, _rehasher); +} + +inline ::flatbuffers::Offset CreateUint8Vector(::flatbuffers::FlatBufferBuilder &_fbb, const Uint8VectorT *_o, const ::flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { ::flatbuffers::FlatBufferBuilder *__fbb; const Uint8VectorT* __o; const ::flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + _fbb.ForceVectorAlignment(_o->values.size(), sizeof(uint8_t), 4); + auto _values = _o->values.size() ? _fbb.CreateVector(_o->values) : 0; + return tflite::CreateUint8Vector( + _fbb, + _values); +} + +inline DimensionMetadataT *DimensionMetadata::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { + auto _o = std::unique_ptr(new DimensionMetadataT()); + UnPackTo(_o.get(), _resolver); + return _o.release(); +} + +inline void DimensionMetadata::UnPackTo(DimensionMetadataT *_o, const ::flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = format(); _o->format = _e; } + { auto _e = dense_size(); _o->dense_size = _e; } + { auto _e = array_segments_type(); _o->array_segments.type = _e; } + { auto _e = array_segments(); if (_e) _o->array_segments.value = tflite::SparseIndexVectorUnion::UnPack(_e, array_segments_type(), _resolver); } + { auto _e = array_indices_type(); _o->array_indices.type = _e; } + { auto _e = array_indices(); if (_e) _o->array_indices.value = tflite::SparseIndexVectorUnion::UnPack(_e, array_indices_type(), _resolver); } +} + +inline ::flatbuffers::Offset DimensionMetadata::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const DimensionMetadataT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { + return CreateDimensionMetadata(_fbb, _o, _rehasher); +} + +inline ::flatbuffers::Offset CreateDimensionMetadata(::flatbuffers::FlatBufferBuilder &_fbb, const DimensionMetadataT *_o, const ::flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { ::flatbuffers::FlatBufferBuilder *__fbb; const DimensionMetadataT* __o; const ::flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _format = _o->format; + auto _dense_size = _o->dense_size; + auto _array_segments_type = _o->array_segments.type; + auto _array_segments = _o->array_segments.Pack(_fbb); + auto _array_indices_type = _o->array_indices.type; + auto _array_indices = _o->array_indices.Pack(_fbb); + return tflite::CreateDimensionMetadata( + _fbb, + _format, + _dense_size, + _array_segments_type, + _array_segments, + _array_indices_type, + _array_indices); +} + +inline SparsityParametersT::SparsityParametersT(const SparsityParametersT &o) + : traversal_order(o.traversal_order), + block_map(o.block_map) { + dim_metadata.reserve(o.dim_metadata.size()); + for (const auto &dim_metadata_ : o.dim_metadata) { dim_metadata.emplace_back((dim_metadata_) ? new tflite::DimensionMetadataT(*dim_metadata_) : nullptr); } +} + +inline SparsityParametersT &SparsityParametersT::operator=(SparsityParametersT o) FLATBUFFERS_NOEXCEPT { + std::swap(traversal_order, o.traversal_order); + std::swap(block_map, o.block_map); + std::swap(dim_metadata, o.dim_metadata); + return *this; +} + +inline SparsityParametersT *SparsityParameters::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { + auto _o = std::unique_ptr(new SparsityParametersT()); + UnPackTo(_o.get(), _resolver); + return _o.release(); +} + +inline void SparsityParameters::UnPackTo(SparsityParametersT *_o, const ::flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = traversal_order(); if (_e) { _o->traversal_order.resize(_e->size()); for (::flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->traversal_order[_i] = _e->Get(_i); } } else { _o->traversal_order.resize(0); } } + { auto _e = block_map(); if (_e) { _o->block_map.resize(_e->size()); for (::flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->block_map[_i] = _e->Get(_i); } } else { _o->block_map.resize(0); } } + { auto _e = dim_metadata(); if (_e) { _o->dim_metadata.resize(_e->size()); for (::flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { if(_o->dim_metadata[_i]) { _e->Get(_i)->UnPackTo(_o->dim_metadata[_i].get(), _resolver); } else { _o->dim_metadata[_i] = std::unique_ptr(_e->Get(_i)->UnPack(_resolver)); }; } } else { _o->dim_metadata.resize(0); } } +} + +inline ::flatbuffers::Offset SparsityParameters::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const SparsityParametersT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { + return CreateSparsityParameters(_fbb, _o, _rehasher); +} + +inline ::flatbuffers::Offset CreateSparsityParameters(::flatbuffers::FlatBufferBuilder &_fbb, const SparsityParametersT *_o, const ::flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { ::flatbuffers::FlatBufferBuilder *__fbb; const SparsityParametersT* __o; const ::flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _traversal_order = _o->traversal_order.size() ? _fbb.CreateVector(_o->traversal_order) : 0; + auto _block_map = _o->block_map.size() ? _fbb.CreateVector(_o->block_map) : 0; + auto _dim_metadata = _o->dim_metadata.size() ? _fbb.CreateVector<::flatbuffers::Offset> (_o->dim_metadata.size(), [](size_t i, _VectorArgs *__va) { return CreateDimensionMetadata(*__va->__fbb, __va->__o->dim_metadata[i].get(), __va->__rehasher); }, &_va ) : 0; + return tflite::CreateSparsityParameters( + _fbb, + _traversal_order, + _block_map, + _dim_metadata); +} + +inline VariantSubTypeT *VariantSubType::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { + auto _o = std::unique_ptr(new VariantSubTypeT()); + UnPackTo(_o.get(), _resolver); + return _o.release(); +} + +inline void VariantSubType::UnPackTo(VariantSubTypeT *_o, const ::flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = shape(); if (_e) { _o->shape.resize(_e->size()); for (::flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->shape[_i] = _e->Get(_i); } } else { _o->shape.resize(0); } } + { auto _e = type(); _o->type = _e; } + { auto _e = has_rank(); _o->has_rank = _e; } +} + +inline ::flatbuffers::Offset VariantSubType::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const VariantSubTypeT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { + return CreateVariantSubType(_fbb, _o, _rehasher); +} + +inline ::flatbuffers::Offset CreateVariantSubType(::flatbuffers::FlatBufferBuilder &_fbb, const VariantSubTypeT *_o, const ::flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { ::flatbuffers::FlatBufferBuilder *__fbb; const VariantSubTypeT* __o; const ::flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _shape = _o->shape.size() ? _fbb.CreateVector(_o->shape) : 0; + auto _type = _o->type; + auto _has_rank = _o->has_rank; + return tflite::CreateVariantSubType( + _fbb, + _shape, + _type, + _has_rank); +} + +inline TensorT::TensorT(const TensorT &o) + : shape(o.shape), + type(o.type), + buffer(o.buffer), + name(o.name), + quantization((o.quantization) ? new tflite::QuantizationParametersT(*o.quantization) : nullptr), + is_variable(o.is_variable), + sparsity((o.sparsity) ? new tflite::SparsityParametersT(*o.sparsity) : nullptr), + shape_signature(o.shape_signature), + has_rank(o.has_rank) { + variant_tensors.reserve(o.variant_tensors.size()); + for (const auto &variant_tensors_ : o.variant_tensors) { variant_tensors.emplace_back((variant_tensors_) ? new tflite::VariantSubTypeT(*variant_tensors_) : nullptr); } +} + +inline TensorT &TensorT::operator=(TensorT o) FLATBUFFERS_NOEXCEPT { + std::swap(shape, o.shape); + std::swap(type, o.type); + std::swap(buffer, o.buffer); + std::swap(name, o.name); + std::swap(quantization, o.quantization); + std::swap(is_variable, o.is_variable); + std::swap(sparsity, o.sparsity); + std::swap(shape_signature, o.shape_signature); + std::swap(has_rank, o.has_rank); + std::swap(variant_tensors, o.variant_tensors); + return *this; +} + +inline TensorT *Tensor::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { + auto _o = std::unique_ptr(new TensorT()); + UnPackTo(_o.get(), _resolver); + return _o.release(); +} + +inline void Tensor::UnPackTo(TensorT *_o, const ::flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = shape(); if (_e) { _o->shape.resize(_e->size()); for (::flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->shape[_i] = _e->Get(_i); } } else { _o->shape.resize(0); } } + { auto _e = type(); _o->type = _e; } + { auto _e = buffer(); _o->buffer = _e; } + { auto _e = name(); if (_e) _o->name = _e->str(); } + { auto _e = quantization(); if (_e) { if(_o->quantization) { _e->UnPackTo(_o->quantization.get(), _resolver); } else { _o->quantization = std::unique_ptr(_e->UnPack(_resolver)); } } else if (_o->quantization) { _o->quantization.reset(); } } + { auto _e = is_variable(); _o->is_variable = _e; } + { auto _e = sparsity(); if (_e) { if(_o->sparsity) { _e->UnPackTo(_o->sparsity.get(), _resolver); } else { _o->sparsity = std::unique_ptr(_e->UnPack(_resolver)); } } else if (_o->sparsity) { _o->sparsity.reset(); } } + { auto _e = shape_signature(); if (_e) { _o->shape_signature.resize(_e->size()); for (::flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->shape_signature[_i] = _e->Get(_i); } } else { _o->shape_signature.resize(0); } } + { auto _e = has_rank(); _o->has_rank = _e; } + { auto _e = variant_tensors(); if (_e) { _o->variant_tensors.resize(_e->size()); for (::flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { if(_o->variant_tensors[_i]) { _e->Get(_i)->UnPackTo(_o->variant_tensors[_i].get(), _resolver); } else { _o->variant_tensors[_i] = std::unique_ptr(_e->Get(_i)->UnPack(_resolver)); }; } } else { _o->variant_tensors.resize(0); } } +} + +inline ::flatbuffers::Offset Tensor::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const TensorT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { + return CreateTensor(_fbb, _o, _rehasher); +} + +inline ::flatbuffers::Offset CreateTensor(::flatbuffers::FlatBufferBuilder &_fbb, const TensorT *_o, const ::flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { ::flatbuffers::FlatBufferBuilder *__fbb; const TensorT* __o; const ::flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _shape = _o->shape.size() ? _fbb.CreateVector(_o->shape) : 0; + auto _type = _o->type; + auto _buffer = _o->buffer; + auto _name = _o->name.empty() ? 0 : _fbb.CreateString(_o->name); + auto _quantization = _o->quantization ? CreateQuantizationParameters(_fbb, _o->quantization.get(), _rehasher) : 0; + auto _is_variable = _o->is_variable; + auto _sparsity = _o->sparsity ? CreateSparsityParameters(_fbb, _o->sparsity.get(), _rehasher) : 0; + auto _shape_signature = _o->shape_signature.size() ? _fbb.CreateVector(_o->shape_signature) : 0; + auto _has_rank = _o->has_rank; + auto _variant_tensors = _o->variant_tensors.size() ? _fbb.CreateVector<::flatbuffers::Offset> (_o->variant_tensors.size(), [](size_t i, _VectorArgs *__va) { return CreateVariantSubType(*__va->__fbb, __va->__o->variant_tensors[i].get(), __va->__rehasher); }, &_va ) : 0; + return tflite::CreateTensor( + _fbb, + _shape, + _type, + _buffer, + _name, + _quantization, + _is_variable, + _sparsity, + _shape_signature, + _has_rank, + _variant_tensors); +} + +inline StablehloGatherOptionsT *StablehloGatherOptions::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { + auto _o = std::unique_ptr(new StablehloGatherOptionsT()); + UnPackTo(_o.get(), _resolver); + return _o.release(); +} + +inline void StablehloGatherOptions::UnPackTo(StablehloGatherOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = offset_dims(); if (_e) { _o->offset_dims.resize(_e->size()); for (::flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->offset_dims[_i] = _e->Get(_i); } } else { _o->offset_dims.resize(0); } } + { auto _e = collapsed_slice_dims(); if (_e) { _o->collapsed_slice_dims.resize(_e->size()); for (::flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->collapsed_slice_dims[_i] = _e->Get(_i); } } else { _o->collapsed_slice_dims.resize(0); } } + { auto _e = start_index_map(); if (_e) { _o->start_index_map.resize(_e->size()); for (::flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->start_index_map[_i] = _e->Get(_i); } } else { _o->start_index_map.resize(0); } } + { auto _e = index_vector_dim(); _o->index_vector_dim = _e; } + { auto _e = slice_sizes(); if (_e) { _o->slice_sizes.resize(_e->size()); for (::flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->slice_sizes[_i] = _e->Get(_i); } } else { _o->slice_sizes.resize(0); } } + { auto _e = indices_are_sorted(); _o->indices_are_sorted = _e; } +} + +inline ::flatbuffers::Offset StablehloGatherOptions::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const StablehloGatherOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { + return CreateStablehloGatherOptions(_fbb, _o, _rehasher); +} + +inline ::flatbuffers::Offset CreateStablehloGatherOptions(::flatbuffers::FlatBufferBuilder &_fbb, const StablehloGatherOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { ::flatbuffers::FlatBufferBuilder *__fbb; const StablehloGatherOptionsT* __o; const ::flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _offset_dims = _o->offset_dims.size() ? _fbb.CreateVector(_o->offset_dims) : 0; + auto _collapsed_slice_dims = _o->collapsed_slice_dims.size() ? _fbb.CreateVector(_o->collapsed_slice_dims) : 0; + auto _start_index_map = _o->start_index_map.size() ? _fbb.CreateVector(_o->start_index_map) : 0; + auto _index_vector_dim = _o->index_vector_dim; + auto _slice_sizes = _o->slice_sizes.size() ? _fbb.CreateVector(_o->slice_sizes) : 0; + auto _indices_are_sorted = _o->indices_are_sorted; + return tflite::CreateStablehloGatherOptions( + _fbb, + _offset_dims, + _collapsed_slice_dims, + _start_index_map, + _index_vector_dim, + _slice_sizes, + _indices_are_sorted); +} + +inline StablehloTransposeOptionsT *StablehloTransposeOptions::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { + auto _o = std::unique_ptr(new StablehloTransposeOptionsT()); + UnPackTo(_o.get(), _resolver); + return _o.release(); +} + +inline void StablehloTransposeOptions::UnPackTo(StablehloTransposeOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = permutation(); if (_e) { _o->permutation.resize(_e->size()); for (::flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->permutation[_i] = _e->Get(_i); } } else { _o->permutation.resize(0); } } +} + +inline ::flatbuffers::Offset StablehloTransposeOptions::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const StablehloTransposeOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { + return CreateStablehloTransposeOptions(_fbb, _o, _rehasher); +} + +inline ::flatbuffers::Offset CreateStablehloTransposeOptions(::flatbuffers::FlatBufferBuilder &_fbb, const StablehloTransposeOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { ::flatbuffers::FlatBufferBuilder *__fbb; const StablehloTransposeOptionsT* __o; const ::flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _permutation = _o->permutation.size() ? _fbb.CreateVector(_o->permutation) : 0; + return tflite::CreateStablehloTransposeOptions( + _fbb, + _permutation); +} + +inline StablehloDotGeneralOptionsT *StablehloDotGeneralOptions::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { + auto _o = std::unique_ptr(new StablehloDotGeneralOptionsT()); + UnPackTo(_o.get(), _resolver); + return _o.release(); +} + +inline void StablehloDotGeneralOptions::UnPackTo(StablehloDotGeneralOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = lhs_batching_dimensions(); if (_e) { _o->lhs_batching_dimensions.resize(_e->size()); for (::flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->lhs_batching_dimensions[_i] = _e->Get(_i); } } else { _o->lhs_batching_dimensions.resize(0); } } + { auto _e = rhs_batching_dimensions(); if (_e) { _o->rhs_batching_dimensions.resize(_e->size()); for (::flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->rhs_batching_dimensions[_i] = _e->Get(_i); } } else { _o->rhs_batching_dimensions.resize(0); } } + { auto _e = lhs_contracting_dimensions(); if (_e) { _o->lhs_contracting_dimensions.resize(_e->size()); for (::flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->lhs_contracting_dimensions[_i] = _e->Get(_i); } } else { _o->lhs_contracting_dimensions.resize(0); } } + { auto _e = rhs_contracting_dimensions(); if (_e) { _o->rhs_contracting_dimensions.resize(_e->size()); for (::flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->rhs_contracting_dimensions[_i] = _e->Get(_i); } } else { _o->rhs_contracting_dimensions.resize(0); } } + { auto _e = precision_config(); if (_e) { _o->precision_config.resize(_e->size()); for (::flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->precision_config[_i] = static_cast(_e->Get(_i)); } } else { _o->precision_config.resize(0); } } +} + +inline ::flatbuffers::Offset StablehloDotGeneralOptions::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const StablehloDotGeneralOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { + return CreateStablehloDotGeneralOptions(_fbb, _o, _rehasher); +} + +inline ::flatbuffers::Offset CreateStablehloDotGeneralOptions(::flatbuffers::FlatBufferBuilder &_fbb, const StablehloDotGeneralOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { ::flatbuffers::FlatBufferBuilder *__fbb; const StablehloDotGeneralOptionsT* __o; const ::flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _lhs_batching_dimensions = _o->lhs_batching_dimensions.size() ? _fbb.CreateVector(_o->lhs_batching_dimensions) : 0; + auto _rhs_batching_dimensions = _o->rhs_batching_dimensions.size() ? _fbb.CreateVector(_o->rhs_batching_dimensions) : 0; + auto _lhs_contracting_dimensions = _o->lhs_contracting_dimensions.size() ? _fbb.CreateVector(_o->lhs_contracting_dimensions) : 0; + auto _rhs_contracting_dimensions = _o->rhs_contracting_dimensions.size() ? _fbb.CreateVector(_o->rhs_contracting_dimensions) : 0; + auto _precision_config = _o->precision_config.size() ? _fbb.CreateVectorScalarCast(::flatbuffers::data(_o->precision_config), _o->precision_config.size()) : 0; + return tflite::CreateStablehloDotGeneralOptions( + _fbb, + _lhs_batching_dimensions, + _rhs_batching_dimensions, + _lhs_contracting_dimensions, + _rhs_contracting_dimensions, + _precision_config); +} + +inline StablehloReduceWindowOptionsT *StablehloReduceWindowOptions::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { + auto _o = std::unique_ptr(new StablehloReduceWindowOptionsT()); + UnPackTo(_o.get(), _resolver); + return _o.release(); +} + +inline void StablehloReduceWindowOptions::UnPackTo(StablehloReduceWindowOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = window_dimensions(); if (_e) { _o->window_dimensions.resize(_e->size()); for (::flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->window_dimensions[_i] = _e->Get(_i); } } else { _o->window_dimensions.resize(0); } } + { auto _e = window_strides(); if (_e) { _o->window_strides.resize(_e->size()); for (::flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->window_strides[_i] = _e->Get(_i); } } else { _o->window_strides.resize(0); } } + { auto _e = base_dilations(); if (_e) { _o->base_dilations.resize(_e->size()); for (::flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->base_dilations[_i] = _e->Get(_i); } } else { _o->base_dilations.resize(0); } } + { auto _e = window_dilations(); if (_e) { _o->window_dilations.resize(_e->size()); for (::flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->window_dilations[_i] = _e->Get(_i); } } else { _o->window_dilations.resize(0); } } + { auto _e = padding(); if (_e) { _o->padding.resize(_e->size()); for (::flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->padding[_i] = _e->Get(_i); } } else { _o->padding.resize(0); } } + { auto _e = body_subgraph_index(); _o->body_subgraph_index = _e; } +} + +inline ::flatbuffers::Offset StablehloReduceWindowOptions::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const StablehloReduceWindowOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { + return CreateStablehloReduceWindowOptions(_fbb, _o, _rehasher); +} + +inline ::flatbuffers::Offset CreateStablehloReduceWindowOptions(::flatbuffers::FlatBufferBuilder &_fbb, const StablehloReduceWindowOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { ::flatbuffers::FlatBufferBuilder *__fbb; const StablehloReduceWindowOptionsT* __o; const ::flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _window_dimensions = _o->window_dimensions.size() ? _fbb.CreateVector(_o->window_dimensions) : 0; + auto _window_strides = _o->window_strides.size() ? _fbb.CreateVector(_o->window_strides) : 0; + auto _base_dilations = _o->base_dilations.size() ? _fbb.CreateVector(_o->base_dilations) : 0; + auto _window_dilations = _o->window_dilations.size() ? _fbb.CreateVector(_o->window_dilations) : 0; + auto _padding = _o->padding.size() ? _fbb.CreateVector(_o->padding) : 0; + auto _body_subgraph_index = _o->body_subgraph_index; + return tflite::CreateStablehloReduceWindowOptions( + _fbb, + _window_dimensions, + _window_strides, + _base_dilations, + _window_dilations, + _padding, + _body_subgraph_index); +} + +inline StablehloWhileOptionsT *StablehloWhileOptions::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { + auto _o = std::unique_ptr(new StablehloWhileOptionsT()); + UnPackTo(_o.get(), _resolver); + return _o.release(); +} + +inline void StablehloWhileOptions::UnPackTo(StablehloWhileOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = cond_subgraph_index(); _o->cond_subgraph_index = _e; } + { auto _e = body_subgraph_index(); _o->body_subgraph_index = _e; } +} + +inline ::flatbuffers::Offset StablehloWhileOptions::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const StablehloWhileOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { + return CreateStablehloWhileOptions(_fbb, _o, _rehasher); +} + +inline ::flatbuffers::Offset CreateStablehloWhileOptions(::flatbuffers::FlatBufferBuilder &_fbb, const StablehloWhileOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { ::flatbuffers::FlatBufferBuilder *__fbb; const StablehloWhileOptionsT* __o; const ::flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _cond_subgraph_index = _o->cond_subgraph_index; + auto _body_subgraph_index = _o->body_subgraph_index; + return tflite::CreateStablehloWhileOptions( + _fbb, + _cond_subgraph_index, + _body_subgraph_index); +} + +inline StablehloSortOptionsT *StablehloSortOptions::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { + auto _o = std::unique_ptr(new StablehloSortOptionsT()); + UnPackTo(_o.get(), _resolver); + return _o.release(); +} + +inline void StablehloSortOptions::UnPackTo(StablehloSortOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = dimension(); _o->dimension = _e; } + { auto _e = is_stable(); _o->is_stable = _e; } + { auto _e = comparator_subgraph_index(); _o->comparator_subgraph_index = _e; } +} + +inline ::flatbuffers::Offset StablehloSortOptions::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const StablehloSortOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { + return CreateStablehloSortOptions(_fbb, _o, _rehasher); +} + +inline ::flatbuffers::Offset CreateStablehloSortOptions(::flatbuffers::FlatBufferBuilder &_fbb, const StablehloSortOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { ::flatbuffers::FlatBufferBuilder *__fbb; const StablehloSortOptionsT* __o; const ::flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _dimension = _o->dimension; + auto _is_stable = _o->is_stable; + auto _comparator_subgraph_index = _o->comparator_subgraph_index; + return tflite::CreateStablehloSortOptions( + _fbb, + _dimension, + _is_stable, + _comparator_subgraph_index); +} + +inline StablehloConcatenateOptionsT *StablehloConcatenateOptions::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { + auto _o = std::unique_ptr(new StablehloConcatenateOptionsT()); + UnPackTo(_o.get(), _resolver); + return _o.release(); +} + +inline void StablehloConcatenateOptions::UnPackTo(StablehloConcatenateOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = dimension(); _o->dimension = _e; } +} + +inline ::flatbuffers::Offset StablehloConcatenateOptions::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const StablehloConcatenateOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { + return CreateStablehloConcatenateOptions(_fbb, _o, _rehasher); +} + +inline ::flatbuffers::Offset CreateStablehloConcatenateOptions(::flatbuffers::FlatBufferBuilder &_fbb, const StablehloConcatenateOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { ::flatbuffers::FlatBufferBuilder *__fbb; const StablehloConcatenateOptionsT* __o; const ::flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _dimension = _o->dimension; + return tflite::CreateStablehloConcatenateOptions( + _fbb, + _dimension); +} + +inline StablehloBroadcastInDimOptionsT *StablehloBroadcastInDimOptions::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { + auto _o = std::unique_ptr(new StablehloBroadcastInDimOptionsT()); + UnPackTo(_o.get(), _resolver); + return _o.release(); +} + +inline void StablehloBroadcastInDimOptions::UnPackTo(StablehloBroadcastInDimOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = broadcast_dimensions(); if (_e) { _o->broadcast_dimensions.resize(_e->size()); for (::flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->broadcast_dimensions[_i] = _e->Get(_i); } } else { _o->broadcast_dimensions.resize(0); } } +} + +inline ::flatbuffers::Offset StablehloBroadcastInDimOptions::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const StablehloBroadcastInDimOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { + return CreateStablehloBroadcastInDimOptions(_fbb, _o, _rehasher); +} + +inline ::flatbuffers::Offset CreateStablehloBroadcastInDimOptions(::flatbuffers::FlatBufferBuilder &_fbb, const StablehloBroadcastInDimOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { ::flatbuffers::FlatBufferBuilder *__fbb; const StablehloBroadcastInDimOptionsT* __o; const ::flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _broadcast_dimensions = _o->broadcast_dimensions.size() ? _fbb.CreateVector(_o->broadcast_dimensions) : 0; + return tflite::CreateStablehloBroadcastInDimOptions( + _fbb, + _broadcast_dimensions); +} + +inline StablehloCompareOptionsT *StablehloCompareOptions::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { + auto _o = std::unique_ptr(new StablehloCompareOptionsT()); + UnPackTo(_o.get(), _resolver); + return _o.release(); +} + +inline void StablehloCompareOptions::UnPackTo(StablehloCompareOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = comparison_direction(); _o->comparison_direction = _e; } + { auto _e = compare_type(); _o->compare_type = _e; } +} + +inline ::flatbuffers::Offset StablehloCompareOptions::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const StablehloCompareOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { + return CreateStablehloCompareOptions(_fbb, _o, _rehasher); +} + +inline ::flatbuffers::Offset CreateStablehloCompareOptions(::flatbuffers::FlatBufferBuilder &_fbb, const StablehloCompareOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { ::flatbuffers::FlatBufferBuilder *__fbb; const StablehloCompareOptionsT* __o; const ::flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _comparison_direction = _o->comparison_direction; + auto _compare_type = _o->compare_type; + return tflite::CreateStablehloCompareOptions( + _fbb, + _comparison_direction, + _compare_type); +} + +inline StablehloDynamicSliceOptionsT *StablehloDynamicSliceOptions::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { + auto _o = std::unique_ptr(new StablehloDynamicSliceOptionsT()); + UnPackTo(_o.get(), _resolver); + return _o.release(); +} + +inline void StablehloDynamicSliceOptions::UnPackTo(StablehloDynamicSliceOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = slice_sizes(); if (_e) { _o->slice_sizes.resize(_e->size()); for (::flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->slice_sizes[_i] = _e->Get(_i); } } else { _o->slice_sizes.resize(0); } } +} + +inline ::flatbuffers::Offset StablehloDynamicSliceOptions::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const StablehloDynamicSliceOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { + return CreateStablehloDynamicSliceOptions(_fbb, _o, _rehasher); +} + +inline ::flatbuffers::Offset CreateStablehloDynamicSliceOptions(::flatbuffers::FlatBufferBuilder &_fbb, const StablehloDynamicSliceOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { ::flatbuffers::FlatBufferBuilder *__fbb; const StablehloDynamicSliceOptionsT* __o; const ::flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _slice_sizes = _o->slice_sizes.size() ? _fbb.CreateVector(_o->slice_sizes) : 0; + return tflite::CreateStablehloDynamicSliceOptions( + _fbb, + _slice_sizes); +} + +inline StablehloPadOptionsT *StablehloPadOptions::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { + auto _o = std::unique_ptr(new StablehloPadOptionsT()); + UnPackTo(_o.get(), _resolver); + return _o.release(); +} + +inline void StablehloPadOptions::UnPackTo(StablehloPadOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = edge_padding_low(); if (_e) { _o->edge_padding_low.resize(_e->size()); for (::flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->edge_padding_low[_i] = _e->Get(_i); } } else { _o->edge_padding_low.resize(0); } } + { auto _e = edge_padding_high(); if (_e) { _o->edge_padding_high.resize(_e->size()); for (::flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->edge_padding_high[_i] = _e->Get(_i); } } else { _o->edge_padding_high.resize(0); } } + { auto _e = interior_padding(); if (_e) { _o->interior_padding.resize(_e->size()); for (::flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->interior_padding[_i] = _e->Get(_i); } } else { _o->interior_padding.resize(0); } } +} + +inline ::flatbuffers::Offset StablehloPadOptions::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const StablehloPadOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { + return CreateStablehloPadOptions(_fbb, _o, _rehasher); +} + +inline ::flatbuffers::Offset CreateStablehloPadOptions(::flatbuffers::FlatBufferBuilder &_fbb, const StablehloPadOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { ::flatbuffers::FlatBufferBuilder *__fbb; const StablehloPadOptionsT* __o; const ::flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _edge_padding_low = _o->edge_padding_low.size() ? _fbb.CreateVector(_o->edge_padding_low) : 0; + auto _edge_padding_high = _o->edge_padding_high.size() ? _fbb.CreateVector(_o->edge_padding_high) : 0; + auto _interior_padding = _o->interior_padding.size() ? _fbb.CreateVector(_o->interior_padding) : 0; + return tflite::CreateStablehloPadOptions( + _fbb, + _edge_padding_low, + _edge_padding_high, + _interior_padding); +} + +inline StablehloIotaOptionsT *StablehloIotaOptions::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { + auto _o = std::unique_ptr(new StablehloIotaOptionsT()); + UnPackTo(_o.get(), _resolver); + return _o.release(); +} + +inline void StablehloIotaOptions::UnPackTo(StablehloIotaOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = iota_dimension(); _o->iota_dimension = _e; } +} + +inline ::flatbuffers::Offset StablehloIotaOptions::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const StablehloIotaOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { + return CreateStablehloIotaOptions(_fbb, _o, _rehasher); +} + +inline ::flatbuffers::Offset CreateStablehloIotaOptions(::flatbuffers::FlatBufferBuilder &_fbb, const StablehloIotaOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { ::flatbuffers::FlatBufferBuilder *__fbb; const StablehloIotaOptionsT* __o; const ::flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _iota_dimension = _o->iota_dimension; + return tflite::CreateStablehloIotaOptions( + _fbb, + _iota_dimension); +} + +inline StablehloCustomCallOptionsT *StablehloCustomCallOptions::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { + auto _o = std::unique_ptr(new StablehloCustomCallOptionsT()); + UnPackTo(_o.get(), _resolver); + return _o.release(); +} + +inline void StablehloCustomCallOptions::UnPackTo(StablehloCustomCallOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = call_target_name(); if (_e) _o->call_target_name = _e->str(); } + { auto _e = has_side_effect(); _o->has_side_effect = _e; } + { auto _e = backend_config(); if (_e) _o->backend_config = _e->str(); } + { auto _e = api_version(); _o->api_version = _e; } + { auto _e = called_computations(); if (_e) { _o->called_computations.resize(_e->size()); for (::flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->called_computations[_i] = _e->Get(_i); } } else { _o->called_computations.resize(0); } } + { auto _e = custom_attributes(); if (_e) { _o->custom_attributes.resize(_e->size()); std::copy(_e->begin(), _e->end(), _o->custom_attributes.begin()); } } +} + +inline ::flatbuffers::Offset StablehloCustomCallOptions::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const StablehloCustomCallOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { + return CreateStablehloCustomCallOptions(_fbb, _o, _rehasher); +} + +inline ::flatbuffers::Offset CreateStablehloCustomCallOptions(::flatbuffers::FlatBufferBuilder &_fbb, const StablehloCustomCallOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { ::flatbuffers::FlatBufferBuilder *__fbb; const StablehloCustomCallOptionsT* __o; const ::flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _call_target_name = _o->call_target_name.empty() ? 0 : _fbb.CreateString(_o->call_target_name); + auto _has_side_effect = _o->has_side_effect; + auto _backend_config = _o->backend_config.empty() ? 0 : _fbb.CreateString(_o->backend_config); + auto _api_version = _o->api_version; + auto _called_computations = _o->called_computations.size() ? _fbb.CreateVector(_o->called_computations) : 0; + auto _custom_attributes = _o->custom_attributes.size() ? _fbb.CreateVector(_o->custom_attributes) : 0; + return tflite::CreateStablehloCustomCallOptions( + _fbb, + _call_target_name, + _has_side_effect, + _backend_config, + _api_version, + _called_computations, + _custom_attributes); +} + +inline StablehloReduceOptionsT *StablehloReduceOptions::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { + auto _o = std::unique_ptr(new StablehloReduceOptionsT()); + UnPackTo(_o.get(), _resolver); + return _o.release(); +} + +inline void StablehloReduceOptions::UnPackTo(StablehloReduceOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = dimensions(); if (_e) { _o->dimensions.resize(_e->size()); for (::flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->dimensions[_i] = _e->Get(_i); } } else { _o->dimensions.resize(0); } } + { auto _e = body_subgraph_index(); _o->body_subgraph_index = _e; } +} + +inline ::flatbuffers::Offset StablehloReduceOptions::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const StablehloReduceOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { + return CreateStablehloReduceOptions(_fbb, _o, _rehasher); +} + +inline ::flatbuffers::Offset CreateStablehloReduceOptions(::flatbuffers::FlatBufferBuilder &_fbb, const StablehloReduceOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { ::flatbuffers::FlatBufferBuilder *__fbb; const StablehloReduceOptionsT* __o; const ::flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _dimensions = _o->dimensions.size() ? _fbb.CreateVector(_o->dimensions) : 0; + auto _body_subgraph_index = _o->body_subgraph_index; + return tflite::CreateStablehloReduceOptions( + _fbb, + _dimensions, + _body_subgraph_index); +} + +inline StablehloSliceOptionsT *StablehloSliceOptions::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { + auto _o = std::unique_ptr(new StablehloSliceOptionsT()); + UnPackTo(_o.get(), _resolver); + return _o.release(); +} + +inline void StablehloSliceOptions::UnPackTo(StablehloSliceOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = start_indices(); if (_e) { _o->start_indices.resize(_e->size()); for (::flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->start_indices[_i] = _e->Get(_i); } } else { _o->start_indices.resize(0); } } + { auto _e = limit_indices(); if (_e) { _o->limit_indices.resize(_e->size()); for (::flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->limit_indices[_i] = _e->Get(_i); } } else { _o->limit_indices.resize(0); } } + { auto _e = strides(); if (_e) { _o->strides.resize(_e->size()); for (::flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->strides[_i] = _e->Get(_i); } } else { _o->strides.resize(0); } } +} + +inline ::flatbuffers::Offset StablehloSliceOptions::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const StablehloSliceOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { + return CreateStablehloSliceOptions(_fbb, _o, _rehasher); +} + +inline ::flatbuffers::Offset CreateStablehloSliceOptions(::flatbuffers::FlatBufferBuilder &_fbb, const StablehloSliceOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { ::flatbuffers::FlatBufferBuilder *__fbb; const StablehloSliceOptionsT* __o; const ::flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _start_indices = _o->start_indices.size() ? _fbb.CreateVector(_o->start_indices) : 0; + auto _limit_indices = _o->limit_indices.size() ? _fbb.CreateVector(_o->limit_indices) : 0; + auto _strides = _o->strides.size() ? _fbb.CreateVector(_o->strides) : 0; + return tflite::CreateStablehloSliceOptions( + _fbb, + _start_indices, + _limit_indices, + _strides); +} + +inline StablehloConvolutionOptionsT *StablehloConvolutionOptions::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { + auto _o = std::unique_ptr(new StablehloConvolutionOptionsT()); + UnPackTo(_o.get(), _resolver); + return _o.release(); +} + +inline void StablehloConvolutionOptions::UnPackTo(StablehloConvolutionOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = window_strides(); if (_e) { _o->window_strides.resize(_e->size()); for (::flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->window_strides[_i] = _e->Get(_i); } } else { _o->window_strides.resize(0); } } + { auto _e = padding(); if (_e) { _o->padding.resize(_e->size()); for (::flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->padding[_i] = _e->Get(_i); } } else { _o->padding.resize(0); } } + { auto _e = lhs_dilation(); if (_e) { _o->lhs_dilation.resize(_e->size()); for (::flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->lhs_dilation[_i] = _e->Get(_i); } } else { _o->lhs_dilation.resize(0); } } + { auto _e = rhs_dilation(); if (_e) { _o->rhs_dilation.resize(_e->size()); for (::flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->rhs_dilation[_i] = _e->Get(_i); } } else { _o->rhs_dilation.resize(0); } } + { auto _e = window_reversal(); if (_e) { _o->window_reversal.resize(_e->size()); for (::flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->window_reversal[_i] = _e->Get(_i) != 0; } } else { _o->window_reversal.resize(0); } } + { auto _e = input_batch_dimension(); _o->input_batch_dimension = _e; } + { auto _e = input_feature_dimension(); _o->input_feature_dimension = _e; } + { auto _e = input_spatial_dimensions(); if (_e) { _o->input_spatial_dimensions.resize(_e->size()); for (::flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->input_spatial_dimensions[_i] = _e->Get(_i); } } else { _o->input_spatial_dimensions.resize(0); } } + { auto _e = kernel_input_feature_dimension(); _o->kernel_input_feature_dimension = _e; } + { auto _e = kernel_output_feature_dimension(); _o->kernel_output_feature_dimension = _e; } + { auto _e = kernel_spatial_dimensions(); if (_e) { _o->kernel_spatial_dimensions.resize(_e->size()); for (::flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->kernel_spatial_dimensions[_i] = _e->Get(_i); } } else { _o->kernel_spatial_dimensions.resize(0); } } + { auto _e = output_batch_dimension(); _o->output_batch_dimension = _e; } + { auto _e = output_feature_dimension(); _o->output_feature_dimension = _e; } + { auto _e = output_spatial_dimensions(); if (_e) { _o->output_spatial_dimensions.resize(_e->size()); for (::flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->output_spatial_dimensions[_i] = _e->Get(_i); } } else { _o->output_spatial_dimensions.resize(0); } } + { auto _e = feature_group_count(); _o->feature_group_count = _e; } + { auto _e = batch_group_count(); _o->batch_group_count = _e; } + { auto _e = precision_config(); if (_e) { _o->precision_config.resize(_e->size()); for (::flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->precision_config[_i] = static_cast(_e->Get(_i)); } } else { _o->precision_config.resize(0); } } +} + +inline ::flatbuffers::Offset StablehloConvolutionOptions::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const StablehloConvolutionOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { + return CreateStablehloConvolutionOptions(_fbb, _o, _rehasher); +} + +inline ::flatbuffers::Offset CreateStablehloConvolutionOptions(::flatbuffers::FlatBufferBuilder &_fbb, const StablehloConvolutionOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { ::flatbuffers::FlatBufferBuilder *__fbb; const StablehloConvolutionOptionsT* __o; const ::flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _window_strides = _o->window_strides.size() ? _fbb.CreateVector(_o->window_strides) : 0; + auto _padding = _o->padding.size() ? _fbb.CreateVector(_o->padding) : 0; + auto _lhs_dilation = _o->lhs_dilation.size() ? _fbb.CreateVector(_o->lhs_dilation) : 0; + auto _rhs_dilation = _o->rhs_dilation.size() ? _fbb.CreateVector(_o->rhs_dilation) : 0; + auto _window_reversal = _o->window_reversal.size() ? _fbb.CreateVector(_o->window_reversal) : 0; + auto _input_batch_dimension = _o->input_batch_dimension; + auto _input_feature_dimension = _o->input_feature_dimension; + auto _input_spatial_dimensions = _o->input_spatial_dimensions.size() ? _fbb.CreateVector(_o->input_spatial_dimensions) : 0; + auto _kernel_input_feature_dimension = _o->kernel_input_feature_dimension; + auto _kernel_output_feature_dimension = _o->kernel_output_feature_dimension; + auto _kernel_spatial_dimensions = _o->kernel_spatial_dimensions.size() ? _fbb.CreateVector(_o->kernel_spatial_dimensions) : 0; + auto _output_batch_dimension = _o->output_batch_dimension; + auto _output_feature_dimension = _o->output_feature_dimension; + auto _output_spatial_dimensions = _o->output_spatial_dimensions.size() ? _fbb.CreateVector(_o->output_spatial_dimensions) : 0; + auto _feature_group_count = _o->feature_group_count; + auto _batch_group_count = _o->batch_group_count; + auto _precision_config = _o->precision_config.size() ? _fbb.CreateVectorScalarCast(::flatbuffers::data(_o->precision_config), _o->precision_config.size()) : 0; + return tflite::CreateStablehloConvolutionOptions( + _fbb, + _window_strides, + _padding, + _lhs_dilation, + _rhs_dilation, + _window_reversal, + _input_batch_dimension, + _input_feature_dimension, + _input_spatial_dimensions, + _kernel_input_feature_dimension, + _kernel_output_feature_dimension, + _kernel_spatial_dimensions, + _output_batch_dimension, + _output_feature_dimension, + _output_spatial_dimensions, + _feature_group_count, + _batch_group_count, + _precision_config); +} + +inline StablehloScatterOptionsT *StablehloScatterOptions::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { + auto _o = std::unique_ptr(new StablehloScatterOptionsT()); + UnPackTo(_o.get(), _resolver); + return _o.release(); +} + +inline void StablehloScatterOptions::UnPackTo(StablehloScatterOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = indices_are_sorted(); _o->indices_are_sorted = _e; } + { auto _e = update_window_dims(); if (_e) { _o->update_window_dims.resize(_e->size()); for (::flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->update_window_dims[_i] = _e->Get(_i); } } else { _o->update_window_dims.resize(0); } } + { auto _e = inserted_window_dims(); if (_e) { _o->inserted_window_dims.resize(_e->size()); for (::flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->inserted_window_dims[_i] = _e->Get(_i); } } else { _o->inserted_window_dims.resize(0); } } + { auto _e = scatter_dims_to_operand_dims(); if (_e) { _o->scatter_dims_to_operand_dims.resize(_e->size()); for (::flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->scatter_dims_to_operand_dims[_i] = _e->Get(_i); } } else { _o->scatter_dims_to_operand_dims.resize(0); } } + { auto _e = index_vector_dim(); _o->index_vector_dim = _e; } + { auto _e = unique_indices(); _o->unique_indices = _e; } + { auto _e = update_computation_subgraph_index(); _o->update_computation_subgraph_index = _e; } +} + +inline ::flatbuffers::Offset StablehloScatterOptions::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const StablehloScatterOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { + return CreateStablehloScatterOptions(_fbb, _o, _rehasher); +} + +inline ::flatbuffers::Offset CreateStablehloScatterOptions(::flatbuffers::FlatBufferBuilder &_fbb, const StablehloScatterOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { ::flatbuffers::FlatBufferBuilder *__fbb; const StablehloScatterOptionsT* __o; const ::flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _indices_are_sorted = _o->indices_are_sorted; + auto _update_window_dims = _o->update_window_dims.size() ? _fbb.CreateVector(_o->update_window_dims) : 0; + auto _inserted_window_dims = _o->inserted_window_dims.size() ? _fbb.CreateVector(_o->inserted_window_dims) : 0; + auto _scatter_dims_to_operand_dims = _o->scatter_dims_to_operand_dims.size() ? _fbb.CreateVector(_o->scatter_dims_to_operand_dims) : 0; + auto _index_vector_dim = _o->index_vector_dim; + auto _unique_indices = _o->unique_indices; + auto _update_computation_subgraph_index = _o->update_computation_subgraph_index; + return tflite::CreateStablehloScatterOptions( + _fbb, + _indices_are_sorted, + _update_window_dims, + _inserted_window_dims, + _scatter_dims_to_operand_dims, + _index_vector_dim, + _unique_indices, + _update_computation_subgraph_index); +} + +inline StablehloCaseOptionsT *StablehloCaseOptions::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { + auto _o = std::unique_ptr(new StablehloCaseOptionsT()); + UnPackTo(_o.get(), _resolver); + return _o.release(); +} + +inline void StablehloCaseOptions::UnPackTo(StablehloCaseOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = branch_subgraph_indices(); if (_e) { _o->branch_subgraph_indices.resize(_e->size()); for (::flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->branch_subgraph_indices[_i] = _e->Get(_i); } } else { _o->branch_subgraph_indices.resize(0); } } +} + +inline ::flatbuffers::Offset StablehloCaseOptions::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const StablehloCaseOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { + return CreateStablehloCaseOptions(_fbb, _o, _rehasher); +} + +inline ::flatbuffers::Offset CreateStablehloCaseOptions(::flatbuffers::FlatBufferBuilder &_fbb, const StablehloCaseOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { ::flatbuffers::FlatBufferBuilder *__fbb; const StablehloCaseOptionsT* __o; const ::flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _branch_subgraph_indices = _o->branch_subgraph_indices.size() ? _fbb.CreateVector(_o->branch_subgraph_indices) : 0; + return tflite::CreateStablehloCaseOptions( + _fbb, + _branch_subgraph_indices); +} + +inline StablehloRngBitGeneratorOptionsT *StablehloRngBitGeneratorOptions::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { + auto _o = std::unique_ptr(new StablehloRngBitGeneratorOptionsT()); + UnPackTo(_o.get(), _resolver); + return _o.release(); +} + +inline void StablehloRngBitGeneratorOptions::UnPackTo(StablehloRngBitGeneratorOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = algorithm(); _o->algorithm = _e; } +} + +inline ::flatbuffers::Offset StablehloRngBitGeneratorOptions::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const StablehloRngBitGeneratorOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { + return CreateStablehloRngBitGeneratorOptions(_fbb, _o, _rehasher); +} + +inline ::flatbuffers::Offset CreateStablehloRngBitGeneratorOptions(::flatbuffers::FlatBufferBuilder &_fbb, const StablehloRngBitGeneratorOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { ::flatbuffers::FlatBufferBuilder *__fbb; const StablehloRngBitGeneratorOptionsT* __o; const ::flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _algorithm = _o->algorithm; + return tflite::CreateStablehloRngBitGeneratorOptions( + _fbb, + _algorithm); +} + +inline Conv2DOptionsT *Conv2DOptions::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { + auto _o = std::unique_ptr(new Conv2DOptionsT()); + UnPackTo(_o.get(), _resolver); + return _o.release(); +} + +inline void Conv2DOptions::UnPackTo(Conv2DOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = padding(); _o->padding = _e; } + { auto _e = stride_w(); _o->stride_w = _e; } + { auto _e = stride_h(); _o->stride_h = _e; } + { auto _e = fused_activation_function(); _o->fused_activation_function = _e; } + { auto _e = dilation_w_factor(); _o->dilation_w_factor = _e; } + { auto _e = dilation_h_factor(); _o->dilation_h_factor = _e; } + { auto _e = quantized_bias_type(); _o->quantized_bias_type = _e; } +} + +inline ::flatbuffers::Offset Conv2DOptions::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const Conv2DOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { + return CreateConv2DOptions(_fbb, _o, _rehasher); +} + +inline ::flatbuffers::Offset CreateConv2DOptions(::flatbuffers::FlatBufferBuilder &_fbb, const Conv2DOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { ::flatbuffers::FlatBufferBuilder *__fbb; const Conv2DOptionsT* __o; const ::flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _padding = _o->padding; + auto _stride_w = _o->stride_w; + auto _stride_h = _o->stride_h; + auto _fused_activation_function = _o->fused_activation_function; + auto _dilation_w_factor = _o->dilation_w_factor; + auto _dilation_h_factor = _o->dilation_h_factor; + auto _quantized_bias_type = _o->quantized_bias_type; + return tflite::CreateConv2DOptions( + _fbb, + _padding, + _stride_w, + _stride_h, + _fused_activation_function, + _dilation_w_factor, + _dilation_h_factor, + _quantized_bias_type); +} + +inline Conv3DOptionsT *Conv3DOptions::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { + auto _o = std::unique_ptr(new Conv3DOptionsT()); + UnPackTo(_o.get(), _resolver); + return _o.release(); +} + +inline void Conv3DOptions::UnPackTo(Conv3DOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = padding(); _o->padding = _e; } + { auto _e = stride_d(); _o->stride_d = _e; } + { auto _e = stride_w(); _o->stride_w = _e; } + { auto _e = stride_h(); _o->stride_h = _e; } + { auto _e = fused_activation_function(); _o->fused_activation_function = _e; } + { auto _e = dilation_d_factor(); _o->dilation_d_factor = _e; } + { auto _e = dilation_w_factor(); _o->dilation_w_factor = _e; } + { auto _e = dilation_h_factor(); _o->dilation_h_factor = _e; } +} + +inline ::flatbuffers::Offset Conv3DOptions::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const Conv3DOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { + return CreateConv3DOptions(_fbb, _o, _rehasher); +} + +inline ::flatbuffers::Offset CreateConv3DOptions(::flatbuffers::FlatBufferBuilder &_fbb, const Conv3DOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { ::flatbuffers::FlatBufferBuilder *__fbb; const Conv3DOptionsT* __o; const ::flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _padding = _o->padding; + auto _stride_d = _o->stride_d; + auto _stride_w = _o->stride_w; + auto _stride_h = _o->stride_h; + auto _fused_activation_function = _o->fused_activation_function; + auto _dilation_d_factor = _o->dilation_d_factor; + auto _dilation_w_factor = _o->dilation_w_factor; + auto _dilation_h_factor = _o->dilation_h_factor; + return tflite::CreateConv3DOptions( + _fbb, + _padding, + _stride_d, + _stride_w, + _stride_h, + _fused_activation_function, + _dilation_d_factor, + _dilation_w_factor, + _dilation_h_factor); +} + +inline Pool2DOptionsT *Pool2DOptions::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { + auto _o = std::unique_ptr(new Pool2DOptionsT()); + UnPackTo(_o.get(), _resolver); + return _o.release(); +} + +inline void Pool2DOptions::UnPackTo(Pool2DOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = padding(); _o->padding = _e; } + { auto _e = stride_w(); _o->stride_w = _e; } + { auto _e = stride_h(); _o->stride_h = _e; } + { auto _e = filter_width(); _o->filter_width = _e; } + { auto _e = filter_height(); _o->filter_height = _e; } + { auto _e = fused_activation_function(); _o->fused_activation_function = _e; } +} + +inline ::flatbuffers::Offset Pool2DOptions::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const Pool2DOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { + return CreatePool2DOptions(_fbb, _o, _rehasher); +} + +inline ::flatbuffers::Offset CreatePool2DOptions(::flatbuffers::FlatBufferBuilder &_fbb, const Pool2DOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { ::flatbuffers::FlatBufferBuilder *__fbb; const Pool2DOptionsT* __o; const ::flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _padding = _o->padding; + auto _stride_w = _o->stride_w; + auto _stride_h = _o->stride_h; + auto _filter_width = _o->filter_width; + auto _filter_height = _o->filter_height; + auto _fused_activation_function = _o->fused_activation_function; + return tflite::CreatePool2DOptions( + _fbb, + _padding, + _stride_w, + _stride_h, + _filter_width, + _filter_height, + _fused_activation_function); +} + +inline DepthwiseConv2DOptionsT *DepthwiseConv2DOptions::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { + auto _o = std::unique_ptr(new DepthwiseConv2DOptionsT()); + UnPackTo(_o.get(), _resolver); + return _o.release(); +} + +inline void DepthwiseConv2DOptions::UnPackTo(DepthwiseConv2DOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = padding(); _o->padding = _e; } + { auto _e = stride_w(); _o->stride_w = _e; } + { auto _e = stride_h(); _o->stride_h = _e; } + { auto _e = depth_multiplier(); _o->depth_multiplier = _e; } + { auto _e = fused_activation_function(); _o->fused_activation_function = _e; } + { auto _e = dilation_w_factor(); _o->dilation_w_factor = _e; } + { auto _e = dilation_h_factor(); _o->dilation_h_factor = _e; } +} + +inline ::flatbuffers::Offset DepthwiseConv2DOptions::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const DepthwiseConv2DOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { + return CreateDepthwiseConv2DOptions(_fbb, _o, _rehasher); +} + +inline ::flatbuffers::Offset CreateDepthwiseConv2DOptions(::flatbuffers::FlatBufferBuilder &_fbb, const DepthwiseConv2DOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { ::flatbuffers::FlatBufferBuilder *__fbb; const DepthwiseConv2DOptionsT* __o; const ::flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _padding = _o->padding; + auto _stride_w = _o->stride_w; + auto _stride_h = _o->stride_h; + auto _depth_multiplier = _o->depth_multiplier; + auto _fused_activation_function = _o->fused_activation_function; + auto _dilation_w_factor = _o->dilation_w_factor; + auto _dilation_h_factor = _o->dilation_h_factor; + return tflite::CreateDepthwiseConv2DOptions( + _fbb, + _padding, + _stride_w, + _stride_h, + _depth_multiplier, + _fused_activation_function, + _dilation_w_factor, + _dilation_h_factor); +} + +inline ConcatEmbeddingsOptionsT *ConcatEmbeddingsOptions::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { + auto _o = std::unique_ptr(new ConcatEmbeddingsOptionsT()); + UnPackTo(_o.get(), _resolver); + return _o.release(); +} + +inline void ConcatEmbeddingsOptions::UnPackTo(ConcatEmbeddingsOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = num_channels(); _o->num_channels = _e; } + { auto _e = num_columns_per_channel(); if (_e) { _o->num_columns_per_channel.resize(_e->size()); for (::flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->num_columns_per_channel[_i] = _e->Get(_i); } } else { _o->num_columns_per_channel.resize(0); } } + { auto _e = embedding_dim_per_channel(); if (_e) { _o->embedding_dim_per_channel.resize(_e->size()); for (::flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->embedding_dim_per_channel[_i] = _e->Get(_i); } } else { _o->embedding_dim_per_channel.resize(0); } } +} + +inline ::flatbuffers::Offset ConcatEmbeddingsOptions::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const ConcatEmbeddingsOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { + return CreateConcatEmbeddingsOptions(_fbb, _o, _rehasher); +} + +inline ::flatbuffers::Offset CreateConcatEmbeddingsOptions(::flatbuffers::FlatBufferBuilder &_fbb, const ConcatEmbeddingsOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { ::flatbuffers::FlatBufferBuilder *__fbb; const ConcatEmbeddingsOptionsT* __o; const ::flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _num_channels = _o->num_channels; + auto _num_columns_per_channel = _o->num_columns_per_channel.size() ? _fbb.CreateVector(_o->num_columns_per_channel) : 0; + auto _embedding_dim_per_channel = _o->embedding_dim_per_channel.size() ? _fbb.CreateVector(_o->embedding_dim_per_channel) : 0; + return tflite::CreateConcatEmbeddingsOptions( + _fbb, + _num_channels, + _num_columns_per_channel, + _embedding_dim_per_channel); +} + +inline LSHProjectionOptionsT *LSHProjectionOptions::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { + auto _o = std::unique_ptr(new LSHProjectionOptionsT()); + UnPackTo(_o.get(), _resolver); + return _o.release(); +} + +inline void LSHProjectionOptions::UnPackTo(LSHProjectionOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = type(); _o->type = _e; } +} + +inline ::flatbuffers::Offset LSHProjectionOptions::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const LSHProjectionOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { + return CreateLSHProjectionOptions(_fbb, _o, _rehasher); +} + +inline ::flatbuffers::Offset CreateLSHProjectionOptions(::flatbuffers::FlatBufferBuilder &_fbb, const LSHProjectionOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { ::flatbuffers::FlatBufferBuilder *__fbb; const LSHProjectionOptionsT* __o; const ::flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _type = _o->type; + return tflite::CreateLSHProjectionOptions( + _fbb, + _type); +} + +inline SVDFOptionsT *SVDFOptions::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { + auto _o = std::unique_ptr(new SVDFOptionsT()); + UnPackTo(_o.get(), _resolver); + return _o.release(); +} + +inline void SVDFOptions::UnPackTo(SVDFOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = rank(); _o->rank = _e; } + { auto _e = fused_activation_function(); _o->fused_activation_function = _e; } + { auto _e = asymmetric_quantize_inputs(); _o->asymmetric_quantize_inputs = _e; } +} + +inline ::flatbuffers::Offset SVDFOptions::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const SVDFOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { + return CreateSVDFOptions(_fbb, _o, _rehasher); +} + +inline ::flatbuffers::Offset CreateSVDFOptions(::flatbuffers::FlatBufferBuilder &_fbb, const SVDFOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { ::flatbuffers::FlatBufferBuilder *__fbb; const SVDFOptionsT* __o; const ::flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _rank = _o->rank; + auto _fused_activation_function = _o->fused_activation_function; + auto _asymmetric_quantize_inputs = _o->asymmetric_quantize_inputs; + return tflite::CreateSVDFOptions( + _fbb, + _rank, + _fused_activation_function, + _asymmetric_quantize_inputs); +} + +inline RNNOptionsT *RNNOptions::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { + auto _o = std::unique_ptr(new RNNOptionsT()); + UnPackTo(_o.get(), _resolver); + return _o.release(); +} + +inline void RNNOptions::UnPackTo(RNNOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = fused_activation_function(); _o->fused_activation_function = _e; } + { auto _e = asymmetric_quantize_inputs(); _o->asymmetric_quantize_inputs = _e; } +} + +inline ::flatbuffers::Offset RNNOptions::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const RNNOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { + return CreateRNNOptions(_fbb, _o, _rehasher); +} + +inline ::flatbuffers::Offset CreateRNNOptions(::flatbuffers::FlatBufferBuilder &_fbb, const RNNOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { ::flatbuffers::FlatBufferBuilder *__fbb; const RNNOptionsT* __o; const ::flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _fused_activation_function = _o->fused_activation_function; + auto _asymmetric_quantize_inputs = _o->asymmetric_quantize_inputs; + return tflite::CreateRNNOptions( + _fbb, + _fused_activation_function, + _asymmetric_quantize_inputs); +} + +inline SequenceRNNOptionsT *SequenceRNNOptions::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { + auto _o = std::unique_ptr(new SequenceRNNOptionsT()); + UnPackTo(_o.get(), _resolver); + return _o.release(); +} + +inline void SequenceRNNOptions::UnPackTo(SequenceRNNOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = time_major(); _o->time_major = _e; } + { auto _e = fused_activation_function(); _o->fused_activation_function = _e; } + { auto _e = asymmetric_quantize_inputs(); _o->asymmetric_quantize_inputs = _e; } +} + +inline ::flatbuffers::Offset SequenceRNNOptions::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const SequenceRNNOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { + return CreateSequenceRNNOptions(_fbb, _o, _rehasher); +} + +inline ::flatbuffers::Offset CreateSequenceRNNOptions(::flatbuffers::FlatBufferBuilder &_fbb, const SequenceRNNOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { ::flatbuffers::FlatBufferBuilder *__fbb; const SequenceRNNOptionsT* __o; const ::flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _time_major = _o->time_major; + auto _fused_activation_function = _o->fused_activation_function; + auto _asymmetric_quantize_inputs = _o->asymmetric_quantize_inputs; + return tflite::CreateSequenceRNNOptions( + _fbb, + _time_major, + _fused_activation_function, + _asymmetric_quantize_inputs); +} + +inline BidirectionalSequenceRNNOptionsT *BidirectionalSequenceRNNOptions::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { + auto _o = std::unique_ptr(new BidirectionalSequenceRNNOptionsT()); + UnPackTo(_o.get(), _resolver); + return _o.release(); +} + +inline void BidirectionalSequenceRNNOptions::UnPackTo(BidirectionalSequenceRNNOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = time_major(); _o->time_major = _e; } + { auto _e = fused_activation_function(); _o->fused_activation_function = _e; } + { auto _e = merge_outputs(); _o->merge_outputs = _e; } + { auto _e = asymmetric_quantize_inputs(); _o->asymmetric_quantize_inputs = _e; } +} + +inline ::flatbuffers::Offset BidirectionalSequenceRNNOptions::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const BidirectionalSequenceRNNOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { + return CreateBidirectionalSequenceRNNOptions(_fbb, _o, _rehasher); +} + +inline ::flatbuffers::Offset CreateBidirectionalSequenceRNNOptions(::flatbuffers::FlatBufferBuilder &_fbb, const BidirectionalSequenceRNNOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { ::flatbuffers::FlatBufferBuilder *__fbb; const BidirectionalSequenceRNNOptionsT* __o; const ::flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _time_major = _o->time_major; + auto _fused_activation_function = _o->fused_activation_function; + auto _merge_outputs = _o->merge_outputs; + auto _asymmetric_quantize_inputs = _o->asymmetric_quantize_inputs; + return tflite::CreateBidirectionalSequenceRNNOptions( + _fbb, + _time_major, + _fused_activation_function, + _merge_outputs, + _asymmetric_quantize_inputs); +} + +inline FullyConnectedOptionsT *FullyConnectedOptions::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { + auto _o = std::unique_ptr(new FullyConnectedOptionsT()); + UnPackTo(_o.get(), _resolver); + return _o.release(); +} + +inline void FullyConnectedOptions::UnPackTo(FullyConnectedOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = fused_activation_function(); _o->fused_activation_function = _e; } + { auto _e = weights_format(); _o->weights_format = _e; } + { auto _e = keep_num_dims(); _o->keep_num_dims = _e; } + { auto _e = asymmetric_quantize_inputs(); _o->asymmetric_quantize_inputs = _e; } + { auto _e = quantized_bias_type(); _o->quantized_bias_type = _e; } +} + +inline ::flatbuffers::Offset FullyConnectedOptions::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const FullyConnectedOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { + return CreateFullyConnectedOptions(_fbb, _o, _rehasher); +} + +inline ::flatbuffers::Offset CreateFullyConnectedOptions(::flatbuffers::FlatBufferBuilder &_fbb, const FullyConnectedOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { ::flatbuffers::FlatBufferBuilder *__fbb; const FullyConnectedOptionsT* __o; const ::flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _fused_activation_function = _o->fused_activation_function; + auto _weights_format = _o->weights_format; + auto _keep_num_dims = _o->keep_num_dims; + auto _asymmetric_quantize_inputs = _o->asymmetric_quantize_inputs; + auto _quantized_bias_type = _o->quantized_bias_type; + return tflite::CreateFullyConnectedOptions( + _fbb, + _fused_activation_function, + _weights_format, + _keep_num_dims, + _asymmetric_quantize_inputs, + _quantized_bias_type); +} + +inline SoftmaxOptionsT *SoftmaxOptions::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { + auto _o = std::unique_ptr(new SoftmaxOptionsT()); + UnPackTo(_o.get(), _resolver); + return _o.release(); +} + +inline void SoftmaxOptions::UnPackTo(SoftmaxOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = beta(); _o->beta = _e; } +} + +inline ::flatbuffers::Offset SoftmaxOptions::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const SoftmaxOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { + return CreateSoftmaxOptions(_fbb, _o, _rehasher); +} + +inline ::flatbuffers::Offset CreateSoftmaxOptions(::flatbuffers::FlatBufferBuilder &_fbb, const SoftmaxOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { ::flatbuffers::FlatBufferBuilder *__fbb; const SoftmaxOptionsT* __o; const ::flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _beta = _o->beta; + return tflite::CreateSoftmaxOptions( + _fbb, + _beta); +} + +inline ConcatenationOptionsT *ConcatenationOptions::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { + auto _o = std::unique_ptr(new ConcatenationOptionsT()); + UnPackTo(_o.get(), _resolver); + return _o.release(); +} + +inline void ConcatenationOptions::UnPackTo(ConcatenationOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = axis(); _o->axis = _e; } + { auto _e = fused_activation_function(); _o->fused_activation_function = _e; } +} + +inline ::flatbuffers::Offset ConcatenationOptions::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const ConcatenationOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { + return CreateConcatenationOptions(_fbb, _o, _rehasher); +} + +inline ::flatbuffers::Offset CreateConcatenationOptions(::flatbuffers::FlatBufferBuilder &_fbb, const ConcatenationOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { ::flatbuffers::FlatBufferBuilder *__fbb; const ConcatenationOptionsT* __o; const ::flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _axis = _o->axis; + auto _fused_activation_function = _o->fused_activation_function; + return tflite::CreateConcatenationOptions( + _fbb, + _axis, + _fused_activation_function); +} + +inline AddOptionsT *AddOptions::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { + auto _o = std::unique_ptr(new AddOptionsT()); + UnPackTo(_o.get(), _resolver); + return _o.release(); +} + +inline void AddOptions::UnPackTo(AddOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = fused_activation_function(); _o->fused_activation_function = _e; } + { auto _e = pot_scale_int16(); _o->pot_scale_int16 = _e; } +} + +inline ::flatbuffers::Offset AddOptions::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const AddOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { + return CreateAddOptions(_fbb, _o, _rehasher); +} + +inline ::flatbuffers::Offset CreateAddOptions(::flatbuffers::FlatBufferBuilder &_fbb, const AddOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { ::flatbuffers::FlatBufferBuilder *__fbb; const AddOptionsT* __o; const ::flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _fused_activation_function = _o->fused_activation_function; + auto _pot_scale_int16 = _o->pot_scale_int16; + return tflite::CreateAddOptions( + _fbb, + _fused_activation_function, + _pot_scale_int16); +} + +inline MulOptionsT *MulOptions::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { + auto _o = std::unique_ptr(new MulOptionsT()); + UnPackTo(_o.get(), _resolver); + return _o.release(); +} + +inline void MulOptions::UnPackTo(MulOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = fused_activation_function(); _o->fused_activation_function = _e; } +} + +inline ::flatbuffers::Offset MulOptions::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const MulOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { + return CreateMulOptions(_fbb, _o, _rehasher); +} + +inline ::flatbuffers::Offset CreateMulOptions(::flatbuffers::FlatBufferBuilder &_fbb, const MulOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { ::flatbuffers::FlatBufferBuilder *__fbb; const MulOptionsT* __o; const ::flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _fused_activation_function = _o->fused_activation_function; + return tflite::CreateMulOptions( + _fbb, + _fused_activation_function); +} + +inline L2NormOptionsT *L2NormOptions::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { + auto _o = std::unique_ptr(new L2NormOptionsT()); + UnPackTo(_o.get(), _resolver); + return _o.release(); +} + +inline void L2NormOptions::UnPackTo(L2NormOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = fused_activation_function(); _o->fused_activation_function = _e; } +} + +inline ::flatbuffers::Offset L2NormOptions::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const L2NormOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { + return CreateL2NormOptions(_fbb, _o, _rehasher); +} + +inline ::flatbuffers::Offset CreateL2NormOptions(::flatbuffers::FlatBufferBuilder &_fbb, const L2NormOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { ::flatbuffers::FlatBufferBuilder *__fbb; const L2NormOptionsT* __o; const ::flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _fused_activation_function = _o->fused_activation_function; + return tflite::CreateL2NormOptions( + _fbb, + _fused_activation_function); +} + +inline LocalResponseNormalizationOptionsT *LocalResponseNormalizationOptions::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { + auto _o = std::unique_ptr(new LocalResponseNormalizationOptionsT()); + UnPackTo(_o.get(), _resolver); + return _o.release(); +} + +inline void LocalResponseNormalizationOptions::UnPackTo(LocalResponseNormalizationOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = radius(); _o->radius = _e; } + { auto _e = bias(); _o->bias = _e; } + { auto _e = alpha(); _o->alpha = _e; } + { auto _e = beta(); _o->beta = _e; } +} + +inline ::flatbuffers::Offset LocalResponseNormalizationOptions::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const LocalResponseNormalizationOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { + return CreateLocalResponseNormalizationOptions(_fbb, _o, _rehasher); +} + +inline ::flatbuffers::Offset CreateLocalResponseNormalizationOptions(::flatbuffers::FlatBufferBuilder &_fbb, const LocalResponseNormalizationOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { ::flatbuffers::FlatBufferBuilder *__fbb; const LocalResponseNormalizationOptionsT* __o; const ::flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _radius = _o->radius; + auto _bias = _o->bias; + auto _alpha = _o->alpha; + auto _beta = _o->beta; + return tflite::CreateLocalResponseNormalizationOptions( + _fbb, + _radius, + _bias, + _alpha, + _beta); +} + +inline LSTMOptionsT *LSTMOptions::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { + auto _o = std::unique_ptr(new LSTMOptionsT()); + UnPackTo(_o.get(), _resolver); + return _o.release(); +} + +inline void LSTMOptions::UnPackTo(LSTMOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = fused_activation_function(); _o->fused_activation_function = _e; } + { auto _e = cell_clip(); _o->cell_clip = _e; } + { auto _e = proj_clip(); _o->proj_clip = _e; } + { auto _e = kernel_type(); _o->kernel_type = _e; } + { auto _e = asymmetric_quantize_inputs(); _o->asymmetric_quantize_inputs = _e; } +} + +inline ::flatbuffers::Offset LSTMOptions::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const LSTMOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { + return CreateLSTMOptions(_fbb, _o, _rehasher); +} + +inline ::flatbuffers::Offset CreateLSTMOptions(::flatbuffers::FlatBufferBuilder &_fbb, const LSTMOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { ::flatbuffers::FlatBufferBuilder *__fbb; const LSTMOptionsT* __o; const ::flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _fused_activation_function = _o->fused_activation_function; + auto _cell_clip = _o->cell_clip; + auto _proj_clip = _o->proj_clip; + auto _kernel_type = _o->kernel_type; + auto _asymmetric_quantize_inputs = _o->asymmetric_quantize_inputs; + return tflite::CreateLSTMOptions( + _fbb, + _fused_activation_function, + _cell_clip, + _proj_clip, + _kernel_type, + _asymmetric_quantize_inputs); +} + +inline UnidirectionalSequenceLSTMOptionsT *UnidirectionalSequenceLSTMOptions::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { + auto _o = std::unique_ptr(new UnidirectionalSequenceLSTMOptionsT()); + UnPackTo(_o.get(), _resolver); + return _o.release(); +} + +inline void UnidirectionalSequenceLSTMOptions::UnPackTo(UnidirectionalSequenceLSTMOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = fused_activation_function(); _o->fused_activation_function = _e; } + { auto _e = cell_clip(); _o->cell_clip = _e; } + { auto _e = proj_clip(); _o->proj_clip = _e; } + { auto _e = time_major(); _o->time_major = _e; } + { auto _e = asymmetric_quantize_inputs(); _o->asymmetric_quantize_inputs = _e; } + { auto _e = diagonal_recurrent_tensors(); _o->diagonal_recurrent_tensors = _e; } +} + +inline ::flatbuffers::Offset UnidirectionalSequenceLSTMOptions::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const UnidirectionalSequenceLSTMOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { + return CreateUnidirectionalSequenceLSTMOptions(_fbb, _o, _rehasher); +} + +inline ::flatbuffers::Offset CreateUnidirectionalSequenceLSTMOptions(::flatbuffers::FlatBufferBuilder &_fbb, const UnidirectionalSequenceLSTMOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { ::flatbuffers::FlatBufferBuilder *__fbb; const UnidirectionalSequenceLSTMOptionsT* __o; const ::flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _fused_activation_function = _o->fused_activation_function; + auto _cell_clip = _o->cell_clip; + auto _proj_clip = _o->proj_clip; + auto _time_major = _o->time_major; + auto _asymmetric_quantize_inputs = _o->asymmetric_quantize_inputs; + auto _diagonal_recurrent_tensors = _o->diagonal_recurrent_tensors; + return tflite::CreateUnidirectionalSequenceLSTMOptions( + _fbb, + _fused_activation_function, + _cell_clip, + _proj_clip, + _time_major, + _asymmetric_quantize_inputs, + _diagonal_recurrent_tensors); +} + +inline BidirectionalSequenceLSTMOptionsT *BidirectionalSequenceLSTMOptions::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { + auto _o = std::unique_ptr(new BidirectionalSequenceLSTMOptionsT()); + UnPackTo(_o.get(), _resolver); + return _o.release(); +} + +inline void BidirectionalSequenceLSTMOptions::UnPackTo(BidirectionalSequenceLSTMOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = fused_activation_function(); _o->fused_activation_function = _e; } + { auto _e = cell_clip(); _o->cell_clip = _e; } + { auto _e = proj_clip(); _o->proj_clip = _e; } + { auto _e = merge_outputs(); _o->merge_outputs = _e; } + { auto _e = time_major(); _o->time_major = _e; } + { auto _e = asymmetric_quantize_inputs(); _o->asymmetric_quantize_inputs = _e; } +} + +inline ::flatbuffers::Offset BidirectionalSequenceLSTMOptions::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const BidirectionalSequenceLSTMOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { + return CreateBidirectionalSequenceLSTMOptions(_fbb, _o, _rehasher); +} + +inline ::flatbuffers::Offset CreateBidirectionalSequenceLSTMOptions(::flatbuffers::FlatBufferBuilder &_fbb, const BidirectionalSequenceLSTMOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { ::flatbuffers::FlatBufferBuilder *__fbb; const BidirectionalSequenceLSTMOptionsT* __o; const ::flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _fused_activation_function = _o->fused_activation_function; + auto _cell_clip = _o->cell_clip; + auto _proj_clip = _o->proj_clip; + auto _merge_outputs = _o->merge_outputs; + auto _time_major = _o->time_major; + auto _asymmetric_quantize_inputs = _o->asymmetric_quantize_inputs; + return tflite::CreateBidirectionalSequenceLSTMOptions( + _fbb, + _fused_activation_function, + _cell_clip, + _proj_clip, + _merge_outputs, + _time_major, + _asymmetric_quantize_inputs); +} + +inline ResizeBilinearOptionsT *ResizeBilinearOptions::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { + auto _o = std::unique_ptr(new ResizeBilinearOptionsT()); + UnPackTo(_o.get(), _resolver); + return _o.release(); +} + +inline void ResizeBilinearOptions::UnPackTo(ResizeBilinearOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = align_corners(); _o->align_corners = _e; } + { auto _e = half_pixel_centers(); _o->half_pixel_centers = _e; } +} + +inline ::flatbuffers::Offset ResizeBilinearOptions::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const ResizeBilinearOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { + return CreateResizeBilinearOptions(_fbb, _o, _rehasher); +} + +inline ::flatbuffers::Offset CreateResizeBilinearOptions(::flatbuffers::FlatBufferBuilder &_fbb, const ResizeBilinearOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { ::flatbuffers::FlatBufferBuilder *__fbb; const ResizeBilinearOptionsT* __o; const ::flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _align_corners = _o->align_corners; + auto _half_pixel_centers = _o->half_pixel_centers; + return tflite::CreateResizeBilinearOptions( + _fbb, + _align_corners, + _half_pixel_centers); +} + +inline ResizeNearestNeighborOptionsT *ResizeNearestNeighborOptions::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { + auto _o = std::unique_ptr(new ResizeNearestNeighborOptionsT()); + UnPackTo(_o.get(), _resolver); + return _o.release(); +} + +inline void ResizeNearestNeighborOptions::UnPackTo(ResizeNearestNeighborOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = align_corners(); _o->align_corners = _e; } + { auto _e = half_pixel_centers(); _o->half_pixel_centers = _e; } +} + +inline ::flatbuffers::Offset ResizeNearestNeighborOptions::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const ResizeNearestNeighborOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { + return CreateResizeNearestNeighborOptions(_fbb, _o, _rehasher); +} + +inline ::flatbuffers::Offset CreateResizeNearestNeighborOptions(::flatbuffers::FlatBufferBuilder &_fbb, const ResizeNearestNeighborOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { ::flatbuffers::FlatBufferBuilder *__fbb; const ResizeNearestNeighborOptionsT* __o; const ::flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _align_corners = _o->align_corners; + auto _half_pixel_centers = _o->half_pixel_centers; + return tflite::CreateResizeNearestNeighborOptions( + _fbb, + _align_corners, + _half_pixel_centers); +} + +inline CallOptionsT *CallOptions::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { + auto _o = std::unique_ptr(new CallOptionsT()); + UnPackTo(_o.get(), _resolver); + return _o.release(); +} + +inline void CallOptions::UnPackTo(CallOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = subgraph(); _o->subgraph = _e; } +} + +inline ::flatbuffers::Offset CallOptions::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const CallOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { + return CreateCallOptions(_fbb, _o, _rehasher); +} + +inline ::flatbuffers::Offset CreateCallOptions(::flatbuffers::FlatBufferBuilder &_fbb, const CallOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { ::flatbuffers::FlatBufferBuilder *__fbb; const CallOptionsT* __o; const ::flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _subgraph = _o->subgraph; + return tflite::CreateCallOptions( + _fbb, + _subgraph); +} + +inline PadOptionsT *PadOptions::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { + auto _o = std::unique_ptr(new PadOptionsT()); + UnPackTo(_o.get(), _resolver); + return _o.release(); +} + +inline void PadOptions::UnPackTo(PadOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; +} + +inline ::flatbuffers::Offset PadOptions::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const PadOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { + return CreatePadOptions(_fbb, _o, _rehasher); +} + +inline ::flatbuffers::Offset CreatePadOptions(::flatbuffers::FlatBufferBuilder &_fbb, const PadOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { ::flatbuffers::FlatBufferBuilder *__fbb; const PadOptionsT* __o; const ::flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + return tflite::CreatePadOptions( + _fbb); +} + +inline PadV2OptionsT *PadV2Options::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { + auto _o = std::unique_ptr(new PadV2OptionsT()); + UnPackTo(_o.get(), _resolver); + return _o.release(); +} + +inline void PadV2Options::UnPackTo(PadV2OptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; +} + +inline ::flatbuffers::Offset PadV2Options::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const PadV2OptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { + return CreatePadV2Options(_fbb, _o, _rehasher); +} + +inline ::flatbuffers::Offset CreatePadV2Options(::flatbuffers::FlatBufferBuilder &_fbb, const PadV2OptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { ::flatbuffers::FlatBufferBuilder *__fbb; const PadV2OptionsT* __o; const ::flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + return tflite::CreatePadV2Options( + _fbb); +} + +inline ReshapeOptionsT *ReshapeOptions::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { + auto _o = std::unique_ptr(new ReshapeOptionsT()); + UnPackTo(_o.get(), _resolver); + return _o.release(); +} + +inline void ReshapeOptions::UnPackTo(ReshapeOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = new_shape(); if (_e) { _o->new_shape.resize(_e->size()); for (::flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->new_shape[_i] = _e->Get(_i); } } else { _o->new_shape.resize(0); } } +} + +inline ::flatbuffers::Offset ReshapeOptions::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const ReshapeOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { + return CreateReshapeOptions(_fbb, _o, _rehasher); +} + +inline ::flatbuffers::Offset CreateReshapeOptions(::flatbuffers::FlatBufferBuilder &_fbb, const ReshapeOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { ::flatbuffers::FlatBufferBuilder *__fbb; const ReshapeOptionsT* __o; const ::flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _new_shape = _o->new_shape.size() ? _fbb.CreateVector(_o->new_shape) : 0; + return tflite::CreateReshapeOptions( + _fbb, + _new_shape); +} + +inline SpaceToBatchNDOptionsT *SpaceToBatchNDOptions::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { + auto _o = std::unique_ptr(new SpaceToBatchNDOptionsT()); + UnPackTo(_o.get(), _resolver); + return _o.release(); +} + +inline void SpaceToBatchNDOptions::UnPackTo(SpaceToBatchNDOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; +} + +inline ::flatbuffers::Offset SpaceToBatchNDOptions::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const SpaceToBatchNDOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { + return CreateSpaceToBatchNDOptions(_fbb, _o, _rehasher); +} + +inline ::flatbuffers::Offset CreateSpaceToBatchNDOptions(::flatbuffers::FlatBufferBuilder &_fbb, const SpaceToBatchNDOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { ::flatbuffers::FlatBufferBuilder *__fbb; const SpaceToBatchNDOptionsT* __o; const ::flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + return tflite::CreateSpaceToBatchNDOptions( + _fbb); +} + +inline BatchToSpaceNDOptionsT *BatchToSpaceNDOptions::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { + auto _o = std::unique_ptr(new BatchToSpaceNDOptionsT()); + UnPackTo(_o.get(), _resolver); + return _o.release(); +} + +inline void BatchToSpaceNDOptions::UnPackTo(BatchToSpaceNDOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; +} + +inline ::flatbuffers::Offset BatchToSpaceNDOptions::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const BatchToSpaceNDOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { + return CreateBatchToSpaceNDOptions(_fbb, _o, _rehasher); +} + +inline ::flatbuffers::Offset CreateBatchToSpaceNDOptions(::flatbuffers::FlatBufferBuilder &_fbb, const BatchToSpaceNDOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { ::flatbuffers::FlatBufferBuilder *__fbb; const BatchToSpaceNDOptionsT* __o; const ::flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + return tflite::CreateBatchToSpaceNDOptions( + _fbb); +} + +inline SkipGramOptionsT *SkipGramOptions::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { + auto _o = std::unique_ptr(new SkipGramOptionsT()); + UnPackTo(_o.get(), _resolver); + return _o.release(); +} + +inline void SkipGramOptions::UnPackTo(SkipGramOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = ngram_size(); _o->ngram_size = _e; } + { auto _e = max_skip_size(); _o->max_skip_size = _e; } + { auto _e = include_all_ngrams(); _o->include_all_ngrams = _e; } +} + +inline ::flatbuffers::Offset SkipGramOptions::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const SkipGramOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { + return CreateSkipGramOptions(_fbb, _o, _rehasher); +} + +inline ::flatbuffers::Offset CreateSkipGramOptions(::flatbuffers::FlatBufferBuilder &_fbb, const SkipGramOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { ::flatbuffers::FlatBufferBuilder *__fbb; const SkipGramOptionsT* __o; const ::flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _ngram_size = _o->ngram_size; + auto _max_skip_size = _o->max_skip_size; + auto _include_all_ngrams = _o->include_all_ngrams; + return tflite::CreateSkipGramOptions( + _fbb, + _ngram_size, + _max_skip_size, + _include_all_ngrams); +} + +inline SpaceToDepthOptionsT *SpaceToDepthOptions::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { + auto _o = std::unique_ptr(new SpaceToDepthOptionsT()); + UnPackTo(_o.get(), _resolver); + return _o.release(); +} + +inline void SpaceToDepthOptions::UnPackTo(SpaceToDepthOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = block_size(); _o->block_size = _e; } +} + +inline ::flatbuffers::Offset SpaceToDepthOptions::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const SpaceToDepthOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { + return CreateSpaceToDepthOptions(_fbb, _o, _rehasher); +} + +inline ::flatbuffers::Offset CreateSpaceToDepthOptions(::flatbuffers::FlatBufferBuilder &_fbb, const SpaceToDepthOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { ::flatbuffers::FlatBufferBuilder *__fbb; const SpaceToDepthOptionsT* __o; const ::flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _block_size = _o->block_size; + return tflite::CreateSpaceToDepthOptions( + _fbb, + _block_size); +} + +inline DepthToSpaceOptionsT *DepthToSpaceOptions::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { + auto _o = std::unique_ptr(new DepthToSpaceOptionsT()); + UnPackTo(_o.get(), _resolver); + return _o.release(); +} + +inline void DepthToSpaceOptions::UnPackTo(DepthToSpaceOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = block_size(); _o->block_size = _e; } +} + +inline ::flatbuffers::Offset DepthToSpaceOptions::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const DepthToSpaceOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { + return CreateDepthToSpaceOptions(_fbb, _o, _rehasher); +} + +inline ::flatbuffers::Offset CreateDepthToSpaceOptions(::flatbuffers::FlatBufferBuilder &_fbb, const DepthToSpaceOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { ::flatbuffers::FlatBufferBuilder *__fbb; const DepthToSpaceOptionsT* __o; const ::flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _block_size = _o->block_size; + return tflite::CreateDepthToSpaceOptions( + _fbb, + _block_size); +} + +inline SubOptionsT *SubOptions::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { + auto _o = std::unique_ptr(new SubOptionsT()); + UnPackTo(_o.get(), _resolver); + return _o.release(); +} + +inline void SubOptions::UnPackTo(SubOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = fused_activation_function(); _o->fused_activation_function = _e; } + { auto _e = pot_scale_int16(); _o->pot_scale_int16 = _e; } +} + +inline ::flatbuffers::Offset SubOptions::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const SubOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { + return CreateSubOptions(_fbb, _o, _rehasher); +} + +inline ::flatbuffers::Offset CreateSubOptions(::flatbuffers::FlatBufferBuilder &_fbb, const SubOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { ::flatbuffers::FlatBufferBuilder *__fbb; const SubOptionsT* __o; const ::flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _fused_activation_function = _o->fused_activation_function; + auto _pot_scale_int16 = _o->pot_scale_int16; + return tflite::CreateSubOptions( + _fbb, + _fused_activation_function, + _pot_scale_int16); +} + +inline DivOptionsT *DivOptions::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { + auto _o = std::unique_ptr(new DivOptionsT()); + UnPackTo(_o.get(), _resolver); + return _o.release(); +} + +inline void DivOptions::UnPackTo(DivOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = fused_activation_function(); _o->fused_activation_function = _e; } +} + +inline ::flatbuffers::Offset DivOptions::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const DivOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { + return CreateDivOptions(_fbb, _o, _rehasher); +} + +inline ::flatbuffers::Offset CreateDivOptions(::flatbuffers::FlatBufferBuilder &_fbb, const DivOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { ::flatbuffers::FlatBufferBuilder *__fbb; const DivOptionsT* __o; const ::flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _fused_activation_function = _o->fused_activation_function; + return tflite::CreateDivOptions( + _fbb, + _fused_activation_function); +} + +inline TopKV2OptionsT *TopKV2Options::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { + auto _o = std::unique_ptr(new TopKV2OptionsT()); + UnPackTo(_o.get(), _resolver); + return _o.release(); +} + +inline void TopKV2Options::UnPackTo(TopKV2OptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; +} + +inline ::flatbuffers::Offset TopKV2Options::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const TopKV2OptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { + return CreateTopKV2Options(_fbb, _o, _rehasher); +} + +inline ::flatbuffers::Offset CreateTopKV2Options(::flatbuffers::FlatBufferBuilder &_fbb, const TopKV2OptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { ::flatbuffers::FlatBufferBuilder *__fbb; const TopKV2OptionsT* __o; const ::flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + return tflite::CreateTopKV2Options( + _fbb); +} + +inline EmbeddingLookupSparseOptionsT *EmbeddingLookupSparseOptions::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { + auto _o = std::unique_ptr(new EmbeddingLookupSparseOptionsT()); + UnPackTo(_o.get(), _resolver); + return _o.release(); +} + +inline void EmbeddingLookupSparseOptions::UnPackTo(EmbeddingLookupSparseOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = combiner(); _o->combiner = _e; } +} + +inline ::flatbuffers::Offset EmbeddingLookupSparseOptions::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const EmbeddingLookupSparseOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { + return CreateEmbeddingLookupSparseOptions(_fbb, _o, _rehasher); +} + +inline ::flatbuffers::Offset CreateEmbeddingLookupSparseOptions(::flatbuffers::FlatBufferBuilder &_fbb, const EmbeddingLookupSparseOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { ::flatbuffers::FlatBufferBuilder *__fbb; const EmbeddingLookupSparseOptionsT* __o; const ::flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _combiner = _o->combiner; + return tflite::CreateEmbeddingLookupSparseOptions( + _fbb, + _combiner); +} + +inline GatherOptionsT *GatherOptions::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { + auto _o = std::unique_ptr(new GatherOptionsT()); + UnPackTo(_o.get(), _resolver); + return _o.release(); +} + +inline void GatherOptions::UnPackTo(GatherOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = axis(); _o->axis = _e; } + { auto _e = batch_dims(); _o->batch_dims = _e; } +} + +inline ::flatbuffers::Offset GatherOptions::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const GatherOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { + return CreateGatherOptions(_fbb, _o, _rehasher); +} + +inline ::flatbuffers::Offset CreateGatherOptions(::flatbuffers::FlatBufferBuilder &_fbb, const GatherOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { ::flatbuffers::FlatBufferBuilder *__fbb; const GatherOptionsT* __o; const ::flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _axis = _o->axis; + auto _batch_dims = _o->batch_dims; + return tflite::CreateGatherOptions( + _fbb, + _axis, + _batch_dims); +} + +inline TransposeOptionsT *TransposeOptions::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { + auto _o = std::unique_ptr(new TransposeOptionsT()); + UnPackTo(_o.get(), _resolver); + return _o.release(); +} + +inline void TransposeOptions::UnPackTo(TransposeOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; +} + +inline ::flatbuffers::Offset TransposeOptions::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const TransposeOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { + return CreateTransposeOptions(_fbb, _o, _rehasher); +} + +inline ::flatbuffers::Offset CreateTransposeOptions(::flatbuffers::FlatBufferBuilder &_fbb, const TransposeOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { ::flatbuffers::FlatBufferBuilder *__fbb; const TransposeOptionsT* __o; const ::flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + return tflite::CreateTransposeOptions( + _fbb); +} + +inline ExpOptionsT *ExpOptions::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { + auto _o = std::unique_ptr(new ExpOptionsT()); + UnPackTo(_o.get(), _resolver); + return _o.release(); +} + +inline void ExpOptions::UnPackTo(ExpOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; +} + +inline ::flatbuffers::Offset ExpOptions::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const ExpOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { + return CreateExpOptions(_fbb, _o, _rehasher); +} + +inline ::flatbuffers::Offset CreateExpOptions(::flatbuffers::FlatBufferBuilder &_fbb, const ExpOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { ::flatbuffers::FlatBufferBuilder *__fbb; const ExpOptionsT* __o; const ::flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + return tflite::CreateExpOptions( + _fbb); +} + +inline CosOptionsT *CosOptions::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { + auto _o = std::unique_ptr(new CosOptionsT()); + UnPackTo(_o.get(), _resolver); + return _o.release(); +} + +inline void CosOptions::UnPackTo(CosOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; +} + +inline ::flatbuffers::Offset CosOptions::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const CosOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { + return CreateCosOptions(_fbb, _o, _rehasher); +} + +inline ::flatbuffers::Offset CreateCosOptions(::flatbuffers::FlatBufferBuilder &_fbb, const CosOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { ::flatbuffers::FlatBufferBuilder *__fbb; const CosOptionsT* __o; const ::flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + return tflite::CreateCosOptions( + _fbb); +} + +inline ReducerOptionsT *ReducerOptions::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { + auto _o = std::unique_ptr(new ReducerOptionsT()); + UnPackTo(_o.get(), _resolver); + return _o.release(); +} + +inline void ReducerOptions::UnPackTo(ReducerOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = keep_dims(); _o->keep_dims = _e; } +} + +inline ::flatbuffers::Offset ReducerOptions::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const ReducerOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { + return CreateReducerOptions(_fbb, _o, _rehasher); +} + +inline ::flatbuffers::Offset CreateReducerOptions(::flatbuffers::FlatBufferBuilder &_fbb, const ReducerOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { ::flatbuffers::FlatBufferBuilder *__fbb; const ReducerOptionsT* __o; const ::flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _keep_dims = _o->keep_dims; + return tflite::CreateReducerOptions( + _fbb, + _keep_dims); +} + +inline SqueezeOptionsT *SqueezeOptions::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { + auto _o = std::unique_ptr(new SqueezeOptionsT()); + UnPackTo(_o.get(), _resolver); + return _o.release(); +} + +inline void SqueezeOptions::UnPackTo(SqueezeOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = squeeze_dims(); if (_e) { _o->squeeze_dims.resize(_e->size()); for (::flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->squeeze_dims[_i] = _e->Get(_i); } } else { _o->squeeze_dims.resize(0); } } +} + +inline ::flatbuffers::Offset SqueezeOptions::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const SqueezeOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { + return CreateSqueezeOptions(_fbb, _o, _rehasher); +} + +inline ::flatbuffers::Offset CreateSqueezeOptions(::flatbuffers::FlatBufferBuilder &_fbb, const SqueezeOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { ::flatbuffers::FlatBufferBuilder *__fbb; const SqueezeOptionsT* __o; const ::flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _squeeze_dims = _o->squeeze_dims.size() ? _fbb.CreateVector(_o->squeeze_dims) : 0; + return tflite::CreateSqueezeOptions( + _fbb, + _squeeze_dims); +} + +inline SplitOptionsT *SplitOptions::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { + auto _o = std::unique_ptr(new SplitOptionsT()); + UnPackTo(_o.get(), _resolver); + return _o.release(); +} + +inline void SplitOptions::UnPackTo(SplitOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = num_splits(); _o->num_splits = _e; } +} + +inline ::flatbuffers::Offset SplitOptions::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const SplitOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { + return CreateSplitOptions(_fbb, _o, _rehasher); +} + +inline ::flatbuffers::Offset CreateSplitOptions(::flatbuffers::FlatBufferBuilder &_fbb, const SplitOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { ::flatbuffers::FlatBufferBuilder *__fbb; const SplitOptionsT* __o; const ::flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _num_splits = _o->num_splits; + return tflite::CreateSplitOptions( + _fbb, + _num_splits); +} + +inline SplitVOptionsT *SplitVOptions::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { + auto _o = std::unique_ptr(new SplitVOptionsT()); + UnPackTo(_o.get(), _resolver); + return _o.release(); +} + +inline void SplitVOptions::UnPackTo(SplitVOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = num_splits(); _o->num_splits = _e; } +} + +inline ::flatbuffers::Offset SplitVOptions::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const SplitVOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { + return CreateSplitVOptions(_fbb, _o, _rehasher); +} + +inline ::flatbuffers::Offset CreateSplitVOptions(::flatbuffers::FlatBufferBuilder &_fbb, const SplitVOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { ::flatbuffers::FlatBufferBuilder *__fbb; const SplitVOptionsT* __o; const ::flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _num_splits = _o->num_splits; + return tflite::CreateSplitVOptions( + _fbb, + _num_splits); +} + +inline StridedSliceOptionsT *StridedSliceOptions::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { + auto _o = std::unique_ptr(new StridedSliceOptionsT()); + UnPackTo(_o.get(), _resolver); + return _o.release(); +} + +inline void StridedSliceOptions::UnPackTo(StridedSliceOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = begin_mask(); _o->begin_mask = _e; } + { auto _e = end_mask(); _o->end_mask = _e; } + { auto _e = ellipsis_mask(); _o->ellipsis_mask = _e; } + { auto _e = new_axis_mask(); _o->new_axis_mask = _e; } + { auto _e = shrink_axis_mask(); _o->shrink_axis_mask = _e; } + { auto _e = offset(); _o->offset = _e; } +} + +inline ::flatbuffers::Offset StridedSliceOptions::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const StridedSliceOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { + return CreateStridedSliceOptions(_fbb, _o, _rehasher); +} + +inline ::flatbuffers::Offset CreateStridedSliceOptions(::flatbuffers::FlatBufferBuilder &_fbb, const StridedSliceOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { ::flatbuffers::FlatBufferBuilder *__fbb; const StridedSliceOptionsT* __o; const ::flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _begin_mask = _o->begin_mask; + auto _end_mask = _o->end_mask; + auto _ellipsis_mask = _o->ellipsis_mask; + auto _new_axis_mask = _o->new_axis_mask; + auto _shrink_axis_mask = _o->shrink_axis_mask; + auto _offset = _o->offset; + return tflite::CreateStridedSliceOptions( + _fbb, + _begin_mask, + _end_mask, + _ellipsis_mask, + _new_axis_mask, + _shrink_axis_mask, + _offset); +} + +inline LogSoftmaxOptionsT *LogSoftmaxOptions::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { + auto _o = std::unique_ptr(new LogSoftmaxOptionsT()); + UnPackTo(_o.get(), _resolver); + return _o.release(); +} + +inline void LogSoftmaxOptions::UnPackTo(LogSoftmaxOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; +} + +inline ::flatbuffers::Offset LogSoftmaxOptions::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const LogSoftmaxOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { + return CreateLogSoftmaxOptions(_fbb, _o, _rehasher); +} + +inline ::flatbuffers::Offset CreateLogSoftmaxOptions(::flatbuffers::FlatBufferBuilder &_fbb, const LogSoftmaxOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { ::flatbuffers::FlatBufferBuilder *__fbb; const LogSoftmaxOptionsT* __o; const ::flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + return tflite::CreateLogSoftmaxOptions( + _fbb); +} + +inline CastOptionsT *CastOptions::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { + auto _o = std::unique_ptr(new CastOptionsT()); + UnPackTo(_o.get(), _resolver); + return _o.release(); +} + +inline void CastOptions::UnPackTo(CastOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = in_data_type(); _o->in_data_type = _e; } + { auto _e = out_data_type(); _o->out_data_type = _e; } +} + +inline ::flatbuffers::Offset CastOptions::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const CastOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { + return CreateCastOptions(_fbb, _o, _rehasher); +} + +inline ::flatbuffers::Offset CreateCastOptions(::flatbuffers::FlatBufferBuilder &_fbb, const CastOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { ::flatbuffers::FlatBufferBuilder *__fbb; const CastOptionsT* __o; const ::flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _in_data_type = _o->in_data_type; + auto _out_data_type = _o->out_data_type; + return tflite::CreateCastOptions( + _fbb, + _in_data_type, + _out_data_type); +} + +inline DequantizeOptionsT *DequantizeOptions::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { + auto _o = std::unique_ptr(new DequantizeOptionsT()); + UnPackTo(_o.get(), _resolver); + return _o.release(); +} + +inline void DequantizeOptions::UnPackTo(DequantizeOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; +} + +inline ::flatbuffers::Offset DequantizeOptions::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const DequantizeOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { + return CreateDequantizeOptions(_fbb, _o, _rehasher); +} + +inline ::flatbuffers::Offset CreateDequantizeOptions(::flatbuffers::FlatBufferBuilder &_fbb, const DequantizeOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { ::flatbuffers::FlatBufferBuilder *__fbb; const DequantizeOptionsT* __o; const ::flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + return tflite::CreateDequantizeOptions( + _fbb); +} + +inline MaximumMinimumOptionsT *MaximumMinimumOptions::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { + auto _o = std::unique_ptr(new MaximumMinimumOptionsT()); + UnPackTo(_o.get(), _resolver); + return _o.release(); +} + +inline void MaximumMinimumOptions::UnPackTo(MaximumMinimumOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; +} + +inline ::flatbuffers::Offset MaximumMinimumOptions::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const MaximumMinimumOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { + return CreateMaximumMinimumOptions(_fbb, _o, _rehasher); +} + +inline ::flatbuffers::Offset CreateMaximumMinimumOptions(::flatbuffers::FlatBufferBuilder &_fbb, const MaximumMinimumOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { ::flatbuffers::FlatBufferBuilder *__fbb; const MaximumMinimumOptionsT* __o; const ::flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + return tflite::CreateMaximumMinimumOptions( + _fbb); +} + +inline TileOptionsT *TileOptions::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { + auto _o = std::unique_ptr(new TileOptionsT()); + UnPackTo(_o.get(), _resolver); + return _o.release(); +} + +inline void TileOptions::UnPackTo(TileOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; +} + +inline ::flatbuffers::Offset TileOptions::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const TileOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { + return CreateTileOptions(_fbb, _o, _rehasher); +} + +inline ::flatbuffers::Offset CreateTileOptions(::flatbuffers::FlatBufferBuilder &_fbb, const TileOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { ::flatbuffers::FlatBufferBuilder *__fbb; const TileOptionsT* __o; const ::flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + return tflite::CreateTileOptions( + _fbb); +} + +inline ArgMaxOptionsT *ArgMaxOptions::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { + auto _o = std::unique_ptr(new ArgMaxOptionsT()); + UnPackTo(_o.get(), _resolver); + return _o.release(); +} + +inline void ArgMaxOptions::UnPackTo(ArgMaxOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = output_type(); _o->output_type = _e; } +} + +inline ::flatbuffers::Offset ArgMaxOptions::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const ArgMaxOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { + return CreateArgMaxOptions(_fbb, _o, _rehasher); +} + +inline ::flatbuffers::Offset CreateArgMaxOptions(::flatbuffers::FlatBufferBuilder &_fbb, const ArgMaxOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { ::flatbuffers::FlatBufferBuilder *__fbb; const ArgMaxOptionsT* __o; const ::flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _output_type = _o->output_type; + return tflite::CreateArgMaxOptions( + _fbb, + _output_type); +} + +inline ArgMinOptionsT *ArgMinOptions::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { + auto _o = std::unique_ptr(new ArgMinOptionsT()); + UnPackTo(_o.get(), _resolver); + return _o.release(); +} + +inline void ArgMinOptions::UnPackTo(ArgMinOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = output_type(); _o->output_type = _e; } +} + +inline ::flatbuffers::Offset ArgMinOptions::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const ArgMinOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { + return CreateArgMinOptions(_fbb, _o, _rehasher); +} + +inline ::flatbuffers::Offset CreateArgMinOptions(::flatbuffers::FlatBufferBuilder &_fbb, const ArgMinOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { ::flatbuffers::FlatBufferBuilder *__fbb; const ArgMinOptionsT* __o; const ::flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _output_type = _o->output_type; + return tflite::CreateArgMinOptions( + _fbb, + _output_type); +} + +inline GreaterOptionsT *GreaterOptions::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { + auto _o = std::unique_ptr(new GreaterOptionsT()); + UnPackTo(_o.get(), _resolver); + return _o.release(); +} + +inline void GreaterOptions::UnPackTo(GreaterOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; +} + +inline ::flatbuffers::Offset GreaterOptions::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const GreaterOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { + return CreateGreaterOptions(_fbb, _o, _rehasher); +} + +inline ::flatbuffers::Offset CreateGreaterOptions(::flatbuffers::FlatBufferBuilder &_fbb, const GreaterOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { ::flatbuffers::FlatBufferBuilder *__fbb; const GreaterOptionsT* __o; const ::flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + return tflite::CreateGreaterOptions( + _fbb); +} + +inline GreaterEqualOptionsT *GreaterEqualOptions::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { + auto _o = std::unique_ptr(new GreaterEqualOptionsT()); + UnPackTo(_o.get(), _resolver); + return _o.release(); +} + +inline void GreaterEqualOptions::UnPackTo(GreaterEqualOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; +} + +inline ::flatbuffers::Offset GreaterEqualOptions::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const GreaterEqualOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { + return CreateGreaterEqualOptions(_fbb, _o, _rehasher); +} + +inline ::flatbuffers::Offset CreateGreaterEqualOptions(::flatbuffers::FlatBufferBuilder &_fbb, const GreaterEqualOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { ::flatbuffers::FlatBufferBuilder *__fbb; const GreaterEqualOptionsT* __o; const ::flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + return tflite::CreateGreaterEqualOptions( + _fbb); +} + +inline LessOptionsT *LessOptions::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { + auto _o = std::unique_ptr(new LessOptionsT()); + UnPackTo(_o.get(), _resolver); + return _o.release(); +} + +inline void LessOptions::UnPackTo(LessOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; +} + +inline ::flatbuffers::Offset LessOptions::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const LessOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { + return CreateLessOptions(_fbb, _o, _rehasher); +} + +inline ::flatbuffers::Offset CreateLessOptions(::flatbuffers::FlatBufferBuilder &_fbb, const LessOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { ::flatbuffers::FlatBufferBuilder *__fbb; const LessOptionsT* __o; const ::flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + return tflite::CreateLessOptions( + _fbb); +} + +inline LessEqualOptionsT *LessEqualOptions::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { + auto _o = std::unique_ptr(new LessEqualOptionsT()); + UnPackTo(_o.get(), _resolver); + return _o.release(); +} + +inline void LessEqualOptions::UnPackTo(LessEqualOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; +} + +inline ::flatbuffers::Offset LessEqualOptions::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const LessEqualOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { + return CreateLessEqualOptions(_fbb, _o, _rehasher); +} + +inline ::flatbuffers::Offset CreateLessEqualOptions(::flatbuffers::FlatBufferBuilder &_fbb, const LessEqualOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { ::flatbuffers::FlatBufferBuilder *__fbb; const LessEqualOptionsT* __o; const ::flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + return tflite::CreateLessEqualOptions( + _fbb); +} + +inline NegOptionsT *NegOptions::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { + auto _o = std::unique_ptr(new NegOptionsT()); + UnPackTo(_o.get(), _resolver); + return _o.release(); +} + +inline void NegOptions::UnPackTo(NegOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; +} + +inline ::flatbuffers::Offset NegOptions::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const NegOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { + return CreateNegOptions(_fbb, _o, _rehasher); +} + +inline ::flatbuffers::Offset CreateNegOptions(::flatbuffers::FlatBufferBuilder &_fbb, const NegOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { ::flatbuffers::FlatBufferBuilder *__fbb; const NegOptionsT* __o; const ::flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + return tflite::CreateNegOptions( + _fbb); +} + +inline SelectOptionsT *SelectOptions::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { + auto _o = std::unique_ptr(new SelectOptionsT()); + UnPackTo(_o.get(), _resolver); + return _o.release(); +} + +inline void SelectOptions::UnPackTo(SelectOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; +} + +inline ::flatbuffers::Offset SelectOptions::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const SelectOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { + return CreateSelectOptions(_fbb, _o, _rehasher); +} + +inline ::flatbuffers::Offset CreateSelectOptions(::flatbuffers::FlatBufferBuilder &_fbb, const SelectOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { ::flatbuffers::FlatBufferBuilder *__fbb; const SelectOptionsT* __o; const ::flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + return tflite::CreateSelectOptions( + _fbb); +} + +inline SliceOptionsT *SliceOptions::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { + auto _o = std::unique_ptr(new SliceOptionsT()); + UnPackTo(_o.get(), _resolver); + return _o.release(); +} + +inline void SliceOptions::UnPackTo(SliceOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; +} + +inline ::flatbuffers::Offset SliceOptions::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const SliceOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { + return CreateSliceOptions(_fbb, _o, _rehasher); +} + +inline ::flatbuffers::Offset CreateSliceOptions(::flatbuffers::FlatBufferBuilder &_fbb, const SliceOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { ::flatbuffers::FlatBufferBuilder *__fbb; const SliceOptionsT* __o; const ::flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + return tflite::CreateSliceOptions( + _fbb); +} + +inline TransposeConvOptionsT *TransposeConvOptions::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { + auto _o = std::unique_ptr(new TransposeConvOptionsT()); + UnPackTo(_o.get(), _resolver); + return _o.release(); +} + +inline void TransposeConvOptions::UnPackTo(TransposeConvOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = padding(); _o->padding = _e; } + { auto _e = stride_w(); _o->stride_w = _e; } + { auto _e = stride_h(); _o->stride_h = _e; } + { auto _e = fused_activation_function(); _o->fused_activation_function = _e; } + { auto _e = quantized_bias_type(); _o->quantized_bias_type = _e; } +} + +inline ::flatbuffers::Offset TransposeConvOptions::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const TransposeConvOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { + return CreateTransposeConvOptions(_fbb, _o, _rehasher); +} + +inline ::flatbuffers::Offset CreateTransposeConvOptions(::flatbuffers::FlatBufferBuilder &_fbb, const TransposeConvOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { ::flatbuffers::FlatBufferBuilder *__fbb; const TransposeConvOptionsT* __o; const ::flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _padding = _o->padding; + auto _stride_w = _o->stride_w; + auto _stride_h = _o->stride_h; + auto _fused_activation_function = _o->fused_activation_function; + auto _quantized_bias_type = _o->quantized_bias_type; + return tflite::CreateTransposeConvOptions( + _fbb, + _padding, + _stride_w, + _stride_h, + _fused_activation_function, + _quantized_bias_type); +} + +inline ExpandDimsOptionsT *ExpandDimsOptions::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { + auto _o = std::unique_ptr(new ExpandDimsOptionsT()); + UnPackTo(_o.get(), _resolver); + return _o.release(); +} + +inline void ExpandDimsOptions::UnPackTo(ExpandDimsOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; +} + +inline ::flatbuffers::Offset ExpandDimsOptions::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const ExpandDimsOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { + return CreateExpandDimsOptions(_fbb, _o, _rehasher); +} + +inline ::flatbuffers::Offset CreateExpandDimsOptions(::flatbuffers::FlatBufferBuilder &_fbb, const ExpandDimsOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { ::flatbuffers::FlatBufferBuilder *__fbb; const ExpandDimsOptionsT* __o; const ::flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + return tflite::CreateExpandDimsOptions( + _fbb); +} + +inline SparseToDenseOptionsT *SparseToDenseOptions::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { + auto _o = std::unique_ptr(new SparseToDenseOptionsT()); + UnPackTo(_o.get(), _resolver); + return _o.release(); +} + +inline void SparseToDenseOptions::UnPackTo(SparseToDenseOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = validate_indices(); _o->validate_indices = _e; } +} + +inline ::flatbuffers::Offset SparseToDenseOptions::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const SparseToDenseOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { + return CreateSparseToDenseOptions(_fbb, _o, _rehasher); +} + +inline ::flatbuffers::Offset CreateSparseToDenseOptions(::flatbuffers::FlatBufferBuilder &_fbb, const SparseToDenseOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { ::flatbuffers::FlatBufferBuilder *__fbb; const SparseToDenseOptionsT* __o; const ::flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _validate_indices = _o->validate_indices; + return tflite::CreateSparseToDenseOptions( + _fbb, + _validate_indices); +} + +inline EqualOptionsT *EqualOptions::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { + auto _o = std::unique_ptr(new EqualOptionsT()); + UnPackTo(_o.get(), _resolver); + return _o.release(); +} + +inline void EqualOptions::UnPackTo(EqualOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; +} + +inline ::flatbuffers::Offset EqualOptions::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const EqualOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { + return CreateEqualOptions(_fbb, _o, _rehasher); +} + +inline ::flatbuffers::Offset CreateEqualOptions(::flatbuffers::FlatBufferBuilder &_fbb, const EqualOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { ::flatbuffers::FlatBufferBuilder *__fbb; const EqualOptionsT* __o; const ::flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + return tflite::CreateEqualOptions( + _fbb); +} + +inline NotEqualOptionsT *NotEqualOptions::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { + auto _o = std::unique_ptr(new NotEqualOptionsT()); + UnPackTo(_o.get(), _resolver); + return _o.release(); +} + +inline void NotEqualOptions::UnPackTo(NotEqualOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; +} + +inline ::flatbuffers::Offset NotEqualOptions::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const NotEqualOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { + return CreateNotEqualOptions(_fbb, _o, _rehasher); +} + +inline ::flatbuffers::Offset CreateNotEqualOptions(::flatbuffers::FlatBufferBuilder &_fbb, const NotEqualOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { ::flatbuffers::FlatBufferBuilder *__fbb; const NotEqualOptionsT* __o; const ::flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + return tflite::CreateNotEqualOptions( + _fbb); +} + +inline ShapeOptionsT *ShapeOptions::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { + auto _o = std::unique_ptr(new ShapeOptionsT()); + UnPackTo(_o.get(), _resolver); + return _o.release(); +} + +inline void ShapeOptions::UnPackTo(ShapeOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = out_type(); _o->out_type = _e; } +} + +inline ::flatbuffers::Offset ShapeOptions::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const ShapeOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { + return CreateShapeOptions(_fbb, _o, _rehasher); +} + +inline ::flatbuffers::Offset CreateShapeOptions(::flatbuffers::FlatBufferBuilder &_fbb, const ShapeOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { ::flatbuffers::FlatBufferBuilder *__fbb; const ShapeOptionsT* __o; const ::flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _out_type = _o->out_type; + return tflite::CreateShapeOptions( + _fbb, + _out_type); +} + +inline RankOptionsT *RankOptions::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { + auto _o = std::unique_ptr(new RankOptionsT()); + UnPackTo(_o.get(), _resolver); + return _o.release(); +} + +inline void RankOptions::UnPackTo(RankOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; +} + +inline ::flatbuffers::Offset RankOptions::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const RankOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { + return CreateRankOptions(_fbb, _o, _rehasher); +} + +inline ::flatbuffers::Offset CreateRankOptions(::flatbuffers::FlatBufferBuilder &_fbb, const RankOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { ::flatbuffers::FlatBufferBuilder *__fbb; const RankOptionsT* __o; const ::flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + return tflite::CreateRankOptions( + _fbb); +} + +inline PowOptionsT *PowOptions::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { + auto _o = std::unique_ptr(new PowOptionsT()); + UnPackTo(_o.get(), _resolver); + return _o.release(); +} + +inline void PowOptions::UnPackTo(PowOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; +} + +inline ::flatbuffers::Offset PowOptions::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const PowOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { + return CreatePowOptions(_fbb, _o, _rehasher); +} + +inline ::flatbuffers::Offset CreatePowOptions(::flatbuffers::FlatBufferBuilder &_fbb, const PowOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { ::flatbuffers::FlatBufferBuilder *__fbb; const PowOptionsT* __o; const ::flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + return tflite::CreatePowOptions( + _fbb); +} + +inline FakeQuantOptionsT *FakeQuantOptions::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { + auto _o = std::unique_ptr(new FakeQuantOptionsT()); + UnPackTo(_o.get(), _resolver); + return _o.release(); +} + +inline void FakeQuantOptions::UnPackTo(FakeQuantOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = min(); _o->min = _e; } + { auto _e = max(); _o->max = _e; } + { auto _e = num_bits(); _o->num_bits = _e; } + { auto _e = narrow_range(); _o->narrow_range = _e; } +} + +inline ::flatbuffers::Offset FakeQuantOptions::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const FakeQuantOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { + return CreateFakeQuantOptions(_fbb, _o, _rehasher); +} + +inline ::flatbuffers::Offset CreateFakeQuantOptions(::flatbuffers::FlatBufferBuilder &_fbb, const FakeQuantOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { ::flatbuffers::FlatBufferBuilder *__fbb; const FakeQuantOptionsT* __o; const ::flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _min = _o->min; + auto _max = _o->max; + auto _num_bits = _o->num_bits; + auto _narrow_range = _o->narrow_range; + return tflite::CreateFakeQuantOptions( + _fbb, + _min, + _max, + _num_bits, + _narrow_range); +} + +inline PackOptionsT *PackOptions::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { + auto _o = std::unique_ptr(new PackOptionsT()); + UnPackTo(_o.get(), _resolver); + return _o.release(); +} + +inline void PackOptions::UnPackTo(PackOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = values_count(); _o->values_count = _e; } + { auto _e = axis(); _o->axis = _e; } +} + +inline ::flatbuffers::Offset PackOptions::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const PackOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { + return CreatePackOptions(_fbb, _o, _rehasher); +} + +inline ::flatbuffers::Offset CreatePackOptions(::flatbuffers::FlatBufferBuilder &_fbb, const PackOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { ::flatbuffers::FlatBufferBuilder *__fbb; const PackOptionsT* __o; const ::flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _values_count = _o->values_count; + auto _axis = _o->axis; + return tflite::CreatePackOptions( + _fbb, + _values_count, + _axis); +} + +inline LogicalOrOptionsT *LogicalOrOptions::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { + auto _o = std::unique_ptr(new LogicalOrOptionsT()); + UnPackTo(_o.get(), _resolver); + return _o.release(); +} + +inline void LogicalOrOptions::UnPackTo(LogicalOrOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; +} + +inline ::flatbuffers::Offset LogicalOrOptions::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const LogicalOrOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { + return CreateLogicalOrOptions(_fbb, _o, _rehasher); +} + +inline ::flatbuffers::Offset CreateLogicalOrOptions(::flatbuffers::FlatBufferBuilder &_fbb, const LogicalOrOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { ::flatbuffers::FlatBufferBuilder *__fbb; const LogicalOrOptionsT* __o; const ::flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + return tflite::CreateLogicalOrOptions( + _fbb); +} + +inline OneHotOptionsT *OneHotOptions::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { + auto _o = std::unique_ptr(new OneHotOptionsT()); + UnPackTo(_o.get(), _resolver); + return _o.release(); +} + +inline void OneHotOptions::UnPackTo(OneHotOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = axis(); _o->axis = _e; } +} + +inline ::flatbuffers::Offset OneHotOptions::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const OneHotOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { + return CreateOneHotOptions(_fbb, _o, _rehasher); +} + +inline ::flatbuffers::Offset CreateOneHotOptions(::flatbuffers::FlatBufferBuilder &_fbb, const OneHotOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { ::flatbuffers::FlatBufferBuilder *__fbb; const OneHotOptionsT* __o; const ::flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _axis = _o->axis; + return tflite::CreateOneHotOptions( + _fbb, + _axis); +} + +inline AbsOptionsT *AbsOptions::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { + auto _o = std::unique_ptr(new AbsOptionsT()); + UnPackTo(_o.get(), _resolver); + return _o.release(); +} + +inline void AbsOptions::UnPackTo(AbsOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; +} + +inline ::flatbuffers::Offset AbsOptions::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const AbsOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { + return CreateAbsOptions(_fbb, _o, _rehasher); +} + +inline ::flatbuffers::Offset CreateAbsOptions(::flatbuffers::FlatBufferBuilder &_fbb, const AbsOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { ::flatbuffers::FlatBufferBuilder *__fbb; const AbsOptionsT* __o; const ::flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + return tflite::CreateAbsOptions( + _fbb); +} + +inline HardSwishOptionsT *HardSwishOptions::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { + auto _o = std::unique_ptr(new HardSwishOptionsT()); + UnPackTo(_o.get(), _resolver); + return _o.release(); +} + +inline void HardSwishOptions::UnPackTo(HardSwishOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; +} + +inline ::flatbuffers::Offset HardSwishOptions::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const HardSwishOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { + return CreateHardSwishOptions(_fbb, _o, _rehasher); +} + +inline ::flatbuffers::Offset CreateHardSwishOptions(::flatbuffers::FlatBufferBuilder &_fbb, const HardSwishOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { ::flatbuffers::FlatBufferBuilder *__fbb; const HardSwishOptionsT* __o; const ::flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + return tflite::CreateHardSwishOptions( + _fbb); +} + +inline LogicalAndOptionsT *LogicalAndOptions::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { + auto _o = std::unique_ptr(new LogicalAndOptionsT()); + UnPackTo(_o.get(), _resolver); + return _o.release(); +} + +inline void LogicalAndOptions::UnPackTo(LogicalAndOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; +} + +inline ::flatbuffers::Offset LogicalAndOptions::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const LogicalAndOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { + return CreateLogicalAndOptions(_fbb, _o, _rehasher); +} + +inline ::flatbuffers::Offset CreateLogicalAndOptions(::flatbuffers::FlatBufferBuilder &_fbb, const LogicalAndOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { ::flatbuffers::FlatBufferBuilder *__fbb; const LogicalAndOptionsT* __o; const ::flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + return tflite::CreateLogicalAndOptions( + _fbb); +} + +inline LogicalNotOptionsT *LogicalNotOptions::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { + auto _o = std::unique_ptr(new LogicalNotOptionsT()); + UnPackTo(_o.get(), _resolver); + return _o.release(); +} + +inline void LogicalNotOptions::UnPackTo(LogicalNotOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; +} + +inline ::flatbuffers::Offset LogicalNotOptions::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const LogicalNotOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { + return CreateLogicalNotOptions(_fbb, _o, _rehasher); +} + +inline ::flatbuffers::Offset CreateLogicalNotOptions(::flatbuffers::FlatBufferBuilder &_fbb, const LogicalNotOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { ::flatbuffers::FlatBufferBuilder *__fbb; const LogicalNotOptionsT* __o; const ::flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + return tflite::CreateLogicalNotOptions( + _fbb); +} + +inline UnpackOptionsT *UnpackOptions::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { + auto _o = std::unique_ptr(new UnpackOptionsT()); + UnPackTo(_o.get(), _resolver); + return _o.release(); +} + +inline void UnpackOptions::UnPackTo(UnpackOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = num(); _o->num = _e; } + { auto _e = axis(); _o->axis = _e; } +} + +inline ::flatbuffers::Offset UnpackOptions::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const UnpackOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { + return CreateUnpackOptions(_fbb, _o, _rehasher); +} + +inline ::flatbuffers::Offset CreateUnpackOptions(::flatbuffers::FlatBufferBuilder &_fbb, const UnpackOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { ::flatbuffers::FlatBufferBuilder *__fbb; const UnpackOptionsT* __o; const ::flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _num = _o->num; + auto _axis = _o->axis; + return tflite::CreateUnpackOptions( + _fbb, + _num, + _axis); +} + +inline FloorDivOptionsT *FloorDivOptions::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { + auto _o = std::unique_ptr(new FloorDivOptionsT()); + UnPackTo(_o.get(), _resolver); + return _o.release(); +} + +inline void FloorDivOptions::UnPackTo(FloorDivOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; +} + +inline ::flatbuffers::Offset FloorDivOptions::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const FloorDivOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { + return CreateFloorDivOptions(_fbb, _o, _rehasher); +} + +inline ::flatbuffers::Offset CreateFloorDivOptions(::flatbuffers::FlatBufferBuilder &_fbb, const FloorDivOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { ::flatbuffers::FlatBufferBuilder *__fbb; const FloorDivOptionsT* __o; const ::flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + return tflite::CreateFloorDivOptions( + _fbb); +} + +inline SquareOptionsT *SquareOptions::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { + auto _o = std::unique_ptr(new SquareOptionsT()); + UnPackTo(_o.get(), _resolver); + return _o.release(); +} + +inline void SquareOptions::UnPackTo(SquareOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; +} + +inline ::flatbuffers::Offset SquareOptions::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const SquareOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { + return CreateSquareOptions(_fbb, _o, _rehasher); +} + +inline ::flatbuffers::Offset CreateSquareOptions(::flatbuffers::FlatBufferBuilder &_fbb, const SquareOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { ::flatbuffers::FlatBufferBuilder *__fbb; const SquareOptionsT* __o; const ::flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + return tflite::CreateSquareOptions( + _fbb); +} + +inline ZerosLikeOptionsT *ZerosLikeOptions::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { + auto _o = std::unique_ptr(new ZerosLikeOptionsT()); + UnPackTo(_o.get(), _resolver); + return _o.release(); +} + +inline void ZerosLikeOptions::UnPackTo(ZerosLikeOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; +} + +inline ::flatbuffers::Offset ZerosLikeOptions::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const ZerosLikeOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { + return CreateZerosLikeOptions(_fbb, _o, _rehasher); +} + +inline ::flatbuffers::Offset CreateZerosLikeOptions(::flatbuffers::FlatBufferBuilder &_fbb, const ZerosLikeOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { ::flatbuffers::FlatBufferBuilder *__fbb; const ZerosLikeOptionsT* __o; const ::flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + return tflite::CreateZerosLikeOptions( + _fbb); +} + +inline FillOptionsT *FillOptions::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { + auto _o = std::unique_ptr(new FillOptionsT()); + UnPackTo(_o.get(), _resolver); + return _o.release(); +} + +inline void FillOptions::UnPackTo(FillOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; +} + +inline ::flatbuffers::Offset FillOptions::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const FillOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { + return CreateFillOptions(_fbb, _o, _rehasher); +} + +inline ::flatbuffers::Offset CreateFillOptions(::flatbuffers::FlatBufferBuilder &_fbb, const FillOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { ::flatbuffers::FlatBufferBuilder *__fbb; const FillOptionsT* __o; const ::flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + return tflite::CreateFillOptions( + _fbb); +} + +inline FloorModOptionsT *FloorModOptions::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { + auto _o = std::unique_ptr(new FloorModOptionsT()); + UnPackTo(_o.get(), _resolver); + return _o.release(); +} + +inline void FloorModOptions::UnPackTo(FloorModOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; +} + +inline ::flatbuffers::Offset FloorModOptions::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const FloorModOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { + return CreateFloorModOptions(_fbb, _o, _rehasher); +} + +inline ::flatbuffers::Offset CreateFloorModOptions(::flatbuffers::FlatBufferBuilder &_fbb, const FloorModOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { ::flatbuffers::FlatBufferBuilder *__fbb; const FloorModOptionsT* __o; const ::flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + return tflite::CreateFloorModOptions( + _fbb); +} + +inline RangeOptionsT *RangeOptions::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { + auto _o = std::unique_ptr(new RangeOptionsT()); + UnPackTo(_o.get(), _resolver); + return _o.release(); +} + +inline void RangeOptions::UnPackTo(RangeOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; +} + +inline ::flatbuffers::Offset RangeOptions::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const RangeOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { + return CreateRangeOptions(_fbb, _o, _rehasher); +} + +inline ::flatbuffers::Offset CreateRangeOptions(::flatbuffers::FlatBufferBuilder &_fbb, const RangeOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { ::flatbuffers::FlatBufferBuilder *__fbb; const RangeOptionsT* __o; const ::flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + return tflite::CreateRangeOptions( + _fbb); +} + +inline LeakyReluOptionsT *LeakyReluOptions::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { + auto _o = std::unique_ptr(new LeakyReluOptionsT()); + UnPackTo(_o.get(), _resolver); + return _o.release(); +} + +inline void LeakyReluOptions::UnPackTo(LeakyReluOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = alpha(); _o->alpha = _e; } +} + +inline ::flatbuffers::Offset LeakyReluOptions::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const LeakyReluOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { + return CreateLeakyReluOptions(_fbb, _o, _rehasher); +} + +inline ::flatbuffers::Offset CreateLeakyReluOptions(::flatbuffers::FlatBufferBuilder &_fbb, const LeakyReluOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { ::flatbuffers::FlatBufferBuilder *__fbb; const LeakyReluOptionsT* __o; const ::flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _alpha = _o->alpha; + return tflite::CreateLeakyReluOptions( + _fbb, + _alpha); +} + +inline SquaredDifferenceOptionsT *SquaredDifferenceOptions::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { + auto _o = std::unique_ptr(new SquaredDifferenceOptionsT()); + UnPackTo(_o.get(), _resolver); + return _o.release(); +} + +inline void SquaredDifferenceOptions::UnPackTo(SquaredDifferenceOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; +} + +inline ::flatbuffers::Offset SquaredDifferenceOptions::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const SquaredDifferenceOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { + return CreateSquaredDifferenceOptions(_fbb, _o, _rehasher); +} + +inline ::flatbuffers::Offset CreateSquaredDifferenceOptions(::flatbuffers::FlatBufferBuilder &_fbb, const SquaredDifferenceOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { ::flatbuffers::FlatBufferBuilder *__fbb; const SquaredDifferenceOptionsT* __o; const ::flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + return tflite::CreateSquaredDifferenceOptions( + _fbb); +} + +inline MirrorPadOptionsT *MirrorPadOptions::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { + auto _o = std::unique_ptr(new MirrorPadOptionsT()); + UnPackTo(_o.get(), _resolver); + return _o.release(); +} + +inline void MirrorPadOptions::UnPackTo(MirrorPadOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = mode(); _o->mode = _e; } +} + +inline ::flatbuffers::Offset MirrorPadOptions::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const MirrorPadOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { + return CreateMirrorPadOptions(_fbb, _o, _rehasher); +} + +inline ::flatbuffers::Offset CreateMirrorPadOptions(::flatbuffers::FlatBufferBuilder &_fbb, const MirrorPadOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { ::flatbuffers::FlatBufferBuilder *__fbb; const MirrorPadOptionsT* __o; const ::flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _mode = _o->mode; + return tflite::CreateMirrorPadOptions( + _fbb, + _mode); +} + +inline UniqueOptionsT *UniqueOptions::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { + auto _o = std::unique_ptr(new UniqueOptionsT()); + UnPackTo(_o.get(), _resolver); + return _o.release(); +} + +inline void UniqueOptions::UnPackTo(UniqueOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = idx_out_type(); _o->idx_out_type = _e; } +} + +inline ::flatbuffers::Offset UniqueOptions::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const UniqueOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { + return CreateUniqueOptions(_fbb, _o, _rehasher); +} + +inline ::flatbuffers::Offset CreateUniqueOptions(::flatbuffers::FlatBufferBuilder &_fbb, const UniqueOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { ::flatbuffers::FlatBufferBuilder *__fbb; const UniqueOptionsT* __o; const ::flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _idx_out_type = _o->idx_out_type; + return tflite::CreateUniqueOptions( + _fbb, + _idx_out_type); +} + +inline ReverseV2OptionsT *ReverseV2Options::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { + auto _o = std::unique_ptr(new ReverseV2OptionsT()); + UnPackTo(_o.get(), _resolver); + return _o.release(); +} + +inline void ReverseV2Options::UnPackTo(ReverseV2OptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; +} + +inline ::flatbuffers::Offset ReverseV2Options::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const ReverseV2OptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { + return CreateReverseV2Options(_fbb, _o, _rehasher); +} + +inline ::flatbuffers::Offset CreateReverseV2Options(::flatbuffers::FlatBufferBuilder &_fbb, const ReverseV2OptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { ::flatbuffers::FlatBufferBuilder *__fbb; const ReverseV2OptionsT* __o; const ::flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + return tflite::CreateReverseV2Options( + _fbb); +} + +inline AddNOptionsT *AddNOptions::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { + auto _o = std::unique_ptr(new AddNOptionsT()); + UnPackTo(_o.get(), _resolver); + return _o.release(); +} + +inline void AddNOptions::UnPackTo(AddNOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; +} + +inline ::flatbuffers::Offset AddNOptions::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const AddNOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { + return CreateAddNOptions(_fbb, _o, _rehasher); +} + +inline ::flatbuffers::Offset CreateAddNOptions(::flatbuffers::FlatBufferBuilder &_fbb, const AddNOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { ::flatbuffers::FlatBufferBuilder *__fbb; const AddNOptionsT* __o; const ::flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + return tflite::CreateAddNOptions( + _fbb); +} + +inline GatherNdOptionsT *GatherNdOptions::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { + auto _o = std::unique_ptr(new GatherNdOptionsT()); + UnPackTo(_o.get(), _resolver); + return _o.release(); +} + +inline void GatherNdOptions::UnPackTo(GatherNdOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; +} + +inline ::flatbuffers::Offset GatherNdOptions::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const GatherNdOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { + return CreateGatherNdOptions(_fbb, _o, _rehasher); +} + +inline ::flatbuffers::Offset CreateGatherNdOptions(::flatbuffers::FlatBufferBuilder &_fbb, const GatherNdOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { ::flatbuffers::FlatBufferBuilder *__fbb; const GatherNdOptionsT* __o; const ::flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + return tflite::CreateGatherNdOptions( + _fbb); +} + +inline WhereOptionsT *WhereOptions::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { + auto _o = std::unique_ptr(new WhereOptionsT()); + UnPackTo(_o.get(), _resolver); + return _o.release(); +} + +inline void WhereOptions::UnPackTo(WhereOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; +} + +inline ::flatbuffers::Offset WhereOptions::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const WhereOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { + return CreateWhereOptions(_fbb, _o, _rehasher); +} + +inline ::flatbuffers::Offset CreateWhereOptions(::flatbuffers::FlatBufferBuilder &_fbb, const WhereOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { ::flatbuffers::FlatBufferBuilder *__fbb; const WhereOptionsT* __o; const ::flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + return tflite::CreateWhereOptions( + _fbb); +} + +inline ReverseSequenceOptionsT *ReverseSequenceOptions::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { + auto _o = std::unique_ptr(new ReverseSequenceOptionsT()); + UnPackTo(_o.get(), _resolver); + return _o.release(); +} + +inline void ReverseSequenceOptions::UnPackTo(ReverseSequenceOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = seq_dim(); _o->seq_dim = _e; } + { auto _e = batch_dim(); _o->batch_dim = _e; } +} + +inline ::flatbuffers::Offset ReverseSequenceOptions::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const ReverseSequenceOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { + return CreateReverseSequenceOptions(_fbb, _o, _rehasher); +} + +inline ::flatbuffers::Offset CreateReverseSequenceOptions(::flatbuffers::FlatBufferBuilder &_fbb, const ReverseSequenceOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { ::flatbuffers::FlatBufferBuilder *__fbb; const ReverseSequenceOptionsT* __o; const ::flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _seq_dim = _o->seq_dim; + auto _batch_dim = _o->batch_dim; + return tflite::CreateReverseSequenceOptions( + _fbb, + _seq_dim, + _batch_dim); +} + +inline MatrixDiagOptionsT *MatrixDiagOptions::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { + auto _o = std::unique_ptr(new MatrixDiagOptionsT()); + UnPackTo(_o.get(), _resolver); + return _o.release(); +} + +inline void MatrixDiagOptions::UnPackTo(MatrixDiagOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; +} + +inline ::flatbuffers::Offset MatrixDiagOptions::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const MatrixDiagOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { + return CreateMatrixDiagOptions(_fbb, _o, _rehasher); +} + +inline ::flatbuffers::Offset CreateMatrixDiagOptions(::flatbuffers::FlatBufferBuilder &_fbb, const MatrixDiagOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { ::flatbuffers::FlatBufferBuilder *__fbb; const MatrixDiagOptionsT* __o; const ::flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + return tflite::CreateMatrixDiagOptions( + _fbb); +} + +inline QuantizeOptionsT *QuantizeOptions::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { + auto _o = std::unique_ptr(new QuantizeOptionsT()); + UnPackTo(_o.get(), _resolver); + return _o.release(); +} + +inline void QuantizeOptions::UnPackTo(QuantizeOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; +} + +inline ::flatbuffers::Offset QuantizeOptions::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const QuantizeOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { + return CreateQuantizeOptions(_fbb, _o, _rehasher); +} + +inline ::flatbuffers::Offset CreateQuantizeOptions(::flatbuffers::FlatBufferBuilder &_fbb, const QuantizeOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { ::flatbuffers::FlatBufferBuilder *__fbb; const QuantizeOptionsT* __o; const ::flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + return tflite::CreateQuantizeOptions( + _fbb); +} + +inline MatrixSetDiagOptionsT *MatrixSetDiagOptions::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { + auto _o = std::unique_ptr(new MatrixSetDiagOptionsT()); + UnPackTo(_o.get(), _resolver); + return _o.release(); +} + +inline void MatrixSetDiagOptions::UnPackTo(MatrixSetDiagOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; +} + +inline ::flatbuffers::Offset MatrixSetDiagOptions::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const MatrixSetDiagOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { + return CreateMatrixSetDiagOptions(_fbb, _o, _rehasher); +} + +inline ::flatbuffers::Offset CreateMatrixSetDiagOptions(::flatbuffers::FlatBufferBuilder &_fbb, const MatrixSetDiagOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { ::flatbuffers::FlatBufferBuilder *__fbb; const MatrixSetDiagOptionsT* __o; const ::flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + return tflite::CreateMatrixSetDiagOptions( + _fbb); +} + +inline IfOptionsT *IfOptions::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { + auto _o = std::unique_ptr(new IfOptionsT()); + UnPackTo(_o.get(), _resolver); + return _o.release(); +} + +inline void IfOptions::UnPackTo(IfOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = then_subgraph_index(); _o->then_subgraph_index = _e; } + { auto _e = else_subgraph_index(); _o->else_subgraph_index = _e; } +} + +inline ::flatbuffers::Offset IfOptions::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const IfOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { + return CreateIfOptions(_fbb, _o, _rehasher); +} + +inline ::flatbuffers::Offset CreateIfOptions(::flatbuffers::FlatBufferBuilder &_fbb, const IfOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { ::flatbuffers::FlatBufferBuilder *__fbb; const IfOptionsT* __o; const ::flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _then_subgraph_index = _o->then_subgraph_index; + auto _else_subgraph_index = _o->else_subgraph_index; + return tflite::CreateIfOptions( + _fbb, + _then_subgraph_index, + _else_subgraph_index); +} + +inline CallOnceOptionsT *CallOnceOptions::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { + auto _o = std::unique_ptr(new CallOnceOptionsT()); + UnPackTo(_o.get(), _resolver); + return _o.release(); +} + +inline void CallOnceOptions::UnPackTo(CallOnceOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = init_subgraph_index(); _o->init_subgraph_index = _e; } +} + +inline ::flatbuffers::Offset CallOnceOptions::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const CallOnceOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { + return CreateCallOnceOptions(_fbb, _o, _rehasher); +} + +inline ::flatbuffers::Offset CreateCallOnceOptions(::flatbuffers::FlatBufferBuilder &_fbb, const CallOnceOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { ::flatbuffers::FlatBufferBuilder *__fbb; const CallOnceOptionsT* __o; const ::flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _init_subgraph_index = _o->init_subgraph_index; + return tflite::CreateCallOnceOptions( + _fbb, + _init_subgraph_index); +} + +inline WhileOptionsT *WhileOptions::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { + auto _o = std::unique_ptr(new WhileOptionsT()); + UnPackTo(_o.get(), _resolver); + return _o.release(); +} + +inline void WhileOptions::UnPackTo(WhileOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = cond_subgraph_index(); _o->cond_subgraph_index = _e; } + { auto _e = body_subgraph_index(); _o->body_subgraph_index = _e; } +} + +inline ::flatbuffers::Offset WhileOptions::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const WhileOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { + return CreateWhileOptions(_fbb, _o, _rehasher); +} + +inline ::flatbuffers::Offset CreateWhileOptions(::flatbuffers::FlatBufferBuilder &_fbb, const WhileOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { ::flatbuffers::FlatBufferBuilder *__fbb; const WhileOptionsT* __o; const ::flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _cond_subgraph_index = _o->cond_subgraph_index; + auto _body_subgraph_index = _o->body_subgraph_index; + return tflite::CreateWhileOptions( + _fbb, + _cond_subgraph_index, + _body_subgraph_index); +} + +inline NonMaxSuppressionV4OptionsT *NonMaxSuppressionV4Options::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { + auto _o = std::unique_ptr(new NonMaxSuppressionV4OptionsT()); + UnPackTo(_o.get(), _resolver); + return _o.release(); +} + +inline void NonMaxSuppressionV4Options::UnPackTo(NonMaxSuppressionV4OptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; +} + +inline ::flatbuffers::Offset NonMaxSuppressionV4Options::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const NonMaxSuppressionV4OptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { + return CreateNonMaxSuppressionV4Options(_fbb, _o, _rehasher); +} + +inline ::flatbuffers::Offset CreateNonMaxSuppressionV4Options(::flatbuffers::FlatBufferBuilder &_fbb, const NonMaxSuppressionV4OptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { ::flatbuffers::FlatBufferBuilder *__fbb; const NonMaxSuppressionV4OptionsT* __o; const ::flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + return tflite::CreateNonMaxSuppressionV4Options( + _fbb); +} + +inline NonMaxSuppressionV5OptionsT *NonMaxSuppressionV5Options::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { + auto _o = std::unique_ptr(new NonMaxSuppressionV5OptionsT()); + UnPackTo(_o.get(), _resolver); + return _o.release(); +} + +inline void NonMaxSuppressionV5Options::UnPackTo(NonMaxSuppressionV5OptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; +} + +inline ::flatbuffers::Offset NonMaxSuppressionV5Options::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const NonMaxSuppressionV5OptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { + return CreateNonMaxSuppressionV5Options(_fbb, _o, _rehasher); +} + +inline ::flatbuffers::Offset CreateNonMaxSuppressionV5Options(::flatbuffers::FlatBufferBuilder &_fbb, const NonMaxSuppressionV5OptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { ::flatbuffers::FlatBufferBuilder *__fbb; const NonMaxSuppressionV5OptionsT* __o; const ::flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + return tflite::CreateNonMaxSuppressionV5Options( + _fbb); +} + +inline ScatterNdOptionsT *ScatterNdOptions::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { + auto _o = std::unique_ptr(new ScatterNdOptionsT()); + UnPackTo(_o.get(), _resolver); + return _o.release(); +} + +inline void ScatterNdOptions::UnPackTo(ScatterNdOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; +} + +inline ::flatbuffers::Offset ScatterNdOptions::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const ScatterNdOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { + return CreateScatterNdOptions(_fbb, _o, _rehasher); +} + +inline ::flatbuffers::Offset CreateScatterNdOptions(::flatbuffers::FlatBufferBuilder &_fbb, const ScatterNdOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { ::flatbuffers::FlatBufferBuilder *__fbb; const ScatterNdOptionsT* __o; const ::flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + return tflite::CreateScatterNdOptions( + _fbb); +} + +inline SelectV2OptionsT *SelectV2Options::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { + auto _o = std::unique_ptr(new SelectV2OptionsT()); + UnPackTo(_o.get(), _resolver); + return _o.release(); +} + +inline void SelectV2Options::UnPackTo(SelectV2OptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; +} + +inline ::flatbuffers::Offset SelectV2Options::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const SelectV2OptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { + return CreateSelectV2Options(_fbb, _o, _rehasher); +} + +inline ::flatbuffers::Offset CreateSelectV2Options(::flatbuffers::FlatBufferBuilder &_fbb, const SelectV2OptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { ::flatbuffers::FlatBufferBuilder *__fbb; const SelectV2OptionsT* __o; const ::flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + return tflite::CreateSelectV2Options( + _fbb); +} + +inline DensifyOptionsT *DensifyOptions::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { + auto _o = std::unique_ptr(new DensifyOptionsT()); + UnPackTo(_o.get(), _resolver); + return _o.release(); +} + +inline void DensifyOptions::UnPackTo(DensifyOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; +} + +inline ::flatbuffers::Offset DensifyOptions::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const DensifyOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { + return CreateDensifyOptions(_fbb, _o, _rehasher); +} + +inline ::flatbuffers::Offset CreateDensifyOptions(::flatbuffers::FlatBufferBuilder &_fbb, const DensifyOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { ::flatbuffers::FlatBufferBuilder *__fbb; const DensifyOptionsT* __o; const ::flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + return tflite::CreateDensifyOptions( + _fbb); +} + +inline SegmentSumOptionsT *SegmentSumOptions::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { + auto _o = std::unique_ptr(new SegmentSumOptionsT()); + UnPackTo(_o.get(), _resolver); + return _o.release(); +} + +inline void SegmentSumOptions::UnPackTo(SegmentSumOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; +} + +inline ::flatbuffers::Offset SegmentSumOptions::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const SegmentSumOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { + return CreateSegmentSumOptions(_fbb, _o, _rehasher); +} + +inline ::flatbuffers::Offset CreateSegmentSumOptions(::flatbuffers::FlatBufferBuilder &_fbb, const SegmentSumOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { ::flatbuffers::FlatBufferBuilder *__fbb; const SegmentSumOptionsT* __o; const ::flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + return tflite::CreateSegmentSumOptions( + _fbb); +} + +inline BatchMatMulOptionsT *BatchMatMulOptions::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { + auto _o = std::unique_ptr(new BatchMatMulOptionsT()); + UnPackTo(_o.get(), _resolver); + return _o.release(); +} + +inline void BatchMatMulOptions::UnPackTo(BatchMatMulOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = adj_x(); _o->adj_x = _e; } + { auto _e = adj_y(); _o->adj_y = _e; } + { auto _e = asymmetric_quantize_inputs(); _o->asymmetric_quantize_inputs = _e; } +} + +inline ::flatbuffers::Offset BatchMatMulOptions::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const BatchMatMulOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { + return CreateBatchMatMulOptions(_fbb, _o, _rehasher); +} + +inline ::flatbuffers::Offset CreateBatchMatMulOptions(::flatbuffers::FlatBufferBuilder &_fbb, const BatchMatMulOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { ::flatbuffers::FlatBufferBuilder *__fbb; const BatchMatMulOptionsT* __o; const ::flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _adj_x = _o->adj_x; + auto _adj_y = _o->adj_y; + auto _asymmetric_quantize_inputs = _o->asymmetric_quantize_inputs; + return tflite::CreateBatchMatMulOptions( + _fbb, + _adj_x, + _adj_y, + _asymmetric_quantize_inputs); +} + +inline CumsumOptionsT *CumsumOptions::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { + auto _o = std::unique_ptr(new CumsumOptionsT()); + UnPackTo(_o.get(), _resolver); + return _o.release(); +} + +inline void CumsumOptions::UnPackTo(CumsumOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = exclusive(); _o->exclusive = _e; } + { auto _e = reverse(); _o->reverse = _e; } +} + +inline ::flatbuffers::Offset CumsumOptions::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const CumsumOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { + return CreateCumsumOptions(_fbb, _o, _rehasher); +} + +inline ::flatbuffers::Offset CreateCumsumOptions(::flatbuffers::FlatBufferBuilder &_fbb, const CumsumOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { ::flatbuffers::FlatBufferBuilder *__fbb; const CumsumOptionsT* __o; const ::flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _exclusive = _o->exclusive; + auto _reverse = _o->reverse; + return tflite::CreateCumsumOptions( + _fbb, + _exclusive, + _reverse); +} + +inline BroadcastToOptionsT *BroadcastToOptions::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { + auto _o = std::unique_ptr(new BroadcastToOptionsT()); + UnPackTo(_o.get(), _resolver); + return _o.release(); +} + +inline void BroadcastToOptions::UnPackTo(BroadcastToOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; +} + +inline ::flatbuffers::Offset BroadcastToOptions::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const BroadcastToOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { + return CreateBroadcastToOptions(_fbb, _o, _rehasher); +} + +inline ::flatbuffers::Offset CreateBroadcastToOptions(::flatbuffers::FlatBufferBuilder &_fbb, const BroadcastToOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { ::flatbuffers::FlatBufferBuilder *__fbb; const BroadcastToOptionsT* __o; const ::flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + return tflite::CreateBroadcastToOptions( + _fbb); +} + +inline Rfft2dOptionsT *Rfft2dOptions::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { + auto _o = std::unique_ptr(new Rfft2dOptionsT()); + UnPackTo(_o.get(), _resolver); + return _o.release(); +} + +inline void Rfft2dOptions::UnPackTo(Rfft2dOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; +} + +inline ::flatbuffers::Offset Rfft2dOptions::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const Rfft2dOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { + return CreateRfft2dOptions(_fbb, _o, _rehasher); +} + +inline ::flatbuffers::Offset CreateRfft2dOptions(::flatbuffers::FlatBufferBuilder &_fbb, const Rfft2dOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { ::flatbuffers::FlatBufferBuilder *__fbb; const Rfft2dOptionsT* __o; const ::flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + return tflite::CreateRfft2dOptions( + _fbb); +} + +inline HashtableOptionsT *HashtableOptions::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { + auto _o = std::unique_ptr(new HashtableOptionsT()); + UnPackTo(_o.get(), _resolver); + return _o.release(); +} + +inline void HashtableOptions::UnPackTo(HashtableOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = table_id(); _o->table_id = _e; } + { auto _e = key_dtype(); _o->key_dtype = _e; } + { auto _e = value_dtype(); _o->value_dtype = _e; } +} + +inline ::flatbuffers::Offset HashtableOptions::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const HashtableOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { + return CreateHashtableOptions(_fbb, _o, _rehasher); +} + +inline ::flatbuffers::Offset CreateHashtableOptions(::flatbuffers::FlatBufferBuilder &_fbb, const HashtableOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { ::flatbuffers::FlatBufferBuilder *__fbb; const HashtableOptionsT* __o; const ::flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _table_id = _o->table_id; + auto _key_dtype = _o->key_dtype; + auto _value_dtype = _o->value_dtype; + return tflite::CreateHashtableOptions( + _fbb, + _table_id, + _key_dtype, + _value_dtype); +} + +inline HashtableFindOptionsT *HashtableFindOptions::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { + auto _o = std::unique_ptr(new HashtableFindOptionsT()); + UnPackTo(_o.get(), _resolver); + return _o.release(); +} + +inline void HashtableFindOptions::UnPackTo(HashtableFindOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; +} + +inline ::flatbuffers::Offset HashtableFindOptions::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const HashtableFindOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { + return CreateHashtableFindOptions(_fbb, _o, _rehasher); +} + +inline ::flatbuffers::Offset CreateHashtableFindOptions(::flatbuffers::FlatBufferBuilder &_fbb, const HashtableFindOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { ::flatbuffers::FlatBufferBuilder *__fbb; const HashtableFindOptionsT* __o; const ::flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + return tflite::CreateHashtableFindOptions( + _fbb); +} + +inline HashtableImportOptionsT *HashtableImportOptions::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { + auto _o = std::unique_ptr(new HashtableImportOptionsT()); + UnPackTo(_o.get(), _resolver); + return _o.release(); +} + +inline void HashtableImportOptions::UnPackTo(HashtableImportOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; +} + +inline ::flatbuffers::Offset HashtableImportOptions::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const HashtableImportOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { + return CreateHashtableImportOptions(_fbb, _o, _rehasher); +} + +inline ::flatbuffers::Offset CreateHashtableImportOptions(::flatbuffers::FlatBufferBuilder &_fbb, const HashtableImportOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { ::flatbuffers::FlatBufferBuilder *__fbb; const HashtableImportOptionsT* __o; const ::flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + return tflite::CreateHashtableImportOptions( + _fbb); +} + +inline HashtableSizeOptionsT *HashtableSizeOptions::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { + auto _o = std::unique_ptr(new HashtableSizeOptionsT()); + UnPackTo(_o.get(), _resolver); + return _o.release(); +} + +inline void HashtableSizeOptions::UnPackTo(HashtableSizeOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; +} + +inline ::flatbuffers::Offset HashtableSizeOptions::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const HashtableSizeOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { + return CreateHashtableSizeOptions(_fbb, _o, _rehasher); +} + +inline ::flatbuffers::Offset CreateHashtableSizeOptions(::flatbuffers::FlatBufferBuilder &_fbb, const HashtableSizeOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { ::flatbuffers::FlatBufferBuilder *__fbb; const HashtableSizeOptionsT* __o; const ::flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + return tflite::CreateHashtableSizeOptions( + _fbb); +} + +inline VarHandleOptionsT *VarHandleOptions::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { + auto _o = std::unique_ptr(new VarHandleOptionsT()); + UnPackTo(_o.get(), _resolver); + return _o.release(); +} + +inline void VarHandleOptions::UnPackTo(VarHandleOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = container(); if (_e) _o->container = _e->str(); } + { auto _e = shared_name(); if (_e) _o->shared_name = _e->str(); } +} + +inline ::flatbuffers::Offset VarHandleOptions::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const VarHandleOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { + return CreateVarHandleOptions(_fbb, _o, _rehasher); +} + +inline ::flatbuffers::Offset CreateVarHandleOptions(::flatbuffers::FlatBufferBuilder &_fbb, const VarHandleOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { ::flatbuffers::FlatBufferBuilder *__fbb; const VarHandleOptionsT* __o; const ::flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _container = _o->container.empty() ? 0 : _fbb.CreateString(_o->container); + auto _shared_name = _o->shared_name.empty() ? 0 : _fbb.CreateString(_o->shared_name); + return tflite::CreateVarHandleOptions( + _fbb, + _container, + _shared_name); +} + +inline ReadVariableOptionsT *ReadVariableOptions::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { + auto _o = std::unique_ptr(new ReadVariableOptionsT()); + UnPackTo(_o.get(), _resolver); + return _o.release(); +} + +inline void ReadVariableOptions::UnPackTo(ReadVariableOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; +} + +inline ::flatbuffers::Offset ReadVariableOptions::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const ReadVariableOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { + return CreateReadVariableOptions(_fbb, _o, _rehasher); +} + +inline ::flatbuffers::Offset CreateReadVariableOptions(::flatbuffers::FlatBufferBuilder &_fbb, const ReadVariableOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { ::flatbuffers::FlatBufferBuilder *__fbb; const ReadVariableOptionsT* __o; const ::flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + return tflite::CreateReadVariableOptions( + _fbb); +} + +inline AssignVariableOptionsT *AssignVariableOptions::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { + auto _o = std::unique_ptr(new AssignVariableOptionsT()); + UnPackTo(_o.get(), _resolver); + return _o.release(); +} + +inline void AssignVariableOptions::UnPackTo(AssignVariableOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; +} + +inline ::flatbuffers::Offset AssignVariableOptions::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const AssignVariableOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { + return CreateAssignVariableOptions(_fbb, _o, _rehasher); +} + +inline ::flatbuffers::Offset CreateAssignVariableOptions(::flatbuffers::FlatBufferBuilder &_fbb, const AssignVariableOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { ::flatbuffers::FlatBufferBuilder *__fbb; const AssignVariableOptionsT* __o; const ::flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + return tflite::CreateAssignVariableOptions( + _fbb); +} + +inline RandomOptionsT *RandomOptions::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { + auto _o = std::unique_ptr(new RandomOptionsT()); + UnPackTo(_o.get(), _resolver); + return _o.release(); +} + +inline void RandomOptions::UnPackTo(RandomOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = seed(); _o->seed = _e; } + { auto _e = seed2(); _o->seed2 = _e; } +} + +inline ::flatbuffers::Offset RandomOptions::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const RandomOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { + return CreateRandomOptions(_fbb, _o, _rehasher); +} + +inline ::flatbuffers::Offset CreateRandomOptions(::flatbuffers::FlatBufferBuilder &_fbb, const RandomOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { ::flatbuffers::FlatBufferBuilder *__fbb; const RandomOptionsT* __o; const ::flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _seed = _o->seed; + auto _seed2 = _o->seed2; + return tflite::CreateRandomOptions( + _fbb, + _seed, + _seed2); +} + +inline BucketizeOptionsT *BucketizeOptions::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { + auto _o = std::unique_ptr(new BucketizeOptionsT()); + UnPackTo(_o.get(), _resolver); + return _o.release(); +} + +inline void BucketizeOptions::UnPackTo(BucketizeOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = boundaries(); if (_e) { _o->boundaries.resize(_e->size()); for (::flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->boundaries[_i] = _e->Get(_i); } } else { _o->boundaries.resize(0); } } +} + +inline ::flatbuffers::Offset BucketizeOptions::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const BucketizeOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { + return CreateBucketizeOptions(_fbb, _o, _rehasher); +} + +inline ::flatbuffers::Offset CreateBucketizeOptions(::flatbuffers::FlatBufferBuilder &_fbb, const BucketizeOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { ::flatbuffers::FlatBufferBuilder *__fbb; const BucketizeOptionsT* __o; const ::flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _boundaries = _o->boundaries.size() ? _fbb.CreateVector(_o->boundaries) : 0; + return tflite::CreateBucketizeOptions( + _fbb, + _boundaries); +} + +inline GeluOptionsT *GeluOptions::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { + auto _o = std::unique_ptr(new GeluOptionsT()); + UnPackTo(_o.get(), _resolver); + return _o.release(); +} + +inline void GeluOptions::UnPackTo(GeluOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = approximate(); _o->approximate = _e; } +} + +inline ::flatbuffers::Offset GeluOptions::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const GeluOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { + return CreateGeluOptions(_fbb, _o, _rehasher); +} + +inline ::flatbuffers::Offset CreateGeluOptions(::flatbuffers::FlatBufferBuilder &_fbb, const GeluOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { ::flatbuffers::FlatBufferBuilder *__fbb; const GeluOptionsT* __o; const ::flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _approximate = _o->approximate; + return tflite::CreateGeluOptions( + _fbb, + _approximate); +} + +inline DynamicUpdateSliceOptionsT *DynamicUpdateSliceOptions::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { + auto _o = std::unique_ptr(new DynamicUpdateSliceOptionsT()); + UnPackTo(_o.get(), _resolver); + return _o.release(); +} + +inline void DynamicUpdateSliceOptions::UnPackTo(DynamicUpdateSliceOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; +} + +inline ::flatbuffers::Offset DynamicUpdateSliceOptions::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const DynamicUpdateSliceOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { + return CreateDynamicUpdateSliceOptions(_fbb, _o, _rehasher); +} + +inline ::flatbuffers::Offset CreateDynamicUpdateSliceOptions(::flatbuffers::FlatBufferBuilder &_fbb, const DynamicUpdateSliceOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { ::flatbuffers::FlatBufferBuilder *__fbb; const DynamicUpdateSliceOptionsT* __o; const ::flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + return tflite::CreateDynamicUpdateSliceOptions( + _fbb); +} + +inline UnsortedSegmentProdOptionsT *UnsortedSegmentProdOptions::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { + auto _o = std::unique_ptr(new UnsortedSegmentProdOptionsT()); + UnPackTo(_o.get(), _resolver); + return _o.release(); +} + +inline void UnsortedSegmentProdOptions::UnPackTo(UnsortedSegmentProdOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; +} + +inline ::flatbuffers::Offset UnsortedSegmentProdOptions::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const UnsortedSegmentProdOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { + return CreateUnsortedSegmentProdOptions(_fbb, _o, _rehasher); +} + +inline ::flatbuffers::Offset CreateUnsortedSegmentProdOptions(::flatbuffers::FlatBufferBuilder &_fbb, const UnsortedSegmentProdOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { ::flatbuffers::FlatBufferBuilder *__fbb; const UnsortedSegmentProdOptionsT* __o; const ::flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + return tflite::CreateUnsortedSegmentProdOptions( + _fbb); +} + +inline UnsortedSegmentMaxOptionsT *UnsortedSegmentMaxOptions::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { + auto _o = std::unique_ptr(new UnsortedSegmentMaxOptionsT()); + UnPackTo(_o.get(), _resolver); + return _o.release(); +} + +inline void UnsortedSegmentMaxOptions::UnPackTo(UnsortedSegmentMaxOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; +} + +inline ::flatbuffers::Offset UnsortedSegmentMaxOptions::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const UnsortedSegmentMaxOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { + return CreateUnsortedSegmentMaxOptions(_fbb, _o, _rehasher); +} + +inline ::flatbuffers::Offset CreateUnsortedSegmentMaxOptions(::flatbuffers::FlatBufferBuilder &_fbb, const UnsortedSegmentMaxOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { ::flatbuffers::FlatBufferBuilder *__fbb; const UnsortedSegmentMaxOptionsT* __o; const ::flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + return tflite::CreateUnsortedSegmentMaxOptions( + _fbb); +} + +inline UnsortedSegmentSumOptionsT *UnsortedSegmentSumOptions::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { + auto _o = std::unique_ptr(new UnsortedSegmentSumOptionsT()); + UnPackTo(_o.get(), _resolver); + return _o.release(); +} + +inline void UnsortedSegmentSumOptions::UnPackTo(UnsortedSegmentSumOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; +} + +inline ::flatbuffers::Offset UnsortedSegmentSumOptions::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const UnsortedSegmentSumOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { + return CreateUnsortedSegmentSumOptions(_fbb, _o, _rehasher); +} + +inline ::flatbuffers::Offset CreateUnsortedSegmentSumOptions(::flatbuffers::FlatBufferBuilder &_fbb, const UnsortedSegmentSumOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { ::flatbuffers::FlatBufferBuilder *__fbb; const UnsortedSegmentSumOptionsT* __o; const ::flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + return tflite::CreateUnsortedSegmentSumOptions( + _fbb); +} + +inline ATan2OptionsT *ATan2Options::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { + auto _o = std::unique_ptr(new ATan2OptionsT()); + UnPackTo(_o.get(), _resolver); + return _o.release(); +} + +inline void ATan2Options::UnPackTo(ATan2OptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; +} + +inline ::flatbuffers::Offset ATan2Options::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const ATan2OptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { + return CreateATan2Options(_fbb, _o, _rehasher); +} + +inline ::flatbuffers::Offset CreateATan2Options(::flatbuffers::FlatBufferBuilder &_fbb, const ATan2OptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { ::flatbuffers::FlatBufferBuilder *__fbb; const ATan2OptionsT* __o; const ::flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + return tflite::CreateATan2Options( + _fbb); +} + +inline UnsortedSegmentMinOptionsT *UnsortedSegmentMinOptions::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { + auto _o = std::unique_ptr(new UnsortedSegmentMinOptionsT()); + UnPackTo(_o.get(), _resolver); + return _o.release(); +} + +inline void UnsortedSegmentMinOptions::UnPackTo(UnsortedSegmentMinOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; +} + +inline ::flatbuffers::Offset UnsortedSegmentMinOptions::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const UnsortedSegmentMinOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { + return CreateUnsortedSegmentMinOptions(_fbb, _o, _rehasher); +} + +inline ::flatbuffers::Offset CreateUnsortedSegmentMinOptions(::flatbuffers::FlatBufferBuilder &_fbb, const UnsortedSegmentMinOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { ::flatbuffers::FlatBufferBuilder *__fbb; const UnsortedSegmentMinOptionsT* __o; const ::flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + return tflite::CreateUnsortedSegmentMinOptions( + _fbb); +} + +inline SignOptionsT *SignOptions::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { + auto _o = std::unique_ptr(new SignOptionsT()); + UnPackTo(_o.get(), _resolver); + return _o.release(); +} + +inline void SignOptions::UnPackTo(SignOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; +} + +inline ::flatbuffers::Offset SignOptions::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const SignOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { + return CreateSignOptions(_fbb, _o, _rehasher); +} + +inline ::flatbuffers::Offset CreateSignOptions(::flatbuffers::FlatBufferBuilder &_fbb, const SignOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { ::flatbuffers::FlatBufferBuilder *__fbb; const SignOptionsT* __o; const ::flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + return tflite::CreateSignOptions( + _fbb); +} + +inline BitcastOptionsT *BitcastOptions::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { + auto _o = std::unique_ptr(new BitcastOptionsT()); + UnPackTo(_o.get(), _resolver); + return _o.release(); +} + +inline void BitcastOptions::UnPackTo(BitcastOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; +} + +inline ::flatbuffers::Offset BitcastOptions::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const BitcastOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { + return CreateBitcastOptions(_fbb, _o, _rehasher); +} + +inline ::flatbuffers::Offset CreateBitcastOptions(::flatbuffers::FlatBufferBuilder &_fbb, const BitcastOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { ::flatbuffers::FlatBufferBuilder *__fbb; const BitcastOptionsT* __o; const ::flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + return tflite::CreateBitcastOptions( + _fbb); +} + +inline BitwiseXorOptionsT *BitwiseXorOptions::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { + auto _o = std::unique_ptr(new BitwiseXorOptionsT()); + UnPackTo(_o.get(), _resolver); + return _o.release(); +} + +inline void BitwiseXorOptions::UnPackTo(BitwiseXorOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; +} + +inline ::flatbuffers::Offset BitwiseXorOptions::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const BitwiseXorOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { + return CreateBitwiseXorOptions(_fbb, _o, _rehasher); +} + +inline ::flatbuffers::Offset CreateBitwiseXorOptions(::flatbuffers::FlatBufferBuilder &_fbb, const BitwiseXorOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { ::flatbuffers::FlatBufferBuilder *__fbb; const BitwiseXorOptionsT* __o; const ::flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + return tflite::CreateBitwiseXorOptions( + _fbb); +} + +inline RightShiftOptionsT *RightShiftOptions::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { + auto _o = std::unique_ptr(new RightShiftOptionsT()); + UnPackTo(_o.get(), _resolver); + return _o.release(); +} + +inline void RightShiftOptions::UnPackTo(RightShiftOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; +} + +inline ::flatbuffers::Offset RightShiftOptions::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const RightShiftOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { + return CreateRightShiftOptions(_fbb, _o, _rehasher); +} + +inline ::flatbuffers::Offset CreateRightShiftOptions(::flatbuffers::FlatBufferBuilder &_fbb, const RightShiftOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { ::flatbuffers::FlatBufferBuilder *__fbb; const RightShiftOptionsT* __o; const ::flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + return tflite::CreateRightShiftOptions( + _fbb); +} + +inline DilateOptionsT *DilateOptions::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { + auto _o = std::unique_ptr(new DilateOptionsT()); + UnPackTo(_o.get(), _resolver); + return _o.release(); +} + +inline void DilateOptions::UnPackTo(DilateOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; +} + +inline ::flatbuffers::Offset DilateOptions::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const DilateOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { + return CreateDilateOptions(_fbb, _o, _rehasher); +} + +inline ::flatbuffers::Offset CreateDilateOptions(::flatbuffers::FlatBufferBuilder &_fbb, const DilateOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { ::flatbuffers::FlatBufferBuilder *__fbb; const DilateOptionsT* __o; const ::flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + return tflite::CreateDilateOptions( + _fbb); +} + +inline ReduceWindowOptionsT *ReduceWindowOptions::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { + auto _o = std::unique_ptr(new ReduceWindowOptionsT()); + UnPackTo(_o.get(), _resolver); + return _o.release(); +} + +inline void ReduceWindowOptions::UnPackTo(ReduceWindowOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = reduce_function(); _o->reduce_function = _e; } +} + +inline ::flatbuffers::Offset ReduceWindowOptions::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const ReduceWindowOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { + return CreateReduceWindowOptions(_fbb, _o, _rehasher); +} + +inline ::flatbuffers::Offset CreateReduceWindowOptions(::flatbuffers::FlatBufferBuilder &_fbb, const ReduceWindowOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { ::flatbuffers::FlatBufferBuilder *__fbb; const ReduceWindowOptionsT* __o; const ::flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _reduce_function = _o->reduce_function; + return tflite::CreateReduceWindowOptions( + _fbb, + _reduce_function); +} + +inline OperatorCodeT *OperatorCode::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { + auto _o = std::unique_ptr(new OperatorCodeT()); + UnPackTo(_o.get(), _resolver); + return _o.release(); +} + +inline void OperatorCode::UnPackTo(OperatorCodeT *_o, const ::flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = deprecated_builtin_code(); _o->deprecated_builtin_code = _e; } + { auto _e = custom_code(); if (_e) _o->custom_code = _e->str(); } + { auto _e = version(); _o->version = _e; } + { auto _e = builtin_code(); _o->builtin_code = _e; } +} + +inline ::flatbuffers::Offset OperatorCode::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const OperatorCodeT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { + return CreateOperatorCode(_fbb, _o, _rehasher); +} + +inline ::flatbuffers::Offset CreateOperatorCode(::flatbuffers::FlatBufferBuilder &_fbb, const OperatorCodeT *_o, const ::flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { ::flatbuffers::FlatBufferBuilder *__fbb; const OperatorCodeT* __o; const ::flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _deprecated_builtin_code = _o->deprecated_builtin_code; + auto _custom_code = _o->custom_code.empty() ? 0 : _fbb.CreateString(_o->custom_code); + auto _version = _o->version; + auto _builtin_code = _o->builtin_code; + return tflite::CreateOperatorCode( + _fbb, + _deprecated_builtin_code, + _custom_code, + _version, + _builtin_code); +} + +inline StableHLOCompositeOptionsT *StableHLOCompositeOptions::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { + auto _o = std::unique_ptr(new StableHLOCompositeOptionsT()); + UnPackTo(_o.get(), _resolver); + return _o.release(); +} + +inline void StableHLOCompositeOptions::UnPackTo(StableHLOCompositeOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = name(); if (_e) _o->name = _e->str(); } + { auto _e = decomposition_subgraph_index(); _o->decomposition_subgraph_index = _e; } + { auto _e = composite_attributes(); if (_e) { _o->composite_attributes.resize(_e->size()); std::copy(_e->begin(), _e->end(), _o->composite_attributes.begin()); } } + { auto _e = composite_attributes_format(); _o->composite_attributes_format = _e; } + { auto _e = version(); _o->version = _e; } +} + +inline ::flatbuffers::Offset StableHLOCompositeOptions::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const StableHLOCompositeOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { + return CreateStableHLOCompositeOptions(_fbb, _o, _rehasher); +} + +inline ::flatbuffers::Offset CreateStableHLOCompositeOptions(::flatbuffers::FlatBufferBuilder &_fbb, const StableHLOCompositeOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { ::flatbuffers::FlatBufferBuilder *__fbb; const StableHLOCompositeOptionsT* __o; const ::flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _name = _o->name.empty() ? 0 : _fbb.CreateString(_o->name); + auto _decomposition_subgraph_index = _o->decomposition_subgraph_index; + auto _composite_attributes = _o->composite_attributes.size() ? _fbb.CreateVector(_o->composite_attributes) : 0; + auto _composite_attributes_format = _o->composite_attributes_format; + auto _version = _o->version; + return tflite::CreateStableHLOCompositeOptions( + _fbb, + _name, + _decomposition_subgraph_index, + _composite_attributes, + _composite_attributes_format, + _version); +} + +inline StablehloShiftLeftOptionsT *StablehloShiftLeftOptions::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { + auto _o = std::unique_ptr(new StablehloShiftLeftOptionsT()); + UnPackTo(_o.get(), _resolver); + return _o.release(); +} + +inline void StablehloShiftLeftOptions::UnPackTo(StablehloShiftLeftOptionsT *_o, const ::flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; +} + +inline ::flatbuffers::Offset StablehloShiftLeftOptions::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const StablehloShiftLeftOptionsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { + return CreateStablehloShiftLeftOptions(_fbb, _o, _rehasher); +} + +inline ::flatbuffers::Offset CreateStablehloShiftLeftOptions(::flatbuffers::FlatBufferBuilder &_fbb, const StablehloShiftLeftOptionsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { ::flatbuffers::FlatBufferBuilder *__fbb; const StablehloShiftLeftOptionsT* __o; const ::flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + return tflite::CreateStablehloShiftLeftOptions( + _fbb); +} + +inline OperatorT *Operator::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { + auto _o = std::unique_ptr(new OperatorT()); + UnPackTo(_o.get(), _resolver); + return _o.release(); +} + +inline void Operator::UnPackTo(OperatorT *_o, const ::flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = opcode_index(); _o->opcode_index = _e; } + { auto _e = inputs(); if (_e) { _o->inputs.resize(_e->size()); for (::flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->inputs[_i] = _e->Get(_i); } } else { _o->inputs.resize(0); } } + { auto _e = outputs(); if (_e) { _o->outputs.resize(_e->size()); for (::flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->outputs[_i] = _e->Get(_i); } } else { _o->outputs.resize(0); } } + { auto _e = builtin_options_type(); _o->builtin_options.type = _e; } + { auto _e = builtin_options(); if (_e) _o->builtin_options.value = tflite::BuiltinOptionsUnion::UnPack(_e, builtin_options_type(), _resolver); } + { auto _e = custom_options(); if (_e) { _o->custom_options.resize(_e->size()); std::copy(_e->begin(), _e->end(), _o->custom_options.begin()); } } + { auto _e = custom_options_format(); _o->custom_options_format = _e; } + { auto _e = mutating_variable_inputs(); if (_e) { _o->mutating_variable_inputs.resize(_e->size()); for (::flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->mutating_variable_inputs[_i] = _e->Get(_i) != 0; } } else { _o->mutating_variable_inputs.resize(0); } } + { auto _e = intermediates(); if (_e) { _o->intermediates.resize(_e->size()); for (::flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->intermediates[_i] = _e->Get(_i); } } else { _o->intermediates.resize(0); } } + { auto _e = large_custom_options_offset(); _o->large_custom_options_offset = _e; } + { auto _e = large_custom_options_size(); _o->large_custom_options_size = _e; } + { auto _e = builtin_options_2_type(); _o->builtin_options_2.type = _e; } + { auto _e = builtin_options_2(); if (_e) _o->builtin_options_2.value = tflite::BuiltinOptions2Union::UnPack(_e, builtin_options_2_type(), _resolver); } + { auto _e = debug_metadata_index(); _o->debug_metadata_index = _e; } +} + +inline ::flatbuffers::Offset Operator::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const OperatorT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { + return CreateOperator(_fbb, _o, _rehasher); +} + +inline ::flatbuffers::Offset CreateOperator(::flatbuffers::FlatBufferBuilder &_fbb, const OperatorT *_o, const ::flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { ::flatbuffers::FlatBufferBuilder *__fbb; const OperatorT* __o; const ::flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _opcode_index = _o->opcode_index; + auto _inputs = _o->inputs.size() ? _fbb.CreateVector(_o->inputs) : 0; + auto _outputs = _o->outputs.size() ? _fbb.CreateVector(_o->outputs) : 0; + auto _builtin_options_type = _o->builtin_options.type; + auto _builtin_options = _o->builtin_options.Pack(_fbb); + auto _custom_options = _o->custom_options.size() ? _fbb.CreateVector(_o->custom_options) : 0; + auto _custom_options_format = _o->custom_options_format; + auto _mutating_variable_inputs = _o->mutating_variable_inputs.size() ? _fbb.CreateVector(_o->mutating_variable_inputs) : 0; + auto _intermediates = _o->intermediates.size() ? _fbb.CreateVector(_o->intermediates) : 0; + auto _large_custom_options_offset = _o->large_custom_options_offset; + auto _large_custom_options_size = _o->large_custom_options_size; + auto _builtin_options_2_type = _o->builtin_options_2.type; + auto _builtin_options_2 = _o->builtin_options_2.Pack(_fbb); + auto _debug_metadata_index = _o->debug_metadata_index; + return tflite::CreateOperator( + _fbb, + _opcode_index, + _inputs, + _outputs, + _builtin_options_type, + _builtin_options, + _custom_options, + _custom_options_format, + _mutating_variable_inputs, + _intermediates, + _large_custom_options_offset, + _large_custom_options_size, + _builtin_options_2_type, + _builtin_options_2, + _debug_metadata_index); +} + +inline SubGraphT::SubGraphT(const SubGraphT &o) + : inputs(o.inputs), + outputs(o.outputs), + name(o.name), + debug_metadata_index(o.debug_metadata_index) { + tensors.reserve(o.tensors.size()); + for (const auto &tensors_ : o.tensors) { tensors.emplace_back((tensors_) ? new tflite::TensorT(*tensors_) : nullptr); } + operators.reserve(o.operators.size()); + for (const auto &operators_ : o.operators) { operators.emplace_back((operators_) ? new tflite::OperatorT(*operators_) : nullptr); } +} + +inline SubGraphT &SubGraphT::operator=(SubGraphT o) FLATBUFFERS_NOEXCEPT { + std::swap(tensors, o.tensors); + std::swap(inputs, o.inputs); + std::swap(outputs, o.outputs); + std::swap(operators, o.operators); + std::swap(name, o.name); + std::swap(debug_metadata_index, o.debug_metadata_index); + return *this; +} + +inline SubGraphT *SubGraph::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { + auto _o = std::unique_ptr(new SubGraphT()); + UnPackTo(_o.get(), _resolver); + return _o.release(); +} + +inline void SubGraph::UnPackTo(SubGraphT *_o, const ::flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = tensors(); if (_e) { _o->tensors.resize(_e->size()); for (::flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { if(_o->tensors[_i]) { _e->Get(_i)->UnPackTo(_o->tensors[_i].get(), _resolver); } else { _o->tensors[_i] = std::unique_ptr(_e->Get(_i)->UnPack(_resolver)); }; } } else { _o->tensors.resize(0); } } + { auto _e = inputs(); if (_e) { _o->inputs.resize(_e->size()); for (::flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->inputs[_i] = _e->Get(_i); } } else { _o->inputs.resize(0); } } + { auto _e = outputs(); if (_e) { _o->outputs.resize(_e->size()); for (::flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->outputs[_i] = _e->Get(_i); } } else { _o->outputs.resize(0); } } + { auto _e = operators(); if (_e) { _o->operators.resize(_e->size()); for (::flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { if(_o->operators[_i]) { _e->Get(_i)->UnPackTo(_o->operators[_i].get(), _resolver); } else { _o->operators[_i] = std::unique_ptr(_e->Get(_i)->UnPack(_resolver)); }; } } else { _o->operators.resize(0); } } + { auto _e = name(); if (_e) _o->name = _e->str(); } + { auto _e = debug_metadata_index(); _o->debug_metadata_index = _e; } +} + +inline ::flatbuffers::Offset SubGraph::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const SubGraphT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { + return CreateSubGraph(_fbb, _o, _rehasher); +} + +inline ::flatbuffers::Offset CreateSubGraph(::flatbuffers::FlatBufferBuilder &_fbb, const SubGraphT *_o, const ::flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { ::flatbuffers::FlatBufferBuilder *__fbb; const SubGraphT* __o; const ::flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _tensors = _o->tensors.size() ? _fbb.CreateVector<::flatbuffers::Offset> (_o->tensors.size(), [](size_t i, _VectorArgs *__va) { return CreateTensor(*__va->__fbb, __va->__o->tensors[i].get(), __va->__rehasher); }, &_va ) : 0; + auto _inputs = _o->inputs.size() ? _fbb.CreateVector(_o->inputs) : 0; + auto _outputs = _o->outputs.size() ? _fbb.CreateVector(_o->outputs) : 0; + auto _operators = _o->operators.size() ? _fbb.CreateVector<::flatbuffers::Offset> (_o->operators.size(), [](size_t i, _VectorArgs *__va) { return CreateOperator(*__va->__fbb, __va->__o->operators[i].get(), __va->__rehasher); }, &_va ) : 0; + auto _name = _o->name.empty() ? 0 : _fbb.CreateString(_o->name); + auto _debug_metadata_index = _o->debug_metadata_index; + return tflite::CreateSubGraph( + _fbb, + _tensors, + _inputs, + _outputs, + _operators, + _name, + _debug_metadata_index); +} + +inline BufferT *Buffer::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { + auto _o = std::unique_ptr(new BufferT()); + UnPackTo(_o.get(), _resolver); + return _o.release(); +} + +inline void Buffer::UnPackTo(BufferT *_o, const ::flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = data(); if (_e) { _o->data.resize(_e->size()); std::copy(_e->begin(), _e->end(), _o->data.begin()); } } + { auto _e = offset(); _o->offset = _e; } + { auto _e = size(); _o->size = _e; } +} + +inline ::flatbuffers::Offset Buffer::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const BufferT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { + return CreateBuffer(_fbb, _o, _rehasher); +} + +inline ::flatbuffers::Offset CreateBuffer(::flatbuffers::FlatBufferBuilder &_fbb, const BufferT *_o, const ::flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { ::flatbuffers::FlatBufferBuilder *__fbb; const BufferT* __o; const ::flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + _fbb.ForceVectorAlignment(_o->data.size(), sizeof(uint8_t), 16); + auto _data = _o->data.size() ? _fbb.CreateVector(_o->data) : 0; + auto _offset = _o->offset; + auto _size = _o->size; + return tflite::CreateBuffer( + _fbb, + _data, + _offset, + _size); +} + +inline MetadataT *Metadata::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { + auto _o = std::unique_ptr(new MetadataT()); + UnPackTo(_o.get(), _resolver); + return _o.release(); +} + +inline void Metadata::UnPackTo(MetadataT *_o, const ::flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = name(); if (_e) _o->name = _e->str(); } + { auto _e = buffer(); _o->buffer = _e; } +} + +inline ::flatbuffers::Offset Metadata::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const MetadataT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { + return CreateMetadata(_fbb, _o, _rehasher); +} + +inline ::flatbuffers::Offset CreateMetadata(::flatbuffers::FlatBufferBuilder &_fbb, const MetadataT *_o, const ::flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { ::flatbuffers::FlatBufferBuilder *__fbb; const MetadataT* __o; const ::flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _name = _o->name.empty() ? 0 : _fbb.CreateString(_o->name); + auto _buffer = _o->buffer; + return tflite::CreateMetadata( + _fbb, + _name, + _buffer); +} + +inline TensorMapT *TensorMap::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { + auto _o = std::unique_ptr(new TensorMapT()); + UnPackTo(_o.get(), _resolver); + return _o.release(); +} + +inline void TensorMap::UnPackTo(TensorMapT *_o, const ::flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = name(); if (_e) _o->name = _e->str(); } + { auto _e = tensor_index(); _o->tensor_index = _e; } +} + +inline ::flatbuffers::Offset TensorMap::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const TensorMapT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { + return CreateTensorMap(_fbb, _o, _rehasher); +} + +inline ::flatbuffers::Offset CreateTensorMap(::flatbuffers::FlatBufferBuilder &_fbb, const TensorMapT *_o, const ::flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { ::flatbuffers::FlatBufferBuilder *__fbb; const TensorMapT* __o; const ::flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _name = _o->name.empty() ? 0 : _fbb.CreateString(_o->name); + auto _tensor_index = _o->tensor_index; + return tflite::CreateTensorMap( + _fbb, + _name, + _tensor_index); +} + +inline SignatureDefT::SignatureDefT(const SignatureDefT &o) + : signature_key(o.signature_key), + subgraph_index(o.subgraph_index) { + inputs.reserve(o.inputs.size()); + for (const auto &inputs_ : o.inputs) { inputs.emplace_back((inputs_) ? new tflite::TensorMapT(*inputs_) : nullptr); } + outputs.reserve(o.outputs.size()); + for (const auto &outputs_ : o.outputs) { outputs.emplace_back((outputs_) ? new tflite::TensorMapT(*outputs_) : nullptr); } +} + +inline SignatureDefT &SignatureDefT::operator=(SignatureDefT o) FLATBUFFERS_NOEXCEPT { + std::swap(inputs, o.inputs); + std::swap(outputs, o.outputs); + std::swap(signature_key, o.signature_key); + std::swap(subgraph_index, o.subgraph_index); + return *this; +} + +inline SignatureDefT *SignatureDef::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { + auto _o = std::unique_ptr(new SignatureDefT()); + UnPackTo(_o.get(), _resolver); + return _o.release(); +} + +inline void SignatureDef::UnPackTo(SignatureDefT *_o, const ::flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = inputs(); if (_e) { _o->inputs.resize(_e->size()); for (::flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { if(_o->inputs[_i]) { _e->Get(_i)->UnPackTo(_o->inputs[_i].get(), _resolver); } else { _o->inputs[_i] = std::unique_ptr(_e->Get(_i)->UnPack(_resolver)); }; } } else { _o->inputs.resize(0); } } + { auto _e = outputs(); if (_e) { _o->outputs.resize(_e->size()); for (::flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { if(_o->outputs[_i]) { _e->Get(_i)->UnPackTo(_o->outputs[_i].get(), _resolver); } else { _o->outputs[_i] = std::unique_ptr(_e->Get(_i)->UnPack(_resolver)); }; } } else { _o->outputs.resize(0); } } + { auto _e = signature_key(); if (_e) _o->signature_key = _e->str(); } + { auto _e = subgraph_index(); _o->subgraph_index = _e; } +} + +inline ::flatbuffers::Offset SignatureDef::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const SignatureDefT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { + return CreateSignatureDef(_fbb, _o, _rehasher); +} + +inline ::flatbuffers::Offset CreateSignatureDef(::flatbuffers::FlatBufferBuilder &_fbb, const SignatureDefT *_o, const ::flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { ::flatbuffers::FlatBufferBuilder *__fbb; const SignatureDefT* __o; const ::flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _inputs = _o->inputs.size() ? _fbb.CreateVector<::flatbuffers::Offset> (_o->inputs.size(), [](size_t i, _VectorArgs *__va) { return CreateTensorMap(*__va->__fbb, __va->__o->inputs[i].get(), __va->__rehasher); }, &_va ) : 0; + auto _outputs = _o->outputs.size() ? _fbb.CreateVector<::flatbuffers::Offset> (_o->outputs.size(), [](size_t i, _VectorArgs *__va) { return CreateTensorMap(*__va->__fbb, __va->__o->outputs[i].get(), __va->__rehasher); }, &_va ) : 0; + auto _signature_key = _o->signature_key.empty() ? 0 : _fbb.CreateString(_o->signature_key); + auto _subgraph_index = _o->subgraph_index; + return tflite::CreateSignatureDef( + _fbb, + _inputs, + _outputs, + _signature_key, + _subgraph_index); +} + +inline ModelT::ModelT(const ModelT &o) + : version(o.version), + description(o.description), + metadata_buffer(o.metadata_buffer) { + operator_codes.reserve(o.operator_codes.size()); + for (const auto &operator_codes_ : o.operator_codes) { operator_codes.emplace_back((operator_codes_) ? new tflite::OperatorCodeT(*operator_codes_) : nullptr); } + subgraphs.reserve(o.subgraphs.size()); + for (const auto &subgraphs_ : o.subgraphs) { subgraphs.emplace_back((subgraphs_) ? new tflite::SubGraphT(*subgraphs_) : nullptr); } + buffers.reserve(o.buffers.size()); + for (const auto &buffers_ : o.buffers) { buffers.emplace_back((buffers_) ? new tflite::BufferT(*buffers_) : nullptr); } + metadata.reserve(o.metadata.size()); + for (const auto &metadata_ : o.metadata) { metadata.emplace_back((metadata_) ? new tflite::MetadataT(*metadata_) : nullptr); } + signature_defs.reserve(o.signature_defs.size()); + for (const auto &signature_defs_ : o.signature_defs) { signature_defs.emplace_back((signature_defs_) ? new tflite::SignatureDefT(*signature_defs_) : nullptr); } +} + +inline ModelT &ModelT::operator=(ModelT o) FLATBUFFERS_NOEXCEPT { + std::swap(version, o.version); + std::swap(operator_codes, o.operator_codes); + std::swap(subgraphs, o.subgraphs); + std::swap(description, o.description); + std::swap(buffers, o.buffers); + std::swap(metadata_buffer, o.metadata_buffer); + std::swap(metadata, o.metadata); + std::swap(signature_defs, o.signature_defs); + return *this; +} + +inline ModelT *Model::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { + auto _o = std::unique_ptr(new ModelT()); + UnPackTo(_o.get(), _resolver); + return _o.release(); +} + +inline void Model::UnPackTo(ModelT *_o, const ::flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = version(); _o->version = _e; } + { auto _e = operator_codes(); if (_e) { _o->operator_codes.resize(_e->size()); for (::flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { if(_o->operator_codes[_i]) { _e->Get(_i)->UnPackTo(_o->operator_codes[_i].get(), _resolver); } else { _o->operator_codes[_i] = std::unique_ptr(_e->Get(_i)->UnPack(_resolver)); }; } } else { _o->operator_codes.resize(0); } } + { auto _e = subgraphs(); if (_e) { _o->subgraphs.resize(_e->size()); for (::flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { if(_o->subgraphs[_i]) { _e->Get(_i)->UnPackTo(_o->subgraphs[_i].get(), _resolver); } else { _o->subgraphs[_i] = std::unique_ptr(_e->Get(_i)->UnPack(_resolver)); }; } } else { _o->subgraphs.resize(0); } } + { auto _e = description(); if (_e) _o->description = _e->str(); } + { auto _e = buffers(); if (_e) { _o->buffers.resize(_e->size()); for (::flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { if(_o->buffers[_i]) { _e->Get(_i)->UnPackTo(_o->buffers[_i].get(), _resolver); } else { _o->buffers[_i] = std::unique_ptr(_e->Get(_i)->UnPack(_resolver)); }; } } else { _o->buffers.resize(0); } } + { auto _e = metadata_buffer(); if (_e) { _o->metadata_buffer.resize(_e->size()); for (::flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->metadata_buffer[_i] = _e->Get(_i); } } else { _o->metadata_buffer.resize(0); } } + { auto _e = metadata(); if (_e) { _o->metadata.resize(_e->size()); for (::flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { if(_o->metadata[_i]) { _e->Get(_i)->UnPackTo(_o->metadata[_i].get(), _resolver); } else { _o->metadata[_i] = std::unique_ptr(_e->Get(_i)->UnPack(_resolver)); }; } } else { _o->metadata.resize(0); } } + { auto _e = signature_defs(); if (_e) { _o->signature_defs.resize(_e->size()); for (::flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { if(_o->signature_defs[_i]) { _e->Get(_i)->UnPackTo(_o->signature_defs[_i].get(), _resolver); } else { _o->signature_defs[_i] = std::unique_ptr(_e->Get(_i)->UnPack(_resolver)); }; } } else { _o->signature_defs.resize(0); } } +} + +inline ::flatbuffers::Offset Model::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const ModelT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { + return CreateModel(_fbb, _o, _rehasher); +} + +inline ::flatbuffers::Offset CreateModel(::flatbuffers::FlatBufferBuilder &_fbb, const ModelT *_o, const ::flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { ::flatbuffers::FlatBufferBuilder *__fbb; const ModelT* __o; const ::flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _version = _o->version; + auto _operator_codes = _o->operator_codes.size() ? _fbb.CreateVector<::flatbuffers::Offset> (_o->operator_codes.size(), [](size_t i, _VectorArgs *__va) { return CreateOperatorCode(*__va->__fbb, __va->__o->operator_codes[i].get(), __va->__rehasher); }, &_va ) : 0; + auto _subgraphs = _o->subgraphs.size() ? _fbb.CreateVector<::flatbuffers::Offset> (_o->subgraphs.size(), [](size_t i, _VectorArgs *__va) { return CreateSubGraph(*__va->__fbb, __va->__o->subgraphs[i].get(), __va->__rehasher); }, &_va ) : 0; + auto _description = _o->description.empty() ? 0 : _fbb.CreateString(_o->description); + auto _buffers = _o->buffers.size() ? _fbb.CreateVector<::flatbuffers::Offset> (_o->buffers.size(), [](size_t i, _VectorArgs *__va) { return CreateBuffer(*__va->__fbb, __va->__o->buffers[i].get(), __va->__rehasher); }, &_va ) : 0; + auto _metadata_buffer = _o->metadata_buffer.size() ? _fbb.CreateVector(_o->metadata_buffer) : 0; + auto _metadata = _o->metadata.size() ? _fbb.CreateVector<::flatbuffers::Offset> (_o->metadata.size(), [](size_t i, _VectorArgs *__va) { return CreateMetadata(*__va->__fbb, __va->__o->metadata[i].get(), __va->__rehasher); }, &_va ) : 0; + auto _signature_defs = _o->signature_defs.size() ? _fbb.CreateVector<::flatbuffers::Offset> (_o->signature_defs.size(), [](size_t i, _VectorArgs *__va) { return CreateSignatureDef(*__va->__fbb, __va->__o->signature_defs[i].get(), __va->__rehasher); }, &_va ) : 0; + return tflite::CreateModel( + _fbb, + _version, + _operator_codes, + _subgraphs, + _description, + _buffers, + _metadata_buffer, + _metadata, + _signature_defs); +} + +inline bool VerifyQuantizationDetails(::flatbuffers::Verifier &verifier, const void *obj, QuantizationDetails type) { + switch (type) { + case QuantizationDetails_NONE: { + return true; + } + case QuantizationDetails_CustomQuantization: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case QuantizationDetails_BlockwiseQuantization: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + default: return true; + } +} + +inline bool VerifyQuantizationDetailsVector(::flatbuffers::Verifier &verifier, const ::flatbuffers::Vector<::flatbuffers::Offset> *values, const ::flatbuffers::Vector *types) { + if (!values || !types) return !values && !types; + if (values->size() != types->size()) return false; + for (::flatbuffers::uoffset_t i = 0; i < values->size(); ++i) { + if (!VerifyQuantizationDetails( + verifier, values->Get(i), types->GetEnum(i))) { + return false; + } + } + return true; +} + +inline void *QuantizationDetailsUnion::UnPack(const void *obj, QuantizationDetails type, const ::flatbuffers::resolver_function_t *resolver) { + (void)resolver; + switch (type) { + case QuantizationDetails_CustomQuantization: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case QuantizationDetails_BlockwiseQuantization: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + default: return nullptr; + } +} + +inline ::flatbuffers::Offset QuantizationDetailsUnion::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const ::flatbuffers::rehasher_function_t *_rehasher) const { + (void)_rehasher; + switch (type) { + case QuantizationDetails_CustomQuantization: { + auto ptr = reinterpret_cast(value); + return CreateCustomQuantization(_fbb, ptr, _rehasher).Union(); + } + case QuantizationDetails_BlockwiseQuantization: { + auto ptr = reinterpret_cast(value); + return CreateBlockwiseQuantization(_fbb, ptr, _rehasher).Union(); + } + default: return 0; + } +} + +inline QuantizationDetailsUnion::QuantizationDetailsUnion(const QuantizationDetailsUnion &u) : type(u.type), value(nullptr) { + switch (type) { + case QuantizationDetails_CustomQuantization: { + value = new tflite::CustomQuantizationT(*reinterpret_cast(u.value)); + break; + } + case QuantizationDetails_BlockwiseQuantization: { + value = new tflite::BlockwiseQuantizationT(*reinterpret_cast(u.value)); + break; + } + default: + break; + } +} + +inline void QuantizationDetailsUnion::Reset() { + switch (type) { + case QuantizationDetails_CustomQuantization: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case QuantizationDetails_BlockwiseQuantization: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + default: break; + } + value = nullptr; + type = QuantizationDetails_NONE; +} + +inline bool VerifySparseIndexVector(::flatbuffers::Verifier &verifier, const void *obj, SparseIndexVector type) { + switch (type) { + case SparseIndexVector_NONE: { + return true; + } + case SparseIndexVector_Int32Vector: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case SparseIndexVector_Uint16Vector: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case SparseIndexVector_Uint8Vector: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + default: return true; + } +} + +inline bool VerifySparseIndexVectorVector(::flatbuffers::Verifier &verifier, const ::flatbuffers::Vector<::flatbuffers::Offset> *values, const ::flatbuffers::Vector *types) { + if (!values || !types) return !values && !types; + if (values->size() != types->size()) return false; + for (::flatbuffers::uoffset_t i = 0; i < values->size(); ++i) { + if (!VerifySparseIndexVector( + verifier, values->Get(i), types->GetEnum(i))) { + return false; + } + } + return true; +} + +inline void *SparseIndexVectorUnion::UnPack(const void *obj, SparseIndexVector type, const ::flatbuffers::resolver_function_t *resolver) { + (void)resolver; + switch (type) { + case SparseIndexVector_Int32Vector: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case SparseIndexVector_Uint16Vector: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case SparseIndexVector_Uint8Vector: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + default: return nullptr; + } +} + +inline ::flatbuffers::Offset SparseIndexVectorUnion::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const ::flatbuffers::rehasher_function_t *_rehasher) const { + (void)_rehasher; + switch (type) { + case SparseIndexVector_Int32Vector: { + auto ptr = reinterpret_cast(value); + return CreateInt32Vector(_fbb, ptr, _rehasher).Union(); + } + case SparseIndexVector_Uint16Vector: { + auto ptr = reinterpret_cast(value); + return CreateUint16Vector(_fbb, ptr, _rehasher).Union(); + } + case SparseIndexVector_Uint8Vector: { + auto ptr = reinterpret_cast(value); + return CreateUint8Vector(_fbb, ptr, _rehasher).Union(); + } + default: return 0; + } +} + +inline SparseIndexVectorUnion::SparseIndexVectorUnion(const SparseIndexVectorUnion &u) : type(u.type), value(nullptr) { + switch (type) { + case SparseIndexVector_Int32Vector: { + value = new tflite::Int32VectorT(*reinterpret_cast(u.value)); + break; + } + case SparseIndexVector_Uint16Vector: { + value = new tflite::Uint16VectorT(*reinterpret_cast(u.value)); + break; + } + case SparseIndexVector_Uint8Vector: { + value = new tflite::Uint8VectorT(*reinterpret_cast(u.value)); + break; + } + default: + break; + } +} + +inline void SparseIndexVectorUnion::Reset() { + switch (type) { + case SparseIndexVector_Int32Vector: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case SparseIndexVector_Uint16Vector: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case SparseIndexVector_Uint8Vector: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + default: break; + } + value = nullptr; + type = SparseIndexVector_NONE; +} + +inline bool VerifyBuiltinOptions(::flatbuffers::Verifier &verifier, const void *obj, BuiltinOptions type) { + switch (type) { + case BuiltinOptions_NONE: { + return true; + } + case BuiltinOptions_Conv2DOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_DepthwiseConv2DOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_ConcatEmbeddingsOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_LSHProjectionOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_Pool2DOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_SVDFOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_RNNOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_FullyConnectedOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_SoftmaxOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_ConcatenationOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_AddOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_L2NormOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_LocalResponseNormalizationOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_LSTMOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_ResizeBilinearOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_CallOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_ReshapeOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_SkipGramOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_SpaceToDepthOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_EmbeddingLookupSparseOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_MulOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_PadOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_GatherOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_BatchToSpaceNDOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_SpaceToBatchNDOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_TransposeOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_ReducerOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_SubOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_DivOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_SqueezeOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_SequenceRNNOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_StridedSliceOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_ExpOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_TopKV2Options: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_SplitOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_LogSoftmaxOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_CastOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_DequantizeOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_MaximumMinimumOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_ArgMaxOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_LessOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_NegOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_PadV2Options: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_GreaterOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_GreaterEqualOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_LessEqualOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_SelectOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_SliceOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_TransposeConvOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_SparseToDenseOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_TileOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_ExpandDimsOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_EqualOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_NotEqualOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_ShapeOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_PowOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_ArgMinOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_FakeQuantOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_PackOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_LogicalOrOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_OneHotOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_LogicalAndOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_LogicalNotOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_UnpackOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_FloorDivOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_SquareOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_ZerosLikeOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_FillOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_BidirectionalSequenceLSTMOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_BidirectionalSequenceRNNOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_UnidirectionalSequenceLSTMOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_FloorModOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_RangeOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_ResizeNearestNeighborOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_LeakyReluOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_SquaredDifferenceOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_MirrorPadOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_AbsOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_SplitVOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_UniqueOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_ReverseV2Options: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_AddNOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_GatherNdOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_CosOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_WhereOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_RankOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_ReverseSequenceOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_MatrixDiagOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_QuantizeOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_MatrixSetDiagOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_HardSwishOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_IfOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_WhileOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_DepthToSpaceOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_NonMaxSuppressionV4Options: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_NonMaxSuppressionV5Options: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_ScatterNdOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_SelectV2Options: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_DensifyOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_SegmentSumOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_BatchMatMulOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_CumsumOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_CallOnceOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_BroadcastToOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_Rfft2dOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_Conv3DOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_HashtableOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_HashtableFindOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_HashtableImportOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_HashtableSizeOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_VarHandleOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_ReadVariableOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_AssignVariableOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_RandomOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_BucketizeOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_GeluOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_DynamicUpdateSliceOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_UnsortedSegmentProdOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_UnsortedSegmentMaxOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_UnsortedSegmentMinOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_UnsortedSegmentSumOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_ATan2Options: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_SignOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_BitcastOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_BitwiseXorOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_RightShiftOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + default: return true; + } +} + +inline bool VerifyBuiltinOptionsVector(::flatbuffers::Verifier &verifier, const ::flatbuffers::Vector<::flatbuffers::Offset> *values, const ::flatbuffers::Vector *types) { + if (!values || !types) return !values && !types; + if (values->size() != types->size()) return false; + for (::flatbuffers::uoffset_t i = 0; i < values->size(); ++i) { + if (!VerifyBuiltinOptions( + verifier, values->Get(i), types->GetEnum(i))) { + return false; + } + } + return true; +} + +inline void *BuiltinOptionsUnion::UnPack(const void *obj, BuiltinOptions type, const ::flatbuffers::resolver_function_t *resolver) { + (void)resolver; + switch (type) { + case BuiltinOptions_Conv2DOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_DepthwiseConv2DOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_ConcatEmbeddingsOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_LSHProjectionOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_Pool2DOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_SVDFOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_RNNOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_FullyConnectedOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_SoftmaxOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_ConcatenationOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_AddOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_L2NormOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_LocalResponseNormalizationOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_LSTMOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_ResizeBilinearOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_CallOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_ReshapeOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_SkipGramOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_SpaceToDepthOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_EmbeddingLookupSparseOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_MulOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_PadOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_GatherOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_BatchToSpaceNDOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_SpaceToBatchNDOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_TransposeOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_ReducerOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_SubOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_DivOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_SqueezeOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_SequenceRNNOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_StridedSliceOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_ExpOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_TopKV2Options: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_SplitOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_LogSoftmaxOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_CastOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_DequantizeOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_MaximumMinimumOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_ArgMaxOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_LessOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_NegOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_PadV2Options: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_GreaterOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_GreaterEqualOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_LessEqualOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_SelectOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_SliceOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_TransposeConvOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_SparseToDenseOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_TileOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_ExpandDimsOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_EqualOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_NotEqualOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_ShapeOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_PowOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_ArgMinOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_FakeQuantOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_PackOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_LogicalOrOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_OneHotOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_LogicalAndOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_LogicalNotOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_UnpackOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_FloorDivOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_SquareOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_ZerosLikeOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_FillOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_BidirectionalSequenceLSTMOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_BidirectionalSequenceRNNOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_UnidirectionalSequenceLSTMOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_FloorModOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_RangeOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_ResizeNearestNeighborOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_LeakyReluOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_SquaredDifferenceOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_MirrorPadOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_AbsOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_SplitVOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_UniqueOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_ReverseV2Options: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_AddNOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_GatherNdOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_CosOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_WhereOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_RankOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_ReverseSequenceOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_MatrixDiagOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_QuantizeOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_MatrixSetDiagOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_HardSwishOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_IfOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_WhileOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_DepthToSpaceOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_NonMaxSuppressionV4Options: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_NonMaxSuppressionV5Options: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_ScatterNdOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_SelectV2Options: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_DensifyOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_SegmentSumOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_BatchMatMulOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_CumsumOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_CallOnceOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_BroadcastToOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_Rfft2dOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_Conv3DOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_HashtableOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_HashtableFindOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_HashtableImportOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_HashtableSizeOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_VarHandleOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_ReadVariableOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_AssignVariableOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_RandomOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_BucketizeOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_GeluOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_DynamicUpdateSliceOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_UnsortedSegmentProdOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_UnsortedSegmentMaxOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_UnsortedSegmentMinOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_UnsortedSegmentSumOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_ATan2Options: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_SignOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_BitcastOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_BitwiseXorOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_RightShiftOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + default: return nullptr; + } +} + +inline ::flatbuffers::Offset BuiltinOptionsUnion::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const ::flatbuffers::rehasher_function_t *_rehasher) const { + (void)_rehasher; + switch (type) { + case BuiltinOptions_Conv2DOptions: { + auto ptr = reinterpret_cast(value); + return CreateConv2DOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_DepthwiseConv2DOptions: { + auto ptr = reinterpret_cast(value); + return CreateDepthwiseConv2DOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_ConcatEmbeddingsOptions: { + auto ptr = reinterpret_cast(value); + return CreateConcatEmbeddingsOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_LSHProjectionOptions: { + auto ptr = reinterpret_cast(value); + return CreateLSHProjectionOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_Pool2DOptions: { + auto ptr = reinterpret_cast(value); + return CreatePool2DOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_SVDFOptions: { + auto ptr = reinterpret_cast(value); + return CreateSVDFOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_RNNOptions: { + auto ptr = reinterpret_cast(value); + return CreateRNNOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_FullyConnectedOptions: { + auto ptr = reinterpret_cast(value); + return CreateFullyConnectedOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_SoftmaxOptions: { + auto ptr = reinterpret_cast(value); + return CreateSoftmaxOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_ConcatenationOptions: { + auto ptr = reinterpret_cast(value); + return CreateConcatenationOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_AddOptions: { + auto ptr = reinterpret_cast(value); + return CreateAddOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_L2NormOptions: { + auto ptr = reinterpret_cast(value); + return CreateL2NormOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_LocalResponseNormalizationOptions: { + auto ptr = reinterpret_cast(value); + return CreateLocalResponseNormalizationOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_LSTMOptions: { + auto ptr = reinterpret_cast(value); + return CreateLSTMOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_ResizeBilinearOptions: { + auto ptr = reinterpret_cast(value); + return CreateResizeBilinearOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_CallOptions: { + auto ptr = reinterpret_cast(value); + return CreateCallOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_ReshapeOptions: { + auto ptr = reinterpret_cast(value); + return CreateReshapeOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_SkipGramOptions: { + auto ptr = reinterpret_cast(value); + return CreateSkipGramOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_SpaceToDepthOptions: { + auto ptr = reinterpret_cast(value); + return CreateSpaceToDepthOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_EmbeddingLookupSparseOptions: { + auto ptr = reinterpret_cast(value); + return CreateEmbeddingLookupSparseOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_MulOptions: { + auto ptr = reinterpret_cast(value); + return CreateMulOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_PadOptions: { + auto ptr = reinterpret_cast(value); + return CreatePadOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_GatherOptions: { + auto ptr = reinterpret_cast(value); + return CreateGatherOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_BatchToSpaceNDOptions: { + auto ptr = reinterpret_cast(value); + return CreateBatchToSpaceNDOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_SpaceToBatchNDOptions: { + auto ptr = reinterpret_cast(value); + return CreateSpaceToBatchNDOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_TransposeOptions: { + auto ptr = reinterpret_cast(value); + return CreateTransposeOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_ReducerOptions: { + auto ptr = reinterpret_cast(value); + return CreateReducerOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_SubOptions: { + auto ptr = reinterpret_cast(value); + return CreateSubOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_DivOptions: { + auto ptr = reinterpret_cast(value); + return CreateDivOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_SqueezeOptions: { + auto ptr = reinterpret_cast(value); + return CreateSqueezeOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_SequenceRNNOptions: { + auto ptr = reinterpret_cast(value); + return CreateSequenceRNNOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_StridedSliceOptions: { + auto ptr = reinterpret_cast(value); + return CreateStridedSliceOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_ExpOptions: { + auto ptr = reinterpret_cast(value); + return CreateExpOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_TopKV2Options: { + auto ptr = reinterpret_cast(value); + return CreateTopKV2Options(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_SplitOptions: { + auto ptr = reinterpret_cast(value); + return CreateSplitOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_LogSoftmaxOptions: { + auto ptr = reinterpret_cast(value); + return CreateLogSoftmaxOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_CastOptions: { + auto ptr = reinterpret_cast(value); + return CreateCastOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_DequantizeOptions: { + auto ptr = reinterpret_cast(value); + return CreateDequantizeOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_MaximumMinimumOptions: { + auto ptr = reinterpret_cast(value); + return CreateMaximumMinimumOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_ArgMaxOptions: { + auto ptr = reinterpret_cast(value); + return CreateArgMaxOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_LessOptions: { + auto ptr = reinterpret_cast(value); + return CreateLessOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_NegOptions: { + auto ptr = reinterpret_cast(value); + return CreateNegOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_PadV2Options: { + auto ptr = reinterpret_cast(value); + return CreatePadV2Options(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_GreaterOptions: { + auto ptr = reinterpret_cast(value); + return CreateGreaterOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_GreaterEqualOptions: { + auto ptr = reinterpret_cast(value); + return CreateGreaterEqualOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_LessEqualOptions: { + auto ptr = reinterpret_cast(value); + return CreateLessEqualOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_SelectOptions: { + auto ptr = reinterpret_cast(value); + return CreateSelectOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_SliceOptions: { + auto ptr = reinterpret_cast(value); + return CreateSliceOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_TransposeConvOptions: { + auto ptr = reinterpret_cast(value); + return CreateTransposeConvOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_SparseToDenseOptions: { + auto ptr = reinterpret_cast(value); + return CreateSparseToDenseOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_TileOptions: { + auto ptr = reinterpret_cast(value); + return CreateTileOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_ExpandDimsOptions: { + auto ptr = reinterpret_cast(value); + return CreateExpandDimsOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_EqualOptions: { + auto ptr = reinterpret_cast(value); + return CreateEqualOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_NotEqualOptions: { + auto ptr = reinterpret_cast(value); + return CreateNotEqualOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_ShapeOptions: { + auto ptr = reinterpret_cast(value); + return CreateShapeOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_PowOptions: { + auto ptr = reinterpret_cast(value); + return CreatePowOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_ArgMinOptions: { + auto ptr = reinterpret_cast(value); + return CreateArgMinOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_FakeQuantOptions: { + auto ptr = reinterpret_cast(value); + return CreateFakeQuantOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_PackOptions: { + auto ptr = reinterpret_cast(value); + return CreatePackOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_LogicalOrOptions: { + auto ptr = reinterpret_cast(value); + return CreateLogicalOrOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_OneHotOptions: { + auto ptr = reinterpret_cast(value); + return CreateOneHotOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_LogicalAndOptions: { + auto ptr = reinterpret_cast(value); + return CreateLogicalAndOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_LogicalNotOptions: { + auto ptr = reinterpret_cast(value); + return CreateLogicalNotOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_UnpackOptions: { + auto ptr = reinterpret_cast(value); + return CreateUnpackOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_FloorDivOptions: { + auto ptr = reinterpret_cast(value); + return CreateFloorDivOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_SquareOptions: { + auto ptr = reinterpret_cast(value); + return CreateSquareOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_ZerosLikeOptions: { + auto ptr = reinterpret_cast(value); + return CreateZerosLikeOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_FillOptions: { + auto ptr = reinterpret_cast(value); + return CreateFillOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_BidirectionalSequenceLSTMOptions: { + auto ptr = reinterpret_cast(value); + return CreateBidirectionalSequenceLSTMOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_BidirectionalSequenceRNNOptions: { + auto ptr = reinterpret_cast(value); + return CreateBidirectionalSequenceRNNOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_UnidirectionalSequenceLSTMOptions: { + auto ptr = reinterpret_cast(value); + return CreateUnidirectionalSequenceLSTMOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_FloorModOptions: { + auto ptr = reinterpret_cast(value); + return CreateFloorModOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_RangeOptions: { + auto ptr = reinterpret_cast(value); + return CreateRangeOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_ResizeNearestNeighborOptions: { + auto ptr = reinterpret_cast(value); + return CreateResizeNearestNeighborOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_LeakyReluOptions: { + auto ptr = reinterpret_cast(value); + return CreateLeakyReluOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_SquaredDifferenceOptions: { + auto ptr = reinterpret_cast(value); + return CreateSquaredDifferenceOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_MirrorPadOptions: { + auto ptr = reinterpret_cast(value); + return CreateMirrorPadOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_AbsOptions: { + auto ptr = reinterpret_cast(value); + return CreateAbsOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_SplitVOptions: { + auto ptr = reinterpret_cast(value); + return CreateSplitVOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_UniqueOptions: { + auto ptr = reinterpret_cast(value); + return CreateUniqueOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_ReverseV2Options: { + auto ptr = reinterpret_cast(value); + return CreateReverseV2Options(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_AddNOptions: { + auto ptr = reinterpret_cast(value); + return CreateAddNOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_GatherNdOptions: { + auto ptr = reinterpret_cast(value); + return CreateGatherNdOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_CosOptions: { + auto ptr = reinterpret_cast(value); + return CreateCosOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_WhereOptions: { + auto ptr = reinterpret_cast(value); + return CreateWhereOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_RankOptions: { + auto ptr = reinterpret_cast(value); + return CreateRankOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_ReverseSequenceOptions: { + auto ptr = reinterpret_cast(value); + return CreateReverseSequenceOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_MatrixDiagOptions: { + auto ptr = reinterpret_cast(value); + return CreateMatrixDiagOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_QuantizeOptions: { + auto ptr = reinterpret_cast(value); + return CreateQuantizeOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_MatrixSetDiagOptions: { + auto ptr = reinterpret_cast(value); + return CreateMatrixSetDiagOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_HardSwishOptions: { + auto ptr = reinterpret_cast(value); + return CreateHardSwishOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_IfOptions: { + auto ptr = reinterpret_cast(value); + return CreateIfOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_WhileOptions: { + auto ptr = reinterpret_cast(value); + return CreateWhileOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_DepthToSpaceOptions: { + auto ptr = reinterpret_cast(value); + return CreateDepthToSpaceOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_NonMaxSuppressionV4Options: { + auto ptr = reinterpret_cast(value); + return CreateNonMaxSuppressionV4Options(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_NonMaxSuppressionV5Options: { + auto ptr = reinterpret_cast(value); + return CreateNonMaxSuppressionV5Options(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_ScatterNdOptions: { + auto ptr = reinterpret_cast(value); + return CreateScatterNdOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_SelectV2Options: { + auto ptr = reinterpret_cast(value); + return CreateSelectV2Options(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_DensifyOptions: { + auto ptr = reinterpret_cast(value); + return CreateDensifyOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_SegmentSumOptions: { + auto ptr = reinterpret_cast(value); + return CreateSegmentSumOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_BatchMatMulOptions: { + auto ptr = reinterpret_cast(value); + return CreateBatchMatMulOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_CumsumOptions: { + auto ptr = reinterpret_cast(value); + return CreateCumsumOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_CallOnceOptions: { + auto ptr = reinterpret_cast(value); + return CreateCallOnceOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_BroadcastToOptions: { + auto ptr = reinterpret_cast(value); + return CreateBroadcastToOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_Rfft2dOptions: { + auto ptr = reinterpret_cast(value); + return CreateRfft2dOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_Conv3DOptions: { + auto ptr = reinterpret_cast(value); + return CreateConv3DOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_HashtableOptions: { + auto ptr = reinterpret_cast(value); + return CreateHashtableOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_HashtableFindOptions: { + auto ptr = reinterpret_cast(value); + return CreateHashtableFindOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_HashtableImportOptions: { + auto ptr = reinterpret_cast(value); + return CreateHashtableImportOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_HashtableSizeOptions: { + auto ptr = reinterpret_cast(value); + return CreateHashtableSizeOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_VarHandleOptions: { + auto ptr = reinterpret_cast(value); + return CreateVarHandleOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_ReadVariableOptions: { + auto ptr = reinterpret_cast(value); + return CreateReadVariableOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_AssignVariableOptions: { + auto ptr = reinterpret_cast(value); + return CreateAssignVariableOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_RandomOptions: { + auto ptr = reinterpret_cast(value); + return CreateRandomOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_BucketizeOptions: { + auto ptr = reinterpret_cast(value); + return CreateBucketizeOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_GeluOptions: { + auto ptr = reinterpret_cast(value); + return CreateGeluOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_DynamicUpdateSliceOptions: { + auto ptr = reinterpret_cast(value); + return CreateDynamicUpdateSliceOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_UnsortedSegmentProdOptions: { + auto ptr = reinterpret_cast(value); + return CreateUnsortedSegmentProdOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_UnsortedSegmentMaxOptions: { + auto ptr = reinterpret_cast(value); + return CreateUnsortedSegmentMaxOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_UnsortedSegmentMinOptions: { + auto ptr = reinterpret_cast(value); + return CreateUnsortedSegmentMinOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_UnsortedSegmentSumOptions: { + auto ptr = reinterpret_cast(value); + return CreateUnsortedSegmentSumOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_ATan2Options: { + auto ptr = reinterpret_cast(value); + return CreateATan2Options(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_SignOptions: { + auto ptr = reinterpret_cast(value); + return CreateSignOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_BitcastOptions: { + auto ptr = reinterpret_cast(value); + return CreateBitcastOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_BitwiseXorOptions: { + auto ptr = reinterpret_cast(value); + return CreateBitwiseXorOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_RightShiftOptions: { + auto ptr = reinterpret_cast(value); + return CreateRightShiftOptions(_fbb, ptr, _rehasher).Union(); + } + default: return 0; + } +} + +inline BuiltinOptionsUnion::BuiltinOptionsUnion(const BuiltinOptionsUnion &u) : type(u.type), value(nullptr) { + switch (type) { + case BuiltinOptions_Conv2DOptions: { + value = new tflite::Conv2DOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_DepthwiseConv2DOptions: { + value = new tflite::DepthwiseConv2DOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_ConcatEmbeddingsOptions: { + value = new tflite::ConcatEmbeddingsOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_LSHProjectionOptions: { + value = new tflite::LSHProjectionOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_Pool2DOptions: { + value = new tflite::Pool2DOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_SVDFOptions: { + value = new tflite::SVDFOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_RNNOptions: { + value = new tflite::RNNOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_FullyConnectedOptions: { + value = new tflite::FullyConnectedOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_SoftmaxOptions: { + value = new tflite::SoftmaxOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_ConcatenationOptions: { + value = new tflite::ConcatenationOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_AddOptions: { + value = new tflite::AddOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_L2NormOptions: { + value = new tflite::L2NormOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_LocalResponseNormalizationOptions: { + value = new tflite::LocalResponseNormalizationOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_LSTMOptions: { + value = new tflite::LSTMOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_ResizeBilinearOptions: { + value = new tflite::ResizeBilinearOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_CallOptions: { + value = new tflite::CallOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_ReshapeOptions: { + value = new tflite::ReshapeOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_SkipGramOptions: { + value = new tflite::SkipGramOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_SpaceToDepthOptions: { + value = new tflite::SpaceToDepthOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_EmbeddingLookupSparseOptions: { + value = new tflite::EmbeddingLookupSparseOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_MulOptions: { + value = new tflite::MulOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_PadOptions: { + value = new tflite::PadOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_GatherOptions: { + value = new tflite::GatherOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_BatchToSpaceNDOptions: { + value = new tflite::BatchToSpaceNDOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_SpaceToBatchNDOptions: { + value = new tflite::SpaceToBatchNDOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_TransposeOptions: { + value = new tflite::TransposeOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_ReducerOptions: { + value = new tflite::ReducerOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_SubOptions: { + value = new tflite::SubOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_DivOptions: { + value = new tflite::DivOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_SqueezeOptions: { + value = new tflite::SqueezeOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_SequenceRNNOptions: { + value = new tflite::SequenceRNNOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_StridedSliceOptions: { + value = new tflite::StridedSliceOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_ExpOptions: { + value = new tflite::ExpOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_TopKV2Options: { + value = new tflite::TopKV2OptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_SplitOptions: { + value = new tflite::SplitOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_LogSoftmaxOptions: { + value = new tflite::LogSoftmaxOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_CastOptions: { + value = new tflite::CastOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_DequantizeOptions: { + value = new tflite::DequantizeOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_MaximumMinimumOptions: { + value = new tflite::MaximumMinimumOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_ArgMaxOptions: { + value = new tflite::ArgMaxOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_LessOptions: { + value = new tflite::LessOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_NegOptions: { + value = new tflite::NegOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_PadV2Options: { + value = new tflite::PadV2OptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_GreaterOptions: { + value = new tflite::GreaterOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_GreaterEqualOptions: { + value = new tflite::GreaterEqualOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_LessEqualOptions: { + value = new tflite::LessEqualOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_SelectOptions: { + value = new tflite::SelectOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_SliceOptions: { + value = new tflite::SliceOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_TransposeConvOptions: { + value = new tflite::TransposeConvOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_SparseToDenseOptions: { + value = new tflite::SparseToDenseOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_TileOptions: { + value = new tflite::TileOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_ExpandDimsOptions: { + value = new tflite::ExpandDimsOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_EqualOptions: { + value = new tflite::EqualOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_NotEqualOptions: { + value = new tflite::NotEqualOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_ShapeOptions: { + value = new tflite::ShapeOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_PowOptions: { + value = new tflite::PowOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_ArgMinOptions: { + value = new tflite::ArgMinOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_FakeQuantOptions: { + value = new tflite::FakeQuantOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_PackOptions: { + value = new tflite::PackOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_LogicalOrOptions: { + value = new tflite::LogicalOrOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_OneHotOptions: { + value = new tflite::OneHotOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_LogicalAndOptions: { + value = new tflite::LogicalAndOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_LogicalNotOptions: { + value = new tflite::LogicalNotOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_UnpackOptions: { + value = new tflite::UnpackOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_FloorDivOptions: { + value = new tflite::FloorDivOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_SquareOptions: { + value = new tflite::SquareOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_ZerosLikeOptions: { + value = new tflite::ZerosLikeOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_FillOptions: { + value = new tflite::FillOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_BidirectionalSequenceLSTMOptions: { + value = new tflite::BidirectionalSequenceLSTMOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_BidirectionalSequenceRNNOptions: { + value = new tflite::BidirectionalSequenceRNNOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_UnidirectionalSequenceLSTMOptions: { + value = new tflite::UnidirectionalSequenceLSTMOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_FloorModOptions: { + value = new tflite::FloorModOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_RangeOptions: { + value = new tflite::RangeOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_ResizeNearestNeighborOptions: { + value = new tflite::ResizeNearestNeighborOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_LeakyReluOptions: { + value = new tflite::LeakyReluOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_SquaredDifferenceOptions: { + value = new tflite::SquaredDifferenceOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_MirrorPadOptions: { + value = new tflite::MirrorPadOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_AbsOptions: { + value = new tflite::AbsOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_SplitVOptions: { + value = new tflite::SplitVOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_UniqueOptions: { + value = new tflite::UniqueOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_ReverseV2Options: { + value = new tflite::ReverseV2OptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_AddNOptions: { + value = new tflite::AddNOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_GatherNdOptions: { + value = new tflite::GatherNdOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_CosOptions: { + value = new tflite::CosOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_WhereOptions: { + value = new tflite::WhereOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_RankOptions: { + value = new tflite::RankOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_ReverseSequenceOptions: { + value = new tflite::ReverseSequenceOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_MatrixDiagOptions: { + value = new tflite::MatrixDiagOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_QuantizeOptions: { + value = new tflite::QuantizeOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_MatrixSetDiagOptions: { + value = new tflite::MatrixSetDiagOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_HardSwishOptions: { + value = new tflite::HardSwishOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_IfOptions: { + value = new tflite::IfOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_WhileOptions: { + value = new tflite::WhileOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_DepthToSpaceOptions: { + value = new tflite::DepthToSpaceOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_NonMaxSuppressionV4Options: { + value = new tflite::NonMaxSuppressionV4OptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_NonMaxSuppressionV5Options: { + value = new tflite::NonMaxSuppressionV5OptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_ScatterNdOptions: { + value = new tflite::ScatterNdOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_SelectV2Options: { + value = new tflite::SelectV2OptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_DensifyOptions: { + value = new tflite::DensifyOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_SegmentSumOptions: { + value = new tflite::SegmentSumOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_BatchMatMulOptions: { + value = new tflite::BatchMatMulOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_CumsumOptions: { + value = new tflite::CumsumOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_CallOnceOptions: { + value = new tflite::CallOnceOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_BroadcastToOptions: { + value = new tflite::BroadcastToOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_Rfft2dOptions: { + value = new tflite::Rfft2dOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_Conv3DOptions: { + value = new tflite::Conv3DOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_HashtableOptions: { + value = new tflite::HashtableOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_HashtableFindOptions: { + value = new tflite::HashtableFindOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_HashtableImportOptions: { + value = new tflite::HashtableImportOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_HashtableSizeOptions: { + value = new tflite::HashtableSizeOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_VarHandleOptions: { + value = new tflite::VarHandleOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_ReadVariableOptions: { + value = new tflite::ReadVariableOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_AssignVariableOptions: { + value = new tflite::AssignVariableOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_RandomOptions: { + value = new tflite::RandomOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_BucketizeOptions: { + value = new tflite::BucketizeOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_GeluOptions: { + value = new tflite::GeluOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_DynamicUpdateSliceOptions: { + value = new tflite::DynamicUpdateSliceOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_UnsortedSegmentProdOptions: { + value = new tflite::UnsortedSegmentProdOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_UnsortedSegmentMaxOptions: { + value = new tflite::UnsortedSegmentMaxOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_UnsortedSegmentMinOptions: { + value = new tflite::UnsortedSegmentMinOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_UnsortedSegmentSumOptions: { + value = new tflite::UnsortedSegmentSumOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_ATan2Options: { + value = new tflite::ATan2OptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_SignOptions: { + value = new tflite::SignOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_BitcastOptions: { + value = new tflite::BitcastOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_BitwiseXorOptions: { + value = new tflite::BitwiseXorOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_RightShiftOptions: { + value = new tflite::RightShiftOptionsT(*reinterpret_cast(u.value)); + break; + } + default: + break; + } +} + +inline void BuiltinOptionsUnion::Reset() { + switch (type) { + case BuiltinOptions_Conv2DOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_DepthwiseConv2DOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_ConcatEmbeddingsOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_LSHProjectionOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_Pool2DOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_SVDFOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_RNNOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_FullyConnectedOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_SoftmaxOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_ConcatenationOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_AddOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_L2NormOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_LocalResponseNormalizationOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_LSTMOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_ResizeBilinearOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_CallOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_ReshapeOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_SkipGramOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_SpaceToDepthOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_EmbeddingLookupSparseOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_MulOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_PadOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_GatherOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_BatchToSpaceNDOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_SpaceToBatchNDOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_TransposeOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_ReducerOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_SubOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_DivOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_SqueezeOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_SequenceRNNOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_StridedSliceOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_ExpOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_TopKV2Options: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_SplitOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_LogSoftmaxOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_CastOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_DequantizeOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_MaximumMinimumOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_ArgMaxOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_LessOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_NegOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_PadV2Options: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_GreaterOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_GreaterEqualOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_LessEqualOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_SelectOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_SliceOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_TransposeConvOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_SparseToDenseOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_TileOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_ExpandDimsOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_EqualOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_NotEqualOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_ShapeOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_PowOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_ArgMinOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_FakeQuantOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_PackOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_LogicalOrOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_OneHotOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_LogicalAndOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_LogicalNotOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_UnpackOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_FloorDivOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_SquareOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_ZerosLikeOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_FillOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_BidirectionalSequenceLSTMOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_BidirectionalSequenceRNNOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_UnidirectionalSequenceLSTMOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_FloorModOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_RangeOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_ResizeNearestNeighborOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_LeakyReluOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_SquaredDifferenceOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_MirrorPadOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_AbsOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_SplitVOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_UniqueOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_ReverseV2Options: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_AddNOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_GatherNdOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_CosOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_WhereOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_RankOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_ReverseSequenceOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_MatrixDiagOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_QuantizeOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_MatrixSetDiagOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_HardSwishOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_IfOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_WhileOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_DepthToSpaceOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_NonMaxSuppressionV4Options: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_NonMaxSuppressionV5Options: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_ScatterNdOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_SelectV2Options: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_DensifyOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_SegmentSumOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_BatchMatMulOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_CumsumOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_CallOnceOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_BroadcastToOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_Rfft2dOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_Conv3DOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_HashtableOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_HashtableFindOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_HashtableImportOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_HashtableSizeOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_VarHandleOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_ReadVariableOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_AssignVariableOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_RandomOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_BucketizeOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_GeluOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_DynamicUpdateSliceOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_UnsortedSegmentProdOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_UnsortedSegmentMaxOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_UnsortedSegmentMinOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_UnsortedSegmentSumOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_ATan2Options: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_SignOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_BitcastOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_BitwiseXorOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_RightShiftOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + default: break; + } + value = nullptr; + type = BuiltinOptions_NONE; +} + +inline bool VerifyBuiltinOptions2(::flatbuffers::Verifier &verifier, const void *obj, BuiltinOptions2 type) { + switch (type) { + case BuiltinOptions2_NONE: { + return true; + } + case BuiltinOptions2_StablehloConcatenateOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions2_StablehloBroadcastInDimOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions2_StablehloSliceOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions2_StablehloConvolutionOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions2_StablehloCustomCallOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions2_StablehloReduceOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions2_StablehloScatterOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions2_StablehloCompareOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions2_StablehloDynamicSliceOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions2_StablehloPadOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions2_StablehloIotaOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions2_StablehloDotGeneralOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions2_StablehloReduceWindowOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions2_StablehloSortOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions2_StablehloWhileOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions2_StablehloGatherOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions2_StablehloTransposeOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions2_DilateOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions2_StablehloRngBitGeneratorOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions2_ReduceWindowOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions2_StableHLOCompositeOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions2_StablehloShiftLeftOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions2_StablehloCaseOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + default: return true; + } +} + +inline bool VerifyBuiltinOptions2Vector(::flatbuffers::Verifier &verifier, const ::flatbuffers::Vector<::flatbuffers::Offset> *values, const ::flatbuffers::Vector *types) { + if (!values || !types) return !values && !types; + if (values->size() != types->size()) return false; + for (::flatbuffers::uoffset_t i = 0; i < values->size(); ++i) { + if (!VerifyBuiltinOptions2( + verifier, values->Get(i), types->GetEnum(i))) { + return false; + } + } + return true; +} + +inline void *BuiltinOptions2Union::UnPack(const void *obj, BuiltinOptions2 type, const ::flatbuffers::resolver_function_t *resolver) { + (void)resolver; + switch (type) { + case BuiltinOptions2_StablehloConcatenateOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions2_StablehloBroadcastInDimOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions2_StablehloSliceOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions2_StablehloConvolutionOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions2_StablehloCustomCallOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions2_StablehloReduceOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions2_StablehloScatterOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions2_StablehloCompareOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions2_StablehloDynamicSliceOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions2_StablehloPadOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions2_StablehloIotaOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions2_StablehloDotGeneralOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions2_StablehloReduceWindowOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions2_StablehloSortOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions2_StablehloWhileOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions2_StablehloGatherOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions2_StablehloTransposeOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions2_DilateOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions2_StablehloRngBitGeneratorOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions2_ReduceWindowOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions2_StableHLOCompositeOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions2_StablehloShiftLeftOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions2_StablehloCaseOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + default: return nullptr; + } +} + +inline ::flatbuffers::Offset BuiltinOptions2Union::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const ::flatbuffers::rehasher_function_t *_rehasher) const { + (void)_rehasher; + switch (type) { + case BuiltinOptions2_StablehloConcatenateOptions: { + auto ptr = reinterpret_cast(value); + return CreateStablehloConcatenateOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions2_StablehloBroadcastInDimOptions: { + auto ptr = reinterpret_cast(value); + return CreateStablehloBroadcastInDimOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions2_StablehloSliceOptions: { + auto ptr = reinterpret_cast(value); + return CreateStablehloSliceOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions2_StablehloConvolutionOptions: { + auto ptr = reinterpret_cast(value); + return CreateStablehloConvolutionOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions2_StablehloCustomCallOptions: { + auto ptr = reinterpret_cast(value); + return CreateStablehloCustomCallOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions2_StablehloReduceOptions: { + auto ptr = reinterpret_cast(value); + return CreateStablehloReduceOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions2_StablehloScatterOptions: { + auto ptr = reinterpret_cast(value); + return CreateStablehloScatterOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions2_StablehloCompareOptions: { + auto ptr = reinterpret_cast(value); + return CreateStablehloCompareOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions2_StablehloDynamicSliceOptions: { + auto ptr = reinterpret_cast(value); + return CreateStablehloDynamicSliceOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions2_StablehloPadOptions: { + auto ptr = reinterpret_cast(value); + return CreateStablehloPadOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions2_StablehloIotaOptions: { + auto ptr = reinterpret_cast(value); + return CreateStablehloIotaOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions2_StablehloDotGeneralOptions: { + auto ptr = reinterpret_cast(value); + return CreateStablehloDotGeneralOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions2_StablehloReduceWindowOptions: { + auto ptr = reinterpret_cast(value); + return CreateStablehloReduceWindowOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions2_StablehloSortOptions: { + auto ptr = reinterpret_cast(value); + return CreateStablehloSortOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions2_StablehloWhileOptions: { + auto ptr = reinterpret_cast(value); + return CreateStablehloWhileOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions2_StablehloGatherOptions: { + auto ptr = reinterpret_cast(value); + return CreateStablehloGatherOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions2_StablehloTransposeOptions: { + auto ptr = reinterpret_cast(value); + return CreateStablehloTransposeOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions2_DilateOptions: { + auto ptr = reinterpret_cast(value); + return CreateDilateOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions2_StablehloRngBitGeneratorOptions: { + auto ptr = reinterpret_cast(value); + return CreateStablehloRngBitGeneratorOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions2_ReduceWindowOptions: { + auto ptr = reinterpret_cast(value); + return CreateReduceWindowOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions2_StableHLOCompositeOptions: { + auto ptr = reinterpret_cast(value); + return CreateStableHLOCompositeOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions2_StablehloShiftLeftOptions: { + auto ptr = reinterpret_cast(value); + return CreateStablehloShiftLeftOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions2_StablehloCaseOptions: { + auto ptr = reinterpret_cast(value); + return CreateStablehloCaseOptions(_fbb, ptr, _rehasher).Union(); + } + default: return 0; + } +} + +inline BuiltinOptions2Union::BuiltinOptions2Union(const BuiltinOptions2Union &u) : type(u.type), value(nullptr) { + switch (type) { + case BuiltinOptions2_StablehloConcatenateOptions: { + value = new tflite::StablehloConcatenateOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions2_StablehloBroadcastInDimOptions: { + value = new tflite::StablehloBroadcastInDimOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions2_StablehloSliceOptions: { + value = new tflite::StablehloSliceOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions2_StablehloConvolutionOptions: { + value = new tflite::StablehloConvolutionOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions2_StablehloCustomCallOptions: { + value = new tflite::StablehloCustomCallOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions2_StablehloReduceOptions: { + value = new tflite::StablehloReduceOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions2_StablehloScatterOptions: { + value = new tflite::StablehloScatterOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions2_StablehloCompareOptions: { + value = new tflite::StablehloCompareOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions2_StablehloDynamicSliceOptions: { + value = new tflite::StablehloDynamicSliceOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions2_StablehloPadOptions: { + value = new tflite::StablehloPadOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions2_StablehloIotaOptions: { + value = new tflite::StablehloIotaOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions2_StablehloDotGeneralOptions: { + value = new tflite::StablehloDotGeneralOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions2_StablehloReduceWindowOptions: { + value = new tflite::StablehloReduceWindowOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions2_StablehloSortOptions: { + value = new tflite::StablehloSortOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions2_StablehloWhileOptions: { + value = new tflite::StablehloWhileOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions2_StablehloGatherOptions: { + value = new tflite::StablehloGatherOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions2_StablehloTransposeOptions: { + value = new tflite::StablehloTransposeOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions2_DilateOptions: { + value = new tflite::DilateOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions2_StablehloRngBitGeneratorOptions: { + value = new tflite::StablehloRngBitGeneratorOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions2_ReduceWindowOptions: { + value = new tflite::ReduceWindowOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions2_StableHLOCompositeOptions: { + value = new tflite::StableHLOCompositeOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions2_StablehloShiftLeftOptions: { + value = new tflite::StablehloShiftLeftOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions2_StablehloCaseOptions: { + value = new tflite::StablehloCaseOptionsT(*reinterpret_cast(u.value)); + break; + } + default: + break; + } +} + +inline void BuiltinOptions2Union::Reset() { + switch (type) { + case BuiltinOptions2_StablehloConcatenateOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions2_StablehloBroadcastInDimOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions2_StablehloSliceOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions2_StablehloConvolutionOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions2_StablehloCustomCallOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions2_StablehloReduceOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions2_StablehloScatterOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions2_StablehloCompareOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions2_StablehloDynamicSliceOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions2_StablehloPadOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions2_StablehloIotaOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions2_StablehloDotGeneralOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions2_StablehloReduceWindowOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions2_StablehloSortOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions2_StablehloWhileOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions2_StablehloGatherOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions2_StablehloTransposeOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions2_DilateOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions2_StablehloRngBitGeneratorOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions2_ReduceWindowOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions2_StableHLOCompositeOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions2_StablehloShiftLeftOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions2_StablehloCaseOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + default: break; + } + value = nullptr; + type = BuiltinOptions2_NONE; +} + +inline const tflite::Model *GetModel(const void *buf) { + return ::flatbuffers::GetRoot(buf); +} + +inline const tflite::Model *GetSizePrefixedModel(const void *buf) { + return ::flatbuffers::GetSizePrefixedRoot(buf); +} + +inline const char *ModelIdentifier() { + return "TFL3"; +} + +inline bool ModelBufferHasIdentifier(const void *buf) { + return ::flatbuffers::BufferHasIdentifier( + buf, ModelIdentifier()); +} + +inline bool SizePrefixedModelBufferHasIdentifier(const void *buf) { + return ::flatbuffers::BufferHasIdentifier( + buf, ModelIdentifier(), true); +} + +inline bool VerifyModelBuffer( + ::flatbuffers::Verifier &verifier) { + return verifier.VerifyBuffer(ModelIdentifier()); +} + +inline bool VerifySizePrefixedModelBuffer( + ::flatbuffers::Verifier &verifier) { + return verifier.VerifySizePrefixedBuffer(ModelIdentifier()); +} + +inline const char *ModelExtension() { + return "tflite"; +} + +inline void FinishModelBuffer( + ::flatbuffers::FlatBufferBuilder &fbb, + ::flatbuffers::Offset root) { + fbb.Finish(root, ModelIdentifier()); +} + +inline void FinishSizePrefixedModelBuffer( + ::flatbuffers::FlatBufferBuilder &fbb, + ::flatbuffers::Offset root) { + fbb.FinishSizePrefixed(root, ModelIdentifier()); +} + +inline std::unique_ptr UnPackModel( + const void *buf, + const ::flatbuffers::resolver_function_t *res = nullptr) { + return std::unique_ptr(GetModel(buf)->UnPack(res)); +} + +inline std::unique_ptr UnPackSizePrefixedModel( + const void *buf, + const ::flatbuffers::resolver_function_t *res = nullptr) { + return std::unique_ptr(GetSizePrefixedModel(buf)->UnPack(res)); +} + +} // namespace tflite + +#endif // FLATBUFFERS_GENERATED_SCHEMA_TFLITE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/schema/schema_utils.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/schema/schema_utils.h new file mode 100644 index 00000000..7498aa02 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/schema/schema_utils.h @@ -0,0 +1,33 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_SCHEMA_SCHEMA_UTILS_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_SCHEMA_SCHEMA_UTILS_H_ + +#include "flatbuffers/flatbuffers.h" +#include "tensorflow/compiler/mlir/lite/schema/schema_generated.h" + +namespace tflite { + +// The following methods are introduced to resolve op builtin code shortage +// problem. The new builtin operator will be assigned to the extended builtin +// code field in the flatbuffer schema. Those methods helps to hide builtin code +// details. +BuiltinOperator GetBuiltinCode(const OperatorCode *op_code); + +BuiltinOperator GetBuiltinCode(const OperatorCodeT *op_code); + +} // namespace tflite + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_SCHEMA_SCHEMA_UTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/sparsity/sparsify_model.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/sparsity/sparsify_model.h new file mode 100644 index 00000000..4fa1b5e2 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/sparsity/sparsify_model.h @@ -0,0 +1,31 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_SPARSITY_SPARSIFY_MODEL_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_SPARSITY_SPARSIFY_MODEL_H_ + +#include "absl/status/status.h" +#include "flatbuffers/flatbuffer_builder.h" // from @flatbuffers +#include "tensorflow/compiler/mlir/lite/schema/schema_generated.h" + +namespace mlir { +namespace lite { + +// Sparsify the `input_model` and write the result to a flatbuffer `builder`. +absl::Status SparsifyModel(const tflite::ModelT& input_model, + flatbuffers::FlatBufferBuilder* builder); +} // namespace lite +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_SPARSITY_SPARSIFY_MODEL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/stablehlo/odml_converter/folders.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/stablehlo/odml_converter/folders.h new file mode 100644 index 00000000..6f3d2d55 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/stablehlo/odml_converter/folders.h @@ -0,0 +1,26 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_ODML_CONVERTER_FOLDERS_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_ODML_CONVERTER_FOLDERS_H_ + +namespace mlir::odml { + +// Populates the pattern set with all folding patterns. These patterns +// are intended to have precedence over any other patterns added to the set. +void PopulateFolderPatterns(RewritePatternSet &patternSet); + +} // namespace mlir::odml + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_ODML_CONVERTER_FOLDERS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/stablehlo/odml_converter/passes.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/stablehlo/odml_converter/passes.h new file mode 100644 index 00000000..bb0c02cc --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/stablehlo/odml_converter/passes.h @@ -0,0 +1,34 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_ODML_CONVERTER_PASSES_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_ODML_CONVERTER_PASSES_H_ + +#include + +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project + +namespace mlir::odml { + +std::unique_ptr> CreateSHLOSimplifyPass(); + +#define GEN_PASS_REGISTRATION +#include "tensorflow/compiler/mlir/lite/stablehlo/odml_converter/passes.h.inc" + +} // namespace mlir::odml + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_ODML_CONVERTER_PASSES_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/stablehlo/transforms/check_accepted_ops_pass.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/stablehlo/transforms/check_accepted_ops_pass.h new file mode 100644 index 00000000..c6461d81 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/stablehlo/transforms/check_accepted_ops_pass.h @@ -0,0 +1,37 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_CHECK_DIALECTS_PASS_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_CHECK_DIALECTS_PASS_H_ + +#include +#include +#include + +#include "mlir/Pass/Pass.h" // from @llvm-project + +namespace mlir { +namespace odml { + +// Creates a pass which checks if there exists allowed dialect ops only or not. +// Based on the list of dialect and op names, it signals failure or not. +// If some ops are in the `optional_accepted_dialects`, then it warns them. +std::unique_ptr createCheckAcceptedOpsPass( + const std::vector &optional_accepted_dialects = {}); + +} // namespace odml +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_CHECK_DIALECTS_PASS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/stablehlo/transforms/composite_avg_pool.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/stablehlo/transforms/composite_avg_pool.h new file mode 100644 index 00000000..2afa2066 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/stablehlo/transforms/composite_avg_pool.h @@ -0,0 +1,55 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_COMPOSITE_AVG_POOL_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_COMPOSITE_AVG_POOL_H_ + +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project +#include "mlir/Transforms/DialectConversion.h" // from @llvm-project +#include "tensorflow/compiler/mlir/lite/transforms/passes.h" // IWYU pragma: keep +#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" // IWYU pragma: keep + +namespace mlir { +namespace odml { + +// Given a Composite op that wraps a core.aten.avg_pool2d, returns the padding +// configuration required for the `tfl.pad` if the padding part of the op is +// to be done before average pooling. +DenseIntElementsAttr GetPadOpAttr(Builder& builder, mhlo::CompositeOp op); + +// Given a Composite op that wraps a core.aten.avg_pool2d, and assuming that +// the padding part is extracted into a tfl.pad op prior to a +// tfl.average_pool_2d, this function finds the return type of the needed +// tfl.pad . +ShapedType GetPadOpType(mhlo::CompositeOp op); + +// Given a Composite op that wraps a core.aten.avg_pool2d, finds the padding +// attribute to be passed to the a tfl.average_pool_2d that can fully replace +// this composite (here, padding is done directly by the tfl.average_pool_2d as +// opposed to being extracted into a separate tfl.pad). +StringAttr GetAvgPoolOpPadAttr(Builder& builder, mhlo::CompositeOp op); + +// Get dense attr for a matrix that corrects the over counting of divisors when +// casting an average pool with ceil mode on in terms of average pool with it +// off. +DenseFPElementsAttr GetCorrectionMatrix(Builder& builder, mhlo::CompositeOp op); + +} // namespace odml +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_COMPOSITE_AVG_POOL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/stablehlo/transforms/composite_lowering_pass.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/stablehlo/transforms/composite_lowering_pass.h new file mode 100644 index 00000000..0bb758ad --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/stablehlo/transforms/composite_lowering_pass.h @@ -0,0 +1,27 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_COMPOSITE_LOWERING_PASS_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_COMPOSITE_LOWERING_PASS_H_ + +namespace mlir { +namespace odml { + +std::unique_ptr CreateCompositeLoweringPass(); + +} // namespace odml +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_COMPOSITE_LOWERING_PASS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/stablehlo/transforms/composite_utils.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/stablehlo/transforms/composite_utils.h new file mode 100644 index 00000000..fbd131bb --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/stablehlo/transforms/composite_utils.h @@ -0,0 +1,84 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_COMPOSITE_UTILS_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_COMPOSITE_UTILS_H_ + +#include +#include +#include +#include + +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Transforms/DialectConversion.h" // from @llvm-project +#include "tensorflow/compiler/mlir/lite/transforms/passes.h" // IWYU pragma: keep +#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" // IWYU pragma: keep + +namespace mlir { +namespace odml { + +// Ensure an attribute named attr_name exists and it is of type AttrType. +// If so, sets the `out_attr` pointer to point to the casted attribute. +template +bool EnsureAttribute(const DictionaryAttr& composite_attributes, + const std::string& attr_name, AttrType* out_attr) { + Attribute attr = composite_attributes.get(attr_name); + if (!mlir::isa_and_nonnull(attr)) { + return false; + } + if (AttrType content = mlir::dyn_cast(attr)) { + *out_attr = content; + return true; + } else { + return false; + } +} + +// Changes a DenseIntElementsAttr **containing I64** elements to an I32 Vector. +bool DenseI64AttrToI32Vector(const DenseIntElementsAttr& dense_attr, + std::vector* out_vec); + +// Gets boolean from composite attrs if it exists. +std::optional GetBoolFromCompositeAttr( + const DictionaryAttr& composite_attrs, llvm::StringRef attr_name); + +// Given a DictionaryAttr, checks if it has a DenseIntElementsAttr attribute +// with the name attr_name. If so, extracts its values and stores as a vector +// of int32_t elements. +// Note: This assumes the DenseIntElementsAttr has its values stored as int64_t. +bool GetI32VectorFromDenseI64CompositeAttr( + const DictionaryAttr& composite_attrs, const std::string& attr_name, + std::vector* out_vec); + +// Get a DenseIntElementsAttr of type I64 and convert it to an I32 attribute. +DenseIntElementsAttr DenseI64AttrToI32Attr( + const DenseIntElementsAttr& dense_attr, PatternRewriter& builder); + +// Returns a NHWC shaped type from an NCHW shaped type op. +// For example- Given a Composite op that wraps a core.aten.avg_pool2d, this +// returns the return type of the tfl.average_pool_2d emitted. Note that the +// aten.avg_pool2d works with the NCHW layout while tfl.average_pool_2d assumes +// NHWC. +ShapedType GetNhwcReturnTypeFromNchw(Operation* old_op); + +} // namespace odml + +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_COMPOSITE_UTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/stablehlo/transforms/drop_savedmodel_semantics.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/stablehlo/transforms/drop_savedmodel_semantics.h new file mode 100644 index 00000000..444a3c46 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/stablehlo/transforms/drop_savedmodel_semantics.h @@ -0,0 +1,31 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_DROP_SAVEDMODEL_SEMANTICS_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_DROP_SAVEDMODEL_SEMANTICS_H_ + +#include + +#include "mlir/Pass/Pass.h" // from @llvm-project + +namespace mlir { +namespace odml { + +std::unique_ptr CreateDropSavedModelSemanticsPass(); + +} // namespace odml +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_DROP_SAVEDMODEL_SEMANTICS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/stablehlo/transforms/hlo_matchers.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/stablehlo/transforms/hlo_matchers.h new file mode 100644 index 00000000..ff91176a --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/stablehlo/transforms/hlo_matchers.h @@ -0,0 +1,35 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_HLO_MATCHERS_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_HLO_MATCHERS_H_ + +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project + +namespace mlir { +namespace odml { +// The following 5 different forms of mhlo::iota will be matched: +// 1. IotaOp. +// 2. IotaOp + BroadCastInDim. +// 3. IotaOp + Reshape. +// 4. Constant (folded Iota) + BroadCastInDim. +// 5. Constant (folded result). +// Moreover, the dimensions has to match the iota_dimension. +bool MatchIota(DenseIntElementsAttr dimensions, Value iota); +} // namespace odml +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_HLO_MATCHERS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/conv.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/conv.h new file mode 100644 index 00000000..0f741d9c --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/conv.h @@ -0,0 +1,59 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_HLO_CONVERSIONS_CONV_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_HLO_CONVERSIONS_CONV_H_ + +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/Transforms/DialectConversion.h" // from @llvm-project + +namespace mlir::odml { + +// Prepares mhlo.convolutions and legalizes to the corresponding tfl op. +// +// Note: "tfl-native" layouts are as follows: +// 2D : [b, 0, 1, f]x[o, 0, 1, i]->[b, 0, 1, f] +// 3D : [b, 0, 1, 2, f]x[0, 1, 2, i, o]->[b, 0, 1, 2, f] +// 2D (depthwise) : [b, 0, 1, f]x[i, 0, 1, o]->[b, 0, 1, f] +// +// Matches: mhlo.convolution +// layout: any (will transpose to tfl-native) +// padding: any (will pull into explicit pad_op) +// lhs_dilations: trivial (all 1) +// rhs_dilations: any +// strides: any +// feature_group: see decision tree below +// batch_group: trivial (1) +// reversal: trivial (all False) +// shape: static, rank 4 or 5 +// +// This pattern emits TFL convs based on the following decision tree: +// if lhs_dilations are trivial && kernel_out_features == output_features +// if feature_group == 1: +// if rank == 5: tfl.conv_3D +// if rank == 4: tfl.conv_2D +// else if input_features == feature_group: +// if rank == 4: tfl.depthwise_conv TODO: b/352954597 - Add support. +// else: +// if rank == 4: tfl.conv_2D +// else: +// tfl.transpose_conv TODO: b/352954597 - Add support. +void PopulateLegalizeConvPatterns(MLIRContext* ctx, RewritePatternSet& patterns, + ConversionTarget& target); + +void PopulatePrepareConvPatterns(MLIRContext* ctx, RewritePatternSet& patterns); + +} // namespace mlir::odml + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_HLO_CONVERSIONS_CONV_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/conv_util.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/conv_util.h new file mode 100644 index 00000000..fe9664c1 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/conv_util.h @@ -0,0 +1,298 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_HLO_CONVERSIONS_CONV_UTIL_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_HLO_CONVERSIONS_CONV_UTIL_H_ + +#include +#include +#include + +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/Sequence.h" +#include "llvm/ADT/SmallVector.h" +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/op_util_common.h" +#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" + +// Helpers for working with mhlo.convolution attrs in the mlir api as +// native cc types. + +namespace mlir::odml { + +class ConvView { + public: + // int for each spatial dim. Default 1. + llvm::ArrayRef Strides() const { return strides_; } + + // 2d array for each spatial dim. Default 0. + llvm::ArrayRef Padding() const { return padding_; } + + int64_t BatchGroupCount() const { return batch_group_count_; } + + int64_t FeatureGroupCount() const { return feature_group_count_; } + + // int for each spatial dim. Default 1. + llvm::ArrayRef InputDilations() const { return input_dilations_; } + + // int for each spatial dim. Default 1. + llvm::ArrayRef KernelDilations() const { return kernel_dilations_; } + + // bool for each spatial dim. Default false. + llvm::ArrayRef WindowReversal() const { return window_reversal_; } + + llvm::ArrayRef InputShape() const { return input_shape_; } + + const Layout& InputLayout() const { return input_layout_; } + + llvm::ArrayRef KernelShape() const { return kernel_shape_; } + + const Layout& KernelLayout() const { return kernel_layout_; } + + llvm::ArrayRef OutputShape() const { return output_shape_; } + + const Layout& OutputLayout() const { return output_layout_; } + + mlir::Type ElementType() const { return element_type_; } + + explicit ConvView(mhlo::ConvolutionOp op); + + private: + llvm::SmallVector strides_; + + llvm::SmallVector padding_; + + llvm::SmallVector input_dilations_; + llvm::SmallVector kernel_dilations_; + + llvm::SmallVector window_reversal_; + + Layout input_layout_; + Layout kernel_layout_; + Layout output_layout_; + + llvm::SmallVector input_shape_; + llvm::SmallVector kernel_shape_; + llvm::SmallVector output_shape_; + + int64_t batch_group_count_; + int64_t feature_group_count_; + + mlir::Type element_type_; +}; + +inline bool HasSupportedRank(const ConvView& data) { + return data.InputLayout().Rank() == 4 || data.InputLayout().Rank() == 5; +} + +inline bool HasSupportedOutFeatureDims(const ConvView& data) { + const int64_t kernel_out_features = + data.KernelLayout().SpecialDim2(data.KernelShape()); + const int64_t out_features = + data.OutputLayout().SpecialDim2(data.OutputShape()); + return kernel_out_features == out_features; +} + +inline bool IsTrivialConv(const ConvView& data) { + return llvm::all_of(data.InputDilations(), [](auto d) { return d == 1; }); +} + +// +// Supported non-trivial conv predicates +//=----- + +bool MatchWithResizeBilinearOp(const ConvView& data, bool& align_corners); + +inline bool MatchWithResizeBilinearOp(const ConvView& data) { + bool align_corners = false; + return MatchWithResizeBilinearOp(data, align_corners); +} + +bool IsTransposeConvPaddingValid(mhlo::ConvolutionOp conv_op, + size_t num_spatial_dims, + const ArrayRef& strides, + const ArrayRef& padding); + +bool IsTransposeConvPaddingSame(mhlo::ConvolutionOp conv_op, + size_t num_spatial_dims, + const ArrayRef& strides, + const ArrayRef& padding); + +inline bool IsSupportedNonTrivialConv(const ConvView& data) { + // Only non-trivial 2d convolutions are supported. + const bool valid_rank = data.InputLayout().Rank() == 4; + + // Negative padding is unsupported. + bool has_nagative_padding = llvm::all_of( + data.Padding(), + [](const DimPadding& p) { return p.Hi() < 0 || p.Lo() < 0; }); + + return (valid_rank && !IsTrivialConv(data) && !has_nagative_padding); +} + +inline bool IsSupportedNonTrivialConv(mhlo::ConvolutionOp op) { + const ConvView data(op); + return IsSupportedNonTrivialConv(data); +} + +// +// Standard conv predicates +//=----- + +inline bool HasStandardConvInFeatureDims(const ConvView& data) { + // kernel_in_features * feature_groups = input_features by definition. + const int64_t input_features = + data.InputLayout().SpecialDim2(data.InputShape()); + + const bool trivial_kernel_in_features = + data.FeatureGroupCount() == input_features; + const bool is_grouped_conv = data.FeatureGroupCount() != 1; + + const int64_t rank = data.InputLayout().Rank(); + return !trivial_kernel_in_features && (!is_grouped_conv || rank == 4); +} + +inline bool IsStandardConv(const ConvView& data) { + return HasSupportedRank(data) && IsTrivialConv(data) && + HasStandardConvInFeatureDims(data) && HasSupportedOutFeatureDims(data); +} + +// Does this convolution map to a standard conv_2d or conv_3d +// (not depthwise or tranpose conv)? +inline bool IsStandardConv(mhlo::ConvolutionOp op) { + const ConvView data(op); + return IsStandardConv(data); +} + +// +// Depthwise conv predicates +//=----- + +inline bool IsDepthwiseConv(const ConvView& data) { + const bool valid_rank = data.InputLayout().Rank() == 4; + if (!valid_rank || !HasSupportedOutFeatureDims(data) || + !IsTrivialConv(data)) { + return false; + } + const int64_t in_channel_dim = + data.InputLayout().SpecialDim2(data.InputShape()); + return data.FeatureGroupCount() == in_channel_dim; +} + +// Does this convolution map to depthwise conv? +inline bool IsDepthwiseConv(mhlo::ConvolutionOp op) { + const ConvView data(op); + return IsDepthwiseConv(data); +} + +// +// Tfl native layouts +//=----- + +inline int64_t DnumRank(mhlo::ConvDimensionNumbersAttr dnums) { + return dnums.getInputSpatialDimensions().size() + 2; +} + +inline Layout GetTFLNativeInputOrOutputLayout(int64_t rank) { + auto spatials = llvm::to_vector(llvm::seq(1, rank - 1)); + return Layout(0, rank - 1, spatials); +} + +inline Layout GetTFLNativeInputOrOutputLayout( + mhlo::ConvDimensionNumbersAttr dnums) { + return GetTFLNativeInputOrOutputLayout((DnumRank(dnums))); +} + +inline Layout GetTFLNativeStandardConvKernelLayout(int64_t rank) { + if (rank != 5) { + auto spatials = llvm::to_vector(llvm::seq(1, rank - 1)); + return Layout(rank - 1, 0, spatials); + } + auto spatials = llvm::to_vector(llvm::seq(rank - 2)); + return Layout(rank - 2, rank - 1, spatials); +} + +inline Layout GetTFLNativeDepthwiseConvKernelLayout() { + return Layout(0, 3, {1, 2}); +} + +inline Layout GetTFLNativeStandardConvKernelLayout( + mhlo::ConvDimensionNumbersAttr dnums) { + return GetTFLNativeStandardConvKernelLayout(DnumRank(dnums)); +} + +inline bool IsTFLNativeLayout(const ConvView& data) { + const int64_t rank = data.KernelLayout().Rank(); + const auto native_io_layout = GetTFLNativeInputOrOutputLayout(rank); + + std::optional native_kernel_layout = std::nullopt; + if (IsDepthwiseConv(data)) { + native_kernel_layout = GetTFLNativeDepthwiseConvKernelLayout(); + } else if (IsStandardConv(data) || IsSupportedNonTrivialConv(data)) { + native_kernel_layout = GetTFLNativeStandardConvKernelLayout(rank); + } + if (!native_kernel_layout.has_value()) { + return false; + } + + return data.InputLayout() == native_io_layout && + data.KernelLayout() == *native_kernel_layout && + data.OutputLayout() == native_io_layout; +} + +// +// ConvDimensionNumbers utils +//=----- + +inline mhlo::ConvDimensionNumbersAttr CloneDnumsWithInputLayout( + OpBuilder& b, mhlo::ConvDimensionNumbersAttr dnums, const Layout& layout) { + return mhlo::ConvDimensionNumbersAttr::get( + b.getContext(), layout.SpecialDim1(), layout.SpecialDim2(), + layout.Spatials(), dnums.getKernelInputFeatureDimension(), + dnums.getKernelOutputFeatureDimension(), + dnums.getKernelSpatialDimensions(), dnums.getOutputBatchDimension(), + dnums.getOutputFeatureDimension(), dnums.getOutputSpatialDimensions()); +} + +inline mhlo::ConvDimensionNumbersAttr CloneDnumsWithKernelLayout( + OpBuilder& b, mhlo::ConvDimensionNumbersAttr dnums, const Layout& layout) { + return mhlo::ConvDimensionNumbersAttr::get( + b.getContext(), dnums.getInputBatchDimension(), + dnums.getInputFeatureDimension(), dnums.getInputSpatialDimensions(), + layout.SpecialDim1(), layout.SpecialDim2(), layout.Spatials(), + dnums.getOutputBatchDimension(), dnums.getOutputFeatureDimension(), + dnums.getOutputSpatialDimensions()); +} + +inline mhlo::ConvDimensionNumbersAttr CloneDnumsWithOutputLayout( + OpBuilder& b, mhlo::ConvDimensionNumbersAttr dnums, const Layout& layout) { + return mhlo::ConvDimensionNumbersAttr::get( + b.getContext(), dnums.getInputBatchDimension(), + dnums.getInputFeatureDimension(), dnums.getInputSpatialDimensions(), + dnums.getKernelInputFeatureDimension(), + dnums.getKernelOutputFeatureDimension(), + dnums.getKernelSpatialDimensions(), layout.SpecialDim1(), + layout.SpecialDim2(), layout.Spatials()); +} + +// Wraps the lhs of given conv op in an explicit pad op matching the same +// behavior implicit in the paddings attribute. Gets result of new pad op. +Value CreatePadOpFromConvPadding(OpBuilder& b, mhlo::ConvolutionOp op); + +} // namespace mlir::odml + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_HLO_CONVERSIONS_CONV_UTIL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/custom_call.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/custom_call.h new file mode 100644 index 00000000..c7c3bdde --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/custom_call.h @@ -0,0 +1,34 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_HLO_CONVERSIONS_CUSTOM_CALL_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_HLO_CONVERSIONS_CUSTOM_CALL_H_ + +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/Transforms/DialectConversion.h" // from @llvm-project + +namespace mlir { +namespace odml { + +void PopulateCustomCallPatterns(MLIRContext* ctx, RewritePatternSet& patterns, + ConversionTarget& target); + +void PopulateCustomCallPreparePatterns(MLIRContext* ctx, + RewritePatternSet& patterns); + +} // namespace odml +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_HLO_CONVERSIONS_CUSTOM_CALL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/dot_general.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/dot_general.h new file mode 100644 index 00000000..91df1b63 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/dot_general.h @@ -0,0 +1,53 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Legalize mhlo.dot_general to tflite.batch_matmul. + +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_HLO_CONVERSIONS_DOT_GENERAL_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_HLO_CONVERSIONS_DOT_GENERAL_H_ + +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Transforms/DialectConversion.h" // from @llvm-project +#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" + +namespace mlir { +namespace odml { +// Converts mhlo.dot_general to tfl.BatchMatMul. Reshape and Transpose ops will +// be inserted to convert to well-formed matrix multiply; i.e., mhlo.dot_general +// -> tfl.batch_matmul(mhlo.transpose(mhlo.reshape(operand)), ...). +// Note: +// 1) Reshape/transpose are inserted because tfl.BatchMatMul requires +// size(contracting_dimensions) = 1 and size(output_dim) = 1, whereas +// mhlo.dot_general has no such restriction. +// 2) Inserted mhlo.reshape/transpose will be legalized to tf.reshape/transpose +// in LegalizeHloToTf (then from tf to tfl later). +// 3) If the operands are dynamic shaped tensors, mhlo.DynamicReshapeOp is +// inserted instead of the regular reshape, and additional ops (e.g. Gather, +// Concat ) are inserted for shape inference purposes. +// 4) All the DotOp are converted to DotGeneral during the optimization pass +// (ConvertDotOp). +class LowerDotGeneralOp : public OpConversionPattern { + public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + mhlo::DotGeneralOp op, OpAdaptor adaptor, + ConversionPatternRewriter& rewriter) const final; +}; +} // namespace odml +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_HLO_CONVERSIONS_DOT_GENERAL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/fft.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/fft.h new file mode 100644 index 00000000..0c9cf35f --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/fft.h @@ -0,0 +1,33 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_HLO_CONVERSIONS_FFT_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_HLO_CONVERSIONS_FFT_H_ + +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/Transforms/DialectConversion.h" // from @llvm-project + +namespace mlir::odml { + +// Patterns to legalize mhlo.fft to TFL. +void PopulateLegalizeFftPatterns(MLIRContext* ctx, RewritePatternSet& patterns, + ConversionTarget& target); + +// Patterns to prepare mhlo.fft to TFL. +void PopulatePrepareFftPatterns(MLIRContext* ctx, RewritePatternSet& patterns); + +} // namespace mlir::odml + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_HLO_CONVERSIONS_FFT_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/gather.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/gather.h new file mode 100644 index 00000000..35a36613 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/gather.h @@ -0,0 +1,31 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_HLO_CONVERSIONS_GATHER_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_HLO_CONVERSIONS_GATHER_H_ + +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/Transforms/DialectConversion.h" // from @llvm-project + +namespace mlir::odml { + +// Patterns to legalize mhlo.gather to TFL +// +// Emits: tfl.gather_nd or a combination of tfl.slice, tfl.squeeze, tfl.concat +void PopulateGatherPatterns(MLIRContext* ctx, RewritePatternSet& patterns, + ConversionTarget& target); + +} // namespace mlir::odml + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_HLO_CONVERSIONS_GATHER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/gelu.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/gelu.h new file mode 100644 index 00000000..6dfc67e2 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/gelu.h @@ -0,0 +1,47 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_HLO_CONVERSIONS_GELU_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_HLO_CONVERSIONS_GELU_H_ + +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project + +namespace mlir::odml { + +// Matches non-approximate GELU patterns. +// +// -> mul 1/sqrt(2) -> erf -> add 1 -> +// in mul +// ---------> mul 0.5 ---------------> +// +// This pattern assumes all binary ewise ops with one constant argument +// have that constant argument as the second operand. It works by +// identifying `erf` ops and validate the structure around them. +class LowerGELU : public RewritePattern { + public: + explicit LowerGELU(MLIRContext* context) + : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {} + + LogicalResult matchAndRewrite(Operation* op, + PatternRewriter& rewriter) const override; +}; + +} // namespace mlir::odml + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_HLO_CONVERSIONS_GELU_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/get_dimension_size.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/get_dimension_size.h new file mode 100644 index 00000000..6cd63730 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/get_dimension_size.h @@ -0,0 +1,30 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_HLO_CONVERSIONS_GET_DIMENSION_SIZE_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_HLO_CONVERSIONS_GET_DIMENSION_SIZE_H_ + +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/Transforms/DialectConversion.h" // from @llvm-project + +namespace mlir::odml { + +void PopulateGetDimensionSizePatterns(MLIRContext* ctx, + RewritePatternSet& patterns, + ConversionTarget& target); + +} // namespace mlir::odml + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_HLO_CONVERSIONS_GET_DIMENSION_SIZE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/if.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/if.h new file mode 100644 index 00000000..459aabf9 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/if.h @@ -0,0 +1,30 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_HLO_CONVERSIONS_IF_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_HLO_CONVERSIONS_IF_H_ + +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/Transforms/DialectConversion.h" // from @llvm-project +#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" + +namespace mlir::odml { + +// Patterns to legalize mhlo.if to TFL. +void PopulateIfPatterns(MLIRContext* ctx, RewritePatternSet& patterns, + ConversionTarget& target); + +} // namespace mlir::odml + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_HLO_CONVERSIONS_IF_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/iota.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/iota.h new file mode 100644 index 00000000..7d4f76bd --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/iota.h @@ -0,0 +1,29 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_HLO_CONVERSIONS_IOTA_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_HLO_CONVERSIONS_IOTA_H_ + +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/Transforms/DialectConversion.h" // from @llvm-project + +namespace mlir::odml { + +void PopulateIotaPatterns(MLIRContext* ctx, RewritePatternSet& patterns, + ConversionTarget& target); + +} // namespace mlir::odml + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_HLO_CONVERSIONS_IOTA_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/op_util_common.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/op_util_common.h new file mode 100644 index 00000000..9b0e19aa --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/op_util_common.h @@ -0,0 +1,148 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_HLO_CONVERSIONS_OP_UTIL_COMMON_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_HLO_CONVERSIONS_OP_UTIL_COMMON_H_ + +#include +#include + +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/SmallVector.h" +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project + +namespace mlir::odml { + +// Class that encodes the "layout" of a tensor. Layouts, generically +// are some naming of the dimensions of a tensor. In all cases, 2 dimensions +// are "special" (e.g. batch / feature) and the rest are referred to as "spatial +// dims". When the special dims are batch and feature, batch is special dim 1 +// and feature is special dim 2. When special dims are input and output features +// (conv filter), input features is special dim 1 and output features is special +// dim 2. +class Layout { + public: + llvm::ArrayRef Spatials() const { return spatials_; } + + int64_t NumSpatials() const { return spatials_.size(); } + + int64_t Rank() const { return NumSpatials() + 2; } + + Layout(int64_t special_dim1, int64_t special_dim2, ArrayRef spatials) + : special_dim1_(special_dim1), + special_dim2_(special_dim2), + spatials_(spatials) {} + + // TODO: b/351437662 - Consider just using 2 arrays for the case where + // there are more than 2 special dims. + int64_t SpecialDim1() const { return special_dim1_; } + + // Conveniance accesor for getting the dimension size of the first + // special dimension from a shape. + int64_t SpecialDim1(llvm::ArrayRef shape) const { + return shape[special_dim1_]; + } + + int64_t SpecialDim2() const { return special_dim2_; } + + // Convenience accesor for getting the dimension size of the second + // special dimension from a shape. + int64_t SpecialDim2(llvm::ArrayRef shape) const { + return shape[special_dim2_]; + } + + // Conveniance method for equality checking special dims. + bool HasSpecialDims(int64_t special_dim1, int64_t special_dim2) const; + + // Determines if the spatial dimensions are all adjacent and in + // ascending order. + bool AreSpatialsIota() const; + + // Gets a "permutation array" to be used for transposing a tensor + // of "this" layout to the given layout. A permutation array is some + // permutation of [0, 1, i...] for i < rank(layout). Assumes + // "this" and given layout have the same rank. + llvm::SmallVector GetPermForReLayout( + const Layout& to_layout) const; + + // Permutes given shape based on the permutaion implied to take this Layout to + // the given one. + llvm::SmallVector PermuteShape(const Layout& to_layout, + ArrayRef shape) const; + + bool operator==(const Layout& other) const { + return SpecialDim1() == other.SpecialDim1() && + SpecialDim2() == other.SpecialDim2() && + Spatials() == other.Spatials(); + } + + bool operator!=(const Layout& other) const { return !(*this == other); } + + private: + int64_t special_dim1_; + int64_t special_dim2_; + llvm::SmallVector spatials_; +}; + +// Wrapper for the padding attrs along a single dimension. +class DimPadding { + public: + int64_t Hi() const { return hi_; } + + int64_t Lo() const { return lo_; } + + bool Trivial() const { return Hi() == 0 && Lo() == 0; } + + DimPadding(int64_t lo, int64_t hi) : lo_(lo), hi_(hi) {} + + private: + int64_t lo_; + int64_t hi_; +}; + +inline llvm::SmallVector UnrollI64Splat(DenseElementsAttr data) { + if (!data.isSplat()) { + return llvm::SmallVector(data.getValues()); + } + return llvm::SmallVector(data.getType().getNumElements(), + data.getSplatValue()); +} + +// Resolves optional strides or dilations attributes. If not present, +// will return trivial 1's vector. +llvm::SmallVector ResolveStridesOrDilations( + int64_t rank, std::optional opt_attr); + +// Resolves optional paddings attributes. If not present, will return +// trivial [0, 0] paddings on each dim. +llvm::SmallVector ResolvePadding( + int64_t rank, std::optional opt_padding); + +// Does the padding correspond to "SAME" on given dimension configuration. +// Assumes given dimension configuration is well formed. +bool IsSamePaddingOnDim(int64_t in, int64_t dilate, int64_t stride, int64_t k, + const DimPadding& pad); + +template +inline DenseElementsAttr BuildScalarDense(Type e_type, T val) { + auto type = RankedTensorType::get({}, e_type); + return DenseElementsAttr::get(type, val); +} + +} // namespace mlir::odml + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_HLO_CONVERSIONS_OP_UTIL_COMMON_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/pad.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/pad.h new file mode 100644 index 00000000..a9c0940b --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/pad.h @@ -0,0 +1,39 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_HLO_CONVERSIONS_PAD_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_HLO_CONVERSIONS_PAD_H_ + +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/Transforms/DialectConversion.h" // from @llvm-project +#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" + +namespace mlir::odml { + +// Patterns to legalize mhlo.pad to TFL +// +// Prefers tfl.pad over tfl.padv2 when it can be asserted that the pad +// values are zero. +// +// Matches: mhlo.pad +// padding_high/low: all positive or zero +//. interior_padding: all zero +// +// Emits: tfl.pad, tfl.padv2 +void PopulatePadPatterns(MLIRContext* ctx, RewritePatternSet& patterns, + ConversionTarget& target); + +} // namespace mlir::odml + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_HLO_CONVERSIONS_PAD_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/pad_util.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/pad_util.h new file mode 100644 index 00000000..50419039 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/pad_util.h @@ -0,0 +1,41 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_HLO_CONVERSIONS_PAD_UTIL_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_HLO_CONVERSIONS_PAD_UTIL_H_ + +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" + +namespace mlir::odml { + +// Gets elements corresponding to slice starts from negative padding +// values. +DenseIntElementsAttr SliceStartFromNegPadLows(mhlo::PadOp op); + +// Gets elements corresponding to slice ends from negative padding +// values. +DenseIntElementsAttr SliceEndFromNegPadHighs(mhlo::PadOp op); + +// Gets a copy of `data` with negative values replaced with 0. +DenseIntElementsAttr ReplaceNegsWithZero(DenseElementsAttr data); + +bool AnyNegativePads(mhlo::PadOp op); + +bool TrivialInterior(mhlo::PadOp op); + +} // namespace mlir::odml + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_HLO_CONVERSIONS_PAD_UTIL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/reduce.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/reduce.h new file mode 100644 index 00000000..3bf03aec --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/reduce.h @@ -0,0 +1,36 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_HLO_CONVERSIONS_REDUCE_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_HLO_CONVERSIONS_REDUCE_H_ + +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/Transforms/DialectConversion.h" // from @llvm-project +#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" // IWYU pragma: keep +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" // IWYU pragma: keep + +namespace mlir { +namespace odml { + +void PopulateReduceArgMinMaxTFPatterns(MLIRContext* ctx, + RewritePatternSet& patterns); + +void PopulateReducePatterns(MLIRContext* ctx, RewritePatternSet& patterns, + ConversionTarget& target); + +} // namespace odml +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_HLO_CONVERSIONS_REDUCE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/reduce_window.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/reduce_window.h new file mode 100644 index 00000000..ccc9c27f --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/reduce_window.h @@ -0,0 +1,45 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_HLO_CONVERSIONS_REDUCE_WINDOW_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_HLO_CONVERSIONS_REDUCE_WINDOW_H_ + +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/Transforms/DialectConversion.h" // from @llvm-project + +namespace mlir::odml { + +// Patterns to legalize mhlo.reduce_window to TFL. +// +// Maps the following representations of AvgPool in MHLO into a tfl.avg_pool +// operation when they cleanly map to 2D or 3D average pool with VALID or SAME +// padding: +// * div(reduce_sum_window(x), constant(sizeof(window))) +// * div(reduce_sum_window(x), reduce_sum_window(constant(1))) +// +// Emits: tfl.average_pool2d +void PopulateLegalizeReduceWindowPatterns(MLIRContext* ctx, + RewritePatternSet& patterns, + ConversionTarget& target); + +// Patterns to prepare mhlo.reduce_window for legalization. +// Transposes reduce_windows to be NHWC. +// +// Emits: tfl.transpose +void PopulatePrepareReduceWindowPatterns(MLIRContext* ctx, + RewritePatternSet& patterns); + +} // namespace mlir::odml + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_HLO_CONVERSIONS_REDUCE_WINDOW_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/reduce_window_util.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/reduce_window_util.h new file mode 100644 index 00000000..69834345 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/reduce_window_util.h @@ -0,0 +1,63 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_HLO_CONVERSIONS_REDUCE_WINDOW_UTIL_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_HLO_CONVERSIONS_REDUCE_WINDOW_UTIL_H_ + +#include +#include + +#include "llvm/ADT/ArrayRef.h" +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/op_util_common.h" +#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" + +// Helpers for working with mhlo.reduce_window attrs in the mlir api as +// native cc types. + +namespace mlir::odml { + +class ReduceWindowView { + public: + explicit ReduceWindowView(mhlo::ReduceWindowOp op); + + llvm::ArrayRef WindowDims() const { return window_dims_; } + int64_t WindowSize() const { return window_size_; } + llvm::ArrayRef WindowStrides() const { return window_strides_; } + llvm::ArrayRef Paddings() const { return paddings_; } + llvm::ArrayRef WindowDilations() const { return window_dilations_; } + llvm::ArrayRef BaseDilations() const { return base_dilations_; } + int64_t Rank() const { return rank_; } + + std::optional GuessLayout() const; + + private: + int64_t rank_; + + llvm::SmallVector window_dims_; + llvm::SmallVector window_strides_; + llvm::SmallVector window_dilations_; + + llvm::SmallVector paddings_; + + llvm::SmallVector base_dilations_; + + int64_t window_size_; +}; + +} // namespace mlir::odml + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_HLO_CONVERSIONS_REDUCE_WINDOW_UTIL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/scatter.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/scatter.h new file mode 100644 index 00000000..a7363c68 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/scatter.h @@ -0,0 +1,109 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_HLO_CONVERSIONS_SCATTER_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_HLO_CONVERSIONS_SCATTER_H_ + +#include + +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/SmallVector.h" +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Transforms/DialectConversion.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" + +namespace mlir { +namespace odml { + +// Convert updates into canonical form as expected by tf.scatter ops. +// +// tf.scatter expects `update_window_dims` to be the trailing dimensions. +// +// To support scatter ops generated by numpy-like slice updates: +// nd_array[:, [i,j]] = [i_values, j_values] +// +// `updates` must be transposed when the update_window_dims are the leading +// dimensions of `updates`. +// +// Other values of `update_window_dims` are left unsupported. +// +// Eg 1. An update in canonical form: +// * indices shape(A,B,C) +// * updates shape(A,B,D,E,F) +// Then: +// * D,E,F are the update window dims [2,3,4] +// * C is the index vector dimension +// * A,B iterate over the updates and indices +// +// If `update_window_dims` are not the trailing dimensions then updates must be +// transposed. +// +// Eg 2. An update in non-canonical form: +// * indices shape(a,b,c) +// * updates shape(d,e,f,a,b) +// Then: +// * d,e,f are the update window dims [0,1,2] +// * c is the index vector dimension +// * a,b iterate over the updates and indices +// +// The update needs permuting to be in the form (a,b,d,e,f) so that the update +// window dims are the trailing dimensions. +// +// To canonicalize the updates above, replace the updates with: +// transpose(updates, permutation={3,4,0,1,2}) +// +// Note: NormalizeIndexVector is assumed to have run on the indices already so +// that the index_vector_dim is the trailing dimension in `indices`. +LogicalResult CanonicalizeScatterUpdates( + Operation* scatter_op, llvm::ArrayRef update_window_dims, + const Value& indices, const ShapedType& indices_type, Value& updates, + ShapedType& updates_type, ConversionPatternRewriter& rewriter); + +template +class ConvertScatterOp : public OpConversionPattern { + public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + mhlo::ScatterOp scatter_op, OpAdaptor adaptor, + ConversionPatternRewriter& rewriter) const final; +}; + +using ConvertScatterAddOp = + ConvertScatterOp; +using ConvertScatterMaxOp = + ConvertScatterOp; +using ConvertScatterMinOp = + ConvertScatterOp; +using ConvertScatterSubOp = + ConvertScatterOp; +using ConvertScatterUpdateOp = + ConvertScatterOp; + +template class ConvertScatterOp; +template class ConvertScatterOp; +template class ConvertScatterOp; +template class ConvertScatterOp; +template class ConvertScatterOp; + +} // end namespace odml +} // end namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_HLO_CONVERSIONS_SCATTER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/slice.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/slice.h new file mode 100644 index 00000000..024cbb4a --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/slice.h @@ -0,0 +1,34 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_HLO_CONVERSIONS_SLICE_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_HLO_CONVERSIONS_SLICE_H_ + +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/Transforms/DialectConversion.h" // from @llvm-project +#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" + +namespace mlir::odml { + +// Patterns to legalize mhlo.slice to TFL. +void PopulateLegalizeSlicePatterns(MLIRContext* ctx, + RewritePatternSet& patterns, + ConversionTarget& target); + +void PopulatePrepareSlicePatterns(MLIRContext* ctx, + RewritePatternSet& patterns); + +} // namespace mlir::odml + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_HLO_CONVERSIONS_SLICE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/sort.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/sort.h new file mode 100644 index 00000000..c293bad9 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/sort.h @@ -0,0 +1,29 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_HLO_CONVERSIONS_SORT_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_HLO_CONVERSIONS_SORT_H_ + +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/Transforms/DialectConversion.h" // from @llvm-project + +namespace mlir::odml { + +void PopulateSortPatterns(MLIRContext* ctx, RewritePatternSet& patterns, + ConversionTarget& target); + +} // namespace mlir::odml + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_HLO_CONVERSIONS_SORT_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/util.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/util.h new file mode 100644 index 00000000..c72fce3f --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/util.h @@ -0,0 +1,168 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_HLO_CONVERSIONS_UTIL_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_HLO_CONVERSIONS_UTIL_H_ + +#include +#include + +#include "llvm/ADT/SmallVector.h" +#include "mlir/IR/Block.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/Region.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Transforms/DialectConversion.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" + +namespace mlir { +namespace odml { + +struct PermutationAndShape { + DenseIntElementsAttr permutation; + ShapedType shape; +}; + +// Check that `arr` is an R1 iota with integer element type starting from +// `start` with `size` number of values. +bool IsIotaAttr(ArrayRef arr, int64_t size, int64_t start = 0); + +// Returns a DenseIntElementsAttr for a permutation and the shape after +// applying the permutation to a given shape through a transpose. +PermutationAndShape GetPermutationAndTransposedShape( + llvm::ArrayRef permutation_array, ShapedType input_type, + ConversionPatternRewriter& rewriter); + +// Create a single const integer. +Value BuildIntConstOp(ImplicitLocOpBuilder& builder, + ConversionPatternRewriter& rewriter, int64_t const_value, + Type type); + +// Create a const integer vector tensor (1-dim). +template +Value BuildIntArrayConstOp(ImplicitLocOpBuilder& builder, + ConversionPatternRewriter& rewriter, + ArrayRef const_value, Type type) { + DenseIntElementsAttr const_value_raw; + if (type == rewriter.getI64Type()) { + const_value_raw = rewriter.getI64TensorAttr(const_value); + } else { + // Convert I64 const array to I32. + llvm::SmallVector const_i32_vec; + for (auto element : const_value) { + const_i32_vec.push_back(static_cast(element)); + } + const_value_raw = rewriter.getI32TensorAttr(const_i32_vec); + } + Value result_const = builder.create(const_value_raw); + return result_const; +} + +// Returns the inverse permutation array for a permutation array. +llvm::SmallVector GetInversePermutationArray( + llvm::ArrayRef permutation_array); + +// Returns the DenseIntElementsAttr for an inverse permutation given a +// permutation_array. +DenseIntElementsAttr GetInversePermutation( + llvm::ArrayRef permutation_array, + ConversionPatternRewriter& rewriter); + +// Returns a DenseIntElementsAttr for an inverse permutation and the shape after +// applying the inverse permutation to a given shape through a transpose. +PermutationAndShape GetInversePermutationAndShape( + llvm::ArrayRef permutation_array, ShapedType input_type, + ConversionPatternRewriter& rewriter); + +// Returns true if the op needs reformat. +bool NeedsReformatTypeAndPermutation(int batch_dim, int feature_dim, + int spatial_dim_start, + int default_batch_dim, + int default_feature_dim, + int default_spatial_dim_start); + +// Gets reformat type and permutation attribute. Call this function only if +// NeedsReformatTypeAndPermutation returns true. If +// NeedsReformatTypeAndPermutation returns false, this function returns the pair +// of input type and no-op permutation. + +std::pair GetReformatTypeAndPermutation( + int batch_dim, int feature_dim, int spatial_dim_start, + int default_batch_dim, int default_feature_dim, + int default_spatial_dim_start, int num_spatial_dims, RankedTensorType type, + ConversionPatternRewriter& rewriter); + +// Insert transpose so the input value is converted to the format specified by +// the default dims +Value InsertTranspose(Value value, int batch_dim, int feature_dim, + ArrayRef spatial_dimensions, + int default_batch_dim, int default_feature_dim, + int default_spatial_dim_start, int num_spatial_dims, + ConversionPatternRewriter& rewriter); + +// If index_vector_dim == indices.rank() then insert the implicit extra +// dimension into indices to normalize everything to index_vector_dim == +// indices.rank() - 1. +LogicalResult NormalizeIndexVector(Operation* parent_op, Value& indices, + ShapedType& indices_type, + int64_t index_vector_dim, + ConversionPatternRewriter& rewriter); + +// Checks if the specified region is a binary reduction function that takes 2 +// inputs, passes it to an instance of the specified reduction op and then +// returns the result. +template +LogicalResult MatchBinaryReduceFunction(mlir::Region& function) { + Block& body = function.front(); + if (body.getNumArguments() != 2) return failure(); + + mhlo::ReturnOp return_op = dyn_cast(body.back()); + if (!return_op) return failure(); + if (return_op.getNumOperands() != 1) return failure(); + + ReductionOp reduce_op = dyn_cast_or_null( + return_op.getOperands().front().getDefiningOp()); + if (!reduce_op) return failure(); + if (reduce_op.getLhs() != body.getArgument(0) || + reduce_op.getRhs() != body.getArgument(1)) + return failure(); + + return success(); +} + +// Check if the specified region is a binary reduction function that takes 2 +// inputs and returns the second input. Functions like this are used by update +// scatter like ops. +template <> +LogicalResult MatchBinaryReduceFunction(mlir::Region& function); + +// Util that casts 'val' to Int32 by adding a tfl cast Op. +Value CreateCastToInt32(Value val, Location loc, PatternRewriter& rewriter); + +// Replaces `region`'s terminator to TFL::Yield. +void ReplaceTerminatorWithYield(Region& region, PatternRewriter& rewriter); +} // namespace odml +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_HLO_CONVERSIONS_UTIL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/while.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/while.h new file mode 100644 index 00000000..3b302215 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/while.h @@ -0,0 +1,30 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_HLO_CONVERSIONS_WHILE_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_HLO_CONVERSIONS_WHILE_H_ + +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/Transforms/DialectConversion.h" // from @llvm-project +#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" + +namespace mlir::odml { + +void PopulateWhilePatterns(MLIRContext* ctx, RewritePatternSet& patterns, + ConversionTarget& target); + +} // namespace mlir::odml + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_HLO_CONVERSIONS_WHILE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_tf_passes.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_tf_passes.h new file mode 100644 index 00000000..9594769e --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_tf_passes.h @@ -0,0 +1,51 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_TF_PASSES_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_TF_PASSES_H_ + +#include +#include + +#include "llvm/ADT/StringRef.h" +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassRegistry.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Transforms/DialectConversion.h" // from @llvm-project + +namespace mlir { + +namespace func { +class FuncOp; +} +class ModuleOp; +class Operation; +template +class OperationPass; +class Pass; + +namespace odml { + +/// Adds the TF to TF lowerings and TF to XLA rewrite patterns to the pattern +/// list. +void PopulateLegalizeTfPatterns(MLIRContext* context, + RewritePatternSet* patterns); + +} // namespace odml +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_TF_PASSES_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_tf_xla_call_module_to_stablehlo_pass.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_tf_xla_call_module_to_stablehlo_pass.h new file mode 100644 index 00000000..9bcee095 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_tf_xla_call_module_to_stablehlo_pass.h @@ -0,0 +1,35 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_TF_XLA_CALL_MODULE_TO_STABLEHLO_PASS_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_TF_XLA_CALL_MODULE_TO_STABLEHLO_PASS_H_ + +#include + +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project + +namespace mlir { +namespace odml { + +// Adds passes which transform TF_XlaCallModule Op to StableHLO Ops. +// Note that this pass only supports static shape tensors for now. +std::unique_ptr> +CreateLegalizeTFXlaCallModuleToStablehloPass(); + +} // namespace odml +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_TF_XLA_CALL_MODULE_TO_STABLEHLO_PASS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/stablehlo/transforms/op_stat_pass.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/stablehlo/transforms/op_stat_pass.h new file mode 100644 index 00000000..8d57016b --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/stablehlo/transforms/op_stat_pass.h @@ -0,0 +1,36 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_OP_STAT_PASS_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_OP_STAT_PASS_H_ + +#include +#include +#include + +#include "mlir/Pass/Pass.h" // from @llvm-project + +namespace mlir { +namespace odml { + +// Creates a pass which prints out a detailed report of conversion stats with: +// success or not, % of Ops non-converted, list of non-converted Ops, etc. +std::unique_ptr createPrintOpStatsPass( + std::vector accepted_dialects); + +} // namespace odml +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_OP_STAT_PASS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/stablehlo/transforms/rename_entrypoint_to_main.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/stablehlo/transforms/rename_entrypoint_to_main.h new file mode 100644 index 00000000..e56b7130 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/stablehlo/transforms/rename_entrypoint_to_main.h @@ -0,0 +1,31 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_RENAME_ENTRYPOINT_TO_MAIN_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_RENAME_ENTRYPOINT_TO_MAIN_H_ + +#include + +#include "mlir/Pass/Pass.h" // from @llvm-project + +namespace mlir { +namespace odml { + +std::unique_ptr CreateRenameEntrypointToMainPass(); + +} // namespace odml +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_RENAME_ENTRYPOINT_TO_MAIN_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/stablehlo/transforms/smuggle_disallowed_ops.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/stablehlo/transforms/smuggle_disallowed_ops.h new file mode 100644 index 00000000..61e076e8 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/stablehlo/transforms/smuggle_disallowed_ops.h @@ -0,0 +1,31 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_SMUGGLE_DISALLOWED_OPS_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_SMUGGLE_DISALLOWED_OPS_H_ + +#include + +#include "mlir/Pass/Pass.h" // from @llvm-project + +namespace mlir { +namespace odml { + +std::unique_ptr CreateSmuggleDisallowedOpsPass(); + +} // namespace odml +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_SMUGGLE_DISALLOWED_OPS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/stablehlo/transforms/stablehlo_passes.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/stablehlo/transforms/stablehlo_passes.h new file mode 100644 index 00000000..7a02085c --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/stablehlo/transforms/stablehlo_passes.h @@ -0,0 +1,84 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_STABLEHLO_PASSES_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_STABLEHLO_PASSES_H_ + +#include + +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project + +namespace mlir { +namespace odml { + +// Unfuses MHLO batch norm inference op into arithmetic ops. +std::unique_ptr createUnfuseBatchNormPass(); + +// Constant folds broadcast_in_dim op conditionally. +std::unique_ptr createFoldBroadcastPass(); + +// Fuses MHLO binary element-wise ops and convolution op. +std::unique_ptr createFuseConvolutionPass(); + +// Applies various optimizations on MHLO IR. +std::unique_ptr createOptimizePass(); + +// Finds quantization patterns and compose them to uniform +// quantized types. +std::unique_ptr> +CreateComposeUniformQuantizedTypePass(); + +// Finds stablehlo ops that accept or produce uniform +// quantized typed tensors and converts them to equivalent ops in the TFLite +// dialect. +std::unique_ptr> +CreateUniformQuantizedStableHloToTflPass(); + +// Commutes transposes through specific ops +std::unique_ptr> CreateTransposeCommuteOpsPass(); + +// Legalizes MHLO to TF dialect. +std::unique_ptr> CreateLegalizeHloToTfPass(); + +// Replaces a splat constant tensor with a BroadcastInDim +// op. +std::unique_ptr> CreateUnfoldSplatConstantPass(); + +// Legalizes MHLO to TFLite dialect. +std::unique_ptr> CreateLegalizeHloToTfLitePass(); + +// Lowers stablehlo composite ops to tflite ops. +std::unique_ptr> CreateCompositeLoweringPass(); + +// Legalizes CHLO to tflite dialect. +std::unique_ptr> CreateLegalizeChloToTflPass(); + +// Rewrites MHLO in preparation for tflite legalization. +std::unique_ptr> CreatePrepareHloPass(); + +// Adds the HLO to TF rewrite patterns to the specified pattern list. +void PopulateLegalizeHloToTfPatterns(RewritePatternSet* patterns, + MLIRContext* context); + +#define GEN_PASS_DECL +#define GEN_PASS_REGISTRATION +#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/stablehlo_passes.h.inc" + +} // namespace odml +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_STABLEHLO_PASSES_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/stablehlo/transforms/stablehlo_util.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/stablehlo/transforms/stablehlo_util.h new file mode 100644 index 00000000..066bcc00 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/stablehlo/transforms/stablehlo_util.h @@ -0,0 +1,43 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_STABLEHLO_UTIL_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_STABLEHLO_UTIL_H_ + +#include +#include + +#include "llvm/ADT/StringRef.h" + +namespace mlir { +namespace odml { + +std::vector GetAcceptedStableHLODialects(); + +std::vector GetAcceptedTFLiteDialects(); + +// Can we find the given `dialect_name` in the `accepted_dialects`? +bool IsAcceptedDialect(llvm::StringRef dialect_name, + const std::vector &accepted_dialects); + +// The consolidated logic to verify if each final op is acceptable or not. +// Also see `PrintOpStatsPass` and `CheckAcceptedOpsPass`. +bool IsAcceptedOp(llvm::StringRef dialect_name, llvm::StringRef op_name, + const std::vector &accepted_dialects); + +} // namespace odml +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_STABLEHLO_UTIL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/stablehlo/transforms/tf_stablehlo_pass.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/stablehlo/transforms/tf_stablehlo_pass.h new file mode 100644 index 00000000..c26a3f36 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/stablehlo/transforms/tf_stablehlo_pass.h @@ -0,0 +1,33 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_TF_STABLEHLO_PASS_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_TF_STABLEHLO_PASS_H_ + +#include "mlir/Pass/PassManager.h" // from @llvm-project + +namespace mlir { +namespace odml { + +// Adds passes which transform TF Ops to StableHLO Ops. +void AddLegalizeTFToStablehloPasses(OpPassManager& pm, + bool skip_quantization_ops, + bool skip_resize, + bool skip_partitioned_calls); + +} // namespace odml +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_TF_STABLEHLO_PASS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/stablehlo/transforms/tfl_stablehlo_pass.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/stablehlo/transforms/tfl_stablehlo_pass.h new file mode 100644 index 00000000..e6e40762 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/stablehlo/transforms/tfl_stablehlo_pass.h @@ -0,0 +1,34 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_TFL_STABLEHLO_PASS_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_TFL_STABLEHLO_PASS_H_ + +#include +#include + +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project + +namespace mlir { +namespace odml { + +// Creates a pass which transforms TFLite to StableHLO Ops. +std::unique_ptr> CreateTflToStablehloPass(); + +} // namespace odml +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_TFL_STABLEHLO_PASS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/stablehlo/transforms/transforms.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/stablehlo/transforms/transforms.h new file mode 100644 index 00000000..abcdd827 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/stablehlo/transforms/transforms.h @@ -0,0 +1,44 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_TRANSFORMS_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_TRANSFORMS_H_ + +#include "mlir/Pass/PassManager.h" // from @llvm-project + +namespace mlir { +namespace odml { + +// Adds all the necessary passes to lower a TF module to StableHLO. +// `skip_resize` enables or disables skipping conversion of tf.ResizeBilinear +// and tf.ResizeNearestNeighbor ops. +// `smuggle_disallowed_ops` enables or disables converting disallowed ops +// like tf.ResizeBilinear or tf.ResizeNearestNeighbor to mhlo.custom_call ops. +void AddTFToStablehloPasses(OpPassManager& pm, bool skip_resize, + bool smuggle_disallowed_ops); + +// This function is a common entry point for all graph optimizations that are +// not specific to any hardware. It legalizes SHLO->MHLO, does MHLO->MHLO +// optimizations by calling `AddMhloOptimizationPasses` internally, and +// legalizes MHLO->SHLO +void AddStablehloOptimizationPasses(OpPassManager& pm); + +// Adds all the backend-agonstic stableHLO optimization passes +void AddMhloOptimizationPasses(OpPassManager& pm, bool add_fold_broadcast_pass); + +} // namespace odml +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_TRANSFORMS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/stablehlo/transforms/utils.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/stablehlo/transforms/utils.h new file mode 100644 index 00000000..fc7c2316 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/stablehlo/transforms/utils.h @@ -0,0 +1,63 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_UTILS_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_UTILS_H_ + +#include + +#include "llvm/ADT/ArrayRef.h" +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" + +namespace mlir { +namespace odml { + +// Builds body for reduce op by using the template binary op as the +// reducer op. +template +void BuildReduceBody(Type element_type, Region* body, OpBuilder* builder) { + OpBuilder::InsertionGuard guard(*builder); + Block* block = builder->createBlock(body); + + // Block arguments are scalars of the given element type. + Type type = RankedTensorType::get(/*shape=*/{}, element_type); + Location loc = body->getLoc(); + block->addArguments({type, type}, SmallVector(2, loc)); + + auto reducer = + builder->create(loc, block->getArgument(0), block->getArgument(1)); + builder->create(loc, reducer.getResult()); +} + +mhlo::ConstantOp GetScalarConstOfType(Type ty, Location loc, int64_t raw_value, + OpBuilder* builder); + +mhlo::ConstantOp GetScalarNegZeroOfType(Type ty, Location loc, + OpBuilder* builder); + +// Converts an ArrayAttr to a 1D 64-bit dense elements attribute. +DenseIntElementsAttr GetI64ElementsAttr(ArrayAttr attr); +DenseIntElementsAttr GetI64ElementsAttr(llvm::ArrayRef values, + Builder* builder); + +} // namespace odml +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_UTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/stateful_error_reporter.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/stateful_error_reporter.h new file mode 100644 index 00000000..fbb82d3e --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/stateful_error_reporter.h @@ -0,0 +1,36 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_STATEFUL_ERROR_REPORTER_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_STATEFUL_ERROR_REPORTER_H_ + +// LINT.IfChange +#include + +#include "tensorflow/compiler/mlir/lite/core/api/error_reporter.h" + +namespace tflite_migration { + +// Similar to tflite::ErrorReporter, except that it allows callers to get the +// last error message. +class StatefulErrorReporter : public tflite::ErrorReporter { + public: + // Returns last error message. Returns empty string if no error is reported. + virtual std::string message() = 0; +}; + +} // namespace tflite_migration +// LINT.ThenChange(//tensorflow/lite/stateful_error_reporter.h) + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_STATEFUL_ERROR_REPORTER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/tf_tfl_passes.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/tf_tfl_passes.h new file mode 100644 index 00000000..3ad5e52b --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/tf_tfl_passes.h @@ -0,0 +1,94 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_TF_TFL_PASSES_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_TF_TFL_PASSES_H_ + +#include "mlir/Pass/PassManager.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h" +#include "tensorflow/compiler/mlir/lite/converter_flags.pb.h" + +namespace tensorflow { + +// Add the TF to TFLite passes, specified in the pass_config, into a +// pass_manager. The session object will be provided when the TF MLIR is +// imported from saved model version one and utilized for capturing resource +// variables. If the `saved_model_dir` directory path is provided, then the +// `tf_saved_model.asset` ops will be freezed. +void AddTFToTFLConversionPasses(llvm::StringRef saved_model_dir, + const tflite::ConverterFlags& converter_flags, + const mlir::TFL::PassConfig& pass_config, + mlir::OpPassManager* pass_manager); + +// Adds the first portion of StableHLO->TF passes happening before quantization. +// The `pass_manager` that runs on a `mlir::ModuleOp` expects a graph containing +// a `mlir::TF::XlaCallModuleOp` with serialized StableHLO module. The resulting +// `mlir::ModuleOp` after running these passes will be an MHLO module, or a +// StableHLO module if `pass_config.enable_stablehlo_quantizer` is `true`. This +// is because StableHLO Quantizer accepts StableHLO modules. +void AddPreQuantizationStableHloToTfPasses( + mlir::StringRef entry_function_name, + const mlir::TFL::PassConfig& pass_config, + mlir::OpPassManager& pass_manager); + +// Adds the second portion of StableHlo->TF passes happening after quantization. +// The input module is expected to be an MHLO module, or a quantized StableHLO +// graph (expressed as `mlir::TF::XlaCallModuleOp`s) if +// `pass_config.enable_stablehlo_quantizer` is `true`. +void AddPostQuantizationStableHloToTfPasses( + const mlir::TFL::PassConfig& pass_config, + mlir::OpPassManager& pass_manager); + +// This is the early part of the conversion in isolation. This enables a caller +// to inject more information in the middle of the conversion before resuming it +// (like freezing variables for example). +void AddPreVariableFreezingTFToTFLConversionPasses( + const mlir::TFL::PassConfig& pass_config, + mlir::OpPassManager* pass_manager); + +// This is the later part of the conversion in isolation. This enables a caller +// to resume the conversion after injecting more information in the middle of +// it. +void AddPostVariableFreezingTFToTFLConversionPasses( + llvm::StringRef saved_model_dir, + const tflite::ConverterFlags& converter_flags, + const mlir::TFL::PassConfig& pass_config, + mlir::OpPassManager* pass_manager); + +// Adds the passes that freeze variables from global tensors and unfreeze +// mutable global tensors. `pass_config` is used to determine whether to freeze +// variables and `pass_manager` will be populated with the passes to run. +void AddVariableFreezingFromGlobalTensorsPasses( + const tflite::ConverterFlags& converter_flags, + const mlir::TFL::PassConfig& pass_config, + mlir::OpPassManager* pass_manager); + +// Simplified API for TF->TFLite conversion with default flags. +void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config, + mlir::OpPassManager* pass_manager); + +// Add the Quantization passes, specified in the pass_config, into a pass +// manager. +void AddQuantizationPasses(const mlir::TFL::PassConfig& pass_config, + mlir::OpPassManager& pass_manager); + +// Add the DynamicRangeQuantization passes, specified in the pass_config, into a +// pass manager. +void AddDynamicRangeQuantizationPasses(const mlir::TFL::PassConfig& pass_config, + mlir::OpPassManager& pass_manager); +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_TF_TFL_PASSES_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/tf_tfl_translate_cl.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/tf_tfl_translate_cl.h new file mode 100644 index 00000000..e002fd34 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/tf_tfl_translate_cl.h @@ -0,0 +1,79 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_TF_TFL_TRANSLATE_CL_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_TF_TFL_TRANSLATE_CL_H_ + +// This file contains command-line options aimed to provide the parameters +// required by the TensorFlow Graph(Def) to TF Lite Flatbuffer conversion. It is +// only intended to be included by binaries. + +#include + +#include "llvm/Support/CommandLine.h" + +// The commandline options are defined in LLVM style, so the caller should +// use llvm::InitLLVM to initialize the options. +// +// Please see the implementation file for documentation of details of these +// options. +// TODO(jpienaar): Revise the command line option parsing here. +extern llvm::cl::opt input_file_name; +extern llvm::cl::opt output_file_name; +extern llvm::cl::opt use_splatted_constant; +extern llvm::cl::opt input_mlir; +extern llvm::cl::opt output_mlir; +extern llvm::cl::list custom_opdefs; +extern llvm::cl::opt emit_quant_adaptor_ops; +extern llvm::cl::opt quant_stats_file_name; +extern llvm::cl::opt convert_tf_while_to_tfl_while; +extern llvm::cl::opt select_user_tf_ops; +extern llvm::cl::opt allow_all_select_tf_ops; +extern llvm::cl::opt unfold_batchmatmul; +extern llvm::cl::opt unfold_large_splat_constant; +extern llvm::cl::opt guarantee_all_funcs_one_use; +extern llvm::cl::opt enable_dynamic_update_slice; +extern llvm::cl::opt preserve_assert_op; +extern llvm::cl::opt legalize_custom_tensor_list_ops; +extern llvm::cl::opt reduce_type_precision; + +// Import saved model. +extern llvm::cl::opt import_saved_model_object_graph; +extern llvm::cl::opt import_saved_model_signature_defs; +extern llvm::cl::opt saved_model_tags; +extern llvm::cl::opt saved_model_exported_names; + +// Import HLO. +enum HloImportType { proto, hlotxt, mlir_text }; + +extern llvm::cl::opt import_hlo; +extern llvm::cl::opt hlo_import_type; + +// enable_hlo_to_tf_conversion and disable_hlo_to_tfl_conversion are used to +// control the HLO to TF and HLO to TFLite conversion while debugging an +// input_mlir. The default value of enable_hlo_to_tf_conversion is false, and +// the default value of disable_hlo_to_tfl_conversion is true. +extern llvm::cl::opt enable_hlo_to_tf_conversion; +extern llvm::cl::opt disable_hlo_to_tfl_conversion; + +// quantization related flags +extern llvm::cl::opt post_training_quantization; + +// TF to stablehlo pass flags +extern llvm::cl::opt enable_stablehlo_conversion; + +// Whether serialize stablehlo ops or not +extern llvm::cl::opt serialize_stablehlo_ops; +#endif // TENSORFLOW_COMPILER_MLIR_LITE_TF_TFL_TRANSLATE_CL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h new file mode 100644 index 00000000..ec8569a1 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h @@ -0,0 +1,91 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_TF_TO_TFL_FLATBUFFER_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_TF_TO_TFL_FLATBUFFER_H_ + +#include +#include +#include +#include + +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/SourceMgr.h" +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/OwningOpRef.h" // from @llvm-project +#include "mlir/Pass/PassManager.h" // from @llvm-project +#include "tensorflow/cc/saved_model/loader.h" +#include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h" +#include "tensorflow/compiler/mlir/lite/converter_flags.pb.h" +#include "tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_config.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/python/py_function_lib.h" +#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h" +#include "tensorflow/core/platform/status.h" + +namespace tensorflow { + +// Load a TF model from a GraphDef definition or a TF control flow dialect MLIR +// source into a MLIR module. If `input_mlir` is true, load from a MLIR source +// file; otherwise, load from a GraphDef. +// Setting prune_unused_nodes to true, would prune unreachable nodes if +// output_arrays is specified. +absl::StatusOr> LoadFromGraphdefOrMlirSource( + const std::string& input_filename, bool input_mlir, + bool use_splatted_constant, const std::vector& extra_tf_opdefs, + const GraphImportConfig& specs, absl::string_view debug_info_file, + absl::string_view input_arrays, absl::string_view input_dtypes, + absl::string_view input_shapes, absl::string_view output_arrays, + absl::string_view control_output_arrays, llvm::SourceMgr* source_mgr, + mlir::MLIRContext* context); + +// Load Saved model (either v1 or v2) into MLIR. +// 'saved_model_bundle' will be initialized if V1 model was loaded. +absl::StatusOr> ImportSavedModel( + const std::string& input_filename, int saved_model_version, + const std::unordered_set& tags, + absl::Span extra_tf_opdefs, + absl::Span exported_names, const GraphImportConfig& specs, + bool enable_variable_lifting, mlir::MLIRContext* context, + std::unique_ptr* saved_model_bundle); + +// Taking a MLIR module in TF executor dialect and a set of parameters, +// applies a set of passes (configured accordingly to the provided +// `pass_config`) to convert the module to TF Lite dialect and serializes the +// result to a string. Depending on an attribute in the module main function, +// full integer quantization is applied. +// * `quantizated_buffer_type` can be set to INT8 or FLOAT16 to trigger the +// corresponding weight quantization. +// * `export_to_mlir` enables exporting to MLIR text format, otherwise exported +// in flat buffer. If the +// * `session` pointer may provided, it will be used to freeze resource +// variables. If the `saved_model_dir` directory path is provided, then the +// `tf_saved_model.asset` ops will be freezed. +absl::Status ConvertTFExecutorToTFLOrFlatbuffer( + std::unique_ptr&& context, + mlir::OwningOpRef module, + tflite::ConverterFlags& converter_flags, + const mlir::TFL::PassConfig& pass_config, + const std::unordered_set& saved_model_tags, + llvm::StringRef saved_model_dir, std::string* result, + bool serialize_stablehlo_ops, bool export_to_mlir, + const quantization::PyFunctionLibrary* quantization_py_function_lib = + nullptr); +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_TF_TO_TFL_FLATBUFFER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/tools/command_line_flags.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/tools/command_line_flags.h new file mode 100644 index 00000000..41e70c94 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/tools/command_line_flags.h @@ -0,0 +1,170 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_TOOLS_COMMAND_LINE_FLAGS_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_TOOLS_COMMAND_LINE_FLAGS_H_ + +#include +#include +#include + +// TODO(b/321735756): Remove this file once common library is implemented with +// the originial file. + +// LINT.IfChange + +namespace mlir { +// A simple command-line argument parsing module. +// Dependency free simplified port of core/util/command_line_flags. +// This class is written for benchmarks and uses inefficient string +// concatenation. This was written to avoid dependency on tensorflow/core/util +// which transitively brings in a lot of other dependencies that are not +// necessary for tflite benchmarking code. +// The recommended way of using it is with local variables and an initializer +// list of Flag objects, for example: +// +// int some_int = 10; +// bool some_switch = false; +// std::string some_name = "something"; +// +// std::vector flag_list = { +// Flag::CreateFlag("some_int", &some_int, "an integer that affects X"), +// Flag::CreateFlag("some_switch", &some_switch, "a bool that affects Y"), +// Flag::CreateFlag("some_name", &some_name, "a string that affects Z") +// }; +// // Get usage message before ParseFlags() to capture default values. +// std::string usage = Flag::Usage(argv[0], flag_list); +// bool parsed_values_ok = Flags::Parse(&argc, argv, flag_list); +// +// tensorflow::port::InitMain(usage.c_str(), &argc, &argv); +// if (argc != 1 || !parsed_values_ok) { +// ...output usage and error message... +// } +// +// The argc and argv values are adjusted by the Parse function so all that +// remains is the program name (at argv[0]) and any unknown arguments fill the +// rest of the array. This means you can check for flags that weren't understood +// by seeing if argv is greater than 1. +// The result indicates if there were any errors parsing the values that were +// passed to the command-line switches. For example, --some_int=foo would return +// false because the argument is expected to be an integer. +// +// NOTE: Unlike gflags-style libraries, this library is intended to be +// used in the `main()` function of your binary. It does not handle +// flag definitions that are scattered around the source code. + +// A description of a single command line flag, holding its name, type, usage +// text, and a pointer to the corresponding variable. +class Flag { + public: + enum FlagType { + kPositional = 0, + kRequired, + kOptional, + }; + + // The order of the positional flags is the same as they are added. + // Positional flags are supposed to be required. + template + static Flag CreateFlag(const char* name, T* val, const char* usage, + FlagType flag_type = kOptional) { + return Flag( + name, [val](const T& v) { *val = v; }, *val, usage, flag_type); + } + +// "flag_T" is same as "default_value_T" for trivial types, like int32, bool +// etc. But when it's a complex type, "default_value_T" is generally a const +// reference "flag_T". +#define CONSTRUCTOR_WITH_ARGV_INDEX(flag_T, default_value_T) \ + Flag(const char* name, \ + const std::function& hook, \ + default_value_T default_value, const std::string& usage_text, \ + FlagType flag_type); + +#define CONSTRUCTOR_WITHOUT_ARGV_INDEX(flag_T, default_value_T) \ + Flag(const char* name, const std::function& hook, \ + default_value_T default_value, const std::string& usage_text, \ + FlagType flag_type) \ + : Flag( \ + name, [hook](const flag_T& flag_val, int) { hook(flag_val); }, \ + default_value, usage_text, flag_type) {} + + CONSTRUCTOR_WITH_ARGV_INDEX(int32_t, int32_t) + CONSTRUCTOR_WITHOUT_ARGV_INDEX(int32_t, int32_t) + + CONSTRUCTOR_WITH_ARGV_INDEX(int64_t, int64_t) + CONSTRUCTOR_WITHOUT_ARGV_INDEX(int64_t, int64_t) + + CONSTRUCTOR_WITH_ARGV_INDEX(float, float) + CONSTRUCTOR_WITHOUT_ARGV_INDEX(float, float) + + CONSTRUCTOR_WITH_ARGV_INDEX(bool, bool) + CONSTRUCTOR_WITHOUT_ARGV_INDEX(bool, bool) + + CONSTRUCTOR_WITH_ARGV_INDEX(std::string, const std::string&) + CONSTRUCTOR_WITHOUT_ARGV_INDEX(std::string, const std::string&) + +#undef CONSTRUCTOR_WITH_ARGV_INDEX +#undef CONSTRUCTOR_WITHOUT_ARGV_INDEX + + FlagType GetFlagType() const { return flag_type_; } + + private: + friend class Flags; + + bool Parse(const std::string& arg, int argv_position, + bool* value_parsing_ok) const; + + std::string name_; + enum { + TYPE_INT32, + TYPE_INT64, + TYPE_BOOL, + TYPE_STRING, + TYPE_FLOAT, + } type_; + + std::function + value_hook_; + std::string default_for_display_; + + std::string usage_text_; + FlagType flag_type_; +}; + +class Flags { + public: + // Parse the command line represented by argv[0, ..., (*argc)-1] to find flag + // instances matching flags in flaglist[]. Update the variables associated + // with matching flags, and remove the matching arguments from (*argc, argv). + // Return true iff all recognized flag values were parsed correctly, and the + // first remaining argument is not "--help". + // Note: + // 1. when there are duplicate args in argv for the same flag, the flag value + // and the parse result will be based on the 1st arg. + // 2. when there are duplicate flags in flag_list (i.e. two flags having the + // same name), all of them will be checked against the arg list and the parse + // result will be false if any of the parsing fails. + // See *Duplicate* unit tests in command_line_flags_test.cc for the + // illustration of such behaviors. + static bool Parse(int* argc, const char** argv, + const std::vector& flag_list); +}; +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_TOOLS_COMMAND_LINE_FLAGS_H_ + +// LINT.ThenChange(//tensorflow/lite/tools/command_line_flags.h) diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/tools/optimize/operator_property.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/tools/optimize/operator_property.h new file mode 100644 index 00000000..5401fcdd --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/tools/optimize/operator_property.h @@ -0,0 +1,157 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_TOOLS_OPTIMIZE_OPERATOR_PROPERTY_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_TOOLS_OPTIMIZE_OPERATOR_PROPERTY_H_ + +#include +#include +#include +#include +#include + +#include "tensorflow/compiler/mlir/lite/schema/schema_generated.h" + +namespace tflite { +namespace optimize { +namespace operator_property { + +// The scales of a certain tensor can be derived from the multiplications of all +// the scales. For example, for bias in conv, derived_scale = {{0, 1}, {}, {}} +// and for lstm gate bias, the derived scale is {{}, {0}, {2^-10}} +struct DerivedScale { + // MSVC2015 version 14.0 and below doesn't support struct initialization with + // initializer lists so emulate the behavior using a float initializer list. +#if _MSC_VER <= 1900 + DerivedScale() = default; + // Construct this object with a list of initializer lists. All list elements + // are cast to float values to avoid ambiguous construction of a union-style + // object that could take either std::initializer_list or + // std::initializer_list. + DerivedScale(std::initializer_list> values) { + assert(values.size() == 3); + std::vector> items(values); + for (auto& it : items[0]) { + input_tensors.push_back(static_cast(it)); + } + for (auto& it : items[1]) { + intermediate_tensors.push_back(static_cast(it)); + } + factors.assign(items[2]); + } +#endif // _MSC_VER <= 1900 + + std::vector input_tensors = {}; + std::vector intermediate_tensors = {}; + // This is a list of extra factors that are not associated with any other + // tensor. + std::vector factors = {}; +}; + +struct TensorProperty { + // per_axis also implies symmetric currently. + bool per_axis = false; + // TODO(jianlijianli): remove dimension index and read it from tensor instead. + int per_axis_index = 0; + bool symmetric = false; + + // Constraints. + bool restriction = false; + // scale/zero_point hardcoded. + std::pair restricted_value_int8 = {0.0f, 0}; + std::pair restricted_value_int16 = {0.0f, 0}; + + // Use derived scale. + bool use_derived_scale = false; + // The derived scale. + DerivedScale derived_scale; + + // The number of bits for this tensor. It could be 8, 16, 32 or even not power + // of two. + int number_of_bits = 8; + + // Extend the range to power of two. + bool extend_to_power_of_two = false; + + // State tensor. + bool state_tensor = false; +}; + +struct OperatorProperty { + // Is a quantized operations currently supported. + bool quantizable = true; + // Is a quantized operations currently supported for 16x8 + bool quantizable_int16 = true; + // Op has arbitrary number of inputs, such as concat. + bool arbitrary_inputs = false; + // Op has arbitrary number of outputs, such as slice. + bool arbitrary_outputs = false; + // Input indexes -> input tensor property. + // Must be topologically sorted since there are derived scales. + std::vector> inputs = {}; + // Output indexes -> output tensor property. + std::vector> outputs = {}; + // Bias indexes. + // TODO(jianlijianli): remove this by putting biases into inputs as well since + // we now can model "derived scale". + std::vector biases = {}; + + // Intermediate indexes -> intermediate tensor property. + std::vector> intermediates = {}; + + // Force output to reuse the same scale and zero point of input when the + // certain type support must require the same scale and zero point + // requirement. + std::function restrict_same_input_output_scale = + [](TensorType) { return false; }; + + // Use same min of min and max of max for each group. + // Incompatible with restrict_same_input_output_scale and restricted_value. + // Currently it only supports scale pair of {input_index, output_index}. + std::vector> restrict_scale = {}; + + // Op version. + int version = 1; + + // When we quantize activations into 16 bit and weights into 8 bit, + // we want to quantize all inputs, including constant tensors, + // for the operators like Add, Mul into 16-bit as well. The constant + // inputs are quantized as weights and this variable indicates + // that we want to do quantizations of these tensors as activations. + bool quantize_input_as_activations = false; +}; + +// The op as well as it variants. +struct OpVariant { + BuiltinOperator op_code; + bool use_layer_norm = false; + bool use_projection = false; + bool use_peephole = false; + // An attribute to indicate if quantization is supported for this Op. + // This attribute is equivalent to the "quantizable" attribute in + // "OperatorProperty". It added here since OpVariants peeks inside the Op and + // determines its quantization related properties. + bool is_quantizable = true; +}; + +OperatorProperty GetOperatorProperty(const ModelT* model, int subgraph_index, + int op_index, int number_of_bits = 8); +OperatorProperty GetOperatorProperty(OpVariant op_variant, + int number_of_bits = 8); + +} // namespace operator_property +} // namespace optimize +} // namespace tflite + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_TOOLS_OPTIMIZE_OPERATOR_PROPERTY_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/tools/optimize/reduced_precision_metadata.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/tools/optimize/reduced_precision_metadata.h new file mode 100644 index 00000000..104cc638 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/tools/optimize/reduced_precision_metadata.h @@ -0,0 +1,119 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_TOOLS_OPTIMIZE_REDUCED_PRECISION_METADATA_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_TOOLS_OPTIMIZE_REDUCED_PRECISION_METADATA_H_ + +#include +#include +#include +#include + +#include "tensorflow/compiler/mlir/lite/kernels/internal/compatibility_macros.h" + +namespace tflite { +namespace optimize { +static constexpr char kTfLiteReducedPrecisionKey[] = + "reduced_precision_support"; + +static constexpr char kTfLiteFloat16String[] = "fp16"; +static constexpr char kTfLiteBfloat16String[] = "bf16"; +static constexpr char kTfLiteFloat32String[] = "fp32"; +static constexpr char kTfLiteAccumulationString[] = "acc"; + +enum class ReducedPrecisionSupport : std::uint8_t { + None = 0, + Float16Inference = 0x1, + Bfloat16Inference = 0x2, + Float16Accumulation = 0x4, + Float32Accumulation = 0x8, +}; + +inline ReducedPrecisionSupport operator|(ReducedPrecisionSupport a, + ReducedPrecisionSupport b) { + return static_cast(static_cast(a) | + static_cast(b)); +} + +inline ReducedPrecisionSupport& operator|=(ReducedPrecisionSupport& a, + ReducedPrecisionSupport b) { + return a = static_cast( + static_cast(a) | static_cast(b)); +} + +inline ReducedPrecisionSupport operator&(ReducedPrecisionSupport a, + ReducedPrecisionSupport b) { + return static_cast(static_cast(a) & + static_cast(b)); +} + +inline ReducedPrecisionSupport& operator&=(ReducedPrecisionSupport& a, + ReducedPrecisionSupport b) { + return a = static_cast( + static_cast(a) & static_cast(b)); +} + +inline bool SupportsFP16Inference(const ReducedPrecisionSupport& mask) { + return static_cast(mask & ReducedPrecisionSupport::Float16Inference); +} + +inline bool SupportsBfloat16Inference(const ReducedPrecisionSupport& mask) { + return static_cast(mask & ReducedPrecisionSupport::Bfloat16Inference); +} + +inline bool SupportsFP16Accumulation(const ReducedPrecisionSupport& mask) { + return static_cast(mask & ReducedPrecisionSupport::Float16Accumulation); +} + +inline bool SupportsFP32Accumulation(const ReducedPrecisionSupport& mask) { + return static_cast(mask & ReducedPrecisionSupport::Float32Accumulation); +} + +inline bool SupportsReducedPrecisionInference( + const ReducedPrecisionSupport& mask) { + return SupportsFP16Inference(mask) || SupportsBfloat16Inference(mask); +} + +inline bool SupportsEitherFP16OrFP32Accumulation( + const ReducedPrecisionSupport& mask) { + return SupportsFP16Accumulation(mask) != SupportsFP32Accumulation(mask); +} + +// Return the key-value pair for reduced precision support metadata. +// Example: mask = Float16Inference | Bfloat16Inference | Float32Accumulation; +// Returned value would be <"reduced_precision_support", "fp16bf16accfp32">. +inline std::pair MetadataForReducedPrecisionSupport( + const ReducedPrecisionSupport& mask) { + TFLITE_DCHECK(SupportsReducedPrecisionInference(mask)); + TFLITE_DCHECK(SupportsEitherFP16OrFP32Accumulation(mask)); + std::string value = ""; + if (SupportsFP16Inference(mask)) { + value += kTfLiteFloat16String; + } + if (SupportsBfloat16Inference(mask)) { + value += kTfLiteBfloat16String; + } + value += kTfLiteAccumulationString; + if (SupportsFP16Accumulation(mask)) { + value += kTfLiteFloat16String; + } else if (SupportsFP32Accumulation(mask)) { + value += kTfLiteFloat32String; + } + return std::make_pair(std::string(kTfLiteReducedPrecisionKey), value); +} + +} // namespace optimize +} // namespace tflite + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_TOOLS_OPTIMIZE_REDUCED_PRECISION_METADATA_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/tools/tf_mlir_translate_cl.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/tools/tf_mlir_translate_cl.h new file mode 100644 index 00000000..b3da62ca --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/tools/tf_mlir_translate_cl.h @@ -0,0 +1,54 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_TOOLS_TF_MLIR_TRANSLATE_CL_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_TOOLS_TF_MLIR_TRANSLATE_CL_H_ + +// This file contains command-line options aimed to provide the parameters +// required by the TensorFlow Graph(Def) to MLIR module conversion. It is only +// intended to be included by binaries. + +#include + +#include "llvm/Support/CommandLine.h" + +// Please see the implementation file for documentation of these options. + +// Import options. +extern llvm::cl::opt input_arrays; +extern llvm::cl::opt input_dtypes; +extern llvm::cl::opt input_shapes; +extern llvm::cl::opt output_arrays; +extern llvm::cl::opt control_output_arrays; +extern llvm::cl::opt inference_type; +extern llvm::cl::opt min_values; +extern llvm::cl::opt max_values; +extern llvm::cl::opt debug_info_file; +extern llvm::cl::opt xla_compile_device_type; +extern llvm::cl::opt prune_unused_nodes; +extern llvm::cl::opt convert_legacy_fed_inputs; +extern llvm::cl::opt graph_as_function; +extern llvm::cl::opt upgrade_legacy; +// TODO(jpienaar): Temporary flag, flip default and remove. +extern llvm::cl::opt enable_shape_inference; +extern llvm::cl::opt unconditionally_use_set_output_shapes; +extern llvm::cl::opt enable_soft_placement; +extern llvm::cl::opt set_original_tf_func_name; + +// Export options. +extern llvm::cl::opt export_entry_func_to_flib; +extern llvm::cl::opt export_original_tf_func_name; + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_TOOLS_TF_MLIR_TRANSLATE_CL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/tools/versioning/op_signature.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/tools/versioning/op_signature.h new file mode 100644 index 00000000..5799194f --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/tools/versioning/op_signature.h @@ -0,0 +1,96 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_TOOLS_VERSIONING_OP_SIGNATURE_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_TOOLS_VERSIONING_OP_SIGNATURE_H_ + +#include +#include +#include + +#include "tensorflow/compiler/mlir/lite/core/c/tflite_types.h" +#include "tensorflow/compiler/mlir/lite/schema/schema_generated.h" + +namespace tflite { + +// OpSignature contains operator parameters for version functions. +typedef struct { + TfLiteType type; + std::vector dims; + bool is_const; + bool is_shape_dynamic; +} OpSignatureTensorSpec; + +typedef struct { + BuiltinOperator op; + std::vector inputs; + std::vector outputs; + void* builtin_data; + int version; + const void* custom_initial_data; + std::string custom_name; + union { + struct { + bool is_per_channel_quantized; + bool is_grouped_convolution; + } conv_2d; + struct { + bool is_per_channel_quantized; + } depthwise_conv_2d; + struct { + // TODO(b/156530611): Make this global when more ops support sparse + // computation. + bool sparse_weight; + bool is_per_channel_quantized; + } fully_connected; + struct { + float input1_scale; + float input2_scale; + float output_scale; + bool input_quantized; + } mul; + struct { + int32_t num_dims; + } strided_slice; + struct { + bool input_quantized; + } abs; + struct { + bool is_per_channel_quantized; + } dequantize; + struct { + bool is_per_channel_quantized; + } quantize; + struct { + bool input_quantized; + } add; + struct { + bool is_per_channel_quantized; + } embedding_lookup; + } ext_options; +} OpSignature; + +// Generate OpSignature with the given OperatorCode, Operator and Tensors (from +// SubGraph). The OpSignature will be used by GetBuiltinOperatorVersion() and +// mostly input and output tensor types are enough to figure out op version. +// But some ops (DEPTHWISE_CONV_2D, FULLY_CONNECTED, ...) require to pass their +// options to decide op version. +// +// WARNING: The caller is responsible to free the allocated +// OpSignature.builtin_data memory. +OpSignature GetOpSignature(const OperatorCode* op_code, const Operator* op, + const SubGraph* subgraph, const Model* model); + +} // namespace tflite +#endif // TENSORFLOW_COMPILER_MLIR_LITE_TOOLS_VERSIONING_OP_SIGNATURE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/tools/versioning/op_version.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/tools/versioning/op_version.h new file mode 100644 index 00000000..bd1f5516 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/tools/versioning/op_version.h @@ -0,0 +1,33 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_TOOLS_VERSIONING_OP_VERSION_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_TOOLS_VERSIONING_OP_VERSION_H_ + +#include + +#include "tensorflow/compiler/mlir/lite/schema/mutable/schema_generated.h" // IWYU pragma: keep +#include "tensorflow/compiler/mlir/lite/tools/versioning/op_signature.h" + +namespace tflite { + +// Returns version of builtin ops by the given signature. +int GetBuiltinOperatorVersion(const OpSignature& op_sig); + +// Update operator's version of the given TFL flatbuffer model. +void UpdateOpVersion(uint8_t* model_buffer_pointer); + +} // namespace tflite + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_TOOLS_VERSIONING_OP_VERSION_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/tools/versioning/runtime_version.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/tools/versioning/runtime_version.h new file mode 100644 index 00000000..7d586df5 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/tools/versioning/runtime_version.h @@ -0,0 +1,40 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_TOOLS_VERSIONING_RUNTIME_VERSION_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_TOOLS_VERSIONING_RUNTIME_VERSION_H_ + +#include +#include + +#include "flatbuffers/flatbuffers.h" // from @flatbuffers // IWYU pragma: keep +#include "tensorflow/compiler/mlir/lite/schema/mutable/schema_generated.h" + +namespace tflite { +// Update minimum runtime version of the given TFL flatbuffer model. +void UpdateMinimumRuntimeVersionForModel(uint8_t* model_buffer_pointer); + +// Find the minimum runtime version of a given op version. Return an empty +// string the version is not registered. +std::string FindMinimumRuntimeVersionForOp(tflite::BuiltinOperator op_code, + int op_version); + +// Returns true if the first version string precedes the second. +// For example, '1.9' should precede '1.14', also '1.14' should precede +// '1.14.1'. If two version string is equal, then false will be returned. +bool CompareRuntimeVersion(const std::string&, const std::string&); + +} // namespace tflite + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_TOOLS_VERSIONING_RUNTIME_VERSION_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/transforms/canonicalize_boundary_value_pass.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/transforms/canonicalize_boundary_value_pass.h new file mode 100644 index 00000000..e9bd67f8 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/transforms/canonicalize_boundary_value_pass.h @@ -0,0 +1,58 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_CANONICALIZE_BOUNDARY_VALUE_PASS_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_CANONICALIZE_BOUNDARY_VALUE_PASS_H_ + +#include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project +#include "mlir/IR/DialectRegistry.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/TypeID.h" // from @llvm-project +#include "stablehlo/dialect/StablehloOps.h" // from @stablehlo // IWYU pragma: keep +#include "tensorflow/compiler/mlir/lite/transforms/pass.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h" + +namespace mlir { +namespace TFL { + +// Pass to canonicalize the IR representations of boundary values. + +class CanonicalizeBoundaryValuePass + : public TFL::Pass { + public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CanonicalizeBoundaryValuePass) + + CanonicalizeBoundaryValuePass() = default; + CanonicalizeBoundaryValuePass(const CanonicalizeBoundaryValuePass&) {}; + + void runOnOperation() override; + static llvm::StringRef GetName() { return "CanonicalizeBoundaryValuePass"; } + static llvm::StringRef GetArgument() { + return "tfl-canonicalize-boundary-value"; + } + static llvm::StringRef GetDescription() { + return "Pass to canonicalize the IR representations of boundary values"; + } + + private: + void getDependentDialects(DialectRegistry& registry) const override { + registry.insert(); + } +}; +} // namespace TFL +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_CANONICALIZE_BOUNDARY_VALUE_PASS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/transforms/converter_pass_options_setter.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/transforms/converter_pass_options_setter.h new file mode 100644 index 00000000..01f71afe --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/transforms/converter_pass_options_setter.h @@ -0,0 +1,51 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_CONVERTER_PASS_OPTIONS_SETTER_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_CONVERTER_PASS_OPTIONS_SETTER_H_ + +#include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h" +#include "tensorflow/compiler/mlir/lite/converter_flags.pb.h" +#include "tensorflow/compiler/mlir/lite/transforms/pass_options_setter.h" + +namespace mlir { +namespace TFL { + +class OptimizePassOptions; +class VariableFreezingPipelineOptions; +class EmptyPassOptions; + +// PassOptionsSetter to set TFLite Converter Pass/Pipeline Options based on +// ConverterFlags and TFL::PassConfig values. +class ConverterPassOptionsSetter : public PassOptionsSetter { + public: + explicit ConverterPassOptionsSetter( + const tflite::ConverterFlags& converter_flags, + const mlir::TFL::PassConfig& pass_config) + : converter_flags_(converter_flags), pass_config_(pass_config) {}; + ~ConverterPassOptionsSetter() override = default; + + void SetOptions(OptimizePassOptions& options) const override; + void SetOptions(VariableFreezingPipelineOptions& options) const override; + void SetOptions(EmptyPassOptions& options) const override; + + private: + tflite::ConverterFlags converter_flags_; + mlir::TFL::PassConfig pass_config_; +}; +} // namespace TFL +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_CONVERTER_PASS_OPTIONS_SETTER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/transforms/dense_to_sparse_pass.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/transforms/dense_to_sparse_pass.h new file mode 100644 index 00000000..fa39e09c --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/transforms/dense_to_sparse_pass.h @@ -0,0 +1,80 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This transformation pass convert dense tensor to sparse format. +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_DENSE_TO_SPARSE_PASS_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_DENSE_TO_SPARSE_PASS_H_ + +#include "llvm/ADT/APFloat.h" +#include "llvm/ADT/APInt.h" +#include "llvm/ADT/StringRef.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/TypeID.h" // from @llvm-project +#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" +#include "tensorflow/compiler/mlir/lite/transforms/pass.h" +#include "tensorflow/compiler/mlir/lite/transforms/pass_options.h" + +namespace mlir { +namespace TFL { + +// This pass encodes sparse weights in the model in the proper format, and adds +// Densify() op if necessary. The general algorithm is: +// 1. Get list of operands (weights) of an op that can be sparse. +// 2. Get list of supported block configurations of the op. +// 3. Calculate random sparsity of the weight. +// 3.1. If sparsity level is below the encoding threshold, keep in dense. +// 3.2. If sparsity level is above the encoding threshold, go to 4. +// 4. Try to encode the weight with supported block configurations. If the +// weight was pruned with the same block config, the blocked sparsity level +// should match the random sparsity. +// 4.1. Return the matching block config if found. +// 4.2. If no matching block config is found, encode the weight with random +// sparsity, and add Densify() op to fall back to dense execution. + +class DenseToSparsePass + : public Pass { + public: + DenseToSparsePass() = default; + DenseToSparsePass(const DenseToSparsePass &other) {} + + void runOnOperation() final; + + /// Returns the command-line argument attached to this pass. + static llvm::StringRef GetArgument() { return "tfl-dense-to-sparse"; } + + static llvm::StringRef GetDescription() { + return "Convert dense tensor to sparse format."; + } + + /// Returns the derived pass name. + static llvm::StringRef GetName() { return "DenseToSparsePass"; } + + /// Return the dialect that must be loaded in the context before this pass. + void getDependentDialects(::mlir::DialectRegistry ®istry) const override { + registry.insert(); + } + + /// Explicitly declare the TypeID for this class. We declare an explicit + /// private instantiation because Pass classes should only be visible by the + /// current library. + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DenseToSparsePass) +}; + +} // namespace TFL +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_DENSE_TO_SPARSE_PASS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/transforms/dilated_conv.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/transforms/dilated_conv.h new file mode 100644 index 00000000..fe8bb7d2 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/transforms/dilated_conv.h @@ -0,0 +1,498 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// This pass identifies patterns for dilated convolution and replace it with +// a real convolution op. + +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_DILATED_CONV_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_DILATED_CONV_H_ + +#include +#include + +#include "llvm/Support/Casting.h" +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/Matchers.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/TypeUtilities.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "tensorflow/compiler/mlir/lite/utils/validators.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" + +namespace mlir { +namespace TFL { + +// A dilated convolution can be emulated with a regular convolution by chaining +// SpaceToBatch and BatchToSpace ops before and after it: +// +// SpaceToBatchND -> Conv2D -> BatchToSpaceND +// +// This method was common before Conv2D fully supported dilated convolution in +// TensorFlow. This transformation detects this "emulation", and replaces it +// with a true dilated convolution, eliminating the SpaceToBatch and +// BatchtoSpace ops. +// +// Detecting this alone would be relatively easy. However, in practice some +// extra ops are used, so we detect the following patterns: +// +// +// SpaceToBatchND -> Expand -> Conv2D -> Squeeze -> BatchToSpaceND -> BiasAdd +// +// SpaceToBatchND -> Expand -> Conv2D -> Squeeze -> Pad -> BatchToSpaceND -> +// BiasAdd +// +// SpaceToBatchND -> Expand -> Conv2D -> Squeeze -> BiasAdd -> BatchToSpaceND +// +// SpaceToBatchND -> Conv2D -> Pad -> BatchToSpaceND -> BiasAdd +// +// SpaceToBatchND -> Conv2D -> BatchToSpaceND -> BiasAdd +// +// +// The Expand/Squeeze combination is used to adapt a 3D array (such as in +// WaveNet) to the 4D arrays that Conv2D requires. Padding and BiasAdd are +// thrown in just for the extra headache. Padding adapts non-conforming input +// sizes, and can be discarded. The bias is necessary, so is kept. +template +class ConvertTFDilatedConvOp : public OpRewritePattern { + private: + using OpRewritePattern::OpRewritePattern; + + // Extract the dilation factor from `block_shape` and pack it in an ArrayAttr. + std::optional ExtractDilationsAttrFromBlockShape( + Value stb_block_shape, Value bts_block_shape, int64_t expand_axis, + PatternRewriter& rewriter) const; + + public: + LogicalResult matchAndRewrite(Conv2dOpTy op, + PatternRewriter& rewriter) const override; +}; + +template +LogicalResult ConvertTFDilatedConvOp::matchAndRewrite( + Conv2dOpTy op, PatternRewriter& rewriter) const { + if (!op.getResult().hasOneUse()) { + return rewriter.notifyMatchFailure( + op, "result for current op has more than 1 use"); + } + // Make sure Conv2D has 'VALID' padding. + if (op->template getAttrOfType("padding").getValue() != "VALID") { + return rewriter.notifyMatchFailure(op, + "Conv2D op doesn't have valid padding"); + } + // Make sure dilations are all ones if set. + const ArrayAttr& dilations = + op->template getAttrOfType("dilations"); + if (dilations && !TFIntListIsAllOnes(dilations)) { + return rewriter.notifyMatchFailure(op, "dilations should be all 1"); + } + + if (!TFL::TFTypeIsFloat32Tensor(op.getInput()) && + !TFL::TFTypeIsBFloat16OrHalfTensor(op.getInput())) { + return rewriter.notifyMatchFailure( + op, "op's input is not float or half or bfloat16"); + } + if (!TFL::TFDataFormatIsNHWC(op)) { + return rewriter.notifyMatchFailure(op, "op's data format isn't NHWC"); + } + + // Allow dynamic width and height dimensions only. + auto result_ty = mlir::cast(op.getResult().getType()); + if (!result_ty.hasRank() || result_ty.getRank() != 4 || + result_ty.isDynamicDim(0) || result_ty.isDynamicDim(3)) { + return rewriter.notifyMatchFailure( + op, "only dynamic width and height dimensions are allowed"); + } + + // Check if the ConvOp's input is defined by `Expand` op, and the output used + // by `Squeeze` op. + Operation* producer_op = op.getOperand(0).getDefiningOp(); + if (!producer_op || producer_op->getNumResults() != 1) { + return rewriter.notifyMatchFailure( + op, "op doesn't have a producer node that has a single result"); + } + if (!producer_op->hasOneUse() || + *(producer_op->getResult(0).user_begin()) != op) { + return rewriter.notifyMatchFailure( + op, "op's input isn't produced by previous operation"); + } + + auto tryGetDirectConsumerOp = + [&rewriter](Operation* current) -> std::pair { + // Check the current operation has a single result. + if (current->getNumResults() != 1) { + return { + rewriter.notifyMatchFailure(current, "op doesn't have single result"), + nullptr}; + } + // Check the current operation has a consumer node. + Operation* consumer_op = + current->getResult(0).getUses().begin()->getOwner(); + if (!consumer_op) { + return { + rewriter.notifyMatchFailure(current, "op doesn't have consumer node"), + nullptr}; + } + // Check the current operation's result is used by its successor node. + if (!current->hasOneUse() || + *(current->getResult(0).user_begin()) != consumer_op) { + return { + rewriter.notifyMatchFailure( + current, "op's result isn't directly consumed by the next op"), + nullptr}; + } + return {LogicalResult::success(), consumer_op}; + }; + + std::pair maybeConsumer = + tryGetDirectConsumerOp(op.getOperation()); + if (failed(maybeConsumer.first)) { + return maybeConsumer.first; + } + Operation* consumer_op = maybeConsumer.second; + + TF::ExpandDimsOp expand_op; + TF::SqueezeOp squeeze_op; + int64_t expand_axis = -1; + // Expand + Squeeze op. + if (llvm::isa(producer_op)) { + if (!llvm::isa(consumer_op)) { + // Expand/Squeeze op must come in pair. + return rewriter.notifyMatchFailure( + op, "ExpandDimsOp and SqueezeOp should come in pair"); + } + expand_op = llvm::cast(producer_op); + squeeze_op = llvm::cast(consumer_op); + if (!expand_op.getResult().hasOneUse()) { + return rewriter.notifyMatchFailure( + expand_op, "result for current op has more than 1 use"); + } + if (!squeeze_op.getResult().hasOneUse()) { + return rewriter.notifyMatchFailure( + squeeze_op, "result for current op has more than 1 use"); + } + // Make sure that the axis in `expand_op` is constant. + if (auto const_op = + llvm::dyn_cast(expand_op.getDim().getDefiningOp())) { + expand_axis = (*mlir::cast(const_op.getValue()) + .getValues() + .begin()) + .getSExtValue(); + // Canonicalize axis. Some TF python functions, such as + // `tf.nn.convolution`, use negative axis. + if (expand_axis < 0) { + // Always expand 3D input to 4D input. + expand_axis += 4; + } + } else { + return rewriter.notifyMatchFailure( + expand_op, "ExpandDimsOp doesn't have a constant axis"); + } + // Make sure that the `squeeze_dims` is equal to `expand_axis`. + auto squeeze_dims = squeeze_op.getSqueezeDims(); + if (squeeze_dims.size() != 1) { + return rewriter.notifyMatchFailure( + squeeze_op, "squeeze dims should have exactly 1 dimension specified"); + } + int64_t squeeze_axis = mlir::cast(squeeze_dims[0]).getInt(); + if (squeeze_axis < 0) { + // Always squeeze 4D input to 3D input. + squeeze_axis += 4; + } + if (squeeze_axis != expand_axis) { + return rewriter.notifyMatchFailure( + op, "squeeze axis and expand axis doesn't match"); + } + + // Update previous/next op pointer. + Operation* tmp = expand_op.getInput().getDefiningOp(); + if (!tmp || tmp->getNumResults() != 1) { + return rewriter.notifyMatchFailure( + producer_op, + "op doesn't have a producer node that has a single result"); + } + if (!tmp->hasOneUse() || *(tmp->getResult(0).user_begin()) != producer_op) { + return rewriter.notifyMatchFailure( + producer_op, "op's input isn't defined by its previous node"); + } + producer_op = tmp; + std::pair maybeConsumer = + tryGetDirectConsumerOp(consumer_op); + if (failed(maybeConsumer.first)) { + return maybeConsumer.first; + } + consumer_op = maybeConsumer.second; + } + + // SpaceToBatchND op. + if (!llvm::isa(producer_op)) { + return rewriter.notifyMatchFailure(producer_op, + "op should be a SpaceToBatchND op"); + } + // TODO(b/149936532): Check `padding` input, currently ignored. + TF::SpaceToBatchNDOp stb_op = llvm::cast(producer_op); + if (!stb_op.getResult().hasOneUse()) { + return rewriter.notifyMatchFailure( + stb_op, "result for current op has more than 1 use"); + } + + // Pad op. + TF::PadOp pad_op; + ElementsAttr pad_attr; + if (llvm::isa(consumer_op)) { + pad_op = llvm::cast(consumer_op); + if (!pad_op.getResult().hasOneUse()) { + return rewriter.notifyMatchFailure( + pad_op, "result for current op has more than 1 use"); + } + std::pair maybeConsumer = + tryGetDirectConsumerOp(consumer_op); + if (failed(maybeConsumer.first)) { + return maybeConsumer.first; + } + consumer_op = maybeConsumer.second; + if (!matchPattern(pad_op.getPaddings(), m_Constant(&pad_attr))) { + // If the padding value isn't constant, we can't determine the padding + // scheme for Conv2D below, in this case just reject the pattern. + return rewriter.notifyMatchFailure( + pad_op, "PadOp's padding value isn't constant"); + } + } + + // BatchToSpaceND + BiasAdd. + TF::BatchToSpaceNDOp bts_op; + TF::BiasAddOp biasadd_op; + bool final_op_is_bts = true; + if (llvm::isa(consumer_op)) { + // Must be BiasAdd + BatchToSpaceND. + biasadd_op = llvm::cast(consumer_op); + if (!biasadd_op.getResult().hasOneUse()) { + return rewriter.notifyMatchFailure( + biasadd_op, "result for current op has more than 1 use"); + } + std::pair maybeConsumer = + tryGetDirectConsumerOp(consumer_op); + if (failed(maybeConsumer.first)) { + return maybeConsumer.first; + } + if (!llvm::isa(maybeConsumer.second)) { + return rewriter.notifyMatchFailure( + consumer_op, "op's next node isn't BatchToSpaceND op"); + } + consumer_op = maybeConsumer.second; + bts_op = llvm::cast(consumer_op); + } else if (llvm::isa(consumer_op)) { + // BatchToSpaceND + (optional) BiasAdd. + bts_op = llvm::cast(consumer_op); + std::pair maybeConsumer = + tryGetDirectConsumerOp(consumer_op); + Operation* tmp = maybeConsumer.second; + if (tmp && llvm::isa(tmp)) { + consumer_op = tmp; + biasadd_op = llvm::cast(consumer_op); + final_op_is_bts = false; + } + } else { + return rewriter.notifyMatchFailure( + consumer_op, "next op is neither BiasAdd nor BatchToSpaceND"); + } + + std::optional dilations_attr = ExtractDilationsAttrFromBlockShape( + stb_op.getBlockShape(), bts_op.getBlockShape(), expand_axis, rewriter); + if (!dilations_attr.has_value()) { + return rewriter.notifyMatchFailure(op, "failed to extract dilation rate"); + } + + if (expand_op) { + if (mlir::dyn_cast(stb_op.getInput().getType()) == + nullptr) { + return rewriter.notifyMatchFailure( + stb_op, "SpaceToBatchND op's input should have RankedTensorType"); + } + } + + // TODO(b/149936532): Check that the input width & height are multiples of + // dilation rate. + // TF python library will rewrite dilated conv to + // "SpaceToBatch->Conv->BatchToSpace" pattern, and the Conv in the middle + // always has 'VALID' padding. The padding tensor in `SpaceToBatch` has two + // parts of contributions, one is to reduce padding of CONV from 'SAME' to + // 'VALID', and another is to make input shape multiples of dilation rate. The + // first part of padding, which is also called `base_padding` will be used + // here to determine if the original padding format is 'SAME' or 'VALID'. + // According to the following formula we will compute the `base_padding` if + // it's a constant. Basically, `paddings` tensor in `SpaceToBatch` and `crops` + // tensor in `BatchToSpace` must satisfy the following: + // paddings[i, 0] = base_paddings[i, 0]. + // 0 <= paddings[i, 1] - base_paddings[i, 1] < block_shape[i] + // (input_shape[i] + paddings[i, 0] + paddings[i, 1]) % block_shape[i] == 0. + // crops[i, 0] = 0. + // crops[i, 1] = paddings[i, 1] - base_paddings[i, 1]. + + // If `paddings` - `crops` != 0, this means that `base_paddings` != 0, which + // tells us the original padding is 'SAME' (with one caveat presented below). + // Here we need to reset the padding back to `SAME` if `base_padding` + // != 0. + // TODO(b/149936532): We might not simply rely on `paddings - crops != 0` to + // determine the original padding format. For example, users can build + // arbitrary valid examples of `STB->Conv->BTS` which doesn't represent a + // dilated conv, hence we shouldn't pattern match here. Instead, we need to + // check values of `paddings` and `crops` to make sure it really stands for + // a dilated conv. + auto stb_paddings = stb_op.getPaddings(); + auto bts_crops = bts_op.getCrops(); + ElementsAttr stb_paddings_attr, bts_crops_attr; + if (!matchPattern(stb_paddings, m_Constant(&stb_paddings_attr)) || + !matchPattern(bts_crops, m_Constant(&bts_crops_attr))) { + return rewriter.notifyMatchFailure( + op, + "either SpaceToBatchND or BatchToSpaceND " + "doesn't have constant padding/crops value"); + } + if (stb_paddings_attr.getType() != bts_crops_attr.getType()) { + return rewriter.notifyMatchFailure( + stb_op, + "SpaceToBatchND op's padding doesn't have same shape/type with " + "BatchToSpaceND op's crops"); + } + int64_t m = stb_paddings_attr.getShapedType().getDimSize(0); + // padding - crop. + for (uint64_t i = 0; i < m; ++i) { + for (uint64_t j = 0; j < 2; ++j) { + // `crops` tensor has shape [M, 2], crops[i] = [crop_start, crop_end] + // specifies the amount to crop from input dimension i + 1. If the input + // of `BatchToSpaceND` has been padded explicitly, then we need to + // take into account the additional padding when determining the padding + // scheme for `Conv2D`. + int64_t addtional_pad = + pad_attr ? pad_attr.getValues()[{i + 1, j}].getSExtValue() : 0; + if (stb_paddings_attr.getValues()[{i, j}].getSExtValue() + + addtional_pad != + bts_crops_attr.getValues()[{i, j}].getSExtValue()) { + op->setAttr("padding", rewriter.getStringAttr("SAME")); + break; + } + } + } + + // Set dilations + op->setAttr("dilations", dilations_attr.value()); + + if (expand_op) { + // If there is `expand_op`, we need to rewire the inputs to bypass the + // `SpaceToBatch`, `BatchToSpace` and `Pad` op. E.g, turning + // 'SpaceToBatchND -> Expand -> Conv2D -> Squeeze -> BatchToSpaceND -> + // BiasAdd' to 'Expand -> Conv2D ->Squeeze -> BiasAdd'. + + // Connect `expand_op` with the input of `stb_op`. + expand_op.setOperand(0, stb_op.getInput()); + // Calculate the shape for expand. + auto input_shape = + mlir::cast(stb_op.getInput().getType()).getShape(); + SmallVector expand_shape(input_shape.begin(), + input_shape.end()); + expand_shape.insert(expand_shape.begin() + expand_axis, 1); + + auto expand_result_type = RankedTensorType::get( + expand_shape, getElementTypeOrSelf(stb_op.getInput())); + expand_op.getResult().setType(expand_result_type); + + // Update the conv op's output shape. + auto bts_output_shape = + mlir::cast(bts_op.getOutput().getType()).getShape(); + SmallVector conv_result_shape(bts_output_shape.begin(), + bts_output_shape.end()); + conv_result_shape.insert(conv_result_shape.begin() + expand_axis, 1); + auto conv_result_type = RankedTensorType::get( + conv_result_shape, getElementTypeOrSelf(stb_op.getInput())); + op.getResult().setType(conv_result_type); + + squeeze_op.getResult().setType(bts_op.getOutput().getType()); + + // Connect `biasadd_op` with the output of `squeeze_op`. + if (biasadd_op) { + biasadd_op.setOperand(0, squeeze_op.getOutput()); + biasadd_op.getOutput().setType(squeeze_op.getOutput().getType()); + } + } else { + if (biasadd_op) biasadd_op.setOperand(0, op.getOutput()); + op.setOperand(0, stb_op.getInput()); + op.getResult().setType(bts_op.getResult().getType()); + } + + if (final_op_is_bts) { + if (bts_op.getInput().getDefiningOp()) { + bts_op.getResult().replaceAllUsesWith(pad_op.getInput()); + } else { + bts_op.getResult().replaceAllUsesWith(bts_op.getInput()); + } + } + + stb_op.getResult().dropAllUses(); + return success(); +} + +template +std::optional +ConvertTFDilatedConvOp::ExtractDilationsAttrFromBlockShape( + Value stb_block_shape, Value bts_block_shape, int64_t expand_axis, + PatternRewriter& rewriter) const { + ElementsAttr stb_bs_attr, bts_bs_attr; + if (!matchPattern(stb_block_shape, m_Constant(&stb_bs_attr)) || + !matchPattern(bts_block_shape, m_Constant(&bts_bs_attr))) { + // Returns failure status if block_shape is not a constant. + return {}; + } + // Check that the block_shape of `stb_op` and `bts_op` are equal. + if (stb_bs_attr.getNumElements() != bts_bs_attr.getNumElements()) return {}; + for (uint64_t i = 0, end = stb_bs_attr.getNumElements(); i < end; ++i) { + if (stb_bs_attr.getValues()[i] != + bts_bs_attr.getValues()[i]) + return {}; + } + + int dilation_h_factor = -1, dilation_w_factor = -1; + // Set dilation factor. + if (stb_bs_attr.getNumElements() >= 2) { + dilation_h_factor = stb_bs_attr.getValues()[0].getSExtValue(); + dilation_w_factor = stb_bs_attr.getValues()[1].getSExtValue(); + } else if (stb_bs_attr.getNumElements() == 1) { + // For 1d conv, `tf.nn.convolution` expands NWC to NHWC format after + // `SpaceToBatchND`. Therefore, `block_shape` of `stb_op` only has one + // dilation factor of W dim, and dilation factor of H dim is set to 1. + if (expand_axis == 1) { + // NWC -> NHWC + dilation_h_factor = 1; + dilation_w_factor = stb_bs_attr.getValues()[0].getSExtValue(); + } else if (expand_axis == 2) { + // NHC -> NHWC + dilation_h_factor = stb_bs_attr.getValues()[0].getSExtValue(); + dilation_w_factor = 1; + } + } + + if (dilation_h_factor == -1 || dilation_w_factor == -1) { + return {}; + } + + return rewriter.getI64ArrayAttr({1, dilation_h_factor, dilation_w_factor, 1}); +} + +} // namespace TFL +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_DILATED_CONV_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/transforms/lift_tflite_flex_ops.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/transforms/lift_tflite_flex_ops.h new file mode 100644 index 00000000..e4530480 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/transforms/lift_tflite_flex_ops.h @@ -0,0 +1,35 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_LIFT_TFLITE_FLEX_OPS_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_LIFT_TFLITE_FLEX_OPS_H_ + +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project + +namespace mlir { +namespace TFL { + +// Creates an instance of the lift TFLite Flex ops pass that lifts TFLite Flex +// ops into TF dialect operations. +std::unique_ptr> CreateLiftTfliteFlexOpsPass(); + +void AddLiftTfliteFlexOpsPatterns(MLIRContext *context, + RewritePatternSet &patterns); + +} // namespace TFL +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_LIFT_TFLITE_FLEX_OPS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/transforms/lower_quant_annotations_helper.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/transforms/lower_quant_annotations_helper.h new file mode 100644 index 00000000..85fffcf2 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/transforms/lower_quant_annotations_helper.h @@ -0,0 +1,55 @@ +/* Copyright 2025 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_LOWER_QUANT_ANNOTATIONS_HELPER_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_LOWER_QUANT_ANNOTATIONS_HELPER_H_ + +#include + +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "stablehlo/dialect/StablehloOps.h" // from @stablehlo + +namespace mlir::TFL { + +LogicalResult FillCompositeParams(stablehlo::CompositeOp op, + SmallVector& scales, + SmallVector& zero_points, + int& num_bits, bool& is_signed); + +LogicalResult GetStorageParams(unsigned num_bits, bool narrow_range, + bool is_signed, MLIRContext* ctx, + Type& storage_type, int64_t& qmin, + int64_t& qmax); + +Type GetPerTensorQuantizedTensorType(Builder& builder, double scale, + int64_t zero_point, Type expressed_type, + int num_bits, Location loc, + bool narrow_range, bool is_signed); + +Type GetPerAxisQuantizedTensorType(Builder& builder, + SmallVector scales, + SmallVector zero_points, + int32_t quantized_dimension, + Type expressed_type, int num_bits, + Location loc, bool narrow_range, + bool is_signed); + +} // namespace mlir::TFL +#endif // TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_LOWER_QUANT_ANNOTATIONS_HELPER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/transforms/optimize_batch_matmul_pass.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/transforms/optimize_batch_matmul_pass.h new file mode 100644 index 00000000..c81548b3 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/transforms/optimize_batch_matmul_pass.h @@ -0,0 +1,68 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_OPTIMIZE_BATCH_MATMUL_PASS_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_OPTIMIZE_BATCH_MATMUL_PASS_H_ + +#include "llvm/ADT/APFloat.h" +#include "llvm/ADT/APInt.h" +#include "llvm/ADT/StringRef.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/TypeID.h" // from @llvm-project +#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" +#include "tensorflow/compiler/mlir/lite/transforms/pass.h" +#include "tensorflow/compiler/mlir/lite/transforms/pass_options.h" + +namespace mlir { +namespace TFL { + +// Optimize FC with BatchMatmul within the TensorFlow Lite dialect. + +class OptimizeBatchMatmulPass + : public TFL::Pass { + public: + OptimizeBatchMatmulPass() = default; + OptimizeBatchMatmulPass(const OptimizeBatchMatmulPass &other) {} + + void runOnOperation() final; + + /// Returns the command-line argument attached to this pass. + static llvm::StringRef GetArgument() { return "tfl-optimize-batch-matmul"; } + + static llvm::StringRef GetDescription() { + return "Optimize FC with BatchMatmul within the TensorFlow Lite dialect."; + } + + /// Returns the derived pass name. + static llvm::StringRef GetName() { return "OptimizeBatchMatmulPass"; } + + /// Return the dialect that must be loaded in the context before this pass. + void getDependentDialects(::mlir::DialectRegistry ®istry) const override { + registry.insert(); + } + + /// Explicitly declare the TypeID for this class. We declare an explicit + /// private instantiation because Pass classes should only be visible by the + /// current library. + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(OptimizeBatchMatmulPass) +}; + +} // namespace TFL +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_OPTIMIZE_BATCH_MATMUL_PASS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/transforms/optimize_broadcast_like_pass.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/transforms/optimize_broadcast_like_pass.h new file mode 100644 index 00000000..f13048a1 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/transforms/optimize_broadcast_like_pass.h @@ -0,0 +1,53 @@ + +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_OPTIMIZE_BROADCAST_LIKE_PASS_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_OPTIMIZE_BROADCAST_LIKE_PASS_H_ + +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/DialectRegistry.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/TypeID.h" // from @llvm-project +#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" +#include "tensorflow/compiler/mlir/lite/transforms/pass.h" +#include "tensorflow/compiler/mlir/lite/transforms/pass_options.h" + +namespace mlir { +namespace TFL { + +// Pass to optimize explicit broadcasting-like patterns. +class OptimizeBroadcastLikePass + : public TFL::Pass { + public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(OptimizeBroadcastLikePass) + + OptimizeBroadcastLikePass() = default; + OptimizeBroadcastLikePass(const OptimizeBroadcastLikePass&) {}; + + void runOnOperation() override; + static llvm::StringRef GetName() { return "OptimizeBroadcastLikePass"; } + static llvm::StringRef GetArgument() { return "tfl-optimize-broadcast-like"; } + static llvm::StringRef GetDescription() { + return "Pass optimizing explicit broadcasting-like patterns."; + } + + private: + void getDependentDialects(DialectRegistry& registry) const override { + registry.insert(); + } +}; +} // namespace TFL +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_OPTIMIZE_BROADCAST_LIKE_PASS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/transforms/optimize_pass.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/transforms/optimize_pass.h new file mode 100644 index 00000000..86e47726 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/transforms/optimize_pass.h @@ -0,0 +1,56 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_OPTIMIZE_PASS_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_OPTIMIZE_PASS_H_ + +#include "llvm/ADT/StringRef.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Pass/PassOptions.h" // from @llvm-project +#include "mlir/Support/TypeID.h" // from @llvm-project +#include "tensorflow/compiler/mlir/lite/transforms/optimize_pass_options.h" +#include "tensorflow/compiler/mlir/lite/transforms/pass.h" + +namespace mlir { +namespace TFL { + +// Optimize TFLite operations in functions. +class OptimizePass + : public Pass { + public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(OptimizePass) + + OptimizePass() = default; + OptimizePass(const OptimizePass &) {} + explicit OptimizePass(const mlir::detail::PassOptions &options) + : Pass(options) {} + + /// Returns the command-line argument attached to this pass. + static llvm::StringRef GetArgument() { return "tfl-optimize"; } + + static llvm::StringRef GetDescription() { + return "Optimize within the TensorFlow Lite dialect"; + } + + /// Returns the derived pass name. + static llvm::StringRef GetName() { return "OptimizePass"; } + + void runOnOperation() override; +}; + +} // namespace TFL +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_OPTIMIZE_PASS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/transforms/optimize_pass_options.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/transforms/optimize_pass_options.h new file mode 100644 index 00000000..915dc380 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/transforms/optimize_pass_options.h @@ -0,0 +1,43 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_OPTIMIZE_PASS_OPTIONS_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_OPTIMIZE_PASS_OPTIONS_H_ + +#include "llvm/Support/CommandLine.h" +#include "mlir/Pass/PassOptions.h" // from @llvm-project + +namespace mlir { +namespace TFL { + +//////////////////////////////////////////////////////////////////////////////// +// Pass Options +//////////////////////////////////////////////////////////////////////////////// + +struct OptimizePassOptions : public mlir::detail::PassOptions { + mlir::detail::PassOptions::Option enable_canonicalization{ + *this, "enable-canonicalization", + llvm::cl::desc("Enable canonicalization in the optimize pass"), + llvm::cl::init(true)}; + mlir::detail::PassOptions::Option disable_fuse_mul_and_fc{ + *this, "disable-fuse-mul-and-fc", + llvm::cl::desc("Disable fuse mul and fc in the optimize pass"), + llvm::cl::init(false)}; +}; + +} // namespace TFL +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_OPTIMIZE_PASS_OPTIONS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/transforms/pass.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/transforms/pass.h new file mode 100644 index 00000000..f2eed518 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/transforms/pass.h @@ -0,0 +1,113 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_PASS_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_PASS_H_ + +#include +#include + +#include "llvm/ADT/StringRef.h" +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassManager.h" // from @llvm-project +#include "mlir/Pass/PassOptions.h" // from @llvm-project +#include "mlir/Support/TypeID.h" // from @llvm-project +#include "tensorflow/compiler/mlir/lite/transforms/pass_options.h" +#include "tensorflow/compiler/mlir/lite/transforms/pass_options_setter.h" + +// Forward declaration for the visitor interface +// class PassOptionsVisitor; + +namespace mlir { +namespace TFL { + +// Interface for setting options for TFLite Converter Pass/Pipeline Options. +class MutableOptionsPass { + public: + virtual ~MutableOptionsPass() = default; + virtual void ApplyOptionsVisitor(const PassOptionsSetter &visitor) = 0; +}; + +// CRTP Class to ensure that the derived passes implement a Options struct +template +class Pass : public PassWrapper, + mlir::OperationPass>, + public MutableOptionsPass { + public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(Pass); + + Pass() = default; + Pass(const Pass &pass) { + static_cast(this)->GetOptions().copyOptionValuesFrom( + pass.GetOptions()); + } + explicit Pass(const DerivedPassOptions &options) { + static_cast(this)->GetOptions().copyOptionValuesFrom( + options); + } + + explicit Pass(const mlir::detail::PassOptions &options) { + static_cast(this)->GetOptions().copyOptionValuesFrom( + options); + } + + /// Functions to satisfy the mlir::Pass interface + llvm::StringRef getArgument() const override { + return DerivedPass::GetArgument(); + } + + llvm::StringRef getDescription() const override { + return DerivedPass::GetDescription(); + } + + llvm::StringRef getName() const override { return DerivedPass::GetName(); } + + /// Support isa/dyn_cast functionality for the derived pass class. + static bool classof(const ::mlir::Pass *pass) { + return pass->getTypeID() == ::mlir::TypeID::get(); + } + + /// A clone method to create a copy of this pass. + std::unique_ptr<::mlir::Pass> clonePass() const override { + auto pass = + std::make_unique(*static_cast(this)); + pass->GetOptions().copyOptionValuesFrom(GetOptions()); + return std::move(pass); + } + void runOnOperation() override {} + + // ApplyOptionsVisitor method to `accept` the visitor + void ApplyOptionsVisitor(const PassOptionsSetter &visitor) override { + visitor.SetOptions(GetOptions()); + } + + protected: + DerivedPassOptions &GetOptions() { + return static_cast(this)->options_; + } + + const DerivedPassOptions &GetOptions() const { + return static_cast(this)->options_; + } + + private: + DerivedPassOptions options_; +}; +} // namespace TFL +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_PASS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/transforms/pass_options.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/transforms/pass_options.h new file mode 100644 index 00000000..7f5bb198 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/transforms/pass_options.h @@ -0,0 +1,27 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_PASS_OPTIONS_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_PASS_OPTIONS_H_ + +#include "mlir/Pass/PassOptions.h" // from @llvm-project + +namespace mlir { +namespace TFL { +struct EmptyPassOptions : public mlir::detail::PassOptions {}; +} // namespace TFL +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_PASS_OPTIONS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/transforms/pass_options_setter.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/transforms/pass_options_setter.h new file mode 100644 index 00000000..534b1402 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/transforms/pass_options_setter.h @@ -0,0 +1,37 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_PASS_OPTIONS_SETTER_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_PASS_OPTIONS_SETTER_H_ + +namespace mlir { +namespace TFL { + +class OptimizePassOptions; +class VariableFreezingPipelineOptions; +class EmptyPassOptions; + +// Interface for setting options for TFLite Converter Pass/Pipeline Options. +class PassOptionsSetter { + public: + virtual ~PassOptionsSetter() = default; + virtual void SetOptions(OptimizePassOptions& options) const = 0; + virtual void SetOptions(VariableFreezingPipelineOptions& options) const = 0; + virtual void SetOptions(EmptyPassOptions& options) const = 0; +}; +} // namespace TFL +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_PASS_OPTIONS_SETTER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/transforms/pass_registry_utils.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/transforms/pass_registry_utils.h new file mode 100644 index 00000000..43f064c4 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/transforms/pass_registry_utils.h @@ -0,0 +1,98 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_PASS_REGISTRY_UTILS_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_PASS_REGISTRY_UTILS_H_ + +#include + +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassManager.h" // from @llvm-project +#include "mlir/Pass/PassRegistry.h" // from @llvm-project +#include "tensorflow/compiler/mlir/lite/transforms/pass_options.h" +#include "tensorflow/compiler/mlir/lite/transforms/pipeline.h" + +namespace mlir { +namespace TFL { + +//////////////////////////////////////////////////////////////////////////////// +// Pass, Pipeline and Options Creation Utilities +//////////////////////////////////////////////////////////////////////////////// + +template +std::unique_ptr Create() { + if constexpr (std::is_base_of_v, + PassType>) { + return std::make_unique>(); + } else { + return std::make_unique(); + } +} + +template +std::unique_ptr Create() { + if constexpr (std::is_base_of_v, + PassType>) { + return std::make_unique>(); + } else { + return std::make_unique(PassOptionsType()); + } +} + +template +std::unique_ptr Create(const mlir::detail::PassOptions& options) { + return std::make_unique(options); +} + +//////////////////////////////////////////////////////////////////////////////// +// Registration Utilities +//////////////////////////////////////////////////////////////////////////////// + +// Utility to register a pass without options. +template +void Register() { + PassRegistration pass([] { return Create(); }); +} + +// Utility to register a pass with options. +template +void Register() { + auto pass_argument = PassType::GetArgument(); + auto pass_description = PassType::GetDescription(); + + if constexpr (std::is_base_of_v, + PassType>) { + // PassType is derived from PipelinePass, proceed with registration + // of the pipeline. + PassPipelineRegistration( + pass_argument, pass_description, + [](OpPassManager& pm, const PassOptionsType& options) { + auto pipeline = PassType(); + pipeline.AddPasses(); + pipeline.GetPipeline(pm, options); + }); + } else { + PassPipelineRegistration( + pass_argument, pass_description, + [](OpPassManager& pm, const PassOptionsType& options) { + pm.addPass(std::move(Create(options))); + }); + } +} + +} // namespace TFL +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_PASS_REGISTRY_UTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/transforms/passes.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/transforms/passes.h new file mode 100644 index 00000000..4d8eccca --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/transforms/passes.h @@ -0,0 +1,356 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_PASSES_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_PASSES_H_ + +#include +#include + +#include "absl/container/flat_hash_set.h" +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassRegistry.h" // from @llvm-project // IWYU pragma: keep +#include "tensorflow/compiler/mlir/lite/transforms/canonicalize_boundary_value_pass.h" +#include "tensorflow/compiler/mlir/lite/transforms/optimize_batch_matmul_pass.h" +#include "tensorflow/compiler/mlir/lite/transforms/optimize_broadcast_like_pass.h" +#include "tensorflow/compiler/mlir/lite/transforms/optimize_pass.h" +#include "tensorflow/compiler/mlir/lite/transforms/pass_registry_utils.h" +#include "tensorflow/compiler/mlir/lite/transforms/push_transpose_through_ewise_pass.h" +#include "tensorflow/compiler/mlir/lite/transforms/tf_legalizations/analyze_variables_pass.h" +#include "tensorflow/compiler/mlir/lite/transforms/tf_legalizations/legalize_tensorlist_pass.h" +#include "tensorflow/compiler/mlir/lite/transforms/tf_legalizations/while_loop_outline_pass.h" +#include "tensorflow/compiler/mlir/lite/transforms/tflite_passes/split_merged_operands_pass.h" +#include "tensorflow/compiler/mlir/lite/transforms/tflite_passes/unfold_large_splat_constants_pass.h" +#include "tensorflow/compiler/mlir/lite/transforms/unfreeze_global_constants.h" +#include "tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_config.h" + +namespace mlir { +namespace quant { +class QuantDialect; +} +namespace quantfork { +class QuantizationForkDialect; +} +namespace mhlo { +class MhloDialect; +} +namespace TF { +class TensorFlowDialect; +} +namespace TFL { +class TFLDialect; +typedef TFLDialect TensorFlowLiteDialect; +} // namespace TFL +namespace func { +class FuncOp; +} +class ModuleOp; +template +class OperationPass; +class Type; + +namespace TFL { + +//////////////////////////////////////////////////////////////////////////////// +// Forward declarations +//////////////////////////////////////////////////////////////////////////////// + +struct OptimizePassOptions; + +//////////////////////////////////////////////////////////////////////////////// +// Utilities for backward compatibility +//////////////////////////////////////////////////////////////////////////////// + +// Creates an instance of the TensorFlow Lite dialect LegalizeTF pass. +// When the given run_tfl_runtime_verification value is true, it will check each +// TFL builtin op towards the TFL runtime capability and the incompatible TF ops +// will be left in the graph without getting legalized. If `preserve_assert_op` +// is true, the TF::AssertOp will not be removed. +std::unique_ptr> CreateLegalizeTFPass( + bool run_tfl_runtime_verification, bool preserve_assert_op = false); +std::unique_ptr> CreateLegalizeTFPass(); + +// Creates an instance of the TensorFlow Lite dialect Optimize pass. +inline std::unique_ptr CreateOptimizePass() { + return Create(); +} + +// Creates an instance of the Tensorflow Lite batch matmul Optimize pass. +inline std::unique_ptr CreateOptimizeBatchMatmulPass() { + return Create(); +} + +// Creates an instance of the TensorFlow Lite dialect PrepareTF pass. +std::unique_ptr> CreatePrepareTFPass( + bool unfold_batch_matmul, bool allow_bf16_and_f16_type_legalization, + bool use_fake_quant_num_bits = false); +std::unique_ptr> CreatePrepareTFPass(); + +// Creates an instance of the TensorFlow Lite dialect LowerStaticTensorList +// pass. +std::unique_ptr> CreateLowerStaticTensorListPass( + bool allow_tensorlist_pass_through, bool default_to_single_batch, + bool enable_dynamic_update_slice); + +std::unique_ptr> CreateLowerStaticTensorListPass(); + +// Creates an instance of the TensorFlow Lite dialect Quantize pass. +// Use quant_specs.ops_blocklist and quant_specs.nodes_blocklist if possible +// as they are now structure variables of QuantizationSpecs. +std::unique_ptr> CreateQuantizePass( + const quant::QuantizationSpecs& quant_specs, + const absl::flat_hash_set& ops_blocklist = {}, + const absl::flat_hash_set& nodes_blocklist = {}); + +std::unique_ptr> CreateDefaultQuantizePass(); + +std::unique_ptr> CreateLowerQuantAnnotationsPass(); + +// Overloading of CreateQuantizePass which takes only necessary flags to reduce +// the binary size. +std::unique_ptr> CreateQuantizePass( + bool verify_numeric = false, bool whole_model_verify = false, + bool legacy_float_scale = false, + const absl::flat_hash_set& ops_blocklist = {}, + const absl::flat_hash_set& nodes_blocklist = {}); + +// Creates an instance of the TensorFlow Lite dialect PrepareQuantize pass. +std::unique_ptr> CreatePrepareQuantizePass( + const quant::QuantizationSpecs& quant_specs); + +std::unique_ptr> CreatePrepareQuantizePass(); + +// Creates an instance of the TensorFlow Lite dialect +// PrepareDynamicRangeQuantize pass. +std::unique_ptr> +CreatePrepareDynamicRangeQuantizePass( + const quant::QuantizationSpecs& quant_specs); + +std::unique_ptr> +CreatePrepareDynamicRangeQuantizePass(); + +// Creates an instance of the TensorFlow Lite dialect PostQuantize pass. +std::unique_ptr> CreatePostQuantizePass(); +std::unique_ptr> CreatePostQuantizePass( + bool emit_quant_adaptor_ops, const quant::CustomOpMap& custom_op_map = {}); + +// Creates an instance of the TensorFlow Lite dialect QuantizeVariables pass. +std::unique_ptr> CreatePrepareQuantizeVariablesPass(); + +// Creates an instance of the TensorFlow Lite pass that decomposes hybrid +// quantization patterns to the same dense operation with tfl dequantization +// and quantization patterns. +std::unique_ptr> +CreateDecomposeHybridQuantizationPass(); + +// Creates an instance of the TensorFlow Lite optimize op order pass. +std::unique_ptr> CreateOptimizeOpOrderPass(); + +// Creates an instance of the TensorFlow Lite dialect TrimFunctions +// pass. +std::unique_ptr> CreateTrimFunctionsPass(); + +std::unique_ptr> CreateTrimFunctionsPass( + const std::vector& trim_funcs_allowlist); + +// Creates an instance of the TensorFlow Lite dialect PrepareCompositeFunctions +// pass. +std::unique_ptr> CreatePrepareCompositeFunctionsPass(); + +// Creates an instance of the TensorFlow Lite dialect SplitMergedOperandsPass. +inline std::unique_ptr CreateSplitMergedOperandsPass() { + return Create(); +} + +// Creates an instance of the TensorFlow Lite dialect OptimizeFunctionalOpsPass. +std::unique_ptr> CreateOptimizeFunctionalOpsPass(); + +std::unique_ptr> CreateModifyIONodesPass( + mlir::Type input_type, mlir::Type output_type); + +std::unique_ptr> CreateModifyIONodesPass(); + +// Creates an instance of the TensorFlow Lite dialect PostQuantizeRemoveQDQ +// pass. +std::unique_ptr> CreatePostQuantizeRemoveQDQPass(); + +// Creates an instance of the TensorFlow Lite dialect pass to add default +// quantization parameters. +std::unique_ptr> CreateDefaultQuantParamsPass( + double default_min, double default_max, bool is_signed); + +std::unique_ptr> CreateDefaultQuantParamsPass(); + +// Creates an instance of the IdentifyDilatedConvPass. +std::unique_ptr> CreateIdentifyDilatedConvPass(); + +// Creates function pass to legalize TF While to TFL While. +std::unique_ptr> CreateLegalizeTFWhilePass(); + +// Legalize tflite flex ops to TF ops. +std::unique_ptr> CreateLiftTfliteFlexOpsPass(); + +// Creates an instance of the TensorFlow Lite dialect WhileOp outline pass. +inline std::unique_ptr CreateWhileOutlinePass() { + return Create(); +} + +// Creates an instance of the TensorFlow Lite dialect IfOp outline pass. +std::unique_ptr> CreateIfOutlinePass(); + +// Creates a pass to remove operands of TFL WhileOp without changing outcomes. +std::unique_ptr> CreateReduceWhileOperandsPass(); + +// Verifies runtime constraints. +std::unique_ptr> CreateRuntimeVerifyPass(); + +// Creates raise custom ops pass, which legalize custom ops to TFL::CustomOp +std::unique_ptr> CreateRaiseCustomOpsPass(); +std::unique_ptr> CreateRaiseCustomOpsPass( + const std::vector& target_ops); + +// Creates raise custom ops pass, which legalize custom ops to TFL::CustomOp +std::unique_ptr> CreateLowerCustomOpsPass(); + +// Inserts an TFL::CallOnce op when the tf_saved_model's session initialzer is +// given. +std::unique_ptr> +CreateInsertCallOnceOpFromSessionInitializerPass(); + +// Replace the tfl wrapped random function body with tfl.customOp. +std::unique_ptr> CreateLegalizeJaxRandomPass(); + +// Creates a pass which is responsible for legalizing TensorFlow variables to +// TensorFlow Lite variables. +std::unique_ptr> CreateLegalizeVariablesPass(); + +// Creates a pass which analyze the model whether it is safe to use +// native TFLite variables or not. +inline std::unique_ptr CreateAnalyzeVariablesPass() { + return Create(); +} + +// Creates a pass which is responsible for legalizing TensorFlow static hash +// tables to TensorFlow Lite hash tables. +std::unique_ptr> CreateLegalizeHashTablesPass(); + +// Creates get arithmetic count pass, which will calculate the arithmetic count +// for each ops. +std::unique_ptr> CreateGetArithmeticCountPass(); + +// Creates unfold large constant pass, which will replace large splat constant +// tensors with fill op. +inline std::unique_ptr CreateUnfoldLargeSplatConstantPass() { + return Create(); +} + +// Creates a pass which is responsible for unfreezing mutable global tensors. +inline std::unique_ptr CreateUnfreezeMutableGlobalTensorsPass() { + return Create(); +} + +// Creates a pass that adds control dependencies to keep the relative +// execution order of operations with side effects frozen. +std::unique_ptr> CreatePinOpsWithSideEffectsPass(); + +// Legalize TensorList Ops iff all of them are supported. +inline std::unique_ptr CreateLegalizeTensorListPass() { + return Create(); +} + +// Reduce the type precision of some tensor types if all values within that +// tensor are within the range of the reduced precision. +std::unique_ptr> CreateReduceTypePrecisionPass(); + +// Conservatively pushes transposes through element-wise ops to prepare +// so redundant ones may be grouped and removed. +inline std::unique_ptr CreatePushTransposeThroughEwisePass() { + return Create(); +} + +// Create a pass that canonicalize the boundary values. +inline std::unique_ptr CreateCanonicalizeBoundaryValuePass() { + return Create(); +} + +// Creates a pass that brings operations into the same order as graph_info.cc. +std::unique_ptr> +CreatePartitionedTopologicalSortPass(); + +#define GEN_PASS_DECL_DEFAULTQUANTPARAMSPASS +#define GEN_PASS_DECL_LEGALIZETFPASS +#define GEN_PASS_DECL_LOWERSTATICTENSORLISTPASS +#define GEN_PASS_DECL_MODIFYIONODESPASS +#define GEN_PASS_DECL_POSTQUANTIZEPASS +#define GEN_PASS_DECL_PREPARECOMPOSITEFUNCTIONSPASS +#define GEN_PASS_DECL_PREPAREDYNAMICRANGEQUANTIZEPASS +#define GEN_PASS_DECL_PREPAREQUANTIZEPASS +#define GEN_PASS_DECL_PREPARETFPASS +#define GEN_PASS_DECL_QUANTIZEPASS +#define GEN_PASS_DECL_RAISECUSTOMOPSPASS +#define GEN_PASS_DECL_TRIMFUNCTIONSPASS +#define GEN_PASS_REGISTRATION +#include "tensorflow/compiler/mlir/lite/transforms/passes.h.inc" + +// Creates an instance of the TensorFlow Lite dialect LegalizeTF pass. +std::unique_ptr> CreateLegalizeTFPass( + const LegalizeTFPassOptions& options); + +// Creates an instance of the TensorFlow Lite dialect PrepareTF pass. +std::unique_ptr> CreatePrepareTFPass( + const PrepareTFPassOptions& options); + +// Creates an instance of the TensorFlow Lite dialect LowerStaticTensorList +// pass. +std::unique_ptr> CreateLowerStaticTensorListPass( + const LowerStaticTensorListPassOptions& options); + +// Creates raise custom ops pass, which legalize custom ops to TFL::CustomOp +std::unique_ptr> CreateRaiseCustomOpsPass( + const RaiseCustomOpsPassOptions& options); + +// Creates an instance of the TensorFlow Lite dialect pass to add default +// quantization parameters. +std::unique_ptr> CreateDefaultQuantParamsPass( + const DefaultQuantParamsPassOptions& options); + +inline void registerTensorFlowLitePasses() { + registerTensorFlowLiteTdPasses(); + // Register TFLite Converter Passes + Register(); + + // TF Legalization Passes + Register(); + Register(); + Register(); + + // TFL Optimization Passes + Register(); + Register(); + Register(); + Register(); + Register(); + Register(); + + // Other TFLite Passes + Register(); + Register(); +} + +} // namespace TFL + +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_PASSES_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/transforms/pipeline.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/transforms/pipeline.h new file mode 100644 index 00000000..f0420b11 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/transforms/pipeline.h @@ -0,0 +1,173 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_PIPELINE_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_PIPELINE_H_ + +#include +#include +#include +#include + +#include "llvm/ADT/StringRef.h" +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassManager.h" // from @llvm-project +#include "mlir/Pass/PassOptions.h" // from @llvm-project +#include "mlir/Support/TypeID.h" // from @llvm-project +#include "tensorflow/compiler/mlir/lite/transforms/pass.h" +#include "tensorflow/compiler/mlir/lite/transforms/pass_options.h" +#include "tensorflow/compiler/mlir/lite/transforms/pass_options_setter.h" + +namespace mlir { +namespace TFL { + +/// Pipeline is a base class for pipelines of passes. +/// +/// A pipeline is a collection of passes that are run in a specific order. The +/// pipeline can be configured with options that control which passes are +/// enabled and how they are run. +/// +/// To create a new pipeline, derive from this class and implement the +/// `AddPasses` method. This method should add passes to the pipeline using the +/// `AddPass` method. +/// +/// Example: +/// +/// ```cpp +/// class MyPipeline : public Pipeline { +/// public: +/// void AddPasses() override { +/// AddPass(); +/// AddPass(); +/// } +/// }; +/// ``` +template +class Pipeline { + public: + struct PipelineEntry { + std::unique_ptr pass; + std::function enable_condition; + }; + + Pipeline() = default; + virtual ~Pipeline() = default; + virtual void AddPasses() = 0; + + /// Function to force the derived pipeline to implement the metadata + // method. + llvm::StringRef getArgument() const { return DerivedPipeline::GetArgument(); } + + llvm::StringRef getDescription() const { + return DerivedPipeline::GetDescription(); + } + + llvm::StringRef getName() const { return DerivedPipeline::GetName(); } + + void GetPipeline(mlir::OpPassManager &pm, + const DerivedPipelineOptions &options) { + for (auto &&entry : passes_) { + if (entry.enable_condition(options)) { + pm.addPass(std::move(entry.pass)); + } + } + }; + + protected: + void AddPass( + std::unique_ptr pass, + std::function enable_condition) { + passes_.push_back({std::move(pass), enable_condition}); + } + + template + friend class PipelinePass; + + std::vector GetPasses() { + std::vector passes; + passes.reserve(passes_.size()); + for (auto &&entry : passes_) { + passes.push_back(entry.pass.get()); + } + return passes; + } + + private: + std::vector passes_; +}; + +/// PipelinePass is a wrapper class to run a pipeline of passes as a single +/// pass. This is an implementation detail of the pipelines mechanism in TFL +/// Converter framework. Users should not need to interact with this class +/// directly. +template +class PipelinePass + : public Pass, PipelineOptions> { + public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PipelinePass); + + PipelinePass() { pipeline_->AddPasses(); }; + PipelinePass(const PipelinePass &) {}; + explicit PipelinePass(const PipelineOptions &options) + : Pass, PipelineOptions>( + options) { + pipeline_.AddPasses(); + }; + + std::unique_ptr<::mlir::Pass> clonePass() const override { + auto pass = std::make_unique>(); + pass->GetOptions().copyOptionValuesFrom(this->GetOptions()); + return std::move(pass); + } + + /// Function to satisfy the mlir::Pass interface + static llvm::StringRef GetArgument() { return Pipeline::GetArgument(); } + + static llvm::StringRef GetDescription() { return Pipeline::GetDescription(); } + + static llvm::StringRef GetName() { return Pipeline::GetName(); } + + void runOnOperation() final { + ModuleOp module_op = this->getOperation(); + + // Create a temporary OpPassManager to run the passes. Nesting is set to be + // implicit to allow for the nesting to happen under-the-hood. + OpPassManager pm(ModuleOp::getOperationName(), + OpPassManager::Nesting::Implicit); + pipeline_->GetPipeline(pm, this->GetOptions()); + if (failed(this->runPipeline(pm, module_op))) { + this->signalPassFailure(); + } + }; + + void ApplyOptionsVisitor(const PassOptionsSetter &visitor) final { + visitor.SetOptions(this->GetOptions()); + + for (auto &&pass : pipeline_->GetPasses()) { + if (auto *derived_pass = dynamic_cast(pass)) { + derived_pass->ApplyOptionsVisitor(visitor); + } + } + } + + private: + std::unique_ptr pipeline_ = std::make_unique(); +}; +} // namespace TFL +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_PIPELINE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/transforms/prepare_quantize_helper.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/transforms/prepare_quantize_helper.h new file mode 100644 index 00000000..824976e3 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/transforms/prepare_quantize_helper.h @@ -0,0 +1,668 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Transform pass for LSTMs. + +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_PREPARE_QUANTIZE_HELPER_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_PREPARE_QUANTIZE_HELPER_H_ + +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_set.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/MathExtras.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Dialect/Quant/IR/QuantTypes.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/OpDefinition.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/TypeUtilities.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" +#include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.h" +#include "tensorflow/compiler/mlir/lite/schema/schema_generated.h" +#include "tensorflow/compiler/mlir/lite/tools/optimize/operator_property.h" +#include "tensorflow/compiler/mlir/quantization/common/ir/FakeQuantSupport.h" +#include "tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_config.h" +#include "tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_traits.h" +#include "tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_utils.h" +#include "tensorflow/compiler/mlir/quantization/common/uniform_quantized_types.h" +#include "tensorflow/core/framework/types.pb.h" + +//===----------------------------------------------------------------------===// +// The prepare-quantize Pass for LSTM. +// +namespace mlir { +namespace TFL { + +constexpr double power_of_two_scale = 32768.0; + +// Same with the ordering of //tensorflow/compiler/mlir/lite/ir/tfl_ops.td +constexpr const char* intermediate_attributes[] = { + "input_to_input_intermediate", "input_to_forget_intermediate", + "input_to_cell_intermediate", "input_to_output_intermediate", + "effective_hidden_scale_intermediate"}; + +// Calculates the minimum power of two that is not less than the value. +double PowerOfTwoBound(double value); + +tensorflow::DataType GetQuantizedInferenceType(bool is_signed, + int activation_number_of_bits); + +// Returns the element type of LSTM's intermediate tensor designated by the +// index. +template +inline QuantizedType GetIntermediateElementType(LstmOp op, int tensor_index) { + if (tensor_index < 0 || tensor_index > 4) return nullptr; + TypeAttr attr = op->template getAttrOfType( + intermediate_attributes[tensor_index]); + if (!attr) { + return nullptr; + } + return QuantizedType::getQuantizedElementType(attr.getValue()); +} + +namespace operator_property = ::tflite::optimize::operator_property; +using Q = quantfork::QuantizeCastOp; +using DQ = quantfork::DequantizeCastOp; + +template +LogicalResult GetLstmProperty(LstmOp op, + operator_property::OpVariant* lstm_variant, + operator_property::OperatorProperty* op_property, + int activation_number_of_bits = 8) { + if (llvm::isa(op.getOperation())) { + lstm_variant->op_code = tflite::BuiltinOperator_LSTM; + } else if (llvm::isa(op.getOperation())) { + lstm_variant->op_code = + tflite::BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM; + } else { + op.emitError("ConvertLstmStatsToQDQs pass only supports LSTMs."); + return failure(); + } + lstm_variant->use_projection = + !mlir::isa(op.getProjectionWeights().getType()); + lstm_variant->use_peephole = + !mlir::isa(op.getCellToOutputWeights().getType()); + lstm_variant->use_layer_norm = + !mlir::isa(op.getForgetLayerNormCoefficients().getType()); + + *op_property = operator_property::GetOperatorProperty( + *lstm_variant, activation_number_of_bits); + + // TODO(b/176258587) move this to operator_property.cc if this is needed in + // other components, too. + bool use_cifg = mlir::isa(op.getInputToInputWeights().getType()); + if (use_cifg) { + const absl::flat_hash_set cifg_non_inputs = {1, 5, 9, 12, 20}; + const int cifg_non_intermediate = 0; + op_property->inputs.erase( + std::remove_if( + op_property->inputs.begin(), op_property->inputs.end(), + [&](std::pair input) { + return cifg_non_inputs.find(input.first) != cifg_non_inputs.end(); + }), + op_property->inputs.end()); + op_property->intermediates.erase( + std::remove_if(op_property->intermediates.begin(), + op_property->intermediates.end(), + [&](std::pair + intermediate) { + return intermediate.first == cifg_non_intermediate; + }), + op_property->intermediates.end()); + } + return success(); +} + +template +class PrepareLstmOutputScale : public OpRewritePattern { + public: + explicit PrepareLstmOutputScale(MLIRContext* context) + : OpRewritePattern(context) {} + LogicalResult matchAndRewrite(SourceOp op, + PatternRewriter& rewriter) const override { + operator_property::OpVariant lstm_variant; + operator_property::OperatorProperty lstm_property; + + if (failed(GetLstmProperty(op, &lstm_variant, &lstm_property))) { + return failure(); + } + if (lstm_property.restrict_scale.size() != 1) { + op.emitError() << "The LSTM's operator property expects exactly one " + << "restrict scale requirement. Got " + << lstm_property.restrict_scale.size() + << " restrict scale requirements."; + return failure(); + } + + // Use same scale for input and output specified in restrict_scale. + const std::vector& tensors = lstm_property.restrict_scale[0]; + if (tensors.size() != 2) { + op.emitError( + "Unexpected restricted_scale from operator property." + " Should only have a pair of indices."); + return failure(); + } + return processRestrictScale(op, tensors[0], tensors[1], rewriter); + } + + private: + // For LSTM's recurrent input activation and output, they are quantized with + // the collective range of both tensors, because theoretically the input + // activation value for the very first inference is not reflected in the + // output and the input activation is not captured. + LogicalResult processRestrictScale(SourceOp op, int input_index, + int output_index, + PatternRewriter& rewriter) const { + assert(output_index == 0); + if (!op.getResult().hasOneUse()) { + op.emitError() + << "output " << output_index + << " should have only one use, which should be quant.stats."; + return failure(); + } + + llvm::SmallVector stats_ops = { + llvm::dyn_cast_or_null( + op.getOperand(input_index).getDefiningOp()), + llvm::dyn_cast_or_null( + *op.getResult().getUsers().begin()), + }; + + if (!stats_ops[0] || !stats_ops[1]) { + return failure(); // Already converted to Q-DQ pair. + } + + llvm::SmallVector min_max_values; + + for (auto& stats_op : stats_ops) { + auto values = + mlir::dyn_cast(stats_op.getLayerStats()) + .getValues(); + min_max_values.insert(min_max_values.end(), values.begin(), values.end()); + } + + // min and max values of two stats are already the same. + if (min_max_values[0] == min_max_values[2] && + min_max_values[1] == min_max_values[3]) { + return failure(); + } + + mlir::ElementsAttr layer_stats = mlir::DenseFPElementsAttr::get( + mlir::RankedTensorType::get({2}, rewriter.getF32Type()), + {llvm::minimum(min_max_values[0], min_max_values[2]), + llvm::maximum(min_max_values[1], min_max_values[3])}); + mlir::ElementsAttr axis_stats; + mlir::IntegerAttr axis; + for (auto& stats_op : stats_ops) { + rewriter.setInsertionPointAfter(stats_op); + rewriter.replaceOpWithNewOp( + stats_op, stats_op.getArg(), layer_stats, axis_stats, axis); + } + return success(); + } +}; + +template +class ConvertOpStatsToQDQs : public OpRewritePattern { + public: + explicit ConvertOpStatsToQDQs(MLIRContext* context, + const quant::QuantizationSpecs& quant_specs, + PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit), + quant_specs_(quant_specs) {} + + protected: + quant::QuantizationSpecs quant_specs_; + + LogicalResult processInputs( + SourceOp op, const operator_property::OpVariant& op_variant, + const operator_property::OperatorProperty& op_property, + PatternRewriter& rewriter) const { + for (auto& enumerated_inputs : op_property.inputs) { + int index = enumerated_inputs.first; + auto& tensor_property = enumerated_inputs.second; + + Value input = op.getOperand(index); + + if (input.getDefiningOp() == nullptr) continue; + + // TODO(b/172517537): make this work with non-PTQ case. + if (llvm::isa( + input.getDefiningOp())) { + // Tensors with derived scale are biases, and handled in propagation. + if (tensor_property.use_derived_scale) continue; + // For weights, use quantization scale inferred from the values. + if (failed(processConstantOp(op, input.getDefiningOp(), index, + tensor_property, rewriter))) { + return failure(); + } + } else { + if (auto stats_op = llvm::dyn_cast( + input.getDefiningOp())) { + if (failed(replaceStatsOp(op, stats_op, index, tensor_property, + rewriter))) { + return failure(); + } + } else if (!llvm::isa(input.getDefiningOp()) && + !llvm::isa( + input.getDefiningOp())) { + // Continue if StatisticsOp is already converted to Q-DQ pair, or + // stats op is not immediately available to the input because either + // it's connected to ops with same scale requirements or it has + // fixed output range. + // TODO(b/172517537): make this work with non-PTQ case. + return failure(); + } + } + } + return success(); + } + + LogicalResult processConstantOp( + SourceOp op, Operation* const_op, int input_index, + const operator_property::TensorProperty& tensor_property, + PatternRewriter& rewriter) const { + // Non-float tensors are neither weights nor require quantization. + auto type = mlir::dyn_cast(const_op->getResult(0).getType()); + if (!type || !mlir::isa(type.getElementType())) return success(); + + DenseFPElementsAttr attr; + if (!matchPattern(const_op->getResult(0), m_Constant(&attr))) { + const_op->emitError("Not a constant op."); + return failure(); + } + + UniformQuantizedType quant_type = nullptr; + // When the number of bits is 10 (instead of 16), quantize the tensor to + // [-512, 512], instead of [-32767, 32767]. + // For now this behavior is specific for SVDF, where 6 bits are reserved for + // the reduce operation after element-wise multiplication between state and + // time weights. + if (tensor_property.number_of_bits == 10) { + SmallVector mins(1, std::numeric_limits::max()); + SmallVector maxs(1, std::numeric_limits::min()); + // Computes the effective min/max values of the attribute values. + quant::ExtractMinMaxFromAttr(attr, /*dim_size=*/1, /*slice_size=*/1, + /*symmetric=*/true, mins, maxs); + double scale = maxs[0] / -llvm::minIntN(tensor_property.number_of_bits); + quant_type = UniformQuantizedType::getChecked( + const_op->getLoc(), quant::QuantizationFlags::Signed, + rewriter.getIntegerType(16), attr.getType().getElementType(), scale, + /*zeroPoint=*/0, llvm::minIntN(10), -llvm::minIntN(10)); + } else { + quant_type = mlir::dyn_cast( + quant::GetUniformQuantizedTypeForWeight( + attr, /*symmetric=*/true, + /*num_bits=*/tensor_property.number_of_bits, + /*is_signed=*/true, + /*narrow_range=*/true, quant_specs_.legacy_float_scale)); + } + if (!quant_type) { + const_op->emitError("Failed to get quantized type"); + return failure(); + } + + // TODO(b/172517537): duplicate the constant when the bias is shared. + Type expressed_type = const_op->getResult(0).getType(); + Type cast_type = quant_type.castFromExpressedType(expressed_type); + rewriter.setInsertionPointAfter(const_op); + auto q = rewriter.create(const_op->getLoc(), cast_type, + const_op->getResult(0)); + auto dq = rewriter.create(const_op->getLoc(), expressed_type, q); + op.setOperand(input_index, dq.getResult()); + return success(); + } + + LogicalResult replaceStatsOp( + SourceOp op, quantfork::StatisticsOp stats_op, int input_index, + const operator_property::TensorProperty& tensor_property, + PatternRewriter& rewriter) const { + if (tensor_property.state_tensor && !stats_op.getResult().hasOneUse()) { + // TODO(b/172517537): check if other tensors should go through this + // check too. + op.emitError() << "Input tensor [" << input_index + << "] is a state tensor, but has more than one use."; + return failure(); + } + auto stats = mlir::dyn_cast(stats_op.getLayerStats()); + if (!stats || stats.getNumElements() != 2) { + stats_op.emitError("Stats should have 2 values."); + return failure(); + } + quant::QuantizedType quant_type; + double min = FloatAttr::getValueAsDouble(stats.getValues()[0]); + double max = FloatAttr::getValueAsDouble(stats.getValues()[1]); + // Make sure the range includes zero. + min = std::min(min, 0.0); + max = std::max(max, 0.0); + Type expressed = getElementTypeOrSelf(stats_op.getType()); + + if (tensor_property.extend_to_power_of_two) { + if (tensor_property.number_of_bits != 16) { + op.emitError( + "extended power of 2 scale is only supported for 16-bit" + " quantization."); + return failure(); + } + + double bound = PowerOfTwoBound(std::max(std::abs(min), std::abs(max))); + // Set flags to 1 for signed type. + quant_type = UniformQuantizedType::getChecked( + op.getLoc(), quant::QuantizationFlags::Signed, + rewriter.getIntegerType(tensor_property.number_of_bits), expressed, + /*scale=*/bound / -llvm::minIntN(tensor_property.number_of_bits), + /*zeroPoint=*/0, llvm::minIntN(tensor_property.number_of_bits), + llvm::maxIntN(tensor_property.number_of_bits)); + } else { + // int16 uses range [-32767, 32767] + if (tensor_property.number_of_bits == 16) { + max = std::max(std::abs(min), std::abs(max)); + min = -max; + quant_type = quantfork::fakeQuantAttrsToType( + op.getLoc(), tensor_property.number_of_bits, min, max, + /*narrowRange=*/true, expressed, + /*isSigned=*/true); + } else { + quant_type = quantfork::fakeQuantAttrsToType( + op.getLoc(), tensor_property.number_of_bits, min, max, + /*narrowRange=*/false, expressed, + /*isSigned=*/true); + } + if (quant_specs_.legacy_float_scale) { + quant_type = quant::DownCastScale(quant_type, min, max, op.getLoc()); + } + } + rewriter.setInsertionPointAfter(stats_op); + Type result_type = quant_type.castFromExpressedType(stats_op.getType()); + auto q = + rewriter.create(stats_op.getLoc(), result_type, stats_op.getArg()); + rewriter.replaceOpWithNewOp(stats_op, stats_op.getType(), q); + return success(); + } +}; + +// Quantize LSTM according to its quantization recipe. +template +class ConvertLstmStatsToQDQs : public ConvertOpStatsToQDQs { + public: + ConvertLstmStatsToQDQs(MLIRContext* context, + const quant::QuantizationSpecs& quant_specs) + : ConvertOpStatsToQDQs(context, quant_specs), + activation_number_of_bits_(quant_specs.GetQuantizationTypeWidth()) {} + LogicalResult matchAndRewrite(SourceOp op, + PatternRewriter& rewriter) const override { + operator_property::OpVariant lstm_variant; + operator_property::OperatorProperty lstm_property; + if (failed(GetLstmProperty(op, &lstm_variant, &lstm_property, + activation_number_of_bits_))) { + return failure(); + } + + if (failed(processIntermediates(op, lstm_variant, lstm_property)) || + failed(ConvertOpStatsToQDQs::processInputs( + op, lstm_variant, lstm_property, rewriter))) { + return failure(); + } + + return success(); + } + + private: + LogicalResult processIntermediates( + SourceOp op, const operator_property::OpVariant& lstm_variant, + const operator_property::OperatorProperty& lstm_property) const { + for (auto& enumerated_intermediates : lstm_property.intermediates) { + int index = enumerated_intermediates.first; + auto& tensor_property = enumerated_intermediates.second; + // intermediate tensors 0, 1, 2, 3 are only used with layer normalization. + if (!lstm_variant.use_layer_norm && index != 4) { + continue; + } + + TypeAttr attr = + op->template getAttrOfType(intermediate_attributes[index]); + auto quant_type = GetIntermediateElementType(op, index); + if (!quant_type) { + // intermediate tensor 4 is optional, unless the LSTM uses projection. + if (index == 4 && !lstm_variant.use_projection) { + return success(); + } + op.emitError() << intermediate_attributes[index] + << " is not quantized."; + return failure(); + } + auto calibrated_type = + mlir::dyn_cast(quant_type); + if (!calibrated_type) { + int num_storage_bits = quant_type.getStorageTypeIntegralWidth(); + if (tensor_property.number_of_bits != num_storage_bits) { + op.emitError() << intermediate_attributes[index] + << " is expected to be quantized with " + << tensor_property.number_of_bits << " bits, but got " + << num_storage_bits << " bits instead."; + return failure(); + } + continue; // skip if it is already quantized. + } + quant::UniformQuantizedType qtype; + if (tensor_property.number_of_bits == 8) { + qtype = quantfork::fakeQuantAttrsToType( + op.getLoc(), tensor_property.number_of_bits, + calibrated_type.getMin(), calibrated_type.getMax(), + /*narrowRange=*/false, calibrated_type.getExpressedType(), + /*isSigned=*/this->quant_specs_.IsSignedInferenceType()); + if (this->quant_specs_.legacy_float_scale) { + qtype = mlir::cast( + quant::DownCastScale(qtype, calibrated_type.getMin(), + calibrated_type.getMax(), op.getLoc())); + } + } else if (tensor_property.number_of_bits == 16) { + double max = std::max(std::abs(calibrated_type.getMin()), + std::abs(calibrated_type.getMax())); + qtype = quantfork::fakeQuantAttrsToType( + op.getLoc(), tensor_property.number_of_bits, -max, max, + /*narrowRange=*/true, calibrated_type.getExpressedType(), + /*isSigned=*/true); + } else { + op.emitError() << "Unsupported quantization bits: " + << tensor_property.number_of_bits; + return failure(); + } + op->setAttr(intermediate_attributes[index], + TypeAttr::get(qtype.castFromExpressedType( + qtype.castToExpressedType(attr.getValue())))); + } + return success(); + } + + int activation_number_of_bits_; +}; + +// Returns a function that returns the quantized type of a bias input. +// The scale of bias is a multiplication of given scale and scales from the +// quantization type of other operands. +inline quant::AccumulatorScaleFunc GetUniformQuantizedTypeForBiasWithScale( + double scale) { + return [=](const std::vector& quant_params, + const int adjusted_quant_dim, + const bool legacy_float_scale) -> quant::QuantParams { + if (auto qtype = mlir::dyn_cast_or_null( + quant::GetUniformQuantizedTypeForBias( + quant_params, legacy_float_scale, adjusted_quant_dim))) { + return quant::UniformQuantizedType::get( + qtype.getFlags(), qtype.getStorageType(), qtype.getExpressedType(), + qtype.getScale() * scale, qtype.getZeroPoint(), + qtype.getStorageTypeMin(), qtype.getStorageTypeMax()); + } + return {}; + }; +} + +// Returns quantization spec for LSTMs based on their operator properties. +template +std::unique_ptr GetLstmOpQuantSpec(LstmOp op) { + operator_property::OpVariant lstm_variant; + operator_property::OperatorProperty lstm_property; + if (failed(GetLstmProperty(op, &lstm_variant, &lstm_property))) { + return nullptr; + } + + auto spec = std::make_unique(); + + for (const auto& enumerated_inputs : lstm_property.inputs) { + int index = enumerated_inputs.first; + auto& tensor_property = enumerated_inputs.second; + if (tensor_property.use_derived_scale) { + double scale = 1.0; + for (int tensor_index : + tensor_property.derived_scale.intermediate_tensors) { + auto quant_type = GetIntermediateElementType(op, tensor_index); + if (!quant_type || + !mlir::isa(quant_type)) { + op->emitError() << "While processing derived scale, intermediate " + << intermediate_attributes[tensor_index] + << " is not quantized."; + return nullptr; + } + scale *= + mlir::dyn_cast(quant_type).getScale(); + } + for (float factor : tensor_property.derived_scale.factors) { + scale *= factor; + } + spec->biases_params.emplace( + index, + std::make_pair(tensor_property.derived_scale.input_tensors, + GetUniformQuantizedTypeForBiasWithScale(scale))); + } + } + return spec; +} + +class ConvertSvdfStatsToQDQs : public ConvertOpStatsToQDQs { + public: + explicit ConvertSvdfStatsToQDQs( + MLIRContext* context, const quant::QuantizationSpecs& quant_specs_param) + : ConvertOpStatsToQDQs(context, quant_specs_param) {} + LogicalResult matchAndRewrite(TFL::SVDFOp op, + PatternRewriter& rewriter) const override { + operator_property::OpVariant op_variant; + op_variant.op_code = tflite::BuiltinOperator_SVDF; + auto op_property = operator_property::GetOperatorProperty(op_variant); + return ConvertOpStatsToQDQs::processInputs( + op, op_variant, op_property, rewriter); + } +}; + +class PropagateTransposedPerAxisQuantDim + : public OpRewritePattern { + public: + explicit PropagateTransposedPerAxisQuantDim(MLIRContext* context) + : OpRewritePattern(context) {} + LogicalResult matchAndRewrite(TFL::TransposeOp transpose_op, + PatternRewriter& rewriter) const override { + // Check if the quantization is per-axis + auto dq_op = dyn_cast_or_null( + transpose_op.getOperand(0).getDefiningOp()); + if (!dq_op) return failure(); + auto q_op = dyn_cast_or_null( + dq_op.getOperand().getDefiningOp()); + if (!q_op) return failure(); + auto qtype = + mlir::cast(dq_op.getArg().getType()).getElementType(); + auto aqtype = dyn_cast_or_null(qtype); + if (!aqtype) return failure(); + + // Return if the result of TransposeOp is already quantized + if (!transpose_op.getResult().hasOneUse()) return failure(); + auto next_op = *transpose_op.getResult().getUsers().begin(); + if (dyn_cast_or_null(next_op)) return failure(); + + auto input_type = mlir::cast(transpose_op.getInput().getType()); + auto perm_type = mlir::cast(transpose_op.getPerm().getType()); + if (input_type.hasStaticShape() && perm_type.hasStaticShape()) { + if (perm_type.getNumElements() != input_type.getRank()) { + return transpose_op.emitOpError( + "perm tensor elements size is not equal to input tensor rank"); + } + } + + // Get permutation axes of the TransposeOp + DenseIntElementsAttr perm; + if (!matchPattern(transpose_op.getPerm(), m_Constant(&perm))) { + return failure(); + } + + SmallVector axes; + for (const auto& axis_int : perm.getValues()) { + int64_t axis = axis_int.getSExtValue(); + if (axis < 0) { + axis += input_type.getRank(); + } + if (axis < 0 || (input_type.hasRank() && axis >= input_type.getRank())) { + return transpose_op.emitOpError("perm must be in [-rank, rank)"); + } + if (std::count(axes.begin(), axes.end(), axis) > 0) { + return transpose_op.emitOpError("perm cannot have duplicated axis"); + } + axes.push_back(axis); + } + + // Find what the quantized dimension has been transposed to + int new_out_quant_dim = -1; + for (int i = 0; i < axes.size(); ++i) { + if (axes[i] == aqtype.getQuantizedDimension()) { + new_out_quant_dim = i; + break; + } + } + if (new_out_quant_dim == -1) { + return transpose_op.emitOpError( + "new quantization dimension not found in perm"); + } + + // Insert a QDQ pair with the new quantized dimension after TransposeOp + auto new_qtype = quant::CreateI8F32UniformQuantizedPerAxisType( + transpose_op.getLoc(), *rewriter.getContext(), aqtype.getScales(), + aqtype.getZeroPoints(), new_out_quant_dim, /*narrow_range=*/true); + auto new_tensor_type = RankedTensorType::getChecked( + transpose_op.getLoc(), transpose_op.getType().getShape(), new_qtype); + rewriter.setInsertionPointAfter(transpose_op); + auto new_q_op = rewriter.create( + transpose_op.getLoc(), new_tensor_type, q_op.getArg()); + auto new_dq_op = rewriter.create( + new_q_op.getLoc(), transpose_op.getResult().getType(), + new_q_op.getResult()); + transpose_op.getResult().replaceAllUsesWith(new_dq_op.getResult()); + new_q_op.setOperand(transpose_op.getResult()); + + return success(); + } +}; + +} // namespace TFL +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_PREPARE_QUANTIZE_HELPER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/transforms/push_transpose_through_ewise_pass.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/transforms/push_transpose_through_ewise_pass.h new file mode 100644 index 00000000..41114864 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/transforms/push_transpose_through_ewise_pass.h @@ -0,0 +1,65 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_PUSH_TRANSPOSE_THROUGH_EWISE_PASS_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_PUSH_TRANSPOSE_THROUGH_EWISE_PASS_H_ + +#include "llvm/ADT/APFloat.h" +#include "llvm/ADT/APInt.h" +#include "llvm/ADT/StringRef.h" +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/TypeID.h" // from @llvm-project +#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" +#include "tensorflow/compiler/mlir/lite/transforms/pass.h" + +namespace mlir { +namespace TFL { + +class PushTransposeThroughEwisePass + : public Pass { + public: + PushTransposeThroughEwisePass() = default; + PushTransposeThroughEwisePass(const PushTransposeThroughEwisePass &other) {} + + void runOnOperation() final; + + /// Returns the command-line argument attached to this pass. + static llvm::StringRef GetArgument() { + return "tfl-push-transpose-through-ewise"; + } + + static llvm::StringRef GetDescription() { + return "Push transpose ops through element-wise ops."; + } + + /// Returns the derived pass name. + static llvm::StringRef GetName() { return "PushTransposeThroughEwisePass"; } + + /// Return the dialect that must be loaded in the context before this pass. + void getDependentDialects(::mlir::DialectRegistry ®istry) const override { + registry.insert(); + } + + /// Explicitly declare the TypeID for this class. We declare an explicit + /// private instantiation because Pass classes should only be visible by the + /// current library. + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PushTransposeThroughEwisePass) +}; + +} // namespace TFL +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_PUSH_TRANSPOSE_THROUGH_EWISE_PASS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/transforms/tf_legalizations/analyze_variables_pass.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/transforms/tf_legalizations/analyze_variables_pass.h new file mode 100644 index 00000000..8d5914d5 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/transforms/tf_legalizations/analyze_variables_pass.h @@ -0,0 +1,55 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_TF_LEGALIZATIONS_ANALYZE_VARIABLES_PASS_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_TF_LEGALIZATIONS_ANALYZE_VARIABLES_PASS_H_ + +#include "mlir/IR/DialectRegistry.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/TypeID.h" // from @llvm-project +#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" +#include "tensorflow/compiler/mlir/lite/transforms/pass.h" + +namespace mlir { +namespace TFL { + +// Pass which analyzes the variables in the graph and add an attribute whether +// variables should be legalized to TFLite native ones. +// This pass needs to run post TF->TFL legalization and before variable +// legalization. + +class AnalyzeVariablesPass : public TFL::Pass { + public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(AnalyzeVariablesPass) + + AnalyzeVariablesPass() = default; + AnalyzeVariablesPass(const AnalyzeVariablesPass&) {}; + + void runOnOperation() override; + static llvm::StringRef GetName() { return "AnalyzeVariablesPass"; } + static llvm::StringRef GetArgument() { return "tfl-analyze-variables-pass"; } + static llvm::StringRef GetDescription() { + return "Pass to analyze variables in the graph"; + } + + private: + void getDependentDialects(DialectRegistry& registry) const override { + registry.insert(); + } +}; +} // namespace TFL +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_TF_LEGALIZATIONS_ANALYZE_VARIABLES_PASS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/transforms/tf_legalizations/legalize_tensorlist_pass.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/transforms/tf_legalizations/legalize_tensorlist_pass.h new file mode 100644 index 00000000..8eb9f728 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/transforms/tf_legalizations/legalize_tensorlist_pass.h @@ -0,0 +1,53 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_TF_LEGALIZATIONS_LEGALIZE_TENSORLIST_PASS_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_TF_LEGALIZATIONS_LEGALIZE_TENSORLIST_PASS_H_ + +#include "mlir/IR/DialectRegistry.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/TypeID.h" // from @llvm-project +#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" +#include "tensorflow/compiler/mlir/lite/transforms/pass.h" + +namespace mlir { +namespace TFL { + +// Pass to Legalize TensorFlow tensorlist ops to TensorFlow Lite custom. + +class LegalizeTensorListPass : public TFL::Pass { + public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(LegalizeTensorListPass) + + LegalizeTensorListPass() = default; + LegalizeTensorListPass(const LegalizeTensorListPass&) {}; + + void runOnOperation() override; + static llvm::StringRef GetName() { return "LegalizeTensorListPass"; } + static llvm::StringRef GetArgument() { return "tfl-legalize-tensorlist"; } + static llvm::StringRef GetDescription() { + return "Pass to Legalize TensorFlow tensorlist ops to TensorFlow Lite " + "custom."; + } + + private: + void getDependentDialects(DialectRegistry& registry) const override { + registry.insert(); + } +}; +} // namespace TFL +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_TF_LEGALIZATIONS_LEGALIZE_TENSORLIST_PASS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/transforms/tf_legalizations/while_loop_outline_pass.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/transforms/tf_legalizations/while_loop_outline_pass.h new file mode 100644 index 00000000..6c114ced --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/transforms/tf_legalizations/while_loop_outline_pass.h @@ -0,0 +1,66 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_TF_LEGALIZATIONS_WHILE_LOOP_OUTLINE_PASS_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_TF_LEGALIZATIONS_WHILE_LOOP_OUTLINE_PASS_H_ + +#include + +#include "mlir/IR/DialectRegistry.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/TypeID.h" // from @llvm-project +#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" +#include "tensorflow/compiler/mlir/lite/transforms/pass.h" +#include "tensorflow/compiler/mlir/op_or_arg_name_mapper.h" + +namespace mlir { +namespace TFL { + +// Pass to hoist while op regions into functions. +// This pass outlines the cond/body region of the TFL WhileOp into functions and +// replaces the regions with calls to these outlined functions. +class WhileOutlinePass : public TFL::Pass { + public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(WhileOutlinePass) + + WhileOutlinePass() = default; + WhileOutlinePass(const WhileOutlinePass&) {}; + + void runOnOperation() override; + static llvm::StringRef GetName() { return "WhileOutlinePass"; } + static llvm::StringRef GetArgument() { return "tfl-while-loop-outline"; } + static llvm::StringRef GetDescription() { + return "Pass to hoist while op regions into functions"; + } + + private: + void getDependentDialects(DialectRegistry& registry) const override { + registry.insert(); + } + + // Outlines the regions of the WhileOp's cond and body and insert function + // calls instead, + void OutlineWhile(WhileOp while_op); + + // Get unique name by using the loc to name mapping. + std::string GetName(Operation* op, StringRef suffix); + + tensorflow::OpOrArgLocNameMapper mapper_; +}; +} // namespace TFL +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_TF_LEGALIZATIONS_WHILE_LOOP_OUTLINE_PASS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/transforms/tflite_passes/split_merged_operands_pass.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/transforms/tflite_passes/split_merged_operands_pass.h new file mode 100644 index 00000000..54be99ff --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/transforms/tflite_passes/split_merged_operands_pass.h @@ -0,0 +1,89 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_TFLITE_PASSES_SPLIT_MERGED_OPERANDS_PASS_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_TFLITE_PASSES_SPLIT_MERGED_OPERANDS_PASS_H_ + +#include "llvm/ADT/APFloat.h" +#include "llvm/ADT/APInt.h" +#include "llvm/ADT/StringRef.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/TypeID.h" // from @llvm-project +#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" +#include "tensorflow/compiler/mlir/lite/transforms/pass.h" +#include "tensorflow/compiler/mlir/lite/transforms/pass_options.h" + +namespace mlir { +namespace TFL { + +// Background info: +// Currently the model taken to MLIRConverter is frozen (all the variables have +// been converted to constants, all the assign ops are gone, etc.). However, +// TFLite has these variable tensors semantics. So the variable mapping from TF +// to TFLite is actually broken here, we sort of hard-code the variable tensors +// based on the actual ops using them, such as unidirectional_sequence_lstm. +// +// MLIRConverter also benefits from lots of typical compiler optimization like +// merging same input values if they're identical. These optimizations are +// desirable but not for those TFLite ops which have variable tensors as inputs. +// Yes, they have identical input values, but those identical values are +// "stateful", their values can change during invocations. +// +// A typical example is unidirectional_sequence_lstm have two variable tensor +// inputs: activation state & cell state. They may have same initial values +// (typical zero-initialized), but their values will be changed. So we cannot +// just merge those values. +// +// This pass is more like short-term workaround since we don't have a good +// variable representation right now. +// +// This pass will duplicate input values for those variable tensor inputs. + +class SplitMergedOperandsPass + : public TFL::Pass { + public: + SplitMergedOperandsPass() = default; + SplitMergedOperandsPass(const SplitMergedOperandsPass &other) {} + + void runOnOperation() final; + + /// Returns the command-line argument attached to this pass. + static llvm::StringRef GetArgument() { return "tfl-split-merged-operands"; } + + static llvm::StringRef GetDescription() { + return "Split merged stateful operands for tfl operations."; + } + + /// Returns the derived pass name. + static llvm::StringRef GetName() { return "SplitMergedOperandsPass"; } + + /// Return the dialect that must be loaded in the context before this pass. + void getDependentDialects(::mlir::DialectRegistry ®istry) const override { + registry.insert(); + } + + /// Explicitly declare the TypeID for this class. We declare an explicit + /// private instantiation because Pass classes should only be visible by the + /// current library. + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SplitMergedOperandsPass) +}; + +} // namespace TFL +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_TFLITE_PASSES_SPLIT_MERGED_OPERANDS_PASS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/transforms/tflite_passes/unfold_large_splat_constants_pass.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/transforms/tflite_passes/unfold_large_splat_constants_pass.h new file mode 100644 index 00000000..18ee20ae --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/transforms/tflite_passes/unfold_large_splat_constants_pass.h @@ -0,0 +1,54 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_TFLITE_PASSES_UNFOLD_LARGE_SPLAT_CONSTANTS_PASS_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_TFLITE_PASSES_UNFOLD_LARGE_SPLAT_CONSTANTS_PASS_H_ + +#include "mlir/IR/DialectRegistry.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/TypeID.h" // from @llvm-project +#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" +#include "tensorflow/compiler/mlir/lite/transforms/pass.h" + +namespace mlir { +namespace TFL { + +// Pass to unfold large splat constant tensors. +// This Pass will replace large splat constant tensors to `tfl.Fill` op to +// reduce the size of the generated flatbuffer model size. +class UnfoldLargeSplatConstantPass + : public TFL::Pass { + public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(UnfoldLargeSplatConstantPass) + + UnfoldLargeSplatConstantPass() = default; + UnfoldLargeSplatConstantPass(const UnfoldLargeSplatConstantPass&) {}; + + void runOnOperation() override; + static llvm::StringRef GetName() { return "UnfoldLargeSplatConstantPass"; } + static llvm::StringRef GetArgument() { return "unfold-large-splat-constant"; } + static llvm::StringRef GetDescription() { + return "Pass to unfold large splat constant tensors."; + } + + private: + void getDependentDialects(DialectRegistry& registry) const override { + registry.insert(); + } +}; +} // namespace TFL +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_TFLITE_PASSES_UNFOLD_LARGE_SPLAT_CONSTANTS_PASS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/transforms/unfreeze_global_constants.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/transforms/unfreeze_global_constants.h new file mode 100644 index 00000000..f79c9ccf --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/transforms/unfreeze_global_constants.h @@ -0,0 +1,62 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_UNFREEZE_GLOBAL_CONSTANTS_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_UNFREEZE_GLOBAL_CONSTANTS_H_ + +#include "mlir/IR/DialectRegistry.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/TypeID.h" // from @llvm-project +#include "tensorflow/compiler/mlir/lite/transforms/pass.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h" + +namespace mlir { +namespace TFL { + +// This pass "unfreezes" the use of global constant tensor ops found in the +// module and converts them to `tf.VarHandleOp`s. Also, an initialization +// pattern `tf.AssignVariableOp(tf.VarHandleOp, tf.ConstOp)` is inserted to the +// initializer function of type "init_op" for each of the unfrozen constants. + +class UnfreezeMutableGlobalTensorsPass + : public Pass { + public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(UnfreezeMutableGlobalTensorsPass) + + UnfreezeMutableGlobalTensorsPass() = default; + UnfreezeMutableGlobalTensorsPass(const UnfreezeMutableGlobalTensorsPass&) {}; + + void runOnOperation() override; + static llvm::StringRef GetName() { + return "UnfreezeMutableGlobalTensorsPass"; + } + static llvm::StringRef GetArgument() { + return "unfreeze-mutable-global-tensors"; + } + static llvm::StringRef GetDescription() { + return "Pass to unfreeze mutable global tensor ops"; + } + + private: + void getDependentDialects(DialectRegistry& registry) const override { + registry.insert(); + } +}; +} // namespace TFL +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_UNFREEZE_GLOBAL_CONSTANTS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/transforms/variable_freezing_pipeline.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/transforms/variable_freezing_pipeline.h new file mode 100644 index 00000000..cfd5a1c3 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/transforms/variable_freezing_pipeline.h @@ -0,0 +1,48 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_VARIABLE_FREEZING_PIPELINE_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_VARIABLE_FREEZING_PIPELINE_H_ + +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "tensorflow/compiler/mlir/lite/transforms/pipeline.h" +#include "tensorflow/compiler/mlir/lite/transforms/variable_freezing_pipeline_options.h" + +namespace mlir { +namespace TFL { + +class VariableFreezingPipeline + : public Pipeline { + public: + void AddPasses() override; + + /// Returns the command-line argument attached to this pass. + static llvm::StringRef GetArgument() { + return "tfl-variable-freezing-pipeline"; + } + + static llvm::StringRef GetDescription() { + return "Variable Freezing Pipeline"; + } + + /// Returns the derived pass name. + static llvm::StringRef GetName() { return "VariableFreezingPipeline"; } +}; + +} // namespace TFL +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_VARIABLE_FREEZING_PIPELINE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/transforms/variable_freezing_pipeline_options.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/transforms/variable_freezing_pipeline_options.h new file mode 100644 index 00000000..d7e9ed8d --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/transforms/variable_freezing_pipeline_options.h @@ -0,0 +1,38 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_VARIABLE_FREEZING_PIPELINE_OPTIONS_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_VARIABLE_FREEZING_PIPELINE_OPTIONS_H_ + +#include "llvm/Support/CommandLine.h" +#include "mlir/Pass/PassOptions.h" // from @llvm-project + +namespace mlir { +namespace TFL { + +//////////////////////////////////////////////////////////////////////////////// +// Pass Options +//////////////////////////////////////////////////////////////////////////////// + +struct VariableFreezingPipelineOptions : public mlir::detail::PassOptions { + mlir::detail::PassOptions::Option enable_tflite_variables{ + *this, "enable_tflite_variables", + llvm::cl::desc("Enable Mutable Variables in TFLite")}; +}; + +} // namespace TFL +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_VARIABLE_FREEZING_PIPELINE_OPTIONS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/utils/arithmetic_count_util.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/utils/arithmetic_count_util.h new file mode 100644 index 00000000..c851d73b --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/utils/arithmetic_count_util.h @@ -0,0 +1,90 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_UTILS_ARITHMETIC_COUNT_UTIL_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_UTILS_ARITHMETIC_COUNT_UTIL_H_ + +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project + +namespace mlir { +namespace TFL { + +// For add/mul/div/sub and other broadcastable ops. +class ArithmeticCountUtilHelper { + public: + static bool GetFirstOutputCount(mlir::Operation* op, int64_t* count) { + auto output = op->getResult(0); + auto output_type = + mlir::dyn_cast_or_null(output.getType()); + if (!output_type || !output_type.hasStaticShape()) return false; + + *count = output_type.getNumElements(); + return true; + } + + static bool GetInputTensorTotalSize(mlir::Operation* op, int64_t* count) { + int64_t total_count = 0; + for (auto input : op->getOperands()) { + auto input_type = + mlir::dyn_cast_or_null(input.getType()); + if (!input_type || !input_type.hasStaticShape()) { + return false; + } + total_count += input_type.getNumElements(); + } + *count = total_count; + return true; + } + + // For conv2d/depthwise_conv/fully_connected ops. + // This algorithm actually comes from TOCO tooling_util.cc + static bool GetArithmeticCountForConvAndFullyconnectedOp(mlir::Operation* op, + int64_t* count) { + auto weight = op->getOperand(1); + auto weight_type = + mlir::dyn_cast_or_null(weight.getType()); + if (weight_type == nullptr || !weight_type.hasStaticShape()) return false; + + auto output = op->getResult(0); + auto output_type = + mlir::dyn_cast_or_null(output.getType()); + if (output_type == nullptr || !output_type.hasStaticShape()) return false; + + int64_t cols = 1; + for (int i = 0; i < output_type.getRank() - 1; ++i) { + cols *= output_type.getDimSize(i); + } + const int64_t cost_per_col = 2 * weight_type.getNumElements(); + + *count = cost_per_col * cols; + + auto bias = op->getOperand(2); + if (bias) { + auto bias_type = + mlir::dyn_cast_or_null(bias.getType()); + if (bias_type && bias_type.hasStaticShape()) { + *count += output_type.getNumElements(); + } + } + + return true; + } +}; + +} // namespace TFL +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_UTILS_ARITHMETIC_COUNT_UTIL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/utils/attribute_utils.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/utils/attribute_utils.h new file mode 100644 index 00000000..565b71c2 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/utils/attribute_utils.h @@ -0,0 +1,50 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This header file defines common utils used by TFLite transformation +// passes to work with op attributes. + +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_UTILS_ATTRIBUTE_UTILS_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_UTILS_ATTRIBUTE_UTILS_H_ + +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project + +namespace mlir { +namespace TFL { + +// Returns true if none of the three attributes are empty. +inline bool HasAll3Attrs(Attribute a, Attribute b, Attribute c) { + return a != Attribute() && b != Attribute() && c != Attribute(); +} + +// Returns the single float element from an ElementsAttr. Returns empty +// attribute if the number of elements in the attribute is not 1 or the +// element isn't a float attribute. +FloatAttr ExtractSingleElementAsFloat(ElementsAttr attr); + +// Returns the single float element if the input is an ElementsAttr, or return +// itself as a float element. Returns empty attribute if the number of elements +// in the attribute is not 1, the element or itself isn't a float attribute. +FloatAttr GetSingleElementAsFloatOrSelf(Attribute attr); + +// Returns the single integer element from an ElementsAttr. Returns empty +// attribute if the number of elements in the attribute is not 1 or the +// element isn't a integer attribute. +IntegerAttr ExtractSingleElementAsInteger(ElementsAttr attr); + +} // end namespace TFL +} // end namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_UTILS_ATTRIBUTE_UTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/utils/const_tensor_utils.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/utils/const_tensor_utils.h new file mode 100644 index 00000000..477c5c67 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/utils/const_tensor_utils.h @@ -0,0 +1,111 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_UTILS_CONST_TENSOR_UTILS_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_UTILS_CONST_TENSOR_UTILS_H_ + +#include + +#include +#include +#include +#include + +#include "absl/base/attributes.h" +#include "absl/meta/type_traits.h" +#include "absl/status/statusor.h" +#include "mlir/Dialect/Quant/IR/QuantTypes.h" // from @llvm-project +#include "mlir/IR/AffineMap.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "tensorflow/compiler/mlir/lite/schema/schema_generated.h" +#include "tensorflow/core/framework/tensor.pb.h" +#include "tensorflow/core/framework/tensor_shape.pb.h" + +namespace mlir { +namespace TFL { + +bool IsQuantized(const tflite::TensorT& tensor); + +absl::StatusOr GetQuantizedType( + const tflite::TensorT& tensor, mlir::Builder builder, + bool is_constant = false, mlir::Type storage_type = {}); + +// Imports float tensor with calibration value into calibrated quantized type. +absl::StatusOr GetCalibratedQuantizedType( + const tflite::TensorT& tensor, mlir::Builder builder); + +absl::StatusOr GetTensorType(const tflite::TensorT& tensor, + mlir::Builder builder, + bool is_constant = false, + bool is_intermediate = false, + bool get_storage = false); + +// Gets a constant splat for the given value of type. Requires value to be of +// type static shaped RankedTensorType. `unique_index` is used to get the unique +// value for the attribute. +mlir::ElementsAttr GetSplat(mlir::RankedTensorType type, int unique_index, + mlir::Builder builder); + +absl::StatusOr ConvertIntBuffer( + mlir::RankedTensorType shaped_type, const std::vector& buffer, + bool truncate = false); + +absl::StatusOr ConvertFloatBuffer( + mlir::RankedTensorType shaped_type, const std::vector& buffer); + +tensorflow::TensorProto ConvertTfliteConstTensor( + const tflite::TensorT& tensor, const std::vector& buffer); + +// Get the size of the type in bits. The type can be ComplexType, FloatType, +// IntegerType, QuantizedType, or ShapeType of other supported types. +// +// Sub-byte types, e.g. qu4 and i2, are treated as a full i8. +int64_t GetSizeInBits(mlir::ShapedType shaped_type); +int64_t GetSizeInBits(mlir::Type type); +int64_t GetSizeInBits(mlir::quant::QuantizedType quant_type); + +// Get the size of the type in bytes. +// +// Sub-byte element types, e.g. qu4 and i2, are treated as a full i8. +// e.g. GetSizeInBytes(tensor<4xi2>) == 4, instead of 1. +int64_t GetSizeInBytes(mlir::Type type); + +// Performs an integer divide and checks that the remainder is zero. +// It supports int64 version as well. +template ::value || + std::is_same::value || + std::is_same::value || + std::is_same::value>> +ABSL_ATTRIBUTE_ALWAYS_INLINE Integer ExactIntegerDivide(Integer numerator, + int64_t denominator) { + const Integer ratio = numerator / denominator; + assert((numerator % denominator) == 0); + return ratio; +} + +template ::value, int> = 0> +ABSL_ATTRIBUTE_ALWAYS_INLINE bool IsPowerOfTwo(IntType n) { + static_assert(std::is_integral::value, ""); + return n > 0 && (n & (n - 1)) == 0; +} + +} // namespace TFL +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_UTILS_CONST_TENSOR_UTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/utils/constant_utils.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/utils/constant_utils.h new file mode 100644 index 00000000..1340aa0f --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/utils/constant_utils.h @@ -0,0 +1,43 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_UTILS_CONSTANT_UTILS_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_UTILS_CONSTANT_UTILS_H_ + +#include "absl/status/statusor.h" +#include "mlir/Bytecode/BytecodeOpInterface.h" // from @llvm-project +#include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/AffineMap.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "tsl/platform/statusor.h" + +namespace mlir { +namespace TFL { + +// Returns a Constant op with a single value. +absl::StatusOr CreateConstOpWithSingleValue( + PatternRewriter* rewriter, Location loc, ShapedType shaped_type, int value); + +// Returns a Constant op with a splat vector value. +absl::StatusOr CreateConstOpWithVectorValue( + PatternRewriter* rewriter, Location loc, ShapedType shaped_type, int value); + +} // namespace TFL +} // namespace mlir +#endif // TENSORFLOW_COMPILER_MLIR_LITE_UTILS_CONSTANT_UTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/utils/control_edges.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/utils/control_edges.h new file mode 100644 index 00000000..e5a16ba7 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/utils/control_edges.h @@ -0,0 +1,34 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_UTILS_CONTROL_EDGES_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_UTILS_CONTROL_EDGES_H_ + +#include +#include +#include + +namespace tflite { + +// LINT.IfChange + +using ControlEdge = std::pair; +using ControlEdges = std::vector; + +// LINT.ThenChange(//tensorflow/lite/graph_info.h) + +} // namespace tflite + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_UTILS_CONTROL_EDGES_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/utils/convert_type.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/utils/convert_type.h new file mode 100644 index 00000000..118f9cd4 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/utils/convert_type.h @@ -0,0 +1,52 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_UTILS_CONVERT_TYPE_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_UTILS_CONVERT_TYPE_H_ + +#include "absl/status/statusor.h" +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "tensorflow/compiler/mlir/lite/schema/schema_generated.h" +#include "tensorflow/core/framework/types.pb.h" + +namespace mlir { +class Builder; +} // namespace mlir + +namespace tflite { +// Convert the MLIR type to the corresponding TFLite tensor. +tflite::TensorType ConvertTypeToTensorType(mlir::Type type); + +// Convert the scalar type of a TFlite tensor to the corresponding MLIR type. +mlir::Type ConvertElementType(tflite::TensorType type, mlir::Builder builder); + +// Convert the scalar type of a TFLite tensor to the corresponding +// Tensorflow type +tensorflow::DataType TflTypeToTfType(tflite::TensorType type); + +// Convert the Tensorflow scalar type to the corresponding TFLite type +absl::StatusOr TfTypeToTflType(tensorflow::DataType type); + +// Returns element type from attribute Type 'type_attr'. +mlir::Type GetShapeStrippedType(mlir::TypeAttr type_attr); + +// Returns true if 'val' is not from Quantize op or +// from Quantize Op with same quant type as 'qtype_attr' +bool NotFromQuantOpOrSameQuantType(mlir::Value val, mlir::TypeAttr qtype_attr); + +} // namespace tflite +#endif // TENSORFLOW_COMPILER_MLIR_LITE_UTILS_CONVERT_TYPE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/utils/fake_quant_utils.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/utils/fake_quant_utils.h new file mode 100644 index 00000000..146cae1f --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/utils/fake_quant_utils.h @@ -0,0 +1,177 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This header file defines common utils used by TFLite transformation +// passes to work with tf.FakeQuant* ops. +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_UTILS_FAKE_QUANT_UTILS_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_UTILS_FAKE_QUANT_UTILS_H_ + +#include +#include + +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Matchers.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" +#include "tensorflow/compiler/mlir/lite/utils/utils.h" +#include "tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_utils.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.h" + +namespace mlir { +namespace TFL { + +template +struct FetchMinMaxAttrs { + using AttrType = FloatAttr; + bool operator()(TFFakeQuantOp tf_op, AttrType &min_value, + AttrType &max_value) const { + min_value = tf_op.getMinAttr(); + max_value = tf_op.getMaxAttr(); + return true; // Successfully matched and fetched. + } +}; + +template +struct FetchConstantMinMaxInputs { + using AttrType = DenseFPElementsAttr; + bool operator()(TFFakeQuantOp tf_op, AttrType &min_value, + AttrType &max_value) const { + Value min = tf_op.getMin(), max = tf_op.getMax(); + if (!matchPattern(min, m_Constant(&min_value))) { + return false; + } + if (!matchPattern(max, m_Constant(&max_value))) { + return false; + } + return true; // Successfully matched and fetched. + } +}; + +// Inserts a "tfl.quantize" and "tfl.dequantize" op pair (QDQs) after the +// tf.FakeQyantWithMinMax{Vars|VarsPerChannel|Args}Op +// before the op being constant folded. Since the constant +// folding logic will use a "arith.constant" op to replace the +// "tf.FakeQuantWithMinMaxVarsOp", the "tfl.quantize" op is used to preserve +// the quantization parameters as a TypeAttr and "tfl.dequantize" op used to +// convert the output type to the next op. Here are the transformations: +// +// input min cst max cst input min cst max cst +// \ | | \ | | +// \ (tf.Identity) (tf.Identity) => \ (tf.Identity) (tf.Identity) +// \ | | \ | | +// tf.FakeQuantWithMinMaxVars tf.FakeQuantWithMinMaxVars +// | | +// tfl.quantize +// | +// tfl.dequantize +// | +// If the input is a constant, the result pattern will eventually converted to +// +// quant-emulated input +// | +// tfl.quantize +// | +// tfl.dequantize +// | +// +// +// Warns if the (most likely unwanted, currently not quite correctly handled) +// case of back-to-back tf.FakeQuant occurs +// +// tf.FakeQuant* +// | +// tf.FakeQuant* +// +template +class InsertTFLQuantOpsAfterTFFakeQuantOp { + public: + explicit InsertTFLQuantOpsAfterTFFakeQuantOp(bool use_fake_quant_num_bits) + : use_fake_quant_num_bits_(use_fake_quant_num_bits) {} + + FetchMinMax fetch_min_max_; + + using FetchAttrType = typename FetchMinMax::AttrType; + LogicalResult matchAndRewrite(TFFakeQuantOp tf_op, + OpBuilder &rewriter) const { + // We don't want to insert quantize/dequantize if the quantize op exists. + auto res = tf_op.getOutputs(); + if (!res.hasOneUse() || isa(*res.user_begin())) { + return failure(); + } + + // Extract the min/max constant values from the operands. We also consider + // a special case that there are tf.Identity ops between the min/max + // constants and the tf.FakeQuantWithMinMaxVarsOp. + + FetchAttrType min_value, max_value; + if (!fetch_min_max_(tf_op, min_value, max_value)) { + return failure(); + } + + int quant_dim = -1; + if (PerAxis) { + // This is a special case that the quant_dim is the last dimensions. + quant_dim = mlir::cast(res.getType()).getRank() - 1; + } + // Use the min/max from the operands and the num_bits and narrow_range + // attribute to create the quantization parameter for the new quantize op. + rewriter.setInsertionPointAfter(tf_op.getOperation()); + IntegerAttr num_bits = rewriter.getI64IntegerAttr(tf_op.getNumBits()); + BoolAttr narrow_range = rewriter.getBoolAttr(tf_op.getNarrowRange()); + Type res_type = tf_op.getType(); + TypeAttr qtype = quant::GetQuantizedTypeAttr( + rewriter, res_type, min_value, max_value, quant_dim, num_bits, + narrow_range, /*is_signed=*/false, /*legacy_float_scale=*/false, + use_fake_quant_num_bits_); + if (!qtype) { + return failure(); + } + + // Finally, use the quantization parameter to create the quantize and + // dequantize ops, and insert them between the tf.FakeQuantWithMinMaxVarsOp + // and its users. + Value value = tf_op.getOutputs(); + auto quantize = rewriter.create( + tf_op.getLoc(), qtype.getValue(), value, qtype); + auto dequantize = rewriter.create( + tf_op.getLoc(), res_type, quantize.getOutput()); + value.replaceAllUsesWith(dequantize); + quantize.getOperation()->replaceUsesOfWith(dequantize, value); + + return success(); + } + + bool use_fake_quant_num_bits_; +}; + +// Removes the wrapper of the tf.FakeQuant* ops and creates the tfl.quantize +// and tfl.dequantize pairs before tf.FakeQuant* being foled. +LogicalResult ConvertFakeQuantOps(func::FuncOp func, MLIRContext *ctx, + bool use_fake_quant_num_bits = false); + +// Returns the names of all the considered tf.FakeQuant* ops. +std::vector AllTfFakeQuantOps(); + +} // namespace TFL +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_UTILS_FAKE_QUANT_UTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/utils/low_bit_utils.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/utils/low_bit_utils.h new file mode 100644 index 00000000..fa9bd851 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/utils/low_bit_utils.h @@ -0,0 +1,36 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_UTILS_LOW_BIT_UTILS_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_UTILS_LOW_BIT_UTILS_H_ + +#include +#include + +namespace tflite { +// Assumes that `src_tensor` is a buffer where each element is a 4-bit value +// stored in 8-bit. +// Returns a new buffer that is packed densely with 2 4-bit values in a byte. +// The packing format is low-bits-first, i.e. the lower nibble of a byte is +// filled first, followed by the upper nibble. +std::vector PackInt4ValuesDensely(std::vector src_buffer); + +// Assumes `src_buffer` contains 2 4-bit elements packed in 8-bit. +// Returns a vector where each int8 element contains a int4 sign-extended value. +std::vector UnpackDenseInt4IntoInt8( + const std::vector& src_buffer, int64_t num_elements); +} // namespace tflite + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_UTILS_LOW_BIT_UTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/utils/lstm_utils.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/utils/lstm_utils.h new file mode 100644 index 00000000..8d9a5ab1 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/utils/lstm_utils.h @@ -0,0 +1,224 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This header file defines common utils used by TFLite transformation +// passes to work with op attributes. + +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_UTILS_LSTM_UTILS_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_UTILS_LSTM_UTILS_H_ + +#include + +#include "llvm/ADT/StringRef.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" +#include "tensorflow/compiler/mlir/lite/utils/utils.h" + +namespace mlir { +namespace TFL { + +constexpr char kTFImplements[] = "tf._implements"; +constexpr char kLstmCellSimple[] = "LSTMCellSimple"; +constexpr char kLayerNormalizedLstmCellSimple[] = + "LayerNormalizedLstmCellSimple"; +constexpr char kCoupleInputForgetGates[] = "CoupleInputForgetGates"; + +// A utility class that enables the conversion of the LSTMCellSimple composite +// op into a fused TFL LSTM op. The fused op is contained within a FuncOp +// that also contains other supporting ops needed to construct the operands for +// the fused op. The caller provides the containing FuncOp as input with +// arguments specifying the input, weight, projection and bias. +// The weight, projection, bias and layer norm scale all need to be +// RankedTensorType. +// This class sets the layer norm coefficients to NoneType. +class ConvertLSTMCellSimpleToFusedLSTM { + public: + explicit ConvertLSTMCellSimpleToFusedLSTM(mlir::func::FuncOp fused_func_op) + : fused_func_op_(fused_func_op), + couple_input_forget_gates_(false), + builder_(fused_func_op.getBody()) {} + + // not copyable. + ConvertLSTMCellSimpleToFusedLSTM(const ConvertLSTMCellSimpleToFusedLSTM&) = + delete; + ConvertLSTMCellSimpleToFusedLSTM& operator=( + const ConvertLSTMCellSimpleToFusedLSTM&) = delete; + virtual ~ConvertLSTMCellSimpleToFusedLSTM() = default; + + virtual llvm::StringRef GetCompositeOpName() { return kLstmCellSimple; } + + // Rewrite the func body with constructed fused lstm. + LogicalResult RewriteFunc(); + + int GetNumInputs() { return n_input_; } + + protected: + // verify input func op arguments/attributes and initialize internal state. + virtual LogicalResult InitializeFromFuncAttributes(); + virtual LogicalResult Initialize(); + + void UpdateFuncSignature(); + void GenerateFusedOpOperands(); + + void SetWeightForInputToCellGate(); + void SetWeightForInputToInputGate(); + void SetWeightForInputToForgetGate(); + void SetWeightForInputToOutputGate(); + + void SetWeightForRecurrentToCellGate(); + void SetWeightForRecurrentToInputGate(); + void SetWeightForRecurrentToForgetGate(); + void SetWeightForRecurrentToOutputGate(); + + void SetBiasToCellGate(); + void SetBiasToInputGate(); + void SetBiasToForgetGate(); + void SetBiasToOutputGate(); + + void SetProjection(); + void SetProjectionBias(); + + void SetInputActivationState(); + void SetInputCellState(); + + virtual void SetCellLayerNormCoefficients(); + virtual void SetInputLayerNormCoefficients(); + virtual void SetForgetLayerNormCoefficients(); + virtual void SetOutputLayerNormCoefficients(); + + // specified state + func::FuncOp fused_func_op_; + Value input_; + Value weight_; + Value bias_; + Value projection_; + bool couple_input_forget_gates_; + + // internal state + Value weight_transposed_; + Value projection_transposed_; + RankedTensorType weight_type_; + RankedTensorType projection_type_; + int num_gates_; + int n_cell_; + int n_output_; + int n_input_; + int num_cols_weight_transposed_; + int num_cols_projection_transposed_; + + // input -> cifg + Value input2input_; + Value input2forget_; + Value input2cell_; + Value input2output_; + + // recurrent -> cifg + Value rec2input_; + Value rec2forget_; + Value rec2cell_; + Value rec2output_; + + // bias -> cifg + Value bias2input_; + Value bias2forget_; + Value bias2cell_; + Value bias2output_; + + // projection + Value proj_weight_; + Value proj_bias_; + + // state + Value input_activation_state_; + Value input_cell_state_; + + // layer norm coefficients + Value input_layer_norm_coefficients_; + Value forget_layer_norm_coefficients_; + Value cell_layer_norm_coefficients_; + Value output_layer_norm_coefficients_; + + mlir::TFL::LSTMOp lstm_; + + Value none_; + SmallVector bias_slice_shape_; + SmallVector bias_size_values_; + SmallVector weight_slice_shape_; + SmallVector weight_slice_size_input_values_; + SmallVector weight_slice_size_recurrent_values_; + OpBuilder builder_; +}; + +// A utility class that enables the conversion of the +// LayerNormalizedLSTMCellSimple composite op into a fused TFL LSTM op. The +// fused op is contained within a FuncOp that also contains other supporting ops +// needed to construct the operands for the fused op. The caller provides the +// containing FuncOp as input with arguments specifying the input, weight, +// projection, bias and layer norm scale. The weight, projection, bias and +// layer norm scale all need to be RankedTensorType. +// This class overrides the layer norm coefficient setters from the base class. +class ConvertLayerNormalizedLSTMCellSimpleToFusedLSTM + : public ConvertLSTMCellSimpleToFusedLSTM { + public: + explicit ConvertLayerNormalizedLSTMCellSimpleToFusedLSTM( + mlir::func::FuncOp fused_func_op) + : ConvertLSTMCellSimpleToFusedLSTM(fused_func_op) {} + + // not copyable. + ConvertLayerNormalizedLSTMCellSimpleToFusedLSTM( + const ConvertLayerNormalizedLSTMCellSimpleToFusedLSTM&) = delete; + ConvertLayerNormalizedLSTMCellSimpleToFusedLSTM& operator=( + const ConvertLayerNormalizedLSTMCellSimpleToFusedLSTM&) = delete; + ~ConvertLayerNormalizedLSTMCellSimpleToFusedLSTM() override = default; + + llvm::StringRef GetCompositeOpName() override { + return kLayerNormalizedLstmCellSimple; + } + + protected: + LogicalResult Initialize() override; + + void SetCellLayerNormCoefficients() override; + void SetInputLayerNormCoefficients() override; + void SetForgetLayerNormCoefficients() override; + void SetOutputLayerNormCoefficients() override; + + private: + // specified state + Value layer_norm_scale_; + + // internal state + RankedTensorType layer_norm_scale_type_; + SmallVector layer_norm_slice_shape_; + SmallVector layer_norm_size_values_; +}; + +LogicalResult ConvertKerasLSTMLayer(mlir::func::FuncOp func_op, + OpBuilder* builder); + +LogicalResult ConvertKerasLSTMLayer(mlir::func::FuncOp func_op, + OpBuilder* builder, bool indy); + +} // end namespace TFL +} // end namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_UTILS_LSTM_UTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/utils/nms_utils.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/utils/nms_utils.h new file mode 100644 index 00000000..e3487ba9 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/utils/nms_utils.h @@ -0,0 +1,84 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This header file defines common utils used by TFLite transformation +// passes to work with NMS ops in TFLite. + +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_UTILS_NMS_UTILS_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_UTILS_NMS_UTILS_H_ + +#include + +#include "flatbuffers/flexbuffers.h" // from @flatbuffers +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h" + +namespace mlir { +namespace TFL { + +// Abstracts the conversion of the padded NMS composite function. +class ConvertNMSPaddedFunc { + public: + explicit ConvertNMSPaddedFunc(func::FuncOp func) : func_(func) {} + + void RewriteFunc(); + + LogicalResult VerifySignature(); + + private: + func::FuncOp func_; +}; + +// Abstracts the conversion of the SSD post-processing composite function to +// TFLite. +class ConvertSSDPostProcessFunc { + public: + explicit ConvertSSDPostProcessFunc(func::FuncOp func, mlir::TF::FuncAttr attr) + : func_(func), attr_(attr) {} + + LogicalResult RewriteFunc(); + + LogicalResult VerifySignature(); + + private: + LogicalResult CreateNMSCustomOptions(func::FuncOp func, DictionaryAttr attrs, + std::string& custom_option_buffer); + + LogicalResult AddIntAttr(func::FuncOp func, DictionaryAttr attrs, + const std::string& attribute, + flexbuffers::Builder* builder); + + LogicalResult AddFloatAttr(func::FuncOp func, DictionaryAttr attrs, + const std::string& attribute, + flexbuffers::Builder* builder); + + LogicalResult HasIntAttr(func::FuncOp func, DictionaryAttr attrs, + const std::string& attribute); + + LogicalResult HasFloatAttr(func::FuncOp func, DictionaryAttr attrs, + const std::string& attribute); + + func::FuncOp func_; + mlir::TF::FuncAttr attr_; +}; + +} // end namespace TFL +} // end namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_UTILS_TFTEXT_UTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/utils/perception_ops_utils.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/utils/perception_ops_utils.h new file mode 100644 index 00000000..609534f4 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/utils/perception_ops_utils.h @@ -0,0 +1,63 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_UTILS_PERCEPTION_OPS_UTILS_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_UTILS_PERCEPTION_OPS_UTILS_H_ + +#include + +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h" + +namespace mlir { +namespace TFL { + +// Fuse MaxUnpooling2D ops annotated by tf.function to a TFLite custom op. +class ConvertMaxUnpoolingFunc { + public: + explicit ConvertMaxUnpoolingFunc(func::FuncOp func, mlir::TF::FuncAttr attr) + : func_(func), attr_(attr) {} + + LogicalResult RewriteFunc(); + + LogicalResult VerifySignature(); + + private: + LogicalResult CreateCustomOptions(std::string& custom_option_buffer); + + func::FuncOp func_; + mlir::TF::FuncAttr attr_; +}; + +// Fuse DenseImageWarp ops annotated by tf.function to a TFLite custom op. +class ConvertDenseImageWarpFunc { + public: + explicit ConvertDenseImageWarpFunc(func::FuncOp func) : func_(func) {} + + LogicalResult RewriteFunc(); + + LogicalResult VerifySignature(); + + private: + func::FuncOp func_; +}; + +} // end namespace TFL +} // end namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_UTILS_PERCEPTION_OPS_UTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/utils/region_isolation.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/utils/region_isolation.h new file mode 100644 index 00000000..b32b2df2 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/utils/region_isolation.h @@ -0,0 +1,41 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_UTILS_REGION_ISOLATION_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_UTILS_REGION_ISOLATION_H_ + +#include + +#include "llvm/ADT/SetVector.h" +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project + +namespace mlir { +namespace TFL { + +// Isolates op's contained regions. Replaces all references to values defined +// above these (single block) regions with a block argument. The union of all +// values referenced this way is returned. Each region will have an identical +// signature, which is the types of the returned vector in the same order. +// NOTE: Critically, llvm::SetVector iterates deterministically in order of +// insertion. +std::optional> IsolateRegions(Operation* op_with_regions, + OpBuilder& b); + +} // namespace TFL +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_UTILS_REGION_ISOLATION_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/utils/size_utils.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/utils/size_utils.h new file mode 100644 index 00000000..52aa50c1 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/utils/size_utils.h @@ -0,0 +1,32 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_UTILS_SIZE_UTILS_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_UTILS_SIZE_UTILS_H_ + +#include + +namespace mlir { +namespace TFL { + +// Converts a TF size (64-bit) to TFLite (32-bit) and properly converts TF's +// value for dynamic size (`std::numeric_limits::min()`) to the +// TFLite-specific value. +int32_t ConvertToTfliteSize(int64_t size); + +} // namespace TFL +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_UTILS_SIZE_UTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/utils/stateful_ops_utils.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/utils/stateful_ops_utils.h new file mode 100644 index 00000000..e7e3e721 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/utils/stateful_ops_utils.h @@ -0,0 +1,34 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_UTILS_STATEFUL_OPS_UTILS_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_UTILS_STATEFUL_OPS_UTILS_H_ + +#include + +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project + +namespace mlir { +namespace TFL { + +// Check if the given op has stateful operands and return their stateful +// operand indices. +bool IsStatefulOp(Operation* op, std::vector* stateful_operand_indices); + +} // namespace TFL +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_UTILS_STATEFUL_OPS_UTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/utils/string_utils.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/utils/string_utils.h new file mode 100644 index 00000000..e1ede084 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/utils/string_utils.h @@ -0,0 +1,110 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Util methods to store a an ordered collection of strings in a char buffer. +// The format of the char buffer is: +// [0, 3] 4 bytes: N, num of strings in the collection. +// [(i+1)*4, (i+1)*4+3] 4 bytes: offset of i-th string in little endian, +// for i from 0 to N-1. +// [(N+1)*4, (N+1)*4+3] 4 bytes: length of the whole char buffer. +// [offset(i), offset(i+1) - 1] : content of i-th string. +// +// A typical usage: +// SimpleDynamicBuffer buf; +// char* buffer; +// # Add string "AB", string is stored in dynamic buffer. +// buf.AddString("AB", 2); +// # Write content of SimpleDynamicBuffer to buffer in format described above. +// buf.WriteToBuffer(&buffer) + +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_UTILS_STRING_UTILS_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_UTILS_STRING_UTILS_H_ + +#include + +#include +#include +#include + +namespace mlir::TFL { + +constexpr uint64_t kDefaultMaxLength = std::numeric_limits::max(); + +class SimpleDynamicBuffer { + public: + explicit SimpleDynamicBuffer(size_t max_length = kDefaultMaxLength) + : offset_({0}), max_length_(max_length) {} + + // Add string to dynamic buffer by resizing the buffer and copying the data. + bool AddString(const char* str, size_t len); + + // Fill content into a buffer and returns the number of bytes stored. + // The function allocates space for the buffer but does NOT take ownership. + int WriteToBuffer(char** buffer); + + protected: + // Data buffer to store contents of strings, not including headers. + std::vector data_; + // Offset of the starting index of each string in data buffer. + std::vector offset_; + // Max length in number of characters that we permit the total + // buffer containing the concatenation of all added strings to be. + // For historical reasons this is limited to 32bit length. At this files + // inception, sizes were represented using 32bit which forced an implicit cap + // on the size of the buffer. When this was refactored to use size_t (which + // could be 64bit) we enforce that the buffer remains at most 32bit length to + // avoid a change in behavior. + const size_t max_length_; +}; + +// Convenient structure to store string pointer and length. Note that +// methods on SimpleDynamicBuffer enforce that the whole buffer (and by +// extension every contained string) is of max length (2ul << 30) - 1. See +// string_util.cc for more info. +typedef struct { + const char* str; + size_t len; +} StringRef; + +// Return num of strings in a String tensor. +inline int GetStringCount(const void* raw_buffer) { + // The first integers in the raw buffer is the number of strings. + // + // NOTE: The string buffer is accessed here as if it's native endian (instead + // of small endian, as documented in the header). This will protentially break + // when TFLite is ported to big endian platforms. + // TODO(b/165919229): This code will need changing if/when we port to a + // big-endian platform. + return *static_cast(raw_buffer); +} + +// Get String pointer and length of index-th string in tensor. +// NOTE: This will not create a copy of string data. +inline StringRef GetString(const void* raw_buffer, int string_index) { + // NOTE: The string buffer is accessed here as if it's native endian (instead + // of small endian, as documented in the header). This will protentially break + // when TFLite is ported to big endian platforms. + // TODO(b/165919229): This code will need changing if/when we port to a + // big-endian platform. + const int32_t* offset = + static_cast(raw_buffer) + (string_index + 1); + const size_t string_len = (*(offset + 1)) - (*offset); + return StringRef{static_cast(raw_buffer) + (*offset), + string_len}; +} + +} // namespace mlir::TFL + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_UTILS_STRING_UTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/utils/tftext_utils.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/utils/tftext_utils.h new file mode 100644 index 00000000..eafa2d44 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/utils/tftext_utils.h @@ -0,0 +1,47 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This header file defines common utils used by TFLite transformation +// passes to work with op attributes. + +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_UTILS_TFTEXT_UTILS_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_UTILS_TFTEXT_UTILS_H_ + +#include "llvm/ADT/StringRef.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h" +#include "tensorflow/core/framework/op.h" + +namespace mlir { +namespace TFL { + +// Fuse TF.Text APIs annotated by tf.function to a TFLite custom op. +LogicalResult ConvertTFTextAPI(mlir::func::FuncOp func, llvm::StringRef api, + mlir::TF::FuncAttr attr); + +// Check if TF.Text Tensorflow ops are registered. +bool IsTFTextRegistered(const tensorflow::OpRegistry* op_registery); + +} // end namespace TFL +} // end namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_UTILS_TFTEXT_UTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/utils/utils.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/utils/utils.h new file mode 100644 index 00000000..53f6a038 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/utils/utils.h @@ -0,0 +1,408 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_UTILS_UTILS_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_UTILS_UTILS_H_ + +#include +#include +#include +#include +#include +#include + +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/ErrorHandling.h" +#include "mlir/Dialect/Traits.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributeInterfaces.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/Matchers.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project + +namespace mlir { +namespace TFL { + +using llvm::ArrayRef; +using mlir::Operation; +using mlir::ShapedType; +using mlir::Value; + +// Returns true if the value is the min float value. +inline bool IsNegInfiniteValue(APFloat value) { + if (!value.isNegative()) return false; + return value.isInfinity(); +} + +// Returns true if the value is the max float value. +inline bool IsPosInfiniteValue(APFloat value) { + if (value.isNegative()) return false; + return value.isInfinity(); +} + +// Returns true if all tensor value in `values` has static shape and same shape. +inline bool OpHasSameStaticShapes(Operation* op) { + auto values = op->getOperands(); + int operand_num = 0; + ArrayRef shape; + for (Value value : values) { + auto shaped_type = value.getType().dyn_cast(); + if (!shaped_type || !shaped_type.hasStaticShape()) { + return false; + } + if (operand_num == 0) { + shape = shaped_type.getShape(); + } else { + if (shape != shaped_type.getShape()) { + return false; + } + } + ++operand_num; + } + return true; +} + +// Utility function to map final permutation to initial permutation +// initial -> permutation1 -> permutation2 -> final +inline DenseElementsAttr RemapPermutation(Value permutation1, + DenseElementsAttr perm2_const) { + SmallVector initial_permutation; + DenseElementsAttr perm1_const; + + SmallVector new_permutation; + if (matchPattern(permutation1, m_Constant(&perm1_const))) { + for (int32_t idx = 0; idx < perm1_const.getNumElements(); ++idx) { + initial_permutation.push_back(idx); + } + for (auto perm : perm2_const.getValues()) { + new_permutation.push_back( + initial_permutation[perm1_const + .getValues()[perm.getSExtValue()] + .getSExtValue()]); + } + } + + return mlir::DenseElementsAttr::get( + RankedTensorType::get( + {static_cast(new_permutation.size())}, + mlir::IntegerType::get(permutation1.getContext(), 32)), + llvm::ArrayRef(new_permutation)); +} + +// Utility function to map final permutation to initial permutation +// initial -> permutation1 -> permutation2 -> final +inline DenseElementsAttr RemapPermutation(Value permutation1, + Value permutation2) { + DenseElementsAttr perm2_const; + (void)matchPattern(permutation2, m_Constant(&perm2_const)); + + return RemapPermutation(permutation1, perm2_const); +} + +// Returns true if the transpose op is trivial. Trivial means that +// the permutation is a cyclic permutation of the original shape with only the +// identity dimensions permuted. +inline bool IsTransposeTrivial(llvm::ArrayRef input_shape, + Value perm) { + DenseElementsAttr perm_values_attr; + if (!matchPattern(perm, m_Constant(&perm_values_attr))) return false; + + SmallVector perm_values; + for (const auto& dim : perm_values_attr.getValues()) + perm_values.push_back(dim.getSExtValue()); + + // This should never happen unless the input graph is malformed. + if (input_shape.size() != perm_values.size()) { + return false; + } + + SmallVector old_major_index_ordering; + SmallVector new_major_index_ordering; + for (int i = 0, end = input_shape.size(); i < end; i++) { + if (input_shape[i] != 1) { + old_major_index_ordering.push_back(i); + } + + if (input_shape[perm_values[i]] != 1) { + new_major_index_ordering.push_back(perm_values[i]); + } + } + return (old_major_index_ordering == new_major_index_ordering); +} + +// Returns the permutation that maps the input shape to the output shape. +// This is only valid for trivial reshape ops. +inline DenseElementsAttr GetPermutationFromTrivialReshape( + ShapedType input_type, ShapedType output_type) { + ArrayRef in_shape = input_type.getShape(); + ArrayRef out_shape = output_type.getShape(); + + // Get the indexes of the non-identity dimensions and the identity dimensions + // in the input shape. + SmallVector input_nonidentity_dims_index_array; + SmallVector input_identity_dims_index_array; + + // Since the reshape is trivial, the input and output shapes should have the + // same number of dimensions. And the non-identity dimensions must be in the + // same cyclic order. + for (size_t idx = 0; idx < in_shape.size(); ++idx) { + if (in_shape[idx] != 1) { + input_nonidentity_dims_index_array.push_back(idx); + } else { + input_identity_dims_index_array.push_back(idx); + } + } + + // Get the permutation that maps the input shape to the output shape. + SmallVector permutation; + size_t nonidentity_dims_index_poiter = 0; + size_t identity_dims_index_pointer = 0; + for (auto out_dim : out_shape) { + if (out_dim != 1) { + permutation.push_back( + input_nonidentity_dims_index_array[nonidentity_dims_index_poiter++]); + } else { + permutation.push_back( + input_identity_dims_index_array[identity_dims_index_pointer++]); + } + } + + return mlir::DenseElementsAttr::get( + RankedTensorType::get( + {static_cast(permutation.size())}, + mlir::IntegerType::get(input_type.getContext(), 32)), + llvm::ArrayRef(permutation)); +} + +// Returns true if the reshape op is equivalent to a transpose op. +// This is true if the reshape op is a trivial reshape op, meaning no change in +// the order of non-identity dimensions. +inline bool IsReshapeEquivalentToTranspose(ShapedType input_type, + ShapedType output_type) { + std::vector in_shape{input_type.getShape().vec()}; + std::vector out_shape{output_type.getShape().vec()}; + + // If the reshape changes the number of dimensions so it cannot be interpreted + // as a transpose. + if (in_shape.size() != out_shape.size()) { + return false; + } + + in_shape.erase(std::remove(in_shape.begin(), in_shape.end(), 1), + in_shape.end()); + out_shape.erase(std::remove(out_shape.begin(), out_shape.end(), 1), + out_shape.end()); + return in_shape == out_shape; +} + +// Checks if all elements in the constant attribute value are 1. +inline bool IsAllOnesConstant(Attribute value) { + auto values = value.cast().getValues(); + return !std::any_of(values.begin(), values.end(), + [](int32_t element_value) { return element_value != 1; }); +} + +// Checks if all elements in the constant attribute value are non-negative. +inline bool HasNonNegativeValues(Attribute value) { + auto values = value.cast().getValues(); + return !std::any_of( + values.begin(), values.end(), + [](const APInt& element_value) { return element_value.isNegative(); }); +} + +// Utility function to get the offset between two dense attribute values. +inline TypedAttr GetOffSet(Attribute begin, Attribute end) { + auto begin_values = begin.cast().getValues(); + auto end_values = end.cast().getValues(); + + SmallVector offsets; + if (begin_values.size() == end_values.size()) { + for (size_t i = 0; i < begin_values.size(); ++i) { + offsets.push_back(end_values[i] - begin_values[i]); + } + } + + return mlir::DenseElementsAttr::get( + RankedTensorType::get({static_cast(offsets.size())}, + mlir::IntegerType::get(begin.getContext(), 32)), + llvm::ArrayRef(offsets)); +} + +// Check if the offset between two dense attribute values is non-negative. +inline bool HasNonNegativeOffset(Attribute begin, Attribute end) { + return HasNonNegativeValues(GetOffSet(begin, end)); +} + +// Return true if the permutation value only swaps the last two dimensions +inline bool AreLastTwoDimsTransposed(Value permutation) { + if (!permutation) return false; + DenseElementsAttr perm_values_attr; + + if (!matchPattern(permutation, m_Constant(&perm_values_attr))) return false; + auto perm_values = perm_values_attr.getValues(); + size_t idx = 0; + for (; idx < perm_values_attr.size() - 2; ++idx) { + if (perm_values[idx].getSExtValue() != idx) return false; + } + + return (perm_values[idx].getSExtValue() == perm_values_attr.size() - 1) && + (perm_values[idx + 1].getSExtValue() == idx); +} + +// Gets the new type after transposing the last 2 dimensions. +inline Type TransposeLastTwoDims(Type type) { + auto shaped_type = type.dyn_cast(); + if (!shaped_type.hasStaticShape() || shaped_type.getRank() < 2) { + return nullptr; + } + int rank = shaped_type.getRank(); + if (rank < 2) { + return nullptr; + } + SmallVector new_shape(shaped_type.getShape().begin(), + shaped_type.getShape().end()); + std::swap(new_shape[rank - 1], new_shape[rank - 2]); + return shaped_type.clone(new_shape); +} + +// Returns a ShapedType for a permutation and the shape of input after +// applying the permutation to the given shape through a transpose. +inline ShapedType GetTransposedType(Value input, + llvm::ArrayRef permutation_array) { + auto input_type = input.getType().cast(); + if (permutation_array.size() != input_type.getRank()) { + return nullptr; + } + llvm::SmallVector transposed_shape(permutation_array.size()); + for (int64_t i = 0; i < permutation_array.size(); ++i) { + transposed_shape[i] = input_type.getDimSize(permutation_array[i]); + } + auto transposed_type = + RankedTensorType::get(transposed_shape, input_type.getElementType()); + return transposed_type; +} + +// Return the resultant shape if the shape of the supplied attribute/value is +// expanded by n leading 1s'. +inline SmallVector GetExpandedShape(Value input_val, int n) { + auto input_shape = mlir::cast(input_val.getType()).getShape(); + SmallVector expanded_shape; + expanded_shape.reserve(input_shape.size() + n); + for (int i = 0; i < n; ++i) { + expanded_shape.push_back(1); + } + expanded_shape.insert(expanded_shape.end(), input_shape.begin(), + input_shape.end()); + return expanded_shape; +} + +// Return the resultant shape as a DenseElementsAttr if the shape of the +// supplied attribute/value is expanded by n leading 1s'. +inline DenseElementsAttr GetExpandedShapeAttr(Value input_val, int n) { + auto expanded_shape = GetExpandedShape(input_val, n); + + return mlir::DenseElementsAttr::get( + RankedTensorType::get({static_cast(expanded_shape.size())}, + mlir::IntegerType::get(input_val.getContext(), 32)), + llvm::ArrayRef(expanded_shape)); +} + +// Return the resultant shape type if the shape of the supplied attribute/value +// is expanded by n leading 1s'. +inline ShapedType GetExpandedShapeType(Value input_val, int n) { + auto expanded_shape = GetExpandedShape(input_val, n); + return RankedTensorType::get( + SmallVector{expanded_shape.begin(), expanded_shape.end()}, + mlir::cast(input_val.getType()).getElementType()); +} + +// Returns shape of a ranked tensor. +// Precondition: output_val's is ranked tensor. +// Returns a truncated shape when `truncate` is set to true. +inline DenseElementsAttr GetShape(Value output_val, bool truncate = false) { + auto output_shape = output_val.getType().dyn_cast().getShape(); + + SmallVector shape; + shape.reserve(output_shape.size()); + + bool needs_truncation = true; + for (size_t dim_idx = 0; dim_idx < output_shape.size(); ++dim_idx) { + int64_t dim = output_shape[dim_idx]; + if (truncate && needs_truncation && dim == 1) { + continue; + } else if (needs_truncation && dim != 1) { + needs_truncation = false; + } + shape.push_back(ShapedType::isDynamic(dim) ? -1 + : static_cast(dim)); + } + + return mlir::DenseElementsAttr::get( + RankedTensorType::get( + {static_cast(shape.size())}, + mlir::IntegerType::get(output_val.getContext(), 32)), + llvm::ArrayRef(shape)); +} + +//////////////////////////////////////////////////////////////////////////////// +///////////////// OP BROADCASTING UTILITIES //////////////////////////////////// +//////////////////////////////////////////////////////////////////////////////// + +// Returns whether the resultant type of any broadcastable operation with +// operands `a` and `b` matches `expected_output`. Returns false if `a` is not +// broadcast-compatible with `b`. +inline bool OperandsBroadcastToOutputType(Type a, Type b, + Type expected_output) { + Type output_element_type = + mlir::cast(expected_output).getElementType(); + Type broadcasted_type = + OpTrait::util::getBroadcastedType(a, b, output_element_type); + return broadcasted_type != Type() && broadcasted_type == expected_output; +} + +// Returns int, float or complex DenseElementsAttr with scalar shape with the +// given element type and the integer value. +template +DenseElementsAttr GetScalarOfType(Type ty, T raw_value) { + RankedTensorType scalar_ty = RankedTensorType::get({}, ty); + if (auto float_ty = mlir::dyn_cast(ty)) { + FloatAttr attr = FloatAttr::get(float_ty, raw_value); + return DenseElementsAttr::get(scalar_ty, attr); + } else if (auto int_ty = mlir::dyn_cast(ty)) { + IntegerAttr attr = IntegerAttr::get(int_ty, raw_value); + return DenseElementsAttr::get(scalar_ty, attr); + } else if (auto complex_ty = mlir::dyn_cast(ty)) { + Type complex_element_ty = complex_ty.getElementType(); + if (complex_element_ty.isF32()) { + return DenseElementsAttr::get( + scalar_ty, static_cast>(raw_value)); + } else if (complex_element_ty.isF64()) { + return DenseElementsAttr::get( + scalar_ty, static_cast>(raw_value)); + } + } + llvm_unreachable("unsupported type"); +} + +} // namespace TFL +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_UTILS_UTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/utils/validators.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/utils/validators.h new file mode 100644 index 00000000..be24f40f --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/utils/validators.h @@ -0,0 +1,126 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This header file defines common validators used by TFLite transformation +// passes to validate op attributes or values. + +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_UTILS_VALIDATORS_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_UTILS_VALIDATORS_H_ + +#include + +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributeInterfaces.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project + +namespace mlir { +namespace TFL { + +// TODO(jpienaar): Change these to being one of these variants and/or generate +// these predicates. + +// Returns true if the given TensorFlow op does not have a `data_format` +// attribute (then default to "NHWC"), or its `data_format` attribute is "NHWC". +inline bool TFDataFormatIsNHWC(Operation *op) { + auto attr = op->getAttrOfType("data_format"); + return !attr || attr.getValue() == "NHWC"; +} + +// Returns true if the given TensorFlow op does not have a `data_format` +// attribute (then default to "NDHWC"), or its `data_format` attribute is +// "NDHWC". +inline bool TFDataFormatIsNDHWC(Operation *op) { + auto attr = op->getAttrOfType("data_format"); + return !attr || attr.getValue() == "NDHWC"; +} + +// Returns true if the given `op` +// * has an attribute with the given `name`, +// * and the attribute is an integer list of the form [1, X, Y, 1], +// and writes X, Y as 32-bit integer attribute to `x`, `y`. +bool TFIntListIs1XY1(Operation *op, StringRef name, IntegerAttr *x, + IntegerAttr *y); + +// Returns true if the attribute is an integer list of the form [1, X, Y, 1]. +bool TFIntListIs1XY1(Attribute attr); + +// Returns true if the attribute is an integer list of the form [1, 1, X, Y]. +bool TFIntListIs11XY(Attribute attr); + +// Returns true if the given `op` +// * has an attribute with the given `name`, +// * and the attribute is an integer list of the form [1, X, Y, Z, 1], +// and writes X, Y as 32-bit integer attribute to `x`, `y`, z. +bool TFIntListIs1XYZ1(Operation *op, StringRef name, IntegerAttr *x, + IntegerAttr *y, IntegerAttr *z); + +// Returns true if every element of the attribute is 1. All elements of `attr` +// must be `IntegerAttr`. +bool TFIntListIsAllOnes(Attribute attr); + +// Returns true iff the given value is a float32 tensor. +// is "DT_FLOAT". +inline bool TFTypeIsFloat32Tensor(Value value) { + auto tensorType = mlir::dyn_cast(value.getType()); + if (!tensorType) return false; + return tensorType.getElementType().isF32(); +} + +// Returns true iff the given value is a bf16 tensor. +inline bool TFTypeIsBFloat16Tensor(Value value) { + auto tensorType = mlir::dyn_cast(value.getType()); + if (!tensorType) return false; + return tensorType.getElementType().isBF16(); +} + +// Returns true iff the given value is a f16 tensor. +inline bool TFTypeIsHalfTensor(Value value) { + auto tensorType = mlir::dyn_cast(value.getType()); + if (!tensorType) return false; + return tensorType.getElementType().isF16(); +} + +// Returns true iff the given value is a f16 or bf16 tensor. +inline bool TFTypeIsBFloat16OrHalfTensor(Value value) { + return TFTypeIsBFloat16Tensor(value) || TFTypeIsHalfTensor(value); +} + +// Returns true iff the given TensorFlow op has a `padding` attribute whose +// value is "SAME" or "VALID", and writes the attribute to `padding`. +inline bool TFPaddingIsSameOrValid(Operation *op, StringAttr *padding) { + auto padding_attr = op->getAttrOfType("padding"); + if (padding_attr.getValue() != "SAME" && padding_attr.getValue() != "VALID") + return false; + *padding = padding_attr; + return true; +} + +/// Returns whether the given `a` and `b` have broadcast-compatible +/// types. +bool IsBroadcastableElementsAttrs(mlir::TypedAttr a, mlir::TypedAttr b); +// Returns true if every dimension of the attribute is 1 except the last one. +bool IsDimensionsDegenerateExceptLastOne(mlir::TypedAttr val); +// Returns true if every element is 1 except the last one. +bool IsDimensionsDegenerateExceptLastOne(ArrayRef elements_shape); + +} // end namespace TFL +} // end namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_UTILS_VALIDATORS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/utils/variables_utils.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/utils/variables_utils.h new file mode 100644 index 00000000..570f9afd --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/utils/variables_utils.h @@ -0,0 +1,36 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_UTILS_VARIABLES_UTILS_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_UTILS_VARIABLES_UTILS_H_ + +#include "mlir/IR/AffineMap.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project + +namespace mlir { +namespace TFL { +namespace utils { + +// Returns true if 'op' has type that is supported by native TFLite +// variables. +bool IsSupportedVariableType(Operation* op); + +// Returns true if 'type' is supported by native tflite variables. +bool IsSupportedVariableType(ShapedType type); + +} // namespace utils +} // namespace TFL +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_UTILS_VARIABLES_UTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/version.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/version.h new file mode 100644 index 00000000..321bd395 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/lite/version.h @@ -0,0 +1,25 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_VERSION_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_VERSION_H_ + +// LINT.IfChange(tflite_schema_version) +// The version number of the Schema. Ideally all changes will be backward +// compatible. If that ever changes, we must ensure that version is the first +// entry in the new tflite root so that we can see that version is not 1. +#define TFLITE_SCHEMA_VERSION (3) +// LINT.ThenChange(//tensorflow/lite/version.h) + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_VERSION_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/mlir_graph_optimization_pass.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/mlir_graph_optimization_pass.h new file mode 100644 index 00000000..1e817d0a --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/mlir_graph_optimization_pass.h @@ -0,0 +1,232 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_MLIR_GRAPH_OPTIMIZATION_PASS_H_ +#define TENSORFLOW_COMPILER_MLIR_MLIR_GRAPH_OPTIMIZATION_PASS_H_ + +#include +#include +#include +#include +#include +#include + +#include "tensorflow/compiler/mlir/tf2xla/mlir_bridge_rollout_policy.h" +#include "absl/log/check.h" +#include "llvm/ADT/StringRef.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "tensorflow/core/common_runtime/device_set.h" +#include "tensorflow/core/common_runtime/function_optimization_registry.h" +#include "tensorflow/core/common_runtime/optimization_registry.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/protobuf/config.pb.h" + +namespace tensorflow { + +// -------------------------------------------------------------------------- // +// MLIR passes running on Tensorflow function graphs (Tensorflow V2). +// -------------------------------------------------------------------------- // + +// Disabled - skip execution of the pass. +// Enabled - execute the pass, propagate errors to the caller if any. +// FallbackEnabled - execute the pass and commit all the changes to the MLIR +// module in case of success. Do not commit any changes in case of failures, +// let the rest of the pipeline run. +enum class MlirOptimizationPassState { Disabled, Enabled, FallbackEnabled }; + +// An API for registering MLIR ModulePass with the Tensorflow runtime. These +// passes are running only for function graphs built by Tensorflow V2 and +// instantiated by the process_function_library_runtime (see +// FunctionOptimizationPass for details). +class MlirOptimizationPass { + public: + virtual ~MlirOptimizationPass() = default; + virtual llvm::StringRef name() const = 0; + + // Returns an enum value: + // Enabled if the pass is enabled for the given graph with specified config. + // Disabled if the pass is disabled. + // FallbackEnabled if the pass needs to be executed in fallback mode. + // + // When the pass is FallbackEnabled, the pass is executed and the changes it + // makes to the MLIR module will be committed only if the pass was successful, + // otherwise no changes are committed and the rest of the pipeline is run. + // + // `device_set` can be nullptr if the devices information is not + // available or no device specific filtering is required. + // `function_library` contains function definitions for function calls in + // `graph` not included in the `graph` FunctionLibraryDefinition. + virtual MlirOptimizationPassState GetPassState( + const DeviceSet* device_set, const ConfigProto& config_proto, + const Graph& graph, + const FunctionLibraryDefinition& function_library) const = 0; + + virtual absl::Status Run( + const std::string& function_name, const ConfigProto& config_proto, + mlir::ModuleOp module, const Graph& graph, + const FunctionLibraryDefinition& function_library) = 0; +}; + +class MlirOptimizationPassRegistry { + public: + struct PassRegistration { + int priority; + std::unique_ptr pass; + }; + + struct PriorityComparator { + bool operator()(const PassRegistration& x, + const PassRegistration& y) const { + return x.priority < y.priority; + } + }; + + using Passes = std::set; + + // Returns the global registry of MLIR optimization passes. + static MlirOptimizationPassRegistry& Global(); + + // Register optimization `pass` with the given `priority`. + void Add(int priority, std::unique_ptr pass) { + auto inserted = passes_.insert({priority, std::move(pass)}); + CHECK(inserted.second) + << "Pass priority must be unique. " + << "Previously registered pass with the same priority: " + << inserted.first->pass->name().str(); + } + + // Free the memory allocated for all passes. + void ClearPasses() { passes_.clear(); } + + const Passes& passes() const { return passes_; } + + private: + Passes passes_; +}; + +// Function optimization pass that runs all MLIR passes registered in +// MlirOptimizationPassRegistry. +class MlirFunctionOptimizationPass : public FunctionOptimizationPass { + public: + explicit MlirFunctionOptimizationPass( + const MlirOptimizationPassRegistry* registry = + &MlirOptimizationPassRegistry::Global()) + : registry_(registry) {} + + // Executes all of the underlying registered MlirOptimizationPasses. + absl::Status Run( + const std::string& function_name, const DeviceSet& device_set, + const ConfigProto& config_proto, + const FunctionOptimizationPass::FunctionOptions& function_options, + std::unique_ptr* graph, FunctionLibraryDefinition* flib_def, + std::vector* control_ret_node_names, + bool* control_rets_updated) override; + + private: + const MlirOptimizationPassRegistry* registry_; +}; + +// -------------------------------------------------------------------------- // +// MLIR passes running on Tensorflow V1 graphs. +// -------------------------------------------------------------------------- // + +// An API for registering MLIR ModulePass with the Tensorflow runtime. These +// passes are running only for V1 graphs (legacy graphs) executed via Session +// runtime. Graph importer updates legacy graph behavior to V2 constructs (e.g. +// it raises control flow from Switch/Merge nodes to functional control flow +// with If/While operations). +class MlirV1CompatOptimizationPass { + public: + virtual ~MlirV1CompatOptimizationPass() = default; + virtual llvm::StringRef name() const = 0; + + // Returns a MlirOptimizationPassState based on the given graph and + // config. See comments on `MlirOptimizationPassState` enum for more info + // on exact values. + virtual MlirOptimizationPassState GetPassState( + const DeviceSet* device_set, const ConfigProto& config_proto, + const Graph& graph, + const FunctionLibraryDefinition& function_library) const = 0; + + virtual absl::Status Run(const GraphOptimizationPassOptions& options, + mlir::ModuleOp module) = 0; +}; + +class MlirV1CompatOptimizationPassRegistry { + public: + // Returns the global registry of MLIR optimization passes. + static MlirV1CompatOptimizationPassRegistry& Global(); + + void Add(std::unique_ptr pass) { + CHECK(pass_ == nullptr) << "Only a single pass can be registered"; + pass_ = std::move(pass); + } + + MlirV1CompatOptimizationPass* pass() const { + return pass_ ? pass_.get() : nullptr; + } + + // Free the memory allocated for the single pass. + // This method is used for testing mostly. + void ClearPass() { pass_.reset(); } + + private: + std::unique_ptr pass_{}; +}; + +class MlirV1CompatGraphOptimizationPass : public GraphOptimizationPass { + public: + explicit MlirV1CompatGraphOptimizationPass( + const MlirV1CompatOptimizationPassRegistry* registry = + &MlirV1CompatOptimizationPassRegistry::Global()) + : registry_(registry) {} + + absl::Status Run(const GraphOptimizationPassOptions& options) override; + + private: + const MlirV1CompatOptimizationPassRegistry* registry_; +}; + +// -------------------------------------------------------------------------- // +// Helper classes for static registration of MLIR (V1 Compat) passes in the +// corresponding registry. +// -------------------------------------------------------------------------- // + +namespace mlir_pass_registration { + +class MlirOptimizationPassRegistration { + public: + explicit MlirOptimizationPassRegistration( + int priority, std::unique_ptr pass) { + MlirOptimizationPassRegistry::Global().Add(priority, std::move(pass)); + } +}; + +class MlirV1CompatOptimizationPassRegistration { + public: + explicit MlirV1CompatOptimizationPassRegistration( + std::unique_ptr pass) { + MlirV1CompatOptimizationPassRegistry::Global().Add(std::move(pass)); + } +}; + +} // namespace mlir_pass_registration + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_MLIR_GRAPH_OPTIMIZATION_PASS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/op_or_arg_name_mapper.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/op_or_arg_name_mapper.h new file mode 100644 index 00000000..f8c596ff --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/op_or_arg_name_mapper.h @@ -0,0 +1,105 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_OP_OR_ARG_NAME_MAPPER_H_ +#define TENSORFLOW_COMPILER_MLIR_OP_OR_ARG_NAME_MAPPER_H_ + +#include +#include + +#include "absl/strings/string_view.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/PointerUnion.h" +#include "llvm/ADT/StringMap.h" +#include "llvm/ADT/StringRef.h" +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project + +namespace tensorflow { + +// PointerUnion for operation and value. +// TODO(jpienaar): Rename the files. +using OpOrVal = llvm::PointerUnion; + +// Mapper from operation or value to name. +class OpOrArgNameMapper { + public: + // Returns unique name for the given prefix. + llvm::StringRef GetUniqueName(llvm::StringRef prefix, int hash_value = 0); + + // Returns unique name for the operation or value. + llvm::StringRef GetUniqueName(OpOrVal op_or_val, int hash_value = 0); + + // Returns unique name as a string_view for the operation or value. + absl::string_view GetUniqueNameView(OpOrVal op_or_val); + + // Initializes operation or value to map to name. Returns number of + // operations or value already named 'name' which should be 0 else + // GetUniqueName could return the same names for different operations or + // values. + // Note: Its up to the caller to decide the behavior when assigning two + // operations or values to the same name. + int InitOpName(OpOrVal op_or_val, llvm::StringRef name); + + virtual ~OpOrArgNameMapper(); + + protected: + // Returns true if the name is unique. A derived class can override it if the + // class maintains uniqueness in a different scope. + virtual bool IsUnique(llvm::StringRef name); + + // Returns a constant view of the underlying map. + const llvm::DenseMap& GetMap() const { + return op_or_val_to_name_; + } + + // Returns the separator used before uniqueing suffix. + virtual llvm::StringRef GetSuffixSeparator() { return ""; } + + virtual llvm::StringRef GetDashSeparator() { return "_"; } + + private: + // Returns name from the location of the operation or value. + virtual std::string GetName(OpOrVal op_or_val) = 0; + + // Maps string name to count. This map is used to help keep track of unique + // names for operations or values. + llvm::StringMap name_to_count_; + // Maps operation or values to name. Value in map is a view of the string + // name in `name_to_count_`. Names in `name_to_count_` are never removed. + llvm::DenseMap op_or_val_to_name_; +}; + +// OpOrArgNameMapper that returns, for operations or values not initialized +// to a specific name, a name based on the location of the operation or +// value. +class OpOrArgLocNameMapper : public OpOrArgNameMapper { + protected: + std::string GetName(OpOrVal op_or_val) override; +}; + +// OpOrArgNameMapper that returns, for operations or values not initialized +// to a specific name, a short name. +class OpOrArgStripNameMapper : public OpOrArgNameMapper { + private: + std::string GetName(OpOrVal op_or_val) override; + + // Number of ops mapped. + int count_ = 0; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_OP_OR_ARG_NAME_MAPPER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/python/mlir.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/python/mlir.h new file mode 100644 index 00000000..99a17ca1 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/python/mlir.h @@ -0,0 +1,114 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Functions for getting information about kernels registered in the binary. +// Migrated from previous SWIG file (mlir.i) authored by aminim@. +#ifndef TENSORFLOW_COMPILER_MLIR_PYTHON_MLIR_H_ +#define TENSORFLOW_COMPILER_MLIR_PYTHON_MLIR_H_ + +#include +#include + +#include "absl/strings/string_view.h" +#include "tensorflow/c/eager/c_api.h" +#include "tensorflow/c/tf_status.h" + +namespace tensorflow { + +// Simple wrapper to support tf.mlir.experimental.convert_graph_def. +// Load a GraphDef (binary or textual proto format), convert to MLIR, and +// (optionally) optimize the module before returning it as a string. +// This is an early experimental API, ideally we should return a wrapper object +// around a Python binding to the MLIR module. +std::string ImportGraphDef(const std::string &proto, + const std::string &pass_pipeline, + bool show_debug_info, TF_Status *status); + +// Simple wrapper to support tf.mlir.experimental.convert_function. +// Load FunctionDef (binary or textual proto format), convert to MLIR, and +// (optionally) optimize the module before returning it as a string. +// This is an early experimental API, ideally we should return a wrapper object +// around a Python binding to the MLIR module. +std::string ImportFunction(const std::string &functiondef_proto, + const std::string &pass_pipeline, + bool show_debug_info, TFE_Context *context, + TF_Status *status); + +// This wrapper passes the graph_def taking names of input nodes, the shapes and +// types of its inputs and the output nodes as parameters to MLIR. +std::string ImportGraphDef(const std::string &proto, + const std::string &pass_pipeline, + bool show_debug_info, absl::string_view(input_names), + absl::string_view(input_data_types), + absl::string_view(input_data_shapes), + absl::string_view(output_names), TF_Status *status); + +// Load a SavedModel and return a textual MLIR string corresponding to it. +// +// Args: +// saved_model_path: File path from which to load the SavedModel. +// exported_names_str: Comma-separated list of names to export. +// Empty means "export all". +// +// Returns: +// A string of textual MLIR representing the raw imported SavedModel. +std::string ExperimentalConvertSavedModelToMlir( + const std::string &saved_model_path, const std::string &exported_names_str, + bool show_debug_info, TF_Status *status); + +// Load a SavedModel V1 and return a textual MLIR string corresponding to it +// without any MLIR graph transformation. +// +// Args: +// saved_model_path: File path from which to load the SavedModel. +// tags: Tags to identify MetaGraphDef that need to be loaded. +// upgrade_legacy: Boolean flag that indicates whether to upgrade legacy +// graphs +// +// Returns: +// A string of textual MLIR representing the raw imported SavedModel. +std::string ExperimentalConvertSavedModelV1ToMlirLite( + const std::string &saved_model_path, const std::string &exported_names_str, + const std::string &tags, bool upgrade_legacy, bool show_debug_info, + TF_Status *status); + +// Load a SavedModel V1 and return a textual MLIR string corresponding to it. +// +// Args: +// saved_model_path: File path from which to load the SavedModel. +// tags: Tags to identify MetaGraphDef that need to be loaded. +// lift_variables: Boolean flag that indicates whether to hoist variables +// after loading the SavedModel. +// +// Returns: +// A string of textual MLIR representing the raw imported SavedModel. +std::string ExperimentalConvertSavedModelV1ToMlir( + const std::string &saved_model_path, const std::string &exported_names_str, + const std::string &tags, bool lift_variables, + bool include_variables_in_initializers, bool upgrade_legacy, + bool show_debug_info, TF_Status *status); + +std::string ExperimentalRunPassPipeline(const std::string &mlir_txt, + const std::string &pass_pipeline, + bool show_debug_info, + TF_Status *status); + +// Writes the input textual MLIR as bytecode to output file. +void ExperimentalWriteBytecode(const std::string &filename, + const std::string &mlir_txt, TF_Status *status); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_PYTHON_MLIR_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/python/mlir_wrapper/mlir_wrapper.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/python/mlir_wrapper/mlir_wrapper.h new file mode 100644 index 00000000..f9fbed1c --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/python/mlir_wrapper/mlir_wrapper.h @@ -0,0 +1,30 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_PYTHON_MLIR_WRAPPER_MLIR_WRAPPER_H_ +#define TENSORFLOW_COMPILER_MLIR_PYTHON_MLIR_WRAPPER_MLIR_WRAPPER_H_ + +#include "pybind11/pybind11.h" // from @pybind11 +#include "pybind11/stl.h" // from @pybind11 + +namespace py = pybind11; + +void init_basic_classes(py::module& m); +void init_types(py::module& m); +void init_builders(py::module& m); +void init_ops(py::module& m); +void init_attrs(py::module& m); + +#endif // TENSORFLOW_COMPILER_MLIR_PYTHON_MLIR_WRAPPER_MLIR_WRAPPER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.h new file mode 100644 index 00000000..8d805d93 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.h @@ -0,0 +1,263 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_COMMON_ATTRS_AND_CONSTRAINTS_H_ +#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_COMMON_ATTRS_AND_CONSTRAINTS_H_ + +#include +#include +#include +#include + +#include "absl/status/statusor.h" +#include "llvm/Support/Debug.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Dialect/Traits.h" // from @llvm-project +#include "mlir/IR/AffineMap.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project +#include "mlir/IR/Matchers.h" // from @llvm-project +#include "mlir/IR/OpDefinition.h" // from @llvm-project +#include "mlir/IR/TypeUtilities.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Interfaces/DerivedAttributeOpInterface.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "stablehlo/dialect/StablehloOps.h" // from @stablehlo +#include "tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_utils.h" // IWYU pragma: keep +#include "tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.pb.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/xla_call_module_attrs.h" + +namespace mlir::quant { + +constexpr char kAttrMapAttribute[] = "attr_map"; + +// Name of the string attribute attached to `XlaCallModuleOp`, which is the +// textproto representation of `Method`. +inline constexpr StringRef kQuantizationMethodAttr = "_quantization_method"; + +// Permutation from the NHWC tensor format to NCHW. This is an inverse +// permutation of `kNchwToNhwcPermutation`. +inline constexpr std::array kNhwcToNchwPermutation = {0, 3, 1, 2}; + +// Permutation from the NCHW tensor format to NHWC. This is an inverse +// permutation of `kNchwToNhwcPermutation`. +inline constexpr std::array kNchwToNhwcPermutation = {0, 2, 3, 1}; + +// Permutation from the OIHW (== (output features, input features, height, +// width)) tensor format to HWIO. This is commonly used to transpose convolution +// weights represented as OIHW format to HWIO, which is more desirable for +// certain downstream optimization passes (e.g. XLA). +inline constexpr std::array kOihwToHwioPermutation = {2, 3, 1, 0}; + +// Returns true if the value has static shape. +bool HasStaticShape(Value value); + +// Returns true if the value has static shape at given dims. +bool HasStaticShapeAtDims(Value value, ArrayRef dims); + +// Whether `value` has known rank of `rank`. Returns false when it is not a +// `ShapedType` or its rank is unknown. +inline bool HasRankOf(Value value, const int64_t rank) { + auto shaped_type = mlir::dyn_cast_or_null(value.getType()); + return shaped_type && shaped_type.hasRank() && shaped_type.getRank() == rank; +} + +// Creates a new type that has the shape from the `old_type` and the element +// type from the `element_type`. +Type CloneTypeWithNewElementType(Type old_type, Type element_type); + +// Creates an array with integer/float type. +template || std::is_same_v), void>> +Value CreateConstValue(OpBuilder& builder, const Location loc, + const SmallVector& shape, + const SmallVector& values) { + if constexpr (std::is_integral_v) { + auto shape_type = + RankedTensorType::get(shape, builder.getIntegerType(sizeof(T) * 8)); + + const auto attr = DenseIntElementsAttr::get(shape_type, values); + return builder.create(loc, attr); + } + + const auto type = RankedTensorType::get(shape, builder.getF32Type()); + const auto value_attr = DenseFPElementsAttr::get(type, values); + return builder.create(loc, value_attr); +} + +// Creates a 1D array with integer/float type. +template +Value Create1DConstValue(OpBuilder& builder, const Location loc, + const SmallVector& values) { + return CreateConstValue(builder, loc, + {static_cast(values.size())}, values); +} + +// Creates a scalar with integer / float type. +template +Value CreateScalarConstValue(OpBuilder& builder, const Location loc, + const T value) { + return CreateConstValue(builder, loc, /*shape=*/{}, {value}); +} + +// Checks if the value is a constant and return its splat value. +template || std::is_same_v), void>> +bool GetSplatValue(Value value, T& splat_value) { + if constexpr (std::is_integral_v) { + DenseIntElementsAttr value_attr; + if (!matchPattern(value, m_Constant(&value_attr)) || + !value_attr.isSplat()) { + return false; + } + splat_value = value_attr.getSplatValue(); + return true; + } + + DenseFPElementsAttr value_attr; + if (!matchPattern(value, m_Constant(&value_attr)) || !value_attr.isSplat()) { + return false; + } + splat_value = value_attr.getSplatValue(); + return true; +} + +// Checks if the value is a constant and its splat value is equal to x. +template +bool IsSplatValueEqual(Value value, const T x) { + T splat_value; + if (!GetSplatValue(value, splat_value)) return false; + + return splat_value == x; +} + +// Checks if two values are constants and their splat values are equal. +template +bool AreSplatValuesEqual(Value x, Value y) { + T splat_x, splat_y; + if (!GetSplatValue(x, splat_x) || !GetSplatValue(y, splat_y)) { + return false; + } + + return splat_x == splat_y; +} + +// Clones an operation with new operands while keeping attributes. +SmallVector CloneOpWithReplacedOperands(OpBuilder& builder, + Operation* op, + ArrayRef new_operands); + +// Tries casting `op` with a concrete op type `T`. If the cast fails or `op` is +// a `nullptr`, returns `failure` and prints a debugging message identifying +// the cast attempt as `name`. +template +FailureOr TryCast(Operation* op, const StringRef name) { + auto cast_op = dyn_cast_or_null(op); + if (cast_op) { + return cast_op; + } else { + DEBUG_WITH_TYPE("mlir-quant-attrs-and-constraints", + llvm::dbgs() << "Failed to match " << name << " (" + << T::getOperationName() << ").\n"); + return failure(); + } +} + +FailureOr CastI64ToI32(int64_t value); + +// Tries to cast an array of int64 to int32. If any of the element in the +// array is not in the range of int32, returns failure(). +FailureOr> CastI64ArrayToI32( + ArrayRef int64_array); + +// Returns the first operation with the given type in the function. +template +OpType FindOperationOfType(func::FuncOp function) { + for (auto op : function.getBody().getOps()) { + return op; + } + return nullptr; +} + +// Returns the first user of the given operation, optionally of the given +// type if provided. If there is no user or user of type, return nullptr. +template +Operation* FindUserOfType(Operation* op) { + for (Operation* user : op->getUsers()) { + if (isa(user)) { + return user; + } + } + return nullptr; +} + +// Returns the first user of the given operation, optionally of the given +// type if provided. If there is no user or user of type, return nullptr. +template +Operation* FindOperandOfType(Operation* op) { + for (Value operand_value : op->getOperands()) { + if (isa(operand_value.getDefiningOp())) { + return operand_value.getDefiningOp(); + } + } + return nullptr; +} + +// Returns the function attribute for the given call op which is lifted for +// quantization. +inline FlatSymbolRefAttr GetFuncAttr(TF::PartitionedCallOp call_op) { + return mlir::dyn_cast(call_op.getFAttr()); +} + +inline FlatSymbolRefAttr GetFuncAttr(TF::XlaCallModuleOp call_op) { + return call_op->getAttrOfType( + TF::kStablehloEntryFunctionAttrName); +} + +// Returns the entry function name for the given tf.XlaCallModule op. Returns +// empty string if such attribute does not exist. +StringRef GetEntryFunctionName(TF::XlaCallModuleOp op); + +// Checks whether the given op contains QuantizationTrait::FullyQuantizable. +inline bool HasQuantizableTrait(Operation* op) { + return op->hasAttrOfType(kQuantTraitAttrName) && + op->getAttrOfType(kQuantTraitAttrName).getValue().str() == + QuantTraitValues[QuantizationTrait::FullyQuantizable]; +} + +// Returns true if `op` has two operands and one result and only second operand +// is quantized. +bool IsHybridQuantizedOp(Operation* op); + +// Returns whether a given `stablehlo.dot_general` can be legalizable to +// `tfl.fully_connected`. +absl::StatusOr IsDotGeneralFullyConnected( + ::mlir::stablehlo::DotGeneralOp dot_general_op); + +// Returns the quantization dimension for a given `stablehlo.dot_general` op, +// or `std::nullopt` if the given op is not per-channel quantizable. +std::optional GetDotGeneralQuantizationDim( + ::mlir::stablehlo::DotGeneralOp dot_general_op); + +// Checks if a `StringRef` contains 'conv' or 'dot_general'. +bool ContainsConvOrDot(StringRef str); + +} // namespace mlir::quant + +#endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_COMMON_ATTRS_AND_CONSTRAINTS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/common/func.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/common/func.h new file mode 100644 index 00000000..ade7bcfc --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/common/func.h @@ -0,0 +1,31 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_COMMON_FUNC_H_ +#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_COMMON_FUNC_H_ + +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project + +namespace mlir::quant { + +// Returns a public `func::FuncOp` in `module_op` whose name matches either +// `main` or `serving_default`. If `func::FuncOps` with both names exist, the +// function with name "main" takes precedence. Returns null if no such a +// function exists. +func::FuncOp FindMainFuncOp(ModuleOp module_op); + +} // namespace mlir::quant + +#endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_COMMON_FUNC_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/common/ir/FakeQuantSupport.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/common/ir/FakeQuantSupport.h new file mode 100644 index 00000000..9e0e0e63 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/common/ir/FakeQuantSupport.h @@ -0,0 +1,79 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// +// This file defines support utilities for interoperating with FakeQuant* based +// QAT (Quantized Aware Training) computations, as implemented by TFLite. Note +// that FakeQuant* operators mix multiple concerns specific to how TFLite +// originally implemented quantization. As such, utilities here enforce +// opinions taken by that codebase (vs providing any amount of genericity). +// +// Specifically, it combines the following concerns, each of which would be +// independent variables in a more generic setup: +// - numBits and isSigned imply storage data type (uint8, int8, int16) +// - numBits < 8 is promoted to uint8 or int8 +// - "narrow_range" narrows the lower bound of the storage type's range by +// 1 +// - the specified min/max values are "nudged" so that the result has a zero +// that can be exactly expressed +// - min=max=0 implies scale=0 and zero_point=0 +// +// With the above assumptions applied, every conforming specified FakeQuant op +// can be represented by a UniformQuantizedType. This scheme is not expected to +// be generalized further in the future and should be considered to be a +// legacy set of rules. +// +// As canonically used in TensorFlow graphs, the presence of a FakeQuant node +// is a hint that the specific math represented here has been simulated at +// training time. As such, it is usually not advised to arbitrarily change +// quantization parameters derived from FakeQuant. +// +//===----------------------------------------------------------------------===// + +#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_COMMON_IR_FAKEQUANTSUPPORT_H_ +#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_COMMON_IR_FAKEQUANTSUPPORT_H_ + +#include + +#include "mlir/Dialect/Quant/IR/QuantTypes.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project + +namespace mlir { +namespace quantfork { + +/// Converts per-layer FakeQuant attributes to the corresponding type. +/// In the event that the parameters cannot be converted, returns a nullptr +/// convertible Type and issues an appropriate error. +/// Note that there are multiple variants of a per-layer FakeQuant op, so +/// this function takes the attributes discretely vs taking a reference to the +/// originating op. +quant::UniformQuantizedType fakeQuantAttrsToType(Location loc, unsigned numBits, + double rmin, double rmax, + bool narrowRange, + Type expressedType, + bool isSigned = false); + +/// Converts per-channel FakeQuant attributes to the corresponding type. +/// In the event that the parameters cannot be converted, returns a nullptr +/// convertible Type and issues an appropriate error. +quant::UniformQuantizedPerAxisType fakeQuantAttrsToType( + Location loc, unsigned numBits, int32_t quantizedDimension, + ArrayRef rmins, ArrayRef rmax, bool narrowRange, + Type expressedType, bool isSigned = false); +} // namespace quantfork +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_COMMON_IR_FAKEQUANTSUPPORT_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/common/ir/QuantOps.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/common/ir/QuantOps.h new file mode 100644 index 00000000..699b2582 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/common/ir/QuantOps.h @@ -0,0 +1,34 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_COMMON_IR_QUANTOPS_H_ +#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_COMMON_IR_QUANTOPS_H_ + +#include "llvm/Support/MathExtras.h" +#include "mlir/Bytecode/BytecodeOpInterface.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/Dialect.h" // from @llvm-project +#include "mlir/IR/OpDefinition.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/Interfaces/InferTypeOpInterface.h" // from @llvm-project +#include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project +#include "tensorflow/compiler/mlir/quantization/common/ir/QuantOpsDialect.h.inc" +#define GET_OP_CLASSES + +#include "tensorflow/compiler/mlir/quantization/common/ir/QuantOps.h.inc" + +#endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_COMMON_IR_QUANTOPS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/common/ir/UniformSupport.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/common/ir/UniformSupport.h new file mode 100644 index 00000000..f4dcc8bf --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/common/ir/UniformSupport.h @@ -0,0 +1,247 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_COMMON_IR_UNIFORMSUPPORT_H_ +#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_COMMON_IR_UNIFORMSUPPORT_H_ + +#include +#include +#include +#include +#include +#include + +#include "llvm/ADT/APFloat.h" +#include "llvm/ADT/APInt.h" +#include "llvm/ADT/APSInt.h" +#include "mlir/Dialect/Quant/IR/QuantTypes.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributeInterfaces.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project + +namespace mlir::quantfork { + +// Performs type conversion from an arbitrary input type to a type +// that is expressed by a QuantizedType. +// +// This handles cases where the inputType is a supported primitive type +// (i.e. f32, bf16, etc) or a vector/tensor type based on a supported +// elemental type. +// +// Since conversion often involves introspecting some attributes of the +// input type in order to determine how to represent it, this is a two step +// process. +struct ExpressedToQuantizedConverter { + // Creates a converter for the given input type. + static ExpressedToQuantizedConverter forInputType(Type input_type); + + // Converts the inputType to be based on the given elemental type, + // returning the new type (or nullptr and emit an error on failure). + Type convert(quant::QuantizedType elemental_type) const; + + // Whether the conversion is legal. + explicit operator bool() const { return (bool)expressed_type; } + + // The input type that is being converted from. + // This may be an elemental or composite type. + const Type input_type; + + // Supported, elemental expressed type (i.e. f32). + // Will be nullptr if conversion is not supported. + const Type expressed_type; +}; + +// Reference implementation of converting between real numbers and values +// represented by a UniformQuantizedType. +// Note that this is not expected to be speedy and may be superseded eventually +// by a more optimal implementation. +// Also, the interface assumes that quantization is done per-layer and will +// need to be wider for various per-channel schemes. As such, this is a +// placeholder. +class UniformQuantizedValueConverter { + public: + explicit UniformQuantizedValueConverter( + quant::UniformQuantizedType uniform_type) + : UniformQuantizedValueConverter( + uniform_type.getScale(), + static_cast(uniform_type.getZeroPoint()), + static_cast(uniform_type.getStorageTypeMin()), + static_cast(uniform_type.getStorageTypeMax()), + uniform_type.getStorageTypeIntegralWidth(), + uniform_type.isSigned()) { + assert(isa(uniform_type.getExpressedType())); + assert(uniform_type.getStorageType().isSignlessInteger()); + } + + UniformQuantizedValueConverter(double scale, double zero_point, + double clamp_min, double clamp_max, + uint32_t storage_bit_width, bool is_signed) + : scale_(scale), + zero_point_(zero_point), + clamp_min_(clamp_min), + clamp_max_(clamp_max), + scale_double_(scale), + zero_point_double_(zero_point), + clamp_min_double_(clamp_min), + clamp_max_double_(clamp_max), + storage_bit_width_(storage_bit_width), + is_signed_(is_signed), + round_mode_(APFloat::rmNearestTiesToAway) {} + + UniformQuantizedValueConverter(double scale, double zero_point, + const APFloat& clamp_min, + const APFloat& clamp_max, + uint32_t storage_bit_width, bool is_signed) + : scale_(scale), + zero_point_(zero_point), + clamp_min_(clamp_min), + clamp_max_(clamp_max), + scale_double_(scale), + zero_point_double_(zero_point), + clamp_min_double_(clamp_min.convertToDouble()), + clamp_max_double_(clamp_max.convertToDouble()), + storage_bit_width_(storage_bit_width), + is_signed_(is_signed), + round_mode_(APFloat::rmNearestTiesToAway) {} + + virtual APInt quantizeFloatToInt(APFloat expressed_value) const { + // This function is a performance critical code path in quantization + // since it runs for each single float parameter value. + + // Specialize f32->u8/i8 case to optimize performance. + if (&expressed_value.getSemantics() == &APFloat::IEEEsingle() && + storage_bit_width_ == 8 && + round_mode_ == llvm::APFloatBase::rmNearestTiesToAway) { + return quantizeF32ToInt8(expressed_value); + } + + bool lossy; + expressed_value.convert(scale_.getSemantics(), round_mode_, &lossy); + // fixed_point = clamp(clamp_min, clamp_max, ( + // roundHalfToEven(expressed / scale) + zero_point)) + APFloat scaled = (expressed_value / scale_); + scaled.roundToIntegral(round_mode_); + scaled.add(zero_point_, round_mode_); + APFloat fixed_point = llvm::minimum(scaled, clamp_max_); + fixed_point = llvm::maximum(fixed_point, clamp_min_); + + llvm::APSInt result(storage_bit_width_, !is_signed_); + fixed_point.convertToInteger(result, round_mode_, &lossy); + + return std::move(result); + } + + int64_t quantizeFloatToInt64(APFloat expressed_value) const { + const APInt q_value = quantizeFloatToInt(std::move(expressed_value)); + return is_signed_ ? q_value.getSExtValue() : q_value.getZExtValue(); + } + + virtual ~UniformQuantizedValueConverter() = default; + + private: + // An optimized implementation to quantize f32 to i8/u8 with C++ native + // arithmetic. + virtual APInt quantizeF32ToInt8(const APFloat& expressed_value) const { + assert(&expressed_value.getSemantics() == &APFloat::IEEEsingle()); + assert(storage_bit_width_ == 8); + assert(round_mode_ == llvm::APFloatBase::rmNearestTiesToAway); + + const float real_value = expressed_value.convertToFloat(); + + const double scaled = real_value / scale_double_ + zero_point_double_; + // Round to nearest integer with halfway cases rounded away from zero. + const double scaled_rounded = std::round(scaled); + const double clamped = std::min(std::max(scaled_rounded, clamp_min_double_), + clamp_max_double_); + + uint64_t signless_result; + if (is_signed_) { + int64_t clamped_int = static_cast(clamped); + memcpy(&signless_result, &clamped_int, sizeof(clamped_int)); + } else { + signless_result = static_cast(clamped); + } + return APInt(storage_bit_width_, signless_result, /*isSigned=*/is_signed_); + } + + // Keep both APFloat and double versions of the quantization parameters + // around since they will be used in generic and specialized arithmetic, + // respectively. + const APFloat scale_; + const APFloat zero_point_; + const APFloat clamp_min_; + const APFloat clamp_max_; + + const double scale_double_; + const double zero_point_double_; + const double clamp_min_double_; + const double clamp_max_double_; + + const uint32_t storage_bit_width_; + const bool is_signed_; + const llvm::APFloat::roundingMode round_mode_; +}; + +// An utility class to quantize an attribute by the per-axis quantization +// parameters. The size of the quantization dim in the converted elements +// attribute should match the size of of scales/zero_points vectors in the +// quantization parameters. +class UniformQuantizedPerAxisValueConverter { + public: + explicit UniformQuantizedPerAxisValueConverter( + quant::UniformQuantizedPerAxisType uniform_type) + : scales_(uniform_type.getScales()), + zero_points_(uniform_type.getZeroPoints()), + clamp_min_(static_cast(uniform_type.getStorageTypeMin())), + clamp_max_(static_cast(uniform_type.getStorageTypeMax())), + storage_bit_width_(uniform_type.getStorageTypeIntegralWidth()), + is_signed_(uniform_type.isSigned()), + quantization_dim_(uniform_type.getQuantizedDimension()) { + assert(isa(uniform_type.getExpressedType())); + assert(uniform_type.getStorageType().isSignlessInteger()); + assert(scales_.size() == zero_points_.size()); + } + + // Quantize an Attribute by the quantization parameters. Return nullptr if + // the conversion fails or the input array isn't an ElementsAttr. + ElementsAttr convert(Attribute real_value); + + private: + // Quantize an DenseFPElementsAttr by the quantization parameters. + DenseElementsAttr convert(DenseFPElementsAttr attr); + + // Get a uniform converter for the index-th chunk along the quantizationDim. + // All the elements in this chunk is quantized by the returned converter. + UniformQuantizedValueConverter getPerChunkConverter(int index) const { + return UniformQuantizedValueConverter(scales_[index], zero_points_[index], + clamp_min_, clamp_max_, + storage_bit_width_, is_signed_); + } + + const ArrayRef scales_; + const ArrayRef zero_points_; + const APFloat clamp_min_; + const APFloat clamp_max_; + const uint32_t storage_bit_width_; + const bool is_signed_; + int32_t quantization_dim_; +}; + +} // namespace mlir::quantfork + +#endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_COMMON_IR_UNIFORMSUPPORT_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/common/lift_as_function_call.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/common/lift_as_function_call.h new file mode 100644 index 00000000..22e0307f --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/common/lift_as_function_call.h @@ -0,0 +1,118 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_COMMON_LIFT_AS_FUNCTION_CALL_H_ +#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_COMMON_LIFT_AS_FUNCTION_CALL_H_ + +#include "absl/base/nullability.h" +#include "absl/status/statusor.h" +#include "llvm/ADT/SmallVector.h" +#include "mlir/Bytecode/BytecodeOpInterface.h" // from @llvm-project +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Dialect/Traits.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/Matchers.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" + +namespace mlir::quant { + +// This attribute will be set for functions created by this pass. +// Presence of this attribute will mark the function as quantization target. +inline constexpr StringRef kFusedFunctionAttr = "tf_quant.composite_function"; +// The keyword to detect if this is a `NullAttribute`. +inline constexpr StringRef kNullAttributeValue = "N/A"; + +// Prefixes attached to lifted functions. +constexpr StringRef kQuantizedFuncPrefix = "quantized_"; +constexpr StringRef kCompositeFuncPrefix = "composite_"; + +// The attribute will be used for TF::XlaCallModuleOp to restore the original +// function name when loading it back. +inline constexpr StringRef kOriginalStablehloEntryFunctionAttrName = + "_original_entry_function"; + +// FunctionCallOpType to be generated as the function call operator when +// function lifting will happen. +enum FunctionCallOpType { TFPartitionedCallOp = 0, TFXlaCallModuleOp = 1 }; + +// Checks if an op is inside a lifted function. +// If the given op pointer is a nullptr, returns false. +bool IsInLiftedFunc(Operation* op); + +// Checks if the op is inside a StableHLO op with region. +// If the given op pointer is a nullptr, returns false. +bool IsInStableHloOpRegion(Operation* op); + +// Checks if a given einsum op is supported for XlaDotV2 quantization. +bool IsEinsumSupportedByXlaDotV2(StringAttr equation_attr); + +// Gets the quantization method from `op`. It is retrieved from the +// `kQuantizationMethodAttr` string attribute. Returns +// `absl::InvalidArgumentError` when the attribute doesn't exist. Returns +// `absl::InternalError` when parsing the attribute to `Method` failed. +// `op` must be non-null. +absl::StatusOr<::stablehlo::quantization::Method> GetQuantizationMethod( + absl::Nonnull op); + +// Gets the quantization method from `op`. It is retrieved from the +// `kQuantizationMethodAttr` string attribute. Returns a default instance of +// `Method` iff the attribute doesn't exist or the attribute contains an invalid +// textproto for `Method`. `op` must be non-null. +::stablehlo::quantization::Method GetQuantizationMethodOrDefault( + absl::Nonnull op); + +// Creates a function to wrap the section between arguments and results. +// The generated function call op type will be decided by the given call_op_type +// argument. Currently, it supports TF::XlaCallModuleOp and +// TF::PartitionedCallOp function call op generations. +SmallVector LiftAsFunctionCall(OpBuilder& builder, Location location, + FunctionCallOpType call_op_type, + StringRef func_name, + ArrayRef arguments, + ArrayRef results, + ArrayRef attributes); + +// Same as above but with empty attributes. +SmallVector LiftAsFunctionCall(OpBuilder& builder, Location location, + FunctionCallOpType call_op_type, + StringRef func_name, + ArrayRef arguments, + ArrayRef results); + +// Add the second argument to the first argument, which is expected to be an +// argument list. +// Used to attach bias to einsum argument list. +SmallVector AppendToVector(ArrayRef arguments, Value append); + +// Checks if the `Method` attatched to the given `tf.XlaCallModule` op has +// `WeightOnlyPtq`. +bool HasWeightOnlyPtqMethod(TF::XlaCallModuleOp xla_call_module_op); + +// Checks if an op is a `tf.XlaCallModule` op, contains 'conv' or 'dot_general' +// in its name and has `Method` with `WeightOnlyPtq`. +bool IsWeightOnlyQuantizableOp(const Operation& op); + +// Lists the functions in a ModuleOp sorted by their names. +SmallVector GetSortedFunctions(ModuleOp module_op); + +} // namespace mlir::quant + +#endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_COMMON_LIFT_AS_FUNCTION_CALL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_config.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_config.h new file mode 100644 index 00000000..cb9dac20 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_config.h @@ -0,0 +1,252 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This header file defines node specs for quantization and the methods to parse +// command line flags to these specs. + +#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_COMMON_QUANTIZATION_LIB_QUANTIZATION_CONFIG_H_ +#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_COMMON_QUANTIZATION_LIB_QUANTIZATION_CONFIG_H_ + +#include +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/strings/string_view.h" +#include "tensorflow/compiler/mlir/lite/tools/optimize/reduced_precision_metadata.h" +#include "tensorflow/core/framework/types.pb.h" + +namespace mlir { +namespace quant { + +// Stores information about how to quantize a user-specified custom operation. +struct CustomOpInfo { + std::vector quantizable_input_indices; + bool is_weight_only = false; + bool no_side_effect = true; +}; + +using CustomOpMap = std::unordered_map; +enum CustomOpUpdateOptions { kInputIndices, kWeightOnly, kNoSideEffect }; +enum class QDQConversionMode { kQDQNone, kQDQStatic, kQDQDynamic }; + +struct QuantizationSpecs { + // Which function this node quant specifications belong to. + std::string target_func = "main"; + + // Whether to trigger quantization passses for post-training quantization. + // If true, the model input doesn't require user specified input ranges. + bool post_training_quantization = false; + + // Whether to allow dynamic range quantization. This is the easiest + // quantization mode which doesn't require QAT or sample inputs. + // This option only targets `DT_HALF` and `DT_QINT8` inference type. + bool weight_quantization = false; + + // Whether to use the MLIR dynamic range quantizer instead of TOCO. + bool enable_mlir_dynamic_range_quantizer = false; + + // Whether to allow weight-only quantization. This scheme quantizes + // weights but will dequantize them back at runtime which is useful for + // memory bound case without kernel support available in lower precisions. + // Used in MLIR dynamic range quantizer. + bool weight_only_quantization = false; + + // The minimum number of elements in a weights array required to apply + // quantization. This is especially useful not to quantize small tensors as + // it is hard to get performance benefits from them with quantization. Used + // in MLIR dynamic range quantizer with int8 weight data type. + int64_t minimum_elements_for_weights = 1024; + + // Whether to calculate scales in float to keep quantized values the same with + // old TOCO quantizer. + bool legacy_float_scale = false; + + // Whether to perform per-tensor quantization. Currently, this option is only + // valid when the quantization parameters need to be created by scanning the + // constant content (post-training quantization or QAT without weight + // FakeQuant). + bool disable_per_channel = false; + + // Whether to disable per-channel weight quantization and enable legacy per + // tensor quantization. The legacy quantization for Dense layers is + // inconsistent with Conv 1x1 which always performs per channel quantization. + bool disable_per_channel_for_dense_layers = false; + + // Whether to use fixed output ranges of the activation ops (tanh, sigmoid, + // etc.) and not infer weight constants. + // If this option is set, quantization emulation ops should be placed after + // the ops in the input graph. This flag should be set to false for + // post-training quantization. + bool disable_infer_tensor_range = false; + + // Whether to use the unfrozen variable quantization in MLIR. Typically, + // variables are frozen for passing passes, but some variables aren't frozen. + // If it is true, QuantizeVariables pass will be added after the + // PrepareQuantizePass. + bool enable_mlir_variable_quantization = false; + + // The node type when the model is exported. Currently this is limited to + // DT_FLOAT, DT_HALF, DT_QINT8, and DT_QUINT8. When DT_HALF is used, the + // `weight_quantization` flag needs to set to true. When DT_QUINT8 is used, + // the `weight_quantization` flag needs to set to false. + tensorflow::DataType inference_type = tensorflow::DT_FLOAT; + + // The input and output data type during inference. This flag is only used + // when `inference_type` is different from DT_FLOAT. This flag can only be set + // to DT_FLOAT or as same as `inference_type`. If this flag is different + // from `inference_type`, adaptor ops are inserted as heading and tailing ops + // in the result model. + tensorflow::DataType inference_input_type = tensorflow::DT_FLOAT; + + // Input node ranges. These ranges are stored as the same order of function + // arguments. They are only used when `weight_quantization` is set to false, + // and the model is required to have quantization parameters, either from + // quantization aware training or calibration, for the remaining tensors. + std::vector, std::optional>> + input_ranges; + + // Whether to disable setting the quantization parameters of the input nodes + // using input ranges. + bool disable_set_input_nodes_quantization_params = false; + + // The default ranges can be used when a tensor doesn't have quantization + // parameters and couldn't be quantized. Used only for latency tests. + std::pair, std::optional> default_ranges; + + // A serialized "QuantizationInfo" object to specify value ranges for some of + // the tensors with known names. + std::string serialized_quant_stats = ""; + + // A bitmask to encode support for reduced precision inference in the model. + tflite::optimize::ReducedPrecisionSupport support_mask = + tflite::optimize::ReducedPrecisionSupport::None; + + // Whether to run the passes to propagate the quantization parameters and + // graph rewrites. Returns false if the inference_type is DT_FLOAT or + // `weight_quantization` flag is set. + bool RunPropagationAndRewriteQuantizationPasses() const { + return inference_type != tensorflow::DT_FLOAT && !weight_quantization; + } + + // TODO: b/202075505 - make implicit weight type clearer + // Whether run the passes and graph rewrites for dynamic range quantization. + bool RunAndRewriteDynamicRangeQuantizationPasses() const { + bool dynamic_range_quantize = + (inference_type != tensorflow::DT_FLOAT) && weight_quantization && + !post_training_quantization && !disable_infer_tensor_range && + enable_mlir_dynamic_range_quantizer; + return dynamic_range_quantize; + } + + // Returns whether this inference type represents a signed storage type. + bool IsSignedInferenceType() const { + switch (inference_type) { + case tensorflow::DT_QUINT8: + case tensorflow::DT_QUINT16: + return false; + default: + return true; + } + } + + // Gets the width of this quantization type. Returns 0 if it isn't a + // quantization type. + int64_t GetQuantizationTypeWidth() const { + switch (inference_type) { + case tensorflow::DT_INT8: + case tensorflow::DT_UINT8: + case tensorflow::DT_QINT8: + case tensorflow::DT_QUINT8: + return 8; + case tensorflow::DT_INT16: + case tensorflow::DT_UINT16: + case tensorflow::DT_QINT16: + case tensorflow::DT_QUINT16: + return 16; + case tensorflow::DT_INT32: + case tensorflow::DT_QINT32: + return 32; + default: + return 0; + } + } + + // Whether to add the NumericVerify ops to verify numbers before and after + // quantization. + bool verify_numeric = false; + // Whether to add verification for layer by layer, or on whole model. When + // disabled (per-layer) float and quantized ops will be run from same input + // (output of previous quantized layer). When enabled, float and quantized ops + // will run with respective float and quantized output of previous ops. + bool whole_model_verify = false; + + // Whether to use fake quant attributes to calculate quantization parameters. + bool use_fake_quant_num_bits = false; + + // Names of ops to block from quantization. Used in QuantizePass. + // For dynamic range quantization, ops in blocklist are quantized in weight- + // only manner. + absl::flat_hash_set ops_blocklist; + + // Names of locations to block from quantization. Used in QuantizePass. + absl::flat_hash_set nodes_blocklist; + + // Map from custom op code to custom op quantization information. + // For dynamic range quantization, among the custom ops in the graph those + // specified in this map are subject to quantization. + CustomOpMap custom_map; + + // If other than kQDQNone, the model is a floating point graph with QDQ ops + // to be eliminated and fused into quantized kernels. + QDQConversionMode qdq_conversion_mode = QDQConversionMode::kQDQNone; + + // When set, adheres to the QDQ annotations added by the framework when + // possible rather than quantizing any op that is possible to quantize. + bool strict_qdq_mode = false; +}; + +// Parses the command line flag strings to the CustomOpMap specification. +void ParseCustomOpSpecs(absl::string_view node_names, + const CustomOpUpdateOptions& update_option, + CustomOpMap& custom_op_map); + +// Parses the command line flag strings to the quantization specification for +// input arrays of a graph. The array names are not stored in the spec, and will +// be matched by position. Returns true if failed. +bool ParseInputNodeQuantSpecs(absl::string_view node_names, + absl::string_view min_values, + absl::string_view max_values, + absl::string_view inference_type, + QuantizationSpecs* quant_specs); + +// Gets the quantization specification for input arrays. The array names are not +// stored in the spec, and will be matched by position. The min/max will be +// ignored if the inference_type isn't a quantized type. Returns true if failed. +bool GetInputNodeQuantSpecs(const std::vector& node_names, + const std::vector>& node_mins, + const std::vector>& node_maxs, + tensorflow::DataType inference_type, + QuantizationSpecs* quant_specs); + +// Returns a human-readable string of the QDQQuantMode enum class +std::string GetQDQQuantModeString(QDQConversionMode mode); +} // namespace quant +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_COMMON_QUANTIZATION_LIB_QUANTIZATION_CONFIG_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_driver.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_driver.h new file mode 100644 index 00000000..43edaab9 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_driver.h @@ -0,0 +1,387 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_COMMON_QUANTIZATION_LIB_QUANTIZATION_DRIVER_H_ +#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_COMMON_QUANTIZATION_LIB_QUANTIZATION_DRIVER_H_ + +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/STLExtras.h" +#include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Dialect/Quant/IR/QuantTypes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/OpDefinition.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_utils.h" + +namespace mlir { +namespace quant { + +// The state for each op result during the quantization parameters propagation. +struct QuantState { + // Quantization parameters propagated to an op result. + QuantizedType params; + // A flag indicates this state (the params) shouldn't be changed after it is + // initialized. This flag will be set to true if the quantization parameters + // are from the quantization-aware training. + const bool immutable; + + bool IsEmpty() const { return params == nullptr; } +}; + +// The state for rescaling the propagated quantization parameters. This can be +// on the input side to satisfy the constraint of previous operation, or on the +// output side to satisfy the constraint of the next operation. +struct RequantizeState { + // Sometimes, we have to "requantize" the quantization result to satisfy all + // the constraints. The "requantize" can happen either on the input or output + // of the quantization result. + enum RequantizePosition { + NO_REQUANTIZE, + ON_INPUT, + ON_OUTPUT + } pos = NO_REQUANTIZE; + + // Quantization parameters will be used to add the requantize ops. + QuantizedType params; + + // Avoid clobbering all uses of the value, limit to just these ops. + SmallVector> users; +}; + +using RequantizeStates = SmallVector; + +// This is a worklist-driven driver for propagating quantization parameters +// across operations. +// +// The initial quantization parameters are extracted from the quantized type +// between adjacent `quantfork::QuantizeCastOp` and +// `quantfork::DequantizeCastOp`s. All these initial parameters are marked as +// immutable because they are from quantization-aware training. +// +// The algorithm traverses each op and sets the quantization parameters of its +// operands and results, according to its quantization specification, and then +// adds the operands and results to the worklist. If there are any conflicts +// (for example, there are quantization parameters propagated from the previous +// iteration), this process stops if the existing parameters are the immutable, +// or adding `requantize` op to resolve the conflicts. +// +// After the algorithm is converged, pairs of `quantfork::QuantizeCastOp` and +// `quantfork::DequantizeCastOp` are inserted to the right position to +// materialize the propagation and requantize results. +// +class QuantizationDriver { + public: + // Type alias of int used to access `states_`. + using QuantStateIndex = int; + + // (op, operand index) pair. + using OpWithOperandIndex = std::pair; + + // (op, result index) pair. + using OpWithResultIndex = std::pair; + + explicit QuantizationDriver(func::FuncOp func_op, const bool is_signed, + const int bit_width, + const bool disable_per_channel, + OpQuantSpecGetter op_quant_spec_getter, + OpQuantScaleSpecGetter op_quant_scale_spec_getter, + const bool infer_tensor_range, + const bool legacy_float_scale = false, + const bool is_qdq_conversion = false) + : fn_(func_op), + builder_(func_op.getBody()), + is_signed_(is_signed), + bit_width_(bit_width), + disable_per_channel_(disable_per_channel), + op_quant_spec_getter_(op_quant_spec_getter), + op_quant_scale_spec_getter_(op_quant_scale_spec_getter), + infer_tensor_range_(infer_tensor_range), + legacy_float_scale_(legacy_float_scale), + is_qdq_conversion_(is_qdq_conversion) {} + + // The entry point of the quantization parameters propagation. + void Run(); + + // Sets up the states for all the op results in the function. + void Initialize(); + + // Propagates the quantization parameters across all the ops. + bool PropagateParamsAndReturnIfChanged(); + + // Inserts the Quantize and Dequantize ops according to the propagation + // result. + void Finalize(); + + SmallVector GetArgs() { return args_; } + + llvm::DenseMap, int> GetResultStates() { + return result_states_; + } + + DenseMap result_states_; + + // Returns the state of the block argument. + QuantState& GetArgQuantState(BlockArgument arg) { + return states_[arg_states_[arg]]; + } + + // Returns the state of the index-th result of the op. + QuantState& GetResultQuantState(Operation* op, const int index) { + return states_[result_states_[{op, index}]]; + } + + private: + // Duplicates the constant op if it has multiple uses, and replaces + // target_op->operand[operand_index] with the newly created op. This also + // replaces corresponsing quantization states. + arith::ConstantOp DuplicateConstantOpIfNeeded(arith::ConstantOp op, + Operation* target_op, + int operand_index); + + // Adjusts bias scale that is derived from other scales (fc, conv ops) to + // prevent overflow of quantized bias values. This also changes quantization + // state of other inputs when needed. + bool SetBiasParamsWithAdjustments(Operation* op, int bias_index, + ArrayRef input_indices, + QuantizedType params); + + // Checks preconditions to adjust bias scale. + bool ShouldCheckBiasScale(Operation* op, int bias_index, + ArrayRef input_indices, + QuantizedType quantized_type, int& input_index, + int& filter_index); + + // Preprocesses the constants by doing the following: + // - Duplicates constants if it is used by multiple ops. For example, if a + // constant is used by multiple ops as a bias, duplicate constants and + // let each op assign its own quantization parameter for bias. + // - Adds all the non-bias constants (weights) to a set for looking up + // later. + // - Adds all per-channel weights to a set for looking up later. + void PreprocessConstantOps(); + + // Sets up all the data structures for quantization propagation. + void SetupAllStates(); + + // Returns Whether the constant is a weight, which shouldn't be shared by + // different ops. + bool IsWeight(Operation* cst) { return llvm::is_contained(weights_, cst); } + + // Returns all the related quantization constraints of the op. + std::unique_ptr GetQuantSpec(Operation* op); + std::unique_ptr GetQuantScaleSpec(Operation* op); + + // Returns whether quantization parameters have been propagated to the results + // of this op. + bool IsQuantized(Operation* op); + + // Adds all the users of index-th result of op to the work list. + void AddUserToList(Operation* op, const int index) { + for (Operation* user : op->getResult(index).getUsers()) { + work_list_.push_back(user); + } + } + + // Adds the defining op of index-th operand of op to the work list. + void AddOperandToList(Operation* op, const int index) { + if (Operation* operand_op = op->getOperand(index).getDefiningOp(); + operand_op != nullptr) { + work_list_.push_back(operand_op); + } + } + + // Returns the quantization params for the bias input from the non-bias + // operands which have their indexes in the `non_biases` vector. The returned + // parameters are calculated by `func`. + QuantizedType GetBiasParams(Operation* op, int bias_index, + ArrayRef non_bias_operand_indices, + AccumulatorScaleFunc func); + + // Sets the quantization parameters of the result to `quantized_type`. If + // any quantization parameters have been propagated, a requantize will + // happen on the input of propagated quantization. Returns `true` if internal + // state has been modified. + bool SetResultParams(Operation* op, int result_index, + QuantizedType quantized_type); + + // Sets the quantization parameters of the operand to `quantized_type`. If any + // quantization parameters have been propagated, a `requantize` will happen on + // the output of propagated quantization. When `override` is set, quantization + // state of the value is replaced instead of adding requantization. Returns + // `true` if internal state has been modified. + bool SetOperandParams(Operation* op, int operand_index, + QuantizedType quantized_type, bool override = false); + + // Sets the quantization parameters of the constant result according to its + // content. + bool SetConstantResultParams(Operation* op); + + // Inserts the Quantize and Dequantize ops after `op`'s `index`-th result. The + // quantized element type for the result is `quantized_type`. + void QuantizeOpResult(Operation* op, int result_index, + QuantizedType quantized_type); + + // Inserts the Quantize and Dequantize ops after `arg`. The quantized element + // type for `arg` is `quantized_type`. + void QuantizeArg(BlockArgument arg, QuantizedType quantized_type); + + // Inserts the Quantize and Dequantize ops (i.e. QDQ) after `value`. The + // quantized element type for `value` is `quantized_type`. + void QuantizeValue(Value value, QuantizedType quantized_type, Location loc); + + // Inserts the Quantize ops for requantizing the index-th result of the op. + void RequantizeOpResult(Operation* op, int result_index, + RequantizeStates& states); + + // Inserts the Quantize ops for requantizing a block argument. + void RequantizeArg(BlockArgument arg, RequantizeStates& states); + + // Inserts the Quantize and Dequantize ops to quantize the value and returns + // the Quantize op. + void RequantizeValue(Value value, RequantizeStates& states, Location loc); + + // Returns the quantization parameter satisfies the same scale + // constraints for the op. Returns an empty option if this quantization + // parameter doesn't exist. + QuantizedType GetQuantParamsForSameScaleConstraint(Operation* op); + + // Returns the state of the index-th operand of the op. + QuantState& GetOperandQuantState(Operation* op, const int index) { + return states_[operand_states_[{op, index}]]; + } + + // Returns the states of the index-th operand of the op. + RequantizeStates& GetOperandRequantizeStates(Operation* op, const int index) { + return rescale_states_[operand_states_[{op, index}]]; + } + + // Returns the states of the index-th result of the op. + RequantizeStates& GetResultRequantizeStates(Operation* op, const int index) { + return rescale_states_[result_states_[{op, index}]]; + } + + // Returns the states of the arg. + RequantizeStates& GetArgRequantizeStates(BlockArgument arg) { + return rescale_states_[arg_states_[arg]]; + } + + // Sets the state of an argument. If this value is cached, uses the cached + // result without creating new entry in the state vector. Otherwise, allocate + // a new entry in the state vector. + void InitializeArgState(BlockArgument arg, Value arg_value); + + // Sets the state of the index-th operand of the op. If this operand is + // cached, uses the cached result without creating new entry in the state + // vector. Otherwise, allocate a new entry in the state vector. + void InitializeOperandState(Operation* op, int index, Value value); + + // Sets the state of the index-th result of the op. If this result is cached, + // uses the cached result without creating new entry in the state vector. + // Otherwise, allocate a new entry in the state vector. + void InitializeResultState(Operation* op, int index, Value value); + + func::FuncOp fn_; + OpBuilder builder_; + const bool is_signed_; + const int bit_width_; + const bool disable_per_channel_; + + // We should distinguish weights and bias constants. Biases are specified by + // the quantization spec or are the operands of ops with same scale spec. The + // rest are weights. + DenseSet weights_; + + // The weights require narrow_range quantization. This map collects all the + // weight operands defined by the op quant spec. The value of each entry is + // the quantization dimension. If it is positive, per-channel quantization is + // required. + DenseMap optimized_weights_; + + // All the ops needs to propagate the quantization parameters to. + std::vector work_list_; + absl::flat_hash_set quantized_; + + // The vector contains all the quantization parameters propagated from the + // defining operations of the value, or from the quantization aware training. + std::vector states_; + + // The map contains all the quantization parameters which are required to + // satisfy the same operands and results constraint. The keys of this map are + // the values from `operand_states_` and `result_state_`. + absl::flat_hash_map rescale_states_; + + // Maps of indexes to the propagation state vector from the ops operands, + // results and arguments. + DenseMap operand_states_; + DenseMap arg_states_; + DenseMap value_to_state_; + + // This vector is to preserve the arguments order, so the newly inserted + // quantized ops for the arguments are deterministically ordered. + SmallVector args_; + + OpQuantSpecGetter op_quant_spec_getter_; + OpQuantScaleSpecGetter op_quant_scale_spec_getter_; + + // Infer output ranges for activation ops and constants. This is usually + // required for post-training quantization. + const bool infer_tensor_range_; + + // Calculate scales in float instead of double, so that the scales and + // quantized values are exactly the same with the TOCO quantizer. + const bool legacy_float_scale_; + + // If true, the model is a floating point graph with QDQ ops to be eliminated + // and fused into quantized kernels. + const bool is_qdq_conversion_; +}; + +// Propagates quantization parameters across ops in this function and satisfies +// the quantization specification of the ops. This methods assumes the initial +// quantization parameters are stored as adjacent quantize and dequantize ops +// and the propagation results are materialized by inserting pairs of quantize +// and dequantize ops to this function. Set `disable_per_channel` to true to not +// use per channel quantization even the op supports it. +// Setting `infer_tensor_range` to true, to infer quantization parameters from +// the activation ops and weight constants. This is only used for post-training +// quantization. +void ApplyQuantizationParamsPropagation(func::FuncOp func, bool is_signed, + int bit_width, bool disable_per_channel, + OpQuantSpecGetter op_quant_spec_getter, + bool infer_tensor_ranges, + bool legacy_float_scale, + bool is_qdq_conversion); + +void ApplyQuantizationParamsPropagation( + func::FuncOp func, bool is_signed, int bit_width, bool disable_per_channel, + OpQuantSpecGetter op_quant_spec_getter, + OpQuantScaleSpecGetter op_quant_scale_spec_getter, bool infer_tensor_ranges, + bool legacy_float_scale, bool is_qdq_conversion); + +} // namespace quant +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_COMMON_QUANTIZATION_LIB_QUANTIZATION_DRIVER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_traits.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_traits.h new file mode 100644 index 00000000..e93cc4cf --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_traits.h @@ -0,0 +1,152 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This file defines the op traits used in the MLIR TensorFlow Lite dialect. + +#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_COMMON_QUANTIZATION_LIB_QUANTIZATION_TRAITS_H_ +#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_COMMON_QUANTIZATION_LIB_QUANTIZATION_TRAITS_H_ + +#include +#include +#include + +#include "mlir/Dialect/Quant/IR/QuantTypes.h" // from @llvm-project +#include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/OpDefinition.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project + +using QuantizedType = mlir::quant::QuantizedType; +using UniformQuantizedType = mlir::quant::UniformQuantizedType; + +namespace mlir { +namespace quant { +// Verifies that the op satisfies the same operands and results scales +// constraints. Note that this constraint can only be applied on some +// storage types of the op. +LogicalResult VerifySameScales(Operation* op); +} // namespace quant + +// This includes the interface class definition. It couldn't be in a namespace +// because the table gen doesn't emit the namespace when it is used. +#include "tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_interface.h.inc" + +namespace OpTrait { +namespace quant { + +// The base class that all the quantization related OpTrait implements. +template class TraitType> +struct QuantizationSpecTraitBase : public TraitBase { + static bool IsBias(int index) { return false; } + static bool IsQuantizable() { return true; } +}; + +// This class provides the API for ops that has a fixed output value range. +// This is used as a trait like this: +// +// class SoftmaxOp +// : public Op::Impl> { +// +// TODO(fengliuai): create a better way to express floating point scale in the +// template argument list. +template +class FixedResultUniformScale { + public: + template + class Impl + : public QuantizationSpecTraitBase< + ConcreteType, FixedResultUniformScale< + BitWidth, ZeroPoint, ScaleMantissa, ScaleExp, + StorageTypeMin, StorageTypeMax, Sign>::Impl> { + public: + QuantizedType GetResultQuantizedType(int index) { + auto op = this->getOperation(); + const auto result_type = + op->getResult(index).getType().template cast(); + if (!result_type.getElementType().template isa()) return {}; + Builder builder(op->getContext()); + const IntegerType storage_type = builder.getIntegerType(BitWidth); + const double scale = static_cast(ScaleMantissa) * + std::pow(10.0, static_cast(ScaleExp)); + return UniformQuantizedType::getChecked( + Sign, storage_type, result_type.getElementType(), scale, ZeroPoint, + StorageTypeMin, StorageTypeMax, builder.getUnknownLoc()); + } + }; +}; + +// This class provides the API for ops that has input as bias. This is used +// as a trait like this: +// +// class Conv2DOp +// : public Op::Impl> +// +// TODO(fengliuai): supports a configurable accumulator bit width. +template +class AccumulatorUniformScale { + public: + template + class Impl + : public QuantizationSpecTraitBase< + ConcreteType, AccumulatorUniformScale::Impl> { + public: + // Whether the index-th operand is a bias. + static bool IsBias(int index) { return index == Bias; } + + // Returns the indexes of all the non-bias operands. + static std::vector GetAllNonBiasOperands() { + return std::vector({Operands...}); + } + }; +}; + +// The trait to specify the operand index of the coefficient for an affine op +// and also the quantization dimension if per-axis quantization is support. +// If the quantization dimension is -1, per-axis quantization isn't supported. +// +// class Conv2DOp +// : public Op::Impl> +// +template +class AffineOpCoefficient { + public: + template + class Impl + : public TraitBase::Impl> { + public: + static int GetCoefficientOperandIndex() { return OperandIndex; } + static int GetQuantizationDim() { return QuantDim; } + }; +}; + +// This class provides the API for ops that can be quantized. +// This is as a trait like this: +// +// class LessOp : public Op { +// +template +class QuantizableResult + : public QuantizationSpecTraitBase {}; + +} // namespace quant +} // namespace OpTrait +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_COMMON_QUANTIZATION_LIB_QUANTIZATION_TRAITS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_utils.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_utils.h new file mode 100644 index 00000000..94169e3e --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_utils.h @@ -0,0 +1,973 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This header file defines common utils used by TFLite transformation +// passes to work with op attributes. + +#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_COMMON_QUANTIZATION_LIB_QUANTIZATION_UTILS_H_ +#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_COMMON_QUANTIZATION_LIB_QUANTIZATION_UTILS_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/strings/string_view.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/Casting.h" +#include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Dialect/Quant/IR/QuantTypes.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/IRMapping.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Matchers.h" // from @llvm-project +#include "mlir/IR/OpDefinition.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/OperationSupport.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.h" +#include "tensorflow/compiler/mlir/quantization/common/ir/FakeQuantSupport.h" +#include "tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_config.h" +#include "tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_traits.h" +#include "tensorflow/core/framework/types.pb.h" + +namespace mlir { +namespace quant { + +// A unit attribute can be attached to the quantize/dequantize ops which are +// added by the quantization passes. These ops can be removed erased without +// losing accuracy. +inline constexpr char kVolatileOpAttrName[] = "volatile"; + +// Following attributes are used to mark ops that are not quantizable during +// debug model generation process for whole-model verify mode. If these +// attributes are attached, the upstream float/quantized ops know which ops to +// connect to, and it also prevents these ops from being copied again. +inline constexpr char kDebugModeOpFloatAttrName[] = "debug_float"; +inline constexpr char kDebugModeOpQuantAttrName[] = "debug_quant"; + +// Used to annotate custom ops if they are quantizable. +inline constexpr char kQuantTraitAttrName[] = "_tfl_quant_trait"; +enum QuantizationTrait { FullyQuantizable = 0, NotQuantizable = 1 }; +inline constexpr absl::string_view QuantTraitValues[] = {"fully_quantizable", + "not_quantizable"}; +inline constexpr char kOutputQuantized[] = "_output_quantized"; + +inline constexpr double kNearZeroTolerance = 1.0e-6; + +using QuantParams = QuantizedType; +using QuantSpec = QuantizationSpecs; +using SignedInteger = std::pair; // bitwidth and sign +using QuantParamsForResults = llvm::SmallVector; +using AccumulatorScaleFunc = + std::function&, int, bool)>; +using BiasParamsMap = + absl::flat_hash_map, AccumulatorScaleFunc>>; +// UniformQuantizedType GetFixedOutputRange(bool sign, int bit_width) +using GetFixedOutputRangeFunc = std::function; +// bool RequiredSameOperandsAndResultsScale(bool sign, int $bit_width) +using RequiredSameOperandsAndResultsScaleFunc = std::function; +// bool RequiredSameQuantizedAxes() +using RequiredSameQuantizedAxesFunc = std::function; + +using CustomMap = quant::CustomOpMap; + +// Quantization spec of an op, driving the quantization algorithm. +struct OpQuantSpec { + // Maps the operand index of a bias input to its quantization specifications, + // including the non-bias operand indexes and the method retrieving + // quantization parameters from list of parameters of the non-bias operands. + // This map is empty if the op doesn't have a bias operand. + BiasParamsMap biases_params; + + // Quantization parameters for value restricted outputs. This is the + // "hard-coded" parameters and should be used unconditionally for the + // quantized op. This vector is empty if the op doesn't have value restricted + // outputs. + llvm::DenseMap restricted_output_params; + + // Coefficient operand index and whether supporting per-channel quantization. + // For QAT, this information is carried by the FakeQuant*/Quantize/Dequantize + // ops, but post-training quantization, the quantization parameters need to be + // inferred from the tensor content and op property. A "-1" value indicates + // the operand doesn't support per-channel quantization. + llvm::DenseMap coeff_op_quant_dim; + + // Indices of quantizable operands. Biases are not included in this field, + // the indices of biases can be found in the `biases_params`. + absl::flat_hash_set quantizable_operands; +}; + +// A function signature for getting the particular OpQuantSpec for the provided +// op. +using OpQuantSpecGetter = + std::function(Operation*)>; + +// Quantization scale spec of an op. The information defined in the MLIR +// interfaces FixedOutputRangeInterface and SameOperandsAndResultsScale should +// be checked first if present. +// TODO: b/323478683: Consider deprecating this. +struct OpQuantScaleSpec { + // Whether this op has a fixed range requirement (e.g. sigmoid) + bool has_fixed_output_range = false; + // Whether this op should have same operand and result scales (e.g. concat) + bool has_same_scale_requirement = false; + // Whether this op should have same operand and result type (e.g. gather) + bool has_same_operand_and_result_type_requirement = false; + // Returns the fixed output range, when has_fixed_output_range is set. + GetFixedOutputRangeFunc fixed_output_range_func; + // Returns whether same operands and results scales are required. + RequiredSameOperandsAndResultsScaleFunc required_same_scale_func = + [](bool sign, int bit_width) { return true; }; + // Returns whether operands and results must have the same quantized axis. + RequiredSameQuantizedAxesFunc required_same_quantized_axes_func = []() { + return true; + }; +}; + +// A function signature for getting the particular OpQuantScaleSpec for the +// provided op. +using OpQuantScaleSpecGetter = + std::function(Operation*)>; + +// Used in TFL Numeric Verify +struct NumericVerifySpec { + // Whether to enable numeric verification + bool verify_numeric = false; + + // Tolerance level from the quantized value for verification. If the tolerance + // is very small(<0.1), only the stats of the diff is displayed. + float error_tolerance = 5.0f; + + // Whether to verify numerical correctness layer by layer or by whole model + bool whole_model_verify = false; + + // Whether to enable log for failures + bool log_if_failed_flag = false; +}; + +// Used in TFL Quantize Pass +struct QuantPassSpec { + // Variables to control TFL Numeric Verify + NumericVerifySpec numeric_verify_spec; + + // Variables related to quantization + QuantSpec quant_spec; +}; + +// Re-calculates scales again in float instead of simply downcasting existing +// scales. +quant::QuantizedType DownCastScale(quant::QuantizedType type, + const SmallVectorImpl& mins, + const SmallVectorImpl& maxs, + Location loc); + +quant::QuantizedType DownCastScale(quant::QuantizedType type, double min, + double max, Location loc); + +bool IsOpQuantizable(Operation* op); +bool QuantizableOpSupportsFloatOutputType(Operation* op); + +// Specialized version of location to string for flatbuffer exported locations. +inline std::string GetTensorNameFromLoc(Location loc) { + if (auto name_loc = loc.dyn_cast()) { + return name_loc.getName().str(); + } + return ""; +} + +template +struct ConvertStatsToQDQs : public OpRewritePattern { + ConvertStatsToQDQs(int num_bits, bool narrow_range, bool is_signed, + bool legacy_float_scale, MLIRContext* context) + : OpRewritePattern(context), + num_bits(num_bits), + narrow_range(narrow_range), + is_signed(is_signed), + legacy_float_scale(legacy_float_scale) {} + + LogicalResult matchAndRewrite(quantfork::StatisticsOp op, + PatternRewriter& rewriter) const override { + Type expressed = op.getType().cast().getElementType(); + quant::QuantizedType quant_type; + SmallVector mins, maxs; + + if (op.getAxisStats().has_value()) { + // Per axis quantization (or per channel quantization) + int stats_num = op.getAxisStats()->getNumElements(); + if (stats_num == 0 || stats_num % 2 != 0) return failure(); + auto stats = op.getAxisStats()->dyn_cast(); + if (!stats) return failure(); + + for (auto it = stats.begin(), e = stats.end(); it != e; ++it) { + double rmin = FloatAttr::getValueAsDouble(*it++); + double rmax = FloatAttr::getValueAsDouble(*it); + // The default nudging implementation of mlir quant library might cause + // clamping during inference if the calibration range isn't wide enough. + // So here we adjust the range to include 0.0. + rmin = std::min(rmin, 0.0); + rmax = std::max(rmax, 0.0); + if (num_bits == 16) { + // TODO: b/266536261 - Since the kernel implementation assumes that + // 16x8 integer quantization is symmetric, this MLIR quantizer + // supports only symmetric quantization. + rmax = std::max(std::abs(rmin), std::abs(rmax)); + rmin = -rmax; + } + TensorRangeSanityCheck(op, rmin, rmax); + mins.push_back(rmin); + maxs.push_back(rmax); + } + quant_type = quantfork::fakeQuantAttrsToType( + op.getLoc(), num_bits, *op.getAxis(), mins, maxs, narrow_range, + expressed, is_signed); + if (legacy_float_scale) { + quant_type = DownCastScale(quant_type, mins, maxs, op->getLoc()); + } + } else if (auto stats = + op.getLayerStats().dyn_cast()) { + // Per tensor quantization + auto statValues = stats.getValues(); + double rmin = FloatAttr::getValueAsDouble(statValues[0]); + double rmax = FloatAttr::getValueAsDouble(statValues[1]); + // The default nudging implementation of mlir quant library might cause + // clamping during inference if the calibration range isn't wide enough. + // So here we adjust the range to include 0.0. + rmin = std::min(rmin, 0.0); + rmax = std::max(rmax, 0.0); + if (num_bits == 16) { + // TODO: b/266536261 - Since the kernel implementation assumes that + // 16x8 integer quantization is symmetric, this MLIR quantizer supports + // only symmetric quantization. + rmax = std::max(std::abs(rmin), std::abs(rmax)); + rmin = -rmax; + } + TensorRangeSanityCheck(op, rmin, rmax); + quant_type = + quantfork::fakeQuantAttrsToType(op.getLoc(), num_bits, rmin, rmax, + narrow_range, expressed, is_signed); + if (legacy_float_scale) { + quant_type = DownCastScale(quant_type, rmin, rmax, op->getLoc()); + } + } else { + return failure(); + } + + rewriter.setInsertionPointAfter(op.getOperation()); + Type result_type = quant_type.castFromExpressedType(op.getType()); + auto q = + rewriter.create(op.getLoc(), result_type, op.getArg()); + q->setAttr(kVolatileOpAttrName, rewriter.getUnitAttr()); + + auto dq = rewriter.create(op.getLoc(), op.getType(), q); + op.getResult().replaceAllUsesWith(dq); + q.getOperation()->replaceUsesOfWith(dq, op.getArg()); + op.erase(); + + return success(); + } + + private: + int num_bits; + bool narrow_range; + bool is_signed; + bool legacy_float_scale; + + // Emits an op warning message if the calibrated range is larger than 10.0 and + // the storage type is less than or equal to 8 bits. + void TensorRangeSanityCheck(quantfork::StatisticsOp op, double& min, + double& max) const { + double range = std::fabs(max - min); + if (num_bits <= 8 && range >= 10.0) { + op.emitWarning() + << "Tensor range is too wide to be quantized. Use tf.clip_by_value " + "or tf.relu6 to narrow the tensor range. Range: " + << range << ", bit width: " << num_bits; + } + if (std::abs(max - min) < kNearZeroTolerance) { + op.emitWarning() << "Tensor range (" << min << ", " << max + << ") is too narrow and it might cause overflow. " + "Expanding range symmetrically by " + << kNearZeroTolerance; + min -= kNearZeroTolerance; + max += kNearZeroTolerance; + } + } +}; + +template +bool UsedBy(Operation* op) { + for (Operation* user : op->getUsers()) { + if (llvm::isa_and_nonnull(user)) return true; + } + return false; +} + +template +void CreateVerifier(Operation* quantizing_op, Operation* quantized_op, + PatternRewriter& rewriter, int result_idx, + const QuantPassSpec& quant_params) { + rewriter.setInsertionPointAfter(quantized_op); + FloatAttr tolerance = rewriter.getF32FloatAttr( + quant_params.numeric_verify_spec.error_tolerance); + BoolAttr log = + rewriter.getBoolAttr(quant_params.numeric_verify_spec.log_if_failed_flag); + // Verify the quantized value by sending the result to the verifier. + rewriter.create( + quantizing_op->getLoc(), quantized_op->getResult(result_idx).getType(), + quantized_op->getResult(result_idx), quantizing_op->getResult(result_idx), + tolerance, log); +} + +template <> +inline bool UsedBy(Operation* op) { + return false; +} + +// This specialization is not going to be called, but needed for compilation. +template <> +inline void CreateVerifier(Operation* quantizing_op, + Operation* quantized_op, + PatternRewriter& rewriter, int result_idx, + const QuantPassSpec& quant_params) {} + +// A base rewrite pattern which matches any N-in-M-out operations with +// quantization parameters propagated to at least one of its operands. The +// quantization parameters are annotated by the QuantizeOp/DequantizeOp pairs. +// Each matched pattern are rewritten by its quantized alternatives. +// +// The concrete pattern, extends from this base pattern, can specify whether it +// allows dynamic range quantized operands and results for the operations in the +// current context. These "DynamicRangeQuantized" operands and results don't +// have quantization parameters propagated to, so will be in float in the +// quantized results. The concrete pattern should define the following two +// functions: +// +// bool AllowDynamicRangeQuantizedOperand(Operation *) const +// bool AllowDynamicRangeQuantizedResult(Operation *) const +// +// Full integer quantization disallows "DynamicRangeQuantized" operands or +// results. Dynamic range quantization allows "DynamicRangeQuantized" operands +// and results. +template +class QuantizationPattern : public RewritePattern { + public: + using BaseType = QuantizationPattern; + + explicit QuantizationPattern(MLIRContext* context, + const QuantPassSpec& quant_params) + // Set the score to a large number so it is always preferred. + : RewritePattern(RootOpT::getOperationName(), 300, context), + quant_params_(quant_params) {} + + LogicalResult matchAndRewrite(Operation* op, + PatternRewriter& rewriter) const override { + llvm::SmallVector quantizing_ops; + + // Collect all the ops to quantize, as the user / producer of the root op. + if constexpr (std::is_same_v) { + if (op->getNumResults() != 1) { + return failure(); + } + auto users = op->getResult(0).getUsers(); + quantizing_ops.append(users.begin(), users.end()); + } else if constexpr (std::is_same_v) { + if (op->getNumOperands() != 1) { + return failure(); + } + Value quantize_operand = op->getOperand(0); + if (QuantizedType::getQuantizedElementType(quantize_operand.getType())) { + // The input of this QuantizeOp has already been quantized, i.e. + // rescale. + return failure(); + } + DenseFPElementsAttr attr; + if (matchPattern(quantize_operand, m_Constant(&attr))) { + // Const-> QuantizeOp pattern will be handled separately. + return failure(); + } + if (Operation* quantizing_op = quantize_operand.getDefiningOp()) { + quantizing_ops.push_back(quantizing_op); + } + } + + tensorflow::DataType inference_type = + quant_params_.quant_spec.inference_type; + bool weight_only_quantization = + quant_params_.quant_spec.weight_only_quantization; + bool enable_verify = quant_params_.numeric_verify_spec.verify_numeric; + bool enable_whole_model_verify = + quant_params_.numeric_verify_spec.whole_model_verify; + absl::flat_hash_set ops_blocklist = + quant_params_.quant_spec.ops_blocklist; + absl::flat_hash_set nodes_blocklist = + quant_params_.quant_spec.nodes_blocklist; + CustomMap custom_map = quant_params_.quant_spec.custom_map; + + // Rewrite the floating-point ops to the quantized version, by fusing + // preceding dequantize ops and succeding quantize ops. + for (Operation* quantizing_op : quantizing_ops) { + // If it is requantize op, we shouldn't rewrite this op. + if (llvm::isa(quantizing_op)) { + return failure(); + } + + // If the op is terminator, not quantizable or any ops from the mlir quant + // ops dialect, we shouldn't rewrite. In case of whole-model verify debug + // mode, not-quantizable ops should be duplicated to keep parallel + // float/quant model execution. + if (quantizing_op->hasTrait()) { + return failure(); + } + + if (!IsOpQuantizable(quantizing_op) && + !static_cast(this)->IsQuantizableCustomOp( + quantizing_op, custom_map)) { + if (!(enable_verify && enable_whole_model_verify)) { + return failure(); + } + if (quantizing_op->hasAttr(kDebugModeOpQuantAttrName) || + quantizing_op->hasAttr(kDebugModeOpFloatAttrName)) { + return failure(); + } + + rewriter.setInsertionPoint(quantizing_op); + Operation* float_op = rewriter.clone(*quantizing_op); + quantizing_op->setAttr(kDebugModeOpQuantAttrName, + rewriter.getUnitAttr()); + float_op->setAttr(kDebugModeOpFloatAttrName, rewriter.getUnitAttr()); + RewireFloatModelBackbone(quantizing_op, float_op); + return success(); + } + + // Blocklist op is checked in advance for non-dynamic range quantization + // case. + if (!quant_params_.quant_spec.weight_quantization && + (ops_blocklist.find(quantizing_op->getName().getStringRef().str()) != + ops_blocklist.end())) { + return failure(); + } + + if (!nodes_blocklist.empty()) { + if (auto name_loc = quantizing_op->getLoc().dyn_cast()) { + std::string sloc = name_loc.getName().str(); + if (!sloc.empty() && + (nodes_blocklist.find(sloc) != nodes_blocklist.end())) { + return failure(); + } + } + } + + // An op with float inputs and outputs are expected when it's used by a + // NumericVerify op. Skip this op. + if (enable_verify && UsedBy(quantizing_op)) { + continue; + } + + bool is_operand_or_result_modified = false; + // Collect all the quantized inputs and "clone" the matched op by these + // inputs. + SmallVector inputs; + inputs.reserve(quantizing_op->getNumOperands()); + for (auto operand : quantizing_op->getOperands()) { + Type operand_type = operand.getType(); + if (operand_type.isa()) { + inputs.push_back(operand); + continue; + } + + auto ele_type = operand.getType().cast().getElementType(); + if (static_cast(this) + ->AllowDynamicRangeQuantizedOperand(quantizing_op, + custom_map)) { + auto dq_op = dyn_cast_or_null(operand.getDefiningOp()); + + if (dq_op && inference_type == tensorflow::DT_QINT8 && + !static_cast(this)->IsWeightOnlyOp( + quantizing_op, ops_blocklist, weight_only_quantization, + custom_map)) { + // Dynamic range quantization is applied by having QuantizeOp as an + // input. Only int8 weight is supported for now. + inputs.push_back(dq_op.getOperand()); + is_operand_or_result_modified = true; + } else { + // Otherwise, it's the case where the operand is activations or the + // quantizing_op is non-supported/weight-only. + inputs.push_back(operand); + } + } else { + if (auto dq_op = + dyn_cast_or_null(operand.getDefiningOp())) { + is_operand_or_result_modified = true; + inputs.push_back(dq_op.getOperand()); + } else if (!ele_type.isF32()) { + // If the operand is an integer tensor, then it doesn't require the + // DequantizeOp in the pattern. + inputs.push_back(operand); + } else { + return failure(); + } + } + } + + Operation* quantized_op; + if (QuantizableOpSupportsFloatOutputType(quantizing_op)) { + rewriter.setInsertionPointAfter(quantizing_op); + OperationState new_state( + quantizing_op->getLoc(), quantizing_op->getName().getStringRef(), + inputs, quantizing_op->getResultTypes(), quantizing_op->getAttrs()); + for (const auto& indexed_regions : + llvm::enumerate(quantizing_op->getRegions())) { + Region* target_region = new_state.addRegion(); + IRMapping mapping; + indexed_regions.value().cloneInto(target_region, mapping); + } + quantized_op = rewriter.create(new_state); + rewriter.replaceOp(quantizing_op, quantized_op); + } else { + // Collect all the quantized outputs and replace them by the results of + // the new quantized op. + llvm::SmallDenseMap outputs_replaced; + SmallVector output_types; + output_types.reserve(quantizing_op->getNumResults()); + for (const auto& enumerated_result : + llvm::enumerate(quantizing_op->getResults())) { + Value result = enumerated_result.value(); + Type result_type = result.getType(); + // Add this to the test coverage once we create test ops with none + // type results. + if (result_type.isa()) { + outputs_replaced.insert({result, enumerated_result.index()}); + output_types.push_back(result_type); + continue; + } + Type result_ele_type = + result.getType().cast().getElementType(); + // If the user is the QuantizeOp, it must be the only user. + if (result.hasOneUse() && + llvm::isa(*result.user_begin())) { + auto user = llvm::cast(*result.user_begin()); + outputs_replaced.insert( + {user.getResult(), enumerated_result.index()}); + output_types.push_back(user.getType()); + is_operand_or_result_modified = true; + } else if (!result_ele_type.isF32()) { + // If the result is an integer tensor, then it doesn't require the + // D op in the pattern. + outputs_replaced.insert({result, enumerated_result.index()}); + output_types.push_back(result.getType()); + } else if (static_cast(this) + ->AllowDynamicRangeQuantizedResult(quantizing_op, + custom_map)) { + outputs_replaced.insert({result, enumerated_result.index()}); + output_types.push_back(result.getType()); + } else { + return failure(); + } + } + + // For float16 quantization if none of the operand or result is + // modified, replacing the op. See b/335025403. + if (inference_type == tensorflow::DT_HALF && + !is_operand_or_result_modified) { + return failure(); + } + + rewriter.setInsertionPointAfter(quantizing_op); + OperationState new_state( + quantizing_op->getLoc(), quantizing_op->getName().getStringRef(), + inputs, output_types, quantizing_op->getAttrs()); + for (int i = 0; i < quantizing_op->getNumRegions(); ++i) { + new_state.addRegion(); + } + quantized_op = rewriter.create(new_state); + if (quantizing_op->getNumRegions() != 0) { + for (const auto& indexed_regions : + llvm::enumerate(quantizing_op->getRegions())) { + Region& target_region = + quantized_op->getRegion(indexed_regions.index()); + IRMapping mapping; + indexed_regions.value().cloneInto(&target_region, mapping); + } + } + for (auto output : outputs_replaced) { + output.getFirst().replaceAllUsesWith( + quantized_op->getResult(output.getSecond())); + } + } + + // To verify the numericals, the original floating-point ops are + // preserved in the graph. The result of these floating-point ops are sent + // to a numeric verifier op as the reference. + if (enable_verify && !std::is_same_v) { + // For constant operands, the floating-point constant is duplicated in + // case it is quantized. + for (int i = 0, e = quantized_op->getNumOperands(); i < e; ++i) { + auto def = quantized_op->getOperand(i).getDefiningOp(); + if (auto q = llvm::dyn_cast_or_null(def)) { + DenseFPElementsAttr attr; + if (!matchPattern(q.getOperand(), m_Constant(&attr))) { + continue; + } + auto cst = rewriter.create( + quantized_op->getLoc(), attr); + quantizing_op->setOperand(i, cst.getResult()); + } + } + + for (int i = 0, e = quantized_op->getNumResults(); i < e; ++i) { + if (!quantizing_op->getResult(i) + .getType() + .cast() + .getElementType() + .isa()) { + continue; + } + CreateVerifier(quantizing_op, quantized_op, rewriter, i, + quant_params_); + + if (enable_whole_model_verify) { + RewireFloatModelBackbone(quantized_op, quantizing_op); + } + } + } + } + return success(); + } + + private: + // Reconnects float ops in the whole-model verify mode. Works for both + // Quantizable ops and Unquantizable ops + void RewireFloatModelBackbone(Operation* quantized_op, + Operation* float_op) const { + for (int i = 0, e = quantized_op->getNumResults(); i < e; ++i) { + if (!float_op->getResult(i) + .getType() + .cast() + .getElementType() + .isF32()) { + continue; + } + // Find the Quantize/Dequantize users of the new op results, and replace + // the usage. Then all the floating-point ops are connected, forming a + // separate float "backbone" model that the quantized model can be + // compared against in parallel. + // N.B. the return op will use this floating-point result. + Value result; + if (!IsOpQuantizable(float_op)) { + // For not quantizable ops, search for dequantize attached to the + // quantized op of the output. + if (Operation* quantize_op = dyn_cast_or_null( + *quantized_op->getResult(i).getUsers().begin())) { + result = quantize_op->getResult(0); + } else { + quantized_op->emitError() + << "Output[" << i + << "] is expected to have only one user [QUANTIZE]"; + return; + } + } else { + result = quantized_op->getResult(i); + } + for (auto user : result.getUsers()) { + // Skip the Requantize op and set the user to the following dequantize + // op. This happens when the quantizer tries to match the scale conflict + // with QuantizeOp - QuantizeOp(requant) - DequantizeOp triples. The + // correct float op should be the user of the last DequantizeOp. + if (llvm::isa(user)) { + user = *user->getResult(0).getUsers().begin(); + } + if (auto dequantize = llvm::dyn_cast(user)) { + // Replace all uses, except not quantizable ops that are being used in + // the float backbone. + dequantize.getResult().replaceUsesWithIf( + float_op->getResult(i), [&](OpOperand& use) { + return !use.getOwner()->hasAttr(kDebugModeOpQuantAttrName); + }); + } + } + } + } + + QuantPassSpec quant_params_; +}; + +// A pattern that removes debug attributes that are annotated to ops during +// the debug model creation. +class RemoveDebugAttrPattern : public RewritePattern { + public: + explicit RemoveDebugAttrPattern(MLIRContext* context) + : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {} + LogicalResult matchAndRewrite(Operation* op, + PatternRewriter& rewriter) const override; +}; + +// Converts quantized tensor type with signed integer type to quantized tensor +// type with unsigned integer type. +Type ConvertSignedQuantizedToUnsigned(Type signed_tensor_type, Location loc); + +// Converts quantize ops with unsigned quantized types to these with signed +// quantized types and preserves the scales. +template +struct ConvertUnsignedToSigned : public OpRewritePattern { + using BaseType = ConvertUnsignedToSigned; + using QType = quant::QuantizedType; + + explicit ConvertUnsignedToSigned(MLIRContext* context) + : OpRewritePattern(context, 1) {} + + LogicalResult matchAndRewrite(QuantizeOpT op, + PatternRewriter& rewriter) const override { + Type output_type = op.getResult().getType(); + auto qtype = QType::getQuantizedElementType(output_type); + if (!qtype || qtype.isSigned()) return failure(); + + int num_bits = qtype.getStorageTypeIntegralWidth(); + if (num_bits == 8) { + // If storage is 8-bit, trained num bits may be less than 8 so check here. + num_bits = + static_cast(std::ceil(std::log2(qtype.getStorageTypeMax()))); + } + // This is a positive value, and will be applied on zero points and fixed + // point ranges. + int64_t offset = + QType::getDefaultMinimumForInteger(/*isSigned=*/false, num_bits) - + QType::getDefaultMinimumForInteger(/*isSigned=*/true, num_bits); + + auto flags = quant::QuantizationFlags::Signed; + QType new_qtype; + if (auto uqtype = qtype.template dyn_cast()) { + new_qtype = quant::UniformQuantizedType::getChecked( + op.getLoc(), flags, qtype.getStorageType(), qtype.getExpressedType(), + uqtype.getScale(), uqtype.getZeroPoint() - offset, + uqtype.getStorageTypeMin() - offset, + uqtype.getStorageTypeMax() - offset); + } else if (auto aqtype = qtype.template dyn_cast< + quant::UniformQuantizedPerAxisType>()) { + auto zero_points = aqtype.getZeroPoints(); + llvm::SmallVector new_zero_points(zero_points.begin(), + zero_points.end()); + for (int i = 0, e = new_zero_points.size(); i < e; ++i) { + new_zero_points[i] -= offset; + } + new_qtype = quant::UniformQuantizedPerAxisType::getChecked( + op.getLoc(), flags, qtype.getStorageType(), qtype.getExpressedType(), + aqtype.getScales(), new_zero_points, aqtype.getQuantizedDimension(), + aqtype.getStorageTypeMin() - offset, + aqtype.getStorageTypeMax() - offset); + } else { + return failure(); + } + + if (!new_qtype) return failure(); + Type new_output_type = new_qtype.castFromExpressedType( + QType::castToExpressedType(output_type)); + rewriter.replaceOpWithNewOp(op, new_output_type, op.getArg()); + return success(); + } +}; + +// Fold Extra Requantize ops if the preceding ops has free scale requirement. +template +struct FoldTrivalRequantizeOp : public OpRewritePattern { + explicit FoldTrivalRequantizeOp(MLIRContext* context) + : OpRewritePattern(context, 1) {} + + LogicalResult matchAndRewrite(RequantizeOpT op, + PatternRewriter& rewriter) const override { + Value pre_quantized = op->getOperand(0); + auto pre_quantized_type = + quant::QuantizedType::getQuantizedElementType(pre_quantized.getType()); + if (!pre_quantized_type) return failure(); + + Operation* def = pre_quantized.getDefiningOp(); + if (!def) return failure(); + if (llvm::isa(def) || + !def->hasTrait()) { + return failure(); + } + + // This op should not clobber def, if more than one requant of this value. + if (!pre_quantized.hasOneUse()) { + return failure(); + } + + op.emitWarning("Remove trivial `rescale` op. Please fix the source graph."); + + llvm::SmallVector new_output_types; + for (auto result : def->getResults()) { + if (result.hasOneUse() && *result.getUsers().begin() == op) { + new_output_types.push_back(op.getResult().getType()); + } else { + new_output_types.push_back(result.getType()); + } + } + + // Remove this rescale op. + rewriter.replaceOp(op, {pre_quantized}); + + // Replace the output scale of the preceding op. + rewriter.setInsertionPointAfter(def); + OperationState new_state(def->getLoc(), def->getName().getStringRef(), + def->getOperands(), new_output_types, + def->getAttrs()); + Operation* new_op = rewriter.create(new_state); + + rewriter.replaceOp(def, new_op->getResults()); + return success(); + } +}; + +// Given a quantized type `input`, magnifying its scales by the factor stored in +// `factor`. If `input` isn't a quantized type or the `factor` doesn't match the +// dimension size of `input` or isn't floating-point, nullptr will be returned. +TypeAttr RescaleQuantizedType(Type input, Attribute factor); + +// Converts the min/max/num_bits/narrow_range information to a +// QuantizedType, and then returns the attribute containing the QuantizedType. +// The `min` and `max` arguments can be FloatAttr or DenseFPElementsAttr and +// returns UniformQuantizedType or UniformQuantizedPerAxisType respectively. +// `narrow_range` is set to true for weights and `is_signed` is set to true +// if it is using signed int symmetric quantization. +// +// Note that this method may broadcast min and max to match the dimension length +// of `input_type`, if the `quant_dim` is valid. On the other hand, the +// symmetry of min and max is not adjusted by this method. The QAT workflow +// should set min/max correctly (and use `narrow_range`=true, `is_signed`=true) +// if symmetric quantization is required. +TypeAttr GetQuantizedTypeAttr(Builder builder, Type input_type, Attribute min, + Attribute max, int quant_dim, + IntegerAttr num_bits, BoolAttr narrow_range, + bool is_signed, bool legacy_float_scale = false, + bool use_fake_quant_num_bits = false); + +// Casts the `target` type to a quantized type by using the quantization +// parameters from the type in the `source` type attribute. +// Examples: +// f32 -> !quant.uniform +// tensor<4xf32> -> tensor<4x!quant.uniform> +// The result is wrapped by a type attribute. Returns nullptr if the cast +// isn't valid. +// +// `axis` is to specify the quantization dimension in the `target` and only +// used if the element type of `source` is a per-channel quantized type. During +// the casting, the quantization dimension of the result type needs to be set +// this new `axis` value. +TypeAttr CastQuantizedTypeAttrFromExpressedType(Builder builder, + TypeAttr source, Type target, + int axis); + +// Quantizes the elements in the attribute `real_value` by the quantization +// parameters in `tensor_type`. Returns empty Attribute if the +// `tensor_type` is not a QuantizedType or the quantization fails. +ElementsAttr Quantize(Attribute real_value, Type tensor_type); + +// Quantizes the elements in "legacy mode", where it calls TOCO's methods to +// to quantize values with float scale. +ElementsAttr QuantizeLegacy(Attribute real_value, Type tensor_type); + +// Returns the quantized type for an element attribute. The quantization +// parameters in this type is based on the min and max element of the +// attribute. When the elements in the `attr` are not in floating-point, or +// the value range isn't straddling zero, an empty type is returned. The min/max +// are adjusted to be symmetric if `symmetric` flag is set to True. And +// `symmetric` can only be set to true when it is signed and narrow_range. +Type GetUniformQuantizedTypeForWeight(ElementsAttr attr, bool symmetric, + unsigned num_bits, bool is_signed, + bool narrow_range, + bool legacy_float_scale = false, + bool use_fake_quant_num_bits = false); + +// Returns the per channel quantized type for an element attribute. +// `quant_dim` defines the quantization axis. The channel min/max are adjusted +// to be symmetric if `symmetric` flag is set to True. And `symmetric` can only +// be set to true when it is signed and narrow_range. +Type GetUniformQuantizedPerAxisTypeForWeight( + ElementsAttr attr, int quant_dim, bool symmetric, unsigned num_bits, + bool is_signed, bool narrow_range, bool legacy_float_scale = false, + bool use_fake_quant_num_bits = false); + +// Returns the quantized type of a bias input, given the quantized types of +// other operands which are multiply-accumulated (the bias is added to the +// accumulated value). +quant::QuantizedType GetUniformQuantizedTypeForBias( + const std::vector& op_types, int adjusted_quant_dim, + bool legacy_float_scale = false); + +// Gets quantization scale specs (e.g. fixed output range, same result and +// operand scales) from the default quantization interfaces. The op should +// outlive returned spec for its interface methods to be properly referenced. +std::unique_ptr GetDefaultQuantScaleSpec(Operation* op); + +// The function might contain more stats ops than required, and it will +// introduce requantize if the calibration stats have conflicts. This method +// tries to remove all the redundant stats ops. +bool RemoveRedundantStatsOps(mlir::func::FuncOp func, + OpQuantSpecGetter op_quant_spec_getter, + OpQuantScaleSpecGetter op_quant_scale_spec_getter = + GetDefaultQuantScaleSpec); + +// Given quantization parameters for int8, compute the quantization parameters +// for uint if it is required, and wrap the result in an UniformQuantizedType. +quant::UniformQuantizedType GetFixedOutputRange(bool is_signed, int bit_width, + Type tensor_type, double scale, + int64_t zero_point, + int64_t storage_min, + int64_t storage_max); + +quant::UniformQuantizedType GetFixedOutputRange(bool is_signed, int bit_width, + Type tensor_type, double scale, + int64_t zero_point); + +// Extracts min and max values from the DenseFPElementsAttr, and stores them +// into `mins` and `maxs`. When mins and maxs are extracted per-channel, +// `dim_size` is number of channels and `slice_size` is the size of slice per +// each channel. When `symmetric` is true, the range is expanded to [-M, M]. +void ExtractMinMaxFromAttr(DenseFPElementsAttr values, int dim_size, + int slice_size, bool symmetric, + SmallVectorImpl& mins, + SmallVectorImpl& maxs); + +// Returns the quantized type for the +// input_type/min/max/storage_type_width/narrow_range. +Type GetQuantizedType(Builder builder, Type input_type, ArrayRef min, + ArrayRef max, int quant_dim, + int storage_type_width, bool narrow_range, bool is_signed, + bool legacy_float_scale = false, + bool use_fake_quant_num_bits = false); +} // namespace quant +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_COMMON_QUANTIZATION_LIB_QUANTIZATION_UTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/common/test_base.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/common/test_base.h new file mode 100644 index 00000000..f33e586c --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/common/test_base.h @@ -0,0 +1,87 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_COMMON_TEST_BASE_H_ +#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_COMMON_TEST_BASE_H_ + +#include + +#include +#include "absl/strings/string_view.h" +#include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Dialect/Quant/IR/Quant.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/OwningOpRef.h" // from @llvm-project +#include "mlir/Parser/Parser.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "stablehlo/dialect/StablehloOps.h" // from @stablehlo +#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" +#include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.h" +#include "tensorflow/compiler/mlir/quantization/common/func.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/context.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h" +#include "tensorflow/core/platform/test.h" + +namespace mlir::quant { + +using ::testing::Test; + +class QuantizationTestBase : public Test { + protected: + QuantizationTestBase() + : ctx_(stablehlo::CreateMlirContextForQuantization()), + builder_(ctx_.get()) { + ctx_->loadDialect< + arith::ArithDialect, mlir::stablehlo::StablehloDialect, + func::FuncDialect, TF::TensorFlowDialect, TFL::TensorFlowLiteDialect, + tf_saved_model::TensorFlowSavedModelDialect, + tf_executor::TensorFlowExecutorDialect, quant::QuantDialect, + quantfork::QuantizationForkDialect>(); + } + + // Parses `module_op_str` to create a `ModuleOp`. + OwningOpRef ParseModuleOpString( + const absl::string_view module_op_str) { + return parseSourceString(module_op_str, ctx_.get()); + } + + // Convenience function that returns the first operation of type `OpT` from + // the `@main` function in `module_op`. Useful when testing with a text + // representation of a `ModuleOp` containing a single function `@main`. + // Returns `failure` iff there is no `@main` or no such operation is found in + // `@main`. + template + FailureOr FindFirstOpFromMainFunc(ModuleOp module_op) { + func::FuncOp main_func_op = FindMainFuncOp(module_op); + if (main_func_op == nullptr) return failure(); + + auto ops = main_func_op.getOps(); + if (ops.empty()) return failure(); + + return *ops.begin(); + } + + std::unique_ptr ctx_; + OpBuilder builder_; +}; + +} // namespace mlir::quant + +#endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_COMMON_TEST_BASE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/common/uniform_quantized_types.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/common/uniform_quantized_types.h new file mode 100644 index 00000000..99815f73 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/common/uniform_quantized_types.h @@ -0,0 +1,120 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_COMMON_UNIFORM_QUANTIZED_TYPES_H_ +#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_COMMON_UNIFORM_QUANTIZED_TYPES_H_ + +#include + +#include "mlir/Dialect/Quant/IR/QuantTypes.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project + +namespace mlir { +namespace quant { + +// Creates a `UniformQuantizedType` with the given `scale` and `zero_point` +// values. The produced type has f32 as its expressed type and i8 as its +// storage type. The available values use the full range of the storage value, +// i.e. [-128, 127]. Assumes asymmetric quantization, meaning the zero point +// value can be a non-zero value. +// If `narrow_range` is set true (ex: for weights), a restricted range of +// integers will be used for symmetric mapping, i.e. [-127, 127]. +UniformQuantizedType CreateI8F32UniformQuantizedType(Location loc, + MLIRContext& context, + double scale, + int64_t zero_point, + bool narrow_range = false); + +// Creates a `UniformQuantizedType` with the given `scale` and `zero_point` +// values. The produced type has f32 as its expressed type and i32 as its +// storage type. The available values use the full range of the storage value. +// Assumes asymmetric quantization, meaning the zero point value can be +// a non-zero value. +UniformQuantizedType CreateI32F32UniformQuantizedType(Location loc, + MLIRContext& context, + double scale, + int64_t zero_point); + +// Creates a `UniformQuantizedPerAxisType` with the given `scales` and +// `zero_points` values. The produced type has f32 as its expressed type and +// i8 as its storage type. The available values use the full range of the +// storage value, i.e. [-128, 127]. Assumes asymmetric quantization, meaning the +// zero point values can be non-zero values. +// If `narrow_range` is set true (ex: for weights), a restricted range of +// integers will be used for symmetric mapping, i.e. [-127, 127]. +UniformQuantizedPerAxisType CreateI8F32UniformQuantizedPerAxisType( + Location loc, MLIRContext& context, ArrayRef scales, + ArrayRef zero_points, int quantization_dimension, + bool narrow_range = false); + +// Creates a `UniformQuantizedPerAxisType` with the given `scales` and +// `zero_points` values. The produced type has f32 as its expressed type and +// i32 as its storage type. The available values use the full range of the +// storage value. Assumes asymmetric quantization, meaning the +// zero point values can be non-zero values. +UniformQuantizedPerAxisType CreateI32F32UniformQuantizedPerAxisType( + Location loc, MLIRContext& context, ArrayRef scales, + ArrayRef zero_points, int quantization_dimension); + +bool IsStorageTypeI8(QuantizedType quantized_type); + +bool IsStorageTypeI32(QuantizedType quantized_type); + +bool IsExpressedTypeF32(QuantizedType quantized_type); + +// Given a value, extract the `ElementType`. +// `value` should be a non-null `TensorType`. +inline Type GetElementType(const Value value) { + return mlir::cast(value.getType()).getElementType(); +} + +// Returns true iff `type` is a uniform quantized type whose storage type is +// 8-bit integer and expressed type is f32. +bool IsI8F32UniformQuantizedType(Type type); + +// Returns true iff `type` is a uniform quantized per-axis (per-channel) type +// whose storage type is 8-bit integer and expressed type is f32. +bool IsI8F32UniformQuantizedPerAxisType(Type type); + +// Returns true iff `type` is a uniform quantized type whose storage type is +// 32-bit integer and expressed type is f32. +bool IsI32F32UniformQuantizedType(Type type); + +// Returns true iff `type` is a uniform quantized per-axis (per-channel) type +// whose storage type is 32-bit integer and expressed type is f32. +bool IsI32F32UniformQuantizedPerAxisType(Type type); + +// Determines whether the storage type of a quantized type is supported by +// `tfl.quantize` or `tfl.dequantize` ops. ui8, i8 and i16 are supported. +bool IsSupportedByTfliteQuantizeOrDequantizeOps(IntegerType storage_type); + +// Returns true if a type is quantized tensor type. +bool IsQuantizedTensorType(Type type); + +// Returns true if all operands and results are quantized. +bool IsOpFullyQuantized(Operation* op); + +// Returns true iff none among operand and result tensors are quantized. +bool IsOpNotQuantized(Operation* op); + +} // namespace quant +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_COMMON_UNIFORM_QUANTIZED_TYPES_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/calibration_parameters.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/calibration_parameters.h new file mode 100644 index 00000000..9e1950af --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/calibration_parameters.h @@ -0,0 +1,77 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_CALIBRATION_CALIBRATION_PARAMETERS_H_ +#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_CALIBRATION_CALIBRATION_PARAMETERS_H_ + +#include +#include +#include + +#include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h" + +namespace stablehlo::quantization { + +// Calculates the bin width from the range and expected number of bins. The +// bin width is formalized to the form of 2^n. As a consequence, the actual +// number of bins might be smaller than the given `num_bins`. +inline float CalculateBinWidth(const float min_value, const float max_value, + const int32_t num_bins) { + const float raw_bin_width = (max_value - min_value) / num_bins; + return std::pow(2, std::ceil(std::log2(raw_bin_width))); +} + +// Calculates the lower bound of the histogram. The lower bound is in form of +// `N * bin_width`. +inline float CalculateLowerBound(const float min_value, const float bin_width) { + return std::floor(min_value / bin_width) * bin_width; +} + +// Calculates the bin index of the current value. +inline int32_t CalculateBinIndex(const float value, const float lower_bound, + const float bin_width) { + return std::floor((value - lower_bound) / bin_width); +} + +// Same as `CalculateBinIndex` but clamps to avoid out-of-bound. +inline int32_t CalculateBinIndexSafe(const float value, const float lower_bound, + const float bin_width, + const int32_t num_bins) { + const int32_t bin_index = CalculateBinIndex(value, lower_bound, bin_width); + return std::clamp(bin_index, 0, num_bins - 1); +} + +// Checks if the given method is a histogram-based calibration method. +inline bool IsHistogramCalibration( + const CalibrationOptions::CalibrationMethod method) { + return method == + CalibrationOptions::CALIBRATION_METHOD_HISTOGRAM_PERCENTILE || + method == + CalibrationOptions::CALIBRATION_METHOD_HISTOGRAM_MSE_BRUTEFORCE || + method == CalibrationOptions:: + CALIBRATION_METHOD_HISTOGRAM_MSE_MAX_FREQUENCY || + method == + CalibrationOptions::CALIBRATION_METHOD_HISTOGRAM_MSE_SYMMETRIC; +} + +// Gets the number of bins for the given calibration method. +inline int32_t GetNumBins(const CalibrationOptions& calib_opts) { + return IsHistogramCalibration(calib_opts.calibration_method()) + ? calib_opts.calibration_parameters().num_bins() + : 0; +} + +} // namespace stablehlo::quantization + +#endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_CALIBRATION_CALIBRATION_PARAMETERS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/component.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/component.h new file mode 100644 index 00000000..03d2dd93 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/component.h @@ -0,0 +1,122 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_CALIBRATION_COMPONENT_H_ +#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_CALIBRATION_COMPONENT_H_ + +#include +#include +#include + +#include "absl/base/nullability.h" +#include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/component.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/types.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/exported_model.pb.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/python/py_function_lib.h" +#include "tensorflow/core/protobuf/meta_graph.pb.h" + +namespace mlir::quant::stablehlo { + +// Performs post-calibration graph transformation as part of post-training +// static-range quantization. +// +// The resulting `ModuleOp` contains quantized StableHLO ops serialized in +// `TF::XlaCallModuleOp`s. They are quantized using the statistics collected +// after the calibration step, corresponding to each `TF::CustomAggregatorOp`s +// in the input module op. +// +// TODO: b/320607042 - Add tests for this component on the python layer. +class CalibrationComponent : public Component { + public: + // Name of the post-training quantization post-calibration step. Used for + // debugging purposes. + static constexpr absl::string_view kName = "quant_ptq_calibration"; + + // `CalibrationComponent` ctor with necessary information required to run + // calibration on a `ModuleOp`. Meta information like `function_aliases`, + // `tags`, `signature_def_map`, and `signature_keys` are required to properly + // save and load the module_op to and from SavedModel. + // `representative_dataset_file_map` contains information about the + // calibration dataset. + CalibrationComponent( + absl::Nonnull ctx, + absl::Nonnull + py_function_lib, + absl::string_view src_saved_model_path, + absl::flat_hash_map function_aliases, + std::unordered_set tags, + absl::flat_hash_map + signature_def_map, + std::vector signature_keys); + + // Runs calibration on `module_op` and returns a calibrated ModuleOp with + // calibrated statistics embedded. + absl::StatusOr Run( + ModuleOp module_op, + const ::stablehlo::quantization::QuantizationConfig& config) override; + + private: + // Exports `module_op` to SavedModel at `dst_saved_model_path`. This is used + // to export the pre-calibrated `module_op` to SavedModel so that the + // calibration process can use it to load and run the graph with the + // representative dataset. Returns a failure status if the export fails. + absl::Status ExportToSavedModel(ModuleOp module_op, + absl::string_view calibration_data_dir, + bool force_regenerate_calibration_data, + absl::string_view dst_saved_model_path); + + // Imports the SavedModel at `calibrated_saved_model_path` to `ModuleOp` after + // running calibration. + absl::StatusOr ImportCalibratedSavedModel( + absl::string_view calibrated_saved_model_path); + + absl::Nonnull ctx_; + + // Contains function implementations from the python layer. Should be injected + // from the python level using pybind11. + absl::Nonnull + py_function_lib_; + + // Path to the pre-calibrated SavedModel. + std::string src_saved_model_path_; + + // Function alias mapping for pre-calibrated SavedModel. Used to preserve + // aliased functions. + absl::flat_hash_map function_aliases_; + + // Tags to identify the MetaGraphDef to load from a SavedModel. + const std::unordered_set tags_; + + const absl::flat_hash_map + signature_def_map_; + + // Signature keys to identify the functions to load & quantize. + const std::vector signature_keys_; +}; + +// Runs passes to prepare the calibration model. +absl::Status RunCalibrationPasses(mlir::ModuleOp module_op, MLIRContext& ctx, + absl::string_view calibration_data_dir, + bool force_regenerate_calibration_data); + +} // namespace mlir::quant::stablehlo + +#endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_CALIBRATION_COMPONENT_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/min_max_value.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/min_max_value.h new file mode 100644 index 00000000..5302bad4 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/min_max_value.h @@ -0,0 +1,28 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_CALIBRATION_MIN_MAX_VALUE_H_ +#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_CALIBRATION_MIN_MAX_VALUE_H_ + +#include + +namespace stablehlo::quantization { + +// Represents the (min, max) value pair, representing the range of values after +// calibrating for quantization. +using MinMaxValue = std::pair; + +} // namespace stablehlo::quantization + +#endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_CALIBRATION_MIN_MAX_VALUE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/representative_dataset.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/representative_dataset.h new file mode 100644 index 00000000..33357630 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/representative_dataset.h @@ -0,0 +1,41 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_CALIBRATION_REPRESENTATIVE_DATASET_H_ +#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_CALIBRATION_REPRESENTATIVE_DATASET_H_ + +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/status/statusor.h" +#include "absl/types/span.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.pb.h" + +namespace stablehlo::quantization { + +// Translates a set of `RepresentativeDatsetConfig` to signature key -> +// `RepresentativeDatasetFile` mapping. This is useful when using +// `RepresentativeDatasetConfig`s at places that accept the legacy +// `RepresentativeDatasetFile` mapping. +// Returns a non-OK status when there is a duplicate signature key among +// `representative_dataset_configs`. +absl::StatusOr> +CreateRepresentativeDatasetFileMap(absl::Span + representative_dataset_configs); + +} // namespace stablehlo::quantization + +#endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_CALIBRATION_REPRESENTATIVE_DATASET_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/statistics.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/statistics.h new file mode 100644 index 00000000..41f78be3 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/statistics.h @@ -0,0 +1,50 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_CALIBRATION_STATISTICS_H_ +#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_CALIBRATION_STATISTICS_H_ + +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibration_statistics.pb.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/python/py_function_lib.h" + +namespace stablehlo::quantization { + +// Reads the calibration statistics from the given directory. +absl::StatusOr> +ReadStatistics(absl::string_view calibration_data_dir); + +// Adds calibrated min / max values to CustomAggregator nodes in `graph_def`. +// The min and max values will be added to the "min" and "max" attributes, +// respectively. `calibration_options` provides the strategy to retrieve min and +// max values. +absl::Status AddCalibrationStatistics( + mlir::ModuleOp module_op, absl::string_view calibration_data_dir, + const stablehlo::quantization::CalibrationOptions& calibration_options, + const tensorflow::quantization::PyFunctionLibrary& py_function_library); + +// Checks if the model required calibration. +bool IsCalibrationRequired(mlir::ModuleOp module_op); + +} // namespace stablehlo::quantization + +#endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_CALIBRATION_STATISTICS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/stablehlo/cc/component.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/stablehlo/cc/component.h new file mode 100644 index 00000000..a1ddb5cb --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/stablehlo/cc/component.h @@ -0,0 +1,40 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_COMPONENT_H_ +#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_COMPONENT_H_ + +#include "absl/status/statusor.h" +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h" + +namespace mlir::quant::stablehlo { + +// Component is a public abstraction for StableHLO Quantizer that represents the +// most basic unit of action applied to the StableHLO graph. Derived classes +// should override the `Run` method to implement the action. +class Component { + public: + virtual ~Component() = default; + + // Runs the action to the StableHLO graph, passed by the `module_op`. `config` + // should provide information necessary to configure the action's behavior. + virtual absl::StatusOr Run( + ModuleOp module_op, + const ::stablehlo::quantization::QuantizationConfig& config) = 0; +}; + +} // namespace mlir::quant::stablehlo + +#endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_COMPONENT_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/stablehlo/cc/config.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/stablehlo/cc/config.h new file mode 100644 index 00000000..f668cacd --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/stablehlo/cc/config.h @@ -0,0 +1,65 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_CONFIG_H_ +#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_CONFIG_H_ + +#include + +#include "absl/base/attributes.h" +#include "absl/strings/string_view.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h" + +namespace stablehlo::quantization { + +// Returns a copy of `user_provided_config` with default values populated where +// the user did not explicitly specify. +QuantizationConfig PopulateDefaults( + const QuantizationConfig& user_provided_config); + +// Returns a copy of `QuantizationConfig` where presets are expanded and +// transformed into other fields in `QuantizationConfig`. +// +// The expansion rules are as follows: +// * StaticRangePtqPreset +// - The preset's `representative_datasets` field will be transferred to +// `QuantizationConfig.calibration_options.representative_datasets`, unless +// the user explicitly provided representative dataset configs to +// `calibration_options`. In that case, the explicit configs take precedence +// and the preset's configs are ignored. +// - For `QuantizationSpecs`, the expanded `QuantizationSpec`s will be +// populated first and user-provided `QuantizationSpec`s, if any, will be +// appended. This expresses the fact that user-provided specs take precedence. +// * Preset unspecified +// - No-op. +QuantizationConfig ExpandPresets(const QuantizationConfig& config); + +// Returns whether a given QuantizationSpecs has the given quantization method. +bool HasQuantizationMethod(const QuantizationSpecs& specs, + Method::MethodCase method_case); + +// Convenience function for converting the optional `report_file_path` field to +// `std::optional`, where `std::nullopt` represents that the +// field is not explicitly set. The returned value is a reference type +// (`absl::string_view`) so its lifetime is bound to the input `config`. +inline std::optional GetReportFilePath( + const QuantizationConfig& config ABSL_ATTRIBUTE_LIFETIME_BOUND) { + return config.has_report_file_path() + ? std::make_optional(config.report_file_path()) + : std::nullopt; +} + +} // namespace stablehlo::quantization + +#endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_CONFIG_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/stablehlo/cc/context.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/stablehlo/cc/context.h new file mode 100644 index 00000000..7d03564a --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/stablehlo/cc/context.h @@ -0,0 +1,36 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_CONTEXT_H_ +#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_CONTEXT_H_ + +#include + +#include "mlir/Dialect/Func/Extensions/AllExtensions.h" // from @llvm-project +#include "mlir/IR/DialectRegistry.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project + +namespace mlir::quant::stablehlo { + +// Creates an MLIRContext with the extensions required for quantization are +// registered. +inline std::unique_ptr CreateMlirContextForQuantization() { + DialectRegistry registry{}; + func::registerAllExtensions(registry); + return std::make_unique(registry); +} + +} // namespace mlir::quant::stablehlo + +#endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_CONTEXT_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/stablehlo/cc/debugger.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/stablehlo/cc/debugger.h new file mode 100644 index 00000000..feae1444 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/stablehlo/cc/debugger.h @@ -0,0 +1,31 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_DEBUGGER_H_ +#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_DEBUGGER_H_ + +#include "mlir/IR/BuiltinOps.h" // from @llvm-project + +namespace stablehlo::quantization { + +// Disables debugging on `DumpTensor` ops. +void DisableDebugging(mlir::ModuleOp module_op); + +// Changes the filename from `unquantized_tensor_data.pb` to +// `quantized_tensor_data.pb`. +void ChangeToQuantizedFilename(mlir::ModuleOp module_op); + +} // namespace stablehlo::quantization + +#endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_DEBUGGER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/stablehlo/cc/graph_def.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/stablehlo/cc/graph_def.h new file mode 100644 index 00000000..5796b18e --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/stablehlo/cc/graph_def.h @@ -0,0 +1,46 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_GRAPH_DEF_H_ +#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_GRAPH_DEF_H_ + +#include + +#include "tensorflow/core/framework/function.pb.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/node_def.pb.h" + +namespace stablehlo::quantization { + +// Mutates all `NodeDef`s in `graph_def` by applying `func`. It modifies the +// top-level `NodeDef`s as well as all `NodeDef`s in the function library. +// `func` should accept a `NodeDef` reference. +template >> +void MutateNodeDefs(tensorflow::GraphDef& graph_def, FuncT&& func) { + for (tensorflow::NodeDef& node_def : *graph_def.mutable_node()) { + func(node_def); + } + + for (tensorflow::FunctionDef& function_def : + *graph_def.mutable_library()->mutable_function()) { + for (tensorflow::NodeDef& node_def : *function_def.mutable_node_def()) { + func(node_def); + } + } +} + +} // namespace stablehlo::quantization + +#endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_GRAPH_DEF_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/stablehlo/cc/io.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/stablehlo/cc/io.h new file mode 100644 index 00000000..39c99436 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/stablehlo/cc/io.h @@ -0,0 +1,73 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_IO_H_ +#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_IO_H_ + +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "tsl/platform/env.h" +#include "tsl/platform/errors.h" + +namespace stablehlo::quantization::io { + +// Generates a unique local tmp file name. This function only generates the name +// (path) and doesn't actually creates the file. +absl::StatusOr GetLocalTmpFileName(tsl::Env* env); + +// Generates a unique local tmp file name. This function only generates the name +// (path) and doesn't actually creates the file. The default environment +// `tsl::Env::Default` is used to generate the name. +absl::StatusOr GetLocalTmpFileName(); + +// Creates a temporary directory on an environment defined by the implementation +// of `tsl::Env` and returns its path. Returns an InternalError status if +// failed. +absl::StatusOr CreateTmpDir(tsl::Env* env); + +// Creates a temporary directory and returns its path. Returns an InternalError +// status if failed. The file system used will be the default environment +// returned by `tsl::Env::Default`. +absl::StatusOr CreateTmpDir(); + +// Convenience function for writing string `data` to file without the need to +// pass `tsl::Env` instance. Internally it uses the default `tsl::Env::Default`. +absl::Status WriteStringToFile(absl::string_view file_path, + absl::string_view data); + +// Convenience function for reading string data from file at `file_path` without +// the need to pass `tsl::Env` instance. Internally it uses the default +// `tsl::Env::Default`. Returns an OK status with string data containing file +// contents. Returns non-ok status upon error, e.g. file doesn't exist. +absl::StatusOr ReadFileToString(absl::string_view file_path); + +// Lists all files and directories under the given directory. +absl::StatusOr> ListDirectory( + absl::string_view directory); + +template +absl::StatusOr ReadBinaryProto(const std::string& binary_file_path) { + MessageT message; + TF_RETURN_IF_ERROR( + tsl::ReadBinaryProto(tsl::Env::Default(), binary_file_path, &message)); + return message; +} + +} // namespace stablehlo::quantization::io + +#endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_IO_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/stablehlo/cc/pass_pipeline.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/stablehlo/cc/pass_pipeline.h new file mode 100644 index 00000000..408152f6 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/stablehlo/cc/pass_pipeline.h @@ -0,0 +1,75 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_PASS_PIPELINE_H_ +#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_PASS_PIPELINE_H_ + +#include "mlir/Pass/PassManager.h" // from @llvm-project +#include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.pb.h" + +namespace mlir::quant::stablehlo { + +// Adds passes for static-range quantization pre-calibration. Inserts ops +// required to collect tensor statistics. +void AddPreCalibrationPasses( + OpPassManager& pm, + const ::stablehlo::quantization::CalibrationOptions& calibration_options, + const ::stablehlo::quantization::QuantizationSpecs& specs, + const ::stablehlo::quantization::DebuggerConfig& debugger_config); + +// Adds passes for static-range quantization post-calibration. Utilizes tensor +// statistics collected from the calibration step and performs quantization. +void AddPostCalibrationPasses( + OpPassManager& pm, + const ::stablehlo::quantization::PipelineConfig& pipeline_config, + const ::stablehlo::quantization::QuantizationSpecs& specs); + +// Adds passes for weight-only quantization. +void AddWeightOnlyQuantizationPasses( + OpPassManager& pm, + const ::stablehlo::quantization::QuantizationSpecs& quantization_specs, + const ::stablehlo::quantization::PipelineConfig& pipeline_config, + const ::stablehlo::quantization::DebuggerConfig& debugger_config); + +// Deserializes StableHLO functions serialized and embedded in XlaCallModuleOps. +void AddXlaCallModuleOpDeserializationPasses(OpPassManager& pm); + +// Legalizes shape/tensor/arith dialect ops to StableHLO for handling dynamic +// shapes, by going through a round-trip to MHLO. +void AddShapeLegalizationPasses(OpPassManager& pm); + +// Serializes the StableHLO module into a tf.XlaCallModuleOp for compatibility +// with passes that expect TF format. This also allows the StableHLO ops to be +// exported as a TF SavedModel. +void AddCallModuleSerializationPasses(OpPassManager& pm); + +// Passes for unpacking quantized ops to int valued StableHLO ops. This is +// useful when uniform quantized types are suboptimal for the hardware. It goes +// through a StableHLO <-> MHLO roundtrip to utilize the MHLOQuantToInt pass. +void AddStablehloQuantToIntPasses(OpPassManager& pm); + +// Processes tensors with NCHW format (== (batch, channel, height, weight)) by +// converting them to NHWC formats along with extra optimizations such as +// constant folding the transpose->convolution pattern. This is useful when +// downstream pipeline (e.g. XLA) is more optimized when accepting NHWC formats. +void AddProcessNchwTensorPasses(OpPassManager& pm); + +// Registers quantization pass pipelines. This is only required when running +// MLIR opt binaries and not required when adding passes programmatically. +void RegisterPassPipelines(); + +} // namespace mlir::quant::stablehlo + +#endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_PASS_PIPELINE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/stablehlo/cc/permutation.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/stablehlo/cc/permutation.h new file mode 100644 index 00000000..35b1082b --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/stablehlo/cc/permutation.h @@ -0,0 +1,44 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_PERMUTATION_H_ +#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_PERMUTATION_H_ + +#include +#include + +#include "llvm/ADT/ArrayRef.h" // IWYU pragma: keep; required to include the definition of ArrayRef +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" // IWYU pragma: keep; required to include the definition of SmallVector +#include "mlir/Support/LLVM.h" // from @llvm-project + +namespace mlir::quant { + +// Permutes `values` with `permutation`. Returns the permuted values. Sizes of +// `values` and `permutation` must be equal, and the elements of `permutation` +// should be less than `values.size()`. +template , void>> +SmallVector Permute(const ArrayRef values, + const ArrayRef permutation) { + SmallVector permuted_values(/*Size=*/values.size(), /*Value=*/T{}); + for (auto [i, permutation_idx] : llvm::enumerate(permutation)) { + permuted_values[i] = std::move(values[permutation_idx]); + } + return permuted_values; +} + +} // namespace mlir::quant + +#endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_PERMUTATION_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/stablehlo/cc/post_calibration.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/stablehlo/cc/post_calibration.h new file mode 100644 index 00000000..6e376281 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/stablehlo/cc/post_calibration.h @@ -0,0 +1,59 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_POST_CALIBRATION_H_ +#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_POST_CALIBRATION_H_ + +#include "absl/base/nullability.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/Pass/PassManager.h" // from @llvm-project +#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/component.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h" + +namespace mlir::quant::stablehlo { + +// Performs post-calibration graph transformation as part of post-training +// static-range quantization. +// +// The resulting `ModuleOp` contains quantized StableHLO ops serialized in +// `TF::XlaCallModuleOp`s. They are quantized using the statistics collected +// after the calibration step, corresponding to each `TF::CustomAggregatorOp`s +// in the input module op. +class PostCalibrationComponent : public Component { + public: + // Name of the post-training quantization post-calibration step. Used for + // debugging purposes. + static constexpr absl::string_view kName = "quant_ptq_post_calibration"; + + explicit PostCalibrationComponent(absl::Nonnull ctx); + + absl::StatusOr Run( + ModuleOp module_op, + const ::stablehlo::quantization::QuantizationConfig& config) override; + + void AddPasses( + OpPassManager& pm, + const ::stablehlo::quantization::QuantizationSpecs& specs, + const ::stablehlo::quantization::PipelineConfig& pipeline_config) const; + + private: + absl::Nonnull ctx_; +}; + +} // namespace mlir::quant::stablehlo + +#endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_POST_CALIBRATION_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/stablehlo/cc/pre_calibration.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/stablehlo/cc/pre_calibration.h new file mode 100644 index 00000000..bdc61baf --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/stablehlo/cc/pre_calibration.h @@ -0,0 +1,53 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_PRE_CALIBRATION_H_ +#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_PRE_CALIBRATION_H_ + +#include "absl/base/nullability.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/component.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.pb.h" + +namespace mlir::quant::stablehlo { + +// Performs pre-calibration graph transformation as part of post-training +// static-range quantization. + +// The resulting `ModuleOp` contains `TF::CustomAggregatorOp`s for collecting +// quantization statistics, along with `TF::XlaCallModuleOp`s that correspond to +// lifted quantizable functions. +class PreCalibrationComponent : public Component { + public: + // Name of the post-training quantization pre-calibration step. Used for + // debugging purposes. + static constexpr absl::string_view kName = "quant_ptq_pre_calibration"; + + explicit PreCalibrationComponent(absl::Nonnull ctx); + + absl::StatusOr Run( + ModuleOp, + const ::stablehlo::quantization::QuantizationConfig& config) override; + + private: + absl::Nonnull ctx_; +}; + +} // namespace mlir::quant::stablehlo + +#endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_PRE_CALIBRATION_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/stablehlo/cc/report.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/stablehlo/cc/report.h new file mode 100644 index 00000000..8252dda6 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/stablehlo/cc/report.h @@ -0,0 +1,71 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_REPORT_H_ +#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_REPORT_H_ + +#include + +#include "absl/status/status.h" +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h" + +namespace mlir::quant::stablehlo { + +// A class that manages information about `QuantizableUnit`s post-quantization, +// internally in the form of `QuantizationUnits`. It is used to collect +// quantization summary from a quantized `ModuleOp` and emit it in a human- and +// machine-readable format. +class QuantizationReport { + public: + QuantizationReport() = default; + + // Initializes `QuantizationReport` by collecting `QuantizationResults` from + // `module_op`. + explicit QuantizationReport(ModuleOp module_op); + + // Adds a `QuantizationResult` to the report. + void AddQuantizationResult( + ::stablehlo::quantization::QuantizationResult&& result); + + // Returns `QuantizationResults` that are registered in this report. + const ::stablehlo::quantization::QuantizationResults& GetQuantizationResults() + const { + return quantization_results_; + } + + // Returns a human-readable string representation of this report. + std::string ToString() const; + + // Prints a human-readable report to stdout. + void Print() const; + + // Saves the report to `file_path`. The textproto representation of + // `QuantizationResults` will be written to the file. Returns non-ok status + // when the file write fails. + absl::Status Save(StringRef file_path) const; + + private: + ::stablehlo::quantization::QuantizationResults CollectResultsFromModuleOp( + ModuleOp module_op) const; + + // Quantization results that are registered in this report. A quantization + // result may be added manually by calling `AddQuantizationResult`. + ::stablehlo::quantization::QuantizationResults quantization_results_; +}; + +} // namespace mlir::quant::stablehlo + +#endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_REPORT_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/stablehlo/cc/saved_model_export.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/stablehlo/cc/saved_model_export.h new file mode 100644 index 00000000..357c5b0e --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/stablehlo/cc/saved_model_export.h @@ -0,0 +1,142 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// Functionalities for exporting MLIR ModuleOp to TensorFlow SavedModel. + +#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_SAVED_MODEL_EXPORT_H_ +#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_SAVED_MODEL_EXPORT_H_ + +#include +#include +#include +#include + +#include "absl/base/attributes.h" +#include "absl/container/flat_hash_map.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/Pass/PassManager.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/types.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/exported_model.pb.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/protobuf/meta_graph.pb.h" +#include "tensorflow/core/protobuf/saver.pb.h" + +namespace mlir::quant::stablehlo { + +// Suffix string for the module export step. Used for debugging. +constexpr absl::string_view kExportStepSuffix = "_export"; + +// Options when running passes for exporting an MLIR ModuleOp. +struct ExportOptions { + // If set to `true`, it runs `DuplicateShapeDeterminingConstantsPass` before + // lowering to tf_executor dialect. + bool duplicate_shape_determining_constants = true; + + // If set to `true`, unfreezes constants into variables and saves them to a + // checkpoint file. Setting this to `true` is an experimental feature that has + // no stability guarantees. + bool unfreeze_constants = false; + + // Path to the directory where checkpoint files are saved. + std::string checkpoint_dir = ""; + + // Name used to identify the ModuleOp this is exporting. Only used for + // debugging and does not modify the behavior of the export. + std::string debug_name = "stablehlo_quant"; +}; + +// Creates `ExportedModel` from `module_op`. `module_op` goes through post +// process passes before an `ExportModel` is created. +// TODO: b/329206105 - Add unit tests after decomposing post processing passes. +absl::StatusOr CreateExportedModel( + const std::vector& signature_keys, + const std::unordered_set& tags, + const ::stablehlo::quantization::QuantizationConfig& quantization_config, + absl::string_view debug_name_prefix, + const absl::flat_hash_map& function_aliases, + MLIRContext& ctx ABSL_ATTRIBUTE_LIFETIME_BOUND, ModuleOp module_op); + +// Factory function for `ExportedModel`. +[[nodiscard]] tensorflow::quantization::ExportedModel +CreateExportedModelFromGraphDef( + tensorflow::GraphDef&& graph_def, absl::string_view init_node_name, + absl::string_view checkpoint_dir, + std::optional saver_def, + const absl::flat_hash_map& function_aliases, + const std::vector& asset_file_defs); + +// Creates a new `SaverDef` instance, which contains information regarding +// checkpoint saving and restoring. This function returns a `SaverDef` instance +// with four fields populated: `version`, `filename_tensor_name`, +// `restore_op_name` and `save_tensor_name`. For valid quantized `graph_def` and +// `control_ret_node_names`, it should be able to retrieve the last three fields +// if there is at lest one variable in the graph. +// +// Returns a `std::nullopt` if there are no variables in the graph and no saving +// & restoring are required. Returns an `InternalError` status for when the +// required fields are only partially provided. +absl::StatusOr> CreateSaverDef( + const std::vector& control_ret_node_names, + const tensorflow::GraphDef& graph_def); + +// Adds passes for transforming the MLIR module op so that it can be exported +// back to GraphDef. Roughly, this consists of: +// 1) Inserting the @main function, which will become the main Graph. +// 2) Duplicating shape-determining constants. +// 3) Converting TF dialect -> tf_executor dialect. +// 4) Adding initializer function's ops into @main function for correct +// resource initialization when loading the exported model. +// +// Duplicating shape-determining constants is required to place constants that +// affect the shape of a tensor to be placed in the TPU graph instead of in the +// CPU graph, when the graph gets converted for TPU inference. This allows these +// constants to be known at XLA compilation time. +void AddExportPasses(mlir::PassManager& pm, + bool duplicate_shape_determining_constants); + +// Converts MLIR ModuleOp to `ExportedModel`. Returns `InternalError` status +// when the conversion fails. +// +// * `checkpoint_dir` is the directory where checkpoints where variable values +// are stored. This value will be fed to the "file_prefix" tensor to restore the +// variables. +// * `function_aliases` maps the actual function name to the function alias. +// This associates the quantized functions to the original functions' aliases. +// If there were no function aliases in the input model, this should be empty. +// * `asset_file_defs` include information about the assets, if any, that are +// used directly to initialize resources (like hash tables). If no assets are +// used in the model, this should be empty. +absl::StatusOr +ConvertMlirModuleToExportedModel( + mlir::ModuleOp module_op, absl::string_view checkpoint_dir, + const absl::flat_hash_map& function_aliases, + const std::vector& asset_file_defs); + +// Sets up and runs the passes for exporting `module_op`. The behavior of the +// exporting passes is controlled by `export_opts`. Returns `AssetFileDef`s that +// associate the input arguments of @main and the asset file names. Asset file +// names will be used to feed the corresponding tensors during initialization +// upon model loading. +// TODO: b/329206105 - Add unit tests after decomposing post processing passes. +absl::StatusOr> RunExportPasses( + const ExportOptions& export_opts, MLIRContext& ctx, ModuleOp module_op); + +} // namespace mlir::quant::stablehlo + +#endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_SAVED_MODEL_EXPORT_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/stablehlo/cc/saved_model_import.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/stablehlo/cc/saved_model_import.h new file mode 100644 index 00000000..9918b144 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/stablehlo/cc/saved_model_import.h @@ -0,0 +1,90 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// Functionalities for importing MLIR ModuleOp from TensorFlow SavedModel. + +#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_SAVED_MODEL_IMPORT_H_ +#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_SAVED_MODEL_IMPORT_H_ + +#include +#include +#include +#include +#include + +#include "absl/base/attributes.h" +#include "absl/container/flat_hash_map.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/OwningOpRef.h" // from @llvm-project +#include "tensorflow/cc/saved_model/loader.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/types.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h" + +namespace mlir::quant::stablehlo { + +// Represents a pair of `mlir::ModuleOp` and `tensorflow::SavedModelBundle`. The +// SavedModelBundle complements the imported ModuleOp by providing access to +// `tensorflow::Session` which may be useful when reading values from resources +// (e.g. `TF::VarHandleOp`s). +using ImportedMlirModuleOp = + std::pair, + std::unique_ptr<::tensorflow::SavedModelBundle>>; + +// Loads a SavedModel at `saved_model_path` and converts it to `mlir::ModuleOp`. +// +// `tags` identify the `tensorflow::MetaGraphDef` to load from the SavedModel. +// Similarly, `signature_keys` identify the functions (`SignatureDef`s) to load +// within the `MetaGraphDef`. `ctx` is the `MLIRContext`, which should outlive +// the returned `ModuleOp`, thus marked with the lifetime bound attribute. +// TODO: b/329206105 - Add unit tests after decomposing preprocessing passes. +absl::StatusOr SavedModelToMlirModuleOp( + absl::string_view saved_model_path, + const std::unordered_set& tags, + const std::vector& signature_keys, + MLIRContext& ctx ABSL_ATTRIBUTE_LIFETIME_BOUND); + +// Gets the function aliases from the SavedModel. +absl::StatusOr> +GetFunctionAliases(absl::string_view saved_model_path, + const std::unordered_set& tags); + +// Updates the function aliases. `module_op` may have different +// function names from the original model, so it re-associates the aliases +// with the new function names. Both the input `function_aliases` and the +// returned value are function name -> alias mappings. `function_aliases` is +// the function alias mapping of the original function. The original function's +// name is retrieved by looking at the "tf._original_func_name" string attribute +// attached to a `func::FuncOp`. +void UpdateFunctionAliases( + absl::flat_hash_map& function_aliases, + ModuleOp module_op); + +// Loads a SavedModel to `mlir::ModuleOp` and performs preprocesses including +// shape inference and graph freezing. +// TODO: b/329206105 - Add unit tests after decomposing preprocessing passes. +absl::StatusOr> ImportSavedModel( + absl::string_view saved_model_path, + const std::vector& signature_keys, + const std::unordered_set& tags, + const ::stablehlo::quantization::QuantizationConfig& quantization_config, + absl::string_view mlir_dump_file_prefix, + absl::flat_hash_map& function_aliases, + MLIRContext& ctx ABSL_ATTRIBUTE_LIFETIME_BOUND); + +} // namespace mlir::quant::stablehlo + +#endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_SAVED_MODEL_IMPORT_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/stablehlo/cc/static_range_ptq.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/stablehlo/cc/static_range_ptq.h new file mode 100644 index 00000000..69bd9da6 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/stablehlo/cc/static_range_ptq.h @@ -0,0 +1,103 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_STATIC_RANGE_PTQ_H_ +#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_STATIC_RANGE_PTQ_H_ + +#include +#include +#include +#include +#include + +#include "absl/base/nullability.h" +#include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/component.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/types.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/exported_model.pb.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/python/py_function_lib.h" +#include "tensorflow/core/protobuf/meta_graph.pb.h" + +namespace mlir::quant::stablehlo { + +// Component for static-range post-training quantization (PTQ). +// TODO: b/320607042 - Add tests in python level. +class StaticRangePtqComponent : public Component { + public: + // Name of this component. Used for debugging purposes. + static constexpr absl::string_view kName = "quant_static_range_ptq"; + + // Constructs `StaticRangePtqComponent` by creating three sub-components: + // `PreCalibrationComponent`, `CalibrationComponent`, and + // `PostCalibrationComponent`. These are stored in `sub_components_` in + // sequence. All arguments except `ctx` is used to initialize + // `CalibrationComponent`. For detailed explanation of each argument, see the + // comment of `CalibrationComponent`'s constructor. + StaticRangePtqComponent( + absl::Nonnull ctx, + absl::Nonnull + py_function_library, + absl::string_view src_saved_model_path, + std::vector signature_keys, + std::unordered_set tags, + absl::flat_hash_map + signature_def_map, + absl::flat_hash_map function_aliases); + + // Runs the static-range post-training quantization (PTQ) on `module_op`. + absl::StatusOr Run( + ModuleOp module_op, + const ::stablehlo::quantization::QuantizationConfig& config) override; + + private: + // A non-owning `MLIRContext`. This `MLIRContext` should exceed the lifetime + // of `StaticRangePtqComponent`. + absl::Nonnull ctx_; + // This component consists of three sub-components, `PreCalibrationComponent`, + // `CalibrationComponent`, and `PostCalibrationComponent`. + std::array, 3> sub_components_; +}; + +// Runs static-range post-training quantization (PTQ) on a SavedModel at +// `src_saved_model_path` and saves the resulting model to +// `dst_saved_model_path`. +// +// `quantization_config` configures the quantization behavior for the +// static-range PTQ. +// +// `signature_keys` specify the signatures that correspond to functions to be +// quantized. `signature_def_map` connects the signature keys to +// `SignatureDef`s. +// +// Returns a non-OK status when the quantization is not successful. +// LINT.IfChange +absl::Status QuantizeStaticRangePtq( + absl::string_view src_saved_model_path, + absl::string_view dst_saved_model_path, + const ::stablehlo::quantization::QuantizationConfig& quantization_config, + const std::vector& signature_keys, + const absl::flat_hash_map& + signature_def_map, + const tensorflow::quantization::PyFunctionLibrary& py_function_library); +// LINT.ThenChange(../python/pywrap_quantization.cc:static_range_ptq) + +} // namespace mlir::quant::stablehlo + +#endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_STATIC_RANGE_PTQ_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/stablehlo/cc/types.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/stablehlo/cc/types.h new file mode 100644 index 00000000..c2166330 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/stablehlo/cc/types.h @@ -0,0 +1,31 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_TYPES_H_ +#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_TYPES_H_ + +#include + +namespace mlir::quant::stablehlo { + +// Introduces aliases for `std::string` to distinguish btw. function name and +// its alias, to prevent confusion when used together in a container. For +// example, it is easy to confuse function name -> alias mapping with alias -> +// function name mapping when both are just represented as `std::string`. +using FunctionAlias = std::string; +using FunctionName = std::string; + +} // namespace mlir::quant::stablehlo + +#endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_TYPES_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/stablehlo/cc/weight_only_ptq.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/stablehlo/cc/weight_only_ptq.h new file mode 100644 index 00000000..bf23e932 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/stablehlo/cc/weight_only_ptq.h @@ -0,0 +1,80 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_WEIGHT_ONLY_PTQ_H_ +#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_WEIGHT_ONLY_PTQ_H_ + +#include +#include + +#include "absl/base/nullability.h" +#include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/Pass/PassManager.h" // from @llvm-project +#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/component.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/python/py_function_lib.h" +#include "tensorflow/core/protobuf/meta_graph.pb.h" + +namespace mlir::quant::stablehlo { + +// Performs int8 weight-only quantization on dot_general ops. +// +// The resulting `ModuleOp` contains quantized StableHLO ops serialized in +// `TF::XlaCallModuleOp`s. They are quantized using the weight constants, not +// relying on calibration. +class WeightOnlyPtqComponent : public Component { + public: + // Used for debugging purposes. + static constexpr absl::string_view kName = "quant_ptq_weight_only"; + + explicit WeightOnlyPtqComponent(absl::Nonnull ctx); + + absl::StatusOr Run( + ModuleOp module_op, + const ::stablehlo::quantization::QuantizationConfig& config) override; + + private: + absl::Nonnull ctx_; +}; + +// Runs weight-only quantization on a SavedModel at +// `src_saved_model_path` and saves the resulting model to +// `dst_saved_model_path`. +// +// `quantization_config` configures the quantization behavior for the +// weight-only quantization. +// +// `signature_keys` specify the signatures that correspond to functions to be +// quantized. `signature_def_map` connects the signature keys to +// `SignatureDef`s. +// +// Returns a non-OK status when the quantization is not successful. +// LINT.IfChange +absl::Status QuantizeWeightOnlyPtq( + absl::string_view src_saved_model_path, + absl::string_view dst_saved_model_path, + ::stablehlo::quantization::QuantizationConfig quantization_config, + const std::vector& signature_keys, + const absl::flat_hash_map& + signature_def_map, + const tensorflow::quantization::PyFunctionLibrary& py_function_library); +// LINT.ThenChange(../python/pywrap_quantization.cc:weight_only_ptq) + +} // namespace mlir::quant::stablehlo + +#endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_WEIGHT_ONLY_PTQ_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/stablehlo/instrumentations/save_report.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/stablehlo/instrumentations/save_report.h new file mode 100644 index 00000000..e690e625 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/stablehlo/instrumentations/save_report.h @@ -0,0 +1,52 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_INSTRUMENTATIONS_SAVE_REPORT_H_ +#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_INSTRUMENTATIONS_SAVE_REPORT_H_ + +#include +#include + +#include "absl/strings/string_view.h" +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassInstrumentation.h" // from @llvm-project + +namespace mlir::quant::stablehlo { + +// A `PassInstrumentation` that saves quantization report to file after +// `QuantizeCompositeFunctionsPass` is run. It inspects the `ModuleOp` after +// quantization and analyzes the quantizable units and quantization methods +// used. The report file will be saved at the `file_path`. The report file +// contains textproto of `QuantizationResults`. `file_path`'s base directories +// should exist (this pass instrumentation will not `mkdir` them). +// +// See `QuantizationReport` for further details on the quantization report. +class SaveQuantizationReportInstrumentation : public PassInstrumentation { + public: + // `file_path` is the path to save the report file. The report file is in + // textproto format so a `.txtpb` extension is preferred but it doesn't result + // in error if other extension is used. This instrumentation will not be run + // if `file_path` is a `nullopt`. + explicit SaveQuantizationReportInstrumentation( + std::optional file_path); + + void runAfterPass(Pass* pass, Operation* op) override; + + private: + std::optional file_path_; // Path to file to save the report. +}; + +} // namespace mlir::quant::stablehlo + +#endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_INSTRUMENTATIONS_SAVE_REPORT_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/stablehlo/ops/stablehlo_op_quant_spec.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/stablehlo/ops/stablehlo_op_quant_spec.h new file mode 100644 index 00000000..6c688e82 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/stablehlo/ops/stablehlo_op_quant_spec.h @@ -0,0 +1,41 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_OPS_STABLEHLO_OP_QUANT_SPEC_H_ +#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_OPS_STABLEHLO_OP_QUANT_SPEC_H_ + +#include + +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_utils.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.pb.h" + +namespace mlir::quant::stablehlo { + +// Returns StableHLO quantization specs for an op. +std::unique_ptr GetStableHloOpQuantSpec(Operation* op); + +// Returns quantization constraints (ex: fixed output, same scale) given +// a StableHLO op. +std::unique_ptr GetStableHloQuantConstraints(Operation* op); + +// Checks if an op is quantizable in StableHLO quantizer. Argument op is not +// necessarily a StableHLO op. +bool IsOpQuantizableStableHlo(Operation* op); + +} // namespace mlir::quant::stablehlo + +#endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_OPS_STABLEHLO_OP_QUANT_SPEC_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/passes.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/passes.h new file mode 100644 index 00000000..9d19c6e7 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/passes.h @@ -0,0 +1,62 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_PASSES_BRIDGE_PASSES_H_ +#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_PASSES_BRIDGE_PASSES_H_ + +#include + +#define GEN_PASS_DECL +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project + +namespace mlir::quant::stablehlo { + +// Creates an instance of the ConvertTFQuantOpsToMHLOPass pass, which will +// convert TF uniform quantized ops to the corresponding quantized MHLO ops. +std::unique_ptr> +CreateConvertTFQuantOpsToMHLOPass(); + +// TODO(b/288094093): Migrate uniform quantization legalization in a separate +// pass. +void PopulateLegalizeTfQuantizationPatterns(MLIRContext *context, + RewritePatternSet *patterns); + +// Creates an instance of the ConvertTFQuantTypes pass, which will convert TF +// qint types to int types and surround TF UniformQuantized ops with qint <-> +// int casts. +std::unique_ptr> CreateConvertTFQuantTypesPass(); + +// Creates an instance of the VerifyQuantLegalization pass, which verifies all +// quant ops and types are lowered. +std::unique_ptr> +CreateVerifyQuantLegalizationPass(); + +// Add all passes for lowering TF quant ops and types to MHLO int. +void AddQuantizationLoweringPasses(mlir::OpPassManager &pm); + +// Creates an instance of OptimizeIntGraphPass, which optimizes the int graph +// lowered from the quantized graph. +std::unique_ptr> CreateOptimizeIntGraphPass(); + +#define GEN_PASS_REGISTRATION +#define GEN_PASS_DECL_CONVERTTFQUANTOPSTOMHLO +#define GEN_PASS_DECL_CONVERTTFQUANTTYPES +#define GEN_PASS_DECL_VERIFYQUANTLEGALIZATION +#define GEN_PASS_DECL_OPTIMIZEINTGRAPH +#include "tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/passes.h.inc" +} // namespace mlir::quant::stablehlo + +#endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_PASSES_BRIDGE_PASSES_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/stablehlo/passes/passes.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/stablehlo/passes/passes.h new file mode 100644 index 00000000..d13c589c --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/stablehlo/passes/passes.h @@ -0,0 +1,62 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_PASSES_PASSES_H_ +#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_PASSES_PASSES_H_ + +#include +#include +#include + +#include "absl/status/statusor.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_config.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_options.pb.h" + +namespace mlir::quant::stablehlo { + +// Creates a pass that quantizes weight component of StableHLO graph. +std::unique_ptr> CreateQuantizeWeightPass( + const ::stablehlo::quantization::QuantizationComponentSpec& + quantization_component_spec = {}); + +// Converts a serialized StableHLO module to bfloat16 and output serialized +// module. +absl::StatusOr ConvertSerializedStableHloModuleToBfloat16( + StringRef serialized_stablehlo_module); + +std::unique_ptr> +CreateLiftQuantizableSpotsAsFunctionsPass( + const ::stablehlo::quantization::QuantizationSpecs& quantization_specs); + +// Creates a pass that inserts CalibrationStatisticsSaverOp. +std::unique_ptr> +CreateInsertCalibrationStatisticsSaverPass( + StringRef calibration_data_dir, + const std::vector& aggregator_ops_to_ignore); + +// Adds generated pass default constructors or options definitions. +#define GEN_PASS_DECL +// Adds generated pass registration functions. +#define GEN_PASS_REGISTRATION +#include "tensorflow/compiler/mlir/quantization/stablehlo/passes/passes.h.inc" + +} // namespace mlir::quant::stablehlo + +#endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_PASSES_PASSES_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantization_patterns.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantization_patterns.h new file mode 100644 index 00000000..5e45d6d7 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantization_patterns.h @@ -0,0 +1,258 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_PASSES_QUANTIZATION_PATTERNS_H_ +#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_PASSES_QUANTIZATION_PATTERNS_H_ + +#include +#include +#include + +#include "absl/container/flat_hash_set.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Dialect/Quant/IR/QuantTypes.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/IRMapping.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/OpDefinition.h" // from @llvm-project +#include "mlir/IR/OperationSupport.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/TypeUtilities.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "stablehlo/dialect/StablehloOps.h" // from @stablehlo +#include "tensorflow/compiler/mlir/quantization/common/lift_as_function_call.h" +#include "tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_utils.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/ops/stablehlo_op_quant_spec.h" +#include "tensorflow/core/framework/types.pb.h" + +namespace mlir::quant::stablehlo { + +// Checks whether an op is connected with a quantized composite function. If +// not, the same-scale op will not be quantized. This decision is based on the +// current assumption that the performance gain of the same-scale op itself +// could not beat the overhead of the quantize and dequantize routines need to +// be added around that op. When the assumption changes, this policy might +// change as well. +bool IsConnectedWithQuantizedCompsiteFunction(Operation* same_scale_op); + +// A base rewrite pattern which matches any N-in-M-out operations with +// quantization parameters propagated to at least one of its operands. The +// quantization parameters are annotated by the QuantizeOp/DequantizeOp pairs. +// Each matched pattern are rewritten by its quantized alternatives. +// +// Quantization method is determined by the `_quantization_method` attributes +// attached to each quantizable units. +// +// Template constraints are imposed as follows: +// +// * `QuantizeOpT` should have only one operand. +// * `DequantizeOpT` should have only one result. +template () && + DequantizeOpT::template hasTrait()>> +class StableHloQuantizationPattern : public OpRewritePattern { + public: + explicit StableHloQuantizationPattern(MLIRContext* context) + // Set the benefit to a large number so that it is always preferred. + : OpRewritePattern(context, /*benefit=*/300) {} + + private: + // Collects all candidate ops for quantization, which are the + // `dequantize_op`'s users. + FailureOr> CollectCandidateOps( + DequantizeOpT dequantize_op) const { + auto users = dequantize_op->getResult(0).getUsers(); + return SmallVector(users.begin(), users.end()); + } + + // Collects all candidate ops for quantization, which is the operand of + // `quantize_op`. If successful, this always returns one element which is the + // operand of `quantize_op`. + FailureOr> CollectCandidateOps( + QuantizeOpT quantize_op) const { + Value operand = quantize_op->getOperand(0); + if (QuantizedType::getQuantizedElementType(operand.getType())) { + // The input of the quantize op has already been quantized, i.e. + // rescale. + return failure(); + } + + Operation* operand_op = operand.getDefiningOp(); + if (operand_op == nullptr) { + // When `QuantizeOpT`'s operand does not have a defining op, it means it + // is a `BlockArgument`. The pattern does not match if there is no op to + // quantize. + return failure(); + } + + if (operand_op->hasTrait()) { + // Const-> QuantizeOp pattern will be handled separately. + return failure(); + } + + return SmallVector{operand_op}; + } + + LogicalResult matchAndRewrite(RootOpT op, + PatternRewriter& rewriter) const override { + // Collect all the candidate ops for quantization. + FailureOr> candidate_ops = CollectCandidateOps(op); + // Safeguard check to ensure that there is at least one quantizable op. + if (failed(candidate_ops) || candidate_ops->empty()) return failure(); + + // Rewrite the floating-point ops to the quantized version, by fusing + // preceding dequantize ops and succeding quantize ops. + for (Operation* candidate_op : *candidate_ops) { + // If it is requantize op, we shouldn't rewrite this op. + if (isa(candidate_op)) { + return failure(); + } + + // If the op is terminator, we shouldn't rewrite. + if (candidate_op->hasTrait()) { + return failure(); + } + + if (!IsOpQuantizableStableHlo(candidate_op)) { + return failure(); + } + + if (GetStableHloQuantConstraints(candidate_op) + ->has_same_scale_requirement && + !IsConnectedWithQuantizedCompsiteFunction(candidate_op)) { + return failure(); + } + + // Ops with regions will be quantized in a separate pattern. + if (isa(candidate_op)) { + return failure(); + } + + const bool weight_only_quantizable = + IsWeightOnlyQuantizableOp(*candidate_op); + + // Collect all the quantized inputs and "clone" the matched op by these + // inputs. + SmallVector inputs; + inputs.reserve(candidate_op->getNumOperands()); + for (auto operand : candidate_op->getOperands()) { + Type operand_type = operand.getType(); + if (mlir::isa(operand_type)) { + inputs.push_back(operand); + continue; + } + + auto ele_type = + mlir::cast(operand.getType()).getElementType(); + if (auto dq_op = + dyn_cast_or_null(operand.getDefiningOp())) { + inputs.push_back(dq_op.getOperand()); + } else if (!ele_type.isF32()) { + // If the operand is an integer tensor, then it doesn't require the + // DequantizeOp in the pattern. + inputs.push_back(operand); + } else if (weight_only_quantizable) { + inputs.push_back(operand); + } else { + return failure(); + } + } + + // Collect all the quantized outputs and replace them by the results of + // the new quantized op. + llvm::SmallDenseMap outputs_replaced; + SmallVector output_types; + output_types.reserve(candidate_op->getNumResults()); + for (const auto& enumerated_result : + llvm::enumerate(candidate_op->getResults())) { + Value result = enumerated_result.value(); + Type result_type = result.getType(); + // Add this to the test coverage once we create test ops with none type + // results. + if (mlir::isa(result_type)) { + outputs_replaced.insert({result, enumerated_result.index()}); + output_types.push_back(result_type); + continue; + } + Type result_ele_type = + mlir::cast(result.getType()).getElementType(); + // If the user is the QuantizeOp, it must be the only user. + if (result.hasOneUse() && isa(*result.user_begin())) { + auto user = cast(*result.user_begin()); + outputs_replaced.insert( + {user.getResult(), enumerated_result.index()}); + output_types.push_back(user.getType()); + } else if (!result_ele_type.isF32()) { + // If the result is an integer tensor, then it doesn't require the + // D op in the pattern. + outputs_replaced.insert({result, enumerated_result.index()}); + output_types.push_back(result.getType()); + } else if (weight_only_quantizable) { + outputs_replaced.insert({result, enumerated_result.index()}); + output_types.push_back(result.getType()); + } else { + return failure(); + } + } + + rewriter.setInsertionPointAfter(candidate_op); + OperationState new_state(candidate_op->getLoc(), + candidate_op->getName().getStringRef(), inputs, + output_types, candidate_op->getAttrs()); + for (int i = 0; i < candidate_op->getNumRegions(); ++i) { + new_state.addRegion(); + } + Operation* quantized_op = rewriter.create(new_state); + if (candidate_op->getNumRegions() != 0) { + for (const auto& indexed_regions : + llvm::enumerate(candidate_op->getRegions())) { + Region& target_region = + quantized_op->getRegion(indexed_regions.index()); + IRMapping mapping; + indexed_regions.value().cloneInto(&target_region, mapping); + } + } + for (auto output : outputs_replaced) { + output.getFirst().replaceAllUsesWith( + quantized_op->getResult(output.getSecond())); + } + } + return success(); + } +}; + +// Populates common patterns that are usually compute heavy or memory bound. +void PopulateCommonQuantizationPatterns( + MLIRContext& ctx, RewritePatternSet& patterns, + bool enable_per_channel_quantized_weight); + +// Populates conversion patterns for all quantizable ops, including +// ops that are not compute-heavy and data movement ops. +void PopulateAllQuantizablePatterns(MLIRContext& ctx, + RewritePatternSet& patterns); + +} // namespace mlir::quant::stablehlo + +#endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_PASSES_QUANTIZATION_PATTERNS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/stablehlo/passes/testing/passes.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/stablehlo/passes/testing/passes.h new file mode 100644 index 00000000..a8a59d1c --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/stablehlo/passes/testing/passes.h @@ -0,0 +1,40 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_PASSES_TESTING_PASSES_H_ +#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_PASSES_TESTING_PASSES_H_ + +#include "mlir/Pass/Pass.h" // from @llvm-project // IWYU pragma: keep + +namespace mlir::quant::stablehlo::testing { + +// Identifies predefined `QuantizationSpecs` for +// `TestLiftQuantizableSpotsAsFunctionsWithQuantizationSpecsPass`. The pass +// option argument is specified in line comments for each enum value. +enum class TestQuantizationSpecs { + kEmpty, // empty + kDisableAllDotGeneral, // disable-all-dot-general + kStaticRangePtqToAll, // static-range-ptq-to-all + kStaticRangePtqToComputeHeavy, // static-range-ptq-to-compute-heavy +}; + +// Adds generated pass default constructors or options definitions. +#define GEN_PASS_DECL +// Adds generated pass registration functions. +#define GEN_PASS_REGISTRATION +#include "tensorflow/compiler/mlir/quantization/stablehlo/passes/testing/passes.h.inc" + +} // namespace mlir::quant::stablehlo::testing + +#endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_PASSES_TESTING_PASSES_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/stablehlo/python/pywrap_quantization_lib.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/stablehlo/python/pywrap_quantization_lib.h new file mode 100644 index 00000000..ff724aba --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/stablehlo/python/pywrap_quantization_lib.h @@ -0,0 +1,64 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_PYTHON_PYWRAP_QUANTIZATION_LIB_H_ +#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_PYTHON_PYWRAP_QUANTIZATION_LIB_H_ + +// Contains mirror functions from StableHLO Quantizer to be exposed to python +// via `pywrap_quantization`. + +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" +#include "absl/strings/string_view.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/python/py_function_lib.h" +#include "tensorflow/core/protobuf/meta_graph.pb.h" + +namespace stablehlo::quantization::pywrap { + +// Function used by the pywrap_quantization module to mirror +// `::mlir::quant::stablehlo::QuantizeStaticRangePtq`. +absl::Status PywrapQuantizeStaticRangePtq( + absl::string_view src_saved_model_path, + absl::string_view dst_saved_model_path, const QuantizationConfig& config, + const std::vector& signature_keys, + const absl::flat_hash_map& + signature_def_map, + const tensorflow::quantization::PyFunctionLibrary& py_function_library); + +// Function used by the pywrap_quantization module to mirror +// `::mlir::quant::stablehlo::QuantizeWeightOnlyPtq`. +absl::Status PywrapQuantizeWeightOnlyPtq( + absl::string_view src_saved_model_path, + absl::string_view dst_saved_model_path, const QuantizationConfig& config, + const std::vector& signature_keys, + const absl::flat_hash_map& + signature_def_map, + const tensorflow::quantization::PyFunctionLibrary& py_function_library); + +// Function used by the pywrap_quantization module to mirror +// `::stablehlo::quantization::PopulateDefaults`. +QuantizationConfig PywrapPopulateDefaults( + const QuantizationConfig& user_provided_config); + +// Function used by the pywrap_quantization module to mirror +// `::stablehlo::quantization::ExpandPresets`. +QuantizationConfig PywrapExpandPresets(const QuantizationConfig& config); + +} // namespace stablehlo::quantization::pywrap + +#endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_PYTHON_PYWRAP_QUANTIZATION_LIB_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/stablehlo/quantize_passes.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/stablehlo/quantize_passes.h new file mode 100644 index 00000000..d754be94 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/stablehlo/quantize_passes.h @@ -0,0 +1,31 @@ +/* Copyright 2023 The StableHLO Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_QUANTIZE_PASSES_H_ +#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_QUANTIZE_PASSES_H_ + +#include "mlir/Pass/PassManager.h" // from @llvm-project +#include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_options.pb.h" + +namespace stablehlo { +namespace quantization { +// Adds passes for quantization of individual quantizable components. +// (i.e. activation, weight, bias) +void AddQuantizationPasses(mlir::PassManager& pass_manager, + const QuantizationOptions& quantization_options); + +} // namespace quantization +} // namespace stablehlo + +#endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_QUANTIZE_PASSES_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/stablehlo/utils/bfloat16_type.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/stablehlo/utils/bfloat16_type.h new file mode 100644 index 00000000..2873b071 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/stablehlo/utils/bfloat16_type.h @@ -0,0 +1,32 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_UTILS_BFLOAT16_TYPE_H_ +#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_UTILS_BFLOAT16_TYPE_H_ + +#include "mlir/IR/Types.h" // from @llvm-project + +namespace mlir::quant::stablehlo { + +// Returns true if the type or its element type is a float type with bit_width +// > 16. +bool IsLargeFloatType(Type type); + +// Converts large float type to bfloat16. Otherwise returns original type. +Type ToBfloat16Type(Type type); + +} // namespace mlir::quant::stablehlo + +#endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_UTILS_BFLOAT16_TYPE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/stablehlo/utils/fill_quantization_options.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/stablehlo/utils/fill_quantization_options.h new file mode 100644 index 00000000..691d4c35 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/stablehlo/utils/fill_quantization_options.h @@ -0,0 +1,41 @@ +/* Copyright 2023 The StableHLO Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_UTILS_FILL_QUANTIZATION_OPTIONS_H_ +#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_UTILS_FILL_QUANTIZATION_OPTIONS_H_ + +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_options.pb.h" + +namespace mlir::quant::stablehlo { + +using ::stablehlo::quantization::QuantizationOptions; + +// Returns QuantizationOptions filled with detailed specs when user specifies +// an optional preset method name. The preset methods are defined in +// quantization_options.proto. This function will only be executed if a user +// gives a preset method, not a custom method. +QuantizationOptions FillPresetQuantizationOptions( + QuantizationOptions quantization_options); + +// Returns LogicalResult depending on the look up of activation bit width in the +// custom quantization method. If such information exists, returns success, +// otherwise, returns false. +LogicalResult GetActivationBitWidth(QuantizationOptions quantization_options, + int* bit_width); + +} // namespace mlir::quant::stablehlo + +#endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_UTILS_FILL_QUANTIZATION_OPTIONS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/stablehlo/utils/math_utils.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/stablehlo/utils/math_utils.h new file mode 100644 index 00000000..f63e06a3 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/stablehlo/utils/math_utils.h @@ -0,0 +1,32 @@ +#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_UTILS_MATH_UTILS_H_ +#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_UTILS_MATH_UTILS_H_ + +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "mlir/Support/LogicalResult.h" // from @llvm-project + +namespace mlir::quant::stablehlo { + +// Decomposes a given floating point value num into a normalized and quantized +// fraction and an integral power of two. +LogicalResult QuantizeMultiplier(double double_multiplier, + int32_t& quantized_fraction, int32_t& shift); + +} // namespace mlir::quant::stablehlo + +#endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_UTILS_MATH_UTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/stablehlo/utils/stablehlo_type_utils.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/stablehlo/utils/stablehlo_type_utils.h new file mode 100644 index 00000000..81dfb576 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/stablehlo/utils/stablehlo_type_utils.h @@ -0,0 +1,34 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_UTILS_STABLEHLO_TYPE_UTILS_H_ +#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_UTILS_STABLEHLO_TYPE_UTILS_H_ + +#include "llvm/ADT/StringRef.h" +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/Transforms/DialectConversion.h" // from @llvm-project +#include "stablehlo/dialect/StablehloOps.h" // from @stablehlo + +namespace mlir::quant::stablehlo { + +// Checks if an op is from StableHLO dialect. +inline bool IsStablehloOp(Operation* op) { + return op->getDialect()->getNamespace() == + mlir::stablehlo::StablehloDialect::getDialectNamespace(); +} + +} // namespace mlir::quant::stablehlo + +#endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_UTILS_STABLEHLO_TYPE_UTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/stablehlo/utils/tf_type_utils.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/stablehlo/utils/tf_type_utils.h new file mode 100644 index 00000000..9eab6b00 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/stablehlo/utils/tf_type_utils.h @@ -0,0 +1,43 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_UTILS_TF_TYPE_UTILS_H_ +#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_UTILS_TF_TYPE_UTILS_H_ + +#include "llvm/ADT/StringRef.h" +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project + +namespace mlir::quant::tensorflow { + +// GetDenseAttrFromTensorProtoAttr returns DenseElementsAttr from tensor proto. +FailureOr GetDenseAttrFromTensorProtoAttr( + llvm::StringRef mangled_tensor_proto, TensorType result_tensor_type); + +// Check if a type is TF qint type. +bool IsTFQintType(Type type); + +// Convert qint type to the corresponding int type. Return original type if it +// is not qint type. +Type GetIntTypeFromTFQint(Type type); + +// Check if an op is TF UniformQuantized op. +bool IsTFUniformQuantizedOp(Operation* op); + +} // namespace mlir::quant::tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_UTILS_TF_TYPE_UTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibration_statistics_collector_average_min_max.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibration_statistics_collector_average_min_max.h new file mode 100644 index 00000000..f6a5da84 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibration_statistics_collector_average_min_max.h @@ -0,0 +1,52 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_CALIBRATOR_CALIBRATION_STATISTICS_COLLECTOR_AVERAGE_MIN_MAX_H_ +#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_CALIBRATOR_CALIBRATION_STATISTICS_COLLECTOR_AVERAGE_MIN_MAX_H_ + +#include + +#include "absl/types/span.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibration_statistics.pb.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibration_statistics_collector_base.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.pb.h" + +namespace tensorflow { +namespace calibrator { + +using ::stablehlo::quantization::CalibrationOptions; + +// AverageMinMax calibration calculates the average of min and max values. +// average of min = sum of min values / number of samples +// average of max = sum of max values / number of samples +class CalibrationStatisticsCollectorAverageMinMax + : public CalibrationStatisticsCollectorBase { + public: + explicit CalibrationStatisticsCollectorAverageMinMax() { ClearData(); } + + void ClearData() override; + + void Collect(float min, float max, + absl::Span histogram) override; + + std::optional GetStatistics() const override; + + private: + CalibrationStatistics::AverageMinMaxStatistics average_min_max_statistics_; +}; +} // namespace calibrator +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_CALIBRATOR_CALIBRATION_STATISTICS_COLLECTOR_AVERAGE_MIN_MAX_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibration_statistics_collector_base.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibration_statistics_collector_base.h new file mode 100644 index 00000000..9ce6a819 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibration_statistics_collector_base.h @@ -0,0 +1,45 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_CALIBRATOR_CALIBRATION_STATISTICS_COLLECTOR_BASE_H_ +#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_CALIBRATOR_CALIBRATION_STATISTICS_COLLECTOR_BASE_H_ + +#include +#include + +#include "absl/types/span.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibration_statistics.pb.h" + +namespace tensorflow { +namespace calibrator { + +// Abstract base class for CalibrationStatisticsCollcetor such as +// CalibrationStatisticsCollectorMinMax. Each class collects different +// statistics based on the calibration methods. +class CalibrationStatisticsCollectorBase { + public: + // Collect data for calibration. + virtual void Collect(float min, float max, + absl::Span histogram) = 0; + + virtual void ClearData() = 0; + // Return the statistics needed for a given calibration method. + virtual std::optional GetStatistics() const = 0; + virtual ~CalibrationStatisticsCollectorBase() = default; +}; + +} // namespace calibrator +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_CALIBRATOR_CALIBRATION_STATISTICS_COLLECTOR_BASE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibration_statistics_collector_histogram.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibration_statistics_collector_histogram.h new file mode 100644 index 00000000..84f641a5 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibration_statistics_collector_histogram.h @@ -0,0 +1,66 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_CALIBRATOR_CALIBRATION_STATISTICS_COLLECTOR_HISTOGRAM_H_ +#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_CALIBRATOR_CALIBRATION_STATISTICS_COLLECTOR_HISTOGRAM_H_ + +#include +#include +#include +#include + +#include "absl/types/span.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibration_statistics.pb.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibration_statistics_collector_base.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.pb.h" + +namespace tensorflow { +namespace calibrator { + + +class CalibrationStatisticsCollectorHistogram + : public CalibrationStatisticsCollectorBase { + public: + explicit CalibrationStatisticsCollectorHistogram() { ClearData(); } + + void ClearData() override; + + void Collect(float min, float max, + absl::Span histogram) override; + + std::optional GetStatistics() const override; + + private: + // Expands the histogram so the lower_bound and upper_bound can fit in the + // histogram. Returns the indexes associated to those values. + std::pair ExpandHistogramIfNeeded(float lower_bound, + float upper_bound); + + // hist_freq_[i] saves frequency of range [bins[i], bins[i + 1]). + // bins[i] = lower_bound_ + bin_width_ * i + // bins[i + 1] = lower_bound_ + bin_width_ * (i + 1) + std::deque hist_freq_; + + // Width of bin + float bin_width_; + + // The first bin's left value. [left, right) + float lower_bound_; +}; + +} // namespace calibrator +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_CALIBRATOR_CALIBRATION_STATISTICS_COLLECTOR_HISTOGRAM_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibration_statistics_collector_min_max.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibration_statistics_collector_min_max.h new file mode 100644 index 00000000..8ee545e5 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibration_statistics_collector_min_max.h @@ -0,0 +1,53 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_CALIBRATOR_CALIBRATION_STATISTICS_COLLECTOR_MIN_MAX_H_ +#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_CALIBRATOR_CALIBRATION_STATISTICS_COLLECTOR_MIN_MAX_H_ + +#include +#include + +#include "absl/types/span.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibration_statistics.pb.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibration_statistics_collector_base.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.pb.h" + +namespace tensorflow { +namespace calibrator { + +using ::stablehlo::quantization::CalibrationOptions; + +// MinMax calibration calculates the global min and global max values. +// global min = min of given sample inputs +// global max = max of given sample inputs +class CalibrationStatisticsCollectorMinMax + : public CalibrationStatisticsCollectorBase { + public: + explicit CalibrationStatisticsCollectorMinMax() { ClearData(); } + + void ClearData() override; + + void Collect(float min, float max, + absl::Span histogram) override; + + std::optional GetStatistics() const override; + + private: + CalibrationStatistics::MinMaxStatistics min_max_statistics_; +}; + +} // namespace calibrator +} // namespace tensorflow +#endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_CALIBRATOR_CALIBRATION_STATISTICS_COLLECTOR_MIN_MAX_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/tensorflow/cc/const_op_size.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/tensorflow/cc/const_op_size.h new file mode 100644 index 00000000..884ac938 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/tensorflow/cc/const_op_size.h @@ -0,0 +1,32 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_CC_CONST_OP_SIZE_H_ +#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_CC_CONST_OP_SIZE_H_ + +#include + +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" + +namespace mlir { +namespace quant { + +// Returns the size in bytes of the underlying data of `const_op`. If the +// underlying type's size cannot be determined, it assumes 4 bytes per element. +int64_t GetSizeInBytes(TF::ConstOp const_op); + +} // namespace quant +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_CC_CONST_OP_SIZE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/tensorflow/cc/constant_fold.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/tensorflow/cc/constant_fold.h new file mode 100644 index 00000000..d0a4157b --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/tensorflow/cc/constant_fold.h @@ -0,0 +1,43 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_CC_CONSTANT_FOLD_H_ +#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_CC_CONSTANT_FOLD_H_ + +#include "llvm/ADT/SmallVector.h" +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project + +namespace mlir { +namespace quant { + +// Applies constant folding recursively if the operation and all of its operands +// are foldable. Returns the constants generated by constant-folding or the +// original operation's outputs if not folded. +SmallVector ConstantFoldOpIfPossible(Operation* op); + +// This pattern tries to constant-fold the quantizable operands of supported +// TF operations. +struct ConstantFoldQuantizableOperands : public RewritePattern { + public: + explicit ConstantFoldQuantizableOperands(MLIRContext* context) + : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {} + LogicalResult matchAndRewrite(Operation* op, + PatternRewriter& rewriter) const override; +}; + +} // namespace quant +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_CC_CONSTANT_FOLD_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/tensorflow/cc/convert_asset_args.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/tensorflow/cc/convert_asset_args.h new file mode 100644 index 00000000..7ff335fa --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/tensorflow/cc/convert_asset_args.h @@ -0,0 +1,40 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_CC_CONVERT_ASSET_ARGS_H_ +#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_CC_CONVERT_ASSET_ARGS_H_ + +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "tensorflow/core/protobuf/meta_graph.pb.h" + +namespace mlir::quant { + +// Converts arguments of the @main function that are bound to +// `tf_saved_model::AssetOp`s into regular tensor args. Returns `AsestFileDef`s +// that associates the arg with the asset. +// +// In detail, this function performs the following: +// * Replaces "tf_saved_model.bound_input" attributes to +// "tf_saved_model.index_path", if the bound input is attached to the +// `tf_saved_model::AssetOp`. +// * Strips the "assets/" prefix of the filename when setting it to +// `AssetFileDef`. +FailureOr> ConvertAssetArgs( + ModuleOp module_op); + +} // namespace mlir::quant + +#endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_CC_CONVERT_ASSET_ARGS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/tensorflow/cc/quantization_unit_loc.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/tensorflow/cc/quantization_unit_loc.h new file mode 100644 index 00000000..32fb6f89 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/tensorflow/cc/quantization_unit_loc.h @@ -0,0 +1,54 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_CC_QUANTIZATION_UNIT_LOC_H_ +#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_CC_QUANTIZATION_UNIT_LOC_H_ + +#include + +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/Support/TypeID.h" // from @llvm-project +#include "tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.pb.h" + +namespace mlir { +namespace quant { + +// QuantizationUnitLoc uses CallSiteLoc as the base class so it can be printed +// with AsmPrinter and used to set the node name in MLIR to GraphDef exporter. +// The callee is named as `node_name@func_name` with child loc named as +// `op_type` while the caller is the quantization unit. +class QuantizationUnitLoc : public CallSiteLoc { + public: + using QuantizationUnit = + tensorflow::quantization::UnitWiseQuantizationSpec::QuantizationUnit; + + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(QuantizationUnitLoc) + + QuantizationUnitLoc(MLIRContext* context, const QuantizationUnit& unit); + + // Checks if the given location is QuantizationUnitLoc. Users could call + // `isa(loc)` to check if the type matches. + static bool classof(Attribute attr); +}; + +// Finds the QuantizationUnit from location info. +std::optional +FindQuantizationUnitFromLoc(Location loc); + +} // namespace quant +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_CC_QUANTIZATION_UNIT_LOC_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/tensorflow/cc/run_passes.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/tensorflow/cc/run_passes.h new file mode 100644 index 00000000..06db2acb --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/tensorflow/cc/run_passes.h @@ -0,0 +1,77 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_CC_RUN_PASSES_H_ +#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_CC_RUN_PASSES_H_ + +#include + +#include "absl/status/status.h" +#include "absl/strings/str_format.h" +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/Pass/PassManager.h" // from @llvm-project +#include "tensorflow/compiler/mlir/quantization/tensorflow/debugging/mlir_dump.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" + +namespace tensorflow { +namespace quantization { + +// Runs MLIR passes with `module_op`. The passes are added by calling +// `add_passes_func`, which is a callable receiving mlir::PassManager& as its +// only argument. `name` identifies the set of passes added by `add_passes_func` +// and is used for debugging. Changing the `name` does not modify the behavior +// of the passes. +// +// It will try to dump intermediate MLIRs if certain conditions are met. See the +// description from `MaybeEnableIrPrinting` for the details about the +// conditions. +// +// Returns a non-OK status when the pass run fails or it fails to create an MLIR +// dump file. +template +absl::Status RunPasses(const absl::string_view name, FuncT add_passes_func, + mlir::MLIRContext& ctx, mlir::ModuleOp module_op) { + mlir::PassManager pm{&ctx}; + add_passes_func(pm); + + mlir::StatusScopedDiagnosticHandler diagnostic_handler{&ctx}; + TF_RETURN_IF_ERROR(MaybeEnableIrPrinting(pm, name)); + + if (failed(pm.run(module_op))) { + return absl::InternalError( + absl::StrFormat("Failed to run pass: %s. %s", name, + diagnostic_handler.ConsumeStatus().message())); + } + + return absl::OkStatus(); +} + +// Runs MLIR passes with `module_op` on a `pass_manager`. +// +// It will try to dump intermediate MLIRs if certain conditions are met. See the +// description from `MaybeEnableIrPrinting` for the details about the +// conditions. +// +// Returns a non-OK status when the pass run fails or it fails to create an MLIR +// dump file. +absl::Status RunPassesOnModuleOp( + std::optional mlir_dump_file_name, + mlir::PassManager& pass_manager, mlir::ModuleOp module_op); + +} // namespace quantization +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_CC_RUN_PASSES_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/tensorflow/cc/save_variables.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/tensorflow/cc/save_variables.h new file mode 100644 index 00000000..124f2a5b --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/tensorflow/cc/save_variables.h @@ -0,0 +1,40 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_CC_SAVE_VARIABLES_H_ +#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_CC_SAVE_VARIABLES_H_ + +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "mlir/IR/BuiltinOps.h" // from @llvm-project + +namespace tensorflow { +namespace quantization { + +// Saves variables in `module_op` to the checkpoint file inside `prefix`. +// It finds variables that are initialized with "tf.AssignVariableOp" inside the +// initializer function with type "restore_op". The "tf.Const"s used to +// initialize the variables are saved. This function does not modify the +// `module_op`. Returns a list of saved names of the saved variables. +absl::StatusOr> SaveVariablesToCheckpoint( + absl::string_view prefix, mlir::ModuleOp module_op); + +} // namespace quantization +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_CC_SAVE_VARIABLES_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/tensorflow/debugging/mlir_dump.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/tensorflow/debugging/mlir_dump.h new file mode 100644 index 00000000..38a9c4fa --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/tensorflow/debugging/mlir_dump.h @@ -0,0 +1,45 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_DEBUGGING_MLIR_DUMP_H_ +#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_DEBUGGING_MLIR_DUMP_H_ + +#include "absl/status/status.h" +#include "absl/strings/string_view.h" +#include "mlir/Pass/PassManager.h" // from @llvm-project + +namespace tensorflow { +namespace quantization { + +// Enables IR printing for `pm`. When the passes are run, each pass will dump to +// its own file with prefix `file_name_prefix`. +void EnableIrPrinting(mlir::PassManager &pm, + absl::string_view file_name_prefix); + +// If verbosity level >= 1, this will dump intermediate IRs of passes to a file. +// The dumped mlir files with be under a directory determined by +// the TF_QUANT_MLIR_DUMP_PREFIX env variable. The PassManager will dump to a +// new file for each pass. The file name will have the format +// {file_name_prefix}_{pass_number}_{pass_name}_{before|after}.mlir. +// * `file_name_prefix` is from input. +// * `pass_number` increments from 1 for each pass. +// * `pass_name` is the name of the pass. +// * `before|after` indicates whether the dump occurs before or after the pass. +absl::Status MaybeEnableIrPrinting(mlir::PassManager &pm, + absl::string_view file_name_prefix); + +} // namespace quantization +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_DEBUGGING_MLIR_DUMP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/tensorflow/ops/tf_op_quant_spec.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/tensorflow/ops/tf_op_quant_spec.h new file mode 100644 index 00000000..44c60b61 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/tensorflow/ops/tf_op_quant_spec.h @@ -0,0 +1,61 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// Functions for quantization specifications of TensorFlow ops. + +#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_OPS_TF_OP_QUANT_SPEC_H_ +#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_OPS_TF_OP_QUANT_SPEC_H_ + +#include +#include + +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_utils.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.pb.h" + +namespace mlir { +namespace quant { + +// Check if the op has data movement trait. Ops with this trait do not perform +// any computations but just move data and has one result operand. +bool IsOpWithDataMovementTrait(Operation* op); + +// Check if the op is quantizable. Currently, the scope of quantizable op is +// limited to compute intense operations and the ops that supports integer +// operands. +bool IsOpWithQuantizableTrait(Operation* op); + +// Check if the op's operand accepts int8 type. +bool IsOpWithInt8TypeOperand(Operation* op); + +// Check if the data is in quantizable precision. Currently, a value in f32 or +// bf16 is quantizable. +bool IsValueWithQuantizablePrecision(Value val); + +std::optional +GetWeightComponentSpec( + const tensorflow::quantization::QuantizationOptions& quantization_options); + +// Returns the spec for the given operation that can be used for both of +// dynamic and static range quantization. +std::unique_ptr GetTFOpQuantSpec(Operation* op); + +// Returns quantization scale specs (fixed output, same scale) for a TF op. +std::unique_ptr GetTfQuantScaleSpec(Operation* op); + +} // namespace quant +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_OPS_TF_OP_QUANT_SPEC_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/tensorflow/ops/tf_quantize_op.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/tensorflow/ops/tf_quantize_op.h new file mode 100644 index 00000000..bc6031ee --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/tensorflow/ops/tf_quantize_op.h @@ -0,0 +1,45 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This file provides a list of supported quantization algorithms in the format +// of "applyQuantization". +// After applying the function, a quantize/dequantize functions are created +// where the body of each function contains a specific quantization algorithm. +// The input of the quantize function has one operand of +// IsValueWithQuantizablePrecision and the output is a tensor with supported +// quantized precision (like int8). For dequantize function, it is the other way +// around. + +#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_OPS_TF_QUANTIZE_OP_H_ +#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_OPS_TF_QUANTIZE_OP_H_ + +#include + +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Dialect/Traits.h" // from @llvm-project +#include "tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.pb.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" + +namespace mlir { +namespace quant { + +std::optional ApplyUniformQuantization( + PatternRewriter& rewriter, TF::ConstOp op, + tensorflow::quantization::QuantizationComponentSpec& weight_spec); + +} // namespace quant +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_OPS_TF_QUANTIZE_OP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/tensorflow/ops/uniform_op_quant_spec.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/tensorflow/ops/uniform_op_quant_spec.h new file mode 100644 index 00000000..8a062a16 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/tensorflow/ops/uniform_op_quant_spec.h @@ -0,0 +1,35 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// Functions for quantization specifications of Uniform Quantized ops. + +#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_OPS_UNIFORM_OP_QUANT_SPEC_H_ +#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_OPS_UNIFORM_OP_QUANT_SPEC_H_ + +#include + +#include "mlir/IR/Operation.h" // from @llvm-project +#include "tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_utils.h" + +namespace mlir { +namespace quant { + +// Returns the spec for the given operation that can be used for both of +// dynamic and static range quantization. +std::unique_ptr GetUniformOpQuantSpec(Operation* op); + +} // namespace quant +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_OPS_UNIFORM_OP_QUANT_SPEC_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/tensorflow/passes/constants.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/tensorflow/passes/constants.h new file mode 100644 index 00000000..6be6f05a --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/tensorflow/passes/constants.h @@ -0,0 +1,42 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_PASSES_CONSTANTS_H_ +#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_PASSES_CONSTANTS_H_ + +#include "llvm/ADT/StringRef.h" +#include "mlir/Support/LLVM.h" // from @llvm-project + +namespace mlir { +namespace quant { + +// Name of the save function. The "tf_quant__" prefix is for avoiding conflict +// with existing function's name. +inline constexpr StringRef kTfQuantSaveFuncName = "tf_quant__save"; + +// Name of the TensorFlow Operation to be fetched to save the variables to +// checkpoint. This save op follows the SavedModel's load semantics, so it +// should return the file prefix of the checkpoint as a string tensor. +inline constexpr StringRef kTfQuantSaveOpName = "tf_quant__save_op"; + +// Name the file prefix string tensor. The tensor is used to identify the prefix +// to the checkpoint where the variables are saved / loaded. This may be present +// in a function argument's "tf_saved_model.index_path" attribute to identify +// the file prefix function argument. +inline constexpr StringRef kTfFilePrefix = "__tf_file_prefix"; + +} // namespace quant +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_PASSES_CONSTANTS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/tensorflow/passes/manipulate_model_attr.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/tensorflow/passes/manipulate_model_attr.h new file mode 100644 index 00000000..d42ad360 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/tensorflow/passes/manipulate_model_attr.h @@ -0,0 +1,32 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_PASSES_MANIPULATE_MODEL_ATTR_H_ +#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_PASSES_MANIPULATE_MODEL_ATTR_H_ + +#include "llvm/ADT/StringRef.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project + +namespace mlir { +namespace quant { + +// Adds a new input name to the `inputs` field of the `tf.entry_function` +// attribute if the attribute exist in the given function. Otherwise, no +// attribute is modified. +void AddEntryFunctionInput(StringRef input_name, func::FuncOp func_op); + +} // namespace quant +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_PASSES_MANIPULATE_MODEL_ATTR_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/tensorflow/passes/passes.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/tensorflow/passes/passes.h new file mode 100644 index 00000000..9a0084ef --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/tensorflow/passes/passes.h @@ -0,0 +1,250 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_PASSES_PASSES_H_ +#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_PASSES_PASSES_H_ + +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_config.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.pb.h" + +namespace mlir { +namespace quant { + +// Creates a main function if it doesn't exist in the module. This is a +// workaround to make ConvertMlirToGraphdef work for multi-signatures graphs. +// TODO(b/204265523): Removes this pass after the exporting MLIR to SavedModel +// path is available. +std::unique_ptr> CreateInsertMainFunctionPass(); + +// Converts FakeQuant ops to quant.qcast and quant.dcast (QDQ) pairs. +std::unique_ptr> CreateConvertFakeQuantToQdqPass(); + +// Lifts the quantizable spots as composite functions. +std::unique_ptr> +CreateLiftQuantizableSpotsAsFunctionsPass( + const tensorflow::quantization::QuantizationOptions& quant_options); + +// Apply graph optimizations such as fusing and constant folding to prepare +// lifting. +std::unique_ptr> CreatePrepareLiftingPass( + tensorflow::quantization::OpSet target_opset); + +// Lifts the dynamic range quantizable spots as composite functions. +std::unique_ptr> +CreateLiftQuantizableSpotsAsFunctionsDRQPass( + tensorflow::quantization::QuantizationMethod::PresetMethod + quantization_method, + tensorflow::quantization::OpSet op_set, int min_num_elements_for_weights); + +// Replaces tf.CustomAggregator ops with quant.Stats ops for finalizing the +// calibration procedure. +std::unique_ptr> +CreateConvertCustomAggregationOpToQuantStatsPass(); + +// Inserts quantized function library. +std::unique_ptr> CreateInsertQuantizedFunctionsPass( + tensorflow::quantization::QuantizationMethod::PresetMethod + quantization_method, + tensorflow::quantization::OpSet target_opset); + +// Inserts custom aggregation operators for the calibration procedure. +std::unique_ptr> +CreateInsertCustomAggregationOpsPass( + const ::stablehlo::quantization::CalibrationOptions& calib_opts); + +// Replaces composite functions with quantized composite functions. After this +// pass runs, functions in the given graph will be replaced with their quantized +// versions. By doing so, the quantization will be applied to the given input. +// mlir_dump_file_prefix is an optional field that is used for debugging to save +// mlir dump files. +std::unique_ptr> CreateQuantizeCompositeFunctionsPass( + tensorflow::quantization::QuantizationMethod::PresetMethod + quantization_method, + tensorflow::quantization::OpSet target_opset, + bool enable_per_channel_quantization, int min_num_elements_for_weights, + bool enable_legacy_weight_only = false, + std::optional mlir_dump_file_prefix = + std::nullopt); + +// Converts dequantize-(quantizable) call-quantize pattern to a single call op +// that has quantized input and output types. It is expected for this pass to +// emit illegal IR with unsupported quantized input and output types. The +// pass following immediately after this one will be responsible for legalizing +// input and output types by unwrapping quantization parameters. +std::unique_ptr> CreateQuantizePass(); + +// Overloading of CreateQuantizePass which takes QuantizationSpecs. +std::unique_ptr> CreateQuantizePass( + QuantizationSpecs quant_specs, + tensorflow::quantization::OpSet target_opset); + +// Creates an instance of the PrepareQuantize pass, which will perform similar +// transformations as TFL::PrepareQuantizePass. +std::unique_ptr> CreatePrepareQuantizePass( + const QuantizationSpecs& quant_specs, + tensorflow::quantization::QuantizationMethod::PresetMethod + quantization_method); + +// Creates an instance of the PrepareQuantizeDRQ pass, which will +// perform similar transformations as TFL::PrepareQuantizeDynamicRangePass. +std::unique_ptr> CreatePrepareQuantizeDRQPass( + const QuantizationSpecs& quant_specs, + tensorflow::quantization::OpSet op_set); + +// Creates an instance of the PreprocessOp pass, which will perform op +// preprocessing to allow multi-axis quantization, prior to quantization. +std::unique_ptr> CreatePreprocessOpPass( + tensorflow::quantization::OpSet op_set, + tensorflow::quantization::QuantizationMethod::PresetMethod + quantization_method, + bool enable_per_channel_quantization); + +// Creates an instance of the PostQuantize pass, which will remove unnecessary +// ops from the final quantized graph. +std::unique_ptr> CreatePostQuantizePass(); + +// Applies optimization patterns after quantization. +std::unique_ptr> CreateOptimizePass(); + +// Creates an instance of the ReplaceCastHacksWithTFXLAOpsPass, which will +// replace mixed-type convolution and matmul cast hacks by XLA Conv2DOp and +// MatmulOp. +std::unique_ptr> +CreateReplaceCastHacksWithTFXLAOpsPass(); + +// Creates a pass that moves & merges initializer function's ops into the @main +// function. This pass should be run on a valid tf_executor dialect. The control +// output of the initializer function for non-variable resource initialization +// will be passed on as a dependency to a new `tf.NoOp`, whose control output +// will be merged into the main function's FetchOp. The initializer functions +// will be removed. +// +// Running this pass essentially has the effect of inlining the initializer +// functions into the main graph. This is beneficial when we wish to find and +// fetch the node that restores resources, after the ModuleOp has been exported +// as GraphDef. +std::unique_ptr> +CreateMergeInitializerFunctionOpsToMainPass(); + +// Creates a pass that moves & merges the "@tf_quant__save" function to "@main" +// function. A new `IdentityOp` will be created. It will have control dependency +// to the save function and returns the file_prefix argument (typed +// `tensor`). The file_prefix argument, which can be identified +// if the "tf_saved_model.index_path" attribute has "__tf_file_prefix", will be +// reused if it already exist in @main. Otherwise a new file prefix argument +// will be created. @tf_quant__save function will be erased. +// +// Running this pass essentially has the effect of inlining the @tf_quant__save +// into the main graph. This is beneficial when we wish to find and fetch +// the node that saves the variables, after the ModuleOp has been exported as +// GraphDef. +std::unique_ptr> CreateMergeSaveFunctionOpsToMainPass(); + +// Creates a pass that "unfreezes" ConstOps into variables. Each ConstOp's use +// will be replaced by a VarHandleOp -> ReadVariableOp pattern. The newly +// created variables will be initialized in the session initializer function via +// AssignVariableOps. +std::unique_ptr> CreateUnfreezeConstantsPass(); + +// Creates a pass that duplicates constants that affect the shape of a tensor +// after some computation. +std::unique_ptr> +CreateDuplicateShapeDeterminingConstantsPass(); + +// Creates a pass that creates a RestoreV2 op in the initializer function with +// type "restore_op" that initializes variables from the checkpoint. It finds +// tf.AssignVariableOp(tf.VarHandleOp, tf.Const) patterns in the initializer +// function and replaces tf.Consts with the results of RestoreV2. +std::unique_ptr> CreateInsertRestoreOpPass(); + +// Creates a pass that creates a new function that wraps the newly created +// SaveV2 op. The new function's name is "tf_quant__save". The function accepts +// a single string tensor as argument, which specifies the path to the +// checkpoint to which the variable's tensor values are saved. It finds +// `tf.AssignVariableOp(tf.VarHandleOp, tf.Const)` pattern in the initializer +// function of type "restore_op" to identify the VarHandleOps that should be +// saved using the SaveV2 op. +std::unique_ptr> CreateInsertSaveOpPass(); + +// Creates a pass that marks functions with the attribute `tf._noinline = true` +// to avoid being inlined by the `InlinerPass`. `noinline_functions` is the name +// of the functions to mark. +std::unique_ptr> CreateMarkFunctionsNoinlinePass( + ArrayRef noinline_functions); + +// Removes `tf.AssignVariableOp(tf.VarHandleOp, tf.Const)` patterns from the +// initializer function (type = "restore_op"). +// Note: initializing values (`tf.Const`s) will be removed and this may result +// in an information loss and uninitialized variables eventually. Make sure that +// this effect is desired (e.g. there is a `tf.RestoreV2Op` that restores the +// variables instead). +std::unique_ptr> +CreateRemoveVariableInitializationByConstPass(); + +// Creates a pass that converts Tensorflow Xla ops to non-Xla ops. +std::unique_ptr> CreateConvertTfXlaOpToTfOpPass(); + +// Creates a pass that converts TPU models for CPU by removing TPU related ops +// such as TPUPartitionedCall, TPUReplicatedOp, etc. The TF quantizer does not +// work with models specifically designed for TPU, so this pass makes the input +// TPU model compatible with the TF quantizer by rewriting the TPU ops. The +// output model of this pass is expected to be ready for the TF quantizer. +std::unique_ptr> CreateConvertTpuModelToCpuPass(); + +// Creates a pass that casts BFloat16 operations to Float32 operations. This +// pass is a part of the ConvertTpuModelToCpu pass to support BF16 optimized TPU +// model quantization. +std::unique_ptr> CreateCastBf16OpsToF32Pass(); + +// Creates a pass that lifts HashTable ops as function arguments. In the graph +// execution mode, resource ops with the same `shared_name` attribute point to +// the same underlying resource. This is not true in the eager execution mode. +// Lifting resource ops as arguments will help unifying them across functions. +std::unique_ptr> CreateLiftHashTableOpsAsArgsPass(); + +// Creates a pass that merges duplicate resource ops in each function. Two +// resource ops are considered duplicated if they have the same `shared_name`. +std::unique_ptr> +CreateMergeDuplicateResourceOpsPass(); + +// Apply quantization to weights based on the provided schemes. +std::unique_ptr> CreateQuantizeWeightsPass( + const tensorflow::quantization::QuantizationOptions& quant_options); + +// Propagate quantized type through allowed ops. +std::unique_ptr> CreatePropagateQuantizeTypePass(); + +// Create a pass that inserts dump tensor to quantizable layer's output. +std::unique_ptr> CreateAddDumpTensorOpPass( + ::stablehlo::quantization::DebuggerConfig::DebuggerType debugger_type, + std::string log_dir_path); + +// Creates a pass that add QuantizationUnitLoc to quantizable layers. +std::unique_ptr> CreateAddQuantizationUnitLocPass(); + +} // namespace quant +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_PASSES_PASSES_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/tensorflow/passes/remove_identity_op_pattern.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/tensorflow/passes/remove_identity_op_pattern.h new file mode 100644 index 00000000..8fe144d7 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/tensorflow/passes/remove_identity_op_pattern.h @@ -0,0 +1,39 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_PASSES_REMOVE_IDENTITY_OP_PATTERN_H_ +#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_PASSES_REMOVE_IDENTITY_OP_PATTERN_H_ + +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" + +namespace mlir { +namespace quant { + +// Copied from tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc. +// By removing identity ops, constant operands with dynamic shapes have static +// shape information which is necessary for correct pattern matching in this +// pass. +struct RemoveIdentity : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TF::IdentityOp identity, + PatternRewriter &rewriter) const override; +}; + +} // namespace quant +} // namespace mlir +#endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_PASSES_REMOVE_IDENTITY_OP_PATTERN_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_quant_ops.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_quant_ops.h new file mode 100644 index 00000000..6c8ad1ca --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_quant_ops.h @@ -0,0 +1,52 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_PASSES_TF_QUANT_OPS_H_ +#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_PASSES_TF_QUANT_OPS_H_ + +#include "mlir/Dialect/Traits.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/Dialect.h" // from @llvm-project +#include "mlir/IR/Matchers.h" // from @llvm-project +#include "mlir/IR/OpImplementation.h" // from @llvm-project +#include "mlir/IR/TypeUtilities.h" // from @llvm-project +#include "mlir/Interfaces/CallInterfaces.h" // from @llvm-project +#include "mlir/Interfaces/ControlFlowInterfaces.h" // from @llvm-project +#include "mlir/Interfaces/DerivedAttributeOpInterface.h" // from @llvm-project +#include "mlir/Interfaces/InferTypeOpInterface.h" // from @llvm-project +#include "mlir/Interfaces/LoopLikeInterface.h" // from @llvm-project +#include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_traits.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_verifiers.h" + +#define GET_OP_CLASSES +#include "tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_quant_ops.h.inc" + +namespace mlir { +namespace quant { + +// Function to register TensorFlow Uniform Quantized ops. +void RegisterOps(); + +} // namespace quant +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_PASSES_TF_QUANT_OPS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/tensorflow/python/py_function_lib.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/tensorflow/python/py_function_lib.h new file mode 100644 index 00000000..fbba7247 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/tensorflow/python/py_function_lib.h @@ -0,0 +1,114 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_PYTHON_PY_FUNCTION_LIB_H_ +#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_PYTHON_PY_FUNCTION_LIB_H_ + +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/strings/string_view.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/min_max_value.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibration_statistics.pb.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/exported_model.pb.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.pb.h" +#include "tensorflow/core/protobuf/meta_graph.pb.h" + +namespace tensorflow::quantization { + +// Declares pure virtual member functions for a python-side derived class to +// override. This allows calling python implementations from the C++ layer. +// Member functions should be pure not stateful; they should not access or rely +// on member fields. +class PyFunctionLibrary { + public: + virtual ~PyFunctionLibrary() = default; + + // Saves `exported_model` to `dst_saved_model_path` as SavedModel. + // `src_saved_model_path` is the path to the source SavedModel from which the + // exported model is produced. It is used to copy the asset files to + // `dst_saved_model_path`. `tags` will be attached to the saved + // `MetaGraphDef`. `signature_def_map` will be passed to the + // `add_meta_graph_and_variables` function, which is internally used to add a + // `MetaGraphDef` to save to the SavedModel. + // + // Returns `true` if successful. Returns `std::nullopt` otherwise. + // + // If the function signature changes, likely its corresponding .pyi type + // hinting and definition should also change. + // LINT.IfChange(save_exported_model) + virtual std::optional SaveExportedModel( + absl::string_view dst_saved_model_path, + const ExportedModel& exported_model, + absl::string_view src_saved_model_path, + const std::unordered_set& tags, + const absl::flat_hash_map& + signature_def_map) const = 0; + // LINT.ThenChange( + // pywrap_function_lib.pyi:save_exported_model, + // py_function_lib.py:save_exported_model, + // ) + + // Runs calibration on a model saved at `saved_model_path`. `exported_model` + // should be the corresponding exported model resulting from the + // pre-calibration step. `signature_keys` is a set of keys that identify a + // SignatureDef to run the calibration on. `tags` is a set of strings that + // identify the `MetaGraphDef`. `calibration_options` provides configurations + // for the calibration behavior. `representative_dataset` is a python object + // of type `RepresentativeDatasetOrMapping`, which is used to run the + // calibration. + // + // Returns `true` if successful. Returns `std::nullopt` otherwise. + // + // If the function signature changes, likely its corresponding .pyi type + // hinting and definition should also change. + // LINT.IfChange(run_calibration) + virtual std::optional RunCalibration( + absl::string_view saved_model_path, + const std::vector& signature_keys, + const std::unordered_set& tags, + bool force_graph_mode_calibration, + const absl::flat_hash_map& + representative_dataset_file_map) const = 0; + // LINT.ThenChange( + // pywrap_function_lib.pyi:run_calibration, + // py_function_lib.py:run_calibration, + // ) + + // Retrieves min and max value from `calibration_statistics`, based on the + // calibration method specified by `calibration_options`. + // + // Returns `std::nullopt` if unsuccessful. + // + // If the function signature changes, likely its corresponding .pyi type + // hinting and definition should also change. + // LINT.IfChange(get_calibration_min_max_value) + virtual std::optional + GetCalibrationMinMaxValue(const tensorflow::calibrator::CalibrationStatistics& + calibration_statistics, + const ::stablehlo::quantization::CalibrationOptions& + calibration_options) const = 0; + // LINT.ThenChange( + // pywrap_function_lib.pyi:get_calibration_min_max_value, + // py_function_lib.py:get_calibration_min_max_value, + // ) +}; + +} // namespace tensorflow::quantization + +#endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_PYTHON_PY_FUNCTION_LIB_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.h new file mode 100644 index 00000000..9e36ce52 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.h @@ -0,0 +1,77 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_PYTHON_QUANTIZE_MODEL_H_ +#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_PYTHON_QUANTIZE_MODEL_H_ + +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/exported_model.pb.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/python/py_function_lib.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.pb.h" +#include "tensorflow/core/protobuf/meta_graph.pb.h" + +namespace tensorflow { +namespace quantization { + +// Names of the TensorFlow Quantization steps. These names are used primarily +// for debugging. +inline constexpr absl::string_view kTfQuantPtqPreCalibrationStepName = + "tf_quant_ptq_pre_calibration"; +inline constexpr absl::string_view kTfQuantPtqPostCalibrationStepName = + "tf_quant_ptq_post_calibration"; +inline constexpr absl::string_view kTfQuantQatStepName = "tf_quant_qat"; +inline constexpr absl::string_view kTfQuantPtqDynamicRangeStepName = + "tf_quant_ptq_dynamic_range"; +inline constexpr absl::string_view kTfQuantWeightOnlyStepName = + "tf_quant_weight_only"; + +absl::StatusOr QuantizeQatModel( + absl::string_view saved_model_path, + const std::vector& signature_keys, + const std::unordered_set& tags, + const QuantizationOptions& quantization_options); + +// Applies post-training dynamic-range quantization to the model. +absl::StatusOr QuantizeDynamicRangePtq( + absl::string_view saved_model_path, + const std::vector& signature_keys, + const std::unordered_set& tags, + const QuantizationOptions& quantization_options); + +// Applies post-training static-range weight-only quantization to the model. +absl::StatusOr QuantizeWeightOnly( + absl::string_view saved_model_path, + const QuantizationOptions& quantization_options); + +// Applies post-training static-range quantization to the model. +absl::StatusOr QuantizeStaticRangePtq( + absl::string_view saved_model_path, + const std::vector& signature_keys, + const std::unordered_set& tags, + const QuantizationOptions& quantization_options, + const absl::flat_hash_map& signature_def_map, + const PyFunctionLibrary& py_function_library, + const absl::flat_hash_map& + representative_dataset_file_map_serialized); + +} // namespace quantization +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_PYTHON_QUANTIZE_MODEL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/tensorflow/python/type_casters.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/tensorflow/python/type_casters.h new file mode 100644 index 00000000..dd5fe761 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/tensorflow/python/type_casters.h @@ -0,0 +1,158 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_PYTHON_TYPE_CASTERS_H_ +#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_PYTHON_TYPE_CASTERS_H_ + +#include +#include +#include + +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" +#include "pybind11/cast.h" // from @pybind11 +#include "pybind11/detail/common.h" // from @pybind11 +#include "pybind11/pytypes.h" // from @pybind11 +#include "pybind11_abseil/absl_casters.h" // from @pybind11_abseil // IWYU pragma: keep +#include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibration_statistics.pb.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/exported_model.pb.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.pb.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/protobuf/meta_graph.pb.h" +#include "tensorflow/python/lib/core/pybind11_lib.h" +#include "tsl/platform/protobuf.h" // IWYU pragma: keep + +namespace pybind11::detail { +namespace internal { + +// Serializes a protobuf object. Raises python ValueError if serialization +// fails. +inline std::string Serialize(const tsl::protobuf::Message& protobuf_object) { + const std::string serialized = protobuf_object.SerializeAsString(); + + // Empty string means it failed to serialize the protobuf with an error. See + // the docstring for SerializeAsString for details. + if (serialized.empty()) { + // Show the name of the protobuf message type to provide more information + // and easier debugging. + const absl::string_view descriptor_name = + protobuf_object.GetDescriptor() == nullptr + ? absl::string_view("unknown") + : absl::string_view(protobuf_object.GetDescriptor()->full_name()); + throw py::value_error(absl::StrFormat( + "Failed to serialize protobuf object: %s.", descriptor_name)); + } + + return serialized; +} + +// Handles `ProtoT` (c++) <-> `bytes` (python) conversion. The `bytes` +// object in the python layer is a serialization of `ProtoT`. +// +// The caller of c++ interfaces should make sure to pass valid serialized +// `ProtoT` objects as arguments. Failing to do so results in raising a +// `ValueError`. Similarly, the python implementation of a c++ virtual member +// function that return an `ProtoT` should return a valid serialized `ProtoT`. +// +// See https://pybind11.readthedocs.io/en/stable/advanced/cast/custom.html +template >> +struct SerializedProtobufCaster { + public: + PYBIND11_TYPE_CASTER(ProtoT, const_name()); + + // Loads an `ProtoT` instance from a python `bytes` object (`src`). + bool load(handle src, const bool convert) { + auto caster = make_caster(); + // Make sure the user passed a valid python string. + if (!caster.load(src, convert)) return false; + + const absl::string_view serialized_proto = + cast_op(std::move(caster)); + + // NOLINTNEXTLINE: Explicit std::string conversion required for OSS. + return value.ParseFromString(std::string(serialized_proto)); + } + + // Constructs a `bytes` object by serializing `src`. + static handle cast(ProtoT&& src, return_value_policy policy, handle parent) { + // release() prevents the reference count from decreasing upon the + // destruction of py::bytes and returns a raw python object handle. + return py::bytes(Serialize(src)).release(); + } + + // Constructs a `bytes` object by serializing `src`. + static handle cast(const ProtoT& src, return_value_policy policy, + handle parent) { + // release() prevents the reference count from decreasing upon the + // destruction of py::bytes and returns a raw python object handle. + return py::bytes(Serialize(src)).release(); + } +}; + +} // namespace internal + +// The following explicit specializations of protobuf `type_caster`s for +// specific protobuf message types are there to have higher priority over those +// defined in `native_proto_caster.h` during the resolution process. This is +// because the type casters in `native_proto_caster.h`, which allow seamlessly +// exchanging protobuf messages across c++-python boundaries, potentially +// without serialization, fail in the open-source environment. +// Explicitly-specialized type casters for serialized protobufs are added on an +// on-demand basis for quantization library. +// TODO: b/308532051 - Make `native_proto_caster.h` work in the open-source +// environment. + +template <> +struct type_caster + : public internal::SerializedProtobufCaster< + tensorflow::quantization::ExportedModel> {}; + +template <> +struct type_caster + : public internal::SerializedProtobufCaster< + tensorflow::quantization::QuantizationOptions> {}; + +template <> +struct type_caster<::stablehlo::quantization::CalibrationOptions> + : public internal::SerializedProtobufCaster< + ::stablehlo::quantization::CalibrationOptions> {}; + +template <> +struct type_caster + : public internal::SerializedProtobufCaster {}; + +template <> +struct type_caster + : public internal::SerializedProtobufCaster {}; + +template <> +struct type_caster + : public internal::SerializedProtobufCaster< + tensorflow::calibrator::CalibrationStatistics> {}; + +template <> +struct type_caster + : public internal::SerializedProtobufCaster< + stablehlo::quantization::QuantizationConfig> {}; + +template <> +struct type_caster + : public internal::SerializedProtobufCaster< + tensorflow::quantization::RepresentativeDatasetFile> {}; + +} // namespace pybind11::detail + +#endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_PYTHON_TYPE_CASTERS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/tensorflow/python/unfreeze_constants.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/tensorflow/python/unfreeze_constants.h new file mode 100644 index 00000000..3086d705 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/tensorflow/python/unfreeze_constants.h @@ -0,0 +1,38 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_PYTHON_UNFREEZE_CONSTANTS_H_ +#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_PYTHON_UNFREEZE_CONSTANTS_H_ + +#include "absl/status/status.h" +#include "absl/strings/string_view.h" +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project + +namespace tensorflow { +namespace quantization { + +inline constexpr absl::string_view kTfQuantConstantUnfreezingStepName = + "tf_quant_constant_unfreezing"; +inline constexpr absl::string_view kTfQuantInsertRestoreOpStepName = + "tf_quant_insert_restore_op"; + +absl::Status UnfreezeConstantsAndSaveVariables(absl::string_view checkpoint_dir, + mlir::MLIRContext &ctx, + mlir::ModuleOp module_op); + +} // namespace quantization +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_PYTHON_UNFREEZE_CONSTANTS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/tensorflow/quantize_passes.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/tensorflow/quantize_passes.h new file mode 100644 index 00000000..b9c765c0 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/tensorflow/quantize_passes.h @@ -0,0 +1,55 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_QUANTIZE_PASSES_H_ +#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_QUANTIZE_PASSES_H_ + +#include + +#include "absl/strings/string_view.h" +#include "mlir/Pass/PassManager.h" // from @llvm-project +#include "tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.pb.h" + +namespace tensorflow { +namespace quantization { + +// mlir_dump_file_prefix is an optional field that is used for debugging to save +// mlir dump files. +void AddQuantizeQatPasses(mlir::OpPassManager &pm, + const QuantizationOptions &quantization_options, + std::optional + mlir_dump_file_prefix = std::nullopt); + +void AddQuantizePtqDynamicRangePasses( + mlir::OpPassManager &pm, const QuantizationOptions &quantization_options, + std::optional mlir_dump_file_prefix = + std::nullopt); + +void AddQuantizeWeightOnlyPasses( + mlir::OpPassManager &pm, const QuantizationOptions &quantization_options, + std::optional mlir_dump_file_prefix = + std::nullopt); + +void AddQuantizePtqPreCalibrationPasses( + mlir::OpPassManager &pm, const QuantizationOptions &quantization_options); + +void AddQuantizePtqPostCalibrationPasses( + mlir::OpPassManager &pm, const QuantizationOptions &quantization_options, + std::optional mlir_dump_file_prefix = + std::nullopt); + +} // namespace quantization +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_QUANTIZE_PASSES_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/tensorflow/quantize_preprocess.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/tensorflow/quantize_preprocess.h new file mode 100644 index 00000000..47bed2e5 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/tensorflow/quantize_preprocess.h @@ -0,0 +1,86 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_QUANTIZE_PREPROCESS_H_ +#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_QUANTIZE_PREPROCESS_H_ + +#include +#include +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/status/status.h" +#include "absl/strings/string_view.h" +#include "llvm/ADT/ArrayRef.h" +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/Pass/PassManager.h" // from @llvm-project +#include "tensorflow/core/public/session.h" + +namespace tensorflow { +namespace quantization { + +// Default MLIR dump file prefix for TensorFlow quantization passes. +inline constexpr absl::string_view kDefaultTfQuantMlirDumpFilePrefix = + "tf_quant"; + +// Preprocesses the `module_op` for quantization. The preprocess steps include +// freezing the variables in the graph into constants. `is_inliner_run` +// determines whether the `InlinerPass` should be run after unfreezing. +// +// `mlir_dump_file_prefix` is primarily used for debugging and does not affect +// the preprocessing behavior. Instructions for producing MLIR dump files are in +// the comments of `tensorflow::quantization::MaybeEnableIrPrinting` function. +absl::Status PreprocessAndFreezeGraph( + absl::string_view mlir_dump_file_prefix, bool is_inliner_run, + const absl::flat_hash_set& noinline_functions, + mlir::ModuleOp module_op, mlir::MLIRContext* context, + std::optional session, bool run_tf_to_stablehlo, + bool deserialize_xla_call_module, + llvm::ArrayRef> input_arg_shapes = {}); + +// Overload of `PreprocessAndFreezeGraph` that uses the default MLIR dump file +// prefix. +inline absl::Status PreprocessAndFreezeGraph(mlir::ModuleOp module_op, + mlir::MLIRContext* context, + std::optional session) { + return PreprocessAndFreezeGraph( + /*mlir_dump_file_prefix=*/kDefaultTfQuantMlirDumpFilePrefix, + /*is_inliner_run=*/true, /*noinline_functions=*/{}, module_op, context, + session, /*run_tf_to_stablehlo=*/false, + /*deserialize_xla_call_module=*/false, /*input_arg_shapes=*/{}); +} + +// Overload of `PreprocessAndFreezeGraph` that uses the default MLIR dump file +// prefix. +inline absl::Status PreprocessAndFreezeGraph(mlir::ModuleOp module_op, + mlir::MLIRContext* context) { + return PreprocessAndFreezeGraph( + /*mlir_dump_file_prefix=*/kDefaultTfQuantMlirDumpFilePrefix, + /*is_inliner_run=*/true, /*noinline_functions=*/{}, module_op, context, + nullptr, /*run_tf_to_stablehlo=*/false, + /*deserialize_xla_call_module=*/false, /*input_arg_shapes=*/{}); +} + +// TF->StableHLO has limited support for dynamic shapes. +// Some models can only be converted with explicitly provided input argument +// shapes. +void AddTFToStablehloPasses( + mlir::PassManager& pm, + llvm::ArrayRef> input_arg_shapes = {}); + +} // namespace quantization +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_QUANTIZE_PREPROCESS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/tensorflow/utils/fake_quant_utils.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/tensorflow/utils/fake_quant_utils.h new file mode 100644 index 00000000..702e1950 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/tensorflow/utils/fake_quant_utils.h @@ -0,0 +1,160 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This header file defines common utils used by TF-Quant transformation +// passes to work with tf.FakeQuant* ops. Copied and modified from +// //third_party/tensorflow/compiler/mlir/lite/utils/fake_quant_utils.h +#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_UTILS_FAKE_QUANT_UTILS_H_ +#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_UTILS_FAKE_QUANT_UTILS_H_ + +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.h" +#include "tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_utils.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.h" + +namespace mlir { +namespace quant { + +template +struct FetchMinMaxAttrs { + using AttrType = FloatAttr; + bool operator()(TFFakeQuantOp tf_op, AttrType &min_value, + AttrType &max_value) const { + min_value = tf_op.getMinAttr(); + max_value = tf_op.getMaxAttr(); + return true; // Successfully matched and fetched. + } +}; + +template +struct FetchConstantMinMaxInputs { + using AttrType = DenseFPElementsAttr; + bool operator()(TFFakeQuantOp tf_op, AttrType &min_value, + AttrType &max_value) const { + Value min = tf_op.getMin(), max = tf_op.getMax(); + if (auto min_id = min.getDefiningOp()) { + min = min_id.getInput(); + } + if (auto max_id = max.getDefiningOp()) { + max = max_id.getInput(); + } + + if (!matchPattern(min, m_Constant(&min_value))) { + return false; + } + if (!matchPattern(max, m_Constant(&max_value))) { + return false; + } + return true; // Successfully matched and fetched. + } +}; + +// Inserts a "quant.qcast" and "quant.dcast" op pair (QDQs) in place of the +// tf.FakeQyantWithMinMax{Vars|VarsPerChannel|Args}Op +// before the op being constant folded. Since the constant +// folding logic will use a "arith.constant" op to replace the +// "tf.FakeQuantWithMinMaxVarsOp", the "quant.qcast" op is used to preserve +// the quantization parameters as a TypeAttr and "quant.dcast" op used to +// convert the output type to the next op. Here are the transformations: +// +// input min cst max cst input +// \ | | | +// \ (tf.Identity) (tf.Identity) => quant.qcast +// \ | | | +// tf.FakeQuantWithMinMaxVars quant.dcast +// | | +// +// Warns if the (most likely unwanted, currently not quite correctly handled) +// case of back-to-back tf.FakeQuant occurs +// +// tf.FakeQuant* +// | +// tf.FakeQuant* +// +template +class ConvertFakeQuantOpToQuantOps { + public: + explicit ConvertFakeQuantOpToQuantOps(bool use_fake_quant_num_bits) + : use_fake_quant_num_bits_(use_fake_quant_num_bits) {} + + FetchMinMax fetch_min_max_; + + using FetchAttrType = typename FetchMinMax::AttrType; + LogicalResult matchAndRewrite(TFFakeQuantOp tf_op, + OpBuilder &rewriter) const { + if (tf_op.getNumBits() != 8) { + return failure(); + } + + // Extract the min/max constant values from the operands. We also consider + // a special case that there are tf.Identity ops between the min/max + // constants and the tf.FakeQuantWithMinMaxVarsOp. + FetchAttrType min_value, max_value; + if (!fetch_min_max_(tf_op, min_value, max_value)) { + return failure(); + } + + Value input = tf_op.getInputs(); + int quant_dim = -1; + auto input_type = mlir::cast(input.getType()); + if (PerAxis) { + if (!input_type.hasRank()) { + tf_op.emitError("The input should have known rank for per-channel op."); + return failure(); + } + // This is a special case that the quant_dim is the last dimensions. + quant_dim = input_type.getRank() - 1; + } + // Use the min/max from the operands and the num_bits and narrow_range + // attribute to create the quantization parameter for the new quantize op. + rewriter.setInsertionPointAfter(tf_op.getOperation()); + IntegerAttr num_bits = rewriter.getI64IntegerAttr(tf_op.getNumBits()); + BoolAttr narrow_range = rewriter.getBoolAttr(tf_op.getNarrowRange()); + Type res_type = tf_op.getType(); + TypeAttr qtype = quant::GetQuantizedTypeAttr( + rewriter, input_type, min_value, max_value, quant_dim, num_bits, + narrow_range, /*is_signed=*/true, /*legacy_float_scale=*/false, + use_fake_quant_num_bits_); + if (!qtype) { + return failure(); + } + + // Finally, use the quantization parameter to create the quantize and + // dequantize ops, and insert them between the tf.FakeQuantWithMinMaxVarsOp + // and its users. + auto quantize = rewriter.create( + tf_op.getLoc(), qtype.getValue(), input); + auto dequantize = rewriter.create( + tf_op.getLoc(), res_type, quantize.getResult()); + tf_op.getOutputs().replaceAllUsesWith(dequantize); + + return success(); + } + + bool use_fake_quant_num_bits_; +}; + +// Removes the wrapper of the tf.FakeQuant* ops and creates the quant.qcast +// and quant.dcast pairs before tf.FakeQuant* ops are being folded. +LogicalResult ConvertFakeQuantOps(func::FuncOp func, MLIRContext *ctx, + bool use_fake_quant_num_bits); + +} // namespace quant +} // namespace mlir +#endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_UTILS_FAKE_QUANT_UTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/tensorflow/utils/tf_quantize_op_utils.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/tensorflow/utils/tf_quantize_op_utils.h new file mode 100644 index 00000000..2e573e28 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/tensorflow/utils/tf_quantize_op_utils.h @@ -0,0 +1,29 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_UTILS_TF_QUANTIZE_OP_UTILS_H_ +#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_UTILS_TF_QUANTIZE_OP_UTILS_H_ + +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project + +namespace mlir { +namespace quant { + +UnrankedTensorType CreateUnknownShapeFromElementType(Type tensor_type); + +} // namespace quant +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_UTILS_TF_QUANTIZE_OP_UTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/tensorflow/utils/tf_to_uniform_attribute_utils.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/tensorflow/utils/tf_to_uniform_attribute_utils.h new file mode 100644 index 00000000..922729d9 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/tensorflow/utils/tf_to_uniform_attribute_utils.h @@ -0,0 +1,72 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// This header file defines common utils used when transforming TF ops to +// Uniform Quantized ops. + +#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_UTILS_TF_TO_UNIFORM_ATTRIBUTE_UTILS_H_ +#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_UTILS_TF_TO_UNIFORM_ATTRIBUTE_UTILS_H_ + +#include "llvm/ADT/StringMap.h" +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.h" + +namespace mlir::quant { + +LogicalResult FillAttributesForUniformQuantizedDotOp( + PatternRewriter& rewriter, Operation* op, + llvm::StringMap& identifier_to_attr, + tensorflow::quantization::QuantizationMethod::PresetMethod + quantization_method, + bool enable_per_channel_quantization); + +LogicalResult FillAttributesForUniformQuantizedConvolutionOp( + PatternRewriter& rewriter, Operation* op, + llvm::StringMap& identifier_to_attr, + tensorflow::quantization::QuantizationMethod::PresetMethod + quantization_method, + bool enable_per_channel_quantization); + +LogicalResult FillAttributesForUniformQuantizedAddOp( + PatternRewriter& rewriter, Operation* op, + llvm::StringMap& identifier_to_attr, + tensorflow::quantization::QuantizationMethod::PresetMethod + quantization_method, + bool enable_per_channel_quantization); + +LogicalResult FillAttributesForUniformQuantizedClipByValueOp( + PatternRewriter& rewriter, Operation* op, + llvm::StringMap& identifier_to_attr, + tensorflow::quantization::QuantizationMethod::PresetMethod + quantization_method, + bool enable_per_channel_quantization); + +LogicalResult FillAttributesForUniformRequantizeOp( + PatternRewriter& rewriter, Operation* op, + llvm::StringMap& identifier_to_attr, + tensorflow::quantization::QuantizationMethod::PresetMethod + quantization_method, + bool enable_per_channel_quantization); + +LogicalResult FillAttributesForUniformQuantizeOp( + PatternRewriter& rewriter, Operation* op, + llvm::StringMap& identifier_to_attr, + tensorflow::quantization::QuantizationMethod::PresetMethod + quantization_method, + bool enable_per_channel_quantization); + +} // namespace mlir::quant + +#endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_UTILS_TF_TO_UNIFORM_ATTRIBUTE_UTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/tensorflow/utils/tf_to_xla_attribute_utils.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/tensorflow/utils/tf_to_xla_attribute_utils.h new file mode 100644 index 00000000..80212b9a --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/quantization/tensorflow/utils/tf_to_xla_attribute_utils.h @@ -0,0 +1,43 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This header file defines common utils used when transforming TF ops to XLA +// ops. +#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_UTILS_TF_TO_XLA_ATTRIBUTE_UTILS_H_ +#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_UTILS_TF_TO_XLA_ATTRIBUTE_UTILS_H_ + +#include "mlir/IR/Builders.h" // from @llvm-project + +namespace mlir::quant { + +// Caclulate padding values for XLA ops. +// Padding values for Uniform Quantized ops can be generated with this method as +// well as it shares the same definition for padding attribute with the XLA ops. +Value CalculatePaddingAndPadIfNeeded(OpBuilder &builder, Location loc, + Value input, Value filter, + int8_t input_zp_value, ArrayAttr strides, + ArrayAttr dilations, + StringAttr conv_padding, + ArrayAttr explicit_paddings, + Value &padding, int num_dims = 4); + +// Given value that is in 8bit type, but holds 4bit data in unpacked format, +// pack to nibble format along pack_dim. +// If the pack_dim size is odd, add 1-size 0 padding and then pack. +Value PackOperand(OpBuilder &builder, Location loc, Value value, int pack_dim); + +} // namespace mlir::quant + +#endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_UTILS_TF_TO_XLA_ATTRIBUTE_UTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/register_common_dialects.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/register_common_dialects.h new file mode 100644 index 00000000..d88bcc83 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/register_common_dialects.h @@ -0,0 +1,28 @@ +/* Copyright 2023 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_REGISTER_COMMON_DIALECTS_H_ +#define TENSORFLOW_COMPILER_MLIR_REGISTER_COMMON_DIALECTS_H_ + +#include "mlir/IR/DialectRegistry.h" // from @llvm-project + +namespace mlir { + +// Inserts common Tensorflow dialects used for offline tools. +void RegisterCommonToolingDialects(mlir::DialectRegistry& registry); + +}; // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_REGISTER_COMMON_DIALECTS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/analysis/per_function_aggregate_analysis.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/analysis/per_function_aggregate_analysis.h new file mode 100644 index 00000000..5ba65901 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/analysis/per_function_aggregate_analysis.h @@ -0,0 +1,85 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_ANALYSIS_PER_FUNCTION_AGGREGATE_ANALYSIS_H_ +#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_ANALYSIS_PER_FUNCTION_AGGREGATE_ANALYSIS_H_ + +#include +#include +#include + +#include "llvm/ADT/DenseMap.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project + +namespace mlir { +namespace TF { +namespace detail { + +// This template defines an aggregate analysis base class, which analyzes a +// module but the analysis info is stored per function. +template +class PerFunctionAggregateAnalysis { + public: + using Info = InfoT; + + // Returns the analysis info for the given function. + const Info& GetAnalysisForFunc(func::FuncOp func) const { + auto it = info_map_.find(func); + assert(it != info_map_.end()); + return it->second; + } + + protected: + // Since `InfoT` might be large, DenseMap is used instead of SmallDenseMap to + // avoid stack overflow. + llvm::DenseMap info_map_; +}; + +} // namespace detail + +// Base CRTP class to help write passes that are consumes a per-function +// aggregate analysis and operate on all non-extern functions (similar to a +// OperationPass, but with no concurrency between functions). The +// derived classes need to provide a runOnFunction() method that accepts the +// function and the analysis information for that function. +template +class PerFunctionAggregateAnalysisConsumerPass + : public PassWrapper< + PerFunctionAggregateAnalysisConsumerPass, + OperationPass> { + public: + static ::mlir::TypeID resolveTypeID() { + static ::mlir::SelfOwningTypeID id; + return id; + } + + private: + void runOnOperation() override { + ModuleOp op = this->getOperation(); + DerivedT& derived = *static_cast(this); + auto& analysis = this->template getAnalysis(); + + for (auto func : op.getOps()) + if (!func.isExternal()) + derived.runOnFunction(func, analysis.GetAnalysisForFunc(func)); + } +}; + +} // namespace TF +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_ANALYSIS_PER_FUNCTION_AGGREGATE_ANALYSIS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/analysis/resource_alias_analysis.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/analysis/resource_alias_analysis.h new file mode 100644 index 00000000..c49852c1 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/analysis/resource_alias_analysis.h @@ -0,0 +1,175 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_ANALYSIS_RESOURCE_ALIAS_ANALYSIS_H_ +#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_ANALYSIS_RESOURCE_ALIAS_ANALYSIS_H_ + +#include +#include +#include + +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SetVector.h" +#include "llvm/ADT/SmallSet.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringMap.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/Region.h" // from @llvm-project +#include "mlir/IR/SymbolTable.h" // from @llvm-project +#include "mlir/IR/TypeUtilities.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Support/TypeID.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/analysis/per_function_aggregate_analysis.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" + +namespace mlir { +namespace TF { +namespace detail { +class BacktrackAnalysis; +class BacktrackAnalysisInfo; + +// Resource alias analysis information for a single function. +class ResourceAliasAnalysisInfo { + public: + // Constructs analysis info by analyzing the given function. + ResourceAliasAnalysisInfo(func::FuncOp func, + const BacktrackAnalysis& backtrack_analysis, + SymbolTableCollection& symbol_table_collection); + + ResourceAliasAnalysisInfo(ResourceAliasAnalysisInfo&&) = default; + + // Returns if the analysis fails to resolve a resource-type value. + bool IsUnknownResource(Value resource) const; + + // Returns the set of unique IDs which `resource` could alias. Requires that + // IsUnknownResource(resource) == false. + const llvm::SmallSet& GetResourceUniqueIds(Value resource) const; + + // Returns the set of values that are potentially aliases of `value`. Requires + // `IsUnknownResource(resource) == false`. + llvm::SmallSetVector GetResourceAliases(Value resource) const; + + llvm::SmallSetVector GetValuesForResourceId(int64_t id) const { + auto it = id_to_resource_values_.find(id); + if (it == id_to_resource_values_.end()) { + return {}; // return empty set + } + return it->getSecond(); + } + + // Returns true iff given resource is allocated by op with + // `UniqueResourceAllocation` trait. This can be utilized for while-loop + // parallelization. + bool IsUniqueResourceAllocationId(int64_t resource_id) const { + return unique_resource_allocation_ids_.contains(resource_id); + } + + private: + // Maps resource value to unique ID and vice-versa. Returns true if the + // mapping has changed. + bool AddValueUniqueIDMapping(Value value, int64_t id) { + resource_value_to_ids_[value].insert(id); + return id_to_resource_values_[id].insert(value); + } + + // Returns the set unique Values which map to `id`. + const llvm::SmallSetVector& GetUniqueIdResources(int64_t id) const; + + // Propagates the resource IDs from an input operand to a result. Returns + // true of the mapping has changed. + bool PropagateInputToOutput(const Value& operand, const OpResult& result); + + // Analyzes while loops to compute resource IDs for the loop results. + // `body_info` is the backtrack analysis info for the loop body. + void AnalyzeWhileLoop(Operation* while_op, + const BacktrackAnalysisInfo& body_info); + + // Analyzes tf.Case/tf.If ops to compute resource IDs. + template + void AnalyzeFunctionalCaseOrIfOp(CaseOrIfOp case_or_if_op, + llvm::ArrayRef functions, + const BacktrackAnalysis& backtrack_analysis); + + // Analyzes tf.CaseRegion/tf.IfRegion ops to compute resource IDs. + void AnalyzeRegionCaseOrIfOp(Operation* case_or_if_op, + const BacktrackAnalysis& backtrack_analysis); + + // Maps each resource-type value to a set of unique IDs that it could alias. + llvm::SmallDenseMap, 8> + resource_value_to_ids_; + + // Maps each unique ID to a set of resource-type values that could alias to + // it. This is inverse of `resource_value_to_ids_` map. + llvm::SmallDenseMap, 8> + id_to_resource_values_; + + // Maps MLIR type IDs for resource types to internal resource type IDs. + llvm::SmallDenseMap type_id_to_internal_type_id_; + + // Contains IDs of all resources that are allocated by ops with + // `UniqueResourceAllocation` trait. + llvm::SmallDenseSet unique_resource_allocation_ids_; + + public: + // Resource IDs have the following semantics: + // a) -1 represents an unknown resource (both instance and type unknown) + // b) IDs in range [0,kMaxResourceTypeId] represent resource type IDs; we use + // such IDs when we know the resource type but not the instance + // c) IDs > kMaxResourceTypeId represent resource instance IDs (i.e., we know + // the specific resource instance) + // + // Note: In general, there can be different ops allocating a resource of the + // same type, for one we might assign a resource type ID and for the other + // a resource instance ID. That means, they will be treated as non-aliasing. + // This is correct for all current cases. A problematic case could be if we + // had two ops A and B, A has the `ResourceHandleAllocatorInterface` and B has + // not, and both ops might return a handle to the same resource (depending on + // attributes). In this case, the return value of A would get a different ID + // than the return value of B although both could point to the same resource. + // It seems highly unlikely to encounter such a case but, to be safe, this + // should be revisited for new resource-allocators that might potentially + // break our currently guaranteed correctness. + // For context, we are very conservative here compared to + // `auto_control_deps.py` where it is assumed that allocated resource values + // NEVER alias. We should align our assumptions in the future. + static constexpr int64_t kUnknownResourceId = -1; + static constexpr int64_t kInvalidResourceId = -2; + static constexpr int64_t kMaxResourceTypeId = 9999; +}; + +} // namespace detail + +// An analysis that runs on a module and maps each resource-type value to a +// set of unique IDs representing the possible resources it could alias. +// +// Note that this is not an inter-procedural or inter-regional analysis, i.e., +// each function and region are handled separately and cross-function or cross- +// region aliasing cannot be checked by this analysis. +class ResourceAliasAnalysis : public detail::PerFunctionAggregateAnalysis< + detail::ResourceAliasAnalysisInfo> { + public: + // Constructs analysis by analyzing the given module operation. + explicit ResourceAliasAnalysis(ModuleOp module); +}; + +} // namespace TF +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_ANALYSIS_RESOURCE_ALIAS_ANALYSIS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/analysis/resource_dataflow.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/analysis/resource_dataflow.h new file mode 100644 index 00000000..1e68ac41 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/analysis/resource_dataflow.h @@ -0,0 +1,85 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_ANALYSIS_RESOURCE_DATAFLOW_H_ +#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_ANALYSIS_RESOURCE_DATAFLOW_H_ + +#include +#include + +#include "llvm/ADT/BitVector.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/Debug.h" +#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h" // from @llvm-project +#include "mlir/Analysis/DataFlow/SparseAnalysis.h" // from @llvm-project +#include "mlir/Analysis/DataFlowFramework.h" // from @llvm-project +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/SymbolTable.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" + +namespace mlir { +namespace TF { + +// Used as a lattice value. +struct ResourceConstructingOps { + explicit ResourceConstructingOps(Operation *op = nullptr); + static ResourceConstructingOps EntryState(MLIRContext *context); + static ResourceConstructingOps EntryState(Value value); + bool operator==(const ResourceConstructingOps &rhs) const { + return ops == rhs.ops; + } + + static ResourceConstructingOps join(const ResourceConstructingOps &lhs, + const ResourceConstructingOps &rhs); + void print(raw_ostream &os) const; + + // The operation(s) which created the resource value. + // IR constructs (i.e., GlobalTensorOp) are not const-correct. + mutable DenseSet ops; +}; + +struct IsComposite { + explicit IsComposite(Operation *op = nullptr); + static IsComposite EntryState(MLIRContext *context); + static IsComposite EntryState(Value value); + bool operator==(const IsComposite &rhs) const { + return is_on_composite_device == rhs.is_on_composite_device; + } + + static IsComposite join(const IsComposite &lhs, const IsComposite &rhs); + void print(raw_ostream &os) const; + + bool is_on_composite_device = false; +}; + +typedef dataflow::Lattice ResourceDataflowState; +typedef dataflow::Lattice IsCompositeDataflowState; + +void LoadResourceDataflowAnalysis(DataFlowSolver &solver); +void LoadIsCompositeDataflowAnalysis(DataFlowSolver &solver); + +} // namespace TF +} // namespace mlir +#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_ANALYSIS_RESOURCE_DATAFLOW_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/analysis/resource_value_typed_analyzer.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/analysis/resource_value_typed_analyzer.h new file mode 100644 index 00000000..738d8c1d --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/analysis/resource_value_typed_analyzer.h @@ -0,0 +1,79 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_ANALYSIS_RESOURCE_VALUE_TYPED_ANALYZER_H_ +#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_ANALYSIS_RESOURCE_VALUE_TYPED_ANALYZER_H_ + +#include + +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/StringRef.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/Region.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" + +namespace mlir { +namespace TF { + +class ResourceAnalyzer { + public: + explicit ResourceAnalyzer(ModuleOp module, bool skip_session_init = false); + + bool IsPotentiallyWritten(Value resource) const; + + private: + // Analyze the specified region for resource mutating operations, namely + // TF::AssignVariableOp, if so, set the resource associated as "potentially + // written". + LogicalResult AnalyzeRegion(Region& region); + + // If an op is not one of the handled ones, we assume all resource usages + // within its purview are mutating in nature. + void PropagatePotentiallyWrittenWithinUnhandledOp(Operation* op); + + // Given a Region associated with the callee and operands from the + // corresponding callOp, propagate the potentially written decision to the + // callOp's operands, if the corresponding region's arguments are potentially + // written resources. + void PropagatePotentiallyWrittenUpFromCallee( + Region& region, Operation::operand_range propagate_to); + + // Marks 'resource' as written. + void SetPotentiallyWritten(Value resource); + + struct ResourceInfo { + bool potentially_written = false; + }; + // Key: Resource Value's + // Value: Information we know about that Value. + // Note that these Value's are in general in different functions. + DenseMap resource_infos_; + // The set of regions we already discovered. + DenseSet discovered_; + // Identifiers about mutable variables. + // All variables are identified by (device, container, shared_name). + DenseSet> + mutable_variables_; +}; + +} // namespace TF +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_ANALYSIS_RESOURCE_VALUE_TYPED_ANALYZER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.h new file mode 100644 index 00000000..feb90de1 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.h @@ -0,0 +1,343 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_ANALYSIS_SIDE_EFFECT_ANALYSIS_H_ +#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_ANALYSIS_SIDE_EFFECT_ANALYSIS_H_ + +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/STLFunctionalExtras.h" +#include "llvm/ADT/SetVector.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/SmallSet.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringMap.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/Region.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/analysis/per_function_aggregate_analysis.h" +#include "tensorflow/compiler/mlir/tensorflow/analysis/resource_alias_analysis.h" + +namespace mlir { +namespace TF { +using ResourceId = int64_t; +inline constexpr ResourceId kUnknownResourceId = + ResourceAliasAnalysis::Info::kUnknownResourceId; +static_assert(kUnknownResourceId < 0, "kUnknownResourceId must be < 0"); + +// Maps group IDs to branch IDs. +using ParallelIdsMap = std::map; +using OpToParallelIdsMap = absl::flat_hash_map; + +namespace detail { + +class OpSideEffectCollector; + +using StackResourceToOps = std::vector< + absl::flat_hash_map>>; + +// Side effect analysis info for a single function. +// +// This class provides an interface for querying control predecessors and +// successors for ops of the given function. This information is computed from +// side effects, using resource alias analysis where possible. +// Remarks: +// - Control dependencies model execution order constraints for side-effecting +// ops. For example, two ops writing to the same resource cannot switch their +// order and cannot be executed in parallel. +// - A control dependency (A,B) means that op A has to be executed before op B. +// A is a control predecessor of B, and B is a control successor of A. +// - The control dependencies provided by side effect analysis are guaranteed to +// be sufficient for correct execution but they are not guaranteed to be +// minimal (that means, some control dependencies might not be required for +// correct execution). +class SideEffectAnalysisInfo { + public: + SideEffectAnalysisInfo() = default; + + // Constructs analysis info by analyzing the given function. + SideEffectAnalysisInfo(func::FuncOp func_op, + const OpSideEffectCollector& op_side_effect_collector, + const TF::ResourceAliasAnalysis::Info& alias_analysis, + const OpToParallelIdsMap& op_to_parallel_ids) + : op_side_effect_collector_(op_side_effect_collector), + alias_analysis_(alias_analysis), + op_to_parallel_ids_(op_to_parallel_ids) { + AnalyzeFunction(func_op); + } + + // Constructs analysis info by analyzing the given region. + SideEffectAnalysisInfo(Region* region, + const OpSideEffectCollector& op_side_effect_collector, + const TF::ResourceAliasAnalysis::Info& alias_analysis, + const OpToParallelIdsMap& op_to_parallel_ids) + : op_side_effect_collector_(op_side_effect_collector), + alias_analysis_(alias_analysis), + op_to_parallel_ids_(op_to_parallel_ids) { + AnalyzeRegion(region); + } + + SideEffectAnalysisInfo(SideEffectAnalysisInfo&&) = default; + + // Returns a vector of ops that are direct control predecessors of `op`, + // sorted in program order. If `filter` is provided, only predecessors that + // pass the filter (returning true) will be included. + const llvm::SmallVector& DirectControlPredecessors( + Operation* op) const; + llvm::SmallVector DirectControlPredecessors( + Operation* op, llvm::function_ref filter) const; + + // pass the filter (returning true) will be included. + const llvm::SmallVector& DirectControlSuccessors( + Operation* op) const; + llvm::SmallVector DirectControlSuccessors( + Operation* op, llvm::function_ref filter) const; + + // Returns a vector of ops that are control sinks (i.e. side-effecting ops + // with no control successors). + llvm::ArrayRef ControlSinks() const { + return sorted_control_sinks_; + } + + // Returns a vector with IDs of all resources that might be accessed by `op`. + // This includes both op-based and value-based resources. The bool indicates + // whether a resource is accessed read-only. + const llvm::SmallVector>& GetResourceIds( + Operation* op) const; + + // Returns true iff given resource is allocated by op with + // `UniqueResourceAllocation` trait. This can be utilized for while-loop + // parallelization. + bool IsUniqueResourceAllocationId(ResourceId resource_id) const { + return alias_analysis_.IsUniqueResourceAllocationId(resource_id); + } + + const TF::ResourceAliasAnalysis::Info& GetAliasAnalysis() const { + return alias_analysis_; + } + + private: + // Runs the analysis and populates `sorted_control_predecessors_` and + // `sorted_control_successors_` for `func_op`. Clears `control_predecessors_`. + void AnalyzeFunction(func::FuncOp func_op); + + // Runs the analysis and populates `control_predecessors_` for `region`. + void AnalyzeRegion(Region* region); + + // Runs the analysis and populates `control_predecessors_` for `op`. + void AnalyzeOp(Operation* op); + + // Updates `control_predecessors_` for given `resource_id` and `op`. + void AddPredecessorsForAccess(ResourceId resource_id, Operation* op, + bool read_only); + + // Updates resource access for given `resource_id` and `op` in + // `per_resource_access_info_` and `op_to_resource_ids_`. + void UpdateAccess(ResourceId resource_id, Operation* op, bool read_only); + + // Returns true iff the last unknown resource access is already indirectly + // tracked by a previous `resource` access. `read_only` specifies the type of + // access considered. + bool IsUnknownAccessIndirectlyTrackedByResource(ResourceId resource, + bool read_only); + + // Returns a set of resource IDs that have potential dependencies to + // `resource_id` (i.e., there are potential dependencies between the + // resources corresponding to the IDs). + llvm::SmallSet GetDependentIds(ResourceId resource_id, + bool is_fetch_op) const; + + // Returns the parallel ids of the op. + ParallelIdsMap GetParallelIdsMap(Operation* op); + + // Converts from read/write state that relates ops with the same parallel id + // to a set of last accesses for use with other parallel ids. Reads/writes + // between parallel ids are conservatively approximated as writes. + absl::flat_hash_set GetLastWrites(ResourceId resource_id); + + // Sets the read/write state for ops within the same parallel id. + void SetLastWrites(ResourceId resource_id, + absl::flat_hash_set last_writes); + + // Enters a sequence of ops that have the same parallel id. This converts + // stack state to per_resource_access_info_. + void Enter(); + + // Exits a sequence of ops that have the same parallel id. This converts + // per_resource_access_info_ to stack state. + void Exit(); + + // Steps down one parallel nesting level (i.e. increase parallel id size + // by 1). + void Down(); + + // Steps laterally between parallel nesting levels. + void Lateral(); + + // Steps up one parallel nesting level. + void Up(); + + // Transitions nesting levels from `from` to `to`. + void Transition(ParallelIdsMap from, ParallelIdsMap to); + + // Transitions nesting levels from the previous parallel id to `to`. + void TransitionToParallelIdsMap(ParallelIdsMap to); + + // Transitions nesting levels from the previous parallel id to `to`. + void TransitionToOp(Operation* to); + + // Initializes stack state for a function. + void InitFunction(); + + // Uninitializes stack state for a function. + void UninitFunction(); + + // Maps from an op to its control predecessors. + llvm::SmallDenseMap, 8> + control_predecessors_; + // Maps from an op to its control predecessors sorted in program order. + llvm::SmallDenseMap, 8> + sorted_control_predecessors_; + // Maps from an op to its control successors sorted in program order. + llvm::SmallDenseMap, 8> + sorted_control_successors_; + // Side-effecting ops with no control successors in this function. + llvm::SmallVector sorted_control_sinks_; + + // Maps from an op to its resource IDs along with a bool indicating if the + // resource is accessed `read-only`. + llvm::SmallDenseMap>> + op_to_resource_ids_; + llvm::SmallVector> empty_resource_ids_; + + // For predecessor / successor queries on ops we don't track. + llvm::SmallVector empty_operation_set_; + + // Internal per-resource data structure for building the dependencies. + struct PerResourceAccessInfo { + // Last writes to resource before the current op is being analyzed. In + // general there can be multiple most recent accesses when ops have + // different parallel ids. + absl::flat_hash_set last_writes; + // Read ops since `last_write` before the current op is being analyzed. + llvm::SmallVector reads_since_last_write; + // Whether a previous access of this resource already tracks the last + // unknown read(s). + bool are_last_unknown_reads_tracked = false; + // Whether a previous write access of this resource already tracks the last + // unknown write. + bool is_last_unknown_write_tracked_by_write = false; + // Whether a previous read or write access of this resource already tracks + // the last unknown write. + bool is_last_unknown_write_tracked = false; + }; + + // Resource access info per resource ID. + llvm::SmallDenseMap + per_resource_access_info_; + + // Hold the last set of reads and writes that + // will be depended on by ops with greater nesting depths. + // For example, the last read/write with parallel_ids `{group0:branch0}` + // lives at stack depth 1 and is depended on by ops with parallel_ids + // of the form `{group0:branch0, ...}`. + // + // We track a set of reads/writes rather than a single read/write because + // multiple parallel ops may be live at any particular point. + StackResourceToOps stack_down_; + + // Hold the last set of reads and writes that will be depended on by + // ops with lesser nesting depths. For example, the last read/writes + // with parallel_ids `{group0:branch0}` and `{group0:branch1}` live at + // stack depth 1 and are depended on by ops with parallel_ids `{}`. + StackResourceToOps stack_up_; + + // Parallel ids of the previously traversed op in the same function. + // The transition from the previous parallel_ids to the current parallel_ids + // determines which stack actions occur. + ParallelIdsMap previous_parallel_ids_; + + const OpSideEffectCollector& op_side_effect_collector_; + const TF::ResourceAliasAnalysis::Info& alias_analysis_; + + // Map op to parallel_ids. If an op is not a key then it has empty parallel + // ids, which corresponds to nesting depth 0. + const OpToParallelIdsMap& op_to_parallel_ids_; +}; + +} // namespace detail + +// An analysis that runs on a function and infers the control predecessors and +// successors for each op, based on side effects on known and unknown resources. +// Side-effecting ops on unknown resources are conservatively treated as +// interfering with all known resource op accesses. It distinguishes accesses +// based on whether they are read-only, and read-only ops do not interfere with +// each other. +// +// If there are nested regions, each region is handled separately, and control +// dependencies are only tracked for ops under the same parent op. +class SideEffectAnalysis : public detail::PerFunctionAggregateAnalysis< + detail::SideEffectAnalysisInfo> { + public: + // Constructs analysis by analyzing the given module operation. Because no + // parallel_ids are given, the program has sequential memory semantics. + explicit SideEffectAnalysis(ModuleOp module_op); + + // Constructs analysis by analyzing the given module operation where + // `op_to_parallel_ids` supplies the group to branch map. This is the map + // that is encoded by op attribute `_parallel_execution_ids`. This map is + // used to code which ops should be executed in parallel and which + // ops should be executed in sequence after ops have been flattened. + // For example, children of + // `tf_device.parallel_execute` will be executed in parallel and + // each replica child of a `tf_device.replicate` will be executed in parallel. + // Otherwise, by default, an op's children will be executed in sequence. + // + // Two ops with the same groups and different branches are considered + // parallel so are not made dependent. For example if `OpA` has parallel_ids + // `{group0:branch0, group1:branch0}` + // and `OpB` has parallel_ids + // `{group0:branch1, graph1:branch0}` + // then `OpA` and `OpB` are executed in parallel because `group0` is common + // with a different branch. + // + // Two ops with the same branches between common groups are executed in + // sequence so are made dependent. For example, if `OpA` has parallel_ids + // `{group0:branch0, group1:branch0}` + // and `OpB` has parallel_ids + // `{group0:branch0, group2:branch0}` + // then `OpA` and `OpB` are executed in sequence because the common groups + // have the same branch. + // + // If an op is not in `op_to_parallel_ids` then it is considered to have the + // empty map from groups to branches. + SideEffectAnalysis(ModuleOp module_op, OpToParallelIdsMap op_to_parallel_ids); + + private: + ResourceAliasAnalysis alias_analysis_; +}; + +} // namespace TF +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_ANALYSIS_SIDE_EFFECT_ANALYSIS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/analysis/tf_dataflow.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/analysis/tf_dataflow.h new file mode 100644 index 00000000..a7d622c0 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/analysis/tf_dataflow.h @@ -0,0 +1,92 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_ANALYSIS_TF_DATAFLOW_H_ +#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_ANALYSIS_TF_DATAFLOW_H_ + +#include "llvm/ADT/STLExtras.h" +#include "mlir/Analysis/DataFlow/SparseAnalysis.h" // from @llvm-project +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/SymbolTable.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" + +namespace mlir { +namespace TF { + +template +class TensorflowDataflowAnalysis + : public dataflow::SparseForwardDataFlowAnalysis> { + public: + using StateT = dataflow::Lattice; + using dataflow::SparseForwardDataFlowAnalysis< + StateT>::SparseForwardDataFlowAnalysis; + using dataflow::SparseForwardDataFlowAnalysis::getLatticeElement; + ~TensorflowDataflowAnalysis() override = default; + + bool ForwardThroughTFOperation(Operation *op, + ArrayRef operands, + ArrayRef results) { + if (auto cast = dyn_cast(op)) { + this->join(results[0], *operands[0]); + } else if (auto while_op = dyn_cast(op)) { + for (auto ®ion : while_op->getRegions()) { + for (auto [arg, value] : + llvm::zip(region.getArguments(), while_op->getOperands())) { + this->join(getLatticeElement(arg), *getLatticeElement(value)); + } + } + } else if (auto while_op = dyn_cast(op)) { + func::FuncOp cond = SymbolTable::lookupNearestSymbolFrom( + while_op, while_op.getCondAttr()); + func::FuncOp body = SymbolTable::lookupNearestSymbolFrom( + while_op, while_op.getBodyAttr()); + for (auto &arg : while_op->getOpOperands()) { + BlockArgument cond_arg = cond.getArgument(arg.getOperandNumber()); + this->join(getLatticeElement(cond_arg), *getLatticeElement(arg.get())); + BlockArgument body_arg = body.getArgument(arg.getOperandNumber()); + this->join(getLatticeElement(body_arg), *getLatticeElement(arg.get())); + } + } else if (auto graph = dyn_cast(op)) { + for (auto &arg : graph.GetFetch()->getOpOperands()) { + if (arg.getOperandNumber() < graph.getNumResults()) { + auto result = graph.getResult(arg.getOperandNumber()); + this->join(getLatticeElement(result), *getLatticeElement(arg.get())); + } + } + } else if (auto island = dyn_cast(op)) { + for (auto &arg : island.GetYield()->getOpOperands()) { + auto result = island.getResult(arg.getOperandNumber()); + this->join(getLatticeElement(result), *getLatticeElement(arg.get())); + } + } else { + return false; + } + return true; + } + + void setToEntryState(StateT *lattice) override { + this->propagateIfChanged( + lattice, lattice->join(L::EntryState(lattice->getAnchor()))); + } +}; + +} // namespace TF +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_ANALYSIS_TF_DATAFLOW_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/dialect_registration.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/dialect_registration.h new file mode 100644 index 00000000..3f8305ba --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/dialect_registration.h @@ -0,0 +1,61 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_DIALECT_REGISTRATION_H_ +#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_DIALECT_REGISTRATION_H_ + +#include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project +#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h" // from @llvm-project +#include "mlir/Dialect/Func/Extensions/AllExtensions.h" // from @llvm-project +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Dialect/MLProgram/IR/MLProgram.h" // from @llvm-project +#include "mlir/Dialect/MLProgram/IR/MLProgramAttributes.h" // from @llvm-project +#include "mlir/IR/Dialect.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h" +#include "tensorflow/core/ir/ops.h" + +namespace mlir { +// Inserts all the TensorFlow dialects in the provided registry. This is +// intended for tools that need to register dialects before parsing .mlir files. +// If include_extensions is set (default), also registers extensions. Otherwise +// it is the responsibility of the caller, typically required when the registry +// is appended to the context in a parallel context, which does not allow for +// extensions to be added. +inline void RegisterAllTensorFlowDialectsImpl(DialectRegistry ®istry, + bool include_extensions = true) { + registry + .insert(); + if (include_extensions) { + mlir::func::registerAllExtensions(registry); + } +} + +// Inserts all the TensorFlow dialects in the provided registry. This is +// intended for tools that need to register dialects before parsing .mlir files. +inline void RegisterAllTensorFlowDialects(DialectRegistry ®istry) { + RegisterAllTensorFlowDialectsImpl(registry, true); +} +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_DIALECT_REGISTRATION_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/ir/host_runtime/tfrt_ops.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/ir/host_runtime/tfrt_ops.h new file mode 100644 index 00000000..73243e2f --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/ir/host_runtime/tfrt_ops.h @@ -0,0 +1,30 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_HOST_RUNTIME_TFRT_OPS_H_ +#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_HOST_RUNTIME_TFRT_OPS_H_ + +#include "mlir/Bytecode/BytecodeOpInterface.h" // from @llvm-project // IWYU pragma: keep +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/Interfaces/DerivedAttributeOpInterface.h" // from @llvm-project +#include "mlir/Interfaces/InferTypeOpInterface.h" // from @llvm-project +#include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_traits.h" + +#define GET_OP_CLASSES +#include "tensorflow/compiler/mlir/tensorflow/ir/host_runtime/tfrt_ops.h.inc" + +#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_HOST_RUNTIME_TFRT_OPS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/ir/tf_arith_ops_folder.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/ir/tf_arith_ops_folder.h new file mode 100644 index 00000000..64b5d2e1 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/ir/tf_arith_ops_folder.h @@ -0,0 +1,133 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_ARITH_OPS_FOLDER_H_ +#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_ARITH_OPS_FOLDER_H_ + +#include + +#include "llvm/ADT/StringRef.h" +#include "mlir/Dialect/Traits.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/OpDefinition.h" // from @llvm-project +#include "mlir/IR/TypeRange.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project + +namespace mlir { + +class Operation; + +namespace TF { + +class AddV2Op; +class SubOp; +class MulOp; +class DivOp; +class RealDivOp; + +// Verifies an reduction op's `input` and reduction `dims`. +LogicalResult VerifyReductionInputAndDims(Value input, Value dims, + Location loc); + +// A type range with description (in singular form) attached to it. +using TypeRangeWithDesc = std::pair; + +LogicalResult VerifyTypeRangesAreCompatible(Operation *op, + TypeRangeWithDesc range0, + TypeRangeWithDesc range1); + +// Fold Arithmetic Op if one of the operands is a constant known to be an +// Identity (e.g. X+0, X*1, etc...). For commutative operations fold if +// known identity value is either lhs or rhs. +template < + typename OpT, + typename std::enable_if::value>::type * = nullptr> +OpFoldResult IdentityArithmeticOpFolder(OpT arithmetic_op, + ArrayRef operands) { + auto lhs_type = mlir::cast(arithmetic_op.getX().getType()); + auto rhs_type = mlir::cast(arithmetic_op.getY().getType()); + auto result_type = + mlir::cast(arithmetic_op.getResult().getType()); + + // We can fold arithmetic operation only of we can prove that we will not + // accidentally hide a broadcasting error. + auto is_valid_broadcasting = [](ShapedType operand_ty, ShapedType identity_ty, + ShapedType result_ty) -> bool { + // Scalar identity is broadcastable to any operand shape, we only need to + // check that operand has the same shape as a result. + bool scalar_identity = identity_ty.hasRank() && identity_ty.getRank() == 0; + if (scalar_identity) return operand_ty == result_ty; + + // If identity is not a scalar, we must verify that identity shape is + // statically known to be broadcastable to the operand shape and the operand + // and result shape are equal. + return operand_ty == result_ty && identity_ty.hasStaticShape() && + result_ty.hasStaticShape() && + OpTrait::util::staticallyKnownBroadcastable(operand_ty.getShape(), + identity_ty.getShape()); + }; + + // Check that we have a constant operand on one side (candidate for identity). + const bool is_commutative = + (std::is_same::value || std::is_same::value); + auto lhs_attr = mlir::dyn_cast_or_null(operands[0]); + auto rhs_attr = mlir::dyn_cast_or_null(operands[1]); + if (!rhs_attr && !(is_commutative && lhs_attr)) return {}; + + // Mul and Div ops have identity value one while AddV2 and SubOp have identity + // value zero. + const int identity = + (std::is_same::value || std::is_same::value || + std::is_same::value) + ? 1 + : 0; + + Type element_ty = lhs_type.getElementType(); + Attribute identity_attr; + if (auto ty = mlir::dyn_cast(element_ty)) { + identity_attr = FloatAttr::get(ty, static_cast(identity)); + } else if (auto ty = mlir::dyn_cast(element_ty)) { + identity_attr = IntegerAttr::get(ty, static_cast(identity)); + } else { + return {}; + } + + // Fold: Op(Operand, Identity) -> Operand. + if (rhs_attr && is_valid_broadcasting(lhs_type, rhs_type, result_type)) { + if (rhs_attr.isSplat() && + rhs_attr.getSplatValue() == identity_attr) + return arithmetic_op.getX(); + } + + // Fold: Op(Identity, Operand) -> Operand for commutative operations. + if (lhs_attr && is_commutative && + is_valid_broadcasting(rhs_type, lhs_type, result_type)) { + if (lhs_attr.isSplat() && + lhs_attr.getSplatValue() == identity_attr) + return arithmetic_op.getY(); + } + + return {}; +} + +} // namespace TF +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_ARITH_OPS_FOLDER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h new file mode 100644 index 00000000..d5223870 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h @@ -0,0 +1,36 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This file defines the attributes used in the TensorFlow dialect. + +#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_ATTRIBUTES_H_ +#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_ATTRIBUTES_H_ + +#include "tensorflow/core/ir/types/dialect.h" + +namespace mlir { +namespace TF { + +// This all moved under tensorflow/core/ir/types and these using declaration are +// to help with the transition. +using mlir::tf_type::FuncAttr; // NOLINT +using mlir::tf_type::PlaceholderAttr; // NOLINT +using mlir::tf_type::ShapeAttr; // NOLINT +using mlir::tf_type::TensorProtoAttr; // NOLINT + +} // end namespace TF +} // end namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_ATTRIBUTES_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/ir/tf_device.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/ir/tf_device.h new file mode 100644 index 00000000..0c7ff33e --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/ir/tf_device.h @@ -0,0 +1,54 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This file defines the tf_device dialect: it contains operations that model +// TensorFlow's actions to launch computations on accelerator devices. + +#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_DEVICE_H_ +#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_DEVICE_H_ + +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/Dialect.h" // from @llvm-project +#include "mlir/IR/OpDefinition.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Interfaces/ControlFlowInterfaces.h" // from @llvm-project +#include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project + +namespace mlir { +namespace tf_device { + +// The TensorFlow Device dialect. +// +// This dialect contains operations to describe/launch computations on devices. +// These operations do not map 1-1 to TensorFlow ops and requires a lowering +// pass later to transform them into Compile/Run op pairs, like XlaCompile and +// XlaRun. +class TensorFlowDeviceDialect : public Dialect { + public: + static StringRef getDialectNamespace() { return "tf_device"; } + // Constructing TensorFlowDevice dialect under an non-null MLIRContext. + explicit TensorFlowDeviceDialect(MLIRContext* context); +}; + +} // namespace tf_device +} // namespace mlir + +// Declares the operations for this dialect using the generated header. +#define GET_OP_CLASSES +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h.inc" + +#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_DEVICE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h new file mode 100644 index 00000000..cad01806 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h @@ -0,0 +1,120 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This file defines the standard MLIR TensorFlow dialect after control +// dependences are raise to the standard form. + +#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_DIALECT_H_ +#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_DIALECT_H_ + +#include +#include + +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/Dialect.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" + +namespace mlir { +namespace TF { + +class TensorFlowRegistryEffectInterfaceFallback; + +class TensorFlowDialect final : public Dialect { + public: + explicit TensorFlowDialect(MLIRContext *context); + ~TensorFlowDialect() override; + + static StringRef getDialectNamespace() { return "tf"; } + + // Overrides to redirect to tf_type dialect. + Attribute parseAttribute(DialectAsmParser &parser, Type type) const override; + Type parseType(DialectAsmParser &parser) const override; + + // Gradient attribute ("tf.gradient") in the list of NamedAttributes in a + // function references to its gradient function. This attribute in TensorFlow + // Dialect is used to model TF GradientDef. GetGradientAttrName() returns the + // string description of gradient attribute. + static StringRef GetGradientAttrName() { return "tf.gradient"; } + + // This attribute marks if a function is stateful. + // Returns the string description of stateful attribute. + static StringRef GetStatefulAttrName() { return "tf.signature.is_stateful"; } + + // Returns true if the op can be duplicated during transformations. + static bool CanDuplicate(Operation *op); + + // Returns true if the op can have side effects. + static bool CanHaveSideEffects(Operation *op); + + // Registered hook to materialize a constant operation from a given attribute + // value with the desired resultant type. + Operation *materializeConstant(OpBuilder &builder, Attribute value, Type type, + Location loc) override; + + typedef std::function AdditionalOpFunction; + + // Register an op registration hook which is invoked during construction. + // + // A hook may use the public addOperations() method to add additional + // operations to the dialect. Hooks will only apply to subsequent + // instantations of the Dialect/MLIRContext. + static void RegisterAdditionalOperationHook(TypeID uniqueId, + AdditionalOpFunction fn); + + // Re-define publicly the protected addOperations() method from the Dialect + // class, usually used in a Dialect constructor. This allows hook + // functions to register operations on the TensorFlow dialect using the + // same interface. + template + void addOperations() { + Dialect::addOperations(); + } + + using ConstantFoldHook = LogicalResult (*)(Operation *, ArrayRef, + SmallVectorImpl &); + static void RegisterConstantFoldHook(ConstantFoldHook fn) { + constant_fold_hook_ = std::move(fn); + } + + static LogicalResult constantFold(Operation *op, ArrayRef operands, + SmallVectorImpl &results) { + if (constant_fold_hook_) return constant_fold_hook_(op, operands, results); + return failure(); + } + + static bool HasConstantFoldHook() { return constant_fold_hook_; } + + // Provides a hook for op interface. + void *getRegisteredInterfaceForOp(mlir::TypeID interface, + mlir::OperationName opName) override; + + private: + static ConstantFoldHook constant_fold_hook_; + + // Storage for a custom fallback interface. + TensorFlowRegistryEffectInterfaceFallback *fallback_effect_op_interface_; +}; + +} // namespace TF +} // namespace mlir + +#define TF_DIALECT_REGISTER_ADDITIONAL_OPERATIONS(hookFn) \ + { \ + static bool key; \ + ::mlir::TF::TensorFlowDialect::RegisterAdditionalOperationHook( \ + ::mlir::TypeID::getFromOpaquePointer(&key), hookFn); \ + } + +#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_DIALECT_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h new file mode 100644 index 00000000..a3c95bdf --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h @@ -0,0 +1,71 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This file defines the tf_executor dialect: it models the TensorFlow executor +// semantics and can represent arbitrary TensorFlow graphs. As such it follows +// the existing execution model that includes deadness propagation, concurrent +// semantics, and control dependencies. + +#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_EXECUTOR_H_ +#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_EXECUTOR_H_ + +#include "mlir/Bytecode/BytecodeOpInterface.h" // from @llvm-project // IWYU pragma: keep +#include "mlir/Dialect/Traits.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/Dialect.h" // from @llvm-project +#include "mlir/IR/Matchers.h" // from @llvm-project +#include "mlir/IR/OpImplementation.h" // from @llvm-project +#include "mlir/Interfaces/InferTypeOpInterface.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" + +namespace mlir { +namespace tf_executor { + +class TensorFlowExecutorDialect : public Dialect { + public: + static StringRef getDialectNamespace() { return "tf_executor"; } + explicit TensorFlowExecutorDialect(MLIRContext *context); + + // Parses a type registered to this dialect. + Type parseType(DialectAsmParser &parser) const override; + + // Prints a type registered to this dialect. + void printType(Type type, DialectAsmPrinter &os) const override; +}; + +// The Control type is a token-like value that models control dependencies from +// TensorFlow graphs. +class ControlType : public Type::TypeBase { + public: + using Base::Base; + static constexpr ::mlir::StringLiteral name = "tf_executor.control"; +}; + +class TokenType : public Type::TypeBase { + public: + using Base::Base; + static constexpr ::mlir::StringLiteral name = "tf_executor.token"; +}; + +} // namespace tf_executor +} // namespace mlir + +// Declares the operations for this dialect using the generated header. +#define GET_OP_CLASSES +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h.inc" + +#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_EXECUTOR_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.h new file mode 100644 index 00000000..db820889 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.h @@ -0,0 +1,166 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_OP_INTERFACES_H_ +#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_OP_INTERFACES_H_ + +#include + +#include "llvm/ADT/DenseMapInfo.h" +#include "llvm/ADT/Hashing.h" +#include "llvm/ADT/StringRef.h" +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/OpImplementation.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_verifiers.h" +#include "tensorflow/core/framework/resource_mgr.h" + +namespace mlir { +namespace TF { + +//===----------------------------------------------------------------------===// +// TensorFlow Contraction Fusion. +//===----------------------------------------------------------------------===// + +struct ContractionFusion { + explicit ContractionFusion( + StringRef output_kernel, ArrayRef additional_arguments = {}, + ArrayRef additional_attributes = {}) + : output_kernel(output_kernel.str()), + additional_arguments(additional_arguments.begin(), + additional_arguments.end()), + additional_attributes(additional_attributes.begin(), + additional_attributes.end()) {} + + // Name of the output kernel implementing the contraction fusion. + std::string output_kernel; + + // Indices of additional arguments that will be forwarded to the fused + // operation (e.g. forward bias vector if fusing BiasAdd operation). + SmallVector additional_arguments; + + // Add additional attributes to the fused node. + SmallVector additional_attributes; +}; + +//===----------------------------------------------------------------------===// +// TensorFlow Resource Handles. +//===----------------------------------------------------------------------===// + +inline bool IsResourceHandleAnonymous(StringRef name) { + return name == ::tensorflow::ResourceHandle::ANONYMOUS_NAME; +} + +// Helper struct representing an identifier for a resource handle. For resource +// handles created explicitly and shared across resource allocator ops, +// `container`, `name`, and `device` can be set. If an resource handle is tied +// to an instance of an operation (e.g. TensorFlow runtime operation caching), +// `op` can be set instead. +struct ResourceHandle { + ResourceHandle(StringRef container, StringRef name, StringRef device, + Operation* op) + : container(container), name(name), device(device), op(op) {} + + bool operator==(const ResourceHandle& rhs) const { + return container == rhs.container && name == rhs.name && + device == rhs.device && op == rhs.op; + } + + // Make ResourceHandle hashable. + friend ::llvm::hash_code hash_value(const ResourceHandle& resource_handle); + + StringRef container; + StringRef name; + StringRef device; + Operation* op = nullptr; +}; + +// Make ResourceHandle hashable. +inline ::llvm::hash_code hash_value(const ResourceHandle& resource_handle) { + return ::llvm::hash_combine(resource_handle.container, resource_handle.name, + resource_handle.device, resource_handle.op); +} + +// Helper struct holding a resource handle value and unique id associated to the +// resource handle. +struct ResourceHandleValueAndId { + ResourceHandleValueAndId(Value value, int64_t id) : value(value), id(id) {} + + Value value; + int64_t id = -1; +}; + +//===----------------------------------------------------------------------===// +// TF op helper functions for handling resource handles and ids. +//===----------------------------------------------------------------------===// + +// Returns device of op if present. If op has no device set, an empty string ref +// is returned instead. +llvm::StringRef GetDeviceOrEmpty(Operation* op); + +// Returns resource handle value and id for resource op based on attributes. If +// a resource handle is anonymous, a new id is always returned. +ResourceHandleValueAndId GetResourceHandleValueAndIdBase( + llvm::StringRef container, llvm::StringRef shared_name, + llvm::StringRef device, Value resource, + llvm::SmallDenseMap& resource_handle_id_map, + int64_t& next_id); + +// Shape functions for ops that are using TF_SameOperandsAndResultTypeResolveRef +// and have at least one operand, result type can be inferred using the first +// operand's type. + +#define INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(Op) \ + LogicalResult Op::inferReturnTypeComponents( \ + MLIRContext* context, std::optional location, \ + ValueShapeRange operands, DictionaryAttr attributes, \ + OpaqueProperties properties, RegionRange regions, \ + SmallVectorImpl& inferredReturnShapes) { \ + return inferReturnTypeComponentsFromOperands( \ + context, location, operands, attributes, properties, regions, \ + inferredReturnShapes); \ + } + +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.h.inc" +} // namespace TF +} // namespace mlir + +namespace llvm { +template <> +struct DenseMapInfo { + static mlir::TF::ResourceHandle getEmptyKey() { + return {/*container=*/"", /*name=*/"", /*device=*/"", + /*op=*/DenseMapInfo::getEmptyKey()}; + } + + static mlir::TF::ResourceHandle getTombstoneKey() { + return {/*container=*/"", /*name=*/"", /*device=*/"", + /*op=*/DenseMapInfo::getTombstoneKey()}; + } + + static unsigned getHashValue( + const mlir::TF::ResourceHandle& resource_handle) { + return mlir::TF::hash_value(resource_handle); + } + + static bool isEqual(const mlir::TF::ResourceHandle& lhs, + const mlir::TF::ResourceHandle& rhs) { + return lhs == rhs; + } +}; +} // namespace llvm + +#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_OP_INTERFACES_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h new file mode 100644 index 00000000..30c503aa --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h @@ -0,0 +1,51 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This file defines the operations used in the standard MLIR TensorFlow dialect +// after control dependences are raise to the standard form. + +#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_OPS_H_ +#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_OPS_H_ + +#include "mlir/Bytecode/BytecodeOpInterface.h" // from @llvm-project // IWYU pragma: keep +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Dialect/Traits.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/Dialect.h" // from @llvm-project +#include "mlir/IR/Matchers.h" // from @llvm-project +#include "mlir/IR/OpImplementation.h" // from @llvm-project +#include "mlir/IR/TypeUtilities.h" // from @llvm-project +#include "mlir/Interfaces/CallInterfaces.h" // from @llvm-project +#include "mlir/Interfaces/ControlFlowInterfaces.h" // from @llvm-project +#include "mlir/Interfaces/DerivedAttributeOpInterface.h" // from @llvm-project +#include "mlir/Interfaces/InferTypeOpInterface.h" // from @llvm-project +#include "mlir/Interfaces/LoopLikeInterface.h" // from @llvm-project +#include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/host_runtime/tfrt_ops.h" // IWYU pragma: keep +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_remaining_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_traits.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_verifiers.h" + +#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_OPS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.h new file mode 100644 index 00000000..2956174f --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.h @@ -0,0 +1,65 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_OPS_A_M_H_ +#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_OPS_A_M_H_ + +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Dialect/Traits.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/Dialect.h" // from @llvm-project +#include "mlir/IR/Matchers.h" // from @llvm-project +#include "mlir/IR/OpImplementation.h" // from @llvm-project +#include "mlir/IR/TypeUtilities.h" // from @llvm-project +#include "mlir/Interfaces/CallInterfaces.h" // from @llvm-project +#include "mlir/Interfaces/ControlFlowInterfaces.h" // from @llvm-project +#include "mlir/Interfaces/DerivedAttributeOpInterface.h" // from @llvm-project +#include "mlir/Interfaces/InferTypeOpInterface.h" // from @llvm-project +#include "mlir/Interfaces/LoopLikeInterface.h" // from @llvm-project +#include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_traits.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_verifiers.h" + +// IWYU pragma: private, include "third_party/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" + +namespace mlir { +namespace TF { + +class YieldOp; + +} // namespace TF +} // namespace mlir + +// TODO(b/131258166): TensorFlow's mutex.h defines a `mutex_lock` macro, whose +// purpose is to catch bug on `tensorflow::mutex_lock`. We don't use +// `tensorflow::mutex_lock` here but we have ops (`tf.MutexLock` and +// `tf.ConsumeMutexLock`) with getter methods named as `mutex_lock()`. Need to +// undefine here to avoid expanding the getter symbol as macro when including +// both mutex.h and this header file. +#undef mutex_lock + +#define GET_OP_FWD_DEFINES +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_all_ops.h.inc" +#define GET_OP_CLASSES +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.h.inc" + +#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_OPS_A_M_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_canonicalization_helper.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_canonicalization_helper.h new file mode 100644 index 00000000..fa171a00 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_canonicalization_helper.h @@ -0,0 +1,62 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_OPS_CANONICALIZATION_HELPER_H_ +#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_OPS_CANONICALIZATION_HELPER_H_ + +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.h" + +namespace mlir { +namespace TF { + +// Eliminate attributes that are not needed, but can get attached to Ops +// during import. +template +struct DropAttributes : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + // Drop the "output_shapes" attribute. + LogicalResult matchAndRewrite(Op op, + PatternRewriter &rewriter) const override { + bool found = !!op->removeAttr("output_shapes"); + return success(found); + } +}; + +// Helper function to create TF op while copying all underscore attributes from +// another TF op. +// TODO(jpienaar): This is a workaround until behavior is established. +template +OpTy CreateTfOp(RewriterBase &b, Operation *op, Args &&...args) { + auto ret = b.create(op->getLoc(), std::forward(args)...); + CopyDeviceAndUnderscoredAttributes(op, ret.getOperation()); + return ret; +} + +// Helper function to replace TF op with another op while copying all underscore +// attributes from the TF op. +// TODO(jpienaar): This is a workaround until behavior is established. +template +OpTy ReplaceTfOpWithNewOp(RewriterBase &b, Operation *op, Args &&...args) { + auto ret = CreateTfOp(b, op, std::forward(args)...); + b.replaceOp(op, ret.getOperation()->getResults()); + return ret; +} + +} // namespace TF +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_OPS_CANONICALIZATION_HELPER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_device_helper.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_device_helper.h new file mode 100644 index 00000000..4657fb18 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_device_helper.h @@ -0,0 +1,41 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_OPS_DEVICE_HELPER_H_ +#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_OPS_DEVICE_HELPER_H_ + +namespace mlir { + +class Operation; + +namespace TF { + +class RuntimeDevices; + +// Returns true if at least one GPU device is available at runtime. +bool CanUseGpuDevice(const RuntimeDevices &devices); + +// Returns true if all of the GPUs available at runtime support TensorCores +// (NVIDIA compute capability >= 7.0). +bool CanUseTensorCores(const RuntimeDevices &devices); + +// Returns true if operation does not have explicit device placement that would +// prevent it from running on GPU device. +bool CanUseGpuDevice(Operation *op); + +} // namespace TF +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_OPS_DEVICE_HELPER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_layout_helper.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_layout_helper.h new file mode 100644 index 00000000..29dae271 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_layout_helper.h @@ -0,0 +1,137 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_OPS_LAYOUT_HELPER_H_ +#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_OPS_LAYOUT_HELPER_H_ + +#include +#include + +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.h" + +namespace mlir { + +class MLIRContext; + +namespace TF { + +SmallVector ReversePermutation(ArrayRef permutation); + +SmallVector GetDataFormatPermutation(StringRef from, StringRef to); + +// Shuffle elements in the `attr` according to the permutation. Optional +// `inner_size` allows to shuffle array attributes created from rank 2 tensors +// on outer dimension only. +ArrayAttr ShuffleArrayAttr(ArrayAttr attr, ArrayRef permutation, + int inner_size = 1); + +// Shuffle ranked tensor dimensions according to the permutation. +Type ShuffleRankedTensorType(Type type, ArrayRef permutation); + +bool AreCancellablePermutations(DenseIntElementsAttr perm0, + DenseIntElementsAttr perm1); + +// Default implementation of `LayoutSensitiveInterface::UpdateDataFormat` for +// layout sensitive operations that do not have any additional layout dependent +// attributes besides `data_format` string. +template +LogicalResult UpdateDataFormat(StringRef data_format, Op *op) { + auto perm = GetDataFormatPermutation(op->getDataFormat(), data_format); + if (perm.empty()) return failure(); + + // Update data format attribute. + (*op)->setAttr("data_format", StringAttr::get(op->getContext(), data_format)); + + // Update types for all layout sensitive results. + auto layout_sensitive = cast(op->getOperation()); + for (unsigned idx : layout_sensitive.GetLayoutDependentResults()) { + OpResult result = op->getOperation()->getResult(idx); + result.setType(ShuffleRankedTensorType(result.getType(), perm)); + } + + return success(); +} + +// Default implementation for folding operand transpose into the operation. +// See `FoldOperandsTransposeInterface::FoldOperandsPermutation`. +template +LogicalResult FoldOperandsPermutation( + ArrayRef permutation, Op *op, + ArrayRef> shuffle_attrs = {}) { + MLIRContext *context = + (*op)->template getParentOfType().getContext(); + + // We only support NHWC <-> NCHW permutations. + static constexpr std::array kNchwToNhwc = {0, 2, 3, 1}; + static constexpr std::array kNhwcToNchw = {0, 3, 1, 2}; + + // Operation data format after folding `permutation`. + StringRef target_data_format = [&]() -> StringRef { + if (op->getDataFormat() == "NHWC" && permutation.equals(kNchwToNhwc)) { + return "NCHW"; // cancel NCHW->NHWC operand permutation + } else if (op->getDataFormat() == "NCHW" && + permutation.equals(kNhwcToNchw)) { + return "NHWC"; // cancel NHWC->NCHW operand permutation + } else { + return ""; + } + }(); + if (target_data_format.empty()) return failure(); + + // To fold operand `permutation` into the `op` we need shuffle all layout + // dependent attributes and types with a reverse permutation, and change + // operation data format to `target_data_format`. + // + // Example: + // %1 = SomeOp(...) {data_format = NHWC} + // %2 = Transpose(%1) {permutation = NHWC->NCHW} + // %3 = Op(%2) {data_format = NCHW} + // + // To bypass %2 we have to change data format to shuffle data format from NCHW + // to NHWC, which is the reverse of operand permutation (function argument). + auto reverse_permutation = + GetDataFormatPermutation(op->getDataFormat(), target_data_format); + if (reverse_permutation.empty()) return failure(); + + (*op)->setAttr("data_format", StringAttr::get(context, target_data_format)); + + for (auto pair : shuffle_attrs) { + StringRef attr_name = pair.first; + ArrayAttr attr_value = pair.second; + (*op)->setAttr(attr_name, + ShuffleArrayAttr(attr_value, reverse_permutation)); + } + + auto fold = cast(op->getOperation()); + for (unsigned idx : fold.GetLayoutDependentResults()) { + OpResult result = op->getOperation()->getResult(idx); + result.setType( + ShuffleRankedTensorType(result.getType(), reverse_permutation)); + } + + return success(); +} + +} // namespace TF +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_OPS_LAYOUT_HELPER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.h new file mode 100644 index 00000000..7812cc4c --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.h @@ -0,0 +1,49 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_OPS_N_Z_H_ +#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_OPS_N_Z_H_ + +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Dialect/Traits.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/Dialect.h" // from @llvm-project +#include "mlir/IR/Matchers.h" // from @llvm-project +#include "mlir/IR/OpImplementation.h" // from @llvm-project +#include "mlir/IR/TypeUtilities.h" // from @llvm-project +#include "mlir/Interfaces/CallInterfaces.h" // from @llvm-project +#include "mlir/Interfaces/ControlFlowInterfaces.h" // from @llvm-project +#include "mlir/Interfaces/DerivedAttributeOpInterface.h" // from @llvm-project +#include "mlir/Interfaces/InferTypeOpInterface.h" // from @llvm-project +#include "mlir/Interfaces/LoopLikeInterface.h" // from @llvm-project +#include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_traits.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_verifiers.h" + +// IWYU pragma: private, include "third_party/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" + +#define GET_OP_FWD_DEFINES +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_all_ops.h.inc" +#define GET_OP_CLASSES +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.h.inc" + +#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_OPS_N_Z_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_tensor_helper.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_tensor_helper.h new file mode 100644 index 00000000..e77ea7d7 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_tensor_helper.h @@ -0,0 +1,94 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_OPS_TENSOR_HELPER_H_ +#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_OPS_TENSOR_HELPER_H_ + +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project + +namespace mlir { + +class Builder; + +namespace TF { + +// Returns the RankedTensorType for the given operand. TensorFlow constant ops +// may have non-static shape because the shape is not propagated during constant +// folding. If the defining op for the given operand is a constant op, this +// routine uses the constant op's attribute to get the actual shape. +RankedTensorType GetRankedTensorTypeForOperand(Value operand); + +// Returns true if the given `value` is of ranked float tensor type with the +// given `rank`. +inline bool IsOfRankedFloatTensorType(RankedTensorType type, int rank) { + return type && type.getRank() == rank && + mlir::isa(type.getElementType()); +} + +// Returns true if the given `value` has the specified rank or has unranked +// type. +inline bool IsOfRankOrUnranked(Value value, int64_t rank) { + RankedTensorType type = GetRankedTensorTypeForOperand(value); + return !type || type.getRank() == rank; +} + +// Returns true if the given `value` has at least the specified rank or has +// unranked type. +inline bool HasRankAtLeast(Value value, int64_t rank) { + RankedTensorType type = GetRankedTensorTypeForOperand(value); + return !type || type.getRank() >= rank; +} + +// Returns true if the given `value` has at most the specified rank or has +// unranked type. +inline bool HasRankAtMost(Value value, int64_t rank) { + RankedTensorType type = GetRankedTensorTypeForOperand(value); + return !type || type.getRank() <= rank; +} + +inline bool IsUnknownDimOrRank(int64_t dim_or_rank) { + return dim_or_rank == -1; +} + +// Returns dimension index for the given TensorFlow axis that supports negative +// indexing. +inline int64_t GetDimForAxis(int64_t axis, int64_t rank) { + return axis >= 0 ? axis : axis + rank; +} + +// Returns the tf.Equal/tf.NotEqual result type given `x` and `y` and inputs. If +// `incompatible_shape_error` is true, reports error if `x` and `y` has +// incompatible shapes. Otherwise, returns a tensor type with unknown rank. +Type DeduceEqualCmpOpType(Builder *builder, Location loc, Value x, Value y, + BoolAttr incompatible_shape_error); + +Type InferReductionOpType(Value input, Value reduction_indices, + BoolAttr keep_dims); + +// Verifies that the given types are cast compatible. If not, emits appropriate +// error for the given op. If mask_one_dim is set to true, then the types are +// allowed to have one mismatching dimension. Masking one of the dimensions is +// useful for ops like Concat that requires all ranked inputs to have the same +// rank and match dimension sizes for all but one of the dimensions. +LogicalResult VerifyTypesCompatibility(Operation::operand_type_range types, + bool mask_one_dim, Operation *op); + +} // namespace TF +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_OPS_TENSOR_HELPER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/ir/tf_remaining_ops.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/ir/tf_remaining_ops.h new file mode 100644 index 00000000..8e9f32cd --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/ir/tf_remaining_ops.h @@ -0,0 +1,45 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_REMAINING_OPS_H_ +#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_REMAINING_OPS_H_ + +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Dialect/Traits.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/Matchers.h" // from @llvm-project +#include "mlir/IR/OpImplementation.h" // from @llvm-project +#include "mlir/IR/TypeUtilities.h" // from @llvm-project +#include "mlir/Interfaces/CallInterfaces.h" // from @llvm-project +#include "mlir/Interfaces/DerivedAttributeOpInterface.h" // from @llvm-project +#include "mlir/Interfaces/InferTypeOpInterface.h" // from @llvm-project +#include "mlir/Interfaces/LoopLikeInterface.h" // from @llvm-project +#include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_traits.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_verifiers.h" + +#define GET_OP_FWD_DEFINES +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_all_ops.h.inc" +#define GET_OP_CLASSES +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_remaining_ops.h.inc" + +#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_REMAINING_OPS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h new file mode 100644 index 00000000..208cd7ae --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h @@ -0,0 +1,123 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_SAVED_MODEL_H_ +#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_SAVED_MODEL_H_ + +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/Dialect.h" // from @llvm-project +#include "mlir/IR/OpDefinition.h" // from @llvm-project + +namespace mlir { +namespace tf_saved_model { + +// The name of the attribute indicating under what name an object is exported. +inline constexpr StringRef kTfSavedModelExportedNamesAttr = + "tf_saved_model.exported_names"; + +// The name of the attribute attached to input arguments or results of a +// function to represent the path which one would use to index into a structured +// value to reach a given tensor. +inline constexpr StringRef kTfSavedModelIndexPathAttr = + "tf_saved_model.index_path"; + +// Name of the attribute that inidicates the type of initializer. It should be +// on a function and the function should exist in the initializers attribute of +// the SessionInitializerOp. +inline constexpr StringRef kTfSavedModelInitializerTypeAttr = + "tf_saved_model.initializer_type"; + +// Indicates that the initializer corresponds to the restore op. +inline constexpr StringRef kTfSavedModelInitializerRestoreType = "restore_op"; + +// Indicates that the initializer corresponds to the init op. +inline constexpr StringRef kTfSavedModelInitializerInitType = "init_op"; + +class TensorFlowSavedModelDialect : public Dialect { + public: + explicit TensorFlowSavedModelDialect(MLIRContext *context); + LogicalResult verifyRegionArgAttribute(Operation *op, unsigned region_index, + unsigned arg_index, + NamedAttribute named_attr) override; + LogicalResult verifyRegionResultAttribute(Operation *op, + unsigned region_index, + unsigned result_index, + NamedAttribute named_attr) override; + LogicalResult verifyOperationAttribute(Operation *op, + NamedAttribute named_attr) override; + + static StringRef getDialectNamespace() { return "tf_saved_model"; } +}; + +} // namespace tf_saved_model +} // namespace mlir + +// Declares the operations for this dialect using the generated header. +#define GET_OP_CLASSES +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h.inc" + +namespace mlir { +namespace tf_saved_model { + +// Returns the list of exported names for `op`. +// An empty list means `op` is not exported. +SmallVector GetExportedNames(Operation *op); + +// Returns true if `op` is exported. +bool IsExported(Operation *op); + +// Returns true if `module` has tf_saved_model linkage semantics. +bool HasTfSavedModelSemantics(ModuleOp module_op); + +// Returns the tf_saved_model.global_tensor op that func's arg_index'th argument +// refers to as a bound input, or null. +Operation *LookupBoundInput(func::FuncOp func, int arg_index, + const SymbolTable &symbol_table); + +template +T LookupBoundInputOfType(func::FuncOp func, int arg_index, + const SymbolTable &symbol_table) { + return llvm::dyn_cast_or_null( + LookupBoundInput(func, arg_index, symbol_table)); +} + +// Gets the type that an exported function arg that is bound to symbol ops such +// as `global_tensor` and `asset` should have. +Type GetBoundInputArgTypeFor(mlir::Operation *op); + +// Returns the session initializer of this module if it exists. Returns null +// otherwise. +SessionInitializerOp GetSessionInitializerOp(ModuleOp module_op); + +// Returns the exported name for the session initializer function. +SmallVector GetSessionInitializerExportedName(ModuleOp module_op); + +// Returns initializer function ops. These functions' symbols are in the +// "initializers" attribute of the session initializer op. +SmallVector GetInitializerFunctions(ModuleOp module_op); + +// Returns the initializer function whose `tf_saved_model.initializer_type` +// attribute matches `initializer_type`. Returns a null op if it doesn't exist. +func::FuncOp GetInitializerFunction(ModuleOp module_op, + StringRef initializer_type); + +// Checks if the module restores variables from a Checkpoint. +bool IsRestoreGraph(ModuleOp module); + +} // namespace tf_saved_model +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_SAVED_MODEL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/ir/tf_side_effects.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/ir/tf_side_effects.h new file mode 100644 index 00000000..6384d077 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/ir/tf_side_effects.h @@ -0,0 +1,136 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This is the side effect definition file for TensorFlow. +#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_SIDE_EFFECTS_H_ +#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_SIDE_EFFECTS_H_ + +#include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project + +namespace mlir { +namespace TF { +namespace ResourceEffects { + +struct Variable : ::mlir::SideEffects::Resource::Base { + StringRef getName() final { return "Variable"; } +}; + +struct Stack : ::mlir::SideEffects::Resource::Base { + StringRef getName() final { return "Stack"; } +}; + +struct TensorArray : ::mlir::SideEffects::Resource::Base { + StringRef getName() final { return "TensorArray"; } +}; + +struct Summary : ::mlir::SideEffects::Resource::Base { + StringRef getName() final { return "Summary"; } +}; + +struct LookupTable : ::mlir::SideEffects::Resource::Base { + StringRef getName() final { return "LookupTable"; } +}; + +struct DatasetSeedGenerator + : ::mlir::SideEffects::Resource::Base { + StringRef getName() final { return "DatasetSeedGenerator"; } +}; + +struct DatasetMemoryCache + : ::mlir::SideEffects::Resource::Base { + StringRef getName() final { return "DatasetMemoryCache"; } +}; + +struct DatasetIterator : ::mlir::SideEffects::Resource::Base { + StringRef getName() final { return "DatasetIterator"; } +}; + +// Special resource type to track TPU Embedding specific ops, which must execute +// but do not have side effects with one another or with resource variable ops. +struct TPUEmbedding : ::mlir::SideEffects::Resource::Base { + StringRef getName() final { return "TPUEmbedding"; } +}; + +// Resource corresponding to GeneratorOp. +struct GeneratorOp : public ::mlir::SideEffects::Resource::Base { + StringRef getName() final { return "Generator"; } +}; + +struct Send : public ::mlir::SideEffects::Resource::Base { + StringRef getName() final { return "Send"; } +}; + +struct Recv : public ::mlir::SideEffects::Resource::Base { + StringRef getName() final { return "Recv"; } +}; + +struct XlaHostCompute + : public ::mlir::SideEffects::Resource::Base { + StringRef getName() final { return "XlaHostCompute"; } +}; + +struct RandomGenerator + : public ::mlir::SideEffects::Resource::Base { + StringRef getName() final { return "RandomGenerator"; } +}; + +struct TPUExecute : public ::mlir::SideEffects::Resource::Base { + StringRef getName() final { return "TPUExecute"; } +}; + +struct MustExecute : public ::mlir::SideEffects::Resource::Base { + StringRef getName() final { return "MustExecute"; } +}; + +struct CollectiveReduceOrdering + : public ::mlir::SideEffects::Resource::Base { + StringRef getName() final { return "CollectiveReduceOrdering"; } +}; + +struct NcclAllReduceOrdering + : public ::mlir::SideEffects::Resource::Base { + StringRef getName() final { return "NcclAllReduceOrdering"; } +}; + +struct GlobalIterId : public ::mlir::SideEffects::Resource::Base { + StringRef getName() final { return "GlobalIterId"; } +}; + +struct XlaLaunch : public ::mlir::SideEffects::Resource::Base { + StringRef getName() final { return "XlaLaunch"; } +}; + +struct WriteTrainingPredictions + : public ::mlir::SideEffects::Resource::Base { + StringRef getName() final { return "WriteTrainingPredictions"; } +}; + +struct _XlaRun : public ::mlir::SideEffects::Resource::Base<_XlaRun> { + StringRef getName() final { return "_XlaRun"; } +}; + +// Returns true iff resource type with given ID is only self-dependent, i.e., +// there are no dependencies to other resource types (including unknown resource +// type). +inline bool IsOnlySelfDependent(TypeID resource_type_id) { + return resource_type_id == ResourceEffects::Send::getResourceID() || + resource_type_id == ResourceEffects::Recv::getResourceID(); +} + +} // namespace ResourceEffects +} // namespace TF +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_SIDE_EFFECTS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h new file mode 100644 index 00000000..2c6fd05d --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h @@ -0,0 +1,69 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This file defines the types used in the standard MLIR TensorFlow dialect. + +#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_STRUCTS_H_ +#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_STRUCTS_H_ + +#include + +#include "llvm/ADT/StringMap.h" +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/Diagnostics.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "tensorflow/core/ir/types/dialect.h" +#include "tensorflow/core/util/device_name_utils.h" + +namespace mlir { +namespace TF { + +using GpuDeviceMetadata = tf_type::GpuDeviceMetadataAttr; + +// Tensorflow devices available at runtime with corresponding metadata if it is +// available. It's completely valid to have a device without any metadata +// attached to it. +class RuntimeDevices { + using DeviceNameUtils = ::tensorflow::DeviceNameUtils; + using ParsedName = ::tensorflow::DeviceNameUtils::ParsedName; + + public: + // Adds a device with and empty metadata. Device can be of any type. + void AddDevice(const ParsedName& device); + + // Adds a GPU device with GPU specific metadata. + void AddGpuDevice(const ParsedName& device, + const GpuDeviceMetadata& metadata); + + llvm::ArrayRef device_names() const { return device_names_; } + size_t NumDevices() const { return device_names_.size(); } + + // Returns GPU device metadata if it is available, otherwise returns None. + std::optional GetGpuDeviceMetadata( + const ParsedName& device) const; + + private: + llvm::SmallVector device_names_; + // TODO(ezhulenev): Add DenseMapInfo specialization to be able to + // use ParsedName as a key in a DenseMap. + llvm::StringMap gpu_metadata_; +}; + +} // namespace TF +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_STRUCTS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/ir/tf_traits.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/ir/tf_traits.h new file mode 100644 index 00000000..c6abd768 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/ir/tf_traits.h @@ -0,0 +1,329 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This file defines the op traits used in the MLIR TensorFlow dialect. + +#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_TRAITS_H_ +#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_TRAITS_H_ + +#include + +#include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/OpDefinition.h" // from @llvm-project +#include "mlir/IR/TypeUtilities.h" // from @llvm-project +#include "mlir/Interfaces/InferTypeOpInterface.h" // from @llvm-project +#include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" + +namespace mlir { +namespace OpTrait { +namespace TF { + +// Verifies if 'ref_type' is a REF type corresponding to 'type'. +static inline LogicalResult VerifyRefTypeMatch(mlir::Type type, + mlir::Type maybe_ref_type) { + if (auto ref_type = + mlir::dyn_cast(maybe_ref_type)) + return success(ref_type.RemoveRef().getTypeID() == type.getTypeID()); + return failure(); +} + +// This class provides verification for ops that are known to have the same +// result types and all operands are either of the same type as result or a REF +// type corresponding to the result type. +// TODO(jpienaar): Update the name and the description. +template +class OperandsSameAsResultsTypeOrRef + : public TraitBase { + public: + static LogicalResult verifyTrait(Operation* op) { + LogicalResult shapeMatch = impl::verifySameOperandsAndResultShape(op); + if (failed(shapeMatch)) return shapeMatch; + Type type = op->getResult(0).getType(); + // Verify that the first result type is same as the rest of the results. + // We skip the comparison against itself. + for (auto result_type : llvm::drop_begin(op->getResultTypes(), 1)) { + if (!mlir::tf_type::HasCompatibleElementTypes(type, result_type)) + return op->emitOpError() + << "requires all return types to have compatible element types"; + } + for (auto operand_type : op->getOperandTypes()) { + if (!mlir::tf_type::HasCompatibleElementTypes( + operand_type, type, /*may_ignore_ref_type_lhs=*/true)) + return op->emitError() << "requires all operands and results to have " + "compatible element types"; + } + return success(); + } +}; + +namespace detail { +inline LogicalResult verifySameOperandsAndResultElementTypeResolveRef( + Operation* op) { + Type element_type; + if (op->getNumResults() > 0) { + element_type = mlir::tf_type::GetElementTypeOrSelfResolveRef( + op->getResult(0).getType()); + } else if (op->getNumOperands() > 0) { + element_type = mlir::tf_type::GetElementTypeOrSelfResolveRef( + op->getOperand(0).getType()); + } else { + // Nothing to check. + return success(); + } + // Verify that all result element types are compatible to `element_type`. + for (const auto& result_type : op->getResultTypes()) { + if (mlir::tf_type::GetElementTypeOrSelfResolveRef(result_type) != + element_type) { + return op->emitOpError( + "requires compatible element types for all operands and results"); + } + } + // Verify that all operand element types are compatible to `element_type`. + for (const auto& operand_type : op->getOperandTypes()) { + if (mlir::tf_type::GetElementTypeOrSelfResolveRef(operand_type) != + element_type) { + return op->emitOpError( + "requires compatible element types for all operands and results"); + } + } + return success(); +} + +inline ShapedType MergeType(ShapedType a, ShapedType b) { + if (!a.hasRank()) { + return b; + } + if (!b.hasRank()) { + return a; + } + int64_t rank = a.getRank(); + SmallVector dims; + dims.resize(rank); + for (int i = 0, e = rank; i != e; i++) { + int64_t dim0 = a.getDimSize(i); + int64_t dim1 = b.getDimSize(i); + dims[i] = (dim0 == ShapedType::kDynamic) ? dim1 : dim0; + } + return RankedTensorType::get(dims, a.getElementType()); +} +} // namespace detail + +// Verifies that op has the same operand and result element types (or type +// itself, if scalar) after resolving reference types (i.e., after converting +// reference types to their corresponding TensorFlow or standard types). +template +class SameOperandsAndResultElementTypeResolveRef + : public TraitBase { + public: + static LogicalResult verifyTrait(Operation* op) { + return detail::verifySameOperandsAndResultElementTypeResolveRef(op); + } +}; + +// Verifies that op has the same operand and result types after resolving +// reference types (i.e., after converting reference types to their +// corresponding TensorFlow or standard types). +template +class SameOperandsAndResultTypeResolveRef + : public TraitBase { + public: + static LogicalResult verifyTrait(Operation* op) { + if (failed(impl::verifySameOperandsAndResultShape(op))) return failure(); + return detail::verifySameOperandsAndResultElementTypeResolveRef(op); + } + + static LogicalResult inferReturnTypeComponentsFromOperands( + MLIRContext*, std::optional location, ValueShapeRange operands, + DictionaryAttr attributes, OpaqueProperties properties, + RegionRange regions, + SmallVectorImpl& inferredReturnShapes) { + if (operands.empty()) + return emitOptionalError( + location, + "Expected non-empty operands for [CompatibleOperandsAndResultType]"); + + auto result_ty = llvm::dyn_cast_or_null(operands[0].getType()); + if (!result_ty) { + return emitOptionalError(location, "Expected shape type for operand 0"); + } + for (auto [index, ty] : + llvm::drop_begin(llvm::enumerate(operands.getTypes()), 1)) { + auto shape_type = llvm::dyn_cast_or_null(ty); + if (!shape_type) { + return emitOptionalError(location, "Expected shape type for operand ", + index); + } + result_ty = detail::MergeType(shape_type, result_ty); + } + inferredReturnShapes.push_back(result_ty); + return success(); + } +}; + +// Layout agnostic operations do not depend on the operands data layout (data +// format), as and example all element wise operations are layout agnostic. +template +class LayoutAgnostic : public TraitBase {}; + +// Trait to indicate operations that cannot be duplicated as they might carry +// certain state around within their implementations. +template +class CannotDuplicate : public TraitBase { + public: + static LogicalResult verifyTrait(Operation* op) { + if (isMemoryEffectFree(op)) + return op->emitError( + "operations with no side effects cannot have CannotDuplicate trait"); + return success(); + } +}; + +// Trait to indicate an operation cannot be constant folded. +template +class NoConstantFold : public TraitBase {}; + +// Coefficient-wise binary operation with implicit broadcasting support, for +// example tf.Sub operation. +template +class CwiseBinary : public TraitBase {}; + +// Coefficient-wise unary operation, for example tf.Sqrt operation. +template +class CwiseUnary : public TraitBase {}; + +namespace detail { + +inline LogicalResult verifyIsIdempotent(Operation* op) { + // TODO(b/246518997): Add back check for no side effects on operation. + // Currently adding it would cause the shared library build + // to fail since there would be a dependency of IR on SideEffectInterfaces + // which is cyclical. + return success(); +} + +inline OpFoldResult foldIdempotent(Operation* op) { + if (op->getNumOperands() == 1) { + auto* argumentOp = op->getOperand(0).getDefiningOp(); + if (argumentOp && op->getName() == argumentOp->getName()) { + // Replace the outer operation output with the inner operation. + return op->getOperand(0); + } + } else if (op->getOperand(0) == op->getOperand(1)) { + return op->getOperand(0); + } + + return {}; +} + +inline LogicalResult verifyIsInvolution(Operation* op) { + // TODO(b/246518997): Add back check for no side effects on operation. + // Currently adding it would cause the shared library build + // to fail since there would be a dependency of IR on SideEffectInterfaces + // which is cyclical. + return success(); +} + +inline OpFoldResult foldInvolution(Operation* op) { + auto* argumentOp = op->getOperand(0).getDefiningOp(); + if (argumentOp && op->getName() == argumentOp->getName()) { + // Replace the outer involutions output with inner's input. + return argumentOp->getOperand(0); + } + + return {}; +} + +} // namespace detail + +// This class adds property that the operation is idempotent. +// This means a unary to unary operation "f" that satisfies f(f(x)) = f(x), +// or a binary operation "g" that satisfies g(x, x) = x. +template +class IsIdempotent : public TraitBase { + public: + static LogicalResult verifyTrait(Operation* op) { + static_assert(ConcreteType::template hasTrait(), + "expected operation to produce one result"); + static_assert(ConcreteType::template hasTrait() || + ConcreteType::template hasTrait::Impl>(), + "expected operation to take one or two operands"); + static_assert( + ConcreteType::template hasTrait(), + "expected operation to preserve type"); + // Idempotent requires the operation to be side effect free as well + // but currently this check is under a FIXME and is not actually done. + return detail::verifyIsIdempotent(op); + } + + static OpFoldResult foldTrait(Operation* op, ArrayRef operands) { + return detail::foldIdempotent(op); + } +}; + +/// This class adds property that the operation is an involution. +/// This means a unary to unary operation "f" that satisfies f(f(x)) = x +template +class IsInvolution : public TraitBase { + public: + static LogicalResult verifyTrait(Operation* op) { + static_assert(ConcreteType::template hasTrait(), + "expected operation to produce one result"); + static_assert(ConcreteType::template hasTrait(), + "expected operation to take one operand"); + static_assert( + ConcreteType::template hasTrait(), + "expected operation to preserve type"); + // TODO(b/246518997): Involution requires the operation to be side effect + // free as well but currently this check is under a FIXME and is not + // actually done. + return detail::verifyIsInvolution(op); + } + + static OpFoldResult foldTrait(Operation* op, ArrayRef operands) { + return detail::foldInvolution(op); + } +}; + +// Indicates that any returned resource is unique. +template +class UniqueResourceAllocation + : public TraitBase { + public: + // Implements method required for `ResourceHandleAllocatorInterface`. + llvm::SmallVector + GetResourceHandleValueAndIdList( + llvm::SmallDenseMap& + resource_handle_id_map, + int64_t& next_id) { + llvm::SmallVector resource_vec; + for (Value resource : + mlir::tf_type::filter_resources(this->getOperation()->getResults())) { + resource_vec.push_back({resource, next_id++}); + } + return resource_vec; + } +}; + +} // namespace TF +} // namespace OpTrait +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_TRAITS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/ir/tf_types.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/ir/tf_types.h new file mode 100644 index 00000000..31233f56 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/ir/tf_types.h @@ -0,0 +1,55 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This file defines the types used in the standard MLIR TensorFlow dialect. + +#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_TYPES_H_ +#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_TYPES_H_ + +#include "tensorflow/core/ir/types/dialect.h" + +namespace mlir { +namespace TF { + +// This all moved under tensorflow/core/ir/types and these using declaration are +// to help with the transition. + +using ::mlir::tf_type::AreCastCompatible; // NOLINT +using ::mlir::tf_type::ArraysAreCastCompatible; // NOLINT +using ::mlir::tf_type::BroadcastCompatible; // NOLINT +using ::mlir::tf_type::DropRefType; // NOLINT +using ::mlir::tf_type::filter_resources; // NOLINT +using ::mlir::tf_type::GetCastCompatibleType; // NOLINT +using ::mlir::tf_type::HasCompatibleElementTypes; // NOLINT +using ::mlir::tf_type::IsValidTFTensorType; // NOLINT +using ::mlir::tf_type::OperandShapeIterator; // NOLINT +using ::mlir::tf_type::ResourceType; // NOLINT +using ::mlir::tf_type::ResultShapeIterator; // NOLINT +using ::mlir::tf_type::ResultShapeRange; // NOLINT +using ::mlir::tf_type::StringType; // NOLINT +using ::mlir::tf_type::TensorFlowRefType; // NOLINT +using ::mlir::tf_type::TensorFlowType; // NOLINT +using ::mlir::tf_type::TensorFlowTypeWithSubtype; // NOLINT +using ::mlir::tf_type::VariantType; // NOLINT + +#define HANDLE_TF_TYPE(tftype, enumerant, name) \ + using tftype##Type = mlir::tf_type::tftype##Type; +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.def" + + +} // end namespace TF +} // end namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_TYPES_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/ir/tf_verifiers.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/ir/tf_verifiers.h new file mode 100644 index 00000000..8fbf54c4 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/ir/tf_verifiers.h @@ -0,0 +1,41 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_VERIFIERS_H_ +#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_VERIFIERS_H_ + +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project + +namespace mlir { +namespace TF { + +// Verifies correctness of ops implementing LayoutSensitiveInterface (see +// definition in tf_op_base.td): +// (1) Operation must have valid `data_format` attribute. +// (2) Layout dependent arguments and results indices must be in +// [0, getNumOperands/getNumResults) range. +LogicalResult VerifyLayoutSensitiveInterface(Operation* op); + +// Verifies correctness of ops implementing FoldOperandsTransposeInterface (see +// definition in tf_op_base.td): +// (1) Layout dependent arguments and results indices must be in +// [0, getNumOperands/getNumResults) range. +LogicalResult VerifyFoldOperandsTransposeInterface(Operation* op); + +} // namespace TF +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_VERIFIERS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/ir/tpu_embedding_ops_registry.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/ir/tpu_embedding_ops_registry.h new file mode 100644 index 00000000..c8160418 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/ir/tpu_embedding_ops_registry.h @@ -0,0 +1,59 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TPU_EMBEDDING_OPS_REGISTRY_H_ +#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TPU_EMBEDDING_OPS_REGISTRY_H_ + +#include "llvm/ADT/DenseSet.h" +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/Support/TypeID.h" // from @llvm-project + +namespace mlir { +namespace TF { + +// A global ops registry that is used to hold TPU embedding ops. +// +// Example: +// TPUEmbeddingOpsRegistry::Global().Add(); +// for (auto op_type_id : TPUEmbeddingOpsRegistry::Global().GetOpsTypeIds()) +// { +// ... +// } +class TPUEmbeddingOpsRegistry { + public: + // Add the op to the registry. + // + // Adding an op here will allow use old bridge legalization from the MLIR + // bridge with the use of fallback mechanism. Therefore, addition of any op + // here must have a python test with MLIR bridge enabled to verify that the + // fallback works correctly. + template + void Add() { + ops_type_ids_.insert(TypeID::get()); + } + + // Returns the type id of the ops in the TPUEmbeddingOpRegistry. + const llvm::SmallDenseSet& GetOpsTypeIds(); + + // Returns the global registry. + static TPUEmbeddingOpsRegistry& Global(); + + private: + llvm::SmallDenseSet ops_type_ids_{}; +}; +} // namespace TF +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TPU_EMBEDDING_OPS_REGISTRY_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/transforms/bridge.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/transforms/bridge.h new file mode 100644 index 00000000..81af0f63 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/transforms/bridge.h @@ -0,0 +1,46 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_BRIDGE_H_ +#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_BRIDGE_H_ + +#include + +#include "absl/base/attributes.h" +#include "llvm/ADT/StringRef.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "tensorflow/core/lib/core/status.h" + +namespace mlir { +namespace TF { + +inline constexpr char kStandardPipelineBefore[] = "standard_pipeline_before"; +inline constexpr char kStandardPipelineAfter[] = "standard_pipeline_after"; + +// Runs all passes involved in transforming or optimizing an MLIR graph without +// any target specialization. When enable_logging is true, enables +// tensorflow::BridgeLogger. When enable_inliner is true, enables the inliner +// pass. +ABSL_DEPRECATED( + "This is legacy code and is unsupported. Use at your own risk. Use " + "tf2xla/api/v2/* for specific functionality") +absl::Status RunBridgeWithStandardPipeline(ModuleOp module, bool enable_logging, + bool enable_inliner); +} // namespace TF + +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_BRIDGE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/transforms/cluster_ops_by_policy.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/transforms/cluster_ops_by_policy.h new file mode 100644 index 00000000..e3c0ee5c --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/transforms/cluster_ops_by_policy.h @@ -0,0 +1,294 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_CLUSTER_OPS_BY_POLICY_H_ +#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_CLUSTER_OPS_BY_POLICY_H_ + +#include +#include + +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/SmallVector.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/Region.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" + +namespace mlir { +namespace TFDevice { + +// -------------------------------------------------------------------------- // +// ValueConstraint. +// -------------------------------------------------------------------------- // + +// In order to be clustered operation can require its operands to satisfy +// some constraints (e.g. reduction operation can require reduction dimension +// operand to be a constant value). +enum class ValueConstraint { + // Operand must have statically known rank. + kRank = 0, + // Operand must have statically known shape (all dimensions are known at + // compile time). + kShape = 1, + // Operand must have statically known value (operand must be defined by a + // constant operation). + kValue = 2, +}; + +// Returns the more restrictive constraint of `a` and `b`: +// +// Value >> Shape >> Rank +// +// If you know the value, you always know the shape and the rank. If you know +// the shape, you always know the rank. +ValueConstraint Merge(ValueConstraint a, ValueConstraint b); + +// Returns success if constraint can be resolved statically based on the value +// type, e.g. `shape` constraint can be resolved if the value is a tensor of +// statically known shape. +LogicalResult IsStaticallyResolved(Value value, ValueConstraint constraint); + +raw_ostream& operator<<(raw_ostream& os, const ValueConstraint& constraint); + +// -------------------------------------------------------------------------- // +// ValuesConstraintSet. +// -------------------------------------------------------------------------- // + +// A set of constraints for values, that either operation results or operands. +class ValuesConstraintSet { + using ConstraintsMap = llvm::SmallDenseMap; + using ConstIterator = typename ConstraintsMap::const_iterator; + + public: + ValuesConstraintSet() = default; + + // Inserts a new constraint for the `value`. If the `value` already has some + // constraint, it will merge it with a new one, and will return a new + // constraint value. Returned pair has a constraint value that was set for + // a value, and a boolean flag that is true if the constraint was updated. + std::pair Insert(Value value, + ValueConstraint constraint); + + // Inserts constraints for multiple values. + void Insert(ValueRange value, ValueConstraint constraint); + + // Walk all the constraints owned by this set. + void Walk(llvm::function_ref walk) const; + + // Returns the constraint of the value if it exists, or None otherwise. + std::optional GetConstraint(Value value) const; + bool HasConstraint(Value value) const; + + // Merges all constrains from the other constraints set into this one. + void MergeAll(const ValuesConstraintSet& other); + + // Remove constraints that can be statically resolved from the type of the + // constrained value (see `IsStaticallyResolved` defined above). + ValuesConstraintSet& Resolve(); + + // Reset all constraints. + ValuesConstraintSet& Reset(); + + // Return the number of constrained values in the set. + size_t Size() const; + + // Returns true if the constraint set is empty. + bool Empty() const; + + ConstIterator begin() const { return constraints_.begin(); } + ConstIterator end() const { return constraints_.end(); } + + private: + llvm::SmallDenseMap constraints_; +}; + +// -------------------------------------------------------------------------- // +// ClusteringPolicy. +// -------------------------------------------------------------------------- // + +// Clustering policy specifies if the operation can be clustered (in practice it +// usually means that operation can be added to a cluster that will be later +// compiled) given the set of constraints on its results, and might propagate or +// create new constraints on the operation operands. +// +// Clustering policy must make a local decision just for a single operation. It +// is the responsibility of a clustering pass to combine all these individual +// operations constraints to form a valid cluster. +// +// Example: compilation using XLA (MHLO) lowering +// +// %0 = "tf.Transpose"(%input, %perm) +// : (tensor, tensor<2xi32>) -> tensor +// +// XLAs `mhlo.transpose` operation requires permutation to be an attribute +// (compile time value), so it means that if we want to put `tf.Transpose` +// into a cluster that will be compiled with XLA, the `%perm` operand must +// be a known compiled time value, e.g. result of a `tf.Const` operation. +// +class ClusteringPolicy { + public: + virtual ~ClusteringPolicy() = default; + + // Returns success if an operation can be clustered given the constraints on + // the operation results. Updates operands constraits to satisfy all the + // results constraints. + virtual LogicalResult MatchAndUpdateConstraints( + Operation* operation, const ValuesConstraintSet& results, + ValuesConstraintSet& operands) const = 0; +}; + +// Clustering policy for a specific operation type. +template +class OpClusteringPolicy : public ClusteringPolicy { + public: + LogicalResult MatchAndUpdateConstraints( + Operation* operation, const ValuesConstraintSet& results, + ValuesConstraintSet& operands) const final { + if (auto op = dyn_cast(operation)) + return MatchAndUpdateConstraints(op, results, operands); + return failure(); + } + + virtual LogicalResult MatchAndUpdateConstraints( + OpTy op, const ValuesConstraintSet& results, + ValuesConstraintSet& operands) const = 0; +}; + +// -------------------------------------------------------------------------- // +// ClusteringPolicySet. +// -------------------------------------------------------------------------- // + +// A set of clustering policies for different operations. +class ClusteringPolicySet { + public: + using Policies = std::vector>; + + const Policies& policies() const { return policies_; } + + // Add an instance of each of the policy types 'Ts'. Return a reference to + // `this` for chaining insertions. + template + ClusteringPolicySet& Add() { + (void)std::initializer_list{0, (AddImpl(), 0)...}; + return *this; + } + + // ClusteringPolicySet is move only type. + ClusteringPolicySet() = default; + ClusteringPolicySet(const ClusteringPolicySet&) = delete; + ClusteringPolicySet(ClusteringPolicySet&&) = default; + ClusteringPolicySet& operator=(const ClusteringPolicySet&) = delete; + ClusteringPolicySet& operator=(ClusteringPolicySet&&) = default; + + private: + template + void AddImpl(Args&&... args) { + static_assert(std::is_base_of::value, + "T must implement ClusteringPolicy"); + policies_.emplace_back(std::make_unique(std::forward(args)...)); + } + + std::vector> policies_; +}; + +// -------------------------------------------------------------------------- // +// Discovering clusters of operations based on the policy. +// -------------------------------------------------------------------------- // + +// Cluster groups together operations in the single basic block based on the +// given clustering policy set. Clusters can be outlined into nested modules +// later device specific compilation (e.g. for TFRT JIT compiler). +struct Cluster { + llvm::SmallVector operations; + ValuesConstraintSet constraints; +}; + +// Returns clusters of operations in the given `block` based on the provided +// clustering policy. If `filter` is defined, it will be used to filter +// operations that can be considered for clustering based on the policy. +// +// TODO(ezhulenev): Additional filter function is a workaround for customizing +// clustering policies at runtime for experimentation. In the long term, +// clustering policy should be enough. +llvm::SmallVector FindClustersInTheBlock( + Block* block, const ClusteringPolicySet& policies, + std::function filter = {}); + +// Creates a `tf_device.cluster` operation from the clustered operations. +tf_device::ClusterOp CreateClusterOp(Cluster& cluster, StringAttr policy = {}); + +// -------------------------------------------------------------------------- // +// Helper functions for value constraints propagations and analysis. +// -------------------------------------------------------------------------- // + +// Propagates initial constraints on the values defined by the `constraints` set +// with operations in the `root` as a starting point, using user provided set of +// clustering policies. +// +// Filter predicate specifies if constraints should be propagated across the +// given operation. Operations in the root set will be also filtered using +// the `filter` predicate. +// +// Optionally resolve constraints that can be statically satisfied by the +// value type, and stop constraints propagation early. +// +// Optionally emits remarks attached to operation that failed to propagate +// results constraints to its operands (for testing purpose). +// +// Returns failure if constraints can't be propagated through some of the +// operations accepted by the filter (there is no clustering policy for an +// operation, or constraints can't be satisfied by the policy), and attaches +// error diagnostics to the operation that prevented constraints propagation. +mlir::LogicalResult PropagateValuesConstraints( + llvm::ArrayRef root, std::function filter, + const ClusteringPolicySet& policies, ValuesConstraintSet& constraints, + bool resolve = false, bool emit_remarks = false); + +// Propagates initial constraints on the values in the `region` to the other +// values in the same region, using user provided set of clustering policies. +mlir::LogicalResult PropagateValuesConstraints( + mlir::Region& region, const ClusteringPolicySet& policies, + ValuesConstraintSet& constraints, bool resolve = false, + bool emit_remarks = false); + +// Emits constraints remarks for all operations that use constrained values. +void EmitValueConstraintsRemarks(const ValuesConstraintSet& constraints); + +// Emits constraints remarks for function inputs that are in the constraints +// set (entry block arguments have constraints). +void EmitInputsConstraintsRemarks(func::FuncOp func, + const ValuesConstraintSet& constraints); + +// Infers constraints for the values in the function body from the function +// results attributes. +// +// Example: +// func @test(...) -> (tensor {tf.constraint = "shape"}) { +// ..... +// %v = "some_operation"() : () -> tensor +// return %v : tensor +// } +LogicalResult InferFunctionBodyValuesConstraints( + func::FuncOp func, ValuesConstraintSet& constraints); + +} // namespace TFDevice +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_CLUSTER_OPS_BY_POLICY_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/transforms/collection_ops_util.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/transforms/collection_ops_util.h new file mode 100644 index 00000000..e43a1ec4 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/transforms/collection_ops_util.h @@ -0,0 +1,109 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_COLLECTION_OPS_UTIL_H_ +#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_COLLECTION_OPS_UTIL_H_ + +#include + +#include "llvm/ADT/ArrayRef.h" +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" + +namespace mlir { +namespace TF { +namespace collection_ops_util { + +// This file includes utilities for decomposing collection ops (stack, tensor +// list, tensor array) in TF. We represent such a data structure as a buffer of +// shape [max_element_count, element_shape]. + +// Creates an i32 scalar tf.Const. +Value CreateScalarConst(int32_t value, OpBuilder builder, Location loc); + +// Creates an integer vector tf.Const. +Value GetR1Const(ArrayRef r1, OpBuilder builder, Location loc, + int bitwidth = 32); + +// Returns the type of the size tensor used to track a data structure's element +// count. It is a tensor<1xi32>, and we use R1 instead of a scalar because it is +// easier to concat it with other offsets. +TensorType GetSizeType(OpBuilder builder); + +// Reshapes a scalar value to match the size type tensor. +Value ReshapeScalarToSizeType(OpBuilder builder, Value scalar, Location loc); + +// Creates ops that represent the indices of the slice for an element in the +// buffer. Requires `index` to have tensor<1xi32> type. +Value GetIndicesForElement(Value index, Value buffer, OpBuilder builder, + Location loc); + +// Creates ops that slice the element out of a buffer at the given index. +// Requires `index` to have tensor<1xi32> type. +Value GetElement(Value index, Value buffer, OpBuilder builder, Location loc, + bool keep_slice_shape = false); + +// Creates ops that copy the buffer and update an element at the given index. +// Requires `index` to have tensor<1xi32> type. +Value SetElement(Value index, Value buffer, Value element, OpBuilder builder, + Location loc); + +// Creates the buffer for the data structure with given element shape, type and +// maximum size. +LogicalResult CreateInitBufferValue(ArrayRef element_shape, + int64_t max_size, Operation* op, + Type element_dtype, OpBuilder builder, + Value* buffer); + +// Same as above, but uses a Value as max_size and check if it is a constant. +LogicalResult CreateInitBufferValue(ArrayRef element_shape, + Value max_size, Operation* op, + Type element_dtype, OpBuilder builder, + Value* buffer); + +// Tries to infer the element type with full shape based its write accesses. +// `infer_from_user` should check if the provided op is an accessing op that +// could be used to infer the type. +std::optional GetElementTypeFromAccess( + Value collection, ModuleOp module, + llvm::function_ref(Operation*)> infer_from_op); + +// Creates a ReadVariableOp on a local variable. +Value ReadLocalVariable(Value local_var, OpBuilder builder, Location loc); + +// Creates an AssignVariableOp on a local variable. +TF::AssignVariableOp WriteLocalVariable(Value local_var, Value value, + OpBuilder builder, Location loc); + +// Adds two values, or creates a logical-or if they are boolean type. +Value AccumulateBuffers(Value a, Value b, OpBuilder builder, Location loc); + +// Gathers elements in buffer with the indices. +Value GatherElements(Value indices, Value buffer, OpBuilder builder, + Location loc); + +// Scatters elements into buffer, where each scattered element is accumulated +// with the old value in buffer. +Value ScatterAccumulateElements(Value indices, Value updates, Value buffer, + OpBuilder builder, Location loc); + +} // namespace collection_ops_util +} // namespace TF +} // namespace mlir +#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_COLLECTION_OPS_UTIL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/transforms/constant_fold.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/transforms/constant_fold.h new file mode 100644 index 00000000..887eea74 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/transforms/constant_fold.h @@ -0,0 +1,35 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_CONSTANT_FOLD_H_ +#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_CONSTANT_FOLD_H_ + +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/SmallVector.h" +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project + +namespace mlir { +namespace TF { + +LogicalResult ConstantFoldFallbackHook( + Operation *inst, ArrayRef operands, + SmallVectorImpl &results); // NOLINT + +} // namespace TF +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_CONSTANT_FOLD_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/transforms/constant_fold_utils.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/transforms/constant_fold_utils.h new file mode 100644 index 00000000..636dde98 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/transforms/constant_fold_utils.h @@ -0,0 +1,38 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_CONSTANT_FOLD_UTILS_H_ +#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_CONSTANT_FOLD_UTILS_H_ + +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/SmallVector.h" +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project + +namespace mlir { +namespace TF { + +// Checks whether the given TF operation can be folded or not. +bool CanBeFolded(Operation* inst); + +// Evaluates the operation with given operand values. +LogicalResult EvaluateOperation(Operation* inst, + llvm::ArrayRef operands, + llvm::SmallVector& results); + +} // namespace TF +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_CONSTANT_FOLD_UTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/transforms/decompose_resource_ops.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/transforms/decompose_resource_ops.h new file mode 100644 index 00000000..c8f0c84b --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/transforms/decompose_resource_ops.h @@ -0,0 +1,37 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_DECOMPOSE_RESOURCE_OPS_H_ +#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_DECOMPOSE_RESOURCE_OPS_H_ + +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project + +namespace mlir { +namespace TF { + +// Populates rewrite patterns that decompose composite resource operations into +// primitive ones like ReadVariableOp, AssignVariableOp and other computations +// to facilitate transformations like resource op lifting. +// NOTE: These patterns do not support `use_locking=true` for a lot of resource +// operations. So decomposition may not be correct outside of backends like XLA, +// which automatically locks all resource variables. +void PopulateDecomposeResourceOpsPatterns(MLIRContext *context, + RewritePatternSet *patterns); + +} // namespace TF +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_DECOMPOSE_RESOURCE_OPS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/transforms/einsum.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/transforms/einsum.h new file mode 100644 index 00000000..65e05280 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/transforms/einsum.h @@ -0,0 +1,55 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// This pass identifies patterns for certain Einsum Ops and replaces them +// with other equivalent TF Ops. + +#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_EINSUM_H_ +#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_EINSUM_H_ + +#include +#include + +#include "llvm/ADT/ArrayRef.h" +#include "llvm/Support/Casting.h" +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Matchers.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/TypeUtilities.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/core/util/matmul_bcast.h" + +namespace mlir { +namespace TF { + +// TF.Einsum provides fully general tensor contractions. For a few select +// cases, we can convert this op to other TF Ops, which in later passes +// properly convert to TF Lite ops. +struct ConvertTFEinsumOp : public OpRewritePattern { + public: + explicit ConvertTFEinsumOp(MLIRContext* context) + : OpRewritePattern(context) {} + + LogicalResult matchAndRewrite(TF::EinsumOp op, + PatternRewriter& rewriter) const override; +}; + +} // namespace TF +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_EINSUM_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/transforms/graph_optimization_pass.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/transforms/graph_optimization_pass.h new file mode 100644 index 00000000..0de93ca4 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/transforms/graph_optimization_pass.h @@ -0,0 +1,53 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_GRAPH_OPTIMIZATION_PASS_H_ +#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_GRAPH_OPTIMIZATION_PASS_H_ + +#include + +#include "tensorflow/compiler/mlir/mlir_graph_optimization_pass.h" + +namespace mlir { +namespace TF { + +// Bundle generic MLIR graph optimization passes (some derived from TF Grappler +// graph optimizers) into a single MLIR optimization pass. +class MlirGraphOptimizationPass : public ::tensorflow::MlirOptimizationPass { + public: + llvm::StringRef name() const override { return "graph_optimization"; } + + ::tensorflow::MlirOptimizationPassState GetPassState( + const ::tensorflow::DeviceSet* device_set, + const ::tensorflow::ConfigProto& config_proto, + const tensorflow::Graph& graph, + const tensorflow::FunctionLibraryDefinition& function_library) + const override { + return config_proto.experimental().enable_mlir_graph_optimization() + ? tensorflow::MlirOptimizationPassState::Enabled + : tensorflow::MlirOptimizationPassState::Disabled; + } + + absl::Status Run( + const std::string& function_name, + const ::tensorflow::ConfigProto& config_proto, ModuleOp module, + const ::tensorflow::Graph& graph, + const tensorflow::FunctionLibraryDefinition& function_library) override; +}; + +} // namespace TF +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_GRAPH_OPTIMIZATION_PASS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/transforms/group_by_dialect.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/transforms/group_by_dialect.h new file mode 100644 index 00000000..5fe8ab12 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/transforms/group_by_dialect.h @@ -0,0 +1,35 @@ +/* Copyright 2022 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_GROUP_BY_DIALECT_H_ +#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_GROUP_BY_DIALECT_H_ + +#include + +#include "mlir/Pass/Pass.h" // from @llvm-project + +namespace mlir { +namespace TF { + +// Create a pass that groups ops into functions that only contain one dialect. +std::unique_ptr CreateGroupByDialectPass(); + +// Register this pass in the global registry of MLIR. +void RegisterGroupByDialectPass(); + +} // namespace TF +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_GROUP_BY_DIALECT_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/lower_cluster_to_runtime_ops.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/lower_cluster_to_runtime_ops.h new file mode 100644 index 00000000..3640f53a --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/lower_cluster_to_runtime_ops.h @@ -0,0 +1,53 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_HOST_RUNTIME_LOWER_CLUSTER_TO_RUNTIME_OPS_H_ +#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_HOST_RUNTIME_LOWER_CLUSTER_TO_RUNTIME_OPS_H_ + +#include "absl/base/attributes.h" +#include "llvm/ADT/StringRef.h" +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/Pass/PassManager.h" // from @llvm-project +#include "xla/tsl/framework/device_type.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { +namespace tfrt_compiler { + +// Given a MLIR module with tf_device.cluster ops, insert specific Runtime ops +// such as TPUExecute or XlaExecute depending on the device type and specific +// host runtime. Also does some optimization. Will return an error if it fails. +// The output Runtime ops depends on both Device Type and Runtime Host. +// +// Input: +// Tensorflow Dialect MLIR with tf_device.cluster ops and virtual devices. +// xla_device_type - The device type that is being targeted. +// Output: +// Tensorflow Dialect MLIR with Runtime specific ops. All tf_device.cluster +// ops are removed. Physical devices are assigned to ops instead of virtual +// devices. +absl::Status RunLowerClusterToRuntimeOpsPassPipeline( + mlir::ModuleOp module, tsl::DeviceType xla_device_type, + llvm::StringRef module_name = llvm::StringRef()); + +// The same API as RunLowerClusterToRuntimeOpsPassPipeline but as an MLIR pass +// pipeline. +void RegisterTPULowerClusterToRuntimeOpsPassPipeline(); +void RegisterNonTPULowerClusterToRuntimeOpsPassPipeline(); + +} // namespace tfrt_compiler +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_HOST_RUNTIME_LOWER_CLUSTER_TO_RUNTIME_OPS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/runtime_passes.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/runtime_passes.h new file mode 100644 index 00000000..7012d6a7 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/runtime_passes.h @@ -0,0 +1,53 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_HOST_RUNTIME_RUNTIME_PASSES_H_ +#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_HOST_RUNTIME_RUNTIME_PASSES_H_ + +#include + +#include "llvm/ADT/StringRef.h" +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project + +namespace mlir { +namespace TFTPU { + +// Creates a pass that rewrites `tf_device.launch_func` on TPUs into TPU runtime +// ops. +std::unique_ptr> CreateTPURewritePass( + llvm::StringRef module_name = llvm::StringRef()); + +// Creates a pass that adds ops which perform formatting on variables at +// run-time according to compilation result. +std::unique_ptr> +CreateTPUVariableRuntimeReformattingPass(); + +// Creates a pass that merges device variable reads/updates into the surrounded +// TPUExecute node. This allows the execute node to perform in-place variable +// updates. +std::unique_ptr> +CreateTPUMergeVariablesWithExecutePass(); + +#define GEN_PASS_REGISTRATION +#define GEN_PASS_DECL_TPUMERGEVARIABLESWITHEXECUTEPASS +#define GEN_PASS_DECL_TPUREWRITEPASS +#define GEN_PASS_DECL_TPUVARIABLERUNTIMEREFORMATTINGPASS +#include "tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/runtime_passes.h.inc" + +} // namespace TFTPU +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_HOST_RUNTIME_RUNTIME_PASSES_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/tpu_metadata_utils.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/tpu_metadata_utils.h new file mode 100644 index 00000000..b58401eb --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/tpu_metadata_utils.h @@ -0,0 +1,43 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_HOST_RUNTIME_TPU_METADATA_UTILS_H_ +#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_HOST_RUNTIME_TPU_METADATA_UTILS_H_ + +#include + +#include "mlir/IR/Diagnostics.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" +#include "xla/xla.pb.h" +#include "xla/xla_data.pb.h" +#include "tensorflow/core/framework/tensor_shape.pb.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h" + +namespace mlir { +namespace TFTPU { + +// Populates a TPUCompileMetadataProto from attributes of a +// `tf_device::ClusterFuncOp`. If any necessary attributes are missing from the +// op, a failure will be returned. +// TODO(lyandy): Support session handle and guaranteed consts. +LogicalResult SetMetadataProtoFromClusterFuncOp( + tf_device::ClusterFuncOp op, int num_replicas, int num_cores_per_replica, + std::optional&& xla_device_assignment, + tensorflow::tpu::TPUCompileMetadataProto* metadata); +} // namespace TFTPU +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_HOST_RUNTIME_TPU_METADATA_UTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/transforms/initialize_variables_in_session_init.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/transforms/initialize_variables_in_session_init.h new file mode 100644 index 00000000..623e5f4e --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/transforms/initialize_variables_in_session_init.h @@ -0,0 +1,35 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_INITIALIZE_VARIABLES_IN_SESSION_INIT_H_ +#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_INITIALIZE_VARIABLES_IN_SESSION_INIT_H_ + +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "tensorflow/core/public/session.h" + +namespace mlir { +namespace tf_saved_model { + +// Initializes all variables in Session Init function for all variables in +// 'session'. +LogicalResult InitializeVariablesInSessionInitializer( + ModuleOp module, tensorflow::Session *session); + +} // namespace tf_saved_model + +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_INITIALIZE_VARIABLES_IN_SESSION_INIT_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/transforms/lift_variables.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/transforms/lift_variables.h new file mode 100644 index 00000000..a0a218f6 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/transforms/lift_variables.h @@ -0,0 +1,35 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_LIFT_VARIABLES_H_ +#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_LIFT_VARIABLES_H_ + +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "tensorflow/core/public/session.h" + +namespace mlir { +namespace tf_saved_model { + +// Creates GlobalTensorOp for each variable from function arguments and converts +// them to the corresponding saved model arguments. +LogicalResult LiftVariables(ModuleOp module, ::tensorflow::Session* session, + bool import_variables_as_dense_resources = false); + +} // namespace tf_saved_model +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_LIFT_VARIABLES_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/transforms/lower_globals_to_ml_program.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/transforms/lower_globals_to_ml_program.h new file mode 100644 index 00000000..e66bf22f --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/transforms/lower_globals_to_ml_program.h @@ -0,0 +1,35 @@ +/* Copyright 2022 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_LOWER_GLOBALS_TO_ML_PROGRAM_H_ +#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_LOWER_GLOBALS_TO_ML_PROGRAM_H_ + +#include + +#include "mlir/Pass/Pass.h" // from @llvm-project + +namespace mlir { +namespace tf_saved_model { + +// Create a pass that removes function arguments that map to global tensors. +std::unique_ptr CreateLowerGlobalsToMlProgramPass(); + +// Register this pass in the global registry of MLIR. +void RegisterLowerGlobalsToMlProgramPass(); + +} // namespace tf_saved_model +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_LOWER_GLOBALS_TO_ML_PROGRAM_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.h new file mode 100644 index 00000000..b8e26302 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.h @@ -0,0 +1,52 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_LOWER_TF_H_ +#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_LOWER_TF_H_ + +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project + +namespace mlir { +namespace TF { + +// Populates TensorFlow lowering patterns to lower some of the TensorFlow +// operations that can be represented using other TensorFlow operations. +// TODO(laurenzo): For some reason, TFLite uses this pass and has exact +// requirements on what it can do. This is fragile and should be fixed (at a +// minimum, names should clearly convey scope). In the mean time, for a real +// compiler, use PopulateTFLoweringBeforeHLOPatterns. +void PopulateLoweringTFPatterns(MLIRContext *context, + RewritePatternSet *patterns); + +// Populates TensorFlow lowering patterns to lower some of the TensorFlow +// operations that can be represented by means of other TensorFlow operations. +// This pattern collection preserves those TensorFlow operations that will later +// be lowered to equivalent operations in CHLO or MHLO. This allows for +// HLO-specific lowerings. +void PopulateTFLoweringBeforeHLOPatterns(MLIRContext *context, + RewritePatternSet *patterns); + +// Populates TensorFlow lowering patterns to lower some of the TensorFlow +// operations that can be represented using other TensorFlow operations. +// Patterns are from ops with some inputs or outputs that are quantized types +// only to ops that allow non-quantized types on all inputs and outputs. +void PopulateLoweringQuantizedPatterns(MLIRContext *context, + RewritePatternSet *patterns); + +} // namespace TF +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_LOWER_TF_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/transforms/mark_initialized_variables.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/transforms/mark_initialized_variables.h new file mode 100644 index 00000000..4a18c096 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/transforms/mark_initialized_variables.h @@ -0,0 +1,41 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_MARK_INITIALIZED_VARIABLES_H_ +#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_MARK_INITIALIZED_VARIABLES_H_ + +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "tensorflow/core/public/session.h" + +namespace mlir { +namespace tf_saved_model { +// Marks all variables in 'function' whether they are initialized +// in 'session' or not by setting an attribute named 'is_initialized' +// on each variable op with value true/false based on variable is initialized +// in the session or not. +// If 'session' is NULL the function is no-op. +// Returns failure in case fetching variables from session failed, success +// otherwise. +LogicalResult MarkInitializedVariablesInFunction(func::FuncOp function, + tensorflow::Session* session); +// Apply `MarkInitializedVariablesInFunction` to every non-empty function in the +// module. +LogicalResult MarkInitializedVariablesInFunction(ModuleOp module, + tensorflow::Session* session); + +} // namespace tf_saved_model +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_MARK_INITIALIZED_VARIABLES_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/transforms/mlprogram.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/transforms/mlprogram.h new file mode 100644 index 00000000..61cbf4d0 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/transforms/mlprogram.h @@ -0,0 +1,32 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_MLPROGRAM_H_ +#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_MLPROGRAM_H_ + +#include "llvm/ADT/STLFunctionalExtras.h" +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/Pass/PassManager.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Transforms/Passes.h" // from @llvm-project + +namespace tensorflow { + +void PopulateLowerToMlProgramAndHloPipeline(mlir::OpPassManager& pm); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_MLPROGRAM_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/transforms/order_by_dialect.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/transforms/order_by_dialect.h new file mode 100644 index 00000000..0268a89a --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/transforms/order_by_dialect.h @@ -0,0 +1,36 @@ +/* Copyright 2022 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_ORDER_BY_DIALECT_H_ +#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_ORDER_BY_DIALECT_H_ + +#include + +#include "mlir/Pass/Pass.h" // from @llvm-project + +namespace mlir { +namespace TF { + +// Create an instance of a pass that reorders ops so ops of the same dialect are +// next to each other. +std::unique_ptr CreateOrderByDialectPass(); + +// Register this pass in the global registry of MLIR. +void RegisterOrderByDialectPass(); + +} // namespace TF +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_ORDER_BY_DIALECT_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/transforms/passes.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/transforms/passes.h new file mode 100644 index 00000000..54dad08e --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/transforms/passes.h @@ -0,0 +1,715 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_PASSES_H_ +#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_PASSES_H_ + +#include +#include +#include + +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/STLFunctionalExtras.h" +#include "llvm/Support/CommandLine.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributeInterfaces.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassOptions.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" + +namespace mlir { + +// Creates a pass that breaks up an island with multiple ops into multiple +// islands, each with a single op. +std::unique_ptr> CreateBreakUpIslandsPass(); + +// Creates a pass that converts mlir functions consisting of mlir ops into a +// tf_executor dialect as a single island. +std::unique_ptr> +CreateFunctionalToExecutorDialectConversionPass(); + +// Creates a pass that lifts inner ops of tf_executor.island ops in +// tf_executor.graph into the same block as the tf_executor.graph. +std::unique_ptr> +CreateExecutorDialectToFunctionalConversionPass(); + +namespace TF { +// Creates a pass that canonicalizes legacy compilation and replication +// attributes. +std::unique_ptr> +CreateCanonicalizeCompileAndReplicateAttributesPass(); + +// Creates a pass that drops `shape_invariant` attribute from While/WhileRegion +// ops. +std::unique_ptr> +CreateDropWhileShapeInvariantPass(); + +// Creates a pass that drops `shape_invariant` attribute from While/WhileRegion +// ops within device cluster. +std::unique_ptr> +CreateDropWhileShapeInvariantInDeviceClusterPass(); + +// Creates a pass that moves writes to replicate invariant resource variables +// outside tf_device.replicate op. +std::unique_ptr> +CreateHoistReplicateInvariantResourceWritesPass(); + +// Transforms functional control flow operations in the TensorFlow dialect to +// MLIR Control Flow Graph (CFG) form. +std::unique_ptr> +CreateTFFunctionalControlFlowToCFG(); + +// Transforms functional control flow operations in the TensorFlow dialect to +// their region based counterparts. +std::unique_ptr> +CreateTFFunctionalControlFlowToRegions(); +std::unique_ptr> CreateTFFunctionalControlFlowToRegions( + bool allow_passthrough_args); + +// Transforms region bases control flow operations in the TensorFlow dialect to +// their functional counterparts. +std::unique_ptr> +CreateTFRegionControlFlowToFunctional(); + +// Materialize the MlirPassthroughOp by replacing it with the MLIR module +// attached as an attribute. +std::unique_ptr> +CreateMaterializePassthroughOpPass(); + +// Replicates the TensorList init op by undoing some CSE needed for correct +// shape assignment in shape_inference. +std::unique_ptr> +CreateReplicateTensorListInitOpsPass(); + +// Performs Shape Inference on the TensorFlow dialect using the global registry. +std::unique_ptr> CreateTFShapeInferencePass( + ArrayRef> input_shapes = {}); + +// Performs TF.data optimizations. +std::unique_ptr> CreateTFDataOptimizationPass(); + +std::unique_ptr> CreateMoveTransposesPass(); +std::unique_ptr> CreateLayoutAssignmentPass(); + +// Guarantee that all FuncOp's have a single use. +std::unique_ptr> CreateGuaranteeAllFuncsOneUsePass(); + +// Optional pass which will unroll BatchMatMul and use only MatMul +std::unique_ptr> CreateUnrollBatchMatMulPassPass(); + +// Optional pass which will map TF BatchMatMul to TF Einsum +std::unique_ptr> CreateBatchMatMulToEinsumPass(); + +// Pass that transform Einsum to other TF Ops for the supported variants. +std::unique_ptr> CreateTransformEinsumPass(); + +// Optimizes Tensorflow graph. +std::unique_ptr> CreateTFOptimizePass(); +void RegisterTFOptimizePassPipeline(); + +// Creates pass to rewrite RecvTPUEmbeddingActivationsOp and +// SendTPUEmbeddingGradients ops to internal variants. +std::unique_ptr> CreateRewriteTPUEmbeddingOpsPass(); + +// Performs specific fusion for GPU targets. +std::unique_ptr> CreateGpuOpFusionPass(); + +// Creates a pass that decomposes to be compiled ReduceDataset ops into a while +// loop that iterates the dataset and calls the reduction function. +std::unique_ptr> CreateDecomposeReduceDatasetPass(); + +// Create a pass that convert ops that copy tensors between devices, e.g. +// tf.Identity. +std::unique_ptr> +CreateTensorDeviceCopyConversionPass(); + +// Returns a pass that folds tf.BroadcastTo nodes with subsequent nodes if they +// have built in broadcasting support. +std::unique_ptr> CreateBroadcastFoldPass(); + +void populateTfControlFlowToScfPatterns(MLIRContext* context, + RewritePatternSet* patterns); +// Create a pass to convert TensorFlow control flow to SCF. +std::unique_ptr> createConvertTfControlFlowToScfPass(); + +struct LayoutOptimizationPipelineOptions + : public PassPipelineOptions { + Option force_data_format{ + *this, "force-data-format", + llvm::cl::desc("Force data format for all layout sensitive ops")}; + Option skip_fold_transpose_in_ops{ + *this, "skip-fold-transpose-in-ops", + llvm::cl::desc("Skip folding transpose operands in Ops which can support " + "different layouts.")}; +}; + +// Layout optimization assigns optimal data layout for layout sensitive +// operations, and cancels all redundant transposes. +void CreateLayoutOptimizationPipeline( + OpPassManager& pm, // NOLINT - MLIR contract is pass by mutable reference. + const LayoutOptimizationPipelineOptions& options); + +struct StandardPipelineOptions + : public PassPipelineOptions { + Option enable_inliner{*this, "enable-inliner", + llvm::cl::desc("Enable inliner."), + llvm::cl::init(false)}; + Option form_clusters{*this, "form-clusters", + llvm::cl::desc("Enable Cluster Formation pass."), + llvm::cl::init(false)}; +}; + +// Propagates the pass manager with the passes involved in transforming or +// optimizing an MLIR graph without any target specialization. +// NOLINTNEXTLINE - MLIR contract is pass by mutable reference. +void CreateTFStandardPipeline(OpPassManager& pm, + const StandardPipelineOptions& options); + +// Propagates device attributes of resources from callers to callees. +std::unique_ptr> CreateResourceDeviceInferencePass(); + +// Creates a pass that promotes resource reads/writes in `functions` to inputs +// and outputs of `functions`, assuming that resource operations have already +// been decomposed and function calls have already been inlined. If `functions` +// is empty, the pass is applied to the main function by default. The pass also +// annotates the input arguments for resources with the indices of their +// aliasing output arguments. +std::unique_ptr> CreatePromoteResourcesToArgsPass( + llvm::ArrayRef functions = {}); + +// Creates a pass that promotes tf.VarHandleOp to resource arguments for all +// functions. +std::unique_ptr> CreatePromoteVarHandlesToArgsPass(); + +// Creates a pass that converts readonly reference variables to the +// corresponding resource variables. +std::unique_ptr> +CreateConvertReadonlyReferenceVariablesToResourceVariablesPass(); + +// Creates a simple device assignment pass on TF dialect for CoreRT use case. +std::unique_ptr> CreateSimpleTFDeviceAssignmentPass( + llvm::StringRef default_device = "cpu"); + +// Creates a pass to perform device assignment for TF dialect ops that do not +// have device assignment, by using the device attribute of the function. +std::unique_ptr> +CreateTFDeviceAssignmentByFuncAttrPass(); + +// Performs resource lifting on the function body to hoist resource variable +// accesses outside all control flow statements. +LogicalResult ResourceLiftingForFunctionalControlFlow(func::FuncOp function); + +// Converts stack ops into operations on local variables, which can later be +// removed by resource lifting. Requires known maximum sizes of stacks and +// known element shapes of push ops. +std::unique_ptr> CreateStackOpsDecompositionPass(); + +// Creates a pass to strip the "tf._noinline" attribute from the functions in +// the module. +std::unique_ptr> CreateStripNoinlineAttributePass(); + +// Converts tensor list operations into operations on buffers and sizes. Needs +// static shapes and known max element count. +std::unique_ptr> CreateTensorListOpsDecompositionPass(); + +// Converts tensor array ops into operations on local variables, which can later +// be removed by resource lifting. Requires known sizes and known element shapes +// (either defined in TensorArrayV3 or implied in the first write). +std::unique_ptr> +CreateTensorArrayOpsDecompositionPass(); + +// Create a pass that legalize TFG to TF dialect. +std::unique_ptr CreateLegalizeTFGToTFEPass(); + +// Matches sequence of ops to TensorFlow fused kernels. This pass should not be +// generally used beyond exporting to runtimes that supports these ops. In the +// future these fusions may be codegen'd automatically. +std::unique_ptr> CreateFusedKernelMatcherPass(); + +// Creates function pass to select device index/fold tf.DeviceIndex. +std::unique_ptr> CreateDeviceIndexSelectorPass(); + +// Creates function pass to replace InitializeTableFromTextFileV2Ops with +// LookupTableImportV2Op ops. +std::unique_ptr> CreateInitTextFileToImportPass( + std::string saved_model_dir = ""); + +// Creates function pass to cluster TensorFlow ops by host. The program +// generated by this pass will have one function per host where all operations +// in the same function are placed on the same host. Each result of the per-host +// function will have a "tf.device" attribute which specifies the device +// assignment of the result. +std::unique_ptr> CreateClusterTFOpsByHostPass(); + +// Creates a pass to insert tf_device.send and tf_device.receive ops to make +// sure any argument of any op is on the same host of the op itself. +std::unique_ptr> CreateCrossHostTransferPass(); + +// Creates a pass that adds the device attribute to every tf.Const op based on +// the device attribute of the operations that read its result. If the result of +// a tf.Const op is read by operations placed on multiple devices, then the pass +// will replicate the tf.Const op once for each device. +std::unique_ptr> CreateConstantOpDeviceAssignmentPass(); + +// Returns pass that verifies whether all functions in module are of single +// tf_executor.graph and each tf_executor.island in tf_executor.graph only has a +// single op. +std::unique_ptr> CreateVerifySuitableForExportPass(); + +// Returns pass that prepares TPU computation to be legal for export to +// TensorFlow. +std::unique_ptr> +CreatePrepareTpuComputationForTfExportPass(); + +// Rewrites ops that require quantized inputs or outputs to ops that allow +// non-quantized inputs and outputs. +std::unique_ptr> CreateLowerQuantizedPass(); + +// Reorders ops so ops of the same dialect are next to each other. +std::unique_ptr CreateOrderByDialectPass(); + +// Groups ops into functions that only contain one dialect. +std::unique_ptr CreateGroupByDialectPass(); + +// Removes unused parameters from functions & their callers. +std::unique_ptr> CreateRemoveUnusedArgumentsPass(); + +// Removes unused results from WhileRegion ops. +std::unique_ptr> +CreateRemoveUnusedWhileResultsPass(); + +// Hoists loop invariant ops to the outside of the loop. +std::unique_ptr> CreateHoistLoopInvariantPass(); + +// Creates VarHandleOps right next to the operations that use them. +std::unique_ptr> CreateLocalizeVarHandlesPass(); + +// Removes all TF attributes +std::unique_ptr> CreateStripTfAttributesPass(); + +// Converts AnonymousIteratorOps to (named) IteratorOps. +std::unique_ptr> CreateNameAnonymousIteratorsPass(); + +// Creates a pass that breaks up an island with multiple ops into multiple +// islands, each with a single op. This pass intentionally does not propagate +// control dependencies across newly created islands and is handled by +// CreateTFExecutorUpdateControlDependenciesPass. +std::unique_ptr> CreateSplitIntoIslandPerOpPass(); + +// Prints, but otherwise pipes through without changes, the current module. +std::unique_ptr> CreatePrintPass( + raw_ostream* os = nullptr); + +// Moves TPUCompileMlir ops as far to the front as possible. +std::unique_ptr> CreateMoveTpuCompileToFrontPass(); + +// Decomposes OptionalFromValue, OptionalGetValue, OptionalNone, +// and OptionalHasValue +std::unique_ptr> CreateDecomposeOptionalsPass(); + +//===----------------------------------------------------------------------===// +// XlaCallModule +//===----------------------------------------------------------------------===// + +// Creates a pass that deserializes functions in the StableHLO modules from +// `tf.XlaCallModule` to the top-level module. +std::unique_ptr> +CreateXlaCallModuleDeserializationPass(); + +// Creates a pass that serializes StableHLO functions referenced by +// `tf.XlaCallModule` from the top-level module to `tf.XlaCallModule`'s +// `module` attribute. +std::unique_ptr> CreateXlaCallModuleSerializationPass(); + +} // namespace TF + +namespace tf_executor { + +// Creates a pass to chain control outputs of while loop body. +std::unique_ptr> +CreateTFExecutorConvertControlToDataOutputsPass(); +std::unique_ptr> +CreateTFExecutorConvertControlToDataOutputsPass( + bool composite_tpuexecute_side_effects); + +std::unique_ptr> +CreateTFExecutorCheckControlDependenciesPass(); + +// Creates a pass to merge IslandOps from TFExecutor dialect. +std::unique_ptr> +CreateTFExecutorIslandCoarseningPass(); + +// Creates a pass to merge IslandOps for operation marked for execution on TPU. +// This is a V1 backward compatibility. +std::unique_ptr> +CreateTFExecutorTPUV1IslandCoarseningPass(); + +// Creates a pass to outlining TPU clusters from single IslandOp into a nested +// module suitable for being processed as-if it was a V2 module. +// This is a V1 backward compatibility. +std::unique_ptr> +CreateTFExecutorTPUV1IslandOutliningPass(); + +// Creates a pass to inline calls to the nested TPU module, this reverses the +// effect of the `TFExecutorTPUV1IslandOutlining` pass above. +// This is a V1 backward compatibility. +std::unique_ptr> +CreateTFExecutorTPUV1IslandInliningPass(); + +// Creates a pass to prune tf_executor.graph from dead nodes. +std::unique_ptr> CreateTFExecutorGraphPruningPass( + llvm::ArrayRef ops_to_preserve = {}); + +// Creates a pass to update control dependencies. +std::unique_ptr> +CreateTFExecutorUpdateControlDependenciesPass(); + +} // namespace tf_executor + +namespace TFDevice { +// Creates a pass that forms clusters from instructions that are assigned to +// same device. +std::unique_ptr> CreateClusterFormationPass(); + +// Sinks `tf.Const` operations in the ClusterOp region using them. This is +// performed in order to limit the number of values implicitly captured in this +// region before outlining. +std::unique_ptr> CreateClusterConstantSinkingPass( + llvm::function_ref filter = {}); + +// Creates a pass that outlines regions of tf_device.cluster operations. +std::unique_ptr> CreateClusterOutliningPass(); + +// Creates a pass that outlines regions of tf_device.launch operations. +std::unique_ptr> CreateLaunchOutliningPass(); + +// Creates a pass that converts tf_device::LaunchFuncOp into +// TF::PartitionedCallOp. +std::unique_ptr> CreateConvertLaunchFuncToTFCallPass(); + +// A pass that decomposes composite resource operations into primitive ones like +// ReadVariableOp, AssignVariableOp and other computations to facilitate +// transformations like resource op lifting. +std::unique_ptr> CreateDecomposeResourceOpsPass(); + +// A pass that decomposes composite resource operations in device cluster +// (tf_device.cluster op) into primitive ones like ReadVariableOp, +// AssignVariableOp and other computations to facilitate transformations like +// resource op lifting. +std::unique_ptr> +CreateDecomposeResourceOpsInClusterPass(); + +// Creates a pass that marks TPU cluster input-output pairs reading and writing +// to same resource variable as aliases. +std::unique_ptr> CreateMarkInputOutputAliasesPass(); + +// Creates a pass that lifts operations on external resource variables from +// device computation nested in `tf_device::LaunchOp` out so that resource +// variable load operations are all before device computation while resource +// variable store operations are all after device computation. After this pass, +// device computation no longer interacts with external resource variables. +std::unique_ptr> CreateResourceOpLiftingPass(); + +// Creates a pass that lifts operations from the main function. +std::unique_ptr> +CreateResourceOpLiftingForMainFunctionPass(); + +// Lifts resource operations from tf_device.launch_func ops nested in `op` +// outside. Returns a failure if there are remaining resource-type values that +// can not be lifted. +LogicalResult LiftResourceOps(Operation* op); + +// Creates a pass that hoists invariant operations in a `tf_device.replicate`. +std::unique_ptr> +CreateReplicateInvariantOpHoistingPass(); + +// Creates a pass that forms replica `tf_executor.island` from a single +// `tf_device.replicate` island. +std::unique_ptr> CreateReplicateToIslandPass( + bool legacy_graph_export = true); + +// Creates a pass that sets the device ordinal attribute of the required op +// using the replica id attribute. +std::unique_ptr> +CreateReplicaIDToDeviceOrdinalPass(); + +// Creates a pass that creates `tf_executor.island` from a single +// `tf_device.parallel_execute` island. +std::unique_ptr> CreateParallelExecuteToIslandsPass( + bool legacy_graph_export = true); + +// Creates a pass that annotates whether a LaunchFuncOp's parameters have the +// same data across replicas. +std::unique_ptr> +CreateAnnotateParameterReplicationPass(); + +// Creates a pass that merges control flow with similar predicates. +std::unique_ptr> CreateMergeControlFlowPass(); + +// Creates a pass that wraps each TensorFlow dialect with `device` attribute +// in a `tf_device.launch` op with the same `device` attribute. +std::unique_ptr> +CreateDeviceAttributeToLaunchPass(); + +// Creates a pass that hoists a `tf_device.launch` body and assigns a `device` +// attribute to each TensorFlow dialect op in the body based on the `device` +// attribute on the `tf_device.launch`. +std::unique_ptr> CreateLaunchToDeviceAttributePass( + bool legacy_graph_export = true); + +// Creates a pass to ensure that the `_xla_outside_compilation` and +// tf_device.launch op no longer exist after Outside Compilation is complete. +std::unique_ptr> +CreateVerifyNoOutsideCompilationMarkersPass(); + +// Create a pass that inlines the StatefulPartitionedCallOp op based in the +// parent region. +std::unique_ptr> CreateXlaInlineDeviceOpsPass(); + +// Creates a pass that rewrites partitioned calls with `_xla_compile_device +// type` with `tf.XlaLaunch` ops. +std::unique_ptr> CreateXlaRewritePass(); + +// Create a pass that validates the input graph to the CPU/GPU bridge. +std::unique_ptr> CreateXlaValidateInputsPass(); +} // namespace TFDevice + +namespace TFTPU { +// Creates a pass that converts unified compilation and replication +// attributes back to legacy attributes. +std::unique_ptr> +CreateConvertToLegacyCompileAndReplicateAttributesPass(); + +// Creates a pass that converts all TPUPartitionedInput to TPUPartitionedInputV2 +std::unique_ptr> +CreateTPUPartitionedOpConversionPass(); + +// Creates a pass that cleans up `_replication_info` attribute on operations +// that are inside a cluster. +std::unique_ptr> +CreateTPUClusterCleanupAttributesPass(); + +// Creates a pass that removes Identity/IdentityN ops from a cluster. +std::unique_ptr> CreateTPUIdentityPruningPass(); + +// Creates a pass that allows TPU program inputs to have layouts determined at +// run time. +std::unique_ptr> CreateTPUDynamicLayoutPass(); + +// Creates a pass that adds `tf.ReadVariableOp` to a TPU cluster for resources +// the cluster only writes to. +std::unique_ptr> CreateTPUResourceReadForWritePass(); + +// Creates a pass that reorders partitiioned resource reads and replicated +// inputs. +std::unique_ptr> +CreateTPUReorderReplicateAndPartitionedInputsPass(); + +// Creates a pass that partitions unpartitioned resource read/write to +// partitioned resource variables. +std::unique_ptr> +CreateTPUResourceReadsWritesPartitioningPass(); + +// Creates a pass that looks for usage of the result of +// TPUCopyWithDynamicShapeOp and annotate these values to be dynamic shape. This +// ensures that the generated tpu program has the correct inputs annotation. +std::unique_ptr> +CreateTPUAnnotateDynamicShapeInputsPass(); + +// Creates a pass that moves `tf.AssignVariableOp` into a +// `tf_device.parallel_execute` region if the `tf.AssignVariableOp` is the +// only consumer of a `tf_device.parallel_execute` result. +std::unique_ptr> +CreateTPUParallelExecuteSinkResourceWritePass(); + +// Create a pass that extract TPUCopyWithDynamicShapeOp from the host launch op +// and wrap them in device launch op. This allows this op executed on TPU while +// still compiled on host. +std::unique_ptr> +CreateExtractTPUCopyWithDynamicShapeOpPass(); + +// Creates a pass that wraps ReadVariableOp/AssignVariable op that consumes a +// packed tensor to have same device placement as underlying TPU device. +std::unique_ptr> +CreateTPUColocateCompositeResourceOps(); + +// Creates a pass that expands outside compilation cluster at the head/tail of +// TPU computation by adding outside compilation attribute to identity/cast ops +// that are only used for host computation. +std::unique_ptr> +CreateTPUHostComputationExpansionPass(); + +// Creates a pass that updates inputs to TPU embedding layer enqueue ops so that +// correct ops are invoked during training and evaluation. +std::unique_ptr> +CreateTPUUpdateEmbeddingEnqueueOpInputsPass(); + +// Creates a pass that propagates TPU devices to users. +std::unique_ptr> CreateTPUDevicePropagationPass(); + +// Create a pass that colocates each `Split` with its predecessor. +std::unique_ptr> CreateTPUColocateSplitsPass(); + +// Creates a pass that replicates the tf._TPUCompileMlir op on each host that +// needs the compiled program. It helps avoid transferring the compiled binary +// between hosts. +std::unique_ptr> +CreateTPUCompileOpReplicationPass(); + +// Creates a pass that applies space to depth transform +// for the first or frontier convolutions consume host inputs on TPU. +std::unique_ptr> CreateTPUSpaceToDepthPass(); + +// Adjusts the device on TPUCopyWithDynamicShape ops. +std::unique_ptr> +CreateColocateTPUCopyWithDynamicShapePass(); + +} // namespace TFTPU + +// Define the registrations in a detail namespace, just so that we can overload +// the main entry point `registerTensorFlowPasses` to inject +// RegisterTFOptimizePassPipeline. +namespace detail { + +// Direction in which to move transposes in MoveTransposePass. +enum MoveTransposeDirection { kBegin, kEnd }; + +#define GEN_PASS_REGISTRATION +#define GEN_PASS_DECL_BATCHMATMULTOEINSUMPASS +#define GEN_PASS_DECL_BREAKUPISLANDSPASS +#define GEN_PASS_DECL_BROADCASTFOLDPASS +#define GEN_PASS_DECL_CANONICALIZECOMPILEANDREPLICATEATTRIBUTESPASS +#define GEN_PASS_DECL_CLUSTERCONSTANTSINKINGPASS +#define GEN_PASS_DECL_CLUSTERFORMATIONPASS +#define GEN_PASS_DECL_CLUSTEROUTLININGPASS +#define GEN_PASS_DECL_CLUSTERTFOPSBYHOSTPASS +#define GEN_PASS_DECL_CONSTANTOPDEVICEASSIGNMENTPASS +#define GEN_PASS_DECL_CONVERTLAUNCHFUNCTOTFCALLPASS +#define GEN_PASS_DECL_CONVERTREADONLYREFERENCEVARIABLESTORESOURCEVARIABLESPASS +#define GEN_PASS_DECL_CONVERTTFCONTROLFLOWTOSCFPASS +#define GEN_PASS_DECL_CONVERTTOLEGACYCOMPILEANDREPLICATEATTRIBUTESPASS +#define GEN_PASS_DECL_DECOMPOSEREDUCEDATASETPASS +#define GEN_PASS_DECL_DEVICEINDEXSELECTORPASS +#define GEN_PASS_DECL_DROPWHILESHAPEINVARIANTINDEVICECLUSTERPASS +#define GEN_PASS_DECL_DROPWHILESHAPEINVARIANTPASS +#define GEN_PASS_DECL_EXECUTORCHECKCONTROLDEPENDENCIESPASS +#define GEN_PASS_DECL_EXECUTORCONVERTCONTROLTODATAOUTPUTSPASS +#define GEN_PASS_DECL_EXECUTORDIALECTTOFUNCTIONALPASS +#define GEN_PASS_DECL_EXECUTORGRAPHPRUNINGPASS +#define GEN_PASS_DECL_EXECUTORISLANDCOARSENINGPASS +#define GEN_PASS_DECL_EXECUTORTPUV1ISLANDINLININGPASS +#define GEN_PASS_DECL_EXECUTORUPDATECONTROLDEPENDENCIESPASS +#define GEN_PASS_DECL_FUNCTIONALCONTROLFLOWTOCFGPASS +#define GEN_PASS_DECL_FUNCTIONALCONTROLFLOWTOREGIONSPASS +#define GEN_PASS_DECL_FUNCTIONALTOEXECUTORDIALECTCONVERSIONPASS +#define GEN_PASS_DECL_FUSEDKERNELMATCHERPASS +#define GEN_PASS_DECL_GROUPBYDIALECTPASS +#define GEN_PASS_DECL_GUARANTEEALLFUNCSONEUSEPASS +#define GEN_PASS_DECL_HOISTREPLICATEINVARIANTRESOURCEWRITESPASS +#define GEN_PASS_DECL_INITTEXTFILETOIMPORTPASS +#define GEN_PASS_DECL_LAUNCHOUTLININGPASS +#define GEN_PASS_DECL_LAYOUTASSIGNMENTPASS +#define GEN_PASS_DECL_LEGALIZEHLOTOTFPASS +#define GEN_PASS_DECL_LEGALIZETFGTOTFPASS +#define GEN_PASS_DECL_LOCALIZEVARHANDLESPASS +#define GEN_PASS_DECL_LOWERQUANTIZEDPASS +#define GEN_PASS_DECL_MARKINPUTOUTPUTALIASESPASS +#define GEN_PASS_DECL_MATERIALIZEPASSTHROUGHOP +#define GEN_PASS_DECL_MERGECONTROLFLOWPASS +#define GEN_PASS_DECL_MOVETRANSPOSESPASS +#define GEN_PASS_DECL_ORDERBYDIALECTPASS +#define GEN_PASS_DECL_PARALLELEXECUTETOISLANDSPASS +#define GEN_PASS_DECL_PREPARETPUCOMPUTATIONFORTFEXPORTPASS +#define GEN_PASS_DECL_PROMOTERESOURCESTOARGSPASS +#define GEN_PASS_DECL_PROMOTEVARHANDLESTOARGSPASS +#define GEN_PASS_DECL_REGIONCONTROLFLOWTOFUNCTIONALPASS +#define GEN_PASS_DECL_REMOVEUNUSEDARGUMENTSPASS +#define GEN_PASS_DECL_REMOVEUNUSEDWHILERESULTSPASS +#define GEN_PASS_DECL_REPLICAIDTODEVICEORDINALPASS +#define GEN_PASS_DECL_REPLICATEINVARIANTOPHOISTINGPASS +#define GEN_PASS_DECL_REPLICATETOISLANDPASS +#define GEN_PASS_DECL_RESOURCEDEVICEINFERENCEPASS +#define GEN_PASS_DECL_REWRITETPUEMBEDDINGOPSPASS +#define GEN_PASS_DECL_SIMPLETFDEVICEASSIGNMENTPASS +#define GEN_PASS_DECL_SPLITINTOISLANDPEROPPASS +#define GEN_PASS_DECL_STACKOPSDECOMPOSITIONPASS +#define GEN_PASS_DECL_STRIPNOINLINEATTRIBUTEPASS +#define GEN_PASS_DECL_TFDATAOPTIMIZATIONPASS +#define GEN_PASS_DECL_TFDEVICEASSIGNMENTBYFUNCATTRPASS +#define GEN_PASS_DECL_TPUBRIDGEEXECUTORISLANDOUTLININGPASS +#define GEN_PASS_DECL_TPUCLEANUPCLUSTERATTRIBUTESPASS +#define GEN_PASS_DECL_TPUCLUSTERFORMATIONPASS +#define GEN_PASS_DECL_TPUCOLOCATECOMPOSITERESOURCEOPSPASS +#define GEN_PASS_DECL_TPUDEVICEPROPAGATIONPASS +#define GEN_PASS_DECL_TPUDYNAMICLAYOUTPASS +#define GEN_PASS_DECL_TPUHOSTCOMPUTATIONEXPANSIONPASS +#define GEN_PASS_DECL_TPUIDENTITYPRUNINGPASS +#define GEN_PASS_DECL_EXTRACTTPUCOPYWITHDYNAMICSHAPEOPPASS +#define GEN_PASS_DECL_TPUPARALLELEXECUTESINKRESOURCEWRITEPASS +#define GEN_PASS_DECL_TPUREORDERREPLICATEANDPARTITIONEDINPUTSPASS +#define GEN_PASS_DECL_TPURESOURCEREADFORWRITEPASS +#define GEN_PASS_DECL_TPURESOURCEREADSWRITESPARTITIONINGPASS +#define GEN_PASS_DECL_TPUSPACETODEPTHPASS +#define GEN_PASS_DECL_TPUUPDATEEMBEDDINGENQUEUEOPINPUTSPASS +#define GEN_PASS_DECL_TENSORARRAYOPSDECOMPOSITIONPASS +#define GEN_PASS_DECL_TENSORDEVICECOPYCONVERSIONPASS +#define GEN_PASS_DECL_TENSORFLOWOPTIMIZEPASS +#define GEN_PASS_DECL_TENSORFLOWSHAPEINFERENCEPASS +#define GEN_PASS_DECL_TENSORLISTOPSDECOMPOSITIONPASS +#define GEN_PASS_DECL_TENSORFLOWGPUFUSION +#define GEN_PASS_DECL_TPUV1BRIDGEEXECUTORISLANDCOARSENINGPASS +#define GEN_PASS_DECL_TRANSFORMEINSUMPASS +#define GEN_PASS_DECL_UNROLLBATCHMATMULPASS +#define GEN_PASS_DECL_VERIFYSUITABLEFOREXPORTPASS +#define GEN_PASS_DECL_XLACALLMODULEDESERIALIZATIONPASS +#define GEN_PASS_DECL_XLACALLMODULESERIALIZATIONPASS +#define GEN_PASS_DECL_XLACALLMODULECUSTOMCALLTFFUNCTIONRENAMINGPASS +#include "tensorflow/compiler/mlir/tensorflow/transforms/tf_passes.h.inc" +} // namespace detail +using namespace detail; // NOLINT +inline void registerTensorFlowPasses() { + detail::registerTensorFlowPasses(); + TF::RegisterTFOptimizePassPipeline(); +} + +namespace TFDevice { +#define GEN_PASS_REGISTRATION +#define GEN_PASS_DECL_ANNOTATEPARAMETERREPLICATIONPASS +#define GEN_PASS_DECL_DECOMPOSERESOURCEOPSINCLUSTERPASS +#define GEN_PASS_DECL_DECOMPOSERESOURCEOPSPASS +#define GEN_PASS_DECL_DEVICEATTRIBUTETOLAUNCHPASS +#define GEN_PASS_DECL_HOSTLAUNCHTOOUTSIDECOMPILEDPASS +#define GEN_PASS_DECL_LAUNCHTODEVICEATTRIBUTEPASS +#define GEN_PASS_DECL_OUTSIDECOMPILEDTOHOSTLAUNCHPASS +#define GEN_PASS_DECL_RESOURCEOPLIFTINGFORMAINFUNCTIONPASS +#define GEN_PASS_DECL_RESOURCEOPLIFTINGPASS +#define GEN_PASS_DECL_VERIFYNOOUTSIDECOMPILATIONMARKERSPASS +#define GEN_PASS_DECL_XLACLUSTERFORMATIONPASS +#define GEN_PASS_DECL_XLAINLINEDEVICEOPSPASS +#define GEN_PASS_DECL_XLAREWRITEPASS +#define GEN_PASS_DECL_XLAREWRITEV2PASS +#define GEN_PASS_DECL_XLAVALIDATEINPUTSPASS +#include "tensorflow/compiler/mlir/tensorflow/transforms/tf_device_passes.h.inc" +} // namespace TFDevice + +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_PASSES_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting_cleanup.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting_cleanup.h new file mode 100644 index 00000000..f526acc1 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting_cleanup.h @@ -0,0 +1,47 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_RESOURCE_OP_LIFTING_CLEANUP_H_ +#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_RESOURCE_OP_LIFTING_CLEANUP_H_ + +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/TypeUtilities.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project + +// Performs IR cleanup and canonicalization in preparation for Resource Op +// Lifting pass. It does several things: +// - Eliminate identity nodes to remove (most) of resource aliasing +// - Canonicalize functional control flow. For functional control flow we +// expect that any resource output of these ops matches the corresponding +// input, and then forward that input to the output. Fails if this is not the +// case. If successful, the following invariants will hold true: +// (a) For if/case, any resource type results will be deleted. +// (b) For while, any resource type results will be unused. +// - Canonicalize region based control flow. Again, any resource outputs are +// expected to be resolved to be one of the captured resource inputs. Fails +// if this is not the case. If successful, the following invariants will hold +// true: +// (a) For if/case, any resource type results will be deleted. +// (b) For while, any resource type results will be unused. +namespace mlir { +namespace TF { +LogicalResult CleanupAndCanonicalizeForResourceOpLifting(ModuleOp module); +LogicalResult CleanupAndCanonicalizeForResourceOpLifting(func::FuncOp func); + +} // namespace TF +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_RESOURCE_OP_LIFTING_CLEANUP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/transforms/rewrite_util.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/transforms/rewrite_util.h new file mode 100644 index 00000000..b8bc0a1d --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/transforms/rewrite_util.h @@ -0,0 +1,95 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_REWRITE_UTIL_H_ +#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_REWRITE_UTIL_H_ + +#include + +#include "mlir/IR/Matchers.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project + +namespace mlir { +namespace TF { + +// Returns int, float or complex DenseElementsAttr with scalar shape with the +// given element type and the integer value. +template +DenseElementsAttr GetScalarOfType(Type ty, T raw_value) { + RankedTensorType scalar_ty = RankedTensorType::get({}, ty); + if (auto float_ty = mlir::dyn_cast(ty)) { + FloatAttr attr = FloatAttr::get(float_ty, raw_value); + return DenseElementsAttr::get(scalar_ty, attr); + } else if (auto int_ty = mlir::dyn_cast(ty)) { + IntegerAttr attr = IntegerAttr::get(int_ty, raw_value); + return DenseElementsAttr::get(scalar_ty, attr); + } else if (auto complex_ty = mlir::dyn_cast(ty)) { + Type complex_element_ty = complex_ty.getElementType(); + if (complex_element_ty.isF32()) { + return DenseElementsAttr::get( + scalar_ty, static_cast>(raw_value)); + } else if (complex_element_ty.isF64()) { + return DenseElementsAttr::get( + scalar_ty, static_cast>(raw_value)); + } + } + llvm_unreachable("unsupported type"); +} + +// Returns true if `value` is compile-time constant and its splat value equals +// to `raw_value`. +template +bool IsConstantValueOf(Value value, T raw_value) { + auto element_type = mlir::cast(value.getType()).getElementType(); + if (mlir::isa(element_type)) { + DenseFPElementsAttr float_attr; + if (matchPattern(value, m_Constant(&float_attr)) && float_attr.isSplat() && + float_attr.getSplatValue().isExactlyValue(raw_value)) + return true; + } else if (mlir::isa(element_type)) { + DenseIntElementsAttr int_attr; + if (matchPattern(value, m_Constant(&int_attr)) && int_attr.isSplat() && + int_attr.getSplatValue() == raw_value) + return true; + } + + return false; +} + +// Returns true if `op` is placed on GPU device, and false if it's on other +// devices or the device is not specified. +bool IsOnGpuDevice(mlir::Operation *op); + +// Wrappers for CopyDeviceAndUnderscoredAttributes +void CopyDeviceAndUnderscoredAttributesAdaptor(mlir::OpResult src, + mlir::OpResult dest); +void CopyDeviceAndUnderscoredAttributesAdaptor(mlir::Operation *src, + mlir::OpResult dest); +void CopyDeviceAndUnderscoredAttributesAdaptor(mlir::Operation *src, + mlir::Operation *dest); + +// Wrappers for CopyXlaOutsideCompilationAttributes +void CopyXlaOutsideCompilationAttributesAdaptor(mlir::OpResult src, + mlir::OpResult dest); +void CopyXlaOutsideCompilationAttributesAdaptor(mlir::Operation *src, + mlir::OpResult dest); +void CopyXlaOutsideCompilationAttributesAdaptor(mlir::Operation *src, + mlir::Operation *dest); + +} // namespace TF +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_REWRITE_UTIL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/transforms/set_tpu_infeed_layout.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/transforms/set_tpu_infeed_layout.h new file mode 100644 index 00000000..8b634b60 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/transforms/set_tpu_infeed_layout.h @@ -0,0 +1,32 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_SET_TPU_INFEED_LAYOUT_H_ +#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_SET_TPU_INFEED_LAYOUT_H_ + +#include "mlir/IR/BuiltinOps.h" // from @llvm-project + +namespace mlir { + +// Set layouts attribute of tf.InfeedDequeueTuple ops. +bool SetTPUInfeedLayout(ModuleOp mlir_module); + +// Try to determine the right TPU infeed layout. +FailureOr GetTPUInfeedLayout(ArrayRef types, + OpBuilder& rewriter); + +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_SET_TPU_INFEED_LAYOUT_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.h new file mode 100644 index 00000000..9075754d --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.h @@ -0,0 +1,85 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_SHAPE_INFERENCE_H_ +#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_SHAPE_INFERENCE_H_ + +#include +#include + +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Support/TypeID.h" // from @llvm-project + +namespace mlir { +namespace TF { + +inline constexpr char kMLIRContextSingleThreadVar[] = + "TF_USE_SINGLE_THREAD_MLIR_CONTEXT"; + +// Returns whether type can be further refined. +bool CanBeRefined(Type type); + +// Returns a new arg type based on the shape and element type. If there are +// dynamic bounds attribute to the arg, update the bounds based on the shape +// as well. +Type GetNewArgType(Type old_arg_type, ArrayRef shape, + Type element_type, mlir::MLIRContext* context); + +// Refines all the shapes in a module, skipping the inference for all ops +// whose type is in ops_to_skip. +// Returns a failure() on error, otherwise returns true to indicate that it +// reached convergence, false otherwise. +// If input shapes are provided, first refines the `main` function using +// InferShapeForFunction. +FailureOr InferModuleShape(ModuleOp module, int64_t max_iterations = 10, + ArrayRef ops_to_skip = {}, + ArrayRef> input_shapes = {}); + +// Given a tensorflow NodeShape string, returns a vector of argument shapes +// that can be used with InferShapeForFunction. +// TF NodeShape uses `,` to separate dimensions, and `:` to separate arguments. +// Ex: 1,2:3,4,5:6,? --> [[1, 2], [3, 4, 5], [6, ?]] +absl::StatusOr>> ParseArgumentShapes( + absl::string_view input_shapes); + +// Given a list of refined shapes matching the function arguments of func, runs +// shape inference over the function to propagate this updated information, +// skipping the inference for all ops whose type is in ops_to_skip. +// If arg_shapes are empty, then argument shapes will be left unchanged. +// Note: This affects the entire module, and changes are not just scoped to the +// function being inferred. +// Returns a failure() on error, otherwise returns true to indicate that it +// reached convergence, false otherwise. +FailureOr InferShapeForFunction(func::FuncOp func, + ArrayRef> arg_shapes, + int64_t graph_version, + int64_t max_iterations = 10, + ArrayRef ops_to_skip = {}); + +// Create a MLIRContext based on the threading setup in the env var. +std::unique_ptr MakeMLIRContextWithThreading(); + +} // namespace TF + +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_SHAPE_INFERENCE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/transforms/sparsecore/sparsecore_passes.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/transforms/sparsecore/sparsecore_passes.h new file mode 100644 index 00000000..8944745d --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/transforms/sparsecore/sparsecore_passes.h @@ -0,0 +1,50 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_SPARSECORE_SPARSECORE_PASSES_H_ +#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_SPARSECORE_SPARSECORE_PASSES_H_ + +#include + +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project + +namespace mlir { +namespace TFDevice { + +// For architectures that support accelerated embedding lookups, this pass will +// rewrite the graph to use pipelining for better device utilization. +std::unique_ptr> CreateEmbeddingSequencingPass(); + +// This is a strictly sequential and formally correct fallback option for the +// embedding pipelining pass intended for debugging during pipelining +// development. +std::unique_ptr> CreateEmbeddingPipeliningPass(); + +// Passes in the program key to embedding ops, by moving the embedding ops +// after the _TPUCompileMlir op. +std::unique_ptr> CreateEmbeddingProgramKeyPass(); + +#define GEN_PASS_REGISTRATION +#define GEN_PASS_DECL_EMBEDDINGSEQUENCINGPASS +#define GEN_PASS_DECL_EMBEDDINGPIPELININGPASS +#define GEN_PASS_DECL_EMBEDDINGPROGRAMKEYPASS +#include "tensorflow/compiler/mlir/tensorflow/transforms/sparsecore/sparsecore_passes.h.inc" + +} // namespace TFDevice +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_SPARSECORE_SPARSECORE_PASSES_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/transforms/test_passes.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/transforms/test_passes.h new file mode 100644 index 00000000..f2a3eeba --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/transforms/test_passes.h @@ -0,0 +1,79 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_TEST_PASSES_H_ +#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_TEST_PASSES_H_ + +#include + +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project + +namespace mlir { +namespace tf_test { + +// Returns test pass for variable freezing. +std::unique_ptr> CreateFreezeVariableTestPass(); + +// Test pass for applying TF->TF lowering patterns. +std::unique_ptr> CreateTestTFLowerTFPass(); + +// Test passes for visitor util. +std::unique_ptr> CreateTestVisitorUtilPass(); +std::unique_ptr> +CreateTestVisitorUtilInterruptPass(); + +// Test operation clustering based on user defined policy. +std::unique_ptr> CreateTestClusteringPolicyPass(); + +// Test pass for analyzing side-effect analysis result. +std::unique_ptr> CreateTestSideEffectAnalysisPass(); + +std::unique_ptr> CreateTestResourceAliasAnalysisPass(); + +std::unique_ptr> CreateInitTextFileToImportTestPass(); +std::unique_ptr> +CreateInitTextFileToImportSavedModelTestPass(); + +// Variable Lifting test passes: only useful for lit testing. +std::unique_ptr> CreateLiftVariablesTestPass(); +std::unique_ptr> +CreateLiftVariablesInvalidSessionTestPass(); + +// Create a test pass for the above with a "fake" session, for lit testing. +std::unique_ptr> +CreateInitializeVariablesInSessionInitializerTestPass(); + +// Create a test pass that emits remarks for each analysis result for resources. +// This pass is only used for lit testing. +std::unique_ptr> CreateResourceAnalyzerTestPass(); + +#define GEN_PASS_REGISTRATION +#define GEN_PASS_DECL_FREEZEVARIABLESTESTPASS +#define GEN_PASS_DECL_INITTEXTFILETOIMPORTSAVEDMODELTESTPASS +#define GEN_PASS_DECL_INITTEXTFILETOIMPORTTESTPASS +#define GEN_PASS_DECL_INITIALIZEVARIABLESINSESSIONINITIALIZERPASS +#define GEN_PASS_DECL_LIFTVARIABLESINVALIDSESSIONTESTPASS +#define GEN_PASS_DECL_LIFTVARIABLESTESTPASS +#define GEN_PASS_DECL_RESOURCEANALYZERTESTPASS +#define GEN_PASS_DECL_TESTCLUSTERINGPOLICYPASS +#define GEN_PASS_DECL_TESTRESOURCEALIASANALYSIS +#define GEN_PASS_DECL_TESTSIDEEFFECTANALYSISPASS +#define GEN_PASS_DECL_TESTTENSORFLOWLOWERTFPASS +#include "tensorflow/compiler/mlir/tensorflow/transforms/test_passes.h.inc" + +} // namespace tf_test +} // namespace mlir +#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_TEST_PASSES_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/transforms/tf_data_optimization.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/transforms/tf_data_optimization.h new file mode 100644 index 00000000..b8a176da --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/transforms/tf_data_optimization.h @@ -0,0 +1,32 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_TF_DATA_OPTIMIZATION_H_ +#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_TF_DATA_OPTIMIZATION_H_ + +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project + +namespace mlir { +namespace TF { + +// Populates patterns to perform optimizations specific to tf.data operations. +void PopulateTFDataOptimizationPatterns(MLIRContext *context, + RewritePatternSet *patterns); + +} // namespace TF +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_TF_DATA_OPTIMIZATION_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/transforms/tf_graph_optimization_pass.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/transforms/tf_graph_optimization_pass.h new file mode 100644 index 00000000..2b601395 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/transforms/tf_graph_optimization_pass.h @@ -0,0 +1,49 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_TF_GRAPH_OPTIMIZATION_PASS_H_ +#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_TF_GRAPH_OPTIMIZATION_PASS_H_ + +#include +#include +#include + +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "tensorflow/core/common_runtime/optimization_registry.h" + +namespace tensorflow { + +// Create a module pass that will execute the given TF GraphOptimization passes +// in sequence. +// Pass requires that the module ran on is convertible to TF Graph. +std::unique_ptr> +CreateTensorFlowGraphOptimizationPass( + std::vector tf_passes); + +// Same as above but pass names instead of the passes provided. The registered +// passes are queried, if a TF graph optimization pass is not found in registry +// then the pass fails. +// Pass requires that the module ran on is convertible to TF Graph. +std::unique_ptr> +CreateTensorFlowGraphOptimizationPass( + const std::vector& pass_names); + +// Register the pass for command line testing. +void RegisterGraphOptimizationPasses(); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_TF_GRAPH_OPTIMIZATION_PASS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/transforms/tf_saved_model_asset_sinking_pass.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/transforms/tf_saved_model_asset_sinking_pass.h new file mode 100644 index 00000000..9c08e2d3 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/transforms/tf_saved_model_asset_sinking_pass.h @@ -0,0 +1,44 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_TF_SAVED_MODEL_ASSET_SINKING_PASS_H_ +#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_TF_SAVED_MODEL_ASSET_SINKING_PASS_H_ + +#include + +#include "absl/status/status.h" +#include "absl/strings/string_view.h" +#include "llvm/ADT/StringRef.h" +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project + +namespace mlir { +namespace tf_saved_model { + +// Helper function that sets up a module for an AssetSinkingPass. The sole +// argument of the main function of `module` is prepared to be inlined with +// the value `checkpoint_path`. +// Also adds SessionInitializer op. +absl::Status AddSessionInitializerAndInlineCheckpoint( + ModuleOp module, absl::string_view checkpoint_path); + +// Creates a pass that sinks SavedModel asset filenames to constants. +std::unique_ptr> CreateAssetSinkingPass( + llvm::StringRef saved_model_dir); + +} // namespace tf_saved_model +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_TF_SAVED_MODEL_ASSET_SINKING_PASS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/transforms/tf_saved_model_freeze_utils.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/transforms/tf_saved_model_freeze_utils.h new file mode 100644 index 00000000..7bfb9871 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/transforms/tf_saved_model_freeze_utils.h @@ -0,0 +1,47 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_TF_SAVED_MODEL_FREEZE_UTILS_H_ +#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_TF_SAVED_MODEL_FREEZE_UTILS_H_ + +#include "llvm/ADT/MapVector.h" +#include "llvm/ADT/SmallVector.h" +#include "mlir/IR/BuiltinAttributeInterfaces.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project + +namespace mlir { +namespace tf_saved_model { +// Container to hold all update actions on ops. +// Key: Operation to update. +// Value: optional list of argument indices to delete from this op. +// Note that we use MapVector because we want to iterate on the same order +// of insertion. +LogicalResult EraseObsoleteResourceUses( + llvm::MapVector> + arguments_to_erase); + +// Traces usage of 'var_handle_op' or 'resources' and replaces it's usage with +// constant value 'value'. All op operands updates are captured in +// 'arguments_to_erase'. +LogicalResult ReplaceVarWithConstant( + mlir::Value::use_range uses, ElementsAttr value, + llvm::MapVector>* + arguments_to_erase); +} // namespace tf_saved_model +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_TF_SAVED_MODEL_FREEZE_UTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/transforms/tf_saved_model_freeze_variables.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/transforms/tf_saved_model_freeze_variables.h new file mode 100644 index 00000000..ad8d20d0 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/transforms/tf_saved_model_freeze_variables.h @@ -0,0 +1,33 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_TF_SAVED_MODEL_FREEZE_VARIABLES_H_ +#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_TF_SAVED_MODEL_FREEZE_VARIABLES_H_ + +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "tensorflow/core/public/session.h" + +namespace mlir { +namespace tf_saved_model { + +// Freezes readonly variables in the graph. +LogicalResult FreezeVariables(ModuleOp module, tensorflow::Session* session); + +} // namespace tf_saved_model + +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_TF_SAVED_MODEL_FREEZE_VARIABLES_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/transforms/tf_saved_model_passes.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/transforms/tf_saved_model_passes.h new file mode 100644 index 00000000..801eaaeb --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/transforms/tf_saved_model_passes.h @@ -0,0 +1,88 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_TF_SAVED_MODEL_PASSES_H_ +#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_TF_SAVED_MODEL_PASSES_H_ + +#include + +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/transforms/tf_saved_model_asset_sinking_pass.h" +#include "tensorflow/core/public/session.h" + +namespace mlir { +namespace tf_saved_model { + +// Creates a pass that optimizes tf_saved_model.global_tensor ops. +std::unique_ptr> CreateOptimizeGlobalTensorsPass(); + +// Creates a pass that freezes tf_saved_model.global_tensor ops. +std::unique_ptr> CreateFreezeGlobalTensorsPass( + bool allow_mutable_tensors = false); + +// Creates a pass that freezes tf_saved_model.asset ops. +std::unique_ptr> CreateFreezeAssetsPass( + std::string saved_model_dir = ""); + +// Creates as pass that removes variables in the session initializer. +// This job is required with lifting variable passes. Originally, the session +// initializer function does assigning variables. However, the read-only +// variable assignments will be done via lifting variables pass by converting +// the read-only variables to constant ops, instead. This pass removes the +// redundant operations. This pass should be located in front of the pass for +// lifting read-only variables. +std::unique_ptr> +CreateRemoveVariablesInSessionInitializerPass(); + +// Creates a pass that removes duplicate 'tf_saved_model.bound_input' bindings. +std::unique_ptr> CreateDedupBoundInputBindingPass(); + +// Create a pass that removes function arguments that map to global tensors. +std::unique_ptr CreateLowerGlobalsToMlProgramPass(); + +// Create a pass that lowers variable read/write ops to ml_program ops. +std::unique_ptr> +CreateLowerVariableOpsToMlProgramPass(); + +// Strips saved_model attributes from a module and its functions. +std::unique_ptr> CreateStripSavedModuleMetadataPass(); + +// Convert the session initializer to a function. +std::unique_ptr> +CreateConvertSessionInitializerToFunctionPass(); + +// Creates forwarding functions for 'exported_names'. +std::unique_ptr> +CreateAddFunctionsForExportedNamesPass(); + +#define GEN_PASS_REGISTRATION +#define GEN_PASS_DECL_DEDUPBOUNDINPUTBINDINGPASS +#define GEN_PASS_DECL_FREEZEASSETSPASS +#define GEN_PASS_DECL_FREEZEGLOBALTENSORSPASS +#define GEN_PASS_DECL_LOWERGLOBALSTOMLPROGRAMPASS +#define GEN_PASS_DECL_LOWERVARIABLEOPSTOMLPROGRAMPASS +#define GEN_PASS_DECL_OPTIMIZEGLOBALTENSORSPASS +#define GEN_PASS_DECL_REMOVEVARIABLESINSESSIONINITIALIZERPASS +#define GEN_PASS_DECL_STRIPSAVEDMODULEMETADATAPASS +#define GEN_PASS_DECL_ADDFUNCTIONSFOREXPORTEDNAMESPASS +#include "tensorflow/compiler/mlir/tensorflow/transforms/tf_savedmodel_passes.h.inc" + +} // namespace tf_saved_model + +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_TF_SAVED_MODEL_PASSES_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/transforms/unroll_batch_matmul.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/transforms/unroll_batch_matmul.h new file mode 100644 index 00000000..39ceab7c --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/transforms/unroll_batch_matmul.h @@ -0,0 +1,39 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_UNROLL_BATCH_MATMUL_H_ +#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_UNROLL_BATCH_MATMUL_H_ + +#include "llvm/ADT/ArrayRef.h" +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/OperationSupport.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/TypeUtilities.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/core/util/matmul_bcast.h" + +namespace mlir { +namespace TF { + +// Populate patterns to unroll tf.BatchMatMulV2 op into a sequence of TF ops. +// Since TFLite does not support BatchMatMul operation, it unrolls a BatchMatMul +// op into tf.Reshape, tf.Slice, tf.MatMul, tf.Pack, and tf.Reshape ops. +void PopulateUnrollTfBatchMatMul(MLIRContext* context, + RewritePatternSet& patterns); + +} // namespace TF +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_UNROLL_BATCH_MATMUL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.h new file mode 100644 index 00000000..47bc42e0 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.h @@ -0,0 +1,54 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSLATE_EXPORT_TF_DIALECT_OP_H_ +#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSLATE_EXPORT_TF_DIALECT_OP_H_ + +#include + +#include "absl/status/statusor.h" +#include "llvm/ADT/StringRef.h" +#include "mlir/IR/Operation.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/utils/export_utils.h" +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/op_def_builder.h" +#include "tensorflow/core/platform/status.h" + +namespace tensorflow { + +// Extracts the attributes of a MLIR operation and populates the converted +// attributes in a proto map. +absl::Status GetAttrValuesFromOperation( + mlir::Operation* inst, llvm::StringRef name, + const tensorflow::OpRegistrationData* op_reg_data, + bool ignore_unregistered_attrs, AttrValueMap* attributes); + +// Converts a MLIR operation to TensorFlow NodeDef with given node name. This +// name should be unique to the graph it is being inserted to. If the +// `ignore_unregistered_attrs` argument is set to true, the attributes which are +// not in the op registry will be ignored. If the `ignore_unregistered_attrs` +// argument is not set to true, _output_shapes attribute is added to nodes with +// ShapedType for the leading values with ShapedType in the results of the +// nodes. Set it to true if the returned NodeDef will be executed by the linked +// TF Eager runtime. +absl::StatusOr> ConvertTFDialectOpToNodeDef( + mlir::Operation* inst, llvm::StringRef name, + bool ignore_unregistered_attrs); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSLATE_EXPORT_TF_DIALECT_OP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/translate/import_model.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/translate/import_model.h new file mode 100644 index 00000000..fe7684ad --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/translate/import_model.h @@ -0,0 +1,139 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSLATE_IMPORT_MODEL_H_ +#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSLATE_IMPORT_MODEL_H_ + +#include +#include +#include + +#include "absl/base/attributes.h" +#include "absl/log/check.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/OperationSupport.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "tensorflow/cc/saved_model/bundle_v2.h" +#include "tensorflow/cc/saved_model/loader.h" +#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_import_options.h" +#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/graph_debug_info.pb.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/protobuf/meta_graph.pb.h" + +namespace tensorflow { + +inline constexpr absl::string_view kImportModelDefaultGraphFuncName = "main"; + +// Given a GraphDef, returns a MLIR module containing the graph, expressed with +// tf_executor dialect. +ABSL_DEPRECATED("Use tensorflow::tf2xla::v2::ConvertGraphToTfExecutor instead.") +absl::StatusOr> ConvertGraphdefToMlir( + const GraphDef& graphdef, const GraphDebugInfo& debug_info, + const GraphImportConfig& specs, mlir::MLIRContext* context); + +// Given a SavedModel, returns a MLIR module containing the functions, expressed +// with tf_executor dialect. +absl::StatusOr> ConvertSavedModelToMlir( + SavedModelV2Bundle* saved_model, mlir::MLIRContext* context, + absl::Span exported_names, MLIRImportOptions options = {}); + +// Given a V1 SavedModel, returns a MLIR module containing the functions, +// expressed with tf_executor dialect. +absl::StatusOr> ConvertSavedModelV1ToMlir( + const SavedModelBundle& saved_model, absl::Span exported_names, + mlir::MLIRContext* context, MLIRImportOptions options = {}); + +// Given a V1 SavedModel, returns a MLIR module containing the functions, +// expressed with tf_executor dialect. It does not require a session to be +// created and it does not perform any graph transformation. If `exported_names` +// is std::nullopt, all signatures will be imported. Otherwise, only names +// in `exported_names` are imported. +// +// Note that the word `Lite` means it is a lighter version compared to +// ConvertSavedModelV1ToMlir(), and is not related to TFLite. +// +// TODO(b/179683149): Rename this class to avoid confusion with TFLite. +absl::StatusOr> ConvertSavedModelV1ToMlirLite( + const MetaGraphDef& meta_graph_def, const GraphDebugInfo& debug_info, + std::optional> exported_names, + mlir::MLIRContext* context, MLIRImportOptions options); + +// SavedModelMLIRImportInput is an adapter class for users to inject custom +// graph transformation logic on Tensorflow graphs before importing to MLIR. It +// serves as the source that provides the subgraphs requested by the savedmodel +// MLIR importer, and at the same time it allows the implementation of this +// class to transform the graph before feeding it to the importer. +class SavedModelMLIRImportInput { + public: + SavedModelMLIRImportInput(const MetaGraphDef* meta_graph_def, + const GraphDebugInfo& debug_info) + : meta_graph_def_(meta_graph_def), debug_info_(debug_info) { + DCHECK(meta_graph_def); + } + + virtual ~SavedModelMLIRImportInput(); + + // The original MetaGraphDef of the savedmodel. + const MetaGraphDef& meta_graph_def() const { return *meta_graph_def_; } + + const GraphDebugInfo& debug_info() const { return debug_info_; } + + // GetSubGraph() is expected to return a tensorflow::Graph that contains the + // node set specified in `specs`. The implementation is free to transform the + // graph in the original savedmodel as needed, as long as it produces the same + // results and effects. If the transformation requires some configs in `spec` + // (e.g., control_outputs) to be changed, they should be updated accordingly + // and remain valid for the graph. + // `name` is a unique identifier for this subgraph, so the implementation can + // use it for eg. debugging or caching compilation results. + virtual absl::StatusOr GetSubGraph( + absl::string_view name, GraphImportConfig& specs) = 0; + + private: + const MetaGraphDef* meta_graph_def_ = nullptr; + GraphDebugInfo debug_info_; +}; + +// Given the SavedModelMLIRImportInput for a saved model, returns a MLIR module +// containing the functions, expressed with tf_executor dialect. It does not +// require a session to be created. If `exported_names` is std::nullopt, all +// signatures will be imported. Otherwise, only names in `exported_names` are +// imported. + +// +// Note that the word `Lite` means it is a lighter version compared to +// ConvertSavedModelV1ToMlir(), and is not related to TFLite. +// +// TODO(b/179683149): Rename this class to avoid confusion with TFLite. +absl::StatusOr> ConvertSavedModelV1ToMlirLite( + SavedModelMLIRImportInput& input, + std::optional> exported_names, + mlir::MLIRContext* context, + bool unconditionally_use_set_output_shapes = false); + +// Serialize a MLIR module to a string. +std::string MlirModuleToString(mlir::ModuleOp m, bool show_debug_info = false); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSLATE_IMPORT_MODEL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/translate/mlir_import_options.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/translate/mlir_import_options.h new file mode 100644 index 00000000..b49ed7bb --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/translate/mlir_import_options.h @@ -0,0 +1,60 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSLATE_MLIR_IMPORT_OPTIONS_H_ +#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSLATE_MLIR_IMPORT_OPTIONS_H_ + +namespace tensorflow { + +// TODO(jpienaar): This file and class are confusingly named. This seems to be +// a SavedModel only import options file that exposes a subset of the +// GraphImportConfig options, but the naming would make one think it is more +// general. +struct MLIRImportOptions { + // If true, functionalize the input graph before importing it into MLIR. + bool upgrade_legacy = false; + + // Whether to unconditionally use the shape set via _output_shapes on import. + bool unconditionally_use_set_output_shapes = false; + + // Apply default attributes from the op definition to the loaded op. + bool add_default_attributes = true; + + // If set, promote tf.VarHandleOp to resource arguments for all functions. + bool lift_variables = true; + + // Keeps the variables in initializers before lifting variables (when + // `lift_variables == true`) or newly adding variable initialization patterns + // in the initializer functions. One might want to set this to `true` because + // the `RemoveVariablesInSessionInitializerPass` pass, which runs otherwise, + // may unexpectedly also remove the initialization patterns for non-variable + // resources (like hash tables) if they involve variables. Such a case is + // illustrated in the test file + // "../tests/tf_saved_model_remove_vars_in_session_initializer.mlir". + // This defaults to `false` to avoid breaking existing uses. + bool include_variables_in_initializers = false; + + // Load the model without restoring associated variables from disk. Enables + // loading raw programs without checkpoints. + bool allow_uninitialized_variables = false; + + // If true, variables are imported as DenseResourceElementsAttr; else, + // variables are imported as DenseElementsAttr. + bool import_variables_as_dense_resources = false; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSLATE_MLIR_IMPORT_OPTIONS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h new file mode 100644 index 00000000..cf90b7ed --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h @@ -0,0 +1,160 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSLATE_MLIR_ROUNDTRIP_FLAGS_H_ +#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSLATE_MLIR_ROUNDTRIP_FLAGS_H_ + +#include +#include +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/strings/string_view.h" +#include "llvm/ADT/MapVector.h" +#include "llvm/ADT/StringMap.h" +#include "tensorflow/core/framework/tensor_shape.pb.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { + +struct ArrayInfoBase { + // The node type when the input node is imported. Typically needs to be + // specified when passing arbitrary nodes (some node attributes are removed). + DataType imported_dtype; + + // Node "shape" attribute value. + TensorShapeProto shape; +}; + +struct ArrayInfo : public ArrayInfoBase { + using SubTypeInfo = ArrayInfoBase; + // DT_RESOURCE and DT_VARIANT have subtypes + std::vector subtypes; +}; + +struct GraphImportConfig { + // Returns string representation of config. + std::string str() const; + + using InputArrays = + llvm::MapVector>; + // The name assigned to the function which is the import result of the given + // graph. If empty, a default one will be used. + std::string graph_func_name; + // Maps input node names to node data types and shapes. + InputArrays inputs; + // name:index strings for the data outputs. + std::vector outputs; + // name strings for the control outputs. + std::vector control_outputs; + // Setting prune_unused_nodes to true, would prune unreachable nodes if + // output_arrays is specified. + bool prune_unused_nodes = false; + // If true, inputs of type LegacyFedInput are replaced with Placeholder ops. + // LegacyFedInput ops have two outputs unlike Placeholder which has only one + // output, so if both outputs of the LegacyFedInput ops are used then returns + // an error. + bool convert_legacy_fed_inputs = false; + // If true, the main graph will be treated as a function. + bool graph_as_function = false; + // If true, upgrade legacy features of the graph (for instance, functionalize + // control-flow). + bool upgrade_legacy = false; + // If true, functionalization is restricted to nodes that will be + // XLA-compiled. This is only needed if + // - `upgrade_legacy` is true + // - upgrading legacy features of the graph (which includes functionalization) + // runs before compilation cluster extraction (as for MLIR-based TPU bridge) + // - session runtime is used (session runtime has issues with function names + // rewritten by functionalization). + // Otherwise, this parameter should be set to false. + bool restrict_functionalization_to_compiled_nodes = false; + // If true, enables shape inference on input. + // TODO(jpienaar): This will be removed shortly. + bool enable_shape_inference = true; + // _output_shapes is an unregistered attribute which is used during + // GraphConstructor::ConvertGraph to override shapes. It is unfortunately + // not always set correctly (which is undesirable and should be addressed) + // so make it opt-in to consider it unconditionally also when importing the + // graph. + bool unconditionally_use_set_output_shapes = false; + // If set, use the value as the device type and mark the function graph for + // XLA compilation. + string xla_compile_device_type; + // If true, enables moving ops to different devices or moving unsupported ops + // out of a compilation cluster. + bool enable_soft_placement = false; + // If true, a function attribute, `tf._original_func_name`, will be set in + // functions which contains the corresponding original TF function name. + bool set_original_tf_func_name = false; + + // If true, all functions in the graph will be converted to MLIR regardless of + // whether the functions are referenced by the nodes. This is needed if + // aliases and saved model object graph function matching is needed. + bool convert_all_functions_to_mlir = false; +}; + +struct GraphExportConfig { + // Whether to export the entry function to function library instead of the + // graph. + bool export_entry_func_to_flib = false; + // Whether to export functions using the name set in the attribute + // `tf._original_func_name` if it exists. + bool export_original_tf_func_name = false; +}; + +// Parses the command line flag strings to the specification of nodes in +// the Graph. +absl::Status ParseOutputArrayInfo(absl::string_view array_names, + std::vector* outputs); + +absl::Status ParseOutputArrayInfo(const std::vector& output_names, + std::vector* outputs); + +// Parses the command line flag strings to the specification of nodes in +// the Graph. `data_types` input string can be empty since the flag is optional. +absl::Status ParseInputArrayInfo(absl::string_view array_names, + absl::string_view data_types, + absl::string_view shapes, + GraphImportConfig::InputArrays* inputs); + +absl::Status ParseInputArrayInfo( + const std::vector& node_names, + const std::vector& node_dtypes, + const std::vector>>& node_shapes, + GraphImportConfig::InputArrays* inputs); + +// Parses shapes from the given string into shapes_vector which is a structured +// format. +// NOTE: If shapes_str is empty, shapes_vector will also be empty. +absl::Status ParseNodeShapes( + absl::string_view shapes_str, + std::vector>>& shapes_vector); + +// Parses names from the given string into the names_vector. +// NOTE: If names_str is empty, names_vector will also be empty. +absl::Status ParseNodeNames(absl::string_view names_str, + std::vector& names_vector); + +// Parses data types from the given string into the data_type_vector. +// NOTE: If data_types_str is empty, data_type_vector will also be empty. +absl::Status ParseNodeDataTypes(absl::string_view data_types_str, + std::vector& data_type_vector); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSLATE_MLIR_ROUNDTRIP_FLAGS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h new file mode 100644 index 00000000..8d404575 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h @@ -0,0 +1,140 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSLATE_TF_MLIR_TRANSLATE_H_ +#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSLATE_TF_MLIR_TRANSLATE_H_ + +#include +#include +#include +#include +#include + +#include "absl/base/attributes.h" +#include "absl/base/macros.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "tensorflow/cc/saved_model/loader.h" +#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_import_options.h" + +namespace tensorflow { + +using tsl::Status; +using tsl::StatusOr; + +struct GraphdefToMlirOptions { + std::string debug_info_file; + std::string xla_compile_device_type; + bool prune_unused_nodes; + bool convert_legacy_fed_inputs; + bool graph_as_function; + bool upgrade_legacy; + bool enable_shape_inference; + bool unconditionally_use_set_output_shapes; + bool enable_soft_placement; + bool set_original_tf_func_name = false; +}; + +// TODO(antiagainst): Directly manipulating files in library functions is not +// a good idea. We should pass in a string/stream here. + +// Converts a TensorFlow GraphDef contained in `input` param into a MLIR module. +// Creates MLIR entities into the given MLIR `context`. +absl::StatusOr> +GraphdefToMlirTranslateFunction( + llvm::StringRef input, const std::vector& input_arrays, + const std::vector& input_dtypes, + const std::vector>>& input_shapes, + const std::vector& output_arrays, + const std::vector& control_output_arrays, + const GraphdefToMlirOptions& import_options, mlir::MLIRContext* context); + +ABSL_DEPRECATED( + "Please use the other overload of this function which accepts structured " + "inputs instead of strings") +// Converts a TensorFlow GraphDef contained in `input` param into a MLIR module. +// Creates MLIR entities into the given MLIR `context`. +absl::StatusOr> +GraphdefToMlirTranslateFunction( + llvm::StringRef input, absl::string_view input_arrays, + absl::string_view input_dtypes, absl::string_view input_shapes, + absl::string_view output_arrays, absl::string_view control_output_arrays, + const GraphdefToMlirOptions& import_options, mlir::MLIRContext* context); + +// Similar as the above function, but replaces all constant tensors +// with randomly generated splat values. +absl::StatusOr> +GraphdefToSplattedMlirTranslateFunction( + llvm::StringRef input, const std::vector& input_arrays, + const std::vector& input_dtypes, + const std::vector>& input_shapes, + const std::vector& output_arrays, + const std::vector& control_output_arrays, + const GraphdefToMlirOptions& import_options, mlir::MLIRContext* context); + +ABSL_DEPRECATED( + "Please use the other overload of this function which accepts structured " + "inputs instead of strings") +// Similar as the above function, but replaces all constant tensors +// with randomly generated splat values. +absl::StatusOr> +GraphdefToSplattedMlirTranslateFunction( + llvm::StringRef input, absl::string_view input_arrays, + absl::string_view input_dtypes, absl::string_view input_shapes, + absl::string_view output_arrays, absl::string_view control_output_arrays, + const GraphdefToMlirOptions& import_options, mlir::MLIRContext* context); + +// Converts a TensorFlow SavedModel stored in the directory with the given +// `saved_model_dir` into a MLIR module. Creates MLIR entities into the +// given MLIR `context`. +absl::StatusOr> +SavedModelObjectGraphToMlirImport( + absl::string_view saved_model_dir, + const std::unordered_set& tags, + absl::Span exported_names, mlir::MLIRContext* context, + bool unconditionally_use_set_output_shapes = false, + bool import_variables_as_dense_resources = false); + +// Converts a TensorFlow V1 SavedModel stored in the directory with the given +// `saved_model_dir` into a MLIR module. Creates MLIR entities into the +// given MLIR `context`. +// 'saved_model_bundle' if not null, will be initialized with the model bundle. +absl::StatusOr> +SavedModelSignatureDefsToMlirImport( + absl::string_view saved_model_dir, + const std::unordered_set& tags, + absl::Span exported_names, mlir::MLIRContext* context, + MLIRImportOptions options, + std::unique_ptr* saved_model_bundle = + nullptr); + +// Converts a TensorFlow V1 SavedModel stored in the directory with the given +// `saved_model_dir` into a MLIR module. Creates MLIR entities into the +// given MLIR `context`. This does not create session internally so it is faster +// and does not perform any graph transformation. +absl::StatusOr> +SavedModelSignatureDefsToMlirImportLite( + absl::string_view saved_model_dir, + const std::unordered_set& tags, + absl::Span exported_names, mlir::MLIRContext* context, + MLIRImportOptions options); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSLATE_TF_MLIR_TRANSLATE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/translate/upgrade_graph.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/translate/upgrade_graph.h new file mode 100644 index 00000000..31baee55 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/translate/upgrade_graph.h @@ -0,0 +1,35 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSLATE_UPGRADE_GRAPH_H_ +#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSLATE_UPGRADE_GRAPH_H_ + +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/graph/graph.h" + +namespace tensorflow { + +class GraphDef; +class MetaGraphDef; + +// Generate the shared_name for resource handle ops in the graph and functions +// if their shared_names are empty. Resource handle ops with empty shared_name +// may have undesired semantics. +absl::Status GenerateResourceSharedNameIfEmpty( + GraphDef& gdef, const OpRegistryInterface* default_registry); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSLATE_UPGRADE_GRAPH_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.h new file mode 100644 index 00000000..6ed684e2 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.h @@ -0,0 +1,206 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_ATTRIBUTE_UTILS_H_ +#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_ATTRIBUTE_UTILS_H_ + +#include +#include + +#include "llvm/ADT/SmallVector.h" +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "tensorflow/compiler/tf2xla/tf2xla_defs.h" + +namespace mlir { +namespace TF { + +// TODO(b/229028654) Use definitions from tf2xla_defs.h directly. We currently +// don't do this to avoid explicit casts (implicit conversion from +// `absl::string_view` to `llvm::StringRef` is not supported until C++17). + +// Whether soft placement is allowed. If true, the marked node is eligible for +// outside compilation. +inline constexpr llvm::StringRef kAllowSoftPlacementAttr = + "allow_soft_placement"; + +// Marks a node for XLA compilation. The attribute value indicates the +// compilation device type. +inline constexpr llvm::StringRef kCompileDeviceTypeAttr = + "_xla_compile_device_type"; +// The attribute value speicifes the preferred outlined function name in +// ClusterOutliningPass. +inline constexpr llvm::StringRef kClusterOutlinedFunctionNameAttr = + "_cluster_outlined_function_name"; +// Marks a node for replication. The attribute value indicates the replication +// metadata op. +inline constexpr llvm::StringRef kReplicationInfoAttr = "_replication_info"; +// Marks a node for XLA-TPU compilation. The attribute value indicates the +// associated compilation cluster and replication metadata op. +inline constexpr llvm::StringRef kTpuReplicateAttr = "_tpu_replicate"; +// Device types. +inline constexpr llvm::StringRef kTpuDevice = "TPU"; +// _xla_outside_compilation +inline constexpr llvm::StringRef kXlaOutsideCompilationAttr = + "_xla_outside_compilation"; +// device attr +inline constexpr llvm::StringRef kDeviceAttr = "device"; +// Function attribute to signal that a function should be skipped from TPU +// island outlining. The attribute is set in +// `TpuV1BridgeExecutorIslandCoarsening` and removed in the subsequent +// `TPUBridgeExecutorIslandOutlining` pass. +inline constexpr llvm::StringRef kSkipIslandOutlining = + "_skip_island_outlining"; +// Function attribute to signal which argument contains bounded dynamic +// dimension. +inline constexpr llvm::StringRef kDynamicArgIndexAttr = "_dynamic_arg_index"; + +// This string attribute encodes parallel execution groups and their associated +// branches. It has the following format: +// `_parallel_execution_ids= group1:branch1,group2:branch2,...` +// For example, if we have IR as follows: +// +// tf_executor.island wraps "tf.OpA" +// tf_executor.island { +// "tf_device.replicate" {n = 2} { +// "tf.OpB" +// "tf_device.parallel_execute"() ({ +// "tf.OpC" +// }, { +// "tf.OpD" +// }) +// } +// +// The above IR will be flattened after `ReplicateToIslandPass` and +// `ParallelExecuteToIslandsPass` as follows: +// +// tf_executor.island wraps "tf.OpA" +// tf_executor.island {_parallel_execution_ids=r0:0} wraps "tf.OpB" +// tf_executor.island {_parallel_execution_ids=r0:0,p0:0} wraps "tf.OpC" +// tf_executor.island {_parallel_execution_ids=r0:0,p0:1} wraps "tf.OpD" +// tf_executor.island {_parallel_execution_ids=r0:1} wraps "tf.OpB" +// tf_executor.island {_parallel_execution_ids=r0:1,p0:0} wraps "tf.OpC" +// tf_executor.island {_parallel_execution_ids=r0:1,p0:1} wraps "tf.OpD" +// +// "tf.OpA" will not have `_parallel_execution_ids` attr, +// means it does not belong to any parallel execution groups. +// First instance of "tf.OpB" after flattening will have +// `_parallel_execution_ids = "r0:0"`, +// which represents the first branch of replicate group 0. +// Second instance of "tf.OpB" after flattening will have +// `_parallel_execution_ids = "r0:1"` +// which represents the second branch of replicate group 0. +// First instance of "tf.OpC" after flattening will have +// `_parallel_execution_ids = "r0:0,p0:0"` +// which represents the first branch of replicate group 0 and +// the first branch of parallel group 0. +// Second instance of "tf.OpC" after flattening will have +// `_parallel_execution_ids = "r0:1,p0:0"` +// which represents the second branch of replicate group 0 and +// the first branch of parallel group 0. +// First instance of "tf.OpD" after flattening will have +// `_parallel_execution_ids = "r0:0,p0:1"` +// which represents the first branch of replicate group 0 and +// the second branch of parallel group 0. +// Second instance of "tf.OpD" after flattening will have +// `_parallel_execution_ids = "r0:1,p0:1"` +// which represents the second branch of replicate group 0 and +// the second branch of parallel group 0. +inline constexpr llvm::StringRef kParallelExecAnnotation = + "_parallel_execution_ids"; + +// Logging + +// Name of component for error logging. This name is fixed and required to +// enable logging. +inline const char kBridgeComponent[] = "TFXLABridge"; +inline const char kMlirPh1BridgeCounterReplicated[] = "replicated"; +inline const char kMlirPh1BridgeCounterNonReplicated[] = "nonreplicated"; +inline const char kMlirPh1BridgeCounterV1[] = "v1"; +inline const char kMlirPh1BridgeCounterV2[] = "v2"; +inline const char kMlirPh1BridgeCounterTpu[] = "tpu"; +inline const char kMlirPh1BridgeCounterNonTpu[] = "cpu/gpu"; +inline const char kXlaOutsideCompilation[] = "_xla_outside_compilation"; + +// Copies attributes that satisfy the given predicate from `from` to `to`. +template +void CopyAttributes(Operation *from, Operation *to, Predicate P) { + for (const NamedAttribute &attr : from->getAttrs()) + if (P(attr)) to->setAttr(attr.getName(), attr.getValue()); +} + +// Copies attributes whose name begins with an _ from `from` to `to`. +inline void CopyUnderscoredAttributes(Operation *from, Operation *to) { + CopyAttributes(from, to, [](const NamedAttribute &attr) { + return attr.getName().strref().front() == '_'; + }); +} + +// Copies outside compilation attribute from `from` to `to`. +inline void CopyXlaOutsideCompilationAttributes(Operation *from, + Operation *to) { + CopyAttributes(from, to, [](const NamedAttribute &attr) { + return attr.getName().strref() == kXlaOutsideCompilationAttr; + }); +} + +// Copies attributes that are either `device` or whose name begins with an _ +// from `from` to `to`. +// TODO(b/158769932): This should be a general feature instead post some policy +// discussion. +inline void CopyDeviceAndUnderscoredAttributes(Operation *from, Operation *to) { + auto device = mlir::StringAttr::get(from->getContext(), "device"); + CopyAttributes(from, to, [&device](const NamedAttribute &attr) { + return attr.getName().strref().front() == '_' || attr.getName() == device; + }); +} + +// Forward declare these passthrough ops. +// TODO(jpienaar): Remove these and use trait instead. +class IdentityOp; +class IdentityNOp; + +// Returns if a value corresponds to a constant, returns the matched constant +// as an attribute. +template +bool GetValueAsConstant(Value val, AttrT &attr) { + while (auto result = mlir::dyn_cast(val)) { + Operation *op = result.getOwner(); + if (!isa(op) && !isa(op)) break; + val = op->getOperand(result.getResultNumber()); + } + return matchPattern(val, m_Constant(&attr)); +} + +// Checks if both compilation and replication attributes are present in the +// operation, and if their values are valid. +LogicalResult HasValidCompilationAndReplicationAttributes(Operation &op); + +// Checks if the device attribute is valid. +LogicalResult IsValidDeviceTypeOrEmpty(StringAttr attr); + +using ParallelExecutionIdPairs = + llvm::SmallVector, 8>; +// Parses the parallel execution attribute for `op` and fills `id_pairs` with +// the corresponding (group ID,branch ID) pairs. +// Returns `failure` if the attribute is malformed. +LogicalResult ParseParallelExecutionIds(Operation *op, + ParallelExecutionIdPairs &id_pairs); + +} // namespace TF +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_ATTRIBUTE_UTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/utils/bridge_logger.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/utils/bridge_logger.h new file mode 100644 index 00000000..84bc1c60 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/utils/bridge_logger.h @@ -0,0 +1,100 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_BRIDGE_LOGGER_H_ +#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_BRIDGE_LOGGER_H_ + +#include +#include + +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/OperationSupport.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassManager.h" // from @llvm-project + +namespace tensorflow { + +// Logger for logging MLIR modules before and after passes in MLIR TPU bridge. +// +// The IR logging can be restricted to a particular set of pass invocations via +// filters that are specified with the `MLIR_BRIDGE_LOG_PASS_FILTER` and +// `MLIR_BRIDGE_LOG_STRING_FILTER` environment variables. +// `MLIR_BRIDGE_LOG_PASS_FILTER` takes a semicolon-separated list of pass class +// names, `MLIR_BRIDGE_LOG_STRING_FILTER` takes a semicolon-separated list of +// strings, and IR is only dumped for a pass invocation if the pass name exactly +// matches any of the provided pass names and if the serialized operation on +// which the pass is invoked contains any of the specified strings as a +// substring. An empty list is interpreted as no restriction. The string filter +// can be handy e.g. if one is only interested in a certain function or when +// checking where a certain attribute gets lost. Note that we use a semicolon +// instead of comma as the separator to allow strings that contain commas (which +// frequently appear in MLIR). The strings can contain any characters (including +// spaces) except semicolons. +// +// Example: Setting the environment variables +// `MLIR_BRIDGE_LOG_PASS_FILTER="LegalizeTF;Canonicalizer"` and +// `MLIR_BRIDGE_LOG_STRING_FILTER="my_string"` will dump IR only for invocations +// of `LegalizeTF` and `Canonicalizer` where the string `my_string` is contained +// in the serialized operation on which the pass is invoked. For verbose log +// level >= 1, `bridge_logger.cc` prints details about pass invocations for +// which the IR dumping was skipped because of a filter. +class BridgeLoggerConfig : public mlir::PassManager::IRPrinterConfig { + public: + explicit BridgeLoggerConfig( + bool print_module_scope = false, bool print_after_only_on_change = true, + mlir::OpPrintingFlags op_printing_flags = mlir::OpPrintingFlags()); + + // A hook that may be overridden by a derived config that checks if the IR + // of 'operation' should be dumped *before* the pass 'pass' has been + // executed. If the IR should be dumped, 'print_callback' should be invoked + // with the stream to dump into. + void printBeforeIfEnabled(mlir::Pass* pass, mlir::Operation* op, + PrintCallbackFn print_callback) override; + + // A hook that may be overridden by a derived config that checks if the IR + // of 'operation' should be dumped *after* the pass 'pass' has been + // executed. If the IR should be dumped, 'print_callback' should be invoked + // with the stream to dump into. + void printAfterIfEnabled(mlir::Pass* pass, mlir::Operation* op, + PrintCallbackFn print_callback) override; + + // Returns `true` iff we should log IR for given `pass` and `op`. + // Note: Visibility of this function is public for use in unit testing. + bool ShouldPrint(mlir::Pass* pass, mlir::Operation* op); + + private: + // Get `filter` encoded by environment variable `env_var`. + static std::vector GetFilter(const std::string& env_var); + // Returns `true` iff any of the strings in `filter` matches `str`, either + // exactly or as a substring, depending on `exact_match`. + static bool MatchesFilter(const std::string& str, + const std::vector& filter, + bool exact_match); + // Determines whether only top-level passes should be dumped. + // Returns true unless the environment variable is set to "0" or "false". + static bool ShouldOnlyDumpTopLevelPasses(); + + // Only log pass invocations whose pass name exactly matches any string in + // `pass_filter_` (or when `pass_filter_` is empty). + const std::vector pass_filter_; + // Only log pass invocations where the serialized operation on which the pass + // is invoked contains any of the specified strings as a substring (or when + // `string_filter_` is empty). + const std::vector string_filter_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_BRIDGE_LOGGER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/utils/call_graph_util.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/utils/call_graph_util.h new file mode 100644 index 00000000..ddefbd0a --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/utils/call_graph_util.h @@ -0,0 +1,119 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_CALL_GRAPH_UTIL_H_ +#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_CALL_GRAPH_UTIL_H_ + +#include +#include +#include + +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project + +namespace mlir { + +// Return a list of attribute names that indicates an entry function. +std::vector GetEntryFunctionAttributeNames(); + +// Check if a function is an entry in an MLIR module. +bool IsEntryFunction(func::FuncOp func); + +// Get all the entry functions in an MLIR module. +llvm::SmallVector GetEntryFunctions(ModuleOp module); + +// Get all the functions referenced in a symber user op and save them in +// `callees`. +LogicalResult GetCallees(SymbolUserOpInterface op, SymbolTable &symtab, + llvm::SmallVector &callees); + +// Find the first op with any of the specified types on each path rooted at the +// `root` node in a tree. Additional checks can be applied via `predicate`. The +// results are stored in `ops`. +template +LogicalResult GetFirstOpsOfType( + func::FuncOp root, SymbolTable &symtab, + const std::function &predicate, + llvm::SmallVector &ops) { + std::stack worklist; + worklist.push(root); + while (!worklist.empty()) { + func::FuncOp u = worklist.top(); + worklist.pop(); + auto result = u.walk([&](SymbolUserOpInterface op) { + if (llvm::isa(op) && (!predicate || predicate(op))) { + ops.push_back(op); + return WalkResult::advance(); + } + llvm::SmallVector callees; + if (GetCallees(op, symtab, callees).failed()) { + return WalkResult::interrupt(); + } + for (auto callee : callees) { + worklist.push(callee); + } + return WalkResult::advance(); + }); + if (result.wasInterrupted()) return failure(); + } + return success(); +} + +// Find the nodes with any of the specified types on the tree rooted at `root` +// node. Additional checks can be applied via `predicate`. The search skips +// the current path if a node with the specified types fails the check, and +// continues on the next path. The passing ops are stored in `hits`, while the +// first failing on on each path is stored in `first_misses`. +template +LogicalResult GetOpsOfTypeUntilMiss( + func::FuncOp root, SymbolTable &symtab, + const std::function &predicate, + llvm::SmallVector &hits, + llvm::SmallVector &first_misses) { + std::stack worklist; + worklist.push(root); + while (!worklist.empty()) { + func::FuncOp u = worklist.top(); + worklist.pop(); + auto result = u.walk([&](SymbolUserOpInterface op) { + if (llvm::isa(op)) { + if (!predicate || predicate(op)) { + hits.push_back(op); + } else { + first_misses.push_back(op); + return WalkResult::advance(); + } + } + llvm::SmallVector callees; + if (GetCallees(op, symtab, callees).failed()) { + return WalkResult::interrupt(); + } + for (auto callee : callees) { + worklist.push(callee); + } + return WalkResult::advance(); + }); + if (result.wasInterrupted()) return failure(); + } + return success(); +} + +// Check if a function has one region and one block only. +bool HasSingleBlock(func::FuncOp func); + +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_CALL_GRAPH_UTIL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/utils/cluster_util.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/utils/cluster_util.h new file mode 100644 index 00000000..c521298e --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/utils/cluster_util.h @@ -0,0 +1,67 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_CLUSTER_UTIL_H_ +#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_CLUSTER_UTIL_H_ + +#include +#include + +#include "llvm/ADT/StringRef.h" +#include "mlir/IR/Block.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.h" + +namespace mlir::TF { + +// Cluster structure captures all the operations that are assigned to same +// device and can form a legal strict cluster. +// Ops must follow same ordering in their parent block. We rely on this +// assumption to perform analysis. +struct Cluster { + llvm::SetVector ops; + std::string target; +}; + +// Builds the op clusters in the `block`. Ops are filtered by the function +// `get_target` that takes an op and returns the target name. `is_ignored_op` is +// a hook to ignore certain ops that are not included in any clusters. +llvm::StringMap> BuildAllClusters( + Block& block, const TF::SideEffectAnalysis::Info& side_effect_analysis, + std::function get_target, + std::function is_ignored_op); + +// Reorder all users of the given op's results to after the op. +// +// Since launch ops are inserted after the last op in the region, the region is +// guaranteed to dominate all live-in values. On the other hand, it is still +// possible that live-out values don't dominate the region. For example: +// +// ``` +// %0 = "tf.OpA"() +// %1 = "tf.OpB"(%0) +// %2 = "tf.OpC"(%0) +// ``` +// +// Assuming `tf.OpA` and `tf.OpC` are clustered together, the region will be +// inserted right after `tf.OpC`. The live-out `%0`, however, is used by +// `tf.OpB`, which won't dominate the region. This function reorders all users +// of the cluster op to be placed after the cluster op itself so that SSA +// dominance is preserved after cluster op creation. +void ReorderOpResultUses(mlir::Operation* cluster); + +} // namespace mlir::TF + +#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_CLUSTER_UTIL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/utils/convert_attr.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/utils/convert_attr.h new file mode 100644 index 00000000..10271fcb --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/utils/convert_attr.h @@ -0,0 +1,39 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_CONVERT_ATTR_H_ +#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_CONVERT_ATTR_H_ + +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tsl/platform/statusor.h" + +namespace tensorflow { + +using tsl::StatusOr; + +// Converts non func AttrValue proto into an MLIR attribute. Func attribute is +// exclused in this function because the function might be renamed when the +// function definition is imported. +absl::StatusOr ConvertNonFuncAttributeValue( + const AttrValue& value, mlir::Builder* builder); + +// Converts all kinds of AttrValue proto into an MLIR attribute. +absl::StatusOr ConvertAttributeValue(const AttrValue& value, + mlir::Builder* builder); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_CONVERT_ATTR_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h new file mode 100644 index 00000000..ba5cd3d8 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h @@ -0,0 +1,70 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_CONVERT_TENSOR_H_ +#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_CONVERT_TENSOR_H_ + +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/SmallVector.h" +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor.pb.h" +#include "tensorflow/core/framework/tensor_shape.pb.h" +#include "tensorflow/core/protobuf/struct.pb.h" + +namespace tensorflow { + +using tsl::StatusOr; + +// Converts an TensorFlow tensor proto into an MLIR elements attribute. +absl::StatusOr ConvertTensorProto( + const TensorProto& input_tensor, mlir::Builder* builder, + bool convert_to_dense_resource = false); + +// Converts an TensorFlow tensor into an MLIR elements attribute. +absl::StatusOr ConvertTensor( + const Tensor& input_tensor, mlir::Builder* builder, + bool convert_to_dense_resource = false); + +// Converts a shape from MLIR to a TensorFlow tensor shape proto. +void ConvertToTensorShapeProto(llvm::ArrayRef shape, + TensorShapeProto* output_shape); + +// Converts an MLIR type to a TensorFlow tensor shape. +PartialTensorShape ConvertTypeToTensorShape(const mlir::Type& type); + +// Converts an MLIR shaped type to a TensorFlow shape attribute. +mlir::TF::ShapeAttr ConvertTypeToTensorShapeAttr(const mlir::Type& type); + +// Converts an MLIR shaped type to a Tensorflow tensor spec proto. +absl::StatusOr ConvertTypeToTensorSpecProto( + const mlir::Type& type); + +// Converts a TensorFlow shape attribute to an MLIR shape attribute. +absl::StatusOr ConvertTensorShapeProto( + const TensorShapeProto& shape, mlir::MLIRContext* context); + +// Converts an MLIR elements attribute to a TensorFlow tensor proto. +absl::Status ConvertToTensorProto(mlir::ElementsAttr attr, + TensorProto* output_tensor); + +// Converts an MLIR elements attribute to a TensorFlow tensor. +absl::Status ConvertToTensor(mlir::ElementsAttr attr, Tensor* output_tensor); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_CONVERT_TENSOR_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/utils/convert_type.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/utils/convert_type.h new file mode 100644 index 00000000..1ce9d054 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/utils/convert_type.h @@ -0,0 +1,54 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_CONVERT_TYPE_H_ +#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_CONVERT_TYPE_H_ + +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/tensor_shape.pb.h" +#include "tensorflow/core/framework/types.pb.h" + +namespace tensorflow { + +using tsl::StatusOr; + +// Converts the TensorFlow DataType 'dtype' into an MLIR (scalar) type. +absl::Status ConvertDataType(DataType dtype, mlir::Builder builder, + mlir::Type* type); + +// Converts a scalar MLIR type to a TensorFlow Datatype. +absl::Status ConvertScalarTypeToDataType(mlir::Type type, DataType* dtype); + +// Converts an MLIR type to TensorFlow DataType. If 'type' is a scalar type, it +// is converted directly. If it is a shaped type, the element type is converted. +absl::Status ConvertToDataType(mlir::Type type, DataType* dtype); + +// Converts an TensorFlow shape to the one used in MLIR. +void ConvertToMlirShape(const TensorShape& input_shape, + llvm::SmallVectorImpl* shape); + +// Converts an TensorFlow shape proto to the one used in MLIR. +absl::Status ConvertToMlirShape(const TensorShapeProto& input_shape, + llvm::SmallVectorImpl* shape); + +// Given a tensor shape and dtype, get the corresponding MLIR tensor type. +absl::StatusOr ConvertToMlirTensorType( + const TensorShapeProto& shape, DataType dtype, mlir::Builder* builder); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_CONVERT_TYPE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/utils/data_dumper_logger_config.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/utils/data_dumper_logger_config.h new file mode 100644 index 00000000..e45479bf --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/utils/data_dumper_logger_config.h @@ -0,0 +1,57 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_DATA_DUMPER_LOGGER_CONFIG_H_ +#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_DATA_DUMPER_LOGGER_CONFIG_H_ + +#include +#include + +#include "tensorflow/compiler/mlir/tensorflow/utils/bridge_logger.h" + +namespace tensorflow { + +class DataDumperLoggerConfig : public ::tensorflow::BridgeLoggerConfig { + public: + explicit DataDumperLoggerConfig( + std::function + get_filename, + const std::string &pass_prefix = "", bool print_module_scope = false, + bool print_after_only_on_change = true, + mlir::OpPrintingFlags op_printing_flags = mlir::OpPrintingFlags()); + + void printBeforeIfEnabled(mlir::Pass *pass, mlir::Operation *op, + PrintCallbackFn print_callback) override; + + void printAfterIfEnabled(mlir::Pass *pass, mlir::Operation *op, + PrintCallbackFn print_callback) override; + + private: + static void DumpMlir(const std::string &filename, + BridgeLoggerConfig::PrintCallbackFn print_callback); + + // The function to dump the target MLIR string to file. + // The parameter that will be sent to the dump_func_ is: + // The pass name (std::string) + std::function + get_filename_; + + // The pass prefix. + std::string pass_prefix_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_DATA_DUMPER_LOGGER_CONFIG_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/utils/device_util.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/utils/device_util.h new file mode 100644 index 00000000..14e48bf7 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/utils/device_util.h @@ -0,0 +1,53 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_DEVICE_UTIL_H_ +#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_DEVICE_UTIL_H_ + +#include "llvm/ADT/SmallVector.h" +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h" +#include "tensorflow/core/common_runtime/device_set.h" +#include "tensorflow/core/util/device_name_utils.h" + +namespace tensorflow { + +// Collects all devices known to the system by name and adds them as a +// `tf.devices` dictionary attribute with a full device name as a key, and +// device metadata as a value. +// +// Device names added in full parsed device form: +// /job:/replica:/task:/device:: +// +// Supported device metadata types: +// (1) GpuDeviceMetadata: GPU device compute capability. +void AddDevicesToOp(mlir::Operation* op, const DeviceSet* device_set); + +// Collects devices information from an op `tf.devices` attributes. Returns +// failure if can't parse device metadata from the attribute. +mlir::LogicalResult GetDevicesFromOp(mlir::Operation* op, + mlir::TF::RuntimeDevices* devices); + +// Parses a device string and returns its ordinal (id). This will return an +// error if the device string is invalid or has no id. +mlir::LogicalResult GetDeviceOrdinalFromDeviceString(mlir::Location loc, + llvm::StringRef device, + int64_t* device_ordinal); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_DEVICE_UTIL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/utils/dump_graph.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/utils/dump_graph.h new file mode 100644 index 00000000..ae6e0b61 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/utils/dump_graph.h @@ -0,0 +1,76 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_DUMP_GRAPH_H_ +#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_DUMP_GRAPH_H_ + +#include +#include + +#include "mlir/IR/OperationSupport.h" // from @llvm-project +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/platform/status.h" + +namespace tensorflow { + +struct MlirDumpConfig; + +// Dumps 'graph_def' to a file, as textual IR. Returns the file name chosen. +// +// Note: This is for debugging use and is not optimized for performance. +absl::Status DumpTextualIRToFile(const MlirDumpConfig& config, + const Graph& graph, + const FunctionLibraryDefinition* flib_def, + WritableFile* file); + +// Config of the textual dump. +struct MlirDumpConfig { + enum class Dialect { + // Tensorflow Graph Dialect + kTFG, + }; + + // The limit of element size that gets printed. + MlirDumpConfig& elide_large_attributes(int large_element_limit = 16) { + this->op_printing_flags.elideLargeElementsAttrs(large_element_limit); + return *this; + } + + // Enable printing of debug information. If 'pretty_form' is set to true, + // debug information is printed in a more readable 'pretty' form but this + // pretty form is not parsable (so only for human readability). + MlirDumpConfig& emit_location_information(bool pretty_form = false) { + this->op_printing_flags.enableDebugInfo(/*enable=*/true, pretty_form); + return *this; + } + + MlirDumpConfig& emit_dialect(Dialect dialect) { + this->dialect = dialect; + return *this; + } + + // Op printing flags. + mlir::OpPrintingFlags op_printing_flags = std::nullopt; + + // The target MLIR dialect. + Dialect dialect = Dialect::kTFG; +}; + +// Change DumpGraphToFile to dump MLIR textual IR instead of protobuf. +void UseMlirForGraphDump(const MlirDumpConfig& = {}); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_DUMP_GRAPH_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h new file mode 100644 index 00000000..87d53e8b --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h @@ -0,0 +1,105 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_DUMP_MLIR_UTIL_H_ +#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_DUMP_MLIR_UTIL_H_ + +#include +#include + +#include "absl/strings/string_view.h" +#include "llvm/ADT/StringRef.h" +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/Pass/PassManager.h" // from @llvm-project +#include "tensorflow/core/platform/status.h" + +namespace tensorflow { + +inline constexpr absl::string_view kCrashReproducerStdErr = "-"; +inline constexpr absl::string_view kCrashReproducerCrashAnalysis = + "crash_analysis"; + +// Creates a file to use for dumping and returns success if a file could be +// created. The opened file is placed in 'os' and the path of the file used is +// placed in 'filepath'. +// +// If the TF_DUMP_GRAPH_PREFIX environment variable is kCrashReproducerStdErr, +// then the LOG(INFO) macro is used instead. +// +// This will create a file name via prefixing `name` with the value of the +// TF_DUMP_GRAPH_PREFIX environment variable if `dirname` is empty and +// suffixing `name` with ".mlir". +absl::Status CreateFileForDumping(llvm::StringRef name, + std::unique_ptr* os, + std::string* filepath, + llvm::StringRef dirname = ""); + +// Dumps MLIR operation to a file and returns the file name used. +// +// If the TF_DUMP_GRAPH_PREFIX environment variable is kCrashReproducerStdErr, +// then the MLIR operation will be logged (using the LOG(INFO) macro) instead. +// +// This will create a file name via prefixing `name` with the value of the +// TF_DUMP_GRAPH_PREFIX environment variable if `dirname` is empty and +// suffixing `name` with ".mlir". +// If `pass_manager` is provided, prints a header with the pass pipeline. +std::string DumpMlirOpToFile(llvm::StringRef name, mlir::Operation* op, + llvm::StringRef dirname = "", + const mlir::PassManager* pass_manager = nullptr); + +// Reads the directory to dump the MLIR module from environment variables. +// Default is reading from TF_DUMP_GRAPH_PREFIX, and if the string is 'sponge' +// read from TEST_UNDECLARED_OUTPUTS_DIR. Returns nullptr if the directory +// cannot be determined and generates a warning message. +std::string GetDumpDirFromEnvVar(); + +// Dumps a raw string to a file and returns the file name used. +// +// This will create a file name via prefixing `name` with the value of the +// TF_DUMP_GRAPH_PREFIX environment variable if `dirname` is empty and +// suffixing `name` with ".mlir". +std::string DumpRawStringToFile(llvm::StringRef name, llvm::StringRef content, + llvm::StringRef dirname = ""); + +// Enable the crash reproducer on the provided PassManager to the provided +// directory path. +// If the provided path is empty, it is retrieved from the +// environment variable `MLIR_CRASH_REPRODUCER_DIRECTORY`. +// If the provided path is the string "sponge", the file will be included +// in the sponge "Output Files" by looking up the environment to infer +// the directory path. +// If the provided path is the string kCrashReproducerStdErr, the data is +// dumped into the stderr. +// If the provided path is the string kCrashReproducerCrashAnalysis, the data +// is dumped to the crash analysis system. Note, environment var +// `MLIR_CRASH_REPRODUCER_DIRECTORY` can be used to override +// kCrashReproducerCrashAnalysis settings. +void SetCrashReproducer(mlir::PassManager& pm, llvm::StringRef dir_path = ""); + +// This applies both the PassManagerCLOptions provided by MLIR along with any +// tensorflow specific options. +// +// Note that this function should be in a more appropriate file, but it is +// unclear what a proper file would be as no other functions would currently be +// in the file also. +void applyTensorflowAndCLOptions(mlir::PassManager& pm, + llvm::StringRef dir_path = ""); + +// Prints the pass pipeline of `pass_manager` to `os`. +void PrintPassPipeline(const mlir::PassManager& pass_manager, + mlir::Operation* op, llvm::raw_ostream& os); +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_DUMP_MLIR_UTIL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/utils/dynamic_shape_utils.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/utils/dynamic_shape_utils.h new file mode 100644 index 00000000..a06d9664 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/utils/dynamic_shape_utils.h @@ -0,0 +1,34 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_DYNAMIC_SHAPE_UTILS_H_ +#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_DYNAMIC_SHAPE_UTILS_H_ + +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project + +namespace tensorflow { + +llvm::SmallVector ConvertTFShapeToMlir(llvm::ArrayRef shapes); + +llvm::SmallVector ConvertMlirShapeToTF(llvm::ArrayRef shape); + +static constexpr int64_t kTFDynamicSize = -1; +mlir::RankedTensorType GetTypeFromTFTensorShape(llvm::ArrayRef shape, + mlir::Type elementType, + mlir::Attribute encoding = {}); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_DYNAMIC_SHAPE_UTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/utils/error_util.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/utils/error_util.h new file mode 100644 index 00000000..bd958c8c --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/utils/error_util.h @@ -0,0 +1,63 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_ERROR_UTIL_H_ +#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_ERROR_UTIL_H_ + +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "xla/mlir/utils/error_util.h" +#include "tensorflow/core/platform/status.h" + +// Error utilities for MLIR when interacting with code using Status returns. +namespace mlir { + +// TensorFlow's Status is used for error reporting back to callers. +using ::tensorflow::Status; + +// TF customized diagnostic handler that collects all the diagnostics reported +// and can produce a Status to return to callers. This is for the case where +// MLIR functions are called from a function that will return a Status: MLIR +// code still uses the default error reporting, and the final return function +// can return the Status constructed from the diagnostics collected. +// todo: [b/253331656]. Note ConsumeStatus() and Combine() are wrappers +// of what is inherited from the BaseScopedDiagnosticHandler to +// support cases where tensorflow::Status is still being used (base class uses +// absl::Status) +class StatusScopedDiagnosticHandler : public BaseScopedDiagnosticHandler { + public: + // Constructs a diagnostic handler in a context. If propagate is true, then + // diagnostics reported are also propagated back to the original diagnostic + // handler. If filter_stack is true, a reduced stack will be produced. + + explicit StatusScopedDiagnosticHandler(MLIRContext* context, + bool propagate = false, + bool filter_stack = false); + + ~StatusScopedDiagnosticHandler() = default; + // Returns Status corresponding to the diagnostics reported. This consumes + // the diagnostics reported and returns a Status of type Unknown. It is + // required to consume the error status, if there is one, before destroying + // the object. + Status ConsumeStatus(); + + // Returns the combination of the passed in status and consumed diagnostics. + // This consumes the diagnostics reported and either appends the diagnostics + // to the error message of 'status' (if 'status' is already an error state), + // or returns an Unknown status (if diagnostics reported), otherwise OK. + Status Combine(Status status); +}; +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_ERROR_UTIL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/utils/eval_util.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/utils/eval_util.h new file mode 100644 index 00000000..e3e14afc --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/utils/eval_util.h @@ -0,0 +1,39 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_EVAL_UTIL_H_ +#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_EVAL_UTIL_H_ + +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" +#include "mlir/IR/Operation.h" // from @llvm-project +#include "tensorflow/c/eager/c_api.h" + +namespace tensorflow { + +// Attempts to evaluates an MLIR Operation in TensorFlow eager mode with the +// specified operands. The op is always executed on the local host CPU +// irrespective of the device attribute of the given op. If there is a CPU +// kernel registered for the op and is executed successfully, this fills in the +// results vector. If not, results vector is unspecified. +// +mlir::LogicalResult EvaluateOperation( + mlir::Operation* inst, llvm::ArrayRef operands, + TFE_Context* context, llvm::SmallVectorImpl* results); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_EVAL_UTIL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/utils/export_utils.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/utils/export_utils.h new file mode 100644 index 00000000..28d5df0c --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/utils/export_utils.h @@ -0,0 +1,96 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_EXPORT_UTILS_H_ +#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_EXPORT_UTILS_H_ + +#include +#include +#include + +#include "absl/container/flat_hash_set.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/StringRef.h" +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/lib/core/status.h" + +namespace mlir { +class ShapedType; +} // namespace mlir + +namespace tensorflow { + +using tsl::StatusOr; + +// Add custom op prefix for TensorFlow dialects. +absl::Status AddTensorFlowOpPrefix(std::string); + +// Maps an MLIR op name in the TensorFlow dialect or the TensorFlow control +// dialect back into a TensorFlow valid op name. +absl::StatusOr GetTensorFlowOpName(llvm::StringRef); + +// Converts an MLIR operation to TensorFlow NodeDef with given node name. This +// name should be unique to the graph it is being inserted into. +absl::StatusOr> GetOperationNodeDef( + mlir::Operation* inst, llvm::StringRef name); + +// Converts MLIR attributes with values to their tensorflow equivalent. +// "name" and "device" attributes are ignored by default. Use attrs_to_ignore to +// specify any other attributes that should be ignored. +absl::Status ConvertAttributes( + llvm::ArrayRef attrs, + const absl::flat_hash_set& attrs_to_ignore, + bool remove_ref_type, AttrValueMap* values); + +// Fill in the contents of TensorShapeProto for the given shape. +// ShapeContainerT is any type with the following methods: +// bool hasRank() +// ArrayRef getShape() +// This includes mlir::TF::ShapeAttr and mlir::ShapedType. +template +void SetTensorShapeProto(ShapeContainerT shape, TensorShapeProto* proto) { + if (shape.hasRank()) { + for (int64_t dim : shape.getShape()) { + proto->add_dim()->set_size(mlir::ShapedType::isDynamic(dim) ? -1 : dim); + } + } else { + proto->set_unknown_rank(true); + } +} + +// Sets shape attribute with the given name. If the attribute already exists +// with a different value, returns an error. +absl::Status SetShapeAttribute(absl::string_view name, mlir::ShapedType shape, + AttrValueMap* values); + +// Returns true if the given instruction is an mlir::TF::LegacyCallOp or the +// result of such an operation transformed by the +// ExecutorToControlDialectConversion pass. +// +// TODO(b/145706023): When the ExecutorToControlDialectConversion pass runs +// before the exporter, it mutates an mlir::TF::LegacyCallOp instruction to +// an instruction with a different operation name. As such, this routine checks +// both forms of a LegacyCall instruction. We only need to check for +// mlir::TF::LegacyCallOp when the ticket is resolved. +bool IsLegacyCallInstruction(mlir::Operation* inst); +} // namespace tensorflow +#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_EXPORTER_UTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/utils/fake_session.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/utils/fake_session.h new file mode 100644 index 00000000..6ded27b0 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/utils/fake_session.h @@ -0,0 +1,85 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_FAKE_SESSION_H_ +#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_FAKE_SESSION_H_ + +#include +#include +#include +#include + +#include "tensorflow/core/common_runtime/device_mgr.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/public/session.h" + +namespace mlir { +namespace TF { +namespace test_util { +// FakeSession is for testing only. +class FakeSession : public tensorflow::Session { + public: + FakeSession(); + + absl::Status Create(const tensorflow::GraphDef& graph) override; + absl::Status Extend(const tensorflow::GraphDef& graph) override; + + absl::Status Close() override; + + absl::Status ListDevices( + std::vector* response) override; + + absl::Status LocalDeviceManager( + const tensorflow::DeviceMgr** deviceMgrPtr) override; + + absl::Status Run( + const std::vector>& inputs, + const std::vector& output_names, + const std::vector& target_nodes, + std::vector<::tensorflow::Tensor>* outputs) override; + + absl::Status Run( + const tensorflow::RunOptions& run_options, + const std::vector>& inputs, + const std::vector& output_names, + const std::vector& target_nodes, + std::vector<::tensorflow::Tensor>* outputs, + tensorflow::RunMetadata* run_metadata) override; + + absl::Status Run( + const tensorflow::RunOptions& run_options, + const std::vector>& inputs, + const std::vector& output_names, + const std::vector& target_nodes, + std::vector<::tensorflow::Tensor>* outputs, + tensorflow::RunMetadata* run_metadata, + const tensorflow::thread::ThreadPoolOptions& thread_pool_options) + override; + + private: + void InitVariables(); + void BuildDeviceManager(); + void Initialize(); + + std::unique_ptr device_mgr_; + bool initialized_ = false; +}; + +} // namespace test_util +} // namespace TF +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_FAKE_SESSION_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/utils/import_utils.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/utils/import_utils.h new file mode 100644 index 00000000..8b0aaa37 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/utils/import_utils.h @@ -0,0 +1,45 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_IMPORT_UTILS_H_ +#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_IMPORT_UTILS_H_ + +#include "absl/strings/string_view.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/protobuf.h" + +namespace tensorflow { + +// Reads text (.pbtext) or binary (.pb) format of a proto message from the given +// buffer. Returns error status of the file is not found or malformed proto. +// Note that text protos can only be parsed when full protobuf::Message protos +// are used, and will fail for protobuf::MessageLite protos. +absl::Status LoadProtoFromBuffer(absl::string_view input, + protobuf::Message* proto); +absl::Status LoadProtoFromBuffer(absl::string_view input, + protobuf::MessageLite* proto); + +// Reads text (.pbtext) or binary (.pb) format of a proto message from the given +// file path. Returns error status of the file is not found or malformed proto. +// Note that text protos can only be parsed when full protobuf::Message protos +// are used, and will fail for protobuf::MessageLite protos. +absl::Status LoadProtoFromFile(absl::string_view input_filename, + protobuf::Message* proto); +absl::Status LoadProtoFromFile(absl::string_view input_filename, + protobuf::MessageLite* proto); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_IMPORT_UTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/utils/location_utils.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/utils/location_utils.h new file mode 100644 index 00000000..c65cbb3e --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/utils/location_utils.h @@ -0,0 +1,27 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_LOCATION_UTILS_H_ +#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_LOCATION_UTILS_H_ + +#include "mlir/IR/Location.h" // from @llvm-project + +namespace tensorflow { + +mlir::Location GetLocationWithoutOpType(mlir::Location loc); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_IMPORT_UTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h new file mode 100644 index 00000000..a0c14f27 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h @@ -0,0 +1,61 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_MANGLING_UTIL_H_ +#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_MANGLING_UTIL_H_ + +#include "absl/strings/string_view.h" +#include "tensorflow/core/framework/tensor.pb.h" +#include "tensorflow/core/framework/tensor_shape.pb.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { +namespace mangling_util { +// The type of a mangled string. +enum class MangledKind { kUnknown, kDataType, kTensorShape, kTensor }; + +// Mangles an attribute name, marking the attribute as a TensorFlow attribute. +string MangleAttributeName(absl::string_view str); + +// Returns true if 'str' was mangled with MangleAttributeName. +bool IsMangledAttributeName(absl::string_view str); + +// Demangles an attribute name that was manged with MangleAttributeName. +// REQUIRES: IsMangledAttributeName returns true. +absl::string_view DemangleAttributeName(absl::string_view str); + +// Returns the type of a mangled string, or kUnknown. +MangledKind GetMangledKind(absl::string_view str); + +// Return a TensorShapeProto mangled as a string. +string MangleShape(const TensorShapeProto& shape); +// Demangle a string mangled with MangleShape. +absl::Status DemangleShape(absl::string_view str, TensorShapeProto* proto); + +// Return a TensorProto mangled as a string. +string MangleTensor(const TensorProto& tensor); +// Demangle a string mangled with MangleTensor. +absl::Status DemangleTensor(absl::string_view str, TensorProto* proto); + +// Return a DataType mangled as a string. +string MangleDataType(const DataType& dtype); +// Demangle a string mangled with MangleDataType. +absl::Status DemangleDataType(absl::string_view str, DataType* proto); + +} // namespace mangling_util +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_MANGLING_UTIL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/utils/mlprogram_util.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/utils/mlprogram_util.h new file mode 100644 index 00000000..0359d38c --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/utils/mlprogram_util.h @@ -0,0 +1,31 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_MLPROGRAM_UTIL_H_ +#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_MLPROGRAM_UTIL_H_ + +#include +#include + +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project + +namespace tensorflow { + +void RegisterMlProgramPasses(); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_MLPROGRAM_UTIL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/utils/parallel_execute_util.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/utils/parallel_execute_util.h new file mode 100644 index 00000000..1b0e0201 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/utils/parallel_execute_util.h @@ -0,0 +1,41 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_PARALLEL_EXECUTE_UTIL_H_ +#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_PARALLEL_EXECUTE_UTIL_H_ + +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" + +namespace mlir { +namespace TF { + +// TODO(b/243076653): Once the ParallelExecute is added do not remove it. This +// means BuildSingletonParallelExecuteOp will be used in one location, and +// RemoveSingletonParallelExecuteOp can be removed. + +// Wrap `cluster_func` in a `ParallelExecute` with only one child. This +// can be used to canonicalize IR, so there is always one `ParallelExecute`. +tf_device::ParallelExecuteOp BuildParallelExecuteOp( + tf_device::ClusterFuncOp cluster_func, OpBuilder* builder); + +// Unwrap `parallel_execute`'s contents if it only has one child. +LogicalResult RemoveSingletonParallelExecuteOp( + tf_device::ParallelExecuteOp parallel_execute, OpBuilder* builder); + +} // namespace TF +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_PARALLEL_EXECUTE_UTIL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/utils/parse_text_proto.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/utils/parse_text_proto.h new file mode 100644 index 00000000..fdeec88c --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/utils/parse_text_proto.h @@ -0,0 +1,44 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_PARSE_TEXT_PROTO_H_ +#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_PARSE_TEXT_PROTO_H_ + +#include "absl/strings/string_view.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/protobuf.h" + +namespace tensorflow { + +// Sets output to the given input with `prefix` stripped, or returns an error if +// the prefix doesn't exist. +absl::Status ConsumePrefix(absl::string_view str, absl::string_view prefix, + absl::string_view* output); + +// Strips `prefix_to_strip` from `text_proto`, parses, and returns the parsed +// proto. +absl::Status ParseTextProto(absl::string_view text_proto, + absl::string_view prefix_to_strip, + protobuf::Message* parsed_proto); +inline absl::Status ParseTextProto(absl::string_view /* text_proto */, + absl::string_view /* prefix_to_strip */, + protobuf::MessageLite* /* parsed_proto */) { + return errors::Unavailable("Cannot parse text protos on mobile."); +} + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_PARSE_TEXT_PROTO_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/utils/serialize_mlir_module_utils.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/utils/serialize_mlir_module_utils.h new file mode 100644 index 00000000..fc204413 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/utils/serialize_mlir_module_utils.h @@ -0,0 +1,40 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_SERIALIZE_MLIR_MODULE_UTILS_H_ +#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_SERIALIZE_MLIR_MODULE_UTILS_H_ + +#include + +#include "llvm/ADT/StringRef.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "tensorflow/core/platform/status.h" + +namespace tensorflow { + +// Prints a MLIR module `module_op` and returns it as a string. +std::string SerializeMlirModule(mlir::ModuleOp module_op); + +// Parses a MLIR module from `mlir_module_string` into `mlir_module` with +// context `mlir_context`. +absl::Status DeserializeMlirModule( + llvm::StringRef serialized_mlir_module, mlir::MLIRContext* mlir_context, + mlir::OwningOpRef* mlir_module); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_SERIALIZE_MLIR_MODULE_UTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/utils/session_utils.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/utils/session_utils.h new file mode 100644 index 00000000..be2d3786 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/utils/session_utils.h @@ -0,0 +1,50 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_SESSION_UTILS_H_ +#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_SESSION_UTILS_H_ + +#include +#include + +#include "absl/status/statusor.h" +#include "llvm/ADT/ArrayRef.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/core/public/session.h" + +namespace mlir { +namespace tf_saved_model { + +// Returns the variable for the provided 'var_handle_op'. +std::string GetVariableName(TF::VarHandleOp var_handle_op); + +// Returns pointer to the variable from 'session' that 'var_handle_op' +// refers to which is in 'device_name' device. If failed to fetch the value null +// will be returned. +// Note, caller is responsible for Unref the variable. +tensorflow::Var* GetVariableFromSession(mlir::TF::VarHandleOp var_handle_op, + llvm::StringRef device_name, + const tensorflow::DeviceMgr* mgr); + +// Returns resource tensors from session for all variables in 'module'. +absl::StatusOr> GetResourcesFromSession( + llvm::ArrayRef var_handle_ops, + tensorflow::Session* session); + +} // namespace tf_saved_model +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_SESSION_UTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/utils/shape_inference_utils.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/utils/shape_inference_utils.h new file mode 100644 index 00000000..28e2c93f --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/utils/shape_inference_utils.h @@ -0,0 +1,50 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_SHAPE_INFERENCE_UTILS_H_ +#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_SHAPE_INFERENCE_UTILS_H_ + +#include + +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/Interfaces/InferTypeOpInterface.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "tensorflow/core/ir/utils/shape_inference_utils.h" + +namespace mlir { + +class Operation; + +namespace TF { + +// Runs TensorFlow shape inference associated to the op type registered in the +// TensorFlow op registry based on the Graph version, operands, and attributes. +// Invoking this shape function will create conversions of parameters to the +// TensorFlow Graph equivalent data structures and back to MLIR equivalent data +// structures. This does not use a natively implemented shape inference in MLIR, +// and instead is temporary until shape functions are reimplemented/migrated to +// being in MLIR instead of the TensorFlow op registry. +LogicalResult InferReturnTypeComponentsForTFOp( + std::optional location, Operation* op, int64_t graph_version, + tfg::OperandAsConstantFn operand_as_constant_fn, + tfg::OpResultAsShapeFn op_result_as_shape_fn, + tfg::ResultElementTypeFn result_element_type_fn, + SmallVectorImpl& inferred_return_shapes); + +} // namespace TF +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_SHAPE_INFERENCE_UTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/utils/side_effect_analysis_util.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/utils/side_effect_analysis_util.h new file mode 100644 index 00000000..0c6e1532 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/utils/side_effect_analysis_util.h @@ -0,0 +1,46 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_SIDE_EFFECT_ANALYSIS_UTIL_H_ +#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_SIDE_EFFECT_ANALYSIS_UTIL_H_ + +#include + +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" + +namespace mlir { +namespace TF { + +std::string GetDeviceAttrAsResourceInstanceStr(Operation* op); + +void MarkResourceAsReadAndWrite( + OpOperand& op_operand, + SmallVectorImpl>& + effect); + +void MarkResourceAsReadOnly( + OpOperand& op_operand, + SmallVectorImpl>& + effect); + +} // namespace TF +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_SIDE_EFFECT_ANALYSIS_UTIL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/utils/stablehlo_custom_call.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/utils/stablehlo_custom_call.h new file mode 100644 index 00000000..ea1ae8c8 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/utils/stablehlo_custom_call.h @@ -0,0 +1,40 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_STABLEHLO_CUSTOM_CALL_H_ +#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_STABLEHLO_CUSTOM_CALL_H_ + +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "stablehlo/dialect/StablehloOps.h" // from @stablehlo + +namespace mlir { +namespace TF { + +// Returns whether the custom call op represents a TF function call. +bool IsTfFuncCustomCall(stablehlo::CustomCallOp op); + +// Returns the `called_func` symbol ref attribute in the `tf.backend_config` +// dictionary attribute. +// +// If the op does not represent a TF function call, returns nullptr. +// Otherwise, if the op does not have `caller_name`, returns failure. +FailureOr GetTfFuncCustomCallFuncName( + stablehlo::CustomCallOp op); + +} // namespace TF +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_STABLEHLO_CUSTOM_CALL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/utils/string_util.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/utils/string_util.h new file mode 100644 index 00000000..56410385 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/utils/string_util.h @@ -0,0 +1,60 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_STRING_UTIL_H_ +#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_STRING_UTIL_H_ + +#include +#include + +#include "llvm/ADT/StringRef.h" +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project + +// Utility functions for dumping operations/attributes as strings and ostream +// bindings. + +namespace tensorflow { +std::string OpAsString(mlir::Operation& op); +std::string AttrAsString(mlir::Attribute& attr); + +// b/281863212 enable automatic without Op/AttrAsString. +// We add logging via a wrapper struct in order to respect ODS and avoid +// multiple symbol definitions if MLIR or someone else decides to add ostream +// definitions for the MLIR symbols. +struct LoggableOperation { + mlir::Operation& v; + // NOLINTNEXTLINE(google-explicit-constructor) + LoggableOperation(mlir::Operation& v) : v(v) {} +}; +std::ostream& operator<<(std::ostream& o, const LoggableOperation& op); + +struct LoggableAttribute { + mlir::Attribute& v; + // NOLINTNEXTLINE(google-explicit-constructor) + LoggableAttribute(mlir::Attribute& v) : v(v) {} +}; +std::ostream& operator<<(std::ostream& o, const LoggableAttribute& attr); + +struct LoggableStringRef { + const llvm::StringRef& v; + // NOLINTNEXTLINE(google-explicit-constructor) + LoggableStringRef(const llvm::StringRef& v) : v(v) {} +}; +std::ostream& operator<<(std::ostream& o, const LoggableStringRef& ref); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_STRING_UTIL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/utils/topological_sort.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/utils/topological_sort.h new file mode 100644 index 00000000..1daab855 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/utils/topological_sort.h @@ -0,0 +1,74 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_TOPOLOGICAL_SORT_H_ +#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_TOPOLOGICAL_SORT_H_ + +#include +#include + +#include "absl/types/span.h" +#include "llvm/ADT/STLFunctionalExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "mlir/IR/Block.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project + +namespace mlir { +namespace TF { + +// A function that determines which op to emit next in the case of ties. +// The predecessor (which can be null) is the last op we emitted, +// and op is the candidate we're considering. A larger returned integer +// means the op has a higher chance of being emitted first. +typedef int (*PriorityFunction)(Operation *predecessor, Operation *op); + +// A function that returns extra dependencies for each op. These might +// e.g. be known side-effects (or control dependencies) between ops. +// If "incoming" is true, then the list of (extra) predecessors of the +// op should be returned. If "incoming" is false, the list of successors. +// The algorithm assumes that these are consistent which each other. So +// if (and only if) op1 is in extra_dependencies(op2, true), then op2 +// must also be in extra_dependencies(op1, false). +// This function is called multiple times during the topological sort, +// so the implementation should preferably be constant-time. +typedef llvm::function_ref const &( + Operation *, bool incoming)> + ExtraDependenciesFunction; + +// Convenience function if there are no extra dependencies to declare. +// (Unlike nullptr, this also works inside the ternary operator) +extern ExtraDependenciesFunction no_extra_dependencies; + +// Sort a block topologically, so that for all ops, all operands are +// available at the time of execution. This is similar to MLIR's topological +// sort (lib/Transforms/TopologicalSort.cpp) but also takes a priority +// function to determine the next op to emit in the case of ambiguity. This +// makes it possible to group operations by certain attributes. For example, +// the order_by_dialect pass uses this function to group by dialect. +// Only the operations nested directly under the block will be reordered. +// Nested blocks will be left alone. +// Also takes a list of control dependencies (vector of operation pairs, +// from->to) that will be honored when ordering the ops together with the +// data dependencies given through (the ops/results of) the operations +// themselves. +std::vector SortBlockTopologically( + Block &block, PriorityFunction priorityFunction, + ExtraDependenciesFunction extraDependencies = no_extra_dependencies); + +} // namespace TF +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_TOPOLOGICAL_SORT_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/utils/tpu_cluster_util.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/utils/tpu_cluster_util.h new file mode 100644 index 00000000..46ead1b8 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/utils/tpu_cluster_util.h @@ -0,0 +1,51 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_TPU_CLUSTER_UTIL_H_ +#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_TPU_CLUSTER_UTIL_H_ + +#include +#include +#include + +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h" + +namespace mlir { +namespace TFTPU { + +// For each TPU cluster in `module`, walk over all ops inside the cluster +// and reachable in the call graph from the cluster. +// For each op walked, `callback` is applied to the op, the root cluster, and +// the root cluster's host device. `callback` returning WasInterrupted +// indicates failure. +// The host device is null when the tpu_cluster HasModelParallelism: The +// HasModelParallelism case is currently unsupported in combination with +// outside compilation. +mlir::LogicalResult WalkReachableFromTpuCluster( + ModuleOp module, std::function)> + callback); + +// Like above, except TPU clusters are not required to have a host device, and +// no host device is passed to `callback`. +mlir::LogicalResult WalkReachableFromTpuCluster( + ModuleOp module, + std::function callback); + +} // namespace TFTPU +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_TPU_CLUSTER_UTIL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h new file mode 100644 index 00000000..cdbf7396 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h @@ -0,0 +1,311 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_TPU_REWRITE_DEVICE_UTIL_H_ +#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_TPU_REWRITE_DEVICE_UTIL_H_ + +#include +#include +#include + +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h" +#include "xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/statusor.h" +#include "tensorflow/core/protobuf/tpu/topology.pb.h" +#include "tensorflow/core/util/device_name_utils.h" +#include "tsl/platform/statusor.h" + +namespace tensorflow { +using tsl::StatusOr; + +inline constexpr absl::string_view kNumCoresPerReplicaAttr = + "num_cores_per_replica"; +inline constexpr absl::string_view kTopologyAttr = "topology"; +inline constexpr absl::string_view kDeviceAssignmentAttr = "device_assignment"; + +// A TPU device for execution alongside its associated host CPU device. +struct TPUDeviceAndHost { + TPUDeviceAndHost() = default; + TPUDeviceAndHost(llvm::StringRef device, llvm::StringRef host) + : device(device), host(host) {} + + std::string device; + std::string host; +}; + +// TPU devices to be used for execution (e.g. devices for TPUExecute ops) and +// their associated host CPU devices (for outside compilation). They are ordered +// by `num_replicas` followed by `num_cores_per_replica`. +using TPUDevicesAndHosts = + llvm::SmallVector, 8>; + +// TPU compilation device, execution and associated host devices, and optionally +// execution device IDs. Execution device IDs are populated if `topology` and +// `device_assignment` are provided. +struct TPUDeviceAssignment { + TPUDeviceAssignment(llvm::StringRef compilation_device, + TPUDevicesAndHosts&& tpu_devices) + : compilation_device(compilation_device), + tpu_devices(std::move(tpu_devices)) {} + + TPUDeviceAssignment(llvm::StringRef compilation_device, + TPUDevicesAndHosts&& tpu_devices, + xla::DeviceAssignmentProto&& xla_device_assignment) + : compilation_device(compilation_device), + tpu_devices(std::move(tpu_devices)), + xla_device_assignment(std::move(xla_device_assignment)) {} + + std::string compilation_device; + TPUDevicesAndHosts tpu_devices; + std::optional xla_device_assignment; +}; + +// Extracts device coordinates from a device assignment attribute on an op. +absl::StatusOr> GetDeviceCoordinates( + mlir::ArrayAttr device_assignment_attr); + +// Finds the TPU compilation device and execution devices from `devices` for a +// TPU computation subgraph. Compilation device is determined from looking up +// all TPU_SYSTEM:0 devices and choosing the CPU device associated to the first +// TPU_SYSTEM device sorted lexicographically by replica and task. Execution +// devices are determined by looking up all TPU devices associated with each +// TPU_SYSTEM:0 device found, alongside associated `topology_attr` and +// `device_assignment_attr`. If `topology_attr` not an empty string (parsable to +// TopologyProto), `device_assignment_attr` must not be empty also. When +// `topology_attr` and `device_assignment_attr` are not empty, a general device +// assignment based on those two attributes are used. Otherwise when +// `topology_attr` and `device_assignment_attr` are empty, a full mesh device +// assignment is used instead. A failure will be returned if it is not possible +// (e.g. invalid devices or invalid parameters). +// +// +// For example, for `devices`: +// { +// /job:localhost/replica:0/task:0/device:CPU:0, +// /job:worker/replica:0/task:0/device:CPU:0, +// /job:worker/replica:0/task:0/device:TPU_SYSTEM:0, +// /job:worker/replica:0/task:0/device:TPU:0, +// /job:worker/replica:0/task:0/device:TPU:1, +// /job:worker/replica:0/task:0/device:TPU:2, +// /job:worker/replica:0/task:0/device:TPU:3, +// /job:worker/replica:0/task:1/device:CPU:0, +// /job:worker/replica:0/task:1/device:TPU_SYSTEM:0, +// /job:worker/replica:0/task:1/device:TPU:0, +// /job:worker/replica:0/task:1/device:TPU:1, +// /job:worker/replica:0/task:1/device:TPU:2, +// /job:worker/replica:0/task:1/device:TPU:3 +// } +// +// +// With the following parameters (full mesh device assignment): +// `num_replicas` = 8 +// `num_cores_per_replica` = 1 +// `topology_attr` = "" +// `device_assignment_attr` = {} +// +// The `compilation_device` will be: +// /job:worker/replica:0/task:0/device:CPU:0 +// +// `execution_devices` will be: +// { +// { +// /job:worker/replica:0/task:0/device:TPU:0 +// }, +// { +// /job:worker/replica:0/task:0/device:TPU:1 +// }, +// { +// /job:worker/replica:0/task:0/device:TPU:2 +// }, +// { +// /job:worker/replica:0/task:0/device:TPU:3 +// }, +// { +// /job:worker/replica:0/task:1/device:TPU:0 +// }, +// { +// /job:worker/replica:0/task:1/device:TPU:1 +// }, +// { +// /job:worker/replica:0/task:1/device:TPU:2 +// }, +// { +// /job:worker/replica:0/task:1/device:TPU:3 +// } +// } +// +// and `xla_device_assignment` will not be set. +// +// +// With the following parameters (general device assignment): +// `num_replicas` = 4 +// `num_cores_per_replica` = 2 +// `topology_attr` (in proto debug string format) = +// { +// mesh_shape: 2 +// mesh_shape: 2 +// mesh_shape: 2 +// num_tasks: 2 +// num_tpu_devices_per_task: 4 +// device_coordinates: 0 +// device_coordinates: 0 +// device_coordinates: 0 +// device_coordinates: 0 +// device_coordinates: 1 +// device_coordinates: 0 +// device_coordinates: 1 +// device_coordinates: 1 +// device_coordinates: 0 +// device_coordinates: 1 +// device_coordinates: 0 +// device_coordinates: 0 +// device_coordinates: 1 +// device_coordinates: 0 +// device_coordinates: 1 +// device_coordinates: 1 +// device_coordinates: 1 +// device_coordinates: 1 +// device_coordinates: 0 +// device_coordinates: 1 +// device_coordinates: 1 +// device_coordinates: 0 +// device_coordinates: 0 +// device_coordinates: 1 +// } +// `device_assignment` = +// {0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 1, 1, 0, 0, 1, 0, 1, 1, 1, 0, 1, 1, 1} +// +// The `compilation_device` will be: +// /job:worker/replica:0/task:0/device:CPU:0 +// +// `execution_devices` will be: +// { +// { +// "/job:worker/replica:0/task:0/device:TPU:0", +// "/job:worker/replica:0/task:1/device:TPU:3" +// }, +// { +// "/job:worker/replica:0/task:0/device:TPU:1", +// "/job:worker/replica:0/task:1/device:TPU:2" +// }, +// { +// "/job:worker/replica:0/task:0/device:TPU:3", +// "/job:worker/replica:0/task:1/device:TPU:0" +// }, +// { +// "/job:worker/replica:0/task:0/device:TPU:2", +// "/job:worker/replica:0/task:1/device:TPU:1" +// } +// } +// +// and `xla_device_assignment` will be: +// { +// replica_count: 4 +// computation_count: 2 +// computation_devices { +// replica_device_ids: 0 +// replica_device_ids: 4 +// replica_device_ids: 2 +// replica_device_ids: 6 +// } +// computation_devices { +// replica_device_ids: 1 +// replica_device_ids: 5 +// replica_device_ids: 3 +// replica_device_ids: 7 +// } +// } +absl::StatusOr GetTPUCompilationAndExecutionDevices( + llvm::ArrayRef devices, int num_replicas, + int num_cores_per_replica, llvm::StringRef topology_attr, + llvm::ArrayRef device_assignment_attr); + +// Converts a device assignment attribute to an XLA device assignment proto. +absl::StatusOr GetXlaDeviceAssignmentProto( + llvm::StringRef topology_attr, int num_replicas, int num_cores_per_replica, + llvm::ArrayRef device_assignment_attr); + +// Virtual device name of the passed logical core. The logical core is the index +// of a core within a replica. +std::string GetDeviceAliasForLogicalCore(int core_index); + +// Virtual device name of the host that is associated with the passed logical +// core. The logical core is the index of a core within a replica. +std::string GetDeviceAliasForHostOfLogicalCore(int core_index); + +// Returns true if cluster contains model parallelism based on +// `num_cores_per_replica_attribute`. Otherwise returns false. +bool HasModelParallelism(mlir::tf_device::ClusterOp cluster); + +// Returns true if the devices list contain any TPU devices +bool HasTPUDevice(const mlir::TF::RuntimeDevices& devices); + +// Returns the host device used for outside compilation in generic pipeline. +mlir::LogicalResult GetHostDeviceOutsideCompilationInGenericPipeline( + mlir::TF::RuntimeDevices devices, std::string* host_device); + +// Parses XLA compilation and execution devices from a tf_device.cluster and +// returns the host device for the head and tail computations. For TPU device, +// if the computation is replicated, GetDeviceAliasForHostOfLogicalCore(0) is +// returned instead. +mlir::LogicalResult GetHostDeviceOutsideComputation( + mlir::TF::RuntimeDevices devices, mlir::tf_device::ClusterOp cluster, + std::string* host_device); + +// Checks if a device string is a TPU device. +bool IsTPUDevice(llvm::StringRef device); + +// Checks if a device string is a TPU replicated core device. +bool IsTPUReplicatedCore(llvm::StringRef device); + +// Checks if `type` is allowed for XLA. String and resources are not XLA types. +// There are other TF types that are not XLA types which will be removed by +// successive passes in TF/XLA bridge phase 2. +bool TypeValidForXLA(const mlir::Type& type); + +// Returns the map from core to the host that is associated with the +// core. If `cluster` is not replicated then the core is a physical core index +// and the host is a physical host name. If `cluster` is replicated then the +// core with index `i` is a logical core (`TPU_REPLICATED_CORE_i`), and the host +// is the associated virtual device name (`TPU_REPLICATED_HOST_i`). +mlir::LogicalResult GetDeviceToHostMap( + mlir::tf_device::ClusterOp cluster, + llvm::SmallVector& core_to_host); + +// Returns the first TPU device, for use in the non-replicated case. The list of +// TPU devices is retrived from `op`'s module ancestor. +mlir::LogicalResult GetNonReplicatedTPU0(mlir::Operation* op, + std::string* tpu0_device); + +// Returns the CPU of the first TPU device, for use in the non-replicated case. +// The list of devices is retrived from `op`'s module ancestor. +mlir::LogicalResult GetNonReplicatedCPU0(mlir::Operation* op, + std::string* cpu0_device); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_TPU_REWRITE_DEVICE_UTIL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/utils/translate_utils.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/utils/translate_utils.h new file mode 100644 index 00000000..60beacc8 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/utils/translate_utils.h @@ -0,0 +1,45 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_TRANSLATE_UTILS_H_ +#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_TRANSLATE_UTILS_H_ + +#include "absl/status/statusor.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "tensorflow/core/framework/versions.pb.h" +#include "tensorflow/core/platform/statusor.h" + +namespace tensorflow { + +// Populates the tf.versions attribute on a module, given a corresponding +// graph VersionDef proto. +void PopulateTfVersions(mlir::ModuleOp module, const VersionDef& versions); + +// Extracts TensorFlow GraphDef version information from the given module. +// Returns failure if version attribute is missing or any of the sub attributes +// are invalid. +mlir::LogicalResult ExtractTfVersions(mlir::ModuleOp module, + VersionDef* versions); + +// Returns TensorFlow GraphDef producer version for the given module. Returns an +// error if the version information is missing for the module or is not valid. +absl::StatusOr GetTfGraphProducerVersion(mlir::ModuleOp module); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_TRANSLATE_UTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/utils/verification_utils.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/utils/verification_utils.h new file mode 100644 index 00000000..3ec239c4 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/utils/verification_utils.h @@ -0,0 +1,33 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_VERIFICATION_UTILS_H_ +#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_VERIFICATION_UTILS_H_ + +#include + +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project + +namespace mlir { +namespace TF { + +// Returns success when the given shape argument of the Reshape op is valid. +LogicalResult VerifyShapeOfReshapeOp(ArrayRef shape); + +} // namespace TF +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_VERIFICATION_UTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/utils/verify_suitable_for_graph_export.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/utils/verify_suitable_for_graph_export.h new file mode 100644 index 00000000..31a6e25a --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/utils/verify_suitable_for_graph_export.h @@ -0,0 +1,31 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_VERIFY_SUITABLE_FOR_GRAPH_EXPORT_H_ +#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_VERIFY_SUITABLE_FOR_GRAPH_EXPORT_H_ + +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project + +namespace tensorflow { + +// Returns whether all functions in module are of single tf_executor.graph and +// each tf_executor.island in tf_executor.graph only has a single op. +mlir::LogicalResult VerifyExportSuitable(mlir::ModuleOp module); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_VERIFY_SUITABLE_FOR_GRAPH_EXPORT_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/utils/visitor.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/utils/visitor.h new file mode 100644 index 00000000..9fd25569 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/utils/visitor.h @@ -0,0 +1,52 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_VISITOR_H_ +#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_VISITOR_H_ + +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/STLFunctionalExtras.h" +#include "llvm/ADT/StringRef.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/OwningOpRef.h" // from @llvm-project +#include "mlir/IR/SymbolTable.h" // from @llvm-project +#include "mlir/IR/Visitors.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project + +namespace mlir { +namespace TF { + +// Walks the function by following function call chains and calling the callback +// for each reachable function (including `func`). Each function is visited only +// once even if it's called from multiple places and/or recursively. +// +// The current implementation follows direct calls to `mlir::func::FuncOp` only +// and returns a `mlir::WalkResult::interrupt()` when it encounters a call whose +// callee cannot be resolved to `mlir::func::FuncOp`. +mlir::WalkResult WalkReachableFunctions( + mlir::func::FuncOp func, + llvm::function_ref callback, + mlir::SymbolTableCollection* symbol_table = nullptr); + +// Creates a new MLIR module that contains only the given functions and all +// reachable functions from them. +mlir::FailureOr> CreatePrunedModule( + mlir::ModuleOp module, llvm::ArrayRef function_names); + +} // namespace TF +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_VISITOR_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/utils/xla_call_module_attrs.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/utils/xla_call_module_attrs.h new file mode 100644 index 00000000..264e1b4c --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/utils/xla_call_module_attrs.h @@ -0,0 +1,48 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_XLA_CALL_MODULE_ATTRS_H_ +#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_XLA_CALL_MODULE_ATTRS_H_ + +#include "llvm/ADT/StringRef.h" + +namespace mlir { +namespace TF { + +// The main function's name in the serialized stablehlo module embedded in +// XlaCallModule's `module` attribute. +constexpr llvm::StringRef kStablehloMainFunctionName = "main"; + +// After deserializing the stablehlo functions from XlaCallModule, +// this XlaCallModule attribute refers to the deserialized stablehlo main +// function. +constexpr llvm::StringRef kStablehloEntryFunctionAttrName = "_entry_function"; + +// The StableHLO version of the serialized stablehlo module embedded in +// XlaCallModule's `module` attribute, set on deserialization. +constexpr llvm::StringRef kStablehloVersionAttrName = "_stablehlo_version"; + +// Every stablehlo function deserialized from XlaCallModule has this attribute. +constexpr llvm::StringRef kFromXlaCallModuleAttrName = "_from_xla_call_module"; + +// Name of `tf.XlaCallModule`'s dictionary attribute for keeping the +// deserialized stablehlo module's attributes. +constexpr llvm::StringRef kStablehloModuleAttrsAttrName = + "_stablehlo_module_attrs"; + +} // namespace TF +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_XLA_CALL_MODULE_ATTRS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/utils/xla_rewrite_util.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/utils/xla_rewrite_util.h new file mode 100644 index 00000000..8ce5403e --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/utils/xla_rewrite_util.h @@ -0,0 +1,66 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_XLA_REWRITE_UTIL_H_ +#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_XLA_REWRITE_UTIL_H_ + +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" +#include "xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/statusor.h" +#include "tensorflow/core/util/device_name_utils.h" + +namespace tensorflow { +// Erase rewritten ClusterFuncOp(s). If TPUPartitionedInputV2Op / +// TPUPartitionedOutputV2Op are present, they must be removed along with the +// ClusterFuncOp(s). +mlir::LogicalResult EraseClusterFuncs( + llvm::MutableArrayRef to_be_erased); + +// Move child processes of the ParallelExecute that do not change. These are all +// children except for the child with the ClusterFunc. +// Returns the index of the child with the ClusterFunc. +int MovePreservedParallelExecuteChildren( + int num_cores_per_replica, + llvm::SmallVector& concatenated_output_types, + mlir::OpBuilder* builder, mlir::tf_device::ClusterFuncOp cluster_func, + mlir::tf_device::ParallelExecuteOp old_parallel_execute, + mlir::tf_device::ParallelExecuteOp* new_parallel_execute); + +// Wraps single op in `tf_device.launch` for explicit device assignment. +mlir::tf_device::LaunchOp WrapOpInLaunch(mlir::OpBuilder* builder, + mlir::Location loc, + mlir::Operation* op, + llvm::StringRef device); + +} // namespace tensorflow +#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_XLA_REWRITE_UTIL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.h new file mode 100644 index 00000000..8b87b1c2 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.h @@ -0,0 +1,172 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_XLA_SHARDING_UTIL_H_ +#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_XLA_SHARDING_UTIL_H_ + +#include + +#include +#include +#include + +#include "absl/status/statusor.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" +#include "xla/xla_data.pb.h" +#include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h" + +namespace tensorflow { + +inline constexpr llvm::StringRef kInputShardingAttr = + "input_sharding_configuration"; +inline constexpr llvm::StringRef kOutputShardingAttr = + "output_sharding_configuration"; + +inline constexpr llvm::StringRef kICIWeightDistributionMlirBridgeMarker = + "_ici_weight_distribution_mlir_bridge_marker"; + +// Parses the sharding string. This sharding string can be binary (serialized) +// or human readable. +mlir::LogicalResult DecodeShardingAttribute(const std::string& shard_str, + xla::OpSharding& sharding, + bool report_error = true); + +// Encodes the sharding in human readable form. +mlir::LogicalResult DecodeShardingAttribute(mlir::Attribute shard_attr, + xla::OpSharding& sharding, + bool report_error = true); + +// Parses the sharding attr. This sharding attr can be binary (serialized) +// or human readable. +void EncodeSharding(mlir::Operation* op, llvm::StringRef shard_str); + +// Parses "input_sharding_configuration" attribute and returns a list where i-th +// element is a list of mlir::Value's which represent inputs for the TPU +// computation corresponding to i-th logical device. If the attribute does not +// exist, the all inputs are placed on logical core 0. +mlir::LogicalResult ExtractInputsForLogicalDevices( + int num_cores_per_replica, mlir::tf_device::ClusterFuncOp cluster_func, + mlir::OpBuilder* builder, + llvm::SmallVectorImpl>* input_list); + +// Same as above, except creates tf.XlaSplitND Op for split sharding if +// use_xla_nd_ops is true, otherwise creates tf.Split op. +mlir::LogicalResult ExtractInputsForLogicalDevices( + int num_cores_per_replica, mlir::tf_device::ClusterFuncOp cluster_func, + mlir::OpBuilder* builder, bool use_xla_nd_ops, + llvm::SmallVectorImpl>* input_list); + +// Extracts a list of OpSharding that represent output sharding configuration of +// `tf_device.cluster`. +mlir::LogicalResult ParseAndValidateOutputSharding( + int num_cores_per_replica, mlir::tf_device::ClusterFuncOp cluster_func, + mlir::SmallVector* output_sharding_list); + +// Retrieves output types for TPUExecute op representing execution for provided +// logical device id. TPUExecute op for different logical device may have +// different outputs depending on the output sharding configuration. +mlir::LogicalResult GetOutputTypesForLogicalDeviceComputation( + int core_id, llvm::ArrayRef output_sharding_config, + mlir::tf_device::ClusterFuncOp cluster_func, + llvm::SmallVectorImpl* output_types, + llvm::SmallVectorImpl* cluster_to_core_index); + +// Same as above, except creates tf.XlaSplitND Op for split sharding if +// use_xla_nd_ops is true, otherwise creates tf.Split op. +mlir::LogicalResult GetOutputTypesForLogicalDeviceComputation( + int core_id, llvm::ArrayRef output_sharding_config, + mlir::tf_device::ClusterFuncOp cluster_func, + llvm::SmallVectorImpl* output_types, bool use_xla_nd_ops, + llvm::SmallVectorImpl* cluster_to_core_index); + +// Remaps outputs of `new_parallel_execute` op that represent concurrent +// execution of the `tf_device.cluster_func` at index `cluster_idx` of +// `old_parallel_execute` with its users. +// `num_results_pre_cluster` represent the # of outputs of +// `new_parallel_execute` which are from ops before `tf_device.cluster_func` op. +mlir::LogicalResult RemapOutputsFromLogicalDevices( + const mlir::Location& location, + llvm::ArrayRef output_sharding_config, + llvm::SmallVector, 4> cluster_to_core_index, + int num_results_pre_cluster, + mlir::tf_device::ParallelExecuteOp old_parallel_execute, int cluster_idx, + mlir::tf_device::ParallelExecuteOp new_parallel_execute, + mlir::OpBuilder* builder); + +// Same as above, except creates tf.XlaConcatNd Op for split sharding if +// use_xla_nd_ops is true, otherwise creates tf.Concat op. +mlir::LogicalResult RemapOutputsFromLogicalDevices( + const mlir::Location& location, + llvm::ArrayRef output_sharding_config, + llvm::SmallVector, 4> cluster_to_core_index, + int num_results_pre_cluster, + mlir::tf_device::ParallelExecuteOp old_parallel_execute, int cluster_idx, + mlir::tf_device::ParallelExecuteOp new_parallel_execute, + bool use_xla_nd_ops, mlir::OpBuilder* builder); + +// Determines each logical core argument to metadata argument index mapping, +// based on sharding. The return value is indexed first by logical core then by +// argument index. +llvm::SmallVector, 4> GetMetadataArgumentMapping( + const tpu::TPUCompileMetadataProto& metadata); + +// Gets the proper tensor dimension from XLA OpSharding. +// "replicate_on_last_tile_dim" and "last_tile_dims" should be deducted from the +// real Tensor dimensions when tiled. +// For example: +// f32[8,512](sharding={devices=[1,1,2]0,1 last_tile_dims={REPLICATED}) +// also means a replicated tensor over all devices. +// +// See xla_data.proto for detailed explanations on the fields. +int GetDimsFromXLAShardingTiled(const xla::OpSharding& xla_sharding); + +// A sharding with OTHER type may be REPLICATED if: +// 'replicate_on_last_tile_dim' is true OR +// 'last_tile_dims' is not empty +// AND +// other than replicated last tile dims, all other dims are not sharded. +bool IsOtherReplicatedSharding(const xla::OpSharding& xla_sharding); + +// Returns whether the sharding is split sharding. i.e. A sharding with OTHER +// type but not replicated. +bool IsSplitSharding(const xla::OpSharding& sharding); + +// Returns whether the sharding is replicated. It includes sharding with +// REPLICATED type and replicated OTHER type. +bool IsReplicatedSharding(const xla::OpSharding& sharding); + +// Returns whether the shape of inputs and outputs is statically known when +// split sharding is done on inputs or outputs. +bool AreInputOutputShapesStaticallyKnownForSplitSharding( + llvm::ArrayRef output_sharding_config, + mlir::tf_device::ClusterFuncOp cluster_func); + +// Returns a map of dimension indices and number of splits for tiled sharding. +absl::StatusOr> GetDimensionIndicesAndNumSplitsFromSharding( + const xla::OpSharding& sharding); +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_XLA_SHARDING_UTIL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow_to_stablehlo/python/pywrap_tensorflow_to_stablehlo_lib.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow_to_stablehlo/python/pywrap_tensorflow_to_stablehlo_lib.h new file mode 100644 index 00000000..efc8be06 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow_to_stablehlo/python/pywrap_tensorflow_to_stablehlo_lib.h @@ -0,0 +1,67 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TO_STABLEHLO_PYTHON_PYWRAP_TENSORFLOW_TO_STABLEHLO_LIB_H_ +#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TO_STABLEHLO_PYTHON_PYWRAP_TENSORFLOW_TO_STABLEHLO_LIB_H_ + +#include +#include + +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" + +namespace mlir::tensorflow_to_stablehlo::pywrap { + +// Converts a TensorFlow SavedModel to a StableHLO MLIR module and serializes it +// to bytecode. +// +// Args: +// input_path: The path to the SavedModel directory. +// exported_model_signatures: Comma-separated list of exported model +// signatures to convert. tag_names: Comma-separated list of tags for loading +// SavedModel. +// input_arg_shapes_str: A string representation of input argument +// shapes for 'main' entry-point, separating tensors with ':', dimension +// with ',', and using '?' for unknown sizes. For example, +// 'input-arg-shapes=1,2::1,?' expresses argument shapes [1,2], [] and [1,?]. +// +// Returns: +// An absl::StatusOr containing the serialized bytecode of the StableHLO +// module on success, or an error status on failure. +absl::StatusOr PywrapSavedModelToStablehlo( + absl::string_view input_path, + const std::vector& exported_model_signatures, + const std::vector& tag_names, + absl::string_view input_arg_shapes_str); + +// Converts a TensorFlow MLIR module string to a StableHLO MLIR module and +// serializes it to bytecode. +// +// Args: +// module_op_str: TensorFlow MLIR module string. +// input_arg_shapes_str: A string representation of input argument +// shapes for 'main' entry-point, separating tensors with ':', dimension +// with ',', and using '?' for unknown sizes. For example, +// 'input-arg-shapes=1,2::1,?' expresses argument shapes [1,2], [] and [1,?]. +// +// Returns: +// An absl::StatusOr containing the serialized bytecode of the StableHLO +// module on success, or an error status on failure. +absl::StatusOr PywrapTfModuleToStablehlo( + absl::string_view module_op_str, absl::string_view input_arg_shapes_str); + +} // namespace mlir::tensorflow_to_stablehlo::pywrap + +#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TO_STABLEHLO_PYTHON_PYWRAP_TENSORFLOW_TO_STABLEHLO_LIB_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow_to_stablehlo/tf_to_stablehlo.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow_to_stablehlo/tf_to_stablehlo.h new file mode 100644 index 00000000..bb0e2a07 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tensorflow_to_stablehlo/tf_to_stablehlo.h @@ -0,0 +1,56 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TO_STABLEHLO_TF_TO_STABLEHLO_H_ +#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TO_STABLEHLO_TF_TO_STABLEHLO_H_ + +#include +#include + +#include "absl/status/statusor.h" +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project + +namespace mlir { + +// Converts a TensorFlow model (either from a SavedModel or an MLIR module) to a +// StableHLO MLIR module. +// +// Args: +// input_path: The path to the input TensorFlow SavedModel or MLIR module. +// context: The MLIR context to use for parsing or creating the MLIR module. +// exported_model_signatures: List of exported model signatures (strings) to +// convert. +// tag_names: List of tag names (strings) used for loading SavedModel. +// Ignored for MLIR input. +// input_arg_shapes_str: A string representation of input argument shapes for +// 'main' entry-point, separating tensors with ':', dimension with ',', and +// using '?' for unknown sizes. For example, 'input-arg-shapes=1,2::1,?' +// expresses argument shapes [1,2], [] and [1,?]. +// is_input_mlir_module: If true, `input_path` is treated as an MLIR +// module instead of a SavedModel. +// +// Returns: +// An absl::StatusOr containing the converted StableHLO MLIR module on +// success, or an absl::Status with an error message on failure. +absl::StatusOr> TfToStablehlo( + absl::string_view input_path, MLIRContext* context, + const std::vector& exported_model_signatures, + const std::vector& tag_names, + absl::string_view input_arg_shapes_str, bool is_input_mlir_module); + +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TO_STABLEHLO_TF_TO_STABLEHLO_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tf2xla/api/v1/cluster_tf.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tf2xla/api/v1/cluster_tf.h new file mode 100644 index 00000000..b290554d --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tf2xla/api/v1/cluster_tf.h @@ -0,0 +1,44 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TF2XLA_API_V1_CLUSTER_TF_H_ +#define TENSORFLOW_COMPILER_MLIR_TF2XLA_API_V1_CLUSTER_TF_H_ + +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { +namespace tf2xla { +namespace v1 { + +// Run all the passes involved in transforming the graph before execution so +// that it is suitable for targeting devices when called via the TF1 Session +// API. +// These transformations take as input a Tensorflow Graph as an MLIR Module +// and transforms the module in place to cluster the given ops for compilation +// that is compatible with the given device_type. The MLIR should be in the TF +// Executor Dialect for graph nodes and edges or TF Functional. It will convert +// to TF Functional internally. Individual Op inside a node should be the +// Tensorflow Dialect. The output MLIR is in the TF Functional Dialect. The +// input MLIR should not have infeed and outfeed ops, which are unsupported via +// this API. Returns OkStatus if passed, otherwise an error. +absl::Status RunSessionTf2xlaClusteringBridge(mlir::ModuleOp module, + bool is_in_fallback_enabled_mode); + +} // namespace v1 +} // namespace tf2xla +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_TF2XLA_API_V1_CLUSTER_TF_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tf2xla/api/v1/compile_mlir_util.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tf2xla/api/v1/compile_mlir_util.h new file mode 100644 index 00000000..53431dfe --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tf2xla/api/v1/compile_mlir_util.h @@ -0,0 +1,238 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TF2XLA_API_V1_COMPILE_MLIR_UTIL_H_ +#define TENSORFLOW_COMPILER_MLIR_TF2XLA_API_V1_COMPILE_MLIR_UTIL_H_ + +#include + +#include "absl/base/attributes.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/StringRef.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassManager.h" // from @llvm-project +#include "tensorflow/compiler/tf2xla/layout_util.h" +#include "tensorflow/compiler/tf2xla/xla_argument.h" +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "xla/hlo/builder/xla_computation.h" +#include "tensorflow/core/common_runtime/device.h" +#include "tensorflow/core/framework/tensor_shape.h" + +namespace tensorflow { + +// Lowers MLIR module to XLA HLO inside an XlaComputation. The input module +// should only contain operations in tf dialect. If the input module contains +// operation in the tf_executor dialect, for example, returns an error. +// Exception to this are tf_executor dialect ops that are optimized away through +// canonicalization. +// +// Operations in tf dialect are lowered to XLA HLO through the following steps: +// . Legalizes control flow operations. +// . Decomposes compound resource operations so that the only remaining +// operations on resource variables are resource reads/writes.. +// . Replaces resource reads/writes with function inputs/outputs and +// eliminates the use of resource variables. +// . Legalizes the operations to XLA HLO operations. +// . Canonicalizes the XLA HLO operations. +// +// device_type: XLA JIT device to use for compilation such as "XLA_CPU_JIT", +// "XLA_GPU_JIT" or "XLA_TPU_JIT". +// use_tuple_args: when this is true, always create a tuple argument for the +// entry computation. +// enable_op_fallback: when this is true, prefer tf2xla fallback kernels over +// MLIR +// native kernels for legalization to HLO. +// return_tuple: when this is true, always create a tuple result for the +// entry computation. +// shape_determination_fns: Contains layout preference fn and shape +// representation fn. The two functions are used to determine argument and +// result shapes. +// custom_legalization_passes: passes to run before the default TF legalization +// passes for backend-specific ops. +ABSL_DEPRECATED("Use v2/legalize_tf.h::LegalizeMlirToHlo instead.") +absl::Status ConvertMLIRToXlaComputation( + mlir::ModuleOp module_op, llvm::StringRef device_type, + xla::XlaComputation* xla_computation, bool use_tuple_args, + bool enable_op_fallback, bool return_tuple, + XlaShapeLayoutHelpers::ShapeDeterminationFns shape_determination_fns = {}, + llvm::MutableArrayRef> + custom_legalization_passes = {}, + llvm::StringRef module_name = llvm::StringRef()); + +// Creates a MLIR pipeline that lowers MLIR module to MHLO dialect. The input +// module should only contain operations in tf dialect. For example, if the +// input module contains operation in the tf_executor dialect, the pass raises +// an error unless the tf_executor dialect ops are optimized away by +// canonicalization. +// +// The pipeline is used in ConvertMLIRToXlaComputation. And it generally has the +// following pass structure: +// - TensorFlow passes +// - Legalization passes +// - MHLO passes +// +// device_type: XLA JIT device to use for compilation such as "XLA_CPU_JIT", +// "XLA_GPU_JIT" or "XLA_TPU_JIT". +// enable_op_fallback: when this is true, prefer tf2xla fallback kernels over +// MLIR +// native kernels for legalization to HLO. +// custom_legalization_passes: passes to run before the default TF legalization +// passes for backend-specific ops. +// lower_to_xla_hlo: Temporary parameter to be removed in imminent update. If +// true, includes legalization and MHLO lowering passes. +// allow_partial_conversion: when this is true, allow operations that can't be +// legalized. +ABSL_DEPRECATED("Use v2/legalize_tf.h::LegalizeMlirToHlo instead.") +void CreateConvertMlirToXlaHloPipeline( + mlir::OpPassManager& pm, llvm::StringRef device_type, + bool enable_op_fallback, + llvm::MutableArrayRef> + custom_legalization_passes, + bool lower_to_xla_hlo = true, bool allow_partial_conversion = false); + +// Helper struct representing argument tensor or resource handle shapes. +struct TensorOrResourceShape { + TensorShape shape; + bool is_resource = false; +}; + +// Refine MLIR types based on new shape information. +ABSL_DEPRECATED("Not meant to be used directly and should be a util.") +absl::Status RefineShapes(llvm::ArrayRef arg_shapes, + mlir::ModuleOp module); + +// Lower TF to MHLO and insert HLO into the XlaBuilder. xla_params are HLO-level +// inputs to module_op that have already been added to the XlaBuilder. returns +// are the returned XlaOps. +ABSL_DEPRECATED("Use v2/legalize_tf.h::LegalizeMlirToHlo instead.") +absl::Status BuildHloFromTf(mlir::ModuleOp module_op, xla::XlaBuilder& builder, + llvm::ArrayRef xla_params, + std::vector& returns, + llvm::ArrayRef arg_shapes, + llvm::StringRef device_type, + llvm::MutableArrayRef> + custom_legalization_passes); + +// Apply shape, description, and resource information to inputs and outputs +// in the XlaCompilationResult. This should be called after +// compilation_result->computation was set. +ABSL_DEPRECATED("Not meant to be used directly and should be a util.") +absl::Status PopulateResultIOInfo( + mlir::ModuleOp module_op, llvm::ArrayRef arg_shapes, + bool use_tuple_args, bool use_resource_updates_for_aliases, + XlaShapeLayoutHelpers::ShapeDeterminationFns shape_determination_fns, + XlaCompilationResult* compilation_result); + +// Runs MLIR Bridge on an MLIR module. +// +// If lower_to_xla_hlo is true then compiles down into XLA HLO, generates all +// accompanying metadata and stores them in CompilationResult. +// +// If enable_op_fallback is set to false, graph is legalized only if the graph +// analysis for the graph is successful. Otherwise, an error is returned. +// +// Running the MLIR Bridge performs many transformations on the input module +// which is modified in place. +ABSL_DEPRECATED("Use v2/legalize_tf.h::LegalizeMlirToHlo instead.") +absl::Status CompileMlirToXlaHlo( + mlir::ModuleOp module_op, llvm::ArrayRef arg_shapes, + llvm::StringRef device_type, bool use_tuple_args, bool enable_op_fallback, + bool use_return_tuple, bool use_resource_updates_for_aliases, + XlaShapeLayoutHelpers::ShapeDeterminationFns shape_determination_fns, + XlaCompilationResult* compilation_result, + llvm::MutableArrayRef> + custom_legalization_passes, + llvm::StringRef module_name = llvm::StringRef(), + bool lower_to_xla_hlo = true); + +// Runs MLIR Bridge on a MLIR module. +// +// If lower_to_xla_hlo is true then compiles down into XLA HLO, generates all +// accompanying metadata and stores them in CompilationResult. +// +// If enable_op_fallback is set to false, graph is legalized only if the graph +// analysis for the graph is successful. Otherwise, an error is returned. +// +// On success, returns the serialized MLIR module. +ABSL_DEPRECATED("Use v2/legalize_tf.h::LegalizeMlirToHlo instead.") +absl::StatusOr CompileMlirToXlaHloAndSerialize( + mlir::ModuleOp module_op, llvm::ArrayRef arg_shapes, + llvm::StringRef device_type, bool use_tuple_args, bool enable_op_fallback, + bool use_return_tuple, bool use_resource_updates_for_aliases, + XlaShapeLayoutHelpers::ShapeDeterminationFns shape_determination_fns, + XlaCompilationResult* compilation_result, + llvm::MutableArrayRef> + custom_legalization_passes, + llvm::StringRef module_name = llvm::StringRef(), + bool lower_to_xla_hlo = true); + +// Runs MLIR Bridge on a serialized MLIR module. +// +// If lower_to_xla_hlo is true then compiles down into XLA HLO, generates all +// accompanying metadata and stores them in CompilationResult. +ABSL_DEPRECATED("Use v2/legalize_tf.h::LegalizeMlirToHlo instead.") +absl::StatusOr CompileSerializedMlirToXlaHlo( + llvm::StringRef mlir_module_string, llvm::ArrayRef arg_shapes, + llvm::StringRef device_type, bool use_tuple_args, bool enable_op_fallback, + XlaShapeLayoutHelpers::ShapeDeterminationFns shape_determination_fns, + XlaCompilationResult* compilation_result, + llvm::MutableArrayRef> + custom_legalization_passes = {}, + llvm::StringRef module_name = llvm::StringRef(), + bool lower_to_xla_hlo = true); + +// Compiles a TensorFlow Graph (already converted to MLIR, imported with +// tf_executor dialect still present) into XLA HLO, generates all accompanying +// metadata and stores them in CompilationResult. This will rewrite arguments +// and run the TensorFlow standard pipeline prior to invoking +// `CompileMlirToXlaHlo`. +ABSL_DEPRECATED("Use v2/legalize_tf.h::LegalizeMlirToHlo instead.") +absl::Status CompileGraphToXlaHlo( + mlir::ModuleOp module_op, llvm::ArrayRef args, + llvm::StringRef device_type, bool use_tuple_args, bool enable_op_fallback, + bool use_return_tuple, + XlaShapeLayoutHelpers::ShapeDeterminationFns shape_determination_fns, + XlaCompilationResult* compilation_result, + llvm::MutableArrayRef> + custom_legalization_passes); + +// Compiles a Graph from TF to HLO and adds the resulting HLO to the +// XlaBuilder. This function adds HLO to a larger HLO computation, so +// HLO-level inputs are supplied, and HLO-level outputs are produced. +// xla_params is the HLO-level inputs and returns is the HLO-level outputs. +// If unconditionally_use_output_shapes is true then the unregistered +// attribute _output_shapes is always used to set the output shapes of the ops. +ABSL_DEPRECATED( + "Use v1/compile_tf_graph.h::CompileTensorflowGraphToHlo instead.") +absl::Status BuildHloFromGraph( + const Graph& graph, xla::XlaBuilder& builder, + mlir::MLIRContext& mlir_context, llvm::ArrayRef xla_params, + std::vector& returns, bool unconditionally_use_output_shapes, + llvm::ArrayRef args, llvm::ArrayRef control_rets, + llvm::StringRef device_type, const FunctionLibraryDefinition& flib_def); + +static inline absl::Status CompileToHloGraphAnalysisFailedError() { + return errors::Internal("disabled after graph analysis"); +} + +// Register a convenient pipeline for invoking TF/XLA lowering from the command +// line. +void RegisterConvertMlirToXlaHloPipelineWithDefaults(); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_TF2XLA_API_V1_COMPILE_MLIR_UTIL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tf2xla/api/v1/compile_tf_graph.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tf2xla/api/v1/compile_tf_graph.h new file mode 100644 index 00000000..7007d70b --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tf2xla/api/v1/compile_tf_graph.h @@ -0,0 +1,56 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TF2XLA_API_V1_COMPILE_TF_GRAPH_H_ +#define TENSORFLOW_COMPILER_MLIR_TF2XLA_API_V1_COMPILE_TF_GRAPH_H_ + +#include +#include + +#include "absl/status/status.h" +#include "absl/types/variant.h" +#include "tensorflow/compiler/tf2xla/layout_util.h" +#include "tensorflow/compiler/tf2xla/xla_compiler.h" +#include "xla/client/compile_only_client.h" +#include "xla/pjrt/compile_options.pb.h" +#include "xla/shape.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/tpu/kernels/tpu_compile.pb.h" +#include "tensorflow/core/tpu/kernels/tpu_compile_op_support.h" + +namespace tensorflow { +namespace tf2xla { +namespace v1 { + +// Compiles the given Tensorflow graph into xla::HLO. The result is in +// compilation_result. If the input computation is in MLIR, it will be +// converted to a Tensorflow graph. Otherwise, the graph compiler will be run. +absl::Status CompileTensorflowGraphToHlo( + const std::variant& computation, + const tpu::TPUCompileMetadataProto& metadata, bool use_tuple_args, + XlaShapeLayoutHelpers::ShapeDeterminationFns shape_determination_funcs, + const std::vector& arg_shapes, + tsl::DeviceType device_type, + std::vector* arg_core_mapping, + std::vector>* per_core_arg_shapes, + xla::CompileOnlyClient* client, + XlaCompiler::CompilationResult* compilation_result); + +} // namespace v1 +} // namespace tf2xla +}; // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_TF2XLA_API_V1_COMPILE_TF_GRAPH_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tf2xla/api/v1/tf_dialect_to_executor.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tf2xla/api/v1/tf_dialect_to_executor.h new file mode 100644 index 00000000..d41627b3 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tf2xla/api/v1/tf_dialect_to_executor.h @@ -0,0 +1,57 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TF2XLA_API_V1_TF_DIALECT_TO_EXECUTOR_H_ +#define TENSORFLOW_COMPILER_MLIR_TF2XLA_API_V1_TF_DIALECT_TO_EXECUTOR_H_ + +#include "absl/base/attributes.h" +#include "llvm/ADT/StringRef.h" +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "tensorflow/core/platform/status.h" + +namespace tensorflow { +namespace tf2xla { +namespace v1 { + +// Given the input Module op that's in the Tensorflow Dialect, convert the MLIR +// module in place to the Tensorflow Executor Dialect. Returns an OK Status if +// success, otherwise failure with an error message. +// The Tensorflow Executor Dialect is required to export an MLIR module to a +// Tensorflow GraphDef. This API will add control dependencies and verify that +// the conversion was successful. This version adds extra control dependencies +// for replication and parallel execution ops, which may slow performance. +// Prefer to use the v2 of this API. +// +// This also converts the Tensorflow Dialect MLIR into the Tensorflow Executor +// dialect that is suitable to be exported to GraphDef. Graph -> MLIR -> Graph +// is not perfectly round trippable, so this API will attempt to make the module +// exportable and verify some properties of the Tensorflow Executor MLIR that +// are required by Graph Export. It will return an error if it cannot. +// +// Input: A MLIR Module in the Tensorflow Dialect with no +// `tf_device.cluster_func` ops. +// Output: A MLIR module in the Tensorflow Executor Dialect. + +ABSL_DEPRECATED( + "Use v2/tf_dialect_to_executor.h::ExportFromTensorflowDialectToExecutor " + "instead.") +absl::Status ExportFromTensorflowDialectToExecutor( + mlir::ModuleOp module, llvm::StringRef module_name = llvm::StringRef()); + +} // namespace v1 +} // namespace tf2xla +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_TF2XLA_API_V1_TF_DIALECT_TO_EXECUTOR_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tf2xla/api/v2/cluster_tf.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tf2xla/api/v2/cluster_tf.h new file mode 100644 index 00000000..6e9576fd --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tf2xla/api/v2/cluster_tf.h @@ -0,0 +1,61 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TF2XLA_API_V2_CLUSTER_TF_H_ +#define TENSORFLOW_COMPILER_MLIR_TF2XLA_API_V2_CLUSTER_TF_H_ + +#include "llvm/ADT/StringRef.h" +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tf2xla/api/v2/device_type.pb.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { +namespace tf2xla { +namespace v2 { + +// Run all the passes involved in transforming the graph before execution so +// that it is suitable for targeting devices when called with the TF 2 Function +// API. Users that need clustering with the Session API should use the v1 Bridge +// API. These transformations take as input a Tensorflow Graph as an MLIR Module +// and transforms the module in place to cluster the given ops for compilation +// that is compatible with the given device_type. The MLIR should be in the TF +// Executor Dialect for graph nodes and edges or be in TF Functional already. +// Individual Op inside a node should be the Tensorflow Functional Dialect. The +// output MLIR is in the TF Functional Dialect. Returns OkStatus if passed, +// otherwise an error. +// +// Inputs: +// module - The MLIR Module that will be clustered. Expected to be in TF +// Executor Dialect or TF Functional Dialect. Will convert to TF Functional. +// is_supported_by_replicated_brige - If the graph targets the replicated +// bridge. Set it to true for replicated/partitioned graphs. e.g. replicated +// and single-core TPU graphs. Set this to false if the graph is not +// replicated, e.g. CPU/GPU graphs. is_in_fallback_enabled_mode - Whether this +// was called with fallback to the non-MLIR Bridge. This is just for logging +// purposes and doesn't affect logic. module_name - What the input module name +// is for debugging help. +// +// Output: Modifies the input module in place with clustered operations. +// status - Whether the transformation to cluster the input MLIR module was +// successful. +absl::Status RunFunctionTf2xlaClusteringBridge( + mlir::ModuleOp module, bool is_supported_by_replicated_brige, + bool is_in_fallback_enabled_mode, + llvm::StringRef module_name = llvm::StringRef()); +} // namespace v2 +} // namespace tf2xla +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_TF2XLA_API_V2_CLUSTER_TF_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tf2xla/api/v2/graph_to_tf_executor.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tf2xla/api/v2/graph_to_tf_executor.h new file mode 100644 index 00000000..1af93e6b --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tf2xla/api/v2/graph_to_tf_executor.h @@ -0,0 +1,52 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TF2XLA_API_V2_GRAPH_TO_TF_EXECUTOR_H_ +#define TENSORFLOW_COMPILER_MLIR_TF2XLA_API_V2_GRAPH_TO_TF_EXECUTOR_H_ + +#include + +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h" +#include "tensorflow/compiler/mlir/tf2xla/internal/graph_to_tf_executor_util.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/graph_debug_info.pb.h" +#include "tensorflow/core/graph/graph.h" + +namespace tensorflow { +namespace tf2xla { +namespace v2 { + +inline constexpr absl::string_view kImportModelDefaultGraphFuncName = "main"; + +// Given a Graph, returns a MLIR module containing the graph, expressed with +// tf_executor dialect. +absl::StatusOr> ConvertGraphToTfExecutor( + const Graph& graph, const GraphDebugInfo& debug_info, + const FunctionLibraryDefinition& flib_def, const GraphImportConfig& specs, + mlir::MLIRContext* context, + std::unordered_map* tf_name_to_mlir_name = + nullptr, + const ConfigProto& config_proto = {}, + tensorflow::TF2XLABridgeVersion bridge_version = + tensorflow::TF2XLABridgeVersion::kNotBridgeUseCase); + +} // namespace v2 +} // namespace tf2xla +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_TF2XLA_API_V2_GRAPH_TO_TF_EXECUTOR_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tf2xla/api/v2/legalize_tf.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tf2xla/api/v2/legalize_tf.h new file mode 100644 index 00000000..14a8271d --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tf2xla/api/v2/legalize_tf.h @@ -0,0 +1,68 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TF2XLA_API_V2_LEGALIZE_TF_H_ +#define TENSORFLOW_COMPILER_MLIR_TF2XLA_API_V2_LEGALIZE_TF_H_ + +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/types/variant.h" +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tf2xla/api/v2/device_type.pb.h" +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "xla/client/compile_only_client.h" +#include "xla/pjrt/compile_options.pb.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/tpu/kernels/tpu_compile.pb.h" +#include "tensorflow/core/tpu/kernels/tpu_compile_op_support.h" +#include "tsl/platform/statusor.h" + +namespace tensorflow { +namespace tf2xla { +namespace v2 { + +// Legalizes the given mlir::Module into XLA HLO. If successful, returns the +// compiled XLA HLO. V1 of the tf2xla uses MLIR whereas V0 does not use MLIR. +// +// Inputs: +// computation - The MLIR module op. It currently takes in +// tpu::FunctionToHloArgs but this is deprecated. arg_shapes - The shapes of +// the arguments in module_op. device_type - The device type to compile for. +// use_tuple_args - Pack the incoming arg shapes into a single tuple. +// custom_legalization_passes - Extra passes to lower from TF -> MHLO. +// arg_shapes - The shapes of the args. +// arg_core_mapping - Which args go on which cores. +// per_core_arg_shapes - For each core, the shapes for each argument. +// client - The Xla Compilation client. +absl::StatusOr LegalizeMlirToHlo( + const std::variant& computation, + const tpu::TPUCompileMetadataProto& metadata, bool use_tuple_args, + llvm::StringRef device_type, + std::vector>& custom_legalization_passes, + XlaShapeLayoutHelpers::ShapeDeterminationFns shape_determination_fns, + const std::vector& arg_shapes, + std::vector* arg_core_mapping, + std::vector>* per_core_arg_shapes, + xla::CompileOnlyClient* client); + +}; // namespace v2 +}; // namespace tf2xla +}; // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_TF2XLA_API_V2_LEGALIZE_TF_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tf2xla/api/v2/testing/compile_mlir.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tf2xla/api/v2/testing/compile_mlir.h new file mode 100644 index 00000000..7394fe37 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tf2xla/api/v2/testing/compile_mlir.h @@ -0,0 +1,40 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_TF2XLA_API_V2_TESTING_COMPILE_MLIR_H_ +#define TENSORFLOW_COMPILER_MLIR_TF2XLA_API_V2_TESTING_COMPILE_MLIR_H_ + +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "tensorflow/compiler/tf2xla/xla_compiler.h" +#include "tensorflow/core/protobuf/config.pb.h" +#include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h" + +namespace tensorflow { +namespace tf2xla { +namespace v2 { +namespace testing { + +// Compiles the given MLIR module to XLA HLO. +absl::StatusOr CompileMlirModule( + const char* mlir_module_str, + ConfigProto::Experimental::MlirBridgeRollout rollout_state, + absl::string_view device_type = "XLA_TPU_JIT"); + +} // namespace testing +} // namespace v2 +} // namespace tf2xla +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_TF2XLA_API_V2_TESTING_COMPILE_MLIR_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tf2xla/api/v2/testing/utils.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tf2xla/api/v2/testing/utils.h new file mode 100644 index 00000000..b2c2cf62 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tf2xla/api/v2/testing/utils.h @@ -0,0 +1,34 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TF2XLA_API_V2_TESTING_UTILS_H_ +#define TENSORFLOW_COMPILER_MLIR_TF2XLA_API_V2_TESTING_UTILS_H_ + +#include + +namespace tensorflow { +namespace tf2xla { +namespace v2 { +namespace testing { + +// Returns the path to the testdata directory. +std::string TestDataPath(); + +} // namespace testing +} // namespace v2 +} // namespace tf2xla +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_TF2XLA_API_V2_TESTING_UTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tf2xla/api/v2/tf_dialect_to_executor.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tf2xla/api/v2/tf_dialect_to_executor.h new file mode 100644 index 00000000..185cefa5 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tf2xla/api/v2/tf_dialect_to_executor.h @@ -0,0 +1,50 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TF2XLA_API_V2_TF_DIALECT_TO_EXECUTOR_H_ +#define TENSORFLOW_COMPILER_MLIR_TF2XLA_API_V2_TF_DIALECT_TO_EXECUTOR_H_ + +#include "llvm/ADT/StringRef.h" +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "tensorflow/core/platform/status.h" + +namespace tensorflow { +namespace tf2xla { +namespace v2 { + +// Given the input Module op that's in the Tensorflow Dialect, convert the MLIR +// module in place to the Tensorflow Executor Dialect. Returns an OK Status if +// success, otherwise failure with an error message. +// The Tensorflow Executor Dialect is required to export an MLIR module to a +// Tensorflow GraphDef. This API will add control dependencies and verify that +// the conversion was successful. +// +// This also converts the Tensorflow Dialect MLIR into the Tensorflow Executor +// dialect that is suitable to be exported to GraphDef. Graph -> MLIR -> Graph +// is not perfectly round trippable, so this API will attempt to make the module +// exportable and verify some properties of the Tensorflow Executor MLIR that +// are required by Graph Export. It will return an error if it cannot. +// +// Input: A MLIR Module in the Tensorflow Dialect with no +// `tf_device.cluster_func` ops. +// Output: A MLIR module in the Tensorflow Executor Dialect. +absl::Status ExportFromTensorflowDialectToExecutor( + mlir::ModuleOp module, llvm::StringRef module_name = llvm::StringRef()); + +} // namespace v2 +} // namespace tf2xla +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_TF2XLA_API_V2_TF_DIALECT_TO_EXECUTOR_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tf2xla/api/v2/tf_executor_to_graph.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tf2xla/api/v2/tf_executor_to_graph.h new file mode 100644 index 00000000..8fd7607a --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tf2xla/api/v2/tf_executor_to_graph.h @@ -0,0 +1,54 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TF2XLA_API_V2_TF_EXECUTOR_TO_GRAPH_H_ +#define TENSORFLOW_COMPILER_MLIR_TF2XLA_API_V2_TF_EXECUTOR_TO_GRAPH_H_ + +#include "absl/base/attributes.h" +#include "absl/container/flat_hash_set.h" +#include "llvm/ADT/StringRef.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/graph/graph.h" + +namespace tensorflow { +namespace tf2xla { +namespace v2 { + +// Converts an MLIR module to TensorFlow graph and FunctionLibraryDefinition. +// The "main" function of the module is stored in the graph and the rest of +// functions are stored in the library. Control ret nodes are stored separately +// in `control_ret_nodes`. +absl::Status ConvertTfExecutorToGraph( + mlir::ModuleOp module, const GraphExportConfig& configs, + std::unique_ptr* graph, FunctionLibraryDefinition* flib_def, + absl::flat_hash_set* control_ret_nodes); + +// Converts an MLIR function and adds it to a FunctionLibraryDefinition. +absl::Status ConvertMlirFunctionToFunctionLibraryDef( + mlir::func::FuncOp func, const GraphExportConfig& configs, + FunctionDef* function_def); + +} // namespace v2 +} // namespace tf2xla +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_TF2XLA_API_V2_TF_EXECUTOR_TO_GRAPH_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tf2xla/internal/clustering_bridge_passes.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tf2xla/internal/clustering_bridge_passes.h new file mode 100644 index 00000000..6f8595cb --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tf2xla/internal/clustering_bridge_passes.h @@ -0,0 +1,39 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TF2XLA_INTERNAL_CLUSTERING_BRIDGE_PASSES_H_ +#define TENSORFLOW_COMPILER_MLIR_TF2XLA_INTERNAL_CLUSTERING_BRIDGE_PASSES_H_ + +#include "absl/base/attributes.h" +#include "llvm/ADT/StringRef.h" +#include "mlir/Pass/PassManager.h" // from @llvm-project + +namespace tensorflow { +namespace tf2xla { +namespace internal { + +// Given the pass manager, add Bridge passes to cluster the replicated input +// graphs. +void AddReplicatedBridgeClusteringPipelinePasses( + mlir::OpPassManager& pm, llvm::StringRef module_name = llvm::StringRef()); + +// Same as above but for non replicated graphs. +void AddNonReplicatedBridgeClusteringPipelinePasses(mlir::OpPassManager& pm); + +}; // namespace internal +}; // namespace tf2xla +}; // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_TF2XLA_INTERNAL_CLUSTERING_BRIDGE_PASSES_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tf2xla/internal/compilation_timer.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tf2xla/internal/compilation_timer.h new file mode 100644 index 00000000..2eb46935 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tf2xla/internal/compilation_timer.h @@ -0,0 +1,43 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TF2XLA_INTERNAL_COMPILATION_TIMER_H_ +#define TENSORFLOW_COMPILER_MLIR_TF2XLA_INTERNAL_COMPILATION_TIMER_H_ + +#include // NOLINT(build/c++11) + +#include "tensorflow/core/platform/profile_utils/cpu_utils.h" + +// Time the execution of kernels (in CPU cycles). Meant to be used as RAII. +struct CompilationTimer { + uint64_t start_cycles = + tensorflow::profile_utils::CpuUtils::GetCurrentClockCycle(); + + uint64_t ElapsedCycles() { + return tensorflow::profile_utils::CpuUtils::GetCurrentClockCycle() - + start_cycles; + } + + int64_t ElapsedCyclesInMilliseconds() { + std::chrono::duration duration = + tensorflow::profile_utils::CpuUtils::ConvertClockCycleToTime( + ElapsedCycles()); + + return std::chrono::duration_cast(duration) + .count(); + } +}; + +#endif // TENSORFLOW_COMPILER_MLIR_TF2XLA_INTERNAL_COMPILATION_TIMER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tf2xla/internal/graph_to_tf_executor_util.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tf2xla/internal/graph_to_tf_executor_util.h new file mode 100644 index 00000000..c08a2c39 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tf2xla/internal/graph_to_tf_executor_util.h @@ -0,0 +1,64 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TF2XLA_INTERNAL_GRAPH_TO_TF_EXECUTOR_UTIL_H_ +#define TENSORFLOW_COMPILER_MLIR_TF2XLA_INTERNAL_GRAPH_TO_TF_EXECUTOR_UTIL_H_ + +#include + +#include "tensorflow/core/graph/graph.h" + +namespace tensorflow { + +// These are used for grouping the recorded stats appropriately. Specifically, +// we're considering different entrypoints to the bridge as having potentially +// interesting differences at least in the domain of accepted graphs so we want +// to separately track graph features based on these unique entrypoints. One key +// example of this distinction is for TFRT which uses the "nominal" TPU bridge +// pipeline, but may potentially allow graphs with v1 control flow. This +// separate grouping will allow us to dig into these differences granularly. +enum class TF2XLABridgeVersion { + kNominal = 0, + kV1Compat, + kTFRTNominal, + kNotBridgeUseCase, +}; + +// Analyzes whether the graph has features not guaranteed to be supported by the +// MLIR-based TF XLA bridge for phase 1. If MLIR bridge phase 1 is not used, +// then MLIR bridge phase 2 will not be used. The optional `function_library` +// can be provided if it contains function definitions not including in the +// `graph` FunctionLibraryDefinition. +// +// Conservatively, during the initial rollout, we are not supporting graphs for +// which any of the following are true: +// +// - Not known to be TF2 +// - Contains one or more reference variables +// - Contains one or more TPUPartitionedCall ops (which is a proxy for +// inference), but the graph is not v1 compat +// - Uses V1 control flow +// - Graph is invalid or otherwise encounters error during traversal +// If `single_core_inference_mode` is true, we skip some of check conditions +// because they are not applicable. +// TODO(b/241702857): remove single_core_inference_mode +bool GraphHasUnsupportedFeaturesInMlirBridge( + const Graph& graph, const FunctionLibraryDefinition* function_library, + std::optional config_proto, TF2XLABridgeVersion bridge_version, + bool single_core_inference_mode); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_TF2XLA_INTERNAL_GRAPH_TO_TF_EXECUTOR_UTIL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tf2xla/internal/inference/inference_passes.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tf2xla/internal/inference/inference_passes.h new file mode 100644 index 00000000..7d4bf660 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tf2xla/internal/inference/inference_passes.h @@ -0,0 +1,39 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TF2XLA_INTERNAL_INFERENCE_INFERENCE_PASSES_H_ +#define TENSORFLOW_COMPILER_MLIR_TF2XLA_INTERNAL_INFERENCE_INFERENCE_PASSES_H_ + +#include + +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project + +namespace mlir { +namespace tf2xla { +namespace internal { + +std::unique_ptr> CreateInferenceMetricsPass(); + +#define GEN_PASS_REGISTRATION +#define GEN_PASS_DECL_INFERENCEMETRICSPASS +#include "tensorflow/compiler/mlir/tf2xla/internal/inference/inference_passes.h.inc" + +} // namespace internal +} // namespace tf2xla +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_TF2XLA_INTERNAL_INFERENCE_INFERENCE_PASSES_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tf2xla/internal/legalize_tf_mlir.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tf2xla/internal/legalize_tf_mlir.h new file mode 100644 index 00000000..fec64c0f --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tf2xla/internal/legalize_tf_mlir.h @@ -0,0 +1,49 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TF2XLA_INTERNAL_LEGALIZE_TF_MLIR_H_ +#define TENSORFLOW_COMPILER_MLIR_TF2XLA_INTERNAL_LEGALIZE_TF_MLIR_H_ + +#include +#include + +#include "llvm/ADT/StringRef.h" +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/core/tpu/kernels/tpu_compile_op_support.h" +#include "tsl/platform/statusor.h" + +namespace tensorflow { +namespace tf2xla { +namespace internal { + +// Runs all the MLIR Bridge passes on the given MLIR module. +// If compile_to_xla_hlo is true then those passes include all the Legalization +// to XLA HLO which is returned in the compilation_result. +absl::Status CompileFromMlirToXlaHlo( + bool lower_to_xla_hlo, mlir::ModuleOp mlir_module_op, + const tpu::TPUCompileMetadataProto& metadata, llvm::StringRef device_type, + const XlaShapeLayoutHelpers::ShapeDeterminationFns& shape_determination_fns, + bool use_tuple_args, XlaCompiler::CompilationResult* compilation_result, + std::vector>& custom_legalization_passes, + const std::vector& arg_shapes, + std::vector* arg_core_mapping, + std::vector>* per_core_arg_shapes); + +}; // namespace internal +}; // namespace tf2xla +}; // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_TF2XLA_INTERNAL_LEGALIZE_TF_MLIR_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tf2xla/internal/legalize_tf_to_hlo.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tf2xla/internal/legalize_tf_to_hlo.h new file mode 100644 index 00000000..664bd549 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tf2xla/internal/legalize_tf_to_hlo.h @@ -0,0 +1,47 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TF2XLA_INTERNAL_LEGALIZE_TF_TO_HLO_H_ +#define TENSORFLOW_COMPILER_MLIR_TF2XLA_INTERNAL_LEGALIZE_TF_TO_HLO_H_ + +#include "llvm/ADT/StringRef.h" +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "xla/client/compile_only_client.h" +#include "tensorflow/core/tpu/kernels/tpu_compile_op_support.h" +#include "tsl/platform/statusor.h" + +namespace tensorflow { +namespace tf2xla { +namespace internal { + +// Legalize the given MLIR module to XLA HLO using a combination of the MLIR +// Bridge and XlaBuilder +absl::StatusOr LegalizeTfToHlo( + const tpu::MlirToHloArgs& computation, + const tpu::TPUCompileMetadataProto& metadata, bool use_tuple_args, + llvm::StringRef device_type, + XlaShapeLayoutHelpers::ShapeDeterminationFns shape_determination_fns, + const std::vector& arg_shapes, + std::vector* arg_core_mapping, + std::vector>* per_core_arg_shapes, + std::vector>& custom_legalization_passes, + xla::CompileOnlyClient* client, XlaCompilationResult* compilation_result); + +}; // namespace internal +}; // namespace tf2xla +}; // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_TF2XLA_INTERNAL_LEGALIZE_TF_TO_HLO_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tf2xla/internal/logging_hooks.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tf2xla/internal/logging_hooks.h new file mode 100644 index 00000000..61c5028a --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tf2xla/internal/logging_hooks.h @@ -0,0 +1,38 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TF2XLA_INTERNAL_LOGGING_HOOKS_H_ +#define TENSORFLOW_COMPILER_MLIR_TF2XLA_INTERNAL_LOGGING_HOOKS_H_ + +#include + +#include "llvm/ADT/StringRef.h" +#include "mlir/Pass/PassManager.h" // from @llvm-project + +namespace tensorflow { +namespace tf2xla { +namespace internal { + +// Setup the input pass manager to enable IR dumping after each pass. +// Note a side effect of this method is that multi threading will be disabled. +void EnablePassIRPrinting(mlir::PassManager& pm, + const std::string& dump_group_name, + llvm::StringRef module_name = llvm::StringRef()); + +}; // namespace internal +}; // namespace tf2xla +}; // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_TF2XLA_INTERNAL_LOGGING_HOOKS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tf2xla/internal/mlir_bridge_pass_util.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tf2xla/internal/mlir_bridge_pass_util.h new file mode 100644 index 00000000..c0f2a5e5 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tf2xla/internal/mlir_bridge_pass_util.h @@ -0,0 +1,54 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TF2XLA_INTERNAL_MLIR_BRIDGE_PASS_UTIL_H_ +#define TENSORFLOW_COMPILER_MLIR_TF2XLA_INTERNAL_MLIR_BRIDGE_PASS_UTIL_H_ + +#include + +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "tensorflow/core/framework/function.h" + +namespace tensorflow { + +// Checks if a graph or reachable functions in the library have any +// StatefulPartitionedOps with _XlaMustCompile=true. The function library will +// be skipped if nullptr is provided. +bool IsSupportedByNonReplicatedBridge( + const Graph& graph, const FunctionLibraryDefinition* function_library); + +// Checks if a graph or reachable functions in the library have any ops with +// _tpu_replicate or _xla_compile_device_type=TPU. The function library will be +// skipped if nullptr is provided. + +bool IsSupportedByReplicatedBridge( + const Graph& graph, const FunctionLibraryDefinition* function_library); + +// Check if an MLIR module has any ops with _tpu_replicate or +// _xla_compile_device_type=TPU. +bool IsSupportedByReplicatedBridge(mlir::ModuleOp module); + +// Check if an MLIR module contains TPUPartitionedCall op. If so, we define +// such graph as an inference graph. Otherwise, it is non inference graph. +bool HasTPUPartitionedCallOpInModule(mlir::ModuleOp module); + +// Check if a graph contains TPUPartitionedCall op, including its reachable +// functions. The function library is used to store the functions that are +// defined in a TensorFlow program +bool IsInferenceGraph(const Graph& graph, + const FunctionLibraryDefinition* function_library); +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_TF2XLA_INTERNAL_MLIR_BRIDGE_PASS_UTIL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tf2xla/internal/mlir_pass_instrumentation.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tf2xla/internal/mlir_pass_instrumentation.h new file mode 100644 index 00000000..f4375dfc --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tf2xla/internal/mlir_pass_instrumentation.h @@ -0,0 +1,36 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TF2XLA_INTERNAL_MLIR_PASS_INSTRUMENTATION_H_ +#define TENSORFLOW_COMPILER_MLIR_TF2XLA_INTERNAL_MLIR_PASS_INSTRUMENTATION_H_ + +#include +#include +#include +#include + +#include "mlir/Pass/PassInstrumentation.h" // from @llvm-project + +namespace mlir { + +void RegisterPassInstrumentor( + const std::string& name, + std::function()> creator); +std::vector()>> +GetPassInstrumentors(); + +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_TF2XLA_INTERNAL_MLIR_PASS_INSTRUMENTATION_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tf2xla/internal/node_order.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tf2xla/internal/node_order.h new file mode 100644 index 00000000..a6f65006 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tf2xla/internal/node_order.h @@ -0,0 +1,51 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TF2XLA_INTERNAL_NODE_ORDER_H_ +#define TENSORFLOW_COMPILER_MLIR_TF2XLA_INTERNAL_NODE_ORDER_H_ + +#include +#include +#include + +#include "absl/container/flat_hash_set.h" +#include "tensorflow/core/graph/algorithm.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/lib/gtl/array_slice.h" + +namespace tensorflow { + +struct GroupByDevice { + std::string operator()(const Node* node) const { + return node->requested_device(); + } +}; + +// Performs a topological ordering of nodes. +// This has the property that any child node of a parent node p is emitted +// before p. A grouping function is used to break ties if multiple child nodes +// (of possibly different parents) are ready to be emitted at some point, which +// is when we prefer to stay in the current group. Remaining ties are broken by +// node name. +// The "emit" function is used for outputing the result, and is called once +// for each node. +// This algorithm is O(n * k * log k), with k the largest node degree. +void TopologicalOrdering( + const Graph& g, const std::function& emit, + const std::function& get_grouping_key); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_TF2XLA_INTERNAL_NODE_ORDER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tf2xla/internal/passes/clustering_passes.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tf2xla/internal/passes/clustering_passes.h new file mode 100644 index 00000000..4d91f113 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tf2xla/internal/passes/clustering_passes.h @@ -0,0 +1,93 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TF2XLA_INTERNAL_PASSES_CLUSTERING_PASSES_H_ +#define TENSORFLOW_COMPILER_MLIR_TF2XLA_INTERNAL_PASSES_CLUSTERING_PASSES_H_ + +#include + +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project + +namespace tensorflow { +namespace tf2xla { +namespace internal { + +// Verifies that all MLIR Ops have the expected attributes. +std::unique_ptr> +CreateVerifyClusteringPass(); + +// Creates a pass that forms clusters from operations of the same +// `_replication_info` attribute. +std::unique_ptr> +CreateTPUClusterFormationPass(bool strict_clusters = false); + +// Creates a pass that extracts outside compilation (Host ops inside device +// cluster) at head/tail of Device cluster to run before/after XLA computation. +std::unique_ptr> +CreateExtractHeadTailOutsideCompilationPass(); + +// Creates a pass that extract outside compilation (Host ops inside cevice +// cluster) ops to a separate parallel_execute region to run on CPU. +std::unique_ptr> +CreateExtractOutsideCompilationPass(); + +// Create a pass that encapsulates StatefulPartitionedCallOp within a cluster. +std::unique_ptr> +CreateXlaClusterFormationPass(); + +// Creates a pass that marks unsupported ops in device cluster for outside +// compilation. +std::unique_ptr> +CreateMarkOpsForOutsideCompilationPass(); + +// Creates a pass that hoists reads out of a replicate that are on a variable +// whose value is broacast to all replicas. +std::unique_ptr> +CreateHoistBroadcastReadPass(); + +// Creates a pass that moves broadcasts from TF host ops to XLA code, encoded as +// XlaAllReduces. This enables use of the device network for broadcasts, which +// is faster. +std::unique_ptr> +CreateXlaBroadcastPass(); + +// Creates a pass that identifies XLASharding ops in launch op for TPU +// computation. +std::unique_ptr> +CreateTPUShardingIdentificationPass(); + +// Creates a pass that validates the inputs to a TPU computation. +std::unique_ptr> +CreateTPUValidateSessionInputsPass(); + +std::unique_ptr> +CreateTPUValidateInputsPass(); + +#define GEN_PASS_REGISTRATION +#define GEN_PASS_DECL_MARKOPSFOROUTSIDECOMPILATIONPASS +#define GEN_PASS_DECL_TPUCLUSTERFORMATIONPASS +#define GEN_PASS_DECL_TPUEXTRACTHEADTAILOUTSIDECOMPILATIONPASS +#define GEN_PASS_DECL_TPUEXTRACTOUTSIDECOMPILATIONPASS +#define GEN_PASS_DECL_TPUSHARDINGIDENTIFICATIONPASS +#define GEN_PASS_DECL_TPUVALIDATEINPUTSPASS +#define GEN_PASS_DECL_TPUVALIDATESESSIONINPUTSPASS +#define GEN_PASS_DECL_VERIFYCLUSTERINGPASS +#define GEN_PASS_DECL_XLACLUSTERFORMATIONPASS +#include "tensorflow/compiler/mlir/tf2xla/internal/passes/clustering_passes.h.inc" + +} // namespace internal +} // namespace tf2xla +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_TF2XLA_INTERNAL_PASSES_CLUSTERING_PASSES_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tf2xla/internal/passes/lowering_passes.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tf2xla/internal/passes/lowering_passes.h new file mode 100644 index 00000000..0be689c6 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tf2xla/internal/passes/lowering_passes.h @@ -0,0 +1,38 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TF2XLA_INTERNAL_PASSES_LOWERING_PASSES_H_ +#define TENSORFLOW_COMPILER_MLIR_TF2XLA_INTERNAL_PASSES_LOWERING_PASSES_H_ + +#include + +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project + +namespace tensorflow { +namespace tf2xla { +namespace internal { + +// Create a pass that just collects metrics about the input MLIR. Does not +// logically transform the program. +std::unique_ptr> +CreateInputLoweringMetricsPass(); + +#define GEN_PASS_REGISTRATION +#define GEN_PASS_DECL_INPUTLOWERINGMETRICSPASS +#include "tensorflow/compiler/mlir/tf2xla/internal/passes/lowering_passes.h.inc" + +} // namespace internal +} // namespace tf2xla +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_TF2XLA_INTERNAL_PASSES_LOWERING_PASSES_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tf2xla/internal/passes/mlir_to_graph_passes.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tf2xla/internal/passes/mlir_to_graph_passes.h new file mode 100644 index 00000000..4e28930b --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tf2xla/internal/passes/mlir_to_graph_passes.h @@ -0,0 +1,35 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_TF2XLA_INTERNAL_PASSES_MLIR_TO_GRAPH_PASSES_H_ +#define TENSORFLOW_COMPILER_MLIR_TF2XLA_INTERNAL_PASSES_MLIR_TO_GRAPH_PASSES_H_ + +#include + +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project + +namespace tensorflow { +namespace tf2xla { +namespace internal { + +// Verifies that Executor input is of the expected format. +std::unique_ptr> +CreateVerifyInputDialectToExecutorPass(); + +#define GEN_PASS_REGISTRATION +#define GEN_PASS_DECL_VERIFYINPUTDIALECTTOEXECUTORPASS +#include "tensorflow/compiler/mlir/tf2xla/internal/passes/mlir_to_graph_passes.h.inc" + +} // namespace internal +} // namespace tf2xla +} // namespace tensorflow +#endif // TENSORFLOW_COMPILER_MLIR_TF2XLA_INTERNAL_PASSES_MLIR_TO_GRAPH_PASSES_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tf2xla/internal/passes/tpu_validate_inputs_utils.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tf2xla/internal/passes/tpu_validate_inputs_utils.h new file mode 100644 index 00000000..152b2e02 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tf2xla/internal/passes/tpu_validate_inputs_utils.h @@ -0,0 +1,49 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TF2XLA_INTERNAL_PASSES_TPU_VALIDATE_INPUTS_UTILS_H_ +#define TENSORFLOW_COMPILER_MLIR_TF2XLA_INTERNAL_PASSES_TPU_VALIDATE_INPUTS_UTILS_H_ + +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/Support/TypeID.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.h" + +namespace tensorflow { +namespace tf2xla { +namespace internal { + +constexpr char kTpuReplicatedCoreZeroAttr[] = "TPU_REPLICATED_CORE:0"; + +using mlir::ModuleOp; +using mlir::Operation; +using mlir::StringAttr; +using mlir::TypeID; +using mlir::TF::InfeedDequeueTupleOp; +using mlir::TF::kDeviceAttr; +using mlir::tf_executor::GraphOp; + +bool IsPotentialUnsupportedOp(Operation* op); + +bool HasV1ControlFlow(GraphOp graph); + +} // namespace internal +} // namespace tf2xla +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_TF2XLA_INTERNAL_PASSES_TPU_VALIDATE_INPUTS_UTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tf2xla/internal/test_matchers.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tf2xla/internal/test_matchers.h new file mode 100644 index 00000000..57c65bbf --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tf2xla/internal/test_matchers.h @@ -0,0 +1,91 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TF2XLA_INTERNAL_TEST_MATCHERS_H_ +#define TENSORFLOW_COMPILER_MLIR_TF2XLA_INTERNAL_TEST_MATCHERS_H_ + +#include +#include "absl/status/statusor.h" +#include "tensorflow/compiler/mlir/tf2xla/api/v1/compile_mlir_util.h" +#include "tsl/platform/statusor.h" + +template +bool WasGraphAnalysisFailure(const absl::StatusOr& status) { + return (status.status() == + tensorflow::CompileToHloGraphAnalysisFailedError()); +} + +/* The third party version of the Graph Analysis always returns disabled so + * these matchers short circuit on that error. */ +MATCHER(IsOkOrFiltered, + "Status was OK or equal to the Graph Analysis failure") { + bool is_ok = arg.ok(); + auto graph_analysis_failure = WasGraphAnalysisFailure(arg); + return testing::ExplainMatchResult( + testing::IsTrue(), is_ok || graph_analysis_failure, result_listener); +} + +MATCHER_P2(IncrementedOrFiltered, metric, value, + "Metric was incremented by value or Status equal to the Graph " + "Analysis failure") { + auto graph_analysis_failure = WasGraphAnalysisFailure(arg); + if (graph_analysis_failure) { + return testing::ExplainMatchResult(testing::IsTrue(), + graph_analysis_failure, result_listener); + } + return testing::ExplainMatchResult(testing::Eq(metric), value, + result_listener); +} + +MATCHER_P(ComputationProtoContains, regex, + "If not a Graph Analysis failure then matches the computation result " + "with the regex") { + auto graph_analysis_failure = WasGraphAnalysisFailure(arg); + if (graph_analysis_failure) { + return testing::ExplainMatchResult(testing::IsTrue(), + graph_analysis_failure, result_listener); + } + auto proto = arg.value().computation->proto().DebugString(); + return testing::ExplainMatchResult(testing::ContainsRegex(regex), proto, + result_listener); +} + +MATCHER_P(XlaComputationProtoContains, regex, + "If not a Graph Analysis failure then matches the computation result " + "with the regex") { + auto graph_analysis_failure = WasGraphAnalysisFailure(arg); + if (graph_analysis_failure) { + return testing::ExplainMatchResult(testing::IsTrue(), + graph_analysis_failure, result_listener); + } + auto proto = arg.value().proto().DebugString(); + return testing::ExplainMatchResult(testing::ContainsRegex(regex), proto, + result_listener); +} + +MATCHER_P( + HasMlirModuleWith, expected, + "If not a Graph Analysis failure then matches the mlir module result") { + auto graph_analysis_failure = WasGraphAnalysisFailure(arg); + if (graph_analysis_failure) { + return testing::ExplainMatchResult(testing::IsTrue(), + graph_analysis_failure, result_listener); + } + auto actual = arg.value(); + return testing::ExplainMatchResult(testing::ContainsRegex(expected), actual, + result_listener); +} + +#endif // TENSORFLOW_COMPILER_MLIR_TF2XLA_INTERNAL_TEST_MATCHERS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tf2xla/internal/utils/dialect_detection_utils.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tf2xla/internal/utils/dialect_detection_utils.h new file mode 100644 index 00000000..6dd9851f --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tf2xla/internal/utils/dialect_detection_utils.h @@ -0,0 +1,33 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TF2XLA_INTERNAL_UTILS_DIALECT_DETECTION_UTILS_H_ +#define TENSORFLOW_COMPILER_MLIR_TF2XLA_INTERNAL_UTILS_DIALECT_DETECTION_UTILS_H_ + +#include "mlir/IR/Operation.h" // from @llvm-project + +namespace tensorflow { +namespace tf2xla { +namespace internal { + +// Returns true if the op has a valid namespace during clustering & tf dialect +// to executor components of the Bridge. +bool IsInBridgeAcceptableDialects(mlir::Operation* op); + +} // namespace internal +} // namespace tf2xla +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_TF2XLA_INTERNAL_UTILS_DIALECT_DETECTION_UTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tf2xla/internal/utils/test_metadata_config.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tf2xla/internal/utils/test_metadata_config.h new file mode 100644 index 00000000..83d6beb2 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tf2xla/internal/utils/test_metadata_config.h @@ -0,0 +1,41 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TF2XLA_INTERNAL_UTILS_TEST_METADATA_CONFIG_H_ +#define TENSORFLOW_COMPILER_MLIR_TF2XLA_INTERNAL_UTILS_TEST_METADATA_CONFIG_H_ + +#include +#include + +#include "absl/status/status.h" +#include "absl/strings/string_view.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h" + +namespace tensorflow { +namespace tf2xla { +namespace internal { + +// Fills in arg_shapes and metadata_proto with appropriate values based on the +// input mlir module. +absl::Status ConfigureMetadata(absl::string_view mlir_module_str, + std::vector& arg_shapes, + tpu::TPUCompileMetadataProto& metadata_proto); + +} // namespace internal +} // namespace tf2xla +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_TF2XLA_INTERNAL_UTILS_TEST_METADATA_CONFIG_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tf2xla/mlir_bridge_rollout_policy.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tf2xla/mlir_bridge_rollout_policy.h new file mode 100644 index 00000000..7508a8d7 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tf2xla/mlir_bridge_rollout_policy.h @@ -0,0 +1,79 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TF2XLA_MLIR_BRIDGE_ROLLOUT_POLICY_H_ +#define TENSORFLOW_COMPILER_MLIR_TF2XLA_MLIR_BRIDGE_ROLLOUT_POLICY_H_ + +#include + +#include "mlir/IR/BuiltinOps.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/protobuf/config.pb.h" + +namespace tensorflow { + +enum class MlirBridgeRolloutPolicy { + // The MLIR bridge is explicitly disabled by the user and must not be run. + kDisabledByUser = 0, + // The MLIR bridge is explicitly enabled by the user and must be run. If the + // MLIR bridge errors, the fallback path should NOT be used. + kEnabledByUser, + // The bridge was not explicitly enabled or disabled by the user. Based on the + // features in the model, the MLIR bridge should not be run. + kDisabledAfterGraphAnalysis, + // The bridge was not explicitly enabled or disabled by the user. Based on the + // features in the model, the MLIR bridge should be run. If the MLIR Bridge + // errors, the fallback path should be used whenever possible. + kEnabledAfterGraphAnalysis, +}; + +// Analyzes the user requested policy as well as the contents of the graph and +// returns true when the MLIR Bridge should be run. +// +// If the user explicitly requests the bridge be enabled or disabled, this +// function will respect the request. If the user does not explicitly request +// enabled or disabled, it will decide whether or not to run the bridge. +// +// The config_proto param is a required input for all TF1 graphs but it is +// redundant for TF2 graphs. +// If getting rollout policy involves graph analysis, `record_stats` is used +// to decide whether to emit metrics on unsupported features of the graph. +MlirBridgeRolloutPolicy GetMlirBridgeRolloutPolicy( + const tensorflow::Graph& graph, + const FunctionLibraryDefinition* function_library, + std::optional config_proto, + bool is_supported_by_replicated_brige, bool is_v1_compat, + bool record_stats); + +static inline MlirBridgeRolloutPolicy GetMlirBridge2ndPhaseRolloutPolicy( + mlir::ModuleOp module) { + return MlirBridgeRolloutPolicy::kDisabledAfterGraphAnalysis; +} + +// Explicit Interface for when we want to log features vs test the validity of +// the graph for MLIR bridge processing. Note that right now the logging +// which is done in the logic used by GraphHasFeaturesUnsupportedByMlirBridge +// has diverged and logs supported features as well. Parameters are the same +// as for GetMlirBridgeRolloutPolicy with the exception of +// record_stats, which isn't needed because this interface will always record. +void LogGraphFeatures(const Graph& graph, + const FunctionLibraryDefinition* function_library, + std::optional config_proto, + bool is_v1_compat); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_TF2XLA_MLIR_BRIDGE_ROLLOUT_POLICY_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tf2xla/transforms/legalization_op_config.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tf2xla/transforms/legalization_op_config.h new file mode 100644 index 00000000..b94f3370 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tf2xla/transforms/legalization_op_config.h @@ -0,0 +1,45 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TF2XLA_TRANSFORMS_LEGALIZATION_OP_CONFIG_H_ +#define TENSORFLOW_COMPILER_MLIR_TF2XLA_TRANSFORMS_LEGALIZATION_OP_CONFIG_H_ + +#include "mlir/Support/TypeID.h" // from @llvm-project + +namespace mlir { +namespace mhlo { + +// Given the type ID, check if it's legalized with MLIR. +bool IsTypeLegalizedWithMlir(const TypeID& type_id); + +// Returns true if the op is considered a dynamic padder op. +bool IsDynamicPadderOp(const TypeID& type_id); + +// Returns True if this op has a Tf2XLA fallback. Currently, this is not the +// inverse of the !IsOpLegalizedWithMlir, but it should be. +bool HasTf2XlaFallback(const TypeID& type_id); + +// Whether this type is allowed to have a TF2XLA fallback. +bool IsOpAllowedTf2xlaFallback(const TypeID& type_id); + +// Whether this type is Preferred to use a TF2XLA fallback kernel when using +// the MLIR bridge. If this is true, then the TF2XLA fallback kernel will be +// used over the MLIR lowering. +bool IsOpAllowedTf2xlaPreferred(const TypeID& type_id); + +} // namespace mhlo +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_TF2XLA_TRANSFORMS_LEGALIZATION_OP_CONFIG_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_with_tf2xla_passes.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_with_tf2xla_passes.h new file mode 100644 index 00000000..8c83fb56 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_with_tf2xla_passes.h @@ -0,0 +1,63 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TF2XLA_TRANSFORMS_LEGALIZE_TF_WITH_TF2XLA_PASSES_H_ +#define TENSORFLOW_COMPILER_MLIR_TF2XLA_TRANSFORMS_LEGALIZE_TF_WITH_TF2XLA_PASSES_H_ + +#include +#include + +#include "llvm/ADT/StringRef.h" +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassRegistry.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Transforms/DialectConversion.h" // from @llvm-project + +namespace mlir { + +namespace func { +class FuncOp; +} +class ModuleOp; +class Operation; +template +class OperationPass; +class Pass; + +namespace mhlo { + +/// Converter to be used along with the fallback Tf2Xla patterns below. +class Tf2XlaTypeConverter : public TypeConverter { + public: + Tf2XlaTypeConverter(); +}; + +/// Adds the TF to XLA via TF2XLA rewrite patterns to the pattern list. +/// `prefer_tf2xla` means an op will be included iff it is not in +/// `MlirLegalizedUnderPreferTf2XlaSet`. `!prefer_tf2xla` mean an op will be +/// included if there is no native MLIR legalization for the op. +void PopulateLegalizeTfWithTf2XlaPatterns(llvm::StringRef device_type, + RewritePatternSet& patterns, + MLIRContext* ctx, + Tf2XlaTypeConverter& converter, + bool prefer_tf2xla = false); + + +} // namespace mhlo +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_TF2XLA_TRANSFORMS_LEGALIZE_TF_WITH_TF2XLA_PASSES_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tf2xla/transforms/passes.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tf2xla/transforms/passes.h new file mode 100644 index 00000000..0b9f5a1e --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tf2xla/transforms/passes.h @@ -0,0 +1,120 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TF2XLA_TRANSFORMS_PASSES_H_ +#define TENSORFLOW_COMPILER_MLIR_TF2XLA_TRANSFORMS_PASSES_H_ + +#include +#include + +#include "llvm/ADT/StringRef.h" +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassRegistry.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Transforms/DialectConversion.h" // from @llvm-project + +namespace mlir { + +namespace func { +class FuncOp; +} +class ModuleOp; +class Operation; +template +class OperationPass; +class Pass; + +namespace mhlo { + +/// Lowers from TF dialect to HLO dialect. When allow_partial_conversion is +/// false, emits an error if there is any operation that can't be legalized. +/// When `tf2xla_fallback_device_type` is not `None`, also uses legalization +/// patterns from TF2XLA fallback for provided device type (see +/// legalize_tf_with_tf2xla.cc for details). By default, TF2XLA fallback is not +/// used. +/// Note: This is a module pass because when legalizing with TF2XLA fallback, +/// functions are imported into the module. Importing functions into a +/// module is not thread safe. +std::unique_ptr> createLegalizeTFPass( + bool legalize_chlo = true, + std::optional tf2xla_fallback_device_type = std::nullopt, + bool prefer_tf2xla = false); + +/// Adds the TF to TF lowerings and TF to XLA rewrite patterns to the pattern +/// list. +void PopulateLegalizeTfPatterns(MLIRContext* context, + RewritePatternSet* patterns); + +// Populates TF to MHLO legalization for some of the quantization ops. +// +// TODO(hinsu): Remove this once we combine quantized and non quantized op +// legalization in the ODML conversion pipeline. +void PopulateLegalizeTfQuantizationPatterns(MLIRContext* context, + RewritePatternSet* patterns); + +/// Converts the provided Operation as well as all nested operations into HLO +/// dialect using the conversion patterns registered by the HLO dialect. When +/// allow_partial_conversion is false, emits an error if there is any operation +/// that can't be legalized. +/// When `tf2xla_fallback_device_type` is not `None`, also uses legalization +/// patterns from TF2XLA fallback for provided device type (see +/// legalize_tf_with_tf2xla.cc for details). By default, TF2XLA fallback is not +/// used. +LogicalResult legalizeTF( + Operation* op, bool allow_partial_conversion = false, + bool legalize_chlo = true, + std::optional tf2xla_fallback_device_type = std::nullopt, + bool prefer_tf2xla = false); + +// Legalizes TF/XLA communication ops (TF dialect) to HLO dialect communication +// ops. +std::unique_ptr> CreateLegalizeTFCommunicationPass(); + +// Legalizes TF/XLA collective ops (TF dialect) to HLO dialect collective +// ops. +std::unique_ptr> CreateLegalizeTFCollectivePass(); + +// Verifies that the TF/XLA ops have all been lowered to MHLO. +std::unique_ptr> CreateVerifyTFXLALegalizationPass( + bool legalize_chlo = true); + +// Transforms TFXLA Device specific ops into device independent ops. +std::unique_ptr> +CreateTFXLADeviceSpecificTransformsPass( + std::optional tf2xla_fallback_device_type = std::nullopt); + +// Adjusts XLA layout for Infeed ops. +std::unique_ptr> +CreateInfeedsOpsXlaAdjustLayoutPass(); + +#define GEN_PASS_REGISTRATION +#define GEN_PASS_DECL_INFEEDSOPSXLAADJUSTLAYOUT +#define GEN_PASS_DECL_LEGALIZETF +#define GEN_PASS_DECL_LEGALIZETFCOLLECTIVE +#define GEN_PASS_DECL_LEGALIZETFMODULEPASS +#define GEN_PASS_DECL_LEGALIZETFTYPESPASS +#define GEN_PASS_DECL_TFXLADEVICESPECIFICTRANSFORMS +#define GEN_PASS_DECL_VERIFYTFXLALEGALIZATION +#include "tensorflow/compiler/mlir/tf2xla/transforms/xla_legalize_tf_passes.h.inc" + +#define GEN_PASS_REGISTRATION +#define GEN_PASS_DECL_LEGALIZETFCOMMUNICATIONPASS +#include "tensorflow/compiler/mlir/tf2xla/transforms/tf_xla_passes.h.inc" +} // namespace mhlo +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_TF2XLA_TRANSFORMS_PASSES_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tf2xla/transforms/split_into_island_per_op_pass.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tf2xla/transforms/split_into_island_per_op_pass.h new file mode 100644 index 00000000..21de5950 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tf2xla/transforms/split_into_island_per_op_pass.h @@ -0,0 +1,31 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TF2XLA_TRANSFORMS_SPLIT_INTO_ISLAND_PER_OP_PASS_H_ +#define TENSORFLOW_COMPILER_MLIR_TF2XLA_TRANSFORMS_SPLIT_INTO_ISLAND_PER_OP_PASS_H_ + +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" + +namespace mlir { +namespace TF { + +// Converts a single island into multiple islands (one for each op). +void SplitIsland(mlir::tf_executor::IslandOp island_op, + mlir::tf_executor::ControlType control_type); + +} // namespace TF +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_TF2XLA_TRANSFORMS_SPLIT_INTO_ISLAND_PER_OP_PASS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tf2xla/transforms/test_utils.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tf2xla/transforms/test_utils.h new file mode 100644 index 00000000..0ad6e9af --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tf2xla/transforms/test_utils.h @@ -0,0 +1,39 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TF2XLA_TRANSFORMS_TEST_UTILS_H_ +#define TENSORFLOW_COMPILER_MLIR_TF2XLA_TRANSFORMS_TEST_UTILS_H_ + +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "llvm/ADT/StringRef.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/utils/serialize_mlir_module_utils.h" +#include "tsl/platform/statusor.h" + +namespace mlir { +namespace mhlo { +namespace test { + +// Given a raw string, return a ModuleOp that can be used with the given +// MLIRContext. +absl::StatusOr> GetMlirModuleFromString( + absl::string_view module_string, MLIRContext* mlir_context); + +} // namespace test +} // namespace mhlo +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_TF2XLA_TRANSFORMS_TEST_UTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tf2xla/transforms/tf2xla_rewriter.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tf2xla/transforms/tf2xla_rewriter.h new file mode 100644 index 00000000..c5c417e2 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tf2xla/transforms/tf2xla_rewriter.h @@ -0,0 +1,128 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TF2XLA_TRANSFORMS_TF2XLA_REWRITER_H_ +#define TENSORFLOW_COMPILER_MLIR_TF2XLA_TRANSFORMS_TF2XLA_REWRITER_H_ + +#include +#include +#include +#include + +#include "absl/status/statusor.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "tensorflow/compiler/mlir/op_or_arg_name_mapper.h" +#include "tensorflow/compiler/tf2xla/xla_context.h" +#include "tensorflow/compiler/tf2xla/xla_expression.h" +#include "xla/hlo/builder/xla_builder.h" +#include "xla/hlo/builder/xla_computation.h" +#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" +#include "tensorflow/core/common_runtime/device_mgr.h" +#include "tensorflow/core/framework/op_kernel.h" + +namespace mlir { +namespace mhlo { + +class Tf2XlaRewriterTestPeer; + +class Tf2XlaRewriter { + public: + static mlir::LogicalResult RewriteOp(mlir::Operation* op, + mlir::PatternRewriter& rewriter, + const std::string& device_type); + + private: + friend class Tf2XlaRewriterTestPeer; + + Tf2XlaRewriter(mlir::Operation* op, mlir::PatternRewriter& rewriter, + const std::string& device_type); + + ~Tf2XlaRewriter(); + + // Compiles the given Operation with XlaBuilder and imports the generated HLO + // via the HLO -> MHLO importer. + absl::StatusOr CompileWithHloImporter( + tensorflow::OpKernelContext& op_context); + + // Import the given XlaComputation into the parent module. Returns the given + // generated function. + absl::StatusOr ImportXlaComputation( + xla::XlaComputation& computation); + + // Prepares OpKernelContext params common to all the ops. + // Emits an error on failure. + mlir::LogicalResult PrepareParams(); + + // Given the required_consts, it will fill the 3 output vectors with + // their respective data. + // Expressions: Output XLA expressions as required by the compiled kernel. + // Tensors: Vector of tensors that back the TensorValue inputs + // Inputs: Vector of inputs that are backed by tensors. + mlir::LogicalResult PrepareKernelInputs( + const llvm::SmallDenseSet& required_consts, + std::vector& expressions, + std::vector& tensors, + std::vector& inputs); + + mlir::LogicalResult VerifyOpResults(tensorflow::OpKernelContext& op_context); + mlir::LogicalResult GetKernelOutputs(tensorflow::OpKernelContext& op_context, + mhlo::TupleOp tuple_results, + llvm::SmallVector& outputs); + + // Given a translated function with a single return value, unpack the tuple + // results. + mlir::LogicalResult UnpackTupleResults(mhlo::TupleOp tuple_result, + llvm::SmallVector& outputs); + + // Tries to legalize the specified TensorFlow op, if supported. + // + // Emits an error and returns failure if an error is encountered during + // conversion. Note that success return value doesn't mean successful + // legalization. + mlir::LogicalResult LegalizeOp(); + + // Converts the given operand to expression of kind kConstant or kXlaOp. + // Emits a remark and returns expression of kind kInvalid on failure. + tensorflow::XlaExpression GetExprForOperand(mlir::Value operand, + mlir::Operation* op, + int64_t operand_index); + + mlir::Operation* op_; + std::string device_type_; + + mlir::PatternRewriter& rewriter_; + std::unique_ptr name_mapper_; + + tensorflow::XlaContext* context_; // Ref-counted. + + std::unique_ptr device_mgr_; + tensorflow::Device* device_; // Owned by device_mgr_; + std::unique_ptr step_container_; + std::unique_ptr flib_def_; + std::unique_ptr pflr_; + tensorflow::OpKernelContext::Params params_; + + xla::XlaBuilder xla_builder_; +}; + +} // namespace mhlo +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_TF2XLA_TRANSFORMS_TF2XLA_REWRITER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tf2xla/transforms/utils.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tf2xla/transforms/utils.h new file mode 100644 index 00000000..5dba4a4d --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tf2xla/transforms/utils.h @@ -0,0 +1,62 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TF2XLA_TRANSFORMS_UTILS_H_ +#define TENSORFLOW_COMPILER_MLIR_TF2XLA_TRANSFORMS_UTILS_H_ + +#include + +#include "llvm/ADT/ArrayRef.h" +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" + +namespace mlir { +namespace mhlo { + +// Builds body for reduce op by using the template binary op as the +// reducer op. +template +void BuildReduceBody(Type element_type, Region* body, OpBuilder* builder) { + OpBuilder::InsertionGuard guard(*builder); + Block* block = builder->createBlock(body); + + // Block arguments are scalars of the given element type. + Type type = RankedTensorType::get(/*shape=*/{}, element_type); + Location loc = body->getLoc(); + block->addArguments({type, type}, SmallVector(2, loc)); + + auto reducer = + builder->create(loc, block->getArgument(0), block->getArgument(1)); + builder->create(loc, reducer.getResult()); +} + +ConstantOp GetScalarConstOfType(Type ty, Location loc, int64_t raw_value, + OpBuilder* builder); + +ConstantOp GetScalarNegZeroOfType(Type ty, Location loc, OpBuilder* builder); + +// Converts an ArrayAttr to a 1D 64-bit dense elements attribute. +DenseIntElementsAttr GetI64ElementsAttr(ArrayAttr attr); +DenseIntElementsAttr GetI64ElementsAttr(llvm::ArrayRef values, + Builder* builder); + +} // namespace mhlo +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_TF2XLA_TRANSFORMS_UTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tf2xla/transforms/xla_legalize_targets.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tf2xla/transforms/xla_legalize_targets.h new file mode 100644 index 00000000..1711e039 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tf2xla/transforms/xla_legalize_targets.h @@ -0,0 +1,34 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TF2XLA_TRANSFORMS_XLA_LEGALIZE_TARGETS_H_ +#define TENSORFLOW_COMPILER_MLIR_TF2XLA_TRANSFORMS_XLA_LEGALIZE_TARGETS_H_ + +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/Transforms/DialectConversion.h" // from @llvm-project + +namespace mlir { +namespace mhlo { + +// Returns a ConversionTarget that includes default legalized MLIR dialects +// for conversion to XLA. +// If legalize_chlo is true, the resulting conversion target cannot have CHLO. +mlir::ConversionTarget GetDefaultLegalConversionTargets( + MLIRContext& mlir_context, bool legalize_chlo); + +} // namespace mhlo +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_TF2XLA_TRANSFORMS_XLA_LEGALIZE_TARGETS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tfr/integration/tfr_decompose_ctx.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tfr/integration/tfr_decompose_ctx.h new file mode 100644 index 00000000..73d241b8 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tfr/integration/tfr_decompose_ctx.h @@ -0,0 +1,85 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_TFR_INTEGRATION_TFR_DECOMPOSE_CTX_H_ +#define TENSORFLOW_COMPILER_MLIR_TFR_INTEGRATION_TFR_DECOMPOSE_CTX_H_ + +#include "absl/status/statusor.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/Pass/PassManager.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h" +#include "tensorflow/core/framework/function.pb.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/stringpiece.h" +#include "tsl/platform/statusor.h" + +namespace tensorflow { +namespace tfr { + +extern const char* const kTFRLibEnv; + +using tsl::StatusOr; + +// An wrapper for all the objects used to decompose a module (graph mode) and +// node_def (eager mode). Note that this class owns the decomposition library. +class TFRDecomposeContext { + public: + // The entry function to get a decompose context. All the required passes have + // been initialized. + static absl::StatusOr> Get( + mlir::MLIRContext* mlir_ctx); + + // Constructor of the decompose context. To share the decompose library, the + // whole decompose TFR function library is loaded. + explicit TFRDecomposeContext(mlir::ModuleOp tfr_module); + + // Constructs the decompose context from the tfr text module and the mlir + // context. The tfr text module is added to the mlir context. + static std::unique_ptr GetFromText( + StringPiece tfr_raw_text, mlir::MLIRContext* mlir_ctx); + + // Decomposes the op in the NodeDef to a set of primitive ops according to the + // decompose library in the context. Wrap the decomposed result in a + // FunctionDef. + absl::StatusOr ExpandNode(const NodeDef& node_def, + StringPiece func_name); + + // Runs the decompose passes on the user_module. + absl::Status DecomposeGraph(mlir::ModuleOp user_module); + + // Erases the tfr_module created. + void Destroy(); + + private: + mlir::ModuleOp tfr_module_; + mlir::PassManager pm_; + + GraphExportConfig export_confs_; +}; + +// Decomposes the NodeDef to a set of primitive ops according to the decompose +// library loaded. Wrap the decomposed result in a FunctionDef. +absl::StatusOr ExpandNode(const NodeDef& node_def, + StringPiece func_name); + +// Decomposes the ops in the ModuleOp to a set of primitive ops according to +// decompose library in the context. +absl::Status DecomposeGraph(mlir::ModuleOp user_module); + +} // namespace tfr +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_TFR_INTEGRATION_TFR_DECOMPOSE_CTX_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tfr/ir/tfr_ops.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tfr/ir/tfr_ops.h new file mode 100644 index 00000000..2066d7ba --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tfr/ir/tfr_ops.h @@ -0,0 +1,66 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TFR_IR_TFR_OPS_H_ +#define TENSORFLOW_COMPILER_MLIR_TFR_IR_TFR_OPS_H_ + +#include "llvm/ADT/StringSet.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Dialect/Shape/IR/Shape.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/Dialect.h" // from @llvm-project +#include "mlir/IR/DialectImplementation.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/Interfaces/CallInterfaces.h" // from @llvm-project +#include "mlir/Interfaces/ControlFlowInterfaces.h" // from @llvm-project +#include "mlir/Interfaces/FunctionInterfaces.h" // from @llvm-project +#include "mlir/Interfaces/InferTypeOpInterface.h" // from @llvm-project +#include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project + +namespace mlir { +namespace TFR { + +constexpr char kAttrArgumentNameAttr[] = "tfr.name"; +constexpr char kAttrArgumentDefaultAttr[] = "tfr.default"; +constexpr char kAttrArgumentTypeAttr[] = "tfr.type"; + +class TFRDialect : public Dialect { + public: + explicit TFRDialect(MLIRContext *context); + + static StringRef getDialectNamespace() { return "tfr"; } + + Operation *materializeConstant(OpBuilder &builder, Attribute value, Type type, + Location loc) override; + + // Parse a type registered to this dialect. + Type parseType(DialectAsmParser &parser) const override; + + // Prints a type registered to this dialect. + void printType(Type ty, DialectAsmPrinter &os) const override; +}; + +} // namespace TFR +} // namespace mlir + +#define GET_OP_CLASSES +#include "tensorflow/compiler/mlir/tfr/ir/tfr_ops.h.inc" + +#endif // TENSORFLOW_COMPILER_MLIR_TFR_IR_TFR_OPS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tfr/ir/tfr_types.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tfr/ir/tfr_types.h new file mode 100644 index 00000000..e0e24f4a --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tfr/ir/tfr_types.h @@ -0,0 +1,126 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TFR_IR_TFR_TYPES_H_ +#define TENSORFLOW_COMPILER_MLIR_TFR_IR_TFR_TYPES_H_ + +#include +#include + +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/Diagnostics.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/TypeSupport.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project + +namespace mlir { +namespace TFR { + +class TFRType : public Type { + public: + using Type::Type; + + static bool classof(Type type); +}; + +namespace detail { + +struct TFRTypeStorage final + : public TypeStorage, + public llvm::TrailingObjects { + using KeyTy = ArrayRef; + + explicit TFRTypeStorage(unsigned num_attrs) : num_attrs(num_attrs) {} + + static TFRTypeStorage* construct(TypeStorageAllocator& allocator, KeyTy key) { + // Allocate a new storage instance. + auto byteSize = TFRTypeStorage::totalSizeToAlloc(key.size()); + auto rawMem = allocator.allocate(byteSize, alignof(TFRTypeStorage)); + auto result = ::new (rawMem) TFRTypeStorage(key.size()); + + // Copy in the string attributes into the trailing storage. + std::uninitialized_copy(key.begin(), key.end(), + result->getTrailingObjects()); + return result; + } + + bool operator==(const KeyTy& attrs) const { return attrs == GetAttrs(); } + + KeyTy GetAttrs() const { + return {getTrailingObjects(), num_attrs}; + } + + unsigned num_attrs; +}; + +template +class TFRTypeImpl : public Type::TypeBase { + public: + using Base = Type::TypeBase; + using TFRBase = TFRTypeImpl; + using Base::Base; + + static Derived get(ArrayRef attrs, MLIRContext* context) { + return Base::get(context, attrs); + } + + static Derived getChecked(ArrayRef attrs, Location loc) { + return Base::getChecked(loc, loc.getContext(), attrs); + } + static Derived getChecked(function_ref emitError, + MLIRContext* context, ArrayRef attrs) { + return Base::getChecked(emitError, context, attrs); + } + + static Derived get(MLIRContext* context) { return get({}, context); } + + // TODO(fengliuai): fix the implementation + static LogicalResult verify(function_ref emitError, + ArrayRef attrs) { + return success(); + } + + ArrayRef getAttrKeys() { return Base::getImpl()->GetAttrs(); } +}; +} // namespace detail + +class TFRTensorType : public detail::TFRTypeImpl { + public: + using TFRBase::TFRBase; + static constexpr StringLiteral name = "tfr.tensor"; + static std::string getTypeName() { return "TFRTensorType"; } +}; + +class TFRTensorListType : public detail::TFRTypeImpl { + public: + using TFRBase::TFRBase; + static constexpr StringLiteral name = "tfr.tensor_list"; + static std::string getTypeName() { return "TFRTensorListType"; } +}; + +class TFRAttrType : public Type::TypeBase { + public: + using Base::Base; + static constexpr StringLiteral name = "tfr.attr"; + static std::string getTypeName() { return "TFRAttrType"; } +}; + +} // namespace TFR +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_TFR_IR_TFR_TYPES_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tfr/passes/passes.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tfr/passes/passes.h new file mode 100644 index 00000000..00bf1187 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tfr/passes/passes.h @@ -0,0 +1,52 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TFR_PASSES_PASSES_H_ +#define TENSORFLOW_COMPILER_MLIR_TFR_PASSES_PASSES_H_ + +#include +#include + +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project + +namespace mlir { +namespace TFR { + +// Scans the func op and adds all the canonicalization patterns of the ops +// except the tf ops, inside the function. +void populateCanonicalizationPatterns(func::FuncOp func, + RewritePatternSet &patterns); + +// Decompose ops. +std::unique_ptr> CreateDecomposeTFOpsPass( + std::optional tfr_module = std::nullopt); + +// Rewrites quantized operands and results with their storage types. +// This pass should be run at module level after decomposition, if there are +// quantized operands or results. +std::unique_ptr> CreateRewriteQuantizedIOPass(); + +// Raise to TF ops. +std::unique_ptr> CreateRaiseToTFOpsPass( + std::optional tfr_module = std::nullopt, + bool materialize_derived_attrs = false); + +} // namespace TFR +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_TFR_PASSES_PASSES_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tfr/utils/utils.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tfr/utils/utils.h new file mode 100644 index 00000000..911015ae --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tfr/utils/utils.h @@ -0,0 +1,66 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TFR_UTILS_UTILS_H_ +#define TENSORFLOW_COMPILER_MLIR_TFR_UTILS_UTILS_H_ + +#include + +#include "mlir/IR/Block.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/OperationSupport.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tfr/ir/tfr_ops.h" + +namespace mlir { +namespace TFR { + +// This is a hardcoded rule for mapping a TF op name to the corresponding +// TFR function name. Examples: +// tf.Pack => tf__pack +// tf.ConcatV2 => tf__concat_v2 +// TODO(fengliuai): move to an util file. +std::string GetComposeFuncName(StringRef tf_op_name); + +// This is a hardcoded rule for mapping a TFR function op name to the +// corresponding TF opname. Examples: +// tf__pack -> tf.Pack +// tf__concat_v2 => tf.ConcatV2 +std::string GetTFOpName(StringRef compose_func_name); + +// Validate the attributes of 'src' is either contained in the registered +// attribute sets or in the allowed list. +LogicalResult ValidateAttrs(Operation* src, const StringSet<>& registered); + +// Copies all the allowed attributes in 'src' to 'dst'. The copy failed if the +// 'dst' has the attribute. Return a failure if there are any attributes are not +// allowed and also unregistered. +LogicalResult CopyAllowedUnregisteredAttrs(Operation* src, CallOp dst, + const StringSet<>& registered); + +// Copies all the allowed attributes in 'src' to 'dst'. FlatSymbolRefAttr is +// excluded. +LogicalResult CopyNonSymbolRefAttrs(CallOp src, Operation* dst); + +// Propagates all the attributes in 'src' to the operations between 'begin' and +// 'end'. Operation 'end' is excluded. +void PropagateAttrsToOperations(CallOp src, Block::iterator begin, + Block::iterator end); + +} // namespace TFR +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_TFR_UTILS_UTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tfrt/analysis/cost_analysis.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tfrt/analysis/cost_analysis.h new file mode 100644 index 00000000..b27b6aa9 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tfrt/analysis/cost_analysis.h @@ -0,0 +1,98 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TFRT_ANALYSIS_COST_ANALYSIS_H_ +#define TENSORFLOW_COMPILER_MLIR_TFRT_ANALYSIS_COST_ANALYSIS_H_ + +#include +#include + +#include "absl/strings/string_view.h" +#include "llvm/ADT/DenseMap.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/Block.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/tfrt/fallback/cost_recorder.h" +#include "tensorflow/core/tfrt/fallback/op_cost_map.pb.h" + +namespace tensorflow { +namespace tfrt_compiler { + +// Analyze costs for tensorflow operations. +// +// The current heuristic used is quite simple, which is to calculate the total +// size of input tensors. The exception is that ops whose cost is irrelevant to +// input sizes, such as tf.Shape and tf.Reshape, are whitelisted to have cheap +// cost. This cost analysis is expected to be used conservatively (eg. use a low +// threshold to decide whether a cost is cheap or expensive), as it might not be +// accurate in some cases. +// +class CostAnalysis { + public: + explicit CostAnalysis( + mlir::func::FuncOp func_op, + const tfrt_stub::CostRecorder* cost_recorder = nullptr) { + cost_recorder_ = cost_recorder; + AnalyzeArguments(func_op); + AnalyzeBlock(&func_op.front()); + } + + int64_t GetCost(mlir::Operation* op) const; + + private: + void AnalyzeArguments(mlir::func::FuncOp func_op); + void AnalyzeBlock(mlir::Block* block); + void EvaluateCost(mlir::Operation* op); + + int64_t max_arg_size_ = 1; + llvm::DenseMap cost_map_; + const tfrt_stub::CostRecorder* cost_recorder_; +}; + +struct CostContext { + int64_t default_unranked_tensor_size; +}; + +using CostFunction = + std::function; + +void RegisterCostFunction(absl::string_view op_name, + CostFunction cost_function); + +template +void RegisterCostFunction(F f) { + RegisterCostFunction( + OpType::getOperationName().str(), + [f = std::move(f)](const CostContext& context, mlir::Operation* op) { + return f(context, llvm::cast(op)); + }); +} + +template +struct CostFunctionRegistration { + explicit CostFunctionRegistration( + std::function cost_function) { + RegisterCostFunction(std::move(cost_function)); + } +}; + +bool HasCostFunctionRegistered(absl::string_view op_name); + +} // namespace tfrt_compiler +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_TFRT_ANALYSIS_COST_ANALYSIS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tfrt/analysis/tensor_array_side_effect_analysis.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tfrt/analysis/tensor_array_side_effect_analysis.h new file mode 100644 index 00000000..4f8501b5 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tfrt/analysis/tensor_array_side_effect_analysis.h @@ -0,0 +1,52 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_TFRT_ANALYSIS_TENSOR_ARRAY_SIDE_EFFECT_ANALYSIS_H_ +#define TENSORFLOW_COMPILER_MLIR_TFRT_ANALYSIS_TENSOR_ARRAY_SIDE_EFFECT_ANALYSIS_H_ + +#include "llvm/ADT/DenseSet.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project + +namespace tensorflow { +namespace tfrt_compiler { + +// Return true if it is a TensorArrayOp, eg. TensorArrayV3Op. +bool IsTensorArrayOp(mlir::Operation* op); + +// This class provides utilities for analyzing side effects for TensorArray ops +// in the graph. mlir::TF::SideEffectAnalysis currently produces suboptimal +// side-effect analysis for TensorArray ops. On the other hand, control +// dependencies are already sorted out for TensorArray ops in the original TF +// graph. Each TensorArray op will take or produce a `flow` value and they are +// already properly chained in the origninal TF graph. +class TensorArraySideEffectAnalysis { + public: + explicit TensorArraySideEffectAnalysis(mlir::ModuleOp module); + + // Return if the function contains only non-side-effecting ops or TensorArray + // ops. + bool HasAtMostTensorArrayEffect(mlir::func::FuncOp func_op) const { + return set_.count(func_op) > 0; + } + + private: + llvm::DenseSet set_; +}; + +} // namespace tfrt_compiler +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_TFRT_ANALYSIS_TENSOR_ARRAY_SIDE_EFFECT_ANALYSIS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tfrt/backend_compiler.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tfrt/backend_compiler.h new file mode 100644 index 00000000..7167c8ef --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tfrt/backend_compiler.h @@ -0,0 +1,41 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TFRT_BACKEND_COMPILER_H_ +#define TENSORFLOW_COMPILER_MLIR_TFRT_BACKEND_COMPILER_H_ + +#include "absl/status/status.h" +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/DialectRegistry.h" // from @llvm-project +#include "tensorflow/core/tfrt/runtime/runtime.h" + +namespace tensorflow { + +class BackendCompiler { + public: + virtual ~BackendCompiler(); + + virtual void GetDependentDialects(mlir::DialectRegistry& registry) const {} + + // Compile the `module` in TF dialect. The result module should be also in TF + // dialect. + virtual absl::Status CompileTensorflow( + tfrt_stub::ModelRuntimeContext& model_context, + mlir::ModuleOp module) const = 0; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_TFRT_BACKEND_COMPILER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tfrt/constants.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tfrt/constants.h new file mode 100644 index 00000000..ed6e773c --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tfrt/constants.h @@ -0,0 +1,28 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TFRT_CONSTANTS_H_ +#define TENSORFLOW_COMPILER_MLIR_TFRT_CONSTANTS_H_ + +namespace tensorflow { +namespace tfrt_compiler { + +// Use __ prefix to indicate this is internal attribute. +inline constexpr char kOpKeyAttrName[] = "__op_key"; + +} // namespace tfrt_compiler +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_TFRT_CONSTANTS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tfrt/function/function.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tfrt/function/function.h new file mode 100644 index 00000000..8d09f8cb --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tfrt/function/function.h @@ -0,0 +1,83 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TFRT_FUNCTION_FUNCTION_H_ +#define TENSORFLOW_COMPILER_MLIR_TFRT_FUNCTION_FUNCTION_H_ + +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" +#include "absl/strings/string_view.h" +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tfrt/translate/tfrt_compile_options.h" +#include "tensorflow/core/platform/status.h" +#include "tfrt/bef/bef_buffer.h" // from @tf_runtime +#include "tfrt/core_runtime/tensor_handle.h" // from @tf_runtime + +namespace tfrt { +class CoreRuntime; +} + +namespace mlir { +class ModuleOp; +} + +namespace tensorflow { + +struct TfrtFunctionCompileOptions : public TfrtCompileOptions { + // Currently only SavedModel API inference uses the tpu_fuse_ops option + TfrtFunctionCompileOptions() { + tpu_fuse_ops = false; + // Currently grappler is not correctly applied in the eager execution of TF + // functions, as it may sometimes remove arguments and results. + enable_grappler = false; + } + + // If true, use ServingCoreSelector to pick TPU core. Otherwise, obtain core + // location from assigned device name. + // Currently we don't use core_selector for training use cases. + bool tpu_use_core_selector = false; + + // If true, use BundledTransferToTpuOp to transfer variables and input tensors + // to TPU. + bool tpu_use_bundled_transfer = false; + + // If true, lower an TF op that's placed on TPU device to be executed with + // tfrt_fallback.execute. + // Currently for training use cases we need to lower the op to corert.execute + // to execute with TPU OpHandler, and with TFRT's native implementation. + // TODO(b/188940204): remove this config after we clear up the TPU variable + // implementation. + bool tpu_lower_to_fallback = false; + // If true, transfer the result of TPUExecuteOp from TPU to host. + // Currently for training and Python bulk inference use cases, we don't need + // to proactively transfer the result to host since the consumer op (or + // function) of the result may still be on TPU. + // TODO(b/194081364): remove this option once we unify servo TPU serving + // result transfer behavior. + bool tpu_transfer_result_to_host = false; +}; + +// Compile MLIR generated by tf.function in TF dialect into BEF. +absl::Status CompileTFMLIRToBEF(const TfrtFunctionCompileOptions& options, + mlir::ModuleOp module, + tfrt::BefBuffer* bef_buffer); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_TFRT_FUNCTION_FUNCTION_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tfrt/ir/gpu_ops.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tfrt/ir/gpu_ops.h new file mode 100644 index 00000000..ff06f069 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tfrt/ir/gpu_ops.h @@ -0,0 +1,41 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_TFRT_IR_GPU_OPS_H_ +#define TENSORFLOW_COMPILER_MLIR_TFRT_IR_GPU_OPS_H_ + +#include "mlir/Bytecode/BytecodeOpInterface.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/Dialect.h" // from @llvm-project +#include "mlir/IR/OpDefinition.h" // from @llvm-project + +using namespace mlir; // NOLINT + +namespace tfrt { +namespace gpu { + +// Dialect for TFRT GPU operations. +class GpuRuntimeDialect : public Dialect { + public: + explicit GpuRuntimeDialect(MLIRContext *context); + static StringRef getDialectNamespace() { return "gpurt"; } +}; + +} // namespace gpu +} // namespace tfrt + +#define GET_OP_CLASSES +#include "tensorflow/compiler/mlir/tfrt/ir/gpu_ops.h.inc" + +#endif // TENSORFLOW_COMPILER_MLIR_TFRT_IR_GPU_OPS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tfrt/ir/mlrt/mlrt_dialect.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tfrt/ir/mlrt/mlrt_dialect.h new file mode 100644 index 00000000..644de261 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tfrt/ir/mlrt/mlrt_dialect.h @@ -0,0 +1,62 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_TFRT_IR_MLRT_MLRT_DIALECT_H_ +#define TENSORFLOW_COMPILER_MLIR_TFRT_IR_MLRT_MLRT_DIALECT_H_ + +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/Dialect.h" // from @llvm-project +#include "mlir/IR/OpDefinition.h" // from @llvm-project +#include "mlir/IR/OpImplementation.h" // from @llvm-project + +namespace mlrt { +namespace compiler { + +class MlrtDialect : public mlir::Dialect { + public: + explicit MlrtDialect(mlir::MLIRContext *context); + static llvm::StringRef getDialectNamespace() { return "mlrt"; } + + mlir::Type parseType(mlir::DialectAsmParser &parser) const override; + void printType(mlir::Type type, mlir::DialectAsmPrinter &os) const override; +}; + +// The MLIR type represents a C++ mlrt::Future. +class FutureType + : public mlir::Type::TypeBase { + public: + using Base::Base; + static constexpr mlir::StringLiteral name = "mlrt.compiler.future"; +}; + +// The MLIR type represents a C++ mlrt::Promise. +class PromiseType + : public mlir::Type::TypeBase { + public: + using Base::Base; + static constexpr mlir::StringLiteral name = "mlrt.compiler.promise"; +}; + +// The MLIR type represents a C++ mlrt::AsyncHandle. +class AsyncHandleType : public mlir::Type::TypeBase { + public: + using Base::Base; + static constexpr mlir::StringLiteral name = "mlrt.compiler.async_handle"; +}; + +} // namespace compiler +} // namespace mlrt + +#endif // TENSORFLOW_COMPILER_MLIR_TFRT_IR_MLRT_MLRT_DIALECT_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tfrt/ir/mlrt/mlrt_ops.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tfrt/ir/mlrt/mlrt_ops.h new file mode 100644 index 00000000..e3922c6e --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tfrt/ir/mlrt/mlrt_ops.h @@ -0,0 +1,27 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_TFRT_IR_MLRT_MLRT_OPS_H_ +#define TENSORFLOW_COMPILER_MLIR_TFRT_IR_MLRT_MLRT_OPS_H_ + +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/Dialect.h" // from @llvm-project +#include "mlir/IR/OpDefinition.h" // from @llvm-project +#include "mlir/IR/OpImplementation.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tfrt/ir/mlrt/mlrt_dialect.h" + +#define GET_OP_CLASSES +#include "tensorflow/compiler/mlir/tfrt/ir/mlrt/mlrt_ops.h.inc" + +#endif // TENSORFLOW_COMPILER_MLIR_TFRT_IR_MLRT_MLRT_OPS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tfrt/ir/mlrt/tf_mlrt_ops.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tfrt/ir/mlrt/tf_mlrt_ops.h new file mode 100644 index 00000000..a542373e --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tfrt/ir/mlrt/tf_mlrt_ops.h @@ -0,0 +1,63 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_TFRT_IR_MLRT_TF_MLRT_OPS_H_ +#define TENSORFLOW_COMPILER_MLIR_TFRT_IR_MLRT_TF_MLRT_OPS_H_ + +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/Dialect.h" // from @llvm-project +#include "mlir/IR/OpDefinition.h" // from @llvm-project +#include "mlir/IR/OpImplementation.h" // from @llvm-project +#include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_side_effects.h" +#include "tfrt/compiler/opdefs/tfrt_op_interfaces.h" // from @tf_runtime +#include "tfrt/compiler/opdefs/tfrt_traits.h" // from @tf_runtime + +namespace tensorflow { +namespace tf_mlrt { + +class TensorflowMlrtDialect : public mlir::Dialect { + public: + explicit TensorflowMlrtDialect(mlir::MLIRContext *context); + static llvm::StringRef getDialectNamespace() { return "tf_mlrt"; } + + mlir::Type parseType(mlir::DialectAsmParser &parser) const override; + void printType(mlir::Type type, mlir::DialectAsmPrinter &os) const override; +}; + +// The MLIR type represents a tensorflow::Tensor. +class TFTensorType + : public mlir::Type::TypeBase { + public: + using Base::Base; + static constexpr mlir::StringLiteral name = "tensorflow.tf_mlrt.tf_tensor"; +}; + +// The MLIR type represents a tensorflow::Device* +class TFDeviceType + : public mlir::Type::TypeBase { + public: + using Base::Base; + static constexpr mlir::StringLiteral name = "tensorflow.tf_mlirt.tf_device"; +}; + +} // namespace tf_mlrt +} // namespace tensorflow + +#define GET_OP_CLASSES +#include "tensorflow/compiler/mlir/tfrt/ir/mlrt/tf_mlrt_ops.h.inc" +#define GET_OP_CLASSES +#include "tensorflow/compiler/mlir/tfrt/ir/mlrt/tf_ops.h.inc" + +#endif // TENSORFLOW_COMPILER_MLIR_TFRT_IR_MLRT_TF_MLRT_OPS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tfrt/ir/mlrt/tf_mlrt_tpu_ops.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tfrt/ir/mlrt/tf_mlrt_tpu_ops.h new file mode 100644 index 00000000..a428488d --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tfrt/ir/mlrt/tf_mlrt_tpu_ops.h @@ -0,0 +1,39 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_TFRT_IR_MLRT_TF_MLRT_TPU_OPS_H_ +#define TENSORFLOW_COMPILER_MLIR_TFRT_IR_MLRT_TF_MLRT_TPU_OPS_H_ + +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/Dialect.h" // from @llvm-project +#include "mlir/IR/OpDefinition.h" // from @llvm-project +#include "mlir/IR/OpImplementation.h" // from @llvm-project +#include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project + +namespace tensorflow { +namespace tf_mlrt_tpu { + +class TensorflowMlrtTpuDialect : public mlir::Dialect { + public: + explicit TensorflowMlrtTpuDialect(mlir::MLIRContext *context); + static llvm::StringRef getDialectNamespace() { return "tf_mlrt_tpu"; } +}; + +} // namespace tf_mlrt_tpu +} // namespace tensorflow + +#define GET_OP_CLASSES +#include "tensorflow/compiler/mlir/tfrt/ir/mlrt/tf_mlrt_tpu_ops.h.inc" + +#endif // TENSORFLOW_COMPILER_MLIR_TFRT_IR_MLRT_TF_MLRT_TPU_OPS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tfrt/ir/tfrt_fallback.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tfrt/ir/tfrt_fallback.h new file mode 100644 index 00000000..24fa464f --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tfrt/ir/tfrt_fallback.h @@ -0,0 +1,60 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_TFRT_IR_TFRT_FALLBACK_H_ +#define TENSORFLOW_COMPILER_MLIR_TFRT_IR_TFRT_FALLBACK_H_ + +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/Dialect.h" // from @llvm-project +#include "mlir/IR/OpDefinition.h" // from @llvm-project +#include "mlir/IR/OpImplementation.h" // from @llvm-project +#include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project + +using namespace mlir; // NOLINT + +namespace tfrt { +namespace fallback { + +// Dialect for fallback operations. +class FallbackDialect : public Dialect { + public: + explicit FallbackDialect(MLIRContext *context); + static StringRef getDialectNamespace() { return "tfrt_fallback"; } + + Type parseType(DialectAsmParser &parser) const override; + void printType(Type type, DialectAsmPrinter &os) const override; +}; + +// The MLIR type represents a tensorflow::Tensor. +class TFTensorType : public Type::TypeBase { + public: + using Base::Base; + static constexpr StringLiteral name = "tfrt.tf_tensor"; +}; + +// The MLIR type represents a tensorflow::Allocator. +class TFAllocatorType + : public Type::TypeBase { + public: + using Base::Base; + static constexpr StringLiteral name = "tfrt.tf_allocator"; +}; + +} // namespace fallback +} // namespace tfrt + +#define GET_OP_CLASSES +#include "tensorflow/compiler/mlir/tfrt/ir/tfrt_fallback.h.inc" + +#endif // TENSORFLOW_COMPILER_MLIR_TFRT_IR_TFRT_FALLBACK_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tfrt/ir/tfrt_fallback_async.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tfrt/ir/tfrt_fallback_async.h new file mode 100644 index 00000000..eab44d1d --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tfrt/ir/tfrt_fallback_async.h @@ -0,0 +1,45 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_TFRT_IR_TFRT_FALLBACK_ASYNC_H_ +#define TENSORFLOW_COMPILER_MLIR_TFRT_IR_TFRT_FALLBACK_ASYNC_H_ + +#include "mlir/Bytecode/BytecodeOpInterface.h" // from @llvm-project +#include "mlir/IR/Dialect.h" // from @llvm-project +#include "mlir/IR/OpDefinition.h" // from @llvm-project +#include "mlir/IR/OpImplementation.h" // from @llvm-project +#include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project +#include "tfrt/compiler/opdefs/tfrt_op_interfaces.h" // from @tf_runtime +#include "tfrt/compiler/opdefs/tfrt_traits.h" // from @tf_runtime +#include "tfrt/core_runtime/opdefs/traits.h" // from @tf_runtime + +using namespace mlir; // NOLINT + +namespace tfrt { +namespace fallback_async { + +// Dialect for fallback async operations. +class FallbackAsyncDialect : public Dialect { + public: + explicit FallbackAsyncDialect(MLIRContext *context); + static StringRef getDialectNamespace() { return "tfrt_fallback_async"; } +}; + +} // namespace fallback_async +} // namespace tfrt + +#define GET_OP_CLASSES +#include "tensorflow/compiler/mlir/tfrt/ir/tfrt_fallback_async.h.inc" + +#endif // TENSORFLOW_COMPILER_MLIR_TFRT_IR_TFRT_FALLBACK_ASYNC_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tfrt/ir/tfrt_fallback_common.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tfrt/ir/tfrt_fallback_common.h new file mode 100644 index 00000000..0cddb101 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tfrt/ir/tfrt_fallback_common.h @@ -0,0 +1,127 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_TFRT_IR_TFRT_FALLBACK_COMMON_H_ +#define TENSORFLOW_COMPILER_MLIR_TFRT_IR_TFRT_FALLBACK_COMMON_H_ + +#include + +#include "llvm/ADT/STLExtras.h" +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/OpImplementation.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "tfrt/basic_kernels/opdefs/types.h" // from @tf_runtime + +namespace tfrt { +namespace fallback_common { + +template +mlir::LogicalResult VerifyExecuteOpCommon(OpTy op) { + auto op_attr_array = op.getOpAttrs().getValue(); + for (auto op_attr : op_attr_array) { + auto key_value = mlir::dyn_cast(op_attr); + if (!key_value || key_value.getValue().size() != 2 || + !mlir::isa(key_value.getValue()[0])) + return op.emitOpError() << "each op_attr should be a key-value pair, " + "where the key is a string"; + } + return mlir::success(); +} + +template +mlir::LogicalResult VerifyFallbackExecuteOp(OpTy op) { + auto result = VerifyExecuteOpCommon(op); + if (failed(result)) return result; + + // Verify function attributes. + auto op_func_attr_array = op.getOpFuncAttrs().getValue(); + for (auto op_attr : op_func_attr_array) { + auto key_value = mlir::dyn_cast(op_attr); + if (!key_value || key_value.getValue().size() != 2 || + !mlir::isa(key_value.getValue()[0]) || + !mlir::isa(key_value.getValue()[1])) + return op.emitOpError() << "each op_func_attr should be a key-value " + "pair, where both the key and the value are " + "strings"; + } + return mlir::success(); +} + +template +void PrintExecuteOpFuncAttribute(mlir::OpAsmPrinter &p, OpTy op) { + auto op_func_attrs = op.getOpFuncAttrs(); + if (!op_func_attrs.empty()) { + auto print_key_value = [&](mlir::Attribute attr) { + auto key_value = mlir::cast(attr).getValue(); + auto key = key_value[0]; + auto value = key_value[1]; + + p << mlir::cast(key).getValue(); + p << " = "; + p << value; + }; + + auto op_func_attr_array = op_func_attrs.getValue(); + p << " {"; + llvm::interleaveComma(op_func_attr_array, p, print_key_value); + p << '}'; + } +} + +template +void PrintExecuteOpCommon(mlir::OpAsmPrinter &p, OpTy op) { + auto op_attrs = op.getOpAttrs(); + if (!op_attrs.empty()) { + auto print_key_value = [&](mlir::Attribute attr) { + auto key_value = mlir::cast(attr).getValue(); + auto key = key_value[0]; + auto value = key_value[1]; + + p << mlir::cast(key).getValue(); + p << " = "; + p << value; + }; + + auto op_attr_array = op_attrs.getValue(); + p << " {"; + llvm::interleaveComma(op_attr_array, p, print_key_value); + p << '}'; + } +} + +void GetExecuteOpAttrsCommon( + mlir::MLIRContext *context, llvm::ArrayRef op_attr_array, + llvm::SmallVectorImpl> + *op_attrs); + +struct ParseExecuteOpOptions { + bool has_chain = false; + bool has_key = false; + bool has_device = false; + bool has_func_attr = false; + bool has_cost = false; + bool has_op_name = true; + bool has_symbol_ref = false; +}; + +mlir::ParseResult ParseExecuteOpCommon(mlir::OpAsmParser &parser, + mlir::Builder &builder, + mlir::OperationState &result, + mlir::Type tensor_type, + const ParseExecuteOpOptions &options); +} // namespace fallback_common +} // namespace tfrt + +#endif // TENSORFLOW_COMPILER_MLIR_TFRT_IR_TFRT_FALLBACK_COMMON_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tfrt/ir/tfrt_fallback_sync.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tfrt/ir/tfrt_fallback_sync.h new file mode 100644 index 00000000..78e99830 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tfrt/ir/tfrt_fallback_sync.h @@ -0,0 +1,45 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_TFRT_IR_TFRT_FALLBACK_SYNC_H_ +#define TENSORFLOW_COMPILER_MLIR_TFRT_IR_TFRT_FALLBACK_SYNC_H_ + +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/Dialect.h" // from @llvm-project +#include "mlir/IR/OpDefinition.h" // from @llvm-project +#include "mlir/IR/OpImplementation.h" // from @llvm-project +#include "mlir/Interfaces/InferTypeOpInterface.h" // from @llvm-project +#include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project +#include "tfrt/core_runtime/opdefs/traits.h" // from @tf_runtime +#include "tfrt/tensor/opdefs/tensor.h" // from @tf_runtime + +using namespace mlir; // NOLINT + +namespace tfrt { +namespace fallback_sync { + +// Dialect for fallback operations. +class FallbackSyncDialect : public Dialect { + public: + explicit FallbackSyncDialect(MLIRContext *context); + static StringRef getDialectNamespace() { return "tfrt_fallback_sync"; } +}; + +} // namespace fallback_sync +} // namespace tfrt + +#define GET_OP_CLASSES +#include "tensorflow/compiler/mlir/tfrt/ir/tfrt_fallback_sync.h.inc" + +#endif // TENSORFLOW_COMPILER_MLIR_TFRT_IR_TFRT_FALLBACK_SYNC_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tfrt/ir/tfrt_fallback_util.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tfrt/ir/tfrt_fallback_util.h new file mode 100644 index 00000000..93235ec6 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tfrt/ir/tfrt_fallback_util.h @@ -0,0 +1,36 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_TFRT_IR_TFRT_FALLBACK_UTIL_H_ +#define TENSORFLOW_COMPILER_MLIR_TFRT_IR_TFRT_FALLBACK_UTIL_H_ + +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project + +namespace tfrt { +namespace fallback_async { + +bool IsArgConsumedByFallback(mlir::func::FuncOp func, int arg_index); + +void ForEachArgConsumedByFallback( + mlir::func::FuncOp func, llvm::function_ref action); + +void ForEachArgConsumedByFallback( + mlir::ModuleOp module, + llvm::function_ref action); + +} // namespace fallback_async +} // namespace tfrt + +#endif // TENSORFLOW_COMPILER_MLIR_TFRT_IR_TFRT_FALLBACK_UTIL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tfrt/runtime_fallback/runtime_fallback_executor.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tfrt/runtime_fallback/runtime_fallback_executor.h new file mode 100644 index 00000000..41c2b818 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tfrt/runtime_fallback/runtime_fallback_executor.h @@ -0,0 +1,63 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TFRT_RUNTIME_FALLBACK_RUNTIME_FALLBACK_EXECUTOR_H_ +#define TENSORFLOW_COMPILER_MLIR_TFRT_RUNTIME_FALLBACK_RUNTIME_FALLBACK_EXECUTOR_H_ + +#include +#include + +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/platform/threadpool_interface.h" +#include "tfrt/bef/bef_buffer.h" // from @tf_runtime +#include "tfrt/bef_executor/bef_file.h" // from @tf_runtime +#include "tfrt/host_context/execution_context.h" // from @tf_runtime +#include "tfrt/host_context/host_context.h" // from @tf_runtime +#include "tfrt/host_context/resource_context.h" // from @tf_runtime +#include "tfrt/support/ref_count.h" // from @tf_runtime + +namespace tensorflow { + +class RuntimeFallbackExecutor { + public: + explicit RuntimeFallbackExecutor(int64_t num_threads); + + // Prepare() needs to be called once before calling Execute(). It sets up all + // things necessary to execute the given 'mlir_input' with the fallback to + // tensorflow. + void Prepare(llvm::StringRef mlir_input); + + // Execute() can be called several times after the call to Prepare() (e.g. for + // benchmarking). + llvm::SmallVector Execute(llvm::StringRef function_name, + llvm::ArrayRef arguments); + + private: + void RunTfrtInitializer(); + + std::unique_ptr intra_op_; + std::unique_ptr host_context_; + tfrt::ResourceContext resource_context_; + std::unique_ptr exec_ctx_; + tfrt::BefBuffer bef_buffer_; + tfrt::RCReference bef_file_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_TFRT_RUNTIME_FALLBACK_RUNTIME_FALLBACK_EXECUTOR_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tfrt/runtime_fallback/runtime_fallback_ops.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tfrt/runtime_fallback/runtime_fallback_ops.h new file mode 100644 index 00000000..9d77a1a7 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tfrt/runtime_fallback/runtime_fallback_ops.h @@ -0,0 +1,43 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This file defines the operations used in the Runtime Fallback dialect. + +#ifndef TENSORFLOW_COMPILER_MLIR_TFRT_RUNTIME_FALLBACK_RUNTIME_FALLBACK_OPS_H_ +#define TENSORFLOW_COMPILER_MLIR_TFRT_RUNTIME_FALLBACK_RUNTIME_FALLBACK_OPS_H_ + +#include "mlir/IR/Dialect.h" // from @llvm-project +#include "mlir/IR/OpDefinition.h" // from @llvm-project +#include "mlir/IR/TypeUtilities.h" // from @llvm-project +#include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project +#include "tfrt/tensor/opdefs/tensor.h" // from @tf_runtime + +namespace mlir { +namespace tfd { + +// Dialect for TFRT delegate operations. +class RuntimeFallbackDialect : public Dialect { + public: + explicit RuntimeFallbackDialect(MLIRContext* context); + static StringRef getDialectNamespace() { return "tfd"; } +}; + +} // namespace tfd +} // namespace mlir + +#define GET_OP_CLASSES +#include "tensorflow/compiler/mlir/tfrt/runtime_fallback_ops.h.inc" + +#endif // TENSORFLOW_COMPILER_MLIR_TFRT_RUNTIME_FALLBACK_RUNTIME_FALLBACK_OPS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tfrt/saved_model/saved_model.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tfrt/saved_model/saved_model.h new file mode 100644 index 00000000..087d50de --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tfrt/saved_model/saved_model.h @@ -0,0 +1,80 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TFRT_SAVED_MODEL_SAVED_MODEL_H_ +#define TENSORFLOW_COMPILER_MLIR_TFRT_SAVED_MODEL_SAVED_MODEL_H_ + +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/strings/string_view.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/STLFunctionalExtras.h" +#include "llvm/ADT/StringRef.h" +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/platform/status.h" +#include "tfrt/bef/bef_buffer.h" // from @tf_runtime +#include "tfrt/core_runtime/tensor_handle.h" // from @tf_runtime + +namespace tfrt { +class CoreRuntime; +} + +namespace mlir { +class ModuleOp; +} + +namespace tensorflow { + +// TFRTSavedModelSignatureInfo contains the metadata for a signature in the +// savedmodel such as function name, inputs/outputs' names and types. This can +// be used to retrieve these information in a tf_saved_model module. +struct TFRTSavedModelSignatureInfo { + llvm::StringRef func_name; + + // The following are metadata for inputs. + llvm::ArrayRef input_names; + llvm::ArrayRef< + std::pair> + input_specs; + llvm::ArrayRef input_devices; + + // The following are metadata for outputs. + llvm::ArrayRef output_names; + llvm::ArrayRef< + std::pair> + output_specs; + + // The following are metadata for bound_inputs, ie. captures. + llvm::ArrayRef bound_inputs; +}; + +// Apply `map_fn` on every exported function in the module with the +// corresponding signature metadata populated in TFRTSavedModelSignatureInfo for +// the function. +absl::Status MapFunctionSignaturesFromTFSavedModelMLIR( + mlir::ModuleOp module, + llvm::function_ref map_fn); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_TFRT_SAVED_MODEL_SAVED_MODEL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tfrt/tfrt_fallback_registration.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tfrt/tfrt_fallback_registration.h new file mode 100644 index 00000000..65f1554e --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tfrt/tfrt_fallback_registration.h @@ -0,0 +1,33 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This file implements TFRuntimeFallback tensor conversion function for +// converting to host tensor. + +#ifndef TENSORFLOW_COMPILER_MLIR_TFRT_TFRT_FALLBACK_REGISTRATION_H_ +#define TENSORFLOW_COMPILER_MLIR_TFRT_TFRT_FALLBACK_REGISTRATION_H_ + +#include "mlir/IR/Dialect.h" // from @llvm-project + +namespace tensorflow { +namespace tfd { + +// Register conversion functions for TFRuntimeFallbackTensors. +void RegisterTfrtFallbackDialect(mlir::DialectRegistry ®istry); + +} // namespace tfd +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_TFRT_TFRT_FALLBACK_REGISTRATION_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tfrt/transforms/attr_lowering_utils.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tfrt/transforms/attr_lowering_utils.h new file mode 100644 index 00000000..791cb346 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tfrt/transforms/attr_lowering_utils.h @@ -0,0 +1,42 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_ATTR_LOWERING_UTILS_H_ +#define TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_ATTR_LOWERING_UTILS_H_ + +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/StringRef.h" +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project + +namespace tensorflow { + +// TODO(chky): attributes "_output_shapes" should be removed by any tool that +// generates TF MLIR dialect, as they are not used by CoreRuntime. Remove this +// filtering logic once unused attributes are cleaned up in the upper layer. +bool IsUnusedTfrtAttribute(llvm::StringRef name); + +// Create a single attribute that contains the named attribute lists. It is an +// array of pairs. The key must be a string attribute, and the value can be +// any attribute that is supported by CoreRuntime. +mlir::ArrayAttr CreateTfrtOpAttrs(llvm::ArrayRef attrs, + mlir::Builder& builder); + +bool IsSupportedTfrtNumericDType(mlir::Type type); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_ATTR_LOWERING_UTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tfrt/transforms/corert_converter.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tfrt/transforms/corert_converter.h new file mode 100644 index 00000000..be212e44 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tfrt/transforms/corert_converter.h @@ -0,0 +1,175 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_CORERT_CONVERTER_H_ +#define TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_CORERT_CONVERTER_H_ + +#include +#include +#include +#include + +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringMap.h" +#include "llvm/ADT/StringRef.h" +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/SymbolTable.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Transforms/DialectConversion.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.h" +#include "tfrt/basic_kernels/opdefs/types.h" // from @tf_runtime +#include "tfrt/core_runtime/opdefs/core_runtime.h" // from @tf_runtime +#include "tfrt/core_runtime/opdefs/types.h" // from @tf_runtime + +namespace tensorflow { + +struct ParseDeviceNameResult { + std::string device_type; + std::string device_name; + std::string op_handler_name; +}; + +// A helper class for converting CoreRT types and attributes. +class CoreRTConverter : public mlir::TypeConverter { + public: + CoreRTConverter( + mlir::MLIRContext *context, + const mlir::TF::SideEffectAnalysis::Info *side_effect_analysis); + // Materialize all derived attributes. Note that this is only needed by + // CoreRT ops and fallback ops. + void MaterializeDerivedAttributes(mlir::Operation *op); + + // Similar to CreateOpAttrs, create a single attribute that contains the + // named attribute lists, which is an array of pairs, with keys and values + // both being string attributes. The values represent function names. + // This method also populates a vector of attribute keys to be removed. + // If `use_mlir_func_name` is true, the function name given by MLIR will be + // used, which could be different from the original function name in the graph + // function library. This is used when the original function has been changed + // by lowering passes, and hence it needs to be exported to function library + // for runtime to use. + mlir::ArrayAttr CreateOpFuncAttrs( + const mlir::SymbolTable &symbol_table, + llvm::ArrayRef attrs, + llvm::SmallVector *func_attr_keys, + bool use_mlir_func_name = false); + + // Parse the device name of `op` to TFRT's device name. For example, "/CPU:0" + // will be parsed as "cpu". Return None if no device is assigned. + std::optional ParseDeviceName( + llvm::StringRef device_name) const; + std::optional ParseDeviceName( + mlir::Operation *op) const; + + // Convert the device name in a TF op to a op_handler value produced by the + // corresponding GetOpHandler in the current block. If there does not exist + // one, insert a GetOpHandler to the beginning of the block and return the + // device value. + mlir::Value ConvertOpHandler(mlir::Operation *op, llvm::StringRef device_name, + mlir::ConversionPatternRewriter *rewriter); + + // Get a DistributedContext value to be used by the given op. The + // DistributedContext value should be shared by all operations in the body + // of the same FuncOp. If there does not exist one, return a null Value. + mlir::Value GetDistributedContext(mlir::Operation *op, + mlir::ConversionPatternRewriter *rewriter); + + // Get a RemoteChainManager value to be used by the given op. The + // RemoteChainManager value should be shared by all operations in the body + // of the same FuncOp. If there does not exist one, return a null Value. + mlir::Value GetRemoteChainManager(mlir::Operation *op, + mlir::ConversionPatternRewriter *rewriter); + + // Get a TaskHandle value with the given task name. If the TaskHandle value + // has already been created for the given task name within the same FuncOp, + // return this TaskHandle value. Otherwise, return a null Value. + mlir::Value GetTaskHandle(mlir::Operation *op, StringRef task_name, + mlir::ConversionPatternRewriter *rewriter); + + // Any local operation which uses any result of the `op` should depend on the + // given `chain`. + void RegisterLocalSideEffectChain(mlir::Operation *op, mlir::Value chain) { + local_side_effect_chains_[op] = chain; + } + + // Return a local chain for side effects for `op`. If there are multiple + // chains, a merge_chains kernel will be inserted and the merged chain will be + // returned. + mlir::Value GetLocalSideEffectChain( + mlir::Operation *op, mlir::ConversionPatternRewriter *rewriter); + + mlir::Type op_handler_type() { + return builder_.getType<::tfrt::corert::OpHandlerType>(); + } + + mlir::Type tensor_handle_type() { + return builder_.getType<::tfrt::corert::TensorHandleType>(); + } + + mlir::Type chain_type() { + return builder_.getType<::tfrt::compiler::ChainType>(); + } + + mlir::Builder &builder() { return builder_; } + + private: + // TODO(chky): attributes "_output_shapes" should be removed by any tool that + // generates TF MLIR dialect, as they are not used by CoreRuntime. Remove this + // filtering logic once unused attributes are cleaned up in the upper layer. + bool IsUnusedAttribute(llvm::StringRef name) const { + // NOTE: attributes "f.*" are function attribute related and + // are added during importing graph to MLIR TF Executor dialect. These + // attributes are not actually used by TF ops with function attributes. + // TODO(b/180399811): Re-evaluate the usage of these attributes. + static const char *const kUnusedAttributes[] = { + "_output_shapes", + "result_segment_sizes", + "operand_segment_sizes", + }; + + for (auto attr : kUnusedAttributes) { + if (name == attr) { + return true; + } + } + + return name.contains("f."); + } + + // Returns the converted attribute in TFRT dialect. If the conversion fails, + // returns a null attribute instead. + mlir::Attribute ConvertAttribute(mlir::Attribute attr); + + mlir::TypeAttr ConvertTypeAttribute(mlir::TypeAttr type_attr); + + mlir::Builder builder_; + + const mlir::TF::SideEffectAnalysis::Info &side_effect_analysis_; + + llvm::DenseMap local_side_effect_chains_; + llvm::DenseMap distributed_context_by_func_; + llvm::DenseMap remote_chain_mgr_by_func_; + llvm::DenseMap> + task_handles_by_func_; + llvm::StringMap op_handler_by_name_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_CORERT_CONVERTER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tfrt/transforms/fallback_converter.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tfrt/transforms/fallback_converter.h new file mode 100644 index 00000000..c1c1d42a --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tfrt/transforms/fallback_converter.h @@ -0,0 +1,96 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_FALLBACK_CONVERTER_H_ +#define TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_FALLBACK_CONVERTER_H_ + +#include + +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/IR/ValueRange.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Transforms/DialectConversion.h" // from @llvm-project + +namespace tensorflow { +namespace tfrt_compiler { + +inline llvm::StringRef GetDefaultCpuDeviceName() { + static constexpr char kCpuDeviceName[] = + "/job:localhost/replica:0/task:0/device:CPU:0"; + return kCpuDeviceName; +} + +class FallbackConverter : public mlir::TypeConverter { + public: + explicit FallbackConverter(mlir::MLIRContext *context); + + // Return the next dense key for fallback ops. The key is simply an array + // index so that in runtime, the fallback ops can be efficiently retrieved. + int64_t GetNextFallbackKey() const { return fallback_ops_.size(); } + + void RegisterFallbackOp(mlir::Operation *op) { fallback_ops_.push_back(op); } + + void ReplaceFallbackOp(int64_t key, mlir::Operation *op) { + fallback_ops_[key] = op; + } + + llvm::ArrayRef GetFallbackOps() const { + return fallback_ops_; + } + + private: + mlir::Builder builder_; + // Using a vector to keep fallback ops in order, and the key for a fallback op + // is its corresponding index here. + llvm::SmallVector fallback_ops_; +}; + +// Convert the `value` that is a !corert.tensorhandle to +// !tfrt_fallback.tf_tensor. If needed, tensor conversion kernels will be added. +// On error it returns nullptr. +mlir::Value ConvertCoreRTTensorHandleToFallbackTensor( + mlir::Location loc, llvm::StringRef device, mlir::Value value, + mlir::ConversionPatternRewriter &rewriter); + +// Convert the `value` that is a !tfrt_fallback.tf_tensor to +// !corert.tensorhandle. If needed, tensor conversion kernels will be added. On +// error it returns nullptr. +mlir::Value ConvertFallbackTensorToCoreRTTensorHandle( + mlir::Location loc, mlir::Value value, + mlir::ConversionPatternRewriter &rewriter); + +// Convert operands that might be !tfrt_fallback.tf_tensor for corert operations +// that take only !corert.tensorhandle. +mlir::LogicalResult ConvertCoreRTOperands( + mlir::Operation *op, mlir::ValueRange operands, + llvm::SmallVectorImpl *new_operands, + mlir::ConversionPatternRewriter &rewriter); + +// Convert operands that might be !corert.tensorhandle for fallback operations +// that take only !tfrt_fallback.tf_tensor. +mlir::LogicalResult ConvertFallbackOperands( + mlir::Operation *op, llvm::StringRef device, mlir::ValueRange operands, + llvm::SmallVectorImpl *new_operands, + mlir::ConversionPatternRewriter &rewriter); + +} // namespace tfrt_compiler +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_FALLBACK_CONVERTER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tfrt/transforms/gpu_passes.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tfrt/transforms/gpu_passes.h new file mode 100644 index 00000000..801e10bf --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tfrt/transforms/gpu_passes.h @@ -0,0 +1,38 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_GPU_PASSES_H_ +#define TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_GPU_PASSES_H_ + +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassOptions.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/IR/PatternMatch.h" // from @llvm-project + +namespace tensorflow { + +// Registers dialects used in TFRT GPU lowering. +void RegisterGpuDialects(mlir::DialectRegistry *registry); + +// Adds a target dialect and rewrite patterns for TFRT GPU lowering. +void AddGpuTargetDialectAndPatterns(mlir::MLIRContext *context, + mlir::ConversionTarget *target, + mlir::RewritePatternSet *patterns); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_GPU_PASSES_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tfrt/transforms/ifrt/extract_callback.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tfrt/transforms/ifrt/extract_callback.h new file mode 100644 index 00000000..a345d1d8 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tfrt/transforms/ifrt/extract_callback.h @@ -0,0 +1,36 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_IFRT_EXTRACT_CALLBACK_H_ +#define TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_IFRT_EXTRACT_CALLBACK_H_ + +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/OwningOpRef.h" // from @llvm-project +#include "tensorflow/core/framework/types.pb.h" + +namespace tensorflow { +namespace ifrt_serving { + +// Extracts a module that consists of a public callback function in name of +// `callback_key` and all its reachables. +absl::StatusOr> ExtractCallbackModule( + mlir::ModuleOp module, absl::string_view callback_key); + +} // namespace ifrt_serving +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_IFRT_EXTRACT_CALLBACK_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_backend_compiler.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_backend_compiler.h new file mode 100644 index 00000000..0dfaa081 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_backend_compiler.h @@ -0,0 +1,67 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_IFRT_IFRT_BACKEND_COMPILER_H_ +#define TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_IFRT_IFRT_BACKEND_COMPILER_H_ + +#include "absl/status/status.h" +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tfrt/backend_compiler.h" +#include "tensorflow/compiler/mlir/tfrt/transforms/tpu_passes.h" +#include "tensorflow/core/tfrt/runtime/runtime.h" + +namespace tensorflow { +namespace ifrt_serving { + +// Implements the custom backend compiler for IFRT based serving in TFRT. +class IfrtBackendCompiler : public tensorflow::BackendCompiler { + public: + struct Options { + // If true, disable running TFRTSetTPUDeviceAttrPass which set the default + // `tf.device` and `device_assignment` attributes. + // This is a server-level option for now. We can consider to make it a + // per-model option in the future. + bool disable_set_default_tpu_device_and_device_assignment_attributes = true; + }; + + explicit IfrtBackendCompiler(TpuCompiler* tpu_compiler = nullptr) + : tpu_compiler_(tpu_compiler) {} + + explicit IfrtBackendCompiler(const Options& ifrt_backend_compile_options, + TpuCompiler* tpu_compiler = nullptr) + : tpu_compiler_(tpu_compiler), + compile_options_(ifrt_backend_compile_options) {} + + void GetDependentDialects(mlir::DialectRegistry& registry) const override { + if (tpu_compiler_) { + tpu_compiler_->RegisterTPUDialects(®istry); + } + } + + // Rewrites the tensorflow graph in MLIR for IFRT serving. The methods + // extracts regions for IFRT execution on accelerator (e.g. TPU). + absl::Status CompileTensorflow( + tensorflow::tfrt_stub::ModelRuntimeContext& model_context, + mlir::ModuleOp module) const override; + + private: + TpuCompiler* tpu_compiler_; // Not owned. + Options compile_options_; +}; + +} // namespace ifrt_serving +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_IFRT_IFRT_BACKEND_COMPILER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_constants.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_constants.h new file mode 100644 index 00000000..3e497182 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_constants.h @@ -0,0 +1,40 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_IFRT_IFRT_CONSTANTS_H_ +#define TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_IFRT_IFRT_CONSTANTS_H_ + +#include "absl/strings/string_view.h" + +namespace tensorflow { +namespace ifrt_serving { + +// Attribute name of a text TpuCompileMetadataProto. Note that the text proto is +// not backward compatible and shall not be serialized. +inline constexpr absl::string_view kMetadataTextAttrName = + "__tpu_compile_metadata_text"; + +// Name of a variable as loaded IFRT array . +inline constexpr absl::string_view kVariableArrayNameAttr = + "__variable_array_name"; + +// Attribute of a text `VariableDeviceShardingConfigProto`. +inline constexpr absl::string_view kVariableShardingConfigTextAttr = + "__variable_sharding_config_text"; + +} // namespace ifrt_serving +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_IFRT_IFRT_CONSTANTS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_types.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_types.h new file mode 100644 index 00000000..c64672cd --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_types.h @@ -0,0 +1,33 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_IFRT_IFRT_TYPES_H_ +#define TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_IFRT_IFRT_TYPES_H_ + +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.pb.h" + +namespace tensorflow { +namespace ifrt_serving { + +struct DtypeAndShape { + tensorflow::DataType dtype; + tensorflow::TensorShape shape; +}; + +} // namespace ifrt_serving +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_IFRT_IFRT_TYPES_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf2hlo.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf2hlo.h new file mode 100644 index 00000000..7122f26e --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf2hlo.h @@ -0,0 +1,86 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_IFRT_TF2HLO_H_ +#define TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_IFRT_TF2HLO_H_ + +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_compilation.pb.h" +#include "tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_types.h" +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "xla/python/ifrt/client.h" +#include "xla/python/ifrt/topology.h" +#include "xla/service/hlo.pb.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h" + +namespace tensorflow { +namespace ifrt_serving { + +struct Tf2HloArg { + mlir::ModuleOp module; + // `input_dtypes_and_shapes` can be mutable during Tf2HLO compilation. + std::vector input_dtypes_and_shapes; + absl::Span variable_arg_indices; + absl::string_view entry_function_name; + // `compile_metadata` can be mutable during Tf2HLO compilation. + tensorflow::tpu::TPUCompileMetadataProto compile_metadata; + tensorflow::XlaHelpers::ShapeRepresentationFn shape_representation_fn; + std::shared_ptr topology; + absl::string_view platform_name; + bool enable_r1_optimization = true; + + absl::StatusOr Fingerprint() const; +}; + +struct Tf2HloResult { + xla::HloModuleProto hlo_module_proto; + tensorflow::tpu::TPUCompileMetadataProto compile_metadata; + tf2xla::HostComputeMetadata host_compute_metadata; + Tf2HLOResultProto ToProto() const; +}; + +absl::Status UpdateCompileMetadata( + tensorflow::tpu::TPUCompileMetadataProto& metadata, + absl::Span inputs); + +absl::StatusOr GetCompileMetadata( + mlir::ModuleOp module, const xla::ifrt::Client& ifrt_client); + +class TfToHloCompiler { + public: + TfToHloCompiler() = default; + virtual ~TfToHloCompiler() = default; + + // Returns a cache key that can be used to identify the result of + // CompileTfToHlo. + virtual absl::StatusOr Key(const Tf2HloArg& arg); + + virtual absl::StatusOr CompileTfToHlo(Tf2HloArg& arg); +}; + +} // namespace ifrt_serving +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_IFRT_TF2HLO_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf_ifrt_passes.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf_ifrt_passes.h new file mode 100644 index 00000000..34490c74 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf_ifrt_passes.h @@ -0,0 +1,86 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_IFRT_TF_IFRT_PASSES_H_ +#define TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_IFRT_TF_IFRT_PASSES_H_ + +#include +#include + +#include "absl/status/status.h" +#include "llvm/ADT/StringRef.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassRegistry.h" // from @llvm-project + +namespace tensorflow { +namespace ifrt_serving { + +// Create a pass to convert tf_device.cluster_func to tf.ifrt_program_call. +std::unique_ptr> +CreateRewriteClusterToIfrtCallPass(); + +// Creates a pass that sinks variable tensor argument to `tf.IfrtCall` as named +// arrays and lowers `tf.ReadVariableOp` to `tf.IfrtLoadVariableOp`. +std::unique_ptr> +CreateSinkVariableAsNamedArrayPass(); + +// Creates a pass that splits `tf.RestoreV2` ops. +std::unique_ptr> +CreateTfRestoreSplittingPass(); + +// Creates a pass that merges `tf.RestoreV2` ops. +std::unique_ptr> +CreateTfRestoreMergingPass(); + +// Creates a pass that propagates inputs of no-op identity ops to their outputs. +std::unique_ptr> +CreateTfIdentityPropagationPass(); + +// Creates a pass that prunes unused `tf.RestoreV2` ops. +std::unique_ptr> +CreateTfRestorePruningPass(); + +// Creates a pass that lower `tf.RestoreVariableOp` to +// `tf.IfrtRestoreVariableOp`. +std::unique_ptr> +CreateLowerToIfrtRestoreVariablePass(); + +// Creates a pass that cleans up device attributes from all ops. +std::unique_ptr> +CreateTfDeviceCleanupPass(); + +#define GEN_PASS_REGISTRATION +#include "tensorflow/compiler/mlir/tfrt/transforms/ifrt/passes.h.inc" // IWYU pragma: keep + +// Register all passes. +void RegisterTfIfrtPasses(); + +// Setup the input pass manager to enable IR dumping after each pass. +// Note a side effect of this method is that multi threading will be disabled. +void EnablePassIRPrinting(mlir::PassManager& pm, + const std::string& dump_group_name, + llvm::StringRef module_name); + +// Convert tf_device.cluster_func to tf.ifrt_program_call. +// The callee function is converted to a ifrt_program. +absl::Status RunClusterToIfrtRuntimeOpsPassPipeline( + mlir::ModuleOp module, llvm::StringRef module_name = llvm::StringRef()); + +} // namespace ifrt_serving +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_IFRT_TF_IFRT_PASSES_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tfrt/transforms/mlrt/assign_op_key.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tfrt/transforms/mlrt/assign_op_key.h new file mode 100644 index 00000000..6ed9f1e9 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tfrt/transforms/mlrt/assign_op_key.h @@ -0,0 +1,32 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_MLRT_ASSIGN_OP_KEY_H_ +#define TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_MLRT_ASSIGN_OP_KEY_H_ +#include + +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project + +namespace tensorflow { +namespace mlrt_compiler { + +// Create a pass that assigns an op_key to every fallback OP. The op_key +// provides a uniform key to look up online cost for a specific op. +// This pass is expected to run before parallerization. +std::unique_ptr> CreateAssignOpKeyPass(); + +} // namespace mlrt_compiler +} // namespace tensorflow +#endif // TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_MLRT_ASSIGN_OP_KEY_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tfrt/transforms/mlrt/async_while.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tfrt/transforms/mlrt/async_while.h new file mode 100644 index 00000000..684e8dd1 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tfrt/transforms/mlrt/async_while.h @@ -0,0 +1,36 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_MLRT_ASYNC_WHILE_H_ +#define TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_MLRT_ASYNC_WHILE_H_ + +#include + +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project + +namespace tensorflow { +namespace mlrt_compiler { + +// Creates a pass that converts applicable tf.While to tf_mlrt.AsyncWhile. +// tf_mlrt.AsyncWhile dispatch iterations asynchronously, thus allowing +// pipelining between iterations to reduce latency. This is intended for +// tf.While that is not converted from tf.MapFn, but still can benefit from +// asynchronous execution of iterations to reduce latency. +std::unique_ptr> CreateAsyncWhilePass(); + +} // namespace mlrt_compiler +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_MLRT_ASYNC_WHILE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tfrt/transforms/mlrt/execute_op_registry.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tfrt/transforms/mlrt/execute_op_registry.h new file mode 100644 index 00000000..93dde814 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tfrt/transforms/mlrt/execute_op_registry.h @@ -0,0 +1,60 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_MLRT_EXECUTE_OP_REGISTRY_H_ +#define TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_MLRT_EXECUTE_OP_REGISTRY_H_ + +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/SmallVector.h" +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project + +namespace tensorflow { +namespace mlrt_compiler { + +class ExecuteOpRegistry { + public: + mlir::LogicalResult RegisterExecuteOp(mlir::Operation* op, uint32_t op_key) { + if (op_key >= execute_ops_.size()) { + execute_ops_.resize(op_key + 1); + } + if (auto* register_op = execute_ops_[op_key]) { + if (register_op->getName() != op->getName() || + register_op->getAttrs() != op->getAttrs()) { + return op->emitError() << "Key " << op_key << " already registered."; + } + return mlir::success(); + } + execute_ops_[op_key] = op; + return mlir::success(); + } + + void ReplaceExecuteOp(int64_t key, mlir::Operation* op) { + execute_ops_[key] = op; + } + + llvm::ArrayRef GetExecuteOps() const { + return execute_ops_; + } + + private: + // Using a vector to keep fallback ops in order, and the key for a fallback op + // is its corresponding index here. + llvm::SmallVector execute_ops_; +}; + +} // namespace mlrt_compiler +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_MLRT_EXECUTE_OP_REGISTRY_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tfrt/transforms/mlrt/fuse_mlrt_ops.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tfrt/transforms/mlrt/fuse_mlrt_ops.h new file mode 100644 index 00000000..6f772a89 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tfrt/transforms/mlrt/fuse_mlrt_ops.h @@ -0,0 +1,31 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_MLRT_FUSE_MLRT_OPS_H_ +#define TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_MLRT_FUSE_MLRT_OPS_H_ + +#include + +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project + +namespace tensorflow { +namespace mlrt_compiler { + +std::unique_ptr> CreateFuseMlrtOpPass(); + +} +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_MLRT_FUSE_MLRT_OPS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tfrt/transforms/mlrt/ifrt_set_tpu_host_allocator.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tfrt/transforms/mlrt/ifrt_set_tpu_host_allocator.h new file mode 100644 index 00000000..ddd8ee0a --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tfrt/transforms/mlrt/ifrt_set_tpu_host_allocator.h @@ -0,0 +1,34 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_MLRT_IFRT_SET_TPU_HOST_ALLOCATOR_H_ +#define TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_MLRT_IFRT_SET_TPU_HOST_ALLOCATOR_H_ + +#include + +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project + +namespace tensorflow { +namespace mlrt_compiler { + +// Creates a pass that set tpu input producers to use tpu host allocators. +std::unique_ptr> +CreateIfrtSetTpuHostAllocatorPass(); + +} // namespace mlrt_compiler +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_MLRT_IFRT_SET_TPU_HOST_ALLOCATOR_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tfrt/transforms/mlrt/import_model.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tfrt/transforms/mlrt/import_model.h new file mode 100644 index 00000000..1258c953 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tfrt/transforms/mlrt/import_model.h @@ -0,0 +1,54 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_MLRT_IMPORT_MODEL_H_ +#define TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_MLRT_IMPORT_MODEL_H_ + +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/OwningOpRef.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tfrt/translate/tfrt_compile_options.h" +#include "tensorflow/core/platform/statusor.h" +#include "tensorflow/core/tfrt/fallback/cost_recorder.h" +#include "tensorflow/core/tfrt/fallback/fallback_state.h" +#include "tensorflow/core/tfrt/mlrt/bytecode/bytecode.h" +#include "tensorflow/core/tfrt/runtime/runtime.h" + +namespace tensorflow { +namespace mlrt_compiler { + +// Converts an MLIR `module` in TF dialect to MLRT's bytecode format. If +// `module_with_op_keys` is non-null, the intermediate module on which passes +// until (including) AssignOpKeyPass have run will be cloned to it. +// +// This is for initial conversion. +absl::StatusOr ConvertTfMlirToBytecode( + const TfrtCompileOptions& options, tfrt_stub::FallbackState& fallback_state, + mlir::ModuleOp module, tfrt_stub::ModelRuntimeContext& model_context, + mlir::OwningOpRef* module_with_op_keys = nullptr, + std::vector* added_xla_function_names = nullptr); + +// Converts an MLIR `module_with_op_keys` in TF dialect to MLRT's bytecode +// format, with op costs from `cost_recorder`. +// +// This is for re-conversion. +absl::StatusOr ConvertTfMlirWithOpKeysToBytecode( + const TfrtCompileOptions& options, + const tfrt_stub::FallbackState& fallback_state, + mlir::ModuleOp module_with_op_keys, + const tfrt_stub::CostRecorder& cost_recorder); + +} // namespace mlrt_compiler +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_MLRT_IMPORT_MODEL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tfrt/transforms/mlrt/mlrt_device_constants.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tfrt/transforms/mlrt/mlrt_device_constants.h new file mode 100644 index 00000000..3c2c588a --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tfrt/transforms/mlrt/mlrt_device_constants.h @@ -0,0 +1,27 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_MLRT_MLRT_DEVICE_CONSTANTS_H_ +#define TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_MLRT_MLRT_DEVICE_CONSTANTS_H_ + +namespace tensorflow { +namespace mlrt_compiler { + +inline constexpr char kTfMlrtCustomDevice[] = "tf_mlrt.custom_device"; +inline constexpr char kTpuHostDevice[] = "tpu_host_device"; + +} // namespace mlrt_compiler +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_MLRT_MLRT_DEVICE_CONSTANTS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tfrt/transforms/mlrt/parallelization.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tfrt/transforms/mlrt/parallelization.h new file mode 100644 index 00000000..71221276 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tfrt/transforms/mlrt/parallelization.h @@ -0,0 +1,37 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_MLRT_PARALLELIZATION_H_ +#define TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_MLRT_PARALLELIZATION_H_ + +#include + +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "tensorflow/core/tfrt/fallback/cost_recorder.h" + +namespace tensorflow { +namespace mlrt_compiler { + +std::unique_ptr> CreateParallelizationPass( + uint64_t cost_threshold, bool merge_inter_dependent_streams, + const tfrt_stub::CostRecorder* cost_recorder = nullptr); + +std::unique_ptr> +CreateParallelizationPass(); + +} // namespace mlrt_compiler +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_MLRT_PARALLELIZATION_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tfrt/transforms/mlrt/passes.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tfrt/transforms/mlrt/passes.h new file mode 100644 index 00000000..f9bf621b --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tfrt/transforms/mlrt/passes.h @@ -0,0 +1,38 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_MLRT_PASSES_H_ +#define TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_MLRT_PASSES_H_ + +#include "mlir/Pass/PassOptions.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tfrt/transforms/tfrt_pipeline_options.h" +#include "tensorflow/core/tfrt/fallback/cost_recorder.h" +#include "tensorflow/core/tfrt/fallback/fallback_state.h" + +namespace tensorflow { +namespace mlrt_compiler { + +void RegisterMlrtPasses(); + +// Creates a pipeline of passes that lowers MLIR TF dialect to MLRT dialects. +// The op costs from `cost_recorder` (if non-null) are used for Stream Analysis. +void CreateTfToMlrtPipeline( + mlir::OpPassManager& pm, const TfrtPipelineOptions& options, + const tfrt_stub::FallbackState* fallback_state, + const tfrt_stub::CostRecorder* cost_recorder = nullptr); + +} // namespace mlrt_compiler +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_MLRT_PASSES_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tfrt/transforms/mlrt/rewrite_ifrt_load_variable.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tfrt/transforms/mlrt/rewrite_ifrt_load_variable.h new file mode 100644 index 00000000..1423011b --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tfrt/transforms/mlrt/rewrite_ifrt_load_variable.h @@ -0,0 +1,36 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_MLRT_REWRITE_IFRT_LOAD_VARIABLE_H_ +#define TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_MLRT_REWRITE_IFRT_LOAD_VARIABLE_H_ + +#include + +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project + +namespace tensorflow { +namespace mlrt_compiler { + +// Creates a pass that converts tf.IfrtLoadVariableOp to +// tf_mlrt.TFIfrtLoadVariableOp and inserts tf_mlrt.Await on the returned future +// from tf_mlrt.TFIfrtLoadVariableOp if it is used by CPU ops. +std::unique_ptr> +CreateRewriteIfrtLoadVariablePass(); + +} // namespace mlrt_compiler +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_MLRT_REWRITE_IFRT_LOAD_VARIABLE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tfrt/transforms/mlrt/tf_to_mlrt.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tfrt/transforms/mlrt/tf_to_mlrt.h new file mode 100644 index 00000000..1206f66f --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tfrt/transforms/mlrt/tf_to_mlrt.h @@ -0,0 +1,48 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_MLRT_TF_TO_MLRT_H_ +#define TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_MLRT_TF_TO_MLRT_H_ +#include + +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tfrt/transforms/tfrt_pipeline_options.h" +#include "tensorflow/core/tfrt/fallback/fallback_state.h" + +namespace tensorflow { +namespace mlrt_compiler { + +// The conversion pass that is run before 'tf-mlrt-parallelization' passes. The +// parallelization pass changes the graph content, so any rewrite/conversion +// that depends on the graph instead of individual ops should be done before +// parallelization. +std::unique_ptr> +CreateTfToMlrtPreParallelizationConversionPass( + const TfrtPipelineOptions& options); + +// The conversion pass that is run after 'tf-mlrt-parallelization' passes. The +// parallelization pass changes the graph content, so this pass should only +// contain conversion that depends on individual ops. +std::unique_ptr> +CreateTfToMlrtConversionPass(const TfrtPipelineOptions& options); + +std::unique_ptr> +CreateTfToMlrtConversionPass(const TfrtPipelineOptions& options, + const tfrt_stub::FallbackState* fallback_state); + +} // namespace mlrt_compiler +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_MLRT_TF_TO_MLRT_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tfrt/transforms/mlrt/tpu_conversion_patterns.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tfrt/transforms/mlrt/tpu_conversion_patterns.h new file mode 100644 index 00000000..20592c95 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tfrt/transforms/mlrt/tpu_conversion_patterns.h @@ -0,0 +1,42 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_MLRT_TPU_CONVERSION_PATTERNS_H_ +#define TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_MLRT_TPU_CONVERSION_PATTERNS_H_ + +#include "mlir/IR/DialectRegistry.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/Transforms/DialectConversion.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tfrt/transforms/mlrt/execute_op_registry.h" +#include "tensorflow/compiler/mlir/tfrt/transforms/tfrt_pipeline_options.h" + +namespace tensorflow { +namespace mlrt_compiler { + +void RegisterTpuDialect(mlir::DialectRegistry& registry); + +void PopulateTpuPreParallelizationConversionPatterns( + mlir::ConversionTarget& target, mlir::RewritePatternSet& patterns, + const TfrtPipelineOptions& options); + +void PopulateTpuConversionPatterns(mlir::ConversionTarget& target, + mlir::RewritePatternSet& patterns, + mlir::TypeConverter& type_converter, + ExecuteOpRegistry& execute_op_registry, + const TfrtPipelineOptions& options); + +} // namespace mlrt_compiler +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_MLRT_TPU_CONVERSION_PATTERNS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tfrt/transforms/mlrt/util.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tfrt/transforms/mlrt/util.h new file mode 100644 index 00000000..c47471f6 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tfrt/transforms/mlrt/util.h @@ -0,0 +1,30 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_MLRT_UTIL_H_ +#define TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_MLRT_UTIL_H_ + +#include "mlir/IR/Operation.h" // from @llvm-project + +namespace tensorflow { +namespace mlrt_compiler { + +// Use fallback by default for anything that does not have a native kernel +// with some exceptions. +bool UseFallback(mlir::Operation *op); + +} // namespace mlrt_compiler +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_MLRT_UTIL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tfrt/transforms/mlrt/while_to_map_fn.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tfrt/transforms/mlrt/while_to_map_fn.h new file mode 100644 index 00000000..a45c0387 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tfrt/transforms/mlrt/while_to_map_fn.h @@ -0,0 +1,31 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_MLRT_WHILE_TO_MAP_FN_H_ +#define TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_MLRT_WHILE_TO_MAP_FN_H_ + +#include + +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project + +namespace tensorflow { +namespace mlrt_compiler { + +std::unique_ptr> CreateWhileToMapFnPass(); + +} // namespace mlrt_compiler +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_MLRT_WHILE_TO_MAP_FN_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tfrt/transforms/passes.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tfrt/transforms/passes.h new file mode 100644 index 00000000..8dad2c71 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tfrt/transforms/passes.h @@ -0,0 +1,169 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_PASSES_H_ +#define TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_PASSES_H_ + +#include +#include + +#include "llvm/ADT/ArrayRef.h" +#include "llvm/Support/CommandLine.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassOptions.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Transforms/DialectConversion.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.h" +#include "tensorflow/compiler/mlir/tfrt/transforms/tfrt_pipeline_options.h" +#include "tensorflow/compiler/mlir/tfrt/transforms/tpu_passes.h" +#include "tensorflow/core/platform/status.h" + +namespace mlir { +class PassManager; +} + +namespace tensorflow { + +namespace tfrt_compiler { + +// Create a pass to insert kernels that copy fallback tensors when they are +// passed to multiple threads, to avoid atomic contention on their refcounts. +std::unique_ptr> +CreateInsertFallbackTensorCopyPass(); + +// Create a pass to reorder tf.Assert ops or tf.If ops that contains only +// tf.Assert ops to the end of the function, to avoid unnecessary control +// dependencies to other ops. +std::unique_ptr> +CreateReorderTfAssertPass(); + +// Create a pass to optimize the side-effect of control flow ops. eg. if both +// branches of a tf.If op contains only non-side-effecting ops, its +// `is_stateless` attribute will be set to true. +std::unique_ptr> +CreateOptimizeTfControlFlowSideEffectPass(); + +// Create a pass to remove tf.If ops' operands that are produced by tf.Const +// ops. +std::unique_ptr> +CreateRemoveTfIfConstArgsPass(); + +// Create a pass to merge non-side-effecting tf.If ops that have the same +// operands. +std::unique_ptr> CreateMergeTfIfOpsPass(); + +// Create a pass to deduplicate the function invoked by tf.BatchFunction with +// the same shared_name. +std::unique_ptr> +CreateDeduplicateFunctionsInovkedByBatchFunctionPass(); + +// Create a pass to lower bound the number of threads in tf.BatchFunction. +struct ReconfigBatchOpPassOptions { + int64_t min_num_batch_threads = 1; + int64_t min_max_enqueued_batches = 1; + std::string batch_padding_policy = ""; + int64_t num_batch_threads = 0; + int64_t max_batch_size = 0; + int64_t batch_timeout_micros = 0; + llvm::ArrayRef allowed_batch_sizes = {}; + int64_t max_enqueued_batches = 0; +}; +std::unique_ptr> CreateReconfigBatchOpPass( + ReconfigBatchOpPassOptions options); + +// Create a pass to fuse the TPU Ops for TFRT. +std::unique_ptr> +CreateFuseTpuCompileAndExecutePass(); + +// Create a pass to optimize TF dialect for TFRT workflow. +std::unique_ptr> +CreateOptimizeTfForTfrtPass(); + +std::unique_ptr> CreateTfrtXlaRewritePass(); + +// Create a pass to deduplicate results of tf.If ops. +std::unique_ptr> +CreateDeduplicateIfResultPass(); + +} // namespace tfrt_compiler + +class CoreRTConverter; + +// Create a pass that sink in the var handle op to the callee function when +// proper. +std::unique_ptr> +CreateSinkInInvariantOpsPass(); + +// Create a pass that rewrites tf_saved_model dialect's ops according to TFRT's +// requirements. +std::unique_ptr> +CreateLowerTFSavedModelPass(bool hoist_invariant_ops, + bool fuse_get_resource_ops); + +// Create a pass that converts ref variables to resource variables in a limited +// number of cases. +std::unique_ptr> +CreateConvertReferenceVariableToResourceVariablePass(); + +// Run *ToCoreRTConversionPassRun as free functions. Useful for +// reusing the pass logic in a custom pass with additional conversions. +mlir::LogicalResult TFSavedModelToCoreRTConversionPassRun( + mlir::MLIRContext* context, mlir::func::FuncOp func, + mlir::ConversionTarget* target, mlir::RewritePatternSet* patterns, + CoreRTConverter* corert_converter); + +// Create an operation pass that removes the device attribute from every +// corert.executeop. +std::unique_ptr> +CreateRemoveDeviceAttributePass(); + +// Create an operation pass that inserts corert.transfer op to make sure any +// argument of any op is on the same device of the op itself. +std::unique_ptr> +CreateCrossDeviceTransferPass(); + +// Create a pass that converts MLIR TF dialect to MLIR TFRT dialect. +std::unique_ptr> +CreateTfToTfrtConversionPass(const TfrtPipelineOptions& options); + +// Creates a pipeline of passes that lowers MLIR TF dialect to TFRT dialects. +void CreateTfToTfrtPipeline(mlir::OpPassManager& pm, + const TfrtPipelineOptions& options); + +// Creates a pipeline of passes that lowers MLIR TF dialect from tf.function to +// TFRT dialect. SavedModel related conversions are not included. +absl::Status CreateTfExecutorToTfrtPipeline(mlir::PassManager& pm, + const TfrtPipelineOptions& options); + +// Creates a pipeline of passes that lowers MLIR TF Executor dialect to TF +// dialect for CoreRT purposes. +absl::Status CreateTFExecutorToTFPipeline(mlir::PassManager& pm, + const TfrtPipelineOptions& options); + +// TODO(deqiangc): refactor below helpers once mlrt is OSSed. +void CreateTFExecutorToTFPreInvariantOptimizationPipelineHelper( + mlir::OpPassManager& pm, const TfrtPipelineOptions& options); +void CreateTFExecutorToTFInvariantOptimizationPipelineHelper( + mlir::OpPassManager& pm, const TfrtPipelineOptions& options); + +absl::Status CreateTFExecutorToTFPreInvariantOptimizationPipeline( + mlir::PassManager& pm, const TfrtPipelineOptions& options); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_PASSES_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tfrt/transforms/set_shape_invariant_in_while_ops.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tfrt/transforms/set_shape_invariant_in_while_ops.h new file mode 100644 index 00000000..44929772 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tfrt/transforms/set_shape_invariant_in_while_ops.h @@ -0,0 +1,35 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_SET_SHAPE_INVARIANT_IN_WHILE_OPS_H_ +#define TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_SET_SHAPE_INVARIANT_IN_WHILE_OPS_H_ + +#include + +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project + +namespace tensorflow { +namespace tfrt_compiler { + +// Create a pass to set shape_invariant attribute for all tf.While ops except +// those are on TPU. +std::unique_ptr> +CreateSetShapeInvariantInWhileOps(); + +} // namespace tfrt_compiler +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_SET_SHAPE_INVARIANT_IN_WHILE_OPS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tfrt/transforms/tfrt_pipeline_options.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tfrt/transforms/tfrt_pipeline_options.h new file mode 100644 index 00000000..2588d0f8 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tfrt/transforms/tfrt_pipeline_options.h @@ -0,0 +1,191 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_TFRT_PIPELINE_OPTIONS_H_ +#define TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_TFRT_PIPELINE_OPTIONS_H_ + +#include +#include + +#include "llvm/Support/CommandLine.h" +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tfrt/translate/tfrt_compile_options.h" + +namespace tensorflow { + +struct TfrtPipelineOptions + : public mlir::PassPipelineOptions { + Option saved_model_dir{*this, "saved-model-dir", + llvm::cl::desc(""), llvm::cl::init("")}; + Option default_device{ + *this, "default-device", llvm::cl::desc("default device assignment"), + llvm::cl::init("/job:localhost/replica:0/task:0/device:CPU:0")}; + Option enable_optimizer{ + *this, "enable-optimizer", + llvm::cl::desc("run optimization passes on corert dialect"), + llvm::cl::init(false)}; + Option decompose_resource_ops{ + *this, "decompose-resource-ops", + llvm::cl::desc("decompose composite resource ops into ReadVariableOp and " + "non-resource ops. This is currently used in TFRT " + "savedmodel pipeline."), + llvm::cl::init(false)}; + Option force_data_format{ + *this, "force-data-format", + llvm::cl::desc("force data format for all layout sensitive operations")}; + // TODO(tfrt-devs): consider making compiler to figure out whether to fold + // transpose or not instead of exposing the specific option. + Option skip_fold_transpose_in_ops{ + *this, "skip-fold-transpose-in-ops", + llvm::cl::desc("Skip folding transpose operands in Ops which can support " + "different layouts.")}; + Option target_tpurt{*this, "target-tpurt", + llvm::cl::desc("target TPURT dialect if true"), + llvm::cl::init(false)}; + Option tpu_use_core_selector{ + *this, "tpu-use-core-selector", + llvm::cl::desc("If true, use ServingCoreSelector to pick TPU core. " + "Otherwise, use the assigned core. Currently we use " + "core selector for Servo serving use cases."), + llvm::cl::init(true)}; + Option tpu_use_bundled_transfer{ + *this, "tpu-use-bundled-transfer", + llvm::cl::desc("If true, use BundledTransferToTpuOp to transfer " + "variables and input tensors to TPU."), + llvm::cl::init(true)}; + Option tpu_lower_to_fallback{ + *this, "tpu-lower-to-fallback", + llvm::cl::desc("If true, lower an TF op that's placed on TPU device " + "to be executed by tfrt_fallback.execute."), + llvm::cl::init(true)}; + Option tpu_fuse_ops{ + *this, "tpu-fuse-ops", + llvm::cl::desc("If true, use the TPU fused compile_and_execute kernel"), + llvm::cl::init(false)}; + // TODO(b/194081364): remove this option once we unify servo TPU serving + // result transfer behavior. + Option tpu_transfer_result_to_host{ + *this, "tpu-transfer-result-to-host", + llvm::cl::desc("If true, transfer the result of tpurt.execute from TPU " + "to host."), + llvm::cl::init(true)}; + Option use_tpu_host_allocator_for_inputs{ + *this, "use-tpu-host-allocator-for-inputs", + llvm::cl::desc("If true, fallback executeops that produce inputs to tpu " + "program will use tpu host allocator."), + llvm::cl::init(false)}; + Option tpu_allow_unpadded_batch{ + *this, "tpu-allow-unpadded-batch", + llvm::cl::desc("To allow unpadded batch for TPU execution."), + llvm::cl::values( + clEnumValN(TfrtCompileOptions::TpuAllowUnpaddedBatch::kDisabled, + "disabled", "Disable this feature."), + clEnumValN(TfrtCompileOptions::TpuAllowUnpaddedBatch::kAuto, "auto", + "Enable this feature when in-graph batching is detected."), + clEnumValN(TfrtCompileOptions::TpuAllowUnpaddedBatch::kEnforced, + "enforced", "Force to enable this feature.")), + llvm::cl::init(TfrtCompileOptions::TpuAllowUnpaddedBatch::kDisabled)}; + + Option target_gpu{ + *this, "target-gpu", + llvm::cl::desc("If true, target GPU compiler passes."), + llvm::cl::init(false)}; + + // TODO(b/294895431): Remove the flag and default to the fused op. + Option use_gpu_compile_and_execute_op{ + *this, "use-gpu-compile-and-execute-op", + llvm::cl::desc("If true, gpurt.compile_and_execute is used for GPU"), + llvm::cl::init(false)}; + + Option enable_while_parallel_iterations{ + *this, "enable-while-parallel-iterations", + llvm::cl::desc("If true, tf.While op will be parallelized. This is " + "currently experimental."), + llvm::cl::init(false)}; + + Option hoist_invariant_ops{ + *this, "hoist-invariant-ops", + llvm::cl::desc("If true, invariant ops in savedmodels will be hoisted " + "out to run during loading."), + llvm::cl::init(false)}; + + Option fuse_get_resource_ops_in_hoisting{ + *this, "fuse-get-resource-ops-in-hoisting", + llvm::cl::desc("If true, get_resource_op will be fused during hoisting"), + llvm::cl::init(true)}; + + Option sink_in_invariant_ops{ + *this, "sink-in-invariant-ops", + llvm::cl::desc("If true, sink the selected invariant ops in to the " + "nested functions to facilitate invariant ops hoisting."), + llvm::cl::init(false)}; + + Option cost_threshold{ + *this, "tfrt-cost-threshold", + llvm::cl::desc( + "The cost threshold to decide whether a sequence of operations is " + "cheap, and then whether it can be executed inline."), + llvm::cl::init(1)}; + + Option min_num_batch_threads{ + *this, "tfrt-min-num-batch-threads", + llvm::cl::desc("The minimum number of batch threads"), llvm::cl::init(1)}; + + Option min_max_enqueued_batches{ + *this, "tfrt-min-max-enqueued-batches", + llvm::cl::desc( + "The minimum of the maximum number of outstanding enqueued batches"), + llvm::cl::init(1)}; + + Option batch_padding_policy{ + *this, "tfrt-batch-padding-policy", + llvm::cl::desc("The policy used when padding (or splitting) batches."), + llvm::cl::init("")}; + + Option num_batch_threads{ + *this, "tfrt-num-batch-threads", + llvm::cl::desc( + "The number of threads for processing batches in parallel"), + llvm::cl::init(0)}; + + Option max_batch_size{ + *this, "tfrt-max-batch-size", + llvm::cl::desc("The maximum allowed batch size"), llvm::cl::init(0)}; + + Option batch_timeout_micros{ + *this, "tfrt-batch-timeout-micros", + llvm::cl::desc("The maximum number of microseconds before outputting an " + "incomplete batch"), + llvm::cl::init(0)}; + + ListOption allowed_batch_sizes{ + *this, "tfrt-allowed-batch-sizes", + llvm::cl::desc("Allowed sizes for padding (or splitting) batches")}; + + Option max_enqueued_batches{ + *this, "tfrt-max-enqueued-batches", + llvm::cl::desc("The maximum number of batches enqueued for processing " + "before requests are failed fast"), + llvm::cl::init(0)}; + + Option merge_inter_dependent_streams{ + *this, "tfrt-merge-inter-dependent-streams", + llvm::cl::desc("If true, streams with inter data depenedencies will be " + "preferred to be merged for inline execution."), + llvm::cl::init(false)}; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_TFRT_PIPELINE_OPTIONS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tfrt/transforms/tpu_passes.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tfrt/transforms/tpu_passes.h new file mode 100644 index 00000000..3cae00e2 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tfrt/transforms/tpu_passes.h @@ -0,0 +1,87 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_TPU_PASSES_H_ +#define TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_TPU_PASSES_H_ + +// This file contains stub implementations for Google internal TPU APIs. + +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Pass/PassOptions.h" +#include "tensorflow/compiler/mlir/tfrt/translate/tfrt_compile_options.h" + +namespace tensorflow { + +class CoreRTConverter; + +namespace tfrt_compiler { + +class FallbackConverter; + +} + +struct TfrtTpuCompileOptions + : mlir::PassPipelineOptions { + Option move_resource_gather_to_host{ + *this, "move-resource-gather-to-host", + llvm::cl::desc("Move resource gather ops to host"), + llvm::cl::init(false)}; + Option gather_table_width_threshold_bytes{ + *this, "gather-table-width-threshold-bytes", + llvm::cl::desc( + "The threshold to control whether a TPU resource gather op should be " + "moved to host. A negative values means all are moved."), + llvm::cl::init(-1)}; +}; + +struct TfrtTpuExecuteOpConversionOptions { + bool use_core_selector = false; + bool use_bundled_transfer = false; + bool transfer_result_to_host = false; + bool use_tpu_host_allocator_for_inputs = false; + TfrtCompileOptions::TpuAllowUnpaddedBatch allow_unpadded_batch = + TfrtCompileOptions::TpuAllowUnpaddedBatch::kDisabled; +}; + +// Registers a set of dialects used in TFRT TPU lowering. +inline void RegisterTPUDialects(mlir::DialectRegistry *registry) {} + +// Adds a target dialect and a set of rewrite patterns for TFRT TPU lowering. +inline void AddTPUTargetDialectAndPatterns( + mlir::ConversionTarget *target, mlir::RewritePatternSet *patterns, + mlir::MLIRContext *context, CoreRTConverter *corert_converter, + tfrt_compiler::FallbackConverter *fallback_converter, + const TfrtTpuExecuteOpConversionOptions &tpu_exec_conv_opts, + bool tpu_lower_to_fallback) {} + +// Rewrites specific TF TPU ops to equivalent TF ops in a module. +inline mlir::LogicalResult RunTPUBackwardCompatConversion( + mlir::ModuleOp module, const TfrtTpuCompileOptions &options) { + return mlir::failure(); +} + +// The rewrite rules to support the fallback execution of TPUPartitionedCallOp. +inline mlir::LogicalResult RunTPUPartitionedCallFallbackCompatConversion( + mlir::ModuleOp module) { + return mlir::failure(); +} + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_TPU_PASSES_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tfrt/transforms/update_op_cost_in_tfrt_mlir.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tfrt/transforms/update_op_cost_in_tfrt_mlir.h new file mode 100644 index 00000000..99b7c192 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tfrt/transforms/update_op_cost_in_tfrt_mlir.h @@ -0,0 +1,32 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_UPDATE_OP_COST_IN_TFRT_MLIR_H_ +#define TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_UPDATE_OP_COST_IN_TFRT_MLIR_H_ + +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "tensorflow/core/tfrt/fallback/cost_recorder.h" + +namespace tensorflow { +namespace tfrt_compiler { + +// Updates the existing costs for all the fallback ops with the records in +// `cost_recorder`. +void UpdateOpCostInTfrtMlir(mlir::ModuleOp op, + const tfrt_stub::CostRecorder& cost_recorder); + +} // namespace tfrt_compiler +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_UPDATE_OP_COST_IN_TFRT_MLIR_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tfrt/transforms/utils.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tfrt/transforms/utils.h new file mode 100644 index 00000000..0b94fc79 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tfrt/transforms/utils.h @@ -0,0 +1,46 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_UTILS_H_ +#define TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_UTILS_H_ + +#include +#include + +#include "absl/strings/string_view.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/SymbolTable.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project + +namespace tensorflow { + +// Checks if the given `value` is a resource argument. +bool IsResourceArgument(mlir::Value value); + +// Checks if an operand is the value of a variable. +bool IsResultVariable(const mlir::Value &original_operand, + const mlir::Value &operand); + +// Canonicalize the symbol attr to the original TF function name. +std::optional CanonicalizeTensorflowFunctionName( + const mlir::SymbolTable &symbol_table, absl::string_view mlir_func_name, + bool use_mlir_func_name = false); + +// Returns true if the function is a session initializer in tf_saved_model +// dialect. +bool IsSessionInitializer(mlir::func::FuncOp op); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_UTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tfrt/translate/import_model.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tfrt/translate/import_model.h new file mode 100644 index 00000000..9459f90c --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tfrt/translate/import_model.h @@ -0,0 +1,73 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TFRT_TRANSLATE_IMPORT_MODEL_H_ +#define TENSORFLOW_COMPILER_MLIR_TFRT_TRANSLATE_IMPORT_MODEL_H_ + +#include +#include +#include + +#include "absl/functional/function_ref.h" +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/Pass/PassManager.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tfrt/function/function.h" +#include "tensorflow/compiler/mlir/tfrt/transforms/passes.h" +#include "tensorflow/compiler/mlir/tfrt/transforms/tfrt_pipeline_options.h" +#include "tensorflow/compiler/mlir/tfrt/translate/tfrt_compile_options.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/tfrt/fallback/fallback_state.h" +#include "tensorflow/core/tfrt/runtime/runtime.h" +#include "tfrt/bef/bef_buffer.h" // from @tf_runtime + +namespace tensorflow { + +struct FunctionBody; + +// Converts an MLIR `module` in TF dialect to TFRT's Binary Executable Format. +// If `fallback_state` is not null, the MLIR functions for XLA clusters in +// the form of XlaLaunch will be exported and added to the function library when +// needed. The nested functions will also be exported. If +// `added_xla_function_names` is not null, it will be populated with the names +// of the added XLA functions. +absl::Status ConvertTfMlirToBef( + const TfrtCompileOptions& options, mlir::ModuleOp module, + tfrt::BefBuffer* bef_buffer, tfrt_stub::ModelRuntimeContext& model_context, + tfrt_stub::FallbackState* fallback_state = nullptr, + std::vector* added_xla_function_names = nullptr); + +absl::Status ConvertTfMlirToRuntimeExecutable( + const TfrtCompileOptions& options, mlir::ModuleOp module, + absl::FunctionRef< + absl::Status(mlir::PassManager&, mlir::ModuleOp, + const tensorflow::TfrtPipelineOptions& options)> + emit_executable, + tfrt_stub::ModelRuntimeContext& model_context, + tfrt_stub::FallbackState* fallback_state = nullptr, + std::vector* added_xla_function_names = nullptr); + +std::unique_ptr GetTfrtPipelineOptions( + const TfrtCompileOptions& options); + +// Adds MLIR functions for XLA clusters to the function library. +absl::Status AddXlaFunctions( + tfrt_stub::FallbackState* fallback_state, mlir::ModuleOp mlir_module, + std::vector* added_xla_function_names = nullptr); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_TFRT_TRANSLATE_IMPORT_MODEL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tfrt/translate/mlrt/mlir_to_bytecode.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tfrt/translate/mlrt/mlir_to_bytecode.h new file mode 100644 index 00000000..95086564 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tfrt/translate/mlrt/mlir_to_bytecode.h @@ -0,0 +1,135 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_TFRT_TRANSLATE_MLRT_MLIR_TO_BYTECODE_H_ +#define TENSORFLOW_COMPILER_MLIR_TFRT_TRANSLATE_MLRT_MLIR_TO_BYTECODE_H_ + +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/StringRef.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "tensorflow/core/tfrt/mlrt/bytecode/bytecode.h" + +namespace mlrt { + +class ModuleEmitterContext; + +// Defines a custom attribute encoding registry. Users can register custom +// attribute encoding for their dialects in this registry. If no custom encoder +// is registered for a dialect, the default encoding with a limited support, the +// EncodeSimpleAttribute() below, will be used. +class AttributeEncoderRegistry { + public: + using EncoderFn = std::function( + const ModuleEmitterContext&, mlir::Attribute)>; + + void Register(absl::string_view dialect, EncoderFn encoder) { + encoders_[dialect] = std::move(encoder); + } + + // Returns the encoder for the specified dialect. It can be nullptr if it is + // not registered for this dialect. The returned reference will be invalidated + // if Register() is called. + const EncoderFn* Get(absl::string_view dialect) const { + auto iter = encoders_.find(dialect); + if (iter != encoders_.end()) return &iter->second; + return nullptr; + } + + private: + absl::flat_hash_map encoders_; +}; + +class ModuleEmitterContext { + public: + explicit ModuleEmitterContext( + const AttributeEncoderRegistry* attribute_encoder_registry) + : attribute_encoder_registry_(*attribute_encoder_registry) {} + + void AddKernelName(std::string name) { + AddData(std::move(name), kernels_, kernel_id_map_); + } + + int GetKernelId(llvm::StringRef name) const { + return kernel_id_map_.at(name); + } + + absl::Status AddAttribute(mlir::Operation* op, mlir::Attribute attr); + + int GetAttributeId(mlir::Attribute attr) const { + return attribute_id_map_.lookup(attr); + } + + int AddFunction(mlir::func::FuncOp func); + + int GetFunctionId(absl::string_view name) const { + return function_name_id_map_.at(name); + } + + absl::Span kernels() const { return kernels_; } + absl::Span attributes() const { return attributes_; } + absl::Span functions() const { return functions_; } + + private: + int AddData(std::string data, std::vector& data_vector, + absl::flat_hash_map& data_map) { + auto iter = data_map.find(data); + if (iter != data_map.end()) return iter->second; + + int id = data_vector.size(); + data_map[data] = id; + data_vector.push_back(std::move(data)); + return id; + } + + absl::StatusOr DefaultEncodeAttribute(mlir::Attribute attr); + + const AttributeEncoderRegistry& attribute_encoder_registry_; + + std::vector kernels_; + absl::flat_hash_map kernel_id_map_; + + std::vector attributes_; + llvm::DenseMap attribute_id_map_; + absl::flat_hash_map attribute_data_id_map_; + + std::vector functions_; + absl::flat_hash_map function_name_id_map_; +}; + +// Encodes a few simple attributes. Users can use this function in their custom +// attribute encoder. +std::optional EncodeSimpleAttribute( + const ModuleEmitterContext& module_context, mlir::Attribute attr); + +absl::StatusOr EmitExecutable( + const AttributeEncoderRegistry& attribute_encoder_registry, + mlir::ModuleOp module); + +} // namespace mlrt + +#endif // TENSORFLOW_COMPILER_MLIR_TFRT_TRANSLATE_MLRT_MLIR_TO_BYTECODE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tfrt/translate/mlrt/test_utils.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tfrt/translate/mlrt/test_utils.h new file mode 100644 index 00000000..6140c711 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tfrt/translate/mlrt/test_utils.h @@ -0,0 +1,119 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_TFRT_TRANSLATE_MLRT_TEST_UTILS_H_ +#define TENSORFLOW_COMPILER_MLIR_TFRT_TRANSLATE_MLRT_TEST_UTILS_H_ + +#include +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/tfrt/graph_executor/sync_resource_state.h" +#include "tensorflow/core/tfrt/mlrt/attribute/attribute.h" +#include "tensorflow/core/tfrt/mlrt/bytecode/bytecode.h" +#include "tensorflow/core/tfrt/mlrt/bytecode/kernel.h" +#include "tensorflow/core/tfrt/mlrt/interpreter/context.h" +#include "tensorflow/core/tfrt/mlrt/interpreter/interpreter_testutil.h" +#include "tensorflow/core/tfrt/mlrt/interpreter/value.h" +#include "tensorflow/core/tfrt/stubs/tfrt_native_lowering_stub.h" +#include "tensorflow/core/tfrt/utils/tensor_util.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" +#include "tfrt/host_context/concurrent_work_queue.h" // from @tf_runtime +#include "tfrt/host_context/execution_context.h" // from @tf_runtime +#include "tfrt/host_context/host_allocator.h" // from @tf_runtime +#include "tfrt/host_context/host_context.h" // from @tf_runtime +#include "tfrt/support/string_util.h" // from @tf_runtime +#include "tfrt/tensor/dense_host_tensor.h" // from @tf_runtime +#include "tfrt/tensor/dense_tensor_utils.h" // from @tf_runtime + +namespace mlrt { +namespace testing { + +absl::StatusOr EncodeAttribute(const tensorflow::AttrValue& attr); + +absl::Status EncodeAttributes(AttributeTable& attributes, + const tensorflow::AttrValueMap& attr_map); + +absl::StatusOr>> +CreateKernelAndAttrs(int num_inputs, int num_outputs, + mlrt::ExecutionContext& exec_ctx, mlrt::bc::Buffer* buffer, + const tensorflow::AttrValueMap& attrs = {}); + +template +absl::Status TestMlrtKernel( + absl::string_view kernel_name, absl::Span regs, + tfrt::HostContext* host, int num_inputs, int num_outputs, + absl::Span expected_outputs, + mlrt::KernelRegistry* registry, bool approx_equal = false, + const tensorflow::AttrValueMap& attrs = {}) { + mlrt::ExecutionContext execution_context(nullptr); + + mlrt::bc::Buffer buffer; + TF_ASSIGN_OR_RETURN(auto kernel_and_attrs, + CreateKernelAndAttrs(num_inputs, num_outputs, + execution_context, &buffer, attrs)); + + tensorflow::tfrt_stub::SyncResourceState sync_resource_state; + tfrt::AddSyncContext(execution_context, *host, &sync_resource_state); + + auto kernel_fn = registry->Get(kernel_name); + mlrt::KernelFrame::State state(regs, kernel_and_attrs.second, + &execution_context); + mlrt::KernelFrame frame(&state); + frame.set_kernel(kernel_and_attrs.first); + + kernel_fn(frame); + + TF_RETURN_IF_ERROR(execution_context.status()); + + for (int i = 0, j = num_inputs; i < expected_outputs.size(); ++i, ++j) { + const auto& expected_output = expected_outputs[i]; + auto expected_dht = tfrt::ConvertTfTensorToDHT(expected_output); + if (!expected_dht) { + return absl::InternalError(tfrt::StrCat(expected_dht.takeError())); + } + + if (!approx_equal) { + if (!tfrt::TensorEqual(regs[j].Get(), + *expected_dht)) { + return absl::InternalError( + absl::StrCat("wrong result for ", kernel_name)); + } + } else { + if (!tfrt::TensorApproxEqual(regs[j].Get(), + *expected_dht)) { + return absl::InternalError( + absl::StrCat("wrong result for ", kernel_name)); + } + } + } + + return absl::OkStatus(); +} + +} // namespace testing +} // namespace mlrt + +#endif // TENSORFLOW_COMPILER_MLIR_TFRT_TRANSLATE_MLRT_TEST_UTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tfrt/translate/tfrt_compile_options.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tfrt/translate/tfrt_compile_options.h new file mode 100644 index 00000000..e75fdc35 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tfrt/translate/tfrt_compile_options.h @@ -0,0 +1,190 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TFRT_TRANSLATE_TFRT_COMPILE_OPTIONS_H_ +#define TENSORFLOW_COMPILER_MLIR_TFRT_TRANSLATE_TFRT_COMPILE_OPTIONS_H_ + +#include +#include +#include +#include +#include + +#include "tensorflow/core/protobuf/config.pb.h" + +namespace tensorflow { + +class BackendCompiler; + +enum class TfrtDeviceInfraTarget { + kCpu, // CPU only, no device support. + kTpurt, // Target TPURT dialect and kernels. + kTfFallback, // Target TPU kernels in TF Fallback. + kBridgeFallback, // TPU support but choose kTpurt or kTfFallback depending on + // whether the graph has unsupported feature in Bridge. + kGpu, // Target GPU specific compiler passes and runtime + // initializations. +}; + +std::ostream& operator<<(std::ostream& os, TfrtDeviceInfraTarget device_target); + +struct TfrtCompileOptions { + std::string saved_model_dir; + // TODO(tfrt-devs): Ideally, compiler should make the decision where + // to place the variable. + std::string variable_device = "/job:localhost/replica:0/task:0/device:CPU:0"; + std::string default_device = "/job:localhost/replica:0/task:0/device:CPU:0"; + + // Enable compiler optimization in TFRT dialect. + bool enable_optimizer = true; + + // If true, run grappler passes before compiling. + bool enable_grappler = true; + + // Graph rewrite options that will be applied on GraphDef before converting to + // MLIR. + GraphOptions graph_options; + + // Force data format for all layout sensitive operations, eg. setting it to + // "NHWC" will changes all data format in the graph to "NHWC" by inserting + // or removing related tf.Transpose op. Currently the supported formats are + // "NHWC" and "NCHW". + // + // TODO(tfrt-devs): Ideally compiler should figure out whether the + // data format should be changed, instead of controlled by users. + std::string force_data_format; + + // The target device infrastructure to use. This will trigger target specific + // compiler passes and runtime initialization. + TfrtDeviceInfraTarget device_target = TfrtDeviceInfraTarget::kCpu; + + // The custom compiler for device compilation. Instead of using the enum above + // to choose predefined device target, users can use this `backend_compiler` + // to inject their customized implementation. + BackendCompiler* backend_compiler = nullptr; + + // If true, use the fused TPU compile_and_execute kernel, which performs all + // TPU inference related operations, e.g. core selection, h2d/d2h transfers, + // compile and execute. + bool tpu_fuse_ops = false; + + // If true, resource gather ops in the device graph are moved to host graphs + // in order to saved TPU memory usage. This option is experimental. + bool tpu_move_resource_gather_to_host = false; + + // The threshold in bytes that controls whether a resource gather op on TPU + // should be moved to host. A negative value means there is no threshold. This + // option is experimental. + int64_t tpu_gather_table_width_threshold_bytes = -1; + + // If true, fallback executeops that produce inputs to tpu program will use + // tpu host allocator. This options is experimental. + bool use_tpu_host_allocator_for_inputs = false; + + // To allow unpadded batch for TPU execution. + enum class TpuAllowUnpaddedBatch { + // Disable this feature. + kDisabled, + // Enable this feature when in-graph batching is detected. + kAuto, + // Force to enable this feature. + kEnforced, + }; + TpuAllowUnpaddedBatch tpu_allow_unpadded_batch = + TpuAllowUnpaddedBatch::kDisabled; + + // If true, the compiler will try to hoist invariant ops (e.g., const ops and + // their non-side-effecting consumers) to loading phase, which avoids the + // runtime cost during later running. + // TODO(tfrt-devs): Set the default value to true after testing as it is + // supposed to be turned on by default. + bool hoist_invariant_ops = false; + + // If true, get_resource_op will be fused during hoisting. + bool fuse_get_resource_ops_in_hoisting = true; + + // If true, the compiler will try to sink in the invariant ops (e.g. const + // ops, var handle ops, etc.) to the nested function (e.g. batch function) to + // facilitate invariant ops hoisting. + // TODO(tfrt-devs): Set the default value to true after testing as it is + // supposed to be turned on by default. + bool sink_in_invariant_ops = false; + + // This flag behaves differently for TFRT and MLRT. + // For TFRT, if true, tf.While's iterations will be parallelized on a + // best-effort basis. This is currently experimental. MLRT attempts to convert + // tf.while to tf_mlrt.map_fn regardless of this flag. For tf.While that + // cannot be converted tf_mlrt.map_fn, MLRT try to parallelize tf.while's + // iterations on a best-effort basis. + bool enable_while_parallel_iterations = false; + + // The cost threshold to decide whether a sequence of operations is cheap, and + // then whether it can be executed inline. If the cost is smaller than the + // threshold, it will be considered as cheap operations. Since the cost must + // be positive integers, setting the threshold to 1 makes all operations + // expensive. + uint64_t cost_threshold = 1; + + // The minimum number of batch threads. This number provides a lower bound on + // the number of batch threads on top of what is specified in the model. If + // the number of batch threads is too small (e.g. smaller than the number of + // parallel hardware accelerator available), it can lead to under utilization + // of resources. + int64_t min_num_batch_threads = 1; + + // The minimum of the maximum number of enqueued batches. This number provides + // a lower bound on top of what is specified in the model. If the number of + // max_enqueued_batches is too small, it can lead to under utilization of + // resources. + int64_t min_max_enqueued_batches = 1; + + // The policy used by a BatchScheduler to pad (or split) batches. + std::string batch_padding_policy; + + // Batching parameters to be rewritten in the existing BatchFunction ops. + BatchingOptions batch_options; + + // If true, streams with inter data dependencies will be preferred to be + // merged for inline execution. + bool merge_inter_dependent_streams = true; + + // Whether to enable the DecomposeResourceOpsPass. + bool decompose_resource_ops = true; + + // Whether to compile to sync TFRT dialect. + bool compile_to_sync_tfrt_dialect = false; + + // Whether to use gpurt.compile_and_execute for GPU. + // TODO(b/294895431): Remove the flag and default to the fused op. + bool use_gpu_compile_and_execute_op = false; + + // If true, MLIR module will be serialized to aot_packages. + bool serialize_mlir_module_to_aot_packages = false; + + // Serialized MLIR module file under aot_packages. + std::string aot_mlir_module_file; + + // If true, BEF will be serialized to aot_packages. + bool serialize_bef_to_aot_packages = false; + + // Serialized BEF file under aot_packages. + std::string aot_bef_file; +}; + +std::ostream& operator<<(std::ostream& os, const TfrtCompileOptions& options); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_TFRT_TRANSLATE_TFRT_COMPILE_OPTIONS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tfrt/utils/export.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tfrt/utils/export.h new file mode 100644 index 00000000..84f0e272 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tfrt/utils/export.h @@ -0,0 +1,36 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_TFRT_UTILS_EXPORT_H_ +#define TENSORFLOW_COMPILER_MLIR_TFRT_UTILS_EXPORT_H_ + + +#include "absl/functional/any_invocable.h" +#include "absl/status/status.h" +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "tensorflow/core/framework/function.pb.h" + +namespace tensorflow { + +// Exports every function in `module` into `tensorflow.FunctionDef` and calls +// `callback` for each `tensorflow.FunctionDef`. Modifies `module` in place to +// be suitable for FunctionDef export. +absl::Status ExportFunctionDefs( + mlir::ModuleOp module, + absl::AnyInvocable callback, + bool export_tf_original_func_name = true); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_TFRT_UTILS_EXPORT_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tfrt/utils/host_context.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tfrt/utils/host_context.h new file mode 100644 index 00000000..7b2e143d --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tfrt/utils/host_context.h @@ -0,0 +1,36 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TFRT_UTILS_HOST_CONTEXT_H_ +#define TENSORFLOW_COMPILER_MLIR_TFRT_UTILS_HOST_CONTEXT_H_ + +#include +#include + +#include "absl/base/attributes.h" +#include "tfrt/host_context/host_context.h" // from @tf_runtime + +namespace tensorflow { + +// The name of the default host device for running fallback kernels. +ABSL_CONST_INIT extern const char* const kDefaultHostDeviceName; + +std::unique_ptr CreateSingleThreadedHostContext(); +std::unique_ptr CreateMultiThreadedHostContext( + int64_t num_threads); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_TFRT_UTILS_HOST_CONTEXT_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.h new file mode 100644 index 00000000..1241a73d --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.h @@ -0,0 +1,69 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This file defines the operations used in the TFFramework dialect. +// +#ifndef TENSORFLOW_COMPILER_MLIR_TOOLS_KERNEL_GEN_IR_TF_FRAMEWORK_OPS_H_ +#define TENSORFLOW_COMPILER_MLIR_TOOLS_KERNEL_GEN_IR_TF_FRAMEWORK_OPS_H_ + +#include "absl/status/status.h" +#include "mlir/Bytecode/BytecodeOpInterface.h" // from @llvm-project +#include "mlir/Dialect/Bufferization/IR/AllocationOpInterface.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/Dialect.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/OpDefinition.h" // from @llvm-project +#include "mlir/IR/OpImplementation.h" // from @llvm-project +#include "mlir/IR/TypeSupport.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/Interfaces/ControlFlowInterfaces.h" // from @llvm-project +#include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_status.h.inc" +#include "tensorflow/core/protobuf/error_codes.pb.h" + +namespace mlir { +namespace kernel_gen { +namespace tf_framework { + +/// OpKernelContextType corresponds to C++ class OpKernelContext defined in +/// tensorflow/core/framework/op_kernel.h +class OpKernelContextType + : public Type::TypeBase { + public: + using Base::Base; + static constexpr StringLiteral name = + "kernel_gen.tf_framework.op_kernel_context"; +}; + +class JITCallableType + : public Type::TypeBase { + public: + using Base::Base; + static constexpr StringLiteral name = "kernel_gen.tf_framework.jit_callable"; +}; + +absl::StatusCode ConvertAttrToEnumValue(ErrorCode error_code); + +} // namespace tf_framework +} // namespace kernel_gen +} // namespace mlir + +#define GET_OP_CLASSES +#include "tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_dialect.h.inc" +#include "tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.h.inc" + +#endif // TENSORFLOW_COMPILER_MLIR_TOOLS_KERNEL_GEN_IR_TF_FRAMEWORK_OPS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tools/kernel_gen/kernel_creator.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tools/kernel_gen/kernel_creator.h new file mode 100644 index 00000000..8fa1f26d --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tools/kernel_gen/kernel_creator.h @@ -0,0 +1,58 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +//===- kernel_creator.h -----------------------------------------*- C++ -*-===// +// +// This file declares the function to compile a TF kernel function to gpu +// binary (hsaco for AMD, cubin for NVIDIA) or to a gpu binary with host side. +// +//===----------------------------------------------------------------------===// +#ifndef TENSORFLOW_COMPILER_MLIR_TOOLS_KERNEL_GEN_KERNEL_CREATOR_H_ +#define TENSORFLOW_COMPILER_MLIR_TOOLS_KERNEL_GEN_KERNEL_CREATOR_H_ + +#include +#include +#include + +#include "absl/status/statusor.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/StringRef.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/OwningOpRef.h" // from @llvm-project +#include "tensorflow/core/platform/statusor.h" + +namespace tensorflow { +namespace kernel_gen { + +// Parses tf_code to create a module. An MLIRContext is taken in case any +// unexpected dialects are needed. +absl::StatusOr> SetupContextAndParseModule( + mlir::MLIRContext& context, llvm::StringRef tf_code); + +// Converts TF code to LLVM with or without GPU support. +absl::StatusOr> GenerateKernelForHloCode( + mlir::MLIRContext& context, llvm::StringRef tf_code, + llvm::ArrayRef architectures, + llvm::ArrayRef tile_sizes, llvm::ArrayRef unroll_factors, + bool print_ptx, bool print_llvmir, bool enable_ftz, bool index_64bit, + bool jit_compile, bool jit_i64_indexed_for_large_tensors, + bool apply_cl_options); + +} // namespace kernel_gen +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_TOOLS_KERNEL_GEN_KERNEL_CREATOR_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tools/kernel_gen/tf_framework_c_interface.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tools/kernel_gen/tf_framework_c_interface.h new file mode 100644 index 00000000..66c84df4 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tools/kernel_gen/tf_framework_c_interface.h @@ -0,0 +1,52 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TOOLS_KERNEL_GEN_TF_FRAMEWORK_C_INTERFACE_H_ +#define TENSORFLOW_COMPILER_MLIR_TOOLS_KERNEL_GEN_TF_FRAMEWORK_C_INTERFACE_H_ + +#include +#include + +#include "mlir/ExecutionEngine/RunnerUtils.h" // from @llvm-project + +namespace mlir { +namespace kernel_gen { +namespace tf_framework { + +extern "C" MLIR_RUNNERUTILS_EXPORT void* _mlir_ciface_tf_alloc( + void* op_kernel_ctx, size_t num_elements, size_t element_size, + int32_t output_index, int32_t num_candidates, + int32_t* candidate_input_indices); + +extern "C" MLIR_RUNNERUTILS_EXPORT void _mlir_ciface_tf_dealloc( + void* op_kernel_ctx, void* ptr); + +extern "C" MLIR_RUNNERUTILS_EXPORT void _mlir_ciface_tf_report_error( + void* op_kernel_ctx, int32_t error_code, char* msg); + +extern "C" MLIR_RUNNERUTILS_EXPORT void* _mlir_ciface_tf_jit_compile( + void* op_kernel_ctx, char* code, int64_t num_tile_sizes, + int64_t* tile_sizes_ptr, int64_t num_unroll_factors, + int64_t* unroll_factors_ptr, bool enable_ftz, bool index_64bit); + +extern "C" MLIR_RUNNERUTILS_EXPORT void _mlir_ciface_tf_jit_execute( + void* op_kernel_ctx, void* callable, void* result, int64_t num_args, + void* args_ptr); + +} // namespace tf_framework +} // namespace kernel_gen +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_TOOLS_KERNEL_GEN_TF_FRAMEWORK_C_INTERFACE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tools/kernel_gen/tf_gpu_runtime_wrappers.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tools/kernel_gen/tf_gpu_runtime_wrappers.h new file mode 100644 index 00000000..54d8b0dd --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tools/kernel_gen/tf_gpu_runtime_wrappers.h @@ -0,0 +1,98 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TOOLS_KERNEL_GEN_TF_GPU_RUNTIME_WRAPPERS_H_ +#define TENSORFLOW_COMPILER_MLIR_TOOLS_KERNEL_GEN_TF_GPU_RUNTIME_WRAPPERS_H_ + +#include "absl/container/flat_hash_map.h" +#include "mlir/ExecutionEngine/RunnerUtils.h" // from @llvm-project +#include "tensorflow/core/framework/resource_base.h" +#include "tensorflow/core/framework/resource_op_kernel.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/status.h" +#include "tsl/platform/hash.h" +#include "tsl/platform/thread_annotations.h" + +#if GOOGLE_CUDA +#include "third_party/gpus/cuda/include/cuda.h" +#endif +#if TENSORFLOW_USE_ROCM +#include "rocm/include/hip/hip_runtime.h" +#endif + +namespace mlir { +namespace kernel_gen { +namespace tf_framework { + +class GPURuntimeCache : public tensorflow::ResourceBase { + public: +#if GOOGLE_CUDA + using GPUModule = CUmodule; + using GPUFunction = CUfunction; +#endif +#if TENSORFLOW_USE_ROCM + using GPUModule = hipModule_t; + using GPUFunction = hipFunction_t; +#endif + + ~GPURuntimeCache() override; + static constexpr const char* kDefaultResourceName = "mlir-gpu-runtime-cache"; + static absl::Status Create(GPURuntimeCache** dst); + std::string DebugString() const override; + + // Assumes that no two modules are loaded from the same memory location over + // the lifetime of this cache. This allows to use the pointer as a key. All + // modules are unloaded on destruction of this cache. + GPUModule LookupOrLoadModule(void* data); + + GPUFunction LookupOrGetFunction(GPUModule module, const char* kernel_name); + + private: + struct FunctionKey { + GPUModule module; + const char* kernel_name; + + friend bool operator==(const FunctionKey& lhs, const FunctionKey& rhs) { + return lhs.module == rhs.module && lhs.kernel_name == rhs.kernel_name; + } + + struct Hash { + size_t operator()(const FunctionKey& key) const { + return tsl::Hash64Combine(tsl::hash()(key.module), + tsl::Hash64(key.kernel_name)); + } + }; + }; + + tensorflow::mutex mu_; + absl::flat_hash_map gpu_module_by_data_ptr_ + TF_GUARDED_BY(mu_); + absl::flat_hash_map + gpu_function_by_module_and_name_ TF_GUARDED_BY(mu_); +}; + +// Implements a C wrapper around the TensorFlow runtime and CUDA (or ROCm) +// library that allows launching a kernel on the current device and stream from +// a binary blob for the module and function name. +extern "C" MLIR_RUNNERUTILS_EXPORT void _mlir_ciface_tf_launch_kernel( + void* ctx, void* module_blob, char* kernel_name, intptr_t gridX, + intptr_t gridY, intptr_t gridZ, intptr_t blockX, intptr_t blockY, + intptr_t blockZ, void** params); + +} // namespace tf_framework +} // namespace kernel_gen +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_TOOLS_KERNEL_GEN_TF_GPU_RUNTIME_WRAPPERS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tools/kernel_gen/tf_jit_cache.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tools/kernel_gen/tf_jit_cache.h new file mode 100644 index 00000000..15d105ca --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tools/kernel_gen/tf_jit_cache.h @@ -0,0 +1,59 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TOOLS_KERNEL_GEN_TF_JIT_CACHE_H_ +#define TENSORFLOW_COMPILER_MLIR_TOOLS_KERNEL_GEN_TF_JIT_CACHE_H_ + +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" +#include "mlir/ExecutionEngine/ExecutionEngine.h" // from @llvm-project +#include "tensorflow/core/framework/resource_base.h" +#include "tensorflow/core/framework/resource_op_kernel.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/status.h" +#include "tsl/platform/thread_annotations.h" + +namespace mlir { +namespace kernel_gen { +namespace tf_framework { + +class JITCache : public tensorflow::ResourceBase { + public: + static constexpr const char* kDefaultResourceName = "mlir-jit-cache"; + static absl::Status Create(JITCache** dst); + + std::string DebugString() const override; + ExecutionEngine* LookupOrCompile( + std::string code, + std::function>()> + compile_callback); + size_t Size(); + + private: + tensorflow::mutex mu_; + absl::flat_hash_map> + execution_engine_by_key_ TF_GUARDED_BY(mu_); +}; + +} // namespace tf_framework +} // namespace kernel_gen +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_TOOLS_KERNEL_GEN_TF_JIT_CACHE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h new file mode 100644 index 00000000..45e248ce --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h @@ -0,0 +1,120 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TOOLS_KERNEL_GEN_TRANSFORMS_PASSES_H_ +#define TENSORFLOW_COMPILER_MLIR_TOOLS_KERNEL_GEN_TRANSFORMS_PASSES_H_ + +#include +#include + +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Dialect/GPU/IR/GPUDialect.h" // from @llvm-project +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project + +#define GEN_PASS_DECL_TFKERNELTOLLVMPASS +#define GEN_PASS_DECL_EMBEDTFFRAMEWORKPASS +#define GEN_PASS_DECL_REWRITETFFRAMEWORKASSERT +#define GEN_PASS_DECL_FUNCTOJITINVOCATIONPASS +#define GEN_PASS_DECL_BUFFERREUSEPASS +#define GEN_PASS_DECL_SHAPETODESCRIPTORSPASS +#define GEN_PASS_DECL_KERNELGENFINALBUFFERIZEPASS +#define GEN_PASS_DECL_GPUKERNELTOBLOBPASS +#define GEN_PASS_DECL_PARALLELLOOPSTOSEQUENTIAL +#define GEN_PASS_DECL_PROPAGATETFABIKNOWLEDGETOKERNELS +#define GEN_PASS_DECL_PROPAGATESHAPEKNOWLEDGETOKERNELS +#define GEN_PASS_DECL_FUSEINNERPARALLELLOOPSPASS +#define GEN_PASS_DECL_COPYCLEANUPPASS + +namespace mlir { +namespace kernel_gen { +namespace tf_framework { + +// Pass to replace some of the Standard ops with TF Framework ops. +// * adds tf_framework::OpKernelContextType argument to the function +// * std.alloc becomes tf_framework.alloc_raw +// * std.dealloc becomes tf_framework.dealloc_raw +// * std.assert becomes tf_framework.assert +std::unique_ptr> CreateEmbedTFFrameworkPass(); + +// Pass to convert tf_framework.assert operations to calls to +// tf_framework.report_error and create the required control flow to abort the +// function on failed execution. +std::unique_ptr> CreateRewriteTFFrameworkAssert(); + +} // namespace tf_framework + +namespace transforms { + +// Pass to find and annotate candidates for buffer reuse. +std::unique_ptr> CreateBufferReusePass(); + +// Pass to rewrite all functions to JIT invocations through the TF +// framework. +std::unique_ptr> CreateFuncToJITInvocationPass( + llvm::ArrayRef tile_sizes = {}, + llvm::ArrayRef unroll_factors = {}, bool enable_ftz = false, + bool index_64bit = false, bool cpu_codegen = false, + bool jit_i64_indexed_for_large_tensors = false); + +// Pass for applying LLVM legalization patterns. +std::unique_ptr> CreateTFKernelToLLVMPass( + mlir::StringRef blob_annotation = {}); + +// Pass to tranform shape computations in shape dialect to standard and scf +// using memref descriptors. +std::unique_ptr> CreateShapeToDescriptorsPass(); + +// Pass to convert scf::ParallelOp to scf::ForOp. +std::unique_ptr> CreateParallelLoopsToSequential(); + +// Pass to annotate GPU Module with its PTX. +std::unique_ptr> CreateGpuKernelToBlobPass( + mlir::StringRef blob_annotation = {}, + ArrayRef architectures = {}, bool print_ptx = false, + bool print_llvmir = false, bool enable_ftz = false); + +// Pass to propagate tensorflow runtime ABI knowledge across kernel boundaries. +std::unique_ptr> +CreatePropagateTfAbiKnowledgeToKernels(); + +// Pass to propagate shape equalities across kernel boundaries. +std::unique_ptr> +CreatePropagateShapeKnowledgeToKernels(); + +/// Greedily maps loops to GPU hardware dimensions. +std::unique_ptr> CreateMapParallelLoopsPass(); + +/// We need to direct fusion to the inner loops. This cannot be done with +/// a passmanager alone ATM, as nested pass managers require operations to +/// be closed from above. +std::unique_ptr> +CreateFuseInnerParallelLoopsPass(); + +// Pass to remove copies which are consumed by a GenericOp. +std::unique_ptr> CreateCopyCleanupPass(); + +std::unique_ptr> CreateKernelgenFinalBufferizePass(); + +} // namespace transforms + +#define GEN_PASS_REGISTRATION +#include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/kernel_gen_passes.h.inc" + +} // namespace kernel_gen +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_TOOLS_KERNEL_GEN_TRANSFORMS_PASSES_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tools/kernel_gen/transforms/rewriters.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tools/kernel_gen/transforms/rewriters.h new file mode 100644 index 00000000..e85d14d0 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tools/kernel_gen/transforms/rewriters.h @@ -0,0 +1,59 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TOOLS_KERNEL_GEN_TRANSFORMS_REWRITERS_H_ +#define TENSORFLOW_COMPILER_MLIR_TOOLS_KERNEL_GEN_TRANSFORMS_REWRITERS_H_ + +#include "mlir/IR/MLIRContext.h" // from @llvm-project + +namespace mlir { +namespace bufferization { +class BufferizeTypeConverter; +} +class ConversionTarget; +class LLVMTypeConverter; +class MLIRContext; +class RewritePatternSet; +class TypeConverter; + +namespace kernel_gen { +namespace tf_framework { + +/// Collects a set of patterns to convert from the TF Framework dialect to LLVM. +void PopulateTFFrameworkToLLVMConversionPatterns(LLVMTypeConverter *converter, + RewritePatternSet *patterns); + +/// Collects a set of patterns to rewrite functions for use with TF framework +/// and also replace `alloc`, `dealloc` and `assert`. +void PopulateEmbedTFFrameworkPatterns(RewritePatternSet *patterns); +void PopulateEmbedTFFrameworkAssertPattern(RewritePatternSet *patterns); + +} // namespace tf_framework + +namespace transforms { + +/// Collects a set of patterns that bufferize operations from the standard and +/// other dialects. +void populateExtraBufferizeDialects(DialectRegistry ®istry); +void populateExtraBufferizePatterns(ConversionTarget &target, + MLIRContext *context, + TypeConverter *converter, + RewritePatternSet *patterns); + +} // namespace transforms +} // namespace kernel_gen +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_TOOLS_KERNEL_GEN_TRANSFORMS_REWRITERS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tools/kernel_gen/transforms/utils.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tools/kernel_gen/transforms/utils.h new file mode 100644 index 00000000..e0b67b73 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tools/kernel_gen/transforms/utils.h @@ -0,0 +1,43 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TOOLS_KERNEL_GEN_TRANSFORMS_UTILS_H_ +#define TENSORFLOW_COMPILER_MLIR_TOOLS_KERNEL_GEN_TRANSFORMS_UTILS_H_ + +#include + +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project + +namespace mlir { +namespace kernel_gen { +namespace transforms { + +// Attempts to find function symbol in the module, adds it if not found. +FlatSymbolRefAttr GetOrInsertLLVMFunction(StringRef func_name, Type func_type, + Operation* op, OpBuilder* b); + +// Attemts to find a global string constant in the module, adds it if not found. +Value CreateOrFindGlobalStringConstant(Location loc, StringRef global_name, + StringRef content, OpBuilder* builder); + +// Generates a global name with the format "base_hash(content)". +std::string GetGlobalName(StringRef base, StringRef content); + +} // namespace transforms +} // namespace kernel_gen +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_TOOLS_KERNEL_GEN_TRANSFORMS_UTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tools/optimize/quantization_utils.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tools/optimize/quantization_utils.h new file mode 100644 index 00000000..aa22d546 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tools/optimize/quantization_utils.h @@ -0,0 +1,45 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_TOOLS_OPTIMIZE_QUANTIZATION_UTILS_H_ +#define TENSORFLOW_COMPILER_MLIR_TOOLS_OPTIMIZE_QUANTIZATION_UTILS_H_ + +#include +#include + +namespace tflite_migration { +namespace optimize { +namespace utils { + +template +std::vector SymmetricBiasQuantize(const float* data, + uint64_t num_elements, + const std::vector& scales); + +std::vector SymmetricQuantizeFloatsToInt16(const float* data, + uint64_t num_elements, + float scaling_factor); + +// Quantize the values given an array of scales. +void SymmetricPerChannelQuantizeValues(const float* input, + const std::vector& scales_inv, + const std::vector& dimension, + int32_t channel_dim_index, + std::vector* output_value); + +} // namespace utils +} // namespace optimize +} // namespace tflite_migration + +#endif // TENSORFLOW_COMPILER_MLIR_TOOLS_OPTIMIZE_QUANTIZATION_UTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tosa/tf_passes.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tosa/tf_passes.h new file mode 100644 index 00000000..53388a99 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tosa/tf_passes.h @@ -0,0 +1,37 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TOSA_TF_PASSES_H_ +#define TENSORFLOW_COMPILER_MLIR_TOSA_TF_PASSES_H_ + +#include "mlir/Pass/PassManager.h" // from @llvm-project +#include "mlir/Pass/PassOptions.h" // from @llvm-project + +namespace mlir { +namespace tosa { + +struct TOSATFLegalizationPipelineOptions + : public PassPipelineOptions {}; + +// Legalizes TF dialect(s) to Tosa. +void createTFtoTOSALegalizationPipeline( + OpPassManager& pm, const TOSATFLegalizationPipelineOptions& opts); + +void registerTFtoTOSALegalizationPipeline(); + +} // namespace tosa +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_TOSA_TF_PASSES_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tosa/tf_tfl_passes.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tosa/tf_tfl_passes.h new file mode 100644 index 00000000..93a67f9c --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tosa/tf_tfl_passes.h @@ -0,0 +1,39 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TOSA_TF_TFL_PASSES_H_ +#define TENSORFLOW_COMPILER_MLIR_TOSA_TF_TFL_PASSES_H_ + +#include "mlir/Pass/PassManager.h" // from @llvm-project +#include "mlir/Pass/PassOptions.h" // from @llvm-project + +namespace mlir { +namespace tosa { + +struct TOSATFTFLLegalizationPipelineOptions + : public PassPipelineOptions { + bool dequantize_tfl_softmax = false; +}; + +// Legalizes TF dialect(s) to Tosa. +void createTFTFLtoTOSALegalizationPipeline( + OpPassManager& pm, const TOSATFTFLLegalizationPipelineOptions& opts); + +void registerTFTFLtoTOSALegalizationPipeline(); + +} // namespace tosa +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_TOSA_TF_TFL_PASSES_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tosa/tfl_passes.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tosa/tfl_passes.h new file mode 100644 index 00000000..96d3cabf --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tosa/tfl_passes.h @@ -0,0 +1,59 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TOSA_TFL_PASSES_H_ +#define TENSORFLOW_COMPILER_MLIR_TOSA_TFL_PASSES_H_ + +#include +#include + +#include "llvm/Support/CommandLine.h" +#include "mlir/Pass/PassManager.h" // from @llvm-project +#include "mlir/Pass/PassOptions.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project + +namespace mlir { +namespace tosa { + +struct TOSATFLLegalizationPipelineOptions + : public PassPipelineOptions { + ArrayRef disabled_patterns; + ArrayRef enabled_patterns; + + PassOptions::Option target_compilation_backend{ + *this, "target-compilation-backend", + llvm::cl::desc("Whether targetting compilation backend"), + llvm::cl::init(false)}; + + PassOptions::Option dequantize_tfl_softmax{ + *this, "dequantize-tfl-softmax", + llvm::cl::desc("Dequantize the TFLite softmax"), llvm::cl::init(false)}; + + TOSATFLLegalizationPipelineOptions() { + disabled_patterns = std::nullopt; + enabled_patterns = std::nullopt; + } +}; + +// Legalizes TFL (TensorFlow lite) dialect(s) to Tosa. +void createTFLtoTOSALegalizationPipeline( + OpPassManager& pm, const TOSATFLLegalizationPipelineOptions& opts); + +void registerTFLtoTOSALegalizationPipeline(); + +} // namespace tosa +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_TOSA_TFL_PASSES_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tosa/transforms/legalize_common.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tosa/transforms/legalize_common.h new file mode 100644 index 00000000..cfe06340 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tosa/transforms/legalize_common.h @@ -0,0 +1,313 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TOSA_TRANSFORMS_LEGALIZE_COMMON_H_ +#define TENSORFLOW_COMPILER_MLIR_TOSA_TRANSFORMS_LEGALIZE_COMMON_H_ + +#include + +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project + +// This file contains legalizations common to mapping both TensorFlow and +// TensorFlow Lite to TOSA. +// +// Conversion functions return None on a failure or result value on success. +// Callers must check and return a LogicalResult failure on nullptr. +// +// For these functions, the framework-specific operands/attributes/defaults +// are already extracted and placed in a common form for lowering. + +namespace mlir { +namespace tosa { + +// Lowers the Pack operator to TOSA. +std::optional convertPackOp(PatternRewriter& rewriter, Operation* op, + Value result_value, + SmallVectorImpl& inputs, + int32_t axis); + +// Lowers the Unpack operator to TOSA. +std::optional> convertUnpackOp(PatternRewriter& rewriter, + Operation* op, + Value input_value, + int32_t axis); + +// Lowers the Select operator to TOSA. +std::optional convertSelectOp(PatternRewriter& rewriter, Operation* op, + Value result_value, Value condition_value, + Value x_value, Value y_value); + +// Lowers the ZerosLike operator to TOSA by creating a constant +// of the desired type and shape. +std::optional convertZerosLikeOp(PatternRewriter& rewriter, + Operation* op, Value result, + Value input); + +// Lowers the Mul operator to TOSA. For quantized types, this requires +// inserting rescale operators before and after the operation. +std::optional convertMultiplyOp(PatternRewriter& rewriter, Operation* op, + Value output_val, Value input_lhs_val, + Value input_rhs_val); + +// Lowers the SquaredDifference operator to TOSA. +std::optional convertSquaredDifferenceOp(PatternRewriter& rewriter, + Operation* op, Value result, + Value x, Value y); + +// Lowers the Round operator to TOSA. +std::optional convertRoundOp(PatternRewriter& rewriter, Operation* op, + Value result, Value input); + +// Lowers ConcatV2 to TOSA. +std::optional convertConcatV2Op(PatternRewriter& rewriter, Operation* op, + ShapedType result_type, + SmallVectorImpl& values, + int32_t axis); + +// Lowers SpaceToBatchND to TOSA. +std::optional convertSpaceToBatchNDOp(PatternRewriter& rewriter, + Operation* op, Value result_value, + Value input_value, + Value block_shape_value, + Value paddings_value); + +// Lowers BatchToSpaceND to TOSA. +std::optional convertBatchToSpaceNDOp(PatternRewriter& rewriter, + Operation* op, Value result_value, + Value input_value, + Value block_shape_value, + Value crops_value); + +// Lowers ExpandDims to TOSA. +std::optional convertExpandDimsOp(PatternRewriter& rewriter, + Operation* op, Value result_value, + Value input_value, Value dim_value); + +// Lowers Squeeze to TOSA. +std::optional convertSqueezeOp(PatternRewriter& rewriter, Operation* op, + Value result_value, Value input_value, + SmallVectorImpl& squeeze_dims); + +// Lowers ELU to a sequence of TOSA ops. +std::optional convertEluOp(PatternRewriter& rewriter, Operation* op, + Value result_value, Value features_value); + +// Lowers Softmax to a sequence of TOSA ops. +std::optional convertSoftmaxOp(PatternRewriter& rewriter, Operation* op, + Value result_value, Value logits_value, + double beta); + +// Lowers LogSoftmax to a sequence of TOSA ops. +std::optional convertLogSoftmaxOp(PatternRewriter& rewriter, + Operation* op, Value result_value, + Value logits_value); + +// Lowers SpaceToDepth to a sequence of TOSA ops. Supports NHWC. +std::optional convertSpaceToDepthOp(PatternRewriter& rewriter, + Operation* op, Value result_value, + Value input_value, + IntegerAttr block_size_attr, + StringAttr data_format); + +// Lowers DepthToSpace to a sequence of TOSA ops. Supports NHWC. +std::optional convertDepthToSpaceOp(PatternRewriter& rewriter, + Operation* op, Value result_value, + Value input_value, + IntegerAttr block_size_attr, + StringAttr data_format); + +// Lowers Split to a sequence of TOSA ops. +std::optional> convertSplitOp( + PatternRewriter& rewriter, Operation* op, Value result_value, + Value input_value, int32_t num_split, int32_t axis); + +// Lowers SplitV to a sequence of TOSA ops. +std::optional> convertSplitVOp( + PatternRewriter& rewriter, Operation* op, Value result_value, + Value input_value, SmallVectorImpl& size_split, int32_t axis); + +// Lowers StridedSlice to a sequence of TOSA ops. +std::optional convertStridedSliceOp( + PatternRewriter& rewriter, Operation* op, Value result_value, + Value input_value, Value begin_value, Value end_value, Value strides_value, + int32_t begin_mask, int32_t end_mask, int32_t ellipsis_mask, + int32_t new_axis_mask, int32_t shrink_axis_mask); + +// Lowers FloorDiv to a sequence of TOSA operators. +std::optional convertFloorDivOp(PatternRewriter& rewriter, Operation* op, + Value result_value, Value lhs_value, + Value rhs_value); + +// Lowers FloorMod to a sequence of TOSA operators. +std::optional convertFloorModOp(PatternRewriter& rewriter, Operation* op, + Value result_value, Value lhs_value, + Value rhs_value); + +// Lowers FusedActivation to a sequence of TOSA ops. +std::optional convertFusedActivation(PatternRewriter& rewriter, + Operation* op, Value input_value, + StringAttr fused_activation_fn); + +// Helper function for implementing quantized divide by power-of-two in TOSA +// ops. +std::optional convertRoundingDivideByPOT(PatternRewriter& rewriter, + Operation* op, + Value input_value, + Value rshift_value); + +// Lowers ReduceAll to a sequence of TOSA ops. +std::optional convertReduceAllOp(PatternRewriter& rewriter, + Operation* op, + RankedTensorType output_type, + Value input_value, + ElementsAttr axes_elems); + +// Lowers ReduceAny to a sequence of TOSA ops. +std::optional convertReduceAnyOp(PatternRewriter& rewriter, + Operation* op, + RankedTensorType output_type, + Value input_value, + ElementsAttr axes_elems); + +// Lowers ReduceMin to a sequence of TOSA ops. +std::optional convertReduceMinOp(PatternRewriter& rewriter, + Operation* op, + RankedTensorType output_type, + Value input_value, + ElementsAttr axes_elems); + +// Lowers ReduceMax to a sequence of TOSA ops. +std::optional convertReduceMaxOp(PatternRewriter& rewriter, + Operation* op, + RankedTensorType output_type, + Value input_value, + ElementsAttr axes_elems); + +// Lowers ReduceProd to a sequence of TOSA ops. +std::optional convertReduceProdOp(PatternRewriter& rewriter, + Operation* op, + RankedTensorType output_type, + Value input_value, + ElementsAttr axes_elems); + +// Lowers ReduceSum to a sequence of TOSA ops. +std::optional convertReduceSumOp(PatternRewriter& rewriter, + Operation* op, + RankedTensorType output_type, + Value input_value, + ElementsAttr axes_elems); + +// Lowers ReduceMean to a sequence of TOSA ops. +std::optional convertReduceMeanOp(PatternRewriter& rewriter, + Operation* op, + RankedTensorType output_type, + Value input_value, + ElementsAttr axes_elem); + +// Lowers ResizeBilinear and ResizeNearestNeighbor to TOSA resize. +std::optional convertResizeOp(PatternRewriter& rewriter, Operation* op, + RankedTensorType output_type, + Value input_value, StringRef mode, + bool align_corners, + bool half_pixel_centers); + +// Lowers Quantize to a sequence of TOSA quantization ops. +std::optional convertQuantizeOp(PatternRewriter& rewriter, Operation* op, + ShapedType output_type, + Value input_value, double scale, + int64_t zeropoint); + +// Lowers Dequantize to a sequence of TOSA dequantization ops. +std::optional convertDequantizeOp(PatternRewriter& rewriter, + Operation* op, ShapedType output_type, + Value input_value, + ArrayRef scale, + ArrayRef zeropoint, + int64_t dim); + +// Lowers FakeQuant to a sequence of TOSA quantization ops. +std::optional convertFakeQuantOp(PatternRewriter& rewriter, + Operation* op, ShapedType output_type, + Value input_value, double min, + double max, int64_t num_bits, + bool narrow_range); + +// Align to TF_MirrorPadOp::mode and TFL_MirrorPadOp::mode +enum class TFTFLMirrorPaddingType : uint32_t { + REFLECT = 0, + SYMMETRIC = 1, +}; + +std::optional convertMirrorPadCommon(PatternRewriter& rewriter, + Operation* op, + RankedTensorType output_type, + Value input, Value pad, + TFTFLMirrorPaddingType mode); + +// Lowers TensorFlow Conv2D to a sequence of TOSA quantization ops. +std::optional convertTFConv2DCommon( + PatternRewriter& rewriter, Operation* op, RankedTensorType output_type, + Value input, Value filter, Value bias, ArrayAttr strides_attr, + ArrayAttr dilations_attr, ArrayAttr explicit_padding_attr, + StringRef padding_ref, StringRef data_format_ref); + +// Lowers TensorFlow and TensorFlow Lite Conv3D to a sequence of TOSA +// quantization ops. +std::optional convertConv3DCommon(PatternRewriter& rewriter, + Operation* op, ShapedType output_type, + Value input, Value filter, Value bias, + ArrayRef strides, + ArrayRef dilations, + StringRef padding_ref, + StringRef data_format_ref); + +// Preprocess TensorFlow Conv3D attributes prior to calling +// `convertConv3DCommon` +std::optional convertTFConv3DCommon( + PatternRewriter& rewriter, Operation* op, ShapedType output_type, + Value input, Value filter, Value bias, ArrayAttr strides_attr, + ArrayAttr dilations_attr, StringRef padding_ref, StringRef data_format_ref); + +// Lowers Gather operator to a sequence of TOSA ops. +std::optional convertGatherOp(PatternRewriter& rewriter, Operation* op, + Value params_value, Value indices_value, + int32_t batch_dims, int32_t axis, + bool tosaOnly = true); + +// Lowers GatherNd operator to a sequence of TOSA ops. +std::optional convertGatherNdOp(PatternRewriter& rewriter, Operation* op, + Value result_value, Value params_value, + Value indices_value); + +// Lowers OneHot operator to a sequence of TOSA ops. +std::optional convertOneHotOp(PatternRewriter& rewriter, Operation* op, + Value result_value, Value indices_value, + Value on_value, Value off_value, + int32_t depth, int32_t axis); + +// Lowers Sign operator to a sequence of TOSA ops. +std::optional convertSignOp(PatternRewriter& rewriter, Operation* op, + Value input, RankedTensorType output_type); + +// Lowers BroadcastTo operator to a sequence of TOSA ops. +std::optional convertBroadcastToOp(PatternRewriter& rewriter, + Operation* op, Value input, + Value shape); + +}; // namespace tosa +}; // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_TOSA_TRANSFORMS_LEGALIZE_COMMON_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tosa/transforms/legalize_utils.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tosa/transforms/legalize_utils.h new file mode 100644 index 00000000..c576504d --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tosa/transforms/legalize_utils.h @@ -0,0 +1,259 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TOSA_TRANSFORMS_LEGALIZE_UTILS_H_ +#define TENSORFLOW_COMPILER_MLIR_TOSA_TRANSFORMS_LEGALIZE_UTILS_H_ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Dialect/Quant/IR/QuantTypes.h" // from @llvm-project +#include "mlir/Dialect/Tosa/Utils/ShapeUtils.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/Interfaces/InferTypeOpInterface.h" // from @llvm-project +#include "mlir/Rewrite/FrozenRewritePatternSet.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/dynamic_shape_utils.h" +#include "tensorflow/core/framework/kernel_shape_util.h" +#include "tensorflow/core/kernels/conv_grad_shape_utils.h" +#include "tensorflow/core/util/padding.h" +#include "tensorflow/core/util/tensor_format.h" + +namespace mlir { +namespace tosa { + +LogicalResult getDynamicDims(PatternRewriter& rewriter, Value value, + llvm::SmallVector& dims); + +std::optional buildReshapeWithDynamicDims(PatternRewriter& rewriter, + Operation* op, + Value input_value, + ShapedType output_type, + llvm::ArrayRef dims); + +// Create a TOSA rescale op from TFLite scaling multiplier, scaling shift, zero +// points and rounding mode +Value buildRescale(PatternRewriter& rewriter, Operation* op, + ShapedType output_type, Value input_val, + int32_t scale_multiplier, int32_t scale_shit, + int64_t input_zp, int64_t output_zp, bool double_round, + bool scale32); + +// Create a TOSA rescale op from TFLite scaling, zero points and rounding mode +Value buildRescale(PatternRewriter& rewriter, Operation* op, + ShapedType output_type, Value input_val, double scale, + int64_t input_zp, int64_t output_zp, bool double_round, + bool scale32); + +// Removes the zero point and cast to int32, no need to handle roundings modes +Value removeZeroPointAndCastToInt32(PatternRewriter& rewriter, Operation* op, + Value input_val, int64_t input_zp); + +// Creates TOSA rescale op with int32 output +Value buildRescaleToInt32(PatternRewriter& rewriter, Operation* op, + Value input_val, int32_t input_scale_multiplier, + int32_t input_scale_shift, int64_t input_zp); + +// Creates TOSA rescale op with int32 output +Value buildRescaleToInt32(PatternRewriter& rewriter, Operation* op, + Value input_val, double input_scale, + int64_t input_zp); + +// Creates TOSA rescale op with int32 input +Value buildRescaleFromInt32(PatternRewriter& rewriter, Operation* op, + ShapedType output_type, Value input_val, + double output_scale, int64_t output_zp); + +// Creates a TOSA rescale op based on conv2d parameters. +Value buildRescaleOpConvOutput(PatternRewriter& rewriter, Operation* op, + Value conv_val, ShapedType input_type, + ShapedType weight_type, ShapedType output_type); + +// Create a 8-bit TOSA TABLE constant tensor +Value getTosaConst8bitTable(PatternRewriter& rewriter, Operation* op, + double input_scale, int32_t input_zp, + double output_scale, int32_t output_zp, + std::function func); + +// Create a 16-bit TOSA TABLE constant tensor +Value getTosaConst16bitTable(PatternRewriter& rewriter, Operation* op, + std::function func, double min, + double max); + +// Create a 32-bit TOSA TABLE for Softmax Exp +void getTosaConst32bitSoftmaxExpTable(PatternRewriter& rewriter, Operation* op, + double beta, double input_scale, + Value& first_const, Value& second_const, + Value& third_const, Value& fourth_const); + +// Create 8 bit TOSA TABLE constant tensor for the RSqrt operator +Value getTosaConstRsqrt8bitTable(PatternRewriter& rewriter, Operation* op, + float input_scale, int32_t input_zp, + float output_scale, int32_t output_zp); + +// Create a 32-bit float constant operator from a float +Value getTosaConstTensorSingleF32(PatternRewriter& rewriter, Operation* op, + float val); + +// Create a 32-bit integer constant operator from an int +Value getTosaConstTensorSingleI32(PatternRewriter& rewriter, Operation* op, + int32_t val); + +// Create an expected bitwidth integer constant operator based on the type +// parameter. +Value getTosaConstTensorScalarInt(ImplicitLocOpBuilder& builder, Type type, + int64_t val); + +// Create a vector from a 32-bit value tensor. Returns vector size on success +// or -1 on error. +LogicalResult getVectorFromValue32(Value val, SmallVectorImpl& vec); + +// Calculates the TOSA padding values based on TF operators padded with +// SAME/VALID. +bool getPaddingValuesFromPadType(tensorflow::Padding tf_pad, + tensorflow::TensorFormat data_format_tf, + uint32_t first_filter_spatial_dim, + ShapedType input_type, ShapedType filter_type, + DenseI64ArrayAttr strides, + DenseI64ArrayAttr dilations, + PatternRewriter& rewriter, + DenseI64ArrayAttr& explicit_pad); + +// Calculates the TOSA padding values for explicit-padded TF operators. +DenseI64ArrayAttr getPaddingValuesFromExplicitPadAttr( + ArrayAttr explicit_pad, tensorflow::TensorFormat data_format_tf, + PatternRewriter& rewriter); + +// Calculates the TOSA padding values for transposeConv2d +bool getTransposeConv2dPaddingValues( + tensorflow::Padding tf_pad, tensorflow::TensorFormat data_format_tf, + uint32_t first_filter_spatial_dim, ShapedType input_type, + ShapedType filter_type, ShapedType output_type, DenseI64ArrayAttr strides, + PatternRewriter& rewriter, DenseI64ArrayAttr& explicit_pad); + +// Templated function to create a constant op for given type and shape. +// T: storage C type. +// Default template creates a constant tensor in T. +// To create INT48 TOSA constant, need to pass in llvm::APInt instead. +template +std::optional getConstTensor(PatternRewriter& rewriter, Operation* op, + ArrayRef vec, ArrayRef shape); + +// Check if scale32 mode is used for given output_element_type +bool isScale32(mlir::quant::UniformQuantizedType output_element_type); + +// Applies a set of patterns greedily to the specified function, then applies +// a cleanup to guarantee the function contract and constants are valid. This +// means patterns can performed shape inference while not altering immutable +// types. +LogicalResult ApplyPatternsWithShapeResolution( + func::FuncOp func, const FrozenRewritePatternSet& patterns); + +// Creates a TOSA operation and performs shape inference on the individual +// op. This allows shape inference during the TFLite to TOSA lowering. +template +TosaOp CreateOpAndInfer(ImplicitLocOpBuilder& builder, Type result_ty, + Args&&... args) { + auto op = builder.create(result_ty, args...); + + InferShapedTypeOpInterface shapeInterface = + dyn_cast(op.getOperation()); + if (!shapeInterface) return op; + + SmallVector returnedShapes; + if (shapeInterface + .inferReturnTypeComponents(op.getContext(), builder.getLoc(), + op->getOperands(), op->getAttrDictionary(), + op->getPropertiesStorage(), + op->getRegions(), returnedShapes) + .failed()) + return op; + + // We need to use the element type of the existing result type to generate + // the new result shaped type. This is because rescale can include a cast to + // different bit-width types and does not have a TypeAttr to define the + // target type. + auto result = op->getResult(0); + auto predictedShape = returnedShapes[0]; + auto currentKnowledge = ValueKnowledge::getKnowledgeFromType(result_ty); + + // Compute the knowledge based on the inferred type. + auto inferredKnowledge = ValueKnowledge::getPessimisticValueState(); + inferredKnowledge.dtype = mlir::cast(result_ty).getElementType(); + inferredKnowledge.hasRank = predictedShape.hasRank(); + if (predictedShape.hasRank()) { + for (auto dim : predictedShape.getDims()) { + inferredKnowledge.sizes.push_back(dim); + } + } + + // Compute the new type based on the joined version. + auto newKnowledge = ValueKnowledge::join(currentKnowledge, inferredKnowledge); + Type new_ty = + newKnowledge.hasRank + ? Type{tensorflow::GetTypeFromTFTensorShape( + llvm::ArrayRef(newKnowledge.sizes), newKnowledge.dtype)} + : Type{mlir::UnrankedTensorType::get(newKnowledge.dtype)}; + result.setType(new_ty); + return op; +} + +template +TosaOp CreateOpAndInfer(PatternRewriter& rewriter, Location loc, Type result_ty, + Args&&... args) { + ImplicitLocOpBuilder builder(loc, rewriter); + return CreateOpAndInfer(builder, result_ty, args...); +} + +template +void CreateReplaceOpAndInfer(PatternRewriter& rewriter, Operation* op, + Type result_ty, Args&&... args) { + auto result = + CreateOpAndInfer(rewriter, op->getLoc(), result_ty, args...); + rewriter.replaceOp(op, result->getResults()); +} + +void TrimQuantizedIntegerRangeMin(mlir::quant::UniformQuantizedType dtype, + int64_t& val_min); + +void TrimQuantizedIntegerRangeMax(mlir::quant::UniformQuantizedType dtype, + int64_t& val_max); + +void TrimQuantizedIntegerRange(mlir::quant::UniformQuantizedType dtype, + int64_t& val_min, int64_t& val_max); + +inline bool IsTFLDoubleRoundingMode() { +#if TFLITE_SINGLE_ROUNDING + return false; +#else + return true; +#endif // TFLITE_SINGLE_ROUNDING +} + +} // namespace tosa +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_TOSA_TRANSFORMS_LEGALIZE_UTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/tosa/transforms/passes.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tosa/transforms/passes.h new file mode 100644 index 00000000..de0872b6 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/tosa/transforms/passes.h @@ -0,0 +1,94 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TOSA_TRANSFORMS_PASSES_H_ +#define TENSORFLOW_COMPILER_MLIR_TOSA_TRANSFORMS_PASSES_H_ + +#include +#include +#include +#include + +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project + +namespace mlir { + +namespace quant { +class QuantDialect; +} + +namespace quantfork { +class QuantizationForkDialect; +} + +namespace TFL { +class TFLDialect; +} + +namespace tosa { +class TosaDialect; + +void populateLegalizeTFPatterns(MLIRContext* ctx, RewritePatternSet& patterns); +void populateLegalizeTFLPatterns(MLIRContext* ctx, RewritePatternSet& patterns); + +std::unique_ptr> createLegalizeTFPass(); +std::unique_ptr> createFuseBiasTFPass(); + +// `disabledPatterns` is a set of labels used to filter out input patterns with +// a debug label or debug name in this set. +// `enabledPatterns` is a set of labels used to filter out input patterns that +// do not have one of the labels in this set. +std::unique_ptr> createLegalizeTFLPass( + ArrayRef disabled_patterns = std::nullopt, + ArrayRef enabled_patterns = std::nullopt); + +std::unique_ptr> createRetainCallOnceFuncsPass(); +std::unique_ptr> createStripModuleMetadataPass(); +std::unique_ptr> createConvertTFLUint8Pass(); +std::unique_ptr> +createConvertFunctionMetadataPass(); +std::unique_ptr> createDequantizeTFLSoftmaxPass(); +std::unique_ptr> createLegalizeTFTFLPass(); +std::unique_ptr> createLowerComplexTypesPass(); +std::unique_ptr> createStripFunctionMetadataPass(); +std::unique_ptr> createStripQuantTypesPass(); +std::unique_ptr> createVerifyFullyConvertedPass(); +std::unique_ptr> createLegalizeTFLStatefulPass(); + +#define GEN_PASS_REGISTRATION +#define GEN_PASS_CLASSES +#define GEN_PASS_DECL_TOSALEGALIZETFPASS +#define GEN_PASS_DECL_TOSALEGALIZETFLPASS +#define GEN_PASS_DECL_TOSALEGALIZETFTFLPASS +#define GEN_PASS_DECL_TOSAFUSEBIASTFPASS +#define GEN_PASS_DECL_TOSACONVERTTFLUINT8PASS +#define GEN_PASS_DECL_TOSASTRIPQUANTTYPESPASS +#define GEN_PASS_DECL_TOSALOWERCOMPLEXTYPESPASS +#define GEN_PASS_DECL_TOSADEQUANTIZETFLSOFTMAXPASS +#define GEN_PASS_DECL_RETAINCALLONCEFUNCS +#define GEN_PASS_DECL_STRIPFUNCTIONMETADATA +#define GEN_PASS_DECL_STRIPMODULEMETADATA +#define GEN_PASS_DECL_VERIFYFULLYCONVERTED +#define GEN_PASS_DECL_CONVERTFUNCTIONMETADATA +#define GEN_PASS_DECL_TOSALEGALIZESTATEFULPASS + +#include "tensorflow/compiler/mlir/tosa/transforms/passes.h.inc" + +} // namespace tosa +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_TOSA_TRANSFORMS_PASSES_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/utils/array_container_utils.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/utils/array_container_utils.h new file mode 100644 index 00000000..80fa14e2 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/utils/array_container_utils.h @@ -0,0 +1,46 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_UTILS_ARRAY_CONTAINER_UTILS_H_ +#define TENSORFLOW_COMPILER_MLIR_UTILS_ARRAY_CONTAINER_UTILS_H_ + +#include "absl/types/span.h" +#include "llvm/ADT/ArrayRef.h" + +namespace mlir { + +template +inline llvm::ArrayRef SpanToArrayRef(absl::Span span) { + return llvm::ArrayRef(span.data(), span.size()); +} + +template +inline llvm::ArrayRef SpanToArrayRef(absl::Span span) { + return llvm::ArrayRef(span.data(), span.size()); +} + +template +inline llvm::MutableArrayRef SpanToMutableArrayRef(absl::Span span) { + return llvm::MutableArrayRef(span.data(), span.size()); +} + +template +inline absl::Span ArrayRefToSpan(llvm::ArrayRef ref) { + return absl::Span(ref.data(), ref.size()); +} + +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_UTILS_ARRAY_CONTAINER_UTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/utils/name_utils.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/utils/name_utils.h new file mode 100644 index 00000000..356b4d25 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/utils/name_utils.h @@ -0,0 +1,34 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_UTILS_NAME_UTILS_H_ +#define TENSORFLOW_COMPILER_MLIR_UTILS_NAME_UTILS_H_ + +#include + +#include "mlir/IR/Location.h" // from @llvm-project + +namespace mlir { + +// Converts characters in name that are considered illegal in TensorFlow Node +// name to '.'. +void LegalizeNodeName(std::string& name); + +// Returns the TensorFlow node name associated with a location. +std::string GetNameFromLoc(Location loc); + +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_UTILS_NAME_UTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/mlir/utils/string_container_utils.h b/third_party/tflite-hdrs/tensorflow/compiler/mlir/utils/string_container_utils.h new file mode 100644 index 00000000..fb2fa06c --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/mlir/utils/string_container_utils.h @@ -0,0 +1,34 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_UTILS_STRING_CONTAINER_UTILS_H_ +#define TENSORFLOW_COMPILER_MLIR_UTILS_STRING_CONTAINER_UTILS_H_ + +#include "absl/strings/string_view.h" +#include "llvm/ADT/StringRef.h" + +namespace mlir { + +inline absl::string_view StringRefToView(llvm::StringRef ref) { + return absl::string_view(ref.data(), ref.size()); +} + +inline llvm::StringRef StringViewToRef(absl::string_view view) { + return llvm::StringRef(view.data(), view.size()); +} + +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_UTILS_STRING_CONTAINER_UTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/tf2tensorrt/common/datavec.h b/third_party/tflite-hdrs/tensorflow/compiler/tf2tensorrt/common/datavec.h new file mode 100644 index 00000000..eff32f1f --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/tf2tensorrt/common/datavec.h @@ -0,0 +1,38 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2TENSORRT_COMMON_DATAVEC_H_ +#define TENSORFLOW_COMPILER_TF2TENSORRT_COMMON_DATAVEC_H_ + +#include + +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +namespace tensorrt { + +// Input/output data format for OpConverterTest::BuildAndRun(). +struct InputOutputData { + size_t TotalBytes() const { return tensor.TotalBytes(); } + string name; + Tensor tensor; +}; + +using DataVec = std::vector; + +} // namespace tensorrt +} // namespace tensorflow +#endif diff --git a/third_party/tflite-hdrs/tensorflow/compiler/tf2tensorrt/common/utils.h b/third_party/tflite-hdrs/tensorflow/compiler/tf2tensorrt/common/utils.h new file mode 100644 index 00000000..0bc63ecd --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/tf2tensorrt/common/utils.h @@ -0,0 +1,175 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2TENSORRT_COMMON_UTILS_H_ +#define TENSORFLOW_COMPILER_TF2TENSORRT_COMMON_UTILS_H_ + +#include +#include + +#include "absl/strings/str_join.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { +namespace tensorrt { +// Returns the compile time TensorRT library version information +// {Maj, Min, Patch}. +std::tuple GetLinkedTensorRTVersion(); + +// Returns the runtime time TensorRT library version information +// {Maj, Min, Patch}. +std::tuple GetLoadedTensorRTVersion(); +} // namespace tensorrt +} // namespace tensorflow + +#if GOOGLE_CUDA && GOOGLE_TENSORRT + +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/status.h" +#include "third_party/tensorrt/NvInfer.h" + +#define ERROR_LOC __FILE__, ":", __LINE__ + +#define TFTRT_INTERNAL_ERROR_AT_NODE(node) \ + return errors::Internal("TFTRT::", __FUNCTION__, "\n", ERROR_LOC, \ + " failed to add TRT layer, at: ", node); + +#define TFTRT_RETURN_ERROR_IF_NULLPTR(ptr, node) \ + if (ptr == nullptr) { \ + TFTRT_INTERNAL_ERROR_AT_NODE(node); \ + } + +// Use this macro within functions that return a Status or StatusOR to check +// boolean conditions. If the condition fails, it returns an +// errors::Internal message with the file and line number. +#define TRT_ENSURE(x) \ + if (!(x)) { \ + return errors::Internal(ERROR_LOC, " TRT_ENSURE failure"); \ + } + +// Checks that a Status or StatusOr object does not carry an error message. +// If it does have an error, returns an errors::Internal instance +// containing the error message, along with the file and line number. For +// pointer-containing StatusOr, use the below TRT_ENSURE_PTR_OK macro. +#define TRT_ENSURE_OK(x) \ + if (!x.ok()) { \ + return errors::Internal(ERROR_LOC, " TRT_ENSURE_OK failure:\n ", \ + x.status().ToString()); \ + } + +// Checks that a StatusOrobject does not carry an error, and that the +// contained T* is non-null. If it does have an error status, returns an +// errors::Internal instance containing the error message, along with the file +// and line number. +#define TRT_ENSURE_PTR_OK(x) \ + TRT_ENSURE_OK(x); \ + if (*x == nullptr) { \ + return errors::Internal(ERROR_LOC, " pointer had null value"); \ + } + +namespace tensorflow { +namespace tensorrt { + +#define IS_TRT_VERSION_GE(major, minor, patch, build) \ + ((NV_TENSORRT_MAJOR > major) || \ + (NV_TENSORRT_MAJOR == major && NV_TENSORRT_MINOR > minor) || \ + (NV_TENSORRT_MAJOR == major && NV_TENSORRT_MINOR == minor && \ + NV_TENSORRT_PATCH > patch) || \ + (NV_TENSORRT_MAJOR == major && NV_TENSORRT_MINOR == minor && \ + NV_TENSORRT_PATCH == patch && NV_TENSORRT_BUILD >= build)) + +#define LOG_WARNING_WITH_PREFIX LOG(WARNING) << "TF-TRT Warning: " + +// Initializes the TensorRT plugin registry if this hasn't been done yet. +void MaybeInitializeTrtPlugins(nvinfer1::ILogger* trt_logger); + +class IONamePrefixes { + public: + static constexpr const char* const kInputPHName = "TensorRTInputPH_"; + static constexpr const char* const kOutputPHName = "TensorRTOutputPH_"; +}; + +// Gets the binding index of a tensor in an engine. +// +// The binding index is looked up using the tensor's name and the profile index. +// Profile index should be set to zero, if we do not have optimization profiles. +Status GetTrtBindingIndex(const char* tensor_name, int profile_index, + const nvinfer1::ICudaEngine* cuda_engine, + int* binding_index); + +// Gets the binding index of a tensor in an engine. +// +// Same as above, but uses the network input index to identify the tensor. +Status GetTrtBindingIndex(int network_input_idx, int profile_index, + const nvinfer1::ICudaEngine* cuda_engine, + int* binding_index); +} // namespace tensorrt +} // namespace tensorflow + +namespace nvinfer1 { +// Prints nvinfer1::Dims or any drived type to the given ostream. Per GTest +// printing requirements, this must be in the nvinfer1 namespace. +inline std::ostream& operator<<(std::ostream& os, const nvinfer1::Dims& v) { + os << "nvinfer1::Dims["; + os << absl::StrJoin(std::vector(v.d, v.d + v.nbDims), ","); + os << "]"; + return os; +} + +// Returns true if any two derived nvinfer1::Dims type structs are equivalent. +inline bool operator==(const nvinfer1::Dims& lhs, const nvinfer1::Dims& rhs) { + if (rhs.nbDims != lhs.nbDims) { + return false; + } + for (int i = 0; i < lhs.nbDims; i++) { + if (rhs.d[i] != lhs.d[i]) { + return false; + } + } + return true; +} + +// Returns false if any 2 subclasses of nvinfer1::Dims are equivalent. +inline bool operator!=(const nvinfer1::Dims& lhs, const nvinfer1::Dims& rhs) { + return !(rhs == lhs); +} + +// Prints nvinfer1::INetworkDefinition* information to the given ostream. +inline std::ostream& operator<<(std::ostream& os, + nvinfer1::INetworkDefinition* n) { + os << "nvinfer1::INetworkDefinition{\n"; + std::vector layer_idxs(n->getNbLayers()); + std::iota(layer_idxs.begin(), layer_idxs.end(), 0); + os << absl::StrJoin(layer_idxs, "\n ", + [n](std::string* out, const int layer_idx) { + out->append(n->getLayer(layer_idx)->getName()); + }); + os << "}"; + return os; +} + +// Prints the TensorFormat enum name to the stream. +std::ostream& operator<<(std::ostream& os, + const nvinfer1::TensorFormat& format); + +// Prints the DataType enum name to the stream. +std::ostream& operator<<(std::ostream& os, const nvinfer1::DataType& data_type); + +} // namespace nvinfer1 + +#endif // GOOGLE_CUDA && GOOGLE_TENSORRT + +#endif // TENSORFLOW_COMPILER_TF2TENSORRT_COMMON_UTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/tf2tensorrt/convert/algorithm_selector.h b/third_party/tflite-hdrs/tensorflow/compiler/tf2tensorrt/convert/algorithm_selector.h new file mode 100644 index 00000000..0a9ee702 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/tf2tensorrt/convert/algorithm_selector.h @@ -0,0 +1,121 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_ALGORITHM_SELECTOR_H_ +#define TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_ALGORITHM_SELECTOR_H_ +#if GOOGLE_CUDA && GOOGLE_TENSORRT +#include +#include +#include + +#include "absl/types/optional.h" +#include "third_party/tensorrt/NvInfer.h" + +namespace tensorflow { +namespace tensorrt { +namespace convert { + +// Implements core algorithm selection logic in a testable manner. The policy +// implemented depends on the given TRT version. We have this class because TRT +// interfaces make it difficult to directly test an IAlgorithmSelector +// implementation. +class AlgorithmSelectorImpl { + public: + using TRTVersion = std::array; + using ImplementationID = int64_t; + using TacticID = int64_t; + + static constexpr TRTVersion CompileTimeTRTVersion() { + return TRTVersion{NV_TENSORRT_MAJOR, NV_TENSORRT_MINOR, NV_TENSORRT_PATCH, + NV_TENSORRT_BUILD}; + } + + explicit AlgorithmSelectorImpl( + const TRTVersion& version = CompileTimeTRTVersion()) + : version_(version) {} + + bool IsShuffleLayer(ImplementationID id) const; + + bool IsBannedTactic(TacticID id) const; + + // Returns true if the algorithm implementing the IShuffleLayer is acceptable. + bool AllowShuffleAlgorithm(TacticID tactic, nvinfer1::DataType input_dtype, + nvinfer1::TensorFormat input_format) const; + + bool IsTrtVersionGE(const TRTVersion& version) const; + + // Returns true if we know at compile time that the algorithm selector + // should be required. This is a conservative estimate. + bool IsAlgorithmSelectorRequired() const; + + static std::set GetBannedTRT72TuringTactics(); + + private: + TRTVersion version_; +}; + +// Implements the TRT IAlgorithmSelector interface. The method +// "selectAlgorithms" selects allowable algorithms for each layer, and +// "reportAlgorithms" summarizes the algorithms selected by TensorRT. +class TftrtAlgorithmSelector : public nvinfer1::IAlgorithmSelector { + private: + using TacticID = AlgorithmSelectorImpl::TacticID; + + // An index we should choose for all algorithms. Used for debugging. + std::optional fixed_algorithm_idx_; + + AlgorithmSelectorImpl selector_; + + public: + TftrtAlgorithmSelector(); + + // If the environment variable TF_TRT_FIXED_ALGORITHM_ID is empty, this + // function returns nullopt. Otherwise, it returns the specified number. + static std::optional GetFixedAlgorithmID(); + + // Returns true if the algorithm associated with context is acceptable. + bool AlgorithmPolicy(const nvinfer1::IAlgorithmContext& context, + const nvinfer1::IAlgorithm& alg) const; + + // This function fills the array "selection" with the indices of selected + // algorithm candidates from "algoChoices", each of which is an implementation + // for the kernel described by the given IAlgorithmContext. It should return a + // number in [0, nbChoices] indicating the number of selected indices. If 0 is + // returned, TensorRT will use its default selection mechanism. + int32_t selectAlgorithms(const nvinfer1::IAlgorithmContext& algoContext, + const nvinfer1::IAlgorithm* const* algoChoices, + int32_t nbChoices, + int32_t* selection) noexcept override; + + // Called by TensorRT to report choices it made. + void reportAlgorithms(const nvinfer1::IAlgorithmContext* const* algoContexts, + const nvinfer1::IAlgorithm* const* algoChoices, + int32_t nbAlgorithms) noexcept override; + + bool IsRequired() const { + return selector_.IsAlgorithmSelectorRequired() || + fixed_algorithm_idx_ != std::nullopt; + } +}; + +// Returns an initialized AlgorithmSelector if an algorithm selector is required +// for the current TRT version. Otherwise, returns nullptr. +std::unique_ptr MaybeCreateAlgorithmSelector(); + +} // namespace convert +} // namespace tensorrt +} // namespace tensorflow + +#endif // GOOGLE_CUDA && GOOGLE_TENSORRT +#endif // TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_ALGORITHM_SELECTOR_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/tf2tensorrt/convert/convert_graph.h b/third_party/tflite-hdrs/tensorflow/compiler/tf2tensorrt/convert/convert_graph.h new file mode 100644 index 00000000..0607fb85 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/tf2tensorrt/convert/convert_graph.h @@ -0,0 +1,70 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_CONVERT_GRAPH_H_ +#define TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_CONVERT_GRAPH_H_ + +#include + +#include "tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h" +#include "tensorflow/compiler/tf2tensorrt/convert/trt_optimization_pass.h" +#include "tensorflow/compiler/tf2tensorrt/convert/utils.h" +#include "tensorflow/compiler/tf2tensorrt/utils/trt_shape_optimization_profiles.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/grappler/clusters/cluster.h" +#include "tensorflow/core/grappler/grappler_item.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/types.h" + +#if GOOGLE_CUDA && GOOGLE_TENSORRT + +namespace tensorflow { +namespace tensorrt { +namespace convert { + +// These functions are internal implementation functions for the +// TRTOptimizationPass. + +// Performs segmentation and conversion on the given Grappler item. This method +// contains the core logic of the TRTOptimizationPass. +Status ConvertGraph(const TRTOptimizationPass::ConversionParams& params, + grappler::GrapplerItem& grappler_item, + const std::vector& input_output_names, + grappler::Cluster* cluster, GraphDef* output); + +// Helper method for the conversion, expose for testing. +std::pair GetDeviceAndAllocator( + const grappler::Cluster* cluster, const EngineInfo& engine); + +// Helper method that registers `segment_graph` as a function to the function +// library in `graph`. +Status RegisterGraphToFunctionLibrary(const GraphDef& segment_graph_def, + Graph* graph, const string& engine_name); + +// Creates and serializes an ICudaEngine. Used only in is_dynamic_op=false, +// a.k.a. static engine mode. +Status CreateStaticEngine(const TRTOptimizationPass::ConversionParams& params, + const EngineInfo& info, int max_batch_size, + const std::vector& input_shapes, + TrtShapeOptimizationProfile* profile, + string* segment_string, grappler::Cluster* cluster); + +} // namespace convert +} // namespace tensorrt +} // namespace tensorflow + +#endif // GOOGLE_CUDA && GOOGLE_TENSORRT + +#endif // TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_CONVERT_GRAPH_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h b/third_party/tflite-hdrs/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h new file mode 100644 index 00000000..9664f1a0 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h @@ -0,0 +1,593 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_CONVERT_NODES_H_ +#define TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_CONVERT_NODES_H_ + +#include +#include +#include +#include +#include + +#include "absl/types/optional.h" +#include "tensorflow/compiler/tf2tensorrt/convert/op_converter.h" +#include "tensorflow/compiler/tf2tensorrt/convert/utils.h" +#include "tensorflow/compiler/tf2tensorrt/convert/weights.h" +#include "tensorflow/compiler/tf2tensorrt/utils/trt_allocator.h" +#include "tensorflow/compiler/tf2tensorrt/utils/trt_int8_calibrator.h" +#include "tensorflow/compiler/tf2tensorrt/utils/trt_logger.h" +#include "tensorflow/compiler/tf2tensorrt/utils/trt_shape_optimization_profiles.h" +#include "tensorflow/compiler/tf2tensorrt/utils/trt_tensor_proxy.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/grappler/costs/graph_properties.h" +#include "tensorflow/core/lib/core/status.h" + +#if GOOGLE_CUDA && GOOGLE_TENSORRT +#include "third_party/tensorrt/NvInfer.h" + +namespace tensorflow { +namespace tensorrt { + +namespace convert { +using ::tsl::StatusOr; + +struct EngineConnection { + // Constructs a non-control edge. + EngineConnection(const string& outside, int out_id, int out_port, + const string& inside, int in_id, int in_port, + bool input_edge, int port) + : outside_node_name(outside), + outside_id(out_id), + outside_port(out_port), + inside_node_name(inside), + inside_id(in_id), + inside_port(in_port), + is_input_edge(input_edge), + port_number(port) {} + + // Constructs a control edge. + EngineConnection(const string& outside, int out_id, const string& inside, + int in_id, bool input_edge) + : outside_node_name(outside), + outside_id(out_id), + outside_port(Graph::kControlSlot), + inside_node_name(inside), + inside_id(in_id), + inside_port(Graph::kControlSlot), + is_input_edge(input_edge), + port_number(Graph::kControlSlot) {} + + bool is_control_edge() const { return port_number == Graph::kControlSlot; } + + const string outside_node_name; + const int outside_id; + const int outside_port; + PartialTensorShape outside_shape; // Only set for input edge. + + const string inside_node_name; + const int inside_id; + const int inside_port; + PartialTensorShape inside_shape; // Only set for output edge. + + DataType connection_type; + const bool is_input_edge; + + // The port number of the TRT node connected with this edge. + const int port_number; +}; + +struct EngineInfo { + EngineInfo() + : engine_type(EngineType::TRTStatic), + max_workspace_size_bytes(0), + max_batch_size(std::nullopt), + maximum_cached_engines(0), + precision_mode(TrtPrecisionMode::FP32), + use_calibration(true), + + allow_build_at_runtime(true), + use_explicit_precision(false) {} + + string engine_name; + string device; + GraphDef segment_graph_def; + + // Non-control input connections inside this vector are sorted in a way such + // that, the segment nodes connecting to them are topological sorted. + // In addition, for non-control connections, there must be no duplicates. + std::vector connections; + + enum class EngineType { TRTStatic = 0, TRTDynamic = 1 }; + EngineType engine_type; + int64 max_workspace_size_bytes; + std::optional max_batch_size; + int maximum_cached_engines; + TrtPrecisionMode precision_mode; + bool use_calibration; + bool allow_build_at_runtime; + bool use_explicit_precision; +}; + +// Constructs a graphdef from the segment in the given graph and stores it to +// the engine_info. Adds _Arg nodes for input edges (InputPH_*) and _Retval +// nodes for output edges (OutputPH_*). Maintains the topological order of the +// non-input/output nodes in the graphdef. This function needs to be called +// before TensorRT layers are created because it prepares the original graph +// for TensorRT conversion. +// +// - subgraph_node_names: the node names of the subgraph. +// - subgraph_node_ids: the node ids of the subgraph, must be sorted in +// topological order. +// - engine_info: a data structure that records the information about the +// engine containing the subgraph. +// +// TODO(aaroey): add tests to validate these properties. +Status ConvertSegmentToGraphDef( + const Graph* graph, const grappler::GraphProperties& graph_properties, + const std::vector& subgraph_nodes, EngineInfo* engine_info); + +// Converts given subgraph to a TRT engine saved in 'engine'. Returns ok iff +// 'builder' successfully build the engine. If the result is not ok, 'engine' +// will be set to nullptr +// Once returned, 'builder' is not needed any more and can be safely destroyed. +// +// - convert_successfully: indicates whether the conversion to TensorRT network +// is successful. This is different than successfully building the engine: +// building can still fail afterwards. +// Note: When 'cluster' is not null, it contains the graph to be converted. +// We may perform additional optimizations to the graph before converting +// the graph. +Status ConvertGraphDefToEngine( + const GraphDef& gdef, OpKernelContext* ctx, TrtPrecisionMode precision_mode, + int max_batch_size, size_t max_workspace_size_bytes, + const std::vector& input_shapes, + nvinfer1::ILogger* logger, nvinfer1::IGpuAllocator* allocator, + TRTInt8Calibrator* calibrator, + TrtUniquePtrType* engine, bool use_calibration, + const bool use_implicit_batch, bool* convert_successfully, + TrtShapeOptimizationProfile* profiles, absl::string_view engine_name, + bool use_explicit_precision, + tensorflow::grappler::Cluster* cluster = nullptr, + const string& device = ""); + +// Helper class for the segmenter to determine whether an output edge from the +// TRT segment is valid. +class OutputEdgeValidator { + public: + // Return true if the specified edge is eligible to be an output edge of the + // TRT segment. + bool operator()(const Edge* out_edge) const; +}; + +// Class to verify if specific TF node is supported by TRT. +class TrtNodeValidator { + public: + // 'graph_properties' is the GraphProperties of the graph whose nodes will be + // checked by IsTensorRTCandidate() later. It is used to get the shape and + // data type information of a tensor for validation purpose. + TrtNodeValidator(const grappler::GraphProperties& graph_properties, + TrtPrecisionMode precision_mode, bool use_calibration, + bool use_implicit_batch, bool use_explicit_precision); + + // Returns OK iff 'node' is a TF-TRT conversion candidate, which will be added + // to TRT subgraph and later converted into TRT engine. + Status IsTensorRTCandidate(const Node* node); + + static const std::set* quantize_ops; + + // Returns validator by op type. If no validator is registered for + // specific op, it means no validation is needed and ValidateNode() will + // return OK. + StatusOr GetValidator(const std::string& op); + + private: + // Convert a Const node to a TRT_TensorOrWeights. + Status ConvertConstToWeights(const NodeDef& const_node_def, + const std::vector& inputs, + TRT_TensorOrWeights* output); + + // Convert a VariableV2 node to a TRT_TensorOrWeights. + Status ConvertVariableToWeights( + const NodeDef& const_node_def, + const std::vector& inputs, + TRT_TensorOrWeights* output); + + // Convert the output tensor at 'output_port' of 'node_def' to a + // TRT_TensorOrWeights which will be later used as an input to other nodes and + // passed to ValidateNode() below. + Status ConvertToTensorOrWeights(const NodeDef& node_def, int output_port, + TRT_TensorOrWeights* tensor_or_weights); + + // Store the weights added during validation. Some validations (e.g. + // validation for Const node) may produce weights. + TrtWeightStore weight_store_; + + // GraphProperties of the graph whose nodes are to be validated by + // IsTensorRTCandidate(). + const grappler::GraphProperties& graph_properties_; + + // Quantization ops are only converted when using quantized precisions. + const TrtPrecisionMode precision_mode_; + + const bool use_calibration_; + + const bool use_implicit_batch_; + + const bool use_explicit_precision_; + + friend class ValidatorTest; + friend class OpConverterTest; +}; + +// Class to convert TF nodes to TRT network. +class Converter { + public: + // Used for Converter::RenameAndMarkOutputTensors() + struct EngineOutputInfo { + // The TRT tensor name which produces the output. + string source_tensor_name; + // The TensorFlow node name which is receiving the output from the TRT + // engine. This should always be the Identity node created in + // ConvertSegmentToGraphDef. + string dest_node_name; + // Output type. TensorRT requires this to be explicitly set for engine + // outputs. + nvinfer1::DataType trt_dtype; + }; + + static StatusOr> Create( + TrtPrecisionMode precision_mode, bool use_calibration, + nvinfer1::ILogger* trt_logger, const bool use_implicit_batch, + absl::string_view engine_name, bool use_explicit_precision = false, + OpKernelContext* ctx = nullptr); + + ////////////////////////////////////////////////////////////////////////////// + // Methods used by the TRT engine builder to build a TRT network from a TF + // function/subgraph. + + // Convert the node to TRT network. + Status ConvertNode(const NodeDef& node_def); + + // Add input tensor to the TRT network with given 'name', 'dtype', 'dims' and + // 'batch_size'. + Status AddInputTensor(const string& name, nvinfer1::DataType dtype, + const nvinfer1::Dims& dims, int batch_size); + + // Store the ResourceHandle as a TRT_TensorOrWeights object. This can be + // later used as input to other nodes. + Status AddInputResource(const string& name, const ResourceHandle& resource); + + // Mark the tensors with names specified by source_tensor_name as output of + // the TRT network, and set their names in the TRT network as dest_node_name. + Status RenameAndMarkOutputTensors( + const std::vector& output_tensors); + + // Build a TRT engine using the created network. + Status BuildCudaEngine(TrtUniquePtrType* engine, + int max_batch_size, size_t max_workspace_size_bytes, + nvinfer1::IGpuAllocator* allocator, + TRTInt8Calibrator* calibrator, + TrtShapeOptimizationProfile* profiles); + + ////////////////////////////////////////////////////////////////////////////// + // Methods used by op converters to convert individual TF node and add layers + // to the TRT network. + + // Op converters (e.g. ConvertReshape) need to access the TRT network in order + // to add TRT layers. + nvinfer1::INetworkDefinition* network() { return trt_network_.get(); } + + // What precision are we targeting? + TrtPrecisionMode precision_mode() const { return precision_mode_; } + + // Variable converters need the context to read variable values. + OpKernelContext* context() { return ctx_; } + + // Calibration will be or was previously performed on this network? + bool use_calibration() const { return use_calibration_; } + + // Whether implicit batch mode is enabled + bool use_implicit_batch() const { return use_implicit_batch_; } + + // This function should be called when we know the quantization range of a + // tensor from a quantize/dequantize node. + void ProvideQuantizationRange(ITensorProxyPtr* tensor, float min_range, + float max_range); + + // Should be called when full TRT network has been constructed and before + // building the engine. + void MaybeApplyQuantizationRanges(); + + // Below are helper methods for op converters to add different layers to the + // TRT network. + + // Transpose 'input_tensor' with given permutation 'order_with_batch_dim' to + // 'output_tensor'. The permutation 'order_with_batch_dim' contains the batch + // dimension which should always be 0. If this is for adding a transpose layer + // to support the conversion of 'node_def', callers need to provide a + // non-empty 'sub_op_name' appended to the name of 'node_def' to avoid layer + // name conflicts. + Status TransposeTensor(ITensorProxyPtr input_tensor, + const std::vector& order_with_batch_dim, + ITensorProxyPtr* output_tensor, + const NodeDef& node_def, + absl::string_view sub_op_name = ""); + + // Reshapes a dynamic shape tensor by removing or adding dimensions of size 1, + // and/or permuting the dimensions. The new shape is derived from the shape of + // the input tensor according to the slices and size_for_added_dims arguments. + // + // If there would be at most one unknown dimension, we could set the new shape + // using IShuffleLayer::setReshapeDimensions, which treats -1 as a special + // value (the same way as TF). In general, we can have more than one unknown + // dimensions, and we have to manipulate the shape tensors during runtime to + // define the new shape. This helper function defines the necessary shape + // inference layers and calls reshape using the calculated new shape. + // + // Example: + // + // Assume that we want to reshape a tensor from shape {A,B,C,D} to {C,D,A,B} + // (no transpose, just change the shape). In dynamic shape mode, the A,B,C,D + // values are not necessarily known at conversion time, they can be all -1. We + // can only define the new shape at runtime, when the actual shape is already + // known. To define the new shape: + // - We use an IShapeLayer to retrieve a shape tensor with the {A,B,C,D} + // values. + // - Create two slices {C,D} and {A,B} of the shape tensor. + // - Concatenate these slices {C,D,A,B}, + // - Set the {C,D,A,B} shape tensor as an input shape tensor for + // IShuffleLayer. + // + // This can be achieved by calling DynamicReshape(input, {{2,4},{0,2}}, + // params). + // + // Before each slice we can insert new dims if the corresponding + // size_for_added_dims element is not negative. The size_for_added_dims array + // can have more than slices.size() elements, in order to insert a dimension + // after the last slice. For example, to add two leading 1 dimensions, and + // three trailing 1 dimensions, call DynamicReshape(input, {{0,nbDims}}, + // {2, 3}). + // + // Parameters: + // input - input tensor + // slices - [start, end) pairs of slices + // params - conversion parameters + // output - reshaped tensor + // size_for_added_dims - size of dimension inserted right before slice[i]. We + // only insert a new dim if size_for_added_dims[i] >= 0. + Status DynamicReshape(ITensorProxyPtr input, + std::vector> slices, + const OpConverterParams* params, + ITensorProxyPtr* output, + std::vector size_for_added_dims = {}, + std::optional op_instance = std::nullopt); + + // Inserts a singleton dimension at axis for a dynamic shape tensor. + Status DynamicExpandDims(ITensorProxyPtr input, const nvinfer1::Dims& dims, + int axis, const OpConverterParams* params, + ITensorProxyPtr* output, + std::optional op_instance = std::nullopt); + + // Helper function to add a squeeze op to the network. + // + // The input_dims argument stores the TRT dimensions of the input tensor, + // where the dimensions to be squeezed are replaced by 0. + Status SqueezeTensor(ITensorProxyPtr input, std::vector* input_dims, + const OpConverterParams* params, ITensorProxyPtr* output, + std::optional op_instance = std::nullopt); + + // Creates an IConstantLayer using 'weights' whose dimensions are specified by + // 'dims', and returns the output ITensor. + ITensorProxyPtr CreateConstantLayer(const TRT_ShapedWeights& weights, + const nvinfer1::Dims& dims); + + // Gets the min and max value in a TRT_ShapedWeights + Status GetWeightRange(const TRT_ShapedWeights& weights, float* out_min, + float* out_max) const; + + // Constructs a name and passed it to the TensorRT layer to support xprof. + void SetLayerName(nvinfer1::ILayer* layer, const NodeDef& node_def, + absl::string_view sub_op_name = "", + std::optional sub_op_instance = std::nullopt, + std::optional origin_node_name = std::nullopt); + + void SetLayerName(nvinfer1::ILayer* layer, absl::string_view main_op_name, + absl::string_view sub_op_name, + std::optional sub_op_instance = std::nullopt); + + std::unordered_map& TensorsMap() { + return trt_tensors_; + } + + bool UseExplicitPrecision() const { return use_explicit_precision_; } + + private: + Converter(TrtPrecisionMode precision_mode, bool use_calibration, + nvinfer1::ILogger* trt_logger, const bool use_implicit_batch, + absl::string_view engine_name, bool use_explicit_precision, + OpKernelContext* ctx); + + Status Init(nvinfer1::ILogger* trt_logger); + + // Verify the provided batch_size is consistent with batch_size_ and update it + // if necessary. + Status MaybeUpdateBatchSize(int batch_size); + + // Add the provided tensor/weights to the map trt_tensors_. + Status AddTensorOrWeights(const string& name, TRT_TensorOrWeights input); + + // Get the tensor/weights from trt_tensors_ by 'name'. + Status GetTensorOrWeights(const string& name, TRT_TensorOrWeights* output); + + // Get the inputs of 'node_def' from trt_tensors_. + Status GetInputs(const NodeDef& node_def, + std::vector* inputs) const; + + // Tensors/weights added during construction of trt_network_. + std::unordered_map trt_tensors_; + + // The TRT builder used to create the network and build the engine. Not owned. + TrtUniquePtrType trt_builder_; + + // The TRT network being built. + TrtUniquePtrType trt_network_; + + // Store the weights added during construction of trt_network_. + TrtWeightStore weight_store_; + + // Store the context. + OpKernelContext* ctx_; + + // During conversion, this table is populated with quantization ranges per + // tensor. MaybeApplyQuantizationRanges() will use this table to set the TRT + // quantization ranges. Since TRT only supports symmetric ranges, we will + // store the range as a single float = max(abs(min_range), abs(max_range)). + // Range refers to the floating point values, e.g. min_range = 0.0f, max_range + // = 6.0f for Relu6. + std::unordered_map quantization_ranges_proxy_; + std::unordered_map quantization_ranges_; + + const TrtPrecisionMode precision_mode_; + + const bool use_calibration_; + + // If this is false, all dimensions including the batch dimension are + // set explicitly. + const bool use_implicit_batch_; + + // Batch size of inputs to trt_network_ added by AddInputTensor(). During + // network construction it will update this, use it to verify the batch + // size of all inputs are compatible, and make sure individual TF node is + // acceptable by TRT. + int batch_size_ = -1; + + // Assign a ID to each constant layer we create, so that we can assign a + // unique name to the layer. + int next_constant_layer_id_ = 0; + + // The name of the TRTEngineOp node. + absl::string_view engine_name_; + + // Indicates whether to use explicit precision in TensorRT (Q/DQ support). + bool use_explicit_precision_; + + friend class ConverterTest; + friend class OpConverterTest; +}; + +// Converts a TensorFlow tensor to TRT shaped weights. +Status TfTensorToTrtWeights(const Tensor& tensor, TrtWeightStore* weight_store, + TRT_ShapedWeights* weights); + +// Converts 'input' of 'node_def' into 'tensor' with shape specified by 'dims' +// (which doesn't contain the batch dimension). +// +// If validation_only is true, it doesn't do the conversion but only do some +// minimum validation for the eligibility of the conversion, and *tensor will +// be set to nullptr. +// If validation_only is false converter must not be nullptr. +Status PrepareTensorForShape( + Converter* converter, const TRT_TensorOrWeights& input, + const DimsAdapter& dims, const bool validation_only, + ITensorProxyPtr* tensor, const NodeDef& node_def, + std::optional op_instance = std::nullopt, + std::optional origin_node_name = std::nullopt); + +// Return OK if the broadcast scheme is supported and compute the shapes after +// broadcasting. check_feasibility can be set to false in cases where dimensions +// do not need to match exactly (as in the case of BatchMatMulV2). +Status GetTrtBroadcastShape(const TRT_TensorOrWeights& operand_l, + const TRT_TensorOrWeights& operand_r, + const bool check_feasibility, + const bool use_implicit_batch, + nvinfer1::Dims* operand_l_new_dims, + nvinfer1::Dims* operand_r_new_dims); + +template +using OperationMap = std::unordered_map; + +// Map from Tensorflow operation names to TensorRT unary operations. +using UnaryOperationMapType = OperationMap; +const UnaryOperationMapType* UnaryOperationMap(); + +// Map from Tensorflow boolean operation names to TensorRT unary operations. +const UnaryOperationMapType* UnaryBooleanOperationMap(); + +// Map of all supported ActivationTypes. +using ActivationTypeMapType = OperationMap; +const ActivationTypeMapType* ActivationTypeMap(); + +// Map from Tensorflow binary operation names to TensorRT binary operations +// types. +using BinaryOperationMapType = OperationMap; +const BinaryOperationMapType* BinaryOperationMap(); + +// Map from Tensorflow boolean binary operation names to TensorRT binary +// operations types. +const BinaryOperationMapType* BinaryBooleanOperationMap(); + +template +absl::InlinedVector GetOperationNames(const T& set) { + absl::InlinedVector result; + absl::c_transform(set, std::back_inserter(result), + [](const auto x) { return x.first; }); + return result; +} + +// Adds a matrix multiplication operation to the TensorRT graph. The "params" +// pointer is only used to access the TRT network builder. The inputs and +// parameters for the op are fully specified by input_[a|b] and transpose_[a|b]. +StatusOr ConvertMatMulImpl(const OpConverterParams* params, + TRT_TensorOrWeights input_a, + TRT_TensorOrWeights input_b, + bool transpose_a, bool transpose_b); + +Status ApplyBroadcast(std::unique_ptr& operand, + const DimsAdapter& broadcasted_dims, + const OpConverterParams* params, + std::optional op_instance); + +std::string convert_range_error_msg(float start, float limit, float delta); +std::string convert_range_expected_msg(const NodeDef& node_def); +std::string bool_weight_error_msg(const NodeDef& node_def); +std::string unexpected_type_error_msg(nvinfer1::DataType type_being_checked, + nvinfer1::DataType type_expected, + const NodeDef& node_def, int idx = 0); +std::string then_else_dtypes_error_msg(nvinfer1::DataType type_then, + nvinfer1::DataType type_else, + const NodeDef& node_def); +std::string input_shapes_error_msg(const nvinfer1::Dims& shape1, + const nvinfer1::Dims& shape2, + const NodeDef& node, + bool then_vs_else = false); +std::string batch_size_error(absl::string_view name, absl::string_view comment); + +inline bool find_name(const string& name, const std::vector names) { + return std::find(names.begin(), names.end(), name) != names.end(); +} + +Status check_type(nvinfer1::DataType type_being_checked, + nvinfer1::DataType type_expected, const NodeDef& node_def, + int idx = 0); + +} // namespace convert +} // namespace tensorrt +} // namespace tensorflow + +#endif // GOOGLE_CUDA && GOOGLE_TENSORRT + +#endif // TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_CONVERT_NODES_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/tf2tensorrt/convert/logger_registry.h b/third_party/tflite-hdrs/tensorflow/compiler/tf2tensorrt/convert/logger_registry.h new file mode 100644 index 00000000..2a265cf7 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/tf2tensorrt/convert/logger_registry.h @@ -0,0 +1,58 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_LOGGER_REGISTRY_H_ +#define TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_LOGGER_REGISTRY_H_ + +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/types.h" + +#if GOOGLE_CUDA && GOOGLE_TENSORRT + +#include "third_party/tensorrt/NvInfer.h" + +namespace tensorflow { +namespace tensorrt { + +class LoggerRegistry { + public: + virtual Status Register(const string& name, nvinfer1::ILogger* logger) = 0; + virtual nvinfer1::ILogger* LookUp(const string& name) = 0; + virtual ~LoggerRegistry() {} +}; + +LoggerRegistry* GetLoggerRegistry(); + +class RegisterLogger { + public: + RegisterLogger(const string& name, nvinfer1::ILogger* logger) { + TF_CHECK_OK(GetLoggerRegistry()->Register(name, logger)); + } +}; + +#define REGISTER_TENSORRT_LOGGER(name, logger) \ + REGISTER_TENSORRT_LOGGER_UNIQ_HELPER(__COUNTER__, name, logger) +#define REGISTER_TENSORRT_LOGGER_UNIQ_HELPER(ctr, name, logger) \ + REGISTER_TENSORRT_LOGGER_UNIQ(ctr, name, logger) +#define REGISTER_TENSORRT_LOGGER_UNIQ(ctr, name, logger) \ + static ::tensorflow::tensorrt::RegisterLogger register_trt_logger##ctr \ + TF_ATTRIBUTE_UNUSED = \ + ::tensorflow::tensorrt::RegisterLogger(name, logger) + +} // namespace tensorrt +} // namespace tensorflow + +#endif // GOOGLE_CUDA && GOOGLE_TENSORRT +#endif // TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_LOGGER_REGISTRY_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/tf2tensorrt/convert/op_converter.h b/third_party/tflite-hdrs/tensorflow/compiler/tf2tensorrt/convert/op_converter.h new file mode 100644 index 00000000..7ebaaeb1 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/tf2tensorrt/convert/op_converter.h @@ -0,0 +1,224 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_OP_CONVERTER_H_ +#define TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_OP_CONVERTER_H_ + +#if GOOGLE_CUDA && GOOGLE_TENSORRT + +#include +#include + +#include "absl/strings/str_format.h" +#include "tensorflow/compiler/tf2tensorrt/convert/trt_parameters.h" +#include "tensorflow/compiler/tf2tensorrt/convert/weights.h" + +namespace tensorflow { +namespace tensorrt { +namespace convert { + +class Converter; + +// Specifies the expected type taken by a TRT_TensorOrWeights input during op +// conversion. +// kResource is only used for resource variable ops. For an operation like +// Add(tensor, ReadVariableOp(...)), the second operand of Add is the result of +// the ReadVariableOp, which is a kWeight. +enum class TrtInputArg { kTensor = 1, kWeight = 2, kBoth = 3, kResource = 4 }; + +// Parameters for each op converter. +struct OpConverterParams { + // Constructor used for validation only. + OpConverterParams(const NodeDef& node_def, + const std::vector& inputs, + std::vector* outputs, + TrtWeightStore* weight_store, + TrtPrecisionMode precision_mode, bool use_calibration, + bool use_implicit_batch, bool use_explicit_precision); + + // Constructor used for conversion. + OpConverterParams(Converter* converter, const NodeDef& node_def, + const std::vector& inputs, + std::vector* outputs, + TrtWeightStore* weight_store); + + Converter* converter = nullptr; + const NodeDef& node_def; + const std::vector& inputs; + std::vector* outputs; + const bool validation_only; + TrtWeightStore* weight_store; + const TrtPrecisionMode precision_mode; + const bool use_calibration; + const bool use_implicit_batch; + const bool use_explicit_precision; +}; + +// Operation converter function specification. +using OpConverter = std::function; + +struct InputArgSpec { + absl::string_view name; + TrtInputArg allowed_roles; + + static constexpr InputArgSpec Create(absl::string_view n, TrtInputArg role) { + return InputArgSpec{n, role}; + } +}; + +template +std::string convert_not_supported_dtype_msg(const T& allowed_types, + DataType tf_type, + const NodeDef& node) { + string allowed_types_string = + absl::StrJoin(allowed_types, ", ", [](string* out, const DataType& type) { + absl::StrAppendFormat(out, "%s", DataTypeString(type)); + }); + + return absl::StrCat("Data type ", DataTypeString(tf_type), + " is not supported for ", node.op(), ", must be one of [", + allowed_types_string, "]"); +} + +std::string convert_not_supported_implicit(const std::string& pOpName, + const std::string& pNodeName, + const char* pOpType = NULL); + +// A Curiously recurring template pattern (CRTP) template class for operation +// converters. +template +class OpConverterBase { + public: + explicit OpConverterBase(const OpConverterParams* params, + const std::vector& data_types = + {DataType::DT_FLOAT, DataType::DT_HALF}) + : params_(params), + node_def_attrs_(params->node_def), + allowed_dtypes_(data_types) {} + + // Default NodeDef attribute name to inspect in order to determine node data + // type. The Impl class can override this by implementing the same function. + static constexpr const char* NodeDefDataTypeAttributeName() { return "T"; } + + // Validate data type of the given NodeDef against allowed types. + Status ValidateNodeDefDataType() { + // If the attribute name is empty, we should skip this check. + if (absl::string_view(Impl::NodeDefDataTypeAttributeName()).empty()) { + return OkStatus(); + } + + // Get the NodeDef data type. + auto dtype = GetAttrValue(Impl::NodeDefDataTypeAttributeName()); + if (!dtype.ok()) { + return errors::InvalidArgument("Attribute with name ", + Impl::NodeDefDataTypeAttributeName(), + " not found."); + } + + // Check allowed data types.; + if (std::find(allowed_dtypes_.begin(), allowed_dtypes_.end(), *dtype) == + allowed_dtypes_.end()) { + return errors::Unimplemented(convert_not_supported_dtype_msg( + allowed_dtypes_, *dtype, params_->node_def)); + } + return OkStatus(); + } + + static constexpr bool HasFixNumberOfInputs() { return true; } + + // Validates input argument roles and data types. + Status ValidateInputs() { + const NodeDef& node_def = params_->node_def; + const auto& inputs = params_->inputs; + if (Impl::HasFixNumberOfInputs()) { + TRT_ENSURE(inputs.size() == Impl::InputSpec().size()); + } else { + TRT_ENSURE(inputs.size() <= Impl::InputSpec().size()); + } + for (int i = 0; i < inputs.size(); i++) { + const InputArgSpec arg_spec = Impl::InputSpec()[i]; + if (arg_spec.allowed_roles == TrtInputArg::kWeight && + inputs.at(i).is_tensor()) { + return errors::Unimplemented("The input \"", arg_spec.name, "\" for ", + node_def.op(), " must be a constant, at ", + node_def.name()); + } + if (arg_spec.allowed_roles == TrtInputArg::kTensor && + inputs.at(i).is_weights()) { + return errors::Unimplemented("The input \"", arg_spec.name, "\" for ", + node_def.op(), " must be a tensor, at ", + node_def.name()); + } + } + return OkStatus(); + } + + Status operator()() { + // Validate data type and inputs. + TF_RETURN_IF_ERROR(this->ValidateNodeDefDataType()); + TF_RETURN_IF_ERROR(this->ValidateInputs()); + + // Perform op-level validation. + TF_RETURN_IF_ERROR(reinterpret_cast(this)->Validate()); + if (params_->validation_only) { + return OkStatus(); + } + + // Perform conversion. + return reinterpret_cast(this)->Convert(); + } + + protected: + Status NotSupportedInImplicitBatch(const char* pOpType = nullptr) { + if (params_->use_implicit_batch) { + const auto& op = params_->node_def.op(); + const auto& nodeName = params_->node_def.name(); + const auto& error = convert_not_supported_implicit(op, nodeName, pOpType); + return errors::Unimplemented(error); + } + return OkStatus(); + } + + void AddOutput(const TRT_TensorOrWeights& out) { + params_->outputs->push_back(out); + } + + template + StatusOr GetAttrValue(absl::string_view key) const { + T result; + TF_RETURN_IF_ERROR(GetNodeAttr(node_def_attrs_, key, &result)); + return result; + } + + const OpConverterParams* const params_; + const AttrSlice node_def_attrs_; + const std::vector allowed_dtypes_; +}; + +// Constructs and returns a converter function for a given operation converter +// class T. This requires T to be a derived class of StructuredOpConverter. +template +OpConverter MakeConverterFunction() { + return [](const OpConverterParams* params) -> Status { + T converter(params); + return converter(); + }; +} + +} // namespace convert +} // namespace tensorrt +} // namespace tensorflow + +#endif // GOOGLE_CUDA && GOOGLE_TENSORRT +#endif // TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_OP_CONVERTER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/tf2tensorrt/convert/op_converter_registry.h b/third_party/tflite-hdrs/tensorflow/compiler/tf2tensorrt/convert/op_converter_registry.h new file mode 100644 index 00000000..8780aa68 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/tf2tensorrt/convert/op_converter_registry.h @@ -0,0 +1,104 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_OP_CONVERTER_REGISTRY_H_ +#define TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_OP_CONVERTER_REGISTRY_H_ + +#include +#if GOOGLE_CUDA && GOOGLE_TENSORRT + +#include +#include +#include + +#include "tensorflow/compiler/tf2tensorrt/convert/op_converter.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { +namespace tensorrt { +namespace convert { + +class OpConverterRegistry { + public: + OpConverterRegistry(); + ~OpConverterRegistry() = default; + + InitOnStartupMarker Register(const string& name, const int priority, + OpConverter converter); + + InitOnStartupMarker Register(const std::initializer_list& names, + const int priority, OpConverter converter) { + for (const auto& name : names) { + Register(name, priority, converter); + } + return {}; + } + + template ::value>::type* = nullptr> + InitOnStartupMarker Register(const T& names, const int priority, + OpConverter converter) { + for (const auto& name : names) { + Register(name, priority, converter); + } + return {}; + } + + // Clear all registered converters for the given Tensorflow operation name. + void Clear(const std::string& name); + + StatusOr LookUp(const string& name); + + std::vector ListRegisteredOps() const; + + private: + class Impl; + std::unique_ptr impl_; +}; + +OpConverterRegistry* GetOpConverterRegistry(); + +class RegisterOpConverter { + public: + RegisterOpConverter(const string& name, const int priority, + OpConverter converter) { + GetOpConverterRegistry()->Register(name, priority, converter); + } +}; + +constexpr int kDefaultConverterPriority = 1; + +} // namespace convert +} // namespace tensorrt + +#define REGISTER_TRT_OP_CONVERTER_IMPL(ctr, func, priority, ...) \ + static ::tensorflow::InitOnStartupMarker const \ + register_trt_op_converter##ctr TF_ATTRIBUTE_UNUSED = \ + TF_INIT_ON_STARTUP_IF(true) \ + << tensorrt::convert::GetOpConverterRegistry()->Register( \ + __VA_ARGS__, priority, func) + +#define REGISTER_TRT_OP_CONVERTER(func, priority, ...) \ + TF_NEW_ID_FOR_INIT(REGISTER_TRT_OP_CONVERTER_IMPL, func, priority, \ + __VA_ARGS__) + +#define REGISTER_DEFAULT_TRT_OP_CONVERTER(func, ...) \ + REGISTER_TRT_OP_CONVERTER( \ + func, tensorrt::convert::kDefaultConverterPriority, __VA_ARGS__) + +} // namespace tensorflow + +#endif // GOOGLE_CUDA && GOOGLE_TENSORRT +#endif // TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_OP_CONVERTER_REGISTRY_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/tf2tensorrt/convert/ops/layer_utils.h b/third_party/tflite-hdrs/tensorflow/compiler/tf2tensorrt/convert/ops/layer_utils.h new file mode 100644 index 00000000..f31af032 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/tf2tensorrt/convert/ops/layer_utils.h @@ -0,0 +1,715 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_OPS_LAYER_UTILS_H_ +#define TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_OPS_LAYER_UTILS_H_ +#if GOOGLE_CUDA && GOOGLE_TENSORRT + +#include + +#include "absl/strings/str_cat.h" +#include "tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h" +#include "tensorflow/compiler/tf2tensorrt/convert/utils.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/platform/statusor.h" +#include "third_party/tensorrt/NvInfer.h" +#include "third_party/tensorrt/NvInferRuntimeCommon.h" + +namespace tensorflow { +namespace tensorrt { + +namespace convert { + +// Facilitates the creation of TensorRT layers inside a network. The user +// provides a INetworkDefinition pointer during construction. They can then add +// operations to the network through the provided functions. Each function +// returns a struct which contains the symbolic result of the operation (ITensor +// pointer) as well as a pointer to the last TensorRT ILayer created. Some +// operations may create multiple layers in order to accomplish the desired +// result (e.g. Sign). +class TRTNetworkBuilder { + public: + static StatusOr Create( + nvinfer1::INetworkDefinition* network, TrtWeightStore* weight_store) { + TRT_ENSURE(network); + TRT_ENSURE(weight_store); + return TRTNetworkBuilder(network, weight_store); + } + + private: + TRTNetworkBuilder(nvinfer1::INetworkDefinition* network, + TrtWeightStore* weight_store) + : network_(network), weight_store_(weight_store) {} + + public: + // Adds an Add operation to the network. + StatusOr Add(nvinfer1::ITensor* lhs, + nvinfer1::ITensor* rhs) noexcept { + TRT_ENSURE(lhs); + TRT_ENSURE(rhs); + nvinfer1::IElementWiseLayer* layer = network_->addElementWise( + *lhs, *rhs, nvinfer1::ElementWiseOperation::kSUM); + TRT_ENSURE(layer); + return layer; + }; + + // Adds an elementwise min(lhs, rhs) operation to the network. The output has + // the same data type as the input. + StatusOr Min(nvinfer1::ITensor* lhs, + nvinfer1::ITensor* rhs) noexcept { + TRT_ENSURE(lhs); + TRT_ENSURE(rhs); + nvinfer1::IElementWiseLayer* layer = network_->addElementWise( + *lhs, *rhs, nvinfer1::ElementWiseOperation::kMIN); + TRT_ENSURE(layer); + return layer; + }; + + // Adds an elementwise max(lhs, rhs) operation to the network. The output has + // the same datatype as the input. + StatusOr Max(nvinfer1::ITensor* lhs, + nvinfer1::ITensor* rhs) noexcept { + TRT_ENSURE(lhs); + TRT_ENSURE(rhs); + nvinfer1::IElementWiseLayer* layer = network_->addElementWise( + *lhs, *rhs, nvinfer1::ElementWiseOperation::kMAX); + TRT_ENSURE(layer); + return layer; + }; + + // Adds an absolute value operation to the network. Note that this unary + // operation will do an implicit float conversion. For int32 tensors, use + // "AbsInt". + StatusOr AbsFloat(nvinfer1::ITensor* input) noexcept { + TRT_ENSURE(input); + TRT_ENSURE(input->getType() != nvinfer1::DataType::kFLOAT && + input->getType() != nvinfer1::DataType::kHALF); + nvinfer1::IUnaryLayer* layer = + network_->addUnary(*input, nvinfer1::UnaryOperation::kABS); + TRT_ENSURE(layer); + return layer; + } + + // Performs Abs without implicit float conversion. The input should be of type + // kInt32. For float datatypes, use "Abs". + StatusOr AbsInt( + nvinfer1::ITensor* input) noexcept { + TRT_ENSURE(input); + TRT_ENSURE(input->getType() == nvinfer1::DataType::kINT32); + StatusOr sign = this->SignInt(input); + return this->Mul(input, (*sign)->getOutput(0)); + } + + // Returns elementwise sign(x) for int32 input tensors where sign(x) is + // defined as 1 where x > 0, -1 where x < 0 and 0 where x == 0. + StatusOr SignInt( + nvinfer1::ITensor* input) noexcept { + TRT_ENSURE(input); + + // Create constants +1 and -1. + StatusOr one = + this->Constant(1, input->getDimensions().nbDims); + TRT_ENSURE_PTR_OK(one); + + StatusOr neg_one = + this->Constant(-1, input->getDimensions().nbDims); + TRT_ENSURE_PTR_OK(neg_one); + + // Turn all negaitve elements into -1, positive and zero elements + // unaffected. + StatusOr max = + this->Max(input, (*neg_one)->getOutput(0)); + TRT_ENSURE_PTR_OK(max); + + // Turn all positive elements into +1, negative and zero elements + // unaffected. + StatusOr min = + this->Min((*max)->getOutput(0), (*one)->getOutput(0)); + TRT_ENSURE_PTR_OK(min); + return min; + } + + // Adds a Sub operation to the network. + StatusOr Sub(nvinfer1::ITensor* lhs, + nvinfer1::ITensor* rhs) noexcept { + TRT_ENSURE(lhs); + TRT_ENSURE(rhs); + nvinfer1::IElementWiseLayer* layer = network_->addElementWise( + *lhs, *rhs, nvinfer1::ElementWiseOperation::kSUB); + TRT_ENSURE(layer); + return layer; + } + + // Adds an Greater operation to the network. + StatusOr Greater( + nvinfer1::ITensor* lhs, nvinfer1::ITensor* rhs) noexcept { + TRT_ENSURE(lhs); + TRT_ENSURE(rhs); + nvinfer1::IElementWiseLayer* layer = network_->addElementWise( + *lhs, *rhs, nvinfer1::ElementWiseOperation::kGREATER); + TRT_ENSURE(layer); + return layer; + } + + // Adds an Equal operation to the network. + StatusOr Equal( + nvinfer1::ITensor* lhs, nvinfer1::ITensor* rhs) noexcept { + TRT_ENSURE(lhs); + TRT_ENSURE(rhs); + nvinfer1::IElementWiseLayer* layer = network_->addElementWise( + *lhs, *rhs, nvinfer1::ElementWiseOperation::kEQUAL); + TRT_ENSURE(layer); + return layer; + } + + // Adds a FloorDiv operation to the network. + StatusOr FloorDiv( + nvinfer1::ITensor* lhs, nvinfer1::ITensor* rhs) noexcept { + TRT_ENSURE(lhs); + TRT_ENSURE(rhs); + nvinfer1::IElementWiseLayer* layer = network_->addElementWise( + *lhs, *rhs, nvinfer1::ElementWiseOperation::kFLOOR_DIV); + TRT_ENSURE(layer); + return layer; + } + + // Returns the equivalent of ceil_divide(abs(x)/abs(y))) operation. The inputs + // "lhs" and "rhs" should be int32 tensors. + StatusOr AbsCeilDivInt( + nvinfer1::ITensor* lhs, nvinfer1::ITensor* rhs) noexcept { + TRT_ENSURE(lhs); + TRT_ENSURE(rhs); + TRT_ENSURE(lhs->getType() == nvinfer1::DataType::kINT32); + TRT_ENSURE(rhs->getType() == nvinfer1::DataType::kINT32); + + StatusOr rhs_abs = this->AbsInt(rhs); + TRT_ENSURE_PTR_OK(rhs_abs); + StatusOr lhs_abs = this->AbsInt(lhs); + TRT_ENSURE_PTR_OK(lhs_abs); + StatusOr add1 = + this->Add((*lhs_abs)->getOutput(0), (*rhs_abs)->getOutput(0)); + TRT_ENSURE_PTR_OK(add1); + StatusOr one_const = + this->Constant(1, rhs->getDimensions().nbDims); + TRT_ENSURE_PTR_OK(one_const); + StatusOr numerator = + this->Sub((*add1)->getOutput(0), (*one_const)->getOutput(0)); + TRT_ENSURE_PTR_OK(numerator); + return FloorDiv((*numerator)->getOutput(0), (*rhs_abs)->getOutput(0)); + } + + // Adds an elementwise multiplication operation to the network. + StatusOr Mul(nvinfer1::ITensor* lhs, + nvinfer1::ITensor* rhs) noexcept { + TRT_ENSURE(lhs); + TRT_ENSURE(rhs); + nvinfer1::IElementWiseLayer* layer = network_->addElementWise( + *lhs, *rhs, nvinfer1::ElementWiseOperation::kPROD); + TRT_ENSURE(layer); + return layer; + } + + // Adds a sequence of elementwise multiplication operations to the network. + // The returned layer's output contains the cumulative elementwise product of + // all tensors in the input. + StatusOr CumulativeProd( + absl::Span inputs) noexcept { + TRT_ENSURE(!absl::c_any_of( + inputs, [](nvinfer1::ITensor* x) { return x == nullptr; })); + nvinfer1::ILayer* out = nullptr; + if (inputs.size() == 1) { + out = network_->addIdentity(*inputs[0]); + TRT_ENSURE(out != nullptr); + return out; + } + nvinfer1::ITensor* last = inputs[0]; + for (int i = 1; i < inputs.size(); i++) { + StatusOr mul = this->Mul(last, inputs[i]); + TRT_ENSURE_PTR_OK(mul); + out = *mul; + last = (*mul)->getOutput(0); + } + return out; + } + + // Adds a Constant layer whose output is a TensorRT shape tensor. The shape + // tensor's size and values correspond to dim's nbDims and d[], respectively. + StatusOr ConstantShape( + const DimsAdapter& shape_data) noexcept { + TRT_ENSURE(shape_data.NumDims() > 0); + nvinfer1::Dims shape_dims; + shape_dims.nbDims = 1; + shape_dims.d[0] = shape_data.NumDims(); + StatusOr const_weights = + weight_store_->GetTempWeights(nvinfer1::DataType::kINT32, shape_dims); + TRT_ENSURE_OK(const_weights); + absl::c_copy(shape_data, const_weights->GetPointer()); + StatusOr trt_dims = const_weights->Shape().AsTrtDims(); + TRT_ENSURE_OK(trt_dims); + nvinfer1::IConstantLayer* const_layer = + network_->addConstant(*trt_dims, const_weights->GetTrtWeights()); + TRT_ENSURE(const_layer); + nvinfer1::ITensor* output = const_layer->getOutput(0); + TRT_ENSURE(output); + TRT_ENSURE(output->getType() == nvinfer1::DataType::kINT32); + return const_layer; + } + + // Adds a Constant layer whose output is a TensorRT shape tensor. The shape + // tensor's size and values correspond to dim's nbDims and d[], respectively. + StatusOr Constant( + const std::vector& data) noexcept { + nvinfer1::Dims shape_dims; + shape_dims.nbDims = 1; + shape_dims.d[0] = data.size(); + StatusOr const_weights = + weight_store_->GetTempWeights(nvinfer1::DataType::kINT32, shape_dims); + TRT_ENSURE_OK(const_weights); + int32* values = const_weights->GetPointer(); + for (int i = 0; i < data.size(); i++) { + values[i] = static_cast(data[i]); + } + StatusOr trt_dims = const_weights->Shape().AsTrtDims(); + TRT_ENSURE_OK(trt_dims); + nvinfer1::IConstantLayer* const_layer = + network_->addConstant(*trt_dims, const_weights->GetTrtWeights()); + TRT_ENSURE(const_layer); + nvinfer1::ITensor* output = const_layer->getOutput(0); + TRT_ENSURE(output); + TRT_ENSURE(output->getType() == nvinfer1::DataType::kINT32); + TRT_ENSURE(const_layer); + return const_layer; + } + + // Adds a Constant layer that produces a tensor of shape "shape", + // type "data_type" and filled with value "scalar". + template + StatusOr Constant( + const T value, nvinfer1::Dims shape, + nvinfer1::DataType data_type) noexcept { + StatusOr const_weights = + weight_store_->GetTempWeights(data_type, shape); + TRT_ENSURE_OK(const_weights); + TRT_ENSURE(const_weights->SetValues(value).ok()); + nvinfer1::IConstantLayer* const_layer = + network_->addConstant(shape, const_weights->GetTrtWeights()); + TRT_ENSURE(const_layer); + return const_layer; + } + + // Adds a Constant layer that produces a tensor with a single value "scalar". + // The tensor has "nb_dims" dimensions and each dimension has only one + // element. The data type of the tensor is determined by the data type of + // "scalar". + template ::value>::type* = nullptr> + StatusOr Constant(const T scalar, + const int nb_dims) noexcept { + TRT_ENSURE(nb_dims <= nvinfer1::Dims::MAX_DIMS); + auto data_type = nvinfer1::DataType::kINT32; + if (std::is_floating_point::value) { + data_type = nvinfer1::DataType::kFLOAT; + } + nvinfer1::Dims zero_shape; + zero_shape.nbDims = nb_dims; + std::fill_n(zero_shape.d, nb_dims, 1); + return Constant(scalar, zero_shape, data_type); + } + + // Adds a Constant layer from a TRT_ShapedWeights object. + StatusOr WeightsToConstant( + const nvinfer1::Weights& weights, const DimsAdapter& dims) noexcept { + StatusOr vol = dims.Volume(); + TRT_ENSURE_OK(vol); + TRT_ENSURE(*vol == weights.count); + StatusOr trt_dims = dims.AsTrtDims(); + TRT_ENSURE_OK(trt_dims); + nvinfer1::IConstantLayer* const_layer = + network_->addConstant(*trt_dims, weights); + TRT_ENSURE(const_layer); + return const_layer; + } + + Status get_tensor4TensorOrWeights(const TRT_TensorOrWeights& input, + ITensorProxyPtr* pTensor) { + if (input.is_weights()) { + StatusOr const_layer = WeightsToConstant( + input.weights().GetTrtWeights(), input.GetTrtDims()); + if (!const_layer.status().ok()) return const_layer.status(); + *pTensor = (*const_layer)->getOutput(0); + } else { + *pTensor = input.tensor(); + } + return OkStatus(); + } + + // Creates a nvinfer1::Weights object containing a single scalar. + template ::value>::type* = nullptr> + StatusOr ScalarWeights(const T scalar, + const int nb_dims) noexcept { + TRT_ENSURE(nb_dims <= nvinfer1::Dims::MAX_DIMS); + auto data_type = nvinfer1::DataType::kINT32; + if (std::is_floating_point::value) { + data_type = nvinfer1::DataType::kFLOAT; + } + nvinfer1::Dims weights_shape; + weights_shape.nbDims = nb_dims; + std::fill_n(weights_shape.d, nb_dims, 1); + StatusOr const_weights = + weight_store_->GetTempWeights(data_type, weights_shape); + TRT_ENSURE_OK(const_weights); + const_weights->GetPointer()[0] = scalar; + return const_weights->GetTrtWeights(); + } + + // Adds a TensorRT Slice operation to the network. + StatusOr Slice( + nvinfer1::ITensor* input, const nvinfer1::Dims& begin, + const nvinfer1::Dims& size, const nvinfer1::Dims& stride) noexcept { + nvinfer1::ISliceLayer* layer = + network_->addSlice(*input, begin, size, stride); + TRT_ENSURE(layer); + return layer; + } + + // Adds a TensorRT Concatenate operation to the network. + StatusOr Concat( + absl::Span inputs, const int axis) { + for (nvinfer1::ITensor* input : inputs) { + TRT_ENSURE(input); + } + nvinfer1::IConcatenationLayer* layer = network_->addConcatenation( + inputs.data(), static_cast(inputs.size())); + TRT_ENSURE(layer); + layer->setAxis(axis); + return layer; + } + + // Adds a TensorRT Concatenate operation to the network. + StatusOr Concat( + const std::vector& inputs, const int axis) { + return this->Concat(absl::MakeSpan(inputs), axis); + } + + // Adds a TensorRT Shape operation, which determines the runtime shape of the + // input tensor, to the network. + StatusOr Shape(nvinfer1::ITensor* input) { + TRT_ENSURE(input); + nvinfer1::IShapeLayer* layer = network_->addShape(*input); + TRT_ENSURE(layer); + return layer; + } + + // Creates a Gather operation on the shape of the input tensor. The output of + // the gather operation is a 1D shape tensor where output[i] = (!sub_one ? + // input_shape[i] : input_shape[i] -1) if i is in "indices", otherwise zero. + StatusOr GetPartialShapeOf( + nvinfer1::ITensor* input, absl::InlinedVector indices, + bool sub_one = false) { + TRT_ENSURE(input); + TRT_ENSURE(indices.size() <= nvinfer1::Dims::MAX_DIMS); + + // Get the runtime shape of input; + StatusOr shape_layer = this->Shape(input); + TRT_ENSURE_PTR_OK(shape_layer); + nvinfer1::ITensor* runtime_shape = (*shape_layer)->getOutput(0); + + if (sub_one) { + StatusOr ones = this->Constant(1, 1); + TRT_ENSURE_PTR_OK(ones); + StatusOr sub = + this->Sub(runtime_shape, (*ones)->getOutput(0)); + TRT_ENSURE_PTR_OK(sub); + runtime_shape = (*sub)->getOutput(0); + } + + // Create a constant tensor containing the gather indices. + // For any dim not in "indices", we mark it size to gather a zero. + const int input_nb_dims = input->getDimensions().nbDims; + std::vector indices_all(input_nb_dims, input_nb_dims); + for (auto idx : indices) { + TRT_ENSURE(idx < input_nb_dims); + indices_all[idx] = idx; + } + + StatusOr indices_result = + this->Constant(indices_all); + TRT_ENSURE_PTR_OK(indices_result); + nvinfer1::ITensor* gather_indices = (*indices_result)->getOutput(0); + TRT_ENSURE(gather_indices->getDimensions().nbDims == 1); + TRT_ENSURE(gather_indices->getType() == nvinfer1::DataType::kINT32); + + // Append a zero to the shape tensor. + StatusOr zero_result = + this->Constant(std::vector{0}); + TRT_ENSURE_PTR_OK(zero_result); + std::array cat_inputs = { + runtime_shape, (*zero_result)->getOutput(0)}; + nvinfer1::IConcatenationLayer* cat_layer = + network_->addConcatenation(cat_inputs.data(), cat_inputs.size()); + TRT_ENSURE(cat_layer); + nvinfer1::ITensor* gather_input = cat_layer->getOutput(0); + TRT_ENSURE(gather_input); + + // Finally, gather the indices from the input. + nvinfer1::IGatherLayer* gather = + network_->addGather(*gather_input, *gather_indices, 0); + TRT_ENSURE(gather); + return gather; + } + + // Adds a scale layer that uniformly scales the input tensor by the specified + // amount. + StatusOr AddUniformScale(nvinfer1::ITensor* input, + float scale, + const std::string& name) { + TRT_ENSURE(input); + TRT_ENSURE(!name.empty()); + StatusOr weight = this->ScalarWeights(scale, 1); + TRT_ENSURE_OK(weight); + const nvinfer1::Weights empty_weights = + nvinfer1::Weights{nvinfer1::DataType::kFLOAT, nullptr, 0}; + nvinfer1::IScaleLayer* scale_layer = + network_->addScale(*input, nvinfer1::ScaleMode::kUNIFORM, empty_weights, + (*weight), empty_weights); + TRT_ENSURE(scale_layer != nullptr); + scale_layer->setName(name.c_str()); + TRT_ENSURE((*scale_layer).getPower().count == 0); + TRT_ENSURE((*scale_layer).getShift().count == 0); + TRT_ENSURE((*scale_layer).getScale().count == 1); + return scale_layer; + } + + StatusOr AddFill(const TRT_TensorOrWeights& value_input, + const TRT_TensorOrWeights& dims_input, + bool is_value_static, bool is_dims_static, + int nbDims, + const nvinfer1::Dims& trt_dims, + ITensorProxyPtr scalar_tensor = nullptr, + ITensorProxyPtr beta_tensor = nullptr, + const float delta = 0) { + // TensorRT IFillLayer requires a rank 0 scalar. + nvinfer1::Dims scalar_dims; + scalar_dims.nbDims = 0; + if (is_value_static) { + StatusOr const_layer = + WeightsToConstant(value_input.weights().GetTrtWeights(), scalar_dims); + if (!const_layer.status().ok()) return const_layer.status(); + scalar_tensor = (*const_layer)->getOutput(0); + } else { + if (scalar_tensor == nullptr) { + StatusOr shuffler_layer = + Reshape(value_input.tensor()->trt_tensor(), scalar_dims); + if (!shuffler_layer.status().ok()) return shuffler_layer.status(); + scalar_tensor = (*shuffler_layer)->getOutput(0); + } + } + + if (beta_tensor == nullptr) { + nvinfer1::Dims beta_shape{1, {nbDims}}; + StatusOr const_layer = + Constant(delta, beta_shape, value_input.TrtDType()); + TF_RETURN_IF_ERROR(const_layer.status()); + beta_tensor = (*const_layer)->getOutput(0); + } + + nvinfer1::IFillLayer* layer = + network_->addFill(trt_dims, nvinfer1::FillOperation::kLINSPACE); + TRT_ENSURE(layer); + if (!is_dims_static) { + layer->setInput(0, *dims_input.tensor()->trt_tensor()); + } + layer->setInput(1, *scalar_tensor->trt_tensor()); + layer->setInput(2, *beta_tensor->trt_tensor()); + return layer; + } + + // Adds a quantization layer that uniformly scales the input tensor + // by the given multiplicative "scaling_factor", then rounds + // (round-to-nearest-ties-to-even) to the nearest integer and clamps in the + // range of [-128, 127]. + StatusOr Quantize(nvinfer1::ITensor* input, + const float scaling_factor, + const std::string& name) { + TRT_ENSURE(input); + TRT_ENSURE(!name.empty()); + // Preprocessor usage here is unavoidable because TRT8 API is new. +#if IS_TRT_VERSION_GE(8, 0, 0, 0) + // The TensorRT IQuantizeLayer divides by the scale factor rather than + // multiplies. To be consistent, in this function we expect a multiplicative + // scale factor, so we take the reciprical. + StatusOr scaling_const = + this->Constant(1.0f / scaling_factor, 1); + TRT_ENSURE_PTR_OK(scaling_const); + (*scaling_const)->setDimensions(nvinfer1::Dims{0, {}}); + nvinfer1::IQuantizeLayer* quant_layer = + network_->addQuantize(*input, *(*scaling_const)->getOutput(0)); + TRT_ENSURE(quant_layer); + quant_layer->setAxis(1); + return quant_layer; +#else + StatusOr result = + this->AddUniformScale(input, scaling_factor, name); + TRT_ENSURE_PTR_OK(result); + (*result)->setOutputType(0, nvinfer1::DataType::kINT8); + (*result)->setPrecision(nvinfer1::DataType::kFLOAT); + return result; +#endif + } + + // Adds a dequantize layer that casts the input tensor to TensorRT float type + // and scales it uniformly by the given multiplicative "scaling_factor". + StatusOr Dequantize(nvinfer1::ITensor* input, + const float scaling_factor, + const std::string& name) { + TRT_ENSURE(input); + TRT_ENSURE(!name.empty()); +#if IS_TRT_VERSION_GE(8, 0, 0, 0) + StatusOr scaling_const = + this->Constant(scaling_factor, 1); + TRT_ENSURE_PTR_OK(scaling_const); + (*scaling_const)->setDimensions(nvinfer1::Dims{0, {}}); + nvinfer1::IDequantizeLayer* dequant_layer = + network_->addDequantize(*input, *(*scaling_const)->getOutput(0)); + dequant_layer->setAxis(1); + TRT_ENSURE(dequant_layer); + return dequant_layer; +#else + StatusOr result = + this->AddUniformScale(input, scaling_factor, name); + TRT_ENSURE_PTR_OK(result); + (*result)->setOutputType(0, nvinfer1::DataType::kFLOAT); + (*result)->setPrecision(nvinfer1::DataType::kINT8); + return result; +#endif + } + + // Adds TensorRT Q/DQ operations. This is for explicit precision mode. + StatusOr UniformQuantizeDequantizeExplicit( + nvinfer1::ITensor* input, float quantize_scale, float dequantize_scale, + const std::string& name) { + TRT_ENSURE(input); + if (!IS_TRT_VERSION_GE(8, 0, 0, 0)) { + TRT_ENSURE(network_->hasExplicitPrecision()); + } + TRT_ENSURE(IS_TRT_VERSION_GE(7, 1, 0, 0)); + + static int count = 0; + TRT_ENSURE(input->getType() == nvinfer1::DataType::kFLOAT); + std::string quant_name = absl::StrCat(input->getName(), "_quant_", count); + + StatusOr quant = + this->Quantize(input, quantize_scale, quant_name); + TRT_ENSURE_PTR_OK(quant); + + std::string dequant_name = + absl::StrCat(input->getName(), "_dequant_", count); + StatusOr dequant = this->Dequantize( + (*quant)->getOutput(0), dequantize_scale, dequant_name); + TRT_ENSURE_PTR_OK(dequant); + + count++; + return dequant; + } + + StatusOr Reshape(nvinfer1::ITensor* input, + const nvinfer1::Dims& new_shape) { + TRT_ENSURE(input); + nvinfer1::IShuffleLayer* layer = network_->addShuffle(*input); + TRT_ENSURE(layer); + layer->setReshapeDimensions(new_shape); + return layer; + } + + StatusOr FindProducerOf(const nvinfer1::ITensor* tensor) { + const char* name = tensor->getName(); + const int num_layers = network_->getNbLayers(); + for (int i = 0; i < num_layers; i++) { + nvinfer1::ILayer* layer = network_->getLayer(i); + const int num_outputs = layer->getNbOutputs(); + for (int j = 0; j < num_outputs; j++) { + nvinfer1::ITensor* t = layer->getOutput(j); + if (std::string(t->getName()) == name) { + return layer; + } + } + } + return errors::NotFound("could not find producing layer of ", name); + } + + StatusOr UniqueParentOf(const nvinfer1::ILayer* layer, + int input_idx = 0) { + return FindProducerOf(layer->getInput(input_idx)); + } + + nvinfer1::INetworkDefinition* Network() { return network_; } + + private: + nvinfer1::INetworkDefinition* network_; + TrtWeightStore* weight_store_; +}; + +class ShuffleBuilder { + private: + explicit ShuffleBuilder(TRTNetworkBuilder* builder, nvinfer1::ITensor* input) + : builder_(builder) { + layer_ = builder->Network()->addShuffle(*input); + } + + public: + static StatusOr Create(TRTNetworkBuilder* builder, + nvinfer1::ITensor* input) { + TRT_ENSURE(builder != nullptr); + TRT_ENSURE(input != nullptr); + return ShuffleBuilder(builder, input); + } + + ShuffleBuilder& SetReshape(const nvinfer1::Dims& dims) { + layer_->setReshapeDimensions(dims); + return *this; + } + + ShuffleBuilder& SetReshape(nvinfer1::ITensor* shape) { + layer_->setInput(1, *shape); + return *this; + } + + ShuffleBuilder& SetFirstTranspose(const nvinfer1::Permutation& perm) { + layer_->setFirstTranspose(perm); + return *this; + } + + ShuffleBuilder& SetSecondTranspose(const nvinfer1::Permutation& perm) { + layer_->setSecondTranspose(perm); + return *this; + } + + StatusOr Output() { + TRT_ENSURE(layer_ != nullptr); + TRT_ENSURE(layer_->getOutput(0) != nullptr); + return layer_->getOutput(0); + } + + private: + TRTNetworkBuilder* builder_; + nvinfer1::IShuffleLayer* layer_; +}; + +} // namespace convert +} // namespace tensorrt +} // namespace tensorflow + +#endif // GOOGLE_CUDA && GOOGLE_TENSORRT +#endif // TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_OPS_LAYER_UTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/tf2tensorrt/convert/ops/quantization_ops.h b/third_party/tflite-hdrs/tensorflow/compiler/tf2tensorrt/convert/ops/quantization_ops.h new file mode 100644 index 00000000..280dc1e7 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/tf2tensorrt/convert/ops/quantization_ops.h @@ -0,0 +1,76 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_OPS_QUANTIZATION_OPS_H_ +#define TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_OPS_QUANTIZATION_OPS_H_ +#if GOOGLE_CUDA && GOOGLE_TENSORRT + +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { +namespace tensorrt { +namespace convert { + +constexpr std::array kQuantizationOpNames = { + "QuantizeAndDequantizeV2", + "QuantizeAndDequantizeV3", + "FakeQuantWithMinMaxVars", + "FakeQuantWithMinMaxArgs", +}; + +// Operations with supported conversion to Q/DQ ops in TensorRT explicit +// precision mode. +constexpr std::array kExplicitQuantizationOpNames = { + "QuantizeAndDequantizeV2", +}; + +// Contains two scaling factors for quantization and dequantization +// respectively. A shift factor is omitted as TensorRT only supports symmetric +// quantization. +template +struct QuantizationScales { + std::array quantize_scale; + std::array dequantize_scale; +}; + +// In TensorRT 7 and 8, only uniform tensor scaling is supported for +// activations. +using UniformQuantizationScales = QuantizationScales; + +// Per-channel scaling is supported for weights in TensorRT version >= 8.0. +template +using PerChannelQuantizationScales = QuantizationScales; + +template +std::ostream& operator<<(std::ostream& os, + const QuantizationScales& scales) { + os << absl::StrFormat("QuantizationScales[quantize={%s},dequantize={%s}]", + absl::StrJoin(scales.quantize_scale, ","), + absl::StrJoin(scales.dequantize_scale, ",")); + return os; +} + +// Returns true if the Tensorflow node is a quantize and dequantize operation. +bool IsQuantizeAndDequantizeOp(const Node*); + +} // namespace convert +} // namespace tensorrt +} // namespace tensorflow + +#endif // GOOGLE_CUDA && GOOGLE_TENSORRT + +#endif // TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_OPS_QUANTIZATION_OPS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/tf2tensorrt/convert/ops/slice_ops.h b/third_party/tflite-hdrs/tensorflow/compiler/tf2tensorrt/convert/ops/slice_ops.h new file mode 100644 index 00000000..4dd281ae --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/tf2tensorrt/convert/ops/slice_ops.h @@ -0,0 +1,70 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_OPS_SLICE_OPS_H_ +#define TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_OPS_SLICE_OPS_H_ +#if GOOGLE_CUDA && GOOGLE_TENSORRT + +#include "tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/util/strided_slice_op.h" + +namespace tensorflow { +namespace tensorrt { +namespace convert { +using SliceDims = absl::InlinedVector; + +// Creates a strided slice operation using the given information. This function +// expects that the begin, stride, and end vectors have already been validated. +// This function converts the [begin:stride:end] specification to the TensorRT +// [begin:stride:size] ISliceLayer specification. The following algorithm is +// used to perform this conversion: 1) The given (input_dims, +// [begin:stride:end]) specification is dividied into +// "static dimensions" and "dynamic dimensions". "Dynamic dimensions" +// includes all dimensions of the slice where input_dims[i] == -1. +// 2a) If there are no dynamic dimensions, then the "begin", "stride", and +// "size" variables are passed to the ISLiceLayer creation as build-time +// constants in the form of nvinfer1::Dims objects. +// 2b) If there are any dynamic dimensions, then the "begin", "stride", and +// "size" variables are treated as runtime dynamic shape Tensors in the +// TensorRT graph. In this case, we must calculate "size" at runtime for all +// dynamic dimensions, while static dimensions use the constant values. +// +// Note that when any dynamic indices are present (2b), the "strided_slice_spec" +// must be specified. This structure can be obtained through the +// "tensorflow::ValidateStridedSliceOp" function, or it can be constructed +// directly. When the ValidateStridedSliceOp helper function is used, it will +// also return the "begin", "stride", and "end" vectors. When all dimensions are +// static (2a), the "strided_slice_spec" variable is not required. +// +// If the "final_shape" variable is specified, then a reshape operation will be +// added to the graph to achieve this shape. The shape must be fully specified. +// +// "op_instance" is only required if the caller needs to pass this variable +// through to the Converter functions optionally accept it (SetLayerName, +// PrepareTensorForShape). +Status ConvertStridedSliceHelper( + const OpConverterParams* params, const TRT_TensorOrWeights& input, + const PartialTensorShape& input_dims, const SliceDims& begin, + const SliceDims& stride, const SliceDims& end, + std::optional final_shape = std::nullopt, + std::optional op_instance = std::nullopt, + std::optional strided_slice_spec = std::nullopt); + +} // namespace convert +} // namespace tensorrt +} // namespace tensorflow + +#endif // GOOGLE_CUDA && GOOGLE_TENSORRT +#endif // TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_OPS_SLICE_OPS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/tf2tensorrt/convert/timing_cache.h b/third_party/tflite-hdrs/tensorflow/compiler/tf2tensorrt/convert/timing_cache.h new file mode 100644 index 00000000..4d43b1d0 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/tf2tensorrt/convert/timing_cache.h @@ -0,0 +1,70 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_TIMING_CACHE_H_ +#define TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_TIMING_CACHE_H_ +#if GOOGLE_CUDA && GOOGLE_TENSORRT +#include + +#include "tensorflow/compiler/tf2tensorrt/common/utils.h" +#include "tensorflow/core/framework/registration/registration.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/statusor.h" +#include "third_party/tensorrt/NvInfer.h" + +namespace tensorflow { +namespace tensorrt { +namespace convert { + +// A registry for holding serialized TensorRT autotuner timing caches. +// For TensorRT versions < 8.0, the timing cache is not serializable, so these +// operations become no-ops. +class TimingCacheRegistry { + public: + TimingCacheRegistry() = default; + ~TimingCacheRegistry() = default; + +#if IS_TRT_VERSION_GE(8, 0, 0, 0) + using TimingCache = nvinfer1::ITimingCache; + using TimingCachePtr = std::unique_ptr; +#else + struct TimingCache {}; + using TimingCachePtr = std::unique_ptr; +#endif + + // Insert or update a registry into the map using the given name. The cache + // will be serialized before being placed into the map. + void Upsert(const string& name, TimingCache* cache); + + // Find a timing cache using the given name. The provided BuilderConfig is + // used to deserialize the cache. If no timing cache is found, a new timing + // cache is returned. + StatusOr LookUp(const string& name, + nvinfer1::IBuilderConfig* builder_config); + + private: + using SerializedTimingCache = std::vector; + + mutex mu_; + std::unordered_map map_; +}; + +TimingCacheRegistry* GetTimingCacheRegistry(); + +} // namespace convert +} // namespace tensorrt +} // namespace tensorflow + +#endif // GOOGLE_CUDA && GOOGLE_TENSORRT +#endif // TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_TIMING_CACHE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/tf2tensorrt/convert/trt_layout_optimization_pass.h b/third_party/tflite-hdrs/tensorflow/compiler/tf2tensorrt/convert/trt_layout_optimization_pass.h new file mode 100644 index 00000000..e91b3cd8 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/tf2tensorrt/convert/trt_layout_optimization_pass.h @@ -0,0 +1,69 @@ +/* Copyright 20121 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_TRT_LAYOUT_OPTIMIZATION_PASS_H_ +#define TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_TRT_LAYOUT_OPTIMIZATION_PASS_H_ + +#include + +#include "tensorflow/compiler/tf2tensorrt/convert/utils.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer.h" +#include "tensorflow/core/platform/logging.h" + +#if GOOGLE_CUDA && GOOGLE_TENSORRT + +#if !IS_TRT_VERSION_GE(7, 0, 0, 0) +#error From version 2.6, we only support NVIDIA TensorRT version 7 or newer. +#error Please update your environment and relaunch the compilation. +#endif + +namespace tensorflow { +namespace tensorrt { +namespace convert { +class TRTLayoutOptimizationPass : public grappler::CustomGraphOptimizer { + public: + TRTLayoutOptimizationPass(const string& name = "TRTLayoutOptimizationPass"); + + string name() const override { return name_; }; + + bool UsesFunctionLibrary() const override { return true; } + + Status Init( + const RewriterConfig_CustomGraphOptimizer* config = nullptr) override; + + Status Optimize(grappler::Cluster* cluster, + const grappler::GrapplerItem& item, + GraphDef* optimized_graph) override; + + /* void PrintDebugInfo(grappler::Cluster* cluster, + const grappler::GrapplerItem& item); + */ + + private: + const string name_; + string trt_logger_name_; + int minimum_segment_size_; + bool is_dynamic_op_; + int max_cached_batches_; + int64_t max_workspace_size_bytes_; +}; + +} // namespace convert +} // namespace tensorrt +} // namespace tensorflow + +#endif // GOOGLE_CUDA && GOOGLE_TENSORRT +#endif // TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_TRT_LAYOUT_OPTIMIZATION_PASS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/tf2tensorrt/convert/trt_optimization_pass.h b/third_party/tflite-hdrs/tensorflow/compiler/tf2tensorrt/convert/trt_optimization_pass.h new file mode 100644 index 00000000..abc3bdce --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/tf2tensorrt/convert/trt_optimization_pass.h @@ -0,0 +1,87 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_TRT_OPTIMIZATION_PASS_H_ +#define TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_TRT_OPTIMIZATION_PASS_H_ + +#include +#include + +#include "tensorflow/compiler/tf2tensorrt/convert/trt_parameters.h" +#include "tensorflow/compiler/tf2tensorrt/convert/utils.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/grappler/costs/graph_properties.h" +#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer.h" +#include "tensorflow/core/grappler/utils.h" +#include "tensorflow/core/platform/logging.h" + +#if GOOGLE_CUDA && GOOGLE_TENSORRT + +#if !IS_TRT_VERSION_GE(7, 0, 0, 0) +#error From version 2.6, we only support NVIDIA TensorRT version 7 or newer. +#error Please update your environment and relaunch the compilation. +#endif + +namespace tensorflow { +namespace tensorrt { +namespace convert { + +class TRTOptimizationPass : public grappler::CustomGraphOptimizer { + public: + struct ConversionParams { + string trt_logger_name = "DefaultLogger"; + size_t max_batch_size = -1; + size_t max_workspace_size_bytes = 1 << 30; + TrtPrecisionMode precision_mode = TrtPrecisionMode::FP32; + int minimum_segment_size = 3; + // Whether to create engine on conversion or execution time + bool is_dynamic_op = false; + // maximum number of cached engines + int max_cached_engines = 1; + bool use_calibration = true; + bool use_implicit_batch = true; + ProfileStrategy profile_strategy = ProfileStrategy::kRange; + bool allow_build_at_runtime = true; + bool use_explicit_precision = false; + }; + + TRTOptimizationPass(const string& name = "TRTOptimizationPass") + : name_(name) {} + + string name() const override { return name_; }; + + bool UsesFunctionLibrary() const override { return true; } + + Status Init( + const RewriterConfig_CustomGraphOptimizer* config = nullptr) override; + + Status Optimize(grappler::Cluster* cluster, + const grappler::GrapplerItem& item, + GraphDef* optimized_graph) override; + + private: + const string name_; + + ConversionParams params_; + + std::vector batches_; +}; + +} // namespace convert +} // namespace tensorrt +} // namespace tensorflow + +#endif // GOOGLE_CUDA && GOOGLE_TENSORRT +#endif // TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_TRT_OPTIMIZATION_PASS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/tf2tensorrt/convert/trt_parameters.h b/third_party/tflite-hdrs/tensorflow/compiler/tf2tensorrt/convert/trt_parameters.h new file mode 100644 index 00000000..3f44bb5f --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/tf2tensorrt/convert/trt_parameters.h @@ -0,0 +1,72 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_TRT_PARAMETERS_H_ +#define TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_TRT_PARAMETERS_H_ + +#if GOOGLE_CUDA && GOOGLE_TENSORRT + +#include "tensorflow/core/platform/status.h" + +namespace tensorflow { +namespace tensorrt { + +// The PrecisionMode controls the precision used in TRT converted parts of the +// model. Setting PrecisionMode other than FP32 enables TensorRT to select +// lower-precision implementations when searching for the fastest kernels. +// +// For regularized models whose input dynamic range is approximately one, this +// typically produces significant speedups with negligible change in accuracy. +// There is additional complexity when working with INT8, see Calibration. +// +// - FP32 +// - FP16 Enable FP16 layer selection, with FP32 fallback. +// - INT8 Enable Int8 layer selection, with FP32 and FP16 fallback. +// +// Note that TensorRT will still choose a higher-precision kernel if it results +// in overall lower runtime, or if no low-precision implementation exists. +enum class TrtPrecisionMode { FP32, FP16, INT8 }; + +Status TrtPrecisionModeToName(const TrtPrecisionMode mode, string* name); + +Status TrtPrecisionModeFromName(const string& name, TrtPrecisionMode* mode); + +string DebugString(const TrtPrecisionMode mode); + +// Optimization profile generation strategies. +// - `kRange`: create one profile that works for inputs with dimension values +// in the range of [min_dims, max_dims] where min_dims and max_dims are +// derived from the provided inputs. +// - `kOptimal`: create one profile for each input. The profile only works for +// inputs with the same dimensions as the input it is created for. The GPU +// engine will be run with optimal performance with such inputs. +// - `kRangeOptimal`: create the profiles for both `Range` and `Optimal`. +// - `kImplicitBatchModeCompatible`: create the profiles that will produce the +// same GPU engines as the implicit_batch_mode would produce. +enum class ProfileStrategy { + kRange, + kOptimal, + kRangeOptimal, + kImplicitBatchModeCompatible, +}; + +string ProfileStrategyToName(const ProfileStrategy strategy); +Status ProfileStrategyFromName(const string& name, ProfileStrategy* strategy); + +} // namespace tensorrt +} // namespace tensorflow + +#endif // GOOGLE_CUDA && GOOGLE_TENSORRT +#endif // TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_TRT_PARAMETERS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/tf2tensorrt/convert/utils.h b/third_party/tflite-hdrs/tensorflow/compiler/tf2tensorrt/convert/utils.h new file mode 100644 index 00000000..3e2d54f4 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/tf2tensorrt/convert/utils.h @@ -0,0 +1,399 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_UTILS_H_ +#define TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_UTILS_H_ + +#include +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/types/optional.h" +#include "tensorflow/compiler/tf2tensorrt/common/utils.h" +#include "tensorflow/compiler/tf2tensorrt/utils/trt_tensor_proxy.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/util/env_var.h" + +#if GOOGLE_CUDA && GOOGLE_TENSORRT +#include "third_party/tensorrt/NvInfer.h" + +#define TFTRT_ERROR(func, ...) \ + do { \ + return func("TFTRT::", __FUNCTION__, ":", __LINE__, ": ", __VA_ARGS__); \ + } while (0) + +#define TFTRT_CHECK_SHAPE_TENSOR(tensor) \ + if (!IsTrtShapeTensorCompatible(tensor)) { \ + TFTRT_ERROR(errors::InvalidArgument, "Tensor of type ", \ + DebugString(tensor.dtype()), " having shape ", \ + tensor.shape().DebugString(), " is not TRT compatible"); \ + } + +namespace tensorflow { +namespace tensorrt { + +static constexpr char kCastOutputTypeAttrName[] = "DstT"; + +#if !IS_TRT_VERSION_GE(8, 2, 0, 0) +template +struct TrtDestroyer { + void operator()(T* t) { + if (t) t->destroy(); + } +}; +template +using TrtUniquePtrType = std::unique_ptr>; +#else +template +using TrtUniquePtrType = std::unique_ptr; +#endif + +// Define a hash function for vector because it is used as the key +// for the engine cache. +struct VectorTensorShapeHasher { + std::size_t operator()(const std::vector& key) const { + return std::hash()(TensorShapeUtils::ShapeListString(key)); + } +}; + +using absl::StrAppend; +using absl::StrCat; + +// This utility template converts an arithmetic type to a string. This function +// is necessary to allow the following function to behave recursively: +// `string DebugString(const std::vector&)`. +template ::value, CType>::type> +string DebugString(const CType& el) { + string el_str = std::to_string(el); + // Prettify std::to_string which can sometimes returns 1.50000 instead of 1.5. + // In short it removes trailing 0s in a string-formatted number. + el_str.erase(el_str.find_last_not_of('0') + 1, std::string::npos); + return el_str; +} +// This utility template converts nested vectors to a string for debug purposes. +template +string DebugString(const std::vector& vector) { + string tmp_s = ""; + for (const auto el : vector) { + StrAppend(&tmp_s, StrCat(DebugString(el), ", ")); + } + return StrCat("{", tmp_s.substr(0, tmp_s.length() - 2), "}"); +} +string DebugString(const nvinfer1::Dims& dims); +string DebugString(const nvinfer1::DataType trt_dtype); +string DebugString(const DataType tf_type); +string DebugString(const nvinfer1::Permutation& permutation, int len); +string DebugString(const ITensorProxyPtr& tensor); +string DebugString(const nvinfer1::ITensor& tensor); +string DebugString(const std::vector& dimvec); +string DebugString(const std::vector& shapes); +string DebugString(const std::vector& shapes); + +template +string DebugString(const absl::InlinedVector& data) { + return absl::StrCat("[", absl::StrJoin(data, ","), "]"); +} + +inline bool HasStaticShape(const nvinfer1::Dims& dims) { + if (dims.nbDims < 0) return false; + for (int d = 0; d < dims.nbDims; ++d) { + if (dims.d[d] < 0) return false; + } + return true; +} + +template +bool HasStaticShape(const T& dims) { + return !absl::c_any_of(dims, [](int i) { return i < 0; }); +} + +// Returns whether a shape is compatible with a TRT shape tensor. +template +inline bool IsTrtShapeTensorCompatible(const TensorShapeType& shape) { + return ( + shape.dims() == 0 || + (shape.dims() == 1 && shape.num_elements() <= nvinfer1::Dims::MAX_DIMS)); +} + +// Returns whether a TF tensor could be interpreted as a TRT shape tensor. +inline bool IsTrtShapeTensorCompatible(const Tensor& tensor) { + return tensor.dtype() == DT_INT32 && + IsTrtShapeTensorCompatible(tensor.shape()); +} + +// Adapts various representations of shape (TF Shape, TRT Dims, plain +// containers) and provides methods for properties (length, volume) and +// conversion between types. Note that unlike TF's TensorShape, the underlying +// storage will only contain active dimensions. In the case of scalar shapes, +// `NumDims` is allowed to return 0 or 1, but the `storage_` vector will contain +// 1 element in both cases. In the non-scalar case, `NumDims() == +// storage_.size()`. +class DimsAdapter { + public: + using StorageType = absl::InlinedVector; + + private: + template + using EnableIfNotTensorShapeType = + std::enable_if_t, T>::value>; + + template + using EnableIfInt = std::enable_if_t::value && + std::is_integral::value>; + + public: + //----- Constructors ------ + + // Constructs from an absl::Span. + template + explicit DimsAdapter(absl::Span shape) + : num_dims_(static_cast(shape.size())) { + absl::c_copy(shape, std::back_inserter(storage_)); + } + + // Constructs from an absl::Span. + template + explicit DimsAdapter(const std::vector& shape) + : num_dims_(static_cast(shape.size())) { + absl::c_copy(shape, std::back_inserter(storage_)); + } + + // Constructs from a TRT dims object. + DimsAdapter(const nvinfer1::Dims& dims) : num_dims_(dims.nbDims) { + absl::c_copy(absl::MakeSpan(dims.d, dims.d + std::max(dims.nbDims, 0)), + std::back_inserter(storage_)); + } + + // Constructs explicitly specifying num_dims and storage data. + DimsAdapter(int32_t num_dims, StorageType data) + : num_dims_(num_dims), storage_(std::forward(data)) {} + + // Constructs from a TensorShape or PartialTensorShape. + template + static StatusOr Create(const TensorShapeBase& shape, + bool ignore_first_dim = false) { + if (shape.dims() > nvinfer1::Dims::MAX_DIMS) + return errors::InvalidArgument("dims of TensorShape exceed MAX_DIMS"); + if (ignore_first_dim && shape.dims() <= 0) + return errors::InvalidArgument( + "removing first dim requires explicit batch dimension"); + if (shape.dims() == -1) { + return DimsAdapter(-1, StorageType{}); + } + if (shape.dims() == 0) { + return DimsAdapter(0, StorageType{1}); + } + auto offt = (ignore_first_dim ? 1 : 0); + return DimsAdapter( + absl::MakeSpan(shape.dim_sizes().begin() + offt, shape.dims() - offt)); + } + + // Constructs from a container. + template > + static StatusOr Create(const InputSequence& shape, + bool ignore_first_dim = false) { + if (ignore_first_dim && shape.size() <= 0) { + return errors::InvalidArgument( + "removing first dim requires explicit batch dimension"); + } + return DimsAdapter( + absl::MakeSpan(shape).subspan(ignore_first_dim ? 1 : 0, shape.size())); + } + + //----- Conversion Utilities ------ + + // Converts to an nvinfers::Dims and assign the result to the object passed + // in via the result pointer. + void TrtDims(nvinfer1::Dims* result) const { + result->nbDims = num_dims_; + absl::c_copy(storage_, static_cast(result->d)); + } + + // Converts to an nvinfer1::Dims and return by value. + nvinfer1::Dims AsTrtDims() const { + nvinfer1::Dims result; + TrtDims(&result); + return result; + } + + // Converts to a TensorShape and assigns the result to the object passed in + // via the shape pointer. + Status TensorShape(TensorShape* shape, + std::optional batch_size = std::nullopt) const { + TF_RETURN_IF_ERROR(TensorShapeUtils::MakeShape( + static_cast(storage_.data()), storage_.size(), shape)); + if (batch_size) shape->InsertDim(0, *batch_size); + return OkStatus(); + } + + // Converts to a PartialTensorShape and assigns the result to the object + // passed in via the shape pointer. + Status PartialTensorShape( + PartialTensorShape* shape, + std::optional batch_size = std::nullopt) const { + TF_RETURN_IF_ERROR(TensorShapeUtils::MakeShape( + static_cast(storage_.data()), storage_.size(), shape)); + if (batch_size) shape->InsertDim(0, *batch_size); + return OkStatus(); + } + + // Copies the dimension values to the vector passed in via the shape pointer. + template > + Status Vector(std::vector* shape) const { + shape->clear(); + absl::c_copy(storage_, std::back_inserter(*shape)); + return OkStatus(); + } + + //----- Property Accessors ------ + + // Returns true if the shape has no dynamic dimensions. + bool IsStatic() const { + return !absl::c_any_of(storage_, [](auto i) { return i < 0; }); + } + + // Returns product of all dimensions. + int64_t Volume() const { + return absl::c_accumulate(storage_, static_cast(1), + std::multiplies<>()); + } + + int32_t NumDims() const { return num_dims_; } + + // Returns true if the shape should be interpreted as a scalar. This follows + // TensorRT conversions: a scalar shape can have NumDims()==1 or NumDims()==0, + // but the underlying storage_ container has a single dimension of size 1. + bool IsScalar() const { + return (num_dims_ == 0 || num_dims_ == 1) && storage_.size() == 1 && + storage_[0] == 1; + } + + // Returns true if the dimension storage is empty. This indicates an empty + // shape in both the scalar and non-scalar case. + bool IsEmpty() const { return storage_.empty(); } + + string DebugString() const { + auto vol = absl::c_accumulate(storage_, static_cast(1), + std::multiplies<>()); + return absl::StrCat("DimsAdapter(num_dims=", num_dims_, ",shape=[", + absl::StrJoin(storage_, ","), "],", "vol=", vol, ")"); + } + + // Returns beginning iterator for the underlying storage. + StorageType::const_iterator begin() const { return storage_.begin(); } + + // Returns ending iterator for the underlying storage. + StorageType::const_iterator end() const { return storage_.end(); } + + // Returns the size of the dimension at `idx`. + StorageType::value_type dim(size_t idx) const { return storage_[idx]; } + + // Returns a references to the dimension at `idx`. + StorageType::value_type& dim(size_t idx) { return storage_[idx]; } + + //----- Non-Const Operators ------ + + DimsAdapter& Append(int32_t dim) { + StatusOr is_scalar = IsScalar(); + if (!is_scalar.ok()) return *this; + num_dims_ = *is_scalar ? 2 : num_dims_ + 1; + storage_.push_back(dim); + return *this; + } + + DimsAdapter& Prepend(std::optional dim) { + if (dim) { + num_dims_ = IsScalar() ? 2 : num_dims_ + 1; + storage_.insert(storage_.begin(), *dim); + } + return *this; + } + + Status RemoveBatchDimension() { + if (storage_.empty()) + return errors::InvalidArgument( + "attempted to remove batch dim from scalar"); + num_dims_ -= 1; + storage_.erase(storage_.begin()); + return OkStatus(); + } + + //----- Comparison Operators ------ + + bool operator==(const DimsAdapter& rhs) const { + if (rhs.num_dims_ != num_dims_) return false; + for (int i = 0; i < num_dims_; i++) { + if (rhs.storage_[i] != storage_[i]) return false; + } + return true; + } + + bool operator!=(const DimsAdapter& rhs) const { return !(*this == rhs); } + + private: + int32_t num_dims_{0}; + StorageType storage_{}; +}; + +Status GetNetworkInputShapes(const nvinfer1::INetworkDefinition* network, + std::vector* input_shapes); + +Status TfTypeToTrtType(DataType tf_type, nvinfer1::DataType* trt_type); +Status TrtTypeToTfType(nvinfer1::DataType trt_type, DataType* tf_type); + +// Returns true if an engine built for cached_shapes can also run actual_shapes. +bool AreShapesCompatible(const std::vector& actual_shapes, + const std::vector& cached_shapes); + +// Returns the number of inputs for the engine, which also correspends to the +// number of input tensors for the network. This can differ from the number of +// input bindings, because the number of total input bindings equals the number +// of profiles times the number of engine inputs. +int GetNumberOfEngineInputs(const nvinfer1::ICudaEngine* engine); + +// Returns the string representation for the assigned device or the requested +// device of the given node. +absl::string_view GetDeviceName(const Node* node); + +// Returns the ParsedName representation for the assigned device or the +// requested device string of the given node. If the device string is invalid, +// returns std::nullopt. +std::optional GetDeviceParsedName( + const Node* node); + +// If the given two device assignments as compatible, returns the merge of the +// two assignments. Otherwise, returns std::nullopt. +std::optional MergeIfCompatible( + const DeviceNameUtils::ParsedName& a, const DeviceNameUtils::ParsedName& b); +// Similar to the above, except that the second device assignment is represented +// by a string_view. +std::optional MergeIfCompatible( + const DeviceNameUtils::ParsedName& a, absl::string_view b); + +} // namespace tensorrt +} // namespace tensorflow + +#endif // GOOGLE_CUDA && GOOGLE_TENSORRT +#endif // TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_UTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/tf2tensorrt/convert/weights.h b/third_party/tflite-hdrs/tensorflow/compiler/tf2tensorrt/convert/weights.h new file mode 100644 index 00000000..20b66e98 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/tf2tensorrt/convert/weights.h @@ -0,0 +1,295 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_WEIGHTS_H_ +#define TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_WEIGHTS_H_ + +#if GOOGLE_CUDA && GOOGLE_TENSORRT + +#include + +#include "tensorflow/compiler/tf2tensorrt/convert/utils.h" +#include "tensorflow/compiler/tf2tensorrt/utils/trt_tensor_proxy.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/types.h" +#include "third_party/tensorrt/NvInfer.h" + +namespace tensorflow { +namespace tensorrt { +namespace convert { + +// Class to convert TF compile-time constants (e.g. Const nodes) to TRT weight. +class TRT_ShapedWeights { + public: + explicit TRT_ShapedWeights( + nvinfer1::DataType type = nvinfer1::DataType::kFLOAT); + + // Constructs a weights from another weights. + // + // NOTE: this does not copy the underlying buffer but only increase its + // reference count. + TRT_ShapedWeights(const TRT_ShapedWeights& rhs) = default; + + nvinfer1::Weights GetTrtWeights() const; + + const Tensor& GetTensor() const { return tensor_; } + + // Returns a pointer of type const T to the underlying buffer of the tensor. + template + const T* GetPointer() const { + int64 num_elem = + (tensor_.NumElements() * DataTypeSize(tensor_.dtype())) / sizeof(T); + return tensor_.bit_casted_shaped({num_elem}).data(); + } + + // Returns a pointer of type T to the underlying buffer of the tensor. + template + T* GetPointer() { + int64 num_elem = + (tensor_.NumElements() * DataTypeSize(tensor_.dtype())) / sizeof(T); + return tensor_.bit_casted_shaped({num_elem}).data(); + } + + // Fills all the weight values with value. + template + Status SetValues(T value) { + switch (type_) { + case nvinfer1::DataType::kFLOAT: { + float* ptr = tensor_.flat().data(); + std::fill(ptr, ptr + volume_, value); + break; + } + case nvinfer1::DataType::kHALF: { + Eigen::half* ptr = tensor_.flat().data(); + std::fill(ptr, ptr + volume_, Eigen::half(value)); + break; + } + case nvinfer1::DataType::kINT32: { + int32* ptr = tensor_.flat().data(); + std::fill(ptr, ptr + volume_, value); + break; + } + default: + return errors::InvalidArgument( + "Unsupported data type ", tensorflow::tensorrt::DebugString(type_)); + } + return OkStatus(); + } + + Status SetShape(DimsAdapter dims); + void SetShapeUnsafe(DimsAdapter dims) { shape_ = std::move(dims); } + + // Returns total number of elements. Returning 0 means either some dim is 0 + // or the number of dims is 0. Note that a TF scalar constant is marked as + // Dims{0, {1}}, and has a count() == 1. + int64_t count() const { return volume_; } + + size_t size_bytes() const; + + string DebugString() const; + + template + absl::Span GetSpan() const { + return absl::Span(tensor_.flat().data(), volume_); + } + + template + std::vector ToVector() const { + auto span = GetSpan(); + return std::vector(span.data(), span.data() + span.size()); + } + + nvinfer1::DataType TrtDType() const { return type_; } + + const DimsAdapter& Shape() const { return shape_; } + DimsAdapter& Shape() { return shape_; } + + private: + // The shape of the weights. Defaults to the empty shape. + DimsAdapter shape_; + + // This creation method is only used by TrtWeightStore, which creates the + // underlying buffer. + static StatusOr CreateWithTensor(nvinfer1::DataType type, + DimsAdapter dims, + Tensor tensor); + + nvinfer1::DataType type_; + + // All weights should be stored inside TrtWeightStore to make sure lifetime of + // all the underlying tensors are available until the engine is built. For + // this reason, tensor_ should never be reassigned to a different value that + // is not already present in the TrtWeightStore. + Tensor tensor_; + // Contains the volume of the weight's shape. + int64_t volume_; + + friend class TrtWeightStore; +}; + +// Container for TRT_ShapedWeights. We need this container because TRT does not +// manage the lifetime of the weights buffer, it only keeps a pointer to it and +// requires that the data referenced by the pointer be available until the +// building of engine is complete. For more information see +// https://docs.nvidia.com/deeplearning/sdk/tensorrt-api/c_api/classnvinfer1_1_1_weights.html +// +// TODO(laigd): consider adding garbage collection to the unused weights. +class TrtWeightStore { + public: + // Gets a TRT_ShapedWeights with 'type' and 'dims'. + StatusOr GetTempWeights(nvinfer1::DataType trt_type, + const DimsAdapter& dims); + + // Gets a TRT_ShapedWeights with the same data type and dimensions as + // 'weights'. + StatusOr GetTempWeights(const TRT_ShapedWeights& weights) { + return GetTempWeights(weights.TrtDType(), weights.Shape()); + } + + private: + // The backend storage of the TRT_ShapedWeights. + std::vector store_; +}; + +// Enumerates the possible types of arguments of a converter. This determines +// what object is contained in TRT_TensorOrWeights, and converters can require +// a specific type for each of their arguments. +enum class TRT_ArgumentType { + TENSOR = 0, + WEIGHTS = 1, + RESOURCE = 2, +}; + +struct OpConverterParams; + +// Represents a TRT-style input to a TF node, it can be either a +// ITensorProxyPtr (representing nvinfer1::ITensor* or SimpleITensor), +// or TRT_ShapedWeights which is compile-time constant. +// +// TODO(laigd): maybe rename it to TrtArgument, or mimic XlaCompiler::Argument. +class TRT_TensorOrWeights { + public: + TRT_TensorOrWeights() {} + TRT_TensorOrWeights(ITensorProxyPtr); + TRT_TensorOrWeights(ITensorProxyPtr tensor, int batch_size); + + // Constructs a wrapper for the given ITensor. + // This is used by Converter when building the TRT network, where the ITensor + // is owned by the TRT network being built. See comment for 'trt_tensor_' + // in trt_proxy_tensor.h. + explicit TRT_TensorOrWeights(nvinfer1::ITensor* tensor, int batch_size = -1); + + // Creates a SimpleITensor for trt_dtype and trt_dims and takes ownership of + // the object. Constructs a wrapper for the SimpleITensor. This is used by + // TrtNodeValidator to encapsulate the type and shape information for + // validation of graph nodes, and the created ITensor is fake and temporary, + // and should not be used to build any TRT network. See comment for + // 'simple_tensor_' in trt_proxy_tensor.h. + explicit TRT_TensorOrWeights(nvinfer1::DataType trt_dtype, + const nvinfer1::Dims& trt_dims, int batch_size); + + // Constructs a wrapper for the given weights. + explicit TRT_TensorOrWeights(const TRT_ShapedWeights& weights); + + // Constructs a wrapper for the given resource handle. + explicit TRT_TensorOrWeights(const ResourceHandle& resource); + + TRT_TensorOrWeights(const TRT_TensorOrWeights& rhs); + + void operator=(const TRT_TensorOrWeights& rhs); + + bool is_tensor() const { + return initialized_ && arg_type_ == TRT_ArgumentType::TENSOR; + } + bool is_weights() const { + return initialized_ && arg_type_ == TRT_ArgumentType::WEIGHTS; + } + bool is_resource() const { + return initialized_ && arg_type_ == TRT_ArgumentType::RESOURCE; + } + + ITensorProxyPtr tensor() const; + + ResourceHandle resource() const; + + ITensorProxyPtr as_tensor(const OpConverterParams* params); + + TRT_ShapedWeights& weights() { + DCHECK(is_weights()); + return weights_; + } + + const TRT_ShapedWeights& weights() const { + DCHECK(is_weights()); + return weights_; + } + + nvinfer1::Dims GetTrtDims() const; + + Status GetTfType(DataType* tf_type) const; + + int batch_size() const { return batch_size_; } + + string DebugString() const; + + nvinfer1::DataType TrtDType() const { + if (arg_type_ == TRT_ArgumentType::RESOURCE) { + VLOG(0) << "Calling TrtDType() with a RESOURCE argument is undefined " + "behavior."; + } + return arg_type_ == TRT_ArgumentType::TENSOR ? tensor_proxy_ptr_->getType() + : weights_.TrtDType(); + } + + private: + void set_batch_size(int batch_size) { batch_size_ = batch_size; } + + // First dimension of the TF tensor (NOT tensor_) that is represented by + // tensor_ is treated as the "batch dimension" by TRT, and tensor_'s + // dimensions (obtained via tensor_->getDimensions()) do not contain the batch + // dimension. For example, when a TF tensor with shape (A,B,C) is represented + // in TRT, tensor_->getDimensions() will be (B,C) and batch_size_ will be A. + // + // This requires that all tensors in the subgraph that is converted to a TRT + // engine have the same batch size are represented by the first dimension of + // their shape, and Converter will verify this during conversion. The drawback + // is that currently it cannot convert a graph that doesn't have the batch + // size represented in the shapes or the batch sizes are different. See + // b/118387490 for more details. + // + // If use_implicit_batch is false, batch_size_ is unused and + // tensor_->getDimensions() will contain the entire shape (A,B,C). + // + // tensor_proxy_ptr_ is used when arg_type_ == TENSOR. + ITensorProxyPtr tensor_proxy_ptr_ = nullptr; + int batch_size_ = -1; + + // For DT_RESOURCE arguments (there is no corresponding type in TRT). + // resource_ is used when arg_type_ == RESOURCE. + ResourceHandle resource_; + + // weights_ is used when arg_type_ == WEIGHTS. + TRT_ShapedWeights weights_; + bool initialized_ = false; + TRT_ArgumentType arg_type_ = TRT_ArgumentType::WEIGHTS; + + friend class Converter; +}; +} // namespace convert +} // namespace tensorrt +} // namespace tensorflow + +#endif // GOOGLE_CUDA && GOOGLE_TENSORRT +#endif // TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_WEIGHTS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/tf2tensorrt/plugin/trt_plugin.h b/third_party/tflite-hdrs/tensorflow/compiler/tf2tensorrt/plugin/trt_plugin.h new file mode 100644 index 00000000..8976cc6e --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/tf2tensorrt/plugin/trt_plugin.h @@ -0,0 +1,94 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2TENSORRT_PLUGIN_TRT_PLUGIN_H_ +#define TENSORFLOW_COMPILER_TF2TENSORRT_PLUGIN_TRT_PLUGIN_H_ + +#include + +#include "tensorflow/core/platform/logging.h" + +#if GOOGLE_CUDA && GOOGLE_TENSORRT +#include "third_party/tensorrt/NvInfer.h" + +namespace tensorflow { +namespace tensorrt { + +extern const char* kTfTrtPluginVersion; +extern const char* kTfTrtPluginNamespace; + +// A wrapper class for TensorRT plugin. User application should inherit from +// this class to write custom kernels. +class TrtPlugin : public nvinfer1::IPluginV2Ext { + public: + TrtPlugin() { setPluginNamespace(kTfTrtPluginNamespace); } + + TrtPlugin(const void* serialized_data, size_t length) {} + + TrtPlugin(const TrtPlugin& rhs) : namespace_(rhs.namespace_) {} + + int initialize() noexcept override { return 0; } + + void terminate() noexcept override {} + + void destroy() noexcept override { delete this; } + + void setPluginNamespace(const char* plugin_namespace) noexcept override { + namespace_ = plugin_namespace; + } + + const char* getPluginNamespace() const noexcept override { + return namespace_.c_str(); + } + + protected: + template + void WriteToBuffer(const T& val, char** buffer) const { + *reinterpret_cast(*buffer) = val; + *buffer += sizeof(T); + } + + template + T ReadFromBuffer(const char** buffer) { + T val = *reinterpret_cast(*buffer); + *buffer += sizeof(T); + return val; + } + + private: + std::string namespace_; +}; + +template +class TrtPluginRegistrar { + public: + TrtPluginRegistrar() { + getPluginRegistry()->registerCreator(creator, kTfTrtPluginNamespace); + } + + private: + T creator; +}; + +#define REGISTER_TFTRT_PLUGIN(name) \ + static ::tensorflow::tensorrt::TrtPluginRegistrar \ + plugin_registrar_##name {} + +} // namespace tensorrt +} // namespace tensorflow + +#endif // GOOGLE_CUDA && GOOGLE_TENSORRT + +#endif // TENSORFLOW_COMPILER_TF2TENSORRT_PLUGIN_TRT_PLUGIN_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/tf2tensorrt/segment/segment.h b/third_party/tflite-hdrs/tensorflow/compiler/tf2tensorrt/segment/segment.h new file mode 100644 index 00000000..06a3893d --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/tf2tensorrt/segment/segment.h @@ -0,0 +1,94 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2TENSORRT_SEGMENT_SEGMENT_H_ +#define TENSORFLOW_COMPILER_TF2TENSORRT_SEGMENT_SEGMENT_H_ + +#include +#include + +#include "absl/types/optional.h" +#include "tensorflow/compiler/tf2tensorrt/segment/union_find.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/grappler/costs/graph_properties.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/types.h" + +#if GOOGLE_CUDA && GOOGLE_TENSORRT + +namespace tensorflow { +namespace tensorrt { +namespace segment { + +constexpr char kTftrtOpMaxBatchSizeAttr[] = "_tftrt_op_max_batch_size"; + +struct SegmentOptions { + // This struct holds per graph segmenting parameters. + // Segment must contain at least this many nodes. + int minimum_segment_size = 2; + bool use_implicit_batch = true; + // The maximum batch size used to build the engines in the graph, when + // use_implicit_batch is true. + std::optional maximum_batch_size = std::nullopt; + // When use_implicit_batch is false or when we are building dynamic engines, + // we allow dynamic non-batch dimensions. + bool allow_dynamic_non_batch_dim = false; + // The name of the device to put the segment on. + std::set exclude_node_list; +}; + +struct NodePtrCompare { + bool operator()(const Node* lhs, const Node* rhs) const { + return lhs->name() < rhs->name(); + } +}; + +struct Segment { + Segment() {} + Segment(const ClusterProperty& property, + const std::set& nodes) + : property(property), nodes(nodes) {} + ClusterProperty property; + std::set nodes; +}; + +// Vector of segments, each entry contains a set of node pointers. +using SegmentVector = std::vector; + +// Get the subgraphs of a graph that can be handled by TensorRT. +// +// @param tf_graph Graph of the network. +// @graph_properties is the static graph properties. +// @param candidate_fn A function that returns OK for a Node* if +// that node can be handled by TensorRT. +// @param segments Returns the TensorRT segments/subgraphs. Each entry +// in the vector describes a subgraph by giving a set of the names of +// all the NodeDefs in that subgraph. +// @return the status. +Status SegmentGraph(const Graph* tf_graph, + const grappler::GraphProperties* graph_properties, + const std::function& candidate_fn, + const std::function& input_candidate_fn, + const std::function& output_candidate_fn, + const SegmentOptions& options, SegmentVector* segments); + +} // namespace segment +} // namespace tensorrt +} // namespace tensorflow + +#endif // GOOGLE_CUDA && GOOGLE_TENSORRT + +#endif // TENSORFLOW_COMPILER_TF2TENSORRT_SEGMENT_SEGMENT_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/tf2tensorrt/segment/union_find.h b/third_party/tflite-hdrs/tensorflow/compiler/tf2tensorrt/segment/union_find.h new file mode 100644 index 00000000..41dd9ff1 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/tf2tensorrt/segment/union_find.h @@ -0,0 +1,218 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2TENSORRT_SEGMENT_UNION_FIND_H_ +#define TENSORFLOW_COMPILER_TF2TENSORRT_SEGMENT_UNION_FIND_H_ + +#include "absl/types/optional.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/util/device_name_utils.h" + +#if GOOGLE_CUDA && GOOGLE_TENSORRT + +namespace tensorflow { +namespace tensorrt { +namespace segment { + +// ClusterBatchSize is a data structure to record the batch size we have seen +// for a cluster during segmentation. +// +// With the help of shape inference, all the dynamic batch sizes are converted +// to a negative integer number. +// If the number is -1, then nothing is known about the dynamic batch size. +// Ideally, we should not put nodes with -1 batch size into the same cluster, +// as they will likely have different batch sizes at runtime. However, we +// currently treat -1 as an equivalent class for simple implementation. We may +// need to revise this if it causes performance issues. +// If the number is strictly less than -1, then it represents a equivalent +// class. It is inferred that all the nodes with the same equivalent class +// (strictly less than -1) shall have the same batch size at runtime. +// +// When constructing clusters for implicit batch mode, we support both +// dynamic batch sizes and static batch sizes. As all the nodes inside the same +// cluster shall have the same batch size at runtime, we restrict nodes inside a +// cluster to either have the same dynamic batch size equivalent class or the +// same static batch size value. +// +// Besides, all the nodes with an annotated max batch size inside the same +// cluster shall have the same annotated max batch size. (It is allowed if +// part or all the nodes inside the cluster doesn't have annotated max batch +// size). Static batch sizes are treated as max batch size annotations. The +// converter max batch size is used for an OP with a dynamic batch size and no +// annotated max batch size. +// +// cluster: a = a1[1,3] + a1[1,3] +// ClusterBatchSize: batch_size_ = 1 +// max_batch_size_ = 1 +// +// cluster: b = b1[-1,3] + b2[-1, 3] +// ClusterBatchSize: batch_size_ = -1 +// max_batch_size_ = null +// +// cluster: c = c1[-2,3] + c2[-2, 3](max_batch_size=100) +// ClusterBatchSize: batch_size_ = -2 +// max_batch_size_ = 100 +// +// When constructing cluster for explicit batch mode, all ClusterBatchSize is +// irrelevant. +// + +class ClusterBatchSize { + public: + ClusterBatchSize(); + + bool operator==(const ClusterBatchSize& other); + bool operator!=(const ClusterBatchSize& other) { return !(*this == other); } + + // Sets the batch size assuming that the object doesn't have a batch size yet: + // A non-negative input representing a static batch size value. + // A negative input representing a dynamic batch size equivalent class. + ClusterBatchSize& SetBatchSize(int batch_size); + bool HasBatchSize() const; + int GetBatchSize() const; + + // Sets the max batch size assuming that the object doesn't have a max batch + // size yet. + ClusterBatchSize& SetMaxBatchSize(int max_batch_size); + std::optional GetOptionalMaxBatchSize() const; + + // Merge `other` into the current ClusterBatchSize if the two are not + // conflicting. Two ClusterBatchSizes are conflicting iff they both have a + // value and their values are different. + bool MergeIfCompatible(const ClusterBatchSize& other); + + // Returns a string for the batch size and the annotated max batch size. + // For the batch size: + // If the object has a static batch size, return a string representing a + // non-negative integer. + // If the object has a dynamic batch size, return a string representing a + // negative integer as an equivalent class. + // If the object doesn't have a batch size yet, return "?". + // For the annotated max batch size: + // If the cluster has annotated max batch size in at least one of the nodes, + // return a string representing the annotated max batch size. Otherwise, + // return "?". + std::string ToString() const; + + private: + ClusterBatchSize& SetBatchSize(const std::optional& batch_size); + ClusterBatchSize& SetMaxBatchSize(const std::optional& batch_size); + + std::optional batch_size_; + std::optional max_batch_size_; +}; + +inline std::ostream& operator<<(std::ostream& os, + const ClusterBatchSize& batch_size) { + return os << batch_size.ToString(); +} + +// Represents the accumulated properties of a cluster during segmentation, +// including information about batch size and device assignment. Clusters shall +// have compatible properties in order to be merged together. +class ClusterProperty { + public: + ClusterProperty() {} + ClusterProperty(const ClusterBatchSize& batch_size, + const DeviceNameUtils::ParsedName& device_name); + + // Returns the batch size of the cluster and compresses the path from this + // object to the root object. + const ClusterBatchSize& BatchSize() const { return batch_size_; } + + // Returns the device name of the cluster and compresses the path from this + // object to the root object. + const DeviceNameUtils::ParsedName& DeviceName() const { return device_name_; } + + Status Merge(const ClusterProperty& other); + + private: + ClusterBatchSize batch_size_; + DeviceNameUtils::ParsedName device_name_; +}; + +// Represents a disjoint set of copyable value with type T and accumulated +// property of the values with type P. Most of the methods in this class are +// side-effecting as they also compress the path from the object to the parent +// of its containing set. +template +class UnionFind { + public: + UnionFind() : size_(1), parent_(nullptr) {} + UnionFind(const T& v, const P& p) + : size_(1), parent_(nullptr), value_(v), property_(p) {} + UnionFind(const T& v, P&& p) + : size_(1), parent_(nullptr), value_(v), property_(p) {} + + // Returns the number of elements in the set and compresses the path from + // this object to the root of the set. + int Size() { return FindRoot()->size_; } + + // Returns the accumulated property of all the elements in the set and + // compresses the path from this object to the root of the set. + const P& Property() { return FindRoot()->property_; } + + // Merges this set with 'other'. This updates the size_ and property_ of the + // set. The size_ and property_ of 'other' becomes inaccessible as only the + // size_ and property_ of the root of the set is accessible. + Status Merge(UnionFind* other); + + // Retrieves the value for the root of the set. + const T& ParentValue() { return FindRoot()->value_; } + + // Returns the value for the object. + const T& Value() const { return value_; } + + private: + // Returns the root object for the set and compresses the path from this + // object to the root object. + UnionFind* FindRoot(); + + int size_; + UnionFind* parent_; + T value_; + P property_; +}; + +template +Status UnionFind::Merge(UnionFind* other) { + UnionFind* a = FindRoot(); + UnionFind* b = other->FindRoot(); + if (a == b) return OkStatus(); + + P merged_property(a->property_); + TF_RETURN_IF_ERROR(merged_property.Merge(b->property_)); + b->parent_ = a; + a->size_ += b->size_; + a->property_ = std::move(merged_property); + return OkStatus(); +} + +template +UnionFind* UnionFind::FindRoot() { + if (!parent_) return this; + // Path compression: update intermediate nodes to point to the root of the + // equivalence class. + parent_ = parent_->FindRoot(); + return parent_; +} + +} // namespace segment +} // namespace tensorrt +} // namespace tensorflow + +#endif // GOOGLE_CUDA && GOOGLE_TENSORRT + +#endif // TENSORFLOW_COMPILER_TF2TENSORRT_SEGMENT_UNION_FIND_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/tf2tensorrt/trt_convert_api.h b/third_party/tflite-hdrs/tensorflow/compiler/tf2tensorrt/trt_convert_api.h new file mode 100644 index 00000000..bba45add --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/tf2tensorrt/trt_convert_api.h @@ -0,0 +1,129 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2TENSORRT_TRT_CONVERT_API_H_ +#define TENSORFLOW_COMPILER_TF2TENSORRT_TRT_CONVERT_API_H_ + +#include +#include +#include + +#if GOOGLE_CUDA && GOOGLE_TENSORRT + +#include "tensorflow/compiler/tf2tensorrt/common/utils.h" +#include "tensorflow/compiler/tf2tensorrt/convert/trt_parameters.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/platform/statusor.h" +#include "tensorflow/core/protobuf/meta_graph.pb.h" + +namespace tensorflow { + +struct SavedModelBundle; + +namespace tensorrt { + +struct TfTrtConversionParams { + // Corresponds 'workspaceSize' parameter of + // nvinfer1::IBuilderConfig::setMaxWorkspaceSize. +#if IS_TRT_VERSION_GE(8, 4, 0, 0) + // Must use `LLONG_MAX - 512` to avoid overflow during casting. + size_t max_workspace_size_bytes = LLONG_MAX - 512; +#else + size_t max_workspace_size_bytes = 1 << 30; // 1,073,741,824 +#endif + + // Minimum precision used by the TRT Engine. + TrtPrecisionMode precision_mode = TrtPrecisionMode::FP32; + + // The minimum number of nodes required for a subgraph to be replaced by + // TRTEngineOp. Note that many small TRT subgraphs could be detrimental for + // performance, increasing the minimum segment size can help avoid the + // problem. + int minimum_segment_size = 3; + + // Max number of cached TRT engines for dynamic TRT ops (by default we have + // dynamic TRT ops). + int max_cached_engines = 1; + + // Note that calibration is currently not implemented with the C++ converter. + // This argument is ignored if precision_mode is not INT8. If set to True, the + // implementation will use the user provided inputs to generate calibration + // data. If set to False, quantization nodes will be expected for every tensor + // in the graph (excluding those which will be fused). If a range is missing, + // an error will occur. Please note that accuracy may be negatively affected + // if there is a mismatch between which tensors TRT quantizes and which + // tensors were trained with fake quantization. + bool use_calibration = true; + + // Whether to enable dynamic shape mode for the TRT engines. It is + // recommended to use_dynamic_shape mode to handle dynamic input shape. + // Enabling dynamic shape mode can also improve the conversion rate of graphs + // with static input shape. + bool use_dynamic_shape = true; + + // In dynamic shape mode we create an engine that can handle various input + // shape ranges. We derive the shape optimization profiles for the TRT engines + // in the graph based on user provided input data and profile_strategy. + ProfileStrategy profile_strategy = ProfileStrategy::kRange; + + // Whether to allow building TRT engines at runtime. If no TensorRT engine can + // be found in cache that can handle the given inputs during runtime, then a + // new TensorRT engine is built at runtime if allow_build_at_runtime=True, + // otherwise native TF is used. We recommend to set this value false and build + // the engine in advance, to avoid runtime overhead. + bool allow_build_at_runtime = true; + + // Record the TRT engine as an attribute of the TRTEngineOp. This is only + // valid when max_cached_engines == 1. Note: the frozen graph together with + // the serialized engines have to be below 2GiB (protobuf size limit). If + // convert_to_static_engine = false, then the converted graph_def only + // contains placeholder TRTEngineOp nodes. + bool convert_to_static_engine = true; +}; + +/** + * Converts the graph with TF-TRT. + * + * Performs TF-TRT conversion and returns the converted GraphDef. If inputs is + * not empty and convert_to_static_engine is requested, we also build the + * engines and convert the engines to static engines. + * + * Arguments: + * - frozen_graph_def input graph, it is assumed to be frozen + * - input_names names of the input tensors + * - output_names names of the output tensors + * - inputs tensors that we will use as input while building the TRT engines + * - conv_params parameters for the TF-TRT conversion + * + * Returns the converted graph_def. + */ +StatusOr ConvertAndBuild( + const GraphDef& frozen_graph_def, const std::vector& input_names, + const std::vector& output_names, + const std::vector>& inputs, + const TfTrtConversionParams& conv_params); + +StatusOr ConvertAndBuild( + SavedModelBundle* bundle, + const std::string& signature_key = "serving_default", + const std::vector>& inputs = {}, + const TfTrtConversionParams& conversion_params = TfTrtConversionParams()); + +} // namespace tensorrt +} // namespace tensorflow + +#endif // GOOGLE_CUDA && GOOGLE_TENSORRT + +#endif // TENSORFLOW_COMPILER_TF2TENSORRT_TRT_CONVERT_API_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/tf2tensorrt/utils/py_utils.h b/third_party/tflite-hdrs/tensorflow/compiler/tf2tensorrt/utils/py_utils.h new file mode 100644 index 00000000..b888dc5d --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/tf2tensorrt/utils/py_utils.h @@ -0,0 +1,32 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_PY_UTILS_H_ +#define TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_PY_UTILS_H_ + +#include +#include + +namespace tensorflow { +namespace tensorrt { + +bool IsGoogleTensorRTEnabled(); + +std::vector GetRegisteredOpConverters(); + +} // namespace tensorrt +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_PY_UTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/tf2tensorrt/utils/trt_allocator.h b/third_party/tflite-hdrs/tensorflow/compiler/tf2tensorrt/utils/trt_allocator.h new file mode 100644 index 00000000..2812aa06 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/tf2tensorrt/utils/trt_allocator.h @@ -0,0 +1,73 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_TRT_ALLOCATOR_H_ +#define TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_TRT_ALLOCATOR_H_ + +#include + +#include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/platform/mutex.h" + +#if GOOGLE_CUDA && GOOGLE_TENSORRT +#include "third_party/tensorrt/NvInfer.h" +#endif // GOOGLE_CUDA && GOOGLE_TENSORRT + +namespace tensorflow { +namespace tensorrt { +// std::align is not supported, so this function mimic its behavior. +void* Align(uint64_t alignment, uint64_t size, void*& ptr, uint64_t& space); +} // namespace tensorrt +} // namespace tensorflow + +#if GOOGLE_CUDA && GOOGLE_TENSORRT + +namespace tensorflow { +namespace tensorrt { + +class TRTBaseAllocator : public nvinfer1::IGpuAllocator { + // Base allocator class so we can have a virtual destructor; + public: + // python wrapper seems to be not happy with an pure virtual destructor; + virtual ~TRTBaseAllocator() = default; +}; + +class TRTDeviceAllocator : public TRTBaseAllocator { + // Allocator implementation wrapping TF device allocators. + public: + TRTDeviceAllocator(Allocator* allocator); + + // TODO(aaroey): base class doesn't have a virtual destructor, work with + // Nvidia to fix it. + virtual ~TRTDeviceAllocator() { + VLOG(1) << "Destroying allocator attached to " << allocator_->Name(); + } + void* allocate(uint64_t size, uint64_t alignment, + uint32_t flags) noexcept override; + void free(void* memory) noexcept override; + + private: + mutex mu_; + Allocator* allocator_; + + // supporting alignment from allocation request requires a map to free; + std::unordered_map mem_map_ TF_GUARDED_BY(mu_); +}; + +} // namespace tensorrt +} // namespace tensorflow + +#endif // GOOGLE_CUDA && GOOGLE_TENSORRT +#endif // TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_TRT_ALLOCATOR_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/tf2tensorrt/utils/trt_engine_utils.h b/third_party/tflite-hdrs/tensorflow/compiler/tf2tensorrt/utils/trt_engine_utils.h new file mode 100644 index 00000000..b0935afb --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/tf2tensorrt/utils/trt_engine_utils.h @@ -0,0 +1,82 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_TRT_ENGINE_UTILS_H_ +#define TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_TRT_ENGINE_UTILS_H_ + +#include +#include + +#include "tensorflow/compiler/tf2tensorrt/common/datavec.h" +#include "tensorflow/compiler/tf2tensorrt/common/utils.h" +#include "tensorflow/compiler/tf2tensorrt/utils/trt_shape_optimization_profiles.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/lib/core/status.h" + +#if GOOGLE_CUDA && GOOGLE_TENSORRT +#include "third_party/tensorrt/NvInfer.h" + +namespace tensorflow { +namespace tensorrt { +using ::tsl::StatusOr; + +// Creates a TensorRT execution context. +ExecutionContext CreateExecutionContext(nvinfer1::ICudaEngine* cuda_engine); + +// Sets input buffers for TRT from a list of input tensors. The input tensors +// are either defined by ctx or by input_vec. +Status SetTrtEngineInputs(nvinfer1::ICudaEngine* cuda_engine, + nvinfer1::IExecutionContext* execution_context, + const int trt_profile_idx, + std::vector& buffers, bool use_implicit_batch, + int num_batch, + const TrtShapeOptimizationProfile& profiles, + OpKernelContext* ctx = nullptr, + const DataVec* input_vec = nullptr); + +// Returns the shape of a binding from TensorRT. +// +// The binding is identified by its binding_index. The batch_size argument is +// ignored if use_implicit_batch==false. The shape is returned in the last +// argument. +Status GetTrtBindingShape(const nvinfer1::ICudaEngine* cuda_engine, + const nvinfer1::IExecutionContext* execution_context, + int binding_index, bool use_implicit_batch, + int batch_size, TensorShape& shape); + +// Defines output buffers for TRT. The buffers are allocated by ctx, if ctx is +// not null. Otherwise it is expected that the outputs DataVec is not null, and +// the Tensors in outputs are already allocated. +Status SetTrtEngineOutputs(nvinfer1::ICudaEngine* cuda_engine, + nvinfer1::IExecutionContext* execution_context, + int trt_profile_idx, std::vector& buffers, + bool use_implicit_batch, int batch_size = 0, + OpKernelContext* ctx = nullptr, + DataVec* outputs = nullptr); + +// Enqueues TensorRT inference job. The batch_size argument is only relevant in +// implicit batch mode. +Status TrtEnqueue(nvinfer1::IExecutionContext* execution_context, + std::vector& buffers, cudaStream_t stream, + bool use_implicit_batch, int batch_size = 1); + +} // namespace tensorrt +} // namespace tensorflow + +#endif // GOOGLE_CUDA && GOOGLE_TENSORRT + +#endif // TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_TRT_ENGINE_UTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/tf2tensorrt/utils/trt_execution_context.h b/third_party/tflite-hdrs/tensorflow/compiler/tf2tensorrt/utils/trt_execution_context.h new file mode 100644 index 00000000..05b5cefb --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/tf2tensorrt/utils/trt_execution_context.h @@ -0,0 +1,43 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_TRT_EXECUTION_CONTEXT_H_ +#define TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_TRT_EXECUTION_CONTEXT_H_ + +#if GOOGLE_CUDA && GOOGLE_TENSORRT +#include "third_party/tensorrt/NvInfer.h" + +namespace tensorflow { +namespace tensorrt { + +// A wrapper for the TensorRT execution context which will destroy the TensorRT +// execution context when the object goes out of scope. +class ExecutionContext : public TrtUniquePtrType { + public: + ExecutionContext(nvinfer1::IExecutionContext* context, bool has_memory) + : TrtUniquePtrType(context), + has_device_memory_(has_memory) {} + static ExecutionContext Create(nvinfer1::ICudaEngine* cuda_engine); + + bool HasDeviceMemory() { return has_device_memory_; } + + private: + bool has_device_memory_; +}; + +}; // namespace tensorrt +}; // namespace tensorflow +#endif +#endif diff --git a/third_party/tflite-hdrs/tensorflow/compiler/tf2tensorrt/utils/trt_experimental_features.h b/third_party/tflite-hdrs/tensorflow/compiler/tf2tensorrt/utils/trt_experimental_features.h new file mode 100644 index 00000000..1a502c5f --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/tf2tensorrt/utils/trt_experimental_features.h @@ -0,0 +1,31 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_TRT_EXPERIMENTAL_FEATURES_H_ +#define TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_TRT_EXPERIMENTAL_FEATURES_H_ + +#if GOOGLE_CUDA && GOOGLE_TENSORRT + +namespace tensorflow { +namespace tensorrt { + +bool isExperimentalFeatureActivated(string feature_name); + +} // namespace tensorrt +} // namespace tensorflow + +#endif // GOOGLE_CUDA && GOOGLE_TENSORRT + +#endif // TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_TRT_EXPERIMENTAL_FEATURES_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/tf2tensorrt/utils/trt_int8_calibrator.h b/third_party/tflite-hdrs/tensorflow/compiler/tf2tensorrt/utils/trt_int8_calibrator.h new file mode 100644 index 00000000..2fa22662 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/tf2tensorrt/utils/trt_int8_calibrator.h @@ -0,0 +1,102 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_TRT_INT8_CALIBRATOR_H_ +#define TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_TRT_INT8_CALIBRATOR_H_ + +#include +#include +#include +#include + +#include "tensorflow/core/platform/mutex.h" + +#if GOOGLE_CUDA && GOOGLE_TENSORRT + +#include "third_party/gpus/cuda/include/cuda_runtime_api.h" +#include "third_party/tensorrt/NvInfer.h" + +namespace tensorflow { +namespace tensorrt { +// This class provides a 1 element queue to match TFs push model to +// TRTs pull model for calibration. When TRT implements a means for +// a push calibration This class should be updated accordingly + +// IInt8EntropyCalibrator2 is preferred for TRT 5.1+. +struct TRTInt8Calibrator : public nvinfer1::IInt8EntropyCalibrator2 { + public: + // Construct a calibrator for future calibration. + TRTInt8Calibrator( + const std::unordered_map>& dev_buffers, + int batch_size, string engine_name); + + // Construct a finalized calibrator where we don't need to run calibration any + // more, as the calibration data is provided. + TRTInt8Calibrator(const string& calibration_data); + + ~TRTInt8Calibrator(); + + int getBatchSize() const noexcept override; + + bool getBatch(void* bindings[], const char* names[], + int num_bindings) noexcept override; + + // Feed calibration data to the calibrator, and return true if the data is + // accepted. Return false if the calibrator has been terminated. + bool setBatch(const std::unordered_map& data, + const cudaStream_t stream); + + // Wait until the last batch is consumed by the calibrator and set done. + void waitAndSetDone(); + + // Notify that calibration is done and future batches provided by setBatch() + // will be ignored. + void setDone(); + + // If not null, calibration is skipped. + const void* readCalibrationCache(std::size_t& length) noexcept override; + + void writeCalibrationCache(const void* ptr, + std::size_t length) noexcept override; + + const string& getCalibrationTableAsString() { return calibration_table_; } + + private: + const int batch_size_; + + // mutex for condition_variable + mutex cond_mtx_; + + // condition variable to implement producer-consumer queue for calibration + condition_variable cond_; + + // Is calibration finished? + bool done_; + + // Map to keep tensorrt input buffers and sizes keyed with buffer names + std::unordered_map> dev_buffers_; + + bool calib_running_; + bool batch_is_set_; + + string engine_name_; + string calibration_table_; +}; + +} // namespace tensorrt +} // namespace tensorflow + +#endif // GOOGLE_CUDA && GOOGLE_TENSORRT +#endif // TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_TRT_INT8_CALIBRATOR_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/tf2tensorrt/utils/trt_logger.h b/third_party/tflite-hdrs/tensorflow/compiler/tf2tensorrt/utils/trt_logger.h new file mode 100644 index 00000000..8002df53 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/tf2tensorrt/utils/trt_logger.h @@ -0,0 +1,50 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_TRT_LOGGER_H_ +#define TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_TRT_LOGGER_H_ + +#include "tensorflow/core/platform/types.h" + +#if GOOGLE_CUDA && GOOGLE_TENSORRT +#include "third_party/tensorrt/NvInfer.h" + +namespace tensorflow { +namespace tensorrt { + +// Logger for GIE info/warning/errors +class Logger : public nvinfer1::ILogger { + public: + Logger(string name = "DefaultLogger") : name_(name) {} + void log(nvinfer1::ILogger::Severity severity, + const char* msg) noexcept override; + void suppressLoggerMsgs(nvinfer1::ILogger::Severity severity); + void unsuppressLoggerMsgs(nvinfer1::ILogger::Severity severity); + void unsuppressAllLoggerMsgs() { suppressedMsg_ = 0; } + static Logger* GetLogger(); + + private: + bool isValidSeverity(nvinfer1::ILogger::Severity severity, + const char* msg = nullptr) noexcept; + const string name_; + unsigned int suppressedMsg_ = 0; +}; + +} // namespace tensorrt +} // namespace tensorflow + +#endif // GOOGLE_CUDA && GOOGLE_TENSORRT + +#endif // TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_TRT_LOGGER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache.h b/third_party/tflite-hdrs/tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache.h new file mode 100644 index 00000000..dbcea12a --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache.h @@ -0,0 +1,261 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_TRT_LRU_CACHE_H_ +#define TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_TRT_LRU_CACHE_H_ + +#include +#include +#include + +#include "tensorflow/compiler/tf2tensorrt/convert/utils.h" +#include "tensorflow/compiler/tf2tensorrt/utils/trt_allocator.h" +#include "tensorflow/compiler/tf2tensorrt/utils/trt_engine_utils.h" +#include "tensorflow/compiler/tf2tensorrt/utils/trt_int8_calibrator.h" +#include "tensorflow/compiler/tf2tensorrt/utils/trt_logger.h" +#include "tensorflow/compiler/tf2tensorrt/utils/trt_shape_optimization_profiles.h" +#include "tensorflow/core/framework/resource_mgr.h" +#include "tensorflow/core/lib/core/errors.h" + +#if GOOGLE_CUDA && GOOGLE_TENSORRT +#include "third_party/tensorrt/NvInfer.h" +#endif // GOOGLE_CUDA && GOOGLE_TENSORRT + +namespace tensorflow { +namespace tensorrt { + +template +class LRUCache { + public: + typedef Value value_type; + typedef Key key_type; + typedef HashFunction hasher; + typedef typename std::unordered_map map_type; + typedef typename map_type::iterator iterator; + typedef typename map_type::const_iterator const_iterator; + + LRUCache() : capacity_(0) {} + explicit LRUCache(size_t capacity) : capacity_(capacity) {} + + size_t capacity() const { return capacity_; } + + void reserve(size_t capacity) { + capacity_ = capacity; + DiscardOld(); + } + + size_t size() const { return objects_.size(); } + + size_t count(const key_type& key) const { return objects_.count(key); } + + value_type& at(const key_type& key) { return Touch(key); } + + const_iterator begin() const { return objects_.begin(); } + const_iterator end() const { return objects_.end(); } + + iterator begin() { return objects_.begin(); } + iterator end() { return objects_.end(); } + + template + std::pair emplace(Args&&... args) { + DiscardOld(1); + std::pair result = + objects_.emplace(std::forward(args)...); + key_type key = result.first->first; + if (result.second) { + keys_.push_front(key); + } else { + TouchNoCheck(key); // The key must exist in this case. + } + return result; + } + + private: + std::unordered_map objects_; + std::list keys_; + size_t capacity_; + value_type not_found_value_; + + value_type& Touch(const key_type& key) { + // Check that the key exists, and let it return std::out_of_range error if + // not. + value_type& value = objects_.at(key); + TouchNoCheck(key); + return value; + } + + void TouchNoCheck(const key_type& key) { + auto rank = std::find(keys_.begin(), keys_.end(), key); + if (rank != keys_.begin()) { + keys_.erase(rank); + keys_.push_front(key); + } + } + + // Creates n free positions in cache + void DiscardOld(size_t n = 0) { + DCHECK(capacity_ >= n) << "Insufficient capacity in cache (capacity = " + << capacity_ << ", requested " << n << ")"; + while (objects_.size() > (capacity_ - n)) { + key_type discard_key = keys_.back(); + keys_.pop_back(); + objects_.erase(discard_key); + } + } +}; + +#if GOOGLE_CUDA && GOOGLE_TENSORRT + +struct EngineContext { + EngineContext() {} // Creates an empty context. + EngineContext(TrtUniquePtrType&& cuda_engine, + ExecutionContext&& execution_context) + : cuda_engine_(std::move(cuda_engine)) { + execution_contexts.push_back(std::move(execution_context)); + device_memory_size_ = + cuda_engine_ ? cuda_engine_->getDeviceMemorySize() : 0; + } + EngineContext(TrtUniquePtrType&& cuda_engine, + std::vector&& execution_contexts) + : cuda_engine_(std::move(cuda_engine)), + execution_contexts(std::move(execution_contexts)) { + device_memory_size_ = + cuda_engine_ ? cuda_engine_->getDeviceMemorySize() : 0; + } + + mutex mu; + + nvinfer1::ICudaEngine* GetCudaEngine() { return cuda_engine_.get(); } + + Status GetExecutionContext(int idx, nvinfer1::IExecutionContext** exec_ctx, + bool* has_device_memory) + TF_EXCLUSIVE_LOCKS_REQUIRED(mu) { + if (idx >= execution_contexts.size()) { + return errors::Internal("Requested engine context with index ", idx, + ", but only ", execution_contexts.size(), + "contexts are present."); + } + *exec_ctx = execution_contexts[idx].get(); + *has_device_memory = execution_contexts[idx].HasDeviceMemory(); + return OkStatus(); + } + + int GetNumContexts() { + mutex_lock lock(mu); + return execution_contexts.size(); + } + + size_t GetDeviceMemorySize() { return device_memory_size_; } + + private: + // Note: declaration has to come before execution_contexts, to ensure proper + // order of destruction. + TrtUniquePtrType cuda_engine_; + + public: + // In explicit batch mode, we maintain a vector of contexts for each engine, + // where each context is created for a specific profile. This is because it is + // either not possible or non-trivial to change the profile of a context for + // the following reasons: + // - To switch profiles (from TRT 7), one must first ensure that all inference + // calls in that context are finished. This would require an additional + // synchronization before we call setOptimizationProfile. To avoid this + // extra sync call, we maintain separate execution context for each profile. + // IExecutionContext object is not thread safe: only one thread should use it + // for inference at a time therefore we need a mutex. More details at + // https://docs.nvidia.com/deeplearning/sdk/tensorrt-best-practices/index.html#thread-safety + // Additional discussion about execution context management and thread safety + // at https://github.com/tensorflow/tensorflow/issues/36959 + std::vector execution_contexts TF_GUARDED_BY(mu); + + private: + // Until TRT 8.4 ICudaEngine::getDeviceMemorySize() has a non-negligible + // latency. Since its value remains constant, we can cache it. + size_t device_memory_size_; +}; +// Contains the context required to build the calibration data. +class CalibrationContext { + public: + string TerminateCalibration(); + + // Lookup table for temporary staging areas of input tensors for calibration. + std::unordered_map> device_buffers_; + + // Temporary staging areas for calibration inputs. + std::vector device_tensors_; + + std::unique_ptr calibrator_; + TrtUniquePtrType builder_; + TrtUniquePtrType engine_; + // TODO(sami): Use threadpool threads! + std::unique_ptr thr_; + + private: + mutex mu_; + bool terminated_ TF_GUARDED_BY(mu_) = false; + std::string calibration_table_ TF_GUARDED_BY(mu_); +}; + +ABSL_CONST_INIT extern const absl::string_view kTfTrtContainerName; + +class TRTEngineCacheResource : public ResourceBase { + public: + // According to the TensorRT API, the logger is considered a singleton by the + // TensorRT library, and multiple instances of IRuntime and/or IBuilder must + // all use the same logger. So here we make it a singleton. + // + // TODO(laigd): use this logger in all places where conversion happens. + static Logger& GetLogger(); + + TRTEngineCacheResource(OpKernelContext* ctx, size_t capacity); + + ~TRTEngineCacheResource() override; + + string DebugString() const override; + + // Returns the EngineContext that is compatible with input_shapes. + // Returns nullptr if no compatible EngineContexts is found in cache. + EngineContext* GetEngineContext(const std::vector& input_shapes); + + // Returns the EngineContext that is compatible with profile_id. + // This function should be only called in explicit batch mode where + // cache size is expected to be at most one. + // Returns nullptr if no compatible EngineContexts is found in cache. + EngineContext* GetEngineContext(const int profile_id); + + // Keep device allocator for TRT. + std::unique_ptr allocator_; + + // Declare cache after allocator so that it is destroyed before allocator is. + LRUCache, std::unique_ptr, + VectorTensorShapeHasher> + cache_; + + // TODO(hinsu): Use different calibration context for the available shapes and + // attach it to each item of the cache. + std::unique_ptr calib_ctx_; + + // This object maintains all the optimization profiles during profile + // generation and engine build. During runtime the list of profiles is used to + // look up a matching profile for the input data. + TrtShapeOptimizationProfile profiles_; +}; + +#endif // GOOGLE_CUDA && GOOGLE_TENSORRT + +} // namespace tensorrt +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_TRT_LRU_CACHE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/tf2tensorrt/utils/trt_shape_optimization_profiles.h b/third_party/tflite-hdrs/tensorflow/compiler/tf2tensorrt/utils/trt_shape_optimization_profiles.h new file mode 100644 index 00000000..e2d8fdb6 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/tf2tensorrt/utils/trt_shape_optimization_profiles.h @@ -0,0 +1,351 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_TRT_SHAPE_OPTIMIZATION_PROFILES_H_ +#define TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_TRT_SHAPE_OPTIMIZATION_PROFILES_H_ + +#include +#include +#include +#include + +#include "tensorflow/compiler/tf2tensorrt/common/datavec.h" +#include "tensorflow/compiler/tf2tensorrt/convert/trt_parameters.h" +#include "tensorflow/compiler/tf2tensorrt/convert/utils.h" +#include "tensorflow/compiler/tf2tensorrt/utils/trt_execution_context.h" +#include "tensorflow/compiler/tf2tensorrt/utils/trt_logger.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/strings/strcat.h" + +#if GOOGLE_CUDA && GOOGLE_TENSORRT + +#include "third_party/tensorrt/NvInfer.h" + +namespace tensorflow { +namespace tensorrt { + +// Stores optimization profile parameters (min/opt/max of each input shape). +// +// A TensorRT optimization profile describes the possible min/max values of +// each dynamic input shape along with an optimum value. These values are used +// by the TensorRT builder to select the best kernel for the optimum value among +// those kernels that are valid for all input tensors in the [min, max] range. +struct OptimizationProfileConfig { + // Length of vector == 2*num_inputs to engine. min[0:num_inputs-1] are the min + // input dimensions for execution tensors. If engine has shape input tensors, + // then min[num_inputs + i] store the shape value for input i. For inputs that + // are not shape tensors min = opt = max = {0, {}}. + // + // When the OptimizationProfileConfig is created from the network definition + // (AddProfiles), then each elements of the min, opt, max vectors are defined. + // When the OptimizationProfileConfig object is restored during engine + // deserialization (RestoreProfiles), then some inputs can be pruned + // (see TrtShapeOptimizationProfile::is_pruned_input_). In that case min[i] + // is not defined for pruned inputs (same is true for opt and max). + std::vector min; + std::vector opt; + std::vector max; + + string DebugString() const { + using absl::StrCat; + return StrCat("[min: ", tensorflow::tensorrt::DebugString(min), + ", opt: : ", tensorflow::tensorrt::DebugString(opt), + ", max: ", tensorflow::tensorrt::DebugString(max), "]"); + } + + // Sets the min/opt/max dimensions for profile. + // + // The given min/opt/max dimensions should satisfy the condition + // min <= opt <= max. Additionally TRT requires that the min/opt/max values + // are compatible with the network input. Compatibility is defined the + // following way: let dim be the shape of an input binding and min/opt/max the + // corresponding profile dims. TRT requires that dim.d[k] must be -1 if + // (min.d[k] != dim.d[k] || opt.d[k] != dim.d[k] || max.d[k] != dim.d[k]). + // + // Parameters: + // network - TensorRT network, used to enumerate all the input tensors + // profile - on exit the profile information will be set for each input tensor + // input_mask - 1 for TRT inputs, 0 for TF inputs that are not TRT inputs + Status SetDimensions(const nvinfer1::INetworkDefinition* network, + nvinfer1::IOptimizationProfile* profile, + const std::vector& input_mask) const { + int n_inputs_trt = network->getNbInputs(); + int n_inputs_tf = opt.size() / 2; + /// TODO(lsugy): check that the sum of the mask equals n_inputs. + if (input_mask.size() != n_inputs_tf) { + return errors::Internal("Incorrect input mask size: ", input_mask.size()); + } + int n_mask_true = 0; + for (bool mask_val : input_mask) { + if (mask_val) { + n_mask_true++; + } + } + if (n_mask_true != n_inputs_trt) { + return errors::Internal( + "Number of true elements in input_mask (", n_mask_true, + ") doesn't match expected TRT inputs (", n_inputs_trt, ")"); + } + int j = 0; + for (int i = 0; i < n_inputs_tf; i++) { + if (input_mask[i]) { + const ITensorProxyPtr input = network->getInput(j); + const char* name = input->getName(); + if (input->isShapeTensor()) { + int idx = i + n_inputs_tf; + VLOG(2) << "Setting shape values for " << name << ", " + << ::tensorflow::tensorrt::DebugString(opt[idx]); + profile->setShapeValues(name, nvinfer1::OptProfileSelector::kMIN, + min[idx].d, min[idx].nbDims); + profile->setShapeValues(name, nvinfer1::OptProfileSelector::kOPT, + opt[idx].d, opt[idx].nbDims); + profile->setShapeValues(name, nvinfer1::OptProfileSelector::kMAX, + max[idx].d, max[idx].nbDims); + } + VLOG(2) << "Setting input dimensions for " << name << ", " + << ::tensorflow::tensorrt::DebugString(opt[i]); + profile->setDimensions(name, nvinfer1::OptProfileSelector::kMIN, + min[i]); + profile->setDimensions(name, nvinfer1::OptProfileSelector::kOPT, + opt[i]); + profile->setDimensions(name, nvinfer1::OptProfileSelector::kMAX, + max[i]); + + j++; + } + } + return OkStatus(); + } + + // Returns true if profile range completely includes the given shapes. + bool IncludesShapes(const std::vector& shapes, + bool has_shape_tensor, + const std::vector& shape_values, + const std::vector& is_pruned_input, + const std::vector& is_shape_tensor) const { + // min, max, and opt must have the same size which is already verified in + // SetDimensions. + if (min.size() != shapes.size() * 2 || + (has_shape_tensor && min.size() != shape_values.size() * 2)) { + VLOG(2) << "Profile size mismatch min size " << min.size() + << " vs input shapes size " << shapes.size() << " " + << shape_values.size(); + return false; + } + for (int i = 0; i < shapes.size(); i++) { + if (is_pruned_input[i]) { + continue; + } + auto current_shape = shapes[i]; + // min, max, and opt must have the same nbDims, which is already verified + // in SetDimensions. + if (min[i].nbDims != current_shape.dims()) { + return false; + } + // Check if range [min, max] includes current_shape. + for (int dim = 0; dim < current_shape.dims(); dim++) { + if ((min[i].d[dim] > current_shape.dim_size(dim)) || + (max[i].d[dim] < current_shape.dim_size(dim))) { + return false; + } + } + } + // Check shape values. + if (has_shape_tensor) { + int offset = shapes.size(); + for (int i = 0; i < shape_values.size(); i++) { + if (is_pruned_input[i] || !is_shape_tensor[i]) { + continue; + } + auto shape_val = shape_values[i]; + // min, max, and opt must have the same nbDims, which is already + // verified in SetDimensions. + if (min[i + offset].nbDims != shape_val.nbDims) { + return false; + } + // Check if range [min, max] includes shape_val. + for (int dim = 0; dim < shape_val.nbDims; dim++) { + if (min[i + offset].d[dim] > shape_val.d[dim] || + max[i + offset].d[dim] < shape_val.d[dim]) { + return false; + } + } + } + } + return true; + } +}; + +// Manages Optimization profiles during TRT Engine construction. +// +// An optimization profile describes a range of dimensions for each TRT network +// input, and the optimal dimensions that the auto-tuner should use for +// optimization. +// +// This class stores the list of input shapes that were seen during the +// build/profile_generation_mode phase, and using them it creates a set of +// OptimizationProfileConfigs. These configs will be added to IBuilderConfig +// before the engine is created. +class TrtShapeOptimizationProfile { + public: + TrtShapeOptimizationProfile() {} + + // Stores input shape information during profile_generation_mode. + void AddShape(const std::vector& shapes) { + input_shapes_.push_back(shapes); + input_shape_values_.push_back(actual_shape_values_); + VLOG(1) << "Collected shape(s) " << DebugString(shapes) << " for profiles."; + } + + // Stores the input mask. + void SetInputMask(const std::vector& input_mask) { + input_mask_ = input_mask; + } + + // Collects ShapeTensorCompatible tensor values. This is needed both during + // profile_generation_mode and during normal inference calls. + Status CollectShapeValues(OpKernelContext* ctx); + + // Collects ShapeTensorCompatible tensor values, used only for unit tests. + Status CollectShapeValues(const DataVec& input); + + void clear() { profiles_.clear(); } + + // Returns the profile number that should be used to execute the network with + // the given input shapes. Returns -1 if none of cached profiles are + // compatible with the given input shapes. + int GetProfileNumber(const std::vector& shapes); + + // Creates optimization profiles and add them to the builder config. + Status ConfigureBuilder(nvinfer1::IBuilder* builder, + nvinfer1::IBuilderConfig* config, + const nvinfer1::INetworkDefinition* network); + + // Creates execution contexts for each optimization profile. + Status CreateExecutionContexts(nvinfer1::ICudaEngine* engine, + std::vector* exec_contexts); + + Status SetInputShapeBinding(int input_index, int binding_index, + nvinfer1::ICudaEngine* cuda_engine, + nvinfer1::IExecutionContext* exec_context) const; + + // Creates optimization profiles profiles_ for the set of concrete input + // shapes collected in input_shapes_. The input_partial_shapes of the network + // is used to ensure that the created optimization profiles are compatible + // with the network. + void InitProfiles(const std::vector& input_partial_shapes, + ProfileStrategy strategy); + + void InitCalibProfile(const std::vector& shapes); + + // Returns number of created profiles. + int GetNumProfiles() const; + + bool HasShape() const { return !input_shapes_.empty(); } + bool NeedProfiles() const { return need_profiles_; } + + // Restores profiles from the engine (used after deserialization). + Status RestoreProfiles(const nvinfer1::ICudaEngine* engine, + int n_network_inputs); + + // Whether the network has any shape tensors. + bool HasShapeTensor() const { return has_shape_tensor_; } + + void SetShapeTensorMask(const nvinfer1::INetworkDefinition* network); + + // Whether the optimization profiles describe input that can be handled with + // a static engine (only 1 profile with min=max). + bool IsStaticCompatible() { + return strategy_ == ProfileStrategy::kOptimal && profiles_.size() == 1 +#if !IS_TRT_VERSION_GE(8, 0, 0, 0) + && !HasShapeTensor() +#endif + ; + // TODO(tfeher): remove !HasShapeTensor() condition once the + // FixShapeValueProfile workaround is turned off. + } + + private: + // Set of input shape vetors that we collect during profile_generation_mode. + std::vector> input_shapes_; + + // Input shape values that we collect during profile_generation_mode. If the + // tensor is not compatible with a TRT shape tensor then an empty shape is + // stored. + std::vector> input_shape_values_; + + // Shape values present in the current inference call. + std::vector actual_shape_values_; + + // The optimization profiles generated from input_shapes_. + std::vector profiles_; + + // The optimization profile for calibration. + OptimizationProfileConfig calib_profiles_; + + // A TRTEngineOp can have resource inputs. These are treated as constants: + // their value is read during conversion and stored as weights in the TRT + // engine. This means that resource inputs have no corresponding TRT engine + // input, and we do not need to provide profile information for these. The + // input mask helps to identify the TRT inputs, where we need to define + // optimization profiles. + std::vector input_mask_; + + // Whether the network has any shape tensors. Initially we assume that the + // network might have a shape value input. This will be updated when the + // network is created / engine is deserialized. + bool has_shape_tensor_ = true; + + // Whether the network/engine requires optimization profiles. + bool need_profiles_ = false; + + // Whether an input tensor is a shape tensor. + std::vector is_shape_tensor_; + + // Whether a network input was pruned (only in TRT 7). + std::vector is_pruned_input_; + + // Optimization profile generation strategy. + ProfileStrategy strategy_; + + // Adds optimization profiles to the builder config. + Status AddProfiles(nvinfer1::IBuilder* builder, + nvinfer1::IBuilderConfig* config, + const nvinfer1::INetworkDefinition* network); + + void SetShapeTensorMask(const nvinfer1::ICudaEngine* engine, int n_inputs); + void SetShapeTensorMask( + const std::vector& input_partial_shapes); + + Status SetPrunedMask(const nvinfer1::ICudaEngine* engine, + int n_network_inputs); + + void ImplicitBatchModeCompatibleStrategy( + const std::vector>& collected_shapes); + void OptimalStrategy( + const std::vector>& collected_shapes); + Status RangeStrategy( + const std::vector>& collected_shapes); +}; + +} // namespace tensorrt +} // namespace tensorflow + +#endif // GOOGLE_CUDA && GOOGLE_TENSORRT +#endif // TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_TRT_SHAPE_OPTIMIZATION_PROFILES_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/tf2tensorrt/utils/trt_tensor_proxy.h b/third_party/tflite-hdrs/tensorflow/compiler/tf2tensorrt/utils/trt_tensor_proxy.h new file mode 100644 index 00000000..5eea183f --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/tf2tensorrt/utils/trt_tensor_proxy.h @@ -0,0 +1,458 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_TRT_TENSOR_PROXY_H_ +#define TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_TRT_TENSOR_PROXY_H_ + +#include +#include +#include +#include + +#include "tensorflow/compiler/tf2tensorrt/common/utils.h" +#include "tensorflow/core/platform/logging.h" + +#if GOOGLE_CUDA && GOOGLE_TENSORRT +#include "third_party/tensorrt/NvInfer.h" + +namespace tensorflow { + +namespace tensorrt { + +// SimpleITensor implements part of the ITensor interfaces to support the TF-TRT +// validator, as well as some TF-TRT tests. The former use case only utilizes +// the interfaces related to shape and type information. +class SimpleITensor { + public: + SimpleITensor(nvinfer1::DataType trt_dtype, const nvinfer1::Dims& trt_dims) + : trt_dtype_(trt_dtype), trt_dims_(trt_dims) {} + + SimpleITensor() : dynamic_range_min_(0.0f), dynamic_range_max_(0.0f) {} + SimpleITensor(const nvinfer1::Dims& dims) + : trt_dims_(dims), dynamic_range_min_(0.0f), dynamic_range_max_(0.0f) {} + + SimpleITensor(const std::vector& dims) { + trt_dims_.nbDims = dims.size(); + for (int i = 0; i < dims.size(); ++i) { + trt_dims_.d[i] = dims[i]; + } + dynamic_range_min_ = 0.0f; + dynamic_range_max_ = 0.0f; + } + + void setName(const char* name) {} + + const char* getName() const { return ""; } + + void setDimensions(nvinfer1::Dims dimensions) { trt_dims_ = dimensions; } + + nvinfer1::Dims getDimensions() const { return trt_dims_; } + + void setType(nvinfer1::DataType trt_dtype) { trt_dtype_ = trt_dtype; } + + nvinfer1::DataType getType() const { return trt_dtype_; } + + bool isNetworkInput() const { return false; } + + bool isNetworkOutput() const { return false; } + + void setBroadcastAcrossBatch(bool broadcastAcrossBatch) {} + + bool getBroadcastAcrossBatch() const { return false; } + + nvinfer1::TensorLocation getLocation() const { return location_; } + + void setLocation(nvinfer1::TensorLocation location) { location_ = location; } + bool setDynamicRange(float min, float max) { + dynamic_range_max_ = max; + dynamic_range_min_ = min; + return true; + } + + float getDynamicRange() const { + return (std::abs(dynamic_range_min_) + dynamic_range_max_) / 2.f; + } + bool dynamicRangeIsSet() const { return true; } + + void resetDynamicRange() { + dynamic_range_min_ = 0.f; + dynamic_range_max_ = 0.f; + } + float getDynamicRangeMin() const { return dynamic_range_min_; } + + float getDynamicRangeMax() const { return dynamic_range_max_; } + + void setAllowedFormats(nvinfer1::TensorFormats formats) {} + + nvinfer1::TensorFormats getAllowedFormats() const { return 1; } + + bool isShapeTensor() const { return false; } + bool isExecutionTensor() const { return true; } + + private: + nvinfer1::DataType trt_dtype_; + nvinfer1::Dims trt_dims_; + std::string name_; + nvinfer1::TensorLocation location_; + float dynamic_range_min_; + float dynamic_range_max_; +}; + +enum class TensorType : int { kTRT, kSIMPLE }; + +class ITensorProxy { + public: + //! ITensor not owned + ITensorProxy(nvinfer1::ITensor* trt_tensor) + : trt_tensor_(trt_tensor), ttype_(TensorType::kTRT) {} + + //! SimpleITensor owned + ITensorProxy(SimpleITensor* simple_itensor) + : simple_tensor_(simple_itensor), ttype_(TensorType::kSIMPLE) {} + + //! SimpleITensor owned + explicit ITensorProxy(nvinfer1::DataType trt_dtype, + const nvinfer1::Dims& trt_dims) + : simple_tensor_(std::unique_ptr( + new SimpleITensor(trt_dtype, trt_dims))), + ttype_(TensorType::kSIMPLE) {} + + //! Variants for testing purposes + ITensorProxy() + : simple_tensor_(std::unique_ptr(new SimpleITensor())), + ttype_(TensorType::kSIMPLE) {} + + explicit ITensorProxy(const nvinfer1::Dims& dims) + : simple_tensor_(std::unique_ptr(new SimpleITensor(dims))), + ttype_(TensorType::kSIMPLE) {} + + explicit ITensorProxy(const std::vector& dims) + : simple_tensor_(std::unique_ptr(new SimpleITensor(dims))), + ttype_(TensorType::kSIMPLE) {} + + bool is_trt_tensor() const { + CHECK(validate()); + return trt_tensor_ != nullptr; + } + + bool is_simple_tensor() const { + CHECK(validate()); + return simple_tensor_ != nullptr; + } + + TensorType ttype() const { return ttype_; } + + nvinfer1::ITensor* trt_tensor() const { + CHECK_NOTNULL(trt_tensor_); + CHECK(ttype_ == TensorType::kTRT); + return trt_tensor_; + } + + SimpleITensor* simple_tensor() const { + CHECK_NOTNULL(simple_tensor_); + CHECK(ttype_ == TensorType::kSIMPLE); + return simple_tensor_.get(); + } + + void setName(const char* name) { + switch (ttype_) { + case TensorType::kTRT: + return trt_tensor_->setName(name); + case TensorType::kSIMPLE: + return simple_tensor_->setName(name); + } + LOG(FATAL) << "Unsupported itensor_ type"; + } + + const char* getName() const { + switch (ttype_) { + case TensorType::kTRT: + return trt_tensor_->getName(); + case TensorType::kSIMPLE: + return simple_tensor_->getName(); + } + LOG(FATAL) << "Unsupported itensor_ type"; + } + + void setDimensions(nvinfer1::Dims dimensions) { + switch (ttype_) { + case TensorType::kTRT: + return trt_tensor_->setDimensions(dimensions); + case TensorType::kSIMPLE: + return simple_tensor_->setDimensions(dimensions); + } + LOG(FATAL) << "Unsupported itensor_ type"; + } + + nvinfer1::Dims getDimensions() const { + switch (ttype_) { + case TensorType::kTRT: + return trt_tensor_->getDimensions(); + case TensorType::kSIMPLE: + return simple_tensor_->getDimensions(); + } + LOG(FATAL) << "Unsupported itensor_ type"; + } + + void setType(nvinfer1::DataType type) { + switch (ttype_) { + case TensorType::kTRT: + return trt_tensor_->setType(type); + case TensorType::kSIMPLE: + return simple_tensor_->setType(type); + } + LOG(FATAL) << "Unsupported itensor_ type"; + } + + nvinfer1::DataType getType() const { + switch (ttype_) { + case TensorType::kTRT: + return trt_tensor_->getType(); + case TensorType::kSIMPLE: + return simple_tensor_->getType(); + } + LOG(FATAL) << "Unsupported itensor_ type"; + } + + bool isNetworkInput() const { + switch (ttype_) { + case TensorType::kTRT: + return trt_tensor_->isNetworkInput(); + case TensorType::kSIMPLE: + return simple_tensor_->isNetworkInput(); + } + LOG(FATAL) << "Unsupported itensor_ type"; + } + + bool isNetworkOutput() const { + switch (ttype_) { + case TensorType::kTRT: + return trt_tensor_->isNetworkOutput(); + case TensorType::kSIMPLE: + return simple_tensor_->isNetworkOutput(); + } + LOG(FATAL) << "Unsupported itensor_ type"; + } + + void setBroadcastAcrossBatch(bool broadcastAcrossBatch) { + switch (ttype_) { + case TensorType::kTRT: + return trt_tensor_->setBroadcastAcrossBatch(broadcastAcrossBatch); + case TensorType::kSIMPLE: + return simple_tensor_->setBroadcastAcrossBatch(broadcastAcrossBatch); + } + LOG(FATAL) << "Unsupported itensor_ type"; + } + + bool getBroadcastAcrossBatch() const { + switch (ttype_) { + case TensorType::kTRT: + return trt_tensor_->getBroadcastAcrossBatch(); + case TensorType::kSIMPLE: + return simple_tensor_->getBroadcastAcrossBatch(); + } + LOG(FATAL) << "Unsupported itensor_ type"; + } + + nvinfer1::TensorLocation getLocation() const { + switch (ttype_) { + case TensorType::kTRT: + return trt_tensor_->getLocation(); + case TensorType::kSIMPLE: + return simple_tensor_->getLocation(); + } + LOG(FATAL) << "Unsupported itensor_ type"; + } + + void setLocation(nvinfer1::TensorLocation location) { + switch (ttype_) { + case TensorType::kTRT: + return trt_tensor_->setLocation(location); + case TensorType::kSIMPLE: + return simple_tensor_->setLocation(location); + } + LOG(FATAL) << "Unsupported itensor_ type"; + } + + bool setDynamicRange(float min, float max) { + switch (ttype_) { + case TensorType::kTRT: + return trt_tensor_->setDynamicRange(min, max); + case TensorType::kSIMPLE: + return simple_tensor_->setDynamicRange(min, max); + } + LOG(FATAL) << "Unsupported itensor_ type"; + } + + bool dynamicRangeIsSet() const { + switch (ttype_) { + case TensorType::kTRT: + return trt_tensor_->dynamicRangeIsSet(); + case TensorType::kSIMPLE: + return simple_tensor_->dynamicRangeIsSet(); + } + LOG(FATAL) << "Unsupported itensor_ type"; + } + + void resetDynamicRange() { + switch (ttype_) { + case TensorType::kTRT: + return trt_tensor_->resetDynamicRange(); + case TensorType::kSIMPLE: + return simple_tensor_->resetDynamicRange(); + } + LOG(FATAL) << "Unsupported itensor_ type"; + } + float getDynamicRangeMin() const { + switch (ttype_) { + case TensorType::kTRT: + return trt_tensor_->getDynamicRangeMin(); + case TensorType::kSIMPLE: + return simple_tensor_->getDynamicRangeMin(); + } + LOG(FATAL) << "Unsupported itensor_ type"; + } + + float getDynamicRangeMax() const { + switch (ttype_) { + case TensorType::kTRT: + return trt_tensor_->getDynamicRangeMax(); + case TensorType::kSIMPLE: + return simple_tensor_->getDynamicRangeMax(); + } + LOG(FATAL) << "Unsupported itensor_ type"; + } +#if !IS_TRT_VERSION_GE(8, 0, 0, 0) + float getDynamicRange() const { + switch (ttype_) { + case TensorType::kTRT: + return trt_tensor_->getDynamicRange(); + case TensorType::kSIMPLE: + return simple_tensor_->getDynamicRange(); + } + LOG(FATAL) << "Unsupported itensor_ type"; + } +#endif + void setAllowedFormats(nvinfer1::TensorFormats formats) { + switch (ttype_) { + case TensorType::kTRT: + return trt_tensor_->setAllowedFormats(formats); + case TensorType::kSIMPLE: + return simple_tensor_->setAllowedFormats(formats); + } + LOG(FATAL) << "Unsupported itensor_ type"; + } + + nvinfer1::TensorFormats getAllowedFormats() const { + switch (ttype_) { + case TensorType::kTRT: + return trt_tensor_->getAllowedFormats(); + case TensorType::kSIMPLE: + return simple_tensor_->getAllowedFormats(); + } + LOG(FATAL) << "Unsupported itensor_ type"; + } + + bool isShapeTensor() const { + switch (ttype_) { + case TensorType::kTRT: + return trt_tensor_->isShapeTensor(); + case TensorType::kSIMPLE: + return simple_tensor_->isShapeTensor(); + } + LOG(FATAL) << "Unsupported itensor_ type"; + } + + bool isExecutionTensor() const { + switch (ttype_) { + case TensorType::kTRT: + return trt_tensor_->isExecutionTensor(); + case TensorType::kSIMPLE: + return simple_tensor_->isExecutionTensor(); + } + LOG(FATAL) << "Unsupported itensor_ type"; + } + + private: + bool validate() const { + return (trt_tensor_ && !simple_tensor_) || (!trt_tensor_ && simple_tensor_); + } + + // When ITensorProxy represents an ITensor, the ITensor can be either passed + // by the caller via the constructor that takes an ITensor* as parameter, or + // be created as a SimpleITensor. + // + // In the first case, the ITensor pointer is stored in 'tensor_' below, and + // the ITensor itself is not owned by this class. This method is used by + // Converter (e.g. AddInputTensor) and op converters during TRT network + // construction, where the TRT network owns the ITensor. + // + nvinfer1::ITensor* trt_tensor_ = nullptr; // Not owned. + // In the second case, the created SimpleITensor is stored in + // 'simple_itensor_' below and is owned by this class. SimpleITensor is a fake + // implementation of ITensor and is used for testing and by TrtNodeValidator + // to validate the graph nodes. + std::shared_ptr simple_tensor_ = nullptr; + + TensorType ttype_; +}; + +class ITensorProxyPtr { + public: + ITensorProxyPtr(std::nullptr_t) : p_(nullptr) {} + ITensorProxyPtr(ITensorProxy* p) : p_(p) {} + ITensorProxyPtr(nvinfer1::ITensor* p) : p_(new ITensorProxy(p)) {} + ITensorProxyPtr(SimpleITensor* p) : p_(new ITensorProxy(p)) {} + + ITensorProxyPtr() : p_(new ITensorProxy()) {} + ITensorProxyPtr(const nvinfer1::Dims& dims) : p_(new ITensorProxy(dims)) {} + ITensorProxyPtr(const std::vector& dims) : p_(new ITensorProxy(dims)) {} + + std::shared_ptr p_{nullptr}; + ITensorProxy* operator->() { return p_.get(); } + ITensorProxy* operator->() const { return p_.get(); } + ITensorProxy* operator*() { return p_.get(); } + ITensorProxy* operator*() const { return p_.get(); } +}; + +inline bool operator==(const ITensorProxyPtr& p1, const ITensorProxyPtr& p2) { + if (p1.p_ == nullptr) { + return p2.p_ == nullptr; + } + if (p2.p_ == nullptr) { + return p1.p_ == nullptr; + } + return (p1->ttype() == p2->ttype()) && + ((p1->ttype() == TensorType::kTRT && + p1->trt_tensor() == p2->trt_tensor()) || + (p1->ttype() == TensorType::kSIMPLE && + p1->simple_tensor() == p2->simple_tensor())); +} + +inline bool operator!=(const ITensorProxyPtr& p1, const ITensorProxyPtr& p2) { + return !(p1 == p2); +} + +struct ITensorProxyHash { + size_t operator()(const ITensorProxyPtr& tensor) const { + return reinterpret_cast(tensor.p_.get()); + } +}; + +} // namespace tensorrt +} // namespace tensorflow +#endif // GOOGLE_CUDA && GOOGLE_TENSORRT + +#endif // TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_TRT_TENSOR_PROXY_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/tf2tensorrt/utils/trt_testutils.h b/third_party/tflite-hdrs/tensorflow/compiler/tf2tensorrt/utils/trt_testutils.h new file mode 100644 index 00000000..e0b9a036 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/tf2tensorrt/utils/trt_testutils.h @@ -0,0 +1,183 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_TRT_TESTUTILS_H_ +#define TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_TRT_TESTUTILS_H_ + +#if GOOGLE_CUDA && GOOGLE_TENSORRT + +#include +#include +#include +#include +#include + +#include +#include +#include "absl/strings/str_format.h" +#include "absl/types/span.h" +#include "tensorflow/cc/framework/scope.h" +#include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/compiler/tf2tensorrt/common/utils.h" +#include "tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h" +#include "tensorflow/compiler/tf2tensorrt/utils/trt_engine_utils.h" +#include "tensorflow/core/framework/node_def.pb.h" // NOLINT +#include "tensorflow/core/framework/tensor.pb.h" // NOLINT +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "third_party/tensorrt/NvInfer.h" + +namespace tensorflow { +namespace tensorrt { +namespace convert { +// Creates a node with the given op, inputs, and attributes. +NodeDef MakeNodeDef(const std::string& name, const std::string& op, + const std::vector& inputs, + const std::map attrs = {}); + +// Creates a constant node with the given name and values arranged in the given +// shape. +template +NodeDef MakeConstNodeDef(const std::string& name, const std::vector& vals, + const TensorShape& shape) { + Scope s = Scope::NewRootScope(); + Tensor t = test::AsTensor(vals, shape); + auto const_op = ops::Const(s.WithOpName(name), t); + return const_op.node()->def(); +} + +// Creates a constant node with the given name and values, assuming a 1-D shape. +template +NodeDef MakeConstNodeDef(const std::string& name, const std::vector& vals) { + TensorShape shape; + const std::vector shape_dims = {static_cast(vals.size())}; + TF_EXPECT_OK(TensorShapeUtils::MakeShape(shape_dims, &shape)); + return MakeConstNodeDef(name, vals, shape); +} + +// Creates an nvinfer1::Dims struct from the given vector. +nvinfer1::Dims CreateDims(const std::vector& d); + +// A gmock matcher that check that elements of a float vector match to a given +// tolerance. +::testing::Matcher> ArrayFloatNear( + const std::vector& values, float max_abs_error = 1e-5, + bool nan_sensitive = false); + +// nvinfer1::Dims gMock matchers + +// matches nvinfer1::Dims to initializer list or vector of ints +// Example: EXPECT_THAT(my_dims, DimsAreArray({1, 2, 3})) +MATCHER_P(DimsAreArrayHelper, array_value, + absl::StrFormat("%s [%s]", negation ? "are" : "are not", + ::testing::PrintToString(array_value))) { + if (arg.nbDims != array_value.size()) return false; + for (int i = 0; i < arg.nbDims; ++i) { + if (arg.d[i] != array_value[i]) { + return false; + } + } + return true; +} +using DimsAreArray = DimsAreArrayHelperMatcherP>; + +// nvinfer1::INetworkDefinition gMock matchers + +// Checks that layer names are equal to initializer list or vector of strings. +// Example: EXPECT_THAT(my_network, LayerNamesAreArray({"conv1", "conv2"})) +MATCHER_P(LayerNamesAreArrayHelper, array_value, + absl::StrFormat("layer names %s [%s]", negation ? "are" : "are not", + ::testing::PrintToString(array_value))) { + if (array_value.size() != arg->getNbLayers()) return false; + for (int i = 0; i < arg->getNbLayers(); ++i) { + if (arg->getLayer(i)->getName() == nullptr) { + return false; + } + } + return true; +} +using LayerNamesAreArray = + LayerNamesAreArrayHelperMatcherP>; + +// Checks layer names are all non-empty. +MATCHER(LayerNamesNonEmpty, "") { + for (int i = 0; i < arg->getNbLayers(); ++i) { + if (arg->getLayer(i)->getName() == nullptr) { + return false; + } + } + return true; +} + +// TRT_ShapedWeights gMock matchers. + +// Checks that the weight dimensions are values are equal to the given values. +// Example: EXPECT_THAT(my_weights, +// ShapedWeightsHasDimsAndValues({1, 2},{1.0f, 2.0f})) +MATCHER_P2(ShapedWeightsHasDimsAndValuesHelper, dims_vec, expected_values, "") { + DimsAdapter dims(dims_vec); + if (arg.Shape() != dims) { + return false; + } + if (arg.count() != expected_values.size()) { + return false; + } + using T = typename decltype(expected_values)::value_type; + const T* actual_values = arg.template GetPointer(); + for (int i = 0; i < expected_values.size(); ++i) { + if (expected_values[i] != actual_values[i]) { + return false; + } + } + return true; +} + +template +using ShapedWeightsHasDimsAndValues = + ShapedWeightsHasDimsAndValuesHelperMatcherP2, + std::vector>; + +// std::vector convenience utilities. + +// Creates a new vector by casting all values of the given InCType vector to +// OutCType. +template +std::vector CastVector( + const gtl::ArraySlice& vals) { // non-absl ok + std::vector res(vals.size()); + std::transform(vals.begin(), vals.end(), res.begin(), + [](const InCType in_val) -> OutCType { + return static_cast(in_val); + }); + return res; +} + +// Creates a new vector of the given size and fills it with an increasing +// sequence starting from the given start_value using std::iota. +template +std::vector CreateVectorIota(int size, CType start_value = CType(0)) { + std::vector res(size); + std::iota(res.begin(), res.end(), start_value); + return res; +} + +} // namespace convert +} // namespace tensorrt +} // namespace tensorflow + +#endif // GOOGLE_CUDA && GOOGLE_TENSORRT +#endif // TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_TRT_TESTUTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/tf2xla/const_analysis.h b/third_party/tflite-hdrs/tensorflow/compiler/tf2xla/const_analysis.h new file mode 100644 index 00000000..ea7d9eb8 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/tf2xla/const_analysis.h @@ -0,0 +1,51 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2XLA_CONST_ANALYSIS_H_ +#define TENSORFLOW_COMPILER_TF2XLA_CONST_ANALYSIS_H_ + +#include + +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { + +// Backwards dataflow analysis that finds nodes in a graph that must be +// compile-time constants for us to be able to lower the graph to XLA. +// +// The indices of the arguments to `graph` that must be constant are returned in +// `compile_time_const_arg_indices`, if `compile_time_const_arg_indices` is not +// null. +// +// The ids of the nodes in `graph` that must be constant are returned in +// `compile_time_const_nodes`, if `compile_time_const_nodes` is not null. +// +// If `edge_filter` is non-null, only propagate const-ness along edges for which +// `edge_filter` returns true. +absl::Status BackwardsConstAnalysis( + const Graph& g, std::vector* compile_time_const_arg_indices, + std::vector* compile_time_const_nodes, + FunctionLibraryRuntime* flib_runtime, + std::function edge_filter_input = nullptr); + +// Given an op kernel and function library runtime, return all the indices of +// inputs that need to be compile time constant. +absl::Status GetCompileTimeConstInputs(const OpKernel* op_kernel, + std::vector* const_input_idxs, + FunctionLibraryRuntime* flib_runtime); +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2XLA_CONST_ANALYSIS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/tf2xla/frontend_attributes_util.h b/third_party/tflite-hdrs/tensorflow/compiler/tf2xla/frontend_attributes_util.h new file mode 100644 index 00000000..2f8436fa --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/tf2xla/frontend_attributes_util.h @@ -0,0 +1,36 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_TF2XLA_FRONTEND_ATTRIBUTES_UTIL_H_ +#define TENSORFLOW_COMPILER_TF2XLA_FRONTEND_ATTRIBUTES_UTIL_H_ + +#include + +#include "absl/types/optional.h" +#include "xla/xla_data.pb.h" +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/platform/statusor.h" + +namespace tensorflow { + +// Return the FrontendAttributes stored in the AttrSlice if there are some. +// +// Return an InvalidArgument error if some attributes are present but +// cannot be parsed. +absl::StatusOr> +GetFrontendAttributesFromAttrSlice(const AttrSlice& attrs); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2XLA_FRONTEND_ATTRIBUTES_UTIL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/tf2xla/functionalize_cond.h b/third_party/tflite-hdrs/tensorflow/compiler/tf2xla/functionalize_cond.h new file mode 100644 index 00000000..e37555b0 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/tf2xla/functionalize_cond.h @@ -0,0 +1,291 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2XLA_FUNCTIONALIZE_COND_H_ +#define TENSORFLOW_COMPILER_TF2XLA_FUNCTIONALIZE_COND_H_ + +#include + +#include "tensorflow/compiler/tf2xla/functionalize_control_flow_util.h" +#include "xla/status_macros.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/graph/graph.h" + +namespace tensorflow { + +// Functionalize all the switch-merge nodes of a loop-free graph into If +// nodes. That is, attempt to transform every remaining switch and merge nodes +// in the graph into If nodes. +// +// If `node_filter` is defined, then only conditions for whose nodes +// `node_filter` returns true are functionalized. +// +// Preconditions: +// a) Same as for `FunctionalizeControlFlow` (see comment there). +// b) While loops must have been functionalized before according to +// `node_filter` (e.g., by calling `FunctionalizeWhileLoop` with the same +// filter before calling this function). +absl::Status FunctionalizeCond(Graph* graph, FunctionLibraryDefinition* library, + const NodeFilter& node_filter = {}); + +// Internal functions/classes exposed for testing purposes. +namespace functionalize_cond { + +// All nodes are assumed to be either in no branch, then branch, else branch, +// or both branches (such as merge nodes). +// The code below relies on Else and Then being 0 and 1 (corresponding to the +// switch outputs). Both and Neither are arbitrary. +enum class BranchType { + kElseBranch = 0, + kThenBranch = 1, + kBoth = 2, + kNeither = 3, +}; + +// When we keep track of which switch/merge node's feed into a node, we record +// 1) predicate for non-dead switch node, +// 2) the switch node itself for dead switch node, +// 3) the merge node itself for merge node. +// Case 1) is an optimization. With this optimization, if there are nodes from +// different switch nodes but those switch nodes have the same predicate, the +// nodes will still have same AncestorState, and they will be clustered into a +// single "If". +struct AncestorNode { + enum class AncestorNodeType { + kPred = 0, + kSwitch = 1, + kMerge = 2, + }; + + OutputTensor output_tensor; + AncestorNodeType type; + + // Compare two AncestorNodes by (node id, index, type). + bool operator<(const AncestorNode& other) const; + bool operator==(const AncestorNode& other) const; + + struct Hash { + size_t operator()(const AncestorNode&) const; + }; +}; + +// StateMap is responsible for mapping from each graph Node to +// * a CondState, where each CondState is a map from predicate to branch (i,e., +// what predicates have to hold or not hold). +// * a AncestorState, where each AncestorState is a set of switch/merge nodes +// that are an ancestor of the node in the graph; +// For efficiency, this class interns the CondState (AncestorState), so that +// CondState (AncestorState) equality comparisons are simply pointer +// comparisons. +class StateMap { + public: + explicit StateMap(Graph* graph); + + // Compare two OutputTensors by (node id, index). + struct OutputTensorLess { + bool operator()(const OutputTensor& lhs, const OutputTensor& rhs) const; + }; + + // A node in the graph is executed when multiple conditions hold. Keep track + // of the predicates that must hold for a node to execute. + using CondState = std::map; + + // Every unique ID is mapped to a CondState. + using CondId = const CondState*; + + // Keep track of which switch/merge node's feed into a node's values. + using AncestorState = std::set; + + // Every unique ID is mapped to a AncestorState. + using AncestorId = const AncestorState*; + + // Returns the CondId for a given node. + CondId LookupCondId(const Node* node) const; + + // Returns the unique CondId for CondState. + CondId GetCondId(const CondState& state); + + // Resets the CondId for a given node. + void ResetCondId(const Node* node, CondId id); + + // Returns the AncestorId for a given node. + AncestorId LookupAncestorId(const Node* node) const; + + // Returns the unique AncestorId for CondState. + AncestorId GetAncestorId(const AncestorState& state); + + // Resets the AncestorId for a given node. + void ResetAncestorId(const Node* node, AncestorId id); + + // Marks `node` as dead. + void MarkDead(const Node* node); + + // Determine branch execution of CondState. + BranchType FindBranchOf(CondId id, OutputTensor predicate) const; + + // Returns textual representation of node's CondState. + string CondStateToString(const Node* node) const; + string CondStateToString(CondId id) const; + + // Returns textual representation of node's AncestorState. + string AncestorStateToString(const Node* node) const; + + // Returns whether the cond state is the dead state. + bool IsDead(CondId id) const; + + // Returns whether the cond state is the empty state. + bool IsEmpty(CondId id) const; + + private: + // Hash for CondState and AncestorState. + struct Hash { + size_t operator()(const CondState& map) const; + size_t operator()(const AncestorState& map) const; + }; + + // Set to keep track of unique CondStates. + // Pointers to the entries in the unordered set are used as identifiers: + // unordered_set guarantees that the pointers remain the same. + std::unordered_set condstate_set_; + + // Mapping from Node id to CondId. + std::vector node_to_condid_map_; + + // Track the CondId for newly inserted nodes. We use a vector to quickly map + // from Node id in the original graph to the CondId, but there will be nodes + // added to the original graph (such as If nodes) whose CondState needs to be + // tracked too. + std::unordered_map added_node_condid_mapping_; + + // AncestorId variants of the CondId members. + std::unordered_set ancestorstate_set_; + std::vector node_to_ancestorid_map_; + std::unordered_map added_node_ancestorid_mapping_; + + // Identifier of the dead flow state. The empty flow state is represented with + // a nullptr. + CondId dead_id_; +}; + +// FunctionalizeCond groups all the state used by functionalizing conditionals +// of the given graph together. +class FunctionalizeCond { + public: + // See comment for function `FunctionalizeCond`. + static absl::Status Functionalize(Graph* graph, + FunctionLibraryDefinition* library, + const NodeFilter& node_filter); + + // Build identity node with the same name as the merge that will be replaced + // in case the output is fetched/colocated. + absl::Status AddIdentityNode(const Node* replacee, Node* if_node, int port); + + // Add a If node to the graph defined by def that will, amongst other, replace + // replacee in the graph. + absl::StatusOr AddIfNode(const NodeDef& def, const Node* replacee, + const OutputTensor& predicate); + + // Propagates the state of a newly inserted node. + absl::Status PropagateUpdatedState(const Node* replacee); + + // Dump graph with the CondState annotated. + void DumpGraphWithCondState(const string& name); + + // Adds `switch_id` to the list of Switch node ids. + void AddSwitchId(int switch_id); + + private: + FunctionalizeCond(Graph* graph, FunctionLibraryDefinition* library, + const NodeFilter& node_filter); + + // Performs the actual cond functionalization. Iterate over groups of merge + // nodes (linked by common predicates & ancestor IDs), from innermost to + // outermost, and extract into If nodes. + absl::Status FunctionalizeInternal(); + + // Returns the forward flow state propagated along edge `e`. + // This may modify state_map_. + StateMap::CondId StateAlongEdge(const Edge* e); + + // Determines the CondState and AncestorState of all the nodes in the given + // vector where the input is expected in reverse topological order. + // This populates the state_map_. + absl::Status DetermineStates(std::vector rev_topo_order); + + // Determine the CondState for a given node using the incoming edges + // to the node. Note: it is expected that this node's CondState is only + // determined once its input's CondState is. + absl::Status DetermineCondState(Node* dst) { + if (IsMerge(dst)) return DetermineCondStateMerge(dst); + return DetermineCondStateNonMerge(dst); + } + + // Helper functions for DetermineCondState. + absl::Status DetermineCondStateNonMerge(Node* dst); + absl::Status DetermineCondStateMerge(Node* dst); + + // Determines the dst node's CondState by joining the src and dst's CondState + // where either the dst node is a merge or not. + // These may modify state_map_. + absl::StatusOr JoinCondStatesMerge(Node* merge, + StateMap::CondId src, + StateMap::CondId dst); + absl::StatusOr JoinCondStatesNonMerge(StateMap::CondId src, + StateMap::CondId dst); + + // Determines which switch/merge nodes are ancestors of this node. + absl::Status DetermineAncestorState(Node* dst); + + // Checks if a merge node is redundant and if so removes it from the graph. + absl::Status RemoveRedundantMerge(Node* node); + + // Checks if a switch node is redundant and if so removes it from the graph. + absl::Status RemoveRedundantSwitch(Node* node); + + // Sorts merge nodes (in reverse topological order) in order of increasing + // nesting depth. + void SortMergeNodes(std::vector* merge_order); + + // Deletes all nodes in/consumers reachable from switch/merge nodes that were + // extracted. + void DeleteReachableAndDeadNodes(const std::vector& merge_order); + + // Member used to unique the CondState to a unique CondId (AncestorState to a + // unique AncestorId) and keep track of CondState/CondId + // (AncestorState/AncestorId) per Node. + StateMap state_map_; + + // Mapping from merge nodes to predicate. + std::unordered_map merge_to_predicate_; + + // Mapping from merge nodes to corresponding If node outputs. + std::unordered_map merge_to_replacement_; + + FunctionLibraryDefinition* library_; + Graph* graph_; + + friend class FunctionalizeCondTest; + + std::vector switch_ids_; + + // Controls which nodes are skipped for functionalization. + NodeFilter node_filter_ = {}; +}; + +} // namespace functionalize_cond + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2XLA_FUNCTIONALIZE_COND_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/tf2xla/functionalize_control_flow.h b/third_party/tflite-hdrs/tensorflow/compiler/tf2xla/functionalize_control_flow.h new file mode 100644 index 00000000..ec728885 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/tf2xla/functionalize_control_flow.h @@ -0,0 +1,76 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2XLA_FUNCTIONALIZE_CONTROL_FLOW_H_ +#define TENSORFLOW_COMPILER_TF2XLA_FUNCTIONALIZE_CONTROL_FLOW_H_ + +#include "tensorflow/compiler/tf2xla/functionalize_control_flow_util.h" +#include "xla/status_macros.h" +#include "tensorflow/core/common_runtime/optimization_registry.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/graph/graph.h" + +namespace tensorflow { + +const char kFunctionalizeControlFlowFailureMessage[] = + "Failed to functionalize Control Flow V1 ops. Consider using Control " + "Flow V2 ops instead. See " + "https://www.tensorflow.org/api_docs/python/tf/" + "compat/v1/enable_control_flow_v2."; + +// Transformation that converts tf.while_loop() loops into functional While +// operators and tf.cond() conditionals into function If operators, suitable for +// XLA compilation. +// +// If `node_filter` is defined, then only loops and conditions for whose +// nodes `node_filter` returns true are functionalized. + +// If `include_functions` is true, then loops and conditions inside of functions +// that are associated with nodes in `graph` (e.g., a function called from a +// node in `graph`) are also functionalized, otherwise they are not. +// This also handles transitive cases, e.g., a function body will be +// functionalized when it is called in another function that is called by some +// node in `graph` (and so on). The node filter also applies here. +// +// Precondition: +// For any node in a loop or condition for which `node_filter` returns true, +// all nodes inside of the same loop or condition must also return true +// (including nodes in other nested loops and conditions inside of that loop or +// condition). +// This means that a "not to be functionalized" loop or condition is not allowed +// inside a "to be functionalized" loop or condition. +// +// The user of this function is responsible for using a node filter that +// satisfies the above conditions. +absl::Status FunctionalizeControlFlow(Graph* graph, + FunctionLibraryDefinition* library, + const NodeFilter& node_filter = {}, + bool include_functions = false); + +absl::Status FunctionalizeControlFlowForGraphDef( + GraphDef* graph_def, FunctionLibraryDefinition* library, + const NodeFilter& node_filter = {}, bool include_functions = false); + +// Rewrites the graph by turning V1 control flow structure +// (Switch/Merge/etc.) into V2 control flow structure (If/While), only modifies +// functions that will be executed by XLA. +class FunctionalizeControlFlowForXlaPass : public GraphOptimizationPass { + public: + absl::Status Run(const GraphOptimizationPassOptions& options) override; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2XLA_FUNCTIONALIZE_CONTROL_FLOW_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/tf2xla/functionalize_control_flow_util.h b/third_party/tflite-hdrs/tensorflow/compiler/tf2xla/functionalize_control_flow_util.h new file mode 100644 index 00000000..970f62da --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/tf2xla/functionalize_control_flow_util.h @@ -0,0 +1,111 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2XLA_FUNCTIONALIZE_CONTROL_FLOW_UTIL_H_ +#define TENSORFLOW_COMPILER_TF2XLA_FUNCTIONALIZE_CONTROL_FLOW_UTIL_H_ + +#include "absl/strings/str_join.h" +#include "xla/status_macros.h" +#include "tensorflow/core/graph/control_flow.h" +#include "tensorflow/core/graph/graph.h" + +// Utility functions shared between functionalize cond and while +// or used by other graph optimization passes. + +namespace tensorflow { + +using NodeFilter = std::function; + +// Information about a loop argument. +struct WhileLoopArg { + // Every loop argument has an Enter node. + Node* enter; + + // Is the loop argument a loop-invariant value? Taken from the `is_constant` + // attribute on the Enter node. + bool is_loop_invariant; + + // If 'is_loop_invariant' is true, the following are all nullptr. Non-constant + // arguments must have all of the following nodes: + Node* merge = nullptr; + Node* switch_node = nullptr; + Node* next_iteration = nullptr; + Node* exit = nullptr; +}; + +// Information about a loop frame. +struct WhileLoopFrame { + string name; + + // Pointer to the parent frame. The root frame has a pointer to itself. + WhileLoopFrame* parent = nullptr; + int num_children = 0; + + // Arguments to this loop. + std::vector args; + + // The loop condition of the loop. There should be exactly one loop condition + // in every loop. + Node* loop_cond = nullptr; + + // Set of nodes that belong to the loop frame. + std::unordered_set nodes; + + // After `ExtractWhileLoopFrames` this is true if for all control flow nodes + // of this frame `node_filter` returns true, i.e., the frame should be + // functionalized, and false otherwise. + bool should_be_functionalized = true; +}; + +// Extracts v1 while loops within a graph and creates a map of +// . +// If `node_filter` is defined, then we keep track of frames that should be +// functionalized according to the filter (see comment for +// `FunctionalizeControlFlow` for more details about node filters). +absl::Status ExtractWhileLoopFrames( + const std::vector& cf_info, const Graph* graph, + std::unordered_map* frames, + const NodeFilter& node_filter = {}); + +// Check that the graph has no cycle containing the given node. +absl::Status CheckNodeNotInCycle(const Node* node, const int num_nodes); + +// Comparison function used for sorting nodes consistently. +// a) resource variables are last, and +// b) sort lexicographically by name (for deterministic output). +struct NodeCmpByNameResourcesLast { + bool operator()(const Node* lhs, const Node* rhs) const; +}; + +// Returns the Node* created from the NodeDef in the Graph. +absl::StatusOr AddNodeDefToGraph(const NodeDef& node_def, Graph* graph); + +// Build a retval node of given type and index. +absl::StatusOr BuildRetvalNode(Graph* graph, DataType type, int index); + +// Returns a textual representation of the names of the nodes in the input. +template +string NodesToString(const T& nodes) { + return absl::StrCat("{", + absl::StrJoin(nodes, ",", + [](string* output, const Node* node) { + absl::StrAppend(output, node->name()); + }), + "}"); +} + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2XLA_FUNCTIONALIZE_CONTROL_FLOW_UTIL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/tf2xla/functionalize_while.h b/third_party/tflite-hdrs/tensorflow/compiler/tf2xla/functionalize_while.h new file mode 100644 index 00000000..e9b361f6 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/tf2xla/functionalize_while.h @@ -0,0 +1,40 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2XLA_FUNCTIONALIZE_WHILE_H_ +#define TENSORFLOW_COMPILER_TF2XLA_FUNCTIONALIZE_WHILE_H_ + +#include "tensorflow/compiler/tf2xla/functionalize_control_flow_util.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/graph/graph.h" + +namespace tensorflow { + +// Transformation that converts tf.while_loop() loops into functional While +// operators, suitable for XLA compilation. If lookup_library is provided, use +// it to make the library for control flow self-contained. +// +// If `node_filter` is defined, then only loops for whose nodes `node_filter` +// returns true are functionalized. +// +// Preconditions: +// Same as for `FunctionalizeControlFlow` (see comment there). +absl::Status FunctionalizeWhileLoop(Graph* graph, + FunctionLibraryDefinition* library, + const NodeFilter& node_filter = {}); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2XLA_FUNCTIONALIZE_WHILE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/tf2xla/graph_compiler.h b/third_party/tflite-hdrs/tensorflow/compiler/tf2xla/graph_compiler.h new file mode 100644 index 00000000..6ab20955 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/tf2xla/graph_compiler.h @@ -0,0 +1,92 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2XLA_GRAPH_COMPILER_H_ +#define TENSORFLOW_COMPILER_TF2XLA_GRAPH_COMPILER_H_ + +#include "tensorflow/compiler/tf2xla/xla_compilation_device.h" +#include "tensorflow/compiler/tf2xla/xla_context.h" +#include "xla/client/local_client.h" +#include "tensorflow/core/common_runtime/device.h" +#include "tensorflow/core/common_runtime/device_mgr.h" +#include "tensorflow/core/common_runtime/function.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/notification.h" +#include "tensorflow/core/platform/thread_annotations.h" +#include "tensorflow/core/public/version.h" + +namespace tensorflow { + +// GraphCompiler compiles the graph in topological order in the current +// thread. It also resolves the nondeterminism in the graph by enforcing a +// total order on all inputs to a node. This abstraction helps us create the +// same XLA computation given two structurally equivalent TensorFlow graphs. +// If a function call is visited during the graph traversal, it is then +// compiled through the xla_context into a computation and a `Call` operation +// is inserted to call into that computation. +// +// Note: GraphCompiler was created to remove our dependency to TF Executor in +// the history. There are still some todos so that we can completely decouple +// from Executor. +// +// TODO(yunxing): Remove usage of XlaCompilationDevice. +// +// TODO(yunxing): Remove the hack that wraps XlaExpression within a tensor now +// that we don't use TF Executor to pass around a tensor. +// +// TODO(yunxing): Make XlaOpkernel not a subclass of OpKernel so that it can +// handle a XlaExpression directly instead of a Tensor. This may require our own +// op registration infrastructure instead of FunctionLibraryRuntime. +class GraphCompiler { + public: + GraphCompiler(XlaCompilationDevice* device, Graph* graph, + FunctionLibraryRuntime* flib, + ScopedStepContainer* step_container) + : device_(device), + graph_(graph), + flib_(flib), + step_container_(step_container) {} + + // Compiles the graph. The results are written in xla_context stored in the + // resource_manager of the 'XlaCompilationDevice' that's passed into the + // constructor. + absl::Status Compile(); + + private: + // Partially sets params. This partially set params can be reused + // across multiple nodes visit. + void PartiallySetupParams(OpKernelContext::Params* params); + + // Compiles a functional node and writes result to OpkernelContext. A + // functional node represents a defined computation and should be compiled + // using `compiler_`. + absl::Status CompileFunctionalNode(Node* n, OpKernelContext* op_context); + + XlaCompilationDevice* device_; + Graph* graph_; + FunctionLibraryRuntime* flib_; + ScopedStepContainer* step_container_; + // A buffer to hold tensor inputs to a node, this is reused across the graph + // traversal. + absl::InlinedVector tensor_inputs_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2XLA_GRAPH_COMPILER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/tf2xla/graph_compiler_util.h b/third_party/tflite-hdrs/tensorflow/compiler/tf2xla/graph_compiler_util.h new file mode 100644 index 00000000..ebdf07f7 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/tf2xla/graph_compiler_util.h @@ -0,0 +1,51 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_TF2XLA_GRAPH_COMPILER_UTIL_H_ +#define TENSORFLOW_COMPILER_TF2XLA_GRAPH_COMPILER_UTIL_H_ + +#include + +#include "absl/types/optional.h" +#include "tensorflow/compiler/tf2xla/tf2xla.pb.h" +#include "tensorflow/compiler/tf2xla/xla_compiler.h" +#include "xla/status_macros.h" +#include "tensorflow/core/framework/graph.pb.h" + +namespace tensorflow { + +// Fills in xla_args from the corresponding _Arg nodes in the graph. +absl::Status CreateXlaArgs(const Graph& graph, + std::vector* xla_args); + +// Populate xla_args for the given XLA config. +void PopulateXlaArgs(const tf2xla::Config& config, + std::vector* xla_args); + +// InitGraph creates a graph based on the graph_def, that may then be converted +// to an xla::XlaComputation via ConvertGraphToXla. +// +// The graph is rewritten with _Arg and _Retval nodes, representing the inputs +// and outputs of the function that will be compiled. Each feed id causes a new +// _Arg node to be created, where we first collect all existing edges pointing +// from the named node's output index, and then rewrite them to point from that +// _Arg node instead. Each fetch id causes a new _Retval node to be created, +// with a new edge pointing from the named node's output index to that _Retval +// node. +absl::Status InitGraph(const GraphDef& graph_def, const tf2xla::Config& config, + std::unique_ptr* graph); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2XLA_GRAPH_COMPILER_UTIL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/tf2xla/kernels/case_op.h b/third_party/tflite-hdrs/tensorflow/compiler/tf2xla/kernels/case_op.h new file mode 100644 index 00000000..a4c01bea --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/tf2xla/kernels/case_op.h @@ -0,0 +1,78 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2XLA_KERNELS_CASE_OP_H_ +#define TENSORFLOW_COMPILER_TF2XLA_KERNELS_CASE_OP_H_ + +#include +#include +#include + +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/types.h" + +namespace tensorflow { + +// This TensorFlow op provides a functional switch/case primitive. +// +// The outputs of the branches must agree on the number, types, and +// shapes of the Tensors carried around the two bodies. +// +// Computations in branch bodies may read from and write to resource variables. +// Resource variables may be passed as arguments to the branch function's +// bodies. The XlaCompiler converts resource variable arguments +// into parameters to the XLA computation and moves them to the end of the +// parameter list, and by using the `return_updated_values_for_all_variables` +// we ensure that all variables that appear in the input also appear at the +// end of the branch bodies output. This ensures the branch bodies output +// signatures match. +// +// It is the user's responsibility to ensure that each non-variable _Arg matches +// the corresponding _Retval. +class XlaCaseOp : public XlaOpKernel { + public: + explicit XlaCaseOp(OpKernelConstruction* ctx); + + void Compile(XlaOpKernelContext* ctx) override; + + private: + XlaCaseOp(const XlaCaseOp&) = delete; + void operator=(const XlaCaseOp&) = delete; + + // If the branch_index input is a constant: prunes out all but the branch + // corrresponding to that constant branch index, and returns that branch and + // the literal 0 (as the first and second component of the pair). + // + // If the branch_index input is not a constant: returns unpruned_branches_ and + // the branch_index input. + std::pair, xla::XlaOp> GetPrunedBranchesAndIndex( + XlaOpKernelContext* ctx); + + std::vector unpruned_branches_; + DataTypeVector input_types_; + DataTypeVector output_types_; + bool has_token_input_output_; + std::vector token_input_nodes_; + string original_node_name_; + // Whether to propagate compile time consts into the cond branches. + // This is not supported by default now since it may cause HBM memory + // overheads. + bool propagate_compile_time_consts_ = false; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2XLA_KERNELS_CASE_OP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.h b/third_party/tflite-hdrs/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.h new file mode 100644 index 00000000..f53f9fd0 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.h @@ -0,0 +1,94 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2XLA_KERNELS_CONV_OP_HELPERS_H_ +#define TENSORFLOW_COMPILER_TF2XLA_KERNELS_CONV_OP_HELPERS_H_ + +#include +#include + +#include "absl/status/statusor.h" +#include "xla/hlo/builder/xla_builder.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/platform/statusor.h" +#include "tensorflow/core/util/padding.h" +#include "tensorflow/core/util/tensor_format.h" + +// This header exposes utilities for translating TensorFlow convolution ops into +// XLA ops. +// +// conv_ops.cc contains lowerings for many of these TF convolution ops (e.g. +// Conv2D, Conv3DBackpropFilterV2), but you might want to use the utilities in +// this header to implement a new and exciting convolution op, for example a +// fused TensorFlow op that contains a convolution and other things. + +namespace tensorflow { + +// We don't support integers for convolutions for GPU, so we list the supported +// types for non-gpu and gpu here. +std::vector GetXlaConvTypesForNonGpu(); +std::vector GetXlaConvTypesForGpu(); + +// ConvOpAttrs contains all of the metadata necessary to specify a TF or XLA +// convolution. +struct ConvOpAttrs { + // Constructs a ConvOpAttrs, reading most of the attributes from `ctx`. + static absl::StatusOr Create(int num_spatial_dims, + bool depthwise, + OpKernelConstruction* ctx); + + bool depthwise; + int num_spatial_dims; + std::vector dilations; + std::vector strides; + Padding padding; + std::vector explicit_paddings; + TensorFormat data_format; +}; + +// Helper for the general Conv Op. +struct ConvNDOpAttrs { + // Constructs a ConvOpAttrs, reading most of the attributes from `ctx`. + static absl::StatusOr Create(OpKernelConstruction* ctx); + + int groups; + int batch_dims; + std::vector dilations; + std::vector strides; + Padding padding; + std::vector explicit_paddings; + TensorFormat data_format; +}; + +// Creates a new XLA forward or backward convolution with the given inputs and +// attributes. +absl::StatusOr MakeXlaForwardConvOp(absl::string_view type_string, + xla::XlaOp conv_input, + xla::XlaOp filter, + const ConvOpAttrs& attrs); +absl::StatusOr MakeXlaBackpropInputConvOp( + absl::string_view type_string, const xla::Shape& input_shape, + xla::XlaOp filter, xla::XlaOp out_backprop, const ConvOpAttrs& attrs, + xla::XlaOp* input_sizes = nullptr); +absl::StatusOr MakeXlaBackpropFilterConvOp( + absl::string_view type_string, xla::XlaOp activations, + const xla::Shape& filter_shape, xla::XlaOp gradients, + const ConvOpAttrs& attrs); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2XLA_KERNELS_CONV_OP_HELPERS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/tf2xla/kernels/cwise_ops.h b/third_party/tflite-hdrs/tensorflow/compiler/tf2xla/kernels/cwise_ops.h new file mode 100644 index 00000000..d22e6eb7 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/tf2xla/kernels/cwise_ops.h @@ -0,0 +1,81 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// XLA-specific base classes for Unary and Binary Ops. + +#ifndef TENSORFLOW_COMPILER_TF2XLA_KERNELS_CWISE_OPS_H_ +#define TENSORFLOW_COMPILER_TF2XLA_KERNELS_CWISE_OPS_H_ + +#include +#include +#include + +#include "absl/types/span.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "xla/client/client_library.h" +#include "xla/hlo/builder/xla_builder.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/util/bcast.h" + +namespace tensorflow { + +// Coefficient-wise binary operations. Each binary Op expects two +// inputs that can be broadcast to the same shape. The base class +// contains pure virtual methods to override: description is a textual +// description of the operation; and Computation adds the +// implementation of the operation to a xla::XlaBuilder. For most +// arithmetic Ops XLA handles the broadcasting automatically given the input +// tensors. +class XlaBinaryOp : public XlaOpKernel { + public: + explicit XlaBinaryOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + const DataType lhs = BaseType(input_type(0)); + const DataType rhs = BaseType(input_type(1)); + OP_REQUIRES(ctx, lhs == rhs, + errors::InvalidArgument("Input types of binary op must match")); + } + ~XlaBinaryOp() override = default; + + // Implement the (tensor,tensor)->tensor lambda that should be + // applied to the inputs. The desired computation should be added to + // 'tc->builder()' and '(lhs,rhs)' are the function's inputs and + // (lhs_shape,rhs_shape) are their respective + // shapes. 'broadcast_helper' contains metadata about the shapes of + // the inputs and the dimensions that need to be broadcast, which + // may be useful for Ops that can't use standard XLA automatic + // broadcasting. 'extend_dimension' is non-empty if lhs and rhs have + // different ranks, and indicates which dimensions of the + // higher-rank input should be matched when broadcasting the + // lower-rank input. See comment below and the documentation on broadcasting + // in the XLA documentation. + virtual xla::XlaOp Computation( + XlaOpKernelContext* ctx, const xla::XlaOp& lhs, + const absl::Span& lhs_shape, const xla::XlaOp& rhs, + const absl::Span& rhs_shape, const BCast& broadcast_helper, + const std::vector& extend_dimensions) = 0; + + void Compile(XlaOpKernelContext* ctx) override; + + // Helper function that performs the broadcasting described by + // 'broadcast_helper', yielding arguments 'lhs' and 'rhs' that have the same + // shape. + static std::pair Broadcast( + xla::XlaOp lhs, xla::XlaOp rhs, const BCast& broadcast_helper); +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2XLA_KERNELS_CWISE_OPS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/tf2xla/kernels/elu_op.h b/third_party/tflite-hdrs/tensorflow/compiler/tf2xla/kernels/elu_op.h new file mode 100644 index 00000000..09c88fcb --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/tf2xla/kernels/elu_op.h @@ -0,0 +1,26 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_TF2XLA_KERNELS_ELU_OP_H_ +#define TENSORFLOW_COMPILER_TF2XLA_KERNELS_ELU_OP_H_ + +#include "xla/hlo/builder/lib/constants.h" +#include "xla/hlo/builder/xla_builder.h" + +namespace xla { +XlaOp Elu(XlaOp x); +XlaOp Selu(XlaOp x); +} // namespace xla + +#endif // TENSORFLOW_COMPILER_TF2XLA_KERNELS_ELU_OP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h b/third_party/tflite-hdrs/tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h new file mode 100644 index 00000000..8a8a6666 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h @@ -0,0 +1,52 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Helper methods for XLA Gather Ops. + +#ifndef TENSORFLOW_COMPILER_TF2XLA_KERNELS_GATHER_OP_HELPERS_H_ +#define TENSORFLOW_COMPILER_TF2XLA_KERNELS_GATHER_OP_HELPERS_H_ + +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "xla/client/client_library.h" +#include "xla/hlo/builder/xla_builder.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/util/bcast.h" + +namespace tensorflow { + +// Adds to builder an XLA computation that performs a gather on input (of +// shape input_shape) keyed on indices (of shape indices_shape). +// +// index_type must be must be DT_INT32 or DT_INT64. +// If `indices_are_nd` is true, the last dimension of `indices` are treated as +// a multidimensional index values. Otherwise, `indices` is treated as a tensor +// of scalar indices. +absl::Status XlaGather(const xla::XlaOp& input, const TensorShape& input_shape, + const xla::XlaOp& indices, + const TensorShape& indices_shape, int64_t axis, + bool indices_are_nd, DataType dtype, DataType index_type, + xla::XlaBuilder* builder, xla::XlaOp* gather_output); + +// The implementation of Gather and ResourceGather through XLA. Uses `input` as +// the input instead of context->input(0) in order to allow ResourceGather to +// handle obtaining the data from the ResourceVariable. +absl::Status XlaGatherWithBatchDimsOpImpl(XlaOpKernelContext* context, + xla::XlaOp input, + const TensorShape& input_shape, + int batch_dims, + xla::XlaOp* gather_output); +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2XLA_KERNELS_GATHER_OP_HELPERS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/tf2xla/kernels/if_op.h b/third_party/tflite-hdrs/tensorflow/compiler/tf2xla/kernels/if_op.h new file mode 100644 index 00000000..fc6dd2e0 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/tf2xla/kernels/if_op.h @@ -0,0 +1,70 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2XLA_KERNELS_IF_OP_H_ +#define TENSORFLOW_COMPILER_TF2XLA_KERNELS_IF_OP_H_ + +#include + +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +// This TensorFlow op provides a functional conditional primitive. +// +// The outputs of the then/else branches must agree on the number, types, and +// shapes of the Tensors carried around the two bodies. +// +// Computations in then/else bodies may read from and write to resource +// variables. +// Resource variables may be passed as arguments to the then/else function's +// bodies. The XlaCompiler converts resource variable arguments +// into parameters to the XLA computation and moves them to the end of the +// parameter list, and by using the `return_updated_values_for_all_variables` +// we ensure that all variables that appear in the input also appear at the +// end of the then/else bodies output. This ensures the then/else bodies output +// signatures match. +// +// It is the user's responsibility to ensure that each non-variable _Arg matches +// the corresponding _Retval. +class XlaIfOp : public XlaOpKernel { + public: + explicit XlaIfOp(OpKernelConstruction* ctx); + + void Compile(XlaOpKernelContext* ctx) override; + + private: + XlaIfOp(const XlaIfOp&) = delete; + void operator=(const XlaIfOp&) = delete; + + NameAttrList then_branch_; + NameAttrList else_branch_; + DataType cond_type_; + DataTypeVector input_types_; + DataTypeVector output_types_; + std::vector output_shapes_; + bool has_token_input_output_; + std::vector token_input_nodes_; + string original_node_name_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2XLA_KERNELS_IF_OP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/tf2xla/kernels/if_while_utils.h b/third_party/tflite-hdrs/tensorflow/compiler/tf2xla/kernels/if_while_utils.h new file mode 100644 index 00000000..1800e5a6 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/tf2xla/kernels/if_while_utils.h @@ -0,0 +1,54 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2XLA_KERNELS_IF_WHILE_UTILS_H_ +#define TENSORFLOW_COMPILER_TF2XLA_KERNELS_IF_WHILE_UTILS_H_ + +#include +#include + +#include "absl/container/inlined_vector.h" +#include "absl/status/status.h" +#include "tensorflow/compiler/tf2xla/xla_compiler.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/core/common_runtime/function_body.h" +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { + +extern const char kPropagateCompileTimeConsts[]; + +// Convert arguments in `args` to constants provided they are compile-time +// constants and they satisfy the condition in `should_resolve_constant`. The +// argument `xla_expression_offset` determines what offset is needed to get the +// input expression from context given the argument index in `args`. +// +// Returns a list of indices which were converted to constants. +absl::InlinedVector ConvertCompileTimeConstArgumentsToConst( + XlaOpKernelContext* ctx, std::vector* args, + int xla_expression_offset, + std::function should_resolve_constant); + +// Find and populate `must_be_const_nodes` and `body` of the function +// corresponding to the kernel with context `ctx` with name `func_name`. +absl::Status FindMustBeConstNodes(XlaOpKernelContext* ctx, + const NameAttrList& func_name, + std::vector* must_be_const_nodes, + const FunctionBody** body); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2XLA_KERNELS_IF_WHILE_UTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/tf2xla/kernels/image_resize_ops.h b/third_party/tflite-hdrs/tensorflow/compiler/tf2xla/kernels/image_resize_ops.h new file mode 100644 index 00000000..8d0fff23 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/tf2xla/kernels/image_resize_ops.h @@ -0,0 +1,62 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_TF2XLA_KERNELS_IMAGE_RESIZE_OPS_H_ +#define TENSORFLOW_COMPILER_TF2XLA_KERNELS_IMAGE_RESIZE_OPS_H_ + +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "xla/primitive_util.h" +#include "xla/xla_data.pb.h" +#include "tensorflow/core/framework/op_kernel.h" + +namespace tensorflow { + +class ResizeNearestNeighborOp : public XlaOpKernel { + public: + explicit ResizeNearestNeighborOp(OpKernelConstruction* ctx); + void Compile(XlaOpKernelContext* ctx) override; + + protected: + bool align_corners_ = true; + bool half_pixel_centers_ = true; + bool is_kernel_bilinear_ = false; +}; + +class ResizeBilinearOp : public XlaOpKernel { + public: + explicit ResizeBilinearOp(OpKernelConstruction* ctx); + + void Compile(XlaOpKernelContext* ctx) override; + + protected: + bool align_corners_ = true; + bool half_pixel_centers_ = true; + bool is_kernel_bilinear_ = true; +}; + +class ResizeBilinearGradOp : public XlaOpKernel { + public: + explicit ResizeBilinearGradOp(OpKernelConstruction* ctx); + + void Compile(XlaOpKernelContext* ctx) override; + + protected: + bool align_corners_; + bool half_pixel_centers_ = true; + xla::PrimitiveType output_type_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2XLA_KERNELS_IMAGE_RESIZE_OPS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/tf2xla/kernels/index_ops.h b/third_party/tflite-hdrs/tensorflow/compiler/tf2xla/kernels/index_ops.h new file mode 100644 index 00000000..ef2b9e6b --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/tf2xla/kernels/index_ops.h @@ -0,0 +1,42 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Declarations of the ArgMax/ArgMin ops using a pure XLA implementation. + +#ifndef TENSORFLOW_COMPILER_TF2XLA_KERNELS_INDEX_OPS_H_ +#define TENSORFLOW_COMPILER_TF2XLA_KERNELS_INDEX_OPS_H_ + +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/core/framework/op_kernel.h" + +namespace tensorflow { + +class XlaArgMinMaxOp : public XlaOpKernel { + public: + explicit XlaArgMinMaxOp(OpKernelConstruction* ctx, bool is_min); + void Compile(XlaOpKernelContext* ctx) override; + + private: + const bool is_min_; // Are we computing ArgMin (true) or ArgMax (false)? +}; + +class XlaArgMaxOp : public XlaArgMinMaxOp { + public: + explicit XlaArgMaxOp(OpKernelConstruction* ctx); +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2XLA_KERNELS_INDEX_OPS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/tf2xla/kernels/light_outside_compilation.h b/third_party/tflite-hdrs/tensorflow/compiler/tf2xla/kernels/light_outside_compilation.h new file mode 100644 index 00000000..f9c42e03 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/tf2xla/kernels/light_outside_compilation.h @@ -0,0 +1,70 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2XLA_KERNELS_LIGHT_OUTSIDE_COMPILATION_H_ +#define TENSORFLOW_COMPILER_TF2XLA_KERNELS_LIGHT_OUTSIDE_COMPILATION_H_ + +#include + +#include "absl/status/statusor.h" +#include "tensorflow/compiler/tf2xla/kernels/callback.pb.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/platform/status.h" + +namespace tensorflow { + +// Using std::map as the maps are presumed to be tiny, and we want a +// deterministic iteration order. +// +// Dimension -> bound. +using DimensionBoundsMap = std::map; + +// Output -> dimension -> bound. +using OutputDimensionBoundsMap = std::map; + +// Generic kernel for registering TF2XLA kernels which call back into the TF +// runtime to run a given kernel defined by the wrapped node. +// +// Cf. example usages in light_outside_compilation_kernels_for_test.cc. +// +// Currently does not support dynamic shape or resource variables. Currently +// works only on GPU. +class LightOutsideCompilationOp : public XlaOpKernel { + public: + explicit LightOutsideCompilationOp(OpKernelConstruction* context); + void Compile(XlaOpKernelContext* ctx) override; + + // Override to provide statically known bounds on output in case of dynamic + // shapes. + virtual absl::StatusOr DynamicOutputDimensions( + const NodeDef& ndef, XlaOpKernelContext* ctx) const { + return OutputDimensionBoundsMap{}; + } + + private: + absl::Status CompileToCustomCallCallingTfKernel(int graph_def_version, + const NodeDef& node_def, + XlaOpKernelContext* ctx); + static absl::Status CallTfKernel(void* stream_handle, void** buffers, + const char* opaque, int opaque_len); + + NodeDef def_; + int graph_def_version_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2XLA_KERNELS_LIGHT_OUTSIDE_COMPILATION_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/tf2xla/kernels/random_ops_util.h b/third_party/tflite-hdrs/tensorflow/compiler/tf2xla/kernels/random_ops_util.h new file mode 100644 index 00000000..11ff4460 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/tf2xla/kernels/random_ops_util.h @@ -0,0 +1,96 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2XLA_KERNELS_RANDOM_OPS_UTIL_H_ +#define TENSORFLOW_COMPILER_TF2XLA_KERNELS_RANDOM_OPS_UTIL_H_ + +#include + +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "xla/hlo/builder/lib/prng.h" +#include "xla/hlo/builder/xla_builder.h" +#include "xla/shape.h" +#include "xla/xla_data.pb.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +inline constexpr int kRandomKeyInputIdx = 1; +inline constexpr int kRandomCounterInputIdx = 2; +inline constexpr int kRandomAlgInputIdx = 3; + +// Returns a tensor containing 'shape' random values uniformly distributed in +// the range [minval, maxval). The raw random bits are generated by the given +// `bit_generator` and converted to the requested data type and range. This +// routine requires 2 32-bit integer seeds and currently only supports 'shape's +// of type F32, S32 and S64. +xla::XlaOp StatelessRngUniform(absl::string_view device_type_string, + xla::XlaOp seeds, const xla::Shape& shape, + xla::XlaOp minval, xla::XlaOp maxval); + +// Converts to bfloat16 if `dtype` equals DT_BFLOAT16, no-op otherwise. +// It masks the last 16 bit. With normal rounding, values near "maxval" would be +// converted to "maxval" which is out of range ["minval", "maxval"). In +// addition, the distribution near the limit is not uniform. +xla::XlaOp MaybeConvertF32ToBF16(xla::XlaOp input, DataType dtype); + +// Combines two signed 32-bit seeds into a single unsigned 64 bit seed. +xla::XlaOp GetU64FromS32Seeds(xla::XlaOp seed0, xla::XlaOp seed1); + +absl::StatusOr GetAlgId(XlaOpKernelContext* ctx, int alg_input_idx); + +xla::RngOutput BitGenerator(xla::RandomAlgorithm const& alg, xla::XlaOp key, + xla::XlaOp counter, const xla::Shape& shape); + +// Gets user specified RNG algorithm. +absl::StatusOr AlgorithmFromInput( + XlaOpKernelContext* ctx, int alg_input_idx, + absl::string_view device_type_string); + +xla::XlaOp MaybeSliceCounter(xla::RandomAlgorithm const& alg, + TensorShape const& counter_shape, + xla::XlaOp counter); + +DataType MaybeConvertBF16ToF32(DataType const& dtype); + +// Builds uniform randoms from a stateless RNG with given data type and device +// type, in the given low and high range, where low and high are expressed in +// XLA functions. +absl::StatusOr BuildUniformRandoms( + XlaOpKernelContext* ctx, DataType dtype, string device_type_string, + TensorShape shape, + std::function lo, + std::function hi); + +// Overloads BuildUniformRandoms where low and high range are expressed in XLA +// ops. +absl::StatusOr BuildUniformRandoms(XlaOpKernelContext* ctx, + DataType dtype, + string device_type_string, + xla::Shape xla_shape, + xla::XlaOp lo, xla::XlaOp hi); +} // namespace tensorflow + +namespace xla { + +int GetCounterSize(RandomAlgorithm const& alg); + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_TF2XLA_KERNELS_RANDOM_OPS_UTIL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/tf2xla/kernels/reduction_ops.h b/third_party/tflite-hdrs/tensorflow/compiler/tf2xla/kernels/reduction_ops.h new file mode 100644 index 00000000..9c222224 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/tf2xla/kernels/reduction_ops.h @@ -0,0 +1,78 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// XLA-specific base classes for Reduction Ops. + +#ifndef TENSORFLOW_COMPILER_TF2XLA_KERNELS_REDUCTION_OPS_H_ +#define TENSORFLOW_COMPILER_TF2XLA_KERNELS_REDUCTION_OPS_H_ + +#include +#include + +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "xla/hlo/builder/xla_builder.h" +#include "xla/xla_data.pb.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/types.pb.h" + +namespace tensorflow { + +// Reduction operations. The base class contains pure virtual methods +// to override: description is a textual description of the mapped +// function; InitialValue constructs the base case for the reduction; +// BuildReducer adds the implementation of the reduction lambda to a +// xla::XlaBuilder and BuildFinalizer adds the +// implementation of the finalizer lambda (if there is one) to a +// xla::XlaBuilder. +class XlaReductionOp : public XlaOpKernel { + public: + XlaReductionOp(OpKernelConstruction* ctx, DataType reduction_type); + ~XlaReductionOp() override = default; + + // Return the base case for the reduction. + virtual xla::XlaOp InitialValue(xla::XlaBuilder* builder) = 0; + + // Implement the (scalar,scalar)->scalar lambda that should be + // applied to each pair of elements to be reduced. The desired + // computation should be added to 'builder' and + // '(scalar_lhs,scalar_rhs)' are the function's inputs. + virtual void BuildReducer(xla::XlaBuilder* builder, + const xla::XlaOp& scalar_lhs, + const xla::XlaOp& scalar_rhs) = 0; + + // Applies a transformation to the output of the reduction. The desired + // computation should be added to 'builder'. Argument 'input' is the original + // input of the reduction; 'reduce_output' is the output of the reduction. + // Returns the transformed reduction output. Defaults to returning + // 'reduce_output' converted to the input type. + virtual xla::XlaOp BuildFinalizer( + xla::XlaBuilder* builder, const xla::XlaOp& input, + const xla::XlaOp& reduce_output, + const std::vector& dimensions_to_reduce); + + void Compile(XlaOpKernelContext* ctx) override; + + private: + // True if the number of dimensions should be maintained. + bool keep_dims_; + + protected: + DataType reduction_type_; + xla::PrimitiveType xla_reduction_type_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2XLA_KERNELS_REDUCTION_OPS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/tf2xla/kernels/relu_op.h b/third_party/tflite-hdrs/tensorflow/compiler/tf2xla/kernels/relu_op.h new file mode 100644 index 00000000..b980df77 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/tf2xla/kernels/relu_op.h @@ -0,0 +1,26 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_TF2XLA_KERNELS_RELU_OP_H_ +#define TENSORFLOW_COMPILER_TF2XLA_KERNELS_RELU_OP_H_ + +#include "xla/hlo/builder/lib/constants.h" +#include "xla/hlo/builder/xla_builder.h" + +namespace xla { +XlaOp Relu(XlaOp x); +XlaOp Relu6(XlaOp x); +} // namespace xla + +#endif // TENSORFLOW_COMPILER_TF2XLA_KERNELS_RELU_OP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/tf2xla/kernels/resampler_ops.h b/third_party/tflite-hdrs/tensorflow/compiler/tf2xla/kernels/resampler_ops.h new file mode 100644 index 00000000..7ecc2e93 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/tf2xla/kernels/resampler_ops.h @@ -0,0 +1,41 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2XLA_KERNELS_RESAMPLER_OPS_H_ +#define TENSORFLOW_COMPILER_TF2XLA_KERNELS_RESAMPLER_OPS_H_ + +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/core/framework/op_kernel.h" + +namespace tensorflow { + +// XLA op kernel for both contrib and addon flavors of TenforFlow Resampler +class ResamplerOp : public XlaOpKernel { + public: + explicit ResamplerOp(OpKernelConstruction* ctx); + void Compile(XlaOpKernelContext* ctx) override; +}; + +// XLA op kernel for both contrib and addon flavors of TenforFlow Resampler +// gradient. +class ResamplerGradOp : public XlaOpKernel { + public: + explicit ResamplerGradOp(OpKernelConstruction* ctx); + void Compile(XlaOpKernelContext* ctx) override; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2XLA_KERNELS_RESAMPLER_OPS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/tf2xla/kernels/rng_converter_utils.h b/third_party/tflite-hdrs/tensorflow/compiler/tf2xla/kernels/rng_converter_utils.h new file mode 100644 index 00000000..ec45834d --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/tf2xla/kernels/rng_converter_utils.h @@ -0,0 +1,34 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2XLA_KERNELS_RNG_CONVERTER_UTILS_H_ +#define TENSORFLOW_COMPILER_TF2XLA_KERNELS_RNG_CONVERTER_UTILS_H_ + +#include "absl/strings/string_view.h" +#include "xla/xla_data.pb.h" +#include "tensorflow/core/framework/rng_alg.h" + +namespace tensorflow { + +// Given the XLA::RandomAlgorithm, return the Tensorflow equivalent. +Algorithm ToTensorflowAlgorithm(xla::RandomAlgorithm alg); + +// Given the device type, return the default XLA::RandomAlgorithm +xla::RandomAlgorithm DefaultRngAlgForDeviceType( + absl::string_view device_type_string); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2XLA_KERNELS_RNG_CONVERTER_UTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/tf2xla/kernels/shape_util.h b/third_party/tflite-hdrs/tensorflow/compiler/tf2xla/kernels/shape_util.h new file mode 100644 index 00000000..bfce0919 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/tf2xla/kernels/shape_util.h @@ -0,0 +1,37 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2XLA_KERNELS_SHAPE_UTIL_H_ +#define TENSORFLOW_COMPILER_TF2XLA_KERNELS_SHAPE_UTIL_H_ + +#include + +#include "absl/status/status.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/platform/status.h" + +namespace tensorflow { + +// Converts a TensorShape to a constant Tensor. +// +// The input TensorShape input_shape is used to populate the elements of +// shape_constant, which is modified in place. +absl::Status TensorShapeToConstant(const TensorShape& input_shape, + Tensor* shape_constant); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2XLA_KERNELS_SHAPE_UTIL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/tf2xla/kernels/tensor_list_utils.h b/third_party/tflite-hdrs/tensorflow/compiler/tf2xla/kernels/tensor_list_utils.h new file mode 100644 index 00000000..e4aeb015 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/tf2xla/kernels/tensor_list_utils.h @@ -0,0 +1,135 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2XLA_KERNELS_TENSOR_LIST_UTILS_H_ +#define TENSORFLOW_COMPILER_TF2XLA_KERNELS_TENSOR_LIST_UTILS_H_ + +#include +#include + +#include "absl/status/status.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "xla/hlo/builder/xla_builder.h" +#include "xla/shape.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/platform/status.h" + +namespace tensorflow { + +// Whether the input expression at `index` corresponds to a TensorList. +bool IsTensorListInput(XlaOpKernelContext* ctx, int index); + +// Whether the TensorList is initialized (has known data type and shape). +absl::Status IsTensorListInitialized(xla::XlaOp list, bool* is_initialized); + +// Whether the TensorList is a nested TensorList. +// Input must be an initialized TensorList. +// Non-nested and nested TensorLists are both supported. +absl::Status IsNestedTensorList(xla::XlaOp list, bool* is_nested_list); + +// Builds a non-nested TensorList from `buffer` and `push_index`. +absl::Status BuildNonNestedTensorList(xla::XlaOp buffer, xla::XlaOp push_index, + xla::XlaOp* output_list); + +// Returns buffer shape for the TensorList. +// Input must be an initialized TensorList. +// Non-nested and nested TensorLists are both supported. +absl::Status GetTensorListBufferShape(xla::XlaOp list, + xla::Shape* buffer_shape); + +// Returns buffer for the TensorList. +// Input must be an initialized TensorList. +// Non-nested and nested TensorLists are both supported. +absl::Status GetTensorListBuffer(xla::XlaOp list, xla::XlaOp* buffer); + +// Returns push index for the TensorList. +// Input must be an initialized TensorList. +// Non-nested and nested TensorLists are both supported. +absl::Status GetTensorListPushIndex(xla::XlaOp list, xla::XlaOp* push_index); + +// Returns a new TensorList with given push_index. +// Input must be an initialized TensorList. +// Non-nested and nested TensorLists are both supported. +absl::Status SetTensorListPushIndex(xla::XlaOp list, xla::XlaOp push_index, + xla::XlaOp* result); + +// Returns an uninitialized TensorList. +xla::XlaOp BuildUninitializedTensorList(xla::XlaBuilder* b, + int64_t leading_dimension, + bool leading_size_is_dynamic, + xla::XlaOp leading_dim_size); + +// Returns leading dimension for the TensorList as well as a dynamic op +// representing the dynamic size. Input can be initialized or uninitialized +// TensorList. Non-nested and nested TensorLists are both supported. +absl::Status GetLeadingDimForTensorList(xla::XlaOp list, int64_t* leading_dim, + bool* leading_dim_is_dynamic, + xla::XlaOp* leading_dim_dynamic_size); + +// Returns TensorList shape for the element shape. +// Element shape must be a normal tensor shape. +absl::Status GetTensorListShapeFromElementShape(const xla::Shape& element_shape, + int64_t leading_dim, + bool leading_dim_is_dynamic, + xla::Shape* tensor_list_shape); + +// Returns a TensorList filled by zeros with the given shape. +absl::Status CreateZerosTensorListWithShape( + xla::XlaBuilder* b, const xla::Shape& list_shape, + const std::vector>& dynamic_dims, xla::XlaOp* list); + +// If the TensorList is initialized, check that its shape matches element shape; +// If the TensorList is uninitialized, initialize it with the element shape. +// Input can be initialized or uninitialized TensorList. +// "element" can be normal tensor or TensorList. +absl::Status GetInitializedTensorListForElement(xla::XlaOp list, + xla::XlaOp element, + bool element_is_tensor_list, + xla::XlaOp* initialized_list); + +// Executes TensorListPushBack with given TensorList and element. +// Input must be an initialized TensorList. +// Non-nested and nested TensorLists are both supported. +absl::Status ExecuteTensorListPushBack(xla::XlaOp list, xla::XlaOp element, + bool element_is_tensor_list, + xla::XlaOp* result); + +// Executes TensorListPopBack with given TensorList. +// Input must be an initialized TensorList. +// Non-nested and nested TensorLists are both supported. +absl::Status ExecuteTensorListPopBack(xla::XlaOp list, xla::XlaOp* list_result, + xla::XlaOp* element_result, + bool* element_is_tensor_list); + +// Executes TensorListSetItem with given TensorList, index and element. +// Input must be an initialized TensorList. +// Only non-nested TensorList is supported. +absl::Status ExecuteTensorListSetItem(xla::XlaOp list, xla::XlaOp index, + xla::XlaOp element, xla::XlaOp* result); + +// Executes TensorListGetItem with given TensorList and index. +// Input must be an initialized TensorList. +// Only non-nested TensorList is supported. +absl::Status ExecuteTensorListGetItem(xla::XlaOp list, xla::XlaOp index, + xla::XlaOp* result); + +// Executes TensorListPushBack with given tensor and push index. +// "tensor" must be a normal tensor. +absl::Status ExecuteTensorListFromTensor(int push_index, xla::XlaOp tensor, + xla::XlaOp* result); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2XLA_KERNELS_TENSOR_LIST_UTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/tf2xla/kernels/while_op.h b/third_party/tflite-hdrs/tensorflow/compiler/tf2xla/kernels/while_op.h new file mode 100644 index 00000000..8e9f317a --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/tf2xla/kernels/while_op.h @@ -0,0 +1,77 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2XLA_KERNELS_WHILE_OP_H_ +#define TENSORFLOW_COMPILER_TF2XLA_KERNELS_WHILE_OP_H_ + +#include + +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +// This TensorFlow op provides a functional iteration primitive. +// +// The inputs and outputs of the loop body must agree on the number, types, and +// shapes of the Tensors carried around the loop body. +// +// Computations in while loops may read from and write to resource variables. +// Resource variables may be passed as arguments to a function's body and +// condition functions. The XlaCompiler converts resource variable arguments +// into parameters to the XLA computation and moves them to the end of the +// parameter list, and by using the `return_updated_values_for_all_variables` +// we ensure that all variables that appear in the input also appear at the +// end of the body's output. This ensures the loop body's input and output +// signatures match. +// +// It is the user's responsibility to ensure that each non-variable _Arg matches +// the corresponding _Retval. +// +// For example, suppose we have a loop body with arguments: +// DT_INT32, DT_RESOURCE (pointing to a DT_BOOL var), DT_FLOAT +// and return values +// DT_INT32, DT_FLOAT +// It is an error for the body to return DT_RESOURCE values. +// +// The body will be lowered into an XLA computation that takes and returns a +// tuple with XLA type (I32, F32, PRED). Note the resource variable appears at +// the end of both the loop body's input and output argument lists. +class XlaWhileOp : public XlaOpKernel { + public: + explicit XlaWhileOp(OpKernelConstruction* ctx); + + void Compile(XlaOpKernelContext* ctx) override; + + private: + NameAttrList cond_name_attr_; + NameAttrList body_name_attr_; + bool has_token_input_output_; + std::vector token_input_nodes_; + string original_node_name_; + // Whether to propagate compile time consts into the loop body. + // This is not supported by default now since it may cause HBM memory + // overheads. + bool propagate_compile_time_consts_ = false; + + XlaWhileOp(const XlaWhileOp&) = delete; + void operator=(const XlaWhileOp&) = delete; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2XLA_KERNELS_WHILE_OP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/tf2xla/kernels/xla_call_module_loader.h b/third_party/tflite-hdrs/tensorflow/compiler/tf2xla/kernels/xla_call_module_loader.h new file mode 100644 index 00000000..3b75ca3b --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/tf2xla/kernels/xla_call_module_loader.h @@ -0,0 +1,130 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2XLA_KERNELS_XLA_CALL_MODULE_LOADER_H_ +#define TENSORFLOW_COMPILER_TF2XLA_KERNELS_XLA_CALL_MODULE_LOADER_H_ + +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "llvm/ADT/ArrayRef.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/OwningOpRef.h" // from @llvm-project +#include "mlir/IR/TypeRange.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "stablehlo/dialect/StablehloOps.h" // from @stablehlo +#include "xla/hlo/builder/xla_computation.h" +#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" +#include "xla/shape.h" +#include "tsl/platform/statusor.h" + +namespace tensorflow { + +bool IsTokenType(mlir::Type type); + +class XlaCallModuleLoader { + public: + static absl::StatusOr> Create( + mlir::MLIRContext* context, int version, mlir::StringRef module_str, + std::vector disabled_checks, + std::vector platforms, int num_invocation_args, + bool main_has_token_input_output); + + int NrInputs() { return main_.getNumArguments(); } + mlir::TypeRange InputTypes() { return main_.getArgumentTypes(); } + + int NrOutputs() { return main_.getNumResults(); } + mlir::TypeRange OutputTypes() { return main_.getResultTypes(); } + + // Sets the platform index argument, if the module is compiled for multiple + // platforms, and then erases the argument. + absl::Status SetPlatformIndex(absl::string_view compilation_platform); + + // Refines the dynamic module arguments based on the static argument shapes. + // This assumes that the module has a "main" function without dimension args, + // but possibly with dynamic shapes. We read the static shapes of the inputs, + // then set them as the types of the function parameters, and run StableHLO + // shape refinement to specialize all dynamic shapes in the StableHLO program + // to static shapes. + // Starting with version 9, the "main" function may accept token arguments. + // + // If the module uses multi-platform lowering, and you called SetPlatformIndex + // then the refinement will also remove the dead platform code. + // + // This method accepts a list of `llvm::ArrayRef` instead of `mlir::Type`. + // This is to prevent callers from accidentally passing `mlir::Type` owned by + // a context that's different from the one passed to `Create`, which could + // cause lifetime issues. + // The input_shapes includes only the non-token and the non-platform-index + // arguments. + absl::Status RefineDynamicShapes(llvm::ArrayRef input_shapes); + + // Validates that the module only contains ops from valid dialects. + absl::Status ValidateDialect(); + + // Validates that the module represents a statically-shaped StableHLO program, + // otherwise all sorts of weirdness might happen in the HLO exporter which is + // much easier to detect here. + absl::Status ValidateStaticShapes(); + + // Lowers the StableHLO module to MHLO in place. + absl::Status LowerModuleToMhlo(); + + // Lowers the MHLO module to XlaComputation and returns it. + // + // REQUIRES: `LowerModuleToMhlo()` is called beforehand. + absl::StatusOr ToXlaComputation(); + + // Returns the deserialized stablehlo module. + mlir::ModuleOp module() & { return *module_; } + mlir::OwningOpRef module() && { return std::move(module_); } + + private: + XlaCallModuleLoader() = default; + + // Initializes the loader with the given serialized module string. + absl::Status LoadModule(mlir::MLIRContext* context, int version, + mlir::StringRef module_str, + std::vector disabled_checks, + std::vector platforms, + int num_invocation_args, + bool main_has_token_input_output); + + // Adds a wrapper for the "main" function to compute the platform index and + // the dimension arguments. + absl::Status AddMainWrapper(); + + mlir::MLIRContext* context_; + int version_; + mlir::OwningOpRef module_; + std::vector platforms_; + bool platform_index_arg_set_ = false; + // The disabled checks at loading time, including those from the + // disabled_checks attribute and the TF_XLA_FLAGS environment variable. + std::vector loading_disabled_checks_; + mlir::func::FuncOp main_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2XLA_KERNELS_XLA_CALL_MODULE_LOADER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/tf2xla/layout_util.h b/third_party/tflite-hdrs/tensorflow/compiler/tf2xla/layout_util.h new file mode 100644 index 00000000..dcb19561 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/tf2xla/layout_util.h @@ -0,0 +1,84 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Utilities for working with XLA layout and shapes. + +#ifndef TENSORFLOW_COMPILER_TF2XLA_LAYOUT_UTIL_H_ +#define TENSORFLOW_COMPILER_TF2XLA_LAYOUT_UTIL_H_ + +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "tensorflow/compiler/tf2xla/xla_argument.h" +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "xla/hlo/builder/xla_builder.h" +#include "xla/hlo/ir/hlo_sharding.h" +#include "xla/shape.h" +#include "xla/xla_data.pb.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/statusor.h" + +namespace tensorflow { + +class XlaShapeLayoutHelpers { + public: + // The following defines the layout preference of an xla tensor. + // The return value of LayoutPreferenceFn can be used in + // XlaHelper::ShapeRepresentationFn. + typedef std::function)> + LayoutPreferenceFn; + + // A bundle of LayoutPreferenceFn and ShapeRepresentationFn. + struct ShapeDeterminationFns { + // Use no preference function, and identity shape representation function, + // as default value. + ShapeDeterminationFns(); + + ShapeDeterminationFns( + LayoutPreferenceFn layout_preference_fn, + XlaHelpers::ShapeRepresentationFn shape_representation_fn) + : layout_preference_fn(layout_preference_fn), + shape_representation_fn(shape_representation_fn) {} + + LayoutPreferenceFn layout_preference_fn; + XlaHelpers::ShapeRepresentationFn shape_representation_fn; + }; +}; + +// Return a LayoutPreferenceFn that always uses kNoPreference layout. +XlaShapeLayoutHelpers::LayoutPreferenceFn UseNoPreferenceLayoutFn(); + +// Rewrites the layout of xla_shape if there is tiled sharding. +absl::Status RewriteLayoutWithShardedShape( + const std::optional& sharding, bool use_fast_memory, + XlaShapeLayoutHelpers::ShapeDeterminationFns shape_determination_fns, + xla::Shape* xla_shape); + +// Adds reshapes to fix the layout of an output, if a shape_representation_fn or +// sharding is present. +absl::StatusOr ReshapeWithCorrectRepresentationAndSharding( + xla::XlaBuilder* builder, xla::XlaOp original, xla::Shape original_shape, + XlaShapeLayoutHelpers::ShapeDeterminationFns shape_determination_fns, + std::optional sharding, bool fast_mem); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2XLA_SHAPE_UTIL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/tf2xla/lib/broadcast.h b/third_party/tflite-hdrs/tensorflow/compiler/tf2xla/lib/broadcast.h new file mode 100644 index 00000000..48dec32a --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/tf2xla/lib/broadcast.h @@ -0,0 +1,36 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2XLA_LIB_BROADCAST_H_ +#define TENSORFLOW_COMPILER_TF2XLA_LIB_BROADCAST_H_ + +#include "absl/status/statusor.h" +#include "absl/types/span.h" +#include "xla/hlo/builder/xla_builder.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/statusor.h" + +namespace tensorflow { + +// Forwards to xla::BroadcastTo. +// TODO(cheshire): Call the underlying function directly. +absl::StatusOr BroadcastTo(xla::XlaOp input, + absl::Span output_dims); + +// Forwards to xla::BroadcastOpsToSame. +absl::Status BroadcastOpsToSame(xla::XlaOp* lhs, xla::XlaOp* rhs); +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2XLA_LIB_BROADCAST_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/tf2xla/lib/data_format.h b/third_party/tflite-hdrs/tensorflow/compiler/tf2xla/lib/data_format.h new file mode 100644 index 00000000..131f5491 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/tf2xla/lib/data_format.h @@ -0,0 +1,38 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2XLA_LIB_DATA_FORMAT_H_ +#define TENSORFLOW_COMPILER_TF2XLA_LIB_DATA_FORMAT_H_ + +#include "absl/status/statusor.h" +#include "xla/hlo/builder/xla_builder.h" +#include "tensorflow/core/platform/statusor.h" +#include "tensorflow/core/util/tensor_format.h" + +namespace tensorflow { + +// Reformat from NCHW_VECT_C to NCHW. +// +// Prerequisites: the last dimension of the input must be of size 4. +absl::StatusOr NCHW_VECT_CToNCHW(xla::XlaOp input); + +// Reformat from NCHW to NCHW_VECT_C. +// +// Prerequisites: the vectorized dimension `C` must be a multiple of 4. +absl::StatusOr NCHWToNCHW_VECT_C(xla::XlaOp input); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2XLA_LIB_DATA_FORMAT_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/tf2xla/lib/random.h b/third_party/tflite-hdrs/tensorflow/compiler/tf2xla/lib/random.h new file mode 100644 index 00000000..3c03633d --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/tf2xla/lib/random.h @@ -0,0 +1,42 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2XLA_LIB_RANDOM_H_ +#define TENSORFLOW_COMPILER_TF2XLA_LIB_RANDOM_H_ + +#include "xla/hlo/builder/xla_builder.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/platform/statusor.h" + +namespace tensorflow { + +// Builds an array of values sampled from a truncated normal distribution: +// +// uniform: an array of random numbers in uniform distribution (0, 1). +// mu: the mean of the normal distribution. +// sigma: the standard deviation of the normal distribution. +// a: the lower bound of the generated values. +// b: the upper bound of the generated values. +xla::XlaOp ParameterizedTruncatedNormal(xla::XlaOp uniform, xla::XlaOp mu, + xla::XlaOp sigma, xla::XlaOp a, + xla::XlaOp b); + +// A specialized version of ParameterizedTruncatedNormal, with mu=0, sigma=1, +// a=-2 and b=2. +xla::XlaOp TruncatedNormal(xla::XlaOp uniform); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2XLA_LIB_RANDOM_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/tf2xla/lib/scatter.h b/third_party/tflite-hdrs/tensorflow/compiler/tf2xla/lib/scatter.h new file mode 100644 index 00000000..90af6e63 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/tf2xla/lib/scatter.h @@ -0,0 +1,57 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2XLA_LIB_SCATTER_H_ +#define TENSORFLOW_COMPILER_TF2XLA_LIB_SCATTER_H_ + +#include + +#include "absl/status/statusor.h" +#include "xla/hlo/builder/xla_builder.h" +#include "xla/hlo/builder/xla_computation.h" +#include "tensorflow/core/platform/statusor.h" + +namespace tensorflow { + +// Builds an XLA computation that performs a scatter operation on `buffer`, +// returning an updated buffer. +// For each i0, i1, ..., sets +// buffer[indices[i0, i1, ...], ...] := updates[i0, i1, ...] +// +// If `indices_are_vectors` is false, then each index in indices is a scalar, +// and the shape of `indices` must be a prefix of the shape of updates. +// Otherwise, `indices_are_vectors`, then indices are multidimensional and the +// minor dimension of `indices` represents a vector of indices. +// +// If `updates` is a scalar, then it will be broadcasted into the expected shape +// of updates. +// +// If any part of the update region is out-of-bounds, the corresponding update +// is discarded. +// +// If a `combiner` is provided, updates are combined with the existing values in +// the buffer using the combiner function. Otherwise, the updates replace the +// existing values. The order of updates is implementation-defined. +absl::StatusOr XlaScatter( + const xla::XlaOp& buffer, const xla::XlaOp& updates, + const xla::XlaOp& indices, bool indices_are_vectors, + bool indices_are_sorted, + const std::function& + combiner, + xla::XlaBuilder* builder); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2XLA_LIB_SCATTER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/tf2xla/lib/util.h b/third_party/tflite-hdrs/tensorflow/compiler/tf2xla/lib/util.h new file mode 100644 index 00000000..eaf52188 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/tf2xla/lib/util.h @@ -0,0 +1,46 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2XLA_LIB_UTIL_H_ +#define TENSORFLOW_COMPILER_TF2XLA_LIB_UTIL_H_ + +#include + +#include "absl/types/span.h" +#include "xla/hlo/builder/xla_builder.h" +#include "xla/hlo/builder/xla_computation.h" +#include "xla/xla_data.pb.h" +#include "tensorflow/core/platform/statusor.h" + +namespace tensorflow { + +// Returns a floating point scalar constant of 'type' with 'value'. +// If 'type' is complex, returns a real value with zero imaginary component. +xla::XlaOp FloatLiteral(xla::XlaBuilder* builder, xla::PrimitiveType type, + double value); + +// Makes a 1D tensor [0, ..., x, y] from two tensors x and y with zeros +// prepended until the array is length n_dims. +xla::XlaOp PrependZerosInMajorDims(xla::XlaOp x, + absl::Span starts); + +// Returns a integer scalar constant of 'type' with 'value'. +// If 'type' is complex, returns a real value with zero imaginary component. +xla::XlaOp IntegerLiteral(xla::XlaBuilder* builder, xla::PrimitiveType type, + int64_t value); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2XLA_LIB_UTIL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/tf2xla/literal_util.h b/third_party/tflite-hdrs/tensorflow/compiler/tf2xla/literal_util.h new file mode 100644 index 00000000..4463024e --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/tf2xla/literal_util.h @@ -0,0 +1,81 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Utilities for working with XLA Literals. + +#ifndef TENSORFLOW_COMPILER_TF2XLA_LITERAL_UTIL_H_ +#define TENSORFLOW_COMPILER_TF2XLA_LITERAL_UTIL_H_ + +#include "absl/status/statusor.h" +#include "absl/types/span.h" +#include "xla/literal.h" +#include "xla/shape.h" +#include "xla/xla_data.pb.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { + +// Returns a BorrowingLiteral that utilizes the same underlying buffer owned by +// 'host_tensor'. +absl::Status HostTensorToBorrowingLiteral(const Tensor& host_tensor, + xla::BorrowingLiteral* literal); +// Similar as above, except the literal shape is explicitly provided and used +// instead of obtaining it from the 'host_tensor'. The provided literal shape +// 'xla_shape' must be compatible with the shape of 'host_tensor'. +absl::Status HostTensorToBorrowingLiteral(const xla::Shape& xla_shape, + const Tensor& host_tensor, + xla::BorrowingLiteral* literal); + +// Returns a Literal with the contents of 'host_tensor', backed by its own +// storage (i.e., not reusing 'host_tensor's buffers.) +absl::StatusOr HostTensorToLiteral(const Tensor& host_tensor); + +// Returns a MutableBorrowingLiteral that utilizes the same underlying buffer +// owned by 'host_tensor', but is mutable via the xla::Literal methods. +absl::Status HostTensorToMutableBorrowingLiteral( + Tensor* host_tensor, xla::MutableBorrowingLiteral* literal); +// Similar as above, except the literal shape is explicitly provided and used +// instead of obtaining it from the 'host_tensor'. The provided literal shape +// 'xla_shape' must be compatible with the shape of 'host_tensor'. +absl::Status HostTensorToMutableBorrowingLiteral( + const xla::Shape& xla_shape, Tensor* host_tensor, + xla::MutableBorrowingLiteral* literal); + +// Returns a BorrowingLiteral tuple that utilizes the same underlying buffers +// owned by 'host_tensors'. +absl::Status HostTensorsToBorrowingLiteralTuple( + absl::Span host_tensors, xla::BorrowingLiteral* literal); + +// Copies 'literal' to freshly allocated 'host_tensor', which is allocated of +// type . +// Fails if the literal's primitive type != +// DataTypeToPrimitiveType(target_type). Note that is not +// derivable from the type of , because multiple tensorflow types map +// to the same XLA type (e.g. INT32 and QINT32 both map to INT32 in +// XLA). +absl::Status LiteralToHostTensor(const xla::LiteralSlice& literal, + DataType target_type, Tensor* host_tensor); + +// Copies the contents of 'literal' to a previously allocated tensor +// 'host_tensor'. The tensor and the literal must have the same number of +// elements and the same type. +absl::Status CopyLiteralToHostTensor(const xla::LiteralSlice& literal, + Tensor* host_tensor); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2XLA_LITERAL_UTIL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/tf2xla/mlir_bridge_pass.h b/third_party/tflite-hdrs/tensorflow/compiler/tf2xla/mlir_bridge_pass.h new file mode 100644 index 00000000..eae5fb83 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/tf2xla/mlir_bridge_pass.h @@ -0,0 +1,75 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2XLA_MLIR_BRIDGE_PASS_H_ +#define TENSORFLOW_COMPILER_TF2XLA_MLIR_BRIDGE_PASS_H_ + +#include + +#include "tensorflow/compiler/mlir/tf2xla/mlir_bridge_rollout_policy.h" +#include "llvm/ADT/StringRef.h" +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "tensorflow/compiler/jit/flags.h" +#include "tensorflow/compiler/mlir/mlir_graph_optimization_pass.h" +#include "tensorflow/core/common_runtime/device_set.h" +#include "tensorflow/core/common_runtime/optimization_registry.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/protobuf/config.pb.h" + +namespace tensorflow { + +// This pass uses MLIR to implement all the conversion steps to target XLA from +// a TensorFlow Function Graph. It is meant to expose a very limited set of +// functionalities during the bring-up of MLIR-based bridge. +class MlirBridgePass : public MlirOptimizationPass { + public: + llvm::StringRef name() const override { return "bridge"; } + + MlirOptimizationPassState GetPassState( + const DeviceSet* device_set, const ConfigProto& config_proto, + const Graph& graph, + const FunctionLibraryDefinition& function_library) const override; + + // This should be used as a thin mapper around mlir::ModulePass::runOnModule + // API integrated with the Tensorflow runtime. + absl::Status Run(const std::string& function_name, + const ConfigProto& config_proto, mlir::ModuleOp module, + const Graph& graph, + const FunctionLibraryDefinition& function_library) override; +}; + +// This pass uses MLIR to implement all the conversion steps to target XLA from +// a TensorFlow V1 Graph. It is meant to expose a very limited set of +// functionalities during the bring-up of MLIR-based bridge. +class MlirBridgeV1CompatPass : public MlirV1CompatOptimizationPass { + public: + llvm::StringRef name() const override { return "bridge"; } + + MlirOptimizationPassState GetPassState( + const DeviceSet* device_set, const ConfigProto& config_proto, + const Graph& graph, + const FunctionLibraryDefinition& function_library) const override; + + // This should be used as a thin mapper around mlir::ModulePass::runOnModule + // API integrated with the Tensorflow runtime. + absl::Status Run(const GraphOptimizationPassOptions& options, + mlir::ModuleOp module) override; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2XLA_MLIR_BRIDGE_PASS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/tf2xla/mlir_xla_op_kernel.h b/third_party/tflite-hdrs/tensorflow/compiler/tf2xla/mlir_xla_op_kernel.h new file mode 100644 index 00000000..6053f5d6 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/tf2xla/mlir_xla_op_kernel.h @@ -0,0 +1,41 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2XLA_MLIR_XLA_OP_KERNEL_H_ +#define TENSORFLOW_COMPILER_TF2XLA_MLIR_XLA_OP_KERNEL_H_ + +#include "tensorflow/compiler/tf2xla/xla_compiler.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/platform/status.h" + +namespace tensorflow { + +// An XlaOpKernel that's implemented by lowering using MLIR TensorFlow to HLO +// legalization. +class MlirXlaOpKernel : public XlaOpKernel { + public: + explicit MlirXlaOpKernel(OpKernelConstruction* ctx); + + private: + absl::Status ContextToXlaArgs(XlaOpKernelContext* ctx, + std::vector& xla_args); + void Compile(XlaOpKernelContext* ctx) override; + absl::Status ConstructXlaOp(XlaOpKernelContext* ctx); +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2XLA_MLIR_XLA_OP_KERNEL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/tf2xla/rearrange_function_argument.h b/third_party/tflite-hdrs/tensorflow/compiler/tf2xla/rearrange_function_argument.h new file mode 100644 index 00000000..1a290017 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/tf2xla/rearrange_function_argument.h @@ -0,0 +1,42 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2XLA_REARRANGE_FUNCTION_ARGUMENT_H_ +#define TENSORFLOW_COMPILER_TF2XLA_REARRANGE_FUNCTION_ARGUMENT_H_ + +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/graph/graph.h" + +namespace tensorflow { + +// For the given graph `g`: +// 1. Rewrite If/While node functions to rearrange arguments and return values, +// so that all resource arguments/return values are placed in the end (as +// required by XlaCompiler), +// 2. Inline StatefulPartitionedCall nodes so we do not need to rearrange +// arguments and return values. +// `get_function_body_fn` is used to instantiate FunctionDef. +// `fld` is used to store rewritten functions. +// `global_fld` is used to potentially supply stack traces for functions when +// they are not found in `fld`. +absl::Status RearrangeFunctionArguments( + std::function + get_function_body_fn, + Graph* g, FunctionLibraryDefinition* fld, + const FunctionLibraryDefinition* global_fld = nullptr); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2XLA_REARRANGE_FUNCTION_ARGUMENT_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/tf2xla/resource_operation_table.h b/third_party/tflite-hdrs/tensorflow/compiler/tf2xla/resource_operation_table.h new file mode 100644 index 00000000..61c7a56f --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/tf2xla/resource_operation_table.h @@ -0,0 +1,71 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2XLA_RESOURCE_OPERATION_TABLE_H_ +#define TENSORFLOW_COMPILER_TF2XLA_RESOURCE_OPERATION_TABLE_H_ + +#include +#include + +#include "absl/strings/string_view.h" +#include "tensorflow/core/platform/logging.h" + +// Exposes information about the resource operations supported by tf2xla in a +// structured form. + +namespace tensorflow { +enum class XlaResourceOpKind { + kRead, // Only reads from resources. + kWrite, // Only writes to resources. + kReadWrite // Reads from and writes to resources. +}; + +enum class XlaResourceKind { + kVariable, // Operates on resource variables. + kStack, // Operates on stacks. + kTensorArray // Operates on tensor arrays. +}; + +class XlaResourceOpInfo { + public: + explicit XlaResourceOpInfo(XlaResourceOpKind op_kind, + XlaResourceKind resource_kind) + : op_kind_(op_kind), resource_kind_(resource_kind) {} + + XlaResourceOpKind kind() const { return op_kind_; } + XlaResourceKind resource_kind() const { return resource_kind_; } + + static absl::string_view XlaResourceOpKindToString(XlaResourceOpKind op_kind); + + private: + XlaResourceOpKind op_kind_; + XlaResourceKind resource_kind_; +}; + +// Returns a XlaResourceOpInfo describing `op` if it is a resource operation +// supported by tf2xla, otherwise returns null (i.e. if this returns null then +// `op` is either not a resource operation or is unsupported by XLA). +const XlaResourceOpInfo* GetResourceOpInfoForOp(absl::string_view op); + +namespace resource_op_table_internal { +// NB! Implementation detail exposed for unit testing, do not use. +// +// Returns the set of resource operations known by this module. +std::vector GetKnownResourceOps(); +} // namespace resource_op_table_internal + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2XLA_RESOURCE_OPERATION_TABLE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/tf2xla/resource_util.h b/third_party/tflite-hdrs/tensorflow/compiler/tf2xla/resource_util.h new file mode 100644 index 00000000..e4bdb511 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/tf2xla/resource_util.h @@ -0,0 +1,96 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2XLA_RESOURCE_UTIL_H_ +#define TENSORFLOW_COMPILER_TF2XLA_RESOURCE_UTIL_H_ + +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/hash/hash.h" +#include "absl/strings/str_cat.h" +#include "tensorflow/core/common_runtime/function.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/lib/core/errors.h" + +namespace tensorflow { +class ResourceUsageAnalysis { + public: + // NodeInfo is a triple of function_name:node_name:op to uniquely identity a + // node in graph. ResourceUsageAnalysis uses it to represent resource sources + // and users. + class NodeInfo { + public: + std::optional function_name_; + std::string node_name_; + std::string op_; + + NodeInfo() {} + + NodeInfo(const std::optional& function_name, + std::string node_name, std::string op) + : function_name_(function_name), + node_name_(std::move(node_name)), + op_(std::move(op)) {} + + std::string DebugString() const { + return absl::StrJoin({function_name_.value_or(""), node_name_, op_}, ":"); + } + + bool operator==(const NodeInfo& o) const { + return function_name_ == o.function_name_ && node_name_ == o.node_name_ && + op_ == o.op_; + } + + template + friend H AbslHashValue(H h, const NodeInfo& o) { + return H::combine(std::move(h), o.function_name_, o.node_name_, o.op_); + } + }; + + // This method analyzes a Tensorflow graph and finds all operations that + // create Stack/TensorArray resources and all the operations that consume + // resource created by them. + // + // Note that _Arg nodes that introduce resources are not considered sources. + // Note again that Control Flow v1 nodes + // (Enter/Exit/Switch/Merge/NextIteration) are not supported. Graphs contain + // these nodes cause analysis failures. However Control Flow v2 nodes + // (While/If) will be supported. + // + // TODO(b/135628319): Support analyzing functional while/if as pass-through + // ops. + // + // For example, consider following subgraph: + // + // TensorArrayOp -> Identity -> TensorArrayWriteOp + // + // It should be able to tell that TensorArrayWriteOp actually operates on the + // resource created by TensorArrayOp even though there might be + // non-resource-specific operations like Identity (or other pass-through + // operations). + // + // source_to_path maps the nodes that creates resources to all nodes that + // operate on the corresponding resource, not including sources themselves. It + // is cleared upon calling this method. + static absl::Status Analyze( + const Graph* graph, FunctionLibraryRuntime* lib_runtime, + absl::flat_hash_map>* + source_to_path); +}; + +} // namespace tensorflow +#endif // TENSORFLOW_COMPILER_TF2XLA_RESOURCE_UTIL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/tf2xla/shape_util.h b/third_party/tflite-hdrs/tensorflow/compiler/tf2xla/shape_util.h new file mode 100644 index 00000000..018ab191 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/tf2xla/shape_util.h @@ -0,0 +1,87 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Utilities for working with XLA shapes. + +#ifndef TENSORFLOW_COMPILER_TF2XLA_SHAPE_UTIL_H_ +#define TENSORFLOW_COMPILER_TF2XLA_SHAPE_UTIL_H_ + +#include + +#include "xla/shape.h" +#include "xla/xla_data.pb.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/platform/statusor.h" + +namespace tensorflow { + +// Convert an XLA Shape into the equivalent TensorFlow shape. May fail since +// not all XLA shapes can be represented as TensorShapes. +absl::Status XLAShapeToTensorShape(const xla::Shape& shape, + TensorShape* tensor_shape); + +// Convert a TensorShape into the equivalent XLA Shape proto. Unlike Tensorflow, +// XLA shapes include the type. Not all `dtype` values can be represented by +// XLA, so this conversion may fail. +absl::Status TensorShapeToXLAShape(DataType dtype, + const TensorShape& tensor_shape, + xla::Shape* shape); + +absl::StatusOr TensorShapeToXLAShape( + DataType dtype, const TensorShape& tensor_shape); + +// Converts a TensorShape into the equivalent XLA Shape proto, taking an +// xla::PrimitiveType to specify the element type. This never fails. +xla::Shape TensorShapeToXLAShape(xla::PrimitiveType type, + const TensorShape& tensor_shape); + +// Convert a PartialTensorShape into the equivalent XLA Shape proto. An shape +// with unknown rank is represented by an r1 with empty dimension. +absl::Status TensorShapeToXLAShape(DataType dtype, + const PartialTensorShape& tensor_shape, + xla::Shape* shape); + +// Convert a PartialTensorShape into the equivalent XLA Shape proto. An shape +// with unknown rank is represented by an r1 with empty dimension. +xla::Shape TensorShapeToXLAShape(xla::PrimitiveType type, + const PartialTensorShape& tensor_shape); + +absl::Status TensorShapeToBoundedXLAShape( + DataType dtype, const PartialTensorShape& tensor_shape, + const TensorShape& bound, xla::Shape* shape); + +// Given an XLA shape with layouts, builds a layout vector in the form able to +// be fed to ops like InfeedEnqueue/InfeedEnqueueTuple/XRTAllocateV2/.... +// THe returned vector is a linearized sequence of the minor-to-major values of +// the layouts held within the input shape. +// In case the input shape is a tuple, the minor-to-major values will be in the +// order of the tuple elements within the tuple shape. +// If a shape (or a subshape of a tuple shape) has missing layout, a rank long +// sequence of -1 values will be emitted. +absl::StatusOr> GetShapeLayoutVector(const xla::Shape& shape); + +// Given the input shape and a linearized sequence of the minor-to-major values +// of the layouts, create the output shape by rewriting the input shape layouts. +// If a layout is missing (has -1 values) for a matching tuple subshape, the +// layout_func will be called, if not nullptr. +absl::Status GetShapeWithLayout( + const xla::Shape& input_shape, absl::Span minor_to_major, + const std::function& layout_func, + xla::Shape* output_shape); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2XLA_SHAPE_UTIL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/tf2xla/sharding_util.h b/third_party/tflite-hdrs/tensorflow/compiler/tf2xla/sharding_util.h new file mode 100644 index 00000000..473ad1dd --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/tf2xla/sharding_util.h @@ -0,0 +1,58 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_TF2XLA_SHARDING_UTIL_H_ +#define TENSORFLOW_COMPILER_TF2XLA_SHARDING_UTIL_H_ + +#include + +#include "xla/hlo/builder/sharding_builder.h" +#include "xla/status_macros.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { + +// Parses the op sharding from the 'replicated core' device_name . +// Returns an error: +// - if the device name is invalid. +// - the core is parsed and is out of the range [0, num_cores_per_replica). +// +// Otherwise, returns either: +// - explicit_sharding if explicit_sharding.has_value() +// - a non-value if there is no assigned core or +// - a sharding set as per xla::sharding_builder::AssignDevice. +absl::StatusOr> ParseShardingFromDevice( + const string& device_name, int num_cores_per_replica, + std::optional explicit_sharding = std::nullopt, + std::optional metadata = std::nullopt); + +absl::StatusOr> ParseShardingFromDevice( + const Node& node, int num_cores_per_replica, bool add_metadata); + +absl::StatusOr> ParseShardingFromDevice( + const NodeDef& node_def, int num_cores_per_replica, bool add_metadata); + +absl::StatusOr> ParseShardingFromEdgeSource( + const Edge& edge, int num_cores_per_replica, bool add_metadata); + +void SetShardingDeviceAssignmentFromNode(const Node& src, Node* dst); + +// Get sharding inforamtion from node. +absl::StatusOr> GetShardingFromNodeDef( + const NodeDef& node_def, bool add_metadata); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2XLA_SHARDING_UTIL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/tf2xla/side_effect_util.h b/third_party/tflite-hdrs/tensorflow/compiler/tf2xla/side_effect_util.h new file mode 100644 index 00000000..34f30eb7 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/tf2xla/side_effect_util.h @@ -0,0 +1,69 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2XLA_SIDE_EFFECT_UTIL_H_ +#define TENSORFLOW_COMPILER_TF2XLA_SIDE_EFFECT_UTIL_H_ + +#include + +#include "tensorflow/core/graph/graph.h" + +namespace tensorflow { + +// Side-effecting nodes will have this attribute set. Its value is the list of +// node names which this node has side-effect dependencies on. +// +// Nodes like HostCompute, SendToHost, RecvFromHost always have this attribute, +// because they always have side-effect. +// If and While nodes may or may not have this attribute, depending on whether +// their bodies have side-effecting nodes. +extern const char kXlaTokenInputNodesAttrName[]; + +// This node name is used in kXlaTokenInputNodesAttrName attr to signal that a +// node has side-effect dependency on current graph's token input. +extern const char kXlaTokenArgNodeName[]; + +// This node have XlaRecvAtHost/XlaSendFromHost in its associated functions. +extern const char kXlaHasHostTransferAttrName[]; + +// This attribute is the replica id for an outside compilation node node. +extern const char kXlaReplicaIdAttrName[]; + +// This node is a Placeholder node added for tail outside compilation. +extern const char kXlaIsPlaceholderForTailOcAttrName[]; + +// This attribute is the original node name for this node. +extern const char kXlaOriginalOutsideCompilationNodeName[]; + +// Sets device ordinal attribute for nodes with attribute +// `kXlaHasHostTransferAttrName`. +absl::Status SetDeviceOrdinalAttributeForNode(Node* node, int device_ordinal); + +// Calculates side-effect dependencies for the graph's token output. +// Returns a set of node names representing these dependencies. +std::set CalculateTokenInputsForOutputToken(const Graph& g); + +// Returns whether a graph contains side-effecting nodes. +bool HasSideEffectingNodes(const Graph& g); + +// Parse the mapping from outside_compilation_subgraph name to core number, +// which is specified in an attr as a list of strings +// :. +absl::Status ParseHostComputeCoreList(absl::Span list_from_attr, + std::map* host_compute_core); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2XLA_SIDE_EFFECT_UTIL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/tf2xla/test_util.h b/third_party/tflite-hdrs/tensorflow/compiler/tf2xla/test_util.h new file mode 100644 index 00000000..2b2eb4f5 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/tf2xla/test_util.h @@ -0,0 +1,60 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Helper functions for tests. + +#ifndef TENSORFLOW_COMPILER_TF2XLA_TEST_UTIL_H_ +#define TENSORFLOW_COMPILER_TF2XLA_TEST_UTIL_H_ + +#include +#include +#include + +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/graph_def_util.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/util/equal_graph_def.h" + +namespace tensorflow { + +// Same as InstantiationResult, but has a GraphDef instead of just nodes. +struct InstantiationResultForTest { + DataTypeVector arg_types; + DataTypeVector ret_types; + GraphDef gdef; +}; + +// Instantiates a function, producing a GraphDef to compare against the +// expected graph. +absl::Status InstantiateFunctionForTest( + const string& name, const FunctionLibraryDefinition& library, + InstantiationResultForTest* result); + +} // namespace tensorflow + +// Variant of TF_EXPECT_GRAPH_EQ that also compares internal attributes for +// equality. +#define TF_EXPECT_GRAPH_EQ_INTERNAL(expected, actual) \ + do { \ + string diff; \ + EqualGraphDefOptions eq_options; \ + eq_options.ignore_internal_attrs = false; \ + EXPECT_TRUE(EqualGraphDef(actual, expected, &diff, eq_options)) \ + << diff << "\nActual: " << SummarizeGraphDef(actual); \ + } while (false) + +#endif // TENSORFLOW_COMPILER_TF2XLA_TEST_UTIL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/tf2xla/tf2xla.h b/third_party/tflite-hdrs/tensorflow/compiler/tf2xla/tf2xla.h new file mode 100644 index 00000000..095ad49a --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/tf2xla/tf2xla.h @@ -0,0 +1,53 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2XLA_TF2XLA_H_ +#define TENSORFLOW_COMPILER_TF2XLA_TF2XLA_H_ + +#include "absl/strings/string_view.h" +#include "tensorflow/compiler/tf2xla/tf2xla.pb.h" +#include "xla/client/client.h" +#include "xla/hlo/builder/xla_computation.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/platform/status.h" + +namespace tensorflow { + +// Converts a tensorflow::GraphDef into an xla::XlaComputation. The given +// `config` specifies the portion of the graph to convert, via feeds and +// fetches. Each feed is a positional input argument for the generated +// computation, while each fetch is a positional output argument. +// +// The computation is built in the context of the given `client`, which may +// subsequently be used to compile or execute the computation. +absl::Status ConvertGraphDefToXla(GraphDef graph_def, + const tf2xla::Config& config, + xla::Client* client, + xla::XlaComputation* computation); + +// Similar to ConvertGraphDefToXla, but uses MLIR and handle debug information. +// +// debug_info_filename: the file for the debug information proto. +// debug_info_path_begin_marker: if not empty, file pathes in the debug +// information are trimmed from the beginning to the first appearance of the +// marker. +absl::Status ConvertGraphDefToXlaViaMlir( + GraphDef graph_def, const tf2xla::Config& config, + xla::XlaComputation* computation, absl::string_view debug_info_filename, + absl::string_view debug_info_path_begin_marker); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2XLA_TF2XLA_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/tf2xla/tf2xla_defs.h b/third_party/tflite-hdrs/tensorflow/compiler/tf2xla/tf2xla_defs.h new file mode 100644 index 00000000..2f81d2dd --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/tf2xla/tf2xla_defs.h @@ -0,0 +1,65 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2XLA_TF2XLA_DEFS_H_ +#define TENSORFLOW_COMPILER_TF2XLA_TF2XLA_DEFS_H_ + +#include + +#include "absl/strings/string_view.h" + +namespace tensorflow { + +// Marks a node for XLA compilation. The attribute value indicates the +// compilation device type. +inline constexpr absl::string_view kCompileDeviceTypeAttr = + "_xla_compile_device_type"; +// Marks a node for XLA compilation. +inline constexpr absl::string_view kMustCompileAttr = "_XlaMustCompile"; +// Marks a node for replication. The attribute value indicates the replication +// metadata op. +inline constexpr absl::string_view kReplicationInfoAttr = "_replication_info"; +// Marks a node for XLA-TPU compilation. The attribute value indicates the +// associated compilation cluster and replication metadata op. +inline constexpr absl::string_view kTpuReplicateAttr = "_tpu_replicate"; +// Marks a node inside of an XLA compilation cluster to be placed outside of the +// cluster. +inline constexpr absl::string_view kXlaOutsideCompilationAttr = + "_xla_outside_compilation"; +// Frontend attributes ID. +inline constexpr absl::string_view kXlaFrontendAttributesAttrName = + "_XlaFrontendAttributes"; +// Device types. +inline constexpr absl::string_view kDeviceAttr = "device"; +inline constexpr absl::string_view kCpuDevice = "CPU"; +inline constexpr absl::string_view kGpuDevice = "GPU"; +inline constexpr absl::string_view kTpuDevice = "TPU"; +inline constexpr absl::string_view kEmptyDevice = ""; +// Device type may be empty in ops such as TF.PartitionedCall. +inline constexpr std::array kValidDeviceTypes = { + kCpuDevice, kGpuDevice, kTpuDevice, kEmptyDevice}; +// Attributes that need to be propagated during rewrites (e.g., in +// functionalization). +inline constexpr std::array kAttrsToPropagate = { + kCompileDeviceTypeAttr, + kReplicationInfoAttr, + kXlaFrontendAttributesAttrName, + kXlaOutsideCompilationAttr, + kTpuReplicateAttr, +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2XLA_TF2XLA_DEFS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/tf2xla/tf2xla_opset.h b/third_party/tflite-hdrs/tensorflow/compiler/tf2xla/tf2xla_opset.h new file mode 100644 index 00000000..37fa8f39 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/tf2xla/tf2xla_opset.h @@ -0,0 +1,30 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2XLA_TF2XLA_OPSET_H_ +#define TENSORFLOW_COMPILER_TF2XLA_TF2XLA_OPSET_H_ + +#include +#include + +#include "absl/status/statusor.h" + +namespace tensorflow { + +absl::StatusOr> GetRegisteredXlaOpsForDevice( + absl::string_view device_name); + +} // namespace tensorflow +#endif // TENSORFLOW_COMPILER_TF2XLA_TF2XLA_OPSET_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/tf2xla/tf2xla_supported_ops.h b/third_party/tflite-hdrs/tensorflow/compiler/tf2xla/tf2xla_supported_ops.h new file mode 100644 index 00000000..1b45fb4c --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/tf2xla/tf2xla_supported_ops.h @@ -0,0 +1,33 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2XLA_TF2XLA_SUPPORTED_OPS_H_ +#define TENSORFLOW_COMPILER_TF2XLA_TF2XLA_SUPPORTED_OPS_H_ + +namespace tensorflow { +namespace tf2xla { + +// The implementation of a main function for a binary that prints a table of +// supported tf2xla operators for a given device, along with their type +// constraints, to stdout. +// +// Pass the argc and argv from main, unmodified. Use regen_run to specify the +// command used to regenerate the table. +void SupportedOpsMain(int argc, char** argv, const char* regen_run); + +} // namespace tf2xla +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2XLA_TF2XLA_SUPPORTED_OPS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/tf2xla/tf2xla_util.h b/third_party/tflite-hdrs/tensorflow/compiler/tf2xla/tf2xla_util.h new file mode 100644 index 00000000..f2ce3944 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/tf2xla/tf2xla_util.h @@ -0,0 +1,226 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2XLA_TF2XLA_UTIL_H_ +#define TENSORFLOW_COMPILER_TF2XLA_TF2XLA_UTIL_H_ + +#include + +#include "absl/types/optional.h" +#include "tensorflow/compiler/tf2xla/tf2xla.pb.h" +#include "tensorflow/compiler/tf2xla/tf2xla_defs.h" +#include "xla/status_macros.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/kernel_def.pb.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { + +// ValidateConfig returns OK iff config is valid. +absl::Status ValidateConfig(const tf2xla::Config& config); + +// Modifies to include placeholders for each fed tensor, and +// update references to the fed tensors to refer to the placeholders. +// The existing nodes referenced by the feeds are not removed or modified +// (except where their input edges are modified by the replacement of other +// feeds). +absl::Status AddPlaceholdersForFeeds( + const tf2xla::Config& config, const OpRegistryInterface* op_registry, + std::unordered_map* feed_remapping, GraphDef* graph_def); + +// Returns in a copy of , pruned to only include fetches from +// . +absl::Status PruneGraphDefInto(const tf2xla::Config& config, const GraphDef& in, + GraphDef* out); + +// Returns node:port for the given . +string TensorIdToString(const tf2xla::TensorId& id); + +// Updates the sharding of based on the sharding of its neighbors. +// If is true, outgoing edges from are considered; else incoming +// edges are considered. +absl::Status SetNodeShardingFromNeighbors(Node* n, bool out_edges); + +// Add an allowed data type to the AttrConstraint with the given name. +void AddDtypeToKernelDefConstraint(absl::string_view name, DataType dtype, + KernelDef* kdef); + +// Returns the next random seed to use for seeding xla rng. +uint32 GetXLARandomSeed(); + +// Indicates how a FunctionDef is associated with a graph node (e.g. the node is +// a function call, or the node has function attrs). +class AssociatedFunctionInfo { + public: + enum AssociatedFunctionType { + kFunctionAttr = 0, + kFunctionCallNode = 1, + kSymbolicGradient = 2, + }; + + // The function is an attr of the node. + static AssociatedFunctionInfo FunctionAttr(const string& func_name, + const AttrValueMap& attrs, + const string& attr_name) { + return AssociatedFunctionInfo(kFunctionAttr, func_name, attrs, attr_name); + } + + // The node is a function call. + static AssociatedFunctionInfo FunctionCall(const string& func_name, + const AttrValueMap& attrs) { + // attr_name will not be used in this case. + return AssociatedFunctionInfo(kFunctionCallNode, func_name, attrs, + /*attr_name=*/""); + } + + // The node is a SymbolicGradient op. + static AssociatedFunctionInfo SymbolicGradient(const string& func_name, + const AttrValueMap& attrs) { + // attr_name will not be used in this case. + return AssociatedFunctionInfo(kSymbolicGradient, func_name, attrs, + /*attr_name=*/""); + } + + AssociatedFunctionType type() const { return type_; } + + const string& func_name() const { return func_name_; } + + const string& attr_name() const { return attr_name_; } + + const AttrValueMap& attrs() const { return attrs_; } + + private: + AssociatedFunctionInfo(AssociatedFunctionType type, const string& func_name, + const AttrValueMap& attrs, const string& attr_name) + : type_(type), + func_name_(func_name), + attrs_(attrs), + attr_name_(attr_name) {} + + // Available for all instances. + AssociatedFunctionType type_; + string func_name_; + AttrValueMap attrs_; + + // Only available if the function is defined in an attr. + string attr_name_; +}; + +// Returns if the NodeDef has associated function. +bool HasAssociatedFunction(const NodeDef& node_def, + const FunctionLibraryDefinition* fld); + +// Gets functions associated with the node. Current cases: +// 1. For function call node, its function name; +// 2. For SymbolicGradient op, returned func_name will be "SymbolicGradient", +// and returned attrs will be this node's attributes; +// 3. For nodes like XlaWhile/XlaIf, all their function attributes. +std::vector GetAssociatedFunctions( + const Node& node, const FunctionLibraryDefinition* fld); + +// Changes associated functions for the node. Current cases: +// 1. For function call node, creates a new node with the new function name and +// remove the old node; +// 2. For SymbolicGradient op, add or replace GradientDef in +// FunctionLibraryDefinition; +// 3. For nodes like XlaWhile/XlaIf, modify their function attributes. +absl::Status RewriteAssociatedFunction( + Graph* graph, Node* node, FunctionLibraryDefinition* fld, + const AssociatedFunctionInfo& associated_function, + const string& rewritten_function_name); + +// Class to act as cache for FunctionLibraryRuntime::Handle objects. +class CachedFunctionHandles { + public: + CachedFunctionHandles(FunctionLibraryRuntime* flr) : flr_(flr) {} + + // Populates `handle` for requested function and attributes. If we have + // instantiated the function with the same attributes before, `handle` will be + // cached handle; otherwise instantiate the function and populate `handle`. + absl::Status GetOrInstantiate(const string& func_name, AttrSlice attrs, + FunctionLibraryRuntime::Handle* handle); + + // Releases all handles in the cache. Returns first non-OK status if any; + // returns OK otherwise. + absl::Status ReleaseAllHandles(); + + ~CachedFunctionHandles() { ReleaseAllHandles().IgnoreError(); } + + private: + FunctionLibraryRuntime* flr_; + std::map handles_; + + CachedFunctionHandles(const CachedFunctionHandles&) = delete; + void operator=(const CachedFunctionHandles&) = delete; +}; + +// Struct for node's output edge info. +struct OutEdgeInfo { + Node* dst; + int src_output, dst_input; +}; + +// Replaces node `n` with a new node whose NodeDef is `node_def`. +absl::StatusOr ReplaceNode(Graph* g, Node* n, const NodeDef& node_def); + +// Helper function that builds an Identity node. +absl::StatusOr BuildIdentityNode(Graph* graph, const string& node_name, + DataType dtype, const Node* input, + std::optional requested_device); + +// For "If"/"While" nodes, if some of their inputs are Const nodes, rewrite +// body functions to use the Const nodes instead of original _Arg nodes. +// +// For example, say we have the following computation: +// shape = constant_op.constant([1]) +// return tf.cond(pred, lambda: tf.ones(shape), lambda: tf.zeros(shape)) +// If we do not rewrite then/else function, they will use _Arg node as shape +// input for tf.ones/tf.zeros. But XLA requires that shape input to be compile +// time constant, so XLA compilation will fail. This rewriting process will +// change the shape input to Const node. +absl::Status PropagateConstIntoFunctionalNodes( + Graph* g, const FunctionLibraryDefinition* lookup_fld, + FunctionLibraryDefinition* fld); + +// Prunes unreachable FunctionDefs from FunctionLibraryDefinition. +absl::Status PruneUnreachableFunctionsFromGraph(const Graph& g, + FunctionLibraryDefinition* fld); + +// Finds the following pattern in the graph: +// 1) EmptyTensorList -> forward While op -> backward While op, +// 2) in forward While op, a Const node is pushed, +// 3) in backward While op, data is popped from the tensor list. +// And rewrites backward While op to use Const node instead of TensorListPopBack +// result. +// TODO(b/128633174) remove the TensorList and related TensorList ops. +absl::Status RewriteTensorListWithConstElement(Graph* g, + FunctionLibraryDefinition* fld); + +inline bool IsConstTraversableOpType(const Node* node) { + return node->type_string() == "Identity" || + node->type_string() == "IdentityN" || node->IsWhileNode(); +} + +// Determines whether a loop body is invariant for the given argument index. +absl::StatusOr IsLoopInvariant( + const FunctionBody* loop_body, int index, + const FunctionLibraryDefinition* lookup_fld); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2XLA_TF2XLA_UTIL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/tf2xla/type_util.h b/third_party/tflite-hdrs/tensorflow/compiler/tf2xla/type_util.h new file mode 100644 index 00000000..a3027a5f --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/tf2xla/type_util.h @@ -0,0 +1,42 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2XLA_TYPE_UTIL_H_ +#define TENSORFLOW_COMPILER_TF2XLA_TYPE_UTIL_H_ + +#include "xla/xla_data.pb.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/statusor.h" + +namespace tensorflow { + +// Converts a Tensorflow DataType to an XLA PrimitiveType. +absl::Status DataTypeToPrimitiveType(DataType data_type, + xla::PrimitiveType* type); + +// Converts an XLA PrimitiveType to a TensorFlow DataType. +// Caution: The mapping from TF types to XLA types is not one-to-one: for +// example, both DT_INT8 and DT_QINT8 map to xla::S8. So the inverse is not a +// uniquely defined function. This is fine if you want a way to encode an XLA +// object as a TensorFlow object (e.g., in XRT); whereas if you started with a +// TensorFlow object in the first place, you most likely should preserve the +// original TensorFlow type, rather than trying to convert an XLA type back into +// a TensorFlow type. +absl::StatusOr EncodePrimitiveTypeAsDataType(xla::PrimitiveType type); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2XLA_TYPE_UTIL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/tf2xla/xla_argument.h b/third_party/tflite-hdrs/tensorflow/compiler/tf2xla/xla_argument.h new file mode 100644 index 00000000..9e2eccd2 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/tf2xla/xla_argument.h @@ -0,0 +1,136 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2XLA_XLA_ARGUMENT_H_ +#define TENSORFLOW_COMPILER_TF2XLA_XLA_ARGUMENT_H_ + +#include "absl/types/optional.h" +#include "absl/types/span.h" +#include "tensorflow/compiler/tf2xla/host_compute_metadata.pb.h" +#include "tensorflow/compiler/tf2xla/xla_resource.h" +#include "xla/hlo/builder/xla_builder.h" +#include "xla/hlo/ir/hlo_sharding.h" +#include "tensorflow/core/framework/tensor.h" + +namespace tensorflow { + +// Describes how to derive the value of each _Arg node in the graph/function +// being compiled. There must be one Argument for each _Arg index. +struct XlaArgument { + enum Kind { + // Default value; not a valid kind. + kInvalid, + + // Argument is a compile-time constant. No associated runtime parameter. + kConstant, + + // Argument is a Variable, TensorArray, or Stack resource. Has an + // associated runtime parameter iff `initialized` is true. + kResource, + + // A resource variable with a constant value known at compile time. + kConstantResource, + + // Argument is a run-time parameter. + kParameter, + + // Argument is an XLA token. + kToken, + + // Argument is a TensorList. + kTensorList, + }; + + Kind kind = kInvalid; + + // The type of the argument. If the argument is a resource, this + // is the type of the variable's value, not DT_RESOURCE. + DataType type = DT_INVALID; + + // The shape of the argument. For: + // * a parameter: the shape of the parameter. We allow setting the xla shape + // if known. This helps avoid conversions to and from TensorShape. + // * a constant: ignored; the shape given by constant_value is used + // instead. + // * an uninitialized resource: ignored. We don't yet know the shape of an + // uninitialized resource (otherwise we would have initialized it!) + // * an initialized variable: the shape of the variable's value. + // * an initialized TensorArray or Stack resource: the shape of an entry in + // the TensorArray/Stack. Note this is the size of a single entry, not the + // XLA data structure that represents the complete stack/array. + absl::variant shape; + + // The value of the argument, if it is a compile-time constant. Must be a + // host-memory tensor. + Tensor constant_value; + + // The upper bounds of the value. + std::optional value_bound; + + // Indicates whether each value is dynamic or constant. + std::optional value_dynamism; + + // The name of this argument, used for debugging. + string name; + + // The name of TensorFlow _Arg node, used for debugging. + string node_name; + + // For a kResource, what kind of resource is it? + XlaResource::Kind resource_kind = XlaResource::kInvalid; + + // For a kResource, has this resource been initialized? + bool initialized = false; + + // For a kResource, is this resource on Fast Memory. + bool fast_mem = false; + + // For a TensorArray or Stack resource, what is the array's declared size? + // (Used for lazy initialization.) + int64_t max_array_size = -1; + + // TensorArray resource parameters are passed as (array, gradient array 0, + // ..., gradient array k), where the gradient arrays are in the same order + // as `tensor_array_gradients`. + std::set tensor_array_gradients; + + // Whether this argument will receive the same data across all replicas. + bool is_same_data_across_replicas = false; + + bool operator==(const XlaArgument& other) const; + + // Returns a human-readable summary of the argument. + string HumanString() const; + + // Returns the dimension sizes for either TensorShape or xla::Shape. + std::vector DimensionSizes() const; + absl::InlinedVector DimensionSizesAsInlinedVector() const; + + // Returns the human-readable string for either TensorShape or xla::Shape. + string ShapeHumanString() const; + + // Whether to broadcast this parameter to all replicas before use. + // When true, xla_compiler should input/output alias this arg to prevent + // unnecessary HBM usage. + bool requires_broadcast = false; + std::optional definition_stack_trace; +}; + +// Returns true if any of `args` is an uninitialized resource variable. +bool AnyUninitializedResourceArg(absl::Span args); + +} // end namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2XLA_XLA_ARGUMENT_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/tf2xla/xla_compilation_device.h b/third_party/tflite-hdrs/tensorflow/compiler/tf2xla/xla_compilation_device.h new file mode 100644 index 00000000..e3f6571c --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/tf2xla/xla_compilation_device.h @@ -0,0 +1,69 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILATION_DEVICE_H_ +#define TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILATION_DEVICE_H_ + +#include + +#include "tensorflow/core/common_runtime/local_device.h" +#include "tensorflow/core/framework/device_base.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/mem.h" +#include "tensorflow/core/public/session_options.h" + +namespace tensorflow { + +// Class is defined in xla_compilation_device.cc, reference +// included here only so the XlaCompilationDevice allocator_ member can be +// declared. +class XlaCompilationAllocator; + +// This is a 'dummy' TensorFlow device that is only used to execute a +// subgraph of XLA compilation Ops to construct a compiled version +// of the subgraph's computation. It has a 'dummy' allocator that +// backs each Tensor with an XlaExpression. The shape of the Tensor +// matches the shape of XlaExpression. +// +// We deliberately don't register a device factory because we *never* +// want placement to put Ops on a compilation device. The device is created +// manually, not using a factory. +// +// XLA compilation is not thread-safe. OpKernels registered on the +// XlaCompilationDevice must not use threads or concurrency. +class XlaCompilationDevice : public LocalDevice { + public: + XlaCompilationDevice(const SessionOptions& options, DeviceType type); + + ~XlaCompilationDevice() override; + + Allocator* GetAllocator(AllocatorAttributes attr) override; + + void Compute(OpKernel* op_kernel, OpKernelContext* context) override; + + absl::Status Sync() override; + + absl::Status MakeTensorFromProto(const TensorProto& tensor_proto, + const AllocatorAttributes alloc_attrs, + Tensor* tensor) override; + + private: + std::unique_ptr allocator_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILATION_DEVICE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h b/third_party/tflite-hdrs/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h new file mode 100644 index 00000000..db280e23 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h @@ -0,0 +1,470 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILED_CPU_FUNCTION_H_ +#define TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILED_CPU_FUNCTION_H_ + +#include +#include +#include + +#include "xla/cpu_function_runtime.h" +#include "xla/executable_run_options.h" +#include "xla/service/cpu/buffer_desc.h" +#include "xla/service/custom_call_status_internal.h" +#include "tensorflow/core/platform/types.h" + +// Forward-declare, rather than include, to reduce code size for users that +// never use this functionality. +namespace xla { +class ProgramShapeProto; +class HloProfilePrinterData; + +namespace cpu { +class CpuExecutable; +} // namespace cpu +} // namespace xla + +namespace tensorflow { + +// Represents a function compiled by XLA, produced via either JIT or AOT. +// +// The Run method invokes the actual computation, with inputs read from arg +// buffers, and outputs written to result buffers. Each Run call may also use a +// set of temporary buffers for the computation. +// +// By default each instance of this class manages its own arg, result and temp +// buffers. The AllocMode constructor parameter may be used to modify the buffer +// allocation strategy. +// +// Under the default allocation strategy, this class is thread-compatible: +// o Calls to non-const methods require exclusive access to the object. +// o Concurrent calls to const methods are OK, if those calls are made while it +// is guaranteed that no thread may call a non-const method. +class XlaCompiledCpuFunction { + public: + // Type of the raw XLA Classic function, produced by either JIT or AOT. + using RawFunction = void (*)(void* result, + const xla::ExecutableRunOptions* run_options, + const void** args, void** temps, + XlaCustomCallStatus*, int64_t* profile_counters); + + // Simple struct to describe a tensor's shape. + // Note: this is a poor man's substitute for xla::ShapeProto, but we cannot + // depend on protobuf's in this library. + // TODO(ecg): extend ShapeInfo to support tuples, if needed. + struct ShapeInfo { + const int32_t* dimensions = nullptr; + int32_t num_dimensions = 0; + }; + + // StaticData represents the state necessary to run an XLA-compiled + // function. For JIT this is backed by data in XlaJitCompiledCpuFunction; for + // AOT this is backed by data compiled into the object file. + // + // The contents of StaticData are XLA-internal implementation details and + // should not be relied on by clients (and therefore are private). + class StaticData { + private: + // The raw function to call. + RawFunction raw_function_; + + // Contains information about the buffers used by the XLA computation. + const xla::cpu_function_runtime::BufferInfo* buffer_infos_ = nullptr; + int32_t num_buffers_ = 0; + + // Result parameter i is described by + // buffer_infos[result_index_table[i]]. + const int32* result_index_table_ = nullptr; + + // There are num_results result parameters. + int64_t num_results_ = 0; + + // Entry parameter i is described by + // buffer_infos[arg_index_table[i]]. + const int32* arg_index_table_ = nullptr; + + // There are num_args entry parameters. + int64_t num_args_ = 0; + + // There are num_variables variables. + int64_t num_variables_ = 0; + + // The 0-based index of the result tuple, in the temp buffers. + size_t result_index_ = 0; + + const ShapeInfo* arg_shape_infos_ = nullptr; + const ShapeInfo* result_shape_infos_ = nullptr; + + // [Optional] Arrays of arg and result names. These are arrays of C-style + // strings, where the array is terminated by nullptr. + const char** arg_names_ = nullptr; + const char** variable_names_ = nullptr; + const char** result_names_ = nullptr; + + // [Optional] Arg and result shapes. + const xla::ProgramShapeProto* program_shape_ = nullptr; + + // [Optional] Profile printer data. Null if profiling is disabled. + const xla::HloProfilePrinterData* hlo_profile_printer_data_ = nullptr; + + // [Optional] The number of profile counters expected in the profile counter + // buffer by the generated code and hlo_profile_printer. 0 if profiling is + // disabled. This information is already present in + // hlo_profile_printer_data but xla::HloProfilePrinterData is forward + // declared so we don't have access to that information here. + int64_t profile_counters_size_ = 0; + + // Only XlaCompiledCpuFunction is allowed to read and write the above + // fields. + friend class XlaCompiledCpuFunction; + }; + + // AllocMode controls the buffer allocation mode. + enum class AllocMode { + // Allocate all buffers - args, results, profile and temps. + ARGS_VARIABLES_RESULTS_PROFILES_AND_TEMPS, + + // Only allocate result, profile and temp buffers. + // Use set_arg_data to set argument buffers before Run is called. + RESULTS_PROFILES_AND_TEMPS_ONLY, + }; + + explicit XlaCompiledCpuFunction( + const StaticData& static_data, + AllocMode alloc_mode = + AllocMode::ARGS_VARIABLES_RESULTS_PROFILES_AND_TEMPS); + virtual ~XlaCompiledCpuFunction(); + + XlaCompiledCpuFunction(const XlaCompiledCpuFunction&) = delete; + XlaCompiledCpuFunction& operator=(const XlaCompiledCpuFunction&) = delete; + XlaCompiledCpuFunction(XlaCompiledCpuFunction&&) = default; + XlaCompiledCpuFunction& operator=(XlaCompiledCpuFunction&&) = default; + + // Sets the intra-op thread pool used to run individual ops concurrently. + void set_thread_pool(const Eigen::ThreadPoolDevice* pool) { + run_options_.set_intra_op_thread_pool(pool); + } + + // Runs the computation, with inputs read from arg buffers, and outputs + // written to result buffers. Returns true on success and false on failure. + bool Run(); + + // Returns the error message from the previous failed Run call. + // + // TODO(fschneider): For now this always returns an empty string because there + // is no support for error reporting in XLA. Remove this once all callers are + // updated. + string error_msg() const { return {}; } + + // ------------------------------ + // Arg methods for managing input buffers. Buffers are in row-major order. + + // Returns the buffer for the positional argument at the given `index`. + void* arg_data(size_t index) { + return buffer_table_[arg_index_table_[index]]; + } + const void* arg_data(size_t index) const { + return buffer_table_[arg_index_table_[index]]; + } + + int num_results() const { return num_results_; } + + int num_args() const { return num_args_; } + + int num_variables() const { return num_variables_; } + + // Returns the size of entry parameter `idx`. + // + // There is a static version of this method on tfcompile generated subclasses + // of XlaCompiledCpuFunction, but try to prefer this when possible since it + // works both for XlaJitCompiledCpuFunction and AOT compiled subclasses. + int arg_size(int idx) const { + assert(idx < num_args()); + return buffer_infos_[arg_index_table_[idx]].size(); + } + + // Sets the buffer for the positional argument at the given `index` to `data`. + // Must be called before Run to have an effect. May be called under any + // AllocMode; if the AllocMode is RESULTS_AND_TEMPS_ONLY, this method must be + // called for each positional argument, in order to set the argument buffers. + // + // Allocated memory must be aligned to the size specified by + // xla::cpu_function_runtime::MinAlign(). If possible, use the functions in + // tensorflow/compiler/tf2xla/cpu_function_runtime.h to ensure correct + // alignment. + // + // Aliasing of argument and result buffers is not allowed, and results in + // undefined behavior. + void set_arg_data(size_t index, const void* data) { + assert((arg_size(index) < xla::cpu_function_runtime::MinAlign() || + (uintptr_t)data % xla::cpu_function_runtime::MinAlign() == 0) && + "Underaligned pointer!"); + // The const_cast is safe because the generated code does not write to arg + // buffers. + // + // buffer_table_ contains pointers to buffers that _will_ be written to by + // generated code so it would be misleading to make buffer_table_ a `const + // void**`. + buffer_table_[arg_index_table_[index]] = const_cast(data); + } + + // ------------------------------ + // Result methods for managing output buffers. Buffers are in row-major order. + // Must only be called after a successful Run call. Unlike the arg methods, + // there is no set_resultN_data method. The result buffers are managed + // internally, and may change after each call to Run. + + // Returns the underlying array of result buffers, where results()[I] is the + // buffer for the positional result at index I. + void** results() { return static_cast(buffer_table_[result_index_]); } + const void* const* results() const { + return static_cast(buffer_table_[result_index_]); + } + + // Profile counters for this XLA computation. + // + // When Hlo profiling is enabled (`hlo_profiling_enabled()` return true in + // this case) these counters are non-null and are automatically populated by + // `Run`. The counters can then be pretty-printed using + // `hlo_profile_printer()`. + // + // When Hlo profiling is disabled, this accessor returns null. + const int64_t* profile_counters() const { return profile_counters_; } + + // Returns the buffer for the positional result at the given `index`. + void* result_data(size_t index) { return results()[index]; } + const void* result_data(size_t index) const { return results()[index]; } + + // ------------------------------ + // Methods for extracting optional metadata. + + // Returns true iff data is available for the Lookup{Arg,Variable,Result}Index + // methods. E.g. the data might not be compiled into the binary for AOT. + bool HasNameIndices() const { + return arg_names_ != nullptr && variable_names_ != nullptr && + result_names_ != nullptr; + } + + // Returns the 0-based index for the argument with the given `name`. + // Returns -1 if the name wasn't found, or data isn't available. + // + // The index remains constant for every instance of XlaCompiledCpuFunction + // generated from the same static data, and might not be cheap to determine. + // Recommended usage is to capture this in a variable for re-use. + int LookupArgIndex(const string& name) const; + + // Returns the 0-based index for the variable with the given `name`. + // Returns -1 if the name wasn't found, or data isn't available. + // + // The index remains constant for every instance of XlaCompiledCpuFunction + // generated from the same static data, and might not be cheap to determine. + // Recommended usage is to capture this in a variable for re-use. + int LookupVariableIndex(const string& name) const; + + // Returns the 0-based index for the result with the given `name`. + // Returns -1 if the name wasn't found, or data isn't available. + // + // The index remains constant for every instance of XlaCompiledCpuFunction + // generated from the same static data, and might not be cheap to determine. + // Recommended usage is to capture this in a variable for re-use. + int LookupResultIndex(const string& name) const; + + // Returns the name of the argument at `index`. + // Returns nullptr if `HasNameIndices() == false` or `index` is out of range. + const char* GetArgName(int index) const; + + // Returns the name of the variable at `index`. + // Returns nullptr if `HasNameIndices() == false` or `index` is out of range. + const char* GetVariableName(int index) const; + + // Returns the name of the result at `index`. + // Returns nullptr if `HasNameIndices() == false` or `index` is out of range. + const char* GetResultName(int index) const; + + // Returns the shape of the args and results. May return nullptr if the + // program shape isn't available. + const xla::ProgramShapeProto* ProgramShape() const { return program_shape_; } + + bool hlo_profiling_enabled() const { + return hlo_profile_printer_data_ != nullptr; + } + const xla::HloProfilePrinterData& hlo_profile_printer_data() const { + assert(hlo_profiling_enabled()); + return *hlo_profile_printer_data_; + } + + protected: + // --------------------------------------------------------------------------- + // Accessors for reading from and writing to instances of `StaticData`. + // + // Classes generated by tfcompile can call these because the generated classes + // inherit from `XlaCompiledCpuFunction`. `XlaJitCompiledCpuFunction` can + // call these because it is explicitly added as a friend. + + static void set_static_data_raw_function(StaticData* static_data, + RawFunction raw_function) { + static_data->raw_function_ = raw_function; + } + + static void set_static_data_buffer_infos( + StaticData* static_data, + const xla::cpu_function_runtime::BufferInfo* buffer_infos) { + static_data->buffer_infos_ = buffer_infos; + } + + static void set_static_data_num_buffers(StaticData* static_data, + size_t num_buffers) { + static_data->num_buffers_ = num_buffers; + } + + static void set_static_data_result_index_table( + StaticData* static_data, const int32* result_index_table) { + static_data->result_index_table_ = result_index_table; + } + + static void set_static_data_num_results(StaticData* static_data, + int64_t num_results) { + static_data->num_results_ = num_results; + } + + static void set_static_data_arg_index_table(StaticData* static_data, + const int32* arg_index_table) { + static_data->arg_index_table_ = arg_index_table; + } + + static void set_static_data_num_args(StaticData* static_data, + int64_t num_args) { + static_data->num_args_ = num_args; + } + + static void set_static_data_num_variables(StaticData* static_data, + int64_t num_variables) { + static_data->num_variables_ = num_variables; + } + + static void set_static_data_result_index(StaticData* static_data, + size_t result_index) { + static_data->result_index_ = result_index; + } + + static void set_static_data_arg_shape_infos(StaticData* static_data, + const ShapeInfo* shape_infos) { + static_data->arg_shape_infos_ = shape_infos; + } + + static void set_static_data_result_shape_infos(StaticData* static_data, + const ShapeInfo* shape_infos) { + static_data->result_shape_infos_ = shape_infos; + } + + static void set_static_data_arg_names(StaticData* static_data, + const char** arg_names) { + static_data->arg_names_ = arg_names; + } + + static void set_static_data_variable_names(StaticData* static_data, + const char** variable_names) { + static_data->variable_names_ = variable_names; + } + + static void set_static_data_result_names(StaticData* static_data, + const char** result_names) { + static_data->result_names_ = result_names; + } + + static void set_static_data_program_shape( + StaticData* static_data, const xla::ProgramShapeProto* program_shape) { + static_data->program_shape_ = program_shape; + } + + static void set_static_data_hlo_profile_printer_data( + StaticData* static_data, + const xla::HloProfilePrinterData* hlo_profile_printer_data) { + static_data->hlo_profile_printer_data_ = hlo_profile_printer_data; + } + + static const xla::HloProfilePrinterData* + get_static_data_hlo_profile_printer_data(StaticData* static_data) { + return static_data->hlo_profile_printer_data_; + } + + static void set_static_data_profile_counters_size( + StaticData* static_data, int64_t profile_counters_size) { + static_data->profile_counters_size_ = profile_counters_size; + } + + // TODO(ezhulenev): This is a no-op after removing xla runtime, however it is + // still required for building some targets. Figure out why and delete! + static void set_static_data_use_xla_runtime(StaticData* static_data, bool) {} + + private: + const RawFunction raw_function_; + + const size_t result_index_; + + // Array containing pointers to argument and temp buffers (slots corresponding + // to constant and on-stack buffers are null). + void** const buffer_table_; + + // Describes the buffers used by the XLA computation. + const xla::cpu_function_runtime::BufferInfo* const buffer_infos_; + const int32 num_buffers_; + + // Indices of expanded result tuple. + const int32 num_results_; + const int32* const result_index_table_; + + // Argument i needs to be placed in buffer_table_[arg_index_to_temp_index_[i]] + // for XLA generated code to be able to find it. + const int32* const arg_index_table_; + + // The number of incoming arguments. + const int32 num_args_; + + // The number of incoming variables. + const int32 num_variables_; + + // Shapes of the input arguments. + const ShapeInfo* const arg_shape_infos_; + + // Shapes of the results. + const ShapeInfo* const result_shape_infos_; + + // Backing memory for buffer_table_ and args_, the latter depending on + // AllocMode. + void* alloc_buffer_table_ = nullptr; + + // Backing memory for profiling counters. + int64_t* profile_counters_ = nullptr; + + // Options and context passed to the compiled function. + xla::ExecutableRunOptions run_options_; + + // Optional metadata. + const char** arg_names_ = nullptr; + const char** variable_names_ = nullptr; + const char** result_names_ = nullptr; + const xla::ProgramShapeProto* program_shape_ = nullptr; + const xla::HloProfilePrinterData* hlo_profile_printer_data_ = nullptr; + + // Add `XlaJitCompiledCpuFunction` as a friend so that it can access the + // `set_static_data_*` static methods above. + friend class XlaJitCompiledCpuFunction; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILED_CPU_FUNCTION_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/tf2xla/xla_compiler.h b/third_party/tflite-hdrs/tensorflow/compiler/tf2xla/xla_compiler.h new file mode 100644 index 00000000..cbb57f38 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/tf2xla/xla_compiler.h @@ -0,0 +1,403 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILER_H_ +#define TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILER_H_ + +#include + +#include "absl/synchronization/mutex.h" +#include "absl/types/span.h" +#include "absl/types/variant.h" +#include "tensorflow/compiler/tf2xla/host_compute_metadata.pb.h" +#include "tensorflow/compiler/tf2xla/layout_util.h" +#include "tensorflow/compiler/tf2xla/xla_argument.h" +#include "tensorflow/compiler/tf2xla/xla_compilation_device.h" +#include "tensorflow/compiler/tf2xla/xla_expression.h" +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "xla/client/local_client.h" +#include "xla/hlo/builder/xla_builder.h" +#include "xla/hlo/builder/xla_computation.h" +#include "xla/status_macros.h" +#include "tensorflow/core/common_runtime/device.h" +#include "tensorflow/core/common_runtime/device_mgr.h" +#include "tensorflow/core/common_runtime/function.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/notification.h" +#include "tensorflow/core/platform/thread_annotations.h" +#include "tensorflow/core/protobuf/config.pb.h" +#include "tensorflow/core/public/version.h" + +namespace tensorflow { + +class XlaContext; + +// The XlaCompiler class is responsible for compilation of a self-contained +// subgraph of a TensorFlow computation using the XLA linear algebra runtime. +// It does a symbolic execution of the graph starting from specific input +// shapes, using a JIT device to convert operators into XLA computations. +// +// XlaCompiler is typically invoked from an `XlaLaunch` operator once the +// shapes of all input parameters to the computation are known. This is +// because the symbolic execution requires known shapes for all operations. +// +// XlaCompiler compiles Tensorflow graphs that received inputs via _Arg nodes, +// and return outputs via _Retval nodes. +// +// The XlaCompiler requires one Argument struct for each _Arg index, that +// describes each argument. Arguments can be compile-time constants +// (kind kConstant), run-time parameters (kind kParameter), or resources +// (kind kResource). +// +// Only kParameter and initialized kResource arguments become runtime parameters +// to the generated XLA computation. +// +// The run-time outputs of the XLA computation are arranged in the following +// order: +// +------------------+-----------------------------------------+ +// | _Retval values | Updated values of kResource arguments | +// +------------------+-----------------------------------------+ +// _Retval values are ordered by _Retval index, whereas kResource values are +// ordered by the original _Arg position of the variable. +// +// If a shape representation function is provided as part of +// XlaCompiler::CompileOptions, kParameter arguments and return values to an +// entry computation will be reshaped in accordance to the shape function. +// Arguments and return values to a non-entry computation are not reshaped. +// Variable resource arguments are passed and returned in reshaped form, even +// for non-entry computations. This feature allows TensorFlow to keep on-device +// tensors with a different shape to their representation inside the XLA +// computation. +// +// In computation outputs, updated kResource values are placed the end. When +// emitting While loop bodies, we must ensure that the loop body has +// identical input and output signatures. By passing variable values +// at the end of the argument list and using the +// `return_updated_values_for_all_variables` option, we can ensure that the +// input and output values of resources appear at the same positions. +// +// Resources are passed as parameters or returned as resource updates in +// "packed" form. +// kStack resources are packed as (array, size of stack) XLA tuples. +// kTensorArray resources without gradients are packed as the array that +// backs the TensorArray. If gradients are present (`tensor_array_gradients`), +// the packed representation is a (array, gradient0, gradient1, ...) tuple, +// where gradient_k is the value of the k-th gradient in the +// `tensor_array_gradients` ordered set. +class XlaCompiler { + public: + // TODO(b/255826209): Remove this alias. Depending on XlaCompiler just to use + // XlaArgument seeems weird and can cause circular dependencies. + using Argument = ::tensorflow::XlaArgument; + + // Options pertaining to an individual call to CompileGraph() or + // CompileFunction(). + struct CompileOptions { + // If `use_tuple_arg` is true, a single tuple parameter will be used for all + // arguments; if false, each argument gets its own parameter. + bool use_tuple_arg = false; + + // If 'return_updated_values_for_all_resources' is true, then updated + // values of all resource arguments will be included in the + // 'resource_updates' of the computation, even if the resource was not + // modified by the computation. Used when compiling loop bodies to ensure + // the input and output signatures match. + bool return_updated_values_for_all_resources = false; + + // If 'always_return_tuple' is true, then the output of a computation will + // always be a tuple. Otherwise, a single-element output will not be wrapped + // in a tuple. + bool always_return_tuple = true; + + // True when compiling the entry computation, false for subcomputations + // (while, call, etc.) + bool is_entry_computation = true; + + // True when we should add XLA input & output to the graph/function. + bool add_token_input_output = false; + + // Resource updates are converted into input / output of xla. The two + // buffers are aliased with other if this option is true. + bool alias_resource_update = false; + }; + + using OutputDescription = ::tensorflow::XlaOutputDescription; + + using ResourceUpdate = ::tensorflow::XlaResourceUpdate; + + using CompilationResult = ::tensorflow::XlaCompilationResult; + + struct Options { + // Name of the compilation device to use. It must be set by the caller. + // The default empty value is invalid. + DeviceType device_type = DeviceType(""); + + // The device to use during compilation to execute instructions on, for + // example for auto-tuning. + // Valid values are defined by `xla::Backend::devices_ordinal_supported()`. + // -1 indicates the default device should be used. + int device_ordinal = -1; + + xla::Client* client = nullptr; + + // Function library in which to find function definitions. Must be non-null. + const FunctionLibraryDefinition* flib_def = nullptr; + + // The graph def version to be compiled. + int graph_def_version = TF_GRAPH_DEF_VERSION; + + // If 'allow_cpu_custom_calls' is true, kernels may make use of CustomCall() + // for CPU. + bool allow_cpu_custom_calls = false; + + // A ShapeDeterminationFns (i.e., a bundle of LayoutSelectionFn and + // ShapeRepresentationFn). Each bundle describes the XLA representation of + // arguments represented to XLA as the shape given by this shape function. + // Arguments are input activations or weights to an XLA entry computation. + // Variables are reshaped to this shape on write, and reshaped to their + // original shape on read. + XlaShapeLayoutHelpers::ShapeDeterminationFns shape_determination_fns; + + // If not nullptr, populate_resource_manager is called with the + // compilation device's resource manager when the compilation + // device is created, and can be used to create metadata objects + // that can be accessed by XLA op kernels. + std::function* populate_resource_manager = + nullptr; + + // If not nullptr, this memory allocator can be used by the compiler for + // temporary allocations it might want to make during compilation. + // + // For example, the compiler may want to try out different algorithms and + // choose the fastest one, and it might run those algorithms over buffers + // created using this allocator. + // + // The compiler can function correctly without an explicit allocator given + // here, but on some devices (notably, GPUs), TensorFlow tends to eagerly + // allocate most or all available memory on the device, leaving none for the + // compiler to access, unless it can use TensorFlow's allocator. + // This must be a shared_ptr, as this is passed all the way down to the + // cluster compilation. This allows asynchronous compilation to hold a + // reference until the compilation is finished. + std::shared_ptr device_allocator; + + // Alias input and output buffers for parameters that are passed-through XLA + // modules without being changed. + bool alias_passthrough_params = false; + + // Enable detailed logging of compilation metadata. + bool detailed_logging = true; + }; + + // Argument for compiling a single op. + struct SingleOpCompileArgument { + // Data type of the output tensors. This is used to create _Retval node. + std::vector output_dtypes; + + // The NodeDef representing the op. + NodeDef node_def; + + // This is currently only used to obtain MLIR TPU bridge rollout state. + // Can be removed once full rollout is complete. + ConfigProto config_proto; + + SingleOpCompileArgument() = default; + + explicit SingleOpCompileArgument(const OpKernelContext& ctx); + }; + + explicit XlaCompiler(Options options); + + ~XlaCompiler(); + + // Helper function to populate an XlaCompiler::Argument from XlaResource. + static void PopulateArgumentFromResource(const XlaResource& resource, + Argument* arg); + + absl::Status CompileFunction(const CompileOptions& options, + const NameAttrList& fn_name_attrs, + absl::Span args, + CompilationResult* result); + + absl::Status CompileSingleOp( + const CompileOptions& options, + const SingleOpCompileArgument& single_op_compile_argument, + absl::Span args, CompilationResult* result); + + // Compiles a tensorflow::Graph into an xla::XlaComputation. + // Similar to CompileFunction, but takes a Graph as input rather than a + // function. + absl::Status CompileGraph(const CompileOptions& options, string const& name, + std::unique_ptr graph, + absl::Span args, + CompilationResult* result); + + // Returns the shape of the XLA parameter for an argument 'arg'. + // See the class comment for more details about the argument passing + // convention. + absl::Status XLAShapeForArgument( + const Argument& arg, bool is_entry_computation, + const std::optional& arg_sharding, + xla::Shape* xla_shape) const; + + // Retrieves the channel handle associated with `key`. Allocates + // a new channel handle if none exists. + // Channel handles can be used to communicate between different + // computations. Computations that communicate should be compiled with the + // same XlaCompiler. + absl::Status GetChannelHandle(const string& key, xla::ChannelHandle* channel); + + // Retrieves the host-to-device channel handle associated with `key`. + // Allocates a new channel handle if none exists. + absl::Status GetHostToDeviceChannelHandle(const string& key, + xla::ChannelHandle* channel); + + // Retrieves the device-to-host channel handle associated with `key`. + // Allocates a new channel handle if none exists. + absl::Status GetDeviceToHostChannelHandle(const string& key, + xla::ChannelHandle* channel); + + // Sets the shapes and types for the device to host transfer associated with + // 'key'. + absl::Status SetDeviceToHostMetadata(const string& key, + absl::Span types, + absl::Span shapes); + + // Gets the shapes the device to host transfer associated with 'key'. + absl::Status GetDeviceToHostShapes(const string& key, + std::vector* shapes) const; + + // Sets the shapes and types for the host to device transfer associated with + // 'key'. + absl::Status SetHostToDeviceMetadata(const string& key, + absl::Span types, + absl::Span shapes); + + // In order to avoid deadlocks from dependencies in host computations, it can + // be necessary to enforce a partial order on the execution of HostCompute + // Ops. In particular it may be necessary to constrain the SendToHost for one + // HostCompute to run before blocking on the RecvAtHost for another + // HostCompute. The compiler maintains a mapping from 'host_compute_name' to + // handle, where the handle is an 'output' of the HostCompute Op corresponding + // to 'host_compute_name'. Another HostCompute Op that needs to be sequenced + // later can add the handle as an 'input' to enforce the constraints. + // 'host_compute_name' can be any string the client wishes to use to identify + // a given HostCompute Op as long as the names are unique within the + // compilation. + absl::Status GetHostComputeControlDependency(const string& host_compute_name, + xla::XlaOp* handle); + absl::Status SetHostComputeControlDependency(const string& host_compute_name, + xla::XlaOp handle); + + const Options& options() const { return options_; } + xla::Client* client() const { return options_.client; } + FunctionLibraryRuntime* flib_runtime() const { return flib_runtime_; } + + void PushNodeTokenMapping(); + absl::Status PopNodeTokenMapping(); + absl::Status SetNodeToken(const string& node_name, xla::XlaOp op); + absl::StatusOr GetNodeToken(const string& node_name); + + // Sets the function body `fbody` to the one registered as `function`. + absl::Status FindFunctionBody(const NameAttrList& function, + const FunctionBody** fbody, + const ConfigProto** config_proto = nullptr); + + private: + absl::Mutex channel_mutex_; + // Returns the optimized graph object in this function body. + std::unique_ptr GetGraph(const FunctionBody* fbody); + + // Builds XLA computations for each of the arguments to the computation. + // `args` are the arguments to the computation. + absl::Status BuildArguments( + const Graph& graph, const std::vector& args, + bool use_tuple_arg, xla::XlaBuilder* builder, XlaContext* context, + const std::map& arg_shardings, + std::vector* arg_expressions, + std::vector* input_to_args, std::vector* input_shapes, + bool is_entry_computation); + + xla::ChannelHandle NewChannel(xla::ChannelHandle::ChannelType type); + + // Graph compiler needs to know how to get an optimized graph from a function + // body. + friend class GraphCompiler; + friend class XlaCompilerTest; + + Options options_; + + // Status set to non-OK in the constructor if initialization fails. + absl::Status initialization_status_; + + // Returns the next step sequence number. + int64_t NextStepId(); + + // Internal sequence number for steps executed on the compilation device. + int64_t next_step_id_; + + XlaCompilationDevice* device_; // Owned by device_mgr_ + StaticDeviceMgr device_mgr_; + + // The next sequence number to assign to a channel. + int64_t next_channel_ ABSL_GUARDED_BY(channel_mutex_) = 1; + + // To avoid copying the client's function library, use a local function + // library and runtime for functions created as part of the functionalize + // control flow transformation. + std::unique_ptr local_flib_def_; + std::unique_ptr pflr_; + std::unique_ptr local_pflr_; + + FunctionLibraryRuntime* local_flib_runtime_; // owned by local_pflr_. + FunctionLibraryRuntime* flib_runtime_; // owned by pflr_. + + struct SignatureHash { + uint64 operator()( + const std::pair>& signature) const; + }; + + std::unordered_map>, + CompilationResult, SignatureHash> + cache_; + + std::unordered_map channels_; + + std::unordered_map host_compute_sends_; + std::unordered_map host_compute_recvs_; + + std::unordered_map host_compute_control_output_; + + // This is used to store mapping. Side-effecting + // ops call SetNodeToken() to record its token output, so later side-effecting + // ops can use GetNodeToken() to get it and use it as token input. + // + // It's a stack because we need a mapping like this for each level of nested + // CompileGraph() call. In CompileGraph(), we will push a new mapping to the + // stack, and pop the mapping before returning. + std::stack> node_token_mapping_stack_; + + XlaCompiler(const XlaCompiler&) = delete; + void operator=(const XlaCompiler&) = delete; +}; + + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/tf2xla/xla_context.h b/third_party/tflite-hdrs/tensorflow/compiler/tf2xla/xla_context.h new file mode 100644 index 00000000..9184fb43 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/tf2xla/xla_context.h @@ -0,0 +1,184 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This file defines the contexts used during XLA compilation. + +#ifndef TENSORFLOW_COMPILER_TF2XLA_XLA_CONTEXT_H_ +#define TENSORFLOW_COMPILER_TF2XLA_XLA_CONTEXT_H_ + +#include + +#include "tensorflow/compiler/tf2xla/xla_expression.h" +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "xla/hlo/builder/xla_builder.h" +#include "xla/hlo/builder/xla_computation.h" +#include "xla/status_macros.h" +#include "xla/xla_data.pb.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/resource_mgr.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/platform/macros.h" + +namespace tensorflow { + +class XlaOpKernelContext; +class XlaCompiler; + +// The XlaContext is the data structure that holds the state of an XLA +// compilation, that is accessible from OpKernelContexts when compiling a +// subgraph of Ops using XLA. +class XlaContext : public ResourceBase { + public: + // Retrieves the XlaContext of the current compilation. + static XlaContext& Get(const OpKernelContext* ctx); + + // Creates a new XlaContext. See the documentation on the class data fields + // for descriptions of the arguments. + XlaContext(XlaCompiler* compiler, xla::XlaBuilder* builder, + const Graph* graph); + + // Virtual method defined by ResourceBase. + string DebugString() const override; + + XlaCompiler* compiler() const { return compiler_; } + + const AbstractStackTrace* StackTraceForNodeName(const std::string& name) { + const auto& it = stack_traces_.find(name); + if (it != stack_traces_.end()) { + return it->second.get(); + } + return nullptr; + } + + // Returns the XlaBuilder that Ops use for compiling new expressions. + xla::XlaBuilder* builder() { return builder_; } + + const std::vector& args() const { return args_; } + void set_args(std::vector args); + + const std::vector& retvals() { return retvals_; } + + // Sets a return value. + // Since we do not always know in advance how many return values there are, + // grows the return values vector to size index+1 if it is smaller. + void SetRetval(int index, const XlaExpression& expression); + + // Adds 'resource' to the set of resources owned by the context. + XlaResource* AddResource(std::unique_ptr resource); + + const std::vector>& resources() { + return resources_; + } + + // Get an XLA lambda to compute Max. This is cached in the + // XlaContext since it may be used by multiple Ops. There is a + // separate specialization of the computation for each DataType. + const xla::XlaComputation* GetOrCreateMax(const DataType type); + + // Get an XLA lambda to compute Min. This is cached in the + // XlaContext since it may be used by multiple Ops. There is a + // separate specialization of the computation for each DataType. + const xla::XlaComputation* GetOrCreateMin(const DataType type); + + // Get an XLA lambda to compute Add. This is cached in the + // XlaContext since it may be used by multiple Ops. There is a + // separate specialization of the computation for each DataType. + const xla::XlaComputation* GetOrCreateAdd(const DataType type); + + // Get an XLA lambda to compute LogAddExp. This is cached in the + // XlaContext since it may be used by multiple Ops. There is a + // separate specialization of the computation for each DataType. + const xla::XlaComputation* GetOrCreateLogAddExp(const DataType type); + + // Get an XLA lambda to compute Mul. This is cached in the + // XlaContext since it may be used by multiple Ops. There is a + // separate specialization of the computation for each DataType. + const xla::XlaComputation* GetOrCreateMul(const DataType type); + + // The name of the XlaContext resource during symbolic graph execution. + static const char kXlaContextResourceName[]; + + // Records the collective information from the nested compilation `result`. + absl::Status RecordCollectiveInfoFromNestedCompilationResult( + const XlaCompilationResult& result); + + // Records the collective configurations for all the collectives in the XLA + // cluster and returns the channel_id to be used for the next collective. + absl::StatusOr RecordCollectiveInfo(int group_key, int group_size); + + const std::optional& + GetCollectiveInfo() { + return collective_info_; + } + + private: + XlaCompiler* const compiler_; + + // The XlaBuilder used to construct the subgraph's compiled representation. + xla::XlaBuilder* builder_; + + // Stack traces for the graph used for compilation. + StackTracesMap stack_traces_; + + // Arguments to the Tensorflow graph, indexed by _Arg index. + // Includes both compile-time constant arguments and runtime parameters. + std::vector args_; + + // Return values of the Tensorflow graph, indexed by _Retval index. + std::vector retvals_; + + // Holds ownership of resources. The resources are not ordered. + std::vector> resources_; + + // Information about encountered collective ops. We allow only a + // single configuration per cluster. + std::optional collective_info_; + + // Cache of prebuilt computations indexed by their type. + using ComputationMap = std::map; + + // Finds the value for the given type in out map if it already + // exists or makes a new value with create function and keeps it the + // map. The returned value != nullptr and is owned by the map. + const xla::XlaComputation* LookupOrCreate( + DataType type, ComputationMap* out, + const std::function& create); + + // Cached computation to compute Max of two elements, specialized by type. + ComputationMap max_func_; + + // Cached computation to compute Min of two elements, specialized by type. + ComputationMap min_func_; + + // Cached computation to compute Sum of two elements, specialized by type. + ComputationMap add_func_; + + // Cached computation to compute Mul of two elements, specialized by type. + ComputationMap mul_func_; + + // Cached computation to compute Log(Add(Exp())) of two elements, specialized + // by type. + ComputationMap log_add_exp_func_; + + // Cached computation to compute Sigmoid of an element, specialized by type. + ComputationMap sigmoid_func_; + + XlaContext(const XlaContext&) = delete; + void operator=(const XlaContext&) = delete; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2XLA_XLA_CONTEXT_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/tf2xla/xla_expression.h b/third_party/tflite-hdrs/tensorflow/compiler/tf2xla/xla_expression.h new file mode 100644 index 00000000..d410b79a --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/tf2xla/xla_expression.h @@ -0,0 +1,173 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2XLA_XLA_EXPRESSION_H_ +#define TENSORFLOW_COMPILER_TF2XLA_XLA_EXPRESSION_H_ + +#include "absl/types/optional.h" +#include "tensorflow/compiler/tf2xla/xla_resource.h" +#include "xla/client/client.h" +#include "xla/hlo/builder/value_inference.h" +#include "xla/hlo/builder/xla_builder.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/statusor.h" + +namespace tensorflow { + +// A XlaExpression represents a symbolic TensorFlow value in a TF->XLA +// compilation. +// An expression is one of: +// * a constant tensor. +// * an xla::XlaOp, representing a symbolic XLA value. +// * a resource, e.g., a variable, represented as an XlaResource pointer. +// * a tensor list, represented by a tuple of tensors and the list length. +// +// Constant tensors are mostly an optimization to avoid passing large constants +// to XLA, but are also sometimes used to represent tensors that have no XLA +// representation, for example, DT_STRING tensors. A canonical use case might be +// an error message string. +// +// Tensor lists are very similar to xla::XlaOp, however they require some +// specific logic around shape management since the tuples are not supported by +// TensorFlow. +class XlaExpression { + public: + enum class Kind { + kInvalid, + kConstant, + kXlaOp, + kResource, + kTensorList, + }; + + XlaExpression(); + XlaExpression(const XlaExpression&) = default; + XlaExpression& operator=(const XlaExpression&) = default; + + // Builds an invalid expression. (Same as the default constructor, but makes + // the intent clearer.) + static XlaExpression Invalid(); + + // Builds a constant XLA expression. + static XlaExpression Constant(Tensor value); + + // Builds a XlaOp expression. Since the mapping from TF data types to XLA + // types is not 1-1, the TF type must also be provided; in general it cannot + // be derived from the XLA type. + static XlaExpression XlaOp(xla::XlaOp value, DataType dtype); + + // Builds a tensor list expression. + static XlaExpression TensorList(xla::XlaOp tensor_list); + + // Builds a resource expression. + static XlaExpression Resource(XlaResource* resource); + + // Builds a resource whose value is known at a compile time. + static XlaExpression ConstantResource(Tensor value, XlaResource* resource); + + Kind kind() const { return kind_; } + + DataType dtype() const { return dtype_; } + + // handle() returns the XlaOp that backs a kXlaOp expression. + const xla::XlaOp& handle() const { return handle_; } + + // Return a constant value associated with this expression. Always set for + // constants, might be set for resources. + std::optional constant_value() const { + if (kind_ == Kind::kResource && resource_->IsOverwritten()) { + // The constant is no longer available if the value was overwritten. + return std::nullopt; + } + return constant_value_; + } + + // Set the bound of the expression. + void set_value_bound(Tensor tensor) { + value_bound_.emplace(std::move(tensor)); + } + + // Return the bound of the expression, if available. + std::optional value_bound() const { return value_bound_; } + + // Set the dynamism of the expression, indicating whether or not each value in + // this expression is dynamic. + void set_value_dynamism(Tensor tensor) { + value_dynamism_.emplace(std::move(tensor)); + } + + // Return the dynamism of the expression, if available. + std::optional value_dynamism() const { return value_dynamism_; } + + XlaResource* resource() const { return resource_; } + + // Returns a human-readable summary of the expression. + string HumanString() const; + + // Returns the value of a kValue or kXlaOp as an xla::XlaOp. Returns + // an erroneous XlaOp if the expression is not a constant or an expression. + xla::XlaOp AsXlaOp(xla::XlaBuilder* builder) const; + + // If a kXlaOp or kValue expression can be resolved to a compile-time + // constant, returns the value as a host-memory Tensor. Returns an empty + // optional if it cannot be resolved. Returns an error if passed a resource + // expression. + absl::StatusOr> ResolveConstant( + xla::Client* client, bool dynamic_dimension_is_minus_one = false, + xla::ValueInferenceMode mode = xla::ValueInferenceMode::kValue) const; + + // ResolveDynamism computes where a value inside this op is dynamic or can be + // inferred at compile time. + absl::StatusOr ResolveDynamism() const; + + // Returns the shape of the tensor. + // The shape of a resource is the shape of a resource handle (i.e., a scalar), + // not the shape of the resource's value. + absl::StatusOr GetShape() const; + absl::StatusOr GetXlaShape() const; + + // Retrieves an XlaExpression that was allocated by a previous Op. + static const XlaExpression* CastExpressionFromTensor(const Tensor& tensor); + + // Assigns an XlaExpression to a tensor on an XLA compilation device. + static void AssignExpressionToTensor(const XlaExpression& value, + Tensor* tensor); + + private: + Kind kind_ = Kind::kInvalid; + + DataType dtype_ = DT_INVALID; + + // The XLA handle of the expression's computation, if kind_ == kXlaOp or + // a tuple expression if kind_ == kTensorList. + xla::XlaOp handle_; + + // The value of the constant, if available. + std::optional constant_value_; + + // The bound of the expression, if available. + std::optional value_bound_; + + // Indicate whether each value inside a tensor is dynamic or not. + std::optional value_dynamism_; + + // The resource, if kind_ == kResource. Not owned. + XlaResource* resource_ = nullptr; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2XLA_XLA_EXPRESSION_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/tf2xla/xla_helpers.h b/third_party/tflite-hdrs/tensorflow/compiler/tf2xla/xla_helpers.h new file mode 100644 index 00000000..38f01c83 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/tf2xla/xla_helpers.h @@ -0,0 +1,214 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This file defines helper routines for the XLA device. + +#ifndef TENSORFLOW_COMPILER_TF2XLA_XLA_HELPERS_H_ +#define TENSORFLOW_COMPILER_TF2XLA_XLA_HELPERS_H_ + +#include + +#include "absl/types/optional.h" +#include "absl/types/span.h" +#include "tensorflow/compiler/tf2xla/host_compute_metadata.pb.h" +#include "xla/executable_run_options.h" +#include "xla/hlo/builder/xla_builder.h" +#include "xla/hlo/ir/hlo_sharding.h" +#include "xla/hlo/translate/mhlo_to_hlo/layout_util.h" +#include "xla/service/computation_placer.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor.h" + +namespace tensorflow { + +using XlaLayoutPreference = mlir::XlaLayoutPreference; + +inline std::string GetDeviceToHostChannelName(absl::string_view channel_key, + int index) { + return absl::StrCat(channel_key, "_dtoh_", index); +} +inline std::string GetHostToDeviceChannelName(absl::string_view channel_key, + int index) { + return absl::StrCat(channel_key, "_htod_", index); +} + +// Helper methods for building XLA computations. +class XlaHelpers { + public: + // Returns a handle representing the zero value of a scalar + // element of data_type. + static xla::XlaOp Zero(xla::XlaBuilder* b, DataType data_type); + + // Returns a handle representing the one value of a scalar + // element of data_type. + static xla::XlaOp One(xla::XlaBuilder* b, DataType data_type); + + // Returns a handle representing the given value of an integer scalar + // element of data_type. + // Note that unlike One and Zero, does not work on boolean types. + static xla::XlaOp IntegerLiteral(xla::XlaBuilder* b, DataType data_type, + int64_t value); + + // Returns a handle representing the given value of a floating-point scalar + // element of data_type. + static xla::XlaOp FloatLiteral(xla::XlaBuilder* b, DataType data_type, + double value); + + // Reshapes literal 'input' to have 'shape'. Both the original shape and + // 'shape' must contain the same number of elements. + static absl::Status ReshapeLiteral(const xla::Literal& input, + absl::Span shape, + xla::Literal* output); + + // Converts `indices` into a one-hot representation. `depth` is the size + // of the new axis to add. `axis` is the position at which to add the new + // axis. `indices_shape` is the shape of `indices`. `on_value` and + // `off_value` represent the values to use for the on and off positions, + // respectively. + static absl::Status OneHot(xla::XlaBuilder* builder, int64_t depth, int axis, + DataType index_type, + const TensorShape& indices_shape, + xla::XlaOp indices, xla::XlaOp on_value, + xla::XlaOp off_value, xla::XlaOp* one_hot); + + // Certain DataTypes should use increased precision DataTypes when performing + // reductions. This function remaps a given DataType to a higher precision + // DataType if needed. + static DataType SumAccumulationType(const DataType& dtype); + + // A helper for creating a ConvertElementType xla op given a DataType rather + // than the xla::PrimitiveType. + static xla::XlaOp ConvertElementType(xla::XlaOp operand, + const DataType new_element_type); + + typedef std::function(const TensorShape&, DataType, + bool, XlaLayoutPreference)> + ShapeRepresentationFn; +}; + +// Creates an identity shape representation function. +XlaHelpers::ShapeRepresentationFn IdentityShapeRepresentationFn(); + +struct XlaOutputDescription { + // Type and shape of the output. The shape is the unflattened shape. + // When `type` is DT_RESOURCE, `shape` is the shape of the resource + // variable's value. + DataType type; + TensorShape shape; + + // Constant output value, if known to be constant at JIT compilation time. + // 'Tensor' is in host memory. + bool is_constant = false; + Tensor constant_value; + + // When this output is a resource, i.e. `type == DT_RESOURCE`, this is + // the index of the input that contains the resource. + int input_index; + + // Whether this output is a TensorList. + bool is_tensor_list = false; +}; + +// Describes a variable write side effect of the computation. +struct XlaResourceUpdate { + // Index of the input that contains the variable resource to write to. + int input_index; + + // Type and shape of the tensor to be written back. + // The `shape` field has the same meaning as the Argument::shape field. + DataType type; + TensorShape shape; + + // Was the value of the variable modified by the computation? + // (Always true, unless `return_updated_values_for_all_resources` is true.) + bool modified; + + // If the resource is a TensorArray, the set of gradients read or written. + std::set tensor_array_gradients_accessed; +}; + +struct XlaCompilationResult { + // Vector that maps from the parameters of the XLA computation to their + // original argument positions. To handle compile-time constant inputs, the + // parameters to the XLA computation may be a subset of the original + // arguments. The relative ordering of parameters are maintained. + std::vector input_mapping; + + // Input shapes of the computation. If we are flattening inputs, these are + // the flattened shapes. + std::vector xla_input_shapes; + + // Output shape in XLA format. The output shape is always a tuple. If we + // are flattening outputs, these are the flattened shapes. + xla::Shape xla_output_shape; + + // TensorFlow shapes of outputs, together with the values of any + // constant arguments. Vector indexed by Tensorflow _Retval number, + // containing both constant and non-constant results. + std::vector outputs; + + // TensorFlow shapes and types of sends/recvs from HostCompute Ops to their + // matching RecvAtHost/SendFromHost Ops in the outer graph. + tf2xla::HostComputeMetadata host_compute_metadata; + + // Resources whose values were updated by the computation, ordered + // by return value position (which is the same as the order the resources + // were passed as arguments). Resource updates follow the non-constant + // results in the outputs of XLA computation. + std::vector resource_updates; + + // The XLA computation built from the tensorflow subgraph. + std::shared_ptr computation; + + // Meta-info about encountered collective ops. + struct CollectiveInfo { + int group_key; + int group_size; + int next_id; + + template + friend H AbslHashValue(H h, const CollectiveInfo& info) { + return H::combine(std::move(h), info.group_key, info.group_size, + info.next_id); + } + + friend bool operator==(const CollectiveInfo& lhs, + const CollectiveInfo& rhs) { + return lhs.group_key == rhs.group_key && + lhs.group_size == rhs.group_size && lhs.next_id == rhs.next_id; + } + }; + + // Information of the collectives encountered during the translation. + std::optional collective_info; +}; + +// Resolves the device assignment based on CollectiveInfo. +// CollectiveInfo records collective ops in the cluster. Note that +// this relies on a rendezvous and blocks until all replicas are there. +// +// Takes several extra configuration objects by reference since +// xla::ExecutableRunOptions does not take ownership; these are configured and +// bundled into `run_options` if applicable. +absl::Status ResolveDeviceAssignment( + OpKernelContext* ctx, + const XlaCompilationResult::CollectiveInfo& collective_info, + xla::ExecutableRunOptions& run_options, + xla::DeviceAssignment& device_assignment, + xla::gpu::GpuExecutableRunOptions& gpu_options); + +} // end namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2XLA_XLA_HELPERS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.h b/third_party/tflite-hdrs/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.h new file mode 100644 index 00000000..c3982bb5 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.h @@ -0,0 +1,100 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2XLA_XLA_JIT_COMPILED_CPU_FUNCTION_H_ +#define TENSORFLOW_COMPILER_TF2XLA_XLA_JIT_COMPILED_CPU_FUNCTION_H_ + +#include +#include + +#include "absl/log/check.h" +#include "tensorflow/compiler/tf2xla/tf2xla.pb.h" +#include "tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h" +#include "xla/client/local_client.h" +#include "xla/cpu_function_runtime.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/platform/statusor.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +// Represents the result of JIT compilation by XLA down to a function. This +// class holds the state necessary to create XlaCompiledCpuFunction instances, +// which are used to actually invoke the compiled computation. +// +// XlaJitCompiledCpuFunction must outlive the XlaCompiledCpuFunctions that are +// created from it. It holds state shared by all of the functions, including the +// JIT-compiled function itself, along with buffer sizes and other metadata +// necessary for execution. +class XlaJitCompiledCpuFunction { + public: + // Compile a tensorflow::GraphDef into an XlaJitCompiledCpuFunction. The given + // `config` specifies the portion of the graph to compile, via feeds and + // fetches. Each feed is a positional input argument for the compiled + // function, while each fetch is a positional output argument. + static absl::StatusOr> Compile( + const GraphDef& graph_def, const tf2xla::Config& config, + const xla::ExecutableBuildOptions& build_options); + + XlaJitCompiledCpuFunction(const XlaJitCompiledCpuFunction&) = delete; + XlaJitCompiledCpuFunction& operator=(const XlaJitCompiledCpuFunction&) = + delete; + + // Returns static data used to create an XlaCompiledCpuFunction instance, + // which represents the JIT-compiled function. The static data is unchanging + // across each instance. + const XlaCompiledCpuFunction::StaticData& StaticData() const { + return static_data_; + } + + const xla::LocalExecutable& LocalExecutable() const { + CHECK(executable_); // Crash ok + return *executable_; + } + + private: + XlaJitCompiledCpuFunction() {} + + // The executable holds the underlying function. + std::unique_ptr executable_; + + // The static data is backed by the rest of the state in this class. + XlaCompiledCpuFunction::StaticData static_data_; + + // The backing array for buffer infos. + std::vector buffer_infos_; + + // The backing array for the arg index table. + std::vector arg_index_table_; + + // The backing arrays of arg and result names. We hold the actual strings in + // nonempty_*_names_, and hold arrays of pointers in *_names_ for the static + // data to refer to. + std::vector nonempty_arg_names_; + std::vector nonempty_variable_names_; + std::vector nonempty_result_names_; + std::vector arg_names_; + std::vector variable_names_; + std::vector result_names_; + + // The backing data for the program shape. The proto form of program shape is + // used because the program shape is serialized and embedded in the object + // file. + std::unique_ptr program_shape_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2XLA_XLA_JIT_COMPILED_CPU_FUNCTION_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/tf2xla/xla_op_kernel.h b/third_party/tflite-hdrs/tensorflow/compiler/tf2xla/xla_op_kernel.h new file mode 100644 index 00000000..b0830d07 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/tf2xla/xla_op_kernel.h @@ -0,0 +1,390 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2XLA_XLA_OP_KERNEL_H_ +#define TENSORFLOW_COMPILER_TF2XLA_XLA_OP_KERNEL_H_ + +#include "absl/base/attributes.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "tensorflow/compiler/tf2xla/xla_compiler.h" +#include "tensorflow/compiler/tf2xla/xla_context.h" +#include "tensorflow/compiler/tf2xla/xla_expression.h" +#include "tensorflow/compiler/tf2xla/xla_resource.h" +#include "xla/hlo/builder/value_inference.h" +#include "xla/hlo/builder/xla_builder.h" +#include "xla/hlo/builder/xla_computation.h" +#include "xla/literal.h" +#include "xla/shape.h" +#include "xla/xla_data.pb.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/status.h" + +namespace tensorflow { + +class XlaOpKernelContext; + +// Implementations of operators that generate XLA code should usually subclass +// XlaOpKernel and implement the Compile() method. Unlike a regular OpKernel, +// an XlaOpKernel produces and consumes symbolic values during compilation. +// +// See the comments in xla_context.h for more details. +class XlaOpKernel : public OpKernel { + public: + explicit XlaOpKernel(OpKernelConstruction* construction); + + // Subclasses should implement Compile(), much as standard OpKernels implement + // Compute(). + virtual void Compile(XlaOpKernelContext* context) = 0; + + private: + void Compute(OpKernelContext* context) final; +}; + +// The context passed to the Compile() method of XlaOpKernel. An +// XlaOpKernelContext is a variant of the standard OpKernel class, tailored for +// implementing operators that perform symbolic execution as part of the XLA +// compiler. The key difference is that XlaOpKernelContext produces and consumes +// data as XLA computations, rather than as standard Tensors. +// +// Under the hood, symbolic execution communicates using special Tensors that +// wrap XlaExpression objects, however this is an implementation detail that +// this class hides. The *only* correct way to allocate a Tensor during +// compilation is using the XlaOpKernelContext methods, since they ensure there +// is a valid XlaExpression backing the tensor. No Op should ever call +// allocate_output or allocate_temp directly on the underlying OpKernelContext. +class XlaOpKernelContext { + public: + explicit XlaOpKernelContext(OpKernelContext* context); + + XlaContext* xla_context() const; + + // Returns the XLA XlaBuilder containing the output of compilation. + xla::XlaBuilder* builder() const; + + xla::ValueInference& value_inference(); + + // Inputs + + // Returns the number of inputs to the operator. + int num_inputs() const { return context_->num_inputs(); } + + // Returns the type of input `index`. + DataType input_type(int index) const; + + // Returns the type of input `name`. + DataType InputType(absl::string_view name); + + // Returns the type of input `index` as an xla::PrimitiveType. If the type + // is not representable as an XLA type, sets an error status and returns + // xla::PRIMITIVE_TYPE_INVALID. + xla::PrimitiveType input_xla_type(int index); + + // Returns the type of input `name` as an xla::PrimitiveType. If the type + // is not representable as an XLA type, sets an error status and returns + // xla::PRIMITIVE_TYPE_INVALID. + xla::PrimitiveType InputXlaType(absl::string_view name); + + // Returns the shape of input at `index` or input the given `name`. Note that + // in case the shape of the input is not static, then the returned shape has + // bounds as the dimension size instead of having unknown dimensions. Use + // InputXlaShape instead that provides shapes with dynamism information. + // + ABSL_DEPRECATED( + "Prefer InputXlaShape which handles dynamic shapes accurately.") + TensorShape InputShape(int index); + ABSL_DEPRECATED( + "Prefer InputXlaShape which handles dynamic shapes accurately.") + TensorShape InputShape(absl::string_view name); + + // Returns input `index` as a XlaOp. Unlike + // OpKernelContext::Input returns a symbolic value rather than a concrete + // Tensor. + xla::XlaOp Input(int index); + // Returns input `name` as a XlaOp. + xla::XlaOp Input(absl::string_view name); + + // Returns the xla input shape for a given index. + absl::StatusOr InputXlaShape(int index); + absl::StatusOr InputXlaShape(absl::string_view name); + + // Returns true if all inputs are the same shape, otherwise sets the + // status to a non-OK value and returns false. + // Usage: if (!context->ValidateInputsAreSameShape(this)) return; + bool ValidateInputsAreSameShape(OpKernel* op) TF_MUST_USE_RESULT; + + // Returns the named list-valued immutable input in "list", as + // defined in the OpDef. If the named output is not list-valued, + // returns a one-element list. + absl::Status InputList(absl::string_view name, + std::vector* handles, + std::vector* shapes); + // Evaluates input and returns their dynamism vector in a vector of + // predicates. + absl::Status ResolveInputDynamismIntoPredVector(int index, + std::vector* out); + absl::Status ResolveInputDynamismIntoPred(int index, bool* out); + absl::Status ResolveInputDynamismIntoPredVector(absl::string_view name, + std::vector* out); + absl::Status ResolveInputDynamismIntoPred(absl::string_view name, bool* out); + + absl::Status ResolveInputDynamism(int index, xla::Literal* dynamism_literal); + absl::Status ResolveInputDynamism(absl::string_view name, + xla::Literal* dynamism_literal); + + absl::Status ResolveInputDynamismReshaped(int index, + absl::Span new_dims, + xla::Literal* dynamism_literal); + // Helper methods for constant inputs. + + // Evaluates input `index` and stores it in `*constant_literal`. If the + // expression cannot be evaluated, e.g., because it depends on unbound + // parameters, returns a non-OK status. This function can also be used to + // infer constant input upper or lower bounds, by changing the `mode` + // parameter. + absl::Status ConstantInput( + int index, xla::Literal* constant_literal, + xla::ValueInferenceMode mode = xla::ValueInferenceMode::kValue); + absl::Status ConstantInput( + absl::string_view name, xla::Literal* constant_literal, + xla::ValueInferenceMode mode = xla::ValueInferenceMode::kValue); + + // Converts a constant scalar int32 or int64 tensor into an int64. + absl::Status ConstantInputAsIntScalar( + int index, int64_t* out, + xla::ValueInferenceMode mode = xla::ValueInferenceMode::kValue); + absl::Status ConstantInputAsIntScalar( + absl::string_view name, int64_t* out, + xla::ValueInferenceMode mode = xla::ValueInferenceMode::kValue); + + absl::StatusOr ConstantInputAsIntScalar( + absl::string_view name, + xla::ValueInferenceMode mode = xla::ValueInferenceMode::kValue); + + // Converts a constant scalar float32 or float64 tensor into a float64. + absl::Status ConstantInputAsFloatScalar( + int index, double* out, + xla::ValueInferenceMode mode = xla::ValueInferenceMode::kValue); + + // Converts a constant 1D int32 or int64 tensor into a vector of int64s. + absl::Status ConstantInputAsIntVector( + int index, std::vector* out, + xla::ValueInferenceMode mode = xla::ValueInferenceMode::kValue); + absl::Status ConstantInputAsIntVector( + absl::string_view name, std::vector* out, + xla::ValueInferenceMode mode = xla::ValueInferenceMode::kValue); + + // Reshapes and converts a constant int32 or int64 tensor into a vector of + // int64s. + absl::Status ConstantInputReshapedToIntVector( + int index, std::vector* out, + xla::ValueInferenceMode mode = xla::ValueInferenceMode::kValue); + absl::Status ConstantInputReshapedToIntVector( + absl::string_view name, std::vector* out, + xla::ValueInferenceMode mode = xla::ValueInferenceMode::kValue); + + // Converts a constant int32 or int64 Tensor into an xla int64 Literal. + absl::Status ConstantInputAsInt64Literal( + int index, xla::Literal* out, + xla::ValueInferenceMode mode = xla::ValueInferenceMode::kValue); + absl::Status ConstantInputAsInt64Literal( + absl::string_view name, xla::Literal* out, + xla::ValueInferenceMode mode = xla::ValueInferenceMode::kValue); + + // Converts a constant 1D int32 or int64 tensor into a TensorShape. + absl::Status ConstantInputAsShape( + int index, TensorShape* shape, + xla::ValueInferenceMode mode = xla::ValueInferenceMode::kValue); + + // Converts a constant 1D int32 or int64 tensor, or a scalar with value -1 + // into a PartialTensorShape. + absl::Status ConstantInputAsPartialShape(int index, + PartialTensorShape* shape); + + // Returns the named list-valued immutable input in "list", as + // defined in the OpDef. If the named output is not list-valued, + // returns a one-element list. + absl::Status ConstantInputList( + absl::string_view name, std::vector* outputs, + xla::ValueInferenceMode mode = xla::ValueInferenceMode::kValue); + + // Returns the Tensor representation of the constant input. + absl::StatusOr ConstantInputTensor( + int index, + xla::ValueInferenceMode mode = xla::ValueInferenceMode::kValue); + + // Returns an XlaExpression describing the value of 'index'. + const XlaExpression& InputExpression(int index); + const XlaExpression& InputExpression(absl::string_view name); + + // Outputs + + int num_outputs() const { return context_->num_outputs(); } + DataType expected_output_dtype(int index) const { + return context_->expected_output_dtype(index); + } + + // Returns the type of output `index` as an xla::PrimitiveType. If the type + // is not representable as an XLA type, sets an error status and returns + // xla::PRIMITIVE_TYPE_INVALID. + xla::PrimitiveType output_xla_type(int index); + + // Sets output `index` to the XlaOp `handle`. + // All outputs should be set using SetOutput and SetConstantOutput, not + // via the underlying OpKernelContext. + void SetOutput(int index, const xla::XlaOp& handle); + + // Sets output `index` to compile-time constant `host_tensor`, where + // `host_tensor` is a tensor in host memory. It is preferable to use + // SetConstantOutput where possible. + void SetConstantOutput(int index, const Tensor& host_tensor); + + // Returns an XlaExpression describing the value of 'index'. + void SetOutputExpression(int index, const XlaExpression& expression); + + // Sets output `index` to the Tensor List `handle`. + void SetTensorListOutput(int index, const xla::XlaOp& handle); + + // Status handling. + void SetStatus(const absl::Status& status) { context_->SetStatus(status); } + absl::Status status() { return context_->status(); } + + // Variables + + // Sets `*resource` to the resource associated with input `index`. + absl::Status GetResourceInput(int index, XlaResource** resource); + + // Sets output `index` to be a reference to resource `resource`. + void SetResourceOutput(int index, XlaResource* resource); + + // Sets `*type` and `*shape` to the current type and shape of a variable's + // value. + absl::Status GetVariableTypeAndShape(int index, DataType* type, + TensorShape* shape) const; + + // When dynamic_dimension_is_minus_one is set, querying a dynamic dimension + // returns "-1", this is useful when the underlying ops expect explicit + // dynamic index like reshape. + void set_dynamic_dimension_is_minus_one(bool value) { + dynamic_dimension_is_minus_one_ = value; + } + + bool dynamic_dimension_is_minus_one() const { + return dynamic_dimension_is_minus_one_; + } + + bool is_dynamic_dimension(int64_t dim_size) { return dim_size == -1; } + + // Reads the current value of the resource variable referred to by input + // `index`. If `shape` is not nullptr, sets `*shape` to the shape of the + // variable. Returns an error if the variable has not been initialized, or if + // its type does not match `type`. + absl::Status ReadVariableInput(int index, DataType type, TensorShape* shape, + xla::XlaOp* value); + // Reads the current value of the resource variable referred to by input + // `name`. + absl::Status ReadVariableInput(absl::string_view name, DataType type, + TensorShape* shape, xla::XlaOp* value); + + // Assigns the value `handle` to the variable referenced by input + // `input_index`. The variable must be of `type`. Returns an error if the + // variable has been initialized with a different type or with a + // different shape. + absl::Status AssignVariable(int input_index, DataType type, + xla::XlaOp handle); + // Assigns the value `handle` to the variable referenced by input `name`. + absl::Status AssignVariable(absl::string_view name, DataType type, + xla::XlaOp handle); + + // Helper routines for the OP_REQUIRES macros + void CtxFailure(const absl::Status& s); + void CtxFailureWithWarning(const absl::Status& s); + void CtxFailure(const char* file, int line, const absl::Status& s); + void CtxFailureWithWarning(const char* file, int line, const absl::Status& s); + + // If this kernel invocation is within a function execution, + // call_frame() returns the call frame for the function call. + CallFrameInterface* call_frame() const { return context_->call_frame(); } + + FunctionLibraryRuntime* function_library() const { + return context_->function_library(); + } + + const OpKernel& op_kernel() const { return context_->op_kernel(); } + + // Returns the underlying OpKernelContext. Use rarely. + OpKernelContext* op_kernel_context() const { return context_; } + + // Returns the XlaCompiler that is performing the compilation. Used for, e.g., + // While to compile nested computations. + XlaCompiler* compiler() const; + + // TODO(phawkins): find a better home for these helpers. + + // Gets an XLA lambda to compute Max. This is cached in the + // XlaContext since it may be used by multiple Ops. There is a + // separate specialization of the computation for each DataType. + const xla::XlaComputation* GetOrCreateMax(const DataType type); + + // Gets an XLA lambda to compute Min. This is cached in the + // XlaContext since it may be used by multiple Ops. There is a + // separate specialization of the computation for each DataType. + const xla::XlaComputation* GetOrCreateMin(const DataType type); + + // Gets an XLA lambda to compute Add. This is cached in the + // XlaContext since it may be used by multiple Ops. There is a + // separate specialization of the computation for each DataType. + const xla::XlaComputation* GetOrCreateAdd(const DataType type); + + // Gets an XLA lambda to compute LogAddExp. This is cached in the + // XlaContext since it may be used by multiple Ops. There is a + // separate specialization of the computation for each DataType. + const xla::XlaComputation* GetOrCreateLogAddExp(const DataType type); + + // Gets an XLA lambda to compute Mul. This is cached in the + // XlaContext since it may be used by multiple Ops. There is a + // separate specialization of the computation for each DataType. + const xla::XlaComputation* GetOrCreateMul(const DataType type); + + // Returns stack trace encoded as a string at a given module, or an empty + // string if none found. + std::string StackTrace() const; + + private: + // Returns the tensor of input `name`. + const Tensor& GetInputTensorByName(absl::string_view name); + // Evaluates input `index`, reshapes it to `new_shape` if new_shape != + // InputShape(index), and stores it in `*constant_literal`. If the input + // cannot be evaluated, e.g., because it depends on unbound parameters, + // returns a non-Ok status. If InputShape(index).num_elements() != + // new_shape.num_elements(), returns an error status. + absl::Status ConstantInputReshaped( + int index, absl::Span new_dims, + xla::Literal* constant_literal, + xla::ValueInferenceMode mode = xla::ValueInferenceMode::kValue); + + OpKernelContext* const context_; + bool dynamic_dimension_is_minus_one_; + xla::ValueInference value_inference_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2XLA_XLA_OP_KERNEL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/tf2xla/xla_op_registry.h b/third_party/tflite-hdrs/tensorflow/compiler/tf2xla/xla_op_registry.h new file mode 100644 index 00000000..11bbbf2b --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/tf2xla/xla_op_registry.h @@ -0,0 +1,440 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2XLA_XLA_OP_REGISTRY_H_ +#define TENSORFLOW_COMPILER_TF2XLA_XLA_OP_REGISTRY_H_ + +#include +#include +#include +#include +#include + +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "tensorflow/core/common_runtime/device_factory.h" +#include "tensorflow/core/common_runtime/local_device.h" +#include "tensorflow/core/framework/device_base.h" +#include "tensorflow/core/framework/kernel_def.pb.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/op_def.pb.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/mem.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/thread_annotations.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/public/session_options.h" +#include "tsl/platform/errors.h" + +namespace tensorflow { + +// Names of the XLA compilation devices. These are not user-visible, and are +// used internally by the Tensorflow/XLA bridge to perform symbolic execution of +// a Tensorflow graph. + +extern const char* const DEVICE_CPU_XLA_JIT; // "CPU_XLA_JIT" +extern const char* const DEVICE_GPU_XLA_JIT; // "GPU_XLA_JIT" + +extern const char* const DEVICE_XLA_CPU; +extern const char* const DEVICE_XLA_GPU; + +// Do not include DT_FLOAT8_* as float or numeric types since they are only +// supported in a very limited set of ops. +constexpr std::array kFloatTypes = { + {DT_HALF, DT_FLOAT, DT_DOUBLE, DT_BFLOAT16}}; +constexpr std::array kFloatAndComplexTypes = { + {DT_HALF, DT_FLOAT, DT_DOUBLE, DT_BFLOAT16, DT_COMPLEX64, DT_COMPLEX128}}; +constexpr std::array kNumericTypes = { + {DT_UINT8, DT_UINT16, DT_UINT32, DT_UINT64, DT_INT8, DT_INT16, DT_INT32, + DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_COMPLEX128, + DT_BFLOAT16}}; + +constexpr std::array kCpuAllTypes = { + {DT_UINT8, DT_QUINT8, DT_UINT16, DT_UINT32, DT_UINT64, + DT_INT8, DT_QINT8, DT_INT16, DT_INT32, DT_QINT32, + DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, + DT_COMPLEX128, DT_BOOL, DT_BFLOAT16, DT_FLOAT8_E5M2, DT_FLOAT8_E4M3FN, + DT_INT4, DT_UINT4}}; + +constexpr std::array kGpuAllTypes = { + {DT_UINT8, DT_QUINT8, DT_UINT16, DT_UINT32, DT_UINT64, + DT_INT8, DT_QINT8, DT_INT16, DT_INT32, DT_QINT32, + DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, + DT_COMPLEX128, DT_BOOL, DT_BFLOAT16, DT_FLOAT8_E5M2, DT_FLOAT8_E4M3FN, + DT_INT4, DT_UINT4}}; + +// Class that manages registrations of operators and devices for the XLA JIT. +// Not thread-safe. +class XlaOpRegistry { + public: + typedef OpKernel* (*Factory)(OpKernelConstruction*); + + enum class AutoclusteringPolicy { + // Enable autoclustering if the user requests it, e.g., via + // experimental_jit_scope. Does not autocluster if the JIT is enabled + // globally (e.g., via the OptimizerOptions in the TF session + // configuration.) + kIfExplicitlyRequested, + // Enable autoclustering if explicitly requested, or if the JIT is enabled + // globally in the session options, or via TF_XLA_FLAGS=--tf_xla_auto_jit=N. + kIfEnabledGlobally, + // Always try to autocluster ops placed on this device. + kAlways, + }; + + // Describes how to compile operators assigned to a device. + struct DeviceRegistration { + // The name of the an XLA compilation device to use to compile code. + string compilation_device_name; + + // When should we autocluster operators assigned to this device? + AutoclusteringPolicy autoclustering_policy; + + // If we should ignore the resource variable memory model when clustering + // resource variable reads and writes placed on this device. + bool cluster_resource_variable_ops_unsafely = false; + + // If we should auto-cluster Stack operations placed on this device. + bool cluster_stack_ops = false; + + // If we should auto-cluster TensorArray operations placed on this device. + bool cluster_tensor_array_ops = false; + + // If we should auto-cluster stateful RNG operations placed on this device. + // Stateful RNG semantics are not properly supported by XLA so it is not + // necessarily correct to auto-cluster stateful RNG ops in general. + bool cluster_stateful_rng_ops = false; + + // If we should auto-cluster ControlTrigger operations placed on this + // device. ControlTrigger operations are not necessarily safe to cluster + // since they affect deadness (a dead ControlTrigger produces a live + // output). + bool cluster_control_trigger = false; + + // If we should cluster Assert and CheckNumerics by eliding them (XLA does + // not natively support Assert or CheckNumerics). + bool elide_assert_and_checknumerics = false; + + // If we should cluster operations returning DT_VARIANT. + bool cluster_variant_ops = false; + + // Whether ops known to be slow should be auto-clustered. + bool cluster_slow_ops = false; + + // Whether ops known to have numerical accuracy issues should be + // auto-clustered. + bool cluster_inaccurate_ops = false; + }; + + // Registers an XLA backend. `compilation_device_name` is the name of the + // device used for symbolic execution during compilation. `supported_types` + // is the list of non-resource types supported by the device. Each operators + // will be registered for the intersection of the operator's supported types + // and the device's supported types. `backend_op_filter` is a function used + // to exclude or modify operator registrations on the device; it may be + // nullptr, in which case all ops are included. + // `backend_op_filter` should return true if the op should be registered on + // the device; it may optionally modify the KernelDef. + typedef bool (*BackendOpFilter)(KernelDef* kdef); + static void RegisterBackend(const string& compilation_device_name, + absl::Span supported_types, + BackendOpFilter op_filter); + + // Returns the names of the registered backends. + static std::vector BackendNames(); + + // Returns true iff a backend with the given name is registered. + static bool IsBackendRegistered(const string& name); + + // Registers `device_name` for XLA compilation, using information from + // `registration`. + // Does nothing if a registration for `device_name` already exists. + static void RegisterCompilationDevice(const string& device_name, + const DeviceRegistration& registration); + + // Returns whether the device name is for the JIT device used exclusively for + // TF2XLA conversion. + static bool IsCompilationDevice(const string& device_name); + + // Returns the JIT device name associated with 'device_name', setting + // 'jit_device_name', 'requires_jit', and 'enabled_jit_by_default', if they + // are not null. Returns false and leaves the outputs unchanged if no matching + // JIT device is registered. + // '*enable_jit_by_default' is set to true if we should try to JIT using this + // device when the JIT is enabled via the Session OptimizerOptions. + static bool GetCompilationDevice(const string& device_name, + const DeviceRegistration** registration); + + // Registers all JIT kernels on JIT devices, if not already registered. + // Does nothing otherwise. + static void RegisterCompilationKernels(); + + // Returns KernelDefs for compilation ops registered on + // 'compilation_device_name'. Does not include kernels registered as + // CompilationOnly, iff include_compilation_only_kernels=false. + static std::vector DeviceKernels( + const string& compilation_device_name, + bool include_compilation_only_kernels); + + // Returns all operations for which there are XLA kernels on any device. + static std::vector GetAllRegisteredOps(); + + // Returns (via `result`) the indices of inputs to `node_def` that must be + // compile-time constants. Returns an empty vector if the op is not + // registered. + // + // `result` is sorted. + static absl::Status CompileTimeConstantInputs(const NodeDef& node_def, + const OpDef& op_def, + std::vector* result) { + return CompileTimeConstantInputs(node_def, /*op_kernel=*/nullptr, &op_def, + result); + } + + static absl::StatusOr> CompileTimeConstantInputs( + const NodeDef& node_def, const OpDef& op_def) { + std::vector out; + TF_RETURN_IF_ERROR(CompileTimeConstantInputs(node_def, op_def, &out)); + return out; + } + + // Returns (via `result`) the indices of inputs to `op_kernel` that must be + // compile-time constants. + // + // `result` is sorted. + static absl::Status CompileTimeConstantInputs(const OpKernel& op_kernel, + std::vector* result) { + return CompileTimeConstantInputs(op_kernel.def(), /*op_kernel=*/&op_kernel, + /*op_def=*/nullptr, result); + } + + // Return names of arguments for a given op which are supposed to be + // constants. + static const std::unordered_set* + CompileTimeConstantInputArgNames(const string& op); + + // Returns true if `op` is a "metadata" op, one that only looks at the shapes + // of its operands and not their values. + static bool IsMetadataOp(const string& op); + + private: + friend class XlaBackendRegistrar; + friend class XlaOpRegistrar; + friend class XlaOpRegistrationBuilder; + + static XlaOpRegistry& Instance(); + + XlaOpRegistry(); + ~XlaOpRegistry(); + + mutex mutex_; + + // Describes an XLA backend. + struct Backend { + // Which types are supported by this device? + std::set supported_types; + + // The per-backend operator filter function. See the comment on + // RegisterBackend() for details. + BackendOpFilter op_filter; + + // KernelDefs built by RegisterCompilationKernels() for each op supported + // by the device. + std::vector> kernel_defs; + }; + + // Map from compilation device names to a description of the backend. + std::unordered_map backends_ TF_GUARDED_BY(mutex_); + + // Map from Tensorflow device names to the corresponding JIT device metadata. + std::unordered_map compilation_devices_ + TF_GUARDED_BY(mutex_); + + // A description of a Tensorflow operator that can be compiled to XLA. + struct OpRegistration { + string name; + + // Should this operator be registered only on compilation devices, without a + // dummy kernel registered on the corresponding XLA device? + bool compilation_only = false; + + // Should we allow resource types for type attributes? Used by _Arg to + // allow DT_RESOURCE. + bool allow_resource_types = false; + + // Should we allow variant types for type attributes? Used by While to + // allow TensorList which is of type DT_VARIANT. + bool allow_variant_types = false; + + // Should we allow string type for type attributes? Used by PartitionedCall + // to allow DT_STRING. + bool allow_string_type = false; + + // Mapping from attribute name to a list of supported types. + std::unordered_map> type_constraints; + + // An optional allowlist of devices. If there is no allowlist, all devices + // are permitted. + bool has_device_allowlist = false; + std::unordered_set device_allowlist; + + // Names of arguments that must be compile-time constants. + std::unordered_set compile_time_constant_inputs; + + // True if this is a "metadata" op, one that only looks at the shapes of its + // operands and not their values. + bool is_metadata_op = false; + + std::string label; + + // Factory used to build OpKernels that perform symbolic execution. + Factory factory; + }; + + // Returns true if registrations x and y can both be added to the registry. + // This is always the case if they refer to different ops. If they refer to + // the same op name, they must: have the same values for compilation_only, + // allow_resource_types and allow_variant_types; use a device_allowlist; and + // their allowlists must not intersect. + static bool IsCompatible(const OpRegistration& x, const OpRegistration& y); + + static absl::Status CompileTimeConstantInputs(const NodeDef& node_def, + const OpKernel* op_kernel, + const OpDef* op_def, + std::vector* result); + + // Map from operator name to OpRegistrations, populated by REGISTER_XLA_OP. + // Registrations present under the same key must satisfy IsCompatible above, + // and this is checked during registration. + std::unordered_map>> ops_ + TF_GUARDED_BY(mutex_); + + // Have we already registered the JIT kernels on the JIT devices? + bool jit_kernels_registered_ = false; + + // Holds ownership of OpKernelRegistrars that represent the Tensorflow kernel + // registrations created by RegisterCompilationKernels() and + // RegisterDeviceKernels(). + std::vector> + kernel_registrars_ TF_GUARDED_BY(mutex_); +}; + +// REGISTER_XLA_OP() registers an XLA OpKernel by name, for example: +// REGISTER_XLA_OP(Name("Add"), AddOp); +// where 'AddOp' is the name of a JIT OpKernel class that implements "Add". +// +// We don't use a variadic macro here because we don't expect JIT operators to +// be templated. + +#define REGISTER_XLA_OP(NAME, OP) \ + REGISTER_XLA_OP_UNIQ_HELPER(__COUNTER__, NAME, OP) + +#define REGISTER_XLA_CONV_OP(BUILDER, OP) \ + REGISTER_XLA_OP(BUILDER.TypeConstraint("T", GetXlaConvTypesForNonGpu()), OP) \ + REGISTER_XLA_OP(BUILDER.TypeConstraint("T", GetXlaConvTypesForGpu()) \ + .Device(DEVICE_GPU_XLA_JIT), \ + OP) + +class XlaOpRegistrationBuilder { + public: + // Starts an operator registration chain. + static XlaOpRegistrationBuilder Name(absl::string_view name); + + // Specifies a allowlist of devices on which the operator may run. + XlaOpRegistrationBuilder& Device(absl::string_view devices); + XlaOpRegistrationBuilder& Device(absl::Span devices); + + // Specifies a type constraint for a type variable attribute. Each constraint + // specifies the set of types that the type variable may assume. + XlaOpRegistrationBuilder& TypeConstraint(absl::string_view attr_name, + DataType allowed); + + XlaOpRegistrationBuilder& TypeConstraint(absl::string_view attr_name, + absl::Span allowed); + + // Specifies that a dummy copy of this operator should not be registered on + // XLA_* devices, but may be used during compilation. + XlaOpRegistrationBuilder& CompilationOnly(); + + // Allow DT_RESOURCE types for type parameters. + XlaOpRegistrationBuilder& AllowResourceTypes(); + + // Allow DT_VARIANT types for type parameters. + XlaOpRegistrationBuilder& AllowVariantTypes(); + + // Allow DT_STRING type for type parameters. + XlaOpRegistrationBuilder& AllowStringType(); + + // Mark 'input_name' as an argument whose value must be known at compile-time. + XlaOpRegistrationBuilder& CompileTimeConstantInput( + absl::string_view input_name); + + // Mark this op as a "metadata" op, one that only looks at the shapes of its + // operands and not their values. + XlaOpRegistrationBuilder& IsMetadataOp(); + + // Specifies a particular value for the "_kernel" attr. + XlaOpRegistrationBuilder& Label(std::string label); + + std::unique_ptr Build( + XlaOpRegistry::Factory factory); + + private: + XlaOpRegistrationBuilder(absl::string_view name); + + std::unique_ptr registration_; +}; + +// REGISTER_XLA_BACKEND() registers an XLA backend. Example usage: +// REGISTER_XLA_BACKEND(DEVICE_GPU_XLA_JIT, kGpuAllTypes, GpuOpFilter); +#define REGISTER_XLA_BACKEND(NAME, ...) \ + REGISTER_XLA_BACKEND_UNIQ_HELPER(__COUNTER__, NAME, __VA_ARGS__) + +// Implementation details. + +class XlaOpRegistrar { + public: + XlaOpRegistrar(std::unique_ptr registration); +}; + +#define REGISTER_XLA_OP_UNIQ_HELPER(COUNTER, BUILDER, OP) \ + REGISTER_XLA_OP_UNIQ(COUNTER, BUILDER, OP) + +#define REGISTER_XLA_OP_UNIQ(CTR, BUILDER, OP) \ + static ::tensorflow::XlaOpRegistrar xla_op_registrar__body__##CTR##__object( \ + ::tensorflow::XlaOpRegistrationBuilder::BUILDER.Build( \ + [](::tensorflow::OpKernelConstruction* context) \ + -> ::tensorflow::OpKernel* { return new OP(context); })); + +class XlaBackendRegistrar { + public: + XlaBackendRegistrar(absl::string_view name, absl::Span types, + XlaOpRegistry::BackendOpFilter op_filter = nullptr); +}; + +#define REGISTER_XLA_BACKEND_UNIQ_HELPER(COUNTER, NAME, ...) \ + REGISTER_XLA_BACKEND_UNIQ(COUNTER, NAME, __VA_ARGS__) + +#define REGISTER_XLA_BACKEND_UNIQ(CTR, NAME, ...) \ + static ::tensorflow::XlaBackendRegistrar \ + xla_backend_registrar__body__##CTR##__object(NAME, __VA_ARGS__); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2XLA_XLA_OP_REGISTRY_H_ diff --git a/third_party/tflite-hdrs/tensorflow/compiler/tf2xla/xla_resource.h b/third_party/tflite-hdrs/tensorflow/compiler/tf2xla/xla_resource.h new file mode 100644 index 00000000..d4c8f7c1 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/compiler/tf2xla/xla_resource.h @@ -0,0 +1,197 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2XLA_XLA_RESOURCE_H_ +#define TENSORFLOW_COMPILER_TF2XLA_XLA_RESOURCE_H_ + +#include + +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "xla/hlo/builder/xla_builder.h" +#include "xla/shape.h" +#include "xla/xla_data.pb.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/managed_stack_trace.h" + +namespace tensorflow { + +// Represents a resource, such as a Variable or TensorArray. +class XlaResource { + public: + enum Kind { + kInvalid, + kVariable, + kTensorArray, + kStack, + }; + static absl::string_view KindToString(Kind kind); + + // Creates a new Stack resource. + static std::unique_ptr CreateStack(string name, DataType type, + int64_t max_size); + + // Creates a new TensorArray resource. + static std::unique_ptr CreateTensorArray( + string name, DataType type, TensorShape shape, xla::XlaOp initial_value, + int64_t max_array_size); + + XlaResource(Kind kind, int arg_num, string name, DataType type, + TensorShape shape, xla::XlaOp initial_value, + int64_t max_array_size, + const std::set& tensor_array_gradients, + bool tensor_array_multiple_writes_aggregate, + const std::optional& definition_stack_trace = + std::nullopt); + + XlaResource(const XlaResource&) = delete; + XlaResource(XlaResource&&) = delete; + XlaResource& operator=(const XlaResource&) = delete; + XlaResource& operator=(XlaResource&&) = delete; + + Kind kind() const { return kind_; } + + // If this resource is visible externally to the computation, what was its + // argument number? + // < 0 means "not visible externally". + int arg_num() const { return arg_num_; } + + // A descriptive name for the resource, used in error messages. + const string& name() const { return name_; } + + // Current type and value of the resource. Uninitialized resources are + // represented by a default (zero) handle and type DT_INVALID. + // While the type of a resource is notionally fixed during execution, when + // a resource is first initialized we do not yet know its type, so we keep + // track of its type dynamically. + DataType type() const { return type_; } + + // Shape of the resource. For an uninitialized resource, this is ignored. + // For a Variable, this is the shape of the value. For a TensorArray or Stack + // this is the shape of each entry in the TensorArray/Stack. + const TensorShape& shape() const { return shape_; } + + const xla::XlaOp& value() const { return value_; } + + // Value of the resource at computation entry. Used to detect which + // variables have new values that need to be written back. + const xla::XlaOp& initial_value() const { return initial_value_; } + + // An xla shape that indicates how this resource variable is represented on + // device. + const std::optional& representation_shape() const { + return representation_shape_; + } + + // A variable is initialized if it has a value. + bool initialized() const { return value_.valid(); } + + // Sets the type and shape of the resource. The type and shape of a resource + // must not change once the variable has been initialized. + absl::Status SetTypeAndShape(DataType type, const TensorShape& shape); + + // Sets the current value of the resource. Returns an error if the type is not + // set to a valid value. + absl::Status SetValue(xla::XlaOp value); + + // Sets the current value of the resource to an all-zero value. + absl::Status SetZeroValue(xla::XlaBuilder* builder); + + // Sets the representational shape of the resource on device. + void SetRepresentationShape(const xla::Shape& shape) { + representation_shape_ = absl::make_optional(shape); + } + + // Looks up the gradient for `source`, or creates it if it does not already + // exist. The call target must be an initialized TensorArray resource. A + // TensorArray can have multiple named gradients; see the operator + // documentation for TensorArrayGradV3 for details. + absl::Status GetOrCreateTensorArrayGradient(const string& source, + xla::XlaBuilder* builder, + XlaResource** gradient_out); + + // Packs a resource into a single XLA value `pack`, suitable for use as + // an XlaCompiler::Argument. For non-TensorArrays or TensorArrays without + // gradients, sets `*pack` to `value`. + // For TensorArrays with gradients, packs the value and its gradient values in + // a tuple; the gradients values are packed in order by source name. + absl::Status Pack(xla::XlaOp* pack, xla::XlaBuilder* builder) const; + + // Updates the resource with values from `pack`. If `gradient_sources` is + // non-empty, treats `pack` as a tuple that represents a TensorArray and + // its gradients, and unpacks and updates the gradient resources. + // If `reset_initial_values` is true, sets the initial_values as well as the + // values. + // Opposite of Pack(). + absl::Status SetFromPack(const std::set& gradient_sources, + xla::XlaOp pack, xla::XlaBuilder* builder); + + bool IsOverwritten() { return is_overwritten_; } + + // TensorArray and Stack specific fields + // TODO(phawkins): refactor this code to use subclasses, rather than putting + // kind-specific fields in XlaResource. + + // 'max_array_size' stores the expected size of the TensorArray or Stack. + // We need to store this since sometimes TensorArrays must be initialized + // lazily since we do not know the element shape at construction time. + // Used by both TensorArrays and Stacks. + int64_t max_array_size() const { return max_array_size_; } + void set_max_array_size(int64_t size) { max_array_size_ = size; } + + bool tensor_array_multiple_writes_aggregate() const { + return tensor_array_multiple_writes_aggregate_; + } + + // 'tensor_array_gradient' is a map from TensorArrayGradV3 'source' attributes + // to an XlaResource containing the gradient TensorArrays. We store a pointer + // here since there should only be one gradient TensorArray per 'source' + // string, irrespective of the number of calls to TensorArrayGrad. The map + // is ordered since values are packed into tuples by Pack() sorted by name + // order. + const std::map>& tensor_array_gradients() + const { + return tensor_array_gradients_; + } + + private: + const Kind kind_; + const int arg_num_; + const string name_; + + DataType type_; + TensorShape shape_; + xla::XlaOp value_; + xla::XlaOp initial_value_; + + // An xla shape that indicates how this resource variable is represented on + // device. + std::optional representation_shape_; + + int64_t max_array_size_ = -1; + bool tensor_array_multiple_writes_aggregate_ = false; + + std::map> tensor_array_gradients_; + bool is_overwritten_ = false; + + std::optional definition_stack_trace_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2XLA_XLA_RESOURCE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/activity_watcher/activity.h b/third_party/tflite-hdrs/tensorflow/core/activity_watcher/activity.h new file mode 100644 index 00000000..eecd207a --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/activity_watcher/activity.h @@ -0,0 +1,186 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_ACTIVITY_WATCHER_ACTIVITY_H_ +#define TENSORFLOW_CORE_ACTIVITY_WATCHER_ACTIVITY_H_ + +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "xla/tsl/platform/macros.h" +#include "xla/tsl/platform/types.h" + +namespace tsl { +class CoordinationServiceAgent; +} + +namespace tensorflow { + +namespace activity_watcher { + +using ActivityId = tsl::uint64; +constexpr ActivityId kActivityNotRecorded = 0; +constexpr int kWatcherDisabled = 0; + +enum ActivityCategory { + kCollective = 0, + kRemoteFunction = 1, + kMisc = 2, + kDatasetOp = 3, + kTpuOp = 4, + kRendezvous = 5, +}; + +static tsl::string ToString(ActivityCategory category) { + switch (category) { + case ActivityCategory::kCollective: + return "Collective"; + case ActivityCategory::kRemoteFunction: + return "Remote Function"; + case ActivityCategory::kMisc: + return "Miscellaneous"; + case ActivityCategory::kDatasetOp: + return "Dataset Op"; + case ActivityCategory::kTpuOp: + return "TPU Op"; + case ActivityCategory::kRendezvous: + return "Rendezvous"; + } +} + +// An activity to be recorded. +struct Activity { + using Attributes = absl::flat_hash_map; + // A human readable title of the activity. + tsl::string title; + // The category of the activity. + ActivityCategory category = ActivityCategory::kMisc; + // Key/value pairs that are attached to the activity. + Attributes attributes; + Activity() = default; + Activity(tsl::string title, ActivityCategory category) + : title(std::move(title)), category(category) {} + Activity(tsl::string title, ActivityCategory category, Attributes attributes) + : title(std::move(title)), + category(category), + attributes(std::move(attributes)) {} +}; + +// Enable activity wathcer to send own workers activities to coordination +// service and also fetch all workers' activities. +void MaybeEnableMultiWorkersWatching(tsl::CoordinationServiceAgent* agent); + +namespace tfw_internal { + +#if defined(TF_ENABLE_ACTIVITY_WATCHER) + +// Records an activity start without checking whether the watcher is enabled. +ActivityId RecordActivityStart(std::unique_ptr activity); +// Records an activity end without checking whether the activity_id is valid. +void RecordActivityEnd(ActivityId activity_id); + +TF_EXPORT extern std::atomic g_watcher_level; + +// Returns whether the activitity watcher is enabled. +inline bool WatcherEnabled(int level = 1) { + return g_watcher_level.load(std::memory_order_acquire) >= level; +} + +#endif + +// NOTE: Borrowed from boost C++ libraries because std::is_invocable_r is not +// available in Android NDK. +template +struct is_invocable_r + : std::is_constructible< + std::function, + std::reference_wrapper::type>> {}; + +} // namespace tfw_internal + +template +constexpr bool is_activity_generator = + tfw_internal::is_invocable_r, F>::value; + +// Records an activity explicitly. Useful when the start and end of an activity +// happen in different threads. Generates the Activity only if activity +// watching is enabled, useful for avoiding expensive operations when activity +// watching is disabled. +// Example Usage: +// auto aid = ActivityStart([&]() { +// return std::make_unique( +// op_name, category, +// Activity::Attributes{{"key1", value1}, {"key2", value2}}); +// }, /*level=*/2); +// DoSomething(); +// ActivityEnd(aid); +template < + typename ActivityGenerator, + std::enable_if_t, bool> = true> +inline ActivityId ActivityStart(ActivityGenerator&& gen, int level = 1) { +#if defined(TF_ENABLE_ACTIVITY_WATCHER) + if (TF_PREDICT_FALSE(tfw_internal::WatcherEnabled(level))) { + return tfw_internal::RecordActivityStart( + std::forward(gen)()); + } +#endif + return kActivityNotRecorded; +} + +inline void ActivityEnd(ActivityId id) { +#if defined(TF_ENABLE_ACTIVITY_WATCHER) + if (TF_PREDICT_FALSE(id != kActivityNotRecorded)) { + tfw_internal::RecordActivityEnd(id); + } +#endif +} + +// ActivityScope marks a scope as an activity and record it with a global +// ActivityRecorder. +// Example Usage: +// { +// ActivityScope activity_scope([&]() { +// return std::make_unique( +// op_name, ActivityCategory::kMisc, +// Activity::Attributes{{"key1", value1}, {"key2", value2}}); +// }, /*level=*/2); +// DoSomething(); +// } +class ActivityScope { + public: + template < + typename ActivityGenerator, + std::enable_if_t, bool> = true> + explicit ActivityScope(ActivityGenerator&& gen, int level = 1) { + activity_id_ = ActivityStart(std::forward(gen), level); + } + ActivityScope(ActivityScope&& activity) { + activity_id_ = activity.activity_id_; + activity.activity_id_ = kActivityNotRecorded; + } + ~ActivityScope() { ActivityEnd(activity_id_); } + + private: + ActivityId activity_id_; + ActivityScope(const ActivityScope&) = delete; + void operator=(const ActivityScope&) = delete; +}; + +} // namespace activity_watcher +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_ACTIVITY_WATCHER_ACTIVITY_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/activity_watcher/activity_utils.h b/third_party/tflite-hdrs/tensorflow/core/activity_watcher/activity_utils.h new file mode 100644 index 00000000..64958cd5 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/activity_watcher/activity_utils.h @@ -0,0 +1,38 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_ACTIVITY_WATCHER_ACTIVITY_UTILS_H_ +#define TENSORFLOW_CORE_ACTIVITY_WATCHER_ACTIVITY_UTILS_H_ + +#include + +#include "xla/tsl/platform/types.h" +#include "tensorflow/core/activity_watcher/activity.h" + +namespace tensorflow { + +class OpKernelContext; + +namespace activity_watcher { + +// A convenient way to create an activity. Writes OpKernelContext information +// and given attributes to a new activity and returns. +std::unique_ptr ActivityFromContext( + OpKernelContext* context, tsl::string name, ActivityCategory category, + Activity::Attributes additional_attributes = Activity::Attributes()); + +} // namespace activity_watcher +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_ACTIVITY_WATCHER_ACTIVITY_UTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/api_def/excluded_ops.h b/third_party/tflite-hdrs/tensorflow/core/api_def/excluded_ops.h new file mode 100644 index 00000000..409e5d32 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/api_def/excluded_ops.h @@ -0,0 +1,28 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_API_DEF_EXCLUDED_OPS_H_ +#define TENSORFLOW_CORE_API_DEF_EXCLUDED_OPS_H_ + +#include +#include + +namespace tensorflow { + +// Returns a list of ops excluded from ApiDef. +// TODO(annarev): figure out if we should keep ApiDefs for these ops as well +const std::unordered_set* GetExcludedOps(); +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_API_DEF_EXCLUDED_OPS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/api_def/update_api_def.h b/third_party/tflite-hdrs/tensorflow/core/api_def/update_api_def.h new file mode 100644 index 00000000..1e285c06 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/api_def/update_api_def.h @@ -0,0 +1,45 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_API_DEF_UPDATE_API_DEF_H_ +#define TENSORFLOW_CORE_API_DEF_UPDATE_API_DEF_H_ +// Functions for updating ApiDef when new ops are added. + +#include "tensorflow/core/framework/op_def.pb.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +// Returns ApiDefs text representation in multi-line format +// constructed based on the given op. +string CreateApiDef(const OpDef& op); + +// Removes .Doc call for the given op. +// If unsuccessful, returns original file_contents and prints an error. +// start_location - We search for .Doc call starting at this location +// in file_contents. +string RemoveDoc(const OpDef& op, const string& file_contents, + size_t start_location); + +// Creates api_def_*.pbtxt files for any new ops (i.e. ops that don't have an +// api_def_*.pbtxt file yet). +// If op_file_pattern is non-empty, then this method will also +// look for a REGISTER_OP call for the new ops and removes corresponding +// .Doc() calls since the newly generated api_def_*.pbtxt files will +// store the doc strings. +void CreateApiDefs(const OpList& ops, const string& api_def_dir, + const string& op_file_pattern); + +} // namespace tensorflow +#endif // TENSORFLOW_CORE_API_DEF_UPDATE_API_DEF_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/all_to_all.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/all_to_all.h new file mode 100644 index 00000000..f0fb1651 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/all_to_all.h @@ -0,0 +1,70 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_ALL_TO_ALL_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_ALL_TO_ALL_H_ + +#include +#include +#include +#include + +#include "tensorflow/core/common_runtime/base_collective_executor.h" +#include "tensorflow/core/framework/collective.h" +#include "tensorflow/core/framework/device.h" + +namespace tensorflow { + +// Implementation of collective all-to-all. +class AllToAll : public CollectiveImplementationInterface { + public: + AllToAll(); + + void Run(StatusCallback done) override; + + absl::Status InitializeCollectiveParams( + CollectiveParams* col_params) override { + return absl::OkStatus(); + } + + // Initializes members of CollectiveContext not yet initialized, i.e. device + // and device_locality. Also saves the CollectiveContext in this object. + absl::Status InitializeCollectiveContext( + std::shared_ptr col_ctx) override; + + private: + std::shared_ptr col_ctx_; + const CollectiveParams* col_params_; // Not owned + std::vector input_chunks_; + Tensor output_buffer_; + std::vector output_chunks_; + StatusCallback done_; + mutex mu_; + absl::Status status_ TF_GUARDED_BY(mu_); + int counter_ TF_GUARDED_BY(mu_); + + void DispatchSend(int src_rank, int target_rank, const Tensor* tensor, + const StatusCallback& done); + + void DispatchRecv(int src_rank, int target_rank, Tensor* tensor, + const StatusCallback& done); + + // Atomically increments counter_ by one for sending, one for receiving. + // Invokes done when counter_ reaches 2. + // The purpose of checking counter_ is to ensure that done_ is called once. + StatusCallback CheckCounterAndCallDone(); +}; + +} // namespace tensorflow +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_ALL_TO_ALL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/allocator_retry.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/allocator_retry.h new file mode 100644 index 00000000..842b82db --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/allocator_retry.h @@ -0,0 +1,28 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_ALLOCATOR_RETRY_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_ALLOCATOR_RETRY_H_ + +#include "xla/tsl/framework/allocator_retry.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +using tsl::AllocatorRetry; // NOLINT +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_ALLOCATOR_RETRY_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/arg_ret_placement.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/arg_ret_placement.h new file mode 100644 index 00000000..e0b40182 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/arg_ret_placement.h @@ -0,0 +1,158 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_ARG_RET_PLACEMENT_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_ARG_RET_PLACEMENT_H_ + +#include +#include + +#include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/platform/status.h" + +namespace tensorflow::full_type { + +// Set the contents of memory_types for args (inputs to functions, "_Arg" ops) +// based on dtype. Raises an error if an int32 arg does not have +// expected full_type information. If an error raised about bad full +// time information causes a breakage, changing `SetMemoryTypeForArgs` to +// `WeakSetMemoryTypeForArgs` is a possible work around. +absl::Status SetMemoryTypeForArgs(const absl::InlinedVector& nodes, + const DataTypeVector& dtypes, + MemoryTypeVector& memory_types); + +// TODO(b/258849883) Delete the `Weak...` versions of these functions once +// everything is working with the version without `Weak`. + +// Set the contents of memory_types for args (inputs to functions, "_Arg" ops) +// based on dtype. Logging of warnings if an int32 arg does not have +// expected full_type information can be enabled. +absl::Status WeakSetMemoryTypeForArgs( + const absl::InlinedVector& nodes, const DataTypeVector& dtypes, + MemoryTypeVector& memory_types); + +// Set the contents of memory_types for rets (outputs from functions, "_Retval" +// ops) based on dtype. Raises an error if an int32 ret does not have +// expected full_type information (i.e. if the source of the input to the ret +// does not have expected full type information). If an error raised about bad +// full time information causes a breakage, changing `SetMemoryTypeForRets` to +// `WeakSetMemoryTypeForRets` is a possible work around. +absl::Status SetMemoryTypeForRets(const absl::InlinedVector& nodes, + const DataTypeVector& dtypes, + MemoryTypeVector& memory_types); + +// Set the contents of memory_types for rets (outputs from functions, "_Retval" +// ops) based on dtype. Logging of warnings if an int32 ret does not have +// expected full_type information (i.e. if the source of the input to the ret +// does not have expected full type information) can be enabled. +absl::Status WeakSetMemoryTypeForRets( + const absl::InlinedVector& nodes, const DataTypeVector& dtypes, + MemoryTypeVector& memory_types); + +// Set the contents of alloc_attrs for args (inputs to functions, "_Arg" ops) +// based on dtype. Raises an error if an int32 arg does not have +// expected full_type information. If an error raised about bad full +// time information causes a breakage, changing `SetAllocAttrsForArgs` to +// `WeakSetAllocAttrsForArgs` is a possible work around. +absl::Status SetAllocAttrsForArgs( + const absl::InlinedVector& nodes, const DataTypeVector& dtypes, + std::vector& alloc_attrs); + +// Set the contents of alloc_attrs for args (inputs to functions, "_Arg" ops) +// based on dtype. Logging of warnings if an int32 arg does not have +// expected full_type information can be enabled. +absl::Status WeakSetAllocAttrsForArgs( + const absl::InlinedVector& nodes, const DataTypeVector& dtypes, + std::vector& alloc_attrs); + +// Set the contents of alloc_attrs for rets (outputs from functions, "_Retval" +// ops) based on dtype. Raises an error if an int32 ret does not have +// expected full_type information (i.e. if the source of the input to the ret +// does not have expected full type information). If an error raised about bad +// full time information causes a breakage, changing `SetAllocAttrsForRets` to +// `WeakSetAllocAttrsForRets` is a possible work around. +absl::Status SetAllocAttrsForRets( + const absl::InlinedVector& nodes, const DataTypeVector& dtypes, + std::vector& alloc_attrs); + +// Set the contents of alloc_attrs for rets (outputs from functions, "_Retval" +// ops) based on dtype. Logging of warnings if an int32 ret does not have +// expected full_type information (i.e. if the source of the input to the ret +// does not have expected full type information) can be enabled. +absl::Status WeakSetAllocAttrsForRets( + const absl::InlinedVector& nodes, const DataTypeVector& dtypes, + std::vector& alloc_attrs); + +// Set the contents of alloc_attrs for args (inputs to functions, "_Arg" ops) +// for a single device funtion based on dtype. Raises an error if an int32 arg +// does not have expected full_type information. If an error raised about bad +// full time information causes a breakage, changing +// `SingleDeviceSetAllocAttrsForArgs` to `WeakSingleDeviceSetAllocAttrsForArgs` +// is a possible work around. The DataType specified by the "T" attr of input +// nodes is used. +absl::Status SingleDeviceSetAllocAttrsForArgs( + std::vector> arg_nodes, + bool ints_on_device, std::vector& alloc_attrs); + +// Set the contents of alloc_attrs for args (inputs to functions, "_Arg" ops) +// for a single device based on dtype. Logging of warnings if an int32 arg does +// not have expected full_type information can be enabled. The DataType +// specified by the "T" attr of input nodes is used. +absl::Status WeakSingleDeviceSetAllocAttrsForArgs( + std::vector> arg_nodes, + bool ints_on_device, std::vector& alloc_attrs); + +// Set the contents of alloc_attrs for rets (outputs from functions, "_Retval" +// ops) for a single device based on dtype. Raises an error if an int32 ret does +// not have expected full_type information (i.e. if the source of the input to +// the ret does not have expected full type information). If an error raised +// about bad full time information causes a breakage, changing +// `SingleDeviceSetAllocAttrsForRets` to `WeakSingleDeviceSetAllocAttrsForRets` +// is a possible work around. The DataType specified by the "T" attr of input +// nodes is used. +absl::Status SingleDeviceSetAllocAttrsForRets( + std::vector> ret_nodes, bool ints_on_device, + std::vector& alloc_attrs); + +// Set the contents of alloc_attrs for rets (outputs from functions, "_Retval" +// ops) for a single device based on dtype. Logging of warnings if an int32 ret +// does not have expected full_type information (i.e. if the source of the input +// to the ret does not have expected full type information) can be enabled. The +// DataType specified by the "T" attr of input nodes is used. +absl::Status WeakSingleDeviceSetAllocAttrsForRets( + std::vector> ret_nodes, bool ints_on_device, + std::vector& alloc_attrs); + +// Given a FullTypeId, return the corresponding MemoryTypes (i.e. return +// HOST_MEMORY for TFT_SHAPE_TENSOR, DEVICE_MEMORY othersize). +MemoryType MemoryTypeFromFullTypeId(FullTypeId id); + +// Check that use_host_memory is true iff FT has type_id TFT_SHAPE_TENSOR +// and logging of a warning if not can be enabled. Returns true if check passes. +// Note the FT is expected to be the full type information for a tensor, not for +// the whole ouput of an op, i.e. it should not have an outer TFT_PRODUCT. +bool LogMemoryTypeMismatch(bool use_host_memory, const FullTypeDef& ft); + +// Check that use_host_memory is true iff FT has type_id TFT_SHAPE_TENSOR +// and raise an error if not. Note the FT is expected to be the full type +// information for a tensor, not for the whole ouput of an op, i.e. it should +// not have an outer TFT_PRODUCT. +absl::Status CheckMemoryType(bool use_host_memory, const FullTypeDef& ft); + +} // namespace tensorflow::full_type + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_ARG_RET_PLACEMENT_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/base_collective_executor.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/base_collective_executor.h new file mode 100644 index 00000000..0c4689bc --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/base_collective_executor.h @@ -0,0 +1,164 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_BASE_COLLECTIVE_EXECUTOR_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_BASE_COLLECTIVE_EXECUTOR_H_ + +#include +#include + +#include "tensorflow/core/common_runtime/buf_rendezvous.h" +#include "tensorflow/core/framework/collective.h" +#include "tensorflow/core/framework/device_attributes.pb.h" +#include "tensorflow/core/platform/unbounded_work_queue.h" + +namespace tensorflow { +class CollectiveImplementation; +class DeviceMgr; +class Device; + +// Helper interface that aliases regular subfields of a Tensor as separate +// Tensors for in-place update. +class CollectiveAdapter { + public: + virtual ~CollectiveAdapter() {} + + // Move the backing tensor to 'output' with its original storage and + // shape. After this call this CollectiveAdapter object should be + // deleted immediately without calling any of its other methods. + virtual void ConsumeFinalValue(Tensor* output) = 0; + + // const access to entire intermediate value for debugging + virtual const Tensor& Value() const = 0; + + // Returns tensor for chunk i which aliases the backing buffer. + virtual Tensor ChunkAlias(int i) = 0; + + // Returns tensor allocated on the same device but with its own + // separate backing buffer. Will have same type and size as + // chunk i. + virtual Tensor TempChunk(int i) const = 0; + + // Bytes in chunk i + virtual int64_t ChunkBytes(int i) const = 0; + + // Generate a CPU RAM scalar tensor of the same DataType as the + // backing tensor with the given integer value. + virtual Tensor Scalar(int v) const = 0; + + // Generate a scalar tensor of same DataType and on the same device + // as the backing tensor. + virtual Tensor Scalar(Allocator* a, + const AllocationAttributes& attr) const = 0; + + // Debugging string describing buffer location + virtual string TBounds(const Tensor& t) const = 0; + + virtual string DebugString() const = 0; + + // Computes the number of elements per alias chunk tensor. + // + // A CHECK in tensor.cc expects that the memory buffer backing a + // Tensor will be aligned according to EIGEN_MAX_ALIGN_BYTES. To + // ensure that all chunk aliasing Tensors maintain this alignment we + // need to pick a chunk size that preserves it. Note than in extreme + // cases (impractical, but possible with very small tensors) one or + // more tail chunks can end up emptby. + static int64_t AlignedChunkElts(int64_t elt_bytes, int64_t total_elts, + int64_t num_chunks); +}; + +// Create a CollectiveAdaptor wrapping 'output', specialized to its +// data-type and shape. If align_chunks == true then chunk size may +// be larger than output->NumElements() / num_chunks and one or more +// of the suffix chunks may be empty. Chunks will be arranged to start +// and end on alignment boundaries. If align_chunks == false then +// output->NumElements() % num_chunks must be 0 and all chunks will +// have exactly the same size, ignoring alignment issues. +CollectiveAdapter* MakeCollectiveAdapter(Tensor* output, int num_chunks, + Allocator* allocator, + bool align_chunks = true); + +// Default implementation of CollectiveExecutor. Delegates the actual +// work of moving data to a class specialized for the operation type, +// arguments and device+interconnect topology. +class BaseCollectiveExecutor : public CollectiveExecutor { + public: + BaseCollectiveExecutor(CollectiveExecutorMgrInterface* cem, + CollectiveRemoteAccess* remote_access, int64_t step_id, + const DeviceMgr* dev_mgr, + std::shared_ptr work_queue) + : CollectiveExecutor(cem), + step_id_(step_id), + dev_mgr_(dev_mgr), + remote_access_(remote_access), + work_queue_(std::move(work_queue)) {} + + ~BaseCollectiveExecutor() override; + + void StartAbort(const absl::Status& s) override TF_LOCKS_EXCLUDED(status_mu_); + + void ExecuteAsync(OpKernelContext* ctx, const CollectiveParams* col_params, + const string& exec_key, StatusCallback done) override; + + void CompleteParamsAsync(const DeviceAttributes& device, CollectiveParams* cp, + CancellationManager* cancel_mgr, + StatusCallback done) override; + + CollectiveRemoteAccess* remote_access() override { + return remote_access_.get(); + } + + void RunClosure(std::function closure) override { + work_queue_->Schedule(std::move(closure)); + } + + // If we need to enforce an ordering on any portion of collective + // implementation, and the ordering is encoded via attribute on the collective + // op, this function will block until all dependencies for this collective + // have completed. + void WaitForDependencies(const CollectiveParams& col_params) override; + // Record that this collective has completed the portion of the implementation + // that needs to be ordered wrt other collectives, to unblock any of its + // dependent ops. + void UnblockDependencies(const CollectiveParams& col_params) override; + + protected: + const int64_t step_id_; + const DeviceMgr* dev_mgr_; // Not owned. + std::unique_ptr remote_access_; + // Ownership of `work_queue_` is shared between `this` and + // `CollectiveExecutorMgr`. + std::shared_ptr work_queue_; + mutex launch_mu_; + condition_variable launch_cv_; + // collective instance key -> number of local devices for which NCCL ops have + // been launched. + std::unordered_map launched_ TF_GUARDED_BY(launch_mu_); + mutex status_mu_; + absl::Status status_ TF_GUARDED_BY(status_mu_); + + private: + absl::Status CreateCollective(const CollectiveParams& col_params, + CollectiveImplementationInterface** col_impl); + // Check if all ops on which this collective depends on have launched. + bool CheckDependencies(const CollectiveParams& col_params) + TF_EXCLUSIVE_LOCKS_REQUIRED(launch_mu_); + // Tries to return the status that is the original error. It returns the + // aborted status if the collective executor is aborted. + absl::Status GetStatus(const absl::Status& s) TF_LOCKS_EXCLUDED(status_mu_); +}; + +} // namespace tensorflow +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_BASE_COLLECTIVE_EXECUTOR_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/bfc_allocator.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/bfc_allocator.h new file mode 100644 index 00000000..c8becd4c --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/bfc_allocator.h @@ -0,0 +1,45 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_BFC_ALLOCATOR_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_BFC_ALLOCATOR_H_ + +#include +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_set.h" +#include "xla/tsl/framework/bfc_allocator.h" +#include "tensorflow/core/common_runtime/allocator_retry.h" +#include "tensorflow/core/common_runtime/shared_counter.h" +#include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/lib/strings/numbers.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/thread_annotations.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +class MemoryDump; // NOLINT +using tsl::BFCAllocator; // NOLINT + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_BFC_ALLOCATOR_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/buf_rendezvous.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/buf_rendezvous.h new file mode 100644 index 00000000..8c2d201e --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/buf_rendezvous.h @@ -0,0 +1,134 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_BUF_RENDEZVOUS_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_BUF_RENDEZVOUS_H_ + +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/strings/string_view.h" +#include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/framework/cancellation.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/mutex.h" + +namespace tensorflow { +class Device; +class DeviceContext; +class DeviceMgr; +class Tensor; + +// EXPERIMENTAL: RDMA oriented producer/consumer rendezvous on a local +// Tensor value for which DMAHelper::CanUseDMA() is true, i.e. dense +// numeric types. Similar to Rendezvous but never owns a Ref on the +// tensor, instead it uses an explicit callback to the producer when +// the consumer side is finished with the value. This allows the +// producer to perform in-place updates on the source buffer or to take +// other actions that depend on knowing the consumer has passed a certain +// execution point. +class BufRendezvous { + public: + explicit BufRendezvous(uint64 step_id, const DeviceMgr* dev_mgr) + : step_id_(step_id), dev_mgr_(dev_mgr) {} + + virtual ~BufRendezvous(); + + // Inform all waiting parties that this BufRendezvous is defunct because of + // an error Status interrupting the Step. + void StartAbort(const absl::Status& s); + + struct Hook; + // Provided by the consumer to be called when access to the buffer + // is available. If the Status arg is not OK, then hook will not + // be populated. Ownership of Hook passes to consumer with the + // callback. + typedef std::function ConsumerCallback; + // Provided by the producer to be called when the consumer has finished + // reading the buffer and will no longer access it. + typedef std::function ProducerCallback; + + struct Hook { + Device* prod_dev; + DeviceContext* prod_ctx; + const Tensor* prod_value; + AllocatorAttributes prod_attr; + ProducerCallback prod_cb; + ConsumerCallback cons_cb; + CancellationManager* cancellation_manager; + CancellationToken cancellation_token; + explicit Hook(CancellationManager* cancellation_manager, + CancellationToken cancellation_token) + : prod_dev(nullptr), + prod_ctx(nullptr), + prod_value(nullptr), + prod_cb(nullptr), + cons_cb(nullptr), + cancellation_manager(cancellation_manager), + cancellation_token(cancellation_token) {} + string DebugString() const; + }; + + // Called to advertise availability of a Tensor value corresponding + // to key. That value must stay valid until done is called. + // + // If a non-null cancellation manager is provided, this function registers a + // callback to delete the hook and invoke provider/consumer callbacks with + // cancelled error. + void ProvideBuf(const string& key, Device* dev, DeviceContext* dev_ctx, + const Tensor* v, const AllocatorAttributes& attr, + const ProducerCallback& done, + CancellationManager* cancellation_manager); + + // Called to request access to a Tensor value corresponding to key. + // Consumer is provided with a Hook as soon as available. + // + // This function also checks that the current incarnation number of the + // `device` that produced this value matches the `incarnation` expected by the + // consumer, and invokes `done` with `FailedPrecondition` status and + // `nullptr` hook if it does not match. + // + // If a non-null cancellation manager is provided, this function registers a + // callback to delete the hook and invoke provider/consumer callbacks with + // cancelled error. + virtual void ConsumeBuf(const string& key, const string& device, + const uint64 incarnation, + const ConsumerCallback& done, + CancellationManager* cancellation_manager); + + // Cancel the rendezvous entry corresponding to `key`. Triggered by the + // cancellation manager. No-op if the rendezvous was already successful. + void CancelHook(const string& key); + + // Consumer must call this function when it's done reading the Hook provided + // by the ConsumerCallback. This function will invoke the producer callback + // and then delete h. + static void DoneWithHook(Hook* h); + + // Write the current contents of the table to the INFO log. + void LogContents(); + + protected: + const uint64 step_id_; + const DeviceMgr* const dev_mgr_; // Not owned. + mutex mu_; + absl::Status status_ TF_GUARDED_BY(mu_); + typedef absl::flat_hash_map HookTable; + HookTable hook_table_ TF_GUARDED_BY(mu_); + + void PurgeTable(const absl::Status& s, HookTable* table); +}; +} // namespace tensorflow +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_BUF_RENDEZVOUS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/build_graph_options.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/build_graph_options.h new file mode 100644 index 00000000..f33d43fb --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/build_graph_options.h @@ -0,0 +1,48 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_BUILD_GRAPH_OPTIONS_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_BUILD_GRAPH_OPTIONS_H_ + +#include + +#include "tensorflow/core/graph/collective_order.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/protobuf/config.pb.h" + +namespace tensorflow { + +struct BuildGraphOptions { + CallableOptions callable_options; + + // If `true`, uses Arg/Retval to implement feeds/fetches; otherwise + // uses Recv/Send to implement feeds/fetches. + // TODO(mrry): Remove this when the distributed runtime supports Arg/Retval. + bool use_function_convention = false; + + static constexpr int64_t kNoCollectiveGraphKey = 0; + int64_t collective_graph_key = kNoCollectiveGraphKey; + + // If not `kNone`, order all CollectiveReduce operations statically and + // deterministically. If `kEdges`, encode dependencies as explicit control + // edges, if `kAttrs` encode as attribute on collective op. + GraphCollectiveOrder collective_order = GraphCollectiveOrder::kNone; + + string DebugString() const; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_BUILD_GRAPH_OPTIONS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/collective_executor_mgr.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/collective_executor_mgr.h new file mode 100644 index 00000000..dddaa7ae --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/collective_executor_mgr.h @@ -0,0 +1,98 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_COLLECTIVE_EXECUTOR_MGR_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_COLLECTIVE_EXECUTOR_MGR_H_ + +#include "tensorflow/core/framework/collective.h" +#include "tensorflow/core/lib/gtl/flatmap.h" +#include "tensorflow/core/platform/unbounded_work_queue.h" + +namespace tensorflow { +class ConfigProto; +class DeviceMgr; + +class CollectiveExecutorMgr : public CollectiveExecutorMgrInterface { + public: + CollectiveExecutorMgr( + const ConfigProto& config, const DeviceMgr* dev_mgr, + std::unique_ptr dev_resolver, + std::unique_ptr param_resolver, + std::unique_ptr nccl_communicator); + + virtual ~CollectiveExecutorMgr(); + + CollectiveExecutor* FindOrCreate(int64_t step_id) override; + + void Cleanup(int64_t step_id) override; + + void CleanupAll() override; + + ParamResolverInterface* GetParamResolver() const override { + return param_resolver_.get(); + } + + DeviceResolverInterface* GetDeviceResolver() const override { + return dev_resolver_.get(); + } + + NcclCommunicatorInterface* GetNcclCommunicator() const override { + return nccl_communicator_.get(); + } + + void GetStepSequenceAsync(const GetStepSequenceRequest* request, + GetStepSequenceResponse* response, + const StatusCallback& done) override; + + void RefreshStepIdSequenceAsync(int64_t graph_key, + const StatusCallback& done) override; + + int64_t NextStepId(int64_t graph_key) override { + return CollectiveExecutor::kInvalidId; + } + + void RetireStepId(int64_t graph_key, int64_t step_id) override {} + + protected: + // Called by FindOrCreate when table entry does not yet exist. + virtual CollectiveExecutor* Create(int64_t step_id); + + const DeviceMgr* dev_mgr_; + std::unique_ptr dev_resolver_; + std::unique_ptr param_resolver_; + string gpu_ring_order_; + std::unique_ptr nccl_communicator_; + // Unbounded work queue for scheduling potentially-blocking work during + // collective op execution. Ownership is shared between `this` and + // `CollectiveRemoteAccessLocal`. + std::shared_ptr work_queue_; + + private: + mutex exec_mu_; + // Map from step_id to CollectiveExecutor + gtl::FlatMap executor_table_ + TF_GUARDED_BY(exec_mu_); +}; + +// Creates a local CollectiveExecutorMgr with production implementations of each +// components. Cases that need to inject other implementations of these +// components should call CollectiveExecutorMgr constructor directly. This only +// supports a single host. For distributed use case, use +// CreateProdRpcCollectiveExecutorMgr() instead. +std::unique_ptr CreateProdLocalCollectiveExecutorMgr( + const ConfigProto& config, const DeviceMgr* device_mgr, + std::unique_ptr nccl_communicator); + +} // namespace tensorflow +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_COLLECTIVE_EXECUTOR_MGR_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/collective_param_resolver_local.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/collective_param_resolver_local.h new file mode 100644 index 00000000..88813b0e --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/collective_param_resolver_local.h @@ -0,0 +1,215 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_COLLECTIVE_PARAM_RESOLVER_LOCAL_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_COLLECTIVE_PARAM_RESOLVER_LOCAL_H_ + +#include +#include +#include +#include +#include +#include +#include + +#include "tensorflow/core/framework/collective.h" +#include "tensorflow/core/framework/device_attributes.pb.h" +#include "tensorflow/core/lib/gtl/flatmap.h" +#include "tensorflow/core/platform/thread_annotations.h" + +namespace tensorflow { +class CompleteGroupRequest; +class CompleteGroupResponse; +class CompleteInstanceRequest; +class CompleteInstanceResponse; +class ConfigProto; +class DeviceMgr; + +// Implements ParamResolverInterface for a single-task context. +// It also implements the functionality necessary to serve as the +// group leader for param resolution in a multi-task context. +class CollectiveParamResolverLocal : public ParamResolverInterface { + public: + CollectiveParamResolverLocal(const ConfigProto& config, + const DeviceMgr* dev_mgr, + DeviceResolverInterface* dev_resolver, + NcclCommunicatorInterface* nccl_communicator, + const string& task_name); + + ~CollectiveParamResolverLocal() override {} + + void CompleteParamsAsync(const DeviceAttributes& device, CollectiveParams* cp, + CancellationManager* cancel_mgr, + const StatusCallback& done) override; + + void CompleteGroupAsync(const DeviceAttributes& device, + CollGroupParams* group_params, + CancellationManager* cancel_mgr, + const StatusCallback& done) override; + + void CompleteInstanceAsync(const CompleteInstanceRequest* request, + CompleteInstanceResponse* response, + CancellationManager* cancel_mgr, + const StatusCallback& done) override; + + absl::Status LookupGroup(int32_t group_key, CollGroupParams* group) override; + + void StartAbort(const absl::Status& s) override; + + protected: + // For access to InstanceRec and CompleteDefaultRanking. + friend class CollectiveParamResolverLocalTest; + + // Used to complete/verify CollGroup. + struct GroupRec { + mutable mutex mu; + CollGroupParams group TF_GUARDED_BY(mu); + absl::Status status TF_GUARDED_BY(mu); + std::unordered_map incarnations_by_device_name + TF_GUARDED_BY(mu); + std::vector pending_params TF_GUARDED_BY(mu); + std::vector pending_done TF_GUARDED_BY(mu); + }; + + // Finds the GroupRec that corresponds to group_params->group_key. + // Also populates group_params from that group_rec. + // Will wait until GroupRec is fully populated or an error arises before + // calling done. Callback GroupRec* arg is only valid if status is ok. + // Ownership of GroupRec stays with this object and does not pass to the + // callback. + void CompleteGroupLocal(const DeviceAttributes& device, + CollGroupParams* group_params, + CancellationManager* cancel_mgr, StatusCallback done) + TF_LOCKS_EXCLUDED(group_mu_); + + // Finishes the group parameters once all members of the group are there. + void FinishGroup(GroupRec* gr) TF_EXCLUSIVE_LOCKS_REQUIRED(gr->mu); + + // Cancels the group if it's still pending. + void CancelGroup(int32 group_key) TF_LOCKS_EXCLUDED(group_mu_); + + // Lookup and populate parameters from an already initialized group. + absl::Status LookupAndPopulateGroupParams(CollGroupParams* group_params); + + // Used to complete/verify CollInstance. + struct InstanceRec; + + typedef std::function IRConsumer; + struct InstanceRec { + mutex mu; + // Values to be shared by all instances, constant after initialization. + CollectiveParams* shared; + // If an error occurs during initialization this structure stays in the + // table with a non-OK status. Purging the table and restarting needs to be + // done at a higher level. + absl::Status status TF_GUARDED_BY(mu); + + // These fields are used to count the instances that have called + // in and become known while resolving broadcast source identity and + // communicator key. + int source_rank TF_GUARDED_BY(mu); + string communicator_key TF_GUARDED_BY(mu); + int known_count TF_GUARDED_BY(mu); + std::vector known TF_GUARDED_BY(mu); + std::vector known_waiters TF_GUARDED_BY(mu); + + InstanceRec() + : shared(new CollectiveParams()), source_rank(-1), known_count(0) {} + ~InstanceRec() { shared->Unref(); } + }; + + // Find the InstanceRec with the same instance_key as cp. If it doesn't + // already exist, create and initialize from gr and cp. + // created is set to true if a new IRec is created, false otherwise. + // + // Precondition: *gr must be a complete GroupRec, i.e. the value set + // by CompleteGroupLocal. *cp must be populated with all the fields + // required by InitInstanceSharedParams. Ownership of InstanceRec stays + // with this object and does not pass to the callback. + InstanceRec* GetOrCreateInstanceRec(CollectiveParams* cp, bool* created) + TF_LOCKS_EXCLUDED(instance_mu_, group_mu_); + + // Populate *ir with device membership from gr, then initialize to be specific + // to cp->instance_key, i.e. order the devices and tasks. + // + // Preconditions: + // cp is populated with all DeviceLocalities + void InitInstanceSharedParams(const CollectiveParams* cp, InstanceRec* ir); + + // Establishes the final order of gp->device_names and gp->task_names by + // considering localities of all devices. + void CompleteDefaultRanking(CollGroupParams* gp); + + // Finish populating *cp. + // Precondition: *gr has been fully populated by CompleteGroupLocal. + void CompleteInstanceLocal(const string& device, CollectiveParams* cp, + const StatusCallback& done) + TF_LOCKS_EXCLUDED(instance_mu_, group_mu_); + + // Finish populating *cp from fully initialized *ir. + // Precondition: *gr and *ir are fully populated. + void CompleteInstanceFromInitializedIRec(const string& device, + CollectiveParams* cp, + InstanceRec* ir, + const StatusCallback& done) + TF_LOCKS_EXCLUDED(ir->mu); + + // Complete instance params after waiting for group. + // Precondition: *cp has complete group data and default_rank. + void WaitForGroup(InstanceRec* ir, CollectiveParams* cp, const IRConsumer& f) + TF_LOCKS_EXCLUDED(ir->mu); + + // If cp.device_names contains only devices local to this process + // populates *localities, else returns an error. + absl::Status GetLocalDeviceLocalities( + const CollectiveParams& cp, std::vector* localities); + + // Sets cp->instance_default_rank according to location of device in + // current ordering of cp->instance.device_names. + void SetDefaultRank(const string& device, CollectiveParams* cp); + + // Sets cp->instance.type based on collective op type, and attempts to assign + // best implementation. + void AssignCollectiveType(CollectiveParams* cp); + + void StartAbortLocal(const absl::Status& s) + TF_LOCKS_EXCLUDED(status_mu_, group_mu_, instance_mu_); + + const bool nccl_; + const DeviceMgr* dev_mgr_; + DeviceResolverInterface* dev_resolver_; // Not owned. + NcclCommunicatorInterface* nccl_communicator_; // Not owned. + string task_name_; + string gpu_ring_order_; + mutex group_mu_; + gtl::FlatMap> group_table_ + TF_GUARDED_BY(group_mu_); + struct TupleHash { + std::size_t operator()(const std::tuple x) const { + // The hash does not need to be unique and a value of 20 is picked + // arbitrarily as an effort to reduce probability of conflicts. + return (std::get<0>(x) << 20) + std::get<1>(x); + } + }; + mutex instance_mu_; + gtl::FlatMap, + std::unique_ptr, TupleHash>> + instance_table_ TF_GUARDED_BY(instance_mu_); + mutex status_mu_; + absl::Status status_ TF_GUARDED_BY(status_mu_); +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_COLLECTIVE_PARAM_RESOLVER_LOCAL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/collective_rma_local.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/collective_rma_local.h new file mode 100644 index 00000000..2c51b87a --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/collective_rma_local.h @@ -0,0 +1,82 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_COLLECTIVE_RMA_LOCAL_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_COLLECTIVE_RMA_LOCAL_H_ + +#include "tensorflow/core/common_runtime/buf_rendezvous.h" +#include "tensorflow/core/common_runtime/device_mgr.h" +#include "tensorflow/core/framework/collective.h" +#include "tensorflow/core/framework/rendezvous.h" + +namespace tensorflow { + +// Basic implementation of PerStepCollectiveRemoteAccess. +class CollectiveRemoteAccessLocal : public CollectiveRemoteAccess { + public: + CollectiveRemoteAccessLocal(const DeviceMgr* dev_mgr, + DeviceResolverInterface* dev_resolver, + int64_t step_id) + : dev_mgr_(dev_mgr), + dev_resolver_(dev_resolver), + buf_rendezvous_(step_id, dev_mgr), + step_id_(step_id) {} + + ~CollectiveRemoteAccessLocal() override = default; + + void StartAbort(const absl::Status& s) override; + + void RecvFromPeer(const string& peer_device, const string& peer_task, + bool peer_is_local, const string& key, Device* to_device, + DeviceContext* to_device_ctx, + const AllocatorAttributes& to_alloc_attr, Tensor* to_tensor, + const DeviceLocality& client_locality, + int dev_to_dev_stream_index, + CancellationManager* cancellation_manager, + const StatusCallback& done) override; + + void PostToPeer(const string& peer_device, const string& peer_task, + const string& key, Device* from_device, + DeviceContext* from_device_ctx, + const AllocatorAttributes& from_alloc_attr, + const Tensor* from_tensor, + const DeviceLocality& client_locality, + CancellationManager* cancellation_manager, + const StatusCallback& done) override; + + void CheckPeerHealth(const string& peer_task, int64_t timeout_in_ms, + const StatusCallback& done) override; + + BufRendezvous* buf_rendezvous() override { return &buf_rendezvous_; } + + // Copy utility that always copies bytes from src to dst even if + // they are on the same device, unlike CopyTensor::ViaDMA which will + // just change the dst buffer pointer in that case. + static void MemCpyAsync(DeviceContext* src_dev_ctx, + DeviceContext* dst_dev_ctx, Device* src_dev, + Device* dst_dev, const AllocatorAttributes& src_attr, + const AllocatorAttributes& dst_attr, + const Tensor* src, Tensor* dst, + int dev_to_dev_stream_index, + const StatusCallback& done); + + protected: + const DeviceMgr* dev_mgr_; // not owned + DeviceResolverInterface* dev_resolver_; // not owned + BufRendezvous buf_rendezvous_; + int64_t step_id_; +}; + +} // namespace tensorflow +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_COLLECTIVE_RMA_LOCAL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/collective_test_util.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/collective_test_util.h new file mode 100644 index 00000000..492097c5 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/collective_test_util.h @@ -0,0 +1,109 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_COLLECTIVE_TEST_UTIL_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_COLLECTIVE_TEST_UTIL_H_ + +#include "tensorflow/core/common_runtime/collective_rma_local.h" +#include "tensorflow/core/common_runtime/device_mgr.h" +#include "tensorflow/core/common_runtime/test_collective_executor_mgr.h" +#include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/framework/cancellation.h" +#include "tensorflow/core/framework/collective.h" +#include "tensorflow/core/framework/device.h" +#include "tensorflow/core/framework/device_attributes.pb.h" +#include "tensorflow/core/framework/device_base.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/platform/refcount.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/unbounded_work_queue.h" + +namespace tensorflow { + +// Wraps CollectiveRemoteAccessLocal with the ability to return an +// error status to the N'th action. +class FailTestRMA : public CollectiveRemoteAccessLocal { + public: + FailTestRMA(const DeviceMgr* dev_mgr, DeviceResolverInterface* dev_resolver, + int64_t step_id); + + // Sets when it should fail. Setting to zero disables the failure. + void set_fail_after(int fail_after) { + mutex_lock l(mu_); + fail_after_ = fail_after; + } + + void RecvFromPeer(const string& peer_device, const string& peer_task, + bool peer_is_local, const string& key, Device* to_device, + DeviceContext* to_device_ctx, + const AllocatorAttributes& to_alloc_attr, Tensor* to_tensor, + const DeviceLocality& client_locality, + int dev_to_dev_stream_index, + CancellationManager* cancellation_manager, + const StatusCallback& done) override; + + void PostToPeer(const string& peer_device, const string& peer_task, + const string& key, Device* from_device, + DeviceContext* from_device_ctx, + const AllocatorAttributes& from_alloc_attr, + const Tensor* from_tensor, + const DeviceLocality& client_locality, + CancellationManager* cancellation_manager, + const StatusCallback& done) override; + + private: + bool MaybeFail(const StatusCallback& done); + + mutex mu_; + int fail_after_ TF_GUARDED_BY(mu_); +}; + +struct CollectiveTestEnv { + int num_workers; + int num_devices_per_worker; + DeviceType device_type; + std::unique_ptr param_resolver; + std::unique_ptr col_exec_mgr; + std::shared_ptr work_queue; + std::unique_ptr device_mgr; + std::unique_ptr device_resolver; + std::unique_ptr nccl_communicator; + core::RefCountPtr col_exec; + FailTestRMA* remote_access; + + CollectiveTestEnv() : device_type(DEVICE_DEFAULT) {} +}; + +std::unique_ptr CreateCollectiveTestEnv( + int num_workers, int num_devices_per_worker, DeviceType device_type, + bool use_nccl = false); + +core::RefCountPtr CreateCollectiveParams( + const CollectiveTestEnv& test_env, int rank, const string& collective_name, + CollectiveType collective_type, DataType dtype, const TensorShape& shape, + const std::vector> user_specified_rank_per_worker = {{}}); + +std::vector GenerateEvenSubdivOffsets(int num_devices_per_worker, + int num_subdivs); + +// Runs a collective. input and output should be on the host. +absl::Status RunCollective(CollectiveTestEnv* test_env, + CollectiveParams* col_params, Device* device, + Tensor* input, Tensor* output); + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_COLLECTIVE_TEST_UTIL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/collective_util.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/collective_util.h new file mode 100644 index 00000000..79cd5d50 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/collective_util.h @@ -0,0 +1,60 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_COLLECTIVE_UTIL_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_COLLECTIVE_UTIL_H_ + +#include + +#include "tensorflow/core/common_runtime/device.h" +#include "tensorflow/core/common_runtime/device_mgr.h" +#include "tensorflow/core/framework/collective.h" +#include "tensorflow/core/framework/device_attributes.pb.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { +namespace collective_util { + +absl::Status InitializeDeviceAndLocality(const DeviceMgr* dev_mgr, + const string& device_name, + Device** device, + DeviceLocality* device_locality); +string SubdivPermDebugString(const CollectiveParams& col_params); + +// Used for executing a sub-operation, e.g. a merge_op instance, with +// an OpKernelContext based on the one passed into this Op. +class SubContext { + public: + OpKernelContext::Params sub_params_; + absl::InlinedVector sub_inputs_; + absl::InlinedVector sub_input_attr_; + absl::InlinedVector sub_input_dc_; + // Used only for Binary and Unary Ops for which we require + // the calculation to be in-place on the first input. + int forward_from_ = 0; + std::unique_ptr sub_ctx_; + SubContext(OpKernelContext* ctx, OpKernelContext::Params* params, + OpKernel* op, Tensor* output, Tensor* input); + ~SubContext() = default; +}; + +absl::Status ComputeBinOp(OpKernelContext* op_ctx, + OpKernelContext::Params* params, Device* device, + OpKernel* op, Tensor* output, Tensor* input); + +} // namespace collective_util +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_COLLECTIVE_UTIL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/colocate_predecessor_trees_pass.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/colocate_predecessor_trees_pass.h new file mode 100644 index 00000000..b1c1eea6 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/colocate_predecessor_trees_pass.h @@ -0,0 +1,138 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_COLOCATE_PREDECESSOR_TREES_PASS_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_COLOCATE_PREDECESSOR_TREES_PASS_H_ + +#include "tensorflow/core/common_runtime/optimization_registry.h" + +// TODO(b/344910755): Use the marker in Fill op to find the identity op. This +// makes the heuristic more straightforward. +// Colocate a tree of unplaced nodes with its placed Identity node. Identify a +// dangling tree of ops whose Identify nodes are assigned but rest of ops are +// not assigned. Then it should colocate the rest of the ops. +// +// For example, the graph before pass is: +// +// node { +// name: "const0" +// op: "Const" +// } +// node { +// name: "const1" +// op: "Const" +// } +// node { +// name: "fill0" +// op: "Fill" +// input: "const1" +// input: "const0" +// } +// node { +// name: "id0" +// op: "Identity" +// input: "fill0" +// device: "/job:worker/replica:0/task:2/device:CPU:0" +// } +// node { +// name: "id1" +// op: "Identity" +// input: "fill0" +// device: "/job:worker/replica:0/task:2/device:CPU:0" +// } +// +// The graph after pass is: +// +// node { +// name: "const0" +// op: "Const" +// attr { +// key: "_class" +// value { +// list { +// s: "loc:@id0" +// } +// } +// } +// } +// node { +// name: "const1" +// op: "Const" +// attr { +// key: "_class" +// value { +// list { +// s: "loc:@id0" +// } +// } +// } +// } +// node { +// name: "fill0" +// op: "Fill" +// input: "const1" +// input: "const0" +// attr { +// key: "_class" +// value { +// list { +// s: "loc:@id0" +// } +// } +// } +// } +// node { +// name: "id0" +// op: "Identity" +// input: "fill0" +// device: "/job:worker/replica:0/task:2/device:CPU:0" +// attr { +// key: "_class" +// value { +// list { +// s: "loc:@id0" +// } +// } +// } +// } +// node { +// name: "id1" +// op: "Identity" +// input: "fill0" +// device: "/job:worker/replica:0/task:2/device:CPU:0" +// attr { +// key: "_class" +// value { +// list { +// s: "loc:@id0" +// } +// } +// } +// } + +namespace tensorflow { + +// This pass can place each tree of unassigned nodes with its Identity nodes, +// when the Identity nodes are already assigned to a device. Placement is +// instructed here with the colocation class attribute _class. This is a good +// heuristic because it reduces number of cut edges and tends to load balance. +class ColocatePredecessorTreesPass : public GraphOptimizationPass { + public: + absl::Status Run(const GraphOptimizationPassOptions& options) override; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_COLOCATE_PREDECESSOR_TREES_PASS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/colocation_graph.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/colocation_graph.h new file mode 100644 index 00000000..a31a2aad --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/colocation_graph.h @@ -0,0 +1,394 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_COLOCATION_GRAPH_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_COLOCATION_GRAPH_H_ + +#include +#include + +#include "absl/strings/str_join.h" +#include "tensorflow/core/common_runtime/device.h" +#include "tensorflow/core/common_runtime/inspecting_placer.h" +#include "tensorflow/core/common_runtime/placer_inspection_required_ops_utils.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/util/device_name_utils.h" +#include "tensorflow/core/util/port.h" + +namespace tensorflow { + +// Represents a node in the disjoint node forest and the +// accumulated constraints on the device used by that node. +class Member { + public: + Member() = default; + + absl::Status SetParentAndSupportedDevices( + const Node& node, const std::vector& types, + const DeviceNameUtils::ParsedName* local_address_spec); + + const DeviceNameUtils::ParsedName& requested_device_name() const { + return requested_device_name_; + } + + absl::Status SetAssignedDeviceName(const string& device_name); + absl::Status SetResourceDeviceName(const Node& node); + absl::Status SetRequestedDeviceName(const Node& node); + + absl::Status FillPossibleDevices(PossibleDevices* possible_device) const; + + // Returns whether `src_root` is assigned to a CompositeDevice and `this` is + // assigned to a physical device. + bool IsEdgeFromCompositeDeviceToPhysicalDevice(const Member& src_root) const; + + absl::Status EnsureCompatibilityAcrossResourceEdge( + const Node& src, const Member& src_root, + const Node& dst, /*dst_root is this*/ + bool log_device_placement); + + const PrioritizedDeviceTypeVector& supported_device_types() const { + return supported_device_types_; + } + + // If `dry_run` is true, just sets `new_root` and `old_root` and does not + // actually modify anything in the `tree`. + static void Merge(std::vector* tree, int x_root, int y_root, + Member** new_root, Member** old_root, bool dry_run); + + // Returns the root node of the disjoint tree to which the node with the + // given id is connected. + // FindRoot should be called only for debugging or after the members have + // been updated with direct root pointers because it does not update + // root pointers and can traverse many links. It exists to have + // a const version of FindAndUpdateRoot + static int FindRoot(const std::vector& tree, int node_id); + static int FindAndUpdateRoot(std::vector* tree, int node_id); + + absl::Status MergeDeviceNames(const Member& other, bool allow_soft_placement); + + // Updates this to contain the intersection of the device types in + // this and "other". If the intersection is empty, returns false and does + // not update this. Else returns true and updates this. + bool MergeSupportedDevices(const Member& other); + + absl::Status AssignDevice(const Node& node); + + // If user does not explicitly request XLA device and non-XLA device is + // supported for this node, use only the non-XLA device. See b/140896502. + void MaybeExcludeXlaDevices(); + + // Limit the possible devices of this (should be a root) to the device + // specifications in `devices`. + absl::Status LimitToPossibleDevices(const PossibleDevices& devices, + bool allow_soft_placement); + + void set_possible_devices(std::vector&& devices) { + possible_devices_ = devices; + } + const std::vector& possible_devices() { return possible_devices_; } + + // Returns a (parsed) device name that is based on requested_device_name() + // but with potentially cleared device type and ID fields. A field is cleared + // if the assigned_device_name does not specify it. If it does, the field + // is not cleared because soft placement cannot violate assigned device names. + DeviceNameUtils::ParsedName GetSoftDeviceName() const; + + // Same as GetSoftDeviceName but device type and device ID fields are not + // cleared if resource device has them set. + DeviceNameUtils::ParsedName GetPreferredSoftDeviceName() const; + + string DebugString() const; + + bool has_assigned_device_name() const { return assigned_device_name_.has_id; } + + private: + // Updates this to contain the intersection of the device types in + // this and `other_devices`. + bool MergeSupportedDevices(const PrioritizedDeviceTypeVector& other_devices); + + // The id of the node that is the parent of this one, or its own + // id if it is a root. parent <= 0 indicates that this member is invalid. + int parent_ = -1; + + // A proxy for the depth of the tree that is used to prefer + // connecting smaller trees to larger trees when merging disjoint + // sets. + int rank_ = 0; + + // Once colocation groups have been formed, the Placer starts actually + // choosing devices. All nodes in a group must be assigned to the same + // device. Once we assigned the first device to some node in this group, + // we set assigned_device_name_index to this device name's index in the + // graph. + // The `*_device_name_` fields will contain the parsed name of this device + // and `possible_devices`, if computed, will contain just this device. + // `assigned_device_name_index` is an optimization to avoid parsing and + // comparing device names. The value of -1 signals that a single device + // has not been chosen yet. + int assigned_device_name_index_ = -1; + + // The merged form of the device requested for this node, with those of all of + // its children. requested_device_name_ is always kept a specialization (i.e. + // DeviceNameUtils::IsSpecification) of assigned_device_name_. When no device + // is requested, this field is set to assigned_device_name_. As a + // specialization of assigned_device_name_, requested_device_name_ represents + // the most specific form of all assigned and requested devices of this node + // and its children, if this node is a root. requested_device_name_ is used + // to finally select devices for nodes. We can override requested devices due + // to resource colocation constraints but not assigned devices (unless soft + // placement is on). + // INVARIANT: requested_device_name_ is always kept a + // DeviceNameUtils::IsSpecification of assigned_device_name_ and + // resource_device_name_. This makes requested_device_name_ the "accumulation + // of all wishes" about the device. + DeviceNameUtils::ParsedName requested_device_name_; + + // The merged form of the device assigned for this node, with + // those of all of its children. + // This field is used to raise errors due to unsatisfiable constraints. + // Can be a partial specification. + DeviceNameUtils::ParsedName assigned_device_name_; + + // The merged form of the requested resource device assigned for this node, + // with those of all of its children. + // This field is used to raise errors due to unsatisfiable constraints. + // Can be a partial specification. + // resource_device_name_ is initialized with user-requested device on nodes + // producing resources, e.g. VarHandleOp. + // For historical reasons, with soft placement enabled, Placer can "move" + // resources (place resource producing ops on a device different from what + // the user explicitly requested) when the colocation group of a resource + // producing op contains ops that are not supported on the user-requested + // resource device. A classic example of this is a sparse optimizer (only + // supported on CPU) used on a GPU variable. In this case, the whole group + // will be assigned to some device supported by all ops in the colocation + // group. This is a surprising and unfortunate behavior because: + // 1. Since soft_placement is on by default, users don't know that their + // variables are created on a different device than what they requested. + // Among other things, this can lead to surprising poor performance. + // 2. Eager runtime cannot "move" resources. The same code can "work" when + // wrapped in tf.function but will fail when run eagerly. + // 3. Extra complexity here to preserve these resource moving capabilities. + DeviceNameUtils::ParsedName resource_device_name_; + + // The intersection of all device types supported by this node, + // and those of all of its children, in priority order + // of the preferred device. + // It is possible that supported_device_types_ has an empty intersection with + // requested/assigned/resource devices. We could have detected such cases + // as soon as they happen and raise an error. Instead, for historical reasons, + // we leave such error detection to the final device picking stage. + PrioritizedDeviceTypeVector supported_device_types_; + + // If this node is a root, stores a list of Devices to which this node + // and all of its children can be assigned. + // `possible_devices` is empty if they have not yet been computed. + std::vector possible_devices_; +}; + +// This class maintains the connected components of a colocation +// constraint graph, and uses this information to assign a satisfying +// device placement to the nodes of the graph. +// +// This implementation uses the Union-Find algorithm to efficiently maintain the +// connected components and incrementally adds edges via +// ColocationGraph::ColocateNodes() invocations. +// +// ColocationGraph does not assign any devices to graph nodes. The +// `log_device_placement` argument is used to log messages when requested +// device is ignored. +class ColocationGraph { + public: + // graph, flib_def, and device_set must not be null and must outlive + // this ColocationGraph. default_local_device can be null. If not, must + // outlive this. + ColocationGraph(const Graph* graph, const FunctionStack& stack, + const FunctionLibraryDefinition* flib_def, + const DeviceSet* device_set, + const Device* default_local_device, bool allow_soft_placement, + bool log_device_placement); + + absl::Status Initialize(); + + const std::vector& members() const { return members_; } + + // Limit the group containing `node` to the device specifications in + // `devices`. + absl::Status LimitToPossibleDevices(const Node& node, + const PossibleDevices& devices); + + // Limits the possible devices of `node`'s colocation group to the device + // to which `node` is assigned. This makes sure that all nodes in this + // colocation group will be assigned to the same device. Without this + // explicit restriction, heuristics can choose a different possible device + // for other nodes in the group. + absl::Status LimitToAssignedDevice(const Node& node); + + // Returns the root node of the disjoint tree to which the node with the + // given id is connected. + // Updates the internal pointers so that future calls will returns faster. + int FindAndUpdateRoot(int node_id) { + return Member::FindAndUpdateRoot(&members_, node_id); + } + + // For the given node, subject to the constraints previously given + // to this ColocationGraph, set its assigned_device_name. Returns OK + // if a satisfying device can be found, otherwise an error. + // + // Note: This method returns a pointer to a field within members_. + // The caller must not use the returned pointer after there is any possibility + // that the members_[i].possible_devices field has been modified. + absl::Status GetDevicesForNode(Node* node, + const std::vector** possible_devices); + + // Returns debugging info for the node referred to by 'node_root'. + string DebugInfo(const int node_root) const; + + string DebugString() const; + + // Returns a list of devices having type in supported_device_types. The + // returned list is sorted by preferred type (higher numeric type is + // preferred). + static std::vector FilterSupportedDevices( + const std::vector& devices, + const PrioritizedDeviceTypeVector& supported_device_types, + const Device* default_local_device); + + private: + // Adds each node of the Graph to this ColocationGraph as a singleton. + // + // NOTE: The implementation assumes that the ids of nodes passed to + // this method are dense and zero-based; the memory used will be linear in + // the largest node ID. + // NOTE: If this method returns an error, *this is left in an undefined + // state. + absl::Status ColocateAllNodes(); + + absl::Status ColocateResourceOrRefEdge(const Node* src, const Node* dst); + + // Adds colocation constraints to data types known not to support copying. + absl::Status ColocateUncopiableTypeEdges( + std::unordered_set* inspection_required); + + // Updates this ColocationGraph by making sure that all nodes + // touching resource and/or ref tensors are colocated. + // As it iterates over the edges, fills the `inspection_required` set with + // the nodes that + // PlacerInspectionRequiredOpChecker::IsPlacerInspectionRequired + // deems as requiring deep inspection by placer. This is an optimization. + // TODO(mdan): Deprecate in favor of ColocateUncopiableTypeEdges. + absl::Status ColocateResourceAndRefEdges( + std::unordered_set* inspection_required); + + // Updates this ColocationGraph by making sure that all nodes having inputs of + // a DT_VARIANT data type with a host-only underlying types (e.g. strings) can + // be placed only on CPU device. We do that by reverse-DFS traversal from all + // nodes that take variant inputs to the node that produces that variant. + // TODO(ezhulenev): This function does not yet support "deep op" inspection, + // that we have for DT_RESOURCE edges. + absl::Status AddHostOnlyDataTypesConstraints(); + + absl::Status AddInspectionConstraints( + const std::unordered_set& inspection_required); + + // Applies colocation groups for `node`'s inputs and outputs to this + // ColocationGraph. + // `groups` are the colocation groups to which `nodes`'s inputs and outputs + // belong. + // `node` is a node requiring deep inspection (e.g. a node calling + // a function) + // + // For example, consider a `node` taking two inputs and producing one output + // a b + // | | + // v v + // node + // | + // v + // c + // + // `groups` can tell us that `a` and `c` must be colocated and their device + // must be a GPU. `b` might be in a group by itself without any device + // restrictions. + // + // ApplyIOColocationGroups will have an effect of calling + // ColocateNodes(a, c) and LimitToPossibleDevices(`a`, "GPU"). The colocation + // group of the `node` itself is not directly impacted. + // + absl::Status ApplyIOColocationGroups(const IOColocationGroups& groups, + const Node& node); + + absl::Status ColocateNodeToGroup( + std::unordered_map* + colocation_group_root, + const Node* node, absl::string_view colocation_group); + + // Merge the (possibly disjoint) sets containing nodes "x" and + // "y". Returns OK if the all nodes in the union of these sets can + // be placed on the same device type. + // + // If this method returns an error, *this is unchanged. + absl::Status ColocateNodes(const Node& x, const Node& y); + + // This overload of ColocateNodes() allows a caller to provide the root node + // ids for the two nodes. For large graphs, this noticeably reduces the + // graph load time. + // If this method returns an error, *this is unchanged. + absl::Status ColocateNodes(const Node& x, int x_root, const Node& y, + int y_root); + + void GetSoftDeviceCandidates(const Node& node, const Member& root_member, + int root_id, + std::vector* possible_devices); + + absl::Status InitializeMembers(); + + absl::Status InitializeMemberWithAssignedDevice( + const string& assigned_device_name, const string& node_type, + Member* member); + + absl::Status InitializeMember(const Node& node, Member* member); + + // Returns the root node of the disjoint tree to which the node with the + // given id is connected. + // FindRoot should be called only for debugging or after the members have + // been updated with direct root pointers because it does not update + // root pointers and can traverse many links. It exists to have + // a const version of FindAndUpdateRoot + int FindRoot(int node_id) const { + return Member::FindRoot(members_, node_id); + } + + const Graph& graph_; + const FunctionStack stack_; + std::vector members_; + InspectingPlacer inspecting_placer_; + PlacerInspectionRequiredOpChecker inspection_required_checker_; + const DeviceSet& device_set_; + const std::vector device_types_; + const DeviceNameUtils::ParsedName local_address_spec_; + const Device* default_local_device_; + const bool allow_soft_placement_; + const bool log_device_placement_; + + ColocationGraph(const ColocationGraph&) = delete; + void operator=(const ColocationGraph&) = delete; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_COLOCATION_GRAPH_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/composite_device.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/composite_device.h new file mode 100644 index 00000000..6e79542a --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/composite_device.h @@ -0,0 +1,72 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_COMPOSITE_DEVICE_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_COMPOSITE_DEVICE_H_ + +#include "absl/strings/string_view.h" +#include "tensorflow/core/common_runtime/device.h" +#include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/framework/device_attributes.pb.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { + +extern const char* const kCompositeDeviceType; + +// A virtual device which represents a set of devices. We don't execute any +// op on this virtial device. +class CompositeDevice : public Device { + public: + absl::Status Sync() override { + return errors::Internal( + "Sync() should never been invoked on CompositeDevice."); + } + + Allocator* GetAllocator(AllocatorAttributes) override { return nullptr; } + + const std::vector* underlying_devices() const { + return &underlying_devices_; + } + + // Helper for creating a CompositeDevice on the same task as the given host + // CPU. + static std::unique_ptr MakeDevice( + const std::vector& underlying_devices, const int unique_device_id, + const DeviceNameUtils::ParsedName& host_name, absl::Status* status); + + // Helper for creating a CompositeDevice with the given device name. + static std::unique_ptr MakeDevice( + const std::vector& underlying_devices, const string& device_name, + absl::Status* status); + + bool IsRemoteCallAllowed() const override { return false; } + + private: + CompositeDevice(const DeviceAttributes& device_attributes, + const std::vector& underlying_devices) + : Device(/*env=*/nullptr, device_attributes), + underlying_devices_(underlying_devices) {} + + const std::vector underlying_devices_; + + CompositeDevice(const CompositeDevice&) = delete; + void operator=(const CompositeDevice&) = delete; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_COMPOSITE_DEVICE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/constant_folding.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/constant_folding.h new file mode 100644 index 00000000..fd74a554 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/constant_folding.h @@ -0,0 +1,70 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_CONSTANT_FOLDING_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_CONSTANT_FOLDING_H_ + +#include "tensorflow/core/common_runtime/device.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/platform/env.h" + +// TODO(skyewm): can this be combined with EvaluateConstantTensor? + +namespace tensorflow { + +// This generator type is used to generate a name for the newly folded node +// based on the node's old name. +using ConstantFoldNameGenerator = + std::function; + +// Options specific to constant folding optimizations. +struct ConstantFoldingOptions { + // If "consider" is not a nullptr, then only constant fold a node "n" if + // consider(n) returns true. + std::function consider = nullptr; + // If shape_map is not a nullptr, it is a map from node n to a + // vector of the (potentially partially-known) shapes of its + // outputs. + const std::unordered_map>* shape_map = + nullptr; // not owned + // The maximum size of each constant created during constant folding + // optimization. + int64_t max_constant_size_in_bytes = 10 * 1024 * 1024; + + // A generator for the name suffix of constant folded nodes. A + // default id generator that monotonically increases is used if nullptr is + // passed. + ConstantFoldNameGenerator generate_new_name = nullptr; +}; + +// Perform constant folding optimization on "graph". +// Looks for nodes in "graph" that can be completely evaluated statically, i.e., +// that are only dependent on constants. Evaluates those nodes on a CPU device +// and replaces those nodes with the result of the evaluation. +// "partition_device", if non-null, is the device where all the graph nodes are +// assumed to execute. +// Sets `was_mutated` to true if and only if "graph" has been mutated. +// The status is only set to a non-OK state if an unexpected error is hit +// running the graph. +absl::Status ConstantFold(const ConstantFoldingOptions& opts, + FunctionLibraryRuntime* function_library, Env* env, + const Device* partition_device, Graph* graph, + bool* was_mutated); + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_CONSTANT_FOLDING_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/copy_tensor.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/copy_tensor.h new file mode 100644 index 00000000..0f621603 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/copy_tensor.h @@ -0,0 +1,80 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_COPY_TENSOR_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_COPY_TENSOR_H_ + +#include "tensorflow/core/common_runtime/device.h" +#include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/framework/device_base.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +class CopyTensor { + public: + typedef void (*CopyFunction)( + DeviceContext* send_dev_context, DeviceContext* recv_dev_context, + Device* src, Device* dst, const AllocatorAttributes src_alloc_attr, + const AllocatorAttributes dst_alloc_attr, const Tensor* input, + Tensor* output, int dev_to_dev_stream_index, StatusCallback done); + + // Copies "input" to "output" between devices accessible to the + // local process via some DMA-like method. "edge_name" is the name + // of the tensor being copied, for debugging purposes. Depending on + // the type of devices and memory in use, the copy may be performed + // synchronously or asynchronously. 'done' will be invoked only + // after the copy is actually complete. + static void ViaDMA(absl::string_view edge_name, + DeviceContext* send_dev_context, + DeviceContext* recv_dev_context, Device* src, Device* dst, + const AllocatorAttributes src_alloc_attr, + const AllocatorAttributes dst_alloc_attr, + const Tensor* input, Tensor* output, + int dev_to_dev_stream_index, StatusCallback done, + bool sync_dst_compute = true); + + // Object used to call Register() at static-initialization time. + // Note: This should only ever be used as a global-static object; no stack + // or heap instances. + class Registration { + public: + Registration(DeviceType sender_device_type, DeviceType receiver_device_type, + CopyFunction copy_function) { + TF_QCHECK_OK(Register(sender_device_type, receiver_device_type, + copy_function, /*is_pluggable_device=*/false)); + } + }; + + // Register a function for copying between two specific DeviceTypes. + // Note: This should only be called via the constructor of + // CopyTensor::Registration or from PluggableDevice implementation. + static absl::Status Register(DeviceType sender_device_type, + DeviceType receiver_device_type, + CopyFunction copy_function, + bool is_pluggable_device); +}; + +void CopyDeviceToHost(const Tensor* input, Allocator* cpu_allocator, + Allocator* out_allocator, absl::string_view edge_name, + Device* src, Tensor* output, + DeviceContext* send_dev_context, StatusCallback done); + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_COPY_TENSOR_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/cost_constants.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/cost_constants.h new file mode 100644 index 00000000..df01bf53 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/cost_constants.h @@ -0,0 +1,57 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_COST_CONSTANTS_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_COST_CONSTANTS_H_ + +namespace tensorflow { + +// Types of per-request cost. +inline constexpr char kGpuCostName[] = "gpu"; +inline constexpr char kTpuCostName[] = "tpu"; +inline constexpr char kGcuCostName[] = "gcu"; +inline constexpr char kNoOpCostName[] = "no_op"; + +// Each type of per-request cost could have the following versions. +// +// A server may have costs that cannot be directly attributed to a specific +// query. Each request will be assigned a portion of it, and the cost ends with +// '_with_smear" includes this part. +inline constexpr char kWithSmearSuffix[] = "_with_smear"; +inline constexpr char kNoSmearSuffix[] = "_no_smear"; +inline constexpr char kNonBatchingSuffix[] = "_non_batching"; + +// Full names of per-request cost. +inline constexpr char kTpuWithSmearCostName[] = "tpu_with_smear"; +inline constexpr char kTpuNoSmearCostName[] = "tpu_no_smear"; +inline constexpr char kTpuDecodeWithSmearCostName[] = "tpu_decode_with_smear"; +inline constexpr char kTpuDecodeNoSmearCostName[] = "tpu_decode_no_smear"; +inline constexpr char kTpuPrefillWithSmearCostName[] = "tpu_prefill_with_smear"; +inline constexpr char kTpuPrefillNoSmearCostName[] = "tpu_prefill_no_smear"; +inline constexpr char kTpuNonBatchingCostName[] = "tpu_non_batching"; +inline constexpr char kGpuWithSmearCostName[] = "gpu_with_smear"; +inline constexpr char kGpuNoSmearCostName[] = "gpu_no_smear"; +inline constexpr char kGpuDecodeWithSmearCostName[] = "gpu_decode_with_smear"; +inline constexpr char kGpuDecodeNoSmearCostName[] = "gpu_decode_no_smear"; +inline constexpr char kGpuPrefillWithSmearCostName[] = "gpu_prefill_with_smear"; +inline constexpr char kGpuPrefillNoSmearCostName[] = "gpu_prefill_no_smear"; +inline constexpr char kGpuNonBatchingCostName[] = "gpu_non_batching"; +inline constexpr char kGcuWithSmearCostName[] = "gcu_with_smear"; +inline constexpr char kGcuNoSmearCostName[] = "gcu_no_smear"; +inline constexpr char kGcuNonBatchingCostName[] = "gcu_non_batching"; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_COST_CONSTANTS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/cost_measurement.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/cost_measurement.h new file mode 100644 index 00000000..3da322e5 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/cost_measurement.h @@ -0,0 +1,45 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_COST_MEASUREMENT_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_COST_MEASUREMENT_H_ + +#include "absl/strings/string_view.h" +#include "absl/time/time.h" + +namespace tensorflow { + +// An interface for cost measurement. +class CostMeasurement { + public: + // Context of the CostMeasurement. + struct Context { + // Whether this CostMeasurement is running within a per-query context (e.g. + // rpc handler) or not (e.g. batching). + bool is_per_query = false; + }; + + explicit CostMeasurement(const Context& context) {} + + virtual ~CostMeasurement() {} + + virtual absl::Duration GetTotalCost() = 0; + + virtual absl::string_view GetCostType() const = 0; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_COST_MEASUREMENT_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/cost_measurement_registry.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/cost_measurement_registry.h new file mode 100644 index 00000000..b2f17273 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/cost_measurement_registry.h @@ -0,0 +1,72 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_COST_MEASUREMENT_REGISTRY_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_COST_MEASUREMENT_REGISTRY_H_ + +#include +#include +#include +#include + +#include "absl/memory/memory.h" +#include "absl/strings/string_view.h" +#include "tensorflow/core/common_runtime/cost_measurement.h" + +namespace tensorflow { + +// CostMeasurementRegistry allows to +// - register a CostMeasurement type to the global map +// - create an instance of registered CostMeasurement. +class CostMeasurementRegistry { + public: + // Creates an instance of registered CostMeasurement by name. If the named + // CostMeasurement is not registered yet, returns nullptr. Any returned + // std::unique_ptr should not be moved. + // TODO(b/185852990): create a non-moveable wrapper class for the returned + // unique_ptr. + static std::unique_ptr CreateByNameOrNull( + const std::string& name, const CostMeasurement::Context& context); + + using Creator = std::function( + const CostMeasurement::Context&)>; + + // Registers a CostMeasurement type to the global map. Registering different + // types of CostMeasurement with the same name is prohibited. + static void RegisterCostMeasurement(absl::string_view name, Creator creator); +}; + +// Registers a CostMeasurement type to the global map. Registering different +// types of CostMeasurement with the same name is prohibited. +class CostMeasurementRegistrar { + public: + explicit CostMeasurementRegistrar(absl::string_view name, + CostMeasurementRegistry::Creator creator) { + CostMeasurementRegistry::RegisterCostMeasurement(name, std::move(creator)); + } +}; + +#define REGISTER_COST_MEASUREMENT(name, MyCostMeasurementClass) \ + namespace { \ + static ::tensorflow::CostMeasurementRegistrar \ + MyCostMeasurementClass##_registrar( \ + (name), [](const CostMeasurement::Context& context) { \ + return std::make_unique(context); \ + }); \ + } // namespace + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_COST_MEASUREMENT_REGISTRY_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/cost_util.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/cost_util.h new file mode 100644 index 00000000..aa1102c1 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/cost_util.h @@ -0,0 +1,39 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_COST_UTIL_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_COST_UTIL_H_ + +#include +#include + +#include "tensorflow/core/common_runtime/cost_measurement.h" +#include "tensorflow/core/common_runtime/request_cost_accessor.h" + +namespace tensorflow { + +// Creates instances of CostMeasurement. The types to create are determined by +// env. +std::vector> CreateCostMeasurements( + const CostMeasurement::Context& context); + +// Creates an instance of RequestCostAccessor. The type to create is determined +// by env. Returns nullptr if the type is not specified in env, or the type of +// CostMeasurement is unregistered.. +std::unique_ptr CreateRequestCostAccessor(); + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_COST_UTIL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/costmodel_manager.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/costmodel_manager.h new file mode 100644 index 00000000..8ea8a137 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/costmodel_manager.h @@ -0,0 +1,55 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_COSTMODEL_MANAGER_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_COSTMODEL_MANAGER_H_ + +#include + +#include "tensorflow/core/framework/cost_graph.pb.h" +#include "tensorflow/core/graph/costmodel.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/gtl/iterator_range.h" + +namespace tensorflow { + +// Used to manage all the cost models for a session. +class CostModelManager { + public: + ~CostModelManager(); + + typedef std::unordered_map CostModelMap; + typedef CostModelMap::iterator CostModelMapIter; + + void ExportCostModels(CostModelMap* cost_models) { + mutex_lock l(mu_); + *cost_models = cost_models_; + } + + CostModel* FindOrCreateCostModel(const Graph* graph); + + bool RemoveCostModelForGraph(const Graph* graph); + + absl::Status AddToCostGraphDef(const Graph* graph, CostGraphDef* cost_graph); + + private: + mutex mu_; + CostModelMap cost_models_ TF_GUARDED_BY(mu_); +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_COSTMODEL_MANAGER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/debugger_state_interface.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/debugger_state_interface.h new file mode 100644 index 00000000..1b9f190e --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/debugger_state_interface.h @@ -0,0 +1,123 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_DEBUGGER_STATE_INTERFACE_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_DEBUGGER_STATE_INTERFACE_H_ + +#include + +#include "tensorflow/core/common_runtime/device.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/protobuf/debug.pb.h" + +namespace tensorflow { + +// Returns a summary string for the list of debug tensor watches. +const string SummarizeDebugTensorWatches( + const protobuf::RepeatedPtrField& watches); + +// An abstract interface for storing and retrieving debugging information. +class DebuggerStateInterface { + public: + virtual ~DebuggerStateInterface() {} + + // Publish metadata about the debugged Session::Run() call. + // + // Args: + // global_step: A global step count supplied by the caller of + // Session::Run(). + // session_run_index: A chronologically sorted index for calls to the Run() + // method of the Session object. + // executor_step_index: A chronologically sorted index of invocations of the + // executor charged to serve this Session::Run() call. + // input_names: Name of the input Tensors (feed keys). + // output_names: Names of the fetched Tensors. + // target_names: Names of the target nodes. + virtual absl::Status PublishDebugMetadata( + const int64_t global_step, const int64_t session_run_index, + const int64_t executor_step_index, const std::vector& input_names, + const std::vector& output_names, + const std::vector& target_nodes) = 0; +}; + +class DebugGraphDecoratorInterface { + public: + virtual ~DebugGraphDecoratorInterface() {} + + // Insert special-purpose debug nodes to graph and dump the graph for + // record. See the documentation of DebugNodeInserter::InsertNodes() for + // details. + virtual absl::Status DecorateGraph(Graph* graph, Device* device) = 0; + + // Publish Graph to debug URLs. + virtual absl::Status PublishGraph(const Graph& graph, + const string& device_name) = 0; +}; + +typedef std::function( + const DebugOptions& options)> + DebuggerStateFactory; + +// Contains only static methods for registering DebuggerStateFactory. +// We don't expect to create any instances of this class. +// Call DebuggerStateRegistry::RegisterFactory() at initialization time to +// define a global factory that creates instances of DebuggerState, then call +// DebuggerStateRegistry::CreateState() to create a single instance. +class DebuggerStateRegistry { + public: + // Registers a function that creates a concrete DebuggerStateInterface + // implementation based on DebugOptions. + static void RegisterFactory(const DebuggerStateFactory& factory); + + // If RegisterFactory() has been called, creates and supplies a concrete + // DebuggerStateInterface implementation using the registered factory, + // owned by the caller and return an OK Status. Otherwise returns an error + // Status. + static absl::Status CreateState( + const DebugOptions& debug_options, + std::unique_ptr* state); + + private: + static DebuggerStateFactory* factory_; + + DebuggerStateRegistry(const DebuggerStateRegistry&) = delete; + void operator=(const DebuggerStateRegistry&) = delete; +}; + +typedef std::function( + const DebugOptions& options)> + DebugGraphDecoratorFactory; + +class DebugGraphDecoratorRegistry { + public: + static void RegisterFactory(const DebugGraphDecoratorFactory& factory); + + static absl::Status CreateDecorator( + const DebugOptions& options, + std::unique_ptr* decorator); + + private: + static DebugGraphDecoratorFactory* factory_; + + DebugGraphDecoratorRegistry(const DebugGraphDecoratorRegistry&) = delete; + void operator=(const DebugGraphDecoratorRegistry&) = delete; +}; + +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_DEBUGGER_STATE_INTERFACE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/device.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/device.h new file mode 100644 index 00000000..83785e33 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/device.h @@ -0,0 +1,20 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_DEVICE_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_DEVICE_H_ + +#include "tensorflow/core/framework/device.h" + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_DEVICE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/device/device_event_mgr.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/device/device_event_mgr.h new file mode 100644 index 00000000..7725a941 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/device/device_event_mgr.h @@ -0,0 +1,160 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_DEVICE_DEVICE_EVENT_MGR_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_DEVICE_DEVICE_EVENT_MGR_H_ + +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "tensorflow/core/lib/core/notification.h" +#include "tensorflow/core/lib/core/threadpool.h" +#include "tensorflow/core/lib/gtl/inlined_vector.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/stream_executor.h" +#include "tensorflow/core/platform/thread_annotations.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +// TODO(annarev): Check if we can use a more general option representation here +// that could work for other device types as well. +class GPUOptions; + +// The callback provided to EventMgr::ThenExecute must not block or take a long +// time. If it does, performance may be impacted and device memory may be +// exhausted. This macro is for checking that an EventMgr thread is not +// accidentally entering blocking parts of the code, e.g. the RPC subsystem. +// +// Intended use is something like +// +// void RespondToAnRPC(Params* params) { +// WARN_IF_IN_EVENT_MGR_THREAD; +// if (params->status.ok()) { ... +// +namespace device_event_mgr { +// Logs a stack trace if current execution thread belongs to this EventMgr +// object. If f is not nullptr, executes instead of logging the stack trace. +// trace. +void WarnIfInCallback(std::function f); +} // namespace device_event_mgr +#define WARN_IF_IN_EVENT_MGR_THREAD \ + ::tensorflow::device_event_mgr::WarnIfInCallback(nullptr) + +// EventMgr lets you register a callback to be executed when a given +// StreamExecutor stream completes all the work that's thus-far been enqueued on +// the stream. +class EventMgr { + public: + virtual ~EventMgr(); + + // Execute `func` when all pending stream actions have completed. func must + // be brief and non-blocking since it executes in the one thread used for all + // such callbacks and also buffer deletions. + void ThenExecute(se::Stream* stream, std::function func) { + ToFreeVector to_free; + { + mutex_lock l(mu_); + EnqueueCallback(stream, std::move(func)); + PollEvents(stream, &to_free); + } + FreeMemory(to_free); + } + + private: + friend class TEST_EventMgr; + friend class TEST_EventMgrHelper; + friend class EventMgrFactory; + + se::StreamExecutor* const exec_; + const int32 polling_active_delay_usecs_; + mutex mu_; + condition_variable events_pending_ TF_GUARDED_BY(mu_); + + struct InUse { + se::Event* event; + std::function func; + }; + + typedef absl::InlinedVector ToFreeVector; + + EventMgr(se::StreamExecutor* se, const GPUOptions& gpu_options); + + void FreeMemory(const ToFreeVector& to_free) { + for (const auto& iu : to_free) { + // The function must be called in another thread. + if (iu.func != nullptr) threadpool_.Schedule(iu.func); + } + } + + // Set up `func` to be called once `stream` completes all its outstanding + // work. + void EnqueueCallback(se::Stream* stream, std::function func) + TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); + + // This function should be called at roughly the same tempo as QueueTensors() + // to check whether pending events have recorded, and then retire them. + // + // If `stream` is not null, we only poll events for that stream. Otherwise we + // poll events for all streams. + void PollEvents(se::Stream* stream, ToFreeVector* to_free) + TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); + + // An internal polling loop that runs at a low frequency to clear straggler + // Events. + void PollLoop(); + + // Setup/Teardown functions for the polling loop. + void StartPollingLoop(); + void StopPollingLoop(); + + // A stack of unused events + std::vector> free_events_ TF_GUARDED_BY(mu_); + + // Callbacks waiting on their events to complete. + absl::flat_hash_map< + se::Stream*, + std::deque, std::function>>> + callbacks_ TF_GUARDED_BY(mu_); + + bool stop_polling_ TF_GUARDED_BY(mu_); + std::unique_ptr polling_stopped_; + + // The main PollLoop for the event manager runs in this threadpool. + thread::ThreadPool threadpool_; +}; + +// Manages all the EventMgr instances. +class EventMgrFactory { + public: + static EventMgrFactory* Singleton(); + + EventMgr* GetEventMgr(se::StreamExecutor* se, const GPUOptions& gpu_options); + + private: + mutex mu_; + + // Maintain one EventMgr per physical device (StreamExecutor is + // per-physical-device). + absl::flat_hash_map event_mgr_map_ + TF_GUARDED_BY(mu_); +}; + +} // namespace tensorflow +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_DEVICE_DEVICE_EVENT_MGR_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/device/device_host_allocator.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/device/device_host_allocator.h new file mode 100644 index 00000000..0ed688fc --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/device/device_host_allocator.h @@ -0,0 +1,28 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_DEVICE_DEVICE_HOST_ALLOCATOR_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_DEVICE_DEVICE_HOST_ALLOCATOR_H_ + +#include "xla/stream_executor/integrations/device_host_allocator.h" +#include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/stream_executor.h" + +namespace tensorflow { +using stream_executor::DeviceHostAllocator; // NOLINT +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_DEVICE_DEVICE_HOST_ALLOCATOR_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/device/device_id.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/device/device_id.h new file mode 100644 index 00000000..d64a83cb --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/device/device_id.h @@ -0,0 +1,91 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_DEVICE_DEVICE_ID_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_DEVICE_DEVICE_ID_H_ + +#include "xla/tsl/framework/device_id.h" +#include "tensorflow/core/lib/gtl/int_type.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +// There are three types of device ids: +// - *physical* device id: this is the integer index of a device in the +// physical machine, it can be filtered (for e.g. using environment variable +// CUDA_VISIBLE_DEVICES when using CUDA). Note that this id is not visible to +// Tensorflow, but result after filtering is visible to TF and is called +// platform device id as below. +// For CUDA, see +// http://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#env-vars +// for more details. +// - *platform* device id (also called *visible* device id in +// third_party/tensorflow/core/protobuf/config.proto): this is the id that is +// visible to Tensorflow after filtering (for e.g. by CUDA_VISIBLE_DEVICES). +// For CUDA, this id is generated by the CUDA GPU driver. It starts from 0 +// and is used for CUDA API calls like cuDeviceGet(). +// - TF device id (also called *virtual* device id in +// third_party/tensorflow/core/protobuf/config.proto): this is the id that +// Tensorflow generates and exposes to its users. It is the id in the +// field of the device name "/device:GPU:", and is also the identifier of +// a BaseGPUDevice. Note that the configuration allows us to create multiple +// BaseGPUDevice per GPU hardware in order to use multi CUDA streams on the +// hardware, so the mapping between TF GPU id and platform GPU id is not a 1:1 +// mapping, see the example below. +// +// For example, assuming that in the machine we have GPU device with index 0, 1, +// 2 and 3 (physical GPU id). Setting "CUDA_VISIBLE_DEVICES=1,2,3" will create +// the following mapping between platform GPU id and physical GPU id: +// +// platform GPU id -> physical GPU id +// 0 -> 1 +// 1 -> 2 +// 2 -> 3 +// +// Note that physical GPU id 0 is invisible to TF so there is no mapping entry +// for it. +// +// Assuming we configure the Session to create one BaseGPUDevice per GPU +// hardware, then setting GPUOptions::visible_device_list to "2,0" will create +// the following mapping between TF device id and platform device id: +// +// TF GPU id -> platform GPU ID +// 0 (i.e. /device:GPU:0) -> 2 +// 1 (i.e. /device:GPU:1) -> 0 +// +// Note that platform device id 1 is filtered out by +// GPUOptions::visible_device_list, so it won't be used by the TF process. +// +// On the other hand, if we configure it to create 2 BaseGPUDevice per GPU +// hardware, then setting GPUOptions::visible_device_list to "2,0" will create +// the following mapping between TF device id and platform device id: +// +// TF GPU id -> platform GPU ID +// 0 (i.e. /device:GPU:0) -> 2 +// 1 (i.e. /device:GPU:1) -> 2 +// 2 (i.e. /device:GPU:2) -> 0 +// 3 (i.e. /device:GPU:3) -> 0 +// +// We create strong-typed integer classes for both TF device id and platform +// device id to minimize programming errors and improve code readability. Except +// for the StreamExecutor interface (as we don't change its API), whenever we +// need a TF device id (or platform device id) we should use TfDeviceId (or +// PlatformDeviceId) instead of a raw integer. +using tsl::PlatformDeviceId; // NOLINT +using tsl::TfDeviceId; // NOLINT + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_DEVICE_DEVICE_ID_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/device/device_id_manager.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/device/device_id_manager.h new file mode 100644 index 00000000..058e94fb --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/device/device_id_manager.h @@ -0,0 +1,28 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_DEVICE_DEVICE_ID_MANAGER_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_DEVICE_DEVICE_ID_MANAGER_H_ + +#include "xla/tsl/framework/device_id_manager.h" +#include "tensorflow/core/common_runtime/device/device_id.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { +using tsl::DeviceIdManager; // NOLINT +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_DEVICE_DEVICE_ID_MANAGER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/device/device_mem_allocator.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/device/device_mem_allocator.h new file mode 100644 index 00000000..44e516b9 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/device/device_mem_allocator.h @@ -0,0 +1,28 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_DEVICE_DEVICE_MEM_ALLOCATOR_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_DEVICE_DEVICE_MEM_ALLOCATOR_H_ + +#include "xla/stream_executor/integrations/device_mem_allocator.h" +#include "tensorflow/core/common_runtime/device/device_id.h" +#include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/platform/stream_executor.h" + +namespace tensorflow { +using stream_executor::DeviceMemAllocator; // NOLINT +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_DEVICE_DEVICE_MEM_ALLOCATOR_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/device/device_utils.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/device/device_utils.h new file mode 100644 index 00000000..5447c729 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/device/device_utils.h @@ -0,0 +1,41 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_DEVICE_DEVICE_UTILS_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_DEVICE_DEVICE_UTILS_H_ + +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/stringpiece.h" + +namespace tensorflow { +namespace device_utils { + +// Validate device type. Device type must start with a capital letter and +// consist of capital letters and underscores. Reasoning behind this decision: +// * At the minimum we want to disallow '/' and ':' since +// these characters are used in device spec, for e.g. +// /job:foo/replica:12/device:GPU:1. +// * Underscores seem useful, for e.g. XLA_GPU uses underscores. +// * Allowing lowercase might get confusing. For example, say someone +// registers a new type called "Gpu". It might be confusing for users that +// "Gpu" is not the same device type as "GPU". +// Note that lowercase "cpu" and "gpu" are currently supported only for +// legacy reasons: +// https://cs.opensource.google/tensorflow/tensorflow/+/master:tensorflow/python/framework/device_spec.py;l=46;drc=d3a378f9665d8eee827c74cb9ecbee81e4c288dd +absl::Status ValidateDeviceType(absl::string_view type); + +} // namespace device_utils +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_DEVICE_DEVICE_UTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/device_factory.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/device_factory.h new file mode 100644 index 00000000..1b5a6626 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/device_factory.h @@ -0,0 +1,20 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_DEVICE_FACTORY_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_DEVICE_FACTORY_H_ + +#include "tensorflow/core/framework/device_factory.h" + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_DEVICE_FACTORY_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/device_id_utils.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/device_id_utils.h new file mode 100644 index 00000000..f0cab86b --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/device_id_utils.h @@ -0,0 +1,42 @@ + +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_DEVICE_ID_UTILS_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_DEVICE_ID_UTILS_H_ + +#include "xla/stream_executor/platform.h" +#include "xla/stream_executor/stream_executor.h" +#include "xla/tsl/framework/device_id.h" +#include "xla/tsl/framework/device_id_manager.h" + +namespace tensorflow { + +// Utility method for getting the associated executor given a TfDeviceId. +class DeviceIdUtil { + public: + static absl::StatusOr ExecutorForTfDeviceId( + const tsl::DeviceType& type, stream_executor::Platform* device_manager, + tsl::TfDeviceId tf_device_id) { + tsl::PlatformDeviceId platform_device_id; + TF_RETURN_IF_ERROR(tsl::DeviceIdManager::TfToPlatformDeviceId( + type, tf_device_id, &platform_device_id)); + return device_manager->ExecutorForDevice(platform_device_id.value()); + } +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_DEVICE_ID_UTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/device_mgr.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/device_mgr.h new file mode 100644 index 00000000..3e0abb14 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/device_mgr.h @@ -0,0 +1,180 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_DEVICE_MGR_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_DEVICE_MGR_H_ + +#include +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_set.h" +#include "tensorflow/core/common_runtime/device.h" +#include "tensorflow/core/lib/core/arena.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/gtl/inlined_vector.h" +#include "tensorflow/core/platform/macros.h" + +namespace tensorflow { + +class DeviceAttributes; + +// Represents a set of devices. +class DeviceMgr { + public: + DeviceMgr() = default; + virtual ~DeviceMgr(); + + // Returns attributes of all devices. + virtual void ListDeviceAttributes( + std::vector* devices) const = 0; + + // Returns raw pointers to the underlying devices. + virtual std::vector ListDevices() const = 0; + + // Returns a string listing all devices. + virtual string DebugString() const = 0; + + // Returns a string of all the device mapping. + virtual string DeviceMappingString() const = 0; + + // Assigns *device with pointer to Device of the given name. + // Accepts either a full device name, or just the replica-local suffix. + virtual absl::Status LookupDevice(absl::string_view name, + Device** device) const = 0; + + // Check if the current device manager contains device with the given + // incarnation ID. Looking up by incarnation IDs because they are randomly + // generated and not intentionally reused (unlike device pointers). + virtual bool ContainsDevice(int64_t device_incarnation) const = 0; + + // Clears given containers of all devices if 'container' is + // non-empty. Otherwise, clears default containers of all devices. + virtual void ClearContainers(absl::Span containers) const = 0; + + virtual int NumDeviceType(const string& type) const = 0; + + virtual int NumDevices() const = 0; + + // Returns an arbitrary CPU device if one is present, otherwise return + // nullptr. + virtual Device* HostCPU() const = 0; + + DeviceMgr(const DeviceMgr&) = delete; + void operator=(const DeviceMgr&) = delete; +}; + + +// Size of stale device buffer for temporary storage of removed devices. +static const size_t kStaleDeviceBufferSize = 8192; + +// Represents a dynamic set of devices +class DynamicDeviceMgr : public DeviceMgr { + public: + // Constructs an empty DynamicDeviceMgr. + DynamicDeviceMgr(); + + // Constructs a DynamicDeviceMgr from a list of devices. + explicit DynamicDeviceMgr(std::vector>&& devices); + explicit DynamicDeviceMgr(std::unique_ptr&& device); + + ~DynamicDeviceMgr() override; + + void ListDeviceAttributes( + std::vector* devices) const override; + std::vector ListDevices() const override; + string DebugString() const override; + string DeviceMappingString() const override; + absl::Status LookupDevice(absl::string_view name, + Device** device) const override; + bool ContainsDevice(int64_t device_incarnation) const override; + void ClearContainers(absl::Span containers) const override; + int NumDeviceType(const string& type) const override; + int NumDevices() const override; + Device* HostCPU() const override; + + // Add devices to device manager. Returns error for repeated device names. + absl::Status AddDevices(std::vector> devices); + + // Remove devices from device manager. + // Returns error for non-existing devices or if the HostCPU() device is in the + // input list. If an error is returned, the device list is not modified. + absl::Status RemoveDevices(const std::vector& devices); + + // Remove devices from device manager by their names. Returns error for + // non-existing devices or if the HostCPU() device is given in the input list. + // If an error is returned, the device list is not modified. + absl::Status RemoveDevicesByName(const std::vector& device_names); + + private: + mutable mutex devices_mu_; + + // Using an ordered map to ensure deterministic ordering of devices. + // Not a set, because we need to do find(Device*) and own the devices + // at the same time. + // We still have to override C++'s default pointer ordering. + struct DereferenceDevicePtrLess { + bool operator()(const Device* a, const Device* b) const { + return Device::LessByParsedName(*a, *b); + } + }; + std::map, DereferenceDevicePtrLess> + dynamic_devices_ TF_GUARDED_BY(devices_mu_); + + absl::flat_hash_set device_incarnation_set_ + TF_GUARDED_BY(devices_mu_); + std::unordered_map device_map_ TF_GUARDED_BY(devices_mu_); + + std::unordered_map device_type_counts_ + TF_GUARDED_BY(devices_mu_); + + mutable std::atomic cpu_device_; // memoize `HostCPU` result + + class DeviceCircularBuffer { + public: + DeviceCircularBuffer() : index_(0) { + devices_.resize(kStaleDeviceBufferSize); + } + void add(std::unique_ptr device) { + devices_[index_] = std::move(device); + index_ = (index_ + 1) % kStaleDeviceBufferSize; + } + + private: + int index_; + std::vector> devices_; + }; + + // Buffer to temporarily store the removed devices. Raw device pointers are + // accessible to DeviceSet, and if the function instantiation process directly + // access fields through the device set, the underlying device object must + // still be available to avoid segmentation fault. We keep the devices in this + // buffer only for that purpose. + DeviceCircularBuffer stale_devices_ TF_GUARDED_BY(devices_mu_); + + DynamicDeviceMgr(const DynamicDeviceMgr&) = delete; + void operator=(const DynamicDeviceMgr&) = delete; +}; + +// TODO(b/183966398): Remove StaticDeviceMgr since there's no usage. +using StaticDeviceMgr = DynamicDeviceMgr; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_DEVICE_MGR_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/device_propagation.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/device_propagation.h new file mode 100644 index 00000000..20f5f916 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/device_propagation.h @@ -0,0 +1,49 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_DEVICE_PROPAGATION_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_DEVICE_PROPAGATION_H_ + +#include +#include + +#include "absl/container/flat_hash_set.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/platform/stringpiece.h" + +namespace tensorflow { + +namespace device_propagation { + +typedef std::function DeviceFilter; +typedef std::function NodeFilter; +} // namespace device_propagation + +// Propagates device assignments from a certain types of nodes to their outputs +// to avoid unnecessary D2H or H2D copies. +// If an node satisfies the following conditions, it will be placed on the same +// device as its inputs: +// (1) The node can accept device update (`node_filter` returns true). +// (2) The node itself has no requested or assigned devices. +// (3) The source nodes of this node's input edges, except for edges that are +// "LoopCond->Switch" or "Enter->Merge", are all placed on the same device. +// (4) The device can be propagated (`device_filter` returns true) +void PropagateDevices(const device_propagation::NodeFilter& node_filter, + const device_propagation::DeviceFilter& device_filter, + Graph* graph); + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_DEVICE_PROPAGATION_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/device_resolver_local.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/device_resolver_local.h new file mode 100644 index 00000000..814bea88 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/device_resolver_local.h @@ -0,0 +1,47 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_DEVICE_RESOLVER_LOCAL_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_DEVICE_RESOLVER_LOCAL_H_ + +#include +#include + +#include "tensorflow/core/framework/collective.h" +#include "tensorflow/core/framework/device_attributes.pb.h" +#include "tensorflow/core/platform/status.h" + +namespace tensorflow { +class DeviceMgr; + +// Implements DeviceResolverInterface in a single-task context. +class DeviceResolverLocal : public DeviceResolverInterface { + public: + explicit DeviceResolverLocal(const DeviceMgr* dev_mgr) : dev_mgr_(dev_mgr) {} + + absl::Status GetDeviceAttributes(const string& device, + DeviceAttributes* attributes) override; + + absl::Status GetAllDeviceAttributes( + const string& task, std::vector* attributes) override; + + absl::Status UpdateDeviceAttributes( + const std::vector& attributes) override; + + protected: + const DeviceMgr* dev_mgr_; +}; + +} // namespace tensorflow +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_DEVICE_RESOLVER_LOCAL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/device_set.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/device_set.h new file mode 100644 index 00000000..16dcd0ad --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/device_set.h @@ -0,0 +1,139 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_DEVICE_SET_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_DEVICE_SET_H_ + +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "tensorflow/core/common_runtime/device.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/device_name_utils.h" + +namespace tensorflow { + +typedef std::vector> PrioritizedDeviceVector; + +// DeviceSet is a container class for managing the various types of +// devices used by a model. +class DeviceSet { + public: + DeviceSet(); + ~DeviceSet(); + + // Does not take ownership of 'device'. + void AddDevice(Device* device) TF_LOCKS_EXCLUDED(devices_mu_); + + // Set the device designated as the "client". This device + // must also be registered via AddDevice(). + void set_client_device(Device* device) { + DCHECK(client_device_ == nullptr); + client_device_ = device; + } + + // Returns a pointer to the device designated as the "client". + Device* client_device() const { return client_device_; } + + // Return the list of devices in this set. + const std::vector& devices() const { return devices_; } + + // Given a DeviceNameUtils::ParsedName (which may have some + // wildcards for different components), fills "*devices" with all + // devices in "*this" that match "spec". + void FindMatchingDevices(const DeviceNameUtils::ParsedName& spec, + std::vector* devices) const; + + // Finds the device with the given "fullname". Returns nullptr if + // not found. + Device* FindDeviceByName(const string& fullname) const; + + // Return the list of unique device types in this set, ordered + // with more preferable devices earlier. + std::vector PrioritizedDeviceTypeList() const; + + // Return the prioritized list of devices in this set. + // Devices are prioritized first by `DeviceTypeOrder`, then by name. + const PrioritizedDeviceVector& prioritized_devices() const + TF_LOCKS_EXCLUDED(devices_mu_); + + // Return the prioritized list of unique device types in this set. + // + // The list will be ordered by decreasing priority. The priorities (the second + // element in the list's `std::pair`) will be initialized + // to the value of `DeviceTypeOrder` for the device types. + const PrioritizedDeviceTypeVector& prioritized_device_types() const + TF_LOCKS_EXCLUDED(devices_mu_); + + // An order to sort by device types according to system-determined + // priority. + // + // Higher result implies higher priority. + static int DeviceTypeOrder(const DeviceType& d); + + // Sorts a PrioritizedDeviceVector according to devices and explicit + // priorities. + // + // After a call to this function, the argument vector will be sorted by + // explicit priority (the second element in the `std::pair`), then by `DeviceTypeOrder` of the device type, then by device + // locality, and lastly by device name. + static void SortPrioritizedDeviceVector(PrioritizedDeviceVector* vector); + + // Sorts a PrioritizedDeviceTypeVector according to types and explicit + // priorities. + // + // After a call to this function, the argument vector will be sorted by + // explicit priority (the second element in the `std::pair`), then by `DeviceTypeOrder` of the device type. + static void SortPrioritizedDeviceTypeVector( + PrioritizedDeviceTypeVector* vector); + + private: + mutable mutex devices_mu_; + + mutable absl::flat_hash_map> + matching_device_cache_; + + // Not owned. + std::vector devices_; + + // Cached prioritized vector, created on-the-fly when + // prioritized_devices() is called. + mutable PrioritizedDeviceVector prioritized_devices_ + TF_GUARDED_BY(devices_mu_); + + // Cached prioritized vector, created on-the-fly when + // prioritized_device_types() is called. + mutable PrioritizedDeviceTypeVector prioritized_device_types_ + TF_GUARDED_BY(devices_mu_); + + // Fullname -> device* for device in devices_. + std::unordered_map device_by_name_; + + // client_device_ points to an element of devices_ that we consider + // to be the client device (in this local process). + Device* client_device_ = nullptr; + + DeviceSet(const DeviceSet&) = delete; + void operator=(const DeviceSet&) = delete; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_DEVICE_SET_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/direct_session.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/direct_session.h new file mode 100644 index 00000000..c43827ee --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/direct_session.h @@ -0,0 +1,449 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_DIRECT_SESSION_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_DIRECT_SESSION_H_ + +#include +#include +#include +#include +#include +#include + +#include "tensorflow/core/common_runtime/costmodel_manager.h" +#include "tensorflow/core/common_runtime/debugger_state_interface.h" +#include "tensorflow/core/common_runtime/device_mgr.h" +#include "tensorflow/core/common_runtime/device_set.h" +#include "tensorflow/core/common_runtime/executor.h" +#include "tensorflow/core/common_runtime/graph_execution_state.h" +#include "tensorflow/core/common_runtime/process_function_library_runtime.h" +#include "tensorflow/core/common_runtime/rendezvous_mgr.h" +#include "tensorflow/core/common_runtime/session_factory.h" +#include "tensorflow/core/framework/cancellation.h" +#include "tensorflow/core/framework/collective.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/session_state.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/thread_annotations.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/public/session.h" + +namespace tensorflow { + +class CostModel; +class DebugGateway; +class Device; +class DirectSessionFactory; + +class DirectSession : public Session { + public: + typedef std::function CloseCallback; + + // Takes ownership of 'device_mgr'. + // 'factory' is used to unregister the DirectSession with 'factory' when its + // closed. This ensures that Reset requests from the 'factory' don't get sent + // to sessions that are already closed. + DirectSession(const SessionOptions& options, const DeviceMgr* device_mgr, + DirectSessionFactory* factory); + ~DirectSession() override; + + typedef std::vector> NamedTensorList; + typedef std::unordered_map + NameNodeMap; + + absl::Status Create(const GraphDef& graph) override; + absl::Status Create(GraphDef&& graph) override; + absl::Status Extend(const GraphDef& graph) override; + absl::Status Extend(GraphDef&& graph) override; + absl::Status Run(const NamedTensorList& inputs, + const std::vector& output_names, + const std::vector& target_nodes, + std::vector* outputs) override; + + // NOTE: Experimental and subject to change. + absl::Status Run(const ::tensorflow::RunOptions& run_options, + const NamedTensorList& inputs, + const std::vector& output_names, + const std::vector& target_nodes, + std::vector* outputs, + RunMetadata* run_metadata) override; + + // NOTE: Experimental and subject to change. + absl::Status Run( + const ::tensorflow::RunOptions& run_options, + const NamedTensorList& inputs, const std::vector& output_names, + const std::vector& target_nodes, std::vector* outputs, + RunMetadata* run_metadata, + const thread::ThreadPoolOptions& threadpool_options) override; + + // NOTE: PRunSetup and PRun are added to support partial execution. This + // feature is experimental and subject to change. + absl::Status PRunSetup(const std::vector& input_names, + const std::vector& output_names, + const std::vector& target_nodes, + string* handle) override; + absl::Status PRun(const string& handle, const NamedTensorList& inputs, + const std::vector& output_names, + std::vector* outputs) override; + + // Reset clears 'containers' from the device_mgr of the DirectSession. + // If 'containers' is empty, then Reset clears the default container. + absl::Status Reset(const std::vector& containers); + + absl::Status ListDevices(std::vector* response) override; + absl::Status Close() override; + absl::Status LocalDeviceManager(const DeviceMgr** output) override { + *output = device_mgr_.get(); + return absl::OkStatus(); + } + + void ExportCostModels(CostModelManager::CostModelMap* cost_models) { + cost_model_manager_.ExportCostModels(cost_models); + } + + absl::Status MakeCallable(const CallableOptions& callable_options, + CallableHandle* out_handle) override; + + absl::Status RunCallable(CallableHandle handle, + const std::vector& feed_tensors, + std::vector* fetch_tensors, + RunMetadata* run_metadata) override; + + absl::Status RunCallable( + CallableHandle handle, const std::vector& feed_tensors, + std::vector* fetch_tensors, RunMetadata* run_metadata, + const thread::ThreadPoolOptions& threadpool_options) override; + + absl::Status ReleaseCallable(CallableHandle handle) override; + + absl::Status Finalize() override; + + const SessionOptions& options() const { return options_; } + + private: + // For access to collective_graph_key_. + friend class DirectSessionCollectiveTest; + + // We create one executor and its dependent library runtime for + // every partition. + struct PerPartitionExecutorsAndLib { + std::unique_ptr graph = nullptr; + Device* device = nullptr; // not owned. + FunctionLibraryRuntime* flib = nullptr; // not owned. + std::unique_ptr executor; + }; + + // An ExecutorsAndKeys is created for a given set of feeds/fetches. + // 'step_count' is the number of times this graph is executed. + // 'graph' is the entire graph being executed. 'name_to_node' + // maps node name to node. We keep 'graph' and 'name_to_node' only in + // the case of partial runs. Each item in 'items' is the executor for + // a partition of the graph bundled with its dependent library runtime. + // 'input_keys' are the rendezvous keys for the feeds and 'output_keys' + // are rendezvous keys for the fetches. + struct ExecutorsAndKeys { + ExecutorsAndKeys() : step_count(0) {} + + std::atomic_int_fast64_t step_count; + std::unique_ptr graph; + NameNodeMap name_to_node; + std::vector items; + std::unordered_map input_name_to_index; + std::unordered_map input_name_to_rendezvous_key; + std::unordered_map output_name_to_index; + std::unordered_map output_name_to_rendezvous_key; + + DataTypeVector input_types; + DataTypeVector output_types; + + CallableOptions callable_options; + + int64_t collective_graph_key = BuildGraphOptions::kNoCollectiveGraphKey; + }; + + // A FunctionInfo object is created for every unique set of feeds/fetches. + // This info could be folded into the ExecutorsAndKeys object but we would + // like to maintain a deletion order in which the OpKernels (owned by the + // executor) should be destroyed first, followed by the resources in the + // device and then followed by the function stuff. + // TODO(rohanj): Consolidate function library definitions so that we can + // instantiate only one ProcFLR and lib_def and make this just a member + // variable and not a vector. + // 'flib_def' is the function library used. + // 'proc_flr' is the collection of FunctionLibraryRuntime objects, one per + // device. + struct FunctionInfo { + std::unique_ptr flib_def; + std::unique_ptr proc_flr; + }; + + // For each live Run() call, the session maintains a RunState. + // 'status' is the current status of the execution. + struct RunState { + mutex mu; + absl::Status status TF_GUARDED_BY(mu); + std::unique_ptr collective_executor; + std::unique_ptr collector; + TensorStore tensor_store; + ScopedStepContainer step_container; + + RunState(int64_t step_id, const std::vector* devices); + }; + + // For each live partial execution, the session maintains a PartialRunState. + // 'executor_done' is "notified" when all executors are done. 'pending_inputs' + // are the set of pending feeds and 'pending_outputs' are the set of pending + // fetches. + struct PartialRunState : public RunState { + Notification executors_done; + std::unordered_map pending_inputs; // true if fed + std::unordered_map pending_outputs; // true if fetched + core::RefCountPtr rendez = nullptr; + + PartialRunState(const std::vector& pending_input_names, + const std::vector& pending_output_names, + int64_t step_id, const std::vector* devices); + + // Returns true if all pending inputs and outputs have been completed. + bool PendingDone() const; + + ~PartialRunState(); + }; + + struct RunStateArgs { + explicit RunStateArgs(const DebugOptions& options) + : debug_options(options) {} + + bool is_partial_run = false; + string handle; + std::unique_ptr graph; + const DebugOptions& debug_options; + int64_t collective_graph_key = BuildGraphOptions::kNoCollectiveGraphKey; + }; + + // Retrieves an already existing set of executors to run 'inputs' and + // 'outputs', or creates and caches them for future use. + absl::Status GetOrCreateExecutors(absl::Span inputs, + absl::Span outputs, + absl::Span target_nodes, + ExecutorsAndKeys** executors_and_keys, + RunStateArgs* run_state_args); + + // Creates a set of executors to run the subgraph defined by + // `callable_options`. + absl::Status CreateExecutors( + const CallableOptions& callable_options, + std::unique_ptr* out_executors_and_keys, + std::unique_ptr* out_func_info, + RunStateArgs* run_state_args); + + // Creates several graphs given the existing graph_def_ and the + // input feeds and fetches, given 'devices'. The graphs share a common + // function library 'flib_def'. + absl::Status CreateGraphs( + const BuildGraphOptions& options, + std::unordered_map>* outputs, + std::unique_ptr* flib_def, + RunStateArgs* run_state_args, DataTypeVector* input_types, + DataTypeVector* output_types, int64_t* collective_graph_key); + + absl::Status RunInternal(int64_t step_id, const RunOptions& run_options, + CallFrameInterface* call_frame, + ExecutorsAndKeys* executors_and_keys, + RunMetadata* run_metadata, + const thread::ThreadPoolOptions& threadpool_options); + + // Returns whether inter-op execution uses a global pool or the input + // `run_options` requests being run on inter_op_thread_pool = 0 in case + // multiple pools are configured. + bool ShouldUseRunHandlerPool(const RunOptions& run_options) const; + + absl::Status ExtendLocked(GraphDef&& graph) + TF_EXCLUSIVE_LOCKS_REQUIRED(graph_state_lock_); + + absl::Status ResourceHandleToInputTensor(const Tensor& resource_tensor, + Tensor* retrieved_tensor); + + // Feeds more inputs to the executors, triggering further execution. + absl::Status SendPRunInputs( + const std::vector>& inputs, + const ExecutorsAndKeys* executors_and_keys, + IntraProcessRendezvous* rendez); + + // Fetches more outputs from the executors. It waits until the output + // tensors are computed. + absl::Status RecvPRunOutputs(const std::vector& output_names, + const ExecutorsAndKeys* executors_and_keys, + PartialRunState* run_state, + std::vector* outputs); + + // Check if the specified fetches can be computed from the feeds + // that we have already provided. + absl::Status CheckFetch(const std::vector>& feeds, + const std::vector& fetches, + const ExecutorsAndKeys* executors_and_keys, + const PartialRunState* run_state); + + // Use the appropriate WaitForNotification function based on whether + // operation_timeout_in_ms is greater than 0. + // + // If the timeout expires, the `cm->StartCancel()` will be called. + absl::Status WaitForNotification(Notification* n, int64_t timeout_in_ms); + void WaitForNotification(Notification* n, RunState* run_state, + CancellationManager* cm, int64_t timeout_in_ms); + + absl::Status CheckNotClosed() { + mutex_lock l(closed_lock_); + if (closed_) return errors::Cancelled("Session has been closed."); + return absl::OkStatus(); + } + + absl::Status CheckGraphCreated(const char* method) { + mutex_lock l(graph_state_lock_); + if (!graph_created_) { + return errors::InvalidArgument( + "Session was not created with a graph before ", method, "!"); + } + return absl::OkStatus(); + } + + absl::Status CreateDebuggerState( + const CallableOptions& options, int64_t global_step, + int64_t session_run_index, int64_t executor_step_index, + std::unique_ptr* debugger_state); + + absl::Status DecorateAndPublishGraphForDebug( + const DebugOptions& debug_options, Graph* graph, Device* device); + + const SessionOptions options_; + + // Device structures. + const std::unique_ptr device_mgr_; + std::vector devices_; // not owned + DeviceSet device_set_; + + // Unique session identifier. + string session_handle_; + mutex graph_state_lock_; + bool graph_created_ TF_GUARDED_BY(graph_state_lock_) = false; + bool finalized_ TF_GUARDED_BY(graph_state_lock_) = false; + + // The thread-pools to use for running ops, with a bool indicating if the pool + // is owned. + std::vector> thread_pools_; + + absl::Status init_error_; // Set to an error if construction failed. + + // If true, blocks until device has finished all queued operations in a step. + bool sync_on_finish_ = true; + + std::vector> functions_ + TF_GUARDED_BY(executor_lock_); + + mutex executor_lock_; // protects executors_ + // Holds mappings from signature to the executors that process + // it. The reason for a level of indirection around mapped_type is + // to guarantee address stability. + // The map value is a shared_ptr since multiple map keys can point to the + // same ExecutorsAndKey object. + std::unordered_map> executors_ + TF_GUARDED_BY(executor_lock_); + + class RunCallableCallFrame; + struct Callable { + std::shared_ptr executors_and_keys; + std::shared_ptr function_info; + ~Callable(); + }; + mutex callables_lock_; + int64_t next_callable_handle_ TF_GUARDED_BY(callables_lock_) = 0; + std::unordered_map callables_ + TF_GUARDED_BY(callables_lock_); + + // Holds mappings from handle to partial run state. + std::unordered_map> partial_runs_ + TF_GUARDED_BY(executor_lock_); + + // This holds all the tensors that are currently alive in the session. + SessionState session_state_; + + DirectSessionFactory* const factory_; // not owned + CancellationManager* cancellation_manager_; + std::unique_ptr collective_executor_mgr_; + + // Map of placed stateful nodes, i.e. nodes for which is_stateful() + // is true, such as "params" and "queue" nodes. Once placed these + // nodes can not be moved to a different device. Maps node names to + // device names. + std::unordered_map stateful_placements_ + TF_GUARDED_BY(graph_state_lock_); + + // Execution_state; used when placing the entire graph. + std::unique_ptr execution_state_ + TF_GUARDED_BY(graph_state_lock_); + + // The function library, before any rewrites or optimizations have been + // performed. In particular, CreateGraphs() may need to modify the function + // library; it copies and modifies the function library. + std::unique_ptr flib_def_; + + // true if the Session has been Closed. + mutex closed_lock_; + bool closed_ TF_GUARDED_BY(closed_lock_) = false; + + // For generating unique names for this session instance. + std::atomic edge_name_counter_ = {0}; + std::atomic handle_name_counter_ = {0}; + + // For generating step ids that are unique among all sessions. + static std::atomic_int_fast64_t step_id_counter_; + + // Global timeout for all blocking operations in this session. + const int64_t operation_timeout_in_ms_ = 0; + + // Manages all the cost models for the graphs executed in this session. + CostModelManager cost_model_manager_; + + // For testing collective graph key generation. + mutex collective_graph_key_lock_; + int64_t collective_graph_key_ TF_GUARDED_BY(collective_graph_key_lock_) = -1; + + // Run in caller's thread if RunOptions.inter_op_thread_pool is negative or + // all of following conditions are met: + // 1. This session doesn't own any thread pool. + // 2. RunOptions.inter_op_thread_pool is unspecified or 0. + // 3. This session has a single executor. + // 4. config.inter_op_parallelism_threads is specified to negative explicitly + // or through environment variable TF_NUM_INTEROP_THREADS. + // 5. RunOptions.experimental.use_run_handler_pool is unspecified or false. + // Otherwise run in global thread pool, session owned thread pool or handler + // pool according to other specifications of RunOptions and ConfigProto. + bool run_in_caller_thread_ = false; + + DirectSession(const DirectSession&) = delete; + void operator=(const DirectSession&) = delete; + + // EXPERIMENTAL: debugger (tfdbg) related + friend class DebugGateway; +}; + +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_DIRECT_SESSION_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/dma_helper.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/dma_helper.h new file mode 100644 index 00000000..4a76cff1 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/dma_helper.h @@ -0,0 +1,38 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_DMA_HELPER_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_DMA_HELPER_H_ + +#include "tensorflow/core/framework/tensor.h" + +namespace tensorflow { + +// For TensorFlow internal use only. +class DMAHelper { + public: + static bool CanUseDMA(const Tensor* t) { return t->CanUseDMA(); } + static const void* base(const Tensor* t) { return t->base(); } + static void* base(Tensor* t) { return t->base(); } + static TensorBuffer* buffer(Tensor* t) { return t->buf_; } + static const TensorBuffer* buffer(const Tensor* t) { return t->buf_; } + static void UnsafeSetShape(Tensor* t, const TensorShape& s) { + t->set_shape(s); + } +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_DMA_HELPER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/eager/attr_builder.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/eager/attr_builder.h new file mode 100644 index 00000000..9dc480d8 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/eager/attr_builder.h @@ -0,0 +1,223 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_ATTR_BUILDER_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_ATTR_BUILDER_H_ + +// Support for eager execution of TensorFlow kernels. + +#include +#include +#include + +#include "tensorflow/c/eager/abstract_op_attrs.h" +#include "tensorflow/c/tf_attrtype.h" +#include "tensorflow/core/common_runtime/device.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/gtl/inlined_vector.h" +#include "tensorflow/core/platform/fingerprint.h" +#include "tensorflow/core/util/tensor_slice_reader_cache.h" + +namespace tensorflow { + +// Maps attribute name to an encoding of the type of the attribute value. +// If the type is not a list type, the value is the same as the TF_AttrType type +// of the value. Else, the highest order bit is on, and the rest of the bits +// represent the TF_AttrType type of the values in the list. +typedef std::unordered_map AttrTypeMap; + +// Look up OpDef for `op_name`. +absl::Status OpDefForOp(const string& op_name, const OpDef** op_def); + +// Returns the AttrTypeMap for the TensorFlow operation named op_name. +// If op_name is not registered in global op registry, AttrTypeMapForOp assumes +// the op to be a function and returns the default attributes for a function. +// `is_function` is set to true in this case. +absl::Status AttrTypeMapForOp(const char* op_name, const AttrTypeMap** out, + bool* is_function); + +// Looks for 'attr_name' in 'm' and sets 'out' and 'is_list'. +absl::Status AttrTypeByName(const AttrTypeMap& m, const string& attr_name, + TF_AttrType* out, unsigned char* is_list); + +// KernelAndDevice::Init needs a NodeDef only to pass the attribute map through. +// An AttrBuilder is a convenience class to help with that - providing a smaller +// interface than NodeDefBuilder and avoiding expensive (unnecessary?) sanity +// checks (like number of inputs matching the OpDef - we only care about +// attributes here). +// +// TODO(ashankar): Take a closer look at checks in NodeDefBuilder and see which +// ones make sense to replicate. + +// This is a helper class for creating a NodeDef. Additionally, this class +// allows computing a cache key based on fingerprinting the attributes of this +// NodeDef. +// +// Example usage: +// AttrBuilder a; +// a.NumInputs(2); +// a.Set("T", TF_FLOAT); +// tensorflow::Fprint128 cache_key = a.CacheKey("cpu:0"); +// const NodeDef& n = a.BuildNodeDef(); +// +// Calls to NumInputs or Set between multiple invocations to CacheKey may cause +// different values to be returned by CacheKey. +// +// If NumInputs or Set is called, BuildNodeDef should be called again to update +// the NodeDef. +// +// For performance reasons, the class internally delays the actual construction +// of the NodeDef till BuildNodeDef is called, or Set is called with certain +// uncommon types (see template specializations of Set to see which types +// trigger a NodeDef creation). +// +// Setting attributes via `Set` may cause arena-allocated protocol buffer +// messages to be destructed, which is not thread safe. This means that it is +// currently not safe to set attributes on *different* AttrBuilder objects from +// multiple threads. This does not apply to `CopyAttributes`. +class AttrBuilder : public AbstractOpAttrs { + public: + AttrBuilder() + : AbstractOpAttrs(AbstractOpAttrs::AbstractOpAttrsKind::kEager) {} + + ~AttrBuilder() override = default; + explicit AttrBuilder(const char* op) + : AbstractOpAttrs(AbstractOpAttrs::AbstractOpAttrsKind::kEager) { + Reset(op); + } + + void Reset(const char* op) { + op_name_ = op; + num_inputs_ = 0; + encoded_attrs_.clear(); + node_def_finalized_ = false; + cached_cache_key_ = std::nullopt; + device_for_cached_cache_key_.clear(); + } + + const string& op_name() const { return op_name_; } + void set_op_name(const string& name) { op_name_ = name; } + + // Needed to work around call to ValidateNodeDef in CreateOpKernel. + AttrBuilder& NumInputs(int n); + + template + AttrBuilder& Set(absl::string_view attr_name, T&& value) { + SetAttrValue(value, &attr_tmp_); + AddAttrIfNotPresent(attr_name, attr_tmp_); + node_def_finalized_ = false; + cached_cache_key_ = std::nullopt; + return *this; + } + + size_t NumAttributes() const { return encoded_attrs_.size(); } + + AttrBuilder& Set(absl::string_view attr_name, const AttrValue& value) { + AddAttrIfNotPresent(attr_name, value); + cached_cache_key_ = std::nullopt; + return *this; + } + + // Retrieves the attribute value. + // Note that Get() can involve a linear scan of all attributes with the same + // value type in this Node. This is not an issue, because Get is used rarely + // and nodes have a small number of attributes. + template + absl::Status Get(absl::string_view attr_name, T* value) const { + // Common attributes are stored in AttrVecs. This Get() template + // is specialized for them below. If we end up here, the type must be + // among those that we store in the node_def_. + if (!node_def_finalized_) { + return errors::NotFound("No attr named'", attr_name, + "' found in AttrBuilder for ", op_name_); + } + return GetNodeAttr(AttrSlice(node_def_), attr_name, value); + } + + tensorflow::Fprint128 CacheKey(absl::string_view device); + + // Fill `m` with the attr-value pairs set via AttrBuilder::Set() so far, as + // well as any default attr-value pairs from the associated op_def, if there + // is one. + void FillAttrValueMap(AttrValueMap* m) const; + + // Fill `m` with the attr-value pairs set via AttrBuilder::Set() so far except + // when the value matches the default for this attr. + // More precisely, if the global op registry contains an OpDef for this op + // and if an attribute value is the same as the default (according to the + // OpDef), this attr-value pair is not added to `m`. + void FillAttrValueMapWithoutDefaults(AttrValueMap* m) const; + const NodeDef& BuildNodeDef(); + + // Transfers the attributes from `other` to this AttrBuilder. Does not + // overwrite existing attributes. Since it does not require deserializing and + // re-serializing attributes, it is much more efficient than going through an + // AttrValueMap. + void CopyAttributes(const AttrBuilder& other); + + void GetNameAttrList(tensorflow::NameAttrList* name_and_attrs) const override; + + bool GetInt(absl::string_view attr_name, int64_t* result) const override; + bool GetFloat(absl::string_view attr_name, float* result) const override; + bool GetBool(absl::string_view attr_name, bool* result) const override; + bool GetType(absl::string_view attr_name, + tensorflow::DataType* result) const override; + absl::Status GetTypeList( + absl::string_view attr_name, + absl::InlinedVector* type_list) const override; + + private: + tensorflow::Fprint128 BuildCacheKeyForDevice(absl::string_view device) const; + + template + void SetInAttrValueMap(AttrValueMap* m, const string& attr_name, + T&& value) const { + DCHECK(!node_def_finalized_) + << "Calling SetInAttrValueMap after BuildNodeDef."; + // If attribute is set more than once, its first value prevails + m->insert({attr_name, value}); + } + + void AddAttrIfNotPresent(absl::string_view attr_name, const AttrValue& value); + + gtl::FlatMap encoded_attrs_; + mutable AttrValue attr_tmp_; // For encoding + + string op_name_; + int num_inputs_; + NodeDef node_def_; + bool node_def_initialized_; + bool node_def_finalized_; + + std::optional cached_cache_key_; + string device_for_cached_cache_key_; +}; + +template <> +absl::Status AttrBuilder::Get(absl::string_view attr_name, int* value) const; +template <> +absl::Status AttrBuilder::Get(absl::string_view attr_name, float* value) const; +template <> +absl::Status AttrBuilder::Get(absl::string_view attr_name, bool* value) const; +template <> +absl::Status AttrBuilder::Get(absl::string_view attr_name, + tensorflow::DataType* value) const; +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_ATTR_BUILDER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/eager/context.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/eager/context.h new file mode 100644 index 00000000..8440e298 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/eager/context.h @@ -0,0 +1,968 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_CONTEXT_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_CONTEXT_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "tensorflow/c/eager/immediate_execution_context.h" +#include "tensorflow/c/tensor_interface.h" +#include "tensorflow/core/common_runtime/composite_device.h" +#include "tensorflow/core/common_runtime/device_mgr.h" +#include "tensorflow/core/common_runtime/eager/custom_device.h" +#include "tensorflow/core/common_runtime/eager/custom_device_op_handler.h" +#include "tensorflow/core/common_runtime/eager/eager_executor.h" +#include "tensorflow/core/common_runtime/eager/kernel_and_device.h" +#include "tensorflow/core/common_runtime/eager/rendezvous_cache.h" +#include "tensorflow/core/common_runtime/process_function_library_runtime.h" +#include "tensorflow/core/common_runtime/rendezvous_mgr.h" +#include "tensorflow/core/example/example.pb.h" +#include "tensorflow/core/framework/collective.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/log_memory.h" +#include "tensorflow/core/framework/rendezvous.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/core/threadpool.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/fingerprint.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/random.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/thread_annotations.h" +#include "tensorflow/core/platform/threadpool.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/public/session_options.h" +#include "tensorflow/core/util/device_name_utils.h" +#include "tsl/platform/refcount.h" + +// "tensorflow/core/platform/platform.h" must be included first before using +// IS_MOBILE_PLATFORM. +#if !defined(IS_MOBILE_PLATFORM) +#include "tensorflow/core/distributed_runtime/eager/eager_client.h" +#include "tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h" +#include "tensorflow/core/distributed_runtime/server_lib.h" +#include "tensorflow/core/distributed_runtime/worker_cache.h" +#include "tensorflow/core/distributed_runtime/worker_env.h" +#endif // !IS_MOBILE_PLATFORM + +namespace tensorflow { + +namespace eager { +// We need this forward declaration because we have circular dependency: +// Context -> RemoteMgr -> TensorHandle -> Context. +// TODO(fishx): Remove this once we remove Context dependency in TensorHandle. +class RemoteMgr; +} // namespace eager + +// Check the value of the environment variable, +// `TF_REMOTE_HANDLE_SKIP_WAIT_FOR_READY` from its cached copy in memory and if +// not cached, reads from the environment variable. +bool SkipRemoteHandleWaitReady(); + +class EagerContext : public ImmediateExecutionContext, public core::RefCounted { + public: + static constexpr uint64 kInvalidContextId = 0; + + static uint64 NewContextId() { + uint64 context_id = random::New64(); + while (context_id == kInvalidContextId) { + context_id = random::New64(); + } + return context_id; + } + + EagerContext( + const SessionOptions& opts, + ContextDevicePlacementPolicy default_device_placement_policy, bool async, + /*const*/ DeviceMgr* device_mgr, bool device_mgr_owned, + /*const*/ tsl::core::RefCountPtr rendezvous, + DistributedFunctionLibraryRuntime* cluster_flr = nullptr, + CollectiveExecutorMgrInterface* collective_executor_mgr = nullptr, + bool run_eager_op_as_function = false, bool jit_compile_rewrite = false); + + void Release() override { Unref(); } + + AbstractTensorInterface* CreateInt64Scalar(int64_t value) override; + AbstractTensorInterface* CreateUint64Scalar(uint64 value) override; + AbstractTensorInterface* CreateInt32Scalar(int32_t value) override; + AbstractTensorInterface* CreateFloatScalar(float value) override; + AbstractTensorInterface* CreateDoubleScalar(double value) override; + AbstractTensorInterface* CreateHalfScalar(Eigen::half value) override; + AbstractTensorInterface* CreateStringScalar( + tensorflow::tstring value) override; + AbstractTensorInterface* CreateComplex128Scalar( + tensorflow::complex128 value) override; + AbstractTensorInterface* CreateBoolScalar(bool value) override; + + AbstractTensorInterface* CreateTensor( + DataType dtype, absl::Span dim_sizes) override; + AbstractTensorInterface* CreateTensor(DataType dtype, const int64_t* dims, + int num_dims, void* data, size_t len, + MemoryReleaser memory_releaser, + void* memory_releaser_arg) override; + + ImmediateExecutionTensorHandle* CreateLocalHandle( + AbstractTensorInterface* t) override; + // Create an abstract tensor handle from tensorflow::Tensor. + ImmediateExecutionTensorHandle* CreateLocalHandleFromTFTensor( + tensorflow::Tensor& t, const char* d_name) override; + ImmediateExecutionTensorHandle* CopyTensorHandleToDevice( + ImmediateExecutionTensorHandle* handle, const char* device_name, + absl::Status* status) override; + ImmediateExecutionOperation* CreateOperation() override; + + // This is a virtual helper function to convert TFRT TensorHandle to + // tensorflow::TensorHandle. In current runtime EagerContext, just forward + // the input since the input tensor handle is already a + // tensorflow::TensorHandle. + ImmediateExecutionTensorHandle* TFTensorHandleFromInterface( + ImmediateExecutionTensorHandle* handle) override; + + absl::Status RegisterFunction(AbstractFunction* f) override; + + bool UsesTFRT() override; + + bool RunEagerOpAsFunction() const; + + void SetRunEagerOpAsFunction(bool enable) override; + + bool JitCompileRewrite() const; + + void SetJitCompileRewrite(bool enable) override; + + void ListDevices(std::vector* device_attributes) override; + + absl::Status AddDevices( + std::vector> devices) override; + + thread::ThreadPool* GetThreadPool() { return thread_pool_.get(); } + + // Returns the function library runtime for the given device. + FunctionLibraryRuntime* func_lib(const Device* d) const { + return pflr_->GetFLR(d->name()); + } + + ProcessFunctionLibraryRuntime* pflr() const { return pflr_.get(); } + + std::function)>* runner() { return &runner_; } + + // Specify a executor for this thread. + void SetExecutorForThread(EagerExecutor* executor) override; + + std::shared_ptr> prioritized_device_type_list() + const { + mutex_lock l(device_type_list_mu_); + return prioritized_device_type_list_; + } + + // Clear pending nodes in thread executors and kernel caches. + void ClearCachesAndThreadExecutors() override; + // Clear pending nodes in default executor and kernel caches. + void ClearCachesAndDefaultExecutor(); + + // Sets the device placement policy for the current thread. + void SetThreadLocalDevicePlacementPolicy( + ContextDevicePlacementPolicy policy) override; + + // Returns the device placement policy for the current thread. + ContextDevicePlacementPolicy GetDevicePlacementPolicy() const override; + + // Select an appropriate device for an operation. + // + // Given the preferred device for the operation, and the node_def, finds the + // best suitable device for the operation in this context. + // + // The preferred device is specified as a `ParsedName` containing the elements + // (details) that the resulting device should match. If there are no such + // devices, and the context currently allows soft device placement, a suitable + // device not matching `preferred` will be chosen. + // + // The chosen device is stored in the `device` argument. The argument is not + // modified unless this method returns `OkStatus()`. + absl::Status SelectDevice(DeviceNameUtils::ParsedName preferred, + const NodeDef& ndef, Device** out) const; + + // TODO(mdan): Rename to ContainsFunction. + bool FindFunctionByName(const string& name) const; + + absl::Status FindFunctionOpData( + const string& name, const tensorflow::OpRegistrationData** op_data); + + const FunctionDef* FindFunctionDef(const string& name) const override; + core::RefCountPtr FindRecord( + const string& name) const override; + + Device* HostCPU() const { return host_cpu_device_; } + Device* CanonicalDevice(Device* d) const { + return HostCPU() == d ? nullptr : d; + } + const DeviceNameUtils::ParsedName& HostCPUParsedName() const override { + return HostCPU()->parsed_name(); + } + + const string& HostCPUName() const override { return HostCPU()->name(); } + + GraphCollector* GetGraphCollector() { return &graph_collector_; } + + EagerExecutor& Executor() override; + + // Add the given `fdef` to the local FunctionLibraryDefinition. And add an + // entry to the KernelAndDevice cache for it if it's not exist. + absl::Status AddFunctionDef(const FunctionDef& fdef) override; + + absl::Status AddFunctionDefWithStackTraces( + const FunctionDef& fdef, const StackTracesMap& stack_traces) override; + + // `library` contains all FunctionDefs and GradientDefs to expand `fdef`. Add + // it to the local FunctionLibraryDefinition as well, but no need to add it + // to the KernelAndDevice cache since they won't be executed as + // KernelAndDevices. + absl::Status AddFunctionDef(const FunctionDef& fdef, + const FunctionDefLibrary& library, + bool add_to_local_only = false, + const StackTracesMap& stack_traces = {}); + + // `library` contains all FunctionDefs and GradientDefs to expand `fdef`. Add + // it to the local FunctionLibraryDefinition as well, but no need to add it + // to the KernelAndDevice cache since they won't be executed as + // KernelAndDevices. + absl::Status AddFunctionRecord(core::RefCountPtr func_record, + const FunctionDefLibrary& library, + bool add_to_local_only = false); + + // Adds a component function (i.e. containing a subgraph of a multi-process + // function) implemented as `fdef`. + // + // REQUIRES: `library` must contain all functions reachable from `fdef`. It + // should not contain `fdef` itself. + absl::Status AddComponentFunction(const FunctionDef& fdef, + const FunctionDefLibrary& library); + + const FunctionDef* GetFunctionDef(const string& function_name); + + std::vector ListFunctionNames() override; + tensorflow::ImmediateExecutionContext::CacheStats GetCacheStats() override; + + absl::Status RemoveFunction(const string& func) override; + absl::Status AddRemoveFunctionNotifier( + const string& func, std::function notifier) override; + + // Wait for pending nodes to be finished in local executors (including context + // default executor and thread executors) and executors on remote workers. + // Return combined status of remote executors. If there are multiple errors, + // the Status code will be the same as the first remote executor that has + // errors, and the error message will be combined from all executors. + absl::Status SyncExecutors(); + + absl::Status AsyncWait() override { return SyncExecutors(); } + + core::RefCountPtr GetCachedKernel(Fprint128 cache_key); + Device* GetCachedDevice(Fprint128 device_cache_key); + + core::RefCountPtr AddKernelToCache( + Fprint128 cache_key, core::RefCountPtr kernel); + void AddDeviceToCache(Fprint128 device_cache_key, Device* device); + + bool LogDevicePlacement() const { return log_device_placement_; } + void SetLogDevicePlacement(bool enable) override { + log_device_placement_ = enable; + } + + bool AllowSoftPlacement() const { return allow_soft_placement_; } + void SetAllowSoftPlacement(bool enable) override { + allow_soft_placement_ = enable; + } + bool LogMemory() const { return log_memory_; } + + // Returns a borrowed pointer to the global rendezvous. The rendezvous may + // become invalid if this Context is destroyed. + Rendezvous* GetRendezvous() const { return rendezvous_.get(); } + + void ResetGlobalRendezvousForFunction() override { + mutex_lock l(global_rendezvous_mu_); + // Remove the global rendezvous instance from the local rendezvous table + // if it uses local rendezvous type, which forces EagerContext to create a + // new local rendezvous instance in the table. + // TODO(b/274683676) Why can't we abort the old rendezvous here? + local_rendezvous_cache_.Remove(-1); + TF_CHECK_OK(CreateRendezvousFactory()(-1, nullptr, + &global_rendezvous_for_functions_)); + } + + // Returns the global_rendezvous_for_functions' underlying LocalRendezvous' + // status. If the underlying Rendezvous is not in the local_rendezvous_cache_ + // returns OK. + absl::Status GetGlobalRendezvousForFunctionLocalRendezvousStatus(); + + // Returns a factory which maps from step_id to rendezvous. + // + // When tensor transfer across functions/eager executions using send/recv ops + // are required, `reuse_rendezvous_for_functions` can be set to true so that + // function executions and eager executions use the same rendezvous instance, + // instead of creating new instance per function calls. + // + // The caller of the returned function owns a reference to the resulting + // Rendezvous. + Rendezvous::Factory RendezvousFactory( + bool reuse_rendezvous_for_functions = false) { + // There is an implicit assumption that the global_rendezvous_for_functions_ + // is always an IntraProcessRendezvous to match the behaviour of the + // EagerContext's rendezvous. + // Ref: tensorflow/c/eager/c_api.cc;l=143;rcl=396387348 + // If a cross process kernel needs a rendezvous a new InterProcessRendezvous + // should be created. + if (reuse_rendezvous_for_functions && rendezvous_creator_ == nullptr && +#if !defined(IS_MOBILE_PLATFORM) + worker_env_ == nullptr && +#endif + remote_device_mgr() == nullptr) { + return Rendezvous::Factory{[this](const int64_t step_id, + const DeviceMgr* device_mgr, + tsl::core::RefCountPtr* r) { + mutex_lock l(global_rendezvous_mu_); + *r = global_rendezvous_for_functions_.GetNewRef(); + return absl::OkStatus(); + }}; + } else { + return CreateRendezvousFactory(); + } + } + + CollectiveExecutorMgrInterface* collective_executor_mgr() { + return collective_executor_mgr_.Get(); + } + std::unique_ptr GetCollectiveExecutorHandle() { + return std::make_unique( + + collective_executor_mgr()->FindOrCreate(0), true /*inherit_ref*/); + } + + void SetCollectiveExecutorMgr(CollectiveExecutorMgrInterface* mgr) { + collective_executor_mgr_.Reset(mgr); + } + tensorflow::DeviceMgr* local_device_mgr() const { + return local_device_manager_.Get(); + } + const tensorflow::DynamicDeviceMgr* remote_device_mgr() const { + return remote_device_manager_.Get(); + } + + tensorflow::DynamicDeviceMgr* GetOwnedRemoteDeviceMgr() { + return remote_device_manager_.GetOwned(); + } + + std::vector ListLocalTfDevices() override { + return local_device_mgr()->ListDevices(); + } + + std::vector ListAllTfDevices() override; + + // TODO(apassos) clean up RunMetadata storage. + mutex* MetadataMu() TF_LOCK_RETURNED(metadata_mu_) { return &metadata_mu_; } + bool ShouldStoreGraphs() TF_LOCKS_EXCLUDED(metadata_mu_); + void SetShouldStoreGraphs(bool value) override; + RunMetadata* RunMetadataProto() TF_EXCLUSIVE_LOCKS_REQUIRED(metadata_mu_) { + return run_metadata_.get(); + } + std::unique_ptr ExportRunMetadata() override + TF_LOCKS_EXCLUDED(metadata_mu_); + + void StartStep() override; + void EndStep() override; + ScopedStepContainer* StepContainer(); + + FunctionLibraryDefinition* FuncLibDef() override { return &func_lib_def_; } + + FunctionLibraryDefinition* GetComponentFunctionFunctionLibraryDefinition( + const string& function_name) { + tf_shared_lock lock(cache_mu_); + auto iter = component_function_libraries_.find(function_name); + if (iter != component_function_libraries_.end()) { + return iter->second.get(); + } + return nullptr; + } + +#if !defined(IS_MOBILE_PLATFORM) + // Assign the EagerClient pointer to `client` based on the given device / task + // name, and increment the refcount of the client. The reference ownership is + // transferred to the caller, and the unref should automatically happen when + // destructing the RefCountPtr object at the caller's side. + // `client` must not be initialized or holding a reference of another object + // before calling this method. + absl::Status GetClient(Device* device, + core::RefCountPtr* client); + absl::Status GetClient(const DeviceNameUtils::ParsedName& device_name, + core::RefCountPtr* client); + absl::Status GetClient(const string& remote_task, + core::RefCountPtr* client); + + uint64 GetContextId() const; + uint64 GetContextViewId() const; + void IncrementContextViewId(); + + absl::Status EnableCollectiveOps(const ServerDef& server_def) override; + + // TODO(nareshmodi): Encapsulate remote state into a separate + // class/struct. + // + // Enables the eager context to communicate with remote devices. When + // initializing with this method, this context will be the primary context, + // which will kill all its remote contexts in shutdown. + // + // - server: A ServerInterface that exports the tensorflow.WorkerService. + // Note that this class expects the server to already have been started. + // - remote_eager_workers: A cache from which we can get "EagerClient"s to + // communicate with remote eager services. + // - remote_device_mgr: A DeviceMgr* which contains all remote devices + // (should contain no local devices). + // - remote_contexts: A vector containing task names. + // TODO(b/184375824): clean up parameter order for better readability. + absl::Status InitializeRemoteMaster( + std::unique_ptr server, WorkerEnv* worker_env, + std::shared_ptr worker_session, + std::unique_ptr remote_eager_workers, + std::unique_ptr remote_device_manager, + const std::vector& remote_contexts, uint64 context_id, + tsl::core::RefCountPtr r, + /*const*/ DeviceMgr* local_device_mgr, int keep_alive_secs, + DistributedFunctionLibraryRuntime* cluster_flr, + std::unique_ptr> + remote_mgr); + + // Update an existing master context with a new set of remote workers (i.e., a + // new "view" of cluster membership. Similar to InitializeRemoteMaster but + // this will keep the current context_id and increment a context_view_id, will + // keep the current resource manager so that resources from the previous view + // can still be accessed, and will automatically register existing functions + // if there are newly added hosts. + absl::Status UpdateRemoteMaster( + uint64 context_id, + std::unique_ptr remote_eager_workers, + const std::vector& add_remote_contexts, + const std::vector& remove_remote_contexts); + + // Similar with InitializeRemoteMaster but this context will not kill remote + // contexts in shutdown. + absl::Status InitializeRemoteWorker( + std::unique_ptr remote_eager_workers, + DynamicDeviceMgr* remote_device_mgr, + const std::vector& remote_contexts, uint64 context_id, + uint64 context_view_id, + std::function(const int64_t)> + rendezvous_creator, + DistributedFunctionLibraryRuntime* cluster_flr, + std::unique_ptr> + remote_mgr, + std::function resource_deallocator); + + // Similar with InitializeRemoteWorker but will reuse existing context and + // increment context_view_id. + absl::Status UpdateRemoteWorker( + std::unique_ptr remote_eager_workers, + const std::vector& remote_contexts, uint64 context_id); + + absl::Status StoreCollectiveOpsServer( + std::unique_ptr new_server, DeviceMgr* device_mgr, + CollectiveExecutorMgrInterface* rpc_collective_executor_mgr); + + // For the specified remote worker, preprocess and set its device filters. + absl::Status SetRemoteDeviceFilters( + const string& remote_worker, const std::vector& device_filters); + + // For the specified remote worker, apply the stored device filters to the + // list of device attributes following these rules: + // (1) if the remote worker does not have device filters, all devices are + // visible to the worker; + // (2) if the device is on the remote worker, then it is visible; + // (3) if the device matches at least one device filter, then it is visible. + // The result is saved as a boolean vector of the same length (i.e., + // filtered_device_mask) indicating whether each of the devices is visible to + // the remote worker. + void FilterDevicesForRemoteWorkers( + const string& remote_worker, + const protobuf::RepeatedPtrField& device_attrs, + std::vector* filtered_device_mask); + + // TODO(fishx): Remove the custom deleter once we remove forward declaration. + const std::unique_ptr>& + RemoteMgr() { + return remote_mgr_; + } + + // If true, then tensors should be shipped across processes via the + // EagerService.Enqueue(SendTensorOp). If false, _Send/_Recv ops should be + // used instead (which in-turn use WorkerService.RecvTensor RPCs). + bool UseSendTensorRPC() { return use_send_tensor_rpc_; } + + tensorflow::ServerInterface* GetServer() { return server_.get(); } + + // For LLVM style RTTI. + static bool classof(const AbstractContext* ptr) { + return ptr->getKind() == kEager; + } + + // Function to support distributed C API. + void SetDistributedManager( + std::unique_ptr distributed) + override { + distributed_manager_ = std::move(distributed); + } + ImmediateExecutionDistributedManager* GetDistributedManager() override { + return distributed_manager_.get(); + } + + // May only be used during multi-client setup so that a RemoteRendezvous + // can be initialized instead of defaulting to the IntraProcessRendezvous. + void SetWorkerEnv(WorkerEnv* worker_env, + std::shared_ptr worker_session); +#endif // IS_MOBILE_PLATFORM + + // Closes remote eager contexts, waits for all RPCs to finish, and + // destroys the EagerClientCache. No RPCs can be made through this context + // after this method has been called. + // This method exists to aid a clean shutdown. It causes all RPCs to finish + // and remote TensorHandles to release their references to this context. + // To avoid deadlocks, this method must not be called on the thread + // processing RPCs because it makes RPCs and waits for their completion. + // + // On mobile, it just cleans the caches. + void WaitForAndCloseRemoteContexts(); + + bool PinSmallOpsToCPU() const { return pin_small_ops_to_cpu_; } + + tensorflow::Env* TFEnv() const { return env_; } + + absl::Status FindDeviceFromName(const char* device_name, + Device** device) const; + + absl::Status FindCompositeDeviceFromName(absl::string_view device_name, + CompositeDevice** device) const; + + bool IsCustomDevice(const string& device_name) override; + + absl::Status RegisterCustomDevice( + const string& name, std::unique_ptr device) override; + + CustomDeviceOpHandler& GetCustomDeviceOpHandler() override { + return custom_device_op_handler_; + }; + + // Find or create a composite device with the given `underlying_devices` and + // `device_name` (if not empty). + absl::Status FindOrCreateCompositeDevice( + const std::vector& underlying_devices, const string& device_name, + CompositeDevice** composite_device); + + bool OnSameTask(const Device* first, const Device* second) const; + // Gets the CPU device on the task of device. + absl::Status CPUDeviceOnTask(const Device* device, Device** cpu_device) const; + + const SessionOptions& session_options() const { return opts_; } + void InitPrioritizedDeviceTypeList(); + + // Re-assign cluster-FLR and re-initialize devices and FLR in process-FLR + void UpdateClusterFLRAndInitDevices( + DistributedFunctionLibraryRuntime* cluster_flr); + + // A constant representing the step id used for the global rendezvous. + // This is used to distibguish whether a user-specified step id should be set. + // Step id value of kGlobalRendezvous is reserved and should not be specified + // by the user. + static const int64_t kGlobalRendezvousId; + + private: + // The class for caching Rendezvous instances per step_id. + // If the Rendezvous object is destroyed for the step, a new one will be + // created on demand. + class LocalRendezvousCache { + public: + LocalRendezvousCache() + : cache_(new RendezvousCache) {} + + tsl::core::RefCountPtr FindOrCreate( + int64_t step_id, DeviceMgr* device_mgr); + + tsl::core::RefCountPtr Find(int64_t step_id) const { + return cache_->Find(step_id); + } + + std::vector GetActiveStepIds() const { + return cache_->GetActiveStepIds(); + } + + void Remove(int64_t step_id) { cache_->Remove(step_id); } + + private: + tsl::core::RefCountPtr> cache_; + }; + + Rendezvous::Factory CreateRendezvousFactory() { + if (rendezvous_creator_ != nullptr) { + return Rendezvous::Factory{[this](const int64_t step_id, + const DeviceMgr* device_mgr, + tsl::core::RefCountPtr* r) { + VLOG(6) << "Creating rendezvous using the rendezvous_creator_."; + *r = rendezvous_creator_(step_id); + return absl::OkStatus(); + }}; + } + +#if !defined(IS_MOBILE_PLATFORM) + if (worker_env_ != nullptr && worker_env_->rendezvous_mgr != nullptr) { + return Rendezvous::Factory{[this](const int64_t step_id, + const DeviceMgr* device_mgr, + tsl::core::RefCountPtr* r) { + VLOG(6) << "Creating rendezvous using the worker_env's rendezvous_mgr."; + // TODO(hhb): Add a Create method and use it here. + auto remote_r = worker_env_->rendezvous_mgr->Find(step_id); + remote_r->Initialize(worker_session_.get()).IgnoreError(); + *r = std::move(remote_r); + return absl::OkStatus(); + }}; + } +#endif + + if (remote_device_mgr() == nullptr) { + return Rendezvous::Factory{[this](const int64_t step_id, + const DeviceMgr* device_mgr, + tsl::core::RefCountPtr* r) { + VLOG(6) << "Creating rendezvous using local_device_mgr."; + *r = local_rendezvous_cache_.FindOrCreate(step_id, local_device_mgr()); + return absl::OkStatus(); + }}; + } + + return Rendezvous::Factory(); + } + + ~EagerContext() override; + + absl::Status MaybeRegisterFunctionRemotely(const FunctionDef& fdef); + absl::Status MaybeRemoveFunctionRemotely(const string& function_name); + absl::Status RegisterExistingFunctionsOnRemoteWorkers( + const std::vector& remote_workers); + + void ResetPFLR(const DeviceMgr* device_mgr, Env* env, + const ConfigProto* config, int graph_def_version, + const FunctionLibraryDefinition* lib_def, + const OptimizerOptions& optimizer_options, + thread::ThreadPool* thread_pool = nullptr, + DistributedFunctionLibraryRuntime* cluster_flr = nullptr); + + void ResetClusterFLR(DistributedFunctionLibraryRuntime* cluster_flr); + void UpdateGlobalRendezvousDeviceManager(tensorflow::DeviceMgr* device_mgr); + + void ClearResourceContainer(const string& name); + + template + struct OwnedOrUnownedHelper { + public: + OwnedOrUnownedHelper() = default; + explicit OwnedOrUnownedHelper(T* object, const bool owned = false) { + Reset(object, owned); + } + + void Reset(std::unique_ptr object) { + owned_object = std::move(object); + unowned_object_ptr = nullptr; + } + + void Reset(T* object, const bool owned = false) { + if (owned) { + owned_object.reset(object); + unowned_object_ptr = nullptr; + } else { + owned_object.reset(nullptr); + unowned_object_ptr = object; + } + } + + bool Owned() const { return owned_object != nullptr; } + + T* GetOwned() const { return owned_object.get(); } + T* Get() const { + return owned_object ? owned_object.get() : unowned_object_ptr; + } + + std::unique_ptr owned_object = nullptr; + T* unowned_object_ptr = nullptr; + }; + + SessionOptions opts_; + const ContextDevicePlacementPolicy default_device_placement_policy_; + + // Note: we cannot use C++11 thread_local here as there is no concept of a + // thread-local-object-local variable in C++11. + mutable mutex policy_map_mu_; + std::unordered_map + device_placement_policy_ TF_GUARDED_BY(policy_map_mu_); + + // This device manager maintains only the local devices on this worker. + OwnedOrUnownedHelper local_device_manager_; + // Maintain copy of all previously created local device managers. + std::vector> old_local_device_managers_; + + // Unowned DynamicDeviceMgr is set on remote worker to allow running + // multi-device function on remote worker. + // This device manager maintains all the devices (including both local and + // remote to this worker) in the cluster. + OwnedOrUnownedHelper remote_device_manager_; + + Device* host_cpu_device_; // Owned by device_manager + mutable mutex device_type_list_mu_; + std::shared_ptr> prioritized_device_type_list_ + TF_GUARDED_BY(device_type_list_mu_); + tsl::core::RefCountPtr rendezvous_; + std::function(const int64_t)> + rendezvous_creator_; + CustomDeviceOpHandler custom_device_op_handler_; + + mutable mutex composite_devices_mu_; + // Maps from the fingerprint of a set of device names to a virtual + // CompositeDevice. + // TODO(b/145922293): Consider taking device names as keys. + absl::flat_hash_map> + composite_devices_ ABSL_GUARDED_BY(composite_devices_mu_); + + FunctionLibraryDefinition func_lib_def_{OpRegistry::Global(), + FunctionDefLibrary()}; + + std::unique_ptr thread_pool_; + + // EagerContext owns the DistributedFunctionLibraryRuntime( + // EagerClusterFunctionLibraryRuntime) if using EagerService for remote + // function execution (lazy_copy_function_remote_inputs_=true). + OwnedOrUnownedHelper cluster_flr_; + // One FunctionLibraryRuntime per device. + // func_libs[i] is the FunctionLibraryRuntime corresponding to + // session->devices[i]. + std::unique_ptr pflr_; + + std::function)> runner_; + + mutex cache_mu_; + mutex device_cache_mu_; + mutex remove_function_notifiers_mu_; + struct RegisteredFunction : public core::RefCounted { + ~RegisteredFunction() override = default; + + std::unique_ptr> cached_kernel_keys; + }; + std::unordered_map, + Fprint128Hasher> + kernel_cache_ TF_GUARDED_BY(cache_mu_); + std::unordered_map registered_functions_ + TF_GUARDED_BY(cache_mu_); + + std::unordered_map> + component_function_libraries_ TF_GUARDED_BY(cache_mu_); + absl::flat_hash_map device_cache_ + TF_GUARDED_BY(device_cache_mu_); + std::unordered_map>> + remove_function_notifiers_ TF_GUARDED_BY(remove_function_notifiers_mu_); + + // Whether we should compute RunMetadata. + std::atomic should_store_graphs_{false}; + mutex metadata_mu_; + std::unique_ptr run_metadata_ TF_GUARDED_BY(metadata_mu_); + GraphCollector graph_collector_; + std::atomic log_device_placement_; + std::atomic allow_soft_placement_; + + // Information related to step containers. + std::atomic num_active_steps_; + std::unique_ptr step_container_ + TF_GUARDED_BY(metadata_mu_); + + EagerExecutor default_executor_; + mutable mutex executor_map_mu_; + // Not owned. + std::unordered_map thread_local_executor_ + TF_GUARDED_BY(executor_map_mu_); + std::unordered_map> + has_cleanup_ TF_GUARDED_BY(executor_map_mu_); + + const bool log_memory_; + + // The table of local rendezvous instances for intra-process communication. + // This make sures only one local rendezvous instance exists per step id. + LocalRendezvousCache local_rendezvous_cache_; + + // Whether to use same rendezvous instance across function/eager executions. + std::atomic reuse_rendezvous_for_functions_{false}; + mutable mutex global_rendezvous_mu_; + + // Keeps alive the global rendezvous object. + core::RefCountPtr global_rendezvous_for_functions_ + TF_GUARDED_BY(global_rendezvous_mu_); + + Env* const env_; + + OwnedOrUnownedHelper collective_executor_mgr_; + +#if !defined(IS_MOBILE_PLATFORM) + std::vector GetRemoteContexts() TF_LOCKS_EXCLUDED(remote_state_mu_); + bool IsRemoteContextsEmpty() TF_LOCKS_EXCLUDED(remote_state_mu_); + void CloseAndClearAllRemoteContexts(); + void CloseRemoteContexts(const std::vector& remote_contexts, + uint64 context_id, uint64 context_view_id); + + // TODO(b/184375824): clean up parameter order for better readability. + absl::Status SetMasterContextState( + std::unique_ptr server, WorkerEnv* worker_env, + std::shared_ptr worker_session, + std::unique_ptr remote_eager_workers, + std::unique_ptr remote_device_manager, + uint64 context_id, uint64 context_view_id, + tsl::core::RefCountPtr r, + /*const*/ DeviceMgr* local_device_mgr, int keep_alive_secs, + DistributedFunctionLibraryRuntime* cluster_flr, + std::unique_ptr> + remote_mgr); + + // The server_ is not const since we release it when the context is destroyed. + // Therefore the server_ object is not marked as const (even though it should + // be). + std::unique_ptr server_; + WorkerEnv* worker_env_ = nullptr; + std::shared_ptr worker_session_; + + mutable mutex remote_state_mu_; + + uint64 context_id_ TF_GUARDED_BY(remote_state_mu_); + // The view id of an eager context should be set to 0 when context is created, + // and continuously incremented when context with the same context_id gets + // updated. The view id should be consistent between master and workers. + uint64 context_view_id_ TF_GUARDED_BY(remote_state_mu_); + std::vector remote_contexts_ TF_GUARDED_BY(remote_state_mu_); + std::unique_ptr remote_eager_workers_ + TF_GUARDED_BY(remote_state_mu_); + + int keep_alive_secs_ TF_GUARDED_BY(remote_state_mu_); + std::atomic sleep_for_secs_; + + std::unique_ptr keep_alive_thread_; + mutex keep_alive_thread_shutdown_mu_; + condition_variable keep_alive_thread_cv_; + bool shutting_down_ TF_GUARDED_BY(keep_alive_thread_shutdown_mu_) = false; + + std::unique_ptr> + remote_mgr_; + bool is_master_ TF_GUARDED_BY(remote_state_mu_); + + // Maps from a remote worker to a list of parsed device filters. + std::unordered_map> + cluster_device_filters_ TF_GUARDED_BY(remote_state_mu_); + + // A distributed manager that helps setup, update, and check liveness of + // member tasks in the cluster. + std::unique_ptr distributed_manager_; + +#endif // IS_MOBILE_PLATFORM + + // For a multi device function, the target device of each input is unknown + // until the function is instantiated on the default function device. + // If false, eagerly copy all remote inputs to the default function device; + // if true, lazily copy remote inputs to their target devices to avoid + // redundant copies. + bool lazy_copy_function_remote_inputs_ = false; + bool use_send_tensor_rpc_; + const bool pin_small_ops_to_cpu_; + + // Function that will be invoked in destructor to deallocate resources related + // to this context. + std::function resource_deallocator_ = nullptr; + bool run_eager_op_as_function_; + bool jit_compile_rewrite_; + + // Controls the behavior of + // `EagerContext::RegisterFunction(AbstractFunction*)` in distributed + // settings. + // + // By default, each abstract function will be registered on all workers in + // a cluster. If the environment variable + // `TF_EAGER_REGISTER_ABSTRACT_FUNCTIONS_LOCAL_ONLY=1` is set, each abstract + // function will be registered on the local worker only. + // + // In the common case that all functions are initially dispatched to + // a local device, the `ProcessFunctionLibraryRuntime` + // will ensure that the precise dependencies of that function are shipped to + // the remote device. Since PFLR instantiation often involves optimization, + // passes such as lowering control flow and inlining function calls, this will + // result in (1) sending a substantially smaller set of functions to each + // worker, and (2) the unoptimized functions never being called. + // + // Therefore setting `TF_EAGER_REGISTER_ABSTRACT_FUNCTIONS_LOCAL_ONLY=1` can + // significantly reduce both the startup time and the memory footprint on + // remote workers by avoiding the shipping of unneeded functions. + // + // TODO(b/326251557): Infer automatically when it is necessary to register a + // function or its dependencies on remote hosts; then remove the environment + // variable. + bool register_abstract_functions_local_only_; +}; + +inline EagerContext* ContextFromInterface(ImmediateExecutionContext* context) { + return down_cast(context); +} + +namespace internal { +struct EagerContextDeleter { + void operator()(EagerContext* p) const { + if (p != nullptr) { + p->Release(); + } + } +}; +} // namespace internal + +using EagerContextPtr = + std::unique_ptr; + +// Sets the EagerContext owned by the current Python eager Context (see +// TFE_Py_SetEagerContext in python/eager/pywrap_tfe.h). This is always called +// in tandem with TFE_Py_SetEagerContext (but not called by it, because its +// py_context argument is opaque). +// +// Do not use this function in production. It is only intended for testing. +// (see _reset_context in context.py). +// +// Not thread-safe. +void SetCEagerContext(EagerContext* ctx); + +// Returns the EagerContext owned by the current Python eager Context (see +// TFE_Py_SetEagerContext in pywrap_tfe.h). +// +// Not thread-safe. +EagerContext* GetCEagerContext(); + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_CONTEXT_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/eager/context_distributed_manager.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/eager/context_distributed_manager.h new file mode 100644 index 00000000..9db43d9e --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/eager/context_distributed_manager.h @@ -0,0 +1,80 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_CONTEXT_DISTRIBUTED_MANAGER_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_CONTEXT_DISTRIBUTED_MANAGER_H_ + +#include +#include +#include + +#include "tensorflow/c/eager/immediate_execution_context.h" +#include "tensorflow/c/eager/immediate_execution_distributed_manager.h" +#include "tensorflow/core/common_runtime/eager/context.h" +#include "tensorflow/core/platform/status.h" + +#if !defined(IS_MOBILE_PLATFORM) +#include "xla/tsl/distributed_runtime/coordination/coordination_service_agent.h" +#include "xla/tsl/distributed_runtime/preemption/preemption_notifier.h" +#endif // !IS_MOBILE_PLATFORM + +namespace tensorflow { +#if !defined(IS_MOBILE_PLATFORM) +class EagerContext; +class ServerDef; + +class EagerContextDistributedManager + : public ImmediateExecutionDistributedManager { + public: + explicit EagerContextDistributedManager(EagerContext* context) + : context_(context) {} + + // When running in a distributed context, `init_timeout_in_ms` requests the + // amount of time to wait for remote workers to respond. + + absl::Status SetOrUpdateServerDef( + const ServerDef& server_def, bool reset_context, int keep_alive_secs, + int64_t init_timeout_in_ms, int retries, + bool clear_existing_contexts = false) override; + + absl::Status InitializeLocalOnlyContext(const ServerDef& server_def, + int keep_alive_secs) override; + + absl::Status EnableCollectiveOps(const ServerDef& server_def) override; + + absl::Status CheckRemoteAlive(const std::string& remote_task_name, + bool* is_alive) override; + + tsl::CoordinationServiceAgent* GetCoordinationServiceAgent() override { + return coordination_service_agent_; + } + void SetCoordinationServiceAgent(tsl::CoordinationServiceAgent* agent) { + coordination_service_agent_ = agent; + } + void SetPreemptionNotifier( + std::unique_ptr notifier) { + preemption_notifier_ = std::move(notifier); + } + + private: + EagerContext* context_; + // Owned by context_->GetServer()->worker_env()->session_mgr. + tsl::CoordinationServiceAgent* coordination_service_agent_ = nullptr; + std::unique_ptr preemption_notifier_; +}; +#endif // !IS_MOBILE_PLATFORM +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_CONTEXT_DISTRIBUTED_MANAGER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/eager/copy_to_device_node.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/eager/copy_to_device_node.h new file mode 100644 index 00000000..37d943b2 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/eager/copy_to_device_node.h @@ -0,0 +1,95 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_COPY_TO_DEVICE_NODE_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_COPY_TO_DEVICE_NODE_H_ + +#include + +#include "tensorflow/core/common_runtime/device.h" +#include "tensorflow/core/common_runtime/eager/eager_executor.h" +#include "tensorflow/core/common_runtime/eager/tensor_handle.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/profiler/lib/scoped_memory_debug_annotation.h" + +namespace tensorflow { + +class CopyToDeviceNode : public EagerNode { + public: + CopyToDeviceNode(TensorHandle* src, TensorHandle* dst, Device* dstd, + const EagerContext& ctx, bool async, bool mirror) + : EagerNode(), + src_(src), + dst_(dst), + dstd_(dstd), + ctx_(ctx), + async_(async), + mirror_(mirror) { + if (async_) { + src_->Ref(); + dst_->Ref(); + } + } + + ~CopyToDeviceNode() override { + if (async_) { + src_->Unref(); + dst_->Unref(); + } + } + + absl::Status Run() override { + tensorflow::Tensor tensor; + tsl::profiler::ScopedMemoryDebugAnnotation op_annotation( + "eager::CopyToDeviceNode", "dynamic", tensor.dtype(), + [&tensor]() { return tensor.shape().DebugString(); }); + TF_RETURN_IF_ERROR(src_->CopyToDevice(ctx_, dstd_, &tensor)); + if (!async_ && mirror_) { + absl::Status s = dst_->AddLocalMirror(std::move(tensor), dstd_); + // If a mirror was added since we called HasLocalMirror then just return + // and ignore the error. + if (s.ok() || (s.code() == error::Code::ALREADY_EXISTS)) { + return absl::OkStatus(); + } + return s; + } else { + return dst_->SetTensor(std::move(tensor), dstd_); + } + } + + void Abort(absl::Status status) override { dst_->Poison(status, dstd_); } + + string DebugString() const override { + string out = "[CopyToDeviceNode]"; + strings::StrAppend(&out, " src_tensor: ", src_->DebugString()); + strings::StrAppend(&out, ", dst_tensor: ", dst_->DebugString()); + strings::StrAppend(&out, ", dst_device: ", dstd_ ? dstd_->name() : "[]"); + return out; + } + + TensorHandle* dst() { return dst_; } + + private: + TensorHandle* src_; + TensorHandle* dst_; + Device* dstd_; + const EagerContext& ctx_; + bool async_; + bool mirror_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_COPY_TO_DEVICE_NODE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/eager/custom_device.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/eager/custom_device.h new file mode 100644 index 00000000..2f4f5acc --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/eager/custom_device.h @@ -0,0 +1,134 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_CUSTOM_DEVICE_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_CUSTOM_DEVICE_H_ + +#include +#include +#include + +#include "tensorflow/c/eager/immediate_execution_context.h" +#include "tensorflow/c/eager/immediate_execution_operation.h" +#include "tensorflow/c/eager/immediate_execution_tensor_handle.h" +#include "tensorflow/core/framework/full_type.pb.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/util/device_name_utils.h" + +namespace tensorflow { + +class TensorHandle; +class EagerOperation; +class CustomDeviceTensorHandle; + +// Custom devices intercept the execution of operations (the `Execute` method), +// typically implemented with one or more of the custom device's own executions. +class CustomDevice { + public: + virtual ~CustomDevice() = default; + virtual const string& name() = 0; + virtual absl::Status CopyTensorToDevice( + ImmediateExecutionTensorHandle* tensor, + ImmediateExecutionTensorHandle** result) = 0; + + virtual absl::Status CopyTensorFromDevice( + ImmediateExecutionTensorHandle* tensor, const string& target_device_name, + ImmediateExecutionTensorHandle** result) = 0; + + virtual absl::Status Execute(const ImmediateExecutionOperation* op, + ImmediateExecutionTensorHandle** retvals, + int* num_retvals) = 0; + + // Creates a packed TensorHandle from a group of custom device TensorHandles, + // one of which is on this custom device. + virtual absl::Status Pack(absl::Span handles, + ImmediateExecutionTensorHandle** result) = 0; + + // Returns true signifying to pin to the current custom device. + // Returns false to pin to the physical device. + virtual absl::StatusOr ShallPinToThisDevice( + const ImmediateExecutionOperation* op) = 0; +}; + +// Custom devices do many of the same things as physical Devices, but have a +// much more restricted interface. We pass around ambiguous pointers since +// operations may be placed either on custom or physical devices. +using VariantDevice = std::variant; + +// Indicates either HostCPU or an unset physical device. We never set a null +// CustomDevice*. +const VariantDevice kVariantDeviceNull = static_cast(nullptr); + +// A tensor handle produced by a custom device. Generally they can only be +// consumed by executing an operation on the same custom device that produced it +// originally, or by attempting to copy the handle off the custom device. +// +// TODO(allenl): Currently custom devices are tied to the eager C API. They +// should be renamed op handlers and subclass AbstractTensorHandle instead so +// they are eager/graph agnostic. +// +// full_type_ is not set by the constructor (because it is not currently +// needed). If full type information is needed in the future, the constructor +// could use map_dtype_to_child_of_tensor() from core/framework/types.h to set +// it based on dtype. Update test CustomDevice.TestTensorHandle in +// custom_device_test.cc if this changes. +class CustomDeviceTensorHandle : public ImmediateExecutionTensorHandle { + public: + CustomDeviceTensorHandle(ImmediateExecutionContext* context, + CustomDevice* device, tensorflow::DataType dtype) + : ImmediateExecutionTensorHandle(kCustomDevice), + context_(context), + device_(device), + dtype_(dtype) {} + + // TODO(allenl): Should this be a generic method of + // ImmediateExecutionTensorHandle to support TFE_TensorHandleDevicePointer? + virtual void* DevicePointer() const = 0; + + tensorflow::DataType DataType() const override { return dtype_; } + tensorflow::FullTypeDef FullType() const override { return full_type_; } + absl::Status Shape(PartialTensorShape* shape) const override; + absl::Status NumElements(int64_t* num_elements) const override; + + const char* DeviceName(absl::Status* status) const override { + return device_->name().c_str(); + } + const char* BackingDeviceName(absl::Status* status) const override { + return device_->name().c_str(); + } + CustomDevice* device() const { return device_; } + const char* DeviceType(absl::Status* status) const override; + int DeviceId(absl::Status* status) const override; + + AbstractTensorInterface* Resolve(absl::Status* status) override; + + // For LLVM style RTTI. + static bool classof(const AbstractTensorHandle* ptr) { + return ptr->getKind() == kCustomDevice; + } + + protected: + const DeviceNameUtils::ParsedName* ParsedName(absl::Status* status) const; + + ImmediateExecutionContext* const context_; + CustomDevice* const device_; + const tensorflow::DataType dtype_; + tensorflow::FullTypeDef full_type_; + + mutable std::optional parsed_name_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_CUSTOM_DEVICE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/eager/custom_device_op_handler.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/eager/custom_device_op_handler.h new file mode 100644 index 00000000..6c38e50d --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/eager/custom_device_op_handler.h @@ -0,0 +1,60 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_CUSTOM_DEVICE_OP_HANDLER_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_CUSTOM_DEVICE_OP_HANDLER_H_ + +#include +#include + +#include "tensorflow/c/eager/immediate_execution_operation.h" +#include "tensorflow/c/eager/immediate_execution_tensor_handle.h" +#include "tensorflow/core/common_runtime/eager/custom_device.h" +#include "tensorflow/core/lib/core/status.h" +namespace tensorflow { + +// TODO(tfrt-devs): Figure out a way to unify it with OpHandler in TFRT. +class CustomDeviceOpHandler { + public: + ~CustomDeviceOpHandler() = default; + // Register a new custom device. + absl::Status RegisterCustomDevice(const string& device_name, + std::unique_ptr device); + + // Find the custom device from given name. Return true if it finds one. + bool FindCustomDeviceFromName(const string& name, + CustomDevice** device) const; + + absl::Status Execute(ImmediateExecutionOperation* op, + ImmediateExecutionTensorHandle** retvals, + int* num_retvals); + + ImmediateExecutionTensorHandle* CopyTensorHandleToDevice( + ImmediateExecutionContext* context, + ImmediateExecutionTensorHandle* handle, const char* device_name, + absl::Status* status); + + // Determine whether to place an op on a custom device. This method is + // exposed as public for test only. + absl::Status MaybePinToCustomDevice( + CustomDevice** device, const ImmediateExecutionOperation& op) const; + + void Clear(); + + private: + std::unordered_map> custom_devices_; +}; +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_CUSTOM_DEVICE_OP_HANDLER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/eager/eager_executor.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/eager/eager_executor.h new file mode 100644 index 00000000..cec897b3 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/eager/eager_executor.h @@ -0,0 +1,291 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_EAGER_EXECUTOR_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_EAGER_EXECUTOR_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "tensorflow/core/common_runtime/device_factory.h" +#include "tensorflow/core/common_runtime/function.h" +#include "tensorflow/core/common_runtime/rendezvous_mgr.h" +#include "tensorflow/core/framework/rendezvous.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/refcount.h" +#include "tensorflow/core/lib/gtl/inlined_vector.h" +#include "tensorflow/core/lib/gtl/map_util.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/thread_annotations.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/public/version.h" + +namespace tensorflow { + +class AsyncEagerNode; +class AsyncRemoteExecuteNode; +namespace eager { +class EagerClient; +} + +// A unit of execution for the EagerExecutor class below. Example subclasses +// encapsulate execution of a TFE_Op, or copying a TFE_TensorHandle from one +// device to another. +class EagerNode { + public: + EagerNode() = default; + + virtual ~EagerNode() = default; + + // Prepares the node when adding it into EagerExecutor. If any errors happens, + // EagerExecutor will abort the node immediately. + virtual absl::Status Prepare() { return absl::OkStatus(); } + + // Runs the computation corresponding to this node and blocks till the + // execution is done. + virtual absl::Status Run() = 0; + + // Called when this node will not be run due to some error contained in + // `status`. `status` must not be OK. + // For example, if the node would have computed some tensors in the Run(), + // it should poison the corresponding tensor handles in this method. + virtual void Abort(absl::Status status) = 0; + + // Returns nullptr iff this Eager node is synchronous. + virtual AsyncEagerNode* AsAsync() { return nullptr; } + virtual AsyncRemoteExecuteNode* AsAsyncRemoteExecuteNode() { return nullptr; } + + virtual string DebugString() const = 0; + + // Indicates whether a node failure should make the executor unusable. + virtual bool Fatal() const { return true; } +}; + +class AsyncEagerNode : public EagerNode { + public: + using EagerNode::EagerNode; // Lift EagerNode constructors. + + // This node will be cleaned up once the done callback is called. + virtual void RunAsync(StatusCallback done) = 0; + + AsyncEagerNode* AsAsync() final { return this; } + + absl::Status Run() final { + return errors::Unimplemented("Don't call AsyncEagerNode::Run()."); + } +}; + +class AsyncRemoteExecuteNode : public AsyncEagerNode { + public: + AsyncRemoteExecuteNode* AsAsyncRemoteExecuteNode() final { return this; } + + virtual const eager::EagerClient* eager_client() const = 0; + virtual bool needs_remote_inputs() const = 0; + virtual bool allow_multiple_pending_requests() const = 0; + virtual absl::Status SyncExecutors() = 0; +}; + +// A class for handling async execution (see TFE_ContextSetAsync). +// Note that this class is thread-safe. +// TODO(agarwal): TFE_OpAddInput may currently block if it tries to access the +// device of the input handle. Fix that. +// TODO(agarwal): Implement support for control dependencies. +// TODO(agarwal): Support out-of-order execution and dispatching multiple +// EagerNode in parallel. +// TODO(agarwal): Implement optimizations over EagerNode traces. +class EagerExecutor { + public: + explicit EagerExecutor(bool async, bool enable_streaming_enqueue = true, + int in_flight_nodes_limit = 0); + + ~EagerExecutor(); + + // Puts this in a shutdown state. In this state, AddOrExecute() will return an + // error and not add new EagerNodes. After putting this in the shutdown state, + // blocks until all pendings nodes have finished running. + // Returns the status of executing pending nodes. + // If async was not enabled, aborts and destroys all pending nodes. + absl::Status ShutDown(); + + bool Async() const; + + bool StreamingEnqueue() const; + + // Inline execute node if executor is in sync mode. + absl::Status SyncExecute(EagerNode* node); + + // - Async Mode: schedules `node` for execution. + // - Sync Mode: inline execute the 'node' directly. + // If an error occurs (e.g. EagerExecutor has already been shut down), the + // `node` is not added to this executor and its Abort() method is called. + absl::Status AddOrExecute(std::unique_ptr node); + + // Blocks till all currently pending ops are done. + // In particular, if EnableAsync() has not beed called, it will not return + // until that happens (and pendings, at the time of call, nodes finish + // running). If this executor has already been shut down, its final status is + // returned. + absl::Status WaitForAllPendingNodes(); + + // Clears all currently set errors which re-enables async execution. + void ClearError(); + + // Returns Status based on any errors that occurred during async execution. + absl::Status status() const { + if (ok()) return absl::OkStatus(); + + tf_shared_lock l(node_queue_mutex_); + return status_; + } + + bool ok() const TF_NO_THREAD_SAFETY_ANALYSIS { return ok_; } + + // On destruction, runs `callback`. Used by the EagerContext for clearing + // thread-local executors. + void AddCleanup(intptr_t key, std::function callback); + // If `key` (e.g. a context) is destroyed before the executor, the associated + // callbacks are no longer safe to run. + void RemoveCleanups(intptr_t key); + + private: + // Possible states for this executor. + // Executor starts in kActive state. When Shutdown() is called, Executor + // is put in the kShuttingDown state. In this state, the executor thread + // continues to run, but no new nodes are accepted. Finally, when all nodes + // are drained, the executor is put in the kShutDown state, which causes the + // thread to exit. + // If this executor is destroyed without calling shutdown first, it + // transitions to kShutDown state immediately which causes the thread to exit + // without running pending nodes. + enum class ExecutorState { + kActive, + kShuttingDown, + kShutDown, + }; + + enum class NodeState { + kPENDING, + kSCHEDULED, + kDONE, + }; + + struct NodeItem : core::RefCounted { + // Unique id generated in EagerExecutor::Add(). If item1.id < item2.id, it + // means item1.node is added before item2.node. + uint64 id; + std::unique_ptr node; + NodeState state; + }; + + const char* StateStringLocked() + TF_EXCLUSIVE_LOCKS_REQUIRED(node_queue_mutex_); + + void NodeDone(const core::RefCountPtr& item, + const absl::Status& status, bool from_queue); + void NotifyWaiters(uint64 id) TF_EXCLUSIVE_LOCKS_REQUIRED(node_queue_mutex_); + + // Starts execution of pending EagerNodes. This function loops till executor + // state_ is set to kShutDown. If any errors are encountered, these are set + // inside `status_`. The loop blocks anytime there are no pending nodes, or if + // `status_` is not ok. + void Run(); + + absl::Status RunItem(core::RefCountPtr item, bool from_queue); + absl::Status MoveToUnfinished(core::RefCountPtr item, + bool from_queue); + + // The impl of WaitForAllPendingNodes + // `lock` is the lock that holds node_queue_mutex_. + absl::Status WaitForAllPendingNodesLocked(mutex_lock* lock) + TF_EXCLUSIVE_LOCKS_REQUIRED(node_queue_mutex_); + + absl::Status WaitImpl(bool wait_all, uint64 node_id); + + std::atomic next_node_id_; + + mutable mutex node_queue_mutex_; + + // Used to signal that some EagerNodes are pending execution. + condition_variable nodes_pending_ TF_GUARDED_BY(node_queue_mutex_); + // Used to signal that some EagerNodes are done. + condition_variable nodes_done_ TF_GUARDED_BY(node_queue_mutex_); + + // Queue of pending NodeItems. Ordered by NodeItem::id. + std::queue> node_queue_ + TF_GUARDED_BY(node_queue_mutex_); + + // Ordered by NodeItem::id. + std::map, std::less> + unfinished_nodes_ TF_GUARDED_BY(node_queue_mutex_); + + // `status_` is set based on any errors raised during execution of a + // EagerNode. It remains set until ClearError is called. + absl::Status status_ TF_GUARDED_BY(node_queue_mutex_); + std::atomic ok_ TF_GUARDED_BY(node_queue_mutex_); + + // Map from id of a EagerNode to condition_variables (not owned by the map). + // These condition_variables are notified and removed when that EagerNode is + // done executing, or if an error is found in execution of any EagerNode. + // The map is ordered by id. + std::multimap> + node_done_notifications_ TF_GUARDED_BY(node_queue_mutex_); + + // thread_exited_notification_ is notified by the `thread_` right before it + // exits. + Notification thread_exited_notification_; + + // When state_ is set to kShutDown, it indicates that `thread_` should stop as + // soon as it is done executing the current EagerNode. + ExecutorState state_ TF_GUARDED_BY(node_queue_mutex_) = + ExecutorState::kActive; + + // Thread object that calls the `Run` method in async mode.This thread runs + // until state_ is set to kShuttingDown. It is `nullptr` in sync mode. + const std::unique_ptr thread_; + + // Last device where remote function with remote inputs was executed. + const eager::EagerClient* last_eager_client_; + + const bool enable_async_wait_for_remote_function_; + + // Enable sending remote executions through streaming enqueue. + const bool enable_streaming_enqueue_; + + // Callbacks to run on destruction. + absl::flat_hash_map>> cleanups_; + + // Limit the number of in-flight nodes. When the number of in-flight eager + // async nodes reach this number, enqueuing to the eager async queue is + // blocked. + const int64_t in_flight_nodes_limit_; +}; + +inline bool EagerExecutor::Async() const { return thread_ != nullptr; } + +inline bool EagerExecutor::StreamingEnqueue() const { + return enable_streaming_enqueue_; +} + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_EAGER_EXECUTOR_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/eager/eager_op_rewrite_registry.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/eager/eager_op_rewrite_registry.h new file mode 100644 index 00000000..bd709847 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/eager/eager_op_rewrite_registry.h @@ -0,0 +1,110 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_EAGER_OP_REWRITE_REGISTRY_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_EAGER_OP_REWRITE_REGISTRY_H_ + +#include +#include +#include +#include + +#include "tensorflow/core/common_runtime/eager/eager_operation.h" + +namespace tensorflow { + +// Eager op rewrites should inherit from this class and +// implement the Run method. +class EagerOpRewrite { + public: + EagerOpRewrite(string name, string file, string line) { + debug_info_.name = name; + debug_info_.file = file; + debug_info_.line = line; + } + + virtual ~EagerOpRewrite() = default; + + // To be implemented by an Eager op rewrite pass. + virtual absl::Status Run( + EagerOperation* orig_op, + std::unique_ptr* out_op) = 0; + + // Holds information about the rewrite registration. + struct DebugInfo { + string name, file, line; + }; + + // Returns information about the registered Eager op rewrite. + DebugInfo GetDebugInfo() const { return debug_info_; } + + private: + DebugInfo debug_info_; +}; + +class EagerOpRewriteRegistry { + public: + // Phases at which the Eager op rewrite pass should run. + enum Phase { + PRE_EXECUTION = 0, // right before executing an eager op + POST_PLACEMENT = 1 // after device placement + }; + + // Add a rewrite pass to the registry. + void Register(Phase phase, int32_t ordinal, + std::unique_ptr pass); + + // Run the rewrite pass registered for a given phase. + absl::Status RunRewrite(Phase phase, EagerOperation* orig_op, + std::unique_ptr* out_op); + + // Returns the global registry of rewrite passes. + static EagerOpRewriteRegistry* Global(); + + private: + static constexpr int32_t kNumPhases = 2; + // Holds all the registered Eager op rewrites and their ordinal numbers. + std::array, int32>>, + kNumPhases> + rewrites_; +}; + +namespace eager_rewrite_registration { + +// This class is used to register a new Eager Op rewrite. +class EagerRewriteRegistration { + public: + EagerRewriteRegistration(EagerOpRewriteRegistry::Phase phase, int32_t ordinal, + std::unique_ptr pass) { + EagerOpRewriteRegistry::Global()->Register(phase, ordinal, std::move(pass)); + } +}; + +} // namespace eager_rewrite_registration + +#define REGISTER_REWRITE(phase, ordinal, rewrite) \ + REGISTER_REWRITE_UNIQ_HELPER(__COUNTER__, __FILE__, __LINE__, phase, \ + ordinal, rewrite) + +#define REGISTER_REWRITE_UNIQ_HELPER(ctr, file, line, phase, ordinal, rewrite) \ + REGISTER_REWRITE_UNIQ(ctr, file, line, phase, ordinal, rewrite) + +#define REGISTER_REWRITE_UNIQ(ctr, file, line, phase, ordinal, rewrite) \ + static ::tensorflow::eager_rewrite_registration::EagerRewriteRegistration \ + register_rewrite_##ctr(phase, ordinal, \ + ::std::unique_ptr<::tensorflow::EagerOpRewrite>( \ + new rewrite(#rewrite, file, #line))) + +} // namespace tensorflow +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_EAGER_OP_REWRITE_REGISTRY_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/eager/eager_operation.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/eager/eager_operation.h new file mode 100644 index 00000000..b81b0fc7 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/eager/eager_operation.h @@ -0,0 +1,347 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_EAGER_OPERATION_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_EAGER_OPERATION_H_ + +#include +#include +#include +#include + +#include "absl/container/inlined_vector.h" +#include "absl/types/optional.h" +#include "absl/types/span.h" +#include "absl/types/variant.h" +#include "tensorflow/c/eager/abstract_tensor_handle.h" +#include "tensorflow/c/eager/immediate_execution_operation.h" +#include "tensorflow/core/common_runtime/eager/attr_builder.h" +#include "tensorflow/core/common_runtime/eager/context.h" +#include "tensorflow/core/common_runtime/eager/eager_executor.h" +#include "tensorflow/core/common_runtime/eager/kernel_and_device.h" +#include "tensorflow/core/common_runtime/eager/tensor_handle.h" +#include "tensorflow/core/framework/cancellation.h" +#include "tensorflow/core/framework/device_attributes.pb.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/op_def.pb.h" +#include "tensorflow/core/util/device_name_utils.h" +#include "tensorflow/core/util/managed_stack_trace.h" + +namespace tensorflow { + +class EagerOperation : public ImmediateExecutionOperation { + public: + explicit EagerOperation(tensorflow::EagerContext* ctx) + : ImmediateExecutionOperation(kEager), ctx_(*ctx), is_function_(false) {} + ~EagerOperation() override { + for (ImmediateExecutionTensorHandle* h : inputs_) { + h->Unref(); + } + } + + void Release() override { delete this; } + + void Clear() override; + absl::Status Reset(const char* op, const char* raw_device_name) override { + return Reset(op, raw_device_name, false, nullptr); + } + + const string& Name() const override { return attrs_.op_name(); } + + const string& DeviceName() const override { return device_name_; } + + ImmediateExecutionContext* GetContext() const override { return &ctx_; } + + const DeviceNameUtils::ParsedName& GetDeviceParsedName() const { + return device_parsed_name_; + } + + // Replaces the previous device name with the given one (see + // AbstractOperation::SetDeviceName for more details). + // + // This also resets the internal device pointer, unless the given name refers + // to a known custom device, in which case the internal device pointer is + // updated to that device. + absl::Status SetDeviceName(const char* name) override; + + void SetDevice(VariantDevice device) { + device_ = device; + device_name_ = std::visit( + [](auto* device) { return device == nullptr ? "" : device->name(); }, + device); + DeviceNameUtils::ParseFullName(device_name_, &device_parsed_name_); + // TODO(b/154133594): Due to intricacies of external logic, we can not + // set this do device_name_ as it would be natural, because we need the + // next call to SetDeviceName to reset the device pointer. + last_set_device_name_ = "\177"; // DEL (an invalid value) + } + + absl::Status SetAttrValue(const char* attr_name, const AttrValue& value); + + absl::Status AddInput(AbstractTensorHandle* input) override; + absl::Status AddInputList( + absl::Span inputs) override; + absl::Status SetInput(size_t index, + ImmediateExecutionTensorHandle* input) override; + absl::Span GetInputs() const override; + bool HasCustomDeviceInput() const override { + return custom_device_tensor_handles_count_ > 0; + } + absl::Status Execute(absl::Span retvals, + int* num_retvals) override; + const tensorflow::OpDef* OpDef() const override { return op_def_; }; + + absl::Status SetAttrString(const char* attr_name, const char* data, + size_t length) override; + absl::Status SetAttrInt(const char* attr_name, int64_t value) override; + absl::Status SetAttrFloat(const char* attr_name, float value) override; + absl::Status SetAttrBool(const char* attr_name, bool value) override; + absl::Status SetAttrType(const char* attr_name, DataType value) override; + absl::Status SetAttrShape(const char* attr_name, const int64_t* dims, + int num_dims) override; + absl::Status SetAttrFunction(const char* attr_name, + const AbstractOperation* value) override; + absl::Status SetAttrFunctionName(const char* attr_name, const char* data, + size_t length) override; + absl::Status SetAttrTensor(const char* attr_name, + AbstractTensorInterface* tensor) override; + absl::Status SetAttrStringList(const char* attr_name, + const void* const* values, + const size_t* lengths, + int num_values) override; + absl::Status SetAttrFloatList(const char* attr_name, const float* values, + int num_values) override; + absl::Status SetAttrIntList(const char* attr_name, const int64_t* values, + int num_values) override; + absl::Status SetAttrTypeList(const char* attr_name, const DataType* values, + int num_values) override; + absl::Status SetAttrBoolList(const char* attr_name, + const unsigned char* values, + int num_values) override; + absl::Status SetAttrShapeList(const char* attr_name, const int64_t** dims, + const int* num_dims, int num_values) override; + absl::Status SetAttrFunctionList( + const char* attr_name, + absl::Span values) override; + + absl::Status InputLength(const char* input_name, int* length) override; + absl::Status OutputLength(const char* output_name, int* length) override; + + const AbstractOpAttrs* GetOpAttrs() const override; + void AddAttrs(const AbstractOpAttrs* op_attrs) override; + + void SetStackTrace(ManagedStackTrace stack_trace) override { + stack_trace_ = stack_trace; + } + + std::optional GetStackTrace() override { + return stack_trace_; + } + + absl::Status Reset( + const char* op, const char* device_name, bool remote, + EagerExecutor* executor, + absl::optional eager_func_params = std::nullopt); + + bool is_function() const { return is_function_; } + bool colocation_exempt() const { return colocation_exempt_; } + + tensorflow::EagerContext& EagerContext() const { return ctx_; } + + const FunctionLibraryDefinition* FuncLibDef() const { + if (eager_func_params_.has_value() && + eager_func_params_.value().func_lib_def_override) { + return eager_func_params_.value().func_lib_def_override; + } else { + return ctx_.FuncLibDef(); + } + } + + const FunctionDef* GetFunctionDef() const { + if (is_function_) { + return FuncLibDef()->Find(attrs_.op_name()); + } else { + return nullptr; + } + } + + AttrBuilder* MutableAttrs() { return &attrs_; } + const AttrBuilder& Attrs() const { return attrs_; } + + // TensorHandleInputs and MutableTensorHandleInputs first check that all + // inputs are TensorHandles, i.e. that there are no custom device inputs. They + // return a bad status otherwise. + absl::Status TensorHandleInputs( + const absl::InlinedVector** inputs) const; + absl::Status MutableTensorHandleInputs( + absl::InlinedVector** inputs); + + const absl::InlinedVector& Inputs() + const { + return inputs_; + } + + void UpdateInput(int i, TensorHandle* h); + + // This is useful if we want the EagerOperation to point to a different + // function. + void UpdateName(const string& name) { + op_name_ = name.c_str(); + attrs_.set_op_name(name); + } + + // Like TensorHandles, EagerOperations may be placed either on a virtual + // CustomDevice or on a physical Device. + VariantDevice Device() const { return device_; } + + // Indicates whether the op is assigned to a device that is local to the + // current host. + bool IsLocal() const; + + CancellationManager* GetCancellationManager() const { + return cancellation_manager_; + } + void SetCancellationManager( + CancellationManager* cancellation_manager) override { + cancellation_manager_ = cancellation_manager; + } + + // Assign step_id value only if op has valid step id. + // When eager_func_params.has_value() returns true, we can directly overwrite + // its step id according to Op's step id (if not default value). However, when + // eager_func_params.has_value() returns false, we need to first create a new + // EagerFuncParams object for it before assigning step_id; otherwise, + // directly assigning step_id in this case leaves eager_func_params to be + // in a weird state where: + // (1) eager_func_params.has_value() returns false, but + // (2) eager_func_params->step_id.has_value() returns true. + void SetStepId(int64_t step_id) override { + assert(is_function()); + if (step_id != EagerContext::kGlobalRendezvousId) { + if (eager_func_params_.has_value()) { + eager_func_params_->step_id = step_id; + } else { + eager_func_params_ = EagerFunctionParams{ + kInvalidOpId, /*is_component_function=*/false, step_id}; + } + } else { + LOG(WARNING) << "SetStepId() should not receive a gloabl rendezvous id."; + } + } + + EagerExecutor& Executor() { return *executor_; } + + string DebugString() const; + + const absl::optional& eager_func_params() const { + return eager_func_params_; + } + + // Op name recorded for memory debugging purpose. + const char* op_name() const { return op_name_; } + + // For LLVM style RTTI. + static bool classof(const AbstractOperation* ptr) { + return ptr->getKind() == kEager; + } + + private: + void AddTensorHandle(ImmediateExecutionTensorHandle* h); + + const tensorflow::OpDef* GetOpDef(absl::Status* status); + + void ClearInferenceState() { + op_def_ = nullptr; + inference_arg_idx_ = 0; + inference_attrs_.clear_no_resize(); + } + + absl::Status MaybeInferSingleInputAttrs( + ImmediateExecutionTensorHandle* handle); + absl::Status InferInputListAttrs(int num_inputs); + + void InferSingleTypeInputListAttrs(const OpDef::ArgDef& input_def, + DataType dtype, int num_inputs); + void InferMixedTypeInputListAttrs(const OpDef::ArgDef& input_def, + const std::vector& dtypes); + + tensorflow::EagerContext& ctx_; + const char* op_name_ = nullptr; + AttrBuilder attrs_; + const AttrTypeMap* attr_types_; + + // The number of custom device TensorHandle inputs. These inputs need to be + // processed by CustomDeviceOpHandler first. + int custom_device_tensor_handles_count_ = 0; + absl::InlinedVector inputs_; + + // The last device name given to SetDeviceName. + // This is used to avoid having to re-process the same device in repeated + // calls to SetDeviceName. + string last_set_device_name_; + + // The operation's device name. + // This contains the named passed to SetDeviceName until device_ is set, + // at which point it contains the device_ name. + string device_name_; + + // The parsed device name. + // This will always contain the result of + // DeviceNameUtils::ParseFullName(device_name_). + DeviceNameUtils::ParsedName device_parsed_name_; + + // The operation's device. + // This is set by the execution device placement logic, and should conform + // with the contents of device_name_. Once it is set, the device_name_ is + // updated accordingly. + VariantDevice device_; + + std::optional stack_trace_; + bool is_function_; // Conceptually const, but can't be because of Reset + bool colocation_exempt_; + CancellationManager* cancellation_manager_ = nullptr; // Not owned. + EagerExecutor* executor_; // Not owned. + + std::optional eager_func_params_; + + // Inference information + const tensorflow::OpDef* op_def_; // op definition from protobuf + int inference_arg_idx_; // arg definition index for the next input to be + // added + gtl::FlatSet inference_attrs_; // attributes inferred so far +}; + +inline void EagerOperation::UpdateInput(int i, TensorHandle* h) { + ImmediateExecutionTensorHandle** slot = &inputs_[i]; + ImmediateExecutionTensorHandle* existing = *slot; + if (existing != h) { + h->Ref(); + existing->Unref(); + *slot = h; // Update inputs_[i] to h + } +} + +inline EagerOperation* OperationFromInterface( + ImmediateExecutionOperation* operation) { + return down_cast(operation); +} + +inline const EagerOperation* OperationFromInterface( + const ImmediateExecutionOperation* operation) { + return down_cast(operation); +} + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_EAGER_OPERATION_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/eager/execute.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/eager/execute.h new file mode 100644 index 00000000..cbd1e0c9 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/eager/execute.h @@ -0,0 +1,90 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_EXECUTE_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_EXECUTE_H_ + +#include "absl/container/inlined_vector.h" +#include "absl/types/span.h" +#include "tensorflow/core/common_runtime/device.h" +#include "tensorflow/core/common_runtime/eager/context.h" +#include "tensorflow/core/common_runtime/eager/eager_operation.h" +#include "tensorflow/core/common_runtime/eager/kernel_and_device.h" +#include "tensorflow/core/common_runtime/eager/tensor_handle.h" +#include "tensorflow/core/framework/step_stats.pb.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { + +// Utility function that executes a fully constructed EagerOperation. +// There are a few possible different combinations of how things can be +// executed: +// - Async (the op context is configured to schedule asynchronously) +// Eager execute should return quickly after scheduling this operation to +// execute. +// - Remote (the op device is on a remote task) +// Eager execute will send an RPC to execute the op on a remote device. +// Note that in the Async + Remote case, EagerExecute should still return +// quickly, but it will schedule the op to be executed remotely. +// +// 'retvals' must point to a pre-allocated array of TensorHandle* and +// '*num_retvals' should be set to the size of this array. It is an error if +// the size of 'retvals' is less than the number of outputs. This call sets +// *num_retvals to the number of outputs. +absl::Status EagerExecute(EagerOperation* op, TensorHandle** retvals, + int* num_retvals); + +// Low-level utility to execute the kernel specified by `kernel` on +// `kernel->device()`, with the inputs op_inputs, in the context 'ctx'. +absl::Status EagerKernelExecute( + EagerContext* ctx, const absl::InlinedVector& op_inputs, + const absl::optional& eager_func_params, + const core::RefCountPtr& kernel, + GraphCollector* graph_collector, CancellationManager* cancellation_manager, + absl::Span retvals, + const absl::optional& stack_trace = {}); + +// Low-level utility to copy a tensor handle from one device to another. If +// successful, result TensorHandle will be populated. If the caller requests for +// the mirror flag, EagerCopyToDevice will attempt to add a mirror to the +// original handle and update *result to point to h. Since this is not +// guaranteed, callers should always use the value in *result. +absl::Status EagerCopyToDevice(TensorHandle* h, EagerContext* ctx, + EagerExecutor* executor, Device* device, + bool mirror, TensorHandle** result); + +// Utility function that executes a fully constructed EagerOperation +// asynchronously on the local task. This function works differently from +// EagerExecute in several ways: +// - It supports local execution only. +// - It returns after launching the eager operation to run asynchronously. +// Different from EagerExecute with async context that apends the operation +// to the end of the eager executor schedule queue, this call bypasses the +// executor logic and directly launches op execution. Ops running through +// this call does NOT have an ordering and can be executed in parallel. +// - It takes a StatusCallback which will be triggered after execution with the +// execution status. +// +// Does not support custom device. +// +// 'retvals' must point to a pre-allocated array of TensorHandle* and +// '*num_retvals' should be set to the size of this array. It is an error if +// the size of 'retvals' is less than the number of outputs. This call sets +// *num_retvals to the number of outputs. +void EagerLocalExecuteAsync(EagerOperation* op, TensorHandle** retvals, + int* num_retvals, StatusCallback done); + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_EXECUTE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/eager/execute_node.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/eager/execute_node.h new file mode 100644 index 00000000..52bf1ecf --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/eager/execute_node.h @@ -0,0 +1,252 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_EXECUTE_NODE_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_EXECUTE_NODE_H_ + +// clang-format off +// Required for IS_MOBILE_PLATFORM +#include +#include +#include +#include +#include +#include +#include "absl/container/flat_hash_map.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/platform.h" +// clang-format on + +#include "absl/container/inlined_vector.h" +#include "absl/memory/memory.h" +#include "absl/types/optional.h" +#include "absl/types/span.h" +#include "tensorflow/core/common_runtime/device.h" +#include "tensorflow/core/common_runtime/eager/context.h" +#include "tensorflow/core/common_runtime/eager/eager_executor.h" +#include "tensorflow/core/common_runtime/eager/execute.h" +#include "tensorflow/core/common_runtime/eager/kernel_and_device.h" +#include "tensorflow/core/common_runtime/eager/tensor_handle.h" +#include "tensorflow/core/framework/step_stats.pb.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/strings/strcat.h" +#if !defined(IS_MOBILE_PLATFORM) +#include "tensorflow/core/distributed_runtime/eager/remote_mgr.h" +#include "tensorflow/core/protobuf/remote_tensor_handle.pb.h" +#endif // IS_MOBILE_PLATFORM + +namespace tensorflow { + +class ExecuteNodeArgs : public EagerKernelArgs { + public: + explicit ExecuteNodeArgs(int count) : EagerKernelArgs(count) {} + + absl::Status Init(EagerContext* ctx, + const absl::InlinedVector& op_inputs, + const core::RefCountPtr& kernel); + + absl::Status GetLocalArg(const FunctionArgIndex& index, + Tensor* val) const override; + + bool HasRemoteOrPackedInputs() const override { + return has_remote_inputs_ || has_packed_inputs_; + }; + +#if !defined(IS_MOBILE_PLATFORM) + absl::Status GetRemoteArg(const FunctionArgIndex& index, + eager::RemoteTensorHandle* val) const override { + return serialize_remote_handle_(index, val); + } +#endif // IS_MOBILE_PLATFORM + + private: +#if !defined(IS_MOBILE_PLATFORM) + // Returns whether `handle` is a remote handle or has a remote mirror on + // `input_device` + bool IsRemote(EagerContext* ctx, Device* input_device, TensorHandle* handle); +#endif // IS_MOBILE_PLATFORM + + // Initialize a packed TensorHandle which is the `index`-th argument. + absl::Status InitPackedHandle(int index, EagerContext* ctx, + Device* input_device, + TensorHandle* packed_handle); + + bool has_remote_inputs_ = false; + bool has_packed_inputs_ = false; + // Maps from the index of a packed arg to a list of sub-args. + absl::flat_hash_map> packed_args_; +#if !defined(IS_MOBILE_PLATFORM) + std::function + serialize_remote_handle_; +#endif // IS_MOBILE_PLATFORM +}; + +class ExecuteNode : public EagerNode { + public: + ExecuteNode(EagerContext* ctx, + const absl::InlinedVector& inputs, + const absl::optional& eager_func_params, + const core::RefCountPtr& kernel, + GraphCollector* graph_collector, + CancellationManager* cancellation_manager, + absl::Span retvals, + std::optional stack_trace) + : EagerNode(), + ctx_(ctx), + inputs_(inputs), + eager_func_params_(eager_func_params), + kernel_(kernel), + graph_collector_(graph_collector), + cancellation_manager_(cancellation_manager), + retvals_(retvals), + stack_trace_(stack_trace) {} + + absl::Status Run() override { + int i = 0; + for (TensorHandle* h : inputs_) { + if (h->RefCountIsOne()) { + const Device* d = ctx_->CanonicalDevice(kernel_->InputDevice(i)); + absl::Status s = h->Unprotect(d); + if (!s.ok()) { + VLOG(1) << "Unable to unprotect tensor: " << s; + } + } + ++i; + } + return EagerKernelExecute(ctx_, inputs_, eager_func_params_, kernel_, + graph_collector_, cancellation_manager_, retvals_, + stack_trace_); + } + + void Abort(absl::Status status) override {} + + std::string DebugString() const override { + std::string out = "[ExecuteNode]"; + strings::StrAppend(&out, " kernel: ", kernel_->name()); + return out; + } + + private: + EagerContext* ctx_; + const absl::InlinedVector& inputs_; + const absl::optional& eager_func_params_; + const core::RefCountPtr& kernel_; + GraphCollector* graph_collector_; + CancellationManager* const cancellation_manager_; + absl::Span retvals_; + std::optional stack_trace_; +}; + +class AsyncExecuteNode : public EagerNode { + public: + AsyncExecuteNode(EagerContext* ctx, + const absl::InlinedVector& inputs, + const absl::optional& eager_func_params, + core::RefCountPtr kernel, + GraphCollector* graph_collector, + CancellationManager* cancellation_manager, + absl::Span retvals, + std::optional stack_trace) + : EagerNode(), + ctx_(ctx), + inputs_(inputs), + eager_func_params_(eager_func_params), + kernel_(std::move(kernel)), + graph_collector_(graph_collector), + cancellation_manager_(cancellation_manager), + stack_trace_(stack_trace) { + // Copy the output handles, since the container for them might get + // destroyed. + for (auto handle : retvals) { + handle->Ref(); + retvals_.push_back(handle); + } + + // This is required to ensure that the tensor handles stay alive across + // the execution. + for (auto handle : inputs_) { + handle->Ref(); + } + } + + ~AsyncExecuteNode() override { + for (auto handle : retvals_) { + handle->Unref(); + } + + for (auto handle : inputs_) { + handle->Unref(); + } + } + + absl::Status Run() override { + int i = 0; + for (TensorHandle* h : inputs_) { + if (h->RefCountIsOne()) { + const Device* d = ctx_->CanonicalDevice(kernel_->InputDevice(i)); + absl::Status s = h->Unprotect(d); + if (!s.ok()) { + VLOG(1) << "Unable to unprotect tensor: " << s; + } + } + ++i; + } + absl::Status status = EagerKernelExecute( + ctx_, inputs_, eager_func_params_, kernel_, graph_collector_, + cancellation_manager_, absl::MakeSpan(retvals_), stack_trace_); + if (!status.ok()) { + if (stack_trace_.has_value()) { + errors::SetStackTrace( + status, stack_trace_->ToStackFrames( + {}, {}, /*reverse_traversal=*/false, /*limit=*/-1)); + } + Abort(status); + return status; + } + // If status is ok, EagerKernelExecute would have called SetTensor on + // all the output handles. + return absl::OkStatus(); + } + + void Abort(absl::Status status) override { + int i = 0; + for (auto handle : retvals_) { + handle->Poison(status, ctx_->CanonicalDevice(kernel_->OutputDevice(i))); + ++i; + } + } + + std::string DebugString() const override { + std::string out = "[AsyncExecuteNode]"; + strings::StrAppend(&out, " kernel: ", kernel_->name()); + return out; + } + + private: + EagerContext* ctx_; + absl::InlinedVector inputs_; + const absl::optional eager_func_params_; + core::RefCountPtr kernel_; + GraphCollector* graph_collector_; + CancellationManager* const cancellation_manager_; + std::optional stack_trace_; + absl::InlinedVector retvals_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_EXECUTE_NODE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/eager/kernel_and_device.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/eager/kernel_and_device.h new file mode 100644 index 00000000..c13e1524 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/eager/kernel_and_device.h @@ -0,0 +1,426 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_KERNEL_AND_DEVICE_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_KERNEL_AND_DEVICE_H_ + +// Support for eager execution of TensorFlow kernels. + +#include +#include +#include +#include +#include +#include +#include + +// clang-format off +// Required for IS_MOBILE_PLATFORM +#include "absl/memory/memory.h" +#include "tensorflow/core/platform/platform.h" +// clang-format on + +#include "absl/container/flat_hash_map.h" +#include "absl/types/optional.h" +#include "tensorflow/core/common_runtime/device.h" +#include "tensorflow/core/common_runtime/process_function_library_runtime.h" +#include "tensorflow/core/framework/cancellation.h" +#include "tensorflow/core/framework/collective.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/gtl/inlined_vector.h" +#include "tensorflow/core/platform/fingerprint.h" +#include "tensorflow/core/util/managed_stack_trace.h" +#include "tensorflow/core/util/tensor_slice_reader_cache.h" +#if !defined(IS_MOBILE_PLATFORM) +#include "tensorflow/core/protobuf/remote_tensor_handle.pb.h" +#endif // IS_MOBILE_PLATFORM + +namespace tensorflow { + +static constexpr const char* const kOutputsOnOpDevice = "_OutputsOnOpDevice"; + +class ProcessFunctionLibraryRuntime; +class FunctionLibraryRuntime; + +const int64_t kInvalidOpId = -1; + +// This struct is used for: +// 1. Setting `op_id` and `step_id`, `is_component_function` for single-client +// remote function scenario, +// 2. Setting `step_id` for multi-client parallel_device scenario. +// 3. Supplying an overriding, private `FunctionLibraryDefinition` for component +// functions. +struct EagerFunctionParams { + int64_t op_id = kInvalidOpId; + bool is_component_function; + std::optional step_id = std::nullopt; + FunctionLibraryDefinition* func_lib_def_override = + nullptr; // Not owned (owned by `EagerContext`). If not null, functions + // called by the function will be looked up in this library. +}; + +class EagerKernelArgs : public FunctionArgsInterface { + public: + EagerKernelArgs() = default; + + explicit EagerKernelArgs(int count) : tensor_args_(count) {} + + explicit EagerKernelArgs(absl::InlinedVector&& tensor_args) + : tensor_args_(std::move(tensor_args)) {} + + ~EagerKernelArgs() override = default; + + bool HasRemoteOrPackedInputs() const override { return false; }; + TensorValue* MutableInput(int i) { return &tensor_args_[i]; } + + absl::Status GetLocalArg(const FunctionArgIndex& index, + Tensor* val) const override; + + std::vector GetLocalTensors() const override; + + const absl::InlinedVector* GetTensorValues() const { + return &tensor_args_; + } + + protected: + absl::InlinedVector tensor_args_; +}; + +typedef std::variant EagerKernelRet; + +// KernelAndDevice encapsulates the logic needed to run a computation eagerly. +// The computation can be a single instantiated kernel (implemented by +// KernelAndDeviceOp below) or a multi-device function (implemented by +// KernelAndDeviceFunc below). +// +// Also see: +// https://www.tensorflow.org/code/tensorflow/core/common_runtime/kernel_benchmark_testlib.h +// and +// https://www.tensorflow.org/code/tensorflow/core/kernels/ops_testutil.h +class KernelAndDevice : public core::RefCounted { + public: + // Populates this with a kernel appropriate for 'ndef'. + // + // The provided FunctionLibraryRuntime MUST outlive all calls to + // Run() on the returned KernelAndDevice. + virtual absl::Status Init( + bool log_device_placement, const NodeDef& ndef, + GraphCollector* graph_collector, + const absl::optional& eager_func_params) = 0; + + // Non-multi-device functions are run using regular CallOp and look like + // primitive operations from KernelAndDevice perspective. + // `flr` can be nullptr if the operation is not run on any specific device + // (currently can happen only for multi-device functions). + KernelAndDevice( + FunctionLibraryRuntime* flr, + std::function)>* runner, + std::unique_ptr collective_executor, + Device* host_cpu_device) + : device_(flr == nullptr ? nullptr : flr->device()), + host_cpu_device_(host_cpu_device), + flr_(flr), + collective_executor_(std::move(collective_executor)), + runner_(runner) {} + + // Not thread safe. + ~KernelAndDevice() override = default; + + virtual bool IsFunction() { return false; } + + virtual bool IsCrossProcess() { return false; } + + // TODO(ashankar): Handle list-valued inputs. + virtual absl::Status Run( + ScopedStepContainer* step_container, const EagerKernelArgs& inputs, + std::vector* outputs, + CancellationManager* cancellation_manager, + const absl::optional& eager_func_params, + const absl::optional& stack_trace, + tsl::CoordinationServiceAgent* coordination_service_agent) = 0; + + // Execute kernel asynchronously when applicable. Different from `Run` which + // blocks the caller thread and waits for the execution of the op/function, + // `RunAsync` could return before finishing the execution. The `done` callback + // will be triggered once the op/function execution finishes. + // Currently, calling RunAsync on ops might not honor the asynchronicity when + // it is called on an instance with only sync implementation, execute the + // kernel synchronously and then call the callback with the return status + // from sync execution. + virtual void RunAsync( + ScopedStepContainer* step_container, const EagerKernelArgs& inputs, + std::vector* outputs, + CancellationManager* cancellation_manager, + const absl::optional& eager_func_params, + tsl::CoordinationServiceAgent* coordination_service_agent, + StatusCallback done) = 0; + + virtual Device* InputDevice(int i) const = 0; + virtual Device* OutputDevice(int idx) const = 0; + // If idx'th output is a resource, returns the device backing the resource. + // Else, returns nullptr. + virtual Device* OutputResourceDevice(int idx) const = 0; + + // Returns the kernel that will be used to run this. + // Returns nullptr if this will be run using function library runtime. + virtual const OpKernel* kernel() const = 0; + + // Returns the device on which this kernel will run. In the case of + // multi-device functions, this is the default device that is passed to the + // placer but actual computation can happen on a different set of devices. + // Also, outputs can be produced on devices different from what this method + // returns. + Device* device() const { return device_; } + + virtual const DataTypeVector& input_dtypes() const = 0; + virtual const DataTypeVector& output_dtypes() const = 0; + + virtual int num_inputs() const = 0; + virtual int num_outputs() const = 0; + virtual const string& name() const = 0; + + protected: + std::function)>* get_runner() const; + + Device* const device_; // can be null + Device* const host_cpu_device_; // non-null + FunctionLibraryRuntime* const flr_; // can be null + const std::unique_ptr collective_executor_; + + private: + std::function)>* const runner_; // can be null +}; + +// Represents an op kernel and the device it will be run on. +class KernelAndDeviceOp final : public KernelAndDevice { + public: + KernelAndDeviceOp( + tensorflow::Rendezvous* rendezvous, bool log_memory, + FunctionLibraryRuntime* flr, + std::function)>* runner, + std::unique_ptr collective_executor, + Device* host_cpu_device) + : KernelAndDevice(flr, runner, std::move(collective_executor), + host_cpu_device), + rendezvous_(rendezvous), + log_memory_(log_memory) {} + + ~KernelAndDeviceOp() override = default; + + absl::Status Init( + bool log_device_placement, const NodeDef& ndef, + GraphCollector* graph_collector, + const absl::optional& eager_func_params) override; + + absl::Status Run( + ScopedStepContainer* step_container, const EagerKernelArgs& inputs, + std::vector* outputs, + CancellationManager* cancellation_manager, + const absl::optional& eager_func_params, + const absl::optional& stack_trace, + tsl::CoordinationServiceAgent* coordination_service_agent) override; + + void RunAsync(ScopedStepContainer* step_container, + const EagerKernelArgs& inputs, + std::vector* outputs, + CancellationManager* cancellation_manager, + const absl::optional& eager_func_params, + tsl::CoordinationServiceAgent* coordination_service_agent, + StatusCallback done) override { + // Trivial async implementation on top of the sync version + done(Run(step_container, inputs, outputs, cancellation_manager, + eager_func_params, {}, coordination_service_agent)); + } + + const OpKernel* kernel() const override { return kernel_.get(); } + + Device* InputDevice(int i) const override; + Device* OutputDevice(int idx) const override; + Device* OutputResourceDevice(int idx) const override; + + const DataTypeVector& input_dtypes() const override { + return kernel_->input_types(); + } + const DataTypeVector& output_dtypes() const override { + return kernel_->output_types(); + } + int num_inputs() const override { return kernel_->num_inputs(); } + int num_outputs() const override { return kernel_->num_outputs(); } + const string& name() const override { return kernel_->name(); } + + private: + std::unique_ptr kernel_; + bool is_distributed_communication_op_; + absl::InlinedVector input_alloc_attrs_; + std::vector input_devices_; + absl::InlinedVector output_alloc_attrs_; + Rendezvous* const rendezvous_; + checkpoint::TensorSliceReaderCacheWrapper slice_reader_cache_; + const bool log_memory_; +}; + +// Represents a multi-device function. Functions can also be run using +// various function-calling kernels including CallOp and PartitionedCallOp. +// In such cases, KernelAndDeviceOp is used. +class KernelAndDeviceFunc : public KernelAndDevice { + public: + // `flr` can be nullptr. + // `pflr` must not be nullptr. + // `host_cpu_device` must not be nullptr. + KernelAndDeviceFunc( + FunctionLibraryRuntime* flr, ProcessFunctionLibraryRuntime* pflr, + std::vector input_devices, + absl::flat_hash_map*> composite_devices, + std::unordered_map + input_resource_dtypes_and_shapes, + std::function)>* runner, + std::unique_ptr collective_executor, + Device* host_cpu_device, const string& name, + const bool outputs_on_op_device, + const bool allow_small_function_optimizations, + const bool allow_control_flow_sync_execution, + const bool shape_inference_on_tfe_dialect_import, + const bool int_args_and_retvals_on_device, + std::optional xla_compile_device_type, + const bool allow_soft_placement, Rendezvous::Factory rendezvous_factory, + std::function get_op_id) + : KernelAndDevice(flr, runner, std::move(collective_executor), + host_cpu_device), + pflr_(pflr), + handle_(kInvalidHandle), + outputs_on_op_device_(outputs_on_op_device), + allow_small_function_optimizations_(allow_small_function_optimizations), + allow_control_flow_sync_execution_(allow_control_flow_sync_execution), + shape_inference_on_tfe_dialect_import_( + shape_inference_on_tfe_dialect_import), + int_args_and_retvals_on_device_(int_args_and_retvals_on_device), + xla_compile_device_type_(xla_compile_device_type), + allow_soft_placement_(allow_soft_placement), + input_devices_(std::move(input_devices)), + composite_devices_(std::move(composite_devices)), + input_resource_dtypes_and_shapes_( + std::move(input_resource_dtypes_and_shapes)), + name_(name), + rendezvous_factory_(std::move(rendezvous_factory)), + get_op_id_(std::move(get_op_id)) {} + + ~KernelAndDeviceFunc() override; + + bool IsFunction() override { return true; }; + + bool IsCrossProcess() override { return is_cross_process_; } + + absl::Status InstantiateFunc( + bool log_device_placement, const NodeDef& ndef, + GraphCollector* graph_collector, + const absl::optional& eager_func_params); + + absl::Status Init( + bool log_device_placement, const NodeDef& ndef, + GraphCollector* graph_collector, + const absl::optional& eager_func_params) override; + + absl::Status Run( + ScopedStepContainer* step_container, const EagerKernelArgs& inputs, + std::vector* outputs, + CancellationManager* cancellation_manager, + const absl::optional& eager_func_params, + const absl::optional& stack_trace, + tsl::CoordinationServiceAgent* coordination_service_agent) override; + + void RunAsync(ScopedStepContainer* step_container, + const EagerKernelArgs& inputs, + std::vector* outputs, + CancellationManager* cancellation_manager, + const absl::optional& eager_func_params, + tsl::CoordinationServiceAgent* coordination_service_agent, + StatusCallback done) override; + + const OpKernel* kernel() const override { return nullptr; } + + Device* InputDevice(int i) const override; + Device* OutputDevice(int idx) const override; + Device* OutputResourceDevice(int idx) const override; + + const DataTypeVector& input_dtypes() const override { return input_dtypes_; } + const DataTypeVector& output_dtypes() const override { + return output_dtypes_; + } + int num_inputs() const override { return input_dtypes_.size(); } + int num_outputs() const override { return output_dtypes_.size(); } + const string& name() const override { return name_; }; + + private: + std::shared_ptr PrepareForRun( + ScopedStepContainer* step_container, std::vector* outputs, + CancellationManager* cancellation_manager, + const absl::optional& eager_func_params, + const absl::optional& stack_trace, + tsl::CoordinationServiceAgent* coordination_service_agent, + tsl::core::RefCountPtr* rendezvous); + + ProcessFunctionLibraryRuntime* const pflr_; // non-null + FunctionLibraryRuntime::Handle handle_; + // Indicates whether the function needs to execute cross process. + bool is_cross_process_; + + // If true, function outputs are explicitly assigned to the default device; + // if false, the output devices are inferred by pflr_. + bool outputs_on_op_device_; + + // If True, allow optimizations which should be targeted at a limited + // set of small functions. (For example, running kernels synchronously can + // be faster under some conditions.) + const bool allow_small_function_optimizations_; + + // If True, allows control nodes to run on the single threaded executor. + const bool allow_control_flow_sync_execution_; + + // TODO(b/176491312): Remove this if shape inference on import flag is + // removed. If True, allows mlir roundtrip to run shape inference on import. + const bool shape_inference_on_tfe_dialect_import_; + + const bool int_args_and_retvals_on_device_; + + const absl::optional xla_compile_device_type_; + + const bool allow_soft_placement_; + + // CPU devices are null. Resource handles' devices are actual backing + // devices. + std::vector output_devices_; + // CPU devices are not null. Resource handles' devices are actual backing + // devices. + std::vector input_devices_; + // Maps from a CompositeDevice name to a list of physical device names. + absl::flat_hash_map*> composite_devices_; + std::unordered_map + input_resource_dtypes_and_shapes_; + + DataTypeVector input_dtypes_; + DataTypeVector output_dtypes_; + string name_; + + Rendezvous::Factory rendezvous_factory_; + std::function get_op_id_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_KERNEL_AND_DEVICE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/eager/placement_utils.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/eager/placement_utils.h new file mode 100644 index 00000000..fa51f198 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/eager/placement_utils.h @@ -0,0 +1,49 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_PLACEMENT_UTILS_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_PLACEMENT_UTILS_H_ + +#include "tensorflow/c/eager/immediate_execution_operation.h" +#include "tensorflow/core/common_runtime/eager/context.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/stringpiece.h" + +namespace tensorflow { +namespace eager { + +bool IsColocationExempt(absl::string_view op_name); + +bool IsFunction(absl::string_view op_name); + +// TODO(b/154234908): Unify placement logic. + +// Pin the op to cpu if all op inputs are on the CPU, small (<64 elements) and +// integers (int32/int64). This can be disabled by setting the environment +// variable "TF_EAGER_ENABLE_SMALL_TENSOR_CPU_PINNING" to "0" or "false". +absl::Status MaybePinSmallOpsToCpu( + bool* result, absl::string_view op_name, + absl::Span args, + absl::string_view cpu_device_name); + +// If a resource touching input is specified, all resource-touching ops run in +// the device the resource is, regardless of anything else that has been +// specified. This is identical to the graph mode behavior. +absl::Status MaybePinToResourceDevice(Device** device, + const EagerOperation& op); +} // namespace eager +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_PLACEMENT_UTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/eager/rendezvous_cache.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/eager/rendezvous_cache.h new file mode 100644 index 00000000..e79171f8 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/eager/rendezvous_cache.h @@ -0,0 +1,146 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_RENDEZVOUS_CACHE_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_RENDEZVOUS_CACHE_H_ + +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/mutex.h" +#include "tsl/platform/refcount.h" +#include "tsl/platform/thread_annotations.h" + +namespace tensorflow { + +// The class for caching Rendezvous instances per step_id. +// If the Rendezvous object is destroyed for the step, a new one will be +// created on demand. +template +class RendezvousCache : public tsl::core::WeakRefCounted { + public: + RendezvousCache() = default; + ~RendezvousCache() override { + for (auto& p : table_) { + auto rendez = p.second.GetNewRef(); + if (rendez) { + rendez->StartAbort(tsl::errors::Aborted("Shutdown")); + } + } + } + + // Returns a new Reference. + template + tsl::core::RefCountPtr FindOrCreate(int64_t step_id, + RendezvousCreator create_fn) { + tsl::mutex_lock l(table_lock_); + tsl::core::RefCountPtr rendz = nullptr; + auto iter = table_.find(step_id); + if (iter != table_.end()) { + rendz = iter->second.GetNewRef(); + VLOG(5) << "step_id:" << step_id << " " + << "WeakPtr returned:" << rendz.get(); + if (!rendz) { + table_.erase(iter); + } + } + if (!rendz) { // Deleted or not found + rendz = create_fn(); + VLOG(5) << "step_id:" << step_id << " " + << "Rendezvous not found, inserting a new one." << rendz.get(); + auto cleanup_fn = [weak_cache = tsl::core::WeakPtr(this), + step_id]() { + tsl::core::RefCountPtr cache = weak_cache.GetNewRef(); + if (cache != nullptr) { + // If the rendezvous is released, Find() will clean it up from the + // map. + cache->Find(step_id); + } + }; + table_.insert({step_id, tsl::core::WeakPtr{rendz.get(), cleanup_fn}}); + } + return rendz; + } + + // Returns a new Reference. + tsl::core::RefCountPtr Find(int64_t step_id) { + tsl::mutex_lock l(table_lock_); + auto iter = table_.find(step_id); + if (iter == table_.end()) return nullptr; + tsl::core::RefCountPtr res = iter->second.GetNewRef(); + // Cleans the record if the rendezvous is already destroyed. + if (res == nullptr) { + table_.erase(iter); + } + return res; + } + + // Removes a Rendezvous weak reference from table. + void Remove(int64_t step_id) { + tsl::mutex_lock l(table_lock_); + table_.erase(step_id); + } + + // Removes a Rendezvous weak reference from table, and abort the rendezvous. + void RemoveAndAbort(int64_t step_id) { + tsl::core::RefCountPtr rendez = nullptr; + { + tsl::mutex_lock l(table_lock_); + auto iter = table_.find(step_id); + if (iter != table_.end()) { + rendez = iter->second.GetNewRef(); + table_.erase(iter); + } + } + if (rendez) { + rendez->StartAbort(tsl::errors::Aborted("Cleanup ", step_id)); + } + } + + void RemoveAll() { + tsl::mutex_lock l(table_lock_); + table_.clear(); + } + + // Returns a list of active step ids. This result is only informative + // at time of the call. The returned vector may contain step ids that have + // been invalidated after the call. + std::vector GetActiveStepIds() { + std::vector list; + tsl::mutex_lock l(table_lock_); + list.reserve(table_.size()); + for (const auto& iter : table_) { + list.push_back(iter.first); + } + return list; + } + + size_t Size() const { + tsl::mutex_lock l(table_lock_); + return table_.size(); + } + + private: + mutable tsl::mutex table_lock_; + absl::flat_hash_map> table_ + TF_GUARDED_BY(table_lock_); +}; + +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_RENDEZVOUS_CACHE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/eager/shape_inference.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/eager/shape_inference.h new file mode 100644 index 00000000..be386f97 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/eager/shape_inference.h @@ -0,0 +1,36 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_SHAPE_INFERENCE_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_SHAPE_INFERENCE_H_ + +#include "tensorflow/core/common_runtime/eager/tensor_handle.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/gtl/inlined_vector.h" + +namespace tensorflow { +namespace eager { + +absl::Status RunShapeInference( + const NodeDef& ndef, const FunctionLibraryDefinition& lib_def, + const absl::InlinedVector& inputs, + const absl::InlinedVector& retvals); + +} // namespace eager +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_SHAPE_INFERENCE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/eager/small_constants_optimizer.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/eager/small_constants_optimizer.h new file mode 100644 index 00000000..cb70fb99 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/eager/small_constants_optimizer.h @@ -0,0 +1,42 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_SMALL_CONSTANTS_OPTIMIZER_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_SMALL_CONSTANTS_OPTIMIZER_H_ + +#include +#include + +#include "absl/strings/string_view.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/function.pb.h" + +namespace tensorflow::small_constants_optimizer { + +// Checks whether small constant optimization is enabled for a tf.function. +bool IsSmallConstantOptimizationEnabled(const FunctionDef& fdef); + +// Generates new FunctionDefs with the boolean input tensors folded as +// constants into the FunctionDef. +std::vector FoldInputTensors( + const FunctionDef& fdef, const FunctionLibraryDefinition& flib); + +// Generates the FunctionDef name for the folded function. +std::string FoldedFunctionName(absl::string_view fname, + absl::string_view input_name, bool input_value); + +} // namespace tensorflow::small_constants_optimizer + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_SMALL_CONSTANTS_OPTIMIZER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/eager/summary_optimizer.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/eager/summary_optimizer.h new file mode 100644 index 00000000..0b337e04 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/eager/summary_optimizer.h @@ -0,0 +1,51 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_SUMMARY_OPTIMIZER_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_SUMMARY_OPTIMIZER_H_ + +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/function.pb.h" + +namespace tensorflow::summary_optimizer { +namespace internal { + +// Normalizes an edge's name to match the names stored in a NodeDef. +std::string NormalizeEdgeName(absl::string_view name); + +} // namespace internal + +// Returns the name of the input_arg and the bool value that determines whether +// or not to disable summaries. If no such arg exists returns an empty string. +std::pair GetDisableSummariesInputArg( + const FunctionDef& fdef); + +// Generates new FunctionDef(s) with the summaries stripped out. +// This function will traverse all the nested functions and generate a version +// of the nested functions with summaries stripped out. +std::vector StripSummaries(const FunctionDef& fdef, + const FunctionLibraryDefinition& flib); + +// Generates a new function name for the stripped function. +std::string StrippedFunctionName(absl::string_view fname); + +} // namespace tensorflow::summary_optimizer + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_SUMMARY_OPTIMIZER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/eager/tensor_handle.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/eager/tensor_handle.h new file mode 100644 index 00000000..ca60815d --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/eager/tensor_handle.h @@ -0,0 +1,419 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_TENSOR_HANDLE_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_TENSOR_HANDLE_H_ + +#include +#include +#include +#include +#include +#include +#include +#include + +// clang-format off +// Required for IS_MOBILE_PLATFORM +#include "tensorflow/core/framework/full_type.pb.h" +#include "tensorflow/core/framework/shape_inference.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/platform/platform.h" +// clang-format on + +#include "absl/types/variant.h" +#include "tensorflow/c/eager/immediate_execution_tensor_handle.h" +#include "tensorflow/core/common_runtime/device.h" +#include "tensorflow/core/common_runtime/eager/eager_executor.h" +#include "tensorflow/core/common_runtime/eager/tensor_handle_data.h" +#include "tensorflow/core/common_runtime/function.h" +#if !defined(IS_MOBILE_PLATFORM) +#include "tensorflow/core/distributed_runtime/eager/remote_tensor_handle_data.h" +#endif // IS_MOBILE_PLATFORM +#include "tensorflow/core/framework/tensor.h" + +#include "tensorflow/core/lib/core/stringpiece.h" + +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/thread_annotations.h" + +namespace tensorflow { + +class EagerContext; + +// Associates a Tensor and a Device, used in the eager runtime. Internal version +// of the TFE_TensorHandle struct and the python EagerTensor class +// (unrelated to python TensorHandle). +class TensorHandle : public ImmediateExecutionTensorHandle { + // TensorHandle for dtype != DT_RESOURCE + TensorHandle(tensorflow::Tensor&& t, Device* d, Device* op_device, + Device* resource_device, EagerContext* ctx); + // TensorHandle for dtype == DT_RESOURCE + TensorHandle(tensorflow::Tensor&& t, Device* d, Device* op_device, + EagerContext* ctx); + TensorHandle(Device* d, Device* op_device, Device* resource_device, + tensorflow::DataType dtype, EagerContext* ctx); + +#if !defined(IS_MOBILE_PLATFORM) + TensorHandle(int64_t op_id, int32_t output_num, const string& remote_task, + tensorflow::DataType dtype, Device* device, EagerContext* ctx, + bool unknown_device); + TensorHandle(int64_t op_id, int32_t output_num, tensorflow::DataType dtype, + Device* device, bool is_ready, EagerContext* ctx); +#endif // IS_MOBILE_PLATFORM + + public: + // TensorHandle with no assigned device + static TensorHandle* CreateLocalHandle(const tensorflow::Tensor& t); + static TensorHandle* CreateLocalHandle(tensorflow::Tensor&& t, Device* d, + Device* op_device, EagerContext* ctx); + static TensorHandle* CreateLocalHandle(tensorflow::Tensor&& t, Device* d, + Device* op_device, + Device* resource_device, + EagerContext* ctx); + static TensorHandle* CreateEmptyLocalHandle(Device* d, Device* op_device, + Device* resource_device, + tensorflow::DataType dtype, + EagerContext* ctx); + + // Create a handle which packs the given handles of the same dtype and shape. + // If handles are on different devices, assign the packed handle to a + // CompositeDevice. + // + // The new tensor handle shares ownership of the given handle: their reference + // count will be increased by one after a call to `CreatePackedHandle`. + // TODO(b/170414377): Use `TensorHandlePtr` instead. + static absl::Status CreatePackedHandle(std::vector&& handles, + tensorflow::DataType dtype, + const tensorflow::TensorShape& shape, + const string& device_name, + EagerContext* ctx, + TensorHandle** packed_handle); + static absl::Status CreatePackedHandle(std::vector&& handles, + EagerContext* ctx, + TensorHandle** packed_handle); + +#if !defined(IS_MOBILE_PLATFORM) + // An unshaped remote handle refers to a tensor on a remote worker. It's not + // ready until the shape is set. It controls the lifetime of the remote + // tensor. + static TensorHandle* CreateUnshapedRemoteHandle(int64_t op_id, + int32_t output_num, + const string& remote_task, + tensorflow::DataType dtype, + Device* d, EagerContext* ctx, + bool unknown_device = false); + // A lazy remote handle refers to a tensor on a remote worker. The lifetime of + // the remote tensor is controlled by the remote worker, but not by the lazy + // remote handle. Lazy handles are normally created on a default function + // device. + static TensorHandle* CreateLazyRemoteHandle(int64_t op_id, int32_t output_num, + tensorflow::DataType dtype, + Device* d, bool is_ready, + EagerContext* ctx); +#endif // IS_MOBILE_PLATFORM + + // Templated struct `AutoReleaser` in + // core/runtime_fallback/runtime/kernel_utils.h needs a Release() method + // defined. + void Release(); + + tensorflow::DataType DataType() const override; + absl::Status Shape(tensorflow::PartialTensorShape* shape) const override; + absl::Status NumDims(int* num_dims) const override; + absl::Status NumElements(int64_t* num_elements) const override; + absl::Status Dim(int dim_index, int64_t* dim) const override; + + const char* DeviceName(absl::Status* status) const override; + const char* BackingDeviceName(absl::Status* status) const override; + const char* DeviceType(absl::Status* status) const override; + int DeviceId(absl::Status* status) const override; + AbstractTensorInterface* Resolve(absl::Status* status) override; + + // Subclasses may return True to instruct the string formatter + // to use SummarizeValue instead of the NumPy formatter. + bool PreferCustomSummarizer() const override { + return dtype == DT_VARIANT || dtype == DT_RESOURCE; + } + + // Return the Tensor from the default device. + absl::Status Tensor(const tensorflow::Tensor** t) const; + // Return the Tensor from the specified device which could be either the + // default device or a local mirror. The device pointer should be nullptr if + // requesting the HostCPU. + absl::Status TensorFromDevice(const Device* d, + const tensorflow::Tensor** t) const; + + // Return the TensorValue from the specified device which could be either the + // default device or a local mirror. The device pointer should be nullptr if + // requesting the HostCPU. + absl::Status TensorValue(const Device* d, tensorflow::TensorValue* t); + + Device* device() const { return device_; } + Device* op_device() const { return op_device_; } + Device* resource_device() const { return resource_device_; } + int64_t resource_remote_device_incarnation() const { + return resource_remote_device_incarnation_; + } + + // If the devices are unknown at creation time, block until the actual devices + // are set (data is ready). + absl::Status WaitUnknownDevice() const; + + Device* DeviceOrHostCPU(const EagerContext& ctx) const; + + absl::Status Shape(tensorflow::TensorShape* shape); + + absl::Status Unprotect(const Device* d); + + // Checks if a mirror tensor exists for the specified device. Mirrors are only + // maintained for local devices, like CPUs & GPUs. Note a mirror may be empty, + // as it is still to be set by an async operation. + bool HasLocalMirror(const Device* d) const; + // Add an empty mirror placeholder for the specified device. The expectation + // is this will be populated by a call to SetTensor. + absl::Status AddEmptyLocalMirror(const Device* d); + // Add a local mirror. This will fail if an empty local mirror was previously + // added. For that case, SetTensor should be used instead. + absl::Status AddLocalMirror(tensorflow::Tensor&& tensor, const Device* d); + +#if !defined(IS_MOBILE_PLATFORM) + bool HasRemoteMirror(const Device* d, uint64 context_view_id) const; + bool HasResourceShapeMirror(const Device* d, uint64 context_view_id) const; + + absl::Status AddUnshapedRemoteMirror(const Device* d, int64_t op_id, + int output_num, + const string& remote_task, + EagerContext* ctx); + absl::Status AddResourceShapeMirror(const Device* d, int64_t op_id, + int output_num, EagerContext* ctx); + + // Return the op_id and output num if the handle refers to a remote tensor. + // If wait_until_ready is true, block until the remote tensor is ready on the + // given remote worker. + absl::Status RemoteAddress(const Device* d, bool wait_until_ready, + int64_t* op_id, int32* output_num) const; + + // Called on an async remote tensor once it's shape has been determined. This + // transitions the tensor handle from a non-ready to a ready state by + // replacing the backing data abstraction to allow for the shape to be + // queried. + // creating a TensorHandle (e.g. a remote output of a remote function). + // This method or Poison must be called exactly once for remote tensors that + // were created without a known shape. + absl::Status SetRemoteShape(const TensorShape& shape, const Device* d, + uint64 context_view_id); + // If op_device is not empty, reset the devices of a remote tensor which is + // created without known devices (e.g. function outputs). + absl::Status SetRemoteShapeAndDevice(const TensorShape& shape, + const Device* d, uint64 context_view_id, + string op_device); + + // Poisons either this handle or a remote mirror with error `status`. + // Poisoning means that the handle will become ready and methods trying + // to access the remote shape will return this error `status`. + // Exactly one of SetRemoteShape or PoisonRemote methods must be called on a + // unshaped handle on a remote device. + void PoisonRemote(absl::Status status, const Device* d, + uint64 context_view_id); +#endif + + // Sets the `tensor` for this async non-ready handle making it ready. + // This method or Poison must be called exactly once for non-ready async + // handles to make them ready. + absl::Status SetTensor(tensorflow::Tensor&& tensor, const Device* d); + + // Poisons either this handle or a local mirror with error `status`. + // Poisoning means that the handle will become ready and methods trying + // to access the actual tensor or shape will return this error `status`. + // Exactly one of SetTensor or Poison methods must be called on a non-ready + // tensor for a specific device. + void Poison(absl::Status status, const Device* d); + + // TODO(b/154282629): Consider moving it to EagerContext. + // Copies to the tensor on the given device `d`, or to host iff `d` is null. + absl::Status CopyToDevice(const EagerContext& ctx, tensorflow::Device* d, + tensorflow::Tensor* output) const; + + absl::Status InferenceShape( + shape_inference::InferenceContext* inference_context, + shape_inference::ShapeHandle* shape_handle); + void SetInferenceShape(shape_inference::InferenceContext* inference_context, + const shape_inference::ShapeHandle& shape_handle); + absl::Status CopyInferenceShape(TensorHandle* other); + + // dtype for the handle. It must be the same as t.dtype() once the handle is + // ready. + const tensorflow::DataType dtype; + + enum HandleType { LOCAL = 0, PACKED = 1, REMOTE = 2 }; + + HandleType Type() const; + string TypeString() const; + + void SetResourceHandleDtypeAndShape( + std::vector dtypes_and_shapes); + + // If this TensorHandle is 1) a local tensor, and 2) a resource handle, + // return data types and shapes of the underlying resource. + absl::Status GetResourceHandleDtypesAndShapes( + std::vector* result); + + // Returns the number of packed handles. 0 if the handle type is not PACKED. + int NumPackedHandles() const; + // It's called on a packed TensorHandle. Extract a handle with the given + // index. + absl::Status ExtractPackedHandle(int index, TensorHandle** handle) const; + + // For LLVM style RTTI. + static bool classof(const AbstractTensorHandle* ptr) { + return ptr->getKind() == kEager; + } + + tensorflow::FullTypeDef FullType() const override { return full_type_; } + + void SetFullType(FullTypeDef& full_type) { full_type_ = full_type; } + + private: + friend class PackedTensorHandleTest; + + TensorHandle(std::vector&& handles, Device* device, + tensorflow::DataType dtype, const tensorflow::TensorShape& shape, + EagerContext* ctx); + + ~TensorHandle() override; + + // The TensorHandleData can either represent a local or remote tensor handle. + // Further, it can be in a non-ready state. It would become ready with a call + // to either SetTensor or SetRemoteShape which replaces the underlying data + // with a ready version of the tensor handle data. + bool IsReady() const; + absl::Status WaitReady(const char* caller) const; + + tensorflow::Device* device_; + + // Device in which the op producing this tensor was executed. Equals to + // device_ for constant tensors. + // Can be nullptr if the op producing this tensor was a function executed + // with function library runtime. + tensorflow::Device* op_device_; + + // If the tensor dtype is DT_RESOURCE, resource_device_ holds the device + // backing the resource. Else resource_device_ is nullptr. + tensorflow::Device* resource_device_; + // Incarnation ID of the resource device if it locates on a remote device, or + // 0 if it locates on a local device. + int64_t resource_remote_device_incarnation_; + + // If true, the handle refers to a remote tensor which is created without + // known devices. The actual devices are set by SetRemoteShape. The devices + // should be accessed once the handle is ready. + const bool unknown_device_ = false; + + mutable mutex mu_; + + // Map of local mirrors. This can include both ready and non-ready mirrors. + std::unordered_map + local_mirrors_ TF_GUARDED_BY(mu_); +#if !defined(IS_MOBILE_PLATFORM) + // TODO(yujingzhang): Remove resource_shape_mirrors_ once scalable per-replica + // variable is ready, since we could get the shape locally without remote copy + // then. + std::unordered_map resource_shape_mirrors_ + TF_GUARDED_BY(mu_); + std::unordered_map remote_mirrors_ + TF_GUARDED_BY(mu_); +#endif + + // `ctx` is only guaranteed to be set if the handle is not "ready". This is + // typically true when the handle was produced during async execution. + // `ctx` object is not owned and should outlive this handle. + // + // TODO(b/150614042): Reference count EagerContext to ensure that 'device_' of + // a TensorHandle does not outlive the EagerContext from which it came? + EagerContext* const ctx_; + + // If this TensorHandle 1) is a local tensor, and 2) is a resource handle or + // refers to a remote resource handle, we store data types and shapes for + // the underlying resource. + std::vector handle_dtypes_and_shapes_; + + // A handle data which refers to multiple TensorHandles of the same dtype and + // shape. + class PackedTensorHandleData { + public: + // Initialize handle data from list of tensor handles. + // Ownership of the tensor handles is shared between the + // `PackedTensorHandleData` and the caller (the reference count for the + // given handles is incremented). + // TODO(b/170414377): Use `TensorHandlePtr` instead. + PackedTensorHandleData(std::vector&& handles, + const TensorShape& shape); + + ~PackedTensorHandleData(); + + absl::Status Shape(TensorShape* shape) const; + absl::Status NumDims(int* num_dims) const; + absl::Status Dim(int dim_index, int64_t* dim) const; + absl::Status NumElements(int64_t* num_elements) const; + absl::Status Unprotect(); + bool IsReady() const; + absl::Status WaitReady(const char* caller) const; + void Poison(absl::Status status); + string DebugString() const; + + // Number of packed handles. + int NumPackedHandles() const; + // Extract a handle on the given index. + absl::Status ExtractPackedHandle(int index, TensorHandle** handle) const; + + private: + // TODO(b/170414377): Use `TensorHandlePtr` instead. + const std::vector handles_; + const TensorShape shape_; + + mutable mutex mu_; + absl::Status is_poisoned_ TF_GUARDED_BY(mu_); + }; + + // Does not need synchronization because it can be accessed only after + // WaitReady() has returned. At that point, data_ is immutable. +#if !defined(IS_MOBILE_PLATFORM) + std::variant + data_; +#else + absl::variant data_; +#endif + + PartialTensorShape inference_shape_; + + FullTypeDef full_type_; +}; + +// Returns the device backing the resource. Else, returns nullptr. +Device* GetResourceDevice(const ResourceHandle& handle, EagerContext* ctx); + +class TensorHandleInterface : public ImmediateExecutionTensorHandle { + public: +}; + +template +inline TensorHandle* TensorHandleFromInterface(T* handle) { + return down_cast(handle); +} + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_TENSOR_HANDLE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/eager/tensor_handle_data.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/eager/tensor_handle_data.h new file mode 100644 index 00000000..ed58e83a --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/eager/tensor_handle_data.h @@ -0,0 +1,115 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_TENSOR_HANDLE_DATA_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_TENSOR_HANDLE_DATA_H_ + +#include +#include + +#include "absl/types/variant.h" +#include "tensorflow/core/common_runtime/eager/context.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { + +// Local Tensor Handle: Handle to a Tensor present on the local host. +class LocalTensorHandleData { + public: + LocalTensorHandleData() : ctrl_(absl::in_place_type) {} + explicit LocalTensorHandleData(tensorflow::Tensor&& t) + : tensor_(std::move(t)), + forwarding_protection_tensor_(tensor_), + ctrl_(absl::in_place_type) {} + + // A local tensor handle should be able to satisfy all of these requests. + absl::Status Tensor(const tensorflow::Tensor** t) const; + absl::Status TensorValue(tensorflow::TensorValue* t); + absl::Status Shape(TensorShape* shape) const; + absl::Status NumDims(int* num_dims) const; + absl::Status Dim(int dim_index, int64_t* dim) const; + absl::Status NumElements(int64_t* num_elements) const; + absl::Status Unprotect(); + + bool IsReady() const { + return std::visit([](auto& data) { return data.IsReady(); }, ctrl_); + } + + absl::Status WaitReady(const char* caller) const { + return std::visit([caller](auto& data) { return data.WaitReady(caller); }, + ctrl_); + } + void Poison(absl::Status status) { + return std::visit([status](auto& data) { data.Poison(status); }, ctrl_); + } + absl::Status IsPoisoned() const { + return std::visit([](auto& data) { return data.IsPoisoned(); }, ctrl_); + } + + absl::Status SetTensor(tensorflow::Tensor&& t); + + string DebugString() const; + + private: + tensorflow::Tensor tensor_; + // TensorHandle has its own reference counting which is distinct from the + // backing Tensor. As a result, if the Tensor reference count is 1 while + // executing an op, the TensorBuffer could be reused for the output. We avoid + // this behavior maintaining another reference count with the + // forwarding_protection_tensor_ Tensor. When Unprotect() is called, we + // release this Tensor to allow forwarding. + tensorflow::Tensor forwarding_protection_tensor_; + + // We distinguish between ready and empty tensors with the ctrl_ variant. + // which contains 2 implementations of the waiting logic. The + // NonBlockingControl is a simple no-op class whereas the BlockingControl + // actually uses a mutex. By using a variant we avoid the overhead of + // constructing and destructing the mutex for ready local tensors. + class NonBlockingControl { + public: + bool IsReady() const { return true; } + absl::Status WaitReady(const char* caller) const { + return absl::OkStatus(); + } + void Poison(absl::Status status) {} + absl::Status IsPoisoned() const { return absl::OkStatus(); } + }; + + class BlockingControl { + public: + bool IsReady() const { + tf_shared_lock l(mu_); + return is_ready_; + } + void SetReady(); + absl::Status WaitReady(const char* caller) const; + void Poison(absl::Status status); + absl::Status IsPoisoned() const { + tf_shared_lock l(mu_); + return is_poisoned_; + } + + private: + mutable mutex mu_; + bool is_ready_ TF_GUARDED_BY(mu_); + absl::Status is_poisoned_ TF_GUARDED_BY(mu_); + }; + + std::variant ctrl_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_TENSOR_HANDLE_DATA_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/entry.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/entry.h new file mode 100644 index 00000000..82bf44ea --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/entry.h @@ -0,0 +1,141 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_ENTRY_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_ENTRY_H_ + +#include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/lib/gtl/inlined_vector.h" +#include "tensorflow/core/lib/gtl/manual_constructor.h" + +namespace tensorflow { + +class Tensor; + +// An Entry store a single input value for an individual kernel invocation in +// an executor. +// +// Either a tensor pointer (pass-by-reference) or a tensor (pass-by-value). +struct Entry { + enum class State { + NO_VALUE = 0, // The default state for a newly-created Entry. + HAS_VALUE, // `this->val` is valid. + HAS_CONST_TENSOR, // `this->const_tensor` is valid. + HAS_REF_TENSOR, // `this->ref_tensor` is valid. + }; + + Entry() : state(State::NO_VALUE) {} + Entry(const Entry& other) : state(other.state), alloc_attr(other.alloc_attr) { + switch (state) { + case State::NO_VALUE: + break; + case State::HAS_VALUE: + val.Init(*other.val); + break; + case State::HAS_CONST_TENSOR: + const_tensor = other.const_tensor; + break; + case State::HAS_REF_TENSOR: + ref_tensor = other.ref_tensor; + break; + } + } + + ~Entry() { + if (state == State::HAS_VALUE) val.Destroy(); + } + + Entry& operator=(const Entry& other) { + if (state == State::HAS_VALUE) { + val.Destroy(); + } + state = other.state; + alloc_attr = other.alloc_attr; + switch (state) { + case State::NO_VALUE: + break; + case State::HAS_VALUE: + val.Init(*other.val); + break; + case State::HAS_CONST_TENSOR: + const_tensor = other.const_tensor; + break; + case State::HAS_REF_TENSOR: + ref_tensor = other.ref_tensor; + break; + } + return *this; + } + + Entry& operator=(Entry&& other) { + if (state == State::HAS_VALUE) { + val.Destroy(); + } + state = other.state; + alloc_attr = other.alloc_attr; + switch (state) { + case State::NO_VALUE: + break; + case State::HAS_VALUE: + val.Init(std::move(*other.val)); + break; + case State::HAS_CONST_TENSOR: + const_tensor = other.const_tensor; + break; + case State::HAS_REF_TENSOR: + ref_tensor = other.ref_tensor; + break; + } + return *this; + } + + // Clears the field, and sets this entry to the `NO_VALUE` state. + void ClearVal() { + if (state == State::HAS_VALUE) { + val.Destroy(); + } + state = State::NO_VALUE; + } + + union { + // A tensor value. Valid iff `state_ == HAS_VALUE`. + ManualConstructor val; + + // A pointer to a constant tensor value. Valid iff `state_ == + // HAS_CONST_TENSOR`. + const Tensor* const_tensor; + + // A tensor reference and associated mutex. Valid iff `state_ == + // HAS_REF_TENSOR`. + struct { + Tensor* tensor; + mutex* mu; + } ref_tensor; + }; + + // The current state of this entry, indicating which member of the above + // union is active. + State state; + + // The attributes of the allocator that creates the tensor. + AllocatorAttributes alloc_attr; +}; + +// TODO(b/152925936): Re-evaluate this constant with current usage patterns. +typedef absl::InlinedVector EntryVector; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_ENTRY_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/eval_const_tensor.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/eval_const_tensor.h new file mode 100644 index 00000000..049a3e9f --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/eval_const_tensor.h @@ -0,0 +1,63 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_EVAL_CONST_TENSOR_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_EVAL_CONST_TENSOR_H_ + +#include +#include + +#include "absl/functional/function_ref.h" +#include "tensorflow/core/platform/statusor.h" + +namespace tensorflow { + +class GraphRunner; +class Node; +class OpRegistryInterface; +class ShapeRefiner; +class Tensor; + +// Configuration of the graph runner for constant folding. +struct EvaluateConstantTensorRunner { + // Op registry for temporary graphs. By default, the global registry will + // be used. + const OpRegistryInterface* op_registry = nullptr; + // Version of the graph API to use. + int32_t graph_def_version = 0; + // Graph runner for constant folding. By default, a temporary graph runner + // will be created. + GraphRunner* graph_runner = nullptr; +}; + +// Attempts to evaluate an output of the given node. This will only be possible +// if it doesn't depend on any graph inputs (this function is safe to call +// if this isn't the case though). +// +// When the evaluation is successful, the function returns a tensor, otherwise +// it returns std::nullopt. +absl::StatusOr> EvaluateConstantTensor( + // The tensor to be evaluated. + const Node& node, int node_output, + // Used to fetch inference contexts for nodes in the graph. + const ShapeRefiner& refiner, + // Used to both lookup cached results and request function arguments. + absl::FunctionRef(const Node&, int)> lookup, + // Configuration of the graph runner. If not set, no attempt to fold a + // constant subgraph will be made. + std::optional runner); + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_EVAL_CONST_TENSOR_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/executor.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/executor.h new file mode 100644 index 00000000..2a13ff0c --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/executor.h @@ -0,0 +1,265 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_EXECUTOR_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_EXECUTOR_H_ + +#include + +#include "absl/time/time.h" +#include "absl/types/optional.h" +#include "tensorflow/core/common_runtime/device.h" +#include "tensorflow/core/common_runtime/local_executor_params.h" +#include "tensorflow/core/framework/rendezvous.h" +#include "tensorflow/core/framework/session_state.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/notification.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/core/threadpool_interface.h" +#include "tensorflow/core/platform/error_logging.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/util/managed_stack_trace.h" + +namespace tensorflow { + +class StepStatsCollector; + +// Executor runs a graph computation. +// Example: +// Graph* graph = ...; +// ... construct graph ... +// Executor* executor; +// TF_CHECK_OK(NewSimpleExecutor(my_device, graph, &executor)); +// Rendezvous* rendezvous = NewNaiveRendezvous(); +// TF_CHECK_OK(rendezvous->Send("input", some_input_tensor)); +// TF_CHECK_OK(executor->Run({ExecutorOpts, rendezvous, nullptr})); +// TF_CHECK_OK(rendezvous->Recv("output", &output_tensor)); +// ... ... +// +// Multiple threads can call Executor::Run concurrently. +class Executor { + public: + virtual ~Executor() {} + + // RunAsync() executes the graph computation. "done" is run when the + // graph computation completes. If any error happens during the + // computation, "done" is run and the error is passed to "done". + // + // RunAsync() is given a few arguments in Args. The caller must + // ensure objects passed in Args (rendezvous, stats_collector, etc.) + // are alive at least until done is invoked. All pointers to the + // argument objects can be nullptr. + // + // "step_id" is a process-wide unique identifier for the step being + // run. Executors on different devices may receive the same step_id + // in the case that a step runs Ops on more than one device. The + // step_id is used for tracking resource usage of a given step. + // + // RunAsync() uses the given "rendezvous", if not null, as the + // mechanism to communicate inputs and outputs of the underlying + // graph computation. + // + // RunAsync() calls "stats_collector", if not null, to keep track of + // stats. This allows us to collect statistics and traces on demand. + // + // RunAsync() is provided a "call_frame", if the executor is used + // for executing a function, is used to pass arguments and return + // values between the caller and the callee. + // + // RunAsync() uses "cancellation_manager", if not nullptr, to + // register callbacks that should be called if the graph computation + // is canceled. Note that the callbacks merely unblock any + // long-running computation, and a canceled step will terminate by + // returning/calling the DoneCallback as usual. + // + // RunAsync() dispatches closures to "runner". Typically, "runner" + // is backed up by a bounded threadpool. + // + // "start_time_usecs" is a timestamp for the start of RunAsync() + // execution. Used for system-wide latency metrics. + struct Args { + int64_t step_id = 0; + // Used only by tracer/profiler, applicable only when running under + // FunctionRuntimeLibrary, unique per invocation. + std::optional function_trace_id; + RendezvousInterface* rendezvous = nullptr; + StepStatsCollectorInterface* stats_collector = nullptr; + CallFrameInterface* call_frame = nullptr; + CancellationManager* cancellation_manager = nullptr; + const ConfigProto* session_config = nullptr; + SessionState* session_state = nullptr; + // Unique session identifier. Can be empty. + string session_handle; + TensorStore* tensor_store = nullptr; + ScopedStepContainer* step_container = nullptr; + CollectiveExecutor* collective_executor = nullptr; + thread::ThreadPoolInterface* user_intra_op_threadpool = nullptr; + tsl::CoordinationServiceAgent* coordination_service_agent = nullptr; + int64_t start_time_usecs = 0; + // The deadline for the kernel to complete by. Empty if unspecified. + absl::optional deadline; + absl::optional stack_trace = absl::nullopt; + + // If true, calls Sync() on the device. + bool sync_on_finish = false; + + typedef std::function Closure; + typedef std::function Runner; + Runner runner = nullptr; + + // If true, all kernels will be treated as "inexpensive", and hence executed + // on the scheduling thread. + bool run_all_kernels_inline = false; + }; + typedef std::function DoneCallback; + + void RunAsync(const Args& args, DoneCallback done) { + RunAsyncInternal(args, [done = std::move(done)](const absl::Status& s) { + if (!s.ok()) Log("TFExecutor", "Run", s.message()).IgnoreError(); + done(s); + }); + } + + // Synchronous wrapper for RunAsync(). + virtual absl::Status Run(const Args& args) { + absl::Status ret; + Notification n; + RunAsync(args, [&ret, &n](const absl::Status& s) { + ret = s; + n.Notify(); + }); + n.WaitForNotification(); + return ret; + } + + private: + virtual void RunAsyncInternal(const Args& args, DoneCallback done) = 0; +}; + +// Creates an Executor that computes the given "graph". +// +// If successful, returns the constructed executor in "*executor". Otherwise, +// returns an error status. +// +// "params" provides a set of context for the executor. We expect that +// different context would provide different implementations. +absl::Status NewLocalExecutor(const LocalExecutorParams& params, + const Graph& graph, Executor** executor); + +// A class to help run multiple executors in parallel and wait until +// all of them are complete. +// +// ExecutorBarrier deletes itself after the function returned by Get() +// is called. +class ExecutorBarrier { + public: + typedef std::function StatusCallback; + + // Create an ExecutorBarrier for 'num' different executors. + // + // 'r' is the shared Rendezvous object that is used to communicate + // state. If any of the executors experiences an error, the + // rendezvous object will be aborted exactly once. + // + // 'done' is called after the last executor completes, and + // ExecutorBarrier is deleted. + ExecutorBarrier(size_t num, Rendezvous* r, StatusCallback done) + : rendez_(r), done_cb_(done), pending_(num) {} + + ~ExecutorBarrier() {} + + // Returns a closure that Executors must call when they are done + // computing, passing the status of their execution as an argument. + StatusCallback Get() { + return std::bind(&ExecutorBarrier::WhenDone, this, std::placeholders::_1); + } + + private: + Rendezvous* rendez_ = nullptr; + StatusCallback done_cb_ = nullptr; + + mutable mutex mu_; + int pending_ TF_GUARDED_BY(mu_) = 0; + StatusGroup status_group_ TF_GUARDED_BY(mu_); + + void WhenDone(const absl::Status& s) { + Rendezvous* error_rendez = nullptr; + StatusCallback done = nullptr; + absl::Status status; + + { + mutex_lock l(mu_); + + // If we are the first error encountered, trigger an abort of the + // Rendezvous object by this thread only. + if (status_group_.ok() && !s.ok()) { + error_rendez = rendez_; + error_rendez->Ref(); + } + + if (!s.ok() && !StatusGroup::IsDerived(s) && + !status_group_.HasLogMessages()) { + status_group_.AttachLogMessages(); + } + + status_group_.Update(s); + + // If this is the last call to WhenDone, call the final callback + // below. + if (--pending_ == 0) { + CHECK(done_cb_ != nullptr); + std::swap(done, done_cb_); + status = status_group_.as_summary_status(); + } + } + + if (error_rendez != nullptr) { + error_rendez->StartAbort( + errors::Aborted("Stopping remaining executors.")); + error_rendez->Unref(); + } + + if (done != nullptr) { + delete this; + if (!status.ok()) { + VLOG(1) << "ExecutorBarrier finished with bad status: " << status; + } + done(status); + } + } + + ExecutorBarrier(const ExecutorBarrier&) = delete; + void operator=(const ExecutorBarrier&) = delete; +}; + +// A few helpers to facilitate create/delete kernels. + +// Creates a kernel based on "props" on device "device". The kernel can +// access the functions in the "flib". The caller takes ownership of +// returned "*kernel". +absl::Status CreateNonCachedKernel( + Device* device, FunctionLibraryRuntime* flib, + const std::shared_ptr& props, int graph_def_version, + OpKernel** kernel); + +// Deletes "kernel" returned by CreateKernel. +void DeleteNonCachedKernel(OpKernel* kernel); + +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_EXECUTOR_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/executor_factory.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/executor_factory.h new file mode 100644 index 00000000..14a8d277 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/executor_factory.h @@ -0,0 +1,50 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_EXECUTOR_FACTORY_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_EXECUTOR_FACTORY_H_ + +#include + +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +class Executor; +class Graph; +struct LocalExecutorParams; + +class ExecutorFactory { + public: + virtual absl::Status NewExecutor(const LocalExecutorParams& params, + const Graph& graph, + std::unique_ptr* out_executor) = 0; + virtual ~ExecutorFactory() {} + + static void Register(const string& executor_type, ExecutorFactory* factory); + static absl::Status GetFactory(const string& executor_type, + ExecutorFactory** out_factory); +}; + +absl::Status NewExecutor(const string& executor_type, + const LocalExecutorParams& params, const Graph& graph, + std::unique_ptr* out_executor); + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_EXECUTOR_FACTORY_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/function.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/function.h new file mode 100644 index 00000000..f86732b2 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/function.h @@ -0,0 +1,81 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_FUNCTION_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_FUNCTION_H_ + +#include +#include + +#include "absl/types/optional.h" +#include "tensorflow/core/common_runtime/device.h" +#include "tensorflow/core/common_runtime/device_mgr.h" +#include "tensorflow/core/common_runtime/function_body.h" +#include "tensorflow/core/common_runtime/function_def_utils.h" +#include "tensorflow/core/common_runtime/function_utils.h" +#include "tensorflow/core/common_runtime/graph_optimizer.h" +#include "tensorflow/core/common_runtime/inline_function_utils.h" +#include "tensorflow/core/common_runtime/process_function_library_runtime.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/protobuf/config.pb.h" + +namespace tensorflow { + +// Get default customizable kernel creator if set +const CustomKernelCreator* GetDefaultCustomKernelCreator(); + +// Registers a default customizable kernel creator for a function call. +// +// If c->CanCreateKernel returns false, we still fall back to an executor-based +// interpreter op kernel to execute a function. Else c->CreateKernel() can be +// used to create a kernel that will compile the function with XLA and run the +// resulting program. +void RegisterDefaultCustomKernelCreator(CustomKernelCreator* c); + +// Creates a FunctionLibraryRuntime, which instantiates functions +// defined in "lib_def" and executes functions on the "device". +// "device_mgr" must contain the "device". +// +// The returned object does not take ownerships of "device" or +// "lib_def". The caller must ensure "device" and "lib_def" outlives +// the returned object. +// +// The "parent" is a pointer to the ProcessFunctionLibraryRuntime object that +// typically owns the created FunctionLibraryRuntime object. The parent pointer +// is not owned by the FunctionLibraryRuntime object. +core::RefCountPtr NewFunctionLibraryRuntime( + const DeviceMgr* device_mgr, Env* env, const ConfigProto* config, + Device* device, int graph_def_version, + const FunctionLibraryDefinition* lib_def, thread::ThreadPool* thread_pool, + const OptimizerOptions& optimizer_options, + const SessionMetadata* session_metadata, + ProcessFunctionLibraryRuntime* parent); + +// Given a numerical function "f", returns another numerical function +// "g", such that if "f" takes N inputs and produces M outputs, "g" +// takes N + M inputs and produces N outputs. I.e., if +// (y1, y2, ..., y_M) = f(x1, x2, ..., x_N), +// g is a function which is +// (dL/dx1, dL/dx2, ..., dL/dx_N) = g(x1, x2, ..., x_N, +// dL/dy1, dL/dy2, ..., dL/dy_M), +// where L is a scalar-value function of (...x_i...). +// +// TODO(zhifengc): Asks math expert to say the comment again. +std::unique_ptr SymbolicGradient(const FunctionBody& f); + +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_FUNCTION_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/function_body.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/function_body.h new file mode 100644 index 00000000..959f9803 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/function_body.h @@ -0,0 +1,54 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_FUNCTION_BODY_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_FUNCTION_BODY_H_ + +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/gtl/inlined_vector.h" +#include "tensorflow/core/platform/refcount.h" + +namespace tensorflow { + +class FunctionRecord; +class Graph; +class Node; + +// FunctionLibraryRuntime::GetFunctionBody returns a description of an +// instantiated function that is represented as a Graph with arg/ret +// nodes annotated. +struct FunctionBody { + core::RefCountPtr record; + Graph* graph = nullptr; // owned. + DataTypeVector arg_types; + DataTypeVector ret_types; + // arg_nodes[i] contains the i'th function input. In other words, + // GetNodeAttr(arg_nodes[i]->attrs(), "index") == i. + absl::InlinedVector arg_nodes; + // ret_nodes[i] contains the i'th function output. In other words, + // GetNodeAttr(ret_nodes[i]->attrs(), "index") == i. + absl::InlinedVector ret_nodes; + absl::InlinedVector control_ret_nodes; + + FunctionBody() {} + FunctionBody(core::RefCountPtr&& record, + DataTypeSlice arg_types, DataTypeSlice ret_types, Graph* g); + ~FunctionBody(); +}; + +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_FUNCTION_BODY_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/function_def_utils.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/function_def_utils.h new file mode 100644 index 00000000..cd3b021e --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/function_def_utils.h @@ -0,0 +1,73 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_FUNCTION_DEF_UTILS_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_FUNCTION_DEF_UTILS_H_ + +#include +#include + +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/refcount.h" + +namespace tensorflow { + +class AttrSlice; +struct FunctionBody; +class FunctionDef; +class FunctionLibraryDefinition; +class FunctionRecord; +class OpDef; + +// Instantiates FunctionDef into a graph. Set *fbody to point to the +// FunctionBody that holds the instantiated FunctionDef. +absl::Status FunctionDefToBodyHelper(core::RefCountPtr&& record, + const AttrSlice& attrs, + const FunctionLibraryDefinition* lib_def, + std::unique_ptr* fbody); + +// Instantiates FunctionDef into a graph. Set *fbody to point to the +// FunctionBody that holds the instantiated FunctionDef. +// +// NOTE(mrry): This implementation incurs a copy of `fdef`. If possible, use +// the overload that takes a `core::RefCountPtr`. +absl::Status FunctionDefToBodyHelper(const FunctionDef& fdef, + const AttrSlice& attrs, + const FunctionLibraryDefinition* lib_def, + std::unique_ptr* fbody); + +// Instantiates FunctionDef into a graph. Set *fbody to point to the +// FunctionBody that holds the instantiated FunctionDef. Use custom function +// signature lookup, in case instantiated function is not in the 'lib_def'. +absl::Status FunctionDefToBodyHelper( + core::RefCountPtr&& record, const AttrSlice& attrs, + const FunctionLibraryDefinition* lib_def, + const std::function& + get_func_sig, + std::unique_ptr* fbody); + +// Removes all stateless nodes that do not contribute to a return +// value from the function body. Unlike `RemoveDeadNodes()`, which is +// triggered by `OptimizerOptions.do_function_inlining`, this pass +// ignores the SINK node, from which (by definition) all nodes are +// reverse reachable, and preserves all nodes that are reachable from +// control output nodes. +void PruneFunctionBody(const FunctionDef& fdef, Graph* g, + absl::Span additional_root_nodes = {}); + +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_FUNCTION_DEF_UTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/function_optimization_registry.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/function_optimization_registry.h new file mode 100644 index 00000000..ba501d3e --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/function_optimization_registry.h @@ -0,0 +1,105 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_FUNCTION_OPTIMIZATION_REGISTRY_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_FUNCTION_OPTIMIZATION_REGISTRY_H_ + +#include +#include +#include + +#include "tensorflow/core/common_runtime/device_set.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/protobuf/config.pb.h" + +// Classes to maintain a static registry of Graph based passes to be applied to +// a function graph. + +namespace tensorflow { + +// A pass to be registered with the FunctionOptimizationPassRegistry. This pass +// takes in a DeviceSet (available devices for executing the Graph), ConfigProto +// (session configuration parameters), an optional target device for XLA +// compilation, Graph (computation), +// FunctionLibraryDefinition (mapping between function names and function +// definitions of the Graph), control ret/target node names (names of nodes that +// must execute but their data outputs, if they have any, are irrelevant), and +// whether control ret nodes (via thier name) were updated. Mutations to the +// Graph and other associated arguments are performed inplace by the pass. +class FunctionOptimizationPass { + public: + // Grouped Options for the optimized function. + struct FunctionOptions { + // Specifies the compilation device type(CPU, GPU, etc) + // that should be used for entire function. + std::string xla_compile_device_type = ""; + // Whether soft placement and outside compilation + // are enabled for the function. + bool allow_soft_placement = false; + }; + + virtual ~FunctionOptimizationPass() {} + virtual absl::Status Run(const std::string& function_name, + const DeviceSet& device_set, + const ConfigProto& config_proto, + const FunctionOptions& function_options, + std::unique_ptr* graph, + FunctionLibraryDefinition* flib_def, + std::vector* control_ret_node_names, + bool* control_rets_updated) = 0; +}; + +// A global function optimization pass registry that is used to hold one +// FunctionOptimizationPass. Passes registered to this registry will run before +// passes registered in OptimizationPassRegistry. +class FunctionOptimizationPassRegistry { + public: + // Initializes registry with a pass. Only one pass should be set. An assertion + // will be triggered if the registry already has a pass set and is being + // initialized with another pass. + void Init(std::unique_ptr pass); + + // Runs a pass if the registry contains one. + absl::Status Run( + const std::string& function_name, const DeviceSet& device_set, + const ConfigProto& config_proto, + const FunctionOptimizationPass::FunctionOptions& function_options, + std::unique_ptr* graph, FunctionLibraryDefinition* flib_def, + std::vector* control_ret_node_names, + bool* control_rets_updated); + + // Returns the global registry of function graph passes. + static FunctionOptimizationPassRegistry& Global(); + + private: + std::unique_ptr pass_; +}; + +namespace function_optimization_registration { + +class FunctionOptimizationPassRegistration { + public: + explicit FunctionOptimizationPassRegistration( + std::unique_ptr pass) { + FunctionOptimizationPassRegistry::Global().Init(std::move(pass)); + } +}; + +} // namespace function_optimization_registration + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_FUNCTION_OPTIMIZATION_REGISTRY_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/function_testlib.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/function_testlib.h new file mode 100644 index 00000000..9618c408 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/function_testlib.h @@ -0,0 +1,54 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_FUNCTION_TESTLIB_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_FUNCTION_TESTLIB_H_ + +#include "tensorflow/cc/framework/scope.h" +#include "tensorflow/core/framework/function.h" + +namespace tensorflow { +namespace test { +namespace function { + +// {} -> y:DT_STRING (device where this op runs). +FunctionDef FindDevice(); +FunctionDef FindDeviceWithUuid(); + +class BlockingOpState { + public: + void AwaitState(int awaiting_state); + + void MoveToState(int expected_current, int next); + + private: + mutex mu_; + condition_variable cv_; + int state_ = 0; +}; + +extern BlockingOpState* blocking_op_state; + +FunctionDef BlockingOpFn(); + +// Adds a function call to the given scope and returns the output for the node. +// TODO(phawkins): replace with C++ API for calling functions, when that exists. +Output Call(Scope* scope, const string& op_name, const string& fn_name, + absl::Span inputs); + +} // namespace function +} // namespace test +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_FUNCTION_TESTLIB_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/function_utils.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/function_utils.h new file mode 100644 index 00000000..cfbfe869 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/function_utils.h @@ -0,0 +1,105 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_FUNCTION_UTILS_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_FUNCTION_UTILS_H_ + +#include +#include + +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { + +class AttrSlice; +class Graph; +class GraphDef; +class NameAttrList; +class Node; +class NodeDef; +class OpDef; + +// Debugging facility. Returns a debug string for a graph +// representing an instantiated function. +string DebugString(const Graph* g); + +// Dump the contents of the "graph" to log files if the logging level is +// sufficiently high. +void DumpGraph(absl::string_view label, const Graph* g); + +// Convert the Graph of a function to a GraphDef. +// +// Handles renaming of nodes to avoid duplicate names which may +// be present after various rewriting operations. +void ToGraphDef(const Graph* g, GraphDef* gdef, bool pretty = false); + +// Extracts function name and attributes from `call_def` +// `call_def` can be a native function call (where the op type is the function +// name) or a call through PartitionedCall/StatefulPartitionedCall. +absl::Status NameAndAttrsFromFunctionCall(const NodeDef& call_def, + NameAttrList* function); + +// A few hand-crafted optimization on the instantiated function body +// (a Graph*). + +// Removes nodes that are +// 1. not stateful; and +// 2. not _Arg; and +// 3. not reachable from _Retval. +// +// This function is triggered by function inlining, unlike 'PruneFunctionBody' +// it doesn't preserve nodes that are reachable from control returns. Function +// inlining is responsible for connecting control return nodes with the nodes +// that have input control edges from the inlined function call node. +// +// Assuming that automatic control dependency tracking is correct, absence of +// outgoing control edge from the function call node means that no one needs to +// observe side-effect that might have been generated by the function (see +// documentation in common_runtime/function.cc for details). +// +// Returns true iff any node is removed from "g". +bool RemoveDeadNodes(Graph* g); + +// Find a pattern: +// src -(in)-> node -(out)-> dst, where +// 1) node is an identity node; +// 2) in is the only incoming data edge; +// 3) out is the only outgoing data edge; +// +// Rewrites the above pattern with src->dst and relevant data +// dependencies updated. Repeat the process until no such pattern +// left. +bool RemoveIdentityNodes(Graph* g); + +// Rewrites _ListToArray and _ArrayToList to a set of Identity nodes. +bool RemoveListArrayConverter(Graph* g); + +// Extracts function name and attributes from `call_def` and invokes +// flr->Instantiate(name, attrs, handle). +// `call_def` can be a native function call (where the op type is the function +// name) or a call through PartitionedCall/StatefulPartitionedCall. +absl::Status InstantiateFunctionCall(const NodeDef& call_def, + FunctionLibraryRuntime* flr, + FunctionLibraryRuntime::Handle* handle); + +// Returns true iff `n` represents a function call. `n` can be a native +// function call (n.type_string() is the function name), +// a PartitionedCall/StatefulPartitionedCall, or a SymbolicGradient (which +// has been deprecated for a while). +bool IsFunctionCall(const FunctionLibraryDefinition& lib_def, const Node& n); +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_FUNCTION_UTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.h new file mode 100644 index 00000000..9d025cc0 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.h @@ -0,0 +1,63 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_GPU_GPU_BFC_ALLOCATOR_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_GPU_GPU_BFC_ALLOCATOR_H_ + +#include +#include +#include + +#include "xla/tsl/framework/allocator.h" +#include "xla/tsl/framework/bfc_allocator.h" +#include "tsl/platform/macros.h" + +namespace tensorflow { + +// A GPU memory allocator that implements a 'best-fit with coalescing' +// algorithm. +class GPUBFCAllocator : public tsl::BFCAllocator { + public: + // See BFCAllocator::Options. + struct Options { + // Overridden by TF_FORCE_GPU_ALLOW_GROWTH if that envvar is set. + bool allow_growth = false; + + // If nullopt, defaults to TF_ENABLE_GPU_GARBAGE_COLLECTION, or true if that + // envvar is not present. + // + // Note: + // + // - BFCAllocator defaults garbage_collection to false, not true. + // - this is not the same override behavior as TF_FORCE_GPU_ALLOW_GROWTH. + std::optional garbage_collection; + + double fragmentation_fraction = 0; + bool allow_retry_on_failure = true; + }; + + GPUBFCAllocator(std::unique_ptr sub_allocator, + size_t total_memory, const std::string& name, + const Options& opts); + + ~GPUBFCAllocator() override {} + + GPUBFCAllocator(const GPUBFCAllocator&) = delete; + void operator=(const GPUBFCAllocator&) = delete; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_GPU_GPU_BFC_ALLOCATOR_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/gpu/gpu_cudamalloc_allocator.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/gpu/gpu_cudamalloc_allocator.h new file mode 100644 index 00000000..ba08f096 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/gpu/gpu_cudamalloc_allocator.h @@ -0,0 +1,52 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_GPU_GPU_CUDAMALLOC_ALLOCATOR_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_GPU_GPU_CUDAMALLOC_ALLOCATOR_H_ + +#include +#include + +#include "xla/stream_executor/stream_executor.h" +#include "xla/tsl/framework/allocator.h" +#include "xla/tsl/framework/device_id.h" +#include "tsl/platform/macros.h" + +namespace tensorflow { + +// An allocator which directly uses cuMemAlloc and cuMemFree to allocate and +// free memory. +class GPUcudaMallocAllocator : public tsl::Allocator { + public: + explicit GPUcudaMallocAllocator(tsl::PlatformDeviceId platform_device_id); + std::string Name() override { return "gpu_debug"; } + void* AllocateRaw(size_t alignment, size_t num_bytes) override; + void DeallocateRaw(void* ptr) override; + bool TracksAllocationSizes() const override; + + tsl::AllocatorMemoryType GetMemoryType() const override { + return tsl::AllocatorMemoryType::kDevice; + } + + private: + se::StreamExecutor* stream_exec_; // Not owned. + + GPUcudaMallocAllocator(const GPUcudaMallocAllocator&) = delete; + void operator=(const GPUcudaMallocAllocator&) = delete; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_GPU_GPU_CUDAMALLOC_ALLOCATOR_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/gpu/gpu_debug_allocator.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/gpu/gpu_debug_allocator.h new file mode 100644 index 00000000..13f10007 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/gpu/gpu_debug_allocator.h @@ -0,0 +1,93 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_GPU_GPU_DEBUG_ALLOCATOR_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_GPU_GPU_DEBUG_ALLOCATOR_H_ + +#include +#include +#include +#include + +#include "xla/stream_executor/stream_executor.h" +#include "xla/tsl/framework/allocator.h" +#include "xla/tsl/framework/device_id.h" +#include "tsl/platform/macros.h" + +namespace tensorflow { + +// An allocator that wraps a GPU allocator and adds debugging +// functionality that verifies that users do not write outside their +// allocated memory. +class GPUDebugAllocator : public tsl::Allocator { + public: + explicit GPUDebugAllocator(tsl::Allocator* allocator, + tsl::PlatformDeviceId platform_device_id); + ~GPUDebugAllocator() override; + std::string Name() override { return "gpu_debug"; } + void* AllocateRaw(size_t alignment, size_t num_bytes) override; + void DeallocateRaw(void* ptr) override; + bool TracksAllocationSizes() const override; + size_t RequestedSize(const void* ptr) const override; + size_t AllocatedSize(const void* ptr) const override; + int64_t AllocationId(const void* ptr) const override; + std::optional GetStats() override; + bool ClearStats() override; + + // For testing. + bool CheckHeader(void* ptr); + bool CheckFooter(void* ptr); + + private: + tsl::Allocator* base_allocator_ = nullptr; // owned + + se::StreamExecutor* stream_exec_; // Not owned. + + GPUDebugAllocator(const GPUDebugAllocator&) = delete; + void operator=(const GPUDebugAllocator&) = delete; +}; + +// An allocator that wraps a GPU allocator and resets the memory on +// allocation and free to 'NaN', helping to identify cases where the +// user forgets to initialize the memory. +class GPUNanResetAllocator : public tsl::Allocator { + public: + explicit GPUNanResetAllocator(tsl::Allocator* allocator, + tsl::PlatformDeviceId platform_device_id); + ~GPUNanResetAllocator() override; + std::string Name() override { return "gpu_nan_reset"; } + void* AllocateRaw(size_t alignment, size_t num_bytes) override; + void DeallocateRaw(void* ptr) override; + size_t RequestedSize(const void* ptr) const override; + size_t AllocatedSize(const void* ptr) const override; + std::optional GetStats() override; + bool ClearStats() override; + + tsl::AllocatorMemoryType GetMemoryType() const override { + return base_allocator_->GetMemoryType(); + } + + private: + tsl::Allocator* base_allocator_ = nullptr; // owned + + se::StreamExecutor* stream_exec_; // Not owned. + + GPUNanResetAllocator(const GPUNanResetAllocator&) = delete; + void operator=(const GPUNanResetAllocator&) = delete; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_GPU_GPU_DEBUG_ALLOCATOR_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/gpu/gpu_device.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/gpu/gpu_device.h new file mode 100644 index 00000000..d09cdc2f --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/gpu/gpu_device.h @@ -0,0 +1,477 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#if !GOOGLE_CUDA && !TENSORFLOW_USE_ROCM +#error This file must only be included when building with Cuda or ROCm support +#endif + +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_GPU_GPU_DEVICE_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_GPU_GPU_DEVICE_H_ + +// TODO(b/282059652): Merge google internal and open-source code path once TF +// dependency issue is resolved. +#if (defined(PLATFORM_GOOGLE) && defined(TF_PLATFORM_LINUX_X86_64)) +#define TF_GPU_USE_PJRT +#endif // PLATFORM_GOOGLE && TF_PLATFORM_LINUX_X86_64 + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#ifdef TF_GPU_USE_PJRT +#include "tensorflow/compiler/jit/pjrt_device_context.h" +#include "tensorflow/compiler/tf2xla/layout_util.h" +#include "xla/pjrt/local_device_state.h" +#include "xla/stream_executor/integrations/tf_allocator_adapter.h" +#endif // TF_GPU_USE_PJRT +#include "xla/tsl/framework/device_id.h" +#include "tensorflow/core/common_runtime/device_factory.h" +#include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h" +#include "tensorflow/core/common_runtime/gpu/gpu_id_manager.h" +#include "tensorflow/core/common_runtime/gpu_device_context.h" +#include "tensorflow/core/common_runtime/local_device.h" +#include "tensorflow/core/common_runtime/node_file_writer.h" +#include "tensorflow/core/common_runtime/scoped_allocator_mgr.h" +#include "tensorflow/core/common_runtime/shared_counter.h" +#include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/framework/device_base.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/gtl/inlined_vector.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/stream_executor.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/public/session_options.h" + +namespace Eigen { +class StreamInterface; +} + +namespace tensorflow { +class GPUKernelTracker; + +class ConcretePerOpGpuDevice : public PerOpGpuDevice { + public: + ConcretePerOpGpuDevice(); + + void Reinitialize(OpKernelContext* context, void* gpu_stream, + tsl::TfDeviceId tf_device_id, Allocator* base_allocator, + char* scratch); + + void Reinitialize(OpKernelContext* context, void* gpu_stream, + tsl::PlatformDeviceId platform_device_id, + Allocator* base_allocator, char* scratch); + + const Eigen::GpuDevice& device() const override; + + private: + std::unique_ptr<::Eigen::StreamInterface> stream_device_; +}; + +class BaseGPUDevice : public LocalDevice { + public: + BaseGPUDevice(const SessionOptions& options, const std::string& name, + Bytes memory_limit, const DeviceLocality& locality, + tsl::TfDeviceId tf_device_id, + const std::string& physical_device_desc, + Allocator* gpu_allocator, Allocator* cpu_allocator, + bool sync_every_op); + + ~BaseGPUDevice() override; + + struct StreamGroup { + se::Stream* compute = nullptr; +#if TENSORFLOW_USE_ROCM + se::Stream* nccl = nullptr; +#endif + se::Stream* host_to_device = nullptr; + se::Stream* device_to_host = nullptr; + gtl::InlinedVector device_to_device; + int priority = 0; + }; + + // Initialize the device and return the status of initialization. +#ifdef TF_GPU_USE_PJRT + Status Init(const SessionOptions& options, + xla::LocalDeviceState* xla_local_device_state); +#else + Status Init(const SessionOptions& options); +#endif // TF_GPU_USE_PJRT + + void Compute(OpKernel* op_kernel, OpKernelContext* context) override; + + Status Sync() override; + + void ComputeAsync(AsyncOpKernel* op_kernel, OpKernelContext* context, + AsyncOpKernel::DoneCallback done) override; + + Status MakeTensorFromProto(const TensorProto& tensor_proto, + AllocatorAttributes alloc_attrs, + Tensor* tensor) override; + + void CopyTensorInSameDevice(const Tensor* input_tensor, Tensor* output_tensor, + const DeviceContext* device_context, + StatusCallback done) override; + + // The caller owns the returned device. + PerOpGpuDevice* MakeGpuDevice() override; + + Status ReinitializeGpuDevice(OpKernelContext* context, PerOpGpuDevice* device, + DeviceContext* dc, + Allocator* allocator) override; + + // Returns the platform GPU id of this device within the native driver system; + // e.g., for CUDA and ROCm this is the ordinal of the GPU within the system. + int gpu_id() const { + tsl::PlatformDeviceId platform_device_id; + TF_CHECK_OK( + GpuIdManager::TfToPlatformDeviceId(tf_device_id_, &platform_device_id)); + return platform_device_id.value(); + } + + // The executor that provides control for the device; e.g., for CUDA this + // corresponds to the cuda context. + se::StreamExecutor* executor() const { return executor_; } + + Allocator* GetScopedAllocator(AllocatorAttributes attr, + int64_t step_id) override; + + ScopedAllocatorMgr* GetScopedAllocatorMgr() const override { + return scoped_allocator_mgr_.get(); + } + + // The following two functions always return 0 unless one of the + // related experimental config options has been specified. + + // If returned value is > 0 then GPU Memory chunks freed before this count + // are guaranteed not to be in use by any kernel pending on this device. + uint64 SafeAllocFrontier(uint64 old_value) override; + + // Returns the number of kernels that have been queued for execution on + // the compute stream and are not yet known to have completed. + int PendingKernels(); + + int priority() const { return stream_->priority; } + + // Helper method for unit tests to reset the streams. Never use in production. + static void TestOnlyReset(); + + se::Stream* compute_stream() { return stream_->compute; } + + // Given the compute stream for a GPU or virtual GPU, return the TfDeviceId + // for the GPU or vGPU. + static std::optional FindTfDeviceId(se::Stream* compute); + + bool merge_host_to_device_stream() const override { + return stream_merge_options_.merge_host_to_device_stream(); + } + + bool merge_device_to_host_stream() const override { + return stream_merge_options_.merge_device_to_host_stream(); + } + + bool merge_device_to_device_stream() const override { + return stream_merge_options_.merge_device_to_device_stream(); + } + + protected: + Allocator* gpu_allocator_; // not owned + Allocator* cpu_allocator_; // not owned + + se::StreamExecutor* executor_; // not owned + std::unique_ptr scoped_allocator_mgr_; + + private: + friend class GPUDeviceTestHelper; + class StreamGroupFactory; + + core::RefCountPtr pjrt_device_context_; + StreamGroup* stream_; + mutex scratch_init_mutex_; + char* scratch_ = nullptr; + GPUDeviceContext* device_context_; + DeviceBase::AcceleratorDeviceInfo* accelerator_device_info_ = nullptr; + mutex trace_mu_; + tsl::TfDeviceId tf_device_id_; + const bool sync_every_op_ = false; + EventMgr* em_ = nullptr; + std::unique_ptr thread_pool_; + std::unique_ptr kernel_tracker_; + int32 pending_cap_ = 0; + bool timestamped_allocator_ = false; + NodeFileWriter* node_file_writer_ = nullptr; // not owned + const GPUOptions::Experimental::StreamMergeOptions stream_merge_options_; + + // Initialize scratch buffers used by Eigen. + Status InitScratchBuffers(); + + void ReinitializeDevice(OpKernelContext* context, PerOpGpuDevice* device, + int stream_id, Allocator* allocator); + + std::string ComputeOpKernelDebugString(const OpKernel& op_kernel, + const int& stream_id); + + // This method returns an initialization status, in addition to + // calling the "done" StatusCallback, if there is a failure to + // allocate memory or if the tensor "from" is not DMA-copyable. + // If there is no error prior to enqueueing the copy, an OK status + // is returned. + Status MaybeCopyTensorToGPU(const AllocatorAttributes& alloc_attrs, + const Tensor& from, Tensor* to, + StatusCallback done); + + Tensor CopyGpuTensorToHostDebugOnly(const Tensor& gpu_tensor); + void LogInputs(OpKernel* op_kernel, OpKernelContext* context); + void LogOutputs(OpKernel* op_kernel, OpKernelContext* context); +}; + +// A per-compute-stream utility that keeps track of kernels that have been +// queued for execution but may not yet have terminated and also the queued +// time of the most recently terminated kernel. +class GPUKernelTracker { + public: + // Controls the strategy for inserting tracking events after GPU kernels. + // If max_interval >= 0, then insert an event after this many kernels + // if an event has not been inserted for another reason. + // If max_bytes > 0, then insert an event after kernels allocating this + // many bytes have been queued since the last event. + // If max_pending > 0, then track up to this many events at once. If + // this limit is reached the GPU::Compute() method will delay starting + // additional ops until some event completes. If 0 and one of the other + // fields is non-zero, then a reasonable default will be selected. + struct Params { + int max_interval = 0; + int max_bytes = 0; + int max_pending = 0; + Params(int mi, int mb, int mp) + : max_interval(mi), max_bytes(mb), max_pending(mp) {} + }; + + // If we're going to share a SharedCounter with an allocator, it's owned + // by the allocator because allocators are initialized once per process. + // Devices are per-session. + explicit GPUKernelTracker(const Params& params, Env* env, + se::Stream* compute_stream, + SharedCounter* timing_counter, Allocator* allocator, + EventMgr* event_manager) + : params_(params), + env_(env), + stream_(compute_stream), + timing_counter_(timing_counter), + allocator_(allocator), + em_(event_manager), + pending_kernels_( + params.max_pending > 0 ? std::max(8, 2 * params.max_pending) : 64) { + mem_since_last_ = 0; + if (!timing_counter_) { + // There's not a preexisting counter owned by GPUProcessState, i.e. + // pending_cap > 0 but timestamped_allocator == false. + owned_counter_ = std::make_unique(); + timing_counter_ = owned_counter_.get(); + } + } + + // Determine whether a GPU kernel should have a recording event queued + // immediately afterwards. If so, advance the counter and return the new + // counter value after enqueuing. + uint64 MaybeQueue(OpKernelContext* ctx); + + // Record that a GPU kernel has just been enqueued on the compute stream. + // Inserts the supplied counter value in a new PendingKernel record appended + // to the end of the ring buffer then returns that same count. + // Caller is responsible for ensuring that RecordTerminate() is eventually + // called with the same counter value. + void RecordQueued(uint64 queued_count, int weight) + TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); + + // Takes a count value returned by RecordQueued and finds the corresponding + // PendingKernel record in the ring buffer. Marks the kernel as completed and + // advances the completion frontier accordingly. + void RecordTerminated(uint64 queued_count); + + // Returns the largest timing count such that all kernels queued no + // later than that count are known to have terminated. + inline uint64 LastTerminatedCount(uint64 old_value) { + uint64 new_value = last_terminated_count_.load(std::memory_order_relaxed); + if (new_value == old_value) { + MaybeQueueProgressEvent(); + } + return new_value; + } + + // Returns the number of kernels enqueued that are not yet known to + // have terminated. + int NumPending() { + mutex_lock l(mu_); + return num_pending_; + } + + // Yield current thread until number of pending kernels no longer + // exceeds the cap. + void PauseWhilePendingExceeds(int cap) TF_LOCKS_EXCLUDED(mu_) { + mutex_lock l(mu_); + while (num_pending_ > cap) { + VLOG(1) << "num_pending_=" << num_pending_ << " cap=" << cap; + pending_decreased_.wait(l); + } + } + + private: + friend class GPUKernelTrackerTest; + Params params_; + Env* env_; + se::Stream* stream_; + SharedCounter* timing_counter_; + std::unique_ptr owned_counter_; + Allocator* allocator_ = nullptr; + EventMgr* em_ = nullptr; + std::atomic last_terminated_count_ = {1}; + + void MaybeQueueProgressEvent(); + + // Records when a kernel was queued for execution. Kernel launches are + // identified by a unique count value from a per-GPU device timing counter. + struct PendingKernel { + uint64 queued_count; + int weight; + bool terminated; + PendingKernel(const PendingKernel& pk) = default; + PendingKernel() : queued_count(0), weight(0), terminated(false) {} + }; + mutex mu_; + int32 mem_since_last_ TF_GUARDED_BY(mu_); + int32 ops_since_last_ TF_GUARDED_BY(mu_); + // Ring buffer of PendingKernel records. + std::vector pending_kernels_ TF_GUARDED_BY(mu_); + // Next unused slot in pending_kernels_. + int first_available_ TF_GUARDED_BY(mu_) = 0; + // Last completed PendingKernel such that all prior PendingKernels are + // also completed. With out-of-order completion there may be a mixture + // of completed and uncompleted entries between last_completed_ and + // first_available_. + int last_completed_ TF_GUARDED_BY(mu_) = -1; + // Sum of weights of the outstanding events marking tracked kernels. + int num_pending_ TF_GUARDED_BY(mu_) = 0; + condition_variable pending_decreased_ TF_GUARDED_BY(mu_); +}; + +class BaseGPUDeviceFactory : public DeviceFactory { + public: + Status ListPhysicalDevices(std::vector* devices) override; + Status CreateDevices(const SessionOptions& options, + const std::string& name_prefix, + std::vector>* devices) override; + Status GetDeviceDetails(int device_index, + std::unordered_map* details) override; + + struct InterconnectMap { + // Name of interconnect technology, if known. + std::string name; + // If possible, strength should approximate Gb/sec bandwidth rate. + // Where architecture-specific subclassing is not done that won't + // always be possible. The minimum expectation is that + // faster links should have a higher value than slower links. + int32 strength; + static const int kSameDeviceStrength; + static const int kStreamExecutorStrength; + std::set> + directed_links; + }; + + protected: + // Populates *maps with interconnect maps for all local direct access + // pathways between GPUs. + virtual Status GetInterconnectMaps( + const std::vector& visible_gpu_order, + se::Platform* gpu_manager, std::vector* maps); + + struct TfDeviceIdHash { + std::size_t operator()(const tsl::TfDeviceId& id) const noexcept { + return std::hash{}(id.value()); + } + }; + typedef std::unordered_map + LocalityMap; + // Populates *localities with the DeviceLocality descriptor for + // every TfDeviceId. + virtual Status GetDeviceLocalities( + int num_tf_gpus, const std::vector& interconnects, + LocalityMap* localities); + + private: + // Creates a BaseGPUDevice associated with 'tf_device_id', and adds it to the + // 'devices' vector. The 'gpu_allocator' is created by the caller and usually + // preallocates a set amount of GPU memory. +#ifdef TF_GPU_USE_PJRT + Status CreateGPUDevice(const SessionOptions& options, + const std::string& name_prefix, + tsl::TfDeviceId tf_device_id, + const DeviceLocality& dev_locality, + xla::LocalDeviceState* xla_local_device_state, + Allocator* gpu_allocator, + std::vector>* devices); +#else + Status CreateGPUDevice(const SessionOptions& options, + const std::string& name_prefix, + tsl::TfDeviceId tf_device_id, + const DeviceLocality& dev_locality, + Allocator* gpu_allocator, + std::vector>* devices); +#endif // TF_GPU_USE_PJRT + + virtual std::unique_ptr CreateGPUDevice( + const SessionOptions& options, const string& name, Bytes memory_limit, + const DeviceLocality& dev_locality, tsl::TfDeviceId tf_device_id, + const string& physical_device_desc, Allocator* gpu_allocator, + Allocator* cpu_allocator) = 0; + + Status EnablePeerAccess( + const std::vector& visible_gpu_order); + + // Returns into 'ids' the list of valid platform GPU ids, in the order that + // they should map to TF GPU ids "/device:GPU:0", "/device:GPU:1", etc, + // based upon 'visible_gpu_order' which was generated by parsing + // GPUOptions::visible_device_list which is a comma-separated list of CUDA or + // ROCm GPU ids. + Status GetValidDeviceIds( + const std::vector& visible_gpu_order, + std::vector* ids); + + // Cache the valid device IDs if not already cached. Cached IDs are stored in + // field cached_device_ids_. Passes {0, 1, ..., num_devices-1} to + // GetValidDeviceIds, so this should only be used in functions where all + // devices should be treated as visible, like ListPhysicalDevices. + Status CacheDeviceIds(); + + // visible_gpu_initialized_[platform_device_id] is true if visible GPU + // platform_device_id has been initialized by the process. + std::unordered_map visible_gpu_initialized_; + + // Cached device IDs, as returned by GetValidDeviceIds when every physical + // device is visible. Cache should not be used if some devices are not + // visible. + std::vector cached_device_ids_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_GPU_GPU_DEVICE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/gpu/gpu_event_mgr.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/gpu/gpu_event_mgr.h new file mode 100644 index 00000000..601119fb --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/gpu/gpu_event_mgr.h @@ -0,0 +1,23 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// TODO(annarev): remove this file once all includes are updated to +// include device_event_mgr.h instead. + +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_GPU_GPU_EVENT_MGR_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_GPU_GPU_EVENT_MGR_H_ + +#include "tensorflow/core/common_runtime/device/device_event_mgr.h" + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_GPU_GPU_EVENT_MGR_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/gpu/gpu_id.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/gpu/gpu_id.h new file mode 100644 index 00000000..c2849d2d --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/gpu/gpu_id.h @@ -0,0 +1,22 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_GPU_GPU_ID_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_GPU_GPU_ID_H_ + +#include "tensorflow/core/common_runtime/device/device_id.h" + +// TODO(sanjoy): Delete the header and forward the references. + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_GPU_GPU_ID_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/gpu/gpu_id_manager.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/gpu/gpu_id_manager.h new file mode 100644 index 00000000..aa8553f6 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/gpu/gpu_id_manager.h @@ -0,0 +1,43 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_GPU_GPU_ID_MANAGER_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_GPU_GPU_ID_MANAGER_H_ + +#include "xla/tsl/framework/device_id.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { + +// Class that maintains a map from TfDeviceId to PlatformDeviceId, and manages +// the translation between them. +class GpuIdManager { + public: + // Adds a mapping from tf_device_id to platform_device_id. + static absl::Status InsertTfPlatformDeviceIdPair( + tsl::TfDeviceId tf_device_id, tsl::PlatformDeviceId platform_device_id); + + // Gets the platform_device_id associated with tf_device_id. Returns OK if + // found. + static absl::Status TfToPlatformDeviceId( + tsl::TfDeviceId tf_device_id, tsl::PlatformDeviceId* platform_device_id); + + // Clears the map. Used in unit tests only. + static void TestOnlyReset(); +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_GPU_GPU_ID_MANAGER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/gpu/gpu_managed_allocator.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/gpu/gpu_managed_allocator.h new file mode 100644 index 00000000..78c57ca2 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/gpu/gpu_managed_allocator.h @@ -0,0 +1,38 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_GPU_GPU_MANAGED_ALLOCATOR_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_GPU_GPU_MANAGED_ALLOCATOR_H_ + +#include + +#include "xla/tsl/framework/allocator.h" + +namespace tensorflow { + +// An allocator for CUDA unified memory. Memory allocated with this allocator +// can be accessed from both host and device. CUDA transparently migrates dirty +// pages, which can be slow. Therefore, this allocator is intended for +// convenience in functional tests only. +class GpuManagedAllocator : public tsl::Allocator { + public: + std::string Name() override { return "GpuManagedAllocator"; } + void* AllocateRaw(size_t alignment, size_t num_bytes) override; + void DeallocateRaw(void* ptr) override; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_GPU_GPU_MANAGED_ALLOCATOR_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/gpu/gpu_process_state.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/gpu/gpu_process_state.h new file mode 100644 index 00000000..19f8448e --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/gpu/gpu_process_state.h @@ -0,0 +1,187 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_GPU_GPU_PROCESS_STATE_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_GPU_GPU_PROCESS_STATE_H_ + +// TODO(b/282059652): Merge google internal and open-source code path once TF +// dependency issue is resolved. +#if (defined(PLATFORM_GOOGLE) && defined(TF_PLATFORM_LINUX_X86_64)) +#define TF_GPU_USE_PJRT +#endif // PLATFORM_GOOGLE && TF_PLATFORM_LINUX_X86_64 + +#include +#include +#include +#include +#include + +#include "xla/tsl/framework/device_id.h" +#include "tensorflow/core/common_runtime/process_state.h" +#include "tensorflow/core/common_runtime/shared_counter.h" +#include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/thread_annotations.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/protobuf/config.pb.h" + +namespace tensorflow { + +class GPUBFCAllocator; +class PoolAllocator; + +// Singleton that manages per-process state when GPUs are present. +class GPUProcessState { + public: + // If ps == nullptr, returns pointer to the single instance of this class to + // be used within this process. + // + // If ps != nullptrs, accepts a value to be returned by all subsequent calls. + // A non-null ps may ONLY be provided during program static storage + // initialization. Must not be called more than once with a non-null ps. + // + // If a derived class of GPUProcessState is ever used in a process, it must + // always be used in place of this class. In order to ensure that existing + // calls to GPUProcessState::singleton() all resolve to the derived instance + // instead, this function must be called once during startup, supplying the + // derived instance value, prior to any accessor call to this function. + static GPUProcessState* singleton(GPUProcessState* ps = nullptr); + + // Query whether any GPU device has been created so far. + // Disable thread safety analysis since a race is benign here. + bool HasGPUDevice() const TF_NO_THREAD_SAFETY_ANALYSIS { + return gpu_device_enabled_; + } + + // Set the flag to indicate a GPU device has been created. + // Disable thread safety analysis since a race is benign here. + void EnableGPUDevice() TF_NO_THREAD_SAFETY_ANALYSIS { + gpu_device_enabled_ = true; + } + + // Returns the one GPU allocator used for the indexed GPU. + // Note that this is a system GPU index, not (necessarily) a brain + // device index. + // + // 'total_bytes' is the total number of bytes that should be made + // available to the allocator. The first call to this function for + // a given tf_device_id creates the allocator, so only the total_bytes + // used on that first call is used. + // + // "Allocator type" describes the type of algorithm to use for the + // underlying allocator. REQUIRES: Must be a valid type (see + // config.proto for the list of supported strings.). + // + // `options` is read on the very first call to this function in the process. + // After that if you pass in a set of options, they will be ignored. + // + // REQUIRES: tf_device_id must be a valid id for a BaseGPUDevice available in + // the current system environment. Otherwise returns nullptr. + virtual Allocator* GetGPUAllocator( + const GPUOptions& options, tsl::TfDeviceId tf_device_id, + size_t total_bytes, const std::vector& peer_gpu_ids); + + Allocator* GetGPUAllocator(tsl::TfDeviceId tf_device_id) { + return GetGPUAllocator(/*options=*/{}, tf_device_id, /*total_bytes=*/0, + /*peer_gpu_ids=*/{}); + } + + int NumGPUAllocators() { + mutex_lock l(mu_); + return gpu_allocators_.size(); + } + + // `options` is read on the very first call to this function in the process, + // e.g. to set the memory limit on this allocator. After that if you pass in + // a different set of options, they will be ignored. + virtual Allocator* GetGpuHostAllocator(const GPUOptions& options, + int numa_node); + + // Registers a Visitor to be invoked on new chunks of memory allocated by the + // SubAllocator of every GPU proximate to the specified bus. The AllocVisitor + // is provided with a memory pointer, a GPU id, and the size of the area it + // identifies. The pointer is not guaranteed to be valid after the call + // terminates. The intention is for this interface to be used for network + // device memory registration. "bus_id" is platform-specific. On many + // platforms it should be 0. On machines with multiple PCIe buses, it should + // be the index of one of the PCIe buses (maybe the NUMA node at which the + // PCIe is rooted). If the bus_id is invalid, results are undefined. + virtual void AddGPUAllocVisitor(int bus_id, + const SubAllocator::Visitor& visitor); + + // Registers a Visitor to be invoked on new chunks of memory allocated by + // the SubAllocator of the GpuHostAllocator for the given numa_node. + virtual void AddGpuHostAllocVisitor(int numa_node, + const SubAllocator::Visitor& visitor); + + // Registers a Visitor to be invoked on each chunk handed back for freeing to + // the SubAllocator of the GpuHostAllocator for the given numa_node. + virtual void AddGpuHostFreeVisitor(int numa_node, + const SubAllocator::Visitor& visitor); + + // Returns bus_id for the given GPU id. + virtual int BusIdForGPU(tsl::TfDeviceId tf_device_id); + + SharedCounter* GPUAllocatorCounter(tsl::TfDeviceId tf_device_id); + + protected: + // GPUProcessState is a singleton that should not normally be deleted except + // at process shutdown. + GPUProcessState(); + virtual ~GPUProcessState() {} + friend class GPUDeviceTest; + + // Helper method for unit tests to reset the ProcessState singleton by + // cleaning up everything. Never use in production. + virtual void TestOnlyReset(); + + ProcessState::MDMap* mem_desc_map() { + if (process_state_) return &process_state_->mem_desc_map_; + return nullptr; + } + + static GPUProcessState* instance_; + ProcessState* process_state_; // Not owned. + bool gpu_device_enabled_; + + mutex mu_; + + struct AllocatorParts { + std::unique_ptr allocator; + std::unique_ptr counter; + GPUBFCAllocator* bfc_allocator; + SubAllocator* sub_allocator; // owned by allocator + std::unique_ptr recording_allocator; + +#ifdef TF_GPU_USE_PJRT + // Not owning GPU allocator. The allocator is owned by PJRT. If + // `allocator_not_owned` is set, `allocator` owned by AllocatorParts won't + // be set. + Allocator* allocator_not_owned; +#endif // TF_GPU_USE_PJRT + }; + std::vector gpu_allocators_ TF_GUARDED_BY(mu_); + std::vector> gpu_visitors_ + TF_GUARDED_BY(mu_); + + std::vector gpu_host_allocators_ TF_GUARDED_BY(mu_); + std::vector> gpu_host_alloc_visitors_ + TF_GUARDED_BY(mu_); + std::vector> gpu_host_free_visitors_ + TF_GUARDED_BY(mu_); +}; + +} // namespace tensorflow +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_GPU_GPU_PROCESS_STATE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/gpu/gpu_scheduling_metrics_storage.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/gpu/gpu_scheduling_metrics_storage.h new file mode 100644 index 00000000..5e665414 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/gpu/gpu_scheduling_metrics_storage.h @@ -0,0 +1,47 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_GPU_GPU_SCHEDULING_METRICS_STORAGE_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_GPU_GPU_SCHEDULING_METRICS_STORAGE_H_ + +#include +#include +#include + +#include "xla/tsl/framework/real_time_in_memory_metric.h" + +namespace tensorflow { + +// Storage class that holds all the exported in memory metrics exported by GPU +// runtime. +class GpuSchedulingMetricsStorage { + public: + static GpuSchedulingMetricsStorage& GetGlobalStorage(); + + // Gets the metrics for estimated total GPU load. + tsl::RealTimeInMemoryMetric& TotalGpuLoadNs() { + return total_gpu_load_ns_; + } + + const tsl::RealTimeInMemoryMetric& TotalGpuLoadNs() const { + return total_gpu_load_ns_; + } + + private: + tsl::RealTimeInMemoryMetric total_gpu_load_ns_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_GPU_GPU_SCHEDULING_METRICS_STORAGE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/gpu/gpu_serving_device_selector.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/gpu/gpu_serving_device_selector.h new file mode 100644 index 00000000..51a342fb --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/gpu/gpu_serving_device_selector.h @@ -0,0 +1,94 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_GPU_GPU_SERVING_DEVICE_SELECTOR_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_GPU_GPU_SERVING_DEVICE_SELECTOR_H_ + +#include +#include +#include + +#include "absl/base/thread_annotations.h" +#include "absl/container/fixed_array.h" +#include "absl/container/node_hash_map.h" +#include "absl/strings/string_view.h" +#include "absl/synchronization/mutex.h" +#include "xla/tsl/framework/serving_device_selector.h" +#include "tensorflow/core/framework/resource_base.h" + +namespace tensorflow { +namespace gpu { +class GpuServingDeviceSelector; +const char kGpuServingDeviceSelectorResourceName[] = + "gpu_serving_device_selector"; + +class GpuServingDeviceSelectorResource : public ResourceBase { + public: + explicit GpuServingDeviceSelectorResource( + int num_devices, std::unique_ptr + device_selector_policy) + : selector_(std::make_unique( + num_devices, std::move(device_selector_policy))) {} + + std::string DebugString() const override { + return "GpuServingDeviceSelectorResource"; + }; + + GpuServingDeviceSelector* selector() const { return selector_.get(); } + + private: + std::unique_ptr selector_; +}; + +class GpuServingDeviceSelector : public tsl::ServingDeviceSelector { + public: + GpuServingDeviceSelector( + int num_devices, + std::unique_ptr device_selector_policy); + + tsl::DeviceReservation ReserveDevice( + absl::string_view program_fingerprint) override; + + // Enqueues the program on the stream of index `index_on_host`. + void Enqueue(int32_t index_on_host, absl::string_view fingerprint) override; + + // Marks the completion of a program on the given stream. + // If `had_error` is true, this function doesn't update program's execution + // time stats to avoid incorrect estimates. + void Completed(int32_t index_on_host, bool had_error) override; + + private: + friend class ServingDeviceSelectorTestHelper; + static void OverwriteNowNsFunctionForTest(int64_t (*now_ns)()); + + void FreeDeviceReservation( + const tsl::DeviceReservation& reservation) override; + + // Only for metrics reporting purposes. + int64_t TotalEstimatedTimeTillIdleNs() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); + + absl::Mutex mu_; + absl::FixedArray device_states_ ABSL_GUARDED_BY(mu_); + std::unique_ptr device_selector_policy_; + int64_t req_id_counter_ ABSL_GUARDED_BY(mu_); + // Map from program fingerprint to execution info. + absl::node_hash_map execution_info_ + ABSL_GUARDED_BY(mu_); + std::optional min_exec_time_ ABSL_GUARDED_BY(mu_); +}; + +} // namespace gpu +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_GPU_GPU_SERVING_DEVICE_SELECTOR_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/gpu/gpu_util.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/gpu/gpu_util.h new file mode 100644 index 00000000..0b650ad9 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/gpu/gpu_util.h @@ -0,0 +1,111 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_GPU_GPU_UTIL_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_GPU_GPU_UTIL_H_ + +#include "tensorflow/core/common_runtime/device.h" +#include "tensorflow/core/common_runtime/dma_helper.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/stream_executor.h" + +namespace tensorflow { + +class RecvTensorResponse; +class TensorProto; + +class GPUUtil { + public: + // "tensor" is GPU-local. "dev" is the hosting GPU. + // "device_context" should be the context of the GPU "_Send" op + // which provides the Tensor. + // Sets all necessary fields of "proto" by transferring value + // bytes from GPU to CPU RAM. "is_dead" indicates that the + // tensor is dead with an uninit value. + static void SetProtoFromGPU(const Tensor& tensor, Device* dev, + const DeviceContext* device_context, + TensorProto* proto, bool is_dead, + StatusCallback done); + + // Copies the data in 'gpu_tensor' into 'cpu_tensor'. + // 'gpu_tensor''s backing memory must be on 'gpu_device' and + // 'cpu_tensor' must be allocated to be of the same size as + // 'gpu_tensor'. Synchronous: may block. + static void CopyGPUTensorToCPU(Device* gpu_device, + const DeviceContext* device_context, + const Tensor* gpu_tensor, Tensor* cpu_tensor, + StatusCallback done); + + // Blocks until all operations queued on the stream associated with + // "gpu_device" at the time of the call have completed. Returns any + // error pending on the stream at completion. + static absl::Status Sync(Device* gpu_device); + + // Blocks until all operations queued on all streams associated with the + // corresponding GPU device at the time of call have completed. + // Returns any error pending on the stream at completion. + static absl::Status SyncAll(Device* gpu_device); + + // For debugging purpose, given a "device" and a "tensor" allocated + // on the device, return a string printing each byte in the tensor + // (up to a limit). "device" can be either a CPU or a GPU device. + static string MemoryDebugString(const Device* device, Tensor* tensor); + + // Map a Tensor as a DeviceMemory object wrapping the given typed + // buffer. + // + // NOTE: will be removed soon, see StreamExecutorUtil::AsDeviceMemory + // instead. + template + static se::DeviceMemory AsDeviceMemory(const Tensor& t) { + T* ptr = reinterpret_cast(const_cast(DMAHelper::base(&t))); + return se::DeviceMemory(se::DeviceMemoryBase(ptr, t.TotalBytes())); + } + + // Computes a checksum over the contents of "tensor", which is allocated + // on "gpu_device". + static uint64 Checksum(Device* gpu_device, + const DeviceContext* device_context, + const Tensor& tensor); + + // Computes a checksum over the contents of "tensor", which is allocated + // in local CPU RAM. + static uint64 Checksum(const Tensor& tensor); + + static void CopyCPUTensorToGPU(const Tensor* cpu_tensor, + const DeviceContext* device_context, + Device* gpu_device, Tensor* gpu_tensor, + StatusCallback done, bool sync_dst_compute); + + static void DeviceToDeviceCopy( + DeviceContext* send_dev_context, DeviceContext* recv_dev_context, + Device* src, Device* dst, AllocatorAttributes src_alloc_attr, + AllocatorAttributes dst_alloc_attr, const Tensor* input, Tensor* output, + int dev_to_dev_stream_index, StatusCallback done); + + // Deep-copying of GPU tensor on the same device. + // 'src_gpu_tensor''s and 'dst_gpu_tensor''s backing memory must be on + // 'gpu_device' and 'dst_cpu_tensor' must be allocated to be of the same + // size as 'src_gpu_tensor'. + static void CopyGPUTensorToSameGPU(Device* gpu_device, + const DeviceContext* device_context, + const Tensor* src_gpu_tensor, + Tensor* dst_gpu_tensor, + StatusCallback done); +}; + +} // namespace tensorflow +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_GPU_GPU_UTIL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/gpu_device_context.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/gpu_device_context.h new file mode 100644 index 00000000..e7486e97 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/gpu_device_context.h @@ -0,0 +1,107 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_GPU_DEVICE_CONTEXT_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_GPU_DEVICE_CONTEXT_H_ + +#include "tensorflow/core/common_runtime/device.h" +#include "tensorflow/core/framework/device_base.h" +#include "tensorflow/core/lib/gtl/inlined_vector.h" + +namespace stream_executor { +class Stream; +} // namespace stream_executor + +namespace tensorflow { + +class GPUDeviceContext : public DeviceContext { + public: + // Does not take ownership of streams. + GPUDeviceContext( + int stream_id, se::Stream* stream, +#if TENSORFLOW_USE_ROCM + se::Stream* nccl_stream, +#endif + se::Stream* host_to_device_stream, se::Stream* device_to_host_stream, + absl::InlinedVector device_to_device_stream, + Allocator* host_memory_allocator) + : stream_id_(stream_id), + stream_(stream), +#if TENSORFLOW_USE_ROCM + nccl_stream_(nccl_stream), +#endif + host_to_device_stream_(host_to_device_stream), + device_to_host_stream_(device_to_host_stream), + device_to_device_stream_(device_to_device_stream), + host_memory_allocator_(host_memory_allocator) { + } + + ~GPUDeviceContext() override {} + + se::Stream* stream() const override { return stream_; } +#if TENSORFLOW_USE_ROCM + se::Stream* nccl_stream() const { return nccl_stream_; } +#endif + se::Stream* host_to_device_stream() const { return host_to_device_stream_; } + se::Stream* device_to_host_stream() const { return device_to_host_stream_; } + se::Stream* device_to_device_stream(int index) const { + return device_to_device_stream_[index % device_to_device_stream_.size()]; + } + int stream_id() const { return stream_id_; } + Allocator* host_memory_allocator() const override { + return host_memory_allocator_; + } + + void CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device, + Tensor* device_tensor, StatusCallback done, + bool sync_dst_compute) const override; + + void CopyDeviceTensorToCPU(const Tensor* device_tensor, + absl::string_view edge_name, Device* device, + Tensor* cpu_tensor, StatusCallback done) override; + + void CopyTensorInSameDevice(const Tensor* input_tensor, Device* device, + Tensor* output_tensor, + StatusCallback done) const override; + + void MaintainLifetimeOnStream(const Tensor* t, + se::Stream* stream) const override {} + + absl::Status ThenExecute(Device* device, se::Stream* stream, + std::function func) override; + + private: + int stream_id_; + // The default primary stream to use for this context. + // All the memory belongs to this stream. + se::Stream* stream_; +#if TENSORFLOW_USE_ROCM + // The stream to use for nccl operations. + se::Stream* nccl_stream_; +#endif + // The stream to use for copying data from host into GPU. + se::Stream* host_to_device_stream_; + // The stream to use for copying data from GPU to host. + se::Stream* device_to_host_stream_; + // Streams to use for copying data between GPUs. + absl::InlinedVector device_to_device_stream_; + // The allocator to use for allocating pinned host memory. + // Not owned. + Allocator* host_memory_allocator_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_GPU_DEVICE_CONTEXT_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/gradients.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/gradients.h new file mode 100644 index 00000000..aaa9cad8 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/gradients.h @@ -0,0 +1,58 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_GRADIENTS_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_GRADIENTS_H_ + +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/gtl/array_slice.h" + +namespace tensorflow { + +// Represents the output of 'node' at 'index'. +struct NodeOut { + Node* node; + int index; + + // Returns the string name that represents the output of this node. + string name() const; + // Returns the data type of the output of this node. + DataType dtype() const; +}; + +// NOTE: This API is a work in progress and will likely be changing frequently. +// +// Given initial gradient-node outputs 'y_grad_node_outputs' (which compute the +// symbolic partial derivatives of some loss function 'L' w.r.t the node outputs +// 'y_node_outputs'), adds gradient nodes to 'graph' that compute the symbolic +// partial derivatives of 'L' w.r.t the node outputs 'x_node_outputs'. +// +// REQUIRES: Each node in 'x_node_outputs' to be unique, and so to have a single +// output (this restriction will be removed in a subsequent change). + +// TODO(andydavis) Add symbolic gradient support for general graphs (the current +// implementation only supports gradients for functions). In particular, +// the nodes in 'x_nodes' are currently restricted to have one output. + +absl::Status AddSymbolicGradients(absl::Span y_node_outputs, + absl::Span x_node_outputs, + absl::Span y_grad_node_outputs, + std::vector* x_grad_node_outputs, + Graph* graph); + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_GRADIENTS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/graph_constructor.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/graph_constructor.h new file mode 100644 index 00000000..5f97f387 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/graph_constructor.h @@ -0,0 +1,210 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_GRAPH_CONSTRUCTOR_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_GRAPH_CONSTRUCTOR_H_ + +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/graph/tensor_id.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { +class ShapeRefiner; + +// Construct a Graph *g out of a GraphDef gdef. Returns non-OK on +// error, in which case *g is left in an incomplete state. +// +// *g is expected to be an empty graph (with no more than a source and sink +// nodes) when provided to ConvertGraphDefToGraph. To enhance an existing Graph, +// see ImportGraphDef. +struct GraphConstructorOptions { + GraphConstructorOptions() = default; + + // If true, allows internal ops in the GraphDef. + bool allow_internal_ops = false; + + // If true, the graph def is expected to have fully specified + // devices for all nodes. A node in the resulting graph "g" has the + // device name set accordingly. + // + // TODO(zhifengc): if possible, consider removing this option. + bool expect_device_spec = false; + + // If true, validates that nodes being converted have all expected attrs + // set and no unknown attrs set by calling ValidateNodeDef(). + // Setting validate_nodes without add_default_attributes, will fail if + // the GraphDef does not have all required attributes set. + bool validate_nodes = false; + + // If true, GraphConstructor will add attributes with their default + // value to the Node when they are missing from the NodeDef. + bool add_default_attributes = true; +}; +extern absl::Status ConvertGraphDefToGraph(const GraphConstructorOptions& opts, + const GraphDef& gdef, Graph* g); +extern absl::Status ConvertGraphDefToGraph(const GraphConstructorOptions& opts, + GraphDef&& gdef, Graph* g); + +// Same as ConvertGraphDefToGraph, but takes just nodes. Used by function +// instantiation. +// TODO(irving): This will turn into std::vector soon. +extern absl::Status ConvertNodeDefsToGraph( + const GraphConstructorOptions& opts, absl::Span nodes, + Graph* g, const GraphDebugInfo* debug_info = nullptr); + +// Options for calling ImportGraphDef(). +struct ImportGraphDefOptions { + ImportGraphDefOptions() + : uniquify_names(false), + uniquify_prefix(false), + skip_mapped_nodes(false), + validate_shape(true), + propagate_device_spec(false) {} + + // Name prefix to use for nodes imported from the GraphDef. For example, if + // prefix="animals" and GraphDef contains a node "bunny" then the node will be + // named "animals/bunny" in *g. Must not be already used as a node name or + // prefix in the graph. + string prefix; + + // If true, imported node names will be modified if their name already exists + // in the graph. If false, conflicting names will be treated as an error. Note + // that this option has no effect if `prefix` is specified, since `prefix` + // will guarantee all node names are unique. + bool uniquify_names; + + // If true, `prefix` will be modified if it already exists as a node name or + // prefix in the graph. If false, a conflicting prefix will be treated as an + // error. This option has no effect if `prefix` isn't specified. + bool uniquify_prefix; + + // Maps tensors in `gdef` to existing tensors in `g`. Inputs in `gdef` + // corresponding to `input_map` keys will be remapped to the nodes in `g` + // corresponding to the values. + // + // Keys should not include `prefix`, i.e., a key ID's name should be the name + // as it originally appears in `gdef`. + // + // If this is non-empty, ImportGraphDef must be called with the shape refiner + // used to create the existing nodes referenced in `input_map`. + // TODO(skyewm): can we remove this requirement? How do we access the original + // shape refiner? + std::map input_map; + + // If true, nodes that will have all output edges removed because of + // overrides in `input_map` will not be imported. + bool skip_mapped_nodes; + + // The names of existing nodes in `g` that the imported graph should have + // control dependencies on. + // + // Note that to avoid creating many redundant control edges, ImportGraphDef() + // won't add control edges to nodes that will inherit the dependencies from + // other nodes in `gdef`. + std::vector control_dependencies; + + // Tensors in `gdef` that will be returned via the ImportGraphDefResults + // output parameter of `ImportGraphDef()`. If this list is non-empty, the + // caller must pass a results object to `ImportGraphDef()`. The + // `return_tensors` field will be populated with the imported nodes in `g`. + // + // Entries should not include `prefix`, i.e., each ID's name should be the + // name as it originally appears in `gdef`. + // + // If this contains a tensor that's also being remapped via `input_map`, the + // corresponding existing tensor in `g` will be returned. + std::vector return_tensors; + + // The names of nodes in `gdef` that will be returned via the + // ImportGraphDefResults output parameter of `ImportGraphDef()`. If this list + // is non-empty, the caller must pass a results object to + // `ImportGraphDef()`. The `return_nodes` field will be populated with the + // imported nodes in `g`. + // + // Entries should not include `prefix`, i.e., each node's name should be the + // name as it originally appears in `gdef`. + // + // Unlike `return_tensors`, `input_map` has no effect on the nodes + // returned. `return_nodes` must be empty if `skip_mapped_nodes` is true. + // TODO(skyewm): make this work with `skip_mapped_nodes` if there's a need. + std::vector return_nodes; + + // If true, checks that all colocation constraints are nodes in the GraphDef. + bool validate_colocation_constraints = true; + + // If false skips shape validation. + bool validate_shape; + + // TODO(ashankar): Enable handling of GraphDefs produced by newer binaries + // with ops that are not defined in the binary calling ImportGraphDef. + // Similar to the producer_op_list argument to import_graph_def in the + // python API. + + // Try to set default execution device for this grapth. + string default_device; + + // If true, propagates a node's assigned device. By default the runtime + // will recompute the assigned device every time. + bool propagate_device_spec; +}; + +// Optional results that may be returned by ImportGraphDef. +struct ImportGraphDefResults { + // The requested tensors associated with + // ImportGraphDefOptions::return_tensors. Note that the index may be different + // than the requested index if the returned tensor has been remapped according + // to `input_map`. + typedef int Index; + std::vector> return_tensors; + + // The requested nodes associated with ImportGraphDefOptions::return_nodes. + std::vector return_nodes; + + // Keys in ImportGraphDefOptions::input_map that don't appear in `gdef` and + // weren't used as an input to any node in `gdef`. These keys are likely due + // to typos, and callers may wish to treat their existence as an error. + std::vector missing_unused_input_map_keys; +}; + +// Adds the graph in GraphDef `gdef` into an existing Graph `*g`. +// +// On error, returns non-OK and leaves `*g` unmodified. +// +// `refiner` can be null. It should be non-null if the caller +// intends to add additional nodes to the graph after the import. This +// allows the caller to validate shapes of those nodes (since +// ShapeRefiner::AddNode must be called in topological order). +// +// `results` must be non-null if `opts.return_tensors` or `opts.result_nodes` is +// non-empty. It can also be set to fetch the unused input map keys. If it's +// non-null, all the vector fields must be empty. +// +// TODO(ashankar): Push this mechanism and get rid of Session::Extend() +// as a means of enhancing an existing Graph. +extern absl::Status ImportGraphDef(const ImportGraphDefOptions& opts, + const GraphDef& gdef, Graph* g, + ShapeRefiner* refiner, + ImportGraphDefResults* results = nullptr); + +// Make a copy of "src" into "*dest". +// +// REQUIRES: "*dest" is a freshly allocated graph without any nodes or edges +// other than the implicit Source/Sink nodes. +extern void CopyGraph(const Graph& src, Graph* dest); + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_GRAPH_CONSTRUCTOR_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/graph_def_builder_util.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/graph_def_builder_util.h new file mode 100644 index 00000000..8fb53997 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/graph_def_builder_util.h @@ -0,0 +1,36 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_GRAPH_DEF_BUILDER_UTIL_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_GRAPH_DEF_BUILDER_UTIL_H_ + +#include "tensorflow/core/graph/graph_def_builder.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { + +class Graph; + +// Converts the `GraphDef` being built by `builder` to a `Graph` and +// stores it in `*graph`. +// TODO(josh11b): Make this faster; right now it converts +// Graph->GraphDef->Graph. This cleans up the graph (e.g. adds +// edges from the source and to the sink node, resolves back edges +// by name), and makes sure the resulting graph is valid. +absl::Status GraphDefBuilderToGraph(const GraphDefBuilder& builder, + Graph* graph); + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_GRAPH_DEF_BUILDER_UTIL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/graph_execution_state.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/graph_execution_state.h new file mode 100644 index 00000000..4f713ae9 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/graph_execution_state.h @@ -0,0 +1,243 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_GRAPH_EXECUTION_STATE_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_GRAPH_EXECUTION_STATE_H_ + +#include +#include +#include +#include + +#include "tensorflow/core/common_runtime/build_graph_options.h" +#include "tensorflow/core/common_runtime/device.h" +#include "tensorflow/core/common_runtime/device_set.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/graph/costmodel.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +struct SessionOptions; + +namespace subgraph { +struct RewriteGraphMetadata; +} + +struct GraphExecutionStateOptions { + const DeviceSet* device_set = nullptr; + const SessionOptions* session_options = nullptr; + // Unique session identifier. Can be empty. + string session_handle; + // A map from node name to device name, representing the unchangeable + // placement of stateful nodes. + std::unordered_map stateful_placements; + // Whether to run Placer on the graph. + bool run_placer = true; + + // Whether to enable tf2xla mlir bridge. The default is true and intends to + // work for almost all models. Non default values should only applied to + // selective models. + bool enable_tf2xla_mlir_bridge = true; +}; + +// A ClientGraph is simply a sub-graph of the full graph as induced by +// BuildGraphOptions. +struct ClientGraph { + explicit ClientGraph(std::unique_ptr flib, + DataTypeVector feed_types, DataTypeVector fetch_types, + int64_t collective_graph_key) + : flib_def(std::move(flib)), + graph(flib_def.get()), + feed_types(std::move(feed_types)), + fetch_types(std::move(fetch_types)), + collective_graph_key(collective_graph_key) {} + // Each client-graph gets its own function library since optimization passes + // post rewrite for execution might want to introduce new functions. + std::unique_ptr flib_def; + Graph graph; + DataTypeVector feed_types; + DataTypeVector fetch_types; + int64_t collective_graph_key; +}; + +// GraphExecutionState is responsible for generating an +// executable ClientGraph from the original GraphDef that specifies +// the complete graph and from BuildGraphOptions which specifies +// input/output nodes. +// +// An executable Graph differs from a GraphDef by being Placed, +// meaning that each Node is assigned to a single Device in the +// available set. +// +// When GraphExecutionState is first constructed it instantiates +// a full Graph from the provided GraphDef, and places it, using only +// the static device assignments from the GraphDef. Nodes without are +// currently placed in a very naive way. Since stateful Nodes cannot +// be moved after initial placement, it is important that stateful +// Nodes get sensible initial device assignments in the graph +// definition. +// +// Subsequently, GraphExecutionState generates a SimpleClientGraph on +// demand, which is a sub-graph of the latest placement of the full +// Graph. MasterSession uses such a ClientGraph to execute one or +// more similar client requests. +// +// GraphExecutionState is thread-safe. + +class GraphExecutionState { + public: + virtual ~GraphExecutionState(); + + // Creates a new `GraphExecutionState` for the given + // `graph_def`, which represents the entire graph for a session. + static absl::Status MakeForBaseGraph( + GraphDef&& graph_def, const GraphExecutionStateOptions& options, + std::unique_ptr* out_state); + + // Creates a new `GraphExecutionState` and `SimpleClientGraph` + // for the subgraph of `original_graph_def` defined by + // `subgraph_options`. + static absl::Status MakeForPrunedGraph( + const GraphExecutionState& base_execution_state, + const GraphExecutionStateOptions& options, + const BuildGraphOptions& subgraph_options, + std::unique_ptr* out_state, + std::unique_ptr* out_client_graph); + + // Creates a new GraphExecutionState representing the + // concatenation of this graph, and the graph defined by + // "extension_def". The same name may not be used to define a node + // in both this graph and "extension_def". + // + // If successful, returns OK and the caller takes ownership of "*out". + // Otherwise returns an error and does not modify "*out". + // + // After calling `old_state->Extend()`, `old_state` may no longer be + // used. + // + // NOTE(mrry): This method respects the placement of stateful nodes in + // in *this, but currently does not transfer any other placement + // or cost model information to the new graph. + // + // Note that using this interface requires setting the value of + // config.experimental().disable_optimize_for_static_graph() in the state + // options to `true`, otherwise it will return an error. + absl::Status Extend(const GraphDef& extension_def, + std::unique_ptr* out) const; + + // Builds a ClientGraph (a sub-graph of the full graph as induced by + // the Node set specified in "options"). If successful, returns OK + // and the caller takes the ownership of "*out". Otherwise, returns + // an error. + absl::Status BuildGraph(const BuildGraphOptions& options, + std::unique_ptr* out); + + // Optimize the graph with the node set specified in `options`. + absl::Status OptimizeGraph( + const BuildGraphOptions& options, const Graph& graph, + const FunctionLibraryDefinition* flib_def, + std::unique_ptr* optimized_graph, + std::unique_ptr* optimized_flib); + + // The graph returned by BuildGraph may contain only the pruned + // graph, whereas some clients may want access to the full graph. + const Graph* full_graph() { return graph_; } + + // The original graph. + GraphDef* original_graph_def() { return original_graph_def_.get(); } + + // The original function library of this graph. + const FunctionLibraryDefinition& flib_def() const { return *flib_def_; } + + // Returns the node with the given name, or null if it does not exist. + const Node* get_node_by_name(const string& name) const { + NodeNameToCostIdMap::const_iterator iter = + node_name_to_cost_id_map_.find(name); + if (iter != node_name_to_cost_id_map_.end()) { + return graph_->FindNodeId(iter->second); + } else { + return nullptr; + } + } + + // Returns the map of stateful placements as a map of + // node name to placement string. + std::unordered_map GetStatefulPlacements() const { + return stateful_placements_; + } + + private: + GraphExecutionState(std::unique_ptr&& graph_def, + std::unique_ptr&& flib_def, + const GraphExecutionStateOptions& options); + + absl::Status InitBaseGraph(std::unique_ptr&& graph, + bool enable_tf2xla_mlir_bridge = true); + + // Map of placed stateful nodes, i.e. nodes for which is_stateful() + // is true, such as "params" and "queue" nodes. Once placed these + // nodes can not be moved to a different device. Maps node names to + // device names. + std::unordered_map stateful_placements_; // Immutable after + // ctor. + void SaveStatefulNodes(Graph* graph); + void RestoreStatefulNodes(Graph* graph); + + // Extract the subset of the graph that needs to be run, adding feed/fetch + // ops as needed. + absl::Status PruneGraph(const BuildGraphOptions& options, Graph* graph, + subgraph::RewriteGraphMetadata* out_rewrite_metadata); + + // The GraphExecutionState must store a copy of the original GraphDef if + // either of the following conditions holds: + // + // * `session_options_.config.graph_options().place_pruned_graph()` is true. + // * `session_options_.config.experimental().optimize_for_static_graph()` is + // false. + const std::unique_ptr original_graph_def_; + + const DeviceSet* device_set_; // Not owned + const SessionOptions* session_options_; // Not owned + // Unique session identifier. Can be empty. + string session_handle_; + + // Map from name to Node for the full graph in placed_. + NodeNameToCostIdMap node_name_to_cost_id_map_; + + // 'flib_def_' is initialized from the initial graph def's library, + // and may be updated by a graph optimization pass. + std::unique_ptr flib_def_; + + // `rewrite_metadata_` is only set for GraphExecutionState + // objects created by `MakeForPrunedGraph()`. + std::unique_ptr rewrite_metadata_; + + // The dataflow graph owned by this object. + Graph* graph_; + + // Whether to run Placer. + bool run_placer_; + + GraphExecutionState(const GraphExecutionState&) = delete; + void operator=(const GraphExecutionState&) = delete; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_GRAPH_EXECUTION_STATE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/graph_optimizer.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/graph_optimizer.h new file mode 100644 index 00000000..f8322cfe --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/graph_optimizer.h @@ -0,0 +1,100 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_GRAPH_OPTIMIZER_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_GRAPH_OPTIMIZER_H_ + +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/protobuf/config.pb.h" + +namespace tensorflow { + +class GraphOptimizer { + public: + using NodePredicate = std::function; + + struct Options { + // If not null it maps from nodes in graph to partially-known + // shapes of their outputs, and may be used, e.g., in the constant folding + // pass. The use of shape_map implies that the mapping from node name to the + // vector of partial shapes of its outputs is stable, i.e., no optimization + // pass may replace a node with a different node of the same name that has a + // different number of outputs, or outputs with different known shapes. + // TODO(b/65453533) introduce a unique way to name nodes in a graph. + std::unordered_map>* shape_map = + nullptr; + + // If not null then only nodes for which cse_consider_fn returns true will + // be considered for CSE. + NodePredicate cse_consider_fn = nullptr; + + // If not null then only nodes for which cf_consider_fn returns true will be + // considered for CF. + NodePredicate cf_consider_fn = nullptr; + + // If true, multi-device functions will be inlined if + // opts_.do_function_inlining() is true. + bool inline_multi_device_functions = false; + + // If true, functions in implementation selection group will be inlined if + // opts_.do_function_inlining() is true. + bool inline_impl_selection_group_functions = false; + + // If true all functions will be inlined with a single device function + // body placer strategy. + bool inline_with_single_device_body_placer = false; + + // If true, the _noinline attribute on functions and callers is ignored. + bool ignore_noinline = false; + }; + + explicit GraphOptimizer(const OptimizerOptions& opts); + ~GraphOptimizer(); + + // Applies optimization passes specified in 'opts' to 'graph'. + // Maybe replace *graph with a new graph object. 'device' is device + // on which the 'graph' will execute. It's passed to the optimizers + // so that they can respect constraints if any, that should be + // respected. + void Optimize(FunctionLibraryRuntime* runtime, Env* env, const Device* device, + std::unique_ptr* graph, + const Options& graph_optimizer_options); + + const OptimizerOptions& options() { return opts_; } + + private: + OptimizerOptions opts_; + + GraphOptimizer(const GraphOptimizer&) = delete; + void operator=(const GraphOptimizer&) = delete; +}; + +// Applies graph rewrite optimization such as inlining, dead code +// removal, etc. +// +// **g is a graph constructed based on the runtime library 'lib'. +// OptimizeGraph mutates **g extensively and replaces '*g' with a +// complete copy. Therefore, the caller should not keep any references +// to nodes *g. +void OptimizeGraph(FunctionLibraryRuntime* lib, std::unique_ptr* g, + const GraphOptimizer::Options& graph_optimizer_options); +void OptimizeGraph(FunctionLibraryRuntime* lib, std::unique_ptr* g); + +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_GRAPH_OPTIMIZER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/graph_runner.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/graph_runner.h new file mode 100644 index 00000000..a40d17b8 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/graph_runner.h @@ -0,0 +1,74 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_GRAPH_RUNNER_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_GRAPH_RUNNER_H_ + +#include +#include +#include + +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tsl { +class Env; +} // namespace tsl +namespace tensorflow { +using Env = tsl::Env; + +class Device; +class Graph; + +// GraphRunner takes a Graph, some inputs to feed, and some outputs +// to fetch and executes the graph required to feed and fetch the +// inputs and outputs. +// +// This class is only meant for internal use where one needs to +// partially evaluate inexpensive nodes in a graph, such as for shape +// inference or for constant folding. Because of its limited, simple +// use-cases, it executes all computation on the given device (CPU by default) +// and is not meant to be particularly lightweight, fast, or efficient. +class GraphRunner { + public: + // REQUIRES: `env` is not nullptr. + GraphRunner(Env* env); + // REQUIRES: 'device' is not nullptr. Not owned. + GraphRunner(Device* device); + ~GraphRunner(); + + // Function semantics for `inputs`, `output_names` and `outputs` + // matches those from Session::Run(). + // + // NOTE: The output tensors share lifetime with the GraphRunner, and could + // be destroyed once the GraphRunner is destroyed. + // + // REQUIRES: `graph`, `env`, and `outputs` are not nullptr. + // `function_library` may be nullptr. + typedef std::vector> NamedTensorList; + absl::Status Run(Graph* graph, FunctionLibraryRuntime* function_library, + const NamedTensorList& inputs, + const std::vector& output_names, + std::vector* outputs); + + private: + std::unique_ptr device_deleter_; + Device* const device_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_GRAPH_RUNNER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/graph_view.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/graph_view.h new file mode 100644 index 00000000..d1fe278a --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/graph_view.h @@ -0,0 +1,258 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_GRAPH_VIEW_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_GRAPH_VIEW_H_ + +#include +#include + +#include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +class Device; +class Graph; +class Node; +class OpKernel; +class Tensor; + +// Represents a single data edge in a `NodeItem`. +struct EdgeInfo { + // The node ID of the destination in the containing `GraphView`. + int dst_id; + // The index of the output that produces values on this edge. + int output_slot : 31; + // true if this is the last info for output_slot in the EdgeInfo list. + bool is_last : 1; + // The index of the input that consumes values on this edge. + int input_slot; +}; + +// Represents a single control edge in a `NodeItem`. +struct ControlEdgeInfo { + // The node ID of the destination in the containing `GraphView`. + int dst_id; +}; + +// Compact structure representing a graph node and its associated kernel. +// +// Each NodeItem is an element of exactly one GraphView. +struct NodeItem { + // The index of this node's item in its GraphView. + int node_id = -1; + + // Cached attributes of this node for fast lookup. + bool kernel_is_async : 1; // True iff kernel->AsAsync() != nullptr + bool is_merge : 1; // True iff IsMerge(node) + bool is_enter : 1; // True iff IsEnter(node) + bool is_constant_enter : 1; // True iff IsEnter(node) and + // node->GetAttr("is_constant") == true. + bool is_exit : 1; // True iff IsExit(node) + bool is_control_trigger : 1; // True iff IsControlTrigger(node) + bool is_source : 1; // True iff IsSource(node) + // True iff IsEnter(node) || IsExit(node) || IsNextIteration(node) + bool is_enter_exit_or_next_iter : 1; + bool is_transfer_node : 1; // True iff IsTransferNode(node) + bool is_initialization_op : 1; // True iff IsInitializationOp(node) + bool is_recv_or_switch : 1; // True iff IsRecv(node) || IsSwitch(node) + bool is_next_iteration : 1; // True iff IsNextIteration(node) + bool is_noop : 1; // True iff item->kernel->type_string_view() == "NoOp") + bool + is_any_consumer_merge_or_control_trigger : 1; // True iff the destination + // of any output edge is a + // merge or control trigger + // node. + bool is_any_input_ref_typed : 1; // True iff any IsRefType(dt) for dt in this + // node's input types. + bool is_distributed_communication : 1; // True iff the op is registered to + // use distributed communication. + + // The kernel for this node. + OpKernel* kernel = nullptr; + + // If the kernel is a Const op, this containts points to the constant tensor. + const Tensor* const_tensor = nullptr; + + // Cached values of node->num_inputs() and node->num_outputs(), to + // avoid levels of indirection. + int num_inputs; + int num_outputs; + + // ExecutorImpl::tensors_[input_start] is the 1st positional input + // for this node. + int input_start = 0; + + // Number of output edges, excluding control edges. + int32 num_output_edges; + + // Number of output control edges. + int32 num_output_control_edges; + + // If non-null, contains an array of num_outputs bools, where the ith bool + // is true if and only if the ith output is consumed by another node. + std::unique_ptr outputs_required; + + absl::Span mutable_output_edges() { + return absl::Span(output_edge_base(), num_output_edges); + } + + gtl::ArraySlice output_edges() const { + return gtl::ArraySlice(output_edge_base(), num_output_edges); + } + + gtl::ArraySlice output_control_edges() const { + return gtl::ArraySlice(output_control_edge_base(), + num_output_control_edges); + } + + DataType input_type(int i) const { + DCHECK_LT(i, num_inputs); + return static_cast(input_type_base()[i]); + } + DataType output_type(int i) const { + DCHECK_LT(i, num_outputs); + return static_cast(output_type_base()[i]); + } + + // Return array of per-output allocator attributes. + const AllocatorAttributes* output_attrs() const { return output_attr_base(); } + + // Return array of expected input index from which each output should + // be forwarded: + // kNeverForward (-2) for DO NOT FORWARD (must allocate). + // kNoReservation (-1) for no expected forwarding. + // 0... for forward from that input. + const int* forward_from() const { return forward_from_base(); } + + string DebugString() const; + + private: + friend class GraphView; + + NodeItem() {} + + // Variable length section starts immediately after *this + // (uint8 is enough for DataType). + // EdgeInfo out_edges[num_output_edges]; + // ControlEdgeInfo out_control_edges[num_output_control_edges]; + // AllocatorAttributes output_attr[num_outputs]; + // int forward_from[num_outputs]; + // uint8 input_type[num_inputs]; + // uint8 output_type[num_outputs]; + + // Return pointer to variable length section. + char* var() const { + return const_cast(reinterpret_cast(this) + + sizeof(NodeItem)); + } + + EdgeInfo* output_edge_base() const { + return reinterpret_cast(var()); + } + + ControlEdgeInfo* output_control_edge_base() const { + return reinterpret_cast(var() + sizeof(EdgeInfo) * + num_output_edges); + } + + AllocatorAttributes* output_attr_base() const { + return reinterpret_cast( + var() + sizeof(EdgeInfo) * num_output_edges + + sizeof(ControlEdgeInfo) * num_output_control_edges); + } + int* forward_from_base() const { + return reinterpret_cast(var() + sizeof(EdgeInfo) * num_output_edges + + sizeof(ControlEdgeInfo) * + num_output_control_edges + + sizeof(AllocatorAttributes) * num_outputs); + } + uint8* input_type_base() const { + return reinterpret_cast( + var() + sizeof(EdgeInfo) * num_output_edges + + sizeof(ControlEdgeInfo) * num_output_control_edges + + sizeof(AllocatorAttributes) * num_outputs + sizeof(int) * num_outputs); + } + uint8* output_type_base() const { + return reinterpret_cast( + var() + sizeof(EdgeInfo) * num_output_edges + + sizeof(ControlEdgeInfo) * num_output_control_edges + + sizeof(AllocatorAttributes) * num_outputs + sizeof(int) * num_outputs + + sizeof(uint8) * num_inputs); + } + + NodeItem(const NodeItem&) = delete; + void operator=(const NodeItem&) = delete; +}; + +// Immutable view of a Graph organized for efficient execution. +// +// TODO(b/152651962): Add independent unit tests for this class. +class GraphView { + public: + GraphView() : space_(nullptr) {} + ~GraphView(); + + absl::Status Initialize(const Graph* g); + absl::Status SetAllocAttrs(const Graph* g, const Device* device); + void SetScopedAllocatorAttrs(const std::vector& sa_nodes); + + // Returns a mutable pointer to the `NodeItem` with the given `id` if it + // exists in the graph, or `nullptr` if it does not. + NodeItem* node(int32_t id) const { + DCHECK_GE(id, 0); + DCHECK_LT(id, num_nodes_); + uint32 offset = node_offsets_[id]; + return ((offset == kuint32max) + ? nullptr + : reinterpret_cast(space_ + node_offsets_[id])); + } + + // Returns the `NodeItem` with the given `id`. + // + // REQUIRES: `id` must be the ID of a valid node in the graph. + const NodeItem& node_ref(int32_t id) const { + DCHECK_GE(id, 0); + DCHECK_LT(id, num_nodes_); + uint32 offset = node_offsets_[id]; + DCHECK_NE(offset, kuint32max); + return *reinterpret_cast(space_ + node_offsets_[id]); + } + + int32 num_nodes() const { return num_nodes_; } + + private: + char* InitializeNode(char* ptr, const Node* n); + size_t NodeItemBytes(const Node* n); + + int32 num_nodes_ = 0; + uint32* node_offsets_ = nullptr; // array of size "num_nodes_" + // node_offsets_[id] holds the byte offset for node w/ "id" in space_ + + char* space_; // NodeItem objects are allocated here + + GraphView(const GraphView&) = delete; + void operator=(const GraphView&) = delete; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_GRAPH_VIEW_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/hierarchical_tree_broadcaster.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/hierarchical_tree_broadcaster.h new file mode 100644 index 00000000..fd5ee985 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/hierarchical_tree_broadcaster.h @@ -0,0 +1,87 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_HIERARCHICAL_TREE_BROADCASTER_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_HIERARCHICAL_TREE_BROADCASTER_H_ + +#include + +#include "tensorflow/core/common_runtime/base_collective_executor.h" +#include "tensorflow/core/framework/collective.h" + +namespace tensorflow { + +// Hierarchical tree-algorithm implementation of collective broadcast. +class HierarchicalTreeBroadcaster : public CollectiveImplementationInterface { + public: + HierarchicalTreeBroadcaster(); + ~HierarchicalTreeBroadcaster() override = default; + + // Establishes the subdiv permutations needed for a hierarchical broadcast. + // If all devices are local, establishes a single subdiv comprising all + // devices. If any devices are on a different task, establishes n+1 subdivs + // for n tasks. + // The first subdiv comprises one device per task which gets the tensor on + // each task. Subdiv i+1 corresponds to a task-local tree-broadcast for task + // i. + absl::Status InitializeCollectiveParams( + CollectiveParams* col_params) override; + + // Initializes members of CollectiveContext not yet initialized, i.e. device + // and device_locality. Also saves the CollectiveContext in this object. + absl::Status InitializeCollectiveContext( + std::shared_ptr col_ctx) override; + + // Begins async execution of the hierarchical tree broadcast. + // Must be called in a blockable thread. + // TODO(b/80529858): remove the previous warning when we have a dedicated + // collective threadpool. + void Run(StatusCallback done) override; + + // Returns the rank of the device from which this device should receive + // its value, -1 if no value should be received. + static int TreeRecvFrom(const CollectiveParams& cp, int subdiv); + + // Populates targets with the ranks of the devices to which this device + // should forward the value. + static void TreeSendTo(const CollectiveParams& cp, int subdiv, + std::vector* targets); + + private: + // Get the task to which the device at `device_rank` belongs. + int GetDeviceTask(int device_rank, const std::vector& dev_per_task); + + // Sends `src_tensor` asynchronously from this device to device at `dst_rank` + // in `subdiv`. Calls `done` upon completion. + void DispatchSend(int subdiv, int dst_rank, int src_rank, + const Tensor* src_tensor, const StatusCallback& done); + + // Receives a tensor into the memory buffer owned by `dst_tensor` at this + // device from device at `src_rank` in `subdiv`. Calls `done` upon + // completion. + void DispatchRecv(int subdiv, int src_rank, int dst_rank, Tensor* dst_tensor, + const StatusCallback& done); + + // Executes the hierarchical broadcast defined by this op. + void RunTree(); + + std::shared_ptr col_ctx_; + const CollectiveParams* col_params_; // Not owned + StatusCallback done_; + absl::Status status_; + bool is_source_; +}; + +} // namespace tensorflow +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_HIERARCHICAL_TREE_BROADCASTER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/immutable_executor_state.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/immutable_executor_state.h new file mode 100644 index 00000000..6a12bc1f --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/immutable_executor_state.h @@ -0,0 +1,163 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_IMMUTABLE_EXECUTOR_STATE_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_IMMUTABLE_EXECUTOR_STATE_H_ + +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "tensorflow/core/common_runtime/graph_view.h" +#include "tensorflow/core/common_runtime/local_executor_params.h" +#include "tensorflow/core/common_runtime/pending_counts.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/gtl/flatmap.h" +#include "tensorflow/core/lib/gtl/flatset.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +class Graph; + +// Represents the state of an executor (graph and control flow information) +// that is immutable throughout execution. +// +// TODO(b/152651962): Add independent unit tests for this class. +class ImmutableExecutorState { + public: + struct FrameInfo { + explicit FrameInfo(string name) + : name(std::move(name)), + input_count(0), + total_inputs(0), + pending_counts(nullptr), + nodes(nullptr), + parallel_iterations(-1) {} + + // The name of the frame. + string name; + + // The total number of inputs to a frame. + int input_count; + + // The total number of input tensors of a frame. + // == sum(nodes[*].num_inputs()) where nodes are the nodes in the frame. + int total_inputs; + + // Used to determine the next place to allocate space in the + // pending_counts data structure we'll eventually construct + PendingCounts::Layout pending_counts_layout; + + // Each frame has its own PendingCounts only for the nodes in the frame. + std::unique_ptr pending_counts; + + // The nodes in a frame. Used only for debugging. + std::unique_ptr> nodes; + + // The number of iterations of this frame that can execute concurrently. + int32 parallel_iterations; + }; + + explicit ImmutableExecutorState(const LocalExecutorParams& p) + : params_(p), gview_() {} + ~ImmutableExecutorState(); + + absl::Status Initialize(const Graph& graph); + + // Process all Nodes in the current graph, attempting to infer the + // memory allocation attributes to be used wherever they may allocate + // a tensor buffer. + absl::Status SetAllocAttrs(); + + const LocalExecutorParams& params() const { return params_; } + const GraphView& graph_view() const { return gview_; } + const std::vector& pending_ids() const { + return pending_ids_; + } + const std::vector& root_nodes() const { return root_nodes_; } + + const FrameInfo& get_root_frame_info() const { return *root_frame_info_; } + + const FrameInfo& get_enter_frame_info(const NodeItem& node_item) const { + DCHECK(node_item.is_enter); + return *enter_frame_info_[node_item.node_id]; + } + + bool requires_control_flow_support() const { return requires_control_flow_; } + + // Copies the pending counts for nodes in this graph to the given array. + // + // This method provides a more efficient way of initializing + // `SimplePropagatorState` than individually accessing the pending counts from + // `get_root_frame_info().counts`. + // + // REQUIRES: `!requires_control_flow_support && len(dest) == + // graph_view().num_nodes()`. + void copy_pending_counts(std::atomic* dest) const { + DCHECK(!requires_control_flow_); + memcpy(dest, atomic_pending_counts_.get(), + graph_view().num_nodes() * sizeof(std::atomic)); + std::atomic_thread_fence(std::memory_order_release); + } + + private: + struct ControlFlowInfo { + gtl::FlatSet unique_frame_names; + std::vector frame_names; + }; + + static absl::Status BuildControlFlowInfo(const Graph* graph, + ControlFlowInfo* cf_info); + void InitializePending(const Graph* graph, const ControlFlowInfo& cf_info); + + FrameInfo* EnsureFrameInfo(const string& fname); + + // Owned. + LocalExecutorParams params_; + GraphView gview_; + bool requires_control_flow_; + std::vector pending_ids_; + + // Root nodes (with no in edges) that should form the initial ready queue + std::vector root_nodes_; + + // Mapping from frame name to static information about the frame. + // TODO(yuanbyu): We could cache it along with the graph so to avoid + // the overhead of constructing it for each executor instance. + absl::flat_hash_map> + frame_info_; + const FrameInfo* root_frame_info_; // Not owned. + + // If the graph contains any "Enter" or "RefEnter" nodes, this vector maps + // dense node IDs to the corresponding FrameInfo. + std::vector enter_frame_info_; + + // If `requires_control_flow_` is false, this points to an array of initial + // pending counts for the nodes in the graph, indexed by node ID. + std::unique_ptr[]> atomic_pending_counts_; + + // Shallow copies of the constant tensors used in the graph. + std::vector const_tensors_; + + ImmutableExecutorState(const ImmutableExecutorState&) = delete; + void operator=(const ImmutableExecutorState&) = delete; +}; + +} // namespace tensorflow +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_IMMUTABLE_EXECUTOR_STATE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/inline_function_utils.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/inline_function_utils.h new file mode 100644 index 00000000..94c118fe --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/inline_function_utils.h @@ -0,0 +1,241 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_INLINE_FUNCTION_UTILS_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_INLINE_FUNCTION_UTILS_H_ + +#include +#include + +#include "absl/types/optional.h" +#include "tensorflow/core/common_runtime/device.h" +#include "tensorflow/core/common_runtime/function_body.h" +#include "tensorflow/core/common_runtime/lower_function_call_inline_policy.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/protobuf/config.pb.h" + +namespace tensorflow { + +static constexpr const char* const kNoInlineAttr = "_noinline"; + +// Optionally override device assignment for nodes added to the graph for +// inlined functions: +// (1) Identity nodes added in place of function input arguments. +// (2) Identity nodes added in place of function return values. +// (3) Special NoOp nodes that enforce side-effects execution order. +// (4) All nodes inside function body specified in FunctionDef. +class InlinedFunctionBodyPlacer { + public: + virtual ~InlinedFunctionBodyPlacer() = default; + + virtual absl::optional InputNodeDevice(int input_index) const = 0; + virtual absl::optional OutputNodeDevice(int output_index) const = 0; + // Returns true if the added input/output identity nodes should be colocated + // with the corresponding input/output from the function body. + virtual bool ColocateInputOutputIdentities() const = 0; + virtual absl::optional ControlNodeDevice() const = 0; + virtual absl::optional BodyNodeDevice(const NodeDef& ndef) const = 0; + + // LINT.IfChange + // Place input nodes on the same device as the corresponding caller input + // node. Do not specify any placement for all other nodes. + static std::unique_ptr DefaultPlacer( + const Graph& graph, const Node& caller); + + // Place all nodes on the same device as caller node. + static std::unique_ptr SingleDevicePlacer( + const Graph& graph, const Node& caller); + + // Place input nodes on the same device as the corresponding caller input + // node. Do not place output node. Place control nodes on the same device as + // caller node. For all function body nodes set job, replica and task + // parts of the device assignment to match function caller node where those + // are unspecified. + static std::unique_ptr MultiDevicePlacer( + const Graph& graph, const Node& caller); + // LINT.ThenChange(lower_function_call_inline_policy.h) + + using Factory = std::function( + const Graph&, const Node&)>; + + struct Config { + string name; + Factory get; + }; + + static Config Default() { return {"default", DefaultPlacer}; } + static Config SingleDevice() { return {"single_device", SingleDevicePlacer}; } + static Config MultiDevice() { return {"multi_device", MultiDevicePlacer}; } +}; + +struct InlineFunctionBodyOptions { + // All nodes that have incoming control edge *from* the function call node, + // will be forwarded to the "output control node". There are two options for + // choosing which nodes will have a control edge *to* the "output control + // node": + // a) control returns (`control_ret` field in FunctionDef) + // b) data returns (`ret` field in FunctionDef) + enum class OutputControlSource { kDataOutputs, kControlOutputs }; + + // Keep a node in a graph with the same name as the function call node: + // + // a) DoNotKeep: Function call node is fully inlined, and there is no node in + // a graph with the same name. + // + // b) Fetchable: Add an IdentityN node to the graph in place of the inlined + // function call node. It will have a control edge from inlined + // 'output_control_node' and data edges from function output nodes. + // The IdentityN node will be placed on the same device as the caller node. + // + // This is mostly for compatibility with Tensorflow v1 and sessions. + // When we prepare a graph for execution in + // GraphExecutionState::MakeForBaseGraph we don't know what nodes will be + // fetched, so we can't safely remove any of them. When graph executed as a + // function it has 'Retval' nodes for all fetched tensors, and we can + // safely inline function calls. + // + // c) Targetable: Add a NoOp node to the graph in place of the inlined + // function call node. It will have a control edge from inline + // 'output_control_node' and no data edges. NoOp node will be placed on the + // same device as the caller node. This will keep the inlined function call + // node a valid 'session.run' target, and also will keep it a valid control + // output node. + enum class KeepCallerNode { kDoNotKeep, kFetchable, kTargetable }; + + // If 'true' function inlining is completely disabled. This allows to control + // function inlining for different types of function calls (see + // 'ExpandInlineFunctionsOptions' below). + bool disable_inlining = false; + // Ignore '_noinline' function attribute. + bool ignore_noinline = false; + // If 'true' function inlining will inline functions in implementation + // selection group. Normally those functions should not be inlined; they will + // be handled by Grappler. + bool inline_impl_selection_group_functions = false; + // Controls if we want to keep a node with the name as the function call node + // in a graph after function inlining. + KeepCallerNode keep_caller_node = KeepCallerNode::kDoNotKeep; + // For compatibility with Tensorflow v1 by default we will use data outputs. + // Control returns were added to Tensorflow v2 with automatic control + // dependencies tracking in Eager mode. + OutputControlSource output_control_src = OutputControlSource::kDataOutputs; + // Inlined function body placer decides what requested device assignments + // should be added to the nodes added to the graph. See documentation above + // for available strategies. + InlinedFunctionBodyPlacer::Config inlined_function_body_placer = + InlinedFunctionBodyPlacer::Default(); + // If true, frame names in the function body will be + // made unique in the resulting graph (e.g. by prepending a unique prefix). + // NOTE(mrry): Only set this option to false when there is a single function + // call in the graph (e.g. when making a remote function call via + // ClusterFunctionLibraryRuntime). This option is provided because the graph + // partitioner generates frame names that must remain unmodified across all + // partitions of a multi-device function. + bool uniquify_frame_names = true; + + // A human-readable debug string for this options. + string DebugString() const; +}; + +// Returns 'OkStatus()' iff the function '*fbody' can be inlined at 'node' +// based on the type signature of 'node' and 'fbody': +// +// (1) Caller node has the same number of inputs and outputs as the function. +// (2) Caller node inputs and outputs have the same data types as function +// inputs and returns. +// (3) Validation rules defined in InlineFunctionBodyOptions. +// +// If function can't be safely inlined, returns error message with details why +// inlining is not possible or safe. +absl::Status ValidateInlining(const Node* node, const FunctionBody* fbody, + const InlineFunctionBodyOptions& options); + +// Given a "caller" in graph "g", which is a function call of a function +// to "fbody". Replaces the "caller" with fbody->graph and connects +// edges properly. "override_device" specifies whether inlining should replace +// explicitly specified devices inside fbody with the callee's device. +// +// Returns 'OkStatus()' if function was successfully inlined into the graph. +// If function inlining is not possible returns an error with a reason, and +// leaves the graph in unmodified state. +absl::Status InlineFunctionBody(const FunctionLibraryDefinition& flib_def, + Graph* g, Node* caller, + const FunctionBody* fbody, + const InlineFunctionBodyOptions& options); + +// There are three types of function calls that could be invoked during +// *Tensorflow graph execution*: +// +// 1) Native function call (node.type_string() is the function name). These +// functions are always executed on a single-device, which is the device of +// the function call node. +// +// 2) Multi-device function calls (PartitionedCall or StatefulPartitionedCall +// ops) can execute on multiple devices and accept DT_RESOURCE inputs that +// belong to different devices. This type of functions was added in +// Tensorflow 2.0 Eager mode, and it has control outputs to represent +// side-effects that must always execute (see `control_ret` in FunctionDef). +// +// 3) SymbolicGradient has been deprecated for a while, but we still keep it and +// use `native` options for inlining for compatibility. +// +// We need to have distinct inlining rules for compatibility with Tensorflow v1. +// +// There are few other places in Tensorflow that could execute functions: +// +// 1) common_runtime/eager/kernel_and_device.{h,cc} - executes "top level" +// functions directly via function library runtime, without going through +// the graph. +// 2) tf.data pipelines - also execute functions directly via function library +// runtime with custom executors. +struct ExpandInlineFunctionsOptions { + ExpandInlineFunctionsOptions() : native_options(), multi_device_options() { + using OutputControlSrc = InlineFunctionBodyOptions::OutputControlSource; + multi_device_options.output_control_src = OutputControlSrc::kControlOutputs; + } + + InlineFunctionBodyOptions native_options; + InlineFunctionBodyOptions multi_device_options; +}; + +// WARNING(ezhulenev): PLEASE DO NOT USE THIS FUNCTION. This is a temporary +// workaround that will be enabled only during the function inlining unification +// (b/126811947). Contact ezhulenev@ if you think you need it. +// TODO(ezhulenev): Delete this function. +bool ExpandInlineFunctions(FunctionLibraryRuntime* lib, Graph* graph, + const ExpandInlineFunctionsOptions& options); + +// For each node in "graph", if "lib" indicates that the node is a +// function call, inline the function body. Returns true if at least +// one node is inlined. +// +// This routine goes through "graph" nodes once and applies the +// inlining. The caller may decide to apply the inlining on "graph" +// multiple times by calling ExpandInlineFunctions a few times. +// +// Function calls that can't be safely inlined into the graph (ValidateInlining +// returns error), are ignored. +// +// TODO(ezhulenev): We do not FunctionLibraryRuntime for this. We need just the +// FunctionLibraryDefinition and FunctionDefToBodyHelper to implement this (see +// lower_function_call.cc). +inline bool ExpandInlineFunctions(FunctionLibraryRuntime* lib, Graph* graph) { + return ExpandInlineFunctions(lib, graph, ExpandInlineFunctionsOptions()); +} + +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_INLINE_FUNCTION_UTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/input_colocation_exemption_registry.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/input_colocation_exemption_registry.h new file mode 100644 index 00000000..c393fe74 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/input_colocation_exemption_registry.h @@ -0,0 +1,76 @@ +/* Copyright 2019 The TensorFlow Authors. Al Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_INPUT_COLOCATION_EXEMPTION_REGISTRY_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_INPUT_COLOCATION_EXEMPTION_REGISTRY_H_ + +#include + +#include "tensorflow/core/lib/gtl/flatset.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +// TensorFlow runtime (both eager and graph) will aim to colocate ops with +// their resource inputs so that the ops can access the resource state. In some +// cases, such as tf.data ops, this is not desirable as the ops themselves might +// not have a kernel registered for the device on which the resource is placed +// and instead use a mechanism, such as a multi-device function, to access the +// resource state. +// +// This registry can be used to register and list ops that should be exempt from +// the input colocation described above. +// +// Example usage: +// REGISTER_INPUT_COLOCATION_EXEMPTION("MapDataset"); +class InputColocationExemptionRegistry { + public: + // Returns a pointer to a global InputColocationExemptionRegistry object. + static InputColocationExemptionRegistry* Global(); + + // Returns the set of ops exempt from the input colocation constraints. + const gtl::FlatSet& Get() { return ops_; } + + // Registers an op to be excluded from the input colocation constraints. + void Register(const string& op); + + private: + gtl::FlatSet ops_; +}; + +namespace input_colocation_exemption_registration { + +class InputColocationExemptionRegistration { + public: + explicit InputColocationExemptionRegistration(const string& op) { + InputColocationExemptionRegistry::Global()->Register(op); + } +}; + +} // namespace input_colocation_exemption_registration + +#define REGISTER_INPUT_COLOCATION_EXEMPTION(op) \ + REGISTER_INPUT_COLOCATION_EXEMPTION_UNIQ_HELPER(__COUNTER__, op) + +#define REGISTER_INPUT_COLOCATION_EXEMPTION_UNIQ_HELPER(ctr, op) \ + REGISTER_INPUT_COLOCATION_EXEMPTION_UNIQ(ctr, op) + +#define REGISTER_INPUT_COLOCATION_EXEMPTION_UNIQ(ctr, op) \ + static input_colocation_exemption_registration:: \ + InputColocationExemptionRegistration \ + input_colocation_exemption_registration_fn_##ctr(op) + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_INPUT_COLOCATION_EXEMPTION_REGISTRY_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/inspecting_placer.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/inspecting_placer.h new file mode 100644 index 00000000..90df36c5 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/inspecting_placer.h @@ -0,0 +1,96 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_INSPECTING_PLACER_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_INSPECTING_PLACER_H_ + +#include + +#include "absl/strings/str_join.h" +#include "tensorflow/core/common_runtime/device.h" +#include "tensorflow/core/common_runtime/placer_inspection_required_ops_utils.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/util/device_name_utils.h" +#include "tensorflow/core/util/port.h" + +namespace tensorflow { + +// TODO(iga): Convert this struct into a class to ensure invariants between +// device names, i.e. +// DeviceNameUtils::IsSpecification(resource_device_name, +// requested_device_name) +// PossibleDevices does not contain assigned_device_name because we don't +// assign devices to nested functions. +struct PossibleDevices { + // The same as Member::requested_device_name_ in colocation_graph.cc. + DeviceNameUtils::ParsedName requested_device_name; + + // The same as Member::resource_device_name_ in colocation_graph.cc. + DeviceNameUtils::ParsedName resource_device_name; + + // A device type outside of this set will not be supported by some + // internal op. + PrioritizedDeviceTypeVector device_types; +}; + +// A struct for communicating constraints on devices that can +// be chosen for inputs and outputs of an op requiring deep placer inspection. +struct IOColocationGroups { + // input_groups[i] contains the group id that i'th input belongs to. + // List inputs are not supported. + std::vector input_groups; + // output_groups[i] contains the group id that i'th output belongs to. + // List inputs are not supported. + std::vector output_groups; + // group_devices[i] contains possible devices for group with id i. + std::vector group_devices; + + string DebugString() const; +}; + +class InspectingPlacer { + public: + // graph and device_set must not be null and must outlive this + // InspectingPlacer. default_device can be null. If not, must outlive this. + // TODO(iga): Add a "stack trace" to detect recursion and improve log + // messages. Currently, we will enter an infinite loop for recursive + // functions. + InspectingPlacer(const FunctionStack& stack, + const FunctionLibraryDefinition* flib_def, + const DeviceSet* device_set, const Device* default_device, + bool allow_soft_placement, bool log_device_placement); + + // `node` must be + // PlacerInspectionRequiredOpsChecker::IsPlacerInspectionRequired. + absl::Status ComputeIOColocationGroups(const Node& node, + IOColocationGroups* groups); + + private: + const FunctionStack stack_; + const FunctionLibraryDefinition& flib_def_; + const DeviceSet& device_set_; + const Device* default_device_; + const bool allow_soft_placement_; + const bool log_device_placement_; + + InspectingPlacer(const InspectingPlacer&) = delete; + void operator=(const InspectingPlacer&) = delete; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_INSPECTING_PLACER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/int32_fulltype.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/int32_fulltype.h new file mode 100644 index 00000000..1a55e0bc --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/int32_fulltype.h @@ -0,0 +1,65 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_INT32_FULLTYPE_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_INT32_FULLTYPE_H_ + +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/macros.h" + +namespace tensorflow { + +// An optimization (graph rewrite) pass to automatically set TFT_SHAPE_TENSOR +// full type information annotations for all int32 tensors, creating or +// modifying existing full type information as needed. This allows placement +// mechanisms using full type information to always place int32 on host. +class Int32FulltypePass { + public: + Int32FulltypePass() = default; + explicit Int32FulltypePass(string debug_location) + : debug_location_(debug_location) {} + + // For each node in this graph that outputs int32 tensors, set full + // type information such that the int32 tensors use TFT_SHAPE_TENSOR + // (or TFT_TENSOR if ints_on_device is true, which is only for single + // device functions including the functions with just one op used for + // eager execution). + // + // This method is not thread-safe. + absl::Status ProcessGraph(Graph* graph, bool ints_on_device); + + // Update full type information for int32 tensors that are in HOST_MEMORY + // to use TFT_SHAPE_TENSOR. The type_id of TENSOR_T is expected to be + // TFT_UNSET, TFT_TENSOR or TFT_SHAPE_TENSOR on input and will be updated + // to TFT_SHAPE_TENSOR on output for int32 tensors if it is not + // TFT_SHAPE_TENSOR already. For tensors that are not int32, if the input full + // type information is TFT_UNSET, it will only be updated if SET_ONLY_INT32 is + // false. Note that TENSOR_T is not the full type information for the outputs + // of a node, so it does have an outer TFT_PRODUCT. NODE and OUTPUT_IDX are + // optional and only used in an error message to say that the tensor is output + // OUTPUT_IDX of node NODE. + absl::Status Int32FullTypeForTensor(DataType dtype, FullTypeDef* tensor_t, + bool set_only_int32, Node* node = nullptr, + int output_idx = 0); + + private: + // Location of where annotations were added for debug messages. + string debug_location_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_INT32_FULLTYPE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/isolate_placer_inspection_required_ops_pass.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/isolate_placer_inspection_required_ops_pass.h new file mode 100644 index 00000000..1bcdc001 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/isolate_placer_inspection_required_ops_pass.h @@ -0,0 +1,63 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_ISOLATE_PLACER_INSPECTION_REQUIRED_OPS_PASS_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_ISOLATE_PLACER_INSPECTION_REQUIRED_OPS_PASS_H_ + +#include "tensorflow/core/common_runtime/optimization_registry.h" + +namespace tensorflow { +// Adds Identities for each input/output of function-calling ops. +// +// For example, the following graph calling a function on inputs `a` and `b` +// and producing output `y` will be rewritted to include identities on all +// edges: +// +// a b +// | | +// v v +// f (PartitionedCallOp) +// | +// v +// y +// +// is transformed to +// +// a b +// | | +// a_f (Identity) a_f (Identity) +// | | +// v v +// f (PartitionedCallOp) +// | +// f_y (Identity) +// | +// v +// y +// +// This pass is currently needed to simplify correctly placing the nodes +// producing inputs for as well as consuming output from function-calling ops. +// +// This pass should also help to implement replacing PartitionedCallOp with +// component function calls (to avoid copying input/output tensors), if we get +// to it. +class IsolatePlacerInspectionRequiredOpsPass : public GraphOptimizationPass { + public: + absl::Status Run(const GraphOptimizationPassOptions& options) override; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_ISOLATE_PLACER_INSPECTION_REQUIRED_OPS_PASS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/kernel_benchmark_testlib.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/kernel_benchmark_testlib.h new file mode 100644 index 00000000..fcab9a65 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/kernel_benchmark_testlib.h @@ -0,0 +1,86 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_KERNEL_BENCHMARK_TESTLIB_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_KERNEL_BENCHMARK_TESTLIB_H_ + +#include +#include + +#include "tensorflow/core/common_runtime/executor.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/graph/testlib.h" +#include "tensorflow/core/lib/core/threadpool.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/test_benchmark.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +class Device; +class FunctionLibraryRuntime; +class ProcessFunctionLibraryRuntime; +struct SessionOptions; +class DynamicDeviceMgr; + +namespace test { + +class Benchmark { + public: + // "device" must be either "cpu" or "gpu". Takes ownership of "g", + // "init", and one reference on "rendez" (if not null). + // + // old_benchmark_api: If true, the benchmark is running with older API + // * In the old API, the timer needs to be stopped/restarted + // by users. + // * In the new API, the timer starts automatically at the first + // iteration of the loop and stops after the last iteration. + // TODO(vyng) Remove this once we have migrated all code to newer API. + Benchmark(const string& device, Graph* g, + const SessionOptions* options = nullptr, Graph* init = nullptr, + Rendezvous* rendez = nullptr, const char* executor_type = "", + bool old_benchmark_api = false); + + Benchmark(const string& device, Graph* g, bool old_benchmark_api); + + ~Benchmark(); + + void Run(benchmark::State& state); + + void RunWithRendezvousArgs( + const std::vector>& inputs, + const std::vector& outputs, benchmark::State& state); + + private: + thread::ThreadPool* pool_ = nullptr; // Not owned. + Device* device_ = nullptr; // Not owned. + Rendezvous* rendez_ = nullptr; + std::unique_ptr device_mgr_; + std::unique_ptr flib_def_; + std::unique_ptr pflr_; + FunctionLibraryRuntime* flr_; // Not owned. + std::unique_ptr exec_; + + Benchmark(const Benchmark&) = delete; + void operator=(const Benchmark&) = delete; +}; + +// Returns the rendezvous key associated with the given Send/Recv node. +string GetRendezvousKey(const Node* node); + +} // end namespace test +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_KERNEL_BENCHMARK_TESTLIB_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/layout_pass_util.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/layout_pass_util.h new file mode 100644 index 00000000..909ff86f --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/layout_pass_util.h @@ -0,0 +1,82 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_LAYOUT_PASS_UTIL_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_LAYOUT_PASS_UTIL_H_ + +#if defined(INTEL_MKL) || defined(AMD_ZENDNN) + +#include +#include +#include + +#include "tensorflow/core/graph/node_builder.h" +#include "tensorflow/core/platform/logging.h" + +namespace tensorflow { + +// Temporarily wrapping these helper functions in the zendnn namespace +// to avoid crashing with similar functions in mkl_layout_pass.cc. +// TODO(penporn): Delete the functions in mkl_layout_pass and use the functions +// here after TF 2.12 branch cut. +namespace zendnn { + +// Is OpDef::ArgDef a list type? It could be N * T or list(type). +// Refer to opdef.proto for details of list type. +inline bool ArgIsList(const OpDef::ArgDef &arg); + +// Get length of a list in 'n' if 'arg' is of list type. Refer to +// description of ArgIsList for definition of list type. +inline int GetTensorListLength(const OpDef::ArgDef &arg, const Node *n); + +// Can op represented by node 'n' run on DEVICE_CPU? +// Op can run on CPU with ZenDNN if the runtime assigned device or the +// user requested device contains device CPU, or both are empty. +bool CanOpRunOnCPUDevice(const Node *n); + +// Get nodes that will feed a list of TF tensors to the new +// node that we are constructing. +// +// @input inputs - inputs to old node that we are using for constructing +// new inputs, +// @input input_idx - the index in the 'inputs' vector pointing to the +// current input that we have processed so far +// @output input_idx - index will be incremented by the number of nodes +// from 'inputs' that are processed +// @input list_length - The expected length of list of TF tensors +// @output output_nodes - the list of new nodes creating TF tensors +// +// @return None +void GetNodesProducingTFTensorList( + const gtl::InlinedVector, 4> &inputs, int *input_idx, + int list_length, std::vector *output_nodes); + +// Create new inputs by copying old inputs 'inputs' for the rewritten node +// in 'nb' in graph 'g'. Original node is input in 'orig_node'. This is mostly +// used in the context of rewrite for just operator name change in which +// inputs of old operator and new operator are same. +// +// Returns OkStatus() if setting up inputs is successful, otherwise +// returns appropriate status code. +Status CopyInputs( + const Node *old_node, + const gtl::InlinedVector, 4> &old_node_inputs, + NodeBuilder *nb); + +} // namespace zendnn +} // namespace tensorflow + +#endif // INTEL_MKL || AMD_ZENDNN +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_LAYOUT_PASS_UTIL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/local_device.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/local_device.h new file mode 100644 index 00000000..595d3b88 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/local_device.h @@ -0,0 +1,58 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_LOCAL_DEVICE_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_LOCAL_DEVICE_H_ + +#include "tensorflow/core/common_runtime/device.h" +#include "tensorflow/core/framework/device_attributes.pb.h" +#include "tensorflow/core/platform/macros.h" + +namespace tensorflow { + +namespace test { +class Benchmark; +} +struct SessionOptions; + +// This class is shared by ThreadPoolDevice and GPUDevice and +// initializes a shared Eigen compute device used by both. This +// should eventually be removed once we refactor ThreadPoolDevice and +// GPUDevice into more 'process-wide' abstractions. +class LocalDevice : public Device { + public: + LocalDevice(const SessionOptions& options, + const DeviceAttributes& attributes); + ~LocalDevice() override; + + private: + static bool use_global_threadpool_; + + static void set_use_global_threadpool(bool use_global_threadpool) { + use_global_threadpool_ = use_global_threadpool; + } + + struct EigenThreadPoolInfo; + std::unique_ptr owned_tp_info_; + + friend class test::Benchmark; + + LocalDevice(const LocalDevice&) = delete; + void operator=(const LocalDevice&) = delete; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_LOCAL_DEVICE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/local_executor_params.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/local_executor_params.h new file mode 100644 index 00000000..a363f113 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/local_executor_params.h @@ -0,0 +1,57 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_LOCAL_EXECUTOR_PARAMS_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_LOCAL_EXECUTOR_PARAMS_H_ + +#include +#include + +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { +class Device; +class StepStatsCollector; +class SessionMetadata; +class FunctionLibraryRuntime; +class NodeProperties; +class OpKernel; + +// LocalExecutorParams provides arguments that will be shared by all invocations +// of an executor. We expect that different contexts would provide different +// implementations (e.g. local versus distributed). +struct LocalExecutorParams { + Device* device; + + const SessionMetadata* session_metadata = nullptr; + + // The library runtime support. + FunctionLibraryRuntime* function_library = nullptr; + + // create_kernel returns an instance of op kernel based on NodeDef. + // delete_kernel is called for every kernel used by the executor + // when the executor is deleted. + std::function&, + OpKernel**)> + create_kernel; + std::function delete_kernel; + + // Whether control flow nodes are allowed to be executed synchronously. + bool allow_control_flow_sync_execution = false; +}; + +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_LOCAL_EXECUTOR_PARAMS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/local_session_selection.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/local_session_selection.h new file mode 100644 index 00000000..9e21c8d7 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/local_session_selection.h @@ -0,0 +1,33 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_LOCAL_SESSION_SELECTION_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_LOCAL_SESSION_SELECTION_H_ + +namespace tensorflow { + +// The TF Session implementations that can be used to run local sessions, i.e. +// when session_target in SessionOptions is empty. +enum class LocalSessionImpl { + kDirectSession, + kTfrtSession, +}; + +void SetDefaultLocalSessionImpl(LocalSessionImpl impl); +LocalSessionImpl GetDefaultLocalSessionImpl(); + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_LOCAL_SESSION_SELECTION_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/lower_case_op.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/lower_case_op.h new file mode 100644 index 00000000..65b56e51 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/lower_case_op.h @@ -0,0 +1,31 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_LOWER_CASE_OP_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_LOWER_CASE_OP_H_ + +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { + +class Graph; +class Node; + +// Replaces Case node `n` with a lowered form that uses _SwitchN/Merge nodes. +absl::Status RewriteCaseNode(Node* n, Graph* g, bool keep_node_fetchable); + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_LOWER_CASE_OP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/lower_function_call_inline_policy.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/lower_function_call_inline_policy.h new file mode 100644 index 00000000..6dc48f8e --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/lower_function_call_inline_policy.h @@ -0,0 +1,60 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_LOWER_FUNCTION_CALL_INLINE_POLICY_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_LOWER_FUNCTION_CALL_INLINE_POLICY_H_ + +#include "tensorflow/core/graph/graph.h" + +namespace tensorflow { + +// LINT.IfChange +enum class FunctionCallInlinePolicy { + // Place input nodes on the same device as the corresponding caller input + // node. Do not specify any placement for all other nodes. + kDefaultPlacer, + + // Place all nodes on the same device as caller node. + kSingleDevicePlacer, + + // Place input nodes on the same device as the corresponding caller input + // node. Do not place output node. Place control nodes on the same device as + // caller node. For all function body nodes overrides job, replica and task + // parts of the device assignment to match function caller node. + kMultiDevicePlacer +}; +// LINT.ThenChange(inline_function_utils.h,\ +// ../../compiler/mlir/tensorflow/ir/tf_ops.cc) + +struct LowerFunctionalOpsConstants { + static constexpr const char* const kLowerUsingSwitchMergeAttr = + "_lower_using_switch_merge"; + static constexpr const char* const kLowerAsMultiDeviceFunctionAttr = + "_lower_as_multi_device_function"; +}; + +// Inliner policy used in common runtime's lower function call op. + +// Returns the function call inline policy to use for a given call. +FunctionCallInlinePolicy GetFunctionCallInlinePolicy(const Node* n); + +// Overload of GetFunctionCallInlinePolicy that doesn't require an op but only +// the features required. +FunctionCallInlinePolicy GetFunctionCallInlinePolicy( + bool is_partioned_call, bool has_lower_as_multi_device_function_attr); + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_LOWER_FUNCTION_CALL_INLINE_POLICY_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/lower_function_call_op.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/lower_function_call_op.h new file mode 100644 index 00000000..71d5e807 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/lower_function_call_op.h @@ -0,0 +1,37 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_LOWER_FUNCTION_CALL_OP_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_LOWER_FUNCTION_CALL_OP_H_ + +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { + +class FunctionLibraryDefinition; +class Graph; +class Node; + +// Replaces function call node `n` with its function body. Uses +// InlineFunctionBody from `common_runtime/function.{h,cc}`. If function +// inlining is not possible or safe (see ValidateInlining), leaves the graph in +// unmodified state and returns OkStatus(); +absl::Status RewriteFunctionCallNode(Node* n, Graph* g, + const FunctionLibraryDefinition& flib_def, + bool keep_caller_fetchable); + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_LOWER_FUNCTION_CALL_OP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/lower_functional_ops.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/lower_functional_ops.h new file mode 100644 index 00000000..a849550a --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/lower_functional_ops.h @@ -0,0 +1,47 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_LOWER_FUNCTIONAL_OPS_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_LOWER_FUNCTIONAL_OPS_H_ + +#include "absl/types/optional.h" +#include "tensorflow/core/common_runtime/inline_function_utils.h" +#include "tensorflow/core/common_runtime/optimization_registry.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { + +// Rewrite functional ops into low level primitives: +// - If/While ops lowered into low level control flow primitives: Switch, Merge, +// Enter, Exit, NextIteration +// - Function calls inlined into the main graph +// +// IMPORTANT: Although SymbolicGradient is a function call, we currently do not +// lower it, because it has been deprecated for a while. +class LowerFunctionalOpsPass : public GraphOptimizationPass { + public: + LowerFunctionalOpsPass() = default; + + absl::Status Run(const GraphOptimizationPassOptions& options) override; + + static constexpr const char* const kLowerUsingSwitchMergeAttr = + LowerFunctionalOpsConstants::kLowerUsingSwitchMergeAttr; + static constexpr const char* const kLowerAsMultiDeviceFunctionAttr = + LowerFunctionalOpsConstants::kLowerAsMultiDeviceFunctionAttr; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_LOWER_FUNCTIONAL_OPS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/lower_if_op.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/lower_if_op.h new file mode 100644 index 00000000..c125a197 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/lower_if_op.h @@ -0,0 +1,31 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_LOWER_IF_OP_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_LOWER_IF_OP_H_ + +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { + +class Graph; +class Node; + +// Replaces If node `n` with its lowered form that uses Switch and Merge nodes. +absl::Status RewriteIfNode(Node* n, Graph* g, bool keep_node_fetchable); + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_LOWER_IF_OP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/lower_while_op.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/lower_while_op.h new file mode 100644 index 00000000..98095dee --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/lower_while_op.h @@ -0,0 +1,35 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_LOWER_WHILE_OP_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_LOWER_WHILE_OP_H_ + +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { + +class Graph; +class Node; +class FunctionLibraryDefinition; + +// Replaces While node `n` with its lowered form that uses Enter, Exit, Switch, +// Merge, NextIteration and LoopCond nodes. +absl::Status RewriteWhileNode(Node* n, Graph* g, + const FunctionLibraryDefinition* flib_def, + bool keep_node_fetchable); + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_LOWER_WHILE_OP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/memory_types.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/memory_types.h new file mode 100644 index 00000000..46a943c0 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/memory_types.h @@ -0,0 +1,49 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_MEMORY_TYPES_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_MEMORY_TYPES_H_ + +#include "tensorflow/core/framework/memory_types.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { + +// Returns an error iff *g running on a single device of 'device_type' +// has memory type mismatch for any edge's source and destination. +absl::Status ValidateMemoryTypes(const DeviceType& device_type, const Graph* g); + +// Updates '*g' so that every edge's source and destination has +// compatible memory types by inserting proper HostSend/Recv and +// Send/HostRecv nodes. 'device_type' specifies the type of device on +// which '*g' is going to run on and that device has the name +// 'device_name'. +// +// Returns OK if '*g' is updated properly (ValidateMemoryTypes(g) must +// be OK). Otherwise, returns an error and '*g' may be in an +// invalidate state and the caller should discard it. +absl::Status EnsureMemoryTypes(const DeviceType& device_type, + const string& device_name, Graph* g); + +// Get the memory type for 'index'th output of node 'n' in graph 'g', when +// running on 'device_type'. +absl::Status MemoryTypeForOutput(const DeviceType& device_type, const Graph* g, + const Node* n, int index, + MemoryType* memory_type); + +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_MEMORY_TYPES_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/mkl_cpu_allocator.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/mkl_cpu_allocator.h new file mode 100644 index 00000000..54d60fdc --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/mkl_cpu_allocator.h @@ -0,0 +1,331 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// A simple CPU allocator that intercepts malloc/free calls from MKL library +// and redirects them to Tensorflow allocator + +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_MKL_CPU_ALLOCATOR_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_MKL_CPU_ALLOCATOR_H_ + +#ifdef INTEL_MKL + +#include + +#include "tensorflow/core/common_runtime/bfc_allocator.h" +#include "tensorflow/core/common_runtime/pool_allocator.h" +#include "tensorflow/core/lib/strings/numbers.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/platform/mem.h" +#include "tensorflow/core/platform/numa.h" +#include "tensorflow/core/util/env_var.h" +#include "tensorflow/core/util/onednn_env_vars.h" +#ifdef _WIN32 +typedef unsigned int uint; +#endif + +namespace tensorflow { + +static bool mkl_small_allocator_collect_stats = false; + +class MklSubAllocator : public BasicCPUAllocator { + public: + MklSubAllocator() : BasicCPUAllocator(port::kNUMANoAffinity, {}, {}) {} + ~MklSubAllocator() override {} +}; + +// CPU allocator that handles small-size allocations by calling +// suballocator directly. Mostly, it is just a wrapper around a suballocator +// (that calls malloc and free directly) with support for bookkeeping. +class MklSmallSizeAllocator : public Allocator { + public: + MklSmallSizeAllocator(SubAllocator* sub_allocator, size_t total_memory, + const string& name) + : sub_allocator_(sub_allocator), name_(name) { + stats_.bytes_limit = total_memory; + } + ~MklSmallSizeAllocator() override {} + + MklSmallSizeAllocator(const MklSmallSizeAllocator&) = delete; + void operator=(const MklSmallSizeAllocator&) = delete; + + inline string Name() override { return name_; } + + void* AllocateRaw(size_t alignment, size_t num_bytes) override { + void* ptr = port::AlignedMalloc(num_bytes, alignment); + if (mkl_small_allocator_collect_stats) IncrementStats(num_bytes); + return ptr; + } + + void DeallocateRaw(void* ptr) override { + if (ptr == nullptr) { + LOG(ERROR) << "tried to deallocate nullptr"; + return; + } + + if (mkl_small_allocator_collect_stats) { + const size_t alloc_size = port::MallocExtension_GetAllocatedSize(ptr); + DecrementStats(alloc_size); + } + port::AlignedFree(ptr); + } + + absl::optional GetStats() override { + mutex_lock l(mutex_); + return stats_; + } + + bool ClearStats() override { + mutex_lock l(mutex_); + stats_.num_allocs = 0; + stats_.peak_bytes_in_use = 0; + stats_.largest_alloc_size = 0; + stats_.bytes_in_use = 0; + stats_.bytes_limit = 0; + return true; + } + + private: + // Increment statistics for the allocator handling small allocations. + inline void IncrementStats(size_t alloc_size) TF_LOCKS_EXCLUDED(mutex_) { + mutex_lock l(mutex_); + ++stats_.num_allocs; + stats_.bytes_in_use += alloc_size; + stats_.peak_bytes_in_use = + std::max(stats_.peak_bytes_in_use, stats_.bytes_in_use); + stats_.largest_alloc_size = + std::max(alloc_size, static_cast(stats_.largest_alloc_size)); + } + + // Decrement statistics for the allocator handling small allocations. + inline void DecrementStats(size_t dealloc_size) TF_LOCKS_EXCLUDED(mutex_) { + mutex_lock l(mutex_); + stats_.bytes_in_use -= dealloc_size; + } + + SubAllocator* sub_allocator_; // Not owned by this class. + + // Mutex for protecting updates to map of allocations. + mutable mutex mutex_; + + // Allocator name + string name_; + + // Allocator stats for small allocs + AllocatorStats stats_ TF_GUARDED_BY(mutex_); +}; + +/// CPU allocator for MKL that wraps BFC allocator and intercepts +/// and redirects memory allocation calls from MKL. +class MklCPUAllocator : public Allocator { + public: + // Constructor and other standard functions + + /// Environment variable that user can set to upper bound on memory allocation + static constexpr const char* kMaxLimitStr = "TF_MKL_ALLOC_MAX_BYTES"; + + /// Default upper limit on allocator size - 64GB + static constexpr size_t kDefaultMaxLimit = 64LL << 30; + + MklCPUAllocator() { TF_CHECK_OK(Initialize()); } + + ~MklCPUAllocator() override { + delete small_size_allocator_; + delete large_size_allocator_; + } + + Status Initialize() { + VLOG(2) << "MklCPUAllocator: In MklCPUAllocator"; + + // Set upper bound on memory allocation to physical RAM available on the + // CPU unless explicitly specified by user + uint64 max_mem_bytes = kDefaultMaxLimit; +#if defined(_SC_PHYS_PAGES) && defined(_SC_PAGESIZE) + max_mem_bytes = + (uint64)sysconf(_SC_PHYS_PAGES) * (uint64)sysconf(_SC_PAGESIZE); +#endif + char* user_mem_bytes = getenv(kMaxLimitStr); + + if (user_mem_bytes != NULL) { + uint64 user_val = 0; + if (!strings::safe_strtou64(user_mem_bytes, &user_val)) { + return errors::InvalidArgument("Invalid memory limit (", user_mem_bytes, + ") specified for MKL allocator through ", + kMaxLimitStr); + } +#if defined(_SC_PHYS_PAGES) && defined(_SC_PAGESIZE) + if (user_val > max_mem_bytes) { + LOG(WARNING) << "The user specified a memory limit " << kMaxLimitStr + << "=" << user_val + << " greater than available physical memory: " + << max_mem_bytes + << ". This could significantly reduce performance!"; + } +#endif + max_mem_bytes = user_val; + } + + VLOG(1) << "MklCPUAllocator: Setting max_mem_bytes: " << max_mem_bytes; + + sub_allocator_ = new MklSubAllocator(); + + // SubAllocator is owned by BFCAllocator, so we do not need to deallocate + // it in MklSmallSizeAllocator. + small_size_allocator_ = + new MklSmallSizeAllocator(sub_allocator_, max_mem_bytes, kName); + + BFCAllocator::Options large_allocator_opts; + large_allocator_opts.allow_growth = kAllowGrowth; + large_size_allocator_ = + new BFCAllocator(absl::WrapUnique(sub_allocator_), max_mem_bytes, kName, + large_allocator_opts); + return OkStatus(); + } + + inline string Name() override { return kName; } + inline bool IsSmallSizeAllocation(const void* ptr) const + TF_LOCKS_EXCLUDED(mutex_) { + mutex_lock l(mutex_); + return large_allocations_map_.find(ptr) == large_allocations_map_.end(); + } + // AddLargeAllocMap and RemoveLargeAllocMap are always called with a lock held + inline void AddLargeAllocMap(void* ptr, size_t num_bytes) + TF_EXCLUSIVE_LOCKS_REQUIRED(mutex_) { + if (ptr != nullptr) { + std::pair map_val(ptr, num_bytes); + large_allocations_map_.insert(map_val); + } + } + inline void RemoveLargeAllocMap(void* ptr) + TF_EXCLUSIVE_LOCKS_REQUIRED(mutex_) { + auto map_iter = large_allocations_map_.find(ptr); + if (map_iter != large_allocations_map_.end()) { + large_allocations_map_.erase(map_iter); + } else { + LOG(ERROR) << "tried to deallocate invalid pointer"; + } + return; + } + + inline void* AllocateRaw(size_t alignment, size_t num_bytes) override { + // If the allocation size is less than threshold, call small allocator, + // otherwise call large-size allocator (BFC). We found that BFC allocator + // does not deliver good performance for small allocations when + // inter_op_parallelism_threads is high. + if (UseSystemAlloc() || num_bytes < kSmallAllocationsThreshold) { + return small_size_allocator_->AllocateRaw(alignment, num_bytes); + } else { + mutex_lock l(mutex_); + void* ptr = large_size_allocator_->AllocateRaw(alignment, num_bytes); + AddLargeAllocMap(ptr, num_bytes); + return ptr; + } + } + inline void DeallocateRaw(void* ptr) override { + // Check if ptr is for "small" allocation. If it is, then call Free + // directly. Otherwise, call BFC to handle free. + if (UseSystemAlloc() || IsSmallSizeAllocation(ptr)) { + small_size_allocator_->DeallocateRaw(ptr); + } else { + mutex_lock l(mutex_); + RemoveLargeAllocMap(ptr); + large_size_allocator_->DeallocateRaw(ptr); + } + } + absl::optional GetStats() override { + auto s_stats = small_size_allocator_->GetStats(); + auto l_stats = large_size_allocator_->GetStats(); + + // Combine statistics from small-size and large-size allocator. + mutex_lock l(mutex_); + stats_.num_allocs = l_stats->num_allocs + s_stats->num_allocs; + stats_.bytes_in_use = l_stats->bytes_in_use + s_stats->bytes_in_use; + stats_.peak_bytes_in_use = + l_stats->peak_bytes_in_use + s_stats->peak_bytes_in_use; + + // Since small-size allocations go to MklSmallSizeAllocator, + // max_alloc_size from large_size_allocator would be the maximum + // size allocated by MklCPUAllocator. + stats_.largest_alloc_size = l_stats->largest_alloc_size; + stats_.bytes_limit = std::max(s_stats->bytes_limit, l_stats->bytes_limit); + return stats_; + } + + bool ClearStats() override { + bool stats_cleared = small_size_allocator_->ClearStats(); + stats_cleared &= large_size_allocator_->ClearStats(); + return stats_cleared; + } + + private: + // Hooks provided by this allocator for memory allocation routines from MKL + static inline void* MallocHook(size_t size) { + VLOG(3) << "MklCPUAllocator: In MallocHook"; + return cpu_allocator()->AllocateRaw(kAlignment, size); + } + + static inline void FreeHook(void* ptr) { + VLOG(3) << "MklCPUAllocator: In FreeHook"; + cpu_allocator()->DeallocateRaw(ptr); + } + + static inline void* CallocHook(size_t num, size_t size) { + Status s = Status(absl::StatusCode::kUnimplemented, + "Unimplemented case for hooking MKL function."); + TF_CHECK_OK(s); // way to assert with an error message + return nullptr; // return a value and make static code analyzers happy + } + + static inline void* ReallocHook(void* ptr, size_t size) { + Status s = Status(absl::StatusCode::kUnimplemented, + "Unimplemented case for hooking MKL function."); + TF_CHECK_OK(s); // way to assert with an error message + return nullptr; // return a value and make static code analyzers happy + } + + // Do we allow growth in BFC Allocator + static const bool kAllowGrowth = true; + + // Name + static constexpr const char* kName = "mklcpu"; + + // The alignment that we need for the allocations + static constexpr const size_t kAlignment = 64; + + Allocator* large_size_allocator_; // owned by this class + MklSmallSizeAllocator* small_size_allocator_; // owned by this class. + + SubAllocator* sub_allocator_; // not owned by this class + mutable mutex mutex_; + AllocatorStats stats_ TF_GUARDED_BY(mutex_); + + // Hash map to keep track of "BFC" allocations + // We do not use BFC allocator for small allocations. + std::unordered_map large_allocations_map_ + TF_GUARDED_BY(mutex_); + + // Size in bytes that defines the upper-bound for "small" allocations. + // Any allocation below this threshold is "small" allocation. + static constexpr const size_t kSmallAllocationsThreshold = 262144; + + // Prevent copying and assignment + MklCPUAllocator(const MklCPUAllocator&) = delete; + void operator=(const MklCPUAllocator&) = delete; +}; + +} // namespace tensorflow + +#endif // INTEL_MKL + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_MKL_CPU_ALLOCATOR_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/mkl_layout_pass.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/mkl_layout_pass.h new file mode 100644 index 00000000..6b5c586c --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/mkl_layout_pass.h @@ -0,0 +1,36 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// A graph pass that rewrites graph for propagating MKL layout as a tensor + +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_MKL_LAYOUT_PASS_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_MKL_LAYOUT_PASS_H_ + +#ifdef INTEL_MKL + +#include +#include +#include "tensorflow/core/graph/graph.h" + +namespace tensorflow { +// Interface to invoke the pass for unit test +// +// Returns true if and only if 'g' is mutated. +extern bool RunMklLayoutRewritePass(std::unique_ptr* g); +} // namespace tensorflow + +#endif + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_MKL_LAYOUT_PASS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/next_pluggable_device/c/example_plugin.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/next_pluggable_device/c/example_plugin.h new file mode 100644 index 00000000..baeebef6 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/next_pluggable_device/c/example_plugin.h @@ -0,0 +1,49 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_NEXT_PLUGGABLE_DEVICE_C_EXAMPLE_PLUGIN_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_NEXT_PLUGGABLE_DEVICE_C_EXAMPLE_PLUGIN_H_ + +#include "tensorflow/core/common_runtime/next_pluggable_device/c/plugin_c_api.h" +#include "tfrt/host_context/host_context.h" // from @tf_runtime + +// This is an example plugin that impelements several basic APIs for event. This +// is for testing only. + +#ifdef __cplusplus +extern "C" { +#endif + +struct TFNPD_DeviceEvent { + tfrt::RCReference event; +}; + +// Does not pass ownership of returned TFNPD_Api* to caller. +const TFNPD_Api* GetExamplePluginApi(); + +#ifdef __cplusplus +} +#endif + +namespace example_plugin { + +// A helper method that generates a TFNPD_DeviceEvent, and makes the event +// available (or ready) in two seconds. +TFNPD_DeviceEvent* CreateDeviceEventAndSetAvailable(tfrt::HostContext* host, + bool set_as_error = false); + +} // namespace example_plugin + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_NEXT_PLUGGABLE_DEVICE_C_EXAMPLE_PLUGIN_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/next_pluggable_device/c/outside_compilation_params.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/next_pluggable_device/c/outside_compilation_params.h new file mode 100644 index 00000000..21b17104 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/next_pluggable_device/c/outside_compilation_params.h @@ -0,0 +1,37 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_NEXT_PLUGGABLE_DEVICE_C_OUTSIDE_COMPILATION_PARAMS_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_NEXT_PLUGGABLE_DEVICE_C_OUTSIDE_COMPILATION_PARAMS_H_ + +#include "xla/stream_executor/tpu/c_api_decl.h" +#include "tensorflow/core/common_runtime/next_pluggable_device/c/tf_rendezvous_c_api.h" + +#ifdef __cplusplus +extern "C" { +#endif + +struct SE_OutsideCompilationParams { + char* device_name; + char* rendezvous_key; + TF_RendezvousThunk* rendezvous; + TpuSerializedProto host_transfers; +}; + +#ifdef __cplusplus +} +#endif + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_NEXT_PLUGGABLE_DEVICE_C_OUTSIDE_COMPILATION_PARAMS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/next_pluggable_device/c/plugin_c_api.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/next_pluggable_device/c/plugin_c_api.h new file mode 100644 index 00000000..e44e5f3f --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/next_pluggable_device/c/plugin_c_api.h @@ -0,0 +1,176 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_NEXT_PLUGGABLE_DEVICE_C_PLUGIN_C_API_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_NEXT_PLUGGABLE_DEVICE_C_PLUGIN_C_API_H_ + +#include + +#include "tensorflow/c/c_api.h" +#include "tensorflow/c/c_api_macros.h" +#include "tensorflow/c/tf_status.h" +#include "tensorflow/c/tf_tensor.h" +#include "xla/c/c_api_decl.h" +#include "xla/pjrt/c/pjrt_c_api.h" +#include "xla/stream_executor/tpu/c_api_decl.h" + +#define TFNPD_MAJOR 0 +#define TFNPD_MINOR 0 +#define TFNPD_PATCH 1 + +// Experimental C API for TensorFlow Next Pluggable device (TFNPD). + +#ifdef __cplusplus +extern "C" { +#endif + +// ---------------------------- Event ---------------------------------------- +typedef struct TFNPD_DeviceEvent TFNPD_DeviceEvent; + +typedef TFNPD_DeviceEvent* TFNPD_NewDeviceEvent(); + +typedef void TFNPD_DeviceEventAwait(TFNPD_DeviceEvent* event, + TF_Status* status); + +typedef bool TFNPD_DeviceEventIsReady(TFNPD_DeviceEvent* event); + +// Invokes the callback after event becomes ready. +typedef void TFNPD_DeviceEventAndThen(TFNPD_DeviceEvent* event, + void (*callback)(void*), + void* callback_arg); + +typedef void TFNPD_DeviceEventDelete(TFNPD_DeviceEvent* event); + +// -------------------------- Allocator -------------------------------------- +typedef struct TFNPD_DeviceAllocator TFNPD_DeviceAllocator; + +typedef TFNPD_DeviceAllocator* TFNPD_DeviceAllocatorCreate(int device_ordinal); + +typedef void* TFNPD_DeviceAllocateRaw(TFNPD_DeviceAllocator* allocator, + size_t alignment, size_t num_bytes); + +typedef void TFNPD_DeviceDeallocateRaw(TFNPD_DeviceAllocator* allocator, + void* ptr); + +typedef TF_StringView TFNPD_DeviceAllocatorName( + TFNPD_DeviceAllocator* allocator); + +typedef bool TFNPD_DeviceAllocatorAllocatesOpaqueHandle( + TFNPD_DeviceAllocator* allocator); + +typedef void TFNPD_DeviceAllocatorDelete(TFNPD_DeviceAllocator* allocator); + +// ------------------------ Tensor Transfers --------------------------------- +typedef struct TFNPD_DeviceContext TFNPD_DeviceContext; + +// TODO(chuanhao): use an option struct to create context. Plugin can define the +// option so that we support more features in the DeviceContext, e.g. +// shape_determination_fns. +typedef TFNPD_DeviceContext* TFNPD_DeviceContextCreate(int device_ordinal); + +typedef TFNPD_DeviceEvent* TFNPD_DeviceTensorToHostTensor( + TFNPD_DeviceContext* device_context, const TF_Tensor* device_tensor, + TF_Tensor* cpu_tensor, TF_Status* status); + +typedef TFNPD_DeviceEvent* TFNPD_HostTensorToDeviceTensor( + TFNPD_DeviceContext* device_context, const TF_Tensor* cpu_tensor, + TF_Tensor* device_tensor, TF_Status* status); + +typedef TFNPD_DeviceEvent* TFNPD_SameDeviceTensorCopy( + TFNPD_DeviceContext* context); + +typedef PJRT_Buffer* TFNPD_SameDevicePjRtBufferCopy(PJRT_Buffer* src_buffer, + PJRT_Client* c_client, + TF_Status* status); + +typedef void TFNPD_DeviceContextDelete(TFNPD_DeviceContext* context); + +// ------------------------------ TF2XLA ------------------------------------- +// TODO(b/254484247): either separate XLA_Shape to its own file, or use PJRT +// solution when it is ready. +typedef void TFNPD_XlaShapeToDeviceShapeRepresentation( + XLA_Shape* serialized_xla_shape, int data_type, bool use_fast_memory, + XLA_LayoutPreference layout_preference, XLA_Shape* serialized_device_shape, + TF_Status* tf_status); + +// ----------------------- Plugin System related ----------------------------- +typedef int32_t TFNPD_GetDeviceCount(TF_Status* status); + +// Initialize any per-device states or resources that are internal to plugin. +typedef void TFNPD_InitPluginInternalDeviceStates(TF_Status* status); + +// --------------------------- C API access ------------------------------------ +#define TFNPD_API_STRUCT_FN(fn_type) fn_type* fn_type + +typedef struct { + size_t struct_size; + void* priv; + + TFNPD_API_STRUCT_FN(TFNPD_NewDeviceEvent); + TFNPD_API_STRUCT_FN(TFNPD_DeviceEventAwait); + TFNPD_API_STRUCT_FN(TFNPD_DeviceEventIsReady); + TFNPD_API_STRUCT_FN(TFNPD_DeviceEventAndThen); + TFNPD_API_STRUCT_FN(TFNPD_DeviceEventDelete); + + TFNPD_API_STRUCT_FN(TFNPD_DeviceAllocatorCreate); + TFNPD_API_STRUCT_FN(TFNPD_DeviceAllocateRaw); + TFNPD_API_STRUCT_FN(TFNPD_DeviceDeallocateRaw); + TFNPD_API_STRUCT_FN(TFNPD_DeviceAllocatorName); + TFNPD_API_STRUCT_FN(TFNPD_DeviceAllocatorAllocatesOpaqueHandle); + TFNPD_API_STRUCT_FN(TFNPD_DeviceAllocatorDelete); + + TFNPD_API_STRUCT_FN(TFNPD_DeviceContextCreate); + TFNPD_API_STRUCT_FN(TFNPD_DeviceContextDelete); + + // TODO(chuanhao): Deprecate the tensor transfer C APIs when PJRT API + // development is ready since we plan to adopt PJRT as Device API. + TFNPD_API_STRUCT_FN(TFNPD_DeviceTensorToHostTensor); + TFNPD_API_STRUCT_FN(TFNPD_HostTensorToDeviceTensor); + TFNPD_API_STRUCT_FN(TFNPD_SameDeviceTensorCopy); + TFNPD_API_STRUCT_FN(TFNPD_SameDevicePjRtBufferCopy); + + TFNPD_API_STRUCT_FN(TFNPD_XlaShapeToDeviceShapeRepresentation); + + TFNPD_API_STRUCT_FN(TFNPD_GetDeviceCount); + TFNPD_API_STRUCT_FN(TFNPD_InitPluginInternalDeviceStates); +} TFNPD_Api; + +const size_t TFNPD_Api_STRUCT_SIZE = + TF_OFFSET_OF_END(TFNPD_Api, TFNPD_InitPluginInternalDeviceStates); + +#undef TFNPD_API_STRUCT_FN + +typedef struct TFNPD_PluginParams { + size_t struct_size; + void* ext; // reserved for future use + + const char* device_type; // output, set by plugin + const char* compilation_device_name; // output, set by plugin + int32_t priority; // output, set by plugin + // Certain devices may set this one to false to avoid using device copy logic + // implemented for legacy PluggableDevice. + bool is_pluggable_device; // output, set by plugin + bool use_pjrt_on_demand_compile; // output, set by plugin +} TFNPD_PluginParams; +const size_t TFNPD_PLUGIN_PARAMS_STRUCT_SIZE = + TF_OFFSET_OF_END(TFNPD_PluginParams, is_pluggable_device); +const TFNPD_Api* TFNPD_InitPlugin(TFNPD_PluginParams* params, + TF_Status* tf_status); + +#if defined(__cplusplus) +} // extern "C" +#endif // defined(__cplusplus) + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_NEXT_PLUGGABLE_DEVICE_C_PLUGIN_C_API_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/next_pluggable_device/c/tf_device_context_c_api.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/next_pluggable_device/c/tf_device_context_c_api.h new file mode 100644 index 00000000..507faaf4 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/next_pluggable_device/c/tf_device_context_c_api.h @@ -0,0 +1,87 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_NEXT_PLUGGABLE_DEVICE_C_TF_DEVICE_CONTEXT_C_API_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_NEXT_PLUGGABLE_DEVICE_C_TF_DEVICE_CONTEXT_C_API_H_ + +#include + +#ifdef __cplusplus +extern "C" { +#endif + +typedef struct TF_Tensor TF_Tensor; +typedef struct TSL_Status TF_Status; + +// Structs for TF_StatusCallback. + +typedef void (*TF_StatusCallback_Function)(void*, TF_Status*); +typedef struct TF_StatusCallback { + void* context; + TF_StatusCallback_Function callback; +} TF_StatusCallback; + +// Structs for CopyCPUTensorToDevice API. +typedef struct TF_DeviceContext_CopyCPUTensorToDevice_Params { + TF_Tensor* cpu_tensor; + // API for `Device` is not available. + // Device* device; + TF_Tensor* device_tensor; // out + TF_StatusCallback* done; + bool sync_dst_compute; +} TF_DeviceContext_CopyCPUTensorToDevice_Params; + +typedef void (*TF_DeviceContext_CopyCPUTensorToDevice_Function)( + void*, TF_DeviceContext_CopyCPUTensorToDevice_Params*); + +// Structs for CopyDeviceTensorToCPU API. +typedef struct TF_DeviceContext_CopyDeviceTensorToCPU_Params { + TF_Tensor* device_tensor; + char* tensor_name; + // API for `Device` is not available. + // Device* device; + uint32_t tensor_name_len; + TF_Tensor* cpu_tensor; // out + TF_StatusCallback* done; +} TF_DeviceContext_CopyDeviceTensorToCPU_Params; + +typedef void (*TF_DeviceContext_CopyDeviceTensorToCPU_Function)( + void*, TF_DeviceContext_CopyDeviceTensorToCPU_Params*); + +// Structs for CopyTensorInSameDevice API. +typedef struct TF_DeviceContext_CopyTensorInSameDevice_Params { + TF_Tensor* input_tensor; + // API for `Device` is not available. + // Device* device; + TF_Tensor* output_tensor; // out + TF_StatusCallback* done; +} TF_DeviceContext_CopyTensorInSameDevice_Params; + +typedef void (*TF_DeviceContext_CopyTensorInSameDevice_Function)( + void*, TF_DeviceContext_CopyTensorInSameDevice_Params*); + +/* DeviceContext */ +typedef struct TF_DeviceContext { + void* device_context; + TF_DeviceContext_CopyCPUTensorToDevice_Function cpu_to_device_func; + TF_DeviceContext_CopyDeviceTensorToCPU_Function device_to_cpu_func; + TF_DeviceContext_CopyTensorInSameDevice_Function same_device_func; +} TF_DeviceContext; + +#ifdef __cplusplus +} // extern "C" +#endif + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_NEXT_PLUGGABLE_DEVICE_C_TF_DEVICE_CONTEXT_C_API_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/next_pluggable_device/c/tf_device_context_c_api_helper.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/next_pluggable_device/c/tf_device_context_c_api_helper.h new file mode 100644 index 00000000..c037f48a --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/next_pluggable_device/c/tf_device_context_c_api_helper.h @@ -0,0 +1,33 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_NEXT_PLUGGABLE_DEVICE_C_TF_DEVICE_CONTEXT_C_API_HELPER_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_NEXT_PLUGGABLE_DEVICE_C_TF_DEVICE_CONTEXT_C_API_HELPER_H_ + +#include + +#include "tensorflow/core/common_runtime/next_pluggable_device/c/tf_device_context_c_api.h" +#include "tensorflow/core/framework/device_base.h" +#include "tensorflow/core/framework/rendezvous.h" + +namespace tensorflow { + +TF_DeviceContext* DeviceContext_ToC(DeviceContext* device_context); + +void DeviceContext_Destroy(TF_DeviceContext* c_device_context); + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_NEXT_PLUGGABLE_DEVICE_C_TF_DEVICE_CONTEXT_C_API_HELPER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/next_pluggable_device/c/tf_device_context_c_api_internal.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/next_pluggable_device/c/tf_device_context_c_api_internal.h new file mode 100644 index 00000000..52bf1ead --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/next_pluggable_device/c/tf_device_context_c_api_internal.h @@ -0,0 +1,31 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_NEXT_PLUGGABLE_DEVICE_C_TF_DEVICE_CONTEXT_C_API_INTERNAL_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_NEXT_PLUGGABLE_DEVICE_C_TF_DEVICE_CONTEXT_C_API_INTERNAL_H_ + +#include + +#include "tensorflow/core/common_runtime/next_pluggable_device/c/tf_device_context_c_api.h" +#include "tensorflow/core/framework/device_base.h" +#include "tensorflow/core/framework/rendezvous.h" + +namespace tensorflow { + +DeviceContext* DeviceContext_FromC(TF_DeviceContext* c_device_context); + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_NEXT_PLUGGABLE_DEVICE_C_TF_DEVICE_CONTEXT_C_API_INTERNAL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/next_pluggable_device/c/tf_rendezvous_c_api.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/next_pluggable_device/c/tf_rendezvous_c_api.h new file mode 100644 index 00000000..706efe42 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/next_pluggable_device/c/tf_rendezvous_c_api.h @@ -0,0 +1,102 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_NEXT_PLUGGABLE_DEVICE_C_TF_RENDEZVOUS_C_API_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_NEXT_PLUGGABLE_DEVICE_C_TF_RENDEZVOUS_C_API_H_ + +#include + +#include "tensorflow/c/c_api_macros.h" // IWYU pragma: export +#include "tensorflow/c/tf_status.h" +#include "tensorflow/c/tf_tensor.h" + +#ifdef __cplusplus +extern "C" { +#endif + +typedef struct TF_DeviceContext TF_DeviceContext; + +typedef struct TFDevice_AllocatorAttributes { + uint32_t value; + int32_t scope_id; +} TFDevice_AllocatorAttributes; + +typedef struct TFE_CancellationManager TFE_CancellationManager; + +typedef struct TF_RendezvousArgsStruct { + TF_DeviceContext* device_context; + TFDevice_AllocatorAttributes alloc_attrs; + TFE_CancellationManager* cancellation_manager; +} TF_RendezvousArgsStruct; + +typedef struct TF_RendezvousParsedKey { + char* full_key; + uint32_t full_key_size; +} TF_RendezvousParsedKey; + +typedef struct TF_RendezvousSend_Params { + const TF_RendezvousParsedKey* key; + const TF_RendezvousArgsStruct* args; + TF_Tensor* tensor; + bool is_dead; + + TF_Status* status; // out +} TF_RendezvousSend_Params; + +typedef void (*TF_RendezvousSend_Function)(void*, TF_RendezvousSend_Params*); + +typedef struct TF_RendezvousDoneCallback_Params { + void* context; + const TF_Status* status; + // TODO: Pass args through. + // const TF_RendezvousArgsStruct* sender_args; + // const TF_RendezvousArgsStruct* recver_args; + const TF_Tensor* tensor; + bool is_dead; +} TF_RendezvousDoneCallback_Params; + +typedef void (*TF_RendezvousDoneCallback_Function)( + void*, TF_RendezvousDoneCallback_Params*); + +typedef struct TF_RendezvousDoneCallbackImpl { + void* context; + TF_RendezvousDoneCallback_Function callback; +} TF_RendezvousDoneCallbackImpl; + +typedef struct TF_RendezvousAsyncRecv_Params { + void* context; + const TF_RendezvousParsedKey* key; + const TF_RendezvousArgsStruct* args; + TF_RendezvousDoneCallbackImpl on_done; +} TF_RendezvousAsyncRecv_Params; + +typedef void (*TF_RendezvousAsyncRecv_Function)(void*, + TF_RendezvousAsyncRecv_Params*); + +typedef void (*TF_RendezvousStartAbort_Function)(void* context, + const TF_Status*); + +typedef struct TF_RendezvousThunk { + void* rendezvous; + TF_RendezvousSend_Function send_func; + TF_RendezvousAsyncRecv_Function async_recv_func; + TF_RendezvousStartAbort_Function start_abort_func; +} TF_RendezvousThunk; + +#ifdef __cplusplus +} // extern "C" +#endif + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_NEXT_PLUGGABLE_DEVICE_C_TF_RENDEZVOUS_C_API_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/next_pluggable_device/c/tf_rendezvous_c_api_defn.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/next_pluggable_device/c/tf_rendezvous_c_api_defn.h new file mode 100644 index 00000000..9e9cbccd --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/next_pluggable_device/c/tf_rendezvous_c_api_defn.h @@ -0,0 +1,32 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_NEXT_PLUGGABLE_DEVICE_C_TF_RENDEZVOUS_C_API_DEFN_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_NEXT_PLUGGABLE_DEVICE_C_TF_RENDEZVOUS_C_API_DEFN_H_ + +#include "tensorflow/core/common_runtime/next_pluggable_device/c/tf_rendezvous_c_api.h" +#include "tensorflow/core/framework/cancellation.h" +#include "tensorflow/core/framework/device_base.h" +#include "tensorflow/core/framework/tensor.h" + +struct TF_CancellationManager { + tensorflow::CancellationManager* cancellation_manager; // not owned +}; + +struct TF_TensorWrapper { + tensorflow::Tensor tensor; +}; + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_NEXT_PLUGGABLE_DEVICE_C_TF_RENDEZVOUS_C_API_DEFN_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/next_pluggable_device/c/tf_rendezvous_c_api_helper.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/next_pluggable_device/c/tf_rendezvous_c_api_helper.h new file mode 100644 index 00000000..e55b8583 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/next_pluggable_device/c/tf_rendezvous_c_api_helper.h @@ -0,0 +1,31 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_NEXT_PLUGGABLE_DEVICE_C_TF_RENDEZVOUS_C_API_HELPER_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_NEXT_PLUGGABLE_DEVICE_C_TF_RENDEZVOUS_C_API_HELPER_H_ + +#include + +#include "tensorflow/core/common_runtime/next_pluggable_device/c/tf_rendezvous_c_api.h" +#include "tensorflow/core/framework/rendezvous.h" + +namespace tensorflow { + +std::unique_ptr FromC( + const TF_RendezvousThunk* thunk); + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_NEXT_PLUGGABLE_DEVICE_C_TF_RENDEZVOUS_C_API_HELPER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/next_pluggable_device/c/tf_rendezvous_c_api_internal.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/next_pluggable_device/c/tf_rendezvous_c_api_internal.h new file mode 100644 index 00000000..30cc3ec0 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/next_pluggable_device/c/tf_rendezvous_c_api_internal.h @@ -0,0 +1,29 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_NEXT_PLUGGABLE_DEVICE_C_TF_RENDEZVOUS_C_API_INTERNAL_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_NEXT_PLUGGABLE_DEVICE_C_TF_RENDEZVOUS_C_API_INTERNAL_H_ + +#include "tensorflow/core/common_runtime/next_pluggable_device/c/tf_rendezvous_c_api.h" +#include "tensorflow/core/framework/rendezvous.h" + +namespace tensorflow { + +TF_RendezvousThunk* ToC(tensorflow::RendezvousInterface* rendezvous); +void Destroy(TF_RendezvousThunk* thunk); + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_NEXT_PLUGGABLE_DEVICE_C_TF_RENDEZVOUS_C_API_INTERNAL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/next_pluggable_device/c/tf_tensor_utils.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/next_pluggable_device/c/tf_tensor_utils.h new file mode 100644 index 00000000..f1a35ffc --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/next_pluggable_device/c/tf_tensor_utils.h @@ -0,0 +1,30 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_NEXT_PLUGGABLE_DEVICE_C_TF_TENSOR_UTILS_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_NEXT_PLUGGABLE_DEVICE_C_TF_TENSOR_UTILS_H_ + +#include "tensorflow/c/tf_tensor.h" +#include "tensorflow/core/framework/tensor.h" + +namespace tensorflow { + +void CopyTF_TensorToTensor(const TF_Tensor* src, Tensor* dst); + +TF_Tensor* CopyTensorToTF_Tensor(const Tensor& src); + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_NEXT_PLUGGABLE_DEVICE_C_TF_TENSOR_UTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/next_pluggable_device/c_plugin_coordination_service_agent.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/next_pluggable_device/c_plugin_coordination_service_agent.h new file mode 100644 index 00000000..8d9d3268 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/next_pluggable_device/c_plugin_coordination_service_agent.h @@ -0,0 +1,60 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_NEXT_PLUGGABLE_DEVICE_C_PLUGIN_COORDINATION_SERVICE_AGENT_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_NEXT_PLUGGABLE_DEVICE_C_PLUGIN_COORDINATION_SERVICE_AGENT_H_ + +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/time/time.h" +#include "tensorflow/c/experimental/next_pluggable_device/c_api.h" +#include "tensorflow/c/kernels_experimental.h" +#include "tensorflow/core/common_runtime/next_pluggable_device/plugin_coordination_service_agent.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/statusor.h" + +namespace tensorflow { + +class CPluginCoordinationServiceAgent : public PluginCoordinationServiceAgent { + public: + explicit CPluginCoordinationServiceAgent(void* agent) + : agent_(reinterpret_cast(agent)) {} + + bool IsInitialized() const override { + if (agent_ == nullptr) return false; + return TF_CoordinationServiceIsInitialized(agent_); + } + + absl::Status InsertKeyValue(std::string_view key, + std::string_view value) override; + + absl::StatusOr GetKeyValue(std::string_view key) override; + absl::StatusOr GetKeyValue(std::string_view key, + absl::Duration timeout) override; + absl::StatusOr TryGetKeyValue(std::string_view key) override; + + absl::Status DeleteKeyValue(std::string_view key) override; + + private: + TF_CoordinationServiceAgent* agent_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_NEXT_PLUGGABLE_DEVICE_C_PLUGIN_COORDINATION_SERVICE_AGENT_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/next_pluggable_device/c_plugin_op_kernel.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/next_pluggable_device/c_plugin_op_kernel.h new file mode 100644 index 00000000..fa7206c2 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/next_pluggable_device/c_plugin_op_kernel.h @@ -0,0 +1,177 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_NEXT_PLUGGABLE_DEVICE_C_PLUGIN_OP_KERNEL_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_NEXT_PLUGGABLE_DEVICE_C_PLUGIN_OP_KERNEL_H_ + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "tensorflow/c/kernels.h" +#include "xla/pjrt/pjrt_client.h" +#include "tensorflow/core/common_runtime/next_pluggable_device/plugin_coordination_service_agent.h" +#include "tensorflow/core/common_runtime/next_pluggable_device/plugin_op_kernel.h" +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/protobuf/config.pb.h" +#include "tsl/platform/thread_annotations.h" + +namespace tensorflow { + +class CPluginOpKernelConstruction : public PluginOpKernelConstruction { + public: + explicit CPluginOpKernelConstruction(void* ctx) + : ctx_(reinterpret_cast(ctx)) {} + + absl::Status GetBoolAttr(std::string_view attr_name, + bool* value) const override; + absl::Status GetInt32Attr(std::string_view attr_name, + int* value) const override; + absl::Status GetInt32AttrList(std::string_view attr_name, + std::vector* value) const override; + absl::Status GetInt64Attr(std::string_view attr_name, + int64_t* value) const override; + absl::Status GetStringAttr(std::string_view attr_name, + std::string* value) const override; + absl::Status GetFunctionAttr(std::string_view attr_name, + NameAttrList* function) const override; + + void CtxFailure(const absl::Status& status) override; + void CtxFailure(const char* file, int line, + const absl::Status& status) override; + + void* GetContext() const override { return ctx_; } + + private: + TF_OpKernelConstruction* ctx_; // not owned. +}; + +class CPluginOpKernelContext : public PluginOpKernelContext { + public: + explicit CPluginOpKernelContext(void* ctx) + : ctx_(reinterpret_cast(ctx)) {} + + std::string_view GetResourceMgrDefaultContainerName() override; + + absl::Status LookupOrCreateResource(std::string_view container_name, + std::string_view plugin_resource_name, + void** result_plugin_resource, + void* (*create_func)(void*), + void* create_func_args, + void (*delete_func)(void*)) override; + + std::unique_ptr + GetPluginCoordinationServiceAgent() const override; + + absl::Status CreatePluginVariable(int index, + PluginVariable** variable) const override; + + absl::Status AllocateTempForPluginVariable(PluginVariable* variable) override; + + int NumInputs() const override { return TF_NumInputs(ctx_); } + + absl::Status GetInput(int index, const Tensor** tensor) const override; + + absl::Status GetInput(const char* name, const Tensor** tensor) const override; + + absl::Status GetInputRange(std::string_view name, + std::pair* range) const override; + + DataType GetInputDataType(int index) const override; + + std::string_view GetOpKernelRequestedInput(int index) const override; + + std::string_view GetOpKernelName() const override; + + uint64_t GetFrameId() const override { return TF_GetFrameId(ctx_); } + + int64_t GetIterId() const override { return TF_GetIterId(ctx_); } + + int64_t GetStepId() const override { return TF_GetStepId(ctx_); } + + int GetDeviceId() const override { return TF_GetDeviceId(ctx_); } + + std::string_view GetDeviceName() const override; + + std::string GetSessionName() const override { + // TODO(haoyuzhang): Implement with ctx_->session_metadata() if needed. + return ""; + } + + absl::Status GetConfigProto(const ConfigProto** config_proto) const override; + + // Note: this function is only meant to clear up `config_proto` created by the + // above `CPluginOpKernelContext::GetConfigProto()`. + void MaybeDeleteConfigProto(const ConfigProto* config_proto) const override { + delete config_proto; + } + + absl::Status GetFunctionLibraryDefinition( + const FunctionLibraryDefinition** flib_def) const override; + + // Note: this function is only meant to clear up `flib_def` created by the + // above `CPluginOpKernelContext::GetFunctionLibraryDefinition()`. + void MaybeDeleteFunctionLibraryDefinition( + const FunctionLibraryDefinition* flib_def) const override { + delete flib_def; + } + + absl::Status GetResourceHandle(int index, + const ResourceHandle** handle) const override; + + // Note: this function is only meant to clear up `handle` created by the above + // `CPluginOpKernelContext::GetResourceHandle()`. + void MaybeDeleteResourceHandle(const ResourceHandle* handle) const override { + delete handle; + } + + int GetGraphDefVersion() const override { + return TF_GetGraphDefVersion(ctx_); + } + + absl::Status AllocateOutput(int index, const TensorShape& shape, + Tensor** out) override; + + absl::Status SetOutput(int index, const Tensor& tensor) override; + + void CtxFailure(const absl::Status& status) override; + void CtxFailure(const char* file, int line, + const absl::Status& status) override; + + void* GetContext() const override { return ctx_; } + + private: + mutable mutex mu_; + + // A cache for tensors obtained from the ctx_. This is needed to extend the + // lifetime of the c++ tensorflow::Tensor created from `TF_TensorToTensor`. + // Use std::deque here to make sure elements in the container are pointer + // stable. + // "insertion and deletion at either end of a deque never invalidates pointers + // or references to the rest of the elements." + mutable std::deque obtained_tensors_ TF_GUARDED_BY(mu_); + TF_OpKernelContext* ctx_; // not owned. +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_NEXT_PLUGGABLE_DEVICE_C_PLUGIN_OP_KERNEL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/next_pluggable_device/c_plugin_variable.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/next_pluggable_device/c_plugin_variable.h new file mode 100644 index 00000000..157c5b45 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/next_pluggable_device/c_plugin_variable.h @@ -0,0 +1,51 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_NEXT_PLUGGABLE_DEVICE_C_PLUGIN_VARIABLE_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_NEXT_PLUGGABLE_DEVICE_C_PLUGIN_VARIABLE_H_ + +#include "absl/status/status.h" +#include "tensorflow/c/experimental/next_pluggable_device/c_api.h" +#include "tensorflow/core/common_runtime/next_pluggable_device/plugin_variable.h" +#include "tensorflow/core/framework/tensor.h" + +namespace tensorflow { + +class CPluginOpKernelContext; + +class CPluginVariable : public PluginVariable { + public: + ~CPluginVariable() override; + explicit CPluginVariable(TF_VariableInfo* var_info) : var_info_(var_info) {} + + absl::Status GetTensor(const Tensor** result_tensor) override; + + absl::Status GetMutableTensor(Tensor** result_tensor) override; + + TF_VariableInfo* GetVariableInfo() { return var_info_; } + + friend class CPluginOpKernelContext; + + private: + absl::Status GetTensorInternal(); + + TF_VariableInfo* var_info_; // Owned. Cleared by destructor. + bool tensor_obtained_ = false; + tensorflow::Tensor tensor_; // Tensor obtained from variable. +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_NEXT_PLUGGABLE_DEVICE_C_PLUGIN_VARIABLE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/next_pluggable_device/direct_plugin_coordination_service_agent.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/next_pluggable_device/direct_plugin_coordination_service_agent.h new file mode 100644 index 00000000..930efed4 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/next_pluggable_device/direct_plugin_coordination_service_agent.h @@ -0,0 +1,69 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_NEXT_PLUGGABLE_DEVICE_DIRECT_PLUGIN_COORDINATION_SERVICE_AGENT_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_NEXT_PLUGGABLE_DEVICE_DIRECT_PLUGIN_COORDINATION_SERVICE_AGENT_H_ + +#include +#include + +#include "absl/time/time.h" +#include "xla/tsl/distributed_runtime/coordination/coordination_service_agent.h" +#include "tensorflow/core/common_runtime/next_pluggable_device/plugin_coordination_service_agent.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/statusor.h" + +namespace tensorflow { + +class DirectPluginCoordinationServiceAgent + : public PluginCoordinationServiceAgent { + public: + explicit DirectPluginCoordinationServiceAgent(void* agent) + : agent_(reinterpret_cast(agent)) {} + + bool IsInitialized() const override { + if (agent_ == nullptr) return false; + return agent_->IsInitialized(); + } + + absl::Status InsertKeyValue(std::string_view key, + std::string_view value) override { + return agent_->InsertKeyValue(key, value); + } + + absl::StatusOr GetKeyValue(std::string_view key) override { + return agent_->GetKeyValue(key); + } + + absl::StatusOr GetKeyValue(std::string_view key, + absl::Duration timeout) override { + return agent_->GetKeyValue(key, timeout); + } + + absl::StatusOr TryGetKeyValue(std::string_view key) override { + return agent_->TryGetKeyValue(key); + } + + absl::Status DeleteKeyValue(std::string_view key) override { + return agent_->DeleteKeyValue(key); + } + + private: + tsl::CoordinationServiceAgent* agent_; // Not owned. +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_NEXT_PLUGGABLE_DEVICE_DIRECT_PLUGIN_COORDINATION_SERVICE_AGENT_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/next_pluggable_device/direct_plugin_op_kernel.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/next_pluggable_device/direct_plugin_op_kernel.h new file mode 100644 index 00000000..3df3543b --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/next_pluggable_device/direct_plugin_op_kernel.h @@ -0,0 +1,196 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_NEXT_PLUGGABLE_DEVICE_DIRECT_PLUGIN_OP_KERNEL_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_NEXT_PLUGGABLE_DEVICE_DIRECT_PLUGIN_OP_KERNEL_H_ + +#include +#include +#include +#include +#include +#include + +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "tensorflow/core/common_runtime/next_pluggable_device/plugin_coordination_service_agent_helper.h" +#include "tensorflow/core/common_runtime/next_pluggable_device/plugin_op_kernel.h" +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/resource_mgr.h" +#include "tensorflow/core/framework/types.pb.h" + +namespace tensorflow { + +class DirectPluginOpKernelConstruction : public PluginOpKernelConstruction { + public: + explicit DirectPluginOpKernelConstruction(void* ctx) + : ctx_(reinterpret_cast(ctx)) {} + + absl::Status GetBoolAttr(std::string_view attr_name, + bool* value) const override; + absl::Status GetInt32Attr(std::string_view attr_name, + int* value) const override; + absl::Status GetInt32AttrList(std::string_view attr_name, + std::vector* value) const override; + absl::Status GetInt64Attr(std::string_view attr_name, + int64_t* value) const override; + absl::Status GetStringAttr(std::string_view attr_name, + std::string* value) const override; + absl::Status GetFunctionAttr(std::string_view attr_name, + NameAttrList* function) const override; + + void CtxFailure(const absl::Status& status) override { + ctx_->CtxFailure(status); + } + + void CtxFailure(const char* file, int line, + const absl::Status& status) override { + ctx_->CtxFailure(file, line, status); + } + + void* GetContext() const override { return ctx_; } + + private: + OpKernelConstruction* ctx_; // not owned. +}; + +class DirectPluginOpKernelContext : public PluginOpKernelContext { + public: + explicit DirectPluginOpKernelContext(OpKernelContext* ctx) : ctx_(ctx) {} + + std::string_view GetResourceMgrDefaultContainerName() override; + + absl::Status LookupOrCreateResource(std::string_view container_name, + std::string_view plugin_resource_name, + void** result_plugin_resource, + void* (*create_func)(void*), + void* create_func_args, + void (*delete_func)(void*)) override; + + std::unique_ptr + GetPluginCoordinationServiceAgent() const override { + return CreatePluginCoordinationServiceAgent( + ctx_->coordination_service_agent()); + } + + absl::Status CreatePluginVariable(int index, + PluginVariable** variable) const override; + + absl::Status AllocateTempForPluginVariable(PluginVariable* variable) override; + + int NumInputs() const override { return ctx_->num_inputs(); } + + absl::Status GetInput(int index, const Tensor** tensor) const override; + + absl::Status GetInput(const char* name, const Tensor** tensor) const override; + + absl::Status GetInputRange(std::string_view name, + std::pair* range) const override; + + DataType GetInputDataType(int index) const override { + return ctx_->input_dtype(index); + } + + std::string_view GetOpKernelRequestedInput(int index) const override { + return ctx_->op_kernel().requested_input(index); + } + + std::string_view GetOpKernelName() const override { + return ctx_->op_kernel().name(); + } + + uint64_t GetFrameId() const override { return ctx_->frame_iter().frame_id; } + + int64_t GetIterId() const override { return ctx_->frame_iter().iter_id; } + + int64_t GetStepId() const override { return ctx_->step_id(); } + + int GetDeviceId() const override; + + std::string_view GetDeviceName() const override; + + std::string GetSessionName() const override { + return ctx_->session_metadata() ? ctx_->session_metadata()->name() : ""; + } + + absl::Status GetConfigProto(const ConfigProto** config_proto) const override { + *config_proto = ctx_->function_library()->config_proto(); + return absl::OkStatus(); + } + + void MaybeDeleteConfigProto(const ConfigProto* config_proto) const override { + // We don't need to specifically delete ConfigProto since it is obtained + // from FunctionLibraryRuntime in `ctx_`. + } + + absl::Status GetFunctionLibraryDefinition( + const FunctionLibraryDefinition** flib_def) const override { + *flib_def = ctx_->function_library()->GetFunctionLibraryDefinition(); + return absl::OkStatus(); + } + + void MaybeDeleteFunctionLibraryDefinition( + const FunctionLibraryDefinition* flib_def) const override { + // We don't need to specifically delete FunctionLibraryDefinition since it + // is obtained from FunctionLibraryRuntime in `ctx_`. + } + + absl::Status GetResourceHandle(int index, + const ResourceHandle** handle) const override { + *handle = &HandleFromInput(ctx_, index); + return absl::OkStatus(); + } + + void MaybeDeleteResourceHandle(const ResourceHandle* handle) const override { + // We don't need to specifically delete ResourceHandle since it is obtained + // from `ctx_`. + } + + int GetGraphDefVersion() const override { + return ctx_->function_library()->graph_def_version(); + } + + absl::Status AllocateOutput(int index, const TensorShape& shape, + Tensor** out) override { + return ctx_->allocate_output(index, shape, out); + } + + absl::Status SetOutput(int index, const Tensor& tensor) override { + ctx_->set_output(index, tensor); + return absl::OkStatus(); + } + + void CtxFailure(const absl::Status& status) override { + ctx_->CtxFailure(status); + } + + void CtxFailure(const char* file, int line, + const absl::Status& status) override { + LOG(WARNING) << "Plugin OP_REQUIRES failed at " << file << ": " << line + << ": " << status; + ctx_->CtxFailure(file, line, status); + } + + void* GetContext() const override { return ctx_; } + + private: + OpKernelContext* ctx_; // not owned. +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_NEXT_PLUGGABLE_DEVICE_DIRECT_PLUGIN_OP_KERNEL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/next_pluggable_device/direct_plugin_variable.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/next_pluggable_device/direct_plugin_variable.h new file mode 100644 index 00000000..bbbcfee6 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/next_pluggable_device/direct_plugin_variable.h @@ -0,0 +1,53 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_NEXT_PLUGGABLE_DEVICE_DIRECT_PLUGIN_VARIABLE_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_NEXT_PLUGGABLE_DEVICE_DIRECT_PLUGIN_VARIABLE_H_ + +#include + +#include "absl/status/status.h" +#include "tensorflow/compiler/jit/variable_info.h" +#include "tensorflow/core/common_runtime/next_pluggable_device/plugin_variable.h" +#include "tsl/platform/status.h" + +namespace tensorflow { + +class DirectPluginOpKernelContext; + +class DirectPluginVariable : public PluginVariable { + public: + DirectPluginVariable(int index, const std::string& name, Var* var); + absl::Status GetTensor(const Tensor** result_tensor) override { + *result_tensor = var_info_.var()->tensor(); + return absl::OkStatus(); + } + + absl::Status GetMutableTensor(Tensor** result_tensor) override { + *result_tensor = var_info_.var()->tensor(); + return absl::OkStatus(); + } + + VariableInfo* GetVariableInfo() { return &var_info_; } + + friend DirectPluginOpKernelContext; + + private: + VariableInfo var_info_{0, "", nullptr}; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_NEXT_PLUGGABLE_DEVICE_DIRECT_PLUGIN_VARIABLE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/next_pluggable_device/flags.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/next_pluggable_device/flags.h new file mode 100644 index 00000000..681155e1 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/next_pluggable_device/flags.h @@ -0,0 +1,23 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_NEXT_PLUGGABLE_DEVICE_FLAGS_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_NEXT_PLUGGABLE_DEVICE_FLAGS_H_ + +#include "absl/flags/declare.h" + +ABSL_DECLARE_FLAG(bool, next_pluggable_device_use_c_api); + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_NEXT_PLUGGABLE_DEVICE_FLAGS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/next_pluggable_device/next_pluggable_device.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/next_pluggable_device/next_pluggable_device.h new file mode 100644 index 00000000..cb8ecf51 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/next_pluggable_device/next_pluggable_device.h @@ -0,0 +1,96 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_NEXT_PLUGGABLE_DEVICE_NEXT_PLUGGABLE_DEVICE_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_NEXT_PLUGGABLE_DEVICE_NEXT_PLUGGABLE_DEVICE_H_ + +#include +#include +#include + +#include "absl/status/status.h" +#include "tensorflow/compiler/jit/pjrt_base_device.h" +#include "tensorflow/compiler/tf2xla/layout_util.h" +#include "tensorflow/core/common_runtime/local_device.h" +#include "tensorflow/core/common_runtime/next_pluggable_device/next_pluggable_device_context.h" +#include "tensorflow/core/platform/refcount.h" +#include "tensorflow/core/tfrt/common/async_value_tensor.h" + +namespace tensorflow { + +class NextPluggableDeviceAllocator; + +class NextPluggableDevice : public PjRtBaseDevice { + public: + struct Options { + // The device name's prefix (e.g., "/task:7") + string device_name_prefix; + + // The name of the device (e.g., "GPU") + string device_name; + + // The name of the compilation device (e.g., "XLA_TPU_JIT"); + string compilation_device_name; + + // The TfDeviceId. + int device_ordinal = -1; + + // A vector of ShapeDeterminationFn (i.e., a bundle of LayoutSelectionFn, + // ShapeRepresentationFn). Each bundle describes how the on-host shapes of + // a) argument and return value, for entry computations b) variables, for + // all computations, should be represented in XLA. Parameters/return values + // will be shaped according to the function pair, and reshaped back to/from + // their declared shapes for computations. Must be non-empty. + std::vector + shape_determination_fns; + }; + + NextPluggableDevice(const SessionOptions& session_options, + const Options& options); + + ~NextPluggableDevice() override; + + Allocator* GetAllocator(AllocatorAttributes attr) override; + + void Compute(OpKernel* op_kernel, OpKernelContext* context) override; + + void ComputeAsync(AsyncOpKernel* op_kernel, OpKernelContext* context, + AsyncOpKernel::DoneCallback done) override; + + absl::Status Sync() override; + + void Sync(const DoneCallback& done) override; + + absl::Status TryGetDeviceContext(DeviceContext** out_context) override; + + absl::Status MakeTensorFromProto(const TensorProto& tensor_proto, + AllocatorAttributes alloc_attrs, + Tensor* tensor) override; + + int GetDeviceOrdinal() const { return device_ordinal_; } + + private: + int device_ordinal_; + // Need to use RefCountPtr since DeviceContext is a ref counted object. + core::RefCountPtr device_context_; + std::unique_ptr tfnpd_allocator_; + std::unique_ptr pjrt_allocator_; + Allocator* allocator_ = nullptr; // Not owned. + std::unique_ptr accelerator_device_info_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_NEXT_PLUGGABLE_DEVICE_NEXT_PLUGGABLE_DEVICE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/next_pluggable_device/next_pluggable_device_allocator.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/next_pluggable_device/next_pluggable_device_allocator.h new file mode 100644 index 00000000..15cb583b --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/next_pluggable_device/next_pluggable_device_allocator.h @@ -0,0 +1,55 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_NEXT_PLUGGABLE_DEVICE_NEXT_PLUGGABLE_DEVICE_ALLOCATOR_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_NEXT_PLUGGABLE_DEVICE_NEXT_PLUGGABLE_DEVICE_ALLOCATOR_H_ + +#include +#include + +#include "tensorflow/core/common_runtime/next_pluggable_device/c/plugin_c_api.h" +#include "tensorflow/core/framework/allocator.h" + +class TFNPD_DeviceAllocator; + +namespace tensorflow { + +class NextPluggableDeviceAllocator : public Allocator { + public: + explicit NextPluggableDeviceAllocator(int device_ordinal); + + ~NextPluggableDeviceAllocator() override; + + void* AllocateRaw(size_t alignment, size_t num_bytes) override; + + void DeallocateRaw(void* ptr) override; + + std::string Name() override { return device_allocator_name_; } + + bool AllocatesOpaqueHandle() const override { + return allocates_opaque_handle_; + } + + private: + const TFNPD_Api* api_; + int device_ordinal_; + std::string device_allocator_name_; + bool allocates_opaque_handle_; + TFNPD_DeviceAllocator* device_allocator_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_NEXT_PLUGGABLE_DEVICE_NEXT_PLUGGABLE_DEVICE_ALLOCATOR_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/next_pluggable_device/next_pluggable_device_api.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/next_pluggable_device/next_pluggable_device_api.h new file mode 100644 index 00000000..026febe2 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/next_pluggable_device/next_pluggable_device_api.h @@ -0,0 +1,37 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_NEXT_PLUGGABLE_DEVICE_NEXT_PLUGGABLE_DEVICE_API_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_NEXT_PLUGGABLE_DEVICE_NEXT_PLUGGABLE_DEVICE_API_H_ + +#include + +#include "absl/status/statusor.h" +#include "tensorflow/core/common_runtime/next_pluggable_device/c/plugin_c_api.h" +#include "tsl/platform/statusor.h" + +namespace tensorflow { + +// Global TFNPD_Api* singleton. +const TFNPD_Api* TfnpdApi(); +void SetTfnpdApi(const TFNPD_Api* api); + +typedef const TFNPD_Api* (*TFNPDInitPluginFn)(TFNPD_PluginParams*, TF_Status*); +absl::StatusOr InitNextPluggableDevicePlugin( + TFNPDInitPluginFn init_fn); + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_NEXT_PLUGGABLE_DEVICE_NEXT_PLUGGABLE_DEVICE_API_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/next_pluggable_device/next_pluggable_device_context.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/next_pluggable_device/next_pluggable_device_context.h new file mode 100644 index 00000000..185e5f5e --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/next_pluggable_device/next_pluggable_device_context.h @@ -0,0 +1,53 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_NEXT_PLUGGABLE_DEVICE_NEXT_PLUGGABLE_DEVICE_CONTEXT_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_NEXT_PLUGGABLE_DEVICE_NEXT_PLUGGABLE_DEVICE_CONTEXT_H_ + +#include "absl/strings/string_view.h" +#include "tensorflow/core/common_runtime/next_pluggable_device/c/plugin_c_api.h" +#include "tensorflow/core/framework/device_base.h" +#include "tensorflow/core/platform/status.h" + +class TFNPD_DeviceContext; + +namespace tensorflow { + +// Helper class for managing data transfers between host and accelerator +// devices. +class NextPluggableDeviceContext : public DeviceContext { + public: + explicit NextPluggableDeviceContext(int device_ordinal); + + ~NextPluggableDeviceContext() override; + + void CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device, + Tensor* device_tensor, StatusCallback done, + bool sync_dst_compute) const override; + void CopyDeviceTensorToCPU(const Tensor* device_tensor, + absl::string_view tensor_name, Device* device, + Tensor* cpu_tensor, StatusCallback done) override; + void CopyTensorInSameDevice(const Tensor* input_tensor, Device* device, + Tensor* output_tensor, + StatusCallback done) const override; + + private: + const TFNPD_Api* api_; + TFNPD_DeviceContext* context_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_NEXT_PLUGGABLE_DEVICE_NEXT_PLUGGABLE_DEVICE_CONTEXT_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/next_pluggable_device/next_pluggable_device_factory.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/next_pluggable_device/next_pluggable_device_factory.h new file mode 100644 index 00000000..5ccfb6dd --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/next_pluggable_device/next_pluggable_device_factory.h @@ -0,0 +1,57 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_NEXT_PLUGGABLE_DEVICE_NEXT_PLUGGABLE_DEVICE_FACTORY_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_NEXT_PLUGGABLE_DEVICE_NEXT_PLUGGABLE_DEVICE_FACTORY_H_ + +#include +#include +#include + +#include "absl/status/status.h" +#include "tensorflow/core/common_runtime/next_pluggable_device/c/plugin_c_api.h" +#include "tensorflow/core/common_runtime/next_pluggable_device/next_pluggable_device_api.h" +#include "tensorflow/core/framework/device_factory.h" + +namespace tensorflow { + +class NextPluggableDeviceFactory : public DeviceFactory { + public: + explicit NextPluggableDeviceFactory( + const std::string& device_type, + const std::string& compilation_device_name) + : api_(TfnpdApi()), + device_type_(device_type), + compilation_device_name_(compilation_device_name) {} + + absl::Status ListPhysicalDevices(std::vector* devices) override; + + absl::Status CreateDevices( + const SessionOptions& session_options, const std::string& name_prefix, + std::vector>* devices) override; + + const std::string& compilation_device_name() const { + return compilation_device_name_; + } + + private: + const TFNPD_Api* api_; + const std::string device_type_; + const std::string compilation_device_name_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_NEXT_PLUGGABLE_DEVICE_NEXT_PLUGGABLE_DEVICE_FACTORY_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/next_pluggable_device/plugin_coordination_service_agent.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/next_pluggable_device/plugin_coordination_service_agent.h new file mode 100644 index 00000000..4d3a1734 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/next_pluggable_device/plugin_coordination_service_agent.h @@ -0,0 +1,48 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_NEXT_PLUGGABLE_DEVICE_PLUGIN_COORDINATION_SERVICE_AGENT_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_NEXT_PLUGGABLE_DEVICE_PLUGIN_COORDINATION_SERVICE_AGENT_H_ + +#include +#include + +#include "absl/time/time.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/statusor.h" + +namespace tensorflow { + +class PluginCoordinationServiceAgent { + public: + PluginCoordinationServiceAgent() = default; + virtual ~PluginCoordinationServiceAgent() = default; + + virtual bool IsInitialized() const = 0; + + virtual absl::Status InsertKeyValue(std::string_view key, + std::string_view value) = 0; + + virtual absl::StatusOr GetKeyValue(std::string_view key) = 0; + virtual absl::StatusOr GetKeyValue(std::string_view key, + absl::Duration timeout) = 0; + virtual absl::StatusOr TryGetKeyValue(std::string_view key) = 0; + + virtual absl::Status DeleteKeyValue(std::string_view key) = 0; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_NEXT_PLUGGABLE_DEVICE_PLUGIN_COORDINATION_SERVICE_AGENT_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/next_pluggable_device/plugin_coordination_service_agent_helper.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/next_pluggable_device/plugin_coordination_service_agent_helper.h new file mode 100644 index 00000000..a5adfa50 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/next_pluggable_device/plugin_coordination_service_agent_helper.h @@ -0,0 +1,42 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_NEXT_PLUGGABLE_DEVICE_PLUGIN_COORDINATION_SERVICE_AGENT_HELPER_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_NEXT_PLUGGABLE_DEVICE_PLUGIN_COORDINATION_SERVICE_AGENT_HELPER_H_ + +#include + +#include "absl/flags/flag.h" +#include "tensorflow/c/kernels.h" +#include "tensorflow/c/tf_status_helper.h" +#include "tensorflow/core/common_runtime/next_pluggable_device/c_plugin_coordination_service_agent.h" +#include "tensorflow/core/common_runtime/next_pluggable_device/direct_plugin_coordination_service_agent.h" +#include "tensorflow/core/common_runtime/next_pluggable_device/flags.h" +#include "tensorflow/core/common_runtime/next_pluggable_device/plugin_coordination_service_agent.h" + +namespace tensorflow { + +inline std::unique_ptr +CreatePluginCoordinationServiceAgent(void* agent) { + if (!absl::GetFlag(FLAGS_next_pluggable_device_use_c_api)) { + return std::make_unique(agent); + } else { + return std::make_unique(agent); + } +} + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_NEXT_PLUGGABLE_DEVICE_PLUGIN_COORDINATION_SERVICE_AGENT_HELPER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/next_pluggable_device/plugin_op_kernel.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/next_pluggable_device/plugin_op_kernel.h new file mode 100644 index 00000000..b0123999 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/next_pluggable_device/plugin_op_kernel.h @@ -0,0 +1,174 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_NEXT_PLUGGABLE_DEVICE_PLUGIN_OP_KERNEL_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_NEXT_PLUGGABLE_DEVICE_PLUGIN_OP_KERNEL_H_ + +#include +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "xla/pjrt/pjrt_client.h" +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/platform/status.h" + +namespace tensorflow { + +class ConfigProto; +class FunctionLibraryDefinition; +class OpInputList; +class PluginCoordinationServiceAgent; +class PluginVariable; +class Tensor; +class TensorShape; + +// A wrapper base class that provides convenience for developers to implement +// to plugin OpKernels that suites internal and external requirements, without +// duplicating code. +// +// Internal build: Plugin and TF are built together and statically linked. In +// this case, we can directly cast between `TF_OpKernelContext*` and +// `OpKernelContext*`, and directly call C++ API. This way don't need to pay the +// potential performance panelty (e.g. proto serialization/deserialization) +// brought by C API. +// +// External build: Plugin and TF are built separately (potentially on different +// platform and by different compilers). Plugin is dynamically loaded by TF. +// In this case, we need to call C API to ensure binary compatibility. +// +// `DirectPluginOpKernel*` and `CPluginOpKernel*` implement `PluginOpKernel*` +// to support the above mentioned internal and external build cases. OpKernel +// developers can conveniently use the `Wrapper` C++ API to implement `Create` +// and `Compute` functions, and use the helper macro to register the functions +// as a Plugin OpKernel. This method benefit kernel developers in two ways: 1). +// Plugin OpKernel developers don't have to directly deal with C API. 2). In the +// OpKernels are performance critical and developers want to introduce an +// internal version of the same OpKernels, they don't have to implement again +// with mostly duplicated code. +class PluginOpKernelConstruction { + public: + PluginOpKernelConstruction() = default; + virtual ~PluginOpKernelConstruction() = default; + + virtual absl::Status GetBoolAttr(std::string_view attr_name, + bool* value) const = 0; + virtual absl::Status GetInt32Attr(std::string_view attr_name, + int* value) const = 0; + virtual absl::Status GetInt32AttrList(std::string_view attr_name, + std::vector* value) const = 0; + virtual absl::Status GetInt64Attr(std::string_view attr_name, + int64_t* value) const = 0; + virtual absl::Status GetStringAttr(std::string_view attr_name, + std::string* value) const = 0; + virtual absl::Status GetFunctionAttr(std::string_view attr_name, + NameAttrList* function) const = 0; + + virtual void CtxFailure(const absl::Status& status) = 0; + virtual void CtxFailure(const char* file, int line, + const absl::Status& status) = 0; + + virtual void* GetContext() const = 0; +}; + +class PluginOpKernelContext { + public: + PluginOpKernelContext() = default; + virtual ~PluginOpKernelContext() = default; + + virtual std::string_view GetResourceMgrDefaultContainerName() = 0; + + virtual absl::Status LookupOrCreateResource( + std::string_view container_name, std::string_view plugin_resource_name, + void** result_plugin_resource, void* (*create_func)(void*), + void* create_func_args, void (*delete_func)(void*)) = 0; + + virtual std::unique_ptr + GetPluginCoordinationServiceAgent() const = 0; + + // This method will allocate a new `PluginVariable`. Caller is responsible + // for managing it's lifetime. + virtual absl::Status CreatePluginVariable( + int index, PluginVariable** variable) const = 0; + + virtual absl::Status AllocateTempForPluginVariable( + PluginVariable* variable) = 0; + + virtual int NumInputs() const = 0; + + virtual absl::Status GetInput(int index, const Tensor** tensor) const = 0; + + virtual absl::Status GetInput(const char* name, + const Tensor** tensor) const = 0; + + virtual absl::Status GetInputRange(std::string_view name, + std::pair* range) const = 0; + + virtual DataType GetInputDataType(int index) const = 0; + + virtual std::string_view GetOpKernelRequestedInput(int index) const = 0; + + virtual std::string_view GetOpKernelName() const = 0; + + virtual uint64_t GetFrameId() const = 0; + + virtual int64_t GetIterId() const = 0; + + virtual int64_t GetStepId() const = 0; + + virtual int GetDeviceId() const = 0; + + virtual std::string_view GetDeviceName() const = 0; + + virtual std::string GetSessionName() const = 0; + + virtual absl::Status GetConfigProto( + const ConfigProto** config_proto) const = 0; + + virtual void MaybeDeleteConfigProto( + const ConfigProto* config_proto) const = 0; + + virtual absl::Status GetFunctionLibraryDefinition( + const FunctionLibraryDefinition** flib_def) const = 0; + + virtual void MaybeDeleteFunctionLibraryDefinition( + const FunctionLibraryDefinition* flib_def) const = 0; + + virtual absl::Status GetResourceHandle( + int index, const ResourceHandle** handle) const = 0; + + virtual void MaybeDeleteResourceHandle( + const ResourceHandle* handle) const = 0; + + virtual int GetGraphDefVersion() const = 0; + + virtual absl::Status AllocateOutput(int index, const TensorShape& shape, + Tensor** out) = 0; + + virtual absl::Status SetOutput(int index, const Tensor& tensor) = 0; + + virtual void CtxFailure(const absl::Status& status) = 0; + virtual void CtxFailure(const char* file, int line, + const absl::Status& status) = 0; + + virtual void* GetContext() const = 0; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_NEXT_PLUGGABLE_DEVICE_PLUGIN_OP_KERNEL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/next_pluggable_device/plugin_op_kernel_helper.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/next_pluggable_device/plugin_op_kernel_helper.h new file mode 100644 index 00000000..1f51f7c4 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/next_pluggable_device/plugin_op_kernel_helper.h @@ -0,0 +1,124 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_NEXT_PLUGGABLE_DEVICE_PLUGIN_OP_KERNEL_HELPER_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_NEXT_PLUGGABLE_DEVICE_PLUGIN_OP_KERNEL_HELPER_H_ + +#include "absl/flags/flag.h" +#include "tensorflow/core/common_runtime/next_pluggable_device/c_plugin_op_kernel.h" +#include "tensorflow/core/common_runtime/next_pluggable_device/direct_plugin_op_kernel.h" +#include "tensorflow/core/common_runtime/next_pluggable_device/flags.h" +#include "tensorflow/core/common_runtime/next_pluggable_device/plugin_op_kernel.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tsl/platform/macros.h" + +namespace tensorflow { + +inline PluginOpKernelConstruction* CreatePluginOpKernelConstruction(void* ctx) { + if (!absl::GetFlag(FLAGS_next_pluggable_device_use_c_api)) { + return new DirectPluginOpKernelConstruction(ctx); + } else { + return new CPluginOpKernelConstruction(ctx); + } +} + +inline void DeletePluginOpKernelConstruction( + PluginOpKernelConstruction* wrapper) { + delete wrapper; +} + +inline PluginOpKernelContext* CreatePluginOpKernelContext(void* ctx) { + if (!absl::GetFlag(FLAGS_next_pluggable_device_use_c_api)) { + return new DirectPluginOpKernelContext( + reinterpret_cast(ctx)); + } else { + return new CPluginOpKernelContext(ctx); + } +} + +inline void DeletePluginOpKernelContext(PluginOpKernelContext* wrapper) { + delete wrapper; +} + +#define PLUGIN_OP_REQUIRES_OK(CTX, ...) \ + do { \ + absl::Status _s(__VA_ARGS__); \ + if (!TF_PREDICT_TRUE(_s.ok())) { \ + (CTX)->CtxFailure(__FILE__, __LINE__, _s); \ + return; \ + } \ + } while (0) + +// A helper to register C OpKernel. CREATE_FN, COMPUTE_FN, and DELETE_FN are +// expected to be defined in the same file where this macro is used. +// +// HOST_MEMORY_ARGS a string containing names of args to be placed on host +// memory. Names are expected to be comma separated. +// +// TODO(chuanhao): simplify the registration macro. reference: +// REGISTER_KERNEL_BUILDER +#define REGISTER_WRAPPED_C_OPKERNEL_HOST_MEM_ARGS( \ + KERNEL_NAME, CREATE_FN, COMPUTE_FN, DELETE_FN, DEVICE, PRIORITY, \ + HOST_MEMORY_ARGS) \ + { \ + typedef void* (*wrapped_create_func)(TF_OpKernelConstruction*); \ + typedef void (*wrapped_compute_func)(void*, TF_OpKernelContext*); \ + \ + TF_StatusPtr status_ptr(TF_NewStatus()); \ + \ + wrapped_create_func create_func = \ + [](TF_OpKernelConstruction* ctx) -> void* { \ + PluginOpKernelConstruction* ctx_wrapper = \ + CreatePluginOpKernelConstruction(ctx); \ + void* kernel = CREATE_FN(ctx_wrapper); \ + delete ctx_wrapper; \ + return kernel; \ + }; \ + \ + wrapped_compute_func compute_func = [](void* kernel, \ + TF_OpKernelContext* ctx) -> void { \ + PluginOpKernelContext* ctx_wrapper = CreatePluginOpKernelContext(ctx); \ + COMPUTE_FN(kernel, ctx_wrapper); \ + delete ctx_wrapper; \ + }; \ + \ + auto* builder = TF_NewKernelBuilder(KERNEL_NAME, DEVICE, create_func, \ + compute_func, &DELETE_FN); \ + \ + /* NOTE: We explicitly set the priority to 1 to overwrite the */ \ + /* StreamExecutor based OpKernel of the same op. */ \ + TF_KernelBuilder_Priority(builder, PRIORITY); \ + \ + std::stringstream s_stream(HOST_MEMORY_ARGS); \ + while (s_stream.good()) { \ + std::string host_mem_arg; \ + std::getline(s_stream, host_mem_arg, ','); \ + if (host_mem_arg.empty()) break; \ + TF_KernelBuilder_HostMemory(builder, host_mem_arg.c_str()); \ + } \ + \ + TF_RegisterKernelBuilder(KERNEL_NAME, builder, status_ptr.get()); \ + CHECK_EQ(TF_OK, TF_GetCode(status_ptr.get())) \ + << "Error while registering " << KERNEL_NAME << " kernel."; \ + } + +#define REGISTER_WRAPPED_C_OPKERNEL(KERNEL_NAME, CREATE_FN, COMPUTE_FN, \ + DELETE_FN, DEVICE, PRIORITY) \ + REGISTER_WRAPPED_C_OPKERNEL_HOST_MEM_ARGS( \ + KERNEL_NAME, CREATE_FN, COMPUTE_FN, DELETE_FN, DEVICE, PRIORITY, "") + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_NEXT_PLUGGABLE_DEVICE_PLUGIN_OP_KERNEL_HELPER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/next_pluggable_device/plugin_resource.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/next_pluggable_device/plugin_resource.h new file mode 100644 index 00000000..c72fe952 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/next_pluggable_device/plugin_resource.h @@ -0,0 +1,56 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_NEXT_PLUGGABLE_DEVICE_PLUGIN_RESOURCE_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_NEXT_PLUGGABLE_DEVICE_PLUGIN_RESOURCE_H_ + +#include +#include + +#include "tensorflow/core/framework/resource_base.h" + +namespace tensorflow { + +// A wrapper class for plugin to create resources to the ResourceMgr managed by +// TensorFlow. The main motivation is to make resources in plugin have the same +// lifetime as TensorFlow ResourceMgr. +// +// Usage: +// Plugin uses a TensorFlow C API `TF_CreatePluginResource()`, +// to register the `PluginResource` to the ResourceMgr managed by TensorFlow. +// `PluginResource` holds a opaque pointer and a deleter function. The deleter +// will be called at `PluginResource`'s destruction. +class PluginResource : public ResourceBase { + public: + PluginResource(void* plugin_resource, std::string_view plugin_resource_name, + void (*delete_func)(void* plugin_resource)) + : resource_(plugin_resource), + resource_name_(plugin_resource_name), + delete_func_(delete_func) {} + ~PluginResource() override; + + void* GetOpaquePluginResource() { return resource_; } + + std::string DebugString() const override { return resource_name_; } + + private: + void* resource_; + std::string resource_name_; + void (*delete_func_)(void* plugin_resource); +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_NEXT_PLUGGABLE_DEVICE_PLUGIN_RESOURCE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/next_pluggable_device/plugin_variable.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/next_pluggable_device/plugin_variable.h new file mode 100644 index 00000000..ab2ec9a2 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/next_pluggable_device/plugin_variable.h @@ -0,0 +1,46 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_NEXT_PLUGGABLE_DEVICE_PLUGIN_VARIABLE_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_NEXT_PLUGGABLE_DEVICE_PLUGIN_VARIABLE_H_ + +#include "tsl/platform/status.h" + +namespace tensorflow { + +class Tensor; + +// A helper base class that wraps tensorflow::VariableInfo for the convenience +// of passing between plugin and tensorflow. Similar to `PluginOpKernelContext`, +// the implementations can accomodate for "Internal build" and "External build", +// meaning the plugin is built with TensorFlow either together or separately. In +// repsective build modes, the implementations can either include +// tensorflow::VariableInfo and use C++ API directly, or include the C structure +// `TF_VariableInfo` and use the corresponding C API. +class PluginVariable { + public: + PluginVariable() = default; + virtual ~PluginVariable() = default; + + // `result_tensor` will point to the tensor possessed by the variable if + // status is ok. + virtual absl::Status GetTensor(const Tensor** result_tensor) = 0; + + virtual absl::Status GetMutableTensor(Tensor** result_tensor) = 0; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_NEXT_PLUGGABLE_DEVICE_PLUGIN_VARIABLE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/next_pluggable_device/utils.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/next_pluggable_device/utils.h new file mode 100644 index 00000000..9739c009 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/next_pluggable_device/utils.h @@ -0,0 +1,28 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_NEXT_PLUGGABLE_DEVICE_UTILS_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_NEXT_PLUGGABLE_DEVICE_UTILS_H_ + +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "xla/c/c_api_decl.h" + +namespace tensorflow { + +XLA_LayoutPreference ConvertToCXlaLayoutPreference(XlaLayoutPreference input); +XlaLayoutPreference ConvertFromCXlaLayoutPreference(XLA_LayoutPreference input); + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_NEXT_PLUGGABLE_DEVICE_UTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/no_op_cost_measurement.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/no_op_cost_measurement.h new file mode 100644 index 00000000..6c2cc659 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/no_op_cost_measurement.h @@ -0,0 +1,39 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_NO_OP_COST_MEASUREMENT_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_NO_OP_COST_MEASUREMENT_H_ + +#include "absl/strings/string_view.h" +#include "tensorflow/core/common_runtime/cost_measurement.h" +#include "tensorflow/core/common_runtime/cost_measurement_registry.h" + +namespace tensorflow { + +// This class does not do the real cost measurement. It will always return zero +// Duration as the total cost. It's created to allow callers to skip collecting +// costs. +class NoOpCostMeasurement : public CostMeasurement { + public: + using CostMeasurement::CostMeasurement; + + // Always returns zero Duration as the total cost. + absl::Duration GetTotalCost() override; + absl::string_view GetCostType() const override; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_NO_OP_COST_MEASUREMENT_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/node_file_writer.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/node_file_writer.h new file mode 100644 index 00000000..4b92453b --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/node_file_writer.h @@ -0,0 +1,72 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_NODE_FILE_WRITER_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_NODE_FILE_WRITER_H_ + +#include +#include + +#include "absl/container/flat_hash_set.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/util/env_var.h" + +namespace tensorflow { + +// Writes out the NodeDef and the input shapes/dtypes for an executed node to a +// file. This allows the set of executed nodes for a model or test to be +// examined and processed. Currently this is used by an internal tool which +// checks that ops executed by tests are deterministic. +class NodeFileWriter { + public: + // Creates or reuses a NodeFileWriter if environmental variable + // TF_NODE_FILE_WRITER_DIRECTORY is set, which specifies the directory where + // the node file will be created in. Otherwise, returns nullptr. When called + // with the same device_name, the same NodeFileWriter will be returned. + static absl::StatusOr GetNodeFileWriterIfEnabled( + const std::string& device_name, Env* env); + + // Records the execution of a node, if eligible, by writing the node to the + // file. Only writes the node if the exact node with the given input + // shapes/dtypes hasn't already been written. Should be called once every time + // a node is run. + absl::Status RecordNodeExecution(OpKernel* op_kernel, + OpKernelContext* context); + + const std::string& filename() { return filename_; } + + private: + explicit NodeFileWriter(std::string filename) + : filename_{std::move(filename)} {} + + absl::Status Init(Env* env) { + return env->NewWritableFile(filename_, &node_def_file_); + } + + // Writes the NodeDef to a file, if it hasn't already been written yet. + absl::Status MaybeWriteNodeDefToFile(const NodeDef& def); + + const std::string filename_; + mutex mu_; + // Hashes of the NodeDefs already written to the file + absl::flat_hash_set written_hashes_ TF_GUARDED_BY(mu_); + + std::unique_ptr node_def_file_ TF_PT_GUARDED_BY(mu_); +}; + +} // namespace tensorflow +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_NODE_FILE_WRITER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/null_request_cost_accessor.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/null_request_cost_accessor.h new file mode 100644 index 00000000..daae603f --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/null_request_cost_accessor.h @@ -0,0 +1,33 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_NULL_REQUEST_COST_ACCESSOR_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_NULL_REQUEST_COST_ACCESSOR_H_ + +#include "tensorflow/core/common_runtime/request_cost_accessor_registry.h" + +namespace tensorflow { + +// NullRequestCostAccessor always returns nullptr as the RequestCost of current +// rpc. It's created to allow callers to skip collecting the request cost. +class NullRequestCostAccessor : public RequestCostAccessor { + public: + // Always returns nullptr as the RequestCost of current rpc. + RequestCost* GetRequestCost() const override; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_NULL_REQUEST_COST_ACCESSOR_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/optimization_registry.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/optimization_registry.h new file mode 100644 index 00000000..9de93a6b --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/optimization_registry.h @@ -0,0 +1,191 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Classes to maintain a static registry of whole-graph optimization +// passes to be applied by the Session when it initializes a graph. +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_OPTIMIZATION_REGISTRY_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_OPTIMIZATION_REGISTRY_H_ + +#include +#include +#include + +#include "tensorflow/core/common_runtime/composite_device.h" +#include "tensorflow/core/common_runtime/device_set.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/graph/costmodel.h" +#include "tensorflow/core/graph/graph.h" + +namespace tensorflow { +struct SessionOptions; + +// All the parameters used by an optimization pass are packaged in +// this struct. They should be enough for the optimization pass to use +// as a key into a state dictionary if it wants to keep state across +// calls. +struct GraphOptimizationPassOptions { + // Filled in by DirectSession for PRE_PLACEMENT optimizations. Can be empty. + string session_handle; + const SessionOptions* session_options = nullptr; + const CostModel* cost_model = nullptr; + + FunctionLibraryDefinition* flib_def = nullptr; // Not owned. + // The DeviceSet contains all the devices known to the system and is + // filled in for optimizations run by the session master, i.e., + // PRE_PLACEMENT, POST_PLACEMENT, and POST_REWRITE_FOR_EXEC. It is + // nullptr for POST_PARTITIONING optimizations which are run at the + // workers. + const DeviceSet* device_set = nullptr; // Not owned. + + // Maps from a CompositeDevice name to a list of underlying physical + // devices. + const std::vector* composite_devices = + nullptr; // Not owned. + + // The graph to optimize, for optimization passes that run before + // partitioning. Null for post-partitioning passes. + // An optimization pass may replace *graph with a new graph object. + std::unique_ptr* graph = nullptr; + + // Graphs for each partition, if running post-partitioning. Optimization + // passes may alter the graphs, but must not add or remove partitions. + // Null for pre-partitioning passes. + std::unordered_map>* partition_graphs = + nullptr; + + // Indicator of whether or not the graph was derived from a function. + bool is_function_graph = false; + // Set when is_function_graph is true. The default device where the function + // runs. If nullptr, it runs on the local host. + const Device* default_function_device = nullptr; + // Set when is_function_graph is true. The function where the graph was + // derived. `graph` doesn't contain all the information in the function_def, + // e.g. function attributes. + const FunctionDef* function_def = nullptr; + + // TODO(b/176491312): Remove this if shape inference on import flag is + // removed. If True, allows mlir roundtrip to run shape inference on import. + bool shape_inference_on_tfe_dialect_import = true; + + // A unique filename prefix (using hostname, process ID, thread ID and + // timestamp) for graph dumps. + string debug_filename_prefix; + + // Whether to enable tf2xla mlir bridge in compiling SavedModel. + bool enable_tf2xla_mlir_bridge = true; +}; + +// Optimization passes are implemented by inheriting from +// GraphOptimizationPass. +class GraphOptimizationPass { + public: + virtual ~GraphOptimizationPass() {} + virtual absl::Status Run(const GraphOptimizationPassOptions& options) = 0; + void set_name(const string& name) { name_ = name; } + string name() const { return name_; } + + private: + // The name of the optimization pass, which is the same as the inherited + // class name. + string name_; +}; + +// The key is a 'phase' number. Phases are executed in increasing +// order. Within each phase the order of passes is undefined. +typedef std::map>> + GraphOptimizationPasses; + +// A global OptimizationPassRegistry is used to hold all passes. +class OptimizationPassRegistry { + public: + // Groups of passes are run at different points in initialization. + enum Grouping { + PRE_PLACEMENT, // after cost model assignment, before placement. + POST_PLACEMENT, // after placement. + POST_REWRITE_FOR_EXEC, // after re-write using feed/fetch endpoints. + POST_PARTITIONING, // after partitioning + }; + + // Add an optimization pass to the registry. + void Register(Grouping grouping, int phase, + std::unique_ptr pass); + + const std::map& groups() { + return groups_; + } + + // Run all passes in grouping, ordered by phase, with the same + // options. + absl::Status RunGrouping(Grouping grouping, + const GraphOptimizationPassOptions& options); + + // Returns the global registry of optimization passes. + static OptimizationPassRegistry* Global(); + + // Prints registered optimization passes for debugging. + void LogGrouping(Grouping grouping, int vlog_level); + void LogAllGroupings(int vlog_level); + + private: + std::map groups_; + + const char* GetGroupingName(Grouping grouping) const { + switch (grouping) { + case PRE_PLACEMENT: + return "pre_placement"; + case POST_PLACEMENT: + return "post_placement"; + case POST_REWRITE_FOR_EXEC: + return "post_rewrite_for_exec"; + case POST_PARTITIONING: + return "post_partitioning"; + } + return "unknown"; + } +}; + +namespace optimization_registration { + +class OptimizationPassRegistration { + public: + OptimizationPassRegistration(OptimizationPassRegistry::Grouping grouping, + int phase, + std::unique_ptr pass, + string optimization_pass_name) { + pass->set_name(optimization_pass_name); + OptimizationPassRegistry::Global()->Register(grouping, phase, + std::move(pass)); + } +}; + +} // namespace optimization_registration + +#define REGISTER_OPTIMIZATION(grouping, phase, optimization) \ + REGISTER_OPTIMIZATION_UNIQ_HELPER(__COUNTER__, grouping, phase, optimization) + +#define REGISTER_OPTIMIZATION_UNIQ_HELPER(ctr, grouping, phase, optimization) \ + REGISTER_OPTIMIZATION_UNIQ(ctr, grouping, phase, optimization) + +#define REGISTER_OPTIMIZATION_UNIQ(ctr, grouping, phase, optimization) \ + static ::tensorflow::optimization_registration::OptimizationPassRegistration \ + register_optimization_##ctr( \ + grouping, phase, \ + ::std::unique_ptr<::tensorflow::GraphOptimizationPass>( \ + new optimization()), \ + #optimization) + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_OPTIMIZATION_REGISTRY_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/optimize_cross_host_control_deps.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/optimize_cross_host_control_deps.h new file mode 100644 index 00000000..dde9d3e1 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/optimize_cross_host_control_deps.h @@ -0,0 +1,50 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_OPTIMIZE_CROSS_HOST_CONTROL_DEPS_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_OPTIMIZE_CROSS_HOST_CONTROL_DEPS_H_ + +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { + +// Optimize the graph by reducing cross-host control output edges. +// Once we find any nodes in the graph having not less than +// `cross_host_edges_threshold` control output edges in one host, we create +// a `NoOp` node in the destination host to proxy the control edges between the +// oringal node and the destination control output nodes. +absl::Status OptimizeCrossHostControlOutputEdges( + Graph* graph, int cross_host_edges_threshold); + +// Optimize the graph by reducing cross-host data output edges. +// Once we find any nodes in the graph having not less than +// `cross_host_edges_threshold` data output edges in one host, we create +// a `IdentityN` node in the destination host to proxy the data edges between +// the original node and the destination output nodes. +absl::Status OptimizeCrossHostDataOutputEdges(Graph* graph, + int cross_host_edges_threshold); + +// Optimize the graph by reducing cross-host control input edges. +// Once we find any nodes in the graph having not less than +// `cross_host_edges_threshold` control input edges in one host, we create +// a `NoOp` node in the source host to proxy the control edges between the +// source control input nodes and oringal node. +absl::Status OptimizeCrossHostControlInputEdges(Graph* graph, + int cross_host_edges_threshold); + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_OPTIMIZE_CROSS_HOST_CONTROL_DEPS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/optimize_function_graph_utils.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/optimize_function_graph_utils.h new file mode 100644 index 00000000..d5cd2159 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/optimize_function_graph_utils.h @@ -0,0 +1,94 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// This file contains util functions related to function graph instantiation and +// optimizations. +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_OPTIMIZE_FUNCTION_GRAPH_UTILS_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_OPTIMIZE_FUNCTION_GRAPH_UTILS_H_ + +#include +#include +#include +#include + +#include "absl/time/time.h" +#include "tensorflow/core/common_runtime/composite_device.h" +#include "tensorflow/core/common_runtime/optimized_function_graph_info.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/platform/env.h" + +namespace tensorflow { +// TODO(b/246646753): add more tests. + +// The name of the env variable for the caching location of graph optimization. +// Note: if the caching location retrieved by the env variable is empty it means +// no caching would be performed. +static const char kGraphCachingEnvVariableName[] = "TF_GRAPH_CACHING"; +// The threshold of the graph optimization duration to be cached. +// Note: setting this threshold to 0 means to cache for every function. +constexpr absl::Duration kCachingThresholdDuration = absl::Seconds(3); + +// TODO(iga): Reword +// Pins each arg that emits a `DT_RESOURCE` tensor to the device on which the +// corresponding resource lives. This ensures that the Placer assigns ops that +// access these resources to the appropriate devices. +absl::Status PinArgsAndRets(const std::vector& input_devices, + const std::vector& output_devices, + const DeviceSet& device_set, + const std::vector& arg_nodes, + const std::vector& ret_nodes, + const FunctionLibraryDefinition* lib_def, + Device* default_device); + +// Outputs graph optimization result after all the graph optimization (up till +// before graph partitioning); returns error if optimization fails. Note that +// the `input_lib_def` will be used only if the lib_def in `options` is nullptr. +absl::StatusOr OptimizeFunctionGraph( + const string& function_name, AttrSlice attrs, + const FunctionLibraryRuntime::InstantiateOptions& options, + const DeviceSet& dev_set, const FunctionLibraryDefinition* input_lib_def, + const std::vector& composite_devices, Device* cpu_device, + Device* default_device, Env* env, + OptimizedFunctionGraph::OptimizationSource optimization_source); + +// Outputs graph optimization results (as OptimizedFunctionGraphInfo proto), +// either by running the actual graph optimization passes, or by reloading from +// the file cache if existent. If cache loading fails, it goes ahead and runs +// the graph optimization passes. Returns error if running the optimization +// passes fails. +absl::StatusOr +OptimizeFunctionGraphOrReadFromFileCache( + const string& function_name, AttrSlice attrs, + const FunctionLibraryRuntime::InstantiateOptions& options, + const DeviceSet& dev_set, const FunctionLibraryDefinition* input_lib_def, + const std::vector& composite_devices, Device* cpu_device, + Device* default_device, Env* env, + absl::Duration caching_threshold_duration = kCachingThresholdDuration); + +// Pre-processes, partitions and post-optimizes the input graph; returns +// subgraph result (maps from device name to the subgraph); returns error if any +// optimization or partitioning step fails. +absl::StatusOr< + std::unique_ptr>>> +PreprocessAndPartitionGraph( + const std::string& function_name, + OptimizedFunctionGraphInfo& input_optimized_graph, + const FunctionLibraryRuntime::InstantiateOptions& options, + const DeviceSet& dev_set, const FunctionLibraryDefinition* input_lib_def, + const std::vector& composite_devices, Device* cpu_device, + Env* env); + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_OPTIMIZE_FUNCTION_GRAPH_UTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/optimized_function_graph_info.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/optimized_function_graph_info.h new file mode 100644 index 00000000..c23d7221 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/optimized_function_graph_info.h @@ -0,0 +1,90 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_OPTIMIZED_FUNCTION_GRAPH_INFO_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_OPTIMIZED_FUNCTION_GRAPH_INFO_H_ + +#include +#include +#include +#include + +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/optimized_function_graph.pb.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/platform/statusor.h" + +namespace tensorflow { + +// Function graph related information after optimizations. This struct can be +// converted to and from +// third_party/tensorflow/core/framework/optimized_function_graph.proto. +struct OptimizedFunctionGraphInfo { + // Function name. + string name; + // Optimized function graph. + std::unique_ptr function_graph; + // Optimized function library. + FunctionLibraryDefinition lib_def; + // Map from original node names to control return names. + std::unordered_map node_name_to_control_ret; + // Return node types of the function. + DataTypeVector ret_types; + // Number of return nodes. + size_t num_return_nodes; + // Time (in microseconds) spent on running the graph optimization passes for + // this function. + uint64_t optimization_duration_usecs; + // Indicates the source environment where the optimization is created. + OptimizedFunctionGraph::OptimizationSource optimization_source; + + ~OptimizedFunctionGraphInfo() = default; + OptimizedFunctionGraphInfo() : lib_def(OpRegistry::Global()) {} + OptimizedFunctionGraphInfo( + const std::string& name, std::unique_ptr&& graph, + FunctionLibraryDefinition&& lib_def, + const std::unordered_map& node_name_to_control_ret, + const DataTypeVector& ret_types, size_t num_return_nodes, + uint64_t optimization_duration_usecs, + OptimizedFunctionGraph::OptimizationSource optimization_source) + : name(name), + function_graph(std::move(graph)), + lib_def(std::move(lib_def)), + node_name_to_control_ret(node_name_to_control_ret), + ret_types(ret_types), + num_return_nodes(num_return_nodes), + optimization_duration_usecs(optimization_duration_usecs), + optimization_source(optimization_source) {} + + OptimizedFunctionGraphInfo(OptimizedFunctionGraphInfo& info) = delete; + OptimizedFunctionGraphInfo& operator=(OptimizedFunctionGraphInfo& info) = + delete; + OptimizedFunctionGraphInfo(OptimizedFunctionGraphInfo&& info) = + default; // NOLINT + OptimizedFunctionGraphInfo& operator=( + OptimizedFunctionGraphInfo&& info) noexcept = default; // NOLINT + + // Converts from the struct to OptimizedFunctionGraph proto. + static OptimizedFunctionGraph ToProto(const OptimizedFunctionGraphInfo& info); + + // Converts from the proto to struct OptimizedFunctionGraphInfo. Returns error + // if the conversion fails. + static absl::StatusOr FromProto( + OptimizedFunctionGraph&& proto); +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_OPTIMIZED_FUNCTION_GRAPH_INFO_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/partitioning_utils.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/partitioning_utils.h new file mode 100644 index 00000000..6bc9befb --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/partitioning_utils.h @@ -0,0 +1,108 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_PARTITIONING_UTILS_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_PARTITIONING_UTILS_H_ + +#include +#include + +#include "tensorflow/core/common_runtime/device_set.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { + +// Given a `device_set` and a `graph`, partitions the `graph` into +// `subgraphs`. `subgraphs` maps device names to the graph assigned to that +// device. `graph` must have been placed (e.g. by running Placer), +// i.e. all nodes must have an assigned_device set. +// `graph` is non-const because the underlying Partition() function transforms +// the graph to correctly partition distributed control flow. +// `get_tensor_name_attr` computes the "tensor_name" attr value of Send/Recv ops +// inserted during partitioning. Use the default one if not set. It needs to be +// thread safe if it's shared in multple threads. +absl::Status PartitionFunctionGraph( + const DeviceSet& device_set, std::unique_ptr graph, + std::unordered_map>* subgraphs, + std::function get_tensor_name_attr = nullptr); + +// Inserts send/recv ops to `graph` if nodes are assigned to multiple devices. +// Returns the new graph with the added nodes. Moreover, the dependency between +// a send/recv pair is made explicit by adding a control dependency between +// them. +// Note that, the returned graph is intended to be used by TF MLIR importer. +// The dependencies between send/recv pairs ensure the importer will generate TF +// MLIR ops in a valid order. +absl::StatusOr> InsertTransferOps( + const DeviceSet& device_set, std::unique_ptr graph); + +// This function performs bookkeeping to track which `Arg` and `Retval` nodes +// were placed on a particular device / graph. +// +// More specifically, this function +// +// (1) rewrites the indices of the `Arg` and `Retval` nodes in `graph` to be +// consecutive. +// +// These indices might not be consecutive after grappler's pruning +// optimization (e.g. removing redundant Args), or graph partitioning. In +// the latter case, the nodes in `graph` are placed on `device_type`, and +// each such graph partition gets a subset of the arguments and return +// values. The `index` attributes of these _Arg and _Retval nodes reflect +// the indices of these parameters in the original function. To convert +// `subgraph` to a function, we need to replace there original indices with +// 0, 1, 2, ... . +// +// The argument and return value order in `graph` is determined by the +// argument and return value order in the original function. This stability +// is important because it enables us to treat a single-partition function +// as having the same signature as the subgraph. +// +// (2) records the subsets of `Arg` and `Retval` nodes assigned to the +// device in `*_indices`, and +// (3) records which `Arg` and `Retval` nodes live in host memory in +// `*_alloc_attrs`. If these vectors are NULL, do nothing here. If +// `ints_on_device` is false, int32 `Arg` and `Retval` nodes are placed on +// host else not. This is needed because in certain special cases e.g. +// when graph is placed on TPU/XLA device or when the `Retval` is an output +// of an iterator, int32 tensors live on device. +absl::Status UpdateArgAndRetvalMetadata( + Graph* graph, std::vector* arg_indices, + std::vector* ret_indices, + std::vector* arg_alloc_attrs, + std::vector* ret_alloc_attrs, bool ints_on_device); + +// Utility for generating function names not present in `flib_def`, using +// given `name` as the base for the name. +class FunctionNameGenerator { + public: + // `flib_def` must outlive this. + FunctionNameGenerator(const FunctionLibraryDefinition* flib_def, + const string& name) + : flib_def_(flib_def), name_(name), counter_(0) {} + + // Returns a function name not present in `flib_def` using `name` as + // the base and appending a numeric suffix. + string GetName(); + + private: + const FunctionLibraryDefinition* flib_def_; + const string name_; + uint32 counter_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_PARTITIONING_UTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/pending_counts.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/pending_counts.h new file mode 100644 index 00000000..cff837ec --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/pending_counts.h @@ -0,0 +1,573 @@ +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_PENDING_COUNTS_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_PENDING_COUNTS_H_ + +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "tensorflow/core/lib/gtl/flatmap.h" +#include "tensorflow/core/lib/hash/hash.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/util/port.h" + +namespace tensorflow { + +// PendingCounts is an internal helper class to keep track of pending and +// dead counts for nodes, for use in the ExecutorState module. It +// holds a map from Handles to various counts for that handle. This +// information is needed per frame iteration. The amount of memory +// needed for an iteration is the same across all executions of the +// iteration. The memory amount and handles are precomputed at startup +// using a Layout object. +// +// PendingCounts::Layout layout; +// std::vector h(C); +// for (int id = 0; id < C; id++) { +// h[id] = r.AddHandle(max_pending[id], max_dead[id]); +// } +// +// When we actually want to start an iteration we first create a +// PendingCounts object and then index into it using the precomputed +// handles: + +// PendingCounts counts(layout); +// ... +// counts.decrement_pending(h[id], 1); +class PendingCounts { + public: + // The state machine for a node's execution. + enum NodeState { + // The pending count for the node > 0. + PENDING_NOTREADY, + // The pending count for the node == 0, but the node has not + // started executing. + PENDING_READY, + // The node has started executing. + STARTED, + // The node has finished executing. + COMPLETED + }; + + // An opaque handle indicating where in the PendingCounts data structure + // the appropriate count information can be found. + class Handle; + // Given a node that needs to represent counts no larger than the + // specified "max_pending_count" and "max_dead_count", create a + // handle that can be passed to various PendingCounts routines + // to retrieve the count data for this node. + class Layout { + public: + Handle CreateHandle(size_t max_pending_count, size_t max_dead_count); + + private: + friend class PendingCounts; + int next_offset_ = 0; // Next byte offset to allocate + }; + + // Create a new PendingCounts object that can hold the state of + // all the Handles allocated from "final_allocator". + explicit PendingCounts(Layout layout) + : num_bytes_(layout.next_offset_), bytes_(new char[num_bytes_]()) { + if (num_bytes_ >= sizeof(LargeCounts)) { + CHECK_EQ(uintptr_t(bytes_) % alignof(LargeCounts), 0); + } + } + + // Create a new PendingCounts object with the same layout and counts + // as "other". + explicit PendingCounts(const PendingCounts& other) + : num_bytes_(other.num_bytes_), bytes_(new char[num_bytes_]) { + if (num_bytes_ >= sizeof(LargeCounts)) { + CHECK_EQ(uintptr_t(bytes_) % alignof(LargeCounts), 0); + } + memcpy(bytes_, other.bytes_, other.num_bytes_); + } + + ~PendingCounts() { delete[] bytes_; } + + void set_initial_count(Handle h, size_t pending_count) { + if (h.is_large_) { + std::atomic* c_ptr = Large(h); + auto c = c_ptr->load(std::memory_order_relaxed); + c.pending = pending_count; + c.dead_count = 0; + c.has_started = 0; + c_ptr->store(c, std::memory_order_relaxed); + } else { + DCHECK_LE(pending_count, kMaxCountForPackedCounts); + std::atomic* c_ptr = Packed(h); + auto c = c_ptr->load(std::memory_order_relaxed); + c.pending = pending_count; + c.dead_count = 0; + c.has_started = 0; + c_ptr->store(c, std::memory_order_relaxed); + } + } + + NodeState node_state(Handle h) { + if (h.is_large_) { + return NodeStateForStruct(Large(h)->load(std::memory_order_relaxed)); + } else { + return NodeStateForStruct(Packed(h)->load(std::memory_order_relaxed)); + } + } + void mark_started(Handle h) { + DCHECK_EQ(pending(h), 0); + if (h.is_large_) { + std::atomic* c_ptr = Large(h); + auto c = c_ptr->load(std::memory_order_relaxed); + DCHECK_EQ(c.has_started, 0); + c.has_started = 1; + c_ptr->store(c, std::memory_order_relaxed); + } else { + std::atomic* c_ptr = Packed(h); + auto c = c_ptr->load(std::memory_order_relaxed); + DCHECK_EQ(c.has_started, 0); + c.has_started = 1; + c_ptr->store(c, std::memory_order_relaxed); + } + } + void mark_completed(Handle h) { + if (h.is_large_) { + std::atomic* c_ptr = Large(h); + auto c = c_ptr->load(std::memory_order_relaxed); + DCHECK_EQ(c.has_started, 1); + c.pending = 1; + c_ptr->store(c, std::memory_order_relaxed); + } else { + std::atomic* c_ptr = Packed(h); + auto c = c_ptr->load(std::memory_order_relaxed); + DCHECK_EQ(c.has_started, 1); + c.pending = 1; + c_ptr->store(c, std::memory_order_relaxed); + } + } + int pending(Handle h) { + if (h.is_large_) { + LargeCounts c = Large(h)->load(std::memory_order_relaxed); + if (PENDING_NOTREADY == NodeStateForStruct(c)) { + return c.pending; + } else { + // The pending count encodes the state once the node has + // started, so just return 0. + return 0; + } + } else { + PackedCounts c = Packed(h)->load(std::memory_order_relaxed); + if (PENDING_NOTREADY == NodeStateForStruct(c)) { + return c.pending; + } else { + // The pending count encodes the state once the node has + // started, so just return 0. + return 0; + } + } + } + struct AdjustResult { + int dead_count; + int pending_count; + + AdjustResult(int dead_count, int pending_count) + : dead_count(dead_count), pending_count(pending_count) {} + }; + int decrement_pending(Handle h, int v) { + DCHECK_GE(pending(h), v); + if (h.is_large_) { + std::atomic* c_ptr = Large(h); + auto c = c_ptr->load(std::memory_order_relaxed); + c.pending -= v; + c_ptr->store(c, std::memory_order_relaxed); + return c.pending; + } else { + std::atomic* c_ptr = Packed(h); + auto c = c_ptr->load(std::memory_order_relaxed); + c.pending -= v; + c_ptr->store(c, std::memory_order_relaxed); + return c.pending; + } + } + + // Mark a merge node as live + // REQUIRES: Node corresponding to "h" is a merge node + void mark_live(Handle h) { + if (h.is_large_) { + std::atomic* c_ptr = Large(h); + auto c = c_ptr->load(std::memory_order_relaxed); + // Only do anything if the node hasn't already started executing. + if (PENDING_NOTREADY == NodeStateForStruct(c)) { + c.pending &= ~static_cast(0x1); + c_ptr->store(c, std::memory_order_relaxed); + } + } else { + std::atomic* c_ptr = Packed(h); + auto c = c_ptr->load(std::memory_order_relaxed); + // Only do anything if the node hasn't already started executing. + if (PENDING_NOTREADY == NodeStateForStruct(c)) { + static_assert(7 == kMaxCountForPackedCounts, + "Live flag incorrect for max packed count"); + c.pending &= 0x6; + c_ptr->store(c, std::memory_order_relaxed); + } + } + } + + int dead_count(Handle h) { + int r = h.is_large_ ? Large(h)->load(std::memory_order_relaxed).dead_count + : Packed(h)->load(std::memory_order_relaxed).dead_count; + return r; + } + void increment_dead_count(Handle h) { + if (h.is_large_) { + std::atomic* c_ptr = Large(h); + auto c = c_ptr->load(std::memory_order_relaxed); + if (PENDING_NOTREADY == NodeStateForStruct(c)) { + c.dead_count++; + c_ptr->store(c, std::memory_order_relaxed); + } + } else { + std::atomic* c_ptr = Packed(h); + auto c = c_ptr->load(std::memory_order_relaxed); + if (PENDING_NOTREADY == NodeStateForStruct(c)) { + DCHECK_LT(c.dead_count, kMaxCountForPackedCounts); + c.dead_count++; + c_ptr->store(c, std::memory_order_relaxed); + } + } + } + + // Mark a merge node as live. Please note that the pending count it returns + // is before the update. + AdjustResult adjust_for_mark_live(Handle h) { + if (h.is_large_) { + std::atomic* c_ptr = Large(h); + auto c = c_ptr->load(std::memory_order_relaxed); + auto ret_pending = 0; + if (PENDING_NOTREADY == NodeStateForStruct(c)) { + ret_pending = c.pending; + c.pending &= ~static_cast(0x1); + c_ptr->store(c, std::memory_order_relaxed); + } + return AdjustResult(c.dead_count, ret_pending); + } else { + std::atomic* c_ptr = Packed(h); + auto c = c_ptr->load(std::memory_order_relaxed); + auto ret_pending = 0; + if (PENDING_NOTREADY == NodeStateForStruct(c)) { + static_assert(7 == kMaxCountForPackedCounts, + "Live flag incorrect for max packed count"); + ret_pending = c.pending; + c.pending &= 0x6; + c_ptr->store(c, std::memory_order_relaxed); + } + return AdjustResult(c.dead_count, ret_pending); + } + } + + // The same as the above, but performs the operation atomically. This + // is thread-safe to run concurrently with other threads. + AdjustResult adjust_for_mark_live_atomic(Handle h) { + if (h.is_large_) { + std::atomic* c_ptr = Large(h); + auto old_val = c_ptr->load(std::memory_order_relaxed); + while (true) { + auto new_val = old_val; + auto ret_pending = 0; + // Only do anything if the node hasn't already started executing. + if (PENDING_NOTREADY == NodeStateForStruct(new_val)) { + ret_pending = old_val.pending; + new_val.pending &= ~static_cast(0x1); + } + AdjustResult ret(old_val.dead_count, ret_pending); + if (TF_PREDICT_TRUE(c_ptr->compare_exchange_weak(old_val, new_val))) + return ret; + } + } else { + std::atomic* c_ptr = Packed(h); + auto old_val = c_ptr->load(std::memory_order_relaxed); + while (true) { + auto new_val = old_val; + auto ret_pending = 0; + // Only do anything if the node hasn't already started executing. + if (PENDING_NOTREADY == NodeStateForStruct(new_val)) { + static_assert(7 == kMaxCountForPackedCounts, + "Live flag incorrect for max packed count"); + ret_pending = old_val.pending; + new_val.pending &= 0x6; + } + AdjustResult ret(old_val.dead_count, ret_pending); + if (TF_PREDICT_TRUE(c_ptr->compare_exchange_weak(old_val, new_val))) + return ret; + } + } + } + + // A streamlined routine that does several pieces of bookkeeping at + // once. Equivalent to: + // increment_dead_count(h); + // return {dead_count(h) pending(h)}; + AdjustResult adjust_for_increment_dead(Handle h) { + if (h.is_large_) { + return adjust_for_increment_dead_shared(Large(h)); + } else { + return adjust_for_increment_dead_shared(Packed(h)); + } + } + + // The same as the above, but performs the operation atomically. This + // is thread-safe to run concurrently with other threads. + AdjustResult adjust_for_increment_dead_atomic(Handle h) { + if (h.is_large_) { + return adjust_for_increment_dead_shared_atomic(Large(h)); + } else { + return adjust_for_increment_dead_shared_atomic(Packed(h)); + } + } + + // A streamlined routine that does several pieces of bookkeeping at + // once. Equivalent to: + // decrement_pending(h, decrement_pending); + // return {dead_count(h) pending(h)}; + AdjustResult adjust_for_decrement_pending(Handle h, int decrement_pending) { + DCHECK_GE(pending(h), decrement_pending); + if (h.is_large_) { + return adjust_for_decrement_pending_shared(Large(h), decrement_pending); + } else { + return adjust_for_decrement_pending_shared(Packed(h), decrement_pending); + } + } + + // The same as the above, but performs the operation atomically. This + // is thread-safe to run concurrently with other threads. + AdjustResult adjust_for_decrement_pending_atomic(Handle h, + int decrement_pending) { + DCHECK_GE(pending(h), decrement_pending); + if (h.is_large_) { + return adjust_for_decrement_pending_shared_atomic(Large(h), + decrement_pending); + } else { + return adjust_for_decrement_pending_shared_atomic(Packed(h), + decrement_pending); + } + } + + // A streamlined routine that does several pieces of bookkeeping at + // once. Equivalent to: + // if (increment_dead) increment_dead_count(h); + // decrement_pending(h, 1); + // return {dead_count(h), pending(h)}; + AdjustResult adjust_for_activation(Handle h, bool increment_dead) { + DCHECK_GE(pending(h), 1); + if (h.is_large_) { + return adjust_for_activation_shared(Large(h), increment_dead); + } else { + return adjust_for_activation_shared(Packed(h), increment_dead); + } + } + + // The same as the above, but performs the operation atomically. This + // is thread-safe to run concurrently with other threads. + AdjustResult adjust_for_activation_atomic(Handle h, bool increment_dead) { + DCHECK_GE(pending(h), 1); + if (h.is_large_) { + return adjust_for_activation_shared_atomic(Large(h), increment_dead); + } else { + return adjust_for_activation_shared_atomic(Packed(h), increment_dead); + } + } + + class Handle { + public: + Handle() : byte_offset_(0), is_large_(0) {} + + private: + friend class PendingCounts; + int byte_offset_ : 31; // Byte offset of the rep in PendingCounts object + bool is_large_ : 1; // If true, rep is LargeCounts; otherwise PackedCounts + }; + + private: + template + inline AdjustResult adjust_for_increment_dead_shared(std::atomic* c) { + T val = c->load(std::memory_order_relaxed); + auto ret_pending = 0; + // Only do anything if the node hasn't already started executing. + if (PENDING_NOTREADY == NodeStateForStruct(val)) { + val.dead_count++; + ret_pending = val.pending; + c->store(val, std::memory_order_relaxed); + } + return AdjustResult(val.dead_count, ret_pending); + } + + template + inline AdjustResult adjust_for_increment_dead_shared_atomic( + std::atomic* c) { + T old_val = c->load(std::memory_order_relaxed); + while (true) { + auto new_val = old_val; + auto ret_pending = 0; + // Only do anything if the node hasn't already started executing. + if (PENDING_NOTREADY == NodeStateForStruct(new_val)) { + ret_pending = new_val.pending; + new_val.dead_count++; + } + AdjustResult ret(new_val.dead_count, ret_pending); + if (TF_PREDICT_TRUE(c->compare_exchange_weak(old_val, new_val))) + return ret; + } + } + + template + inline AdjustResult adjust_for_decrement_pending_shared( + std::atomic* c, int decrement_pending) { + T val = c->load(std::memory_order_relaxed); + DCHECK_GE(val.pending, decrement_pending); + val.pending -= decrement_pending; + c->store(val, std::memory_order_relaxed); + return AdjustResult(val.dead_count, val.pending); + } + + template + inline AdjustResult adjust_for_decrement_pending_shared_atomic( + std::atomic* c, int decrement_pending) { + T old_val = c->load(std::memory_order_relaxed); + while (true) { + T new_val = old_val; + DCHECK_GE(new_val.pending, decrement_pending); + new_val.pending -= decrement_pending; + AdjustResult ret(new_val.dead_count, new_val.pending); + if (TF_PREDICT_TRUE(c->compare_exchange_weak(old_val, new_val))) + return ret; + } + } + + template + inline AdjustResult adjust_for_activation_shared(std::atomic* c, + bool increment_dead) { + T val = c->load(std::memory_order_relaxed); + if (increment_dead && PENDING_NOTREADY == NodeStateForStruct(val)) { + val.dead_count++; + } + DCHECK_GE(val.pending, 1); + val.pending--; + c->store(val, std::memory_order_relaxed); + return AdjustResult(val.dead_count, val.pending); + } + + template + inline AdjustResult adjust_for_activation_shared_atomic(std::atomic* c, + bool increment_dead) { + T old_val = c->load(std::memory_order_relaxed); + while (true) { + T new_val = old_val; + if (increment_dead && PENDING_NOTREADY == NodeStateForStruct(new_val)) { + new_val.dead_count++; + } + DCHECK_GE(new_val.pending, 1); + new_val.pending--; + AdjustResult ret(new_val.dead_count, new_val.pending); + if (TF_PREDICT_TRUE(c->compare_exchange_weak(old_val, new_val))) + return ret; + } + } + + // We keep track of the pending count and dead input count for each + // graph node. The representation used here is designed to be cache + // efficient for graphs with large numbers of nodes, where most + // nodes have relatively small maximum pending counts (e.g. for one + // LSTM model, 99% of 5000+ nodes had in-degrees of 3 or less). We + // use one byte to hold both the pending and dead count for a node + // where these together can fit in one byte, and we use a hash table + // to handle the rare node ids that need larger counts than this. + // Each frame in this subgraph has its own PendingCounts. + + // We use 3 bits each for dead_count and pending. + static constexpr int kMaxCountForPackedCounts = 7; + + // Most counts are small, so we pack a pending count and a dead + // count into 3 bits each, use 1 bit to indicate that the node has + // started computing. + struct PackedCounts { + uint8 pending : 3; + uint8 dead_count : 3; + uint8 has_started : 1; + }; + + // NOTE: alignas(8) is critical to implement efficient atomic + // on MSVC. + struct alignas(8) LargeCounts { + uint32 pending; + uint32 dead_count : 31; + // NOTE(tlipcon): MSVC won't pack this struct into 8 bytes unless + // all of the member types are uint32. + uint32 has_started : 1; + }; + + template + NodeState NodeStateForStruct(const T& c) const { + if (c.has_started) { + return (c.pending == 0) ? STARTED : COMPLETED; + } else { + return (c.pending == 0) ? PENDING_READY : PENDING_NOTREADY; + } + } + inline std::atomic* Large(Handle h) { + DCHECK(h.is_large_); + DCHECK_LE(h.byte_offset_ + sizeof(std::atomic), num_bytes_); + DCHECK_EQ(h.byte_offset_ % alignof(std::atomic), 0); + return reinterpret_cast*>(bytes_ + h.byte_offset_); + } + inline std::atomic* Packed(Handle h) { + DCHECK(!h.is_large_); + DCHECK_LE(h.byte_offset_ + sizeof(PackedCounts), num_bytes_); + return reinterpret_cast*>(bytes_ + + h.byte_offset_); + } + + const int num_bytes_; // Just for bounds checking in debug mode + char* bytes_; // Array of num_bytes_ bytes + + void operator=(const PendingCounts&) = delete; +}; + +inline PendingCounts::Handle PendingCounts::Layout::CreateHandle( + size_t max_pending_count, size_t max_dead_count) { + Handle result; + if ((max_pending_count > kMaxCountForPackedCounts) || + (max_dead_count > kMaxCountForPackedCounts)) { + constexpr int B = sizeof(std::atomic); + // Round byte offset to proper alignment + static_assert( + sizeof(std::atomic) >= alignof(std::atomic), + "std::atomic must be packed"); + int64_t offset = ((static_cast(next_offset_) + B - 1) / B) * B; + result.byte_offset_ = offset; + result.is_large_ = true; + next_offset_ = result.byte_offset_ + B; + } else { + result.byte_offset_ = next_offset_; + result.is_large_ = false; + static_assert(sizeof(std::atomic) == 1, + "std::atomic should be a single byte"); + next_offset_ += sizeof(std::atomic); + } + return result; +} + +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_PENDING_COUNTS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/permuter.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/permuter.h new file mode 100644 index 00000000..57704dd1 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/permuter.h @@ -0,0 +1,83 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_PERMUTER_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_PERMUTER_H_ + +#include +#include +#include +#include + +#include "tensorflow/core/common_runtime/base_collective_executor.h" +#include "tensorflow/core/framework/collective.h" + +namespace tensorflow { +class Device; + +// Implementation of collective permute. +// +// Permute takes +// - a list of devices participating in the collective +// - a permutation as a list of integers. +// - a tensor +// +// The list of devices replaces the need for group_key and group_size. The +// number of inputs only scales with the number of devices within one group. +// +// The integers in the permutation are based on indices of the list of devices. +// E.g. devices = {"GPU:0", "GPU:1"} and permutation = {1,0} means +// - devices[0] sends to devices[permutation[0]] and +// - devices[1] sends to devices[permutation[1]]. +// +// Each device sends exactly one tensor and receives exactly one tensor. +class Permuter : public CollectiveImplementationInterface { + public: + Permuter(); + ~Permuter() override = default; + + void Run(StatusCallback done) override; + + absl::Status InitializeCollectiveParams( + CollectiveParams* col_params) override { + return absl::OkStatus(); + } + + // Initializes members of CollectiveContext not yet initialized, i.e. device + // and device_locality. Also saves the CollectiveContext in this object. + absl::Status InitializeCollectiveContext( + std::shared_ptr col_ctx) override; + + private: + std::shared_ptr col_ctx_; + const CollectiveParams* col_params_; // Not owned + StatusCallback done_; + mutex mu_; + absl::Status status_ TF_GUARDED_BY(mu_); + int counter_ TF_GUARDED_BY(mu_); + + void DispatchSend(int src_rank, int target_rank, const Tensor* tensor, + const StatusCallback& done); + + void DispatchRecv(int src_rank, int target_rank, Tensor* tensor, + const StatusCallback& done); + + // Atomically increments counter_ by one for sending, one for receiving. + // Invokes done when counter_ reaches 2. + // The purpose of checking counter_ is to ensure that done_ is called once. + StatusCallback CheckCounterAndCallDone(); +}; + +} // namespace tensorflow +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_PERMUTER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/placer.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/placer.h new file mode 100644 index 00000000..d7b89fd3 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/placer.h @@ -0,0 +1,112 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_PLACER_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_PLACER_H_ + +#include + +#include "tensorflow/core/common_runtime/device_set.h" +#include "tensorflow/core/common_runtime/optimization_registry.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/public/session_options.h" + +namespace tensorflow { + +// A placement algorithm that assigns the nodes of the given Graph to +// devices the given DeviceSet, respecting the following constraints: +// +// 1. Existing device assignments remain unchanged. +// 2. Requested (partial or complete) device specifications given by device name +// for each node are granted. +// 3. Nodes connected by edges of a reference type are colocated on +// the same device. +// 4. Given nodes "A" and "B", if node "B" has a colocation group +// "@loc:A", nodes "A" and "B" will be colocated on the same device. +// +// The implementation builds a constraint graph with the same set of +// nodes, and edges that represent colocation constraints between +// nodes. Each connected component in the resulting constraint graph +// is then assigned to a set of valid devices. +// +// Run() will finally assign the device to each node given the list of +// possible devices. +// +// TODO(mrry): "Soft" constraints, such as "place node 'x' as close as +// possible to node 'y' while respecting the other constraints"? +// TODO(mrry): Create a common interface for this and the other +// placement algorithms so that they may be injected into the graph +// builder. +class Placer { + public: + // Creates an instance of the Placer algorithm for the given + // Graph "graph" (nodes in which may or may not be assigned) on the + // given DeviceSet "devices". + // "function_name" should be set to the name of the function whose body is + // represented by "graph". If "graph" is not representing a function body, + // "function_name" should be empty. + // + // If non-null, default_local_device is used where possible as a placement for + // nodes which do not have a device specified, ahead of other devices which + // would otherwise be higher priority. default_local_device should be on the + // local host so that its FLR is directly accessible by the current process. + // + // The "graph", "devices", and "default_local_device" pointer arguments are + // borrowed by this Placer, and must outlive it. + Placer(Graph* graph, const string& function_name, + const FunctionLibraryDefinition* flib_def, const DeviceSet* devices, + const Device* default_local_device, bool allow_soft_placement, + bool log_device_placement); + Placer(Graph* graph, const string& function_name, + const FunctionLibraryDefinition* flib_def, const DeviceSet* devices); + Placer(Graph* graph, const string& function_name, + const FunctionLibraryDefinition* flib_def, const DeviceSet* devices, + const Device* default_local_device); + + ~Placer(); + + // Assigns each node in this Placer's graph to a device in its + // set of devices. + // + // This method is not thread-safe. + // Run() may be invoked at most once. + absl::Status Run(); + absl::Status Run(const GraphOptimizationPassOptions& options); + + private: + // Returns true if the device type of 'candidate_device_name' is + // found in 'devices'. + bool CanAssignToDevice(const string& candidate_device_name, + const std::vector& devices) const; + + Graph* const graph_; // Not owned. + const string function_name_; + const FunctionLibraryDefinition* const flib_def_; // Not owned. + const DeviceSet* const devices_; // Not owned. + const Device* default_local_device_; // Not owned. + const bool allow_soft_placement_; + const bool log_device_placement_; + + Placer(const Placer&) = delete; + void operator=(const Placer&) = delete; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_PLACER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/placer_inspection_required_ops_utils.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/placer_inspection_required_ops_utils.h new file mode 100644 index 00000000..4f8982d6 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/placer_inspection_required_ops_utils.h @@ -0,0 +1,157 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_PLACER_INSPECTION_REQUIRED_OPS_UTILS_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_PLACER_INSPECTION_REQUIRED_OPS_UTILS_H_ + +// Operations calling functions are becoming ubiquitous in TF 2.0. +// Examples include PartitionedCallOp, functional If/While, and Dataset ops. +// Such operations might require deep inspection - looking at the body of the +// called function - to place them and surrounding ops correctly. + +// This file contains some utilities for placer to correctly place such ops +// including: +// - PlacerInspectionRequiredOpChecker: A simple class with a single +// IsPlacerInspectionRequired method. +// - IsolatePlacerInspectionRequiredOps: This function adds Identity ops for +// each input/output of ops requiring placer inspection. It greatly simplifies +// the implementation of placing such ops. + +#include + +#include "absl/types/optional.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { + +// PlacerInspectionRequiredOpChecker allows one to check if Placer needs to +// look deeply into the op to place ops consuming the outputs correctly. +// +// It is a class instead of a standalone method because checking whether +// a function returns a resource takes non-trivial time and we cache the +// results. +class PlacerInspectionRequiredOpChecker { + public: + // Constructs a PlacerInspectionRequiredOpChecker for nodes of `graph`. + // The functions referenced by nodes in `graph` will be looked up in + // `flib_def` + PlacerInspectionRequiredOpChecker(const Graph* graph, + const FunctionLibraryDefinition* flib_def); + + // If `node` is considered a deep op, sets `*is_deep` to true and returns + // OkStatus(). If an error occurs, returns that error, and the value of + // `*is_deep` is undefined. + // Currently, an op is considered deep, if it is a calling a function + // returning a resource. This definition is driven by Placer's need to + // look inside the op. + // REQUIRES: `node` is part of `graph` passed into constructor. + absl::Status IsPlacerInspectionRequired(const Node& node, bool* is_deep); + + private: + const Graph& graph_; + const FunctionLibraryDefinition& flib_def_; + // Indexed by the node id. + // If cache_[node_id] is empty, the deepness of the node with id `node_id` has + // not been computed yet. Else, it contains the value already computed. + std::vector> cache_; +}; + +// Extracts `fdef` and `func` from `flib_def` for the function identified +// in "f" attribute of `node`. +absl::Status GetFunctionDefAndAttrs(const FunctionLibraryDefinition& flib_def, + const Node& node, + core::RefCountPtr* fdef, + NameAttrList* func); + +// The "call" stack of functions. +// Useful for better error messages as well as for detecting recursion. +// Stores references to graph nodes. These references must outlive this. +class FunctionStack { + public: + explicit FunctionStack(const string& function_name); + + // `node_in_current_function` must outlive this. + FunctionStack Push(const Node* node_in_current_function, + const string& new_current_function) const; + + // Returns true iff this stack already includes `function_name`. + bool HasFunction(const string& function_name) const; + + const string& current_function_name() const { return current_function_name_; } + + // Format's this suitable for error interpolation that retrieves + // Python files and line numbers. + string FormatForError() const; + + private: + struct Frame { + Frame(const string& function, const Node* node) + : function_name(function), node(node) {} + + string function_name; + const Node* node; + }; + + // The function at the top of the stack. In other words, the function + // that is currently being inspected for placement. + string current_function_name_; + + // The stack of frames that got the placement to the current_function_name_. + // frames_[0].function_name is the top function that Placer was constructed + // with. frames_[0].function_name can be empty if placer was constructed with + // a nameless graph, not a function. frames_[0].node_name is a name of a node + // in frames_[0].function_name that required deep inspection (e.g. a + // PartitionedCallOp). The function that this node invoked is + // frames_[1].function_name, if frames_.size() > 1. Else, the function that + // this node invoked is current_function_name_. + std::vector frames_; +}; + +// Adds Identities for each input and output of function-calling ops in `graph` +// +// For example, the following graph calling a function on inputs `a` and `b` +// and producing output `y` will be rewritten to include identities on all +// edges: +// +// a b +// | | +// v v +// f (PartitionedCallOp) +// | +// v +// y +// +// is transformed to +// +// a b +// | | +// a_f (Identity) b_f (Identity) +// | | +// v v +// f (PartitionedCallOp) +// | +// f_y (Identity) +// | +// v +// y +// +absl::Status IsolatePlacerInspectionRequiredOps( + const FunctionLibraryDefinition& flib_def, Graph* graph); + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_PLACER_INSPECTION_REQUIRED_OPS_UTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/pluggable_device/pluggable_device.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/pluggable_device/pluggable_device.h new file mode 100644 index 00000000..bfcbc16d --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/pluggable_device/pluggable_device.h @@ -0,0 +1,122 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_PLUGGABLE_DEVICE_PLUGGABLE_DEVICE_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_PLUGGABLE_DEVICE_PLUGGABLE_DEVICE_H_ + +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "xla/stream_executor/stream_executor.h" +#include "tensorflow/core/common_runtime/device/device_event_mgr.h" +#include "tensorflow/core/common_runtime/device/device_id.h" +#include "tensorflow/core/common_runtime/device/device_id_manager.h" +#include "tensorflow/core/common_runtime/local_device.h" +#include "tensorflow/core/common_runtime/pluggable_device/pluggable_device_context.h" +#include "tensorflow/core/common_runtime/shared_counter.h" +#include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/framework/device_attributes.pb.h" +#include "tensorflow/core/framework/device_base.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/graph/types.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/gtl/inlined_vector.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/stream_executor.h" +#include "tensorflow/core/platform/threadpool.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/public/session_options.h" + +namespace tensorflow { + +class PluggableDevice : public LocalDevice { + public: + PluggableDevice(const SessionOptions& options, const std::string& name, + const string& device_type, const string& platform_name, + Bytes memory_limit, const DeviceLocality& locality, + TfDeviceId tf_device_id, + const std::string& physical_device_desc, + Allocator* device_allocator, Allocator* cpu_allocator, + bool sync_every_op); + + ~PluggableDevice() override; + + // Initialize the device and return the status of initialization. + absl::Status Init(const SessionOptions& options); + + void ComputeAsync(AsyncOpKernel* op_kernel, OpKernelContext* context, + AsyncOpKernel::DoneCallback done) override; + + void Compute(OpKernel* op_kernel, OpKernelContext* context) override; + + absl::Status Sync() override; + + Allocator* GetAllocator(AllocatorAttributes attr) override; + + absl::Status MakeTensorFromProto(const TensorProto& tensor_proto, + AllocatorAttributes alloc_attrs, + Tensor* tensor) override; + + void CopyTensorInSameDevice(const Tensor* input_tensor, Tensor* output_tensor, + const DeviceContext* device_context, + StatusCallback done) override; + + // The executor that provides control for the pluggable device; + se::StreamExecutor* executor() const { return executor_; } + + private: + Allocator* device_allocator_; + Allocator* cpu_allocator_; + + se::StreamExecutor* executor_ = nullptr; + struct StreamGroup { + se::Stream* compute = nullptr; + se::Stream* host_to_device = nullptr; + se::Stream* device_to_host = nullptr; + absl::InlinedVector device_to_device; + }; + + class StreamGroupFactory; + + StreamGroup* stream_; + PluggableDeviceContext* device_context_; + // TODO(penpornk): Investigate renaming `GpuDeviceInfo` to `DeviceInfo`. + DeviceBase::AcceleratorDeviceInfo* pluggable_device_info_ = nullptr; + TfDeviceId tf_device_id_; + const string platform_name_; + const bool sync_every_op_ = false; + EventMgr* em_ = nullptr; + std::unique_ptr thread_pool_; + bool force_gpu_compatible_ = false; + std::string ComputeOpKernelDebugString(const OpKernel& op_kernel, + int stream_id); + + // This method returns an initialization status, in addition to + // calling the "done" StatusCallback, if there is a failure to + // allocate memory or if the tensor "from" is not DMA-copyable. + // If there is no error prior to enqueueing the copy, an OK status + // is returned. + absl::Status MaybeCopyTensorToPluggableDevice( + const AllocatorAttributes& alloc_attrs, const Tensor& from, Tensor* to, + StatusCallback done); +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_PLUGGABLE_DEVICE_PLUGGABLE_DEVICE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/pluggable_device/pluggable_device_bfc_allocator.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/pluggable_device/pluggable_device_bfc_allocator.h new file mode 100644 index 00000000..898e3834 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/pluggable_device/pluggable_device_bfc_allocator.h @@ -0,0 +1,57 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_PLUGGABLE_DEVICE_PLUGGABLE_DEVICE_BFC_ALLOCATOR_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_PLUGGABLE_DEVICE_PLUGGABLE_DEVICE_BFC_ALLOCATOR_H_ + +#include +#include +#include +#include +#include + +#include "tensorflow/core/common_runtime/bfc_allocator.h" +#include "tensorflow/core/common_runtime/device/device_mem_allocator.h" +#include "tensorflow/core/platform/thread_annotations.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/protobuf/config.pb.h" + +namespace tensorflow { + +// A PluggableDevice memory allocator that implements a 'best-fit with +// coalescing' algorithm +class PluggableDeviceBFCAllocator : public BFCAllocator { + public: + PluggableDeviceBFCAllocator(DeviceMemAllocator* sub_allocator, + size_t total_memory, const string& name, + bool force_memory_growth_requested); + PluggableDeviceBFCAllocator(DeviceMemAllocator* sub_allocator, + size_t total_memory, + const GPUOptions& gpu_options, const string& name, + bool force_memory_growth_requested); + ~PluggableDeviceBFCAllocator() override = default; + + PluggableDeviceBFCAllocator(const PluggableDeviceBFCAllocator&) = delete; + void operator=(const PluggableDeviceBFCAllocator&) = delete; + + private: + static bool GetAllowGrowthValue(const GPUOptions& gpu_options, + bool force_memory_growth_requested); + static bool GetGarbageCollectionValue(); +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_PLUGGABLE_DEVICE_PLUGGABLE_DEVICE_BFC_ALLOCATOR_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/pluggable_device/pluggable_device_context.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/pluggable_device/pluggable_device_context.h new file mode 100644 index 00000000..596341fd --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/pluggable_device/pluggable_device_context.h @@ -0,0 +1,93 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_PLUGGABLE_DEVICE_PLUGGABLE_DEVICE_CONTEXT_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_PLUGGABLE_DEVICE_PLUGGABLE_DEVICE_CONTEXT_H_ + +#include + +#include "absl/status/status.h" +#include "tensorflow/core/common_runtime/device.h" +#include "tensorflow/core/framework/device_base.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/lib/gtl/inlined_vector.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/stringpiece.h" + +namespace stream_executor { +class Stream; +} // namespace stream_executor + +namespace tensorflow { + +class PluggableDeviceContext : public DeviceContext { + public: + // Does not take ownership of streams. + PluggableDeviceContext( + int stream_id, se::Stream* stream, se::Stream* host_to_device_stream, + se::Stream* device_to_host_stream, + absl::InlinedVector device_to_device_stream) + : stream_id_(stream_id), + stream_(stream), + host_to_device_stream_(host_to_device_stream), + device_to_host_stream_(device_to_host_stream), + device_to_device_stream_(device_to_device_stream) {} + + ~PluggableDeviceContext() override = default; + + se::Stream* stream() const override { return stream_; } + se::Stream* host_to_device_stream() const { return host_to_device_stream_; } + se::Stream* device_to_host_stream() const { return device_to_host_stream_; } + se::Stream* device_to_device_stream(int index) const { + return device_to_device_stream_[index % device_to_device_stream_.size()]; + } + int stream_id() const { return stream_id_; } + + void CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device, + Tensor* device_tensor, StatusCallback done, + bool sync_dst_compute) const override; + + void CopyDeviceTensorToCPU(const Tensor* device_tensor, + absl::string_view tensor_name, Device* device, + Tensor* cpu_tensor, StatusCallback done) override; + + void CopyTensorInSameDevice(const Tensor* input_tensor, Device* device, + Tensor* output_tensor, + StatusCallback done) const override; + + void MaintainLifetimeOnStream(const Tensor* t, + se::Stream* stream) const override {} + + absl::Status ThenExecute(Device* device, se::Stream* stream, + std::function func) override; + + bool IsPluggableDevice() override; + + private: + int stream_id_; + // The default primary stream to use for this context. + // All the memory belongs to this stream. + se::Stream* stream_; + // The stream to use for copying data from host into PluggableDevice. + se::Stream* host_to_device_stream_; + // The stream to use for copying data from PluggableDevice to host. + se::Stream* device_to_host_stream_; + // Streams to use for copying data between PluggableDevices. + absl::InlinedVector device_to_device_stream_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_PLUGGABLE_DEVICE_PLUGGABLE_DEVICE_CONTEXT_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.h new file mode 100644 index 00000000..3f6ab10f --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.h @@ -0,0 +1,66 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_PLUGGABLE_DEVICE_PLUGGABLE_DEVICE_FACTORY_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_PLUGGABLE_DEVICE_PLUGGABLE_DEVICE_FACTORY_H_ + +#include +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "tensorflow/core/common_runtime/device/device_id.h" +#include "tensorflow/core/common_runtime/device_factory.h" +#include "tensorflow/core/framework/device_attributes.pb.h" +#include "tensorflow/core/framework/device_factory.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/public/session_options.h" + +namespace tensorflow { +class PluggableDeviceFactory : public DeviceFactory { + public: + PluggableDeviceFactory(const string& device_type, + const string& platform_name); + absl::Status ListPhysicalDevices(std::vector* devices) override; + absl::Status CreateDevices( + const SessionOptions& options, const std::string& name_prefix, + std::vector>* devices) override; + absl::Status GetDeviceDetails( + int device_index, std::unordered_map* details) override; + + private: + // Populates *device_localities with the DeviceLocality descriptor for + // every TfDeviceId. + absl::Status GetDeviceLocalities( + int num_tf_devices, std::vector* device_localities); + // Create a PluggableDevice associated with 'tf_device_id', allocates + // (strictly) 'memory_limit' bytes of PluggableDevice memory to it, and adds + // it to the 'devices' vector. + absl::Status CreatePluggableDevice( + const SessionOptions& options, const std::string& name_prefix, + TfDeviceId tf_device_id, int64_t memory_limit, + const DeviceLocality& dev_locality, + std::vector>* devices); + + const string device_type_; + const string platform_name_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_PLUGGABLE_DEVICE_PLUGGABLE_DEVICE_FACTORY_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/pluggable_device/pluggable_device_init.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/pluggable_device/pluggable_device_init.h new file mode 100644 index 00000000..b77917d1 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/pluggable_device/pluggable_device_init.h @@ -0,0 +1,45 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_PLUGGABLE_DEVICE_PLUGGABLE_DEVICE_INIT_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_PLUGGABLE_DEVICE_PLUGGABLE_DEVICE_INIT_H_ + +#include + +#include "absl/status/status.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/types.h" + +namespace stream_executor { +class Platform; +} // namespace stream_executor + +namespace tensorflow { + +// Initializes the PluggableDevice platform and returns OK if the +// PluggableDevice platform could be initialized. +absl::Status ValidatePluggableDeviceMachineManager(const string& platform_name); + +// Returns the PluggableDevice machine manager singleton, creating it and +// initializing the PluggableDevices on the machine if needed the first time it +// is called. Must only be called when there is a valid PluggableDevice +// environment in the process (e.g., ValidatePluggableDeviceMachineManager() +// returns OK). +stream_executor::Platform* PluggableDeviceMachineManager( + const string& platform_name); + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_PLUGGABLE_DEVICE_PLUGGABLE_DEVICE_INIT_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/pluggable_device/pluggable_device_plugin_init.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/pluggable_device/pluggable_device_plugin_init.h new file mode 100644 index 00000000..9676a706 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/pluggable_device/pluggable_device_plugin_init.h @@ -0,0 +1,27 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_PLUGGABLE_DEVICE_PLUGGABLE_DEVICE_PLUGIN_INIT_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_PLUGGABLE_DEVICE_PLUGGABLE_DEVICE_PLUGIN_INIT_H_ + +#include "tensorflow/core/platform/status.h" + +namespace tensorflow { + +absl::Status RegisterPluggableDevicePlugin(void* library_filename); + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_PLUGGABLE_DEVICE_PLUGGABLE_DEVICE_PLUGIN_INIT_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/pluggable_device/pluggable_device_process_state.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/pluggable_device/pluggable_device_process_state.h new file mode 100644 index 00000000..0c396588 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/pluggable_device/pluggable_device_process_state.h @@ -0,0 +1,128 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_PLUGGABLE_DEVICE_PLUGGABLE_DEVICE_PROCESS_STATE_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_PLUGGABLE_DEVICE_PLUGGABLE_DEVICE_PROCESS_STATE_H_ + +#include +#include +#include +#include +#include +#include + +#include "tensorflow/core/common_runtime/device/device_id.h" +#include "tensorflow/core/common_runtime/process_state.h" +#include "tensorflow/core/common_runtime/shared_counter.h" +#include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/thread_annotations.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/protobuf/config.pb.h" + +namespace tensorflow { + +class PluggableDeviceBFCAllocator; +class PluggableDeviceSimpleAllocator; +class PoolAllocator; + +// Singleton that manages per-process state when PluggableDevices are present. +class PluggableDeviceProcessState { + public: + // Singleton that manages each platform's per-process state. e.g. allocation + // of shared resource. + static PluggableDeviceProcessState* singleton(const string& device_type, + const string& platform_name); + + // Query whether any PluggableDevice has been created so far. + // Disable thread safety analysis since a race is benign here. + bool HasPluggableDevice() const TF_NO_THREAD_SAFETY_ANALYSIS { + return pluggable_device_enabled_; + } + + // Set the flag to indicate a PluggableDevice has been created. + // Disable thread safety analysis since a race is benign here. + void EnablePluggableDevice() TF_NO_THREAD_SAFETY_ANALYSIS { + pluggable_device_enabled_ = true; + } + + // Returns the one PluggableDevice allocator used for the indexed + // PluggableDevice. Note that this is a system PluggableDevice index. + // + // 'total_bytes' is the total number of bytes that should be made + // available to the allocator. The first call to this function for + // a given tf_device_id creates the allocator, so only the + // total_bytes used on that first call is used. + // + // 'allocator_type' describes the type of algorithm to use for the + // underlying allocator. REQUIRES: Must be a valid type (see + // config.proto for the list of supported strings.). + // + // REQUIRES: tf_device_id must be a valid id for a PluggableDevice + // available in the current system environment. Otherwise returns nullptr. + virtual Allocator* GetPluggableDeviceAllocator(const GPUOptions& options, + TfDeviceId tf_device_id, + size_t total_bytes); + + int NumPluggableDeviceAllocators() { + mutex_lock l(mu_); + return pluggable_device_allocators_.size(); + } + + virtual Allocator* GetPluggableDeviceHostAllocator(int numa_node); + + // Returns bus_id for the given PluggableDevice id. + virtual int BusIdForPluggableDevice(TfDeviceId tf_device_id); + + protected: + // PluggableDeviceProcessState is a singleton that should not normally be + // deleted except at process shutdown. + PluggableDeviceProcessState(const string& device_type, + const string& platform_name); + virtual ~PluggableDeviceProcessState() = default; + + ProcessState::MDMap* mem_desc_map() { + if (process_state_) return &process_state_->mem_desc_map_; + return nullptr; + } + + static PluggableDeviceProcessState* instance_; + ProcessState* process_state_; // Not owned. + bool pluggable_device_enabled_; + const string device_type_; + const string platform_name_; + mutex mu_; + + struct AllocatorParts { + std::unique_ptr allocator; + Allocator* device_allocator; + SubAllocator* sub_allocator; // owned by allocator + }; + + std::vector pluggable_device_allocators_ TF_GUARDED_BY(mu_); + std::vector> pluggable_device_visitors_ + TF_GUARDED_BY(mu_); + + std::vector pluggable_device_host_allocators_ + TF_GUARDED_BY(mu_); + std::vector> + pluggable_device_host_alloc_visitors_ TF_GUARDED_BY(mu_); + std::vector> + pluggable_device_host_free_visitors_ TF_GUARDED_BY(mu_); +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_PLUGGABLE_DEVICE_PLUGGABLE_DEVICE_PROCESS_STATE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/pluggable_device/pluggable_device_simple_allocator.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/pluggable_device/pluggable_device_simple_allocator.h new file mode 100644 index 00000000..7cddbfb6 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/pluggable_device/pluggable_device_simple_allocator.h @@ -0,0 +1,58 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_PLUGGABLE_DEVICE_PLUGGABLE_DEVICE_SIMPLE_ALLOCATOR_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_PLUGGABLE_DEVICE_PLUGGABLE_DEVICE_SIMPLE_ALLOCATOR_H_ + +#include +#include +#include +#include +#include +#include + +#include "tensorflow/core/common_runtime/device/device_mem_allocator.h" +#include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/platform/thread_annotations.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/protobuf/config.pb.h" + +namespace tensorflow { + +class PluggableDeviceSimpleAllocator : public Allocator { + public: + explicit PluggableDeviceSimpleAllocator(DeviceMemAllocator* sub_allocator); + ~PluggableDeviceSimpleAllocator() override = default; + + void* AllocateRaw(size_t alignment, size_t num_bytes) override; + void DeallocateRaw(void* ptr) override; + + bool TracksAllocationSizes() const override { return false; } + string Name() override { return "Simple allocator"; } + std::optional GetStats() override; + + AllocatorMemoryType GetMemoryType() const override { + return sub_allocator_->GetMemoryType(); + } + + private: + PluggableDeviceSimpleAllocator(const PluggableDeviceSimpleAllocator&) = + delete; + void operator=(const PluggableDeviceSimpleAllocator&) = delete; + std::unique_ptr sub_allocator_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_PLUGGABLE_DEVICE_PLUGGABLE_DEVICE_SIMPLE_ALLOCATOR_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/pluggable_device/pluggable_device_util.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/pluggable_device/pluggable_device_util.h new file mode 100644 index 00000000..7d5f1e2a --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/pluggable_device/pluggable_device_util.h @@ -0,0 +1,76 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_PLUGGABLE_DEVICE_PLUGGABLE_DEVICE_UTIL_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_PLUGGABLE_DEVICE_PLUGGABLE_DEVICE_UTIL_H_ + +#include "absl/status/status.h" +#include "tensorflow/core/common_runtime/device.h" +#include "tensorflow/core/common_runtime/dma_helper.h" +#include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/framework/device.h" +#include "tensorflow/core/framework/device_base.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/stream_executor.h" + +namespace tensorflow { + +class RecvTensorResponse; +class TensorProto; + +class PluggableDeviceUtil { + public: + // Copies the data in 'device_tensor' into 'cpu_tensor'. + // 'device_tensor''s backing memory must be on 'device' and + // 'cpu_tensor' must be allocated to be of the same size as + // 'device_tensor'. Synchronous: may block. + static void CopyPluggableDeviceTensorToCPU( + Device* device, const DeviceContext* device_context, + const Tensor* device_tensor, Tensor* cpu_tensor, StatusCallback done); + // Blocks until all operations queued on the stream associated with + // 'device' at the time of the call have completed. Returns any + // error pending on the stream at completion. + static absl::Status Sync(Device* device); + + // Blocks until all operations queued on all streams associated with the + // corresponding 'device' at the time of call have completed. + // Returns any error pending on the stream at completion. + static absl::Status SyncAll(Device* device); + + static void CopyCPUTensorToPluggableDevice( + const Tensor* cpu_tensor, const DeviceContext* device_context, + Device* device, Tensor* device_tensor, StatusCallback done, + bool sync_dst_compute); + + static void DeviceToDeviceCopy( + DeviceContext* send_dev_context, DeviceContext* recv_dev_context, + Device* src, Device* dst, AllocatorAttributes src_alloc_attr, + AllocatorAttributes dst_alloc_attr, const Tensor* input, Tensor* output, + int dev_to_dev_stream_index, StatusCallback done); + + // Deep-copying of PluggableDevice tensor on the same device. + // 'src_device_tensor''s and 'dst_device_tensor''s backing memory must be on + // 'device' and 'dst_cpu_tensor' must be allocated to be of the same + // size as 'src_device_tensor'. + static void CopyPluggableDeviceTensorToSameDevice( + Device* device, const DeviceContext* device_context, + const Tensor* src_device_tensor, Tensor* dst_device_tensor, + StatusCallback done); +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_PLUGGABLE_DEVICE_PLUGGABLE_DEVICE_UTIL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/pool_allocator.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/pool_allocator.h new file mode 100644 index 00000000..6ce3b788 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/pool_allocator.h @@ -0,0 +1,181 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_POOL_ALLOCATOR_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_POOL_ALLOCATOR_H_ + +// Simple LRU pool allocators for various flavors of CPU RAM. + +#include +#include +#include +#include + +#include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/lib/core/bits.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +// Interface of an object that rounds up integers. +class RoundUpInterface { + public: + virtual ~RoundUpInterface() {} + virtual size_t RoundUp(size_t num_bytes) = 0; +}; + +// Size-limited pool of memory buffers obtained from a SubAllocator +// instance. Pool eviction policy is LRU. +class PoolAllocator : public Allocator { + public: + // "pool_size_limit" is the maximum number of returned, re-usable + // memory buffers to keep in the pool. If pool_size_limit == 0, the + // pool is effectively a thin wrapper around the allocator. + // If "auto_resize" is true, then the pool_size_limit will gradually + // be raised so that deallocations happen very rarely, if at all. + // Transitory start-up objects may deallocate, but the long-term + // working-set should not. Auto-resizing can raise pool_size_limit + // but will never lower it. + // "allocator" is the object that performs the underlying memory + // malloc/free operations. This object takes ownership of allocator. + PoolAllocator(size_t pool_size_limit, bool auto_resize, + SubAllocator* allocator, RoundUpInterface* size_rounder, + string name); + ~PoolAllocator() override; + + string Name() override { return name_; } + + void* AllocateRaw(size_t alignment, size_t num_bytes) override; + + void DeallocateRaw(void* ptr) override; + + // Allocate an unused memory region of size "num_bytes". Fetch from + // the pool if available, otherwise call allocator_. + void* Get(size_t num_bytes); + + // Return a no-longer needed memory region to the pool. It is an error + // to deference "ptr" after this call. If the pool is full, the least + // recently used region will be deallocated. + void Put(void* ptr, size_t num_bytes); + + // Reset the pool to empty. + void Clear(); + + // The following accessors permit monitoring the effectiveness of + // the pool at avoiding repeated malloc/frees on the underlying + // allocator. Read locks are not taken on the theory that value + // consistency with other threads is not important. + + // Number of Get() requests satisfied from pool. + int64_t get_from_pool_count() const TF_NO_THREAD_SAFETY_ANALYSIS { + return get_from_pool_count_; + } + // Number of Put() requests. + int64_t put_count() const TF_NO_THREAD_SAFETY_ANALYSIS { return put_count_; } + // Number of Get() requests requiring a fresh allocation. + int64_t allocated_count() const TF_NO_THREAD_SAFETY_ANALYSIS { + return allocated_count_; + } + // Number of pool evictions. + int64_t evicted_count() const TF_NO_THREAD_SAFETY_ANALYSIS { + return evicted_count_; + } + // Current size limit. + size_t size_limit() const TF_NO_THREAD_SAFETY_ANALYSIS { + return pool_size_limit_; + } + + AllocatorMemoryType GetMemoryType() const override { + return allocator_->GetMemoryType(); + } + + private: + struct PtrRecord { + void* ptr; + size_t num_bytes; + PtrRecord* prev; + PtrRecord* next; + }; + + // Remove "pr" from the double-linked LRU list. + void RemoveFromList(PtrRecord* pr) TF_EXCLUSIVE_LOCKS_REQUIRED(mutex_); + + // Add "pr" to the head of the double-linked LRU list. + void AddToList(PtrRecord* pr) TF_EXCLUSIVE_LOCKS_REQUIRED(mutex_); + + // Delete the least recently used record. + void EvictOne() TF_EXCLUSIVE_LOCKS_REQUIRED(mutex_); + + const string name_; + const bool has_size_limit_; + const bool auto_resize_; + size_t pool_size_limit_; + std::unique_ptr allocator_; + std::unique_ptr size_rounder_; + mutex mutex_; + std::multimap pool_ TF_GUARDED_BY(mutex_); + PtrRecord* lru_head_ TF_GUARDED_BY(mutex_) = nullptr; + PtrRecord* lru_tail_ TF_GUARDED_BY(mutex_) = nullptr; + int64_t get_from_pool_count_ TF_GUARDED_BY(mutex_) = 0; + int64_t put_count_ TF_GUARDED_BY(mutex_) = 0; + int64_t allocated_count_ TF_GUARDED_BY(mutex_) = 0; + int64_t evicted_count_ TF_GUARDED_BY(mutex_) = 0; +}; + +// Do-nothing rounder. Passes through sizes unchanged. +class NoopRounder : public RoundUpInterface { + public: + size_t RoundUp(size_t num_bytes) override { return num_bytes; } +}; + +// Power of 2 rounder: rounds up to nearest power of 2 size. +class Pow2Rounder : public RoundUpInterface { + public: + size_t RoundUp(size_t num_bytes) override { + return 1uLL << Log2Ceiling64(num_bytes); + } +}; + +class BasicCPUAllocator : public SubAllocator { + public: + BasicCPUAllocator(int numa_node, const std::vector& alloc_visitors, + const std::vector& free_visitors) + : SubAllocator(alloc_visitors, free_visitors), numa_node_(numa_node) {} + + ~BasicCPUAllocator() override {} + + void* Alloc(size_t alignment, size_t num_bytes, + size_t* bytes_received) override; + + void Free(void* ptr, size_t num_bytes) override; + + bool SupportsCoalescing() const override { return false; } + + AllocatorMemoryType GetMemoryType() const override { + return AllocatorMemoryType::kHostPageable; + } + + private: + int numa_node_; + + BasicCPUAllocator(const BasicCPUAllocator&) = delete; + void operator=(const BasicCPUAllocator&) = delete; +}; + +} // namespace tensorflow +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_POOL_ALLOCATOR_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/process_function_library_runtime.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/process_function_library_runtime.h new file mode 100644 index 00000000..0b3b9dc0 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/process_function_library_runtime.h @@ -0,0 +1,545 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_PROCESS_FUNCTION_LIBRARY_RUNTIME_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_PROCESS_FUNCTION_LIBRARY_RUNTIME_H_ + +#include +#include +#include +#include +#include +#include + +#include "tensorflow/core/common_runtime/composite_device.h" +#include "tensorflow/core/common_runtime/device_mgr.h" +#include "tensorflow/core/common_runtime/device_set.h" +#include "tensorflow/core/common_runtime/stats_publisher_interface.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/refcount.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/protobuf/config.pb.h" +#include "tsl/platform/thread_annotations.h" + +#if !defined(IS_MOBILE_PLATFORM) +#include "tensorflow/core/protobuf/remote_tensor_handle.pb.h" +#endif // !IS_MOBILE_PLATFORM + +namespace tensorflow { + +class FunctionArgsInterface { + public: + virtual ~FunctionArgsInterface() {} + + virtual bool HasRemoteOrPackedInputs() const = 0; + + virtual absl::Status GetLocalArg(const FunctionArgIndex& index, + Tensor* val) const = 0; + + virtual std::vector GetLocalTensors() const = 0; + +#if !defined(IS_MOBILE_PLATFORM) + virtual absl::Status GetRemoteArg(const FunctionArgIndex& index, + eager::RemoteTensorHandle* val) const { + return errors::Unimplemented( + "Serializing a remote argument is not implemented."); + } +#endif // IS_MOBILE_PLATFORM +}; + +// A class that stores all the FunctionLibraryRuntime objects, one per device. +class ProcessFunctionLibraryRuntime { + public: + // Creates FunctionLibraryRuntime objects for each device in the provided + // DeviceMgr. Caller needs to make sure that device_mgr, lib_def and parent + // (if provided) outlive this object. + ProcessFunctionLibraryRuntime( + const DeviceMgr* device_mgr, Env* env, const ConfigProto* config, + int graph_def_version, const FunctionLibraryDefinition* lib_def, + const OptimizerOptions& optimizer_options, + thread::ThreadPool* thread_pool = nullptr, + DistributedFunctionLibraryRuntime* parent = nullptr, + const SessionMetadata* session_metadata = nullptr, + Rendezvous::Factory rendezvous_factory = Rendezvous::Factory(), + StatsPublisherFactory stats_publisher_factory = CreateNoOpStatsPublisher); + + ~ProcessFunctionLibraryRuntime() { + // Deleting the FunctionLibraryRuntime map will delete the function handles + // registered in it, which may call ReleaseHandle in this class again to + // release their sub-function. These circular calls may cause segfault + // since the flr_map_ may have already been deleted. Explicitly releasing + // flr_map_ here and checking flr_map_ in ReleaseHandle to avoid this. + flr_map_.reset(); + } + + // Sends `tensors_to_send` from `source_device` to `target_device` using + // `rendezvous`. `key_prefix` is used as a prefix for the keys sent to the + // Rendezvous. `device_context` should be the DeviceContext of the device + // doing the sending. `alloc_attrs` should either be empty or be the size of + // `tensors_to_send` and indicates how the input tensors are allocated. Method + // takes references on each of the `tensors_to_send`. Method doesn't block. + static absl::Status SendTensors( + const string& source_device, const string& target_device, + const string& key_prefix, int64_t src_incarnation, + absl::Span tensors_to_send, DeviceContext* device_context, + const std::vector& alloc_attrs, + RendezvousInterface* rendezvous); + + // Receives `received_tensors` from `target_device` (originally sent from + // `source_device`) using `rendezvous`. Uses `key_prefix` to construct the + // keys to be retrieved. `device_context` should be for the device receiving + // the tensors. `alloc_attrs` indicates how to allocate the received + // tensors and should either be empty or `num_tensors` in size. Method doesn't + // block and calls `done` when `num_tensors` are fetched. + static void ReceiveTensorsAsync( + const string& source_device, const string& target_device, + const string& key_prefix, int64_t src_incarnation, int64_t num_tensors, + DeviceContext* device_context, + const std::vector& alloc_attrs, + RendezvousInterface* rendezvous, std::vector* received_tensors, + StatusCallback done); + + static const char kDefaultFLRDevice[]; + // Returns the FunctionLibraryRuntime for the corresponding device_name. + FunctionLibraryRuntime* GetFLR(const string& device_name) const; + + // Returns the return types for the function identified by handle `h`. + absl::Status GetRetTypes(FunctionLibraryRuntime::Handle h, + DataTypeVector* ret_types); + + // Returns the device incarnation for the given device_name. + absl::Status GetDeviceIncarnation(const string& device_name, + int64_t* incarnation) const; + + // For a given canonicalized key signature of the function instantiated + // on device `device_name` and a `local_handle`, creates a handle and returns + // that value. Uses core/common_runtime/framework/function.h::Canonicalize + // to canonicalize the function signature. + FunctionLibraryRuntime::Handle AddHandle( + const string& function_key, const string& device_name, + FunctionLibraryRuntime::LocalHandle local_handle); + + // Returns a handle if found for the given key, else returns kInvalidHandle. + FunctionLibraryRuntime::Handle GetHandle(const string& function_key) const; + + // For the given handle instantiated on device `device_name` returns the local + // index of instantiation of that function. If the function was not + // instantiated on `device_name` or the function is multi-device, + // returns kInvalidLocalHandle. + // + // If `include_multi_device` is true and `handle` is a multi-device function + // with a single component that is placed on `device_name`, then this method + // will return the local handle for that component. + FunctionLibraryRuntime::LocalHandle GetHandleOnDevice( + const string& device_name, FunctionLibraryRuntime::Handle handle, + bool include_multi_device = false) const; + + // Fills `output_devices` with the devices on which the results will + // be produced. If some output is produced on CPU, the corresponding Device* + // is set to nullptr. If some output is DT_RESOURCE, the corresponding Device* + // is set to the device backing the resource. + // REQUIRES: `handle` identifies a multi-device function. + absl::Status GetOutputDevices(FunctionLibraryRuntime::Handle handle, + std::vector* output_devices) const; + + // Instantiates the function. See framework/function.h for more details. + // Allows for function_name to be instantiated on different devices + // as specified in attrs. + absl::Status Instantiate( + const string& function_name, AttrSlice attrs, + const FunctionLibraryRuntime::InstantiateOptions& options, + FunctionLibraryRuntime::Handle* handle); + + // Returns whether the function represented by the given handle needs to + // execute cross process. + absl::Status IsCrossProcess(FunctionLibraryRuntime::Handle handle, + bool* is_cross_process) const; + + // Delegates to the local FLR that owns state corresponding to `handle` and + // tells it to release it. If the `handle` isn't needed at all, the local FLR + // might call RemoveHandle on this to get rid of the state owned by the Proc + // FLR. + // For multi-device functions, calls ReleaseHandle on local FLRs for each + // component function that is part of this multi-device function. + // Each local FLR might call RemoveHandle on this. + absl::Status ReleaseHandle(FunctionLibraryRuntime::Handle handle); + + // Runs the function with given `handle`. Function could have been + // instantiated on any device. More details in framework/function.h + void Run(const FunctionLibraryRuntime::Options& opts, + FunctionLibraryRuntime::Handle handle, absl::Span args, + std::vector* rets, + FunctionLibraryRuntime::DoneCallback done) const; + void Run(const FunctionLibraryRuntime::Options& opts, + FunctionLibraryRuntime::Handle handle, CallFrameInterface* frame, + FunctionLibraryRuntime::DoneCallback done) const; + + void Run(const FunctionLibraryRuntime::Options& opts, + FunctionLibraryRuntime::Handle handle, + const FunctionArgsInterface& args, std::vector* rets, + FunctionLibraryRuntime::DoneCallback done) const; + + absl::Status RunSync(const FunctionLibraryRuntime::Options& opts, + FunctionLibraryRuntime::Handle handle, + absl::Span args, + std::vector* rets) const; + absl::Status RunSync(const FunctionLibraryRuntime::Options& opts, + FunctionLibraryRuntime::Handle handle, + CallFrameInterface* frame) const; + + const DeviceMgr* device_mgr() { return device_mgr_; } + + const std::shared_ptr device_set() const { + tf_shared_lock l(mu_); + return device_set_; + } + + // Initialize the set of local and remote devices and corresponding flr for op + // device selection. + void InitializeDeviceAndFlr(); + + const ConfigProto* config() const { return config_ ? &(*config_) : nullptr; } + + const FunctionLibraryDefinition* GetFunctionLibraryDefinition() const { + return lib_def_; + } + + // Add a CompositeDevice to `device_set_` + void AddCompositeDevice(CompositeDevice* d) TF_LOCKS_EXCLUDED(mu_) { + mutex_lock l(mu_); + device_set_->AddDevice(d); + composite_devices_.push_back(d); + } + + protected: + friend class FunctionLibraryRuntimeImpl; + + struct InternalArgs { + std::vector args; +#if !defined(IS_MOBILE_PLATFORM) + // Holds the RemoteTensorHandles referred by args. + std::vector> remote_args; +#endif // IS_MOBILE_PLATFORM + }; + + // Structure detailing the asynchronous assumptions of a component function, + // such as whether it can support synchronous execution and any information + // needed to execute in proper order to resolve inter-subgraph dependencies. + class AsyncAttributes { + public: + enum Summary { kSafeForSync = 0, kSendOnly, kRecvOnly, kAsyncRequired }; + + AsyncAttributes() + : allow_control_flow_sync_execution_(false), summary_(kSafeForSync) {} + explicit AsyncAttributes(const Graph* graph, + bool allow_control_flow_sync_execution) + : allow_control_flow_sync_execution_(allow_control_flow_sync_execution), + summary_(Summarize(graph)) {} + Summary summary() const { return summary_; } + bool allow_control_flow_sync_execution() const { + return allow_control_flow_sync_execution_; + } + + private: + // This data member should be initialized before the summary_. + bool allow_control_flow_sync_execution_; + Summary summary_; + Summary Summarize(const Graph* graph); + }; + + // Structure to keep track of how a component function (a single-device + // piece of a multi-device function) fits into the multi-device function. + struct ComponentFunctionData { + // The handle for the instantiated component function. + FunctionLibraryRuntime::Handle handle; + // The name for the component function. + string name; + // arg_indices.size() is the number of arguments to the component function. + // The i-th argument of the component function comes from the + // `arg_indices[i]`-th argument of the multi-device function. + std::vector arg_indices; + // ret_indices.size() is the number of return values of the component + // function. The i-th return value of the component function goes to the + // `ret_indices[i]`-th return value of the multi-device function. + std::vector ret_indices; + // arg_alloc_attrs[i] are the allocator attributes of the i-th argument to + // the component function. + std::vector arg_alloc_attrs; + // ret_alloc_attrs[i] are the allocator attributes of the i-th return value + // of the component function. + std::vector ret_alloc_attrs; + + AsyncAttributes async_attributes; + }; + + // Data structure holding information for a single instantiated multi-device + // function. + // The fields are filled in during instantiation. Once the object is + // added to mdevice_data_, all fields are constant. + struct MultiDeviceFunctionData { + MultiDeviceFunctionData(const string& function_name, + const string& function_key, int num_outputs, + DataTypeVector ret_types) + : function_name_(function_name), + function_key_(function_key), + instantiation_counter_(1), + num_outputs_(num_outputs), + ret_types_(std::move(ret_types)), + is_cross_process_(false), + has_remote_outputs(false) {} + + const string function_name_; + const string function_key_; + uint64 instantiation_counter_; + // Stored here to resize the output tensor vector when function is run. + const int num_outputs_; + DataTypeVector ret_types_; + + // Indicates whether this function needs to execute cross process. + bool is_cross_process_; + // Indicates whether this function has remote outputs. + bool has_remote_outputs; + + // Indicates if running this function synchronously is both allowed + safe. + bool enable_sync_execution; + + // Maps the device name to the information about the component function + // be run on this device. + std::unordered_map glue_; + }; + + struct CleanUpItem { + string device; + uint64 step_id; + FunctionLibraryRuntime::Handle local_handle; + }; + + // If `handle` represents a multi-device function, returns the multi-device + // data associated with `handle`. Else, nullptr. + MultiDeviceFunctionData* IsMultiDevice( + FunctionLibraryRuntime::Handle handle) const; + + DistributedFunctionLibraryRuntime* const parent_; + + private: + FunctionLibraryRuntime::Handle AddHandleLocked( + const string& function_key, const string& device_name, + FunctionLibraryRuntime::LocalHandle local_handle) + TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); + + // For a given device_name, returns a DeviceContext for copying + // tensors to/from the device. + absl::Status GetDeviceContext(const string& device_name, + DeviceContext** device_context) const; + + // Looks up the information for the given `handle` and returns the name + // of the device where the function is registered. + string GetDeviceName(FunctionLibraryRuntime::Handle handle) const; + + // Removes handle from the state owned by this object. + absl::Status RemoveHandle(FunctionLibraryRuntime::Handle handle); + + // Clones ProcessFunctionLibraryRuntime and FunctionLibraryDefinition + // (transferring ownership of both to the caller). Note that the + // ProcessFunctionLibraryRuntime borrows a pointer to the + // FunctionLibraryDefinition and so the FunctionLibraryDefinition should + // outlive the ProcessFunctionLibraryRuntime. + // + // The `skip_flib_def` argument controls whether the method should clone the + // FunctionLibraryDefinition (default behavior) or return an empty function + // library. The latter is used by tf.data, which manages + // FunctionLibraryDefinitions for its functions independently (and passes + // these into the FunctionLibraryRuntime through an overlay), to avoid linear + // runtime w.r.t. to number of functions in the current function library. + absl::Status Clone(Env* env, int graph_def_version, + const OptimizerOptions& optimizer_options, + std::unique_ptr* out_lib_def, + std::unique_ptr* out_pflr, + bool skip_flib_def = false) const; + + absl::Status ReleaseMultiDeviceHandle(FunctionLibraryRuntime::Handle handle); + + absl::Status InstantiateMultiDevice( + const string& function_name, AttrSlice attrs, + const FunctionLibraryRuntime::InstantiateOptions& options, + FunctionLibraryRuntime::Handle* handle); + + void InstantiateRemote( + const string& function_name, AttrSlice attrs, + const FunctionLibraryRuntime::InstantiateOptions& options, + FunctionLibraryRuntime::Handle* handle, + FunctionLibraryRuntime::DoneCallback done); + + FunctionLibraryRuntime::Handle AddMultiDeviceHandle( + const std::unique_ptr data, + const string& function_key); + + bool HasMultiDeviceHandle(FunctionLibraryRuntime::Handle handle) const; + + void RunInternal(const FunctionLibraryRuntime::Options& opts, + FunctionLibraryRuntime::Handle handle, + absl::Span args, + std::vector* rets, + std::vector>* cleanup_items, + FunctionLibraryRuntime::DoneCallback done) const; + + absl::Status CreateRendezvous( + FunctionLibraryRuntime::Options& opts, + tsl::core::RefCountPtr* created_rendezvous) const; + + FunctionLibraryRuntime::DoneCallback ApplyCleanUpToDoneCallback( + std::vector>* items, + FunctionLibraryRuntime::DoneCallback done, + const FunctionLibraryRuntime::Options& opts, + tsl::core::RefCountPtr rendezvous) const; + + void CleanUp(std::vector>* items, + FunctionLibraryRuntime::DoneCallback done) const; + + static absl::Status GetComponentArgs(absl::Span args, + const ComponentFunctionData& comp_data, + InternalArgs* comp_args); + +#if !defined(IS_MOBILE_PLATFORM) + static absl::Status GetComponentArgs(const FunctionArgsInterface& args, + const ComponentFunctionData& comp_data, + InternalArgs* comp_args); +#endif // IS_MOBILE_PLATFORM + + std::vector GetOrderedSubgraphs( + const MultiDeviceFunctionData* data) const; + + absl::Status PrepareRunMultiDevice( + const FunctionLibraryRuntime::Options& opts, + FunctionLibraryRuntime::Handle handle, + const MultiDeviceFunctionData** data) const; + + absl::Status RunMultiDeviceSync( + const FunctionLibraryRuntime::Options& opts, + FunctionLibraryRuntime::Handle handle, std::vector* rets, + std::function + get_component_args) const; + + void RunMultiDeviceAsync( + const FunctionLibraryRuntime::Options& opts, + FunctionLibraryRuntime::Handle handle, std::vector* rets, + std::vector>* cleanup_items, + FunctionLibraryRuntime::DoneCallback done, + std::function + get_component_args) const; + + void PublishSubgraphs( + const std::string& function_name, + std::vector>&& function_records); + + // Data structure holding information for a single instantiated remote + // (to be executed on `target_device`) function. + class FunctionData { + public: + FunctionData(const string& target_device, + FunctionLibraryRuntime::LocalHandle local_handle, + const string& function_key) + : target_device_(target_device), + local_handle_(local_handle), + function_key_(function_key) {} + + const string& target_device() { return target_device_; } + const string& function_key() { return function_key_; } + + FunctionLibraryRuntime::LocalHandle local_handle() { + mutex_lock l(mu_); + return local_handle_; + } + + // Initializes the FunctionData object by potentially making an Initialize + // call to the DistributedFunctionLibraryRuntime. + void DistributedInit( + DistributedFunctionLibraryRuntime* parent, const string& function_name, + const FunctionLibraryDefinition& lib_def, AttrSlice attrs, + const FunctionLibraryRuntime::InstantiateOptions& options, + FunctionLibraryRuntime::DoneCallback done); + + bool is_cross_process() { + mutex_lock l(mu_); + return is_cross_process_; + } + + private: + mutex mu_; + + const string target_device_; + FunctionLibraryRuntime::LocalHandle local_handle_ TF_GUARDED_BY(mu_); + const string function_key_; + bool is_cross_process_ TF_GUARDED_BY(mu_) = false; + bool init_started_ TF_GUARDED_BY(mu_) = false; + absl::Status init_result_ TF_GUARDED_BY(mu_); + Notification init_done_; + }; + + mutable mutex mu_; + + Env* const env_; + const std::optional config_; + const DeviceMgr* const device_mgr_; + const FunctionLibraryDefinition* lib_def_; + thread::ThreadPool* default_thread_pool_; + + // Cluster update can reinitialize the device_set_ due to remote device + // changes. At the same time, InstantiateMultiDevice can use the cached + // devices to instantiate multi-worker functions. Function instantiation would + // fail if it spans the changed remote devices. + std::shared_ptr device_set_ TF_GUARDED_BY(mu_); + + // Composite devices owned by a EagerContext. + std::vector composite_devices_ TF_GUARDED_BY(mu_); + + // Holds all the function instantiations. Maps function_keys to handles. + std::unordered_map table_ + TF_GUARDED_BY(mu_); + + // Function data for instantiated remote functions. + std::unordered_map> + function_data_ TF_GUARDED_BY(mu_); + + // Function data for instantiated multi-device functions. + std::unordered_map> + mdevice_data_ TF_GUARDED_BY(mu_); + + std::unique_ptr< + std::unordered_map>> + flr_map_; + int next_handle_ TF_GUARDED_BY(mu_); + const SessionMetadata* const session_metadata_; + const Rendezvous::Factory rendezvous_factory_; + + const OptimizerOptions optimizer_options_; + const int graph_def_version_; + + StatsPublisherFactory stats_publisher_factory_; + // Holds all stats publishers, one for publishing subgraphs of each + // instantiated function. + std::vector> stats_publishers_ + TF_GUARDED_BY(mu_); +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_PROCESS_FUNCTION_LIBRARY_RUNTIME_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/process_state.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/process_state.h new file mode 100644 index 00000000..dd667cc2 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/process_state.h @@ -0,0 +1,161 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_PROCESS_STATE_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_PROCESS_STATE_H_ + +#include +#include +#include +#include + +#include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/framework/allocator_registry.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/thread_annotations.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/protobuf/config.pb.h" + +namespace tensorflow { + +class PoolAllocator; + +// Singleton that manages per-process state, e.g. allocation of +// shared resources. +class ProcessState : public ProcessStateInterface { + public: + static ProcessState* singleton(); + + // Descriptor for memory allocation attributes, used by optional + // runtime correctness analysis logic. + struct MemDesc { + enum MemLoc { CPU, GPU }; + MemLoc loc; + int dev_index; + bool gpu_registered; + bool nic_registered; + MemDesc() + : loc(CPU), + dev_index(0), + gpu_registered(false), + nic_registered(false) {} + string DebugString(); + }; + + // If NUMA Allocators are desired, call this before calling any + // Allocator accessor. + void EnableNUMA() { numa_enabled_ = true; } + + // Returns what we know about the memory at ptr. + // If we know nothing, it's called CPU 0 with no other attributes. + MemDesc PtrType(const void* ptr); + + // Returns the one CPUAllocator used for the given numa_node. + // Treats numa_node == kNUMANoAffinity as numa_node == 0. + Allocator* GetCPUAllocator(int numa_node) override; + + // Registers alloc visitor for the CPU allocator(s). + // REQUIRES: must be called before GetCPUAllocator. + void AddCPUAllocVisitor(SubAllocator::Visitor v); + + // Registers free visitor for the CPU allocator(s). + // REQUIRES: must be called before GetCPUAllocator. + void AddCPUFreeVisitor(SubAllocator::Visitor v); + + typedef std::unordered_map MDMap; + + protected: + ProcessState(); + virtual ~ProcessState() {} + friend class GPUProcessState; + friend class PluggableDeviceProcessState; + + // If these flags need to be runtime configurable consider adding + // them to ConfigProto. + static constexpr bool FLAGS_brain_mem_reg_gpu_dma = true; + static constexpr bool FLAGS_brain_gpu_record_mem_types = false; + + // Helper method for unit tests to reset the ProcessState singleton by + // cleaning up everything. Never use in production. + void TestOnlyReset(); + + static ProcessState* instance_; + bool numa_enabled_; + + mutex mu_; + + // Indexed by numa_node. If we want numa-specific allocators AND a + // non-specific allocator, maybe should index by numa_node+1. + std::vector cpu_allocators_ TF_GUARDED_BY(mu_); + std::vector cpu_alloc_visitors_ TF_GUARDED_BY(mu_); + std::vector cpu_free_visitors_ TF_GUARDED_BY(mu_); + + // A cache of cpu allocators indexed by a numa node. Used as a fast path to + // get CPU allocator by numa node id without locking the mutex. We can't use + // `cpu_allocators_` storage in the lock-free path because concurrent + // operation can deallocate the vector storage. + std::atomic cpu_allocators_cached_; + std::array cpu_allocators_cache_; + + // Optional RecordingAllocators that wrap the corresponding + // Allocators for runtime attribute use analysis. + MDMap mem_desc_map_; + std::vector cpu_al_ TF_GUARDED_BY(mu_); +}; + +namespace internal { +class RecordingAllocator : public Allocator { + public: + RecordingAllocator(ProcessState::MDMap* mm, Allocator* a, + ProcessState::MemDesc md, mutex* mu) + : mm_(mm), a_(a), md_(md), mu_(mu) {} + + string Name() override { return a_->Name(); } + void* AllocateRaw(size_t alignment, size_t num_bytes) override { + void* p = a_->AllocateRaw(alignment, num_bytes); + mutex_lock l(*mu_); + (*mm_)[p] = md_; + return p; + } + void DeallocateRaw(void* p) override { + mutex_lock l(*mu_); + auto iter = mm_->find(p); + mm_->erase(iter); + a_->DeallocateRaw(p); + } + bool TracksAllocationSizes() const override { + return a_->TracksAllocationSizes(); + } + size_t RequestedSize(const void* p) const override { + return a_->RequestedSize(p); + } + size_t AllocatedSize(const void* p) const override { + return a_->AllocatedSize(p); + } + absl::optional GetStats() override { return a_->GetStats(); } + bool ClearStats() override { return a_->ClearStats(); } + + AllocatorMemoryType GetMemoryType() const override { + return a_->GetMemoryType(); + } + + ProcessState::MDMap* mm_; // not owned + Allocator* a_; // not owned + ProcessState::MemDesc md_; + mutex* mu_; +}; +} // namespace internal +} // namespace tensorflow +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_PROCESS_STATE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/process_util.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/process_util.h new file mode 100644 index 00000000..cc2bc439 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/process_util.h @@ -0,0 +1,64 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_PROCESS_UTIL_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_PROCESS_UTIL_H_ + +#include + +#include "absl/functional/any_invocable.h" +#include "tensorflow/core/lib/core/threadpool.h" +#include "tensorflow/core/public/session_options.h" + +// TODO(vrv, mrry): Remove this library: its interface circumvents the +// callers' Env and calls Env::Default() directly. + +namespace tensorflow { + +// Returns a process-wide ThreadPool for scheduling compute operations +// using 'options'. Caller does not take ownership over threadpool. +thread::ThreadPool* ComputePool(const SessionOptions& options); + +// Returns the TF_NUM_INTEROP_THREADS environment value, or 0 if not specified. +int32 NumInterOpThreadsFromEnvironment(); + +// Returns the TF_NUM_INTRAOP_THREADS environment value, or 0 if not specified. +int32 NumIntraOpThreadsFromEnvironment(); + +// Returns the number of inter op threads specified in `options` or a default. +// If no value or a negative value is specified in the provided options, then +// the function returns the value defined in the TF_NUM_INTEROP_THREADS +// environment variable. If neither a value is specified in the options or in +// the environment, this function will return a reasonable default value based +// on the number of schedulable CPUs, and any MKL and OpenMP configurations. +int32 NumInterOpThreadsFromSessionOptions(const SessionOptions& options); + +// Creates a thread pool with number of inter op threads. +// The number is set if `num_threads` > 0, otherwise it will be configured by +// SessionOptions. +thread::ThreadPool* NewThreadPoolFromSessionOptions( + const SessionOptions& options, int32_t num_threads = 0); + +// Schedule "closure" in the default thread queue. +void SchedClosure(absl::AnyInvocable closure); + +// Schedule "closure" after the given number of microseconds in the +// fixed-size ThreadPool used for non-blocking compute tasks. +void SchedNonBlockingClosureAfter(int64_t micros, + absl::AnyInvocable closure); + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_PROCESS_UTIL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/profile_handler.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/profile_handler.h new file mode 100644 index 00000000..71aac10b --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/profile_handler.h @@ -0,0 +1,68 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_PROFILE_HANDLER_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_PROFILE_HANDLER_H_ + +#include "tensorflow/core/framework/step_stats.pb.h" +#include "tensorflow/core/graph/types.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/core/stringpiece.h" + +namespace tensorflow { + +// A profile handler collects event stats from a running step. +class ProfileHandler { + public: + ProfileHandler() {} + virtual ~ProfileHandler() {} + + // Records that a single Op was executed in the current step. + // + // Implementations of this method must be thread-safe. + // + // Args: + // - device: Device on which the Op was executed. + // - stats: Statistics of node execution timing. + // - is_copy: True if the op was a copy, send or recv. + // - label: Extra content for timeline click text. + // - op_type: String name of the Op. + // - details: Main content for timeline click text. + virtual void RecordOneOp(const string& device, const NodeExecStats& stats, + bool is_copy, absl::string_view label, + absl::string_view op_type, + absl::string_view details) = 0; + + // Records that the current step finished. + // + // Implementations of this method need not be thread-safe. + // + // Args: + // - start_time: The time at which the step started. + // - finish_time: The time at which the step finished. + // - cleanup_time: The time at which cleanup for the step finished. + // - total_runops: The number of ops that ran during this step. + // - final_status: The status that this step finished with. + virtual void StepDone(Microseconds start_time, Microseconds finish_time, + Microseconds cleanup_time, int total_runops, + absl::Status final_status) = 0; + + // Returns true if the caller should collect rpc activity. + virtual bool should_collect_rpcs() = 0; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_PROFILE_HANDLER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/propagator_debug_utils.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/propagator_debug_utils.h new file mode 100644 index 00000000..2e837104 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/propagator_debug_utils.h @@ -0,0 +1,38 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_PROPAGATOR_DEBUG_UTILS_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_PROPAGATOR_DEBUG_UTILS_H_ + +namespace tensorflow { + +struct Entry; +struct NodeItem; +class Tensor; + +// Returns a pointer to the tensor in `input` if one exists, or `nullptr`. +const Tensor* GetTensorValueForDump(const Entry& input); + +// Writes a LOG(WARNING) message describing the state of the given pending node +// in the graph described by `immutable_state`. +void DumpPendingNodeState(const NodeItem& node_item, const Entry* input_vector, + const bool show_nodes_with_no_ready_inputs); + +// Writes a LOG(WARNING) message describing the state of the given active node +// in the graph described by `immutable_state`. +void DumpActiveNodeState(const NodeItem& node_item, const Entry* input_vector); + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_PROPAGATOR_DEBUG_UTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/propagator_state.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/propagator_state.h new file mode 100644 index 00000000..e5f4fd6b --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/propagator_state.h @@ -0,0 +1,598 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_PROPAGATOR_STATE_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_PROPAGATOR_STATE_H_ + +#include +#include + +#include "tensorflow/core/common_runtime/entry.h" +#include "tensorflow/core/common_runtime/immutable_executor_state.h" +#include "tensorflow/core/common_runtime/pending_counts.h" +#include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/framework/control_flow.h" +#include "tensorflow/core/lib/gtl/flatmap.h" +#include "tensorflow/core/lib/gtl/inlined_vector.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/thread_annotations.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +typedef absl::InlinedVector AllocatorAttributeVec; + +// Represents the ephemeral "edge state" associated with one invocation of +// `Executor::Run()`. +// +// `PropagatorState` is responsible for propagating values along dataflow +// edges in a TensorFlow graph and determining which nodes are runnable. The +// executor primarily updates `PropagatorState` by calling `PropagateOutputs()` +// after processing a node, and `PropagatorState` dispatches `TaggedNode`s by +// adding them to a `TaggedNodeSeq`. +class PropagatorState { + public: + PropagatorState(const ImmutableExecutorState& immutable_state, + int64_t step_id, bool vlog); + ~PropagatorState(); + + private: + // Forward declaration so that `TaggedNode` can include a `FrameState*` and an + // `IterationState*`. + struct FrameState; + struct IterationState; + + public: + // A `TaggedNode` corresponds to a single invocation of a node's kernel, + // and it is created when the kernel becomes runnable (in a particular + // iteration of a particular frame). + struct TaggedNode { + const NodeItem* node_item; + FrameState* input_frame; + IterationState* input_iter; + bool is_dead; + + TaggedNode() = default; + TaggedNode(const NodeItem* node_item, FrameState* in_frame, + IterationState* in_iter, bool dead) + : node_item(node_item), + input_frame(in_frame), + input_iter(in_iter), + is_dead(dead) {} + + const NodeItem& get_node_item() const { return *node_item; } + + bool get_is_dead() const { return is_dead; } + int64_t get_iter_num() const; + }; + + // A drop-in replacement for std::deque. We typically don't + // have that many nodes in the ready queue, so we just use a vector and + // don't free up memory from the queue as we consume nodes. + class TaggedNodeReadyQueue { + public: + TaggedNodeReadyQueue() : front_index_(0) {} + + void push_back(const TaggedNode& node) { ready_.push_back(node); } + + TaggedNode front() const { + DCHECK_LT(front_index_, ready_.size()); + return ready_[front_index_]; + } + + void pop_front() { + DCHECK_LT(front_index_, ready_.size()); + front_index_++; + if ((front_index_ == ready_.size()) || (front_index_ > kSpillThreshold)) { + if (front_index_ == ready_.size()) { + ready_.clear(); + } else { + // Lots of unused entries at beginning of vector: move everything + // down to start of vector. + ready_.erase(ready_.begin(), ready_.begin() + front_index_); + } + front_index_ = 0; + } + } + bool empty() const { return ready_.empty(); } + int size() const { return ready_.size() - front_index_; } + + private: + // TODO(b/152925936): Re-evaluate these constants with current usage + // patterns. + static constexpr int kSpillThreshold = 16384; + absl::InlinedVector ready_; + int front_index_; + }; + + // TODO(b/152925936): Re-evaluate this constant with current usage patterns. + typedef absl::InlinedVector TaggedNodeSeq; + + private: + // The state of an iteration in a particular frame. + struct IterationState { + explicit IterationState(int64_t iter_num, + const PendingCounts* pending_counts, + int total_input_tensors) + : iter_num(iter_num), + input_tensors(new Entry[total_input_tensors]), + outstanding_ops(0), + outstanding_frame_count(0), + counts(*pending_counts) { // Initialize with copy of *pending_counts + } + + const int64_t + iter_num; // The index of this iteration in the enclosing loop. + + // One copy per iteration. For iteration k, i-th node's j-th input is in + // input_tensors[k][immutable_state_.nodes[i].input_start + j]. An entry is + // either a tensor pointer (pass-by-reference) or a tensor (pass-by-value). + // + // NOTE: No need to protect input_tensors[i] by any locks because it + // is resized once. Each element of tensors_ is written once by the + // source node of an edge and is cleared by the destination of the same + // edge. The latter node is never run concurrently with the former node. + Entry* input_tensors; + + // The number of outstanding ops for each iteration. + std::atomic outstanding_ops; + + // The number of outstanding frames for each iteration. + int outstanding_frame_count; + int pending(PendingCounts::Handle h) { return counts.pending(h); } + int decrement_pending(PendingCounts::Handle h, int v) { + return counts.decrement_pending(h, v); + } + // Mark a merge node as live + // REQUIRES: Node corresponding to "h" is a merge node + void mark_live(PendingCounts::Handle h) { counts.mark_live(h); } + // Mark a node to show that processing has started. + void mark_started(PendingCounts::Handle h) { counts.mark_started(h); } + // Mark a node to show that processing has completed. + void mark_completed(PendingCounts::Handle h) { counts.mark_completed(h); } + PendingCounts::NodeState node_state(PendingCounts::Handle h) { + return counts.node_state(h); + } + + int dead_count(PendingCounts::Handle h) { return counts.dead_count(h); } + void increment_dead_count(PendingCounts::Handle h) { + counts.increment_dead_count(h); + } + // REQUIRES: Node corresponding to "h" is a merge node + PendingCounts::AdjustResult adjust_for_mark_live(PendingCounts::Handle h) { + return counts.adjust_for_mark_live(h); + } + // REQUIRES: Node corresponding to "h" is a merge node + PendingCounts::AdjustResult adjust_for_mark_live_atomic( + PendingCounts::Handle h) { + return counts.adjust_for_mark_live_atomic(h); + } + PendingCounts::AdjustResult adjust_for_decrement_pending( + PendingCounts::Handle h, int decrement_pending) { + return counts.adjust_for_decrement_pending(h, decrement_pending); + } + PendingCounts::AdjustResult adjust_for_decrement_pending_atomic( + PendingCounts::Handle h, int decrement_pending) { + return counts.adjust_for_decrement_pending_atomic(h, decrement_pending); + } + PendingCounts::AdjustResult adjust_for_increment_dead( + PendingCounts::Handle h) { + return counts.adjust_for_increment_dead(h); + } + PendingCounts::AdjustResult adjust_for_increment_dead_atomic( + PendingCounts::Handle h) { + return counts.adjust_for_increment_dead_atomic(h); + } + PendingCounts::AdjustResult adjust_for_activation(PendingCounts::Handle h, + bool increment_dead) { + return counts.adjust_for_activation(h, increment_dead); + } + PendingCounts::AdjustResult adjust_for_activation_atomic( + PendingCounts::Handle h, bool increment_dead) { + return counts.adjust_for_activation_atomic(h, increment_dead); + } + + ~IterationState() { delete[] input_tensors; } + + private: + PendingCounts counts; + }; + + struct FrameState { + explicit FrameState(const ImmutableExecutorState& immutable_state, + int parallel_iters) + : immutable_state(immutable_state), + max_parallel_iterations(parallel_iters), + num_outstanding_iterations(1), + iterations(parallel_iters + 1), + iterations_raw(iterations.data()) {} + + // A new frame is created for each loop. Execution starts at iteration 0. + // When a value at iteration 0 passes through a NextIteration node, + // iteration 1 is created and starts running. Note that iteration 0 may + // still be running so multiple iterations may run in parallel. The + // frame maintains the state of iterations in several data structures + // such as pending_count and input_tensors. When iteration 0 completes, + // we garbage collect the state of iteration 0. + // + // A frame instance is considered "done" and can be garbage collected + // if all its inputs have entered and all its iterations are "done". + // + // A frame manages the live iterations of an iterative computation. + // Iteration i is considered "done" when there are no outstanding ops, + // frames at iteration i are done, all recvs for this iteration are + // completed, and iteration i-1 is done. For iteration 0, we instead + // wait for there to be no more pending inputs of the frame. + // + // Frames and iterations are garbage collected once they are done. + // The state we need to keep around is highly dependent on the + // parallelism enabled by the scheduler. We may want to have the + // scheduler dynamically control the outstanding number of live + // parallel frames and iterations. To reduce the state space, the + // scheduler might want to schedule ops in inner frames first and + // lower iterations first. + // + // This frame state is mostly initialized lazily on demand so we + // don't introduce unnecessary overhead. + + // The immutable state of the executor the frame is in. + const ImmutableExecutorState& immutable_state; + + // The name of this frame, which is the concatenation of its parent + // frame name, the iteration of the parent frame when this frame was + // created, and the value of the attr 'frame_name'. + string frame_name; + + // The unique id for this frame. Generated by fingerprinting + // frame_name. + uint64 frame_id; + + // The iteration state of its parent frame when this frame is created. + // nullptr if there is no parent frame. The frame_name/parent_iter pair + // uniquely identifies this FrameState. + IterationState* parent_iter = nullptr; + + // The FrameState of its parent frame. + FrameState* parent_frame = nullptr; + + // The maximum allowed number of parallel iterations. + const int max_parallel_iterations; + + // The number of inputs this frame is still waiting. + int num_pending_inputs = 0; + + // The highest iteration number we have reached so far in this frame. + int64_t iteration_count TF_GUARDED_BY(mu) = 0; + + // The number of outstanding iterations. + int num_outstanding_iterations TF_GUARDED_BY(mu) = 1; + + private: + // The active iteration states of this frame. + absl::InlinedVector iterations; + IterationState** const iterations_raw TF_GUARDED_BY(mu); + IterationState* iterations_first TF_GUARDED_BY(mu); + + public: + // The NextIteration nodes to enter a new iteration. If the number of + // outstanding iterations reaches the limit, we will defer the start of + // the next iteration until the number of outstanding iterations falls + // below the limit. + std::vector> next_iter_roots + TF_GUARDED_BY(mu); + + // The values of the loop invariants for this loop. They are added into + // this list as they "enter" the frame. When a loop invariant enters, + // we make it available to all active iterations. When the frame starts + // a new iteration, we make all the current loop invariants available + // to the new iteration. + std::vector> inv_values + TF_GUARDED_BY(iter_mu); + + // The list of dead exit node items for the current highest iteration. We + // will only "execute" the dead exits of the final iteration. + std::vector dead_exits TF_GUARDED_BY(iter_mu); + + // Static information specific to this frame. + PendingCounts* pending_counts = nullptr; + int total_input_tensors = 0; + std::vector* nodes = nullptr; + + // Lock ordering: ExecutorState.mu_ < mu < iter_mu; + // during structured traversal: parent_frame->mu < mu. + mutex mu; + + // This mutex lock should only be held when entering next iteration. + mutex iter_mu; + + void InitializeFrameInfo(const ImmutableExecutorState::FrameInfo& finfo); + + inline IterationState* GetIteration(int64_t iter) + TF_SHARED_LOCKS_REQUIRED(mu) { + if (TF_PREDICT_TRUE(iter == 0)) { + return iterations_first; + } else { + size_t index = iter % (max_parallel_iterations + 1); + return iterations_raw[index]; + } + } + + void SetIteration(int64_t iter, IterationState* state); + + // Adjust the outstanding op count by 'delta' and clean up the iterations in + // the frame if no more ops are oustanding. Return true iff the execution of + // the frame is done. + // + // Avoids acquiring the lock in the common case that the frame is not done. + bool AdjustOutstandingOps(IterationState* iter_state, int delta, + TaggedNodeSeq* ready); + + bool AdjustOutstandingOpsLocked(IterationState* iter_state, int delta, + TaggedNodeSeq* ready) + TF_EXCLUSIVE_LOCKS_REQUIRED(mu); + + bool AdjustOutstandingOpsFastPath(IterationState* iter_state, int delta) + TF_SHARED_LOCKS_REQUIRED(mu); + + // Convenience methods for the above 'Adjust' calls where delta takes the + // common value of -1. + bool DecrementOutstandingOps(IterationState* iter_state, + TaggedNodeSeq* ready); + + bool DecrementOutstandingOpsLocked(IterationState* iter_state, + TaggedNodeSeq* ready); + + // Returns true if the computation in the frame is completed. + bool IsFrameDone(); + + // Returns true if the iteration of the frame is completed. + bool IsIterationDone(IterationState* iter_state) + TF_SHARED_LOCKS_REQUIRED(mu); + + // Increments the iteration id. If this is a new iteration, initialize it. + // + // Returns a pointer to the new iteration. + IterationState* IncrementIteration(TaggedNodeSeq* ready) + TF_EXCLUSIVE_LOCKS_REQUIRED(mu); + + // Activate all the deferred NextIteration nodes in a new iteration. + void ActivateNexts(IterationState* iter_state, TaggedNodeSeq* ready) + TF_EXCLUSIVE_LOCKS_REQUIRED(mu); + + // Activate all the current loop invariants in a new iteration. + void ActivateLoopInvs(IterationState* iter_state, TaggedNodeSeq* ready) + TF_EXCLUSIVE_LOCKS_REQUIRED(mu); + + // Add a new loop invariant and make it available to all active + // iterations. + void AddLoopInv(const NodeItem* item, const Entry& entry, + TaggedNodeSeq* ready) TF_EXCLUSIVE_LOCKS_REQUIRED(mu); + + // Activate the successors of a node. Contents of *outputs are left in an + // indeterminate state after returning from this method. + // + // In the case that 'item' is a simple node (no merge/control outputs) this + // will acquire a shared lock and can run concurrently with other + // invocations. + // + // Return true if the frame is done after activation. + bool ActivateNodesAndAdjustOutstanding( + const NodeItem* item, const bool is_dead, IterationState* iter_state, + EntryVector* outputs, TaggedNodeSeq* ready, int decrement_activation); + + // Same as the above, but requires 'mu' already held in exclusive mode. + int ActivateNodesLocked(const NodeItem* item, const bool is_dead, + IterationState* iter_state, EntryVector* outputs, + TaggedNodeSeq* ready) + TF_EXCLUSIVE_LOCKS_REQUIRED(mu); + + // Cleanup iterations of this frame starting from the given iteration. + bool CleanupIterations(IterationState* iter_state, TaggedNodeSeq* ready) + TF_EXCLUSIVE_LOCKS_REQUIRED(mu); + + void DumpIterationState(PropagatorState* parent) { + mutex_lock l(mu); + for (IterationState* iteration : iterations) { + if (iteration) { + LOG(WARNING) << " Iteration:"; + parent->DumpIterationState(this, iteration); + } + } + } + + ~FrameState() { + for (size_t i = 0; i < iterations.size(); ++i) { + delete iterations[i]; + iterations[i] = nullptr; + } + } + + private: + // REQUIRES: `!item->is_any_consumer_merge_or_control_trigger`. + // This variant does not use atomic operations to modify the pending counts + // and thus must hold the exclusive lock. + int ActivateNodesFastPathLocked(const NodeItem* item, bool is_dead, + IterationState* iter_state, + EntryVector* outputs, TaggedNodeSeq* ready) + TF_EXCLUSIVE_LOCKS_REQUIRED(mu); + + // REQUIRES: `!item->is_any_consumer_merge_or_control_trigger`. + // This variant uses atomic operations to modify the pending counts. + int ActivateNodesFastPathShared(const NodeItem* item, bool is_dead, + IterationState* iter_state, + EntryVector* outputs, TaggedNodeSeq* ready) + TF_SHARED_LOCKS_REQUIRED(mu); + + int ActivateNodesSlowPathLocked(const NodeItem* item, bool is_dead, + IterationState* iter_state, + EntryVector* outputs, TaggedNodeSeq* ready) + TF_EXCLUSIVE_LOCKS_REQUIRED(mu); + + int ActivateNodesSlowPathShared(const NodeItem* item, bool is_dead, + IterationState* iter_state, + EntryVector* outputs, TaggedNodeSeq* ready) + TF_SHARED_LOCKS_REQUIRED(mu); + + // Implementation templates. Not for public use. + template + int ActivateNodesFastPathInternal(const NodeItem* item, bool is_dead, + IterationState* iter_state, + EntryVector* outputs, + TaggedNodeSeq* ready); + template + int ActivateNodesSlowPathInternal(const NodeItem* item, bool is_dead, + IterationState* iter_state, + EntryVector* outputs, + TaggedNodeSeq* ready); + }; + + public: + // Creates and adds a `TaggedNode` for each node in `roots` to `*ready`. + void ActivateRoots(gtl::ArraySlice roots, + TaggedNodeSeq* ready); + + // After processing the outputs, propagates the outputs to their dsts. + // Contents of *outputs are left in an indeterminate state after + // returning from this method. + void PropagateOutputs(const TaggedNode& tagged_node, EntryVector* outputs, + TaggedNodeSeq* ready); + + // Returns an array of `Entry` objects corresponding to the inputs of + // `tagged_node`. + // + // NOTE: Thread safety analysis is disabled on this method, because the + // underlying `IterationState` and its array of `input_tensors` retain the + // same address while the iteration is live. + Entry* GetInputTensors(const TaggedNode& tagged_node) const + TF_NO_THREAD_SAFETY_ANALYSIS { + return tagged_node.input_iter->input_tensors + + tagged_node.node_item->input_start; + } + + FrameAndIter GetFrameAndIter(const TaggedNode& tagged_node) const { + return {tagged_node.input_frame->frame_id, + tagged_node.input_iter->iter_num}; + } + + // Provide debugging output of the state of the executor. + void DumpState(); + + // For debugging/logging only. + void MaybeMarkStarted(const TaggedNode& tagged_node) { + // TODO(misard) Replace with a finer-grain enabling flag once we add better + // optional debugging support. + if (TF_PREDICT_FALSE(vlog_) && VLOG_IS_ON(1)) { + mutex_lock l(tagged_node.input_frame->mu); + tagged_node.input_iter->mark_started( + immutable_state_.pending_ids()[tagged_node.node_item->node_id]); + } + } + + void MaybeMarkCompleted(const TaggedNode& tagged_node) { + // TODO(misard) Replace with a finer-grain enabling flag once we add better + // optional debugging support. + if (TF_PREDICT_FALSE(vlog_) && VLOG_IS_ON(1)) { + mutex_lock l(tagged_node.input_frame->mu); + tagged_node.input_iter->mark_completed( + immutable_state_.pending_ids()[tagged_node.node_item->node_id]); + } + } + + private: + // Find an existing or create a new child frame in the frame 'frame' at + // iteration 'iter'. + void FindOrCreateChildFrame(FrameState* frame, IterationState* iter_state, + const NodeItem& node_item, FrameState** child); + + // Delete a frame. Called when the frame is done. + void DeleteFrame(FrameState* frame, TaggedNodeSeq* ready); + + // Cleanup frames and iterations starting from frame/iter. Called when + // a child frame is done. + void CleanupFramesIterations(FrameState* frame, IterationState* iter_state, + TaggedNodeSeq* ready); + + // Provide debugging output about an outstanding iteration in the executor. + void DumpIterationState(const FrameState* frame, IterationState* iteration); + + const ImmutableExecutorState& immutable_state_; + const int64_t step_id_; + const bool vlog_; + + mutex mu_; + + // The root frame in which the execution of this step is started. + FrameState* root_frame_; + + // Mapping from frame ID to outstanding frames. A new frame is created + // at some iteration of an active frame. So the unique key for the new + // child frame is a hash composed of the ID of the parent frame, the iteration + // number at which the parent frame is creating the new frame, and the + // name of the new frame from nodedef. + absl::flat_hash_map outstanding_frames_ + TF_GUARDED_BY(mu_); + + PropagatorState(const PropagatorState&) = delete; + void operator=(const PropagatorState&) = delete; +}; + +inline int64_t PropagatorState::TaggedNode::get_iter_num() const { + return input_iter->iter_num; +} + +// `OrderedPropagatorState` replaces `PropagatorState`s `TaggedNodeReadyQueue` +// with a priority queue. This ensures that the order in which we dequeue +// `TaggedNode&`s is stable with respect to ASLR. +// +// This is not always needed, as in a multithreaded environment, executions are +// expected to happen nondeterministically, but this nondeteminism can be a +// problem: For example, In usecases that are running close to the RAM limit of +// a device, reordering ops can cause an increase in memory fragmenenation, +// causing an OOM. +// This codepath is enabled using TF_DETERMINISTIC_ORDER=1 in executor.cc +class OrderedPropagatorState : public PropagatorState { + using PropagatorState::PropagatorState; + + public: + class TaggedNodeReadyQueue : PropagatorState::TaggedNodeReadyQueue { + public: + TaggedNodeReadyQueue() : readyp_(compare) {} + void push_back(const TaggedNode& node) { readyp_.push(node); } + TaggedNode front() const { return readyp_.top(); } + void pop_front() { readyp_.pop(); } + bool empty() const { return readyp_.empty(); } + int size() const { return readyp_.size(); } + + private: + static bool compare(TaggedNode const& lhs, TaggedNode const& rhs) { + std::tuple lhs_prio{lhs.node_item->node_id, + lhs.input_frame->frame_id, + lhs.input_iter->iter_num}; + std::tuple rhs_prio{rhs.node_item->node_id, + rhs.input_frame->frame_id, + rhs.input_iter->iter_num}; + return lhs_prio < rhs_prio; + } + + std::priority_queue, decltype(&compare)> + readyp_; + }; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_PROPAGATOR_STATE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/quantize_training.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/quantize_training.h new file mode 100644 index 00000000..de3ed6b4 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/quantize_training.h @@ -0,0 +1,57 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_QUANTIZE_TRAINING_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_QUANTIZE_TRAINING_H_ + +#include "tensorflow/core/graph/graph.h" + +namespace tensorflow { +// Rewrites graph for quantized training. +// Rewrites the forward pass to include the precision loss with quantization so +// the model can learn to deal with such loss and achieve better accuracy when +// it is quantized later for inference. +// Note that the num_bits should be in [1, 63] and 'g' must be not null. +// quant_op_type specifies which quantization op should be used. +// Current ops supported: +// - QuantizeAndDequantizeV2. +// - FakeQuantWithMinMaxVars. +// +// On success, returns OK. +// +// On failure, returns the error status. Possible errors include: +// - num_bits out of range. +// - g is null. +// - More than 1 unknown ops encountered. +absl::Status DoQuantizeTraining(int32_t num_bits, const string& quant_op_type, + Graph* g); + +// Converts the input serialized GraphDef and returns a rewritten serialized +// GraphDef for quantized training. +absl::Status DoQuantizeTrainingOnSerializedGraphDef(const string& input_graph, + int32_t num_bits, + const string& quant_op_type, + string* result_graph); + +// Converts the input GraphDef and returns a rewritten GraphDef for quantized +// training. +absl::Status DoQuantizeTrainingOnGraphDef(const GraphDef& input_graphdef, + int32_t num_bits, + const string& quant_op_type, + GraphDef* result_graphdef); + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_QUANTIZE_TRAINING_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/renamed_device.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/renamed_device.h new file mode 100644 index 00000000..e4b4b8ae --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/renamed_device.h @@ -0,0 +1,173 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_RENAMED_DEVICE_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_RENAMED_DEVICE_H_ + +#include "tensorflow/core/common_runtime/device.h" +#include "tensorflow/core/lib/core/threadpool_interface.h" +#include "tensorflow/core/util/device_name_utils.h" + +namespace tensorflow { + +// Wraps a device with a new name, delegating work to the wrapped device. +// +// This class is used to wrap local devices when using clusterspec propagation +// where the name of a particular device may change in the context of a given +// session. +class RenamedDevice : public Device { + public: + static std::unique_ptr NewRenamedDevice( + const string& new_base, Device* underlying, bool owns_underlying, + bool isolate_session_state, + thread::ThreadPoolInterface* underlying_threadpool = nullptr); + + ~RenamedDevice() override; + + const DeviceBase* UnderlyingDevice() const override { + return underlying_device_->UnderlyingDevice(); + } + DeviceBase* UnderlyingDevice() override { + return underlying_device_->UnderlyingDevice(); + } + + const CpuWorkerThreads* tensorflow_cpu_worker_threads() const override { + if (underlying_threadpool_) { + return Device::tensorflow_cpu_worker_threads(); + } + return underlying_device_->tensorflow_cpu_worker_threads(); + } + + const DeviceBase::AcceleratorDeviceInfo* tensorflow_accelerator_device_info() + const override { + return underlying_device_->tensorflow_accelerator_device_info(); + } + + Allocator* GetAllocator(AllocatorAttributes attr) override { + return underlying_device_->GetAllocator(attr); + } + + Allocator* GetScopedAllocator(AllocatorAttributes attr, + int64_t step_id) override { + return underlying_device_->GetScopedAllocator(attr, step_id); + } + + ScopedAllocatorMgr* GetScopedAllocatorMgr() const override { + return underlying_device_->GetScopedAllocatorMgr(); + } + + const Eigen::ThreadPoolDevice* eigen_cpu_device() override { + // Use the underlying threadpool only if the underlying device supports + // eigen_cpu_device. + if (underlying_threadpool_ && underlying_device_->has_eigen_cpu_device()) { + return Device::eigen_cpu_device(); + } + return underlying_device_->eigen_cpu_device(); + } + + thread::ThreadPool* tensorflow_device_thread_pool() override { + // Use the underlying threadpool instead of tensorflow_device_thread_pool + // of the underlying device only if tensorflow_device_thread_pool is defined + // for the underlying device. + if (underlying_threadpool_ && + underlying_device_->tensorflow_device_thread_pool() != nullptr) { + return Device::tensorflow_device_thread_pool(); + } + return underlying_device_->tensorflow_device_thread_pool(); + } + + bool has_eigen_cpu_device() const override { + return underlying_device_->has_eigen_cpu_device(); + } + + + PerOpGpuDevice* MakeGpuDevice() override { + return underlying_device_->MakeGpuDevice(); + } + + absl::Status ReinitializeGpuDevice(OpKernelContext* context, + PerOpGpuDevice* device, DeviceContext* dc, + Allocator* allocator) override { + return underlying_device_->ReinitializeGpuDevice(context, device, dc, + allocator); + } + + absl::Status MakeTensorFromProto(const TensorProto& tensor_proto, + const AllocatorAttributes alloc_attrs, + Tensor* tensor) override { + return underlying_device_->MakeTensorFromProto(tensor_proto, alloc_attrs, + tensor); + } + + void CopyTensorInSameDevice(const Tensor* input_tensor, Tensor* output_tensor, + const DeviceContext* device_context, + StatusCallback done) override { + underlying_device_->CopyTensorInSameDevice(input_tensor, output_tensor, + device_context, std::move(done)); + } + + // Below are virtual methods defined on Device + + void Compute(OpKernel* op_kernel, OpKernelContext* context) override { + underlying_device_->Compute(op_kernel, context); + } + + void ComputeAsync(AsyncOpKernel* op_kernel, OpKernelContext* context, + AsyncOpKernel::DoneCallback done) override { + underlying_device_->ComputeAsync(op_kernel, context, std::move(done)); + } + + absl::Status Sync() override { return underlying_device_->Sync(); } + + absl::Status MaybeRewriteGraph(std::unique_ptr* graph) override { + return underlying_device_->MaybeRewriteGraph(graph); + } + + absl::Status TryGetDeviceContext(DeviceContext** out_context) override { + return underlying_device_->TryGetDeviceContext(out_context); + } + + // Returns the resource manager associated w/ this device. + ResourceMgr* resource_manager() override { + if (isolate_session_state_) { + return Device::resource_manager(); + } else { + return underlying_device_->resource_manager(); + } + } + + bool IsLocal() const override { return underlying_device_->IsLocal(); } + + bool IsRemoteCallAllowed() const override { + return underlying_device_->IsRemoteCallAllowed(); + } + + private: + RenamedDevice(Device* underlying, const DeviceAttributes& attributes, + bool owns_underlying, bool isolate_session_state, + thread::ThreadPoolInterface* underlying_threadpool); + Device* const underlying_device_; + const bool owns_underlying_device_; + const bool isolate_session_state_; + + std::unique_ptr underlying_threadpool_; + // eigen_worker_threads_ is stored here so that we can pass the pointer + // of eigen_worker_threads_.workers to the parent class. + DeviceBase::CpuWorkerThreads eigen_worker_threads_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_RENAMED_DEVICE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/rendezvous_mgr.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/rendezvous_mgr.h new file mode 100644 index 00000000..23c07b3d --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/rendezvous_mgr.h @@ -0,0 +1,106 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_RENDEZVOUS_MGR_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_RENDEZVOUS_MGR_H_ + +#include +#include + +#include "tensorflow/core/common_runtime/device_mgr.h" +#include "tensorflow/core/framework/local_rendezvous.h" +#include "tensorflow/core/framework/rendezvous.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +// The IntraProcessRendezvous classes are implementations of a Rendezvous that +// expects all producers and consumers to be devices immediately accessible +// within the process. That is, it will never be necessary to perform an RPC to +// communicate with either. +// +// Buffering of Tensor values is delegated to a `LocalRendezvous`. An +// IntraProcessRendezvous. just adds functionality to coordinate multiple +// process-local devices. + +// Reference-counted implementation that may be shared between multiple threads. +class RefCountedIntraProcessRendezvous : public Rendezvous { + public: + explicit RefCountedIntraProcessRendezvous(const DeviceMgr* device_mgr); + + // Implementation of RendezvousInterface methods. + // NOTE: The methods may clear the Item list and destroy 'this' if there are + // no other references to the RefCountedIntraProcessRendezvous object. + // If the caller intend to keep a longer life time then it shall keep its own + // reference to the RefCountedIntraProcessRendezvous. + absl::Status Send(const ParsedKey& key, const Rendezvous::Args& args, + const Tensor& val, const bool is_dead) override; + void RecvAsync(const ParsedKey& key, const Rendezvous::Args& args, + DoneCallback done) override; + void StartAbort(const absl::Status& status) override; + + // Returns the member LocalRendezvous' status. + absl::Status GetLocalRendezvousStatus(); + + inline void UpdateDeviceManager(DeviceMgr* device_mgr) { + device_mgr_ = device_mgr; + } + + private: + const DeviceMgr* device_mgr_; // Not owned. + LocalRendezvous local_; + + ~RefCountedIntraProcessRendezvous() override; + + RefCountedIntraProcessRendezvous(const RefCountedIntraProcessRendezvous&) = + delete; + void operator=(const RefCountedIntraProcessRendezvous&) = delete; +}; + +// RefCountedIntraProcessRendezvous is aliased to IntraProcessRendezvous for +// backwards compatibility with existing users. +using IntraProcessRendezvous = RefCountedIntraProcessRendezvous; + +// Non-reference-counted implementation that may be stack-allocated for +// performance. +// +// Prefer to use PrivateIntraProcessRendezvous in new code. +class PrivateIntraProcessRendezvous : public RendezvousInterface { + public: + explicit PrivateIntraProcessRendezvous(const DeviceMgr* device_mgr); + ~PrivateIntraProcessRendezvous() override; + + // Implementation of RendezvousInterface methods. + absl::Status Send(const ParsedKey& key, const Rendezvous::Args& args, + const Tensor& val, const bool is_dead) override; + void RecvAsync(const ParsedKey& key, const Rendezvous::Args& args, + DoneCallback done) override; + void StartAbort(const absl::Status& status) override; + + private: + const DeviceMgr* device_mgr_; + LocalRendezvous local_; + + PrivateIntraProcessRendezvous(const PrivateIntraProcessRendezvous&) = delete; + void operator=(const PrivateIntraProcessRendezvous&) = delete; +}; + +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_RENDEZVOUS_MGR_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/rendezvous_util.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/rendezvous_util.h new file mode 100644 index 00000000..8ed1dd7a --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/rendezvous_util.h @@ -0,0 +1,54 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_RENDEZVOUS_UTIL_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_RENDEZVOUS_UTIL_H_ + +#include + +#include "tensorflow/core/framework/rendezvous.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { + +typedef std::map NamedTensors; +typedef std::function StatusCallback; + +// Uses `rendezvous` to send tensors in `tensors_to_send`. `device_context` +// should be the DeviceContext associated with the source of the tensors. +// `alloc_attrs` contains information about how the `tensors_to_send` are +// allocated. `alloc_attrs` should either be {} or should match the length of +// `keys`. +absl::Status SendTensorsToRendezvous( + RendezvousInterface* rendezvous, DeviceContext* device_context, + const std::vector& alloc_attrs, + const std::vector& keys, absl::Span tensors_to_send); + +// Uses `rendezvous` to obtain tensors. `device_context` should be the +// DeviceContext associated with the receiving device. `alloc_attrs` contains +// information as how to store the received tensors. Should be {} or match the +// length of `keys`. +void RecvOutputsFromRendezvousAsync( + RendezvousInterface* rendezvous, DeviceContext* device_context, + const std::vector& alloc_attrs, + const std::vector& keys, std::vector* received_tensors, + StatusCallback done); + +absl::Status RecvOutputsFromRendezvous(RendezvousInterface* rendezvous, + NamedTensors* out, + const Rendezvous::Args& args); + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_RENDEZVOUS_UTIL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/replicate_constants_pass.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/replicate_constants_pass.h new file mode 100644 index 00000000..b215d301 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/replicate_constants_pass.h @@ -0,0 +1,50 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_REPLICATE_CONSTANTS_PASS_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_REPLICATE_CONSTANTS_PASS_H_ + +#include "tensorflow/core/common_runtime/optimization_registry.h" + +// Small constants are replicated to the hosts of their successors. This pass +// only applies when there are multiple successors. +// +// For example, the graph: +// C -> {Op0, Op1, Op2, Op3} +// C's assigned_device is /job:tpu_host_worker/replica:0/task:0/device:CPU:0 +// Op0's assigned_device is /job:tpu_host_worker/replica:0/task:0/device:TPU:0 +// Op1's assigned_device is /job:tpu_host_worker/replica:0/task:0/device:TPU:1 +// Op2's assigned_device is /job:tpu_host_worker/replica:0/task:1/device:TPU:0 +// Op3's assigned_device is /job:tpu_host_worker/replica:0/task:1/device:TPU:1 +// is rewritten to: +// C0 -> {Op0, Op1} +// C1 -> {Op2, Op3} +// C0's assigned_device is /job:tpu_host_worker/replica:0/task:0/device:CPU:0 +// C1's assigned_device is /job:tpu_host_worker/replica:0/task:1/device:CPU:0 +// Op0's assigned_device is /job:tpu_host_worker/replica:0/task:0/device:TPU:0 +// Op1's assigned_device is /job:tpu_host_worker/replica:0/task:0/device:TPU:1 +// Op2's assigned_device is /job:tpu_host_worker/replica:0/task:1/device:TPU:0 +// Op3's assigned_device is /job:tpu_host_worker/replica:0/task:1/device:TPU:1 + +namespace tensorflow { + +class ReplicateConstantsPass : public GraphOptimizationPass { + public: + absl::Status Run(const GraphOptimizationPassOptions& options) override; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_REPLICATE_CONSTANTS_PASS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/replicate_per_replica_nodes.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/replicate_per_replica_nodes.h new file mode 100644 index 00000000..4be95ea3 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/replicate_per_replica_nodes.h @@ -0,0 +1,44 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_REPLICATE_PER_REPLICA_NODES_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_REPLICATE_PER_REPLICA_NODES_H_ + +#include "absl/container/flat_hash_map.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { + +// `composite_device` maps from a virtual device to a set of devices. +// In a function graph, for each node assigned to a composite device +// (representing N devices), replace it with N replicated nodes (one per +// device). +// REQUIREMENTS: +// 1) Each node has been assigned to a device (including composite device). +// 2) Each cluster of nodes assigned to a composite device should include at +// least one "_Arg" node. +// composite device. +// 3) Clusters assigned to different composite devices should have no data +// dependency. +// TODO(b/145922293): Register it as a POST_REWRITE_FOR_EXEC pass. +absl::Status ReplicatePerReplicaNodesInFunctionGraph( + const absl::flat_hash_map*>& + composite_devices, + Graph* graph); + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_REPLICATE_PER_REPLICA_NODES_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/request_cost.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/request_cost.h new file mode 100644 index 00000000..3cb40ec8 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/request_cost.h @@ -0,0 +1,104 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_REQUEST_COST_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_REQUEST_COST_H_ + +#include +#include +#include +#include + +#include "absl/base/thread_annotations.h" +#include "absl/container/flat_hash_map.h" +#include "absl/strings/string_view.h" +#include "absl/synchronization/mutex.h" +#include "absl/time/time.h" + +namespace tensorflow { + +// RequestCost collects the costs and metrics for processing an rpc request. +class RequestCost { + public: + // Records costs. The inputs should be pairs of cost type and cost. + // It's thread-safe, and can be called from different threads. + void RecordCost( + const std::vector>& costs); + + // Scales all types of costs for processing an rpc request. + // It's thread-safe. It's expected to be called at the end of processing an + // rpc request, when all the costs have been collected. + void ScaleCosts(int scale_factor); + + // Gets all types of costs for processing an rpc request. + // It's thread-safe. It's expected to be called at the end of processing an + // rpc request, when all the costs have been collected. + absl::flat_hash_map GetCosts() const; + + // Records metrics. The inputs should be pairs of metric name and value. + // It's thread-safe, and can be called from different threads. Unlike + // RecordCosts where costs are summed up if recorded with the same key, + // metrics are replaced. + void RecordMetrics( + const std::vector>& metrics); + + // Gets all types of metrics for processing an rpc request. + // It's thread-safe. It's expected to be called at the end of processing an + // rpc request, when all the metrics have been collected. + absl::flat_hash_map GetMetrics() const; + + // Metrics of each batch that processes this rpc request. + struct BatchMetrics { + // Size of the batch. + int64_t processed_size = 0; + // In this batch, input size from this rpc request. + int64_t input_size = 0; + // In this batch, the padding amount. + int64_t padding_size = 0; + // Costs for processing this batch. + absl::flat_hash_map batch_costs; + }; + + // Records the metrics of a batch. + // It's thread-safe, and can be called from different threads. It may be + // called multiple times if a request is processed by more than one batches. + void RecordBatchMetrics(const BatchMetrics& batch_metrics); + + // Scales costs of all the batches that process this rpc request. + // It's thread-safe. It's expected to be called at the end of processing an + // rpc request, when all batch processing has completed. + void ScaleBatchCosts(int scale_factor); + + // Get metrics of all the batches that process this rpc request. + // It's thread-safe. It's expected to be called at the end of processing an + // rpc request, when all batch processing has completed. + std::vector GetBatchMetrics() const; + + private: + mutable absl::Mutex mutex_; + + // Query costs. Map from cost type to cost. + absl::flat_hash_map cost_map_ + ABSL_GUARDED_BY(mutex_); + // Query metrics. Map from metric name to value. + absl::flat_hash_map metric_map_ ABSL_GUARDED_BY(mutex_); + + // Metrics of batches that process this rpc request. + std::vector batch_metrics_ ABSL_GUARDED_BY(mutex_); +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_REQUEST_COST_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/request_cost_accessor.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/request_cost_accessor.h new file mode 100644 index 00000000..ba64da4b --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/request_cost_accessor.h @@ -0,0 +1,37 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_REQUEST_COST_ACCESSOR_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_REQUEST_COST_ACCESSOR_H_ + +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/time/time.h" +#include "tensorflow/core/common_runtime/request_cost.h" + +namespace tensorflow { + +// An interface for accessing the RequestCost associated with the current rpc +// request. +class RequestCostAccessor { + public: + virtual ~RequestCostAccessor() {} + virtual RequestCost* GetRequestCost() const = 0; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_REQUEST_COST_ACCESSOR_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/request_cost_accessor_registry.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/request_cost_accessor_registry.h new file mode 100644 index 00000000..6e91678f --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/request_cost_accessor_registry.h @@ -0,0 +1,71 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_REQUEST_COST_ACCESSOR_REGISTRY_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_REQUEST_COST_ACCESSOR_REGISTRY_H_ + +#include +#include +#include + +#include "absl/memory/memory.h" +#include "absl/strings/string_view.h" +#include "tensorflow/core/common_runtime/request_cost_accessor.h" + +namespace tensorflow { + +// TODO(b/185852990): Create a template Registry that allows registering +// different types (e.g RequestCostAccessor, CostMeasurement). +// +// RequestCostAccessorRegistry allows to +// - register a RequestCostAccessor type to the global map +// - create an instance of registered RequestCostAccessor. +class RequestCostAccessorRegistry { + public: + // Creates an instance of registered RequestCostAccessor by name. If the named + // RequestCostAccessor is not registered yet, returns nullptr. + static std::unique_ptr CreateByNameOrNull( + absl::string_view name); + + using Creator = std::function()>; + + // Registers a RequestCostAccessor type to the global map. Registering + // different types of RequestCostAccessor with the same name is prohibited. + static void RegisterRequestCostAccessor(absl::string_view name, + Creator creator); +}; + +// Registers a RequestCostAccessor type to the global map. Registering different +// types of RequestCostAccessor with the same name is prohibited. +class RequestCostAccessorRegistrar { + public: + explicit RequestCostAccessorRegistrar( + absl::string_view name, RequestCostAccessorRegistry::Creator creator) { + RequestCostAccessorRegistry::RegisterRequestCostAccessor( + name, std::move(creator)); + } +}; + +#define REGISTER_REQUEST_COST_ACCESSOR(name, MyRequestCostAccessorClass) \ + namespace { \ + static ::tensorflow::RequestCostAccessorRegistrar \ + MyRequestCostAccessorClass##_registrar((name), [] { \ + return std::make_unique(); \ + }); \ + } // namespace + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_REQUEST_COST_ACCESSOR_REGISTRY_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/ring_alg.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/ring_alg.h new file mode 100644 index 00000000..df907258 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/ring_alg.h @@ -0,0 +1,121 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_RING_ALG_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_RING_ALG_H_ + +#include +#include +#include +#include + +#include "tensorflow/core/common_runtime/base_collective_executor.h" +#include "tensorflow/core/framework/collective.h" + +namespace tensorflow { +class Device; + +// Basic ring-algorithm implementation to be further specialized +// for specific collective functions. +class RingAlg : public CollectiveImplementationInterface { + public: + explicit RingAlg(CollectiveType type, const string& name); + ~RingAlg() override {} + + // Establishes the requested number of subdivision permutations based on the + // ring order implicit in the device order. + absl::Status InitializeCollectiveParams( + CollectiveParams* col_params) override; + + // Initializes members of CollectiveContext not yet initialized, i.e. device + // and device_locality. Also saves the CollectiveContext in this object. + absl::Status InitializeCollectiveContext( + std::shared_ptr col_ctx) override; + + protected: + // Called when a bad status is received that implies we should terminate + // execution and return a bad status. + void StartAbort(const absl::Status& s); + void Finish(bool ok); + + // Current status of a RingField + enum RingFieldAction { + RF_INIT = 0, // Just initialized for a pass + RF_RECV, // Recv pending + RF_REDUCE, // Reduce pending + RF_FINALIZE, // FinalOp pending + RF_SEND_READY, // Ready to send + RF_SEND, // Send pending + RF_DONE, // No more work + }; + + // Tracks progress of actions on a single subfield of the entire tensor. + struct RingField { + int16 chunk_idx; // major division index + int16 subdiv_idx; // minor division index + int16 sc_idx; // subchunk index + int16 rank; // rank within subdiv permutation + int16 recv_dev_idx; // dev from which value should be recv'd + RingFieldAction action; + bool second_pass; + bool recv_is_remote = false; + bool send_is_remote = false; + bool do_send = false; // is the value sent in this pass? + bool do_recv = false; // is the value recv'd in this pass? + bool is_final = false; // is the last field in the pass for this rank + Tensor chunk; // alias to field values + Tensor tmp_chunk; + absl::Status status; + string DebugString() const; + }; + virtual void InitRingField(RingField* rf, int chunk_idx, int subdiv_idx, + int field_idx); + void AdvanceToSecondPass(RingField* rf); + void DispatchSend(RingField* rf, const StatusCallback& done); + void DispatchRecv(RingField* rf, const StatusCallback& done); + + // For constructing log messages for debugging. + string FieldState(); + string TensorDebugString(const Tensor& tensor); + + // Producer/Consumer Queue of RingField structs. + class PCQueue { + public: + void Enqueue(RingField* rf); + RingField* Dequeue(); + + private: + mutex pcq_mu_; + condition_variable cv_; + int waiter_count_ TF_GUARDED_BY(pcq_mu_) = 0; + std::deque deque_ TF_GUARDED_BY(pcq_mu_); + }; + + const CollectiveType type_; + const string name_; + std::shared_ptr col_ctx_; + const CollectiveParams* col_params_; // Not owned + StatusCallback done_; + int group_size_; + int num_subdivs_; + Tensor group_size_tensor_; + Notification group_size_tensor_ready_; + std::unique_ptr ca_; + mutex status_mu_; + absl::Status status_ TF_GUARDED_BY(status_mu_); + std::vector rfv_; +}; + +} // namespace tensorflow +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_RING_ALG_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/ring_gatherer.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/ring_gatherer.h new file mode 100644 index 00000000..ac894a38 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/ring_gatherer.h @@ -0,0 +1,52 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_RING_GATHERER_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_RING_GATHERER_H_ + +#include +#include +#include +#include + +#include "tensorflow/core/common_runtime/base_collective_executor.h" +#include "tensorflow/core/common_runtime/ring_alg.h" +#include "tensorflow/core/framework/collective.h" + +namespace tensorflow { +class Device; + +// Ring-algorithm implementation of collective all-gather. +class RingGatherer : public RingAlg { + public: + RingGatherer() : RingAlg(GATHER_COLLECTIVE, "Gather") {} + ~RingGatherer() override {} + + absl::Status InitializeCollectiveParams( + CollectiveParams* col_params) override; + + // Begins async execution of the ring gather algorithm. + // Must be called in a blockable thread. + // TODO(b/80529858): remove the previous warning when we have a dedicated + // collective threadpool. + void Run(StatusCallback done) override; + + private: + bool RunAsyncParts(); + + friend class RingGathererTest; +}; + +} // namespace tensorflow +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_RING_GATHERER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/ring_reducer.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/ring_reducer.h new file mode 100644 index 00000000..77317235 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/ring_reducer.h @@ -0,0 +1,61 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_RING_REDUCER_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_RING_REDUCER_H_ + +#include +#include +#include +#include + +#include "tensorflow/core/common_runtime/base_collective_executor.h" +#include "tensorflow/core/common_runtime/ring_alg.h" +#include "tensorflow/core/framework/collective.h" + +namespace tensorflow { +class Device; + +// Ring-algorithm implementation of collective all-reduce. +class RingReducer : public RingAlg { + public: + RingReducer() : RingAlg(REDUCTION_COLLECTIVE, "Reduce") {} + ~RingReducer() override; + + // Begins async execution of the ring reduce algorithm. + // Must be called in a blockable thread. + // TODO(b/80529858): remove the previous warning when we have a dedicated + // collective threadpool. + void Run(StatusCallback done) override; + + absl::Status InitializeCollectiveParams( + CollectiveParams* col_params) override; + + protected: + void InitRingField(RingField* rf, int chunk_idx, int subdiv_idx, + int field_idx) override; + + private: + void ContinueAfterInputCopy(); + bool RunAsyncParts(); + + Tensor group_size_tensor_; + Notification group_size_tensor_ready_; + + friend class RingReducerTest; + friend class RingReducerInitParamsTest; +}; + +} // namespace tensorflow +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_RING_REDUCER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/scoped_allocator.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/scoped_allocator.h new file mode 100644 index 00000000..5b22deb2 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/scoped_allocator.h @@ -0,0 +1,127 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_SCOPED_ALLOCATOR_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_SCOPED_ALLOCATOR_H_ + +#include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/lib/core/refcount.h" +#include "tensorflow/core/platform/mutex.h" + +namespace tensorflow { +class ScopedAllocatorContainer; +class ScopedAllocatorInstance; + +// Manages a single backing tensor and a collection of aliases. +class ScopedAllocator { + public: + static constexpr int32_t kInvalidId = 0; + static constexpr size_t kMaxAlignment = 64; + + // A subrange of the TensorBuffer associated with this object that + // will be the backing memory for one aliased tensor. + struct Field { + int32 scope_id; + size_t offset; + size_t bytes_requested; + size_t bytes_allocated; + }; + // Field index that refers to backing tensor, not any aliased field. + static constexpr int32_t kBackingIndex = -1; + + // backing_tensor is expected to be newly allocated by a ScopedAllocatorOp + // instance. It must be large enough to back all of the specified + // (offset, byte) ranges of the fields. + ScopedAllocator(const Tensor& backing_tensor, int32_t scope_id, + const std::string& name, const absl::Span fields, + int32_t expected_call_count, + ScopedAllocatorContainer* container); + + // Automatically deletes when last use expires, or when + // ScopedAllocatorContainer decides to delete. + ~ScopedAllocator() TF_LOCKS_EXCLUDED(mu_); + + // For debugging: returns true iff p is a pointer that could have + // been returned by AllocateRaw. + bool VerifyPointer(const void* p); + bool VerifyTensor(const Tensor* t); + + const Tensor& tensor() const { return backing_tensor_; } + + const std::string& name() const { return name_; } + + private: + friend class ScopedAllocatorInstance; + // Only ScopedAllocatorInstances can call AllocateRaw and DeallocateRaw on a + // ScopedAllocator + void* AllocateRaw(int32_t field_index, size_t num_bytes) + TF_LOCKS_EXCLUDED(mu_); + void DeallocateRaw(void* p) TF_LOCKS_EXCLUDED(mu_); + Tensor backing_tensor_; + TensorBuffer* tbuf_; + int32 id_; + std::string name_; + ScopedAllocatorContainer* container_; + std::vector fields_; + mutex mu_; + int32 expected_call_count_ TF_GUARDED_BY(mu_); + int32 live_alloc_count_ TF_GUARDED_BY(mu_); +}; + +// An Allocator that will return a pointer into the backing buffer of +// a previously allocated tensor, allowing creation of an alias +// tensor. There is a one-to-one mapping between the fields of a +// ScopedAllocator and ScopedAllocatorInstances. There is also a one-to-one +// mapping between scope_ids and ScopedAllocatorInstances. It should be +// discarded immediately after a single use. +class ScopedAllocatorInstance : public Allocator { + public: + explicit ScopedAllocatorInstance(ScopedAllocator* sa, int32_t field_index); + + private: + ~ScopedAllocatorInstance() override { + VLOG(1) << "~ScopedAllocatorInstance " << this; + } + + public: + // When a ScopedAllocatorContainer "Drops" a scope_id, it calls DropFromTable + // on the underlying ScopedAllocatorInstance. If this instance has already + // deallocated the tensor slice, we can safely delete this. + void DropFromTable() TF_LOCKS_EXCLUDED(mu_); + void* AllocateRaw(size_t alignment, size_t num_bytes) + TF_LOCKS_EXCLUDED(mu_) override; + void* AllocateRaw(size_t alignment, size_t num_bytes, + const AllocationAttributes& allocator_attr) override { + return AllocateRaw(alignment, num_bytes); + } + void DeallocateRaw(void* p) TF_LOCKS_EXCLUDED(mu_) override; + bool TracksAllocationSizes() const override { return false; } + size_t RequestedSize(const void* ptr) const override { return 0; } + size_t AllocatedSize(const void* ptr) const override { return 0; } + int64_t AllocationId(const void* ptr) const override { return 0; } + size_t AllocatedSizeSlow(const void* ptr) const override { return 0; } + std::string Name() override; + + private: + mutex mu_; + ScopedAllocator* scoped_allocator_; + int32 field_index_; + bool allocated_ TF_GUARDED_BY(mu_); + bool deallocated_ TF_GUARDED_BY(mu_); + bool in_table_ TF_GUARDED_BY(mu_); +}; + +} // namespace tensorflow +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_SCOPED_ALLOCATOR_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/scoped_allocator_mgr.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/scoped_allocator_mgr.h new file mode 100644 index 00000000..dbbf7c32 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/scoped_allocator_mgr.h @@ -0,0 +1,111 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_SCOPED_ALLOCATOR_MGR_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_SCOPED_ALLOCATOR_MGR_H_ + +#include +#include + +#include "tensorflow/core/common_runtime/scoped_allocator.h" +#include "tensorflow/core/lib/core/refcount.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/mutex.h" + +namespace tensorflow { +class ScopedAllocatorMgr; + +// At most one of these exists per pair. +// A Ref is held by every ScopedAllocator and also by the ScopedAllocatorMgr. +class ScopedAllocatorContainer : public core::RefCounted { + public: + // Establishes a reachable ScopedAllocator. + absl::Status AddScopedAllocator( + const Tensor& backing_tensor, int32_t scope_id, + const std::string& scope_name, + const absl::Span& fields, + int32_t expected_call_count); + + ScopedAllocatorInstance* GetInstance(int32_t scope_id); + ScopedAllocator* GetAllocator(int32_t scope_id); + + // Retire the scope_id. + void Drop(int32_t scope_id, ScopedAllocator* sa); + + protected: + friend class ScopedAllocatorMgr; + ScopedAllocatorContainer(const ScopedAllocatorMgr* mgr, int64_t step_id) + : mgr_(mgr), step_id_(step_id) {} + ~ScopedAllocatorContainer(); + + private: + const ScopedAllocatorMgr* mgr_; + int64_t step_id_; + mutex mu_; + struct SAField { + int32 field_index; + union { + ScopedAllocator* scoped_allocator; + ScopedAllocatorInstance* instance; + }; + SAField(int32_t fi, ScopedAllocatorInstance* sai) + : field_index(fi), instance(sai) {} + SAField(int32_t fi, ScopedAllocator* sa) + : field_index(fi), scoped_allocator(sa) {} + SAField() + : field_index(ScopedAllocator::kBackingIndex), + scoped_allocator(nullptr) {} + }; + std::unordered_map allocators_ TF_GUARDED_BY(mu_); +}; + +// At most one of these exists per device. +class ScopedAllocatorMgr { + public: + explicit ScopedAllocatorMgr(const std::string& device_name) + : device_name_(device_name) {} + ~ScopedAllocatorMgr(); + + ScopedAllocatorContainer* GetContainer(int64_t step_id); + + // Establishes a reachable ScopedAllocator. + absl::Status AddScopedAllocator( + const Tensor& backing_tensor, int64_t step_id, int32_t scope_id, + const std::string& scope_name, + const absl::Span& fields, + int32_t expected_call_count); + + void Cleanup(int64_t step_id); + + // Populate the bytes and offset members of Field. Instance allocaters get + // consecutive scope_id values following that of the base ScopedAllocator. + // Returns the total number of bytes required to be allocated in the + // backing tensor, for convenience. (The same value can be obtained + // by summing offset and bytes in the last field.) + static size_t PopulateFields(int32_t scope_id, + const absl::Span& shapes, + const DataType dtype, + std::vector* fields); + + const std::string& device_name() const { return device_name_; } + + private: + std::string device_name_; + mutex mu_; + std::unordered_map per_step_map_ + TF_GUARDED_BY(mu_); +}; + +} // namespace tensorflow +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_SCOPED_ALLOCATOR_MGR_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/session_factory.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/session_factory.h new file mode 100644 index 00000000..ffadb29a --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/session_factory.h @@ -0,0 +1,76 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_SESSION_FACTORY_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_SESSION_FACTORY_H_ + +#include + +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +class Session; +struct SessionOptions; + +class SessionFactory { + public: + // Creates a new session and stores it in *out_session, or fails with an error + // status if the Session could not be created. Caller takes ownership of + // *out_session if this returns OkStatus(). + virtual absl::Status NewSession(const SessionOptions& options, + Session** out_session) = 0; + + virtual bool AcceptsOptions(const SessionOptions& options) = 0; + + // Abort and close all existing sessions, disconnecting their resources from + // future sessions. + // + // Reset() allows misbehaving or slow sessions to be aborted and closed, and + // causes their resources eventually to be released. Reset() does not wait + // for the computations in old sessions to cease; it merely starts the + // process of tearing them down. However, if a new session is started after + // a Reset(), the new session is isolated from changes that old sessions + // (started prior to the Reset()) may continue to make to resources, provided + // all those resources are in containers listed in "containers". + // + // Old sessions may continue to have side-effects on resources not in + // containers listed in "containers", and thus may affect future + // sessions' results in ways that are hard to predict. Thus, if well-defined + // behavior is desired, is it recommended that all containers be listed in + // "containers". + // + // If the "containers" vector is empty, the default container is assumed. + // If the "containers" vector is non-empty, the default container should be + // listed explicitly. + // + // Sessions that support resource containers should override this function. + virtual absl::Status Reset(const SessionOptions& options, + const std::vector& containers) { + return errors::Unimplemented("Reset()"); + } + + virtual ~SessionFactory() {} + static void Register(const string& runtime_type, SessionFactory* factory); + static absl::Status GetFactory(const SessionOptions& options, + SessionFactory** out_factory); +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_SESSION_FACTORY_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/shape_refiner.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/shape_refiner.h new file mode 100644 index 00000000..580dafb0 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/shape_refiner.h @@ -0,0 +1,293 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_SHAPE_REFINER_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_SHAPE_REFINER_H_ + +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "tensorflow/core/common_runtime/graph_runner.h" +#include "tensorflow/core/framework/function.pb.h" +#include "tensorflow/core/framework/shape_inference.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/macros.h" + +namespace tensorflow { +namespace grappler { +class GraphProperties; +} + +// ShapeRefiner performs shape inference for TensorFlow Graphs. It is +// responsible for instantiating InferenceContext objects for each +// Node in the Graph, and providing/storing the 'input_tensor' Tensors +// used by Shape Inference functions, when available at graph +// construction time. +class ShapeRefiner { + public: + ShapeRefiner(int graph_def_version, const OpRegistryInterface* ops); + + // Same as ShapeRefiner(versions.producer(), ops) + ShapeRefiner(const VersionDef& versions, const OpRegistryInterface* ops); + + ~ShapeRefiner(); + + // Performs validation of 'node' and runs 'node's shape function, + // storing its shape outputs. + // + // All inputs of 'node' must be added to ShapeRefiner prior to + // adding 'node'. + // + // Returns an error if: + // - the shape function for 'node' was not registered. + // - 'node' was added before its inputs. + // - The shape inference function returns an error. + absl::Status AddNode(const Node* node); + + // Sets 'node's 'output_port' output to have shape 'shape'. + // + // Returns an error if 'node' was not previously added to this + // object, if 'output_port' is invalid, or if 'shape' is + // not compatible with the existing shape of the output. + absl::Status SetShape(const Node* node, int output_port, + shape_inference::ShapeHandle shape); + + // Update the input shapes of node in case the shapes of the fan-ins of 'node' + // have themselves been modified (For example, in case of incremental shape + // refinement). If 'relax' is true, a new shape with the broadest set of + // information will be set as the new input (see InferenceContext::RelaxInput + // for full details and examples). Sets refined to true if any shapes have + // changed (in their string representations). Note that shapes may have been + // updated to newer versions (but with identical string representations) even + // if <*refined> is set to false. + absl::Status UpdateNode(const Node* node, bool relax, bool* refined); + + // Returns the InferenceContext for 'node', if present. + shape_inference::InferenceContext* GetContext(const Node* node) const { + auto it = node_to_context_.find(node); + if (it == node_to_context_.end()) { + return nullptr; + } + return it->second.get(); + } + + // Getters and setters for graph_def_version_. + int32 graph_def_version() const { return graph_def_version_; } + void set_graph_def_version(int32_t version) { graph_def_version_ = version; } + + void set_require_shape_inference_fns(bool require_shape_inference_fns) { + require_shape_inference_fns_ = require_shape_inference_fns; + } + void set_disable_constant_propagation(bool disable) { + disable_constant_propagation_ = disable; + } + + // Set function library to enable function shape inference. + // Without function library, function inference always yields unknown shapes. + // With this enabled, shape inference can take more time since it descends + // into all function calls. It doesn't do inference once for each function + // definition, but once for each function call. + // The function library must outlive the shape refiner. + void set_function_library_for_shape_inference( + const tensorflow::FunctionLibraryDefinition* lib) { + function_library_ = lib; + } + + bool function_shape_inference_supported() const { + return function_library_ != nullptr; + } + + private: + friend class ShapeRefinerTest; + friend class ::tensorflow::grappler::GraphProperties; + + // Returns true if the ranks and all dimensions of and are either + // equal in value or both unknown. + static bool SameDefinedShape(shape_inference::InferenceContext* c, + shape_inference::ShapeHandle s0, + shape_inference::ShapeHandle s1); + + // Returns true if the shapes and types stored in <*existing> are identical in + // value to the shapes and types in <*updated>. + static bool IsUpdatedShapesOrTypes( + shape_inference::InferenceContext* c, + const std::vector& existing, + const std::vector& updated); + + // Performs shape inference for the given function_def within the + // given outer_context. Internally it instantiates the function as a graph + // and runs shape inference recursively on it with the input shapes provided + // by the outer_context. + // + // Returns an error if: + // - number of inputs/outputs on outer_context doesn't match the function_def + // + // On success: + // - outer_context will contain output shapes inferred from input shapes + absl::Status InferShapesForFunction( + const FunctionDef* function_def, AttrSlice attributes, + shape_inference::InferenceContext* outer_context); + + // Performs shape inference for a node inside a function. + // + // 'outer_context' is the 'InferenceContext' for the function's call op. + absl::Status InferShapesForFunctionSubNode( + const Node* node, shape_inference::InferenceContext* outer_context); + + // Performs validation of 'node' and runs 'node's shape function, + // storing its shape outputs. + // + // All inputs of 'node' must be added to ShapeRefiner prior to + // adding 'node'. + // + // Optionally, if 'node' is in a nested function, the 'InferenceContext' for + // the call op of the function can be passed as 'outer_context' (pass nullptr + // otherwise). This gets used to perform constant propagation across Arg nodes + // by requesting the constant of value of the incoming tensor from the + // 'outer_context'. + // + // Returns an error if: + // - the shape function for 'node' was not registered. + // - 'node' was added before its inputs. + // - The shape inference function returns an error. + absl::Status AddNodeInternal( + const Node* node, shape_inference::InferenceContext* outer_context); + + // Attempts to evaluate the 'dst_idx'-th input to 'node'. If the input edge + // value can be evaluated, 'evaluated' is set to true and the value returned + // in 'result'. Otherwise 'evaluated' is set to false. + // + // Optionally, if 'node' is in a nested function, the 'InferenceContext' for + // the call op of the function can be passed as 'outer_context' (pass nullptr + // otherwise). This gets used to perform constant propagation across Arg nodes + // by requesting the constant of value of the incoming tensor from the + // 'outer_context'. + absl::Status EvaluateConstantTensorForEdge( + const Node* node, int dst_idx, bool* evaluated, Tensor* result, + shape_inference::InferenceContext* outer_context); + + // Wrapper around EvaluateConstantTensorForEdge for scalar int32/int64 input + // tensors. The caller is responsible for checking that the specified edge is + // scalar and int32 or int64. + // + // Optionally, if 'node' is in a nested function, the 'InferenceContext' for + // the call op of the function can be passed as 'outer_context' (pass nullptr + // otherwise). This gets used to perform constant propagation across Arg nodes + // by requesting the constant of value of the incoming tensor from the + // 'outer_context'. + absl::Status EvaluateConstantIntScalarEdge( + const Node* node, int dst_idx, bool* evaluated, int64_t* result, + shape_inference::InferenceContext* outer_context); + + // This function tries to materialize as much information about the 'node''s + // dst_idx input as a statically computable shape, and the result may be + // partially known, depending on what is statically inferable. + // + // This is called when node.input[dst_idx] is a tensor that is used to define + // the shape of some other tensor (e.g., the second argument to Reshape is a + // tensor, where each element of the shape tensor is a dimension of + // the target tensor). It returns in a shape for that input. + // + // Unlike simply resolving node.input[dst_idx] to a constant and then + // converting that to a shape, this function can return a partial shape. This + // is useful for cases where the shape tensor is only partially defined, such + // as with calls for: reshape(x, shape(y)) where shape(y) is partially + // defined. + // + // The implementation has op implementations for ops commonly called on shape + // tensors, and the implementations are specialized to shape tensors (namely, + // the output is a vector). + // + // is used when creating new DimensionHandle and ShapeHandle + // objects. + // + // Optionally, if 'node' is in a nested function, the 'InferenceContext' for + // the call op of the function can be passed as 'outer_context' (pass nullptr + // otherwise). This gets used to perform constant propagation across Arg nodes + // by requesting the constant of value of the incoming tensor from the + // 'outer_context'. + absl::Status ConstantPartialShape( + shape_inference::InferenceContext* target_context, const Node* node, + int dst_idx, shape_inference::ShapeHandle* result, + shape_inference::InferenceContext* outer_context); + + // Implementation of ConstantPartialShape for StridedSlice nodes. + // + // Optionally, if 'node' is in a nested function, the 'InferenceContext' for + // the call op of the function can be passed as 'outer_context' (pass nullptr + // otherwise). This gets used to perform constant propagation across Arg nodes + // by requesting the constant of value of the incoming tensor from the + // 'outer_context'. + absl::Status PartialStridedSliceShape( + Node* slice_node, shape_inference::InferenceContext* ctx, + shape_inference::ShapeHandle* result, + shape_inference::InferenceContext* outer_context); + + // Runs the shape function registered for the node's op type. + // + // Optionally, if 'node' is in a nested function, the 'InferenceContext' for + // the call op of the function can be passed as 'outer_context' (pass nullptr + // otherwise). This gets used to perform constant propagation across Arg nodes + // by requesting the constant of value of the incoming tensor from the + // 'outer_context'. + absl::Status RunShapeFn( + const Node* node, const OpRegistrationData* op_reg_data, + shape_inference::InferenceContext* context, + shape_inference::InferenceContext* outer_context = nullptr); + + int32 graph_def_version_; + const OpRegistryInterface* const ops_registry_; + + // The lifetime of the tensors are bound to the runner, so it should be the + // deleted after the tensors. + GraphRunner graph_runner_; + + // Stores a map from a node to its InferenceContext. + absl::flat_hash_map, + hash> + node_to_context_; + + // Holds a cache from tensor id (node id:node output) to the tensor that + // is evaluable as a constant expression. This reduces repeated execution + // of the entire constant subgraph as a graph is being built up. This could + // be changed to some kind of size-based LRU cache to avoid consuming too much + // memory, if that eventually becomes a concern. + // + // Only tensors less than 1KiB are currently stored in the cache. + static constexpr int64_t kMaxTensorSize = 1024; + absl::flat_hash_map, Tensor> const_tensor_map_; + + bool require_shape_inference_fns_ = true; + bool disable_constant_propagation_ = false; + + // Function library is optional, but has to be set to enable function + // shape inference. + const tensorflow::FunctionLibraryDefinition* function_library_ = nullptr; + + // Cache the graph corresponding to each function definition for which shapes + // are refined. + absl::flat_hash_map> functions_; + + ShapeRefiner(const ShapeRefiner&) = delete; + void operator=(const ShapeRefiner&) = delete; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_SHAPE_REFINER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/shared_counter.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/shared_counter.h new file mode 100644 index 00000000..d40f24f9 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/shared_counter.h @@ -0,0 +1,26 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_SHARED_COUNTER_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_SHARED_COUNTER_H_ + +#include + +#include "xla/tsl/framework/shared_counter.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +using tsl::SharedCounter; // NOLINT +} // namespace tensorflow +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_SHARED_COUNTER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/simple_propagator_state.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/simple_propagator_state.h new file mode 100644 index 00000000..9f465ef1 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/simple_propagator_state.h @@ -0,0 +1,190 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_SIMPLE_PROPAGATOR_STATE_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_SIMPLE_PROPAGATOR_STATE_H_ + +#include + +#include "tensorflow/core/common_runtime/entry.h" +#include "tensorflow/core/common_runtime/immutable_executor_state.h" +#include "tensorflow/core/common_runtime/pending_counts.h" +#include "tensorflow/core/framework/control_flow.h" +#include "tensorflow/core/lib/gtl/inlined_vector.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/thread_annotations.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +// Represents the ephemeral "edge state" associated with one invocation of +// `Executor::Run()`. +// +// NOTE: `SimplePropagatorState` does not support "v1-style" control flow, +// including "dead tensors", "Switch" and "Merge" nodes, and cycles in the +// graph. Use `PropagatorState` for graphs with those features. +// `SimplePropagatorState` *does* support "v2-style" or "functional" control +// flow. +// +// `SimplePropagatorState` is responsible for propagating values along dataflow +// edges in a TensorFlow graph and determining which nodes are runnable. The +// executor primarily updates `SimplePropagatorState` by calling +// `PropagateOutputs()` after processing a node, and `SimplePropagatorState` +// dispatches `TaggedNode`s by adding them to a `TaggedNodeSeq`. +class SimplePropagatorState { + public: + SimplePropagatorState(const ImmutableExecutorState& immutable_state, + int64_t step_id, bool vlog); + ~SimplePropagatorState(); + + // A `TaggedNode` corresponds to a single invocation of a node's kernel, + // and it is created when the kernel becomes runnable. + struct TaggedNode { + const NodeItem* node_item; + + explicit TaggedNode(const NodeItem* node_item) : node_item(node_item) {} + + const NodeItem& get_node_item() const { return *node_item; } + + bool get_is_dead() const { return false; } + int64_t get_iter_num() const { return 0; } + }; + + // A drop-in replacement for std::deque. We typically don't + // have that many nodes in the ready queue, so we just use a vector and + // don't free up memory from the queue as we consume nodes. + // TODO(mrry): Extract this and share it with the version in + // `PropagatorState`. The correct constants might be different, since + // sizeof(TaggedNode) is smaller in this version. + class TaggedNodeReadyQueue { + public: + TaggedNodeReadyQueue() : front_index_(0) {} + + void push_back(const TaggedNode& node) { ready_.push_back(node); } + TaggedNode front() const { + DCHECK_LT(front_index_, ready_.size()); + return ready_[front_index_]; + } + void pop_front() { + DCHECK_LT(front_index_, ready_.size()); + front_index_++; + if ((front_index_ == ready_.size()) || (front_index_ > kSpillThreshold)) { + if (front_index_ == ready_.size()) { + ready_.clear(); + } else { + // Lots of unused entries at beginning of vector: move everything + // down to start of vector. + ready_.erase(ready_.begin(), ready_.begin() + front_index_); + } + front_index_ = 0; + } + } + bool empty() const { return ready_.empty(); } + int size() const { return ready_.size() - front_index_; } + + private: + // TODO(b/152925936): Re-evaluate these constants with current usage + // patterns. + static constexpr int kSpillThreshold = 16384; + absl::InlinedVector ready_; + int front_index_; + }; + + // TODO(b/152925936): Re-evaluate this constant with current usage patterns. + typedef absl::InlinedVector TaggedNodeSeq; + + // Creates and adds a `TaggedNode` for each node in `roots` to `*ready`. + void ActivateRoots(gtl::ArraySlice roots, + TaggedNodeSeq* ready); + + // After processing the outputs, propagates the outputs to their dsts. + // Contents of *outputs are left in an indeterminate state after + // returning from this method. + void PropagateOutputs(const TaggedNode& tagged_node, EntryVector* outputs, + TaggedNodeSeq* ready); + + // Returns an array of `Entry` objects corresponding to the inputs of + // `tagged_node`. + Entry* GetInputTensors(const TaggedNode& tagged_node) { +#if defined(THREAD_SANITIZER) || defined(DEBUG) + // NOTE: This read of `pending_[...]` works around a limitation in TSAN. + // To avoid false positive data race reports, we need to perform an atomic + // object access that will establish the happens-before relation between + // the write to input_tensors_ in `PropagateOutputs()` and the read in + // `PrepareInputs()`. + CHECK_EQ(pending_[tagged_node.node_item->node_id], 0); +#endif // defined(THREAD_SANITIZER) || defined(DEBUG) + return input_tensors_.data() + tagged_node.node_item->input_start; + } + + FrameAndIter GetFrameAndIter(const TaggedNode& tagged_node) const { + return {0, 0}; + } + + // Provide debugging output of the state of the executor. + void DumpState(); + + // For debugging/logging only. + void MaybeMarkStarted(const TaggedNode& tagged_node) { + // TODO(misard) Replace with a finer-grain enabling flag once we add better + // optional debugging support. + if (TF_PREDICT_FALSE(vlog_) && VLOG_IS_ON(1)) { + mutex_lock l(mu_); + (*active_)[tagged_node.node_item->node_id] = true; + } + } + void MaybeMarkCompleted(const TaggedNode& tagged_node) { + // TODO(misard) Replace with a finer-grain enabling flag once we add better + // optional debugging support. + if (TF_PREDICT_FALSE(vlog_) && VLOG_IS_ON(1)) { + mutex_lock l(mu_); + (*active_)[tagged_node.node_item->node_id] = false; + } + } + + private: + SimplePropagatorState(const ImmutableExecutorState& immutable_state_, + int64_t step_id, + const ImmutableExecutorState::FrameInfo& finfo, + bool vlog); + + const ImmutableExecutorState& immutable_state_; + const int64_t step_id_; + const bool vlog_; + + // The i-th node's j-th input is stored at + // `input_tensors[impl_->nodes[i].input_start + j]`. + // + // NOTE: No need to protect input_tensors[i] by any locks because it + // is resized once. Each element of input_tensors is written once by the + // source node of an edge and is cleared by the destination of the same + // edge. The destination node always runs after the source node, so there + // is never concurrent access to the same entry. + std::vector input_tensors_; + + std::unique_ptr[]> pending_; + + // If `vlog_` is true, this stores a bit vector of active nodes, indexed by + // node ID. + mutex mu_; + std::unique_ptr> active_ TF_GUARDED_BY(mu_); + + const std::vector* const nodes_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_SIMPLE_PROPAGATOR_STATE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/simplify_ici_dummy_variables_pass.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/simplify_ici_dummy_variables_pass.h new file mode 100644 index 00000000..553e298f --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/simplify_ici_dummy_variables_pass.h @@ -0,0 +1,109 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_SIMPLIFY_ICI_DUMMY_VARIABLES_PASS_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_SIMPLIFY_ICI_DUMMY_VARIABLES_PASS_H_ + +#include "tensorflow/core/common_runtime/optimization_registry.h" +#include "tensorflow/core/platform/status.h" + +// Create new dummy zero variables to TPUExecute Op for ICI +// weight distribution, which is a critical feature in TF2/Min. The new dummy +// zero variables will be put on the same task as the TPUExecute Op. The old +// dummy zero variables will be removed afterwards. +// +// For example, in the following graph, the inputs to TPUExecute Op are on +// task:0, after the pass, the dummy zero variables will be put on task:2. +// which is the same as the TPUExecute. +// +// The graph before pass is: +// +// node {name: "const0", op: "Const"} +// node {name: "const1", op: "Const"} +// node {name: "fill0", op: "Fill", input: "const1", input: "const0"} +// node {name: "Identity0", op: "Identity", input: "fill0", +// device: "/job:tpu_host_worker/replica:0/task:0/device:CPU:0" +// attr { +// key: "_ici_weight_distribution_mlir_bridge_marker", value {b: true} +// } +// } +// node {name: "const2", op: "Const"} +// node {name: "const3", op: "Const"} +// node {name: "fill1", op: "Fill", input: "const2", input: "const3"} +// node {name: "identity1", op: "Identity", input: "fill1" +// device: "/job:tpu_host_worker/replica:0/task:0/device:CPU:0" +// attr { +// key: "_ici_weight_distribution_mlir_bridge_marker", value {b: true} +// } +// } +// node {name: "const4", op: "Const"} +// node {name: "split0", op: "Split", input: "const4", input: "identity1" +// attr { +// key: "_ici_weight_distribution_mlir_bridge_marker" +// value {b: true} +// } +// } +// node {name: "TPUExecute0", op: "TPUExecute" +// input: "identity0", input: "split0:1" +// device: "/job:worker/replica:0/task:2/device:TPU:0" +// attr { +// key: "_parallel_execution_ids" +// value {s: "r0:1,p0:2"} +// } +// } +// +// The graph after pass is: +// +// node {name: "const0_dummy", op: "Const", +// device: "/job:tpu_host_worker/replica:0/task:2/device:CPU:0" +// } +// node {name: "const1_dummy", op: "Const", +// device: "/job:tpu_host_worker/replica:0/task:2/device:CPU:0" +// } +// node {name: "fill0_dummy", op: "Fill", +// input: "const1_dummy", input: "const0_dummy", +// device: "/job:tpu_host_worker/replica:0/task:2/device:CPU:0" +// } +// node {name: "const2_dummy", op: "Const", +// device: "/job:tpu_host_worker/replica:0/task:2/device:CPU:0" +// } +// node {name: "const3_dummy", op: "Const", +// device: "/job:tpu_host_worker/replica:0/task:2/device:CPU:0" +// } +// node {name: "fill1_dummy", op: "Fill", +// input: "const2_dummy", input: "const3_dummy", +// device: "/job:tpu_host_worker/replica:0/task:2/device:CPU:0" +// } +// node {name: "TPUExecute0", op: "TPUExecute" +// input: "fill0_dummy", input: "fill1_dummy" +// device: "/job:worker/replica:0/task:2/device:TPU:0" +// attr { +// key: "_parallel_execution_ids" +// value {s: "r0:1,p0:2"} +// } +// } + +namespace tensorflow { + +// This pass will simplify the dummy variables for ICI weight distribution. +// The dummy variables will be put on the same task as the TPUExecute Op. +class SimplifyIciDummyVariablesPass : public GraphOptimizationPass { + public: + absl::Status Run(const GraphOptimizationPassOptions& options) override; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_SIMPLIFY_ICI_DUMMY_VARIABLES_PASS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/single_threaded_cpu_device.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/single_threaded_cpu_device.h new file mode 100644 index 00000000..3498e4aa --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/single_threaded_cpu_device.h @@ -0,0 +1,36 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_SINGLE_THREADED_CPU_DEVICE_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_SINGLE_THREADED_CPU_DEVICE_H_ + +namespace tsl { +class Env; +} // namespace tsl +namespace tensorflow { +using Env = tsl::Env; + +class Device; + +// Returns a simple single-threaded CPU device. This can be used to run +// inexpensive computations. In particular, using this avoids initializing the +// global thread pools in LocalDevice. +// +// The returned pointer is owned by the caller. +Device* NewSingleThreadedCpuDevice(Env* env); + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_SINGLE_THREADED_CPU_DEVICE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/single_threaded_executor.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/single_threaded_executor.h new file mode 100644 index 00000000..55749ed6 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/single_threaded_executor.h @@ -0,0 +1,66 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_SINGLE_THREADED_EXECUTOR_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_SINGLE_THREADED_EXECUTOR_H_ + +#include "tensorflow/core/common_runtime/executor.h" + +namespace tensorflow { + +// Creates a new `Executor` for executing `graph` synchronously on the caller +// thread. +// +// NOTE(mrry): The returned executor is optimized to impose low overhead on +// graphs that perform a small amount of work (e.g. <15us of work per graph on +// present architectures). It eschews concurrency, because issuing work to +// multiple threads can dominate the cost of executing small ops synchronously, +// and because contention in the executor data structures can reduce throughput +// (in terms of ops executed per unit time). +// +// However, the current implementation has the following limitations: +// +// 1. Reference-typed tensors are not supported and will not be supported in +// future. +// 2. Graphs with control flow (containing "Switch" and "Merge" nodes) are not +// currently supported. The current plan is to extend support to "functional" +// control flow after the TensorFlow APIs transition to building graphs in +// that form (e.g. `tf.cond_v2()`). +// 3. Partitioned graphs (containing "_Recv" nodes) are not currently supported. +// The present implementation executes kernels one at a time in topological +// order, and cannot currently distinguish between disconnected subgraphs +// that are logically connected by subgraphs on a different device. +// 4. Memory logging is not currently supported. +// 5. Allocation forwarding is not currently supported. +// 6. Non-default device contexts are not currently supported. In effect, this +// limits the executor to CPU devices. +// 7. Ops that rely on `OpKernelContext::slice_reader_cache()` being non-null +// are not currently supported. +// +// The single-threaded executor is primarily suitable for executing simple +// TensorFlow functions, such as one might find in a `tf.data` pipeline. +absl::Status NewSingleThreadedExecutor(const LocalExecutorParams& params, + const Graph& graph, Executor** executor); + +// Returns OkStatus() for ops which are compatible with synchronous execution, +// and otherwise returns an error message appropriate for propagation if needed. +// If `allow_control_flow_sync_execution` is set to `true` control +// nodes are marked as safe for execution on the SingleThreadedExecutor. +absl::Status ValidateOpIsSafeForSyncExecution( + const Node& n, bool allow_control_flow_sync_execution); + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_SINGLE_THREADED_EXECUTOR_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/stats_publisher_interface.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/stats_publisher_interface.h new file mode 100644 index 00000000..450683e6 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/stats_publisher_interface.h @@ -0,0 +1,85 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_STATS_PUBLISHER_INTERFACE_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_STATS_PUBLISHER_INTERFACE_H_ + +#include +#include +#include + +#include "tensorflow/core/common_runtime/build_graph_options.h" +#include "tensorflow/core/common_runtime/profile_handler.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/refcount.h" +#include "tensorflow/core/protobuf/config.pb.h" +#include "tensorflow/core/public/session_options.h" + +namespace tensorflow { + +class StatsPublisherInterface; + +typedef std::function( + const std::string&, const BuildGraphOptions&, const SessionOptions&)> + StatsPublisherFactory; + +// StatsPublisherInterface describes objects that publish information exported +// by Sessions. +// NOTE: This interface is experimental and subject to change. +// Implementations must be thread-safe. +class StatsPublisherInterface { + public: + // PublishStatsProto publishes step_stats. + // When PublishStatsProto is called multiple times, only the step_stats + // corresponding to the latest call will be published. + virtual void PublishStatsProto(const StepStats& step_stats) = 0; + + // PublishGraphProto publishes the graph_defs corresponding to each partition + // in the session. + // When PublishGraphProto is called multiple times, only the graph_defs + // corresponding to the latest call will be published. + virtual void PublishGraphProto( + const std::vector& graph_defs) = 0; + virtual void PublishGraphProto(std::vector graph_defs) = 0; + virtual void PublishGraphProto( + std::vector>&& function_records) = 0; + + // Returns a profile handler for the given step based on the execution_count + // and RunOptions. + // + // This method may return a null pointer, if no handler was created. + virtual std::unique_ptr GetProfileHandler( + uint64 step, int64_t execution_count, const RunOptions& ropts) = 0; + + virtual ~StatsPublisherInterface() {} + + static void RegisterStatsPublisher(StatsPublisherFactory factory_fn); + + static StatsPublisherFactory GetStatsPublisherFactory(); + + private: + static StatsPublisherFactory** GetStatsPublisherFactoryPtr() { + static StatsPublisherFactory* stats_publisher_factory = nullptr; + return &stats_publisher_factory; + } +}; + +std::unique_ptr CreateNoOpStatsPublisher( + const string& session, const BuildGraphOptions& bopts, + const SessionOptions& sopts); + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_STATS_PUBLISHER_INTERFACE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/step_stats_collector.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/step_stats_collector.h new file mode 100644 index 00000000..277630cd --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/step_stats_collector.h @@ -0,0 +1,208 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_STEP_STATS_COLLECTOR_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_STEP_STATS_COLLECTOR_H_ + +#include +#include +#include + +#include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/framework/step_stats.pb.h" +#include "tensorflow/core/framework/tracking_allocator.h" +#include "tensorflow/core/lib/gtl/inlined_vector.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/thread_annotations.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +class AllocatorMemoryUsed; +class CostModelManager; +class Graph; +class NodeDef; +class NodeExecStats; +class OpKernelContext; +class StepStats; +class StepStatsCollector; +class Tensor; + +// Statistics collection interface for individual node execution. +// +// See `NodeExecStatsWrapper` for a concrete implementation of this interface +// that interfaces with the `Session` layer. +class NodeExecStatsInterface { + public: + virtual ~NodeExecStatsInterface() {} + + // Called when the statistics collection for the node has finished. Once this + // method is called, the caller should not make assumptions about the validity + // of this object. + virtual void Done(const string& device) = 0; + + // Called immediately after this node starts being processed by the executor. + virtual void RecordExecutorStarted() = 0; + + // Called immediately before this node's `Compute()` or `ComputeAsync()` + // method is called. + virtual void RecordComputeStarted() = 0; + + // Called immediately after this node's `Compute()` method returned (or, for + // asynchronous operations, the callback passed to its `ComputeAsync()` method + // was called). + virtual void RecordComputeEnded() = 0; + + // Called immediately after this executor finishes processing this node. + virtual void RecordExecutorEnded() = 0; + + // Returns `true` if this object should track memory allocations. + virtual bool TrackAllocations() const = 0; + + // Records information about the memory allocated during the execution of this + // node. + // + // Takes ownership of any `TrackingAllocator` objects stored in `ctx`. + virtual void SetMemory(OpKernelContext* ctx) = 0; + + // Records information about the tensor produced by this node at the given + // output slot. + virtual void SetOutput(int slot, const Tensor* tensor) = 0; + + // Records the absolute time in nanoseconds at which this node became + // runnable (i.e. was scheduled for execution). + virtual void SetScheduled(int64_t nanos) = 0; +}; + +// Wraps NodeExecStats and adds allocation to it. +class NodeExecStatsWrapper : public NodeExecStatsInterface { + public: + // Does not take ownership of `node` or `step_stats_collector`. + NodeExecStatsWrapper(const NodeDef* node, + StepStatsCollector* step_stats_collector); + + // Takes ownership of 'stats' but not `node` or `step_stats_collector`. + NodeExecStatsWrapper(std::unique_ptr stats, + const NodeDef* node, + StepStatsCollector* step_stats_collector); + + // Destructor calls Finalize() to release the TrackingAllocators. + ~NodeExecStatsWrapper() override { Finalize(); } + + void Done(const string& device) override; + void RecordExecutorStarted() override; + void RecordComputeStarted() override; + void RecordComputeEnded() override; + void RecordExecutorEnded() override; + bool TrackAllocations() const override { return true; } + void SetMemory(OpKernelContext* ctx) override; + void SetOutput(int slot, const Tensor* tensor) override; + void SetScheduled(int64_t nanos) override; + + private: + friend class StepStatsCollector; + + NodeExecStats* stats() { return stats_.get(); } + + // Populates stats_ and releases TrackingAllocator. + void Finalize(); + + // Does not take ownership of the `allocator`. + // Takes ownership of `tracking_allocator`. + void AddAllocation(Allocator* allocator, + TrackingAllocator* tracking_allocator); + + absl::InlinedVector, 2UL> + allocations_; + std::unique_ptr stats_; + const NodeDef* const node_; // Not owned. + StepStatsCollector* const step_stats_collector_; // Not owned. +}; + +// Statistics collection interface for step execution. +// +// See `StepStatsCollector` for a concrete implementation of this interface +// that interfaces with the `Session` layer. +class StepStatsCollectorInterface { + public: + virtual ~StepStatsCollectorInterface() {} + + // Creates an instance of `NodeExecStatsInterface` that should be used for + // collecting statistics about individual node execution. + virtual NodeExecStatsInterface* CreateNodeExecStats(const NodeDef* node) = 0; + + // Generates a string reporting the currently used memory based + // on ResourceExhausted OOM `err` message. + // `err` message needs to contain device name and allocator name, e.g.: + // "ResourceExhaustedError: OOM when allocating tensor ... + // on /job:localhost/replica:0/task:0/device:GPU:0 by allocator GPU_0_bfc" + virtual string ReportAllocsOnResourceExhausted(absl::string_view err) = 0; +}; + +// StepStatsCollector manages the collection of a StepStats object. +// The StepStats object holds multiple DeviceStats. +// Each DeviceStats object holds multiple NodeExecStats. +class StepStatsCollector : public StepStatsCollectorInterface { + public: + // Does not take ownership of `step_stats`. + explicit StepStatsCollector(StepStats* step_stats); + + // BuildCostModel builds or updates a CostModel managed by cost_model_manager, + // using the currently collected DeviceStats associated with the devices in + // device_map. + void BuildCostModel( + CostModelManager* cost_model_manager, + const std::unordered_map& device_map); + + // Saves node statistics to the DeviceStats object associated with device. + // Should be called before Finalize. + void Save(const string& device, NodeExecStats* node_stats_pb); + void Save(const string& device, NodeExecStatsWrapper* node_stats); + + // Saves thread name. + void SaveThreadName(const string& device, const uint32 thread_id, + const string& thread_name); + + NodeExecStatsInterface* CreateNodeExecStats(const NodeDef* node) override; + string ReportAllocsOnResourceExhausted(absl::string_view err) override; + + // The following 2 Finalize methods populate the StepStats passed + // from the constructor. Calling it more than once won't have any effect. + // User shouldn't call Save() methods after Finalize. + void Finalize(); + // swaps the content of StepStats* from constructor with 'ss'. + void FinalizeAndSwap(StepStats* step_stats); + + private: + // TODO(suharshs): Make this configurable if its not possible to find a value + // that works for all cases. + static constexpr uint64 kMaxCollectedNodes = 1 << 20; + + typedef std::vector> NodeStatsVector; + typedef std::unordered_map ThreadNamesMap; + + void FinalizeInternal() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); + + mutex mu_; + bool finalized_ TF_GUARDED_BY(mu_); + std::unordered_map dev_stats_ TF_GUARDED_BY(mu_); + std::unordered_map thread_names_ TF_GUARDED_BY(mu_); + StepStats* step_stats_ TF_GUARDED_BY(mu_); + uint64 collected_nodes_ TF_GUARDED_BY(mu_) = 0; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_STEP_STATS_COLLECTOR_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/test_collective_executor_mgr.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/test_collective_executor_mgr.h new file mode 100644 index 00000000..0d0b190a --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/test_collective_executor_mgr.h @@ -0,0 +1,153 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_TEST_COLLECTIVE_EXECUTOR_MGR_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_TEST_COLLECTIVE_EXECUTOR_MGR_H_ + +#include "tensorflow/core/framework/collective.h" +#include "tensorflow/core/framework/device_attributes.pb.h" +#include "tensorflow/core/lib/gtl/flatmap.h" + +namespace tensorflow { + +// Mock objects that can't actually execute a Collective, but satisfy +// general infrastructure expectations within tests that don't require +// full functionality. + +class TestCollectiveExecutor : public CollectiveExecutor { + public: + explicit TestCollectiveExecutor(CollectiveExecutorMgrInterface* cem, + CollectiveRemoteAccess* rma = nullptr) + : CollectiveExecutor(cem), rma_(rma) {} + + void RunClosure(std::function fn) override { fn(); } + + CollectiveRemoteAccess* remote_access() override { return rma_; } + + private: + CollectiveRemoteAccess* rma_; +}; + +class TestParamResolver : public ParamResolverInterface { + void CompleteParamsAsync(const DeviceAttributes& device, CollectiveParams* cp, + CancellationManager* cancel_mgr, + const StatusCallback& done) override { + done(errors::Internal("Unimplemented")); + } + + void CompleteGroupAsync(const DeviceAttributes& device, + CollGroupParams* group_params, + CancellationManager* cancel_mgr, + const StatusCallback& done) override { + done(errors::Internal("Unimplemented")); + } + + void CompleteInstanceAsync(const CompleteInstanceRequest* request, + CompleteInstanceResponse* response, + CancellationManager* cancel_mgr, + const StatusCallback& done) override { + done(errors::Internal("Unimplemented")); + } + + absl::Status LookupGroup(int32_t group_key, CollGroupParams* group) override { + return errors::Internal("Unimplemented"); + } + + void StartAbort(const absl::Status& s) override {} +}; + +class TestCollectiveExecutorMgr : public CollectiveExecutorMgrInterface { + public: + explicit TestCollectiveExecutorMgr(ParamResolverInterface* param_resolver, + CollectiveRemoteAccess* rma) + : param_resolver_(param_resolver), rma_(rma) {} + + TestCollectiveExecutorMgr() : param_resolver_(nullptr), rma_(nullptr) {} + + ~TestCollectiveExecutorMgr() override { + for (auto& iter : table_) { + iter.second->Unref(); + } + } + + CollectiveExecutor* FindOrCreate(int64_t step_id) override { + mutex_lock l(mu_); + CollectiveExecutor* ce = nullptr; + auto iter = table_.find(step_id); + if (iter != table_.end()) { + ce = iter->second; + } else { + ce = new TestCollectiveExecutor(this, rma_); + table_[step_id] = ce; + } + ce->Ref(); + return ce; + } + + void Cleanup(int64_t step_id) override { + mutex_lock l(mu_); + auto iter = table_.find(step_id); + if (iter != table_.end()) { + iter->second->Unref(); + table_.erase(iter); + } + } + + void CleanupAll() override { + mutex_lock l(mu_); + for (auto& iter : table_) { + iter.second->Unref(); + } + table_.clear(); + } + + ParamResolverInterface* GetParamResolver() const override { + return param_resolver_; + } + + DeviceResolverInterface* GetDeviceResolver() const override { + LOG(FATAL); + return nullptr; + } + + NcclCommunicatorInterface* GetNcclCommunicator() const override { + return nullptr; + } + + void GetStepSequenceAsync(const GetStepSequenceRequest* request, + GetStepSequenceResponse* response, + const StatusCallback& done) override { + done(errors::Internal("unimplemented")); + } + + void RefreshStepIdSequenceAsync(int64_t graph_key, + const StatusCallback& done) override { + done(errors::Internal("unimplemented")); + } + + int64_t NextStepId(int64_t graph_key) override { + return CollectiveExecutor::kInvalidId; + } + + void RetireStepId(int64_t graph_key, int64_t step_id) override {} + + protected: + mutex mu_; + gtl::FlatMap table_ TF_GUARDED_BY(mu_); + ParamResolverInterface* param_resolver_; + CollectiveRemoteAccess* rma_; +}; + +} // namespace tensorflow +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_TEST_COLLECTIVE_EXECUTOR_MGR_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/threadpool_device.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/threadpool_device.h new file mode 100644 index 00000000..08175ccb --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/threadpool_device.h @@ -0,0 +1,63 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_THREADPOOL_DEVICE_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_THREADPOOL_DEVICE_H_ + +#include "tensorflow/core/common_runtime/device_factory.h" +#include "tensorflow/core/common_runtime/local_device.h" +#include "tensorflow/core/common_runtime/node_file_writer.h" + +namespace tensorflow { + +// CPU device implementation. +class ThreadPoolDevice : public LocalDevice { + public: + ThreadPoolDevice(const SessionOptions& options, const string& name, + Bytes memory_limit, const DeviceLocality& locality, + Allocator* allocator); + ~ThreadPoolDevice() override; + + Allocator* GetAllocator(AllocatorAttributes attr) override; + Allocator* GetScopedAllocator(AllocatorAttributes attr, + int64_t step_id) override; + ScopedAllocatorMgr* GetScopedAllocatorMgr() const override { + return scoped_allocator_mgr_.get(); + } + absl::Status MakeTensorFromProto(const TensorProto& tensor_proto, + const AllocatorAttributes alloc_attrs, + Tensor* tensor) override; + void CopyTensorInSameDevice(const Tensor* input_tensor, Tensor* output_tensor, + const DeviceContext* device_context, + StatusCallback done) override; + + absl::Status Sync() override { return absl::OkStatus(); } + + void Compute(OpKernel* op_kernel, OpKernelContext* context) override; + void ComputeAsync(AsyncOpKernel* op_kernel, OpKernelContext* context, + AsyncOpKernel::DoneCallback done) override; + + private: + void LogInputs(OpKernel* op_kernel, OpKernelContext* context); + void LogOutputs(OpKernel* op_kernel, OpKernelContext* context); + + Allocator* allocator_; // Not owned + std::unique_ptr scoped_allocator_mgr_; + NodeFileWriter* node_file_writer_ = nullptr; // not owned +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_THREADPOOL_DEVICE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/common_runtime/type_inference.h b/third_party/tflite-hdrs/tensorflow/core/common_runtime/type_inference.h new file mode 100644 index 00000000..fdbf6e27 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/common_runtime/type_inference.h @@ -0,0 +1,57 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_TYPE_INFERENCE_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_TYPE_INFERENCE_H_ + +#include "tensorflow/core/common_runtime/optimization_registry.h" + +namespace tensorflow { + +// Run a very basic type inference on the graph. It simply propagates type +// information along edges, until reaching stability. +// +// The pass is designed to run as a graph diffusion process, refining type +// information until it reaches a fixed point. However, the current +// implementation is a simplification that only ensures that: +// 1. each node is visited at least once +// 2. a successful update of a node's type ID prevents future visits +// 3. each node is visited at most a fixed number of times +// +// If needed, we can drop rule #3 and change rule #2 to consider an update to +// be any deep type change (rather than just the type ID). +// +// The state of the diffusion process is the NodeDef.experimental_full_type +// field, while the diffusion function is the node's corresponding +// OpRegistrationData.fwd_type_fn function. +// +// TODO(mdan): Use a regular union-based algorithm instead? +class TypeInferencePass : public GraphOptimizationPass { + public: + absl::Status Run(const GraphOptimizationPassOptions& options) override; +}; + +// A version of TypeInferencePass that prints a warning on error, instead +// of returning error status. This is done because there are a few graphs +// currently in the wild which don't actually type check. +// TODO(mdan): Turn this into an error, once all offenders are clean. +class WeakTypeInferencePass : public GraphOptimizationPass { + public: + absl::Status Run(const GraphOptimizationPassOptions& options) override; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_TYPE_INFERENCE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/config/flag_defs.h b/third_party/tflite-hdrs/tensorflow/core/config/flag_defs.h new file mode 100644 index 00000000..d6bc4d95 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/config/flag_defs.h @@ -0,0 +1,80 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_CONFIG_FLAG_DEFS_H_ +#define TENSORFLOW_CORE_CONFIG_FLAG_DEFS_H_ + +#include "tensorflow/core/config/flags.h" + +namespace tensorflow { +namespace flags { + +class Flags { + public: + // Test only flags. See flags_test.cc for example usage. + TF_DECLARE_FLAG(test_only_experiment_1, true, "Test only experiment 1."); + TF_DECLARE_FLAG(test_only_experiment_2, false, "Test only experiment 2."); + + // Declare flags below here. + // LINT.IfChange + TF_DECLARE_FLAG(enable_nested_function_shape_inference, false, + "Allow ops such as tf.cond to invoke the ShapeRefiner on " + "their nested functions."); + TF_DECLARE_FLAG(enable_quantized_dtypes_training, false, + "Set quantized dtypes, like tf.qint8, to be trainable."); + TF_DECLARE_FLAG(graph_building_optimization, false, + "Optimize graph building for faster tf.function tracing."); + TF_DECLARE_FLAG( + op_building_optimization, true, + "Optimize tf.Operation building for faster tf.function tracing."); + TF_DECLARE_FLAG(saved_model_fingerprinting, true, + "Add fingerprint to SavedModels."); + TF_DECLARE_FLAG( + tf_shape_default_int64, false, + "The default output of tf.shape (i.e. when out_type is not specified) is " + "int64 when this flag is true and int32 otherwise. Setting this to true " + "is an unsupported, experimental setting that causes known breakages."); + TF_DECLARE_FLAG(more_stack_traces, false, + "Enable experimental code that preserves and propagates " + "graph node stack traces in C++."); + TF_DECLARE_FLAG(publish_function_graphs, true, + "Enables the publication of partitioned function graphs " + "via StatsPublisherInterface. Disabling this flag can " + "reduce memory consumption."); + TF_DECLARE_FLAG(enable_aggressive_constant_replication, true, + "Replicate constants across CPU devices and even for local " + "CPUs within the same task if available.") + TF_DECLARE_FLAG(enable_colocation_key_propagation_in_while_op_lowering, false, + "If true, colocation key attributes for the ops will be " + "propagated during while op lowering to switch/merge ops.") + TF_DECLARE_FLAG(enable_tf2min_ici_weight, false, + "If true, ici weight optimization will be used in tf2/min.") + // TODO(b/341325107): Make this behavior the default and remove the flag. + TF_DECLARE_FLAG(enable_function_pruning_before_inlining, false, + "If true, functions will be pruned before inlining.") + TF_DECLARE_FLAG(enable_skip_encapsulation_for_non_tpu_graphs, false, + "If true, TF2XLA encapsulation will be skipped for non-TPU " + "graphs.") + TF_DECLARE_FLAG(enable_graph_debug_info_caching_for_stack_frames, true, + "If true, graph debug info will cache the stack frames.") + // LINT.ThenChange(//tensorflow/core/config/flags_api_wrapper.cc) +}; + +Flags& Global(); + +} // namespace flags +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_CONFIG_FLAG_DEFS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/config/flags.h b/third_party/tflite-hdrs/tensorflow/core/config/flags.h new file mode 100644 index 00000000..c882cd39 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/config/flags.h @@ -0,0 +1,47 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_CONFIG_FLAGS_H_ +#define TENSORFLOW_CORE_CONFIG_FLAGS_H_ + +#include "tensorflow/core/platform/stringpiece.h" + +namespace tensorflow { +namespace config { + +// Container class for a single feature flag. +// Note: this class is not thread safe. +class Flag { + public: + explicit Flag(absl::string_view flag_name, bool default_value); + bool value() { return value_; } + void reset(bool value) { value_ = value; } + + private: + bool value_; +}; + +// Macro to declare new flags. Declare all flags in core/config/flag_defs.h +// These flags can be overridden by setting the associated environment variable +// TF_FLAG_* flag to true or false. E.g. setting TF_FLAG_MY_FLAG=false will +// override the default value for a flag named `my_flag` to false. +#define TF_DECLARE_FLAG(flag_name, default_value, doc) \ + ::tensorflow::config::Flag flag_name = \ + ::tensorflow::config::Flag("TF_FLAG_" #flag_name, default_value); + +} // namespace config +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_CONFIG_FLAGS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/data/captured_function.h b/third_party/tflite-hdrs/tensorflow/core/data/captured_function.h new file mode 100644 index 00000000..553f09b5 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/data/captured_function.h @@ -0,0 +1,340 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_DATA_CAPTURED_FUNCTION_H_ +#define TENSORFLOW_CORE_DATA_CAPTURED_FUNCTION_H_ + +#include +#include +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/cancellation.h" +#include "tensorflow/core/framework/dataset.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/model.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/lib/random/random.h" +#include "tensorflow/core/platform/macros.h" + +namespace tensorflow { + +class Device; +class OpKernelContext; +class ResourceMgr; + +namespace data { + +class CapturedFunction; +class InstantiatedCapturedFunction; + +// Creates an iterator for a dataset which is created by applying the given +// function to the given input element. +absl::Status MakeIteratorFromInputElement( + IteratorContext* ctx, const DatasetBaseIterator* parent, + const std::vector& input_element, int64_t thread_index, + const InstantiatedCapturedFunction& inst_captured_func, + absl::string_view prefix, std::unique_ptr* out_iterator); + +// Creates an iterator for a dataset which is created by applying the given +// function to the given input element. Pass non-null `node` to record +// processing time for modeling Iterator's GetNext() resource usage. +absl::Status MakeIteratorFromInputElement( + IteratorContext* ctx, const DatasetBaseIterator* parent, + const std::vector& input_element, int64_t thread_index, + const InstantiatedCapturedFunction& inst_captured_func, + absl::string_view prefix, std::unique_ptr* out_iterator, + const std::shared_ptr& node); + +struct ShortCircuitInfo { + std::vector indices; + std::vector can_move; +}; + +// Metadata shared across all captures of the same function. +class FunctionMetadata { + public: + struct Params { + bool use_inter_op_parallelism = true; + bool use_default_device = true; + }; + + // Creates a new instance of the `FunctionMetadata` class, fetching function + // from a context argument. + static absl::Status Create(tensorflow::OpKernelConstruction* ctx, + const string& func_name, Params params, + std::shared_ptr* out_metadata); + + // Creates a new instance of the `FunctionMetadata` class, using the provided + // function. + static absl::Status Create(tensorflow::OpKernelConstruction* ctx, + NameAttrList&& func, Params params, + std::shared_ptr* out_metadata); + + // Returns the named list of function arguments. + const NameAttrList& func() const { return func_; } + + // Returns a borrowed pointer to the function library that contains the + // transitive closure of definitions used by the function. + const FunctionLibraryDefinition* lib_def() const { return lib_def_.get(); } + + // Returns short-circuit information. + const ShortCircuitInfo& short_circuit_info() const { + return short_circuit_info_; + } + + // Indicates whether a default device should be used for executing function + // ops. + bool use_default_device() const { return use_default_device_; } + + // Indicates whether to use inter-op parallelism for execution of the + // function. + bool use_inter_op_parallelism() const { return use_inter_op_parallelism_; } + + // Indicates whether the function should a multi-device function backend. + bool use_multi_device_function() const { return use_multi_device_function_; } + + private: + FunctionMetadata(NameAttrList&& func, Params params) + : func_(std::move(func)), + use_default_device_(params.use_default_device), + use_inter_op_parallelism_(params.use_inter_op_parallelism) {} + + NameAttrList func_; + std::unique_ptr lib_def_ = nullptr; + ShortCircuitInfo short_circuit_info_; + bool use_default_device_ = true; + bool use_inter_op_parallelism_ = true; + bool use_multi_device_function_ = true; +}; + +// Constructs and stores the parameters for the CapturedFunction Instantiate +// function. +struct InstantiateCapturedFunctionParams { + explicit InstantiateCapturedFunctionParams(IteratorContext* ctx) { + flr = ctx->flr(); + function_handle_cache = ctx->function_handle_cache(); + runner = ctx->runner(); + } + + explicit InstantiateCapturedFunctionParams(OpKernelContext* ctx) { + flr = ctx->function_library(); + function_handle_cache = nullptr; + runner = ctx->runner(); + } + + FunctionLibraryRuntime* flr; + FunctionHandleCache* function_handle_cache; + std::function)>* runner; +}; + +// A `CapturedFunction` encapsulates a TensorFlow function, plus any "captured" +// arguments that it closed over in the user program. +class CapturedFunction { + public: + // Creates a new instance using a list of named attributes, fetching captured + // inputs from a context argument. + static absl::Status Create(OpKernelContext* ctx, + std::shared_ptr metadata, + const string& argument_name, + std::unique_ptr* out_function); + + // Creates a new instance using a list of named attributes, using provided + // captured inputs. + static absl::Status Create(OpKernelContext* ctx, + std::shared_ptr metadata, + std::vector&& captured_inputs, + std::unique_ptr* out_function); + + // Adds the definition of this captured function into the given graph, + // returning its captured inputs and types through the respective output + // arguments. + absl::Status AddToGraph(SerializationContext* ctx, + DatasetBase::DatasetGraphDefBuilder* b, + std::vector* other_arguments, + DataTypeVector* other_arguments_types) const; + + // Instantiates this function for use in the given context, providing an + // InstantiatedCapturedFunction that can be used to execute functions. + absl::Status Instantiate(IteratorContext* ctx, + std::unique_ptr* + instantiated_captured_function); + + absl::Status Instantiate(InstantiateCapturedFunctionParams params, + std::unique_ptr* + instantiated_captured_function); + + // Determines whether the captured function is stateful. + absl::Status CheckExternalState() const; + + // Returns the additional captured inputs that will be passed to the function. + const std::vector& captured_inputs() const { + return captured_inputs_; + } + + // Returns the named list of function arguments. + const NameAttrList& func() const { return metadata_->func(); } + + // Returns the transitive set of function definition required to instantiate + // this function. + const FunctionLibraryDefinition* lib_def() const { + return metadata_->lib_def(); + } + + // If every function output corresponds to one of its inputs, the method + // returns the mapping from output indices to input indices. Otherwise, it + // returns an empty list. + const ShortCircuitInfo& short_circuit_info() const { + return metadata_->short_circuit_info(); + } + + // Indicates whether the function should use inter op parallelism. + bool use_inter_op_parallelism() const { + return metadata_->use_inter_op_parallelism(); + } + + private: + CapturedFunction(std::shared_ptr metadata, + std::vector captured_inputs); + + absl::Status IsMultiDevice(FunctionLibraryRuntime* flr, + bool* is_multi_device) const; + + const std::shared_ptr metadata_; + const std::vector captured_inputs_; + + CapturedFunction(const CapturedFunction&) = delete; + void operator=(const CapturedFunction&) = delete; +}; + +// `InstantiatedCapturedFunction` encapsulates all the runtime support needed +// to execute a tensorflow function. +// +// While `CapturedFunction` encapsulates constant attributes of the function, +// such as its name and captured arguments, `InstantiatedCapturedFunction` +// encapsulates runtime aspects, such as `FunctionLibraryRuntime` and function +// handle. +// +// The `Iterator` related classes use `InstantiatedCapturedFunction` to execute +// functions outside of the normal `OpKernel::Compute()` context. +class InstantiatedCapturedFunction { + public: + // Runs the instantiated captured function. This method takes ownership of + // the tensors in `args`, in order to be able to deallocate them as early as + // possible. Use `RunWithBorrowedArgs()` if the caller needs to retain + // ownership of the `args`. + absl::Status Run(IteratorContext* ctx, std::vector&& args, + std::vector* rets) const; + + // Runs the instantiated captured function. This method takes ownership of + // the tensors in `args`, in order to be able to deallocate them as early as + // possible. Use `RunWithBorrowedArgs()` if the caller needs to retain + // ownership of the `args`. Pass non-null `node` to record processing time + // for modeling Iterator's GetNext() resource usage. When non-null node is + // provided, the pre-requisite is that the calling thread has previously + // called `DatasetBaseIterator::RecordStart(). + absl::Status Run(IteratorContext* ctx, std::vector&& args, + std::vector* rets, + const std::shared_ptr& node) const; + + // Synchronously runs the captured function on the given `args`, and stores + // the results in `*rets`. Prefer to use `Run()` or `RunAsync()` when + // possible. + absl::Status RunWithBorrowedArgs(IteratorContext* ctx, + const std::vector& args, + std::vector* rets) const; + + // Synchronously runs the captured function on the given `args`, and stores + // the results in `*rets`. Prefer to use `Run()` or `RunAsync()` when + // possible. Pass non-null `node` to record processing time for modeling + // Iterator's GetNext() resource usage. When non-null node is provided, the + // pre-requisite is that the calling thread has previously called + // `DatasetBaseIterator::RecordStart(). + absl::Status RunWithBorrowedArgs( + IteratorContext* ctx, const std::vector& args, + std::vector* rets, + const std::shared_ptr& node) const; + + // Synchronously runs the captured function on the given `args`, and stores + // the results in `*rets`. Prefer to use `Run()` or `RunAsync()` when + // possible. This can be useful for calling a captured function in cases where + // an `IteratorContext*` is not available (such as a destructor). + // + // TODO(b/144278100): Avoid running functions without IteratorContext. + absl::Status RunInstantiated(const std::vector& args, + std::vector* rets); + + // Asynchronously runs the captured function on the given `args`, stores the + // results in `*rets`, and calls the given `done` callback when the function + // returns. This method takes ownership of the tensors in `args`, in order to + // be able to deallocate them as early as possible. Pass non-null `node` to + // record processing time for modeling Iterator's GetNext() resource usage. + // When non-null node is provided, the pre-requisite is that the calling + // thread has previously called `DatasetBaseIterator::RecordStart(). + void RunAsync(IteratorContext* ctx, std::vector&& args, + std::vector* rets, + FunctionLibraryRuntime::DoneCallback done, + const std::shared_ptr& node) const { + RunAsync(*(ctx->runner()), ctx->cancellation_manager(), + ctx->collective_executor(), std::move(args), rets, done, node); + } + + // A version of `RunAsync` that does not take an `IteratorContext` but a + // runner, a cancellation manager, and a collective executor. + void RunAsync(std::function)> runner, + CancellationManager* parent_cancellation_manager, + CollectiveExecutor* collective_executor, + std::vector&& args, std::vector* rets, + FunctionLibraryRuntime::DoneCallback done, + const std::shared_ptr& node) const; + + std::string func_name() const { return captured_func_->func().name(); } + + private: + friend class CapturedFunction; + + InstantiatedCapturedFunction( + FunctionLibraryRuntime* lib, FunctionLibraryRuntime::Handle f_handle, + DataTypeVector ret_types, + std::function)> runner, + CapturedFunction* captured_func, bool is_multi_device); + + // Determines whether a rendezvous object should be created when running the + // instantiated function. + bool ShouldCreateRendezvous() const; + + FunctionLibraryRuntime* const lib_; // Not owned. + const FunctionLibraryRuntime::Handle f_handle_; + const DataTypeVector ret_types_; + // Note: We capture the runner at function instantiation time to be able to + // run the function without `IteratorContext` via `RunInstantiated`. + std::function)> captured_runner_; + CapturedFunction* const captured_func_; // Not owned. + const bool is_multi_device_; + + InstantiatedCapturedFunction(const InstantiatedCapturedFunction&) = delete; + void operator=(const InstantiatedCapturedFunction&) = delete; +}; + +} // namespace data +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_DATA_CAPTURED_FUNCTION_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/data/compression_utils.h b/third_party/tflite-hdrs/tensorflow/core/data/compression_utils.h new file mode 100644 index 00000000..8b4d5179 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/data/compression_utils.h @@ -0,0 +1,44 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_DATA_COMPRESSION_UTILS_H_ +#define TENSORFLOW_CORE_DATA_COMPRESSION_UTILS_H_ + +#include + +#include "absl/status/status.h" +#include "tensorflow/core/framework/dataset.pb.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/platform/status.h" + +namespace tensorflow { +namespace data { + +// Compresses the components of `element` into the `CompressedElement` proto. +// +// In addition to writing the actual compressed bytes, `Compress` fills +// out the per-component metadata for the `CompressedElement`. +// +// Returns an error if the uncompressed size of the element exceeds 4GB. +absl::Status CompressElement(const std::vector& element, + CompressedElement* out); + +// Uncompresses a `CompressedElement` into a vector of tensor components. +absl::Status UncompressElement(const CompressedElement& compressed, + std::vector* out); + +} // namespace data +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_DATA_COMPRESSION_UTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/data/dataset_test_base.h b/third_party/tflite-hdrs/tensorflow/core/data/dataset_test_base.h new file mode 100644 index 00000000..0ef63825 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/data/dataset_test_base.h @@ -0,0 +1,1128 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_DATA_DATASET_TEST_BASE_H_ +#define TENSORFLOW_CORE_DATA_DATASET_TEST_BASE_H_ + +#include + +#include +#include +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/strings/string_view.h" +#include "tensorflow/core/common_runtime/device.h" +#include "tensorflow/core/common_runtime/device_mgr.h" +#include "tensorflow/core/common_runtime/graph_constructor.h" +#include "tensorflow/core/common_runtime/process_function_library_runtime.h" +#include "tensorflow/core/data/name_utils.h" +#include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/framework/cancellation.h" +#include "tensorflow/core/framework/dataset.h" +#include "tensorflow/core/framework/dataset_options.pb.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/function.pb.h" +#include "tensorflow/core/framework/function_handle_cache.h" +#include "tensorflow/core/framework/function_testlib.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/resource_mgr.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/framework/variant_op_registry.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/lib/gtl/inlined_vector.h" +#include "tensorflow/core/lib/io/zlib_compression_options.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/refcount.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/threadpool.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/tensor_slice_reader_cache.h" + +namespace tensorflow { +namespace data { + +typedef std::vector< + std::pair> + AttributeVector; + +constexpr int kDefaultCPUNum = 2; +constexpr int kDefaultThreadNum = 2; + +// Creates a tensor with the specified dtype, shape, and value. +template +static Tensor CreateTensor(const TensorShape& input_shape, + gtl::ArraySlice input_data) { + Tensor tensor(DataTypeToEnum::value, input_shape); + test::FillValues(&tensor, input_data); + return tensor; +} + +// Creates a tensor with the specified dtype and shape, with values 0, 1, 2, ... +template +static Tensor CreateTensor(const TensorShape& input_shape) { + Tensor tensor(DataTypeToEnum::value, input_shape); + test::FillIota(&tensor, 0); + return tensor; +} + +// Creates a vector of tensors with the specified dtype, shape, and values. +template +std::vector CreateTensors( + const TensorShape& shape, const std::vector>& values) { + std::vector result; + result.reserve(values.size()); + for (auto& value : values) { + result.emplace_back(CreateTensor(shape, value)); + } + return result; +} + +enum class CompressionType { ZLIB = 0, GZIP = 1, RAW = 2, UNCOMPRESSED = 3 }; + +// Returns a string representation for the given compression type. +string ToString(CompressionType compression_type); + +// Gets the specified zlib compression options according to the compression +// type. Note that `CompressionType::UNCOMPRESSED` is not supported because +// `ZlibCompressionOptions` does not have an option. +io::ZlibCompressionOptions GetZlibCompressionOptions( + CompressionType compression_type); + +// Used to specify parameters when writing data into files with compression. +// `input_buffer_size` and `output_buffer_size` specify the input and output +// buffer size when ZLIB and GZIP compression is used. +struct CompressionParams { + CompressionType compression_type = CompressionType::UNCOMPRESSED; + int32 input_buffer_size = 0; + int32 output_buffer_size = 0; +}; + +// Writes the input data into the file without compression. +absl::Status WriteDataToFile(const string& filename, const char* data); + +// Writes the input data into the file with the specified compression. +absl::Status WriteDataToFile(const string& filename, const char* data, + const CompressionParams& params); + +// Writes the input data into the TFRecord file with the specified compression. +absl::Status WriteDataToTFRecordFile( + const string& filename, const std::vector& records, + const CompressionParams& params); + +// Provides the parameters for running the dataset op. +class DatasetParams { + public: + DatasetParams(DataTypeVector output_dtypes, + std::vector output_shapes, + string node_name); + + virtual ~DatasetParams() = default; + + // Returns the inputs (except the input datasets) as a tensor vector. + virtual std::vector GetInputTensors() const = 0; + + // Returns the dataset input names as a string vector. + virtual absl::Status GetInputNames( + std::vector* input_names) const = 0; + + // Returns the dataset attributes as a vector. + virtual absl::Status GetAttributes(AttributeVector* attributes) const = 0; + + // Checks if the tensor is a dataset variant tensor. + static bool IsDatasetTensor(const Tensor& tensor); + + string node_name() const { return node_name_; } + + DataTypeVector output_dtypes() const { return output_dtypes_; } + + std::vector output_shapes() const { + return output_shapes_; + } + + string iterator_prefix() const { return iterator_prefix_; } + + const std::vector>& input_dataset_params() + const { + return input_dataset_params_; + } + + // Returns the functions that will be used when running the dataset op. + virtual std::vector func_lib() const { return {}; } + + // Returns the dataset type for the op represented by these parameters. This + // type usually needs to match the constant called `kDatasetType` defined in + // the dataset kernel. + virtual string dataset_type() const = 0; + + // Returns the dataset op name. By default, it returns the Op::kDatasetType + // concatenated with "Dataset". For ops that do not have "Dataset" suffix, + // this method can be overriden to return a different name. + virtual string op_name() const { + name_utils::OpNameParams params; + params.op_version = op_version(); + return name_utils::OpName(dataset_type(), params); + } + + virtual int op_version() const { return op_version_; } + + protected: + std::vector> input_dataset_params_; + DataTypeVector output_dtypes_; + std::vector output_shapes_; + string node_name_; + string iterator_prefix_ = "Iterator"; + int op_version_ = 1; +}; + +// `RangeDatasetParams` is a common dataset parameter type that are used in +// testing. +class RangeDatasetParams : public DatasetParams { + public: + RangeDatasetParams(int64_t start, int64_t stop, int64_t step, + DataTypeVector output_dtypes, + std::vector output_shapes, + string node_name); + + RangeDatasetParams(int64_t start, int64_t stop, int64_t step); + + RangeDatasetParams(int64_t start, int64_t stop, int64_t step, + DataTypeVector output_dtypes); + + std::vector GetInputTensors() const override; + + absl::Status GetInputNames(std::vector* input_names) const override; + + absl::Status GetAttributes(AttributeVector* attr_vector) const override; + + string dataset_type() const override; + + private: + int64_t start_; + int64_t stop_; + int64_t step_; +}; + +// `BatchDatasetParams` is a common dataset parameter type that are used in +// testing. +class BatchDatasetParams : public DatasetParams { + public: + template + BatchDatasetParams(T input_dataset_params, int64_t batch_size, + bool drop_remainder, bool parallel_copy, + DataTypeVector output_dtypes, + std::vector output_shapes, + string node_name) + : DatasetParams(std::move(output_dtypes), std::move(output_shapes), + std::move(node_name)), + batch_size_(batch_size), + drop_remainder_(drop_remainder), + parallel_copy_(parallel_copy) { + input_dataset_params_.push_back(std::make_unique(input_dataset_params)); + op_version_ = 2; + iterator_prefix_ = + name_utils::IteratorPrefix(input_dataset_params.dataset_type(), + input_dataset_params.iterator_prefix()); + } + + std::vector GetInputTensors() const override; + + absl::Status GetInputNames(std::vector* input_names) const override; + + absl::Status GetAttributes(AttributeVector* attr_vector) const override; + + string dataset_type() const override; + + private: + int64_t batch_size_; + bool drop_remainder_; + bool parallel_copy_; +}; + +// `MapDatasetParams` is a common dataset parameter type that are used in +// testing. +class MapDatasetParams : public DatasetParams { + public: + template + MapDatasetParams(T input_dataset_params, std::vector other_arguments, + FunctionDefHelper::AttrValueWrapper func, + std::vector func_lib, + DataTypeVector type_arguments, DataTypeVector output_dtypes, + std::vector output_shapes, + bool use_inter_op_parallelism, bool preserve_cardinality, + string node_name) + : DatasetParams(std::move(output_dtypes), std::move(output_shapes), + std::move(node_name)), + other_arguments_(std::move(other_arguments)), + func_(std::move(func)), + func_lib_(std::move(func_lib)), + type_arguments_(std::move(type_arguments)), + use_inter_op_parallelism_(use_inter_op_parallelism), + preserve_cardinality_(preserve_cardinality) { + input_dataset_params_.push_back(std::make_unique(input_dataset_params)); + iterator_prefix_ = + name_utils::IteratorPrefix(input_dataset_params.dataset_type(), + input_dataset_params.iterator_prefix()); + } + + std::vector GetInputTensors() const override; + + absl::Status GetInputNames(std::vector* input_names) const override; + + absl::Status GetAttributes(AttributeVector* attr_vector) const override; + + string dataset_type() const override; + + std::vector func_lib() const override; + + private: + std::vector other_arguments_; + FunctionDefHelper::AttrValueWrapper func_; + std::vector func_lib_; + DataTypeVector type_arguments_; + bool use_inter_op_parallelism_; + bool preserve_cardinality_; +}; + +// `TensorSliceDatasetParams` is a common dataset parameter type that are used +// in testing. +class TensorSliceDatasetParams : public DatasetParams { + public: + TensorSliceDatasetParams(std::vector components, string node_name, + bool is_files = false); + + std::vector GetInputTensors() const override; + + absl::Status GetInputNames(std::vector* input_names) const override; + + absl::Status GetAttributes(AttributeVector* attr_vector) const override; + + string dataset_type() const override; + + int64_t num_slices() const { return components_[0].dim_size(0); } + + size_t num_tensors_per_slice() const { return components_.size(); } + + private: + DataTypeVector TensorSliceDtypes(const std::vector& input_components); + + std::vector TensorSliceShapes( + const std::vector& input_components); + + public: + std::vector components_; + bool is_files_; +}; + +// `TakeDatasetParams` is a common dataset parameter type that are used in +// testing. +class TakeDatasetParams : public DatasetParams { + public: + template + TakeDatasetParams(T input_dataset_params, int count, + DataTypeVector output_dtypes, + std::vector output_shapes, + string node_name) + : DatasetParams(std::move(output_dtypes), std::move(output_shapes), + std::move(node_name)), + count_(count) { + input_dataset_params_.push_back(std::make_unique(input_dataset_params)); + iterator_prefix_ = + name_utils::IteratorPrefix(input_dataset_params.dataset_type(), + input_dataset_params.iterator_prefix()); + } + + std::vector GetInputTensors() const override; + + absl::Status GetInputNames(std::vector* input_names) const override; + + absl::Status GetAttributes(AttributeVector* attr_vector) const override; + + string dataset_type() const override; + + private: + int64_t count_; +}; + +// `ConcatenateDatasetParams` is a common dataset parameter type that are used +// in testing. +class ConcatenateDatasetParams : public DatasetParams { + public: + template + ConcatenateDatasetParams(T input_dataset_params_0, P input_dataset_params_1, + DataTypeVector output_dtypes, + std::vector output_shapes, + string node_name) + : DatasetParams(std::move(output_dtypes), std::move(output_shapes), + std::move(node_name)) { + input_dataset_params_.push_back( + std::make_unique(input_dataset_params_0)); + input_dataset_params_.push_back( + std::make_unique(input_dataset_params_1)); + iterator_prefix_ = + name_utils::IteratorPrefix(input_dataset_params_0.dataset_type(), + input_dataset_params_0.iterator_prefix()); + } + + std::vector GetInputTensors() const override; + + absl::Status GetInputNames(std::vector* input_names) const override; + + absl::Status GetAttributes(AttributeVector* attr_vector) const override; + + string dataset_type() const override; +}; + +// `OptionsDatasetParams` is a common dataset parameter type that is used in +// testing. +class OptionsDatasetParams : public DatasetParams { + public: + template + OptionsDatasetParams(T input_dataset_params, const string& serialized_options, + DataTypeVector output_dtypes, + std::vector output_shapes, + string node_name) + : DatasetParams(std::move(output_dtypes), std::move(output_shapes), + std::move(node_name)), + serialized_options_(serialized_options) { + input_dataset_params_.push_back(std::make_unique(input_dataset_params)); + } + + std::vector GetInputTensors() const override; + + absl::Status GetInputNames(std::vector* input_names) const override; + + absl::Status GetAttributes(AttributeVector* attr_vector) const override; + + string dataset_type() const override; + + private: + string serialized_options_; +}; + +template +struct GetNextTestCase { + GetNextTestCase(T dataset_params, std::vector expected_outputs, + bool compare_order = true) + : dataset_params(std::move(dataset_params)), + expected_outputs(std::move(expected_outputs)), + compare_order(compare_order) {} + + T dataset_params; + std::vector expected_outputs; + bool compare_order; +}; + +template +struct SkipTestCase { + SkipTestCase(T dataset_params, int num_to_skip, int expected_num_skipped, + bool get_next = false, std::vector expected_outputs = {}, + bool compare_order = true) + : dataset_params(std::move(dataset_params)), + num_to_skip(num_to_skip), + expected_num_skipped(expected_num_skipped), + get_next(get_next), + expected_outputs(std::move(expected_outputs)), + compare_order(compare_order) {} + + T dataset_params; + int num_to_skip; + int expected_num_skipped; + bool get_next; + std::vector expected_outputs; + bool compare_order; +}; + +template +struct DatasetNodeNameTestCase { + T dataset_params; + string expected_node_name; +}; + +template +struct DatasetTypeStringTestCase { + T dataset_params; + string expected_dataset_type_string; +}; + +template +struct DatasetOutputDtypesTestCase { + T dataset_params; + DataTypeVector expected_output_dtypes; +}; + +template +struct DatasetOutputShapesTestCase { + T dataset_params; + std::vector expected_output_shapes; +}; + +template +struct CardinalityTestCase { + T dataset_params; + int64_t expected_cardinality; +}; + +template +struct DatasetSaveTestCase { + T dataset_params; +}; + +template +struct IteratorOutputDtypesTestCase { + T dataset_params; + DataTypeVector expected_output_dtypes; +}; + +template +struct IteratorOutputShapesTestCase { + T dataset_params; + std::vector expected_output_shapes; +}; + +template +struct IteratorPrefixTestCase { + T dataset_params; + string expected_iterator_prefix; +}; + +template +struct IteratorSaveAndRestoreTestCase { + IteratorSaveAndRestoreTestCase(T dataset_params, std::vector breakpoints, + std::vector expected_outputs, + bool compare_order = true) + : dataset_params(std::move(dataset_params)), + breakpoints(std::move(breakpoints)), + expected_outputs(std::move(expected_outputs)), + compare_order(compare_order) {} + + T dataset_params; + std::vector breakpoints; + std::vector expected_outputs; + bool compare_order; +}; + +// Class composing a dataset with its dependencies. +class TestDataset { + public: + // TestDataset expects that the caller has Ref'd the wrapped dataset. When + // TestDataset is destroyed, it will Unref the dataset. + TestDataset(std::unique_ptr kernel_, + std::unique_ptr ctx_params, + std::unique_ptr ctx, + std::vector> input_tensors, + DatasetBase* dataset) + : kernel_(std::move(kernel_)), + ctx_params_(std::move(ctx_params)), + ctx_(std::move(ctx)), + input_tensors_(std::move(input_tensors)), + dataset_(dataset), + scoped_unref_(dataset) {} + + DatasetBase* dataset() const { return dataset_; } + + OpKernelContext* op_kernel_context() const { return ctx_.get(); } + + protected: + std::unique_ptr kernel_; + std::unique_ptr ctx_params_; + std::unique_ptr ctx_; + // The input tensors that this dataset depends on. They must outlive the + // dataset. + std::vector> input_tensors_; + DatasetBase* dataset_; + core::ScopedUnref scoped_unref_; +}; + +// Class composing a dataset iterator with its dependencies. +class TestIterator { + public: + TestIterator(std::unique_ptr ctx, + std::unique_ptr iterator) + : iterator_(std::move(iterator)), ctx_(std::move(ctx)) {} + + IteratorBase* iterator() const { return iterator_.get(); } + + IteratorContext* ctx() const { return ctx_.get(); } + + absl::Status GetNext(std::vector* out_tensors, + bool* end_of_sequence) { + return iterator_->GetNext(ctx(), out_tensors, end_of_sequence); + } + + protected: + std::unique_ptr iterator_; + std::unique_ptr ctx_; +}; + +// Helpful functions to test Dataset op kernels. +class DatasetOpsTestBase : public ::testing::Test { + public: + DatasetOpsTestBase(); + + // Initializes the runtime and creates a dataset and iterator. + absl::Status Initialize(const DatasetParams& dataset_params); + + // Initializes the parts of the runtime needed to run dataset ops. + absl::Status InitializeRuntime(const DatasetParams& dataset_params); + + // Creates a dataset. + absl::Status MakeDataset(const DatasetParams& dataset_params, + std::unique_ptr* dataset); + + // Creates an iterator for the given dataset, using the specified split + // providers. + absl::Status MakeIterator( + const DatasetParams& dataset_params, const TestDataset& dataset, + std::vector> split_providers, + std::unique_ptr* iterator); + // Creates an iterator for the given dataset. + absl::Status MakeIterator(const DatasetParams& dataset_params, + const TestDataset& dataset, + std::unique_ptr* iterator); + + // Runs the dataset operation according to the predefined dataset params and + // produces outputs. Different from `MakeDataset()` which returns a Dataset + // object, `RunDatasetOp()` executes the dataset kernel based on the input + // DatasetParams and returns the produced outputs as a tensor vector. It can + // be used to run some dataset operations that do not have an internal + // customized `Dataset` class (e.g. `ReduceDatasetOp`). + absl::Status RunDatasetOp(const DatasetParams& dataset_params, + std::vector* outputs); + + // The method validates whether the two tensors have the same shape, dtype, + // and value. + static absl::Status ExpectEqual(const Tensor& a, const Tensor& b); + + // The method validates whether the two tensor vectors have the same tensors. + // If `compare_order` is false, the method will only evaluate whether the two + // vectors have the same elements regardless of order. + static absl::Status ExpectEqual(std::vector produced_tensors, + std::vector expected_tensors, + bool compare_order); + + // Checks `IteratorBase::GetNext()`. + absl::Status CheckIteratorGetNext(const std::vector& expected_outputs, + bool compare_order); + + // Checks `IteratorBase::GetNext()`. + absl::Status CheckIteratorGetNext(TestIterator* iterator, + const std::vector& expected_outputs, + bool compare_order); + + // Checks `IteratorBase::GetNext()`. + absl::Status CheckIteratorGetNext(IteratorBase* iterator, + IteratorContext* ctx, + const std::vector& expected_outputs, + bool compare_order); + + // Checks `IteratorBase::Skip()` + absl::Status CheckIteratorSkip(int num_to_skip, int expected_num_skipped, + bool get_next, + const std::vector& expected_outputs, + bool compare_order); + + // Checks that iterating through the dataset using a split provider produces + // the expected outputs. + absl::Status CheckSplitProviderFullIteration( + const DatasetParams& params, const std::vector& expected_outputs); + + // Checks that iterating through the dataset using a sharded split provider + // with the given `num_shards` and `shard_index` produces the expected + // outputs. + absl::Status CheckSplitProviderShardedIteration( + const DatasetParams& params, int64_t num_shards, int64_t shard_index, + const std::vector& expected_outputs); + + // Checks `DatasetBase::node_name()`. + absl::Status CheckDatasetNodeName(const string& expected_dataset_node_name); + + // Checks `DatasetBase::type_string()`. + absl::Status CheckDatasetTypeString(const string& expected_type_str); + + // Checks `DatasetBase::output_dtypes()`. + absl::Status CheckDatasetOutputDtypes( + const DataTypeVector& expected_output_dtypes); + + // Checks `DatasetBase::output_shapes()`. + absl::Status CheckDatasetOutputShapes( + const std::vector& expected_output_shapes); + + // Checks `DatasetBase::Cardinality()`. + absl::Status CheckDatasetCardinality(int expected_cardinality); + + // Checks `DatasetBase::options()`. + absl::Status CheckDatasetOptions(const Options& expected_options); + + // Checks `IteratorBase::output_dtypes()`. + absl::Status CheckIteratorOutputDtypes( + const DataTypeVector& expected_output_dtypes); + + // Checks `IteratorBase::output_shapes()`. + absl::Status CheckIteratorOutputShapes( + const std::vector& expected_output_shapes); + + // Checks `IteratorBase::prefix()`. + absl::Status CheckIteratorPrefix(const string& expected_iterator_prefix); + + absl::Status CheckIteratorSaveAndRestore( + DatasetBase* dataset, IteratorContext* iterator_ctx, + const std::string& iterator_prefix, + const std::vector& expected_outputs, + const std::vector& breakpoints, bool compare_order); + + absl::Status CheckIteratorSaveAndRestore( + const std::string& iterator_prefix, + const std::vector& expected_outputs, + const std::vector& breakpoints, bool compare_order); + + // A class for testing variant tensors. + class TestVariant { + public: + TestVariant() = default; + explicit TestVariant(const std::vector& tensors) + : tensors_(tensors) {} + + bool operator!=(const TestVariant& rhs) const { + return !ExpectEqual(tensors_, rhs.tensors_, /*compare_order=*/true).ok(); + } + + constexpr static const char kTypeName[] = "tensorflow::data::TestVariant"; + + string TypeName() const { return kTypeName; } + + // Encodes the contents of this object into `data`. This function signature + // is required for objects to be stored in `tensorflow::Variant`s. See the + // docs for `tensorflow::Variant` for more information and see + // `tensorflow::Variant::Encode` for how this is used. + void Encode(VariantTensorData* data) const { + data->set_type_name(TypeName()); + for (const auto& tensor : tensors_) { + data->add_tensor(tensor); + } + } + + // Decodes `data` and updates the contents of this object. This function + // signature is required for objects to be stored in `tensorflow::Variant`s. + // See the docs for `tensorflow::Variant` for more information and see + // `tensorflow::Variant::Decode` for how this is used. + bool Decode(VariantTensorData data) { + tensors_ = data.tensors(); + return true; + } + + string DebugString() const { + string result = "TestVariant(["; + for (const auto& tensor : tensors_) { + if (&tensor != &tensors_[0]) result += ", "; + result += tensor.DebugString(); + } + result += "])"; + return result; + } + + private: + std::vector tensors_; + }; + + // Returns a scalar variant tensor containing a `TestVariant` object + // containing `tensors`. + static Tensor CreateTestVariantTensor(const std::vector& tensors) { + Tensor tensor{DT_VARIANT, TensorShape({})}; + TestVariant test_variant{tensors}; + tensor.scalar()() = test_variant; + return tensor; + } + + protected: + // Make destructor protected so that DatasetOpsTestBase objects cannot + // be instantiated directly. Only subclasses can be instantiated. + ~DatasetOpsTestBase() override; + + // Creates a thread pool for parallel tasks. + absl::Status InitThreadPool(int thread_num); + + // Initializes the runtime for computing the dataset operation and registers + // the input function definitions. `InitThreadPool()' needs to be called + // before this method if we want to run the tasks in parallel. + absl::Status InitFunctionLibraryRuntime(const std::vector& flib, + int cpu_num); + + // Creates a new op kernel based on the node definition. + absl::Status CreateOpKernel(const NodeDef& node_def, + std::unique_ptr* op_kernel); + + // Creates a new op kernel context. + absl::Status CreateDatasetContext( + OpKernel* dateset_kernel, absl::InlinedVector* inputs, + std::unique_ptr* dataset_context_params, + std::unique_ptr* dataset_context); + + // Creates a new dataset. + absl::Status CreateDataset(OpKernel* kernel, OpKernelContext* context, + DatasetBase** dataset); + + // Restores the state of the input iterator. It resets the iterator before + // restoring it to make sure the input iterator does not hold any + // resources or tasks. Otherwise, restoring an existing iterator may cause + // the timeout issue or duplicated elements. + absl::Status RestoreIterator(IteratorContext* ctx, + IteratorStateReader* reader, + const string& output_prefix, + const DatasetBase& dataset, + std::unique_ptr* iterator); + + // Fetches the dataset from the operation context. + absl::Status GetDatasetFromContext(OpKernelContext* context, int output_index, + DatasetBase** dataset); + + // Runs an operation producing outputs. + absl::Status RunOpKernel(OpKernel* op_kernel, OpKernelContext* context); + + // Executes a function producing outputs. + absl::Status RunFunction(const FunctionDef& fdef, test::function::Attrs attrs, + const std::vector& args, + const GraphConstructorOptions& graph_options, + std::vector rets); + + // Checks that the size of `inputs` matches the requirement of the op kernel. + absl::Status CheckOpKernelInput( + const OpKernel& kernel, + const absl::InlinedVector& inputs); + + // Creates a new context for running the dataset operation. + absl::Status CreateOpKernelContext( + OpKernel* kernel, absl::InlinedVector* inputs, + std::unique_ptr* context); + + // Creates a new context for running the dataset operation. + absl::Status CreateOpKernelContext( + OpKernel* kernel, absl::InlinedVector* inputs, + std::unique_ptr* params, + std::unique_ptr* context); + + // Creates a new iterator context for iterating the dataset. + absl::Status CreateIteratorContext( + OpKernelContext* op_context, + std::unique_ptr* iterator_context); + + // Creates a new iterator context for iterating the dataset. + // Creates a new serialization context for serializing the dataset and + // iterator. + absl::Status CreateSerializationContext( + std::unique_ptr* context); + + // Creates the dataset op kernel. + absl::Status MakeGetOptionsOpKernel(const DatasetParams& dataset_params, + std::unique_ptr* op_kernel); + + private: + // Runs the dataset operation according to the predefined dataset params and + // the produced outputs will be stored in `dataset_ctx`. + absl::Status RunDatasetOp( + const DatasetParams& dataset_params, + std::unique_ptr* dataset_kernel, + std::unique_ptr* dataset_ctx_params, + std::vector>* created_tensors, + std::unique_ptr* dataset_ctx); + + absl::Status MakeDataset( + const DatasetParams& dataset_params, + std::unique_ptr* dataset_kernel, + std::unique_ptr* dataset_ctx_params, + std::unique_ptr* dataset_ctx, + std::vector>* created_tensors, + DatasetBase** dataset); + + // Creates the dataset op kernel. + absl::Status MakeDatasetOpKernel(const DatasetParams& dataset_params, + std::unique_ptr* dataset_kernel); + + // Creates a dataset tensor according to the input dataset params. + absl::Status MakeDatasetTensor( + const DatasetParams& dataset_params, + std::vector>* created_tensors, + std::unique_ptr* dataset); + + // Adds an empty tensor with the specified dtype and shape to the input + // vector. + absl::Status AddDatasetInput(absl::InlinedVector* inputs, + DataTypeVector input_types, DataType dtype, + const TensorShape& shape); + + protected: + std::unique_ptr device_; + DeviceType device_type_; + int cpu_num_; + int thread_num_; + Allocator* allocator_; // Owned by `AllocatorFactoryRegistry`. + std::vector allocator_attrs_; + std::unique_ptr step_container_; + + // Device manager is used by function handle cache and needs to outlive it. + std::unique_ptr device_mgr_; + std::unique_ptr pflr_; + FunctionLibraryRuntime* flr_; // Owned by `pflr_`. + std::unique_ptr function_handle_cache_; + std::function)> runner_; + std::unique_ptr lib_def_; + std::unique_ptr resource_mgr_; + std::unique_ptr + slice_reader_cache_; + std::unique_ptr thread_pool_; + std::vector> tensors_; // Owns tensors. + mutex lock_for_refs_; // Used as the Mutex for inputs added as refs. + std::unique_ptr cancellation_manager_; + + // Indicates if the below fields have been initialized. + bool initialized_ = false; + std::unique_ptr dataset_kernel_; + std::unique_ptr params_; + std::unique_ptr dataset_ctx_; + DatasetBase* dataset_ = nullptr; + std::unique_ptr iterator_ctx_; + std::unique_ptr iterator_; +}; + +#define ITERATOR_GET_NEXT_TEST_P(dataset_op_test_class, dataset_params_class, \ + test_cases) \ + class ParameterizedGetNextTest \ + : public dataset_op_test_class, \ + public ::testing::WithParamInterface< \ + GetNextTestCase> {}; \ + \ + TEST_P(ParameterizedGetNextTest, GetNext) { \ + auto test_case = GetParam(); \ + TF_ASSERT_OK(Initialize(test_case.dataset_params)); \ + TF_ASSERT_OK( \ + CheckIteratorGetNext(test_case.expected_outputs, \ + /*compare_order=*/test_case.compare_order)); \ + } \ + \ + INSTANTIATE_TEST_SUITE_P( \ + dataset_op_test_class, ParameterizedGetNextTest, \ + ::testing::ValuesIn( \ + std::vector>(test_cases))); + +#define ITERATOR_SKIP_TEST_P(dataset_op_test_class, dataset_params_class, \ + test_cases) \ + class ParameterizedSkipTest : public dataset_op_test_class, \ + public ::testing::WithParamInterface< \ + SkipTestCase> {}; \ + \ + TEST_P(ParameterizedSkipTest, Skip) { \ + auto test_case = GetParam(); \ + TF_ASSERT_OK(Initialize(test_case.dataset_params)); \ + TF_ASSERT_OK(CheckIteratorSkip( \ + test_case.num_to_skip, test_case.expected_num_skipped, \ + test_case.get_next, test_case.expected_outputs, \ + /*compare_order=*/test_case.compare_order)); \ + } \ + \ + INSTANTIATE_TEST_SUITE_P( \ + dataset_op_test_class, ParameterizedSkipTest, \ + ::testing::ValuesIn( \ + std::vector>(test_cases))); + +#define DATASET_NODE_NAME_TEST_P(dataset_op_test_class, dataset_params_class, \ + test_cases) \ + class ParameterizedDatasetNodeNameTest \ + : public dataset_op_test_class, \ + public ::testing::WithParamInterface< \ + DatasetNodeNameTestCase> {}; \ + \ + TEST_P(ParameterizedDatasetNodeNameTest, DatasetNodeName) { \ + auto test_case = GetParam(); \ + TF_ASSERT_OK(Initialize(test_case.dataset_params)); \ + TF_ASSERT_OK(CheckDatasetNodeName(test_case.expected_node_name)); \ + } \ + \ + INSTANTIATE_TEST_SUITE_P( \ + dataset_op_test_class, ParameterizedDatasetNodeNameTest, \ + ::testing::ValuesIn( \ + std::vector>( \ + test_cases))); + +#define DATASET_TYPE_STRING_TEST_P(dataset_op_test_class, \ + dataset_params_class, test_cases) \ + class ParameterizedDatasetTypeStringTest \ + : public dataset_op_test_class, \ + public ::testing::WithParamInterface< \ + DatasetTypeStringTestCase> {}; \ + \ + TEST_P(ParameterizedDatasetTypeStringTest, DatasetTypeString) { \ + auto test_case = GetParam(); \ + TF_ASSERT_OK(Initialize(test_case.dataset_params)); \ + TF_ASSERT_OK( \ + CheckDatasetTypeString(test_case.expected_dataset_type_string)); \ + } \ + \ + INSTANTIATE_TEST_SUITE_P( \ + dataset_op_test_class, ParameterizedDatasetTypeStringTest, \ + ::testing::ValuesIn( \ + std::vector>( \ + test_cases))); + +#define DATASET_OUTPUT_DTYPES_TEST_P(dataset_op_test_class, \ + dataset_params_class, test_cases) \ + \ + class ParameterizedDatasetOutputDtypesTest \ + : public dataset_op_test_class, \ + public ::testing::WithParamInterface< \ + DatasetOutputDtypesTestCase> {}; \ + \ + TEST_P(ParameterizedDatasetOutputDtypesTest, DatasetOutputDtypes) { \ + auto test_case = GetParam(); \ + TF_ASSERT_OK(Initialize(test_case.dataset_params)); \ + TF_ASSERT_OK(CheckDatasetOutputDtypes(test_case.expected_output_dtypes)); \ + } \ + \ + INSTANTIATE_TEST_SUITE_P( \ + dataset_op_test_class, ParameterizedDatasetOutputDtypesTest, \ + ::testing::ValuesIn( \ + std::vector>( \ + test_cases))); + +#define DATASET_OUTPUT_SHAPES_TEST_P(dataset_op_test_class, \ + dataset_params_class, test_cases) \ + \ + class ParameterizedDatasetOutputShapesTest \ + : public dataset_op_test_class, \ + public ::testing::WithParamInterface< \ + DatasetOutputShapesTestCase> {}; \ + \ + TEST_P(ParameterizedDatasetOutputShapesTest, DatasetOutputShapes) { \ + auto test_case = GetParam(); \ + TF_ASSERT_OK(Initialize(test_case.dataset_params)); \ + TF_ASSERT_OK(CheckDatasetOutputShapes(test_case.expected_output_shapes)); \ + } \ + \ + INSTANTIATE_TEST_SUITE_P( \ + dataset_op_test_class, ParameterizedDatasetOutputShapesTest, \ + ::testing::ValuesIn( \ + std::vector>( \ + test_cases))); + +#define DATASET_CARDINALITY_TEST_P(dataset_op_test_class, \ + dataset_params_class, test_cases) \ + \ + class ParameterizedCardinalityTest \ + : public dataset_op_test_class, \ + public ::testing::WithParamInterface< \ + CardinalityTestCase> {}; \ + \ + TEST_P(ParameterizedCardinalityTest, Cardinality) { \ + auto test_case = GetParam(); \ + TF_ASSERT_OK(Initialize(test_case.dataset_params)); \ + TF_ASSERT_OK(CheckDatasetCardinality(test_case.expected_cardinality)); \ + } \ + \ + INSTANTIATE_TEST_SUITE_P( \ + dataset_op_test_class, ParameterizedCardinalityTest, \ + ::testing::ValuesIn( \ + std::vector>( \ + test_cases))); + +#define ITERATOR_OUTPUT_DTYPES_TEST_P(dataset_op_test_class, \ + dataset_params_class, test_cases) \ + class ParameterizedIteratorOutputDtypesTest \ + : public dataset_op_test_class, \ + public ::testing::WithParamInterface< \ + IteratorOutputDtypesTestCase> {}; \ + \ + TEST_P(ParameterizedIteratorOutputDtypesTest, IteratorOutputDtypes) { \ + auto test_case = GetParam(); \ + TF_ASSERT_OK(Initialize(test_case.dataset_params)); \ + TF_ASSERT_OK(CheckDatasetOutputDtypes(test_case.expected_output_dtypes)); \ + } \ + \ + INSTANTIATE_TEST_SUITE_P( \ + dataset_op_test_class, ParameterizedIteratorOutputDtypesTest, \ + ::testing::ValuesIn( \ + std::vector>( \ + test_cases))); + +#define ITERATOR_OUTPUT_SHAPES_TEST_P(dataset_op_test_class, \ + dataset_params_class, test_cases) \ + class ParameterizedIteratorOutputShapesTest \ + : public dataset_op_test_class, \ + public ::testing::WithParamInterface< \ + IteratorOutputShapesTestCase> {}; \ + \ + TEST_P(ParameterizedIteratorOutputShapesTest, IteratorOutputShapes) { \ + auto test_case = GetParam(); \ + TF_ASSERT_OK(Initialize(test_case.dataset_params)); \ + TF_ASSERT_OK(CheckIteratorOutputShapes(test_case.expected_output_shapes)); \ + } \ + \ + INSTANTIATE_TEST_SUITE_P( \ + dataset_op_test_class, ParameterizedIteratorOutputShapesTest, \ + ::testing::ValuesIn( \ + std::vector>( \ + test_cases))); + +#define ITERATOR_PREFIX_TEST_P(dataset_op_test_class, dataset_params_class, \ + test_cases) \ + class ParameterizedIteratorPrefixTest \ + : public dataset_op_test_class, \ + public ::testing::WithParamInterface< \ + IteratorPrefixTestCase> {}; \ + \ + TEST_P(ParameterizedIteratorPrefixTest, IteratorPrefix) { \ + auto test_case = GetParam(); \ + TF_ASSERT_OK(Initialize(test_case.dataset_params)); \ + TF_ASSERT_OK(CheckIteratorPrefix(test_case.expected_iterator_prefix)); \ + } \ + \ + INSTANTIATE_TEST_SUITE_P( \ + dataset_op_test_class, ParameterizedIteratorPrefixTest, \ + ::testing::ValuesIn( \ + std::vector>( \ + test_cases))); + +#define ITERATOR_SAVE_AND_RESTORE_TEST_P(dataset_op_test_class, \ + dataset_params_class, test_cases) \ + class ParameterizedIteratorSaveAndRestoreTest \ + : public dataset_op_test_class, \ + public ::testing::WithParamInterface< \ + IteratorSaveAndRestoreTestCase> {}; \ + TEST_P(ParameterizedIteratorSaveAndRestoreTest, IteratorSaveAndRestore) { \ + auto test_case = GetParam(); \ + TF_ASSERT_OK(Initialize(test_case.dataset_params)); \ + TF_ASSERT_OK(CheckIteratorSaveAndRestore( \ + test_case.dataset_params.iterator_prefix(), \ + test_case.expected_outputs, test_case.breakpoints, \ + test_case.compare_order)); \ + } \ + INSTANTIATE_TEST_SUITE_P( \ + dataset_op_test_class, ParameterizedIteratorSaveAndRestoreTest, \ + ::testing::ValuesIn( \ + std::vector>( \ + test_cases))); + +} // namespace data +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_DATA_DATASET_TEST_BASE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/data/dataset_utils.h b/third_party/tflite-hdrs/tensorflow/core/data/dataset_utils.h new file mode 100644 index 00000000..929af873 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/data/dataset_utils.h @@ -0,0 +1,429 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_DATA_DATASET_UTILS_H_ +#define TENSORFLOW_CORE_DATA_DATASET_UTILS_H_ + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/status/status.h" +#include "tensorflow/core/common_runtime/function.h" +#include "tensorflow/core/framework/dataset.h" +#include "tensorflow/core/framework/dataset_options.pb.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/function.pb.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/resource_handle.h" +#include "tensorflow/core/framework/resource_mgr.h" +#include "tensorflow/core/framework/tensor.h" + +namespace tensorflow { +namespace data { + +// Constant used for indicating that the argument of tf.data.Dataset.shard +// should be supplied by the auto-sharding rewrite. +constexpr int kShardHint = -1; + +// Creates a resource handle with a unique name for the given resource where +// the resource is managed by the Resource Manager. +template +absl::Status CreateWeakHandle(OpKernelContext* ctx, T* resource, + const string& container_name, + ResourceHandle* handle) { + static std::atomic resource_id_counter(0); + string unique_name = + strings::StrCat(container_name, resource_id_counter.fetch_add(1)); + ResourceMgr* mgr = ctx->resource_manager(); + TF_RETURN_IF_ERROR(mgr->Create(container_name, unique_name, resource)); + + *handle = MakeResourceHandle(container_name, unique_name, *ctx->device(), + TypeIndex::Make()); + return absl::OkStatus(); +} + +// Creates a ref-counting resource handle for the given resource, where the +// resource is owned by the handle. +template +absl::Status CreateHandle(OpKernelContext* ctx, T* resource, + ResourceHandle* handle) { + ResourceMgr* mgr = ctx->resource_manager(); + *handle = + ResourceHandle::MakeRefCountingHandle(resource, ctx->device()->name()); + TF_RETURN_IF_ERROR( + mgr->CreateUnowned(handle->container(), handle->name(), resource)); + return absl::OkStatus(); +} + +// TODO(b/198162355): Merge this class with ResourceOpKernel. +template +class AnonymousResourceOp : public OpKernel { + public: + // Creates an AnonymousResourceOp. + // ref_counting: Determines if the Op returns a ref-counting ResourceHandle. + // ResourceHandle. See go/tf-resource-handle-ref-count. + // return_deleter: Determines if the Op outputs a deleter tensor in addition + // to the resource handle tensor. + // If the resource handle is ref-counting, a no-op deleter is returned. + explicit AnonymousResourceOp(OpKernelConstruction* context, bool ref_counting, + bool return_deleter) + : OpKernel(context), + ref_counting_(ref_counting), + return_deleter_(return_deleter) {} + + void Compute(OpKernelContext* ctx) override { + FunctionLibraryRuntime* lib; + std::unique_ptr flib_def(nullptr); + std::unique_ptr pflr(nullptr); + OP_REQUIRES_OK( + ctx, ctx->function_library()->Clone(&flib_def, &pflr, &lib, true)); + T* resource; + OP_REQUIRES_OK(ctx, CreateResource(ctx, std::move(flib_def), + std::move(pflr), lib, &resource)); + + ResourceHandle handle; + if (ref_counting_) { + OP_REQUIRES_OK(ctx, CreateHandle(ctx, resource, &handle)); + } else { + OP_REQUIRES_OK(ctx, CreateWeakHandle(ctx, resource, name(), &handle)); + } + Tensor* handle_t; + OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &handle_t)); + handle_t->scalar()() = handle; + + if (return_deleter_) { + Tensor* deleter_t; + AllocatorAttributes attr; + attr.set_on_host(true); + OP_REQUIRES_OK( + ctx, ctx->allocate_output(1, TensorShape({}), &deleter_t, attr)); + // TODO(feyu): Consider returning an OptionalVariant. + if (!ref_counting_) { + // A deleter output that deletes the resource when destroyed. + deleter_t->scalar()() = + ResourceDeleter(handle, ctx->resource_manager()); + } + } + } + + protected: + virtual string name() = 0; + + virtual absl::Status CreateResource( + OpKernelContext* ctx, std::unique_ptr flib_def, + std::unique_ptr pflr, + FunctionLibraryRuntime* lib, T** resource) = 0; + + private: + const bool ref_counting_; + const bool return_deleter_; +}; + +// Returns OkStatus() if `expected` and `received` types match, +// errors::InvalidArgument otherwise. +absl::Status VerifyTypesMatch(const DataTypeVector& expected, + const DataTypeVector& received); + +absl::Status VerifyTypesMatch(const DataTypeVector& expected, + const std::vector& received); + +// Returns OkStatus() if `expected` and `received` shapes are compatible, +// errors::InvalidArgument otherwise. +absl::Status VerifyShapesCompatible( + const std::vector& expected, + const std::vector& received); + +absl::Status VerifyShapesCompatible( + const std::vector& expected, + const std::vector& received); + +// Dataset op level determinism policy. +class DeterminismPolicy { + public: + enum class Type : int { + // The op must produce elements deterministically. + kDeterministic, + // The op may relax determinism to improve performance. + kNondeterministic, + // The determinism policy is not specified at the op level. In this case we + // use the experimental_deterministic dataset option to determine the + // determinism policy. + kDefault, + }; + static constexpr const char* const kDeterministic = "true"; + static constexpr const char* const kNondeterministic = "false"; + static constexpr const char* const kDefault = "default"; + + DeterminismPolicy() : determinism_(Type::kDefault) {} + explicit DeterminismPolicy(Type determinism) : determinism_(determinism) {} + // Creates a DeterminismPolicy with Type kDeterministic or + // kNondeterministic, depending on the values of `is_deterministic`. + explicit DeterminismPolicy(bool is_deterministic); + + static absl::Status FromString(const std::string& s, DeterminismPolicy* out); + + // Returns the string representing the determinism policy. This will be one of + // the string constants defined above. + std::string String() const; + + /// Convenience methods for checking the DeterminismPolicy::Type. + bool IsDeterministic() const { return determinism_ == Type::kDeterministic; } + bool IsNondeterministic() const { + return determinism_ == Type::kNondeterministic; + } + bool IsDefault() const { return determinism_ == Type::kDefault; } + + private: + Type determinism_; +}; + +// Resolves non-deterministic seeds if necessary, returning either the original +// seeds or the resolved seeds. +// +// By TensorFlow convention, if both seeds are 0, they should be replaced with +// non-deterministically chosen seeds. +std::pair MaybeOverrideSeeds( + std::pair seeds); + +// Adds the functions in `to_add` to `base`. If a function with a matching +// signature already exists in `base`, replaces it with the function from +// `to_add`. +absl::Status AddToFunctionLibrary(FunctionLibraryDefinition* base, + const FunctionLibraryDefinition& to_add); +absl::Status AddToFunctionLibrary(FunctionLibraryDefinition* base, + const FunctionDefLibrary& to_add); + +// Determines whether the given function is stateful. +absl::Status IsFunctionStateful(const FunctionLibraryDefinition& library, + const FunctionDef& function_def); + +// Determines whether the given node is stateful. +absl::Status IsNodeStateful(const FunctionLibraryDefinition& library, + const NodeDef& node); + +// Creates a runner that runs functions with limited parallelism. +std::function)> RunnerWithMaxParallelism( + std::function)> runner, int max_parallelism); + +// Op for creating a typed dummy resource. +// +// This op is used to provide a resource "placeholder" for ops such as +// `CacheDatasetV2` or `ShuffleDatasetV2` that expects a resource input. +// Originally, the lifetime of the resources passed into these ops was managed +// externally. After the implementation changed to manage the lifetime of the +// resources (including creation) by the ops themselves, the resource input is +// only needed to pass a resource handle through graph rewrites. When they are +// invoked from user code, the implementation passes in a dummy resource. +template +class DummyResourceOp : public OpKernel { + public: + explicit DummyResourceOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} + + void Compute(OpKernelContext* ctx) override { + Tensor* tensor; + OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &tensor)); + tensor->scalar()() = MakeResourceHandle( + ctx, /*container=*/"", /*name=*/"dummy_resource"); + } +}; + +// Given an op prefix and an op to match, returns whether the op to match +// is a match for any version of the op prefix. For example, +// MatchesAnyVersion("BatchDataset", "BatchDataset") == true +// MatchesAnyVersion("BatchDataset", "BatchDatasetV2") == true +// MatchesAnyVersion("BatchDataset", "BatchDatasetV3") == true +// MatchesAnyVersion("PaddedBatchDataset", "BatchDataset") == false +bool MatchesAnyVersion(absl::string_view op_prefix, + absl::string_view op_to_match); + +// Returns the index-th slice of a given tensor. If the index-th slice of +// the tensor is not aligned, returns a deep copy of the tensor. +Tensor MaybeCopySubSlice(const Tensor& tensor, int64 index); + +// Removes device placements from the ops of all functions in `library`. +void StripDevicePlacement(FunctionDefLibrary* library); + +// Copies partial of the batch output. +absl::Status CopyPartialBatch(int64_t num_elements, const Tensor& value, + Tensor* output); + +// Reads a batch when restoring the iterator. +absl::Status ReadBatch(IteratorContext* ctx, IteratorStateReader* reader, + int64_t batch_size, const string& iterator_prefix, + const string& batch_prefix, std::vector* batch); + +// Writes a batch when saving the iterator. +absl::Status WriteBatch(int64_t batch_size, int64_t num_elements, + const string& iterator_prefix, + const string& batch_prefix, IteratorStateWriter* writer, + std::vector* batch); + +// Reads a status when restoring the iterator. +absl::Status ReadStatus(const string& iterator_prefix, const string& prefix, + IteratorStateReader* reader, absl::Status* status); + +// Writes a status when saving the iterator. +absl::Status WriteStatus(const string& iterator_prefix, const string& prefix, + const absl::Status& status, + IteratorStateWriter* writer); + +// Processes a batch to output. In the case a partial batch is encountered, copy +// only partial of the batch. +absl::Status ProcessBatch(int64_t batch_size, int64_t num_elements, + bool drop_remainder, const absl::Status& status, + IteratorContext* ctx, std::vector* output, + bool* end_of_sequence, std::vector* batch); + +// Copies the input elements to a batch. +// +// The `batch_elements` argument contains the individual elements to copy into a +// batch. The `parallel_copy` argument indicates whether to parallelize the +// copy. +// The `out_tensors` argument will be used to store the resulting batch (one for +// each component of the input). +absl::Status CopyBatch(AnyContext ctx, + std::vector>&& batch_elements, + bool parallel_copy, std::vector* out_tensors); + +// Computes the set of experiments to apply based on the job name, task id, +// rollout percentage of registered experiments, and the +// TF_DATA_EXPERIMENT_OPT_IN and TF_DATA_EXPERIMENT_OPT_OUT environment +// variables. +absl::flat_hash_set GetExperiments(); +absl::flat_hash_set GetExperiments( + const std::string& job_name, int64_t task_id, + std::function hash_func); + +// Logs and records the experiments that will be applied. +void LogAndRecordExperiments(const absl::flat_hash_set& experiments); + +// Computes the set of enabled, disabled, and default optimizations based on the +// given options. An optimization must be a graph optimizer name that has been +// registered with Grappler. +void GetOptimizations(const Options& options, + absl::flat_hash_set* optimizations_enabled, + absl::flat_hash_set* optimizations_disabled, + absl::flat_hash_set* optimizations_default); + +// Creates graph rewrite configs based on the given options. The configs will +// only be used if their corresponding optimizers registered with Grappler are +// enabled. +// A config is a string with the following format: +// :: +absl::flat_hash_set CreateGraphRewriteConfigs(const Options& options); + +// Determines whether max intra-op parallelism should be configured. +bool ShouldConfigureMaxIntraOpParallelism(const Options& options); + +// Determines whether private threadpool should be used. +bool ShouldUsePrivateThreadPool(const Options& options); + +// Determines whether autotuning should be used. +bool ShouldUseAutotuning(const Options& options); + +// Determines whether optimizations should be applied. +bool ShouldApplyOptimizations( + const Options& options, + const absl::flat_hash_set& optimizations_enabled, + const absl::flat_hash_set& optimizations_default); + +// Returns the default CPU budget. +inline int GetCpuBudget() { + static bool in_experiment = GetExperiments().contains("tune_cpu_budget"); + return (in_experiment ? 1.2 : 1.0) * port::NumSchedulableCPUs(); +} + +// Returns the initial value for parallelism parameter before the first Autotune +// optimization. +int64 GetAutotuneDefaultParallelism(IteratorContext* ctx); + +// Creates an iterator context appropriate for a nested dataset's iterator. A +// nested dataset is a dataset created within another dataset, e.g. by the +// function passed to `interleave` or `flat_map`. +IteratorContext MakeNestedIteratorContext(IteratorContext* ctx); + +// A `DatasetExperimentRegistry::JobSelector` that randomly selects +// `rollout_pct` percent of all jobs. `name_hash` is a hash of the experiment +// and job names. +template +bool RandomJobSamplePercentage(uint64_t name_hash) { + return name_hash % 100 < rollout_pct; +} + +// A `DatasetExperimentRegistry::TaskSelector` that selects all tasks. +bool AllTasks(int64_t unused_task_id, bool unused_evens); + +// A `DatasetExperimentRegistry::TaskSelector` that selects the tasks for half +// of all hosts. Typically, one or two consecutive tasks run on a single host. +// If `evens` is `true`, selects tasks 0,1,4,5,8,9,..., otherwise selects tasks +// 2,3,6,7,10,11,... +bool IndependentHostTasks(int64_t task_id, bool evens); + +// Registry of tf.data experiments. +class DatasetExperimentRegistry { + public: + using JobSelector = std::function; + using TaskSelector = std::function; + + struct ExperimentSelector { + JobSelector job_selector; + TaskSelector task_selector; + }; + + // Registers the experiment. + static void Register(const string& experiment, JobSelector job_selector, + TaskSelector task_selector); + + // Returns all registered experiments. + static absl::flat_hash_map Experiments(); +}; + +// Helper class to register a dataset experiment. +class DatasetExperimentRegistrar { + public: + explicit DatasetExperimentRegistrar( + const string& experiment, + DatasetExperimentRegistry::JobSelector job_selector, + DatasetExperimentRegistry::TaskSelector task_selector) { + DatasetExperimentRegistry::Register(experiment, job_selector, + task_selector); + } +}; + +// Macro that can be used to register a dataset experiment. +#define REGISTER_DATASET_EXPERIMENT(experiment, job_selector, task_selector) \ + REGISTER_DATASET_OP_NAME_UNIQ_HELPER(__COUNTER__, experiment, job_selector, \ + task_selector) + +#define REGISTER_DATASET_OP_NAME_UNIQ_HELPER(ctr, experiment, job_selector, \ + task_selector) \ + REGISTER_DATASET_OP_NAME_UNIQ(ctr, experiment, job_selector, task_selector) + +#define REGISTER_DATASET_OP_NAME_UNIQ(ctr, experiment, job_selector, \ + task_selector) \ + static ::tensorflow::data::DatasetExperimentRegistrar \ + registrar__body__##ctr##__object(experiment, job_selector, \ + task_selector) + +} // namespace data +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_DATA_DATASET_UTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/data/finalization_utils.h b/third_party/tflite-hdrs/tensorflow/core/data/finalization_utils.h new file mode 100644 index 00000000..07e1d75b --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/data/finalization_utils.h @@ -0,0 +1,36 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_DATA_FINALIZATION_UTILS_H_ +#define TENSORFLOW_CORE_DATA_FINALIZATION_UTILS_H_ + +#include + +#include "absl/status/statusor.h" +#include "tensorflow/core/framework/dataset.h" +#include "tensorflow/core/framework/op_kernel.h" + +namespace tensorflow { +namespace data { + +// Returns the finalized version of the dataset. The returned DatasetBase is +// unowned and lives for as long as this dataset. +absl::StatusOr GetFinalizedDataset(OpKernelContext* ctx, + const DatasetBase* dataset); + +} // namespace data +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_DATA_FINALIZATION_UTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/data/flat_map_utils.h b/third_party/tflite-hdrs/tensorflow/core/data/flat_map_utils.h new file mode 100644 index 00000000..658f6855 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/data/flat_map_utils.h @@ -0,0 +1,112 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_DATA_FLAT_MAP_UTILS_H_ +#define TENSORFLOW_CORE_DATA_FLAT_MAP_UTILS_H_ + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "tensorflow/core/common_runtime/process_function_library_runtime.h" +#include "tensorflow/core/data/captured_function.h" +#include "tensorflow/core/framework/dataset.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/function_handle_cache.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/kernels/data/iterator_ops.h" +#include "tsl/platform/refcount.h" +#include "tsl/platform/threadpool.h" + +namespace tensorflow { +namespace data { + +// Utility class for computing the cardinality of a flat map dataset. +class FlatMapRandomAccessHandler { + public: + // Initializes the counter. This will save necessary information from `ctx`. + // `input_dataset` is the input dataset passed to `flat_map` (not the flat_map + // dataset). `captured_map_func` is the captured map function. + FlatMapRandomAccessHandler(OpKernelContext* ctx, + const DatasetBase* input_dataset, + CapturedFunction& captured_map_func); + virtual ~FlatMapRandomAccessHandler(); + FlatMapRandomAccessHandler(const FlatMapRandomAccessHandler&) = delete; + FlatMapRandomAccessHandler& operator=(const FlatMapRandomAccessHandler&) = + delete; + + // Returns the dataset cardinality. + absl::StatusOr Cardinality(); + + // Returns the cumulative cardinality at the index-th dataset. + absl::StatusOr CumulativeCardinality(size_t index); + + // Given the flattened element position `element_position`, returns the index + // of the dataset to which the element belongs. + absl::StatusOr GetDatasetIndex(size_t element_position); + + // Creates the dataset iterators. + absl::StatusOr>> MakeInputIterators( + IteratorContext* ctx, const DatasetBaseIterator* parent, + const std::string& prefix); + + private: + // Computes the cumulative cardinalities. + absl::StatusOr> ComputeCardinalities(); + + // Creates the input datasets. Each dataset is the result of applying the map + // function to one element from the input iterator. + absl::StatusOr> MakeInputDatasets() const; + absl::StatusOr MakeInputDataset( + std::vector input_tensors, + const InstantiatedCapturedFunction& map_func) const; + + const DatasetBase* input_dataset_; + CapturedFunction& captured_map_func_; + + // The iterator context which bundles together the necessary runtime support + // to create and get elements from the input dataset. + std::unique_ptr ctx_; + FunctionLibraryRuntime* flr_; + std::unique_ptr flib_def_; + std::unique_ptr pflr_; + std::unique_ptr interop_threadpool_; + std::unique_ptr function_handle_cache_; + std::function)> runner_; + ResourceMgr resource_mgr_; + CancellationManager cancellation_manager_; + UnboundedThreadPool unbounded_thread_pool_; + + // Input datasets generated by running the map function. Each dataset is the + // result of applying the map function to one element from the input iterator. + std::deque input_datasets_; + + // Cumulative cardinalities. Before `ComputeCardinalities` is called, this is + // an empty vector. After `ComputeCardinalities` is called, the last element + // is the dataset cardinality. + absl::StatusOr> cumulative_cardinalities_ = + std::vector{}; +}; + +} // namespace data +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_DATA_FLAT_MAP_UTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/data/global_shuffle_utils.h b/third_party/tflite-hdrs/tensorflow/core/data/global_shuffle_utils.h new file mode 100644 index 00000000..66e2ff1a --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/data/global_shuffle_utils.h @@ -0,0 +1,100 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_DATA_GLOBAL_SHUFFLE_UTILS_H_ +#define TENSORFLOW_CORE_DATA_GLOBAL_SHUFFLE_UTILS_H_ + +#include +#include +#include +#include + +#include "absl/base/thread_annotations.h" +#include "absl/status/status.h" +#include "absl/synchronization/mutex.h" +#include "tensorflow/core/framework/dataset.h" +#include "tensorflow/core/framework/tensor.h" + +namespace tensorflow { +namespace data { + +// Builds and selects the `IteratorContext` to use based on whether the dataset +// is globally shuffled. +// +// Example usage in `Iterator::GetNextInternal`: +// +// ``` +// IteratorContextWithIndexMapper ctx_with_index_mapper(ctx, this); +// TF_RETURN_IF_ERROR(input_impl_->GetNext( +// ctx_with_index_mapper.Get(), out_tensors, end_of_sequence)); +// ctx_with_index_mapper.MergeCheckpoint(); +// ``` +// +// The iterator should also implement `GetIndexMapper` if it needs to customize +// the index mapping behavior. +class IteratorContextWithIndexMapper { + public: + // Caller keeps ownership of both pointers. + explicit IteratorContextWithIndexMapper(IteratorContext* ctx, + const IteratorBase* iterator); + virtual ~IteratorContextWithIndexMapper() = default; + IteratorContextWithIndexMapper(const IteratorContextWithIndexMapper&) = + delete; + IteratorContextWithIndexMapper& operator=( + const IteratorContextWithIndexMapper&) = delete; + + IteratorContext* Get(); + void MergeCheckpoint(); + + private: + IteratorContext* ctx_; + std::optional ctx_with_index_mapper_; +}; + +// For source datasets that support random access, this class adapts the dataset +// random access API to support globally shuffled iterators. +class GlobalShuffleIterator { + public: + // The dataset is expected to support random access by implementing the + // absl::Status Get(int64_t index, std::vector* out_tensors) const. + explicit GlobalShuffleIterator(const DatasetBase* dataset) + : dataset_(dataset) {} + + // Returns the next shuffled element. + // REQUIRES: ctx->index_mapper() != nullptr. + absl::Status GetNext(IteratorContext* ctx, std::vector* out_tensors, + bool* end_of_sequence); + + absl::Status Save(const std::string& parent_iterator_prefix, + SerializationContext* ctx, IteratorStateWriter* writer); + + // Restores the element count. + // REQUIRES: ctx->restored_element_count() != nullopt. + absl::Status Restore(const std::string& parent_iterator_prefix, + IteratorContext* ctx, IteratorStateReader* reader); + + private: + const DatasetBase* const dataset_; + + mutable absl::Mutex mu_; + + // Count of elements produced by this iterator when it runs in the random + // access mode. + int64_t element_count_ ABSL_GUARDED_BY(mu_) = 0; +}; + +} // namespace data +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_DATA_GLOBAL_SHUFFLE_UTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/data/hash_utils.h b/third_party/tflite-hdrs/tensorflow/core/data/hash_utils.h new file mode 100644 index 00000000..2effd416 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/data/hash_utils.h @@ -0,0 +1,64 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_DATA_HASH_UTILS_H_ +#define TENSORFLOW_CORE_DATA_HASH_UTILS_H_ + +#include "absl/status/status.h" +#include "tensorflow/core/common_runtime/function.h" +#include "tensorflow/core/framework/dataset.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/resource_mgr.h" +#include "tensorflow/core/framework/tensor.h" + +namespace tensorflow { +namespace data { + +// Returns a stable hash of the subgraph rooted at the given node. +// +// NOTE: There is currently no guarantee that the hash of a subgraph will stay +// the same between TensorFlow builds. +absl::Status HashNode(const GraphDef& graph, const NodeDef& node, uint64* hash); +absl::Status HashNode(const GraphDef& graph, const NodeDef& node, + const FunctionLibraryDefinition& flib_def, uint64* hash); + +// Returns a stable hash of the given tensor. +// +// NOTE: There is currently no guarantee that the hash of a subgraph will stay +// the same between TensorFlow builds. +absl::Status HashTensor(const Tensor& tensor, uint64* hash); + +// Returns a stable hash of the given graph. +// +// NOTE: There is currently no guarantee that the hash of a subgraph will stay +// the same between TensorFlow builds. +absl::Status HashGraph(const GraphDef& graph, uint64* hash); + +// Determines whether the given graphs are equal, following the same logic used +// for HashGraph. Returns OK if the graphs can be determined to be equal, +// otherwise returns an error message explaining why the graphs couldn't be +// determined to be equal. +absl::Status CheckGraphsEqual(const GraphDef& a, const GraphDef& b); + +// Determines whether the subgraphs rooted at the given nodes are equal +// following the same logic used for HashGraph. Returns OK if the graphs can be +// determined to be equal, otherwise returns an error message explaining why the +// graphs couldn't be determined to be equal. +absl::Status CheckSubgraphsEqual(const GraphDef& a, const NodeDef* node_a, + const GraphDef& b, const NodeDef* node_b); +} // namespace data +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_DATA_HASH_UTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/data/metric_utils.h b/third_party/tflite-hdrs/tensorflow/core/data/metric_utils.h new file mode 100644 index 00000000..7d67cb92 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/data/metric_utils.h @@ -0,0 +1,87 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_DATA_METRIC_UTILS_H_ +#define TENSORFLOW_CORE_DATA_METRIC_UTILS_H_ + +#include +#include +#include + +#include "absl/time/time.h" +#include "tensorflow/core/data/tfdataz_metrics.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/thread_annotations.h" + +namespace tensorflow { +namespace data { + +// Exports the metrics for `GetNext` calls by tf.data iterators. When the user +// calls `RecordStart` and `RecordStop`, it will export a latency sample. It +// also exports throughput, tf.data iterator life time, etc. This class is +// thread-safe. Example usage: +// +// ``` +// IteratorMetricsCollector metrics_collector(DEVICE_CPU, env); +// absl::Time start_time = metrics_collector.RecordStart(); +// auto status = iterator_->GetNext(IteratorContext(std::move(params)), +// out_tensors, end_of_sequence); +// metrics_collector.RecordStop(start_time, *out_tensors); +// ``` +class IteratorMetricsCollector { + public: + // Constructs a `IteratorMetricsCollector`. `device_type` is one of the + // devices defined in `types.h` (DEVICE_CPU, DEVICE_GPU, DEVICE_TPU, etc). + // We only collect metrics for CPU devices. This is a heuristic to avoid + // collecting metrics for device-side iterators created by the multi-device + // iterator mechanism. + IteratorMetricsCollector(const std::string& device_type, const Env& env); + + // Starts the timer for the next `GetNext` call. Returns the start time. + absl::Time RecordStart(); + + // Records metrics for the most recent `GetNext` call, including the latency, + // bytes fetched, iterator life time, etc. `start_time` is the start time + // returned by `RecordStart`. `output` is the output of the `GetNext` call. + void RecordStop(absl::Time start_time, const std::vector& output); + + private: + // We only collect metrics for CPU devices. + bool ShouldCollectMetrics() const; + + // One of the devices defined in `types.h` + // (DEVICE_CPU, DEVICE_GPU, DEVICE_TPU, etc). + const std::string device_type_; + const Env& env_; + + mutex mu_; + + // Records the number of currently active `GetNext` calls. + uint64_t num_active_calls_ TF_GUARDED_BY(mu_) = 0; + + // Records the start time (in microseconds) of the first `RecordStart()` call + // that followed the last period of inactivity. + uint64_t first_start_time_us_ TF_GUARDED_BY(mu_) = 0; + + // Records the end time (in microseconds) of the most recent `RecordStop()` + // call. + uint64_t end_time_us_ TF_GUARDED_BY(mu_) = 0; +}; + +} // namespace data +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_DATA_METRIC_UTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/data/name_utils.h b/third_party/tflite-hdrs/tensorflow/core/data/name_utils.h new file mode 100644 index 00000000..72e870a1 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/data/name_utils.h @@ -0,0 +1,109 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_DATA_NAME_UTILS_H_ +#define TENSORFLOW_CORE_DATA_NAME_UTILS_H_ + +#include + +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +namespace data { +namespace name_utils { + +extern const char kDelimiter[]; +extern const char kDefaultDatasetDebugStringPrefix[]; + +struct OpNameParams { + int op_version = 1; +}; + +struct DatasetDebugStringParams { + template + void set_args(T... input_args) { + args = {static_cast(input_args).data()...}; + } + + int op_version = 1; + string dataset_prefix = ""; + std::vector args; +}; + +struct IteratorPrefixParams { + int op_version = 1; + string dataset_prefix = ""; +}; + +// Merge the given args in the format of "(arg1, arg2, ..., argn)". +// +// e.g. ArgsToString({"1", "2", "3"}) -> "(1, 2, 3)"; ArgsToString({}) -> "". +string ArgsToString(const std::vector& args); + +// Returns the dataset op name. +// +// e.g. OpName("Map") -> "MapDataset". +string OpName(const string& dataset_type); + +// Returns the dataset op names. +// +// e.g. OpName(ConcatenateDatasetOp::kDatasetType, OpNameParams()) +// -> "ConcatenateDataset" +// +// OpNameParams params; +// params.op_version = 2; +// OpName(ParallelInterleaveDatasetOp::kDatasetType, params) +// -> "ParallelInterleaveDatasetV2" +string OpName(const string& dataset_type, const OpNameParams& params); + +// Returns a human-readable debug string for this dataset in the format of +// "FooDatasetOp(arg1, arg2, ...)::Dataset". +// +// e.g. DatasetDebugString("Map") -> "MapDatasetOp::Dataset"; +string DatasetDebugString(const string& dataset_type); + +// Returns a human-readable debug string for this dataset in the format of +// "FooDatasetOp(arg1, arg2, ...)::Dataset". +// +// e.g. +// DatasetDebugStringParams range_params; +// range_params.set_args(0, 10, 3); +// DatasetDebugString(RangeDatasetOp::kDatasetType, range_params) +// -> "RangeDatasetOp(0, 10, 3)::Dataset"); +string DatasetDebugString(const string& dataset_type, + const DatasetDebugStringParams& params); + +// Returns a string that identifies the sequence of iterators leading up to +// the iterator of this dataset. +// +// e.g. IteratorPrefix("Map", "Iterator::Range") -> "Iterator::Range::Map". +string IteratorPrefix(const string& dataset_type, const string& prefix); + +// Returns a string that identifies the sequence of iterators leading up to +// the iterator of this dataset. +// +// e.g. +// IteratorPrefixParams params; +// params.op_version = 2; +// IteratorPrefix(BatchDatasetOp::KDatasetType, "Iterator::Range", params) -> +// "Iterator::Range::BatchV2". +string IteratorPrefix(const string& dataset_type, const string& prefix, + const IteratorPrefixParams& params); + +} // namespace name_utils +} // namespace data +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_DATA_NAME_UTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/data/rewrite_utils.h b/third_party/tflite-hdrs/tensorflow/core/data/rewrite_utils.h new file mode 100644 index 00000000..addd6f20 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/data/rewrite_utils.h @@ -0,0 +1,93 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_DATA_REWRITE_UTILS_H_ +#define TENSORFLOW_CORE_DATA_REWRITE_UTILS_H_ + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "tensorflow/core/platform/platform.h" + +// On mobile we do not provide this functionality because not all of its +// dependencies are available there. +#if !defined(IS_MOBILE_PLATFORM) + +#include +#include +#include + +#include "absl/container/flat_hash_set.h" +#include "tensorflow/core/framework/dataset.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/grappler/grappler_item.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/statusor.h" +#include "tensorflow/core/platform/tstring.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/protobuf/rewriter_config.pb.h" + +namespace tensorflow { +namespace data { + +RewriterConfig CreateRewriterConfig( + const absl::flat_hash_set& optimizations, + const absl::flat_hash_set& optimizations_configs); + +// Rewrites the input dataset using the given config. The rewritten_input +// stored in the core::RefCountPtr* output parameter is owned. +absl::Status RewriteDataset(OpKernelContext* ctx, const DatasetBase* input, + std::function config_factory, + bool record_fingerprint, + core::RefCountPtr* rewritten_input); + +// Creates a grappler item for `graph_def`, which is required for graph +// optimization. +// `dataset_node` is the name of the node corresponding to the dataset. +// If `add_fake_sinks` is true, it adds fake sink node to graph and functions to +// allow rewriting the actual sink nodes. +// If `apply_optimizations` is true, general grappler optimizations at level +// `tensorflow::OptimizerOptions::L1` are applied to the graph. +// TODO(b/118820916): When MetaOptimizer adds provisions for function retvals to +// be optimizable, we will no longer need to add fake nodes. +std::unique_ptr GetGrapplerItem( + GraphDef* graph_def, std::string* dataset_node, bool add_fake_sinks, + bool apply_optimizations = true); + +// Returns the name of the node corresponding to the dataset. It is indicated by +// the symbolic `_Retval` node. +absl::StatusOr GetDatasetNode(const GraphDef& graph_def); + +// Like `GetDatasetNode` above, but returns the entire node object. +absl::StatusOr GetDatasetNodeDef(const GraphDef& graph_def); + +// Determines which optimizations should be applied. +// +// The result will contain any optimizations that are explicitly enabled, any +// default optimization that are not explicitly disabled, and any experiment +// that corresponds to an optimization as long as the optimization is not +// explicitly disabled. +absl::flat_hash_set SelectOptimizations( + const absl::flat_hash_set& experiments, + const absl::flat_hash_set& optimizations_enabled, + const absl::flat_hash_set& optimizations_disabled, + const absl::flat_hash_set& optimizations_default); + +} // namespace data +} // namespace tensorflow +#endif // !IS_MOBILE_PLATFORM + +#endif // TENSORFLOW_CORE_DATA_REWRITE_UTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/data/root_dataset.h b/third_party/tflite-hdrs/tensorflow/core/data/root_dataset.h new file mode 100644 index 00000000..e5b8f8db --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/data/root_dataset.h @@ -0,0 +1,108 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_DATA_ROOT_DATASET_H_ +#define TENSORFLOW_CORE_DATA_ROOT_DATASET_H_ + +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "tensorflow/core/framework/dataset.h" +#include "tensorflow/core/framework/dataset_options.pb.h" +#include "tensorflow/core/framework/model.pb.h" +#include "tensorflow/core/platform/mem.h" +#include "tensorflow/core/platform/refcount.h" + +namespace tensorflow { +namespace data { + +// Dataset transformation responsible for internal tf.data logic such as +// autotuning, applying threading configuration. +class RootDataset : public DatasetBase { + public: + struct Params { + bool autotune = true; + model::AutotuneAlgorithm autotune_algorithm; + std::function autotune_cpu_budget_func; + double ram_budget_share; + int64_t autotune_ram_budget_from_options; + int64_t max_intra_op_parallelism = 1; + int64_t private_threadpool_size = 0; + + int64_t ComputeInitialAutotuneRamBudget() const { + if (autotune_ram_budget_from_options > 0) { + return autotune_ram_budget_from_options; + } else { + return ram_budget_share * port::AvailableRam(); + } + } + }; + + static absl::Status FromOptions(const DatasetBase* input, + DatasetBase** output); + static absl::Status FromOptions(core::RefCountPtr input, + DatasetBase** output); + + ~RootDataset() override; + + const DataTypeVector& output_dtypes() const override; + const std::vector& output_shapes() const override; + + int64_t CardinalityInternal(CardinalityOptions options) const override; + absl::Status Get(OpKernelContext* ctx, int64 index, + std::vector* out_tensors) const override; + absl::Status CheckExternalState() const override; + string DebugString() const override; + absl::Status InputDatasets( + std::vector* inputs) const override; + std::unique_ptr MakeIteratorInternal( + const string& prefix) const override; + absl::Status RandomIndexingCompatible() const override { + return random_indexing_compatible_; + } + + protected: + absl::Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, + Node** output) const override; + + private: + class Iterator; + + RootDataset(const DatasetBase* input, const Params& params); + + RootDataset(core::RefCountPtr input, const Params& params); + + const DatasetBase* input_; + core::RefCountPtr owned_input_; + const Params params_; + TraceMeMetadata traceme_metadata_; + absl::Status random_indexing_compatible_; +}; + +// Finalizes the `input` dataset, which is expected to be called before the +// dataset is about to be iterated. This can for instance apply static graph +// optimizations or inject internal tf.data transformations responsible for +// autotuning or threading configuration. The caller must ensure that the +// input dataset to be finalized outlives the output. +absl::Status FinalizeDataset(OpKernelContext* ctx, const DatasetBase* input, + DatasetBase** output); + +} // namespace data +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_DATA_ROOT_DATASET_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/data/serialization_utils.h b/third_party/tflite-hdrs/tensorflow/core/data/serialization_utils.h new file mode 100644 index 00000000..e59ac959 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/data/serialization_utils.h @@ -0,0 +1,244 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_DATA_SERIALIZATION_UTILS_H_ +#define TENSORFLOW_CORE_DATA_SERIALIZATION_UTILS_H_ + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "tensorflow/core/framework/dataset.h" +#include "tensorflow/core/framework/dataset.pb.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/variant_tensor_data.h" +#include "tensorflow/core/lib/core/status.h" +#include "tsl/platform/statusor.h" + +namespace tensorflow { +namespace data { + +inline constexpr absl::string_view kRetvalOp = "_Retval"; + +// Reads dataset elements from the checkpoint reader using the given key prefix. +absl::Status ReadElementsFromCheckpoint( + IteratorContext* ctx, IteratorStateReader* reader, + absl::string_view key_prefix, std::vector>* elements); + +// Writes dataset elements to the checkpoint writer using the given key prefix. +// The elements can be read back by passing the same key prefix to +// ReadElementsFromCheckpoint. Only one list of elements can be written under +// the same key_prefix. +absl::Status WriteElementsToCheckpoint( + IteratorStateWriter* writer, absl::string_view key_prefix, + const std::vector>& elements); + +// Updates the dataset elements in the checkpoint for given `checkpoint_indices` +// using the given key prefix, assuming that vector of elements have +// checkpointed these before. The elements can be read back by passing the same +// key prefix to ReadElementsFromCheckpoint. +absl::Status UpdateCheckpointElements( + IteratorStateWriter* writer, absl::string_view key_prefix, + const std::vector>& elements, + const absl::flat_hash_set& checkpoint_indices); + +// Helper class for reading data from a vector of VariantTensorData objects. +class VariantTensorDataReader : public IteratorStateReader { + public: + explicit VariantTensorDataReader( + const std::vector& data); + + bool Contains(absl::string_view key) const override; + bool Contains(absl::string_view name, absl::string_view key) const override; + + absl::Status ReadScalar(absl::string_view key, int64_t* val) const override; + absl::Status ReadScalar(absl::string_view name, absl::string_view key, + int64_t* val) const override; + absl::Status ReadScalar(absl::string_view key, tstring* val) const override; + absl::Status ReadScalar(absl::string_view name, absl::string_view key, + tstring* val) const override; + absl::Status ReadTensor(absl::string_view key, Tensor* val) const override; + absl::Status ReadTensor(FunctionLibraryRuntime* flr, absl::string_view key, + Tensor* val) const override; + absl::Status ReadTensor(absl::string_view name, absl::string_view key, + Tensor* val) const override; + absl::Status ReadTensor(FunctionLibraryRuntime* flr, absl::string_view name, + absl::string_view key, Tensor* val) const override; + + private: + template + absl::Status ReadScalarInternal(absl::string_view name, absl::string_view key, + T* val) const; + absl::Status ReadTensorInternal(FunctionLibraryRuntime* flr, + absl::string_view name, absl::string_view key, + Tensor* val) const; + absl::Status ReadDatasetInternal(FunctionLibraryRuntime* flr, + absl::string_view name, + absl::string_view key, Tensor* val) const; + // Produces all key/value pairs stored in this reader. Useful for debugging. + std::map ReadAllTensors(); + + // For access to ReadAllTensors() + friend absl::StatusOr> + CheckpointStats(const std::string& checkpoint_bytes); + + std::map> map_; + std::map data_; // Not owned. +}; + +// Helper class used to build a list of VariantTensorData objects, one for each +// iterator which is determined from the key supplied from the Write* calls. +// Sample usage: +// VariantTensorDataWriter writer; +// writer.WriteScalar(full_name("buffer_size"), buffer_.size()); +// writer.WriteScalar(full_name("num_threads"), threadpool_.size()); +// .... +// std::vector> variants; +// writer.ReleaseData(&variants); +// Now the VariantTensorData objects can be used to serialize. +class VariantTensorDataWriter : public IteratorStateWriter { + public: + absl::Status WriteScalar(absl::string_view key, int64_t val) override; + absl::Status WriteScalar(absl::string_view name, absl::string_view key, + int64_t val) override; + + absl::Status WriteScalar(absl::string_view key, const tstring& val) override; + absl::Status WriteScalar(absl::string_view name, absl::string_view key, + const tstring& val) override; + + absl::Status WriteTensor(absl::string_view key, const Tensor& val) override; + absl::Status WriteTensor(absl::string_view name, absl::string_view key, + const Tensor& val) override; + + // Releases the built VariantTensorData's to `variants`. Clears out all + // class state. + void ReleaseData(std::vector>* variants); + + // Obtains a read-only version of the VariantTensorData's built. + void GetData(std::vector* variants); + + private: + void MaybeFlush(); + void Reset(); + + template + absl::Status WriteScalarInternal(absl::string_view name, + absl::string_view key, const T& val); + absl::Status WriteTensorInternal(absl::string_view name, + absl::string_view key, const Tensor& val); + absl::Status WriteDatasetInternal(absl::string_view name, + absl::string_view key, + const DatasetBase* dataset); + + bool is_flushed_ = false; + std::map> data_; + std::map> keys_; +}; + +// Wrapper for encoding/decoding the iterator state stored in a Variant tensor. +// The `GetData()` method returns an VariantTensorData object which contains all +// the state needed to restore a single iterator. +// +// Usage example: +// +// Encoding: +// +// Tensor t(DT_VARIANT, TensorShape({})); +// t->scalar()() = IteratorStateVariant(); +// +// Encode() sets the type_name of the VariantTensorData object to +// IteratorStateVariant::TypeName(). +// +// Decoding: +// +// Variant v = ; +// DecodeUnaryVariant(&v); +// IteratorStateVariant* wrapper = v.get(); +// IteratorStateReader reader({wrapper->GetData()}); +// iterator_resource->Restore(ctx, &reader); +// +// The type_name of the VariantTensorData object to be decoded must match +// IteratorStateVariant::TypeName(). +class IteratorStateVariant { + public: + IteratorStateVariant() = default; + IteratorStateVariant(const IteratorStateVariant& other); + IteratorStateVariant& operator=(IteratorStateVariant&& other) = default; + IteratorStateVariant& operator=(const IteratorStateVariant& other) = delete; + + static std::string TypeName(); + + // Initializes `this` from a VariantTensorData object. + absl::Status InitializeFromVariantData( + std::unique_ptr data); + + // Returns a borrowed pointer to the underlying VariantTensorData. + const VariantTensorData* GetData() const { return data_.get(); } + + // Encodes this `IteratorStateVariant` into `*data`. Data will be compressed + // and stored as a scalar `CompressedElement` tensor, or left uncompressed if + // compression fails. + void Encode(VariantTensorData* data) const; + + // Decodes from `data`. If `data` contains a single scalar `CompressedElement` + // tensor, it is assumed to be compressed by `Encode`, and will be + // uncompressed as part of `Decode`. + bool Decode(VariantTensorData data); + + std::string DebugString() const; + + private: + // Returns the compressed element in `data`. If `data` does not contain a + // compressed element, returns nullptr. + static const CompressedElement* GetCompressedElement( + const VariantTensorData& data); + + std::unique_ptr data_; +}; + +// Returns a GraphDef representation of the given dataset. +absl::Status AsGraphDef(const DatasetBase* dataset, + SerializationContext&& serialization_ctx, + GraphDef* graph_def); + +// Returns a GraphDef representation of the given dataset suitable for +// optimization rewrites. It sets serialization parameters to export a minimum +// graph with additional information for optimization (i.e. ignoring external +// state, not serializing data tensors, not failing if there are datasets which +// do not have AsGraphDef implemented). Sets the `dataset_node` parameter to the +// dataset's node name in the resulting GraphDef. +absl::Status AsGraphDefForRewrite( + OpKernelContext* ctx, const DatasetBase* input, + std::vector>* input_list, GraphDef* result, + string* dataset_node); + +// Analyzes the bytes of a tf.data iterator checkpoint to identify all of the +// keys in the checkpoint along with their sizes in bytes. +absl::StatusOr> CheckpointStats( + const std::string& checkpoint_bytes); + +} // namespace data +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_SERIALIZATION_UTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/data/service/auto_scaler.h b/third_party/tflite-hdrs/tensorflow/core/data/service/auto_scaler.h new file mode 100644 index 00000000..edd09863 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/data/service/auto_scaler.h @@ -0,0 +1,180 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_DATA_SERVICE_AUTO_SCALER_H_ +#define TENSORFLOW_CORE_DATA_SERVICE_AUTO_SCALER_H_ + +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" +#include "absl/time/time.h" +#include "tsl/platform/mutex.h" +#include "tsl/platform/status.h" +#include "tsl/platform/thread_annotations.h" + +namespace tensorflow { +namespace data { + +// Estimates the optimal number of tf.data service workers for an Iteration +// based on the current workload. +// Note: It is assumed that all reported times correspond to the same Iteration. +// +// Glossary: +// * Consumer: A client that consumes elements from tf.data service. +// * Worker: A tf.data service worker. +// * Processing time (PT): The estimated time it takes a worker to process and +// produce an element. +// * Target processing time (TPT): From the perspective of a consumer, +// it is the maximum time a tf.data input pipeline can take to produce an +// element such that the downstream processor wait time is 0. In other words, +// this is the ideal time the tf.data pipeline should take to produce an element +// so that training doesn't slow down due to waiting for elements. This means +// that we want processing time <= target processing time, so that when an +// element is requested, the pipeline has processed it already. +// * Worker throughput (WT): It is the multiplicative inverse of processing time +// (1 / PT). This refers to the number of elements produced by a worker per +// second. +// * Consumption rate (CR): It is the multiplicative inverse of target +// processing time (1 / TPT). This refers to the number of elements requested by +// a consumer per second. +// +// **AutoScaler overview** +// +// 1. It keeps track of the most recent worker throughputs reported by each +// worker in the data service cluster, as well as the most recent consumption +// rates reported by each consumer. WTs and CRs are derived from reporting PTs +// and TPTs, respectively. +// 2. Having this information, it estimates the optimal number of workers N as +// follows: +// N = (Sum of CRs reported by all consumers) / +// (Average of WTs reported by all workers) +// +// AutoScaler is thread-safe. +class AutoScaler { + public: + AutoScaler() = default; + // Returns the estimated optimal number of workers according to the current + // observed workload. If there are no previously reported processing and + // target processing times, returns nullopt. + std::optional GetOptimalNumberOfWorkers() const + TF_LOCKS_EXCLUDED(mu_); + // Reports the latest observed processing time from the worker with + // `worker_address`. Returns an error if `processing_time` is ZeroDuration or + // negative. + absl::Status ReportProcessingTime(const std::string& worker_address, + absl::Duration processing_time) + TF_LOCKS_EXCLUDED(mu_); + // Reports the latest observed target processing time from the consumer + // identified by `consumer_id`. Returns an error if `target_processing_time` + // is ZeroDuration or negative. + absl::Status ReportTargetProcessingTime(int64_t consumer_id, + absl::Duration target_processing_time) + TF_LOCKS_EXCLUDED(mu_); + // Unregisters the worker with `worker_address`, removing its reported + // processing time from consideration of the current workload estimation. + // Returns an error if the specified worker does not exist. + absl::Status RemoveWorker(const std::string& worker_address) + TF_LOCKS_EXCLUDED(mu_); + // Unregisters the consumer identified by `consumer_id`, removing its reported + // target processing time from consideration of the current workload + // estimation. Returns an error if the specified consumer does not exist. + absl::Status RemoveConsumer(int64_t consumer_id) TF_LOCKS_EXCLUDED(mu_); + + private: + mutable tsl::mutex mu_; + // Map from worker address to worker throughput. + absl::flat_hash_map worker_throughputs_ + TF_GUARDED_BY(mu_); + // Map from consumer id to consumption rate. + absl::flat_hash_map consumption_rates_ TF_GUARDED_BY(mu_); +}; + +// Exports a metric (/tensorflow/data/service/optimal_number_of_workers) with +// the estimated optimal number of tf.data service workers, according to +// the observed cluster workload. +// +// It estimates the number of workers as the maximum of the estimated optimal +// number of workers for all Iterations running in the tf.data service cluster. +// +// MultipleIterationsAutoScaler is thread-safe. +class MultipleIterationsAutoScaler { + public: + MultipleIterationsAutoScaler() = default; + // Unregisters iteration with `iteration_id`, removing its reported + // times from consideration of the current workload estimation. + // Returns an error if the specified iteration does not exist. + absl::Status UnregisterIteration(int64_t iteration_id) TF_LOCKS_EXCLUDED(mu_); + // Updates the metric value with the current estimated optimal number of + // workers. The estimate is limited to min(4 * `current_number_of_workers`, + // `current_number_of_workers` + 500). Returns an error if there are no + // previously reported processing and target processing times for at least one + // iteration, or `current_number_of_workers` is not positive. + absl::Status UpdateOptimalNumberOfWorkersMetric( + int64_t current_number_of_workers) TF_LOCKS_EXCLUDED(mu_); + // Returns the estimated optimal number of workers according to the current + // observed workload. If there are no previously reported processing and + // target processing times for at least one iteration, returns nullopt. + std::optional GetOptimalNumberOfWorkers() const + TF_LOCKS_EXCLUDED(mu_); + // Reports the latest observed processing time from the worker with + // `worker_address` for iteration with `iteration_id`. Returns an error if + // `processing_time` is ZeroDuration or negative. + absl::Status ReportProcessingTime(int64_t iteration_id, + const std::string& worker_address, + absl::Duration processing_time) + TF_LOCKS_EXCLUDED(mu_); + // Reports the latest observed target processing time from the consumer + // identified by `consumer_id` for iteration with `iteration_id`. Returns an + // error if `target_processing_time` is ZeroDuration or negative. + absl::Status ReportTargetProcessingTime(int64_t iteration_id, + int64_t consumer_id, + absl::Duration target_processing_time) + TF_LOCKS_EXCLUDED(mu_); + // Unregisters the worker with `worker_address` for iteration with + // `iteration_id`, removing its reported processing time from consideration of + // the current workload estimation. Returns an error if there are no + // previously reported processing times for iteration with `iteration_id` and + // the specified worker. + absl::Status RemoveWorker(int64_t iteration_id, + const std::string& worker_address) + TF_LOCKS_EXCLUDED(mu_); + // Unregisters the consumer identified by `consumer_id` for iteration with + // `iteration_id`, removing its reported target processing time from + // consideration of the current workload estimation. Returns an error if there + // are no previously reported processing times for iteration with + // `iteration_id` and the specified consumer. + absl::Status RemoveConsumer(int64_t iteration_id, int64_t consumer_id) + TF_LOCKS_EXCLUDED(mu_); + + private: + // Registers iteration with `iteration_id` if it does not exist already, + // allowing its future reported times to be considered for the current + // workload estimation. + void EnsureIterationIsRegistered(int64_t iteration_id) + TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); + mutable tsl::mutex mu_; + // Map from iteration id to AutoScaler. + absl::flat_hash_map> auto_scalers_ + TF_GUARDED_BY(mu_); +}; + +} // namespace data +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_DATA_SERVICE_AUTO_SCALER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/data/service/byte_size.h b/third_party/tflite-hdrs/tensorflow/core/data/service/byte_size.h new file mode 100644 index 00000000..84d16533 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/data/service/byte_size.h @@ -0,0 +1,198 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_DATA_SERVICE_BYTE_SIZE_H_ +#define TENSORFLOW_CORE_DATA_SERVICE_BYTE_SIZE_H_ + +#include +#include +#include + +namespace tensorflow { +namespace data { + +// A `ByteSize` represents data space usage measured in bytes. It is constructed +// using Bytes, KB, MB, GB, or TB. Supports common arithmetic operations. Uses +// `size_t` in its internal representation. Thus, it only supports non-negative +// sizes, and the maximum byte size is std::numeric_limits::max(). +// +// Usage example: +// +// constexpr ByteSize kAllocatedMemoryLimit = ByteSize::MB(64); +// +// Tensor data = ... +// ByteSize tensor_size = ByteSize::Bytes(data.AllocatedBytes()); +// if (tensor_size > 0.95 * kAllocatedMemoryLimit) { +// LOG(WARNING) << "Tensor memory usage is " << tensor_size << ". This is " +// << "close to the limit " << kAllocatedMemoryLimit << "."; +// } +class ByteSize final { + public: + // The default is 0 bytes. + constexpr ByteSize() = default; + constexpr ByteSize(const ByteSize&) = default; + ByteSize& operator=(const ByteSize&) = default; + + // Constructs byte sizes of bytes, KB, MB, GB, and TB. + constexpr static ByteSize Bytes(size_t n); + + // In this and following templates, `T` should be a numeric type, + // e.g.: size_t, double, etc. + template + constexpr static ByteSize KB(T n); + + template + constexpr static ByteSize MB(T n); + + template + constexpr static ByteSize GB(T n); + + template + constexpr static ByteSize TB(T n); + + // Compound assignment operators. + ByteSize& operator+=(ByteSize rhs); + + // Does not support negative bytes. If *this < rhs, returns 0 bytes. + ByteSize& operator-=(ByteSize rhs); + + template + ByteSize& operator*=(T rhs); + + template + ByteSize& operator/=(T rhs); + + // Converts the measurement into the specified unit. + size_t ToUnsignedBytes() const; + double ToDoubleBytes() const; + double ToDoubleKB() const; + double ToDoubleMB() const; + double ToDoubleGB() const; + double ToDoubleTB() const; + + // Returns a human-readable string of the byte size. For example, "5KB", + // "1GB", etc. + std::string DebugString() const; + + private: + constexpr explicit ByteSize(double bytes) : bytes_(bytes) {} + + size_t bytes_ = 0; +}; + +constexpr ByteSize ByteSize::Bytes(size_t n) { return ByteSize(n); }; + +template +constexpr ByteSize ByteSize::KB(T n) { + return ByteSize::Bytes(n * (size_t{1} << 10)); +} + +template +constexpr ByteSize ByteSize::MB(T n) { + return ByteSize::Bytes(n * (size_t{1} << 20)); +} + +template +constexpr ByteSize ByteSize::GB(T n) { + return ByteSize::Bytes(n * (size_t{1} << 30)); +} + +template +constexpr ByteSize ByteSize::TB(T n) { + return ByteSize::Bytes(n * (size_t{1} << 40)); +} + +// Compound assignments. +inline ByteSize& ByteSize::operator+=(ByteSize rhs) { + bytes_ += rhs.ToUnsignedBytes(); + return *this; +} + +inline ByteSize& ByteSize::operator-=(ByteSize rhs) { + if (bytes_ < rhs.ToUnsignedBytes()) { + bytes_ = 0; + return *this; + } + bytes_ -= rhs.ToUnsignedBytes(); + return *this; +} + +template +inline ByteSize& ByteSize::operator*=(T rhs) { + bytes_ *= rhs; + return *this; +} + +template +inline ByteSize& ByteSize::operator/=(T rhs) { + bytes_ /= rhs; + return *this; +} + +// Binary arithmetic operators. +inline ByteSize operator+(ByteSize lhs, ByteSize rhs) { + return lhs += rhs; +} + +inline ByteSize operator-(ByteSize lhs, ByteSize rhs) { + return lhs -= rhs; +} + +template +inline ByteSize operator*(ByteSize lhs, T rhs) { return lhs *= rhs; } + +template +inline ByteSize operator*(T lhs, ByteSize rhs) { return rhs *= lhs; } + +template +inline ByteSize operator/(ByteSize lhs, T rhs) { return lhs /= rhs; } + +inline double operator/(ByteSize lhs, ByteSize rhs) { + return lhs.ToDoubleBytes() / rhs.ToDoubleBytes(); +} + +// Comparison operators. +inline bool operator<(ByteSize lhs, ByteSize rhs) { + return lhs.ToUnsignedBytes() < rhs.ToUnsignedBytes(); +} + +inline bool operator>(ByteSize lhs, ByteSize rhs) { + return rhs < lhs; +} + +inline bool operator>=(ByteSize lhs, ByteSize rhs) { + return !(lhs < rhs); +} + +inline bool operator<=(ByteSize lhs, ByteSize rhs) { + return !(rhs < lhs); +} + +inline bool operator==(ByteSize lhs, ByteSize rhs) { + return lhs.ToUnsignedBytes() == rhs.ToUnsignedBytes(); +} + +inline bool operator!=(ByteSize lhs, ByteSize rhs) { + return !(lhs == rhs); +} + +// Output operator, which supports logging with LOG(*). +inline std::ostream& operator<<(std::ostream& os, ByteSize byte_size) { + return os << byte_size.DebugString(); +} + +} // namespace data +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_DATA_SERVICE_BYTE_SIZE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/data/service/client/common.h b/third_party/tflite-hdrs/tensorflow/core/data/service/client/common.h new file mode 100644 index 00000000..58c0f0a2 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/data/service/client/common.h @@ -0,0 +1,50 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_DATA_SERVICE_CLIENT_COMMON_H_ +#define TENSORFLOW_CORE_DATA_SERVICE_CLIENT_COMMON_H_ + +#include +#include +#include + +#include "absl/time/time.h" +#include "tensorflow/core/data/service/common.pb.h" +#include "tensorflow/core/protobuf/data_service.pb.h" + +namespace tensorflow { +namespace data { + +// tf.data service parameters. +struct DataServiceParams final { + std::string dataset_id; + ProcessingModeDef processing_mode; + std::string address; + std::string protocol; + std::string data_transfer_protocol; + std::string job_name; + int64_t repetition = 0; + std::optional num_consumers; + std::optional consumer_index; + int64_t max_outstanding_requests = 0; + absl::Duration task_refresh_interval; + TargetWorkers target_workers = TargetWorkers::TARGET_WORKERS_UNSPECIFIED; + DataServiceMetadata metadata; + std::optional cross_trainer_cache_options; +}; + +} // namespace data +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_DATA_SERVICE_CLIENT_COMMON_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/data/service/client/data_service_client.h b/third_party/tflite-hdrs/tensorflow/core/data/service/client/data_service_client.h new file mode 100644 index 00000000..7c211d55 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/data/service/client/data_service_client.h @@ -0,0 +1,274 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_DATA_SERVICE_CLIENT_DATA_SERVICE_CLIENT_H_ +#define TENSORFLOW_CORE_DATA_SERVICE_CLIENT_DATA_SERVICE_CLIENT_H_ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "tensorflow/core/data/service/client/common.h" +#include "tensorflow/core/data/service/common.h" +#include "tensorflow/core/data/service/common.pb.h" +#include "tensorflow/core/data/service/dispatcher.pb.h" +#include "tensorflow/core/data/service/dispatcher_client.h" +#include "tensorflow/core/data/service/worker_client.h" +#include "tensorflow/core/framework/dataset.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/statusor.h" +#include "tensorflow/core/platform/thread_annotations.h" + +namespace tensorflow { +namespace data { + +// Interface for interacting with the tf.data service iterator context. +class DataServiceContext { + public: + virtual ~DataServiceContext() = default; + virtual std::unique_ptr StartThread(const string& name, + std::function fn) = 0; + virtual void RecordBufferEnqueue(const std::vector& element) = 0; + virtual void RecordBufferDequeue(const std::vector& element) = 0; + // Returns the time in nanoseconds a tf.data input pipeline can take to + // produce an element such that the downstream processor wait time is 0. + // Returns 0 if there are not sufficient recorded iterator gap times to + // produce a good estimate, or the tf.data Model instance is null. + virtual double GetTargetProcessingTimeNsec() const = 0; + // Updates the `max_outstanding_requests` with + // `requested_outstanding_requests`. + // Returns the new max outstanding requests which may be different from the + // requested one depending on available ram. + virtual int64_t UpdateMaxOutstandingRequests( + int64_t max_outstanding_requests, + int64_t requested_outstanding_requests) = 0; +}; + +using DataServiceContextFactory = + std::function()>; + +// API for reading data from tf.data service. +// +// The client works by reading from tf.data workers in parallel and interleaving +// the dataset elements. It periodically queries the dispatcher to decide which +// workers to read from (in case workers are added or removed). The data reading +// is non-deterministic. This class is thread-safe. +class DataServiceClient { + public: + explicit DataServiceClient(const DataServiceParams& params); + virtual ~DataServiceClient(); + DataServiceClient(const DataServiceClient&) = delete; + DataServiceClient& operator=(const DataServiceClient&) = delete; + + // Initializes the client. + absl::Status Initialize( + const DeviceBase::AcceleratorDeviceInfo* accelerator_device_info, + Allocator* allocator); + + // Reads the next element from tf.data workers. Blocks if the next element is + // not ready. + virtual absl::StatusOr GetNext( + DataServiceContextFactory context_factory); + + // Cancels the client. + void Cancel(); + + TraceMeMetadata GetTraceMeMetadata() const; + + private: + struct Task { + Task(const TaskInfo& info, std::unique_ptr worker) + : info(info), worker(std::move(worker)) {} + + const TaskInfo info; + // Client for fetching task elements from the tf.data service worker. + std::unique_ptr worker; + // The next round to read from the task. + int64_t round = 0; + // Whether the task has been removed. The task will eventually be + // deleted from `tasks_` on the next dispatcher heartbeat. + bool removed = false; + bool skipped_previous_round = false; + // Indicates whether a worker thread is currently processing the task. + bool in_use TF_GUARDED_BY(&DataServiceClient::mu_) = false; + // Indicates whether the worker has returned end_of_sequence for the task. + bool end_of_sequence TF_GUARDED_BY(&DataServiceClient::mu_) = false; + // Number of retries. The more it is retried, the longer it should wait + // before the next retry. + int64_t num_retries = 0; + }; + + struct Result { + Result() = default; + Result(Result&&) = default; + Result& operator=(Result&&) = default; + Result(const Result&) = delete; + Result& operator=(const Result&) = delete; + + // Whether the result has been computed yet. GetNext needs to block + // until the next result is ready. + bool ready TF_GUARDED_BY(&DataServiceClient::mu_) = false; + std::vector element TF_GUARDED_BY(&DataServiceClient::mu_); + // The element's index within the tf.data worker it came from. Used for + // debugging. + int64_t element_index TF_GUARDED_BY(&DataServiceClient::mu_) = -1; + // The id of the task that generated the result. + int64_t task_id TF_GUARDED_BY(&DataServiceClient::mu_) = -1; + bool end_of_sequence TF_GUARDED_BY(&DataServiceClient::mu_) = false; + bool skip TF_GUARDED_BY(&DataServiceClient::mu_) = false; + }; + + void EnsureThreadsStarted(); + void CancelThreads(); + // Returns whether the client has finished and should return. + bool Finished() const; + // Returns whether the job has more data. + bool ShouldWaitForNext() const; + void DeleteLocalWorkerTasks(); + bool ShouldDeleteLocalTask(const TaskInfo& task) const; + // Periodically refresh the task list. + // Maintain one thread fetching elements for each task. + // TODO(aaudibert): Instead of polling, have dispatcher send updates when + // the list of tasks changes. + void TaskThreadManager(); + void TryBlockRound(int64_t round) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); + void UpdateIterationFinished(bool iteration_finished); + absl::Status AddTask(const TaskInfo& task_info); + absl::StatusOr> CreateWorkerClient( + const TaskInfo& task_info); + absl::StatusOr> CreateWorkerClient( + const std::string& protocol, const TaskInfo& task_info); + absl::StatusOr> + CreateGrpcWorkerClient(const TaskInfo& task_info); + absl::StatusOr> + CreateAlternativeWorkerClientMaybeWithGrpcFallback( + const DataTransferServerInfo& transfer_server, const TaskInfo& task_info); + void Heartbeat(); + void UpdateTasks(const ClientHeartbeatResponse& resp); + bool ShouldReadFromTask(const TaskInfo& task) const; + void RecordTFMetrics(const ClientHeartbeatResponse& resp); + void UpdateBufferSize(); + void UpdateWorkerThreads(); + void RunWorkerThread(std::function done); + // Reports whether we can request another element without violating + // `max_outstanding_requests_`. + bool ShouldProcessTask(); + // Searches for a task to process, visiting tasks in-order and giving every + // task a chance to proceed. + std::shared_ptr GetTaskToProcess(); + void AdvanceTaskIndex(); + absl::Status TryGetElement(const Task& task, bool allow_skip, + GetElementResult& result); + void ProcessGetElementResponse(bool enqueue_result, + GetElementResult& get_element_result, + std::shared_ptr result, Task& task); + absl::Status GetElementTraced(Task* task, int64_t deadline_micros, + bool enqueue_result, bool allow_skip, + std::shared_ptr result); + absl::Status MaybeRemoveTask(Task& task, int64_t deadline_micros, + Result& result); + absl::Status GetElement(Task* task, int64_t deadline_micros, + bool enqueue_result, bool allow_skip, + std::shared_ptr result); + bool ResultReady() const; + std::shared_ptr PopNextResult(); + bool IsCoordinatedRead() const; + std::string DebugString() const; + + const DataServiceParams params_; + + mutable mutex mu_; + condition_variable get_next_cv_ TF_GUARDED_BY(mu_); + condition_variable worker_thread_cv_ TF_GUARDED_BY(mu_); + condition_variable manager_thread_cv_ TF_GUARDED_BY(mu_); + + bool cancelled_ TF_GUARDED_BY(mu_) = false; + + // Number of outstanding requests. + int64_t outstanding_requests_ TF_GUARDED_BY(mu_) = 0; + + // max_outstanding_requests controls how many elements may be held in memory + // at the same time. This count includes both in-progress requests for + // elements as well as completed requests which haven't yet been produced. + int64_t max_outstanding_requests_ TF_GUARDED_BY(mu_); + + // The number of threads in `worker_threads_` which are still running. + int64_t num_running_worker_threads_ TF_GUARDED_BY(mu_) = 0; + + // The index of the next task in `tasks_` to read from. + int64_t next_task_index_ TF_GUARDED_BY(mu_) = 0; + + // The number tasks in the `tasks_` list that have reached end_of_sequence. + int64_t finished_tasks_ TF_GUARDED_BY(mu_) = 0; + + // List of tasks to read from. + std::vector> tasks_ TF_GUARDED_BY(mu_); + + // The current round robin round we are engaged in. A round involves reading + // from each task once. + int64_t current_round_ TF_GUARDED_BY(mu_) = 0; + + // Maximum round robin round to read up to before blocking, not inclusive. + // INVARIANT: current_round_ <= round_robin_round_limit_. + // If current_round_ == round_robin_round_limit_, + // next_task_index_ must be 0. + std::optional round_robin_round_limit_ TF_GUARDED_BY(mu_); + + // A status to be returned from the next call to `GetNext`. This is set by + // asynchronous threads when they encounter errors. + absl::Status status_ TF_GUARDED_BY(mu_) = absl::OkStatus(); + // A queue of results for `GetElement` requests to read from. When doing + // strict round robin reads, the queue will contain placeholder results with + // their `Result::ready` field false until their data has been retrieved + // from a worker. When not doing round-robin reads, results are only added + // to the queue after they are ready, to avoid head-of-line blocking. + std::queue> results_ TF_GUARDED_BY(mu_); + + bool initialized_ = false; + std::unique_ptr ctx_ TF_GUARDED_BY(mu_); + + // Set once in Initialize(). + int64_t job_id_; + int64_t iteration_client_id_; + std::unique_ptr dispatcher_; + const DeviceBase::AcceleratorDeviceInfo* accelerator_device_info_; + Allocator* allocator_; + + int64_t get_next_index_ TF_GUARDED_BY(mu_) = 0; + + bool iteration_finished_ TF_GUARDED_BY(mu_) = false; + bool should_finish_iteration_ TF_GUARDED_BY(mu_) = true; + + // The set of worker UIDs that we have already recorded metrics for. + absl::flat_hash_set worker_uids_ TF_GUARDED_BY(mu_); + + std::vector> worker_threads_ TF_GUARDED_BY(mu_); + std::unique_ptr task_thread_manager_ TF_GUARDED_BY(mu_); +}; + +} // namespace data +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_DATA_SERVICE_CLIENT_DATA_SERVICE_CLIENT_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/data/service/client/utils.h b/third_party/tflite-hdrs/tensorflow/core/data/service/client/utils.h new file mode 100644 index 00000000..2d2a0b77 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/data/service/client/utils.h @@ -0,0 +1,58 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_DATA_SERVICE_CLIENT_UTILS_H_ +#define TENSORFLOW_CORE_DATA_SERVICE_CLIENT_UTILS_H_ + +#include +#include +#include + +#include "absl/status/statusor.h" +#include "tensorflow/core/data/service/dispatcher.pb.h" +#include "tensorflow/core/platform/statusor.h" +#include "tensorflow/core/protobuf/data_service.pb.h" + +namespace tensorflow { +namespace data { + +// Gets the `DataServiceMetadata` for `dataset_id`. +absl::StatusOr GetDataServiceMetadata( + const std::string& dataset_id, const std::string& address, + const std::string& protocol); + +// Gets the `DisableCompressAtRuntimeResponse.compression_disabled_at_runtime` +// for the given dataset. +absl::StatusOr CompressionDisabledAtRuntime( + const std::string& dataset_id, const std::string& address, + const std::string& protocol, bool disable_compression_at_runtime); + +// Gets the `DataServiceConfig` for the data service running at `address`. +absl::StatusOr GetDataServiceConfig( + const std::string& address, const std::string& protocol); + +// Gets the compression from `metadata`. If `metadata` specifies no valid +// compression, returns an internal error. +absl::StatusOr GetValidatedCompression( + const std::string& dataset_id, const DataServiceMetadata& metadata); + +// Estimates the cardinality of a data service dataset. +int64_t EstimateCardinality(const ProcessingModeDef& processing_mode, + const DataServiceMetadata& metadata, + bool is_coordinated_read); + +} // namespace data +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_DATA_SERVICE_CLIENT_UTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/data/service/client/validate_utils.h b/third_party/tflite-hdrs/tensorflow/core/data/service/client/validate_utils.h new file mode 100644 index 00000000..07645004 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/data/service/client/validate_utils.h @@ -0,0 +1,32 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_DATA_SERVICE_CLIENT_VALIDATE_UTILS_H_ +#define TENSORFLOW_CORE_DATA_SERVICE_CLIENT_VALIDATE_UTILS_H_ + +#include "absl/status/status.h" +#include "tensorflow/core/data/service/client/common.h" +#include "tensorflow/core/platform/status.h" + +namespace tensorflow { +namespace data { + +// Validates data service dataset parameters. +absl::Status ValidateDataServiceParams( + const DataServiceParams& data_service_params); + +} // namespace data +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_DATA_SERVICE_CLIENT_VALIDATE_UTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/data/service/common.h b/third_party/tflite-hdrs/tensorflow/core/data/service/common.h new file mode 100644 index 00000000..e9760e56 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/data/service/common.h @@ -0,0 +1,120 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_DATA_SERVICE_COMMON_H_ +#define TENSORFLOW_CORE_DATA_SERVICE_COMMON_H_ + +#include +#include + +#include "absl/strings/string_view.h" +#include "tensorflow/core/data/service/common.pb.h" +#include "tensorflow/core/framework/dataset_options.pb.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/statusor.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/protobuf/data_service.pb.h" + +namespace tensorflow { +namespace data { + +// Increment this when making backwards-incompatible changes to communication +// between tf.data clients and servers. +constexpr int kDataServiceVersion = 9; + +// If the user starts a colocated tf.data worker on each TF host, the worker +// will be applied a "COLOCATED" tag. This is used to avoid reading from tf.data +// workers on other TF hosts when the host runs a local tf.data service worker. +constexpr absl::string_view kColocatedWorkerTag = "COLOCATED"; + +// Container to hold the result of a `GetNext` call. +struct GetNextResult final { + explicit GetNextResult() = default; + GetNextResult(const GetNextResult&) = delete; + GetNextResult& operator=(const GetNextResult&) = delete; + GetNextResult(GetNextResult&&) = default; + GetNextResult& operator=(GetNextResult&&) = delete; + + static GetNextResult EndOfSequence() { + GetNextResult result; + result.end_of_sequence = true; + return result; + } + + std::vector tensors; + bool end_of_sequence = false; +}; + +// Returns true if `processing_mode` specifies no sharding policy. +bool IsNoShard(const ProcessingModeDef& processing_mode); + +// Returns true if `processing_mode` is dynamic sharding. +bool IsDynamicShard(const ProcessingModeDef& processing_mode); + +// Returns true if `processing_mode` is static sharding. +bool IsStaticShard(const ProcessingModeDef& processing_mode); + +// Returns an internal error if `processing_mode` is invalid. +absl::Status ValidateProcessingMode(const ProcessingModeDef& processing_mode); + +// Converts tf.data service `sharding_policy` to `AutoShardPolicy`. Returns an +// internal error if `sharding_policy` is not supported. +absl::StatusOr ToAutoShardPolicy( + ProcessingModeDef::ShardingPolicy sharding_policy); + +// Parses a string representing a `TargetWorkers` (case-insensitive). +// Returns InvalidArgument if the string is not recognized. +absl::StatusOr ParseTargetWorkers(absl::string_view s); + +// Converts a `TargetWorkers` enum to string. +std::string TargetWorkersToString(TargetWorkers target_workers); + +// Parses a string representing a `DeploymentMode` (case-insensitive). +// Returns InvalidArgument if the string is not recognized. +absl::StatusOr ParseDeploymentMode(absl::string_view s); + +// Returns true if `status` is a retriable error that indicates preemption. +bool IsPreemptedError(const absl::Status& status); + +// Base class for data service clients. Data service clients are +// threadsafe. +class DataServiceClientBase { + public: + DataServiceClientBase(const std::string& address, const std::string& protocol) + : address_(address), protocol_(protocol) {} + + virtual ~DataServiceClientBase() = default; + // Not copyable or movable. + DataServiceClientBase(const DataServiceClientBase&) = delete; + DataServiceClientBase& operator=(const DataServiceClientBase&) = delete; + + // Initializes the client. Calling `Initialize()` is not required since the + // first RPC will perform any necessary initialization. However, it can be + // useful to call `Initialize()` proactively so that any errors that happen + // during initialization can be surfaced earlier. + virtual absl::Status Initialize() { return EnsureInitialized(); } + + protected: + // Initializes the client if it isn't already initialized. + virtual absl::Status EnsureInitialized() = 0; + + const std::string address_; + const std::string protocol_; +}; + +} // namespace data +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_DATA_SERVICE_COMMON_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/data/service/credentials_factory.h b/third_party/tflite-hdrs/tensorflow/core/data/service/credentials_factory.h new file mode 100644 index 00000000..d6a3bff5 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/data/service/credentials_factory.h @@ -0,0 +1,77 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_DATA_SERVICE_CREDENTIALS_FACTORY_H_ +#define TENSORFLOW_CORE_DATA_SERVICE_CREDENTIALS_FACTORY_H_ + +#include +#include + +#include "grpcpp/grpcpp.h" +#include "grpcpp/security/credentials.h" +#include "absl/strings/string_view.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { +namespace data { + +// Credential factory implementations should be threadsafe since all callers +// to `GetCredentials` will get the same instance of `CredentialsFactory`. +class CredentialsFactory { + public: + virtual ~CredentialsFactory() = default; + + // Returns a protocol name for the credentials factory. This is the string to + // look up with `GetCredentials` to find the registered credentials factory. + virtual std::string Protocol() = 0; + + // Stores server credentials to `*out`. + virtual absl::Status CreateServerCredentials( + std::shared_ptr<::grpc::ServerCredentials>* out) = 0; + + // Stores client credentials to `*out`. + virtual absl::Status CreateClientCredentials( + std::shared_ptr<::grpc::ChannelCredentials>* out) = 0; + + // Registers a credentials factory. + static void Register(CredentialsFactory* factory); + + // Creates server credentials using the credentials factory registered as + // `protocol`, and stores them to `*out`. + static absl::Status CreateServerCredentials( + absl::string_view protocol, + std::shared_ptr<::grpc::ServerCredentials>* out); + + // Creates client credentials using the credentials factory registered as + // `protocol`, and stores them to `*out`. + static absl::Status CreateClientCredentials( + absl::string_view protocol, + std::shared_ptr<::grpc::ChannelCredentials>* out); + + // Returns whether a factory has been registered under the given protocol + // name. + static bool Exists(absl::string_view protocol); + + private: + // Gets the credentials factory registered via `Register` for the specified + // protocol, and stores it to `*out`. + static absl::Status Get(const absl::string_view protocol, + CredentialsFactory** out); +}; + +} // namespace data +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_DATA_SERVICE_CREDENTIALS_FACTORY_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/data/service/cross_trainer_cache.h b/third_party/tflite-hdrs/tensorflow/core/data/service/cross_trainer_cache.h new file mode 100644 index 00000000..3ef48fe4 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/data/service/cross_trainer_cache.h @@ -0,0 +1,355 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_DATA_SERVICE_CROSS_TRAINER_CACHE_H_ +#define TENSORFLOW_CORE_DATA_SERVICE_CROSS_TRAINER_CACHE_H_ + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "tensorflow/core/data/service/byte_size.h" +#include "tensorflow/core/framework/metrics.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/statusor.h" +#include "tensorflow/core/platform/thread_annotations.h" + +namespace tensorflow { +namespace data { + +// Sliding-window cache shared across concurrent trainers. Readers call `Get` to +// read elements they haven't read. After a trainer reads an element, it remains +// in the cache and the data is shared with other trainers. This is useful for +// datasets involving expensive computation, and multiple models use the same +// data for training. For example, for hyperparameter tuning. +// +// The cache progresses when a trainer that has consumed all elements in the +// cache requests additional data. It has a bounded size. Elements are garbage +// collected when the cache becomes full. Consequently, trainers read from a +// sliding window through the dataset and may not read the full dataset. +// +// The `CrossTrainerCache` class is thread-safe. +// +// Example usage: +// +// // `InfiniteRange` returns 1, 2, 3, ... in the `GetNext` calls. +// class InfiniteRange : public CachableSequence { +// public: +// StatusOr GetNext() override { +// return next_++; +// } +// +// size_t GetElementSizeBytes(const int64_t& element) const override { +// return sizeof(element); +// } +// +// private: +// int64_t next_ = 1; +// }; +// +// CrossTrainerCache cache( +// /*max_cache_size_bytes=*/10 * (size_t{1} << 30), // 10GB +// std::make_unique()); +// +// std::shared_ptr next; +// TF_ASSIGN_OR_RETURN(next, cache.Get("Trainer 1")); // Returns 1 +// TF_ASSIGN_OR_RETURN(next, cache.Get("Trainer 2")); // Returns 1 +// TF_ASSIGN_OR_RETURN(next, cache.Get("Trainer 1")); // Returns 2 +// TF_ASSIGN_OR_RETURN(next, cache.Get("Trainer 2")); // Returns 2 + +// To use the cache, the user needs to define a `CachableSequence` to generate +// an infinite sequence of data. It should implement a `GetNext` method to +// produce elements, and a `GetElementSizeBytes` method to estimate the element +// size in bytes. +template +class CachableSequence { + public: + virtual ~CachableSequence() = default; + + // Returns the next element to be cached. + virtual StatusOr GetNext() = 0; + + // Returns the estimated size of the element in bytes. + virtual size_t GetElementSizeBytes(const ElementType&) const = 0; +}; + +// Sliding-window cache shared across concurrent trainers. +template +class CrossTrainerCache { + public: + // Creates a `CrossTrainerCache` with `max_cache_size_bytes` of memory budget. + // The cache should be able to hold at least one element, i.e.: + // REQUIRES: `max_cache_size_bytes >= max(GetElementSizeBytes(*))` + explicit CrossTrainerCache( + size_t max_cache_size_bytes, + std::unique_ptr> cachable_sequence); + virtual ~CrossTrainerCache() = default; + CrossTrainerCache(const CrossTrainerCache&) = delete; + CrossTrainerCache& operator=(const CrossTrainerCache&) = delete; + + // Gets the next element for a trainer. A `trainer_id` identifies the trainer + // reading from the cache. A trainer reads the next element it hasn't read + // before. After a trainer reads data, the data is cached and reused by other + // trainers. + StatusOr> Get( + const std::string& trainer_id); + + // Cancels the cache with `status` and notifies the readers. After cancelling, + // all `Get` calls will return `status`. + // REQUIRES: !status.ok() + void Cancel(absl::Status status); + + // Returns true if the cache has been cancelled. + bool IsCancelled() const; + + private: + struct CacheQueryResult { + std::shared_ptr element; + bool cache_hit; + }; + + // Returns the next element and metrics about this query. + StatusOr GetCacheQueryResult(const std::string& trainer_id); + + // Returns true if element is ready for `trainer_id`. An element is ready if + // other trainers have read the data and the data remains in the cache. If the + // data is not ready, one of the trainers need to extend the cache. + bool IsElementReady(const std::string& trainer_id); + + // Returns the absolute element index relative to the dataset (not relative to + // the cached elements). + size_t GetElementIndex(const std::string& trainer_id); + + // Returns the next element for `trainer_id`. + StatusOr> GetElement( + const std::string& trainer_id); + + // Reads a new element and writes it into the cache. + absl::Status ExtendCache(); + + // Frees old elements to keep the cache size below `max_cache_size_bytes_`. + // `new_element_size_bytes` is the size of the new element being inserted. + void FreeSpace(size_t new_element_size_bytes); + + // Records the cache hit rate and cache size. + void RecordMetrics(const CacheQueryResult& result); + + // Maximum cache size in bytes. + const size_t max_cache_size_bytes_; + + // The element sequence over which the sliding window cache operates. + std::unique_ptr> cachable_sequence_; + + mutable mutex mu_; + mutable condition_variable cv_; + + // If `status_` is non-OK, the cache is cancelled, and all method calls will + // return this status. + absl::Status status_ TF_GUARDED_BY(mu_) = absl::OkStatus(); + + // `cache_` stores the cached elements. + std::deque> cache_ TF_GUARDED_BY(mu_); + size_t cache_size_bytes_ TF_GUARDED_BY(mu_) = 0; + size_t cache_start_index_ TF_GUARDED_BY(mu_) = 0; + + // True if one thread is extending the cache. + bool extending_cache_ TF_GUARDED_BY(mu_) = false; + + // Maps trainer IDs to element indices. The indices are absolute indices + // within the dataset. The actual index to use with `cache_` would be + // `trainer_to_element_index_map_[trainer_id] - cache_start_index_`. + absl::flat_hash_map trainer_to_element_index_map_ + TF_GUARDED_BY(mu_); +}; + +template +CrossTrainerCache::CrossTrainerCache( + size_t max_cache_size_bytes, + std::unique_ptr> cachable_sequence) + : max_cache_size_bytes_(max_cache_size_bytes), + cachable_sequence_(std::move(cachable_sequence)) { + DCHECK_GT(max_cache_size_bytes, 0) + << "CrossTrainerCache size must be greater than 0."; + VLOG(2) << "Initialized tf.data service cross-trainer cache with " + << ByteSize::Bytes(max_cache_size_bytes) << " of memory."; +} + +template +StatusOr> +CrossTrainerCache::Get(const std::string& trainer_id) + TF_LOCKS_EXCLUDED(mu_) { + if (trainer_id.empty()) { + return errors::InvalidArgument( + "tf.data service cross-trainer cache requires a non-empty trainer ID."); + } + + TF_ASSIGN_OR_RETURN(CacheQueryResult result, GetCacheQueryResult(trainer_id)); + RecordMetrics(result); + return result.element; +} + +template +StatusOr::CacheQueryResult> +CrossTrainerCache::GetCacheQueryResult( + const std::string& trainer_id) { + bool should_extend_cache = false; + while (true) { + { + mutex_lock l(mu_); + TF_RETURN_IF_ERROR(status_); + if (IsElementReady(trainer_id)) { + TF_ASSIGN_OR_RETURN(std::shared_ptr element, + GetElement(trainer_id)); + return CacheQueryResult{element, + /*is_cache_hit=*/!should_extend_cache}; + } + + // Extends the cache or waits for another thread to extend the cache. When + // concurrent trainers wait for the next element, only one of them should + // extend the cache. + if (extending_cache_) { + should_extend_cache = false; + cv_.wait(l); + } else { + should_extend_cache = true; + extending_cache_ = true; + } + } + + if (should_extend_cache) { + absl::Status s = ExtendCache(); + mutex_lock l(mu_); + extending_cache_ = false; + cv_.notify_all(); + TF_RETURN_IF_ERROR(s); + } + } +} + +template +bool CrossTrainerCache::IsElementReady( + const std::string& trainer_id) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { + return GetElementIndex(trainer_id) < cache_start_index_ + cache_.size(); +} + +template +StatusOr> +CrossTrainerCache::GetElement(const std::string& trainer_id) + TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { + size_t element_index = GetElementIndex(trainer_id); + if (element_index >= std::numeric_limits::max()) { + return errors::Internal( + "tf.data service caching element index exceeds integer limit. Got ", + element_index); + } + + std::shared_ptr result = + cache_[element_index - cache_start_index_]; + trainer_to_element_index_map_[trainer_id] = element_index + 1; + return result; +} + +template +size_t CrossTrainerCache::GetElementIndex( + const std::string& trainer_id) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { + size_t element_index = trainer_to_element_index_map_[trainer_id]; + if (element_index < cache_start_index_) { + element_index = cache_start_index_; + } + return element_index; +} + +template +absl::Status CrossTrainerCache::ExtendCache() + TF_LOCKS_EXCLUDED(mu_) { + TF_ASSIGN_OR_RETURN(ElementType element, cachable_sequence_->GetNext()); + size_t new_element_size_bytes = + cachable_sequence_->GetElementSizeBytes(element); + if (new_element_size_bytes > max_cache_size_bytes_) { + return errors::InvalidArgument( + "tf.data service element size is larger than cache size in bytes. Got ", + "element size: ", new_element_size_bytes, + " and cache size: ", max_cache_size_bytes_); + } + + mutex_lock l(mu_); + TF_RETURN_IF_ERROR(status_); + FreeSpace(new_element_size_bytes); + cache_.push_back(std::make_shared(std::move(element))); + cache_size_bytes_ += new_element_size_bytes; + return absl::OkStatus(); +} + +template +void CrossTrainerCache::FreeSpace(size_t new_element_size_bytes) + TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { + size_t num_elements_discarded = 0; + while (!cache_.empty() && + cache_size_bytes_ + new_element_size_bytes > max_cache_size_bytes_) { + size_t free_bytes = + cachable_sequence_->GetElementSizeBytes(*cache_.front()); + cache_.pop_front(); + cache_size_bytes_ -= free_bytes; + ++cache_start_index_; + ++num_elements_discarded; + } + + VLOG(3) << "Freed " << num_elements_discarded << " element(s) from " + << "tf.data service cross-trainer cache. Memory usage: " + << ByteSize::Bytes(cache_size_bytes_) << "."; +} + +template +void CrossTrainerCache::Cancel(absl::Status status) + TF_LOCKS_EXCLUDED(mu_) { + DCHECK(!status.ok()) + << "Cancelling CrossTrainerCache requires a non-OK status. Got " + << status; + VLOG(2) << "Cancel tf.data service cross-trainer cache with status " + << status; + mutex_lock l(mu_); + status_ = std::move(status); + cv_.notify_all(); +} + +template +bool CrossTrainerCache::IsCancelled() const + TF_LOCKS_EXCLUDED(mu_) { + mutex_lock l(mu_); + return !status_.ok(); +} + +template +void CrossTrainerCache::RecordMetrics( + const CacheQueryResult& result) { + metrics::RecordTFDataServiceCrossTrainerCacheQuery(result.cache_hit); + size_t cache_size_bytes = 0; + { + mutex_lock l(mu_); + cache_size_bytes = cache_size_bytes_; + } + metrics::RecordTFDataServiceCrossTrainerCacheSizeBytes(cache_size_bytes); +} + +} // namespace data +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_DATA_SERVICE_CROSS_CLIENT_CACHE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/data/service/data_transfer.h b/third_party/tflite-hdrs/tensorflow/core/data/service/data_transfer.h new file mode 100644 index 00000000..23c8247d --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/data/service/data_transfer.h @@ -0,0 +1,152 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_DATA_SERVICE_DATA_TRANSFER_H_ +#define TENSORFLOW_CORE_DATA_SERVICE_DATA_TRANSFER_H_ + +#include +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "tensorflow/core/data/service/worker.pb.h" +#include "tensorflow/core/framework/dataset.h" +#include "tensorflow/core/framework/dataset.pb.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/protobuf/service_config.pb.h" + +namespace tensorflow { +namespace data { + +// The result of a GetElement request. Exactly one of the following will be +// true: (1) `components` is nonempty (2) `end_of_sequence` is true (3) `skip` +// is true. +struct GetElementResult { + GetElementResult() = default; + GetElementResult(const GetElementResult&) = delete; + GetElementResult& operator=(const GetElementResult&) = delete; + GetElementResult(GetElementResult&&) = default; + GetElementResult& operator=(GetElementResult&&) = default; + + // Creates a copy of this result. This is used to create multiple copies of + // the same cached value. + GetElementResult Copy() const; + + // Estimated memory used by this object, measured in bytes. + size_t EstimatedMemoryUsageBytes() const; + + // A dataset element produced by a GetElement request. + std::vector components; + // The element's index within the task it came from. + int64_t element_index = 0; + // If true, indicates that there is no more data to read. + bool end_of_sequence = false; + // If true, indicates that there is still data, but the caller should skip + // reading from the worker. This is used for load balancing when doing round + // robin reads. + bool skip = false; +}; + +// Client for communicating with the tf.data service transfer server. +class DataTransferClient { + public: + struct Config { + absl::string_view protocol; + std::string address; + const DeviceBase::AcceleratorDeviceInfo* accelerator_device_info; + Allocator* allocator; + }; + using ClientFactoryT = + std::function*)>; + virtual ~DataTransferClient() = default; + + // Fetches the next element. + virtual absl::Status GetElement(const GetElementRequest& req, + GetElementResult& result) = 0; + + // Makes a best effort to cancel all outstanding calls in progress for the + // client, and causes further calls to return Cancelled status. + virtual void TryCancel() = 0; + + // Registers a DataTransferClient factory under `name`. + static void Register(std::string name, ClientFactoryT factory); + + // Builds a DataTransferClient from the factory registered under `name`. + static absl::Status Build(std::string name, Config config, + std::unique_ptr* out); + + // Returns a string describing properties of the client relevant for checking + // compatibility with a server for a given protocol. + virtual absl::StatusOr GetCompatibilityInfo() const { + return std::string(); + } + + // Returns an error if the client is incompatible with a server which has the + // properties described in `server_compatibility_info`. + virtual absl::Status CheckCompatibility( + const std::string& server_compatibility_info) const { + return absl::OkStatus(); + } + + protected: + Env* const env_ = Env::Default(); +}; + +// Server for communicating with the tf.data service transfer client. +class DataTransferServer { + public: + using GetElementT = + std::function; + using ServerFactoryT = std::function*)>; + virtual ~DataTransferServer() = default; + + // Starts DataTransferServer, it should be available for requests afterwards. + virtual absl::Status Start(const experimental::WorkerConfig& config) = 0; + + // Return the port that this server is listening on. + virtual int Port() const = 0; + + // Register a DataTransferServer factory under `name`. + static void Register(std::string name, ServerFactoryT factory); + + // Builds a DataTransferServer from the factory registered with `name`. + static absl::Status Build(std::string name, GetElementT get_element, + std::shared_ptr* out); + + // Returns a string describing properties of the server relevant for checking + // compatibility with a client for a given protocol. + virtual absl::StatusOr GetCompatibilityInfo() const { + return std::string(); + } + + // If `true`, data service clients should fall back to gRPC for this server if + // they fail to create a data transfer client for it. + virtual bool FallBackToGrpcAtClientCreationTime() const { return true; } + + // If `true`, data service clients should fall back to gRPC for this server if + // it nonretryably fails to transfer an element. + virtual bool FallBackToGrpcAtGetElementTime() const { return true; } +}; + +} // namespace data +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_DATA_SERVICE_DATA_TRANSFER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/data/service/dataset_store.h b/third_party/tflite-hdrs/tensorflow/core/data/service/dataset_store.h new file mode 100644 index 00000000..f79120bd --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/data/service/dataset_store.h @@ -0,0 +1,81 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_DATA_SERVICE_DATASET_STORE_H_ +#define TENSORFLOW_CORE_DATA_SERVICE_DATASET_STORE_H_ + +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "tensorflow/core/data/service/dispatcher_state.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/io/record_reader.h" +#include "tensorflow/core/lib/io/record_writer.h" +#include "tensorflow/core/platform/env.h" + +namespace tensorflow { +namespace data { + +// An interface for storing and getting dataset definitions. +class DatasetStore { + public: + virtual ~DatasetStore() = default; + + // Stores the given dataset under the given key. Overwrites a dataset if it + // already exists. + virtual absl::Status Put(const std::string& key, + const DatasetDef& dataset) = 0; + // Gets the dataset for the given key, storing the dataset in `dataset_def`. + virtual absl::Status Get(const std::string& key, + std::shared_ptr& dataset_def) = 0; +}; + +// Dataset store which reads and writes datasets within a directory. +// The dataset with key `key` is stored at the path "datasets_dir/key". +class FileSystemDatasetStore : public DatasetStore { + public: + explicit FileSystemDatasetStore(const std::string& datasets_dir); + FileSystemDatasetStore(const FileSystemDatasetStore&) = delete; + FileSystemDatasetStore& operator=(const FileSystemDatasetStore&) = delete; + + absl::Status Put(const std::string& key, const DatasetDef& dataset) override; + absl::Status Get(const std::string& key, + std::shared_ptr& dataset_def) override; + + private: + const std::string datasets_dir_; +}; + +// DatasetStore which stores all datasets in memory. This is useful when the +// dispatcher doesn't have a work directory configured. +class MemoryDatasetStore : public DatasetStore { + public: + MemoryDatasetStore() = default; + MemoryDatasetStore(const MemoryDatasetStore&) = delete; + MemoryDatasetStore& operator=(const MemoryDatasetStore&) = delete; + + absl::Status Put(const std::string& key, const DatasetDef& dataset) override; + absl::Status Get(const std::string& key, + std::shared_ptr& dataset_def) override; + + private: + // Mapping from key to dataset definition. + absl::flat_hash_map> datasets_; +}; + +} // namespace data +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_DATA_SERVICE_DATASET_STORE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/data/service/dispatcher_client.h b/third_party/tflite-hdrs/tensorflow/core/data/service/dispatcher_client.h new file mode 100644 index 00000000..253d8ec0 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/data/service/dispatcher_client.h @@ -0,0 +1,153 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_DATA_SERVICE_DISPATCHER_CLIENT_H_ +#define TENSORFLOW_CORE_DATA_SERVICE_DISPATCHER_CLIENT_H_ + +#include +#include +#include +#include +#include + +#include "tensorflow/core/data/service/common.h" +#include "tensorflow/core/data/service/common.pb.h" +#include "tensorflow/core/data/service/dispatcher.grpc.pb.h" +#include "tensorflow/core/data/service/dispatcher.pb.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/statusor.h" +#include "tensorflow/core/protobuf/data_service.pb.h" +#include "tensorflow/core/protobuf/service_config.pb.h" +#include "tensorflow/core/protobuf/snapshot.pb.h" + +namespace tensorflow { +namespace data { + +// Client for communicating with the tf.data service dispatcher. +class DataServiceDispatcherClient : public DataServiceClientBase { + public: + DataServiceDispatcherClient(const std::string& address, + const std::string& protocol) + : DataServiceClientBase(address, protocol) {} + + absl::Status Initialize() override; + + // Sends a heartbeat to the dispatcher. If the worker wasn't already + // registered with the dispatcher, this will register the worker. The + // dispatcher will report which new tasks the worker should run, and which + // tasks it should delete. + absl::StatusOr WorkerHeartbeat( + const WorkerHeartbeatRequest& request); + + // Updates the dispatcher with information about the worker's state. + absl::Status WorkerUpdate(const std::string& worker_address, + std::vector& task_progress); + + // Gets a dataset definition for the given dataset id, and stores the + // definition in `dataset_def`. + absl::Status GetDatasetDef(const std::string& dataset_id, + DatasetDef& dataset_def); + + // Gets the next split for the specified iteration id, repetition, and split + // provider index. + absl::Status GetSplit(int64_t iteration_id, int64_t repetition, + int64_t split_provider_index, Tensor& split, + bool& end_of_splits); + + // Gets the next split for the specified source of a stream of the snapshot in + // `base_path`. If `end_of_splits` returns true, then there are no more splits + // to be processed for the specified stream source. + virtual absl::Status GetSnapshotSplit( + const std::string& worker_address, const std::string& base_path, + int64_t stream_index, int64_t source_index, int64_t repetition_index, + Tensor& split, int64_t& local_split_index, bool& end_of_splits); + + // Initiates the process of materializing `dataset`'s output to `path`. + absl::Status Snapshot( + const DatasetDef& dataset, const std::string& path, + const experimental::DistributedSnapshotMetadata& metadata); + + // Registers a dataset with the tf.data service, and stores the generated + // dataset id in `dataset_id`. + absl::Status RegisterDataset( + const DatasetDef& dataset, const DataServiceMetadata& metadata, + const std::optional& requested_dataset_id, + std::string& dataset_id); + + // If `job_name` is set, looks up a job matching `job_name`. + // If `job_name` is absent or no matching job is found, creates a + // new job. The resulting job id is stored in `job_id`. + absl::Status GetOrCreateJob(const std::string& dataset_id, + const ProcessingModeDef& processing_mode, + const std::optional& job_name, + std::optional num_consumers, + bool use_cross_trainer_cache, + TargetWorkers target_workers, int64_t& job_id); + + // Looks up an iteration of a job, creating an iteration if one doesn't + // already exist. The returned `iteration_client_id` can be used to query + // information about the iteration. The client should call + // `ReleaseIterationClient` when finished with the iteration, so that + // resources can be reclaimed. + absl::Status GetOrCreateIteration(int64_t job_id, int64_t repetition, + int64_t& iteration_client_id); + + // Releases a iteration client id, indicating that the id will no longer be + // used to read from the iteration. + absl::Status ReleaseIterationClient(int64_t iteration_client_id); + + // Attempts to remove a task. The task is removed if all consumers try to + // remove the task in the same round. + absl::Status MaybeRemoveTask(int64_t task_id, int64_t consumer_index, + int64_t round, bool& removed); + + // Heartbeats to the dispatcher, getting back the tasks that should be + // running, and whether the iteration is finished. + absl::Status ClientHeartbeat(ClientHeartbeatRequest& req, + ClientHeartbeatResponse& resp); + + // Queries the dispatcher for its registered workers. The worker info will be + // stored in `workers`. + absl::Status GetWorkers(std::vector& workers); + + // Returns data service metadata for the registered dataset. + absl::Status GetDataServiceMetadata(const std::string& dataset_id, + DataServiceMetadata& metadata); + + // Returns data service config of the data service cluster. + absl::Status GetDataServiceConfig(DataServiceConfig& config); + + // Returns information about the decision to disable compression at runtime + // for a given dataset. + absl::Status DisableCompressionAtRuntime( + const std::string& dataset_id, bool disable_compression_at_runtime, + DisableCompressionAtRuntimeResponse& response); + + protected: + absl::Status EnsureInitialized() override; + + private: + mutex mu_; + // Initialization is guarded by `mu_`, but using the stub does not require + // holding `mu_` + std::unique_ptr stub_; +}; + +} // namespace data +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_DATA_SERVICE_DISPATCHER_CLIENT_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/data/service/dispatcher_impl.h b/third_party/tflite-hdrs/tensorflow/core/data/service/dispatcher_impl.h new file mode 100644 index 00000000..6fa299dc --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/data/service/dispatcher_impl.h @@ -0,0 +1,412 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_DATA_SERVICE_DISPATCHER_IMPL_H_ +#define TENSORFLOW_CORE_DATA_SERVICE_DISPATCHER_IMPL_H_ + +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/status/status.h" +#include "absl/time/time.h" +#include "tensorflow/core/data/service/auto_scaler.h" +#include "tensorflow/core/data/service/common.pb.h" +#include "tensorflow/core/data/service/dataset_store.h" +#include "tensorflow/core/data/service/dispatcher.pb.h" +#include "tensorflow/core/data/service/dispatcher_state.h" +#include "tensorflow/core/data/service/export.pb.h" +#include "tensorflow/core/data/service/snapshot/snapshot_manager.h" +#include "tensorflow/core/data/service/task_remover.h" +#include "tensorflow/core/data/service/worker.grpc.pb.h" +#include "tensorflow/core/framework/dataset.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/statusor.h" +#include "tensorflow/core/platform/thread_annotations.h" +#include "tensorflow/core/protobuf/data_service.pb.h" +#include "tensorflow/core/protobuf/service_config.pb.h" + +namespace tensorflow { +namespace data { + +// A service which coordinates a pool of workers to serve dataset elements over +// RPC. +// +// Glossary: +// * Dataset: A definition of how to generate a potentially large collection of +// elements. +// * Iteration: A coordinated phase of reading from the tf.data service. An +// iteration produces some amount of data, and (potentially multiple) +// consumers consume the data from the iteration until there is no data left. +// Each iteration has a ProcessingModeDef which determines what data it +// produces. +// * Task: An iteration is broken into multiple tasks, which each represent +// iterating over all of or part of the dataset. Workers process tasks. +// * Consumer: A process reading from the tf.data service. +// +// **Adding workers** +// +// tf.data service supports adding workers mid-iteration. When a new worker +// connects to the dispatcher, the dispatcher creates a new task for the worker, +// one task for each outstanding iteration. Consumers periodically heartbeat to +// the dispatcher to learn about new tasks. +// +// For non-round-robin-reads, there is no coordination among consumers. Each +// consumer will start reading from the new task as soon as it learns about the +// task from its heartbeat. Round robin reads, on the other hand, require +// consumers to read from the same task at each step. This requires coordination +// to ensure that all consumers start reading from the new task in the same +// round. +// +// The protocol for adding round robin tasks works as follows: +// +// - The dispatcher keeps track of which round each round-robin iteration is on. +// This +// information is reported by consumers in their heartbeats. +// - When a new worker joins and there is an outstanding round-robin iteration, +// we create a new task for the iteration and assign it to the worker. +// However, we don't yet report the task in consumer heartbeats. +// We call the task a "pending task" and add it to its iteration's "pending +// tasks" queue. +// - When we create a pending task, we choose a "target round" to try adding +// the task to. The target round is chosen by adding a "target round delta" to +// the latest reported round for the iteration. +// - When a consumer heartbeats for an iteration and there is a pending task for +// that iteration, the dispatcher sends a heartbeat response telling the +// consumer to block before reading from the target round. +// - When a consumer receives a heartbeat response telling it to block +// (before reading) a round, the consumer try to block the round. If the +// consumer has already started the round, it will too late to block the +// round. +// - When consumers heartbeat, they tell the dispatcher their current round and +// whether they have blocked themselves from reading past a certain round. If +// a consumer reports a current round exceeding the target round, the target +// round has failed and needs to be increased. We choose a new target round by +// doubling the previous target round delta. If the consumer reports that it +// has blocked before the target round, we record that the consumer is ready +// to add the new task. Once all consumers are ready to add the new task, we +// remove the task from the pending tasks list and begin reporting the task to +// consumers. We set the "starting_round" field of the task to indicate the +// target round where all consumers should start reading from the task. +// - If a new worker joins while there are already pending tasks, a pending +// task for the new worker is created and queued behind the existing tasks. +// The new task won't be considered until all previous pending tasks have been +// successfully added. +// +// An example of executing this protocol with two consumers could go as follows: +// 1. Consumers read up to round 50 and heartbeat that they are on round 50. +// 2. A new worker joins. Dispatcher chooses round 51 as the target round. +// 3. Consumer 1 heartbeats that its current round is 50. Dispatcher tells it to +// block round 51. +// 4. Consumer 2 heartbeats that its current round is 51. Dispatcher realizes +// that it is too late to block round 51 and chooses round 53 as the new +// target round. Dispatcher tells consumer 2 to block round 53. +// 5. Consumer 1 heartbeats that its current round is 50 and that it has blocked +// round 51. Dispatcher tells it to block round 53 instead. Dispatcher +// records that consumer 1 is ready to add a task in round 53. +// 6. Consumer 2 heartbeats that its current round is 52 and it has blocked +// round 53. Dispatcher realizes that all consumers are blocked on round 53 +// or earlier and promotes the task from pending to regular. Dispatcher sends +// consumer 2 a task list containing the new task, and tells consumer 2 that +// it no longer needs to block. +// 7. Consumer 1 heartbeats. Dispatcher sends consumer 1 the task list +// containing the new task, and tells it that it no longer needs to block. +// +class DataServiceDispatcherImpl { + public: + explicit DataServiceDispatcherImpl( + const experimental::DispatcherConfig& config); + + ~DataServiceDispatcherImpl(); + + // Starts the dispatcher. If there is a journal, this will read from the + // journal to restore the dispatcher's state. + absl::Status Start(); + + // Stops the dispatcher. After stopping, RPCs should return without blocking. + void Stop(); + + // Returns the number of active iterations. + size_t NumActiveIterations() TF_LOCKS_EXCLUDED(mu_); + + // See dispatcher.proto for API documentation. + + /// Worker-facing API. + absl::Status WorkerHeartbeat(const WorkerHeartbeatRequest* request, + WorkerHeartbeatResponse* response); + absl::Status WorkerUpdate(const WorkerUpdateRequest* request, + WorkerUpdateResponse* response); + absl::Status GetDatasetDef(const GetDatasetDefRequest* request, + GetDatasetDefResponse* response); + absl::Status GetSplit(const GetSplitRequest* request, + GetSplitResponse* response); + + /// Client-facing API. + absl::Status GetVersion(const GetVersionRequest* request, + GetVersionResponse* response); + absl::Status GetOrRegisterDataset(const GetOrRegisterDatasetRequest* request, + GetOrRegisterDatasetResponse* response); + absl::Status GetDataServiceMetadata( + const GetDataServiceMetadataRequest* request, + GetDataServiceMetadataResponse* response); + absl::Status GetDataServiceConfig(const GetDataServiceConfigRequest* request, + GetDataServiceConfigResponse* response); + absl::Status GetOrCreateJob(const GetOrCreateJobRequest* request, + GetOrCreateJobResponse* response); + absl::Status GetOrCreateIteration(const GetOrCreateIterationRequest* request, + GetOrCreateIterationResponse* response); + absl::Status ReleaseIterationClient( + const ReleaseIterationClientRequest* request, + ReleaseIterationClientResponse* response); + absl::Status MaybeRemoveTask(const MaybeRemoveTaskRequest* request, + MaybeRemoveTaskResponse* response); + absl::Status ClientHeartbeat(const ClientHeartbeatRequest* request, + ClientHeartbeatResponse* response); + absl::Status GetWorkers(const GetWorkersRequest* request, + GetWorkersResponse* response); + absl::Status Snapshot(const SnapshotRequest* request, + SnapshotResponse* response); + absl::Status GetSnapshotSplit(const GetSnapshotSplitRequest* request, + GetSnapshotSplitResponse* response); + absl::Status GetSnapshotStreams(const GetSnapshotStreamsRequest* request, + GetSnapshotStreamsResponse* response); + absl::Status DisableCompressionAtRuntime( + const DisableCompressionAtRuntimeRequest* request, + DisableCompressionAtRuntimeResponse* response); + + // Exports the dispatcher state for debugging. + DispatcherStateExport ExportState() const; + + private: + // A thread which periodically checks for iterations to clean up, clients to + // release, workers to consider missing, and snapshot streams to reassign. + void MaintenanceThread(); + + // Restores split providers from the state in `iteration` and stores them in + // `restored`. + absl::Status RestoreSplitProviders( + const DispatcherState::Iteration& iteration, + std::vector>& restored) + TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); + // Makes split providers for the specified `dataset_id`, and stores them in + // `split_providers`. + absl::Status MakeSplitProviders( + const std::string& dataset_id, + std::vector>& split_providers) + TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); + // Registers a dataset, storing the new dataset's id in `dataset_id`. + absl::Status RegisterDataset(const DatasetDef& dataset, + const DataServiceMetadata& metadata, + const std::string& requested_dataset_id, + std::string& dataset_id) + TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); + // Finds the dataset ID with the requested dataset ID. + // Returns nullptr if no such dataset exists. + absl::StatusOr> FindDataset( + const GetOrRegisterDatasetRequest& request); + // Gets a worker's stub from `worker_stubs_`, or if none exists, creates a + // stub and stores it in `worker_stubs_`. A borrowed pointer to the stub is + // stored in `out_stub`. + absl::Status GetOrCreateWorkerStub(const std::string& worker_address, + WorkerService::Stub*& out_stub) + TF_LOCKS_EXCLUDED(mu_); + // Creates a job and stores it in `job`. + absl::Status CreateJob(const std::string& job_name, + const GetOrCreateJobRequest& request, + std::shared_ptr& job) + TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); + // Creates an iteration and stores it in `iteration`. This method updates the + // dispatcher state with the new iteration, but does not assign tasks to + // workers. + absl::Status CreateIteration( + const GetOrCreateIterationRequest& request, + std::shared_ptr& iteration) + TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); + // Creates tasks for the specified worker, one task for every unfinished + // iteration. + absl::Status CreateTasksForWorker(const std::string& worker_address); + // Finds tasks that should be deleted from a worker, updating the heartbeat + // response. + absl::Status FindTasksToDelete( + const absl::flat_hash_set& current_tasks, + const std::vector>& + assigned_tasks, + WorkerHeartbeatResponse* response); + // Finds new tasks that should be assigned to a worker and adds them to + // the heartbeat response. + absl::Status FindNewTasks( + const std::string& worker_address, + const absl::flat_hash_set& current_tasks, + std::vector>& assigned_tasks, + WorkerHeartbeatResponse* response); + // Reports the processing time of each active task to `auto_scaler_`. + void ReportProcessingTimesFromActiveTasks( + const std::vector& active_tasks, + const std::string& worker_address) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); + // Acquires an iteration client id to read from the given iteration and sets + // `iteration_client_id`. + absl::Status AcquireIterationClientId( + const std::shared_ptr& iteration, + int64_t& iteration_client_id) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); + // Creates one task for each worker, for the given iteration. The created + // tasks are stored in `tasks`. This method only updates dispatcher metadata + // with the new tasks, but doesn't assign the tasks to the workers. + absl::Status CreateTasksForIteration( + std::shared_ptr iteration, + std::vector>& tasks) + TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); + + // Creates a new task for an iteration. The created task may be either + // pending or active. + absl::Status CreateTask( + std::shared_ptr iteration, + const std::string& worker_address, + std::shared_ptr& task) + TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); + // Creates a pending task for a round robin iteration. All consumers need to + // agree on which round to add the task in before the pending task can be + // promoted to a regular task. + absl::Status CreatePendingTask( + std::shared_ptr iteration, + const std::string& worker_address) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); + // Creates a new active task for an iteration, storing the created task in + // `task`. + absl::Status CreateActiveTask( + std::shared_ptr iteration, + const std::string& worker_address, + std::shared_ptr& task); + // Assigns the list of tasks to the workers indicated by their + // `worker_address` fields. + absl::Status AssignTasks( + std::vector> tasks) + TF_LOCKS_EXCLUDED(mu_); + // Assigns a task to the worker indicated by its `worker_address` field. + absl::Status AssignTask(std::shared_ptr task) + TF_LOCKS_EXCLUDED(mu_); + // Validates that an existing job matches a given request. + // Returns an error status describing any difference. + absl::Status ValidateMatchingJob( + std::shared_ptr job, + const GetOrCreateJobRequest& request) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); + // Fills out a TaskDef with information about a task. + absl::Status PopulateTaskDef( + std::shared_ptr task, + TaskDef* task_def) const TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); + // Checks that the dispatcher has started, returning UNAVAILABLE if it hasn't. + absl::Status CheckStarted() TF_LOCKS_EXCLUDED(mu_); + // Restores ongoing tf.data snapshots. + absl::Status RestoreSnapshots(); + // Records that a split was produced by a call to `GetSplit`. + absl::Status RecordSplitProduced(int64_t iteration_id, int64_t repetition, + int64_t split_provider_index, bool finished) + TF_LOCKS_EXCLUDED(mu_); + // Applies a state update, updating both the journal and the in-memory state. + absl::Status Apply(const Update& update) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); + // Applies a state update, but doesn't update the journal. Only meant to be + // used when recovering state when the dispatcher starts. + absl::Status ApplyWithoutJournaling(const Update& update) + TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); + // Removes the client with `client_id` from `auto_scaler_` + void RemoveClientFromAutoScaler(int64_t client_id) + TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); + // Releases iteration clients that haven't heartbeated recently. + absl::Status ReleaseMissingClients() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); + // Removes the worker with `worker_address` from `auto_scaler_`, which is + // potentially associated with multiple iterations. + void RemoveWorkerFromAutoScaler(const std::string& worker_address) + TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); + // Checks for workers that haven't heartbeated recently and alerts the + // snapshot managers. + void DetectMissingWorkers() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); + // Scans for old iterations and marks them as finished. + absl::Status GcOldIterations() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); + // Returns true if an iteration should be garbage collected. + bool ShouldGcIteration(const DispatcherState::Iteration& iteration, + int64_t now_us) const; + // Gets a `DatasetDef` from `dataset_store_` for the given dataset id, and + // stores it in `dataset_def`. + absl::Status GetDatasetDef(const std::string& dataset_id, + std::shared_ptr& dataset_def) + TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); + // Gets a `DatasetDef` from `dataset_store_` for the given dataset, and + // stores it in `dataset_def`. + absl::Status GetDatasetDef(const DispatcherState::Dataset& dataset, + std::shared_ptr& dataset_def) + TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); + + const experimental::DispatcherConfig config_; + Env* env_; + + mutable mutex mu_; + // Uses a separate mutex for `GetSplit` requests. `GetSplit` may be blocking. + // Locking `mu_` in `GetSplit` could block all other RPCs. + mutable mutex get_split_mu_; + bool started_ TF_GUARDED_BY(mu_) = false; + bool cancelled_ TF_GUARDED_BY(mu_) = false; + + // Cached worker stubs for communicating with workers. + absl::flat_hash_map> + worker_stubs_ TF_GUARDED_BY(mu_); + // Store of dataset definitions. + std::unique_ptr dataset_store_ TF_GUARDED_BY(mu_); + // Mapping from iteration id to the split providers for the iteration. + absl::flat_hash_map>> + split_providers_ TF_GUARDED_BY(mu_); + // Mapping from round robin iteration id to the round the iteration is + // currently on. This is based on the data provided by client heartbeats, + // and may be stale. + absl::flat_hash_map round_robin_rounds_ TF_GUARDED_BY(mu_); + // Map from task id to a TaskRemover which determines when to remove the task. + absl::flat_hash_map> + remove_task_requests_ TF_GUARDED_BY(mu_); + // Map from client id to the time of the client's last heartbeat. + absl::flat_hash_map latest_client_heartbeats_time_ + TF_GUARDED_BY(mu_); + // Map from worker address to the time of the worker's last heartbeat. + absl::flat_hash_map latest_worker_heartbeats_time_ + TF_GUARDED_BY(mu_); + + // A manager for each snapshot resumed or started during the lifetime of this + // dispatcher instance. Note that these are *not* garbage collected; managers + // for completed snapshots will remain here for the lifetime of the dispatcher + // instance. They will even be recovered if the dispatcher is restarted. + absl::flat_hash_map> snapshots_ + TF_GUARDED_BY(mu_); + // A single stream assignment manager shared by all managers in `snapshots_`. + SnapshotAssignmentManager snapshot_assignment_manager_; + + std::optional> journal_writer_ + TF_GUARDED_BY(mu_); + DispatcherState state_ TF_GUARDED_BY(mu_); + // Condition variable for waking up the gc thread. + condition_variable maintenance_thread_cv_; + std::unique_ptr maintenance_thread_; + MultipleIterationsAutoScaler auto_scaler_; + + DataServiceDispatcherImpl(const DataServiceDispatcherImpl&) = delete; + void operator=(const DataServiceDispatcherImpl&) = delete; +}; + +} // namespace data +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_DATA_SERVICE_DISPATCHER_IMPL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/data/service/dispatcher_state.h b/third_party/tflite-hdrs/tensorflow/core/data/service/dispatcher_state.h new file mode 100644 index 00000000..054c3203 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/data/service/dispatcher_state.h @@ -0,0 +1,381 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_DATA_SERVICE_DISPATCHER_STATE_H_ +#define TENSORFLOW_CORE_DATA_SERVICE_DISPATCHER_STATE_H_ + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/strings/string_view.h" +#include "tensorflow/core/data/service/common.h" +#include "tensorflow/core/data/service/common.pb.h" +#include "tensorflow/core/data/service/graph_rewriters.h" +#include "tensorflow/core/data/service/journal.h" +#include "tensorflow/core/data/service/journal.pb.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/protobuf/data_service.pb.h" +#include "tensorflow/core/protobuf/service_config.pb.h" + +namespace tensorflow { +namespace data { + +// A class encapsulating the journaled state of the dispatcher. All state +// modifications must be done via `Apply`. This helps to ensure that +// replaying the journal will allow us to restore the exact same state. +// +// The following usage pattern will keep the journal in sync with the state of +// the dispatcher: +// { +// mutex_lock l(mu_); +// Update update = ... // create an update +// dispatcher_state.Apply(update); +// journal_writer.write(Update); +// // Unlock mu_ +// } +// +// The division of functionality between DispatcherImpl and DispatcherState is +// as follows: +// - DispatcherImpl is responsible for handling RPC requests, reading from +// DispatcherState, and deciding what updates to apply to DispatcherState. +// DispatcherImpl handles all synchronization. +// - DispatcherState is responsible for making the state changes requested by +// DispatcherImpl and for providing DispatcherImpl with read-only access to +// the state. +// +// DispatcherState is thread-compatible but not thread-safe. +class DispatcherState { + public: + DispatcherState(); + explicit DispatcherState( + const experimental::DispatcherConfig& dispatcher_config); + DispatcherState(const DispatcherState&) = delete; + DispatcherState& operator=(const DispatcherState&) = delete; + + // Applies the given update to the dispatcher's state. + absl::Status Apply(const Update& update); + + // A dataset registered with the dispatcher. + struct Dataset { + explicit Dataset(const std::string& dataset_id, + const DataServiceMetadata& metadata) + : dataset_id(dataset_id), metadata(metadata) {} + + const std::string dataset_id; + const DataServiceMetadata metadata; + }; + + // A worker registered with the dispatcher. + struct Worker { + explicit Worker(const RegisterWorkerUpdate& register_worker) + : address(register_worker.worker_address()), + transfer_servers({register_worker.transfer_servers().begin(), + register_worker.transfer_servers().end()}), + tags(register_worker.worker_tags().begin(), + register_worker.worker_tags().end()), + uid(register_worker.worker_uid()) {} + + const std::string address; + const std::vector transfer_servers; + const std::vector tags; + const int64_t uid; + }; + + // A key for identifying an iteration. The key contains a job name, + // as well as a repetition number describing which repetition of the job + // we are on. + struct IterationKey { + explicit IterationKey(absl::string_view name, int64_t repetition) + : name(name), repetition(repetition) {} + + friend bool operator==(const IterationKey& lhs, const IterationKey& rhs) { + return lhs.name == rhs.name && lhs.repetition == rhs.repetition; + } + + template + friend H AbslHashValue(H h, const IterationKey& k) { + return H::combine(std::move(h), k.name, k.repetition); + } + + std::string DebugString() const { + return absl::StrCat(name, "/", repetition); + } + + const std::string name; + const int64_t repetition; + }; + + struct DistributedEpochState { + explicit DistributedEpochState(int64_t num_split_providers) + : repetitions(num_split_providers), indices(num_split_providers) {} + + // The current repetition for each split provider. + std::vector repetitions; + // Number of splits produced so far by each split provider. + std::vector indices; + }; + + struct Task; + + struct PendingTask { + explicit PendingTask(std::shared_ptr task, int64_t target_round) + : task(std::move(task)), target_round(target_round) {} + + std::shared_ptr task; + // The target round where we want to insert the task. + int64_t target_round; + // Which consumers have responded that they have successfully blocked + // before the target round. + absl::flat_hash_set ready_consumers; + // How many times we have failed to add the task. + int64_t failures = 0; + }; + + struct Job { + explicit Job(int64_t id, const std::string& dataset_id, + const ProcessingModeDef& processing_mode, std::string job_name, + std::optional num_consumers, + bool use_cross_trainer_cache, TargetWorkers target_workers) + : id(id), + dataset_id(dataset_id), + processing_mode(processing_mode), + job_name(job_name), + num_consumers(num_consumers), + use_cross_trainer_cache(use_cross_trainer_cache), + target_workers(target_workers) {} + + const int64_t id; + const std::string dataset_id; + const ProcessingModeDef processing_mode; + const std::string job_name; + const std::optional num_consumers; + const bool use_cross_trainer_cache; + const TargetWorkers target_workers; + }; + + // An iteration for processing a dataset. + struct Iteration { + explicit Iteration(int64_t iteration_id, IterationKey iteration_key, + int64_t num_split_providers, std::shared_ptr job) + : iteration_id(iteration_id), iteration_key(iteration_key), job(job) { + if (IsDynamicShard(job->processing_mode)) { + distributed_epoch_state = DistributedEpochState(num_split_providers); + } + } + + bool IsRoundRobin() const { return job->num_consumers.has_value(); } + + std::string DebugString() const { + return absl::StrCat(iteration_key.name, "_", iteration_key.repetition); + } + + const int64_t iteration_id; + const IterationKey iteration_key; + const std::shared_ptr job; + std::optional distributed_epoch_state; + std::queue pending_tasks; + int64_t num_clients = 0; + int64_t last_client_released_micros = -1; + bool finished = false; + // Indicates whether the iteration was garbage collected. + bool garbage_collected = false; + }; + + struct Task { + template + explicit Task(const T& create_task_update, + const std::shared_ptr& iteration) + : task_id(create_task_update.task_id()), + iteration(iteration), + worker_address(create_task_update.worker_address()), + transfer_servers(create_task_update.transfer_servers().begin(), + create_task_update.transfer_servers().end()), + worker_tags(create_task_update.worker_tags().begin(), + create_task_update.worker_tags().end()), + worker_uid(create_task_update.worker_uid()) {} + + const int64_t task_id; + const std::shared_ptr iteration; + const std::string worker_address; + const std::vector transfer_servers; + const std::vector worker_tags; + const int64_t worker_uid; + int64_t starting_round = 0; + bool finished = false; + bool removed = false; + }; + + using TasksById = absl::flat_hash_map>; + + // Returns the next available dataset ID. + std::string NextAvailableDatasetId() const; + + // Gets a dataset by id. Returns NOT_FOUND if there is no such dataset. + absl::Status DatasetFromId(const std::string& id, + std::shared_ptr& dataset) const; + + // Gets a worker by address. Returns NOT_FOUND if there is no such worker. + absl::Status WorkerFromAddress(const std::string& address, + std::shared_ptr& worker) const; + // Lists all workers registered with the dispatcher. + std::vector> ListWorkers() const; + + // Returns the next available job id. + int64_t NextAvailableJobId() const; + // Gets a job by id. Returns NOT_FOUND if there is no such job. + absl::Status JobFromId(int64_t job_id, std::shared_ptr& job) const; + // Gets a job by name. Returns NOT_FOUND if there is no such job. + absl::Status JobByName(const std::string& job_name, + std::shared_ptr& job) const; + + // Returns the next available iteration id. + int64_t NextAvailableIterationId() const; + // Returns a list of all iterations. + std::vector> ListIterations() const; + // Gets an iteration by id. Returns NOT_FOUND if there is no such iteration. + absl::Status IterationFromId( + int64_t id, std::shared_ptr& iteration) const; + // Gets an iteration by key. Returns NOT_FOUND if there is no such iteration. + absl::Status IterationByKey( + IterationKey key, std::shared_ptr& iteration) const; + + // Returns the iteration associated with the given iteration client id. + // Returns NOT_FOUND if the iteration_client_id is unknown or has been + // released. + absl::Status IterationForIterationClientId( + int64_t iteration_client_id, std::shared_ptr& iteration); + // Returns a list of all active client ids. + std::vector ListActiveClientIds(); + // Returns the next available iteration client id. + int64_t NextAvailableIterationClientId() const; + + // Returns the next available task id. + int64_t NextAvailableTaskId() const; + // Gets a task by id. Returns NOT_FOUND if there is no such task. + absl::Status TaskFromId(int64_t id, std::shared_ptr& task) const; + // Stores a list of all tasks for the given iteration to `tasks`. Returns + // NOT_FOUND if there is no such iteration. + absl::Status TasksForIteration( + int64_t iteration_id, + std::vector>& tasks) const; + // Stores a list of all tasks for the given worker to `tasks`. Returns + // NOT_FOUND if there is no such worker. + absl::Status TasksForWorker( + const absl::string_view worker_address, + std::vector>& tasks) const; + + // If the dispatcher config explicitly specifies a list of workers, validates + // `worker_address` is in the list. + absl::Status ValidateWorker(absl::string_view worker_address) const; + + // If the dispatcher config specifies worker addresses, `GetWorkerIndex` + // returns the worker index according to the list. This is useful for + // deterministically sharding a dataset among a fixed set of workers. + absl::StatusOr GetWorkerIndex( + absl::string_view worker_address) const; + + // Returns the paths of all snapshots initiated during the lifetime of this + // journal. + const absl::flat_hash_set& ListSnapshotPaths() const { + return snapshot_paths_; + } + + // Returns a bool describing whether or not compression was disabled at + // runtime for the given dataset, if such a decision has been made. + std::optional CompressionDisabledAtRuntime( + const std::string& dataset_id) const; + + // Returns the current number of registered workers. + int64_t GetNumberOfRegisteredWorkers() const { return workers_.size(); } + + private: + void RegisterDataset(const RegisterDatasetUpdate& register_dataset); + void RegisterWorker(const RegisterWorkerUpdate& register_worker); + void CreateJob(const CreateJobUpdate& create_job); + void CreateIteration(const CreateIterationUpdate& create_iteration); + void ProduceSplit(const ProduceSplitUpdate& produce_split); + void AcquireIterationClient( + const AcquireIterationClientUpdate& acquire_iteration_client); + void ReleaseIterationClient( + const ReleaseIterationClientUpdate& release_iteration_client); + void GarbageCollectIteration( + const GarbageCollectIterationUpdate& garbage_collect_iteration); + void RemoveTask(const RemoveTaskUpdate& remove_task); + void CreatePendingTask(const CreatePendingTaskUpdate& create_pending_task); + void ClientHeartbeat(const ClientHeartbeatUpdate& client_heartbeat); + void CreateTask(const CreateTaskUpdate& create_task); + void FinishTask(const FinishTaskUpdate& finish_task); + void Snapshot(const SnapshotUpdate& snapshot); + void CompressionDisabledAtRuntime(const CompressionDisabledAtRuntimeUpdate& + compression_disabled_at_runtime); + + // Updates the next available dataset ID. + void UpdateNextAvailableDatasetId(); + + int64_t next_available_dataset_id_ = 1000; + // Registered datasets, keyed by dataset ids. + absl::flat_hash_map> datasets_by_id_; + + // Registered workers, keyed by address. + absl::flat_hash_map> workers_; + + // Assigns an index to each worker according to worker addresses list + // specified in the dispatcher config. + WorkerIndexResolver worker_index_resolver_; + + int64_t next_available_job_id_ = 5000; + // Jobs, keyed by job ids. + absl::flat_hash_map> jobs_by_id_; + // Jobs, keyed by job names. + absl::flat_hash_map> jobs_by_name_; + + int64_t next_available_iteration_id_ = 2000; + // Iterations, keyed by iteration ids. + absl::flat_hash_map> iterations_; + // Iterations, keyed by their iteration keys. + absl::flat_hash_map> + iterations_by_key_; + + int64_t next_available_iteration_client_id_ = 3000; + // Mapping from client ids to the iterations they are associated with. + absl::flat_hash_map> + iterations_for_client_ids_; + + int64_t next_available_task_id_ = 4000; + // Tasks, keyed by task ids. + TasksById tasks_; + // List of tasks associated with each iteration. + absl::flat_hash_map>> + tasks_by_iteration_; + // Tasks, keyed by worker addresses. The values are a map from task id to + // task. + absl::flat_hash_map tasks_by_worker_; + // Paths for all snapshots initiated during the lifetime of this journal. + absl::flat_hash_set snapshot_paths_; + // A mapping of dataset id to a boolean describing whether or not compression + // was disabled at runtime for that dataset. + absl::flat_hash_map compression_disabled_at_runtime_; +}; + +} // namespace data +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_DATA_SERVICE_DISPATCHER_STATE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/data/service/graph_rewriters.h b/third_party/tflite-hdrs/tensorflow/core/data/service/graph_rewriters.h new file mode 100644 index 00000000..e1244fd5 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/data/service/graph_rewriters.h @@ -0,0 +1,108 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_DATA_SERVICE_GRAPH_REWRITERS_H_ +#define TENSORFLOW_CORE_DATA_SERVICE_GRAPH_REWRITERS_H_ + +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "tensorflow/core/data/service/common.pb.h" +#include "tensorflow/core/framework/dataset_options.pb.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/statusor.h" +#include "tensorflow/core/protobuf/rewriter_config.pb.h" + +namespace tensorflow { +namespace data { + +// Rewrites the dataset graph by removing the compression map. +class RemoveCompressionMapRewriter { + public: + // Returns `graph_def` with the compression map removed. + absl::StatusOr ApplyRemoveCompressionMapRewrite( + const GraphDef& graph_def); + + private: + tensorflow::RewriterConfig::CustomGraphOptimizer GetRewriteConfig() const; +}; + +// Rewrites the dataset graph by applying an auto-shard policy. +class AutoShardRewriter { + public: + // Creates an `AutoShardRewriter` according to `task_def`. Returns an error if + // the sharding policy is not a valid auto-shard policy. + static absl::StatusOr Create(const TaskDef& task_def); + + // Applies auto-sharding to `graph_def`. If auto-shard policy is OFF, returns + // the same graph as `graph_def`. Otherwise, returns the re-written graph. + absl::StatusOr ApplyAutoShardRewrite(const GraphDef& graph_def); + + private: + AutoShardRewriter(AutoShardPolicy auto_shard_policy, int64_t num_workers, + int64_t worker_index); + + // Creates a rewrite config based on the auto-shard policy. + tensorflow::RewriterConfig::CustomGraphOptimizer GetRewriteConfig() const; + + const AutoShardPolicy auto_shard_policy_; + const int64_t num_workers_; + const int64_t worker_index_; +}; + +// Maps a worker to its index, given a list of workers. For example, suppose +// `worker_addresses` contains +// /worker/task/0:worker, /worker/task/1:worker, /worker/task/2:worker, +// then +// /worker/task/0:worker maps to index 0, +// /worker/task/1:worker maps to index 1, +// /worker/task/2:worker maps to index 2. +// This is useful for deterministically sharding a dataset among a fixed set of +// tf.data service workers. +class WorkerIndexResolver { + public: + // Constructs a `WorkerIndexResolver` to generate worker indexes according to + // the specified worker addresses. The worker addresses can be "host" or + // "host:port", where "port" is a number, named port, or "%port%" to be + // replaced with the actual port. + template + explicit WorkerIndexResolver(const T& worker_addresses) + : worker_addresses_(worker_addresses.cbegin(), worker_addresses.cend()) {} + + // Validates `worker_address`. Returns an error if the `worker_addresses` list + // is non-empty and `worker_address` is not specified in the worker addresses + // list (with optional port replacement). + absl::Status ValidateWorker(absl::string_view worker_address) const; + + // Processes a worker at address `worker_address`. Its index can be retrieved + // by calling `GetWorkerIndex`. + void AddWorker(absl::string_view worker_address); + + // Returns the worker index for the worker at `worker_address`. Returns a + // NotFound error if the worker is not registered. + absl::StatusOr GetWorkerIndex( + absl::string_view worker_address) const; + + private: + std::vector worker_addresses_; +}; + +} // namespace data +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_DATA_SERVICE_GRAPH_REWRITERS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/data/service/grpc_dispatcher_impl.h b/third_party/tflite-hdrs/tensorflow/core/data/service/grpc_dispatcher_impl.h new file mode 100644 index 00000000..50d5e2c3 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/data/service/grpc_dispatcher_impl.h @@ -0,0 +1,78 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_DATA_SERVICE_GRPC_DISPATCHER_IMPL_H_ +#define TENSORFLOW_CORE_DATA_SERVICE_GRPC_DISPATCHER_IMPL_H_ + +#include "grpcpp/server_builder.h" +#include "tensorflow/core/data/service/dispatcher.grpc.pb.h" +#include "tensorflow/core/data/service/dispatcher_impl.h" +#include "tensorflow/core/data/service/export.pb.h" +#include "tensorflow/core/protobuf/service_config.pb.h" + +namespace tensorflow { +namespace data { + +// This class is a wrapper that handles communication for gRPC. +class GrpcDispatcherImpl : public DispatcherService::Service { + public: + // Constructs a GrpcDispatcherImpl with the given config, and registers it + // with `server_builder`. + explicit GrpcDispatcherImpl(const experimental::DispatcherConfig& config, + ::grpc::ServerBuilder& server_builder); + ~GrpcDispatcherImpl() override { Stop(); } + + absl::Status Start(); + void Stop(); + + size_t NumActiveIterations(); + + DispatcherStateExport ExportState() const; + +#define HANDLER(method) \ + ::grpc::Status method(::grpc::ServerContext* context, \ + const method##Request* request, \ + method##Response* response) override; + HANDLER(WorkerHeartbeat); + HANDLER(WorkerUpdate); + HANDLER(GetDatasetDef); + HANDLER(GetSplit); + HANDLER(GetVersion); + HANDLER(GetOrRegisterDataset); + HANDLER(ReleaseIterationClient); + HANDLER(MaybeRemoveTask); + HANDLER(GetOrCreateJob); + HANDLER(GetOrCreateIteration); + HANDLER(ClientHeartbeat); + HANDLER(GetWorkers); + HANDLER(GetDataServiceMetadata); + HANDLER(GetDataServiceConfig); + HANDLER(Snapshot); + HANDLER(GetSnapshotSplit); + HANDLER(GetSnapshotStreams); + HANDLER(DisableCompressionAtRuntime); +#undef HANDLER + + private: + DataServiceDispatcherImpl impl_; + + GrpcDispatcherImpl(const GrpcDispatcherImpl&) = delete; + void operator=(const GrpcDispatcherImpl&) = delete; +}; + +} // namespace data +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_DATA_SERVICE_GRPC_DISPATCHER_IMPL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/data/service/grpc_util.h b/third_party/tflite-hdrs/tensorflow/core/data/service/grpc_util.h new file mode 100644 index 00000000..8fff6312 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/data/service/grpc_util.h @@ -0,0 +1,54 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_DATA_SERVICE_GRPC_UTIL_H_ +#define TENSORFLOW_CORE_DATA_SERVICE_GRPC_UTIL_H_ + +#include +#include + +#include "grpcpp/grpcpp.h" +#include "tensorflow/core/platform/status.h" + +namespace tensorflow { +namespace data { +namespace grpc_util { + +// Wraps a grpc::Status in a tensorflow::Status with the given message. +absl::Status WrapError(const std::string& message, + const ::grpc::Status& status); + +// Retries the given function if the function produces UNAVAILABLE, ABORTED, or +// CANCELLED status codes. We retry these codes because they can all indicate +// preemption of a server. The retries continue until the deadline is exceeded +// or the `should_retry` callback returns false. `description` may be used to +// log that retries are happening. It should contain a description of the action +// being retried, e.g. "register dataset" The retry loop uses exponential +// backoff between retries. `deadline_micros` is interpreted as microseconds +// since the epoch. +absl::Status Retry(const std::function& f, + const std::function& should_retry, + const std::string& description, int64_t deadline_micros); + +// Same as `Retry` above, but with a `should_retry` callback that always returns +// `true`. +absl::Status Retry(const std::function& f, + const std::string& description, int64_t deadline_micros); + +} // namespace grpc_util +} // namespace data +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_DATA_SERVICE_GRPC_UTIL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/data/service/grpc_worker_impl.h b/third_party/tflite-hdrs/tensorflow/core/data/service/grpc_worker_impl.h new file mode 100644 index 00000000..4513c0ca --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/data/service/grpc_worker_impl.h @@ -0,0 +1,81 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_DATA_SERVICE_GRPC_WORKER_IMPL_H_ +#define TENSORFLOW_CORE_DATA_SERVICE_GRPC_WORKER_IMPL_H_ + +#include +#include +#include +#include + +#include "grpcpp/server_builder.h" +#include "tensorflow/core/data/service/export.pb.h" +#include "tensorflow/core/data/service/worker.grpc.pb.h" +#include "tensorflow/core/data/service/worker.pb.h" +#include "tensorflow/core/data/service/worker_impl.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/protobuf/service_config.pb.h" + +namespace tensorflow { +namespace data { + +// This class is a wrapper that handles communication for gRPC. +class GrpcWorkerImpl : public WorkerService::Service { + public: + // Constructs a GrpcWorkerImpl with the given config, and registers it with + // `server_builder`. + explicit GrpcWorkerImpl(const experimental::WorkerConfig& config, + ::grpc::ServerBuilder& server_builder); + ~GrpcWorkerImpl() override { Stop(); } + + absl::Status Start( + const std::string& worker_address, + const std::vector& transfer_servers); + void Stop(); + + std::function + get_element_getter() { + return [this](const GetElementRequest* request, GetElementResult* result) { + return impl_->GetElementResult(request, result); + }; + } + + WorkerStateExport ExportState() const; + +#define HANDLER(method) \ + ::grpc::Status method(::grpc::ServerContext* context, \ + const method##Request* request, \ + method##Response* response) override; + HANDLER(ProcessTask); + HANDLER(GetElement); + HANDLER(GetWorkerTasks); + HANDLER(GetSnapshotTaskProgresses); +#undef HANDLER + + private: + std::string worker_address_; + // A std::shared_ptr allows clients to access local servers and directly call + // the servers' methods to avoid RPC calls and data copy. + std::shared_ptr impl_; + + GrpcWorkerImpl(const GrpcWorkerImpl&) = delete; + void operator=(const GrpcWorkerImpl&) = delete; +}; + +} // namespace data +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_DATA_SERVICE_GRPC_WORKER_IMPL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/data/service/journal.h b/third_party/tflite-hdrs/tensorflow/core/data/service/journal.h new file mode 100644 index 00000000..0c15856b --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/data/service/journal.h @@ -0,0 +1,118 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_DATA_SERVICE_JOURNAL_H_ +#define TENSORFLOW_CORE_DATA_SERVICE_JOURNAL_H_ + +#include +#include + +#include "tensorflow/core/data/service/journal.pb.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/io/record_reader.h" +#include "tensorflow/core/lib/io/record_writer.h" +#include "tensorflow/core/platform/env.h" + +namespace tensorflow { +namespace data { + +// Returns the location of the journal file within the journal directory. +std::string DataServiceJournalFile(const std::string& journal_dir, + int64_t sequence_number); + +// Interface for writing to a journal. +class JournalWriter { + public: + virtual ~JournalWriter() = default; + // Writes and syncs an update to the journal. + virtual absl::Status Write(const Update& update) = 0; + // Initializes the writer if it is not yet initialized. + virtual absl::Status EnsureInitialized() = 0; +}; + +// FileJournalWriter is not thread-safe, requiring external synchronization when +// used by multiple threads. +// +// FileJournalWriter writes journal files to a configured journal directory. The +// directory is laid out in the following format: +// +// journal_dir/ +// journal_0 +// journal_1 +// ... +// +// When the writer is created, it lists the directory to find the next available +// journal file name. For example, if the journal directory contains +// "journal_0", "journal_1", and "journal_2", the writer will write to +// "journal_3". The writer will flush updates as they are written, so that they +// can be stored durably in case of machine failure. +class FileJournalWriter : public JournalWriter { + public: + // Creates a journal writer to write to the given journal directory. + // If there is already journal data there, the journal writer will append to + // the existing journal. + explicit FileJournalWriter(Env* env, const std::string& journal_dir); + FileJournalWriter(const FileJournalWriter&) = delete; + FileJournalWriter& operator=(const FileJournalWriter&) = delete; + + absl::Status Write(const Update& update) override; + absl::Status EnsureInitialized() override; + + private: + Env* env_; + const std::string journal_dir_; + std::unique_ptr file_; + std::unique_ptr writer_; +}; + +// Interface for reading from a journal. +class JournalReader { + public: + virtual ~JournalReader() = default; + // Reads the next update from the journal. Sets `end_of_journal=true` if + // there are no more updates left in the journal. + virtual absl::Status Read(Update& update, bool& end_of_journal) = 0; +}; + +// JournalReader is not thread-safe, requiring external synchronization when +// used by multiple threads. +// +// The journal reader reads through all journal files in the configured journal +// directory, in order of their sequence numbers. See FileJournalWriter above. +class FileJournalReader : public JournalReader { + public: + explicit FileJournalReader(Env* env, absl::string_view journal_dir); + FileJournalReader(const FileJournalReader&) = delete; + FileJournalReader& operator=(const FileJournalReader&) = delete; + + absl::Status Read(Update& update, bool& end_of_journal) override; + + private: + // Initializes the reader if it is not yet initialized. + absl::Status EnsureInitialized(); + // Updates the `FileJournalReader` to read from a new file. + absl::Status UpdateFile(const std::string& filename); + + Env* env_; + const std::string journal_dir_; + // Sequence number of current journal file. + int64_t sequence_number_ = 0; + std::unique_ptr file_; + std::unique_ptr reader_; +}; + +} // namespace data +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_DATA_SERVICE_JOURNAL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/data/service/py_utils.h b/third_party/tflite-hdrs/tensorflow/core/data/service/py_utils.h new file mode 100644 index 00000000..b0ea8928 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/data/service/py_utils.h @@ -0,0 +1,33 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_DATA_SERVICE_PY_UTILS_H_ +#define TENSORFLOW_CORE_DATA_SERVICE_PY_UTILS_H_ + +#include + +// Utilities called from the Python API through pybind. We define this file +// separately from other utils to keep the transitive closure of dependencies +// minimal, avoiding linking conflicts. +namespace tensorflow { +namespace data { + +// Returns the default protocol to use for tf.data service control flow. +std::string DefaultProtocol(); + +} // namespace data +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_DATA_SERVICE_PY_UTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/data/service/server_lib.h b/third_party/tflite-hdrs/tensorflow/core/data/service/server_lib.h new file mode 100644 index 00000000..56a8f8d9 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/data/service/server_lib.h @@ -0,0 +1,189 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_DATA_SERVICE_SERVER_LIB_H_ +#define TENSORFLOW_CORE_DATA_SERVICE_SERVER_LIB_H_ + +#include +#include +#include + +#include "grpcpp/server.h" +#include "grpcpp/server_builder.h" +#include "tensorflow/core/data/service/common.pb.h" +#include "tensorflow/core/data/service/data_transfer.h" +#include "tensorflow/core/data/service/export.pb.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/profiler/rpc/profiler_service_impl.h" +#include "tensorflow/core/protobuf/service_config.pb.h" + +namespace tensorflow { +namespace data { + +// Forward declared because transitively depending on .grpc.pb.h files causes +// issues in the pywrap build. +class GrpcDispatcherImpl; +class GrpcWorkerImpl; + +// A grpc server for the tf.data service. +class GrpcDataServerBase { + public: + // Constructs a tf.data server with the specified port. If the port is 0, the + // server will find an available port in `Start()`. The chosen port can be + // found by calling `BoundPort()`. + GrpcDataServerBase( + int requested_port, const std::string& protocol, + const std::string& server_type, + std::vector> options = {}); + virtual ~GrpcDataServerBase() = default; + + // Starts the server running asynchronously. + absl::Status Start(); + + // Stops the server. This will block until all outstanding requests complete. + void Stop(); + + // Blocks until the server stops. + void Join(); + + // Returns the port bound by the server. Only valid after calling Start(). + int BoundPort(); + + // Exports the server state to improve debuggability. + virtual ServerStateExport ExportState() const = 0; + + protected: + virtual void AddDataServiceToBuilder(::grpc::ServerBuilder& builder) = 0; + void AddProfilerServiceToBuilder(::grpc::ServerBuilder& builder); + // Starts the service. This will be called after building the service, so + // bound_port() will return the actual bound port. + virtual absl::Status StartServiceInternal() = 0; + virtual void StopServiceInternal() {} + + int bound_port() { return bound_port_; } + + const int requested_port_; + const std::string protocol_; + const std::string server_type_; + + private: + int bound_port_; + bool started_ = false; + bool stopped_ = false; + + std::unique_ptr<::grpc::Server> server_; + // TensorFlow profiler service implementation. + std::unique_ptr profiler_service_ = nullptr; + std::vector> server_options_; +}; + +// A wrapper for `SnapshotStreamInfo` for use with pybind. +struct SnapshotStreamInfoWrapper { + SnapshotStreamInfoWrapper() = default; + explicit SnapshotStreamInfoWrapper(const SnapshotStreamInfo& info) + : index(info.index()), state(info.state()) {} + int64_t index; + int64_t state; +}; + +class DispatchGrpcDataServer : public GrpcDataServerBase { + public: + explicit DispatchGrpcDataServer( + const experimental::DispatcherConfig& config, + std::vector> options = {}); + ~DispatchGrpcDataServer() override; + + // Returns the number of workers registered with the dispatcher. + absl::Status NumWorkers(int* num_workers); + // Returns the number of active (non-finished) iterations running on the + // dispatcher. + size_t NumActiveIterations(); + // Returns information about all the streams for the snapshot at `path`. + absl::Status SnapshotStreams(const std::string& path, + std::vector* streams); + + ServerStateExport ExportState() const override; + + protected: + void AddDataServiceToBuilder(::grpc::ServerBuilder& builder) override; + absl::Status StartServiceInternal() override; + void StopServiceInternal() override; + + private: + const experimental::DispatcherConfig config_; + // Owned. We use a raw pointer because GrpcDispatcherImpl is forward-declared. + GrpcDispatcherImpl* service_; +}; + +// A wrapper for `SnapshotTaskProgress` for use with pybind. +struct SnapshotTaskProgressWrapper { + SnapshotTaskProgressWrapper() = default; + explicit SnapshotTaskProgressWrapper(const SnapshotTaskProgress& progress) + : snapshot_task_base_path(progress.snapshot_task().base_path()), + snapshot_task_stream_index(progress.snapshot_task().stream_index()), + completed(progress.completed()) {} + std::string snapshot_task_base_path; + int64_t snapshot_task_stream_index; + bool completed; +}; + +class WorkerGrpcDataServer : public GrpcDataServerBase { + public: + explicit WorkerGrpcDataServer( + const experimental::WorkerConfig& config, + std::vector> options = {}); + ~WorkerGrpcDataServer() override; + + // Returns the number of tasks currently being executed by the worker. + absl::Status NumTasks(int* num_tasks); + + // Returns the progresses of the snapshot tasks currently being executed by + // the worker. + absl::Status SnapshotTaskProgresses( + std::vector* snapshot_task_progresses); + + ServerStateExport ExportState() const override; + + protected: + void AddDataServiceToBuilder(::grpc::ServerBuilder& builder) override; + absl::Status StartServiceInternal() override; + void StopServiceInternal() override; + + private: + // If an alternative data transfer protocol is configured, tries to start a + // transfer server for it, adding an entry to `transfer_servers` if + // successful. + void MaybeStartAlternativeDataTransferServer( + std::vector& transfer_servers); + + const experimental::WorkerConfig config_; + // Owned. We use a raw pointer because GrpcWorkerImpl is forward-declared. + GrpcWorkerImpl* service_; + std::shared_ptr transfer_server_; +}; + +// Creates a dispatch tf.data server and stores it in `out_server`. +absl::Status NewDispatchServer( + const experimental::DispatcherConfig& config, + std::unique_ptr& out_server); + +// Creates a worker tf.data server and stores it in `out_server`. +absl::Status NewWorkerServer(const experimental::WorkerConfig& config, + std::unique_ptr& out_server); + +} // namespace data +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_DATA_SERVICE_SERVER_LIB_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/data/service/snapshot/file_utils.h b/third_party/tflite-hdrs/tensorflow/core/data/service/snapshot/file_utils.h new file mode 100644 index 00000000..2a6ca60a --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/data/service/snapshot/file_utils.h @@ -0,0 +1,74 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_DATA_SERVICE_SNAPSHOT_FILE_UTILS_H_ +#define TENSORFLOW_CORE_DATA_SERVICE_SNAPSHOT_FILE_UTILS_H_ + +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "tensorflow/core/framework/tensor.h" +#include "tsl/platform/env.h" +#include "tsl/platform/protobuf.h" + +namespace tensorflow { +namespace data { + +// Atomically writes `str` to `filename`. Overwrites existing contents if the +// file already exists. +absl::Status AtomicallyWriteStringToFile(absl::string_view filename, + absl::string_view str, tsl::Env* env); + +// Atomically writes the binary representation of `proto` to `filename`. +// Overwrites existing contents if the file already exists. +absl::Status AtomicallyWriteBinaryProto(absl::string_view filename, + const tsl::protobuf::Message& proto, + tsl::Env* env); + +// Atomically writes the text representation of `proto` to `filename`. +// Overwrites existing contents if the file already exists. +absl::Status AtomicallyWriteTextProto(absl::string_view filename, + const tsl::protobuf::Message& proto, + tsl::Env* env); + +// Atomically writes `tensor` to `filename` in TFRecord format. Overwrites +// existing contents if the file already exists. +absl::Status AtomicallyWriteTFRecords(absl::string_view filename, + const std::vector& tensors, + absl::string_view compression, + tsl::Env* env); + +// Returns the relative paths of the children of `directory`, ignoring temporary +// files. Returns an empty vector if the directory does not have any children. +absl::StatusOr> GetChildren( + absl::string_view directory, tsl::Env* env); + +// Returns true if `filename` is a temporary file and should be ignored in +// normal data processing. +bool IsTemporaryFile(absl::string_view filename); + +// Returns the total number of chunks for a distributed snapshot: +// - If the snapshot is finished, returns the number of committed chunks. +// - If the snapshot is unfinished or has failed, returns kUnknownCardinality. +int64_t SnapshotChunksCardinality(absl::string_view snapshot_path, + tsl::Env* env); + +} // namespace data +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_DATA_SERVICE_SNAPSHOT_FILE_UTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/data/service/snapshot/parallel_tfrecord_writer.h b/third_party/tflite-hdrs/tensorflow/core/data/service/snapshot/parallel_tfrecord_writer.h new file mode 100644 index 00000000..db6cd182 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/data/service/snapshot/parallel_tfrecord_writer.h @@ -0,0 +1,144 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_DATA_SERVICE_SNAPSHOT_PARALLEL_TFRECORD_WRITER_H_ +#define TENSORFLOW_CORE_DATA_SERVICE_SNAPSHOT_PARALLEL_TFRECORD_WRITER_H_ + +#include +#include +#include +#include +#include +#include + +#include "absl/base/thread_annotations.h" +#include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/synchronization/mutex.h" +#include "tensorflow/core/data/service/byte_size.h" +#include "tensorflow/core/data/snapshot_utils.h" +#include "tensorflow/core/framework/tensor.h" +#include "tsl/platform/env.h" +#include "tsl/platform/threadpool.h" + +namespace tensorflow { +namespace data { + +// Uses multiple threads to write TFRecords in parallel. Users add data without +// waiting for the file writes, and it writes one shard of file per thread. +// Returns the file names when writes are finished. This class is thread-safe. +// +// Usage example: +// +// ParallelTFRecordWriter writer( +// "/path/to/file", tsl::io::compression::kSnappy, Env::Default()); +// +// std::vector record; +// bool end_of_sequence = false; +// TF_RETURN_IF_ERROR(iterator.GetNext(record, end_of_sequence)); +// while (!end_of_sequence) { +// TF_RETURN_IF_ERROR(writer.Write(record)); +// TF_RETURN_IF_ERROR(iterator.GetNext(record, end_of_sequence)); +// } +// TF_ASSIGN_OR_RETURN(ParallelTFRecordWriter::FileToStatsMap file_stats, +// writer.Finalize()); +class ParallelTFRecordWriter { + public: + explicit ParallelTFRecordWriter(const std::string& file_prefix, + const std::string& compression, tsl::Env* env, + ByteSize max_file_size = ByteSize::GB(6), + int64_t num_write_threads = 2, + int64_t buffer_size = 1); + virtual ~ParallelTFRecordWriter(); + ParallelTFRecordWriter(const ParallelTFRecordWriter&) = delete; + ParallelTFRecordWriter& operator=(const ParallelTFRecordWriter&) = delete; + + // Writes `record`. If there is sufficient buffer space, it returns without + // waiting for the record to be written to the file. If the buffer is full, + // blocks until there is enough space to buffer the record. + absl::Status Write(std::vector record); + + // File stats: number of records in a file and the estimated size of the file. + struct FileStats { + int64_t num_records = 0; + ByteSize estimated_size; + }; + using FileToStatsMap = absl::flat_hash_map; + + // Flushes the writer and finalizes the files. Returns a map from absolute + // paths to the file stats. After the writer is finalized, `Write` will return + // `FailedPreconditionErrors`. The caller should make sure all `Write` calls + // have finished before calling `Finalize`. Will block until the writer is + // finalized or an error occurs. + absl::StatusOr Finalize(); + + private: + // Run by a thread to write buffered records to sharded files. + void WriteFiles(); + + // Whether there are more records to be written. + bool HasNext() const; + + // Writes a new file. + absl::Status WriteFile(); + + // Whether the file can hold more records without exceeding `max_file_size_`. + bool ShouldWriteFile(const std::string& filename) const; + + // Writes one record to file. + absl::Status WriteRecord(const std::string& filename, + snapshot_util::TFRecordWriter& writer); + + // Gets the next record from the buffer to write. Returns `std::nullopt` if + // there are no more records to write. + absl::StatusOr>> GetNextRecord( + const std::string& filename); + + // Deletes the file if it's empty. + absl::Status DeleteEmptyFile(const std::string& filename); + + // Generates a unique file name in the requested directory. + absl::StatusOr GetUniqueFile() const; + + // Updates the status of the writer and notifies waiters. + void UpdateStatus(absl::Status status); + + tsl::Env* const env_; + const std::string file_prefix_; + const std::string compression_; + const ByteSize max_file_size_; + const int64_t buffer_size_; + + mutable absl::Mutex mu_; + mutable absl::CondVar ready_to_push_; + mutable absl::CondVar ready_to_pop_; + + bool finalized_ ABSL_GUARDED_BY(mu_) = false; + absl::Status status_ ABSL_GUARDED_BY(mu_); + + // A map from absolute paths to the number of records in the files. + FileToStatsMap file_stats_ ABSL_GUARDED_BY(mu_); + + // Buffer to hold the records to be written. The size should be bounded by + // `buffer_size_`. + std::deque> buffer_ ABSL_GUARDED_BY(mu_); + + std::unique_ptr thread_pool_; +}; + +} // namespace data +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_DATA_SERVICE_SNAPSHOT_PARALLEL_TFRECORD_WRITER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/data/service/snapshot/path_utils.h b/third_party/tflite-hdrs/tensorflow/core/data/service/snapshot/path_utils.h new file mode 100644 index 00000000..63c88556 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/data/service/snapshot/path_utils.h @@ -0,0 +1,134 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_DATA_SERVICE_SNAPSHOT_PATH_UTILS_H_ +#define TENSORFLOW_CORE_DATA_SERVICE_SNAPSHOT_PATH_UTILS_H_ + +#include +#include +#include +#include + +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" + +namespace tensorflow { +namespace data { + +// Returns the directory path for the assigned streams of a snapshot. +std::string StreamsDirectory(absl::string_view snapshot_path); + +// Returns the directory path for a worker writing one stream of the snapshot. +std::string StreamDirectory(absl::string_view snapshot_path, + int64_t stream_index); + +// Returns the directory path for the assigned splits for a worker writing one +// stream of a snapshot. +std::string SplitsDirectory(absl::string_view snapshot_path, + int64_t stream_index); + +// Returns the directory path for the assigned splits for one source, for a +// worker writing one stream of a snapshot. +std::string SourceDirectory(absl::string_view snapshot_path, + int64_t stream_index, int64_t source_index); + +// Returns the directory path for one repetition of a split provider. +std::string RepetitionDirectory(absl::string_view snapshot_path, + int64_t stream_index, int64_t source_index, + int64_t repetition_index); + +// Returns the file path for an assigned split for a worker writing one stream +// of a snapshot. +std::string SplitPath(absl::string_view snapshot_path, int64_t stream_index, + int64_t source_index, int64_t repetition_index, + int64_t local_index, int64_t global_index); + +// Returns the index of the stream. The expected format of +// `stream_directory_name` is: +// stream_ +absl::StatusOr ParseStreamDirectoryName( + absl::string_view stream_directory_name); + +// Returns the index of the source. The expected format of +// `source_directory_name` is: +// source_ +absl::StatusOr ParseSourceDirectoryName( + absl::string_view source_directory_name); + +// Returns the index of the repetition. The expected format of +// `repetition_directory_name` is: +// repetition_ +absl::StatusOr ParseRepetitionDirectoryName( + absl::string_view repetition_directory_name); + +// Returns a pair of {local_split_index, global_split_index} of the split. The +// expected format of `split_filename` is: +// split__ +absl::StatusOr> ParseSplitFilename( + absl::string_view split_filename); + +// Returns a pair of {checkpoint_index, checkpoint_num_elements} of the +// checkpoint. The expected format of `checkpoint_filename` is: +// checkpoint__ +absl::StatusOr> ParseCheckpointFilename( + absl::string_view checkpoint_filename); + +// Returns a tuple of {stream_index, stream_chunk_index, chunk_num_elements} of +// the chunk. The expected format of `chunk_filename` is: +// chunk___ +absl::StatusOr> ParseChunkFilename( + absl::string_view chunk_filename); + +// Returns the path of the DONE file of a snapshot stream. +std::string StreamDoneFilePath(absl::string_view snapshot_path, + int64_t stream_index); + +// Returns the path of the owner_worker file of a snapshot stream. +std::string StreamWorkerFilePath(absl::string_view snapshot_path, + int64_t stream_index); + +// Returns the path of the owner_worker file of a snapshot stream. +std::string StreamWorkerFilePath(absl::string_view stream_path); + +// Returns the path of the DONE file of a snapshot. +std::string SnapshotDoneFilePath(absl::string_view snapshot_path); + +// Returns the path of the ERROR file of a snapshot. +std::string SnapshotErrorFilePath(absl::string_view snapshot_path); + +// Returns the path of the serialized metadata for a snapshot. +std::string SnapshotMetadataFilePath(absl::string_view snapshot_path); + +// Returns the path of the serialized graph of the dataset for a snapshot. +std::string DatasetDefFilePath(absl::string_view snapshot_path); + +// Returns the path of the serialized element spec of the dataset for a +// snapshot. +std::string DatasetSpecFilePath(absl::string_view snapshot_path); + +// Returns the directory path for snapshot checkpoints. +std::string CheckpointsDirectory(absl::string_view snapshot_path, + int64_t stream_index); + +// Returns the directory path for committed chunks. +std::string CommittedChunksDirectory(absl::string_view snapshot_path); + +// Returns the directory path for uncommitted chunks. +std::string UncommittedChunksDirectory(absl::string_view snapshot_path, + int64_t stream_index); + +} // namespace data +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_DATA_SERVICE_SNAPSHOT_PATH_UTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/data/service/snapshot/prefetched_split_provider.h b/third_party/tflite-hdrs/tensorflow/core/data/service/snapshot/prefetched_split_provider.h new file mode 100644 index 00000000..2ec9472c --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/data/service/snapshot/prefetched_split_provider.h @@ -0,0 +1,158 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_DATA_SERVICE_SNAPSHOT_PREFETCHED_SPLIT_PROVIDER_H_ +#define TENSORFLOW_CORE_DATA_SERVICE_SNAPSHOT_PREFETCHED_SPLIT_PROVIDER_H_ + +#include +#include +#include +#include + +#include "absl/base/thread_annotations.h" +#include "absl/container/btree_set.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/synchronization/mutex.h" +#include "tensorflow/core/framework/dataset.h" +#include "tensorflow/core/framework/tensor.h" +#include "tsl/platform/env.h" +#include "tsl/platform/threadpool.h" + +namespace tensorflow { +namespace data { + +// Uses multiple threads to prefetch splits and write them to temporary files. +// Used to speed up tf.data snapshot manager where splits should be persisted +// before returning to the users. This class is thread-safe. +// +// Usage example: +// +// std::unique_ptr split_provider = ... +// PrefetchedSplitProvider prefetched_split_provider( +// std::move(split_provider), "/tmp/directory", Env::Default()); +// TF_ASSIGN_OR_RETURN(std::optional split, +// prefetched_split_provider.GetSplit(SplitPath(...))); +// if (split.has_value) { +// return *split; +// } +class PrefetchedSplitProvider { + public: + // Creates a prefetched split provider by prefetching given `split_provider`. + // `directory` is where to write temporary splits. The splits will be moved to + // a target file when returned to the client (see the comment for `GetSplit`). + // `num_write_threads` is the number of threads to prefetch and write splits. + // `buffer_size_per_thread` is the size of the buffer holding the prefetched + // but unread splits. For every prefetched split, we keep: (1) an in-memory + // Tensor in the buffer, and (2) an on-disk file representing the same split. + explicit PrefetchedSplitProvider( + std::unique_ptr split_provider, + const std::string& directory, tsl::Env* env, + size_t num_write_threads = 20, size_t buffer_size_per_thread = 5); + virtual ~PrefetchedSplitProvider(); + PrefetchedSplitProvider(const PrefetchedSplitProvider&) = delete; + PrefetchedSplitProvider& operator=(const PrefetchedSplitProvider&) = delete; + + // Writes the split to `target_split_path` and returns the split. Returns + // `std::nullopt` if no more splits are available. If there are more available + // splits but not currently ready for reading, blocks until they are ready. + absl::StatusOr> GetNext(const std::string& split_path); + + // Resets the split provider. + absl::Status Reset(); + + // Cancels the split provider. After cancelling, concurrent `GetNext` calls + // will return a Cancelled error. + void Cancel(); + + private: + // Prefetched split and its split index. + struct SplitAndIndex { + Tensor split; + size_t index = 0; + + // Returns the absolute path of the prefetched split. + std::string SplitPath(const std::string& directory) const { + return tsl::io::JoinPath(directory, + absl::StrCat("split_", index, ".tfrecord")); + } + + friend bool operator<(const SplitAndIndex& lhs, const SplitAndIndex& rhs) { + return lhs.index < rhs.index; + } + }; + + // Initializes directories for writing. This cleans up all existing files in + // `directory_`. + absl::Status InitDirs(); + + // Runs the prefetch threads. + std::unique_ptr RunPrefetchThreads(); + + // The prefetching threads run this method to prefetch the splits. + void PrefetchLoop(); + + // Whether the prefetching thread should try to fetch more splits. + bool ShouldPrefetchSplit() const; + + // If there is enough buffer space, prefetches one split and writes it to a + // temporary file. If the buffer is full, blocks until there is buffer space. + absl::StatusOr PrefetchSplit(); + + // Gets the next split from the split provider. + absl::StatusOr> GetSplitFromProvider(); + + // Updates the status and notifies waiters. + void UpdateStatus(absl::Status status); + + tsl::Env* const env_; + const std::string directory_; + const size_t num_write_threads_; + const size_t buffer_size_; + + mutable absl::Mutex mu_; + mutable absl::CondVar ready_to_push_; + mutable absl::CondVar ready_to_pop_; + + std::unique_ptr split_provider_; + + absl::Status status_ ABSL_GUARDED_BY(mu_); + + // Whether the split provider is being reset. + bool reset_ ABSL_GUARDED_BY(mu_) = false; + + // The indices ensure the splits are returned in order. When prefetching a + // split, associates each split with the `split_index_to_write_`. The buffer + // is sorted by the split index. When reading, waits for the split with index + // `split_index_to_read_`. + size_t split_index_to_read_ ABSL_GUARDED_BY(mu_) = 0; + size_t split_index_to_write_ ABSL_GUARDED_BY(mu_) = 0; + + // Number of finished threads. If `finished_threads_ >= num_write_threads_`, + // then all the splits have been pushed to the buffer. Otherwise, the split + // provider has not produced all the splits, or some thread is still writing + // splits to the files. + size_t finished_threads_ ABSL_GUARDED_BY(mu_) = 0; + + // Buffer to hold the splits. The size should be bounded by `buffer_size_`. + absl::btree_set buffer_ ABSL_GUARDED_BY(mu_); + + std::unique_ptr thread_pool_ ABSL_GUARDED_BY(mu_); +}; + +} // namespace data +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_DATA_SERVICE_SNAPSHOT_PREFETCHED_SPLIT_PROVIDER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/data/service/snapshot/snapshot_chunk_provider.h b/third_party/tflite-hdrs/tensorflow/core/data/service/snapshot/snapshot_chunk_provider.h new file mode 100644 index 00000000..fefc4998 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/data/service/snapshot/snapshot_chunk_provider.h @@ -0,0 +1,124 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_DATA_SERVICE_SNAPSHOT_SNAPSHOT_CHUNK_PROVIDER_H_ +#define TENSORFLOW_CORE_DATA_SERVICE_SNAPSHOT_SNAPSHOT_CHUNK_PROVIDER_H_ + +#include +#include +#include +#include +#include +#include + +#include "absl/base/thread_annotations.h" +#include "absl/container/btree_set.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/synchronization/mutex.h" +#include "tensorflow/core/framework/dataset.h" +#include "tensorflow/core/framework/tensor.h" +#include "tsl/platform/env.h" + +namespace tensorflow { +namespace data { + +// Provides the next chunk to read. Blocks until the next chunk is unavailable, +// or all the chunks have been read. This class is thread-safe. +class SnapshotChunkProvider : public SplitProvider { + public: + SnapshotChunkProvider(absl::string_view snapshot_path, tsl::Env* env); + ~SnapshotChunkProvider() override = default; + SnapshotChunkProvider(const SnapshotChunkProvider&) = delete; + SnapshotChunkProvider& operator=(const SnapshotChunkProvider&) = delete; + + // Returns the absolute file path of next snapshot chunk to read. If there is + // no available chunk, blocks until the next chunk is unavailable, or all the + // chunks are read. Sets `end_of_splits` to true if all chunks have been read. + absl::Status GetNext(Tensor* split, bool* end_of_splits) override; + + absl::Status Reset() override; + + // Supports checkpointing. + absl::Status Save(std::function full_name, + IteratorStateWriter* writer) override; + absl::Status Restore(std::function full_name, + IteratorStateReader* reader) override; + + // If the snapshot is finished, returns the number of committed chunks. + // If the snapshot is unfinished or has failed, returns kUnknownCardinality. + int64_t Cardinality() const override; + + // Cancels the provider. After cancelling, if the snapshot is unfinished, + // in-flight `GetNext` calls will return Cancelled status. + void Cancel() override; + + private: + // State of the snapshot. + struct SnapshotState { + SnapshotState() = default; + explicit SnapshotState(bool snapshot_is_done) + : snapshot_is_done(snapshot_is_done) {} + explicit SnapshotState(absl::Status status) : status(std::move(status)) {} + + // True if the snapshot is done without errors. + bool snapshot_is_done = false; + + // Non-OK status if writing the snapshot fails. + absl::Status status = absl::OkStatus(); + }; + + // Used to sort chunks by chunk indexes so that chunks are read evenly across + // streams and chunks of early repetitions are read first. + struct ChunkOrder { + bool operator()(const std::string& chunk1, const std::string& chunk2) const; + }; + using OrderedChunkSet = absl::btree_set; + + // String conversions to support `Save` and `Restore`. + static std::string SetToString(const OrderedChunkSet& s); + static OrderedChunkSet SetFromString(absl::string_view s); + + // Updates the snapshot state and available chunks. + absl::Status UpdateSnapshot(); + + // Reads the DONE or ERROR file and returns a SnapshotState indicating whether + // the snapshot is complete. + absl::StatusOr GetSnapshotState(); + + // Reads the available chunks from disk and returns a vector of chunk file + // names. + absl::StatusOr> GetAvailableChunks(); + + const std::string snapshot_path_; + tsl::Env* const env_; + + mutable absl::Mutex mu_; + + // The set of read chunks. + OrderedChunkSet chunks_read_ ABSL_GUARDED_BY(mu_); + + // The set of unread chunks. Uses an ordered set to make sure repeated reads + // produce data in a deterministic order. + OrderedChunkSet chunks_unread_ ABSL_GUARDED_BY(mu_); + + // State of the snapshot. + SnapshotState snapshot_state_ ABSL_GUARDED_BY(mu_); +}; + +} // namespace data +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_DATA_SERVICE_SNAPSHOT_SNAPSHOT_CHUNK_PROVIDER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/data/service/snapshot/snapshot_manager.h b/third_party/tflite-hdrs/tensorflow/core/data/service/snapshot/snapshot_manager.h new file mode 100644 index 00000000..dd3a76d6 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/data/service/snapshot/snapshot_manager.h @@ -0,0 +1,378 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_DATA_SERVICE_SNAPSHOT_SNAPSHOT_MANAGER_H_ +#define TENSORFLOW_CORE_DATA_SERVICE_SNAPSHOT_SNAPSHOT_MANAGER_H_ + +#include +#include +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/container/btree_map.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/strings/substitute.h" +#include "absl/time/time.h" +#include "xla/tsl/protobuf/status.pb.h" +#include "tensorflow/core/data/service/dispatcher.pb.h" +#include "tensorflow/core/data/service/snapshot/prefetched_split_provider.h" +#include "tensorflow/core/framework/dataset.h" +#include "tensorflow/core/protobuf/snapshot.pb.h" +#include "tsl/platform/env.h" +#include "tsl/platform/mutex.h" +#include "tsl/platform/thread_annotations.h" + +namespace tensorflow { +namespace data { + +// A helper shared among `SnapshotManager`s to limit workers' stream assignments +// across ongoing snapshots. This class is thread-safe. +class SnapshotAssignmentManager { + public: + explicit SnapshotAssignmentManager(int64_t worker_max_concurrent_snapshots) + : worker_max_concurrent_snapshots_(worker_max_concurrent_snapshots) {} + + // Tries to record the event of a worker being assigned a stream. Returns + // `false` if the worker has too many assignments. Returns an error if the + // worker is already known to have been assigned this stream. + absl::StatusOr TryAddAssignment(absl::string_view snapshot_path, + absl::string_view worker_address, + int64_t stream_index); + + // Records the event of a worker stopping work on a stream. + void RemoveAssignment(absl::string_view snapshot_path, + absl::string_view worker_address, int64_t stream_index); + + // Adds a new snapshot. + void AddSnapshot(absl::string_view snapshot_path); + + // Load balances snapshots by the number of assigned streams. Given a worker, + // returns snapshots in the following order: + // - Snapshots already assigned to this worker. + // - Snapshots with the fewest assignments. + std::vector LoadBalanceSnapshots( + absl::string_view worker_address); + + // Returns the maximum concurrent snapshots processed by each worker. + int64_t worker_max_concurrent_snapshots() const { + return worker_max_concurrent_snapshots_; + } + + private: + struct Assignment { + std::string snapshot_path; + int64_t stream_index; + + template + friend H AbslHashValue(H h, const Assignment& a) { + return H::combine(std::move(h), a.snapshot_path, a.stream_index); + } + + friend bool operator==(const Assignment& lhs, const Assignment& rhs) { + return lhs.snapshot_path == rhs.snapshot_path && + lhs.stream_index == rhs.stream_index; + } + + std::string DebugString() const { + return absl::Substitute( + "Assignment { snapshot_path: $0, stream_index: $1 }", snapshot_path, + stream_index); + } + }; + + // A mapping of worker address to ongoing assignments. + absl::flat_hash_map> assignments_ + TF_GUARDED_BY(mu_); + + // A mapping from snapshot to the number of assigned workers. + absl::flat_hash_map snapshot_assignment_counts_ + TF_GUARDED_BY(mu_); + + // The maximum number of snapshots that a worker can concurrently process at a + // given point in time. This is a tradeoff between worker resource usage and + // snapshot wall time. A value of 0 indicates that the decision should be left + // up to the runtime. + const int64_t worker_max_concurrent_snapshots_; + + mutable tsl::mutex mu_; +}; + +// A helper used by `DataServiceDispatcherImpl` to manage a call to `Snapshot`. +// +// Two mirrored states are maintained: +// - An in-memory state (objects in the `SnapshotManager` instance). +// - An on-disk state (files in the `SnapshotManager::path_`). +// +// The on-disk state has this structure: +// - snapshot_path +// - DONE +// - ERROR +// - snapshot.metadata +// - dataset_def.proto +// - dataset_spec.pb +// - chunks +// - chunk___ +// - streams +// - stream_0 +// - DONE +// - ERROR +// - splits +// - source_0 +// - split__ +// - uncommitted_chunks +// - chunk_ +// - checkpoints +// - checkpoint__ +// +class SnapshotManager { + public: + // Initiates a new snapshot process, creating a fresh in-memory state and + // writing an on-disk state to `path`. Returns an error if `path` already + // exists in the filesystem. + static absl::StatusOr> Start( + const SnapshotRequest& request, + SnapshotAssignmentManager& assignment_manager, Env* env); + // Resumes an existing snapshot process, reading from the on-disk state in + // `path` to derive an in-memory state. Returns an error if `path` is in a bad + // state. + static absl::StatusOr> Resume( + absl::string_view path, SnapshotAssignmentManager& assignment_manager, + Env* env); + + // Handles the work pertaining to this snapshot process for the respective + // `DispatcherService` API calls: + // - `WorkerHeartbeat`: Returns a stream assignment for the worker. + // - `GetSnapshotSplit`: Returns a split assignment for the worker. + // - `GetSnapshotStreams`: Returns information about all streams. + absl::Status WorkerHeartbeat(const WorkerHeartbeatRequest& request, + WorkerHeartbeatResponse& response); + absl::Status GetSnapshotSplit(const GetSnapshotSplitRequest& request, + GetSnapshotSplitResponse& response); + absl::Status GetSnapshotStreams(GetSnapshotStreamsResponse& response); + + // Cancels the SnapshotManager and finishes in-progress threads. + void Cancel(); + + private: + SnapshotManager(absl::string_view path, + SnapshotAssignmentManager& assignment_manager, Env* env) + : path_(path), + env_(env), + last_progress_log_time_(absl::FromUnixMicros(env->NowMicros())), + assignment_manager_(assignment_manager) {} + + // Helpers for `Start` above. These update the on-disk state. + absl::Status Start(const SnapshotRequest& request); + absl::Status WriteOnDiskSkeleton(); + absl::Status WriteOnDiskMetadata(const SnapshotRequest& request); + + // Helpers for `Resume` above. These update the in-memory state. + absl::Status Resume(); + absl::Status ReadOnDiskMetadata(); + absl::Status ReadOnDiskStreams(); + + // Helpers for `WorkerHeartbeat` above. These may update the in-memory and + // on-disk states. + // Gets or creates a new stream. Returns the stream index and a bool value + // indicating whether a new stream has been created. Returns `std::nullopt` + // if there are no more streams to write or there is an error. + absl::StatusOr>> + MaybeGetOrCreateStreamAssignment( + absl::string_view worker_address, + const SnapshotTaskProgress* snapshot_progress); + absl::Status HandleStreamCompletion(int64_t stream_index, + absl::string_view worker_address); + void ReassignPreviouslyAssignedStream(int64_t stream_index, + absl::string_view worker_address); + std::optional MaybeAssignOrphanStream( + absl::string_view worker_address); + absl::StatusOr> MaybeCreateAndAssignNewStream( + absl::string_view worker_address); + absl::Status HandleStreamError(absl::string_view worker_address, + const StatusProto& status_proto); + + mutable tsl::mutex mu_; + // Uses a separate mutex for `GetSnapshotSplit` RPCs. `GetSnapshotSplit` uses + // file IO and may be slow, which may slow down `WorkerHeartbeat` RPCs if they + // share one mutex. + mutable tsl::mutex get_split_mu_; + + // The filepath of the on-disk state. + const std::string path_; + // A tensorflow environment interface used to write to and read from `path_`. + tsl::Env* const env_; + // Distributed snapshot metadata. + experimental::DistributedSnapshotMetadata metadata_ TF_GUARDED_BY(mu_); + // The last time progress was logged. + absl::Time last_progress_log_time_ TF_GUARDED_BY(mu_); + + // The addresses of all workers considered to be dead based on heartbeat + // timeout. + absl::flat_hash_set dead_workers_ TF_GUARDED_BY(mu_); + + struct Stream { + explicit Stream(int64_t num_sources) + : num_assigned_splits_per_source(num_sources) {} + + enum class State { + // The stream is not finished and the worker is heartbeating. + kActive, + // The stream is finished. + kDone, + }; + + // A counter of assigned splits for each source. + std::vector num_assigned_splits_per_source; + + int64_t num_assigned_splits() const { + return absl::c_accumulate(num_assigned_splits_per_source, 0); + } + + State state = State::kActive; + }; + + struct Source { + Source(std::unique_ptr split_provider, + int64_t repetition_index, int64_t cardinality) + : split_provider(std::move(split_provider)), + repetition_index(repetition_index), + cardinality(cardinality) {} + + // A split provider for each input source of the dataset being snapshotted. + std::unique_ptr split_provider; + // The number of times the split provider has repeated. + int64_t repetition_index = 0; + // The number of splits in `split_provider`. + const int64_t cardinality; + }; + + // Helper class to restore a stream. Multiple stream restorers are safe to run + // in parallel. After it reads the on-disk stream, the client is responsible + // to apply the data to actually restore its internal states. + class StreamRestorer { + public: + explicit StreamRestorer(tsl::Env* env, absl::string_view path, + int64_t stream_index, int64_t num_sources, + SnapshotAssignmentManager& assignment_manager) + : env_(env), + path_(path), + stream_index_(stream_index), + num_sources_(num_sources), + assignment_manager_(assignment_manager) {} + + // Reads snapshot stream from the files and collects data for restoration. + absl::Status ReadOnDiskStream(); + + // Accessors for collected data. Should be called *after* `ReadOnDiskStream` + // is called. + const std::optional& GetStream() const { return restored_stream_; } + int64_t StreamIndex() const { return stream_index_; } + const std::string& WorkerAddress() const { return worker_address_; } + const absl::flat_hash_set& GlobalSplitIndices() const { + return global_split_indices_; + } + + private: + absl::StatusOr OwnerWorkerAddress() const; + absl::Status ReadOnDiskSource(int64_t source_index); + absl::Status ReadOnDiskSplit(int64_t source_index, + const std::vector& split_files, + const std::string& split_file); + absl::Status SkipSplit(SplitProvider& split_provider); + + tsl::Env* const env_; + const std::string path_; + const int64_t stream_index_; + const int64_t num_sources_; + SnapshotAssignmentManager& assignment_manager_; + + std::string worker_address_; + std::optional restored_stream_; + absl::flat_hash_set global_split_indices_; + }; + + // Applies the data collected by `stream_restorer` to actually restore the + // snapshot manager. + absl::Status RestoreFrom( + const StreamRestorer& stream_restorer, + const std::vector& stream_directories, + std::vector>& split_providers, + std::vector& repetition_indices, + absl::flat_hash_set& global_split_indices); + + // Gets the snapshot stream. + Stream& GetStream(int64_t stream_index); + // Initializes the stream directory. + absl::Status InitStreamDirectory( + int64_t stream_index, const std::string& worker_address, + const std::vector& repetitions_per_source); + + std::vector sources_ TF_GUARDED_BY(mu_); + // Creates sources for the specified dataset. + absl::StatusOr> CreateSources( + const DatasetDef& dataset_def) const; + // Returns the total number of splits. + absl::StatusOr GetSplitsCardinality(); + // Resets a source when it runs out of splits, to support repetitions. + absl::Status ResetSource(Source& source, int64_t source_index); + int64_t num_sources() const TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { + return sources_.size(); + } + + // All streams for this snapshot. + absl::btree_map streams_ TF_GUARDED_BY(mu_); + // A counter of completed streams for this snapshot. + int64_t num_completed_streams_ TF_GUARDED_BY(mu_) = 0; + + // A mapping of worker to assigned stream index for this snapshot. + absl::flat_hash_map assignments_ TF_GUARDED_BY(mu_); + // A mapping of worker to assigned streams for all snapshots. + SnapshotAssignmentManager& assignment_manager_ TF_GUARDED_BY(mu_); + + // A counter of assigned splits for this snapshot. + int64_t num_assigned_splits_ TF_GUARDED_BY(mu_) = 0; + // The number of splits in a single repetition of the data in `sources_`. + int64_t num_total_splits_ TF_GUARDED_BY(mu_) = 0; + + enum class Mode { + // No streams are done. + kActive, + // At least one source is fully processed, but not all streams are done. + kWindingDown, + // All streams are done. + kDone, + // If any stream fails, the snapshot is in an error state. `status_` will + // contain the error status. + kError, + }; + + // If not `kActive`, at least one source has finished processing and no new + // streams are created or assigned. + Mode mode_ TF_GUARDED_BY(mu_) = Mode::kActive; + + // If `mode_` is in an error state, `status_` will contain the error status. + absl::Status status_ TF_GUARDED_BY(mu_); +}; + +} // namespace data +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_DATA_SERVICE_SNAPSHOT_SNAPSHOT_MANAGER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/data/service/snapshot/snapshot_split_provider.h b/third_party/tflite-hdrs/tensorflow/core/data/service/snapshot/snapshot_split_provider.h new file mode 100644 index 00000000..b5ca603e --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/data/service/snapshot/snapshot_split_provider.h @@ -0,0 +1,106 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_DATA_SERVICE_SNAPSHOT_SNAPSHOT_SPLIT_PROVIDER_H_ +#define TENSORFLOW_CORE_DATA_SERVICE_SNAPSHOT_SNAPSHOT_SPLIT_PROVIDER_H_ + +#include +#include +#include +#include +#include + +#include "absl/container/btree_map.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/time/time.h" +#include "tensorflow/core/data/service/dispatcher.pb.h" +#include "tensorflow/core/data/service/dispatcher_client.h" +#include "tensorflow/core/framework/dataset.h" +#include "tensorflow/core/framework/tensor.h" +#include "tsl/platform/mutex.h" +#include "tsl/platform/thread_annotations.h" + +namespace tensorflow { +namespace data { + +// Split provider that supports writing distributed snapshots. +class SnapshotSplitProvider : public SplitProvider { + public: + SnapshotSplitProvider(const std::string& worker_address, + const SnapshotTaskDef& snapshot_task, + int64_t source_index, absl::Duration timeout, + std::unique_ptr dispatcher, + Env* env); + + absl::Status GetNext(Tensor* split, bool* end_of_splits) override; + absl::Status Reset() override; + absl::Status Save(std::function full_name, + IteratorStateWriter* writer) override; + absl::Status Restore(std::function full_name, + IteratorStateReader* reader) override; + + private: + const std::string worker_address_; + const SnapshotTaskDef snapshot_task_; + const int64_t source_index_; + const absl::Duration timeout_; + Env* const env_; + + // Gets the next split from file or dispatcher and validates it. + absl::Status GetAndValidateSplit(Tensor* split, bool* end_of_splits); + + // Gets the next split by reading from the splits directory. + absl::Status GetSplitFromFile(const std::string& split_file, Tensor* split, + bool* end_of_splits); + + // Gets the next split by sending an RPC to the dispatcher. Returns the local + // split index from the dispatcher. + absl::StatusOr GetSplitFromDispatcher(Tensor* split, + bool* end_of_splits); + + // Reads from the split directory and returns a map of split index to absolute + // file path of the split, starting at `start_index`. + absl::StatusOr> GetSplitsFiles( + int64_t start_index) const; + + // Verifies `split_files` contains consecutive splits starting at + // `start_index`. + absl::Status ValidateSplitFiles( + const absl::btree_map& split_files, + int64_t start_index) const; + + // Verifies `split_files` contains consecutive splits starting at + // `start_index` and ending at `end_index`. + absl::Status ValidateSplitFiles( + const absl::btree_map& split_files, + int64_t start_index, int64_t end_index, bool end_of_splits) const; + + mutable mutex mu_; + std::unique_ptr dispatcher_ TF_GUARDED_BY(mu_); + + // The next split to read. + int64_t next_split_index_ TF_GUARDED_BY(mu_) = 0; + + // Number of times the dataset has repeated. + int64_t repetition_index_ TF_GUARDED_BY(mu_) = 0; + + // Maps the local split index to the absolute split file path. + absl::btree_map split_to_file_map_ TF_GUARDED_BY(mu_); +}; + +} // namespace data +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_DATA_SERVICE_SNAPSHOT_SNAPSHOT_SPLIT_PROVIDER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/data/service/snapshot/snapshot_stream_writer.h b/third_party/tflite-hdrs/tensorflow/core/data/service/snapshot/snapshot_stream_writer.h new file mode 100644 index 00000000..09d72d86 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/data/service/snapshot/snapshot_stream_writer.h @@ -0,0 +1,245 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_DATA_SERVICE_SNAPSHOT_SNAPSHOT_STREAM_WRITER_H_ +#define TENSORFLOW_CORE_DATA_SERVICE_SNAPSHOT_SNAPSHOT_STREAM_WRITER_H_ + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/substitute.h" +#include "absl/time/clock.h" +#include "absl/time/time.h" +#include "tensorflow/core/data/service/byte_size.h" +#include "tensorflow/core/data/service/common.pb.h" +#include "tensorflow/core/data/service/snapshot/parallel_tfrecord_writer.h" +#include "tensorflow/core/data/service/snapshot/path_utils.h" +#include "tensorflow/core/data/service/task_runner.h" +#include "tensorflow/core/data/service/worker.pb.h" +#include "tensorflow/core/data/snapshot_utils.h" +#include "tensorflow/core/protobuf/service_config.pb.h" +#include "tsl/platform/env.h" +#include "tsl/platform/mutex.h" +#include "tsl/platform/thread_annotations.h" + +namespace tensorflow { +namespace data { + +constexpr ByteSize kDefaultMaxChunkSize = ByteSize::GB(6); +constexpr absl::Duration kDefaultCheckpointInterval = absl::Minutes(30); + +struct SnapshotWriterParams { + // The directory path of the snapshot. See the comment on SnapshotStreamWriter + // for how the directory is structured. + std::string snapshot_path; + + // The index of the snapshot stream. A stream is one shard of the snapshot + // processed by a worker. + int64_t stream_index = 0; + + // Compression method as defined in tsl/lib/io/compression.h. + std::string compression; + + // The Tensorflow environment. + Env* env = nullptr; + + // The maximum number of bytes in each chunk. + ByteSize max_chunk_size = kDefaultMaxChunkSize; + + // How often should checkpoints be written at the steady state. We write + // checkpoints (and committing chunks) more frequently at the startup time to + // avoid starving training jobs during startup. + absl::Duration checkpoint_interval = kDefaultCheckpointInterval; + + // If true, keep temporary files (e.g., checkpoints) after completing the + // snapshot. Used only for unit testing. + bool test_only_keep_temp_files = false; + + std::string StreamDirectory() const { + return tensorflow::data::StreamDirectory(snapshot_path, stream_index); + } + + std::string CommittedChunksDirectory() const { + return tensorflow::data::CommittedChunksDirectory(snapshot_path); + } + + std::string UncommittedChunksDirectory() const { + return tensorflow::data::UncommittedChunksDirectory(snapshot_path, + stream_index); + } + + std::string CheckpointsDirectory() const { + return tensorflow::data::CheckpointsDirectory(snapshot_path, stream_index); + } + + std::string DebugString() const { + return absl::Substitute( + "SnapshotWriterParams { base_path: $0, stream: $1, compression: $2 }", + snapshot_path, stream_index, compression); + } +}; + +// Responsible for writing one snapshot stream, which is organized as following: +// +// - snapshot +// - DONE +// - ERROR +// - snapshot.metadata +// - dataset_def.proto +// - chunks +// - chunk___ +// - streams +// - stream_0 +// - DONE +// - ERROR +// - splits +// - split__ +// - uncommitted chunks +// - chunk_ +// - checkpoints +// - checkpoint__ +// +// This class is thread-safe. +class SnapshotStreamWriter { + public: + // Creates a SnapshotStreamWriter. Once created, it will start writing the + // snapshot stream. Users can call `Wait` to wait for it to finish. + explicit SnapshotStreamWriter(const SnapshotWriterParams& params, + std::unique_ptr iterator); + virtual ~SnapshotStreamWriter() = default; + SnapshotStreamWriter(const SnapshotStreamWriter&) = delete; + SnapshotStreamWriter& operator=(const SnapshotStreamWriter&) = delete; + + // Returns true if the snapshot stream has completed. A snapshot stream is + // completed if the dataset has reached the end of sequence and a DONE file is + // written. Returns an error if the snapshot has failed. This does not block + // the caller. + absl::StatusOr Completed() const; + + // Waits for the writer to finish writing the snapshot stream and returns the + // final status. + absl::StatusOr Wait(); + + // Cancels the writer. If cancelled, `Wait` will return a Cancelled error. + void Cancel(); + + private: + // Writes the snapshot and any debugging log when necessary. + void WriteSnapshotAndLog(); + + // Writes the snapshot. Returns an error if writing fails or the task has been + // cancelled. + absl::Status WriteSnapshot(); + + // Returns true if the stream is already completed and there is no additional + // work to perform. + bool StreamAlreadyCompleted() const; + + // Creates directories to store uncommitted chunks and checkpoints. + absl::Status InitializeDirectories(); + + // Returns true until the snapshot stream writer is finished, which may be due + // to reaching the end of its iterator, encountering an error, or being + // cancelled. + bool ShouldWriteChunks() const; + + // Writes the chunk files. + absl::Status WriteChunks(); + + // Returns true if it should write more records to the current chunks. Returns + // false if it should checkpoint and commit the current chunks, there are no + // more records to write, or there is an error. + bool ShouldWriteRecord() const; + + // Writes the next record to the current chunks. + absl::Status WriteRecord(ParallelTFRecordWriter& writer); + + // Commits the chunks since the last commit. + absl::Status Commit(const ParallelTFRecordWriter::FileToStatsMap& file_stats); + + // Writes a DONE file when the stream is finished. Writes an ERROR file if it + // failed. + absl::Status FinalizeStream(absl::Status status); + absl::Status WriteDoneFile(); + absl::Status WriteErrorFile(const absl::Status& status); + + // Saves an iterator checkpoint. + absl::Status Save(const ParallelTFRecordWriter::FileToStatsMap& file_stats); + + // After committing a checkpoint, deletes the previous checkpoints. + absl::Status DeleteOutdatedCheckpoints(int64_t checkpoint_index); + + // Deletes all checkpoints. + absl::Status DeleteCheckpoints(); + + // Restores from the last checkpoint. + absl::Status Restore(); + + // Returns the filename of the most recent checkpoint. + absl::StatusOr LastCheckpointName() const; + + // Synchronizes the checkpoint with the committed chunks. This is called when + // the worker restores the snapshot in case the worker fails after writing the + // checkpoint but before committing a chunk file. If no checkpoint has been + // written, `checkpoint_index` is nullopt. + absl::Status SyncCheckpointWithChunks(std::optional checkpoint_index, + int64_t checkpoint_num_elements); + + // Index of the last committed chunk. + absl::StatusOr LastCommittedChunkIndex(); + + // Returns the path of the checkpoint for `chunk_index` with + // `chunk_num_elements`. + std::string CheckpointPath(int64_t chunk_index, + int64_t chunk_num_elements) const; + + // Returns the path of the checkpoint for `checkpoint_name`. + std::string CheckpointPath(const std::string& checkpoint_name) const; + + const SnapshotWriterParams params_; + + // The dataset iterator that produces the dataset elements. + std::unique_ptr iterator_; + + // Index of the next chunk to write. + int64_t chunk_index_ = 0; + // Timestamp when the last chunks are committed. + absl::Time last_commit_time_ = absl::Now(); + + // True if the dataset is exhausted. + bool end_of_sequence_ = false; + + mutable mutex mu_; + + // Whether the writer is completed: + // - If the snapshot is successful, this is true. + // - If any error happens during the snapshot write, it is the error status. + // - If the snapshot has not finished, this is false. + absl::StatusOr completed_ TF_GUARDED_BY(mu_) = false; + + std::unique_ptr snapshot_thread_; +}; + +} // namespace data +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_DATA_SERVICE_SNAPSHOT_SNAPSHOT_STREAM_WRITER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/data/service/snapshot/test_utils.h b/third_party/tflite-hdrs/tensorflow/core/data/service/snapshot/test_utils.h new file mode 100644 index 00000000..efa31121 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/data/service/snapshot/test_utils.h @@ -0,0 +1,125 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_DATA_SERVICE_SNAPSHOT_TEST_UTILS_H_ +#define TENSORFLOW_CORE_DATA_SERVICE_SNAPSHOT_TEST_UTILS_H_ + +#include +#include +#include +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/time/time.h" +#include "tensorflow/core/data/service/byte_size.h" +#include "tensorflow/core/data/service/common.pb.h" +#include "tensorflow/core/data/service/snapshot/file_utils.h" +#include "tensorflow/core/data/service/snapshot/path_utils.h" +#include "tensorflow/core/data/service/task_runner.h" +#include "tensorflow/core/data/snapshot_utils.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tsl/platform/env.h" +#include "tsl/platform/path.h" + +namespace tensorflow { +namespace data { +namespace testing { + +// Reads the records from a distributed tf.data snapshot written at `base_path`. +template +absl::StatusOr> ReadSnapshot(const std::string& base_path, + const std::string& compression) { + std::vector result; + std::string chunks_directory = CommittedChunksDirectory(base_path); + TF_ASSIGN_OR_RETURN(std::vector chunk_files, + GetChildren(chunks_directory, Env::Default())); + for (const std::string& chunk_file : chunk_files) { + std::string chunk_file_path = + tsl::io::JoinPath(chunks_directory, chunk_file); + snapshot_util::TFRecordReader tfrecord_reader(chunk_file_path, compression, + DataTypeVector{DT_INT64}); + TF_RETURN_IF_ERROR(tfrecord_reader.Initialize(Env::Default())); + + while (true) { + std::vector tensors; + absl::Status status = tfrecord_reader.ReadTensors(&tensors); + if (absl::IsOutOfRange(status)) { + break; + } + TF_RETURN_IF_ERROR(status); + result.push_back(tensors[0].unaligned_flat().data()[0]); + } + } + return result; +} + +// Writes a partial snapshot to test checkpointing and recovering. It can be +// used to write the specified committed chunks, uncommitted chunks, and +// checkpoints. +class PartialSnapshotWriter { + public: + static absl::StatusOr Create( + const DatasetDef& dataset, const std::string& snapshot_path, + int64_t stream_index, const std::string& compression, + ByteSize max_chunk_size = ByteSize::Bytes(1), + absl::Duration checkpoint_interval = absl::Microseconds(1)); + virtual ~PartialSnapshotWriter() = default; + PartialSnapshotWriter(const PartialSnapshotWriter&) = delete; + PartialSnapshotWriter& operator=(const PartialSnapshotWriter&) = delete; + PartialSnapshotWriter(PartialSnapshotWriter&&) = default; + PartialSnapshotWriter& operator=(PartialSnapshotWriter&&) = delete; + + // Writes the specified chunks. + absl::Status WriteCommittedChunks( + const absl::flat_hash_set& committed_chunk_indexes) const; + + // Writes the specified uncommitted chunks. + absl::Status WriteUncommittedChunks( + const absl::flat_hash_set& uncommitted_chunk_indexes) const; + + // Writes the specified checkpoints. + absl::Status WriteCheckpoints( + const absl::flat_hash_set& checkpoint_indexes) const; + + private: + PartialSnapshotWriter(const DatasetDef& dataset, + const std::string& snapshot_path, int64_t stream_index, + const std::string& compression, ByteSize max_chunk_size, + absl::Duration checkpoint_interval); + + absl::Status Initialize(); + + const DatasetDef dataset_; + const std::string snapshot_path_; + const int64_t stream_index_; + const std::string compression_; + const ByteSize max_chunk_size_; + const absl::Duration checkpoint_interval_; + + std::string tmp_snapshot_path_; +}; + +// Creates a test iterator for the input dataset. The iterator will generate all +// elements of the dataset. +absl::StatusOr> TestIterator( + const DatasetDef& dataset_def); + +} // namespace testing +} // namespace data +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_DATA_SERVICE_SNAPSHOT_TEST_UTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/data/service/snapshot/utils.h b/third_party/tflite-hdrs/tensorflow/core/data/service/snapshot/utils.h new file mode 100644 index 00000000..1ea4d80b --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/data/service/snapshot/utils.h @@ -0,0 +1,34 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_DATA_SERVICE_SNAPSHOT_UTILS_H_ +#define TENSORFLOW_CORE_DATA_SERVICE_SNAPSHOT_UTILS_H_ + +#include + +#include "absl/strings/string_view.h" +#include "tensorflow/core/data/service/byte_size.h" +#include "tensorflow/core/framework/tensor.h" +#include "tsl/platform/status.h" + +namespace tensorflow { +namespace data { + +// Estimates the size of the Tensors when serialized as TensorProtos. +ByteSize EstimatedSize(const std::vector& tensors); + +} // namespace data +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_DATA_SERVICE_SNAPSHOT_UTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/data/service/split_provider.h b/third_party/tflite-hdrs/tensorflow/core/data/service/split_provider.h new file mode 100644 index 00000000..c426fe1a --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/data/service/split_provider.h @@ -0,0 +1,74 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_DATA_SERVICE_SPLIT_PROVIDER_H_ +#define TENSORFLOW_CORE_DATA_SERVICE_SPLIT_PROVIDER_H_ + +#include +#include +#include +#include + +#include "tensorflow/core/data/service/common.pb.h" +#include "tensorflow/core/data/service/dispatcher_client.h" +#include "tensorflow/core/framework/dataset.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/thread_annotations.h" + +namespace tensorflow { +namespace data { + +// SplitProvider which reads splits from a tf.data service dispatcher over RPC. +class DataServiceSplitProvider : public SplitProvider { + public: + DataServiceSplitProvider(const std::string& address, + const std::string& protocol, int64_t iteration_id, + int64_t split_provider_index, int64_t timeout_ms) + : address_(address), + protocol_(protocol), + iteration_id_(iteration_id), + split_provider_index_(split_provider_index), + timeout_ms_(timeout_ms) {} + + absl::Status GetNext(Tensor* split, bool* end_of_splits) override; + absl::Status Reset() override; + absl::Status Save(std::function full_name, + IteratorStateWriter* writer) override; + absl::Status Restore(std::function full_name, + IteratorStateReader* reader) override; + + private: + const std::string address_; + const std::string protocol_; + const int64_t iteration_id_; + const int64_t split_provider_index_; + const int64_t timeout_ms_; + + mutex mu_; + int64_t repetition_ TF_GUARDED_BY(mu_) = 0; + std::unique_ptr dispatcher_ TF_GUARDED_BY(mu_); +}; + +// Makes split providers for `dataset_def` and stores them in `split_providers`. +absl::Status CreateSplitProviders( + const DatasetDef& dataset_def, + std::vector>& split_providers); + +} // namespace data +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_DATA_SERVICE_SPLIT_PROVIDER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/data/service/task_remover.h b/third_party/tflite-hdrs/tensorflow/core/data/service/task_remover.h new file mode 100644 index 00000000..1daf6306 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/data/service/task_remover.h @@ -0,0 +1,54 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_DATA_SERVICE_TASK_REMOVER_H_ +#define TENSORFLOW_CORE_DATA_SERVICE_TASK_REMOVER_H_ + +#include "absl/container/flat_hash_set.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/mutex.h" + +namespace tensorflow { +namespace data { + +// A `TaskRemover` maintains state about a single task and decides whether the +// task should be removed. +class TaskRemover { + public: + explicit TaskRemover(int64_t num_consumers); + + // Attempts to remove the task. The task is removed when all consumers + // concurrently reach a barrier in this method. + // Returns true if the task is successfully removed. + // Returns false if either: + // - There is a timeout waiting for other consumers to request task removal. + // This timeout is hardcoded into the implementation. + // - Another consumer requests removal at a different round. + bool RequestRemoval(int64_t consumer_index, int64_t round); + + private: + const int64_t num_consumers_; + mutex mu_; + condition_variable cv_; + // The round we are considering removing the task in. + int64_t round_ TF_GUARDED_BY(mu_); + bool removed_ TF_GUARDED_BY(mu_) = false; + // Consumers currently blocked in RequestRemoval. + absl::flat_hash_set consumers_waiting_ TF_GUARDED_BY(mu_); +}; + +} // namespace data +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_DATA_SERVICE_TASK_REMOVER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/data/service/task_runner.h b/third_party/tflite-hdrs/tensorflow/core/data/service/task_runner.h new file mode 100644 index 00000000..79d698f9 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/data/service/task_runner.h @@ -0,0 +1,307 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_DATA_SERVICE_TASK_RUNNER_H_ +#define TENSORFLOW_CORE_DATA_SERVICE_TASK_RUNNER_H_ + +#include +#include +#include + +#include "tensorflow/core/data/service/common.pb.h" +#include "tensorflow/core/data/service/cross_trainer_cache.h" +#include "tensorflow/core/data/service/data_transfer.h" +#include "tensorflow/core/data/service/thread_safe_buffer.h" +#include "tensorflow/core/data/service/worker.pb.h" +#include "tensorflow/core/data/standalone.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/statusor.h" +#include "tensorflow/core/platform/thread_annotations.h" +#include "tensorflow/core/protobuf/service_config.pb.h" + +namespace tensorflow { +namespace data { + +// Iterator over a task's elements. +class TaskIterator { + public: + virtual ~TaskIterator() = default; + // If the iterator is not yet exhausted, `GetNext` stores the next element in + // `element` and sets `end_of_sequence` to `false`. Otherwise, sets + // `end_of_sequence to `true`. + virtual absl::Status GetNext(std::vector& element, + bool& end_of_sequence) = 0; + // Reports the cardinality of the dataset that created this iterator. + virtual int64_t Cardinality() const = 0; + + // Saves a checkpoint of the iterator. Returns Tensors that can be called with + // `Restore()`. + virtual absl::StatusOr> Save() { + return errors::Unimplemented( + "Serializing a tf.data service task iterator is unsupported."); + } + + // Restores the iterator from a checkpoint. `saved_iterator` is the serialized + // iterator saved by calling `Save()`. + virtual absl::Status Restore(const std::vector& saved_iterator) { + return errors::Unimplemented( + "Restoring from a tf.data service task iterator is unsupported."); + } + + // Returns the dataset model for performance analysis. + virtual std::shared_ptr model() const { return nullptr; } +}; + +// Implementation of TaskIterator wrapping a standalone iterator. +class StandaloneTaskIterator : public TaskIterator { + public: + // `dataset` should be the dataset that created `iterator`. + // StandaloneTaskIterator takes ownership of the dataset to ensures it + // lives as long as `iterator`. + StandaloneTaskIterator(std::unique_ptr dataset, + std::unique_ptr iterator); + absl::Status GetNext(std::vector& element, + bool& end_of_sequence) override; + int64_t Cardinality() const override; + absl::StatusOr> Save() override; + absl::Status Restore(const std::vector& saved_iterator) override; + std::shared_ptr model() const override; + + private: + std::unique_ptr dataset_; + std::unique_ptr iterator_; +}; + +// Interface for providing elements to task consumers. +class TaskRunner { + public: + // Creates a `TaskRunner` and stores it in `out`. + static absl::Status Create(const experimental::WorkerConfig& worker_config, + const TaskDef& task_def, + std::unique_ptr iterator, + std::unique_ptr& out); + virtual ~TaskRunner() = default; + // Gets the next element for the given request. + virtual absl::Status GetNext(const GetElementRequest& req, + GetElementResult& result) = 0; + // Cancels in-progress `GetNext` requests. + virtual void Cancel() = 0; + // Returns the dataset model for performance analysis. + virtual std::shared_ptr model() const = 0; +}; + +// A task runner which provides elements on a first-come first-served basis. +// It does not consider which consumer is making the request. +class FirstComeFirstServedTaskRunner : public TaskRunner { + public: + explicit FirstComeFirstServedTaskRunner( + std::unique_ptr iterator); + ~FirstComeFirstServedTaskRunner() override; + + // Gets the next element. It may block if the element is not ready yet. + absl::Status GetNext(const GetElementRequest& req, + GetElementResult& result) override; + absl::Status GetNext(GetElementResult& result); + + void Cancel() override; + + std::shared_ptr model() const override; + + private: + // Function to continually prefetch the next element. Returns an error if the + // task has been cancelled. + absl::Status PrefetchFn(); + + // Runs `PrefetchFn` on a dedicated thread. + void RunPrefetchThread(); + + // Gets the next element from the input iterator. + absl::StatusOr GetNextFromInputIterator() + TF_LOCKS_EXCLUDED(mu_); + + const std::shared_ptr model_; + mutex mu_; + std::unique_ptr iterator_ TF_GUARDED_BY(mu_); + int64_t element_index_ TF_GUARDED_BY(mu_) = 0; + + ThreadSafeBuffer buffer_; + std::unique_ptr prefetch_thread_; + + FirstComeFirstServedTaskRunner(const FirstComeFirstServedTaskRunner&) = + delete; + void operator=(const FirstComeFirstServedTaskRunner&) = delete; +}; + +// A task runner which prefetches elements on a first-come first-served basis +// and caches elements in a sliding-window `CrossTrainerCache`. The cache has a +// bounded size and progresses when a trainer that has consumed all elements in +// the cache. Trainers read from a sliding window of the dataset and may not +// read the full dataset. +class CachingTaskRunner : public TaskRunner { + public: + explicit CachingTaskRunner(std::unique_ptr iterator, + size_t max_cache_size_bytes); + ~CachingTaskRunner() override; + + // Gets the next element from the cross-trainer cache, blocking if the data is + // not ready. + // REQUIRES: !req.trainer_id().empty() + absl::Status GetNext(const GetElementRequest& req, + GetElementResult& result) override; + + // Cancel the task runner. After cancelling, all the `GetNext` calls will + // return a Cancelled status. + void Cancel() override; + + // Returns the dataset model for performance analysis. + std::shared_ptr model() const override; + + private: + // The `GetElementResultSequence` generates a sequence of elements from the + // `FirstComeFirstServedTaskRunner`. It is used for the `CrossTrainerCache` to + // generate cached elements. + class GetElementResultSequence : public CachableSequence { + public: + explicit GetElementResultSequence( + FirstComeFirstServedTaskRunner& fcfs_task_runner); + absl::StatusOr GetNext() override; + size_t GetElementSizeBytes(const GetElementResult& element) const override; + + private: + FirstComeFirstServedTaskRunner& fcfs_task_runner_; + }; + + FirstComeFirstServedTaskRunner fcfs_task_runner_; + CrossTrainerCache cache_; + + CachingTaskRunner(const CachingTaskRunner&) = delete; + void operator=(const CachingTaskRunner&) = delete; +}; + +// An element produced by a task. +struct Element { + explicit Element(std::vector&& components, int64_t index) + : components(components), index(index) {} + // The components of the element. + std::vector components; + // The element's index within the task, e.g. 0 for the first element produced + // by the task, 1 for the second element, etc. + int64_t index; +}; + +// Thread for prefetching a round worth of elements. +class PrefetchThread { + public: + explicit PrefetchThread(std::unique_ptr iterator, + int64_t round_size); + ~PrefetchThread(); + // Runs the prefetch thread. It runs until an error is encountered or the + // destructor is called. + void Run(); + // Fills `out` with a round of data. Waits for up to `wait_us` microseconds + // before giving up and returning with `out` empty. A negative `wait_us` + // signals to wait indefinitely. + absl::Status FillBuffer(int64_t wait_us, + std::vector>& out); + // Returns the status for any failures encountered by the prefetch thread. + absl::Status GetStatus(); + // Returns the dataset model for performance analysis. + std::shared_ptr model() const; + + private: + const std::unique_ptr iterator_; + const int64_t round_size_; + mutex mu_; + int64_t index_ TF_GUARDED_BY(mu_) = 0; + // Buffered results for the next round. + std::vector> buffer_ TF_GUARDED_BY(mu_); + // The status if the prefetch thread fails. + absl::Status status_ TF_GUARDED_BY(mu_) = absl::OkStatus(); + // Condition variable notified when elements are added to or removed from + // `buffer_`, or when `status_` is changed. + condition_variable cv_; + bool cancelled_ TF_GUARDED_BY(mu_) = false; + // Thread which constantly tries to fill `buffer_` up with + // `num_consumers` elements. + std::unique_ptr thread_; +}; + +// A task runner which enforces round-robin order for consuming a task's +// elements. `RoundRobinTaskRunner` provides elements in a series of "rounds". +// In each successive round, the runner waits to receive requests from all +// consumers. These requests are blocked until all requests arrive. Once all +// requests arrive, the runner hands out elements to consumers in order of their +// consumer indices. +// +// Consumers are expected to successively request consecutive element indices, +// starting at 0. The same element can be requested multiple times by the same +// consumer, as long as the consumer hasn't yet requested the next element (at +// the start of each round we discard elements from the previous round). +// +// If the worker restarts mid-round, a situation arises where some consumers +// are requesting element index `n` while others are requesting element index +// `n + 1`. To remedy this, the first round after restart may be a partial +// round, where we only serve elements to consumers requesting data for element +// index `n`, blocking other consumers until the second round. +class RoundRobinTaskRunner : public TaskRunner { + public: + RoundRobinTaskRunner(std::unique_ptr iterator, + int64_t num_consumers, string worker_address); + + absl::Status GetNext(const GetElementRequest& req, + GetElementResult& result) override; + void Cancel() override; + std::shared_ptr model() const override; + + private: + // Prepares a full round of data. `wait_us` indicates how long to wait before + // skipping if a full round of data is not yet ready. + absl::Status PrepareFullRound(int64_t wait_us) + TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); + // Prepares a partial round to get consumers back in sync. + absl::Status PreparePartialRound() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); + absl::Status ValidateRequest(const GetElementRequest& req); + // Prepares data for the next round, blocking until the round is ready to + // start. + absl::Status PrepareRound(const GetElementRequest& req); + const int64_t num_consumers_; + const string worker_address_; + mutex mu_; + bool cancelled_ TF_GUARDED_BY(mu_) = false; + // Condition variable notified whenever we start a new round of round-robin. + condition_variable new_round_cv_; + // Outstanding requests, indexed by round number and then consumer index. + absl::flat_hash_map> + requests_ TF_GUARDED_BY(mu_); + // Index of the first round we plan to serve. At startup, this is the minimum + // of all requested element indices. + int64_t first_round_ TF_GUARDED_BY(mu_) = kint64max; + int64_t current_round_ TF_GUARDED_BY(mu_) = -1; + bool round_skipped_ TF_GUARDED_BY(mu_) = false; + // Buffered results for the current round. + std::vector> buffer_ TF_GUARDED_BY(mu_); + // Thread which constantly tries to prepare `num_consumers` elements for the + // next round. + PrefetchThread prefetch_thread_; +}; + +} // namespace data +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_DATA_SERVICE_TASK_RUNNER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/data/service/test_cluster.h b/third_party/tflite-hdrs/tensorflow/core/data/service/test_cluster.h new file mode 100644 index 00000000..b1d242fe --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/data/service/test_cluster.h @@ -0,0 +1,288 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_DATA_SERVICE_TEST_CLUSTER_H_ +#define TENSORFLOW_CORE_DATA_SERVICE_TEST_CLUSTER_H_ + +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" +#include "absl/types/optional.h" +#include "tensorflow/core/data/service/common.pb.h" +#include "tensorflow/core/data/service/data_transfer.h" +#include "tensorflow/core/data/service/dispatcher.pb.h" +#include "tensorflow/core/data/service/dispatcher_client.h" +#include "tensorflow/core/data/service/export.pb.h" +#include "tensorflow/core/data/service/server_lib.h" +#include "tensorflow/core/data/service/test_util.h" +#include "tensorflow/core/data/service/worker.pb.h" +#include "tensorflow/core/data/service/worker_client.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/statusor.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/protobuf/data_service.pb.h" + +namespace tensorflow { +namespace data { + +// Helper class for unit testing a tf.data service cluster. +class TestCluster { + public: + struct Config { + public: + int num_workers = 3; + int64_t client_timeout_ms = 0; + int64_t worker_heartbeat_interval_ms = 0; + int64_t job_gc_check_interval_ms = 0; + int64_t job_gc_timeout_ms = 0; + int64_t worker_max_concurrent_snapshots = 0; + std::string work_dir; + }; + + // Creates a new test cluster with a dispatcher and `num_workers` workers. + explicit TestCluster( + int num_workers, + std::optional data_transfer_protocol = std::nullopt); + explicit TestCluster(const Config& config); + virtual ~TestCluster(); + + // Initializes the test cluster. This must be called before interacting with + // the cluster. Initialize should be called only once. + absl::Status Initialize(); + // Adds a new worker to the cluster. + absl::Status AddWorker( + std::optional port = std::nullopt, + std::optional data_transfer_protocol = std::nullopt); + // Returns the number of workers in this cluster. + size_t NumWorkers() const { return workers_.size(); } + // Returns the port number of a worker. + int WorkerBoundPort(size_t worker_index) const { + return workers_[worker_index]->BoundPort(); + } + // Returns the number of active iterations. + absl::StatusOr NumActiveIterations() const { + return dispatcher_->NumActiveIterations(); + } + // Returns the dispatcher address in the form "hostname:port". + std::string DispatcherAddress() const; + // Returns the address of the worker at the specified index, in the form + // "hostname:port". The index must be non-negative and less than the number of + // workers in the cluster. + std::string WorkerAddress(int index) const; + + // Stops one worker. + void StopWorker(size_t index); + // Stops all workers. + void StopWorkers(); + + // Returns the server state exports. + ServerStateExport ExportDispatcherState() const; + ServerStateExport ExportWorkerState(size_t index) const; + + private: + bool initialized_ = false; + int num_workers_; + std::optional data_transfer_protocol_; + Config config_; + std::unique_ptr dispatcher_; + std::string dispatcher_address_; + std::vector> workers_; + std::vector worker_addresses_; +}; + +// A test utility to provide a `DatasetDef` to a `TestCluster` and generate data +// from each worker for verification. For example: +// +// TestCluster cluster(/*num_workers=*/2); +// TF_ASSERT_OK(cluster.Initialize()); +// DatasetClient dataset_reader(cluster); +// +// EXPECT_THAT( +// dataset_reader.Read(RangeDataset(4), ProcessingModeDef::DATA, +// TARGET_WORKERS_LOCAL), +// IsOkAndHolds(UnorderedElementsAre( +// Pair(cluster.WorkerAddress(0), ElementsAre(0, 2)), +// Pair(cluster.WorkerAddress(1), ElementsAre(1, 3))))); +template +class DatasetClient { + public: + // Creates a dataset client. It will process datasets in `cluster`. + explicit DatasetClient(const TestCluster& cluster); + + // Registers the dataset and returns the dataset ID. + absl::StatusOr RegisterDataset(const DatasetDef& dataset); + + // Maps a worker address to the data it produces when calling `Read`. + using WorkerResultMap = absl::flat_hash_map>; + + // Processes `dataset` and retrieves the data from workers. Returns the data + // produced by each worker, keyed by the worker address. + StatusOr Read( + const DatasetDef& dataset, + ProcessingModeDef::ShardingPolicy sharding_policy, + TargetWorkers target_workers); + // Creates an iteration and returns the iteration client ID. + absl::StatusOr CreateIteration(const DatasetDef& dataset); + // Gets the tasks for iteration `iteration_client_id`. The iteration has one + // task processed by every worker. + absl::StatusOr> GetTasks(int64_t iteration_client_id); + + private: + // Creates an iteration and returns the iteration client ID. + absl::StatusOr CreateIteration( + const std::string& dataset_id, + ProcessingModeDef::ShardingPolicy sharding_policy, + TargetWorkers target_workers); + // Reads values from `tasks`, one task at a time, until all tasks have + // finished. + StatusOr ReadFromTasks(const std::vector& tasks); + // Reads the next element from the specified task. + absl::StatusOr ReadFromTask(const TaskInfo& task_info); + + const TestCluster& cluster_; + std::unique_ptr dispatcher_client_; + absl::flat_hash_map> + worker_clients_; +}; + +template +DatasetClient::DatasetClient(const TestCluster& cluster) + : cluster_(cluster) { + dispatcher_client_ = std::make_unique( + cluster_.DispatcherAddress(), "grpc"); + + for (size_t i = 0; i < cluster.NumWorkers(); ++i) { + worker_clients_[cluster_.WorkerAddress(i)] = + std::make_unique( + cluster_.WorkerAddress(i), /*protocol=*/"grpc", + /*transfer_protocol=*/"grpc", + /*fall_back_to_grpc_at_get_element_time=*/true, + /*accelerator_device_info=*/nullptr, /*allocator=*/nullptr); + } +} + +template +StatusOr::WorkerResultMap> DatasetClient::Read( + const DatasetDef& dataset, + ProcessingModeDef::ShardingPolicy sharding_policy, + TargetWorkers target_workers) { + TF_ASSIGN_OR_RETURN(const std::string dataset_id, RegisterDataset(dataset)); + TF_ASSIGN_OR_RETURN( + const int64_t iteration_client_id, + CreateIteration(dataset_id, sharding_policy, target_workers)); + TF_ASSIGN_OR_RETURN(const std::vector tasks, + GetTasks(iteration_client_id)); + return ReadFromTasks(tasks); +} + +template +absl::StatusOr DatasetClient::RegisterDataset( + const DatasetDef& dataset) { + std::string dataset_id; + TF_RETURN_IF_ERROR(dispatcher_client_->RegisterDataset( + dataset, DataServiceMetadata(), /*requested_dataset_id=*/std::nullopt, + dataset_id)); + return dataset_id; +} + +template +absl::StatusOr DatasetClient::CreateIteration( + const std::string& dataset_id, + ProcessingModeDef::ShardingPolicy sharding_policy, + TargetWorkers target_workers) { + ProcessingModeDef processing_mode_def; + processing_mode_def.set_sharding_policy(sharding_policy); + int64_t job_id; + TF_RETURN_IF_ERROR(dispatcher_client_->GetOrCreateJob( + dataset_id, processing_mode_def, /*job_name=*/std::nullopt, + /*num_consumers=*/std::nullopt, /*use_cross_trainer_cache=*/false, + target_workers, job_id)); + int64_t iteration_client_id; + TF_RETURN_IF_ERROR(dispatcher_client_->GetOrCreateIteration( + job_id, /*repetition=*/0, iteration_client_id)); + return iteration_client_id; +} + +template +absl::StatusOr DatasetClient::CreateIteration( + const DatasetDef& dataset) { + TF_ASSIGN_OR_RETURN(const std::string dataset_id, RegisterDataset(dataset)); + return CreateIteration(dataset_id, ProcessingModeDef::OFF, + TARGET_WORKERS_ANY); +} + +template +absl::StatusOr> DatasetClient::GetTasks( + const int64_t iteration_client_id) { + ClientHeartbeatRequest request; + ClientHeartbeatResponse response; + request.set_iteration_client_id(iteration_client_id); + TF_RETURN_IF_ERROR(dispatcher_client_->ClientHeartbeat(request, response)); + if (response.task_info().empty()) { + return errors::NotFound("No task found for iteration ", iteration_client_id, + "."); + } + return std::vector(response.task_info().begin(), + response.task_info().end()); +} + +template +StatusOr::WorkerResultMap> +DatasetClient::ReadFromTasks(const std::vector& tasks) { + WorkerResultMap result; + bool all_workers_finished = false; + while (!all_workers_finished) { + all_workers_finished = true; + for (const TaskInfo& task : tasks) { + absl::StatusOr element_result = ReadFromTask(task); + // A task may be cancelled when it has finished but other workers are + // still producing data. + if (absl::IsCancelled(element_result.status())) { + continue; + } + TF_RETURN_IF_ERROR(element_result.status()); + if (element_result->end_of_sequence) { + continue; + } + all_workers_finished = false; + result[task.worker_address()].push_back( + element_result->components[0].unaligned_flat().data()[0]); + } + } + return result; +} + +template +absl::StatusOr DatasetClient::ReadFromTask( + const TaskInfo& task_info) { + GetElementRequest request; + GetElementResult element_result; + request.set_task_id(task_info.task_id()); + TF_RETURN_IF_ERROR(worker_clients_[task_info.worker_address()]->GetElement( + request, element_result)); + return element_result; +} + +} // namespace data +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_DATA_SERVICE_TEST_CLUSTER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/data/service/test_util.h b/third_party/tflite-hdrs/tensorflow/core/data/service/test_util.h new file mode 100644 index 00000000..2180675b --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/data/service/test_util.h @@ -0,0 +1,109 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_DATA_SERVICE_TEST_UTIL_H_ +#define TENSORFLOW_CORE_DATA_SERVICE_TEST_UTIL_H_ + +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "tensorflow/core/data/service/common.pb.h" +#include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/tstring.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/protobuf/snapshot.pb.h" + +namespace tensorflow { +namespace data { +namespace testing { + +// Creates a local tempfile and returns the path. +std::string LocalTempFilename(); + +// Creates a dataset graph for testing. `dataset_name` is one of the filenames +// defined in `testdata` (without `.pbtxt`). `args` specifies arguments passed +// to the dataset. These args appear as `$0`, `$1`, etc, in the dataset +// definition and will be replaced with the specified args. +absl::StatusOr GetTestDataset( + absl::string_view dataset_name, const std::vector& args = {}); + +// Returns a test dataset representing +// tf.data.Dataset.range(range). Useful for testing dataset graph execution. +DatasetDef RangeDataset(int64_t range); + +// Returns a test dataset representing +// tf.data.Dataset.range(range).map(lambda x: x*x). +DatasetDef RangeSquareDataset(int64_t range); + +// Returns a test dataset representing +// tf.data.Dataset.range(range).shard(SHARD_HINT, SHARD_HINT). +DatasetDef RangeDatasetWithShardHint(int64_t range); + +// Returns a test dataset representing +// tf.data.Dataset.range(100000000).repeat(). +DatasetDef InfiniteDataset(); + +// Returns a distributed snapshot metadata for a dummy dataset. +experimental::DistributedSnapshotMetadata +CreateDummyDistributedSnapshotMetadata(); + +// Returns a test dataset representing +// tf.data.Dataset.from_tensor_slices(["filenames"]).interleave( +// lambda filepath: tf.data.TextLineDataset(filepath), +// cycle_length=10) +absl::StatusOr InterleaveTextlineDataset( + const std::vector& filenames, + const std::vector& contents); + +// Repeatedly calls `f()`, blocking until `f()` returns `false`. +// +// Returns an error if `f()` returns an error. +absl::Status WaitWhile(std::function()> f); + +// TODO(b/229726259): Make EqualsProto available in Googletest +// (Public feature request: https://github.com/google/googletest/issues/1761). +class ProtoStringMatcher { + public: + explicit ProtoStringMatcher(const tensorflow::protobuf::Message& expected) + : expected_(expected.ShortDebugString()) {} + + template + bool MatchAndExplain(const Message& p, + ::testing::MatchResultListener*) const { + return p.ShortDebugString() == expected_; + } + + void DescribeTo(::std::ostream* os) const { *os << expected_; } + void DescribeNegationTo(::std::ostream* os) const { + *os << "not equal to expected message: " << expected_; + } + + private: + const std::string expected_; +}; + +inline ::testing::PolymorphicMatcher EqualsProto( + const tensorflow::protobuf::Message& x) { + return ::testing::MakePolymorphicMatcher(ProtoStringMatcher(x)); +} +} // namespace testing +} // namespace data +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_DATA_SERVICE_TEST_UTIL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/data/service/thread_safe_buffer.h b/third_party/tflite-hdrs/tensorflow/core/data/service/thread_safe_buffer.h new file mode 100644 index 00000000..570fb5ce --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/data/service/thread_safe_buffer.h @@ -0,0 +1,122 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_DATA_SERVICE_THREAD_SAFE_BUFFER_H_ +#define TENSORFLOW_CORE_DATA_SERVICE_THREAD_SAFE_BUFFER_H_ + +#include +#include + +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/statusor.h" + +namespace tensorflow { +namespace data { + +// A thread-safe bounded buffer with cancellation support. +template +class ThreadSafeBuffer final { + public: + // Creates a buffer with the specified `buffer_size`. + // REQUIRES: buffer_size > 0 + explicit ThreadSafeBuffer(size_t buffer_size); + + // Gets the next element. Blocks if the buffer is empty. Returns an error if + // a non-OK status was pushed or the buffer has been cancelled. + StatusOr Pop(); + + // Writes the next element. Blocks if the buffer is full. Returns an error if + // the buffer has been cancelled. + absl::Status Push(StatusOr value); + + // Cancels the buffer with `status` and notifies waiting threads. After + // cancelling, all `Push` and `Pop` calls will return `status`. + // REQUIRES: !status.ok() + void Cancel(absl::Status status); + + // Returns whether the buffer is empty. + bool Empty() const; + + private: + const size_t buffer_size_; + + mutable mutex mu_; + condition_variable ready_to_pop_; + condition_variable ready_to_push_; + std::deque> results_ TF_GUARDED_BY(mu_); + absl::Status status_ TF_GUARDED_BY(mu_) = absl::OkStatus(); + + ThreadSafeBuffer(const ThreadSafeBuffer&) = delete; + void operator=(const ThreadSafeBuffer&) = delete; +}; + +template +ThreadSafeBuffer::ThreadSafeBuffer(size_t buffer_size) + : buffer_size_(buffer_size) { + DCHECK_GT(buffer_size, 0) + << "ThreadSafeBuffer must have a positive buffer size. Got " + << buffer_size << "."; +} + +template +bool ThreadSafeBuffer::Empty() const { + tf_shared_lock l(mu_); + return results_.empty(); +} + +template +StatusOr ThreadSafeBuffer::Pop() { + mutex_lock l(mu_); + while (status_.ok() && results_.empty()) { + ready_to_pop_.wait(l); + } + if (!status_.ok()) { + return status_; + } + StatusOr result = std::move(results_.front()); + results_.pop_front(); + ready_to_push_.notify_one(); + return result; +} + +template +absl::Status ThreadSafeBuffer::Push(StatusOr value) { + mutex_lock l(mu_); + while (status_.ok() && results_.size() >= buffer_size_) { + ready_to_push_.wait(l); + } + if (!status_.ok()) { + return status_; + } + results_.push_back(std::move(value)); + ready_to_pop_.notify_one(); + return absl::OkStatus(); +} + +template +void ThreadSafeBuffer::Cancel(absl::Status status) { + DCHECK(!status.ok()) + << "Cancelling ThreadSafeBuffer requires a non-OK status. Got " << status; + mutex_lock l(mu_); + status_ = std::move(status); + ready_to_push_.notify_all(); + ready_to_pop_.notify_all(); +} + +} // namespace data +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_DATA_SERVICE_THREAD_SAFE_BUFFER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/data/service/url.h b/third_party/tflite-hdrs/tensorflow/core/data/service/url.h new file mode 100644 index 00000000..84afa162 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/data/service/url.h @@ -0,0 +1,52 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_DATA_SERVICE_URL_H_ +#define TENSORFLOW_CORE_DATA_SERVICE_URL_H_ + +#include + +#include "absl/strings/string_view.h" + +namespace tensorflow { +namespace data { + +// Parses URLs of form host[:port] and provides methods to retrieve its +// components. The port can be a number, named port, or dynamic port +// (i.e.: %port_name%). For example: +// +// URL url("/worker/task/0:worker"); +// url.has_protocol() == false; +// url.host() == "/worker/task/0"; +// url.has_port() == true; +// url.port() == "worker"; +class URL { + public: + explicit URL(absl::string_view url); + + absl::string_view host() const { return host_; } + bool has_port() const { return !port_.empty(); } + absl::string_view port() const { return port_; } + + private: + void Parse(absl::string_view url); + + std::string host_; + std::string port_; +}; + +} // namespace data +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_DATA_SERVICE_URL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/data/service/utils.h b/third_party/tflite-hdrs/tensorflow/core/data/service/utils.h new file mode 100644 index 00000000..482d306e --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/data/service/utils.h @@ -0,0 +1,41 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_DATA_SERVICE_UTILS_H_ +#define TENSORFLOW_CORE_DATA_SERVICE_UTILS_H_ + +#include + +#include "tensorflow/core/data/service/common.pb.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/io/record_reader.h" +#include "tensorflow/core/platform/env.h" + +// Utilities shared between the dispatcher and worker servers. +namespace tensorflow { +namespace data { + +// Writes a dataset definition to the specified path. If the file already +// exists, it will be overwritten. +absl::Status WriteDatasetDef(const std::string& path, + const DatasetDef& dataset_def); + +// Reads a dataset definition from specified path, and stores it in +// `dataset_def`. Returns NOT_FOUND if the path cannot be found. +absl::Status ReadDatasetDef(const std::string& path, DatasetDef& dataset_def); + +} // namespace data +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_DATA_SERVICE_UTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/data/service/validate_utils.h b/third_party/tflite-hdrs/tensorflow/core/data/service/validate_utils.h new file mode 100644 index 00000000..c4278023 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/data/service/validate_utils.h @@ -0,0 +1,35 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_DATA_SERVICE_VALIDATE_UTILS_H_ +#define TENSORFLOW_CORE_DATA_SERVICE_VALIDATE_UTILS_H_ + +#include + +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/protobuf/data_service.pb.h" + +namespace tensorflow { +namespace data { + +// Verifies the datasets with the same ID have the same metadata. If the +// metadata differs, returns an invalid argument error. +absl::Status ValidateMatchingDataset(const std::string& dataset_id, + const DataServiceMetadata& metadata1, + const DataServiceMetadata& metadata2); + +} // namespace data +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_DATA_SERVICE_VALIDATE_UTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/data/service/worker_client.h b/third_party/tflite-hdrs/tensorflow/core/data/service/worker_client.h new file mode 100644 index 00000000..64ac446b --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/data/service/worker_client.h @@ -0,0 +1,106 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_DATA_SERVICE_WORKER_CLIENT_H_ +#define TENSORFLOW_CORE_DATA_SERVICE_WORKER_CLIENT_H_ + +#include +#include + +#include "tensorflow/core/data/service/common.h" +#include "tensorflow/core/data/service/common.pb.h" +#include "tensorflow/core/data/service/data_transfer.h" +#include "tensorflow/core/data/service/worker.pb.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/status.h" + +namespace tensorflow { +namespace data { + +constexpr const char kLocalTransferProtocol[] = "local"; +constexpr const char kGrpcTransferProtocol[] = "grpc"; + +// Client for communicating with the tf.data service worker. +class DataServiceWorkerClient : public DataServiceClientBase { + public: + DataServiceWorkerClient( + const std::string& address, const std::string& protocol, + const std::string& transfer_protocol, + bool fall_back_to_grpc_at_get_element_time, + const DeviceBase::AcceleratorDeviceInfo* accelerator_device_info, + Allocator* allocator) + : DataServiceClientBase(address, protocol), + transfer_protocol_(transfer_protocol), + fall_back_to_grpc_at_get_element_time_( + fall_back_to_grpc_at_get_element_time), + accelerator_device_info_(accelerator_device_info), + allocator_(allocator) {} + + // Fetches an element from the worker. + absl::Status GetElement(const GetElementRequest& req, + GetElementResult& result); + + // Makes a best effort to cancel all outstanding calls in progress for the + // client, and causes further calls to return Cancelled status. + void TryCancel(); + + // Returns an error if the client is incompatible with a server which has the + // properties described in `compatibility_info`. + absl::Status CheckCompatibility( + const std::string& server_compatibility_info) const { + return client_->CheckCompatibility(server_compatibility_info); + } + + // If `true`, data service clients should fall back to gRPC for this worker + // client if it nonretryably fails to transfer an element using an alternative + // data transfer protocol. + bool FallBackToGrpcAtGetElementTime() const { + return fall_back_to_grpc_at_get_element_time_; + } + + // Returns the data transfer protocol, preferring to use the local transfer + // protocol if a local tf.data worker exists. + std::string GetDataTransferProtocol() const; + + protected: + absl::Status EnsureInitialized() override; + + private: + std::string transfer_protocol_; + bool fall_back_to_grpc_at_get_element_time_; + const DeviceBase::AcceleratorDeviceInfo* accelerator_device_info_; + Allocator* allocator_; + + mutex mu_; + // Initialization is guarded by `mu_`, but using the stub does not require + // holding `mu_` + std::unique_ptr client_; +}; + +// Creates and initializes a new tf.data service worker client to read +// from the data transfer server specified in `info`. +absl::StatusOr> +CreateDataServiceWorkerClient( + const std::string& dispatcher_protocol, const DataTransferServerInfo& info, + const DeviceBase::AcceleratorDeviceInfo* accelerator_device_info, + Allocator* allocator); + +// If true, clients should use local protocol for data transfer (disregarding +// any other user-specified or runtime-defaulted protocol). +bool ForceLocalProtocol(const std::string& worker_address); + +} // namespace data +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_DATA_SERVICE_WORKER_CLIENT_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/data/service/worker_impl.h b/third_party/tflite-hdrs/tensorflow/core/data/service/worker_impl.h new file mode 100644 index 00000000..c256c88c --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/data/service/worker_impl.h @@ -0,0 +1,251 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_DATA_SERVICE_WORKER_IMPL_H_ +#define TENSORFLOW_CORE_DATA_SERVICE_WORKER_IMPL_H_ + +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/hash/hash.h" +#include "absl/strings/string_view.h" +#include "tensorflow/core/data/service/common.pb.h" +#include "tensorflow/core/data/service/data_transfer.h" +#include "tensorflow/core/data/service/dispatcher_client.h" +#include "tensorflow/core/data/service/export.pb.h" +#include "tensorflow/core/data/service/snapshot/snapshot_stream_writer.h" +#include "tensorflow/core/data/service/task_runner.h" +#include "tensorflow/core/data/service/worker.pb.h" +#include "tensorflow/core/data/standalone.h" +#include "tensorflow/core/framework/cancellation.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/statusor.h" +#include "tensorflow/core/platform/thread_annotations.h" +#include "tensorflow/core/protobuf/service_config.pb.h" + +namespace tensorflow { +namespace data { + +// A TensorFlow DataService serves dataset elements over RPC. +class DataServiceWorkerImpl { + public: + explicit DataServiceWorkerImpl(const experimental::WorkerConfig& config); + ~DataServiceWorkerImpl(); + + // Starts the worker. The worker needs to know its own address so that it can + // register with the dispatcher. This is set in `Start` instead of in the + // constructor because the worker may be binding to port `0`, in which case + // the address isn't known until the worker has started and decided which port + // to bind to. + absl::Status Start( + const std::string& worker_address, + const std::vector& transfer_servers); + // Stops the worker, attempting a clean shutdown by rejecting new requests + // and waiting for outstanding requests to complete. + void Stop(); + + // Serves a GetElement request, storing the result in `*result`. See + // worker.proto for GetElement API documentation. + absl::Status GetElementResult(const GetElementRequest* request, + GetElementResult* result); + + // Deletes the local task and iterator. Only called by local clients to delete + // unused task iterators assuming the task is not read by remote clients. This + // method is not visible to gRPC clients. + void DeleteLocalTask(const TaskInfo& task_info); + + // See worker.proto for API documentation. + + /// Dispatcher-facing API. + absl::Status ProcessTask(const ProcessTaskRequest* request, + ProcessTaskResponse* response); + + /// Client-facing API. + absl::Status GetElement(const GetElementRequest* request, + GetElementResponse* response); + absl::Status GetWorkerTasks(const GetWorkerTasksRequest* request, + GetWorkerTasksResponse* response); + absl::Status GetSnapshotTaskProgresses( + const GetSnapshotTaskProgressesRequest* request, + GetSnapshotTaskProgressesResponse* response); + + // Exports the worker state for debugging. + WorkerStateExport ExportState() const; + + private: + struct Task { + explicit Task(TaskDef task_def) : task_def(std::move(task_def)) {} + + TaskDef task_def; + mutex mu; + bool initialized TF_GUARDED_BY(mu) = false; + int64_t outstanding_requests TF_GUARDED_BY(&DataServiceWorkerImpl::mu_) = 0; + std::unique_ptr task_runner; + }; + + struct SnapshotTask { + // Base directory of the snapshot. + std::string base_path; + + // Index of the snapshot stream written by this worker. + int64_t stream_index = 0; + + // This is required to use it as a `flat_hash_map` key. + template + friend H AbslHashValue(H h, const SnapshotTask& task) { + return H::combine(std::move(h), task.base_path, task.stream_index); + } + + friend bool operator==(const SnapshotTask& task1, + const SnapshotTask& task2) { + return task1.base_path == task2.base_path && + task1.stream_index == task2.stream_index; + } + }; + + // Validates the worker config. + absl::Status ValidateWorkerConfig() const; + // Creates and initializes a dispatcher client. + absl::StatusOr> + CreateDispatcherClient() const TF_LOCKS_EXCLUDED(mu_); + // Sends task status to the dispatcher and checks for dispatcher commands. + absl::Status SendTaskUpdates() TF_LOCKS_EXCLUDED(mu_); + // Creates an iterator to process a task. + absl::Status ProcessTaskInternal(const TaskDef& task) + TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); + absl::Status EnsureTaskInitialized(Task& task); + // Stops a task, cancelling the task's outstanding requests and waiting for + // them to finish. + void StopTask(Task& task) TF_LOCKS_EXCLUDED(mu_); + // A thread for notifying the dispatcher when tasks complete. + void TaskCompletionThread() TF_LOCKS_EXCLUDED(mu_); + // A thread for doing periodic heartbeats to the dispatcher. + void HeartbeatThread() TF_LOCKS_EXCLUDED(mu_); + // Performs a heartbeat to the dispatcher. + absl::Status Heartbeat(); + // Check with the dispatcher to see whether or not to disable compression. + absl::StatusOr DisableCompressionAtRuntime( + const std::string& dataset_id) const; + // Returns the active tasks of this worker. + std::vector GetActiveTasks() const TF_LOCKS_EXCLUDED(mu_); + // Returns the task IDs of `active_tasks`. + std::vector GetTaskIds( + const std::vector& active_tasks) const; + // Builds a heartbeat request. + WorkerHeartbeatRequest BuildWorkerHeartbeatRequest() const + TF_LOCKS_EXCLUDED(mu_); + // Updates the tasks according to the heartbeat response. + void UpdateTasks(const WorkerHeartbeatResponse& response) + TF_LOCKS_EXCLUDED(mu_); + // Updates the distributed snapshot tasks according to the heartbeat response. + absl::Status UpdateSnapshotWriters(const WorkerHeartbeatResponse& response) + TF_LOCKS_EXCLUDED(mu_); + // Creates an dataset iterator for snapshot writers. + absl::StatusOr> + MakeSnapshotTaskIterator(const SnapshotTaskDef& snapshot_task, + const DatasetDef& dataset_def) const; + // Gets the snapshot task progress from the snapshot writers. + std::vector GetSnapshotTaskProgress() const; + // Gets the DatasetDef for `task_def`. + absl::StatusOr GetDatasetDef(const TaskDef& task_def) const; + // Creates a dataset from `dataset_def`. + absl::StatusOr> MakeDataset( + const DatasetDef& dataset_def, const TaskDef& task_def) const; + // Creates an iterator for `dataset`. + absl::StatusOr> MakeDatasetIterator( + standalone::Dataset& dataset, const TaskDef& task_def) const; + + const experimental::WorkerConfig config_; + // Worker Borg job UID for telemetry. -1 if not supported. + const int64_t worker_uid_; + + // The worker's own address. + std::string worker_address_; + // The data transfer servers available to worker clients. + std::vector transfer_servers_; + std::unique_ptr dispatcher_; + + mutable mutex mu_; + condition_variable cv_; + // Information about tasks, keyed by task ids. The tasks are updated based on + // the heartbeat responses from the dispatcher. + absl::flat_hash_map> tasks_ TF_GUARDED_BY(mu_); + // Ids of tasks that have finished. + absl::flat_hash_set finished_tasks_ TF_GUARDED_BY(mu_); + // Completed tasks which haven't yet been communicated to the dispatcher. + absl::flat_hash_set pending_completed_tasks_ TF_GUARDED_BY(mu_); + // Tasks deleted by the local client. If the client tries to read from them + // again, the worker will return a non-retriable FailedPrecondition error. + absl::flat_hash_set deleted_tasks_ TF_GUARDED_BY(mu_); + bool cancelled_ TF_GUARDED_BY(mu_) = false; + // Whether the worker has registered with the dispatcher yet. + bool registered_ TF_GUARDED_BY(mu_) = false; + condition_variable task_completion_cv_ TF_GUARDED_BY(mu_); + condition_variable heartbeat_cv_ TF_GUARDED_BY(mu_); + CancellationManager cancellation_manager_; + + absl::flat_hash_map, + absl::Hash> + snapshot_writers_ TF_GUARDED_BY(mu_); + + // A thread for notifying the dispatcher when tasks complete. + std::unique_ptr task_completion_thread_; + // A thread for performing regular heartbeats to the dispatcher. + std::unique_ptr heartbeat_thread_; + + DataServiceWorkerImpl(const DataServiceWorkerImpl&) = delete; + void operator=(const DataServiceWorkerImpl&) = delete; +}; + +// Local in-process workers shared among clients and servers. If clients and +// workers colocate in the same process, clients can read from local workers to +// reduce RPC calls and data copy. +class LocalWorkers { + public: + // Adds a `worker` at `worker_address`. If a worker already exists at the + // address, it will be updated to the new `worker`. + // REQUIRES: worker != nullptr. + static void Add(absl::string_view worker_address, + std::shared_ptr worker); + + // Gets a local worker at `worker_address`. Returns nullptr if a worker is not + // found. + static std::shared_ptr Get( + absl::string_view worker_address); + + // Returns if there are any local workers in the process. + static bool Empty(); + + // Removes a worker at `worker_address`. It is no-op if a worker is not found + // at the address. + static void Remove(absl::string_view worker_address); + + private: + using AddressToWorkerMap = + absl::flat_hash_map>; + static mutex mu_; + static AddressToWorkerMap* local_workers_ TF_GUARDED_BY(mu_); +}; + +} // namespace data +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_DATA_SERVICE_WORKER_IMPL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/data/snapshot_utils.h b/third_party/tflite-hdrs/tensorflow/core/data/snapshot_utils.h new file mode 100644 index 00000000..f083cbe4 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/data/snapshot_utils.h @@ -0,0 +1,459 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_DATA_SNAPSHOT_UTILS_H_ +#define TENSORFLOW_CORE_DATA_SNAPSHOT_UTILS_H_ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "tensorflow/core/framework/dataset.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/io/compression.h" +#include "tensorflow/core/lib/io/inputstream_interface.h" +#include "tensorflow/core/lib/io/record_reader.h" +#include "tensorflow/core/lib/io/record_writer.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/file_system.h" +#include "tensorflow/core/platform/path.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/protobuf/snapshot.pb.h" + +namespace tensorflow { + +class GraphDef; + +namespace data { + +namespace experimental { + +class SnapshotMetadataRecord; +class SnapshotTensorMetadata; + +} // namespace experimental + +namespace snapshot_util { + +constexpr char kMetadataFilename[] = "snapshot.metadata"; + +constexpr char kModeAuto[] = "auto"; +constexpr char kModeWrite[] = "write"; +constexpr char kModeRead[] = "read"; +constexpr char kModePassthrough[] = "passthrough"; +constexpr char kShardDirectorySuffix[] = ".shard"; + +enum Mode { READER = 0, WRITER = 1, PASSTHROUGH = 2 }; + +// Returns the name of the "hash" directory for the given base path and hash ID. +std::string HashDirectory(const std::string& path, uint64 hash); + +// Returns the name of the "run" directory for the given base path and run ID. +std::string RunDirectory(const std::string& hash_directory, uint64 run_id); +std::string RunDirectory(const std::string& hash_directory, + const std::string& run_id); + +// Returns the name of the "shard" directory for the given base path and shard +// ID. +std::string ShardDirectory(const std::string& run_directory, int64_t shard_id); + +// Returns the checkpoint file name for the given directory and checkpoint ID. +std::string GetCheckpointFileName(const std::string& shard_directory, + uint64 checkpoint_id); + +// This is a interface class that exposes snapshot writing functionality. +class Writer { + public: + // Creates a new writer object. + static absl::Status Create(Env* env, const std::string& filename, + const std::string& compression_type, int version, + const DataTypeVector& dtypes, + std::unique_ptr* out_writer); + + // Writes a vector of tensors to the snapshot writer file. + virtual absl::Status WriteTensors(const std::vector& tensors) = 0; + + // Flushes any in-memory buffers to disk. + virtual absl::Status Sync() = 0; + + // Closes and finalizes the snapshot file. All calls to any other method will + // be invalid after this call. + virtual absl::Status Close() = 0; + + virtual ~Writer() = default; + + protected: + virtual absl::Status Initialize(tensorflow::Env* env) = 0; +}; + +// Writes snapshots with the standard TFRecord file format. +class TFRecordWriter : public Writer { + public: + TFRecordWriter(const std::string& filename, + const std::string& compression_type); + + absl::Status Initialize(tensorflow::Env* env) override; + + absl::Status WriteTensors(const std::vector& tensors) override; + + absl::Status Sync() override; + + absl::Status Close() override; + + ~TFRecordWriter() override; + + private: + const std::string filename_; + const std::string compression_type_; + + std::unique_ptr dest_; + std::unique_ptr record_writer_; +}; + +// Writes snapshot with a custom (legacy) file format. +class CustomWriter : public Writer { + public: + static constexpr const size_t kHeaderSize = sizeof(uint64); + + static constexpr const char* const kClassName = "SnapshotWriter"; + static constexpr const char* const kWriteStringPiece = "WriteStringPiece"; + static constexpr const char* const kWriteCord = "WriteCord"; + static constexpr const char* const kSeparator = "::"; + + CustomWriter(const std::string& filename, const std::string& compression_type, + const DataTypeVector& dtypes); + + absl::Status WriteTensors(const std::vector& tensors) override; + + absl::Status Sync() override; + + absl::Status Close() override; + + ~CustomWriter() override; + + protected: + absl::Status Initialize(tensorflow::Env* env) override; + + private: + absl::Status WriteRecord(const absl::string_view& data); + +#if defined(TF_CORD_SUPPORT) + absl::Status WriteRecord(const absl::Cord& data); +#endif // TF_CORD_SUPPORT + + std::unique_ptr dest_; + const std::string filename_; + const std::string compression_type_; + const DataTypeVector dtypes_; + // We hold zlib_dest_ because we may create a ZlibOutputBuffer and put that + // in dest_ if we want compression. ZlibOutputBuffer doesn't own the original + // dest_ and so we need somewhere to store the original one. + std::unique_ptr zlib_underlying_dest_; + std::vector simple_tensor_mask_; // true for simple, false for complex. + int num_simple_ = 0; + int num_complex_ = 0; +}; + +// Interface class for reading snapshot files previous written with Writer. +class Reader { + public: + // Op kernel that creates an instance of `Reader::Dataset` needed to support + // serialization and deserialization of `Reader::Dataset`. + class DatasetOp : public DatasetOpKernel { + public: + explicit DatasetOp(OpKernelConstruction* ctx); + + protected: + void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override; + + private: + DataTypeVector output_types_; + std::vector output_shapes_; + std::string compression_; + int64_t version_; + }; + + // Op kernel that creates an instance of `Reader::NestedDataset` needed to + // support serialization and deserialization of `Reader::NestedDataset`. + class NestedDatasetOp : public DatasetOpKernel { + public: + explicit NestedDatasetOp(OpKernelConstruction* ctx); + + protected: + void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override; + + private: + DataTypeVector output_types_; + std::vector output_shapes_; + }; + + // Creates a new Reader object that reads data from `filename`. Note that + // the `version`, `compression_type`, and `dtypes` arguments passed into + // `Writer` and `Reader` must be the same for the reading to succeed. + static absl::Status Create(Env* env, const std::string& filename, + const string& compression_type, int version, + const DataTypeVector& dtypes, + std::unique_ptr* out_reader); + + // Returns a nested dataset for a set of given snapshot file names. + // + // This function takes a vector of snapshot files, and returns a nested + // dataset. Each element within the nested dataset is itself a dataset, and + // contains all the elements written out to each individual snapshot file. + static absl::Status MakeNestedDataset( + Env* env, const std::vector& shard_dirs, + const string& compression_type, int version, const DataTypeVector& dtypes, + const std::vector& shapes, int64_t start_index, + DatasetBase** output); + + // Returns a nested dataset for the given datasets. + static void MakeNestedDataset(const std::vector& datasets, + DatasetBase** output); + + // Reads a vector of Tensors from the snapshot file. + virtual absl::Status ReadTensors(std::vector* read_tensors) = 0; + + // Skips `num_records`. Equivalent to calling `ReadTensors` `num_records` + // times then discarding the results. + virtual absl::Status SkipRecords(int64_t num_records); + + virtual ~Reader() = default; + + protected: + virtual absl::Status Initialize(Env* env) = 0; + + class Dataset; + class NestedDataset; +}; + +class TFRecordReaderImpl { + public: + // Constructs a `TFRecordReaderImpl`. + // `filename` is the file to read from. + // `compression_type` is the compression method, as defined in + // tensorflow/compiler/xla/tsl/lib/io/compression.h. + // `output_buffer_size` specifies the buffer size required by Snappy/Zlib + // compression algorithms. Ignored if compression is not enabled. + TFRecordReaderImpl(const std::string& filename, const string& compression, + std::optional output_buffer_size = std::nullopt); + + // Initializes the reader. Callers must initialize the reader before calling + // `GetNext` or `GetTensors`. + absl::Status Initialize(Env* env); + + // Reads the next Tensor in the input file. + absl::StatusOr GetNext(); + + // Reads all Tensors in the input file. + absl::StatusOr> GetTensors(); + + // Returns the number of bytes read. + uint64_t BytesRead() const { return bytes_read_; } + + private: + // Parses `record` into a Tensor. + absl::StatusOr Parse(const tstring& record); + + std::string filename_; + std::unique_ptr file_; + std::unique_ptr record_reader_; + uint64_t offset_ = 0; + uint64_t bytes_read_ = 0; + + const string compression_; + const std::optional output_buffer_size_; +}; + +// Reads snapshots previously written with `TFRecordWriter`. +class TFRecordReader : public Reader { + public: + TFRecordReader(const std::string& filename, const string& compression, + const DataTypeVector& dtypes, + std::optional output_buffer_size = std::nullopt) + : reader_impl_(filename, compression, output_buffer_size), + dtypes_(dtypes) {} + + // Initializes the reader. Callers must initialize the reader before calling + // `ReadTensors`. + absl::Status Initialize(Env* env) override { + return reader_impl_.Initialize(env); + } + + // Reads Tensors into `read_tensors`. Returns OK on success, OutOfRange for + // end of file, or an error status if there is an error. + absl::Status ReadTensors(std::vector* read_tensors) override; + + // Returns the number of bytes read. + uint64_t BytesRead() const { return reader_impl_.BytesRead(); } + + private: + TFRecordReaderImpl reader_impl_; + const DataTypeVector dtypes_; +}; + +// Reads snapshots previously written with `CustomWriter`. +class CustomReader : public Reader { + public: + // The reader input buffer size is deliberately large because the input reader + // will throw an error if the compressed block length cannot fit in the input + // buffer. + static constexpr const int64_t kSnappyReaderInputBufferSizeBytes = + 1 << 30; // 1 GiB + // TODO(b/148804377): Set this in a smarter fashion. + static constexpr const int64_t kSnappyReaderOutputBufferSizeBytes = + 32 << 20; // 32 MiB + static constexpr const size_t kHeaderSize = sizeof(uint64); + + static constexpr const char* const kClassName = "SnapshotReader"; + static constexpr const char* const kReadString = "ReadString"; + static constexpr const char* const kReadCord = "ReadCord"; + static constexpr const char* const kSeparator = "::"; + + CustomReader(const std::string& filename, const string& compression_type, + int version, const DataTypeVector& dtypes); + + absl::Status ReadTensors(std::vector* read_tensors) override; + + ~CustomReader() override = default; + + protected: + absl::Status Initialize(Env* env) override; + + private: + absl::Status ReadTensorsV0(std::vector* read_tensors); + + absl::Status SnappyUncompress( + const experimental::SnapshotTensorMetadata* metadata, + std::vector* simple_tensors, + std::vector, size_t>>* + tensor_proto_strs); + + absl::Status ReadRecord(tstring* record); + +#if defined(TF_CORD_SUPPORT) + absl::Status ReadRecord(absl::Cord* record); +#endif + + std::string filename_; + std::unique_ptr file_; + std::unique_ptr input_stream_; + const string compression_type_; + const int version_; + const DataTypeVector dtypes_; + int num_simple_ = 0; + int num_complex_ = 0; + std::vector simple_tensor_mask_; // true for simple, false for complex. +}; + +// Writes snapshot metadata to the given directory. +absl::Status WriteMetadataFile( + Env* env, const string& dir, + const experimental::SnapshotMetadataRecord* metadata); + +// Writes distributed snapshot metadata to the given directory. An error is +// returned if `dir` is unable to be created or if `metadata` is unable to be +// written. +absl::Status WriteMetadataFile( + Env* env, const string& dir, + const experimental::DistributedSnapshotMetadata* metadata); + +// Reads snapshot metadata from the given directory. +absl::Status ReadMetadataFile(Env* env, const string& dir, + experimental::SnapshotMetadataRecord* metadata, + bool* file_exists); + +// Reads distributed snapshot metadata from the given directory. If the file +// doesn't exist in `dir`, `file_exists` is set to true and an ok status is +// returned. If the file exists in `dir` but is unable to be opened, an error +// is returned. +absl::Status ReadMetadataFile( + Env* env, const string& dir, + experimental::DistributedSnapshotMetadata* metadata, bool* file_exists); + +// Writes a dataset graph to the given directory. +absl::Status DumpDatasetGraph(Env* env, const std::string& path, uint64 hash, + const GraphDef* graph); + +absl::Status DetermineOpState( + const std::string& mode_string, bool file_exists, + const experimental::SnapshotMetadataRecord* metadata, + uint64 pending_snapshot_expiry_seconds, Mode* mode); + +// Represents a dataset element or EOF. +struct ElementOrEOF { + std::vector value; + bool end_of_sequence = false; +}; + +// AsyncWriter provides API for asynchronously writing dataset elements +// (each represented as a vector of tensors) to a file. +// +// The expected use of this API is: +// +// std::unique_ptr writer = absl_make_unique(...); +// +// while (data_available()) { +// std::vector data = read_data() +// writer->Write(data); +// } +// writer->SignalEOF(); +// writer = nullptr; // This will block until writes are flushed. +class AsyncWriter { + public: + explicit AsyncWriter(Env* env, int64_t file_index, + const std::string& shard_directory, uint64 checkpoint_id, + const std::string& compression, int64_t version, + const DataTypeVector& output_types, + std::function done); + + // Writes the given tensors. The method is non-blocking and returns without + // waiting for the element to be written. + void Write(const std::vector& tensors) TF_LOCKS_EXCLUDED(mu_); + + // Signals the end of input. The method is non-blocking and returns without + // waiting for the writer to be closed. + void SignalEOF() TF_LOCKS_EXCLUDED(mu_); + + private: + void Consume(ElementOrEOF* be) TF_LOCKS_EXCLUDED(mu_); + bool ElementAvailable() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); + absl::Status WriterThread(Env* env, const std::string& shard_directory, + uint64 checkpoint_id, + const std::string& compression, int64_t version, + DataTypeVector output_types); + + mutex mu_; + std::deque deque_ TF_GUARDED_BY(mu_); + + // This has to be last. During destruction, we need to make sure that the + // Thread object is destroyed first as its destructor blocks on thread + // completion. If there are other member variables after this, they may get + // destroyed first before the thread finishes, potentially causing the + // thread to access invalid memory. + std::unique_ptr thread_; +}; + +} // namespace snapshot_util +} // namespace data +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_DATA_SNAPSHOT_UTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/data/split_utils.h b/third_party/tflite-hdrs/tensorflow/core/data/split_utils.h new file mode 100644 index 00000000..a0fdef8d --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/data/split_utils.h @@ -0,0 +1,95 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_DATA_SPLIT_UTILS_H_ +#define TENSORFLOW_CORE_DATA_SPLIT_UTILS_H_ + +#include +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "tensorflow/core/framework/dataset.h" +#include "tensorflow/core/framework/tensor.h" +#include "tsl/platform/mutex.h" +#include "tsl/platform/thread_annotations.h" + +namespace tensorflow { +namespace data { + +// A class which produces splits for a dataset of size N that can be indexed +// into. +class IndexSplitProvider : public SplitProvider { + public: + explicit IndexSplitProvider(int64_t n); + absl::Status GetNext(Tensor* split, bool* end_of_splits) override; + absl::Status Reset() override; + absl::Status Save(std::function full_name, + IteratorStateWriter* writer) override; + absl::Status Restore(std::function full_name, + IteratorStateReader* reader) override; + int64_t Cardinality() const override; + + private: + tsl::mutex mu_; + int64_t i_ TF_GUARDED_BY(mu_); + const int64_t n_; +}; + +// A SplitProvider which wraps another split provider, but drops all splits +// where `index != shard_index % num_shards` +class ShardingSplitProvider : public SplitProvider { + public: + ShardingSplitProvider(int64_t num_shards, int64_t shard_index, + std::shared_ptr split_provider); + + absl::Status GetNext(Tensor* split, bool* end_of_splits) override; + absl::Status Reset() override; + absl::Status Save(std::function full_name, + IteratorStateWriter* writer) override; + absl::Status Restore(std::function full_name, + IteratorStateReader* reader) override; + + private: + const int64_t num_shards_; + const int64_t shard_index_; + tsl::mutex mu_; + std::shared_ptr split_provider_ TF_GUARDED_BY(mu_); + int64_t num_to_skip_ TF_GUARDED_BY(mu_); +}; + +// Returns split providers for all sources of the given dataset. +absl::StatusOr>> GetSplitProviders( + const DatasetBase* dataset); + +// Gets the single split provider from the context, or returns an error if the +// context has zero or multiple split providers. The `dataset` argument is used +// to produce a more useful error message. +absl::StatusOr> GetSingleSplitProvider( + IteratorContext* ctx, const DatasetBase* dataset); + +// Creates iterator contexts for datasets inputs. The split providers +// in `ctx` will be divided among the inputs of `dataset`, so that each input +// gets a number of split providers that matches its number of source datasets. +// If no split providers are defined, the contexts will be the same as `ctx`. +absl::StatusOr> CreateInputIteratorContexts( + IteratorContext* ctx, const DatasetBase* dataset); + +} // namespace data +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_DATA_SPLIT_UTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/data/standalone.h b/third_party/tflite-hdrs/tensorflow/core/data/standalone.h new file mode 100644 index 00000000..5b2b2b2c --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/data/standalone.h @@ -0,0 +1,163 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_DATA_STANDALONE_H_ +#define TENSORFLOW_CORE_DATA_STANDALONE_H_ + +#include +#include +#include +#include + +#include "tensorflow/core/common_runtime/device_mgr.h" +#include "tensorflow/core/data/tfdataz_metrics.h" +#include "tensorflow/core/data/unbounded_thread_pool.h" +#include "tensorflow/core/framework/cancellation.h" +#include "tensorflow/core/framework/dataset.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/function_handle_cache.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/model.h" +#include "tensorflow/core/framework/resource_mgr.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/lib/core/threadpool.h" +#include "tensorflow/core/public/session_options.h" +#include "tsl/platform/status.h" +#include "tsl/platform/statusor.h" + +namespace tensorflow { +namespace data { +namespace standalone { + +// The purpose of the API in this file is to facilitate standalone execution of +// a tf.data input pipeline graph. +// +// The API exposes two abstractions -- a `Dataset` and an `Iterator` -- which +// encapsulate TensorFlow runtime. +// +// The `Dataset` abstraction represents an input pipeline as a collection +// of data sources and a logical plan of transformations that operate over the +// data. +// +// The `Iterator` abstraction represents an execution of an input pipeline that +// can be used to enumerate its elements. +// +// Example usage: +// +// // Create a `Dataset` by running the `graph_def` graph. +// tensorflow::data:standalone::Dataset::Params params; +// std::unique_ptr dataset; +// Status s = tensorflow::data::standalone::Dataset::FromGraph( +// params, graph_def, &dataset); +// if (!s.ok()) { /* error handling */ } +// +// std::unique_ptr iterator; +// s = dataset->MakeIterator(&iterator); +// if (!s.ok()) { /* error handling */ } +// +// bool end_of_input = false; +// while (!end_of_input) { +// std::vector outputs; +// s = iterator->GetNext(&outputs, &end_of_input); +// if (!s.ok()) { /* error handling */ } +// if (!end_of_input) { /* output handling */ } +// } + +class Dataset; + +// Represents an execution of an input pipeline that can be used to enumerate +// its elements. +class Iterator { + public: + virtual ~Iterator(); + + // Returns the next element of the input pipeline (if there is one) and an + // indication of whether the end of the input pipeline has been reached. + absl::Status GetNext(std::vector* outputs, bool* end_of_input); + + // Saves a checkpoint of the iterator. Returns Tensors that can be called with + // `Restore()`. + absl::StatusOr> Save(); + + // Restores the iterator from a checkpoint. `saved_iterator` is the serialized + // iterator saved by calling `Save()`. + absl::Status Restore(const std::vector& saved_iterator); + + // Returns the dataset model for performance analysis. + std::shared_ptr model() const; + + private: + friend class Dataset; + + Iterator(IteratorBase* iterator, IteratorContext* ctx, + SerializationContext* serialization_ctx); + + std::unique_ptr iterator_; + std::unique_ptr ctx_; + std::unique_ptr serialization_ctx_; + std::shared_ptr tf_dataz_metrics_collector_; +}; + +// Represents an input pipeline as a collection of data sources and a logical +// plan of transformations that operate over the data. +class Dataset { + public: + // Parameters for `Dataset` creation (e.g. TensorFlow runtime configuration). + struct Params { + SessionOptions session_options; + }; + + // Creates a new `Dataset` instance by running the given dataset graph. + static absl::Status FromGraph(Params params, const GraphDef& graph_def, + std::unique_ptr* result); + + ~Dataset(); + + // Creates an iterator for this dataset. + absl::Status MakeIterator(std::unique_ptr* result); + // Creates an iterator, optionally with a split provider. + absl::Status MakeIterator( + std::vector> split_providers, + std::unique_ptr* result); + + // Creates split providers for this dataset. + absl::Status MakeSplitProviders( + std::vector>* result); + // Returns a pointer to the underlying dataset. + const DatasetBase* Get() const; + + private: + Dataset(DatasetBase* finalized_dataset, DatasetBase* original_dataset, + DeviceMgr* device_mgr, ProcessFunctionLibraryRuntime* pflr, + FunctionLibraryDefinition* flib_def, thread::ThreadPool* pool, + std::function)> runner); + + DatasetBase* finalized_dataset_; // owned + DatasetBase* original_dataset_; // owned + std::unique_ptr device_mgr_; + std::unique_ptr flib_def_; + std::unique_ptr pflr_; + std::unique_ptr interop_threadpool_; + std::unique_ptr function_handle_cache_; + std::function)> runner_; + ResourceMgr resource_mgr_; + CancellationManager cancellation_manager_; + UnboundedThreadPool unbounded_thread_pool_; +}; + +} // namespace standalone +} // namespace data +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_DATA_STANDALONE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/data/stats_utils.h b/third_party/tflite-hdrs/tensorflow/core/data/stats_utils.h new file mode 100644 index 00000000..5fa1eae3 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/data/stats_utils.h @@ -0,0 +1,68 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_DATA_STATS_UTILS_H_ +#define TENSORFLOW_CORE_DATA_STATS_UTILS_H_ + +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +namespace data { +namespace stats_utils { +extern const char kDelimiter[]; +extern const char kExecutionTime[]; +extern const char kThreadUtilization[]; +extern const char kBufferSize[]; +extern const char kBufferCapacity[]; +extern const char kBufferUtilization[]; +extern const char kFilteredElements[]; +extern const char kDroppedElements[]; +extern const char kFeaturesCount[]; +extern const char kFeatureValuesCount[]; +extern const char kExamplesCount[]; + +// Name for tf.data function execution time (in ns) histogram metrics. +string ExecutionTimeHistogramName(const string& prefix); + +// Name for thread utilization (ratio of threads being used and maximum number +// of threads allocated) scalar metrics. +string ThreadUtilizationScalarName(const string& prefix); + +// Name for buffer size scalar metrics. +string BufferSizeScalarName(const string& prefix); + +// Name for buffer capacity (maximum allocated buffer size) scalar metrics. +string BufferCapacityScalarName(const string& prefix); + +// Name for buffer utilization (ratio of buffer size and maximum allocated +// buffer size.) histogram metrics. +string BufferUtilizationHistogramName(const string& prefix); + +// Name for filtered elements scalar metrics. +string FilterdElementsScalarName(const string& prefix); + +// Name for dropped elements scalar mereics. +string DroppedElementsScalarName(const string& prefix); + +// Name for features count histogram metrics. +string FeatureHistogramName(const string& prefix); + +// Name for feature-values count histogram metrics. +string FeatureValueHistogramName(const string& prefix); + +} // namespace stats_utils +} // namespace data +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_DATA_STATS_UTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/data/test_utils.h b/third_party/tflite-hdrs/tensorflow/core/data/test_utils.h new file mode 100644 index 00000000..61da1807 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/data/test_utils.h @@ -0,0 +1,54 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_DATA_TEST_UTILS_H_ +#define TENSORFLOW_CORE_DATA_TEST_UTILS_H_ + +#include +#include + +#include "tensorflow/core/common_runtime/device_mgr.h" +#include "tensorflow/core/common_runtime/process_function_library_runtime.h" +#include "tensorflow/core/framework/dataset.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tsl/platform/statusor.h" + +namespace tensorflow { +namespace data { + +class TestContext { + public: + static absl::StatusOr> Create(); + virtual ~TestContext() = default; + + OpKernelContext* op_ctx() const { return op_ctx_.get(); } + IteratorContext* iter_ctx() const { return iter_ctx_.get(); } + + private: + TestContext() = default; + + std::unique_ptr device_mgr_; + std::unique_ptr lib_def_; + std::unique_ptr pflr_; + std::function)> runner_; + OpKernelContext::Params params_; + std::unique_ptr op_ctx_; + std::unique_ptr iter_ctx_; +}; + +} // namespace data +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_DATA_TEST_UTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/data/tf_data_memory_logger.h b/third_party/tflite-hdrs/tensorflow/core/data/tf_data_memory_logger.h new file mode 100644 index 00000000..7978fefc --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/data/tf_data_memory_logger.h @@ -0,0 +1,27 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_DATA_TF_DATA_MEMORY_LOGGER_H_ +#define TENSORFLOW_CORE_DATA_TF_DATA_MEMORY_LOGGER_H_ + +namespace tensorflow { +namespace data { + +// Starts the iterator memory logger if it is not already started. The logger is +// only active at VLOG level 4. +void EnsureIteratorMemoryLoggerStarted(); +} // namespace data +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_DATA_TF_DATA_MEMORY_LOGGER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/data/tfdataz_metrics.h b/third_party/tflite-hdrs/tensorflow/core/data/tfdataz_metrics.h new file mode 100644 index 00000000..e37daf89 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/data/tfdataz_metrics.h @@ -0,0 +1,150 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_DATA_TFDATAZ_METRICS_H_ +#define TENSORFLOW_CORE_DATA_TFDATAZ_METRICS_H_ + +#include +#include +#include +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/time/time.h" +#include "tensorflow/core/framework/dataset.h" +#include "tensorflow/core/framework/model.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/thread_annotations.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +namespace data { + +// Calculates the approximate average latency for past 1, 5 and 60 minutes. +// The implementation uses ring buffers to maintain the cumulative latency +// values and count for the past 60 minutes. +class ApproximateLatencyEstimator { + public: + enum class Duration { + kMinute = 1, + kFiveMinutes = 5, + kSixtyMinutes = 60, + }; + + explicit ApproximateLatencyEstimator(const Env& env); + + // Records the latency with the current timestamp. + void AddLatency(int64_t latency_usec); + + // Returns the average latency for the duration (1,5 and 60 minutes) + // specified. + absl::Duration GetAverageLatency(Duration duration); + + private: + static constexpr int64_t kSecondsPerMinute = 60; + static constexpr int64_t kMinutesPerHour = 60; + static constexpr int64_t kSlots = kMinutesPerHour; + + // Updates the latency value and count ring buffers with the latest cumulative + // value and count. Resets the entire ring buffer with the last cumulative + // values stored if the elapsed time duration is greater than 60 minutes. + void UpdateRingBuffer() TF_LOCKS_EXCLUDED(mu_); + // Moves the `next_slot_` to the next index in the ring buffer. + void IncrementNextSlot() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); + // Returns the slot index which is behind the current slot in ring buffer by + // `steps` indices. + int PrevSlot(int steps) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); + + const Env& env_; + + // The time when the ring buffer was last updated. + int64_t last_updated_time_mins_ TF_GUARDED_BY(mu_); + + mutex mu_; + + // Counters storing the cumulative sums of latency values and counts recorded + // so far. + int64_t latency_value_counter_ TF_GUARDED_BY(mu_); + int64_t latency_count_counter_ TF_GUARDED_BY(mu_); + + // Next slot in the ring buffer. + int next_slot_ TF_GUARDED_BY(mu_); + + // Ring buffer storing the cumulative sum of latency values and counts for the + // last 60 minutes. + int64_t latency_value_[kSlots] TF_GUARDED_BY(mu_); + int64_t latency_count_[kSlots] TF_GUARDED_BY(mu_); +}; + +// Collects and exports the tf.data performance metrics to /tfdataz. +class TfDatazMetricsCollector { + public: + // Constructs a `TfDatazMetricsCollector`. + // We only collect metrics for CPU devices. This is a heuristic to avoid + // collecting metrics for device-side iterators created by the multi-device + // iterator mechanism. + TfDatazMetricsCollector(const Env& env, DatasetBaseIterator* iterator, + std::shared_ptr model); + + // Records `GetNext` call latency. + void RecordGetNextLatency(int64_t get_next_latency_usec); + + // Returns the average `GetNext` latency for past 1 minute. + absl::Duration GetAverageLatencyForLastOneMinute(); + + // Returns the average `GetNext` latency for past 5 minutes. + absl::Duration GetAverageLatencyForLastFiveMinutes(); + + // Returns the average `GetNext` latency for past 60 minutes. + absl::Duration GetAverageLatencyForLastSixtyMinutes(); + + // Returns the dataset name if one was set. + std::optional DatasetName(); + + // Returns the total memory (in bytes) used by the iterator. + // Total memory used by the iterator includes the total number of bytes + // buffered in all nodes in the subtree. + int64_t GetIteratorTotalMemoryUsage(); + + std::shared_ptr GetModel(); + + private: + DatasetBaseIterator* iterator_; // not owned + std::shared_ptr model_; + ApproximateLatencyEstimator latency_estimator_; +}; + +// Thread-safe global registry for the /tfdataz metrics. All callers to +// `TfDatazMetricsRegistry` use the same instance to register and deregister +// iterator's `TfDatazMetricsCollector`. +class TfDatazMetricsRegistry { + public: + // Registers the iterator specific `TfDatazMetricsCollector` in the global + // TfDatazMetricsRegistry. + static void Register(std::shared_ptr collector); + + // Deregisters the iterator specific `TfDatazMetricsCollector` from the global + // TfDatazMetricsRegistry. + static void Deregister(std::shared_ptr collector); + + // Returns all the registered `TfDatazMetricsCollector`s. + static absl::flat_hash_set> + GetIteratorMetricCollectors(); +}; + +} // namespace data +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_DATA_TFDATAZ_METRICS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/data/unbounded_thread_pool.h b/third_party/tflite-hdrs/tensorflow/core/data/unbounded_thread_pool.h new file mode 100644 index 00000000..f790c938 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/data/unbounded_thread_pool.h @@ -0,0 +1,66 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_DATA_UNBOUNDED_THREAD_POOL_H_ +#define TENSORFLOW_CORE_DATA_UNBOUNDED_THREAD_POOL_H_ + +#include +#include +#include +#include + +#include "tensorflow/core/framework/thread_factory.h" +#include "tensorflow/core/lib/core/notification.h" +#include "tensorflow/core/lib/core/threadpool_interface.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/unbounded_work_queue.h" + +namespace tensorflow { +namespace data { + +// An `UnboundedThreadPool` provides a mechanism for temporally multiplexing a +// potentially large number of "logical" threads onto a smaller number of +// "physical" threads. The multiplexing is achieved by using an +// `UnboundedWorkQueue`. +class UnboundedThreadPool : public thread::ThreadPoolInterface { + public: + UnboundedThreadPool(Env* env, const string& thread_name) + : unbounded_work_queue_(env, thread_name) {} + UnboundedThreadPool(Env* env, const string& thread_name, + const ThreadOptions& thread_options) + : unbounded_work_queue_(env, thread_name, thread_options) {} + ~UnboundedThreadPool() override = default; + + // Returns an implementation of `ThreadFactory` that can be used to create + // logical threads in this pool. + std::shared_ptr get_thread_factory(); + + void Schedule(std::function fn) override; + int NumThreads() const override; + int CurrentThreadId() const override; + + private: + class LogicalThreadFactory; + class LogicalThreadWrapper; + + void ScheduleOnWorkQueue(std::function fn, + std::shared_ptr done); + + UnboundedWorkQueue unbounded_work_queue_; +}; + +} // namespace data +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_DATA_UNBOUNDED_THREAD_POOL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/data/utils.h b/third_party/tflite-hdrs/tensorflow/core/data/utils.h new file mode 100644 index 00000000..64b8e6f9 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/data/utils.h @@ -0,0 +1,62 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_DATA_UTILS_H_ +#define TENSORFLOW_CORE_DATA_UTILS_H_ + +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/status/statusor.h" +#include "tensorflow/core/protobuf/data_service.pb.h" + +namespace tensorflow { +namespace data { + +// Records latency of fetching data from tf.data iterator. +void AddLatencySample(int64_t microseconds); + +// Records bytes produced by a tf.data iterator. +void IncrementThroughput(int64_t bytes); + +// Returns a modified file name that can be used to do implementation specific +// file name manipulation/optimization. +std::string TranslateFileName(const std::string& fname); + +// Returns the data transfer protocol to use if one is not specified by the +// user. +std::string DefaultDataTransferProtocol(); + +// Returns a path pointing to the same file as `path` with a potential locality +// optimization. +std::string LocalityOptimizedPath(const std::string& path); + +// Returns `true` if tf.data service compression should be disabled at runtime +// based on (1) the inputs or (2) the properties of the calling trainer. +absl::StatusOr DisableCompressionAtRuntime( + const std::string& data_transfer_protocol, DeploymentMode deployment_mode, + DataServiceMetadata::Compression compression); + +// Log filenames into TfDataLogger. Uses the same TfDataFileLoggerClient at +// every call. Thread safe. +// TODO (shushanik) Implement streamz error reporting in case the logging is not +// successful +void LogFilenames(const std::vector& files); + +} // namespace data +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_DATA_UTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/debug/debug_callback_registry.h b/third_party/tflite-hdrs/tensorflow/core/debug/debug_callback_registry.h new file mode 100644 index 00000000..94b57401 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/debug/debug_callback_registry.h @@ -0,0 +1,71 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_DEBUG_DEBUG_CALLBACK_REGISTRY_H_ +#define TENSORFLOW_CORE_DEBUG_DEBUG_CALLBACK_REGISTRY_H_ + +#include +#include +#include +#include + +#include "tensorflow/core/debug/debug_node_key.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/platform/mutex.h" + +namespace tensorflow { + +// Supports exporting observed debug events to clients using registered +// callbacks. Users can register a callback for each debug_url stored using +// DebugTensorWatch. The callback key be equivalent to what follows +// "memcbk:///". +// +// All events generated for a watched node will be sent to the call back in the +// order that they are observed. +// +// This callback router should not be used in production or training steps. It +// is optimized for deep inspection of graph state rather than performance. +class DebugCallbackRegistry { + public: + using EventCallback = std::function; + + // Provides singleton access to the in memory event store. + static DebugCallbackRegistry* singleton(); + + // Returns the registered callback, or nullptr, for key. + EventCallback* GetCallback(const string& key); + + // Associates callback with key. This must be called by clients observing + // nodes to be exported by this callback router before running a session. + void RegisterCallback(const string& key, EventCallback callback); + + // Removes the callback associated with key. + void UnregisterCallback(const string& key); + + private: + DebugCallbackRegistry(); + + // Mutex to ensure that keyed events are never updated in parallel. + mutex mu_; + + // Maps debug_url keys to callbacks for routing observed tensors. + std::map keyed_callback_ TF_GUARDED_BY(mu_); + + static DebugCallbackRegistry* instance_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_DEBUG_DEBUG_CALLBACK_REGISTRY_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/debug/debug_graph_utils.h b/third_party/tflite-hdrs/tensorflow/core/debug/debug_graph_utils.h new file mode 100644 index 00000000..27cfb357 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/debug/debug_graph_utils.h @@ -0,0 +1,124 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_DEBUG_DEBUG_GRAPH_UTILS_H_ +#define TENSORFLOW_CORE_DEBUG_DEBUG_GRAPH_UTILS_H_ + +#include +#include + +#include "tensorflow/core/common_runtime/debugger_state_interface.h" +#include "tensorflow/core/common_runtime/device.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/protobuf/debug.pb.h" + +namespace tensorflow { + +class DebugNodeInserter { + public: + // EXPERIMENTAL: Insert special debug ops (e.g., DebugIdentity) to graph for + // debugging. Currently, such ops need to take exactly one input and has the + // string attribute "tensor_name" to indicate what tensor it watches. + // For example, before the node insertion, the graph may look like: + // + // A:0 -----------1----------> B + // | + // ---------2-----------> C + // + // wherein the output slot 0 of node A feeds as the input to nodes B through + // edge 1 and to node C through edge 2. + // After the node insertion, assuming both B and C have non-Ref input, the + // graph becomes: + // A:0 ---3---> Copy -----------4----------> B + // | + // ---------5--------> C + // | + // ---------6--------> X + // + // If a node (e.g., B) has Ref input, the graph becomes: + // + // --------------------------------> B + // | + // A:0 ---3-----> Copy -----------4----------> C + // | + // -----------5--------> X + // + // In other words, we do not feed Refs to deep-copies to downstream nodes. + // + // Copy is the inserted deep-copy node that copies the input tensor on-device + // (e.g., CPU-to-CPU or GPU-to-GPU deep copy) that reduces the likelihood of + // racy updates during the debug watches. X is the newly created debug node + // that transforms the input (copy of the watched tensor) into a debug signal. + // + // DebugIdentity is the simplest debugging paradigm, in which the debug signal + // (i.e., X:0) equals the tensor itself. More sophisticated debug ops can be + // used to transform the tensor into other debug signals. An example is the + // DebugNanCounter op. + // + // If the nodes (A, B and C) are located on GPU and the edges from A to B or C + // is HOST_MEMORY, then the CopyHost op will be used instead of the Copy op. + static absl::Status InsertNodes( + const protobuf::RepeatedPtrField& watches, Graph* graph, + Device* device); + + // Set the parallel_iterations attribute of TensorFlow while loops + // (specifically the nodes for which IsEnter() returns true) to 1 to prevent + // any node from being executed multiple times concurrently and + // generating temporally-overlapping debug Tensor dumps. + static void DeparallelizeWhileLoops(Graph* graph, Device* device); + + // Get canonical name of a copy node. + static const string GetCopyNodeName(const string& node_name, + const int output_slot); + + // Get canonical name of a debug node. + static const string GetDebugNodeName(const string& tensor_name, + const int debug_op_num, + const string& debug_op_name); + + private: + static absl::Status CreateCopyNode( + Graph* graph, const DeviceType device_type, const bool is_host_memory, + const string& src_node_name, const int src_output, const DataType src_dt, + const string& tensor_name, const std::vector& debug_ops, + const std::vector& debug_urls, Node** copy_node); + + // Parse the debug_op_name string to extract proper op name and attributes. + // debug_op_name can be the proper op name only, e.g., "DebugNumericSummary". + // It can also contain customizable keys and values. Each key-value pair is + // connected with an equal sign ("="). Multiple key-value pairs are separated + // with semicolons (";"), which optional whitespace in between, e.g., + // "DebugNumericSummary(mute_if_healthy=true, lower_bound=-100.0)". + static absl::Status ParseDebugOpName( + const string& debug_op_name, string* debug_op_name_proper, + std::unordered_map* attributes); + + static absl::Status SetDebugNodeAttributes( + Node* debug_node, const std::unordered_map& attributes); + + static absl::Status CreateDebugNode( + Graph* graph, const Device& device, const string& src_copy_node_name, + const DataType src_dt, const string& tensor_name, + const std::vector& debug_urls, const int debug_op_num, + const string& debug_op_name, Node** debug_node); + // TODO(cais): Cut down the number of args to this method. + + friend class DebugGraphUtilsTest; +}; +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_DEBUG_DEBUG_GRAPH_UTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/debug/debug_grpc_testlib.h b/third_party/tflite-hdrs/tensorflow/core/debug/debug_grpc_testlib.h new file mode 100644 index 00000000..2a57df8d --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/debug/debug_grpc_testlib.h @@ -0,0 +1,87 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_DEBUG_DEBUG_GRPC_TESTLIB_H_ +#define TENSORFLOW_CORE_DEBUG_DEBUG_GRPC_TESTLIB_H_ + +#include +#include + +#include "grpcpp/grpcpp.h" +#include "tensorflow/core/debug/debug_io_utils.h" +#include "tensorflow/core/debug/debug_service.grpc.pb.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/platform/mutex.h" + +namespace tensorflow { + +namespace test { + +class TestEventListenerImpl final : public grpc::EventListener::Service { + public: + TestEventListenerImpl() : stop_requested_(false), stopped_(false) {} + + void RunServer(const int server_port); + void StopServer(); + + ::grpc::Status SendEvents( + ::grpc::ServerContext* context, + ::grpc::ServerReaderWriter< ::tensorflow::EventReply, + ::tensorflow::Event>* stream) override; + + // Clear debug data (e.g., Tensors) received so far. + void ClearReceivedDebugData(); + + void RequestDebugOpStateChangeAtNextStream( + const EventReply::DebugOpStateChange::State new_state, + const DebugNodeKey& debug_node_key); + + std::vector debug_metadata_strings; + std::vector encoded_graph_defs; + std::vector device_names; + std::vector node_names; + std::vector output_slots; + std::vector debug_ops; + std::vector debug_tensors; + + private: + std::atomic_bool stop_requested_; + std::atomic_bool stopped_; + + std::vector debug_node_keys_ TF_GUARDED_BY(states_mu_); + std::vector new_states_ + TF_GUARDED_BY(states_mu_); + + std::unordered_set write_enabled_debug_node_keys_; + + mutex states_mu_; +}; + +// Poll a gRPC debug server by sending a small tensor repeatedly till success. +// +// Args: +// server_url: gRPC URL of the server to poll, e.g., "grpc://foo:3333". +// max_attempts: Maximum number of attempts. +// +// Returns: +// Whether the polling succeeded within max_attempts. +bool PollTillFirstRequestSucceeds(const string& server_url, + const size_t max_attempts); + +} // namespace test + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_DEBUG_DEBUG_GRPC_TESTLIB_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/debug/debug_io_utils.h b/third_party/tflite-hdrs/tensorflow/core/debug/debug_io_utils.h new file mode 100644 index 00000000..95864c71 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/debug/debug_io_utils.h @@ -0,0 +1,446 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_DEBUG_DEBUG_IO_UTILS_H_ +#define TENSORFLOW_CORE_DEBUG_DEBUG_IO_UTILS_H_ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "tensorflow/core/debug/debug_node_key.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/util/event.pb.h" + +namespace tensorflow { + +absl::Status ReadEventFromFile(const string& dump_file_path, Event* event); + +struct DebugWatchAndURLSpec { + DebugWatchAndURLSpec(const string& watch_key, const string& url, + const bool gated_grpc) + : watch_key(watch_key), url(url), gated_grpc(gated_grpc) {} + + const string watch_key; + const string url; + const bool gated_grpc; +}; + +// TODO(cais): Put static functions and members in a namespace, not a class. +class DebugIO { + public: + static const char* const kDebuggerPluginName; + + static const char* const kCoreMetadataTag; + static const char* const kGraphTag; + static const char* const kHashTag; + + static const char* const kFileURLScheme; + static const char* const kGrpcURLScheme; + static const char* const kMemoryURLScheme; + + static absl::Status PublishDebugMetadata( + const int64_t global_step, const int64_t session_run_index, + const int64_t executor_step_index, const std::vector& input_names, + const std::vector& output_names, + const std::vector& target_nodes, + const std::unordered_set& debug_urls); + + // Publishes a tensor to a debug target URL. + // + // Args: + // debug_node_key: A DebugNodeKey identifying the debug node. If + // `debug_node_key.io_of_node` is non-empty, publish for node + // inputs/outputs dumping feature. + // tensor: The Tensor object being published. + // wall_time_us: Time stamp for the Tensor. Unit: microseconds (us). + // debug_urls: An array of debug target URLs, e.g., + // "file:///foo/tfdbg_dump", "grpc://localhost:11011" + // gated_grpc: Whether this call is subject to gRPC gating. + // step_id: Step ID associated with the tensor. + static absl::Status PublishDebugTensor( + const DebugNodeKey& debug_node_key, const Tensor& tensor, + const uint64 wall_time_us, const absl::Span debug_urls, + bool gated_grpc, int64_t step_id = -1); + + // Convenience overload of the method above for no gated_grpc by default. + static absl::Status PublishDebugTensor( + const DebugNodeKey& debug_node_key, const Tensor& tensor, + const uint64 wall_time_us, const absl::Span debug_urls); + + // Publishes a graph to a set of debug URLs. + // + // Args: + // graph: The graph to be published. + // debug_urls: The set of debug URLs to publish the graph to. + static absl::Status PublishGraph( + const Graph& graph, const string& device_name, + const std::unordered_set& debug_urls); + + // Determines whether a copy node needs to perform deep-copy of input tensor. + // + // The input arguments contain sufficient information about the attached + // downstream debug ops for this method to determine whether all the said + // ops are disabled given the current status of the gRPC gating. + // + // Args: + // specs: A vector of DebugWatchAndURLSpec carrying information about the + // debug ops attached to the Copy node, their debug URLs and whether + // they have the attribute value gated_grpc == True. + // + // Returns: + // Whether any of the attached downstream debug ops is enabled given the + // current status of the gRPC gating. + static bool IsCopyNodeGateOpen( + const std::vector& specs); + + // Determines whether a debug node needs to proceed given the current gRPC + // gating status. + // + // Args: + // watch_key: debug tensor watch key, in the format of + // tensor_name:debug_op, e.g., "Weights:0:DebugIdentity". + // debug_urls: the debug URLs of the debug node. + // + // Returns: + // Whether this debug op should proceed. + static bool IsDebugNodeGateOpen(const string& watch_key, + const std::vector& debug_urls); + + // Determines whether debug information should be sent through a grpc:// + // debug URL given the current gRPC gating status. + // + // Args: + // watch_key: debug tensor watch key, in the format of + // tensor_name:debug_op, e.g., "Weights:0:DebugIdentity". + // debug_url: the debug URL, e.g., "grpc://localhost:3333", + // "file:///tmp/tfdbg_1". + // + // Returns: + // Whether the sending of debug data to the debug_url should + // proceed. + static bool IsDebugURLGateOpen(const string& watch_key, + const string& debug_url); + + static absl::Status CloseDebugURL(const string& debug_url); +}; + +// Helper class for debug ops. +class DebugFileIO { + public: + // Encapsulates the Tensor in an Event protobuf and write it to a directory. + // The actual path of the dump file will be a contactenation of + // dump_root_dir, tensor_name, along with the wall_time. + // + // For example: + // let dump_root_dir = "/tmp/tfdbg_dump", + // node_name = "foo/bar", + // output_slot = 0, + // debug_op = DebugIdentity, + // and wall_time_us = 1467891234512345, + // the dump file will be generated at path: + // /tmp/tfdbg_dump/foo/bar_0_DebugIdentity_1467891234512345. + // + // Args: + // debug_node_key: A DebugNodeKey identifying the debug node. + // wall_time_us: Wall time at which the Tensor is generated during graph + // execution. Unit: microseconds (us). + // dump_root_dir: Root directory for dumping the tensor. + // dump_file_path: The actual dump file path (passed as reference). + static absl::Status DumpTensorToDir(const DebugNodeKey& debug_node_key, + const Tensor& tensor, + const uint64 wall_time_us, + const string& dump_root_dir, + string* dump_file_path); + + // Similar to the above, but for node inputs/outputs dumping feature. + static absl::Status DumpTensorToDirForNodeDumping( + const DebugNodeKey& debug_node_key, const Tensor& tensor, + uint64 wall_time_us, const string& dump_root_dir, string* dump_file_path, + int64_t step_id); + + // Get the full path to the dump file. + // + // Args: + // dump_root_dir: The dump root directory, e.g., /tmp/tfdbg_dump + // node_name: Name of the node from which the dumped tensor is generated, + // e.g., foo/bar/node_a + // output_slot: Output slot index of the said node, e.g., 0. + // debug_op: Name of the debug op, e.g., DebugIdentity. + // wall_time_us: Time stamp of the dumped tensor, in microseconds (us). + static string GetDumpFilePath(const string& dump_root_dir, + const DebugNodeKey& debug_node_key, + const uint64 wall_time_us); + + // Similar to the above, but for node inputs/outputs dumping feature. + static string GetDumpFilePathForNodeDumping( + const string& dump_root_dir, const DebugNodeKey& debug_node_key, + uint64 wall_time_us, int64_t step_id); + + // Dumps an Event proto to a file. + // + // Args: + // event_prot: The Event proto to be dumped. + // dir_name: Directory path. + // file_name: Base file name. + static absl::Status DumpEventProtoToFile(const Event& event_proto, + const string& dir_name, + const string& file_name); + + // Request additional bytes to be dumped to the file system. + // + // Does not actually dump the bytes, but instead just performs the + // bookkeeping necessary to prevent the total dumped amount of data from + // exceeding the limit (default 100 GBytes or set customly through the + // environment variable TFDBG_DISK_BYTES_LIMIT). + // + // Args: + // bytes: Number of bytes to request. + // + // Returns: + // Whether the request is approved given the total dumping + // limit. + static bool requestDiskByteUsage(uint64 bytes); + + // Reset the disk byte usage to zero. + static void resetDiskByteUsage(); + + static uint64 global_disk_bytes_limit_; + + private: + // Encapsulates the Tensor in an Event protobuf and write it to file. + static absl::Status DumpTensorToEventFile(const DebugNodeKey& debug_node_key, + const Tensor& tensor, + const uint64 wall_time_us, + const string& file_path); + + // Implemented ad hoc here for now. + // TODO(cais): Replace with shared implementation once http://b/30497715 is + // fixed. + static absl::Status RecursiveCreateDir(Env* env, const string& dir); + + // Tracks how much disk has been used so far. + static uint64 disk_bytes_used_; + // Mutex for thread-safe access to disk_bytes_used_. + static mutex bytes_mu_; + // Default limit for the disk space. + static const uint64 kDefaultGlobalDiskBytesLimit; + + friend class DiskUsageLimitTest; +}; + +} // namespace tensorflow + +namespace std { + +template <> +struct hash<::tensorflow::DebugNodeKey> { + size_t operator()(const ::tensorflow::DebugNodeKey& k) const { + return ::tensorflow::Hash64( + ::tensorflow::strings::StrCat(k.device_name, ":", k.node_name, ":", + k.output_slot, ":", k.debug_op, ":")); + } +}; + +} // namespace std + +// TODO(cais): Support grpc:// debug URLs in open source once Python grpc +// genrule becomes available. See b/23796275. +#ifndef PLATFORM_WINDOWS +#include "grpcpp/channel.h" +#include "tensorflow/core/debug/debug_service.grpc.pb.h" + +namespace tensorflow { + +class DebugGrpcChannel { + public: + // Constructor of DebugGrpcChannel. + // + // Args: + // server_stream_addr: Address (host name and port) of the debug stream + // server implementing the EventListener service (see + // debug_service.proto). E.g., "127.0.0.1:12345". + explicit DebugGrpcChannel(const string& server_stream_addr); + + virtual ~DebugGrpcChannel() {} + + // Attempt to establish connection with server. + // + // Args: + // timeout_micros: Timeout (in microseconds) for the attempt to establish + // the connection. + // + // Returns: + // OK Status iff connection is successfully established before timeout, + // otherwise return an error Status. + absl::Status Connect(const int64_t timeout_micros); + + // Write an Event proto to the debug gRPC stream. + // + // Thread-safety: Safe with respect to other calls to the same method and + // calls to ReadEventReply() and Close(). + // + // Args: + // event: The event proto to be written to the stream. + // + // Returns: + // True iff the write is successful. + bool WriteEvent(const Event& event); + + // Read an EventReply proto from the debug gRPC stream. + // + // This method blocks and waits for an EventReply from the server. + // Thread-safety: Safe with respect to other calls to the same method and + // calls to WriteEvent() and Close(). + // + // Args: + // event_reply: the to-be-modified EventReply proto passed as reference. + // + // Returns: + // True iff the read is successful. + bool ReadEventReply(EventReply* event_reply); + + // Receive and process EventReply protos from the gRPC debug server. + // + // The processing includes setting debug watch key states using the + // DebugOpStateChange fields of the EventReply. + // + // Args: + // max_replies: Maximum number of replies to receive. Will receive all + // remaining replies iff max_replies == 0. + void ReceiveAndProcessEventReplies(size_t max_replies); + + // Receive EventReplies from server (if any) and close the stream and the + // channel. + absl::Status ReceiveServerRepliesAndClose(); + + private: + string server_stream_addr_; + string url_; + ::grpc::ClientContext ctx_; + std::shared_ptr<::grpc::Channel> channel_; + std::unique_ptr stub_; + std::unique_ptr<::grpc::ClientReaderWriterInterface> + reader_writer_; + + mutex mu_; +}; + +class DebugGrpcIO { + public: + static const size_t kGrpcMessageSizeLimitBytes; + static const size_t kGrpcMaxVarintLengthSize; + + // Sends a tensor through a debug gRPC stream. + static absl::Status SendTensorThroughGrpcStream( + const DebugNodeKey& debug_node_key, const Tensor& tensor, + const uint64 wall_time_us, const string& grpc_stream_url, + const bool gated); + + // Sends an Event proto through a debug gRPC stream. + // Thread-safety: Safe with respect to other calls to the same method and + // calls to CloseGrpcStream(). + // + // Args: + // event_proto: The Event proto to be sent. + // grpc_stream_url: The grpc:// URL of the stream to use, e.g., + // "grpc://localhost:11011", "localhost:22022". + // receive_reply: Whether an EventReply proto will be read after event_proto + // is sent and before the function returns. + // + // Returns: + // The Status of the operation. + static absl::Status SendEventProtoThroughGrpcStream( + const Event& event_proto, const string& grpc_stream_url, + const bool receive_reply = false); + + // Receive an EventReply proto through a debug gRPC stream. + static absl::Status ReceiveEventReplyProtoThroughGrpcStream( + EventReply* event_reply, const string& grpc_stream_url); + + // Check whether a debug watch key is read-activated at a given gRPC URL. + static bool IsReadGateOpen(const string& grpc_debug_url, + const string& watch_key); + + // Check whether a debug watch key is write-activated (i.e., read- and + // write-activated) at a given gRPC URL. + static bool IsWriteGateOpen(const string& grpc_debug_url, + const string& watch_key); + + // Closes a gRPC stream to the given address, if it exists. + // Thread-safety: Safe with respect to other calls to the same method and + // calls to SendTensorThroughGrpcStream(). + static absl::Status CloseGrpcStream(const string& grpc_stream_url); + + // Set the gRPC state of a debug node key. + // TODO(cais): Include device information in watch_key. + static void SetDebugNodeKeyGrpcState( + const string& grpc_debug_url, const string& watch_key, + const EventReply::DebugOpStateChange::State new_state); + + private: + using DebugNodeName2State = + std::unordered_map; + + // Returns a global map from grpc debug URLs to the corresponding + // DebugGrpcChannels. + static std::unordered_map>* + GetStreamChannels(); + + // Get a DebugGrpcChannel object at a given URL, creating one if necessary. + // + // Args: + // grpc_stream_url: grpc:// URL of the stream, e.g., "grpc://localhost:6064" + // debug_grpc_channel: A pointer to the DebugGrpcChannel object, passed as a + // a pointer to the pointer. The DebugGrpcChannel object is owned + // statically elsewhere, not by the caller of this function. + // + // Returns: + // Status of this operation. + static absl::Status GetOrCreateDebugGrpcChannel( + const string& grpc_stream_url, DebugGrpcChannel** debug_grpc_channel); + + // Returns a map from debug URL to a map from debug op name to enabled state. + static std::unordered_map* + GetEnabledDebugOpStates(); + + // Returns a map from debug op names to enabled state, for a given debug URL. + static DebugNodeName2State* GetEnabledDebugOpStatesAtUrl( + const string& grpc_debug_url); + + // Clear enabled debug op state from all debug URLs (if any). + static void ClearEnabledWatchKeys(); + + static mutex streams_mu_; + static int64_t channel_connection_timeout_micros_; + + friend class GrpcDebugTest; + friend class DebugNumericSummaryOpTest; +}; + +} // namespace tensorflow +#endif // #ifndef(PLATFORM_WINDOWS) + +#endif // TENSORFLOW_CORE_DEBUG_DEBUG_IO_UTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/debug/debug_node_key.h b/third_party/tflite-hdrs/tensorflow/core/debug/debug_node_key.h new file mode 100644 index 00000000..5decb5cc --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/debug/debug_node_key.h @@ -0,0 +1,56 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_DEBUG_DEBUG_NODE_KEY_H_ +#define TENSORFLOW_CORE_DEBUG_DEBUG_NODE_KEY_H_ + +#include + +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +// Encapsulates debug information for a node that was observed. +struct DebugNodeKey { + static const char* const kMetadataFilePrefix; + static const char* const kDeviceTag; + + DebugNodeKey(const string& device_name, const string& node_name, + int32_t output_slot, const string& debug_op, + const string& io_of_node = "", bool is_input = false, + int32_t io_index = -1); + + // Converts a device name string to a device path string. + // E.g., /job:localhost/replica:0/task:0/cpu:0 will be converted to + // ,job_localhost,replica_0,task_0,cpu_0. + static const string DeviceNameToDevicePath(const string& device_name); + + bool operator==(const DebugNodeKey& other) const; + bool operator!=(const DebugNodeKey& other) const; + + const string device_name; + const string node_name; + const int32 output_slot; + const string debug_op; + const string debug_node_name; + const string device_path; + const string io_of_node; + const bool is_input; + const int32 io_index; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_DEBUG_DEBUG_NODE_KEY_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/debug/debugger_state_impl.h b/third_party/tflite-hdrs/tensorflow/core/debug/debugger_state_impl.h new file mode 100644 index 00000000..c34aa8bb --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/debug/debugger_state_impl.h @@ -0,0 +1,61 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_DEBUG_DEBUGGER_STATE_IMPL_H_ +#define TENSORFLOW_CORE_DEBUG_DEBUGGER_STATE_IMPL_H_ + +#include "tensorflow/core/common_runtime/debugger_state_interface.h" + +#include +#include + +namespace tensorflow { + +class DebuggerState : public DebuggerStateInterface { + public: + DebuggerState(const DebugOptions& debug_options); + ~DebuggerState() override; + + // Publish metadata about the debugged Session::Run() call. + // + // See the doc string of DebuggerStateInterface::PublishDebugMetadata() for + // details. + absl::Status PublishDebugMetadata( + const int64_t global_step, const int64_t session_run_count, + const int64_t executor_step_count, const std::vector& input_names, + const std::vector& output_names, + const std::vector& target_names) override; + + private: + std::unordered_set debug_urls_; +}; + +class DebugGraphDecorator : public DebugGraphDecoratorInterface { + public: + DebugGraphDecorator(const DebugOptions& debug_options) + : debug_options_(debug_options) {} + ~DebugGraphDecorator() override {} + + absl::Status DecorateGraph(Graph* graph, Device* device) override; + absl::Status PublishGraph(const Graph& graph, + const string& device_name) override; + + private: + DebugOptions debug_options_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_DEBUG_DEBUGGER_STATE_IMPL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/base_rendezvous_mgr.h b/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/base_rendezvous_mgr.h new file mode 100644 index 00000000..4713f3be --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/base_rendezvous_mgr.h @@ -0,0 +1,297 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_BASE_RENDEZVOUS_MGR_H_ +#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_BASE_RENDEZVOUS_MGR_H_ + +#include +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "tensorflow/core/common_runtime/eager/rendezvous_cache.h" +#include "tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h" +#include "tensorflow/core/distributed_runtime/worker_env.h" +#include "tensorflow/core/distributed_runtime/worker_session.h" +#include "tensorflow/core/framework/cancellation.h" +#include "tensorflow/core/framework/control_flow.h" +#include "tensorflow/core/framework/local_rendezvous.h" +#include "tensorflow/core/framework/rendezvous.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/hash/hash.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/thread_annotations.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/device_name_utils.h" +#include "tsl/platform/refcount.h" + +namespace tensorflow { + +class BaseRemoteRendezvous; +class BaseRecvTensorCall; + +// RendezvousMgr keeps track of a set of local rendezvous instances. +// All tensors sent by this worker are buffered in a RendezvousMgr +// until the tensor is received. Each global unique "step_id" +// corresponds to one local rendezvous instance managed by a +// RendezvousMgr. +// RendezvousMgr holds weak references to rendezvous. When a rendezvous is +// destructed, it will create a new instance to fulfill the Find. +// +// E.g., +// Rendezvous* rendez = worker_env->rendezvous_mgr->Find(0x8935); +// fork execution of a graph executor using "rendez" on thread 1; +// fork execution of another graph executor using "rendez" on thread 2; +// ... +// join threads 1 and 2; +// +// In the example above, execution in thread 1 and 2 communicates with +// each other by send/recv operations through `rendez`. +// +// Tensors sent and received through a rendezvous managed by this +// RendezvousMgr must have keys generated by Rendezvous::CreateKey(). +class BaseRendezvousMgr : public RendezvousMgrInterface { + public: + explicit BaseRendezvousMgr(const WorkerEnv* worker_env); + + ~BaseRendezvousMgr() override; + + // Returns Rendezvous supporting send and recv among workers in the + // "step_id". The caller takes ownership of one reference on the + // returned Rendezvous instance. + // + // Note: the caller must guarantee to eventually call Initialize on the + // returned RemoteRendezvous + tsl::core::RefCountPtr Find(int64_t step_id) override; + + // Finds the local rendezvous instance for the "step_id". Runs + // "done" when the tensor for "key" is produced or an error occurs. + // + // This method is used by the rpc handler of RecvTensor. + void RecvLocalAsync(int64_t step_id, const Rendezvous::ParsedKey& parsed, + Rendezvous::DoneCallback done) override; + + // Synchronous wrapper for RecvLocalAsync. + absl::Status RecvLocal(int64_t step_id, const Rendezvous::ParsedKey& parsed, + Tensor* val, bool* is_dead) override; + + // Removes rendezvous for "step_id". + void Cleanup(int64_t step_id) override { cache_->RemoveAndAbort(step_id); } + + // Remove all rendezvous instances owned by the rendezvous_mgr. + void CleanupAll() override { cache_->RemoveAll(); } + + protected: + virtual tsl::core::RefCountPtr Create( + int64_t step_id, const WorkerEnv* worker_env) = 0; + + private: + tsl::core::RefCountPtr> cache_; + + // Not owned. + const WorkerEnv* const worker_env_; + + tsl::core::RefCountPtr FindOrCreate(int64_t step_id); + + BaseRendezvousMgr(const BaseRendezvousMgr&) = delete; + void operator=(const BaseRendezvousMgr&) = delete; +}; + +// RemoteRendezvous is a Rendezvous which can handle either +// the producer or consumer being in a remote process. +// +// Buffering of Tensor values is delegated to a "local" Rendezvous +// obtained from NewLocalRendezvous(). This class just adds +// functionality to coordinate with remote workers. +class BaseRemoteRendezvous : public RemoteRendezvous { + public: + BaseRemoteRendezvous(const WorkerEnv* env, int64_t step_id); + + // Upgrades the BaseRemoteRendezvous to full initialization. + absl::Status Initialize(WorkerSession* session) override; + + void SetRemoteEagerContextDefault() override { + remote_eager_context_default_ = true; + } + bool IsRemoteEagerContextDefault() override { + return remote_eager_context_default_; + } + + // Forwards to local_, where the Tensor "val" will be buffered and + // any waiting callback stored. + absl::Status Send(const ParsedKey& key, const Rendezvous::Args& args, + const Tensor& val, const bool is_dead) override; + + // This method is called only by the RecvOp. It tests to see + // whether the value will be produced by a local or remote device + // and handles accordingly. In the local case it forwards to + // local_, in the remote case it initiates an RPC request. + void RecvAsync(const ParsedKey& key, const Rendezvous::Args& args, + DoneCallback done) override; + + void StartAbort(const absl::Status& status) override; + + // This method is called only by the local Worker, forwarded through + // the same method on RendezvousMgr. This occurs when the Worker + // has received a RecvTensor request, either locally or over the + // network. In either case it needs to retrieve a locally buffered + // value from local_, and give it to its caller. + // + // Runs "done" as soon as the tensor for "parsed" is available or an error + // is detected. + // + // REQUIRES: "parsed" is one that will be Saved into the local rendezvous. + void RecvLocalAsync(const ParsedKey& parsed, DoneCallback done); + + protected: + virtual void RecvFromRemoteAsync(const Rendezvous::ParsedKey& parsed, + const Rendezvous::Args& args, + DoneCallback done) = 0; + + // Returns true if "src" and "dst" are located in the same worker, + // and hence may use a local rendezvous. + virtual bool IsSameWorker(DeviceNameUtils::ParsedName src, + DeviceNameUtils::ParsedName dst); + + // If aborted, aborts "call". Otherwise, adds "call" into calls_. + void RegisterCall(BaseRecvTensorCall* call, const Rendezvous::Args& args); + + // Removes "call" from calls_ if "call" is in calls_. + void DeregisterCall(BaseRecvTensorCall* call, const Rendezvous::Args& args); + + WorkerSession* session(); + + bool is_initialized(); + + ~BaseRemoteRendezvous() override; + + const WorkerEnv* const env_; // Not owned. + const int64_t step_id_; + + private: + int num_shards_; + LocalRendezvous local_; + // Indicates whether this remote rendezvous instance is used as the default + // rendezvous for remote eager op-by-op execution. Errors in eager op-by-op + // execution should not abort the rendezvous since it is a context-wide + // instance and needs to be reused; instead, the errors are propagated through + // eager executors. + bool remote_eager_context_default_ = false; + + mutable mutex mu_; + mutable mutex calls_mu_; + + // Status given by StartAbort() if any. + absl::Status status_ TF_GUARDED_BY(mu_); + + WorkerSession* session_ TF_GUARDED_BY(mu_); // Not owned. + + // Data structures to handle calls when partially initialized. + struct DeferredCall { + const ParsedKey parsed; + DoneCallback done; + + // Keeps a reference to the rendezvous, to keep it alive. + tsl::core::RefCountPtr rendezvous; + + DeferredCall(const ParsedKey& parsed, DoneCallback done, + tsl::core::RefCountPtr rendez); + }; + std::vector deferred_calls_ TF_GUARDED_BY(mu_); + + struct CallBucket { + mutex mu; + + absl::flat_hash_set calls TF_GUARDED_BY(mu); + }; + + struct PendingCalls { + PendingCalls(CancellationToken token, int num_calls, int num_buckets, + tsl::core::RefCountPtr rendez) + : token(token), + num_calls(num_calls), + buckets(num_buckets), + rendezvous(std::move(rendez)) {} + CancellationToken token = CancellationManager::kInvalidToken; + std::atomic num_calls = 0; + std::vector buckets; + + // Keeps a reference to the rendezvous, to keep it alive. + tsl::core::RefCountPtr rendezvous; + }; + + // "CancellationToken" is stored here so that when there's no active + // RecvTensorCalls, we can de-register the callback in the cancellation + // manager. RecvTensorCalls are managed in multiple buckets since in large + // scaled distributed training, lots of Send/Recv may be triggered + // concurrently. + // + // Note: pointer to CancellationManager can be nullptr in certain use cases. + absl::flat_hash_map> + calls_ TF_GUARDED_BY(calls_mu_); + + // Callback for CancellationManager. + void CancelledByManager(CancellationManager* cm); + + bool is_initialized_locked() TF_SHARED_LOCKS_REQUIRED(mu_) { + return session_ != nullptr; + } + + // If "is_src" is true, checks that the rendezvous key "parsed"'s + // source is in this process. If "is_src" is false, checks that the + // rendezvous key "parsed"'s destination is in this process. + absl::Status ValidateDevices(const Rendezvous::ParsedKey& parsed, + bool is_src); + + // Callback handling the case when a rendezvous has been + // accomplished in local_ and the consumer is local to this process. + // Tensor "in" will be copied into "out". The key "parsed" encodes + // the src and dst devices. + void SameWorkerRecvDone(const Rendezvous::ParsedKey& parsed, + const Rendezvous::Args& in_args, + const Rendezvous::Args& out_args, const Tensor& in, + Tensor* out, StatusCallback done); + + // Must be called only if fully initialized. + void RecvLocalAsyncInternal(const ParsedKey& parsed, DoneCallback done); + + BaseRemoteRendezvous(const BaseRemoteRendezvous&) = delete; + void operator=(const BaseRemoteRendezvous&) = delete; +}; + +class BaseRecvTensorCall { + public: + BaseRecvTensorCall() {} + virtual ~BaseRecvTensorCall() {} + + virtual void Start(std::function recv_done) = 0; + + virtual void StartAbort(const absl::Status& s) = 0; + + virtual absl::Status status() const = 0; + + private: + BaseRecvTensorCall(const BaseRecvTensorCall&) = delete; + void operator=(const BaseRecvTensorCall&) = delete; +}; + +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_BASE_RENDEZVOUS_MGR_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/call_options.h b/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/call_options.h new file mode 100644 index 00000000..a845bcdc --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/call_options.h @@ -0,0 +1,27 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_CALL_OPTIONS_H_ +#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_CALL_OPTIONS_H_ + +#include "xla/tsl/distributed_runtime/call_options.h" + +namespace tensorflow { +// NOLINTBEGIN(misc-unused-using-decls) +using tsl::CallOptions; +// NOLINTEND(misc-unused-using-decls) +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_CALL_OPTIONS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/cancellable_call.h b/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/cancellable_call.h new file mode 100644 index 00000000..7311c8e3 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/cancellable_call.h @@ -0,0 +1,61 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_CANCELLABLE_CALL_H_ +#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_CANCELLABLE_CALL_H_ + +#include +#include "tensorflow/core/distributed_runtime/call_options.h" +#include "tensorflow/core/distributed_runtime/worker_cache.h" +#include "tensorflow/core/framework/cancellation.h" +#include "tensorflow/core/platform/mutex.h" + +namespace tensorflow { + +// Supports client side cancellation of WorkerInterface calls via +// registration with a CancellationManager. +class CancellableCall { + public: + CancellableCall(CancellationManager* cancel_mgr, const string& remote_worker, + WorkerCacheInterface* wc) + : is_cancelled_(false), + cancel_mgr_(cancel_mgr), + remote_worker_(remote_worker), + wc_(wc), + wi_(wc_->GetOrCreateWorker(remote_worker_)) {} + + virtual ~CancellableCall() { wc_->ReleaseWorker(remote_worker_, wi_); } + + virtual void IssueCall(const StatusCallback& done) = 0; + + void Start(const StatusCallback& done); + + // Cancels the RPC if it's not cancelled yet. This must be called after + // Start(). This is normally used if there's a needed to cancel the RPC from a + // sideband. If appliable, pass a cancellation manager to the constructor + // instead of using this method. + void Cancel() TF_LOCKS_EXCLUDED(mu_); + + protected: + mutex mu_; + bool is_cancelled_; + CancellationManager* const cancel_mgr_; // Not owned + const string remote_worker_; + WorkerCacheInterface* const wc_; // Not owned + WorkerInterface* const wi_; // Owned by wc_, must be released. + CallOptions opts_; +}; + +} // namespace tensorflow +#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_CANCELLABLE_CALL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/cluster_function_library_runtime.h b/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/cluster_function_library_runtime.h new file mode 100644 index 00000000..a016a5ee --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/cluster_function_library_runtime.h @@ -0,0 +1,106 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_CLUSTER_FUNCTION_LIBRARY_RUNTIME_H_ +#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_CLUSTER_FUNCTION_LIBRARY_RUNTIME_H_ + +#include +#include +#include + +#include "absl/types/optional.h" +#include "tensorflow/core/distributed_runtime/worker_cache.h" +#include "tensorflow/core/distributed_runtime/worker_interface.h" +#include "tensorflow/core/framework/function.h" + +namespace tensorflow { + +class WorkerSession; + +// ClusterFunctionLibraryRuntime contains methods to Instantiate and Run +// functions across processes by making RPCs through worker service. +class ClusterFunctionLibraryRuntime : public DistributedFunctionLibraryRuntime { + public: + ClusterFunctionLibraryRuntime(WorkerSession* worker_session, + bool create_worker_session_called, + DeviceMgr* remote_device_mgr) + : worker_session_(worker_session), + create_worker_session_called_(create_worker_session_called), + remote_device_mgr_(remote_device_mgr) {} + + ~ClusterFunctionLibraryRuntime() override; + + void Instantiate(const string& function_name, + const FunctionLibraryDefinition& lib_def, AttrSlice attrs, + const FunctionLibraryRuntime::InstantiateOptions& options, + FunctionLibraryRuntime::LocalHandle* handle, + FunctionLibraryRuntime::DoneCallback done) override; + + void Run(const FunctionLibraryRuntime::Options& opts, + FunctionLibraryRuntime::LocalHandle handle, + absl::Span args, std::vector* rets, + FunctionLibraryRuntime::DoneCallback done) override; + + void Run(const FunctionLibraryRuntime::Options& opts, + FunctionLibraryRuntime::LocalHandle handle, + absl::Span args, std::vector* rets, + FunctionLibraryRuntime::DoneCallback done) override; + + void CleanUp(uint64 step_id, FunctionLibraryRuntime::LocalHandle handle, + FunctionLibraryRuntime::DoneCallback done) override; + + DeviceMgr* remote_device_mgr() const override { return remote_device_mgr_; } + + private: + static absl::Status ConstructFunctionGraph( + const OpDef& sig, AttrSlice attrs, + const FunctionLibraryRuntime::InstantiateOptions& options, + const FunctionLibraryDefinition& flib_def, GraphDef* g, + std::vector* send_keys, std::vector* recv_keys); + friend class ClusterFunctionLibraryRuntimeTest; + + mutable mutex mu_; + WorkerSession* const worker_session_ = nullptr; // not owned. + const bool create_worker_session_called_; + + DeviceMgr* remote_device_mgr_; // not owned. + + struct FunctionData { + const string graph_handle; + const string target; + // Hold a shared pointer to the underlying worker cache to avoid it being + // deleted in potential cluster update. + const std::shared_ptr worker_cache; + WorkerInterface* wi = nullptr; + const std::vector send_keys; + const std::vector recv_keys; + + FunctionData(const string& graph_handle, const string& target, + std::shared_ptr worker_cache, + WorkerInterface* wi, const std::vector& send_keys, + const std::vector& recv_keys) + : graph_handle(graph_handle), + target(target), + worker_cache(std::move(worker_cache)), + wi(wi), + send_keys(send_keys), + recv_keys(recv_keys) {} + }; + + std::vector function_data_ TF_GUARDED_BY(mu_); +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_CLUSTER_FUNCTION_LIBRARY_RUNTIME_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/collective_param_resolver_distributed.h b/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/collective_param_resolver_distributed.h new file mode 100644 index 00000000..63006c12 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/collective_param_resolver_distributed.h @@ -0,0 +1,96 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_COLLECTIVE_PARAM_RESOLVER_DISTRIBUTED_H_ +#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_COLLECTIVE_PARAM_RESOLVER_DISTRIBUTED_H_ + +#include "tensorflow/core/common_runtime/collective_param_resolver_local.h" +#include "tensorflow/core/framework/cancellation.h" +#include "tensorflow/core/framework/device_attributes.pb.h" +#include "tensorflow/core/platform/status.h" + +namespace tensorflow { +class ConfigProto; +class WorkerCacheInterface; +class DeviceResolverDistributed; +class DeviceMgr; + +class CollectiveParamResolverDistributed : public CollectiveParamResolverLocal { + public: + CollectiveParamResolverDistributed( + const ConfigProto& config, const DeviceMgr* dev_mgr, + DeviceResolverDistributed* dev_resolver, + NcclCommunicatorInterface* nccl_communicator, + WorkerCacheInterface* worker_cache, const string& task_name); + + void CompleteParamsAsync(const DeviceAttributes& device, CollectiveParams* cp, + CancellationManager* cancel_mgr, + const StatusCallback& done) override; + + void CompleteGroupAsync(const DeviceAttributes& device, + CollGroupParams* group_params, + CancellationManager* cancel_mgr, + const StatusCallback& done) override; + + void CompleteInstanceAsync(const CompleteInstanceRequest* request, + CompleteInstanceResponse* response, + CancellationManager* cancel_mgr, + const StatusCallback& done) override; + + void StartAbort(const absl::Status& s) override; + + protected: + // Returns the cached group iff there's an entry for this group_key in the + // local group_table_; returns nullptr otherwise. + GroupRec* GetCachedGroup(int32_t group_key) TF_LOCKS_EXCLUDED(group_mu_); + + // Updates group_table_ with contents of resp. + absl::Status UpdateGroupCache(const CompleteGroupResponse& resp) + TF_LOCKS_EXCLUDED(group_mu_); + + // Finds the GroupRec that corresponds to cp->group_key and also + // populates cp->group from that GroupRec. + // + // Semantics are like those of CompleteGroupLocal but will make a + // remote call to the group leader if necessary. + void CompleteGroupDistributed(const DeviceAttributes& device, + CollGroupParams* group_params, + CancellationManager* cancel_mgr, + const StatusCallback& done); + + // Returns true iff there's an entry for this instance_key in the + // local instance_table_. + bool InstanceIsCached(int32_t group_key, const CollInstanceParams& instance) + TF_LOCKS_EXCLUDED(instance_mu_); + + // Updates instance_table_ with contents of resp. + absl::Status UpdateInstanceCache(CollectiveParams* cp, + const CompleteInstanceResponse& resp) + TF_LOCKS_EXCLUDED(instance_mu_, group_mu_); + + // Finish populating *cp. Semantics are like those of + // CompleteInstanceLocal but will make a remote call to the group + // leader if necessary. + void CompleteInstanceDistributed(const string& device, CollectiveParams* cp, + CancellationManager* cancel_mgr, + const StatusCallback& done) + TF_LOCKS_EXCLUDED(instance_mu_, group_mu_); + + WorkerCacheInterface* worker_cache_; // Not owned + const string group_leader_; + CancellationManager abortion_cancel_mgr_; +}; + +} // namespace tensorflow +#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_COLLECTIVE_PARAM_RESOLVER_DISTRIBUTED_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/collective_rma_distributed.h b/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/collective_rma_distributed.h new file mode 100644 index 00000000..22d4d6f5 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/collective_rma_distributed.h @@ -0,0 +1,64 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_COLLECTIVE_RMA_DISTRIBUTED_H_ +#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_COLLECTIVE_RMA_DISTRIBUTED_H_ + +#include "tensorflow/core/common_runtime/collective_rma_local.h" +#include "tensorflow/core/framework/cancellation.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/platform/unbounded_work_queue.h" + +namespace tensorflow { +class WorkerCacheInterface; + +// Extend CollectiveRemoteAccessLocal with access to remote peers. +class CollectiveRemoteAccessDistributed : public CollectiveRemoteAccessLocal { + public: + CollectiveRemoteAccessDistributed( + const DeviceMgr* dev_mgr, DeviceResolverInterface* dev_resolver, + std::shared_ptr work_queue, + WorkerCacheInterface* worker_cache, int64_t step_id, string task_name) + : CollectiveRemoteAccessLocal(dev_mgr, dev_resolver, step_id), + worker_cache_(worker_cache), + work_queue_(std::move(work_queue)), + task_name_(std::move(task_name)) {} + + ~CollectiveRemoteAccessDistributed() override {} + + void RecvFromPeer(const string& peer_device, const string& peer_task, + bool peer_is_local, const string& key, Device* to_device, + DeviceContext* to_device_ctx, + const AllocatorAttributes& to_alloc_attr, Tensor* to_tensor, + const DeviceLocality& client_locality, + int dev_to_dev_stream_index, + CancellationManager* cancellation_manager, + const StatusCallback& done) override; + + void CheckPeerHealth(const string& peer_task, int64_t timeout_in_ms, + const StatusCallback& done) override; + + void StartAbort(const absl::Status& s) override; + + protected: + WorkerCacheInterface* worker_cache_; // Not owned + // Ownership of `work_queue_` is shared between `this` and + // `CollectiveExecutorMgr`. + std::shared_ptr work_queue_; + CancellationManager abortion_cancel_mgr_; + string task_name_; +}; + +} // namespace tensorflow +#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_COLLECTIVE_RMA_DISTRIBUTED_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/coordination/coordination_client.h b/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/coordination/coordination_client.h new file mode 100644 index 00000000..0901d56b --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/coordination/coordination_client.h @@ -0,0 +1,31 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_COORDINATION_COORDINATION_CLIENT_H_ +#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_COORDINATION_COORDINATION_CLIENT_H_ + +#include +#include + +#include "xla/tsl/distributed_runtime/coordination/coordination_client.h" + +namespace tensorflow { +// NOLINTBEGIN(misc-unused-using-decls) +using tsl::CoordinationClient; +using tsl::CoordinationClientCache; +// NOLINTEND(misc-unused-using-decls) +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_COORDINATION_COORDINATION_CLIENT_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/coordination/coordination_service_barrier_proxy.h b/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/coordination/coordination_service_barrier_proxy.h new file mode 100644 index 00000000..3e0243ab --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/coordination/coordination_service_barrier_proxy.h @@ -0,0 +1,126 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_COORDINATION_COORDINATION_SERVICE_BARRIER_PROXY_H_ +#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_COORDINATION_COORDINATION_SERVICE_BARRIER_PROXY_H_ + +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/strings/string_view.h" +#include "absl/time/time.h" +#include "xla/tsl/distributed_runtime/coordination/coordination_service_agent.h" +#include "xla/tsl/protobuf/coordination_service.pb.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/thread_annotations.h" + +namespace tensorflow { + +// A local proxy connecting the coordination service's barrier. +// The barrier provided by coordination service can only block at tasks (i.e., +// TPU workers), but sometimes we need a barrier that can block at different +// threads. The proxy first waits at threads on a participating +// task and then issues a barrier wait to the coordination service once all the +// threads at that task have arrived. +// Usage: +// // Main thread creates a `BarrierProxy`: +// barrier = new BarrierProxy(agent, tasks, key, num_local_threads); +// +// // Each participating thread could then call: +// auto [status, last_exit] = barrier.Wait(); +// // The last exited thread is responsible for deleting the barrier. +// if (last_exit) { +// delete barrier; +// } +class BarrierProxy { + public: + BarrierProxy(const BarrierProxy&) = delete; + void operator=(const BarrierProxy&) = delete; + // Construct a BarrierProxy connected to the coordination service via `agent`. + // `tasks` specifies all participating coordinated tasks and + // `num_local_threads` specifies the number of threads in this task to + // particiate. If no tasks are specified, the barrier will block for all the + // connected tasks. + BarrierProxy(tsl::CoordinationServiceAgent* agent, + std::vector tasks, int num_local_threads, + absl::string_view key, absl::Duration timeout) + : key_(key), + agent_(agent), + tasks_(std::move(tasks)), + timeout_(timeout), + num_local_threads_(num_local_threads) {} + + ~BarrierProxy() = default; + + // Waits at the barrier. The first return value is the status when exiting the + // barrier and the second returns `true` for precisely one caller, which may + // then destroy the barrier. + std::pair Wait(); + + private: + const std::string key_; + tsl::CoordinationServiceAgent* agent_; + const std::vector tasks_; + absl::Duration timeout_; + + mutex mu_; + condition_variable cv_ TF_GUARDED_BY(mu_); + const int num_local_threads_; + int num_entered_ TF_GUARDED_BY(mu_) = 0; + int num_to_exit_ TF_GUARDED_BY(mu_) = 0; + absl::Status status_ TF_GUARDED_BY(mu_); + bool status_set_ TF_GUARDED_BY(mu_) = false; +}; + +// Manages the life cycle of BarrierProxies automatically. +// Usage: +// // Main thread creates a `BarrierProxy`: +// BarrierProxyManager barrier_mgr; +// +// // Exactly `num_local_threads` threads call: +// Status s = barrier_mgr.Wait(agent, task, num_local_threads, key, timeout); +class BarrierProxyManager { + public: + BarrierProxyManager(const BarrierProxyManager&) = delete; + void operator=(const BarrierProxyManager&) = delete; + BarrierProxyManager() = default; + ~BarrierProxyManager() = default; + + // Waits at the barrier backed by the coord service `agent` and keyed by + // `key`. `tasks` specifies all participating coordinated tasks and + // `num_local_threads` specifies the number of threads in this task to + // participate. If no tasks are specified, the barrier will block for all the + // connected tasks. + absl::Status Wait(tsl::CoordinationServiceAgent* agent, + const std::vector& tasks, + int num_local_threads, absl::string_view key, + absl::Duration timeout); + // The number of active BarrierProxies. + size_t size() const; + + private: + mutable mutex mu_; + absl::flat_hash_map> barriers_ + TF_GUARDED_BY(mu_); +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_COORDINATION_COORDINATION_SERVICE_BARRIER_PROXY_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/coordination/coordination_service_error_util.h b/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/coordination/coordination_service_error_util.h new file mode 100644 index 00000000..aa6dfa41 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/coordination/coordination_service_error_util.h @@ -0,0 +1,27 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_COORDINATION_COORDINATION_SERVICE_ERROR_UTIL_H_ +#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_COORDINATION_COORDINATION_SERVICE_ERROR_UTIL_H_ + +#include "xla/tsl/distributed_runtime/coordination/coordination_service_error_util.h" + +namespace tensorflow { +// NOLINTBEGIN(misc-unused-using-decls) +using ::tsl::CoordinationErrorPayloadKey; +using ::tsl::MakeCoordinationError; +// NOLINTEND(misc-unused-using-decls) +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_COORDINATION_COORDINATION_SERVICE_ERROR_UTIL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/coordination/coordination_service_rpc_handler.h b/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/coordination/coordination_service_rpc_handler.h new file mode 100644 index 00000000..d378684d --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/coordination/coordination_service_rpc_handler.h @@ -0,0 +1,27 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_COORDINATION_COORDINATION_SERVICE_RPC_HANDLER_H_ +#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_COORDINATION_COORDINATION_SERVICE_RPC_HANDLER_H_ + +#include "xla/tsl/distributed_runtime/coordination/coordination_service_rpc_handler.h" + +namespace tensorflow { +// NOLINTBEGIN(misc-unused-using-decls) +using tsl::CoordinationServiceRpcHandler; +// NOLINTEND(misc-unused-using-decls) +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_COORDINATION_COORDINATION_SERVICE_RPC_HANDLER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/device_resolver_distributed.h b/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/device_resolver_distributed.h new file mode 100644 index 00000000..b46c288c --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/device_resolver_distributed.h @@ -0,0 +1,50 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_DEVICE_RESOLVER_DISTRIBUTED_H_ +#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_DEVICE_RESOLVER_DISTRIBUTED_H_ + +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "tensorflow/core/framework/collective.h" +#include "tensorflow/core/framework/device_attributes.pb.h" +#include "tensorflow/core/platform/status.h" + +namespace tensorflow { +class DeviceMgr; +class WorkerCacheInterface; + +class DeviceResolverDistributed : public DeviceResolverInterface { + public: + explicit DeviceResolverDistributed(const DeviceMgr* dev_mgr); + + absl::Status GetDeviceAttributes(const string& device, + DeviceAttributes* attributes) override; + + absl::Status GetAllDeviceAttributes( + const string& task, std::vector* attributes) override; + + absl::Status UpdateDeviceAttributes( + const std::vector& attributes) override; + + protected: + const string task_name_; + mutex mu_; + absl::flat_hash_map attr_table_ TF_GUARDED_BY(mu_); +}; + +} // namespace tensorflow +#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_DEVICE_RESOLVER_DISTRIBUTED_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/eager/cluster_function_library_runtime.h b/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/eager/cluster_function_library_runtime.h new file mode 100644 index 00000000..58af5ed9 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/eager/cluster_function_library_runtime.h @@ -0,0 +1,115 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_EAGER_CLUSTER_FUNCTION_LIBRARY_RUNTIME_H_ +#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_EAGER_CLUSTER_FUNCTION_LIBRARY_RUNTIME_H_ + +#include +#include +#include + +#include "absl/types/optional.h" +#include "tensorflow/core/common_runtime/device_mgr.h" +#include "tensorflow/core/common_runtime/eager/context.h" +#include "tensorflow/core/common_runtime/eager/eager_operation.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/protobuf/remote_tensor_handle.pb.h" + +namespace tensorflow { + +class WorkerSession; + +namespace eager { + +// EagerClusterFunctionLibraryRuntime contains methods to Instantiate and Run +// functions across processes by making RPCs through eager service. +class EagerClusterFunctionLibraryRuntime + : public DistributedFunctionLibraryRuntime { + public: + EagerClusterFunctionLibraryRuntime(const uint64 context_id, EagerContext* ctx, + DeviceMgr* remote_device_mgr) + : context_id_(context_id), + ctx_(ctx), + remote_device_mgr_(remote_device_mgr) {} + + ~EagerClusterFunctionLibraryRuntime() override{}; + + // Register a partition (i.e., component function) of a multi-device function + // on the remote target specified in `options.target`. This should be + // triggered as part of instantiating a multi-device function in + // ProcessFunctionLibraryRuntime. + void Instantiate(const string& function_name, + const FunctionLibraryDefinition& lib_def, AttrSlice attrs, + const FunctionLibraryRuntime::InstantiateOptions& options, + FunctionLibraryRuntime::LocalHandle* handle, + FunctionLibraryRuntime::DoneCallback done) override; + + // Execute the component function specified by `handle` on its instantiated + // remote target. This should be triggered as part of driving a multi-device + // function execution in ProcessFunctionLibraryRuntime. Running the component + // function remotely is purely asynchronous, and multiple component functions + // with the same remote target are not executed in any particular ordering. + // The main function side must wait for all component functions to finish + // (i.e., the done callbacks triggered) before finishing its execution. + void Run(const FunctionLibraryRuntime::Options& opts, + FunctionLibraryRuntime::LocalHandle handle, + absl::Span args, std::vector* rets, + FunctionLibraryRuntime::DoneCallback done) override; + + // The component function inputs `args` and outputs `rets` may refer to remote + // tensors on a remote device, which will be lazily resolved remotely where + // the inputs/outputs are actually consumed. + void Run(const FunctionLibraryRuntime::Options& opts, + FunctionLibraryRuntime::LocalHandle handle, + absl::Span args, std::vector* rets, + FunctionLibraryRuntime::DoneCallback done) override; + + void CleanUp(uint64 step_id, FunctionLibraryRuntime::LocalHandle handle, + FunctionLibraryRuntime::DoneCallback done) override; + + DeviceMgr* remote_device_mgr() const override { return remote_device_mgr_; } + + private: + const uint64 context_id_; + EagerContext* ctx_; + DeviceMgr* remote_device_mgr_; // not owned. + + struct FunctionData { + const string target; + const absl::optional> ret_indices; + core::RefCountPtr eager_client; + std::unique_ptr op; + + FunctionData(const string& target, + const absl::optional>& ret_indices, + EagerClient* eager_client, std::unique_ptr op) + : target(target), + ret_indices(ret_indices), + eager_client(core::RefCountPtr(eager_client)), + op(std::move(op)) { + eager_client->Ref(); + } + }; + + mutable mutex mu_; + std::vector function_data_ TF_GUARDED_BY(mu_); +}; + +DistributedFunctionLibraryRuntime* CreateClusterFLR( + const uint64 context_id, EagerContext* ctx, WorkerSession* worker_session); + +} // namespace eager +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_EAGER_CLUSTER_FUNCTION_LIBRARY_RUNTIME_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/eager/destroy_tensor_handle_node.h b/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/eager/destroy_tensor_handle_node.h new file mode 100644 index 00000000..a9b9ead8 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/eager/destroy_tensor_handle_node.h @@ -0,0 +1,90 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_EAGER_DESTROY_TENSOR_HANDLE_NODE_H_ +#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_EAGER_DESTROY_TENSOR_HANDLE_NODE_H_ + +#include +#include + +#include "absl/status/status.h" +#include "tensorflow/core/common_runtime/eager/eager_executor.h" +#include "tensorflow/core/distributed_runtime/eager/eager_client.h" +#include "tensorflow/core/protobuf/eager_service.pb.h" + +namespace tensorflow { +namespace eager { + +// DestroyTensorHandleNode is an implementation of EagerNode which enqueues a +// request to destroy a remote tensor handle. +class DestroyTensorHandleNode : public tensorflow::AsyncEagerNode { + public: + DestroyTensorHandleNode(std::unique_ptr request, + core::RefCountPtr eager_client, + bool ready) + : tensorflow::AsyncEagerNode(), + request_(std::move(request)), + eager_client_(std::move(eager_client)), + ready_(ready) {} + + ~DestroyTensorHandleNode() override {} + + void RunAsync(StatusCallback done) override { + EnqueueResponse* response = new EnqueueResponse; + bool ready = ready_; + // NOTE(fishx): Don't use StreamingEnqueueAsync here. When a + // StreamingEnqueueAsync request fails all following requests will fail as + // well. We don't want this request poison following requests since it is + // safe to ignore a failing destroy tensor handle request. + eager_client_->EnqueueAsync( + /*call_opts=*/nullptr, request_.get(), response, + [response, ready, done](const absl::Status& s) { + // Omit the warning if: + // 1. The remote tensor isn't ready. + // 2. Lost connection to remote worker. In this case client will + // crash. We don't want to spam user with redundant warning logs. + if (!s.ok() && ready && !absl::IsUnavailable(s)) { + LOG_EVERY_N_SEC(WARNING, 60) + << "Ignoring an error encountered when deleting " + "remote tensors handles: " + << s.ToString(); + } + done(absl::OkStatus()); + delete response; + }); + } + + void Abort(absl::Status status) override {} + + // Remote node deletions are best effort + bool Fatal() const override { return false; } + + string DebugString() const override { + string out = "[DestroyTensorHandleNode]"; + strings::StrAppend(&out, " request: ", request_->DebugString()); + return out; + } + + private: + std::unique_ptr request_; + core::RefCountPtr eager_client_; + const string remote_task_; + bool ready_; +}; + +} // namespace eager +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_EAGER_DESTROY_TENSOR_HANDLE_NODE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/eager/eager_client.h b/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/eager/eager_client.h new file mode 100644 index 00000000..6fc95601 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/eager/eager_client.h @@ -0,0 +1,102 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_EAGER_EAGER_CLIENT_H_ +#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_EAGER_EAGER_CLIENT_H_ + +#include + +#include "tensorflow/core/distributed_runtime/call_options.h" +#include "tensorflow/core/platform/refcount.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/protobuf/eager_service.pb.h" + +namespace tensorflow { +namespace eager { + +// This is a base class that can be implemented by a variety of +// transports (e.g. gRPC which for each of the client methods makes an RPC). +class EagerClient : public core::RefCounted { + public: + ~EagerClient() override {} +#define CLIENT_METHOD(method) \ + virtual void method##Async(const method##Request* request, \ + method##Response* response, \ + StatusCallback done) = 0; + + CLIENT_METHOD(CreateContext); + CLIENT_METHOD(UpdateContext); + CLIENT_METHOD(WaitQueueDone); + CLIENT_METHOD(KeepAlive); + CLIENT_METHOD(CloseContext); + +#undef CLIENT_METHOD + +#define CLIENT_METHOD_WITH_TIMEOUT_AND_RETRIES(method) \ + virtual void method##Async(const method##Request* request, \ + method##Response* response, StatusCallback done, \ + int64_t init_timeout_in_ms, int retries) = 0; + + CLIENT_METHOD_WITH_TIMEOUT_AND_RETRIES(CreateContext); + +#undef CLIENT_METHOD_WITH_TIMEOUT_AND_RETRIES + +#define CLIENT_CANCELABLE_METHOD(method) \ + virtual void method##Async( \ + CallOptions* call_opts, const method##Request* request, \ + method##Response* response, StatusCallback done) = 0; + + CLIENT_CANCELABLE_METHOD(Enqueue); + CLIENT_CANCELABLE_METHOD(RunComponentFunction); + +#undef CLIENT_CANCELABLE_METHOD + + // Feeds `request` into the request stream of EagerService::StreamingEnqueue. + // `response` will be filled with the response for this `request`. The + // 1-to-1 correspondence between requests and responses is a property + // of the current service implementation. When the response is received, + // `done` is invoked with the current status of the StreamingEnqueue call. + // The status can contain an error because of an earlier request in the + // current streaming call. + // The client initiates a streaming call the first time StreamingEnqueueAsync + // is invoked and keeps it open until some error condition. + // Similarly to the methods above, the request can be deleted as soon as + // StreamingEnqueueAsync returns. + virtual void StreamingEnqueueAsync(bool enable_streaming_enqueue, + CallOptions* call_opts, + const EnqueueRequest* request, + EnqueueResponse* response, + StatusCallback done) = 0; + + virtual bool allow_multiple_pending_requests() const = 0; +}; + +// Simple wrapper class that can be used to retrieve EagerClients. +class EagerClientCache { + public: + virtual ~EagerClientCache() {} + + // If the `target` exists, assign the EagerClient pointer to `client` and + // increment the refcount of the client. The reference ownership is + // transferred to the caller, and the unref should automatically happen when + // destructing the RefCountPtr object from the caller's side. + virtual absl::Status GetClient(const string& target, + core::RefCountPtr* client) = 0; +}; + +} // namespace eager +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_EAGER_EAGER_CLIENT_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/eager/eager_service_impl.h b/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/eager/eager_service_impl.h new file mode 100644 index 00000000..924a99dd --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/eager/eager_service_impl.h @@ -0,0 +1,243 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_EAGER_EAGER_SERVICE_IMPL_H_ +#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_EAGER_EAGER_SERVICE_IMPL_H_ + +#include +#include +#include + +#include "tensorflow/core/common_runtime/eager/context.h" +#include "tensorflow/core/distributed_runtime/eager/remote_mgr.h" +#include "tensorflow/core/distributed_runtime/eager/remote_tensor_handle.h" +#include "tensorflow/core/distributed_runtime/worker_env.h" + +namespace tensorflow { +namespace eager { + +// A TensorFlow Eager Worker runs ops and supports worker to worker +// Tensor transfer. +// +// See eager_service.proto for more details about each method. +// This class can be wrapped by specific classes that implement rpc transports +// over this (e.g. gRPC). +class EagerServiceImpl { + public: + explicit EagerServiceImpl(WorkerEnv* env) : env_(env) { + gc_thread_.reset( + env_->env->StartThread({}, "EagerServiceContextGC", [this]() { + while (true) { + { + mutex_lock l(gc_thread_shutdown_mu_); + gc_thread_cv_.wait_for(l, std::chrono::seconds(1)); + + if (shutting_down_) { + return; + } + } + { + mutex_lock l(contexts_mu_); + for (auto it = contexts_.begin(); it != contexts_.end();) { + if (it->second->IsStale()) { + it->second->Unref(); + it = contexts_.erase(it); + } else { + it++; + } + } + } + } + })); + } + virtual ~EagerServiceImpl() { + { + mutex_lock l(gc_thread_shutdown_mu_); + shutting_down_ = true; + gc_thread_cv_.notify_all(); + } + gc_thread_.reset(); + + mutex_lock l(contexts_mu_); + for (auto& entry : contexts_) { + entry.second->Unref(); + } + } + + absl::Status CreateContext(const CreateContextRequest* request, + CreateContextResponse* response); + + absl::Status UpdateContext(const UpdateContextRequest* request, + UpdateContextResponse* response); + + // Create a ServerContext for master eager context. + absl::Status CreateMasterContext(const tensorflow::uint64 context_id, + EagerContext* context); + + static constexpr uint64 kInvalidStreamId = 0; + + // Used by both Enqueue and StreamingEnqueue RPCs. + absl::Status Enqueue(CallOptions* call_opts, const EnqueueRequest* request, + EnqueueResponse* response, + uint64 stream_id = kInvalidStreamId); + + absl::Status WaitQueueDone(const WaitQueueDoneRequest* request, + WaitQueueDoneResponse* response); + + void RunComponentFunction(CallOptions* call_opts, + const RunComponentFunctionRequest* request, + RunComponentFunctionResponse* response, + StatusCallback done); + + absl::Status KeepAlive(const KeepAliveRequest* request, + KeepAliveResponse* response); + + absl::Status CloseContext(const CloseContextRequest* request, + CloseContextResponse* response); + + protected: + // This is the server-side execution context. All state regarding execution of + // a client's ops is held in this server-side context (all generated tensors, + // and the EagerContext). + class ServerContext : public core::RefCounted { + public: + // Create a ServerContext for local master. + static ServerContext* CreateMasterContext(tensorflow::EagerContext* ctx, + const WorkerEnv* env) { + return new ServerContext(ctx, -1, env, /* is_master= */ true); + } + + explicit ServerContext(tensorflow::EagerContext* ctx, + int64_t destroy_after_secs, const WorkerEnv* env, + const bool is_master = false) + : ctx_(ctx), env_(env), is_master_(is_master) { + ctx->Ref(); + destroy_after_micros_ = + destroy_after_secs * tensorflow::EnvTime::kSecondsToMicros; + RecordAccess(); + } + + ~ServerContext() override { + // TFE_Context is responsible for shutting down master eager context. + if (!is_master_) { + ctx_->WaitForAndCloseRemoteContexts(); + } + // ctx_->RefCountIsOne() should be true here when is_master_ = false. + // TODO(iga): Remove EagerContext refcounting. + ctx_->Unref(); + } + + tensorflow::EagerContext* Context() const { return ctx_; } + + void RecordAccess() { + mutex_lock l(last_accessed_mu_); + last_accessed_micros_ = env_->env->NowMicros(); + } + + bool IsStale() { + mutex_lock l(last_accessed_mu_); + const int64_t time_passed = + env_->env->NowMicros() - last_accessed_micros_; + return (destroy_after_micros_ > 0 && time_passed > destroy_after_micros_); + } + + private: + // The context for this execution. + tensorflow::EagerContext* ctx_; + + const WorkerEnv* const env_; // Not owned. + + mutex last_accessed_mu_; + int64_t last_accessed_micros_ TF_GUARDED_BY(last_accessed_mu_); + int64_t destroy_after_micros_; + + const bool is_master_; + }; + // The returned ServerContext will need to be Unrefed. + absl::Status GetServerContext(uint64, ServerContext**); + + class ClientTensorHandleDeleteNode : public EagerNode { + public: + ClientTensorHandleDeleteNode( + ServerContext* context, + std::unique_ptr handle_to_delete) + : tensorflow::EagerNode(), + context_(context), + handle_to_delete_(std::move(handle_to_delete)) { + context_->Ref(); + } + + ~ClientTensorHandleDeleteNode() override { context_->Unref(); } + + absl::Status Run() override { + VLOG(3) << "ServerContext: Deleting tensor handle " + << handle_to_delete_->op_id << ":" + << handle_to_delete_->output_num; + return context_->Context()->RemoteMgr()->DeleteTensorHandle( + *handle_to_delete_); + } + + void Abort(absl::Status status) override {} + + // Remote node deletions are best effort + bool Fatal() const override { return false; } + + string DebugString() const override { + string out = "[ClientTensorHandleDeleteNode]"; + strings::StrAppend(&out, " op_id: ", handle_to_delete_->op_id); + strings::StrAppend(&out, ", output_num: ", handle_to_delete_->output_num); + return out; + } + + private: + // Owns one reference. + ServerContext* const context_; + const std::unique_ptr handle_to_delete_; + }; + + private: + absl::Status ExecuteOp(CallOptions* call_opts, const Operation& operation, + EagerContext* eager_context, + EagerExecutor* eager_executor, + QueueResponse* queue_response); + absl::Status SendTensor(const SendTensorOp& send_tensor, + EagerContext* eager_context); + absl::Status SendPackedHandle(const SendPackedHandleOp& send_packed_handle, + EagerContext* eager_context); + absl::Status RegisterFunction(const RegisterFunctionOp& register_function, + EagerContext* eager_context); + absl::Status RemoveFunction(const RemoveFunctionOp& remove_function, + EagerContext* eager_context); + absl::Status CleanupFunction(const CleanupFunctionOp& cleanup_function); + + WorkerEnv* const env_; // Not owned. + + mutex contexts_mu_; + std::unordered_map contexts_ + TF_GUARDED_BY(contexts_mu_); + + std::unique_ptr gc_thread_; + mutex gc_thread_shutdown_mu_; + condition_variable gc_thread_cv_; + bool shutting_down_ TF_GUARDED_BY(gc_thread_shutdown_mu_) = false; + + EagerServiceImpl(const EagerServiceImpl&) = delete; + void operator=(const EagerServiceImpl&) = delete; +}; + +} // namespace eager +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_EAGER_EAGER_SERVICE_IMPL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/eager/remote_copy_node.h b/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/eager/remote_copy_node.h new file mode 100644 index 00000000..32f3befd --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/eager/remote_copy_node.h @@ -0,0 +1,179 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_EAGER_REMOTE_COPY_NODE_H_ +#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_EAGER_REMOTE_COPY_NODE_H_ + +#include +#include + +#include "tensorflow/core/common_runtime/eager/eager_executor.h" +#include "tensorflow/core/common_runtime/eager/eager_operation.h" +#include "tensorflow/core/common_runtime/eager/tensor_handle.h" +#include "tensorflow/core/framework/cancellation.h" +#include "tensorflow/core/framework/tensor.h" + +namespace tensorflow { +namespace eager { + +// This node supports copying a tensor in the following way: +// - Remote -> Local: +// We don't block on the remote _Send op and start executing the local +// _Recv immediately after issuing the remote _Send. The local _Recv +// kernel (or rather the special _Recv handling in KernelAndDeviceOp::Run) +// blocks until the tensor is received. If the remote _Send (or some op +// before it) fails, the local callback we give to EnqueueAsync will run +// and call CancellationManager.StartCancel(). The blocked local _Recv will +// get this notification and return with a cancelled error. +// +// - Local -> Remote: +// The local _Send op is synchronous and non-blocking, thus it should complete +// quickly. We issue remote _Recv RPC only after local _Send completes +// successfully. At this point, the tensor to be sent is in the local +// Rendezvous, hence, remote _Recv op will not deadlock waiting for the tensor +// to appear. +// When ctx->UseSendTensorRPC() is true, we use EagerService::Enqueue +// SendTensor instead of _Send/_Recv. +// +// - Remote -> Remote: +// We could issue both remote ops asynchronously, but if remote _Send (or some +// op before it) fails, we don't have a good way of cancelling the remote +// _Recv. The remote _Recv will deadlock in this case. The current approach +// to deal with this issue is to wait for remote _Send to complete before +// issuing remote _Recv RPC. Another option is to close the whole streaming +// RPC that contains the deadlocked remote _Recv. This would not unblock the +// deadlocked RPC on the remote machine without some extra code. Luckily, the +// remote -> remote case seems to be fairly rare at this point. So, the +// current partially synchronous approach seems fine. +// +// To copy a tensor within a host, please use copy_to_device_node instead. +class RemoteCopyNode : public AsyncEagerNode { + public: + RemoteCopyNode(EagerContext* ctx, EagerExecutor* executor, TensorHandle* src, + TensorHandle* dst, Device* recv_device, uint64 recv_op_id); + + ~RemoteCopyNode() override; + + absl::Status Prepare() override; + + void RunAsync(StatusCallback done) override; + + void Abort(absl::Status status) override; + + string DebugString() const override { + string out = "[RemoteCopyNode]"; + strings::StrAppend(&out, " send_device: ", send_device_->name()); + strings::StrAppend(&out, ", recv_device: ", recv_device_->name()); + strings::StrAppend(&out, ", send_tensor: ", src_->DebugString()); + strings::StrAppend( + &out, ", recv_tensor: ", captured_state_->dst()->DebugString()); + return out; + } + + private: + // Runs the _Send operation locally or remotely. + // StartSend() makes sure that captured_state_->send_status_ is set to the + // final _Send status after captured_state->send_done_.WaitForNotification() + // returns. + void StartSend(); + + // Synchronously runs local send `op` and returns its status. + absl::Status RunLocalSend(EagerOperation* op); + + // Runs the _Recv operation locally or remotely. + // An error return value indicates that _Recv did not run successfully. It + // does not indicate that _Send op has completed since StartRecv could have + // encountered an error before waiting for _Send's completion. + // An OK return value does NOT necessarily indicate that _Recv has completed + // successfully (it does now, but won't when streaming RPCs are turned on). + // StartRecv() makes sure that dst_ tensor handle is handled correctly + // (potentially after this methods returns); a tensor is set in the local + // case, a remote shape is set in the remote case, the dst_ handle is + // poisoned in either case if there is an error. + void StartRecv(StatusCallback done); + + // Synchronously runs local receive `op` and returns its status. + // Does not wait for the send to complete before running receive. + absl::Status RunLocalRecv(EagerOperation* op, std::vector* outputs); + + // Waits for send to complete, then issues remote receive `op` and + // returns its status. + void RunRemoteRecv(EagerOperation* op, StatusCallback done); + + // When !ctx->UseSendTensorRPC(), then tensors are shipped between remote + // devices by the receiver invoking the WorkerService.RecvTensor RPC *on the + // sender* (Rendezvous::RecvAsync() invoked by the _Recv kernel). + // + // However, in some configurations the node that has the tensor to be copied + // isn't running a server (WorkerService RPC interface). For such cases, + // this function enables sending tensors using the EagerService.Enqueue + // SendTensor RPC *on the receiver*. + void StartRemoteSendTensor(StatusCallback done); + + // Send a local packed TensorHandle to a remote device. + void StartSendPackedHandle(StatusCallback done); + + // State that is captured by Send and/or Recv callbacks (depending on which + // one(s) is remote) and outlives this node in the case of remote->remote + // copy. + class CapturedSharedState { + public: + explicit CapturedSharedState(TensorHandle* d) : dst_(d) { dst_->Ref(); } + ~CapturedSharedState() { dst_->Unref(); } + + void SetSendStatus(absl::Status status) { + send_status_.Update(status); + send_done_.Notify(); + } + + absl::Status GetSendStatus() { + send_done_.WaitForNotification(); + return send_status_; + } + + // src_shape_ is not thread-safe. It should only be set in one thread. + void SetSrcShape(const TensorShape& shape) { src_shape_ = shape; } + + const TensorShape& GetSrcShape() { return src_shape_; } + + TensorHandle* dst() { return dst_; } + CancellationManager* recv_cancellation() { return &recv_cancellation_; } + + private: + TensorHandle* const dst_; + CancellationManager recv_cancellation_; + // send_status_ is safe to read only after send_done_.WaitForNotification() + // has returned. + absl::Status send_status_; + Notification send_done_; + TensorShape src_shape_; + }; + + TensorHandle* const src_; + EagerContext* const ctx_; + EagerExecutor* const executor_; + Device* const send_device_; + Device* const recv_device_; + const string wire_id_; + const uint64 recv_op_id_; + + std::shared_ptr captured_state_; + bool started_; +}; + +} // namespace eager +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_EAGER_REMOTE_COPY_NODE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/eager/remote_execute_node.h b/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/eager/remote_execute_node.h new file mode 100644 index 00000000..d1c5359d --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/eager/remote_execute_node.h @@ -0,0 +1,145 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_EAGER_REMOTE_EXECUTE_NODE_H_ +#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_EAGER_REMOTE_EXECUTE_NODE_H_ + +#include +#include +#include + +#include "absl/types/span.h" +#include "tensorflow/core/common_runtime/eager/eager_executor.h" +#include "tensorflow/core/common_runtime/eager/shape_inference.h" +#include "tensorflow/core/common_runtime/eager/tensor_handle.h" +#include "tensorflow/core/distributed_runtime/eager/eager_client.h" +#include "tensorflow/core/framework/cancellation.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/lib/gtl/inlined_vector.h" +#include "tensorflow/core/protobuf/eager_service.pb.h" + +namespace tensorflow { +namespace eager { + +// RemoteExecuteNode is an implementation of EagerNode which enqueues +// an operation via RPC in a remote EagerService. +class RemoteExecuteNode : public AsyncRemoteExecuteNode { + public: + RemoteExecuteNode(EagerContext* eager_context, + std::unique_ptr request, Device* device, + uint64 context_view_id, EagerClient* eager_client, + CancellationManager* cancellation_manager, + const NodeDef& ndef, + const FunctionLibraryDefinition* lib_def, + const absl::InlinedVector& inputs, + absl::Span retvals) + : AsyncRemoteExecuteNode(), + eager_context_(eager_context), + request_(std::move(request)), + device_(device), + context_view_id_(context_view_id), + eager_client_(eager_client), + cancellation_manager_(cancellation_manager), + ndef_(ndef), + lib_def_(lib_def), + inputs_(inputs) { + // Copy the output handles, since the container for them might get + // destroyed. + for (auto handle : retvals) { + handle->Ref(); + retvals_.push_back(handle); + } + + // This is required to ensure that the tensor handles stay alive across the + // execution. + for (auto handle : inputs_) { + handle->Ref(); + } + eager_client_->Ref(); + + needs_remote_inputs_ = false; + for (const TensorHandle* input : inputs_) { + // TODO(bramandia): Should this be op_device() instead? + if (input->resource_device() != nullptr && + input->resource_device() != device_) { + needs_remote_inputs_ = true; + break; + } + } + } + + ~RemoteExecuteNode() override { + for (auto handle : retvals_) { + handle->Unref(); + } + + for (auto handle : inputs_) { + handle->Unref(); + } + eager_client_->Unref(); + } + + absl::Status Prepare() override { + return RunShapeInference(ndef_, *lib_def_, inputs_, retvals_); + } + + void RunAsync(StatusCallback done) override; + + absl::Status SyncExecutors() override { + return eager_context_->SyncExecutors(); + } + + void Abort(absl::Status status) override { + int i = 0; + for (auto handle : retvals_) { + handle->PoisonRemote(status, device_, context_view_id_); + ++i; + } + } + + const EagerClient* eager_client() const override { return eager_client_; } + + bool needs_remote_inputs() const override { return needs_remote_inputs_; } + + bool allow_multiple_pending_requests() const override { + return eager_client_->allow_multiple_pending_requests(); + } + + string DebugString() const override { + string out = "[RemoteExecuteNode]"; + strings::StrAppend(&out, " request: ", request_->DebugString()); + strings::StrAppend(&out, ", target_device: ", device_->name()); + return out; + } + + private: + EagerContext* eager_context_; // Not owned, and must outlive this node. + std::unique_ptr request_; + Device* device_; // Not owned + uint64 context_view_id_; + bool needs_remote_inputs_; + EagerClient* eager_client_; // Not owned, and must outlive this node. + CancellationManager* cancellation_manager_; + const NodeDef ndef_; + const FunctionLibraryDefinition* lib_def_; + absl::InlinedVector inputs_; + absl::InlinedVector retvals_; +}; + +} // namespace eager +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_EAGER_REMOTE_EXECUTE_NODE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/eager/remote_mgr.h b/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/eager/remote_mgr.h new file mode 100644 index 00000000..b62134cd --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/eager/remote_mgr.h @@ -0,0 +1,139 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_EAGER_REMOTE_MGR_H_ +#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_EAGER_REMOTE_MGR_H_ + +#include +#include + +#include "absl/strings/string_view.h" +#include "tensorflow/core/common_runtime/eager/eager_executor.h" +#include "tensorflow/core/common_runtime/eager/tensor_handle.h" +#include "tensorflow/core/distributed_runtime/eager/remote_tensor_handle.h" +#include "tensorflow/core/platform/mutex.h" + +namespace tensorflow { +namespace eager { + +// This class manages the states required to setup an eager cluster. +// TODO(fishx): Move remote state from context to this class. +class RemoteMgr { + public: + RemoteMgr(bool is_master, EagerContext* ctx) + : is_master_(is_master), parent_(ctx) {} + + ~RemoteMgr() { + for (const auto& entry : remote_tensor_handle_map_) { + entry.second->Unref(); + } + } + + bool IsMaster() { return is_master_; } + + void AddOperationOutputs( + const absl::Span handles, + int64_t operation_id); + + void AddOperationOutput(tensorflow::TensorHandle* handles, + int64_t operation_id, int32_t output_num); + + absl::Status GetTensorHandle(const RemoteTensorHandleInternal& remote_handle, + tensorflow::TensorHandle** handle); + + absl::Status DeleteTensorHandle( + const RemoteTensorHandleInternal& remote_handle); + + // Helper function to create monotonically increasing ids unique to this + // context. + uint64 NextOpId() { + DCHECK(is_master_); + mutex_lock l(next_id_mutex_); + return next_op_id_++; + } + + // Serialize a remote TensorHandle to a RemoteTensorHandle. + // If wait_until_ready is true, block until the remote handle is ready on a + // remote worker. + absl::Status SerializeRemoteTensorHandle( + TensorHandle* in, const bool wait_until_ready, RemoteTensorHandle* out, + Device* device, absl::string_view device_name = "", + const bool serialize_resource_dtype_and_shape = false); + + // Deserialize a RemoteTensorHandle to a TensorHandle(local/remote). + // The output holds a reference to the TensorHandle. + absl::Status DeserializeRemoteTensorHandle(const RemoteTensorHandle& in, + TensorHandle** out); + + EagerExecutor& GetOrCreateExecutorForStream(uint64 stream_id); + + void DeleteExecutorForStream(uint64 stream_id); + + protected: + mutex next_id_mutex_; + uint64 next_op_id_ TF_GUARDED_BY(next_id_mutex_) = 1; + + private: + // Returns the op_id and output_num if the given local TensorHandle exists in + // remote_tensor_handle_map_. + absl::Status GetRemoteTensorHandle(const tensorflow::TensorHandle* handle, + const bool wait_until_ready, + int64_t* op_id, int32* output_num) + TF_SHARED_LOCKS_REQUIRED(remote_tensor_handle_mu_); + + absl::Status GetTensorHandleImpl( + const RemoteTensorHandleInternal& remote_handle, + tensorflow::TensorHandle** handle) + TF_SHARED_LOCKS_REQUIRED(remote_tensor_handle_mu_); + + absl::Status GetMirroredResourceShape( + const RemoteTensorHandleInternal& remote_handle, + std::vector* handle); + + bool is_master_; + + using RemoteTensorHandleMap = + gtl::FlatMap; + using MirroredResourceShapeMap = gtl::FlatMap< + RemoteTensorHandleInternal, std::vector, + RemoteTensorHandleInternalHash, RemoteTensorHandleInternalEquals>; + + mutex remote_tensor_handle_mu_; + // This map maintains the TensorHandles that are required by remote workers + // in the cluster. Each map key is generated by the master, so it should be + // globally unique. This map owns references on the handles it contains. + RemoteTensorHandleMap remote_tensor_handle_map_ + TF_GUARDED_BY(remote_tensor_handle_mu_); + + mutex mirrored_resource_shape_mu_; + // This map maintains the data types and shapes of resource variables required + // by remote workers in the cluster. Each map key is generated by the master, + // so it should be globally unique. + MirroredResourceShapeMap mirrored_resource_shape_map_ + TF_GUARDED_BY(mirrored_resource_shape_mu_); + + EagerContext* parent_; // not owned. + + mutex executor_map_mu_; + std::unordered_map executor_map_ + TF_GUARDED_BY(executor_map_mu_); +}; + +} // namespace eager +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_EAGER_REMOTE_MGR_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/eager/remote_tensor_handle.h b/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/eager/remote_tensor_handle.h new file mode 100644 index 00000000..903d0191 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/eager/remote_tensor_handle.h @@ -0,0 +1,50 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_EAGER_REMOTE_TENSOR_HANDLE_H_ +#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_EAGER_REMOTE_TENSOR_HANDLE_H_ + +#include "tensorflow/core/platform/fingerprint.h" +#include "tensorflow/core/protobuf/remote_tensor_handle.pb.h" + +namespace tensorflow { +namespace eager { + +struct RemoteTensorHandleInternal { + explicit RemoteTensorHandleInternal(const RemoteTensorHandle& tensor_handle) + : op_id(tensor_handle.op_id()), output_num(tensor_handle.output_num()) {} + RemoteTensorHandleInternal(int64_t op_id, int32_t output_num) + : op_id(op_id), output_num(output_num) {} + int64_t op_id; + int32 output_num; +}; + +struct RemoteTensorHandleInternalHash { + std::size_t operator()(const RemoteTensorHandleInternal& handle) const { + return FingerprintCat64(handle.op_id, handle.output_num); + } +}; + +struct RemoteTensorHandleInternalEquals { + bool operator()(const RemoteTensorHandleInternal& first, + const RemoteTensorHandleInternal& second) const { + return first.op_id == second.op_id && first.output_num == second.output_num; + } +}; + +} // namespace eager +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_EAGER_REMOTE_TENSOR_HANDLE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/eager/remote_tensor_handle_data.h b/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/eager/remote_tensor_handle_data.h new file mode 100644 index 00000000..892d82bd --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/eager/remote_tensor_handle_data.h @@ -0,0 +1,84 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_EAGER_REMOTE_TENSOR_HANDLE_DATA_H_ +#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_EAGER_REMOTE_TENSOR_HANDLE_DATA_H_ + +#include "tensorflow/core/common_runtime/eager/context.h" +#include "tensorflow/core/platform/status.h" + +namespace tensorflow { + +// Remote Tensor Handle: A handle to a Tensor on a remote host. Note that only +// the shape is known. +class RemoteTensorHandleData { + public: + // Constructor for lazy remote handles. A lazy remote handle is created on + // a remote worker with an op_id and an output_num. It doesn't control the + // lifetime of a remote handle that it refers to. If it refers to a remote + // function input, it's sent by a client which won't serialize it until + // the corresponding remote tensor is ready. So the remote tensor should be + // ready when we create a lazy remote handle. If it refers to a remote output, + // it's not ready until the shape is set. + RemoteTensorHandleData(int64_t op_id, int output_num, uint64 context_view_id, + bool is_ready); + // Constructor for unshaped remote handles. It controls the lifetime of a + // remote handle that it refers to. + RemoteTensorHandleData(int64_t op_id, int output_num, + const string& remote_task, EagerContext* ctx); + ~RemoteTensorHandleData(); + + // A remote tensor handle does not have a Tensor object, hence it can only + // support the shape requests. + absl::Status Shape(TensorShape* shape) const; + absl::Status NumDims(int* num_dims) const; + absl::Status Dim(int dim_index, int64_t* dim) const; + absl::Status NumElements(int64_t* num_elements) const; + absl::Status Unprotect() { return absl::OkStatus(); } + + bool IsReady() const; + absl::Status WaitReady(const char* caller) const; + absl::Status SetShape(const TensorShape& shape); + absl::Status SetShapeAndRemoteTask(const TensorShape& shape, + const string& remote_task); + void Poison(absl::Status status); + absl::Status IsPoisoned() const; + + string DebugString() const; + + // Return the op id and output num. If wait_until_ready is true, block until + // the remote tensor is ready on a remote worker. + absl::Status OpIdAndOutputNum(bool wait_until_ready, int64_t* op_id, + int32* output_num) const; + + uint64 context_view_id() const { return context_view_id_; } + + private: + mutable mutex mu_; + bool is_ready_ TF_GUARDED_BY(mu_); + absl::Status is_poisoned_ TF_GUARDED_BY(mu_); + TensorShape shape_ TF_GUARDED_BY(mu_); + + // IDs required when this class is representing a remote tensor handle. + const int64_t op_id_; + const int32 output_num_; + string remote_task_ TF_GUARDED_BY(mu_); + uint64 context_id_; + uint64 context_view_id_; + EagerContext* ctx_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_EAGER_REMOTE_TENSOR_HANDLE_DATA_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/error_payloads.h b/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/error_payloads.h new file mode 100644 index 00000000..ae3b3e5e --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/error_payloads.h @@ -0,0 +1,34 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_ERROR_PAYLOADS_H_ +#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_ERROR_PAYLOADS_H_ + +// This file lists the proto payloads that may be inserted by the code within +// `tensorflow/core/distributed_runtime/` into Status instances. + +namespace tensorflow { +// Proto: tensorflow::distributed_runtime::WorkerPossiblyRestarted +// Location: tensorflow/core/protobuf/distributed_runtime_payloads.proto +// Usage: Flags the Status to be a possible outcome of a worker restart. +constexpr char kWorkerPossiblyRestarted[] = + "type.googleapis.com/" + "tensorflow.distributed_runtime.WorkerPossiblyRestarted"; + +constexpr char kWorkerPreemption[] = + "type.googleapis.com/tensorflow.distributed_runtime.WorkerPreemption"; + +} // namespace tensorflow +#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_ERROR_PAYLOADS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/graph_mgr.h b/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/graph_mgr.h new file mode 100644 index 00000000..5c8c7ce0 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/graph_mgr.h @@ -0,0 +1,214 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_GRAPH_MGR_H_ +#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_GRAPH_MGR_H_ + +#include +#include +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "tensorflow/core/common_runtime/costmodel_manager.h" +#include "tensorflow/core/common_runtime/executor.h" +#include "tensorflow/core/common_runtime/process_function_library_runtime.h" +#include "tensorflow/core/distributed_runtime/message_wrappers.h" +#include "tensorflow/core/distributed_runtime/worker_env.h" +#include "tensorflow/core/framework/cancellation.h" +#include "tensorflow/core/framework/collective.h" +#include "tensorflow/core/framework/cost_graph.pb.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/lib/core/refcount.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/protobuf/config.pb.h" +#include "tensorflow/core/protobuf/debug.pb.h" +#include "tensorflow/core/protobuf/worker.pb.h" +#include "tsl/platform/thread_annotations.h" + +namespace tsl { +class CoordinationServiceAgent; +} + +namespace tensorflow { + +class ExecutorOpts; +class StepStatsCollector; +class RendezvousMgrInterface; +class DeviceMgr; +class WorkerSession; + +// GraphMgr keeps track of a set of graphs that are registered with a +// TensorFlow worker. Each registered graph is identified by a handle +// that is generated by GraphMgr and returned to the caller. +// +// After a successful registration, the caller executes a graph using +// the graph handle. Each execution is distinguished from others by a +// caller generated global unique id "step_id". Multiple executions +// can use the same graph concurrently and independently as long as +// "step_id" used are different. +// +// Multiple threads can call GraphMgr methods concurrently. +// +// E.g., +// GraphMgr gmgr(worker_env); +// string handle; +// TF_CHECK_OK(gmgr.Register("session", { graph computes c = a + b }, +// &handle)); +// GraphMgr::NamedTensors in = { { "a", Tensor({1, 2}) }, +// { "b", Tensor({3, 4}) } }; +// GraphMgr::NamedTensors out = { { "c", Tensor() } }; +// TF_CHECK_OK(gmgr.Execute(handle, 0x0001, in, &out)); +// EXPECT_EQ(out["c"], Tensor({4, 6})); +class GraphMgr { + public: + explicit GraphMgr(const WorkerEnv* worker_env, const DeviceMgr* device_mgr); + ~GraphMgr(); + + // Registers a graph. Fills in "handle". The registered graph retains a + // reference to cluster_flr to do cross process function calls. + absl::Status Register(const string& handle, const GraphDef& gdef, + const GraphOptions& graph_options, + const DebugOptions& debug_options, + const ConfigProto& config_proto, + int64_t collective_graph_key, WorkerSession* session, + DistributedFunctionLibraryRuntime* cluster_flr, + string* graph_handle); + + // Executes one step of a registered graph "handle". + // + // If "out" is not nullptr, "out" specifies all keys the execution + // should receive upon finish. + typedef std::map NamedTensors; + typedef std::function StatusCallback; + void ExecuteAsync(const string& handle, const int64_t step_id, + const ExecutorOpts& opts, const NamedTensors& in, + WorkerSession* session, StepStatsCollector* collector, + MutableRunGraphResponseWrapper* response, + CancellationManager* cancellation_manager, + tsl::CoordinationServiceAgent* coordination_service_agent, + StatusCallback done); + + absl::Status SendInputs(const int64_t step_id, const NamedTensors& in); + absl::Status RecvOutputs(const int64_t step_id, NamedTensors* out); + void RecvOutputsAsync(const int64_t step_id, NamedTensors* out, + StatusCallback done); + + // Deregisters a graph. + absl::Status Deregister(const string& handle); + + // Deregister all graphs. + absl::Status DeregisterAll(); + + private: + typedef GraphMgr ME; + + struct ExecutionUnit { + std::unique_ptr graph = nullptr; + Device* device = nullptr; // not owned. + Executor* root = nullptr; // not owned. + FunctionLibraryRuntime* lib = nullptr; // not owned. + // Build the cost model if this value is strictly positive. + int64_t build_cost_model = 0; + }; + + struct Item : public core::RefCounted { + // TODO(zhifengc): Keeps a copy of the original graph if the need arises. + // TODO(zhifengc): Stats, updated by multiple runs potentially. + // TODO(zhifengc): Dup-detection. Ensure step_id only run once. + ~Item() override; + + // Session handle. + string session; + + // Graph handle. + string handle; + + // Session configuration options for the graph. + ConfigProto session_config; + + std::unique_ptr lib_def; + // Owns the FunctionLibraryRuntime objects needed to execute functions, one + // per device. + std::unique_ptr proc_flr; + // A graph is partitioned over multiple devices. Each partition + // has a root executor which may call into the runtime library. + std::vector units; + + // Used to deregister a cost model when cost model is required in graph + // manager. + GraphMgr* graph_mgr; + + int64_t collective_graph_key; + }; + + const WorkerEnv* worker_env_; // Not owned. + const DeviceMgr* device_mgr_; + + CostModelManager cost_model_manager_; + + // Owned. + mutex mu_; + int64_t next_id_ TF_GUARDED_BY(mu_) = 0; + + // If true, blocks until device has finished all queued operations in a step. + bool sync_on_finish_ = true; + + // Table mapping graph handles to registered graphs. + // + // TODO(zhifengc): If the client does not call Deregister, we'll + // lose memory over time. We should implement a timeout-based + // mechanism to gc these graphs. + std::unordered_map table_; + + void StartParallelExecutors( + const string& handle, int64_t step_id, Item* item, Rendezvous* rendezvous, + CollectiveExecutor::Handle* ce_handle, StepStatsCollector* collector, + CostGraphDef* cost_graph, CancellationManager* cancellation_manager, + WorkerSession* session, int64_t start_time_usecs, + tsl::CoordinationServiceAgent* coordination_service_agent, + StatusCallback done); + + // Don't attempt to process cost models unless explicitly requested for at + // least one of the items. + bool skip_cost_models_ = true; + + void BuildCostModel(Item* item, StepStatsCollector* collector, + CostGraphDef* cost_graph); + + absl::Status InitItem(const string& handle, const GraphDef& gdef, + const GraphOptions& graph_options, + const DebugOptions& debug_options, + const ConfigProto& config_proto, + int64_t collective_graph_key, WorkerSession* session, + DistributedFunctionLibraryRuntime* cluster_flr, + Item* item); + + absl::Status DecorateAndPublishGraphForDebug( + const DebugOptions& debug_options, Graph* graph, Device* device); + + GraphMgr(const GraphMgr&) = delete; + void operator=(const GraphMgr&) = delete; +}; + +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_GRAPH_MGR_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/local_master.h b/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/local_master.h new file mode 100644 index 00000000..e4fc37e4 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/local_master.h @@ -0,0 +1,113 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_LOCAL_MASTER_H_ +#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_LOCAL_MASTER_H_ + +#include + +#include "tensorflow/core/distributed_runtime/master_interface.h" + +namespace tensorflow { + +class Master; + +// An implementation of the TensorFlow master interface that enables direct +// intraprocess communication between the client and the master implementation. +// +// This master implementation is intended to provide more efficient access to +// a master service that has been created in the same process as the client. +// +// TODO(mrry): Add methods that avoid protobuf encoding the request/response +// objects where this affects performance. +// TODO(mrry): Avoid closure creation/context switch overhead for synchronous +// invocation of Master methods. +// TODO(mrry): Make all potentially blocking Master methods take CallOptions +// for cancellation. +class LocalMaster : public MasterInterface { + public: + ~LocalMaster() override {} + + absl::Status CreateSession(CallOptions* call_options, + const CreateSessionRequest* request, + CreateSessionResponse* response) override; + + absl::Status ExtendSession(CallOptions* call_options, + const ExtendSessionRequest* request, + ExtendSessionResponse* response) override; + + absl::Status PartialRunSetup(CallOptions* call_options, + const PartialRunSetupRequest* request, + PartialRunSetupResponse* response) override; + + absl::Status RunStep(CallOptions* call_options, + RunStepRequestWrapper* request, + MutableRunStepResponseWrapper* response) override; + + MutableRunStepRequestWrapper* CreateRunStepRequest() override; + + MutableRunStepResponseWrapper* CreateRunStepResponse() override; + + absl::Status CloseSession(CallOptions* call_options, + const CloseSessionRequest* request, + CloseSessionResponse* response) override; + + absl::Status ListDevices(CallOptions* call_options, + const ListDevicesRequest* request, + ListDevicesResponse* response) override; + + // See tensorflow::Reset() and the comment on ResetRequest. + absl::Status Reset(CallOptions* call_options, const ResetRequest* request, + ResetResponse* response) override; + + absl::Status MakeCallable(CallOptions* call_options, + const MakeCallableRequest* request, + MakeCallableResponse* response) override; + absl::Status RunCallable(CallOptions* call_options, + const RunCallableRequest* request, + RunCallableResponse* response) override; + absl::Status ReleaseCallable(CallOptions* call_options, + const ReleaseCallableRequest* request, + ReleaseCallableResponse* response) override; + + // Registers the mapping from the given `target` to the given `master`. + // + // WARNING: The `master` pointer remains owned by the caller. It is + // the responsibility of the caller to ensure that `master` outlives + // any LocalMaster objects that may wrap this master. There is no + // corresponding deregister method, since clean server shutdown is + // not currently implemented for any server type. + static void Register(const string& target, Master* master, + int64_t default_timeout_in_ms); + + // Returns a pointer to the local master associated with the given + // `target`, or nullptr if none exists. + static std::unique_ptr Lookup(const string& target); + + private: + Master* master_impl_; // Not owned. + const int64_t default_timeout_in_ms_; + + // See `LocalMaster::Lookup` for the factory function that creates + // objects of this type. + LocalMaster(Master* master_impl, const int64_t default_timeout_in_ms); + + LocalMaster(const LocalMaster&) = delete; + void operator=(const LocalMaster&) = delete; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_LOCAL_MASTER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/master.h b/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/master.h new file mode 100644 index 00000000..a3930249 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/master.h @@ -0,0 +1,118 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_MASTER_H_ +#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_MASTER_H_ + +#include + +#include "tensorflow/core/common_runtime/device.h" +#include "tensorflow/core/distributed_runtime/call_options.h" +#include "tensorflow/core/distributed_runtime/master_env.h" +#include "tensorflow/core/distributed_runtime/master_session.h" +#include "tensorflow/core/distributed_runtime/recent_request_ids.h" +#include "tensorflow/core/lib/core/notification.h" +#include "tensorflow/core/lib/gtl/map_util.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/protobuf/master.pb.h" +#include "tensorflow/core/util/util.h" + +namespace tensorflow { + +class Master { + public: + explicit Master(MasterEnv* env, double session_gc_seconds); + virtual ~Master(); + + // Convenient typedef for a closure passing a Status. + typedef std::function MyClosure; + + void CreateSession(const CreateSessionRequest* req, + CreateSessionResponse* resp, MyClosure done); + + void ExtendSession(const ExtendSessionRequest* req, + ExtendSessionResponse* resp, MyClosure done); + + void PartialRunSetup(const PartialRunSetupRequest* req, + PartialRunSetupResponse* resp, MyClosure done); + + void RunStep(CallOptions* opts, const RunStepRequestWrapper* req, + MutableRunStepResponseWrapper* resp, MyClosure done); + + void CloseSession(const CloseSessionRequest* req, CloseSessionResponse* resp, + MyClosure done); + + void ListDevices(const ListDevicesRequest* req, ListDevicesResponse* resp, + MyClosure done); + + // See tensorflow::Reset() and the comment on ResetRequest. + void Reset(const ResetRequest* req, ResetResponse* resp, MyClosure done); + + void MakeCallable(const MakeCallableRequest* req, MakeCallableResponse* resp, + MyClosure done); + void RunCallable(CallOptions* opts, const RunCallableRequest* req, + RunCallableResponse* resp, MyClosure done); + void ReleaseCallable(const ReleaseCallableRequest* req, + ReleaseCallableResponse* resp, MyClosure done); + + private: + typedef Master ME; + + // Not owned. + MasterEnv* env_ = nullptr; + + // Owned. + mutex mu_; + + // shutdown_ is set to true by the dtor. + condition_variable shutdown_cv_; + bool shutdown_ TF_GUARDED_BY(mu_) = false; + Thread* gc_thread_; + + // Maps session handles to sessions. + std::unordered_map sessions_ TF_GUARDED_BY(mu_); + + // Moving average of step times. + MovingAverage last_1000_steps_ TF_GUARDED_BY(mu_); + + // Cumulative number of steps executed. + int64_t step_count_ TF_GUARDED_BY(mu_); + + // If a session is not active for this many seconds, it will be + // closed automatically. + const double session_gc_seconds_; + + // Used to track ids for incoming requests so we can detect duplicates. + RecentRequestIds recent_request_ids_; + + // Call CleanupAll on all workers. + void CleanupWorkers(const ResetRequest& reset); + + // Cleanup unused session. + void GC(); + + // Find master session by session handle, and increments the reference count + // on the returned MasterSession if not null. + MasterSession* FindMasterSession(const string& handle); + + Master(const Master&) = delete; + void operator=(const Master&) = delete; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_MASTER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/master_env.h b/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/master_env.h new file mode 100644 index 00000000..b8dcf196 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/master_env.h @@ -0,0 +1,113 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_MASTER_ENV_H_ +#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_MASTER_ENV_H_ + +#include +#include + +#include "xla/tsl/protobuf/rpc_options.pb.h" +#include "tensorflow/core/distributed_runtime/worker_cache.h" +#include "tensorflow/core/protobuf/cluster.pb.h" +#include "tensorflow/core/protobuf/config.pb.h" +#include "tensorflow/core/protobuf/tensorflow_server.pb.h" +#include "tensorflow/core/public/session_options.h" + +namespace tsl { +class Env; +} // namespace tsl +namespace tensorflow { +using Env = tsl::Env; + +class CollectiveExecutorMgrInterface; +class Device; +class DeviceSet; +class MasterSession; +class OpRegistryInterface; + +// Options passed to the worker_cache_factory function. +struct WorkerCacheFactoryOptions { + ClusterDef cluster_def; + string job_name; + int task_index; + int replica_index = 0; + RPCOptions rpc_options; + + explicit WorkerCacheFactoryOptions() = default; + + // Construct from a ServerDef proto. + explicit WorkerCacheFactoryOptions(const ServerDef& server_def) { + if (server_def.has_cluster() && !server_def.job_name().empty()) { + cluster_def = server_def.cluster(); + job_name = server_def.job_name(); + task_index = server_def.task_index(); + rpc_options = server_def.default_session_config().rpc_options(); + replica_index = server_def.replica(); + } + } +}; + +// The master environment class, which holds a bag of pointers to +// per-master state. +// +// MasterEnv does not own its member pointers. +struct MasterEnv { + Env* env = nullptr; + + // Object from which WorkerInterface instances can be obtained. Not owned. + WorkerCacheInterface* worker_cache = nullptr; + + // The operation definitions to use. Must be filled before use. + const OpRegistryInterface* ops = nullptr; + + // Local devices co-located with this master. Devices are not owned + // by the master service. + // + // REQUIRES: !local_devices.empty(). + std::vector local_devices; + + // In large scaled distributed training, many singleton components (e.g. + // Rendezvous) can becomes the bottleneck of the system. This field allows + // us to shard the single components. This number will scale up with number + // of tasks in this cluster. It is always greater than 1. + int experimental_num_shards = 1; + + // Factory for creating master sessions, given session options and a + // vector of devices. + // + // The caller of the function takes ownership of the returned + // `MasterSession`, which may not be null. Ownership of the + // `MasterEnv*` is retained by the caller. + std::function>>, + std::unique_ptr, + std::unique_ptr device_set, + std::vector filtered_worker_list)> + master_session_factory; + + std::function + worker_cache_factory; + + // Generates per-step CollectiveExecutors and has access to utilities + // supporting collective operations. Not owned. + CollectiveExecutorMgrInterface* collective_executor_mgr = nullptr; +}; + +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_MASTER_ENV_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/master_interface.h b/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/master_interface.h new file mode 100644 index 00000000..df9894f7 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/master_interface.h @@ -0,0 +1,118 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_MASTER_INTERFACE_H_ +#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_MASTER_INTERFACE_H_ + +#include "tensorflow/core/distributed_runtime/call_options.h" +#include "tensorflow/core/distributed_runtime/message_wrappers.h" +#include "tensorflow/core/distributed_runtime/request_id.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/protobuf/master.pb.h" + +namespace tensorflow { + +// Abstract interface for communicating with the TensorFlow Master service. +// +// This interface supports both RPC-based master implementations, and +// in-process master implementations that do not require an RPC +// roundtrip. +class MasterInterface { + public: + virtual ~MasterInterface() {} + virtual absl::Status CreateSession(CallOptions* call_options, + const CreateSessionRequest* request, + CreateSessionResponse* response) = 0; + + virtual absl::Status ExtendSession(CallOptions* call_options, + const ExtendSessionRequest* request, + ExtendSessionResponse* response) = 0; + + virtual absl::Status PartialRunSetup(CallOptions* call_options, + const PartialRunSetupRequest* request, + PartialRunSetupResponse* response) { + return errors::Unimplemented("Partial run not implemented for this master"); + } + + virtual absl::Status RunStep(CallOptions* call_options, + RunStepRequestWrapper* request, + MutableRunStepResponseWrapper* response) = 0; + + virtual absl::Status RunStep(CallOptions* call_options, + const RunStepRequest* request, + RunStepResponse* response) { + std::unique_ptr wrapped_request( + new ProtoRunStepRequest(request)); + std::unique_ptr wrapped_response( + new NonOwnedProtoRunStepResponse(response)); + return RunStep(call_options, wrapped_request.get(), wrapped_response.get()); + } + + // Returns a request object for use in calls to + // `RunStep()`. Ownership is transferred to the caller. + // + // The message returned from this method must only be used in a + // `RunStep()` call on the same `MasterInterface` instance. + virtual MutableRunStepRequestWrapper* CreateRunStepRequest() { + MutableProtoRunStepRequest* ret = new MutableProtoRunStepRequest; + ret->request_.set_request_id(GetUniqueRequestId()); + return ret; + } + + // Returns a response object for use in calls to + // `RunStep()`. Ownership is transferred to the caller. + // + // The message returned from this method must only be used in a + // `RunStep()` call on the same `MasterInterface` instance. + virtual MutableRunStepResponseWrapper* CreateRunStepResponse() { + return new OwnedProtoRunStepResponse; + } + + virtual absl::Status CloseSession(CallOptions* call_options, + const CloseSessionRequest* request, + CloseSessionResponse* response) = 0; + + virtual absl::Status ListDevices(CallOptions* call_options, + const ListDevicesRequest* request, + ListDevicesResponse* response) = 0; + + virtual absl::Status Reset(CallOptions* call_options, + const ResetRequest* request, + ResetResponse* response) = 0; + + virtual absl::Status MakeCallable(CallOptions* call_options, + const MakeCallableRequest* request, + MakeCallableResponse* response) = 0; + virtual absl::Status RunCallable(CallOptions* call_options, + const RunCallableRequest* request, + RunCallableResponse* response) = 0; + virtual absl::Status ReleaseCallable(CallOptions* call_options, + const ReleaseCallableRequest* request, + ReleaseCallableResponse* response) = 0; + + protected: + // NOTE: This should only be called by implementations of this + // interface whose CreateRunStepResponse() method returns a + // proto-based wrappers for the RunStepResponse message. + RunStepResponse* get_proto_from_wrapper( + MutableRunStepResponseWrapper* wrapper) { + return wrapper->get_proto(); + } +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_MASTER_INTERFACE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/master_session.h b/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/master_session.h new file mode 100644 index 00000000..f7016518 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/master_session.h @@ -0,0 +1,265 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_MASTER_SESSION_H_ +#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_MASTER_SESSION_H_ + +#include +#include + +#include "tensorflow/core/common_runtime/debugger_state_interface.h" +#include "tensorflow/core/common_runtime/device_set.h" +#include "tensorflow/core/common_runtime/graph_execution_state.h" +#include "tensorflow/core/common_runtime/stats_publisher_interface.h" +#include "tensorflow/core/distributed_runtime/call_options.h" +#include "tensorflow/core/distributed_runtime/master_env.h" +#include "tensorflow/core/distributed_runtime/message_wrappers.h" +#include "tensorflow/core/distributed_runtime/worker_cache.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/protobuf/master.pb.h" +#include "tensorflow/core/public/session_options.h" + +namespace tensorflow { + +class Device; +struct MasterEnv; + +// A session encapsulates a graph computation (resource allocation, +// placement, execution, etc.). +class MasterSession : public core::RefCounted { + public: + // This session encapsulates the graph computation for a graph. + // + // The session places nodes on devices in "remote_devs" and executes + // operations on these devices. + // + // The caller takes ownership of all remote devices. + MasterSession( + const SessionOptions& options, const MasterEnv* env, + std::unique_ptr>> remote_devs, + std::unique_ptr worker_cache, + std::unique_ptr device_set, + std::vector filtered_worker_list, + StatsPublisherFactory stats_publisher_factory); + + // Initialize the MasterSession for "def". Must be called before Extend(), + // Run(), or Close(). + absl::Status Create(GraphDef&& def, const ClusterDef& cluster_def); + + // Returns the session handle. + const string& handle() const { return handle_; } + + // Returns the last access time (the number of micro-seconds since + // some fixed point in time) of this session. + uint64 last_access_time_usec() const { return last_access_time_usec_.load(); } + + // Attempt to extend the graph according to the given "req". + // (See master.proto for details of valid extensions.) + // + // PRECONDITION: The current version of this session's graph + // is "req->current_graph_version". + // + // POSTCONDITION: The current version of this session's graph + // is "resp->new_graph_version". + // + // Extend() may block the caller thread for a long time. + absl::Status Extend(const ExtendSessionRequest* req, + ExtendSessionResponse* resp); + + // Setup a partial run call. + absl::Status PartialRunSetup(const PartialRunSetupRequest* req, + PartialRunSetupResponse* resp); + + // Run one step. + absl::Status Run(CallOptions* opts, const RunStepRequestWrapper& req, + MutableRunStepResponseWrapper* resp); + + absl::Status ListDevices(ListDevicesResponse* resp) const; + + absl::Status MakeCallable(const MakeCallableRequest& req, + MakeCallableResponse* resp); + + absl::Status RunCallable(CallOptions* opts, const RunCallableRequest& req, + RunCallableResponse* resp); + + absl::Status ReleaseCallable(const ReleaseCallableRequest& req, + ReleaseCallableResponse* resp); + + // Close this session and delete "*this". Returns OK if all known + // states are cleanup successfully. + // + // Close() may block the caller thread for a long time. + absl::Status Close(); + + // Close this session and release a reference on "*this". + // + // Note that, unlike Close(), this method does not block on the + // completion of all work. + void GarbageCollect(); + + private: + SessionOptions session_opts_; + + // Not owned. + const MasterEnv* env_; + + // The opaque session handle. + const string handle_; + + std::unique_ptr>> remote_devs_; + + // The optional session-specific worker cluster. + // TODO(saeta): Convert to std::optional when available. + const std::unique_ptr worker_cache_; + // Retrieves either worker_cache_ or the env_->worker_cache as appropriate. + WorkerCacheInterface* get_worker_cache() const; + + // The device set used by this session. + std::unique_ptr devices_; + + // The (partial device) names of remote worker tasks that this + // session will contact. + const std::vector filtered_worker_list_; + + StatsPublisherFactory stats_publisher_factory_; + + std::atomic_ulong last_access_time_usec_; + + std::atomic partial_run_handle_counter_ = {0}; + + uint64 NewStepId(int64_t graph_key); + + mutex mu_; + std::unique_ptr execution_state_ TF_GUARDED_BY(mu_); + int64_t graph_version_; + + // We keep a map from a signature of a run request to the + // ReffedClientGraph the can execute it. We keep up to one old copy + // of each ReffedClientGraph around because if it gets deallocated + // before a new substitute has been created, Variables can go out of + // scope and lose their state. + class ReffedClientGraph; + typedef std::unordered_map RCGMap; + RCGMap run_graphs_ TF_GUARDED_BY(mu_); + RCGMap partial_run_graphs_ TF_GUARDED_BY(mu_); + int64_t next_callable_handle_ TF_GUARDED_BY(mu_) = 0; + RCGMap callables_ TF_GUARDED_BY(mu_); + + struct PerStepState { + bool collect_costs = false; + bool collect_timeline = false; + bool collect_rpcs = false; + bool collect_partition_graphs = false; + bool report_tensor_allocations_upon_oom = false; + Microseconds start_micros = Microseconds(0); + Microseconds end_micros = Microseconds(0); + std::vector step_stats; // per partition + StepStats rpc_stats; // for RPC layer + CostGraphDef cost_graph; + }; + + struct RunState { + std::unordered_map pending_inputs; // true if fed + std::unordered_map pending_outputs; // true if fetched + ReffedClientGraph* rcg = nullptr; + uint64 step_id; + int64_t collective_graph_key; + int64_t count = 0; + PerStepState pss; + std::unique_ptr ph; + bool step_started = false; + + RunState(const std::vector& input_names, + const std::vector& output_names, ReffedClientGraph* rcg, + const uint64 step_id, const int64_t count); + + bool PendingDone() const; + + ~RunState(); + }; + std::unordered_map> partial_runs_ + TF_GUARDED_BY(mu_); + + // Active RunStep calls. + condition_variable num_running_is_zero_; + int32 num_running_ TF_GUARDED_BY(mu_) = 0; + + bool closed_ TF_GUARDED_BY(mu_) = false; + bool garbage_collected_ TF_GUARDED_BY(mu_) = false; + + std::unordered_map subgraph_execution_counts_ + TF_GUARDED_BY(mu_); + + // We need to ensure that certain nodes added (e.g., send and recv + // nodes) are unique across all sub-graphs within this session. + int64_t next_node_id_ TF_GUARDED_BY(mu_) = 0; + + // Used to cancel running steps on Close(). + CancellationManager cancellation_manager_; + + // Private dtor. The client must call Close(). + ~MasterSession() override; + + // Creates sessions on all workers. + // + // If this session is operating using the new ClusterSpec propagation behavior + // call this method in order to propagate the cluster membership to all + // workers. + absl::Status CreateWorkerSessions(const ClusterDef& cluster_def); + + bool should_delete_worker_sessions_ = false; + absl::Status DeleteWorkerSessions(); + + absl::Status StartStep(const BuildGraphOptions& opts, bool is_partial, + ReffedClientGraph** out_rcg, int64_t* out_count); + void ClearRunsTable(std::vector* to_unref, + RCGMap* rcg_map) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); + void FillPerStepState(MasterSession::ReffedClientGraph* rcg, + const RunOptions& run_options, uint64 step_id, + int64_t count, PerStepState* out_pss, + std::unique_ptr* out_ph); + absl::Status DoRunWithLocalExecution(CallOptions* opts, + const RunStepRequestWrapper& req, + MutableRunStepResponseWrapper* resp); + absl::Status DoPartialRun(CallOptions* opts, const RunStepRequestWrapper& req, + MutableRunStepResponseWrapper* resp); + absl::Status DoRunCallable(CallOptions* opts, ReffedClientGraph* rcg, + const RunCallableRequest& req, + RunCallableResponse* resp); + absl::Status PostRunCleanup(MasterSession::ReffedClientGraph* rcg, + uint64 step_id, const RunOptions& run_options, + PerStepState* pss, + const std::unique_ptr& ph, + const absl::Status& run_status, + RunMetadata* out_run_metadata); + + void MarkRunCompletion(); + void UpdateLastAccessTime(); + + absl::Status BuildAndRegisterPartitions(ReffedClientGraph* rcg); + + absl::Status CreateDebuggerState( + const DebugOptions& debug_options, const RunStepRequestWrapper& req, + int64_t rcg_execution_count, + std::unique_ptr* debugger_state); + + MasterSession(const MasterSession&) = delete; + void operator=(const MasterSession&) = delete; +}; + +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_MASTER_SESSION_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/message_wrappers.h b/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/message_wrappers.h new file mode 100644 index 00000000..d4b07fb5 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/message_wrappers.h @@ -0,0 +1,746 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_MESSAGE_WRAPPERS_H_ +#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_MESSAGE_WRAPPERS_H_ + +#include "absl/status/status.h" +#include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/framework/cost_graph.pb.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/step_stats.pb.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor.pb.h" +#include "tensorflow/core/framework/versions.pb.h" +#include "tensorflow/core/protobuf/config.pb.h" +#include "tensorflow/core/protobuf/master.pb.h" +#include "tensorflow/core/protobuf/worker.pb.h" + +namespace tensorflow { + +//////////////////////////////////////////////////////////////////////////////// +// +// Wrapper classes for the `MasterService.RunStep` request message. +// +// The `RunStepRequest` message can contain potentially large tensor +// data as part of its `feed` submessages. Here we provide specialized +// wrappers that avoid copying the tensor data wherever possible. +// +// See `RunStepRequest` in tensorflow/core/protobuf/master.proto for the +// protocol buffer definition. +// +//////////////////////////////////////////////////////////////////////////////// + +// Abstract interface for an immutable RunStepRequest message. +// +// This interface is typically used by server-side components in the +// TensorFlow master. +class RunStepRequestWrapper { + public: + virtual ~RunStepRequestWrapper() {} + + // REQUIRED: session_handle must be returned by a CreateSession call + // to the same master service. + virtual const string& session_handle() const = 0; + + // Partial run handle (optional). If specified, this will be a partial run + // execution, run up to the specified fetches. + virtual const string& partial_run_handle() const = 0; + + // Tensors to be fed in the step. Each feed is a named tensor. + virtual size_t num_feeds() const = 0; + virtual const string& feed_name(size_t i) const = 0; + + // Stores the content of the feed value at index `i` in `tensor`. + virtual absl::Status FeedValue(size_t i, Tensor* out_tensor) const = 0; + virtual absl::Status FeedValue(size_t i, TensorProto* out_tensor) const = 0; + + // Fetches. A list of tensor names. The caller expects a tensor to + // be returned for each fetch[i] (see RunStepResponse.tensor). The + // order of specified fetches does not change the execution order. + virtual size_t num_fetches() const = 0; + virtual const string& fetch_name(size_t i) const = 0; + + // Target Nodes. A list of node names. The named nodes will be run + // to but their outputs will not be fetched. + virtual size_t num_targets() const = 0; + virtual const string& target_name(size_t i) const = 0; + + // Options for the run call. + virtual const RunOptions& options() const = 0; + + // If true then some errors, e.g., execution errors that have long + // error messages, may return an OK RunStepResponse with the actual + // error saved in the status_code/status_error_message fields of the + // response body. This is a workaround since the RPC subsystem may + // truncate long metadata messages. + virtual bool store_errors_in_response_body() const = 0; + + // Unique identifier for this request. Every RunGraphRequest must have a + // unique request_id, and retried RunGraphRequests must have the same + // request_id. If request_id is zero, retry detection is disabled. + virtual int64_t request_id() const = 0; + + // Returns a human-readable representation of this message for debugging. + virtual string DebugString() const = 0; + + // Returns the wrapped data as a protocol buffer message. + virtual const RunStepRequest& ToProto() const = 0; +}; + +// Abstract interface for a mutable RunStepRequest message. +// +// See `RunStepRequestWrapper` above for a description of the fields. +class MutableRunStepRequestWrapper : public RunStepRequestWrapper { + public: + virtual void set_session_handle(const string& handle) = 0; + virtual void set_partial_run_handle(const string& handle) = 0; + virtual void add_feed(const string& name, const Tensor& value) = 0; + virtual void add_fetch(const string& name) = 0; + virtual void add_target(const string& name) = 0; + virtual RunOptions* mutable_options() = 0; + virtual void set_store_errors_in_response_body(bool store_errors) = 0; +}; + +// Specialized (and mutable) wrapper for RunStep requests between a client and +// master in the same address space. +class InMemoryRunStepRequest : public MutableRunStepRequestWrapper { + public: + // RunStepRequestWrapper methods. + const string& session_handle() const override; + const string& partial_run_handle() const override; + size_t num_feeds() const override; + const string& feed_name(size_t i) const override; + absl::Status FeedValue(size_t i, Tensor* out_tensor) const override; + absl::Status FeedValue(size_t i, TensorProto* out_tensor) const override; + size_t num_fetches() const override; + const string& fetch_name(size_t i) const override; + size_t num_targets() const override; + const string& target_name(size_t i) const override; + const RunOptions& options() const override; + string DebugString() const override; + const RunStepRequest& ToProto() const override; + bool store_errors_in_response_body() const override; + int64_t request_id() const override; + + // MutableRunStepRequestWrapper methods. + void set_session_handle(const string& handle) override; + void set_partial_run_handle(const string& handle) override; + void add_feed(const string& name, const Tensor& value) override; + void add_fetch(const string& name) override; + void add_target(const string& name) override; + RunOptions* mutable_options() override; + void set_store_errors_in_response_body(bool store_errors) override; + + private: + string session_handle_; + string partial_run_handle_; + absl::InlinedVector, 4UL> feeds_; + absl::InlinedVector fetches_; + absl::InlinedVector targets_; + RunOptions options_; + bool store_errors_in_response_body_ = false; + + // Holds a cached and owned representation of the proto + // representation of this request, if needed, so that `ToProto()` + // can return a const RunStepRequest&. + // NOTE(mrry): Although calls to `ToProto()` on this class are + // expected to be rare, retaining ownership of the returned message + // makes it easier to return a reference from the proto-backed + // representations. + mutable std::unique_ptr proto_version_; +}; + +// Wrapper for mutable RunStep requests that uses a protobuf message. +// +// This wrapper class should be used for RunStep requests between a +// client and master in different address spaces. +class MutableProtoRunStepRequest : public MutableRunStepRequestWrapper { + public: + // RunStepRequestWrapper methods. + const string& session_handle() const override; + const string& partial_run_handle() const override; + size_t num_feeds() const override; + const string& feed_name(size_t i) const override; + absl::Status FeedValue(size_t i, Tensor* out_tensor) const override; + absl::Status FeedValue(size_t i, TensorProto* out_tensor) const override; + size_t num_fetches() const override; + const string& fetch_name(size_t i) const override; + size_t num_targets() const override; + const string& target_name(size_t i) const override; + const RunOptions& options() const override; + string DebugString() const override; + const RunStepRequest& ToProto() const override; + bool store_errors_in_response_body() const override; + int64_t request_id() const override; + + // MutableRunStepRequestWrapper methods. + void set_session_handle(const string& handle) override; + void set_partial_run_handle(const string& handle) override; + void add_feed(const string& name, const Tensor& value) override; + void add_fetch(const string& name) override; + void add_target(const string& name) override; + RunOptions* mutable_options() override; + void set_store_errors_in_response_body(bool store_errors) override; + + private: + RunStepRequest request_; + friend class MasterInterface; +}; + +// Wrapper for immutable RunStep requests that use a non-owned +// protobuf message. +// +// This interface is typically used by server-side components in the +// TensorFlow master, where the incoming message is a (possibly const) +// `RunStepRequest*`. +class ProtoRunStepRequest : public RunStepRequestWrapper { + public: + ProtoRunStepRequest(const RunStepRequest* request); + + // RunStepRequestWrapper methods. + const string& session_handle() const override; + const string& partial_run_handle() const override; + size_t num_feeds() const override; + const string& feed_name(size_t i) const override; + absl::Status FeedValue(size_t i, Tensor* out_tensor) const override; + absl::Status FeedValue(size_t i, TensorProto* out_tensor) const override; + size_t num_fetches() const override; + const string& fetch_name(size_t i) const override; + size_t num_targets() const override; + const string& target_name(size_t i) const override; + const RunOptions& options() const override; + string DebugString() const override; + const RunStepRequest& ToProto() const override; + bool store_errors_in_response_body() const override; + int64_t request_id() const override; + + private: + const RunStepRequest* const request_; // Not owned. +}; + +//////////////////////////////////////////////////////////////////////////////// +// +// Wrapper classes for the `WorkerService.RunGraph` request message. +// +// The `RunGraphRequest` message can contain potentially large tensor +// data as part of its `send` submessages. Here we provide specialized +// wrappers that avoid copying the tensor data wherever possible. +// +// See `RunGraphRequest` in tensorflow/core/protobuf/worker.proto for the +// protocol buffer definition. +// +//////////////////////////////////////////////////////////////////////////////// + +// Abstract interface for an immutable RunGraphRequest message. +// +// This interface is typically used by server-side components in the +// TensorFlow worker. +class RunGraphRequestWrapper { + public: + virtual ~RunGraphRequestWrapper() {} + + // The session handle used to register the graph. If empty, a single global + // namespace is used. + virtual const string& session_handle() const = 0; + + // Set to true if `CreateWorkerSession` was called for `session_handle`. + virtual bool create_worker_session_called() const = 0; + + // REQUIRED: graph_handle must be returned by a RegisterGraph call + // to the same WorkerService. + virtual const string& graph_handle() const = 0; + + // A unique ID to distinguish different runs of the same graph. + // + // The master generates a global unique `step_id` to distinguish + // different runs of the graph computation. Subgraphs communicate + // (e.g., send/recv ops) with each other using `step_id` to + // distinguish tensors generated by different runs. + virtual int64_t step_id() const = 0; + + // Options for this step. + virtual const ExecutorOpts& exec_opts() const = 0; + + // Sends the tensors in "send" into the graph before the run. + virtual size_t num_sends() const = 0; + virtual const string& send_key(size_t i) const = 0; + virtual absl::Status SendValue(size_t i, Tensor* out_tensor) const = 0; + + // Fetches the keys into `RunGraphResponse.recv` after the run. + virtual size_t num_recvs() const = 0; + virtual const string& recv_key(size_t i) const = 0; + + // True if the RunGraphRequest is a partial run request. + virtual bool is_partial() const = 0; + + // True if this is the last partial run request in a sequence of requests. + virtual bool is_last_partial_run() const = 0; + + // If true then some errors, e.g., execution errors that have long + // error messages, may return an OK RunStepResponse with the actual + // error saved in the status_code/status_error_message fields of the + // response body. This is a workaround since the RPC subsystem may + // truncate long metadata messages. + virtual bool store_errors_in_response_body() const = 0; + + virtual int64_t request_id() const = 0; + + // Returns the wrapped data as a protocol buffer message. + virtual const RunGraphRequest& ToProto() const = 0; +}; + +// Abstract interface for a mutable RunGraphRequest message. +// +// See `RunGraphRequestWrapper` above for a description of the fields. +class MutableRunGraphRequestWrapper : public RunGraphRequestWrapper { + public: + virtual void set_session_handle(const string& handle) = 0; + virtual void set_create_worker_session_called(bool called) = 0; + virtual void set_graph_handle(const string& handle) = 0; + virtual void set_step_id(int64_t step_id) = 0; + virtual ExecutorOpts* mutable_exec_opts() = 0; + + // Stores the i^{th} feed value in `run_step_request` in this + // request with the given `send_key`. + virtual absl::Status AddSendFromRunStepRequest( + const RunStepRequestWrapper& run_step_request, size_t i, + const string& send_key) = 0; + virtual absl::Status AddSendFromRunCallableRequest( + const RunCallableRequest& run_callable_request, size_t i, + const string& send_key) = 0; + + virtual void add_recv_key(const string& recv_key) = 0; + virtual void set_is_partial(bool is_partial) = 0; + virtual void set_is_last_partial_run(bool is_last_partial_run) = 0; + virtual void set_store_errors_in_response_body(bool store_errors) = 0; + virtual void set_request_id(int64_t request_id) = 0; +}; + +class InMemoryRunGraphRequest : public MutableRunGraphRequestWrapper { + public: + // RunGraphRequestWrapper methods. + const string& session_handle() const override; + const string& graph_handle() const override; + bool create_worker_session_called() const override; + int64_t step_id() const override; + const ExecutorOpts& exec_opts() const override; + size_t num_sends() const override; + const string& send_key(size_t i) const override; + absl::Status SendValue(size_t i, Tensor* out_tensor) const override; + size_t num_recvs() const override; + const string& recv_key(size_t i) const override; + bool is_partial() const override; + bool is_last_partial_run() const override; + const RunGraphRequest& ToProto() const override; + bool store_errors_in_response_body() const override; + int64_t request_id() const override; + + // MutableRunGraphRequestWrapper methods. + void set_session_handle(const string& handle) override; + void set_create_worker_session_called(bool called) override; + void set_graph_handle(const string& handle) override; + void set_step_id(int64_t step_id) override; + ExecutorOpts* mutable_exec_opts() override; + absl::Status AddSendFromRunStepRequest( + const RunStepRequestWrapper& run_step_request, size_t i, + const string& send_key) override; + absl::Status AddSendFromRunCallableRequest( + const RunCallableRequest& run_callable_request, size_t i, + const string& send_key) override; + void add_recv_key(const string& recv_key) override; + void set_is_partial(bool is_partial) override; + void set_is_last_partial_run(bool is_last_partial_run) override; + void set_store_errors_in_response_body(bool store_errors) override; + void set_request_id(int64_t request_id) override; + + private: + string session_handle_; + bool create_worker_session_called_ = false; + string graph_handle_; + int64_t step_id_; + ExecutorOpts exec_opts_; + absl::InlinedVector, 4UL> sends_; + absl::InlinedVector recvs_; + bool is_partial_ = false; + bool is_last_partial_run_ = false; + bool store_errors_in_response_body_ = false; + int64_t request_id_ = 0; + + // Holds a cached and owned representation of the proto + // representation of this request, if needed, so that `ToProto()` + // can return a const RunGraphRequest&. + // NOTE(mrry): Although calls to `ToProto()` on this class are + // expected to be rare, retaining ownership of the returned message + // makes it easier to return a reference from the proto-backed + // representations. + mutable std::unique_ptr proto_version_; +}; + +class MutableProtoRunGraphRequest : public MutableRunGraphRequestWrapper { + public: + // RunGraphRequestWrapper methods. + const string& session_handle() const override; + bool create_worker_session_called() const override; + const string& graph_handle() const override; + int64_t step_id() const override; + const ExecutorOpts& exec_opts() const override; + size_t num_sends() const override; + const string& send_key(size_t i) const override; + absl::Status SendValue(size_t i, Tensor* out_tensor) const override; + size_t num_recvs() const override; + const string& recv_key(size_t i) const override; + bool is_partial() const override; + bool is_last_partial_run() const override; + bool store_errors_in_response_body() const override; + int64_t request_id() const override; + const RunGraphRequest& ToProto() const override; + + // MutableRunGraphRequestWrapper methods. + void set_session_handle(const string& handle) override; + void set_create_worker_session_called(bool called) override; + void set_graph_handle(const string& handle) override; + void set_step_id(int64_t step_id) override; + ExecutorOpts* mutable_exec_opts() override; + absl::Status AddSendFromRunStepRequest( + const RunStepRequestWrapper& run_step_request, size_t i, + const string& send_key) override; + absl::Status AddSendFromRunCallableRequest( + const RunCallableRequest& run_callable_request, size_t i, + const string& send_key) override; + void add_recv_key(const string& recv_key) override; + void set_is_partial(bool is_partial) override; + void set_is_last_partial_run(bool is_last_partial_run) override; + void set_store_errors_in_response_body(bool store_errors) override; + void set_request_id(int64_t request_id) override; + + private: + RunGraphRequest request_; +}; + +class ProtoRunGraphRequest : public RunGraphRequestWrapper { + public: + ProtoRunGraphRequest(const RunGraphRequest* request); + + // RunGraphRequestWrapper methods. + const string& session_handle() const override; + bool create_worker_session_called() const override; + const string& graph_handle() const override; + int64_t step_id() const override; + const ExecutorOpts& exec_opts() const override; + size_t num_sends() const override; + const string& send_key(size_t i) const override; + absl::Status SendValue(size_t i, Tensor* out_tensor) const override; + size_t num_recvs() const override; + const string& recv_key(size_t i) const override; + bool is_partial() const override; + bool is_last_partial_run() const override; + bool store_errors_in_response_body() const override; + int64_t request_id() const override; + const RunGraphRequest& ToProto() const override; + + private: + const RunGraphRequest* const request_; // Not owned. +}; + +//////////////////////////////////////////////////////////////////////////////// +// +// Wrapper classes for the `WorkerService.RunGraph` response message. +// +// The `RunGraphResponse` message can contain potentially large tensor +// data as part of its `recv` submessages. Here we provide specialized +// wrappers that avoid copying the tensor data wherever possible. +// +// See `RunGraphResponse` in tensorflow/core/protobuf/worker.proto for the +// protocol buffer definition. +// +//////////////////////////////////////////////////////////////////////////////// + +// Abstract interface for a mutable RunGraphResponse message. +// +// Note that there is no corresponding (immutable) +// RunGraphResponseWrapper class, because the RunGraphResponse object +// is always used as a mutable pointer. +class MutableRunGraphResponseWrapper { + public: + virtual ~MutableRunGraphResponseWrapper() {} + + // A list of tensors corresponding to those requested by + // `RunGraphRequest.recv_key`. + virtual size_t num_recvs() const = 0; + virtual const string& recv_key(size_t i) const = 0; + // NOTE: The following methods may perform a destructive read, for + // efficiency. + virtual absl::Status RecvValue(size_t i, TensorProto* out_tensor) = 0; + virtual absl::Status RecvValue(size_t i, Tensor* out_tensor) = 0; + virtual void AddRecv(const string& key, const Tensor& value) = 0; + + // Submessages that store performance statistics about the subgraph + // execution, if necessary. + virtual StepStats* mutable_step_stats() = 0; + virtual CostGraphDef* mutable_cost_graph() = 0; + virtual size_t num_partition_graphs() const = 0; + virtual GraphDef* mutable_partition_graph(size_t i) = 0; + virtual void AddPartitionGraph(const GraphDef& partition_graph) = 0; + + // Returned status if requested. + virtual absl::Status status() const = 0; + virtual absl::StatusCode status_code() const = 0; + virtual void set_status(const absl::Status& status) = 0; + + protected: + // Returns a mutable protobuf message that represents the contents of + // this wrapper, for passing to an RPC subsystem that will populate + // the message. + // + // NOTE: Only `WorkerInterface` subclasses may call this method. The + // `InMemoryRunGraphResponse` subclass does not implement this + // method, and attempts to call it will fail with a fatal + // error. However, as long as callers always call + // `WorkerInterface::RunGraphAsync()` with a wrapper object returned + // from `WorkerInterface::CreateRunGraphResponse()` called on the + // *same* WorkerInterface object, this error will never trigger. + virtual RunGraphResponse* get_proto() = 0; + friend class WorkerInterface; +}; + +class InMemoryRunGraphResponse : public MutableRunGraphResponseWrapper { + public: + // MutableRunGraphResponseWrapper methods. + size_t num_recvs() const override; + const string& recv_key(size_t i) const override; + absl::Status RecvValue(size_t i, TensorProto* out_tensor) override; + absl::Status RecvValue(size_t i, Tensor* out_tensor) override; + void AddRecv(const string& key, const Tensor& value) override; + StepStats* mutable_step_stats() override; + CostGraphDef* mutable_cost_graph() override; + size_t num_partition_graphs() const override; + GraphDef* mutable_partition_graph(size_t i) override; + void AddPartitionGraph(const GraphDef& partition_graph) override; + absl::Status status() const override; + absl::StatusCode status_code() const override; + void set_status(const absl::Status& status) override; + + protected: + // NOTE: This method is not implemented. See + // MutableRunGraphResponseWrapper for an explanation. + RunGraphResponse* get_proto() override; + + private: + absl::InlinedVector, 4UL> recvs_; + StepStats step_stats_; + CostGraphDef cost_graph_; + std::vector partition_graphs_; + // Store the code and message separately so that they can be updated + // independently by setters. + absl::Status status_; +}; + +// Proto-based message wrapper for use on the client side of the RunGraph RPC. +class OwnedProtoRunGraphResponse : public MutableRunGraphResponseWrapper { + public: + // MutableRunGraphResponseWrapper methods. + size_t num_recvs() const override; + const string& recv_key(size_t i) const override; + absl::Status RecvValue(size_t i, TensorProto* out_tensor) override; + absl::Status RecvValue(size_t i, Tensor* out_tensor) override; + void AddRecv(const string& key, const Tensor& value) override; + StepStats* mutable_step_stats() override; + CostGraphDef* mutable_cost_graph() override; + size_t num_partition_graphs() const override; + GraphDef* mutable_partition_graph(size_t i) override; + void AddPartitionGraph(const GraphDef& partition_graph) override; + absl::Status status() const override; + absl::StatusCode status_code() const override; + void set_status(const absl::Status& status) override; + + protected: + RunGraphResponse* get_proto() override; + + private: + RunGraphResponse response_; +}; + +// Proto-based message wrapper for use on the server side of the RunGraph RPC. +class NonOwnedProtoRunGraphResponse : public MutableRunGraphResponseWrapper { + public: + NonOwnedProtoRunGraphResponse(RunGraphResponse* response); + + // MutableRunGraphResponseWrapper methods. + size_t num_recvs() const override; + const string& recv_key(size_t i) const override; + absl::Status RecvValue(size_t i, TensorProto* out_tensor) override; + absl::Status RecvValue(size_t i, Tensor* out_tensor) override; + void AddRecv(const string& key, const Tensor& value) override; + StepStats* mutable_step_stats() override; + CostGraphDef* mutable_cost_graph() override; + size_t num_partition_graphs() const override; + GraphDef* mutable_partition_graph(size_t i) override; + void AddPartitionGraph(const GraphDef& partition_graph) override; + absl::Status status() const override; + absl::StatusCode status_code() const override; + void set_status(const absl::Status& status) override; + + protected: + RunGraphResponse* get_proto() override; + + private: + RunGraphResponse* const response_; +}; + +//////////////////////////////////////////////////////////////////////////////// +// +// Wrapper classes for the `MasterService.RunStep` response message. +// +// The `RunStepResponse` message can contain potentially large tensor +// data as part of its `tensor` submessages. Here we provide specialized +// wrappers that avoid copying the tensor data wherever possible. +// +// See `RunStepResponse` in tensorflow/core/protobuf/master.proto for the +// protocol buffer definition. +// +//////////////////////////////////////////////////////////////////////////////// + +// Abstract interface for a mutable RunStepResponse message. +// +// Note that there is no corresponding (immutable) +// RunStepResponseWrapper class, because the RunStepResponse object is +// always used as a mutable pointer. +class MutableRunStepResponseWrapper { + public: + virtual ~MutableRunStepResponseWrapper(); + + // The values of the tensors whose fetching was requested in the + // RunStep call. + // + // NOTE: The order of the returned tensors may or may not match + // the fetch order specified in RunStepRequest. + virtual size_t num_tensors() const = 0; + virtual const string& tensor_name(size_t i) const = 0; + virtual absl::Status TensorValue(size_t i, Tensor* out_tensor) const = 0; + + // Stores the i^{th} recv value in `run_graph_response` in this + // response with the given `name`. + virtual absl::Status AddTensorFromRunGraphResponse( + const string& name, MutableRunGraphResponseWrapper* run_graph_response, + size_t i) = 0; + + // Returned metadata if requested in the options. + virtual const RunMetadata& metadata() const = 0; + virtual RunMetadata* mutable_metadata() = 0; + + // Returned status if requested. + virtual absl::Status status() const = 0; + virtual absl::StatusCode status_code() const = 0; + virtual void set_status(const absl::Status& status) = 0; + + protected: + // Returns a mutable protobuf message that represents the contents of + // this wrapper, for passing to an RPC subsystem that will populate + // the message. + // + // NOTE: Only `MasterInterface` subclasses may call this method. The + // `InMemoryRunStepResponse` subclass does not implement this + // method, and attempts to call it will fail with a fatal + // error. However, as long as callers always call + // `MasterInterface::RunStep()` with a wrapper object returned + // from `MasterInterface::CreateRunStepResponse()` called on the + // *same* MasterInterface object, this error will never trigger. + virtual RunStepResponse* get_proto() = 0; + friend class MasterInterface; +}; + +class InMemoryRunStepResponse : public MutableRunStepResponseWrapper { + public: + // MutableRunStepResponseWrapper methods. + size_t num_tensors() const override; + const string& tensor_name(size_t i) const override; + absl::Status TensorValue(size_t i, Tensor* out_tensor) const override; + absl::Status AddTensorFromRunGraphResponse( + const string& name, MutableRunGraphResponseWrapper* run_graph_response, + size_t i) override; + const RunMetadata& metadata() const override; + RunMetadata* mutable_metadata() override; + absl::Status status() const override; + absl::StatusCode status_code() const override; + void set_status(const absl::Status& status) override; + + protected: + // NOTE: This method is not implemented. See + // MutableRunGraphResponseWrapper for an explanation. + RunStepResponse* get_proto() override; + + private: + absl::InlinedVector, 4UL> tensors_; + RunMetadata metadata_; + // Store the code and message separately so that they can be updated + // independently by setters. + absl::Status status_; +}; + +// Proto-based message wrapper for use on the client side of the RunStep RPC. +class OwnedProtoRunStepResponse : public MutableRunStepResponseWrapper { + public: + // MutableRunStepResponseWrapper methods. + size_t num_tensors() const override; + const string& tensor_name(size_t i) const override; + absl::Status TensorValue(size_t i, Tensor* out_tensor) const override; + absl::Status AddTensorFromRunGraphResponse( + const string& name, MutableRunGraphResponseWrapper* run_graph_response, + size_t i) override; + const RunMetadata& metadata() const override; + RunMetadata* mutable_metadata() override; + absl::Status status() const override; + absl::StatusCode status_code() const override; + void set_status(const absl::Status& status) override; + + protected: + RunStepResponse* get_proto() override; + + private: + RunStepResponse response_; +}; + +// Proto-based message wrapper for use on the server side of the RunStep RPC. +class NonOwnedProtoRunStepResponse : public MutableRunStepResponseWrapper { + public: + NonOwnedProtoRunStepResponse(RunStepResponse* response); + + // MutableRunStepResponseWrapper methods. + size_t num_tensors() const override; + const string& tensor_name(size_t i) const override; + absl::Status TensorValue(size_t i, Tensor* out_tensor) const override; + absl::Status AddTensorFromRunGraphResponse( + const string& name, MutableRunGraphResponseWrapper* run_graph_response, + size_t i) override; + const RunMetadata& metadata() const override; + RunMetadata* mutable_metadata() override; + absl::Status status() const override; + absl::StatusCode status_code() const override; + void set_status(const absl::Status& status) override; + + protected: + RunStepResponse* get_proto() override; + + private: + RunStepResponse* response_; // Not owned. +}; + +bool ParseTensorProtoToTensor(const TensorProto& tensor_proto, + Tensor* out_tensor); + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_MESSAGE_WRAPPERS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/partial_run_mgr.h b/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/partial_run_mgr.h new file mode 100644 index 00000000..bf2b2b1a --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/partial_run_mgr.h @@ -0,0 +1,88 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_PARTIAL_RUN_MGR_H_ +#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_PARTIAL_RUN_MGR_H_ + +#include + +#include "tensorflow/core/distributed_runtime/worker_interface.h" +#include "tensorflow/core/framework/cancellation.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +// PartialRunMgr keeps track of pending partial run requests, and ensures that +// the partial run is only marked complete when the corresponding executor is +// run to completion. +// +// In tensorflow workers, the executor runs operations asynchronously until +// specified fetches (operations that return tensors) or targets (operations +// that don't return tensors) are reached. A PartialRun has two components: a +// setup which specifies all desired fetches and targets, and run calls that +// specify fetch values (from the setup calls) to retrieve. +// On the last partial run call, it is possible to satisfy the +// required fetches before the executor has completed running the graph to all +// the desired targets. +// PartialRunMgr is used to ensure that we don't complete and return the final +// partial run call to the user until both the partial run and executor have +// completed. +// +// PartialRunMgr is thread-safe. +class PartialRunMgr { + public: + // Find or create the CancellationManager associated with step_id. + // The PartialRunMgr owns the cancellation_manager. + // Returns true if a new CancellationManager was created + // (i.e this is a new partial run). + bool FindOrCreate(int step_id, CancellationManager** cancellation_manager); + + // Calls the final callback if the PartialRunRequest has already completed. + // Otherwise stores the executor_status to be propagated when the + // PartialRunRequest completes (PartialRunDone has been called). + void ExecutorDone(int step_id, const absl::Status& executor_status); + + // Calls done if the executor has already completed (ExecutorDone has been + // called). Otherwise, stores the status and done callback, calling them when + // ExecutorDone is called. The callback will either be called by the calling + // thread of either PartialRunDone or ExecutorDone. + // If executor_status in ExecutorDone is not OK, it takes precedence over + // status and is passed to the done callback. + void PartialRunDone(int step_id, StatusCallback done, + const absl::Status& status); + + private: + // PartialRunState stores state associated with a pending partial run request. + // This is protected by the mutex in PartialRunMgr. + struct PartialRunState { + std::unique_ptr cancellation_manager; + + bool executor_done = false; + StatusCallback final_callback = nullptr; + absl::Status final_status; + }; + + mutex mu_; + + std::unordered_map> + step_id_to_partial_run_ TF_GUARDED_BY(mu_); +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_PARTIAL_RUN_MGR_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/preemption/preemption_sync_manager.h b/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/preemption/preemption_sync_manager.h new file mode 100644 index 00000000..cbe03db6 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/preemption/preemption_sync_manager.h @@ -0,0 +1,27 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_PREEMPTION_PREEMPTION_SYNC_MANAGER_H_ +#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_PREEMPTION_PREEMPTION_SYNC_MANAGER_H_ + +#include "xla/tsl/distributed_runtime/preemption/preemption_sync_manager.h" + +namespace tensorflow { +// NOLINTBEGIN(misc-unused-using-decls) +using tsl::CreatePreemptionSyncManager; +using tsl::PreemptionSyncManager; +// NOLINTEND(misc-unused-using-decls) +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_PREEMPTION_PREEMPTION_SYNC_MANAGER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/recent_request_ids.h b/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/recent_request_ids.h new file mode 100644 index 00000000..2eb35ac7 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/recent_request_ids.h @@ -0,0 +1,104 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RECENT_REQUEST_IDS_H_ +#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RECENT_REQUEST_IDS_H_ + +#include +#include +#include + +#include "absl/container/flat_hash_set.h" +#include "tensorflow/core/distributed_runtime/message_wrappers.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/platform/thread_annotations.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/protobuf/worker.pb.h" + +namespace tensorflow { + +// RecentRequestIds tracks recent 64-bit request_ids. When maximum capacity is +// reached, the oldest request_id is evicted. Thread safe. +// +// Some RPCs like RecvTensor are unsafe to retry. For example, RecvTensor pairs +// one sender and one receiver, and the receiver waits for the sender's tensor. +// Retried RecvTensor requests are problematic, because the original RecvTensor +// request may have consumed the sender's tensor, so a retried request might +// block forever. RecentRequestIds identifies retried requests, so we can fail +// them instead of blocking forever. +// +// Internally, recent request_ids are stored in two data structures: a set and a +// circular buffer. The set is used for efficient lookups, and the circular +// buffer tracks the oldest request_id. When the buffer is full, the new +// request_id replaces the oldest request_id in the circular buffer, and the +// oldest request_id is removed from the set. +class RecentRequestIds { + public: + // num_tracked_request_ids should be much larger than the number of RPCs that + // can be received in a small time window. For example, we observed a peak RPC + // rate of ~700 RecvTensor RPC/s when training inception v3 on TPUs, so we + // currently set num_tracked_request_ids to 100,000 for RecvTensor. + // Having a large `num_shars` can prevent run into lock contention in this + // class. + explicit RecentRequestIds(int num_tracked_request_ids, int num_shards = 1); + + // Returns OK iff request_id has not been seen in the last + // num_tracked_request_ids insertions. For backwards compatibility, this + // always returns OK for request_id 0. The method_name and the request's + // ShortDebugString are added to returned errors. + absl::Status TrackUnique(int64_t request_id, const string& method_name, + const protobuf::Message& request); + // Overloaded version of the above function for wrapped protos. + template + absl::Status TrackUnique(int64_t request_id, const string& method_name, + const RequestWrapper* wrapper); + + private: + bool Insert(int64_t request_id); + + struct IndexBucket { + mutex mu; + // next_index indexes into circular_buffer_, and points to the next storage + // space to use. When the buffer is full, next_index_ points at the oldest + // request_id. + int next_index TF_GUARDED_BY(mu) = 0; + std::vector circular_buffer TF_GUARDED_BY(mu); + absl::flat_hash_set set TF_GUARDED_BY(mu); + }; + + // This vector is immutable so we don't need to use a mutex to protect it. + std::vector index_buckets_; +}; + +// Implementation details + +template +absl::Status RecentRequestIds::TrackUnique(int64_t request_id, + const string& method_name, + const RequestWrapper* wrapper) { + if (Insert(request_id)) { + return absl::OkStatus(); + } else { + return errors::Aborted("The same ", method_name, + " request was received twice. ", + wrapper->ToProto().ShortDebugString()); + } +} + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RECENT_REQUEST_IDS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/remote_device.h b/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/remote_device.h new file mode 100644 index 00000000..591531f9 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/remote_device.h @@ -0,0 +1,72 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_REMOTE_DEVICE_H_ +#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_REMOTE_DEVICE_H_ + +#include +#include +#include + +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/platform/stringpiece.h" + +namespace tsl { +class Env; +} // namespace tsl +namespace tensorflow { +using Env = tsl::Env; +class DeviceAttributes; +class Device; +class WorkerCacheInterface; + +// This callback should have the same definition as DeviceMgr::LookupDevice +// It assigns *device with pointer to Device of the given 'name', where 'name' +// is either a full device name, or just the replica-local suffix. +typedef std::function + LookupLocalDevice; + +// Creates Remote Devices for the provided device attributes. Helpful when the +// list of attributes is known, and doesn't need to be discovered via RPC. +void AsRemoteDevices( + Env* env, + const protobuf::RepeatedPtrField& device_attributes, + LookupLocalDevice lookup_local_device, + std::vector>* remote_devices); + +// NewRemoteDevices discovers available devices on the +// 'worker_name'. The implementation uses 'channel_cache' to +// discover how to communicate with the 'worker_name' (via gRPC, for +// example). +// +// NewRemoteDevices does not block. +// +// On success, the 'done' callback is given the OK status and a vector +// of Device*. The caller should take ownership of these devices. +// +// Otherwise, the 'done' callback is given an error status and the +// vector is empty. +typedef std::function*)> + NewRemoteDevicesDone; +void NewRemoteDevices(Env* env, WorkerCacheInterface* worker_cache, + const string& worker_name, NewRemoteDevicesDone done); + +// Create Remote Device based on the given attributes. +std::unique_ptr NewRemoteDevice(Env* env, + DeviceAttributes device_attribute); +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_REMOTE_DEVICE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h b/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h new file mode 100644 index 00000000..6ec759d4 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h @@ -0,0 +1,110 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RENDEZVOUS_MGR_INTERFACE_H_ +#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RENDEZVOUS_MGR_INTERFACE_H_ + +#include + +#include "tensorflow/core/distributed_runtime/worker_env.h" +#include "tensorflow/core/framework/rendezvous.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +class WorkerSession; + +// RemoteRendezvous follow a 2-part initialization. First the objects are +// constructed. Eventually, they will be initialized. Clients of the +// RendezvousMgrInterface must guarantee to call Initialize on the returned +// RemoteRendezvous eventually. +// +// Partially initialized RemoteRendezvous must respect the Rendezvous interface +// (i.e. Send() must never block), however implementations are not expected to +// actually perform the underlying operations until after the RemoteRendezvous +// has been Initialize'd. +class RemoteRendezvous : public Rendezvous { + public: + // Fully construct the RemoteRendezvous. + virtual absl::Status Initialize(WorkerSession* session) = 0; + + // In remote eager, set current instance as context default rendezvous which + // will be used for eager op-by-op execution. + virtual void SetRemoteEagerContextDefault() = 0; + // In remote eager, get if current instance is context default rendezvous. + virtual bool IsRemoteEagerContextDefault() = 0; + + protected: + bool is_cross_process() override { return true; } +}; + +// RendezvousMgr keeps track of a set of local rendezvous instances. +// All tensors sent by this worker are buffered in a RendezvousMgr +// until the tensor is received. Each global unique "step_id" +// corresponds to one local rendezvous instance managed by a +// RendezvousMgr. +// +// E.g., +// Rendezvous* rendez = worker_env->rendezvous_mgr->Find(0x8935); +// fork execution of an graph executor using "rendez" on thread 1; +// fork execution of another graph executor using "rendez" on thread 2; +// ... +// join threads 1 and 2; +// +// In the example above, execution in thread 1 and 2 communicates with +// each other by send/recv operations through the "rend". +// +// Tensors sent and recved through rendezvous managed by this +// RendezvousMgr must have keys generated by Rendezvous::CreateKey. +class RendezvousMgrInterface { + public: + RendezvousMgrInterface() = default; + virtual ~RendezvousMgrInterface() {} + + // Returns Rendezvous supporting send and recv among workers in the + // "step_id". The caller takes ownership of one reference on the + // returned Rendezvous instance. + // + // Note: the caller must guarantee to eventually call Initialize on the + // returned RemoteRendezvous + virtual tsl::core::RefCountPtr Find(int64_t step_id) = 0; + + // Finds the local rendezvous instance for the "step_id". Runs + // "done" when the tensor for "key" is produced or an error occurs. + // + // This method is used by the rpc handler of RecvTensor. + virtual void RecvLocalAsync(int64_t step_id, + const Rendezvous::ParsedKey& parsed, + Rendezvous::DoneCallback done) = 0; + + // Synchronous wrapper for RecvLocalAsync. + virtual absl::Status RecvLocal(int64_t step_id, + const Rendezvous::ParsedKey& parsed, + Tensor* val, bool* is_dead) = 0; + + // Removes rendezvous for "step_id". + // + // TODO(zhifengc): Have a background thread in worker that + // periodically calls CleanupAll(). + virtual void Cleanup(int64_t step_id) = 0; + + // Remove all rendezvous instances owned by the rendezvous_mgr. + virtual void CleanupAll() = 0; +}; + +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RENDEZVOUS_MGR_INTERFACE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/request_id.h b/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/request_id.h new file mode 100644 index 00000000..2f7b3b46 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/request_id.h @@ -0,0 +1,31 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_REQUEST_ID_H_ +#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_REQUEST_ID_H_ + +#include "tensorflow/core/platform/types.h" +#include "tsl/platform/random.h" + +namespace tensorflow { + +// Returns a request_id for use with RecentRequestIds. This number will not be +// zero, and must be unique over RecentRequestIds' window of +// num_tracked_request_ids. See recent_request_ids.h for more details. +int64_t GetUniqueRequestId(); + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_REQUEST_ID_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/rpc/coordination/grpc_coordination_client.h b/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/rpc/coordination/grpc_coordination_client.h new file mode 100644 index 00000000..b692ce70 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/rpc/coordination/grpc_coordination_client.h @@ -0,0 +1,31 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_COORDINATION_GRPC_COORDINATION_CLIENT_H_ +#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_COORDINATION_GRPC_COORDINATION_CLIENT_H_ + +#include + +#include "xla/tsl/distributed_runtime/rpc/coordination/grpc_coordination_client.h" + +namespace tensorflow { +// NOLINTBEGIN(misc-unused-using-decls) +using tsl::NewGrpcCoordinationClient; +using tsl::NewGrpcCoordinationClientCache; +// NOLINTEND(misc-unused-using-decls) + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_COORDINATION_GRPC_COORDINATION_CLIENT_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/rpc/coordination/grpc_coordination_service_impl.h b/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/rpc/coordination/grpc_coordination_service_impl.h new file mode 100644 index 00000000..9e0a218a --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/rpc/coordination/grpc_coordination_service_impl.h @@ -0,0 +1,27 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_COORDINATION_GRPC_COORDINATION_SERVICE_IMPL_H_ +#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_COORDINATION_GRPC_COORDINATION_SERVICE_IMPL_H_ + +#include "xla/tsl/distributed_runtime/rpc/coordination/grpc_coordination_service_impl.h" + +namespace tensorflow { +// NOLINTBEGIN(misc-unused-using-decls) +using tsl::GrpcCoordinationServiceImpl; +// NOLINTEND(misc-unused-using-decls) +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_COORDINATION_GRPC_COORDINATION_SERVICE_IMPL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_client.h b/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_client.h new file mode 100644 index 00000000..2eb41b8a --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_client.h @@ -0,0 +1,32 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_EAGER_GRPC_EAGER_CLIENT_H_ +#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_EAGER_GRPC_EAGER_CLIENT_H_ + +#include + +#include "tensorflow/core/distributed_runtime/eager/eager_client.h" +#include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h" + +namespace tensorflow { +namespace eager { +// The GrpcChannelCache is not owned. +EagerClientCache* NewGrpcEagerClientCache( + std::shared_ptr channel); +} // namespace eager +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_EAGER_GRPC_EAGER_CLIENT_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service.h b/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service.h new file mode 100644 index 00000000..24cd17a4 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service.h @@ -0,0 +1,21 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_EAGER_GRPC_EAGER_SERVICE_H_ +#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_EAGER_GRPC_EAGER_SERVICE_H_ + +#include "tensorflow/core/protobuf/eager_service.grpc.pb.h" + +#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_EAGER_GRPC_EAGER_SERVICE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service_impl.h b/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service_impl.h new file mode 100644 index 00000000..7acc2955 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service_impl.h @@ -0,0 +1,175 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_EAGER_GRPC_EAGER_SERVICE_IMPL_H_ +#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_EAGER_GRPC_EAGER_SERVICE_IMPL_H_ + +#include + +#include "grpcpp/alarm.h" +#include "grpcpp/completion_queue.h" +#include "grpcpp/server_builder.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "xla/tsl/distributed_runtime/rpc/async_service_interface.h" +#include "xla/tsl/distributed_runtime/rpc/grpc_call.h" +#include "tensorflow/core/distributed_runtime/eager/eager_service_impl.h" +#include "tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service.h" +#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h" +#include "tensorflow/core/protobuf/eager_service.pb.h" + +namespace tensorflow { +namespace eager { + +// This class is a wrapper that handles communication for gRPC. +class GrpcEagerServiceImpl : public tsl::AsyncServiceInterface { + public: + template + using EagerCall = + tsl::Call; + template + using StreamingCall = + tsl::ServerBidirectionalStreamingCall; + + GrpcEagerServiceImpl(WorkerEnv* env, ::grpc::ServerBuilder* server_builder); + virtual ~GrpcEagerServiceImpl() {} + + // Create a master context in eager service. + absl::Status CreateMasterContext(tensorflow::uint64 context_id, + EagerContext* context); + + void HandleRPCsLoop() override; + void Shutdown() override; + + private: +#define HANDLER(method) \ + void method##Handler(EagerCall* call) { \ + env_->compute_pool->Schedule([this, call]() { \ + call->SendResponse( \ + ToGrpcStatus(local_impl_.method(&call->request, &call->response))); \ + }); \ + tsl::Call:: \ + EnqueueRequest(&service_, cq_.get(), \ + &grpc::EagerService::AsyncService::Request##method, \ + &GrpcEagerServiceImpl::method##Handler, false); \ + } + HANDLER(CreateContext); + HANDLER(UpdateContext); + HANDLER(WaitQueueDone); + HANDLER(KeepAlive); + HANDLER(CloseContext); +#undef HANDLER + + void EnqueueHandler(EagerCall* call) { + env_->compute_pool->Schedule([this, call]() { + auto call_opts = std::make_shared(); + call->SetCancelCallback([call_opts]() { call_opts->StartCancel(); }); + call->SendResponse(ToGrpcStatus(local_impl_.Enqueue( + call_opts.get(), &call->request, &call->response))); + }); + tsl::Call:: + EnqueueRequest(&service_, cq_.get(), + &grpc::EagerService::AsyncService::RequestEnqueue, + &GrpcEagerServiceImpl::EnqueueHandler, + /*supports_cancel=*/true); + } + + void RunComponentFunctionHandler( + EagerCall* + call) { + env_->compute_pool->Schedule([this, call]() { + auto call_opts = std::make_shared(); + call->SetCancelCallback([call_opts]() { call_opts->StartCancel(); }); + local_impl_.RunComponentFunction( + call_opts.get(), &call->request, &call->response, + [call, call_opts](const absl::Status& s) { + call->ClearCancelCallback(); + call->SendResponse(ToGrpcStatus(s)); + }); + }); + tsl::Call:: + EnqueueRequest( + &service_, cq_.get(), + &grpc::EagerService::AsyncService::RequestRunComponentFunction, + &GrpcEagerServiceImpl::RunComponentFunctionHandler, + /*supports_cancel=*/true); + } + + // Called when a new request has been received as part of a StreamingEnqueue + // call. + // StreamingEnqueueHandler gets the request from the `call` and fills the + // response (also found in `call`) by invoking the local EagerServiceImpl. + // The local EagerServiceImpl is invoked in a single-threaded thread pool. We + // do this to preserve request order. The local service can parallelize based + // on context_id in request if necessary. Remote contexts are created in async + // mode by default, so the local service impl just puts the request on eager + // executor queue. + void StreamingEnqueueHandler( + StreamingCall* call) { + call->Ref(); + enqueue_streaming_thread_.Schedule([this, call]() { + if (call->RefCountIsOne()) { + // This StreamingCall has already been shutdown. Don't need to anything. + call->Unref(); + return; + } + // NOTE(fishx): Use the address of StreamingCall as the stream_id since we + // reuse the same StreamingCall for multiple requests in the same + // streaming connection. + absl::Status status = local_impl_.Enqueue( + /*call_opts=*/nullptr, &call->request(), call->mutable_response(), + reinterpret_cast(static_cast(call))); + + if (status.ok()) { + VLOG(1) << "local_impl_.Enqueue completed successfully"; + call->SendResponse(); + } else { + VLOG(1) << "local_impl_.Enqueue failed with " << status.ToString() + << " on request " << call->request().DebugString(); + call->Finish(ToGrpcStatus(status)); + } + call->Unref(); + + // We do not tell gRPC to accept a new StreamingEnqueue request because + // this method can be called multiple times for a given streaming call. + // The StreamingCall does this per call instead, after a call has been + // opened. + }); + } + + WorkerEnv* const env_; // Not owned. + EagerServiceImpl local_impl_; + + // A single-threaded thread pool to handle streaming enqueue rpc request. + thread::ThreadPool enqueue_streaming_thread_; + std::unique_ptr<::grpc::Alarm> shutdown_alarm_; + + std::unique_ptr<::grpc::ServerCompletionQueue> cq_; + grpc::EagerService::AsyncService service_; + + GrpcEagerServiceImpl(const GrpcEagerServiceImpl&) = delete; + void operator=(const GrpcEagerServiceImpl&) = delete; +}; + +} // namespace eager +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_EAGER_GRPC_EAGER_SERVICE_IMPL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/rpc/grpc_channel.h b/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/rpc/grpc_channel.h new file mode 100644 index 00000000..b9bc118e --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/rpc/grpc_channel.h @@ -0,0 +1,33 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_CHANNEL_H_ +#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_CHANNEL_H_ + +#include "xla/tsl/distributed_runtime/rpc/grpc_channel.h" + +namespace tensorflow { +// NOLINTBEGIN(misc-unused-using-decls) +using tsl::ChannelCreationFunction; +using tsl::ConvertToChannelCreationFunction; +using tsl::GetChannelArguments; +using tsl::GrpcChannelCache; +using tsl::GrpcChannelSpec; +using tsl::NewGrpcChannelCache; +using tsl::NewHostPortGrpcChannel; +// NOLINTEND(misc-unused-using-decls) +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_CHANNEL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/rpc/grpc_client_cq_tag.h b/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/rpc/grpc_client_cq_tag.h new file mode 100644 index 00000000..30822036 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/rpc/grpc_client_cq_tag.h @@ -0,0 +1,27 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_CLIENT_CQ_TAG_H_ +#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_CLIENT_CQ_TAG_H_ + +#include "xla/tsl/distributed_runtime/rpc/grpc_client_cq_tag.h" + +namespace tensorflow { +// NOLINTBEGIN(misc-unused-using-decls) +using tsl::GrpcClientCQTag; +// NOLINTEND(misc-unused-using-decls) +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_CLIENT_CQ_TAG_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/rpc/grpc_master_service.h b/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/rpc/grpc_master_service.h new file mode 100644 index 00000000..bd203163 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/rpc/grpc_master_service.h @@ -0,0 +1,36 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_MASTER_SERVICE_H_ +#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_MASTER_SERVICE_H_ + +#include +#include "grpcpp/server_builder.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/protobuf/master.pb.h" + +namespace tsl { +class AsyncServiceInterface; +} +namespace tensorflow { +class Master; + +tsl::AsyncServiceInterface* NewGrpcMasterService( + Master* master, const ConfigProto& default_session_config, + ::grpc::ServerBuilder* builder); + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_MASTER_SERVICE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/rpc/grpc_master_service_impl.h b/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/rpc/grpc_master_service_impl.h new file mode 100644 index 00000000..bdf683fd --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/rpc/grpc_master_service_impl.h @@ -0,0 +1,218 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_MASTER_SERVICE_IMPL_H_ +#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_MASTER_SERVICE_IMPL_H_ + +#include "grpcpp/impl/codegen/async_stream.h" +#include "grpcpp/impl/codegen/async_unary_call.h" +#include "grpcpp/impl/codegen/client_context.h" +#include "grpcpp/impl/codegen/completion_queue.h" +#include "grpcpp/impl/codegen/proto_utils.h" +#include "grpcpp/impl/codegen/rpc_method.h" +#include "grpcpp/impl/codegen/server_context.h" +#include "grpcpp/impl/codegen/service_type.h" +#include "grpcpp/impl/codegen/status.h" +#include "grpcpp/impl/codegen/stub_options.h" +#include "grpcpp/impl/codegen/sync_stream.h" +#include "tensorflow/core/protobuf/master.pb.h" + +namespace tensorflow { + +namespace grpc { + +// Implementation of `tensorflow.MasterService`, based on the +// definition in "//tensorflow/core/protobuf/master_service.proto", +// and the gRPC generated stub and service classes. +// See that file for the definition of methods and messages. +class MasterService final { + public: + class StubInterface { + public: + virtual ~StubInterface() {} + virtual ::grpc::Status CreateSession(::grpc::ClientContext* context, + const CreateSessionRequest& request, + CreateSessionResponse* response) = 0; + virtual ::grpc::Status ExtendSession(::grpc::ClientContext* context, + const ExtendSessionRequest& request, + ExtendSessionResponse* response) = 0; + virtual ::grpc::Status PartialRunSetup( + ::grpc::ClientContext* context, const PartialRunSetupRequest& request, + PartialRunSetupResponse* response) = 0; + virtual ::grpc::Status RunStep(::grpc::ClientContext* context, + const RunStepRequest& request, + RunStepResponse* response) = 0; + virtual ::grpc::Status CloseSession(::grpc::ClientContext* context, + const CloseSessionRequest& request, + CloseSessionResponse* response) = 0; + virtual ::grpc::Status ListDevices(::grpc::ClientContext* context, + const ListDevicesRequest& request, + ListDevicesResponse* response) = 0; + virtual ::grpc::Status Reset(::grpc::ClientContext* context, + const ResetRequest& request, + ResetResponse* response) = 0; + virtual ::grpc::Status MakeCallable(::grpc::ClientContext* context, + const MakeCallableRequest& request, + MakeCallableResponse* response) = 0; + virtual ::grpc::Status RunCallable(::grpc::ClientContext* context, + const RunCallableRequest& request, + RunCallableResponse* response) = 0; + virtual ::grpc::Status ReleaseCallable( + ::grpc::ClientContext* context, const ReleaseCallableRequest& request, + ReleaseCallableResponse* response) = 0; + }; + class Stub final : public StubInterface { + public: + Stub(const std::shared_ptr< ::grpc::ChannelInterface>& channel); + ::grpc::Status CreateSession(::grpc::ClientContext* context, + const CreateSessionRequest& request, + CreateSessionResponse* response) override; + ::grpc::Status ExtendSession(::grpc::ClientContext* context, + const ExtendSessionRequest& request, + ExtendSessionResponse* response) override; + ::grpc::Status PartialRunSetup(::grpc::ClientContext* context, + const PartialRunSetupRequest& request, + PartialRunSetupResponse* response) override; + ::grpc::Status RunStep(::grpc::ClientContext* context, + const RunStepRequest& request, + RunStepResponse* response) override; + ::grpc::Status CloseSession(::grpc::ClientContext* context, + const CloseSessionRequest& request, + CloseSessionResponse* response) override; + ::grpc::Status ListDevices(::grpc::ClientContext* context, + const ListDevicesRequest& request, + ListDevicesResponse* response) override; + ::grpc::Status Reset(::grpc::ClientContext* context, + const ResetRequest& request, + ResetResponse* response) override; + ::grpc::Status MakeCallable(::grpc::ClientContext* context, + const MakeCallableRequest& request, + MakeCallableResponse* response) override; + ::grpc::Status RunCallable(::grpc::ClientContext* context, + const RunCallableRequest& request, + RunCallableResponse* response) override; + ::grpc::Status ReleaseCallable(::grpc::ClientContext* context, + const ReleaseCallableRequest& request, + ReleaseCallableResponse* response) override; + + private: + std::shared_ptr< ::grpc::ChannelInterface> channel_; + const ::grpc::internal::RpcMethod rpcmethod_CreateSession_; + const ::grpc::internal::RpcMethod rpcmethod_ExtendSession_; + const ::grpc::internal::RpcMethod rpcmethod_PartialRunSetup_; + const ::grpc::internal::RpcMethod rpcmethod_RunStep_; + const ::grpc::internal::RpcMethod rpcmethod_CloseSession_; + const ::grpc::internal::RpcMethod rpcmethod_ListDevices_; + const ::grpc::internal::RpcMethod rpcmethod_Reset_; + const ::grpc::internal::RpcMethod rpcmethod_MakeCallable_; + const ::grpc::internal::RpcMethod rpcmethod_RunCallable_; + const ::grpc::internal::RpcMethod rpcmethod_ReleaseCallable_; + }; + static std::unique_ptr NewStub( + const std::shared_ptr< ::grpc::ChannelInterface>& channel, + const ::grpc::StubOptions& options = ::grpc::StubOptions()); + + class AsyncService : public ::grpc::Service { + public: + AsyncService(); + virtual ~AsyncService(); + void RequestCreateSession( + ::grpc::ServerContext* context, CreateSessionRequest* request, + ::grpc::ServerAsyncResponseWriter* response, + ::grpc::CompletionQueue* new_call_cq, + ::grpc::ServerCompletionQueue* notification_cq, void* tag) { + ::grpc::Service::RequestAsyncUnary(0, context, request, response, + new_call_cq, notification_cq, tag); + } + void RequestExtendSession( + ::grpc::ServerContext* context, ExtendSessionRequest* request, + ::grpc::ServerAsyncResponseWriter* response, + ::grpc::CompletionQueue* new_call_cq, + ::grpc::ServerCompletionQueue* notification_cq, void* tag) { + ::grpc::Service::RequestAsyncUnary(1, context, request, response, + new_call_cq, notification_cq, tag); + } + void RequestPartialRunSetup( + ::grpc::ServerContext* context, PartialRunSetupRequest* request, + ::grpc::ServerAsyncResponseWriter* response, + ::grpc::CompletionQueue* new_call_cq, + ::grpc::ServerCompletionQueue* notification_cq, void* tag) { + ::grpc::Service::RequestAsyncUnary(2, context, request, response, + new_call_cq, notification_cq, tag); + } + void RequestRunStep( + ::grpc::ServerContext* context, RunStepRequest* request, + ::grpc::ServerAsyncResponseWriter* response, + ::grpc::CompletionQueue* new_call_cq, + ::grpc::ServerCompletionQueue* notification_cq, void* tag) { + ::grpc::Service::RequestAsyncUnary(3, context, request, response, + new_call_cq, notification_cq, tag); + } + void RequestCloseSession( + ::grpc::ServerContext* context, CloseSessionRequest* request, + ::grpc::ServerAsyncResponseWriter* response, + ::grpc::CompletionQueue* new_call_cq, + ::grpc::ServerCompletionQueue* notification_cq, void* tag) { + ::grpc::Service::RequestAsyncUnary(4, context, request, response, + new_call_cq, notification_cq, tag); + } + void RequestListDevices( + ::grpc::ServerContext* context, ListDevicesRequest* request, + ::grpc::ServerAsyncResponseWriter* response, + ::grpc::CompletionQueue* new_call_cq, + ::grpc::ServerCompletionQueue* notification_cq, void* tag) { + ::grpc::Service::RequestAsyncUnary(5, context, request, response, + new_call_cq, notification_cq, tag); + } + void RequestReset( + ::grpc::ServerContext* context, ResetRequest* request, + ::grpc::ServerAsyncResponseWriter* response, + ::grpc::CompletionQueue* new_call_cq, + ::grpc::ServerCompletionQueue* notification_cq, void* tag) { + ::grpc::Service::RequestAsyncUnary(6, context, request, response, + new_call_cq, notification_cq, tag); + } + void RequestMakeCallable( + ::grpc::ServerContext* context, MakeCallableRequest* request, + ::grpc::ServerAsyncResponseWriter* response, + ::grpc::CompletionQueue* new_call_cq, + ::grpc::ServerCompletionQueue* notification_cq, void* tag) { + ::grpc::Service::RequestAsyncUnary(7, context, request, response, + new_call_cq, notification_cq, tag); + } + void RequestRunCallable( + ::grpc::ServerContext* context, RunCallableRequest* request, + ::grpc::ServerAsyncResponseWriter* response, + ::grpc::CompletionQueue* new_call_cq, + ::grpc::ServerCompletionQueue* notification_cq, void* tag) { + ::grpc::Service::RequestAsyncUnary(8, context, request, response, + new_call_cq, notification_cq, tag); + } + void RequestReleaseCallable( + ::grpc::ServerContext* context, ReleaseCallableRequest* request, + ::grpc::ServerAsyncResponseWriter* response, + ::grpc::CompletionQueue* new_call_cq, + ::grpc::ServerCompletionQueue* notification_cq, void* tag) { + ::grpc::Service::RequestAsyncUnary(9, context, request, response, + new_call_cq, notification_cq, tag); + } + }; +}; + +} // namespace grpc + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_MASTER_SERVICE_IMPL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/rpc/grpc_remote_master.h b/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/rpc/grpc_remote_master.h new file mode 100644 index 00000000..c80668e8 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/rpc/grpc_remote_master.h @@ -0,0 +1,27 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_REMOTE_MASTER_H_ +#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_REMOTE_MASTER_H_ + +#include "tensorflow/core/distributed_runtime/master_interface.h" +#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h" + +namespace tensorflow { +// Returns a MasterInterface wrapped around the gRPC channel `channel`. +MasterInterface* NewGrpcMaster(const SharedGrpcChannelPtr& channel); +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_REMOTE_MASTER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.h b/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.h new file mode 100644 index 00000000..97e590e0 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.h @@ -0,0 +1,37 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_REMOTE_WORKER_H_ +#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_REMOTE_WORKER_H_ + +#include + +#include "grpcpp/completion_queue.h" +#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h" +#include "tensorflow/core/lib/core/threadpool.h" + +namespace tensorflow { +class WorkerCacheLogger; +class WorkerInterface; + +WorkerInterface* NewGrpcRemoteWorker(SharedGrpcChannelPtr channel, + ::grpc::CompletionQueue* completion_queue, + thread::ThreadPool* callback_threadpool, + WorkerCacheLogger* logger, + const string& target); + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_REMOTE_WORKER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h b/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h new file mode 100644 index 00000000..ca162c19 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h @@ -0,0 +1,242 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_SERVER_LIB_H_ +#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_SERVER_LIB_H_ + +// GrpcServer manages the lifecycle of an Eager, Worker and Master service. + +#include +#include +#include + +#include "grpcpp/grpcpp.h" +#include "grpcpp/security/credentials.h" +#include "xla/tsl/distributed_runtime/rpc/async_service_interface.h" +#include "tensorflow/core/common_runtime/eager/context.h" +#include "tensorflow/core/common_runtime/process_util.h" +#include "tensorflow/core/common_runtime/stats_publisher_interface.h" +#include "tensorflow/core/distributed_runtime/master_env.h" +#include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h" +#include "tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.h" +#include "tensorflow/core/distributed_runtime/rpc/grpc_worker_service.h" +#include "tensorflow/core/distributed_runtime/server_lib.h" +#include "tensorflow/core/distributed_runtime/session_mgr.h" +#include "tensorflow/core/distributed_runtime/worker_env.h" +#include "tensorflow/core/framework/collective.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/platform/env.h" +#include "tsl/profiler/protobuf/profiler_service.grpc.pb.h" + +namespace tensorflow { + +class GrpcWorker; +class Master; + +// function that creates a RendezvousMgr. +typedef std::function + RendezvousMgrCreationFunction; + +// function that creates a CollectiveExecutorMgr. +typedef std::function + CollectiveMgrCreationFunction; + +// function that registers a service to the server. The service needs to +// be registered before builder.BuildAndStart(). +typedef std::function + ServiceInitFunction; + +// function that creates a grpc based worker implementation. +typedef std::function(WorkerEnv*, + const ConfigProto& config)> + WorkerCreationFunction; + +struct GrpcServerOptions { + ServiceInitFunction service_func = nullptr; + RendezvousMgrCreationFunction rendezvous_mgr_func = nullptr; + CollectiveMgrCreationFunction collective_mgr_func = nullptr; + WorkerCreationFunction worker_func = nullptr; + StatsPublisherFactory stats_factory = CreateNoOpStatsPublisher; + GrpcWorkerServiceOptions worker_service_options; + DeviceMgr* local_device_mgr = nullptr; +}; + +class GrpcServer : public ServerInterface { + protected: + GrpcServer(const ServerDef& server_def, Env* env); + GrpcServer(const ServerDef& server_def, DeviceMgr* local_device_mgr, + Env* env); + // Allow children classes to override this and provide custom args to the + // server before it is constructed. Default behavior is to do nothing. + // requested_port provides the port requested by caller as bound_port() is + // not available till BuildAndStart has been called. + virtual void MaybeMutateBuilder(::grpc::ServerBuilder* builder, + int requested_port) {} + + public: + static absl::Status Create(const ServerDef& server_def, Env* env, + std::unique_ptr* out_server); + static absl::Status Create(const ServerDef& server_def, Env* env, + std::unique_ptr* out_server); + // Reuse the local_device_mgr. + static absl::Status Create(const ServerDef& server_def, Env* env, + DeviceMgr* local_device_mgr, + std::unique_ptr* out_server); + + // Destruction is only supported in the factory method. Clean + // shutdown is not currently implemented for this server type. + virtual ~GrpcServer(); + + // Implementations of ServerInterface methods. + absl::Status Start() override; + absl::Status Stop() override; + absl::Status Join() override; + const string target() const override; + + WorkerEnv* worker_env() override { return &worker_env_; } + MasterEnv* master_env() override { return &master_env_; } + + // Add master eager context to local eager service in order to handle enqueue + // requests from remote workers. + absl::Status AddMasterEagerContextToEagerService( + const tensorflow::uint64 context_id, + tensorflow::EagerContext* context) override; + // Update the set of workers that can be reached by the GRPC server + absl::Status UpdateServerDef(const ServerDef& server_def) override; + // Pass coordination service agent instance to server's RPC handler + absl::Status SetCoordinationServiceAgentInstance( + tsl::CoordinationServiceAgent* agent) override; + // TODO(hanyangtay): Remove this method once gRPC server clean shutdown is + // supported. + absl::Status StopCoordinationService() override; + + protected: + virtual absl::Status GetHostAndPort(const ServerDef& server_def, + string* host_name, int* port) const; + absl::Status Init(const GrpcServerOptions& opts = GrpcServerOptions()); + + // A subclass can override this method to support secure credentials. + virtual std::shared_ptr<::grpc::ServerCredentials> GetServerCredentials( + const ServerDef& server_def) const; + + virtual ChannelCreationFunction GetChannelCreationFunction() const; + + virtual std::unique_ptr CreateMaster(MasterEnv* master_env); + + // Creates a WorkerCacheInterface for a session. + virtual absl::Status WorkerCacheFactory( + const WorkerCacheFactoryOptions& options, + WorkerCacheInterface** worker_cache); + + // Override to return extra services to be brought up and managed along with + // the standard {master, worker, eager} services. The map key is an aribtrary + // string and the value is a pointer to the service to be brought up. + // Ownership of the pointer is transferred to GrpcServer after this call + // returns, and the service will be destroyed during the destruction of + // GrpcServer. Each service will have its HandleRPCsLoop called in a separate + // thread. An example usage would be to add a RDMA based partial worker + // service to offload tensor and data buffer transfers. + virtual std::map ExtraServices( + ::grpc::ServerBuilder*) { + return {}; + } + + virtual std::map + GetExtraServices() { + return extra_services_; + } + + // Parses a WorkerCacheFactoryOptions into a GrpcChannelSpec. + absl::Status ParseChannelSpec(const WorkerCacheFactoryOptions& options, + GrpcChannelSpec* channel_spec); + + // Returns the port to which this server is bound. + // This method may only be called after `this->Init()` returns successfully. + int bound_port() const { return bound_port_; } + + // Returns hostname. + const string& host_name() const { return host_name_; } + + const ServerDef& server_def() const { return server_def_; } + GrpcWorker* worker_impl() const { return worker_impl_.get(); } + GrpcWorkerEnv* grpc_worker_env() const { return grpc_worker_env_.get(); } + + absl::Status SetCoordinationServiceInstance( + tsl::CoordinationServiceInterface* service); + + private: + Env* env_; + + // The port to which this server is bound. + int bound_port_ = 0; + + // The host name of this server + string host_name_; + + // Guards server configuration, server, and state. + mutex mu_; + + // Represents the current state of the server, which changes as follows: + // + // Join() Join() + // ___ ___ + // Start() \ / Stop() \ / + // NEW ---------> STARTED --------> STOPPED + // \ / + // \________________________/ + // Stop(), Join() + enum State { NEW, STARTED, STOPPED }; + State state_ TF_GUARDED_BY(mu_); + + // Implementation of a TensorFlow master, and RPC polling thread. + MasterEnv master_env_; + std::unique_ptr master_impl_; + tsl::AsyncServiceInterface* master_service_ = nullptr; + std::unique_ptr master_thread_ TF_GUARDED_BY(mu_); + + std::map extra_services_; + std::vector> extra_service_threads_ + TF_GUARDED_BY(mu_); + + // Implementation of a TensorFlow worker, and RPC polling thread. + WorkerEnv worker_env_; + std::unique_ptr owned_device_manager_; + std::unique_ptr worker_impl_; + tsl::AsyncServiceInterface* worker_service_ = nullptr; + std::unique_ptr worker_thread_ TF_GUARDED_BY(mu_); + std::unique_ptr grpc_worker_env_; + + // TensorFlow Eager implementation, and RPC polling thread. + tsl::AsyncServiceInterface* eager_service_ = nullptr; + std::unique_ptr eager_thread_ TF_GUARDED_BY(mu_); + std::shared_ptr worker_session_; + + // Experimental coordination service implementation, and RPC polling thread. + tsl::AsyncServiceInterface* coordination_service_ = nullptr; + std::unique_ptr coordination_thread_ TF_GUARDED_BY(mu_); + + // TensorFlow profiler service implementation. + std::unique_ptr profiler_service_ = nullptr; + + // The overall server configuration. + ServerDef server_def_ TF_GUARDED_BY(mu_); + + std::unique_ptr<::grpc::Server> server_ TF_GUARDED_BY(mu_); +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_SERVER_LIB_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/rpc/grpc_session.h b/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/rpc/grpc_session.h new file mode 100644 index 00000000..fe92f7c0 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/rpc/grpc_session.h @@ -0,0 +1,156 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_SESSION_H_ +#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_SESSION_H_ + +#include +#include +#include + +#include "tensorflow/core/distributed_runtime/call_options.h" +#include "tensorflow/core/distributed_runtime/message_wrappers.h" +#include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/thread_annotations.h" +#include "tensorflow/core/protobuf/config.pb.h" +#include "tensorflow/core/protobuf/master.pb.h" +#include "tensorflow/core/public/session.h" +#include "tensorflow/core/public/session_options.h" + +namespace tensorflow { + +class MasterInterface; + +// A Session instance lets the caller drive a TensorFlow graph +// computation on potentially remote sets of devices. This is a thin +// wrapper around tensorflow::grpc::MasterService. +// +// Multiple threads must synchronize their accesses to a single +// session. +class GrpcSession : public Session { + protected: + explicit GrpcSession(const SessionOptions& options); + + public: + static absl::Status Create(const SessionOptions& options, + std::unique_ptr* out_session); + // Resets the resource containers. + static absl::Status Reset(const SessionOptions& options, + const std::vector& containers); + + ~GrpcSession() override; + + // Creates a session with the "target". The session carries out + // the graph computation defined by "graph", and will have version + // number "initial_version". + absl::Status Create(const GraphDef& graph) override; + absl::Status Create(const RunOptions& run_options, + const GraphDef& graph) override; + absl::Status Create(GraphDef&& graph) override; + absl::Status Create(const RunOptions& run_options, GraphDef&& graph) override; + + // Runs with and without RunOptions. + absl::Status Run(const std::vector >& inputs, + const std::vector& output_tensor_names, + const std::vector& target_node_names, + std::vector* outputs) override; + absl::Status Run(const RunOptions& run_options, + const std::vector >& inputs, + const std::vector& output_tensor_names, + const std::vector& target_node_names, + std::vector* outputs, + RunMetadata* run_metadata) override; + + absl::Status Extend(const GraphDef& graph) override; + absl::Status Extend(const RunOptions& run_options, + const GraphDef& graph) override; + absl::Status Extend(GraphDef&& graph) override; + absl::Status Extend(const RunOptions& run_options, GraphDef&& graph) override; + + absl::Status Close() override; + + // NOTE: This API is still experimental and may change. + absl::Status PRunSetup(const std::vector& input_names, + const std::vector& output_names, + const std::vector& target_nodes, + string* handle) override; + + // NOTE: This API is still experimental and may change. + absl::Status PRun(const string& handle, + const std::vector >& inputs, + const std::vector& output_names, + std::vector* outputs) override; + + absl::Status ListDevices(std::vector* response) override; + + absl::Status MakeCallable(const CallableOptions& callable_options, + CallableHandle* out_handle) override; + absl::Status RunCallable(CallableHandle handle, + const std::vector& feed_tensors, + std::vector* fetch_tensors, + RunMetadata* run_metadata) override; + absl::Status ReleaseCallable(CallableHandle handle) override; + + protected: + // Takes ownership of `*master`. + void SetRemoteMaster(std::unique_ptr master); + // Allows subclasses to customize Session creation. + void SetHandleAndGraphVersion(string handle, int64_t graph_version) + TF_LOCKS_EXCLUDED(mu_); + + private: + const SessionOptions options_; + std::unique_ptr master_; + mutex mu_; + + // handle_ returned by the master to identify this session. + string handle_ TF_GUARDED_BY(mu_); + + // The current version of the graph. + int64_t current_graph_version_ TF_GUARDED_BY(mu_); + + bool is_local_ = false; + + absl::Status Handle(string* out_handle) TF_LOCKS_EXCLUDED(mu_); + + absl::Status RunHelper(const RunOptions& run_options, + const std::vector >& inputs, + const std::vector& output_tensor_names, + const std::vector& target_node_names, + std::vector* outputs, + RunMetadata* run_metadata, const string& prun_handle); + + absl::Status RunProto(CallOptions* call_options, + MutableRunStepRequestWrapper* req, + MutableRunStepResponseWrapper* resp); + + // Implementations for all the public interfaces. + absl::Status CreateImpl(CallOptions* call_options, GraphDef graph); + absl::Status ExtendImpl(CallOptions* call_options, GraphDef graph); + + GrpcSession(const GrpcSession&) = delete; + void operator=(const GrpcSession&) = delete; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_SESSION_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/rpc/grpc_state.h b/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/rpc/grpc_state.h new file mode 100644 index 00000000..4c5f560e --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/rpc/grpc_state.h @@ -0,0 +1,541 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_STATE_H_ +#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_STATE_H_ + +#include +#include + +#include "grpcpp/generic/generic_stub.h" +#include "grpcpp/grpcpp.h" +#include "xla/tsl/distributed_runtime/rpc/grpc_state.h" +#include "tensorflow/core/distributed_runtime/call_options.h" +#include "tensorflow/core/distributed_runtime/rpc/grpc_client_cq_tag.h" +#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h" +#include "tensorflow/core/distributed_runtime/tensor_coding.h" +#include "tensorflow/core/lib/core/refcount.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/core/threadpool.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/notification.h" +#include "tensorflow/core/util/env_var.h" + +namespace tensorflow { +// NOLINTBEGIN(misc-unused-using-decls) +using tsl::RPCState; +// NOLINTEND(misc-unused-using-decls) + +// Represents state associated with one streaming RPC call. +// Similarly to above, we extract the methods of StreamingRPCState that don't +// need to be templated into this abstract class. +// Currently, *StreamingRPCState does not support client closing the call as +// there is no use case for it - current clients keep the streaming call open +// as long as possible. If/when the need arises, support can be added +// by calling GenericClientAsyncReaderWriter::WritesDone with a new tag +// TagType::kClientFinished and handling the completion in a new callback. +class UntypedStreamingRPCState : public core::RefCounted { + public: + virtual void CallStarted(bool ok) = 0; + virtual void RequestWriteCompleted(bool ok) = 0; + virtual void ResponseReadCompleted(bool ok) = 0; + virtual void CallFinished(bool ok) = 0; + + virtual string DebugString() const = 0; + + class Tag : public GrpcClientCQTag { + public: + // One enum value per supported callback. + enum class TagType { + kCallStarted, + kRequestWriteCompleted, + kResponseReadCompleted, + kCallFinished, + }; + + Tag(UntypedStreamingRPCState* streaming_state, Tag::TagType type); + + // Calls the callback associated with this tag and Unrefs + // `this->streaming_state_`. + void OnCompleted(bool ok) override; + + private: + // OnCompleted() consumes on reference each time it is called. + UntypedStreamingRPCState* const streaming_state_; + const Tag::TagType type_; + }; +}; + +const char* ToString(UntypedStreamingRPCState::Tag::TagType tag_type); + +// Represents a single request/response exchange between client and the server. +// A single streaming call contains a sequence of exchanges. Besides the +// messages, exchange contains: +// - the user callback to invoke when exchange completes (response is received +// or an error occurs). +// - The current state of the exchange. +class Exchange { + public: + enum class State { + kExchangeCreated, + kRequestWriteIssued, + kRequestWriteCompleted, + kResponseReadIssued, + }; + + Exchange(const ::grpc::ByteBuffer& request_buf, protobuf::Message* response, + StatusCallback cb, string debug_string) + : state_(State::kExchangeCreated), + request_buf_(request_buf), + response_(response), + cb_(std::move(cb)), + debug_string_(std::move(debug_string)) {} + + const ::grpc::ByteBuffer& request_buf() { return request_buf_; } + ::grpc::ByteBuffer* response_buf() { return &response_buf_; } + + void MarkRequestWriteIssued() { + DCHECK(state_ == State::kExchangeCreated); + state_ = State::kRequestWriteIssued; + } + void MarkRequestWriteCompleted() { + DCHECK(state_ == State::kRequestWriteIssued); + state_ = State::kRequestWriteCompleted; + } + void MarkResponseReadIssued() { + DCHECK(state_ == State::kRequestWriteCompleted); + state_ = State::kResponseReadIssued; + } + + // If `status` is success, completes this exchange by parsing the + // response_buf_ and invoking cb_ with OkStatus(). Else, invokes the + // callback with `status`. + void Complete(absl::Status status); + + const State& state() const { return state_; } + + string DebugString() const; + + private: + State state_; + ::grpc::ByteBuffer request_buf_; + ::grpc::ByteBuffer response_buf_; + protobuf::Message* response_; + StatusCallback cb_; + string debug_string_; +}; + +const char* ToString(Exchange::State s); + +std::ostream& operator<<(std::ostream& os, const Exchange::State& state); + +// Represents a queue of exchanges. +// When a client sends a new request a new exchange is created and added to the +// end of the queue. Completed exchanges are popped from the front of the queue. +// An explicit exchange queue is needed to brdige the client, which can send new +// requests at any time, with gRPC infrastructure, which can handle a single +// read and a single write request at a time. +// +// As the exchange progresses (request sending initiated, request sending +// completed, response reading initiated) the queue helps to make sure that the +// right operation is issued on the right exchange at the right time. +// +// To satisfy gRPC constraints, the states of exchanges must be as follows +// starting from the front of the queue: +// - 0 or 1 exchange in kResponseReadIssued state +// - 0 or more exchanges in kRequestWriteCompleted state +// - 0 or 1 exchange in kRequestWriteIssued state +// - 0 or more exchanges in kExchangeCreated state +// +// Thread-compatible. +class ExchangeQueue { + public: + // Creates a new exchange and adds it to the end of the queue. + void Emplace(const ::grpc::ByteBuffer& request_buf, + protobuf::Message* response, StatusCallback cb, + std::string debug_string); + + // Returns an exchange for which we can initiate request writing, if any. + // Returns nullptr if there is no such exchange. + Exchange* GetReadyForRequestWriting(); + + // Returns an exchange for which we can initiate response reading, if any. + // Returns nullptr if there is no such exchange. + Exchange* GetReadyForResponseReading(); + + // Changes the state of the exchange that is current in kRequestWriteIssued + // state to kRequestWriteCompleted state. + // REQUIRES: There is an exchange in kRequestWriteIssued state. + void MarkRequestWriteCompleted(); + + // Returns the exchange at the front of the queue. + // REQUIRES: ExchangeQueue is not empty. + Exchange& GetFront(); + + // Removes the exchange at the front of the queue. + // REQUIRES: ExchangeQueue is not empty. + void PopFront(); + + // Returns a string containing addresses and states of all exchanges in this + // queue. + string DebugString() const; + + // Swaps the contents of this and `other`. + void Swap(ExchangeQueue* other); + + // Completes all exchanges in this with `status`. + void CompleteAll(absl::Status status); + + void CallStarted() { call_started_ = true; } + + private: + // Does nothing by default. Turn on VLOG(5) to enable. + // Checks that this ExchangeQueue is in a valid state. + // Kills the process if not. + void CheckInvariants(); + + // We can't process any exchanges until the call has started. + bool call_started_ = false; + + // std::queue is based on std::deque by default. std::deque provides + // fairly strong iterator stability. + std::deque exchanges_; +}; // namespace tensorflow + +// Represents state associated with one streaming RPC call. +// Thread-safe +template +class StreamingRPCState : public UntypedStreamingRPCState { + public: + // Default behavior is to set fail_fast = False and handle timeouts + // manually. + StreamingRPCState( + std::unique_ptr<::grpc::GenericClientAsyncReaderWriter> call, + const std::shared_ptr<::grpc::ClientContext>& context) + : context_(context), call_(std::move(call)), call_state_(State::kActive) { + Ref(); + VLOG(3) << "Created new StreamingRPCState " << this; + VLOG(3) << "StreamingRPCState(" << this << ") calling grpc::StartCall"; + call_->StartCall(&call_started_tag_); + } + + ~StreamingRPCState() override { + VLOG(3) << "Destructing StreamingRPCState " << this; + } + + // Attempts to send the next request. `done` is invoked when + // `response` has been filled with the data from the server, or if there + // is an error. `done` can be invoked before SendNextRequest returns. + // Return `true` if the call is alive and the `done` callback has or + // will be invoked. If the call is dead, returns `false`. `done` callback + // will not be invoked in this case. + // REQUIRES: The call has been started, i.e. WaitForCallStarted() has + // returned. + bool SendNextRequest(const protobuf::Message& request, Response* response, + const StatusCallback& done) { + ::grpc::ByteBuffer request_buf; + ::grpc::Status s = tsl::GrpcMaybeUnparseProto(request, &request_buf); + if (!s.ok()) { + absl::Status status = FromGrpcStatus(s); + LOG(ERROR) << "GrpcMaybeUnparseProto returned with non-ok status: " + << status.ToString(); + done(status); + return true; + } + + mutex_lock l(mu_); + if (call_state_ != State::kActive) { + // `done` is not invoked intentionally. + return false; + } + if (VLOG_IS_ON(3)) { + // If vlog 3 is enabled, include first 100 chars of request as debug + // string. + exchanges_.Emplace(request_buf, response, done, + request.ShortDebugString().substr(0, 100)); + } else { + exchanges_.Emplace(request_buf, response, done, ""); + } + MaybeIssueRequestWriteLocked(); + return true; + } + + void CallStarted(bool ok) override { + VLOG(3) << "StreamingRPCState(" << this << ")::CallStarted(ok=" << ok + << ")"; + mutex_lock l(mu_); + if (!ok) { + call_state_ = State::kDone; + return; + } + exchanges_.CallStarted(); + // Now that the call has started, we can write our first request, if any. + MaybeIssueRequestWriteLocked(); + } + + void RequestWriteCompleted(bool ok) override { + VLOG(3) << "StreamingRPCState(" << this + << ")::RequestWriteCompleted(ok=" << ok << ")"; + mu_.lock(); + if (call_state_ != State::kActive) { + mu_.unlock(); + return; + } + exchanges_.MarkRequestWriteCompleted(); + // Issue ResponseRead regardless of OK status on completing RequestWrite. + // If the underlying completion queue is in Not-OK status due to previous + // request failuress (i.e., `ok` from `Next` call on completion queue is + // False), delay the error in ResponseRead so we can get the remote error + // message from response buffer. + MaybeIssueResponseReadLocked(); + + if (ok) { + MaybeIssueRequestWriteLocked(); + } + mu_.unlock(); + } + + void ResponseReadCompleted(bool ok) override { + VLOG(3) << "StreamingRPCState(" << this + << ")::ResponseReadCompleted(ok=" << ok << ")"; + mu_.lock(); + if (call_state_ != State::kActive) { + mu_.unlock(); + return; + } + if (!ok) { + IssueCallFinishLocked(); + mu_.unlock(); + return; + } + + // Complete the exchange without holding the lock because user's + // callback can call back into this RPC code resulting in a deadlock. + // No other thread can pop this exchange while we release the lock because + // this is the only method that pops exchanges and it is called from a + // single thread that waits on completion queue events. + Exchange* e; + e = &exchanges_.GetFront(); + mu_.unlock(); + + e->Complete(absl::OkStatus()); + + { + mutex_lock l(mu_); + exchanges_.PopFront(); + MaybeIssueResponseReadLocked(); + } + } + + void CallFinished(bool ok) override { + VLOG(3) << "StreamingRPCState(" << this << ")::CallFinished(ok=" << ok + << ")"; + mu_.lock(); + DCHECK(call_state_ != State::kActive); + if (call_state_ != State::kFinishing) { + mu_.unlock(); + return; + } + + absl::Status s = FromGrpcStatus(call_status_); + if (s.ok() && !ok) { + s.Update( + errors::Internal("GRPC status is okay but CompletionQueueStatus is " + "not. This should never happen.", + context_->debug_error_string())); + } + // unlocks mu_ + MarkDoneAndCompleteExchanges(s); + } + + string DebugString() const override { + mutex_lock l(mu_); + return exchanges_.DebugString(); + } + + private: + enum class State { + kActive, + kFinishing, + kDone, + }; + + void MarkDoneAndCompleteExchanges(absl::Status status) + TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) TF_UNLOCK_FUNCTION(mu_) { + call_state_ = State::kDone; + VLOG(2) << "Ending gRPC streaming call on the client side due to " + << status.ToString(); + // Swap the exchanges_ into a temporary ExchangeQueue so that we can + // complete all exchanges without holding mu_ in case user callback + // reach back into this. This should be impossible now, but safer for + // the future. + ExchangeQueue queue; + exchanges_.Swap(&queue); + mu_.unlock(); + queue.CompleteAll(status); + } + + void MaybeIssueRequestWriteLocked() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { + Exchange* exchange = exchanges_.GetReadyForRequestWriting(); + if (exchange == nullptr) { + // There are no queued exchanges, there is already an outstanding write, + // or there are no just created exchanges. + return; + } + exchange->MarkRequestWriteIssued(); + Ref(); + VLOG(3) << "StreamingRPCState(" << this << ") calling grpc::Write"; + call_->Write(exchange->request_buf(), &request_write_completed_tag_); + } + + void MaybeIssueResponseReadLocked() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { + Exchange* exchange = exchanges_.GetReadyForResponseReading(); + if (exchange == nullptr) { + return; + } + exchange->MarkResponseReadIssued(); + Ref(); + VLOG(3) << "StreamingRPCState(" << this << ") calling grpc::Read"; + call_->Read(exchange->response_buf(), &response_read_completed_tag_); + } + + void IssueCallFinishLocked() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { + call_state_ = State::kFinishing; + Ref(); + VLOG(3) << "StreamingRPCState(" << this << ") calling grpc::Finish"; + // We call finish in response to completed (with error) response reading tag + // on some exchange. We let this exchange hang in ResponseReadIssued state. + // ExchangeQueue makes sure that there is at most one exchange in this + // state. So, no new reads will be issued. + call_->Finish(&call_status_, &finished_tag_); + } + + // Holds state for a single request/response exchange between the client + // and the server. + typedef typename UntypedStreamingRPCState::Tag Tag; + + // Order of context_ and call_ is important because context_ must outlive + // call_. + const std::shared_ptr context_; + std::unique_ptr<::grpc::GenericClientAsyncReaderWriter> call_; + + mutable mutex mu_; + ExchangeQueue exchanges_ TF_GUARDED_BY(mu_); + State call_state_ TF_GUARDED_BY(mu_); + ::grpc::Status call_status_ TF_GUARDED_BY(mu_); + + // We can get away with having single instances of these tags per + // StreamingRPCState because we make sure (as gRPC requires) that + // there is at most one outstanding Read and at most one outstanding Write + // in the completion queue. + // Tags are immutable. No need to guard them. + Tag call_started_tag_{this, Tag::TagType::kCallStarted}; + Tag request_write_completed_tag_{this, Tag::TagType::kRequestWriteCompleted}; + Tag response_read_completed_tag_{this, Tag::TagType::kResponseReadCompleted}; + Tag finished_tag_{this, Tag::TagType::kCallFinished}; +}; + +// Creates streaming calls and dispatches requests to them. +// In the common case, the client would create a StreamingRPCDispatcher for +// each bidirectional streaming RPC it might want to make. The first time, it +// calls SendNextRequest, a streaming call is initiated and the request is +// sent within this call. Initiation of the call blocks the client. If there are +// no errors, subsequent calls to SendNextRequest would use the already active +// call. If there was an error, the call object will be destroyed after all +// the callbacks for outstanding requests have been invoked. The next call to +// SendNextRequest will initiate a new call. +// +// Callbacks that are part of the same call, are invoked in the order they were +// provided, but callbacks across calls (a failed and a new one) can be invoked +// in any order. +// +// Thread-safe. +template +class StreamingRPCDispatcher { + public: + StreamingRPCDispatcher(::grpc::GenericStub* stub, ::grpc::CompletionQueue* cq, + const ::grpc::string& method) + : stub_(stub), cq_(cq), method_(method) {} + + // Attempts to send the next request. If there is no active streaming call, + // starts one and sends the request on top of it. `done` is invoked when + // `response` has been filled with the data from the server, or if there + // is an error. `done` can be invoked before SendNextRequest returns. + void SendNextRequest(const protobuf::Message& request, Response* response, + StatusCallback done) { + mutex_lock l(mu_); + if (state_ == nullptr) { + CreateStreamingState(); + } + + bool is_call_alive = state_->SendNextRequest(request, response, done); + if (is_call_alive) { + return; + } + + // The attempt to send failed because the call was dead, create a new + // call and try again. When the call is dead SendNextRequest does not call + // `done`. + CreateStreamingState(); + + is_call_alive = state_->SendNextRequest(request, response, done); + if (!is_call_alive) { + // Consider retrying to create and start a call few more times. + done(errors::Unknown("gRPC call failed right after it was created")); + } + } + + // Request to cancel the current streaming call. Non-blocking. + void CancelCall() { + mutex_lock l(mu_); + if (state_ == nullptr) { + return; + } + context_->TryCancel(); + state_ = nullptr; + } + + private: + void CreateStreamingState() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { + // ClientContext cannot be reused across calls. + context_ = std::make_shared<::grpc::ClientContext>(); + // Don't immediately fail StartCall if the channel is not ready. Wait for + // the channel to become ready. + context_->set_wait_for_ready(true); + + std::unique_ptr<::grpc::GenericClientAsyncReaderWriter> call = + stub_->PrepareCall(context_.get(), method_, cq_); + + state_.reset(new StreamingRPCState(std::move(call), context_)); + } + + mutable mutex mu_; + + // Both are thread-safe + ::grpc::GenericStub* const stub_; + ::grpc::CompletionQueue* const cq_; + + // Does not need synchronization since it is constant. + const ::grpc::string method_; + + std::shared_ptr<::grpc::ClientContext> context_ TF_GUARDED_BY(mu_); + core::RefCountPtr> state_ TF_GUARDED_BY(mu_); +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_STATE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/rpc/grpc_tensor_coding.h b/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/rpc/grpc_tensor_coding.h new file mode 100644 index 00000000..393ef2a7 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/rpc/grpc_tensor_coding.h @@ -0,0 +1,57 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_TENSOR_CODING_H_ +#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_TENSOR_CODING_H_ + +#include "grpcpp/impl/codegen/byte_buffer.h" +#include "absl/status/status.h" + +namespace tensorflow { +class Tensor; +class RecvTensorResponse; + +// TODO(jeff,sanjay): this should not be grpc specific. Instead of +// grpc::ByteBuffer*, it should accept an object of an interface type +// to which owned byte-arrays can be added. +namespace grpc { + +// Encode a RecvTensorResponse protocol buffer into a byte buffer in a +// format that is parseable as a RecvTensorResponse protocol buffer +// holding "proto". +// +// Discards original contents of *result. +void EncodeRecvTensorResponseToByteBuffer(const RecvTensorResponse& proto, + ::grpc::ByteBuffer* result); + +// Encode a Tensor into a byte buffer in a format that is parseable +// as a RecvTensorResponse protocol buffer holding "val". +// +// "is_dead" is the value to encode for "RecvTensorResponse::is_dead" +// (tensor is the output of a dead node and content is invalid because +// control flow operations elsewhere caused the path on which this +// Tensor exists to not be taken). +// +// "val" holds the tensor value to be encoded. +// +// Discards original contents of *result. +absl::Status EncodeTensorToByteBuffer(bool is_dead, const Tensor& val, + bool require_ack, + ::grpc::ByteBuffer* result); + +} // namespace grpc +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_TENSOR_CODING_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/rpc/grpc_testlib.h b/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/rpc/grpc_testlib.h new file mode 100644 index 00000000..9101ca92 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/rpc/grpc_testlib.h @@ -0,0 +1,100 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_TESTLIB_H_ +#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_TESTLIB_H_ + +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "tensorflow/core/framework/device_attributes.pb.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/subprocess.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/public/session_options.h" + +namespace tensorflow { + +class Device; + +namespace test { + +struct TestJob { + std::string name; + int num_tasks; + int num_replicas = 1; +}; + +struct TestClusterConfig { + std::string binary_path; + SessionOptions options; + std::vector jobs; + + TestClusterConfig& Options(const SessionOptions& options) { + this->options = options; + return *this; + } + TestClusterConfig& Jobs(const std::vector& jobs) { + this->jobs = jobs; + return *this; + } +}; + +// Provides a handle to a set of TensorFlow servers (masters and +// workers) for testing purposes. +// +// This class currently runs the servers in separate processes; the +// lifetime of this object is coterminous with the lifetimes of those +// processes. +class TestCluster { + public: + // Creates a new test cluster based on the given `options` (which + // configure the number of devices of each type) and a count of + // processes `n`. On success, the test cluster is stored in + // *out_cluster, and this function returns OK. Otherwise an error is + // returned. + static absl::Status MakeTestCluster( + const TestClusterConfig& config, + std::unique_ptr* out_cluster); + ~TestCluster(); + + // Returns a vector of string ":" pairs that may be + // used as targets to construct a GrpcSession. + const std::vector& targets(std::string job_name = "localhost") { + return targets_.at(job_name); + } + + // Returns a vector of devices available in this test cluster. + const std::vector& devices() const { return devices_; } + + private: + TestCluster() = default; + + std::vector> subprocesses_; + absl::flat_hash_map> targets_; + std::vector devices_; + + TestCluster(const TestCluster&) = delete; + void operator=(const TestCluster&) = delete; +}; + +} // end namespace test +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_TESTLIB_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/rpc/grpc_util.h b/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/rpc/grpc_util.h new file mode 100644 index 00000000..0db18382 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/rpc/grpc_util.h @@ -0,0 +1,72 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_UTIL_H_ +#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_UTIL_H_ + +#include +#include + +#include "grpcpp/grpcpp.h" +#include "grpcpp/impl/codegen/proto_utils.h" +#include "grpcpp/support/byte_buffer.h" +#include "xla/tsl/distributed_runtime/rpc/grpc_util.h" +#include "tensorflow/core/distributed_runtime/tensor_coding.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/strings/stringprintf.h" +#include "tensorflow/core/platform/protobuf.h" + +namespace tensorflow { +// NOLINTBEGIN(misc-unused-using-decls) +using tsl::FromGrpcStatus; +using tsl::SharedGrpcChannelPtr; +using tsl::ToGrpcStatus; +// NOLINTEND(misc-unused-using-decls) + +// Thin wrapper around ::grpc::ProtoBufferReader to give TensorResponse +// an efficient byte reader from which to decode a RecvTensorResponse. +class GrpcByteSource : public TensorResponse::Source { + public: + explicit GrpcByteSource(::grpc::ByteBuffer* buffer) : buffer_(buffer) {} + ~GrpcByteSource() override { DeleteStream(); } + + typedef ::grpc::ProtoBufferReader Reader; + + protobuf::io::ZeroCopyInputStream* contents() override { + DeleteStream(); + stream_ = new (&space_) Reader(buffer_); + return stream_; + } + + private: + void DeleteStream() { + if (stream_) { + stream_->~Reader(); + } + } + + ::grpc::ByteBuffer* buffer_; // Not owned + Reader* stream_ = nullptr; // Points into space_ if non-nullptr + char space_[sizeof(Reader)]; +}; + +inline string GrpcIdKey() { return "tf-rpc"; } + +// Decode a TensorResponse without extra copying. This function is an optimized +// variant of tsl::GrpcMaybeParseProto. +bool GrpcMaybeParseTensorResponse(::grpc::ByteBuffer* src, TensorResponse* dst); +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_UTIL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.h b/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.h new file mode 100644 index 00000000..2dfbc79a --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.h @@ -0,0 +1,76 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_WORKER_CACHE_H_ +#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_WORKER_CACHE_H_ + +#include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h" +#include "tensorflow/core/distributed_runtime/rpc/grpc_client_cq_tag.h" +#include "tensorflow/core/distributed_runtime/worker_cache.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/threadpool.h" + +namespace tensorflow { + +class GrpcWorkerEnv { + public: + GrpcWorkerEnv(size_t num_completion_queues, size_t num_threads); + + ~GrpcWorkerEnv(); + + thread::ThreadPool* GetThreadPool() const { return threadpool_.get(); } + + size_t CompletionQueueSize() const { return threads_.size(); } + + ::grpc::CompletionQueue* GetCompletionQueue(size_t index) const { + return threads_.at(index).completion_queue(); + } + + private: + // Thread wrapping class that drives work over a single gRPC + // CompletionQueue. + class GrpcWorkerCacheThread { + public: + GrpcWorkerCacheThread(); + + ~GrpcWorkerCacheThread(); + + ::grpc::CompletionQueue* completion_queue() const { + return &completion_queue_; + } + + private: + mutable ::grpc::CompletionQueue completion_queue_; + std::unique_ptr thread_; + }; + + std::unique_ptr threadpool_; + std::vector threads_; +}; + +// Create a GrpcWorkerEnv instance that can be used as argument to create +// gRPC worker cache. Caller should take the ownership of the returned instance. +GrpcWorkerEnv* CreateGrpcWorkerEnv(); + +// The returned WorkerCacheInterface object takes the ownership of "cc". +WorkerCacheInterface* NewGrpcWorkerCache(std::shared_ptr cc, + GrpcWorkerEnv* worker_env); + +WorkerCacheInterface* NewGrpcWorkerCacheWithLocalWorker( + std::shared_ptr cc, GrpcWorkerEnv* worker_env, + WorkerInterface* local_worker, const string& local_target); + +} // namespace tensorflow +#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_WORKER_CACHE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.h b/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.h new file mode 100644 index 00000000..ebb1ac91 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.h @@ -0,0 +1,92 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_WORKER_SERVICE_H_ +#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_WORKER_SERVICE_H_ + +#include +#include + +#include "grpcpp/server_builder.h" +#include "xla/tsl/distributed_runtime/rpc/async_service_interface.h" +#include "tensorflow/core/distributed_runtime/rpc/grpc_worker_service_impl.h" +#include "tensorflow/core/distributed_runtime/rpc/rpc_response_cache.h" +#include "tensorflow/core/distributed_runtime/worker.h" +#include "tensorflow/core/protobuf/worker.pb.h" + +namespace grpc { +class ByteBuffer; +} // namespace grpc + +namespace tsl { +class AsyncServiceInterface; +} + +namespace tensorflow { + +class ConfigProto; +struct WorkerEnv; +class WorkerSession; +class RpcResponseCache; + +class GrpcWorker : public Worker { + public: + GrpcWorker(WorkerEnv* env, const ConfigProto& config); + + // Specialized version of RecvTensor for gRPC, which avoids a copy. + virtual void GrpcRecvTensorAsync(CallOptions* opts, + const RecvTensorRequest* request, + ::grpc::ByteBuffer* response, + StatusCallback done); + + void LoggingAsync(const LoggingRequest* request, LoggingResponse* response, + StatusCallback done) override; + + void RecvBufAsync(CallOptions* opts, const RecvBufRequest* request, + RecvBufResponse* response, StatusCallback done) override; + + void CleanupGraphAsync(const CleanupGraphRequest* request, + CleanupGraphResponse* response, + StatusCallback done) override; + + WorkerEnv* env(); + + void EnableResponseCache(); + + void RemoveCacheEntryForId(int64_t request_id); + + private: + std::unique_ptr response_cache_; + const int32 recv_buf_max_chunk_; +}; + +std::unique_ptr NewGrpcWorker(WorkerEnv* worker_env, + const ConfigProto& config); + +struct GrpcWorkerServiceOptions { + // Map from GrpcWorkerMethod id to queue depth. If set this overrides the + // default queue depth for a method. + std::unordered_map queue_depth; + int num_serving_threads = 8; +}; + +// Returns an implementation of WorkerService rpc service. +std::unique_ptr NewGrpcWorkerService( + GrpcWorker* worker, ::grpc::ServerBuilder* builder, + GrpcWorkerServiceOptions options = GrpcWorkerServiceOptions()); + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_WORKER_SERVICE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/rpc/grpc_worker_service_impl.h b/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/rpc/grpc_worker_service_impl.h new file mode 100644 index 00000000..25f5ec97 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/rpc/grpc_worker_service_impl.h @@ -0,0 +1,118 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_WORKER_SERVICE_IMPL_H_ +#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_WORKER_SERVICE_IMPL_H_ + +#include "grpcpp/impl/codegen/async_stream.h" +#include "grpcpp/impl/codegen/async_unary_call.h" +#include "grpcpp/impl/codegen/proto_utils.h" +#include "grpcpp/impl/codegen/rpc_method.h" +#include "grpcpp/impl/codegen/service_type.h" +#include "grpcpp/impl/codegen/status.h" +#include "grpcpp/impl/codegen/stub_options.h" +#include "grpcpp/impl/codegen/sync_stream.h" +#include "grpcpp/support/byte_buffer.h" + +#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h" +#include "tensorflow/core/distributed_runtime/tensor_coding.h" +#include "tensorflow/core/protobuf/worker.pb.h" + +namespace grpc { + +// Support parsing/unparsing of tensorflow::TensorResponse. +// Wire-format is identical to RecvTensorResponse. +// This is specializing an existing template, so it's okay to do this in a +// namespace that we don't own. +template <> +class SerializationTraits { + public: + static Status Serialize(const tensorflow::TensorResponse& msg, ByteBuffer* bp, + bool* own_buffer) { + LOG(FATAL) << "TODO(sanjay,jeff): Implement"; + return Status(); + } + static Status Deserialize(ByteBuffer* buffer, + tensorflow::TensorResponse* msg) { + if (buffer == nullptr) { + return Status(StatusCode::INTERNAL, "No payload"); + } + Status result = Status::OK; + if (result.ok()) { + ::tensorflow::GrpcByteSource source(buffer); + auto s = msg->ParseFrom(&source); + if (!s.ok()) { + result = Status(StatusCode::INTERNAL, + ::tensorflow::strings::StrCat( + "TensorResponse parse error", s.message())); + } + } + buffer->Clear(); + return result; + } +}; + +} // namespace grpc + +namespace tensorflow { + +// Names of worker methods. +enum class GrpcWorkerMethod { + kGetStatus, + kCreateWorkerSession, + kDeleteWorkerSession, + kRegisterGraph, + kDeregisterGraph, + kRunGraph, + kCleanupGraph, + kCleanupAll, + kRecvTensor, + kRecvBuf, + kLogging, + kTracing, + kCompleteGroup, + kCompleteInstance, + kGetStepSequence, + kMarkRecvFinished, +}; + +static const int kGrpcNumWorkerMethods = + static_cast(GrpcWorkerMethod::kMarkRecvFinished) + 1; + +const char* GrpcWorkerMethodName(GrpcWorkerMethod id); + +namespace grpc { + +// Implementation of `tensorflow.WorkerService`, based on the +// definition in "//tensorflow/core/protobuf/worker_service.proto", +// and the gRPC generated stub and service classes. +// See the proto file for the definition of methods and messages. +class WorkerService final { + public: + class AsyncService : public ::grpc::Service { + public: + AsyncService(); + virtual ~AsyncService(); + + // Make RequestAsyncUnary public for grpc_call.h + using ::grpc::Service::RequestAsyncUnary; + }; +}; + +} // namespace grpc + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_WORKER_SERVICE_IMPL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h b/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h new file mode 100644 index 00000000..42eda4ea --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h @@ -0,0 +1,60 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_RPC_RENDEZVOUS_MGR_H_ +#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_RPC_RENDEZVOUS_MGR_H_ + +#include "tensorflow/core/distributed_runtime/base_rendezvous_mgr.h" +#include "tensorflow/core/distributed_runtime/worker_env.h" +#include "tensorflow/core/platform/macros.h" + +namespace tensorflow { + +class DeviceMgr; + +// RendezvousMgr keeps track of a set of local rendezvous instances. +// All tensors sent by this worker are buffered in a RendezvousMgr +// until the tensor is received. Each global unique "step_id" +// corresponds to one local rendezvous instance managed by a +// RendezvousMgr. +// +// E.g., +// Rendezvous* rendez = worker_env->rendezvous_mgr->Find(0x8935); +// fork execution of an graph executor using "rendez" on thread 1; +// fork execution of another graph executor using "rendez" on thread 2; +// ... +// join threads 1 and 2; +// +// In the example above, execution in thread 1 and 2 communicates with +// each other by send/recv operations through the "rend". +// +// Tensors sent and recved through rendezvous managed by this +// RendezvousMgr must have keys generated by Rendezvous::CreateKey. +class RpcRendezvousMgr : public BaseRendezvousMgr { + public: + explicit RpcRendezvousMgr(const WorkerEnv* env); + + protected: + tsl::core::RefCountPtr Create( + int64_t step_id, const WorkerEnv* worker_env) override; + + private: + RpcRendezvousMgr(const RpcRendezvousMgr&) = delete; + void operator=(const RpcRendezvousMgr&) = delete; +}; + +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_RPC_RENDEZVOUS_MGR_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/rpc/rpc_response_cache.h b/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/rpc/rpc_response_cache.h new file mode 100644 index 00000000..0f31ddaf --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/rpc/rpc_response_cache.h @@ -0,0 +1,97 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_RPC_RESPONSE_CACHE_H_ +#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_RPC_RESPONSE_CACHE_H_ + +#include +#include +#include + +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/gtl/flatmap.h" +#include "tensorflow/core/platform/mutex.h" + +// gRPC response caching. Most WorkerService methods cannot be retried directly +// as they will fail or deadlock. To enable retrying, we can instead cache +// responses and reply to duplicate requests from the cache. The cache will be +// cleaned when the MarkRecvFinishedRequest is received from the receiver or the +// session step is completed. +namespace tensorflow { + +// Track and cache the state of worker service RPCs. An RPC can be in 3 states: +// +// * PENDING: this is the first call of the RPC, and it will transition to +// * ACTIVE: another thread is active processing this RPC +// * FINISHED: the worker has finished processing the method + +class RpcResponseCache { + public: + using FinishResponseCB = std::function; + + // Add the given request to the cache. + // If the request is in the cache, + // If it is finished, invoke `cb` immediately + // If active, cb will be invoked when the current call completes. + // In either case, return true. + // Otherwise, store the request and cb in the cache, and return false. + // Note FinishResponseCB is assumed to be thread-safe. + bool QueueRequest(int64_t request_id, int64_t step_id, + const FinishResponseCB& cb); + + // Fill the response cache for the given request_id and respond to all + // pending request. + void RequestFinished(int64_t request_id, const Tensor& tensor, bool is_dead, + const absl::Status& status); + + // Erase the cache entry with the given request_id + void EraseRequestId(int64_t request_id); + + // Erase cache entries with the given step_id + void CleanEntriesForStep(int64_t step_id); + + int64_t size(); + + private: + struct ResponseCacheEntry { + enum class State { + PENDING = 0, + ACTIVE = 1, + FINISHED = 2, + }; + + State state = State::PENDING; + int64_t step_id = -1; + Tensor tensor; + bool is_dead = false; + absl::Status response_status; + + void FinishResponse(const FinishResponseCB& cb) const { + cb(tensor, is_dead, response_status); + } + std::vector callbacks; + }; + + mutex mu_; + // response_cache_ is expected to be small, as entries are cleared immediately + // on ack from the receiver. + gtl::FlatMap response_cache_ TF_GUARDED_BY(mu_); +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_RPC_RESPONSE_CACHE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/rpc_collective_executor_mgr.h b/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/rpc_collective_executor_mgr.h new file mode 100644 index 00000000..6836204c --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/rpc_collective_executor_mgr.h @@ -0,0 +1,94 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_COLLECTIVE_EXECUTOR_MGR_H_ +#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_COLLECTIVE_EXECUTOR_MGR_H_ + +#include "tensorflow/core/common_runtime/collective_executor_mgr.h" +#include "tensorflow/core/framework/collective.h" + +namespace tensorflow { +class CollectiveParamResolverDistributed; +class ConfigProto; +class DeviceMgr; +class DeviceResolverDistributed; +class WorkerCacheInterface; +class StepSequenceRequest; +class StepSequenceResponse; + +// An implementation of CollectiveExecutorMgr for a distributed environment +// that uses WorkerInterface::RecvBufAsync to route data transfers over RPCs. +// +// In some execution environments it may be possible to implement a +// higher-performance solution and use it in place of this class. +class RpcCollectiveExecutorMgr : public CollectiveExecutorMgr { + public: + RpcCollectiveExecutorMgr( + const ConfigProto& config, const DeviceMgr* dev_mgr, + std::unique_ptr dev_resolver, + std::unique_ptr param_resolver, + std::unique_ptr nccl_communicator, + WorkerCacheInterface* worker_cache, const string& task_name); + + virtual ~RpcCollectiveExecutorMgr(); + + // This function should only be called at the group_leader, by an RPC. + // Other needs for StepIds should be satisfied by NextStepId. + void GetStepSequenceAsync(const GetStepSequenceRequest* request, + GetStepSequenceResponse* response, + const StatusCallback& done) override; + + void RefreshStepIdSequenceAsync(int64_t graph_key, + const StatusCallback& done) override; + + int64_t NextStepId(int64_t graph_key) override; + + void RetireStepId(int64_t graph_key, int64_t step_id) override; + + protected: + virtual CollectiveExecutor* Create(int64_t step_id) override; + + WorkerCacheInterface* const worker_cache_; // Not owned. + const string task_name_; + string group_leader_; + friend class RpcCollectiveExecutorMgrTest; + + private: + absl::Status UpdateStepSequences(const GetStepSequenceResponse& resp); + + // This class maintains the step_id sequencing for a single + // collective_graph_key. + struct GraphKeySequence { + explicit GraphKeySequence(int64_t k) + : graph_key_(k), next_step_id_(CollectiveExecutor::kInvalidId) {} + + const int64_t graph_key_; + int64_t next_step_id_; + }; + + mutex sequence_mu_; + gtl::FlatMap sequence_table_ + TF_GUARDED_BY(sequence_mu_); +}; + +// Creates a distributed CollectiveExecutorMgr with production implementations +// of each components. Cases that need to inject other implementations of these +// components should call CollectiveExecutorMgr constructor directly. +std::unique_ptr CreateProdRpcCollectiveExecutorMgr( + const ConfigProto& config, const DeviceMgr* device_mgr, + std::unique_ptr nccl_communicator, + WorkerCacheInterface* worker_cache, const string& default_worker_name); + +} // namespace tensorflow +#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_COLLECTIVE_EXECUTOR_MGR_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/scheduler.h b/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/scheduler.h new file mode 100644 index 00000000..4385db78 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/scheduler.h @@ -0,0 +1,121 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_SCHEDULER_H_ +#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_SCHEDULER_H_ + +#include +#include +#include +#include +#include + +#include "tensorflow/core/common_runtime/device.h" +#include "tensorflow/core/common_runtime/device_set.h" +#include "tensorflow/core/graph/costmodel.h" + +namespace tensorflow { + +class SlackAnalysis { + public: + SlackAnalysis(const Graph* g, const CostModel* cost_model); + + ~SlackAnalysis() {} + + // Compute the earliest possible start time for each node, based on + // a given cost model. 'asap_time' is indexed by node id. + Microseconds ComputeAsap(std::vector* asap_times); + + // Compute the latest possible start time for each node, based on + // a given cost model. 'alap_time' is indexed by node id. + Microseconds ComputeAlap(std::vector* alap_times); + + // Compute the "slack" of each node. 'slacks' is indexed by node id. + void ComputeSlack(std::vector* slacks); + + private: + const Graph* graph_; + const CostModel* cost_model_; + + SlackAnalysis(const SlackAnalysis&) = delete; + void operator=(const SlackAnalysis&) = delete; +}; + +class GreedyScheduler { + public: + struct Sim { + int degree_parallelism; + int num_running; + std::vector ready_nodes; + }; + + struct Event { + const Node* node; + Microseconds time; + bool is_completion; + + bool operator<(const Event& other) const { return time < other.time; } + }; + + GreedyScheduler(const DeviceSet* devices, const CostModel* cost_model, + const Graph* g, std::vector* priority); + + ~GreedyScheduler(); + + // Computes the start time of each node given the priorities of + // the nodes. + Microseconds ComputeSchedule(std::vector* start_times); + + private: + // Returns the ready node with the highest priority for a sim. + const Node* GetNodeWithHighestPriority(const std::vector& nodes); + + const DeviceSet* devices_; + const CostModel* cost_model_; + const Graph* graph_; + std::vector* priority_; + std::unordered_map device_states_; + + GreedyScheduler(const GreedyScheduler&) = delete; + void operator=(const GreedyScheduler&) = delete; +}; + +class PriorityScheduler { + public: + PriorityScheduler(const DeviceSet* devices, const CostModel* cost_model, + const Graph* g); + + ~PriorityScheduler() {} + + // Computes a schedule of the ideal start time for each node. + // Returns the makespan (the total running time). + Microseconds ComputeSchedule(std::vector* start_times); + + // Computes a schedule and assigns priorities to the nodes based on + // the schedule. Returns the makespan. + Microseconds AssignPriorities(std::vector* priorities); + + private: + const DeviceSet* devices_; + const CostModel* cost_model_; + const Graph* graph_; + + PriorityScheduler(const PriorityScheduler&) = delete; + void operator=(const PriorityScheduler&) = delete; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_SCHEDULER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/server_lib.h b/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/server_lib.h new file mode 100644 index 00000000..cc92d0ba --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/server_lib.h @@ -0,0 +1,135 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_SERVER_LIB_H_ +#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_SERVER_LIB_H_ + +#include + +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/protobuf/tensorflow_server.pb.h" + +namespace tsl { +class CoordinationServiceAgent; +} // namespace tsl + +namespace tensorflow { + +class DeviceMgr; +class EagerContext; +class WorkerEnv; +class MasterEnv; + +// This library supports a registration/factory-based mechanism for +// creating TensorFlow server objects. Each server implementation must +// have an accompanying implementation of ServerFactory, and create a +// static "registrar" object that calls `ServerFactory::Register()` +// with an instance of the factory class. See "rpc/grpc_server_lib.cc" +// for an example. + +// Represents a single TensorFlow server that exports Master and Worker +// services. +class ServerInterface { + public: + ServerInterface() {} + virtual ~ServerInterface() {} + + // Starts the server running asynchronously. Returns OK on success, otherwise + // returns an error. + virtual absl::Status Start() = 0; + + // Stops the server asynchronously. Returns OK on success, otherwise returns + // an error. + // + // After calling `Stop()`, the caller may call `Join()` to block until the + // server has stopped. + virtual absl::Status Stop() = 0; + + // Blocks until the server has stopped. Returns OK on success, otherwise + // returns an error. + virtual absl::Status Join() = 0; + + // Returns a target string that can be used to connect to this server using + // `tensorflow::NewSession()`. + virtual const string target() const = 0; + + virtual WorkerEnv* worker_env() = 0; + virtual MasterEnv* master_env() = 0; + + // Update the set of workers that can be reached by the server + virtual absl::Status UpdateServerDef(const ServerDef& server_def) = 0; + + // Functions to operate on service-specific properties. + // + // Add master eager context to local eager service in order to handle enqueue + // requests from remote workers. + virtual absl::Status AddMasterEagerContextToEagerService( + const tensorflow::uint64 context_id, EagerContext* context) = 0; + // Set coordination service agent instance to coordination service RPC handler + virtual absl::Status SetCoordinationServiceAgentInstance( + tsl::CoordinationServiceAgent* agent) = 0; + // TODO(hanyangtay): Remove this method once gRPC server clean shutdown is + // supported. + virtual absl::Status StopCoordinationService() = 0; + + private: + ServerInterface(const ServerInterface&) = delete; + void operator=(const ServerInterface&) = delete; +}; + +class ServerFactory { + public: + struct Options { + // Local DeviceMgr to use. + tensorflow::DeviceMgr* local_device_mgr; + }; + // Creates a new server based on the given `server_def`, and stores + // it in `*out_server`. Returns OK on success, otherwise returns an + // error. + virtual absl::Status NewServer( + const ServerDef& server_def, const Options& options, + std::unique_ptr* out_server) = 0; + + // Returns true if and only if this factory can create a server + // based on the given `server_def`. + virtual bool AcceptsOptions(const ServerDef& server_def) = 0; + + virtual ~ServerFactory() {} + + // For each `ServerFactory` subclass, an instance of that class must + // be registered by calling this method. + // + // The `server_type` must be unique to the server factory. + static void Register(const string& server_type, ServerFactory* factory); + + // Looks up a factory that can create a server based on the given + // `server_def`, and stores it in `*out_factory`. Returns OK on + // success, otherwise returns an error. + static absl::Status GetFactory(const ServerDef& server_def, + ServerFactory** out_factory); +}; + +// Creates a server based on the given `server_def`, and stores it in +// `*out_server`. Returns OK on success, otherwise returns an error. +absl::Status NewServer(const ServerDef& server_def, + std::unique_ptr* out_server); +absl::Status NewServerWithOptions(const ServerDef& server_def, + const ServerFactory::Options& options, + std::unique_ptr* out_server); + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_SERVER_LIB_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/session_mgr.h b/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/session_mgr.h new file mode 100644 index 00000000..55c64f45 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/session_mgr.h @@ -0,0 +1,169 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_SESSION_MGR_H_ +#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_SESSION_MGR_H_ + +#include +#include + +#include "xla/tsl/distributed_runtime/coordination/coordination_service.h" +#include "xla/tsl/distributed_runtime/coordination/coordination_service_agent.h" +#include "xla/tsl/distributed_runtime/coordination/coordination_service_rpc_handler.h" +#include "tensorflow/core/distributed_runtime/worker_session.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/thread_annotations.h" +#include "tensorflow/core/protobuf/tensorflow_server.pb.h" +#include "tensorflow/core/protobuf/worker.pb.h" + +namespace tensorflow { + +class WorkerCacheInterface; +struct WorkerEnv; + +// SessionMgr keeps track of information related to a given session. +// +// SessionMgr runs on the workers. +// +// SessionMgr is threadsafe. +class SessionMgr { + public: + typedef std::function + WorkerCacheFactory; + + explicit SessionMgr( + WorkerEnv* worker_env, const std::string& default_worker_name, + std::unique_ptr default_worker_cache, + WorkerCacheFactory worker_cache_factory, + tsl::CoordinationServiceRpcHandler* coordination_handler); + ~SessionMgr() {} + + // Allocates state for a new session. + absl::Status CreateSession( + const std::string& session, const ServerDef& server_def, + bool isolate_session_state, + StatusCallback coordination_error_callback = [](absl::Status s) { + LOG(ERROR) << "Coordination agent is set to error: " << s; + }); + absl::Status CreateSession( + const std::string& session, const ServerDef& server_def, + const protobuf::RepeatedPtrField& device_attributes, + bool isolate_session_state); + + // Create WorkerSession from the master with the given `master_task` and + // `master_incarnation`. We first look for existing WorkerSessions associated + // with the specified master task. If there are sessions created by the same + // master but with a different incarnation, it indicates that the remote + // master has restarted before deleting the sessions on worker. When it + // happens, old sessions associated with the master will be automatically + // removed before the new session is created. + absl::Status CreateSession( + const std::string& session, const ServerDef& server_def, + const protobuf::RepeatedPtrField& device_attributes, + bool isolate_session_state, std::string master_task, + int64_t master_incarnation, + StatusCallback coordination_error_callback = [](absl::Status s) { + LOG(ERROR) << "Coordination agent is set to error: " << s; + }); + + void ResetDefaultWorkerCache(WorkerCacheInterface* worker_cache); + + // Updates state (worker cache, devices) of worker session identified by + // session name (`session`) based on a new server_def and set of devices. + absl::Status UpdateSession(const std::string& session, + const ServerDef& server_def, + const protobuf::RepeatedPtrField& + cluster_device_attributes); + + // Locates the worker session for a given session handle + absl::Status WorkerSessionForSession( + const std::string& session_handle, + std::shared_ptr* out_session); + std::shared_ptr LegacySession(); + + absl::Status DeleteSession(const std::string& session); + + // Deletes all existing sessions. + absl::Status DeleteAllSessions(); + + // Provides access to the coordination service agent. This method should only + // be called after the agent has been initialized during session creation, or + // an invalid nullptr is returned. Note: the agent is thread-safe and mutable. + tsl::CoordinationServiceAgent* GetCoordinationServiceAgent(); + + static std::string WorkerNameFromServerDef(const ServerDef& server_def); + + void SetLogging(bool active); + + void RetrieveLogs(int64_t step_id, LoggingResponse* response); + + void ClearLogs(); + + // Agent should be torn down before service as it needs to disconnect first. + void TeardownCoordinationServiceAgent(); + void TeardownCoordinationService(); + + private: + WorkerEnv* const worker_env_; // Not owned. + + // A note about destruction: + // We must delete graph_mgr before device_mgr, due to shared + // ownership of OpKernels in the executors. (The graph_mgr will + // free all stateless OpKernels, and pass over borrowed stateful + // OpKernels, which are also held in their respective devices' + // OpSegments.) + // + // legacy_session_ owns the worker_env_.device_mgr, and so we must ensure + // that sessions_'s WorkerSessions are deleted (which do not own the + // underlying devices, but instead own RenamedDevices) before + // legacy_session_ is deleted. Further, we must ensure that WorkerSession's + // device_mgr is deleted after WorkerSession's graph_mgr. + + std::unique_ptr default_worker_cache_; + std::shared_ptr legacy_session_; + std::unique_ptr coordination_service_; + std::unique_ptr coordination_service_agent_; + + bool is_logging_active_ = false; + + const WorkerCacheFactory worker_cache_factory_; + + // Not owned. And should only be used for setting the coordination service. + tsl::CoordinationServiceRpcHandler* coordination_handler_ = nullptr; + + absl::Status WorkerSessionForSessionLocked( + const std::string& session_handle, + std::shared_ptr* out_session) + TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); + + mutex mu_; + // A map from session identifier to internal session structure. + std::map> sessions_ + TF_GUARDED_BY(mu_); + + // Incarnation and WorkerSession handle associated with a master task. + struct MasterAssociatedSession { + const int64_t master_incarnation; + const std::string session_handle; + }; + // A map from master task name to its associated worker sessions. + std::unordered_multimap + master_to_associated_sessions_ TF_GUARDED_BY(mu_); +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_SESSION_MGR_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/tensor_coding.h b/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/tensor_coding.h new file mode 100644 index 00000000..1fd40d95 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/tensor_coding.h @@ -0,0 +1,110 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_TENSOR_CODING_H_ +#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_TENSOR_CODING_H_ + +#include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/protobuf/worker.pb.h" + +namespace tensorflow { + +class DeviceBase; +class TensorProto; + +// TensorResponse can be used as the destination of an RPC that returns +// a RecvTensorResponse. It efficiently decodes the incoming data +// into Tensor contents as well as associated metadata. +class TensorResponse { + public: + TensorResponse() {} + + // Reset to initial state. + void Clear(); + + // Clear just tensor_ and meta_ members without setting allocation + // related members. + void ClearTensor(); + + // Initialize memory allocation related members. + void InitAlloc(DeviceBase* d, const AllocatorAttributes& aa); + + // Source provides a way for a particular RPC implementation to provide + // received data to ParseFrom. + class Source { + public: + virtual ~Source(); + + // Return the stream that contains the data to be parsed. + // Note that this method might be invoked more than once if + // ParseFrom needs to fall back to a more expensive parsing method. + // Every call must return a stream pointing at the beginning of + // the serialized RecvTensorResponse. + // + // Note that a subsequent call to contents() invalidates previous + // results of contents(). + // + // Ownership of the returned stream is retained by the Source and + // should not be deleted by the caller. + virtual ::tensorflow::protobuf::io::ZeroCopyInputStream* contents() = 0; + }; + + // Parse the RecvTensorResponse encoded in the data yielded by + // source->contents() into *this. + absl::Status ParseFrom(Source* source); + + // Initialize tensor from *response. + // Leaves *response with unspecified contents. + absl::Status InitFrom(RecvTensorResponse* response); + + // Initialize tensor metadata from response and allocate + // uninitialized backing storage for actual contents. + void InitPartial(const RecvTensorResponse& response, + const AllocationAttributes& allocation_attr); + + // Return a reference to the parsed tensor. The tensor will remain + // live only until *this is destroyed or modified. + const Tensor& tensor() const { return tensor_; } + + // Return a reference to the parsed tensor metadata (no contents). + // The result will remain live only until *this is destroyed or + // modified. + const RecvTensorResponse& metadata() const { return meta_; } + + // Return pointer to the device hosting the tensor. + DeviceBase* device() const { return device_; } + + private: + bool ParseTensorSubmessage(protobuf::io::CodedInputStream* input, + TensorProto* tensor_meta); + bool ParseFast(Source* source); + bool ParseSlow(Source* source); + + bool on_host_ = false; + DeviceBase* device_ = nullptr; + AllocatorAttributes alloc_attrs_; + Allocator* allocator_ = nullptr; + bool already_used_ = false; + Tensor tensor_; + RecvTensorResponse meta_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_TENSOR_CODING_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/test_utils.h b/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/test_utils.h new file mode 100644 index 00000000..e7ad1041 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/test_utils.h @@ -0,0 +1,202 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_TEST_UTILS_H_ +#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_TEST_UTILS_H_ + +#include +#include "tensorflow/core/distributed_runtime/worker_cache.h" +#include "tensorflow/core/distributed_runtime/worker_interface.h" +#include "tensorflow/core/util/device_name_utils.h" + +namespace tensorflow { + +// Some utilities for testing distributed-mode components in a single process +// without RPCs. + +// Implements the worker interface with methods that just respond with +// "unimplemented" status. Override just the methods needed for +// testing. +class TestWorkerInterface : public WorkerInterface { + public: + void GetStatusAsync(CallOptions* opts, const GetStatusRequest* request, + GetStatusResponse* response, bool fail_fast, + StatusCallback done) override { + done(errors::Unimplemented("GetStatusAsync")); + } + + void CreateWorkerSessionAsync(const CreateWorkerSessionRequest* request, + CreateWorkerSessionResponse* response, + StatusCallback done) override { + done(errors::Unimplemented("CreateWorkerSessionAsync")); + } + + void DeleteWorkerSessionAsync(CallOptions* opts, + const DeleteWorkerSessionRequest* request, + DeleteWorkerSessionResponse* response, + StatusCallback done) override { + done(errors::Unimplemented("DeleteWorkerSessionAsync")); + } + + void RegisterGraphAsync(const RegisterGraphRequest* request, + RegisterGraphResponse* response, + StatusCallback done) override { + done(errors::Unimplemented("RegisterGraphAsync")); + } + + void DeregisterGraphAsync(const DeregisterGraphRequest* request, + DeregisterGraphResponse* response, + StatusCallback done) override { + done(errors::Unimplemented("DeregisterGraphAsync")); + } + + void RunGraphAsync(CallOptions* opts, RunGraphRequestWrapper* request, + MutableRunGraphResponseWrapper* response, + StatusCallback done) override { + done(errors::Unimplemented("RunGraphAsync")); + } + + void CleanupGraphAsync(const CleanupGraphRequest* request, + CleanupGraphResponse* response, + StatusCallback done) override { + done(errors::Unimplemented("CleanupGraphAsync")); + } + + void CleanupAllAsync(const CleanupAllRequest* request, + CleanupAllResponse* response, + StatusCallback done) override { + done(errors::Unimplemented("CleanupAllAsync")); + } + + void RecvTensorAsync(CallOptions* opts, const RecvTensorRequest* request, + TensorResponse* response, StatusCallback done) override { + done(errors::Unimplemented("RecvTensorAsync")); + } + + void LoggingAsync(const LoggingRequest* request, LoggingResponse* response, + StatusCallback done) override { + done(errors::Unimplemented("LoggingAsync")); + } + + void TracingAsync(const TracingRequest* request, TracingResponse* response, + StatusCallback done) override { + done(errors::Unimplemented("TracingAsync")); + } + + void RecvBufAsync(CallOptions* opts, const RecvBufRequest* request, + RecvBufResponse* response, StatusCallback done) override { + done(errors::Unimplemented("RecvBufAsync")); + } + + void CompleteGroupAsync(CallOptions* opts, + const CompleteGroupRequest* request, + CompleteGroupResponse* response, + StatusCallback done) override { + done(errors::Unimplemented("CompleteGroupAsync")); + } + + void CompleteInstanceAsync(CallOptions* ops, + const CompleteInstanceRequest* request, + CompleteInstanceResponse* response, + StatusCallback done) override { + done(errors::Unimplemented("CompleteInstanceAsync")); + } + + void GetStepSequenceAsync(const GetStepSequenceRequest* request, + GetStepSequenceResponse* response, + StatusCallback done) override { + done(errors::Unimplemented("GetStepSequenceAsync")); + } +}; + +class TestWorkerCache : public WorkerCacheInterface { + public: + virtual ~TestWorkerCache() {} + + void AddWorker(const string& target, WorkerInterface* wi) { + workers_[target] = wi; + } + + void AddDevice(const string& device_name, const DeviceLocality& dev_loc) { + localities_[device_name] = dev_loc; + } + + void ListWorkers(std::vector* workers) const override { + workers->clear(); + for (auto it : workers_) { + workers->push_back(it.first); + } + } + + void ListWorkersInJob(const string& job_name, + std::vector* workers) const override { + workers->clear(); + for (auto it : workers_) { + DeviceNameUtils::ParsedName device_name; + CHECK(DeviceNameUtils::ParseFullName(it.first, &device_name)); + CHECK(device_name.has_job); + if (job_name == device_name.job) { + workers->push_back(it.first); + } + } + } + + WorkerInterface* GetOrCreateWorker(const string& target) override { + auto it = workers_.find(target); + if (it != workers_.end()) { + return it->second; + } + return nullptr; + } + + void ReleaseWorker(const string& target, WorkerInterface* worker) override {} + + absl::Status GetEagerClientCache( + std::unique_ptr* eager_client_cache) override { + return errors::Unimplemented("Unimplemented."); + } + + absl::Status GetCoordinationClientCache( + std::unique_ptr* coord_client_cache) override { + return errors::Unimplemented("Unimplemented."); + } + + bool GetDeviceLocalityNonBlocking(const string& device, + DeviceLocality* locality) override { + auto it = localities_.find(device); + if (it != localities_.end()) { + *locality = it->second; + return true; + } + return false; + } + + void GetDeviceLocalityAsync(const string& device, DeviceLocality* locality, + StatusCallback done) override { + auto it = localities_.find(device); + if (it != localities_.end()) { + *locality = it->second; + done(absl::OkStatus()); + return; + } + done(errors::Internal("Device not found: ", device)); + } + + protected: + std::unordered_map workers_; + std::unordered_map localities_; +}; + +} // namespace tensorflow +#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_TEST_UTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/worker.h b/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/worker.h new file mode 100644 index 00000000..4c55e1b9 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/worker.h @@ -0,0 +1,143 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_WORKER_H_ +#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_WORKER_H_ + +#include + +#include "tensorflow/core/distributed_runtime/graph_mgr.h" +#include "tensorflow/core/distributed_runtime/partial_run_mgr.h" +#include "tensorflow/core/distributed_runtime/recent_request_ids.h" +#include "tensorflow/core/distributed_runtime/session_mgr.h" +#include "tensorflow/core/distributed_runtime/worker_interface.h" +#include "tensorflow/core/framework/cancellation.h" + +namespace tensorflow { + +class Device; +struct WorkerEnv; +class WorkerSession; + +// A TensorFlow Worker runs registered graphs and supports worker-to-worker +// Tensor transfer. +// +// See `../protobuf/worker_service.proto` for more details about each method. +// +// This class may be subclassed to provide specialized implementations of +// particular methods for different transport mechanism. For example, +// `GrpcWorker` specializes the `RecvTensorAsync()` method to support a more +// efficient gRPC data structure for handling large binary data. +class Worker : public WorkerInterface { + public: + Worker(WorkerEnv* env); + virtual ~Worker() {} + + void GetStatusAsync(CallOptions* opts, const GetStatusRequest* request, + GetStatusResponse* response, bool fail_fast, + StatusCallback done) override; + + void CreateWorkerSessionAsync(const CreateWorkerSessionRequest* request, + CreateWorkerSessionResponse* response, + StatusCallback done) override; + + void DeleteWorkerSessionAsync(CallOptions* opts, + const DeleteWorkerSessionRequest* request, + DeleteWorkerSessionResponse* response, + StatusCallback done) override; + + void RegisterGraphAsync(const RegisterGraphRequest* request, + RegisterGraphResponse* response, + StatusCallback done) override; + + void DeregisterGraphAsync(const DeregisterGraphRequest* request, + DeregisterGraphResponse* response, + StatusCallback done) override; + + void RunGraphAsync(CallOptions* opts, RunGraphRequestWrapper* request, + MutableRunGraphResponseWrapper* response, + StatusCallback done) override; + + MutableRunGraphRequestWrapper* CreateRunGraphRequest() override; + + MutableRunGraphResponseWrapper* CreateRunGraphResponse() override; + + void CleanupGraphAsync(const CleanupGraphRequest* request, + CleanupGraphResponse* response, + StatusCallback done) override; + + void CleanupAllAsync(const CleanupAllRequest* request, + CleanupAllResponse* response, + StatusCallback done) override; + + void RecvTensorAsync(CallOptions* opts, const RecvTensorRequest* request, + TensorResponse* response, StatusCallback done) override; + + void LoggingAsync(const LoggingRequest* request, LoggingResponse* response, + StatusCallback done) override; + + void TracingAsync(const TracingRequest* request, TracingResponse* response, + StatusCallback done) override; + + void RecvBufAsync(CallOptions* opts, const RecvBufRequest* request, + RecvBufResponse* response, StatusCallback done) override; + + void CompleteGroupAsync(CallOptions* opts, + const CompleteGroupRequest* request, + CompleteGroupResponse* response, + StatusCallback done) override; + + void CompleteInstanceAsync(CallOptions* opts, + const CompleteInstanceRequest* request, + CompleteInstanceResponse* response, + StatusCallback done) override; + + void GetStepSequenceAsync(const GetStepSequenceRequest* request, + GetStepSequenceResponse* response, + StatusCallback done) override; + + protected: + WorkerEnv* const env_; // Not owned. + RecentRequestIds recent_request_ids_; + + absl::Status PrepareRecvTensor(const Rendezvous::ParsedKey& parsed, + Device** src_dev); + + void AbortStep(int64_t); + + private: + PartialRunMgr partial_run_mgr_; + + CancellationManager cancellation_manager_; + + absl::Status PrepareRunGraph(RunGraphRequestWrapper* req, + GraphMgr::NamedTensors* in, + GraphMgr::NamedTensors* out); + + void DoRunGraph(CallOptions* opts, RunGraphRequestWrapper* request, + MutableRunGraphResponseWrapper* response, + StatusCallback done); + + void DoPartialRunGraph(CallOptions* opts, RunGraphRequestWrapper* request, + MutableRunGraphResponseWrapper* response, + StatusCallback done); + + Worker(const Worker&) = delete; + void operator=(const Worker&) = delete; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_WORKER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/worker_cache.h b/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/worker_cache.h new file mode 100644 index 00000000..1ac4de35 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/worker_cache.h @@ -0,0 +1,96 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_WORKER_CACHE_H_ +#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_WORKER_CACHE_H_ + +#include +#include + +#include "tensorflow/core/distributed_runtime/coordination/coordination_client.h" +#include "tensorflow/core/distributed_runtime/eager/eager_client.h" +#include "tensorflow/core/distributed_runtime/worker_interface.h" +#include "tensorflow/core/framework/device_attributes.pb.h" // for DeviceLocality +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { +typedef std::function StatusCallback; + +class ChannelCache; +class StepStats; + +class WorkerCacheInterface { + public: + virtual ~WorkerCacheInterface() {} + + // Updates *workers with strings naming the remote worker tasks to + // which open channels have been established. + virtual void ListWorkers(std::vector* workers) const = 0; + virtual void ListWorkersInJob(const string& job_name, + std::vector* workers) const = 0; + + // If "target" names a remote task for which an RPC channel exists + // or can be constructed, returns a pointer to a WorkerInterface object + // wrapping that channel. The returned value must be destroyed by + // calling `this->ReleaseWorker(target, ret)` + virtual WorkerInterface* GetOrCreateWorker(const string& target) = 0; + + // Release a worker previously returned by this->GetOrCreateWorker(target). + // + // TODO(jeff,sanjay): Consider moving target into WorkerInterface. + // TODO(jeff,sanjay): Unify all worker-cache impls and factor out a + // per-rpc-subsystem WorkerInterface creator. + virtual void ReleaseWorker(const string& target, WorkerInterface* worker) { + // Subclasses may override to reuse worker objects. + delete worker; + } + + // Set *locality with the DeviceLocality of the specified remote device + // within its local environment. Returns true if *locality + // was set, using only locally cached data. Returns false + // if status data for that device was not available. Never blocks. + virtual bool GetDeviceLocalityNonBlocking(const string& device, + DeviceLocality* locality) = 0; + + // Set *locality with the DeviceLocality of the specified remote device + // within its local environment. Callback gets Status::OK if *locality + // was set. + virtual void GetDeviceLocalityAsync(const string& device, + DeviceLocality* locality, + StatusCallback done) = 0; + + // TODO(b/189159585): Define a general client cache maker function to + // construct client cache of different types sharing the same underling RPC + // channels, to replace the eager and coordination cache function. + // Build and return a EagerClientCache object wrapping that channel. + virtual absl::Status GetEagerClientCache( + std::unique_ptr* eager_client_cache) = 0; + + // Build and return a CoordinationClientCache object wrapping that channel. + virtual absl::Status GetCoordinationClientCache( + std::unique_ptr* coordination_client_cache) = 0; + + // Start/stop logging activity. + virtual void SetLogging(bool active) {} + + // Discard any saved log data. + virtual void ClearLogs() {} + + // Return logs for the identified step in *ss. Any returned data will no + // longer be stored. + virtual bool RetrieveLogs(int64_t step_id, StepStats* ss) { return false; } +}; +} // namespace tensorflow +#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_WORKER_CACHE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/worker_cache_logger.h b/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/worker_cache_logger.h new file mode 100644 index 00000000..f5ef19bf --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/worker_cache_logger.h @@ -0,0 +1,89 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_WORKER_CACHE_LOGGER_H_ +#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_WORKER_CACHE_LOGGER_H_ + +#include +#include + +#include "tensorflow/core/framework/step_stats.pb.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/thread_annotations.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +class StepStatsCollector; + +// WorkerCacheLogger is a thread-safe utility for use by a WorkerCache +// to optionally log some selected RPC activity. A single instance +// should be owned by a WorkerCache, for use by its RemoteWorker +// instances. + +class WorkerCacheLogger { + public: + // Start/Stop logging activity. This function increments/decrements + // a counter so that if two separate steps turn logging on/off, + // logging should be on for the union of the durations of both, + // regardless of relative timing. + void SetLogging(bool v); + + // Discard any saved log data. + void ClearLogs(); + + // Return logs for the identified step in *ss. Any returned data will no + // longer be stored. Returns true iff *ss was modified. + bool RetrieveLogs(int64_t step_id, StepStats* ss); + + // Return true if there is any outstanding request for logging on + // the RPC channels. + bool LoggingActive() { + mutex_lock l(count_mu_); + return want_logging_count_ > 0; + } + + // Generates a NodeExecStats record with the given data, and saves for + // later retrieval by RetrieveLogs(). + void RecordRecvTensor(int64_t step_id, int64_t start_usecs, int64_t end_usecs, + const string& tensor_name, const string& src_device, + const string& dst_device, int64_t bytes); + + // Generates a NodeExecStats record with the given data, and saves for + // later retrieval by RetrieveLogs(). + void RecordDataTransfer(int64_t step_id, int64_t start_usecs, + int64_t end_usecs, const string& tensor_name, + const string& src_device, const string& dst_device, + int64_t bytes, const string& details, + const string& transfer_method_name); + + private: + mutex count_mu_; + int32 want_logging_count_ TF_GUARDED_BY(count_mu_) = 0; + + struct StepLog { + StepStats step_stats; + StepStatsCollector* collector; + }; + typedef std::unordered_map LogMap; + mutex mu_; + LogMap log_map_ TF_GUARDED_BY(mu_); + + // Records "ns" in log_map_ under the given device and step. + void Save(const string& device, int64_t step_id, NodeExecStats* ns); + + void ClearLogsWithLock() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); +}; +} // namespace tensorflow +#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_WORKER_CACHE_LOGGER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/worker_cache_partial.h b/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/worker_cache_partial.h new file mode 100644 index 00000000..b5a500b8 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/worker_cache_partial.h @@ -0,0 +1,57 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_WORKER_CACHE_PARTIAL_H_ +#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_WORKER_CACHE_PARTIAL_H_ + +#include +#include + +#include "tensorflow/core/distributed_runtime/worker_cache.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/thread_annotations.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/protobuf/worker.pb.h" + +namespace tensorflow { + +// Implements the part of the interface that caches and returns remote +// device status attributes. +class WorkerCachePartial : public WorkerCacheInterface { + public: + bool GetDeviceLocalityNonBlocking(const string& device, + DeviceLocality* locality) override; + + void GetDeviceLocalityAsync(const string& device, DeviceLocality* locality, + StatusCallback) override; + + ~WorkerCachePartial() override {} + + // Clear all entries from the DeviceStatus cache. + void FlushStatusCache(); + + private: + mutex mu_; + + // Initiate a GetStatusAsync to the remote task named by "task", and + // update the cache with all the DeviceAttributes reported. + absl::Status RefreshDeviceStatus(const string& device_name); + + typedef std::unordered_map StatusMap; + StatusMap device_status_cache_ TF_GUARDED_BY(mu_); +}; + +} // namespace tensorflow +#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_WORKER_CACHE_PARTIAL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/worker_cache_wrapper.h b/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/worker_cache_wrapper.h new file mode 100644 index 00000000..7f709b4f --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/worker_cache_wrapper.h @@ -0,0 +1,101 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_WORKER_CACHE_WRAPPER_H_ +#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_WORKER_CACHE_WRAPPER_H_ + +#include +#include + +#include "tensorflow/core/distributed_runtime/worker_cache.h" + +namespace tensorflow { + +class WorkerCacheWrapper : public WorkerCacheInterface { + public: + WorkerCacheWrapper(WorkerCacheInterface* wrapped) : wrapped_(wrapped) {} + + // Updates *workers with strings naming the remote worker tasks to + // which open channels have been established. + void ListWorkers(std::vector* workers) const override { + return wrapped_->ListWorkers(workers); + } + void ListWorkersInJob(const string& job_name, + std::vector* workers) const override { + return wrapped_->ListWorkersInJob(job_name, workers); + } + + // If "target" names a remote task for which an RPC channel exists + // or can be constructed, returns a pointer to a WorkerInterface object + // wrapping that channel. The returned value must be destroyed by + // calling `this->ReleaseWorker(target, ret)` + WorkerInterface* GetOrCreateWorker(const string& target) override { + return wrapped_->GetOrCreateWorker(target); + } + + // Release a worker previously returned by this->GetOrCreateWorker(target). + // + // TODO(jeff,sanjay): Consider moving target into WorkerInterface. + // TODO(jeff,sanjay): Unify all worker-cache impls and factor out a + // per-rpc-subsystem WorkerInterface creator. + void ReleaseWorker(const string& target, WorkerInterface* worker) override { + return wrapped_->ReleaseWorker(target, worker); + } + + absl::Status GetEagerClientCache( + std::unique_ptr* eager_client_cache) override { + return wrapped_->GetEagerClientCache(eager_client_cache); + } + + absl::Status GetCoordinationClientCache( + std::unique_ptr* coordination_client_cache) + override { + return wrapped_->GetCoordinationClientCache(coordination_client_cache); + } + + // Set *locality with the DeviceLocality of the specified remote device + // within its local environment. Returns true if *locality + // was set, using only locally cached data. Returns false + // if status data for that device was not available. Never blocks. + bool GetDeviceLocalityNonBlocking(const string& device, + DeviceLocality* locality) override { + return wrapped_->GetDeviceLocalityNonBlocking(device, locality); + } + + // Set *locality with the DeviceLocality of the specified remote device + // within its local environment. Callback gets Status::OK if *locality + // was set. + void GetDeviceLocalityAsync(const string& device, DeviceLocality* locality, + StatusCallback done) override { + return wrapped_->GetDeviceLocalityAsync(device, locality, std::move(done)); + } + + // Start/stop logging activity. + void SetLogging(bool active) override { wrapped_->SetLogging(active); } + + // Discard any saved log data. + void ClearLogs() override { wrapped_->ClearLogs(); } + + // Return logs for the identified step in *ss. Any returned data will no + // longer be stored. + bool RetrieveLogs(int64_t step_id, StepStats* ss) override { + return wrapped_->RetrieveLogs(step_id, ss); + } + + private: + WorkerCacheInterface* wrapped_; // Not owned. +}; +} // namespace tensorflow +#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_WORKER_CACHE_WRAPPER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/worker_env.h b/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/worker_env.h new file mode 100644 index 00000000..350c3e5f --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/worker_env.h @@ -0,0 +1,79 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_WORKER_ENV_H_ +#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_WORKER_ENV_H_ + +#include + +#include "tensorflow/core/platform/types.h" + +namespace tsl { +class Env; +namespace thread { +class ThreadPool; +} // namespace thread +} // namespace tsl +namespace tensorflow { +using Env = tsl::Env; + +namespace thread { +using tsl::thread::ThreadPool; +} // namespace thread + +class CollectiveExecutorMgrInterface; +class Device; +class DeviceMgr; +class RendezvousMgrInterface; +class SessionMgr; + +// The worker environment class, which holds a bag of pointers to +// per-worker singletons. +// +// WorkerEnv does not own its member pointers. +struct WorkerEnv { + Env* env = nullptr; + + // session_mgr encapsulates state for each session. + SessionMgr* session_mgr = nullptr; + + // In large scaled distributed training, many singleton components (e.g. + // Rendezvous) can becomes the bottleneck of the system. This field allows + // us to shard the single components. This number will scale up with number + // of tasks in this cluster. It is always greater than 1. + int experimental_num_shards = 1; + + // device_mgr manages local devices (cpu and gpu). The WorkerService + // is the network interface for managed devices. + // + // Note: Please use the device_mgr associated with your session if appropriate + // instead of this one. Using this device_mgr does not support ClusterSpec + // propagated sessions. + DeviceMgr* device_mgr = nullptr; + + // A set of rendezvous keyed by step ids. + RendezvousMgrInterface* rendezvous_mgr = nullptr; + + // Generates per-step CollectiveExecutors and has access to utilities + // supporting collective operations. + std::unique_ptr collective_executor_mgr; + + // A pool of threads for scheduling compute work. + thread::ThreadPool* compute_pool = nullptr; +}; + +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_WORKER_ENV_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/worker_interface.h b/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/worker_interface.h new file mode 100644 index 00000000..382425bb --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/worker_interface.h @@ -0,0 +1,236 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_WORKER_INTERFACE_H_ +#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_WORKER_INTERFACE_H_ + +#include + +#include "tensorflow/core/distributed_runtime/call_options.h" +#include "tensorflow/core/distributed_runtime/message_wrappers.h" +#include "tensorflow/core/lib/core/notification.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/protobuf/worker.pb.h" + +namespace tensorflow { + +// Status callback. +typedef std::function StatusCallback; + +// Custom decoder for a response to RecvTensorAsync. +class TensorResponse; + +// Interface for talking with the TensorFlow Worker service. +class WorkerInterface { + public: + virtual void GetStatusAsync(CallOptions* opts, + const GetStatusRequest* request, + GetStatusResponse* response, bool fail_fast, + StatusCallback done) = 0; + + virtual void CreateWorkerSessionAsync( + const CreateWorkerSessionRequest* request, + CreateWorkerSessionResponse* response, StatusCallback done) = 0; + + virtual void DeleteWorkerSessionAsync( + CallOptions* opts, const DeleteWorkerSessionRequest* request, + DeleteWorkerSessionResponse* response, StatusCallback done) = 0; + + virtual void RegisterGraphAsync(const RegisterGraphRequest* request, + RegisterGraphResponse* response, + StatusCallback done) = 0; + + virtual void DeregisterGraphAsync(const DeregisterGraphRequest* request, + DeregisterGraphResponse* response, + StatusCallback done) = 0; + + virtual void RunGraphAsync(CallOptions* opts, RunGraphRequestWrapper* request, + MutableRunGraphResponseWrapper* response, + StatusCallback done) = 0; + + virtual void RunGraphAsync(CallOptions* opts, const RunGraphRequest* request, + RunGraphResponse* response, StatusCallback done) { + RunGraphRequestWrapper* wrapped_request = new ProtoRunGraphRequest(request); + MutableRunGraphResponseWrapper* wrapped_response = + new NonOwnedProtoRunGraphResponse(response); + RunGraphAsync(opts, wrapped_request, wrapped_response, + [wrapped_request, wrapped_response, + done = std::move(done)](const absl::Status& s) { + done(s); + delete wrapped_request; + delete wrapped_response; + }); + } + + // Returns a request object for use in calls to + // `RunGraphAsync()`. Ownership is transferred to the caller. + // + // The message returned from this method must only be used in a + // `RunGraph()` call on the same `WorkerInterface` instance. + virtual MutableRunGraphRequestWrapper* CreateRunGraphRequest() { + return new MutableProtoRunGraphRequest; + } + + // Returns a response object for use in calls to + // `RunGraphAsync()`. Ownership is transferred to the caller. + // + // The message returned from this method must only be used in a + // `RunGraph()` call on the same `WorkerInterface` instance. + virtual MutableRunGraphResponseWrapper* CreateRunGraphResponse() { + return new OwnedProtoRunGraphResponse; + } + + virtual void CleanupGraphAsync(const CleanupGraphRequest* request, + CleanupGraphResponse* response, + StatusCallback done) = 0; + + virtual void CleanupAllAsync(const CleanupAllRequest* request, + CleanupAllResponse* response, + StatusCallback done) = 0; + + virtual void RecvTensorAsync(CallOptions* opts, + const RecvTensorRequest* request, + TensorResponse* response, + StatusCallback done) = 0; + + virtual void LoggingAsync(const LoggingRequest* request, + LoggingResponse* response, StatusCallback done) = 0; + + virtual void TracingAsync(const TracingRequest* request, + TracingResponse* response, StatusCallback done) = 0; + + virtual void RecvBufAsync(CallOptions* opts, const RecvBufRequest* request, + RecvBufResponse* response, StatusCallback done) = 0; + + virtual void CompleteGroupAsync(CallOptions* opts, + const CompleteGroupRequest* request, + CompleteGroupResponse* response, + StatusCallback done) = 0; + + virtual void CompleteInstanceAsync(CallOptions* ops, + const CompleteInstanceRequest* request, + CompleteInstanceResponse* response, + StatusCallback done) = 0; + + virtual void GetStepSequenceAsync(const GetStepSequenceRequest* request, + GetStepSequenceResponse* response, + StatusCallback done) = 0; + + absl::Status GetStatus(const GetStatusRequest* request, + GetStatusResponse* response) { + absl::Status ret; + Notification n; + GetStatusAsync(/*opts=*/nullptr, request, response, /*fail_fast=*/true, + [&ret, &n](const absl::Status& s) { + ret = s; + n.Notify(); + }); + n.WaitForNotification(); + return ret; + } + + absl::Status CreateWorkerSession(const CreateWorkerSessionRequest* request, + CreateWorkerSessionResponse* response) { + return CallAndWait(&ME::CreateWorkerSessionAsync, request, response); + } + + absl::Status DeleteWorkerSession(const DeleteWorkerSessionRequest* request, + DeleteWorkerSessionResponse* response) { + return CallAndWaitWithOptions(&ME::DeleteWorkerSessionAsync, request, + response); + } + + absl::Status RegisterGraph(const RegisterGraphRequest* request, + RegisterGraphResponse* response) { + return CallAndWait(&ME::RegisterGraphAsync, request, response); + } + + absl::Status DeregisterGraph(const DeregisterGraphRequest* request, + DeregisterGraphResponse* response) { + return CallAndWait(&ME::DeregisterGraphAsync, request, response); + } + + absl::Status CleanupGraph(const CleanupGraphRequest* request, + CleanupGraphResponse* response) { + return CallAndWait(&ME::CleanupGraphAsync, request, response); + } + + absl::Status CleanupAll(const CleanupAllRequest* request, + CleanupAllResponse* response) { + return CallAndWait(&ME::CleanupAllAsync, request, response); + } + + absl::Status Logging(const LoggingRequest* request, + LoggingResponse* response) { + return CallAndWait(&ME::LoggingAsync, request, response); + } + + absl::Status Tracing(const TracingRequest* request, + TracingResponse* response) { + return CallAndWait(&ME::TracingAsync, request, response); + } + + absl::Status GetStepSequence(const GetStepSequenceRequest* request, + GetStepSequenceResponse* response) { + return CallAndWait(&ME::GetStepSequenceAsync, request, response); + } + + protected: + // Instances of WorkerInterface must be deleted by a call to + // WorkerCacheInterface::ReleaseWorker(). + virtual ~WorkerInterface() {} + friend class WorkerCacheInterface; + + // NOTE: This should only be called by implementations of this + // interface whose CreateRunGraphResponse() method returns a + // proto-based wrappers for the RunGraphResponse message. + RunGraphResponse* get_proto_from_wrapper( + MutableRunGraphResponseWrapper* wrapper) { + return wrapper->get_proto(); + } + + private: + typedef WorkerInterface ME; + + template + absl::Status CallAndWait(Method func, const Req* req, Resp* resp) { + absl::Status ret; + Notification n; + (this->*func)(req, resp, [&ret, &n](const absl::Status& s) { + ret = s; + n.Notify(); + }); + n.WaitForNotification(); + return ret; + } + + template + absl::Status CallAndWaitWithOptions(Method func, const Req* req, Resp* resp) { + CallOptions call_opts; + absl::Status ret; + Notification n; + (this->*func)(&call_opts, req, resp, [&ret, &n](const absl::Status& s) { + ret = s; + n.Notify(); + }); + n.WaitForNotification(); + return ret; + } +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_WORKER_INTERFACE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/worker_session.h b/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/worker_session.h new file mode 100644 index 00000000..e366accf --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/distributed_runtime/worker_session.h @@ -0,0 +1,134 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_WORKER_SESSION_H_ +#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_WORKER_SESSION_H_ + +#include +#include +#include +#include + +#include "tensorflow/core/common_runtime/device_mgr.h" +#include "tensorflow/core/distributed_runtime/graph_mgr.h" +#include "tensorflow/core/distributed_runtime/worker_cache.h" +#include "tensorflow/core/framework/function.h" + +namespace tensorflow { + +class ClusterFunctionLibraryRuntime; +class GraphMgr; +class WorkerCacheInterface; + +// WorkerSession encapsulates all of the state relating to a given session. +class WorkerSession { + public: + using DistributedFunctionLibraryRuntimeCreator = + std::function( + WorkerSession* worker_session, bool create_worker_session_called, + DeviceMgr* remote_device_mgr)>; + + // Collection of local devices. These devices are typically + // RenamedDevices in all except the SessionMgr.legacy_session_ and + // sessions created with `isolate_session_state == false`. In the + // those cases, this method returns a pointer to a borrowed + // DeviceMgr (typically the `worker_env.device_mgr`). + DeviceMgr* device_mgr() { + return device_mgr_ ? device_mgr_.get() : borrowed_device_mgr_; + } + + DynamicDeviceMgr* remote_device_mgr() { return remote_device_mgr_.get(); } + + const string& session_name() const { return session_name_; } + const string& worker_name() const { return worker_name_; } + + WorkerCacheInterface* worker_cache() const { + tf_shared_lock l(worker_session_state_mu_); + return worker_cache_.get(); + } + GraphMgr* graph_mgr() const { return graph_mgr_.get(); } + + DistributedFunctionLibraryRuntime* cluster_flr() const { + return cluster_flr_.get(); + } + + WorkerSession(const string& session_name, const string& worker_name, + std::unique_ptr worker_cache, + std::unique_ptr device_mgr, + std::unique_ptr graph_mgr, + std::unique_ptr remote_device_mgr, + DistributedFunctionLibraryRuntimeCreator cluster_flr_creator); + + static std::shared_ptr CreateWithBorrowedDeviceMgr( + const string& session_name, const string& worker_name, + std::unique_ptr worker_cache, + DeviceMgr* borrowed_device_mgr, std::unique_ptr graph_mgr, + std::unique_ptr remote_device_mgr, + DistributedFunctionLibraryRuntimeCreator cluster_flr_creator); + + // In the eager runtime we allow WorkerSession to be updated, where the + // worker cache will be recreated. If WorkerSession upate is expected and a + // worker in the cache is used in RPCs, the caller should hold a shared + // pointer to avoid the workers getting deleted. + std::shared_ptr GetSharedWorkerCache() { + tf_shared_lock l(worker_session_state_mu_); + return worker_cache_; + } + + // Update an existing worker session with new set of remote workers and + // devices. Added devices will be owned by the worker session, and removed + // devices will be freed by their names. + absl::Status UpdateWorkerCacheAndDevices( + std::unique_ptr new_worker_cache, + std::vector> added_remote_devices, + const std::vector& removed_remote_devices); + + ~WorkerSession(); + + private: + WorkerSession(const string& session_name, const string& worker_name, + std::unique_ptr worker_cache, + DeviceMgr* borrowed_device_mgr, + std::unique_ptr graph_mgr, + std::unique_ptr remote_device_mgr, + DistributedFunctionLibraryRuntimeCreator cluster_flr_creator); + + // The name of the session. + const string session_name_; + + // The name of the worker. E.g., /job:mnist/replica:0/task:1. + const string worker_name_; + + mutable mutex worker_session_state_mu_; + // Object from which WorkerInterface instances can be obtained. + std::shared_ptr worker_cache_ + TF_GUARDED_BY(worker_session_state_mu_); + + // graph_mgr keeps track of the registered graphs of this session. + // + // Note: graph_mgr must be deleted before rendezvous_mgr! + // Note: graph_mgr must be deleted before device_mgr! + const std::unique_ptr graph_mgr_; + + std::unique_ptr cluster_flr_; + + const std::unique_ptr device_mgr_; + DeviceMgr* const borrowed_device_mgr_; // Not owned. + std::unique_ptr remote_device_mgr_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_WORKER_SESSION_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/example/example_parser_configuration.h b/third_party/tflite-hdrs/tensorflow/core/example/example_parser_configuration.h new file mode 100644 index 00000000..dd2aacae --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/example/example_parser_configuration.h @@ -0,0 +1,56 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_EXAMPLE_EXAMPLE_PARSER_CONFIGURATION_H_ +#define TENSORFLOW_CORE_EXAMPLE_EXAMPLE_PARSER_CONFIGURATION_H_ + +#include +#include + +#include "tensorflow/core/example/example.pb.h" +#include "tensorflow/core/example/example_parser_configuration.pb.h" +#include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/public/session.h" +#include "tensorflow/core/util/example_proto_helper.h" +#include "tensorflow/core/util/sparse/sparse_tensor.h" + +// This is a set of helper methods that will make it possible to share +// tensorflow::Example proto Tensor conversion code inside the ExampleParserOp +// OpKernel as well as in external code. +namespace tensorflow { + +// Given a graph and the node_name of a ParseExample op, +// extract the FixedLenFeature/VarLenFeature configurations. +absl::Status ExtractExampleParserConfiguration( + const tensorflow::GraphDef& graph, const string& node_name, + tensorflow::Session* session, + std::vector* fixed_len_features, + std::vector* var_len_features); + +// Given a config proto, ostensibly extracted via python, +// fill a vector of C++ structs suitable for calling +// the tensorflow.Example -> Tensor conversion code. +absl::Status ExampleParserConfigurationProtoToFeatureVectors( + const ExampleParserConfiguration& config_proto, + std::vector* fixed_len_features, + std::vector* var_len_features); + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_EXAMPLE_EXAMPLE_PARSER_CONFIGURATION_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/example/feature_util.h b/third_party/tflite-hdrs/tensorflow/core/example/feature_util.h new file mode 100644 index 00000000..092fabe6 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/example/feature_util.h @@ -0,0 +1,644 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// A set of lightweight wrappers which simplify access to Feature protos. +// +// TensorFlow Example proto uses associative maps on top of oneof fields. +// SequenceExample proto uses associative map of FeatureList. +// So accessing feature values is not very convenient. +// +// For example, to read a first value of integer feature "tag": +// int id = example.features().feature().at("tag").int64_list().value(0); +// +// to add a value: +// auto features = example->mutable_features(); +// (*features->mutable_feature())["tag"].mutable_int64_list()->add_value(id); +// +// For float features you have to use float_list, for string - bytes_list. +// +// To do the same with this library: +// int id = GetFeatureValues("tag", example).Get(0); +// GetFeatureValues("tag", &example)->Add(id); +// +// Modification of bytes features is slightly different: +// auto tag = GetFeatureValues("tag", &example); +// *tag->Add() = "lorem ipsum"; +// +// To copy multiple values into a feature: +// AppendFeatureValues({1,2,3}, "tag", &example); +// +// GetFeatureValues gives you access to underlying data - RepeatedField object +// (RepeatedPtrField for byte list). So refer to its documentation of +// RepeatedField for full list of supported methods. +// +// NOTE: Due to the nature of oneof proto fields setting a feature of one type +// automatically clears all values stored as another type with the same feature +// key. +// +// This library also has tools to work with SequenceExample protos. +// +// To get a value from SequenceExample.context: +// int id = GetFeatureValues("tag", se.context()).Get(0); +// To add a value to the context: +// GetFeatureValues("tag", se.mutable_context())->Add(42); +// +// To add values to feature_lists: +// AppendFeatureValues({4.0}, +// GetFeatureList("images", &se)->Add()); +// AppendFeatureValues({5.0, 3.0}, +// GetFeatureList("images", &se)->Add()); +// This will create a feature list keyed as "images" with two features: +// feature_lists { +// feature_list { +// key: "images" +// value { +// feature { float_list { value: [4.0] } } +// feature { float_list { value: [5.0, 3.0] } } +// } +// } +// } +// For string-valued features, note that the Append... and Set... functions +// support absl::string_view containers. This allows you to copy existing +// buffers into a Feature with only one copy: +// std::vector image; +// image.push_back(image_buffer); // No copy. +// SetFeatureValues(image, "image", &example); // Copy. +// +// Functions exposed by this library: +// HasFeature<[FeatureType]>(key, proto) -> bool +// Returns true if a feature with the specified key, and optionally +// FeatureType, belongs to the Features or Example proto. +// HasFeatureList(key, sequence_example) -> bool +// Returns true if SequenceExample has a feature_list with the key. +// +// GetFeatureValues(key, proto) -> RepeatedField +// Returns values for the specified key and the FeatureType. +// Supported types for the proto: Example, Features. +// GetFeatureList(key, sequence_example) -> RepeatedPtrField +// Returns Feature protos associated with a key. +// +// AppendFeatureValues(begin, end, feature) +// AppendFeatureValues(container or initializer_list, feature) +// Copies values into a Feature. +// AppendFeatureValues(begin, end, key, proto) +// AppendFeatureValues(container or initializer_list, key, proto) +// Copies values into Features and Example protos with the specified key. +// +// ClearFeatureValues(feature) +// Clears the feature's repeated field of the given type. +// +// SetFeatureValues(begin, end, feature) +// SetFeatureValues(container or initializer_list, feature) +// Clears a Feature, then copies values into it. +// SetFeatureValues(begin, end, key, proto) +// SetFeatureValues(container or initializer_list, key, proto) +// Clears Features or Example protos with the specified key, +// then copies values into them. +// +// Auxiliary functions, it is unlikely you'll need to use them directly: +// GetFeatures(proto) -> Features +// A convenience function to get Features proto. +// Supported types for the proto: Example, Features. +// GetFeature(key, proto) -> Feature +// Returns a Feature proto for the specified key. +// Supported types for the proto: Example, Features. +// GetFeatureValues(feature) -> RepeatedField +// Returns values of the feature for the FeatureType. + +#ifndef TENSORFLOW_CORE_EXAMPLE_FEATURE_UTIL_H_ +#define TENSORFLOW_CORE_EXAMPLE_FEATURE_UTIL_H_ + +#include +#include +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "tensorflow/core/example/example.pb.h" +#include "tensorflow/core/example/feature.pb.h" +#include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/platform/stringpiece.h" + +// Must come after the import for absl::string_view. +#ifdef ABSL_HAVE_STD_STRING_VIEW +#include +#endif + +namespace tensorflow { +namespace internal { + +// TODO(gorban): Update all clients in a followup CL. +// Returns a reference to a feature corresponding to the name. +// Note: it will create a new Feature if it is missing in the example. +ABSL_DEPRECATED("Use GetFeature instead.") +Feature& ExampleFeature(absl::string_view name, Example* example); + +// Specializations of RepeatedFieldTrait define a type of RepeatedField +// corresponding to a selected feature type. +template +struct RepeatedFieldTrait; + +template <> +struct RepeatedFieldTrait { + using Type = protobuf::RepeatedField; +}; + +template <> +struct RepeatedFieldTrait { + using Type = protobuf::RepeatedField; +}; + +template <> +struct RepeatedFieldTrait { + using Type = protobuf::RepeatedPtrField; +}; + +template <> +struct RepeatedFieldTrait { + using Type = protobuf::RepeatedPtrField; +}; + +// Specializations of FeatureTrait define a type of feature corresponding to a +// selected value type. +template +struct FeatureTrait; + +template +struct FeatureTrait::value>::type> { + using Type = protobuf_int64; +}; + +template +struct FeatureTrait< + ValueType, + typename std::enable_if::value>::type> { + using Type = float; +}; + +template +struct is_string + : public std::integral_constant< + bool, + std::is_same::type>::value || + std::is_same::type>::value> { +}; + +template <> +struct is_string : std::true_type {}; + +template <> +struct is_string : std::true_type {}; + +template <> +struct is_string : std::true_type {}; + +template +struct FeatureTrait< + ValueType, typename std::enable_if::value>::type> { + using Type = std::string; +}; + +// Port of the C++20 `requires` expressions. +template +constexpr bool Requires(F) { + return std::is_invocable::value; +} + +struct NoneSuch {}; + +// True if the Feature map in a tf.Example supports heterogenous lookup. +// See https://abseil.io/tips/144. +// TODO(b/365531379): this cannot be replaced by a lambda because it exposes a +// Clang bug when used in modules. +struct CheckFindFunctor { + template + auto operator()(Container&& c) -> decltype(c.find(NoneSuch{})) {} +}; +inline constexpr bool kFeatureMapHasHeterogeneousLookup = + Requires( + CheckFindFunctor()); + +// Converts an `absl::string_view` into a string-type compatible for use in the +// protobuf library (e.g. as lookup keys in `proto2::Map` or as elements addable +// to a `proto2::RepeatedPtrField`) depending on the BUILD mode. +// +// NOTE: While the newest versions of `proto2::Map` support heterogenous lookup, +// it does so through `std::string_view`. If the type is just an alias (as noted +// by `ABSL_USES_STD_STRING_VIEW`) then nothing more needs to be done; however, +// when the type is not an alias an explicit conversion to is necessary. +// +// NOTE: This conversion is only necessary until the migration for protobuf to +// take a dependency on ABSL is complete. +inline auto ProtoMapKey(absl::string_view str) { + if constexpr (kFeatureMapHasHeterogeneousLookup) { +#ifdef ABSL_USES_STD_STRING_VIEW + return str; +#else +#ifdef ABSL_HAVE_STD_STRING_VIEW + return std::string_view(str.data(), str.size()); +#else + return std::string(str); +#endif +#endif + } else { + return std::string(str); + } +} + +} // namespace internal + +// Returns true if sequence_example has a feature_list with the specified key. +bool HasFeatureList(absl::string_view key, + const SequenceExample& sequence_example); + +template +struct TypeHasFeatures : std::false_type {}; + +template <> +struct TypeHasFeatures : std::true_type {}; + +template <> +struct TypeHasFeatures : std::true_type {}; + +template <> +struct TypeHasFeatures : std::true_type {}; + +// A family of template functions to return mutable Features proto from a +// container proto. Supported ProtoTypes: SequenceExample, Example, Features. +template +typename std::enable_if::value, Features*>::type +GetFeatures(ProtoType* proto); + +template <> +Features* GetFeatures(Features* proto); +template <> +Features* GetFeatures(Example* proto); +template <> +Features* GetFeatures(SequenceExample* proto); + +template +typename std::enable_if::value, + const Features&>::type +GetFeatures(const ProtoType& proto); + +template <> +const Features& GetFeatures(const Features& proto); +template <> +const Features& GetFeatures(const Example& proto); +template <> +const Features& GetFeatures(const SequenceExample& proto); + +// Base declaration of a family of template functions to return a read only +// repeated field of feature values. +template +const typename internal::RepeatedFieldTrait::Type& +GetFeatureValues(const Feature& feature); + +template <> +const protobuf::RepeatedField& GetFeatureValues( + const Feature& feature); +template <> +const protobuf::RepeatedField& GetFeatureValues( + const Feature& feature); +template <> +const protobuf::RepeatedPtrField& GetFeatureValues( + const Feature& feature); +template <> +const protobuf::RepeatedPtrField& GetFeatureValues( + const Feature& feature); + +// Returns a read only repeated field corresponding to a feature with the +// specified name and FeatureType. Supported ProtoTypes: SequenceExample, +// Example, Features. +template +const typename internal::RepeatedFieldTrait::Type& +GetFeatureValues(absl::string_view key, const ProtoType& proto) { + return GetFeatureValues( + GetFeatures(proto).feature().at(internal::ProtoMapKey(key))); +} + +// Returns a mutable repeated field of a feature values. +template +typename internal::RepeatedFieldTrait::Type* GetFeatureValues( + Feature* feature); + +template <> +protobuf::RepeatedField* GetFeatureValues( + Feature* feature); +template <> +protobuf::RepeatedField* GetFeatureValues(Feature* feature); +template <> +protobuf::RepeatedPtrField* GetFeatureValues( + Feature* feature); +template <> +protobuf::RepeatedPtrField* GetFeatureValues( + Feature* feature); + +// Returns a mutable repeated field corresponding to a feature with the +// specified name and FeatureType. Supported ProtoTypes: SequenceExample, +// Example, Features. +template +typename internal::RepeatedFieldTrait::Type* GetFeatureValues( + absl::string_view key, ProtoType* proto) { + ::tensorflow::Feature& feature = + (*GetFeatures(proto)->mutable_feature())[internal::ProtoMapKey(key)]; + return GetFeatureValues(&feature); +} + +// Returns a read-only Feature proto for the specified key, throws +// std::out_of_range if the key is not found. Supported types for the proto: +// SequenceExample, Example, Features. +template +const Feature& GetFeature(absl::string_view key, const ProtoType& proto) { + return GetFeatures(proto).feature().at(internal::ProtoMapKey(key)); +} + +// Returns a read-only Feature proto for the specified key, returns nullptr +// if the key is not found. Supported types for the proto: SequenceExample, +// Example, Features. +template +const Feature* MaybeGetFeature(absl::string_view key, const ProtoType& proto) { + const protobuf::Map& feature_map = + GetFeatures(proto).feature(); + auto it = feature_map.find(internal::ProtoMapKey(key)); + + if (it == feature_map.end()) { + return nullptr; + } + + return &it->second; +} + +// Base declaration of a family of template functions to return a read only +// repeated field of feature values or nullptr. +template +const typename internal::RepeatedFieldTrait::Type* +MaybeGetFeatureValues(const Feature& feature); + +template <> +const protobuf::RepeatedField* +MaybeGetFeatureValues(const Feature& feature); +template <> +const protobuf::RepeatedField* MaybeGetFeatureValues( + const Feature& feature); +template <> +const protobuf::RepeatedPtrField* MaybeGetFeatureValues( + const Feature& feature); +template <> +const protobuf::RepeatedPtrField* +MaybeGetFeatureValues(const Feature& feature); + +// Returns a read only repeated field corresponding to a feature with the +// specified name and FeatureType. Supported ProtoTypes: SequenceExample, +// Example, Features. +template +const typename internal::RepeatedFieldTrait::Type* +MaybeGetFeatureValues(absl::string_view key, const ProtoType& proto) { + const Feature* feature = MaybeGetFeature(key, proto); + if (feature == nullptr) { + return nullptr; + } + return &GetFeatureValues(*feature); +} + +// Returns a mutable Feature proto for the specified key, creates a new if +// necessary. Supported types for the proto: SequenceExample, Example, Features. +template +Feature* GetFeature(absl::string_view key, ProtoType* proto) { + return &(*GetFeatures(proto)->mutable_feature())[internal::ProtoMapKey(key)]; +} + +// Returns a repeated field with features corresponding to a feature_list key. +const protobuf::RepeatedPtrField& GetFeatureList( + absl::string_view key, const SequenceExample& sequence_example); + +// Returns a mutable repeated field with features corresponding to a +// feature_list key. It will create a new FeatureList if necessary. +protobuf::RepeatedPtrField* GetFeatureList( + absl::string_view feature_list_key, SequenceExample* sequence_example); + +template +void AppendFeatureValues(IteratorType first, IteratorType last, + Feature* feature) { + using FeatureType = typename internal::FeatureTrait< + typename std::iterator_traits::value_type>::Type; + auto& values = *GetFeatureValues(feature); + values.Reserve(std::distance(first, last)); + for (auto it = first; it != last; ++it) { + *values.Add() = *it; + } +} + +template +void AppendFeatureValues(std::initializer_list container, + Feature* feature) { + using FeatureType = typename internal::FeatureTrait::Type; + auto& values = *GetFeatureValues(feature); + values.Reserve(container.size()); + for (auto& elt : container) { + *values.Add() = std::move(elt); + } +} + +namespace internal { + +// HasSize::value is true_type if T has a size() member. +template +struct HasSize : std::false_type {}; + +template +struct HasSize().size())>> + : std::true_type {}; + +// Reserves the container's size, if a container.size() method exists. +template +auto ReserveIfSizeAvailable(const ContainerType& container, + RepeatedFieldType& values) -> + typename std::enable_if_t::value, void> { + values.Reserve(container.size()); +} + +template +auto ReserveIfSizeAvailable(const ContainerType& container, + RepeatedFieldType& values) -> + typename std::enable_if_t::value, void> {} + +} // namespace internal + +template +void AppendFeatureValues(const ContainerType& container, Feature* feature) { + using IteratorType = typename ContainerType::const_iterator; + using FeatureType = typename internal::FeatureTrait< + typename std::iterator_traits::value_type>::Type; + auto* values = GetFeatureValues(feature); + internal::ReserveIfSizeAvailable(container, *values); + // This is equivalent to std::copy into `values` with a + // RepeatedFieldBackInserter, the difference is RFBI isn't compatible with + // types that we want to convert (e.g. absl::string_view -> std::string). + for (const auto& elt : container) { + if constexpr (internal::is_string::value) { + *values->Add() = std::string(elt); + } else { + *values->Add() = elt; + } + } +} + +// Copies elements from the range, defined by [first, last) into the feature +// obtainable from the (proto, key) combination. +template +void AppendFeatureValues(IteratorType first, IteratorType last, + absl::string_view key, ProtoType* proto) { + AppendFeatureValues(first, last, GetFeature(key, GetFeatures(proto))); +} + +// Copies all elements from the container into a feature. +template +void AppendFeatureValues(const ContainerType& container, absl::string_view key, + ProtoType* proto) { + AppendFeatureValues(container, + GetFeature(key, GetFeatures(proto))); +} + +// Copies all elements from the initializer list into a Feature contained by +// Features or Example proto. +template +void AppendFeatureValues(std::initializer_list container, + absl::string_view key, ProtoType* proto) { + AppendFeatureValues(container, + GetFeature(key, GetFeatures(proto))); +} + +// Clears the feature's repeated field (int64, float, or string). +template +void ClearFeatureValues(Feature* feature); + +template <> +void ClearFeatureValues(Feature* feature); +template <> +void ClearFeatureValues(Feature* feature); +template <> +void ClearFeatureValues(Feature* feature); +template <> +void ClearFeatureValues(Feature* feature); + +// Clears the feature's repeated field (int64, float, or string). Copies +// elements from the range, defined by [first, last) into the feature's repeated +// field. +template +void SetFeatureValues(IteratorType first, IteratorType last, Feature* feature) { + using FeatureType = typename internal::FeatureTrait< + typename std::iterator_traits::value_type>::Type; + ClearFeatureValues(feature); + AppendFeatureValues(first, last, feature); +} + +// Clears the feature's repeated field (int64, float, or string). Copies all +// elements from the initializer list into the feature's repeated field. +template +void SetFeatureValues(std::initializer_list container, + Feature* feature) { + using FeatureType = typename internal::FeatureTrait::Type; + ClearFeatureValues(feature); + AppendFeatureValues(container, feature); +} + +// Clears the feature's repeated field (int64, float, or string). Copies all +// elements from the container into the feature's repeated field. +template +void SetFeatureValues(const ContainerType& container, Feature* feature) { + using IteratorType = typename ContainerType::const_iterator; + using FeatureType = typename internal::FeatureTrait< + typename std::iterator_traits::value_type>::Type; + ClearFeatureValues(feature); + AppendFeatureValues(container, feature); +} + +// Clears the feature's repeated field (int64, float, or string). Copies +// elements from the range, defined by [first, last) into the feature's repeated +// field. +template +void SetFeatureValues(IteratorType first, IteratorType last, + absl::string_view key, ProtoType* proto) { + SetFeatureValues(first, last, GetFeature(key, GetFeatures(proto))); +} + +// Clears the feature's repeated field (int64, float, or string). Copies all +// elements from the container into the feature's repeated field. +template +void SetFeatureValues(const ContainerType& container, absl::string_view key, + ProtoType* proto) { + SetFeatureValues(container, + GetFeature(key, GetFeatures(proto))); +} + +// Clears the feature's repeated field (int64, float, or string). Copies all +// elements from the initializer list into the feature's repeated field. +template +void SetFeatureValues(std::initializer_list container, + absl::string_view key, ProtoType* proto) { + SetFeatureValues(container, GetFeature(key, GetFeatures(proto))); +} + +// Returns true if a feature with the specified key belongs to the Features. +// The template parameter pack accepts zero or one template argument - which +// is FeatureType. If the FeatureType not specified (zero template arguments) +// the function will not check the feature type. Otherwise it will return false +// if the feature has a wrong type. +template +bool HasFeature(absl::string_view key, const Features& features); + +template <> +bool HasFeature<>(absl::string_view key, const Features& features); +template <> +bool HasFeature(absl::string_view key, + const Features& features); +template <> +bool HasFeature(absl::string_view key, const Features& features); +template <> +bool HasFeature(absl::string_view key, const Features& features); +template <> +bool HasFeature(absl::string_view key, const Features& features); + +// Returns true if a feature with the specified key belongs to the Example. +// Doesn't check feature type if used without FeatureType, otherwise the +// specialized versions return false if the feature has a wrong type. +template +bool HasFeature(absl::string_view key, const Example& example) { + return HasFeature(key, GetFeatures(example)); +} + +// Returns true if a feature with the specified key belongs to the +// SequenceExample. Doesn't check feature type if used without FeatureType, +// otherwise the specialized versions return false if the feature has a wrong +// type. +template +bool HasFeature(absl::string_view key, + const SequenceExample& sequence_example) { + return HasFeature(key, GetFeatures(sequence_example)); +} + +// TODO(gorban): update all clients in a followup CL. +template +ABSL_DEPRECATED("Use HasFeature instead.") +bool ExampleHasFeature(absl::string_view key, const Example& example) { + return HasFeature(key, example); +} + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_EXAMPLE_FEATURE_UTIL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/framework/allocator.h b/third_party/tflite-hdrs/tensorflow/core/framework/allocator.h new file mode 100644 index 00000000..dbf2c29f --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/framework/allocator.h @@ -0,0 +1,55 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_FRAMEWORK_ALLOCATOR_H_ +#define TENSORFLOW_CORE_FRAMEWORK_ALLOCATOR_H_ + +#include + +#include +#include + +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "xla/tsl/framework/allocator.h" +#include "tensorflow/core/framework/numeric_types.h" +#include "tensorflow/core/framework/type_traits.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/numa.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +// NOLINTBEGIN(misc-unused-using-decls) +using tsl::AllocationAttributes; +using tsl::Allocator; +using tsl::AllocatorAttributes; +using tsl::AllocatorMemoryType; +using tsl::AllocatorStats; +using tsl::AllocatorWrapper; +using tsl::cpu_allocator; +using tsl::cpu_allocator_base; +using tsl::CPUAllocatorFullStatsEnabled; +using tsl::CPUAllocatorStatsEnabled; +using tsl::DisableCPUAllocatorStats; +using tsl::EnableCPUAllocatorFullStats; +using tsl::EnableCPUAllocatorStats; +using tsl::SubAllocator; +// NOLINTEND(misc-unused-using-decls) + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_FRAMEWORK_ALLOCATOR_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/framework/allocator_registry.h b/third_party/tflite-hdrs/tensorflow/core/framework/allocator_registry.h new file mode 100644 index 00000000..7bc03241 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/framework/allocator_registry.h @@ -0,0 +1,40 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Classes to maintain a static registry of memory allocator factories. +#ifndef TENSORFLOW_CORE_FRAMEWORK_ALLOCATOR_REGISTRY_H_ +#define TENSORFLOW_CORE_FRAMEWORK_ALLOCATOR_REGISTRY_H_ + +#include +#include + +#include "xla/tsl/framework/allocator_registry.h" +#include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/numa.h" + +namespace tensorflow { + +// NOLINTBEGIN(misc-unused-using-decls) +using tsl::AllocatorFactory; +using tsl::AllocatorFactoryRegistration; +using tsl::AllocatorFactoryRegistry; +using tsl::ProcessStateInterface; +// NOLINTEND(misc-unused-using-decls) + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_FRAMEWORK_ALLOCATOR_REGISTRY_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/framework/attr_value_util.h b/third_party/tflite-hdrs/tensorflow/core/framework/attr_value_util.h new file mode 100644 index 00000000..b6f7c972 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/framework/attr_value_util.h @@ -0,0 +1,142 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_FRAMEWORK_ATTR_VALUE_UTIL_H_ +#define TENSORFLOW_CORE_FRAMEWORK_ATTR_VALUE_UTIL_H_ + +#include +#include +#include + +#include "tensorflow/core/framework/partial_tensor_shape.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/gtl/array_slice.h" + +namespace tensorflow { + +namespace attr_value_util_internal { +// Return the size of the tensor represented by this TensorProto. If shape is +// not fully defined return -1. +int64_t TensorByteSize(const TensorProto& t); +} // namespace attr_value_util_internal + +// Forward declare protos so their symbols can be removed from .so exports +class AttrValue; +class NameAttrList; + +// A human-readable rendering of attr_value, that is more concise than a +// text-format proto. +std::string SummarizeAttrValue(const AttrValue& attr_value); + +// Generates an error if attr_value doesn't have the indicated attr type. +absl::Status AttrValueHasType(const AttrValue& attr_value, + absl::string_view type); + +// Converts a text proto value from "text" into the field of *out +// indicated by "type" (e.g. from the type field of an AttrDef). +// Examples: +// * If type:"int" and text:"-14", then *out is set to "i: -14" +// * If type:"list(string)" and text:"['foo', 'bar']", +// then *out is set to "list { s: ['foo', 'bar'] }" +// Returns true on success. +bool ParseAttrValue(absl::string_view type, absl::string_view text, + AttrValue* out); + +// Sets *out based on the type of value. +void SetAttrValue(const std::string& value, AttrValue* out); +void SetAttrValue(const tstring& value, AttrValue* out); +void SetAttrValue(const char* value, AttrValue* out); +void SetAttrValue(absl::string_view value, AttrValue* out); +void SetAttrValue(int64_t value, AttrValue* out); +void SetAttrValue(int32_t value, AttrValue* out); +void SetAttrValue(float value, AttrValue* out); +void SetAttrValue(double value, AttrValue* out); +void SetAttrValue(bool value, AttrValue* out); +void SetAttrValue(DataType value, AttrValue* out); +void SetAttrValue(const TensorShape& value, AttrValue* out); +void SetAttrValue(const TensorShapeProto& value, AttrValue* out); +void SetAttrValue(const PartialTensorShape& value, AttrValue* out); +void SetAttrValue(const Tensor& value, AttrValue* out); +void SetAttrValue(const TensorProto& value, AttrValue* out); +void SetAttrValue(const NameAttrList& value, AttrValue* out); + +void SetAttrValue(absl::Span value, AttrValue* out); +void SetAttrValue(absl::Span value, AttrValue* out); +void SetAttrValue(absl::Span value, AttrValue* out); +void SetAttrValue(absl::Span value, AttrValue* out); +void SetAttrValue(absl::Span value, AttrValue* out); +void SetAttrValue(absl::Span value, AttrValue* out); +void SetAttrValue(absl::Span value, AttrValue* out); +void SetAttrValue(absl::Span value, AttrValue* out); +void SetAttrValue(absl::Span value, AttrValue* out); +void SetAttrValue(const std::vector& value, AttrValue* out); +void SetAttrValue(std::initializer_list value, AttrValue* out); +void SetAttrValue(DataTypeSlice value, AttrValue* out); +void SetAttrValue(absl::Span value, AttrValue* out); +void SetAttrValue(absl::Span value, AttrValue* out); +void SetAttrValue(absl::Span value, AttrValue* out); +void SetAttrValue(absl::Span value, AttrValue* out); +void SetAttrValue(absl::Span value, AttrValue* out); +void SetAttrValue(absl::Span value, AttrValue* out); + +void SetAttrValue(const AttrValue& value, AttrValue* out); + +void MoveAttrValue(std::vector&& value, AttrValue* out); + +// Returns a hash of `a` that is consistent with AreAttrValuesEqual. In other +// words, if two AttrValues compare equal according to AreAttrValuesEqual, +// they will have the same hash value. +// Similarly to protobuf deterministic serialization, hash value is +// guaranteed to be stable only for a given binary. In particular, one should +// probably not persist the returned value. +uint64 AttrValueHash(const AttrValue& a); + +// WARNING: Equality check might return false-negative for large (> 32mb) +// tensors defined with different TensorProto representations. +// +// A pair of consistent hash and equals functions that are guaranteed to be fast +// with AttrValues that potentially can have very large Tensors (larger than +// 32mb) defined by TensorProto. If large identical Tensors are defined using +// different representations (e.g. one with tensor content, and second with +// bool_val), they will have different hash code and equals will return false. +// Small (less than 32mb) tensors with different TensorProto representations +// hashed/compared by their tensor content. +uint64 FastAttrValueHash(const AttrValue& a); +// Returns true if a and b have the same value. If false negatives are allowed, +// then compares proto representation to avoid construction of large (> 32mb) +// tensors. +bool AreAttrValuesEqual(const AttrValue& a, const AttrValue& b, + bool allow_false_negatives = false); + +// Returns true if "val" has a placeholder. +bool HasPlaceHolder(const AttrValue& val); + +// SubstitutePlaceholders recursively replaces placeholders in 'value' +// with an attr value by calling SubstituteFunc. Returns true iff all +// placeholders in "value" are replaced with a value. +// +// SubstituteFunc is given a placeholder string. If the placeholder is +// unknown, SubstituteFunc returns false. Otherwise, overwrites the +// attr value and returns true. +using SubstituteFunc = std::function; +bool SubstitutePlaceholders(const SubstituteFunc& substitute, AttrValue* value); + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_FRAMEWORK_ATTR_VALUE_UTIL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/framework/bfloat16.h b/third_party/tflite-hdrs/tensorflow/core/framework/bfloat16.h new file mode 100644 index 00000000..4f13039d --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/framework/bfloat16.h @@ -0,0 +1,61 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_FRAMEWORK_BFLOAT16_H_ +#define TENSORFLOW_CORE_FRAMEWORK_BFLOAT16_H_ + +#include "tensorflow/core/framework/numeric_types.h" +#include "tensorflow/core/platform/types.h" + +// Compact 16-bit encoding of floating point numbers. This representation uses +// 1 bit for the sign, 8 bits for the exponent and 7 bits for the mantissa. It +// is assumed that floats are in IEEE 754 format so the representation is just +// bits 16-31 of a single precision float. +// +// NOTE: The IEEE floating point standard defines a float16 format that +// is different than this format (it has fewer bits of exponent and more +// bits of mantissa). We don't use that format here because conversion +// to/from 32-bit floats is more complex for that format, and the +// conversion for this format is very simple. +// +// Because of the existing IEEE float16 type, we do not name our representation +// "float16" but just use "uint16". +// +// <-----our 16bits float-------> +// s e e e e e e e e f f f f f f f f f f f f f f f f f f f f f f f +// <------------------------------float--------------------------> +// 3 3 2 2 1 1 0 +// 1 0 3 2 5 4 0 +// +// +// This type only supports conversion back and forth with float. +// +// This file must be compilable by nvcc. +// +// The type is defined in framework/numeric_types.h. + +namespace tensorflow { + +// Convert from float to bfloat16 with rounding-to-nearest-even. +void RoundFloatToBFloat16(const float* src, bfloat16* dst, int64_t size); +// Convert from float to bfloat16 with truncation. Notice this conversion is +// lossy since it truncates the float to 7 mantissa bits without rounding. +void FloatToBFloat16(const float* src, bfloat16* dst, int64_t size); +// Convert from bfloat16 to float. This conversion is lossless. +void BFloat16ToFloat(const bfloat16* src, float* dst, int64_t size); + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_FRAMEWORK_BFLOAT16_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/framework/bounds_check.h b/third_party/tflite-hdrs/tensorflow/core/framework/bounds_check.h new file mode 100644 index 00000000..76e6e6dd --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/framework/bounds_check.h @@ -0,0 +1,54 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_FRAMEWORK_BOUNDS_CHECK_H_ +#define TENSORFLOW_CORE_FRAMEWORK_BOUNDS_CHECK_H_ + +#include + +#include "Eigen/Core" // from @eigen_archive +#include "tensorflow/core/platform/macros.h" + +namespace tensorflow { + +// Check that 0 <= index < limit using a single comparison, assuming +// that 0 <= limit if Index is signed. Intended for use in performance +// critical contexts where 0 <= index < limit is almost always true. +template +EIGEN_ALWAYS_INLINE EIGEN_DEVICE_FUNC bool FastBoundsCheck(const Ta index, + const Tb limit) { + static_assert(std::is_integral::value && std::is_integral::value, + "FastBoundsCheck can only be used on integer types."); + typedef typename std::make_unsigned::type UIndex; + return TF_PREDICT_TRUE(static_cast(index) < + static_cast(limit)); +} + +namespace internal { +// Ensure that the compiler cannot elide a copy into a local, for +// bounds checking on source tensors that might be updated asynchronously. +// This function may only be used on primitive integral types (int32, int64, +// etc). It does not guarantee any atomicity or barriers. +template +EIGEN_ALWAYS_INLINE EIGEN_DEVICE_FUNC const T SubtleMustCopy(const T &x) { + static_assert(std::is_integral::value, + "SubtleMustCopy can only be used on integer types."); + auto *to_x = reinterpret_cast(&x); + return *to_x; +} +} // namespace internal +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_FRAMEWORK_BOUNDS_CHECK_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/framework/cancellation.h b/third_party/tflite-hdrs/tensorflow/core/framework/cancellation.h new file mode 100644 index 00000000..522de22c --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/framework/cancellation.h @@ -0,0 +1,39 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_FRAMEWORK_CANCELLATION_H_ +#define TENSORFLOW_CORE_FRAMEWORK_CANCELLATION_H_ + +#include "xla/tsl/framework/cancellation.h" +#include "tensorflow/core/lib/core/notification.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/gtl/flatmap.h" +#include "tensorflow/core/lib/hash/hash.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/stringpiece.h" +#include "tensorflow/core/platform/thread_annotations.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +// NOLINTBEGIN(misc-unused-using-decls) +using tsl::CancelCallback; +using tsl::CancellationManager; +using tsl::CancellationToken; +using tsl::RegisterCancellationCallback; +// NOLINTEND(misc-unused-using-decls) +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_FRAMEWORK_CANCELLATION_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/framework/collective.h b/third_party/tflite-hdrs/tensorflow/core/framework/collective.h new file mode 100644 index 00000000..8fca00f0 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/framework/collective.h @@ -0,0 +1,522 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_FRAMEWORK_COLLECTIVE_H_ +#define TENSORFLOW_CORE_FRAMEWORK_COLLECTIVE_H_ + +#include +#include + +#include "absl/container/flat_hash_set.h" +#include "tensorflow/core/framework/cancellation.h" +#include "tensorflow/core/framework/device_attributes.pb.h" +#include "tensorflow/core/framework/device_base.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/lib/core/refcount.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/intrusive_ptr.h" + +namespace tensorflow { + +class BufRendezvous; +class CompleteGroupRequest; +class CompleteGroupResponse; +class CompleteInstanceRequest; +class CompleteInstanceResponse; +class Device; +class DeviceMgr; +class GetStepSequenceRequest; +class GetStepSequenceResponse; +class NcclManager; +class Tensor; + +// Types of supported collective operations. +enum CollectiveType { + REDUCTION_COLLECTIVE = 0, + BROADCAST_COLLECTIVE, + GATHER_COLLECTIVE, + PERMUTE_COLLECTIVE, + ALL_TO_ALL_COLLECTIVE, + REDUCE_SCATTER_COLLECTIVE, + UNDEFINED_COLLECTIVE, +}; + +// Some collective op implementations require runtime group configuration from +// the OpKernel. Currently, this struct is used to set communicator key for +// NCCL-based collective implementation. +struct CollGroupRuntimeDetails { + string communicator_key; // for communicator-based techniques e.g. NCCL + string ToString() const; +}; + +struct CollGroupMember { + DeviceAttributes device; + string task; + bool is_local; + // User provided rank + int32 rank = -1; +}; + +// Data common to all members of a device group. +// All members share the same device set but its order is +// particular to an instance so it is stored there. +struct CollGroupParams { + // Inputs from Collective ops: + int32 group_key; + int32 group_size; + DeviceType device_type; + int user_specified_rank = -1; // rank provided by the user. + // Generated from Collective Group Resolver: + // Members in this group, in default rank order. + std::vector members; + // True if every task has the same number of devices. + bool same_num_devices_per_task = false; + // Task -> number of devices on that task. + std::unordered_map num_devices_per_task; + int32 num_tasks; // number of distinct tasks in group + CollGroupRuntimeDetails runtime_details; + string ToString() const; + CollGroupParams() + : group_key(0), group_size(0), device_type(DEVICE_CPU), num_tasks(0) {} +}; + +// The best implementation of a collective op depends on many factors +// including the number of devices involved, the topology of +// interconnects between them and the sizes of inputs. This structure +// is used in generating and representing data movement choreography +// for each specific algorithm, hence it does not have a single, fixed +// interpretation. On first execution the runtime will update this +// structure with decisions that will guide all subsequent executions. +struct CollImplDetails { + string collective_name; + std::vector> subdiv_permutations; + // subdiv_offsets and max_subdivs_per_device are used together as follows: + // When subdiv_offsets is provided (non-empty) it is used as is. When + // subdiv_offsets is not provided subdivisons are generated dynamically + // constrained by max_subdivs_per_device. When subdiv_offsets is empty AND + // max_subdivs_per_device = 0 an internal default kMaxSubdivsPerDeviceDefault + // is used. When max_subdivs_per_device = -1, no subivision is done. + int max_subdivs_per_device = -1; // Upper bound on subdivisions per device. + std::vector subdiv_offsets; + std::vector subdiv_source_rank; // rank of source in each subdiv + std::vector + dependencies; // collective instances on which this node depends + string communication_hint; // user-supplied hint for implementation choice, + // e.g. ring or nccl + float timeout_seconds; // If non zero, set a completion timeout for the + // collective op to detect staleness. +}; + +// Data common to all members of a collective instance. +// TODO(b/163171014) Refactor this struct to not be a union of all fields. +struct CollInstanceParams { + // Identifies all participating graph nodes. + int32 instance_key = -1; + // The full identifier includes both instance_key and step_id. + int64_t step_id = 0; + CollectiveType type = UNDEFINED_COLLECTIVE; + DataType data_type = DT_FLOAT; + TensorShape shape = {0}; + CollImplDetails impl_details; + string ToString() const; + CollInstanceParams& operator=(const struct CollInstanceParams& other); + std::vector devices; // permuter only + + // For permuter only + // Each rank in the permutation is a receiver. + // Indices of each rank means a sender to that rank. + // Example: permutation = {2,0,1} means + // rank 0 sends to rank 2 + // rank 1 sends to rank 0 + // rank 2 sends to rank 1 + std::vector permutation; +}; + +// Unique to a single CollectiveOp node. +struct CollectiveParams : public core::RefCounted { + CollGroupParams group; + CollInstanceParams instance; + + string name = ""; // node name used only for log or error messages + int default_rank = -1; // index of this op within device_names + bool is_source = false; // broadcast only + int source_rank = -1; // broadcast only + // Rank of this device in each subdivision permutation. + std::vector subdiv_rank; + OpKernel* merge_op = nullptr; // reduction only + OpKernel* final_op = nullptr; // reduction only + string ToString() const; + bool run_group_initialization = true; + bool is_stateless = false; +}; + +class CollectiveExecutor; + +// Interface that provides resolution of device localities. +class DeviceResolverInterface { + public: + virtual ~DeviceResolverInterface() {} + + // Populates *attributes with the DeviceAttributes of the specified device. + virtual absl::Status GetDeviceAttributes(const string& device, + DeviceAttributes* attributes) = 0; + + // Returns all device attributes of a task. + virtual absl::Status GetAllDeviceAttributes( + const string& task, std::vector* attributes) = 0; + + // Updates device attributes. It returns error if any device already + // exists in the DeviceResolver and has a different incarnation. + virtual absl::Status UpdateDeviceAttributes( + const std::vector& attributes) = 0; +}; + +// Interface that provides resolution of shared CollectiveParams fields. +class ParamResolverInterface { + public: + virtual ~ParamResolverInterface() {} + + // Called by each collective op at first execution in order to fill out + // the CollectiveParams structure with data gathered from the full + // (maybe distributed) collection of peer nodes. + virtual void CompleteParamsAsync(const DeviceAttributes& device, + CollectiveParams* cp, + CancellationManager* cancel_mgr, + const StatusCallback& done) = 0; + + // Completes group_params with data gathered from all devices in the group. + // This blocks until all devices are there. + virtual void CompleteGroupAsync(const DeviceAttributes& device, + CollGroupParams* group_params, + CancellationManager* cancel_mgr, + const StatusCallback& done) = 0; + + // Used within a distributed implementation to discover/verify data + // shared across an instance group. + // Note: this works differently from CompleteGroupAsync as a refactor is in + // progress. + virtual void CompleteInstanceAsync(const CompleteInstanceRequest* request, + CompleteInstanceResponse* response, + CancellationManager* cancel_mgr, + const StatusCallback& done) = 0; + + // Looks up a group. It returns an error if the group is not ready or not + // found. + virtual absl::Status LookupGroup(int32_t group_key, + CollGroupParams* group) = 0; + + // Aborts the resolver. After abortion the resolver can no longer be used. + virtual void StartAbort(const absl::Status& s) = 0; +}; + +// Graphs which utilize Collective Ops in a common instance must +// execute with identical step_ids even if they are disjoint graphs +// run by otherwise independent tasks. This interface supplies +// coordinated step_ids to use in such cases. +class StepSequenceInterface { + public: + virtual ~StepSequenceInterface() {} + + // Used with a distributed implementation to coordinate step_id + // sequences across tasks. + virtual void GetStepSequenceAsync(const GetStepSequenceRequest* request, + GetStepSequenceResponse* response, + const StatusCallback& done) = 0; + + // Refresh the local per-graph_key step_id sequence from collective + // group leader, if applicable. + virtual void RefreshStepIdSequenceAsync(int64_t graph_key, + const StatusCallback& done) = 0; + + // Returns the step_id that should be used for initiating a new execution + // on the specified graph. May return the same step_id multiple times if + // RetireStepId or RefreshStepIdReservation is not called. + virtual int64_t NextStepId(int64_t graph_key) = 0; + + // Reports that execution of the given step has completed successfully. + // Should be called immediately after a step completes with OK status, + // prior to calling NextStepId(). If the step fails, don't call. + virtual void RetireStepId(int64_t graph_key, int64_t step_id) = 0; +}; + +class NcclCommunicatorInterface; + +// Interface that provides access to per-step CollectiveExecutor +// instances and various distributed resolution capabilities. +class CollectiveExecutorMgrInterface : public StepSequenceInterface { + public: + ~CollectiveExecutorMgrInterface() override {} + + // Returns the step-specific CollectiveExecutor, creating if one does not + // already exist. The caller assumes ownership of one Ref on the object. + virtual CollectiveExecutor* FindOrCreate(int64_t step_id) = 0; + + // If there is a CollectiveExecutor for step_id, remove it from the + // table. + virtual void Cleanup(int64_t step_id) = 0; + + // Cleanup the entire table, removing all entries for step_ids. + virtual void CleanupAll() = 0; + + virtual ParamResolverInterface* GetParamResolver() const = 0; + + virtual DeviceResolverInterface* GetDeviceResolver() const = 0; + + virtual NcclCommunicatorInterface* GetNcclCommunicator() const = 0; +}; + +// Interface that a Collective Op implementation uses to exchange data +// with peers. Note that data exchange is currently limited to types +// for which DMAHelper::CanUseDMA() returns true, i.e. dense numeric +// types. +class CollectiveRemoteAccess { + public: + virtual ~CollectiveRemoteAccess() {} + + virtual void RecvFromPeer(const string& peer_device, const string& peer_task, + bool peer_is_local, const string& key, + Device* to_device, DeviceContext* to_device_ctx, + const AllocatorAttributes& to_alloc_attr, + Tensor* to_tensor, + const DeviceLocality& client_locality, + int dev_to_dev_stream_index, + CancellationManager* cancellation_manager, + const StatusCallback& done) = 0; + + virtual void PostToPeer(const string& peer_device, const string& peer_task, + const string& key, Device* from_device, + DeviceContext* from_device_ctx, + const AllocatorAttributes& from_alloc_attr, + const Tensor* from_tensor, + const DeviceLocality& client_locality, + CancellationManager* cancellation_manager, + const StatusCallback& done) = 0; + + // Checks the health of a collective peer. It probes the peer to see if it is + // alive. Note that if a peer has restarted, it's considered a different one, + // so CheckPeerHealth fails. + virtual void CheckPeerHealth(const string& peer_task, int64_t timeout_in_ms, + const StatusCallback& done) = 0; + + virtual BufRendezvous* buf_rendezvous() = 0; + + virtual void StartAbort(const absl::Status& s) = 0; +}; + +// A step-specific object that can execute a collective operation completely +// described by a CollectiveParams object. +class CollectiveExecutor : public core::RefCounted { + public: + virtual void StartAbort(const absl::Status& s) {} + + virtual void ExecuteAsync(OpKernelContext* ctx, + const CollectiveParams* col_params, + const string& exec_key, StatusCallback done) { + done(errors::Internal( + "A collective Op has been called in a context in which " + "a CollectiveExecutor has not been provided.")); + } + + virtual void CompleteParamsAsync(const DeviceAttributes& device, + CollectiveParams* cp, + CancellationManager* cancel_mgr, + StatusCallback done) { + done(errors::Internal( + "A collective Op has been called in a context in which " + "a CollectiveExecutor has not been provided.")); + } + + virtual void CompleteGroupAsync(const DeviceAttributes& device, + CollGroupParams* group_params, + CancellationManager* cancel_mgr, + StatusCallback done) { + return cem_->GetParamResolver()->CompleteGroupAsync(device, group_params, + cancel_mgr, done); + } + + virtual absl::Status LookupGroup(int32_t group_key, CollGroupParams* group) { + return cem_->GetParamResolver()->LookupGroup(group_key, group); + } + + // Runs the potentially-blocking closure/expensive callback. + virtual void RunClosure(std::function closure) = 0; + + virtual CollectiveRemoteAccess* remote_access() { return nullptr; } + + // `WaitForDependencies` and `Launched` are used for fine-grained control of + // execution order between collective instances. These functions are intended + // to be called in `Run` function of collective implementations, and may be + // used to make part, or whole, of the collective execution ordered with + // respect to other collective instances. + // + // `WaitForDependencies` will block until it is safe to continue the callee's + // execution, where safety is defined as: ordered with respect to the + // collective instances defined in the callee's `wait_for` attribute. + virtual void WaitForDependencies(const CollectiveParams& col_params) {} + // `UnblockDependencies` unblocks the dependent collective instances by + // recording that this caller's device has completed the critical portion of + // the collective execution. + virtual void UnblockDependencies(const CollectiveParams& col_params) {} + + // Used to designate an invalid group or instance key. + static int64_t kInvalidId; + + // Lexically scoped handle for Ref. + class Handle { + public: + explicit Handle(CollectiveExecutor* ce, bool inherit_ref) : ce_(ce) { + if (!inherit_ref) ce->Ref(); + } + ~Handle() { ce_->Unref(); } + CollectiveExecutor* get() const { return ce_; } + + private: + CollectiveExecutor* ce_; + }; + + protected: + explicit CollectiveExecutor(CollectiveExecutorMgrInterface* cem) + : cem_(cem) {} + + // For use only by derived classes + static OpKernelContext::Params* CtxParams(OpKernelContext* ctx); + CollectiveExecutorMgrInterface* cem_; + + CollectiveExecutor(const CollectiveExecutor&) = delete; + void operator=(const CollectiveExecutor&) = delete; +}; + +struct CollectiveContext { + CollectiveExecutor* col_exec; // Not owned + NcclCommunicatorInterface* nccl_communicator; // Not owned + const DeviceMgr* dev_mgr; // Not owned + OpKernelContext* op_ctx; // Not owned + OpKernelContext::Params* op_params; // Not owned + core::IntrusivePtr col_params; + const string exec_key; + const int64_t step_id; + const Tensor* input; // Not owned + Tensor* output; // Not owned + Device* device; // The device for which this instance labors + const string device_name; + DeviceLocality device_locality; + + CollectiveContext(CollectiveExecutor* col_exec, + NcclCommunicatorInterface* nccl_communicator, + const DeviceMgr* dev_mgr, OpKernelContext* ctx, + OpKernelContext::Params* op_params, + const CollectiveParams* col_params, const string& exec_key, + int64_t step_id, const Tensor* input, Tensor* output); +}; + +class NcclCommunicatorInterface { + public: + virtual ~NcclCommunicatorInterface() = default; + + virtual string GenerateCommunicatorKey() = 0; + + virtual void Enqueue(std::shared_ptr col_ctx, + StatusCallback done) = 0; + + virtual void StartAbort(const absl::Status& s) = 0; +}; + +// Interface of a Collective Op implementation. Each specific CollectiveOp will +// implement this interface and register the implementation via the +// CollectiveRegistry detailed below. See common_runtime/ring_reducer and +// common_runtime/hierarchical_tree_broadcaster for examples. +class CollectiveImplementationInterface : public core::RefCounted { + public: + ~CollectiveImplementationInterface() override = default; + + // Initializes the portions of `col_params` specific to this + // implementation. Called exactly once for every Collective instance during + // the CollectiveParams resolution process when the graph is first executed, + // at the end of `CompleteInstanceLocal()`. + // NOTE(ayushd): This is effectively a static function because it modifies the + // `col_params` passed in and should not manipulate any data members. However + // because it is virtual and needs to be implemented by every derived class we + // do not mark it as static. + virtual absl::Status InitializeCollectiveParams( + CollectiveParams* col_params) = 0; + + // Prepares the CollectiveContext for executing this CollectiveImplementation. + // Called from CollectiveExecutor right before calling Run(). The + // CollectiveContext passed in must outlive the CollectiveImplementation + // object. + virtual absl::Status InitializeCollectiveContext( + std::shared_ptr col_ctx) = 0; + + // Processes and moves data according to the logic of this Collective + // implementation. Relies on appropriate initialization of op-specific + // CollectiveParams in InitializeCollectiveParams(), as well as appropriate + // context initialization in InitializeCollectiveContext(). + virtual void Run(StatusCallback done) = 0; +}; + +// Static-methods only class for registering and looking up collective +// implementations. +class CollectiveRegistry { + public: + using Factory = std::function; + // Looks up a previously registered CollectiveImplementation under + // `collective_name`. If found, creates an instance of the implementation and + // assign to `implementation`. + static absl::Status Lookup( + const string& collective_name, + CollectiveImplementationInterface** implementation); + + // Looks up a previously registered CollectiveImplementation under + // `collective_name`. If found, returns the static instance of this + // implementation via `implementation`. This instance should only be used to + // call InitializateCollectiveParams. + static absl::Status LookupParamResolverInstance( + const string& collective_name, + CollectiveImplementationInterface** implementation); + + // Returns all registered collective implementations. + static void GetAll( + std::vector* implementations); + + private: + friend class CollectiveRegistration; + // Registers a CollectiveImplementation with name `collective_name` and + // factory `factory`. The latter is a function used to create instances of + // the CollectiveImplementation. Also creates a static instance of the + // implementation - this instance is used during param resolution and should + // only be used to call InitializeCollectiveParams. + static absl::Status Register(const string& collective_name, Factory factory); + + static absl::Status LookupHelper( + const string& collective_name, + CollectiveImplementationInterface** implementation, bool param_resolver); +}; + +// Class used to call CollectiveRegistry::Register. This should only be used to +// create a global static object. +class CollectiveRegistration { + public: + CollectiveRegistration(const string& collective_name, + CollectiveRegistry::Factory factory) { + TF_CHECK_OK(CollectiveRegistry::Register(collective_name, factory)); + } +}; + +#define REGISTER_COLLECTIVE(name, implementation) \ + static CollectiveRegistration register_##name##_collective( \ + #name, []() { return new implementation; }); + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_FRAMEWORK_COLLECTIVE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/framework/common_shape_fns.h b/third_party/tflite-hdrs/tensorflow/core/framework/common_shape_fns.h new file mode 100644 index 00000000..1be1633f --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/framework/common_shape_fns.h @@ -0,0 +1,313 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_FRAMEWORK_COMMON_SHAPE_FNS_H_ +#define TENSORFLOW_CORE_FRAMEWORK_COMMON_SHAPE_FNS_H_ + +#include + +#include "tensorflow/core/framework/shape_inference.h" +#include "tensorflow/core/util/padding.h" +#include "tensorflow/core/util/tensor_format.h" + +namespace tensorflow { + +namespace shape_inference { + +// Like GetWindowedOutputSize, but deals with DimensionHandles. Does not support +// EXPLICIT padding. +absl::Status GetWindowedOutputSizeFromDims(InferenceContext* c, + DimensionHandle input_size, + DimensionOrConstant filter_size, + int64_t stride, Padding padding_type, + DimensionHandle* output_size); + +// The V2 version computes the same outputs with arbitrary dilation_rate, and +// supports EXPLICIT padding. For detailed equations, refer to the comments +// for GetWindowedOutputSize(). The 'padding_before' and 'padding_after' +// parameters are only used if padding_type == EXPLICIT. +absl::Status GetWindowedOutputSizeFromDimsV2( + InferenceContext* c, DimensionHandle input_size, + DimensionOrConstant filter_size, int64_t dilation_rate, int64_t stride, + Padding padding_type, int64_t padding_before, int64_t padding_after, + DimensionHandle* output_size); + +// Transfers shape of input(0) to output(0). +absl::Status UnchangedShape(shape_inference::InferenceContext* c); + +// Transfers shape of input(0) to output(0), after asserting its rank is . +inline absl::Status UnchangedShapeWithRank(shape_inference::InferenceContext* c, + int32_t rank) { + ShapeHandle out; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), rank, &out)); + c->set_output(0, out); + return absl::OkStatus(); +} + +// Transfers shape of input(0) to output(0), after asserting its rank >= . +inline absl::Status UnchangedShapeWithRankAtLeast( + shape_inference::InferenceContext* c, int32_t rank) { + ShapeHandle out; + TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), rank, &out)); + c->set_output(0, out); + return absl::OkStatus(); +} + +// Transfers shape of input(0) to output(0), after asserting its rank <= . +inline absl::Status UnchangedShapeWithRankAtMost( + shape_inference::InferenceContext* c, int32_t rank) { + ShapeHandle out; + TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), rank, &out)); + c->set_output(0, out); + return absl::OkStatus(); +} + +// Shape function for use with ops no outputs. +inline absl::Status NoOutputs(shape_inference::InferenceContext* c) { + return absl::OkStatus(); +} + +// Shape function for ops that output a single scalar value. +inline absl::Status ScalarShape(shape_inference::InferenceContext* c) { + c->set_output(0, c->Scalar()); + return absl::OkStatus(); +} + +// Shape function for binary ops where both inputs and the output match. +inline absl::Status MergeBothInputsShapeFn(InferenceContext* c) { + ShapeHandle out; + TF_RETURN_IF_ERROR(c->Merge(c->input(0), c->input(1), &out)); + c->set_output(0, out); + return absl::OkStatus(); +} + +// Shape function for dataset iterators. +absl::Status DatasetIteratorShape(shape_inference::InferenceContext* c); + +// Returns a new shape with the specified dims arranged in the specified +// format. The returned value is owned by this context. +// Note: if format = "FORMAT_NCHW_VECT_C" then C represents the outer_depth. +absl::Status MakeShapeFromFormat( + TensorFormat format, DimensionOrConstant N, + const std::vector& spatial, DimensionOrConstant C, + ShapeHandle* out, shape_inference::InferenceContext* context); + +// Shape function for MatMul-like operations. +absl::Status MatMulShape(shape_inference::InferenceContext* c); + +// Shape function for Batched MatMul-like operations with broadcasting across +// batch dimensions. +absl::Status BatchMatMulV2Shape(shape_inference::InferenceContext* c); + +// Shape function for BatchMatMul-like operations +absl::Status BatchMatMulShape(shape_inference::InferenceContext* c); + +// Shape function for Einsum. +absl::Status EinsumShape(shape_inference::InferenceContext* c); + +// Shape function for BiasAdd-like operations. +absl::Status BiasAddShape(shape_inference::InferenceContext* c); + +// Shape function for BiasAddGrad-like operations. +absl::Status BiasAddGradShape(shape_inference::InferenceContext* c); + +// Shape function for general Convolution operation +absl::Status ConvShape(shape_inference::InferenceContext* c); + +// Shape function for Conv2D-like operations that support explicit padding. +absl::Status Conv2DShapeWithExplicitPadding( + shape_inference::InferenceContext* c); + +// Shape function for Conv2D-like operations that do not support explicit +// padding. +absl::Status Conv2DShape(shape_inference::InferenceContext* c); + +// Shape function for Conv3D-like operations. +absl::Status Conv3DShape(shape_inference::InferenceContext* c); + +// Shape function for DepthwiseConv2D-like operations that support explicit +// padding. +absl::Status DepthwiseConv2DNativeShapeWithExplicitPadding( + shape_inference::InferenceContext* c); + +// Shape function for DepthwiseConv2D-like operations that do not support +// explicit padding. +absl::Status DepthwiseConv2DNativeShape(shape_inference::InferenceContext* c); + +// Shape function for Conv2DBackpropInput. +absl::Status Conv2DBackpropInputShape(shape_inference::InferenceContext* c); + +// Shape function for Conv2DBackpropFilterWithBias. +absl::Status Conv2DBackpropFilterWithBiasShape( + shape_inference::InferenceContext* c); + +// Shape function for AvgPool-like operations. +absl::Status AvgPoolShape(shape_inference::InferenceContext* c); + +// Shape function for AvgPoolGrad-like operations. +absl::Status AvgPoolGradShape(shape_inference::InferenceContext* c); + +// Shape function for FusedBatchNorm and FusedBatchNormV2 operations. +absl::Status FusedBatchNormShape(shape_inference::InferenceContext* c); + +// Shape function for FusedBatchNormV3 operations. +absl::Status FusedBatchNormV3Shape(shape_inference::InferenceContext* c); + +// Shape function for _FusedBatchNormEx operations. +absl::Status FusedBatchNormExShape(shape_inference::InferenceContext* c); + +// Shape function for FusedBatchNormGrad and FusedBatchNormGradV2 operations. +absl::Status FusedBatchNormGradShape(shape_inference::InferenceContext* c); + +// Shape function for _FusedBatchNormGradEx operations. +absl::Status FusedBatchNormGradExShape(shape_inference::InferenceContext* c); + +// Shape function for MatrixDiagPartV2 and MatrixDiagPartV3 operations. +absl::Status MatrixDiagPartV2Shape(shape_inference::InferenceContext* c); + +// Shape function for MatrixDiagV2 and MatrixDiagV3 operations. +absl::Status MatrixDiagV2Shape(shape_inference::InferenceContext* c); + +// Shape function for MatrixSetDiagV2 and MatrixSetDiagV3 operations. +absl::Status MatrixSetDiagV2Shape(shape_inference::InferenceContext* c); + +// Shape function for MaxPool-like operations that support explicit padding. +absl::Status MaxPoolShapeWithExplicitPadding( + shape_inference::InferenceContext* c); + +// Shape function for MaxPool-like operations that do not support explicit +// padding. +absl::Status MaxPoolShape(shape_inference::InferenceContext* c); + +// Shape function for MaxPoolV2-like operations. +absl::Status MaxPoolV2Shape(shape_inference::InferenceContext* c, + int num_inputs); + +// Shape function for MaxPoolGrad-like operations. +absl::Status MaxPoolGradShape(shape_inference::InferenceContext* c); + +// Shape function for 3D Pooling operations. +absl::Status Pool3DShape(shape_inference::InferenceContext* c); + +// Shape function for MaxPool3DGrad-like operations. +absl::Status MaxPool3DGradShape(shape_inference::InferenceContext* c); + +// Shape function for AvgPool3DGrad-like operations. +absl::Status AvgPool3DGradShape(shape_inference::InferenceContext* c); + +// Shape function for use with ops whose output shapes are unknown. +absl::Status UnknownShape(shape_inference::InferenceContext* c); + +// Shape function for reduction operations. +absl::Status ReductionShape(shape_inference::InferenceContext* c); + +// Shape function for unsorted segment operations. +absl::Status SegmentReductionWithNumSegmentsShapeFn(InferenceContext* c); + +// Shape function for concat operations. +// is the number of inputs to concatenate and are taken +// from inputs +// [1,num_inputs_to_concat] of the op. Input 0 is the concat_dim input. +absl::Status ConcatShape(shape_inference::InferenceContext* c, + int num_inputs_to_concat); + +// Shape function for concat operations. +absl::Status ConcatV2Shape(shape_inference::InferenceContext* c); + +absl::Status QuantizedConcatV2Shape(InferenceContext* c, + int num_inputs_to_concat); + +// Shape function for binary operators that broadcast their inputs +// and with output to output_index. +// Note: out cannot be NULL. +absl::Status BroadcastBinaryOpOutputShapeFnHelper(InferenceContext* c, + ShapeHandle shape_x, + ShapeHandle shape_y, + bool incompatible_shape_error, + ShapeHandle* out); + +// Shape function for binary operators that broadcast their inputs +// and with output to output_index. +inline absl::Status BroadcastBinaryOpOutputShapeFn(InferenceContext* c, + int output_index) { + ShapeHandle out; + TF_RETURN_IF_ERROR(BroadcastBinaryOpOutputShapeFnHelper( + c, c->input(0), c->input(1), true, &out)); + c->set_output(output_index, out); + return absl::OkStatus(); +} + +// Shape function for binary operators that broadcast their inputs. +// Tested by ops/math_ops_test.cc. +inline absl::Status BroadcastBinaryOpShapeFn(InferenceContext* c) { + return BroadcastBinaryOpOutputShapeFn(c, 0); +} + +// Shape function for random operations. +absl::Status RandomShape(shape_inference::InferenceContext* c); + +// Shape function for Slice operations. +absl::Status SliceShape(shape_inference::InferenceContext* c); + +// Validates the 3 component tensors of a sparse tensor have the proper +// shapes. This mimics SparseTensor.__init__ in python/framework/ops.py. +absl::Status ValidateSparseTensor(InferenceContext* c, + ShapeHandle indices_shape, + ShapeHandle values_shape, + ShapeHandle shape_shape); + +absl::Status ValidateVariableResourceHandle( + InferenceContext* c, std::vector* shape_and_type); + +// Shape function for GatherNd operations. +absl::Status GatherNdShape(InferenceContext* c); + +// Helper shape function for ScatterNd.../TensorScatter... operations. +absl::Status ScatterNdShapeHelper(InferenceContext* c, + ShapeHandle indices_shape, + ShapeHandle updates_shape, + ShapeHandle input_shape); + +// Shape function for ops with an explicit "shape" attribute. +absl::Status ExplicitShape(InferenceContext* c); + +// Shape function for multiple-output ops with an explicit "shapes" attribute. +absl::Status ExplicitShapes(InferenceContext* c); + +// Shape function for SparseReduceMax and SparseReduceSum. +absl::Status SparseReduceShapeFn(InferenceContext* c); + +// Shape function for QuantizedConv2D op. +absl::Status QuantizedConv2DShape(InferenceContext* c); + +// Shape function for _QuantizedConv2D op/fusion. +absl::Status FusedQuantizedConv2DShape(InferenceContext* c); + +// Shape function for _QuantizedDepthwiseConv2D op/fusion. +absl::Status FusedQuantizedDepthwiseConv2D(InferenceContext* c); + +// Shape function for QuantizedAvgPool op +absl::Status QuantizedAvgPoolShape(InferenceContext* c); + +// Shape function for QuantizeV2 op +absl::Status QuantizeV2Shape(InferenceContext* c); + +// Shape function for ReduceScatter ops +absl::Status ReduceScatterShape(shape_inference::InferenceContext* c); + +} // namespace shape_inference + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_FRAMEWORK_COMMON_SHAPE_FNS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/framework/control_flow.h b/third_party/tflite-hdrs/tensorflow/core/framework/control_flow.h new file mode 100644 index 00000000..3cc270b3 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/framework/control_flow.h @@ -0,0 +1,58 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_FRAMEWORK_CONTROL_FLOW_H_ +#define TENSORFLOW_CORE_FRAMEWORK_CONTROL_FLOW_H_ + +#include "tensorflow/core/lib/hash/hash.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +const uint64 kIllegalFrameId = ~0uLL; +const int64_t kIllegalIterId = -1; + +// For the purpose of control flow, every tensor produced by TensorFlow is +// conceptually tagged by a 'FrameAndIter'. FrameAndIter consists of a +// 'frame_id' and an 'iter_id'. The tensor value it represents is produced +// in the frame with frame_id at the iteration of iter_id. +struct FrameAndIter { + uint64 frame_id = kIllegalFrameId; + int64_t iter_id = kIllegalIterId; + + FrameAndIter() {} + + FrameAndIter(uint64 frame, int64_t iter) { + frame_id = frame; + iter_id = iter; + } + + bool operator==(const FrameAndIter& other) const { + return (frame_id == other.frame_id && iter_id == other.iter_id); + } +}; + +struct FrameAndIterHash { + size_t operator()(const FrameAndIter& key) const { + // Make sure there are no padding bytes that we don't want + CHECK_EQ(sizeof(uint64) + sizeof(int64_t), sizeof(FrameAndIter)); + return Hash64(reinterpret_cast(&key), sizeof(FrameAndIter)); + } +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_FRAMEWORK_CONTROL_FLOW_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/framework/dataset.h b/third_party/tflite-hdrs/tensorflow/core/framework/dataset.h new file mode 100644 index 00000000..70ebc12a --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/framework/dataset.h @@ -0,0 +1,1846 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_FRAMEWORK_DATASET_H_ +#define TENSORFLOW_CORE_FRAMEWORK_DATASET_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/memory/memory.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "xla/tsl/framework/allocator.h" +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/attr_value_util.h" +#include "tensorflow/core/framework/cancellation.h" +#include "tensorflow/core/framework/collective.h" +#include "tensorflow/core/framework/dataset_metadata.pb.h" +#include "tensorflow/core/framework/dataset_options.pb.h" +#include "tensorflow/core/framework/dataset_stateful_op_allowlist.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/function_handle_cache.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/model.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/thread_factory.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/framework/variant_encode_decode.h" +#include "tensorflow/core/framework/variant_tensor_data.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/threadpool.h" +#include "tensorflow/core/lib/core/threadpool_interface.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/cpu_info.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/refcount.h" +#include "tensorflow/core/platform/status.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/thread_annotations.h" + +// Polymorphic datasets should support all primitive TensorFlow +// types. Use this macro to expand `m(T)` once for each primitive type +// `T`, e.g. to build a `switch` statement. +#define TF_CALL_DATASET_TYPES(m) TF_CALL_ALL_TYPES(m) TF_CALL_QUANTIZED_TYPES(m) + +namespace tensorflow { + +// Forward declarations to avoid introducing a dependency on headers in +// "tensorflow/core/graph/...". +class GraphDefBuilder; +class Node; + +namespace data { + +namespace internal { +// Merges Options from source to destination. If there is a conflict on a field, +// the field value from the source takes precedence. +void MergeOptions(const protobuf::Message& source, + protobuf::Message* destination); +void MergeOptions(const protobuf::MessageLite& source, + protobuf::MessageLite* destination); +} // namespace internal + +using TraceMeMetadata = std::vector>; + +// Maps the index of dataset elements to a globally shuffled index. See the +// comment for IteratorContext::Params::index_mapper for more details. +// Notes: +// * `absl::OutOfRangeError` indicates the input index argument exceeds +// the cardinality of the dataset. +// * `absl::NotFoundError` indicates we should skip this element. +// This happens in the case we mix multiple datasets into one. For example, +// `dataset1.concatenate(dataset2)`. +// See go/tf-data-random-access-iterator and +// go/tf-data-random-access-iterator-for-concatenate for more info. +using IndexMapperFn = std::function(size_t)>; + +constexpr char kTFDataFunction[] = "_tf_data_function"; + +constexpr int kInfiniteCardinality = -1; +constexpr int kUnknownCardinality = -2; + +// This constant is a magic number that is used (as a prefix) to identify keys +// used for serialization of iterator state. +constexpr char kFullNameRandomHex[] = "60d899aa0d8ce4351e7c3b419e92d25b"; +constexpr int kFullNameRandomHexLen = std::size(kFullNameRandomHex) - 1; +constexpr char kPipe[] = "|"; +constexpr char kColon[] = ":"; + +constexpr char kTFDataResourceTag[] = "tfdata"; +constexpr char kTraceInfoUnavailable[] = "unavailable"; +constexpr char kMetadata[] = "metadata"; + +constexpr char kCardinalityAttrForRewrite[] = "_cardinality"; + +class DatasetBase; +class IteratorContext; +class SerializationContext; + +inline bool IsTFDataFunction(const FunctionDef& func) { + auto iter = func.attr().find(data::kTFDataFunction); + return (iter != func.attr().end() && iter->second.b()); +} + +// Interface for reading values from a key-value store. +// Used for restoring iterator state. This class is thread safe. +// Please see comment on IteratorStateWriter for guidance around using the +// Read*(key, val) vs Read*(name, key, val). +class IteratorStateReader { + public: + // Determines whether the iterator state contains the given key. + virtual bool Contains(absl::string_view key) const = 0; + virtual bool Contains(absl::string_view name, + absl::string_view key) const = 0; + + // Reads an integer for the given key. + virtual absl::Status ReadScalar(absl::string_view key, + int64_t* val) const = 0; + virtual absl::Status ReadScalar(absl::string_view name, absl::string_view key, + int64_t* val) const = 0; + + // Reads a string for the given key. + virtual absl::Status ReadScalar(absl::string_view key, + tstring* val) const = 0; + virtual absl::Status ReadScalar(absl::string_view name, absl::string_view key, + tstring* val) const = 0; + + // Reads a tensor for the given key. + // TODO(jsimsa): Remove non-FLR overrides once all callers are updated. + virtual absl::Status ReadTensor(absl::string_view key, Tensor* val) const = 0; + virtual absl::Status ReadTensor(FunctionLibraryRuntime* flr, + absl::string_view key, Tensor* val) const = 0; + virtual absl::Status ReadTensor(absl::string_view name, absl::string_view key, + Tensor* val) const = 0; + virtual absl::Status ReadTensor(FunctionLibraryRuntime* flr, + absl::string_view name, absl::string_view key, + Tensor* val) const = 0; + + virtual ~IteratorStateReader() {} +}; + +// Interface for writing values to a key-value store. +// Used for saving iterator state. Not thread safe. +// The IteratorStateWriter creates a tensor for each unique iterator name it +// sees. For the Write*(key, val) API's the key is expected to encode this +// name as keys are required to be produced using the full_name() method. +// Each tensor has an upper limit of 2 GB and so if the state for an iterator +// might exceed the 2 GB limit, you can pass an explicit name in via the +// Write*(name, key, val) APIs allowing you to further split up the state +// into more manageable chunks. +class IteratorStateWriter { + public: + // Writes an integer for the given key. + virtual absl::Status WriteScalar(absl::string_view key, + const int64_t val) = 0; + virtual absl::Status WriteScalar(absl::string_view name, + absl::string_view key, + const int64_t val) = 0; + + // Writes a string for the given key. + virtual absl::Status WriteScalar(absl::string_view key, + const tstring& val) = 0; + virtual absl::Status WriteScalar(absl::string_view name, + absl::string_view key, + const tstring& val) = 0; + + // Writes a tensor for the given key. + virtual absl::Status WriteTensor(absl::string_view key, + const Tensor& val) = 0; + virtual absl::Status WriteTensor(absl::string_view name, + absl::string_view key, + const Tensor& val) = 0; + + virtual ~IteratorStateWriter() {} + + protected: + // Accessible only through derived concrete class's copy/move constructors + IteratorStateWriter() = default; + IteratorStateWriter(const IteratorStateWriter&) = default; + IteratorStateWriter(IteratorStateWriter&&) = default; +}; + +// Generates a full name key for iterator checkpointing. All keys generated for +// iterator checkpoints should go through this function. +std::string FullName(const std::string& prefix, const std::string& name); + +// Extracts iterator prefix from key generated by `FullName`. +absl::Status ExtractIteratorPrefix(absl::string_view key, string* prefix); + +// Interface for objects that can be checkpointed. +class Checkpointable { + public: + Checkpointable() = default; + virtual ~Checkpointable() = default; + + virtual absl::Status Save(SerializationContext* ctx, + IteratorStateWriter* writer) = 0; + virtual absl::Status Restore(IteratorContext* ctx, + IteratorStateReader* reader) = 0; +}; + +// Wrapper around GraphDefBuilder. Used to serialize Dataset graph. +class GraphDefBuilderWrapper { + public: + explicit GraphDefBuilderWrapper(GraphDefBuilder* b) : b_(b) {} + + // Adds a Const node with scalar value to the Graph. + // `*output` contains a pointer to the output `Node`. It is guaranteed to be + // non-null if the method returns with an OK status. + // The returned Node pointer is owned by the backing Graph of GraphDefBuilder. + template + absl::Status AddScalar(const T& val, Node** output) { + Tensor val_t = Tensor(DataTypeToEnum::v(), TensorShape({})); + val_t.scalar()() = val; + AddTensorInternal(val_t, output); + if (*output == nullptr) { + return errors::Internal("AddScalar: Failed to build Const op."); + } + return absl::OkStatus(); + } + + // Adds a Const node with vector value to the Graph. + // `*output` contains a pointer to the output `Node`. It is guaranteed to be + // non-null if the method returns with an OK status. + // The returned Node pointer is owned by the backing Graph of GraphDefBuilder. + // TODO(shivaniagrawal): Consider changing to gtl::ArraySlice? + template + absl::Status AddVector(const std::vector& val, Node** output) { + Tensor val_t = Tensor(DataTypeToEnum::v(), + TensorShape({static_cast(val.size())})); + for (size_t i = 0; i < val.size(); i++) { + val_t.flat()(i) = val[i]; + } + AddTensorInternal(val_t, output); + if (*output == nullptr) { + return errors::Internal("AddVector: Failed to build Const op."); + } + return absl::OkStatus(); + } + + absl::Status AddVector(const std::vector& val, Node** output) { + Tensor val_t = Tensor(DataTypeToEnum::v(), + TensorShape({static_cast(val.size())})); + for (size_t i = 0; i < val.size(); i++) { + val_t.flat()(i) = val[i]; + } + AddTensorInternal(val_t, output); + if (*output == nullptr) { + return errors::Internal("AddVector: Failed to build Const op."); + } + return absl::OkStatus(); + } + + // Adds a `Const` node for the given tensor value to the graph. + // + // `*output` contains a pointer to the output `Node`. It is guaranteed to be + // non-null if the method returns with an OK status. The returned `Node` + // pointer is owned by the backing graph of `GraphDefBuilder`. + absl::Status AddTensor(const Tensor& val, Node** output) { + AddTensorInternal(val, output); + if (*output == nullptr) { + return errors::Internal("AddTensor: Failed to build Const op."); + } + return absl::OkStatus(); + } + + // Adds a `Placeholder` node for the given tensor value to the graph. + // + // `*output` contains a pointer to the output `Node`. It is guaranteed to be + // non-null if the method returns with an OK status. The returned `Node` + // pointer is owned by the backing graph of `GraphDefBuilder`. + absl::Status AddPlaceholder(const Tensor& val, Node** output) { + AddPlaceholderInternal(val, output); + if (*output == nullptr) { + return errors::Internal( + "AddPlaceholder: Failed to build Placeholder op."); + } + return absl::OkStatus(); + } + + // Adds a node for the given dataset to the `Graph`. The value of + // `DatasetBase::type_string()` is used as the op type for the node. Values + // for the `output_types` and `output_shapes` node attributes are also written + // if those attributes are defined in the `OpDef`. + // + // If `use_dataset_name` is set, the value of `DatasetBase::node_name()` is + // used as the op name for the node. This argument should only be set when + // serializing `DatasetBase` instances which might not have been created + // through op kernel execution to make sure the dataset op name is preserved + // across serialization boundaries, which is in turn needed to make sure + // iterator checkpoints are valid across serialization boundaries. When + // `use_dataset_name` is set, the caller is responsible for making sure that + // the op name is unique across the graph. + // + // `*output` contains a pointer to the output `Node`. It is guaranteed to be + // non-null if the method returns with an OK status. The returned `Node` + // pointer is owned by the backing `Graph` of `GraphDefBuilder`. + absl::Status AddDataset(const DatasetBase* dataset, + const std::vector& inputs, Node** output); + absl::Status AddDataset( + const DatasetBase* dataset, const std::vector& inputs, + const std::vector>& attrs, + Node** output); + absl::Status AddDataset( + const DatasetBase* dataset, + const std::vector>& inputs, + const std::vector>>& + list_inputs, + const std::vector>& attrs, + Node** output); + absl::Status AddDataset( + const DatasetBase* dataset, + const std::vector>& inputs, + const std::vector>>& + list_inputs, + const std::vector>& attrs, + bool use_dataset_name, Node** output); + + // Adds a user-defined function with name `function_name` to the graph and + // recursively adds all functions it references. If a function with a matching + // name has already been added, returns with OK status. If a user-defined with + // name `function_name` is not found in the context's function library, + // returns an InvalidArgumentError. If the function with name `function_name` + // or any of its dependent functions are stateful, and the context does not + // explicitly permit stateful functions, returns an InvalidArgument error. + absl::Status AddFunction(SerializationContext* ctx, + const string& function_name, + const FunctionLibraryDefinition& lib_def); + + template + void BuildAttrValue(const T& value, AttrValue* attr) { + SetAttrValue(value, attr); + } + + template + AttrValue BuildAttrValue(const T& value) { + AttrValue attr; + SetAttrValue(value, &attr); + return attr; + } + + protected: + GraphDefBuilder* builder() { return b_; } + + private: + void AddPlaceholderInternal(const Tensor& val, Node** output); + void AddTensorInternal(const Tensor& val, Node** output); + bool HasAttr(const string& op_type_name, const string& attr_name) const; + + bool HasAttr(const OpDef* op_def, const string& attr_name) const { + for (const auto& attr : op_def->attr()) { + if (attr.name() == attr_name) { + return true; + } + } + return false; + } + + absl::Status AddAttrFunctions(SerializationContext* ctx, + const AttrValue& attr_value, + const FunctionLibraryDefinition& lib_def) { + if (attr_value.has_func()) { + TF_RETURN_IF_ERROR(AddFunction(ctx, attr_value.func().name(), lib_def)); + } else if (attr_value.has_list()) { + for (const NameAttrList& name_attr_list : attr_value.list().func()) { + TF_RETURN_IF_ERROR(AddFunction(ctx, name_attr_list.name(), lib_def)); + } + } + return absl::OkStatus(); + } + + GraphDefBuilder* b_; +}; + +class StatsAggregator; + +// A utility class for running a function and ensuring that there is always a +// `tensorflow::data` symbol on the stack. +class Runner { + public: + virtual ~Runner() {} + + // Runs the given function. + virtual void Run(const std::function& f) = 0; + + // Returns a global singleton Runner. + static Runner* get(); +}; + +// A class which provides a sequence of splits. Splits represent subdivisions of +// a dataset, e.g. filenames or ranges within files. We use splitting to +// partition input data into smaller pieces for distributed processing (see +// go/tf-data-splitting-design). The SplitProvider subclasses are expected to be +// thread-safe. +// +// Datasets provide a `MakeSplitProvider` method to expose a listing of their +// splits. +// +// Iterators created with a split provider will only iterate over the splits +// provided by the split provider. +class SplitProvider { + public: + virtual ~SplitProvider() {} + // Stores the next split in `*split`, setting `*end_of_splits` to indicate + // whether there were any splits left. + virtual absl::Status GetNext(Tensor* split, bool* end_of_splits) = 0; + // Resets the split provider to its beginning. + virtual absl::Status Reset() = 0; + // Saves the state of this split provider. + virtual absl::Status Save(std::function full_name, + IteratorStateWriter* writer) = 0; + // Restores the state of this split provider. + virtual absl::Status Restore( + std::function full_name, + IteratorStateReader* reader) = 0; + // Returns the number of splits: + // - If there are a finite number of splits, returns a non-negative count. + // - If there are an infinite number of splits, returns kInfiniteCardinality. + // - If the number of splits is unknown or can't be efficiently computed, + // returns kUnknownCardinality. + virtual int64_t Cardinality() const { return kUnknownCardinality; } + // Cancels the split provider. After cancelling, all other existing and future + // calls should return quickly without blocking. + virtual void Cancel() {} + // Used to determine if the split provider is dynamic. Dynamic split providers + // are expected to be non-deterministic and may return different splits upon + // reinitialization. + virtual bool IsDynamic() const { return false; } +}; + +// Returns the runner threadpool size from an OpKernelContext. +int32_t GetRunnerThreadpoolSizeFromOpKernelContext(OpKernelContext* ctx); + +// In-memory representation of a checkpoint. The checkpoint is represented as a +// collection of key-value pairs and are expected to be written using the +// `IteratorStateWriter` interface. +// +// The implementation is not thread-safe. +class MemoryCheckpoint final : public IteratorStateWriter { + public: + // IdRegistry maintains a bi-directional mapping between string and integer + // representations of checkpoint keys. + // + // The reason we need both is that integer ids are used for fast lookups and + // comparisons, while string ids are used for prefix matching. + class IdRegistry { + public: + IdRegistry() = default; + + // Adds the given string id to the registry, generating a unique integer id + // for it. If the string id already exists, its integer id is returned. + int64_t Add(const std::string& prefix, const std::string& key); + + // Gets all integer ids for string ids matching the given prefix. + std::vector GetMatchingIds(const std::string& prefix_to_match); + + // Gets the string id for the given integer id. + std::pair Get(int64_t id); + + // Removes the entries matching the given integer ids from the registry. + void RemoveIds(const std::vector& ids); + + private: + mutex mu_; + int64_t next_id_ TF_GUARDED_BY(mu_) = 0; + absl::flat_hash_map> + int_to_string_ TF_GUARDED_BY(mu_); + absl::flat_hash_map, int64_t> + string_to_int_ TF_GUARDED_BY(mu_); + }; + + MemoryCheckpoint() = delete; + explicit MemoryCheckpoint(std::shared_ptr registry) + : id_registry_(registry) {} + + MemoryCheckpoint(MemoryCheckpoint&& other) = default; + MemoryCheckpoint(const MemoryCheckpoint& other) = default; + + static MemoryCheckpoint CreateRootCheckpoint( + std::shared_ptr registry) { + return MemoryCheckpoint(/*id_registry*/ registry, /*is_root=*/true); + } + + // BEGIN implementation of `IteratorStateWriter` interface + absl::Status WriteScalar(absl::string_view key, int64_t val) override { + string prefix; + TF_RETURN_IF_ERROR(ExtractIteratorPrefix(key, &prefix)); + return WriteScalar(prefix, key, val); + } + absl::Status WriteScalar(absl::string_view name, absl::string_view key, + int64_t val) override { + auto id = id_registry_->Add(string(name), string(key)); + int_values_[id] = val; + return absl::OkStatus(); + } + absl::Status WriteScalar(absl::string_view key, const tstring& val) override { + string prefix; + TF_RETURN_IF_ERROR(ExtractIteratorPrefix(key, &prefix)); + return WriteScalar(prefix, key, val); + } + absl::Status WriteScalar(absl::string_view name, absl::string_view key, + const tstring& val) override { + auto id = id_registry_->Add(string(name), string(key)); + str_values_[id] = val; + return absl::OkStatus(); + } + absl::Status WriteTensor(absl::string_view key, const Tensor& val) override { + string prefix; + TF_RETURN_IF_ERROR(ExtractIteratorPrefix(key, &prefix)); + return WriteTensor(prefix, key, val); + } + absl::Status WriteTensor(absl::string_view name, absl::string_view key, + const Tensor& val) override { + auto id = id_registry_->Add(string(name), string(key)); + tensor_values_[id] = val; + return absl::OkStatus(); + } + // END implementation of `IteratorStateWriter` interface + + // String representation for the in-memory checkpoint suitable for debugging. + std::string DebugString() const; + + // Returns the status of the in-memory checkpoint. + absl::Status GetStatus() const { return status_; } + + // Merges state of another checkpoint into this checkpoint, overwriting + // existing state (if applicable). + // + // Merge also garbage collects state that is no longer needed. + void Merge(MemoryCheckpoint* other); + + // Purge removes all keys with given prefix from checkpoint. It also adds the + // prefix for tracking unless it is the root checkpoint. + void Purge(const std::string& prefix); + + // Stores the in-memory checkpoint to the given writer. + absl::Status Save(IteratorStateWriter* writer) const; + + // Updates the status of the in-memory checkpoint with the given status. + void UpdateStatus(absl::Status status) { status_.Update(status); } + + private: + explicit MemoryCheckpoint(std::shared_ptr registry, bool is_root) + : is_root_(is_root), id_registry_(registry) {} + void operator=(const MemoryCheckpoint&) = delete; + + absl::Status status_ = absl::OkStatus(); + // Only set to true for the checkpoint in IteratorResource. + // Root checkpoint does not track expired prefixes. + const bool is_root_ = false; + absl::flat_hash_map int_values_; + absl::flat_hash_map str_values_; + absl::flat_hash_map tensor_values_; + + // Keeps track of expired prefixes for propagation. Cleaned after it's merged. + absl::flat_hash_set expired_prefixes_; + + std::shared_ptr id_registry_; +}; + +// Aggregates runtime support needed for dataset and iterator serialization. +class SerializationContext { + public: + // Handles the external state according to the external state policy. + absl::Status HandleCheckExternalStateStatus(absl::Status s) { + if (s.ok()) { + return s; + } + switch (params_.external_state_policy) { + case ExternalStatePolicy::POLICY_WARN: + LOG(WARNING) << s.ToString(); + return absl::OkStatus(); + case ExternalStatePolicy::POLICY_IGNORE: + VLOG(2) << "Ignoring error status: " << s.ToString(); + return absl::OkStatus(); + case ExternalStatePolicy::POLICY_FAIL: + return s; + default: + return errors::InvalidArgument("Unexpected value of external policy: ", + params_.external_state_policy); + } + } + + struct Params { + explicit Params() = default; + + explicit Params(OpKernelContext* ctx) + : resource_mgr(ctx->resource_manager()), + device_name(ctx->device()->attributes().name()) {} + + std::vector>* input_list = nullptr; // Not owned. + + // Indicates what to do if the dataset depends on external state. + ExternalStatePolicy external_state_policy = + ExternalStatePolicy::POLICY_WARN; + + // Indicates whether the serialization is for rewrites. + // + // If true: + // * A dataset that doesn't implement serialization is replaced with a + // placeholder returned in `input_list`. + // * Data tensors are replaced with a placeholder returned in + // `input_list`. + // * Datasets that use random seeds should not serialize the random seeds. + // This doesn't affect datasets that use fixed seeds; fixed seeds will + // always be preserved. + // * Cardinality is serialized as an unregistered attribute + // `_cardinality`. + // If false: + // * A dataset that doesn't implement serialization should result in an + // error. + // * Data tensors (potentially large) should be serialized. + // * Datasets that use random seeds should serialize the random seeds. + bool is_graph_rewrite = false; + + // A resource manager for looking up resources during serialization. + ResourceMgr* resource_mgr; + + // The name of the device doing the serialization. + std::string device_name; + + // Determines whether checkpointing should represent input pipeline state + // symbolically, using cursors into source iterators, or explicitly, by + // storing internal state of each iterator. + bool symbolic_checkpoint = false; + }; + + explicit SerializationContext(Params params) : params_(params) {} + + std::vector>* input_list() { + return params_.input_list; + } + + ExternalStatePolicy external_state_policy() const { + return params_.external_state_policy; + } + + bool is_graph_rewrite() const { return params_.is_graph_rewrite; } + + const ResourceMgr* resource_mgr() const { return params_.resource_mgr; } + + const std::string& device_name() const { return params_.device_name; } + + bool symbolic_checkpoint() const { return params_.symbolic_checkpoint; } + + private: + Params params_; + + SerializationContext(const SerializationContext&) = delete; + void operator=(const SerializationContext&) = delete; +}; + +// Specifies the tf.data pipeline run mode. +enum RunMode { DEFAULT, STANDALONE }; + +// A cut-down version of `OpKernelContext` for running computations in +// iterators. Note that we cannot simply use `OpKernelContext` here because we +// might run computation in an iterator whose lifetime is not nested within the +// lifetime of a single `OpKernelContext` (e.g. asynchronous prefetching). +// +// TODO(mrry): We're making some daring assumptions about the lifetime of the +// runner passed in here. A runner will be deleted when the original step ends, +// but all existing runners only close over session-lifetime (or longer-lived) +// state, so we can make a copy of the function. There's nothing in the +// definition of the API from which we took the runner to guarantee that what we +// are doing is safe. We should formalize the properties here. +class IteratorContext { + public: + struct Params { + explicit Params(IteratorContext* ctx) + : accelerator_device_info(ctx->accelerator_device_info()), + allocator_getter(ctx->allocator_getter()), + cancellation_manager(ctx->cancellation_manager()), + collective_executor(ctx->collective_executor()), + env(ctx->env()), + flr(ctx->flr()), + function_handle_cache(ctx->function_handle_cache()), + interleave_depth(ctx->interleave_depth()), + is_restoring(ctx->is_restoring()), + model(ctx->model()), + options(ctx->options()), + ram_budget_manager(ctx->ram_budget_manager()), + resource_mgr(ctx->resource_mgr()), + runner(*(ctx->runner())), + runner_threadpool_size(ctx->runner_threadpool_size()), + split_providers(ctx->split_providers()), + stats_aggregator(ctx->stats_aggregator()), + symbolic_checkpoint(ctx->symbolic_checkpoint()), + thread_factory(ctx->thread_factory()), + thread_pool(ctx->thread_pool()), + id_registry(ctx->id_registry()), + warm_start(ctx->warm_start()), + index_mapper(ctx->index_mapper()) {} + + explicit Params(OpKernelContext* ctx) + : collective_executor(ctx->collective_executor()), + env(ctx->env()), + flr(ctx->function_library()) { + // NOTE: need reinterpret_cast because function.h forward-declares Device. + DeviceBase* device = + reinterpret_cast(ctx->function_library()->device()); + accelerator_device_info = device->tensorflow_accelerator_device_info(); + allocator_getter = [device](AllocatorAttributes attrs) { + return device->GetAllocator(attrs); + }; + + runner_threadpool_size = GetRunnerThreadpoolSizeFromOpKernelContext(ctx); + + // NOTE: Wrap every runner invocation in a call to Runner()->Run(), so + // that a symbol in the tensorflow::data namespace is always on the stack + // when executing a function inside a Dataset. + runner = std::bind( + []( + // Note: `runner` is a const reference to avoid copying it. + const std::function)>& ctx_runner, + std::function fn) { + std::function wrapped_fn = std::bind( + [](const std::function& fn) { Runner::get()->Run(fn); }, + std::move(fn)); + ctx_runner(std::move(wrapped_fn)); + }, + *ctx->runner(), std::placeholders::_1); + } + + // If non-null, information about the GPU or TPU on which the op is placed. + const DeviceBase::AcceleratorDeviceInfo* accelerator_device_info = nullptr; + + // The Allocator to be used to allocate the output of an iterator. + std::function allocator_getter = nullptr; + + // The CancellationManager to be used to cancel execution of ops. + CancellationManager* cancellation_manager = nullptr; + + // Collective support. + CollectiveExecutor* collective_executor = nullptr; + + // Interface to operating system functionality. + Env* env = nullptr; + + // The FunctionLibraryRuntime object to be used to make function calls. + FunctionLibraryRuntime* flr = nullptr; + + // A FunctionHandleCache that owns all the function handles. Not owned. + FunctionHandleCache* function_handle_cache = nullptr; + + // Records the number of ParallelInterleave operations in the path from the + // root node to this node (not including this node) in the input pipeline + // tree. + int64 interleave_depth = 0; + + // Marks whether the iterator is restored from a checkpoint. + bool is_restoring = false; + + // If non-null, identifies the object used for performance modeling. + std::shared_ptr model = nullptr; + + // The input pipeline options. + const Options* options = nullptr; + + // Manager for the ram budget when using autotune. + std::shared_ptr ram_budget_manager = nullptr; + + // A resource manager for storing dataset-related state, e.g. random + // seeds or cached tensors. Not owned. + ResourceMgr* resource_mgr = nullptr; + + // Function call support. + std::function)> runner = nullptr; + + // Number of threads used for executing user-defined functions. + int32 runner_threadpool_size = 0; + + // Split providers indicating which splits to process. May be empty, + // indicating that the iterator should process all splits. + std::vector> split_providers; + + // The `StatsAggregator` object to record statistics about the iterator. + // + // TODO(b/147325552): Remove this API and any of its uses after we switch to + // using C++ based implementation for tf.data options (on 4/12/2021). + std::shared_ptr stats_aggregator = nullptr; + + // Indicates whether to use symbolic checkpointing. + bool symbolic_checkpoint = false; + + // A factory for creating threads to perform blocking work. + std::shared_ptr thread_factory = nullptr; + + // A shared thread pool to schedule computation into. + thread::ThreadPoolInterface* thread_pool = nullptr; + + std::shared_ptr id_registry = + std::make_shared(); + + // If `true` background threads of asynchronous operations are started when + // the iterator is created. Otherwise, they are started upon first `GetNext` + // request. Default value is set to false to ensure backward compatibility. + bool warm_start = false; + + // Specifies the tf.data pipeline run mode. + RunMode run_mode = RunMode::DEFAULT; + + // Maps the index of dataset elements to a shuffled index. In other words, + // given an index i, returns the permuted index p(i) for the iterator. Used + // to support global shuffling of datasets that support random access. + IndexMapperFn index_mapper = nullptr; + + // Records the number of elements that have been produced prior to a + // checkpoint. This is set by globally shuffled iterators so that upstream + // iterators can restore the element counts in the random access mode. + std::optional restored_element_count = std::nullopt; + }; + + explicit IteratorContext(IteratorContext* ctx) + : IteratorContext(Params{ctx}) {} + + explicit IteratorContext(OpKernelContext* ctx) + : IteratorContext(Params{ctx}) {} + + explicit IteratorContext(Params params) + : params_(std::move(params)), + checkpoint_(MemoryCheckpoint{params_.id_registry}) {} + + IteratorContext(const IteratorContext& other) + : IteratorContext(Params{other.params_}) { + // MemoryCheckpoint should not be copied over as the child context should + // not care what's in the checkpoint of parent context. + } + + std::shared_ptr id_registry() { + return params_.id_registry; + } + + const DeviceBase::AcceleratorDeviceInfo* accelerator_device_info() { + return params_.accelerator_device_info; + } + + Allocator* allocator(AllocatorAttributes attrs) { + return params_.allocator_getter(attrs); + } + + std::function allocator_getter() { + return params_.allocator_getter; + } + + CancellationManager* cancellation_manager() { + return params_.cancellation_manager; + } + + CollectiveExecutor* collective_executor() { + return params_.collective_executor; + } + + Env* env() const { return params_.env; } + + FunctionLibraryRuntime* flr() { return params_.flr; } + + FunctionHandleCache* function_handle_cache() { + return params_.function_handle_cache; + } + + MemoryCheckpoint* checkpoint() { return &checkpoint_; } + + int64 interleave_depth() { return params_.interleave_depth; } + + bool is_restoring() { return params_.is_restoring; } + + const std::shared_ptr& model() const { return params_.model; } + + const Options* options() const { return params_.options; } + + const std::shared_ptr& ram_budget_manager() { + return params_.ram_budget_manager; + } + + ResourceMgr* resource_mgr() { return params_.resource_mgr; } + + std::function)>* runner() { + return ¶ms_.runner; + } + + int32 runner_threadpool_size() { return params_.runner_threadpool_size; } + + std::vector> split_providers() const { + return params_.split_providers; + } + + std::shared_ptr stats_aggregator() { + return params_.stats_aggregator; + } + + bool symbolic_checkpoint() { return params_.symbolic_checkpoint; } + + const std::shared_ptr& thread_factory() { + return params_.thread_factory; + } + + thread::ThreadPoolInterface* thread_pool() { return params_.thread_pool; } + + bool warm_start() { return params_.warm_start; } + + RunMode run_mode() { return params_.run_mode; } + + IndexMapperFn index_mapper() const { return params_.index_mapper; } + + void set_restored_element_count(size_t element_count) { + params_.restored_element_count.emplace(element_count); + } + + std::optional restored_element_count() const { + return params_.restored_element_count; + } + + void SetModel(std::shared_ptr model) { params_.model = model; } + + void SetIndexMapper(const IndexMapperFn& index_mapper) { + params_.index_mapper = index_mapper; + }; + + std::unique_ptr CreateThreadPool(const string& name, + int num_threads) { + if (params_.thread_pool) { + // Create a `ThreadPool` instance by wrapping `params_.thread_pool` (which + // is an instance of `thread::ThreadPoolInterface`). Notably, the + // ownership of `params_.thread_pool` is *not* transferred onto the newly + // created `ThreadPool` instance. + return absl::make_unique(params_.thread_pool); + } else { + return absl::make_unique(params_.env, ThreadOptions(), + name, num_threads, + /*low_latency_hint=*/false); + } + } + + // Merges the given checkpoint with the checkpoint of this context. + // + // The intended for this API is that methods, such as + // `IteratorBase::Initialize`, `IteratorBase::GetNextInternal`, or + // `IteratorBase::RestoreInternal` that store data in the in-memory + // checkpoint, use a separate instance of `IteratorContext` for a nested call, + // then the checkpoint collected by the `IteratorContext` instance passed into + // the callee should be merged into the `IteratorContext` of the caller: + // + // ``` + // Status GetNextInternal(IteratorContext* ctx, ...) { + // ... + // IteratorContext nested_ctx(...); + // TF_RETURN_IF_ERROR(input_impl_->GetNext(&nested_ctx, ...)); + // ctx->MergeCheckpoint(nested_ctx->checkpoint()); + // ... + // } + // ``` + void MergeCheckpoint(MemoryCheckpoint* checkpoint) { + if (symbolic_checkpoint()) { + checkpoint_.Merge(checkpoint); + } + } + + // Removes any keys with the given prefix from the checkpoint. + // + // The intended use for this API is to clean the stale state in checkpoint, + // e.g. when a pipeline created by `flat_map` is exhausted, the state + // associated with the iterator of that pipeline is no longer needed and + // should be removed. + void PurgeCheckpoint(const std::string& prefix) { + if (symbolic_checkpoint()) { + checkpoint_.Purge(prefix); + } + } + + // Saves the state of the given iterator into the checkpoint. + void SaveCheckpoint(Checkpointable* iterator) { + if (symbolic_checkpoint()) { + SerializationContext::Params params; + params.symbolic_checkpoint = true; + SerializationContext ctx(std::move(params)); + checkpoint_.UpdateStatus(iterator->Save(&ctx, &checkpoint_)); + } + } + + std::unique_ptr StartThread(const string& name, + std::function fn) { + if (params_.thread_factory) { + return params_.thread_factory->StartThread(name, std::move(fn)); + } else { + return absl::WrapUnique( + Env::Default()->StartThread({}, name, std::move(fn))); + } + } + + // Updates the status of the checkpoint with the given status. + void UpdateCheckpointStatus(std::function status_fn) { + if (symbolic_checkpoint()) { + checkpoint_.UpdateStatus(status_fn()); + } + } + + private: + Params params_; + MemoryCheckpoint checkpoint_; +}; + +// Generic context that can be constructed with either an `OpKernelContext` or +// `IteratorContext`. +struct AnyContext { + Allocator* allocator; + std::function)>* runner; + int64_t runner_threadpool_size; + + explicit AnyContext(IteratorContext* ctx) { + allocator = ctx->allocator({}); + runner = ctx->runner(); + runner_threadpool_size = ctx->runner_threadpool_size(); + } + + explicit AnyContext(OpKernelContext* ctx) { + allocator = ctx->get_allocator({}); + runner = ctx->runner(); + runner_threadpool_size = GetRunnerThreadpoolSizeFromOpKernelContext(ctx); + } +}; + +// Represents the current position in a range of outputs, where the +// range of outputs is typically represented by an `DatasetBase`, +// defined below. +class IteratorBase : public Checkpointable { + public: + ~IteratorBase() override { + for (auto rit = cleanup_fns_.rbegin(); rit != cleanup_fns_.rend(); ++rit) { + (*rit)(); + } + } + + // Gets the next output from the range that this iterator is traversing. + // + // If at least one output remains in this iterator's range, that + // output will be stored in `*out_tensors` and `false` will be + // stored in `*end_of_sequence`. + // + // If no more outputs remain in this iterator's range, `true` will be stored + // in `*end_of_sequence`, and `*out_tensors` will be empty. + // + // Implementations should never return `OutOfRange` error. If at end of + // sequence, set `*end_of_sequence = true` and return `OkStatus()`. + // Internally raised `OutOfRange` errors that do not imply end of sequence + // should be converted to a different error type before being propagated to + // the caller. + // + // Implementations must explicitly set `*end_of_sequence = false` if an + // `OkStatus()` status is returned and the iterator is not at the end of the + // sequence. + // + // `out_tensors` and `end_of_sequence` are output parameters. `*out_tensors` + // and `*end_of_sequence` should not be read by implementations of `GetNext` + // before they are assigned. + // + // This method is thread-safe. + // + // TODO(mrry): Define `GetNextAsync()` or `GetNextManyAsync()`, and + // potentially remove this method. + virtual absl::Status GetNext(IteratorContext* ctx, + std::vector* out_tensors, + bool* end_of_sequence) = 0; + + absl::Status GetNext(IteratorContext&& ctx, std::vector* out_tensors, + bool* end_of_sequence) { + return GetNext(&ctx, out_tensors, end_of_sequence); + } + + // If a dataset needs to provide its own index mapper behavior to support + // global shuffling, implement this method. + virtual IndexMapperFn GetIndexMapper( + IndexMapperFn parent_index_mapper) const { + return parent_index_mapper; + } + + // Skips the next `num_to_skip` outputs from the range that this iterator + // is traversing. + // + // If there are not enough outputs to skip, it will set + // `*end_of_sequence = true` and return `OkStatus()`. `*num_skipped` will + // store the number of outputs that are skipped. When `*end_of_sequence` is + // `false`, `*num_skipped` should equal to `num_to_skip`. + virtual absl::Status Skip(IteratorContext* ctx, int num_to_skip, + bool* end_of_sequence, int* num_skipped) = 0; + + virtual absl::Status Skip(IteratorContext&& ctx, int num_to_skip, + bool* end_of_sequence, int* num_skipped) { + return Skip(&ctx, num_to_skip, end_of_sequence, num_skipped); + } + + // Returns a vector of DataType values, representing the respective + // element types of each tuple component in the outputs of this + // iterator. + virtual const DataTypeVector& output_dtypes() const = 0; + + // Returns a vector of tensor shapes, representing the respective + // (and possibly partially defined) shapes of each tuple component + // in the outputs of this iterator. + virtual const std::vector& output_shapes() const = 0; + + // Returns a string that identifies the sequence of iterators leading up to + // this iterator. + virtual const string& prefix() const = 0; + + // Indicates whether the iterator is compatible with symbolic checkpointing. + virtual bool SymbolicCheckpointCompatible() const { return false; } + + // Performs initialization that needs to happen outside of a constructor to + // properly propagate errors. + virtual absl::Status Initialize(IteratorContext* ctx) { + return absl::OkStatus(); + } + + // Performs initialization of the base iterator. + absl::Status InitializeBase(IteratorContext* ctx, const IteratorBase* parent); + + // Saves the state of this iterator. + absl::Status Save(SerializationContext* ctx, + IteratorStateWriter* writer) override { + int64_t start_us = EnvTime::NowMicros(); + TF_RETURN_IF_ERROR(SaveInternal(ctx, writer)); + VLOG(1) << "Saved " << prefix() << " in " + << (EnvTime::NowMicros() - start_us) << "us"; + return absl::OkStatus(); + } + + // Restores the state of this iterator. + absl::Status Restore(IteratorContext* ctx, + IteratorStateReader* reader) override { + int64_t start_us = EnvTime::NowMicros(); + TF_RETURN_IF_ERROR(RestoreInternal(ctx, reader)); + ctx->SaveCheckpoint(this); + VLOG(1) << "Restored " << prefix() << " in " + << (EnvTime::NowMicros() - start_us) << "us"; + return absl::OkStatus(); + } + + // Returns the total number of bytes buffered by the iterator across all nodes + // in the subtree for which autotuning is enabled. + int64_t TotalBufferedBytes() const { + if (node_) return node_->TotalBufferedBytes(); + return 0; + } + + protected: + // Returns a node that models this iterator. + virtual std::shared_ptr CreateNode( + IteratorContext* ctx, model::Node::Args args) const = 0; + + // This is needed so that sub-classes of IteratorBase can call + // `SaveInternal` on their input iterators. + absl::Status SaveInput(SerializationContext* ctx, IteratorStateWriter* writer, + const std::unique_ptr& input) { + if (ctx->symbolic_checkpoint()) { + return absl::OkStatus(); + } + return input->Save(ctx, writer); + } + + // This is needed so that sub-classes of IteratorBase can call + // `RestoreInternal` on their input iterators. + absl::Status RestoreInput(IteratorContext* ctx, IteratorStateReader* reader, + const std::unique_ptr& input) { + return input->Restore(ctx, reader); + } + + absl::Status RestoreInput(IteratorContext&& ctx, IteratorStateReader* reader, + const std::unique_ptr& input) { + return RestoreInput(&ctx, reader, input); + } + + // Saves the state of this iterator. + // + // This method is used to store the state of the iterator in a checkpoint. + // implementations have an override. + virtual absl::Status SaveInternal(SerializationContext* ctx, + IteratorStateWriter* writer) = 0; + + // Restores the state of this iterator. + // + // This method is used to restore the state of the iterator from a checkpoint. + // + // Implementations may assume that the iterator is in a clean state. That is, + // its `Initialize` method has been called, but its `GetNext` method has + // never been called. + // implementations have an override. + virtual absl::Status RestoreInternal(IteratorContext* ctx, + IteratorStateReader* reader) = 0; + + // Returns a pointer to the node representing this iterator in the performance + // model. It may be null, if performance modeling is not enabled for this + // iterator. + std::shared_ptr model_node() const { return node_; } + + // Returns the number of elements produced by this iterator. + int64_t num_elements() const { + if (node_) return node_->num_elements(); + return 0; + } + + std::shared_ptr node_ = nullptr; + + private: + // For access to `AddCleanupFunction` and `Restore`. + friend class DatasetBase; + friend class DatasetBaseIterator; // for access to `node_` + + std::vector> cleanup_fns_; + const IteratorBase* parent_ = nullptr; // Not owned. + uint64_t id_ = 0; + uint64_t parent_id_ = 0; +}; + +// Represents runtime information needed to construct a dataset. +class DatasetContext { + public: + struct Params { + string type_string; // op type name of this dataset. + string node_name; // graph node name of this dataset op, uniquely + // identifying the dataset in the graph. + }; + + explicit DatasetContext(Params params) : params_(std::move(params)) {} + + explicit DatasetContext(OpKernelContext* ctx) { + params_.type_string = ctx->op_kernel().type_string(); + params_.node_name = ctx->op_kernel().name(); + } + + const string& type_string() const { return params_.type_string; } + const string& node_name() const { return params_.node_name; } + + private: + Params params_; +}; + +// Returns the number of bytes allocated for the given tensor. +int64_t GetAllocatedBytes(const std::vector& element); + +// Returns the estimated memory usage in bytes of the given tensor. +int64_t GetTotalBytes(const std::vector& element); + +// Validates and extracts a `DatasetBase` object from `tensor`. +// +// `tensor` must have been written by a call to SetVariantTensorToDataset(). +// +// The retrieved pointer is a borrowed reference to the dataset, which is owned +// by the tensor. The consumer must either acquire its own reference to the +// dataset by calling `(*out_dataset)->Ref()`, or ensure that `tensor` is not +// destroyed or mutated while the retrieved pointer is in use. +absl::Status GetDatasetFromVariantTensor(const Tensor& tensor, + DatasetBase** out_dataset); + +// Stores a `DatasetBase` object in `tensor`. +// +// The ownership of `dataset` is transferred to `tensor`. +absl::Status StoreDatasetInVariantTensor(DatasetBase* dataset, Tensor* tensor); + +// Represents a (potentially infinite) range of outputs, where each +// output is a tuple of tensors. +class DatasetBase : public core::RefCounted { + public: + // Key for storing the Dataset graph in the serialized format. + TF_EXPORT static const char kDatasetGraphKey[]; + + // Key for storing the output node of the Dataset graph in the serialized + // format. + TF_EXPORT static const char kDatasetGraphOutputNodeKey[]; + + explicit DatasetBase(DatasetContext&& ctx) + : type_string_(ctx.type_string()), node_name_(ctx.node_name()) {} + + // Op type name of this dataset. + const string& type_string() const { return type_string_; } + + // Graph node name of this dataset op, uniquely identifying the dataset in + // the graph. + const string& node_name() const { return node_name_; } + + const Metadata& metadata() const { return metadata_; } + + const Options& options() const { return options_; } + + int64_t num_sources() const { return num_sources_; } + + // Initializes the dataset using the given metadata. + void Initialize(const Metadata& metadata); + + // Returns a new iterator for iterating over the range of elements in + // this dataset. + // + // This method may be called multiple times on the same instance, + // and the resulting iterators will have distinct state. Each + // iterator will traverse all elements in this dataset from the + // start. + // + // The prefix identifies the sequence of iterators leading up to the newly + // created iterator. + absl::Status MakeIterator(IteratorContext* ctx, const IteratorBase* parent, + const string& output_prefix, + std::unique_ptr* iterator) const; + + absl::Status MakeIterator(IteratorContext&& ctx, const IteratorBase* parent, + const string& output_prefix, + std::unique_ptr* iterator) const { + return MakeIterator(&ctx, parent, output_prefix, iterator); + } + + // Returns a new iterator restored from the checkpoint data in `reader`. + absl::Status MakeIteratorFromCheckpoint( + IteratorContext* ctx, const string& output_prefix, + IteratorStateReader* reader, + std::unique_ptr* iterator) const { + std::unique_ptr it; + IteratorContext::Params params(ctx); + params.is_restoring = true; + IteratorContext restore_ctx(std::move(params)); + TF_RETURN_IF_ERROR(MakeIterator(&restore_ctx, + /*parent=*/nullptr, output_prefix, &it)); + TF_RETURN_IF_ERROR(it->Restore(&restore_ctx, reader)); + ctx->MergeCheckpoint(restore_ctx.checkpoint()); + *iterator = std::move(it); + return absl::OkStatus(); + } + + absl::Status MakeIteratorFromCheckpoint( + IteratorContext&& ctx, const string& output_prefix, + IteratorStateReader* reader, + std::unique_ptr* iterator) const { + return MakeIteratorFromCheckpoint(&ctx, output_prefix, reader, iterator); + } + + // Returns a split provider which partitions the dataset's data into splits + // and provides them in a sequence. The split provider is stored in + // `*split_provider`. + virtual absl::Status MakeSplitProviders( + std::vector>* split_providers) const; + + // Returns a vector of DataType values, representing the respective + // element types of each tuple component in the outputs of this + // dataset. + virtual const DataTypeVector& output_dtypes() const = 0; + + // Returns a vector of tensor shapes, representing the respective + // (and possibly partially defined) shapes of each tuple component + // in the outputs of this dataset. + virtual const std::vector& output_shapes() const = 0; + + // Returns the number of bytes allocated for tensors of this dataset. + virtual int64_t AllocatedBytes() const { return 0; } + + // Returns the estimated element size based on `output_shapes()` and + // `output_dtypes()`. + virtual std::optional GetEstimatedElementSize() const; + + // Returns the estimated number of bytes used for tensors of this dataset. + virtual int64_t TotalBytes() const { return 0; } + + // Returns the cardinality of this dataset. + // TODO(shilpakrish): Remove this overload once all callers are migrated + // to the API which passes in the options parameter. + ABSL_DEPRECATED("Use the overload that passes in the options parameter.") + int64_t Cardinality() const; + + // Returns the cardinality of this dataset based on the options. + int64_t Cardinality(CardinalityOptions options) const; + + // Internal implementation of cardinality for a dataset based on the options. + virtual int64_t CardinalityInternal(CardinalityOptions options) const + TF_EXCLUSIVE_LOCKS_REQUIRED(cardinality_mu_) { + return kUnknownCardinality; + } + + // A human-readable debug string for this dataset. + virtual string DebugString() const = 0; + + // Stores the dataset's input datasets in `*inputs`. The pointers stored in + // `*inputs` are borrowed. The only valid non-ok return status is + // UNIMPLEMENTED in case `InputDatasets` is not implemented by a dataset + // subclass. Implementing `InputDatasets` enables `DatasetBase` to provide a + // default implementation of `MakeSplitProvider` when there is a single input + // dataset. + virtual absl::Status InputDatasets( + std::vector* inputs) const; + + // Indicates whether the dataset depends on any external state which would + // prevent it from being serializable. If so, the method returns + // `errors::FailedPrecondition` with a message that identifies the external + // state. Otherwise, the method returns `OkStatus()`. + virtual absl::Status CheckExternalState() const = 0; + + // Indicates whether the dataset is compatible with random access. + absl::Status CheckRandomAccessCompatible(const int64 index) const; + + // Return the element at a particular index for a randomly accessible dataset. + virtual absl::Status Get(OpKernelContext* ctx, int64 index, + std::vector* out_tensors) const; + + // Same as above, but with an `AnyContext`, which can be constructed from + // either an `OpKernelContext` or `IteratorContext`. Used to support datasets + // that provide random access through both the dataset and iterator APIs. + virtual absl::Status Get(AnyContext ctx, int64 index, + std::vector* out_tensors) const; + + // Returns true if the dataset and its inputs support random access. + virtual absl::Status RandomIndexingCompatible() const { + return absl::FailedPreconditionError( + absl::StrCat(type_string(), " does not support random access.")); + } + + // Return a finalized version of the dataset. The returned DatasetBase is + // unowned and lives for as long as this dataset. + virtual absl::StatusOr Finalize( + OpKernelContext* ctx, + std::function>()> + make_finalized_dataset) const; + + // Wrapper around a GraphDefBuilder which provides support for serializing + // Datasets as GraphDefs. + class DatasetGraphDefBuilder : public GraphDefBuilderWrapper { + public: + explicit DatasetGraphDefBuilder(GraphDefBuilder* b) + : GraphDefBuilderWrapper(b) {} + absl::Status AddInputDataset(SerializationContext* ctx, + const DatasetBase* dataset, Node** output); + absl::Status AddDatasetOrTensor(SerializationContext* ctx, + const Tensor& val, Node** output); + absl::Status AddIdentity(SerializationContext* ctx, + const std::string& name_prefix, Node** input, + Node** output); + + private: + absl::Status AddDatasetOrTensorHelper(SerializationContext* ctx, + const Tensor& val, Node** output); + absl::Status AddResourceHelper(SerializationContext* ctx, const Tensor& val, + Node** output); + }; + + protected: + friend class CapturedFunction; + + // Serializes the dataset into a `GraphDef`, which has two uses: + // + // 1) To perform static input pipeline optimizations, tf.data serializes the + // dataset graph, applies graph rewrites, and then deserializes the graph. + // If a subclass of `DatasetBase` does not implement this method, then it will + // be excluded from static optimizations (and so will any upstream datasets). + // + // 2) To save the dataset so that it can restore at a later point (possibly in + // different environment). If a subclass of `DatasetBase` does not implement + // this method, then this migration will not be possible. + virtual absl::Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, + Node** node) const = 0; + + virtual std::unique_ptr MakeIteratorInternal( + const string& prefix) const = 0; + + void set_options(const Options& options) { options_ = options; } + + private: + // Computes and stores the cardinality of a given dataset. + absl::Status ComputeCardinality(); + + // Computes the number of source datasets feeding into this dataset. A source + // dataset is a leaf in the subtree of dataset inputs. + absl::Status ComputeNumSources(); + + // Merges options from inputs to this dataset. If there is a conflict in a + // field value, the options set on this dataset takes precedence over those in + // the inputs. The order of precedence on the inputs is in the same order as + // how they appear for this dataset. + absl::Status MergeOptionsFromInputs(); + + const string type_string_; + const string node_name_; + Metadata metadata_; + Options options_; + mutable mutex mu_; + mutable mutex cardinality_mu_; + mutable core::RefCountPtr finalized_dataset_; + // The number of source datasets feeding into the dataset. A source dataset + // is a leaf in the subtree of dataset inputs. + int64_t num_sources_ = -1; + mutable int64_t cardinality_ TF_GUARDED_BY(cardinality_mu_) = + kUnknownCardinality; +}; + +// Represents an iterator that is associated with a particular dataset. +class DatasetBaseIterator : public IteratorBase { + public: + struct BaseParams { + // Owns one reference on the shared dataset object. + const DatasetBase* dataset; + + // Identifies the sequence of iterators leading up to this iterator. + const string prefix; + }; + + explicit DatasetBaseIterator(const BaseParams& params); + + ~DatasetBaseIterator() override; + + virtual const DatasetBase* dataset() const { return params_.dataset; } + + const DataTypeVector& output_dtypes() const override { + return params_.dataset->output_dtypes(); + } + + const std::vector& output_shapes() const override { + return params_.dataset->output_shapes(); + } + + const string& prefix() const override { return params_.prefix; } + + // Returns a name to be used for the TraceMe event. + // + // NOTE: TraceMe supports passing key-value pairs of "arguments" using the + // following format "name#arg_1=value_,...,arg_n=value_n". + string BuildTraceMeName(); + + absl::Status GetNext(IteratorContext* ctx, std::vector* out_tensors, + bool* end_of_sequence) final; + + absl::Status GetNext(IteratorContext&& ctx, std::vector* out_tensors, + bool* end_of_sequence) { + return GetNext(&ctx, out_tensors, end_of_sequence); + } + + absl::Status Skip(IteratorContext* ctx, int num_to_skip, + bool* end_of_sequence, int* num_skipped) final; + + absl::Status Save(SerializationContext* ctx, + IteratorStateWriter* writer) final { + VLOG(2) << "Attempting to save checkpoints on iterator (prefix: " + << prefix() << ") from " << dataset()->DebugString(); + return IteratorBase::Save(ctx, writer); + } + + // Returns a copy of the `status` where the error message is prepended with + // dataset name and the iterator prefix. + absl::Status AddErrorContext(const absl::Status& status) const { + return absl::Status( + status.code(), + strings::StrCat("Error in user-defined function passed to ", + dataset()->metadata().name(), + " transformation with iterator: ", prefix(), ": ", + status.message())); + } + + protected: + absl::Status Restore(IteratorContext* ctx, + IteratorStateReader* reader) final { + VLOG(2) << "Attempting to restore checkpoints on iterator (prefix: " + << prefix() << ") from " << dataset()->DebugString(); + return IteratorBase::Restore(ctx, reader); + } + + // Internal implementation of GetNext that is wrapped in tracing logic. + // + // See the docstring of `GetNext` method regaring the contract for + // `out_tensors` and `end_of_sequence`. Implementations may assume that + // `*out_tensors` is empty. + virtual absl::Status GetNextInternal(IteratorContext* ctx, + std::vector* out_tensors, + bool* end_of_sequence) = 0; + + // Internal implementation of Skip that is wrapped in tracing logic + virtual absl::Status SkipInternal(IteratorContext* ctx, int num_to_skip, + bool* end_of_sequence, int* num_skipped); + + string full_name(const string& name) const { + return FullName(params_.prefix, name); + } + + // Returns a map of key-value pairs to included in the TraceMe string. + virtual TraceMeMetadata GetTraceMeMetadata() const { return {}; } + + // By default we model iterators using an unknown node, which acts as + // pass-through with respect to performance modeling. + std::shared_ptr CreateNode( + IteratorContext* ctx, model::Node::Args args) const override { + return model::MakeUnknownNode(std::move(args)); + } + + // When modeling is enabled, this method disables autotuning for the given + // iterator (and the transitive closure of its inputs). + void DisableAutotune(IteratorContext* ctx, IteratorBase* iterator) { + if (iterator->node_) { + iterator->node_->set_autotune(false); + } + } + + // When modeling is enabled, this method enables autotuning for the given + // iterator (and the transitive closure of its inputs). + void EnableAutotune(IteratorContext* ctx, IteratorBase* iterator) { + if (iterator->node_) { + iterator->node_->set_autotune(true); + } + } + + // When modeling is enabled, this method records the fact that this iterator + // has dequeued an element from an internal buffer. + void RecordBufferDequeue(IteratorContext* ctx, + const std::vector& element) { + if (collect_resource_usage(ctx)) { + node_->record_buffer_event(-GetAllocatedBytes(element), -1); + DCHECK_GE(node_->buffered_elements(), 0); + } + } + + // When modeling is enabled, this method records the fact that this iterator + // has enqueued an element in an internal buffer. + void RecordBufferEnqueue(IteratorContext* ctx, + const std::vector& element) { + if (collect_resource_usage(ctx)) { + node_->record_buffer_event(GetAllocatedBytes(element), 1); + } + } + + // When modeling is enabled, this method records the fact that this iterator + // has produced an element and its size in bytes. + void RecordElement(IteratorContext* ctx, std::vector* out_tensors) { + if (collect_resource_usage(ctx)) { + int64_t num_bytes = GetAllocatedBytes(*out_tensors); + node_->record_element(); + node_->record_bytes_produced(num_bytes); + if (node_->output()) { + node_->output()->record_bytes_consumed(num_bytes); + } + } + } + + // When modeling is enabled, this method records the fact that a thread of + // this iterator has started work. + void RecordStart(IteratorContext* ctx) { + if (collect_resource_usage(ctx)) { + int64_t now_nanos = EnvTime::NowNanos(); + node_->record_start(now_nanos); + } + } + + // When modeling is enabled, this method records the fact that a thread of + // this iterator has stopped work. + void RecordStop(IteratorContext* ctx) { + if (collect_resource_usage(ctx)) { + int64_t now_nanos = EnvTime::NowNanos(); + node_->record_stop(now_nanos); + } + } + + // Returns whether work is currently being recorded, i.e. whether we are + // currently between a `RecordStart` and a `RecordStop`. + bool IsRecording(IteratorContext* ctx) { + return node_ && node_->is_recording(); + } + + private: + bool collect_resource_usage(IteratorContext* ctx) { + return ctx->model() && node_; + } + + string traceme_metadata_; + BaseParams params_; +}; + +// Represents an iterator that is associated with a particular dataset +// with a particular type. +template +class DatasetIterator : public DatasetBaseIterator { + public: + struct Params { + // Borrowed pointer to the dataset. + const DatasetType* dataset; + + // Identifies the sequence of iterators leading up to this iterator. + const string prefix; + }; + + explicit DatasetIterator(const Params& params) + : DatasetBaseIterator({params.dataset, params.prefix}), + typed_dataset_(params.dataset) {} + + // The dataset from which this iterator was created. + const DatasetType* dataset() const final { return typed_dataset_; } + + private: + const DatasetType* const typed_dataset_; // Not owned. +}; + +template +absl::Status ParseScalarArgument(OpKernelContext* ctx, + const absl::string_view& argument_name, + T* output) { + const Tensor* argument_t; + TF_RETURN_IF_ERROR(ctx->input(argument_name, &argument_t)); + if (!TensorShapeUtils::IsScalar(argument_t->shape())) { + return errors::InvalidArgument(argument_name, " must be a scalar"); + } + *output = argument_t->scalar()(); + return absl::OkStatus(); +} + +template +absl::Status ParseVectorArgument(OpKernelContext* ctx, + const absl::string_view& argument_name, + std::vector* output) { + const Tensor* argument_t; + TF_RETURN_IF_ERROR(ctx->input(argument_name, &argument_t)); + if (!TensorShapeUtils::IsVector(argument_t->shape())) { + return errors::InvalidArgument(argument_name, " must be a vector"); + } + int size = argument_t->vec().size(); + output->reserve(size); + for (int i = 0; i < size; ++i) { + output->push_back(argument_t->vec()(i)); + } + return absl::OkStatus(); +} + +// Encapsulates the work required to plug a DatasetBase into the core TensorFlow +// graph execution engine. +class DatasetOpKernel : public OpKernel { + public: + explicit DatasetOpKernel(OpKernelConstruction* ctx) : OpKernel(ctx) { + if (ctx->HasAttr(kMetadata)) { + std::string serialized_metadata; + OP_REQUIRES_OK(ctx, ctx->GetAttr(kMetadata, &serialized_metadata)); + OP_REQUIRES(ctx, metadata_.ParseFromString(serialized_metadata), + errors::InvalidArgument(absl::StrCat( + "Could not parse the 'metadata' attribute."))); + } + } + + void Compute(OpKernelContext* ctx) final; + + // Checks whether the given op is a tf.data operation. + // + // NOTE: The check uses a heuristic and can produce both false positives and + // false negatives. In particular, tf.data operations are expected to use + // names that end with "Dataset" or "DatasetV[0-9]+". + static bool IsDatasetOp(const OpDef& op_def); + + string TraceString(const OpKernelContext& ctx, bool verbose) const override; + + protected: + // Subclasses should implement this method. It will be called during Compute + // execution. + virtual void MakeDataset(OpKernelContext* ctx, DatasetBase** output) = 0; + + private: + Metadata metadata_; +}; + +// Encapsulates the work required to plug unary Datasets into the core +// TensorFlow graph execution engine. +class UnaryDatasetOpKernel : public DatasetOpKernel { + public: + explicit UnaryDatasetOpKernel(OpKernelConstruction* ctx) + : DatasetOpKernel(ctx) {} + + protected: + void MakeDataset(OpKernelContext* ctx, DatasetBase** output) final; + virtual void MakeDataset(OpKernelContext* ctx, DatasetBase* input, + DatasetBase** output) = 0; +}; + +// Encapsulates the work required to plug binary Datasets into the core +// TensorFlow graph execution engine. +class BinaryDatasetOpKernel : public DatasetOpKernel { + public: + explicit BinaryDatasetOpKernel(OpKernelConstruction* ctx) + : DatasetOpKernel(ctx) {} + + protected: + void MakeDataset(OpKernelContext* ctx, DatasetBase** output) final; + virtual void MakeDataset(OpKernelContext* ctx, DatasetBase* input, + DatasetBase* another_input, + DatasetBase** output) = 0; +}; + +// A simple background worker that executes closures asynchronously and without +// blocking. +// +// A `BackgroundWorker` is used to offload blocking work from an `AsyncOpKernel` +// to avoid blocking an executor thread that may be required by the blocking +// work. +// +// NOTE(mrry): We do not use a regular `tensorflow::thread::ThreadPool` for this +// purpose because its current implementation (in Eigen) uses a finite-length +// queue and will block the caller when full. This can lead to deadlock under +// heavy load. Since the number of concurrent work items in each user of a +// `BackgroundWorker` is at most one per op invocation, the dynamic allocation +// overhead is tolerable. +class BackgroundWorker { + public: + BackgroundWorker(Env* env, const char* name); + + ~BackgroundWorker(); + + void Schedule(std::function work_item); + + private: + void WorkerLoop(); + + Env* const env_; + const char* const name_; + + std::unique_ptr thread_; + mutex mu_; + condition_variable cond_var_; + bool cancelled_ TF_GUARDED_BY(mu_) = false; + std::deque> work_queue_ TF_GUARDED_BY(mu_); +}; + +} // namespace data +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_FRAMEWORK_DATASET_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/framework/dataset_stateful_op_allowlist.h b/third_party/tflite-hdrs/tensorflow/core/framework/dataset_stateful_op_allowlist.h new file mode 100644 index 00000000..cc25c801 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/framework/dataset_stateful_op_allowlist.h @@ -0,0 +1,81 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_FRAMEWORK_DATASET_STATEFUL_OP_ALLOWLIST_H_ +#define TENSORFLOW_CORE_FRAMEWORK_DATASET_STATEFUL_OP_ALLOWLIST_H_ + +#include +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { +namespace data { +// Registry for stateful ops that need to be used in dataset functions. +// See below macro for usage details. +class AllowlistedStatefulOpRegistry { + public: + absl::Status Add(string op_name) { + op_names_.insert(std::move(op_name)); + return absl::OkStatus(); + } + + absl::Status Remove(string op_name) { + op_names_.erase(op_name); + return absl::OkStatus(); + } + + bool Contains(const string& op_name) { return op_names_.count(op_name); } + + static AllowlistedStatefulOpRegistry* Global() { + static auto* reg = new AllowlistedStatefulOpRegistry; + return reg; + } + + private: + AllowlistedStatefulOpRegistry() = default; + AllowlistedStatefulOpRegistry(AllowlistedStatefulOpRegistry const& copy) = + delete; + AllowlistedStatefulOpRegistry operator=( + AllowlistedStatefulOpRegistry const& copy) = delete; + + std::unordered_set op_names_; +}; + +} // namespace data + +// Use this macro to allowlist an op that is marked stateful but needs to be +// used inside a map_fn in an input pipeline. This is only needed if you wish +// to be able to checkpoint the state of the input pipeline. We currently +// do not allow stateful ops to be defined inside of map_fns since it is not +// possible to save their state. +// Note that the state of the allowlisted ops inside functions will not be +// saved during checkpointing, hence this should only be used if the op is +// marked stateful for reasons like to avoid constant folding during graph +// optimization but is not stateful. +// If possible, try to remove the stateful flag on the op first. +// Example usage: +// +// ALLOW_STATEFUL_OP_FOR_DATASET_FUNCTIONS("LegacyStatefulReader"); +// +#define ALLOW_STATEFUL_OP_FOR_DATASET_FUNCTIONS(name) \ + ALLOW_STATEFUL_OP_FOR_DATASET_FUNCTIONS_UNIQ_HELPER(__COUNTER__, name) +#define ALLOW_STATEFUL_OP_FOR_DATASET_FUNCTIONS_UNIQ_HELPER(ctr, name) \ + ALLOW_STATEFUL_OP_FOR_DATASET_FUNCTIONS_UNIQ(ctr, name) +#define ALLOW_STATEFUL_OP_FOR_DATASET_FUNCTIONS_UNIQ(ctr, name) \ + static ::tensorflow::Status allowlist_op##ctr TF_ATTRIBUTE_UNUSED = \ + ::tensorflow::data::AllowlistedStatefulOpRegistry::Global()->Add(name) + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_FRAMEWORK_DATASET_STATEFUL_OP_ALLOWLIST_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/framework/device.h b/third_party/tflite-hdrs/tensorflow/core/framework/device.h new file mode 100644 index 00000000..7b5bfcb1 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/framework/device.h @@ -0,0 +1,230 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// A Device is a something that can perform computations as part of a +// model. Devices can be local (runs computation on this machine), or +// remote (contacts a device local to another machine using an RPC to +// do the work). Devices are registered in a DeviceSet, which is also +// responsible for the Device <-> id mapping. +// +// Device names +// * Every Device should have a unique name with the format: +// /job:___/replica:___/task:___/(gpu|cpu):___ +// An example name would be "/job:train/replica:0/task:3/device:GPU:2". +// * Task numbers are within the specified replica, so there are as +// many "task zeros" as replicas. + +#ifndef TENSORFLOW_CORE_FRAMEWORK_DEVICE_H_ +#define TENSORFLOW_CORE_FRAMEWORK_DEVICE_H_ + +#include +#include + +#include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/framework/control_flow.h" +#include "tensorflow/core/framework/device_attributes.pb.h" +#include "tensorflow/core/framework/device_base.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/op_segment.h" +#include "tensorflow/core/framework/resource_mgr.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/graph/types.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/device_name_utils.h" + +namespace tensorflow { + +class Device : public DeviceBase { + public: + // Callback type that takes a Status and returns void. + typedef std::function DoneCallback; + + Device(Env* env, const DeviceAttributes& device_attributes); + ~Device() override; + + // A compare function that orders devices by their parsed name. + static bool LessByParsedName(const Device& a, const Device& b) { + return a.parsed_name() < b.parsed_name(); + } + + // Full name of this device (see top comment). + const std::string& name() const override { return device_attributes_.name(); } + + // Parsed name of this device + const DeviceNameUtils::ParsedName& parsed_name() const override { + return parsed_name_; + } + + // Describes what kind of device this is. This is intended to be + // human-readable and not computer-parsed, except that two devices + // with the same device_type() are expected to perform similarly + // (both from a computation and communication perspective). + const std::string& device_type() const override { + return device_attributes_.device_type(); + } + + // Returns an aggregation of device attributes. + const DeviceAttributes& attributes() const override { + return device_attributes_; + } + + // Performs the actual compute function. + // + // Subclasses may override this function if they wish to perform + // some initialization before each compute. + virtual void Compute(OpKernel* op_kernel, OpKernelContext* context) { + op_kernel->Compute(context); + } + + // Asynchronous kernel's compute. + virtual void ComputeAsync(AsyncOpKernel* op_kernel, OpKernelContext* context, + AsyncOpKernel::DoneCallback done) { + op_kernel->ComputeAsync(context, std::move(done)); + } + + // Blocks until all operations queued on the device at the time of + // the call have completed. Returns any error pending on the device + // at completion. + virtual absl::Status Sync() = 0; + + // Calls the given callback when all operations queued on the device at the + // time of the call have completed. The callback is passed any error pending + // on the device at completion. + // TODO(b/112409994): Consolidate these two APIs, removing the synchronous + // version. + virtual void Sync(const DoneCallback& done); + + // On session completion, the executor may call Device::Sync() depending on + // flag settings. Override this to return false for devices that don't allow + // such calls. Instead, these devices must use other mechanisms (such as + // num_deferred_ops) to ensure the device has finished processing necessary + // work at session completion. In addition, for these devices, RefreshStatus + // must be called at session completion to retrieve execution result status. + // + // Devices that override this function must also implement RefreshStatus. + virtual bool AllowsSyncOnCompletion() const { return true; } + + // This is used in conjunction with AllowsSyncOnCompletion to allow the + // executor to get execution result status at session completion. + // + // For supported devices, this call returns the underlying device stream's + // current status in a non-blocking way, without using blocking calls such as + // Stream::BlockHostUntilDone or Device::Sync. When applicable, the device + // status is also updated with the retrieved stream status. + virtual absl::Status RefreshStatus() { + return errors::Unimplemented( + "RefreshStatus is not supported on this device."); + } + + // Optionally modify the device's GraphDef before execution. + // + // This method should be considered experimental and is supplied to enable + // prototyping of TensorFlow device implementations that need to modify + // the GraphDef before execution. + // + // 'graph' supplies the partition of the graph assigned to this + // device. + virtual absl::Status MaybeRewriteGraph(std::unique_ptr* /*graph*/) { + return absl::OkStatus(); + } + + // Sets `out_context` a new DeviceContext* for executing a graph, or nullptr + // if the device does not support contexts. Returns an error status if any + // error occurred while trying to create a context, otherwise OK. + // + // The caller takes ownership of one reference on the output DeviceContext*, + // and should call Unref(). + virtual absl::Status TryGetDeviceContext(DeviceContext** out_context) { + *out_context = nullptr; + return absl::OkStatus(); + } + + // Returns the op segment of this device. The caller can reuse op + // kernels registered for the same session running on this device. + OpSegment* op_segment() { return &op_seg_; } + + // Returns the resource manager associated w/ this device. + virtual ResourceMgr* resource_manager() { return rmgr_; } + + // Summarizes the status of this Device, for debugging. + std::string DebugString() const { return device_attributes_.DebugString(); } + + // Assembles the parameter components into a complete DeviceAttributes value. + static DeviceAttributes BuildDeviceAttributes( + const std::string& name, DeviceType device, Bytes memory_limit, + const DeviceLocality& locality, const std::string& physical_device_desc); + + static DeviceAttributes BuildDeviceAttributes( + const std::string& name, DeviceType device, Bytes memory_limit, + const DeviceLocality& locality) { + // Pass in an empty string as physical device name. + return BuildDeviceAttributes(name, device, memory_limit, locality, ""); + } + + // Updates `attributes()`, indicating the XLA global ID associated with this + // device. This ID is unique across clients in a multi-client setup. For TPUs + // this does not happen until the TPU system has been initialized. + void set_xla_global_id(int64_t id) override { + device_attributes_.set_xla_global_id(id); + } + + // Clears the resource manager associated with this device. + void ClearResourceMgr() { rmgr_->Clear(); } + + virtual bool IsLocal() const { return true; } + + // Informs if this Device can be used as a caller in RemoteCall operation. + virtual bool IsRemoteCallAllowed() const; + + // Whether to merge the host_to_device copy stream with the compute stream. + // Only useful for GPU devices. + virtual bool merge_host_to_device_stream() const { return false; } + + // Whether to merge the device_to_host copy stream with the compute stream. + // Only useful for GPU devices. + virtual bool merge_device_to_host_stream() const { return false; } + + // Whether to merge the device_to_device copy streams with the compute stream. + // Only useful for GPU devices. + virtual bool merge_device_to_device_stream() const { return false; } + + protected: + void DeleteResourceMgr() { + delete rmgr_; + rmgr_ = nullptr; + } + + private: + DeviceAttributes device_attributes_; + DeviceNameUtils::ParsedName parsed_name_; + + // op_seg_ maps session handle and op name to OpKernel objects. + OpSegment op_seg_; + + // Resources associated w/ this device. E.g., shared variables, etc. + ResourceMgr* rmgr_ = nullptr; + + Device(const Device&) = delete; + void operator=(const Device&) = delete; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_FRAMEWORK_DEVICE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/framework/device_base.h b/third_party/tflite-hdrs/tensorflow/core/framework/device_base.h new file mode 100644 index 00000000..fe5099fa --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/framework/device_base.h @@ -0,0 +1,313 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_FRAMEWORK_DEVICE_BASE_H_ +#define TENSORFLOW_CORE_FRAMEWORK_DEVICE_BASE_H_ + +#include +#include +#include + +#include "absl/base/macros.h" +#include "absl/strings/string_view.h" +#include "tensorflow/core/framework/device_attributes.pb.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/refcount.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/threadpool.h" +#include "tensorflow/core/util/device_name_utils.h" + +namespace Eigen { +struct ThreadPoolDevice; +} // end namespace Eigen + +namespace stream_executor { +class Stream; +} // namespace stream_executor + +namespace tsl { +class Env; +namespace thread { +class ThreadPool; +} // namespace thread +} // namespace tsl +namespace tensorflow { + +class Device; +class DeviceAttributes; +class EventMgr; +class OpKernelContext; +class ResourceMgr; +class ScopedAllocatorMgr; +class TensorProto; + +// A wrapper for an Eigen Gpu Device that includes per-op state. The +// class is defined even for non-GPU devices since the +// OpKernelContext::Params structure wants to fill it in. +class PerOpGpuDevice { + public: + virtual ~PerOpGpuDevice() {} + virtual const Eigen::GpuDevice& device() const = 0; +}; + +// A class that devices can subclass to pass around +// Device-specific context to OpKernels. +class DeviceContext : public core::RefCounted { + public: + ~DeviceContext() override {} + virtual stream_executor::Stream* stream() const { return nullptr; } + virtual void MaintainLifetimeOnStream(const Tensor* t, + stream_executor::Stream* stream) const { + } + + // "cpu_tensor" is a tensor on a CPU. Copies "cpu_tensor" into + // "device_tensor" which is on a non-CPU device "device". "device_tensor" + // must be allocated to be of the same size as "cpu_tensor". + virtual void CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device, + Tensor* device_tensor, StatusCallback done, + bool sync_dst_compute = true) const { + done(errors::Internal("Unrecognized device type in CPU-to-device Copy")); + } + + // Same as CopyCPUTensorToDevice, but in a synchronous way. + absl::Status CopyCPUTensorToDeviceSync(const Tensor* cpu_tensor, + Device* device, + Tensor* device_tensor) const; + + // Copies a tensor in this device. + virtual void CopyTensorInSameDevice(const Tensor* input_tensor, + Device* device, Tensor* output_tensor, + StatusCallback done) const { + done(errors::Unimplemented("Copy in same device not implemented.")); + } + + // "device_tensor" is a tensor on a non-CPU device. Copies + // device_tensor into "cpu_tensor". "cpu_tensor" must be allocated + // to be of the same size as "device_tensor". + virtual void CopyDeviceTensorToCPU(const Tensor* device_tensor, + absl::string_view tensor_name, + Device* device, Tensor* cpu_tensor, + StatusCallback done) { + done(errors::Internal("Unrecognized device type in device-to-CPU Copy")); + } + + // Same as `CopyDeviceTensorToCPU`, but blocks until the copy is done. + absl::Status CopyDeviceTensorToCPUSync(const Tensor* device_tensor, + absl::string_view tensor_name, + Device* device, Tensor* cpu_tensor); + + // If possible, wait for all events on *stream to complete then execute func. + // A non-OK Status is returned otherwise. The stream argument should be the + // one provided by AcceleratorDeviceInfo. This function is not applicable to + // devices that don't provide such a value. + virtual absl::Status ThenExecute(Device* device, + stream_executor::Stream* stream, + std::function func) { + return errors::Internal("ThenExecute not supported by device"); + } + + // check if device is a pluggable device + virtual bool IsPluggableDevice() { return false; } + + // Returns the pinned host memory allocator for the device. + virtual Allocator* host_memory_allocator() const { return nullptr; } +}; + +class DeviceBase { + public: + explicit DeviceBase(tsl::Env* env) : env_(env) {} + virtual ~DeviceBase(); + + tsl::Env* env() const { return env_; } + + struct CpuWorkerThreads { + int num_threads = 0; + tsl::thread::ThreadPool* workers = nullptr; + }; + + // Does not take ownership. + void set_tensorflow_cpu_worker_threads(CpuWorkerThreads* t) { + cpu_worker_threads_ = t; + } + + virtual const CpuWorkerThreads* tensorflow_cpu_worker_threads() const { + CHECK(cpu_worker_threads_ != nullptr); + return cpu_worker_threads_; + } + + // "stream" is used in special circumstances (such as the + // constructors of Ops) where there is no available OpKernelContext. + // "default_context" is used by OpKernelContext whenever a device does not + // supply a DeviceContext for an op in TryGetDeviceContext() (e.g. when only + // using a single stream.) + // "event_mgr" is used to delay deallocation of temporary GPU buffers. + // TODO(pbar) Work out how to move this out of DeviceBase. + struct AcceleratorDeviceInfo { + // Make sure all the defaults are NULL, so we can spot missing assignments. + stream_executor::Stream* stream = nullptr; + DeviceContext* default_context = nullptr; + DeviceContext* pjrt_context = nullptr; + bool use_pjrt_tensor_buffer = false; + EventMgr* event_mgr = nullptr; + int gpu_id = -1; + }; + + // Does not take ownership. + void set_tensorflow_accelerator_device_info( + AcceleratorDeviceInfo* device_info) { + accelerator_device_info_ = device_info; + } + + virtual const AcceleratorDeviceInfo* tensorflow_accelerator_device_info() + const { + return accelerator_device_info_; + } + + // The preferred thread pool for this device. If it is nullptr, the system + // automatically assigns a thread pool for execution. + virtual tsl::thread::ThreadPool* tensorflow_device_thread_pool() { + return device_thread_pool_; + } + + // Does not take ownership. + void set_eigen_cpu_device(Eigen::ThreadPoolDevice* d); + + // Return the Allocator implementation to use based on the allocator + // attributes requested. See allocator.h for more details. + virtual Allocator* GetAllocator(AllocatorAttributes /*attr*/) { + LOG(FATAL) << "GetAllocator() is not implemented."; + return nullptr; + } + + // This method is provided for backwards compatibility, and will be removed + // in a future release. + ABSL_DEPRECATED("Use `this->GetAllocator()` or `this->GetScopedAllocator()`.") + Allocator* GetStepAllocator(AllocatorAttributes attr, ResourceMgr*) { + return GetAllocator(attr); + } + + // Return an Allocator prepared for use in particular places by graph + // optimization + virtual Allocator* GetScopedAllocator(AllocatorAttributes attr, + int64_t step_id) { + LOG(FATAL) << "Device does not implement GetScopedAllocator()"; + return nullptr; + } + + virtual ScopedAllocatorMgr* GetScopedAllocatorMgr() const { return nullptr; } + + virtual bool has_eigen_cpu_device() const { + return !eigen_cpu_devices_.empty(); + } + + virtual const Eigen::ThreadPoolDevice* eigen_cpu_device(); + + // Caller owns the return value. The OpKernelContext calls this even + // for devices that do not implement an eigen_gpu_device. Overridden + // by GPU devices to return a derived type. + virtual PerOpGpuDevice* MakeGpuDevice() { return nullptr; } + + virtual DeviceBase* UnderlyingDevice() { return this; } + virtual const DeviceBase* UnderlyingDevice() const { return this; } + + // This is overridden by GPU devices to reinitialize the derived + // type returned by MakeGpuDevice. + virtual absl::Status ReinitializeGpuDevice(OpKernelContext* /*context*/, + PerOpGpuDevice* /*device*/, + DeviceContext* /*dc*/, + Allocator* /*allocator*/) { + return absl::OkStatus(); + } + + // Unimplemented by default + virtual const DeviceAttributes& attributes() const; + virtual int NumaNode() const { return attributes().locality().numa_node(); } + virtual const std::string& name() const; + virtual const DeviceNameUtils::ParsedName& parsed_name() const; + virtual const std::string& device_type() const; + + // Updates `attributes()`, indicating the XLA global ID associated with this + // device. This ID is unique across clients in a multi-client setup. For TPUs + // this does not happen until the TPU system has been initialized. + // + // Implemented in Device. + virtual void set_xla_global_id(int64_t id) {} + + // Materializes the given TensorProto into 'tensor' stored in Device + // memory. Most devices will want to override this. + // + // TODO(vrv): We should be able to put this function into + // OpKernelContext and handle the copies from device memory via send + // and receive nodes, instead of requiring that each device handle + // the copies here as well as in copy ops. + virtual absl::Status MakeTensorFromProto( + const TensorProto& tensor_proto, const AllocatorAttributes alloc_attrs, + Tensor* tensor) { + return errors::Internal("Device does not implement MakeTensorFromProto()"); + } + + // Some devices (i.e. GPUs) may free device memory prior to its actual use + // being completed on the assumption that subsequent allocations can only be + // used serially with respect to pending uses. If this function returns a + // non-zero value it is the value of a device-specific counter such that any + // device memory tagged with an earlier freed-at count is really unencumbered + // by pending uses. For this to be useful the device memory allocator must + // be tagging deallocated memory chunks using the same counter. + virtual uint64 SafeAllocFrontier(uint64 old_value) { return 0; } + + // Copies `input_tensor` to `output_tensor`, where both tensors are on this + // device. This function assumes that `output_tensor` has already been + // allocated with a buffer that is large enough to hold `input_tensor`'s data. + // Calls `done` from a device-specific thread after copy is finished, which + // may be the same as calling thread. + // + // NOTE(ayushd): This function is for TensorFlow internal use only. Deep copy + // is discouraged and should not be used in OpKernels. + virtual void CopyTensorInSameDevice(const Tensor* input_tensor, + Tensor* output_tensor, + const DeviceContext* device_context, + StatusCallback done) { + done(errors::Internal("Device ", name(), " does not implement ", + "CopyTensorInSameDevice")); + } + + protected: + // Does not take ownership. + void set_tensorflow_device_thread_pool(tsl::thread::ThreadPool* thread_pool) { + device_thread_pool_ = thread_pool; + } + + private: + tsl::Env* const env_; + CpuWorkerThreads* cpu_worker_threads_ = nullptr; + // Set by GPUs as well as by TPU devices. + AcceleratorDeviceInfo* accelerator_device_info_ = nullptr; + tsl::thread::ThreadPool* device_thread_pool_ = nullptr; + std::vector eigen_cpu_devices_; +}; + +// Methods to create and check for Symbolic execution devices. +// Such devices are mostly used for TF-XLA bridge. TF should not treat these as +// normal devices. +void AddSymbolicExecutionDevice(absl::string_view device_name); +bool IsSymbolicExecutionDevice(absl::string_view device_name); + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_FRAMEWORK_DEVICE_BASE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/framework/device_factory.h b/third_party/tflite-hdrs/tensorflow/core/framework/device_factory.h new file mode 100644 index 00000000..8b07d15c --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/framework/device_factory.h @@ -0,0 +1,173 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_FRAMEWORK_DEVICE_FACTORY_H_ +#define TENSORFLOW_CORE_FRAMEWORK_DEVICE_FACTORY_H_ + +#include +#include + +#include "absl/base/attributes.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +class Device; +struct SessionOptions; + +class DeviceFactory { + public: + virtual ~DeviceFactory() {} + static void Register(const std::string& device_type, + std::unique_ptr factory, int priority, + bool is_pluggable_device); + ABSL_DEPRECATED("Use the `Register` function above instead") + static void Register(const std::string& device_type, DeviceFactory* factory, + int priority, bool is_pluggable_device) { + Register(device_type, std::unique_ptr(factory), priority, + is_pluggable_device); + } + static DeviceFactory* GetFactory(const std::string& device_type); + + // Append to "*devices" CPU devices. + static absl::Status AddCpuDevices( + const SessionOptions& options, const std::string& name_prefix, + std::vector>* devices); + + // Append to "*devices" all suitable devices, respecting + // any device type specific properties/counts listed in "options". + // + // CPU devices are added first. + static absl::Status AddDevices(const SessionOptions& options, + const std::string& name_prefix, + std::vector>* devices); + + // Helper for tests. Create a single device of type "type". The + // returned device is always numbered zero, so if creating multiple + // devices of the same type, supply distinct name_prefix arguments. + static std::unique_ptr NewDevice(const string& type, + const SessionOptions& options, + const string& name_prefix); + + // Iterate through all device factories and build a list of all of the + // possible physical devices. + // + // CPU is are added first. + static absl::Status ListAllPhysicalDevices(std::vector* devices); + + // Iterate through all device factories and build a list of all of the + // possible pluggable physical devices. + static absl::Status ListPluggablePhysicalDevices( + std::vector* devices); + + // Get details for a specific device among all device factories. + // 'device_index' indexes into devices from ListAllPhysicalDevices. + static absl::Status GetAnyDeviceDetails( + int device_index, std::unordered_map* details); + + // For a specific device factory list all possible physical devices. + virtual absl::Status ListPhysicalDevices(std::vector* devices) = 0; + + // Get details for a specific device for a specific factory. Subclasses + // can store arbitrary device information in the map. 'device_index' indexes + // into devices from ListPhysicalDevices. + virtual absl::Status GetDeviceDetails( + int device_index, std::unordered_map* details) { + return absl::OkStatus(); + } + + // Most clients should call AddDevices() instead. + virtual absl::Status CreateDevices( + const SessionOptions& options, const std::string& name_prefix, + std::vector>* devices) = 0; + + // Return the device priority number for a "device_type" string. + // + // Higher number implies higher priority. + // + // In standard TensorFlow distributions, GPU device types are + // preferred over CPU, and by default, custom devices that don't set + // a custom priority during registration will be prioritized lower + // than CPU. Custom devices that want a higher priority can set the + // 'priority' field when registering their device to something + // higher than the packaged devices. See calls to + // REGISTER_LOCAL_DEVICE_FACTORY to see the existing priorities used + // for built-in devices. + static int32 DevicePriority(const std::string& device_type); + + // Returns true if 'device_type' is registered from plugin. Returns false if + // 'device_type' is a first-party device. + static bool IsPluggableDevice(const std::string& device_type); +}; + +namespace dfactory { + +template +class Registrar { + public: + // Multiple registrations for the same device type with different priorities + // are allowed. Priorities are used in two different ways: + // + // 1) When choosing which factory (that is, which device + // implementation) to use for a specific 'device_type', the + // factory registered with the highest priority will be chosen. + // For example, if there are two registrations: + // + // Registrar("CPU", 125); + // Registrar("CPU", 150); + // + // then CPUFactory2 will be chosen when + // DeviceFactory::GetFactory("CPU") is called. + // + // 2) When choosing which 'device_type' is preferred over other + // DeviceTypes in a DeviceSet, the ordering is determined + // by the 'priority' set during registration. For example, if there + // are two registrations: + // + // Registrar("CPU", 100); + // Registrar("GPU", 200); + // + // then DeviceType("GPU") will be prioritized higher than + // DeviceType("CPU"). + // + // The default priority values for built-in devices is: + // GPU: 210 + // GPUCompatibleCPU: 70 + // ThreadPoolDevice: 60 + // Default: 50 + explicit Registrar(const std::string& device_type, int priority = 50) { + DeviceFactory::Register(device_type, std::make_unique(), priority, + /*is_pluggable_device*/ false); + } +}; + +} // namespace dfactory + +#define REGISTER_LOCAL_DEVICE_FACTORY(device_type, device_factory, ...) \ + INTERNAL_REGISTER_LOCAL_DEVICE_FACTORY(device_type, device_factory, \ + __COUNTER__, ##__VA_ARGS__) + +#define INTERNAL_REGISTER_LOCAL_DEVICE_FACTORY(device_type, device_factory, \ + ctr, ...) \ + static ::tensorflow::dfactory::Registrar \ + INTERNAL_REGISTER_LOCAL_DEVICE_FACTORY_NAME(ctr)(device_type, ##__VA_ARGS__) + +// __COUNTER__ must go through another macro to be properly expanded +#define INTERNAL_REGISTER_LOCAL_DEVICE_FACTORY_NAME(ctr) ___##ctr##__object_ + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_FRAMEWORK_DEVICE_FACTORY_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/framework/fake_input.h b/third_party/tflite-hdrs/tensorflow/core/framework/fake_input.h new file mode 100644 index 00000000..c3062762 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/framework/fake_input.h @@ -0,0 +1,40 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_FRAMEWORK_FAKE_INPUT_H_ +#define TENSORFLOW_CORE_FRAMEWORK_FAKE_INPUT_H_ + +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/types.h" + +namespace tensorflow { + +// These functions return values that may be passed to +// NodeDefBuilder::Input() to add an input for a test. Use them when +// you don't care about the node names/output indices providing the +// input. They also allow you to omit the input types and/or +// list length when they may be inferred. +FakeInputFunctor FakeInput(); // Infer everything +FakeInputFunctor FakeInput(DataType dt); +FakeInputFunctor FakeInput(int n); // List of length n +FakeInputFunctor FakeInput(int n, DataType dt); +FakeInputFunctor FakeInput(DataTypeSlice dts); +inline FakeInputFunctor FakeInput(std::initializer_list dts) { + return FakeInput(DataTypeSlice(dts)); +} + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_FRAMEWORK_FAKE_INPUT_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/framework/full_type_inference_util.h b/third_party/tflite-hdrs/tensorflow/core/framework/full_type_inference_util.h new file mode 100644 index 00000000..3117613b --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/framework/full_type_inference_util.h @@ -0,0 +1,159 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_FRAMEWORK_FULL_TYPE_INFERENCE_UTIL_H_ +#define TENSORFLOW_CORE_FRAMEWORK_FULL_TYPE_INFERENCE_UTIL_H_ + +#include +#include +#include + +#include "tensorflow/core/framework/full_type.pb.h" +#include "tensorflow/core/framework/op_def_builder.h" +#include "tensorflow/core/platform/statusor.h" + +namespace tensorflow { + +namespace full_type { + +// TODO(mdan): Specific helpers won't get too far. Use a parser instead. + +// Helpers that allow shorthand expression for the more common kinds of type +// inference functions. +// TODO(mdan): Break into separate header if it grows. +// Note: The information contained in these functions is also expressed to some +// extent by opdef attributes of the kind "input: T, output T". But in that +// context, T has strong DType semantics (i.e. T is DT_VARIANT for most +// interesting cases). The logic here extends to the op's FullType, so it's best +// to keep them separate, even though it leads to some redundancy. The +// same can be said about the shape inference function. + +// Note: Unlike type constructors, which describe op definitions, type inference +// functions are meant to modify the type information of specific nodes (i.e. +// NodeDef proto). + +// Helper for a no-op type inference function that indicates type inference +// should never alter the node's existing type. +// This is the same as not defining a type inference function at all, but +// explicitly communicates that intent. +TypeInferenceFn KeepExisting(); + +// A helper for a type inference function that indicates a single output that +// is a tensor of type t. This is the equivalent of a type construtor since it +// does not depend on inputs. This can be used with Tuple. +TypeInferenceFn Tensor(FullTypeId t); + +// Helper for a type inference function which has the same type as the i'th +// input. +// The n arg allows multiple outputs, e.g. (T -> Product[T, T]). +// TODO(mdan): Drop defaults for readability if more non-(0, 1) cases appear. +// TODO(mdan): Rename to just Replicate. +TypeInferenceFn ReplicateInput(int i = 0, int n = 1); + +// Helper for a type inference function which has the same type as a variadic +// number of inputs, e.g. (T, T -> Product[T]), (T, T, T -> Product[T]), etc. +// Infers the meet of the input types, in the sense of type meets (see +// https://en.wikipedia.org/wiki/Join_and_meet). This implementation is +// simplified to require the two inputs are a subtype of another. +TypeInferenceFn Merge(); + +// Helper for ops with semantics of encoding an input, that is, +// `T -> Encoded[T, ]`, where is the encoded type. +TypeInferenceFn Encode(FullTypeId t, int i); + +// Helper for ops with semantics of encoding an input, that is, +// `Encoded[T, ] -> T`, where is the encoded type. +TypeInferenceFn Decode(FullTypeId t, int i); + +// Helper for the type inference counterpart of Unary, that is (U -> +// PRODUCT[[U]]), where is parameterized by this factory, and U is the +// type of the input specified by element_idx. +// Note: when we migrate to a more formal type definition of an op, these two +// functions will naturally merge. +TypeInferenceFn UnaryContainerCreate(FullTypeId t, int element_idx); + +// Helper for ops with semantics of adding an element to a container ([T]), +// that is ([U], V -> PRODUCT[[Union[U, V]]]), where is parameterized +// by this factory, U is the type of the input specified by container_idx, and V +// is the type of the input specified by element_idx. The homogeneous arg allows +// for constraints which guarantee that U and V must have a subtyping +// relationship, case in which either V or U is selected, whichever is the +// supertype. +TypeInferenceFn UnaryContainerAdd(FullTypeId t, int container_idx, + int element_idx, bool homogeneous); + +// Helper for ops with semantics of unstacking multiple inputs into a container +// `[T1, ..., Tn]`, that is `T1, ..., Tn -> [PRODUCT[U1, ..., Un]]` +// where Ui is obtained from an "unstack" mapping T -> U. Both and the +// "unstack" mapping are parameterized by this factory. +// Note that when the "unstack" function is the identity function, this becomes +// equivalent to ContainerCreate. +TypeInferenceFn MultiaryUnstack( + FullTypeId t, std::function unstack); + +// Helper for ops with semantics of applying some transformation to the +// elements of a container: +// `[PRODUCT[T1, ..., Tn]] -> [PRODUCT[U1, ..., Un]]`, +// where Ui is obtained by applying a map T -> U. Both and the "map" +// function are parameterized by this factory. See BatchTensor and ShardTensor +// for examples of "map". +TypeInferenceFn ContainerMap( + FullTypeId t, int input_idx, + std::function map); + +// Helper for ops with semantics of repacking some element from a container to +// another ` -> `, in a covariant way, that is, `[T] -> [T]`. +// and are parameterized by this factory. The input type is specified by +// element_idx. +TypeInferenceFn MapCovariant(FullTypeId t, FullTypeId u, int input_idx); + +// Helper for ops with semantics of calling a function. The function is +// specified indirectly, as the name of an attribute that holds the actual +// function name. +TypeInferenceFn FunctionCall(const string& func_attr_name); + +// Compose the type of a function by concatenating the outputs of multiple +// type inference functions. If func_list is {type inference function 1, type +// inference function 2} which return PRODUCT[T1], PRODUCT[T2] resprectively, +// the result is PRODUCT[T1, T2], This supports the Merge op that has an index +// output in addition to the result of the Merge type inference function. +TypeInferenceFn Tuple(const std::vector& func_list); + +// Auxiliary constructs to help creation of type inference functions. +// TODO(mdan): define these as type inference functions as well. + +// Mapping function representing the type function for unstacking of +// Tensor (or Tensor-like) types. Note that this is a helper to use with +// other type inference functions; it's not a function itself. +// TODO(mdan): Replace with a trait, when available. +FullTypeDef UnstackTensor(const FullTypeDef& t); + +// Mapping function representing the type function for an op that changes the +// batch size of dataset. Note that this is a helper to use with other type +// inference functions; it's not a function itself. +// TODO(mdan): Replace with a trait, when available. +FullTypeDef BatchTensor(const FullTypeDef& t); + +// Mapping function representing the type function for an op that creates a +// fixed (given) number of tensors of a size calculated based on the input. Note +// that this is a helper to use with other type inference functions; it's not a +// function itself. +// TODO(mdan): Replace with a trait, when available. +FullTypeDef ShardTensor(const FullTypeDef& t); +} // namespace full_type + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_FRAMEWORK_FULL_TYPE_INFERENCE_UTIL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/framework/full_type_util.h b/third_party/tflite-hdrs/tensorflow/core/framework/full_type_util.h new file mode 100644 index 00000000..4039f3c8 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/framework/full_type_util.h @@ -0,0 +1,130 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_FRAMEWORK_FULL_TYPE_UTIL_H_ +#define TENSORFLOW_CORE_FRAMEWORK_FULL_TYPE_UTIL_H_ + +#include +#include + +#include "tensorflow/core/framework/full_type.pb.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/op_def.pb.h" +#include "tensorflow/core/framework/op_def_builder.h" +#include "tensorflow/core/platform/statusor.h" + +namespace tensorflow { + +namespace full_type { + +// TODO(mdan): Specific helpers won't get too far. Use a parser instead. +// TODO(mdan): Move constructors into a separate file. + +// Helpers that allow shorthand expression for the more common kinds of type +// constructors. +// Note: The arity below refers to the number of arguments of parametric types, +// not to the number of return values from a particular op. +// Note: Type constructors are meant to create static type definitions in the +// op definition (i.e. the OpDef proto). + +// Helper for a no-op type constructor that indicates that the node's type +// should be set by external means (typically by the user). +OpTypeConstructor NoOp(); + +// Helper for a trivial type constructor that indicates a node has no +// outputs (that is, its output type is an empty TFT_PRODUCT). +OpTypeConstructor NoOutputs(); + +// Helper for a type constructor of [] (with no parameters). +OpTypeConstructor Nullary(FullTypeId t); + +// Helper for a type constructor of [FT_VAR[]]. +OpTypeConstructor Unary(FullTypeId t, const string& var_name); + +// Helper for a type constructor of [FT_ANY]. +OpTypeConstructor UnaryGeneric(FullTypeId t); + +// Helper for a type constructor of [FT_TENSOR[]]. +OpTypeConstructor UnaryTensorContainer(FullTypeId t, FullTypeId dtype); + +// Helper for a type constructor of [FT_VAR[]]. +OpTypeConstructor UnaryTensorContainer(FullTypeId t, const string& var_name); + +// Helper for a type constructor of +// [FT_FOR_EACH[ +// FT_PRODUCT, +// FT_TENSOR[FT_VAR[]], +// FT_VAR[]]. +// Multi-valued type variables will expand the template (see full_type.proto). +OpTypeConstructor VariadicTensorContainer(FullTypeId t, const string& var_name); + +// Type specialization and inference logic. This function narrows the type +// specified in an op definition. Such types are usually generic and dependent +// on input types. This function resolves the output types based on the input +// types specified in a given node def. +absl::Status SpecializeType(const AttrSlice& attrs, const OpDef& op_def, + FullTypeDef& target); + +const FullTypeDef& GetArgDefaultUnset(const FullTypeDef& t, int i); +const FullTypeDef& GetArgDefaultAny(const FullTypeDef& t, int i); + +bool IsEqual(const FullTypeDef& lhs, const FullTypeDef& rhs); + +bool IsSubtype(const FullTypeDef& lhs, const FullTypeDef& rhs, + bool covariant = true); + +uint64_t Hash(const FullTypeDef& arg); + +// Determine if the given fulltype is a host memory type. +// While it is prefered that Placer (placer.cc and colocation_graph.cc) make +// all host memory type placement decisions, any decision made elsewhere +// should use this function (e.g. instead of assuming that all variants never +// contain host memory types). +inline bool IsHostMemoryType(const FullTypeDef& t) { + switch (t.type_id()) { + case TFT_TENSOR: + return IsHostMemoryType(full_type::GetArgDefaultAny(t, 0)); + case TFT_ARRAY: + return IsHostMemoryType(full_type::GetArgDefaultAny(t, 0)); + case TFT_DATASET: + return true; + case TFT_MUTEX_LOCK: + return true; + case TFT_RAGGED: + return IsHostMemoryType(full_type::GetArgDefaultAny(t, 0)); + case TFT_STRING: + return true; + case TFT_ITERATOR: + return IsHostMemoryType(full_type::GetArgDefaultAny(t, 0)); + case TFT_OPTIONAL: + return IsHostMemoryType(full_type::GetArgDefaultAny(t, 0)); + case TFT_PRODUCT: + for (int i = 0; i < t.args_size(); i++) { + if (IsHostMemoryType(full_type::GetArgDefaultAny(t, i))) { + return true; + } + } + return false; + default: + return false; + } +} + +} // namespace full_type + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_FRAMEWORK_FULL_TYPE_UTIL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/framework/function.h b/third_party/tflite-hdrs/tensorflow/core/framework/function.h new file mode 100644 index 00000000..8c77af38 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/framework/function.h @@ -0,0 +1,1260 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_FRAMEWORK_FUNCTION_H_ +#define TENSORFLOW_CORE_FRAMEWORK_FUNCTION_H_ + +#include +#include +#include +#include +#include + +// clang-format off +// Required for IS_MOBILE_PLATFORM +#include "tensorflow/core/framework/graph_debug_info.pb.h" +#include "tensorflow/core/framework/op_def_builder.h" +#include "tensorflow/core/platform/platform.h" +// clang-format on + +#include "absl/container/flat_hash_map.h" +#include "absl/types/optional.h" +#include "absl/types/variant.h" +#include "xla/tsl/protobuf/error_codes.pb.h" +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/attr_value_util.h" +#include "tensorflow/core/framework/cancellation.h" +#include "tensorflow/core/framework/function.pb.h" +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/optimized_function_graph.pb.h" +#include "tensorflow/core/framework/registration/registration.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/graph/graph_debug_info_builder.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/platform/random.h" +#include "tensorflow/core/platform/stack_frame.h" +#include "tensorflow/core/platform/threadpool_interface.h" +#include "tensorflow/core/protobuf/config.pb.h" +#if !defined(IS_MOBILE_PLATFORM) +#include "tensorflow/core/protobuf/remote_tensor_handle.pb.h" +#endif // IS_MOBILE_PLATFORM + +namespace tensorflow { + +class CollectiveExecutor; +class DeviceSet; +class Graph; +class GraphDef; +class OpKernel; +class ProcessFunctionLibraryRuntime; +class ResourceMgr; +class Rendezvous; +class ScopedStepContainer; +class StepStatsCollectorInterface; +class Node; + +// FunctionDefHelper::Create is a convenient helper to construct a +// FunctionDef proto. +// E.g., +// FunctionDef my_func = FunctionDefHelper::Create( +// "my_func_name", +// {"x:T", "y:T" /* one string per argument */}, +// {"z:T" /* one string per return value */}, +// {"T: {float, double}" /* one string per attribute */}, +// { +// {{"o"}, "Mul", {"x", "y"}, {{"T", "$T"}}} +// /* one entry per function node */ +// }, +// /* Mapping between function returns and function node outputs. */ +// {{"z", "o:z"}}); +// +// For the old Function::Node approach, use FunctionDefHelper::Define() +// E.g., +// FunctionDef my_func = FunctionDefHelper::Define( +// "my_func_name", +// {"x:T", "y:T" /* one string per argument */}, +// {"z:T" /* one string per return value */}, +// {"T: {float, double}" /* one string per attribute */}, +// { +// {{"z"}, "Mul", {"x", "y"}, {{"T", "$T"}}} +// /* one entry per function node */ +// }); +class FunctionDefHelper { + public: + // AttrValueWrapper has copy constructors for the type T so that + // it's easy to construct a simple AttrValue proto. + // + // If T is a string type (const char*, string, or StringPiece), and + // it starts with "$", we construct a AttrValue of "placeholder". + // + // E.g., + // std:: x = {"T", "$T"} + // is a named attr value placeholder. + struct AttrValueWrapper { + AttrValue proto; + + AttrValueWrapper() {} + + template + AttrValueWrapper(T val) { // NOLINT(runtime/explicit) + SetAttrValue(val, &proto); + } + + private: + void InitFromString(absl::string_view val); + }; + + // Constructs an AttrValue.func given the "name" and "attrs". + static AttrValueWrapper FunctionRef( + const std::string& name, + absl::Span> attrs); + static AttrValueWrapper FunctionRef(const std::string& name) { + return FunctionRef(name, {}); + } + + // Node is used to construct FunctionDef.Node using initialization + // lists. E.g., + // Node n = {{"z"}, "Mul", {"x", "y"}, {{"T", "$T"}}}; // z = x * y + // + // If the op has no inputs, then name is be specified. + // Node n = {{}, "AssignVariable", {"resource", "val"}, {{"dtype", + // "DT_FLOAT"}, + // {"update0"}, "CPU:0", "update1"}} + struct Node { + // When constructing a NodeDef, the first entry in ret is used as + // the node name, the remaining values are ignored. + std::vector ret; + std::string op; + std::vector arg; + std::vector> attr; + std::vector dep; + std::string device; + + // Required if the op has zero outputs. Otherwise, ret[0] used as name if + // name is left empty. + std::string name; + + std::string GetName() const { + if (!name.empty()) return name; + CHECK(!ret.empty()); + return ret[0]; + } + std::vector original_node_names; + std::vector original_func_names; + + NodeDef ToNodeDef() const; + }; + + // Creates a FunctionDef from the given parameters. Node inputs must use + // function encoding (node_name:output_name[:output_index]). + // - `ret_def` holds a mapping from the function output names from `out_def` + // to the node outputs from `node_def`. + // - `control_ret_def` holds a mapping from the function control + // output names to the nodes from `node_def`. + static FunctionDef Create( + const std::string& function_name, absl::Span in_def, + absl::Span out_def, absl::Span attr_def, + absl::Span node_def, + absl::Span> ret_def, + absl::Span> control_ret_def); + + // Creates a FunctionDef from the given parameters. Node inputs must use + // function encoding (node_name:output_name[:output_index]). + // - `ret_def` holds a mapping from the function output names from `out_def` + // to the node outputs from `node_def`. + static FunctionDef Create( + const std::string& function_name, absl::Span in_def, + absl::Span out_def, absl::Span attr_def, + absl::Span node_def, + absl::Span> ret_def); + + // TODO(josh11b): Get rid of these and transition to the one above. + static FunctionDef Define(const std::string& function_name, + absl::Span arg_def, + absl::Span ret_def, + absl::Span attr_def, + absl::Span node_def); + + // Defines an anonymous function. I.e., its name is not relevant. + static FunctionDef Define(absl::Span arg_def, + absl::Span ret_def, + absl::Span attr_def, + absl::Span node_def); + + // Helpers to construct a constant scalar. + template + static Node Const(const std::string& name, const T& val) { + Node n = {{name}, "Const"}; + const DataType dtype = DataTypeToEnum::value; + n.attr.push_back({"dtype", dtype}); + Tensor t(dtype, TensorShape({})); + t.scalar()() = val; + n.attr.push_back({"value", t}); + return n; + } + + template + static Node Const(const std::string& name, gtl::ArraySlice vals) { + Node n = {{name}, "Const"}; + const DataType dtype = DataTypeToEnum::value; + n.attr.push_back({"dtype", dtype}); + int64_t num = vals.size(); + Tensor t(dtype, TensorShape({num})); + for (size_t i = 0; i < vals.size(); ++i) { + t.flat()(i) = vals[i]; + } + n.attr.push_back({"value", t}); + return n; + } +}; + +template <> +inline FunctionDefHelper::AttrValueWrapper::AttrValueWrapper(const char* val) { + InitFromString(val); +} + +template <> +inline FunctionDefHelper::AttrValueWrapper::AttrValueWrapper( + const std::string& val) { + InitFromString(val); +} + +template <> +inline FunctionDefHelper::AttrValueWrapper::AttrValueWrapper( + absl::string_view val) { + InitFromString(val); +} + +// Instantiate a function. +// +// "fdef" encodes a TF function with some attrs in fdef.signature.attr +// containing placeholders. InstantiateFunction binds these +// placeholders and produces an instantiated function encoded in +// "result.gdef". The value to substitute a placeholder is given by +// "attr_values", which is a map from a placeholder name to an attr +// value. +// +// InstantiateFunction calls "get_function" to find signatures of other +// functions and primitive ops. + +// GetFunctionSignature(func name, opdef) returns OK if the func name is found +// and opdef is filled with a pointer to the corresponding signature +// (a OpDef proto). Otherwise, returns an error. +typedef std::function + GetFunctionSignature; + +struct InstantiationResult { + DataTypeVector arg_types; + DataTypeVector ret_types; + std::vector nodes; +}; +absl::Status InstantiateFunction(const FunctionDef& fdef, AttrSlice attr_values, + GetFunctionSignature get_function, + InstantiationResult* result); + +// Returns a debug string for a function definition. +// +// The returned text is multiple-line. It is intended to be +// human-readable rather than being friendly to parsers. It is _NOT_ +// intended to be the canonical string representation of "func_def". +// Particularly, it may not include all information presented in +// "func_def" (e.g., comments, description of the function arguments, +// etc.) +std::string DebugString(const FunctionDef& func_def); +std::string DebugString(const GraphDef& instantiated_func_def); +std::string DebugString(absl::Span instantiated_func_nodes); + +// Returns a debug string for a top level graph (the main program and +// its supporting functions defined in its library). +std::string DebugStringWhole(const GraphDef& gdef); + +// Returns true if f1 == f2. Compares all fields, including descriptions. Order +// of NodeDefs doesn't matter. +bool FunctionDefsEqual(const FunctionDef& f1, const FunctionDef& f2); + +// Return a hash of `fdef` that is consistent with FunctionDefsEqual method. +// In other words, if two fdefs compare equal, their hash values will be the +// same. +uint64 FunctionDefHash(const FunctionDef& fdef); + +class CallFrameInterface { + public: + virtual ~CallFrameInterface() {} + + virtual size_t num_args() const = 0; + virtual size_t num_retvals() const = 0; + + virtual absl::Status GetArg(int index, const Tensor** val) = 0; + + // Optimized implementation of `GetArg()` that allows the caller to take + // ownership of the tensor. This method may only be called once per + // value of `index` and `CallFrameInterface` instance. + // + // REQUIRES: `this->CanConsumeArg(index) == true`. + virtual void ConsumeArg(int index, Tensor* val) { + LOG(ERROR) << "This `CallFrameInterface` implementation does not support " + "consuming arguments."; + } + virtual bool CanConsumeArg(int index) const { return false; } + + virtual absl::Status SetRetval(int index, const Tensor& val) = 0; +}; + +// Represents a function call frame. I.e., the data structure used to +// pass arguments to a function and retrieve its results. +// +// Runtime must arrange accesses to one FunctionCallFrame s.t. +// 1. SetArgs() happens before any GetArg(); +// 2. GetRetvals happens after all SetRetval(); +class FunctionCallFrame : public CallFrameInterface { + public: + FunctionCallFrame(DataTypeSlice arg_types, DataTypeSlice ret_types); + ~FunctionCallFrame() override; + + // Caller methods. + absl::Status SetArgs(absl::Span args); + absl::Status GetRetvals(std::vector* rets) const; + + // Moves the return values from the frame to rets. If allow_dead_tensors is + // false it will fail if any of the retvals do not have a value. + absl::Status ConsumeRetvals(std::vector* rets, + bool allow_dead_tensors); + + size_t num_args() const override { return arg_types_.size(); } + size_t num_retvals() const override { return ret_types_.size(); } + + // Callee methods. + absl::Status GetArg(int index, const Tensor** val) override; + absl::Status SetRetval(int index, const Tensor& val) override; + + private: + DataTypeVector arg_types_; + DataTypeVector ret_types_; + absl::InlinedVector args_; + struct Retval { + bool has_val = false; + Tensor val; + }; + absl::InlinedVector rets_; + + FunctionCallFrame(const FunctionCallFrame&) = delete; + void operator=(const FunctionCallFrame&) = delete; +}; + +// Map of function names to StackTracesMaps. +using FunctionDefLibraryStackTraces = + absl::flat_hash_map; + +// Holds Function information that can be shared in multiple places. +// FunctionRecord must be explicitly finalized before being saved in +// FunctionLibraryDefinition or any other place that expects immutability. +class FunctionRecord : public core::RefCounted { + public: + FunctionRecord(const FunctionDef& fdef, const StackTracesMap& stack_traces, + bool finalized); + FunctionRecord(FunctionDef&& fdef, StackTracesMap&& stack_traces, + bool finalized); + + // Mark FunctionRecord as finalized (disable mutation). + void finalize(); + + // Get a mutable reference to the FunctionDef owned by the record. + // Will fail if record is finalized. + absl::StatusOr mutable_fdef(); + + // Get an immutable access to FunctionRecord properties. + const FunctionDef& fdef() const; + const StackTracesMap& stack_traces() const; + const OpRegistrationData& op_registration_data() const; + const bool finalized() const; + + private: + bool finalized_ = false; + + FunctionDef fdef_; + const StackTracesMap stack_traces_; + const OpRegistrationData op_registration_data_; +}; + +// Helper to maintain a map between function names in a given +// FunctionDefLibrary and function definitions. +// +// This class is thread-safe. +class FunctionLibraryDefinition : public OpRegistryInterface { + public: + // Ops created for function arguments bear the name given by `kArgOp`; those + // created for return values bear the name given by `kRetOp`. + static constexpr const char* const kArgOp = "_Arg"; + static constexpr const char* const kDeviceArgOp = "_DeviceArg"; + static constexpr const char* const kRetOp = "_Retval"; + static constexpr const char* const kDeviceRetOp = "_DeviceRetval"; + static constexpr const char* const kIntsOnDeviceAttr = + "experimental_ints_on_device"; + static constexpr const char* const kSharedRendezvousAttr = + "shared_rendezvous"; + + static constexpr const char* const kGradientOp = "SymbolicGradient"; + static constexpr const char* const kFuncAttr = "f"; + + // Note: This constructor grabs `lib_def`'s lock in shared mode. + FunctionLibraryDefinition(const FunctionLibraryDefinition& lib_def); + explicit FunctionLibraryDefinition( + const OpRegistryInterface* default_registry, + const FunctionDefLibrary& lib_def = {}, + const FunctionDefLibraryStackTraces& library_traces = {}); + FunctionLibraryDefinition(const OpRegistryInterface* default_registry, + const GraphDef& graph_def); + ~FunctionLibraryDefinition() override; + + FunctionLibraryDefinition& operator=(const FunctionLibraryDefinition&) = + delete; + FunctionLibraryDefinition& operator=(FunctionLibraryDefinition&& other); + + // Returns True if the library contains `func`, False otherwise. + bool Contains(const std::string& func) const TF_LOCKS_EXCLUDED(mu_); + + // Returns nullptr if "func" is not defined in "lib_def". Otherwise, + // returns its definition proto. + // + // NB: This function returns a borrowed pointer, which can be invalidated by a + // subsequent call to `ReplaceFunction()` with the given name. + const FunctionDef* Find(const std::string& func) const TF_LOCKS_EXCLUDED(mu_); + + // Returns nullptr if "func" is not defined in "lib_def". Otherwise, + // returns a strong reference pointer to the FunctionRecord in the library. + core::RefCountPtr FindRecord(const std::string& func) const + TF_LOCKS_EXCLUDED(mu_); + + // Adds function definition 'fdef' to this function library. + // Returns status 'ok' on success, or error otherwise. This is a no-op if + // 'fdef' already exists in this function library. + // If 'fdef' is successfully added to the library, it will be accessible + // from 'LookUp' and included in the proto returned by 'ToProto'. + // This operation is atomic. + // + // Associates `graph` with a function `func_name`. Lifetime assumption: + // `graph` has to outlive all instantiated graphs. + absl::Status AddFunctionDef(const FunctionDef& fdef, + const StackTracesMap& stack_traces = {}) + TF_LOCKS_EXCLUDED(mu_); + absl::Status AddFunctionDef(FunctionDef&& fdef, + StackTracesMap&& stack_traces = {}) + TF_LOCKS_EXCLUDED(mu_); + absl::Status AddFunctionRecord(core::RefCountPtr record) + TF_LOCKS_EXCLUDED(mu_); + + // Adds gradient definition 'grad' to this function library. + // This is a no-op if 'grad' already exists in this function library. + // If 'grad' is successfully added, it will be accessible via 'FindGradient' + // and included in the proto returned by 'ToProto'. + // This operation is atomic. + absl::Status AddGradientDef(const GradientDef& grad) TF_LOCKS_EXCLUDED(mu_); + + // Replaces the function corresponding to `func` with `fdef`. Returns + // a non-OK status if "func" was not found in the library, OK otherwise. + // Please be careful when replacing function: make sure all previous pointers + // returned by `Find()` are no longer in use. + absl::Status ReplaceFunction(const std::string& func, const FunctionDef& fdef, + const StackTracesMap& stack_traces = {}) + TF_LOCKS_EXCLUDED(mu_); + + // Replaces the gradient corresponding to `grad.function_name()`. Returns + // a non-OK status if "grad.function_name()" was not found in the library, OK + // otherwise. + absl::Status ReplaceGradient(const GradientDef& grad) TF_LOCKS_EXCLUDED(mu_); + + // Removes the function corresponding to 'func'. Returns a non-OK status if + // 'func' was not found in the library, OK otherwise. + // Please be careful when removing function: make sure there are no other + // nodes using the function, and all previous pointers returned by `Find()` + // are no longer in use. + absl::Status RemoveFunction(const std::string& func) TF_LOCKS_EXCLUDED(mu_); + + // Removes all the functions and gradient functions. + void Clear() TF_LOCKS_EXCLUDED(mu_); + + // Adds the functions and gradients in 'other' to this function library. + // Duplicate functions and gradients are ignored. + // This operation is atomic. + absl::Status AddLibrary(const FunctionLibraryDefinition& other) + TF_LOCKS_EXCLUDED(mu_); + absl::Status AddLibrary(FunctionLibraryDefinition&& other) + TF_LOCKS_EXCLUDED(mu_); + + // Adds the functions and gradients in 'lib_def' to this function library. + // Duplicate functions and gradients are ignored. This overload adds the + // functions with no stack traces. This operation is atomic. + absl::Status AddLibrary(const FunctionDefLibrary& lib_def) + TF_LOCKS_EXCLUDED(mu_); + absl::Status AddLibrary(FunctionDefLibrary&& lib_def) TF_LOCKS_EXCLUDED(mu_); + + // Adds the functions and gradients in 'lib_def' to this function library. + // Duplicate functions and gradients are ignored. + // This operation is atomic. + absl::Status AddLibrary(const FunctionDefLibrary& lib_def, + const FunctionDefLibraryStackTraces& library_traces) + TF_LOCKS_EXCLUDED(mu_); + absl::Status AddLibrary(FunctionDefLibrary&& lib_def, + const FunctionDefLibraryStackTraces& library_traces) + TF_LOCKS_EXCLUDED(mu_); + + // If the gradient function for 'func' is specified explicitly in + // the library, returns the gradient function name. Otherwise, + // returns an empty string. + std::string FindGradient(const std::string& func) const + TF_LOCKS_EXCLUDED(mu_); + + // OpRegistryInterface method. Useful for constructing a Graph. + // + // If "op" is defined in the library, returns its signature. + // Otherwise, assume "op" is a primitive op and returns its op + // signature and shape inference function. + // + // NB: This function outputs a borrowed pointer, which can be invalidated by a + // subsequent call to `ReplaceFunction()` with the given name. + absl::Status LookUp(const std::string& op_type_name, + const OpRegistrationData** op_reg_data) const override + TF_LOCKS_EXCLUDED(mu_); + + // Generates new function name with the specified prefix that is unique + // across this library. + std::string UniqueFunctionName(absl::string_view prefix) const + TF_LOCKS_EXCLUDED(mu_); + + // Given a node def 'ndef', inspects attributes of the callee + // function to derive the attribute 'value' for 'attr'. Returns OK + // iff the attribute is given by the function's definition. + // TODO(irving): Remove; keep only the const Node& version. + template + absl::Status GetAttr(const NodeDef& ndef, const std::string& attr, + T* value) const; + + // Given a node, inspects attributes of the callee function to derive the + // attribute 'value' for 'attr'. Returns OK iff the attribute is given by the + // function's definition. + template + absl::Status GetAttr(const Node& node, const std::string& attr, + T* value) const; + + // Returns a proto representation of the state of this function library. + FunctionDefLibrary ToProto() const TF_LOCKS_EXCLUDED(mu_); + + size_t num_functions() const TF_LOCKS_EXCLUDED(mu_) { + tf_shared_lock l(mu_); + return records_.size(); + } + + // Returns all the function names in the FunctionLibraryDefinition. + std::vector ListFunctionNames() const TF_LOCKS_EXCLUDED(mu_); + + const OpRegistryInterface* default_registry() const { + return default_registry_; + } + void set_default_registry(const OpRegistryInterface* registry) { + default_registry_ = registry; + } + + // Returns a copy of `*this` with only the subset of functions that are + // reachable from the nodes of `graph` or `func`. + FunctionLibraryDefinition ReachableDefinitions(const GraphDef& graph) const; + FunctionLibraryDefinition ReachableDefinitions(const FunctionDef& func) const; + FunctionLibraryDefinition ReachableDefinitions(const Graph& graph) const; + absl::StatusOr ReachableDefinitions( + const std::string& function_name) const; + + // Copies the function named `func` from `other` to this + // FunctionLibraryDefinition. + // REQUIRES: `this->default_registry() == other.default_registry()`. + // Returns OK on success, or error otherwise. This is a no-op if a function + // name `func` already exists in this function library, and has the same + // implementation as in `other`. If the implementations conflict, an invalid + // argument error is returned. + absl::Status CopyFunctionDefFrom(const std::string& name, + const FunctionLibraryDefinition& other); + + // Returns graph with debug stack traces for the given function, or `nullptr` + // if none found. + const StackTracesMap* GetStackTraces(const std::string& func_name) const { + core::RefCountPtr entry = FindRecord(func_name); + if (entry.get() != nullptr) { + return &entry->stack_traces(); + } + return nullptr; + } + + // Adds or updates an OptimizedFunctionGraph. Key is `function_name`. + // + // NOTE: This overload will lead to a copy of a potentially large graph + // being stored in memory for the lifetime of the library. Using the lazy + // `creator` function overload is recommended in new code. + ABSL_DEPRECATED("Use the lazy `creator` function overload in new code.") + void AddOptimizedFunctionGraph(const std::string& function_name, + const OptimizedFunctionGraph& graph) + TF_LOCKS_EXCLUDED(mu_) { + std::function()> creator = + [graph]() { return graph; }; + AddOptimizedFunctionGraph(function_name, std::move(creator)); + } + + // Adds or updates an OptimizedFunctionGraph, using a `creator` that can + // lazily build or load the graph on demand. Key is `function_name`. + void AddOptimizedFunctionGraph( + const std::string& function_name, + std::function()> creator) + TF_LOCKS_EXCLUDED(mu_) { + mutex_lock l(mu_); + optimized_function_graph_creator_map_.emplace(function_name, + std::move(creator)); + } + + // Look up for OptimizedFunctionGraph given `function_name`. Returns nullopt + // if not found. + std::optional> + FindOptimizedFunctionGraph(const std::string& function_name) const + TF_LOCKS_EXCLUDED(mu_) { + tf_shared_lock l(mu_); + if (auto it = optimized_function_graph_creator_map_.find(function_name); + it != optimized_function_graph_creator_map_.end()) { + return it->second(); + } + return std::nullopt; + } + + // Creates a map of function names to stack traces for a FunctionDefLibrary. + static FunctionDefLibraryStackTraces CreateStackTracesForFunctionDefLibrary( + const FunctionDefLibrary& library, const GraphDebugInfo& debug_info); + + private: + void Initialize(const FunctionDefLibrary& library, + const FunctionDefLibraryStackTraces& library_traces); + + core::RefCountPtr FindHelper(const string& func) const + TF_SHARED_LOCKS_REQUIRED(mu_); + std::string FindGradientHelper(const std::string& func) const + TF_SHARED_LOCKS_REQUIRED(mu_); + + absl::Status AddHelper(FunctionRecord* registration, bool* added) + TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); + + // Same as AddFunctionDef/AddGradientDef except these methods set + // `added` to true if the `fdef`/`grad` were actually added to this. + absl::Status AddFunctionDefHelper(FunctionDef&& fdef, + StackTracesMap&& stack_traces, bool* added) + TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); + absl::Status AddGradientDefHelper(const GradientDef& grad, bool* added) + TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); + + // Helper function for GetAttr. Returns the FunctionDef* to get the + // attr from. + const FunctionDef* GetAttrImpl(const NodeDef& ndef) const + TF_LOCKS_EXCLUDED(mu_); + + // Remove all functions in `funcs` and all gradients of functions in + // `funcs_with_grads` from this library. + absl::Status Remove(const std::vector& funcs, + const std::vector& funcs_with_grads) + TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); + + // Remove `func` from the library. Returns non-OK Status unless `func` is in + // the library. This should only be called when there is a guarantee that the + // function being removed hasn't been retrieved with `Find`. + absl::Status RemoveFunctionHelper(const std::string& func) + TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); + + // Remove gradient of function `func` from the library. Returns non-OK Status + // unless `func` has a gradient. + absl::Status RemoveGradient(const std::string& func) + TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); + + mutable mutex mu_; + const OpRegistryInterface* default_registry_; + gtl::FlatMap records_ TF_GUARDED_BY(mu_); + gtl::FlatMap func_grad_ TF_GUARDED_BY(mu_); + // Maps from function name to optimized function graph. + gtl::FlatMap()>> + optimized_function_graph_creator_map_ TF_GUARDED_BY(mu_); +}; + +// Forward declare. Defined in common_runtime/function.h +struct FunctionBody; + +// Forward declare. Defined in common_runtime/device.h +class Device; +// Forward declare. Defined in common_runtime/device_mgr.h +class DeviceMgr; + +// Index of an _Arg node. +struct FunctionArgIndex { + explicit FunctionArgIndex(const int index) : index(index) {} + FunctionArgIndex(const int index, const int sub_index) + : index(index), sub_index(sub_index) {} + + // The value of the attribute "Index" of the _Arg node. + int index; + // Set only when the _Arg node represents multiple arguments (e.g. an _Arg + // node is replicated to multiple devices/subgraphs). Use sub-index to + // distinguish arguments with the same index. + int sub_index = -1; +}; + +class FunctionLibraryRuntime : public core::WeakRefCounted { + public: + ~FunctionLibraryRuntime() override {} + + // Instantiate a function with the given "attrs". + // + // Returns OK and fills in "handle" if the instantiation succeeds. + // Otherwise returns an error and "handle" is undefined. + struct InstantiateOptions { + // The canonical device name of the device on which the function + // should be instantiated. If empty, the function will be + // instantiated on the local device. + std::string target; + + // Should the function be instantiated as a multi-device function? + bool is_multi_device_function = false; + + // If true, graph passes will be skipped when instantiating the function + // since they have already run on the main function side. + bool is_component_function = false; + + // For multi-device functions, a vector of canonical device names for + // function's inputs. The device of resource inputs must be the device + // backing the resource, not the CPU device backing the resource handle. + // Must have the same length as number of inputs to the function. + std::vector input_devices; + + // For multi-device functions, a vector of canonical device names for + // function's outputs. + // + // (a) If specified (must have the same length as number of outputs): + // + // Specified devices will be assigned to Retval nodes inserted into the + // function body graph in place of function outputs. It is allowed to + // specify output device as empty string, in this case Retval device + // assignment will be inferred later when function graph will be placed + // before partitioning (this is required for resource outputs). Placer will + // respect colocation constraints. + // + // (b) If not specified: + // + // Function runtime will infer Retval device by following input edges, until + // it will reach a node with a device specification. This device + // specification must identify a unique device, i.e. a general specification + // like "job:foo" matching multiple devices will result in an error. + // + // IMPORTANT: Resource outputs + // + // Multi device functions might return resources on a devices different from + // the function call device. If output device is not specified for the + // resource output, and node producing that resource is a function call, + // runtime will leave device specification empty and will rely on Placer to + // infer correct device. + std::vector output_devices; + + // If set, it indicates the original output indices of a component function. + absl::optional> ret_indices = absl::nullopt; + + // Maps from a CompositeDevice name to a list of underlying physical + // devices. + absl::flat_hash_map*> composite_devices; + + // This interface is EXPERIMENTAL and subject to change. + // + // For multi-device functions, a mapping from _Arg node index to type and + // shape for input resources. + // REQUIRES: if input_resource_dtypes_and_shapes.count(i) > 0 then i-th + // argument type must be DT_RESOURCE. + std::unordered_map + input_resource_dtypes_and_shapes; + + // This interface is EXPERIMENTAL and subject to change. + // + // If non-null, the runtime will use `lib_def` to resolve function(s) named + // in `function_name` and `attrs`. Otherwise, the runtime will use its + // internal library. + // + // NOTE(mrry): If provided, all functions defined in `lib_def` must be + // self-contained, and cannot refer to functions defined in other libraries. + const FunctionLibraryDefinition* lib_def = nullptr; + + // This interface is EXPERIMENTAL and subject to change. + // + // If non-empty, the runtime will use `state_handle` to identify + // cached state related the instantiated function. Two functions + // of the same name and attrs, instantiated with the same + // `state_handle` will have the same handle and share the same + // state (in stateful kernels); and two functions with different + // values for `state_handle` will have independent state. + std::string state_handle; + + // This interface is EXPERIMENTAL and subject to change. + // + // Instantiates the function using an executor of the given type. If empty, + // the default TensorFlow executor will be used. + std::string executor_type; + + // If true, the runtime will attempt to create kernels for the function at + // instantiation time, rather than on the first run. This can be used to + // surface errors earlier. + bool create_kernels_eagerly = false; + + // This interface is EXPERIMENTAL and subject to change. + // + // Instantiates the function with the provided config_proto. + ConfigProto config_proto; + + // If provided, this optimization function will be invoked before + // the placer for multi-device functions. + std::function /*ret_node_names*/, + std::vector /*keep_node_names*/, + FunctionLibraryDefinition*, const DeviceSet&, + Device* /*cpu_device*/, std::unique_ptr*)> + optimize_graph_fn; + + // If set, partitioned functions will be added to `graph_collector`. + // `graph_collector` must be alive during the call to Instantiate. + GraphCollector* graph_collector = nullptr; + + // Indicates whether the multi-device function backend should default the + // placement of ops without request device to `target`. + bool default_device_to_target = true; + + // If true, the optimized Graph will be stored so that + // `FunctionLibraryRuntime::DebugString(handle)` contains the optimized + // Graph. Otherwise, the unoptimized function Graph will be returned. + bool include_optimized_graph_in_debug_string = false; + + // If true, the function library runtime cache the function instantiation. + bool use_function_cache = false; + + // This interface is EXPERIMENTAL and subject to change. + // + // If True, allow optimizations which should be targeted at a limited + // set of small functions. For example, running kernels synchronously can + // be faster under some conditions. + bool allow_small_function_optimizations = false; + + // This interface is EXPERIMENTAL and subject to change. + // + // If True, allow graphs containing control flow nodes to be run on the + // single threaded executor. + bool allow_control_flow_sync_execution = false; + + // TODO(b/176491312): Remove this if shape inference on import flag is + // removed. If True, allows mlir roundtrip to run shape inference on import. + bool shape_inference_on_tfe_dialect_import = true; + + // Force int32 _Arg and _Retvals nodes to be left on device instead of + // pinning to host. + // + // Note that we do not pin int32 nodes to host for subgraphs running in + // TPU/XLA devices. So this is mainly used to handle the case of multi-CPU + // and GPU (non-XLA) graphs. + bool int_args_and_retvals_on_device = false; + + // This interface is EXPERIMENTAL and subject to change. + // + // Instantiates the function for XLA compilation on device_type. If empty, + // function is not compiled. + std::string xla_compile_device_type; + + // This interface is EXPERIMENTAL and subject to change. + // + // Instantiates the function enabling soft placement or outside compilation. + bool allow_soft_placement = false; + }; + typedef uint64 Handle; + virtual absl::Status Instantiate(const std::string& function_name, + AttrSlice attrs, + const InstantiateOptions& options, + Handle* handle) = 0; + absl::Status Instantiate(const std::string& function_name, AttrSlice attrs, + Handle* handle) { + auto opts = absl::make_unique(); + return Instantiate(function_name, attrs, *opts, handle); + } + + // Releases state associated with the handle. + virtual absl::Status ReleaseHandle(Handle handle) = 0; + + // Returns the function body for the instantiated function given its + // handle 'h'. Returns nullptr if "h" is not found. + // + // *this keeps the ownership of the returned object, which remains alive + // as long as *this. + virtual const FunctionBody* GetFunctionBody(Handle h) = 0; + + // Returns the return types for the function identified by handle `h`. + virtual absl::Status GetRetTypes(Handle h, DataTypeVector* ret_types) = 0; + + // Asynchronously invokes the instantiated function identified by + // "handle". + // + // If function execution succeeds, "done" is called with OK and + // "*rets" is filled with the function's return values. Otherwise, + // "done" is called with an error status. + // + // Does not take ownership of "rets". + // In the cross-process scenario, runner isn't used for making the Async + // RPC calls. + struct Options { + Options() {} + explicit Options(const int64_t step_id) : step_id(step_id) {} + + // Choose a step ID that is guaranteed not to clash with any + // Session-generated step ID. DirectSession only generates + // non-negative step IDs (contiguous, starting from 0), and + // MasterSession generates 56-bit random step IDs whose MSB is + // always 0, so a negative random step ID should suffice. + const int64_t step_id = -std::abs(static_cast(random::New64())); + + // op_id of the function running in eager mode. Set when we want to copy + // remote outputs lazily. All components of a remote multi-device function + // should use the same op_id, in order to correctly map remote output + // tensors to the remote TensorHandles in the default device. + absl::optional op_id = absl::nullopt; + + // Not owned. Caller makes sure that the rendezvous outlives this Options. + RendezvousInterface* rendezvous = nullptr; + CancellationManager* cancellation_manager = nullptr; + CollectiveExecutor* collective_executor = nullptr; + ScopedStepContainer* step_container = nullptr; + StepStatsCollectorInterface* stats_collector = nullptr; + tsl::CoordinationServiceAgent* coordination_service_agent = nullptr; + + absl::optional stack_trace = absl::nullopt; + + std::function)>* runner = nullptr; + + // Parameters for remote function execution. + bool remote_execution = false; + std::string source_device = ""; // Fully specified device name. + + // Allocator attributes specifying where the args are / rets should be put. + // These should either be {} or match the length of args / retvals. If {}, + // the default allocator attributes will be assumed for all args / retvals. + std::vector args_alloc_attrs; + std::vector rets_alloc_attrs; + + // If true, we create a new IntraProcessRendezvous, else use the existing + // one. + bool create_rendezvous = false; + + // If True, allow returning dead tensors. + bool allow_dead_tensors = false; + + // If True, hint that all kernels should be treated as "inexpensive", and + // hence executed on the scheduling thread. + bool run_all_kernels_inline = false; + + // If not null, use this thread pool for intra op scheduling. + thread::ThreadPoolInterface* user_intra_op_threadpool = nullptr; + + // Returns a human readable representation of this. + std::string DebugString() const; + }; + typedef std::function DoneCallback; + virtual void Run(const Options& opts, Handle handle, + absl::Span args, std::vector* rets, + DoneCallback done) = 0; + virtual void Run(const Options& opts, Handle handle, + CallFrameInterface* call_frame, DoneCallback done) = 0; + + virtual absl::Status RunSync(Options opts, Handle handle, + absl::Span args, + std::vector* rets) = 0; + virtual absl::Status RunSync(Options opts, Handle handle, + CallFrameInterface* call_frame) = 0; + + // Creates a "kernel" for the given NodeProperties "props". + // + // If succeeds, returns OK and the caller takes the ownership of the + // returned "*kernel". Otherwise, returns an error. + virtual absl::Status CreateKernel( + const std::shared_ptr& props, + OpKernel** kernel) = 0; + + // Returns true iff the function named `function_name` is stateful. + // + // NOTE(mrry): This method assumes that the runtime is associated with a + // default function library, and looks up `function_name` in that library. + // It does not support overriding the function library. + virtual bool IsStateful(const std::string& function_name) const = 0; + + // Returns the device on which the function executes. + virtual Device* device() = 0; + virtual const Device* device() const = 0; + + // Returns the default runner in which the ops should be launched. If the + // device on which the function executes has a private thread pool, return + // runner on the device local thread pool. + virtual std::function)>* runner() = 0; + + // Get the DeviceMgr from which the device was obtained. + virtual const DeviceMgr* device_mgr() const = 0; + + // Returns the function library definition that backs this runtime. + // + // NOTE(mrry): The returned library definition is the default function library + // for this runtime. The caller may override the function library used by the + // runtime to instantiate functions, which will not be reflected in the return + // value of this function. + virtual const FunctionLibraryDefinition* GetFunctionLibraryDefinition() + const = 0; + + // Returns the environment on which the function executes. + virtual Env* env() = 0; + + // Returns the ConfigProto passed to the session used to create the function. + virtual const ConfigProto* const config_proto() = 0; + + // Returns a debug string showing the definition of the function of + // 'handle'. + virtual std::string DebugString(Handle handle) = 0; + + // Returns the graph version number. + virtual int graph_def_version() const = 0; + + typedef uint64 LocalHandle; + + // Creates a copy of ProcessFunctionLibraryRuntime (transferring ownership to + // the caller), FunctionLibraryRuntime (owned by the returned + // ProcessFunctionLibraryRuntime), FunctionLibraryDefinition (transferring + // ownership to the caller). Note that both the ProcessFunctionLibraryRuntime + // and FunctionLibraryRuntime borrow a pointer to the + // FunctionLibraryDefinition and so the FunctionLibraryDefinition should + // outlive both. + // + // The `skip_flib_def` argument controls whether the method should clone the + // FunctionLibraryDefinition (default behavior) or return an empty function + // library. The latter is used by tf.data, which manages + // FunctionLibraryDefinitions for its functions independently (and passes + // these into the FunctionLibraryRuntime through an overlay), to avoid linear + // runtime w.r.t. to number of functions in the current function library. + virtual absl::Status Clone( + std::unique_ptr* out_lib_def, + std::unique_ptr* out_pflr, + FunctionLibraryRuntime** out_flr, bool skip_flib_def = false) = 0; + + // Returns the name of the executor class (in the sense of + // `ExecutorFactory::GetFactory()`) that will be used based on the given + // dynamic `options` and static `attrs`. If none is specified, this method + // will return an empty string, which leaves the decision up to the runtime. + static std::string ExecutorType(const InstantiateOptions& options, + AttrSlice attrs); +}; + +// Returns the device of the `arg_index`-th function input. Update +// `composite_devices` if the input device is a composite device. +std::string GetFunctionResourceInputDevice( + const Tensor& input, const int arg_index, const FunctionDef& function_def, + absl::flat_hash_map>* composite_devices); + +// Returns a canonicalized string for the instantiation of the function of the +// given "name", attributes "attrs", and "options". +// +// The returned string is guaranteed to be stable within one address space. But +// it may be change as the implementation evolves. Therefore, it should not be +// persisted or compared across address spaces. +std::string Canonicalize( + const std::string& funcname, AttrSlice attrs, + const FunctionLibraryRuntime::InstantiateOptions& options); +std::string Canonicalize(const std::string& funcname, AttrSlice attrs); + +const FunctionLibraryRuntime::Handle kInvalidHandle = -1; +const FunctionLibraryRuntime::LocalHandle kInvalidLocalHandle = -1; + +class CustomKernelCreator { + public: + virtual ~CustomKernelCreator() {} + + // Given a NodeDef 'node_def' and the function library runtime 'flr', + // validate if the class supports creating such a kernel. + virtual bool CanCreateKernel( + const FunctionLibraryRuntime& flr, + const std::shared_ptr& props) const = 0; + + // Given a supported NodeDef, returns a kernel that computes the node. + virtual absl::Status CreateKernel( + FunctionLibraryRuntime* flr, + const std::shared_ptr& props, + std::unique_ptr* kernel) const = 0; +}; + +typedef +#if !defined(IS_MOBILE_PLATFORM) + absl::variant + FunctionArg; +#else + absl::variant + FunctionArg; +#endif + +// Either a local tensor or the shape of a remote tensor. +typedef absl::variant FunctionRet; + +// Used to instantiate and run functions in a distributed system. +class DistributedFunctionLibraryRuntime { + public: + virtual ~DistributedFunctionLibraryRuntime() {} + + // Instantiate a function on a remote target specified in `options.target`, by + // sending the name and definition of the function to the remote worker. The + // local `handle` is filled for the instantiated function data and can be used + // for subsequent run function calls on the remote target. + virtual void Instantiate( + const std::string& function_name, + const FunctionLibraryDefinition& lib_def, AttrSlice attrs, + const FunctionLibraryRuntime::InstantiateOptions& options, + FunctionLibraryRuntime::LocalHandle* handle, + FunctionLibraryRuntime::DoneCallback done) = 0; + + // Run an instantiated remote function (specified by `handle`) with a list of + // input Tensors in `args` and get its output Tensors in `rets`. The input + // tensor data will be sent with the function execution request, and must be + // available on the current caller side. + // opts.runner isn't used for execution. + virtual void Run(const FunctionLibraryRuntime::Options& opts, + FunctionLibraryRuntime::LocalHandle handle, + absl::Span args, std::vector* rets, + FunctionLibraryRuntime::DoneCallback done) = 0; + + // Run an instantiated remote function (specified by `handle`) with a list of + // input Tensors or RemoteTensorHandles as `args` and get its output Tensors + // or TensorShapes in `rets`. When using RemoteTensorHandles as function + // inputs or TensorShapes as outputs, the corresponding tensor data will be + // resolved on the remote worker, so it is not required to be locally + // available on the caller side. Using RemoteTensorHandle inputs is not + // supported in TensorFlow v1 runtime. + virtual void Run(const FunctionLibraryRuntime::Options& opts, + FunctionLibraryRuntime::LocalHandle handle, + absl::Span args, + std::vector* rets, + FunctionLibraryRuntime::DoneCallback done) = 0; + + // Clean up a previously instantiated function on remote worker. + virtual void CleanUp(uint64 step_id, + FunctionLibraryRuntime::LocalHandle handle, + FunctionLibraryRuntime::DoneCallback done) = 0; + + // DeviceMgr with *all* available devices (i.e., local and remote). + virtual DeviceMgr* remote_device_mgr() const = 0; +}; + +// Extracts the actual type from "attr_values" based on its definition +// "arg_def". +// +// If "arg_def" is a N*T type, *is_type_list is set to false, and +// *dtypes is set to be a vector of size N and each element is T. +// +// If "arg_def" is a list(type), *is_type_list is set to true, and +// *dtypes is set to be a vector of types specified in attrs for +// arg_def. +// +// Otherwise (arg_def is a simple type T), *is_type_list is set to +// false, and *dtypes is set to a single element vector, whose only +// element is T. +absl::Status ArgNumType(AttrSlice attrs, const OpDef::ArgDef& arg_def, + bool* is_type_list, DataTypeVector* dtypes); + +// To register a gradient function for a builtin op, one should use +// REGISTER_OP_GRADIENT(, ); +// +// Typically, the c++ grad factory is a plan function that can be +// converted into ::tensorflow::gradient::Creator, which is +// std::function. +// +// A ::tensorflow::gradient::Creator should populate in FunctionDef* with a +// definition of a brain function which compute the gradient for the +// when the is instantiated with the given attrs. +// +// E.g., +// +// Status MatMulGrad(const AttrSlice& attrs, FunctionDef* g) { +// bool transpose_a; +// TF_RETURN_IF_ERROR(attrs.Get("transpose_a", &transpose_a)); +// bool transpose_b; +// TF_RETURN_IF_ERROR(attrs.Get("transpose_b", &transpose_b)); +// DataType dtype; +// TF_RETURN_IF_ERROR(attrs.Get("dtype", &dtype)); +// if (!transpose_a && !transpose_b) { +// *g = FunctionDefHelper::Define( +// "MatMulGrad", +// {"x:T ", "y:T", "dz:T"}, // Inputs to this function +// {"dx:T", "dy:T"}, // Outputs from this function +// {"T: {float, double}"}, // Attributes needed by this function +// { +// {{"x_t"}, "Transpose", {"x"}, {{"T", "$T"}}}, +// {{"y_t"}, "Transpose", {"y"}, {{"T", "$T"}}}, +// {{"dx"}, "MatMul", {"dz", "y_t"}, {{"T", "$T"}}}, +// {{"dy"}, "MatMul", {"x_", "dz"}, {{"T", "$T"}}}, +// }); +// } else { +// ... ... +// } +// return OkStatus(); +// } +// +// NOTE: $T is substituted with the type variable "T" when the +// gradient function MatMul is instantiated. +// +// TODO(zhifengc): Better documentation somewhere. + +// Macros to define a gradient function factory for a primitive +// operation. +#define REGISTER_OP_GRADIENT(name, fn) \ + REGISTER_OP_GRADIENT_UNIQ_HELPER(__COUNTER__, name, fn) + +#define REGISTER_OP_NO_GRADIENT(name) \ + REGISTER_OP_GRADIENT_UNIQ_HELPER(__COUNTER__, name, nullptr) + +#define REGISTER_OP_GRADIENT_UNIQ_HELPER(ctr, name, fn) \ + REGISTER_OP_GRADIENT_UNIQ(ctr, name, fn) + +#define REGISTER_OP_GRADIENT_UNIQ(ctr, name, fn) \ + static bool unused_grad_##ctr TF_ATTRIBUTE_UNUSED = \ + SHOULD_REGISTER_OP_GRADIENT && \ + ::tensorflow::gradient::RegisterOp(name, fn) + +namespace gradient { +// Register a gradient creator for the "op". +typedef std::function + Creator; +bool RegisterOp(const std::string& op, Creator func); + +// Returns OK the gradient creator for the "op" is found (may be +// nullptr if REGISTER_OP_NO_GRADIENT is used. +absl::Status GetOpGradientCreator(const std::string& op, Creator* creator); +}; // namespace gradient + +// Declare explicit instantiations of GetAttr +#define GET_ATTR(T) \ + extern template Status FunctionLibraryDefinition::GetAttr( \ + const Node&, const string&, T*) const; \ + extern template Status FunctionLibraryDefinition::GetAttr( \ + const NodeDef&, const string&, T*) const; +GET_ATTR(string) +GET_ATTR(bool) +#undef GET_ATTR + +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_FRAMEWORK_FUNCTION_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/framework/function_handle_cache.h b/third_party/tflite-hdrs/tensorflow/core/framework/function_handle_cache.h new file mode 100644 index 00000000..1bd67138 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/framework/function_handle_cache.h @@ -0,0 +1,55 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_FRAMEWORK_FUNCTION_HANDLE_CACHE_H_ +#define TENSORFLOW_CORE_FRAMEWORK_FUNCTION_HANDLE_CACHE_H_ + +#include + +#include "tensorflow/core/framework/function.h" + +namespace tensorflow { + +// Thread-safe data structure for caching function instantiations. +class FunctionHandleCache { + public: + explicit FunctionHandleCache(FunctionLibraryRuntime* lib); + + ~FunctionHandleCache(); + + // Looks up the function to be instantiated in the cache first. If present, + // returns handle from there. Otherwise, instantiates a new function + // and stores handle in the cache. + // + // The cache retains the ownership of the handle. In particular, the caller + // should not invoke `ReleaseHandle`. + absl::Status Instantiate(const string& function_name, AttrSlice attrs, + FunctionLibraryRuntime::InstantiateOptions options, + FunctionLibraryRuntime::Handle* handle); + + // Releases all the handles in the cache, clearing out the state for all + // functions involved. + absl::Status Clear(); + + private: + mutex mu_; + FunctionLibraryRuntime* lib_ = nullptr; // not owned + const string state_handle_; + std::unordered_map handles_ + TF_GUARDED_BY(mu_); +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_FRAMEWORK_FUNCTION_HANDLE_CACHE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/framework/function_testlib.h b/third_party/tflite-hdrs/tensorflow/core/framework/function_testlib.h new file mode 100644 index 00000000..93cae697 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/framework/function_testlib.h @@ -0,0 +1,187 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_FRAMEWORK_FUNCTION_TESTLIB_H_ +#define TENSORFLOW_CORE_FRAMEWORK_FUNCTION_TESTLIB_H_ + +#include + +#include "tensorflow/core/framework/attr_value_util.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/function.pb.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +namespace test { +namespace function { + +// A helper class to make AttrSlice from initializer lists +class Attrs { + public: + Attrs(const std::initializer_list< // NOLINT(runtime/explicit) + std::pair>& attrs) { + for (const auto& aval : attrs) { + map_.insert({aval.first, aval.second.proto}); + } + } + + Attrs( + const std::vector>& + attrs) { + for (const auto& aval : attrs) { + map_.insert({aval.first, aval.second.proto}); + } + } + + operator AttrSlice() { return AttrSlice(&map_); } // NOLINT(runtime/explicit) + + private: + AttrValueMap map_; +}; + +// Helper to construct a NodeDef. +NodeDef NDef( + absl::string_view name, absl::string_view op, + absl::Span inputs, + absl::Span> + attrs = {}, + const string& device = ""); + +// Helper to construct a GraphDef proto. +GraphDef GDef(absl::Span nodes, + absl::Span funcs = {}); + +// For testing convenience, we provide a few simple functions that can +// be easily executed and tested. + +// x: T -> x * 2. +FunctionDef XTimesTwo(); +// Same as `XTimesTwo` above, but with the `x` input as a control dependency. +FunctionDef XTimesTwoWithControlInput(); +// Same as `XTimesTwo` above, but with a `dummy` control output node. +FunctionDef XTimesTwoWithControlOutput(); +// Same as `XTimesTwo` above, but with a dangling `FloorDiv` node. +FunctionDef XTimesTwoWithDanglingFloorDivNode(); + +// x: T -> cpu(x * 2) + cpu(x * 3). +FunctionDef TwoDeviceTimesFive(); + +// x: T -> cpu(x * 2), gpu(x * 3). +FunctionDef TwoDeviceMult(); + +// cpu(x): T, gpu(y): T -> cpu(x * 2), gpu(y * 3). +FunctionDef TwoDeviceInputOutput(); + +// Function taking a list of Tensors as input. +FunctionDef FuncWithListInput(); + +// Function returning a list of Tensors as output. +FunctionDef FuncWithListOutput(); + +// x: T -> x + x. +FunctionDef XAddX(); + +// x: T, y: T -> x + y. +FunctionDef XAddY(); + +// x: T -> x * 2, where x is int32. +FunctionDef XTimesTwoInt32(); + +// x: T -> (x * 2) * 2. +FunctionDef XTimesFour(); + +// x: T -> (x * 2) * 2, where x is int32 +FunctionDef XTimesFourInt32(); + +// x: T -> ((x * 2) * 2) * 2. +FunctionDef XTimes16(); + +// w: T, x: T, b: T -> MatMul(w, x) + b +FunctionDef WXPlusB(); + +// x: T -> x: T, T is a type which we automatically converts to a bool. +FunctionDef NonZero(); + +// x: T -> bool. +FunctionDef IsZero(); + +// x: T -> int64 +FunctionDef RandomUniform(); + +// x: T, y:T -> y: T, x: T +FunctionDef Swap(); + +// x: T, y: T -> y: T, x: T, the body has no nodes. +FunctionDef EmptyBodySwap(); + +// x: float, y: resource -> y: resource, 2*x: float. +FunctionDef ResourceOutput(); + +// x: resource -> x: resource +FunctionDef ResourceIdentity(); + +// x: resource -> y: float. +FunctionDef ReadResourceVariable(); + +// Contains simple control flow returning the input via an Enter op. +FunctionDef ControlFlow(); + +// Contains malformed control flow which can't be run by the executor. +FunctionDef InvalidControlFlow(); + +// x: T -> x <= N. +FunctionDef LessThanOrEqualToN(int64_t N); + +// x: T, y: T -> x + 1, x * y +FunctionDef XPlusOneXTimesY(); + +// x: T, y: T -> x <= N +FunctionDef XYXLessThanOrEqualToN(int64_t N); + +// x: T -> bool +FunctionDef RandomUniformLess(); + +// start: int64, stop: int64, step: int64 -> y: RangeDatasetOp::Dataset +FunctionDef MakeRangeDataset(); + +// input_dataset: variant, batch_size: int64, drop_remainder: bool +// -> y: BatchDatasetV2::Dataset +FunctionDef MakeBatchDataset(); + +// input_dataset: variant, other_arguments: Targuments, f: func, +// Targuments: list(type), output_types: list(type), output_shapes: list(shape), +// use_inter_op_parallelism: bool, preserve_cardinality: bool +// -> y: MapDatasetOp::Dataset +FunctionDef MakeMapDataset(bool has_other_args); + +// input_dataset: variant, count: int64 -> y: TakeDataset::Dataset +FunctionDef MakeTakeDataset(); + +// x: T -> y: TensorSliceDatasetOp::Dataset +FunctionDef MakeTensorSliceDataset(); + +// x: T -> y: T, idx: out_idx +FunctionDef Unique(); + +void FunctionTestSchedClosure(std::function fn); + +} // end namespace function +} // end namespace test +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_FRAMEWORK_FUNCTION_TESTLIB_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/framework/graph_def_util.h b/third_party/tflite-hdrs/tensorflow/core/framework/graph_def_util.h new file mode 100644 index 00000000..a164ac31 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/framework/graph_def_util.h @@ -0,0 +1,135 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_FRAMEWORK_GRAPH_DEF_UTIL_H_ +#define TENSORFLOW_CORE_FRAMEWORK_GRAPH_DEF_UTIL_H_ + +#include + +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { + +// Forward declare proto so that it's symbols can be removed from .so exports +class GraphDef; +class NodeDef; + +// Produce a human-readable version of a GraphDef that is more concise +// than a text-format proto. +string SummarizeGraphDef(const GraphDef& graph_def); + +// Validates the syntax of a GraphDef provided externally. +// +// The following is an EBNF-style syntax for GraphDef objects. Note that +// Node objects are actually specified as tensorflow::NodeDef protocol buffers, +// which contain many other fields that are not (currently) validated. +// +// Graph = Node * +// Node = NodeName, Inputs +// Inputs = ( DataInput * ), ( ControlInput * ) +// DataInput = NodeName, ( ":", [1-9], [0-9] * ) ? +// ControlInput = "^", NodeName +// NodeName = [A-Za-z0-9.], [A-Za-z0-9_./] * +absl::Status ValidateExternalGraphDefSyntax(const GraphDef& graph_def); + +// Adds default attributes to NodeDefs in 'graph_def' starting +// from the 'node_offset' node in 'graph_def'. +// +// Default attributes are defined by 'op_registry'. +// +// Returns OK on success, an error if 'graph_def' has a NodeDef +// that cannot be found in 'op_registry'. +// +// REQUIRES: 'graph_def' and 'op_registry' are not nullptr. +absl::Status AddDefaultAttrsToGraphDef(GraphDef* graph_def, + const OpRegistryInterface& op_registry, + int node_offset); + +// Same as above, except for the fact that it skips nodes that aren't found in +// op_registry if skip_unknown_ops is true. +absl::Status AddDefaultAttrsToGraphDef(GraphDef* graph_def, + const OpRegistryInterface& op_registry, + int node_offset, bool skip_unknown_ops); + +// Remove attrs from 'graph_def' that have the default value according +// to 'producer_op_registry', but don't exist according to +// 'consumer_op_registry'. This can allow 'graph_def' to run on the +// consumer even if consumer was built at an earlier CL (before an +// attr with a default was added). Note that this will not affect +// attrs with non-default values, so you must run a +// ValidateGraphDef...() function to see if the result is in fact +// compatible. If not nullptr, the op/attr pairs that were removed +// are added to '*op_attr_removed'. +// +// Expected usage, for a producer that wants to prepare a graph for +// a consumer: +// // For each consumer, update 'graph_def': +// OpListOpRegistry consumer_op_registry(consumer_server_op_list); +// std::unordered_set> op_attr_removed; +// TF_RETURN_IF_ERROR(RemoveNewDefaultAttrsFromGraphDef( +// &graph_def, consumer_op_registry, *OpRegistry::Global(), +// &op_attr_removed)); +// // Validate that each consumer can understand the resulting 'graph_def' +// TF_RETURN_IF_ERROR(graph::ValidateGraphDefAgainstOpRegistry( +// graph_def, consumer_op_registry)); +// // Consumer can use 'graph_def', and 'op_attr_removed' summarizes +// // what changes had to be made to 'graph_def' for it to work. +// +// Expected usage, for a consumer that has a graph and a +// (optionally-stripped) op_list from a producer (say from a call to +// StrippedOpListForGraph(), or in the MetaGraphDef): +// OpListOpRegistry producer_op_registry(producer_stripped_op_list); +// TF_RETURN_IF_ERROR(RemoveNewDefaultAttrsFromGraphDef( +// &graph_def, *OpRegistry::Global(), producer_op_registry, nullptr)); +absl::Status RemoveNewDefaultAttrsFromGraphDef( + GraphDef* graph_def, const OpRegistryInterface& consumer_op_registry, + const OpRegistryInterface& producer_op_registry, + std::set>* op_attr_removed); + +// Goes over the `nodes` and removes attributes that are set to their +// default values according to op_registry. +// If some node's definition is not found in the `op_registry`, this node is +// simply skipped. In most cases, these nodes would be function calls. +// If a stricter behavior is desired, one can add FunctionLibraryDefinition +// argument to check for functions and their attributes. +// This is obvious from signature, but as a warning, if `nodes` contain +// nodes calling functions, e.g. PartitionCallOp or FunctionalIf, this +// function does not "recurse" into them. +void StripDefaultAttributes(const OpRegistryInterface& op_registry, + protobuf::RepeatedPtrField* nodes); + +// Two functions that collect the ops used by a graph. +// +// This returns the ops used as a set of strings. +void OpsUsedByGraph(const GraphDef& graph_def, + std::set* ops_used_in_graph); + +// This function computes the stripped_op_list field of MetaGraphDef +// and similar protos. The op_registry should contain the ops used to +// produce graph_def. The resulting stripped_op_list can be +// communicated from the producer to the consumer, which can use +// RemoveNewDefaultAttrsFromGraphDef() to improve forwards compatibility +// (using an OpListOpRegistry as indicated in the example above). +// +// Most users will pass *OpRegistry::Global() for op_registry to strip against +// the list of ops registered in this process. +absl::Status StrippedOpListForGraph(const GraphDef& graph_def, + const OpRegistryInterface& op_registry, + OpList* stripped_op_list); + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_FRAMEWORK_GRAPH_DEF_UTIL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/framework/graph_to_functiondef.h b/third_party/tflite-hdrs/tensorflow/core/framework/graph_to_functiondef.h new file mode 100644 index 00000000..369b86ec --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/framework/graph_to_functiondef.h @@ -0,0 +1,71 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_FRAMEWORK_GRAPH_TO_FUNCTIONDEF_H_ +#define TENSORFLOW_CORE_FRAMEWORK_GRAPH_TO_FUNCTIONDEF_H_ + +#include +#include +#include + +#include "tensorflow/core/framework/function.pb.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { + +// Graph to FunctionDef conversion. This code is closely modeled on the Python +// function graph_to_function_def(), which is located in +// tensorflow/python/framework/graph_to_function_def.py. +absl::Status GraphToFunctionDef(const Graph& fn_body, const string& fn_name, + bool append_hash_to_fn_name, + bool set_stateful_from_nodes, + bool copy_placeholder_attrs_from_nodes, + const std::vector& body_nodes, + const std::vector& inputs, + const std::vector& outputs, + const std::vector& output_names, + const std::vector& control_outputs, + const std::vector& control_output_names, + const char* description, FunctionDef* fdef); + +// Converts 'graph' to a FunctionDef 'fdef', with name 'name': +// +// (1) 'node->IsArg()' nodes converted to function inputs. +// (2) 'node->IsRetval()' nodes converted to function output. +// (3) 'control_ret' returns an optional with a control output name, that will +// be added to the function `control_ret` map (see FunctionDef) and +// `control_output` in Op definition (see OpDef). Control output name must +// be unique for all control output nodes. +absl::Status GraphToFunctionDef( + const Graph& graph, const string& name, + const std::function(const Node*)>& control_ret, + FunctionDef* fdef); + +absl::Status GraphToFunctionDef(const Graph& graph, const string& name, + FunctionDef* fdef); + +absl::Status GraphToFunctionDef(const Graph& graph, const string& name, + const std::vector& output_names, + FunctionDef* fdef); + +absl::Status GraphToFunctionDef( + std::unique_ptr graph, const string& name, + const std::function(const Node*)>& control_ret, + FunctionDef* fdef); + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_FRAMEWORK_GRAPH_TO_FUNCTIONDEF_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/framework/kernel_def_builder.h b/third_party/tflite-hdrs/tensorflow/core/framework/kernel_def_builder.h new file mode 100644 index 00000000..b7629c8d --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/framework/kernel_def_builder.h @@ -0,0 +1,102 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_FRAMEWORK_KERNEL_DEF_BUILDER_H_ +#define TENSORFLOW_CORE_FRAMEWORK_KERNEL_DEF_BUILDER_H_ + +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +// Forward declare proto so that kernels don't need to depend on it +class KernelDef; + +// Builder class passed to the REGISTER_KERNEL_BUILDER() macro. +class KernelDefBuilder { + public: + // Starts with just the name field set. + // Caller MUST call Build() and take ownership of the result. + explicit KernelDefBuilder(const char* op_name); + ~KernelDefBuilder(); + + // Required: specify the type of device this kernel supports. + // Returns *this. + KernelDefBuilder& Device(const char* device_type); + + // Specify that this kernel supports a limited set of values for a + // particular type or list(type) attr (a further restriction than + // what the Op allows). + // Returns *this. + template + KernelDefBuilder& AttrConstraint(const char* attr_name, + gtl::ArraySlice allowed); + + // Like AttrConstraint above but supports just a single value. + template + KernelDefBuilder& AttrConstraint(const char* attr_name, T allowed); + + // Specify that this kernel supports a limited set of values for a + // particular type or list(type) attr (a further restriction than + // what the Op allows). + // Returns *this. + KernelDefBuilder& TypeConstraint(const char* attr_name, + absl::Span allowed); + + // Like TypeConstraint but supports just a single type. + KernelDefBuilder& TypeConstraint(const char* attr_name, DataType allowed); + + // Like TypeConstraint, but (a) gets the type from a template parameter + // and (b) only supports a constraint to a single type. + template + KernelDefBuilder& TypeConstraint(const char* attr_name) TF_ATTRIBUTE_NOINLINE; + // TODO(josh11b): Support other types of attr constraints as needed. + + // Specify that this kernel requires/provides an input/output arg + // in host memory (instead of the default, device memory). + // Returns *this. + KernelDefBuilder& HostMemory(const char* arg_name); + + // Specify that this kernel requires a particular value for the + // "_kernel" attr. May only be specified once. Returns *this. + KernelDefBuilder& Label(const char* label); + + // Specify a priority number for this kernel. + KernelDefBuilder& Priority(int32_t priority); + + // Returns a pointer to a KernelDef with fields set based on the + // above calls to this instance. + // Caller takes ownership of the result. + const KernelDef* Build(); + + private: + KernelDef* kernel_def_; + + KernelDefBuilder(const KernelDefBuilder&) = delete; + void operator=(const KernelDefBuilder&) = delete; +}; + +// IMPLEMENTATION + +template +KernelDefBuilder& KernelDefBuilder::TypeConstraint(const char* attr_name) { + return this->TypeConstraint(attr_name, DataTypeToEnum::v()); +} + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_FRAMEWORK_KERNEL_DEF_BUILDER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/framework/kernel_def_util.h b/third_party/tflite-hdrs/tensorflow/core/framework/kernel_def_util.h new file mode 100644 index 00000000..b60b3b2c --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/framework/kernel_def_util.h @@ -0,0 +1,31 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_FRAMEWORK_KERNEL_DEF_UTIL_H_ +#define TENSORFLOW_CORE_FRAMEWORK_KERNEL_DEF_UTIL_H_ + +#include "tensorflow/core/framework/kernel_def.pb.h" +#include "tensorflow/core/framework/node_def_util.h" + +namespace tensorflow { + +// Returns whether the attrs satisfy the constraints in the kernel_def. Returns +// an error if attrs in kernel_def are not found, or have a mismatching type. +absl::Status KernelAttrsMatch(const KernelDef& kernel_def, AttrSlice attrs, + bool* match); + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_FRAMEWORK_KERNEL_DEF_UTIL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/framework/kernel_shape_util.h b/third_party/tflite-hdrs/tensorflow/core/framework/kernel_shape_util.h new file mode 100644 index 00000000..6d444e18 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/framework/kernel_shape_util.h @@ -0,0 +1,109 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_FRAMEWORK_KERNEL_SHAPE_UTIL_H_ +#define TENSORFLOW_CORE_FRAMEWORK_KERNEL_SHAPE_UTIL_H_ + +#include + +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/util/padding.h" + +namespace tensorflow { +// GetWindowedOutputSize(): Given an input tensor, kernel, stride and padding +// type, the function computes the output and padding dimensions. +// +// For example, ignoring batches or multiple features, a 1D convolution +// takes as input a 1D tensor of shape (H), and convolves it with a filter of +// shape (K). +// +// It also takes in a few additional parameters: +// +// Stride (S): the stride with which we apply the filters. This is the offset +// between locations where we apply the filters. A larger stride +// means that the output will be spatially smaller. +// +// Padding (P): the padding we apply to the input tensor along each +// dimension. This is usually used to make sure that the spatial dimensions +// do not shrink when we progress with convolutions. This function supports two +// types of padding. +// SAME: the pad value is computed so that the output will have size H/S. +// VALID: no padding is carried out. +// If you want to use EXPLICIT padding, GetWindowedOutputSizeVerbose must be +// called instead. Note the padded area is zero-filled. +// +// The output dimensions are computed as follows: +// - When adding dilation_rate (D), we compute an effective filter size (K'): +// K' = (K - 1) * D + 1 +// - When Padding = SAME: the output size is (H'), where +// H' = ceil(float(H) / float(S)) +// where ceil is the ceiling function. The number of padded cells +// is computed as: +// Pc = ((H' - 1) * S + K' - H) / 2 +// When the stride is 1, the expression simplifies to +// H' = H, Pc = (K'-1)/2. +// This is where SAME comes from - the output has the same size as the input +// has. +// +// - When Padding = VALID: the output size is computed as +// H' = ceil(float(H - K' + 1) / float(S)) +// and the number of padded cells is always zero. +// When the stride is 1, the expression simplifies to +// H' = H-K'+1. +// +// For convolution, mathematically, the output value at location (r') +// is the inner product of two vectors: the chunk of input at +// ((r'*S-Pr) : (r'*S-Pr+K)), +// and the filter. +// +// For 2D and 3D convolutions, the spatial dimensions are orthogonal, so the +// size and padding of each spatial dimension can be computed by calling +// GetWindowedOutputSize separately for each dimension. +// +absl::Status GetWindowedOutputSize(int64_t input_size, int64_t filter_size, + int dilation_rate, int64_t stride, + Padding padding_type, int64_t* output_size, + int64_t* padding_size); + +// Returns the same output dimensions as in GetWindowedOutputSize, but returns +// verbose padding dimensions (before/after), and EXPLICIT padding is supported. +// When padding_type is EXPLICIT, *padding_before and *padding_after must +// already point to initialized integers with the padding amounts. Otherwise, +// *padding_before and *padding_after are set by this function, and any +// excess padding (caused by an odd padding size value) is added to the +// 'padding_after' dimension. +absl::Status GetWindowedOutputSizeVerbose( + int64_t input_size, int64_t filter_size, int64_t dilation_rate, + int64_t stride, Padding padding_type, int64_t* output_size, + int64_t* padding_before, int64_t* padding_after); + +// Given an input tensor, kernel, stride and padding type, populates the 3D size +// of the output tensor and padding to be applied to the input tensor at the +// lower end of every dimension. Use for 3D convolutions, where the input data +// is padded with zeros, as well as for 3D avg/max pooling, where the input data +// is padded with invalid values that are not considered for pooling. EXPLICIT +// padding is not supported. +// The V2 version computes the same outputs with arbitrary dilation_rate. For +// detailed equations, refer to the comments for GetWindowedOutputSize(). +absl::Status Get3dOutputSizeV2(const std::array& input, + const std::array& window, + const std::array& dilations, + const std::array& strides, + Padding padding_type, + std::array* output_ptr, + std::array* padding_ptr); + +} // namespace tensorflow +#endif // TENSORFLOW_CORE_FRAMEWORK_KERNEL_SHAPE_UTIL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/framework/local_rendezvous.h b/third_party/tflite-hdrs/tensorflow/core/framework/local_rendezvous.h new file mode 100644 index 00000000..332daaa6 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/framework/local_rendezvous.h @@ -0,0 +1,121 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_FRAMEWORK_LOCAL_RENDEZVOUS_H_ +#define TENSORFLOW_CORE_FRAMEWORK_LOCAL_RENDEZVOUS_H_ + +#include +#include +#include + +#include "tensorflow/core/framework/rendezvous.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/gtl/flatmap.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +// Implements the basic logic of matching Send and Recv operations. See +// RendezvousInterface for more details. +// +// NOTE: Most users will use a class that wraps LocalRendezvous, such as +// IntraProcessRendezvous or RemoteRendezvous. This class does not implement +// RendezvousInterface because virtual dispatch to LocalRendezvous methods +// is not expected to be needed. +class LocalRendezvous { + public: + // If the class wrapping LocalRendezvous is refcounted (i.e., extending + // Rendezvous), pass in its pointer in constructor so the LocalRendezvous + // can make sure it outlives the async recv requests. + // Pass in nullptr if the wrapping class is not refcounted. + explicit LocalRendezvous(Rendezvous* owner, int num_shards) + : num_buckets_(num_shards > 0 ? num_shards : 1), + rc_owner_(owner), + table_buckets_(std::make_unique(num_buckets_)) {} + ~LocalRendezvous(); + + absl::Status Send(const Rendezvous::ParsedKey& key, + const Rendezvous::Args& send_args, const Tensor& val, + bool is_dead); + void RecvAsync(const Rendezvous::ParsedKey& key, + const Rendezvous::Args& recv_args, + Rendezvous::DoneCallback done); + void StartAbort(const absl::Status& status); + absl::Status status(); + + // Releases all the references to the aborted rendezvous. Used in unit tests. + static void ReleaseAbortedRendezvous() { + mutex_lock l(aborted_rendezs_mu_); + aborted_rendezs_.clear(); + } + + private: + void DoAbort(const absl::Status& status); + + tsl::core::RefCountPtr GetOwnerRefCountPtr(); + + struct Item; + + // By invariant, the item queue under each key is of the form + // [item.type == kSend]* meaning each item is a sent message. + // or + // [item.type == kRecv]* meaning each item is a waiter. + struct ItemQueue { + void push_back(Item* item); + + Item* head = nullptr; + Item* tail = nullptr; + }; + + typedef gtl::FlatMap Table; + + const int num_buckets_; + // Pointer to the owner class of this LocalRendezvous if it is refcounted, + // nullptr otherwise. + Rendezvous* rc_owner_; + + struct TableBucket { + mutex mu; + Table table TF_GUARDED_BY(mu); + + // Track the number of pening callbacks using a counter. + int pending_callback_counter TF_GUARDED_BY(mu) = 0; + condition_variable pending_callback_cond_var TF_GUARDED_BY(mu); + }; + + // Immutable set of buckets. This uses less memory than std::vector. + const std::unique_ptr table_buckets_; + mutex mu_; + absl::Status status_ TF_GUARDED_BY(mu_); + + // We deliberately leak one reference of the aborted rendezvous here, so that + // they won't be destructed, and lose the status_. + // This is necessary because subsequent calls to RendezvousMgr::Find() will + // return the aborted rendezvous, and proper errors will be propagated. + // TODO(hhb): find a better way to manage rendezvous lifespan. + static mutex& aborted_rendezs_mu_; + static std::vector >& aborted_rendezs_ + TF_GUARDED_BY(aborted_rendezs_mu_); + + LocalRendezvous(const LocalRendezvous&) = delete; + void operator=(const LocalRendezvous&) = delete; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_FRAMEWORK_LOCAL_RENDEZVOUS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/framework/log_memory.h b/third_party/tflite-hdrs/tensorflow/core/framework/log_memory.h new file mode 100644 index 00000000..f6c2b07d --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/framework/log_memory.h @@ -0,0 +1,112 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_FRAMEWORK_LOG_MEMORY_H_ +#define TENSORFLOW_CORE_FRAMEWORK_LOG_MEMORY_H_ + +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/platform/protobuf.h" + +namespace tensorflow { + +// LogMemory contains methods for recording memory allocations and +// frees, associating each allocation with a step identified by a +// process-wide id. For now, logging is enabled whenever VLOG_IS_ON(1) +// for the log_memory module. +// +// Limitations: We don't log memory allocations by Eigen on the CPU +// since that would require major changes to plumb through to the +// Eigen::{DefaultDevice,ThreadPoolDevice} allocate and deallocate +// methods. We do log Eigen allocations on GPU since the plumbing was +// already in place. +class LogMemory { + public: + // Allocations sometimes happen outside any computation step, and + // SpecialStepIds lists the ids used for those steps. + enum SpecialStepIds { + // Used when performing a just-in-time constant folding optimization. + CONSTANT_FOLDING_STEP_ID = -1, + // Used when constructing an Op kernel before executing a step. + OP_KERNEL_CONSTRUCTION_STEP_ID = -2, + // Used when allocating a tensor buffer from external code, e.g., + // the C API. + EXTERNAL_TENSOR_ALLOCATION_STEP_ID = -3, + // Used when allocating a buffer for network transfer. + NETWORK_BUFFER_STEP_ID = -4, + // Used when allocating a buffer to fill a Proto from the GPU. + PROTO_BUFFER_STEP_ID = -5, + // Used when allocating a Tensor where the caller has not indicated + // the step. + UNKNOWN_STEP_ID = -6, + }; + + static const std::string kLogMemoryLabel; + + // Test to see if memory logging is enabled. For now, logging is + // enabled whenever VLOG_IS_ON(2) for the log_memory module. + static bool IsEnabled(); + + // Log the beginning of a step. + static void RecordStep(int64_t step_id, const std::string& handle); + + // Log a tensor buffer allocation. The name indicates which kernel + // made the allocation. If the allocation is made through an + // OpKernelContext the step_id indicates which step is executing, + // otherwise step_id is one of the SpecialStepIds defined in + // op_kernel.h, e.g. Op Kernel construction or an optimization pass + // such as constant folding. + static void RecordTensorAllocation(const std::string& kernel_name, + int64_t step_id, const Tensor& tensor); + + // Log a tensor buffer deallocation. The deallocation is triggered + // when the buffer's refcount falls to zero, and the tracking + // mechanism does not associate it with a particular step or + // kernel. The allocation_id/allocator_name should match a + // corresponding tensor previously passed in to + // RecordTensorAllocation. + static void RecordTensorDeallocation(int64_t allocation_id, + const std::string& allocator_name); + + // Log the use of a tensor as an output from a kernel. + static void RecordTensorOutput(const std::string& kernel_name, + int64_t step_id, int index, + const Tensor& tensor); + + // Log a "raw" allocation, which is just a buffer sized in + // bytes. The Eigen allocator, and memory copies, record their + // allocations this way, since they do not allocate TensorFlow + // tensors. The operation is set to the OpKernel name if this is + // called from within an Op execution, otherwise it indicates an + // operation such as memcpy. The step_id if >=0 indicates which step + // is executing, otherwise step_id is one of the SpecialStepIds + // defined in op_kernel.h, e.g. Op Kernel construction or an + // optimization pass such as constant folding. + static void RecordRawAllocation(const std::string& operation, int64_t step_id, + size_t num_bytes, void* ptr, + Allocator* allocator); + + // Log a "raw" deallocation of a buffer. When deferred is true, the + // buffer won't be used again, but a GPU kernel may still be + // enqueued using the buffer. A deferred deallocation should always + // be followed by a matching non-deferred deallocation when the + // buffer is actually returned and can be reused. + static void RecordRawDeallocation(const std::string& operation, + int64_t step_id, void* ptr, + Allocator* allocator, bool deferred); +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_FRAMEWORK_LOG_MEMORY_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/framework/logging.h b/third_party/tflite-hdrs/tensorflow/core/framework/logging.h new file mode 100644 index 00000000..9bde3d51 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/framework/logging.h @@ -0,0 +1,37 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_FRAMEWORK_LOGGING_H_ +#define TENSORFLOW_CORE_FRAMEWORK_LOGGING_H_ + +#include + +namespace tensorflow { + +namespace logging { + +// Register a listener method to call on any printed messages. +// Returns true if it is successfully registered. +bool RegisterListener(void (*listener)(const char*)); + +// Log string to active listeners. Returns true if any listeners were +// registered. +bool LogToListeners(std::string msg, std::string end = "\n"); + +} // namespace logging + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_FRAMEWORK_LOGGING_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/framework/lookup_interface.h b/third_party/tflite-hdrs/tensorflow/core/framework/lookup_interface.h new file mode 100644 index 00000000..9d673fbc --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/framework/lookup_interface.h @@ -0,0 +1,164 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_FRAMEWORK_LOOKUP_INTERFACE_H_ +#define TENSORFLOW_CORE_FRAMEWORK_LOOKUP_INTERFACE_H_ + +#include "tensorflow/core/framework/resource_mgr.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { + +class OpKernelContext; + +namespace lookup { + +// Forward declaration so we can define GetInitializableLookupTable() in +// LookupInterface. +class InitializableLookupTable; + +// Lookup interface for batch lookups used by table lookup ops. +class LookupInterface : public ResourceBase { + public: + // Performs batch lookups, for every element in the key tensor, Find returns + // the corresponding value into the values tensor. + // If an element is not present in the table, the given default value is used. + + // For tables that require initialization, Find is available once the table + // is marked as initialized. + + // Returns the following statuses: + // - OK: when the find finishes successfully. + // - FailedPrecondition: if the table is not initialized. + // - InvalidArgument: if any of the preconditions on the lookup key or value + // fails. + // - In addition, other implementations may provide another non-OK status + // specific to their failure modes. + virtual absl::Status Find(OpKernelContext* ctx, const Tensor& keys, + Tensor* values, const Tensor& default_value) = 0; + + // Inserts elements into the table. Each element of the key tensor is + // associated with the corresponding element in the value tensor. + // This method is only implemented in mutable tables that can be updated over + // the execution of the graph. It returns Status::NotImplemented for read-only + // tables that are initialized once before they can be looked up. + + // Returns the following statuses: + // - OK: when the insert finishes successfully. + // - InvalidArgument: if any of the preconditions on the lookup key or value + // fails. + // - Unimplemented: if the table does not support insertions. + virtual absl::Status Insert(OpKernelContext* ctx, const Tensor& keys, + const Tensor& values) = 0; + + // Removes elements from the table. + // This method is only implemented in mutable tables that can be updated over + // the execution of the graph. It returns Status::NotImplemented for read-only + // tables that are initialized once before they can be looked up. + + // Returns the following statuses: + // - OK: when the remove finishes successfully. + // - InvalidArgument: if any of the preconditions on the lookup key fails. + // - Unimplemented: if the table does not support removals. + virtual absl::Status Remove(OpKernelContext* ctx, const Tensor& keys) = 0; + + // Returns the number of elements in the table. + virtual size_t size() const = 0; + + // Exports the values of the table to two tensors named keys and values. + // Note that the shape of the tensors is completely up to the implementation + // of the table and can be different than the tensors used for the Insert + // function above. + virtual absl::Status ExportValues(OpKernelContext* ctx) = 0; + + // Imports previously exported keys and values. + // As mentioned above, the shape of the keys and values tensors are determined + // by the ExportValues function above and can be different than for the + // Insert function. + virtual absl::Status ImportValues(OpKernelContext* ctx, const Tensor& keys, + const Tensor& values) = 0; + + // Returns the data type of the key. + virtual DataType key_dtype() const = 0; + + // Returns the data type of the value. + virtual DataType value_dtype() const = 0; + + // Returns the shape of a key in the table. + virtual TensorShape key_shape() const = 0; + + // Returns the shape of a value in the table. + virtual TensorShape value_shape() const = 0; + + // Check format of the key and value tensors for the Insert function. + // Returns OK if all the following requirements are satisfied, otherwise it + // returns InvalidArgument: + // - DataType of the tensor keys equals to the table key_dtype + // - DataType of the tensor values equals to the table value_dtype + // - the values tensor has the required shape given keys and the tables's + // value shape. + virtual absl::Status CheckKeyAndValueTensorsForInsert(const Tensor& keys, + const Tensor& values); + + // Similar to the function above but instead checks eligibility for the Import + // function. + virtual absl::Status CheckKeyAndValueTensorsForImport(const Tensor& keys, + const Tensor& values); + + // Check format of the key tensor for the Remove function. + // Returns OK if all the following requirements are satisfied, otherwise it + // returns InvalidArgument: + // - DataType of the tensor keys equals to the table key_dtype + virtual absl::Status CheckKeyTensorForRemove(const Tensor& keys); + + // Check the arguments of a find operation. Returns OK if all the following + // requirements are satisfied, otherwise it returns InvalidArgument: + // - DataType of the tensor keys equals to the table key_dtype + // - DataType of the tensor default_value equals to the table value_dtype + // - the default_value tensor has the required shape given keys and the + // tables's value shape. + absl::Status CheckFindArguments(const Tensor& keys, + const Tensor& default_value); + + string DebugString() const override { + return strings::StrCat("A lookup table of size: ", size()); + } + + // Returns an InitializableLookupTable, a subclass of LookupInterface, if the + // current object is an InitializableLookupTable. Otherwise, returns nullptr. + virtual InitializableLookupTable* GetInitializableLookupTable() { + return nullptr; + } + + protected: + ~LookupInterface() override = default; + + // Makes sure that the key and value tensor DataType's match the table + // key_dtype and value_dtype. + absl::Status CheckKeyAndValueTypes(const Tensor& keys, const Tensor& values); + + // Makes sure that the provided shape is consistent with the table keys shape. + absl::Status CheckKeyShape(const TensorShape& shape); + + private: + absl::Status CheckKeyAndValueTensorsHelper(const Tensor& keys, + const Tensor& values); +}; + +} // namespace lookup +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_FRAMEWORK_LOOKUP_INTERFACE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/framework/memory_types.h b/third_party/tflite-hdrs/tensorflow/core/framework/memory_types.h new file mode 100644 index 00000000..e1247222 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/framework/memory_types.h @@ -0,0 +1,39 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_FRAMEWORK_MEMORY_TYPES_H_ +#define TENSORFLOW_CORE_FRAMEWORK_MEMORY_TYPES_H_ + +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/types.h" + +namespace tensorflow { + +class NodeDef; + +// Returns into *{input,output}_memory_types the memory type of each +// {input,output} tensor. +// +// REQUIRES: * '*_memory_types' is not nullptr. +// * def has all attrs specified (e.g. using AddDefaultsToNodeDef()). +absl::Status MemoryTypesForNode(const OpRegistryInterface* op_registry, + const DeviceType& device_type, + const NodeDef& ndef, + MemoryTypeVector* input_memory_types, + MemoryTypeVector* output_memory_types); + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_FRAMEWORK_MEMORY_TYPES_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/framework/metrics.h b/third_party/tflite-hdrs/tensorflow/core/framework/metrics.h new file mode 100644 index 00000000..18b52c49 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/framework/metrics.h @@ -0,0 +1,550 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_FRAMEWORK_METRICS_H_ +#define TENSORFLOW_CORE_FRAMEWORK_METRICS_H_ + +#include +#include +#include + +#include "tensorflow/core/framework/dataset_options.pb.h" +#include "tensorflow/core/lib/monitoring/counter.h" +#include "tensorflow/core/lib/monitoring/gauge.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/statusor.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/protobuf/data_service.pb.h" +#include "tensorflow/core/protobuf/meta_graph.pb.h" + +namespace tensorflow { +namespace metrics { +enum class GraphOptimizationSource { + kUnknown, + kJit, + kAot, +}; + +// Records when a data-fetching tf.data operation is executed. +// +// The `name` argument identifies the operation type (e.g. "ToSingleElementOp"). +void RecordTFDataFetchOp(const string& name); + +// Records that a tf.data.Dataset executed by the program used autotuning. +// +// The `name` argument identifies the Dataset type (e.g. "ParallelMap"). +void RecordTFDataAutotune(const string& name); + +// Returns a counter that can be used to record the number of bytes produced by +// a tf.data.Dataset. +// +// The `name` argument identifies the Dataset type (e.g. "Batch" or "Map"). +monitoring::CounterCell* GetTFDataBytesConsumedCounter(const string& name); + +// Returns a counter that can be used to record the number of bytes produced by +// a tf.data.Dataset. +// +// The `name` argument identifies the Dataset type (e.g. "Batch" or "Map"). +monitoring::CounterCell* GetTFDataBytesProducedCounter(const string& name); + +// Returns a counter than can be used to record the number of bytes read from +// the filesystem by a tf.data.Dataset source. +// +// The `name` argument identifies the Dataset type (e.g. "TFRecordDataset"). +// +// TODO(jsimsa): Remove this now that we have GetTFDataBytesConsumedCounter? +monitoring::CounterCell* GetTFDataBytesReadCounter(const string& name); + +// Returns a counter than can be used to record the number of elements produced +// by a tf.data.Dataset. +// +// The `name` argument identifies the Dataset type (e.g. "Batch" or "Map"). +monitoring::CounterCell* GetTFDataElementsCounter(const string& name); + +// Returns a gauge than can be used to record the performance model information. +// +// The `id` argument represents the (unique) model ID. +monitoring::GaugeCell>* GetTFDataModelGauge( + const string& id); + +// Records the number of bytes fetched from tf.data.Dataset iterator. +void RecordTFDataBytesFetched(int64_t num_bytes); + +// Records the number of times a tf.data experiment was applied. +void RecordTFDataExperiment(const string& name); + +// Records the number of times a tf.data experiment could have been applied. +void RecordTFDataExperimentLive(const string& name); + +// Records the number of times a tf.data experiment was opted into. +void RecordTFDataExperimentOptIn(const string& experiment_name); + +// Records the number of times a tf.data experiment was opted out of. +void RecordTFDataExperimentOptOut(const string& experiment_name); + +// Records the time (in microseconds) spent generating an element and +// transferring it over the network for the given protocol. +void RecordTFDataServiceGetElementDuration(const string& data_transfer_protocol, + uint64 duration_us); + +// Records the time (in microseconds) spent in a single invocation of +// `ItertatorResource::GetNext()`. +void RecordTFDataGetNextDuration(uint64 duration_us); + +// Records the histogram of ratios of tf.data autotune algorithm used RAM over +// the ram budget. +void RecordTFDataAutotuneUsedRamBudgetRatio(const double ratio); + +// Records the histogram of ratios of tf.data autotune algorithm max buffer +// bytes over the ram budget. +void RecordTFDataAutotuneMaxBufferBudgetRatio(const double ratio); + +// Records the number of times each tf.data fingerprint is used +// to measure duplicate pre-processing. +// +// The `name` argument identifies the Dataset graph fingerprint, +// created using GraphHash(). +void RecordTFDataFingerprint(const string& name); + +// Records the event of a tf.data service pipeline getting a runtime +// compression decision. +void RecordTFDataServiceRuntimeCompressionDecision(bool compression_decision); + +// Records the event of a tf.data service pipeline making the compression +// related action. +void RecordTFDataServiceCompressionAction(const string& action); + +// Records the time (in microseconds) during which `IteratorResource` was busy +// processing at least one `GetNext()` request. +void RecordTFDataIteratorBusy(uint64 duration_us); + +// Records the time (in microseconds) between `IteratorResource` receiving the +// first `GetNext()` request and responding to the last `GetNext()` request. +void RecordTFDataIteratorLifetime(uint64 duration_us); + +// Records the time histogram (in microseconds) between `IteratorResource` +// responding to a `GetNext()` request and receiving the next `GetNext()` +// request. +void RecordTFDataIteratorGap(uint64 duration_us); + +// Records the number of independent graph changes resulting from the +// application of a tf.data optimization. +// +// The `name` argument identifies the optimization (e.g. "noop_elimination"). +void RecordTFDataOptimization(const string& name, int64_t num_changes); + +// Records that a tf.data service worker has been created. +void RecordTFDataServiceWorkerCreated(); + +// Records that a tf.data service job has been created. +void RecordTFDataServiceJobsCreated( + const data::ProcessingModeDef& processing_mode, bool is_coordinated_read); + +// Records tf.data service iterators created by clients. +void RecordTFDataServiceClientIterators( + int64_t worker_uid, data::DeploymentMode deployment_mode, + const data::ProcessingModeDef& processing_mode, bool is_coordinated_read); + +// Records that a tf.data service worker client has been created that will use +// `data_transfer_protocol` to get data from the worker server and whether or +// not the user explicitly specified the protocol. +void RecordTFDataServiceDataTransferProtocolUsed( + const string& data_transfer_protocol, bool user_specified); + +// Records that a tf.data service worker client fell back to gRPC rather than +// use `data_transfer_protocol` because of an error of type `code` with message +// `error_message`. +void RecordTFDataServiceDataTransferProtocolFallback( + const string& data_transfer_protocol, error::Code code, + const string& error_message); + +// Records that a tf.data service worker client got an error of non-retriable +// type `code` with message `error_message` when trying to transfer data over +// `data_transfer_protocol`. +void RecordTFDataServiceDataTransferProtocolError( + const string& data_transfer_protocol, error::Code code, + const string& error_message); + +// Records tf.data service cross-trainer cache queries. +void RecordTFDataServiceCrossTrainerCacheQuery(bool cache_hit); + +// Records tf.data service cross-trainer cache memory usage in bytes. +void RecordTFDataServiceCrossTrainerCacheSizeBytes(size_t bytes); + +// Records tf.data distributed snapshot bytes committed. +void RecordTFDataServiceSnapshotBytesCommitted(int64_t bytes); + +// Records tf.data distributed snapshot save/load ops. +void RecordTFDataServiceSnapshotOp(const std::string& path, + const std::string& op); + +// Records the current estimated optimal number of tf.data service workers. +void RecordTFDataServiceOptimalNumberOfWorkers(int64_t number_of_workers); + +// Records the file name read by a tf.data Dataset. +// +// The `name` argument identifies the Dataset type (e.g. "TFRecordDataset"). +void RecordTFDataFilename(const string& name, const string& filename); + +// Records the total attempts made by file logger. +void RecordTFDataFileLoggerAttempts(); + +// Records an error of type `code` with message `error_message` encountered by +// file logger. +void RecordTFDataFileLoggerErrors(error::Code code, + const string& error_message); + +// Records the total number of files attempted to be logged by file logger. +void RecordTFDataFileLoggerAttemptedNumFiles(size_t num_files); + +// Records the number of files that encountered an error of type +// `code` with message `error_message` during logging by file logger with this +// error code. +void RecordTFDataFileLoggerErrorsNumFiles(size_t num_files, error::Code code, + const string& error_message); + +// Records statistics of tf.data auto sharding. +// +// The `id` is a unique identifier of the input pipeline. The `policy` +// identifies the auto-sharding policy used, the `num_workers` identifies the +// number of workers, and `num_replicas` identifies the number of replicas. +void RecordTFDataAutoShard(const string& id, data::AutoShardPolicy policy, + int64 num_workers, int64 num_replicas); + +// Records statistics of whether we can rewrite batch size in tf.data auto +// sharding. +// +// The `id` is a unique identifier of the input pipeline. The `eligible` +// indicates whether the input pipeline is eligible for the rewrite. The +// `ineligible_reason` is the reason if the input pipeline is ineligible. +void RecordTFDataAutoShardRewriteBatchSize( + bool eligible, const std::vector& ineligible_reason); + +// Records the number of times each tf.data autotuning algorithm stopping +// criterion is met. +void RecordTFDataAutotuneStoppingCriteria(const string& name); + +// Records the number of times this event occured, for debugging. +void RecordTFDataDebug(const string& event); + +// Records the number of times an error of this type occurred with this status +// code. +void RecordTFDataError(const string& error_type, const string& error_code); + +// Records the framework type used to build the tf.data.Dataset. +void RecordTFDataFrameworkType(const std::string& framework_type); + +// Records the number of times tf.data file logger encountered an error of this +// type occurred with this status code. +void RecordTFDataFileLoggerError(const string& error_type, + const string& error_code); + +// Records parsing of dense tensor features. +void RecordParseDenseFeature(int64_t num_features); + +// Records parsing of sparse tensor features. +void RecordParseSparseFeature(int64_t num_features); + +// Records parsing of ragged tensor features. +void RecordParseRaggedFeature(int64_t num_features); + +// Records the size of input/output tensors in bytes. +void RecordGraphInputTensors(const size_t size); +void RecordGraphOutputTensors(const size_t size); + +// Records the number of cores requested by graphs with XLA SPMD enabled. +void RecordTPUXlaSpmdCoresPerReplica(int64_t cores_per_replica); + +void UpdateGraphExecTime(const uint64 running_time_usecs); +void UpdateGraphPendingQueueLength(uint64 len); + +// Records that one output of an op of type `op_name` was unused. +void RecordUnusedOutput(const string& op_name); + +// Records the pipeline processing time in microseconds +void RecordPipelineProcessingTime(const string& id, + double pipeline_processing_time_usec); + +// Increments the count of binaries loaded from the persistent cache. +void UpdatePersistentCacheLoadCount(); + +// Increments the count of BEF and MLIR deserialized. +void UpdateAotBefMlirLoadCount(); + +// Updates the metrics stored about time spent building graphs. +// +// By "GraphBuild", we refer to building a client graph, which is a sub-graph of +// the full graph, induced by a set of options. In particular, these options +// include the feeds and fetches requested. +// +// This includes time spent: +// * optimizing the graphs with Grappler +// * pruning the sub-graph (unless the place_pruned_graph option is set) +// +// When executing eagerly, this will not record any activity. +// +// TODO(jtkeeling): Should we record building/optimizing tf.functions? +void UpdateGraphBuildTime(const uint64 running_time_usecs); + +// Updates the metric stored for time spent optimizing function graphs. +void UpdateFunctionGraphOptimizationTime(const uint64 running_time_usecs); + +// Updates the metric stored for time saved by caching graph optimization. +void UpdateFunctionGraphOptimizationSavingTime(uint64 saving_time_usec, + GraphOptimizationSource source); + +// Retrieves the total time saved by the graph optimization caching. +uint64 GetFunctionGraphOptimizationSavingTimeUsecs( + GraphOptimizationSource source); + +// Increments the hit count for the graph optimization cache. +void IncrementFunctionGraphOptimizationCacheHitCount( + int count, GraphOptimizationSource source); + +// Gets the hit count for the graph optimization cache. +int64_t GetFunctionGraphOptimizationCacheHitCount( + GraphOptimizationSource source); + +// Increments the failure count for the graph optimization cache restoring. +void IncrementFunctionGraphOptimizationCacheFailureCount( + int count, GraphOptimizationSource source); + +// Gets the failure count for the graph optimization cache. +int64_t GetFunctionGraphOptimizationCacheFailureCount( + GraphOptimizationSource source); + +// Increments the miss count for the graph optimization cache. +void IncrementFunctionGraphOptimizationCacheMissCount( + int count, GraphOptimizationSource source); + +// Gets the miss count for the graph optimization cache. +int64_t GetFunctionGraphOptimizationCacheMissCount( + GraphOptimizationSource source); + +// Increments the number of restoring function graph optimization cache. +void IncrementFunctionGraphOptimizationCacheLoadCount( + int count, GraphOptimizationSource source); + +int64_t GetFunctionGraphOptimizationCacheLoadCount( + GraphOptimizationSource source); + +// Records the activity of the first phase of the mlir bridge using the +// tf_metadata.tf_mlir_bridge_first_phase_v2_count metric. +// bridge_type: replicated, nonreplicated, etc. +// bridge_version: v1 compat, v2, etc. +// device_type: tpu, cpu, gpu, etc. +// fallback_enabled: true if fallback will happen, false if not +// result: outcome of bridge (success, failure, disabled, invalid_graph, etc.) +void UpdateTfMlirBridgeFirstPhaseCounter(const std::string& bridge_type, + const std::string& bridge_version, + const std::string& device_type, + bool fallback_enabled, + const std::string& result); + +enum class Phase2XlaCompilerMetric { + // Bridge phase 2 CompileSingleOp Xla Builder (old version) was successful + kCompileSingleOpXlaBuilderSuccess, + // Bridge phase 2 CompileSingleOp Xla Builder (old version) failed + kCompileSingleOpXlaBuilderFailure, + // Bridge phase 2 CompileSingleOp MLIR version was successful + kCompileSingleOpMlirSuccess, + // Bridge phase 2 CompileSingleOp MLIR version failed + kCompileSingleOpMlirFailure, + // Bridge phase 2 CompileFunction Xla Builder (old version) was successful + kCompileFunctionXlaBuilderSuccess, + // Bridge phase 2 CompileFunction Xla Builder (old version) failed + kCompileFunctionXlaBuilderFailure, + // Bridge phase 2 CompileFunction MLIR version was successful + kCompileFunctionMlirSuccess, + // Bridge phase 2 CompileFunction MLIR version failed + kCompileFunctionMlirFailure, +}; + +// Records the activity of the XlaCompiler entry points. +void IncrementPhase2XlaCompilerCounter(Phase2XlaCompilerMetric metric); + +enum class MlirBridgeSecondPhaseMetric { + // MLIR bridge phase 2 was executed and the graph was processed successfully + // (fallback enabled). + kMlirWithFallbackModeSuccess, + // MLIR bridge phase 2 compilation was failure (fallback enabled). + kMlirWithFallbackModeFailure, + // MLIR bridge phase 2 compilation was successful (manually enabled). + kMlirModeSuccess, + // MLIR bridge phase 2 compilation fails (manually enabled) + kMlirModeFailure, + // Old bridge compilation was run successfully (was run because MLIR bridge + // could not process the graph). + kOldBridgeMlirFilteredSuccess, + // Old bridge failed (was run b/c MLIR bridge could not process the graph). + kOldBridgeMlirFilteredFailure, + // Old bridge compilation was successfully run after MLIR bridge ran and + // failed. + kOldBridgeWithFallbackModeSuccess, + // Old Bridge failed in fallback (was run because MLIR bridge failed first). + kOldBridgeWithFallbackModeFailure, + // MLIR bridge phase 2 Combined Bridge MLIR was successful + kMlirCombinedMlirSuccess, + // MLIR bridge phase 2 Combined Bridge MLIR failed + kMlirCombinedMlirFailure, + // MLIR bridge phase 2 Combined Bridge Old bridge was successful + kMlirCombinedOldSuccess, + // MLIR bridge phase 2 Combined Bridge Old bridge was successful + kMlirCombinedOldFailure, +}; + +// Records the activity of the second phase of the mlir bridge. +void IncrementTfMlirBridgeSecondPhaseCounter( + MlirBridgeSecondPhaseMetric metric); + +// Records the activity per op using the +// tf_metadata.tf_mlir_bridge_graph_analysis_per_op. +// op_name: the name of op. +// construction_context: eager, session, Not tracked. +// is_single_core_inference_mode: true, false. +// unsupported_reason: the reason why the graph is not supported in MLIR-based +// bridge, like invalid graph, has unsupported ops, etc. +// has_unsupported_features: true indicates MLIR-based bridge is disabled, +// false indicates MLIR-based bridge is enabled. + +void UpdateTfMlirBridgeGraphAnalysisPerOp( + const std::string& op_name, const std::string& construction_context, + bool is_single_core_inference_mode, const std::string& num_replicas, + const std::string& num_cores_per_replica, const std::string& use_tpu, + const std::string& allow_soft_placement, + const std::string& use_spmd_for_xla_partitioning, + const std::string& unsupported_reason, bool has_unsupported_features); + +// Records whether a graph contains any of the TF1 features +void RecordTFVersionByGraphFeatures(const std::string& device, + const std::string& context, + bool hasControlFlowV1, + bool hasReferenceVariables, + bool hasManualControlDeps); + +// Convenience class allowing RAII style of reporting for a monitoring::Counter. +template +class ScopedCounter final { + public: + ScopedCounter(monitoring::Counter* const counter, + const std::array& labels) + : counter_(counter), labels_(labels) { + Init(); + } + + // Report counter and stop it. Counter needs to be reset to perform + // next measurement. + void ReportAndStop() { + if (started_) { + started_ = false; + ReportInternal(std::make_index_sequence()); + } + } + + // Start the measurement with the new set of labels. + void Reset(const std::array& labels) { + labels_ = labels; + Init(); + } + + // Start the measurement with the existing set of labels. + void Reset() { Init(); } + + // Returns duration of the current interval in case the timer has started. + // Returns nullopt otherwise. + std::optional DurationMicroSec() const { + return started_ ? std::optional(accumulated_time_ + + Env::Default()->NowMicros() - + start_time_) + : std::nullopt; + } + + // Temporarily stop the timer, but keep accumulated time. + void AccumulateAndStop() { + if (started_) { + accumulated_time_ = Env::Default()->NowMicros() - start_time_; + started_ = false; + } + } + + // Start previously stopped timer. + void Start() { + if (started_) return; + + // Keep previously accumulated time if any. + start_time_ = Env::Default()->NowMicros(); + started_ = true; + } + + ~ScopedCounter() { ReportAndStop(); } + + private: + template + void ReportInternal(std::index_sequence) { + uint64 time_interval = Env::Default()->NowMicros() - start_time_; + time_interval += accumulated_time_; + if (time_interval > 0) { + counter_->GetCell(labels_[S]...)->IncrementBy(time_interval); + } + } + + void Init() { + start_time_ = Env::Default()->NowMicros(); + started_ = true; + accumulated_time_ = 0; + } + + monitoring::Counter* counter_; + std::array labels_; + bool started_{false}; + uint64 start_time_; + uint64 accumulated_time_; +}; + +// Returns a counter used to capture timing metrics for graph optimization +// passes. +monitoring::Counter<2>* GetGraphOptimizationCounter(); + +// Updates metrics for time to distribute variables to all TPU hosts. +void UpdateTpuVariableDistributionTime(const uint64 distribution_time_usecs); + +// Updates the metrics stored about time XLA spents compiling graphs. +void UpdateXlaCompilationTime(const uint64 compilation_time_usecs); + +// Increments (by 1) a simple integer counter that is exposed for testing. +void IncrementTestCounter(const string& name, const string& label); + +// Read-only access to a counter for testing. +const monitoring::CounterCell* TestCounter(const string& name, + const string& label); + +// Read-only wrapper for a TestCounter to track increments between calls. +class TestDelta { + public: + TestDelta(const string& name, const string& label); + void Reset(); + int64 Get(); + + private: + const monitoring::CounterCell* cell_; + int64 last_value_; +}; +void UpdateTpuErrorCounter(const string& op, const string& error_type); +void UpdateEagerClientErrorCounter(const string& error_source, + const string& error_type); + +} // namespace metrics +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_FRAMEWORK_METRICS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/framework/model.h b/third_party/tflite-hdrs/tensorflow/core/framework/model.h new file mode 100644 index 00000000..4c78ec7a --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/framework/model.h @@ -0,0 +1,1294 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_FRAMEWORK_MODEL_H_ +#define TENSORFLOW_CORE_FRAMEWORK_MODEL_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +// TODO(b/114492873): Move this include into core/platform. +#include +#include // NOLINT +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/types/optional.h" +#include "tensorflow/core/framework/cancellation.h" +#include "tensorflow/core/framework/metrics.h" +#include "tensorflow/core/framework/model.pb.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/gtl/cleanup.h" +#include "tensorflow/core/lib/gtl/map_util.h" +#include "tensorflow/core/lib/histogram/histogram.h" +#include "tensorflow/core/lib/random/random.h" +#include "tensorflow/core/platform/cpu_info.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/path.h" +#include "tensorflow/core/platform/statusor.h" +#include "tensorflow/core/platform/strcat.h" +#include "tensorflow/core/platform/stringprintf.h" +#include "tsl/platform/mutex.h" +#include "tsl/platform/thread_annotations.h" + +namespace tensorflow { +namespace data { +namespace model { + +// A constant that can be used to enable auto-tuning. +constexpr int64_t kAutotune = -1; +constexpr char kParallelism[] = "parallelism"; +constexpr char kBufferSize[] = "buffer_size"; +constexpr char kCycleLength[] = "cycle_length"; +constexpr char kDeterministic[] = "deterministic"; +constexpr char kMaxBufferedElements[] = "max_buffered_elements"; + +// A key used to identify the input time of the model. +constexpr char kModelInputTimeKey[] = "model_input_time"; + +// Default share of available RAM that can be used by model's internal buffers. +constexpr double kRamBudgetShare = 0.5; + +// Weight of the latest processing time used in computing the exponential moving +// average of processing time per element. +constexpr double kProcessingTimeEmaWeight = 0.1; + +enum class TraversalOrder { + BFS = 0, + REVERSE_BFS = 1, +}; + +// Represents thread-safe state that can be shared between an input pipeline and +// the performance model. +struct SharedState { + public: + SharedState(int64_t value, std::shared_ptr mu, + std::shared_ptr cond_var) + : value(value), + mu(std::move(mu)), + cond_var(std::move(cond_var)), + tunable(value == kAutotune) {} + + double value; + const std::shared_ptr mu; + const std::shared_ptr cond_var; + const bool tunable; +}; + +// Represents a parameter. +struct Parameter { + Parameter(const string& name, std::shared_ptr state, double min, + double max) + : name(name), + // Sometimes non-autotune nodes (with `autotune_=false`) may contain + // parameters (for example inputs of parallel interleave dataset which + // are not in the current cycle). To avoid unrealistic situation + // (say `buffer_size=-1` or `parallelism=-1`) in the optimization + // computation, if the state value is `kAutotune=-1` (just to indicate + // the `SharedState` is tunable), we initialize the parameter value to + // be the minimal value of the state. + value(state == nullptr || state->value == kAutotune ? min + : state->value), + min(min), + max(max), + state(std::move(state)) {} + + explicit Parameter(const std::shared_ptr parameter) + : name(parameter->name), + value(parameter->value), + min(parameter->min), + max(parameter->max), + state(parameter->state) {} + + // Human-readable name of the parameter. + const string name; + + // Identifies the model value of the parameter. This can be different from + // the actual value (e.g. during optimization search). + double value; + + // Identifies the minimum value of the parameter. + const double min; + + // Identifies the maximum value of the parameter. + const double max; + + // Shared state of the parameter. + std::shared_ptr state; +}; + +// Returns a new tunable parameter with the value set to `min`. +std::shared_ptr MakeParameter(const string& name, + std::shared_ptr state, + double min, double max); + +// Returns a new tunable parameter with the value set to `value` instead +// of `min`. +std::shared_ptr MakeParameter(const string& name, + std::shared_ptr state, + double min, double max, double value); + +// Returns a new non-tunable parameter. +std::shared_ptr MakeNonTunableParameter(const string& name, + double value); + +// Class for managing the ram budget of an iterator. This is necessary for +// coordinating ram usage between the model-based autotuner and the legacy +// prefetch autotuner. Once the legacy autotuner is retired we can remove this +// class and move all ram budget management to the model autotuner. +class RamBudgetManager { + public: + explicit RamBudgetManager(int64_t budget) : budget_(budget) { + if (budget <= 0) { + LOG(WARNING) << "RAM budget is " << budget + << " which could prevent autotuner from properly adjusting " + "buffer sizes."; + } + } + + // Requests a new total memory allocation for the parts of the dataset + // tuned by the model. + // + // The autotuner is expected to follow a pattern like + // + // int64_t budget = ram_budget_manager.AvailableModelRam(); + // NewModel potential_new_params = OptimizeModel(budget); + // int64_t new_ram_used = potential_new_params.RamUsed(); + // if (ram_budget_manager.RequestModelAllocation(new_ram_used)) { + // ApplyModel(potential_new_params); + // } + // + // Returns whether the request succeeded. + bool RequestModelAllocation(int64_t total_bytes) { + mutex_lock l(mu_); + if (total_bytes > budget_ - legacy_prefetch_allocated_) { + return false; + } + model_allocated_ = total_bytes; + return true; + } + + // Requests `delta_elements` allocated to the model where each element is of + // size `element_size` bytes. `delta_elements` can be negative. + // Returns the actual allocated delta elements. + int64_t RequestModelBytes(int64_t delta_elements, double element_size) { + if (delta_elements == 0) { + return 0; + } + int64_t allocated_delta_elements = delta_elements; + mutex_lock l(mu_); + // If `delta_elements` is positive, allocate only up to the available + // memory. + if (delta_elements > 0) { + int64_t max_delta_elements = static_cast( + (budget_ - legacy_prefetch_allocated_ - model_allocated_) / + element_size); + if (max_delta_elements < 0) { + return 0; + } + allocated_delta_elements = std::min(max_delta_elements, delta_elements); + } + model_allocated_ += + static_cast(allocated_delta_elements * element_size); + return allocated_delta_elements; + } + + // Requests `bytes` additional bytes for the purpose of legacy prefetch + // autotuning. + // + // Unlike RequestModelAllocation, we use a delta number of bytes, since there + // can only be one model per iterator but there may be multiple legacy + // prefetch autotuners. + // + // Returns whether there were enough bytes left in the budget to serve the + // request. If not, no bytes are allocated. + bool RequestLegacyPrefetchBytes(int64_t delta_bytes) { + mutex_lock l(mu_); + if (delta_bytes > budget_ - legacy_prefetch_allocated_ - model_allocated_) { + return false; + } + legacy_prefetch_allocated_ += delta_bytes; + return true; + } + + // The total number of bytes that the model could potentially use. + int64_t AvailableModelRam() const { + tf_shared_lock l(mu_); + return budget_ - legacy_prefetch_allocated_; + } + + void UpdateBudget(int64_t budget) { + mutex_lock l(mu_); + budget_ = budget; + VLOG(2) << "Updated ram budget to " << budget; + } + + std::string DebugString() { + mutex_lock l(mu_); + return absl::StrCat("RamBudgetManager: budget_: ", budget_, + " prefetch allocated: ", legacy_prefetch_allocated_, + " model allocated: ", model_allocated_); + } + + private: + mutable mutex mu_; + int64_t budget_ TF_GUARDED_BY(mu_) = 0; + // Number of bytes allocated by legacy prefetch autotuner. + int64_t legacy_prefetch_allocated_ TF_GUARDED_BY(mu_) = 0; + // Number of bytes allocated by the model. + int64_t model_allocated_ TF_GUARDED_BY(mu_) = 0; +}; + +// Abstract representation of a TensorFlow input pipeline node. It collects +// information about inputs to this node, processing time spent executing the +// node logic, number of elements produced by the node, various other +// information (e.g. batch size or execution parallelism). +// +// Developers of tf.data transformations are not expected to interact with +// this class directly. Boiler plate code for creating the abstract +// representation of the input pipeline and collecting common information has +// been added to the implementation of `DatasetBase` and `DatasetBaseIterator` +// respectively. +// +// In addition, `DatasetBaseIterator` provides wrappers that can be used for +// transformation-specific information collection. The `SetMetadata` wrapper +// can be used to pass arbitrary metadata to the modeling framework, while the +// `StartWork` and `StopWork` wrappers should be used to correctly account for +// processing time of multi-threaded transformation that yield the CPU; such +// transformations should invoke `StartWork()` when a transformation thread +// starts executing (e.g. when created or woken up) and `StopWork()` when a +// transformation thread stops executing (e.g. when returning or waiting). +class Node { + public: + // Arguments for `Node` constructor. + struct Args { + int64_t id; + string name; + std::shared_ptr output; + }; + + using Factory = std::function(Args)>; + using NodeVector = std::vector>; + using NodePairList = + std::list, std::shared_ptr>>; + using ModelParameters = + std::vector>>; + using NodeValues = absl::flat_hash_map; + using ParameterGradients = + absl::flat_hash_map, double>; + + explicit Node(Args args) + : id_(args.id), + name_(std::move(args.name)), + autotune_(true), + buffered_bytes_(0), + peak_buffered_bytes_(0), + buffered_elements_(0), + buffered_elements_low_(std::numeric_limits::max()), + buffered_elements_high_(std::numeric_limits::min()), + bytes_consumed_(0), + bytes_produced_(0), + num_elements_(0), + processing_time_(0), + record_metrics_(true), + metrics_(name_), + output_(args.output.get()), + output_weak_ptr_(args.output) {} + + virtual ~Node() { + // Clear the sub-nodes instead of relying on implicit shared pointer + // destructor to avoid potential stack overflow when the tree is deep. + std::deque> queue; + { + mutex_lock l(mu_); + while (!inputs_.empty()) { + queue.push_back(inputs_.front()); + inputs_.pop_front(); + } + } + while (!queue.empty()) { + auto node = queue.back(); + queue.pop_back(); + { + mutex_lock l(node->mu_); + while (!node->inputs_.empty()) { + queue.push_back(node->inputs_.front()); + node->inputs_.pop_front(); + } + } + } + + FlushMetrics(); + } + + // Adds an input. + void add_input(std::shared_ptr node) TF_LOCKS_EXCLUDED(mu_) { + mutex_lock l(mu_); + inputs_.push_back(node); + } + + // Increments the aggregate processing time by the given delta. + void add_processing_time(int64_t delta) TF_LOCKS_EXCLUDED(mu_) { + processing_time_ += delta; + } + + // Returns an indication whether autotuning is enabled for this node. + bool autotune() const TF_LOCKS_EXCLUDED(mu_) { return autotune_; } + + // Returns the number of bytes stored in this node's buffer. + int64_t buffered_bytes() const TF_LOCKS_EXCLUDED(mu_) { + return buffered_bytes_; + } + + // Returns the peak number of bytes stored in this node's buffer. + int64_t peak_buffered_bytes() const TF_LOCKS_EXCLUDED(mu_) { + return peak_buffered_bytes_; + } + + // Returns the number of elements stored in this node's buffer. + int64_t buffered_elements() const TF_LOCKS_EXCLUDED(mu_) { + return buffered_elements_; + } + + // Returns the low watermark of the number of elements stored in this node's + // buffer. The watermarks are reset at the beginning of the execution time and + // each time the buffer is upsized or downsized. + int64_t buffered_elements_low() const TF_LOCKS_EXCLUDED(mu_) { + return buffered_elements_low_; + } + + // Returns the high watermark of the number of elements stored in this node's + // buffer. The watermarks are reset at the beginning of the execution time and + // each time the buffer is upsized or downsized. + int64_t buffered_elements_high() const TF_LOCKS_EXCLUDED(mu_) { + return buffered_elements_high_; + } + + // Returns the number of bytes consumed by the node. + int64_t bytes_consumed() const TF_LOCKS_EXCLUDED(mu_) { + return bytes_consumed_; + } + + // Returns the number of bytes produced by the node. + int64_t bytes_produced() const TF_LOCKS_EXCLUDED(mu_) { + return bytes_produced_; + } + + // Indicates whether the node has tunable parameters. + bool has_tunable_parameters() const TF_LOCKS_EXCLUDED(mu_) { + tf_shared_lock l(mu_); + for (const auto& pair : parameters_) { + if (pair.second->state->tunable) return true; + } + return false; + } + + // Returns the unique node ID. + int64_t id() const TF_LOCKS_EXCLUDED(mu_) { return id_; } + + // Returns the node inputs. + std::list> inputs() const TF_LOCKS_EXCLUDED(mu_) { + tf_shared_lock l(mu_); + return inputs_; + } + + // Returns a longer node name that is guaranteed to be unique. + string long_name() const { return strings::StrCat(name_, "(id:", id_, ")"); } + + // Returns the node name. + const string& name() const { return name_; } + + // Returns the number of elements produced by the node. + int64_t num_elements() const TF_LOCKS_EXCLUDED(mu_) { return num_elements_; } + + // Returns the node output. + Node* output() const { return output_; } + std::shared_ptr output_shared() { return output_weak_ptr_.lock(); } + + // Returns the parameter value. + double parameter_value(const string& name) const TF_LOCKS_EXCLUDED(mu_) { + tf_shared_lock l(mu_); + return parameters_.at(name)->state->value; + } + + // Returns the aggregate processing time. + int64_t processing_time() const TF_LOCKS_EXCLUDED(mu_) { + return processing_time_; + } + + // Records that the node consumed the given number of bytes. + void record_bytes_consumed(int64_t num_bytes) { + bytes_consumed_ += num_bytes; + } + + // Records that the node produced the given number of bytes. + void record_bytes_produced(int64_t num_bytes) { + bytes_produced_ += num_bytes; + } + + // Records the change in this node's buffer. + void record_buffer_event(int64_t bytes_delta, int64_t elements_delta) { + buffered_bytes_ += bytes_delta; + peak_buffered_bytes_.store(std::max(peak_buffered_bytes_, buffered_bytes_)); + buffered_elements_ += elements_delta; + // There is no need to maintain watermarks for synchronous ops because we + // will not upsize or downsize the buffers of synchronous ops. + if (IsAsync()) { + int64_t low_watermark = + std::min(buffered_elements_low_, buffered_elements_); + buffered_elements_low_ = low_watermark; + int64_t high_watermark = + std::max(buffered_elements_high_, buffered_elements_); + buffered_elements_high_ = high_watermark; + } + } + + // Records that the node produced an element. + void record_element() TF_LOCKS_EXCLUDED(mu_) { + num_elements_++; + { + mutex_lock l(mu_); + UpdateProcessingTimeEma(); + } + } + + // Records that a node thread has started executing. + void record_start(int64_t time_nanos) TF_LOCKS_EXCLUDED(mu_) { + DCHECK_EQ(work_start_, 0); + work_start_ = time_nanos; + } + + // Records that a node thread has stopped executing. + void record_stop(int64_t time_nanos) TF_LOCKS_EXCLUDED(mu_) { + // TODO(jsimsa): Use DCHECK_NE(work_start_, 0) here. + if (work_start_ != 0) { + processing_time_ += time_nanos - work_start_; + work_start_ = 0; + } else { + VLOG(1) << "Encountered a stop event without a matching start event."; + } + } + + // Returns whether work is currently being recorded, i.e. whether we are + // currently between a `record_start` and a `record_stop`. + bool is_recording() TF_LOCKS_EXCLUDED(mu_) { return work_start_ > 0; } + + // Removes an input. + void remove_input(std::shared_ptr input) TF_LOCKS_EXCLUDED(mu_) { + mutex_lock l(mu_); + inputs_.remove(input); + } + + // Sets the value that determines whether autotuning is enabled for this node. + void set_autotune(bool autotune) TF_LOCKS_EXCLUDED(mu_) { + autotune_.store(autotune); + } + + // Resets buffer watermarks to the current buffered elements. + void ResetBufferWatermarks() { + if (!IsAsync()) { + return; + } + int64_t current_buffer_size = buffered_elements_; + buffered_elements_low_ = current_buffer_size; + buffered_elements_high_ = current_buffer_size; + } + + // Returns true for asynchronous nodes; false otherwise. + virtual bool IsAsync() const { return false; } + + // Returns the ratio of the node, which is defined as the number of elements + // per input needed by the node to produce an element, e.g. batch size of a + // `Batch`. It can be 0 if the ratio is unknown. + virtual double Ratio() const { return 1.0; } + + // Computes the self time in nanoseconds of the node to produce one element. + virtual double ComputeSelfTime() const; + + // Returns the parameter value if it exists, not ok status otherwise. + absl::StatusOr ParameterValue(const std::string& parameter_name) const + TF_LOCKS_EXCLUDED(mu_) { + tf_shared_lock l(mu_); + if (parameters_.contains(parameter_name)) { + return parameters_.at(parameter_name)->value; + } + return errors::NotFound("Parameter ", parameter_name, + " was not found in model node ", long_name()); + } + + // Given the average time between events when the elements in the buffer are + // produced (`producer_time`), the average time between events when elements + // in the buffer are consumed (`consumer_time`) and the buffer size, the + // method computes the expected time a consumer event will have to wait. + // + // The wait time is approximated as the product of the probability the buffer + // will be empty and the time it takes to produce an element into the buffer. + // + // The formula used for computing the probability is derived by modeling the + // problem as an M/M/1/K queue + // (https://en.wikipedia.org/wiki/Birth%E2%80%93death_process#M/M/1/K_queue). + // + // Collects derivatives of `ComputeWaitTime` w.r.t `producer_time`, + // `consumer_time' and `buffer_size` if the corresponding pointers are not + // `nullptr`. + static double ComputeWaitTime(double producer_time, double consumer_time, + double buffer_size, + double* producer_time_derivative, + double* consumer_time_derivative, + double* buffer_size_derivative); + + // Collects tunable parameters in the subtree rooted in this node. + ModelParameters CollectTunableParameters() const TF_LOCKS_EXCLUDED(mu_); + + // Collects tunable parameters in this node. + ModelParameters CollectNodeTunableParameters() const TF_LOCKS_EXCLUDED(mu_); + + // Returns a human-readable representation of this node. + string DebugString() const TF_LOCKS_EXCLUDED(mu_); + + // Flushes the metrics recorded by this node. + void FlushMetrics() TF_LOCKS_EXCLUDED(mu_); + + // Returns the per-element output time for this node and if `gradients` is not + // `nullptr`, collects the output time gradient w.r.t. tunable parameters of + // the subtree rooted in this node. + double OutputTime(NodeValues* input_times, + ParameterGradients* gradients) const TF_LOCKS_EXCLUDED(mu_); + + // Returns a copy of this node, making a deep copy of its inputs and a + // shallow copy of its tunable parameters. + // + // The purpose for this method is to allow the model optimization logic to + // operate over immutable state while allowing concurrent model updates. + std::shared_ptr Snapshot() const TF_LOCKS_EXCLUDED(mu_); + + // Returns the per-element processing time in nanoseconds spent in this node. + double SelfProcessingTime() const TF_LOCKS_EXCLUDED(mu_); + + // Returns the total number of bytes buffered in all nodes in the subtree for + // which autotuning is enabled. + double TotalBufferedBytes() const TF_LOCKS_EXCLUDED(mu_); + + // Collects the total buffer limit of all nodes in the subtree for which + // autotuning is enabled. This number represents the amount of memory that + // would be used by the subtree nodes if all of their buffers were full. + double TotalMaximumBufferedBytes() const TF_LOCKS_EXCLUDED(mu_); + + // Returns the per-element CPU time in nanoseconds spent in the subtree rooted + // in this node. If `processing_times` is not `nullptr`, collects the + // per-element CPU time spent in each node of the subtree. + double TotalProcessingTime(NodeValues* processing_times) + TF_LOCKS_EXCLUDED(mu_); + + // Produces a proto for this node. Does not produce a proto for input nodes. + virtual absl::Status ToProto(ModelProto::Node* node_proto) const; + + // Restores a node from the proto. Does not restore input nodes. + static absl::Status FromProto(ModelProto::Node node_proto, + std::shared_ptr output, + std::shared_ptr* node); + + // Returns a vector of nodes of the subtree rooted in this node. The nodes are + // either in breadth-first search or reverse breadth-first search order + // depending on the `order` argument. The nodes are collected based on the + // results of the `collect_node` predicate: if the predicate returns `false` + // for a given node, then the subtree rooted in this node is excluded. The + // root node itself is not collected. + NodeVector CollectNodes(TraversalOrder order, + bool collect_node(const std::shared_ptr)) const + TF_LOCKS_EXCLUDED(mu_); + + // Downsizes buffer parameters of this node. Returns true if any buffer is + // downsized. + bool TryDownsizeBuffer(); + + // Collects buffer parameters of this node that should be upsized. + void CollectBufferParametersToUpsize( + absl::flat_hash_map& node_parameters); + + // Returns the average size of an element buffered in this node. + double AverageBufferedElementSize() const { + tf_shared_lock l(mu_); + return AverageBufferedElementSizeLocked(); + } + + // Copies node's parameter state value to parameter value if the parameter + // name matches `parameter_name`. + void SyncStateValuesToParameterValues(const std::string& parameter_name); + + void SetEstimatedElementSize(std::optional estimated_element_size) { + mutex_lock l(mu_); + estimated_element_size_ = estimated_element_size; + } + + protected: + // Used for (incrementally) recording metrics. The class is thread-safe. + class Metrics { + public: + explicit Metrics(const string& name) + : bytes_consumed_counter_(metrics::GetTFDataBytesConsumedCounter(name)), + bytes_produced_counter_(metrics::GetTFDataBytesProducedCounter(name)), + num_elements_counter_(metrics::GetTFDataElementsCounter(name)), + recorded_bytes_consumed_(0), + recorded_bytes_produced_(0), + recorded_num_elements_(0) {} + + // Expects the total number of bytes consumed and records the delta since + // last invocation. + void record_bytes_consumed(int64_t total_bytes) { + int64_t delta = + total_bytes - recorded_bytes_consumed_.exchange(total_bytes); + bytes_consumed_counter_->IncrementBy(delta); + } + + // Expects the total number of bytes produced and records the delta since + // last invocation. + void record_bytes_produced(int64_t total_bytes) { + int64_t delta = + total_bytes - recorded_bytes_produced_.exchange(total_bytes); + bytes_produced_counter_->IncrementBy(delta); + } + + // Expects the total number of elements produced and records the delta since + // last invocation. + void record_num_elements(int64_t total_elements) { + int64_t delta = + total_elements - recorded_num_elements_.exchange(total_elements); + num_elements_counter_->IncrementBy(delta); + } + + private: + monitoring::CounterCell* const bytes_consumed_counter_; + monitoring::CounterCell* const bytes_produced_counter_; + monitoring::CounterCell* const num_elements_counter_; + std::atomic recorded_bytes_consumed_; + std::atomic recorded_bytes_produced_; + std::atomic recorded_num_elements_; + }; + + // Computes the exponential moving average of processing time per element. + void UpdateProcessingTimeEma() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { + if (previous_processing_time_ == 0) { + if (num_elements_ > 0) { + processing_time_ema_ = + static_cast(processing_time_) / + static_cast(num_elements_ + buffered_elements_); + } else { + processing_time_ema_ = static_cast(processing_time_); + } + } else { + processing_time_ema_ = + (1.0 - kProcessingTimeEmaWeight) * processing_time_ema_ + + kProcessingTimeEmaWeight * + static_cast(processing_time_ - previous_processing_time_); + } + previous_processing_time_ = processing_time_; + } + + // Returns the number of inputs. + int64_t num_inputs() const TF_SHARED_LOCKS_REQUIRED(mu_) { + int64_t num_inputs = 0; + for (auto& input : inputs_) { + // Inputs for which autotuning is disabled are excluded. + if (input->autotune()) { + ++num_inputs; + } + } + return num_inputs; + } + + // Creates a clone of this node. + virtual std::shared_ptr Clone(std::shared_ptr output) const + TF_SHARED_LOCKS_REQUIRED(mu_) = 0; + + // Returns the average size of an element buffered in this node. + double AverageBufferedElementSizeLocked() const TF_SHARED_LOCKS_REQUIRED(mu_); + + // Returns the sum of per-element output time for the tunable inputs of this + // node. + double OutputTimeForInputs(const NodeValues& output_times) const + TF_SHARED_LOCKS_REQUIRED(mu_); + + // Returns the sum of output time gradient w.r.t. input time for the tunable + // inputs of this node. + double OutputTimeGradientsForInputs(const NodeValues& output_time_gradients) + const TF_SHARED_LOCKS_REQUIRED(mu_); + + // Computes the input time for this node and stores it in `input_times`. + virtual void InputTimeLocked(NodeValues* input_times) const + TF_SHARED_LOCKS_REQUIRED(mu_) = 0; + + // Computes the per-element output time for this node and stores it in + // `output_times`. If `gradients` is not `nullptr`, computes the output time + // gradient w.r.t. tunable parameters of the subtree rooted in this node and + // stores it in `gradients`, also computes the output time gradient w.r.t. + // input time and stores it in `output_time_gradients`. + virtual void OutputTimeLocked(const NodeValues& input_times, + ParameterGradients* gradients, + NodeValues* output_times, + NodeValues* output_time_gradients) const + TF_SHARED_LOCKS_REQUIRED(mu_) = 0; + + // Returns the sum of per-element processing time for the inputs of this node + // by adding values for input nodes in `total_processing_times`. Processing + // time for a given input is a weighted combination of a statistic based on + // history of input processing time and the actual time. This is done to + // improve accuracy of processing time estimation for newly created inputs. + // + // Uniform distribution of per-element processing times across different + // inputs is assumed. + double TotalProcessingTimeForInputs(const NodeValues& total_processing_times) + TF_SHARED_LOCKS_REQUIRED(mu_); + + // Returns the per-element processing time spent in this node. + double SelfProcessingTimeLocked() const TF_SHARED_LOCKS_REQUIRED(mu_); + + // Computes the per-element CPU time spent in the subtree rooted in this node + // and stores it in `total_processing_times`. If `processing_times` is not + // `nullptr`, collects the per-element CPU time spent in each node of the + // subtree. + virtual void TotalProcessingTimeLocked(NodeValues* processing_times, + NodeValues* total_processing_times) + TF_SHARED_LOCKS_REQUIRED(mu_) = 0; + + // This is the locked version of the public `CollectNodes`. + NodeVector CollectNodesLocked(TraversalOrder order, + bool collect_node(const std::shared_ptr)) + const TF_SHARED_LOCKS_REQUIRED(mu_); + + // Collects tunable parameters in the subtree rooted in this node assuming + // mutex locked. + ModelParameters CollectTunableParametersLocked() const + TF_SHARED_LOCKS_REQUIRED(mu_); + + // Collect tunable parameters on the nodes which have recorded + // elements. + void CollectTunableParametersHelper(ModelParameters* parameters) const + TF_SHARED_LOCKS_REQUIRED(mu_); + + // Build up debug string for the node and store in the debug strings map. + void DebugStringHelper(absl::flat_hash_map* debug_strings) + const TF_SHARED_LOCKS_REQUIRED(mu_); + + // Copy the node and add the (input, copy) pairs to the NodePairList. + std::shared_ptr SnapshotHelper(std::shared_ptr cloned_output, + NodePairList* node_pairs) const; + + // Compute total buffered bytes for the node and store in the total bytes map. + void TotalBufferedBytesHelper(NodeValues* total_bytes) const + TF_SHARED_LOCKS_REQUIRED(mu_); + + // Compute total maximum buffered bytes for the node and store in the total + // bytes map. + void TotalMaximumBufferedBytesHelper(NodeValues* total_bytes) const + TF_SHARED_LOCKS_REQUIRED(mu_); + + // Compute and return the maximum buffered bytes on the node itself. By + // default non-tunable nodes are assumed not to buffer any bytes, so the + // tunable nodes as subclasses are expected to override this method to ensure + // that the optimization algorithm respects the memory budget. + virtual double MaximumBufferedBytes() const TF_SHARED_LOCKS_REQUIRED(mu_); + + // Restores node from the proto. Note that this is not done recursively, i.e. + // input nodes are not restored. + static absl::Status FromProtoHelper(ModelProto::Node node_proto, + std::shared_ptr node); + + // Stores the time passed to the last call to `Node::record_start()` on the + // current thread. + // + // NOTE: This thread-local variable is shared between all instances of `Node` + // on which the same thread calls `record_start()` or `record_stop()`. It + // relies on the invariant that at most one `Node` can be "active" on a + // particular thread at any time. Therefore if `n->record_start()` is called + // on thread `t`, then `n->record_stop()` must be called before another call + // to `Node::record_start()` (for any node). + static thread_local int64_t work_start_; // Will be initialized to zero. + + mutable mutex mu_; + const int64_t id_; + const string name_; + + // Indicates whether the subtree rooted in this node should be included in + // autotuning. In particular, if this is `false`, then the subtree is excluded + // from computation of output time and processing time. + std::atomic autotune_; + std::atomic buffered_bytes_; + std::atomic peak_buffered_bytes_; + std::atomic buffered_elements_; + std::atomic buffered_elements_low_; + std::atomic buffered_elements_high_; + std::atomic bytes_consumed_; + std::atomic bytes_produced_; + std::atomic num_elements_; + std::atomic processing_time_; + std::atomic record_metrics_; + Metrics metrics_; + absl::flat_hash_map> parameters_ + TF_GUARDED_BY(mu_); + + // Statistic of inputs processing time history. + double input_processing_time_sum_ = 0.0L; + int64_t input_processing_time_count_ = 0; + + // Holds the previous processing time and the per element processing time + // exponential moving average. + int64_t previous_processing_time_ TF_GUARDED_BY(mu_) = 0; + double processing_time_ema_ TF_GUARDED_BY(mu_) = 0.0; + + // Inputs of this node. These can represent an iterator created from the input + // dataset but also other input iterators (e.g. created by the user-defined + // functions of `flat_map` or `interleave`). + std::list> inputs_ TF_GUARDED_BY(mu_); + + // The reference to the output node is not owned so that deletion of a + // node results in recursive deletion of the subtree rooted in the node. + Node* const output_; + std::weak_ptr output_weak_ptr_; + std::optional estimated_element_size_ TF_GUARDED_BY(mu_) = + std::nullopt; +}; + +// InterleaveMany is used to model datasets whose inputs are used to create +// datasets whose elements are then interleaved. +std::shared_ptr MakeInterleaveManyNode( + Node::Args args, std::vector> parameters); + +// AsyncInterleaveMany nodes are the asynchronous version of InterleaveMany +// nodes. +std::shared_ptr MakeAsyncInterleaveManyNode( + Node::Args args, std::vector> parameters); + +// KnownMany nodes model datasets that synchronously consume known number of +// input element per output element. +std::shared_ptr MakeKnownRatioNode(Node::Args args, double ratio); + +// AsyncKnownRatio nodes are the asynchronous version of KnownRate nodes. +std::shared_ptr MakeAsyncKnownRatioNode( + Node::Args args, double ratio, double memory_ratio, + std::vector> parameters, + bool is_legacy_prefetch_autotuned = false); + +// Makes an AsyncKnownRatioNode. If `estimated_element_size` is provided, +// it will be used during the estimation of maximum buffered bytes. +std::shared_ptr MakeAsyncKnownRatioNode( + Node::Args args, double ratio, + std::vector> parameters, + bool is_legacy_prefetch_autotuned = false, + std::optional estimated_element_size = std::nullopt); + +// Source nodes represent data sources. +std::shared_ptr MakeSourceNode(Node::Args args); + +// UnknownMany nodes represent datasets that synchronously consume an +// unknown number of input elements per output. +// +// Unlike KnownRatio nodes which expect the ratio between inputs and outputs is +// specified as a parameter, UnknownRatio estimates the ratio empirically. +std::shared_ptr MakeUnknownRatioNode(Node::Args args); + +// AsyncUnknownRatio nodes are the asynchronous version of unknown ratio nodes. +std::shared_ptr MakeAsyncUnknownRatioNode( + Node::Args args, std::vector> parameters); + +// Unknown nodes represent datasets for which we do not have a model. It acts +// as pass-through between inputs and output. +std::shared_ptr MakeUnknownNode(Node::Args args); + +// Abstract representation of a TensorFlow input pipeline that can be used +// for collecting runtime information and optimizing performance. It collects +// runtime information about execution of the input pipeline that is used to +// create a performance model, which is in turn used to identify optimal values +// of tunable parameters. +// +// Developers of tf.data transformations are not expected to interact with this +// class directly. Boiler plate code for creating the abstract representation of +// the input pipeline and collecting runtime information has been added to the +// implementation of `DatasetBase` and `DatasetBaseIterator` respectively. +// +// The order of locks acquired is SharedState lock, Model lock, Node lock. +// SharedState lock is acquired first because it shares the same lock as the +// dataset iterator that contains it. +class Model { + public: + using OptimizationParams = ModelProto::OptimizationParams; + using ModelParameters = Node::ModelParameters; + using NodeValues = Node::NodeValues; + using ParameterGradients = Node::ParameterGradients; + + explicit Model(std::optional dataset_name); + explicit Model() : Model(std::nullopt) {} + ~Model(); + + // Returns a pointer to the model's output node. + std::shared_ptr output() const { + mutex_lock l(mu_); + return output_; + } + + // Set the experiment that this job is part of. + void AddExperiment(const std::string& experiment) { + experiments_.insert(experiment); + } + + // Adds a node with the given name and given parent. + void AddNode(Node::Factory factory, const string& name, + std::shared_ptr parent, std::shared_ptr* out_node) + TF_LOCKS_EXCLUDED(mu_); + + // Returns a human-readable string representation of the model. This method + // can be invoked automatically by monitoring gauges and to avoid frequent + // recomputation, the implementation caches the result. + std::string DebugString(); + + // Uses the given algorithm and resource budgets to periodically perform the + // autotuning optimization. + // + // `cpu_budget_func` can be used to provide the optimizer with up-to-date + // values in cases where CPUs budgets may be changed by the runtime + // dynamically. + // + // `ram_budget_func` is similar to `cpu_budget_func`. This lambda takes a + // parameter that is the total number of bytes currently buffered by the + // model. + // + // To terminate the execution of the optimization loop, the caller needs to + // invoke `cancellation_mgr->StartCancel()`. + absl::Status OptimizeLoop(AutotuneAlgorithm algorithm, + std::function cpu_budget_func, + double ram_budget_share, + std::optional fixed_ram_budget, + RamBudgetManager& ram_budget_manager, + CancellationManager* cancellation_manager); + + // Uses the given algorithm and resource budgets to perform the autotuning + // optimization. + void Optimize(AutotuneAlgorithm algorithm, + std::function cpu_budget_func, + double ram_budget_share, + std::optional fixed_ram_budget, + double model_input_time, RamBudgetManager& ram_budget_manager, + CancellationManager* cancellation_manager); + + // Optimizes buffers in the pipeline rooted at `snapshot`. It downsizes + // buffers that are too large and upsizes buffers that are too small while + // respecting the ram budget. If any node is downsized or upsized, the + // watermarks of all nodes are reset to the buffered elements. + void OptimizeBuffers(std::shared_ptr snapshot, int64_t ram_budget); + + // Collects the output time and if `gradients` is not `nullptr`, the output + // time gradient w.r.t. tunable parameters of the subtree rooted in the given + // node. + double OutputTime(std::shared_ptr node, double model_input_time, + ParameterGradients* gradients); + + // Removes the given node. + void RemoveNode(std::shared_ptr node) TF_LOCKS_EXCLUDED(mu_); + + // Produces a proto for this model. + absl::Status ToProto(ModelProto* model_proto); + + // Restores a model from the proto. + static absl::Status FromProto(ModelProto model_proto, + std::unique_ptr* model); + + // Saves this model with a given snapshot and its optimization parameters to a + // file. Note that the file directory must already exist. + absl::Status Save(const string& fname, std::shared_ptr snapshot, + const OptimizationParams& optimization_params); + + // Loads a model and its optimization parameters from a file with the given + // name. + static absl::Status Load(const string& fname, std::unique_ptr* model, + OptimizationParams* optimization_params); + + // Records gap time between consecutive `GetNext()` calls. + void RecordIteratorGapTime(uint64_t duration_usec); + + // Computes the target time in nsecs to use for `STAGE_BASED` autotune + // algorithm. Returns 0 if there if there are not sufficient recorded iterator + // gap times to produce a good estimate. + double ComputeTargetTimeNsec(); + + // Computes the target time in nsecs to use for estimating input bottlenecks. + // Returns 0 if there are not sufficient recorded iterator gap times to + // produce a good estimate. + double ComputeExperimentalTargetTimeNsec(); + + // Returns the time in nanoseconds it takes the pipeline to produce an + // element, according to the latest model snapshot obtained from optimization. + // Returns 0 if the model snapshot is empty or null. This may be caused by not + // having executed an optimization round before. + double ComputeSnapshotProcessingTimeNsec() const; + + private: + // Determines whether optimization should stop given total processing time, + // estimated output time, and estimated number of buffers bytes. + using StopPredicate = + std::function; + + static constexpr int64_t kOptimizationPeriodMinMs = 10; + static constexpr int64_t kOptimizationPeriodMaxMs = + 60 * EnvTime::kSecondsToMillis; + + // Collects tunable parameters in the tree rooted in the given node, returning + // a vector which contains pairs of node names and tunable parameters. + ModelParameters CollectTunableParameters(std::shared_ptr node); + + // Copy parameter state values to parameter values if necessary.For some + // nodes, the parameter state values are not tuned by Autotune and hence the + // parameter values can be stale. We do not sync all parameters because it may + // increase mutex contention with `GetNext()`. + void MaybeSyncStateValuesToValues(std::shared_ptr snapshot); + + // Downsizes buffers that are too large for all nodes rooted at `snapshot`. + // Returns true if any buffer is downsized. + bool DownsizeBuffers(std::shared_ptr snapshot); + + // Upsizes buffers that are too small for all nodes rooted at `snapshot` while + // respecting the ram budget. Returns true if any buffer is upsized. + bool UpsizeBuffers(std::shared_ptr snapshot, int64_t ram_budget); + + // Reset buffer watermarks of all asynchronous nodes to their buffered + // elements. + void ResetBufferWatermarks(); + + // Collects buffer parameters of all nodes in the model that should be + // upsized. + absl::flat_hash_map CollectBufferParametersToUpsize( + std::shared_ptr snapshot); + + // Flushes metrics recorded by the model. + void FlushMetrics() TF_LOCKS_EXCLUDED(mu_); + + // This optimization algorithm starts by setting all tunable parallelism + // parameters to the minimum value. It then improves current parameters by + // making a step in the direction opposite to the gradient of `OutputTime` and + // projecting resulting values on the feasible intervals. Improvement step is + // repeated until either the output time improvement is smaller than threshold + // value or the output time is less than the processing time needed to produce + // an element divided by CPU budget. + void OptimizeGradientDescent(std::shared_ptr snapshot, + const OptimizationParams& optimization_params, + CancellationManager* cancellation_manager); + + // Helper method for implementing hill-climb optimization that can be + // parametrized by a predicate to use for stopping the optimization. + void OptimizeHillClimbHelper(std::shared_ptr snapshot, + const OptimizationParams& optimization_params, + CancellationManager* cancellation_manager, + int64_t ram_budget, + RamBudgetManager& ram_budget_manager, + StopPredicate should_stop); + + // This optimization algorithm starts by setting all tunable parallelism + // parameters to the minimum value. It then repeatedly identifies the + // parameter whose increase in parallelism decreases the output time the most. + // This process is repeated until all parameters reach their maximum values or + // the projected output time is less than or equal to the processing time + // needed to produce an element divided by CPU budget. + void OptimizeHillClimb(std::shared_ptr snapshot, + const OptimizationParams& optimization_params, + CancellationManager* cancellation_manager, + RamBudgetManager& ram_budget_manager); + + // This optimization behaves similarly to the hill climb optimization but uses + // a relaxed stoping condition, allowing the optimization to oversubscribe + // CPU. + void OptimizeMaxParallelism(std::shared_ptr snapshot, + const OptimizationParams& optimization_params, + CancellationManager* cancellation_manager, + RamBudgetManager& ram_budget_manager); + + // This optimization starts by setting all tunable parallelism parameters to + // their minimum values. It then repeatedly increases the parallelism + // parameter of the longest stage by 1 until either the longest stage is + // faster than the target time or the memory or CPU budget is fully utilized. + // TODO(b/226910071): The second part of this algorithm optimizes the buffer + // sizes of parallel ops. + void OptimizeStageBased(std::shared_ptr snapshot, + const OptimizationParams& optimization_params, + CancellationManager* cancellation_manager, + RamBudgetManager& ram_budget_manager); + + // This is the first part of the stage-based optimization that optimizes + // tunable parallelism parameters for async interleave many nodes only. We + // separately optimize async interleave many nodes more aggressively because + // the variance of IO is difficult to predict. + void OptimizeStageBasedAsyncInterleaveManyNodes( + std::shared_ptr snapshot, + const OptimizationParams& optimization_params, + CancellationManager* cancellation_manager, + RamBudgetManager& ram_budget_manager); + + // This is the second part of the stage-based optimization that optimizes + // tunable parallelism parameters for all nodes other than async interleave + // many nodes. + void OptimizeStageBasedNonAsyncInterleaveManyNodes( + std::shared_ptr snapshot, double target_time_nsec, + const OptimizationParams& optimization_params, + CancellationManager* cancellation_manager, + RamBudgetManager& ram_budget_manager); + + // Determines if we should stop the gradient descent optimization iterations + // based on number of increasable parameters, CPU budget, RAM budget and + // current resource usage. + bool ShouldStop(int64_t cpu_budget, int64_t ram_budget, + const ModelParameters& parameters, + const ModelParameters& parallelism_parameters, + const ModelParameters& buffer_size_parameters, + std::shared_ptr snapshot, bool* cpu_budget_reached); + + // Collects the processing time for the given node. + double TotalProcessingTime(std::shared_ptr node); + + // Collects the total number of bytes buffered in all nodes in the subtree + // rooted in the given node for which autotuning is enabled. + double TotalBufferedBytes(std::shared_ptr node); + + // Collects the total buffer limit of all nodes in the subtree rooted in the + // given node for which autotuning is enabled. This number represents the + // amount of memory that would be used by the subtree nodes if all of their + // buffers were full. + double TotalMaximumBufferedBytes(std::shared_ptr node); + + std::optional dataset_name_; + // Used for coordination between different input pipeline threads. Exclusive + // access is required only when adding or removing nodes. Concurrent access to + // existing nodes is protected by a node mutex. + mutable mutex mu_; + // Used for coordinating the optimization loop and model modifications. + condition_variable optimize_cond_var_; + int64_t id_counter_ TF_GUARDED_BY(mu_) = 1; + std::shared_ptr output_ TF_GUARDED_BY(mu_) = nullptr; + + // Determines the time the optimization loop should wait between + // running optimizations. + int64_t optimization_period_ms_ TF_GUARDED_BY(mu_); + + // Gauge cell that can be used to collect the state of the model. + monitoring::GaugeCell>* model_gauge_cell_ = + nullptr; + // Used to synchronize metrics collection attempts against the model's + // destruction. + struct GuardedBool { + explicit GuardedBool(bool val) : val(val) {} + bool val TF_GUARDED_BY(mu); + mutex mu; + }; + std::shared_ptr safe_to_collect_metrics_; + + // Time use for rate limiting the recomputation of human-readable string + // representation of the model. + absl::Time cache_until_ = absl::InfinitePast(); + // Cached result of the `DebugString()` invocation used to implement rate + // limiting of the computation. + std::string cached_debug_string_ = ""; + // Used to coordinate gap time updates between different threads. Gap time is + // the time between the completion of the previous `GetNext()` and the start + // of the next `GetNext()`. + mutable mutex gap_mu_; + // Stores the latest gap times between consecutive `GetNext()`. + std::deque gap_times_usec_ TF_GUARDED_BY(gap_mu_); + // The experiment that this job is part of. + absl::flat_hash_set experiments_; + // Stores the optimization snapshot of the Model. + std::shared_ptr snapshot_ TF_GUARDED_BY(mu_); + // Stores the optimization parameters used by autotune. + OptimizationParams optimization_params_ TF_GUARDED_BY(mu_); + // Stores the model id in the string format + std::string model_id_; +}; + +// Class to compute timing information for a model. +class ModelTiming { + public: + struct NodeTiming { + // Pipeline ratio is the number of elements this node needs to produce in + // order to produce an element at the root of the pipeline. + double pipeline_ratio = 0.0; + // The self time it takes this node to produce the elements needed to + // produce one element of the root of the pipeline. + double self_time_nsec = 0.0; + // The total time it takes this node and the subtree rooted at this node to + // produce the elements needed to produce one element at the root of the + // pipeline. + double total_time_nsec = 0.0; + }; + + explicit ModelTiming(std::shared_ptr root); + + // Returns the timing data for `node`. + const NodeTiming* GetTiming(const Node* node) const; + + // Returns the root nodes of all stages. + std::vector> GetStageRoots() const; + + // Returns all the nodes of a stage given the stage root. + std::vector> GetStageNodes( + std::shared_ptr stage_root) const; + + // Computes the total time for a node. + void ComputeNodeTotalTime(const Node& node); + + private: + // Computes the pipeline ratios of all nodes. + void ComputePipelineRatios(const Node::NodeVector& bfs_nodes); + + // Computes the total time for all nodes. The `reverse_bfs_nodes` are assumed + // to be a vector of model nodes in reversed BFS manner. + void ComputeTotalTimes(const Node::NodeVector& reverse_bfs_nodes); + + // Computes the first input total time of an interleave node. + double ComputeInterleaveManyFirstInputTotalTime(const Node& node); + + // Computes the total time of a node of any type other than async interleave. + void ComputeNonAsyncInterleaveManyTotalTime(const Node& node); + + // Computes the total time of an async interleave node. + void ComputeAsyncInterleaveManyTotalTime(const Node& node); + // Computes the interleaved inputs' total time of an async interleave node. + double ComputeAsyncInterleaveManyInterleavedInputsTotalTime(const Node& node); + + // Returns a vector of all nodes in the model. The nodes are either in + // breadth-first search or reverse breadth-first search order depending on the + // `order` argument. The nodes are collected based on the results of the + // `collect_node` predicate: if the predicate returns `false` for a given + // node, then the subtree rooted in this node is excluded. The root node + // itself is not collected. + Node::NodeVector CollectNodes( + std::shared_ptr root, TraversalOrder order, + bool collect_node(const std::shared_ptr)) const; + + // Stores a pointer to the root of a model. + std::shared_ptr root_; + + // Holds a mapping from node to its timing node. + absl::flat_hash_map timing_nodes_; +}; + +} // namespace model +} // namespace data +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_FRAMEWORK_MODEL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/framework/node_def_builder.h b/third_party/tflite-hdrs/tensorflow/core/framework/node_def_builder.h new file mode 100644 index 00000000..47b14f18 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/framework/node_def_builder.h @@ -0,0 +1,198 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_FRAMEWORK_NODE_DEF_BUILDER_H_ +#define TENSORFLOW_CORE_FRAMEWORK_NODE_DEF_BUILDER_H_ + +#include +#include + +#include "tensorflow/core/framework/attr_value_util.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_def.pb.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/graph/graph_node_util.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/lib/strings/strcat.h" + +namespace tensorflow { + +class NodeDefBuilder; +typedef std::function + FakeInputFunctor; + +// This is a helper for creating a NodeDef. Automatically sets attrs +// that can be inferred from the inputs, and uses default values +// (where they exist) for unspecified attrs. Example usage: +// +// NodeDef node_def; +// Status status = NodeDefBuilder(node_name, op_name) +// .Input(...) +// .Attr(...) +// .Finalize(&node_def); +// if (!status.ok()) return status; +// // Use node_def here. +class NodeDefBuilder { + public: + // To specify an output to be consumed by one of the Input() methods below. + struct NodeOut { + NodeOut(absl::string_view n, int i, DataType dt); + NodeOut(); // uninitialized, call Reset() before use. + void Reset(absl::string_view n, int i, DataType dt); + string node; + int index; + DataType data_type; + }; + + // Specify the name and the Op (either via an OpDef or the name of + // the Op plus a registry) for the NodeDef. Other fields are + // specified by calling the methods below. + // REQUIRES: The OpDef must satisfy ValidateOpDef(). + NodeDefBuilder(absl::string_view name, absl::string_view op_name, + const OpRegistryInterface* op_registry = OpRegistry::Global(), + const NodeDebugInfo* debug = nullptr); + NodeDefBuilder(absl::string_view name, absl::string_view op_name, + const NodeDebugInfo& debug); + // REQUIRES: in addition, *op_def must outlive *this. + NodeDefBuilder(absl::string_view name, const OpDef* op_def); + + // You must call one Input() function per input_arg in the Op, + // *and in the same order as the input_args appear in the OpDef.* + + // For inputs that take a single tensor. + NodeDefBuilder& Input(absl::string_view src_node, int src_index, DataType dt); + NodeDefBuilder& Input(const NodeOut& src); + + // For inputs that take a list of tensors. + NodeDefBuilder& Input(absl::Span src_list); + + // To create inputs in tests, see fake_input.h. + NodeDefBuilder& Input(FakeInputFunctor fake_input); + + // Specify that this node must only run after src_node. + NodeDefBuilder& ControlInput(absl::string_view src_node); + + // Constrains what devices this node may be scheduled on. + NodeDefBuilder& Device(absl::string_view device_spec); + + // Sets the attr, if not already set. If already set with a different + // value, an error will be returned from Finalize(). + NodeDefBuilder& Attr(absl::string_view name, const AttrValue& value); + NodeDefBuilder& Attr(absl::string_view name, AttrValue&& value); + NodeDefBuilder& Attr(absl::string_view name, absl::string_view value); + NodeDefBuilder& Attr(absl::string_view name, const char* value); + NodeDefBuilder& Attr(absl::string_view name, int32_t value); + NodeDefBuilder& Attr(absl::string_view name, int64_t value); + NodeDefBuilder& Attr(absl::string_view name, float value); + NodeDefBuilder& Attr(absl::string_view name, double value); + NodeDefBuilder& Attr(absl::string_view name, bool value); + NodeDefBuilder& Attr(absl::string_view name, DataType value); + NodeDefBuilder& Attr(absl::string_view name, const PartialTensorShape& value); + NodeDefBuilder& Attr(absl::string_view name, const Tensor& value); + NodeDefBuilder& Attr(absl::string_view name, const TensorProto& value); + NodeDefBuilder& Attr(absl::string_view name, const NameAttrList& value); + NodeDefBuilder& Attr(absl::string_view name, + absl::Span value); + NodeDefBuilder& Attr(absl::string_view name, + absl::Span value); + NodeDefBuilder& Attr(absl::string_view name, absl::Span value); + NodeDefBuilder& Attr(absl::string_view name, absl::Span value); + NodeDefBuilder& Attr(absl::string_view name, absl::Span value); + NodeDefBuilder& Attr(absl::string_view name, absl::Span value); + NodeDefBuilder& Attr(absl::string_view name, absl::Span value); + NodeDefBuilder& Attr(absl::string_view name, absl::Span value); + NodeDefBuilder& Attr(absl::string_view name, const std::vector& value); + NodeDefBuilder& Attr(absl::string_view name, + absl::Span value); + NodeDefBuilder& Attr(absl::string_view name, + absl::Span value); + NodeDefBuilder& Attr(absl::string_view name, + absl::Span value); + NodeDefBuilder& Attr(absl::string_view name, + absl::Span value); + NodeDefBuilder& Attr(absl::string_view name, absl::Span value); + NodeDefBuilder& Attr(absl::string_view name, + absl::Span value); + + template + NodeDefBuilder& Attr(absl::string_view name, std::initializer_list value) { + return Attr(name, gtl::ArraySlice(value)); + } + + // Finish building the NodeDef, returning any errors or setting + // *node_def if none. + // If `consume` is true, the builder state will be moved into `node_def`, + // and the builder will be left in an undefined state. + // WARNING: Not all problems are detected! The resulting NodeDef may + // not be valid! Call ValidateNodeDef() from node_def_utils to be sure. + absl::Status Finalize(NodeDef* node_def, bool consume = false); + + // Accessors for the values set in the constructor. + const string& node_name() const { return node_def_.name(); } + const OpDef& op_def() const { return *op_def_; } + + private: + // Called in the constructors. + void Initialize(); + + // Get the current ArgDef and advance to the next one. Returns nullptr + // if no more inputs are available. + const OpDef::ArgDef* NextArgDef(); + + // Returns true if there is still an input_arg available in *op_def_, + // otherwise adds to error_ and returns false. + bool NextArgAvailable(); + + // These do the main work of the Input() methods. + void SingleInput(const OpDef::ArgDef* input_arg, absl::string_view src_node, + int src_index, DataType dt); + void ListInput(const OpDef::ArgDef* input_arg, + absl::Span src_list); + + // Add "src_node:src_index" to the list of inputs in the node_def_. + void AddInput(absl::string_view src_node, int src_index); + + // Generate an error if you can't pass dt when expected is expected. + void VerifyInputType(const OpDef::ArgDef* input_arg, DataType expected, + DataType dt); + + // If input_arg->is_ref() is true, generate an error if dt is not a ref. + void VerifyInputRef(const OpDef::ArgDef* input_arg, DataType dt); + + // Makes dt a ref type if that is what the input_arg specifies. + DataType MaybeAddRef(const OpDef::ArgDef* input_arg, DataType dt) { + return input_arg->is_ref() ? MakeRefType(dt) : dt; + } + + // Returns true if an attr named `name` is already present in the node_def_. + // If such an attr is already present and `value` is not equal to the present + // value, an error is generated. + bool AttrValueAlreadyPresent(absl::string_view name, const AttrValue& value); + + const OpDef* op_def_; + NodeDef node_def_; + int inputs_specified_; + std::vector control_inputs_; + std::vector errors_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_FRAMEWORK_NODE_DEF_BUILDER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/framework/node_def_util.h b/third_party/tflite-hdrs/tensorflow/core/framework/node_def_util.h new file mode 100644 index 00000000..2b82c596 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/framework/node_def_util.h @@ -0,0 +1,462 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_FRAMEWORK_NODE_DEF_UTIL_H_ +#define TENSORFLOW_CORE_FRAMEWORK_NODE_DEF_UTIL_H_ + +#include +#include +#include + +#include "tensorflow/core/framework/attr_value_util.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/op_def.pb.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/lib/gtl/flatmap.h" +#include "tensorflow/core/lib/hash/hash.h" +#include "tensorflow/core/platform/hash.h" +#include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/stringpiece.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/padding.h" + +namespace tensorflow { + +class AttrSlice; +// We forward declare protos so that kernels don't need to depend on them +class OpDef; +class AttrValue; +class NameAttrList; +class TensorProto; +class TensorShapeProto; + +// Name of the attribute used to encode node colocation constraints. +// +// Nodes can be co-located on the same device. Desire for explicit co-location +// is described by list(string) attribute containing the name of colocation +// groups. +extern const char* const kColocationAttrName; + +// String prefix applied to the operation name for colocation constraints. +extern const char* const kColocationGroupPrefix; + +// Constants for host CPU staging op for TPUExecute. +extern const char* const kTpuExecuteStagingOp; +extern const char* const kTpuExecuteStagingNodeName; + +// Produce a human-readable version of a Node or NodeDef that is more concise +// than a text-format proto. +// +// The parameter `max_inputs_in_summary` specifies how many inputs at most to +// serialize in the output (in order not to get a string which is overly large). +// The value `-1` specifies that all inputs will be shown. +std::string SummarizeNodeDef(const NodeDef& node_def, + int max_inputs_in_summary = -1); +std::string SummarizeAttrs(const NodeDef& node_def); +std::string SummarizeAttrsHelper(AttrSlice attrs, absl::string_view device); + +// Produces a formatted string pattern from the node which can uniquely identify +// this node upstream to produce an informative error message. The pattern +// followed is: {{node }} +std::string FormatNodeDefForError(const NodeDef& node_def); +std::string FormatNodeDefForError( + absl::string_view node_name, bool has_experimental_debug_info, + const NodeDef_ExperimentalDebugInfo& experimental_debug_info); + +typedef protobuf::Map AttrValueMap; + +// Adds an attr with name and value to *node_def. +// The type of the attr is based on the type of value. +void AddNodeAttr(absl::string_view name, const AttrValue& value, + NodeDef* node_def); +void AddNodeAttr(absl::string_view name, AttrValue&& value, NodeDef* node_def); +void AddNodeAttr(absl::string_view name, absl::string_view value, + NodeDef* node_def); +void AddNodeAttr(absl::string_view name, const char* value, NodeDef* node_def); +void AddNodeAttr(absl::string_view name, int32_t value, NodeDef* node_def); +void AddNodeAttr(absl::string_view name, int64_t value, NodeDef* node_def); +void AddNodeAttr(absl::string_view name, float value, NodeDef* node_def); +void AddNodeAttr(absl::string_view name, double value, NodeDef* node_def); +void AddNodeAttr(absl::string_view name, bool value, NodeDef* node_def); +void AddNodeAttr(absl::string_view name, DataType value, NodeDef* node_def); +void AddNodeAttr(absl::string_view name, const PartialTensorShape& value, + NodeDef* node_def); +void AddNodeAttr(absl::string_view name, const Tensor& value, + NodeDef* node_def); +void AddNodeAttr(absl::string_view name, const TensorProto& value, + NodeDef* node_def); +void AddNodeAttr(absl::string_view name, const NameAttrList& value, + NodeDef* node_def); +void AddNodeAttr(absl::string_view name, + absl::Span value, NodeDef* node_def); +void AddNodeAttr(absl::string_view name, absl::Span value, + NodeDef* node_def); +void AddNodeAttr(absl::string_view name, absl::Span value, + NodeDef* node_def); +void AddNodeAttr(absl::string_view name, absl::Span value, + NodeDef* node_def); +void AddNodeAttr(absl::string_view name, absl::Span value, + NodeDef* node_def); +void AddNodeAttr(absl::string_view name, absl::Span value, + NodeDef* node_def); +void AddNodeAttr(absl::string_view name, absl::Span value, + NodeDef* node_def); +void AddNodeAttr(absl::string_view name, const std::vector& value, + NodeDef* node_def); +void AddNodeAttr(absl::string_view name, absl::Span value, + NodeDef* node_def); +void AddNodeAttr(absl::string_view name, absl::Span value, + NodeDef* node_def); +void AddNodeAttr(absl::string_view name, + absl::Span value, NodeDef* node_def); +void AddNodeAttr(absl::string_view name, + absl::Span value, NodeDef* node_def); +void AddNodeAttr(absl::string_view name, absl::Span value, + NodeDef* node_def); +void AddNodeAttr(absl::string_view name, absl::Span value, + NodeDef* node_def); + +// Version to workaround C++'s "perfect" forwarding not being able to +// forward {...} initialization. +template +void AddNodeAttr(absl::string_view name, std::initializer_list value, + NodeDef* node_def) { + AddNodeAttr(name, gtl::ArraySlice(value), node_def); +} + +// Adds an attr to an attr value map. +void AddAttr(absl::string_view name, const AttrValue& value, AttrValueMap* map); +void AddAttr(absl::string_view name, bool value, AttrValueMap* map); + +class AttrSlice { + public: + AttrSlice(const NodeDef& node_def); // NOLINT(runtime/explicit) + + AttrSlice(); // Empty + explicit AttrSlice(const AttrValueMap* a); + + int size() const { return attrs()->size(); } + + // Returns the attr with attr_name if found. Otherwise, returns + // nullptr. + const AttrValue* Find(absl::string_view attr_name) const; + const AttrValue* FindByString(const std::string& attr_name) const; + + // Returns the attr_value for attr_name if found. Otherwise, returns a + // NotFound status. + absl::Status Find(absl::string_view attr_name, + const AttrValue** attr_value) const; + absl::Status FindByString(const std::string& attr_name, + const AttrValue** attr_value) const; + + // Helper class to avoid allocations in EqualAttrs. + // TODO(irving): Will go away once NodeInfo is used. + struct Scratch { + std::string a; + std::string b; + }; + + // Check if all attrs and attr values match. Does not take defaults into + // account. + // + // TODO(irving): There is a bug in this routine inherited from its + // OptimizerCSE::EqualAttrs predecessor. The same tensor attr can be + // represented in more than one way as an AttrValue, since TensorProto is + // not 1-1. This bug will go away once I replace everything with NodeInfo, + // which stores a Tensor object directly. The Scratch object will also go + // away. + bool EqualAttrs(AttrSlice other, Scratch* scratch) const; + + // If this AttrSlice has an attached NodeDef, summarize it. This is for + // error messages only: we intentionally do not provide direct access to the + // NodeDef, since it is not always there. + std::string SummarizeNode() const; + + // Iteration over all attrs + AttrValueMap::const_iterator begin() const { return attrs()->begin(); } + AttrValueMap::const_iterator end() const { return attrs()->end(); } + + std::string DebugString() const; + + private: + const AttrValueMap* attrs() const { + return ndef_ != nullptr ? &ndef_->attr() : attrs_; + } + + absl::Status CheckFind(absl::string_view attr_name, + const AttrValue* attr_value) const; + + const NodeDef* ndef_; + const AttrValueMap* attrs_; +}; + +// Return true if the attr with the name attr_name is defined in node_def. +bool HasNodeAttr(const NodeDef& node_def, absl::string_view attr_name); + +// Look up the attr with name attr_name and set *value to its value. If no +// attr with attr_name is found in node_def, or the attr does not have +// a matching type, a non-ok status will be returned. +absl::Status GetNodeAttr(const AttrSlice& attrs, absl::string_view attr_name, + std::string* value); // type: "string" +absl::Status GetNodeAttr(const AttrSlice& attrs, absl::string_view attr_name, + tstring* value); // type: "tstring" +absl::Status GetNodeAttr(const AttrSlice& attrs, absl::string_view attr_name, + int64_t* value); // type: "int" +absl::Status GetNodeAttr(const AttrSlice& attrs, absl::string_view attr_name, + int32* value); // type: "int" +absl::Status GetNodeAttr(const AttrSlice& attrs, absl::string_view attr_name, + float* value); // type: "float" +absl::Status GetNodeAttr(const AttrSlice& attrs, absl::string_view attr_name, + bool* value); // type: "bool" +absl::Status GetNodeAttr(const AttrSlice& attrs, absl::string_view attr_name, + DataType* value); // type: "type" +absl::Status GetNodeAttr(const AttrSlice& attrs, absl::string_view attr_name, + TensorShapeProto* value); // type: "shape" +absl::Status GetNodeAttr(const AttrSlice& attrs, absl::string_view attr_name, + TensorShape* value); // type: "shape" +absl::Status GetNodeAttr(const AttrSlice& attrs, absl::string_view attr_name, + PartialTensorShape* value); // type: "shape" +absl::Status GetNodeAttr(const AttrSlice& attrs, absl::string_view attr_name, + Tensor* value); // type: "tensor" +absl::Status GetNodeAttr(const AttrSlice& attrs, absl::string_view attr_name, + std::vector* value); // type "list(string)" +absl::Status GetNodeAttr(const AttrSlice& attrs, absl::string_view attr_name, + std::vector* value); // type "list(tstring)" +absl::Status GetNodeAttr(const AttrSlice& attrs, absl::string_view attr_name, + std::vector* value); // type "list(int)" +absl::Status GetNodeAttr(const AttrSlice& attrs, absl::string_view attr_name, + std::vector* value); // type "list(int)" +absl::Status GetNodeAttr(const AttrSlice& attrs, absl::string_view attr_name, + std::vector* value); // type "list(float)" +absl::Status GetNodeAttr(const AttrSlice& attrs, absl::string_view attr_name, + std::vector* value); // type "list(bool)" +absl::Status GetNodeAttr(const AttrSlice& attrs, absl::string_view attr_name, + std::vector* value); // type "list(type)" +absl::Status GetNodeAttr(const AttrSlice& attrs, absl::string_view attr_name, + DataTypeVector* value); // type "list(type)" +absl::Status GetNodeAttr( + const AttrSlice& attrs, absl::string_view attr_name, + std::vector* value); // type "list(shape)" +absl::Status GetNodeAttr( + const AttrSlice& attrs, absl::string_view attr_name, + std::vector* value); // type "list(shape)" +absl::Status GetNodeAttr( + const AttrSlice& attrs, absl::string_view attr_name, + std::vector* value); // type "list(shape)" +absl::Status GetNodeAttr(const AttrSlice& attrs, absl::string_view attr_name, + std::vector* value); // type: "list(tensor)" + +template +StatusOr GetNodeAttr(const NodeDef& ndef, absl::string_view attr_name) { + T val; + TF_RETURN_IF_ERROR(GetNodeAttr(ndef, attr_name, &val)); + return val; +} + +// This version avoids copying the TensorProto. +// REQUIRES: Must not use *value beyond the lifetime of node_def. +absl::Status GetNodeAttr(const AttrSlice& attrs, absl::string_view attr_name, + const TensorProto** value); // type: "tensor" +bool TryGetNodeAttr(const AttrSlice& attrs, absl::string_view attr_name, + const TensorProto** value); // type: "tensor" + +// This version avoids copying the NameAttrList. +// REQUIRES: Must not use *value beyond the lifetime of node_def. +absl::Status GetNodeAttr(const AttrSlice& attrs, absl::string_view attr_name, + const NameAttrList** value); // type: "func" +bool TryGetNodeAttr(const AttrSlice& attrs, absl::string_view attr_name, + const NameAttrList** value); // type: "func" + +// These versions copies the NameAttrList(s). +absl::Status GetNodeAttr(const AttrSlice& attrs, absl::string_view attr_name, + NameAttrList* value); // type: "func" +absl::Status GetNodeAttr( + const AttrSlice& attrs, absl::string_view attr_name, + std::vector* value); // type: "list(func)" + +// Look up the attr with name attr_name and set *value to its value. If no +// attr with attr_name is found in node_def, or the attr does not have +// a matching type, false is returned. +bool TryGetNodeAttr(const AttrSlice& attrs, absl::string_view attr_name, + std::string* value); // type: "string" +bool TryGetNodeAttr(const AttrSlice& attrs, absl::string_view attr_name, + int64_t* value); // type: "int" +bool TryGetNodeAttr(const AttrSlice& attrs, absl::string_view attr_name, + std::vector* value); // type: "int" +bool TryGetNodeAttr(const AttrSlice& attrs, absl::string_view attr_name, + int32* value); // type: "int" +bool TryGetNodeAttr(const AttrSlice& attrs, absl::string_view attr_name, + float* value); // type: "float" +bool TryGetNodeAttr(const AttrSlice& attrs, absl::string_view attr_name, + bool* value); // type: "bool" +bool TryGetNodeAttr(const AttrSlice& attrs, absl::string_view attr_name, + DataType* value); // type: "type" +bool TryGetNodeAttr(const AttrSlice& attrs, absl::string_view attr_name, + TensorShape* value); // type: "shape" + +bool TryGetNodeAttr(const AttrSlice& attrs, absl::string_view attr_name, + std::vector* value); // type: "list(string)" +bool TryGetNodeAttr(const AttrSlice& attrs, absl::string_view attr_name, + std::vector* value); // type: "list(tstring)" +bool TryGetNodeAttr(const AttrSlice& attrs, absl::string_view attr_name, + std::vector* value); // type: "list(int)" +bool TryGetNodeAttr(const AttrSlice& attrs, absl::string_view attr_name, + std::vector* value); // type: "list(float)" +bool TryGetNodeAttr(const AttrSlice& attrs, absl::string_view attr_name, + std::vector* value); // type: "list(bool)" +bool TryGetNodeAttr(const AttrSlice& attrs, absl::string_view attr_name, + std::vector* value); // type: "list(type)" +bool TryGetNodeAttr(const AttrSlice& attrs, absl::string_view attr_name, + std::vector value); // type: "shape" + +// Overloads of TryGetNodeAttr() that avoid copying the non-POD attribute +// values. +bool TryGetNodeAttr(const AttrSlice& attrs, absl::string_view attr_name, + std::vector* value); // type: "list(string)" +bool TryGetNodeAttr( + const AttrSlice& attrs, absl::string_view attr_name, + std::vector* value); // type: "list(shape)" + +// Look up the attr with name attr_name and return a reference to its value. +// If no attr with attr_name is found in node_def, or the attr does not have +// a matching type, a reference to an empty string is returned. +// REQUIRES: Must not use the returned value beyond the lifetime of node_def. +const std::string& GetNodeAttrString(const AttrSlice& attrs, + absl::string_view attr_name); + +// Specialization to parse an attribute directly into a Padding enum. +absl::Status GetNodeAttr(const AttrSlice& attrs, absl::string_view attr_name, + Padding* value); + +// Computes the input type for a specific node input. +// REQUIRES: ValidateOpDef(op_def).ok() +absl::Status InputTypeForNode(const NodeDef& node_def, const OpDef& op_def, + int input_port, DataType* input_type); +// Computes the input types for a specific node. +// REQUIRES: ValidateOpDef(op_def).ok() +absl::Status InputTypesForNode(const NodeDef& node_def, const OpDef& op_def, + DataTypeVector* inputs); +// Computes the output type for a specific node output. +// REQUIRES: ValidateOpDef(op_def).ok() +absl::Status OutputTypeForNode(const NodeDef& node_def, const OpDef& op_def, + int output_port, DataType* output_type); +// Computes the output types for a specific node. +// REQUIRES: ValidateOpDef(op_def).ok() +absl::Status OutputTypesForNode(const NodeDef& node_def, const OpDef& op_def, + DataTypeVector* outputs); +absl::Status OutputTypesForNode(const AttrSlice& attrs, const OpDef& op_def, + DataTypeVector* outputs); + +// Computes the input and output types for a specific node. +// REQUIRES: ValidateOpDef(op_def).ok() +absl::Status InOutTypesForNode(const NodeDef& node_def, const OpDef& op_def, + DataTypeVector* inputs, DataTypeVector* outputs); +// Computes the number of outputs for a specific node. +// REQUIRES: ValidateOpDef(op_def).ok() +absl::Status NumOutputsForNode(const NodeDef& node_def, const OpDef& op_def, + int* num_outputs); + +// Map a node/op's input/output port_id to arg_id. +// +// The port_id refers to the n-th tensor of the node, while the arg_id refers to +// the n-th arg of the op. These two can be different if an op's arg is a list +// of tensors. +// +// We return -1 for any invalid port_id (i.e., no corresponding arg_id). +int OpPortIdToArgId(const NodeDef& node, + const protobuf::RepeatedPtrField& args, + int port_id); + +// Validates that the NodeDef: +// * Defines all expected attrs from the OpDef. +// * All attrs satisfies constraints from the OpDef. +// * Has a signature matching SignatureForNode(). +// etc. +absl::Status ValidateNodeDef(const NodeDef& node_def, const OpDef& op_def); + +// Computes the mapping from input/output argument name to the +// corresponding input/output index range. For example, +// input "foo" corresponds to input indices +// [ (*inputs)["foo"].first, (*inputs)["foo"].second ). +// NOTE(mrry): To reduce allocations when the map is used and save +// space, the returned `NameRangeMap` objects borrow the input/output +// argument names from `op_def`. The `op_def` must outlive the +// returned `NameRangeMap` objects. +typedef gtl::FlatMap, + hash> + NameRangeMap; +absl::Status NameRangesForNode(const AttrSlice& attrs, const OpDef& op_def, + NameRangeMap* inputs, NameRangeMap* outputs); +// Adds default values to *node_def for unspecified attrs from op_def. +void AddDefaultsToNodeDef(const OpDef& op_def, NodeDef* node_def); + +// Remove attributes from node_def when the value is the default from the +// op_def. +void StripDefaultsFromNodeDef(const OpDef& op_def, NodeDef* node_def); + +// Validates the syntax of a NodeDef provided externally. +// +// The following is an EBNF-style syntax for NodeDef objects. Note that +// Node objects are actually specified as tensorflow::NodeDef protocol buffers, +// which contain many other fields that are not (currently) validated. +// +// Node = NodeName, Inputs +// Inputs = ( DataInput * ), ( ControlInput * ) +// DataInput = NodeName, ( ":", [1-9], [0-9] * ) ? +// ControlInput = "^", NodeName +// NodeName = [A-Za-z0-9.], [A-Za-z0-9_./] * +absl::Status ValidateExternalNodeDefSyntax(const NodeDef& node_def); + +// Returns "status" with formatted NodeDef attached as additional text +// in the error message. If 'allow_multiple_formatted_node' is false and there +// is already a formatted NodeDef present in 'status', we simply attach the name +// of the NodeDef instead of the formatted string. +absl::Status AttachDef(const absl::Status& status, const NodeDef& node_def, + bool allow_multiple_formatted_node = false); +// Appends the given prefix and suffix to the original node name in order to +// make the name unique. If it's an "Enter" node and uniquify_frame_name is +// true, use the same way to reset attribute "frame_name". +absl::Status AddPrefixAndSuffixToNode(absl::string_view prefix, + absl::string_view suffix, + NodeDef* node_def, + bool uniquify_frame_name = true); + +// Appends the given prefix to the colocation group name if the name exists +// in `to_match`. +absl::Status MaybeAddPrefixToColocationConstraints( + const std::unordered_set& match, absl::string_view prefix, + NodeDef* node_def); + +// Updates the colocation constraint name with the one provided in the map (if +// it exists in the map) for node_def. +absl::Status MaybeUpdateColocationConstraintsWithMap( + const std::map& node_name_map, + NodeDef* node_def); + +// For replacing a existing node with a NoOp, change the op and clear full type +// information (since a NoOp has no output). Note that (duplicate control or +// all) inputs, (regular, output or all) attributes and output properperties are +// NOT cleared (and should be cleared if appropriate elsewhere). +void ChangeToNoOp(NodeDef* node_def); + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_FRAMEWORK_NODE_DEF_UTIL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/framework/node_properties.h b/third_party/tflite-hdrs/tensorflow/core/framework/node_properties.h new file mode 100644 index 00000000..91c495bb --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/framework/node_properties.h @@ -0,0 +1,64 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_FRAMEWORK_NODE_PROPERTIES_H_ +#define TENSORFLOW_CORE_FRAMEWORK_NODE_PROPERTIES_H_ + +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/op_def.pb.h" +#include "tensorflow/core/framework/op_def_builder.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { + +class OpRegistryInterface; + +struct NodeProperties { + public: + NodeProperties(const OpDef* op_def, NodeDef node_def, + const DataTypeSlice inputs, const DataTypeSlice outputs) + : NodeProperties(op_def, std::move(node_def), + DataTypeVector(inputs.begin(), inputs.end()), + DataTypeVector(outputs.begin(), outputs.end())) {} + + NodeProperties(const OpDef* _op_def, NodeDef&& _node_def, + DataTypeVector inputs, DataTypeVector outputs) + : op_def(_op_def), + node_def(std::move(_node_def)), + input_types(std::move(inputs)), + input_types_slice(input_types), + output_types(std::move(outputs)), + output_types_slice(output_types) {} + + // Resets the 'props' shared pointer to point to a new NodeProperties created + // from the given NodeDef. 'op_registry' is used to look up the OpDef + // corresponding to node_def.op(). Returns an error if OpDef lookup or + // creation failed. + static absl::Status CreateFromNodeDef( + NodeDef node_def, const OpRegistryInterface* op_registry, + std::shared_ptr* props); + + const OpDef* op_def; // not owned. + NodeDef node_def; + DataTypeVector input_types; + DataTypeSlice input_types_slice; + DataTypeVector output_types; + DataTypeSlice output_types_slice; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_FRAMEWORK_NODE_PROPERTIES_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/framework/numeric_op.h b/third_party/tflite-hdrs/tensorflow/core/framework/numeric_op.h new file mode 100644 index 00000000..0167e21f --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/framework/numeric_op.h @@ -0,0 +1,113 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_FRAMEWORK_NUMERIC_OP_H_ +#define TENSORFLOW_CORE_FRAMEWORK_NUMERIC_OP_H_ + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { + +// One input and one output, both the same type. +template +class UnaryOp : public OpKernel { + public: + explicit UnaryOp(OpKernelConstruction* context) : OpKernel(context) { + const DataType dt = DataTypeToEnum::v(); + OP_REQUIRES_OK(context, context->MatchSignature({dt}, {dt})); + } +}; + +// Two inputs and one output, all the same type. +template +class BinaryOp : public OpKernel { + public: + explicit BinaryOp(OpKernelConstruction* context) : OpKernel(context) { + const DataType dt = DataTypeToEnum::v(); + OP_REQUIRES_OK(context, context->MatchSignature({dt, dt}, {dt})); + } +}; + +// For operations where the input and output are the same shape. +// +// For usage, see ../framework/elementwise_ops.cc. +template +class UnaryElementWiseOp : public UnaryOp { + public: + using UnaryOp::UnaryOp; + + void Compute(OpKernelContext* context) override { + // Output shape is the same as input shape. + const Tensor& input = context->input(0); + Tensor* output = nullptr; + OP_REQUIRES_OK(context, context->forward_input_or_allocate_output( + {0}, 0, input.shape(), &output)); + static_cast(this)->Operate(context, input, output); + } +}; + +// For binary elementwise operations. +template +class BinaryElementWiseOp : public BinaryOp { + public: + using BinaryOp::BinaryOp; + + void Compute(OpKernelContext* context) override { + const Tensor& a = context->input(0); + const Tensor& b = context->input(1); + + if (!context->ValidateInputsAreSameShape(this)) { + return; + } + + Tensor* output = nullptr; + OP_REQUIRES_OK(context, context->forward_input_or_allocate_output( + {0, 1}, 0, a.shape(), &output)); + + // Dispatch to the descendant's Operate() function. + switch (a.dims()) { +#define NDIM_CASE(NDIMS) \ + case NDIMS: { \ + static_cast(this)->template Operate(context, a, b, output); \ + break; \ + } + + NDIM_CASE(0); + NDIM_CASE(1); + NDIM_CASE(2); + NDIM_CASE(3); + NDIM_CASE(4); + NDIM_CASE(5); + NDIM_CASE(6); + NDIM_CASE(7); + NDIM_CASE(8); +#undef NDIM_CASE + + default: + context->SetStatus(errors::InvalidArgument( + "We only handle up to Tensor::dims() up to 8, not ", a.dims())); + break; + } + } +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_FRAMEWORK_NUMERIC_OP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/framework/numeric_types.h b/third_party/tflite-hdrs/tensorflow/core/framework/numeric_types.h new file mode 100644 index 00000000..0b22dbaf --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/framework/numeric_types.h @@ -0,0 +1,44 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_FRAMEWORK_NUMERIC_TYPES_H_ +#define TENSORFLOW_CORE_FRAMEWORK_NUMERIC_TYPES_H_ + +#include + +// clang-format off +// This include order is required to avoid instantiating templates +// quantized types in the Eigen namespace before their specialization. +#include "xla/tsl/framework/numeric_types.h" +#include "tensorflow/core/platform/types.h" +// clang-format on + +namespace tensorflow { + +// NOLINTBEGIN(misc-unused-using-decls) +using tsl::complex128; +using tsl::complex64; + +// We use Eigen's QInt implementations for our quantized int types. +using tsl::qint16; +using tsl::qint32; +using tsl::qint8; +using tsl::quint16; +using tsl::quint8; +// NOLINTEND(misc-unused-using-decls) + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_FRAMEWORK_NUMERIC_TYPES_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/framework/op.h b/third_party/tflite-hdrs/tensorflow/core/framework/op.h new file mode 100644 index 00000000..41b39fc2 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/framework/op.h @@ -0,0 +1,330 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_FRAMEWORK_OP_H_ +#define TENSORFLOW_CORE_FRAMEWORK_OP_H_ + +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "tensorflow/core/framework/full_type.pb.h" +#include "tensorflow/core/framework/full_type_inference_util.h" // IWYU pragma: export +#include "tensorflow/core/framework/full_type_util.h" // IWYU pragma: export +#include "tensorflow/core/framework/op_def_builder.h" +#include "tensorflow/core/framework/op_def_util.h" // IWYU pragma: export +#include "tensorflow/core/framework/registration/registration.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/thread_annotations.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +// Users that want to look up an OpDef by type name should take an +// OpRegistryInterface. Functions accepting a +// (const) OpRegistryInterface* may call LookUp() from multiple threads. +class OpRegistryInterface { + public: + virtual ~OpRegistryInterface() = default; + + // Returns an error status and sets *op_reg_data to nullptr if no OpDef is + // registered under that name, otherwise returns the registered OpDef. + // Caller must not delete the returned pointer. + virtual absl::Status LookUp(const std::string& op_type_name, + const OpRegistrationData** op_reg_data) const = 0; + + // Shorthand for calling LookUp to get the OpDef. + absl::Status LookUpOpDef(const std::string& op_type_name, + const OpDef** op_def) const; +}; + +// The standard implementation of OpRegistryInterface, along with a +// global singleton used for registering ops via the REGISTER +// macros below. Thread-safe. +// +// Example registration: +// OpRegistry::Global()->Register( +// [](OpRegistrationData* op_reg_data)->Status { +// // Populate *op_reg_data here. +// return OkStatus(); +// }); +class OpRegistry : public OpRegistryInterface { + public: + typedef std::function + OpRegistrationDataFactory; + + OpRegistry(); + + void Register(const OpRegistrationDataFactory& op_data_factory); + + absl::Status LookUp(const std::string& op_type_name, + const OpRegistrationData** op_reg_data) const override; + + // Returns OpRegistrationData* of registered op type, else returns nullptr. + const OpRegistrationData* LookUp(const std::string& op_type_name) const; + + // Fills *ops with all registered OpDefs (except those with names + // starting with '_' if include_internal == false) sorted in + // ascending alphabetical order. + void Export(bool include_internal, OpList* ops) const; + + // Returns ASCII-format OpList for all registered OpDefs (except + // those with names starting with '_' if include_internal == false). + std::string DebugString(bool include_internal) const; + + // A singleton available at startup. + static OpRegistry* Global(); + + // Get all registered ops. + void GetRegisteredOps(std::vector* op_defs); + + // Get all `OpRegistrationData`s. + void GetOpRegistrationData(std::vector* op_data); + + // Registers a function that validates op registry. + void RegisterValidator( + std::function validator) { + op_registry_validator_ = std::move(validator); + } + + // Watcher, a function object. + // The watcher, if set by SetWatcher(), is called every time an op is + // registered via the Register function. The watcher is passed the Status + // obtained from building and adding the OpDef to the registry, and the OpDef + // itself if it was successfully built. A watcher returns a Status which is in + // turn returned as the final registration status. + typedef std::function + Watcher; + + // An OpRegistry object has only one watcher. This interface is not thread + // safe, as different clients are free to set the watcher any time. + // Clients are expected to atomically perform the following sequence of + // operations : + // SetWatcher(a_watcher); + // Register some ops; + // op_registry->ProcessRegistrations(); + // SetWatcher(nullptr); + // Returns a non-OK status if a non-null watcher is over-written by another + // non-null watcher. + absl::Status SetWatcher(const Watcher& watcher); + + // Process the current list of deferred registrations. Note that calls to + // Export, LookUp and DebugString would also implicitly process the deferred + // registrations. Returns the status of the first failed op registration or + // OkStatus() otherwise. + absl::Status ProcessRegistrations() const; + + // Defer the registrations until a later call to a function that processes + // deferred registrations are made. Normally, registrations that happen after + // calls to Export, LookUp, ProcessRegistrations and DebugString are processed + // immediately. Call this to defer future registrations. + void DeferRegistrations(); + + // Clear the registrations that have been deferred. + void ClearDeferredRegistrations(); + + private: + // Ensures that all the functions in deferred_ get called, their OpDef's + // registered, and returns with deferred_ empty. Returns true the first + // time it is called. Prints a fatal log if any op registration fails. + bool MustCallDeferred() const TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); + + // Calls the functions in deferred_ and registers their OpDef's + // It returns the Status of the first failed op registration or OkStatus() + // otherwise. + absl::Status CallDeferred() const TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); + + // Add 'def' to the registry with additional data 'data'. On failure, or if + // there is already an OpDef with that name registered, returns a non-okay + // status. + absl::Status RegisterAlreadyLocked( + const OpRegistrationDataFactory& op_data_factory) const + TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); + + const OpRegistrationData* LookUpSlow(const std::string& op_type_name) const; + + mutable mutex mu_; + // Functions in deferred_ may only be called with mu_ held. + mutable std::vector deferred_ TF_GUARDED_BY(mu_); + // Values are owned. + mutable absl::flat_hash_map> + registry_ TF_GUARDED_BY(mu_); + mutable bool initialized_ TF_GUARDED_BY(mu_); + + // Registry watcher. + mutable Watcher watcher_ TF_GUARDED_BY(mu_); + + std::function + op_registry_validator_; +}; + +// An adapter to allow an OpList to be used as an OpRegistryInterface. +// +// Note that shape inference functions are not passed in to OpListOpRegistry, so +// it will return an unusable shape inference function for every op it supports; +// therefore, it should only be used in contexts where this is okay. +class OpListOpRegistry : public OpRegistryInterface { + public: + // Does not take ownership of op_list, *op_list must outlive *this. + explicit OpListOpRegistry(const OpList* op_list); + absl::Status LookUp(const std::string& op_type_name, + const OpRegistrationData** op_reg_data) const override; + + // Returns OpRegistrationData* of op type in list, else returns nullptr. + const OpRegistrationData* LookUp(const std::string& op_type_name) const; + + private: + // Values are owned. + absl::flat_hash_map> index_; +}; + +// Support for defining the OpDef (specifying the semantics of the Op and how +// it should be created) and registering it in the OpRegistry::Global() +// registry. Usage: +// +// REGISTER_OP("my_op_name") +// .Attr(":") +// .Attr(":=") +// .Input(":") +// .Input(":Ref()") +// .Output(":") +// .Doc(R"( +// <1-line summary> +// +// : +// : +// )"); +// +// Note: .Doc() should be last. +// For details, see the OpDefBuilder class in op_def_builder.h. + +namespace register_op { + +class OpDefBuilderWrapper { + public: + explicit OpDefBuilderWrapper(const char name[]) : builder_(name) {} + OpDefBuilderWrapper& Attr(std::string spec) { + builder_.Attr(std::move(spec)); + return *this; + } + OpDefBuilderWrapper& Attr(const char* spec) TF_ATTRIBUTE_NOINLINE { + return Attr(std::string(spec)); + } + OpDefBuilderWrapper& Input(std::string spec) { + builder_.Input(std::move(spec)); + return *this; + } + OpDefBuilderWrapper& Input(const char* spec) TF_ATTRIBUTE_NOINLINE { + return Input(std::string(spec)); + } + OpDefBuilderWrapper& Output(std::string spec) { + builder_.Output(std::move(spec)); + return *this; + } + OpDefBuilderWrapper& Output(const char* spec) TF_ATTRIBUTE_NOINLINE { + return Output(std::string(spec)); + } + OpDefBuilderWrapper& SetIsCommutative() { + builder_.SetIsCommutative(); + return *this; + } + OpDefBuilderWrapper& SetIsAggregate() { + builder_.SetIsAggregate(); + return *this; + } + OpDefBuilderWrapper& SetIsStateful() { + builder_.SetIsStateful(); + return *this; + } + OpDefBuilderWrapper& SetDoNotOptimize() { + // We don't have a separate flag to disable optimizations such as constant + // folding and CSE so we reuse the stateful flag. + builder_.SetIsStateful(); + return *this; + } + OpDefBuilderWrapper& SetAllowsUninitializedInput() { + builder_.SetAllowsUninitializedInput(); + return *this; + } + OpDefBuilderWrapper& Deprecated(int version, std::string explanation) { + builder_.Deprecated(version, std::move(explanation)); + return *this; + } + OpDefBuilderWrapper& Doc(std::string text) { + builder_.Doc(std::move(text)); + return *this; + } + OpDefBuilderWrapper& SetShapeFn(OpShapeInferenceFn fn) { + builder_.SetShapeFn(std::move(fn)); + return *this; + } + OpDefBuilderWrapper& SetIsDistributedCommunication() { + builder_.SetIsDistributedCommunication(); + return *this; + } + + OpDefBuilderWrapper& SetTypeConstructor(OpTypeConstructor fn) { + builder_.SetTypeConstructor(std::move(fn)); + return *this; + } + + OpDefBuilderWrapper& SetForwardTypeFn(TypeInferenceFn fn) { + builder_.SetForwardTypeFn(std::move(fn)); + return *this; + } + + OpDefBuilderWrapper& SetReverseTypeFn(int input_number, TypeInferenceFn fn) { + builder_.SetReverseTypeFn(input_number, std::move(fn)); + return *this; + } + + const ::tensorflow::OpDefBuilder& builder() const { return builder_; } + + InitOnStartupMarker operator()(); + + private: + mutable ::tensorflow::OpDefBuilder builder_; +}; + +} // namespace register_op + +#define REGISTER_OP_IMPL(ctr, name, is_system_op) \ + static ::tensorflow::InitOnStartupMarker const register_op##ctr \ + TF_ATTRIBUTE_UNUSED = \ + TF_INIT_ON_STARTUP_IF(is_system_op || SHOULD_REGISTER_OP(name)) \ + << ::tensorflow::register_op::OpDefBuilderWrapper(name) + +#define REGISTER_OP(name) \ + TF_ATTRIBUTE_ANNOTATE("tf:op") \ + TF_NEW_ID_FOR_INIT(REGISTER_OP_IMPL, name, false) + +// The `REGISTER_SYSTEM_OP()` macro acts as `REGISTER_OP()` except +// that the op is registered unconditionally even when selective +// registration is used. +#define REGISTER_SYSTEM_OP(name) \ + TF_ATTRIBUTE_ANNOTATE("tf:op") \ + TF_ATTRIBUTE_ANNOTATE("tf:op:system") \ + TF_NEW_ID_FOR_INIT(REGISTER_OP_IMPL, name, true) + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_FRAMEWORK_OP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/framework/op_def_builder.h b/third_party/tflite-hdrs/tensorflow/core/framework/op_def_builder.h new file mode 100644 index 00000000..8009135d --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/framework/op_def_builder.h @@ -0,0 +1,280 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Class and associated machinery for specifying an Op's OpDef and shape +// inference function for Op registration. + +#ifndef TENSORFLOW_CORE_FRAMEWORK_OP_DEF_BUILDER_H_ +#define TENSORFLOW_CORE_FRAMEWORK_OP_DEF_BUILDER_H_ + +#include +#include +#include + +#include "tensorflow/core/framework/full_type.pb.h" +#include "tensorflow/core/framework/op_def.pb.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/platform/macros.h" + +namespace tensorflow { + +// TODO(b/62899350): Refactor without proto dependencies. +typedef std::function OpTypeConstructor; + +typedef std::vector> TypeRefVector; + +// A callback into the type inference process, allowing type inference functions +// to request inferring the type of some function (assumed to exist in the +// runtime). The function is specified by name. +typedef std::function(const string&, + const TypeRefVector&)> + FunctionTypeInferrer; + +// A type inference function, called for each node during type inference +// (possibly multiple times). +// The first argument (input_types) will hold the type of each of the node's +// inputs. The second argument (type_vars) will hold the return type of +// each function referred from any type variable (e.g. `FuncVar`) present +// in the node's corresponding op definition. +// +// TODO(mdan): Consider a vector-in, vector-out contract. +typedef std::function(const TypeRefVector&, + const FunctionTypeInferrer&)> + TypeInferenceFn; + +class FunctionDefHelper; + +namespace shape_inference { +class InferenceContext; +} +typedef std::function + OpShapeInferenceFn; + +struct OpRegistrationData { + public: + OpRegistrationData() {} + OpRegistrationData(const OpDef& def) : op_def(def) {} + OpRegistrationData(const OpDef& def, const OpShapeInferenceFn& fn, + bool is_function = false) + : op_def(def), shape_inference_fn(fn), is_function_op(is_function) {} + + OpDef op_def; + OpShapeInferenceFn shape_inference_fn; + + // Type constructor. This callable initializes the type of this op. + // It is provided as a programmatic mechanism for defining an op's + // type, as part of its registration. It is to be eventually replaced by a + // textual language. + // + // Important: historically, op registrations only contained partial + // input/output type information in non-standardized attribute declarations + // (e.g. typically, input types were held in a `dtype` attribute). The type + // constructor currently duplicates such attribute information, with the aim + // of entirely subsuming it, and eventually deprecating all type-related + // attributes. + // + // Since ops are typically parametrized, the type created by this constructor + // is also parametric. + // + // Example: for an op `Foo(x: T) -> Bar[T]`: + // + // * typically, its op registration included a single attribute `T: type`; + // then the respective input was defined as `x: T`; the output type `Bar` + // was implied by the op name. + // * the type constructor creates a FullType object containing `Bar[T]`; this + // still relies on the `T` attribute which it references. + // * in the future, the type constructor will create a FullType containing + // `Callable[(x: T), Bar[T]]`, and the attribute `T` will be deprecated. + OpTypeConstructor type_ctor; + + // Forward type inference function. This callable infers the return type of an + // op based on its input types. + // + // Note that the type constructor and forward inference functions need not be + // mutually exclusive: if there is some static information that can be set + // based on attributes, then that should be set in the constructor. If more + // information can be extracted from inputs, that should be done in the + // forward inference function. + // + // This is similar to the shape function, but is more general, and applied + // directly to NodeDefs, rather than working on the ShapeAndType structures. + // Note that the op input/output declarations may specify some implicit type + // constraints through attribute references (i.e. two inputs pointing to the + // same type attribute). Those constraints may duplicate what this function + // specifies in its body. That's intended, for a gradual transition to a more + // formal type system. + // + // These type inference functions are intermediate solutions as well: once the + // op registration has a complete, formal type definition, along with + // a solver-based type inference, it will replace these functions. + // + // TODO(mdan): Merge with shape inference. + // TODO(mdan): Replace with a union-based type inference algorithm. + TypeInferenceFn fwd_type_fn; + + // Reverse type inference function. This callable infers some input types + // based on the return type. + // + // TODO(mdan): Replace with a union-based type inference algorithm. + TypeInferenceFn rev_type_fn; + + // The input number affected by reverse type inference. Only one input may be + // updated in this manner. + // TODO(mdan): Encode in a manner more consistent with the forward version. + int rev_type_input; + + bool is_function_op = false; +}; + +// Builder class passed to the REGISTER_OP() macro. +class OpDefBuilder { + public: + // Constructs an OpDef with just the name field set. + explicit OpDefBuilder(std::string op_name); + + // Adds an attr to this OpDefBuilder (and returns *this). The spec has + // format ":" or ":=" + // where matches regexp [a-zA-Z][a-zA-Z0-9_]* + // (by convention only using capital letters for attrs that can be inferred) + // can be: + // "string", "int", "float", "bool", "type", "shape", or "tensor" + // "numbertype", "realnumbertype", "quantizedtype" + // (meaning "type" with a restriction on valid values) + // "{int32,int64}" or {realnumbertype,quantizedtype,string}" + // (meaning "type" with a restriction containing unions of value types) + // "{\"foo\", \"bar\n baz\"}", or "{'foo', 'bar\n baz'}" + // (meaning "string" with a restriction on valid values) + // "list(string)", ..., "list(tensor)", "list(numbertype)", ... + // (meaning lists of the above types) + // "int >= 2" (meaning "int" with a restriction on valid values) + // "list(string) >= 2", "list(int) >= 2" + // (meaning "list(string)" / "list(int)" with length at least 2) + // , if included, should use the Proto text format + // of . For lists use [a, b, c] format. + // + // Note that any attr specifying the length of an input or output will + // get a default minimum of 1 unless the >= # syntax is used. + // + // TODO(josh11b): Perhaps support restrictions and defaults as optional + // extra arguments to Attr() instead of encoding them in the spec string. + // TODO(josh11b): Would like to have better dtype handling for tensor attrs: + // * Ability to say the type of an input/output matches the type of + // the tensor. + // * Ability to restrict the type of the tensor like the existing + // restrictions for type attrs. + // Perhaps by linking the type of the tensor to a type attr? + OpDefBuilder& Attr(std::string spec); + + // Adds an input or output to this OpDefBuilder (and returns *this). + // The spec has form ":" or ":Ref()" + // where matches regexp [a-z][a-z0-9_]* and can be: + // * For a single tensor: + // * For a sequence of tensors with the same type: * + // * For a sequence of tensors with different types: + // Where: + // is either one of "float", "int32", "string", ... + // or the name of an attr (see above) with type "type". + // is the name of an attr with type "int". + // is the name of an attr with type "list(type)". + // TODO(josh11b): Indicate Ref() via an optional argument instead of + // in the spec? + // TODO(josh11b): SparseInput() and SparseOutput() matching the Python + // handling? + OpDefBuilder& Input(std::string spec); + OpDefBuilder& Output(std::string spec); + + // Turns on the indicated boolean flag in this OpDefBuilder (and + // returns *this). + OpDefBuilder& SetIsCommutative(); + OpDefBuilder& SetIsAggregate(); + OpDefBuilder& SetIsStateful(); + OpDefBuilder& SetAllowsUninitializedInput(); + OpDefBuilder& SetIsDistributedCommunication(); + + // Deprecate the op at a certain GraphDef version. + OpDefBuilder& Deprecated(int version, std::string explanation); + + // Adds docs to this OpDefBuilder (and returns *this). + // Docs have the format: + // <1-line summary> + // + // : + // : + // + // Where is the name of an attr, input, or output. Please + // wrap docs at 72 columns so that it may be indented in the + // generated output. For tensor inputs or outputs (not attrs), you + // may start the description with an "=" (like name:= ) + // to suppress the automatically-generated type documentation in + // generated output. + OpDefBuilder& Doc(std::string text); + + // Sets the function to be used as type constructor. + // See OpRegistrationData::type_ctor. + OpDefBuilder& SetTypeConstructor(OpTypeConstructor c); + + // Sets the function to be used for forward type inference. + // See OpRegistrationData::fwd_type_fn. + OpDefBuilder& SetForwardTypeFn(TypeInferenceFn f); + + // Sets the function to be used for reverse type inference. + // See OpRegistrationData::rew_type_fn. + OpDefBuilder& SetReverseTypeFn(int input_number, TypeInferenceFn f); + + // Sets the shape function to be used for shape inference. + // + // Note that currently (October 2016), python code still requires a + // RegisterShape call to invoke this; see call_cpp_shape_fn in + // python/framework/common_shapes.py + OpDefBuilder& SetShapeFn(OpShapeInferenceFn fn); + + // Allows the `` in calls to `Attr()` to be "any". + // This is used by PythonAPIWrapper for pass-through parameters. + OpDefBuilder& AllowAttrTypeAny(); + + // Sets op_reg_data->op_def to the requested OpDef and + // op_reg_data->shape_inference_fn to the requested shape inference function, + // or returns an error. + // Must be called after all of the above methods. + // + // Note that OpDefBuilder only reports parsing errors. You should also + // call ValidateOpDef() to detect other problems. + absl::Status Finalize(OpRegistrationData* op_reg_data) const; + + private: + friend class FunctionDefHelper; + + // Adds control output to this OpDefBuilder (and returns *this). + // The must be a valid node name (matches regexp + // [a-zA-Z][a-zA-Z0-9_]*). Named control output can only exist for functions. + OpDefBuilder& ControlOutput(std::string name); + + OpDef* op_def() { return &op_reg_data_.op_def; } + + OpRegistrationData op_reg_data_; + std::vector attrs_; + std::vector inputs_; + std::vector outputs_; + std::vector control_outputs_; + std::string doc_; + std::vector errors_; + bool allow_attr_type_any_ = false; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_FRAMEWORK_OP_DEF_BUILDER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/framework/op_def_util.h b/third_party/tflite-hdrs/tensorflow/core/framework/op_def_util.h new file mode 100644 index 00000000..be1f0822 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/framework/op_def_util.h @@ -0,0 +1,110 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// TODO(josh11b): Probably not needed for OpKernel authors, so doesn't +// need to be as publicly accessible as other files in framework/. + +#ifndef TENSORFLOW_CORE_FRAMEWORK_OP_DEF_UTIL_H_ +#define TENSORFLOW_CORE_FRAMEWORK_OP_DEF_UTIL_H_ + +#include + +#include "tensorflow/core/framework/api_def.pb.h" +#include "tensorflow/core/framework/op_def.pb.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/platform/protobuf.h" + +namespace tensorflow { + +// Performs a consistency check across the fields of the op_def. +absl::Status ValidateOpDef(const OpDef& op_def); + +// Check if an op is deprecated at the given GraphDef version. If the op is +// deprecated at a future version, a warning will be logged. +absl::Status CheckOpDeprecation(const OpDef& op_def, int graph_def_version); + +// Validates that attr_value satisfies the type and constraints from attr. +// REQUIRES: attr has already been validated. +absl::Status ValidateAttrValue(const AttrValue& attr_value, + const OpDef::AttrDef& attr); + +// The following search through op_def for an attr with the indicated name. +// Returns nullptr if no such attr is found. +const OpDef::AttrDef* FindAttr(absl::string_view name, const OpDef& op_def); +OpDef::AttrDef* FindAttrMutable(absl::string_view name, OpDef* op_def); + +// Searches op_def for input argument with the indicated name. +// Returns nullptr if no such attr is found. +const OpDef::ArgDef* FindInputArg(absl::string_view name, const OpDef& op_def); + +// Searches api_def for input argument with the indicated name. +// Returns nullptr if no such attr is found. +const ApiDef::Arg* FindInputArg(absl::string_view name, const ApiDef& api_def); + +// Produce a human-readable version of an op_def that is more concise +// than a text-format proto. Excludes descriptions. +std::string SummarizeOpDef(const OpDef& op_def); + +// Returns an error if new_op is not backwards-compatible with (more +// accepting than) old_op. +// REQUIRES: old_op and new_op must pass validation. +absl::Status OpDefCompatible(const OpDef& old_op, const OpDef& new_op); + +// Returns an error if any attr in penultimate_op that is not in old_op +// has a different default value in new_op. In general it is not safe +// to change the default for an attr that has been added to an op. +absl::Status OpDefAddedDefaultsUnchanged(const OpDef& old_op, + const OpDef& penultimate_op, + const OpDef& new_op); + +// Returns an error if the default value for any attr is removed or modified +// in new_op compared to old_op. Adding new default values is safe, and does +// not raise an error. +absl::Status OpDefAttrDefaultsUnchanged(const OpDef& old_op, + const OpDef& new_op); + +// Remove all docs from *op_def / *op_list. +void RemoveDescriptionsFromOpDef(OpDef* op_def); +void RemoveDescriptionsFromOpList(OpList* op_list); + +// Remove docs from *op_def but leave explanations of deprecations. +void RemoveNonDeprecationDescriptionsFromOpDef(OpDef* op_def); + +// Returns true if `a1` is equal to `a2`. +// Equality includes all the fields. +bool AttrDefEqual(const OpDef::AttrDef& a1, const OpDef::AttrDef& a2); + +// Returns hash of `a` that is consistent with AttrDefEqual. +uint64 AttrDefHash(const OpDef::AttrDef& a); + +// Returns true if all AttrDefs in `a1` equal corresponding AttrDefs in +// `a2`. Correspondence is established by name. +bool RepeatedAttrDefEqual(const protobuf::RepeatedPtrField& a1, + const protobuf::RepeatedPtrField& a2); + +// Returns hash of `a` that is consistent with RepeatedAttrDefEqual +uint64 RepeatedAttrDefHash(const protobuf::RepeatedPtrField& a); + +// Returns true if `o1` is equal to `o2`. +// Equality includes all the fields. OpDef.attr field is treated as a set. +bool OpDefEqual(const OpDef& o1, const OpDef& o2); + +// Returns hash of `o` that is consistent with AttrDefEqual. +uint64 OpDefHash(const OpDef& o); + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_FRAMEWORK_OP_DEF_UTIL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/framework/op_gen_lib.h b/third_party/tflite-hdrs/tensorflow/core/framework/op_gen_lib.h new file mode 100644 index 00000000..27ffe522 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/framework/op_gen_lib.h @@ -0,0 +1,100 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_FRAMEWORK_OP_GEN_LIB_H_ +#define TENSORFLOW_CORE_FRAMEWORK_OP_GEN_LIB_H_ + +#include +#include +#include "tensorflow/core/framework/api_def.pb.h" +#include "tensorflow/core/framework/op_def.pb.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/platform/env.h" + +namespace tensorflow { + +// Forward declare protos so their symbols can be removed from .so exports +class OpDef; + +inline string Spaces(int n) { return string(n, ' '); } + +// Wrap prefix + str to be at most width characters, indenting every line +// after the first by prefix.size() spaces. Intended use case is something +// like prefix = " Foo(" and str is a list of arguments (terminated by a ")"). +// TODO(josh11b): Option to wrap on ", " instead of " " when possible. +string WordWrap(absl::string_view prefix, absl::string_view str, int width); + +// Looks for an "=" at the beginning of *description. If found, strips it off +// (and any following spaces) from *description and return true. Otherwise +// returns false. +bool ConsumeEquals(absl::string_view* description); + +// Convert text-serialized protobufs to/from multiline format. +string PBTxtToMultiline(absl::string_view pbtxt, + const std::vector& multi_line_fields); +string PBTxtFromMultiline(absl::string_view multiline_pbtxt); + +// Takes a list of files with ApiDefs text protos, and allows you to +// look up the specific ApiDef for any given op. +class ApiDefMap { + public: + // OpList must be a superset of ops of any subsequently loaded + // ApiDef. + explicit ApiDefMap(const OpList& op_list); + ~ApiDefMap(); + + // You can call this method multiple times to load multiple + // sets of files. Api definitions are merged if the same + // op definition is loaded multiple times. Later-loaded + // definitions take precedence. + // ApiDefs loaded from files must contain a subset of ops defined + // in the OpList passed to the constructor. + absl::Status LoadFileList(Env* env, const std::vector& filenames); + + // Load a single file. Api definitions are merged if the same + // op definition is loaded multiple times. Later-loaded + // definitions take precedence. + // ApiDefs loaded from file must contain a subset of ops defined + // in the OpList passed to the constructor. + absl::Status LoadFile(Env* env, const string& filename); + + // Load ApiDefs from string containing ApiDefs text proto. + // api_def_file_contents is expected to be in "multiline format". + // ApiDefs must contain a subset of ops defined in OpsList + // passed to the constructor. + absl::Status LoadApiDef(const string& api_def_file_contents); + + // Updates ApiDef docs. For example, if ApiDef renames an argument + // or attribute, applies these renames to descriptions as well. + // UpdateDocs should only be called once after all ApiDefs are loaded + // since it replaces original op names. + void UpdateDocs(); + + // Look up ApiDef proto based on the given graph op name. + // If graph op name is not in this ApiDefMap, returns nullptr. + // + // Note: Returned ApiDef pointer should stay valid even after calling + // Load* functions defined above. Subsequent calls to Load* might modify + // returned ApiDef contents, but should never remove the ApiDef itself. + const ApiDef* GetApiDef(const string& name) const; + + private: + std::unordered_map map_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_FRAMEWORK_OP_GEN_LIB_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/framework/op_kernel.h b/third_party/tflite-hdrs/tensorflow/core/framework/op_kernel.h new file mode 100644 index 00000000..d925bc21 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/framework/op_kernel.h @@ -0,0 +1,1736 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_FRAMEWORK_OP_KERNEL_H_ +#define TENSORFLOW_CORE_FRAMEWORK_OP_KERNEL_H_ + +#include +#include +#include +#include +#include +#include + +#include "absl/time/time.h" +#include "absl/types/optional.h" +#include "absl/types/span.h" +#include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/framework/cancellation.h" +#include "tensorflow/core/framework/control_flow.h" +#include "tensorflow/core/framework/device_base.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/kernel_def.pb.h" +#include "tensorflow/core/framework/kernel_def_builder.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/node_properties.h" +#include "tensorflow/core/framework/op.h" // TODO(b/62899350): Remove +#include "tensorflow/core/framework/op_requires.h" +#include "tensorflow/core/framework/registration/registration.h" +#include "tensorflow/core/framework/rendezvous.h" +#include "tensorflow/core/framework/session_state.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/tensor_shape.pb.h" // TODO(b/62899350): Remove +#include "tensorflow/core/framework/tracking_allocator.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/lib/gtl/manual_constructor.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/profile_utils/cpu_utils.h" +#include "tensorflow/core/platform/thread_annotations.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/protobuf/config.pb.h" +#include "tensorflow/core/util/managed_stack_trace.h" + +// Used to match ops to kernel sources (and eventually to kernel targets) +#ifdef TF_LOG_KERNEL_SOURCES +#define LOG_KERNEL_SOURCES(name) \ + LOG(INFO) << "Kernel found: " << name << " " << __FILE__ << "\n"; +#else +#define LOG_KERNEL_SOURCES(name) +#endif + +namespace Eigen { +struct ThreadPoolDevice; +struct GpuDevice; +} // end namespace Eigen + +namespace tsl { +class CoordinationServiceAgent; +} + +namespace tensorflow { + +namespace checkpoint { +class TensorSliceReaderCacheWrapper; +} // namespace checkpoint + +class AsyncOpKernel; +class CallFrameInterface; +class DeviceMgr; +class FunctionLibraryRuntime; +class OpKernelConstruction; // declared below +class OpKernelContext; // declared below, +class OpRegistryInterface; +class ResourceMgr; +class ScopedStepContainer; +class CollectiveExecutor; +class StepStatsCollectorInterface; + +// A label that is added to kernels that are JIT compiled. These labels will be +// removed before kernels are looked up, so they can be used without specifying +// the label. This label is a temporary measure to allow JIT kernels to be +// disabled if needed. +extern const char* kJitKernelLabel; +extern const char* kDisableJitKernelsEnvVar; + +class OpKernel { + public: + // OpKernel won't be instantiated by the scheduler, so you may perform + // expensive initialization in the descendant's constructor. + explicit OpKernel(OpKernelConstruction* context); + + // Specialized constructor that allows a kernel implementation to mark itself + // as a "deferred" op. If true, the executor will provide access to the + // `OpKernelContext::inc_num_deferred_ops_function()` and + // `OpKernelContext::dec_num_deferred_ops_function()` methods at run-time. + OpKernel(OpKernelConstruction* context, bool is_deferred); + + // Specialized constructor that enables the descendant to provide a custom + // `NodeDef` value. For example, this constructor can be used to provide a + // stripped-down `NodeDef` that does not contain the full set of attrs (such + // as tensor values) if the descendant stores them in a different form. + OpKernel(OpKernelConstruction* context, NodeDef&& custom_def, + bool is_deferred); + + virtual ~OpKernel(); + + // An OpKernel's computation can be either synchronous or + // asynchronous. All OpKernel Compute() methods must be thread-safe as they + // may be called concurrently (e.g. by multiple executions of the same graph + // concurrently). + // + // Most OpKernels should compute synchronously. They should + // subclass OpKernel and override the Compute() method and have it + // return after completing the supplied work. + // + // A synchronous OpKernel *MUST NOT* block the calling thread on a + // synchronization mechanism (condition variable, Notification, etc.) that + // will be unblocked by the execution of another OpKernel. Execution may + // deadlock in that case, because the executor may use a bounded number of + // threads. + // + // If an OpKernel must block on the execution of another OpKernel (e.g. a + // RecvOp, or a DequeueOp), the implementation *MUST* subclass AsyncOpKernel, + // and override `AsyncOpKernel::ComputeAsync()`. In addition, because the + // unblocking kernel may never run (due to an error or cancellation), in most + // cases the AsyncOpKernel should implement cancellation support via + // `ctx->cancellation_manager()`. + // + // In both cases, implementations of Compute() and ComputeAsync() + // get inputs and write outputs through the given OpKernelContext + // and returns a status via context->SetStatus(). They must be + // thread-safe. + + // Synchronous compute. + // + // "context" is guaranteed to be alive until Compute() returns. + virtual void Compute(OpKernelContext* context) = 0; + + // Returns nullptr iff this op kernel is synchronous. + virtual AsyncOpKernel* AsAsync() { return nullptr; } + + // Returns true iff this op kernel is considered "expensive". The + // runtime may use this flag to optimize graph execution for example + // to "inline" inexpensive kernels. + virtual bool IsExpensive() { return expensive_; } + + // Returns a pointer to the tensor stored inside constant ops. + virtual const Tensor* const_tensor() const { return nullptr; } + + // Accessors. + const NodeDef& def() const { return props_->node_def; } + const std::string& name() const { return props_->node_def.name(); } + absl::string_view name_view() const { return name_view_; } + const std::string& type_string() const { return props_->node_def.op(); } + absl::string_view type_string_view() const { return type_string_view_; } + const std::string& requested_input(int i) const { + return props_->node_def.input(i); + } + const std::string& requested_device() const { + return props_->node_def.device(); + } + + int num_inputs() const { return props_->input_types.size(); } + DataType input_type(int i) const { return props_->input_types[i]; } + const DataTypeVector& input_types() const { return props_->input_types; } + const MemoryTypeVector& input_memory_types() const { + return input_memory_types_; + } + + int num_outputs() const { return props_->output_types.size(); } + DataType output_type(int o) const { return props_->output_types[o]; } + const DataTypeVector& output_types() const { return props_->output_types; } + const MemoryTypeVector& output_memory_types() const { + return output_memory_types_; + } + + absl::Status InputRange(StringPiece input_name, int* start, int* stop) const; + absl::Status OutputRange(StringPiece output_name, int* start, + int* stop) const; + + // Returns `true` if and only if this kernel uses deferred execution. + bool is_deferred() const { return is_deferred_; } + + // Returns a trace string for current computation, op name/type and input + // tensor shape/dtype are encoded for profiler cost analysis. Most OpKernel + // should use the default implementation. + virtual std::string TraceString(const OpKernelContext& ctx, + bool verbose) const; + + protected: + std::string ShapeTraceString(const OpKernelContext& ctx) const; + + private: + const std::shared_ptr props_; + const MemoryTypeVector input_memory_types_; + const MemoryTypeVector output_memory_types_; + NameRangeMap input_name_map_; + NameRangeMap output_name_map_; + const absl::string_view name_view_; + const absl::string_view type_string_view_; + const int graph_def_version_; + const bool is_deferred_; + bool expensive_; + + OpKernel(const OpKernel&) = delete; + void operator=(const OpKernel&) = delete; +}; + +class AsyncOpKernel : public OpKernel { + public: + using OpKernel::OpKernel; // Lift OpKernel constructors. + + // Asynchronous compute. + // + // Implementations of ComputeAsync() must ensure that `done` is (eventually) + // called exactly once to signal the completion of the computation. The + // implementation of ComputeAsync() must not block on the execution of another + // OpKernel. `done` may be called by the current thread, or by another thread. + // `context` is guaranteed to stay alive until the `done` callback starts. + // + // Since it is possible that the unblocking kernel may never run (due to an + // error or cancellation), in most cases the AsyncOpKernel should implement + // cancellation support via `context->cancellation_manager()`. + // + // WARNING: As soon as the `done` callback starts, `context` and `this` may be + // deleted. No code depending on these objects should execute after the call + // to `done`. + typedef std::function DoneCallback; + virtual void ComputeAsync(OpKernelContext* context, DoneCallback done) = 0; + + AsyncOpKernel* AsAsync() override { return this; } + + void Compute(OpKernelContext* context) override; +}; + +class OpKernelConstruction { + public: + OpKernelConstruction(DeviceType device_type, DeviceBase* device, + Allocator* allocator, FunctionLibraryRuntime* flib, + ResourceMgr* resource_mgr, + const std::shared_ptr& props, + const MemoryTypeSlice& input_memory_types, + const MemoryTypeSlice& output_memory_types, + int graph_def_version, absl::Status* status); + + Env* env() const { return device_->env(); } + + // Allocation of tensors during kernel construction: + // + // It is legal to temporarily allocate scratch tensor storage during + // Op kernel construction. Scratch tensors should be allocated using + // allocate_temp below. Some kernels need to keep tensors in between + // invocations. If such a Tensor is allocated during kernel + // construction this also must be done using allocate_temp, and the + // Op may only store the returned Tensor object. + + // Allocates a temporary Tensor of the specified type and shape. The + // Tensor must not be used after kernel construction is + // complete. See comment above. + absl::Status allocate_temp(DataType type, const TensorShape& shape, + Tensor* out_temp); + absl::Status allocate_temp(DataType type, const TensorShape& shape, + Tensor* out_temp, + AllocatorAttributes allocator_attr); + + // User-supplied configuration of this operation. + const NodeDef& def() const { return props_->node_def; } + + // For inspecting the inputs to this operation. + int num_inputs() const { return props_->input_types.size(); } + DataType input_type(int i) const { return props_->input_types[i]; } + const DataTypeSlice& input_types() const { return props_->input_types_slice; } + const MemoryTypeSlice& input_memory_types() const { + return input_memory_types_; + } + + // For inspecting the outputs expected from this operation. + int num_outputs() const { return props_->output_types.size(); } + DataType output_type(int i) const { return props_->output_types[i]; } + const DataTypeSlice& output_types() const { + return props_->output_types_slice; + } + const MemoryTypeSlice& output_memory_types() const { + return output_memory_types_; + } + + // If expected_inputs == inputs() and expected_outputs == output_types(), + // returns OK, else returns INVALID_ARGUMENT with an error message. + // Recommended for Ops with dynamic signatures. + absl::Status MatchSignature(const DataTypeSlice expected_inputs, + const DataTypeSlice expected_outputs); + + // For recording configuration errors during construction. + void SetStatus(const absl::Status& status); + const absl::Status& status() const { return *status_; } + + // Look up the attr with name attr_name and set *value to its value. If no + // attr with attr_name is found in def(), or the attr does not have + // a matching type, a non-ok status will be returned. + template + absl::Status GetAttr(StringPiece attr_name, + T* value) const TF_ATTRIBUTE_NOINLINE; + + // Return true if the attr_name is defined in def(). + bool HasAttr(StringPiece attr_name) const; + + // Return the device type. + const DeviceType& device_type() const { return device_type_; } + + // If not nullptr, the kernel can instantiate functions defined in + // the library. E.g., + // CHECK_NOTNULL(function_library())->Instantiate("Foo", ...). + FunctionLibraryRuntime* function_library() const { return flib_; } + + // Shared resources accessible to this kernel. + ResourceMgr* resource_manager() const { return resource_mgr_; } + + // The GraphDef version whose behavior we should follow. + int graph_def_version() const { return graph_def_version_; } + + // Helper routines for the OP_REQUIRES macros + void CtxFailure(const absl::Status& s); + void CtxFailureWithWarning(const absl::Status& s); + void CtxFailure(const char* file, int line, const absl::Status& s); + void CtxFailureWithWarning(const char* file, int line, const absl::Status& s); + + // Unrecommended functions: these are functions that have some + // current uses but are not recommended for use, and may go away at + // some future major version release. + + // May be used, e.g., to get GPU handles, etc. + // + // Currently only used to call MakeTensorFromProto() for + // implementing ConstantOp for every device. See comments + // on Device::MakeTensorFromProto for longer-term replacement + // ideas. + DeviceBase* device() const { return device_; } + + private: + const DeviceType device_type_; + DeviceBase* const device_; + Allocator* allocator_; + FunctionLibraryRuntime* flib_; + ResourceMgr* const resource_mgr_; + std::shared_ptr props_; + MemoryTypeSlice input_memory_types_; + MemoryTypeSlice output_memory_types_; + const int graph_def_version_; + absl::Status* status_; + + // Allow access from OpKernel ctor. + friend class OpKernel; + + OpKernelConstruction(const OpKernelConstruction&) = delete; + void operator=(const OpKernelConstruction&) = delete; +}; + +// TODO(mrry): Consider converting to a random_access_iterator, and upgrading +// tensorflow::gtl::iterator_range to make the below container classes +// unnecessary. +template +class OpArgIterator { + public: + using iterator_category = std::forward_iterator_tag; + using value_type = ElementType; + using pointer = ElementType*; + using const_pointer = const ElementType*; + using reference = ElementType&; + using const_reference = const ElementType&; + using difference_type = ptrdiff_t; + + OpArgIterator(const ListType* list, int i) : list_(list), i_(i) {} + + bool operator==(const OpArgIterator& rhs) { + DCHECK(list_ == rhs.list_); + return i_ == rhs.i_; + } + + bool operator!=(const OpArgIterator& rhs) { + DCHECK(list_ == rhs.list_); + return i_ != rhs.i_; + } + + OpArgIterator operator++() { // prefix ++it + ++i_; + return *this; + } + + OpArgIterator operator++(int) { // postfix it++ + OpArgIterator old_value = *this; + ++i_; + return old_value; + } + + reference operator*() { return (*list_)[i_]; } + pointer operator->() { return &(*list_)[i_]; } + + const_reference operator*() const { return (*list_)[i_]; } + const_pointer operator->() const { return &(*list_)[i_]; } + + private: + const ListType* const list_; + int i_; +}; + +// Utility class for representing a list of immutable input tensors +// that are passed to the op as a single named argument. +class OpInputList { + public: + typedef OpArgIterator Iterator; + OpInputList() : ctx_(nullptr), start_(0), stop_(0) {} + OpInputList(OpKernelContext* ctx, int start, int stop) + : ctx_(ctx), start_(start), stop_(stop) {} + OpInputList& operator=(const OpInputList& other) = default; + const Tensor& operator[](int i) const; + int size() const { return stop_ - start_; } + Iterator begin() const { return Iterator(this, 0); } + Iterator end() const { return Iterator(this, size()); } + + private: + OpKernelContext* ctx_; // not owned + int start_; + int stop_; +}; + +// Utility class for representing a list of mutable ("ref") input tensors +// that are passed to the op as a single named argument. +class OpMutableInputList { + public: + typedef OpArgIterator Iterator; + OpMutableInputList(OpKernelContext* ctx, int start, int stop) + : ctx_(ctx), start_(start), stop_(stop) {} + OpMutableInputList() : ctx_(nullptr), start_(0), stop_(0) {} + OpMutableInputList& operator=(const OpMutableInputList& other) = default; + Tensor at(int i, bool lock_held); + mutex* ref_mutex(int i); + int size() const { return stop_ - start_; } + Iterator begin() const { return Iterator(this, 0); } + Iterator end() const { return Iterator(this, size()); } + + private: + OpKernelContext* ctx_; // not owned + int start_; + int stop_; +}; + +// Utility class for representing a list of output tensors that are +// grouped as a single named output. +class OpOutputList { + public: + typedef OpArgIterator Iterator; + OpOutputList() : ctx_(nullptr), start_(0), stop_(0) {} + OpOutputList(OpKernelContext* ctx, int start, int stop) + : ctx_(ctx), start_(start), stop_(stop) {} + OpOutputList& operator=(const OpOutputList& other) = default; + Tensor* operator[](int i); + bool required(int i) const; + DataType expected_output_dtype(int i) const; + absl::Status allocate(int i, const TensorShape& shape, Tensor** output); + void set(int i, const Tensor& tensor); + void set(int i, Tensor&& tensor); + void set_ref(int i, mutex* mu, Tensor* tensor_for_ref); + int size() const { return stop_ - start_; } + Iterator begin() const { return Iterator(this, 0); } + Iterator end() const { return Iterator(this, size()); } + + private: + OpKernelContext* ctx_; // not owned + int start_; + int stop_; +}; + +// Holds a tensor or tensor reference. For tensor references, we need +// a mutex to prevent concurrent access to the tensor. +struct TensorValue { + TensorValue() : mutex_if_ref(nullptr), tensor(nullptr) {} + explicit TensorValue(Tensor* t) : mutex_if_ref(nullptr), tensor(t) {} + TensorValue(mutex* mu, Tensor* t) : mutex_if_ref(mu), tensor(t) {} + Tensor* operator->() const { return tensor; } + bool is_ref() const { return mutex_if_ref != nullptr; } + + // Return the dtype of the Tensor. For references, return the underlying type. + DataType dtype() const { + if (is_ref()) { + return MakeRefType(tensor->dtype()); + } else { + return tensor->dtype(); + } + } + + // Return the dtype of the Tensor. For references, return the underlying type. + // This variation on the dtype() acquires the lock for references. + // + // TODO(b/133843385): Disallow dtype modifications + DataType dtype_safe() const { + if (is_ref()) { + tf_shared_lock ml(*mutex_if_ref); + return MakeRefType(tensor->dtype()); + } else { + return tensor->dtype(); + } + } + + mutex* mutex_if_ref; // nullptr if not a ref, != nullptr if a ref + Tensor* tensor; +}; + +// Used to store partitioned graphs from function-calling ops. +struct GraphCollector { + mutex mu; + std::vector partitioned_graphs TF_GUARDED_BY(mu); + GraphDef raw_graph TF_GUARDED_BY(mu); + GraphDef optimized_graph TF_GUARDED_BY(mu); + + bool dirty TF_GUARDED_BY(mu); + + GraphCollector() : dirty(false) {} + + void CollectRawGraph(const GraphDef& graph) { + mutex_lock ml(mu); + raw_graph.MergeFrom(graph); + dirty = true; + } + + void CollectOptimizedGraph(const GraphDef& graph) { + mutex_lock ml(mu); + optimized_graph.MergeFrom(graph); + dirty = true; + } + + void CollectPartitionedGraph(const GraphDef& graph) { + mutex_lock ml(mu); + partitioned_graphs.push_back(graph); + dirty = true; + } + + void ClearGraphs() TF_EXCLUSIVE_LOCKS_REQUIRED(mu) { + raw_graph.Clear(); + optimized_graph.Clear(); + partitioned_graphs.clear(); + dirty = false; + } + + bool HasUpdatedGraphs() { + mutex_lock ml(mu); + return dirty; + } +}; + +class OpKernelContext { + public: + // The first element of a WrappedAllocator is a "base" Allocator and + // the second element is that Allocator wrapped by a + // TrackingAllocator + typedef std::pair WrappedAllocator; + + // TODO(zhifengc): Do some cleanup of Params. + // The Params struct is passed in to initialize an OpKernelContext, + // and must outlive the OpKernelContext. + struct Params { + ~Params() { delete eigen_gpu_device; } + + // The step being executed. + int64_t step_id = 0; + + // Timestamp for the start of graph execution. Used for latency metrics. + int64_t start_time_usecs = 0; + + // The deadline for the session to complete by. Empty if unspecified. + std::optional deadline; + + // The op kernel being computed. + OpKernel* op_kernel = nullptr; + + // The device on which the kernel is running. + DeviceBase* device = nullptr; + + // The Eigen GPU device wrapper, which may include a per-op + // wrapped allocator. The concrete type of this object depends on + // the type of this->device, so eigen_gpu_device can't be an + // inline member and must be heap allocated. However, we don't + // want to allocate a new eigen_gpu_device for every Op that is + // executed. Instead this member is allocated on first use using + // ensure_eigen_gpu_device, and then if the Params structure is + // re-used for subsequent Ops, the eigen_gpu_device is + // ReInitialized in the OpKernelContext constructor. Unlike the + // other pointers in Params, this one is owned by Params. + PerOpGpuDevice* eigen_gpu_device = nullptr; + + inline void ensure_eigen_gpu_device() { + DCHECK(device); + if (nullptr == eigen_gpu_device) { + // Surprisingly, MakeGpuDevice will return nullptr if the + // device is not a GPU device. This is ok, since those devices + // will never use eigen_gpu_device. It seems better to have + // ensure_eigen_gpu_device fall through and regenerate the + // nullptr every time an OpKernelContext is instantiated, than + // to do an unnecessary allocation of a dummy eigen GPU + // device for CPU device Ops. + eigen_gpu_device = device->MakeGpuDevice(); + } + } + + bool track_allocations = false; + bool log_memory = false; + + // Array indexed by output number for this node + const AllocatorAttributes* output_attr_array = nullptr; + + // Shared resources accessible by this op kernel invocation. + ResourceMgr* resource_manager = nullptr; + + // Per-step resources accessible by this op kernel invocation should be + // stored in this container.. + ScopedStepContainer* step_container = nullptr; + + // Mechanism used by this op kernel invocation to communicate with + // computations running on other devices. + RendezvousInterface* rendezvous = nullptr; + + // Mechanism for executing a collective op that needs to coordinate + // with parallel instances running on other devices. + CollectiveExecutor* collective_executor = nullptr; + + // Session configuration parameters. Can be nullptr. + const ConfigProto* session_config = nullptr; + + // The session state for this op. + SessionState* session_state = nullptr; + + // Unique session identifier. Can be empty. + std::string session_handle; + + // Metadata about the session. Can be nullptr. + const SessionMetadata* session_metadata = nullptr; + + // The tensor store for this op. + TensorStore* tensor_store = nullptr; + + // Mechanism used by this op kernel invocation to register a callback + // for its cancellation. + CancellationManager* cancellation_manager = nullptr; + + // Inputs to this op kernel. + absl::Span inputs; + bool is_input_dead = false; + + absl::Span input_alloc_attrs; + + // Device context. + DeviceContext* op_device_context = nullptr; + + // Control-flow op supports. + FrameAndIter frame_iter; + + // Function call supports. + CallFrameInterface* call_frame = nullptr; + FunctionLibraryRuntime* function_library = nullptr; + std::function)>* runner = nullptr; + StepStatsCollectorInterface* stats_collector = nullptr; + GraphCollector* graph_collector = nullptr; + bool run_all_kernels_inline = false; + const std::string* executor_type = nullptr; + + // TensorSliceReaderCache support. + checkpoint::TensorSliceReaderCacheWrapper* slice_reader_cache = nullptr; + + // Support for forwarding reservations (used by ScopedAllocator). + static constexpr int kNeverForward = -2; + static constexpr int kNoReservation = -1; + // Values in [0,...) represent reservations for the indexed output. + const int* forward_from_array = nullptr; + + // For tracking actively running deferred ops. + std::function inc_num_deferred_ops_function; + std::function dec_num_deferred_ops_function; + + std::optional stack_trace = {}; + + // For implementing `OpKernelContext::output_required()`. If null, all + // outputs are required. + bool* outputs_required_array = nullptr; + + // For access to distributed coordination service. + tsl::CoordinationServiceAgent* coordination_service_agent = nullptr; + }; + + // params must outlive the OpKernelContext. + explicit OpKernelContext(Params* params); + OpKernelContext(Params* params, int num_outputs); + ~OpKernelContext(); + + Env* env() const { return params_->device->env(); } + + int64_t step_id() const { return params_->step_id; } + + int64_t start_time_usecs() const { return params_->start_time_usecs; } + + const ConfigProto* session_config() const { return params_->session_config; } + + // The deadline for the session to complete by. Empty if unspecified in + // RunOptions. + std::optional deadline() const { return params_->deadline; } + + const OpKernel& op_kernel() const { return *params_->op_kernel; } + + // Stack trace of where the op was defined (if defined in eager mode). + const absl::optional& stack_trace() const { + return params_->stack_trace; + } + + // Input/output signature. + + int num_inputs() const { return params_->inputs.size(); } + DataType input_dtype(int index) const; + absl::Status input_dtype(StringPiece name, DataType* dtype) const; + MemoryType input_memory_type(int index) const; + + int num_outputs() const { return outputs_.size(); } + DataType expected_output_dtype(int index) const; + MemoryType output_memory_type(int index) const; + + // Input + + // Returns an immutable input tensor by index. May only be used for non-Ref + // inputs. For Ref inputs use mutable_input below. + // REQUIRES: !IsRefType(input_dtype(index)) + // TODO(mrry): Convert this to return Status. + const Tensor& input(int index) const; + + // Returns an immutable input tensor in "tensor" by index. May only be used + // for non-Ref inputs. For Ref inputs use mutable_input below. + // REQUIRES: !IsRefType(input_dtype(index)) + absl::StatusOr get_input(int index) const; + + // Returns the named immutable input tensor in "tensor", as defined + // in the OpDef. May only be used for non-Ref inputs. For Ref inputs + // use mutable_input below. + // REQUIRES: !IsRefType(input_dtype(index)) + // REQUIRES: the named input must not be a list. + absl::Status input(StringPiece name, const Tensor** tensor); + + // Returns the named list-valued immutable input in "list", as + // defined in the OpDef. If the named output is not list-valued, + // returns a one-element list. May only be used for non-Ref + // inputs. For Ref inputs use mutable_input below. + // REQUIRES: !IsRefType(input_dtype(index)) + absl::Status input_list(StringPiece name, OpInputList* list); + + // For mutable inputs, use the following together to make sure there + // is no concurrent access to mutable_input(), e.g.: + // { + // Tensor& t = context->mutable_input(index); + // mutex_lock lock(*context->input_ref_mutex(index)); + // // modify the values in t + // } + // REQUIRES: IsRefType(input_dtype(index)) + absl::Status input_ref_mutex(StringPiece name, mutex** out_mutex); + + // Returns a mutable input tensor. Must be used to access Ref + // inputs. REQUIRES: IsRefType(input_dtype(index)). The caller may + // modify the values stored in the Tensor buffer, and modifications + // will be visible to other Ops reading the same ref tensor. If + // !lock_held the input mutex will be acquired before returning the + // Tensor. + // TODO(mrry): Convert this to return Status. + Tensor mutable_input(int index, bool lock_held); + + // Returns the named mutable input tensor in "tensor", as defined in + // the OpDef. Must be used to access Ref inputs. The values stored + // in the Tensor buffer may be modified, and modifications will be + // visible to other Ops reading the same ref tensor. If !lock_held + // the input mutex will be acquired before returning the Tensor. + // REQUIRES: the named input must not be a list. + // REQUIRES: the named input must be a ref tensor. + absl::Status mutable_input(StringPiece name, Tensor* tensor, bool lock_held); + + // Returns the named list-valued mutable input in "list", as defined + // in the OpDef. If the named input is not list-valued, returns a + // one-element list. Must be used to access Ref inputs. The values + // stored in the Tensor buffer may be modified, and modifications + // will be visible to other Ops reading the same ref tensor. + // REQUIRES: the named input must be a ref tensor. + absl::Status mutable_input_list(StringPiece name, OpMutableInputList* list); + + // Replace the corresponding Ref Input to use the storage buffer + // used by tensor. If !lock_held the input mutex will be acquired + // before returning the Tensor. + // REQUIRES: IsRefType(input_dtype(index)). + void replace_ref_input(int index, const Tensor& tensor, bool lock_held); + + // Replace the corresponding named Ref Input to use the storage + // buffer used by tensor. If !lock_held the input mutex will be + // acquired before returning the Tensor. + // REQUIRES: IsRefType(input_dtype(index)). + absl::Status replace_ref_input(StringPiece name, const Tensor& tensor, + bool lock_held); + + // Deletes the Tensor object used as the Ref Input at + // input_index. This is not usually necessary and should be used + // with caution. If !lock_held the input mutex will be acquired + // before returning the Tensor. + // REQUIRES: IsRefType(input_dtype(input_index)). + void delete_ref_input(int input_index, bool lock_held); + + // Return true if there is input at the given index. An operator has no + // input at index if its tensor is null. This is primarily used by the + // merge operator. + // TODO(mrry): Convert this to return Status. + bool has_input(int index) const; + + // Returns true if all inputs are the same shape, otherwise sets the + // status to a non-OK value and returns false. + // Usage: if (!context->ValidateInputsAreSameShape(this)) return; + bool ValidateInputsAreSameShape(OpKernel* op); + + // If non-null, kernels should populate with any partition subgraphs created. + GraphCollector* graph_collector() { return params_->graph_collector; } + + // If True, hint that all kernels in functions called by this kernel, should + // be treated as "inexpensive", and hence executed on the scheduling thread. + bool run_all_kernels_inline() const { + return params_->run_all_kernels_inline; + } + + // Returns the registered name for the executor type that is executing the + // current kernel. If empty, the default executor is used. + const std::string& executor_type() const; + + // Input to output forwarding. + + // Set the output Ref Tensor at output_index to be an alias of the + // input Ref Tensor at input_index. + // REQUIRES: IsRefType(input_dtype(input_index)). + // REQUIRES: IsRefType(output_dtype(output_index)). + void forward_ref_input_to_ref_output(int input_index, int output_index); + + // Returns true when an alias to input[input_index], reshaped to output_shape, + // which is safe to use for in-place computation was written to *output. + // Returns false if input[input_index] has a refcount greater than one, or if + // its type does not match the expected output type of output[output_index], + // or the number of elements in input[input_index] does not equal the number + // of elements in output_shape. + bool forward_input_to_output_with_shape(int input_index, int output_index, + const TensorShape& output_shape, + Tensor** output) TF_MUST_USE_RESULT; + absl::Status forward_input_to_output_with_shape( + StringPiece input_name, StringPiece output_name, + const TensorShape& output_shape, Tensor** output); + + // Returns a pointer to a Tensor aliasing the underlying buffer backing + // input[input_index] iff + // * input[input_index] is not a ref, + // * the data type, shape, memory type, and allocator attributes of + // input[input_index] are compatible with those given in dtype, shape, + // memory_type, and attr, + // * refcount on the underlying buffer is one. + // * Either there is no forwarding reservation for either input_index + // or output_index or the specified input is reserved for the specified + // output. More precisely: + // + // These cases mean neither input nor output has a reservation: + // forward_from_array = nullptr + // OR (input_index is not in forward_from_array AND + // (output_index == kNoReservation OR + // forward_from_array[output_index] == kNoReservation)) + // + // This case means that input_index is reserved for output_index: + // forward_from_array[output_index] == input_index + // + // This case means the output is reserved to always be allocated, + // never assigned a forwarded input: + // forward_from_array[output_index] == kNeverForward + // + // Otherwise returns nullptr. + // NOTE: For Cuda kernels that read inputs using the __ldg() intrinsic, + // forwarding is only safe if there are no reads via __ldg() after writes + // to the same address. + std::unique_ptr forward_input( + int input_index, int output_index, DataType output_dtype, + const TensorShape& output_shape, MemoryType output_memory_type, + const AllocatorAttributes& output_attr) TF_MUST_USE_RESULT; + + // Tries to forward one of the inputs given in input_indices to + // output[output_index]. If none of the given inputs can be forwarded, calls + // allocate_output() to allocate a new output buffer. The index of the + // forwarded input will be assign to output argument forwarded_input (if it's + // not nullptr). If no inputs are forwarded, forwarded_input will be assigned + // -1. + absl::Status forward_input_or_allocate_output( + absl::Span candidate_input_indices, int output_index, + const TensorShape& output_shape, Tensor** output, + int* forwarded_input = nullptr); + absl::Status forward_input_or_allocate_output( + absl::Span candidate_input_names, + StringPiece output_name, const TensorShape& output_shape, + Tensor** output); + + // Tries to reuse one of the inputs given in input_indices as a temporary. + // If none of the given inputs can be forwarded, calls + // allocate_temp() to allocate a new temporary buffer. + absl::Status forward_input_or_allocate_temp( + absl::Span candidate_input_indices, DataType type, + const TensorShape& shape, const AllocatorAttributes& allocator_attr, + Tensor* out_temp); + + absl::Status forward_input_or_allocate_temp( + absl::Span candidate_input_indices, DataType type, + const TensorShape& shape, Tensor* out_temp) { + return forward_input_or_allocate_temp(candidate_input_indices, type, shape, + AllocatorAttributes(), out_temp); + } + + // Output + + // Returns the named list-valued output in "list", as defined in the OpDef. + // If the named output is not list-valued, returns a one-element list. + absl::Status output_list(StringPiece name, OpOutputList* list); + + // If output_required(index) returns true, the OpKernel's Compute() method + // should call allocate_output(index, ...), set_output(index, ...), + // set_output_ref(index, ...), or set the status to a non-ok value. + // If it returns false, it may output, but is not required to do so. + bool output_required(int index) const { + return !params_->outputs_required_array || + params_->outputs_required_array[index]; + } + + // If output_expects_forwarding returns true, the OpKernel's Compute() method + // should not allocate the output with allocate_output but instead needs to + // use forward_input. + bool output_expects_forwarding(int index) const { + return params_->forward_from_array != nullptr && + params_->forward_from_array[index] >= 0; + } + + // Allocation of tensors during kernel execution inside the Compute + // method: + // + // There are two methods to allocate Tensors when an Op kernel + // executes. + // + // 1) allocate_output. This should be used to allocate any tensor + // that is going to be used as an output from the Op at the end of + // the current execution. The caller indicates which output the + // Tensor will be assigned to, and the call returns the + // newly-allocated Tensor. The Tensor can subsequently be assigned + // to during kernel execution, and will be used as the designated + // output when the kernel execution completes. + // + // 2) allocate_temp. This should be used to allocate any scratch + // storage that is needed while the kernel is executing, and will + // not be retained by the Op. + // + // In some cases a Tensor needs to be used as an output even though + // it was previously allocated elsewhere. The Tensor may have been + // passed as an input, or stored in a Tensor during a + // previous kernel execution, or allocated earlier in the kernel + // execution at a time when it was not known which output it would + // be assigned to. In this case the kernel can use set_output or + // set_output_ref to indicate that the tensor should be used as the + // designated output. It is legal to use any previously-allocated + // Tensor as an argument to set_output or set_output_ref, including + // Tensors allocated via allocate_temp. There may be a performance + // penalty to using a Tensor that was not allocated using + // allocate_output. This is because allocate_output uses the + // AllocatorAttributes stored in output_attr_array for the + // designated output. In some cases, using the wrong attributes may + // cause an extra copy of the Tensor's buffer. + + // Allocates output for the specified output index with shape. + // OpKernelContext retains ownership of the returned pointer. See + // comment above. + // + // If memory allocation fails, returns an error status. + // + // REQUIRES: !IsRefType(expected_output_dtype(index)) + absl::Status allocate_output(int index, const TensorShape& shape, + Tensor** tensor); + absl::Status allocate_output(StringPiece name, const TensorShape& shape, + Tensor** tensor); + // The following methods use the supplied attributes instead of + // those in output_attr_array. The caller is responsible for + // ensuring that the attributes are "compatible" with the + // output_attr_array, e.g. the tensor is allocated on the correct + // device. See comment above. + absl::Status allocate_output(int index, const TensorShape& shape, + Tensor** tensor, AllocatorAttributes attr); + absl::Status allocate_output(StringPiece name, const TensorShape& shape, + Tensor** tensor, AllocatorAttributes attr); + + // Allocates a temporary Tensor of the specified type and + // shape. Devices such as GPUs that enqueue Ops for lazy execution + // may retain references to the temporary tensors after the Op's + // Compute method has run. See comment above. + absl::Status allocate_temp(DataType type, const TensorShape& shape, + Tensor* out_temp, + AllocatorAttributes allocator_attr, + const AllocationAttributes& allocation_attr); + absl::Status allocate_temp(DataType type, const TensorShape& shape, + Tensor* out_temp, + AllocatorAttributes allocator_attr); + absl::Status allocate_temp(DataType type, const TensorShape& shape, + Tensor* out_temp); + + // Copies a tensor (allocated by the caller) to the specified output + // index. REQUIRES: !IsRefType(expected_output_dtype(index)) + // REQUIRES: 'tensor' must have the same MemoryType as + // output_memory_types[index]. See comment above. + absl::Status set_output(StringPiece name, const Tensor& tensor); + absl::Status set_output(StringPiece name, Tensor&& tensor); + void set_output(int index, const Tensor& tensor); + void set_output(int index, Tensor&& tensor); + + // To output a reference. Caller retains ownership of mu and tensor_for_ref, + // and they must outlive all uses within the step. See comment above. + // REQUIRES: IsRefType(expected_output_dtype(index)) + absl::Status set_output_ref(StringPiece name, mutex* mu, + Tensor* tensor_for_ref); + + // Returns nullptr if allocate_output() or set_output() have not been called. + absl::Status mutable_output(StringPiece name, Tensor** tensor); + + // Return the DeviceContext that should be used for this Op. + // + // If using the templated function, the type must be a subclass + // of DeviceContext. + // + // Returns nullptr if the device did not provide one. + template + T* op_device_context(); + DeviceContext* op_device_context() { + DeviceContext* ret = params_->op_device_context; + if (ret == nullptr) { + auto* dev_info = device()->tensorflow_accelerator_device_info(); + if (dev_info) ret = dev_info->default_context; + } + return ret; + } + + AllocatorAttributes input_alloc_attr(int index) const { + if (params_->input_alloc_attrs.empty()) { + return AllocatorAttributes(); + } else { + DCHECK_GE(index, 0); + DCHECK_LT(index, params_->input_alloc_attrs.size()); + return params_->input_alloc_attrs[index]; + } + } + + AllocatorAttributes output_alloc_attr(int index) const { + return params_->output_attr_array[index]; + } + + absl::InlinedVector ConsumeWrappedAllocators() { + absl::InlinedVector retrieved; + if (tracking_state_) { + mutex_lock lock(tracking_state_->mu); + retrieved.swap(tracking_state_->wrapped_allocators); + } + return retrieved; + } + + // Communication. + // + // An op kernel communicates with outside environment through + // Rendezvous Send() and Recv(). + RendezvousInterface* rendezvous() const { return params_->rendezvous; } + + CollectiveExecutor* collective_executor() const { + return params_->collective_executor; + } + + // An op kernel can access the session state it belongs to. + SessionState* session_state() const { return params_->session_state; } + + // Unique identifier of the session it belongs to. Can be empty. + std::string session_handle() const { return params_->session_handle; } + + // Metadata about the session. Can be nullptr. + const SessionMetadata* session_metadata() const { + return params_->session_metadata; + } + + // An op kernel can access the tensor store of the run it belongs to. + TensorStore* tensor_store() const { return params_->tensor_store; } + + // Function call support. + // + // If this kernel invocation is within a function execution, + // call_frame() returns the call frame for the function call. + CallFrameInterface* call_frame() const { return params_->call_frame; } + + // If not nullptr, the kernel invoke functions defined in the + // library. E.g., CHECK_NOTNULL(function_library())->Run("Foo", ...). + FunctionLibraryRuntime* function_library() const { + return params_->function_library; + } + + std::function)>* runner() const { + return params_->runner; + } + StepStatsCollectorInterface* stats_collector() const { + return params_->stats_collector; + } + + // Shared resources accessible to this kernel. + ResourceMgr* resource_manager() const { return params_->resource_manager; } + + checkpoint::TensorSliceReaderCacheWrapper* slice_reader_cache() const { + return params_->slice_reader_cache; + } + + // Execution. + // + // OpKernels can use these eigen devices to carry out their + // numerical computation. + const Eigen::ThreadPoolDevice& eigen_cpu_device() const { + return *device()->eigen_cpu_device(); + } + const Eigen::GpuDevice& eigen_gpu_device() const { + return params_->eigen_gpu_device->device(); + } + template + const EigenDeviceType& eigen_device() const; + + // Error handling. + + // If expected_inputs == inputs() and expected_outputs == output_types(), + // returns OK, else returns INVALID_ARGUMENT with an error message. + // Recommended for Ops with dynamic signatures, where validation can only + // be performed at runtime. + absl::Status MatchSignature(const DataTypeSlice expected_inputs, + const DataTypeSlice expected_outputs); + + // An OpKernel should call SetStatus() if Compute() encounters an + // error. + void SetStatus(const absl::Status& status); + const absl::Status& status() const { return status_; } + + // Cancellation. + // + // EXPERIMENTAL. See the implementation in tensorflow::FIFOQueue for an + // example of how to use this API. + CancellationManager* cancellation_manager() const { + return params_->cancellation_manager; + } + + // Other accessors. + + // For control flow. + FrameAndIter frame_iter() const { return params_->frame_iter; } + bool is_input_dead() const { return params_->is_input_dead; } + + // May be used, e.g., to get GPU handles, etc. + // TODO(tucker): Add example usage. + DeviceBase* device() const { return params_->device; } + + // Per-step container for use by white-listed internal ops. + ScopedStepContainer* step_container() const { + return params_->step_container; + } + + // Access to distributed coordination service. + tsl::CoordinationServiceAgent* coordination_service_agent() const { + return params_->coordination_service_agent; + } + + // Helper routines for the OP_REQUIRES macros + void CtxFailure(const absl::Status& s); + void CtxFailureWithWarning(const absl::Status& s); + void CtxFailure(const char* file, int line, const absl::Status& s); + void CtxFailureWithWarning(const char* file, int line, const absl::Status& s); + + // Unrecommended functions: these are functions that have some + // current uses but are not recommended for use, and may go away at + // some future major version release. + // + // The following functions all have versions that return Status + // to capture error conditions, and are strongly preferred. + Tensor* mutable_output(int index); + mutex* input_ref_mutex(int index); + void set_output_ref(int index, mutex* mu, Tensor* tensor_for_ref); + TensorValue release_output(int index); + + bool track_allocations() const { return params_->track_allocations; } + + // Records temp memory allocation. Tensor object is recorded to identify the + // case where temp memory is used as output memory. + void record_temp_memory_allocation(int64_t size, const Tensor& t) + TF_LOCKS_EXCLUDED(tracking_state_->stats_mu); + + // Returns recorded size of temporary memory; + int64_t temp_memory_allocated() const + TF_LOCKS_EXCLUDED(tracking_state_->stats_mu); + + // Records persistent memory allocation, size can be negative indicating + // deallocation. + void record_persistent_memory_allocation(int64_t size, int64_t alloc_id = -1) + TF_LOCKS_EXCLUDED(tracking_state_->stats_mu); + + // Returns recorded size and ids of persistent memory. + int64_t persistent_memory_allocated() const + TF_LOCKS_EXCLUDED(tracking_state_->stats_mu); + + std::vector persistent_alloc_ids() const + TF_LOCKS_EXCLUDED(tracking_state_->stats_mu); + + // Resets counters for temp and persistent memory and recorded ids. + void clear_recorded_memory() TF_LOCKS_EXCLUDED(tracking_state_->stats_mu); + + bool input_is_ref(int index) const; + + void set_record_memory_consumption(bool v); + + // Used by OpKernel implementations to track actively running deferred ops. + // + // A deferred op is one whose Compute method returns (or whose ComputeAsync + // method invokes the callback) when work is scheduled onto a device. At that + // point, we don't know when the work will actually complete (or if it has + // already completed) on the device. These functions allow the executor to + // track the status of deferred ops and act accordingly. + // + // Deferred OpKernel implementations must use these methods to get two + // functions. It then must call these two functions in pairs, before and after + // device execution, respectively. + TF_MUST_USE_RESULT std::function inc_num_deferred_ops_function() { + DCHECK(params_->op_kernel->is_deferred()); + return params_->inc_num_deferred_ops_function + ? params_->inc_num_deferred_ops_function + : []() {}; + } + TF_MUST_USE_RESULT std::function dec_num_deferred_ops_function() { + DCHECK(params_->op_kernel->is_deferred()); + return params_->dec_num_deferred_ops_function + ? params_->dec_num_deferred_ops_function + : []() {}; + } + + Allocator* get_allocator(AllocatorAttributes attr); + + Params* params() const { return params_; } + void set_params(Params* params) { params_ = params; } + + void ResetOutputs(int num_outputs = 0) { + for (TensorValue& value : outputs_) { + DCHECK(!value.is_ref()); + delete value.tensor; + value.tensor = nullptr; + } + outputs_.resize(num_outputs); + } + + private: + bool record_memory_consumption_ = false; + + // Internal common method used when allocating tensor memory + absl::Status allocate_tensor(DataType type, const TensorShape& shape, + Tensor* out_tensor, + AllocatorAttributes allocator_attr) { + return allocate_tensor(type, shape, out_tensor, allocator_attr, + AllocationAttributes()); + } + + absl::Status allocate_tensor(DataType type, const TensorShape& shape, + Tensor* out_tensor, + AllocatorAttributes allocator_attr, + const AllocationAttributes& allocation_attr); + + // Helpers for `set_output()`. + + // Returns `true` if the tensor was copied into an allocated output. + bool maybe_set_output_by_allocate_and_copy(int index, const Tensor& tensor); + + void maybe_track_allocations_for_set_output(const Tensor& tensor); + + absl::Status get_input_index(StringPiece name, int* out_index) const; + absl::Status get_output_index(StringPiece name, int* out_index) const; + + // Initialize the allocated_scope_ids_ set the first time this method is + // called. + void maybe_initialize_scope_id_set(); + + absl::Status status_; + friend class CollectiveExecutor; // for access to params_ + Params* params_; // not owned + absl::InlinedVector outputs_; + + // Keep track of calls to ScopedAllocator. + // TODO(ayushd): change to absl::flat_hash_set. + std::unique_ptr> allocated_scope_ids_; + + // The following data members are only used when allocation tracking is + // enabled, memory consumption is being recorded, or tensor access is being + // recorded. + struct TrackingState { + mutable mutex mu; + absl::InlinedVector wrapped_allocators + TF_GUARDED_BY(mu); + + mutable mutex stats_mu; + int64_t temp_memory_allocated TF_GUARDED_BY(stats_mu) = 0; + + int64_t persistent_memory_allocated TF_GUARDED_BY(stats_mu) = 0; + absl::InlinedVector, 2UL> + temp_tensor_buffer_and_size TF_GUARDED_BY(stats_mu); + absl::InlinedVector persistent_alloc_ids + TF_GUARDED_BY(stats_mu); + }; + std::unique_ptr tracking_state_; + + // For access to `params_->op_kernel`. + friend void CheckNotInComputeAsync(OpKernelContext* ctx, + const char* correct_macro_name); + + OpKernelContext(const OpKernelContext&) = delete; + void operator=(const OpKernelContext&) = delete; +}; + +template <> +const Eigen::ThreadPoolDevice& OpKernelContext::eigen_device() const; + +template <> +const Eigen::GpuDevice& OpKernelContext::eigen_device() const; + +// Register your OpKernel by specifying the Op's name, the device the +// kernel runs on, any type attr constraints for this kernel, any +// host-memory args, and the class to instantiate. Examples: +// +// // A kernel that supports all types. +// REGISTER_KERNEL_BUILDER(Name("Save").Device(DEVICE_CPU), SaveOp); +// +// // The following are equivalent ways of specifying that the kernel only +// // works if the "T" type attr is set to DT_FLOAT. +// REGISTER_KERNEL_BUILDER( +// Name("Sub").Device(DEVICE_CPU).TypeConstraint("T"), +// SubOp); +// // (You would then repeat this for every type supported by "Sub".) +// +// // This form allows you to specify a list of types as the constraint. +// REGISTER_KERNEL_BUILDER(Name("Sub") +// .Device(DEVICE_CPU) +// .TypeConstraint("T", {DT_FLOAT}), +// SubOp); +// +// // A kernel that expects one of the input tensors in host memory. +// REGISTER_KERNEL_BUILDER( +// Name("Reshape").Device(DEVICE_GPU).HostMemory("shape"), ReshapeOp); +// +// // A kernel that works on any device. Kernels using DEVICE_DEFAULT +// // must aways run on host and all inputs and outputs must use `HostMemory`. +// // Kernels for data management, control-flow primitives or working with +// // tensor shapes for various devices (including `PluggableDevices`) are +// // typical uses. +// REGISTER_KERNEL_BUILDER( +// Name("TensorListLength").Device(DEVICE_DEFAULT).HostMemory("length"), +// TensorListLength); +// +// See kernel_def_builder for details. + +// Instantiate an OpKernel that has been registered. Returns nullptr +// if no operation for that type of device / input signature combination +// (and a NOT_FOUND *status), or there is an error in construction (and +// an INVALID_ARGUMENT *status). Otherwise, the caller takes ownership +// of the returned pointer. +// EXPECTED USAGE: unique_ptr op = CreateOpKernel(...); +// REQUIRES: def has all attrs specified (e.g. using AddDefaultsToNodeDef()). +std::unique_ptr CreateOpKernel( + DeviceType device_type, DeviceBase* device, Allocator* allocator, + const NodeDef& node_def, int graph_def_version, absl::Status* status); + +std::unique_ptr CreateOpKernel( + DeviceType device_type, DeviceBase* device, Allocator* allocator, + const std::shared_ptr& props, int graph_def_version, + absl::Status* status); + +absl::Status CreateOpKernel(DeviceType device_type, DeviceBase* device, + Allocator* allocator, FunctionLibraryRuntime* flib, + const std::shared_ptr& props, + int graph_def_version, OpKernel** kernel); + +absl::Status CreateOpKernel(DeviceType device_type, DeviceBase* device, + Allocator* allocator, FunctionLibraryRuntime* flib, + ResourceMgr* resource_mgr, + const std::shared_ptr& props, + int graph_def_version, OpKernel** kernel); + +// Returns into 'device_types' the subset of prioritized_types that this +// binary has registered for the given NodeDef. +// +// REQUIRES: * 'device_types' is not nullptr. +// * def has all attrs specified (e.g. using AddDefaultsToNodeDef()). +absl::Status SupportedDeviceTypesForNode( + const std::vector& prioritized_types, const NodeDef& def, + PrioritizedDeviceTypeVector* device_types, + const DeviceNameUtils::ParsedName* local_address_spec = nullptr); + +// Returns a message with a description of the kernels registered for op +// `op_name`. +std::string KernelsRegisteredForOp(StringPiece op_name); + +// Call once after Op registration has completed. +absl::Status ValidateKernelRegistrations( + const OpRegistryInterface& op_registry); + +// ----------------------------------------------------------------------------- +// OpKernel registration implementation follows, please ignore. + +// Allow the REGISTER_KERNEL_BUILDER(Name("op_name").Device(...)...) syntax. +namespace register_kernel { + +class Name : public KernelDefBuilder { + public: + explicit Name(const char* op); +}; + +} // namespace register_kernel + +// Kernel registration appears as: +// REGISTER_KERNEL_BUILDER(Name("OpName").Device(DEVICE_CPU)..., OpImpl) +// We'd like to have "OpName" as a constant-expression, without requiring that +// of the overall KernelDefBuilder expression (beginning with the +// register_kernel::Name constructor above). +// +// So, we pull the "OpName" part to a separate macro-level argument. This +// involves treating Name("OpName") as a macro call, via token-pasting (e.g. +// M_## => M_Name("OpName")), and having it expand to '"OpName", +// Name("OpName")' which is then usable as two arguments. +#define TF_EXTRACT_KERNEL_NAME_Name(name_str) \ + name_str, ::tensorflow::register_kernel::Name(name_str) +#define TF_EXTRACT_KERNEL_NAME_IMPL(m, ...) m(__VA_ARGS__) +#define TF_EXTRACT_KERNEL_NAME(m, kernel_builder, ...) \ + TF_EXTRACT_KERNEL_NAME_IMPL(m, TF_EXTRACT_KERNEL_NAME_##kernel_builder, \ + __VA_ARGS__) + +// REGISTER_KERNEL_BUILDER_IMPL_2, with a unique 'ctr' as the first argument. +// TODO(dodgen): There are some uses of this macro inside functions, where +// kernel_builder refers to (non-const) locals (they should be fixed). To +// accommodate those, kernel_builder.Build() appears as an argument to an +// immediately-called lambda (not in the lambda itself). +#define REGISTER_KERNEL_BUILDER_IMPL_3(ctr, op_name, kernel_builder_expr, \ + is_system_kernel, ...) \ + static ::tensorflow::InitOnStartupMarker const register_kernel_##ctr \ + TF_ATTRIBUTE_UNUSED = \ + TF_INIT_ON_STARTUP_IF(is_system_kernel || \ + (SHOULD_REGISTER_OP_KERNEL(#__VA_ARGS__) && \ + SHOULD_REGISTER_OP(op_name))) \ + << ([](::tensorflow::KernelDef const* kernel_def) { \ + ::tensorflow::kernel_factory::OpKernelRegistrar registrar( \ + kernel_def, #__VA_ARGS__, \ + [](::tensorflow::OpKernelConstruction* context) \ + -> ::tensorflow::OpKernel* { \ + return new __VA_ARGS__(context); \ + }); \ + (void)registrar; \ + LOG_KERNEL_SOURCES(op_name) \ + return ::tensorflow::InitOnStartupMarker{}; \ + })(kernel_builder_expr.Build()); + +// REGISTER_KERNEL_BUILDER_IMPL, but with kernel_builder split to op_name, +// kernel_builder_expr. +#define REGISTER_KERNEL_BUILDER_IMPL_2(op_name, kernel_builder_expr, \ + is_system_kernel, ...) \ + TF_NEW_ID_FOR_INIT(REGISTER_KERNEL_BUILDER_IMPL_3, op_name, \ + kernel_builder_expr, is_system_kernel, __VA_ARGS__) + +// REGISTER_KERNEL_BUILDER, but with is_system_kernel bound. +#define REGISTER_KERNEL_BUILDER_IMPL(kernel_builder, is_system_kernel, ...) \ + TF_EXTRACT_KERNEL_NAME(REGISTER_KERNEL_BUILDER_IMPL_2, kernel_builder, \ + is_system_kernel, __VA_ARGS__) + +#define REGISTER_KERNEL_BUILDER(kernel_builder, ...) \ + TF_ATTRIBUTE_ANNOTATE("tf:kernel") \ + REGISTER_KERNEL_BUILDER_IMPL(kernel_builder, false, __VA_ARGS__) + +// The `REGISTER_SYSTEM_KERNEL_BUILDER()` macro acts as +// `REGISTER_KERNEL_BUILDER()` except that the kernel is registered +// unconditionally even when selective registration is used. +#define REGISTER_SYSTEM_KERNEL_BUILDER(kernel_builder, ...) \ + TF_ATTRIBUTE_ANNOTATE("tf:kernel") \ + TF_ATTRIBUTE_ANNOTATE("tf:kernel:system") \ + REGISTER_KERNEL_BUILDER_IMPL(kernel_builder, true, __VA_ARGS__) + +// Checks whether a given kernel is registered on device_type. +bool KernelDefAvailable(const DeviceType& device_type, const NodeDef& node_def); + +// If node of node_name, experimental_debug_info, node_op, node_device and +// node_attrs has a corresponding kernel registered on device_type, returns OK +// and fill in the kernel def and kernel_class_name. and +// may be null. +absl::Status FindKernelDef( + const DeviceType& device_type, StringPiece node_name, + bool has_experimental_debug_info, + const NodeDef_ExperimentalDebugInfo& experimental_debug_info, + StringPiece node_op, StringPiece node_device, AttrSlice node_attrs, + const KernelDef** def, std::string* kernel_class_name); + +// If node_def has a corresponding kernel registered on device_type, +// returns OK and fill in the kernel def and kernel_class_name. and +// may be null. +absl::Status FindKernelDef(const DeviceType& device_type, + const NodeDef& node_def, const KernelDef** def, + std::string* kernel_class_name); + +// Writes a list of all registered kernels to LOG(INFO), to help users debug +// missing kernel errors. +void LogAllRegisteredKernels(); + +// Gets a list of all registered kernels. +KernelList GetAllRegisteredKernels(); + +// Gets a list of all registered kernels for which predicate returns true +KernelList GetFilteredRegisteredKernels( + const std::function& predicate); + +// Gets a list of all registered kernels for a given op +KernelList GetRegisteredKernelsForOp(StringPiece op_name); + +namespace kernel_factory { + +// OpKernelFactory is responsible for creating OpKernels when TensorFlow needs +// them. You register factories with the TensorFlow core by constructing an +// OpKernelRegistrar and passing the factory as a constructor parameter. +class OpKernelFactory { + public: + virtual OpKernel* Create(OpKernelConstruction* context) = 0; + virtual ~OpKernelFactory() = default; +}; + +class OpKernelRegistrar { + public: + // Registers the given kernel factory with TensorFlow. TF will call the + // factory Create() method when it determines that a kernel matching the given + // KernelDef is required. + OpKernelRegistrar(const KernelDef* kernel_def, StringPiece kernel_class_name, + std::unique_ptr factory) + TF_ATTRIBUTE_NOINLINE { + InitInternal(kernel_def, kernel_class_name, std::move(factory)); + } + + // Registers the given factory function with TensorFlow. This is equivalent + // to registering a factory whose Create function invokes `create_fn`. + OpKernelRegistrar(const KernelDef* kernel_def, StringPiece kernel_class_name, + OpKernel* (*create_fn)(OpKernelConstruction*)) + TF_ATTRIBUTE_NOINLINE { + InitInternal(kernel_def, kernel_class_name, + std::make_unique(create_fn)); + } + + private: + struct PtrOpKernelFactory : public OpKernelFactory { + explicit PtrOpKernelFactory(OpKernel* (*create_func)(OpKernelConstruction*)) + : create_func_(create_func) {} + + OpKernel* Create(OpKernelConstruction* context) override; + + OpKernel* (*create_func_)(OpKernelConstruction*); + }; + + void InitInternal(const KernelDef* kernel_def, StringPiece kernel_class_name, + std::unique_ptr factory); +}; + +} // namespace kernel_factory + +// ----------------------------------------------------------------------------- +// Template and inline method implementations, please ignore + +template +absl::Status OpKernelConstruction::GetAttr(StringPiece attr_name, + T* value) const { + return GetNodeAttr(def(), attr_name, value); +} + +inline DataType OpKernelContext::input_dtype(int index) const { + DCHECK_GE(index, 0); + DCHECK_LT(index, num_inputs()); + const TensorValue& value(params_->inputs[index]); + return value.dtype(); +} + +inline MemoryType OpKernelContext::input_memory_type(int index) const { + DCHECK_GE(index, 0); + DCHECK_LT(index, num_inputs()); + return op_kernel().input_memory_types()[index]; +} + +inline DataType OpKernelContext::expected_output_dtype(int index) const { + DCHECK_GE(index, 0); + DCHECK_LT(index, num_outputs()); + return params_->op_kernel->output_type(index); +} + +inline MemoryType OpKernelContext::output_memory_type(int index) const { + DCHECK_GE(index, 0); + DCHECK_LT(index, num_outputs()); + return op_kernel().output_memory_types()[index]; +} + +inline bool OpKernelContext::input_is_ref(int index) const { + const TensorValue& value(params_->inputs[index]); + return value.is_ref(); +} + +// no input if tensor == nullptr. +inline bool OpKernelContext::has_input(int index) const { + DCHECK_GE(index, 0); + DCHECK_LT(index, num_inputs()); + return params_->inputs[index].tensor != nullptr; +} + +inline mutex* OpKernelContext::input_ref_mutex(int index) { + DCHECK_GE(index, 0); + DCHECK_LT(index, num_inputs()); + DCHECK(input_is_ref(index)); + return params_->inputs[index].mutex_if_ref; +} + +inline Tensor* OpKernelContext::mutable_output(int index) { + DCHECK_GE(index, 0); + DCHECK_LT(index, num_outputs()); + return outputs_[index].tensor; +} + +inline TensorValue OpKernelContext::release_output(int index) { + DCHECK_GE(index, 0); + DCHECK_LT(index, num_outputs()); + TensorValue value = outputs_[index]; + outputs_[index] = TensorValue(); + return value; +} + +template +T* OpKernelContext::op_device_context() { + static_assert(std::is_base_of::value, + "T is not a subclass of DeviceContext"); + return static_cast(op_device_context()); +} + +inline const Tensor& OpInputList::operator[](int i) const { + DCHECK_GE(i, 0); + DCHECK_LT(i, stop_ - start_); + return ctx_->input(start_ + i); +} + +inline mutex* OpMutableInputList::ref_mutex(int i) { + DCHECK_GE(i, 0); + DCHECK_LT(i, stop_ - start_); + return ctx_->input_ref_mutex(start_ + i); +} + +inline Tensor OpMutableInputList::at(int i, bool lock_held) { + DCHECK_GE(i, 0); + DCHECK_LT(i, stop_ - start_); + return ctx_->mutable_input(start_ + i, lock_held); +} + +inline Tensor* OpOutputList::operator[](int i) { + DCHECK_GE(i, 0); + DCHECK_LT(i, stop_ - start_); + return ctx_->mutable_output(start_ + i); +} + +inline bool OpOutputList::required(int i) const { + DCHECK_GE(i, 0); + DCHECK_LT(i, stop_ - start_); + return ctx_->output_required(start_ + i); +} + +inline DataType OpOutputList::expected_output_dtype(int i) const { + DCHECK_GE(i, 0); + DCHECK_LT(i, stop_ - start_); + return ctx_->expected_output_dtype(start_ + i); +} + +inline absl::Status OpOutputList::allocate(int i, const TensorShape& shape, + Tensor** output) { + DCHECK_GE(i, 0); + DCHECK_LT(i, stop_ - start_); + return ctx_->allocate_output(start_ + i, shape, output); +} + +inline void OpOutputList::set(int i, const Tensor& tensor) { + DCHECK_GE(i, 0); + DCHECK_LT(i, stop_ - start_); + ctx_->set_output(start_ + i, tensor); +} + +inline void OpOutputList::set(int i, Tensor&& tensor) { + DCHECK_GE(i, 0); + DCHECK_LT(i, stop_ - start_); + ctx_->set_output(start_ + i, std::move(tensor)); +} + +inline void OpOutputList::set_ref(int i, mutex* mu, Tensor* tensor_for_ref) { + DCHECK_GE(i, 0); + DCHECK_LT(i, stop_ - start_); + ctx_->set_output_ref(i, mu, tensor_for_ref); +} + +// Generate a fatal error if OP_REQUIRES or OP_REQUIRES_OK are used in +// AsyncOpKernel implementations. If these macros are used and the condition +// does not hold, the `done` callback will never be called and the system will +// deadlock, so a crash failure is preferable. Since the OP_REQUIRES[_OK] macros +// are legal to use in AsyncOpKernel constructors, we use overload resolution +// to distinguish between OpKernelConstruction* and OpKernelContext* context +// types. +class XlaOpKernelContext; +inline void CheckNotInComputeAsync(XlaOpKernelContext*, const char*) {} +inline void CheckNotInComputeAsync(OpKernelConstruction*, const char*) {} +void CheckNotInComputeAsync(OpKernelContext* ctx, + const char* correct_macro_name); + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_FRAMEWORK_OP_KERNEL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/framework/op_kernel_test_base.h b/third_party/tflite-hdrs/tensorflow/core/framework/op_kernel_test_base.h new file mode 100644 index 00000000..7b3951e5 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/framework/op_kernel_test_base.h @@ -0,0 +1,177 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_FRAMEWORK_OP_KERNEL_TEST_BASE_H_ +#define TENSORFLOW_CORE_FRAMEWORK_OP_KERNEL_TEST_BASE_H_ + +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/attr_value_util.h" +#include "tensorflow/core/framework/fake_input.h" +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor_shape.pb.h" +#include "tensorflow/core/framework/tensor_util.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/test_benchmark.h" +#include "tensorflow/core/protobuf/error_codes.pb.h" +#include "tensorflow/core/public/version.h" +#include "tensorflow/core/util/device_name_utils.h" + +namespace tensorflow { + +static std::vector DeviceTypes() { + return {DeviceType(DEVICE_GPU), DeviceType(DEVICE_CPU)}; +} + +class OpKernelBuilderTest : public ::testing::Test { + protected: + // Each attr is described by a "name|type|value". + NodeDef CreateNodeDef(const string& op_type, + const std::vector& attrs) { + NodeDef node_def; + node_def.set_name(op_type + "-op"); + node_def.set_op(op_type); + for (const string& attr_desc : attrs) { + std::vector parts = str_util::Split(attr_desc, '|'); + CHECK_EQ(parts.size(), 3); + AttrValue attr_value; + CHECK(ParseAttrValue(parts[1], parts[2], &attr_value)) << attr_desc; + node_def.mutable_attr()->insert( + AttrValueMap::value_type(parts[0], attr_value)); + } + return node_def; + } + + std::unique_ptr ExpectSuccess(const string& op_type, + const DeviceType& device_type, + const std::vector& attrs, + DataTypeSlice input_types = {}) { + absl::Status status; + NodeDef def = CreateNodeDef(op_type, attrs); + for (size_t i = 0; i < input_types.size(); ++i) { + def.add_input("a:0"); + } + + Env* env = Env::Default(); + DeviceBase device(env); + + // Test CreateOpKernel() + std::unique_ptr op(CreateOpKernel(device_type, &device, + cpu_allocator(), def, + TF_GRAPH_DEF_VERSION, &status)); + EXPECT_TRUE(status.ok()) << status; + EXPECT_TRUE(op != nullptr); + if (op != nullptr) { + EXPECT_EQ(input_types.size(), op->num_inputs()); + EXPECT_EQ(0, op->num_outputs()); + } + + // Test SupportedDeviceTypesForNode() + PrioritizedDeviceTypeVector devices; + TF_EXPECT_OK(SupportedDeviceTypesForNode(DeviceTypes(), def, &devices)); + bool found = false; + for (const auto& dt : devices) { + if (dt.first == device_type) { + found = true; + } + } + EXPECT_TRUE(found) << "Missing " << device_type << " from " + << devices.size() << " devices."; + + // In case the caller wants to use the OpKernel + return op; + } + + void ExpectFailure(const string& op_type, const DeviceType& device_type, + const std::vector& attrs, error::Code code) { + absl::Status status; + const NodeDef def = CreateNodeDef(op_type, attrs); + Env* env = Env::Default(); + DeviceBase device(env); + + // Test CreateOpKernel(). + std::unique_ptr op(CreateOpKernel(device_type, &device, + cpu_allocator(), def, + TF_GRAPH_DEF_VERSION, &status)); + EXPECT_TRUE(op == nullptr); + EXPECT_FALSE(status.ok()); + if (!status.ok()) { + LOG(INFO) << "Status message: " << status.message(); + EXPECT_EQ(code, status.code()); + + // Test SupportedDeviceTypesForNode(). + PrioritizedDeviceTypeVector devices; + if (absl::IsNotFound(status)) { + TF_EXPECT_OK(SupportedDeviceTypesForNode(DeviceTypes(), def, &devices)); + for (const auto& dt : devices) { + EXPECT_NE(dt.first, device_type); + } + } else { + absl::Status status2 = + SupportedDeviceTypesForNode(DeviceTypes(), def, &devices); + EXPECT_EQ(status.code(), status2.code()); + } + } + } + + string GetKernelClassName(const string& op_type, + const DeviceType& device_type, + const std::vector& attrs, + DataTypeSlice input_types = {}) { + NodeDef def = CreateNodeDef(op_type, attrs); + for (size_t i = 0; i < input_types.size(); ++i) { + def.add_input("a:0"); + } + + const KernelDef* kernel_def = nullptr; + string kernel_class_name; + const absl::Status status = + FindKernelDef(device_type, def, &kernel_def, &kernel_class_name); + if (status.ok()) { + return kernel_class_name; + } else if (absl::IsNotFound(status)) { + return "not found"; + } else { + return status.ToString(); + } + } +}; + +class BaseKernel : public ::tensorflow::OpKernel { + public: + explicit BaseKernel(OpKernelConstruction* context) : OpKernel(context) {} + void Compute(::tensorflow::OpKernelContext* context) override {} + virtual int Which() const = 0; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_FRAMEWORK_OP_KERNEL_TEST_BASE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/framework/op_requires.h b/third_party/tflite-hdrs/tensorflow/core/framework/op_requires.h new file mode 100644 index 00000000..d9a7e35c --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/framework/op_requires.h @@ -0,0 +1,159 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_FRAMEWORK_OP_REQUIRES_H_ +#define TENSORFLOW_CORE_FRAMEWORK_OP_REQUIRES_H_ + +#include + +#include "tensorflow/core/platform/macros.h" + +namespace tensorflow { + +// Convenience macros for asserting and handling exceptional conditions. +// Analogous to the CHECK* macros provided by logging.h. +// +// Example use: +// void Compute(OperationContext* context) { +// OP_REQUIRES(context, context->num_inputs() == 2, +// errors::InvalidArgument("FooOp requires 2 arguments")); +// ... +// absl::Status status = SomeUncertainMethod(); +// OP_REQUIRES_OK(context, status); +// +// // Or in one go: +// OP_REQUIRES_OK(context, SomeUncertainMethod()); +// ... +// } +// +// The *_ASYNC versions take a CALLBACK macro argument which is called just +// before the return in the failure case; the expression in the macro itself +// is evaluated only in the failure case, and can therefore be expensive or +// have side effects that must not occur in the successful case. For example: +// +// auto done = MakeCleanup([&]() { /* necessary continuation */ }); +// OP_REQUIRES_OK_ASYNC(context, SomeUncertainMethod(), done.release()); +// // `done` is still engaged if and only if control reaches here. +// +// These macros depend on CheckNotInComputeAsync and on absl::Status, both +// of which must be defined before invoking the macros. We specifically don't +// include op_kernel.h or the Abseil headers from this header to reduce this +// header's dependencies. These macros may be used with alternative +// implementations of OpKernelContext with fewer dependencies. + +#define OP_REQUIRES(CTX, EXP, STATUS) \ + do { \ + if (!TF_PREDICT_TRUE(EXP)) { \ + CheckNotInComputeAsync((CTX), "OP_REQUIRES_ASYNC"); \ + (CTX)->CtxFailure(__FILE__, __LINE__, (STATUS)); \ + return; \ + } \ + } while (0) + +// The macro arguements passed to the ellipsis must combine to a single +// expression that is convertible to absl::Status. We accept a variable +// number of macro arguments only so as to support interior commas. +#define OP_REQUIRES_OK(CTX, ...) \ + do { \ + if (!TF_PREDICT_TRUE( \ + ::tensorflow::op_requires_internal::OkImpl<::absl::Status>( \ + (CTX), __FILE__, __LINE__, \ + static_cast(__VA_ARGS__)))) { \ + return; \ + } \ + } while (0) + +#define OP_REQUIRES_OK_OR_SET_PAYLOAD(CTX, PAYLOAD_KEY, PAYLOAD_VALUE, STATUS) \ + do { \ + if (!TF_PREDICT_TRUE(STATUS.ok())) { \ + CheckNotInComputeAsync((CTX), "OP_REQUIRES_OK_ASYNC"); \ + if (!PAYLOAD_VALUE.empty()) { \ + STATUS.SetPayload(PAYLOAD_KEY, absl::Cord(PAYLOAD_VALUE)); \ + } \ + (CTX)->CtxFailureWithWarning(__FILE__, __LINE__, STATUS); \ + return; \ + } \ + } while (0) + +#define OP_REQUIRES_ASYNC(CTX, EXP, STATUS, CALLBACK) \ + do { \ + if (!TF_PREDICT_TRUE(EXP)) { \ + (CTX)->CtxFailure(__FILE__, __LINE__, (STATUS)); \ + (CALLBACK)(); \ + return; \ + } \ + } while (0) + +#define OP_REQUIRES_OK_ASYNC(CTX, STATUS, CALLBACK) \ + do { \ + if (!TF_PREDICT_TRUE( \ + ::tensorflow::op_requires_internal::OkAsyncImpl<::absl::Status>( \ + (CTX), __FILE__, __LINE__, (STATUS)))) { \ + (CALLBACK)(); \ + return; \ + } \ + } while (0) + +#define OP_REQUIRES_VALUE(lhs, ctx, rexpr) \ + OP_REQUIRES_VALUE_IMPL( \ + TF_STATUS_MACROS_CONCAT_NAME(_status_or_value, __COUNTER__), lhs, ctx, \ + rexpr) + +#define OP_REQUIRES_VALUE_IMPL(statusor, lhs, ctx, rexpr) \ + auto statusor = (rexpr); \ + OP_REQUIRES_OK(ctx, statusor.status()); \ + lhs = std::move(statusor.value()) + +// The "Impl" functions are implementation details for the above macros. They +// accept values constructed by the macros, and the values are guaranteed to +// be alive for the duration of the function call. Passing the macro arguments +// through a function call is important to support macro arguments that expand +// to short-lived values (which could not be bound to a reference directly). +// +// We use a template parameter S instead of the concrete type absl::Status +// so as to not require the inclusion of the Abseil header in this file. +// The header must be included before the macros are used. + +namespace op_requires_internal { + +// ctx is usually a plain pointer, but could be a smart pointer, so we accept it +// by const ref. +template +bool OkImpl(const Ctx& ctx, const char* file, int line, const S& s) { + if (!TF_PREDICT_TRUE(s.ok())) { + CheckNotInComputeAsync(ctx, "OP_REQUIRES_OK_ASYNC"); + ctx->CtxFailureWithWarning(file, line, s); + return false; + } else { + return true; + } +} + +// ctx is usually a plain pointer, but could be a smart pointer, so we accept it +// by const ref. +template +bool OkAsyncImpl(const Ctx& ctx, const char* file, int line, const S& s) { + if (!TF_PREDICT_TRUE(s.ok())) { + ctx->CtxFailureWithWarning(file, line, s); + return false; + } else { + return true; + } +} + +} // namespace op_requires_internal +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_FRAMEWORK_OP_REQUIRES_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/framework/op_segment.h b/third_party/tflite-hdrs/tensorflow/core/framework/op_segment.h new file mode 100644 index 00000000..10c4fa46 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/framework/op_segment.h @@ -0,0 +1,90 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_FRAMEWORK_OP_SEGMENT_H_ +#define TENSORFLOW_CORE_FRAMEWORK_OP_SEGMENT_H_ + +#include +#include + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/thread_annotations.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +// OpSegment keeps track of OpKernels registered for sessions running +// on a device. +// +// The implementation maintains a two-level map. The 1st level maps +// session handle to the map of registered OpKernels. The 2nd level +// map maps node names to instantiated OpKernel objects. +// +// Each 2-nd level map is reference-counted and the caller can call +// AddHold to obtain a reference on all kernels of a session and +// ensure these kernels are alive until a corresponding RemoveHold is +// called on the same session. +class OpSegment { + public: + OpSegment(); + ~OpSegment(); + + // A hold can be placed on a session, preventing all its kernels + // from being deleted. + void AddHold(const std::string& session_handle); + void RemoveHold(const std::string& session_handle); + + // If the kernel for "node_name" has been created in the + // "session_handle", returns the existing op kernel in "*kernel". + // Otherwise, creates the kernel by calling create_fn(), cache it, + // and returns it in "*kernel". If create_fn() fails, returns the + // error. + // + // OpSegment keeps the ownership of the returned "*kernel". + typedef std::function CreateKernelFn; + absl::Status FindOrCreate(const std::string& session_handle, + const std::string& node_name, OpKernel** kernel, + CreateKernelFn create_fn); + + // Returns true if OpSegment should own the kernel. + static bool ShouldOwnKernel(FunctionLibraryRuntime* lib, + const std::string& node_op); + + private: + // op name -> OpKernel + typedef std::unordered_map KernelMap; + struct Item { + int num_holds = 1; // Num of holds put on the session. + KernelMap name_kernel; // op name -> kernel. + ~Item(); + }; + + // session handle -> item. + // Session handles are produced by strings::FpToString() + typedef std::unordered_map SessionMap; + + mutable mutex mu_; + SessionMap sessions_ TF_GUARDED_BY(mu_); + + OpSegment(const OpSegment&) = delete; + void operator=(const OpSegment&) = delete; +}; + +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_FRAMEWORK_OP_SEGMENT_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/framework/ops_util.h b/third_party/tflite-hdrs/tensorflow/core/framework/ops_util.h new file mode 100644 index 00000000..ae73a562 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/framework/ops_util.h @@ -0,0 +1,116 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_FRAMEWORK_OPS_UTIL_H_ +#define TENSORFLOW_CORE_FRAMEWORK_OPS_UTIL_H_ + +// This file contains utilities for various operations. + +#include + +#include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/util/padding.h" + +namespace tensorflow { + +// Calculates broadcast starting index and size. For SAME padding, addition +// padding could be applied to right, left, top and bottom. Depending on the +// current index, input size, kernel size, stride, padding size, the starting +// index and size for broadcast for that dimension are different from the +// current index and kernel size. +// This is mainly used by gradient algorithms for pooling operations. +absl::Status GetBroadcastSize(const int index, const int in_size, + const int ksize, const int stride, + const int pad_size, int* bindex, int* bsize); + +// Converts Brain's Padding to Eigen's PaddingType. +Eigen::PaddingType BrainPadding2EigenPadding(Padding padding); + +// Given a shape 's' of a tensor of type T. Returns true iff the +// number of bytes occupied by each dim 0 (i.e., &tensor(i + 1, ...) - +// &tensor(i, ...)) is multiple of EIGEN_MAX_ALIGN_BYTES. +template +bool IsInnerDimsSizeAligned(const TensorShape& s) { + if (s.dims() == 0) return false; + const int64_t dim0_size = s.dim_size(0); + if (dim0_size == 0) return false; +#if EIGEN_MAX_ALIGN_BYTES == 0 + return true; +#else + const int64_t bytes_per_dim0 = (s.num_elements() / dim0_size) * sizeof(T); + return bytes_per_dim0 % EIGEN_MAX_ALIGN_BYTES == 0; +#endif +} + +// Given a shape 's' of a tensor of type T and the `start` and `end` index of a +// dim 0 slice, returns true iff slice is aligned with respect to original +// tensor. Here aligned implies the address is a multiple of +// EIGEN_MAX_ALIGN_BYTES. +template +bool IsDim0SliceAligned(const TensorShape& s, int64_t start, + int64_t end_or_size) { + if (s.dims() == 1) { +#if EIGEN_MAX_ALIGN_BYTES == 0 + return true; +#else + bool start_aligned = (start * sizeof(T)) % EIGEN_MAX_ALIGN_BYTES == 0; + // End is aligned if either the explicit end index is passed and is a + // a multiple of EIGEN_MAX_ALIGN_BYTES, or the start index is aligned and + // the size is aligned. So for convenience we can either pass start and + // index, or start and size. + bool end_aligned = (end_or_size * sizeof(T)) % EIGEN_MAX_ALIGN_BYTES == 0; + return start_aligned && end_aligned; +#endif + } else { + return IsInnerDimsSizeAligned(s); + } +} + +// Returns sanitized to have only [a-zA-Z0-9-_]. +std::string SanitizeThreadSuffix(std::string suffix); + +// Helper to compute 'strides' given a tensor 'shape'. I.e., +// strides[i] = prod(shape.dim_size[(i+1):]) +template +gtl::InlinedVector ComputeStride(const TensorShape& shape) { + const int ndims = shape.dims(); + gtl::InlinedVector strides(ndims); + T stride = 1; + for (int i = ndims - 1; i >= 0; --i) { + strides[i] = stride; + stride *= static_cast(shape.dim_size(i)); + } + return strides; +} + +// Helper to compute 'strides' given an Eigen TensorDimensions +template +gtl::InlinedVector ComputeEigenStrides(const EigenDimensions& shape) { + const int ndims = shape.rank(); + gtl::InlinedVector strides(ndims); + T stride = 1; + for (int i = ndims - 1; i >= 0; --i) { + strides[i] = stride; + stride *= static_cast(shape[i]); + } + return strides; +} + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_FRAMEWORK_OPS_UTIL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/framework/partial_tensor_shape.h b/third_party/tflite-hdrs/tensorflow/core/framework/partial_tensor_shape.h new file mode 100644 index 00000000..fa1ce07d --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/framework/partial_tensor_shape.h @@ -0,0 +1,22 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_FRAMEWORK_PARTIAL_TENSOR_SHAPE_H_ +#define TENSORFLOW_CORE_FRAMEWORK_PARTIAL_TENSOR_SHAPE_H_ + +// TODO(irving): Remove this forwarding header +#include "tensorflow/core/framework/tensor_shape.h" + +#endif // TENSORFLOW_CORE_FRAMEWORK_PARTIAL_TENSOR_SHAPE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/framework/queue_interface.h b/third_party/tflite-hdrs/tensorflow/core/framework/queue_interface.h new file mode 100644 index 00000000..e916b506 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/framework/queue_interface.h @@ -0,0 +1,102 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_FRAMEWORK_QUEUE_INTERFACE_H_ +#define TENSORFLOW_CORE_FRAMEWORK_QUEUE_INTERFACE_H_ + +#include +#include + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/resource_mgr.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +// All implementations must be thread-safe. +class QueueInterface : public ResourceBase { + public: + typedef std::vector Tuple; + typedef AsyncOpKernel::DoneCallback DoneCallback; + typedef std::function CallbackWithTuple; + + virtual absl::Status ValidateTuple(const Tuple& tuple) = 0; + virtual absl::Status ValidateManyTuple(const Tuple& tuple) = 0; + + // Stashes a function object for future execution, that will eventually + // enqueue the tuple of tensors into the queue, and returns immediately. The + // function object is guaranteed to call 'callback'. + virtual void TryEnqueue(const Tuple& tuple, OpKernelContext* ctx, + DoneCallback callback) = 0; + + // Same as above, but the component tensors are sliced along the 0th dimension + // to make multiple queue-element components. + virtual void TryEnqueueMany(const Tuple& tuple, OpKernelContext* ctx, + DoneCallback callback) = 0; + + // Stashes a function object for future execution, that will eventually + // dequeue an element from the queue and call 'callback' with that tuple + // element as argument. + virtual void TryDequeue(OpKernelContext* ctx, CallbackWithTuple callback) = 0; + + // Same as above, but the stashed function object will attempt to dequeue + // num_elements items. If allow_small_batch is true, and the Queue is + // closed but at least 1 element is available, there is no blocking + // and between 1 and num_elements items are immediately returned. + // If the queue does not support the allow_small_batch flag will + // return an Unimplemented error. + virtual void TryDequeueMany(int num_elements, OpKernelContext* ctx, + bool allow_small_batch, + CallbackWithTuple callback) = 0; + + // Signals that no more elements will be enqueued, and optionally + // cancels pending Enqueue(Many) operations. + // + // After calling this function, subsequent calls to Enqueue(Many) + // will fail. If `cancel_pending_enqueues` is true, all pending + // calls to Enqueue(Many) will fail as well. + // + // After calling this function, all current and subsequent calls to + // Dequeue(Many) will fail instead of blocking (though they may + // succeed if they can be satisfied by the elements in the queue at + // the time it was closed). + virtual void Close(OpKernelContext* ctx, bool cancel_pending_enqueues, + DoneCallback callback) = 0; + + // Returns true if a given queue is closed and false if it is open. + virtual bool is_closed() const = 0; + + // Assuming *this represents a shared queue, verify that it matches + // another instantiation indicated by node_def. + virtual absl::Status MatchesNodeDef(const NodeDef& node_def) = 0; + + // Returns the number of elements in the queue. + virtual int32 size() const = 0; + + virtual const DataTypeVector& component_dtypes() const = 0; + + string DebugString() const override { + return strings::StrCat("A Queue of size: ", size()); + } + + protected: + ~QueueInterface() override {} +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_FRAMEWORK_QUEUE_INTERFACE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/framework/reader_base.h b/third_party/tflite-hdrs/tensorflow/core/framework/reader_base.h new file mode 100644 index 00000000..73842644 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/framework/reader_base.h @@ -0,0 +1,139 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_FRAMEWORK_READER_BASE_H_ +#define TENSORFLOW_CORE_FRAMEWORK_READER_BASE_H_ + +#include +#include +#include "tensorflow/core/framework/queue_interface.h" +#include "tensorflow/core/framework/reader_interface.h" +#include "tensorflow/core/lib/core/stringpiece.h" + +namespace tensorflow { + +class ReaderBaseState; + +// Default implementation of ReaderInterface. +class ReaderBase : public ReaderInterface { + public: + // name: For use in error messages, should mention both the name of + // the op and the node. + explicit ReaderBase(const string& name); + + // Note that methods with names ending in "Locked" are called while + // the ReaderBase's mutex is held. + + // Implement this function in descendants ----------------------------------- + + // Produce the next key/value pair from the current work item. + // This is called "Locked" since it is executed under a mutex + // that serializes all Reader calls. + // Usage: + // a) If a record was successfully produced, set *produced = true, + // and fill in *key and *value. + // b) If no more records will be produced for this work item, set + // *at_end = true. + // c) If a record was produced, but no more will be produced, you + // may either do both (a) and (b), or do (a) in this call and do (b) in + // the next call to ReadLocked(). + // d) If there was an error producing (e.g. an error reading the file, + // data corruption), return a non-OK() status. ReadLocked may be + // called again if the user reruns this part of the graph. + virtual absl::Status ReadLocked(tstring* key, tstring* value, bool* produced, + bool* at_end) = 0; + + // Descendants may optionally implement these ------------------------------- + + // Produce up to num_records next key/value pairs from the current + // work item, in the same manner of ReadLocked. + virtual absl::Status ReadUpToLocked(int64_t num_records, + std::vector* keys, + std::vector* values, + int64_t* num_read, bool* at_end); + + // Called when work starts / finishes. + virtual absl::Status OnWorkStartedLocked() { return absl::OkStatus(); } + virtual absl::Status OnWorkFinishedLocked() { return absl::OkStatus(); } + + // Called to reset the Reader to a newly constructed state. + virtual absl::Status ResetLocked(); + + // Default implementation generates an Unimplemented error. + // See the protected helper methods below. + virtual absl::Status SerializeStateLocked(tstring* state); + virtual absl::Status RestoreStateLocked(const tstring& state); + + // Accessors ---------------------------------------------------------------- + + // Always true during a call to ReadLocked(). + bool work_in_progress() const { return work_finished_ < work_started_; } + + // Returns the name of the current work item (valid if + // work_in_progress() returns true). May change between calls to + // ReadLocked(). + const tstring& current_work() const { return work_; } + + // What was passed to the constructor. + const string& name() const { return name_; } + + // Produce the key name (from current_work and the actual key). + tstring KeyName(const tstring& key) const; + + protected: + // For descendants wishing to implement serialize & restore state. + + // Writes ReaderBase state to *state. + void SaveBaseState(ReaderBaseState* state) const; + + // Restores ReaderBase state from state. Assumes state was filled + // using SaveBaseState() above. + absl::Status RestoreBaseState(const ReaderBaseState& state); + + private: + // For descendants that wish to obtain the next work item in a different way. + // For implementing Read(). Dequeues the next work item from + // *queue, and if successful returns "work" (a string). May block. + virtual string GetNextWorkLocked(QueueInterface* queue, + OpKernelContext* context) const; + + // Implementations of ReaderInterface methods. These ensure thread-safety + // and call the methods above to do the work. + void Read(QueueInterface* queue, tstring* key, tstring* value, + OpKernelContext* context) override; + + // Produces up to num_records. + // In this implementation all the records come from the same work unit. + int64_t ReadUpTo(const int64_t num_records, QueueInterface* queue, + std::vector* keys, std::vector* value, + OpKernelContext* context) override; + + absl::Status Reset() override; + int64_t NumRecordsProduced() override; + int64_t NumWorkUnitsCompleted() override; + absl::Status SerializeState(tstring* state) override; + absl::Status RestoreState(const tstring& state) override; + + mutable mutex mu_; + const string name_; + int64_t work_started_ = 0; + int64_t work_finished_ = 0; + int64_t num_records_produced_ = 0; + tstring work_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_FRAMEWORK_READER_BASE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/framework/reader_interface.h b/third_party/tflite-hdrs/tensorflow/core/framework/reader_interface.h new file mode 100644 index 00000000..6210b68f --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/framework/reader_interface.h @@ -0,0 +1,88 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_FRAMEWORK_READER_INTERFACE_H_ +#define TENSORFLOW_CORE_FRAMEWORK_READER_INTERFACE_H_ + +#include +#include +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/resource_mgr.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +class QueueInterface; +class ReaderInterface; + +// Readers are the mechanism for reading records from files in +// TensorFlow graphs. Each supported file format has a corresponding +// ReaderInterface descendant and a corresponding Op & OpKernel +// (implemented using ReaderOpKernel from reader_op_kernel.h). +// +// To use a Reader, you first encode "work" (some string, typically a +// filename) in the Reader's "work queue". It then processes the +// "work" (reading records from the file), to produce key/value +// strings. The methods of this class are called by ReaderFoo ops, +// so see ../ops/io_ops.cc for detailed descriptions. +// +// All descendants of this class must be thread-safe. +class ReaderInterface : public ResourceBase { + public: + // Read a single record into *key / *value. May get more work from + // *queue if the current work is complete. Sets the status on + // *context with an OutOfRange Status if the current work is + // complete and the queue is done (closed and empty). + // This method may block. + virtual void Read(QueueInterface* queue, tstring* key, tstring* value, + OpKernelContext* context) = 0; + + // Read up to num_records records into keys / values. May get more work from + // *queue if the current work is complete. Sets the status on + // *context with an OutOfRange Status if the current work is + // complete and the queue is done (closed and empty). + // This method may block. + // The std::vector keys/value pointers are assumed to point to empty + // structures (that have most likely been reserve(num_records)). + // Returns how many records were actually read. + virtual int64_t ReadUpTo(const int64_t num_records, QueueInterface* queue, + std::vector* keys, + std::vector* value, + OpKernelContext* context) = 0; + + // Restore this reader to its newly-constructed state. + virtual absl::Status Reset() = 0; + + // Accessors + virtual int64_t NumRecordsProduced() = 0; + virtual int64_t NumWorkUnitsCompleted() = 0; + + // -- Serialization/Restoration support -- + // Not all readers will support saving and restoring state. + virtual absl::Status SerializeState(tstring* state) = 0; + // Note: Must Reset on error. + virtual absl::Status RestoreState(const tstring& state) = 0; + + string DebugString() const override { return "a reader"; } + + protected: + ~ReaderInterface() override {} +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_FRAMEWORK_READER_INTERFACE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/framework/reader_op_kernel.h b/third_party/tflite-hdrs/tensorflow/core/framework/reader_op_kernel.h new file mode 100644 index 00000000..bc1a7629 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/framework/reader_op_kernel.h @@ -0,0 +1,87 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_FRAMEWORK_READER_OP_KERNEL_H_ +#define TENSORFLOW_CORE_FRAMEWORK_READER_OP_KERNEL_H_ + +#include +#include + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/reader_interface.h" +#include "tensorflow/core/framework/resource_mgr.h" +#include "tensorflow/core/framework/resource_op_kernel.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +// NOTE: This is now a very thin layer over ResourceOpKernel. +// TODO(sjhwang): Remove dependencies to this class, then delete this. + +// Implementation for ops providing a Reader. +class ReaderOpKernel : public ResourceOpKernel { + public: + using ResourceOpKernel::ResourceOpKernel; + + // Must be called by descendants before the first call to Compute() (typically + // called during construction). factory must return a ReaderInterface + // descendant allocated with new that ReaderOpKernel will take ownership of. + void SetReaderFactory(std::function factory) + TF_LOCKS_EXCLUDED(mu_) { + DCHECK(get_resource() == nullptr); + mutex_lock l(mu_); + factory_ = factory; + } + + void Compute(OpKernelContext* context) override { + if (!IsCancellable()) { + ResourceOpKernel::Compute(context); + } else { + // Install cancellation + CancellationManager* cm = context->cancellation_manager(); + CancellationToken token = cm->get_cancellation_token(); + bool already_cancelled = + !cm->RegisterCallback(token, [this]() { this->Cancel(); }); + + if (!already_cancelled) { + ResourceOpKernel::Compute(context); + } else { + context->SetStatus(errors::Cancelled("read operation was cancelled")); + } + } + } + + private: + virtual bool IsCancellable() const { return false; } + virtual void Cancel() {} + + absl::Status CreateResource(ReaderInterface** reader) + TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) override { + *reader = factory_(); + if (*reader == nullptr) { + return errors::ResourceExhausted("Failed to allocate reader"); + } + std::function temp = nullptr; + factory_.swap(temp); + return absl::OkStatus(); + } + + std::function factory_ TF_GUARDED_BY(mu_); +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_FRAMEWORK_READER_OP_KERNEL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/framework/ref_var.h b/third_party/tflite-hdrs/tensorflow/core/framework/ref_var.h new file mode 100644 index 00000000..8e423e81 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/framework/ref_var.h @@ -0,0 +1,31 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_FRAMEWORK_REF_VAR_H_ +#define TENSORFLOW_CORE_FRAMEWORK_REF_VAR_H_ + +#include + +namespace tensorflow { +class OpKernelContext; + +void AssignRefVariable( + OpKernelContext* context, int input_ref_index, int output_ref_index, + int value_index, bool use_locking, bool validate_shape, + bool relax_constraints, + std::function copy); +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_FRAMEWORK_REF_VAR_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/framework/register_types.h b/third_party/tflite-hdrs/tensorflow/core/framework/register_types.h new file mode 100644 index 00000000..eba2ae88 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/framework/register_types.h @@ -0,0 +1,233 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_FRAMEWORK_REGISTER_TYPES_H_ +#define TENSORFLOW_CORE_FRAMEWORK_REGISTER_TYPES_H_ +// This file is used by cuda code and must remain compilable by nvcc. + +#include "tensorflow/core/framework/numeric_types.h" +#include "tensorflow/core/framework/resource_handle.h" +#include "tensorflow/core/framework/variant.h" +#include "tensorflow/core/platform/types.h" + +// Two sets of macros: +// - TF_CALL_float, TF_CALL_double, etc. which call the given macro with +// the type name as the only parameter - except on platforms for which +// the type should not be included. +// - Macros to apply another macro to lists of supported types. These also call +// into TF_CALL_float, TF_CALL_double, etc. so they filter by target platform +// as well. +// If you change the lists of types, please also update the list in types.cc. +// +// See example uses of these macros in core/ops. +// +// +// Each of these TF_CALL_XXX_TYPES(m) macros invokes the macro "m" multiple +// times by passing each invocation a data type supported by TensorFlow. +// +// The different variations pass different subsets of the types. +// TF_CALL_ALL_TYPES(m) applied "m" to all types supported by TensorFlow. +// The set of types depends on the compilation platform. +//. +// This can be used to register a different template instantiation of +// an OpKernel for different signatures, e.g.: +/* + #define REGISTER_PARTITION(type) \ + REGISTER_KERNEL_BUILDER( \ + Name("Partition").Device(DEVICE_CPU).TypeConstraint("T"), \ + PartitionOp); + TF_CALL_ALL_TYPES(REGISTER_PARTITION) + #undef REGISTER_PARTITION +*/ + +#if !defined(IS_MOBILE_PLATFORM) || defined(SUPPORT_SELECTIVE_REGISTRATION) || \ + defined(ANDROID_TEGRA) + +// All types are supported, so all macros are invoked. +// +// Note: macros are defined in same order as types in types.proto, for +// readability. +#define TF_CALL_float(m) m(float) +#define TF_CALL_double(m) m(double) +#define TF_CALL_int32(m) m(::tensorflow::int32) +#define TF_CALL_uint32(m) m(::tensorflow::uint32) +#define TF_CALL_uint8(m) m(::tensorflow::uint8) +#define TF_CALL_int16(m) m(::tensorflow::int16) + +#define TF_CALL_int8(m) m(::tensorflow::int8) +#define TF_CALL_string(m) m(::tensorflow::tstring) +#define TF_CALL_tstring(m) m(::tensorflow::tstring) +#define TF_CALL_resource(m) m(::tensorflow::ResourceHandle) +#define TF_CALL_variant(m) m(::tensorflow::Variant) +#define TF_CALL_complex64(m) m(::tensorflow::complex64) +#define TF_CALL_int64(m) m(::int64_t) +#define TF_CALL_uint64(m) m(::tensorflow::uint64) +#define TF_CALL_bool(m) m(bool) + +#define TF_CALL_qint8(m) m(::tensorflow::qint8) +#define TF_CALL_quint8(m) m(::tensorflow::quint8) +#define TF_CALL_qint32(m) m(::tensorflow::qint32) +#define TF_CALL_bfloat16(m) m(::tensorflow::bfloat16) +#define TF_CALL_qint16(m) m(::tensorflow::qint16) + +#define TF_CALL_quint16(m) m(::tensorflow::quint16) +#define TF_CALL_uint16(m) m(::tensorflow::uint16) +#define TF_CALL_complex128(m) m(::tensorflow::complex128) +#define TF_CALL_half(m) m(Eigen::half) + +#define TF_CALL_float8_e5m2(m) m(::tensorflow::float8_e5m2) +#define TF_CALL_float8_e4m3fn(m) m(::tensorflow::float8_e4m3fn) + +#define TF_CALL_int4(m) m(::tensorflow::int4) +#define TF_CALL_uint4(m) m(::tensorflow::uint4) + +#elif defined(__ANDROID_TYPES_FULL__) + +// Only string, half, float, int32, int64, bool, and quantized types +// supported. +#define TF_CALL_float(m) m(float) +#define TF_CALL_double(m) +#define TF_CALL_int32(m) m(::tensorflow::int32) +#define TF_CALL_uint32(m) +#define TF_CALL_uint8(m) +#define TF_CALL_int16(m) + +#define TF_CALL_int8(m) +#define TF_CALL_string(m) m(::tensorflow::tstring) +#define TF_CALL_tstring(m) m(::tensorflow::tstring) +#define TF_CALL_resource(m) +#define TF_CALL_variant(m) +#define TF_CALL_complex64(m) +#define TF_CALL_int64(m) m(::int64_t) +#define TF_CALL_uint64(m) +#define TF_CALL_bool(m) m(bool) + +#define TF_CALL_qint8(m) m(::tensorflow::qint8) +#define TF_CALL_quint8(m) m(::tensorflow::quint8) +#define TF_CALL_qint32(m) m(::tensorflow::qint32) +#define TF_CALL_bfloat16(m) +#define TF_CALL_qint16(m) m(::tensorflow::qint16) + +#define TF_CALL_quint16(m) m(::tensorflow::quint16) +#define TF_CALL_uint16(m) +#define TF_CALL_complex128(m) +#define TF_CALL_half(m) m(Eigen::half) + +#define TF_CALL_float8_e5m2(m) +#define TF_CALL_float8_e4m3fn(m) + +#define TF_CALL_int4(m) +#define TF_CALL_uint4(m) + +#else // defined(IS_MOBILE_PLATFORM) && !defined(__ANDROID_TYPES_FULL__) + +// Only float, int32, and bool are supported. +#define TF_CALL_float(m) m(float) +#define TF_CALL_double(m) +#define TF_CALL_int32(m) m(::tensorflow::int32) +#define TF_CALL_uint32(m) +#define TF_CALL_uint8(m) +#define TF_CALL_int16(m) + +#define TF_CALL_int8(m) +#define TF_CALL_string(m) +#define TF_CALL_tstring(m) +#define TF_CALL_resource(m) +#define TF_CALL_variant(m) +#define TF_CALL_complex64(m) +#define TF_CALL_int64(m) +#define TF_CALL_uint64(m) +#define TF_CALL_bool(m) m(bool) + +#define TF_CALL_qint8(m) +#define TF_CALL_quint8(m) +#define TF_CALL_qint32(m) +#define TF_CALL_bfloat16(m) +#define TF_CALL_qint16(m) + +#define TF_CALL_quint16(m) +#define TF_CALL_uint16(m) +#define TF_CALL_complex128(m) +#define TF_CALL_half(m) + +#define TF_CALL_float8_e5m2(m) +#define TF_CALL_float8_e4m3fn(m) + +#define TF_CALL_int4(m) +#define TF_CALL_uint4(m) + +#endif // defined(IS_MOBILE_PLATFORM) - end of TF_CALL_type defines + +// Defines for sets of types. +#define TF_CALL_INTEGRAL_TYPES_NO_INT32(m) \ + TF_CALL_uint64(m) TF_CALL_int64(m) TF_CALL_uint32(m) TF_CALL_uint16(m) \ + TF_CALL_int16(m) TF_CALL_uint8(m) TF_CALL_int8(m) + +#define TF_CALL_INTEGRAL_TYPES(m) \ + TF_CALL_INTEGRAL_TYPES_NO_INT32(m) TF_CALL_int32(m) + +#define TF_CALL_FLOAT_TYPES(m) \ + TF_CALL_half(m) TF_CALL_bfloat16(m) TF_CALL_float(m) TF_CALL_double(m) + +#define TF_CALL_REAL_NUMBER_TYPES(m) \ + TF_CALL_INTEGRAL_TYPES(m) TF_CALL_FLOAT_TYPES(m) + +#define TF_CALL_REAL_NUMBER_TYPES_NO_BFLOAT16(m) \ + TF_CALL_INTEGRAL_TYPES(m) TF_CALL_half(m) TF_CALL_float(m) TF_CALL_double(m) + +#define TF_CALL_REAL_NUMBER_TYPES_NO_INT32(m) \ + TF_CALL_half(m) TF_CALL_bfloat16(m) TF_CALL_float(m) TF_CALL_double(m) \ + TF_CALL_INTEGRAL_TYPES_NO_INT32(m) + +#define TF_CALL_COMPLEX_TYPES(m) TF_CALL_complex64(m) TF_CALL_complex128(m) + +// Call "m" for all number types, including complex types +#define TF_CALL_NUMBER_TYPES(m) \ + TF_CALL_REAL_NUMBER_TYPES(m) TF_CALL_COMPLEX_TYPES(m) + +#define TF_CALL_NUMBER_TYPES_NO_INT32(m) \ + TF_CALL_REAL_NUMBER_TYPES_NO_INT32(m) TF_CALL_COMPLEX_TYPES(m) + +#define TF_CALL_POD_TYPES(m) TF_CALL_NUMBER_TYPES(m) TF_CALL_bool(m) + +// Call "m" on all types. +#define TF_CALL_ALL_TYPES(m) \ + TF_CALL_POD_TYPES(m) TF_CALL_tstring(m) TF_CALL_resource(m) TF_CALL_variant(m) + +// Call "m" on POD and string types. +#define TF_CALL_POD_STRING_TYPES(m) TF_CALL_POD_TYPES(m) TF_CALL_tstring(m) + +// Call "m" on all number types supported on GPU. +#define TF_CALL_GPU_NUMBER_TYPES(m) \ + TF_CALL_half(m) TF_CALL_bfloat16(m) TF_CALL_float(m) TF_CALL_double(m) + +// Call "m" on all types supported on GPU. +#define TF_CALL_GPU_ALL_TYPES(m) \ + TF_CALL_GPU_NUMBER_TYPES(m) TF_CALL_COMPLEX_TYPES(m) TF_CALL_bool(m) + +#define TF_CALL_GPU_NUMBER_TYPES_NO_HALF(m) TF_CALL_float(m) TF_CALL_double(m) + +// Call "m" on all quantized types. +// TODO(cwhipkey): include TF_CALL_qint16(m) TF_CALL_quint16(m) +#define TF_CALL_QUANTIZED_TYPES(m) \ + TF_CALL_qint8(m) TF_CALL_quint8(m) TF_CALL_qint32(m) + +// Types used for save and restore ops. +#define TF_CALL_SAVE_RESTORE_TYPES(m) \ + TF_CALL_REAL_NUMBER_TYPES_NO_BFLOAT16(m) \ + TF_CALL_COMPLEX_TYPES(m) \ + TF_CALL_QUANTIZED_TYPES(m) TF_CALL_bool(m) TF_CALL_tstring(m) + +#endif // TENSORFLOW_CORE_FRAMEWORK_REGISTER_TYPES_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/framework/register_types_traits.h b/third_party/tflite-hdrs/tensorflow/core/framework/register_types_traits.h new file mode 100644 index 00000000..b2847d84 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/framework/register_types_traits.h @@ -0,0 +1,93 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_FRAMEWORK_REGISTER_TYPES_TRAITS_H_ +#define TENSORFLOW_CORE_FRAMEWORK_REGISTER_TYPES_TRAITS_H_ +// This file is used by cuda code and must remain compilable by nvcc. + +#include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +typedef Eigen::ThreadPoolDevice CPUDevice; +typedef Eigen::GpuDevice GPUDevice; + + +#include "tensorflow/core/framework/numeric_types.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +// Remap POD types by size to equivalent proxy types. This works +// since all we are doing is copying data around. +struct UnusableProxyType; +template +struct proxy_type_pod { + typedef UnusableProxyType type; +}; +template <> +struct proxy_type_pod { + typedef ::tensorflow::complex128 type; +}; +template <> +struct proxy_type_pod { + typedef ::int64_t type; +}; +template <> +struct proxy_type_pod { + typedef ::tensorflow::int32 type; +}; +template <> +struct proxy_type_pod { + typedef ::tensorflow::int16 type; +}; +template <> +struct proxy_type_pod { + typedef ::tensorflow::int8 type; +}; +template <> +struct proxy_type_pod { + typedef double type; +}; +template <> +struct proxy_type_pod { + typedef float type; +}; +template <> +struct proxy_type_pod { + typedef Eigen::half type; +}; +template <> +struct proxy_type_pod { + typedef ::tensorflow::int8 type; +}; + + +/// If POD we use proxy_type_pod, otherwise this maps to identity. +template +struct proxy_type { + typedef typename std::conditional< + std::is_arithmetic::value, + typename proxy_type_pod::type, T>::type type; + static_assert(sizeof(type) == sizeof(T), "proxy_type_pod is not valid"); +}; + +/// The active proxy types +#define TF_CALL_CPU_PROXY_TYPES(m) \ + TF_CALL_int64(m) TF_CALL_int32(m) TF_CALL_uint16(m) TF_CALL_int16(m) \ + TF_CALL_int8(m) TF_CALL_complex128(m) +#define TF_CALL_GPU_PROXY_TYPES(m) \ + TF_CALL_double(m) TF_CALL_float(m) TF_CALL_half(m) TF_CALL_bfloat16(m) \ + TF_CALL_int32(m) TF_CALL_int8(m) +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_FRAMEWORK_REGISTER_TYPES_TRAITS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/framework/registration/registration.h b/third_party/tflite-hdrs/tensorflow/core/framework/registration/registration.h new file mode 100644 index 00000000..27f2ec2b --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/framework/registration/registration.h @@ -0,0 +1,152 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This file provides some common support for 'registration' of e.g. ops and +// kernels. In particular, it relates to the REGISTER_OP (op registration) and +// REGISTER_KERNEL_BUILDER (kernel registration) macros. +// +// Note that there are two sides to 'registration': +// - Definition (compile-time): making op and kernel definitions _available_. +// - Usage (run-time): adding particular (available) definitions of ops and +// kernels to the global OpRegistry / KernelRegistry, to be found when +// constructing and executing graphs. +// +// Currently, definition and usage happen to be coupled together: all +// 'available' definitions (from the REGISTER_*' macros) are added to the global +// registries on startup / library load. + +#ifndef TENSORFLOW_CORE_FRAMEWORK_REGISTRATION_REGISTRATION_H_ +#define TENSORFLOW_CORE_FRAMEWORK_REGISTRATION_REGISTRATION_H_ + +#include + +#include +#include + +#include "tensorflow/core/framework/registration/options.h" + +#if !TF_OPTION_REGISTRATION_V2() + +#ifdef SELECTIVE_REGISTRATION + +// Experimental selective registration support to reduce binary size. +// +// To use selective registration, when building: +// 1. define SELECTIVE_REGISTRATION, e.g. in gcc by passing +// -DSELECTIVE_REGISTRATION to compilation. +// 2. Provide ops_to_register.h. This file is not included in the repo and must +// be placed by the user or a tool where the compiler can find it. It must +// define the constants and functions used in the macros below. The +// functions should be defined as valid constexpr functions, so that they are +// evaluated at compile time: this is needed to make symbols referenced by +// un-registered objects unused, and therefore allow the linker to strip them +// out. See python/tools/print_selective_registration_header.py for a tool +// that can be used to generate ops_to_register.h. +// +// ops_to_register.h should define macros for: +// // Ops for which this is false will not be registered. +// SHOULD_REGISTER_OP(op) +// // If this is false, then no gradient ops are registered. +// SHOULD_REGISTER_OP_GRADIENT +// // Op kernel classes where this is false won't be registered. +// SHOULD_REGISTER_OP_KERNEL(clz) +// The macros should be defined using constexprs. + +#include "ops_to_register.h" + +#if (!defined(SHOULD_REGISTER_OP) || !defined(SHOULD_REGISTER_OP_GRADIENT) || \ + !defined(SHOULD_REGISTER_OP_KERNEL)) +static_assert(false, "ops_to_register.h must define SHOULD_REGISTER macros"); +#endif +#else // SELECTIVE_REGISTRATION +#define SHOULD_REGISTER_OP(op) true +#define SHOULD_REGISTER_OP_GRADIENT true +#define SHOULD_REGISTER_OP_KERNEL(clz) true +#endif // SELECTIVE_REGISTRATION + +#else // ! TF_OPTION_REGISTRATION_V2() + +#ifdef SELECTIVE_REGISTRATION +#error TF_OPTION_REGISTRATION_V2(): Compile-time selective registration is not supported +#endif + +#endif // ! TF_OPTION_REGISTRATION_V2() + +namespace tensorflow { + +// An InitOnStartupMarker is 'initialized' on program startup, purely for the +// side-effects of that initialization - the struct itself is empty. (The type +// is expected to be used to define globals.) +// +// The '<<' operator should be used in initializer expressions to specify what +// to run on startup. The following values are accepted: +// - An InitOnStartupMarker. Example: +// InitOnStartupMarker F(); +// InitOnStartupMarker const kInitF = +// InitOnStartupMarker{} << F(); +// - Something to call, which returns an InitOnStartupMarker. Example: +// InitOnStartupMarker const kInit = +// InitOnStartupMarker{} << []() { G(); return +// +// See also: TF_INIT_ON_STARTUP_IF +struct InitOnStartupMarker { + constexpr InitOnStartupMarker operator<<(InitOnStartupMarker) const { + return *this; + } + + template + constexpr InitOnStartupMarker operator<<(T&& v) const { + return std::forward(v)(); + } +}; + +// Conditional initializer expressions for InitOnStartupMarker: +// TF_INIT_ON_STARTUP_IF(cond) << f +// If 'cond' is true, 'f' is evaluated (and called, if applicable) on startup. +// Otherwise, 'f' is *not evaluated*. Note that 'cond' is required to be a +// constant-expression, and so this approximates #ifdef. +// +// The implementation uses the ?: operator (!cond prevents evaluation of 'f'). +// The relative precedence of ?: and << is significant; this effectively expands +// to (see extra parens): +// !cond ? InitOnStartupMarker{} : (InitOnStartupMarker{} << f) +// +// Note that although forcing 'cond' to be a constant-expression should not +// affect binary size (i.e. the same optimizations should apply if it 'happens' +// to be one), it was found to be necessary (for a recent version of clang; +// perhaps an optimizer bug). +// +// The parens are necessary to hide the ',' from the preprocessor; it could +// otherwise act as a macro argument separator. +#define TF_INIT_ON_STARTUP_IF(cond) \ + (::std::integral_constant::value) \ + ? ::tensorflow::InitOnStartupMarker{} \ + : ::tensorflow::InitOnStartupMarker {} + +// Wrapper for generating unique IDs (for 'anonymous' InitOnStartup definitions) +// using __COUNTER__. The new ID (__COUNTER__ already expanded) is provided as a +// macro argument. +// +// Usage: +// #define M_IMPL(id, a, b) ... +// #define M(a, b) TF_NEW_ID_FOR_INIT(M_IMPL, a, b) +#define TF_NEW_ID_FOR_INIT_2(m, c, ...) m(c, __VA_ARGS__) +#define TF_NEW_ID_FOR_INIT_1(m, c, ...) TF_NEW_ID_FOR_INIT_2(m, c, __VA_ARGS__) +#define TF_NEW_ID_FOR_INIT(m, ...) \ + TF_NEW_ID_FOR_INIT_1(m, __COUNTER__, __VA_ARGS__) + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_FRAMEWORK_REGISTRATION_REGISTRATION_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/framework/rendezvous.h b/third_party/tflite-hdrs/tensorflow/core/framework/rendezvous.h new file mode 100644 index 00000000..97a5daff --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/framework/rendezvous.h @@ -0,0 +1,177 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_FRAMEWORK_RENDEZVOUS_H_ +#define TENSORFLOW_CORE_FRAMEWORK_RENDEZVOUS_H_ + +#include +#include + +#include "tensorflow/core/framework/cancellation.h" +#include "tensorflow/core/framework/control_flow.h" +#include "tensorflow/core/framework/device_base.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/lib/core/refcount.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/util/device_name_utils.h" + +namespace tensorflow { + +class DeviceMgr; + +// A Rendezvous is an abstraction for passing tensors from producers +// to consumers. A rendezvous is a table of channels. Each channel is +// keyed by a rendezvous key. The key encodes a pair of , where the producer and the consumer are tensorflow +// devices. +// +// The producer calls the Send() method to send one tensor over one +// named channel. The consumer calls the Recv() method to receive one +// tensor from a named channel. A sequence of tensors can be passed +// from the producer to the consumer. The consumer receives them in +// the order as the producer sends them. +// +// A consumer may safely request the tensor before or after it has +// been produced. A consumer has the choice of making a blocking call +// or providing a callback: in either case, the consumer receives the +// Tensor as soon as it is available. A producer never blocks. +class RendezvousInterface { + public: + struct Args { + DeviceContext* device_context = nullptr; + AllocatorAttributes alloc_attrs; + CancellationManager* cancellation_manager = nullptr; // not owned. + }; + + // Parses the key constructed by CreateKey and parse src/dst device + // names into structures respectively. + struct ParsedKey { + absl::string_view src_device; + DeviceNameUtils::ParsedName src; + uint64 src_incarnation = 0; + absl::string_view dst_device; + DeviceNameUtils::ParsedName dst; + absl::string_view edge_name; + + ParsedKey() {} + ParsedKey(const ParsedKey& b) { *this = b; } + + ParsedKey& operator=(const ParsedKey& b); + absl::string_view FullKey() const { return buf_; } + + private: + friend class Rendezvous; + friend class SendOp; + friend class RecvOp; + std::string buf_; + }; + + // The caller is a tensor producer and it sends a message (a tensor + // "val" and a bool "is_dead") under the given "key". + // + // {val, is_dead} is bundled as a message sent and received. + // Typically, is_dead is set by some control flow nodes + // (e.g., a not-taken branch). args is passed by Send to the + // Recv function to communicate any information that the Recv + // function might need. This is typically only necessary for + // Send/Recv on the same worker. + // + // Send() never blocks. + virtual absl::Status Send(const ParsedKey& key, const Args& args, + const Tensor& val, const bool is_dead) = 0; + + // Callback provided by a tensor consumer waiting on the rendezvous. + // It will be invoked when the tensor is available, or when a non-OK + // status arises in the production of that tensor. It also gets + // two Rendezvous::Args, one provided by the sender, the other by the + // receiver, which may be needed when a non-CPU device is in use + // by either side. + typedef std::function + DoneCallback; + + virtual void RecvAsync(const ParsedKey& key, const Args& args, + DoneCallback done) = 0; + + // Synchronous wrapper for RecvAsync. + absl::Status Recv(const ParsedKey& key, const Args& args, Tensor* val, + bool* is_dead, int64_t timeout_ms); + absl::Status Recv(const ParsedKey& key, const Args& args, Tensor* val, + bool* is_dead); + + // Aborts all pending and future Send/Recv with the given "status". + // + // StartAbort() does not wait for ongoing calls to finish. + // REQUIRES: !status.ok() + virtual void StartAbort(const absl::Status& status) = 0; + + virtual ~RendezvousInterface(); + + protected: + virtual bool is_cross_process() { return false; } + friend class ProcessFunctionLibraryRuntime; +}; + +// A reference-counted implementation of RendezvousInterface. +// +// This class is used in cases where a rendezvous may be shared between multiple +// threads with no clear owner. +class Rendezvous : public RendezvousInterface, public core::WeakRefCounted { + public: + class Factory { + public: + // Default to a factory that evaluates to false. + Factory() : valid_(false) {} + + explicit Factory( + std::function*)> + create_fn) + : valid_(true), create_fn_(std::move(create_fn)) {} + + explicit operator bool() const { return valid_; } + + absl::Status operator()(const int64_t step_id, const DeviceMgr* device_mgr, + tsl::core::RefCountPtr* rendez) const { + return create_fn_(step_id, device_mgr, rendez); + } + + private: + bool valid_; + std::function*)> + create_fn_; + }; + + // Constructs a rendezvous key for the tensor of "name" sent from + // "src_device" to "dst_device". The tensor is generated in the frame + // and iteration specified by "frame_iter". + static std::string CreateKey(const std::string& src_device, + uint64 src_incarnation, + const std::string& dst_device, + const std::string& name, + const FrameAndIter& frame_iter); + + static absl::Status ParseKey(absl::string_view key, ParsedKey* out); +}; + +// Returns a Rendezvous instance that is limited to use only by +// producers and consumers in the local process. The caller assumes +// ownership of one Ref() on the returned object. +Rendezvous* NewLocalRendezvous(int num_shards = 1); + +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_FRAMEWORK_RENDEZVOUS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/framework/resource_base.h b/third_party/tflite-hdrs/tensorflow/core/framework/resource_base.h new file mode 100644 index 00000000..c22adb55 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/framework/resource_base.h @@ -0,0 +1,62 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_FRAMEWORK_RESOURCE_BASE_H_ +#define TENSORFLOW_CORE_FRAMEWORK_RESOURCE_BASE_H_ + +#include +#include + +#include "absl/strings/str_format.h" +#include "tensorflow/core/lib/core/refcount.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/errors.h" + +namespace tensorflow { + +// Forward declaration to avoid introducing a dependency on headers in +// "tensorflow/core/graph/...". +class GraphDefBuilder; +class Node; + +// This is the base class of all resource classes. Each resource must be +// represented as a sub-class of ResourceBase (which is reference counted) to be +// able to work with resource facilities such ResourceHandle and ResourceMgr. +class ResourceBase : public core::WeakRefCounted { + public: + // Returns a debug string for *this. + virtual std::string DebugString() const = 0; + + // Returns a name for ref-counting handles. + virtual std::string MakeRefCountingHandleName(int64_t resource_id) const { + return absl::StrFormat("Resource-%d-at-%p", resource_id, this); + } + + // Returns memory used by this resource. + virtual int64_t MemoryUsed() const { return 0; } + + // Writes a representation of this resource into `builder`, so that executing + // `*out` will recreate this resource. The lifetime of the created resource + // should not be tied to the graph that created it, since the graph may be + // destroyed before the resource is used. To avoid this lifetime issue, you + // can usually set a unique `shared_name` attribute for the resource. + virtual absl::Status AsGraphDef(GraphDefBuilder* builder, Node** out) const { + return errors::Unimplemented("AsGraphDef not implemented for resource ", + DebugString()); + } +}; +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_FRAMEWORK_RESOURCE_BASE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/framework/resource_handle.h b/third_party/tflite-hdrs/tensorflow/core/framework/resource_handle.h new file mode 100644 index 00000000..393a8998 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/framework/resource_handle.h @@ -0,0 +1,206 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_FRAMEWORK_RESOURCE_HANDLE_H_ +#define TENSORFLOW_CORE_FRAMEWORK_RESOURCE_HANDLE_H_ + +#include +#include + +#include "tensorflow/core/framework/resource_base.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/type_index.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/platform/casts.h" +#include "tensorflow/core/platform/intrusive_ptr.h" +#include "tensorflow/core/platform/statusor.h" +#include "tensorflow/core/platform/tensor_coding.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/managed_stack_trace.h" + +namespace tensorflow { + +class ResourceHandleProto; + +// Class representing a handle to a tensorflow resource. Handles are +// not valid across executions, but can be serialized back and forth from within +// a single run (except for those created from MakeRefCountingHandle i.e. whose +// resource_ field is not empty). +// +// This is the native C++ class equivalent of ResourceHandleProto. They are +// separate so that kernels do not need to depend on protos. +class ResourceHandle { + public: + ResourceHandle(); + ResourceHandle(const ResourceHandleProto& proto); + ~ResourceHandle(); + + // Use this factory method if the `proto` comes from user controlled input, to + // prevent a denial of service. + static absl::Status BuildResourceHandle(const ResourceHandleProto& proto, + ResourceHandle* out); + + // Unique name for the device containing the resource. + const std::string& device() const { return device_; } + + void set_device(const std::string& device) { device_ = device; } + + // Container in which this resource is placed. + const std::string& container() const { return container_; } + void set_container(const std::string& container) { container_ = container; } + + // Unique name of this resource. + const std::string& name() const { return name_; } + void set_name(const std::string& name) { name_ = name; } + + // Hash code for the type of the resource. Is only valid in the same device + // and in the same execution. + uint64 hash_code() const { return hash_code_; } + void set_hash_code(uint64 hash_code) { hash_code_ = hash_code; } + + // For debug-only, the name of the type pointed to by this handle, if + // available. + const std::string& maybe_type_name() const { return maybe_type_name_; } + void set_maybe_type_name(const std::string& value) { + maybe_type_name_ = value; + } + + // Data types and shapes for the underlying resource. + std::vector dtypes_and_shapes() const { + return dtypes_and_shapes_; + } + void set_dtypes_and_shapes( + const std::vector& dtypes_and_shapes) { + dtypes_and_shapes_ = dtypes_and_shapes; + } + + void set_definition_stack_trace( + const absl::optional& definition_stack_trace) { + definition_stack_trace_ = definition_stack_trace; + } + + const absl::optional& definition_stack_trace() const { + return definition_stack_trace_; + } + + // Conversion to and from ResourceHandleProto + void AsProto(ResourceHandleProto* proto) const; + absl::Status FromProto(const ResourceHandleProto& proto); + + // Serialization via ResourceHandleProto + std::string SerializeAsString() const; + bool ParseFromString(const std::string& s); + + std::string DebugString() const; + + std::string SummarizeValue() const; + + // GUID for anonymous resources. Resources with this shared_name will have + // their shared_name replaced with a GUID at creation time + static constexpr const char* ANONYMOUS_NAME = + "cd2c89b7-88b7-44c8-ad83-06c2a9158347"; + + // Creates a `ResourceHandle` that holds a pointer to a resource and takes + // ownership of it. Normally a `ResourceHandle` only contains the name (and + // some other metadata) of the resource. When created via this function, + // the handle will own the resource, in the sense that it will destroy the + // resource automatically when the resource is no longer needed. It does this + // via automatic ref-counting on the resource: when the handle is copied, it + // will call `Ref` on the resource (remember that all resources inherit from + // `ResourceBase` which inherits from `RefCounted`), and when the handle is + // destroyed, it will call `Unref` on the resource. When the last handle goes + // out of scope, the resource's ref-count will go down to zero and the + // resource will be destroyed. When calling this function, the `resource` + // argument should have a ref-count of one (which is the case when the + // resource is newly created). + // + // For those familiar with `ResourceMgr`, when you create a handle by the + // `MakeResourceHandle` function in resource_mgr.h, the handle doesn't hold a + // strong reference to the resource, and the resource is owned by the + // resource manager whose strong reference must be manually deleted by + // calling `ResourceMgr::Delete`. In contrast, a handle created by this + // function holds a strong reference to the resource. The resource manager + // does not hold a strong reference to the resource. + template + static ResourceHandle MakeRefCountingHandle( + T* resource, const string& device_name, + const std::vector& dtypes_and_shapes = {}, + const absl::optional& definition_stack_trace = {}) { + return MakeRefCountingHandle(resource, device_name, TypeIndex::Make(), + dtypes_and_shapes, definition_stack_trace); + } + + static ResourceHandle MakeRefCountingHandle( + ResourceBase* resource, const string& device_name, + const TypeIndex& type_index, + const std::vector& dtypes_and_shapes = {}, + const absl::optional& definition_stack_trace = {}); + + // Pointer to the resource. + const core::IntrusivePtr& resource() const { return resource_; } + + // Gets the resource pointer in `handle` as `T*`, or an error if the actual + // resource type is not `T`. + template + StatusOr GetResource() const { + TF_RETURN_IF_ERROR(ValidateType()); + return down_cast(resource_.get()); + } + + // Returns True if the resource handle is ref-counting. + // See MakeRefCountingHandle. + bool IsRefCounting() const { return resource_.get() != nullptr; } + + // Validates that the resource type in `handle` is `T`. + template + absl::Status ValidateType() const { + return ValidateType(TypeIndex::Make()); + } + + absl::Status ValidateType(const TypeIndex& type_index) const; + + // Generates unique IDs (e.g. for names of anonymous variables) + static int64_t GenerateUniqueId(); + + private: + std::string device_; + std::string container_; + std::string name_; + uint64 hash_code_ = 0; + std::string maybe_type_name_; + std::vector dtypes_and_shapes_; + std::optional definition_stack_trace_; + // A smart pointer to the actual resource. When this field is not empty, the + // handle is in a "ref-counting" mode, owning the resource; otherwise it's in + // a "weak-ref" mode, only containing the name of the resource (conceptually a + // weak reference). + core::IntrusivePtr resource_; + static std::atomic current_id_; +}; + +// For backwards compatibility for when this was a proto +std::string ProtoDebugString(const ResourceHandle& handle); + +// Encodes a list of ResourceHandle protos in the given StringListEncoder. +void EncodeResourceHandleList(const ResourceHandle* p, int64_t n, + std::unique_ptr e); + +// Decodes a list of ResourceHandle protos from the given StringListDecoder. +bool DecodeResourceHandleList(std::unique_ptr d, + ResourceHandle* ps, int64_t n); + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_FRAMEWORK_RESOURCE_HANDLE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/framework/resource_mgr.h b/third_party/tflite-hdrs/tensorflow/core/framework/resource_mgr.h new file mode 100644 index 00000000..74e26b43 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/framework/resource_mgr.h @@ -0,0 +1,1042 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_FRAMEWORK_RESOURCE_MGR_H_ +#define TENSORFLOW_CORE_FRAMEWORK_RESOURCE_MGR_H_ + +#include +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/types/variant.h" +#include "tensorflow/core/framework/common_shape_fns.h" +#include "tensorflow/core/framework/device_attributes.pb.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/resource_base.h" +#include "tensorflow/core/framework/resource_handle.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/framework/type_index.h" +#include "tensorflow/core/framework/variant_tensor_data.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/hash/hash.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/thread_annotations.h" + +namespace tensorflow { + +// A ResourceMgr instance keeps track of named and typed resources +// grouped into containers. +// +// Each named resource is +// registered with ResourceMgr under a named "container" name. At any +// time, there is at most one instance of a resource given the container +// name, the resource type and the resource name. +// +// All resources for a given container can be dropped by one call of +// Cleanup(). +// +// E.g., +// struct MyVar : public ResourceBase { +// mutex mu; +// Tensor val; +// } +// +// ResourceMgr rm; +// +// // Create a var. +// MyVar* my_var = new MyVar; +// my_var->val = Tensor(DT_FLOAT, my_shape); +// my_var->val.flat().setZeros(); // 0 initialized. +// ctx->SetStatus(rm.Create("my_container", "my_name", my_var)); +// +// // += a variable. +// MyVar* my_var = nullptr; +// Status s = rm.Lookup("my_container", "my_name", &my_var); +// if (s.ok()) { +// my_var->val.flat() += grad; +// } +// my_var->Unref(); // Or use ScopedUnref(). +// ctx->SetStatus(s); + +// Container used for per-step resources. +class ScopedStepContainer { + public: + // step_id: the unique ID of this step. Doesn't have to be sequential, just + // has to be unique. + // cleanup: callback to delete a container of this name. + // prefix: optional string prefix to disambiguate step containers. + ScopedStepContainer(const int64_t step_id, + std::function cleanup) + : step_id_(step_id), + container_(strings::StrCat("__per_step_", step_id)), + cleanup_(cleanup), + dirty_(false) {} + + ScopedStepContainer(const int64_t step_id, + std::function cleanup, + const std::string& prefix) + : step_id_(step_id), + container_(strings::StrCat("__", prefix, "_per_step_", step_id)), + cleanup_(cleanup), + dirty_(false) {} + + ~ScopedStepContainer() { CleanUp(); } + + void CleanUp() TF_NO_THREAD_SAFETY_ANALYSIS { + // NOTE(mrry): Avoid acquiring the mutex in the case that the container is + // clean. + if (dirty_) { + mutex_lock ml(mu_); + cleanup_(container_); + dirty_ = false; + } + } + + // Pass through functions for resource lookup and creation. We do this to + // ensure that we can appropriately set the dirty_ bit in the + // ScopedStepContainer if the name of the container is used to create + // resources. + + // Pass through to MakeResourceHandle with the container name + template + ResourceHandle MakeResourceHandle( + const std::string& name, const DeviceBase& device) TF_MUST_USE_RESULT; + // Pass through to ResourceMgr::Create with the container name + template + absl::Status Create(ResourceMgr* rm, const std::string& name, T* resource); + // Pass through to ResourceMgr::Delete with the container name + template + absl::Status Delete(ResourceMgr* rm, const std::string& name); + // Pass through to ResourceMgr::Lookup with the container name + template + absl::Status Lookup(ResourceMgr* rm, const std::string& name, + T** resource) const; + // Pass through to ResourceMgr::LookupOrCreate with the container name + template + absl::Status LookupOrCreate(ResourceMgr* rm, const std::string& name, + T** resource, + std::function creator); + int64_t StepId() const { return step_id_; } + + private: + const int64_t step_id_; + const std::string container_; + const std::function cleanup_; + mutex mu_; + mutable std::atomic dirty_ TF_GUARDED_BY(mu_); +}; + +class ResourceMgr { + public: + ResourceMgr(); + explicit ResourceMgr(const std::string& default_container); + ~ResourceMgr(); + + // Returns the default container name for *this. + const std::string& default_container() const { return default_container_; } + + // Creates a resource "name" in the "container". The caller transfers + // the ownership of one ref on "resource" to *this, regardless of whether this + // operation succeeds or fails. + // + // REQUIRES: std::is_base_of + // REQUIRES: resource != nullptr. + template + absl::Status Create(const std::string& container, const std::string& name, + T* resource); + + // Creates a unowned resource "name" in the "container". The caller does NOT + // transfer the ownership of any ref on "resource" to *this, regardless of + // whether this operation succeeds or fails. + // + // After the resource is destroyed, lookups from the manager fail. + // The caller must call this->Delete() on the name to free up the memory + // entry of the name. + // + // REQUIRES: std::is_base_of + // REQUIRES: resource != nullptr. + template + absl::Status CreateUnowned(const std::string& container, + const std::string& name, T* resource); + + // If "container" has a resource "name", returns it in "*resource" and + // the caller takes the ownership of one ref on "*resource". + // + // REQUIRES: std::is_base_of + // REQUIRES: resource != nullptr + template + absl::Status Lookup(const std::string& container, const std::string& name, + T** resource) const; + + // If the resource manager has a resource matching "handle", returns it in + // "*resource" and the caller takes the ownership of one ref on "*resource". + // + // REQUIRES: resource != nullptr + absl::Status Lookup(const ResourceHandle& handle, + ResourceBase** resource) const; + + // Similar to Lookup, but looks up multiple resources at once, with only a + // single lock acquisition. If containers_and_names[i] is uninitialized + // then this function does not modify resources[i]. + template + absl::Status LookupMany( + absl::Span const> + containers_and_names, + std::vector>* resources) const; + + // If "container" has a resource "name", returns it in + // "*resource". Otherwise, invokes creator() to create the resource. + // The caller takes the ownership of one ref on "*resource". + // + // WARNING: creator() must not call any methods on ResourceMgr during its + // execution, because a non-reentrant lock is held during the creator() call + // in order to guarantee atomicity of LookupOrCreate(). + // + // REQUIRES: std::is_base_of + // REQUIRES: resource != nullptr + template + absl::Status LookupOrCreate(const std::string& container, + const std::string& name, T** resource, + std::function creator); + + // Deletes the resource "name" from the "container". + // + // REQUIRES: std::is_base_of + template + absl::Status Delete(const std::string& container, const std::string& name); + + // Deletes the resource pointed by "handle". + absl::Status Delete(const ResourceHandle& handle); + + // Deletes all resources from the "container" and removes the container. + absl::Status Cleanup(const std::string& container); + + // Deletes all resources in all containers. + void Clear(); + + // Returns a text description for all resources. + std::string DebugString() const; + + private: + typedef std::pair Key; + struct KeyHash { + std::size_t operator()(const Key& k) const { + return Hash64(k.second.data(), k.second.size(), k.first); + } + }; + struct KeyEqual { + bool operator()(const Key& x, const Key& y) const { + return (x.second == y.second) && (x.first == y.first); + } + }; + struct ResourceAndName { + std::variant, core::WeakPtr> + resource; + std::unique_ptr name; + + ResourceAndName(); + explicit ResourceAndName(const string& name); + ResourceAndName(ResourceAndName&& other) noexcept; + ~ResourceAndName(); + + ResourceAndName& operator=(ResourceAndName&&) noexcept; + + // Returns a strong reference to resource, or nullptr if the resource is + // no longer valid. + core::RefCountPtr GetResource() const; + + private: + ResourceAndName(const ResourceAndName&) = delete; + void operator=(const ResourceAndName&) = delete; + }; + typedef absl::flat_hash_map + Container; + + const std::string default_container_; + mutable mutex mu_; + absl::flat_hash_map containers_ TF_GUARDED_BY(mu_); + + template + absl::Status LookupInternal(const std::string& container, + const std::string& name, T** resource) const + TF_SHARED_LOCKS_REQUIRED(mu_); + absl::Status LookupInternal(const std::string& container, + uint64 type_hash_code, const std::string& name, + ResourceBase** resource) const + TF_SHARED_LOCKS_REQUIRED(mu_); + + absl::Status DoCreate(const std::string& container, TypeIndex type, + const std::string& name, ResourceBase* resource, + bool owns_resource) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); + + absl::Status DoLookup(const std::string& container, TypeIndex type, + const std::string& name, ResourceBase** resource) const + TF_SHARED_LOCKS_REQUIRED(mu_); + absl::Status DoLookup(const std::string& container, uint64 type_hash_code, + const std::string& type_name, + const std::string& resource_name, + ResourceBase** resource) const + TF_SHARED_LOCKS_REQUIRED(mu_); + + absl::Status DoDelete(const std::string& container, uint64 type_hash_code, + const std::string& resource_name, + const std::string& type_name); + absl::Status DoDelete(const std::string& container, TypeIndex type, + const std::string& resource_name); + + // Pops the ResourceAndName entry. The entry is moved from the list to + // the output argument `resource_and_name`. + absl::Status PopResourceAndName(const std::string& container, + uint64 type_hash_code, + const std::string& resource_name, + const std::string& type_name, + ResourceAndName& resource_and_name); + // Inserts the type name for 'hash_code' into the hash_code to type name map. + absl::Status InsertDebugTypeName(uint64 hash_code, + const std::string& type_name) + TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); + + // Returns the type name for the 'hash_code'. + // Returns "" if a resource with such a type was never inserted into + // the container. + const char* DebugTypeName(uint64 hash_code) const + TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); + + // Map from type hash_code to type name. + std::unordered_map debug_type_names_ TF_GUARDED_BY(mu_); + + ResourceMgr(const ResourceMgr&) = delete; + void operator=(const ResourceMgr&) = delete; +}; + +// Makes a resource handle with the specified type for a given container / +// name. +ResourceHandle MakeResourceHandle( + const std::string& container, const std::string& name, + const DeviceBase& device, const TypeIndex& type_index, + const std::vector& dtypes_and_shapes = {}, + const absl::optional& definition_stack_trace = {}) + TF_MUST_USE_RESULT; + +template +ResourceHandle MakeResourceHandle( + OpKernelContext* ctx, const std::string& container, const std::string& name, + const std::vector& dtypes_and_shapes = {}, + const absl::optional& definition_stack_trace = {}) { + return MakeResourceHandle(container.empty() + ? ctx->resource_manager()->default_container() + : container, + name, *ctx->device(), TypeIndex::Make(), + dtypes_and_shapes, definition_stack_trace); +} + +template +ResourceHandle MakeResourceHandle( + OpKernelConstruction* ctx, const std::string& container, + const std::string& name, + const std::vector& dtypes_and_shapes = {}, + const absl::optional& definition_stack_trace = {}) { + return MakeResourceHandle(container.empty() + ? ctx->resource_manager()->default_container() + : container, + name, *ctx->device(), TypeIndex::Make(), + dtypes_and_shapes, definition_stack_trace); +} + +absl::Status MakeResourceHandleToOutput(OpKernelContext* context, + int output_index, + const std::string& container, + const std::string& name, + const TypeIndex& type_index); + +// Returns a resource handle from a numbered op input. +const ResourceHandle& HandleFromInput(OpKernelContext* ctx, int input); + +// Safely returns a resource handle from a numbered op input. +// Prevents segfault by checking for empty resource handle. +absl::Status HandleFromInput(OpKernelContext* ctx, int input, + ResourceHandle* handle); +// Returns a resource handle by name, as defined in the OpDef. +// Also prevents segfault by checking for empty resource handle. +absl::Status HandleFromInput(OpKernelContext* ctx, absl::string_view input, + ResourceHandle* handle); + +// Create a resource pointed by a given resource handle. +// +// If successful, the caller transfers the ownership of one ref on `resource` to +// `ctx->resource_mgr()`. +template +absl::Status CreateResource(OpKernelContext* ctx, const ResourceHandle& p, + T* value); + +// Looks up a resource pointed by a given resource handle. +// +// If the lookup is successful, the caller takes the ownership of one ref on +// `*value`, and must call its `Unref()` method when it has finished using it. +template +absl::Status LookupResource(OpKernelContext* ctx, const ResourceHandle& p, + T** value); + +// Looks up a resource pointed by a given resource handle. +// +// Prefer usage of LookupResource taking `core::RefCountPtr` to avoid +// requiring the caller to explicitly call `Unref()`. +template +absl::Status LookupResource(OpKernelContext* ctx, const ResourceHandle& p, + core::RefCountPtr* value); + +// Looks up multiple resources pointed by a sequence of resource handles. If +// p[i] is uninitialized then values[i] is unmodified. +template +absl::Status LookupResources(OpKernelContext* ctx, + absl::Span p, + std::vector>* values); + +// Looks up or creates a resource. +// +// If successful, the caller takes the ownership of one ref on `*value`, and +// must call its `Unref()` method when it has finished using it. If the +// `creator` is invoked, its reference on the created resource is transferred +// to `ctx->resource_mgr()`. +// +// Prefer usage of LookupOrCreateResource taking `core::RefCountPtr` to avoid +// requiring the caller to explicitly call `Unref()`. +template +absl::Status LookupOrCreateResource(OpKernelContext* ctx, + const ResourceHandle& p, T** value, + std::function creator); + +// Looks up or creates a resource. +template +absl::Status LookupOrCreateResource(OpKernelContext* ctx, + const ResourceHandle& p, + core::RefCountPtr* value, + std::function creator); + +// Destroys a resource pointed by a given resource handle. +template +absl::Status DeleteResource(OpKernelContext* ctx, const ResourceHandle& p); + +// Same as above, but uses the hash code of the type directly. +// The type name information will be missing in the debug output when the +// resource is not present in the container. +absl::Status DeleteResource(OpKernelContext* ctx, const ResourceHandle& p); + +// Policy helper to decide which container/shared_name to use for a +// stateful kernel that accesses shared resource. +class ContainerInfo { + public: + // Analyze the node attribute of 'ndef' and decides the container and + // resource name the kernel should use for accessing the shared + // resource. + // + // 'ndef' is expected to have node attribute "container" and + // "shared_name". Returns non-OK if they are not provided or they are + // invalid. + // + // The policy is as following: + // * If the attribute "container" is non-empty, it is used as is. + // Otherwise, uses the resource manager's default container. + // * If the attribute "shared_name" is non-empty, it is used as is. + // Otherwise, if "use_node_name_as_default" is true, the kernel's + // node name is used as the resource name. Otherwise, a string + // unique to this process is used. + absl::Status Init(ResourceMgr* rmgr, const NodeDef& ndef, + bool use_node_name_as_default); + absl::Status Init(ResourceMgr* rmgr, const NodeDef& ndef) { + return Init(rmgr, ndef, false); + } + + // The policy decides that the kernel should access the resource in + // resource_manager(), the resource is in the container() and its + // name is name(). If resource_is_private_to_kernel() is true, the + // kernel should delete the resource when the kernel is deleted. + ResourceMgr* resource_manager() const { return rmgr_; } + const std::string& container() const { return container_; } + const std::string& name() const { return name_; } + bool resource_is_private_to_kernel() const { + return resource_is_private_to_kernel_; + } + + // Returns a readable string for *this. + std::string DebugString() const; + + private: + ResourceMgr* rmgr_ = nullptr; + std::string container_; + std::string name_; + bool resource_is_private_to_kernel_ = false; +}; + +// Helper for kernels to obtain 'resource' from the +// ctx->resource_manager(). +// +// "input_name" specifies the kernel's ref input which gives a string +// tensor with two elements, which specifies the container and +// resource name. +// +// Returns OK if the resource is found and transfers one ref of +// *resource to the caller. Otherwise, returns an error. +template +absl::Status GetResourceFromContext(OpKernelContext* ctx, + const std::string& input_name, + T** resource); + +// Utility op kernel to check if a handle to resource type T is initialized. +template +class IsResourceInitialized : public OpKernel { + public: + explicit IsResourceInitialized(OpKernelConstruction* c) : OpKernel(c) {} + + void Compute(OpKernelContext* ctx) override; +}; + +// Registers an op which produces just a resource handle to a resource of the +// specified type. The type will be a part of the generated op name. +// TODO(apassos): figure out how to get non-cpu-allocated tensors to work +// through constant folding so this doesn't have to be marked as stateful. +#define REGISTER_RESOURCE_HANDLE_OP(Type) \ + REGISTER_OP(#Type "HandleOp") \ + .Attr("container: string = ''") \ + .Attr("shared_name: string = ''") \ + .Output("resource: resource") \ + .SetIsStateful() \ + .SetShapeFn(tensorflow::shape_inference::ScalarShape) + +// Utility op kernel to produce a handle to a resource of type T. +template +class ResourceHandleOp : public OpKernel { + public: + explicit ResourceHandleOp(OpKernelConstruction* context); + + void Compute(OpKernelContext* ctx) override; + + bool IsExpensive() override { return false; } + + private: + std::string container_; + std::string name_; + mutex mutex_; + Tensor resource_; + std::atomic initialized_{false}; +}; + +// Utility op kernel to produce a handle to a resource of type T. +template +class ResourceHandlesOp : public OpKernel { + public: + explicit ResourceHandlesOp(OpKernelConstruction* context); + + void Compute(OpKernelContext* ctx) override; + + bool IsExpensive() override { return false; } + + private: + std::vector containers_; + std::vector names_; + mutex mutex_; + std::vector resources_; + std::atomic initialized_{false}; +}; + +// Registers a kernel for an op which produces a handle to a resource of the +// specified type. +#define REGISTER_RESOURCE_HANDLE_KERNEL(Type) \ + REGISTER_KERNEL_BUILDER(Name(#Type "HandleOp").Device(DEVICE_CPU), \ + ResourceHandleOp) + +// This class is used to guarantee that an anonymous resource is deleted +// (irrespective of whether a resource deleter op is called explicitly or +// the execution encounters an error before the op runs). +// +// This is achieved by wrapping an instance of this class into a variant +// tensor which is passed as an input to a resource deleter op. If the +// execution encounters an error before the op runs, the tensor will be +// destroyed, essentially triggering the iterator deletion. +// NOTE: This is not a feature-complete implementation of the DT_VARIANT +// specification. In particular, we cannot serialize the `ResourceMgr` +// object, so the `Encode()` and `Decode()` methods are not implemented. +class ResourceDeleter { + public: + ResourceDeleter() : deleter_() {} + + ResourceDeleter(ResourceHandle handle, ResourceMgr* resource_manager) + : deleter_(std::make_shared(handle, resource_manager)) {} + + ResourceDeleter(ResourceDeleter&& rhs) : deleter_(std::move(rhs.deleter_)) { + VLOG(3) << "ResourceDeleter move constructor called."; + } + + ResourceDeleter(const ResourceDeleter& rhs) : deleter_(rhs.deleter_) { + VLOG(3) << "ResourceDeleter copy constructor called."; + } + + ResourceDeleter& operator=(const ResourceDeleter& rhs) = delete; + + ResourceDeleter& operator=(ResourceDeleter&& rhs) = default; + + virtual ~ResourceDeleter() { + VLOG(3) << "ResourceDeleter destructor called."; + } + + void Encode(VariantTensorData*) const { + LOG(ERROR) << "The Encode() method is not implemented for ResourceDeleter " + "objects."; + } + + bool Decode(const VariantTensorData&) { + LOG(ERROR) << "The Decode() method is not implemented for ResourceDeleter " + "objects"; + return false; // Not supported. + } + + private: + // Helper that performs reference counting for the parent class and deletes + // the iterator resource when the refcount goes to zero. + // + // NOTE: The object is borrowing a pointer to the resource manager. + // Consequently, the tensor containing this object should not escape the + // function in which was created (so that it is guaranteed that the resource + // manager will outlive it). + struct Helper { + Helper(ResourceHandle handle, ResourceMgr* resource_manager) + : handle(handle), resource_manager(resource_manager) {} + + Helper(const Helper& rhs) = delete; + Helper(Helper&& rhs) = delete; + + ~Helper() { + VLOG(3) << "Deleting Resource: " << handle.DebugString(); + resource_manager->Delete(handle).IgnoreError(); + } + + ResourceHandle handle; + ResourceMgr* resource_manager; // not owned + }; + + std::shared_ptr deleter_; +}; + +// Implementation details below. + +template +void CheckDeriveFromResourceBase() { + static_assert(std::is_base_of::value, + "T must derive from ResourceBase"); +} + +template +absl::Status ResourceMgr::Create(const std::string& container, + const std::string& name, T* resource) { + CheckDeriveFromResourceBase(); + CHECK(resource != nullptr); + mutex_lock l(mu_); + return DoCreate(container, TypeIndex::Make(), name, resource, + /* owns_resource */ true); +} + +template +absl::Status ResourceMgr::CreateUnowned(const std::string& container, + const std::string& name, T* resource) { + CheckDeriveFromResourceBase(); + mutex_lock l(mu_); + return DoCreate(container, TypeIndex::Make(), name, resource, + /* owns_resource */ false); +} + +template +absl::Status ResourceMgr::Lookup(const std::string& container, + const std::string& name, T** resource) const { + CheckDeriveFromResourceBase(); + tf_shared_lock l(mu_); + return LookupInternal(container, name, resource); +} + +template +absl::Status ResourceMgr::LookupMany( + absl::Span const> + containers_and_names, + std::vector>* resources) const { + CheckDeriveFromResourceBase(); + tf_shared_lock l(mu_); + resources->resize(containers_and_names.size()); + for (size_t i = 0; i < containers_and_names.size(); ++i) { + T* resource; + absl::Status s = LookupInternal( + *containers_and_names[i].first, *containers_and_names[i].second, + &resource); + if (s.ok()) { + (*resources)[i].reset(resource); + } + } + return absl::OkStatus(); +} + +// Simple wrapper to allow conditional dynamic / static casts. +template +struct TypeCastFunctor { + static T* Cast(ResourceBase* r) { return static_cast(r); } +}; + +template +struct TypeCastFunctor { + static T* Cast(ResourceBase* r) { return dynamic_cast(r); } +}; + +template +absl::Status ResourceMgr::LookupInternal(const std::string& container, + const std::string& name, + T** resource) const { + ResourceBase* found = nullptr; + absl::Status s = DoLookup(container, TypeIndex::Make(), name, &found); + if (s.ok()) { + // It's safe to down cast 'found' to T* since + // typeid(T).hash_code() is part of the map key. + *resource = TypeCastFunctor::Cast(found); + } + return s; +} + +template +absl::Status ResourceMgr::LookupOrCreate( + const std::string& container, const std::string& name, T** resource, + std::function creator) { + CheckDeriveFromResourceBase(); + *resource = nullptr; + absl::Status s; + { + tf_shared_lock l(mu_); + s = LookupInternal(container, name, resource); + if (s.ok()) return s; + } + mutex_lock l(mu_); + s = LookupInternal(container, name, resource); + if (s.ok()) return s; + TF_RETURN_IF_ERROR(creator(resource)); + s = DoCreate(container, TypeIndex::Make(), name, *resource, + /* owns_resource */ true); + if (!s.ok()) { + return errors::Internal("LookupOrCreate failed unexpectedly"); + } + (*resource)->Ref(); + return s; +} + +template +absl::Status ResourceMgr::Delete(const std::string& container, + const std::string& name) { + CheckDeriveFromResourceBase(); + return DoDelete(container, TypeIndex::Make(), name); +} + +template +absl::Status GetResourceFromContext(OpKernelContext* ctx, + const std::string& input_name, + T** resource) { + DataType dtype; + TF_RETURN_IF_ERROR(ctx->input_dtype(input_name, &dtype)); + if (dtype == DT_RESOURCE) { + const Tensor* handle; + TF_RETURN_IF_ERROR(ctx->input(input_name, &handle)); + return LookupResource(ctx, handle->scalar()(), resource); + } + std::string container; + std::string shared_name; + { + mutex* mu; + TF_RETURN_IF_ERROR(ctx->input_ref_mutex(input_name, &mu)); + mutex_lock l(*mu); + Tensor tensor; + TF_RETURN_IF_ERROR(ctx->mutable_input(input_name, &tensor, true)); + if (tensor.NumElements() != 2) { + return errors::InvalidArgument( + "Resource handle must have 2 elements, but had shape: ", + tensor.shape().DebugString()); + } + container = tensor.flat()(0); + shared_name = tensor.flat()(1); + } + return ctx->resource_manager()->Lookup(container, shared_name, resource); +} + +namespace internal { + +absl::Status ValidateDevice(OpKernelContext* ctx, const ResourceHandle& p); + +template +absl::Status ValidateDeviceAndType(OpKernelContext* ctx, + const ResourceHandle& p) { + TF_RETURN_IF_ERROR(internal::ValidateDevice(ctx, p)); + TF_RETURN_IF_ERROR(p.ValidateType()); + return absl::OkStatus(); +} + +} // namespace internal + +// Creates the resource pointed at by "p". The caller transfers the ownership of +// one ref on "*value" to the resource manager in "ctx", regardless of whether +// this operation succeeds or fails. +template +absl::Status CreateResource(OpKernelContext* ctx, const ResourceHandle& p, + T* value) { + TF_RETURN_IF_ERROR(internal::ValidateDeviceAndType(ctx, p)); + return ctx->resource_manager()->Create(p.container(), p.name(), value); +} + +// Finds the resource as "*value" from the handle. If the handle is +// ref-counting, returns the resource owned by the handle. Otherwise, looks up +// the resource matching "p" from resource manager associated with ctx. +// Always returns a new reference to the resource in "*value". The caller shall +// call (*value)->Unref(). +template +absl::Status LookupResource(OpKernelContext* ctx, const ResourceHandle& p, + T** value) { + TF_RETURN_IF_ERROR(internal::ValidateDeviceAndType(ctx, p)); + if (p.IsRefCounting()) { + TF_ASSIGN_OR_RETURN(*value, p.GetResource()); + // Transfers out a new reference. + (*value)->Ref(); + return absl::OkStatus(); + } + + return ctx->resource_manager()->Lookup(p.container(), + p.name(), value); +} + +// Finds the resource as "*value" from the handle. This is a type-erased +// variant of LookupResource above. +absl::Status LookupResource(OpKernelContext* ctx, const ResourceHandle& p, + ResourceBase** value); + +// If the resource manager in "ctx" has a resource matching "p", returns it in +// "*value". +template +absl::Status LookupResource(OpKernelContext* ctx, const ResourceHandle& p, + core::RefCountPtr* value) { + T* raw_ptr = nullptr; + TF_RETURN_IF_ERROR(LookupResource(ctx, p, &raw_ptr)); + value->reset(raw_ptr); + + return absl::OkStatus(); +} + +// Similar to Lookup, but looks up multiple resources at once, with only a +// single lock acquisition. +template +absl::Status LookupResources(OpKernelContext* ctx, + absl::Span p, + std::vector>* values) { + std::vector> containers_and_names( + p.size()); + for (size_t i = 0; i < p.size(); ++i) { + TF_RETURN_IF_ERROR(internal::ValidateDeviceAndType(ctx, *p[i])); + containers_and_names[i] = {&p[i]->container(), &p[i]->name()}; + } + return ctx->resource_manager()->LookupMany(containers_and_names, values); +} + +// If the resource manager in "ctx" has a resource pointed at by "p", returns +// it in "*value". Otherwise, invokes creator() to create the resource. +// The caller takes the ownership of one ref on "*value". +// +// WARNING: creator() must not call any methods on the resource manager during +// its execution, because a non-reentrant lock is held during the creator() call +// in order to guarantee atomicity of LookupOrCreateResource(). +template +absl::Status LookupOrCreateResource(OpKernelContext* ctx, + const ResourceHandle& p, T** value, + std::function creator) { + TF_RETURN_IF_ERROR(internal::ValidateDeviceAndType(ctx, p)); + return ctx->resource_manager()->LookupOrCreate(p.container(), p.name(), value, + creator); +} + +// If the resource manager in "ctx" has a resource pointed at by "p", returns +// it in "*value". Otherwise, invokes creator() to create the resource. +// +// WARNING: creator() must not call any methods on the resource manager during +// its execution, because a non-reentrant lock is held during the creator() call +// in order to guarantee atomicity of LookupOrCreateResource(). +template +absl::Status LookupOrCreateResource(OpKernelContext* ctx, + const ResourceHandle& p, + core::RefCountPtr* value, + std::function creator) { + T* raw_ptr = nullptr; + TF_RETURN_IF_ERROR(LookupOrCreateResource(ctx, p, &raw_ptr, creator)); + value->reset(raw_ptr); + + return absl::OkStatus(); +} + +// Deletes the resource pointed by "p", using the resource manager in "ctx". +template +absl::Status DeleteResource(OpKernelContext* ctx, const ResourceHandle& p) { + TF_RETURN_IF_ERROR(internal::ValidateDeviceAndType(ctx, p)); + // This is a noop because ResourceMgr does not hold a reference. + // NOTE(feyu): if we can convert all resources handle to ref-counting, then + // DeleteResource can be removed. + if (p.IsRefCounting()) { + return absl::OkStatus(); + } + return ctx->resource_manager()->Delete(p.container(), p.name()); +} + +// Deletes the resource pointed by "p", using the resource manager in "ctx". +absl::Status DeleteResource(OpKernelContext* ctx, const ResourceHandle& p); + +template +void IsResourceInitialized::Compute(OpKernelContext* ctx) { + Tensor* output; + OP_REQUIRES_OK(ctx, ctx->allocate_output(0, {}, &output)); + T* object; + bool found; + if (LookupResource(ctx, HandleFromInput(ctx, 0), &object).ok()) { + found = true; + object->Unref(); + } else { + found = false; + } + + output->flat()(0) = found; +} + +template +ResourceHandleOp::ResourceHandleOp(OpKernelConstruction* context) + : OpKernel(context) { + OP_REQUIRES_OK(context, context->GetAttr("container", &container_)); + OP_REQUIRES_OK(context, context->GetAttr("shared_name", &name_)); +} + +template +void ResourceHandleOp::Compute(OpKernelContext* ctx) { + if (name_ == ResourceHandle::ANONYMOUS_NAME) { + AllocatorAttributes attr; + attr.set_on_host(true); + Tensor handle; + OP_REQUIRES_OK( + ctx, ctx->allocate_temp(DT_RESOURCE, TensorShape({}), &handle, attr)); + handle.scalar()() = MakeResourceHandle( + ctx, container_, name_, /*dtypes_and_shapes=*/{}, ctx->stack_trace()); + ctx->set_output(0, handle); + } else { + if (!initialized_.load()) { + mutex_lock ml(mutex_); + // Checking again to see if another thread has initialized the resource. + if (!initialized_.load()) { + AllocatorAttributes attr; + attr.set_on_host(true); + OP_REQUIRES_OK(ctx, ctx->allocate_temp(DT_RESOURCE, TensorShape({}), + &resource_, attr)); + resource_.scalar()() = + MakeResourceHandle(ctx, container_, name_, + /*dtypes_and_shapes=*/{}, ctx->stack_trace()); + initialized_.store(true); + } + } + ctx->set_output(0, resource_); + } +} + +template +ResourceHandlesOp::ResourceHandlesOp(OpKernelConstruction* context) + : OpKernel(context) { + int n; + OP_REQUIRES_OK(context, context->GetAttr("N", &n)); + OP_REQUIRES_OK(context, context->GetAttr("containers", &containers_)); + OP_REQUIRES_OK(context, context->GetAttr("shared_names", &names_)); + OP_REQUIRES( + context, containers_.size() == n, + errors::InvalidArgument("Number of containers (", containers_.size(), + ") must be equal to N (", n, ")")); + OP_REQUIRES(context, names_.size() == n, + errors::InvalidArgument("Number of names (", containers_.size(), + ") must be equal to N (", n, ")")); + resources_.resize(n); +} + +template +void ResourceHandlesOp::Compute(OpKernelContext* ctx) { + if (!initialized_.load()) { + mutex_lock ml(mutex_); + // Checking again to see if another thread has initialized the resource. + if (!initialized_.load()) { + AllocatorAttributes attr; + attr.set_on_host(true); + for (size_t i = 0; i < resources_.size(); ++i) { + OP_REQUIRES_OK(ctx, ctx->allocate_temp(DT_RESOURCE, TensorShape({}), + &resources_[i], attr)); + ResourceHandle h = + MakeResourceHandle(ctx, containers_[i], names_[i]); + resources_[i].template scalar()() = h; + } + initialized_.store(true); + } + } + for (size_t i = 0; i < resources_.size(); ++i) { + ctx->set_output(i, resources_[i]); + } +} + +template +ResourceHandle ScopedStepContainer::MakeResourceHandle( + const std::string& name, const DeviceBase& device) { + mutex_lock ml(mu_); + dirty_ = true; + return tensorflow::MakeResourceHandle(container_, name, device, + TypeIndex::Make(), {}); +} + +template +absl::Status ScopedStepContainer::Lookup(ResourceMgr* rm, + const std::string& name, + T** resource) const { + return rm->Lookup(container_, name, resource); +} + +template +absl::Status ScopedStepContainer::LookupOrCreate( + ResourceMgr* rm, const std::string& name, T** resource, + std::function creator) { + mutex_lock ml(mu_); + dirty_ = true; + return rm->LookupOrCreate(container_, name, resource, creator); +} + +template +absl::Status ScopedStepContainer::Create(ResourceMgr* rm, + const std::string& name, T* resource) { + mutex_lock ml(mu_); + dirty_ = true; + return rm->Create(container_, name, resource); +} + +template +absl::Status ScopedStepContainer::Delete(ResourceMgr* rm, + const std::string& name) { + return rm->Delete(container_, name); +} + +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_FRAMEWORK_RESOURCE_MGR_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/framework/resource_op_kernel.h b/third_party/tflite-hdrs/tensorflow/core/framework/resource_op_kernel.h new file mode 100644 index 00000000..9982c02f --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/framework/resource_op_kernel.h @@ -0,0 +1,153 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_FRAMEWORK_RESOURCE_OP_KERNEL_H_ +#define TENSORFLOW_CORE_FRAMEWORK_RESOURCE_OP_KERNEL_H_ + +#include + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/op_requires.h" +#include "tensorflow/core/framework/resource_mgr.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/refcount.h" +#include "tensorflow/core/platform/thread_annotations.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +// ResourceOpKernel is a virtual base class for resource op implementing +// interface type T. The inherited op looks up the resource name (determined by +// ContainerInfo), and creates a new resource if necessary. +// +// Requirements: +// - Op must be marked as stateful. +// - Op must have `container` and `shared_name` attributes. Empty `container` +// means using the default container. Empty `shared_name` means private +// resource. +// - Subclass must override CreateResource(). +// - Subclass is encouraged to override VerifyResource(). +template +class ResourceOpKernel : public OpKernel { + public: + explicit ResourceOpKernel(OpKernelConstruction* context) : OpKernel(context) { + has_resource_type_ = (context->output_type(0) == DT_RESOURCE); + if (!has_resource_type_) { + // The resource variant of the op may be placed on non-CPU devices, but + // this allocation is always on the host. Fortunately we don't need it in + // the resource case. + OP_REQUIRES_OK(context, context->allocate_temp( + DT_STRING, TensorShape({2}), &tensor_)); + } + } + + // The resource is deleted from the resource manager only when it is private + // to kernel. Ideally the resource should be deleted when it is no longer held + // by anyone, but it would break backward compatibility. + ~ResourceOpKernel() override { + if (cinfo_.resource_is_private_to_kernel()) { + if (!cinfo_.resource_manager() + ->template Delete(cinfo_.container(), cinfo_.name()) + .ok()) { + // Do nothing; the resource can have been deleted by session resets. + } + } + } + + void Compute(OpKernelContext* context) override TF_LOCKS_EXCLUDED(mu_) { + mutex_lock l(mu_); + core::RefCountPtr resource_ref_ptr = weak_resource_.GetNewRef(); + if (resource_ref_ptr == nullptr) { + ResourceMgr* mgr = context->resource_manager(); + OP_REQUIRES_OK(context, cinfo_.Init(mgr, def())); + + T* resource; + OP_REQUIRES_OK(context, + mgr->LookupOrCreate( + cinfo_.container(), cinfo_.name(), &resource, + [this](T** ret) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { + absl::Status s = CreateResource(ret); + if (!s.ok() && *ret != nullptr) { + CHECK((*ret)->Unref()); + } + return s; + })); + // Here the code releases the reference to the resource created by this op + // and only holds a WeakPtr to the resource. This way the lifetime of the + // resource is owned by the container; otherwise the container may be + // cleared (e.g. a Session::Reset()) but the resource lives on inside this + // op, causing later lookups in the container by handle to fail. + core::ScopedUnref resource_unref(resource); + OP_REQUIRES_OK(context, VerifyResource(resource)); + weak_resource_ = core::WeakPtr(resource); + // TODO(b/243544755): delete after scam migrates ResourceKernelOp + // subclasses to get_resource() in TF 2.11. + resource_ = resource; + + if (!has_resource_type_) { + auto h = tensor_.template flat(); + h(0) = cinfo_.container(); + h(1) = cinfo_.name(); + } + } + if (has_resource_type_) { + OP_REQUIRES_OK(context, MakeResourceHandleToOutput( + context, 0, cinfo_.container(), cinfo_.name(), + TypeIndex::Make())); + } else { + context->set_output_ref(0, &mu_, &tensor_); + } + } + + protected: + // Variables accessible from subclasses. + mutex mu_; + ContainerInfo cinfo_ TF_GUARDED_BY(mu_); + // TODO(b/243544755): delete after scam migrates ResourceKernelOp subclasses + // to get_resource() in TF 2.11. + ABSL_DEPRECATED("Use get_resource() instead.") + T* resource_ TF_GUARDED_BY(mu_) = nullptr; + + core::RefCountPtr get_resource() TF_LOCKS_EXCLUDED(mu_) { + mutex_lock lock(mu_); + return weak_resource_.GetNewRef(); + } + + private: + core::WeakPtr weak_resource_ TF_GUARDED_BY(mu_) = + core::WeakPtr(nullptr); + + // Must return a T descendant allocated with new that ResourceOpKernel will + // take ownership of. + virtual absl::Status CreateResource(T** resource) + TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) = 0; + + // During the first Compute(), resource is either created or looked up using + // shared_name. In the latter case, the resource found should be verified if + // it is compatible with this op's configuration. The verification may fail in + // cases such as two graphs asking queues of the same shared name to have + // inconsistent capacities. + virtual absl::Status VerifyResource(T* resource) { return absl::OkStatus(); } + + Tensor tensor_ TF_GUARDED_BY(mu_); + + // Is the output of the operator of type DT_RESOURCE? + bool has_resource_type_; +}; +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_FRAMEWORK_RESOURCE_OP_KERNEL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/framework/resource_var.h b/third_party/tflite-hdrs/tensorflow/core/framework/resource_var.h new file mode 100644 index 00000000..6c0a8d96 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/framework/resource_var.h @@ -0,0 +1,153 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_FRAMEWORK_RESOURCE_VAR_H_ +#define TENSORFLOW_CORE_FRAMEWORK_RESOURCE_VAR_H_ + +#include + +#include "tensorflow/core/framework/resource_base.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/lib/core/status.h" + +// Forward declarations to avoid introducing a dependency on headers in +// "tensorflow/core/graph/...". +class GraphDefBuilder; + +namespace tensorflow { + +// Resource stored by variables in the resource manager (new, resource-style +// version). +// +// These variables have a mixed access mode: they can operate on copy-on-write +// mode (the default) or copy-on-read mode (used only for sparse access). +// +// When copy-on-write mode is enabled reading the value of the variable involves +// grabbing its mutex in shared mode and aliasing the internal tensor as the +// output of the read operation, increasing its reference count. Writing, +// conversely, works by, under an exclusive lock, detecting whether there are +// outstanding aliases of the tensor, using the reference count, copying the +// tensor if they exist, and writing to either the original or a copy with no +// outstanding aliases. Sparse operations are not supported in copy-on-write +// mode. +// +// When a variable is accessed sparsely it switches to copy-on-read mode. To +// switch we need to grab an exclusive lock and might (if there are aliases) +// need to copy the entire tensor. Once copy-on-read mode is enabled, no tensor +// is allowed to alias the variable's internal tensor. This means dense reads +// must return a copy of the variable, done while holding a shared lock. Dense +// writes do not need to check whether aliases exist, and can always write +// directly to the buffer without making a copy, while holding an exclusive +// lock. Sparse reads and sparse writes, on the other hand, can be done under a +// shared or exclusive mutex (the damage from writes under a shared mutex is +// limited since no other buffer is allowed to alias the variable's +// buffer). Using an exclusive mutex disallows concurrent writes and concurrent +// sparse reads, providing some extra safety at the expense of performance, +// while shared mutex allow for "hogwild" behavior. Doing sparse writes under a +// shared mutex prevents them from overlapping with dense writes, which is +// necessary as dense writes can change the shape the of the tensor. +// +// Transitioning a variable from copy-on-read mode to copy-on-write mode is +// currently not supported. To upgrade a variable from copy-on-write to +// copy-on-read use `EnsureSparseVariableAccess()`, and then grab the variable's +// mutex as desired. To access the variable in dense mode grab the mutex either +// directly or via `MaybeLockVariableInputMutexesInOrder` on all variables being +// modified and then call `PrepareToUpdateVariable` on them in any order. +class Var : public ResourceBase { + public: + explicit Var(DataType dtype) : tensor_(dtype) {} + explicit Var(DataType dtype, std::string& debug_name) : tensor_(dtype) { + debug_name_ = debug_name; + } + + // When locking multiple variables, the locks must be acquired in order of + // increasing mu() address. + // TODO(ebrevdo): Use LockSet instead of exposing mu. + mutex* mu() { return &mu_; } + Tensor* tensor() { return &tensor_; } + + // Uninitializes the variable, by reverting the state of the tensor to + // the state when the variable is first created. + void Uninitialize() { + // move frees the buffer of the tensor after unused goes out of scope. + Tensor unused = std::move(tensor_); + is_initialized = false; + } + + absl::Status AsGraphDef(GraphDefBuilder* builder, Node** out) const override; + + std::string DebugString() const override { + return strings::StrCat(DataTypeString(tensor_.dtype()), "/", + tensor_.shape().DebugString()); + } + + std::string MakeRefCountingHandleName(int64_t resource_id) const override; + + // Only used in the resource variable path. In resource variables, + // tensor.IsInitialized() can be true (i.e. have memory allocated to it) while + // there is not a good value there due to a race condition, and it's possible + // to stumble upon this during variable.initialized_value(). So it's best to + // just store directly whether the variable is initialized. + bool is_initialized = false; // TF_GUARDED_BY(mu_) but annotalysis doesn't + // like it. + + // Also fake-guarded by mu_. Should be set to True whenever any sparse + // operation uses the variable. Once this is true no tensor is allowed to + // alias the memory of the variable, and we always copy the variable on + // reads. This allows sparse operations to happen with only a shared lock if + // so desired. + std::atomic copy_on_read_mode{false}; + + private: + mutex mu_; + Tensor tensor_; + std::string debug_name_; + + ~Var() override {} + Var(const Var&) = delete; + void operator=(const Var&) = delete; +}; + +// Does unlock and unref automatically when going out of scope, and also +// supports early manual release. +class TF_SCOPED_LOCKABLE ScopedUnlockUnrefVar { + public: + explicit ScopedUnlockUnrefVar(Var* var) TF_EXCLUSIVE_LOCK_FUNCTION(var_->mu()) + : var_(var) { + if (var_) { + var_->mu()->lock(); + } + } + void Release() TF_UNLOCK_FUNCTION() { + if (var_) { + var_->mu()->unlock(); + var_->Unref(); + var_ = nullptr; + } + } + ~ScopedUnlockUnrefVar() TF_UNLOCK_FUNCTION() { Release(); } + + private: + Var* var_; + + ScopedUnlockUnrefVar(const ScopedUnlockUnrefVar&) = delete; + ScopedUnlockUnrefVar(ScopedUnlockUnrefVar&&) = delete; + ScopedUnlockUnrefVar& operator=(const ScopedUnlockUnrefVar&) = delete; + ScopedUnlockUnrefVar& operator=(ScopedUnlockUnrefVar&&) = delete; +}; + +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_FRAMEWORK_RESOURCE_VAR_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/framework/rng_alg.h b/third_party/tflite-hdrs/tensorflow/core/framework/rng_alg.h new file mode 100644 index 00000000..fd756c87 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/framework/rng_alg.h @@ -0,0 +1,63 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_FRAMEWORK_RNG_ALG_H_ +#define TENSORFLOW_CORE_FRAMEWORK_RNG_ALG_H_ + +#include "tensorflow/core/platform/logging.h" + +namespace tensorflow { + +enum Algorithm { + // The Philox algorithm, as described in paper + // ['Parallel Random Numbers: As Easy as 1, 2, 3'] + // (https://www.thesalmons.org/john/random123/papers/random123sc11.pdf) + RNG_ALG_PHILOX = 1, + // The ThreeFry algorithm, as described in paper + // ['Parallel Random Numbers: As Easy as 1, 2, 3'] + // (https://www.thesalmons.org/john/random123/papers/random123sc11.pdf) + RNG_ALG_THREEFRY = 2, + // An algorithm auto-selected by the system according to device type. + RNG_ALG_AUTO_SELECT = 3 +}; + +// Same as `Algorithm`, but without AUTO_SELECT. We use C++ compiler's -Wswitch +// and -Werror to check that `switch` covers all cases. When the algorithm +// auto-selection has been resolved, we use this type so that +// we don't need to (unnecessarily) handle the AUTO_SELECT case. +enum class ConcreteRngAlgorithm { + RNG_ALG_PHILOX = 1, + RNG_ALG_THREEFRY = 2, +}; + +// Gets the counter size (in unit of uint64) for a counter-based RNG +// algorithm `alg`. Callers of this function must ensure that `alg` doesn't have +// non-enumerator values. +inline int GetCounterSize(ConcreteRngAlgorithm alg) { + switch (alg) { + case ConcreteRngAlgorithm::RNG_ALG_PHILOX: + return 2; + case ConcreteRngAlgorithm::RNG_ALG_THREEFRY: + return 1; + } + LOG(ERROR) << "This point shouldn't have been reached."; +} +static constexpr int RNG_MAX_COUNTER_SIZE = 2; + +static constexpr int RNG_KEY_SIZE = 1; + +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_FRAMEWORK_RNG_ALG_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/framework/run_handler.h b/third_party/tflite-hdrs/tensorflow/core/framework/run_handler.h new file mode 100644 index 00000000..148378bc --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/framework/run_handler.h @@ -0,0 +1,315 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_FRAMEWORK_RUN_HANDLER_H_ +#define TENSORFLOW_CORE_FRAMEWORK_RUN_HANDLER_H_ + +#include "tensorflow/core/lib/core/threadpool.h" +#include "tensorflow/core/lib/histogram/histogram.h" +#include "tensorflow/core/platform/context.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/thread_annotations.h" +#include "tensorflow/core/protobuf/config.pb.h" + +namespace Eigen { +struct ThreadPoolDevice; +} + +namespace tensorflow { + +class RunHandler; + +// RunHandlerPool is a fixed size pool of pre-allocated RunHandlers +// that can be used for tracking inter-op work for a given Session::Run(). +// RunHandler(s) in the pool are initially 'inactive'. A RunHandler becomes +// 'active' when its unique_ptr is returned by Get() and is being used by a +// client. It becomes 'inactive' once more when its unique_ptr gets destroyed. +// +// Expected usage: +// +// * Create a single RunHandlerPool (say run_handler_pool_). +// +// * When a Session::Run() is invoked, obtain a handler by: +// auto handler = run_handler_pool_->Get(); +// +// * Use handler for scheduling all inter-op work by: +// handler->ScheduleInterOpClosure(closure); +// +// This class is thread safe. +class RunHandlerPool { + public: + explicit RunHandlerPool(int num_inter_op_threads); + + RunHandlerPool(int num_inter_op_threads, int num_intra_op_threads); + ~RunHandlerPool(); + + // Returns an inactive RunHandler from the pool. + // + // RunHandlers in RunHandlerPool are initially 'inactive'. + // A RunHandler becomes 'active' when its unique_ptr its returned by Get() + // and is being used by a client. It becomes 'inactive' once more when the + // unique_ptr is destroyed. + // + // Will block unless there is an inactive handler. + std::unique_ptr Get( + int64_t step_id = 0, int64_t timeout_in_ms = 0, + const RunOptions::Experimental::RunHandlerPoolOptions& options = + RunOptions::Experimental::RunHandlerPoolOptions()); + + // Get the priorities for active handlers. The return result is with the same + // order of the active handler list. + std::vector GetActiveHandlerPrioritiesForTesting() const; + + private: + class Impl; + friend class RunHandler; + + std::unique_ptr impl_; +}; + +// RunHandler can be used to schedule inter/intra-op closures to run on a global +// pool shared across all Session::Run(s). The closures are enqueued to a +// handler specific queue, from which the work is stolen in a priority order +// (time of the Get() call). +// +// It can only be created via RunHandlerPool::Get(). +// +// This class can be used instead of directly scheduling closures on a global +// pool since it maintains a global view across all sessions and optimizes pool +// scheduling to improve (median and tail) latency. +// +// This class is thread safe. +class RunHandler { + public: + void ScheduleInterOpClosure(std::function fn); + thread::ThreadPoolInterface* AsIntraThreadPoolInterface(); + + ~RunHandler(); + + private: + class Impl; + friend class RunHandlerPool::Impl; + + explicit RunHandler(Impl* impl); + + Impl* impl_; // NOT OWNED. +}; + +namespace internal { + +// TODO(azaks): Refactor with thread:ThreadPool +class RunHandlerEnvironment { + typedef Thread EnvThread; + struct TaskImpl { + std::function f; + Context context; + uint64 trace_id; + }; + Env* const env_; + const ThreadOptions thread_options_; + const string name_; + + public: + struct Task { + std::unique_ptr f; + }; + + RunHandlerEnvironment(Env* env, const ThreadOptions& thread_options, + const string& name); + + EnvThread* CreateThread(std::function f, + const std::string& thread_name); + + Task CreateTask(std::function f); + + void ExecuteTask(const Task& t); +}; + +typedef typename RunHandlerEnvironment::Task Task; +typedef Eigen::RunQueue Queue; + +// To reduce cache misses, we use a doubly-linked list of Waiter structs and +// queue them in LIFO order rather than the FIFO order used by a single +// condition variable. +struct Waiter { + Waiter() { + next = this; + prev = this; + } + condition_variable cv; + mutex mu; + Waiter* next; + Waiter* prev; +}; + +class ThreadWorkSource { + public: + ThreadWorkSource(); + + ~ThreadWorkSource(); + + Task EnqueueTask(Task t, bool is_blocking); + + Task PopBlockingTask(); + + Task PopNonBlockingTask(int start_index, bool search_from_all_queue); + + void WaitForWork(int max_sleep_micros); + + int TaskQueueSize(bool is_blocking); + + int64_t GetTracemeId(); + + void SetTracemeId(int64_t value); + + void SetWaiter(uint64 version, Waiter* waiter, mutex* mutex); + + int64_t GetInflightTaskCount(bool is_blocking); + + void IncrementInflightTaskCount(bool is_blocking); + + void DecrementInflightTaskCount(bool is_blocking); + + unsigned NonBlockingWorkShardingFactor(); + + std::string ToString(); + + private: + struct NonBlockingQueue { + mutex queue_op_mu; + char pad[128]; + Queue queue; + }; + + int32 non_blocking_work_sharding_factor_; + Eigen::MaxSizeVector non_blocking_work_queues_; + + std::atomic blocking_inflight_; + std::atomic non_blocking_inflight_; + + Queue blocking_work_queue_; + mutex blocking_queue_op_mu_; + char pad_[128]; + mutex waiters_mu_; + Waiter queue_waiters_ TF_GUARDED_BY(waiters_mu_); + std::atomic traceme_id_; + + mutex run_handler_waiter_mu_; + uint64 version_ TF_GUARDED_BY(run_handler_waiter_mu_); + mutex* sub_thread_pool_waiter_mu_ TF_GUARDED_BY(run_handler_waiter_mu_); + Waiter* sub_thread_pool_waiter_ TF_GUARDED_BY(run_handler_waiter_mu_); +}; + +class RunHandlerThreadPool { + public: + struct PerThread { + constexpr PerThread() : pool(nullptr), thread_id(-1) {} + RunHandlerThreadPool* pool; // Parent pool, or null for normal threads. + int thread_id; // Worker thread index in pool. + }; + + RunHandlerThreadPool(int num_blocking_threads, int num_non_blocking_threads, + Env* env, const ThreadOptions& thread_options, + const string& name, + Eigen::MaxSizeVector* waiters_mu, + Eigen::MaxSizeVector* queue_waiters); + + ~RunHandlerThreadPool(); + + void Start(); + + void StartOneThreadForTesting(); + + void AddWorkToQueue(ThreadWorkSource* tws, bool is_blocking, + std::function fn); + + // Set work queues from which the thread 'tid' can steal its work. + // The request with start_request_idx will be attempted first. Other requests + // will be attempted in FIFO order based on their arrival time. + void SetThreadWorkSources( + int tid, int start_request_idx, uint64 version, + const Eigen::MaxSizeVector& thread_work_sources); + + PerThread* GetPerThread(); + + int CurrentThreadId() const; + + int NumThreads() const; + + int NumBlockingThreads() const; + + int NumNonBlockingThreads() const; + + void WorkerLoop(int thread_id, bool may_steal_blocking_work); + + // Search tasks from Requets range searching_range_start to + // searching_range_end. If there is no tasks in the search range and + // may_steal_blocking_work is true, then search from all requests. + Task FindTask( + int searching_range_start, int searching_range_end, int thread_id, + int sub_thread_pool_id, int max_blocking_inflight, + bool may_steal_blocking_work, + const Eigen::MaxSizeVector& thread_work_sources, + bool* task_from_blocking_queue, ThreadWorkSource** tws); + + void WaitForWork(bool is_blocking, int thread_id, + int32_t max_blocking_inflight); + + void WaitForWorkInSubThreadPool(bool is_blocking, int sub_thread_pool_id); + + private: + struct ThreadData { + ThreadData(); + mutex mu; + uint64 new_version; + condition_variable sources_not_empty; + std::unique_ptr thread; + int current_index; + std::unique_ptr> + new_thread_work_sources TF_GUARDED_BY(mu); + + uint64 current_version; + // Should only be accessed by one thread. + std::unique_ptr> + current_thread_work_sources; + + int sub_thread_pool_id; + }; + + const int num_threads_; + const int num_blocking_threads_; + const int num_non_blocking_threads_; + Eigen::MaxSizeVector thread_data_; + internal::RunHandlerEnvironment env_; + std::atomic cancelled_; + string name_; + Eigen::MaxSizeVector* waiters_mu_; + Eigen::MaxSizeVector* queue_waiters_; + + bool use_sub_thread_pool_; + std::vector num_threads_in_sub_thread_pool_; + + // Threads in each sub thread pool will search tasks from the given + // start_request_percentage to end_request_percentage in a round robin + // fashion. + std::vector sub_thread_pool_start_request_percentage_; + std::vector sub_thread_pool_end_request_percentage_; +}; + +} // namespace internal + +} // end namespace tensorflow. + +#endif // TENSORFLOW_CORE_FRAMEWORK_RUN_HANDLER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/framework/run_handler_util.h b/third_party/tflite-hdrs/tensorflow/core/framework/run_handler_util.h new file mode 100644 index 00000000..c63583da --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/framework/run_handler_util.h @@ -0,0 +1,78 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_FRAMEWORK_RUN_HANDLER_UTIL_H_ +#define TENSORFLOW_CORE_FRAMEWORK_RUN_HANDLER_UTIL_H_ + +#include +#include +#include + +namespace tensorflow { + +// Assign thread ranges to requests. +// Requests are numbered 0...num_active_requests-1, and +// threads are numbered 0...num_threads-1. +// On return, the range [start_vec->at(i), end_vec->at(i)) +// indicates the subrange of the threads available to request i. +// The ranges given to different requests may overlap. +// Lower numbered requests will tend to be assigned more threads. +// Thus, a client might associate older requests with lower +// array indices so they receive access to more threads. +// However, the routine ensures that each request is given access +// to at least min(min_threads_per_request, num_threads) threads. +// Every thread will be assigned to at least one request range, +// assuming there is at least one request. +void ComputeInterOpSchedulingRanges(int num_active_requests, int num_threads, + int min_threads_per_request, + std::vector* start_vec, + std::vector* end_vec); + +// Assign thread steal ranges to threads.Threads are numbered 0...num_threads-1. +// On return, the range [start_vec->at(i), end_vec->at(i)) indicates the steal +// range of the thread i. The ranges given to different threads may overlap. +void ComputeInterOpStealingRanges(int num_threads, int min_threads_per_domain, + std::vector* start_vec, + std::vector* end_vec); + +// For each of the num_threads determine the index of the active_request whose +// work queue should be attempted first by that the thread. Return a vector of +// size num_threads which represents how threads should be distributed across +// requests. +std::vector ChooseRequestsWithExponentialDistribution( + int num_active_requests, int num_threads); + +// Look up environment variable named 'var_name' and return the value if it +// exist and can be parsed. Return 'default_value' otherwise. +double ParamFromEnvWithDefault(const char* var_name, double default_value); + +// Look up environment variable named 'var_name' and return the value if it +// exist and can be parsed. The value must be in format val1,val2... Return +// 'default_value' otherwise. +std::vector ParamFromEnvWithDefault(const char* var_name, + std::vector default_value); + +// Look up environment variable named 'var_name' and return the value if it +// exist and can be parsed. The value must be in format val1,val2... Return +// 'default_value' otherwise. +std::vector ParamFromEnvWithDefault(const char* var_name, + std::vector default_value); + +// Look up environment variable named 'var_name' and return the value if it +// exist and can be parsed. Return 'default_value' otherwise. +bool ParamFromEnvBoolWithDefault(const char* var_name, bool default_value); + +} // end namespace tensorflow +#endif // TENSORFLOW_CORE_FRAMEWORK_RUN_HANDLER_UTIL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/framework/session_state.h b/third_party/tflite-hdrs/tensorflow/core/framework/session_state.h new file mode 100644 index 00000000..d102e153 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/framework/session_state.h @@ -0,0 +1,91 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_FRAMEWORK_SESSION_STATE_H_ +#define TENSORFLOW_CORE_FRAMEWORK_SESSION_STATE_H_ + +#include +#include +#include + +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/platform/mutex.h" + +namespace tensorflow { + +// The session state remembers the tensors we choose to keep across +// multiple run calls. +class SessionState { + public: + // Get a tensor from the session state. + absl::Status GetTensor(const std::string& handle, Tensor* tensor); + + // Store a tensor in the session state. + absl::Status AddTensor(const std::string& handle, const Tensor& tensor); + + // Delete a tensor from the session state. + absl::Status DeleteTensor(const std::string& handle); + + int64_t GetNewId(); + + static const char* kTensorHandleResourceTypeName; + + private: + mutex state_lock_; + + // For generating unique ids for tensors stored in the session. + int64_t tensor_id_ = 0; + + // The live tensors in the session. A map from tensor handle to tensor. + std::unordered_map tensors_; +}; + +// The tensor store remembers the tensors we choose to keep for the +// current run call. It is available to every op kernel. +class TensorStore { + public: + struct TensorAndKey { + Tensor tensor; + int64_t id; + std::string device_name; + + std::string GetHandle(const std::string& tensor_name) { + return strings::StrCat(tensor_name, ";", id, ";", device_name); + } + }; + + // Add the named tensor to the tensor store for this run. + absl::Status AddTensor(const std::string& name, const TensorAndKey& tk); + + // Save the tensors in the tensor store of this run to the session. + absl::Status SaveTensors(const std::vector& output_names, + SessionState* session_state); + + // Returns true if no tensors have been added to this store. + bool empty() TF_NO_THREAD_SAFETY_ANALYSIS { return !dirty_; } + + private: + mutex lock_; + std::atomic dirty_ TF_GUARDED_BY(lock_){false}; + + // The tensors that will be saved to session state when this run completes. + // A map from tensor string name to tensor. + std::unordered_map tensors_ TF_GUARDED_BY(lock_); +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_FRAMEWORK_SESSION_STATE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/framework/shape_inference.h b/third_party/tflite-hdrs/tensorflow/core/framework/shape_inference.h new file mode 100644 index 00000000..8bfd301d --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/framework/shape_inference.h @@ -0,0 +1,924 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_FRAMEWORK_SHAPE_INFERENCE_H_ +#define TENSORFLOW_CORE_FRAMEWORK_SHAPE_INFERENCE_H_ + +#include + +#include "absl/memory/memory.h" +#include "tensorflow/core/framework/full_type.pb.h" +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/macros.h" + +namespace tensorflow { + +namespace grappler { +class GraphProperties; +class SymbolicShapeManager; +} // namespace grappler + +namespace shape_inference { + +struct DimensionOrConstant; +class InferenceContext; + +// This header contains the InferenceContext that is used to infer the shape of +// the results of an operation or flag an operation with invalid inputs (e.g., +// mismatched shapes for elementwise operation) by ShapeRefiner. The shape of an +// operation is computed using the OpShapeInferenceFn set via SetShapeFn in op +// registration. The OpShapeInferenceFn uses a per op InferenceContext populated +// with input shapes to compute resultant shape (including resource shapes). +// +// The shapes created in the InferenceContext are bound to the lifetime of the +// InferenceContext in which it was created. E.g., in +// +// ```c++ +// InferenceContext c; +// // Below a ShapeHandle is returned by MakeShape, while UnknownDim returns a +// // DimensionHandle. +// ShapeHandle in0 = c.MakeShape({10, c.UnknownDim()}); +// ``` +// +// the ShapeHandle `in0` (and the nested unknown dim inside) is only valid while +// `c` is in scope, as ShapeHandle and DimensionHandle are effectively +// wrappers around pointers stored inside the context with the lifetime of the +// value pointed to managed by the context. The result from one operation's +// inference context will be passed as input to the inference of consumer +// operations. Hence it is possible for ShapeHandles produced by inference on a +// node to consist of ShapeHandles owned by different InferenceContexts. While +// inferring the shapes of a Graph, the InferenceContext of all nodes/operations +// in the Graph remain resident for the lifetime of the Graph (e.g, there is a +// map from each node to its InferenceContext, technically its +// ExtendedInferencContext which additionally stores the element types of inputs +// & outputs, which remains resident). +// +// For functions, the body of the function is instantiated as a Graph while +// inferring the result shapes of a function call node. The rules above apply +// while the function's shape is being inferred, but the contexts associated +// with nodes in the function body are released once the function call's +// resultant shapes are inferred. The shapes of results returned by a function +// are propagated to the InferenceContext of the function call's op (which is +// associated with a Graph of nodes whose shape is being inferred) as the return +// values of a function call node are the inputs of its consumer, but the return +// values are produced by nodes inside the function whose InferenceContexts +// (which owns the values pointed to by ShapeHandle and DimensionHandle) are +// reclaimed after inferring function result shapes. Recursive user-defined +// function are not supported hence inference of functions are fully nested with +// the InferenceContext's of function calls forming a stack. +// +// For example, consider the following call and function: +// +// ```python +// @tf.function +// def g(st): +// d = tf.add(st, st) +// return d +// +// @tf.function +// def f(): +// st = tf.A() +// result = g(st) +// return h(result) +// ``` +// +// During inference of f, the shape of `A` will be inferred and the results from +// its InferenceContext used as inputs to function call `g(st)`. The call node +// will have an InferenceContext created (call it outer context) and the graph +// corresponding to function `g` will be instantiated. The result shape of the +// Arg nodes of the function will be associated with input from outer context. +// During inference of `g` (for the callsite `g(st)` in `f`), the +// InferenceContext of all nodes inside `g` will remain alive. Thus, when shape +// of `tf.add` is computed it may rely on all inputs. Once the RetVal nodes of a +// function is reached, we know the shape of its input may correspond to a shape +// queried in the outer context and it is explicitly copied to outer context. In +// this case that means that the shape of `d` is copied to the InferenceContext +// of `g(st)` and so when `h(result)` is executed this shape may be queried. +// Furthermore, no shapes computed due to call `g(st)` can be queried post this +// point and, as the RetVal shapes have been coppied into outer context, all +// InferenceContexts associated with nodes in function `g` instantiated for +// `g(st)` may be and are released. + +// Dimension values are accessed through InferenceContext. +class Dimension { + private: + Dimension(); + Dimension(int64_t value); + ~Dimension() {} + + const int64_t value_; + + friend class InferenceContext; + friend class ShapeManager; + Dimension(const Dimension&) = delete; + void operator=(const Dimension&) = delete; +}; + +class DimensionHandle { + public: + DimensionHandle() {} + bool SameHandle(DimensionHandle d) const { return ptr_ == d.ptr_; } + std::size_t Handle() const { return reinterpret_cast(ptr_); } + + private: + DimensionHandle(const Dimension* dim) { ptr_ = dim; } + + const Dimension* operator->() const { return ptr_; } + bool IsSet() const { return ptr_ != nullptr; } + + const Dimension* ptr_ = nullptr; + + friend struct DimensionOrConstant; + friend class InferenceContext; + friend class ShapeInferenceTest; + friend class ShapeInferenceTestutil; + friend class ::tensorflow::grappler::GraphProperties; + friend class ::tensorflow::grappler::SymbolicShapeManager; + + // Intentionally copyable. +}; + +// Shape rank and dimensions are accessed through InferenceContext. +class Shape { + private: + Shape(); + Shape(const std::vector& dims); + ~Shape() {} + + const int32 rank_; + const std::vector dims_; + + friend class InferenceContext; + friend class ::tensorflow::grappler::SymbolicShapeManager; + + Shape(const Shape&) = delete; + void operator=(const Shape&) = delete; +}; + +class ShapeHandle { + public: + ShapeHandle() {} + bool SameHandle(ShapeHandle s) const { return ptr_ == s.ptr_; } + std::size_t Handle() const { return reinterpret_cast(ptr_); } + + private: + ShapeHandle(const Shape* shape) { ptr_ = shape; } + const Shape* operator->() const { return ptr_; } + bool IsSet() const { return ptr_ != nullptr; } + + const Shape* ptr_ = nullptr; + + friend class InferenceContext; + friend class ShapeInferenceTest; + friend class ShapeInferenceTestutil; + friend class ::tensorflow::grappler::SymbolicShapeManager; + + // Intentionally copyable. +}; + +// Struct used to allow functions to take DimensionHandle or a dimension value. +// Not meant to be constructed directly. +struct DimensionOrConstant { + public: + // Intentionally not explicit. + DimensionOrConstant(DimensionHandle dim); + + // val must be non-negative or InferenceContext::kUnknownDim. + DimensionOrConstant(int64_t val); + + // dim takes precedence. If dim != nullptr, val is ignored. + DimensionHandle dim; + int64_t val; + + private: + DimensionOrConstant(); +}; + +struct ShapeAndType { + ShapeAndType() {} + ShapeAndType(ShapeHandle s, DataType t) : shape(s), dtype(t) {} + // TODO(mdan): Remove dtype from constructor, and use type_ instead. + // dtype is kept here for backward compatibiity. Its information should + // be redundant to that in type; + ShapeAndType(ShapeHandle s, DataType t, FullTypeDef type_) + : shape(s), dtype(t), type(type_) {} + + ShapeHandle shape; + DataType dtype = DT_INVALID; + FullTypeDef type; +}; + +// Shape inference functions registered on ops in REGISTER_OP implement +// their shape functions in terms of this InferenceContext. An InferenceContext +// is created by the framework and passed to a shape inference function. The +// shape inference function calls functions on the context, and should call +// set_output() to set the shape on all outputs. +// +// To infer shapes for user-defined functions see ShapeRefiner. +// +// All Shape* and Dimension* returned by functions of InferenceContext are owned +// by the InferenceContext. +class InferenceContext { + public: + static constexpr int64_t kUnknownDim = -1; + static constexpr int32_t kUnknownRank = -1; + + // is NULL-padded to be the same size as . + // + // Elements of are used for when a shape function + // makes a call to MakeShapeFromShapeTensor; in particular, when the + // input_tensors[i] is nullptr but the shape represented by it is partially + // known from analysis of the graph. + // can have fewer elements than . + // Values of do not need to outlive the context. + InferenceContext(int graph_def_version, const AttrSlice& attrs, + const OpDef& op_def, + const std::vector& input_shapes, + const std::vector& input_tensors, + const std::vector& input_tensors_as_shapes, + std::vector>> + input_handle_shapes_and_types); + + // is NULL-padded to be the same size as . + // + // Elements of are used for when a shape + // function makes a call to MakeShapeFromShapeTensor; in particular, when + // the input_tensors[i] is nullptr but the shape represented by it is + // partially known from analysis of the graph. + // can have fewer elements than . Values of + // do not need to outlive the context. + InferenceContext( + int graph_def_version, const AttrSlice& attrs, const OpDef& op_def, + const std::vector& input_shapes, + const std::vector& input_tensors, + const std::vector& input_tensors_as_shapes, + const std::vector>>>& + input_handle_shapes_and_types); + + ~InferenceContext(); + + // Runs the shape inference function 'fn' with 'this' as the + // argument, returns the status of the inference. + // + // On error, additional context is provided in the error message. + absl::Status Run( + const std::function& + fn); + + // Merge the stored shape of the input in position idx with according + // to the following rules: + // + // - If the ShapeHandles are the same or is unknown, there will be no + // change. Otherwise if the stored shape is unknown, the new shape will be + // . + // - If both shapes are known, then they must have the same rank. + // - For any one dimension, if the values for that dimension in both shapes + // are known, then the values must match. + // - If one shape has equal or more information than the other shape in every + // dimension, the new shape will become the shape with more information. + // - Example: merging [2,?] and [?,2] results in [2,2] + // - Example: [2,2] cannot be merged with [1,2] + // + // This requires idx to be in the [0, num_inputs) range. If the merge is + // successful, return true. Return false otherwise. + bool MergeInput(int idx, ShapeHandle shape) { + ShapeHandle new_shape; + if (!Merge(inputs_[idx], shape, &new_shape).ok()) return false; + inputs_[idx] = new_shape; + return true; + } + + // Relax the stored shape of the input in position idx with according + // to the following rules: + // + // - If the ShapeHandles are the same then the stored shape will be returned. + // - If either of the ShapeHandles are unknown, then a new UnknownShape will + // be returned. A new shape must be returned because we cannot claim that + // the resulting shape is necessarily the same as either of the input + // shapes. + // - If the shapes both have known ranks but their ranks are different, a new + // UnknownShape will be returned. + // - For any one dimension, if the value for that dimension in either of the + // shapes is unknown, a new shape will be returned with a new UnknownDim in + // that dimension. + // - For any one dimension, if the values for that dimension in both shapes + // are known but do not match, a new shape will be returned with a new + // UnknownDim in that dimension. + // - If both shapes have the same known rank and match in every dimension, + // the stored shape will be returned. + // - Example: relaxing [2,?] and [?,2] results in [?,?] + // - Example: relaxing [2,2] and [3,2] results in [?,2] + // - Example: relaxing [2,2] with [1,2,3] results in ? + // + // This requires idx to be in the [0, num_inputs) range. If the relax is + // successful and the new shape differs from the old one, store the new + // shape and return true. Return false otherwise. + bool RelaxInput(int idx, ShapeHandle shape) { + ShapeHandle new_shape; + Relax(inputs_[idx], shape, &new_shape); + if (inputs_[idx].SameHandle(new_shape)) { + return false; + } + inputs_[idx] = new_shape; + return true; + } + + void SetInput(int idx, ShapeHandle shape) { inputs_[idx] = shape; } + + ShapeHandle input(int64_t idx) const { return inputs_[idx]; } + absl::Status input(absl::string_view input_name, + std::vector* output) const; + int num_inputs() const { return inputs_.size(); } + + // Returns the input tensor at index , or nullptr if the input tensor is + // not available at the time of shape inference. + const Tensor* input_tensor(int idx) { + // Mark that this idx was requested. + request_input_tensor(idx); + return input_tensors_[idx]; + } + + // Notifies the shape refiner that the value of the tensor at index + // is needed. The shape refiner tries to statically compute this tensor, + // and if successful re-runs the shape function with this tensor available + // in the call to 'input_tensor(idx)'. + void request_input_tensor(int idx) { requested_input_tensor_[idx] = true; } + + // Returns true iff input_tensor(idx) was called by the shape function. + bool requested_input_tensor(int idx) const { + return requested_input_tensor_[idx]; + } + + // Notifies the shape refiner that the value of the tensor at index + // as a partial shape is needed. The shape refiner tries to statically compute + // this, and if successful re-runs the shape function with the + // computed PartialTensorShape available in the call to + // 'MakeShapeFromShapeTensor(idx, handle)' or + // 'MakeShapeFromShapeTensorTreatScalarAsUnknownShape(idx, handle)'. + void request_input_tensor_as_partial_shape(int idx) { + requested_input_tensor_as_partial_shape_[idx] = true; + } + + // Returns true if MakeShapeFromInputTensor was called but the constant + // input_tensor was not present. + bool requested_input_tensor_as_partial_shape(int idx) const { + return requested_input_tensor_as_partial_shape_[idx]; + } + + void set_input_tensors(const std::vector& input_tensors) { + input_tensors_ = input_tensors; + } + + void set_input_tensors_as_shapes( + const std::vector& input_tensors_as_shapes) { + input_tensors_as_shapes_ = input_tensors_as_shapes; + } + + const std::vector& input_tensors_as_shapes() const { + return input_tensors_as_shapes_; + } + + ShapeHandle output(int64_t idx) const { return outputs_.at(idx); } + void set_output(int idx, ShapeHandle shape) { outputs_.at(idx) = shape; } + absl::Status set_output(absl::string_view output_name, + const std::vector& shapes); + + int num_outputs() const { return outputs_.size(); } + ShapeHandle output(int idx) const { return outputs_.at(idx); } + absl::Status output(absl::string_view output_name, + std::vector* output) const; + + // Returns the value for attribute named `attr_name`. + absl::Status GetAttr(absl::string_view attr_name, + const AttrValue** attr_value) const { + return attrs_.Find(attr_name, attr_value); + } + const AttrValue* GetAttr(absl::string_view attr_name) const { + return attrs_.Find(attr_name); + } + + const FullTypeDef& ret_types() const { return ret_types_; } + + // idx can be negative for an offset from end of dimensions. + // idx must be in the range [-1 * s.rank, s.rank). + DimensionHandle Dim(ShapeHandle s, int64_t idx) { + if (!s.Handle() || s->rank_ == kUnknownRank) { + return UnknownDim(); + } + return DimKnownRank(s, idx); + } + // As above, but asserts that the rank of the shape is known. + static DimensionHandle DimKnownRank(ShapeHandle s, int64_t idx) { + CHECK_NE(s->rank_, kUnknownRank); + if (idx < 0) { + return s->dims_[s->dims_.size() + idx]; + } + return s->dims_[idx]; + } + + static int32 Rank(ShapeHandle s) { + return s.IsSet() ? s->rank_ : kUnknownRank; + } + static bool RankKnown(ShapeHandle s) { + return (s.IsSet() && (Rank(s) != kUnknownRank)); + } + static inline int64_t Value(DimensionOrConstant d) { + return d.dim.IsSet() ? d.dim->value_ : d.val; + } + static inline bool ValueKnown(DimensionOrConstant d) { + return Value(d) != kUnknownDim; + } + + // Fills the output proto with the shape defined by the handle. + // "proto" is expected to be empty prior to the call. + void ShapeHandleToProto(ShapeHandle handle, TensorShapeProto* proto); + TensorShapeProto ShapeHandleToProto(ShapeHandle handle); + + // Returns true if the rank and all dimensions of the Shape are known. + bool FullyDefined(ShapeHandle s); + + // Returns the total number of elements, or an unknown dimension for an + // incomplete shape. + DimensionHandle NumElements(ShapeHandle s); + + std::string DebugString(ShapeHandle s); + std::string DebugString(DimensionHandle d); + std::string DebugString(const ShapeAndType& shape_and_type); + std::string DebugString(absl::Span shape_and_types); + + // Describes the whole context, for debugging purposes. + std::string DebugString() const; + + // If has rank , or its rank is unknown, return OK and return + // the shape with asserted rank in <*out>. Otherwise return an error. + // + // Note that <*out> may be set to . + absl::Status WithRank(ShapeHandle shape, int64_t rank, ShapeHandle* out); + absl::Status WithRankAtLeast(ShapeHandle shape, int64_t rank, + ShapeHandle* out); + absl::Status WithRankAtMost(ShapeHandle shape, int64_t rank, + ShapeHandle* out); + + // If has value , or its value is unknown, returns OK and returns + // the dimension with asserted value in <*out>. Otherwise returns an error. + // + // Note that <*out> may be set to . + absl::Status WithValue(DimensionHandle dim, int64_t value, + DimensionHandle* out); + + // Merges and and returns the merged shape in <*out>. See + // 'MergeInput' function for full details and examples. + absl::Status Merge(ShapeHandle s0, ShapeHandle s1, ShapeHandle* out); + + // Asserts that 's rank >= 's rank, and the first + // dimensions of are compatible with the dimensions of + // . + // Returns the merged results in <*s_out> and <*prefix_out>. + absl::Status MergePrefix(ShapeHandle s, ShapeHandle prefix, + ShapeHandle* s_out, ShapeHandle* prefix_out); + + // Merges and and returns the merged dimension in <*out>. If + // and have incompatible values, returns an error. + // + // Note that <*out> may be set to or . + absl::Status Merge(DimensionHandle d0, DimensionHandle d1, + DimensionHandle* out); + + // Returns in <*out> a sub-shape of with dimensions [start:]. + // can be negative to index from the end of the shape. If > + // rank of , then an empty subshape is returned. + absl::Status Subshape(ShapeHandle s, int64_t start, ShapeHandle* out); + + // Returns in <*out> a sub-shape of , with dimensions [start:end]. + // and can be negative, to index from the end of the shape. + // and are set to the rank of if > rank of . + absl::Status Subshape(ShapeHandle s, int64_t start, int64_t end, + ShapeHandle* out); + + // Returns in <*out> a sub-shape of , with dimensions [start:end:stride]. + // and can be negative, to index from the end of the shape. + // and are set to the rank of if > rank of . + // can be negative, to reverse the . + absl::Status Subshape(ShapeHandle s, int64_t start, int64_t end, + int64_t stride, ShapeHandle* out); + + // Returns in <*out> the result of appending the dimensions of to those + // of . + absl::Status Concatenate(ShapeHandle s1, ShapeHandle s2, ShapeHandle* out); + + // Returns in the shape from replacing with + // . + absl::Status ReplaceDim(ShapeHandle s, int64_t dim_index, + DimensionHandle new_dim, ShapeHandle* out); + + // Returns a new shape with the given dims. The returned value is owned by + // this context. + ShapeHandle MakeShape(const std::vector& dims); + ShapeHandle MakeShape(std::initializer_list dims); + + // Returns a new unknown shape. + ShapeHandle UnknownShape(); + + // Returns a shape with specified rank but unknown dims. + ShapeHandle UnknownShapeOfRank(int64_t rank); + + // Returns a new shape of zero dimensions. + ShapeHandle Scalar(); + + // Returns a new shape of one dimension. + ShapeHandle Vector(DimensionOrConstant dim); + + // Returns a new shape of two dimensions. + ShapeHandle Matrix(DimensionOrConstant dim1, DimensionOrConstant dim2); + + // Returns in a new shape whose dimension sizes come from input tensor + // . The tensor must be a 1-dimensional int32 or int64 tensor. If + // the input tensor is NULL, then an unknown shape is returned. + absl::Status MakeShapeFromShapeTensor(int input_idx, ShapeHandle* out); + + // Like the function above, but treats scalar values as unknown + // shapes. **NOTE** If the scalar is statically known, its value + // must be -1 or an error is returned. + absl::Status MakeShapeFromShapeTensorTreatScalarAsUnknownShape( + int input_idx, ShapeHandle* out); + + // Returns in a new shape corresponding to . + absl::Status MakeShapeFromShapeProto(const TensorShapeProto& proto, + ShapeHandle* out); + + // Returns in a new shape corresponding to . + absl::Status MakeShapeFromPartialTensorShape( + const PartialTensorShape& partial_shape, ShapeHandle* out); + + // Returns in a new shape corresponding to . + absl::Status MakeShapeFromTensorShape(const TensorShape& shape, + ShapeHandle* out); + absl::StatusOr MakeShapeFromShapeTensor( + const TensorShape& shape); + + // Returns a new dimension of the given size. The returned value is owned by + // this context. + inline DimensionHandle MakeDim(DimensionOrConstant d) { + return shape_manager_.MakeDim(d); + } + + inline DimensionHandle UnknownDim() { return MakeDim(kUnknownDim); } + + // Returns in a scalar value from an input tensor . The input tensor + // must be a 0-dimensional int32 or int64 tensor. Caller must ensure that the + // input tensor is not NULL. + absl::Status GetScalarFromTensor(const Tensor* t, int64_t* val); + + // Returns in a scalar value from a 1D input tensor with int32 or + // int64 elements. Caller must ensure that the input tensor is not NULL. + absl::Status GetScalarFromTensor(const Tensor* t, int64_t idx, int64_t* val); + + // Returns a new dimension whose value is given by a scalar input tensor. + // The input tensor must be in host memory, since it is dereferenced to get + // the value. + absl::Status MakeDimForScalarInput(int idx, DimensionHandle* out); + + // Returns a new dimension whose value is given by a scalar input tensor. + // This allows for a negative input dimension given the rank of a separate + // tensor. This rank can be negative if unknown. + // The input tensor must be in host memory, since it is dereferenced to get + // the value. + absl::Status MakeDimForScalarInputWithNegativeIndexing(int idx, + int input_rank, + DimensionHandle* out); + + // Look up the attr being evaluated with name attr_name and set *value to its + // value. If no attr with attr_name is found in def(), or the attr does not + // have a matching type, a non-ok status will be returned. + template + absl::Status GetAttr(absl::string_view attr_name, T* value) const; + + // Returns in the result of dividing by . + // Returns an error if is not positive or if + // and does not evenly divide . + absl::Status Divide(DimensionHandle dividend, DimensionOrConstant divisor, + bool evenly_divisible, DimensionHandle* out); + + // Returns in the sum of and . + absl::Status Add(DimensionHandle first, DimensionOrConstant second, + DimensionHandle* out); + + // Returns in the dimension that is minus . + absl::Status Subtract(DimensionHandle first, DimensionOrConstant second, + DimensionHandle* out); + + // Returns in the product of and . + absl::Status Multiply(DimensionHandle first, DimensionOrConstant second, + DimensionHandle* out); + + // Returns in the minimum of and . If either or + // is zero the results is zero. Otherwise, if either or + // is unknown the results is unknown. + absl::Status Min(DimensionHandle first, DimensionOrConstant second, + DimensionHandle* out); + + // Returns in the maximum of and . If either or + // is unknown the results is unknown. + absl::Status Max(DimensionHandle first, DimensionOrConstant second, + DimensionHandle* out); + + absl::Status construction_status() const { return construction_status_; } + + // Methods to propagate shape and dtype on edges of handles. Handles are the + // dtype DT_RESOURCE which can be used to access state stored in a + // ResourceManager. When ops (such as variables) consume these handles to + // produce tensors they might need to know side-information about the shapes + // and dtypes of tensors which can be accessed via the handle. These methods + // propagate that information. Output handle dtypes and shapes are ignored if + // the output tensor is not of type DT_RESOURCE. + + // Merge the stored shapes and types corresponding to the input handle in + // position idx with the specified shapes and types. This requires idx to be + // in the [0, num_inputs) range. + // + // If the merge is successful and any of the new shapes differs from the old + // one, or any of the old dtypes was DT_INVALID, store the new shapes and + // return true. Return false otherwise. + // + // See 'MergeInput' function for full details and examples. + bool MergeInputHandleShapesAndTypes( + int idx, + const std::vector& shapes_and_types) TF_MUST_USE_RESULT; + + // As MergeInputHandleShapesAndTypes, but for an output. + bool MergeOutputHandleShapesAndTypes( + int idx, + const std::vector& shapes_and_types) TF_MUST_USE_RESULT; + + // Relaxes the stored shapes and types corresponding to the input handle in + // position idx with the specified shapes and types. This requires idx to be + // in the [0, num_inputs) range. + // + // If the relax is successful (sizes are the same, old dtypes match new ones + // or are DT_INVALID), then store the relaxed shapes and return true. + // Return false otherwise. + // + // See 'RelaxInput' function for full details and examples. + bool RelaxInputHandleShapesAndMergeTypes( + int idx, + const std::vector& shapes_and_types) TF_MUST_USE_RESULT; + + // As RelaxInputHandleShapesAndTypes, but for an output. + bool RelaxOutputHandleShapesAndMergeTypes( + int idx, + const std::vector& shapes_and_types) TF_MUST_USE_RESULT; + + void set_input_handle_shapes_and_types( + int idx, const std::vector& shapes_and_types) { + CHECK_GE(idx, 0) << "idx must be non-negative. Got idx: " << idx << "."; + CHECK_LT(idx, input_handle_shapes_and_types_.size()) + << "Got idx: " << idx << " but only " + << input_handle_shapes_and_types_.size() << " inputs."; + input_handle_shapes_and_types_[idx] = + absl::make_unique>(shapes_and_types); + } + + // Returns the output handle shapes and types, for the resource tensor output + // at index . Returns NULL if the shape and types were never set. + const std::vector* output_handle_shapes_and_types(int idx) { + CHECK_GE(idx, 0) << "idx must be non-negative. Got idx: " << idx << "."; + CHECK_LT(idx, output_handle_shapes_and_types_.size()) + << "Got idx: " << idx << " but only " + << output_handle_shapes_and_types_.size() << " outputs."; + return output_handle_shapes_and_types_[idx].get(); + } + + // Returns the inputs handle shapes and types, for the resource tensor input + // at index . Returns NULL if the shape and types were not available. + const std::vector* input_handle_shapes_and_types(int idx) { + CHECK_GE(idx, 0) << "idx must be non-negative. Got idx: " << idx << "."; + CHECK_LT(idx, input_handle_shapes_and_types_.size()) + << "Got idx: " << idx << " but only " + << input_handle_shapes_and_types_.size() << " inputs."; + return input_handle_shapes_and_types_[idx].get(); + } + + void set_output_handle_shapes_and_types( + int idx, const std::vector& shapes_and_types) { + CHECK_GE(idx, 0) << "idx must be non-negative. Got idx: " << idx << "."; + CHECK_LT(idx, output_handle_shapes_and_types_.size()) + << "Got idx: " << idx << " but only " + << output_handle_shapes_and_types_.size() << " inputs."; + output_handle_shapes_and_types_[idx] = + absl::make_unique>(shapes_and_types); + } + + // Note that shape functions should usually call MakeShapeFromShapeTensor, + // as it does more analysis to provide partial shapes. + // + // Returns in a new shape whose dimension sizes come from tensor . + // The tensor must be a 1-dimensional int32 or int64 tensor. If is NULL, + // then an unknown shape is returned. + absl::Status MakeShapeFromTensor(const Tensor* t, ShapeHandle tensor_shape, + ShapeHandle* out); + + int graph_def_version() const { return graph_def_version_; } + + const std::vector>& MergedShapes() const { + return merged_shapes_; + } + const std::vector>& MergedDims() + const { + return merged_dims_; + } + + // Adds new outputs; useful when mutating the graph. + absl::Status ExpandOutputs(int new_output_size); + + private: + // Creates and stores shapes for use in InferenceContext. + class ShapeManager { + public: + ShapeManager(); + ~ShapeManager(); + + // Returns a new shape with the given dims. The returned value is owned by + // this class. + ShapeHandle MakeShape(const std::vector& dims); + + // Returns a new unknown shape. + ShapeHandle UnknownShape(); + + // Returns a new dimension of the given size. The returned value + // is owned by this class. + inline DimensionHandle MakeDim(DimensionOrConstant d) { + if (d.dim.IsSet()) { + return d.dim; + } else { + all_dims_.push_back(new Dimension(d.val)); + return all_dims_.back(); + } + } + + private: + std::vector all_shapes_; // values are owned. + std::vector all_dims_; // values are owned. + }; + + friend class ::tensorflow::grappler::GraphProperties; + + friend class ShapeInferenceTest; // For testing Relax functions. + friend class ShapeInferenceTestutil; // For testing shapes. + + // Shared initialization across the two constructors. Remove + // once we get rid of one of them. + void PreInputInit(const OpDef& op_def, + const std::vector& input_tensors, + const std::vector& input_tensors_as_shapes); + void PostInputInit(std::vector>> + input_handle_data); + + absl::Status ReturnUnknownShape(ShapeHandle* out) { + *out = UnknownShape(); + return absl::OkStatus(); + } + absl::Status ReturnCreatedShape(const std::vector& dims, + ShapeHandle* out) { + *out = MakeShape(dims); + return absl::OkStatus(); + } + + // Adds additional context to the given status. + absl::Status AttachContext(const absl::Status& status); + + // Relaxes an existing value with a new value and returns the + // relaxed dimension in <*out>. If and have incompatible + // values, returns an error. + // + // Note that <*out> may be set to or . + void Relax(DimensionHandle d_old, DimensionHandle d_new, + DimensionHandle* out); + // Relaxes an existing shape with a new shape and returns the + // relaxed shape in <*out>. See 'RelaxInput' function for full details and + // examples. + void Relax(ShapeHandle s_old, ShapeHandle s_new, ShapeHandle* out); + + // Used to implement MergeInputHandleShapesAndTypes and + // MergeOutputHandleShapesAndTypes. + bool MergeHandleShapesAndTypes( + const std::vector& shapes_and_types, + std::vector* to_update) TF_MUST_USE_RESULT; + // Used to implement RelaxInputHandleShapesAndMergeTypes and + // RelaxOutputHandleShapesAndMergeTypes. + bool RelaxHandleShapesAndMergeTypes( + const std::vector& shapes_and_types, + std::vector* to_update) TF_MUST_USE_RESULT; + + // Forget all the previous merged shapes and dims. + void ForgetMerges() { + merged_shapes_.clear(); + merged_dims_.clear(); + } + + // Helper method for MakeShapeFromTensor and MakeShapeFromShapeTensor. + absl::Status InternalMakeShapeFromTensor( + bool treat_unknown_scalar_tensor_as_unknown_shape, const Tensor* t, + ShapeHandle tensor_shape, ShapeHandle* out); + + ShapeManager shape_manager_; + + // inputs_, outputs_, and input_tensors_as_shapes_ refer to values from + // `shape_manager_`. + std::vector inputs_; + std::vector input_tensors_; + std::vector requested_input_tensor_; + std::vector outputs_; + // Can have fewer elements than inputs_. + std::vector input_tensors_as_shapes_; + std::vector requested_input_tensor_as_partial_shape_; + + // input_handle_shapes_and_types_[i] is the list of shape/type pairs available + // through the resource handle passed along input i of the node. + // + // Values may be NULL. + std::vector>> + input_handle_shapes_and_types_; + + // output_handle_shapes_and_types_[i] is the list of shape/type pairs + // available through the resource handle passed along output i of the node. + // + // Values may be NULL. + std::vector>> + output_handle_shapes_and_types_; + + // Return types for the node this context is associated with. This information + // is to eventually consolidate all the dtype and shape info, allowing for + // output_handle_shapes_and_types_ to be removed. + FullTypeDef ret_types_; + + const int graph_def_version_; + AttrSlice attrs_; + NameRangeMap input_name_map_; + NameRangeMap output_name_map_; + + // An error set during construction. TODO(cwhipkey): remove when test + // constructor is removed. + absl::Status construction_status_; + + // Pair of shape or dim handles that are equivalent, ie that represent the + // same underlying shape of dimension. Note that for each pair at least one of + // the handles must contain an unknown shape, since we don't keep track of + // known shapes or dims here. + std::vector> merged_shapes_; + std::vector> merged_dims_; + + InferenceContext(const InferenceContext&) = delete; + void operator=(const InferenceContext&) = delete; +}; + +// ----------------------------------------------------------------------------- +// Template and inline method implementations, please ignore + +inline Dimension::Dimension() : value_(InferenceContext::kUnknownDim) {} +inline Dimension::Dimension(int64_t value) : value_(value) { + DCHECK(value >= 0 || value == InferenceContext::kUnknownDim) + << "Dimension must be non-negative or equal to " + "InferenceContext::kUnknownDim but got " + << value; +} + +inline Shape::Shape() : rank_(InferenceContext::kUnknownRank) {} +inline Shape::Shape(const std::vector& dims) + : rank_(dims.size()), dims_(dims) {} + +inline DimensionOrConstant::DimensionOrConstant(DimensionHandle dim) + : dim(dim) { + DCHECK(dim.IsSet()) << "Internal error: Got nullptr for Dimension."; +} + +inline DimensionOrConstant::DimensionOrConstant(int64_t val) : val(val) { + DCHECK(val >= 0 || val == InferenceContext::kUnknownDim) + << "Dimension must be non-negative or equal to " + "InferenceContext::kUnknownDim but got " + << val; +} + +template +absl::Status InferenceContext::GetAttr(absl::string_view attr_name, + T* value) const { + return GetNodeAttr(attrs_, attr_name, value); +} + +} // namespace shape_inference +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_FRAMEWORK_SHAPE_INFERENCE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/framework/shape_inference_testutil.h b/third_party/tflite-hdrs/tensorflow/core/framework/shape_inference_testutil.h new file mode 100644 index 00000000..c9b9bd74 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/framework/shape_inference_testutil.h @@ -0,0 +1,103 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_FRAMEWORK_SHAPE_INFERENCE_TESTUTIL_H_ +#define TENSORFLOW_CORE_FRAMEWORK_SHAPE_INFERENCE_TESTUTIL_H_ + +#include +#include + +#include "absl/status/status.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/shape_inference.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/public/version.h" + +// Contains utilities for writing tests for shape inference functions. + +namespace tensorflow { + +class Tensor; + +struct ShapeInferenceTestOp { + typedef std::pair ShapeAndType; + explicit ShapeInferenceTestOp(absl::string_view name) : name(string(name)) {} + string name; + NodeDef node_def; + std::vector input_tensors; + std::vector*> + input_resource_handle_shapes_and_types; + int graph_def_version = TF_GRAPH_DEF_VERSION; +}; + +namespace shape_inference { + +class ShapeInferenceTestutil { + public: + // Run shape inference for , given inputs specified by + // and returns an error if the inferred shape does not match expected_outs. + // + // is a semicolon separated list of shapes. Each shape is formatted + // according to the formatting per + // shape_inference::InferenceContext::InferenceContext. + // + // is a semicolon separated list of shapes. Each shape is + // formatted as one of: + // * ? - an unknown shape, but not matching an input shape + // * in0|in2|... - output shape must be the same as one of these input shapes. + // * [1,?,d0_0|d0_1] - output shape is of known rank, with comma-separated + // dimension values. + // Each dimension value is one of: + // * a constant, which means that constant not equal to a specific input + // * ?, which means an unknown dim size not equal to a specific input + // * d0_0|d1_2, indicating that the dim size must be equal to one of + // the given input dimensions; the first number is the input # and + // the second is which dimension in that input it corresponds to. + // can be "e"; this is used to indicate that shape inference + // should have failed. + static absl::Status InferShapes(ShapeInferenceTestOp op, const string& ins, + const string& expected_outs); + + private: + ShapeInferenceTestutil() = default; + + // Makes a shape out of 'spec'. + static absl::Status MakeShapeFromString( + InferenceContext::ShapeManager* manager, const string& spec, + ShapeHandle* output); +}; + +} // namespace shape_inference + +#define INFER_OK(op, i, o) \ + EXPECT_EQ(tensorflow::shape_inference::ShapeInferenceTestutil::InferShapes( \ + op, i, o), \ + absl::OkStatus()) + +#define INFER_ERROR(error_substring, op, i) \ + { \ + absl::Status status = \ + (tensorflow::shape_inference::ShapeInferenceTestutil::InferShapes( \ + op, i, "e")); \ + std::string error_message = status.ToString(); \ + EXPECT_NE(status, absl::OkStatus()); \ + EXPECT_TRUE(absl::StrContains(error_message, error_substring)) \ + << "Expected to see '" << error_substring << "' in '" << error_message \ + << "'"; \ + } + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_FRAMEWORK_SHAPE_INFERENCE_TESTUTIL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/framework/shared_ptr_variant.h b/third_party/tflite-hdrs/tensorflow/core/framework/shared_ptr_variant.h new file mode 100644 index 00000000..337d51d5 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/framework/shared_ptr_variant.h @@ -0,0 +1,75 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_FRAMEWORK_SHARED_PTR_VARIANT_H_ +#define TENSORFLOW_CORE_FRAMEWORK_SHARED_PTR_VARIANT_H_ + +#include + +#include "tensorflow/core/framework/variant_tensor_data.h" +#include "tensorflow/core/platform/logging.h" + +namespace tensorflow { + +template +struct SharedPtrVariant { + std::shared_ptr shared_ptr; + + SharedPtrVariant() : shared_ptr() {} + + explicit SharedPtrVariant(std::shared_ptr&& ptr) + : shared_ptr(std::forward(ptr)) { + VLOG(3) << "Creating shared_ptr of " << shared_ptr.get() + << " count is: " << shared_ptr.use_count(); + } + + SharedPtrVariant(SharedPtrVariant&& rhs) + : shared_ptr(std::move(rhs.shared_ptr)) { + VLOG(3) << "Moving SharedPtrVariant of " << shared_ptr.get() + << " count is: " << shared_ptr.use_count(); + } + + SharedPtrVariant& operator=(const SharedPtrVariant& rhs) = delete; + + SharedPtrVariant& operator=(SharedPtrVariant&& rhs) { + if (&rhs == this) return *this; + std::swap(shared_ptr, rhs.shared_ptr); + VLOG(3) << "Move-assign of SharedPtrVariant of " << shared_ptr.get() + << " count is: " << shared_ptr.use_count(); + return *this; + } + + SharedPtrVariant(const SharedPtrVariant& rhs) : shared_ptr(rhs.shared_ptr) { + VLOG(3) << "Copying SharedPtrVariant of " << shared_ptr.get() + << " count is: " << shared_ptr.use_count(); + } + + ~SharedPtrVariant() { + VLOG(3) << "Destroying SharedPtrVariant of " << shared_ptr.get() + << " count is: " << shared_ptr.use_count(); + } + + void Encode(VariantTensorData*) const { + // Not supported. + } + + bool Decode(const VariantTensorData&) { + return false; // Not supported. + } +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_FRAMEWORK_SHARED_PTR_VARIANT_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/framework/stats_aggregator.h b/third_party/tflite-hdrs/tensorflow/core/framework/stats_aggregator.h new file mode 100644 index 00000000..5b89a82f --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/framework/stats_aggregator.h @@ -0,0 +1,98 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_FRAMEWORK_STATS_AGGREGATOR_H_ +#define TENSORFLOW_CORE_FRAMEWORK_STATS_AGGREGATOR_H_ + +#include +#include + +#include "tensorflow/core/framework/resource_mgr.h" +#include "tensorflow/core/lib/gtl/array_slice.h" + +namespace tensorflow { + +class Summary; +class SummaryWriterInterface; +namespace data { + +// A `StatsAggregator` accumulates statistics incrementally. A +// `StatsAggregator` can accumulate multiple different statistics, distinguished +// by a string name. +// +// The class currently supports accumulating `Histogram`, `scalar` objects and +// tfstreamz metrics, and we expect to add other methods in future. +// +// NOTE(mrry): `StatsAggregator` is a virtual interface because we anticipate +// that many different implementations will have the same interface. For +// example, we have different implementations in "stats_aggregator_ops.cc" for +// simple in-memory implementation that integrates with the pull-based summary +// API, and for the push-based `SummaryWriterInterface`, and we may add +// implementations that work well with other custom monitoring services. +class StatsAggregator { + public: + virtual ~StatsAggregator() {} + + // Add the given `values` to the histogram with the given `name`. Each + // element of `values` will be treated as a separate sample in the histogram. + virtual void AddToHistogram(const string& name, + absl::Span values, + int64_t global_step) = 0; + + // TODO(shivaniagrawal): consistency in double and float usage. + // Add the given `value` as Scalar with the given `name`. + virtual void AddScalar(const string& name, float value, + int64_t global_step) = 0; + + // Stores a protocol buffer representation of the aggregator state in the + // given `out_summary`. + virtual void EncodeToProto(Summary* out_summary) = 0; + + // Sets a `summary_writer` with this stats_aggregator. + virtual absl::Status SetSummaryWriter( + SummaryWriterInterface* summary_writer) = 0; + + // Increment the `label` cell of metrics mapped with `name` by given `value`. + virtual void IncrementCounter(const string& name, const string& label, + int64_t val) = 0; +}; + +// A `StatsAggregatorResource` wraps a sharable `StatsAggregator` as a resource +// in the TensorFlow resource manager. +// +// NOTE(mrry): This class is separate from `StatsAggregator` in order to +// simplify the memory management of the shared object. Most users of +// `StatsAggregator` interact with a `std::shared_ptr` whereas +// the `ResourceBase` API requires explicit reference counting. +class StatsAggregatorResource : public ResourceBase { + public: + // Creates a new resource from the given `stats_aggregator`. + StatsAggregatorResource(std::unique_ptr stats_aggregator) + : stats_aggregator_(stats_aggregator.release()) {} + + // Returns the wrapped `StatsAggregator`. + std::shared_ptr stats_aggregator() const { + return stats_aggregator_; + } + + string DebugString() const override { return "StatsAggregatorResource"; } + + private: + const std::shared_ptr stats_aggregator_; +}; + +} // namespace data +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_FRAMEWORK_STATS_AGGREGATOR_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/framework/tensor.h b/third_party/tflite-hdrs/tensorflow/core/framework/tensor.h new file mode 100644 index 00000000..8f80ea7c --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/framework/tensor.h @@ -0,0 +1,1104 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_FRAMEWORK_TENSOR_H_ +#define TENSORFLOW_CORE_FRAMEWORK_TENSOR_H_ + +#include +#include +#include +#include +#include + +#include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/lib/core/refcount.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/gtl/inlined_vector.h" +#include "tensorflow/core/platform/mem.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +// Forward declarations. In particular, we forward declare protos so that their +// symbols can be removed from .so exports. +class AllocationDescription; +class OpKernelContext; +class Tensor; +class TensorBuffer; +class TensorCApi; +class TensorInterface; +class TensorCord; +class TensorDescription; +class TensorProto; +class Var; + +namespace batch_util { +absl::Status CopyElementToSlice(Tensor element, Tensor* parent, int64_t index); +absl::Status CopySliceToElement(const Tensor& parent, Tensor* element, + int64_t index); +absl::Status MaybeMoveSliceToElement(Tensor* parent, Tensor* element, + int64_t index); +absl::Status CopyContiguousSlices(const Tensor& src, int64_t src_offset, + int64_t dst_offset, int64_t num_slices, + Tensor* dst); +absl::Status MaybeMoveContiguousSlices(Tensor& src, int64_t src_offset, + int64_t dst_offset, int64_t num_slices, + Tensor* dst); +} // namespace batch_util + +/// @ingroup core + +/// Interface to access the raw ref-counted data buffer. +class TensorBuffer : public core::RefCounted { + public: + explicit TensorBuffer(void* data_ptr) : data_(data_ptr) {} + ~TensorBuffer() override {} + + /// \brief data() points to a memory region of size() bytes. + /// + /// NOTE(mrry): The `data()` method is not virtual for performance reasons. + /// It can be called multiple times when the contents of a `Tensor` are + /// accessed, and so making it non-virtual allows the body to be inlined. + void* data() const { return data_; } + + /// \brief Size (in bytes) of the buffer. + virtual size_t size() const = 0; + + /// \brief If this TensorBuffer is sub-buffer of another TensorBuffer, + /// returns that TensorBuffer. Otherwise, returns this. + virtual TensorBuffer* root_buffer() = 0; + + /// \brief Fills metadata about the allocation into the proto. + virtual void FillAllocationDescription( + AllocationDescription* proto) const = 0; + + virtual bool GetAllocatedBytes(size_t* out_bytes) const; + + /// \brief Helper method to reinterpret the buffer as an array of `T`. + template + T* base() const { + return reinterpret_cast(data()); + } + + /// \brief Whether this TensorBuffer owns the underlying memory. + virtual bool OwnsMemory() const { return true; } + + /// \brief The type of the underlying memory. + virtual AllocatorMemoryType GetMemoryType() const { + return AllocatorMemoryType::kUnknown; + } + + private: + void* const data_; +}; + +/// Represents an n-dimensional array of values. +class Tensor { + public: + /// \brief Creates a 1-dimensional, 0-element float tensor. + /// + /// The returned Tensor is not a scalar (shape {}), but is instead + /// an empty one-dimensional Tensor (shape {0}, NumElements() == + /// 0). Since it has no elements, it does not need to be assigned a + /// value and is initialized by default (IsInitialized() is + /// true). If this is undesirable, consider creating a one-element + /// scalar which does require initialization: + /// + /// ```c++ + /// + /// Tensor(DT_FLOAT, TensorShape({})) + /// + /// ``` + Tensor(); + + /// \brief Creates a Tensor of the given `type` and `shape`. If + /// LogMemory::IsEnabled() the allocation is logged as coming from + /// an unknown kernel and step. Calling the Tensor constructor + /// directly from within an Op is deprecated: use the + /// OpKernelConstruction/OpKernelContext allocate_* methods to + /// allocate a new tensor, which record the kernel and step. + /// + /// The underlying buffer is allocated using a `CPUAllocator`. + Tensor(DataType type, const TensorShape& shape); + + /// \brief Creates a tensor with the input `type` and `shape`, using + /// the allocator `a` to allocate the underlying buffer. If + /// LogMemory::IsEnabled() the allocation is logged as coming from + /// an unknown kernel and step. Calling the Tensor constructor + /// directly from within an Op is deprecated: use the + /// OpKernelConstruction/OpKernelContext allocate_* methods to + /// allocate a new tensor, which record the kernel and step. + /// + /// `a` must outlive the lifetime of this Tensor. + Tensor(Allocator* a, DataType type, const TensorShape& shape); + + /// \brief Creates a tensor with the input `type` and `shape`, using + /// the allocator `a` and the specified "allocation_attr" to + /// allocate the underlying buffer. If the kernel and step are known + /// allocation_attr.allocation_will_be_logged should be set to true + /// and LogMemory::RecordTensorAllocation should be called after the + /// tensor is constructed. Calling the Tensor constructor directly + /// from within an Op is deprecated: use the + /// OpKernelConstruction/OpKernelContext allocate_* methods to + /// allocate a new tensor, which record the kernel and step. + /// + /// `a` must outlive the lifetime of this Tensor. + Tensor(Allocator* a, DataType type, const TensorShape& shape, + const AllocationAttributes& allocation_attr); + + /// \brief Creates a tensor with the input datatype, shape and buf. + /// + /// Acquires a ref on buf that belongs to this Tensor. + Tensor(DataType type, const TensorShape& shape, TensorBuffer* buf); + + /// \brief Creates a tensor with the input datatype, shape and buf. + /// + /// Takes an ownership of the bufffer from the reference counted pointer. + Tensor(DataType type, TensorShape shape, core::RefCountPtr buf); + + /// \brief Creates an empty Tensor of the given data type. + /// + /// Like Tensor(), returns a 1-dimensional, 0-element Tensor with + /// IsInitialized() returning True. See the Tensor() documentation + /// for details. + explicit Tensor(DataType type); + + /// \brief Initializes a tensor with the input `type` and `shape`, or returns + /// an error and leaves `out_tensor` unmodified. This factory method should be + /// used instead of the corresponding constructor if calling code cannot + /// validate that the `DataType` is valid and supported. + /// + /// The underlying buffer is allocated using a `CPUAllocator`. + static absl::Status BuildTensor(DataType type, const TensorShape& shape, + Tensor* out_tensor); + + private: + // A tag type for selecting the `Tensor` constructor overload that creates a + // scalar tensor in host memory. + struct host_scalar_tag {}; + + class HostScalarTensorBufferBase; + template + struct ValueAndTensorBuffer; + + // Creates a tensor with the given scalar `value` in CPU memory. + template + Tensor(T value, host_scalar_tag tag); + + public: + // A series of specialized constructors for scalar tensors in host memory. + // + // NOTE: The `Variant` host-scalar constructor is not defined, because Variant + // is implicitly constructible from many different types, and this causes + // ambiguities with some compilers. + explicit Tensor(float scalar_value) + : Tensor(scalar_value, host_scalar_tag{}) {} + explicit Tensor(double scalar_value) + : Tensor(scalar_value, host_scalar_tag{}) {} + explicit Tensor(int32_t scalar_value) + : Tensor(scalar_value, host_scalar_tag{}) {} + explicit Tensor(uint32 scalar_value) + : Tensor(scalar_value, host_scalar_tag{}) {} + explicit Tensor(uint16 scalar_value) + : Tensor(scalar_value, host_scalar_tag{}) {} + explicit Tensor(uint8 scalar_value) + : Tensor(scalar_value, host_scalar_tag{}) {} + explicit Tensor(int16_t scalar_value) + : Tensor(scalar_value, host_scalar_tag{}) {} + explicit Tensor(int8_t scalar_value) + : Tensor(scalar_value, host_scalar_tag{}) {} + explicit Tensor(tstring scalar_value) + : Tensor(std::move(scalar_value), host_scalar_tag{}) {} + explicit Tensor(complex64 scalar_value) + : Tensor(scalar_value, host_scalar_tag{}) {} + explicit Tensor(complex128 scalar_value) + : Tensor(scalar_value, host_scalar_tag{}) {} + explicit Tensor(int64_t scalar_value) + : Tensor(scalar_value, host_scalar_tag{}) {} + explicit Tensor(uint64 scalar_value) + : Tensor(scalar_value, host_scalar_tag{}) {} + explicit Tensor(bool scalar_value) + : Tensor(scalar_value, host_scalar_tag{}) {} + explicit Tensor(qint8 scalar_value) + : Tensor(scalar_value, host_scalar_tag{}) {} + explicit Tensor(quint8 scalar_value) + : Tensor(scalar_value, host_scalar_tag{}) {} + explicit Tensor(qint16 scalar_value) + : Tensor(scalar_value, host_scalar_tag{}) {} + explicit Tensor(quint16 scalar_value) + : Tensor(scalar_value, host_scalar_tag{}) {} + explicit Tensor(qint32 scalar_value) + : Tensor(scalar_value, host_scalar_tag{}) {} + explicit Tensor(bfloat16 scalar_value) + : Tensor(scalar_value, host_scalar_tag{}) {} + explicit Tensor(Eigen::half scalar_value) + : Tensor(scalar_value, host_scalar_tag{}) {} + explicit Tensor(ResourceHandle scalar_value) + : Tensor(std::move(scalar_value), host_scalar_tag{}) {} + + // NOTE: The `const char*` host-scalar constructor is provided as a + // convenience because otherwise passing a string literal would surprisingly + // construct a DT_BOOL tensor. + explicit Tensor(const char* scalar_value) + : Tensor(tstring(scalar_value), host_scalar_tag{}) {} + + /// Copy constructor. + Tensor(const Tensor& other); + + /// \brief Move constructor. After this call, is safely destructible + /// can be assigned to, and IsInitialized() can be called and will return + /// false. Other calls on (e.g. shape manipulation) are not valid. + Tensor(Tensor&& other); + + // Explicitly delete constructor that take a pointer (except char*) + // so that the pointer doesn't get implicitly cast to bool. + template ::value, + T>::type* = nullptr> + explicit Tensor(T* t) = delete; + + ~Tensor(); + + // I/O operators. + friend std::ostream& // NOLINT: iosfwd + operator<<(std::ostream& out, const Tensor& tensor); + + /// Returns the data type. + DataType dtype() const { return shape_.data_type(); } + + /// Returns the shape of the tensor. + const TensorShape& shape() const { return shape_; } + + /// \brief Convenience accessor for the tensor shape. + /// + /// For all shape accessors, see comments for relevant methods of + /// `TensorShape` in `tensor_shape.h`. + int dims() const { return shape().dims(); } + + /// Convenience accessor for the tensor shape. + int64_t dim_size(int d) const { return shape().dim_size(d); } + + /// Convenience accessor for the tensor shape. + int64_t NumElements() const { return shape().num_elements(); } + + bool IsSameSize(const Tensor& b) const { + return shape().IsSameSize(b.shape()); + } + + // True iff the two tensors use the same underlying refcounted storage + bool SharesBufferWith(const Tensor& b) const; + + /// \brief If necessary, has this Tensor been initialized? + /// + /// Zero-element Tensors are always considered initialized, even if they + /// have never been assigned to and do not have any memory allocated. + bool IsInitialized() const; + + /// Returns the estimated memory usage of this tensor. + size_t TotalBytes() const; + + // Returns the size of allocated memory for this tensor. + size_t AllocatedBytes() const; + + /// Returns true iff this tensor is aligned. + bool IsAligned() const { +#if EIGEN_MAX_ALIGN_BYTES == 0 + return true; +#else + void* ptr = base(); + return dtype() == DT_STRING || NumElements() == 0 || + (reinterpret_cast(ptr) % EIGEN_MAX_ALIGN_BYTES == 0); +#endif + } + + /// Assign operator. This tensor shares other's underlying storage. + Tensor& operator=(const Tensor& other) { + CopyFromInternal(other, other.shape()); + return *this; + } + + /// Move operator. See move constructor for details. + Tensor& operator=(Tensor&& other); + + /// \brief Copy the other tensor into this tensor and reshape it. + /// + /// This tensor shares other's underlying storage. Returns `true` + /// iff `other.shape()` has the same number of elements of the given + /// `shape`. + bool CopyFrom(const Tensor& other, + const TensorShape& shape) TF_MUST_USE_RESULT { + if (other.NumElements() != shape.num_elements()) return false; + CopyFromInternal(other, shape); + return true; + } + + /// \brief Slice this tensor along the 1st dimension. + + /// I.e., the returned tensor satisfies + /// returned[i, ...] == this[dim0_start + i, ...]. + /// The returned tensor shares the underlying tensor buffer with this + /// tensor. + /// + /// NOTE: The returned tensor may not satisfy the same alignment + /// requirement as this tensor depending on the shape. The caller + /// must check the returned tensor's alignment before calling certain + /// methods that have alignment requirement (e.g., `flat()`, `tensor()`). + /// + /// NOTE: When fed with an N-dimensional tensor, this method returns a tensor + /// also with N dimensions. If you want to select a sub tensor, see SubSlice. + /// + /// REQUIRES: `dims()` >= 1 + /// REQUIRES: `0 <= dim0_start <= dim0_limit <= dim_size(0)` + Tensor Slice(int64_t dim0_start, int64_t dim0_limit) const; + + /// \brief Select a subslice from this tensor along the 1st dimension. + /// + /// When fed with an N-dimensional tensor, this method returns a tensor with + /// N-1 dimensions, where the returned tensor is a subslice of the input + /// tensor along the first dimension. The N-1 dimensions of the returned + /// tensor are the last N-1 dimensions of the input tensor. + /// + /// NOTE: The returned tensor may not satisfy the same alignment + /// requirement as this tensor depending on the shape. The caller + /// must check the returned tensor's alignment before calling certain + /// methods that have alignment requirement (e.g., `flat()`, `tensor()`). + /// + /// REQUIRES: `dims()` >= 1 + /// REQUIRES: `0 <= index < dim_size(0)` + Tensor SubSlice(int64_t index) const; + + /// \brief Parse `other` and construct the tensor. + + /// Returns `true` iff the parsing succeeds. If the parsing fails, + /// the state of `*this` is unchanged. + bool FromProto(const TensorProto& other) TF_MUST_USE_RESULT; + bool FromProto(Allocator* a, const TensorProto& other) TF_MUST_USE_RESULT; + + /// \brief Fills in `proto` with `*this` tensor's content. + /// + /// `AsProtoField()` fills in the repeated field for `proto.dtype()`, while + /// `AsProtoTensorContent()` encodes the content in `proto.tensor_content()` + /// in a compact form. + void AsProtoField(TensorProto* proto) const; + void AsProtoTensorContent(TensorProto* proto) const; + + /// \brief Return the tensor data as an `Eigen::Tensor` with the type and + /// sizes of this `Tensor`. + /// + /// Use these methods when you know the data type and the number of + /// dimensions of the Tensor and you want an `Eigen::Tensor` + /// automatically sized to the `Tensor` sizes. The implementation check + /// fails if either type or sizes mismatch. + /// + /// Example: + /// + /// ```c++ + /// + /// typedef float T; + /// Tensor my_mat(...built with Shape{rows: 3, cols: 5}...); + /// auto mat = my_mat.matrix(); // 2D Eigen::Tensor, 3 x 5. + /// auto mat = my_mat.tensor(); // 2D Eigen::Tensor, 3 x 5. + /// auto vec = my_mat.vec(); // CHECK fails as my_mat is 2D. + /// auto vec = my_mat.tensor(); // CHECK fails as my_mat is 2D. + /// auto mat = my_mat.matrix();// CHECK fails as type mismatch. + /// + /// ``` + template + typename TTypes::Vec vec() { + return tensor(); + } + + template + typename TTypes::Matrix matrix() { + return tensor(); + } + + template + typename TTypes::Tensor tensor() TF_ATTRIBUTE_NOINLINE; + + /// \brief Return the tensor data to an `Eigen::Tensor` with the + /// same size but a bitwise cast to the specified dtype `T`. + /// + /// Using a bitcast is useful for move and copy operations. + /// NOTE: this is the same as `tensor()` except a bitcast is allowed. + template + typename TTypes::Tensor bit_casted_tensor(); + + /// \brief Return the tensor data to an `Eigen::Tensor` with the + /// last dimension elements converted into single elements of a larger type. + /// + /// For example, this is useful for kernels that can treat NCHW_VECT_C int8 + /// tensors as NCHW int32 tensors. The sizeof(T) should equal the size of + /// the original element type * num elements in the original last dimension. + /// NDIMS should be 1 less than the original number of dimensions. + template + typename TTypes::Tensor reinterpret_last_dimension(); + + /// \brief Return the tensor data as an `Eigen::Tensor` of the data type and a + /// specified shape. + /// + /// These methods allow you to access the data with the dimensions + /// and sizes of your choice. You do not need to know the number of + /// dimensions of the Tensor to call them. However, they `CHECK` that + /// the type matches and the dimensions requested creates an + /// `Eigen::Tensor` with the same number of elements as the tensor. + /// + /// Example: + /// + /// ```c++ + /// + /// typedef float T; + /// Tensor my_ten(...built with Shape{planes: 4, rows: 3, cols: 5}...); + /// // 1D Eigen::Tensor, size 60: + /// auto flat = my_ten.flat(); + /// // 2D Eigen::Tensor 12 x 5: + /// auto inner = my_ten.flat_inner_dims(); + /// // 2D Eigen::Tensor 4 x 15: + /// auto outer = my_ten.shaped({4, 15}); + /// // CHECK fails, bad num elements: + /// auto outer = my_ten.shaped({4, 8}); + /// // 3D Eigen::Tensor 6 x 5 x 2: + /// auto weird = my_ten.shaped({6, 5, 2}); + /// // CHECK fails, type mismatch: + /// auto bad = my_ten.flat(); + /// + /// ``` + template + typename TTypes::Flat flat(); + + template + typename TTypes::UnalignedFlat unaligned_flat() { + return unaligned_shaped({NumElements()}); + } + + /// Returns the data as an Eigen::Tensor with NDIMS dimensions, collapsing all + /// Tensor dimensions but the last NDIMS-1 into the first dimension of the + /// result. If NDIMS > dims() then leading dimensions of size 1 will be + /// added to make the output rank NDIMS. + template + typename TTypes::Tensor flat_inner_dims(); + + /// Returns the data as an Eigen::Tensor with NDIMS dimensions, collapsing all + /// Tensor dimensions but the first NDIMS-1 into the last dimension of the + /// result. If NDIMS > dims() then trailing dimensions of size 1 will be + /// added to make the output rank NDIMS. + template + typename TTypes::Tensor flat_outer_dims(); + + /// Returns the data as an Eigen::Tensor with NDIMS dimensions, collapsing the + /// first 'begin' Tensor dimensions into the first dimension of the result and + /// the Tensor dimensions of the last dims() - 'begin' - NDIMS into the last + /// dimension of the result. If 'begin' < 0 then the |'begin'| leading + /// dimensions of size 1 will be added. If 'begin' + NDIMS > dims() then + /// 'begin' + NDIMS - dims() trailing dimensions of size 1 will be added. + template + typename TTypes::Tensor flat_inner_outer_dims(int64_t begin); + + template + typename TTypes::Tensor shaped(absl::Span new_sizes); + + /// \brief Return the tensor data to an `Eigen::Tensor` with the new + /// shape specified in `new_sizes` and cast to a new dtype `T`. + /// + /// Using a bitcast is useful for move and copy operations. + /// The allowed bitcast is the only difference from `shaped()`. + template + typename TTypes::Tensor bit_casted_shaped( + absl::Span new_sizes); + + template + typename TTypes::UnalignedTensor unaligned_shaped( + absl::Span new_sizes); + + /// \brief Return the Tensor data as a `TensorMap` of fixed size 1: + /// `TensorMap>`. + + /// Using `scalar()` allows the compiler to perform optimizations as + /// the size of the tensor is known at compile time. + template + typename TTypes::Scalar scalar(); + + /// Const versions of all the methods above. + template + typename TTypes::ConstVec vec() const { + return tensor(); + } + + template + typename TTypes::ConstMatrix matrix() const { + return tensor(); + } + + template + typename TTypes::ConstTensor tensor() const TF_ATTRIBUTE_NOINLINE; + + /// \brief Return the tensor data to an `Eigen::Tensor` with the + /// same size but a bitwise cast to the specified dtype `T`. + /// + /// Using a bitcast is useful for move and copy operations. + /// NOTE: this is the same as `tensor()` except a bitcast is allowed. + template + typename TTypes::ConstTensor bit_casted_tensor() const; + + /// \brief Return the tensor data to an `Eigen::Tensor` with the + /// last dimension elements converted into single elements of a larger type. + /// + /// For example, this is useful for kernels that can treat NCHW_VECT_C int8 + /// tensors as NCHW int32 tensors. The sizeof(T) should equal the size of + /// the original element type * num elements in the original last dimension. + /// NDIMS should be 1 less than the original number of dimensions. + template + typename TTypes::ConstTensor reinterpret_last_dimension() const; + + template + typename TTypes::ConstFlat flat() const; + + template + typename TTypes::UnalignedConstFlat unaligned_flat() const { + return unaligned_shaped({NumElements()}); + } + + template + typename TTypes::ConstTensor shaped( + absl::Span new_sizes) const; + + /// \brief Return the tensor data to an `Eigen::Tensor` with the new + /// shape specified in `new_sizes` and cast to a new dtype `T`. + /// + /// Using a bitcast is useful for move and copy operations. + /// The allowed bitcast is the only difference from `shaped()`. + template + typename TTypes::ConstTensor bit_casted_shaped( + absl::Span new_sizes) const; + + template + typename TTypes::UnalignedConstTensor unaligned_shaped( + absl::Span new_sizes) const; + + template + typename TTypes::ConstScalar scalar() const; + + template + typename TTypes::ConstTensor flat_inner_dims() const; + + template + typename TTypes::ConstTensor flat_outer_dims() const; + + template + typename TTypes::ConstTensor flat_inner_outer_dims( + int64_t begin) const; + + /// Render the first `max_entries` values in `*this` into a string. + std::string SummarizeValue(int64_t max_entries, bool print_v2 = false) const; + + /// A human-readable summary of the tensor suitable for debugging. + // `num_values` is the number of actual data values in the tensor + // included in the message. If the tensor might be resident in + // GPU/TPU memory use DeviceSafeDebugString instead. + std::string DebugString(int num_values) const; + std::string DebugString() const { return DebugString(3); } + + // Variant of DebugString() that should be used for possibly non-CPU tensors. + // If the tensor is not resident on CPU, we can't read its values as + // DebugString() does. + std::string DeviceSafeDebugString() const; + + /// Fill in the `TensorDescription` proto with metadata about the + /// tensor that is useful for monitoring and debugging. + void FillDescription(TensorDescription* description) const; + + /// \brief Returns a `StringPiece` mapping the current tensor's buffer. + /// + /// The returned `StringPiece` may point to memory location on devices + /// that the CPU cannot address directly. + /// + /// NOTE: The underlying tensor buffer is refcounted, so the lifetime + /// of the contents mapped by the `StringPiece` matches the lifetime of + /// the buffer; callers should arrange to make sure the buffer does + /// not get destroyed while the `StringPiece` is still used. + /// + /// REQUIRES: `DataTypeCanUseMemcpy(dtype())`. + absl::string_view tensor_data() const; + void* data() const; + + /// Copy the other tensor into this tensor, reshape it and reinterpret the + /// buffer's datatype. If an ok Status is returned, the two tensors now share + /// the same underlying storage. + /// + /// This call requires that the `other` tensor and the given type and shape + /// are "compatible" (i.e. they occupy the same number of bytes). + /// + /// Specifically: + /// + /// shape.num_elements() * DataTypeSize(type) + /// + /// must equal + /// + /// other.num_elements() * DataTypeSize(other.dtype()) + /// + /// In addition, this function requires: + /// * DataTypeSize(other.dtype()) != 0 + /// * DataTypeSize(type) != 0 + /// + /// If any of the requirements are not met, errors::InvalidArgument is + /// returned. + absl::Status BitcastFrom(const Tensor& other, DataType dtype, + const TensorShape& shape); + + /// Like BitcastFrom, but CHECK fails if any preconditions are not met. + /// + /// Deprecated. Use BitcastFrom instead and check the returned Status. + void UnsafeCopyFromInternal(const Tensor& other, DataType dtype, + const TensorShape& shape) { + TF_CHECK_OK(BitcastFrom(other, dtype, shape)); + } + + // Returns true if the refcount on buf_ and any possible underlying root + // buffer is one. + bool RefCountIsOne() const; + + // Experimental. Returns the refcount on buf_ if it points to a regular + // TensorBuffer. If buf_ points to a SubBuffer, returns -1. + int RefCount() const; + + // Returns the type of the underlying memory. + AllocatorMemoryType GetMemoryType() const { return buf_->GetMemoryType(); } + + private: + void CheckType(DataType expected_dtype) const; + void CheckTypeAndIsAligned(DataType expected_dtype) const; + void CheckIsAlignedAndSingleElement() const; + void set_dtype(DataType t) { shape_.set_data_type(t); } + + // TensorShape's InlineVector. + static absl::InlinedVector ComputeFlatInnerDims( + absl::Span orig, int64_t num_out_dims); + static absl::InlinedVector ComputeFlatOuterDims( + absl::Span orig, int64_t num_out_dims); + + TensorShape shape_; + TensorBuffer* buf_; + + friend class DMAHelper; // For access to buf_. + friend class TensorCApi; // For access to buf_. + friend class TensorCord; // For access to buf_. + friend class TensorReference; // For access to buf_. + friend class VariableOp; // For access to set_shape. + friend class AutoReloadVariableOp; // For access to set_shape. + friend class TensorTestHelper; // For access to set_shape. + friend class TensorInterface; // For access to set_shape. + friend class CastOpBase; // For access to set_dtype. + friend class ScopedAllocator; // For access to buf_. + friend class PjRtTensorBufferUtil; // For access to buf_. + friend absl::Status batch_util::CopyElementToSlice( + Tensor element, Tensor* parent, + int64_t index); // For access to base(). + friend absl::Status batch_util::CopySliceToElement( + const Tensor& parent, Tensor* element, + int64_t index); // For access to base(). + friend absl::Status batch_util::MaybeMoveSliceToElement( + Tensor* parent, Tensor* element, + int64_t index); // For access to base(). + friend absl::Status batch_util::CopyContiguousSlices( + const Tensor& src, int64_t src_offset, int64_t dst_offset, + int64_t num_slices, + Tensor* dst); // For access to base(). + friend absl::Status batch_util::MaybeMoveContiguousSlices( + Tensor& src, int64_t src_offset, int64_t dst_offset, int64_t num_slices, + Tensor* dst); // For access to base(). + + bool CanUseDMA() const; + + // Only needed by variable op to set the shape of an uninitialized + // Tensor. + // TODO: Remove this when we have a better story for detecting + // uninitialized tensors. + void set_shape(const TensorShape& shape) { + DataType dt = dtype(); + shape_ = shape; + set_dtype(dt); + } + + inline void CopyFromInternal(const Tensor& other, const TensorShape& shape) { + DCHECK_EQ(shape.num_elements(), other.NumElements()); + // Data type will be overwritten if this == &other, since dtype is part of + // shape. + DataType other_dtype = other.dtype(); + shape_ = shape; + set_dtype(other_dtype); + if (buf_ != other.buf_) { + if (buf_) buf_->Unref(); + buf_ = other.buf_; + if (buf_) buf_->Ref(); + } + } + + template + T* base() const; + + template + void FillDimsAndValidateCompatibleShape( + absl::Span new_sizes, + Eigen::array* dims) const; + + template + void FillDimsAndValidateCompatibleShape( + absl::Span new_sizes, + Eigen::array* dims) const; +}; + +// Implementation details + +// START_SKIP_DOXYGEN + +template +T* Tensor::base() const { + return buf_ == nullptr ? nullptr : buf_->base(); +} + +// This routine is defined out of line for code-space savings +template +typename TTypes::Tensor Tensor::tensor() { + CheckTypeAndIsAligned(DataTypeToEnum::v()); + return typename TTypes::Tensor(base(), + shape().AsEigenDSizes()); +} + +// This routine is defined out of line for code-space savings +template +typename TTypes::ConstTensor Tensor::tensor() const { + CheckTypeAndIsAligned(DataTypeToEnum::v()); + return typename TTypes::ConstTensor(base(), + shape().AsEigenDSizes()); +} + +template +typename TTypes::Tensor Tensor::bit_casted_tensor() { + CHECK(IsAligned()); + return typename TTypes::Tensor(base(), + shape().AsEigenDSizes()); +} + +template +typename TTypes::ConstTensor Tensor::bit_casted_tensor() const { + CHECK(IsAligned()); + return typename TTypes::ConstTensor(base(), + shape().AsEigenDSizes()); +} + +template +typename TTypes::Tensor Tensor::reinterpret_last_dimension() { + if (NDIMS == dims()) { + return tensor(); + } + CHECK(IsAligned()); + CHECK_EQ(static_cast(NDIMS), dims() - 1); + CHECK_EQ(static_cast(sizeof(T)), + shape_.dim_sizes()[NDIMS] * DataTypeSize(dtype())); + Eigen::array dims; + for (int d = 0; d < NDIMS; ++d) { + dims[d] = shape_.dim_sizes()[d]; + } + return typename TTypes::Tensor(base(), dims); +} + +template +typename TTypes::ConstTensor Tensor::reinterpret_last_dimension() + const { + if (NDIMS == dims()) { + return tensor(); + } + CHECK(IsAligned()); + CHECK_EQ(static_cast(NDIMS), dims() - 1); + CHECK_EQ(static_cast(sizeof(T)), + shape_.dim_sizes()[NDIMS] * DataTypeSize(dtype())); + Eigen::array dims; + for (int d = 0; d < NDIMS; ++d) { + dims[d] = shape_.dim_sizes()[d]; + } + return typename TTypes::ConstTensor(base(), dims); +} + +template +void Tensor::FillDimsAndValidateCompatibleShape( + absl::Span new_sizes, + Eigen::array* dims) const { + CHECK_EQ(NDIMS, new_sizes.size()); + int64_t new_num_elements = 1; + for (size_t d = 0; d < NDIMS; d++) { + new_num_elements *= new_sizes[d]; + (*dims)[d] = new_sizes[d]; + } + CHECK_EQ(new_num_elements, NumElements()); +} + +template +void Tensor::FillDimsAndValidateCompatibleShape( + absl::Span new_sizes, + Eigen::array* dims) const { + CHECK_EQ(NDIMS, new_sizes.size()); + int64_t new_num_elements = 1; + for (size_t d = 0; d < NDIMS; d++) { + new_num_elements *= new_sizes[d]; + (*dims)[d] = new_sizes[d]; + } + const int element_size = DataTypeSize(BaseType(dtype())); + if (element_size > 0) { + CHECK_EQ(new_num_elements * static_cast(sizeof(T)), + NumElements() * element_size); + } else { + // DataTypeSize() returns 0 for some data types. In this case, assume that T + // has the same size as the buffer type. + // NOTE: If we can be sure that DataTypeSize() does not return 0 for all POD + // types, then we should check DataTypeToEnum::v() == dtype(). Or simply + // check if `element_size > 0` to err when bit cast is attempted on Tensor + // of unknown data type size. + CHECK_EQ(new_num_elements, NumElements()); + } +} + +template +typename TTypes::Flat Tensor::flat() { + // Equivalent to 'return shaped({NumElements()});' + CheckTypeAndIsAligned(DataTypeToEnum::v()); + Eigen::array dims; + dims[0] = NumElements(); + return typename TTypes::Tensor(base(), dims); +} + +template +typename TTypes::ConstFlat Tensor::flat() const { + // Equuivalent to 'return shaped({NumElements()});' + CheckTypeAndIsAligned(DataTypeToEnum::v()); + Eigen::array dims; + dims[0] = NumElements(); + return typename TTypes::ConstTensor(base(), dims); +} + +template +typename TTypes::Tensor Tensor::shaped( + absl::Span new_sizes) { + CheckTypeAndIsAligned(DataTypeToEnum::v()); + Eigen::array dims; + FillDimsAndValidateCompatibleShape(new_sizes, &dims); + return typename TTypes::Tensor(base(), dims); +} + +template +typename TTypes::Tensor Tensor::bit_casted_shaped( + absl::Span new_sizes) { + CHECK(IsAligned()); + Eigen::array dims; + FillDimsAndValidateCompatibleShape(new_sizes, &dims); + return typename TTypes::Tensor(base(), dims); +} + +template +typename TTypes::UnalignedTensor Tensor::unaligned_shaped( + absl::Span new_sizes) { + CheckType(DataTypeToEnum::v()); + Eigen::array dims; + FillDimsAndValidateCompatibleShape(new_sizes, &dims); + return typename TTypes::UnalignedTensor(base(), dims); +} + +template +typename TTypes::ConstTensor Tensor::shaped( + absl::Span new_sizes) const { + CheckType(DataTypeToEnum::v()); + CHECK(IsAligned()) << "ptr = " << base(); + Eigen::array dims; + FillDimsAndValidateCompatibleShape(new_sizes, &dims); + return typename TTypes::ConstTensor(base(), dims); +} + +template +typename TTypes::ConstTensor Tensor::bit_casted_shaped( + absl::Span new_sizes) const { + CHECK(IsAligned()); + Eigen::array dims; + FillDimsAndValidateCompatibleShape(new_sizes, &dims); + return typename TTypes::ConstTensor(base(), dims); +} + +template +typename TTypes::UnalignedConstTensor Tensor::unaligned_shaped( + absl::Span new_sizes) const { + CheckType(DataTypeToEnum::v()); + Eigen::array dims; + FillDimsAndValidateCompatibleShape(new_sizes, &dims); + return typename TTypes::UnalignedConstTensor(base(), dims); +} + +template +typename TTypes::Scalar Tensor::scalar() { + static_assert( + !std::is_same::value, + "std::string is no longer a scalar type, use tensorflow::tstring"); + CheckIsAlignedAndSingleElement(); + return typename TTypes::Scalar(base()); +} + +template +typename TTypes::ConstScalar Tensor::scalar() const { + static_assert( + !std::is_same::value, + "std::string is no longer a scalar type, use tensorflow::tstring"); + CheckIsAlignedAndSingleElement(); + return typename TTypes::ConstScalar(base()); +} + +template +typename TTypes::Tensor Tensor::flat_inner_dims() { + return shaped(ComputeFlatInnerDims(shape_.dim_sizes(), NDIMS)); +} + +template +typename TTypes::Tensor Tensor::flat_outer_dims() { + return shaped(ComputeFlatOuterDims(shape_.dim_sizes(), NDIMS)); +} + +template +typename TTypes::Tensor Tensor::flat_inner_outer_dims(int64_t begin) { + absl::InlinedVector flat_outer = + ComputeFlatOuterDims(shape_.dim_sizes(), begin + NDIMS); + return shaped(ComputeFlatInnerDims(flat_outer, NDIMS)); +} + +template +typename TTypes::ConstTensor Tensor::flat_inner_dims() const { + return shaped(ComputeFlatInnerDims(shape_.dim_sizes(), NDIMS)); +} + +template +typename TTypes::ConstTensor Tensor::flat_outer_dims() const { + return shaped(ComputeFlatOuterDims(shape_.dim_sizes(), NDIMS)); +} + +template +typename TTypes::ConstTensor Tensor::flat_inner_outer_dims( + int64_t begin) const { + absl::InlinedVector flat_outer = + ComputeFlatOuterDims(shape_.dim_sizes(), begin + NDIMS); + return shaped(ComputeFlatInnerDims(flat_outer, NDIMS)); +} + +inline Tensor::Tensor(const Tensor& other) + : shape_(other.shape()), buf_(other.buf_) { + if (buf_) buf_->Ref(); +} + +inline Tensor::Tensor(Tensor&& other) + : shape_(std::move(other.shape_)), buf_(other.buf_) { + other.buf_ = nullptr; +} + +class Tensor::HostScalarTensorBufferBase : public TensorBuffer { + public: + using TensorBuffer::TensorBuffer; + bool GetAllocatedBytes(size_t* out_bytes) const final; + void FillAllocationDescription(AllocationDescription* proto) const final; +}; + +// A packed representation for a single scalar value of type `T`, and a +// `TensorBuffer` implementation that describes (and manages the lifetime of) +// that value. +template +struct Tensor::ValueAndTensorBuffer { + class HostScalarTensorBuffer : public Tensor::HostScalarTensorBufferBase { + public: + explicit HostScalarTensorBuffer(void* data) + : HostScalarTensorBufferBase(data) {} + size_t size() const final { return sizeof(T); } + TensorBuffer* root_buffer() final { return this; } + + // Override `operator delete` so that calling `delete this` in + // `core::Refcounted::Unref()` for an object of this type will free + // the enclosing `ValueAndTensorBuffer` for the tensor buffer. + // + // NOTE(mrry): The definition of this method must be outside the class + // definition in order to satisfy some compilers. + static void operator delete(void* ptr); + + static void operator delete(void*, void*) { + // Some compilers require an overridden class-specific deallocation + // function, which will be called if placement `new` throws an + // exception. + } + + private: + ~HostScalarTensorBuffer() override { static_cast(data())->~T(); } + }; + + T value; + HostScalarTensorBuffer tensor_buffer; +}; + +/* static */ +template +void Tensor::ValueAndTensorBuffer::HostScalarTensorBuffer::operator delete( + void* ptr) { + // Use a dummy object to compute to offset of + // `ValueAndTensorBuffer::tensor_buffer`, because `offsetof()` is not + // necessarily defined on this non-POD type (until C++17). + // + // NOTE(mrry): Using `sizeof(Tensor::ValueAndTensorBuffer)` here requires + // us to define this method outside the class definition, so that it is not + // considered an incomplete type. + typename std::aligned_storage), + alignof(Tensor::ValueAndTensorBuffer)>::type + dummy_storage_; + Tensor::ValueAndTensorBuffer* dummy_object = + reinterpret_cast*>(&dummy_storage_); + intptr_t offset = reinterpret_cast(&dummy_object->tensor_buffer) - + reinterpret_cast(dummy_object); + + port::AlignedFree(static_cast(ptr) - offset); +} + +template +Tensor::Tensor(T value, host_scalar_tag tag) { + auto* value_and_buf = static_cast*>( + port::AlignedMalloc(sizeof(typename Tensor::ValueAndTensorBuffer), + EIGEN_MAX_ALIGN_BYTES)); + new (&value_and_buf->value) T(std::move(value)); + new (&value_and_buf->tensor_buffer) + typename Tensor::ValueAndTensorBuffer::HostScalarTensorBuffer( + value_and_buf); + buf_ = &value_and_buf->tensor_buffer; + set_dtype(DataTypeToEnum::value); +} + +inline Tensor& Tensor::operator=(Tensor&& other) { + // Avoid self-assignment, since we might destroy our underlying buffer. + if (&other != this) { + shape_ = std::move(other.shape_); + if (buf_) buf_->Unref(); + buf_ = other.buf_; + other.buf_ = nullptr; + } + return *this; +} + +// END_SKIP_DOXYGEN + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_FRAMEWORK_TENSOR_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/framework/tensor_key.h b/third_party/tflite-hdrs/tensorflow/core/framework/tensor_key.h new file mode 100644 index 00000000..3bde6fce --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/framework/tensor_key.h @@ -0,0 +1,77 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_FRAMEWORK_TENSOR_KEY_H_ +#define TENSORFLOW_CORE_FRAMEWORK_TENSOR_KEY_H_ + +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/types.h" + +namespace tensorflow { + +class TensorKey : public Tensor { + public: + using Tensor::Tensor; + + TensorKey(const Tensor& t) : Tensor(t) {} + + // Equality operator. Needed for absl hashing. + friend bool operator==(const TensorKey& t1, const TensorKey& t2) { + if (t1.dtype() != t2.dtype() || t1.shape() != t2.shape()) { + return false; + } + if (DataTypeCanUseMemcpy(t1.dtype())) { + return t1.tensor_data() == t2.tensor_data(); + } else if (t1.dtype() == DT_STRING) { + const auto s1 = t1.unaligned_flat(); + const auto s2 = t2.unaligned_flat(); + for (int64_t i = 0, n = t1.NumElements(); i < n; ++i) { + if (TF_PREDICT_FALSE(s1(i) != s2(i))) { + return false; + } + } + return true; + } else { + DCHECK(false) << "Unimplemented dtype " << DataTypeString(t1.dtype()) + << std::endl; + } + return false; + } + + friend bool operator!=(const TensorKey& t1, const TensorKey& t2) { + return !(t1 == t2); + } + + // Needed for absl hash function. + template + friend H AbslHashValue(H h, const TensorKey& k) { + if (DataTypeCanUseMemcpy(k.dtype())) { + return H::combine(std::move(h), k.tensor_data()); + } else if (k.dtype() == DT_STRING) { + const auto strs = k.unaligned_flat(); + for (int64_t i = 0, n = k.NumElements(); i < n; ++i) { + h = H::combine(std::move(h), strs(i)); + } + return h; + } else { + DCHECK(false) << "Unimplemented dtype " << DataTypeString(k.dtype()) + << std::endl; + } + return h; + } +}; + +} // namespace tensorflow + +#endif diff --git a/third_party/tflite-hdrs/tensorflow/core/framework/tensor_matcher.h b/third_party/tflite-hdrs/tensorflow/core/framework/tensor_matcher.h new file mode 100644 index 00000000..e89cfc15 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/framework/tensor_matcher.h @@ -0,0 +1,55 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_FRAMEWORK_TENSOR_MATCHER_H_ +#define TENSORFLOW_CORE_FRAMEWORK_TENSOR_MATCHER_H_ + +#include +#include "tensorflow/core/framework/tensor.h" + +namespace tensorflow { +namespace test { + +// Matcher for tensorflow::Tensor instances. Two tensors match iff +// +// - their dtypes are equal, +// - their shapes are equal, +// - and their contents are equal. +// +// Their contents are matched by ::testing::Pointwise() after calling .flat() +// method where the type T satisfies: +// +// ::tensorflow::DataTypeToEnum::value == dtype +// +// Use this like: +// +// EXPECT_THAT(lhs, TensorEq(rhs)); +// +// All POD types and DT_STRING type tensors are supported. Note that this +// utility requires Tensors to point to CPU memory. +class TensorEq { + public: + explicit TensorEq(const tensorflow::Tensor& target) : target_(target) {} + + // Matchers depend on implicit casts. Do not make explicit. + operator ::testing::Matcher() const; // NOLINT + + private: + const tensorflow::Tensor& target_; +}; + +} // namespace test +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_FRAMEWORK_TENSOR_MATCHER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/framework/tensor_reference.h b/third_party/tflite-hdrs/tensorflow/core/framework/tensor_reference.h new file mode 100644 index 00000000..59ccd281 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/framework/tensor_reference.h @@ -0,0 +1,54 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_FRAMEWORK_TENSOR_REFERENCE_H_ +#define TENSORFLOW_CORE_FRAMEWORK_TENSOR_REFERENCE_H_ + +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/lib/gtl/inlined_vector.h" + +namespace tensorflow { + +// An opaque class that holds a reference to an underlying TensorBuffer. +// Unlike Tensor, it does not have any shape or type information, so +// it is cheaper to construct/move, but the only thing you can really do +// with it is Unref it, which releases one of the references to the underlying +// TensorBuffer. +// IMPORTANT: If you do not call Unref(), you will likely leak tensor memory. +class TensorReference { + public: + // Take the reference of the root buffer so the size will be more accurate + explicit TensorReference(const Tensor& tensor) + : buf_(tensor.buf_ ? tensor.buf_->root_buffer() : nullptr) { + if (buf_) buf_->Ref(); + } + + ~TensorReference() {} + + void Unref() const { + if (buf_) buf_->Unref(); + } + + void FillDescription(AllocationDescription* description) const { + if (buf_) buf_->FillAllocationDescription(description); + } + + private: + TensorBuffer* buf_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_FRAMEWORK_TENSOR_REFERENCE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/framework/tensor_shape.h b/third_party/tflite-hdrs/tensorflow/core/framework/tensor_shape.h new file mode 100644 index 00000000..0bcf1fc5 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/framework/tensor_shape.h @@ -0,0 +1,795 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_FRAMEWORK_TENSOR_SHAPE_H_ +#define TENSORFLOW_CORE_FRAMEWORK_TENSOR_SHAPE_H_ + +#include + +#include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/lib/gtl/inlined_vector.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/statusor.h" + +namespace tensorflow { + +// START_SKIP_DOXYGEN +template +class TensorShapeIter; +class TensorShape; +class TensorShapeProto; +class PartialTensorShape; +// END_SKIP_DOXYGEN + +/// Internal representation for both TensorShape and PartialTensorShape. +class TensorShapeRep { + public: + ~TensorShapeRep(); + + /// Copy the specified shape + TensorShapeRep(const TensorShapeRep& b); + void operator=(const TensorShapeRep& b); + + /// Move the specified shape. After moving, `b` is safe for destruction and + // can be reassigned into, but its dimensions and number of elements can be + // nonsensical (e.g., negative dimension sizes, or number of elements not + // properly recomputed). + TensorShapeRep(TensorShapeRep&& b); + void operator=(TensorShapeRep&& b); + + /// Clear a tensor shape, producing the scalar shape. + void Clear(); + + // Maximum number of dimensions in a tensor. + // It's 254 because 255 = kUnknownRank is used to represent unknown rank. + static constexpr int MaxDimensions() { return 254; } + + /// \brief Returns the number of elements in the tensor. + /// + /// We use `int64` and not `size_t` to be compatible with `Eigen::Tensor` + /// which uses `ptrdiff_t`. For PartialTensorShape, -1 means not fully + /// defined. + int64_t num_elements() const { return num_elements_; } + + /// For error messages. + std::string DebugString() const; + static std::string DebugString(const TensorShapeProto& proto); + + protected: + // Constructable only via TensorShapeBase + TensorShapeRep() = default; + + void ClearAllButDataType(); + + // We use 16 bytes to represent a TensorShape. Because we need to + // be able to support full 64-bit dimension sizes and an arbitrary + // number of dimensions for a Tensor, but most tensor dimensions are + // significantly smaller than 64 bits and most tensors are 1, 2, or 3 + // dimensions, we have several representations. + // Rep16: Supports up to 6 dimensions where each dimension is < 2^16 - 1 + // Rep32: Supports up to 3 dimensions where each dimension is < 2^32 - 1 + // Rep64: Supports arbitrary dimensionality, 64-bit dimensions using + // an out of line vector. + // For PartialTensorShape, a dimension of static_cast(-1) is unknown. + // This value is not allowed in TensorShape either for format compatibility. + struct Rep16 { + uint16 dims_[6]; + }; + struct Rep32 { + uint32 dims_[3]; + }; + struct Rep64 { + absl::InlinedVector* dims_; + }; + + // We use the max value of uint16 or uint32 to represent unknown shapes, so + // the maximum representable valid shape in these representations is one less. + static constexpr int64_t kMaxRep16 = std::numeric_limits::max() - 1; + static constexpr int64_t kMaxRep32 = std::numeric_limits::max() - 1; + static constexpr uint16 kUnknownRep16 = std::numeric_limits::max(); + static constexpr uint32 kUnknownRep32 = std::numeric_limits::max(); + + Rep16* as16() { return reinterpret_cast(buf()); } + Rep32* as32() { return reinterpret_cast(buf()); } + Rep64* as64() { return reinterpret_cast(buf()); } + + const Rep16* as16() const { return reinterpret_cast(buf()); } + const Rep32* as32() const { return reinterpret_cast(buf()); } + const Rep64* as64() const { return reinterpret_cast(buf()); } + + enum RepTag { REP16 = 0, REP32 = 1, REP_OUT_OF_LINE = 2 }; + + // Since we have a convenient extra byte available, we allow the + // Tensor class to store an 8-bit value in this extra storage. This + // allows it to store the Tensor's datatype enum value here and avoid + // an extra word of storage. + friend class Tensor; + friend class TensorShapeTestHelper; + DataType data_type() const { return static_cast(buf()[13]); } + void set_data_type(DataType dt) { + // We only have 8 bits available to store DataType, so make sure it fits + DCHECK_LT(static_cast(dt), 256u); + buf()[13] = static_cast(dt); + } + + // We store the number of dimensions in byte 14, and the RepTag in byte 15. + // Bytes [0..13] vary depending on the representation. + // A value of 255 indicates unknown rank in the PartialTensorShape case. + static constexpr uint8 kUnknownRank = 255; + uint8 ndims_byte() const { return buf()[14]; } + void set_ndims_byte(uint8 nd) { buf()[14] = nd; } + + RepTag tag() const { return static_cast(buf()[15]); } + void set_tag(RepTag tag) { buf()[15] = static_cast(tag); } + + void set_num_elements(int64_t n) { num_elements_ = n; } + + private: + void DestructorOutOfLine(); + void SlowCopyFrom(const TensorShapeRep& b); + + uint8* buf() { return &u_.buf[0]; } + const uint8* buf() const { return &u_.buf[0]; } + + union { + uint8 buf[16]; + // Force data to be aligned enough for a pointer. + Rep64* unused_aligner; + } u_; + int64_t num_elements_; +}; + +/// Base class for TensorShape and PartialTensorShape. +/// The class is templatized by either TensorShape or PartialTensorShape to +/// allow skipping known/unknown checks in the TensorShape case, but the +/// representation is shared exactly for fast conversion. +template +class TensorShapeBase : public TensorShapeRep { + public: + /// \brief Construct a `TensorShapeBase` from the provided sizes. + /// REQUIRES: `dim_sizes[i] >= 0` (or >= -1 for PartialTensorShape) + explicit TensorShapeBase(absl::Span dim_sizes); + TensorShapeBase(std::initializer_list dim_sizes) + : TensorShapeBase(absl::Span(dim_sizes)) {} + + /// Construct an empty TensorShape, or an unknown rank PartialTensorShape + TensorShapeBase(); + + // Cannot be made explicit because we rely on conversion between proto and + // `TensorShapeBase` throughtout the codebase (needs bigger cleanup) + TensorShapeBase(const TensorShapeProto& proto); + + // These factory methods should be used instead of the constructors that take + // an array of sizes if calling code cannot validate that the sizes specify a + // valid `TensorShape`. + // The value in `*out` is valid iff the returned value is `Status::OK`. + static absl::Status BuildTensorShapeBase(absl::Span dim_sizes, + TensorShapeBase* out); + static absl::Status BuildTensorShapeBase( + std::initializer_list dim_sizes, TensorShapeBase* out) { + return BuildTensorShapeBase(absl::Span(dim_sizes), out); + } + static absl::Status BuildTensorShapeBase(const TensorShapeProto& proto, + TensorShapeBase* out); + + /// Returns `true` iff `proto` is a valid tensor shape. + // For TensorShape, the proto shape must be fully defined. + static bool IsValid(const TensorShapeProto& proto); + + /// Returns `OK` iff `proto` is a valid tensor shape, and a descriptive error + /// status otherwise. + static absl::Status IsValidShape(const TensorShapeProto& proto); + + /// Returns `true` iff this is a valid tensor shape. + bool IsValid(); + + /// \brief Add a dimension to the end ("inner-most"). + /// REQUIRES: `size >= 0` + void AddDim(int64_t size); + + /// Same as `AddDim` but returns a `Status`. + /// Use if unsure is `size >= 0`, to prevent `CHECK`-crashes. + absl::Status AddDimWithStatus(int64_t size); + + /// Appends all the dimensions from `shape`. + void AppendShape(const TensorShapeBase& shape); + + /// Same as `RemoveDim` but returns a `Status`. + /// Use if you cannot validate all invariants, to prevent `CHECK`-fail. + absl::Status AppendShapeWithStatus(const TensorShapeBase& shape); + + /// \brief Insert a dimension somewhere in the `TensorShape`. + /// REQUIRES: `0 <= d <= dims()` + /// REQUIRES: `size >= 0` + void InsertDim(int d, int64_t size); + + /// Same as `InsertDim` but returns a `Status`. + /// Use if unsure if requirements in `InsertDim` are satistified, to prevent + /// `CHECK`-fail crashes. + absl::Status InsertDimWithStatus(int d, int64_t size); + + /// \brief Modifies the size of the dimension `d` to be `size` + /// REQUIRES: `0 <= d < dims()` + /// REQUIRES: `size >= 0` + void set_dim(int d, int64_t size); + + /// Same as `set_dim` but returns a `Status`. + /// Use if unsure if requirements in `set_dim` are satistified, to prevent + /// `CHECK`-fail crashes. + absl::Status SetDimWithStatus(int d, int64_t size); + + /// \brief Removes dimension `d` from the `TensorShape`. + /// REQUIRES: `0 <= d < dims()` + void RemoveDim(int d) { + CHECK_GE(d, 0); + RemoveDimRange(d, d + 1); + } + + /// Same as `RemoveDim` but returns a `Status`. + /// Use if unsure is `0 <= d < dims()`, to prevent `CHECK`-crashes. + absl::Status RemoveDimWithStatus(int64_t d) { + if (TF_PREDICT_FALSE(d < 0)) { + return errors::Internal( + "Expected dimension index to be non-negative, got ", d); + } + return RemoveDimRangeWithStatus(d, d + 1); + } + + /// \brief Removes last `n` dimensions from the `TensorShape`. + /// REQUIRES: `0 <= n <= dims()` + void RemoveLastDims(int n) { + CHECK_LE(n, dims()); + RemoveDimRange(dims() - n, dims()); + } + + /// Same as `RemoveLastDims` but returns a `Status`. + /// Use if unsure is `0 <= n <= dims()`, to prevent `CHECK`-crashes. + absl::Status RemoveLastDimsWithStatus(int64_t n) { + if (TF_PREDICT_FALSE(n > dims())) { + return errors::Internal("Expected dimension index to be at most ", dims(), + " got ", n); + } + return RemoveDimRangeWithStatus(dims() - n, dims()); + } + + /// \brief Removes the dimensions in range `[begin:end)` from `TensorShape`. + /// Negative values of `end` are interpreted as `dims() + end + 1` (as in + /// Python). The same is true for negative values of `begin`. + /// REQUIRES: `-(dims()+1) <= begin <= dims()` + /// REQUIRES: `-(dims()+1) <= end <= dims()` + void RemoveDimRange(int begin, int end); + + /// Same as `RemoveDimRange` but returns a `Status`. + /// Use if unsure if requirements in `RemoveDimRange` are satistified, to + /// prevent `CHECK`-fail crashes. + absl::Status RemoveDimRangeWithStatus(int begin, int end); + + /// Return whether the rank is unknown + bool unknown_rank() const { + return kIsPartial && ndims_byte() == kUnknownRank; + } + + /// Return the number of dimensions in the tensor. + /// Can be -1 meaning unknown rank for PartialTensorShape. + int dims() const { + uint8 dims = ndims_byte(); + return kIsPartial && dims == kUnknownRank ? -1 : dims; + } + + /// \brief Returns the number of elements in dimension `d`. + /// REQUIRES: `0 <= d < dims()` + // TODO(touts): Rename to `dimension()` to match + // `Eigen::Tensor::dimension()`? + int64_t dim_size(int d) const; + + /// Returns sizes of all dimensions. + // Returns an empty list for unknown rank PartialTensorShape. + absl::InlinedVector dim_sizes() const; + + /// Return true iff the rank and all of the dimensions are well defined + // TODO(irving): Rename to is_fully_defined now that it's fast. + bool IsFullyDefined() const { return !kIsPartial || num_elements() != -1; } + + /// Fill `*proto` from `*this`. + void AsProto(TensorShapeProto* proto) const; + TensorShapeProto AsProto() const; + + /// For iterating through the dimensions. + TensorShapeIter begin() const; + TensorShapeIter end() const; + + protected: + // Optimized constructor for a shape representing an empty vector. + // + // This constructor is provided to optimize the default constructor for + // `Tensor`. + explicit TensorShapeBase(DataType dt); + + private: + absl::Status RecomputeNumElements(); + absl::Status InitDims(absl::Span dim_sizes); + + // True for PartialTensorShape, false for TensorShape + static constexpr bool kIsPartial = + std::is_same::value; + static_assert(kIsPartial || std::is_same::value, + "Shape is neither TensorShape nor PartialTensorShape"); + + // Used by AddDim and MakeShapeHelper. Does no error checking. + void UnsafeAddDim(int64_t size, int64_t new_num_elements); + + // For use by TensorShapeUtils::MakeShape + template + friend absl::Status MakeShapeHelper(const T*, int64_t, S*); +}; + +/// Outputs `TensorShapeBase` to `std::ostream`. +template +std::ostream& operator<<(std::ostream& os, const TensorShapeBase& tsb) { + return os << tsb.DebugString(); +} + +/// Represents the shape of a Tensor. +/// +/// A tensor's shape is denoted by its number of dimensions and a size for each +/// dimension. For example, a Tensor represented by a 3 x 4 matrix would have +/// a shape of 2-D, [3,4]. +/// +/// If you know the exact shape of your Tensor when you create the TensorShape +/// object, you can specify it then, or you can create a TensorShape with +/// zero dimensions and one element, and call AddDim() to add dimensions later. +class TensorShape : public TensorShapeBase { + public: + using TensorShapeBase::TensorShapeBase; + + // These factory methods should be used instead of the constructors that take + // an array of sizes if calling code cannot validate that the sizes specify a + // valid `TensorShape`. + // The value in `*out` is valid iff the returned value is `Status::OK`. + static absl::Status BuildTensorShape(absl::Span dim_sizes, + TensorShape* out) { + return BuildTensorShapeBase(dim_sizes, out); + } + static absl::Status BuildTensorShape(std::initializer_list dim_sizes, + TensorShape* out) { + return BuildTensorShape(absl::Span(dim_sizes), out); + } + static absl::Status BuildTensorShape(const TensorShapeProto& proto, + TensorShape* out) { + return BuildTensorShapeBase(proto, out); + } + + static absl::StatusOr BuildTensorShape( + const TensorShapeProto& proto) { + TensorShape out; + TF_RETURN_IF_ERROR(BuildTensorShape(proto, &out)); + return out; + } + + /// Allow a TensorShape to be used as a PartialTensorShape without copying + operator const PartialTensorShape&() const; // NOLINT(runtime/explicit) + + /// Returns true if `*this` and `b` have the same sizes. Ignores + /// dimension names. + bool IsSameSize(const TensorShape& b) const; + + /// Fill `*dsizes` from `*this`. + /// Notice: Using IndexType=int32 in combination with To32Bit() can + /// significantly improve performance on GPU. + template + Eigen::DSizes AsEigenDSizes() const; + + // Same as `AsEigenDSizes()` but returns a `Status` instead. + // Use this method to surface error to user instead of crashing if `NDMIS` is + // not equal to `dims()`. + // Caller must take ownership of `out`. + template + absl::Status AsEigenDSizesWithStatus( + Eigen::DSizes* out) const; + + /// Same as `AsEigenDSizes()` but allows for `NDIMS > dims()` -- in + /// which case we pad the rest of the sizes with 1. + /// Notice: Using IndexType=int32 in combination with To32Bit() can + /// significantly improve performance on GPU. + template + Eigen::DSizes AsEigenDSizesWithPadding() const; + + // Same as `AsEigenDSizesWithPadding()` but returns a `Status` instead. + // Use this method to surface error to user instead of crashing if `NDMIS` is + // not equal to `dims()`. + // Caller must take ownership of `out`. + template + absl::Status AsEigenDSizesWithPaddingWithStatus( + Eigen::DSizes* out) const; + + private: + // These CHECK fail to ease debugging. + // REQUIRES: dims() == NDIMS + void CheckDimsEqual(int NDIMS) const; + // REQUIRES: dims() <= NDIMS + void CheckDimsAtMost(int NDIMS) const; + + // Fill output from `*this`. + // Helper method for common code between `AsEigenDSize()` and + // `AsEigenDSizeWithStatus()`. + template + Eigen::DSizes AsEigenDSizesCopy() const; + + // Fill output from `*this`. + // Helper method for common code between `AsEigenDSizesWithPadding()` and + // `AsEigenDSizeWithPaddingWithStatus()`. + template + Eigen::DSizes AsEigenDSizesCopyAndPad() const; + + // For access to TensorShapeBase(DataType). + friend class Tensor; +}; + +inline bool operator==(const TensorShape& a, const TensorShape& b) { + return a.IsSameSize(b); +} +inline bool operator!=(const TensorShape& a, const TensorShape& b) { + return !(a == b); +} + +/// Outputs `TensorShapeBase` to `std::ostream`. +inline std::ostream& operator<<(std::ostream& os, const TensorShape& ts) { + return os << ts.DebugString(); +} + +/// Represents the value of one dimension in a TensorShape. +struct TensorShapeDim { + explicit TensorShapeDim(int64_t s) : size(s) {} + int64_t size; +}; + +// START_SKIP_DOXYGEN +template +class TensorShapeIter { + public: + TensorShapeIter(const Shape* shape, int d) : shape_(shape), d_(d) {} + bool operator==(const TensorShapeIter& rhs) { + DCHECK(shape_ == rhs.shape_); + return d_ == rhs.d_; + } + bool operator!=(const TensorShapeIter& rhs) { + DCHECK(shape_ == rhs.shape_); + return d_ != rhs.d_; + } + void operator++() { ++d_; } + TensorShapeDim operator*() { return TensorShapeDim(shape_->dim_size(d_)); } + + private: + const Shape* shape_; + int d_; +}; +// END_SKIP_DOXYGEN + +/// \brief Static helper routines for `TensorShape`. Includes a few common +/// predicates on a tensor shape. +class TensorShapeUtils { + public: + static bool IsScalar(const TensorShape& shape) { return shape.dims() == 0; } + + static bool IsVector(const TensorShape& shape) { return shape.dims() == 1; } + + static bool IsVectorOrHigher(const TensorShape& shape) { + return shape.dims() >= 1; + } + + static bool IsMatrix(const TensorShape& shape) { return shape.dims() == 2; } + + static bool IsSquareMatrix(const TensorShape& shape) { + return shape.dims() == 2 && shape.dim_size(0) == shape.dim_size(1); + } + + static bool IsMatrixOrHigher(const TensorShape& shape) { + return shape.dims() >= 2; + } + + /// \brief Returns a `TensorShape` whose dimensions are + /// `dims[0]`, `dims[1]`, ..., `dims[n-1]`. + static absl::Status MakeShape(const int32* dims, int64_t n, TensorShape* out); + static absl::Status MakeShape(const int64_t* dims, int64_t n, + TensorShape* out); + static absl::Status MakeShape(absl::Span shape, + TensorShape* out); + static absl::Status MakeShape(absl::Span shape, + TensorShape* out); + static absl::Status MakeShape(const int32* dims, int64_t n, + PartialTensorShape* out); + static absl::Status MakeShape(const int64_t* dims, int64_t n, + PartialTensorShape* out); + static absl::Status MakeShape(absl::Span shape, + PartialTensorShape* out); + static absl::Status MakeShape(absl::Span shape, + PartialTensorShape* out); + + static std::string ShapeListString( + const absl::Span& shapes); + + /// \brief Returns true iff `shape` starts with `prefix`. + static bool StartsWith(const TensorShape& shape, const TensorShape& prefix); + + /// \brief Returns true iff `shape` ends with `suffix`. + static bool EndsWith(const TensorShape& shape, const TensorShape& suffix); + + /// \brief Returns the product of values in an int64 array, + /// or a failing Status if the array represents a value larger than + /// a `TensorShape` can hold. + static absl::Status NumElements(absl::Span shape, + int64_t* num_elements); +}; + +/// Manages the partially known dimensions of a Tensor and their sizes. +class PartialTensorShape : public TensorShapeBase { + public: + PartialTensorShape() {} + using TensorShapeBase::TensorShapeBase; + + // These factory methods should be used instead of the constructors that take + // an array of sizes if calling code cannot validate that the sizes specify a + // valid `PartialTensorShape`. + // The value in `*out` is valid iff the returned value is `Status::OK`. + static absl::Status BuildPartialTensorShape( + absl::Span dim_sizes, PartialTensorShape* out) { + return BuildTensorShapeBase(dim_sizes, out); + } + static absl::Status BuildPartialTensorShape( + std::initializer_list dim_sizes, PartialTensorShape* out) { + return BuildPartialTensorShape(absl::Span(dim_sizes), out); + } + static absl::Status BuildPartialTensorShape(const TensorShapeProto& proto, + PartialTensorShape* out) { + return BuildTensorShapeBase(proto, out); + } + + static absl::StatusOr BuildPartialTensorShape( + const TensorShapeProto& proto) { + PartialTensorShape out; + TF_RETURN_IF_ERROR(BuildTensorShapeBase(proto, &out)); + return out; + } + + /// Add a dimension to the end ("inner-most"), returns a new + /// PartialTensorShape. + /// REQUIRES: `size >= -1`, where -1 means unknown + PartialTensorShape Concatenate(int64_t size) const; + + /// Similar to `Concatenate` but returning `Status`. + /// Use if calling code cannot validate all requirements and if `CHECK`-fails + /// are to be avoided. + absl::Status ConcatenateWithStatus(int64_t size, + PartialTensorShape* out) const; + + /// Appends all the dimensions from `shape`. Returns a new + /// PartialTensorShape. + PartialTensorShape Concatenate(const PartialTensorShape& shape) const; + + /// Similar to `Concatenate` but returning `Status`. + /// Use if calling code cannot validate all requirements and if `CHECK`-fails + /// are to be avoided. + absl::Status ConcatenateWithStatus(const PartialTensorShape& shape, + PartialTensorShape* out) const; + + /// Merges all the dimensions from `shape`. Returns + /// `InvalidArgument` error if either `shape` has a different rank + /// or if any of the dimensions are incompatible. + absl::Status MergeWith(const PartialTensorShape& shape, + PartialTensorShape* result) const; + + /// Exact equality test. Returns true iff the ranks match (i.e., both are + /// unknown, or both are known and equal), and all dimensions are equal (i.e., + /// both dimensions are known, or both are known and equal). This is a + /// stronger condition that IsCompatibleWith. + bool IsIdenticalTo(const PartialTensorShape& shape) const; + + /// Return true iff the ranks match, and if the + /// dimensions all either match or one is unknown. + bool IsCompatibleWith(const PartialTensorShape& shape) const; + + // Fill `*shape` from `*this`. + // If `*this` is not fully defined, returns false and + // `*shape` is left in an intermediate state. Otherwise + // returns true. + bool AsTensorShape(TensorShape* shape) const; + + /// \brief Returns a `PartialTensorShape` whose dimensions are + /// `dims[0]`, `dims[1]`, ..., `dims[n-1]`. Values of -1 are + /// considered "unknown". + template + static absl::Status MakePartialShape(const T* dims, int n, + PartialTensorShape* out) { + return TensorShapeUtils::MakeShape(dims, n, out); + } +}; + +inline bool operator==(const PartialTensorShape& a, + const PartialTensorShape& b) { + return a.IsIdenticalTo(b); +} + +/// \brief Static helper routines for `PartialTensorShape`. Includes a few +/// common predicates on a partially known tensor shape. +class PartialTensorShapeUtils { + public: + static std::string PartialShapeListString( + const absl::Span& shapes); + + static bool AreIdentical(const absl::Span& shapes0, + const absl::Span& shapes1); + + static bool AreCompatible( + const absl::Span& shapes0, + const absl::Span& shapes1); +}; + +// ---------------------------------------------------------------------------- +// Template method implementation details below +// ---------------------------------------------------------------------------- + +template +Eigen::DSizes TensorShape::AsEigenDSizesCopy() const { + Eigen::DSizes dsizes; + for (int d = 0; d < NDIMS; d++) { + dsizes[d] = static_cast(dim_size(d)); + } + return dsizes; +} + +template +Eigen::DSizes TensorShape::AsEigenDSizesCopyAndPad() const { + static_assert(NDIMS <= TensorShape::MaxDimensions(), "Too many dimensions"); + Eigen::DSizes dsizes; + for (int d = 0; d < dims(); d++) { + dsizes[d] = static_cast(dim_size(d)); + } + for (int d = dims(); d < NDIMS; d++) { + dsizes[d] = 1; + } + return dsizes; +} + +template +Eigen::DSizes TensorShape::AsEigenDSizes() const { + CheckDimsEqual(NDIMS); + return AsEigenDSizesCopy(); +} + +template +absl::Status TensorShape::AsEigenDSizesWithStatus( + Eigen::DSizes* out) const { + if (TF_PREDICT_FALSE(NDIMS != dims())) { + return errors::Internal("Asking for tensor of ", NDIMS, + " dimensions from a tensor of ", dims(), + " dimensions"); + } + *out = AsEigenDSizesCopy(); + return absl::OkStatus(); +} + +template +Eigen::DSizes TensorShape::AsEigenDSizesWithPadding() const { + CheckDimsAtMost(NDIMS); + return AsEigenDSizesCopyAndPad(); +} + +template +absl::Status TensorShape::AsEigenDSizesWithPaddingWithStatus( + Eigen::DSizes* out) const { + if (TF_PREDICT_FALSE(NDIMS < dims())) { + return errors::Internal("Asking for tensor of at most ", NDIMS, + " dimensions from a tensor of ", dims(), + " dimensions"); + } + *out = AsEigenDSizesCopyAndPad(); + return absl::OkStatus(); +} + +// ---------------------------------------------------------------------------- +// Inlining of some performance critical routines +// ---------------------------------------------------------------------------- + +inline TensorShapeRep::TensorShapeRep(const TensorShapeRep& b) { + num_elements_ = b.num_elements_; + if (b.tag() != REP_OUT_OF_LINE) { + memcpy(buf(), b.buf(), sizeof(u_.buf)); + // memcpy above Implicitly does: + // set_ndims_byte(b.ndims_byte()); + // set_tag(b.tag()); + } else { + set_tag(REP16); // So that SlowCopyFrom does not try to deallocate + SlowCopyFrom(b); + } +} + +inline TensorShapeRep::TensorShapeRep(TensorShapeRep&& b) { + num_elements_ = b.num_elements_; + memcpy(buf(), b.buf(), sizeof(u_.buf)); + // memcpy above Implicitly does: + // set_ndims_byte(b.ndims_byte()); + // set_tag(b.tag()); + b.set_tag(REP16); // other shape no longer owns out-of-line data, if any. +} + +inline TensorShapeRep::~TensorShapeRep() { + if (tag() == REP_OUT_OF_LINE) { + DestructorOutOfLine(); + } +} + +inline void TensorShapeRep::operator=(const TensorShapeRep& b) { + num_elements_ = b.num_elements_; + if (tag() != REP_OUT_OF_LINE && b.tag() != REP_OUT_OF_LINE) { + memcpy(buf(), b.buf(), sizeof(u_.buf)); + // memcpy above implicitly also does: + // set_tag(b.tag()); + // set_ndims_byte(b.ndims_byte()); + } else { + SlowCopyFrom(b); + } +} + +inline void TensorShapeRep::operator=(TensorShapeRep&& b) { + if (tag() == REP_OUT_OF_LINE) { + DestructorOutOfLine(); + } + num_elements_ = b.num_elements_; + memcpy(buf(), b.buf(), sizeof(u_.buf)); + // memcpy above Implicitly does: + // set_ndims_byte(b.ndims_byte()); + // set_tag(b.tag()); + b.set_tag(REP16); // other shape no longer owns out-of-line data, if any. +} + +inline TensorShape::operator const PartialTensorShape&() const { + // Downcast to the shared representation and upcast to PartialTensorShape + const TensorShapeRep* rep = this; + return *static_cast(rep); +} + +template +inline TensorShapeBase::TensorShapeBase(DataType dt) { + set_tag(REP16); + set_data_type(dt); + + // Optimized implementation of InitDims() where the shape is statically known + // to be {0}. + set_ndims_byte(1); + uint16* dst = as16()->dims_; + *dst = 0; + set_num_elements(0); +} + +// Declare explicit instantiations in .cc file +extern template class TensorShapeBase; +extern template class TensorShapeBase; + +// A convenient struct to represent a (DataType, PartialTensorShape) pair. It's +// often used in shape inference. +struct DtypeAndPartialTensorShape { + DataType dtype; + PartialTensorShape shape; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_FRAMEWORK_TENSOR_SHAPE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/framework/tensor_slice.h b/third_party/tflite-hdrs/tensorflow/core/framework/tensor_slice.h new file mode 100644 index 00000000..4ada28d1 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/framework/tensor_slice.h @@ -0,0 +1,231 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_FRAMEWORK_TENSOR_SLICE_H_ +#define TENSORFLOW_CORE_FRAMEWORK_TENSOR_SLICE_H_ + +#include +#include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/tensor_slice.pb.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/gtl/inlined_vector.h" +#include "tensorflow/core/platform/logging.h" + +namespace tensorflow { + +// A tensor slice represents a slice of a given tensor. It is represented by a +// list of (start, length) pairs, where the size of the list is the rank of the +// tensor. + +class TensorSlice { + public: + // Construct a tensor slice: you have a number of ways: + // -- creating an empty slice + // -- from just a dimension (in this case it will create a full slice) + // -- from an array of pairs of integers. + // -- from a TensorSliceProto protocol buffer + // -- from a string format of "start,length:start,length..." where each + // "start,length" pair represents the slice on one dimension. We allow a + // special "-" that means "everything for this dimension". One such example + // is: 0,10:-:14,1:-:- + TensorSlice() {} + explicit TensorSlice(int dim); + explicit TensorSlice(const TensorSliceProto& proto); + explicit TensorSlice( + std::initializer_list> extents); + + // This factory methods should be used instead of the constructor that takes a + // `TensorSliceProto` if calling code cannot validate that the sizes specify a + // valid `TensorSlice`. + static absl::Status BuildTensorSlice(const TensorSliceProto& proto, + TensorSlice* output); + + static absl::Status Parse(const string& str, TensorSlice* output); + static TensorSlice ParseOrDie(const string& str) { + TensorSlice ret; + absl::Status s = Parse(str, &ret); + if (!s.ok()) { + LOG(FATAL) << "Could not parse TensorSlice"; + } + return ret; + } + + void Clear(); + + // Accessors + int dims() const { return starts_.size(); } + + int64_t start(int d) const { + DCHECK_GE(d, 0); + DCHECK_LT(d, dims()); + return starts_[d]; + } + + int64_t length(int d) const { + DCHECK_GE(d, 0); + DCHECK_LT(d, dims()); + return lengths_[d]; + } + + int64_t end(int d) const { + DCHECK_GE(d, 0); + DCHECK_LT(d, dims()); + return start(d) + length(d); + } + + void set_start(int d, int64_t x) { + DCHECK_GE(d, 0); + DCHECK_LT(d, dims()); + DCHECK_GE(x, 0); + starts_[d] = x; + } + + void set_length(int d, int64_t x) { + DCHECK_GE(d, 0); + DCHECK_LT(d, dims()); + lengths_[d] = x; + } + + // If we have a full slice along dimension "d". + bool IsFullAt(int d) const { + return lengths_[d] == kFullExtent && starts_[d] == 0; + } + + // If this is a full slice, i.e. IsFullAt(d) for every d. + bool IsFull() const; + + // Set the slice to be a full slice of "dim" dimensions + void SetFullSlice(int dim); + + // Extend a slice to "dim" dimensions: all the added dimensions are full. + // Requires: dim >= dims(). + void Extend(int dim); + + // Conversion of a TensorSlice to other formats + void AsProto(TensorSliceProto* proto) const; + string DebugString() const; + + // Fill *indices and *sizes from *this (so that we can use the slice() + // function in eigen tensor). We need a tensor shape in case some of the + // slices are full slices. + // We allow NDIMS to be greater than dims(), in which case we will pad the + // higher dimensions with trivial dimensions. + template + void FillIndicesAndSizes( + const TensorShape& shape, + Eigen::DSizes* indices, + Eigen::DSizes* sizes) const; + + // Interaction with other TensorSlices. + + // Compute the intersection with another slice and if "result" is not + // nullptr, store the results in *result; returns true if there is any real + // intersection. + bool Intersect(const TensorSlice& other, TensorSlice* result) const; + // A short hand. + bool Overlaps(const TensorSlice& other) const { + return Intersect(other, nullptr); + } + + // Equals iff "*this" and "other" are logically equivalent. + bool operator==(const TensorSlice& other) const; + bool operator!=(const TensorSlice& other) const { return !(*this == other); } + + // Interaction with TensorShape. + + // Slices a shape and stores the result into *result_shape. + // Requires that the shape and *this have the same rank. + // For example, given a tensor shape of {3, 4, 5}, and a slice of + // 1,2:-:0,2, the result shape is {2, 4, 2}. + absl::Status SliceTensorShape(const TensorShape& shape, + TensorShape* result_shape) const; + + // Given slice "sub" where "sub" is fully contained in *this, + // (meaning that the intersection of "sub" and *this equals "sub"), computes + // the "relative" slice of "sub" with respect to *this. + // + // In other words, if we use A>S to denote slicing a shape S with a slice A, + // then the function is computing a slice X such that: + // X > (this > S) = sub > S + // for any shape S. + // + // In general, along every dimension, the start of the relative slice is the + // start of the "sub" slice minus the start of *this; the length of the + // relative slice is the length of the "sub" slice. + // + // For example, say we have a shape of {3, 4, 5}, "this" is 0,2:-:1,2, and + // "sub" is 1,1:2:2,1,2, then the related slice is 1,1:2,2:0,2. + // + // The caller needs to make sure that "sub" is indeed a sub-slice of *this; + // otherwise the result is undefined. + void ComputeRelative(const TensorSlice& sub, TensorSlice* relative) const; + + // Updates the slice in such a way that it fully covers "other" slice. + // Note, "other" slice should refer to the same tensor shape. + // Example: + // given a slice [2:4, :, 3:] and "other" slice [:, 1:4, 2:4] the + // updated slice would be [:, :, 2:]. Here is why: + // dim 0: "2:4" U ":" -> ":" + // dim 1: ":" U "1-4" -> ":" + // dim 2: "3:" U "2:4" -> "2:" + void UpdateToCover(const TensorSlice& other); + + // Returns true if the length field was specified in an Extent. + static bool HasExtentLength(const TensorSliceProto::Extent& extent); + + // Returns the value of the length field in an Extent, or -1 if it + // is not present. + static int64_t GetExtentLength(const TensorSliceProto::Extent& extent); + + private: + // a length value of kFullExtent (-1) means we have a full slice at this + // dimension. It's defined in tensor_slice.cc. + static const int64_t kFullExtent; + + // TODO(yangke): switch to Eigen once it supports variable size arrays. + // A value of + absl::InlinedVector starts_; + absl::InlinedVector lengths_; +}; + +template +void TensorSlice::FillIndicesAndSizes( + const TensorShape& shape, Eigen::DSizes* indices, + Eigen::DSizes* sizes) const { + CHECK_EQ(shape.dims(), dims()) << "Incompatible dimensions between shape " + << "slices: shape = " << shape.DebugString() + << ", slice = " << DebugString(); + CHECK_GE(NDIMS, dims()) << "Asking for a " << NDIMS << "-dim slice from " + << "a slice of dimension " << dims(); + for (int d = 0; d < dims(); ++d) { + if (IsFullAt(d)) { + (*indices)[d] = 0; + (*sizes)[d] = shape.dim_size(d); + } else { + (*indices)[d] = starts_[d]; + (*sizes)[d] = lengths_[d]; + } + } + for (int d = dims(); d < NDIMS; ++d) { + (*indices)[d] = 0; + (*sizes)[d] = 1; + } +} + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_FRAMEWORK_TENSOR_SLICE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/framework/tensor_testutil.h b/third_party/tflite-hdrs/tensorflow/core/framework/tensor_testutil.h new file mode 100644 index 00000000..53ad5969 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/framework/tensor_testutil.h @@ -0,0 +1,162 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_FRAMEWORK_TENSOR_TESTUTIL_H_ +#define TENSORFLOW_CORE_FRAMEWORK_TENSOR_TESTUTIL_H_ + +#include + +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace test { + +// Constructs a scalar tensor with 'val'. +template +Tensor AsScalar(const T& val) { + Tensor ret(DataTypeToEnum::value, {}); + ret.scalar()() = val; + return ret; +} + +// Constructs a flat tensor with 'vals'. +template +Tensor AsTensor(gtl::ArraySlice vals) { + Tensor ret(DataTypeToEnum::value, {static_cast(vals.size())}); + std::copy_n(vals.data(), vals.size(), ret.flat().data()); + return ret; +} + +// Constructs a tensor of "shape" with values "vals". +template +Tensor AsTensor(gtl::ArraySlice vals, const TensorShape& shape) { + Tensor ret; + CHECK(ret.CopyFrom(AsTensor(vals), shape)); + return ret; +} + +// Fills in '*tensor' with 'vals'. E.g., +// Tensor x(&alloc, DT_FLOAT, TensorShape({2, 2})); +// test::FillValues(&x, {11, 21, 21, 22}); +template +void FillValues(Tensor* tensor, gtl::ArraySlice vals) { + auto flat = tensor->flat(); + CHECK_EQ(flat.size(), vals.size()); + if (flat.size() > 0) { + std::copy_n(vals.data(), vals.size(), flat.data()); + } +} + +// Fills in '*tensor' with 'vals', converting the types as needed. +template +void FillValues(Tensor* tensor, std::initializer_list vals) { + auto flat = tensor->flat(); + CHECK_EQ(flat.size(), vals.size()); + if (flat.size() > 0) { + size_t i = 0; + for (auto itr = vals.begin(); itr != vals.end(); ++itr, ++i) { + flat(i) = T(*itr); + } + } +} + +// Fills in '*tensor' with a sequence of value of val, val+1, val+2, ... +// Tensor x(&alloc, DT_FLOAT, TensorShape({2, 2})); +// test::FillIota(&x, 1.0); +template +void FillIota(Tensor* tensor, const T& val) { + auto flat = tensor->flat(); + std::iota(flat.data(), flat.data() + flat.size(), val); +} + +// Fills in '*tensor' with a sequence of value of fn(0), fn(1), ... +// Tensor x(&alloc, DT_FLOAT, TensorShape({2, 2})); +// test::FillFn(&x, [](int i)->float { return i*i; }); +template +void FillFn(Tensor* tensor, std::function fn) { + auto flat = tensor->flat(); + for (int i = 0; i < flat.size(); ++i) flat(i) = fn(i); +} + +// Expects "x" and "y" are tensors of the same type, same shape, and identical +// values (within 4 ULPs for floating point types unless explicitly disabled). +enum class Tolerance { + kNone, + kDefault, +}; +void ExpectEqual(const Tensor& x, const Tensor& y, + Tolerance t = Tolerance ::kDefault); + +// Expects "x" and "y" are tensors of the same (floating point) type, +// same shape and element-wise difference between x and y is no more +// than atol + rtol * abs(x). If atol or rtol is negative, the data type's +// epsilon * kSlackFactor is used. +void ExpectClose(const Tensor& x, const Tensor& y, double atol = -1.0, + double rtol = -1.0); + +// Expects "x" and "y" are tensors of the same type T, same shape, and +// equal values. Consider using ExpectEqual above instead. +template +void ExpectTensorEqual(const Tensor& x, const Tensor& y) { + EXPECT_EQ(x.dtype(), DataTypeToEnum::value); + ExpectEqual(x, y); +} + +::testing::AssertionResult IsSameType(const Tensor& x, const Tensor& y); +::testing::AssertionResult IsSameShape(const Tensor& x, const Tensor& y); + +template +void ExpectTensorEqual(const Tensor& x, const Tensor& y, + std::function is_equal) { + EXPECT_EQ(x.dtype(), DataTypeToEnum::value); + ASSERT_TRUE(IsSameType(x, y)); + ASSERT_TRUE(IsSameShape(x, y)); + + const T* Tx = x.unaligned_flat().data(); + const T* Ty = y.unaligned_flat().data(); + auto size = x.NumElements(); + int max_failures = 10; + int num_failures = 0; + for (decltype(size) i = 0; i < size; ++i) { + EXPECT_TRUE(is_equal(Tx[i], Ty[i])) << "i = " << (++num_failures, i); + ASSERT_LT(num_failures, max_failures) << "Too many mismatches, giving up."; + } +} + +// Expects "x" and "y" are tensors of the same type T, same shape, and +// approximate equal values. Consider using ExpectClose above instead. +template +void ExpectTensorNear(const Tensor& x, const Tensor& y, double atol) { + EXPECT_EQ(x.dtype(), DataTypeToEnum::value); + ExpectClose(x, y, atol, /*rtol=*/0.0); +} + +// For tensor_testutil_test only. +namespace internal_test { +::testing::AssertionResult IsClose(Eigen::half x, Eigen::half y, + double atol = -1.0, double rtol = -1.0); +::testing::AssertionResult IsClose(float x, float y, double atol = -1.0, + double rtol = -1.0); +::testing::AssertionResult IsClose(double x, double y, double atol = -1.0, + double rtol = -1.0); +} // namespace internal_test + +} // namespace test +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_FRAMEWORK_TENSOR_TESTUTIL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/framework/tensor_types.h b/third_party/tflite-hdrs/tensorflow/core/framework/tensor_types.h new file mode 100644 index 00000000..2381d6b7 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/framework/tensor_types.h @@ -0,0 +1,199 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_FRAMEWORK_TENSOR_TYPES_H_ +#define TENSORFLOW_CORE_FRAMEWORK_TENSOR_TYPES_H_ + +#include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/platform/logging.h" + +namespace tensorflow { + +// Helper to define Tensor types given that the scalar is of type T. +template +struct TTypes { + // Rank- tensor of scalar type T. + typedef Eigen::TensorMap, + Eigen::Aligned> + Tensor; + typedef Eigen::TensorMap< + Eigen::Tensor, Eigen::Aligned> + ConstTensor; + + // Unaligned Rank- tensor of scalar type T. + typedef Eigen::TensorMap > + UnalignedTensor; + typedef Eigen::TensorMap< + Eigen::Tensor > + UnalignedConstTensor; + + typedef Eigen::TensorMap, + Eigen::Aligned> + Tensor32Bit; + + // Scalar tensor (implemented as a rank-0 tensor) of scalar type T. + typedef Eigen::TensorMap< + Eigen::TensorFixedSize, Eigen::RowMajor, IndexType>, + Eigen::Aligned> + Scalar; + typedef Eigen::TensorMap, + Eigen::RowMajor, IndexType>, + Eigen::Aligned> + ConstScalar; + + // Unaligned Scalar tensor of scalar type T. + typedef Eigen::TensorMap< + Eigen::TensorFixedSize, Eigen::RowMajor, IndexType> > + UnalignedScalar; + typedef Eigen::TensorMap, + Eigen::RowMajor, IndexType> > + UnalignedConstScalar; + + // Rank-1 tensor (vector) of scalar type T. + typedef Eigen::TensorMap, + Eigen::Aligned> + Flat; + typedef Eigen::TensorMap< + Eigen::Tensor, Eigen::Aligned> + ConstFlat; + typedef Eigen::TensorMap, + Eigen::Aligned> + Vec; + typedef Eigen::TensorMap< + Eigen::Tensor, Eigen::Aligned> + ConstVec; + + // Unaligned Rank-1 tensor (vector) of scalar type T. + typedef Eigen::TensorMap > + UnalignedFlat; + typedef Eigen::TensorMap< + Eigen::Tensor > + UnalignedConstFlat; + typedef Eigen::TensorMap > + UnalignedVec; + typedef Eigen::TensorMap< + Eigen::Tensor > + UnalignedConstVec; + + // Rank-2 tensor (matrix) of scalar type T. + typedef Eigen::TensorMap, + Eigen::Aligned> + Matrix; + typedef Eigen::TensorMap< + Eigen::Tensor, Eigen::Aligned> + ConstMatrix; + + // Unaligned Rank-2 tensor (matrix) of scalar type T. + typedef Eigen::TensorMap > + UnalignedMatrix; + typedef Eigen::TensorMap< + Eigen::Tensor > + UnalignedConstMatrix; +}; + +typedef typename TTypes::Tensor32Bit::Index Index32; + +template +bool SafeFor32BitIndexing(const Eigen::DSizes& in) { + for (int i = 0; i < NumDims; ++i) { + if (in[i] > std::numeric_limits::max()) return false; + } + return true; +} + +template +bool SafeFor32BitIndexing(const Eigen::array& in) { + for (size_t i = 0; i < NumDims; ++i) { + if (in[i] > std::numeric_limits::max()) return false; + } + return true; +} + +template ::Tensor32Bit> +bool SafeFor32BitIndexing(TensorType in) { + return in.size() <= std::numeric_limits::max(); +} + +template +Eigen::DSizes To32Bit( + const Eigen::DSizes& in) { + DCHECK(SafeFor32BitIndexing(in)); + Eigen::DSizes out; + for (int i = 0; i < NumDims; ++i) { + out[i] = static_cast(in[i]); + } + return out; +} + +template +Eigen::array To32Bit(const Eigen::array& in) { + DCHECK(SafeFor32BitIndexing(in)); + Eigen::array out; + for (size_t i = 0; i < NumDims; ++i) { + out[i] = static_cast(in[i]); + } + return out; +} + +template +typename TTypes::Tensor32Bit +To32Bit(TensorType in) { + typedef typename TTypes::Tensor32Bit RetType; + DCHECK(SafeFor32BitIndexing(in)); + return RetType(in.data(), To32Bit(in.dimensions())); +} + +namespace internal { + +template +struct MaybeWith32BitIndexingImpl { + template + void operator()(Func func, Args&&... args) const { + func(std::forward(args)...); + } +}; + +template <> +struct MaybeWith32BitIndexingImpl { + template + void operator()(Func func, Args&&... args) const { + auto all = [](const auto&... bool_vals) { + for (bool b : {bool_vals...}) { + if (!b) return false; + } + return true; + }; + if (all(SafeFor32BitIndexing(std::forward(args))...)) { + func(To32Bit(std::forward(args))...); + } else { + func(std::forward(args)...); + } + } +}; + +} // namespace internal + +template +void MaybeWith32BitIndexing(Func func, Args&&... args) { + return internal::MaybeWith32BitIndexingImpl()( + func, std::forward(args)...); +} + +} // namespace tensorflow +#endif // TENSORFLOW_CORE_FRAMEWORK_TENSOR_TYPES_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/framework/tensor_util.h b/third_party/tflite-hdrs/tensorflow/core/framework/tensor_util.h new file mode 100644 index 00000000..eec2bd3f --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/framework/tensor_util.h @@ -0,0 +1,358 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_FRAMEWORK_TENSOR_UTIL_H_ +#define TENSORFLOW_CORE_FRAMEWORK_TENSOR_UTIL_H_ + +#include +#include + +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor.pb.h" +#include "tensorflow/core/framework/tensor_shape.pb.h" +#include "tensorflow/core/framework/type_traits.h" +#include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +namespace tensor { + +// DeepCopy returns a tensor whose contents are a deep copy of the +// contents of 'other'. This function is intended only for +// convenience, not speed. +// +// REQUIRES: 'other' must point to data stored in CPU memory. +// REQUIRES: 'other' must be a Tensor of a copy-able type if +// 'other' is not appropriately memory-aligned. +Tensor DeepCopy(const Tensor& other); + +// Deep copies input to output. This function is similar to above, but assumes +// that the memory for the output has already been allocated. +void DeepCopy(const Tensor& input, Tensor* output); + +// Concatenates 'tensors' into a single tensor, along their 0th dimension. +// +// REQUIRES: All members of 'tensors' must have the same data type parameter. +// REQUIRES: Each member of 'tensors' must have at least one dimension. +// REQUIRES: Each member of 'tensors' must point to data stored in CPU memory. +// REQUIRES: Each member of 'tensors' must be a Tensor of a copy-able type if it +// is not appropriately memory-aligned. +absl::Status Concat(absl::Span tensors, Tensor* result); + +// Splits 'tensor' into 'sizes.size()' individual tensors, along the 0th +// dimension. The ith output tensor has 0th-dimension size 'sizes[i]'. +// +// REQUIRES: 'tensor' must have at least one dimension. +// REQUIRES: 'tensor.dim_size(0)' must equal the sum of the elements of 'sizes'. +// REQUIRES: 'tensor' must point to data stored in CPU memory. +// REQUIRES: 'tensor' must be a Tensor of a copy-able type if it is not +// appropriately memory-aligned. +// +// Split() and Concat() are inverse operations. +absl::Status Split(const Tensor& tensor, absl::Span sizes, + std::vector* result); + +namespace internal { +void SetTensorProtoShape(absl::Span shape, + TensorShapeProto* shape_proto); + +template +class TensorProtoFieldHelper : public std::false_type {}; + +#define DEFINE_PROTO_FIELD_HELPER(TYPE, FIELDNAME) \ + template <> \ + class TensorProtoFieldHelper : public std::true_type { \ + public: \ + typedef decltype( \ + std::declval().FIELDNAME##_val(0)) FieldType; \ + typedef decltype( \ + std::declval().FIELDNAME##_val()) RepeatedFieldType; \ + typedef decltype(std::declval().mutable_##FIELDNAME##_val()) \ + MutableRepeatedFieldType; \ + static MutableRepeatedFieldType GetMutableField(TensorProto* proto) { \ + return proto->mutable_##FIELDNAME##_val(); \ + } \ + static RepeatedFieldType& GetField(const TensorProto& proto) { \ + return proto.FIELDNAME##_val(); \ + } \ + } + +// The argument pairs in the following macro instantiations encode the +// mapping from C++ type ($1) to repeated field name "$2_val" used for storing +// values in TensorProto. See tensorflow/core/framework/tensor.proto. +DEFINE_PROTO_FIELD_HELPER(float, float); +DEFINE_PROTO_FIELD_HELPER(double, double); +DEFINE_PROTO_FIELD_HELPER(int8, int); +DEFINE_PROTO_FIELD_HELPER(uint8, int); +DEFINE_PROTO_FIELD_HELPER(int16, int); +DEFINE_PROTO_FIELD_HELPER(uint16, int); +DEFINE_PROTO_FIELD_HELPER(int32, int); +DEFINE_PROTO_FIELD_HELPER(uint32, uint32); +DEFINE_PROTO_FIELD_HELPER(int64_t, int64); +DEFINE_PROTO_FIELD_HELPER(uint64, uint64); +DEFINE_PROTO_FIELD_HELPER(bool, bool); +DEFINE_PROTO_FIELD_HELPER(qint8, int); +DEFINE_PROTO_FIELD_HELPER(quint8, int); +DEFINE_PROTO_FIELD_HELPER(qint16, int); +DEFINE_PROTO_FIELD_HELPER(quint16, int); +DEFINE_PROTO_FIELD_HELPER(qint32, int); +DEFINE_PROTO_FIELD_HELPER(Eigen::half, half); +DEFINE_PROTO_FIELD_HELPER(bfloat16, half); +DEFINE_PROTO_FIELD_HELPER(complex64, scomplex); +DEFINE_PROTO_FIELD_HELPER(complex128, dcomplex); + +#undef DEFINE_PROTO_HELPER + +template +struct CopyHelper { + template + static void ToArray(SrcIter begin, SrcIter end, DstIter dst) { + using SrcType = typename std::iterator_traits::value_type; + using DstType = typename std::iterator_traits::value_type; + std::transform(begin, end, dst, [](const SrcType& x) -> DstType { + return static_cast(x); + }); + } + template + static void ToArray(SrcIter begin, SrcIter end, SrcIter dst) { + std::copy(begin, end, dst); + } + template + static void FromArray(SrcIter begin, SrcIter end, DstIter dst) { + ToArray(begin, end, dst); + } +}; + +// Overloads for Eigen::half and bfloat16 that are 16 bits in size but are +// stored in an int32 field. +template <> +struct CopyHelper { + template + static void ToArray(SrcIter begin, SrcIter end, Eigen::half* dst) { + std::transform(begin, end, dst, [](int x) -> Eigen::half { + return Eigen::numext::bit_cast(static_cast(x)); + }); + } + template + static void FromArray(SrcIter begin, SrcIter end, DstIter dst) { + std::transform(begin, end, dst, [](Eigen::half h) -> int { + return static_cast(Eigen::numext::bit_cast(h)); + }); + } +}; + +template <> +struct CopyHelper { + template + static void ToArray(SrcIter begin, SrcIter end, bfloat16* dst) { + std::transform(begin, end, dst, [](int x) -> bfloat16 { + return Eigen::numext::bit_cast(static_cast(x)); + }); + } + template + static void FromArray(SrcIter begin, SrcIter end, DstIter dst) { + std::transform(begin, end, dst, [](bfloat16 bf16) -> int { + return static_cast(Eigen::numext::bit_cast(bf16)); + }); + } +}; + +// Overloads for complex types that store real and imaginary parts +// at indices 2*i and 2*i+1 in float or double field. +template +struct CopyHelper> { + template + static void ToArray(SrcIter begin, SrcIter end, std::complex* dst) { + RealType* real_dst = reinterpret_cast(dst); + std::copy(begin, end, real_dst); + } + + template + static void FromArray(SrcIter begin, SrcIter end, DstIter dst) { + size_t n = std::distance(begin, end); + const RealType* real_begin = reinterpret_cast(&(*begin)); + std::copy_n(real_begin, 2 * n, dst); + } +}; + +// Helper class to extract and insert values into TensorProto represented as +// repeated fields. +template +class TensorProtoHelper : public std::true_type { + public: + using FieldHelper = TensorProtoFieldHelper; + using FieldType = typename TensorProtoFieldHelper::FieldType; + + static DataType GetDataType() { return DataTypeToEnum::value; } + + // Returns the number of values of type T encoded in the proto. + static size_t NumValues(const TensorProto& proto) { + size_t raw_size = FieldHelper::GetField(proto).size(); + return is_complex::value ? raw_size / 2 : raw_size; + } + + static void AddValue(const T& value, TensorProto* proto) { + const T* val_ptr = &value; + AddValues(val_ptr, val_ptr + 1, proto); + } + + static T GetValue(size_t index, const TensorProto& proto) { + const size_t stride = is_complex::value ? 2 : 1; + T val; + CopyHelper::ToArray( + FieldHelper::GetField(proto).begin() + stride * index, + FieldHelper::GetField(proto).begin() + stride * (index + 1), &val); + return val; + } + + template + static void AddValues(IterType begin, IterType end, TensorProto* proto) { + size_t n = std::distance(begin, end); + FieldType* dst = AppendUninitialized(n, proto); + CopyHelper::FromArray(begin, end, dst); + } + + template + static void CopyValues(IterType dst, const TensorProto& proto) { + CopyHelper::ToArray(FieldHelper::GetField(proto).begin(), + FieldHelper::GetField(proto).end(), dst); + } + + static void Truncate(size_t new_size, TensorProto* proto) { + if (is_complex::value) new_size *= 2; + FieldHelper::GetMutableField(proto)->Truncate(new_size); + } + + static FieldType* AppendUninitialized(size_t n, TensorProto* proto) { + if (is_complex::value) n *= 2; + auto* field = FieldHelper::GetMutableField(proto); + field->Reserve(field->size() + n); + return reinterpret_cast(field->AddNAlreadyReserved(n)); + } +}; + +// Specialization for string. +template <> +class TensorProtoHelper : public std::true_type { + public: + static DataType GetDataType() { return DataType::DT_STRING; } + static void AddValue(const string& value, TensorProto* proto) { + *proto->mutable_string_val()->Add() = value; + } + template + static void AddValues(IterType begin, IterType end, TensorProto* proto) { + for (IterType it = begin; it != end; ++it) { + AddValue(*it, proto); + } + } + template + static void CopyToTensorContent(IterType begin, IterType end, + TensorProto* proto) { + AddValues(begin, end, proto); + } +}; + +template +typename std::enable_if::value, + TensorProto>::type +CreateTensorProto(IterType values_begin, IterType values_end, + const size_t values_size, + const absl::Span shape) { + TensorProto tensor; + TensorShapeProto tensor_shape_proto; + internal::SetTensorProtoShape(shape, &tensor_shape_proto); + if (TensorShape(tensor_shape_proto).num_elements() != values_size) { + LOG(ERROR) << "Shape and number of values (" << values_size + << ") are incompatible."; + return tensor; + } + using TypeHelper = internal::TensorProtoHelper; + tensor.set_dtype(TypeHelper::GetDataType()); + *tensor.mutable_tensor_shape() = std::move(tensor_shape_proto); + TypeHelper::AddValues(values_begin, values_end, &tensor); + return tensor; +} + +} // namespace internal + +// Creates a 'TensorProto' with the specified shape and values. The dtype and a +// field to represent data values of the returned 'TensorProto' are determined +// based on Type. Note that unless the argument provided to `values` is already +// an absl::Span, `Type` will need to be provided as a template parameter--the +// compiler can't infer it: +// auto proto = CreateTensorProtoSpan(my_array, shape); +template +typename std::enable_if::value, + TensorProto>::type +CreateTensorProtoSpan(const absl::Span values, + const absl::Span shape) { + return internal::CreateTensorProto(values.begin(), values.end(), + values.size(), shape); +} + +// Version of the above that's more convenient if `values` is an std::vector, in +// which case Type can automatically be inferred: +// auto proto = CreateTensorProto(my_vector, shape); +template +typename std::enable_if::value, + TensorProto>::type +CreateTensorProto(const std::vector& values, + const absl::Span shape) { + // This awkward iterator passing is essentially just to support vector, + // otherwise we could just represent the vector as a Span. + return internal::CreateTensorProto(values.begin(), values.end(), + values.size(), shape); +} + +// Converts values in tensor to run-length encoded compressed form. +// +// The elements of a tensor can be stored in a TensorProto in one of the +// following two forms: +// 1. As a raw byte string in the field `tensor_content` containing the +// serialized in-memory representation of the tensor. +// 2. As values of a repeated field depending on the datatype, e.g. that +// values of a DT_FLOAT tensor would be stored in the repeated field +// `float_val`. +// Storage scheme 2 may use a simple form of run-length encoding to compress +// data: If the values contains a tail of identical values, the repeated field +// will be truncated such that the number of values in the repeated field is +// less than the number of elements implied by the field`tensor_shape`. The +// original tensor can be recovered by repeating the final value in the repeated +// field. +// +// The TensorProto will be compressed if a) the tensor contains at least +// min_num_elements elements and b) the compressed tensor proto is would be at +// most the size of the original tensor proto divided by min_compression_ratio. +// +// Returns true if the tensor was compressed. +bool CompressTensorProtoInPlace(int64_t min_num_elements, + float min_compression_ratio, + TensorProto* tensor); + +inline bool CompressTensorProtoInPlace(TensorProto* tensor) { + static const int64_t kDefaultMinNumElements = 64; + static const float kDefaultMinCompressionRatio = 2.0f; + return CompressTensorProtoInPlace(kDefaultMinNumElements, + kDefaultMinCompressionRatio, tensor); +} + +// Make a TensorShape from the contents of shape_t. Shape_t must be a +// 1-dimensional tensor of type int32 or int64. +absl::Status MakeShape(const Tensor& shape_t, TensorShape* out); + +} // namespace tensor +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_FRAMEWORK_TENSOR_UTIL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/framework/thread_factory.h b/third_party/tflite-hdrs/tensorflow/core/framework/thread_factory.h new file mode 100644 index 00000000..769ada29 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/framework/thread_factory.h @@ -0,0 +1,44 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_FRAMEWORK_THREAD_FACTORY_H_ +#define TENSORFLOW_CORE_FRAMEWORK_THREAD_FACTORY_H_ + +#include +#include + +#include "tensorflow/core/platform/types.h" + +namespace tsl { +class Thread; +} // namespace tsl +namespace tensorflow { +using tsl::Thread; // NOLINT + +// Virtual interface for an object that creates threads. +class ThreadFactory { + public: + virtual ~ThreadFactory() {} + + // Runs `fn` asynchronously in a different thread. `fn` may block. + // + // NOTE: The caller is responsible for ensuring that this `ThreadFactory` + // outlives the returned `Thread`. + virtual std::unique_ptr StartThread(const string& name, + std::function fn) = 0; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_FRAMEWORK_THREAD_FACTORY_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/framework/tracking_allocator.h b/third_party/tflite-hdrs/tensorflow/core/framework/tracking_allocator.h new file mode 100644 index 00000000..ba54b2c5 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/framework/tracking_allocator.h @@ -0,0 +1,37 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_FRAMEWORK_TRACKING_ALLOCATOR_H_ +#define TENSORFLOW_CORE_FRAMEWORK_TRACKING_ALLOCATOR_H_ + +#include + +#include "xla/tsl/framework/tracking_allocator.h" +#include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/lib/gtl/inlined_vector.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/thread_annotations.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +// NOLINTEND(misc-unused-using-decls) +using tsl::AllocRecord; +using tsl::TrackingAllocator; +// NOLINTEND(misc-unused-using-decls) + +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_FRAMEWORK_TRACKING_ALLOCATOR_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/framework/type_index.h b/third_party/tflite-hdrs/tensorflow/core/framework/type_index.h new file mode 100644 index 00000000..d73ca527 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/framework/type_index.h @@ -0,0 +1,95 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_FRAMEWORK_TYPE_INDEX_H_ +#define TENSORFLOW_CORE_FRAMEWORK_TYPE_INDEX_H_ + +#include + +#if defined(__GXX_RTTI) || defined(_CPPRTTI) +#include +#endif // __GXX_RTTI + +#include "tensorflow/core/platform/hash.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +// On some platforms, we would like to avoid using RTTI in order to have smaller +// binary sizes. This file provides a thin TypeIndex class that mimics +// std::type_index but does not use RTTI (with a minimal set of functions needed +// by the TensorFlow framework, and more can be added if necessary). In the +// absence of RTTI, it does not provide the actual name of the type, and only +// returns a pre-baked string specifying that RTTI is disabled. The hash code +// provided in this class is unique for each class. However, it is generated at +// runtime so this hash code should not be serialized - the value for the same +// type can change from run to run. +class TypeIndex { + public: + TypeIndex(const TypeIndex& src) : hash_(src.hash_), name_(src.name_) {} + TypeIndex& operator=(const TypeIndex& src) { + hash_ = src.hash_; + name_ = src.name_; + return *this; + } + bool operator==(const TypeIndex& rhs) const { return (hash_ == rhs.hash_); } + bool operator!=(const TypeIndex& rhs) const { return (hash_ != rhs.hash_); } + ~TypeIndex() {} + + const char* name() const { return name_; } + + uint64 hash_code() const { return hash_; } + + // Returns a TypeIndex object that corresponds to a typename. + template + static TypeIndex Make() { +#ifdef PLATFORM_CLOUD_TPU + static bool hash_bit[1]; + return TypeIndex(static_cast(reinterpret_cast(hash_bit)), + typeid(T).name()); +#endif +#if defined(__GXX_RTTI) || defined(_CPPRTTI) + + // Use a hash based on the type name to avoid issues due to RTLD_LOCAL on + // MacOS (b/156979412). + return TypeIndex(Hash64(typeid(T).name()), typeid(T).name()); + +#else + static bool hash_bit[1]; +#if TARGET_OS_OSX + // Warn MacOS users that not using RTTI can cause problems (b/156979412). +#warning \ + "Compiling with RTTI disabled on MacOS can cause problems when comparing " \ + "types across shared libraries." +#endif // TARGET_OS_OSX + + // No type names available. + return TypeIndex(static_cast(reinterpret_cast(hash_bit)), + "[RTTI disabled]"); +#endif // __GXX_RTTI + } + + private: + // We hide the constructor of the TypeIndex class. Use the templated + // Make() function to create a TypeIndex object. + explicit TypeIndex(const uint64 hash, const char* name) + : hash_(hash), name_(name) {} + uint64 hash_; + const char* name_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_FRAMEWORK_TYPE_INDEX_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/framework/type_traits.h b/third_party/tflite-hdrs/tensorflow/core/framework/type_traits.h new file mode 100644 index 00000000..ac1c9e86 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/framework/type_traits.h @@ -0,0 +1,38 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_FRAMEWORK_TYPE_TRAITS_H_ +#define TENSORFLOW_CORE_FRAMEWORK_TYPE_TRAITS_H_ + +#include +#include + +#include "xla/tsl/framework/type_traits.h" +#include "tensorflow/core/framework/numeric_types.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +// NOLINTBEGIN(misc-unused-using-decls) +using tsl::false_type; +using tsl::is_complex; +using tsl::is_quantized; +using tsl::is_simple_type; +using tsl::true_type; +// NOLINTEND(misc-unused-using-decls) + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_FRAMEWORK_TYPE_TRAITS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/framework/typed_allocator.h b/third_party/tflite-hdrs/tensorflow/core/framework/typed_allocator.h new file mode 100644 index 00000000..6d89983b --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/framework/typed_allocator.h @@ -0,0 +1,135 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_FRAMEWORK_TYPED_ALLOCATOR_H_ +#define TENSORFLOW_CORE_FRAMEWORK_TYPED_ALLOCATOR_H_ + +#include + +#include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/framework/resource_handle.h" +#include "tensorflow/core/framework/type_traits.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +class Variant; + +// Convenience functions to do typed allocation. C++ constructors +// and destructors are invoked for complex types if necessary. +class TypedAllocator { + public: + // May return NULL if the tensor has too many elements to represent in a + // single allocation. + template + static T* Allocate(Allocator* raw_allocator, size_t num_elements, + const AllocationAttributes& allocation_attr) { + // TODO(jeff): Do we need to allow clients to pass in alignment + // requirements? + + if (num_elements > (std::numeric_limits::max() / sizeof(T))) { + return nullptr; + } + + void* p = + raw_allocator->AllocateRaw(Allocator::kAllocatorAlignment, + sizeof(T) * num_elements, allocation_attr); + T* typed_p = reinterpret_cast(p); + if (typed_p) RunCtor(raw_allocator, typed_p, num_elements); + return typed_p; + } + + template + static void Deallocate(Allocator* raw_allocator, T* ptr, + size_t num_elements) { + if (ptr) { + RunDtor(raw_allocator, ptr, num_elements); + raw_allocator->DeallocateRaw(ptr, Allocator::kAllocatorAlignment, + sizeof(T) * num_elements); + } + } + + private: + // No constructors or destructors are run for simple types + template + static void RunCtor(Allocator* raw_allocator, T* p, size_t n) { + static_assert(is_simple_type::value, "T is not a simple type."); + } + + template + static void RunDtor(Allocator* raw_allocator, T* p, size_t n) {} + + static void RunVariantCtor(Variant* p, size_t n); + + static void RunVariantDtor(Variant* p, size_t n); +}; + +template <> +/* static */ +inline void TypedAllocator::RunCtor(Allocator* raw_allocator, tstring* p, + size_t n) { + if (!raw_allocator->AllocatesOpaqueHandle()) { + for (size_t i = 0; i < n; ++p, ++i) new (p) tstring(); + } +} + +template <> +/* static */ +inline void TypedAllocator::RunDtor(Allocator* raw_allocator, tstring* p, + size_t n) { + if (!raw_allocator->AllocatesOpaqueHandle()) { + for (size_t i = 0; i < n; ++p, ++i) p->~tstring(); + } +} + +template <> +/* static */ +inline void TypedAllocator::RunCtor(Allocator* raw_allocator, ResourceHandle* p, + size_t n) { + if (!raw_allocator->AllocatesOpaqueHandle()) { + for (size_t i = 0; i < n; ++p, ++i) new (p) ResourceHandle(); + } +} + +template <> +/* static */ +inline void TypedAllocator::RunDtor(Allocator* raw_allocator, ResourceHandle* p, + size_t n) { + if (!raw_allocator->AllocatesOpaqueHandle()) { + for (size_t i = 0; i < n; ++p, ++i) p->~ResourceHandle(); + } +} + +template <> +/* static */ +inline void TypedAllocator::RunCtor(Allocator* raw_allocator, Variant* p, + size_t n) { + if (!raw_allocator->AllocatesOpaqueHandle()) { + RunVariantCtor(p, n); + } +} + +template <> +/* static */ +inline void TypedAllocator::RunDtor(Allocator* raw_allocator, Variant* p, + size_t n) { + if (!raw_allocator->AllocatesOpaqueHandle()) { + RunVariantDtor(p, n); + } +} + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_FRAMEWORK_TYPED_ALLOCATOR_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/framework/types.h b/third_party/tflite-hdrs/tensorflow/core/framework/types.h new file mode 100644 index 00000000..c91e262c --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/framework/types.h @@ -0,0 +1,530 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_FRAMEWORK_TYPES_H_ +#define TENSORFLOW_CORE_FRAMEWORK_TYPES_H_ + +#include +#include +#include +#include + +#include "absl/numeric/bits.h" +#include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "xla/tsl/framework/device_type.h" +#include "tensorflow/core/framework/bfloat16.h" +#include "tensorflow/core/framework/full_type.pb.h" +#include "tensorflow/core/framework/numeric_types.h" +#include "tensorflow/core/framework/resource_handle.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/lib/gtl/inlined_vector.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +class Variant; + +// MemoryType is used to describe whether input or output Tensors of +// an OpKernel should reside in "Host memory" (e.g., CPU memory) or +// "Device" Memory (CPU memory for CPU devices, GPU memory for GPU +// devices). +enum MemoryType { + DEVICE_MEMORY = 0, + HOST_MEMORY = 1, +}; + +using tsl::DeviceType; // NOLINT + +// Convenient constants that can be passed to a DeviceType constructor. +// See comments for CreateOpKernel in op_kernel.h for uses of DEVICE_DEFAULT +// and other device types. +TF_EXPORT extern const char* const DEVICE_DEFAULT; // "DEFAULT" +TF_EXPORT extern const char* const DEVICE_CPU; // "CPU" +TF_EXPORT extern const char* const DEVICE_GPU; // "GPU" +TF_EXPORT extern const char* const DEVICE_TPU; // "TPU" +TF_EXPORT extern const char* const DEVICE_TPU_SYSTEM; // "TPU_SYSTEM" + +template +struct DeviceName {}; + +template <> +struct DeviceName { + static const std::string value; +}; + +#if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \ + (defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM) +template <> +struct DeviceName { + static const std::string value; +}; +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM + +typedef absl::InlinedVector MemoryTypeVector; +typedef absl::Span MemoryTypeSlice; + +typedef absl::InlinedVector DataTypeVector; +typedef absl::Span DataTypeSlice; + +typedef absl::InlinedVector DeviceTypeVector; +typedef absl::InlinedVector, 4UL> + PrioritizedDeviceTypeVector; + +// Convert the enums to strings for errors: +std::string DataTypeString(DataType dtype); +std::string DeviceTypeString(const DeviceType& device_type); +std::string DataTypeSliceString(const DataTypeSlice dtypes); +inline std::string DataTypeVectorString(const DataTypeVector& dtypes) { + return DataTypeSliceString(dtypes); +} + +// DataTypeSet represents a set of DataType values as a simple and efficient +// bit mask. Note that DataTypeSet cannot represent all DataType values; it +// cannot represent any of the DT_*_REF values. +class DataTypeSet { + private: + const uint32 mask_; + + static constexpr uint32 kNumBits = 32; + + public: + constexpr DataTypeSet(const DataTypeSet& other) : mask_(other.mask_) {} + explicit constexpr DataTypeSet(uint32 mask) : mask_(mask) {} + + constexpr bool Contains(DataType dt) const { + return (static_cast(dt) < kNumBits) && + ((mask_ >> static_cast(dt)) & 1u) != 0u; + } + + class Iterator { + const DataTypeSet& set_; + uint32 pos_; + + public: + Iterator(const DataTypeSet& set, uint32 pos) : set_(set), pos_(pos) { + DCHECK_LE(pos, kNumBits); + } + DataType operator*() const { return static_cast(pos_); } + Iterator& operator++() { + ++pos_; + DCHECK_LE(pos_, kNumBits); + if (pos_ < kNumBits) { + uint32 remaining_mask = set_.mask_ >> pos_; + if (remaining_mask != 0u) { + pos_ += absl::countr_zero(remaining_mask); + } + } + DCHECK_LE(pos_, kNumBits); + return *this; + } + bool operator==(const Iterator& other) const { return pos_ == other.pos_; } + bool operator!=(const Iterator& other) const { return !(*this == other); } + size_t operator-(const Iterator& other) const { + return this->pos_ - other.pos_; + } + }; + + Iterator begin() const { + // The begin position is the index of the first bit set to 1 in the entire + // bit mask. If there are no bits set to 1, then the index is 0. + if (mask_ != 0) { + return Iterator(*this, absl::countr_zero(mask_)); + } + // The set is empty. + return Iterator(*this, 0); + } + + Iterator end() const { + // The end position is the index of the highest bit that is set, plus 1. + // If there are no bits set to 1, then the index is 0. + if (mask_ != 0) { + return Iterator(*this, kNumBits - absl::countl_zero(mask_)); + } + // The set is empty. + return Iterator(*this, 0); + } + + size_t size() const { return absl::popcount(mask_); } + + constexpr DataTypeSet operator|(const DataTypeSet& other) const { + return DataTypeSet(mask_ | other.mask_); + } +}; + +// If "sp" names a valid type, store it in "*dt" and return true. Otherwise, +// return false. +bool DataTypeFromString(absl::string_view sp, DataType* dt); + +constexpr inline DataTypeSet ToSet(DataType dt) { + return DataTypeSet(1u << static_cast(dt)); +} + +// DT_FLOAT + kDataTypeRefOffset == DT_FLOAT_REF, etc. +enum { kDataTypeRefOffset = 100 }; +inline bool IsRefType(DataType dtype) { + return dtype > static_cast(kDataTypeRefOffset); +} +inline DataType MakeRefType(DataType dtype) { + DCHECK(!IsRefType(dtype)); + return static_cast(dtype + kDataTypeRefOffset); +} +inline DataType RemoveRefType(DataType dtype) { + DCHECK(IsRefType(dtype)); + return static_cast(dtype - kDataTypeRefOffset); +} +inline DataType BaseType(DataType dtype) { + return IsRefType(dtype) ? RemoveRefType(dtype) : dtype; +} + +// Returns true if the actual type is the same as or ref of the expected type. +inline bool TypesCompatible(DataType expected, DataType actual) { + return expected == actual || expected == BaseType(actual); +} + +// Does not include _ref types. +constexpr DataTypeSet kAllTypes = + ToSet(DT_FLOAT) | ToSet(DT_DOUBLE) | ToSet(DT_INT32) | ToSet(DT_UINT8) | + ToSet(DT_INT16) | ToSet(DT_UINT16) | ToSet(DT_INT8) | ToSet(DT_STRING) | + ToSet(DT_COMPLEX64) | ToSet(DT_COMPLEX128) | ToSet(DT_INT64) | + ToSet(DT_BOOL) | ToSet(DT_QINT8) | ToSet(DT_QUINT8) | ToSet(DT_QINT16) | + ToSet(DT_QUINT16) | ToSet(DT_QINT32) | ToSet(DT_HALF) | ToSet(DT_RESOURCE) | + ToSet(DT_VARIANT) | ToSet(DT_UINT32) | ToSet(DT_UINT64) | + ToSet(DT_BFLOAT16) | ToSet(DT_FLOAT8_E5M2) | ToSet(DT_FLOAT8_E4M3FN) | + ToSet(DT_INT4) | ToSet(DT_UINT4); + +inline const DataTypeSet& AllTypes() { return kAllTypes; } + +#if !defined(IS_MOBILE_PLATFORM) || defined(SUPPORT_SELECTIVE_REGISTRATION) + +// Types that support '<' and '>'. +constexpr DataTypeSet kRealNumberTypes = + ToSet(DT_FLOAT) | ToSet(DT_DOUBLE) | ToSet(DT_INT32) | ToSet(DT_INT64) | + ToSet(DT_UINT8) | ToSet(DT_INT16) | ToSet(DT_INT8) | ToSet(DT_UINT16) | + ToSet(DT_HALF) | ToSet(DT_UINT32) | ToSet(DT_UINT64) | ToSet(DT_BFLOAT16); +inline const DataTypeSet& RealNumberTypes() { return kRealNumberTypes; } + +// Return the list of all numeric types. +// Includes complex and quantized types. +// NOTE: On Android, we only include the float and int32 types for now. +const DataTypeSet kNumberTypes = + ToSet(DT_FLOAT) | ToSet(DT_DOUBLE) | ToSet(DT_INT64) | ToSet(DT_INT32) | + ToSet(DT_UINT8) | ToSet(DT_UINT16) | ToSet(DT_INT16) | ToSet(DT_INT8) | + ToSet(DT_COMPLEX64) | ToSet(DT_COMPLEX128) | ToSet(DT_QINT8) | + ToSet(DT_QUINT8) | ToSet(DT_QINT16) | ToSet(DT_QUINT16) | ToSet(DT_QINT32) | + ToSet(DT_HALF) | ToSet(DT_UINT32) | ToSet(DT_UINT64) | ToSet(DT_BFLOAT16); +inline const DataTypeSet& NumberTypes() { return kNumberTypes; } + +constexpr DataTypeSet kQuantizedTypes = ToSet(DT_QINT8) | ToSet(DT_QUINT8) | + ToSet(DT_QINT16) | ToSet(DT_QUINT16) | + ToSet(DT_QINT32); +inline const DataTypeSet& QuantizedTypes() { return kQuantizedTypes; } + +// Types that support '<' and '>', including quantized types. +const DataTypeSet kRealAndQuantizedTypes = + ToSet(DT_FLOAT) | ToSet(DT_DOUBLE) | ToSet(DT_INT32) | ToSet(DT_INT64) | + ToSet(DT_UINT8) | ToSet(DT_UINT16) | ToSet(DT_INT16) | ToSet(DT_INT8) | + ToSet(DT_QINT8) | ToSet(DT_QUINT8) | ToSet(DT_QINT16) | ToSet(DT_QUINT16) | + ToSet(DT_QINT32) | ToSet(DT_HALF) | ToSet(DT_BFLOAT16); +inline const DataTypeSet& RealAndQuantizedTypes() { + return kRealAndQuantizedTypes; +} + +#elif defined(__ANDROID_TYPES_FULL__) + +constexpr DataTypeSet kRealNumberTypes = + ToSet(DT_FLOAT) | ToSet(DT_INT32) | ToSet(DT_INT64) | ToSet(DT_HALF); +inline DataTypeSet RealNumberTypes() { return kRealNumberTypes; } + +constexpr DataTypeSet kNumberTypes = + ToSet(DT_FLOAT) | ToSet(DT_INT32) | ToSet(DT_INT64) | ToSet(DT_QINT8) | + ToSet(DT_QUINT8) | ToSet(DT_QINT32) | ToSet(DT_HALF); +inline DataTypeSet NumberTypes() { return kNumberTypes; } + +constexpr DataTypeSet kQuantizedTypes = ToSet(DT_QINT8) | ToSet(DT_QUINT8) | + ToSet(DT_QINT16) | ToSet(DT_QUINT16) | + ToSet(DT_QINT32); +inline DataTypeSet QuantizedTypes() { return kQuantizedTypes; } + +constexpr DataTypeSet kRealAndQuantizedTypes = + ToSet(DT_FLOAT) | ToSet(DT_INT32) | ToSet(DT_INT64) | ToSet(DT_QINT8) | + ToSet(DT_QUINT8) | ToSet(DT_QINT16) | ToSet(DT_QUINT16) | ToSet(DT_QINT32) | + ToSet(DT_HALF); +inline DataTypeSet RealAndQuantizedTypes() { return kRealAndQuantizedTypes; } + +#else // defined(IS_MOBILE_PLATFORM) && !defined(__ANDROID_TYPES_FULL__) + +constexpr DataTypeSet kRealNumberTypes = ToSet(DT_FLOAT) | ToSet(DT_INT32); +inline DataTypeSet RealNumberTypes() { return kRealNumberTypes; } + +constexpr DataTypeSet kNumberTypes = ToSet(DT_FLOAT) | ToSet(DT_INT32) | + ToSet(DT_QINT8) | ToSet(DT_QUINT8) | + ToSet(DT_QINT32); +inline DataTypeSet NumberTypes() { return kNumberTypes; } + +constexpr DataTypeSet kQuantizedTypes = ToSet(DT_QINT8) | ToSet(DT_QUINT8) | + ToSet(DT_QINT16) | ToSet(DT_QUINT16) | + ToSet(DT_QINT32); +inline DataTypeSet QuantizedTypes() { return kQuantizedTypes; } + +constexpr DataTypeSet kRealAndQuantizedTypes = + ToSet(DT_FLOAT) | ToSet(DT_INT32) | ToSet(DT_QINT8) | ToSet(DT_QUINT8) | + ToSet(DT_QINT16) | ToSet(DT_QUINT16) | ToSet(DT_QINT32); +inline DataTypeSet RealAndQuantizedTypes() { return kRealAndQuantizedTypes; } + +#endif // defined(IS_MOBILE_PLATFORM) + +// Validates type T for whether it is a supported DataType. +template +struct IsValidDataType; + +// DataTypeToEnum::v() and DataTypeToEnum::value are the DataType +// constants for T, e.g. DataTypeToEnum::v() is DT_FLOAT. +template +struct DataTypeToEnum { + static_assert(IsValidDataType::value, "Specified Data Type not supported"); +}; // Specializations below + +// EnumToDataType::Type is the type for DataType constant VALUE, e.g. +// EnumToDataType::Type is float. +template +struct EnumToDataType {}; // Specializations below + +// Template specialization for both DataTypeToEnum and EnumToDataType. +#define MATCH_TYPE_AND_ENUM(TYPE, ENUM) \ + template <> \ + struct DataTypeToEnum { \ + static DataType v() { return ENUM; } \ + static DataType ref() { return MakeRefType(ENUM); } \ + static constexpr DataType value = ENUM; \ + }; \ + template <> \ + struct IsValidDataType { \ + static constexpr bool value = true; \ + }; \ + template <> \ + struct EnumToDataType { \ + typedef TYPE Type; \ + } + +MATCH_TYPE_AND_ENUM(float, DT_FLOAT); +MATCH_TYPE_AND_ENUM(double, DT_DOUBLE); +MATCH_TYPE_AND_ENUM(int32, DT_INT32); +MATCH_TYPE_AND_ENUM(uint32, DT_UINT32); +MATCH_TYPE_AND_ENUM(uint16, DT_UINT16); +MATCH_TYPE_AND_ENUM(uint8, DT_UINT8); +MATCH_TYPE_AND_ENUM(int16, DT_INT16); +MATCH_TYPE_AND_ENUM(int8, DT_INT8); +MATCH_TYPE_AND_ENUM(tstring, DT_STRING); +MATCH_TYPE_AND_ENUM(complex64, DT_COMPLEX64); +MATCH_TYPE_AND_ENUM(complex128, DT_COMPLEX128); +MATCH_TYPE_AND_ENUM(bool, DT_BOOL); +MATCH_TYPE_AND_ENUM(qint8, DT_QINT8); +MATCH_TYPE_AND_ENUM(quint8, DT_QUINT8); +MATCH_TYPE_AND_ENUM(qint16, DT_QINT16); +MATCH_TYPE_AND_ENUM(quint16, DT_QUINT16); +MATCH_TYPE_AND_ENUM(qint32, DT_QINT32); +MATCH_TYPE_AND_ENUM(bfloat16, DT_BFLOAT16); +MATCH_TYPE_AND_ENUM(Eigen::half, DT_HALF); +MATCH_TYPE_AND_ENUM(float8_e5m2, DT_FLOAT8_E5M2); +MATCH_TYPE_AND_ENUM(float8_e4m3fn, DT_FLOAT8_E4M3FN); +MATCH_TYPE_AND_ENUM(int4, DT_INT4); +MATCH_TYPE_AND_ENUM(uint4, DT_UINT4); +MATCH_TYPE_AND_ENUM(ResourceHandle, DT_RESOURCE); +MATCH_TYPE_AND_ENUM(Variant, DT_VARIANT); + +template <> +struct DataTypeToEnum { + static DataType v() { return value; } + static DataType ref() { return MakeRefType(value); } + static constexpr DataType value = sizeof(long) == 4 ? DT_INT32 : DT_INT64; +}; +template <> +struct IsValidDataType { + static constexpr bool value = true; +}; +template <> +struct EnumToDataType { + typedef int64_t Type; +}; + +template <> +struct DataTypeToEnum { + static DataType v() { return value; } + static DataType ref() { return MakeRefType(value); } + static constexpr DataType value = + sizeof(unsigned long) == 4 ? DT_UINT32 : DT_UINT64; +}; +template <> +struct IsValidDataType { + static constexpr bool value = true; +}; +template <> +struct EnumToDataType { + typedef tensorflow::uint64 Type; +}; + +template <> +struct DataTypeToEnum { + static DataType v() { return DT_INT64; } + static DataType ref() { return MakeRefType(DT_INT64); } + static constexpr DataType value = DT_INT64; +}; +template <> +struct IsValidDataType { + static constexpr bool value = true; +}; + +template <> +struct DataTypeToEnum { + static DataType v() { return DT_UINT64; } + static DataType ref() { return MakeRefType(DT_UINT64); } + static constexpr DataType value = DT_UINT64; +}; +template <> +struct IsValidDataType { + static constexpr bool value = true; +}; + +#undef MATCH_TYPE_AND_ENUM + +// All types not specialized are marked invalid. +template +struct IsValidDataType { + static constexpr bool value = false; +}; + +// Extra validity checking; not part of public API. +static_assert(IsValidDataType::value, "Incorrect impl for int64"); +static_assert(IsValidDataType::value, "Incorrect impl for int32"); + +// TODO(jeff): Maybe unify this with Tensor::CanUseDMA, or the underlying +// is_simple in tensor.cc (and possible choose a more general name?) +constexpr DataTypeSet kDataTypesCanUseMemcpy = + ToSet(DT_FLOAT) | ToSet(DT_DOUBLE) | ToSet(DT_INT32) | ToSet(DT_UINT32) | + ToSet(DT_UINT8) | ToSet(DT_UINT16) | ToSet(DT_INT16) | ToSet(DT_INT8) | + ToSet(DT_COMPLEX64) | ToSet(DT_COMPLEX128) | ToSet(DT_INT64) | + ToSet(DT_UINT64) | ToSet(DT_BOOL) | ToSet(DT_QINT8) | ToSet(DT_QUINT8) | + ToSet(DT_QINT16) | ToSet(DT_QUINT16) | ToSet(DT_QINT32) | + ToSet(DT_BFLOAT16) | ToSet(DT_HALF) | ToSet(DT_FLOAT8_E5M2) | + ToSet(DT_FLOAT8_E4M3FN) | ToSet(DT_INT4) | ToSet(DT_UINT4); +inline bool DataTypeCanUseMemcpy(DataType dt) { + return kDataTypesCanUseMemcpy.Contains(dt); +} + +// Returns true iff 'dt' is a real, non-quantized floating point type. +constexpr DataTypeSet kDataTypeIsFloating = + ToSet(DT_HALF) | ToSet(DT_BFLOAT16) | ToSet(DT_FLOAT) | ToSet(DT_DOUBLE) | + ToSet(DT_FLOAT8_E4M3FN) | ToSet(DT_FLOAT8_E5M2); +inline bool DataTypeIsFloating(DataType dt) { + return kDataTypeIsFloating.Contains(dt); +} + +// Returns true iff 'dt' is a numeric type. +inline bool DataTypeIsNumeric(DataType dt) { return kNumberTypes.Contains(dt); } + +// Returns true iff 'dt' is a complex type. +constexpr DataTypeSet kDataTypeIsComplex = + ToSet(DT_COMPLEX64) | ToSet(DT_COMPLEX128); +inline bool DataTypeIsComplex(DataType dt) { + return kDataTypeIsComplex.Contains(dt); +} + +inline bool DataTypeIsQuantized(DataType dt) { + return kQuantizedTypes.Contains(dt); +} + +// Is the dtype nonquantized integral? +constexpr DataTypeSet kDataTypeIsInteger = + ToSet(DT_INT4) | ToSet(DT_UINT4) | ToSet(DT_INT8) | ToSet(DT_UINT8) | + ToSet(DT_INT16) | ToSet(DT_UINT16) | ToSet(DT_INT32) | ToSet(DT_UINT32) | + ToSet(DT_INT64) | ToSet(DT_UINT64); +inline bool DataTypeIsInteger(DataType dt) { + return kDataTypeIsInteger.Contains(dt); +} + +// Is the dtype a signed integral type? +constexpr DataTypeSet kDataTypeIsSigned = ToSet(DT_INT4) | ToSet(DT_INT8) | + ToSet(DT_INT16) | ToSet(DT_INT32) | + ToSet(DT_INT64); +inline bool DataTypeIsSigned(DataType dt) { + return kDataTypeIsSigned.Contains(dt); +} + +// Is the dtype an unsigned integral type? +constexpr DataTypeSet kDataTypeIsUnsigned = ToSet(DT_UINT4) | ToSet(DT_UINT8) | + ToSet(DT_UINT16) | + ToSet(DT_UINT32) | ToSet(DT_UINT64); +inline bool DataTypeIsUnsigned(DataType dt) { + return kDataTypeIsUnsigned.Contains(dt); +} + +// Returns a 0 on failure +int DataTypeSize(DataType dt); + +// Returns HOST_MEMORY if `dtype` is always on host or is a DT_INT32, +// DEVICE_MEMORY otherwise. +MemoryType MTypeFromDType(const DataType dtype); + +// Returns HOST_MEMORY if `dtype` is always on host, DEVICE_MEMORY otherwise. +// The reason we have MTypeFromDType() and MTypeFromDTypeIntsOnDevice(): for +// GPUs, we would like to keep int operations on host for performance concerns. +// But for TPUs (and other devices), int operations are placed on device. +MemoryType MTypeFromDTypeIntsOnDevice(const DataType dtype); + +// Types that always sit on host: DT_STRING, DT_STRING_REF, DT_RESOURCE. +// For DT_RESOURCE, the handle always sits on host (even if the underlying +// object has device-allocated resources). +bool DataTypeAlwaysOnHost(DataType dt); + +// FullType implementation. + +// Reference container for a type definition. These values are usually interned. +// These containers admit a notion of ordering for efficient access. The +// ordering has no semantic otherwise. +struct TypeRef { + std::shared_ptr full_type; + + bool operator==(const TypeRef& other) const { + // TODO(mdan): This should be more efficient. + return full_type->SerializeAsString() == + other.full_type->SerializeAsString(); + } + bool operator<(const TypeRef& other) const { + return full_type->SerializeAsString() < + other.full_type->SerializeAsString(); + } +}; + +struct TypeHasher { + std::size_t operator()(const TypeRef& k) const { + return std::hash()(k.full_type->SerializeAsString()); + } +}; + +// Maps a legacy DType proto enum to an equivalent FullType ID, +// i.e. sets the type_id of t based on dtype. +void map_dtype_to_tensor(const DataType& dtype, FullTypeDef& t); + +// Set the type id_of t to TFT_TENSOR and add a child arg by mapping +// a legacy DType proto enun to an equivalent FullType ID, e.g. +// if dtype is DT_FLOAT, sets t to TFT_TENSOR[TFT_FLOAT]. +void map_dtype_to_child_of_tensor(const DataType& dtype, FullTypeDef& t); + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_FRAMEWORK_TYPES_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/framework/variant.h b/third_party/tflite-hdrs/tensorflow/core/framework/variant.h new file mode 100644 index 00000000..152e0538 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/framework/variant.h @@ -0,0 +1,629 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_FRAMEWORK_VARIANT_H_ +#define TENSORFLOW_CORE_FRAMEWORK_VARIANT_H_ + +#include +#include +#include +#include +#include +#include + +#include "absl/memory/memory.h" +#include "tensorflow/core/framework/type_index.h" +#include "tensorflow/core/framework/variant_encode_decode.h" +#include "tensorflow/core/framework/variant_tensor_data.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/strcat.h" + +namespace tensorflow { + +template +std::string TypeNameVariant(const T& value); + +template +std::string DebugStringVariant(const T& value); + +// Allows for specializations of Variant Decoding. `data` may be modified in +// the process of decoding to `value`. +template +bool DecodeVariant(VariantTensorData* data, T* value); + +template +bool DecodeVariant(std::string* buf, T* value); + +template +void EncodeVariant(const T& value, VariantTensorData* data); + +template +void EncodeVariant(const T& value, std::string* buf); + +// This is an implementation of a type-erased container that can store an +// object of any type. The implementation is very similar to std::any, but has +// restrictions on the types of objects that can be stored, and eschews some of +// the fancier constructors available for std::any. An object of +// tensorflow::Variant is intended to be used as the value that will be stored +// in a tensorflow::Tensor object when its type is DT_VARIANT. +// +// tensorflow::Variant can store an object of a class that satisfies the +// following constraints: +// +// * The class is CopyConstructible. +// * The class has a default constructor. +// * It's either a protocol buffer, a tensorflow::Tensor, or defines the +// following functions: +// +// string TypeName() const; +// void Encode(VariantTensorData* data) const; +// bool Decode(VariantTensorData data); +// +// Simple POD types can elide the Encode/Decode functions, they are provided by +// helper methods. +// Here are some typical usage patterns: +// +// Variant x = 10; +// EXPECT_EQ(*x.get(), 10); +// +// Tensor t(DT_FLOAT, TensorShape({})); +// t.flat()(0) = 42.0f; +// Variant x = t; +// EXPECT_EQ(x.get()->flat()(0), 42.0f); +// +// Accessing the stored object: +// +// The get function is the main mechanism to access the object +// stored in the container. It is type-safe, that is, calling +// get when the stored object's type is not T, returns a +// nullptr. A raw pointer to the stored object can be obtained by calling +// get(). +// +// Serializing/deserializing Variant object: +// +// The Variant class delegates serializing and deserializing operations to the +// contained object. Helper functions to do these operations are provided for +// POD data types, tensorflow::Tensor, and protocol buffer objects. However, +// other classes have to provide Encode/Decode functions to handle +// serialization. +// +// Objects stored in a Variant object often contain references to other +// tensorflow::Tensors of primitive types (Eg., a list of tensorflow::Tensors). +// To efficiently support those use cases, a structure is imposed on the +// serialization format. Namely, classes should serialize their contents into a +// VariantTensorData object: +// +// struct VariantTensorData { +// string type_name; +// string metadata; +// std::vector tensors; +// }; +// +// Objects with references to other Tensors can simply store those tensors in +// the `tensors` field, and serialize other metadata content in to the +// `metadata` field. +// +// Serialization example: +// +// Foo f = Foo {...}; +// Variant x = f; +// string serialized_f; +// x.Encode(&serialized_f); +// +// Variant y = Foo(); // default constructed Foo. +// y.Decode(std::move(serialized_f)); +// EXPECT_EQ(*x.get(), *y.get()); +// +// +// A Variant storing serialized Variant data (a value of type +// VariantTensorDataProto) has different behavior from a standard Variant. +// Namely, its TypeName matches the TypeName of the original Variant; +// and its non-const get method performs lazy deserialization. +// +// Decode and copy example: +// +// Foo f = Foo {...}; +// Variant x = f; +// +// VariantTensorData serialized_data_f; +// VariantTensorDataProto serialized_proto_f; +// x.Encode(&serialized_data_f); +// serialized_data_f.ToProto(&serialized_proto_f); +// +// Variant y_type_unknown = serialized_proto_f; // Store serialized Variant. +// +// EXPECT_EQ(x.TypeName(), y_type_unknown.TypeName()); // Looks like Foo. +// EXPECT_EQ(TypeIndex::Make(), +// y_type_unknown.TypeId()); +// +class Variant { + public: + // Constructs a Variant holding no value (aka `is_empty()`). + // + // This is done by pointing at nullptr via the heap value. + Variant() noexcept : heap_value_(/*pointer=*/nullptr), is_inline_(false) {} + + ~Variant(); + + Variant(const Variant& other); + Variant(Variant&& other) noexcept; + + // Make sure that the type is CopyConstructible and not a + // tensorflow::Variant object itself. We want the copy constructor to be + // chosen for the tensorflow::Variant case. + template ::type, + typename std::enable_if::value && + std::is_move_constructible::value, + void>::type* = nullptr> + Variant(T&& value); + + template ::type, + typename std::enable_if::value && + std::is_copy_constructible::value, + void>::type* = nullptr> + Variant(const T& value); + + template ::type, + typename std::enable_if::value && + std::is_copy_constructible::value, + void>::type* = nullptr> + Variant& operator=(const T& value); + + template ::type, + typename std::enable_if::value && + std::is_move_constructible::value, + void>::type* = nullptr> + Variant& operator=(T&& value); + + Variant& operator=(const Variant& rhs) { + if (&rhs == this) return *this; + Variant(rhs).swap(*this); + return *this; + } + + Variant& operator=(Variant&& rhs) noexcept { + if (&rhs == this) return *this; + Variant(std::move(rhs)).swap(*this); + return *this; + } + + // Constructs a value of type T with the given args in-place in this Variant. + // Returns a reference to the newly constructed value. + // The signature is based on std::variant::emplace() in C++17. + template + T& emplace(Args&&... args) { + ResetMemory(); + is_inline_ = CanInlineType(); + if (is_inline_) { + new (&inline_value_) + InlineValue(InlineValue::Tag{}, std::forward(args)...); + return static_cast*>(inline_value_.AsValueInterface()) + ->value; + } else { + new (&heap_value_) HeapValue( + absl::make_unique>(InPlace(), std::forward(args)...)); + return static_cast*>(heap_value_.get())->value; + } + } + + bool is_empty() const { return GetValue() == nullptr; } + + void clear() noexcept; + + void swap(Variant& other) noexcept; + + // Note, unlike TypeName(), TypeId() does not return the TypeIndex + // of the original type when a TensorValueDataProto is stored as the + // value. In this case, it returns the TypeIndex of TensorValueDataProto. + TypeIndex TypeId() const { + const TypeIndex VoidTypeIndex = TypeIndex::Make(); + if (is_empty()) { + return VoidTypeIndex; + } + return GetValue()->TypeId(); + } + + std::string DebugString() const { + return strings::StrCat("Variant"); + } + + std::string SummarizeValue() const { + return is_empty() ? "[empty]" : GetValue()->DebugString(); + } + + // Returns a pointer to the stored value if it is type T, or nullptr + // otherwise. + template + T* get() { + const TypeIndex TTypeIndex = TypeIndex::Make(); + if (is_empty() || (TTypeIndex != TypeId())) return nullptr; + return std::addressof(static_cast*>(GetValue())->value); + } + + // Returns a pointer to the stored value if it is type T, or nullptr + // otherwise. + template + const T* get() const { + const TypeIndex TTypeIndex = TypeIndex::Make(); + if (is_empty() || (TTypeIndex != TypeId())) return nullptr; + return std::addressof( + static_cast*>(GetValue())->value); + } + + // Returns TypeNameVariant(value). + // + // In the special case that a serialized Variant is stored (value + // is a VariantTensorDataProto), returns value.TypeName(), the + // TypeName field stored in the VariantTensorDataProto buffer. + std::string TypeName() const { + if (is_empty()) { + return ""; + } + return GetValue()->TypeName(); + } + + // Serialize the contents of the stored object into `data`. + void Encode(VariantTensorData* data) const { + if (!is_empty()) { + GetValue()->Encode(data); + } + } + + // Deserialize `data` and update the stored object. + bool Decode(VariantTensorData data); + + // Helper methods to directly serialize/deserialize from strings. + void Encode(std::string* buf) const { + if (!is_empty()) { + GetValue()->Encode(buf); + } + } + bool Decode(std::string buf) { + if (!is_empty()) { + return GetValue()->Decode(std::move(buf)); + } + return true; + } + + template + static constexpr bool CanInlineType() { + return ((sizeof(Value) <= InlineValue::kMaxValueSize) && + (alignof(Value) <= kMaxInlineValueAlignSize)); + } + + private: + struct in_place_t {}; + static constexpr in_place_t InPlace() { return in_place_t{}; } + + struct ValueInterface { + virtual ~ValueInterface() = default; + virtual TypeIndex TypeId() const = 0; + virtual void* RawPtr() = 0; + virtual const void* RawPtr() const = 0; + virtual std::unique_ptr Clone() const = 0; + virtual void CloneInto(ValueInterface* memory) const = 0; + virtual void MoveAssign(ValueInterface* memory) = 0; + virtual void MoveInto(ValueInterface* memory) = 0; + virtual std::string TypeName() const = 0; + virtual std::string DebugString() const = 0; + virtual void Encode(VariantTensorData* data) const = 0; + virtual bool Decode(VariantTensorData data) = 0; + virtual void Encode(std::string* buf) const = 0; + virtual bool Decode(std::string data) = 0; + }; + + template + struct Value final : ValueInterface { + template + explicit Value(in_place_t /*tag*/, Args&&... args) + : value(std::forward(args)...) {} + + // NOTE(ebrevdo): Destructor must be explicitly defined for CUDA to happily + // build `alignof(Variant)`. + ~Value() final = default; + + TypeIndex TypeId() const final { + const TypeIndex value_type_index = + TypeIndex::Make::type>(); + return value_type_index; + } + + void* RawPtr() final { return &value; } + + const void* RawPtr() const final { return &value; } + + std::unique_ptr Clone() const final { + return absl::make_unique(InPlace(), value); + } + + void MoveAssign(ValueInterface* memory) final { + CHECK(TypeId() == memory->TypeId()) + << TypeId().name() << " vs. " << memory->TypeId().name(); + static_cast(memory)->value = std::move(value); + } + + void CloneInto(ValueInterface* memory) const final { + new (memory) Value(InPlace(), value); + } + + void MoveInto(ValueInterface* memory) final { + new (memory) Value(InPlace(), std::move(value)); + } + + std::string TypeName() const final { return TypeNameVariant(value); } + + std::string DebugString() const final { return DebugStringVariant(value); } + + void Encode(VariantTensorData* data) const final { + EncodeVariant(value, data); + } + + bool Decode(VariantTensorData data) final { + return DecodeVariant(&data, &value); + } + + void Encode(std::string* buf) const final { EncodeVariant(value, buf); } + + bool Decode(std::string buf) final { return DecodeVariant(&buf, &value); } + + T value; + }; + static constexpr int kMaxInlineValueAlignSize = alignof(Value); + + using HeapValue = std::unique_ptr; + + struct InlineValue { + // We try to size InlineValue so that sizeof(Variant) <= 64 and it can fit + // into the aligned space of a TensorBuffer. + static constexpr int kMaxValueSize = (64 - /*some extra padding=*/8); + + typedef char ValueDataArray[kMaxValueSize]; + alignas(kMaxInlineValueAlignSize) ValueDataArray value_data; + + // Tag is used for deducing the right type when constructing a Value in + // place. + template + struct Tag {}; + + template + explicit InlineValue(Tag /*tag*/, Args&&... args) noexcept { + Value* inline_value_data = reinterpret_cast*>(value_data); + new (inline_value_data) Value(InPlace(), std::forward(args)...); + } + + InlineValue(const InlineValue& other) noexcept { + other.AsValueInterface()->CloneInto(AsValueInterface()); + } + + InlineValue(InlineValue&& other) noexcept { + other.AsValueInterface()->MoveInto(AsValueInterface()); + } + + void ResetMemory() { AsValueInterface()->~ValueInterface(); } + + InlineValue& operator=(const InlineValue& other) { + if (&other == this) return *this; + ResetMemory(); + other.AsValueInterface()->CloneInto(AsValueInterface()); + return *this; + } + + InlineValue& operator=(InlineValue&& other) { + if (&other == this) return *this; + if (AsValueInterface()->TypeId() == other.AsValueInterface()->TypeId()) { + other.AsValueInterface()->MoveAssign(AsValueInterface()); + } else { + ResetMemory(); + other.AsValueInterface()->MoveInto(AsValueInterface()); + } + return *this; + } + + ValueInterface* AsValueInterface() { + return reinterpret_cast(value_data); + } + + const ValueInterface* AsValueInterface() const { + return reinterpret_cast(value_data); + } + + ~InlineValue() { ResetMemory(); } + }; + + union { + HeapValue heap_value_; + InlineValue inline_value_; + }; + // is_inline_ provides discrimination between which member of the prior union + // is currently within it's lifetime. To switch from one member to the other, + // the destructor must be called on the currently alive member before calling + // the constructor on the other member. In effect, a member is expected to be + // live at any given time and that member is tracked via this boolean. + bool is_inline_; + + bool IsInlineValue() const { return is_inline_; } + + // ResetMemory causes the destructor of the currently active member of the + // union to be run. This must be follwed with a placement new call on the + // member whose lifetime is to start. Additionally, is_inline_ needs to be set + // accordingly. ResetAndSetInline and ResetAndSetHeap are simple helper + // functions for performing the actions that are required to follow. + void ResetMemory() { + if (IsInlineValue()) { + inline_value_.~InlineValue(); + } else { + heap_value_.~HeapValue(); + } + } + + // ResetAndSetInline clears the current state and then constructs a new value + // inline with the provided arguments. + template + void ResetAndSetInline(Args&&... args) noexcept { + ResetMemory(); + new (&inline_value_) InlineValue(std::forward(args)...); + is_inline_ = true; + } + + // ResetAndSetHeap clears the current state then constructs a new value on the + // heap with the provided arguments. + template + void ResetAndSetHeap(Args&&... args) noexcept { + ResetMemory(); + new (&heap_value_) HeapValue(std::forward(args)...); + is_inline_ = false; + } + + ValueInterface* GetValue() { + if (IsInlineValue()) { + return inline_value_.AsValueInterface(); + } else { + return heap_value_.get(); + } + } + + const ValueInterface* GetValue() const { + if (IsInlineValue()) { + return inline_value_.AsValueInterface(); + } else { + return heap_value_.get(); + } + } + + // PRECONDITION: Called on construction or ResetMemory() has been called + // before this method. + template + void InsertValue(T&& value) { + if (IsInlineValue()) { + new (&inline_value_) + InlineValue(InlineValue::Tag{}, std::forward(value)); + } else { + new (&heap_value_) HeapValue( + absl::make_unique>(InPlace(), std::forward(value))); + } + } +}; + +// Make sure that a Variant object can reside in a 64-byte aligned Tensor +// buffer. +static_assert(sizeof(Variant) <= 64, + "Expected internal representation to be 64 bytes."); + +inline Variant::Variant(const Variant& other) + : is_inline_(other.IsInlineValue()) { + if (IsInlineValue()) { + new (&inline_value_) InlineValue(other.inline_value_); + } else { + new (&heap_value_) + HeapValue(other.heap_value_ ? other.heap_value_->Clone() : nullptr); + } +} + +inline Variant::Variant(Variant&& other) noexcept + : is_inline_(other.IsInlineValue()) { + if (IsInlineValue()) { + new (&inline_value_) InlineValue(std::move(other.inline_value_)); + } else { + new (&heap_value_) HeapValue(std::move(other.heap_value_)); + } +} + +template ::value && + std::is_move_constructible::value, + void>::type*> +inline Variant::Variant(T&& value) : is_inline_(CanInlineType()) { + InsertValue(std::forward(value)); +} + +template ::value && + std::is_copy_constructible::value, + void>::type*> +inline Variant::Variant(const T& value) : is_inline_(CanInlineType()) { + InsertValue(value); +} + +template ::value && + std::is_move_constructible::value, + void>::type*> +inline Variant& Variant::operator=(T&& value) { + ResetMemory(); + is_inline_ = CanInlineType(); + InsertValue(std::forward(value)); + return *this; +} + +template ::value && + std::is_copy_constructible::value, + void>::type*> +inline Variant& Variant::operator=(const T& value) { + ResetMemory(); + is_inline_ = CanInlineType(); + InsertValue(value); + return *this; +} + +inline void Variant::clear() noexcept { + // We set the internal unique_ptr to nullptr so that we preserve the + // invariant that one of the two states must be set at all times. nullptr + // indicates that the variant is empty. + ResetAndSetHeap(/*pointer=*/nullptr); +} + +inline void Variant::swap(Variant& other) noexcept { + if (is_empty()) { + if (other.IsInlineValue()) { + ResetAndSetInline(std::move(other.inline_value_)); + } else { + ResetAndSetHeap(std::move(other.heap_value_)); + } + other.clear(); + } else if (other.is_empty()) { + if (IsInlineValue()) { + other.ResetAndSetInline(std::move(inline_value_)); + } else { + other.ResetAndSetHeap(std::move(heap_value_)); + } + clear(); + } else { // Both Variants have values. + if (other.IsInlineValue() && IsInlineValue()) { + std::swap(inline_value_, other.inline_value_); + } else if (!other.IsInlineValue() && !IsInlineValue()) { + std::swap(heap_value_, other.heap_value_); + } else if (other.IsInlineValue() && !IsInlineValue()) { + HeapValue v = std::move(heap_value_); + ResetAndSetInline(std::move(other.inline_value_)); + other.ResetAndSetHeap(std::move(v)); + } else { // !other.IsInlineValue() && IsInlineValue() + HeapValue v = std::move(other.heap_value_); + other.ResetAndSetInline(std::move(inline_value_)); + ResetAndSetHeap(std::move(v)); + } + } +} + +template <> +void* Variant::get(); + +template <> +const void* Variant::get() const; + +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_FRAMEWORK_VARIANT_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/framework/variant_encode_decode.h b/third_party/tflite-hdrs/tensorflow/core/framework/variant_encode_decode.h new file mode 100644 index 00000000..20ceeb93 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/framework/variant_encode_decode.h @@ -0,0 +1,284 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_FRAMEWORK_VARIANT_ENCODE_DECODE_H_ +#define TENSORFLOW_CORE_FRAMEWORK_VARIANT_ENCODE_DECODE_H_ + +#include +#include +#include +#include + +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/type_index.h" +#include "tensorflow/core/framework/variant_tensor_data.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/abi.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/protobuf.h" + +namespace tensorflow { + +// Type used for tag-dispatch of the Encode/Decode Variant implementations. This +// template can determine whether the first type parameter `T` is one of the +// following: +// +// * A POD type (TypeResolver) +// * A tensorflow::Tensor (TypeResolver) +// * A protocol buffer (TypeResolver) +// * None of the above (TypeResolver) +// +template ::type>::value, + bool = std::is_same::type, + ::tensorflow::Tensor>::value, + bool = std::is_base_of::type>::value> +struct TypeResolver {}; + +// Specialization for POD type +template +void EncodeVariantImpl(const T& value, TypeResolver, + VariantTensorData* data) { + data->set_metadata(value); +} + +// Specialization for tensorflow::Tensor +template +void EncodeVariantImpl(const T& value, + TypeResolver, + VariantTensorData* data) { + data->tensors_.clear(); + data->tensors_.push_back(value); +} + +// Specialization for protobuf +template +void EncodeVariantImpl(const T& value, + TypeResolver, + VariantTensorData* data) { + if (!value.SerializeToString(&data->metadata_)) { + data->metadata_.clear(); + LOG(ERROR) << "Failed to encode variant " << value.DebugString(); + } +} + +// Specialization for other types +template +void EncodeVariantImpl(const T& value, + TypeResolver, + VariantTensorData* data) { + value.Encode(data); +} + +// Specialization for POD type +template +bool DecodeVariantImpl(VariantTensorData data, + TypeResolver, + T* value) { + return data.get_metadata(value); +} + +// Specialization for tensorflow::Tensor +template +bool DecodeVariantImpl(VariantTensorData data, + TypeResolver, + T* value) { + *value = data.tensors(0); + return true; +} + +// Specialization for protobuf +template +bool DecodeVariantImpl(VariantTensorData data, + TypeResolver, + T* value) { + std::string metadata; + data.get_metadata(&metadata); + return value->ParseFromString(std::move(metadata)); +} + +// Specialization for other types +template +bool DecodeVariantImpl(VariantTensorData data, + TypeResolver, + T* value) { + return value->Decode(std::move(data)); +} + +template +struct has_type_name : std::false_type {}; + +template +struct has_type_name< + C, typename std::enable_if().TypeName()), string>::value>::type> + : std::true_type {}; + +template ::type>::value, + bool = std::is_same::type, + ::tensorflow::Tensor>::value, + bool = std::is_base_of::type>::value> +struct TypeNameResolver {}; + +template +std::string TypeNameVariantImpl(const T& value, + TypeNameResolver) { + return value.TypeName(); +} + +template +std::string TypeNameVariantImpl( + const T& value, + TypeNameResolver) { + return "tensorflow::Tensor"; +} + +template +std::string TypeNameVariantImpl( + const T& value, TypeNameResolver) { + return std::string(value.GetTypeName()); +} + +template +std::string TypeNameVariantImpl( + const T& value, + TypeNameResolver) { + return port::MaybeAbiDemangle(TypeIndex::Make().name()); +} + +template +std::string TypeNameVariant(const T& value) { + return TypeNameVariantImpl(value, TypeNameResolver()); +} + +template +struct has_debug_string : std::false_type {}; + +template +struct has_debug_string< + C, typename std::enable_if().DebugString()), string>::value>::type> + : std::true_type {}; + +template +struct can_strcat : std::false_type {}; + +template +struct can_strcat< + C, typename std::enable_if())), string>::value>::type> + : std::true_type {}; + +template ::type>::value, + bool = can_strcat::type>::value> +struct DebugStringResolver {}; + +// TODO(ebrevdo): Expand DebugStringResolver to return TypeString if +// there is no StrCat() constructor. +template +std::string DebugStringVariantImpl( + const T& value, DebugStringResolver) { + return value.DebugString(); +} + +template +std::string DebugStringVariantImpl( + const T& value, DebugStringResolver) { + return strings::StrCat(value); +} + +template +std::string DebugStringVariantImpl( + const T& value, DebugStringResolver) { + return "?"; +} + +template +std::string DebugStringVariant(const T& value) { + return DebugStringVariantImpl(value, DebugStringResolver()); +} + +template +void EncodeVariant(const T& value, VariantTensorData* data) { + EncodeVariantImpl(value, TypeResolver(), data); + data->set_type_name(TypeNameVariant(value)); +} + +template +bool DecodeVariant(VariantTensorData* data, T* value) { + return DecodeVariantImpl(std::move(*data), TypeResolver(), value); +} + +template +void EncodeVariant(const T& value, std::string* buf) { + VariantTensorData data; + EncodeVariantImpl(value, TypeResolver(), &data); + data.set_type_name(TypeNameVariant(value)); + DCHECK(buf != nullptr); + data.SerializeToString(buf); +} + +template +bool DecodeVariant(std::string* buf, T* value) { + VariantTensorData data; + if (!data.ParseFromString(*buf)) return false; + if (!DecodeVariantImpl(std::move(data), TypeResolver(), value)) { + return false; + } + return true; +} + +// Specializations for VariantTensorDataProto +template <> +std::string TypeNameVariant(const VariantTensorDataProto& value); + +template <> +void EncodeVariant(const VariantTensorDataProto& value, + VariantTensorData* data); + +template <> +bool DecodeVariant(VariantTensorData* data, VariantTensorDataProto* value); + +template <> +void EncodeVariant(const VariantTensorDataProto& value, std::string* buf); + +template <> +bool DecodeVariant(std::string* buf, VariantTensorDataProto* value); + +// Encodes an array of Variant objects in to the given StringListEncoder. +// `variant_array` is assumed to point to an array of `n` Variant objects. +void EncodeVariantList(const Variant* variant_array, int64_t n, + std::unique_ptr e); + +// Decodes an array of Variant objects from the given StringListDecoder. +// `variant_array` is assumed to point to an array of `n` Variant objects. +bool DecodeVariantList(std::unique_ptr d, + Variant* variant_array, int64_t n); + +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_FRAMEWORK_VARIANT_ENCODE_DECODE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/framework/variant_op_registry.h b/third_party/tflite-hdrs/tensorflow/core/framework/variant_op_registry.h new file mode 100644 index 00000000..c7d8680d --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/framework/variant_op_registry.h @@ -0,0 +1,596 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_FRAMEWORK_VARIANT_OP_REGISTRY_H_ +#define TENSORFLOW_CORE_FRAMEWORK_VARIANT_OP_REGISTRY_H_ + +#include +#include +#include + +#define EIGEN_USE_THREADS + +#include "tensorflow/core/framework/tensor.pb.h" +#include "tensorflow/core/framework/type_index.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/variant.h" +#include "tensorflow/core/framework/variant_encode_decode.h" +#include "tensorflow/core/lib/gtl/flatmap.h" +#include "tensorflow/core/lib/hash/hash.h" +#include "tensorflow/core/platform/abi.h" + +namespace tensorflow { + +class OpKernelContext; +// A global UnaryVariantOpRegistry is used to hold callback functions +// for different variant types. To be used by ShapeOp, RankOp, and +// SizeOp, decoding, etc. + +enum VariantUnaryOp { + INVALID_VARIANT_UNARY_OP = 0, + ZEROS_LIKE_VARIANT_UNARY_OP = 1, + CONJ_VARIANT_UNARY_OP = 2, +}; + +const char* VariantUnaryOpToString(VariantUnaryOp op); + +enum VariantBinaryOp { + INVALID_VARIANT_BINARY_OP = 0, + ADD_VARIANT_BINARY_OP = 1, +}; + +const char* VariantBinaryOpToString(VariantBinaryOp op); + +enum VariantDeviceCopyDirection { + INVALID_DEVICE_COPY_DIRECTION = 0, + HOST_TO_DEVICE = 1, + DEVICE_TO_HOST = 2, + DEVICE_TO_DEVICE = 3, +}; + +class UnaryVariantOpRegistry; +extern UnaryVariantOpRegistry* UnaryVariantOpRegistryGlobal(); + +class UnaryVariantOpRegistry { + public: + typedef std::function VariantDecodeFn; + typedef std::function + VariantUnaryOpFn; + typedef std::function + VariantBinaryOpFn; + + // An AsyncTensorDeviceCopyFn is a function provided to + // the user-provided DeviceCopyFn callback as the third argument ("copier"). + // + // Expected inputs: + // from: A Tensor on the host (if performing cpu->gpu copy), or + // device (if performing gpu->cpu or gpu->gpu copy). + // to: An empty/uninitialized tensor. It will be updated upon + // successful return of the function with the correct dtype and shape. + // However, the copied data will not be available until the compute + // stream has been synchronized. + // + // Returns: + // The status upon memory allocation / initialization of the + // "to" tensor, and enqueue of the copy onto the compute stream. + // Any failure of the copy itself will update the underlying + // stream status and propagate through the runtime independent + // of the caller. + typedef std::function + AsyncTensorDeviceCopyFn; + + // The AsyncVariantDeviceCopyFn is the signature of the 'device_copy_fn' + // expected to be passed to the registration macro + // INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION. + typedef std::function + AsyncVariantDeviceCopyFn; + + // Add a decode function to the registry. + void RegisterDecodeFn(const std::string& type_name, + const VariantDecodeFn& decode_fn); + + // Returns nullptr if no decode function was found for the given TypeName. + VariantDecodeFn* GetDecodeFn(absl::string_view type_name); + + // Add a copy-to-GPU function to the registry. + void RegisterDeviceCopyFn(const VariantDeviceCopyDirection direction, + const TypeIndex& type_index, + const AsyncVariantDeviceCopyFn& device_copy_fn) { + AsyncVariantDeviceCopyFn* existing = GetDeviceCopyFn(direction, type_index); + CHECK_EQ(existing, nullptr) + << "UnaryVariantDeviceCopy for direction: " << direction + << " and type_index: " << port::MaybeAbiDemangle(type_index.name()) + << " already registered"; + device_copy_fns.insert( + std::pair, + AsyncVariantDeviceCopyFn>( + std::make_pair(direction, type_index), device_copy_fn)); + } + + // Returns nullptr if no copy function was found for the given + // TypeName and direction. + AsyncVariantDeviceCopyFn* GetDeviceCopyFn( + const VariantDeviceCopyDirection direction, const TypeIndex& type_index) { + auto found = device_copy_fns.find(std::make_pair(direction, type_index)); + if (found == device_copy_fns.end()) return nullptr; + return &found->second; + } + + // Add a unary op function to the registry. + void RegisterUnaryOpFn(VariantUnaryOp op, const std::string& device, + const TypeIndex& type_index, + const VariantUnaryOpFn& unary_op_fn) { + VariantUnaryOpFn* existing = GetUnaryOpFn(op, device, type_index); + CHECK_EQ(existing, nullptr) + << "Unary VariantUnaryOpFn for type_index: " + << port::MaybeAbiDemangle(type_index.name()) + << " already registered for device type: " << device; + unary_op_fns.insert(std::pair, VariantUnaryOpFn>( + {op, GetPersistentStringPiece(device), type_index}, unary_op_fn)); + } + + // Returns nullptr if no unary op function was found for the given + // op, device, and TypeName. + VariantUnaryOpFn* GetUnaryOpFn(VariantUnaryOp op, absl::string_view device, + const TypeIndex& type_index) { + auto found = unary_op_fns.find({op, device, type_index}); + if (found == unary_op_fns.end()) return nullptr; + return &found->second; + } + + // Add a binary op function to the registry. + void RegisterBinaryOpFn(VariantBinaryOp op, const std::string& device, + const TypeIndex& type_index, + const VariantBinaryOpFn& add_fn) { + VariantBinaryOpFn* existing = GetBinaryOpFn(op, device, type_index); + CHECK_EQ(existing, nullptr) + << "Unary VariantBinaryOpFn for type_index: " + << port::MaybeAbiDemangle(type_index.name()) + << " already registered for device type: " << device; + binary_op_fns.insert( + std::pair, VariantBinaryOpFn>( + {op, GetPersistentStringPiece(device), type_index}, add_fn)); + } + + // Returns nullptr if no binary op function was found for the given + // op, device and TypeName. + VariantBinaryOpFn* GetBinaryOpFn(VariantBinaryOp op, absl::string_view device, + const TypeIndex& type_index) { + auto found = binary_op_fns.find({op, device, type_index}); + if (found == binary_op_fns.end()) return nullptr; + return &found->second; + } + + // Get a pointer to a global UnaryVariantOpRegistry object + static UnaryVariantOpRegistry* Global() { + return UnaryVariantOpRegistryGlobal(); + } + + // Get a pointer to a global persistent string storage object. + // ISO/IEC C++ working draft N4296 clarifies that insertion into an + // std::unordered_set does not invalidate memory locations of + // *values* inside the set (though it may invalidate existing + // iterators). In other words, one may safely point a StringPiece to + // a value in the set without that StringPiece being invalidated by + // future insertions. + static std::unordered_set* PersistentStringStorage(); + + private: + struct TypeIndexHash { + std::size_t operator()(const TypeIndex& x) const { return x.hash_code(); } + }; + + gtl::FlatMap + decode_fns; + + // Map std::pair to function. + struct PairHash { + template + std::size_t operator()(const std::pair& x) const { + // The hash of an enum is just its value as a std::size_t. + std::size_t ret = static_cast(std::get<0>(x)); + ret = Hash64Combine(ret, std::get<1>(x).hash_code()); + return ret; + } + }; + + gtl::FlatMap, + AsyncVariantDeviceCopyFn, PairHash> + device_copy_fns; + + // Map std::tuple to function. + + // this breaks by falling victim to "too perfect forwarding" + // see https://stackoverflow.com/questions/44475317/variadic-template-issue + // and references therein + template + struct FuncTuple { + FuncTuple(const Op& op, const absl::string_view& dev, + const TypeIndex& type_index) + : op_type_(op), device_(dev), type_index_(type_index) {} + Op op_type_; + absl::string_view device_; + TypeIndex type_index_; + }; + // friend declaration for operator== + // needed for clang + template + friend bool operator==(const FuncTuple& l, const FuncTuple& r); + struct TupleHash { + template + std::size_t operator()( + const std::tuple& x) const { + // The hash of an enum is just its value as a std::size_t. + std::size_t ret = static_cast(std::get<0>(x)); + ret = Hash64Combine(ret, sp_hasher_(std::get<1>(x))); + ret = Hash64Combine(ret, std::get<2>(x).hash_code()); + return ret; + } + + template + std::size_t operator()(const FuncTuple& x) const { + // The hash of an enum is just its value as a std::size_t. + std::size_t ret = static_cast(x.op_type_); + ret = Hash64Combine(ret, sp_hasher_(x.device_)); + ret = Hash64Combine(ret, x.type_index_.hash_code()); + return ret; + } + StringPieceHasher sp_hasher_; + }; + gtl::FlatMap, VariantUnaryOpFn, TupleHash> + unary_op_fns; + gtl::FlatMap, VariantBinaryOpFn, TupleHash> + binary_op_fns; + + // Find or insert a string into a persistent string storage + // container; return the StringPiece pointing to the permanent string + // location. + static absl::string_view GetPersistentStringPiece(const std::string& str) { + const auto string_storage = PersistentStringStorage(); + auto found = string_storage->find(str); + if (found == string_storage->end()) { + auto inserted = string_storage->insert(str); + return absl::string_view(*inserted.first); + } else { + return absl::string_view(*found); + } + } +}; +template +inline bool operator==(const UnaryVariantOpRegistry::FuncTuple& lhs, + const UnaryVariantOpRegistry::FuncTuple& rhs) { + return (lhs.op_type_ == rhs.op_type_) && (lhs.device_ == rhs.device_) && + (lhs.type_index_ == rhs.type_index_); +} + +// Decodes the Variant whose data_type has a registered decode +// function. Returns an Internal error if the Variant does not have a +// registered decode function, or if the decoding function fails. +// +// REQUIRES: +// variant is not null. +// +bool DecodeUnaryVariant(Variant* variant); + +// Copies a variant between CPU<->GPU, or between GPU<->GPU. +// The variant 'from' must have a registered DeviceCopyFn for the +// given direction. The returned variant 'to' will have +// (some subset of its) tensors stored on destination according to the +// registered DeviceCopyFn function for the given direction. Returns +// an Internal error if the Variant does not have a registered +// DeviceCopyFn function for the given direction, or if initiating the +// copy fails. +// +// REQUIRES: +// 'to' is not null. +// +absl::Status VariantDeviceCopy( + const VariantDeviceCopyDirection direction, const Variant& from, + Variant* to, + const UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn& copy_fn); + +// Sets *v_out = unary_op(v). The variant v must have a registered +// UnaryOp function for the given Device. Returns an Internal error +// if v does not have a registered unary_op function for this device, or if +// UnaryOp fails. +// +// REQUIRES: +// v_out is not null. +// +template +absl::Status UnaryOpVariant(OpKernelContext* ctx, VariantUnaryOp op, + const Variant& v, Variant* v_out) { + const std::string& device = DeviceName::value; + UnaryVariantOpRegistry::VariantUnaryOpFn* unary_op_fn = + UnaryVariantOpRegistry::Global()->GetUnaryOpFn(op, device, v.TypeId()); + if (unary_op_fn == nullptr) { + return errors::Internal("No unary variant unary_op function found for op ", + VariantUnaryOpToString(op), + " Variant type_name: ", v.TypeName(), + " for device type: ", device); + } + return (*unary_op_fn)(ctx, v, v_out); +} + +// Sets *out = binary_op(a, b). The variants a and b must be the same type +// and have a registered binary_op function for the given Device. Returns an +// Internal error if a and b are not the same type_name or if +// if a does not have a registered op function for this device, or if +// BinaryOp fails. +// +// REQUIRES: +// out is not null. +// +template +absl::Status BinaryOpVariants(OpKernelContext* ctx, VariantBinaryOp op, + const Variant& a, const Variant& b, + Variant* out) { + if (a.TypeId() != b.TypeId()) { + return errors::Internal( + "BinaryOpVariants: Variants a and b have different " + "type ids. Type names: '", + a.TypeName(), "' vs. '", b.TypeName(), "'"); + } + const std::string& device = DeviceName::value; + UnaryVariantOpRegistry::VariantBinaryOpFn* binary_op_fn = + UnaryVariantOpRegistry::Global()->GetBinaryOpFn(op, device, a.TypeId()); + if (binary_op_fn == nullptr) { + return errors::Internal("No unary variant binary_op function found for op ", + VariantBinaryOpToString(op), + " Variant type_name: '", a.TypeName(), + "' for device type: ", device); + } + return (*binary_op_fn)(ctx, a, b, out); +} + +namespace variant_op_registry_fn_registration { + +template +class UnaryVariantDecodeRegistration { + public: + UnaryVariantDecodeRegistration(const std::string& type_name) { + // The Variant is passed by pointer because it should be + // mutable: get below may Decode the variant, which + // is a self-mutating behavior. The variant is not modified in + // any other way. + UnaryVariantOpRegistry::Global()->RegisterDecodeFn( + type_name, [type_name](Variant* v) -> bool { + DCHECK_NE(v, nullptr); + VariantTensorDataProto* t = v->get(); + if (t == nullptr) { + return false; + } + Variant decoded = T(); + VariantTensorData data(std::move(*t)); + if (!decoded.Decode(std::move(data))) { + return false; + } + std::swap(decoded, *v); + return true; + }); + } +}; + +template +class UnaryVariantDeviceCopyRegistration { + public: + typedef std::function + LocalVariantDeviceCopyFn; + UnaryVariantDeviceCopyRegistration( + const VariantDeviceCopyDirection direction, const TypeIndex& type_index, + const LocalVariantDeviceCopyFn& device_copy_fn) { + const std::string type_index_name = + port::MaybeAbiDemangle(type_index.name()); + UnaryVariantOpRegistry::Global()->RegisterDeviceCopyFn( + direction, type_index, + [type_index_name, device_copy_fn]( + const Variant& from, Variant* to, + UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn + device_copy_tensor_fn) -> absl::Status { + DCHECK_NE(to, nullptr); + *to = T(); + if (from.get() == nullptr) { + return errors::Internal( + "VariantCopyToGPUFn: Could not access object, type_index: ", + type_index_name); + } + const T& t = *from.get(); + T* t_out = to->get(); + return device_copy_fn(t, t_out, device_copy_tensor_fn); + }); + } +}; + +template +class UnaryVariantUnaryOpRegistration { + typedef std::function + LocalVariantUnaryOpFn; + + public: + UnaryVariantUnaryOpRegistration(VariantUnaryOp op, const std::string& device, + const TypeIndex& type_index, + const LocalVariantUnaryOpFn& unary_op_fn) { + const std::string type_index_name = + port::MaybeAbiDemangle(type_index.name()); + UnaryVariantOpRegistry::Global()->RegisterUnaryOpFn( + op, device, type_index, + [type_index_name, unary_op_fn](OpKernelContext* ctx, const Variant& v, + Variant* v_out) -> absl::Status { + DCHECK_NE(v_out, nullptr); + *v_out = T(); + if (v.get() == nullptr) { + return errors::Internal( + "VariantUnaryOpFn: Could not access object, type_index: ", + type_index_name); + } + const T& t = *v.get(); + T* t_out = v_out->get(); + return unary_op_fn(ctx, t, t_out); + }); + } +}; + +template +class UnaryVariantBinaryOpRegistration { + typedef std::function + LocalVariantBinaryOpFn; + + public: + UnaryVariantBinaryOpRegistration(VariantBinaryOp op, + const std::string& device, + const TypeIndex& type_index, + const LocalVariantBinaryOpFn& binary_op_fn) { + const std::string type_index_name = + port::MaybeAbiDemangle(type_index.name()); + UnaryVariantOpRegistry::Global()->RegisterBinaryOpFn( + op, device, type_index, + [type_index_name, binary_op_fn](OpKernelContext* ctx, const Variant& a, + const Variant& b, + Variant* out) -> absl::Status { + DCHECK_NE(out, nullptr); + *out = T(); + if (a.get() == nullptr) { + return errors::Internal( + "VariantBinaryOpFn: Could not access object 'a', type_index: ", + type_index_name); + } + if (b.get() == nullptr) { + return errors::Internal( + "VariantBinaryOpFn: Could not access object 'b', type_index: ", + type_index_name); + } + const T& t_a = *a.get(); + const T& t_b = *b.get(); + T* t_out = out->get(); + return binary_op_fn(ctx, t_a, t_b, t_out); + }); + } +}; + +}; // namespace variant_op_registry_fn_registration + +// Register a unary decode variant function for the given type. +#define REGISTER_UNARY_VARIANT_DECODE_FUNCTION(T, type_name) \ + REGISTER_UNARY_VARIANT_DECODE_FUNCTION_UNIQ_HELPER(__COUNTER__, T, type_name) + +#define REGISTER_UNARY_VARIANT_DECODE_FUNCTION_UNIQ_HELPER(ctr, T, type_name) \ + REGISTER_UNARY_VARIANT_DECODE_FUNCTION_UNIQ(ctr, T, type_name) + +#define REGISTER_UNARY_VARIANT_DECODE_FUNCTION_UNIQ(ctr, T, type_name) \ + static ::tensorflow::variant_op_registry_fn_registration:: \ + UnaryVariantDecodeRegistration \ + register_unary_variant_op_decoder_fn_##ctr(type_name) + +// ****** NOTE ****** +// FOR INTERNAL USE ONLY. IF YOU USE THIS WE MAY BREAK YOUR CODE. +// ****** NOTE ****** +// +// Register a device copy variant function for the given copy +// direction and type; where direction is the enum +// VariantDeviceCopyDirection, and the device_copy_fn has signature: +// +// Status device_copy_fn( +// const T& t, T* t_out, +// const UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn& copier); +// +// And device_copy_fn calls copier 0 or more times. For details on +// the behavior of the copier function, see the comments at the +// declaration of UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn. +// +// Note, the device_copy_fn may choose to keep some tensors +// on host, e.g. by assigning to->tensor = from.tensor (assuming +// from.tensor is already on host); or by setting +// to->tensor = Tensor(cpu_allocator(), ...) +// and manually updating its values. +// +// If this is the case, the CopyFns for HOST_TO_DEVICE, +// DEVICE_TO_HOST, and DEVICE_TO_DEVICE must perform host-to-host +// copies in a consistent manner. For example, one must always +// manually copy any "always on host" tensors in all directions instead of e.g. +// - performing a host-to-host copy in one direction, +// - using the provided copier function in the reverse direction. +// Doing the latter will cause program failures. +// +// ****** NOTE ****** +// FOR INTERNAL USE ONLY. IF YOU USE THIS WE MAY BREAK YOUR CODE. +// ****** NOTE ****** +#define INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION(T, direction, \ + device_copy_fn) \ + INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION_UNIQ_HELPER( \ + __COUNTER__, T, direction, TypeIndex::Make(), device_copy_fn) + +#define INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION_UNIQ_HELPER( \ + ctr, T, direction, type_index, device_copy_fn) \ + INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION_UNIQ( \ + ctr, T, direction, type_index, device_copy_fn) + +#define INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION_UNIQ( \ + ctr, T, direction, type_index, device_copy_fn) \ + static variant_op_registry_fn_registration:: \ + UnaryVariantDeviceCopyRegistration \ + register_unary_variant_op_device_copy_fn_##ctr( \ + direction, type_index, device_copy_fn) + +// Register a unary unary_op variant function with the signature: +// Status UnaryOpFn(OpKernelContext* ctx, const T& t, T* t_out); +// to Variants having TypeIndex type_index, for device string device, +// for UnaryVariantOp enum op. +#define REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION(op, device, T, \ + unary_op_function) \ + REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION_UNIQ_HELPER( \ + __COUNTER__, op, device, T, TypeIndex::Make(), unary_op_function) + +#define REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION_UNIQ_HELPER( \ + ctr, op, device, T, type_index, unary_op_function) \ + REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION_UNIQ(ctr, op, device, T, \ + type_index, unary_op_function) + +#define REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION_UNIQ( \ + ctr, op, device, T, type_index, unary_op_function) \ + static ::tensorflow::variant_op_registry_fn_registration:: \ + UnaryVariantUnaryOpRegistration \ + register_unary_variant_op_decoder_fn_##ctr(op, device, type_index, \ + unary_op_function) + +// Register a binary_op variant function with the signature: +// Status BinaryOpFn(OpKernelContext* ctx, const T& a, const T& b, T* out); +// to Variants having TypeIndex type_index, for device string device, +// for BinaryVariantOp enum OP. +#define REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION(op, device, T, \ + binary_op_function) \ + REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION_UNIQ_HELPER( \ + __COUNTER__, op, device, T, TypeIndex::Make(), binary_op_function) + +#define REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION_UNIQ_HELPER( \ + ctr, op, device, T, type_index, binary_op_function) \ + REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION_UNIQ( \ + ctr, op, device, T, type_index, binary_op_function) + +#define REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION_UNIQ( \ + ctr, op, device, T, type_index, binary_op_function) \ + static ::tensorflow::variant_op_registry_fn_registration:: \ + UnaryVariantBinaryOpRegistration \ + register_unary_variant_op_decoder_fn_##ctr(op, device, type_index, \ + binary_op_function) + +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_FRAMEWORK_VARIANT_OP_REGISTRY_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/framework/variant_tensor_data.h b/third_party/tflite-hdrs/tensorflow/core/framework/variant_tensor_data.h new file mode 100644 index 00000000..bfe5899d --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/framework/variant_tensor_data.h @@ -0,0 +1,144 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_FRAMEWORK_VARIANT_TENSOR_DATA_H_ +#define TENSORFLOW_CORE_FRAMEWORK_VARIANT_TENSOR_DATA_H_ + +#include +#include + +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +class VariantTensorDataProto; + +// The serialization format for Variant objects. Objects with references to +// other Tensors can simply store those tensors in the `tensors` field, and +// serialize other metadata content in to the `metadata` field. Objects can +// optionally set the `type_name` for type-checking before deserializing an +// object. +// +// This is the native C++ class equivalent of VariantTensorDataProto. They are +// separate so that kernels do not need to depend on protos. +class VariantTensorData { + public: + VariantTensorData() = default; + + // TODO(b/118823936): This silently returns if the proto is invalid. + // Consider calling FromProto explicitly instead. + VariantTensorData(VariantTensorDataProto proto); + + // Name of the type of objects being serialized. + const std::string& type_name() const { return type_name_; } + void set_type_name(const std::string& type_name) { type_name_ = type_name; } + + template ::type>::value> + struct PODResolver {}; + + // Portions of the object that are not Tensors. + // Directly supported types include string POD types. + template + void set_metadata(const T& value) { + SetMetadata(value, PODResolver()); + } + + template + bool get_metadata(T* value) const { + return GetMetadata(value, PODResolver()); + } + + std::string& metadata_string() { return metadata_; } + + const std::string& metadata_string() const { return metadata_; } + + // Tensors contained within objects being serialized. + int tensors_size() const; + const Tensor& tensors(int index) const; + const std::vector& tensors() const; + Tensor* add_tensors(); + + // A more general version of add_tensors. Parameters are perfectly forwarded + // to the constructor of the tensor added here. + template + Tensor* add_tensor(TensorConstructorArgs&&... args); + + // Conversion to and from VariantTensorDataProto + void ToProto(VariantTensorDataProto* proto) const; + // This allows optimizations via std::move. + bool FromProto(VariantTensorDataProto proto); + bool FromConstProto(const VariantTensorDataProto& proto); + + // Serialization via VariantTensorDataProto + std::string SerializeAsString() const; + bool SerializeToString(std::string* buf); + bool ParseFromString(std::string s); + + std::string DebugString() const; + + public: + std::string type_name_; + std::string metadata_; + std::vector tensors_; + + private: + void SetMetadata(const std::string& value, + PODResolver) { + metadata_ = value; + } + + bool GetMetadata(std::string* value, + PODResolver) const { + *value = metadata_; + return true; + } + + // Specialize for bool, it is undefined behvaior to assign a non 0/1 value to + // a bool. Now we coerce a non-zero value to true. + bool GetMetadata(bool* value, PODResolver) const { + if (metadata_.size() != sizeof(bool)) return false; + *value = false; + for (size_t i = 0; i < sizeof(bool); ++i) + *value = *value || (metadata_.data()[i] != 0); + return true; + } + + template + void SetMetadata(const T& value, PODResolver) { + metadata_.assign(reinterpret_cast(&value), sizeof(T)); + } + + template + bool GetMetadata(T* value, PODResolver) const { + if (metadata_.size() != sizeof(T)) return false; + std::copy_n(metadata_.data(), sizeof(T), reinterpret_cast(value)); + return true; + } +}; + +// For backwards compatibility for when this was a proto +std::string ProtoDebugString(const VariantTensorData& object); + +template +Tensor* VariantTensorData::add_tensor(TensorConstructorArgs&&... args) { + tensors_.emplace_back(std::forward(args)...); + return &tensors_.back(); +} + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_FRAMEWORK_VARIANT_TENSOR_DATA_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/framework/versions.h b/third_party/tflite-hdrs/tensorflow/core/framework/versions.h new file mode 100644 index 00000000..a63ff703 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/framework/versions.h @@ -0,0 +1,40 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_FRAMEWORK_VERSIONS_H_ +#define TENSORFLOW_CORE_FRAMEWORK_VERSIONS_H_ + +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { + +class VersionDef; + +// Check whether data with the given versions is compatible with the given +// consumer and min producer. upper_name and lower_name are used to form +// error messages upon failure. Example usage: +// +// #include "tensorflow/core/public/version.h" +// +// TF_RETURN_IF_ERROR(CheckVersions(versions, TF_GRAPH_DEF_VERSION, +// TF_GRAPH_DEF_VERSION_MIN_PRODUCER, +// "GraphDef", "graph")); +absl::Status CheckVersions(const VersionDef& versions, int consumer, + int min_producer, const char* upper_name, + const char* lower_name); + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_FRAMEWORK_VERSIONS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/function/runtime_client/runtime_client.h b/third_party/tflite-hdrs/tensorflow/core/function/runtime_client/runtime_client.h new file mode 100644 index 00000000..789788fb --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/function/runtime_client/runtime_client.h @@ -0,0 +1,100 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_FUNCTION_RUNTIME_CLIENT_RUNTIME_CLIENT_H_ +#define TENSORFLOW_CORE_FUNCTION_RUNTIME_CLIENT_RUNTIME_CLIENT_H_ + +#include + +#include "absl/types/span.h" +#include "tensorflow/c/eager/abstract_tensor_handle.h" +#include "tensorflow/c/eager/immediate_execution_tensor_handle.h" +#include "tensorflow/core/common_runtime/eager/context.h" +#include "tensorflow/core/framework/function.pb.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/statusor.h" +#include "tensorflow/core/platform/stringpiece.h" + +namespace tensorflow { +namespace core { +namespace function { + +// TODO(mdan): Get rid of this once pybind can depend on MLIR headers. +// This empty struct serves to hide a pointer to an actual MLIR TFG dialect +// FuncOp object. +struct OpaqueTfgGraphFuncOp; + +// TODO(xjun): Get rid of this once pybind can depend on MLIR headers. +// This empty struct serves to hide a pointer to an actual MLIR TF dialect +// FuncOp object. +struct OpaqueTfFuncOp; + +// This is the current global context managed by the Python API. For historical +// reasons, the Python runtime controls this context and all other clients must +// use it. See tensorflow/python/eager/pywrap_tfe.h and +// tensorflow/python/eager/context.py. +// +// This must always be called after the Python eager context was initialized. +// +// If the Python runtime isn't involved, or when writing code that exclusively +// relies on functions defined in this namespace, users are encouraged to +// maintain their own EagerContext or use GlobalEagerContext. +EagerContext& GlobalPythonEagerContext(); + +// This global context is available for testing and to be shared among various +// APIs. +EagerContext& GlobalEagerContext(); + +using ReturnValues = std::vector; + +// A public API for manipulating and executing functions in a TensorFlow +// runtime. +class Runtime { + public: + explicit Runtime(EagerContext& eager_ctx) : eager_ctx_(eager_ctx) {} + + enum class Dialect { + TFG, + TF, + }; + + absl::StatusOr GetFunctionProto(absl::string_view name); + + // TODO(mdan): Enforce creation or rename to SetFunction. + absl::Status CreateFunction(const FunctionDef& fdef); + // TODO(mdan): Change to mlir::tfg::GraphFuncOp once pybind can depend on it. + absl::Status CreateFunction(OpaqueTfgGraphFuncOp* fop); + // TODO(xjun): Change to mlir::func::FuncOp once pybind can depend on it. + absl::Status CreateFunction(OpaqueTfFuncOp* fop); + // Applies a MLIR pipeline to an existing function. + // The pipeline may rename the function. If it does so, the old function + // remains unchanged. If the new name specifies an existing function, it will + // be overwritten. + absl::Status TransformFunction(absl::string_view name, + absl::string_view pipeline_name, + Dialect dialect = Dialect::TFG); + + absl::StatusOr CallFunction( + absl::string_view name, absl::Span args); + + private: + EagerContext& eager_ctx_; +}; + +} // namespace function +} // namespace core +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_FUNCTION_RUNTIME_CLIENT_RUNTIME_CLIENT_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/function/testing/test_pass.h b/third_party/tflite-hdrs/tensorflow/core/function/testing/test_pass.h new file mode 100644 index 00000000..93c2116f --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/function/testing/test_pass.h @@ -0,0 +1,133 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_FUNCTION_TESTING_TEST_PASS_H_ +#define TENSORFLOW_CORE_FUNCTION_TESTING_TEST_PASS_H_ + +#include + +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassRegistry.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/core/ir/dialect.h" +#include "tensorflow/core/ir/ops.h" +#include "tensorflow/core/ir/tf_op_wrapper.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace core { +namespace function { +namespace testing { + +// A simple testing pass for BinaryFunction that replaces an AddV2 node named +// `x_plus_y` with a Mul one. +struct TestPassTfgDialect + : public mlir::PassWrapper> { + TestPassTfgDialect() = default; + + llvm::StringRef getArgument() const final { return "test-pass"; } + + void runOnOperation() override { + auto module = getOperation(); + mlir::OpBuilder builder(module); + mlir::tfg::TFGraphDialect* dialect = + builder.getContext()->getOrLoadDialect(); + + mlir::Operation* target = nullptr; + module->walk([&target](mlir::tfg::TFOp op) { + if (op.nameAttr() == nullptr) { + return; + } + if (op.name() != "x_plus_y") { + return; + } + target = op.getOperation(); + }); + DCHECK(target != nullptr); + + builder.setInsertionPoint(target); + mlir::OperationState opstate(builder.getUnknownLoc(), "tfg.Mul"); + opstate.operands.append(target->getOperands().begin(), + target->getOperands().end()); + opstate.types.append(target->getResultTypes().begin(), + target->getResultTypes().end()); + opstate.addAttribute("T", target->getAttr("T")); + opstate.addAttribute(dialect->getNameAttrIdentifier(), + builder.getStringAttr("x_times_y")); + + mlir::Operation* replacement = builder.create(opstate); + target->replaceAllUsesWith(replacement->getResults()); + target->erase(); + } +}; + +// A simple testing pass that replaces the first Mul node in the module +// to a AddV2 node and names it `x_plus_y`. +struct TestPassTfDialect + : public mlir::PassWrapper> { + TestPassTfDialect() = default; + + llvm::StringRef getArgument() const final { return "test-pass-tf-dialect"; } + + void runOnOperation() override { + auto module = getOperation(); + mlir::OpBuilder builder(module); + + mlir::Operation* target = nullptr; + module->walk([&target](mlir::Operation* op) { + if (op->getName().getStringRef() == "tf.Mul") { + target = op; + return; + } + }); + DCHECK(target != nullptr); + + builder.setInsertionPoint(target); + auto replacement = builder.create( + mlir::NameLoc::get( + mlir::StringAttr::get(builder.getContext(), "x_plus_y")), + target->getResultTypes(), target->getOperand(0), target->getOperand(1)); + target->replaceAllUsesWith(replacement->getResults()); + target->erase(); + } +}; + +inline std::unique_ptr> +CreateTfgDialectTestPass() { + return std::make_unique(); +} + +inline std::unique_ptr> +CreateTfDialectTestPass() { + return std::make_unique(); +} + +inline void RegisterTestPass() { + mlir::registerPass([] { return CreateTfgDialectTestPass(); }); + mlir::registerPass([] { return CreateTfDialectTestPass(); }); +} + +} // namespace testing +} // namespace function +} // namespace core +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_FUNCTION_TESTING_TEST_PASS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/graph/algorithm.h b/third_party/tflite-hdrs/tensorflow/core/graph/algorithm.h new file mode 100644 index 00000000..e20d6823 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/graph/algorithm.h @@ -0,0 +1,154 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPH_ALGORITHM_H_ +#define TENSORFLOW_CORE_GRAPH_ALGORITHM_H_ + +#include +#include +#include + +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/lib/gtl/array_slice.h" + +namespace tensorflow { + +// Comparator for two nodes. This is used in order to get a stable ording. +using NodeComparator = std::function; + +using EdgeFilter = std::function; + +// Compares two node based on their ids. +struct NodeComparatorID { + bool operator()(const Node* n1, const Node* n2) const { + return n1->id() < n2->id(); + } +}; + +// Compare two nodes based on their names. +struct NodeComparatorName { + bool operator()(const Node* n1, const Node* n2) const { + return n1->name() < n2->name(); + } +}; + +// Perform a depth-first-search on g starting at the source node. +// If enter is not empty, calls enter(n) before visiting any children of n. +// If leave is not empty, calls leave(n) after visiting all children of n. +// If stable_comparator is set, a stable ordering of visit is achieved by +// sorting a node's neighbors first before visiting them. +// If edge_filter is set then ignores edges for which edge_filter returns false. +void DFS(const Graph& g, const std::function& enter, + const std::function& leave, + const NodeComparator& stable_comparator = {}, + const EdgeFilter& edge_filter = {}); + +// Perform a depth-first-search on g starting at the 'start' nodes. +// If enter is not empty, calls enter(n) before visiting any children of n. +// If leave is not empty, calls leave(n) after visiting all children of n. +// If stable_comparator is set, a stable ordering of visit is achieved by +// sorting a node's neighbors first before visiting them. +// If edge_filter is set then ignores edges for which edge_filter returns false. +void DFSFrom(const Graph& g, absl::Span start, + const std::function& enter, + const std::function& leave, + const NodeComparator& stable_comparator = {}, + const EdgeFilter& edge_filter = {}); +void DFSFrom(const Graph& g, absl::Span start, + const std::function& enter, + const std::function& leave, + const NodeComparator& stable_comparator = {}, + const EdgeFilter& edge_filter = {}); + +// Perform a reverse depth-first-search on g starting at the sink node. +// If enter is not empty, calls enter(n) before visiting any parents of n. +// If leave is not empty, calls leave(n) after visiting all parents of n. +// If stable_comparator is set, a stable ordering of visit is achieved by +// sorting a node's neighbors first before visiting them. +// If edge_filter is set then ignores edges for which edge_filter returns false. +void ReverseDFS(const Graph& g, const std::function& enter, + const std::function& leave, + const NodeComparator& stable_comparator = {}, + const EdgeFilter& edge_filter = {}); + +// Perform a reverse depth-first-search on g starting at the 'start' nodes. +// If enter is not empty, calls enter(n) before visiting any parents of n. +// If leave is not empty, calls leave(n) after visiting all parents of n. +// If stable_comparator is set, a stable ordering of visit is achieved by +// sorting a node's neighbors first before visiting them. +// If edge_filter is set then ignores edges for which edge_filter returns false. +void ReverseDFSFrom(const Graph& g, absl::Span start, + const std::function& enter, + const std::function& leave, + const NodeComparator& stable_comparator = {}, + const EdgeFilter& edge_filter = {}); +void ReverseDFSFrom(const Graph& g, absl::Span start, + const std::function& enter, + const std::function& leave, + const NodeComparator& stable_comparator = {}, + const EdgeFilter& edge_filter = {}); + +void BreadthFirstTraversal( + const Graph& g, absl::Span start, + const std::function& visit, + NodeComparator stable_comparator = NodeComparatorID()); + +void BreadthFirstTraversal( + Graph& g, absl::Span start, + const std::function& visit, + NodeComparator stable_comparator = NodeComparatorID()); + +// Stores in *order the post-order numbering of all nodes +// in graph found via a depth first search starting at the source node. +// +// Note that this is equivalent to reverse topological sorting when the +// graph does not have cycles. +// +// If stable_comparator is set, a stable ordering of visit is achieved by +// sorting a node's neighbors first before visiting them. +// +// If edge_filter is set then ignores edges for which edge_filter returns +// false. +// +// REQUIRES: order is not NULL. +void GetPostOrder(const Graph& g, std::vector* order, + const NodeComparator& stable_comparator = {}, + const EdgeFilter& edge_filter = {}); + +// Stores in *order the reverse post-order numbering of all nodes +// If stable_comparator is set, a stable ordering of visit is achieved by +// sorting a node's neighbors first before visiting them. +// +// If edge_filter is set then ignores edges for which edge_filter returns +// false. +void GetReversePostOrder(const Graph& g, std::vector* order, + const NodeComparator& stable_comparator = {}, + const EdgeFilter& edge_filter = {}); + +// Prune nodes in "g" that are not in some path from the source node +// to any node in 'nodes'. Returns true if changes were made to the graph. +// Does not fix up source and sink edges. +bool PruneForReverseReachability(Graph* g, + std::unordered_set nodes); + +// Connect all nodes with no incoming edges to source. +// Connect all nodes with no outgoing edges to sink. +// +// Returns true if and only if 'g' is mutated. +bool FixupSourceAndSinkEdges(Graph* g); + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPH_ALGORITHM_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/graph/benchmark_testlib.h b/third_party/tflite-hdrs/tensorflow/core/graph/benchmark_testlib.h new file mode 100644 index 00000000..54716405 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/graph/benchmark_testlib.h @@ -0,0 +1,191 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPH_BENCHMARK_TESTLIB_H_ +#define TENSORFLOW_CORE_GRAPH_BENCHMARK_TESTLIB_H_ + +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/lib/random/philox_random.h" +#include "tensorflow/core/lib/random/simple_philox.h" + +namespace tensorflow { +namespace test { + +REGISTER_OP("Input").Output("y: float"); +REGISTER_OP("Output") + .Input("x: N * float") + .Attr("N: int >= 1") + .Output("y: float"); +REGISTER_OP("In2Out1").Input("a: float").Input("b: float").Output("y: float"); +REGISTER_OP("In4Out1") + .Input("a: float") + .Input("b: float") + .Input("c: float") + .Input("d: float") + .Output("y: float"); +REGISTER_OP("In8Out1") + .Input("a: float") + .Input("b: float") + .Input("c: float") + .Input("d: float") + .Input("e: float") + .Input("f: float") + .Input("g: float") + .Input("h: float") + .Output("y: float"); +REGISTER_OP("In16Out1") + .Input("a: float") + .Input("b: float") + .Input("c: float") + .Input("d: float") + .Input("e: float") + .Input("f: float") + .Input("g: float") + .Input("h: float") + .Input("i: float") + .Input("j: float") + .Input("k: float") + .Input("l: float") + .Input("m: float") + .Input("n: float") + .Input("o: float") + .Input("p: float") + .Output("y: float"); + +inline GraphDef CreateGraphDef(int num_nodes, int num_edges_per_node) { + const int kNumInNodes = 10 * num_edges_per_node; + GraphDef graph_def; + + auto create_node = [](const string& name, const string& op) { + NodeDef node; + node.set_name(name); + node.set_op(op); + return node; + }; + + NodeDef node; + for (int in = 0; in < kNumInNodes; ++in) { + node = create_node(/*name=*/absl::StrFormat("in%04d", in), /*op=*/"Input"); + *graph_def.add_node() = std::move(node); + } + + random::PhiloxRandom philox(301, 17); + random::SimplePhilox rnd(&philox); + for (int op = 0; op < num_nodes; ++op) { + node = create_node(/*name=*/absl::StrFormat("op%05d", op), + /*op=*/absl::StrFormat("In%dOut1", num_edges_per_node)); + for (int edge = 0; edge < num_edges_per_node; ++edge) { + node.add_input(absl::StrFormat("in%04d", rnd.Uniform(kNumInNodes))); + } + *graph_def.add_node() = std::move(node); + } + + // Add a single sink node. Otherwise a lot of time is spent in + // FixupSourceAndSinkEdges(). + node = create_node(/*name=*/"out", /*op=*/"Output"); + for (int op = 0; op < num_nodes; ++op) { + node.add_input(absl::StrFormat("op%05d", op)); + } + AttrValue attr; + attr.set_i(num_nodes); + node.mutable_attr()->insert({"N", std::move(attr)}); + *graph_def.add_node() = std::move(node); + + return graph_def; +} + +inline GraphDef CreateRandomGraph(int size) { + random::PhiloxRandom philox(0x12345); + random::SimplePhilox rnd(&philox); + + string prefix = "long_node_name_prefix_to_measure_string_copy_overhead"; + + GraphDef graph; + for (int i = 0; i < size; ++i) { + const string name = absl::StrCat(prefix, i); + const uint32 num_inputs = rnd.Uniform(std::min(i, 5)); + + NodeDef node; + node.set_name(name); + for (int n = 0; n < num_inputs; ++n) { + const uint32 input_node = rnd.Uniform(i); + node.add_input(absl::StrCat(prefix, input_node)); + } + + *graph.add_node() = std::move(node); + } + + return graph; +} + +inline GraphDef CreateFaninFanoutNodeGraph(int num_regular_fanins, + int num_regular_fanouts, + int num_controlling_fanins, + int num_controlled_fanouts, + bool fanout_unique_index) { + GraphDef graph; + + auto create_node = [](const string& name) { + NodeDef node; + node.set_name(name); + return node; + }; + + NodeDef node = create_node(/*name=*/"node"); + + for (int i = 0; i < num_regular_fanins; ++i) { + const string input_node_name = absl::StrFormat("in%05d", i); + NodeDef input_node = create_node(/*name=*/input_node_name); + *graph.add_node() = std::move(input_node); + node.add_input(input_node_name); + } + + for (int i = 0; i < num_controlling_fanins; ++i) { + const string input_node_name = absl::StrFormat("control_in%05d", i); + NodeDef input_node = create_node(/*name=*/input_node_name); + *graph.add_node() = std::move(input_node); + node.add_input(absl::StrCat("^", input_node_name)); + } + + for (int i = 0; i < num_regular_fanouts; ++i) { + NodeDef output_node = create_node(/*name=*/absl::StrFormat("out%05d", i)); + const string input_node_index = + fanout_unique_index ? absl::StrCat(node.name(), ":", i) : node.name(); + output_node.add_input(input_node_index); + *graph.add_node() = std::move(output_node); + } + + const string controlled_fanout_input = absl::StrCat("^", node.name()); + for (int i = 0; i < num_controlled_fanouts; ++i) { + NodeDef output_node = + create_node(/*name=*/absl::StrFormat("control_out%05d", i)); + output_node.add_input(controlled_fanout_input); + *graph.add_node() = std::move(output_node); + } + + *graph.add_node() = std::move(node); + + return graph; +} + +} // namespace test +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPH_BENCHMARK_TESTLIB_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/graph/collective_order.h b/third_party/tflite-hdrs/tensorflow/core/graph/collective_order.h new file mode 100644 index 00000000..c62017bb --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/graph/collective_order.h @@ -0,0 +1,36 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_GRAPH_COLLECTIVE_ORDER_H_ +#define TENSORFLOW_CORE_GRAPH_COLLECTIVE_ORDER_H_ + +#include "tensorflow/core/graph/graph.h" + +namespace tensorflow { + +enum class GraphCollectiveOrder { kNone, kEdges, kAttrs }; + +// Introduces a deterministic execution order between potentially concurrent +// CollectiveOps. This may be used to execute collectives in the same order +// across all workers in a distributed execution, if all workers are executing +// the same graph. +// If `order_type` is `kEdges`, introduce the ordering in the form of explicit +// control edges between collective graph nodes. If `order_type` is `kAttrs`, +// add an attribute to the node which may be used by collective executor to +// ensure the required ordering. +absl::Status OrderCollectives(Graph* graph, GraphCollectiveOrder order_type); + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPH_COLLECTIVE_ORDER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/graph/colors.h b/third_party/tflite-hdrs/tensorflow/core/graph/colors.h new file mode 100644 index 00000000..43d22255 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/graph/colors.h @@ -0,0 +1,29 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPH_COLORS_H_ +#define TENSORFLOW_CORE_GRAPH_COLORS_H_ + +namespace tensorflow { + +// Return a color drawn from a palette to represent an entity +// identified by "i". The return value has the form "#RRGGBB" Note +// that the palette has a limited set of colors and therefore colors +// will be reused eventually. +const char* ColorFor(int dindex); + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPH_COLORS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/graph/control_flow.h b/third_party/tflite-hdrs/tensorflow/core/graph/control_flow.h new file mode 100644 index 00000000..c1e2db33 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/graph/control_flow.h @@ -0,0 +1,61 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPH_CONTROL_FLOW_H_ +#define TENSORFLOW_CORE_GRAPH_CONTROL_FLOW_H_ + +#include + +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { + +// Control flow info for a graph node. +struct ControlFlowInfo { + // 'frame' and 'parent_frame' are pointers to: + // + // a) One of the Enter nodes corresponding to the loop body, if the node + // executes inside a loop. If multiple tensors enter the while loop, it's + // undefined which Enter node will be used. + // + // b) SOURCE node (node.id() == Graph::kSourceId), if the node is not inside + // any of the while loops. + + const Node* frame = nullptr; // frame of a node + const Node* parent_frame = nullptr; // parent frame of a node + string frame_name; // frame name of a node +}; + +// Clear and populate `info` with each node's frame and the level it belongs to. +// We check the well-formedness of the graph: +// 1) All inputs to a node must come from the same frame and have the same +// "static" iteration level. +// 2) Each frame has at most one LoopCond node. +// 3) Each frame has a single parent frame. +// If `unreachable_nodes` is set, return names of nodes unreachable from the +// source node. We cannot build ControlFlowInfo for such nodes. They might be +// pruned later. +// +// NOTE(yuanbyu): For now, we require all sends/recvs have iteration level 0. +// This essentially means there can't be multiple serial Nexts in an iteration, +// which all sane front-ends should satisfy. +absl::Status BuildControlFlowInfo( + const Graph* g, std::vector* info, + std::vector* unreachable_nodes = nullptr); + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPH_CONTROL_FLOW_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/graph/costmodel.h b/third_party/tflite-hdrs/tensorflow/core/graph/costmodel.h new file mode 100644 index 00000000..795d9472 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/graph/costmodel.h @@ -0,0 +1,241 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPH_COSTMODEL_H_ +#define TENSORFLOW_CORE_GRAPH_COSTMODEL_H_ + +#include +#include +#include + +#include "tensorflow/core/framework/cost_graph.pb.h" +#include "tensorflow/core/framework/step_stats.pb.h" +#include "tensorflow/core/framework/tensor_shape.pb.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/graph/types.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/protobuf.h" + +namespace tensorflow { +typedef std::unordered_map + NodeNameToCostIdMap; + +class StepStats; + +// CostModel keeps track of the following runtime statistics for nodes +// of a single Graph: +// * The total number of times a node has executed. +// * The accumulated execution time (in microseconds) of a node. +// * The accumulated size (in bytes) of each node's output. +// +// This class is NOT thread-safe. +class CostModel { + public: + // If "global" is true, maintains costs based on Node::cost_id, otherwise + // maintains costs based on Node::id. + explicit CostModel(bool is_global) : is_global_(is_global) { + unknown_shape_.set_unknown_rank(true); + } + + // Assigns min_count_ as a function of the median count for a Node. + // This value is then used for suppressing the time/size costs of + // infrequent operations. + // NOTE(tucker): Maybe this should move to a subclass of CostModel. + void SuppressInfrequent(); + + bool is_global() const { return is_global_; } + + inline int Id(const Node* n) const { + if (is_global_) { + return n->cost_id(); + } else { + return n->id(); + } + } + + inline int GlobalId(const Node* n, int offset) const { + if (is_global_) { + return n->cost_id(); + } else { + return n->id() + offset; + } + } + + // Initializes cost model for 'g'. + void InitFromGraph(const Graph& g); + + // Merges costs from cm. + // REQUIRES: is_global_ is true for this and for "cm" + void MergeFromGlobal(const CostModel& cm); + + // Merges costs from "cm", which has been computed relative to "g". + // REQUIRES: is_global_ is true for this, and false for "cm". + void MergeFromLocal(const Graph& g, const CostModel& cm); + + void MergeFromStats(const NodeNameToCostIdMap& map, const StepStats& ss); + + // Sets the number of outputs of "node". + void SetNumOutputs(const Node* node, int num_outputs); + + // Records that "node" has executed "num_count" more times. + void RecordCount(const Node* node, int num_count); + + // Returns how many times "node" has been executed. + int32 TotalCount(const Node* node) const; + + // Records that "output_slot" of "node" has produced tensors of + // aggregated "bytes". + void RecordSize(const Node* node, int output_slot, Bytes bytes); + + // Returns total bytes of tensors produced by "node"s output slot. + Bytes TotalBytes(const Node* node, int output_slot) const; + + // Returns a prediction for the size of the tensor at the + // output_slot produced by one execution of "node". + Bytes SizeEstimate(const Node* node, int output_slot) const; + + // Records that Executions of "node" have taken "time" microseconds. + void RecordTime(const Node* node, Microseconds time); + + // Returns the total execution time for "node". + Microseconds TotalTime(const Node* node) const; + + // Returns a prediction for one execution of "node". + Microseconds TimeEstimate(const Node* node) const; + + // Check that an estimate is available for every OP node in graph. + void CheckInitialized(const Graph& graph) const; + + // Records the maximum size in bytes and optionally the corresponding shape of + // the tensor generated by "output_slot" of "node". If + void RecordMaxMemorySize(const Node* node, int output_slot, Bytes bytes, + const TensorShapeProto& tensor_shape, + const DataType& dtype); + + // Returns the maximum size in bytes of the tensor generated by "output_slot" + // of "node". + Bytes MaxMemorySize(const Node* node, int output_slot) const; + + // Returns the shape corresponding to the largest memory size of the tensor + // generated by "output_slot" of "node". + const TensorShapeProto& MaxMemoryShape(const Node* node, + int output_slot) const; + + // Returns the shape corresponding to the largest memory size of the tensor + // generated by "output_slot" of "node". + DataType MaxMemoryType(const Node* node, int output_slot) const; + + // Returns the size in bytes of temporary memory consumed by "node". + Bytes TempMemorySize(const Node* node) const; + + // Returns the size of persistent memory allocated by "node". + Bytes PersistentMemorySize(const Node* node) const; + + // Records memory stats such as temp momory and persistent memory. + void RecordMemoryStats(const Node* node, const MemoryStats& memory_stats); + + // Records the maximum execution time (in microseconds) of "node". + void RecordMaxExecutionTime(const Node* node, Microseconds time); + + // Returns the maximum execution time (in microseconds) of "node". + Microseconds MaxExecutionTime(const Node* node) const; + + // Record the unique id of the tensor generated by "output_slot" of "node". + // Any other tensor sharing the same id will be an alias, i.e. it will share + // the same underlying memory storage area. + void RecordAllocationId(const Node* node, int output_slot, int64_t alloc_id); + + // Return the unique id of the tensor generated by "output_slot" of "node". + int64_t AllocationId(const Node* node, int output_slot) const; + + bool IsPersistentTensor(const Node* node, int64_t alloc_id) const; + + // Helper routines to encapsulate static estimation heuristics + + // Compute an estimate of the time to copy "b" bytes over the network, + // given a fixed cost of "network_latency_millis" milliseconds and + // an estimated bandwidth of "estimated_gbps" gigabits per second (note that + // this value is in gigabits, not gigabytes). + static Microseconds CopyTimeEstimate(Bytes b, double network_latency_millis, + double estimated_gbps); + static Microseconds ComputationTimeEstimate(int64_t mathops); + + // Add this CostModel into the CostGraphDef. + void AddToCostGraphDef(const Graph* graph, CostGraphDef* cost_graph) const; + + // Write the contents of the CostModel to the INFO log. + void WriteSummaryToLog() const; + + // Increment the times that the cost model is updated. + void IncrementUpdateTimes(); + + // Get the times that the cost model is updated. + int32 GetUpdateTimes() const; + + private: + static Bytes MinTensorMemoryUsage(const TensorShapeProto& tensor_shape, + const DataType& dtype); + + const bool is_global_; + + // Resizes vectors so that they are large enough for "id" and id's outputs. + void Ensure(int id, int num_outputs); + + // Nodes and Edges whose count is < this value + // get type/byte estimates of 0. + int32 min_count_ = 0; + + // The number of times the cost model is updated. + int32 update_times_ = 0; + + // Number of times each Node has been executed. + std::vector count_; + // Cumulative execution time. + std::vector time_; + // Cumulative Bytes output on each channel. + std::vector> slot_bytes_; + + // Maximum execution time + std::vector max_exec_time_; + + // Maximum memory usage + struct MemUsage { + MemUsage() : temp_memory_size(0), persistent_memory_size(0) {} + + // TODO(yuefengz): temp_memory_size is not being used, remove it. + Bytes temp_memory_size; + Bytes persistent_memory_size; + + absl::InlinedVector output_port_mem; + absl::InlinedVector output_port_shape; + absl::InlinedVector output_port_type; + }; + std::vector max_mem_usage_; + + std::vector> output_port_alloc_ids_; + + std::set persistent_alloc_ids_; + + TensorShapeProto unknown_shape_; + + CostModel(const CostModel&) = delete; + void operator=(const CostModel&) = delete; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPH_COSTMODEL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/graph/default_device.h b/third_party/tflite-hdrs/tensorflow/core/graph/default_device.h new file mode 100644 index 00000000..011b7c11 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/graph/default_device.h @@ -0,0 +1,41 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPH_DEFAULT_DEVICE_H_ +#define TENSORFLOW_CORE_GRAPH_DEFAULT_DEVICE_H_ + +#include + +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/node_def.pb.h" + +namespace tensorflow { +namespace graph { + +// Sets the default device for all nodes in graph_def to "device", +// only if not already set. +inline void SetDefaultDevice(const std::string& device, GraphDef* graph_def) { + for (int i = 0; i < graph_def->node_size(); ++i) { + auto node = graph_def->mutable_node(i); + if (node->device().empty()) { + node->set_device(device); + } + } +} + +} // namespace graph +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPH_DEFAULT_DEVICE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/graph/edgeset.h b/third_party/tflite-hdrs/tensorflow/core/graph/edgeset.h new file mode 100644 index 00000000..6d6cb3ff --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/graph/edgeset.h @@ -0,0 +1,246 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPH_EDGESET_H_ +#define TENSORFLOW_CORE_GRAPH_EDGESET_H_ + +#include + +#include "tensorflow/core/lib/gtl/flatset.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/types.h" +namespace tensorflow { + +class Edge; + +// An unordered set of edges. Uses very little memory for small sets. +// Unlike gtl::FlatSet, EdgeSet does NOT allow mutations during +// iteration. +class EdgeSet { + public: + EdgeSet(); + ~EdgeSet(); + + typedef const Edge* key_type; + typedef const Edge* value_type; + typedef size_t size_type; + typedef ptrdiff_t difference_type; + + class const_iterator; + typedef const_iterator iterator; + + bool empty() const; + size_type size() const; + void clear(); + std::pair insert(value_type value); + size_type erase(key_type key); + void reserve(size_type new_size) { + if (new_size > kInline) { + auto s = new gtl::FlatSet(new_size); + s->insert(reinterpret_cast(std::begin(ptrs_)), + reinterpret_cast(&ptrs_[0] + size())); + ptrs_[0] = this; + ptrs_[1] = s; + } + } + + // Caller is not allowed to mutate the EdgeSet while iterating. + const_iterator begin() const; + const_iterator end() const; + + private: + // Up to kInline elements are stored directly in ptrs_ (nullptr means none). + // If ptrs_[0] == this then ptrs_[1] points to a set. + // kInline must be >= 2, and is chosen such that ptrs_ fills a 64 byte + // cacheline. + static constexpr int kInline = 64 / sizeof(const void*); + const void* ptrs_[kInline]; + + gtl::FlatSet* get_set() const { + if (ptrs_[0] == this) { + return static_cast*>( + const_cast(ptrs_[1])); + } else { + return nullptr; + } + } + +// To detect mutations while iterating. +#ifdef NDEBUG + void RegisterMutation() {} +#else + uint32 mutations_ = 0; + void RegisterMutation() { mutations_++; } +#endif + + EdgeSet(const EdgeSet&) = delete; + void operator=(const EdgeSet&) = delete; +}; + +class EdgeSet::const_iterator { + public: + typedef typename EdgeSet::value_type value_type; + typedef const typename EdgeSet::value_type& reference; + typedef const typename EdgeSet::value_type* pointer; + typedef typename EdgeSet::difference_type difference_type; + typedef std::forward_iterator_tag iterator_category; + + const_iterator() {} + + const_iterator& operator++(); + const_iterator operator++(int /*unused*/); + const value_type* operator->() const; + value_type operator*() const; + bool operator==(const const_iterator& other) const; + bool operator!=(const const_iterator& other) const { + return !(*this == other); + } + + private: + friend class EdgeSet; + + void const* const* array_iter_ = nullptr; + typename gtl::FlatSet::const_iterator tree_iter_; + +#ifdef NDEBUG + inline void Init(const EdgeSet* e) {} + inline void CheckNoMutations() const {} +#else + inline void Init(const EdgeSet* e) { + owner_ = e; + init_mutations_ = e->mutations_; + } + inline void CheckNoMutations() const { + CHECK_EQ(init_mutations_, owner_->mutations_); + } + const EdgeSet* owner_ = nullptr; + uint32 init_mutations_ = 0; +#endif +}; + +inline EdgeSet::EdgeSet() { + for (int i = 0; i < kInline; i++) { + ptrs_[i] = nullptr; + } +} + +inline EdgeSet::~EdgeSet() { delete get_set(); } + +inline bool EdgeSet::empty() const { return size() == 0; } + +inline EdgeSet::size_type EdgeSet::size() const { + auto s = get_set(); + if (s) { + return s->size(); + } else { + size_t result = 0; + for (int i = 0; i < kInline; i++) { + if (ptrs_[i]) result++; + } + return result; + } +} + +inline void EdgeSet::clear() { + RegisterMutation(); + delete get_set(); + for (int i = 0; i < kInline; i++) { + ptrs_[i] = nullptr; + } +} + +inline EdgeSet::const_iterator EdgeSet::begin() const { + const_iterator ci; + ci.Init(this); + auto s = get_set(); + if (s) { + ci.tree_iter_ = s->begin(); + } else { + ci.array_iter_ = &ptrs_[0]; + } + return ci; +} + +inline EdgeSet::const_iterator EdgeSet::end() const { + const_iterator ci; + ci.Init(this); + auto s = get_set(); + if (s) { + ci.tree_iter_ = s->end(); + } else { + ci.array_iter_ = &ptrs_[size()]; + } + return ci; +} + +inline EdgeSet::const_iterator& EdgeSet::const_iterator::operator++() { + CheckNoMutations(); + if (array_iter_ != nullptr) { + ++array_iter_; + } else { + ++tree_iter_; + } + return *this; +} + +inline EdgeSet::const_iterator EdgeSet::const_iterator::operator++( + int /*unused*/) { + CheckNoMutations(); + const_iterator tmp = *this; + operator++(); + return tmp; +} + +// gcc's set and multiset always use const_iterator since it will otherwise +// allow modification of keys. +inline const EdgeSet::const_iterator::value_type* EdgeSet::const_iterator:: +operator->() const { + CheckNoMutations(); + if (array_iter_ != nullptr) { + return reinterpret_cast(array_iter_); + } else { + return tree_iter_.operator->(); + } +} + +// gcc's set and multiset always use const_iterator since it will otherwise +// allow modification of keys. +inline EdgeSet::const_iterator::value_type EdgeSet::const_iterator::operator*() + const { + CheckNoMutations(); + if (array_iter_ != nullptr) { + return static_cast(*array_iter_); + } else { + return *tree_iter_; + } +} + +inline bool EdgeSet::const_iterator::operator==( + const const_iterator& other) const { + DCHECK((array_iter_ == nullptr) == (other.array_iter_ == nullptr)) + << "Iterators being compared must be from same set that has not " + << "been modified since the iterator was constructed"; + CheckNoMutations(); + if (array_iter_ != nullptr) { + return array_iter_ == other.array_iter_; + } else { + return other.array_iter_ == nullptr && tree_iter_ == other.tree_iter_; + } +} + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPH_EDGESET_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/graph/graph.h b/third_party/tflite-hdrs/tensorflow/core/graph/graph.h new file mode 100644 index 00000000..6e70b0cd --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/graph/graph.h @@ -0,0 +1,1116 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// A Graph describes a set of computations that are to be +// performed, as well as the dependencies between those +// computations. The basic model is a DAG (directed acyclic graph) with +// * internal nodes representing computational operations to be performed; +// * edges represent dependencies, indicating the target may only be +// executed once the source has completed; and +// * predefined "source" (start) and "sink" (finish) nodes -- the source +// should be the only node that doesn't depend on anything, and the sink +// should be the only node that nothing depends on. +// +// Note: Node ids are intended to be relatively dense in the +// 0..max_id range, but there may be gaps since ids won't be reused. +// +// Note: Some dependencies between operations are due to one operation +// consuming the output of another. In fact operations can produce +// multiple outputs and consume multiple inputs, and some +// optimizations will care about which specific outputs are connected +// to which specific inputs. We therefore represent data dependency +// between output O of layer A and input I of layer B using +// "input index" and "output index" labels per edge. + +#ifndef TENSORFLOW_CORE_GRAPH_GRAPH_H_ +#define TENSORFLOW_CORE_GRAPH_GRAPH_H_ + +#include +#include +#include +#include + +#include "absl/types/optional.h" +#include "tensorflow/core/framework/full_type.pb.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/graph/edgeset.h" +#include "tensorflow/core/lib/core/arena.h" +#include "tensorflow/core/lib/core/refcount.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/gtl/iterator_range.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/stringpiece.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +class Edge; +class EdgeSetTest; +class Graph; +class GraphDebugInfo; +class GraphDef; +class GraphTest; +class Node; +struct OutputTensor; +class VersionDef; +class WhileContext; + +class NeighborIter; // Declared below +class NodeIter; // Declared below + +// Indicates where the graph instance is originated from. +enum class ConstructionContext { + kNotTracked, // Not tracked. + kDirectSession, // From `tensorflow::DirectSession`, TF1 session API. + kEagerRuntime, // Registered from TF2 eager runtime. +}; + +class Node { + public: + std::string DebugString() const; + int id() const { return id_; } + int cost_id() const { return cost_id_; } + const std::string& name() const; + void set_name(std::string name); + const std::string& type_string() const; + + // def() provides the NodeDef the user supplied, but the specifics + // of this Node may have changed due to placement, optimization, etc. + // In particular: + // * def().name() will match name(); + // * def().op() will match type_string() and op_def().name(); + // * def().input() is not reliable, use "in_edges()" below instead; + // * def().device() is the "user's requested device" and may not match + // the actual assigned device, see assigned_device_name() below; + // * def().attr() is authoritative. + // TODO(irving): Replace with NodeInfo. + const NodeDef& def() const; + const OpDef& op_def() const; + + NodeDef* mutable_def(); + + // input and output types + int32 num_inputs() const; + DataType input_type(int32_t i) const; + const DataTypeVector& input_types() const; + + int32 num_outputs() const; + DataType output_type(int32_t o) const; + const DataTypeVector& output_types() const; + + // The device requested by the user. For the actual assigned device, + // use assigned_device_name() below. + const std::string& requested_device() const; + + // This changes the user requested device but not necessarily the device that + // on which the operation will run. + void set_requested_device(const std::string& device); + + // This gives the device the runtime has assigned this node to. If + // you want the device the user requested, use def().device() instead. + // TODO(josh11b): Validate that the assigned_device, if not empty: + // fully specifies a device, and satisfies def().device(). + // TODO(josh11b): Move assigned_device_name outside of Node into a + // NodeId->DeviceName map. + const std::string& assigned_device_name() const; + void set_assigned_device_name(const std::string& device_name); + bool has_assigned_device_name() const { + return assigned_device_name_index_ > 0; + } + int assigned_device_name_index() const { return assigned_device_name_index_; } + void set_assigned_device_name_index(int index); + + // Sets 'original_node_names' field of this node's DebugInfo proto to + // 'names'. + void set_original_node_names(const std::vector& names); + void set_original_func_names(const std::vector& names); + + // Read only access to attributes + AttrSlice attrs() const; + + // Inputs requested by the NodeDef. For the actual inputs, use in_edges. + const protobuf::RepeatedPtrField& requested_inputs() const; + + // Get the neighboring nodes via edges either in or out of this node. This + // includes control edges. + gtl::iterator_range in_nodes() const; + gtl::iterator_range out_nodes() const; + const EdgeSet& in_edges() const { return in_edges_; } + const EdgeSet& out_edges() const { return out_edges_; } + + // Node type helpers. + bool IsSource() const { return id() == 0; } + bool IsSink() const { return id() == 1; } + // Anything other than the special Source & Sink nodes. + bool IsOp() const { return id() > 1; } + + // Node class helpers + bool IsSwitch() const { return class_ == NC_SWITCH; } + bool IsMerge() const { return class_ == NC_MERGE; } + bool IsEnter() const { return class_ == NC_ENTER; } + bool IsExit() const { return class_ == NC_EXIT; } + bool IsNextIteration() const { return class_ == NC_NEXT_ITERATION; } + bool IsLoopCond() const { return class_ == NC_LOOP_COND; } + bool IsControlTrigger() const { return class_ == NC_CONTROL_TRIGGER; } + bool IsSend() const { return class_ == NC_SEND || class_ == NC_HOST_SEND; } + bool IsRecv() const { return class_ == NC_RECV || class_ == NC_HOST_RECV; } + bool IsConstant() const { return class_ == NC_CONSTANT; } + bool IsVariable() const { return class_ == NC_VARIABLE; } + bool IsIdentity() const { return class_ == NC_IDENTITY; } + bool IsGetSessionHandle() const { return class_ == NC_GET_SESSION_HANDLE; } + bool IsGetSessionTensor() const { return class_ == NC_GET_SESSION_TENSOR; } + bool IsDeleteSessionTensor() const { + return class_ == NC_DELETE_SESSION_TENSOR; + } + bool IsControlFlow() const { + return (class_ != NC_OTHER) && // Fast path + (IsSwitch() || IsMerge() || IsEnter() || IsExit() || + IsNextIteration()); + } + bool IsHostSend() const { return class_ == NC_HOST_SEND; } + bool IsHostRecv() const { return class_ == NC_HOST_RECV; } + bool IsScopedAllocator() const { return class_ == NC_SCOPED_ALLOCATOR; } + bool IsCollective() const { return class_ == NC_COLLECTIVE; } + + bool IsMetadata() const { return class_ == NC_METADATA; } + bool IsFakeParam() const { return class_ == NC_FAKE_PARAM; } + bool IsPartitionedCall() const { return class_ == NC_PARTITIONED_CALL; } + + // Returns true if this node is any kind of function call node. + // + // NOTE: "function call nodes" include partitioned call ops, symbolic gradient + // ops, and ops whose type_string is the name of a function ("function ops"). + bool IsFunctionCall() const { + return class_ == NC_PARTITIONED_CALL || class_ == NC_FUNCTION_OP || + class_ == NC_SYMBOLIC_GRADIENT; + } + + bool IsIfNode() const { return class_ == NC_IF; } + bool IsWhileNode() const { return class_ == NC_WHILE; } + bool IsCaseNode() const { return class_ == NC_CASE; } + // Is this node a function input + bool IsArg() const { return class_ == NC_ARG; } + // Is this node a function output + bool IsRetval() const { return class_ == NC_RETVAL; } + + bool IsDistributedCommunication() const { + return op_def().is_distributed_communication(); + } + + template + void AddAttr(const std::string& name, const T& val) { + SetAttrValue(val, AddAttrHelper(name)); + UpdateProperties(); + } + + void AddAttr(const std::string& name, std::vector&& val) { + MoveAttrValue(std::move(val), AddAttrHelper(name)); + UpdateProperties(); + } + + void ClearAttr(const std::string& name); + + // Returns into '*e' the edge connecting to the 'idx' input of this Node. + absl::Status input_edge(int idx, const Edge** e) const; + + // Returns into '*edges' the input data edges of this Node, indexed by input + // number. Does not return control edges. + absl::Status input_edges(std::vector* edges) const; + + // Returns into '*n' the node that has an output connected to the + // 'idx' input of this Node. + absl::Status input_node(int idx, const Node** n) const; + absl::Status input_node(int idx, Node** n) const; + + // Returns into '*t' the idx-th input tensor of this node, represented as the + // output tensor of input_node(idx). + absl::Status input_tensor(int idx, OutputTensor* t) const; + + WhileContext* while_ctx() const { return while_ctx_; } + void set_while_ctx(WhileContext* while_ctx) { + DCHECK(IsExit()); + DCHECK(while_ctx_ == nullptr); + while_ctx_ = while_ctx; + } + + std::shared_ptr properties() const { return props_; } + + // Sets the stack trace for the node. Assumes that getting and setting the + // stack trace for a given node will not race. + void SetStackTrace(const std::shared_ptr& stack_trace) { + stack_trace_ = stack_trace; + } + + // Get the stack trace for when the node was instantiated. + const std::shared_ptr& GetStackTrace() const { + return stack_trace_; + } + + // Called after an attr has changed. Decides whether we need to update some + // property of the node (stored in props_). + void UpdateProperties(); + + // Erases type information from the node. + void ClearTypeInfo(); + + // Update type information for a node with a list of inputs and/or outputs + // described by its TYPE_ATTR_NAME attr when removing some of these. The keys + // of INDEX_MAPPING are the indexes of the inputs/outputs that are not + // removed. dtype information in the TYPE_ATTR_NAME attr is always updated. + // Use UPDATE_FULL_TYPE=true when this changes the node's outputs to also + // update the node's full type information (if present). + absl::Status ShrinkTypeInfo( + const absl::flat_hash_map& index_mapping, + const string& type_attr_name, bool update_full_type); + + // Called after an incident non-control edge has changed. Does nothing if not + // all input edges are defined. + void RunForwardTypeInference(); + + private: + // TODO(mdan): Drop this. + friend class Graph; + Node(); + + // Stack trace for the user code for node instantiation. Can be shared across + // multiple nodes (e.g. when inlining). + std::shared_ptr stack_trace_; + + // Releases memory from props_, in addition to restoring *this to its + // uninitialized state. + void Clear(); + + // Make a copy of the Node's props_ if props_ is shared with + // other nodes. This must be called before mutating properties, + // e.g. in AddAttr. + void MaybeCopyOnWrite(); + + AttrValue* AddAttrHelper(const std::string& name); + + // A set of mutually exclusive classes for different kinds of nodes, + // class_ is initialized in the Node::Initialize routine based on the + // node's type_string(). + enum NodeClass { + NC_UNINITIALIZED, + NC_SWITCH, + NC_MERGE, + NC_ENTER, + NC_EXIT, + NC_NEXT_ITERATION, + NC_LOOP_COND, + NC_CONTROL_TRIGGER, + NC_SEND, + NC_HOST_SEND, + NC_RECV, + NC_HOST_RECV, + NC_CONSTANT, + NC_VARIABLE, + NC_IDENTITY, + NC_GET_SESSION_HANDLE, + NC_GET_SESSION_TENSOR, + NC_DELETE_SESSION_TENSOR, + NC_METADATA, + NC_SCOPED_ALLOCATOR, + NC_COLLECTIVE, + NC_FAKE_PARAM, + NC_PARTITIONED_CALL, + NC_FUNCTION_OP, + NC_SYMBOLIC_GRADIENT, + NC_IF, + NC_WHILE, + NC_CASE, + NC_ARG, + NC_RETVAL, + NC_OTHER // Not a special kind of node + }; + + void Initialize(int id, int cost_id, std::shared_ptr props, + NodeClass node_class); + + static NodeClass GetNodeClassForOp(const std::string& ts); + + int id_; // -1 until Initialize() is called + int cost_id_; // -1 if there is no corresponding cost accounting node + NodeClass class_; + + EdgeSet in_edges_; + EdgeSet out_edges_; + + // NOTE(skyewm): inheriting from core::RefCounted may have a slight + // performance benefit over using shared_ptr, at the cost of manual ref + // counting + std::shared_ptr props_; + + // Index within Graph::device_names_ of the name of device assigned + // to perform this computation. + int assigned_device_name_index_; + + // A back-pointer to the Graph that owns this node. Currently, this exists + // solely to allow Node::[set_]assigned_device_name() to work. However, if all + // callers of Node::[set_]assigned_device_name() are modified to use the + // equivalent methods defined directly on Graph, then we can remove this + // field and reclaim that memory. + Graph* graph_; + + // Set if this is an exit node of a while loop with an associated + // WhileContext. Otherwise null. (This is only set for exit nodes because + // they're the first nodes of a loop encountered while creating the gradient + // graph. Exit nodes that are part of while loop gradient graphs will not have + // this set.) + WhileContext* while_ctx_; + + Node(const Node&) = delete; + void operator=(const Node&) = delete; +}; + +// Stores debug information associated with the Node. +struct NodeDebugInfo { + const std::string name; + std::vector original_node_names; + std::vector original_func_names; + + NodeDebugInfo(const Node& n); + NodeDebugInfo(const NodeDef& ndef); + NodeDebugInfo(absl::string_view node_name, bool has_experimental_debug_info, + const NodeDef_ExperimentalDebugInfo& experimental_debug_info); +}; + +// Represents an input of a node, i.e., the `index`-th input to `node`. +struct InputTensor { + Node* node; + int index; + + InputTensor(Node* n, int i) : node(n), index(i) {} + InputTensor() : node(nullptr), index(0) {} + + // Returns true if this InputTensor is identical to 'other'. Nodes are + // compared using pointer equality. + bool operator==(const InputTensor& other) const; + + // A hash function for InputTensors. Nodes are hashed based on their pointer + // value. + struct Hash { + uint64 operator()(InputTensor const& s) const; + }; +}; + +// Represents an output of a node, i.e., the `index`-th output of `node`. Note +// that a single `OutputTensor` can correspond to multiple `Edge`s if the output +// is consumed by multiple destination nodes. +struct OutputTensor { + Node* node; + int index; + + OutputTensor(Node* n, int i) : node(n), index(i) {} + OutputTensor() : node(nullptr), index(0) {} + + // Returns true if this OutputTensor is identical to 'other'. Nodes are + // compared using pointer equality. + bool operator==(const OutputTensor& other) const; + + // A hash function for OutputTensors. Nodes are hashed based on their pointer + // value. + struct Hash { + uint64 operator()(OutputTensor const& s) const; + }; +}; + +class Edge { + public: + Node* src() const { return src_; } + Node* dst() const { return dst_; } + int id() const { return id_; } + + // Return the index of the source output that produces the data + // carried by this edge. The special value kControlSlot is used + // for control dependencies. + int src_output() const { return src_output_; } + + // Return the index of the destination input that consumes the data + // carried by this edge. The special value kControlSlot is used + // for control dependencies. + int dst_input() const { return dst_input_; } + + // Return true iff this is an edge that indicates a control-flow + // (as opposed to a data-flow) dependency. + bool IsControlEdge() const; + + std::string DebugString() const; + + private: + Edge() {} + + friend class EdgeSetTest; + friend class GraphTest; + friend class Graph; + Node* src_; + Node* dst_; + int id_; + int src_output_; + int dst_input_; +}; + +// Allows for iteration of the edges of a Graph, by iterating the underlying +// Graph.edges_ vector while skipping over null entries. +class GraphEdgesIterable { + private: + const std::vector& edges_; + + public: + explicit GraphEdgesIterable(const std::vector& edges) + : edges_(edges) {} + + typedef Edge* value_type; + + class const_iterator { + private: + // The underlying iterator. + std::vector::const_iterator iter_; + + // The end of the underlying iterator. + std::vector::const_iterator end_; + + // Advances iter_ until it reaches a non-null item, or reaches the end. + void apply_filter() { + while (iter_ != end_ && *iter_ == nullptr) { + ++iter_; + } + } + + public: + const_iterator(std::vector::const_iterator iter, + std::vector::const_iterator end) + : iter_(iter), end_(end) { + apply_filter(); + } + + bool operator==(const const_iterator& other) const { + return iter_ == other.iter_; + } + + bool operator!=(const const_iterator& other) const { + return iter_ != other.iter_; + } + + // This is the prefix increment operator (++x), which is the operator + // used by C++ range iteration (for (x : y) ...). We intentionally do not + // provide a postfix increment operator. + const_iterator& operator++() { + ++iter_; + apply_filter(); + return *this; + } + + value_type operator*() { return *iter_; } + }; + + const_iterator begin() { + return const_iterator(edges_.begin(), edges_.end()); + } + const_iterator end() { return const_iterator(edges_.end(), edges_.end()); } +}; + +// Thread compatible but not thread safe. +class Graph { + public: + // Constructs a graph with a single SOURCE (always id kSourceId) and a + // single SINK (always id kSinkId) node, and an edge from SOURCE->SINK. + // + // The graph can hold ops found in the registry. `ops`s lifetime must be at + // least that of the constructed graph's. + explicit Graph(const OpRegistryInterface* ops); + + // Constructs a graph with a single SOURCE (always id kSourceId) and a + // single SINK (always id kSinkId) node, and an edge from SOURCE->SINK. + // + // The graph can hold ops found in `flib_def`. Unlike the constructor taking + // an OpRegistryInterface, this constructor copies the function definitions in + // `flib_def` so its lifetime may be shorter than that of the graph's. The + // OpRegistryInterface backing `flib_def` must still have the lifetime of the + // graph though. + explicit Graph(const FunctionLibraryDefinition& flib_def); + + ~Graph(); + + // Clone the current graph into a new one. + std::unique_ptr Clone(); + + static constexpr int kControlSlot = -1; + + // The GraphDef version range of this graph (see graph.proto). + const VersionDef& versions() const; + void set_versions(const VersionDef& versions); + + // Adds a new node to this graph, and returns it. Infers the Op and + // input/output types for the node. *this owns the returned instance. + // Returns nullptr and sets *status on error. + Node* AddNode(NodeDef node_def, absl::Status* status); + + // Same as above, but using StatusOr. This method is always preferred. + absl::StatusOr AddNode(NodeDef node_def); + + // Copies *node, which may belong to another graph, to a new node, + // which is returned. Does not copy any edges. *this owns the + // returned instance. + Node* CopyNode(const Node* node); + + // Removes a node from this graph, including all edges from or to it. + // *node should not be accessed after calling this function. + // REQUIRES: node->IsOp() + void RemoveNode(Node* node); + + void Copy(const Graph& src); + + // Removes all nodes from this graph, including all edges from or to them. + // No Node* references to the Graph are valid post. + void Clear(); + + // Adds an edge that connects the xth output of `source` to the yth input of + // `dest` and returns it. Does not update dest's NodeDef. + const Edge* AddEdge(Node* source, int x, Node* dest, int y); + + // Adds a control edge (no data flows along this edge) that connects `source` + // to `dest`. If `dest`s NodeDef is missing the corresponding control input, + // adds the control input. + // + // If such a control edge already exists and `allow_duplicates` is false, no + // edge is added and the function returns nullptr. Otherwise the edge is + // unconditionally created and returned. The NodeDef is not updated if + // `allow_duplicates` is true. + // TODO(skyewm): // TODO(skyewm): allow_duplicates is needed only by + // graph_partition.cc. Figure out if we can do away with it. + const Edge* AddControlEdge(Node* source, Node* dest, + bool allow_duplicates = false); + + // Removes edge from the graph. Does not update the destination node's + // NodeDef. Does not update the full type information of the source node's + // NodeDef. (See ShrinkTypeInfo for an example of updating full type + // information when removing some outputs from a node.) + // REQUIRES: The edge must exist. + void RemoveEdge(const Edge* edge); + + // Removes control edge `edge` from the graph. Note that this also updates + // the corresponding NodeDef to reflect the change. + // REQUIRES: The control edge must exist. + void RemoveControlEdge(const Edge* e); + + // Updates the input to a node. The existing edge to `dst` is removed and an + // edge from `new_src` to `dst` is created. The NodeDef associated with `dst` + // is also updated. + absl::Status UpdateEdge(Node* new_src, int new_src_index, Node* dst, + int dst_index); + + // Add an input to dst that comes from the "src_slot" output of the + // node named by "src_name". + static void AddInput(NodeDef* dst, absl::string_view src_name, int src_slot); + + // Like AddEdge but updates dst's NodeDef. Used to add an input edge to a + // "While" op during gradient construction, see AddInputWhileHack in + // python_api.h for more details. + absl::Status AddWhileInputHack(Node* new_src, int new_src_index, Node* dst); + + // Adds the function and gradient definitions in `fdef_lib` to this graph's op + // registry. Ignores duplicate functions, and returns a bad status if an + // imported function differs from an existing function or op with the same + // name. This overload adds the function definitions with no stack traces. + absl::Status AddFunctionLibrary(const FunctionDefLibrary& fdef_lib); + absl::Status AddFunctionLibrary(FunctionDefLibrary&& fdef_lib); + + // Adds the function and gradient definitions in `fdef_lib` to this graph's op + // registry. Ignores duplicate functions, and returns a bad status if an + // imported function differs from an existing function or op with the same + // name. + absl::Status AddFunctionLibrary( + const FunctionDefLibrary& fdef_lib, + const FunctionDefLibraryStackTraces& stack_traces); + absl::Status AddFunctionLibrary( + FunctionDefLibrary&& fdef_lib, + const FunctionDefLibraryStackTraces& stack_traces); + + // Adds the function definition and its stacktraces to this graph's op + // registry. Ignores duplicate functions, and returns a bad status if an + // imported function differs from an existing function or op with the same + // name. + absl::Status AddFunctionDef(const FunctionDef& fdef, + const StackTracesMap& stack_traces); + + // Adds the gradient definition to this graph's op registry. Ignores duplicate + // gradients of the same function, and returns a bad status if an imported + // gradient differs from an existing gradient of the same function name. + absl::Status AddGradientDef(const GradientDef& gdef); + + // The number of live nodes in the graph. + // + // Because nodes can be removed from the graph, num_nodes() is often + // smaller than num_node_ids(). If one needs to create an array of + // nodes indexed by node ids, num_node_ids() should be used as the + // array's size. + int num_nodes() const { return num_nodes_; } + + // The number of live nodes in the graph, excluding the Source and Sink nodes. + int num_op_nodes() const { + DCHECK_GE(num_nodes_, 2); + return num_nodes_ - 2; + } + + // The number of live edges in the graph. + // + // Because edges can be removed from the graph, num_edges() is often + // smaller than num_edge_ids(). If one needs to create an array of + // edges indexed by edge ids, num_edge_ids() should be used as the + // array's size. + int num_edges() const { return num_edges_; } + + // Serialize the nodes starting at `from_node_id` to a GraphDef. + // `include_flib_def` indicates whether the function library will be populated + // in the `graph_def`. `include_flib_def` should be usually set to true so + // that the populated `graph_def` will be complete. Setting `include_flib_def` + // to false would mean that the returned `graph_def` is incomplete and may + // contain references to functions whose definition is not included. It can + // make sense to do this in cases where the caller already has a copy of the + // function library. + // If `include_debug_info` is true, the `debug_info` field of the GraphDef + // will be populated with stack traces from the nodes and the function + // library. Note that if `include_debug_info` is true and `include_flib_def` + // is false, then `debug_info` will contain stack traces for nodes in the + // function library, which will not itself be included in the GraphDef. + void ToGraphDefSubRange(GraphDef* graph_def, int from_node_id, + bool include_flib_def = true, + bool include_debug_info = false) const; + + // Serialize to a GraphDef. `include_flib_def` indicates whether the function + // library will be populated in the `graph_def`. `include_flib_def` should be + // usually set to true so that the populated `graph_def` will be complete. + // Setting `include_flib_def` to false would mean that the returned + // `graph_def` is incomplete and may contain references to functions whose + // definition is not included. It can make sense to do this in cases where the + // caller already has a copy of the function library. + // If `include_debug_info` is true, the `debug_info` field of the GraphDef + // will be populated with stack traces from the nodes and the function + // library. Note that if `include_debug_info` is true and `include_flib_def` + // is false, then `debug_info` will contain stack traces for nodes in the + // function library, which will not itself be included in the GraphDef. + void ToGraphDef(GraphDef* graph_def, bool include_flib_def = true, + bool include_debug_info = false) const; + + // This version can be called from debugger to inspect the graph content. + // Use the previous version outside debug context for efficiency reasons. + // + // Note: We do not expose a DebugString() API, since GraphDef.DebugString() is + // not defined in some TensorFlow builds. + GraphDef ToGraphDefDebug() const; + + // Generate new node name with the specified prefix that is unique + // across this graph. + std::string NewName(absl::string_view prefix); + + // Access to the list of all nodes. Example usage: + // for (Node* node : graph.nodes()) { ... } + gtl::iterator_range nodes() const; + + // Access to the list of all nodes, excluding the Source and Sink nodes. + gtl::iterator_range op_nodes() const; + + // Returns one more than the maximum id assigned to any node. + int num_node_ids() const { return nodes_.size(); } + + // Returns the node associated with an id, or nullptr if no node + // with that id (the node with that id was removed and the id has + // not yet been re-used). *this owns the returned instance. + // REQUIRES: 0 <= id < num_node_ids(). + Node* FindNodeId(int id) const { return nodes_[id]; } + + // Returns one more than the maximum id assigned to any edge. + int num_edge_ids() const { return edges_.size(); } + + // Returns the Edge associated with an id, or nullptr if no edge + // with that id (the edge with that id was removed and the id has + // not yet been re-used). *this owns the returned instance. + // REQUIRES: 0 <= id < num_edge_ids(). + const Edge* FindEdgeId(int id) const { return edges_[id]; } + + // Access to the set of all edges. Example usage: + // for (const Edge* e : graph.edges()) { ... } + GraphEdgesIterable edges() const { return GraphEdgesIterable(edges_); } + + // The pre-defined nodes. + enum { kSourceId = 0, kSinkId = 1 }; + Node* source_node() const { return FindNodeId(kSourceId); } + Node* sink_node() const { return FindNodeId(kSinkId); } + + const OpRegistryInterface* op_registry() const { return &ops_; } + const FunctionLibraryDefinition& flib_def() const { return ops_; } + + FunctionLibraryDefinition* mutable_flib_def() { return &ops_; } + + void CheckDeviceNameIndex(int index) { + DCHECK_GE(index, 0); + DCHECK_LT(index, static_cast(device_names_.size())); + } + + int InternDeviceName(const std::string& device_name); + + const std::string& get_assigned_device_name(const Node& node) const { + return device_names_[node.assigned_device_name_index()]; + } + + void set_assigned_device_name_index(Node* node, int device_name_index) { + CheckDeviceNameIndex(device_name_index); + node->assigned_device_name_index_ = device_name_index; + } + + void set_assigned_device_name(Node* node, const std::string& device_name) { + node->assigned_device_name_index_ = InternDeviceName(device_name); + } + + // Returns OK if `node` is non-null and belongs to this graph + absl::Status IsValidNode(const Node* node) const; + + // Returns OK if IsValidNode(`node`) and `idx` is a valid output. Does not + // accept control outputs. + absl::Status IsValidOutputTensor(const Node* node, int idx) const; + + // Returns OK if IsValidNode(`node`) and `idx` a valid input. Does not accept + // control inputs. + absl::Status IsValidInputTensor(const Node* node, int idx) const; + + // Create and return a new WhileContext owned by this graph. This is called + // when a new while loop is created. `frame_name` must be unique among + // WhileContexts in this graph. + absl::Status AddWhileContext(absl::string_view frame_name, + std::vector enter_nodes, + std::vector exit_nodes, + OutputTensor cond_output, + std::vector body_inputs, + std::vector body_outputs, + WhileContext** result); + + // Builds a node name to node pointer index for all nodes in the graph. + std::unordered_map BuildNodeNameIndex() const; + + absl::optional>& GetConstArgIndicesCache() const { + return const_arg_indices_cache_; + } + + // TODO(kkb): Add to the constructor when it becomes managable. + // Sets the graph construction context. + void SetConstructionContext(ConstructionContext construction_context) { + construction_context_ = construction_context; + } + + // TODO(kkb): Rename to `GetConstructionContext` once we're comfortable + // making this stable and make it available widely. + // Returns the graph construction context. It's `kUnknown` if not set. + ConstructionContext GetConstructionContextInternal() const { + return construction_context_; + } + + // Set full type information for a node given its name. + // Note that if this is called in a loop iterating over all the nodes + // elsewhere it would be O(n^2) complexity. If this case was important in the + // future, an alternative method could be added that takes in a flat_hash_map + // of name: type and simply iterates through the graph once and annotates all + // nodes. + void SetNodeType(absl::string_view name, const FullTypeDef& type); + + // Get full type information for a node given its name. + // Note that if this is called in a loop iterating over all the nodes + // elsewhere it would be O(n^2) complexity. If this case was important in the + // future, an alternative method could be added that takes in flat_hash_map of + // name: type and simply iterates through the graph once and stores all the + // information in the map. + void NodeType(absl::string_view name, const FullTypeDef** result); + + // Builds a GraphDebugInfo from the functions and nodes in this graph. Stack + // traces associated with function definitions will have a key of the form + // '@' . Stack traces associated with other Nodes + // will use the node name as the key. + GraphDebugInfo BuildDebugInfo() const; + + // TODO(josh11b): uint64 hash() const; + + private: + // If cost_node is non-null, then cost accounting (in CostModel) + // will be associated with that node rather than the new one being + // created. + // + // Ownership of the returned Node is not transferred to caller. + Node* AllocateNode(std::shared_ptr props, + const Node* cost_node, Node::NodeClass node_class); + void ReleaseNode(Node* node); + // Insert edge in free_edges_ for possible reuse. + void RecycleEdge(const Edge* edge); + // Registry of all known ops, including functions. + FunctionLibraryDefinition ops_; + + // GraphDef versions + const std::unique_ptr versions_; + + // Allocator which will give us good locality. + core::Arena arena_; + + // Map from node ids to allocated nodes. nodes_[id] may be nullptr if + // the node with that id was removed from the graph. + std::vector nodes_; + + // Number of nodes alive. + int64_t num_nodes_ = 0; + + // Map from edge ids to allocated edges. edges_[id] may be nullptr if + // the edge with that id was removed from the graph. + std::vector edges_; + + // The number of entries in edges_ that are not nullptr. + int num_edges_ = 0; + + // Allocated but free nodes and edges. + std::vector free_nodes_; + std::vector free_edges_; + + // For generating unique names. + int name_counter_ = 0; + + // In most graphs, the number of unique values used for the + // Node::assigned_device_name() property is quite small. If the graph is + // large, then this duplication of values can consume a significant amount of + // memory. Instead, we represent the same information using an interning + // table, which consists of a vector of unique strings (device_names_), as + // well a map (device_names_map_) from unique strings to indices within the + // unique string table. + // + // The InternDeviceName() method handles adding a new entry into the table, + // or locating the index of an existing entry. + // + // The fact that Node::assigned_device_name() is implemented using an + // interning table is intentionally public. This allows algorithms that + // frequently access this field to do so efficiently, especially for the case + // where the assigned_device_name of one Node is copied directly from that + // of another Node. + + // A table of the unique assigned device names. Indices do NOT correspond + // to node IDs. Index 0 is always the empty string. + std::vector device_names_; + + // Maps unique device names to indices within device_names_[i]. + std::unordered_map device_names_map_; + + // All the while contexts owned by this graph, keyed by frame name, + // corresponding to all the while loops contained in this graph (including + // nested loops). The stored contexts are usually accessed via + // AddWhileContext() or Node::while_ctx(), but this manages the lifetime. + std::map while_ctxs_; + + // Cache of the indices of the arguments which need to be constant for the XLA + // compilation. + mutable absl::optional> const_arg_indices_cache_; + + // Indicates the context that this Graph instance is constructed. + ConstructionContext construction_context_ = ConstructionContext::kNotTracked; + + Graph(const Graph&) = delete; + void operator=(const Graph&) = delete; +}; + +// TODO(josh11b): We may want to support keeping an index on various +// node/edge attributes in a graph, particularly node names. + +// Helper routines + +inline bool IsSource(const Node* node) { return node->IsSource(); } +inline bool IsSink(const Node* node) { return node->IsSink(); } +inline bool IsSwitch(const Node* node) { return node->IsSwitch(); } +inline bool IsMerge(const Node* node) { return node->IsMerge(); } +inline bool IsEnter(const Node* node) { return node->IsEnter(); } +inline bool IsExit(const Node* node) { return node->IsExit(); } +inline bool IsNextIteration(const Node* n) { return n->IsNextIteration(); } +inline bool IsLoopCond(const Node* node) { return node->IsLoopCond(); } +inline bool IsControlTrigger(const Node* n) { return n->IsControlTrigger(); } +inline bool IsSend(const Node* node) { return node->IsSend(); } +inline bool IsRecv(const Node* node) { return node->IsRecv(); } +inline bool IsHostSend(const Node* node) { return node->IsHostSend(); } +inline bool IsHostRecv(const Node* node) { return node->IsHostRecv(); } + +// True for Nodes that mediate the transfer of values between processes. +inline bool IsTransferNode(const Node* n) { return IsSend(n) || IsRecv(n); } + +inline bool IsConstant(const Node* node) { return node->IsConstant(); } +inline bool IsVariable(const Node* node) { return node->IsVariable(); } +inline bool IsIdentity(const Node* node) { return node->IsIdentity(); } + +// Returns true iff 'n' is a control flow node. +inline bool IsControlFlow(const Node* n) { return n->IsControlFlow(); } + +// Returns true if the node only depends on its input's metadata +// (shape). Specifically, returns true for "Size", "Shape" and "Rank" ops. +inline bool IsMetadata(const Node* n) { return n->IsMetadata(); } + +inline bool IsScopedAllocator(const Node* n) { return n->IsScopedAllocator(); } + +inline bool IsHostMemoryPreserving(const Node* node) { + return IsIdentity(node) || IsControlFlow(node); +} + +inline bool IsDistributedCommunication(const Node* n) { + return n->IsDistributedCommunication(); +} + +// NOTE: We declare Reference type of NodeIter and NeighborIter as Node* (see +// https://en.cppreference.com/w/cpp/iterator/iterator). + +// Iterator for stepping through the nodes of a graph. +class NodeIter { + public: + using iterator_category = std::forward_iterator_tag; + using value_type = Node; + using difference_type = std::ptrdiff_t; + using pointer = Node*; + using reference = Node*; + + NodeIter(const Graph* graph, int id); + bool operator==(const NodeIter& rhs) const; + bool operator!=(const NodeIter& rhs) const; + void operator++(); + reference operator*() const; + pointer operator->() const; + + private: + // Invariant: id_ == graph_->num_node_ids() || graph_->FindId(id_) != nullptr + const Graph* graph_; + int id_; +}; + +// Iterator for stepping through the neighbors of a node. +class NeighborIter { + public: + using iterator_category = std::forward_iterator_tag; + using value_type = Node; + using difference_type = std::ptrdiff_t; + using pointer = Node*; + using reference = Node*; + + NeighborIter(EdgeSet::const_iterator iter, bool incoming); + bool operator==(const NeighborIter& rhs) const; + bool operator!=(const NeighborIter& rhs) const; + void operator++(); + reference operator*() const; + pointer operator->() const; + + private: + EdgeSet::const_iterator iter_; + bool incoming_; +}; + +// IMPLEMENTATION DETAILS, PLEASE IGNORE + +inline NodeIter::NodeIter(const Graph* graph, int id) + : graph_(graph), id_(id) {} + +inline bool NodeIter::operator==(const NodeIter& rhs) const { + DCHECK(graph_ == rhs.graph_); + return id_ == rhs.id_; +} + +inline bool NodeIter::operator!=(const NodeIter& rhs) const { + return !(*this == rhs); +} + +inline void NodeIter::operator++() { + while (true) { + DCHECK_LE(id_, graph_->num_node_ids()); + ++id_; + if (id_ >= graph_->num_node_ids() || graph_->FindNodeId(id_) != nullptr) { + return; + } + } +} + +inline Node* NodeIter::operator*() const { return graph_->FindNodeId(id_); } + +inline Node* NodeIter::operator->() const { return graph_->FindNodeId(id_); } + +inline NeighborIter::NeighborIter(EdgeSet::const_iterator iter, bool incoming) + : iter_(iter), incoming_(incoming) {} + +inline bool NeighborIter::operator==(const NeighborIter& rhs) const { + return iter_ == rhs.iter_ && incoming_ == rhs.incoming_; +} + +inline bool NeighborIter::operator!=(const NeighborIter& rhs) const { + return !(*this == rhs); +} + +inline void NeighborIter::operator++() { ++iter_; } + +inline Node* NeighborIter::operator*() const { + const Edge* e = *iter_; + return incoming_ ? e->src() : e->dst(); +} + +inline Node* NeighborIter::operator->() const { + const Edge* e = *iter_; + return incoming_ ? e->src() : e->dst(); +} + +inline bool Edge::IsControlEdge() const { + // Note that if either src_output_ or dst_input_ is kControlSlot, + // so is the other one (AddEdge checks this). + return src_output_ == Graph::kControlSlot; +} + +inline gtl::iterator_range Graph::nodes() const { + // Note that NodeId 0 is always valid since we don't let the source + // node be removed from the graph. + return gtl::make_range(NodeIter(this, 0), NodeIter(this, num_node_ids())); +} + +inline gtl::iterator_range Graph::op_nodes() const { + // Note that NodeId 0 is always valid since we don't let the source + // node be removed from the graph. + // + // The current implementation of Graph maintains the invariant that the + // first two nodes are the source and sink nodes, and all other nodes are op + // nodes. This method (op_nodes()) relies on this invariant. + NodeIter begin(this, 0); + NodeIter end(this, num_node_ids()); + if (begin != end) { + ++begin; + } + if (begin != end) { + ++begin; + } + return gtl::make_range(begin, end); +} + +inline void Node::set_assigned_device_name_index(int index) { + graph_->CheckDeviceNameIndex(index); + assigned_device_name_index_ = index; +} + +inline void Node::set_assigned_device_name(const std::string& device_name) { + graph_->set_assigned_device_name(this, device_name); +} + +inline const std::string& Node::assigned_device_name() const { + return graph_->get_assigned_device_name(*this); +} + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPH_GRAPH_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/graph/graph_debug_info_builder.h b/third_party/tflite-hdrs/tensorflow/core/graph/graph_debug_info_builder.h new file mode 100644 index 00000000..b1c8fcef --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/graph/graph_debug_info_builder.h @@ -0,0 +1,210 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPH_GRAPH_DEBUG_INFO_BUILDER_H_ +#define TENSORFLOW_CORE_GRAPH_GRAPH_DEBUG_INFO_BUILDER_H_ + +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "tensorflow/core/framework/graph_debug_info.pb.h" +#include "tensorflow/core/platform/stack_frame.h" +#include "tsl/platform/macros.h" + +namespace tensorflow { + +// Language agnostic stack traces. +class AbstractStackTrace { + public: + struct TracePrintingOptions { + // Show inline the contents of each stack line. + bool show_line_contents = false; + + // Drop the common largest prefix of all filenames in stack frames. + bool filter_common_prefix = false; + + // Do not show internal frames. + bool drop_internal_frames = false; + }; + + virtual ~AbstractStackTrace() = default; + + // The returned span is alive as long as the AbstractStackTrace is alive. + virtual absl::Span ToFrames() const = 0; + + // Returns the stack frames without caching any generated data. + virtual std::vector ToUncachedFrames() const = 0; + + // Returns the last stack frame from user code, attempting to ignore the + // framework code. Returns an empty frame if no such stack frame was found. + virtual StackFrame LastUserFrame() const = 0; + + // Returns stack trace from user code (instead of op creation ones returned in + // ToFrames). + virtual std::vector GetUserFrames(int limit) const = 0; + + virtual std::string ToString(const TracePrintingOptions& opts) const = 0; +}; + +// A frozen sequence of StackFrames; an adapter for a span of StackFrames that +// conforms to the AbstractStackTrace contract. +class FrozenStackTrace : public AbstractStackTrace { + public: + // Constructs a FrozenStackTrace from a span of StackFrames by making a copy + // of each stack frame. + explicit FrozenStackTrace(absl::Span frames, + absl::Span user_frames = {}); + + explicit FrozenStackTrace(std::vector&& frames) + : frames_(std::move(frames)), user_frames_({}) {} + + FrozenStackTrace(FrozenStackTrace&&) = default; + + // Constructs a FrozenStackTrace from serialized proto data. + FrozenStackTrace(const GraphDebugInfo::StackTrace& stack_trace, + const GraphDebugInfo& debug_info); + + ~FrozenStackTrace() override = default; + + absl::Span ToFrames() const override; + + std::vector ToUncachedFrames() const override; + + StackFrame LastUserFrame() const override; + + std::vector GetUserFrames(int limit) const override; + + std::string ToString(const TracePrintingOptions& opts) const override; + + private: + std::vector frames_; + std::vector user_frames_; +}; + +// Holder type to use `AbstractStackTrace` as a key. +struct StackTracePointer { + std::shared_ptr trace; + + template + friend H AbslHashValue(H h, const StackTracePointer& p) { + for (const auto& frame : p.trace->ToFrames()) { + h = H::combine(std::move(h), frame); + } + return h; + } + + bool operator==(const StackTracePointer& other) const { + absl::Span other_frames = other.trace->ToFrames(); + absl::Span frames = trace->ToFrames(); + return frames == other_frames; + } +}; + +using StackTracesMap = + absl::flat_hash_map>; + +// Load all stack traces from `debug_info`. +StackTracesMap LoadTracesFromDebugInfo(const GraphDebugInfo& debug_info); +absl::StatusOr LoadTracesFromDebugInfoStr( + absl::string_view debug_info_str); + +// Generates a GraphDebugInfo proto from a StackTracesMap object. Returns user +// frames by default. If `user_frames` is false, returns all frames. +GraphDebugInfo StackTracesMapToGraphDebugInfo(const StackTracesMap& map, + bool user_frames = true); + +// Builder for GraphDebugInfo protos from either an existing map of string keys +// to stack traces, or individual stack traces, or both. All stack traces in a +// GraphDebugInfo are stored with a string key in the `traces` field. In the +// case of an existing map, its keys are used, appended with a key suffix, +// which may be empty. If it is not empty, it is conventionally of the form +// "@function_name", although this class doesn't care. In the case of an +// individual stack trace, a key for `traces` must be provided. +// +// This builder will create a list of the unique file names across all stack +// traces and store it in the `files` field. When storing stack traces into the +// proto, file names are replaced by their index into `files`. +// +// Typical usage is to call one or both of the accumulate methods one or more +// times and then to call the Build(). +class GraphDebugInfoBuilder { + public: + struct Options { + // Call the AbstractTraceMap GetUserFrames method rather than ToFrames + bool user_frames; + // Value of `limit` to pass to GetUserFrames if `user_frames` is true, + // otherwise ignored + int user_frames_limit; + }; + + GraphDebugInfoBuilder(); + virtual ~GraphDebugInfoBuilder() = default; + + // Adds a map of stack traces to the GraphDebugInfo proto. For each key (node + // id) and stack traces entry in `stack_traces_map`, combine the key with + // `key_suffix` to form a new key and use that to add the stack traces to the + // `traces` field of the proto. If not empty, the suffix is typically of the + // form "@function_name", although this function doesn't care. + void AccumulateStackTracesMap(const StackTracesMap& stack_traces_map, + absl::string_view key_suffix = "", + const GraphDebugInfoBuilder::Options& options = + GraphDebugInfoBuilder::Options()); + + // Adds one stack trace to the GraphDebugInfo proto, using `traces_key` as the + // key for the `traces` field of the proto. + void AccumulateStackTrace(std::shared_ptr trace, + absl::string_view traces_key, + const GraphDebugInfoBuilder::Options& options = + GraphDebugInfoBuilder::Options()); + + void AppendGraphDebugInfo(absl::string_view prefix, + const GraphDebugInfo& new_info); + + // These string methods are used in the Python bindings to avoid symbol + // resolution errors with pybind on Windows. + absl::Status AppendGraphDebugInfoStr(absl::string_view prefix, + absl::string_view new_info_str); + + std::string ToGraphDebugInfoStr() const; + + // Returns the GraphDebugInfo proto. + GraphDebugInfo Build() const; + + private: + void AppendToStackTraceProto(const StackFrame& stack_frame, + GraphDebugInfo::StackTrace& stack_trace_proto); + + std::unique_ptr debug_info_; + absl::flat_hash_map file_name_to_index_; + + absl::flat_hash_map trace_to_index_; + absl::flat_hash_map frame_to_index_; + int new_name_index_ = 0; + + GraphDebugInfoBuilder(const GraphDebugInfoBuilder&) = delete; + void operator=(const GraphDebugInfoBuilder&) = delete; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPH_GRAPH_DEBUG_INFO_BUILDER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/graph/graph_def_builder.h b/third_party/tflite-hdrs/tensorflow/core/graph/graph_def_builder.h new file mode 100644 index 00000000..b635ece0 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/graph/graph_def_builder.h @@ -0,0 +1,216 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPH_GRAPH_DEF_BUILDER_H_ +#define TENSORFLOW_CORE_GRAPH_GRAPH_DEF_BUILDER_H_ + +#include +#include + +#include "tensorflow/core/framework/function.pb.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/graph/node_builder.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/gtl/array_slice.h" + +namespace tensorflow { + +// Given a function like: +// namespace ops { +// Node* Identity(NodeOut input, const GraphDefBuilder::Options& opts) { +// if (opts.HaveError()) return nullptr; +// static const string kOpName = "Identity"; +// NodeBuilder node_builder(opts.GetNameForOp(kOpName), kOpName, +// opts.op_registry()); +// node_builder.Input(input); +// return opts.FinalizeBuilder(&node_builder); +// } +// } // namespace ops +// +// // Or, alternatively: +// namespace ops { +// Node* Identity(NodeOut input, const GraphDefBuilder::Options& opts) { +// static const string kOpName = "Identity"; +// return UnaryOp(kOpName, input, opts); +// } +// } // namespace ops +// +// You call it like: +// GraphDefBuilder b; +// using namespace ::tensorflow::ops; // NOLINT(build/namespaces) +// Node* na = Const(7, b.opts()); +// // Note: WithName() returns a copy, opts is unchanged. +// Node* nb = Const(5, b.opts().WithName("control-input")); +// Node* nc = Identity(na, b.opts().WithControlInput(nb)); +// GraphDef graph_def; +// Status status = b.ToGraphDef(&graph_def); +// if (!status.ok()) { /* Handle error */ } +// +// In tests you can skip the status handling via: +// GraphDefBuilder b(GraphDefBuilder::kFailImmediately); +// ... +// b.ToGraphDef(&graph_def); + +class GraphDefBuilder { + public: + // Options for adding a Node to a Graph. + class Options { + public: + // Sets the Graph (that Nodes will be added to) and the status. The + // status may be set to nullptr, in which case errors cause CHECK + // failures. The graph and status must outlive *this. + Options(Graph* graph, absl::Status* status); + ~Options(); + + // Methods for setting options. These are const methods: they + // return a copy of *this with the option set. + Options WithName(absl::string_view name) const; + Options WithDevice(absl::string_view device) const; + Options WithControlInput(Node* control_input) const; + Options WithControlInputs(absl::Span control_inputs) const; + + // Override the default value for an optional attr. + template + Options WithAttr(absl::string_view attr_name, T&& value) const { + return Options(*this).WithAttrImpl(attr_name, std::forward(value)); + } + // Note: overload needed to allow {...} expressions for value. + template + Options WithAttr(absl::string_view attr_name, + std::initializer_list value) const { + return WithAttr>(attr_name, std::move(value)); + } + + // Methods for using options from a function that creates a Node. + + // Returns true if the status associated with *this has an error. + // Use this to skip processing that may depend on prior results. + bool HaveError() const { return status_ != nullptr && !status_->ok(); } + + // Returns a string representation of the status associated with *this. + // Returns the string `"OK"` if the status doesn't have any error. + string StatusToString() const { + return status_->ok() ? "OK" : std::string(status_->message()); + } + + // Given the Op type name, return a name for a node of that type. + // Uses the value set in WithName() if that has been called. Otherwise, + // returns a name built out of the Op type name. + string GetNameForOp(absl::string_view op) const; + + // Sets the device, adds control inputs, adds attrs, and calls Finalize(). + // If Finalize returns an error, it is saved and this function returns + // nullptr. + Node* FinalizeBuilder(NodeBuilder* builder) const; + + // Updates the associated status, if any, or calls TF_CHECK_OK if none. + void UpdateStatus(const absl::Status& status) const; + + // Accessor + const OpRegistryInterface* op_registry() const { + return graph_->op_registry(); + } + + private: + Options WithNameImpl(absl::string_view name); + Options WithDeviceImpl(absl::string_view device); + Options WithControlInputImpl(Node* control_input); + Options WithControlInputsImpl(absl::Span control_inputs); + template + Options WithAttrImpl(absl::string_view name, T&& value) { + attrs_.emplace_back(string(name), AttrValue()); + SetAttrValue(std::forward(value), &attrs_.back().second); + return *this; + } + + Graph* const graph_; + absl::Status* const status_; + string name_; + string device_; + std::vector control_inputs_; + std::vector> attrs_; + }; + + // Start building a new graph. + explicit GraphDefBuilder( + const OpRegistryInterface* op_registry = OpRegistry::Global()) + : graph_(op_registry), flib_def_(op_registry), opts_(&graph_, &status_) {} + + // For use in tests, where you want to fail immediately on error instead + // of checking the status at the end. + enum TestFailImmediatelyType { kFailImmediately }; + explicit GraphDefBuilder( + TestFailImmediatelyType, + const OpRegistryInterface* op_registry = OpRegistry::Global()) + : graph_(op_registry), flib_def_(op_registry), opts_(&graph_, nullptr) {} + + // Gets the Options with the associated Graph and Status. + const Options& opts() const { return opts_; } + + // Once all the nodes have been added, call this to get whether it was + // successful, and if so fill *graph_def. + absl::Status ToGraphDef(GraphDef* graph_def) const; + + // Adds the function and gradient definitions in `fdef_lib` to this graph's op + // registry. Ignores duplicate functions, and returns a bad status if an + // imported function differs from an existing function or op with the same + // name. + absl::Status AddFunctionLibrary(const FunctionDefLibrary& fdef_lib) { + return flib_def_.AddLibrary(fdef_lib); + } + + // Returns whether a user-defined function with `name` already exists in the + // graph. + bool HasFunction(const string& name) { + return flib_def_.Find(name) != nullptr; + } + + private: + Graph graph_; + FunctionLibraryDefinition flib_def_; + absl::Status status_; + Options opts_; +}; + +namespace ops { + +// A NodeOut may either be a regular input or back input. Regular +// inputs are specified via either a Node* or a Node* and an output +// index. Back inputs are specified by a node name, output index, and +// output type. +typedef NodeBuilder::NodeOut NodeOut; + +// For adding an Op with no inputs to a GraphDefBuilder. +Node* SourceOp(const string& op_name, const GraphDefBuilder::Options& opts); + +// For adding an Op with one input to a GraphDefBuilder. +Node* UnaryOp(const string& op_name, NodeOut input, + const GraphDefBuilder::Options& opts); + +// For adding an Op with two inputs to a GraphDefBuilder. +Node* BinaryOp(const string& op_name, NodeOut a, NodeOut b, + const GraphDefBuilder::Options& opts); + +// For adding an Op with three inputs to a GraphDefBuilder. +Node* TernaryOp(const string& op_name, NodeOut a, NodeOut b, NodeOut c, + const GraphDefBuilder::Options& opts); + +} // namespace ops +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPH_GRAPH_DEF_BUILDER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/graph/graph_node_util.h b/third_party/tflite-hdrs/tensorflow/core/graph/graph_node_util.h new file mode 100644 index 00000000..146c4c07 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/graph/graph_node_util.h @@ -0,0 +1,64 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_GRAPH_GRAPH_NODE_UTIL_H_ +#define TENSORFLOW_CORE_GRAPH_GRAPH_NODE_UTIL_H_ + +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/platform/status.h" + +namespace tensorflow { +class Node; +struct NodeDebugInfo; + +// We forward declare protos so that kernels don't need to depend on them +class NodeDef; +class OpDef; + +// Produce a human-readable version of a Node or NodeDef that is more concise +// than a text-format proto. +string SummarizeNode(const Node& node); + +// Produces a formatted string pattern from the node which can uniquely identify +// this node upstream to produce an informative error message. The pattern +// followed is: {{node }} +string FormatNodeForError(const Node& node); + +// Merges the original node names from the debug information of 'from' to the +// debug information of 'to'. +void MergeDebugInfo(const NodeDebugInfo& from, Node* to); +void MergeDebugInfo(const NodeDebugInfo& from, NodeDef* to); +void MergeDebugInfo(const NodeDef& from, NodeDef* to); + +// Computes the mapping from input/output argument name to the +// corresponding input/output index range. For example, +// input "foo" corresponds to input indices +// [ (*inputs)["foo"].first, (*inputs)["foo"].second ). +// NOTE(mrry): To reduce allocations when the map is used and save +// space, the returned `NameRangeMap` objects borrow the input/output +// argument names from `op_def`. The `op_def` must outlive the +// returned `NameRangeMap` objects. +absl::Status NameRangesForNode(const Node& node, const OpDef& op_def, + NameRangeMap* inputs, NameRangeMap* outputs); + +// Returns "status" with formatted Node attached as additional text +// in the error message. If 'allow_multiple_formatted_node' is false and there +// is already a formatted Node present in 'status', we simply attach the name +// of the Node instead of the formatted string. +absl::Status AttachDef(const absl::Status& status, const Node& node, + bool allow_multiple_formatted_node = false); +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPH_GRAPH_NODE_UTIL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/graph/graph_partition.h b/third_party/tflite-hdrs/tensorflow/core/graph/graph_partition.h new file mode 100644 index 00000000..59e9fe0e --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/graph/graph_partition.h @@ -0,0 +1,109 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPH_GRAPH_PARTITION_H_ +#define TENSORFLOW_CORE_GRAPH_GRAPH_PARTITION_H_ + +#include +#include +#include +#include + +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/graph/costmodel.h" +#include "tensorflow/core/graph/graph.h" + +namespace tensorflow { + +struct PartitionOptions { + // A function that returns a location for the execution of a given + // Node. + typedef std::function NodeToLocFunc; + NodeToLocFunc node_to_loc = nullptr; + + // A function that returns a unique graph node name with the given + // prefix. + typedef std::function NewNameFunc; + NewNameFunc new_name = nullptr; + + // A function that returns the incarnation of a device given the + // device's fullname. If not found, GetIncarnationFunc should return + // kIllegalIncarnation. + static constexpr uint64 kIllegalIncarnation = 0; + typedef std::function GetIncarnationFunc; + GetIncarnationFunc get_incarnation = nullptr; + + // If specified, flib_def defines a function library that should be + // partitioned and replicated into each resulting partition graphs. + const FunctionLibraryDefinition* flib_def = nullptr; + + // True if all the control flow "code" has already been added. The + // control flow code needs to be added when we still have the entire + // graph before any partitioning. So this flag should be false for + // the first partitioning but true for all subsequent partitioning. + // + // TODO(yuanbyu): We could also make the addition of the control + // flow code incremental based on 'node_to_loc'. This makes the + // communication a broadcast tree, which could be more efficient when + // the number of participating devices is large. + bool control_flow_added = false; + + // A function that returns the data type into which the tensor + // should be cast before sent over the wire. + typedef std::function ShouldCastFunc; + ShouldCastFunc should_cast = nullptr; + + // Schedule the execution of the recvs based on their start times + // computed by some scheduling algorithm. The recvs are divided into + // epochs based on their start times. A recv is enabled only when + // execution reaches its epoch - N for some predefined N. + bool scheduling_for_recvs = false; + // The start time for each node in the graph computed by some scheduling + // algorithm. If 'need_to_record_start_times' is true, we record them + // in the graph as a node attribute. + bool need_to_record_start_times = false; + std::vector start_times; + + // Optional customized function to compute the "tensor_name" attr value of + // Send/Recv ops inserted during partitioning. + std::function get_tensor_name_attr = nullptr; + + // If true, the `Partition()` function can make destructive changes to the + // passed-in `Graph`. + // + // TODO(b/327983931): Add wrapper functions for partitioning that clearly + // signal this intent by taking a `Graph` or `Graph&&`. + bool can_make_destructive_changes = false; +}; + +// Partition "input" graph into a set of graphs, one per location. +// The location for node n is derived by calling opts.node_to_loc(n). +// New nodes added by Partition use "opts.new_name(old_name)" to +// generate node names. +// +// Stores the partitions in *partitions. +absl::Status Partition(const PartitionOptions& opts, Graph* input, + std::unordered_map* partitions); + +// Add control edges to the partitions to control the ordering +// and timing of the recv nodes based on the start times calculated +// using some scheduling algorithm. +absl::Status AddControlEdges(const PartitionOptions& opts, + std::unordered_map* partitions); + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPH_GRAPH_PARTITION_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/graph/mkl_graph_util.h b/third_party/tflite-hdrs/tensorflow/core/graph/mkl_graph_util.h new file mode 100644 index 00000000..00e2e74b --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/graph/mkl_graph_util.h @@ -0,0 +1,284 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPH_MKL_GRAPH_UTIL_H_ +#define TENSORFLOW_CORE_GRAPH_MKL_GRAPH_UTIL_H_ +#ifdef INTEL_MKL + +#include "absl/base/call_once.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/util/env_var.h" +#include "tensorflow/core/util/util.h" + +namespace tensorflow { +// Since our ops are going to produce and also consume N addition tensors +// (Mkl) for N Tensorflow tensors, we can have following different +// orderings among these 2N tensors. +// +// E.g., for Tensorflow tensors A, B, and C, our ops will produce and +// consume A_m, B_m, and C_m additionally. +// +// INTERLEAVED: in this case 2N tensors are interleaved. So for above +// example, the ordering looks like: A, A_m, B, B_m, C, C_m. +// +// CONTIGUOUS: in thi case N Tensorflow tensors are contiguous followed +// by N Mkl tensors. So for above example, the ordering looks +// like: A, B, C, A_m, B_m, C_m +// +// Following APIs map index of original Tensorflow tensors to their +// appropriate position based on selected ordering. For contiguous ordering, +// we need to know the total number of tensors (parameter total). +// +typedef enum { TENSORS_INTERLEAVED, TENSORS_CONTIGUOUS } MklTfTensorOrdering; +// NOTE: Currently, we use contiguous ordering. If you change this, then you +// would need to change Mkl op definitions in nn_ops.cc. +static const MklTfTensorOrdering kTensorOrdering = TENSORS_CONTIGUOUS; + +// Get index of MetaData tensor from index 'n' of Data tensor. +inline int DataIndexToMetaDataIndex(int n, int total_tensors) { + if (kTensorOrdering == MklTfTensorOrdering::TENSORS_INTERLEAVED) { + // For interleaved ordering, Mkl tensor follows immediately after + // Tensorflow tensor. + return n + 1; + } else { + CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS); + // For contiguous ordering, Mkl tensor is n+total_tensors / 2 away. + return n + total_tensors / 2; + } +} + +int inline GetTensorDataIndex(int n, int total_tensors) { + if (kTensorOrdering == MklTfTensorOrdering::TENSORS_INTERLEAVED) { + return 2 * n; // index corresponding to nth input/output tensor + } else { + CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS); + return n; + } +} + +int inline GetTensorMetaDataIndex(int n, int total_tensors) { + // Get index for TensorData first and then use mapping function + // to get TensorMetaData index from TensorData index. + int tidx = GetTensorDataIndex(n, total_tensors); + return DataIndexToMetaDataIndex(tidx, total_tensors); +} + +// check if the control between src and dst nodes already exists +bool inline DoesControlEdgeExist(const Node* src, const Node* dst) { + for (const Edge* edge : src->out_edges()) { + if (edge->IsControlEdge() && edge->dst() == dst) { + return true; + } + } + return false; +} + +// In TF 2.8, oneDNN blocked format will not be supported. +// TODO(intel_tf): Cleanup shall be done in future: +// (1) Remove this method; +// (2) Update related code wherever it is called. +bool inline NativeFormatEnabled() { return true; } + +// Check if the data_format attribute in the node def represents 5D tensor +bool inline Check5DFormat(const NodeDef& ndef) { + string data_format; + TF_CHECK_OK(GetNodeAttr(ndef, "data_format", &data_format)); + if (data_format.compare("NCDHW") == 0 || data_format.compare("NDHWC") == 0) { + return true; + } + return false; +} + +namespace mkl_op_registry { +// MKL operators whose kernels are registered with 'MklLayoutDependentOp' label +// (e.g., MklConv2D) understand input tensors in MKL layout. These operators +// get additional meta-tensors for actual input tensors. +static const char* kMklLayoutDependentOpLabel = "MklLayoutDependentOp"; +static const char* kMklLayoutDependentOpLabelPattern = + "label='MklLayoutDependentOp'"; +// MKL operators whose kernels are registered with 'MklNameChangeOp' label +// (e.g., MklMatMul, MklTranspose) do not understand input tensors in MKL +// layout. These operators do not get additional meta-tensors. The signatures of +// these operators are the same as the original TensorFlow operators that they +// correspond to. So these ops just go through a name change during graph +// rewrite pass. +static const char* kMklNameChangeOpLabel = "MklNameChangeOp"; +static const char* kMklNameChangeOpLabelPattern = "label='MklNameChangeOp'"; +static const char* kMklQuantizedOpLabel = "QuantizedMklOp"; +static const char* kMklQuantizedOpLabelPattern = "label='QuantizedMklOp'"; + +// Prefix that we add to Tensorflow op name to construct Mkl op name. +static const char* const kMklOpPrefix = "_Mkl"; +// TODO(intel-tf): PR review feedback (penpornk) +// Can we add eager_mode (or is_eager) as an op attribute instead? +// This way we don't need to rename the op just to pass eager_mode +// through template parameter. +static const char* const kMklEagerOpPrefix = "_MklEager"; + +// Prefix that we add to TF op name to construct MKL op that does not +// depend on layout propagation. It will be used in both Eager and graph +// modes unless there is a reason to have additional op name with +// _MklEager prefix. +static const char* const kMklNativeOpPrefix = "_MklNative"; + +// Get the name of Mkl Native (does not depend on layout propagation) op +// from original TensorFlow op. +inline string GetMklNativeOpName(const string& name) { + // There are few operators that don't depend on layout propagation but are + // prefixed with _Mkl instead of _MklNative. + bool result = + (0 == name.compare("ConjugateTranspose") || + 0 == name.compare("SparseTensorDenseMatMul") || + 0 == name.compare("BatchMatMul") || 0 == name.compare("BatchMatMulV2") || + 0 == name.compare("Einsum") || 0 == name.compare("MatMul") || + 0 == name.compare("Transpose") || 0 == name.compare("QuantizeV2") || + 0 == name.compare("Dequantize") || 0 == name.compare("Softmax") || + 0 == name.rfind("Quantized", 0)); + + if (result) { + return string(kMklOpPrefix) + name; + } else { + return string(kMklNativeOpPrefix) + name; + } +} + +// Get the name of Mkl op from original TensorFlow op +// We prefix the original op with _Mkl or _MklNative to get Mkl op. +inline string GetMklOpName(const string& name) { + if (!NativeFormatEnabled()) { + return string(kMklOpPrefix) + name; + } else { + return GetMklNativeOpName(name); + } +} + +// Get the name of Mkl Eager op from original TensorFlow op +// We prefix 'MklEager' to the original op to get Mkl Eager op. +inline string GetMklEagerOpName(const string& name) { + return string(kMklEagerOpPrefix) + name; +} + +// Check whether opname with type T is registered as MKL operator +// that will go through name change or layout change pass. +// +// @input: name of the op +// @input: T datatype to be used for checking op +// @return: true if opname is registered as MKL op that will go through name +// change or layout change pass; false otherwise +static inline bool IsMklOp(const string& op_name, DataType T, + bool is_native_op) { + string label = is_native_op ? kMklNameChangeOpLabelPattern + : kMklLayoutDependentOpLabelPattern; + string registered_kernels_key = op_name + label + std::to_string(T); + thread_local static auto registered_kernels_map = + std::make_unique>(); + auto kernel_element = registered_kernels_map->find(registered_kernels_key); + bool kernel_registered = false; + + if (kernel_element == registered_kernels_map->end()) { + string registered_kernels = KernelsRegisteredForOp(op_name); + // String returned by KernelsRegisteredForOp looks like below: + // + // Op = _MklMatMul, kernels = + // device='CPU'; label='MklNameChangeOp'; T in [DT_COMPLEX128] + // device='CPU'; label='MklNameChangeOp'; T in [DT_COMPLEX64] + // device='CPU'; label='MklNameChangeOp'; T in [DT_DOUBLE] + // device='CPU'; label='MklNameChangeOp'; T in [DT_FLOAT] + + if (is_native_op && + registered_kernels.find(kMklQuantizedOpLabelPattern) != string::npos) { + // Restrict quantized ops to QUINT8, QINT8 and DT_QINT32 + kernel_registered = (T == DT_QUINT8 || T == DT_QINT8 || T == DT_QINT32); + } + + // Now we just construct a search string to match what we are looking for. + string search_string = + label + string("; T in [") + DataType_Name(T) + string("]"); + + if (registered_kernels.find(search_string) != string::npos) { + kernel_registered = is_native_op + ? (T == DT_COMPLEX128 || T == DT_COMPLEX64 || + T == DT_DOUBLE || T == DT_FLOAT) + : T == DT_FLOAT; + if (!kernel_registered) { + if ((T == DT_BFLOAT16 || T == DT_HALF) && + IsDataTypeSupportedByOneDNNOnThisCPU(T)) { + kernel_registered = true; + } else { + DataTypeUnsupportedWarning(T); + } + } + } + registered_kernels_map->insert( + std::make_pair(registered_kernels_key, kernel_registered)); + } else { + // Kernel is visited at least once. Return stored registration result. + kernel_registered = kernel_element->second; + } + return kernel_registered; +} + +// TODO(intel-tf): QuantizedConv2D is registered with input: QUINT8 +// filter:QINT8 for oneDNN integration. First a dummy kernel is created +// and then it is replaced by an actual kernel. +static inline bool IsMklQuantizedOp(const string& op_name, DataType Tinput, + DataType Tfilter) { + // Restrict quantized ops to QUINT8 and QINT8 for now + if (IsMklOp(op_name, Tinput, kMklQuantizedOpLabelPattern)) { + return (Tfilter == DT_QINT8); + } + return false; +} + +// Check if the operator with 'op_name' and type 'T' is an MKL operator that +// will either understand input tensors in MKL layout or will go through name +// rewrite that some operators go through. +static inline bool IsMklOp(const string& op_name, DataType T) { + return IsMklOp(op_name, T, true) || IsMklOp(op_name, T, false); +} + +static inline bool IsMklOp(const Node* n) { + DataType T; + return GetNodeAttr(n->def(), "T", &T).ok() && IsMklOp(n->type_string(), T); +} + +// Check whether opname with type T is registered as MKL-compliant and +// is element-wise. +// +// @input: name of the op +// @input: T datatype to be used for checking op +// @return: true if opname is registered as element-wise Mkl op; +// false otherwise +static inline bool IsMklElementWiseOp(const string& op_name, DataType T) { + if (!IsMklOp(op_name, T)) { + return false; + } + bool result = (0 == op_name.compare(GetMklOpName("Add")) || + 0 == op_name.compare(GetMklOpName("AddV2")) || + 0 == op_name.compare(GetMklOpName("Sub")) || + 0 == op_name.compare(GetMklOpName("Mul")) || + 0 == op_name.compare(GetMklOpName("Maximum")) || + 0 == op_name.compare(GetMklOpName("Sigmoid")) || + 0 == op_name.compare(GetMklOpName("SquaredDifference"))); + + return result; +} +} // namespace mkl_op_registry +} // namespace tensorflow +#endif // INTEL_MKL +#endif // TENSORFLOW_CORE_GRAPH_MKL_GRAPH_UTIL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/graph/mkl_testlib.h b/third_party/tflite-hdrs/tensorflow/core/graph/mkl_testlib.h new file mode 100644 index 00000000..3dffded1 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/graph/mkl_testlib.h @@ -0,0 +1,38 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPH_MKL_TESTLIB_H_ +#define TENSORFLOW_CORE_GRAPH_MKL_TESTLIB_H_ + +#ifdef INTEL_MKL + +#include "tensorflow/core/graph/graph.h" + +namespace tensorflow { +namespace test { +namespace graph { + +Node* oneDNNSoftmax(Graph* g, Node* input); + +#ifdef ENABLE_ONEDNN_V3 +Node* oneDNNSparseCSRMatmul(Graph* g, Node* csr_matrix_t, Node* b); +#endif // ENABLE_ONEDNN_V3 + +} // namespace graph +} // namespace test +} // namespace tensorflow + +#endif // INTEL_MKL +#endif // TENSORFLOW_CORE_GRAPH_MKL_TESTLIB_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/graph/node_builder.h b/third_party/tflite-hdrs/tensorflow/core/graph/node_builder.h new file mode 100644 index 00000000..6f249371 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/graph/node_builder.h @@ -0,0 +1,181 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPH_NODE_BUILDER_H_ +#define TENSORFLOW_CORE_GRAPH_NODE_BUILDER_H_ + +#include +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_def.pb.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/gtl/array_slice.h" + +namespace tensorflow { + +// This is a helper for creating a Node and adding it to a Graph. +// Internally, it uses a NodeDefBuilder to automatically set attrs +// that can be inferred from the inputs, and use default values +// (where they exist) for unspecified attrs. Example usage: +// +// Node* node; +// Status status = NodeBuilder(node_name, op_name) +// .Input(...) +// .Attr(...) +// .Finalize(&graph, &node); +// if (!status.ok()) return status; +// // Use node here. +class NodeBuilder { + public: + // For specifying the output of a Node to provide to one of the Input() + // functions below. It supports both regular inputs (where you are + // connecting to an existing Node*), and inputs from outside the graph + // (or haven't been added to the graph yet, like back edges, where + // you don't have a Node*). Both types can be mixed, e.g. in an + // ArraySlice. + struct NodeOut { + // For referencing an existing Node. + NodeOut(Node* n, int32_t i = 0); + NodeOut(OutputTensor t); + + // For referencing Nodes not in the graph being built. It is + // useful when preparing a graph for ExtendSession or creating a + // back edge to a node that hasn't been added to the graph yet, + // but will be. + NodeOut(absl::string_view name, int32_t i, DataType t); + + // Default constructor for std::vector. + NodeOut(); + + Node* node; + // error is set to true if: + // * the NodeOut was default constructed and never overwritten, + // * a nullptr Node* was passed to the NodeOut constructor, or + // * an out-of-range index was passed to the NodeOut constructor. + bool error; + string name; + int32 index; + DataType dt; + }; + + // Specify the name and the Op (either via an OpDef or the name of + // the Op plus a registry) for the Node. Other fields are + // specified by calling the methods below. + // REQUIRES: The OpDef must satisfy ValidateOpDef(). + NodeBuilder(absl::string_view name, absl::string_view op_name, + const OpRegistryInterface* op_registry = OpRegistry::Global(), + const NodeDebugInfo* debug = nullptr); + NodeBuilder(absl::string_view name, const OpDef* op_def); + + // Create a NodeBuilder from an existing NodeDefBuilder. + NodeBuilder(const NodeDefBuilder& def_builder); + + // You must call one Input() function per input_arg in the Op, + // *and in the same order as the input_args appear in the OpDef.* + + // For inputs that take a single tensor. + NodeBuilder& Input(Node* src_node, int src_index = 0); + NodeBuilder& Input(NodeOut src); + + // For inputs that take a list of tensors. + NodeBuilder& Input(absl::Span src_list); + + // Require that this node run after src_node(s). + NodeBuilder& ControlInput(Node* src_node); + NodeBuilder& ControlInputs(absl::Span src_nodes); + + // Sets the "requested device spec" in the NodeDef (not the + // "assigned device" in the Node). + NodeBuilder& Device(absl::string_view device_spec); + + // Sets the device name in the "assigned device" field in tensorflow::Node. + NodeBuilder& AssignedDevice(absl::string_view device); + + // Sets the _XlaCluster attribute in created node to `xla_cluster`. + NodeBuilder& XlaCluster(absl::string_view xla_cluster); + + // Set the value of an attr. attr_name must match the name of one of + // attrs defined by the Op, and value must have the corresponding type + // (see SetAttrValue() in ../framework/attr_value_util.h for legal + // types for value). Note that attrs will be set automatically if + // they can be determined by the inputs. + template + NodeBuilder& Attr(absl::string_view attr_name, T&& value); + template + NodeBuilder& Attr(absl::string_view attr_name, + std::initializer_list value); + + // Validates the described node and adds it to *graph, adding edges + // for all (non-back) inputs. If created_node is not nullptr, + // *created_node will be set to the new node (or nullptr on error). + // If `consume` is true, the builder state will be moved into `node_def`, + // and the builder will be left in an undefined state. + absl::Status Finalize(Graph* graph, Node** created_node, + bool consume = false); + + // Same as `Finalize` above, but using StatusOr to return value. Preferred + // form. + absl::StatusOr Finalize(Graph* graph, bool consume = false); + + // Accessors for the values set in the constructor. + const string& node_name() const { return def_builder_.node_name(); } + const OpDef& op_def() const { return def_builder_.op_def(); } + + private: + static DataType SafeGetOutput(const Node* node, int i, bool* error) { + if (node != nullptr && i >= 0 && i < node->num_outputs()) { + *error = false; + return node->output_type(i); + } else { + *error = true; + return DT_FLOAT; + } + } + + // If SafeGetOutput indicates a range error, add it to errors_. + void AddIndexError(const Node* node, int i); + + // Set *dt and returns true if i is in range. Combines + // SafeGetOutput() and AddIndexError(). + bool GetOutputType(const Node* node, int i, DataType* dt); + + NodeDefBuilder def_builder_; + const OpRegistryInterface* op_registry_; + std::vector inputs_; + std::vector control_inputs_; + std::vector errors_; + string assigned_device_; +}; + +// IMPLEMENTATION ------------------------------------------------------------- + +template +NodeBuilder& NodeBuilder::Attr(absl::string_view attr_name, T&& value) { + def_builder_.Attr(attr_name, std::forward(value)); + return *this; +} + +template +NodeBuilder& NodeBuilder::Attr(absl::string_view attr_name, + std::initializer_list value) { + def_builder_.Attr(attr_name, value); + return *this; +} + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPH_NODE_BUILDER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/graph/optimizer_cse.h b/third_party/tflite-hdrs/tensorflow/core/graph/optimizer_cse.h new file mode 100644 index 00000000..ef466fb7 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/graph/optimizer_cse.h @@ -0,0 +1,37 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// An optimization pass that performs common subexpression elimination. + +#ifndef TENSORFLOW_CORE_GRAPH_OPTIMIZER_CSE_H_ +#define TENSORFLOW_CORE_GRAPH_OPTIMIZER_CSE_H_ + +#include +#include "tensorflow/core/graph/graph.h" + +namespace tensorflow { + +// Perform common-subexpression elimination on the graph "*g". If +// "consider_fn" is not nullptr, then only nodes for which +// consider_fn(node) returns true will be considered for combining +// during the common subexpression elimination. +// +// Returns true if and only if 'g' is mutated. +extern bool OptimizeCSE(Graph* g, + const std::function& consider_fn); + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPH_OPTIMIZER_CSE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/graph/regularization/simple_delete.h b/third_party/tflite-hdrs/tensorflow/core/graph/regularization/simple_delete.h new file mode 100644 index 00000000..07ebd00e --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/graph/regularization/simple_delete.h @@ -0,0 +1,28 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPH_REGULARIZATION_SIMPLE_DELETE_H_ +#define TENSORFLOW_CORE_GRAPH_REGULARIZATION_SIMPLE_DELETE_H_ + +#include "tensorflow/core/framework/graph.pb.h" + +namespace tensorflow::graph_regularization { + +// Regularizes the graph_def by deleting non-deterministic sections. +void SimpleDelete(GraphDef& graph_def); + +} // namespace tensorflow::graph_regularization + +#endif // TENSORFLOW_CORE_GRAPH_REGULARIZATION_SIMPLE_DELETE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/graph/regularization/util.h b/third_party/tflite-hdrs/tensorflow/core/graph/regularization/util.h new file mode 100644 index 00000000..2fff6452 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/graph/regularization/util.h @@ -0,0 +1,37 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPH_REGULARIZATION_UTIL_H_ +#define TENSORFLOW_CORE_GRAPH_REGULARIZATION_UTIL_H_ + +#include + +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/platform/statusor.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow::graph_regularization { + +// Computes the Fingerprint64 hash of the GraphDef. +uint64 ComputeHash(const GraphDef& graph_def); + +// Returns the suffix UID of `function_name`, returns an error if there is none. +absl::StatusOr GetSuffixUID(absl::string_view function_name); + +} // namespace tensorflow::graph_regularization + +#endif // TENSORFLOW_CORE_GRAPH_REGULARIZATION_UTIL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/graph/subgraph.h b/third_party/tflite-hdrs/tensorflow/core/graph/subgraph.h new file mode 100644 index 00000000..37013b8f --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/graph/subgraph.h @@ -0,0 +1,165 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPH_SUBGRAPH_H_ +#define TENSORFLOW_CORE_GRAPH_SUBGRAPH_H_ + +#include + +#include "tensorflow/core/framework/device_attributes.pb.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/graph/node_builder.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/protobuf/config.pb.h" + +namespace tensorflow { +namespace subgraph { + +// Information about a graph rewritten by `RewriteGraphForExecution()`. +struct RewriteGraphMetadata { + // The element type of each tensor fed to this subgraph. The order + // of types corresponds to the order of tensor names in + // `fed_outputs` when calling `RewriteGraphForExecution()`. + DataTypeVector feed_types; + // The element type of each tensor fetched from this subgraph. The + // order of types corresponds to the order of tensor names in + // `fetch_outputs` when calling `RewriteGraphForExecution()`. + DataTypeVector fetch_types; +}; + +// Describes the action to take on a particular tensor endpoint (described by +// a ":" pair) when pruning the graph. +// +// The `AddNode()` method must be overridden to describe this action. The method +// will be invoked once during `RewriteGraphForExecution()` with tensor endpoint +// named by `endpoint_name`, and it may either create a single new node, or fail +// with an error if the resulting graph would be invalid. +class PruneRewrite { + public: + // `endpoint_name` and `device_info` must outlive this object. + PruneRewrite(const string* endpoint_name, const DeviceAttributes* device_info) + : endpoint_name_(endpoint_name), device_info_(device_info) {} + virtual ~PruneRewrite() {} + + // Creates a new node whose output replaces the given `tensor` in graph `g`. + // The node will be assigned to the device named in `device_info`. + virtual absl::Status AddNode(Graph* g, NodeBuilder::NodeOut tensor, + Node** out_node) = 0; + + // Returns the name of the tensor to which this rewrite applies. + const string& endpoint_name() { return *endpoint_name_; } + + protected: + // The device on which the new node will be created. + const DeviceAttributes& device_info() { return *device_info_; } + + private: + const string* const endpoint_name_; // Not owned. + const DeviceAttributes* const device_info_; // Not owned. +}; + +// Rewrite the graph structure of "*g" to deal with feeding node +// outputs, fetching node outputs, and only running a subset of the +// graph. "fed_outputs" and "fetch_outputs" are both lists of +// output tensor identifiers in the form of +// "[:]", and "target_nodes_str" is a +// lists of target node names in "*g" "g". +// +// In the resulting graph "*g", output edges in "fed_outputs" have +// been redirected to special "_recv" nodes introduced into the graph. +// If these fed nodes are not needed in order to compute the effects +// of the nodes in "target_node_names" and "fetch_outputs", then these may +// be omitted from the graph. +// +// In the resulting graph "*g", additional "_send" nodes are connected +// to every output in "fetch_outputs". These "_send" nodes are set up +// to execute on the device described by device_info. +// +// On success, returns OK, and sets "*g" to a version of "*g" +// that represents the portions of the graph necessary for producing +// the output of all nodes listed in "target_node_names" and fetching the +// specific node outputs specified in "fetch_outputs". +// +// On failure, returns the error status. Possible errors include: +// - fed output "node:output_index" does not exist in "*g" +// - fetch output "node:output_index" does not exist in "*g" +// - target node "node" does not exist in "*g" +absl::Status RewriteGraphForExecution( + Graph* g, const absl::Span& fed_outputs, + const absl::Span& fetch_outputs, + const absl::Span& target_node_names, + const DeviceAttributes& device_info, bool use_function_convention, + RewriteGraphMetadata* out_metadata); + +// A more general version of the above function that supports +// customizable rewriting actions for each fed and fetched tensor. +absl::Status RewriteGraphForExecution( + Graph* g, const std::vector>& feed_rewrites, + const std::vector>& fetch_rewrites, + const absl::Span& target_node_names, + RewriteGraphMetadata* out_metadata); + +///////////////////////////////////////////////////////// +// Custom rewrite actions for fed and fetched tensors. // +///////////////////////////////////////////////////////// + +// A rewrite action that adds an _Arg node for a fed tensor. +class ArgFeedRewrite : public PruneRewrite { + public: + ArgFeedRewrite(const string* endpoint_name, + const DeviceAttributes* device_info, int32_t arg_index) + : PruneRewrite(endpoint_name, device_info), arg_index_(arg_index) {} + absl::Status AddNode(Graph* g, NodeBuilder::NodeOut feed_tensor, + Node** out_node) override; + + private: + const int32 arg_index_; +}; + +// A rewrite action that adds a client-terminated _Recv node for a fed tensor. +class RecvFeedRewrite : public PruneRewrite { + public: + using PruneRewrite::PruneRewrite; + absl::Status AddNode(Graph* g, NodeBuilder::NodeOut feed_tensor, + Node** out_node) override; +}; + +// A rewrite action that adds a _Retval node for a fetched tensor. +class RetvalFetchRewrite : public PruneRewrite { + public: + RetvalFetchRewrite(const string* endpoint_name, + const DeviceAttributes* device_info, int32_t retval_index) + : PruneRewrite(endpoint_name, device_info), retval_index_(retval_index) {} + absl::Status AddNode(Graph* g, NodeBuilder::NodeOut fetch_tensor, + Node** out_node) override; + + private: + const int32 retval_index_; +}; + +// A rewrite action that adds a client-terminated _Send node for a +// fetched tensor. +class SendFetchRewrite : public PruneRewrite { + public: + using PruneRewrite::PruneRewrite; + absl::Status AddNode(Graph* g, NodeBuilder::NodeOut fetch_tensor, + Node** out_node) override; +}; + +} // namespace subgraph +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPH_SUBGRAPH_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/graph/tensor_id.h b/third_party/tflite-hdrs/tensorflow/core/graph/tensor_id.h new file mode 100644 index 00000000..0cdfb7d9 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/graph/tensor_id.h @@ -0,0 +1,94 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPH_TENSOR_ID_H_ +#define TENSORFLOW_CORE_GRAPH_TENSOR_ID_H_ + +#include + +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/hash/hash.h" +#include "tensorflow/core/lib/strings/strcat.h" + +namespace tensorflow { + +struct SafeTensorId; + +// Identifier for a tensor within a step. +// first == operation_name, second == output_index +// Note: does not own backing storage for name. +struct TensorId : public std::pair { + typedef std::pair Base; + + // Inherit the set of constructors. + using Base::pair; + + // NOTE(skyewm): this is required on some platforms. I'm not sure why the + // using statement above isn't always sufficient. + TensorId() : Base() {} + TensorId(const SafeTensorId& id); + + const absl::string_view node() const { return first; } + int index() const { return second; } + + string ToString() const { + if (second == Graph::kControlSlot) return strings::StrCat("^", first); + return strings::StrCat(first, ":", second); + } + + struct Hasher { + public: + std::size_t operator()(const TensorId& x) const { + return Hash32(x.first.data(), x.first.size(), x.second); + } + }; +}; + +TensorId ParseTensorName(const string& name); +TensorId ParseTensorName(absl::string_view name); + +bool IsTensorIdControl(const TensorId& tensor_id); + +// Same as TensorId, except owns the backing storage for the op name. This makes +// the memory management simpler at the expense of a copy. +struct SafeTensorId : public std::pair { + typedef std::pair Base; + + // NOTE(skyewm): this is required on some platforms. I'm not sure why the + // using "using Base::pair;" isn't always sufficient. + SafeTensorId() : Base() {} + SafeTensorId(const string& str, int idx) : Base(str, idx) {} + SafeTensorId(const TensorId& id); + + const string& node() const { return first; } + int index() const { return second; } + + string ToString() const { + if (second == Graph::kControlSlot) return strings::StrCat("^", first); + return strings::StrCat(first, ":", second); + } + + struct Hasher { + public: + std::size_t operator()(const TensorId& x) const { + return Hash32(x.first.data(), x.first.size(), x.second); + } + }; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPH_TENSOR_ID_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/graph/testlib.h b/third_party/tflite-hdrs/tensorflow/core/graph/testlib.h new file mode 100644 index 00000000..b2d1a416 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/graph/testlib.h @@ -0,0 +1,230 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// DEPRECATED: Use the C++ API defined in tensorflow/cc instead. + +#ifndef TENSORFLOW_CORE_GRAPH_TESTLIB_H_ +#define TENSORFLOW_CORE_GRAPH_TESTLIB_H_ + +#include +#include + +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/graph/types.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +namespace test { +namespace graph { + +// Converts "g" into its corresponding GraphDef "def". +ABSL_DEPRECATED("Call g->ToGraphDef(def) instead.") +void ToGraphDef(Graph* g, GraphDef* def); + +// A few helpers to construct a graph. + +// Adds a node in "g" producing a constant "tensor". +Node* Constant(Graph* g, const Tensor& tensor); +Node* Constant(Graph* g, const Tensor& tensor, const string& name); + +// Adds a node in "g" producing a constant "tensor" on the host. +// The given node which, unlike the regular Constant above, always +// stores its output on the host. This is necessary for use +// in GPU tests where the test Op in question runs on the device +// but requires some arguments to be pinned to the host. +Node* HostConstant(Graph* g, const Tensor& tensor); +Node* HostConstant(Graph* g, const Tensor& tensor, const string& name); + +// Adds a variable in "g" of the given "shape" and "dtype". +Node* Var(Graph* g, const DataType dtype, const TensorShape& shape); +Node* Var(Graph* g, const DataType dtype, const TensorShape& shape, + const string& name); + +// Adds an assign node in "g" which assigns "val" into "var". +Node* Assign(Graph* g, Node* var, Node* val); + +// Adds a send node "g" sending "input" as a named "tensor" from +// "sender" to "receiver". +Node* Send(Graph* g, Node* input, const string& tensor, const string& sender, + const uint64 sender_incarnation, const string& receiver); + +// Adds a recv node in "g" receiving a named "tensor" from "sender" +// to "receiver". +Node* Recv(Graph* g, const string& tensor, const string& type, + const string& sender, const uint64 sender_incarnation, + const string& receiver); + +// Adds a cumsum "node" in "g" doing cumsum(data, axes). +Node* Cumsum(Graph* g, Node* data, Node* axes, bool exclusive = false, + bool reverse = false); + +// Adds a reduction "node" in "g" doing sum(data, axes). "reduce" is +// a reduction, e.g., Sum, Max, Min, Mean, etc. +Node* Reduce(Graph* g, const string& reduce, Node* data, Node* axes, + bool keep_dims = false); + +// Adds a Matmul node in g doing in0.contract(in1). +Node* Matmul(Graph* g, Node* in0, Node* in1, bool transpose_a, + bool transpose_b); + +// Adds a Matmul node in g doing in0.contract(in1). +Node* BatchMatmul(Graph* g, Node* in0, Node* in1, bool adj_x, bool adj_y); + +// Adds a Quantize node into g that quantize floats into QUINT8. The range of +// the input float tensor is assumed to be [-1, 1]. +Node* QuantizeToUINT8(Graph* g, Node* data); + +// Adds a unary function "func" "node" in "g" taking "input". +Node* Unary(Graph* g, const string& func, Node* input, int index = 0); + +// Adds an identity node in "g" taking "input" and producing an +// identity copy. +Node* Identity(Graph* g, Node* input, int index = 0); + +// Adds a binary function "func" node in "g" taking "in0" and "in1". +Node* Binary(Graph* g, const string& func, Node* in0, Node* in1); + +// Adds a function "func" node in "g" taking inputs "ins". +Node* Multi(Graph* g, const string& func, absl::Span ins); + +// Adds a binary add node in "g" doing in0 + in1. +Node* Add(Graph* g, Node* in0, Node* in1); + +// Reverses dimensions of > +Node* Reverse(Graph* g, Node* tensor, Node* axis); + +// Generates random unit uniform distribution of the input shape. +Node* RandomUniform(Graph* g, Node* input, DataType dtype); + +// Generates random unit normal distribution of the input shape. +Node* RandomGaussian(Graph* g, Node* input, DataType dtype); + +// Generates random gamma distribution with the given shape and alpha[s]. +// Output dtype determined by alpha. +Node* RandomGamma(Graph* g, Node* shape, Node* alpha); + +// Generates random poisson distribution with the given shape and lam[s]. +// Output dtype determined by lam. +Node* RandomPoisson(Graph* g, Node* shape, Node* lam); + +// Rolls tensor by an offset of along the corresponding +// dimensions. +Node* Roll(Graph* g, Node* input, Node* shift, Node* axis); + +// Generates random parameters from the truncated standard normal distribution +// of the input shape +Node* TruncatedNormal(Graph* g, Node* input, DataType dtype); + +// Adds an error node in "g". The node's computation always +// generates an error with the given error message "errmsg". +Node* Error(Graph* g, Node* input, const string& errmsg, + bool log_error = false); + +// Adds a node that generates a invalid ref output. +Node* InvalidRefType(Graph* g, DataType out_type, DataType invalid_type); + +// Adds a node in "g". Its Compute() sleeps a while and outputs the +// input (i.e., same as identity). +Node* Delay(Graph* g, Node* input, Microseconds delay_micros); + +// Adds a no-op "node" in "g", with control inputs from all nodes in +// control_inputs vector. +Node* NoOp(Graph* g, const std::vector& control_inputs); + +// Adds a Switch node in "g". If "in1" is true, it forwards "in0" to +// output 1. Otherwise, it forwards "in0" to output 0. +Node* Switch(Graph* g, Node* in0, Node* in1); + +// Adds an Enter node in "g", which enters a new frame. +Node* Enter(Graph* g, Node* input, const string& frame_name); + +// Adds an Exit node in "g", which exits a frame. +Node* Exit(Graph* g, Node* input); + +// Adds a Merge node in "g" with two inputs "in0" and "in1". +Node* Merge(Graph* g, Node* in0, Node* in1); + +// Adds a Merge node in "g". The first input is "in0", the remaining +// inputs are only given by their names in remaining_in. +Node* Merge(Graph* g, Node* in0, absl::Span remaining_in); + +// Adds a NextIteration node in "g", which makes its input available +// to the next iteration. +Node* Next(Graph* g, const string& name, Node* input); + +// Adds a LoopCond node in "g", representing the "pivot" termination +// condition of a loop. +Node* LoopCond(Graph* g, Node* input); + +// Adds a less node in "g", which returns true iff "in0" < "in1". +Node* Less(Graph* g, Node* in0, Node* in1); + +// Adds a select node in "g", which outputs either "inx" or "iny" +// depending on the boolean value of "c". +Node* Select(Graph* g, Node* c, Node* inx, Node* iny); + +// Casts "in" into data type "dst". +Node* Cast(Graph* g, Node* in, DataType dst); + +// Perform gather op on params "in0" with indices "in1" and axis "axis". +Node* Gather(Graph* g, Node* in0, Node* in1, Node* axis); + +// Gets a tensor stored in the session state. +Node* GetSessionTensor(Graph* g, Node* in); + +// Adds a Concat node in "g". The first input is "concat_dim", the +// dimension to concatenate on, and the tensors to concatenate are +// given in "tensors". +Node* Concat(Graph* g, Node* concat_dim, absl::Span tensors); + +// Adds a ConcatV2 node in "g". The last input is "concat_dim", the +// dimension to concatenate on, and the tensors to concatenate are +// given in "tensors". +Node* ConcatV2(Graph* g, absl::Span tensors, Node* concat_dim); + +// Add a Relu node in "g". +Node* Relu(Graph* g, Node* in); + +// Add a Relu6 node in "g". +Node* Relu6(Graph* g, Node* in); + +// Add a BiasAdd node in "g". +Node* BiasAdd(Graph* g, Node* value, Node* bias); + +// Add a Conv2D node in "g". +Node* Conv2D(Graph* g, Node* in0, Node* in1); + +// Add a Diag node in "g". +Node* Diag(Graph* g, Node* in, DataType type); + +// Add a DiagPart node in "g". +Node* DiagPart(Graph* g, Node* in, DataType type); + +// Add a CheckNumerics node in "g". +Node* CheckNumerics(Graph* g, Node* in, const string& message); + +// Add an _Arg node in "g". +Node* Arg(Graph* g, int64_t index, DataType type); + +// Add a _Retval node in "g". +Node* Retval(Graph* g, int64_t index, Node* in, int64_t in_index = 0); + +} // end namespace graph +} // end namespace test +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPH_TESTLIB_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/graph/types.h b/third_party/tflite-hdrs/tensorflow/core/graph/types.h new file mode 100644 index 00000000..05dd03ab --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/graph/types.h @@ -0,0 +1,35 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPH_TYPES_H_ +#define TENSORFLOW_CORE_GRAPH_TYPES_H_ + +#include "tensorflow/core/lib/gtl/int_type.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +// We model running time in microseconds. +TSL_LIB_GTL_DEFINE_INT_TYPE(Microseconds, int64_t); + +// We can also model running time in nanoseconds for more accuracy. +TSL_LIB_GTL_DEFINE_INT_TYPE(Nanoseconds, int64_t); + +// We model size in bytes. +TSL_LIB_GTL_DEFINE_INT_TYPE(Bytes, int64_t); + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPH_TYPES_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/graph/validate.h b/third_party/tflite-hdrs/tensorflow/core/graph/validate.h new file mode 100644 index 00000000..3d59219b --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/graph/validate.h @@ -0,0 +1,68 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPH_VALIDATE_H_ +#define TENSORFLOW_CORE_GRAPH_VALIDATE_H_ + +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { +namespace graph { + +// Returns OK if every NodeDef in `graph_def` is valid with respect to +// its corresponding OpDef (as defined by ValidateNodeDef()) as +// registered in `op_registry`. Also checks for deprecated ops. +// +// REQUIRES: +// * `op_registry` is not nullptr. +// * `graph_def` has default attrs filled in (see AddDefaultAttrsToGraphDef()). +absl::Status ValidateGraphDef(const GraphDef& graph_def, + const OpRegistryInterface& op_registry); + +// Like ValidateGraphDef() except it makes a copy of `graph_def` and calls +// AddDefaultAttrsToGraphDef() on the copy, removing that requirement from the +// caller. +absl::Status ValidateGraphDefAgainstOpRegistry( + const GraphDef& graph_def, const OpRegistryInterface& op_registry); + +// Like ValidateGraphDefAgainstOpRegistry() except it takes an OpList +// instead of an OpRegistryInterface. Note that the OpList need not +// have descriptions, which can be a big space savings, see +// GetOpListForValidation() below. +absl::Status ValidateGraphDefAgainstOpList(const GraphDef& graph_def, + const OpList& op_list); + +// Get an OpList from `*op_registry` with all the descriptions removed. +void GetOpListForValidation( + OpList* op_list, const OpRegistry& op_registry = *OpRegistry::Global()); + +// Validate that the graph has no cycle except for legal while loop cycles. +// This traverses the specified nodes in topological order to verify there are +// no cycles. Starting with inputless nodes, it visits nodes whose inputs have +// all been visited, and counts the total number of visited nodes. If there is a +// cycle, nodes in the cycle will never be visited, and the visited count will +// be less than the total node count. +absl::Status ValidateGraphHasNoCycle(const Graph& graph); + +// Returns OK if the graph has no duplicate node names. +absl::Status VerifyNoDuplicateNodeNames(const GraphDef& graph); + +} // namespace graph +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPH_VALIDATE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/graph/while_context.h b/third_party/tflite-hdrs/tensorflow/core/graph/while_context.h new file mode 100644 index 00000000..e23e9df9 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/graph/while_context.h @@ -0,0 +1,76 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPH_WHILE_CONTEXT_H_ +#define TENSORFLOW_CORE_GRAPH_WHILE_CONTEXT_H_ + +#include "tensorflow/core/graph/graph.h" + +namespace tensorflow { + +// Information about a while loop. Every user-defined while loop has an +// associated WhileContext, i.e., there is a WhileContext for every execution +// frame. Created with the while loop and used during gradient +// construction. Note that the gradient graph of while loop contains while loops +// itself, but these do not generate separate WhileContexts. +// +// TODO(skyewm): this is currently insufficient to handle nested loops and +// conditionals (and possibly other requirements). This may change a lot in the +// future to support these features. +// +// TODO(skyewm): de/serialize in MetaGraphDef so imported while loops will be +// differentiable. Figure out backwards compatibility story. +class WhileContext { + public: + WhileContext(absl::string_view frame_name, std::vector enter_nodes, + std::vector exit_nodes, OutputTensor cond_output, + std::vector body_inputs, + std::vector body_outputs); + + const string& frame_name() const { return frame_name_; } + const std::vector& enter_nodes() const { return enter_nodes_; } + const std::vector& exit_nodes() const { return exit_nodes_; } + const OutputTensor& cond_output() const { return cond_output_; } + const std::vector& body_inputs() const { return body_inputs_; } + const std::vector& body_outputs() const { + return body_outputs_; + } + + private: + // Each user-defined while loop defines a new execution frame, which is + // uniquely identified by its frame name. Frames are used by the executor to + // manage the iterations of a loop. See the FrameState comment in + // core/common_runtime/executor.cc for more details. + const string frame_name_; + + // The enter nodes defining the input loop variables to the while loop. This + // vector defines the order of the loop variables. + const std::vector enter_nodes_; + + // The exit nodes defining the outputs of the while loop. These are in loop + // variable order. + const std::vector exit_nodes_; + + // The boolean output of the loop predicate. + const OutputTensor cond_output_; + + // The inputs and outputs to the loop body. + const std::vector body_inputs_; + const std::vector body_outputs_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPH_WHILE_CONTEXT_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/graph/zen_graph_util.h b/third_party/tflite-hdrs/tensorflow/core/graph/zen_graph_util.h new file mode 100644 index 00000000..7dc23fbc --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/graph/zen_graph_util.h @@ -0,0 +1,83 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPH_ZEN_GRAPH_UTIL_H_ +#define TENSORFLOW_CORE_GRAPH_ZEN_GRAPH_UTIL_H_ +#ifdef AMD_ZENDNN + +#include +#include + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/cpu_info.h" +#include "tensorflow/core/util/env_var.h" + +namespace tensorflow { + +namespace zen_op_registry { + +// Prefix that we add to Tensorflow op name to construct Zen op name. +static const char* const kZenNodePrefix = "_Zen"; + +// Get the name of Zen op from original TensorFlow op. +// We prefix the original op with "Zen" to get Zen op. +inline string GetZenOpName(const string& name) { + return string(kZenNodePrefix) + name; +} + +// Check whether op name with type T is registered as Zen operator +// that will go through name change or layout change pass. +// +// @input op_name - name of the op. +// @input T - datatype to be used for checking op. +// @return true if op name is registered as Zen op that will go through name +// change or layout change pass; false otherwise. +static inline bool IsZenOpKernelRegistered(const string& op_name, DataType T) { + string registered_kernels_key = op_name + string(DataType_Name(T)); + thread_local static auto* registered_kernels_map = + new absl::flat_hash_map(); + auto kernel_element = registered_kernels_map->find(registered_kernels_key); + bool kernel_registered = false; + + if (kernel_element == registered_kernels_map->end()) { + string registered_kernels = KernelsRegisteredForOp(op_name); + // String returned by KernelsRegisteredForOp looks like below: + // + // Op = ZenMatMul, kernels = + // device='CPU'; T in [DT_FLOAT] + // device='CPU'; T in [DT_DOUBLE] + + // If we have multiple kernels registered for the op. We need to verify + // our datatype + if (registered_kernels.find(string(DataType_Name(T))) != string::npos) { + kernel_registered = true; + } + registered_kernels_map->insert( + std::make_pair(registered_kernels_key, kernel_registered)); + } else { + // Kernel is visited at least once. Return stored registration result. + kernel_registered = kernel_element->second; + } + return kernel_registered; +} + +} // namespace zen_op_registry +} // namespace tensorflow + +#endif // AMD_ZENDNN +#endif // TENSORFLOW_CORE_GRAPH_ZEN_GRAPH_UTIL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/grappler/clusters/cluster.h b/third_party/tflite-hdrs/tensorflow/core/grappler/clusters/cluster.h new file mode 100644 index 00000000..36aec54c --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/grappler/clusters/cluster.h @@ -0,0 +1,149 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPPLER_CLUSTERS_CLUSTER_H_ +#define TENSORFLOW_CORE_GRAPPLER_CLUSTERS_CLUSTER_H_ + +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "tensorflow/core/common_runtime/device_set.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/grappler/grappler_item.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/protobuf/config.pb.h" +#include "tensorflow/core/protobuf/device_properties.pb.h" +#include "tensorflow/core/public/session_options.h" + +namespace tensorflow { +namespace grappler { + +// A cluster represents of collection of hardware resources available to run +// the TensorFlow model. +// A process can only create a single cluster at a time. +class Cluster { + public: + explicit Cluster(int timeout_s); + virtual ~Cluster(); + + // Returns a string that represent the type of cluster that was instantiated. + virtual string type() const = 0; + + // Provision the hardware resources needed to run TensorFlow and start a + // TensorFlow session that can take advantage of these resources. + // The actual resources that are leveraged depend on the type of cluster + // instantiated. + // Returns OK iff all the requested resources could be reserved and a + // TensorFlow session successfully created. Returns an error otherwise. + // There is no graceful degradation to handle the case where only a subset + // of the requested resources are available. + virtual absl::Status Provision() = 0; + + // Attempts to shutdown the cluster. + // Returns OK iff there are no pending calls to the Run() method and all the + // resources used by the cluster could be released. Returns an error + // otherwise. + virtual absl::Status Shutdown() { return absl::OkStatus(); } + + // Whether soft placement is allowed. If allow_soft_placement is true, + // an op will be placed on CPU if there's no GPU implementation for the OP + // or if no GPU devices are known or registered or if we need to co-locate + // with reftype input(s) which are from CPU. + void AllowSoftPlacement(bool soft_placement_state); + + // Update the number of inter-op threads for each per-session threadpool + void SetNumInterOpThreads(int num_threads); + + // Set the number of steps required to warmup TensorFlow. Must be called + // before Provision(). + void SetNumWarmupSteps(int num_steps); + + // Set executor type to instantiate + void SetExecutorType(const string* executor_type); + + // Returns the number of warmup steps. + int NumWarmupSteps() const; + + // Disable the collection of detailed statistics. Must be called + // before Provision(). + void DisableDetailedStats(bool disable); + + // Returns true iff the collection of detailed statistics is enabled. + bool DetailedStatsEnabled() const; + + // Disable the TensorFlow optimizer. This ensures that the graph that TF + // executes is similar to the input graph. Must be called before Provision(). + void DisableOptimizer(bool disable); + + // Return the list of TensorFlow devices that are available to execute a + // graph. This is empty until provision() is called. + const std::unordered_map& GetDevices() const { + return devices_; + } + + // Convenience method that returns the set of device names. These names are + // sorted alphabetically. + const std::vector GetDeviceNames() const; + + // The DeviceSet is not always available, but when it is it contains a + // superset of the devices listed in GetDevices/GetDeviceNames(). + virtual const DeviceSet* GetDeviceSet() const { return nullptr; } + + // Enables collecting the allocator stats. If called, must be called before + // Provision(). + virtual absl::Status EnablePeakMemoryStats() { + return absl::UnimplementedError(strings ::StrCat( + "Peak Memory Stats are not supported on ", type(), " clusters")); + } + + // Returns peak memory of all devices during the session creation and session + // runs. + virtual absl::Status GetPeakMemoryUsage( + std::unordered_map* device_peak_memory) const { + return absl::UnimplementedError( + "GetPeakMemoryUsage is not implemented for this type of cluster."); + } + + // Prepare the session to run the specified grappler item. This include + // initializing all the model variables. + virtual absl::Status Initialize(const GrapplerItem& item) = 0; + + // Run the specified graph_def and return the corresponding metadata. + virtual absl::Status Run(const GraphDef& graph_def, + const std::vector>& feed, + const std::vector& fetch, + RunMetadata* metadata) = 0; + + // Run the specified GrapplerItem and return the corresponding metadata. + virtual absl::Status Run(const GrapplerItem& item, RunMetadata* metadata) { + return Run(item.graph, item.feed, item.fetch, metadata); + } + + protected: + std::unordered_map devices_; + const int timeout_s_; + SessionOptions options_; + RunOptions run_options_; +}; + +} // end namespace grappler +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPPLER_CLUSTERS_CLUSTER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/grappler/clusters/single_machine.h b/third_party/tflite-hdrs/tensorflow/core/grappler/clusters/single_machine.h new file mode 100644 index 00000000..f3f36626 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/grappler/clusters/single_machine.h @@ -0,0 +1,104 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPPLER_CLUSTERS_SINGLE_MACHINE_H_ +#define TENSORFLOW_CORE_GRAPPLER_CLUSTERS_SINGLE_MACHINE_H_ + +#include +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "tensorflow/cc/training/coordinator.h" +#include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/framework/cost_graph.pb.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/grappler/clusters/cluster.h" +#include "tensorflow/core/lib/core/threadpool.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/protobuf/config.pb.h" +#include "tensorflow/core/protobuf/queue_runner.pb.h" +#include "tensorflow/core/public/session.h" + +namespace tensorflow { +namespace grappler { + +// Create a simple cluster that makes available to grappler a subset of the +// nodes available on a single local computer. +class SingleMachine : public Cluster { + public: + SingleMachine(int timeout_s, int num_cpu_cores, int num_gpus); + ~SingleMachine() override; + + string type() const override { return "single_machine"; } + + absl::Status Provision() override; + absl::Status Shutdown() override; + + absl::Status Initialize(const GrapplerItem& item) override; + absl::Status Run(const GraphDef& item, + const std::vector>& feed, + const std::vector& fetch, + RunMetadata* metadata) override; + + const DeviceSet* GetDeviceSet() const override { return device_set_.get(); } + + absl::Status EnablePeakMemoryStats() override; + + // It requires EnableAllocatorStats(true) be called before Provision(). + absl::Status GetPeakMemoryUsage( + std::unordered_map* device_peak_memory) const override; + + private: + absl::Status RunWithTimeout( + const std::vector>& feed, + const std::vector& fetch, RunMetadata* run_metadata); + absl::Status RunWithTimeout( + const std::vector>& feed, + const std::vector& fetch, RunMetadata* run_metadata, + int64_t timeout_s); + absl::Status ResetSession(); + absl::Status CloseSession(bool use_timeout); + absl::Status ShutdownSession(); + void MergeCosts(CostGraphDef* graph_costs, const CostGraphDef& init_costs, + const CostGraphDef& queue_costs); + + absl::Status ClearAllocatorStats() const; + + std::unique_ptr session_; + std::vector queue_runner_defs_; + string last_graph_id_; + mutex last_graph_mu_; + const GraphDef* last_graph_ TF_GUARDED_BY(last_graph_mu_) = nullptr; + std::vector init_ops_; + int64_t expected_init_time_s_; + std::unique_ptr coordinator_; + std::unique_ptr thread_pool_; + std::unique_ptr device_set_; + + RunMetadata init_metadata_; + + mutex close_mu_; + bool closing_ TF_GUARDED_BY(close_mu_); + + bool cpu_allocator_stats_enabled_ = false; +}; + +} // end namespace grappler +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPPLER_CLUSTERS_SINGLE_MACHINE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/grappler/clusters/utils.h b/third_party/tflite-hdrs/tensorflow/core/grappler/clusters/utils.h new file mode 100644 index 00000000..8a597854 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/grappler/clusters/utils.h @@ -0,0 +1,39 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPPLER_CLUSTERS_UTILS_H_ +#define TENSORFLOW_CORE_GRAPPLER_CLUSTERS_UTILS_H_ + +#include "tensorflow/core/common_runtime/gpu/gpu_id.h" +#include "tensorflow/core/protobuf/device_properties.pb.h" +#include "tensorflow/core/util/device_name_utils.h" + +namespace tensorflow { +namespace grappler { + +// Returns the DeviceProperties of the CPU on which grappler is running. +DeviceProperties GetLocalCPUInfo(); + +// Returns the DeviceProperties for the specified GPU attached to the server on +// which grappler is running. +DeviceProperties GetLocalGPUInfo(PlatformDeviceId platform_device_id); + +// Returns the DeviceProperties of the specified device +DeviceProperties GetDeviceInfo(const DeviceNameUtils::ParsedName& device); + +} // end namespace grappler +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPPLER_CLUSTERS_UTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/grappler/clusters/virtual_cluster.h b/third_party/tflite-hdrs/tensorflow/core/grappler/clusters/virtual_cluster.h new file mode 100644 index 00000000..1204a34c --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/grappler/clusters/virtual_cluster.h @@ -0,0 +1,69 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPPLER_CLUSTERS_VIRTUAL_CLUSTER_H_ +#define TENSORFLOW_CORE_GRAPPLER_CLUSTERS_VIRTUAL_CLUSTER_H_ + +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "tensorflow/core/common_runtime/device_set.h" +#include "tensorflow/core/grappler/clusters/cluster.h" +#include "tensorflow/core/grappler/costs/analytical_cost_estimator.h" +#include "tensorflow/core/grappler/costs/op_level_cost_estimator.h" +#include "tensorflow/core/grappler/costs/virtual_scheduler.h" +#include "tensorflow/core/protobuf/config.pb.h" +#include "tensorflow/core/protobuf/device_properties.pb.h" + +namespace tensorflow { +namespace grappler { + +// Create a simple cluster that lists the devices (and their properties) +// available in a TensorFlow session. This cluster simulates the execution of +// actual graphs. +class VirtualCluster : public Cluster { + public: + explicit VirtualCluster( + const std::unordered_map& devices); + VirtualCluster(const std::unordered_map& devices, + std::unique_ptr node_estimator, + std::unique_ptr node_manager); + explicit VirtualCluster(const DeviceSet* device_set); + + ~VirtualCluster() override; + + string type() const override { return "virtual"; } + + absl::Status Provision() override; + absl::Status Initialize(const GrapplerItem& item) override; + absl::Status Run(const GraphDef& graph, + const std::vector>& feed, + const std::vector& fetch, + RunMetadata* metadata) override; + absl::Status Run(const GrapplerItem& item, RunMetadata* metadata) override; + const DeviceSet* GetDeviceSet() const override { return device_set_; } + + private: + std::unique_ptr estimator_; + const DeviceSet* device_set_ = nullptr; +}; + +} // end namespace grappler +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPPLER_CLUSTERS_VIRTUAL_CLUSTER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/grappler/costs/analytical_cost_estimator.h b/third_party/tflite-hdrs/tensorflow/core/grappler/costs/analytical_cost_estimator.h new file mode 100644 index 00000000..b31ce39e --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/grappler/costs/analytical_cost_estimator.h @@ -0,0 +1,83 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPPLER_COSTS_ANALYTICAL_COST_ESTIMATOR_H_ +#define TENSORFLOW_CORE_GRAPPLER_COSTS_ANALYTICAL_COST_ESTIMATOR_H_ + +#include "tensorflow/core/grappler/costs/cost_estimator.h" +#include "tensorflow/core/grappler/costs/op_level_cost_estimator.h" +#include "tensorflow/core/grappler/costs/virtual_scheduler.h" +#include "tensorflow/core/grappler/grappler_item.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { +class CostGraphDef; +class GraphDef; +} // namespace tensorflow + +namespace tensorflow { +namespace grappler { + +class Cluster; +struct GrapplerItem; + +// Estimate the cost of running a Grappler item based on the theoretical +// performance of the hardware that will run the model. Note that this +// internally uses static shape inference. An option for aggressive shape +// inference is provided to minimize unknown shapes, and this is only applicable +// with static shape inference. +class AnalyticalCostEstimator : public CostEstimator { + public: + AnalyticalCostEstimator(Cluster* cluster, bool use_static_shapes, + bool use_aggressive_shape_inference); + AnalyticalCostEstimator(Cluster* cluster, + std::unique_ptr node_estimator, + std::unique_ptr node_manager, + bool use_static_shapes, + bool use_aggressive_shape_inference); + AnalyticalCostEstimator(Cluster* cluster, + std::unique_ptr node_estimator, + std::unique_ptr node_manager, + std::unique_ptr placer, + bool use_static_shapes, + bool use_aggressive_shape_inference); + ~AnalyticalCostEstimator() override {} + + // This implementation always returns OK. + absl::Status Initialize(const GrapplerItem& item) override; + + // Predict the performance of each node of the optimized graph and annotate + // the RunMetadata with the corresponding estimates. Also returns the + // expected cost for the whole graph. + absl::Status PredictCosts(const GraphDef& optimized_graph, + RunMetadata* run_metadata, + Costs* cost) const override; + + const VirtualScheduler* GetScheduler() const { return scheduler_.get(); } + + private: + const GrapplerItem* item_; + std::unique_ptr node_estimator_; + std::unique_ptr node_manager_; + std::unique_ptr scheduler_; + + bool use_static_shapes_; + bool use_aggressive_shape_inference_; +}; + +} // end namespace grappler +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPPLER_COSTS_ANALYTICAL_COST_ESTIMATOR_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/grappler/costs/cost_estimator.h b/third_party/tflite-hdrs/tensorflow/core/grappler/costs/cost_estimator.h new file mode 100644 index 00000000..b133b369 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/grappler/costs/cost_estimator.h @@ -0,0 +1,259 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPPLER_COSTS_COST_ESTIMATOR_H_ +#define TENSORFLOW_CORE_GRAPPLER_COSTS_COST_ESTIMATOR_H_ + +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/protobuf/config.pb.h" + +namespace tensorflow { +class GraphDef; +class CostGraphDef; + +namespace grappler { +struct GrapplerItem; + +constexpr uint64_t kMemoryUnknown = std::numeric_limits::max(); +constexpr uint64_t kZeroMemory = 0ULL; + +struct DeviceInfo { + // Billions of operations executed per second. + double gigaops; + + // Bandwidth to main memory in GB per second. + double gb_per_sec; + + // Read bandwidth to intermediate memory in GB per second. + double intermediate_read_gb_per_sec; + + // Write bandwidth to intermediate memory in GB per second. + double intermediate_write_gb_per_sec; + + DeviceInfo() + : gigaops(INFINITY), + gb_per_sec(INFINITY), + intermediate_read_gb_per_sec(INFINITY), + intermediate_write_gb_per_sec(INFINITY) {} + + DeviceInfo(const DeviceInfo& input) + : gigaops(input.gigaops), + gb_per_sec(input.gb_per_sec), + intermediate_read_gb_per_sec(input.intermediate_read_gb_per_sec), + intermediate_write_gb_per_sec(input.intermediate_write_gb_per_sec) {} + + DeviceInfo(double gigaops, double gb_per_sec, + double intermediate_read_gb_per_sec = INFINITY, + double intermediate_write_gb_per_sec = INFINITY) + : gigaops(gigaops), + gb_per_sec(gb_per_sec), + intermediate_read_gb_per_sec(intermediate_read_gb_per_sec), + intermediate_write_gb_per_sec(intermediate_write_gb_per_sec) {} +}; + +// Holds the set of things we might want to estimate or measure in Grappler. +// Always produce execution time. Other fields are optional depending on the +// estimator being used. +struct Costs { + // Returns a Costs structure with default values for all of the fields. + inline Costs(); + + // Builds a Costs structure with all zero values, rather than unknowns. + static inline Costs ZeroCosts(bool inaccurate = false); + + struct MilliSeconds : std::chrono::milliseconds { + MilliSeconds() : std::chrono::milliseconds(0) {} + MilliSeconds(double d) + : std::chrono::milliseconds(static_cast(d)) {} + MilliSeconds(const std::chrono::milliseconds& d) + : std::chrono::milliseconds(d) {} + MilliSeconds& operator=(const std::chrono::milliseconds& d) { + std::chrono::milliseconds::operator=(d); + return *this; + } + }; + struct MicroSeconds : std::chrono::microseconds { + MicroSeconds() : std::chrono::microseconds(0) {} + MicroSeconds(double d) + : std::chrono::microseconds(static_cast(d)) {} + MicroSeconds(const std::chrono::microseconds& d) + : std::chrono::microseconds(d) {} + MicroSeconds& operator=(const std::chrono::microseconds& d) { + std::chrono::microseconds::operator=(d); + return *this; + } + MilliSeconds asMilliSeconds() const { + return std::chrono::duration_cast(*this); + } + }; + struct NanoSeconds : std::chrono::nanoseconds { + NanoSeconds() : std::chrono::nanoseconds(0) {} + NanoSeconds(double d) : std::chrono::nanoseconds(static_cast(d)) {} + NanoSeconds(const std::chrono::nanoseconds& d) + : std::chrono::nanoseconds(d) {} + NanoSeconds& operator=(const std::chrono::nanoseconds& d) { + std::chrono::nanoseconds::operator=(d); + return *this; + } + MicroSeconds asMicroSeconds() const { + return std::chrono::duration_cast(*this); + } + MilliSeconds asMilliSeconds() const { + return std::chrono::duration_cast(*this); + } + static NanoSeconds infinity() { + return NanoSeconds(std::chrono::nanoseconds::max()); + } + }; + + // We store all our times in nanoseconds. If needs be, we can always switch to + // picoseconds in the future by updating this typedef. + typedef NanoSeconds Duration; + + // Overall cost of running the graph; latency. + Duration execution_time; + + // Computation cost of running the graph. + Duration compute_time; + + // Memory access cost of running the graph. + Duration memory_time; + + // Intermediate memory access cost of running the graph + Duration intermediate_memory_time; + Duration intermediate_memory_read_time; // Intermediate memory read cost. + Duration intermediate_memory_write_time; // Intermediate memory write cost. + + // Network time (colelctived ops - all gather, all reduce, etc.) + Duration network_time; + + // This field can be a very pessimistic estimate of the main memory + // requirements of a graph. For example, it might assume that all activations + // are live for all of a graph's execution. + uint64_t max_memory; // Max main memory requirement in bytes over all ops. + uint64_t persistent_memory; + uint64_t temporary_memory; + + // Output memory usage per port. + absl::flat_hash_map output_tensor_size_bytes; + + // Track persistent versus temporary memory. + absl::flat_hash_set persistent_output_ports; + + // These fields are used for TPU-related estimations. They are per-op + // maximums, so each op is evaluated independently, but we want the maximum of + // the value over all ops. + int64_t max_per_op_buffers; // Sum of all buffers used by the ops. + int64_t max_per_op_streaming; // Ignore largest input buffer, assuming it + // streams from main memory. + + // Number of ops included in this Costs in total. + // Default initialized to be one. + int64_t num_ops_total = 1; + // If the time estimation is inaccurate. + bool inaccurate = false; + // Number of ops that are estimated with unknown shapes. + int64_t num_ops_with_unknown_shapes = 0; + // TODO(pcma): include a counter for total inaccurate ops and counters for + // other reasons causing the inaccuracy + + // Max possible memory usage per device. + std::unordered_map estimated_max_memory_per_device; +}; + +inline std::ostream& operator<<(std::ostream& os, const Costs::MilliSeconds d) { + os << d.count() << "ms"; + return os; +} +inline std::ostream& operator<<(std::ostream& os, const Costs::MicroSeconds d) { + os << d.count() << "us"; + return os; +} +inline std::ostream& operator<<(std::ostream& os, const Costs::NanoSeconds d) { + os << d.count() << "ns"; + return os; +} + +Costs::Costs() { + execution_time = Duration::zero(); + compute_time = Duration::zero(); + memory_time = Duration::zero(); + intermediate_memory_time = Duration::zero(); + network_time = Duration::zero(); + max_memory = kMemoryUnknown; + persistent_memory = kMemoryUnknown; + temporary_memory = kMemoryUnknown; + max_per_op_buffers = kMemoryUnknown; + max_per_op_streaming = kMemoryUnknown; +} + +Costs Costs::ZeroCosts(bool inaccurate) { + Costs costs; + costs.execution_time = Duration::zero(); + costs.compute_time = Duration::zero(); + costs.memory_time = Duration::zero(); + costs.intermediate_memory_time = Duration::zero(); + costs.network_time = Duration::zero(); + costs.max_memory = kZeroMemory; + costs.persistent_memory = kZeroMemory; + costs.temporary_memory = kZeroMemory; + costs.max_per_op_buffers = kZeroMemory; + costs.max_per_op_streaming = kZeroMemory; + costs.inaccurate = inaccurate; + return costs; +} + +Costs CombineCosts(const Costs& left, const Costs& right); + +// Multiplies Costs by a scalar. +// Equivalent to applying CombineCosts "multiplier" times. +Costs MultiplyCosts(const Costs& costs, int multiplier); + +// Given a GrapperItem and an optimized implementation of the corresponding +// TensorFlow graph, the CostEstimator attempts to predicts the actual cost of +// running the graph. +class CostEstimator { + public: + virtual ~CostEstimator() {} + + // Initializes the estimator for the specified grappler item. + // The estimator shouldn't be used if this function returns any status other + // that OK. + virtual absl::Status Initialize(const GrapplerItem& item) = 0; + + // Predicts the cost of running the given optimized version of the grappler + // item. + // If a RunMetadata is passed, it will be populated with detailed information + // about the cost of running each operation of the optimized graph. + // if a double value is passed, it will be set to a value that reflects the + // overall cost of running the graph (e.g. the latency of the computation). + // Returns a status that indicate is the performance could be estimated or + // not. + virtual absl::Status PredictCosts(const GraphDef& optimized_graph, + RunMetadata* run_metadata, + Costs* cost) const = 0; +}; + +} // end namespace grappler +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPPLER_COSTS_COST_ESTIMATOR_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/grappler/costs/graph_memory.h b/third_party/tflite-hdrs/tensorflow/core/grappler/costs/graph_memory.h new file mode 100644 index 00000000..fcd9eaeb --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/grappler/costs/graph_memory.h @@ -0,0 +1,81 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPPLER_COSTS_GRAPH_MEMORY_H_ +#define TENSORFLOW_CORE_GRAPPLER_COSTS_GRAPH_MEMORY_H_ + +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/grappler/clusters/cluster.h" +#include "tensorflow/core/grappler/costs/cost_estimator.h" +#include "tensorflow/core/grappler/costs/graph_properties.h" +#include "tensorflow/core/grappler/grappler_item.h" + +namespace tensorflow { +namespace grappler { + +// Infer the worst case memory usage for a given grappler item. +class GraphMemory { + public: + struct LiveTensor { + string node; + int output_id; + size_t memory_used; + Costs::Duration allocation_time; + Costs::Duration deallocation_time; + }; + struct MemoryUsage { + int64_t used_memory; + std::vector live_tensors; + }; + + explicit GraphMemory(const GrapplerItem& item) + : item_(item), unknown_usage_({-1, {}}) {} + + absl::Status InferStatically( + const std::unordered_map& devices); + absl::Status InferDynamically(Cluster* cluster); + + // Worst case memory usage in bytes, or -1 if the usage is unknown. If there + // are multiple devices, returns the highest per device memory usage. + int64_t GetWorstCaseMemoryUsage() const; + + // Returns the peak memory usage for the specified device. + const MemoryUsage& GetPeakMemoryUsage(const string& device) const { + auto it = peak_usage_.find(device); + if (it == peak_usage_.end()) { + return unknown_usage_; + } + return it->second; + } + + private: + void InferMemUsageForNodes(const std::vector& nodes, + GraphProperties* properties, int64_t* worst_case, + int64_t* best_case) const; + int64_t InferMemUsageForNeighbors( + const std::vector& props) const; + + void InferFromTrace(const StepStats& timeline); + + const GrapplerItem& item_; + std::unordered_map worst_case_memory_usage_; + std::unordered_map peak_usage_; + const MemoryUsage unknown_usage_; +}; + +} // end namespace grappler +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPPLER_COSTS_GRAPH_MEMORY_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/grappler/costs/graph_properties.h b/third_party/tflite-hdrs/tensorflow/core/grappler/costs/graph_properties.h new file mode 100644 index 00000000..1d9575e1 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/grappler/costs/graph_properties.h @@ -0,0 +1,226 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPPLER_COSTS_GRAPH_PROPERTIES_H_ +#define TENSORFLOW_CORE_GRAPPLER_COSTS_GRAPH_PROPERTIES_H_ + +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "tensorflow/core/framework/shape_inference.h" +#include "tensorflow/core/grappler/clusters/cluster.h" +#include "tensorflow/core/grappler/costs/op_performance_data.pb.h" +#include "tensorflow/core/grappler/grappler_item.h" + +namespace tensorflow { + +namespace grappler { + +// Optional attributes that tell about node output information. +// We use these side information, if provided, for static shape inference +// and VirtualScheduler scheduling. + +// Switch op attribute as a vector of int that tells which branch the +// Switch output is taken on every round of execution. +// Used for scheduling ops after Switch correctly (e.g., While loop). +ABSL_CONST_INIT const char kOutputSlots[] = "_output_slot_vector"; + +// Example: +// Assume a node has two outputs and iterated for three times. Then it has: +// _execution_count = 3 +// _output_sizes_vector = [2, 2, 2] +// _output_dtype_vector.size = 6 +// _output_shape_vector.size = 6 + +// If all the iterations have same output shapes, then +// _execution_count = 3 +// _same_output_for_iterations = true +// _output_sizes_vector = [2] +// _output_dtype_vector.size = 2 +// _output_shape_vector.size = 2 + +// How many times this node has been executed. +ABSL_CONST_INIT const char kExecutionCount[] = "_execution_count"; + +// Records the output sizes for each round of execution. +ABSL_CONST_INIT const char kOutputSizes[] = "_output_sizes_vector"; + +// The node has been scheduled multiple times with outputs that have the same +// shape. +ABSL_CONST_INIT const char kOutputSame[] = "_same_output_for_iterations"; + +// Outputs DataType vector. +ABSL_CONST_INIT const char kOutputTypes[] = "_output_dtype_vector"; + +// Outputs TensorShapeProto vector. +ABSL_CONST_INIT const char kOutputShapes[] = "_output_shape_vector"; + +class SymbolicShapeRefiner; +class TopoQueue; + +// Infer OpInfo::TensorProperties for graph nodes inputs/outputs. +// +// Typical use case, is to infer tensor properties from a graph, before doing +// optimization pass. Nodes modified during optimization pass have to be +// invalidated, to prevent further incorrect optimizations based on wrong shape +// and data type properties. +class GraphProperties { + public: + // The item must outlive the properties + explicit GraphProperties(const GrapplerItem& item) : item_(item) {} + + // Infer the shapes through abstract interpretation. Feed information can be + // incorrect so it should be discarded to ensure correctness of the analysis. + // However, it can help infer shapes in the fanout of fed nodes (even though + // the correctness of these shapes can't be guaranteed), so in some cases + // (such as simulation or scheduling) it makes sense of keep these shapes. + // aggressive_shape_inference option executes nodes on the host to identify + // output values when possible and does other aggressive strategies. + // Similar to assuming_valid_feeds, this may cause incorrectness in graph + // analyses, but is useful for simulation or scheduling. + // If include_input_tensor_values is true, the values of constant tensors + // will included in the input properties. + // If include_output_tensor_values is true, the values of constant tensors + // will be included in the output properties. + absl::Status InferStatically(bool assume_valid_feeds, + bool aggressive_shape_inference, + bool include_input_tensor_values, + bool include_output_tensor_values); + absl::Status InferStatically(bool assume_valid_feeds, + bool aggressive_shape_inference, + bool include_tensor_values) { + return InferStatically( + assume_valid_feeds, + /*aggressive_shape_inference=*/aggressive_shape_inference, + /*include_input_tensor_values=*/include_tensor_values, + /*include_output_tensor_values=*/include_tensor_values); + } + absl::Status InferStatically(bool assume_valid_feeds) { + return InferStatically(assume_valid_feeds, + /*aggressive_shape_inference=*/false, + /*include_tensor_values=*/true); + } + // Infer the shape by running the graph on the specified cluster and recording + // the shapes of the processed tensors. + absl::Status InferDynamically(Cluster* cluster); + // Extract the properties from a cost graph. For testing only since there is + // no way to ensure that the cost graph match the item. + absl::Status InferFromCostGraph(const CostGraphDef& cost_graph); + + // Stores `item_.graph` with the inferred output shapes to `output_graph_def`. + absl::Status AnnotateOutputShapes(GraphDef* output_graph_def) const; + + // Return the properties of node inputs/outputs, including data types and + // shapes. Note that the dimensions in the shapes can be negative. We use the + // -1 value to denote that we don't know anything about a dimension. We use + // values strictly less than -1 to encode symbolic dimensions: although we + // don't know the actual value of the symbolic dimension, we know that all the + // dimensions denoted by the same negative value are the equal. + bool HasInputProperties(const string& node_name) const; + bool HasOutputProperties(const string& node_name) const; + const std::vector& GetInputProperties( + const string& node_name) const; + const std::vector& GetOutputProperties( + const string& node_name) const; + + // Invalidate input/output properties for nodes modified during graph + // optimization pass, to prevent potential optimizations, based on incorrect + // shape information. + void ClearInputProperties(const string& node_name); + void ClearOutputProperties(const string& node_name); + // Returns true if we have *any* properties. + bool has_properties() const { + return !input_properties_.empty() || !output_properties_.empty(); + } + + bool CheckShapeIncompatible(const string& node_name) const { + return incompatible_shape_nodes_.find(node_name) != + incompatible_shape_nodes_.end(); + } + + // Clear all infered properties. + void Clear() { + input_properties_.clear(); + output_properties_.clear(); + } + + private: + // Relaxes shapes , determined from an EnqueueV2 node, into + // <*queue_shapes_and_types>. + static absl::Status RelaxEnqueueShapesAndMergeTypes( + SymbolicShapeRefiner* shape_refiner, const NodeDef* qnode, + const std::vector& shapes_and_types, + std::vector* queue_shapes_and_types); + + // Update the shapes of the enqueue node, port them over to the corresponding + // queue, and schedule the reprocessing of the queue if needed. + static absl::Status UpdateEnqueue( + const NodeDef* enqueue_node, + const absl::flat_hash_map& + resource_handles, + SymbolicShapeRefiner* shape_refiner, bool* new_shapes); + + // Update the shapes and types of the Queue node, if not set by Enqueue node. + static absl::Status UpdateQueue(const NodeDef* queue_node, + SymbolicShapeRefiner* shape_refiner, + bool* new_shapes); + + // Update the output shapes of a Merge node, and enqueue its fanout in + // new_shapes if needed. + absl::Status UpdateMerge(SymbolicShapeRefiner* shape_refiner, + const NodeDef* node, bool* new_shapes) const; + // Process the Enter node, and enqueue its fanout in new_shapes if needed. + static absl::Status UpdateEnter(SymbolicShapeRefiner* shape_refiner, + const NodeDef* node, bool* new_shapes); + // Update the shapes for node 'n'. If output shapes for n have changed, + // enqueue its fanout in 'new_shapes'. + absl::Status UpdateShapes( + SymbolicShapeRefiner* shape_refiner, + const absl::flat_hash_map& + resource_handles, + const NodeDef* n, bool* new_shapes) const; + // Propagate the shapes for the nodes enqueued in new_shapes and their + // transitive fanout until a fixed point is reached. + absl::Status PropagateShapes( + SymbolicShapeRefiner* shape_refiner, TopoQueue* new_shapes, + const absl::flat_hash_map& + resource_handles, + int num_loops) const; + + // Data members + const GrapplerItem& item_; + absl::flat_hash_map> + input_properties_; + absl::flat_hash_map> + output_properties_; + const std::vector missing_properties_; + + // Nodes with output shape incompatible between shape inference and + // annotation. + std::unordered_set incompatible_shape_nodes_; +}; + +// Helper function for GraphProperties. +bool IsShapeFullyDefinedIntegerVectorOrScalar( + shape_inference::InferenceContext* ic, + const shape_inference::ShapeHandle& shape, + const shape_inference::ShapeHandle& tensor_as_shape, const DataType& dtype); + +} // end namespace grappler +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPPLER_COSTS_GRAPH_PROPERTIES_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/grappler/costs/measuring_cost_estimator.h b/third_party/tflite-hdrs/tensorflow/core/grappler/costs/measuring_cost_estimator.h new file mode 100644 index 00000000..5da9bac9 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/grappler/costs/measuring_cost_estimator.h @@ -0,0 +1,77 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPPLER_COSTS_MEASURING_COST_ESTIMATOR_H_ +#define TENSORFLOW_CORE_GRAPPLER_COSTS_MEASURING_COST_ESTIMATOR_H_ + +#include +#include +#include + +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/grappler/costs/cost_estimator.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/core/threadpool.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +class CostGraphDef; +class GraphDef; +} // namespace tensorflow + +namespace tensorflow { +namespace grappler { + +class Cluster; +struct GrapplerItem; + +// Estimate the cost of running a Grappler item by actually running the +// corresponding TensorFlow graph on the specified cluster and measuring the +// runtimes. +class MeasuringCostEstimator : public CostEstimator { + public: + // Run the model for measurement_steps to measure its average cost. + // When measurement_threads is greater than 0, use a threadpool of as many + // threads to run the measurements; otherwise, run them serially. Does not + // take ownership of cluster. + explicit MeasuringCostEstimator(Cluster* cluster, int measurement_steps, + int measurement_threads); + ~MeasuringCostEstimator() override {} + + // Initializes the estimator for the specified grappler item. + // This implementation always returns OK. + absl::Status Initialize(const GrapplerItem& item) override; + + // Runs the optimized version of the graph on the cluster, measures + // the runtimes of each operation, and annotates the CostGraphDef of + // RunMetadata with the corresponding measurements. + // Returns the average latency for the whole graph. + absl::Status PredictCosts(const GraphDef& optimized_graph, + RunMetadata* run_metadata, + Costs* cost) const override; + + private: + Cluster* cluster_; // Not owned. + int measurement_steps_; + int measurement_threads_; + std::vector> feed_; + std::vector fetch_; + std::unique_ptr thread_pool_; +}; + +} // end namespace grappler +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPPLER_COSTS_MEASURING_COST_ESTIMATOR_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/grappler/costs/op_context.h b/third_party/tflite-hdrs/tensorflow/core/grappler/costs/op_context.h new file mode 100644 index 00000000..90063333 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/grappler/costs/op_context.h @@ -0,0 +1,44 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPPLER_COSTS_OP_CONTEXT_H_ +#define TENSORFLOW_CORE_GRAPPLER_COSTS_OP_CONTEXT_H_ + +#include "absl/container/flat_hash_map.h" +#include "tensorflow/core/framework/function.pb.h" +#include "tensorflow/core/grappler/costs/op_performance_data.pb.h" + +namespace tensorflow { +namespace grappler { + +// A structure to keep the context of op execution, including its shape, +// execution context, and other relevant information. +struct OpContext { + std::string name; + std::string device_name; + OpInfo op_info; + const FunctionDefLibrary* function_library; // Not owned. + // This map is used to stash meta attributes so that they may be + // communicated, for instance, from the scheduler that creates them to the + // CostEstimator or EventCostManager that uses them. + absl::flat_hash_map> + op_meta_attributes; + OpContext() { function_library = nullptr; } +}; + +} // end namespace grappler +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPPLER_COSTS_OP_CONTEXT_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/grappler/costs/op_level_cost_estimator.h b/third_party/tflite-hdrs/tensorflow/core/grappler/costs/op_level_cost_estimator.h new file mode 100644 index 00000000..cd160d6d --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/grappler/costs/op_level_cost_estimator.h @@ -0,0 +1,346 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPPLER_COSTS_OP_LEVEL_COST_ESTIMATOR_H_ +#define TENSORFLOW_CORE_GRAPPLER_COSTS_OP_LEVEL_COST_ESTIMATOR_H_ + +#include +#include +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "tensorflow/core/grappler/costs/cost_estimator.h" +#include "tensorflow/core/grappler/costs/op_context.h" +#include "tensorflow/core/grappler/costs/op_performance_data.pb.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/padding.h" + +namespace tensorflow { +namespace grappler { + +bool GetTensorShapeProtoFromTensorProto(const TensorProto& tensor_proto, + TensorShapeProto* tensor_shape_proto); +std::vector MaybeGetMinimumShape( + const TensorShapeProto& original_shape, int rank, + bool* found_unknown_shapes); + +// Node costs; an intermediate structure used within op level cost estimator. +struct NodeCosts { + // If this FLAG is true, override calculated compute time with a minimum + // value, instead of calculating it from num_compute_ops and compute ops/sec. + // For example, PredictIdentity, PredictVariable, PredictMetadata set this + // FLAG. + bool minimum_cost_op = false; + + // Compute ops. + int64_t num_compute_ops = 0; + + // Memory bytes accessed; note that these may be different to the size of + // tensors. + std::vector num_input_bytes_accessed; // ordered by input tensors. + std::vector num_output_bytes_accessed; // ordered by output ports. + int64_t internal_read_bytes = 0; + int64_t internal_write_bytes = 0; + + // Convenience functions. + int64_t num_total_input_bytes() const { + return std::accumulate(num_input_bytes_accessed.begin(), + num_input_bytes_accessed.end(), 0LL); + } + int64_t num_total_read_bytes() const { + return num_total_input_bytes() + internal_read_bytes; + } + int64_t num_total_output_bytes() const { + return std::accumulate(num_output_bytes_accessed.begin(), + num_output_bytes_accessed.end(), 0LL); + } + int64_t num_total_write_bytes() const { + return num_total_output_bytes() + internal_write_bytes; + } + int64_t num_bytes_accessed() const { + return num_total_read_bytes() + num_total_write_bytes(); + } + + // Memory usage. + int64_t max_memory = 0; + int64_t persistent_memory = 0; + int64_t temporary_memory = 0; + + // Stats. + int64_t num_nodes = 1; + int64_t num_nodes_with_unknown_shapes = 0; + int64_t num_nodes_with_unknown_op_type = 0; + int64_t num_nodes_with_pure_memory_op = 0; + bool inaccurate = false; + + // TODO(dyoon): this is added for compatibility; some old code is hard to + // migrate; hence, using these as a backup. Once we clean up, we'll delete + // these fields. New code should not use these. + bool has_costs = false; + Costs costs; +}; + +class OpLevelCostEstimator { + public: + OpLevelCostEstimator(); + virtual ~OpLevelCostEstimator() {} + + virtual Costs PredictCosts(const OpContext& op_context) const; + + // Returns basic device performance info. + virtual DeviceInfo GetDeviceInfo(const DeviceProperties& device) const; + + protected: + // TODO(dyoon): Consider to remove PredictOpCountBasedCosts() with OpInfo. + // Naive cost estimate based on the given operations count and total + // input/output tensor sizes of the given op_info combined. + Costs PredictOpCountBasedCost(double operations, const OpInfo& op_info) const; + + // Naive cost estimate based on the given operations count and the given total + // io size in bytes. Sizes of op_info inputs and outputs are not taken into + // consideration. + Costs PredictOpCountBasedCost(double operations, double input_io_bytes, + double output_io_bytes, + const OpInfo& op_info) const; + + // Top-level method cost function (PredictCosts calls this method to get + // NodeCosts, and then converts it to Costs). PredictNodeCosts() calls other + // Predict methods depending on op types. + absl::Status PredictNodeCosts(const OpContext& op_context, + NodeCosts* node_costs) const; + + // Predict cost of an op for which no accurate estimator is defined. + absl::Status PredictCostOfAnUnknownOp(const OpContext& op_context, + NodeCosts* node_costs) const; + + // This family of routines predicts the costs to + // perform the specified TensorFlow Op on the + // device represented by a subclass. The default + // implementation just divides the operations to + // perform the op (from the "Count" routines, + // above) by the device peak operations per + // second. + // Implementation of costs other than + // execution_time is optional, depending on the + // device. + absl::Status PredictNaryOp(const OpContext& op_context, + NodeCosts* node_costs) const; + absl::Status PredictConv2D(const OpContext& op_context, + NodeCosts* node_costs) const; + absl::Status PredictCwiseOp(const OpContext& op_context, + NodeCosts* node_costs) const; + absl::Status PredictConv2DBackpropInput(const OpContext& op_context, + NodeCosts* node_costs) const; + absl::Status PredictConv2DBackpropFilter(const OpContext& op_context, + NodeCosts* node_costs) const; + absl::Status PredictFusedConv2DBiasActivation(const OpContext& op_context, + NodeCosts* node_costs) const; + absl::Status PredictMatMul(const OpContext& op_context, + NodeCosts* node_costs) const; + absl::Status PredictSparseTensorDenseMatMul(const OpContext& op_context, + NodeCosts* node_costs) const; + absl::Status PredictNoOp(const OpContext& op_context, + NodeCosts* node_costs) const; + absl::Status PredictIdentity(const OpContext& op_context, + NodeCosts* node_costs) const; + absl::Status PredictVariable(const OpContext& op_context, + NodeCosts* node_costs) const; + absl::Status PredictBatchMatMul(const OpContext& op_context, + NodeCosts* node_costs) const; + absl::Status PredictMetadata(const OpContext& op_context, + NodeCosts* node_costs) const; + absl::Status PredictGatherOrSlice(const OpContext& op_context, + NodeCosts* node_costs) const; + absl::Status PredictScatter(const OpContext& op_context, + NodeCosts* node_costs) const; + absl::Status PredictMaxPool(const OpContext& op_context, + NodeCosts* node_costs) const; + absl::Status PredictMaxPoolGrad(const OpContext& op_context, + NodeCosts* node_costs) const; + absl::Status PredictAvgPool(const OpContext& op_context, + NodeCosts* node_costs) const; + absl::Status PredictAvgPoolGrad(const OpContext& op_context, + NodeCosts* node_costs) const; + absl::Status PredictFusedBatchNorm(const OpContext& op_context, + NodeCosts* node_costs) const; + absl::Status PredictFusedBatchNormGrad(const OpContext& op_context, + NodeCosts* node_costs) const; + absl::Status PredictEinsum(const OpContext& op_context, + NodeCosts* node_costs) const; + absl::Status PredictAssignVariableOps(const OpContext& op_context, + NodeCosts* node_costs) const; + absl::Status PredictPureMemoryOp(const OpContext& op_context, + NodeCosts* node_costs) const; + absl::Status PredictSoftmax(const OpContext& op_context, + NodeCosts* node_costs) const; + absl::Status PredictResizeBilinear(const OpContext& op_context, + NodeCosts* node_costs) const; + absl::Status PredictCropAndResize(const OpContext& op_context, + NodeCosts* node_costs) const; + + int64_t GetSoftmaxComputeOps(const OpContext& op_context) const; + + // Generic cost prediction method for fused operations. + absl::Status PredictFusedOp(const OpContext& op_context, + const std::vector& fused_op_contexts, + NodeCosts* node_costs) const; + + // Utility function for safe division. Returns 0 + // if rhs is 0 or negative. + static double SafeDiv(const double lhs, const double rhs) { + if (rhs > 0) { + return lhs / rhs; + } else { + return 0.0; + } + } + + // This family of routines counts the number of operations to perform the + // specified TensorFlow Op. + struct MatMulDimensions { + int m; + int n; + int k; + }; + struct BatchMatMulDimensions { + std::vector batch_dims; + MatMulDimensions matmul_dims; + }; + struct ConvolutionDimensions { + int64_t batch; // Batch size. + int64_t ix; // Input size x. + int64_t iy; // Input size y. + int64_t iz; // Input depth. + int64_t kx; // Kernel x. + int64_t ky; // Kernel y. + int64_t kz; // Kernel depth (in case of group convolution, this will be + // smaller than input depth). + int64_t oz; // Output depth. + int64_t ox; // Output size x. + int64_t oy; // Output size y. + int64_t sx; // Stride x. + int64_t sy; // Stride y. + Padding padding; // SAME or VALID. + }; + static int64_t CountConv2DOperations(const OpInfo& op_info, + bool* found_unknown_shapes); + static int64_t CountConv2DOperations(const OpInfo& op_info, + ConvolutionDimensions* conv_info, + bool* found_unknown_shapes); + static int64_t CountMatMulOperations(const OpInfo& op_info, + bool* found_unknown_shapes); + static int64_t CountMatMulOperations(const OpInfo& op_info, + MatMulDimensions* mat_mul, + bool* found_unknown_shapes); + static int64_t CountMatMulOperations(const OpInfo& op_info, bool transpose_a, + bool transpose_b, + MatMulDimensions* mat_mul, + bool* found_unknown_shapes); + bool GenerateBatchMatmulContextFromEinsum(const OpContext& einsum_context, + OpContext* batch_matmul_context, + bool* found_unknown_shapes) const; + static int64_t CountBatchMatMulOperations(const OpInfo& op_info, + bool* found_unknown_shapes); + static int64_t CountBatchMatMulOperations( + const OpInfo& op_info, BatchMatMulDimensions* batch_mat_mul, + bool* found_unknown_shapes); + static int64_t CountConv2DBackpropInputOperations( + const OpInfo& op_info, ConvolutionDimensions* returned_conv_dims, + bool* found_unknown_shapes); + static int64_t CountConv2DBackpropFilterOperations( + const OpInfo& op_info, ConvolutionDimensions* returned_conv_dims, + bool* found_unknown_shapes); + + // Calculate the element count of an input/output tensor. + static int64_t CalculateTensorElementCount( + const OpInfo::TensorProperties& tensor, bool* found_unknown_shapes); + + // Calculate the total size in bytes of an input/output tensor. + static int64_t CalculateTensorSize(const OpInfo::TensorProperties& tensor, + bool* found_unknown_shapes); + + // Calculate the element count of the largest + // input of specified TensorFlow op. + static int64_t CalculateLargestInputCount(const OpInfo& op_info, + bool* found_unknown_shapes); + + // Calculate the total size in bytes of the all + // the inputs of specified TensorFlow op. + static int64_t CalculateInputSize(const OpInfo& op_info, + bool* found_unknown_shapes); + + // Same, but a vector format: one for each input. + static std::vector CalculateInputTensorSize( + const OpInfo& op_info, bool* found_unknown_shapes); + + // Calculate the total size in bytes of the all + // the outputs of specified TensorFlow op. + static int64_t CalculateOutputSize(const OpInfo& op_info, + bool* found_unknown_shapes); + + // Same, but a vector format: one for each output. + static std::vector CalculateOutputTensorSize( + const OpInfo& op_info, bool* found_unknown_shapes); + + // For convolution and its grad ops. + static ConvolutionDimensions ConvolutionDimensionsFromInputs( + const TensorShapeProto& original_image_shape, + const TensorShapeProto& original_filter_shape, const OpInfo& op_info, + bool* found_unknown_shapes); + + // For Pooling, FusedBatchNorm, and their grad ops. + static absl::StatusOr OpDimensionsFromInputs( + const TensorShapeProto& original_image_shape, const OpInfo& op_info, + bool* found_unknown_shapes); + + // Helper to construct child operation contexts for the component operations + // of fused ops. + static OpContext FusedChildContext( + const OpContext& parent, const string& op_name, + const OpInfo::TensorProperties& output, + const std::vector& inputs); + + // Helper to construct tensor shapes. + static OpInfo::TensorProperties DescribeTensor( + DataType type, const std::vector& dims); + + // Helper method for building common case NodeCosts. + static absl::Status PredictDefaultNodeCosts(int64_t num_compute_ops, + const OpContext& op_context, + bool* found_unknown_shapes, + NodeCosts* node_costs); + + protected: + std::map elementwise_ops_; + typedef std::function + CostImpl; + std::map device_cost_impl_; + // If true, assume compute and memory overlap; hence, the op cost is max of + // compute_time and memory_time, instead of sum of those two. + bool compute_memory_overlap_; + std::set persistent_ops_; + + private: + friend class OpLevelCostEstimatorTest; +}; + +} // end namespace grappler +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPPLER_COSTS_OP_LEVEL_COST_ESTIMATOR_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/grappler/costs/robust_stats.h b/third_party/tflite-hdrs/tensorflow/core/grappler/costs/robust_stats.h new file mode 100644 index 00000000..f11e608c --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/grappler/costs/robust_stats.h @@ -0,0 +1,42 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPPLER_COSTS_ROBUST_STATS_H_ +#define TENSORFLOW_CORE_GRAPPLER_COSTS_ROBUST_STATS_H_ + +#include +namespace tensorflow { +namespace grappler { +class RobustStats { + public: + explicit RobustStats(const std::vector& values); + explicit RobustStats(std::vector&& values); + + double lo() const { return lo_; } + double hi() const { return hi_; } + double mean() const { return mean_; } + + private: + void HuberMAD(const std::vector& values); + + double lo_; + double hi_; + double mean_; + double stddev_; +}; +} // namespace grappler +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPPLER_COSTS_ROBUST_STATS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/grappler/costs/utils.h b/third_party/tflite-hdrs/tensorflow/core/grappler/costs/utils.h new file mode 100644 index 00000000..94f5c240 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/grappler/costs/utils.h @@ -0,0 +1,132 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPPLER_COSTS_UTILS_H_ +#define TENSORFLOW_CORE_GRAPPLER_COSTS_UTILS_H_ + +#include +#include +#include + +#include "tensorflow/core/framework/cost_graph.pb.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/op_def.pb.h" +#include "tensorflow/core/graph/types.h" +#include "tensorflow/core/grappler/costs/cost_estimator.h" +#include "tensorflow/core/grappler/costs/op_performance_data.pb.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/protobuf/config.pb.h" +#include "tensorflow/core/protobuf/device_properties.pb.h" + +namespace tensorflow { +namespace grappler { + +// Returns a vector of InputProperties for 'node'. The vector will contain one +// entry for each input of 'node'. +// For each node in the graph, the 'name_to_cost' map stores a pointer to the +// corresponding cost graph node indexed by node name. The 'name_to_node' maps a +// node name to its node definition. +std::vector FindInputFeatures( + const NodeDef& node, + const std::unordered_map& name_to_cost, + const std::unordered_map& name_to_node); + +// Returns the size of tensor (unit: bytes). For tensor shape with unknown rank, +// it assumes the tensor to be scalar. For any unknown dimension, it assumes +// size one. +int64_t CalculateTensorSize(const OpInfo::TensorProperties& prop); + +// Returns the size of output at port_num (unit: bytes). A special case is +// port_num -1, which is for control dependency and assumed to be 4 bytes. +int64_t CalculateOutputSize( + const std::vector& output_properties, + int port_num); + +// Returns the DeviceProperties of the device on which 'node' runs. +DeviceProperties GetDeviceInfo(const CostGraphDef::Node& node); +DeviceProperties GetDeviceInfo(const string& device_str); + +// Return a string describing a node given a nodeinfo. +string GetOpDescription(const OpInfo& op_info); + +// Builds the OpInfo for node without filling its device information, given all +// nodes in the graph and its input properties. +OpInfo BuildOpInfoWithoutDevice( + const NodeDef& node, + const std::unordered_map& name_to_node, + const std::vector& inputs); + +// Gather performance data from a cost graph. +OpPerformanceList CostGraphToOpPerformanceData(const CostGraphDef& cost_graph, + const GraphDef& graph); + +// Simple histogram for profiling Tensor size; histogram uses logarithmic +// buckets. +class TensorSizeHistogram { + public: + TensorSizeHistogram() : buckets_(kMaxBuckets, 0) {} + + void Add(const uint64 value); + void Merge(const TensorSizeHistogram& src); + double Average() const { + if (num_elem_ > 0) { + return static_cast(sum_elem_) / num_elem_; + } else { + return 0.0; + } + } + uint64 Min() const { return min_; } + uint64 Max() const { return max_; } + uint64 NumElem() const { return num_elem_; } + uint64 SumElem() const { return sum_elem_; } + string ToString() const; + + protected: + const int Index(const uint64 value) const; + const std::vector& GetBuckets() const { return buckets_; } + + private: + const int kMaxBuckets = 64; + uint64 num_elem_ = 0; + uint64 sum_elem_ = 0; + // min_ and max_ are initialized to a very large value and zero, respectively, + // so that any value added can replace the initial min_ and max_. + uint64 min_ = kuint64max; + uint64 max_ = 0; + // Buckets are logarithmic: + // 0B, 1B, 2-3B, 4-7B, 8-15B, ..., 2^N - 2^(N+1)-1B, ... + std::vector buckets_; +}; + +// Helper functions for aggregating per-device stats into per-device-class +// stats. +string GetDeviceClassForNonChannelDevice(const string& device_name); +string GetDeviceClass(const string& device_name); + +// Get stats in string format from RunMetadata. +string GetStatsStringFromRunMetadata(const RunMetadata& run_metadata, + bool verbosity); + +// This method calculates the execution time depending on whether IO can +// overlap with computation. It assumes the memory and the compute times have +// already been calculated. +void CombineCostsAndUpdateExecutionTime(bool compute_memory_overlap, + Costs* costs); + +} // end namespace grappler +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPPLER_COSTS_UTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/grappler/costs/virtual_placer.h b/third_party/tflite-hdrs/tensorflow/core/grappler/costs/virtual_placer.h new file mode 100644 index 00000000..5f6119ed --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/grappler/costs/virtual_placer.h @@ -0,0 +1,62 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPPLER_COSTS_VIRTUAL_PLACER_H_ +#define TENSORFLOW_CORE_GRAPPLER_COSTS_VIRTUAL_PLACER_H_ + +#include + +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/protobuf/device_properties.pb.h" + +namespace tensorflow { +class NodeDef; + +namespace grappler { +class Cluster; + +// The virtual placer emulates the behavior of the TF placer. +class VirtualPlacer { + public: + explicit VirtualPlacer( + const std::unordered_map& devices); + + const DeviceProperties& get_device(const NodeDef& node) const; + + // Returns device name from cluster, which best matches the node.device() + // specification. Returns default device if no match was found or the + // node.device() could not be parsed. + string get_canonical_device_name(const NodeDef& node) const; + + private: + // Converts given device name to Lowercase Fully-Qualified Name (LFQN) string. + // This helps us disambiguate device names internally and simplify matching. + // If device_name couldn't be parsed successfully, returns empty string. + string to_lfqn_or_empty(const string& device_name) const; + + // Map based on the cluster info: cluster device name -> device properties. + std::unordered_map devices_; + + // Maps LFQN to original device name as it was declared in cluster. + std::unordered_map lfqn_map_; + + string default_device_name_; + string default_job_name_lowercase_; +}; + +} // namespace grappler +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPPLER_COSTS_VIRTUAL_PLACER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/grappler/costs/virtual_scheduler.h b/third_party/tflite-hdrs/tensorflow/core/grappler/costs/virtual_scheduler.h new file mode 100644 index 00000000..f574832b --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/grappler/costs/virtual_scheduler.h @@ -0,0 +1,543 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPPLER_COSTS_VIRTUAL_SCHEDULER_H_ +#define TENSORFLOW_CORE_GRAPPLER_COSTS_VIRTUAL_SCHEDULER_H_ + +#include +#include +#include +#include +#include +#include + +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/step_stats.pb.h" +#include "tensorflow/core/grappler/costs/cost_estimator.h" +#include "tensorflow/core/grappler/costs/graph_properties.h" +#include "tensorflow/core/grappler/costs/op_context.h" +#include "tensorflow/core/grappler/costs/virtual_placer.h" +#include "tensorflow/core/grappler/grappler_item.h" + +namespace tensorflow { +namespace grappler { + +ABSL_CONST_INIT extern const char kAttrInputSrc[]; +ABSL_CONST_INIT extern const char kAttrSrcDevice[]; +ABSL_CONST_INIT extern const char kAttrDstDevice[]; +ABSL_CONST_INIT extern const char kAttrTensorName[]; +ABSL_CONST_INIT extern const char kChannelDevice[]; +ABSL_CONST_INIT extern const char kStreaming[]; + +struct NodeState { + // A node (i.e., an op) takes a set of input:port pairs and produces + // a set of output ports. + + // Cross references to input and output nodes from graphdef. + std::vector> inputs; // Input, port pairs. + // List of output nodes (a list of nodes that takes this output port as input) + // keyed by port_num. Note that port_num -1 is used for control dependency. + std::unordered_map> outputs; + + // Info from GraphProperties. + std::vector input_properties; + std::vector output_properties; + + // Canonical device name used within VirtualScheduler. + string device_name; + + // States updated as scheduling nodes. + int num_inputs_ready; + std::unordered_map num_outputs_executed; + Costs::Duration time_ready; + Costs::Duration time_scheduled; + Costs::Duration time_finished; + // Time that all the consumers are executed (hence, no need to keep this + // output in memory), keyed by port_num. + std::unordered_map time_no_references; + + // Note that a node may have multiple output ports. The length of outputs, + // num_outputs_executed, and time_no_references should be + // identical when a NodeState is fully initialized. + // They should be 1 + output_properties.size() as we add [-1] for control + // dependency. + + // Node will be ready to be executed at time_ready, scheduled at + // time_scheduled, and finishes execution at time_finished. + // Each output port uses up memory space from time_scheduled to its + // time_no_references. + + Costs node_costs; // Node costs per execution + Costs TotalNodeCosts() const { + return MultiplyCosts(node_costs, execution_count); + } + // How many times this node has been executed, e.g. in a while loop. + int execution_count; + + // Output shape incompatible between shape annotation and shape inference. + bool shape_incompatible; + + NodeState() { + num_inputs_ready = 0; + time_ready = Costs::Duration::max(); + time_scheduled = Costs::Duration::max(); + time_finished = Costs::Duration::max(); + execution_count = 0; + shape_incompatible = false; + // Note that num_outputs_executed and time_no_references are not initialized + // here, since we don't know the size (i.e., # outputs for this node). + } +}; + +struct DeviceState { + // Nodes executed on this device in execution order. + std::vector nodes_executed; + + struct NodePairHash { + public: + const std::size_t operator()( + const std::pair& element) const { + return std::hash()(element.first); + } + }; + + // Nodes currently allocated in memory: set of NodeDef* and port_num pairs + // so that we can track which output of the node is in memory. + std::unordered_set, NodePairHash> + nodes_in_memory; + + // Nodes allocated in memory persistently: e.g., Variables. + std::unordered_set, NodePairHash> + persistent_nodes; + + // Snapshot of nodes_in_memory, when memory usage is at peak. + // Same to nodes_in_memory, it's a set of NodeDef* and port_num pairs. + std::unordered_set, NodePairHash> + mem_usage_snapshot_at_peak; + + // Vector of temporary memory usage trace in execution order. + // Each pair represents the current node name and current (accumulated) + // temporary memory usage of the device when the node is scheduled. + // Only enabled when mem_usage_tracking is enabled. + // Note: CPU uses an inter-op threadpool, so the execution order on CPU may + // not be deterministic. + std::vector> temporary_memory_usage_trace; + + Costs device_costs; + std::map op_to_cost; // Per-op cost. + + int64_t memory_usage; // Current temporary memory usage + int64_t max_memory_usage; // Max temporary memory usage + + // Shape annotation statistics. + struct ShapeAnnotationStats { + // Number of ops with shape annotated. + int64_t num_ops_annotated = 0; + // Number of ops executed multiple times (e.g. in a loop). + int64_t num_ops_executed_more_than_once = 0; + // Number of ops executed: account for execution count. + int64_t num_ops_executed = 0; + // Number of ops with dynamic shapes (e.g. shape changes in a loop). + int64_t num_ops_with_dynamic_shapes = 0; + // Number of ops with incompatible shapes between annotation and shape + // inference. + int64_t num_ops_with_incompatible_shapes = 0; + } shape_annotation_stats; + + DeviceState() { + device_costs = Costs::ZeroCosts(); + device_costs.num_ops_total = 0; + memory_usage = 0; + max_memory_usage = 0; + } + + Costs::Duration GetCurrTime() const { return device_costs.execution_time; } +}; + +// ReadyNodeManager (abstract class): +// Keeps ready nodes and picks the best one to be scheduled. +class ReadyNodeManager { + public: + ReadyNodeManager() {} + virtual ~ReadyNodeManager() {} + virtual absl::Status Init( + const std::unordered_map* node_map) { + return absl::OkStatus(); + } + virtual void AddNode(const NodeDef* node) = 0; + virtual const NodeDef* GetCurrNode() = 0; + virtual void RemoveCurrNode() = 0; + virtual bool Empty() const = 0; +}; + +class FIFOManager : public ReadyNodeManager { + public: + FIFOManager() : ReadyNodeManager() {} + ~FIFOManager() override {} + void AddNode(const NodeDef* node) override { nodes_.push_back(node); } + const NodeDef* GetCurrNode() override { + CHECK(!nodes_.empty()) << "GetCurrNode(), but there's no ready node"; + return nodes_.front(); + } + void RemoveCurrNode() override { nodes_.pop_front(); } + bool Empty() const override { return nodes_.empty(); } + + private: + std::list nodes_; +}; + +// The LIFOManager schedules nodes by returning the last one added to the +// scheduler. A node is executed and then its ready outputs are newly added to +// the scheduler, so the LIFOManager will return outputs to a node following +// that node's execution. +class LIFOManager : public ReadyNodeManager { + public: + LIFOManager() : ReadyNodeManager() {} + ~LIFOManager() override {} + void AddNode(const NodeDef* node) override; + const NodeDef* GetCurrNode() override; + void RemoveCurrNode() override; + bool Empty() const override { return nodes_.empty(); } + + private: + std::list nodes_; + // Keep track of the current node being executed by saving its position. + // Necessary because nodes may be added to the end of the list while a node is + // executing, and we want to remove the correct node (the one that is + // executing) rather than the new ones being added. + std::list::iterator curr_pos_ = nodes_.end(); +}; + +// Abstract class that maintains a heap/priority queue for scheduling ready +// nodes. Derived class needs to implement the Greater() function which returns +// the comparator for the heap. +class HeapReadyManager : public ReadyNodeManager { + public: + HeapReadyManager(); + absl::Status Init( + const std::unordered_map* node_map) override; + ~HeapReadyManager() override {} + void AddNode(const NodeDef* node) override; + const NodeDef* GetCurrNode() override; + void RemoveCurrNode() override; + bool Empty() const override; + + protected: + virtual std::function Greater() = 0; + + // nodes_ is the main queue, where we construct heap, and the front is the + // current node. + std::vector nodes_; + + // Comparator functor for heap; stl heap is max heap, so we use "greater than" + // functor for keeping the smallest time_ready node at the front of heap. + std::function greater_; + + // NodeState structure from SchedulerState to get time_ready of ready nodes. + // Not owned by FirstReadyManager. + const std::unordered_map* node_map_; + + // Cached curr node. Set back to nullptr from RemoveCurrNode(). + const NodeDef* curr_node_; +}; + +// FirstReadyManager picks a node with the minimum time_ready value. +// Behavior is deterministic when there are more than one nodes with the minimum +// time_ready value with unique node names as the tie-breaker. +class FirstReadyManager : public HeapReadyManager { + public: + FirstReadyManager() : HeapReadyManager() {} + ~FirstReadyManager() override {} + + protected: + std::function Greater() override; +}; + +// PriorityReadyManager uses the given node priorities when picking up next node +// from all the ready nodes. +class PriorityReadyManager : public HeapReadyManager { + public: + PriorityReadyManager() : HeapReadyManager() {} + ~PriorityReadyManager() override {} + void AddNode(const NodeDef* node) override; + + // Note this should be called after Init(). + absl::Status SetPriority( + const std::unordered_map& node_priority); + + protected: + std::function Greater() override; + + private: + // A map from unique node name to priority. Lower number means higher + // priority. + std::unordered_map node_priority_; +}; + +// CompositeNodeManager has a few other NodeManagers: per-device LIFO for normal +// ops (neither _Send nor _Recv) and FirstReadyManagers for _Send ops and _Recv +// ops, and then it chooses FirstReady among the ops chosen from each +// internal NodeManagers. The objective is to maximize producer-consumer +// locality within device, while processing nodes across devices, including +// _Send and _Recv, fairly, in terms of their time_ready. +class CompositeNodeManager : public ReadyNodeManager { + public: + CompositeNodeManager(); + ~CompositeNodeManager() override {} + + absl::Status Init( + const std::unordered_map* node_map) override; + void AddNode(const NodeDef* node) override; + const NodeDef* GetCurrNode() override; + void RemoveCurrNode() override; + bool Empty() const override; + + private: + // Internal ready node managers: + // LIFO for normal ops to maximize producer consumer locality. + // One LIFO per device. + std::unordered_map ops_lifo_map_; + // FirstReady for send and recv. Handle send and recv separately ensures that + // send and recv do not block previously read ops with LIFO schedule. + FirstReadyManager send_manager_; + FirstReadyManager recv_manager_; + + // NodeState structure from SchedulerState to get time_ready of ready nodes. + // Not owned by CompositeReadyManager. + const std::unordered_map* node_map_; + + // Cached curr node. Set back to nullptr from RemoveCurrNode(). + const NodeDef* curr_node_; +}; + +// Constructs a ready node manager from the given string. +std::unique_ptr ReadyNodeManagerFactory( + const string& ready_node_manager); + +// Encapsulates all of the various pieces uses to track state of a scheduler; +// enables reuse of all scheduler state-related utilities across different +// scheduler implementations. +class SchedulerState { + public: + SchedulerState(const bool use_static_shapes, + const bool use_aggressive_shape_inference, Cluster* cluster, + std::unique_ptr placer); + // Move constructor. Explicitly defined because it otherwise gets implicitly + // deleted. SchedulerState is a move-only class, as we have a + // for it in VirtualScheduler. A derivative of VirtualScheduler can move a + // SchedulerState to VirtualScheduler when it is constructed, + // which is where this move constructor is needed. + SchedulerState(SchedulerState&& arg) = default; + // We explicitly delete assinment and copy operators, this is done implicitly, + // but we state it here explicitly for clarity. + SchedulerState& operator=(SchedulerState&& arg) = delete; + SchedulerState(const SchedulerState&) = delete; + SchedulerState& operator=(const SchedulerState&) = delete; + // Destructor. Must be defined such that a derivative class can override it + // and allow proper desctruction of the derivative class. If this is not done + // properly, memory leaks can occur. + virtual ~SchedulerState(); + // Sets up the graph while also performing some necessary transformations + // initial_nodes is the set of nodes (primary inputs) discovered by Init() + // which may be added by a ReadyNodeManager (or related/derivative scheduler) + // to begin node schedule and graph simulation. + absl::Status Init(const GrapplerItem* item, + std::vector* initial_nodes, + bool create_explicit_channel_device = true); + + virtual Costs Summary() const; + // Like the above, but writes detailed stats to RunMetadata. + // If metadata is nullptr, then just calls and return Summary(). + virtual Costs Summary(RunMetadata* metadata); + // Generates RunMetadata's step_stats and partition_graphs fields from results + // of the virtual execution of the graph. + // TODO(rdegruijl) See if we can make this function and caller Summary() + // const. + void GenerateRunMetadata(RunMetadata* metadata); + + // Returns per device memory usage. + const std::unordered_map GetPeakMemoryUsage() const; + const std::unordered_map GetPersistentMemoryUsage() const; + void enable_mem_usage_tracking() { track_mem_usage_snapshot_ = true; } + // Returns (read only) device and node states. + const std::unordered_map* GetDeviceStates() const { + return &device_; + } + + const std::unordered_map* GetNodeStates() const { + return &node_map_; + } + + virtual OpContext CreateOpContext(const NodeDef* node) const; + std::vector MarkNodeExecuted( + const NodeDef* node, const Costs& node_costs, const OpContext& op_context, + bool extract_execution_count_attr = true, + const std::string& override_device_name = ""); + + // Some getter functions. + const GrapplerItem* GetGrapplerItem() { return grappler_item_; } + Costs GetGraphCost() { return graph_costs_; } + Cluster* GetCluster() { return cluster_; } + bool GetUseStaticShape() { return use_static_shapes_; } + bool GetUseAggressiveShapeInference() { + return use_aggressive_shape_inference_; + } + const std::unordered_map& GetNodeMap() { + return node_map_; + } + + protected: + // Assigns the time_scheduled in the NodeState of node to the current + // execution_time of the device executing this node. + void SetNodeStateTimeScheduled(const NodeDef* node); + + // This method can be used by a class derived from SchedulerState to + // access the device state map. + std::unordered_map* GetMutableDeviceState() { + return &device_; + } + + private: + // Methods called from Init(). Fails if initialize_ is set. + + void MaybeUpdateInputOutput(const NodeDef* node); + NodeState& GetNodeStateOrCreateIt(const NodeDef* node); + // Creates a Send_ and Recv_ pair between from and to. The argument + // create_channel_device tells the function to create an explicit device for + // the channel. + std::pair CreateSendRecv( + const NodeDef* from, const NodeDef* to, const NodeDef* input_node, + const string& input_name, bool create_channel_device); + string DeviceName(const NodeDef* node) const; + string SanitizedDeviceName(const NodeDef* node) const; + string ChannelDeviceName(const NodeDef* from, const NodeDef* to) const; + + // Helper methods. + void GetOutputNodes(const NodeDef* node, const Costs::Duration& curr_time, + std::vector* output_nodes); + // Retrieves output size from node_cost at a port_num. If the output size has + // not been set, defaults back to CalculateOutputSize. + int64_t GetOrCalculateOutputSize(const NodeState& node_state, + int port_num) const; + + std::unordered_map node_map_; + std::unordered_map device_; + + // Pool of NodeDefs for SendRecv and Identity ops created. + std::vector> additional_nodes_; + + // Stats: + // Op counts with key with input shape. + // Example key: "[Op=AssignSub, input_shapes=[[7,1,160,160][7,1,160,160]]" + std::map op_counts_; + // Individual op costs with key with input shape. + // Integer field for execution time in micro seconds. + // Boolean field for whether the cost is accurate. + std::map> op_costs_; + + Costs graph_costs_; // Graph cost. + std::map op_to_cost_; // Per-op cost. + + // Auxiliary data structures for constructing NodeState and DeviceState. + std::unique_ptr graph_properties_; // Initialized in Init(). + Cluster* cluster_; // Not owned. + const GrapplerItem* grappler_item_; // Not owned. + bool use_static_shapes_; + bool initialized_; + bool track_mem_usage_snapshot_; + const bool use_aggressive_shape_inference_; + std::unique_ptr placer_; +}; + +// The virtual scheduler emulates execution of nodes in a graph, considering +// dependencies, device, etc. +class VirtualScheduler { + public: + // Does not take ownership of cluster or ready_nodes. + VirtualScheduler(const bool use_static_shapes, + const bool use_aggressive_shape_inference, Cluster* cluster, + ReadyNodeManager* ready_nodes, + std::unique_ptr placer); + // This constructor can be called by a derivative of VirtualScheduler to + // construct the base class. It lets VirtualScheduler take ownership of + // a new SchedulerState or a derivative thereof. + // Note that this constructor does not set a VirtualPlacer, in this + // constructor the VirtialPlacer is passed as a member of the SchedulerState + // that is passed as an argument. + VirtualScheduler(ReadyNodeManager* ready_nodes, + std::unique_ptr scheduler_state); + virtual ~VirtualScheduler(); + + // Initializes the scheduler for the specific grappler item. + // Should be called immediately after the c'tor or when the scheduler will be + // reused for a new grappler item. All internal states of the scheduler + // related to the previous grappler item will be reset/cleared. + // + // This function should be called at least once after the scheduler is + // constructed. An uninitialized or failed-to-initialize scheduler will cause + // undefined behavior. + virtual absl::Status Init(const GrapplerItem* item); + + // Gets the current scheduled node for execution; the caller of this function + // can accordingly simulate the execution of the current scheduled node. + virtual OpContext GetCurrNode(); + // Marks the current scheduled node as executed. Note that we should call this + // function only after the execution of the node has been simulated; + // node_costs_ capture the simulated costs of the node. + // Returns true if there is any node to be scheduled. + virtual bool MarkCurrNodeExecuted(const Costs& node_costs); + + // Prints out summary of execution (timing, memory usage, etc.) + Costs Summary() const { return scheduler_state_->Summary(); } + // Like the above, but writes detailed stats to RunMetadata. + // If metadata is nullptr, then just calls and return Summary(). + Costs Summary(RunMetadata* metadata) { + return scheduler_state_->Summary(metadata); + } + // Generates RunMetadata's step_stats and partition_graphs fields from results + // of the virtual execution of the graph. + void GenerateRunMetadata(RunMetadata* metadata) { + scheduler_state_->GenerateRunMetadata(metadata); + } + // Returns per device memory usage. + const std::unordered_map GetPeakMemoryUsage() const { + return scheduler_state_->GetPeakMemoryUsage(); + } + const std::unordered_map GetPersistentMemoryUsage() const { + return scheduler_state_->GetPersistentMemoryUsage(); + } + // Returns VirtualScheduler (read only) device and node states. + const std::unordered_map* GetDeviceStates() const { + return scheduler_state_->GetDeviceStates(); + } + const std::unordered_map* GetNodeStates() const { + return scheduler_state_->GetNodeStates(); + } + void enable_mem_usage_tracking() { + scheduler_state_->enable_mem_usage_tracking(); + } + + protected: + // The state of the scheduler and the execution of the graph is encapsulated + // by the scheduler_state_ object. + std::unique_ptr scheduler_state_; + // ready_nodes_ is responsible for ordering the traversal of the graph. + ReadyNodeManager* ready_nodes_; // Not owned. +}; + +} // namespace grappler +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPPLER_COSTS_VIRTUAL_SCHEDULER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/grappler/devices.h b/third_party/tflite-hdrs/tensorflow/core/grappler/devices.h new file mode 100644 index 00000000..8a27bfac --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/grappler/devices.h @@ -0,0 +1,46 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPPLER_DEVICES_H_ +#define TENSORFLOW_CORE_GRAPPLER_DEVICES_H_ + +#include +#include +#include + +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/core/threadpool.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +namespace grappler { + +// Get the number of available GPUs whose number of multiprocessors is no less +// than 8 and whose CUDA compute capability is no less than +// min_cuda_compute_capability. +int GetNumAvailableGPUs( + const std::pair& min_cuda_compute_capability = {0, 0}); + +// Maximum amount of gpu memory available per gpu. gpu_id must be in the range +// [0, num_available_gpu) +int64_t AvailableGPUMemory(int gpu_id); + +// Get the number of logical CPU cores (aka hyperthreads) available. +int GetNumAvailableLogicalCPUCores(); + +} // end namespace grappler +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPPLER_DEVICES_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/grappler/graph_analyzer/gen_node.h b/third_party/tflite-hdrs/tensorflow/core/grappler/graph_analyzer/gen_node.h new file mode 100644 index 00000000..e47e2d94 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/grappler/graph_analyzer/gen_node.h @@ -0,0 +1,168 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPPLER_GRAPH_ANALYZER_GEN_NODE_H_ +#define TENSORFLOW_CORE_GRAPPLER_GRAPH_ANALYZER_GEN_NODE_H_ + +#include +#include +#include +#include + +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/op_def.pb.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/protobuf/meta_graph.pb.h" + +namespace tensorflow { +namespace grappler { +namespace graph_analyzer { + +class GenNode; + +// To find nodes by name. +using GenNodeMap = std::unordered_map>; + +// One node in the graph, in the form convenient for traversal and generation of +// subgraphs. It refers to the original NodeDef protobuf for most information +// and adds the extra enrichment. +// +// The graph building is 2-stage: first match a GenNode with each NodeDef and +// collect them into a map that finds them by name, then process the map, +// deep-parse the underlying NodeDefs and connect the GenNodes together. +class GenNode { + public: + // Will keep the pointer, so the underlying object must not be deleted while + // GenNode is alive. + explicit GenNode(const NodeDef* node); + + // Access wrappers. + const string& name() const { return node_->name(); } + const string& opcode() const { return node_->op(); } + const NodeDef* node_def() const { return node_; } + + // Parse the inputs of this node and update the map accordingly, creating the + // links (i.e. edges, connections between nodes) in itself and in the nodes + // it's linked to (the map itself is unchanged, only the nodes in it are + // updated). + absl::Status ParseInputs(const GenNodeMap* map); + + // Does the full 2-stage build of the graph. The map should be initially + // empty. The map keeps pointers to the nodes in source, so the source must + // not be destroyed before the map. + static absl::Status BuildGraphInMap(const GraphDef& source, GenNodeMap* map); + + // The enrichment that constitutes the point of this class. + + // Representation of a connection on a node. + class Port { + public: + // A port may be inbound or outbound. + // Negative ids (canonically -1) mean a control port. + Port(bool inbound, int32_t id) : value_(id << 1) { + if (inbound) { + value_ |= 1; + } + } + Port(const Port&) = default; + Port& operator=(const Port&) = default; + + bool IsInbound() const { return (value_ & 0x1); } + + bool IsControl() const { return (value_ < 0); } + + int32_t Id() const { + // Arithmetic shift preserves the sign. + return (value_ >> 1); + } + + // Integer type used to represent the encoded port value. + using IntPort = int32_t; + + // Returns the encoded form of this port, so that it can be used + // as various map indexes. + IntPort Encoded() const { return value_; } + + static Port Decode(IntPort encoded) { return Port(encoded); } + + bool operator==(const Port& other) const { return value_ == other.value_; } + bool operator<(const Port& other) const { return value_ < other.value_; } + + struct Hasher { + size_t operator()(const Port& port) const noexcept { + return hasher(port.Encoded()); + } + std::hash hasher; + }; + + // Convenient for printing. I've really wanted it to be implicit but + // ClangTidy insists on making it explicit. + explicit operator string() const; + + private: + explicit Port(IntPort value) : value_(value) {} + + IntPort value_; + }; + + struct LinkTarget { + GenNode* node; // Node where this link points. + Port port; // Port on the remote side of this link. + + LinkTarget(GenNode* a_node, Port a_port) : node(a_node), port(a_port) {} + }; + // All the links that are connected to the same port of this node + // are collected in one vector. A link is an edge of the graph that connects + // 2 nodes. Each of the connected nodes has its own perspective on the link, + // seeing its local port, remote port and the remote node. The direction of + // the link is encoded in the ports, one port is always incoming and another + // one outgoing. + using LinkTargetVector = std::vector; + // Both inputs and outputs are stored in the same map. + using LinkMap = std::unordered_map; + + // Access to the link map. + const LinkMap& links() const { return links_; } + + // Check whether the port is an input (including the controls) with multiple + // connections. Such inputs get handled in a special way when building the + // subgraphs, in an "all or nothing" fashion. + bool IsMultiInput(Port port) const; + + // When building the subgraphs, must include either all non-control inputs of + // this node into the subgraph or none of them. This happens when at least one + // of the inputs is a multi-input (or if the opcode is commutative, thus + // treating all the inputs as one multi-input). + bool AllInputsOrNone() const { return all_inputs_or_none_; } + + private: + const NodeDef* node_; + // Becomes valid only after ParseInputs(). + const OpDef* op_; + + // The opcode has a complicated structure of input args, with multi-input args + // that are not commutative. This means that to make sense, the subgraphs that + // include this node must also include either all its inputs or none of them. + bool all_inputs_or_none_ = false; + + LinkMap links_; +}; + +} // end namespace graph_analyzer +} // end namespace grappler +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPPLER_GRAPH_ANALYZER_GEN_NODE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/grappler/graph_analyzer/graph_analyzer.h b/third_party/tflite-hdrs/tensorflow/core/grappler/graph_analyzer/graph_analyzer.h new file mode 100644 index 00000000..56828ee1 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/grappler/graph_analyzer/graph_analyzer.h @@ -0,0 +1,154 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPPLER_GRAPH_ANALYZER_GRAPH_ANALYZER_H_ +#define TENSORFLOW_CORE_GRAPPLER_GRAPH_ANALYZER_GRAPH_ANALYZER_H_ + +#include +#include + +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/grappler/graph_analyzer/map_tools.h" +#include "tensorflow/core/grappler/graph_analyzer/sig_node.h" +#include "tensorflow/core/grappler/graph_analyzer/subgraph.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { +namespace grappler { +namespace graph_analyzer { + +namespace test { +class GraphAnalyzerTest; +} // end namespace test + +// Finds all the subgraphs of a given size and groups them by equivalence. +class GraphAnalyzer { + public: + // Makes a copy of the graph. + GraphAnalyzer(const GraphDef& graph, int subgraph_size); + + virtual ~GraphAnalyzer(); + + // Performs the analysis and collects the subgraphs. + absl::Status Run(); + + // Returns the subgraphs found in Run() printed to text. + std::vector DumpSubgraphs(); + + // Prints the subgraphs found in Run() to stdout. + absl::Status OutputSubgraphs(); + + // TODO(babkin): add a way to extract the subgraphs as direct data + // structures and as protobufs, and to write protobufs to a RecordIO. + + private: + GraphAnalyzer() = delete; + GraphAnalyzer(const GraphAnalyzer&) = delete; + void operator=(const GraphAnalyzer&) = delete; + + friend class tensorflow::grappler::graph_analyzer::test::GraphAnalyzerTest; + + // Builds the map of nodes from the original graph definition. + absl::Status BuildMap(); + + // Using nodes_, finds all the subgraphs of size subgraph_size_ and places + // them into result_. + void FindSubgraphs(); + + // Deletes from result_ the unacceptable subgraphs. Those include the + // subgraphs where not all the inputs at a multi-input port are included (this + // could happen if some of these inputs were reached and included through + // different paths). + void DropInvalidSubgraphs(); + + // Deletes from result_ duplicate entries of equivalent topology. + absl::Status CollateResult(); + + // Returns the raw subgraphs found in FindSubgraphs() printed to text. + std::vector DumpRawSubgraphs(); + + // Finds and adds appropriately to either partial_ or result_ all the + // subgraphs that can be created by extending the parent subgraph by one node. + // Ignores the duplicates. + void ExtendSubgraph(Subgraph* parent); + + // Extends the parent subgraph by adding another node (if it wasn't already + // added) and all its non-control inputs in the link map range at once. + // If the subgraph would grow over subgraph_size_, it gets ignored. + void ExtendSubgraphAllOrNone(Subgraph* parent, const GenNode* node); + // Same but adds one specific inbound port (even control) all-or-none. + void ExtendSubgraphPortAllOrNone(Subgraph* parent, const GenNode* node, + GenNode::Port port); + // The common final step called by ExtendSubgraph*AllOrNone() methods. + void AddExtendedSubgraph(Subgraph* parent, const Subgraph::Identity& id); + + // Returns true if this subgraph has any multi-inputs that aren't all-in or + // all-out. + bool HasInvalidMultiInputs(Subgraph* sg); + + // Graph to run the analysis on. + GraphDef graph_; + int subgraph_size_; + + // The enriched graph of parsed nodes and connections. + GenNodeMap nodes_; + // The resulting set of subgraphs. + SubgraphPtrSet result_; + // The subgraphs of partial size, stored while finding the result. + SubgraphPtrSet partial_; + // The subgraphs of partial size (stored in partial_) that are still waiting + // to be extended. + // + // TODO(babkin): This is rather simple-minded, each subgraph is examined from + // scratch, which means that all its internal links get iterated too. But it's + // OK for the small subgraphs. This can be improved by keeping not just + // subgraphs but iterators on the list, each of them having the list not-yet + // examined nodes (and the link position of the next link to be examined for + // the first node). This would add extra constant overhead, so the break-even + // subgraph size is not clear yet. + std::deque todo_; + + // The collation map by signature is designed to allow the removal of entries + // and moving of the signature references from the keys of this map to the + // outside world. Must be careful at inserting and removal: make sure that + // when a new entry is inserted, its signature reference gets populated with + // the same data as the key of the map, and that if a reference is moved out, + // the map entry gets removed before that reference gets destroyed. + struct CollationEntry { + std::shared_ptr sig; + size_t count = 0; + }; + using CollationMap = + std::unordered_map, + EqAtPtr >; + CollationMap collation_map_; + + // The entries are owned by collation_map_, so must be removed from + // ordered_collation_ before removing them from collation_map_. + struct ReverseLessByCount { + bool operator()(CollationEntry* left, CollationEntry* right) const { + return left->count > right->count; // Reverse order. + } + }; + using CollationOrderByCount = + std::multiset; + CollationOrderByCount ordered_collation_; +}; + +} // end namespace graph_analyzer +} // end namespace grappler +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPPLER_GRAPH_ANALYZER_GRAPH_ANALYZER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/grappler/graph_analyzer/graph_analyzer_tool.h b/third_party/tflite-hdrs/tensorflow/core/grappler/graph_analyzer/graph_analyzer_tool.h new file mode 100644 index 00000000..5a91fe7d --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/grappler/graph_analyzer/graph_analyzer_tool.h @@ -0,0 +1,31 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPPLER_GRAPH_ANALYZER_GRAPH_ANALYZER_TOOL_H_ +#define TENSORFLOW_CORE_GRAPPLER_GRAPH_ANALYZER_GRAPH_ANALYZER_TOOL_H_ + +#include "tensorflow/core/lib/strings/str_util.h" + +namespace tensorflow { +namespace grappler { +namespace graph_analyzer { + +void GraphAnalyzerTool(const string& file_name, int n); + +} // end namespace graph_analyzer +} // end namespace grappler +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPPLER_GRAPH_ANALYZER_GRAPH_ANALYZER_TOOL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/grappler/graph_analyzer/hash_tools.h b/third_party/tflite-hdrs/tensorflow/core/grappler/graph_analyzer/hash_tools.h new file mode 100644 index 00000000..b0e79f9a --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/grappler/graph_analyzer/hash_tools.h @@ -0,0 +1,47 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPPLER_GRAPH_ANALYZER_HASH_TOOLS_H_ +#define TENSORFLOW_CORE_GRAPPLER_GRAPH_ANALYZER_HASH_TOOLS_H_ + +#include + +namespace tensorflow { +namespace grappler { +namespace graph_analyzer { + +// Unfortunately, std::hash provides no way to combine hashes, so everyone +// is copying boost::hash_combine. This is a version that follows Google's +// guidelines on the arguments, and contains only the combination, without +// hashing. +inline void CombineHash(size_t from, size_t* to) { + *to ^= from + 0x9e3779b9 + (*to << 6) + (*to >> 2); +} + +// Combine two hashes in such a way that the order of combination doesn't matter +// (so it's really both commutative and associative). The result is not a very +// high-quality hash but can be used in case if the order of sub-elements must +// not matter in the following comparison. An alternative would be to sort the +// hashes of the sub-elements and then combine them normally in the sorted +// order. +inline void CombineHashCommutative(size_t from, size_t* to) { + *to = *to + from + 0x9e3779b9; +} + +} // end namespace graph_analyzer +} // end namespace grappler +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPPLER_GRAPH_ANALYZER_HASH_TOOLS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/grappler/graph_analyzer/map_tools.h b/third_party/tflite-hdrs/tensorflow/core/grappler/graph_analyzer/map_tools.h new file mode 100644 index 00000000..f380504a --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/grappler/graph_analyzer/map_tools.h @@ -0,0 +1,46 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPPLER_GRAPH_ANALYZER_MAP_TOOLS_H_ +#define TENSORFLOW_CORE_GRAPPLER_GRAPH_ANALYZER_MAP_TOOLS_H_ + +#include + +namespace tensorflow { +namespace grappler { +namespace graph_analyzer { + +// Helpers for building maps of pointers. + +template +struct LessAtPtr : std::function { + bool operator()(const Ptr& x, const Ptr& y) const { return *x < *y; } +}; + +template +struct EqAtPtr : std::function { + bool operator()(const Ptr& x, const Ptr& y) const { return *x == *y; } +}; + +template +struct HashAtPtr : std::function { + size_t operator()(const Ptr& x) const { return x->Hash(); } +}; + +} // end namespace graph_analyzer +} // end namespace grappler +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPPLER_GRAPH_ANALYZER_MAP_TOOLS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/grappler/graph_analyzer/sig_node.h b/third_party/tflite-hdrs/tensorflow/core/grappler/graph_analyzer/sig_node.h new file mode 100644 index 00000000..6e6749b4 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/grappler/graph_analyzer/sig_node.h @@ -0,0 +1,304 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPPLER_GRAPH_ANALYZER_SIG_NODE_H_ +#define TENSORFLOW_CORE_GRAPPLER_GRAPH_ANALYZER_SIG_NODE_H_ + +#include +#include +#include + +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/grappler/graph_analyzer/gen_node.h" +#include "tensorflow/core/grappler/graph_analyzer/hash_tools.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/protobuf/meta_graph.pb.h" + +namespace tensorflow { +namespace grappler { +namespace graph_analyzer { + +namespace test { +class SigBaseTest; +} // end namespace test + +class SigNode; + +// To find nodes by name. Having the map ordered makes the tests easier, +// and it isn't used in production code often enough to get any win from +// using an unordered map. +using SigNodeMap = std::map>; + +// One node in the graph, in the form convenient for generation of the signature +// of the graph, and comparison of two (sub)graphs for equivalence. It refers to +// the original NodeDef protobuf for most information and adds the extra +// enrichment. +// +// The graph building is 2-stage: first match a SigNode with each NodeDef and +// collect them into a map that finds them by name, then process the map, +// deep-parse the underlying NodeDefs and connect the SigNodes together. +class SigNode { + public: + friend struct Signature; + + // Will keep the pointer to the underlying NodeDef, so that + // underlying object must not be deleted while SigNode is alive. + explicit SigNode(const NodeDef* node); + + // Access wrappers. + const string& name() const { return node_->name(); } + const string& opcode() const { return node_->op(); } + const NodeDef* node_def() const { return node_; } + + // For extraction of subgraphs into a separate SigNodeMap, copies the links + // that point inside the subgraph from a full-graph SigNode to a subgraph + // SigNode. The translation map defines the subgraph and gives the mapping + // from the nodes in the full graph to the matching nodes in subgraph. + using TranslationMap = + std::unordered_map; + void CopyLinks(const GenNode& from, const TranslationMap& map); + + // A link is an edge of the graph that connects 2 nodes. Each of the connected + // nodes has its own perspective on the link, seeing its local port, remote + // port and the remote node. The direction of the link is encoded in the + // ports, one port is always incoming and another one outgoing. + // + // The link tag here contains both ports of the link viewed from the + // perspective of this node; consisting of both the local port (i.e. at this + // node) and remote port (i.e. on the other node), the local one going first. + struct LinkTag { + struct Hasher { + size_t operator()(const LinkTag& tag) const noexcept { + size_t hval = port_hasher(tag.local); + CombineHash(port_hasher(tag.remote), &hval); + return hval; + } + GenNode::Port::Hasher port_hasher; + }; + + LinkTag(GenNode::Port a_local, GenNode::Port a_remote) + : local(a_local), remote(a_remote) {} + + // The default constructor is used for the default values in maps. + // (false, 99) is an arbitrary value that makes the uninitialized + // links easy to tell when debugging (they should never happen). + LinkTag() : local(false, 99), remote(false, 99) {} + + // Port of the link on the local node. + GenNode::Port local; + // Port of the link on the remote node. + GenNode::Port remote; + + bool operator==(const LinkTag& other) const { + return local == other.local && remote == other.remote; + } + bool operator<(const LinkTag& other) const { + return local < other.local || + (local == other.local && remote < other.remote); + } + }; + + // Since the signature logic doesn't differentiate between the links + // with the same tag (other than by the "peer" nodes on their other ends), + // all the links with the same tag are grouped into a single structure. + struct Link { + LinkTag tag; + size_t unique_hash; // Hash of the tag after conflict resolution. + // The remote node(s) on the other side on the link(s). + using PeerVector = std::vector; + PeerVector peers; + }; + + // A way to look up the link description by its hash. + using LinkHashMap = std::map; + const LinkHashMap& hash_to_link() const { return hash_to_link_; } + + // The enumeration of all the peer nodes in a predictable order. + // Before the signature generation, only the link values determine the + // order, after the signature generation the entries at the same + // links get further sorted by their peer node ranks. + struct HashedPeer { + HashedPeer(size_t l, SigNode* p) : link_hash(l), peer(p) {} + + struct LessByRank { + bool operator()(const SigNode::HashedPeer& left, + const SigNode::HashedPeer& right) { + return left.peer->unique_rank_ < right.peer->unique_rank_; + } + }; + + size_t link_hash; + SigNode* peer; + }; + using HashedPeerVector = std::vector; + const HashedPeerVector& hashed_peers() const { return hashed_peers_; } + + // Compares two nodes in two different graphs for equivalence (two nodes in + // the same graph would never be equivalent). Expects that the signatures of + // the graphs have already been computed, so unique_rank_ is filled in and + // the hashed_peers_ properly ordered. + bool operator==(const SigNode& other) const; + + bool operator!=(const SigNode& other) const { return !(*this == other); } + + private: + friend class test::SigBaseTest; + + // The CopyLinks code is split into 2 parts for testability. + // The first pass builds a map ordered by LinkTag for predictability. + void CopyLinksPass1(const GenNode& from, const TranslationMap& map, + std::map* link_map); + // The second pass converts to the map by hash value, + // resolves any hash conflicts, and builds the hashed peer vector. + void CopyLinksPass2(std::map* link_map); + + // Computes the topological hash at distance 0. Resets the topo_hash_ vector + // and hashed_nodes_; + void ComputeTopoHash0(); + + // Compute the topological has at the given distance. The hashes for all the + // lower distances must be already computed for all the nodes in the graph. + // Also computes next_hashed_nodes_ from last_hashed_nodes_. + void ComputeTopoHash(int distance); + + // Get the hash value for a particular distance. It must be previously + // computed. + size_t GetTopoHash(int distance) const; + + // The hash value for the highest computed distance. It must be previously + // computed. + size_t GetHighTopoHash() const { + CHECK(!topo_hash_.empty()); + return topo_hash_.back(); + } + + // Rehash the topmost hash, to avoid conflicts. + void ReHighTopoHash() { + CHECK(!topo_hash_.empty()); + CombineHash(1, &topo_hash_.back()); + } + + // Ordering by node order and highest available hash (it must be + // previously computed). + struct NodeOrderLess { + bool operator()(const SigNode* left, const SigNode* right) { + return left->topo_hash_.back() < right->topo_hash_.back(); + } + }; + + private: + const NodeDef* node_; + + // The bitmap mask with 1 bit set that represents this node in the set + // during the computation of the signature. + uint64_t node_mask_ = 0; + + // The code that populates this map makes sure that there are no hash + // conflicts, rehashing if necessary. + LinkHashMap hash_to_link_; + + // The enumeration of all the direct peers in the predictable order (which + // happens to be the order ot their link tags, but the order of the hashes + // would do too). It is used for the quick enumeration during the signature + // computation. After the signature building is completed, the entries that + // have the same link tag get further sorted in the order of the ranks of + // their nodes. + HashedPeerVector hashed_peers_; + + // The unique rank represents the order in which the node will be included + // into the signature. It gets assigned in order either when the topo_hash_ of + // this node becomes unique in the graph, or when the nodes are completely + // equivalent, one of them is picked at random to assign the next rank, and + // then the rest of the nodes attempt to disambiguate based on that + // information. + size_t unique_rank_ = ~0; + // When hash_is_final_ is set, the topo_has_ vector stops growing, and the + // last value from it is used for all the further hashes. + bool hash_is_final_ = false; + // The hashes that include the topology of the nodes up to the distance N. The + // hash for distance 0 is produced from the attributes of this node itself and + // its general connectivity properties but no information about the + // neighboring nodes. The hash for distance D+1 is build from hashes at level + // D of this node and of all its immediate neighbors. The neighbors that are + // connected by equivalent links are included in a commutative way. + std::vector topo_hash_; + // The set of nodes that got included into the computation of the + // last topo_hash_ entry. + uint64_t last_hashed_nodes_ = 0; + // The next set of nodes that gets used for the current topo_hash entry. + uint64_t next_hashed_nodes_ = 0; +}; + +// Signature of a graph. The computation is intertwined with the private methods +// of SigNode, so keeping both in the same file looks more convenient. +struct Signature { + friend class test::SigBaseTest; + + // Maximal size of the graphs for which the signature can be computed. + // Changing this constant won't magically add the support for a larger size, + // the rest of implementation would have to be extended. The value of 64 is + // driven by the size of a bitset in an uint64_t, and should be enough for our + // purposes, while having a high efficiency of implementation. + static constexpr int kMaxGraphSize = 64; + + // Using the map, computes the rest of the fields of a signature. + // Returns an error is the graph is too big. + absl::Status Compute(); + + // Convert the computed signature to a string representation. + string ToString() const; + + SigNodeMap map; // The nodes in the graph, accessible by name. + size_t sig_short = 0; // Hash of the signature, for the quick equality check. + // The full signature: hashes of the nodes in a predictable order. + std::vector sig_full; + // The nodes in the same order as they go in the signature. + std::vector nodes; + + // For building the unordered maps. + size_t Hash() const { return sig_short; } + + // Returns true if the graphs are equivalent. The signature must be already + // computed. + bool operator==(const Signature& other) const; + + private: + // Populates the nodes vector from the map and initializes the state of the + // nodes for the signature computation. + void PrepareNodes(); + + // Finds the nodes with the hashes that are unique and assigns the unique ids + // to them. If there are nodes with non-unique hashes, exactly one node from + // the first such sequence (in the order of hash values) will be picked and + // assigned a unique id. Assumes that the nodes[0...(next_node_id-1)] have + // been already assigned the unique ids. Advances next_node_id by at least 1. + void FindUniqueHashes(size_t* next_node_id_p); + + // One round of the signature computation. Assumes that the + // nodes[0...(next_node_id-1)] have been already assigned the fixed + // positions, and thus computes the hashes only for the remaining nodes. + void ComputeOneRound(size_t next_node_id); + + // Additional ordering of the hashed_peers_ links in the nodes, so that they + // can be compared and printed in a predictable order. + void OrderLinks(); +}; + +} // end namespace graph_analyzer +} // end namespace grappler +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPPLER_GRAPH_ANALYZER_SIG_NODE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/grappler/graph_analyzer/subgraph.h b/third_party/tflite-hdrs/tensorflow/core/grappler/graph_analyzer/subgraph.h new file mode 100644 index 00000000..7d3494cd --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/grappler/graph_analyzer/subgraph.h @@ -0,0 +1,190 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPPLER_GRAPH_ANALYZER_SUBGRAPH_H_ +#define TENSORFLOW_CORE_GRAPPLER_GRAPH_ANALYZER_SUBGRAPH_H_ + +#include +#include +#include + +#include "tensorflow/core/grappler/graph_analyzer/gen_node.h" +#include "tensorflow/core/grappler/graph_analyzer/map_tools.h" +#include "tensorflow/core/grappler/graph_analyzer/sig_node.h" +#include "tensorflow/core/lib/gtl/flatset.h" + +namespace tensorflow { +namespace grappler { +namespace graph_analyzer { + +// The description of a single subgraph for processing. +class Subgraph { + public: + // Identity of a single subgraph as a set of nodes. + class Identity : public gtl::FlatSet { + public: + using InitializerList = std::initializer_list; + + Identity() = default; + Identity(InitializerList init); + bool operator<(const Identity& other) const; + bool operator==(const Identity& other) const; + + // Compute the hash. + size_t Hash() const; + }; + + explicit Subgraph(Identity id) : id_(std::move(id)), hash_(id_.Hash()) {} + + // Construct by extending the parent identity with an extra node. + Subgraph(const Identity& parent_id, GenNode* add_node); + + Subgraph() = delete; + Subgraph(const Subgraph& other) = delete; + void operator=(const Subgraph& other) = delete; + + // Order for building sets of subgraphs. + bool operator<(const Subgraph& other) const { return this->id_ < other.id_; } + // Support for hashed sets. + bool operator==(const Subgraph& other) const { + return this->id_ == other.id_; + } + size_t Hash() const { return hash_; } + + // Dump the subgraph information to a string. + string Dump(); + + // Extract this subgraph into a separate graph representation for signature + // building, that includes only the links between the nodes in the subgraph + // and drops all the external links. The result map should be clear before the + // call. + void ExtractForSignature(SigNodeMap* result); + + const Identity& id() const { return id_; } + bool specific() const { return specific_; } + void SetSpecific(bool value) { specific_ = value; } + int32_t collation_count() const { return collation_count_; } + void AddCollation(int32_t n = 1) { collation_count_ += n; } + void ResetCollation() { collation_count_ = 1; } + void MergeCollation(const Subgraph& other) { + collation_count_ += other.collation_count_; + } + + private: + // Identity also serves as the list of nodes. It never changes throughout the + // life of subgraph. + Identity id_; + size_t hash_; // Cached from the identity. + // Whether the dump should include the specific names of the nodes. The + // non-specific (i.e. generic) subgraphs represent a collation of multiple + // subgraphs. + bool specific_ = true; + // How many collated subgraphs are represented by this subgraph. + int32_t collation_count_ = 1; +}; + +// Iteration of all links in a subgraph. This is more like Java iterators than +// the normal C++ iterators. It's simpler this way and there seems to be no +// major reason to make it a proper C++ iterator. +class SubgraphIterator { + public: + // Obviously an iterator is valid only until the original object + // gets destroyed. + explicit SubgraphIterator(const Subgraph::Identity* id); + explicit SubgraphIterator(const Subgraph* sg) : SubgraphIterator(&sg->id()) {} + + // Check whether the built-in iterator is at the end. + bool AtEnd() const { return id_it_ == id_->end(); } + + // Get the neighbor at the current iterator. + // MUST NOT be called when AtEnd(); + const GenNode::LinkTarget& GetNeighbor() const { + return link_map_it_->second[link_idx_]; + } + + // Get the node at the current iterator. + // MUST NOT be called when AtEnd(); + const GenNode* GetNode() const { return *id_it_; } + + // Get the port leading to the neighbor at the current iterator. + // MUST NOT be called when AtEnd(); + GenNode::Port GetPort() const { return link_map_it_->first; } + + // Increases the iterator. + // Returns true if NOT AtEnd() after increasing the iterator. + // Safe to call if already AtEnd(). + bool Next(); + + // If there are more links at the same port, increases the iterator and + // returns true. Otherwise leaves the iterator unchanged and returns false. + bool NextIfSamePort(); + + // Increases the iterator directly to the last position on the current port + // (or if already there then doesn't increase). Equivalent to calling + // NextIfSamePort() while it returns true, but faster. + // Safe to call if already AtEnd(). + void SkipPort(); + + // Increases the iterator directly to the last position on the current node. + // Safe to call if already AtEnd(). + void SkipNode(); + + // Returns true if the iterators are exactly the same. + bool operator==(const SubgraphIterator& other) const; + bool operator!=(const SubgraphIterator& other) const { + return !(*this == other); + } + + private: + // After link_idx_ has been increased, make sure that it points to the + // next valid element (or end) by increasing the higher levels of iteration if + // needed. + // Returns true if NOT AtEnd() after increasing the iterator. + // NOT safe to call if already AtEnd(). + bool PropagateNext(); + + // Identity of the subgraph being iterated over. + const Subgraph::Identity* id_; + + // The current position, allowing to iterate through the links (see the + // reasoning for it in the public section). + // + // (1) Iterator of the nodes in the subgraph. + Subgraph::Identity::const_iterator id_it_; + // (2) Iterator in the link map of the node. + GenNode::LinkMap::const_iterator link_map_it_; + // (3) Index in the vector of the links. + int32_t link_idx_; +}; + +// A convenient way to store subgraphs: in a set of unique_ptrs. This way the +// addresses of subgraph objects will stay stable, and the objects themselves +// won't be copied. +class SubgraphPtrSet + : public std::unordered_set, + HashAtPtr>, + EqAtPtr>> { + public: + // Attempts to extend the set by adding a new subgraph that gets created by + // adding one node to the parent subgraph. If such a subgraph already exists, + // returns nullptr, otherwise returns the pointer to the new subgraph. + Subgraph* ExtendParent(const Subgraph::Identity& parent_id, GenNode* node); +}; + +} // end namespace graph_analyzer +} // end namespace grappler +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPPLER_GRAPH_ANALYZER_SUBGRAPH_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/grappler/graph_analyzer/test_tools.h b/third_party/tflite-hdrs/tensorflow/core/grappler/graph_analyzer/test_tools.h new file mode 100644 index 00000000..98e269d5 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/grappler/graph_analyzer/test_tools.h @@ -0,0 +1,120 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPPLER_GRAPH_ANALYZER_TEST_TOOLS_H_ +#define TENSORFLOW_CORE_GRAPPLER_GRAPH_ANALYZER_TEST_TOOLS_H_ + +#include +#include + +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/grappler/graph_analyzer/gen_node.h" +#include "tensorflow/core/grappler/graph_analyzer/sig_node.h" +#include "tensorflow/core/grappler/op_types.h" + +namespace tensorflow { +namespace grappler { +namespace graph_analyzer { +namespace test { + +//=== Helper methods to construct the nodes. + +NodeDef MakeNodeConst(const string& name); + +NodeDef MakeNode2Arg(const string& name, const string& opcode, + const string& arg1, const string& arg2); + +NodeDef MakeNode4Arg(const string& name, const string& opcode, + const string& arg1, const string& arg2, const string& arg3, + const string& arg4); + +inline NodeDef MakeNodeMul(const string& name, const string& arg1, + const string& arg2) { + return MakeNode2Arg(name, "Mul", arg1, arg2); +} + +// Not really a 2-argument but convenient to construct. +inline NodeDef MakeNodeAddN(const string& name, const string& arg1, + const string& arg2) { + return MakeNode2Arg(name, "AddN", arg1, arg2); +} + +inline NodeDef MakeNodeSub(const string& name, const string& arg1, + const string& arg2) { + return MakeNode2Arg(name, "Sub", arg1, arg2); +} + +// Has 2 honest outputs. +inline NodeDef MakeNodeBroadcastGradientArgs(const string& name, + const string& arg1, + const string& arg2) { + return MakeNode2Arg(name, "BroadcastGradientArgs", arg1, arg2); +} + +NodeDef MakeNodeShapeN(const string& name, const string& arg1, + const string& arg2); + +NodeDef MakeNodeIdentityN(const string& name, const string& arg1, + const string& arg2); + +NodeDef MakeNodeQuantizedConcat(const string& name, const string& arg1, + const string& arg2, const string& arg3, + const string& arg4); + +//=== A container of pre-constructed graphs. + +class TestGraphs { + public: + TestGraphs(); + + // Graph with 3 nodes and a control link to self (which is not valid in + // reality but adds excitement to the tests). + GraphDef graph_3n_self_control_; + // Graph that has the multi-input links. + GraphDef graph_multi_input_; + // Graph that has the all-or-none nodes. + GraphDef graph_all_or_none_; + // All the nodes are connected in a circle that goes in one direction. + GraphDef graph_circular_onedir_; + // All the nodes are connected in a circle that goes in both directions. + GraphDef graph_circular_bidir_; + // The nodes are connected in a line. + GraphDef graph_linear_; + // The nodes are connected in a cross shape. + GraphDef graph_cross_; + GraphDef graph_small_cross_; + // For testing the ordering of links at the end of signature generation, + // a variation of a cross. + GraphDef graph_for_link_order_; + // Sun-shaped, a ring with "rays". + GraphDef graph_sun_; +}; + +//=== Helper methods for analysing the structures. + +std::vector DumpLinkMap(const GenNode::LinkMap& link_map); + +// Also checks for the consistency of hash values. +std::vector DumpLinkHashMap(const SigNode::LinkHashMap& link_hash_map); + +std::vector DumpHashedPeerVector( + const SigNode::HashedPeerVector& hashed_peers); + +} // end namespace test +} // end namespace graph_analyzer +} // end namespace grappler +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPPLER_GRAPH_ANALYZER_TEST_TOOLS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/grappler/graph_topology_view.h b/third_party/tflite-hdrs/tensorflow/core/grappler/graph_topology_view.h new file mode 100644 index 00000000..91cbfa2a --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/grappler/graph_topology_view.h @@ -0,0 +1,116 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPPLER_GRAPH_TOPOLOGY_VIEW_H_ +#define TENSORFLOW_CORE_GRAPPLER_GRAPH_TOPOLOGY_VIEW_H_ + +#include "absl/container/flat_hash_map.h" +#include "absl/container/inlined_vector.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "absl/types/span.h" +#include "tensorflow/core/graph/tensor_id.h" +#include "tensorflow/core/grappler/graph_view.h" + +namespace tensorflow { +namespace grappler { + +// GraphTopologyView is a helper class to simplify `node-to-node` connectivity +// traversals. Regular `GraphView` simplifies `tensor-to-tensor` traversals: +// connections between output tensors and inputs of a consumer nodes. For the +// topology view we are focused on nodes connected to nodes, and it's irrelevant +// if this connection is formed by one or multiple individual tensors. +// +// Example: +// a = Placeholder(..) +// b = Placeholder(..) +// c = AddN([a, a, b]) +// +// GraphView edges: [a:0 -> c:0, a:0 -> c:1, b:0 -> c:2] +// GraphTopologyView edges: [a -> c, b -> c] +// +// GraphView is used for exploring single node fanins and fanouts, and +// GraphTopologyView is focused on efficient full graph traversals (computing +// graph node properties from transitive fanouts, etc...). +class GraphTopologyView { + public: + GraphTopologyView() = default; + explicit GraphTopologyView(bool skip_invalid_edges) + : skip_invalid_edges_(skip_invalid_edges) {} + + // Initialize graph topology view from the graph. It's possible to pass + // additional edges that do not exist in a graph, but must be respected when + // computing graph topology. Example: Tensorflow runtime allows concurrent + // execution of dequeue/enqueue ops from the same queue resource, but we might + // want to enforce ordering between them for the purpose of graph analysis. + absl::Status InitializeFromGraph( + const GraphDef& graph, absl::Span ephemeral_edges, + bool ignore_control_edges); + absl::Status InitializeFromGraph( + const GraphDef& graph, absl::Span ephemeral_edges); + absl::Status InitializeFromGraph(const GraphDef& graph, + bool ignore_control_edges); + absl::Status InitializeFromGraph(const GraphDef& graph); + + bool is_initialized() const { return graph_ != nullptr; } + int num_nodes() const { return num_nodes_; } + const GraphDef* graph() const { return graph_; } + + // Returns true iff the node exists in the underlying graph. + bool HasNode(absl::string_view node_name) const; + + // Finds a node by name or returns `nullptr` if it's not in the graph. + const NodeDef* GetNode(absl::string_view node_name) const; + // Returns a node corresponding to the given node index. + const NodeDef* GetNode(int node_idx) const; + + // Returns a node index for the given node name, if the name exists in the + // underlying graph. Otherwise returns empty optional. + const absl::optional GetNodeIndex(absl::string_view node_name) const; + // Returns a node index for the given node, if the node belongs to the + // underlying graph. Otherwise returns empty optional. + const absl::optional GetNodeIndex(const NodeDef& node) const; + + // Returns all the node indexes that are in the direct fanin of the given + // node. If the `node_idx` is outside of [0, num_nodes_) returns empty vector. + const absl::InlinedVector& GetFanin(int node_idx) const; + // Returns all the node indexes that are in the direct fanout of the given + // node. If the `node_idx` is outside of [0, num_nodes_) returns empty vector. + const absl::InlinedVector& GetFanout(int node_idx) const; + + private: + // If true, all invalid edges and inputs (srd, dst or input node not found in + // a graph) will be skipped, otherwise initialization will fail with error. + bool skip_invalid_edges_ = false; + + // WARN: `graph_` must outlive this object and graph nodes must not be + // destructed, because node names captured with absl::string_view. + const GraphDef* graph_ = nullptr; // do not own + int num_nodes_ = 0; + std::vector index_to_node_name_; + absl::flat_hash_map node_name_to_index_; + std::vector> fanins_; // node_idx->input nodes + std::vector> fanouts_; // node_idx->output nodes + + // We need a valid reference to return from GetFanin/GetFanout if the + // `node_idx` argument is outside of the [0, num_nodes_) range. + absl::InlinedVector empty_fanin_; + absl::InlinedVector empty_fanout_; +}; + +} // end namespace grappler +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPPLER_GRAPH_TOPOLOGY_VIEW_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/grappler/graph_view.h b/third_party/tflite-hdrs/tensorflow/core/grappler/graph_view.h new file mode 100644 index 00000000..4b7e8cfe --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/grappler/graph_view.h @@ -0,0 +1,428 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPPLER_GRAPH_VIEW_H_ +#define TENSORFLOW_CORE_GRAPPLER_GRAPH_VIEW_H_ + +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/hash/hash.h" +#include "absl/status/status.h" +#include "absl/strings/string_view.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/op_def.pb.h" +#include "tensorflow/core/graph/tensor_id.h" +#include "tensorflow/core/grappler/utils.h" +#include "tensorflow/core/lib/gtl/map_util.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +namespace grappler { + +// Map a node/op's input/output port_id to arg_id. +// +// The port_id refers to the n-th tensor of the node, while the arg_id refers to +// the n-th arg of the op. These two can be different if an op's arg is a list +// of tensors. +// +// We return -1 for any invalid port_id (i.e., no corresponding arg_id). +int OpOutputPortIdToArgId(const NodeDef& node, const OpDef& op, int port_id); +int OpInputPortIdToArgId(const NodeDef& node, const OpDef& op, int port_id); + +namespace internal { + +// GraphViewInternal is a helper class to simplify graph traversal. It creates +// an immutable view of the nodes and edges represented by a GraphDef protocol +// buffer. +// +// There are two public classes implementing GraphViewInternal: +// +// - GraphView: constructed from the `const GraphDef` and doesn't allow +// to mutate underlying graph via input/output ports lookup functions (ports +// have const pointers to nodes). +// +// - MutableGraphView: constructed from the 'GraphDef` and allows to mutate +// the graph via input/output ports lookup functions (ports have non-const +// pointers to nodes), and also have couple additional functions to +// add/remove/replace nodes in the graph. +// +// --------------------------- !!! WARNING !!! --------------------------------- +// Removing nodes from the graph outside of MutableGraphView will +// lead to segfaults! Guaranteed by absl::string_view! +// ----------------------------------------------------------------------------- +// +template +class GraphViewInternal { + public: + struct Port { + Port() : node(nullptr), port_id(0) {} + Port(NodeDefT* n, int port) : node(n), port_id(port) {} + + bool operator==(const Port& other) const { + return node == other.node && port_id == other.port_id; + } + + template + friend H AbslHashValue(H h, const Port& p) { + return H::combine(std::move(h), p.node, p.port_id); + } + + NodeDefT* node; + int port_id; + }; + + struct InputPort : public Port { + using Port::Port; + }; + + struct OutputPort : public Port { + using Port::Port; + }; + + struct Edge { + Edge(OutputPort s, InputPort d) : src(s), dst(d) {} + + bool operator==(const Edge& other) const { + return src == other.src && dst == other.dst; + } + + template + friend H AbslHashValue(H h, const Edge& e) { + return H::combine(std::move(h), e.src, e.dst); + } + + OutputPort src; + InputPort dst; + }; + + GraphDefT* graph() const { return graph_; } + + // Finds a node by name or return `nullptr` if it's not in the graph view. + NodeDefT* GetNode(absl::string_view node_name) const { + return gtl::FindWithDefault(nodes_, node_name, nullptr); + } + + // Checks if a node by name is in the graph view. + bool HasNode(absl::string_view node_name) const { + return GetNode(node_name) != nullptr; + } + + // Gets the specified input port. Note that the special '-1' port_id can be + // used to access the controlling nodes (i.e. the nodes connected to node_name + // through an incoming control dependency). + InputPort GetInputPort(absl::string_view node_name, int port_id) const { + return InputPort(GetNode(node_name), port_id); + } + + // Gets the specified output port. Note that the special '-1' port_id can be + // used to access the controlled nodes (i.e. the nodes connected to node_name + // through an outgoing control dependency). + OutputPort GetOutputPort(absl::string_view node_name, int port_id) const { + return OutputPort(GetNode(node_name), port_id); + } + + // Gets the input port(s) in the immediate fanout of an output port. + const absl::flat_hash_set& GetFanout( + const OutputPort& port) const { + return gtl::FindWithDefault(fanouts_, port, fanout_not_found_value_); + } + + // Gets the output port(s) in the immediate fanin of an input port. + absl::flat_hash_set GetFanin(const InputPort& port) const { + if (port.port_id >= 0) { + OutputPort regular_fanin = GetRegularFanin(port); + if (regular_fanin.node == nullptr) { + return {}; + } + return {regular_fanin}; + } + + // Collect fanin for the control input. + absl::flat_hash_set result; + const int first_control_port = + gtl::FindWithDefault(max_regular_input_port_, port.node, -1) + 1; + for (int i = first_control_port; i < port.node->input_size(); ++i) { + TensorId tensor_id = ParseTensorName(port.node->input(i)); + + auto it = nodes_.find(tensor_id.node()); + if (it != nodes_.end()) result.emplace(it->second, tensor_id.index()); + } + return result; + } + + // Special case: regular (i.e. non-control) input ports can only have one + // fanin. If port.port_id is out of range or is a control dependency, then an + // empty OutputPort is returned. + const OutputPort GetRegularFanin(const InputPort& port) const { + if (port.port_id < 0 || + port.port_id > + gtl::FindWithDefault(max_regular_input_port_, port.node, -1)) { + return OutputPort(); + } + + TensorId tensor_id = ParseTensorName(port.node->input(port.port_id)); + return GetOutputPort(tensor_id.node(), tensor_id.index()); + } + + // Checks if a tensor id is a fanin of the node. + bool HasFanin(const NodeDefT& node, const TensorId& fanin) const { + int end = node.input_size(); + if (end == 0 || fanin.index() < -1) { + return false; + } + + const int num_regular_fanins = + gtl::FindWithDefault(max_regular_input_port_, &node, -1) + 1; + int start = 0; + if (fanin.index() > -1) { + end = num_regular_fanins; + } else { + start = num_regular_fanins; + } + for (int i = start; i < end; ++i) { + if (ParseTensorName(node.input(i)) == fanin) { + return true; + } + } + return false; + } + + // Gets all the input ports in the immediate fanout of a node. Include the + // controlled nodes iff include_controlled_nodes is true. + absl::flat_hash_set GetFanouts( + const NodeDefT& node, bool include_controlled_nodes) const { + absl::flat_hash_set result; + + OutputPort port; + port.node = const_cast(&node); + const int first_port_id = include_controlled_nodes ? -1 : 0; + const int last_port_id = + gtl::FindWithDefault(max_regular_output_port_, &node, -1); + + for (int i = first_port_id; i <= last_port_id; ++i) { + port.port_id = i; + auto it = fanouts_.find(port); + if (it != fanouts_.end()) { + result.insert(it->second.begin(), it->second.end()); + } + } + return result; + } + + // Gets all the output ports in the immediate fanin of a node. Include the + // controlling nodes iff include_controlling_nodes is true. + absl::flat_hash_set GetFanins( + const NodeDefT& node, bool include_controlling_nodes) const { + absl::flat_hash_set result; + const int max_input_port = + include_controlling_nodes + ? node.input_size() - 1 + : gtl::FindWithDefault(max_regular_input_port_, &node, -1); + for (int i = 0; i <= max_input_port; ++i) { + TensorId tensor_id = ParseTensorName(node.input(i)); + + auto it = nodes_.find(tensor_id.node()); + if (it != nodes_.end()) result.emplace(it->second, tensor_id.index()); + } + return result; + } + + // Gets the number of ports in the immediate fanin of a node. Count the + // controlling nodes iff include_controlling_nodes is true. + int NumFanins(const NodeDefT& node, bool include_controlling_nodes) const { + if (include_controlling_nodes) { + return node.input_size(); + } + return gtl::FindWithDefault(max_regular_input_port_, &node, -1) + 1; + } + + // Gets the number of ports in the immediate fanout of a node. Count the + // controlled nodes iff include_controlled_nodes is true. + int NumFanouts(const NodeDefT& node, bool include_controlled_nodes) const { + int count = 0; + + OutputPort port; + port.node = const_cast(&node); + const int first_port_id = include_controlled_nodes ? -1 : 0; + const int last_port_id = + gtl::FindWithDefault(max_regular_output_port_, &node, -1); + + for (int i = first_port_id; i <= last_port_id; ++i) { + port.port_id = i; + auto it = fanouts_.find(port); + if (it != fanouts_.end()) count += it->second.size(); + } + + return count; + } + + // Gets all the edges in the immediate fanout of a node. Include the + // controlled edges iff include_controlled_edges is true. + absl::flat_hash_set GetFanoutEdges( + const NodeDefT& node, bool include_controlled_edges) const { + absl::flat_hash_set result; + + OutputPort port; + port.node = const_cast(&node); + const int first_port_id = include_controlled_edges ? -1 : 0; + const int last_port_id = + gtl::FindWithDefault(max_regular_output_port_, &node, -1); + + for (int i = first_port_id; i <= last_port_id; ++i) { + port.port_id = i; + auto it = fanouts_.find(port); + if (it != fanouts_.end()) { + for (auto itr = it->second.begin(); itr != it->second.end(); ++itr) { + result.emplace(/*src=*/port, /*dst=*/*itr); + } + } + } + return result; + } + + // Gets all the edges in the immediate fanin of a node. Include the + // controlling edges iff include_controlling_edges is true. + absl::flat_hash_set GetFaninEdges( + const NodeDefT& node, bool include_controlling_edges) const { + absl::flat_hash_set result; + const int max_input_port = + include_controlling_edges + ? node.input_size() - 1 + : gtl::FindWithDefault(max_regular_input_port_, &node, -1); + for (int i = 0; i <= max_input_port; ++i) { + TensorId tensor_id = ParseTensorName(node.input(i)); + + auto it = nodes_.find(tensor_id.node()); + if (it != nodes_.end()) { + result.emplace(/*src=*/OutputPort(it->second, tensor_id.index()), + /*dst=*/InputPort(const_cast(&node), i)); + } + } + return result; + } + + protected: + explicit GraphViewInternal(GraphDefT* graph) : graph_(graph) {} + + absl::Status AddUniqueNode(NodeDefT* node) { + auto inserted = nodes_.emplace(node->name(), node); + return inserted.second + ? absl::OkStatus() + : absl::InvalidArgumentError(absl::StrCat( + "Non unique node name detected: ", node->name())); + } + + // TODO(ezhulenev): Remove this function. + void AddUniqueNodeOrDie(NodeDefT* node) { + absl::Status st = AddUniqueNode(node); + CHECK(st.ok()) << st.message(); + } + + // TODO(lyandy): Checks for self loops, Switch control dependencies, fanins + // exist, and all regular fanins come before controlling fanins. + void AddFanouts(NodeDefT* node) { + int max_input_port = -1; + for (int i = 0; i < node->input_size(); ++i) { + TensorId tensor_id = ParseTensorName(node->input(i)); + OutputPort output(nodes_[tensor_id.node()], tensor_id.index()); + + if (output.port_id < 0) { + fanouts_[output].emplace(node, -1); + } else { + max_input_port = i; + int& max_regular_output_port = max_regular_output_port_[output.node]; + max_regular_output_port = + std::max(max_regular_output_port, output.port_id); + fanouts_[output].emplace(node, i); + } + } + if (max_input_port > -1) { + max_regular_input_port_[node] = max_input_port; + } + } + + // Access to the mutable internal state for MutableGraphView. + absl::flat_hash_map& nodes() { return nodes_; } + + absl::flat_hash_map>& fanouts() { + return fanouts_; + } + + absl::flat_hash_map& max_regular_input_port() { + return max_regular_input_port_; + } + + absl::flat_hash_map& max_regular_output_port() { + return max_regular_output_port_; + } + + private: + GraphDefT* graph_; // must outlive the graph view + + // A mapping from the node name to the node itself. + absl::flat_hash_map nodes_; + + // A mapping from the output port to all inputs that read from it. + absl::flat_hash_map> fanouts_; + + // Keep a maximum index of input tensors of the node. + absl::flat_hash_map max_regular_input_port_; + + // Keep a maximum index of tensor fetched from the node. It doesn't guarantee + // that all tensors in the [0, max_regular_output_port] range are actually + // fetched by other nodes. + absl::flat_hash_map max_regular_output_port_; + + // If the node has no fanouts at given output port (output tensor consumers) + // we return a reference to this set from `GetFanout` (we can't construct new + // empty set every time, because we need a non-dangling reference). + absl::flat_hash_set fanout_not_found_value_; +}; + +} // namespace internal + +// Immutable GraphView that keeps the constness of the GraphDef. If you need to +// mutate the graph or the nodes via the graph view lookup functions, see +// MutableGraphView. +class GraphView + : public internal::GraphViewInternal { + public: + explicit GraphView(const GraphDef* graph) : GraphViewInternal(graph) { + for (const NodeDef& node : graph->node()) AddUniqueNodeOrDie(&node); + for (const NodeDef& node : graph->node()) AddFanouts(&node); + } +}; + +// Returns true if node has one (or zero) fanout nodes at given output port. +bool HasSingleFanoutNode(const GraphView& graph_view, const NodeDef* node, + int port = 0); + +// Returns true if node has at least one fanout node at given output port. +bool HasFanouts(const GraphView& graph_view, const NodeDef* node, int port = 0); +// Returns true if the node has at least one input control dependency. +bool HasControlFanin(const GraphView& graph_view, const NodeDef* node); +// Returns true if the node has at least one output control dependency. +bool HasControlFanout(const GraphView& graph_view, const NodeDef* node); +// Returns true if the node has at least one input or output control dependency. +bool HasControlFaninOrFanout(const GraphView& graph_view, const NodeDef* node); + +} // end namespace grappler +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPPLER_GRAPH_VIEW_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/grappler/grappler_item.h b/third_party/tflite-hdrs/tensorflow/core/grappler/grappler_item.h new file mode 100644 index 00000000..36bc4f15 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/grappler/grappler_item.h @@ -0,0 +1,145 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPPLER_GRAPPLER_ITEM_H_ +#define TENSORFLOW_CORE_GRAPPLER_GRAPPLER_ITEM_H_ + +#include +#include +#include +#include +#include +#include + +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/variable.pb.h" +#include "tensorflow/core/protobuf/queue_runner.pb.h" +#include "tsl/platform/cpu_info.h" + +namespace tensorflow { +namespace grappler { + +// A TensorFlow model to optimize. +// Models are represented by the combination of a graph, one of more fetch +// nodes, and potentially a set of nodes to feed. +struct GrapplerItem { + GrapplerItem() = default; + GrapplerItem(const GrapplerItem& other) = default; + GrapplerItem(GrapplerItem&& other) = default; + GrapplerItem& operator=(const GrapplerItem& other) = default; + GrapplerItem& operator=(GrapplerItem&& other) = default; + virtual ~GrapplerItem() = default; + + // Create a copy of this GrapplerItem with graph swapped with the argument. + GrapplerItem WithGraph(GraphDef&& graph) const; + + string id; // A unique id for this item + + // Inputs + GraphDef graph; + std::vector> feed; + std::vector fetch; + + // Initialization op(s). + std::vector init_ops; + // Expected initialization time in seconds, or 0 if unknown + int64_t expected_init_time = 0; + + // Save/restore ops (if any) + string save_op; + string restore_op; + string save_restore_loc_tensor; + + // Queue runner(s) required to run the queue(s) of this model. + std::vector queue_runners; + + // List of op names to keep in the graph. This includes nodes that are + // referenced in various collections, and therefore must be preserved to + // ensure that the optimized metagraph can still be loaded. + std::vector keep_ops; + + // Return the set of node evaluated during a regular train/inference step. + std::vector MainOpsFanin() const; + // Return the set of node run to populate the queues (if any). + std::vector EnqueueOpsFanin() const; + // Return the set nodes used by TensorFlow to initialize the graph. + std::vector InitOpsFanin() const; + // Return the set of variables accessed during a regular train/inference step. + std::vector MainVariables() const; + // Return a set of node names that must be preserved. This includes feed and + // fetch nodes, keep_ops, init_ops. + std::unordered_set NodesToPreserve() const; + + struct OptimizationOptions { + // Is it allowed to add nodes to the graph that do not have registered + // gradient function. + bool allow_non_differentiable_rewrites = true; + + // Tensorflow function execution semantics is slightly different from the + // main Tensorflow graph, and we need to make sure that we do not change it + // by running Grappler optimizer passes. One main difference is that + // functions do not prune ops with side-effects and dataset-output ops (see + // PruneFunctionBody in common_runtime/function.cc). + bool allow_pruning_stateful_and_dataset_ops = true; + + // If true Grappler will optimize the main graph, and also all functions in + // the graph function library (function can't be polymorphic, it can't have + // undefined type parameters in the function signature, or placeholder + // attributes in the function body). + bool optimize_function_library = true; + + // Mark the grapper optimization run in eager mode or not. + bool is_eager_mode = false; + + // Number of intra threads used to run operation. + int intra_op_parallelism_threads = tsl::port::MaxParallelism(); + }; + + const std::unordered_set& devices() const; + // Adds a device to a set of available devices, only if it's a valid fully + // defined device name. Returns `OkStatus()` if successfully added a device, + // and an error otherwise. + absl::Status AddDevice(const string& device); + // Adds all valid devices from the other Grappler item to the device set. + absl::Status AddDevices(const GrapplerItem& other); + // Adds all valid devices from the nodes of the graph to the device set. + // Returns `OkStatus()` if all device annotations found in a graph are valid + // fully defined device names, and an error otherwise. + absl::Status InferDevicesFromGraph(); + // Clears a set of available devices. + void ClearDevices(); + + const OptimizationOptions& optimization_options() const; + OptimizationOptions& optimization_options(); + + private: + // TODO(ezhulenev) Make GrapplerItem a class and hide all public data members. + // TODO(ezhulenev): Migrate all unordered collections to absl. + + // A set of fully defined device names that can be used to place the nodes of + // the `graph`. + // Example of a fully defined name: "/job:work/replica:1/task:1/device:CPU:0" + std::unordered_set devices_; + + OptimizationOptions optimization_options_; +}; + +GrapplerItem::OptimizationOptions CreateOptOptionsForEager(); + +} // end namespace grappler +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPPLER_GRAPPLER_ITEM_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/grappler/grappler_item_builder.h b/third_party/tflite-hdrs/tensorflow/core/grappler/grappler_item_builder.h new file mode 100644 index 00000000..00661da0 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/grappler/grappler_item_builder.h @@ -0,0 +1,85 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPPLER_GRAPPLER_ITEM_BUILDER_H_ +#define TENSORFLOW_CORE_GRAPPLER_GRAPPLER_ITEM_BUILDER_H_ + +#include +#include +#include + +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/grappler/grappler_item.h" + +namespace tensorflow { + +class MetaGraphDef; + +namespace grappler { + +struct ItemConfig { + ItemConfig() {} + + // If true, ignore all user specified node placement. + bool ignore_user_placement = true; + // If true, ignore all user specified colocation attributes. + bool ignore_colocation = true; + // Dimension to use if a placeholder node has an _output_shapes attribute with + // a dimension of -1. + int placeholder_unknown_output_shape_dim = -1; + // If true, erases all "_noinline" attributes from user-defined functions. + // Has no effect if "inline_functions" is disabled. + bool erase_noinline_attributes = false; + // If non-empty, override the directory of asset paths. + string assets_directory_override; + // If true, runs ModelPruner on the graph. + bool prune_graph = false; + // Override feed nodes list. + std::set feed_nodes; + // Override fetch nodes list. + std::set fetch_nodes; + + // Configs for graph optimizations from common_runtime. This is NOT Grappler + // function optimizer. When Grappler is invoked at runtime, it is typically + // running after common_runtime pass. + // + // If true, does L1 optimizations. + bool apply_optimizations = false; + // If true, does function inlining. + bool inline_functions = false; +}; + +// Method for optimizing the graph def (including function inlining and other +// optimizations). This is optimizations from common_runtime, NOT Grappler +// function optimizer. +absl::Status RuntimeGraphOptimizer(const GraphDef& graph_def_arg, + GraphDef* output_graph_def, + const ItemConfig& cfg); + +// Factory method for creating a GrapplerItem from a MetaGraphDef. +// Returns nullptr if the given meta_graph cannot be converted. +std::unique_ptr GrapplerItemFromMetaGraphDef( + const string& id, const MetaGraphDef& meta_graph, const ItemConfig& cfg); + +// Factory method for creating a GrapplerItem from a file +// containing a MetaGraphDef in either binary or text format. +// Returns nullptr if the given meta_graph cannot be converted. +std::unique_ptr GrapplerItemFromMetaGraphDefFile( + const string& id, const string& meta_graph_file, const ItemConfig& cfg); + +} // end namespace grappler +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPPLER_GRAPPLER_ITEM_BUILDER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/grappler/inputs/file_input_yielder.h b/third_party/tflite-hdrs/tensorflow/core/grappler/inputs/file_input_yielder.h new file mode 100644 index 00000000..f3e9ecb6 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/grappler/inputs/file_input_yielder.h @@ -0,0 +1,56 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// The file input provides a mechanism to feed grappler with existing TensorFlow +// graphs stored in TensorFlow checkpoints. Note that at this point the weights +// that may be stored in the checkpoint are not restored in order to speedup the +// initialization. + +#ifndef TENSORFLOW_CORE_GRAPPLER_INPUTS_FILE_INPUT_YIELDER_H_ +#define TENSORFLOW_CORE_GRAPPLER_INPUTS_FILE_INPUT_YIELDER_H_ + +#include +#include +#include +#include "tensorflow/core/grappler/inputs/input_yielder.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +namespace grappler { + +class GrapplerItem; + +class FileInputYielder : public InputYielder { + public: + // Iterates over the files specified in the list of 'filename' up to + // 'max_iterations' times. + explicit FileInputYielder( + const std::vector& filenames, + size_t max_iterations = std::numeric_limits::max()); + bool NextItem(GrapplerItem* item) override; + + private: + const std::vector filenames_; + size_t current_file_; + size_t current_iteration_; + size_t max_iterations_; + + size_t bad_inputs_; +}; + +} // end namespace grappler +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPPLER_INPUTS_FILE_INPUT_YIELDER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/grappler/inputs/input_yielder.h b/third_party/tflite-hdrs/tensorflow/core/grappler/inputs/input_yielder.h new file mode 100644 index 00000000..06f642c5 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/grappler/inputs/input_yielder.h @@ -0,0 +1,35 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPPLER_INPUTS_INPUT_YIELDER_H_ +#define TENSORFLOW_CORE_GRAPPLER_INPUTS_INPUT_YIELDER_H_ + +namespace tensorflow { +namespace grappler { + +struct GrapplerItem; + +// Abstract interface for yielding graphs that we want to optimize. +class InputYielder { + public: + virtual ~InputYielder() {} + + virtual bool NextItem(GrapplerItem* item) = 0; +}; + +} // end namespace grappler +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPPLER_INPUTS_INPUT_YIELDER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.h b/third_party/tflite-hdrs/tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.h new file mode 100644 index 00000000..bf776bcd --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.h @@ -0,0 +1,47 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPPLER_INPUTS_TRIVIAL_TEST_GRAPH_INPUT_YIELDER_H_ +#define TENSORFLOW_CORE_GRAPPLER_INPUTS_TRIVIAL_TEST_GRAPH_INPUT_YIELDER_H_ + +#include +#include +#include "tensorflow/core/grappler/inputs/input_yielder.h" + +namespace tensorflow { +namespace grappler { + +class Cluster; +struct GrapplerItem; + +class TrivialTestGraphInputYielder : public InputYielder { + public: + TrivialTestGraphInputYielder(int num_stages, int width, int tensor_size, + bool insert_queue, + const std::vector& device_names); + bool NextItem(GrapplerItem* item) override; + + private: + const int num_stages_; + const int width_; + const int tensor_size_; + const bool insert_queue_; + std::vector device_names_; +}; + +} // end namespace grappler +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPPLER_INPUTS_TRIVIAL_TEST_GRAPH_INPUT_YIELDER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/grappler/inputs/utils.h b/third_party/tflite-hdrs/tensorflow/core/grappler/inputs/utils.h new file mode 100644 index 00000000..9caefcd8 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/grappler/inputs/utils.h @@ -0,0 +1,49 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPPLER_INPUTS_UTILS_H_ +#define TENSORFLOW_CORE_GRAPPLER_INPUTS_UTILS_H_ + +#include +#include + +#include "absl/status/status.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/protobuf/meta_graph.pb.h" + +namespace tensorflow { +namespace grappler { + +bool FilesExist(const std::vector& files, + std::vector* status = nullptr); +bool FilesExist(const std::set& files); + +bool FileExists(const string& file, absl::Status* status); + +// Reads GraphDef from file in either text or raw serialized format. +absl::Status ReadGraphDefFromFile(const string& graph_def_path, + GraphDef* result); + +// Reads MetaGraphDef from file in either text or raw serialized format. +absl::Status ReadMetaGraphDefFromFile(const string& meta_graph_def_path, + MetaGraphDef* result); + +} // end namespace grappler +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPPLER_INPUTS_UTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/grappler/mutable_graph_view.h b/third_party/tflite-hdrs/tensorflow/core/grappler/mutable_graph_view.h new file mode 100644 index 00000000..fdd4fa32 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/grappler/mutable_graph_view.h @@ -0,0 +1,336 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPPLER_MUTABLE_GRAPH_VIEW_H_ +#define TENSORFLOW_CORE_GRAPPLER_MUTABLE_GRAPH_VIEW_H_ + +#include +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/graph/tensor_id.h" +#include "tensorflow/core/grappler/graph_view.h" +#include "tensorflow/core/grappler/op_types.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +namespace grappler { + +const char kMutableGraphViewCtrl[] = "ConstantFoldingCtrl"; + +// A utility class to simplify the traversal of a GraphDef that, unlike +// GraphView, supports updating the graph. Note that you should not modify the +// graph separately, because the view will get out of sync. + +class MutableGraphView : public internal::GraphViewInternal { + public: + explicit MutableGraphView(GraphDef* graph) : GraphViewInternal(graph) { + for (NodeDef& node : *graph->mutable_node()) AddUniqueNodeOrDie(&node); + for (NodeDef& node : *graph->mutable_node()) AddAndDedupFanouts(&node); + } + + // Lookup fanouts/fanins using immutable ports. + using GraphViewInternal::GetFanout; + const absl::flat_hash_set& GetFanout( + const GraphView::OutputPort& port) const; + + using GraphViewInternal::GetFanin; + absl::flat_hash_set GetFanin( + const GraphView::InputPort& port) const; + + using GraphViewInternal::GetRegularFanin; + const OutputPort GetRegularFanin(const GraphView::InputPort& port) const; + + // Adds a new node to graph and updates the view. Returns a pointer to the + // node in graph. + NodeDef* AddNode(NodeDef&& node); + + // Adds all nodes from the `subgraph` to the underlying graph and updates the + // view. `subgraph` doesn't have to be a valid graph definition on it's own, + // it can have edges to the nodes that are not in it, however after adding + // it to the underlying graph, final graph must be valid. + // + // If subgraph function library is not empty, all new functions will be added + // to the graph. Functions that appear with the same name in both subgraph and + // the graph represented by *this, must have identical function definitions. + // + // IMPORTANT: All nodes and functions of the given subgraph moved into the + // underlying graph, which leaves subgraph in valid but undefined state. + absl::Status AddSubgraph(GraphDef&& subgraph); + + // Updates node `node_name` op, device, and attributes. This will clear any + // existing attributes. If it is not possible to update the node or if the + // node does not exist, an error will be returned and nothing will be modified + // in the graph. + absl::Status UpdateNode(absl::string_view node_name, absl::string_view op, + absl::string_view device, + absl::Span> attrs); + + // Updates node `from_node_name` name to `to_node_name`. If `to_node_name` is + // in use, node `from_node_name` does not exist, or node `from_node_name` has + // fanouts and `update_fanouts` is set to false, an error will be returned and + // nothing will be modified in the graph. + absl::Status UpdateNodeName(absl::string_view from_node_name, + absl::string_view to_node_name, + bool update_fanouts); + + // Swap node names `from_node_name` and `to_node_name`. Self loops of one node + // are removed by updating the inputs introducing self loops to use the other + // node's name. Setting `update_fanouts` to false will exclude other fanouts + // from having their inputs updated, but inputs introducing self loops will + // always be updated regardless of `update_fanouts. + // + // Example: + // 1. foo(other:3, bar:2, ^bar) + // 2. bar(foo:3, other:1, foo:1, ^foo) + // 3. other(foo:5, bar:6) + // + // After calling SwapNodeNames("foo", "bar", false): + // 1. bar(other:3, foo:2, ^foo) + // 2. foo(bar:3, other:1, bar:1, ^bar) + // 3. other(foo:5, bar:6) + // + // After calling SwapNodeNames("foo", "bar", true): + // 1. bar(other:3, foo:2, ^foo) + // 2. foo(bar:3, other:1, bar:1, ^bar) + // 3. other(bar:5, foo:6) + // + // If it is not possible to swap node names (i.e. nodes do not exist or Switch + // control dependency may be introduced), an error will be returned and + // nothing will be modified in the graph. + absl::Status SwapNodeNames(absl::string_view from_node_name, + absl::string_view to_node_name, + bool update_fanouts); + + // Updates all fanouts (input ports fetching output tensors) from + // `from_node_name` to the `to_node_name`, including control dependencies. + // + // Example: We have 3 nodes that use `bar` node output tensors as inputs: + // 1. foo1(bar:0, bar:1, other:0) + // 2. foo2(bar:1, other:1) + // 3. foo3(other:2, ^bar) + // + // After calling UpdateFanouts(bar, new_bar): + // 1. foo1(new_bar:0, new_bar:1, other:0) + // 2. foo2(new_bar:1, other:1) + // 3. foo3(other:2, ^new_bar) + absl::Status UpdateFanouts(absl::string_view from_node_name, + absl::string_view to_node_name); + + // Adds regular fanin `fanin` to node `node_name`. If the node or fanin do not + // exist in the graph, nothing will be modified in the graph. Otherwise fanin + // will be added after existing non control dependency fanins. Control + // dependencies will be deduped. To add control dependencies, use + // AddControllingFanin. + absl::Status AddRegularFanin(absl::string_view node_name, + const TensorId& fanin); + + // Adds regular fanin `fanin` to node `node_name` at port `port`. If the node + // or fanin do not exist in the graph, nothing will be modified in the graph. + // Otherwise fanin will be inserted at port `port`. Control dependencies will + // be deduped. To add control dependencies, use AddControllingFanin. + // + // If the port is not a valid port (less than 0 or greater than the number of + // regular fanins), this will result in an error and the node will not be + // modified. + absl::Status AddRegularFaninByPort(absl::string_view node_name, int port, + const TensorId& fanin); + + // Adds control dependency `fanin` to the target node named `node_name`. To + // add regular fanins, use AddRegularFanin. + // + // Case 1: If the fanin is not a Switch node, the control dependency is simply + // added to the target node: + // + // fanin -^> target node. + // + // Case 2: If the fanin is a Switch node, we cannot anchor a control + // dependency on it, because unlike other nodes, only one of its outputs will + // be generated when the node is activated. In this case, we try to find an + // Identity/IdentityN node in the fanout of the relevant port of the Switch + // and add it as a fanin to the target node. If no such Identity/IdentityN + // node can be found, a new Identity node will be created. In both cases, we + // end up with: + // + // fanin -> Identity{N} -^> target node. + // + // If the control dependency being added is redundant (control dependency + // already exists or control dependency can be deduped from regular fanins), + // this will not result in an error and the node will not be modified. + absl::Status AddControllingFanin(absl::string_view node_name, + const TensorId& fanin); + + // Removes regular fanin `fanin` from node `node_name`. If the node or fanin + // do not exist in the graph, nothing will be modified in the graph. If there + // are multiple inputs that match the fanin, all of them will be removed. To + // remove controlling fanins, use RemoveControllingFanin. + // + // If the fanin being removed doesn't exist in the node's inputs, this will + // not result in an error and the node will not be modified. + absl::Status RemoveRegularFanin(absl::string_view node_name, + const TensorId& fanin); + + // Removes regular fanin at port `port` from node `node_name`. If the node + // does not exist in the graph, nothing will be modified in the graph. + // To remove controlling fanins, use RemoveControllingFanin. + // + // If the port is not a valid port (less than 0 or greater than the last index + // of the regular fanins), this will result in an error and the node will not + // be modified. + absl::Status RemoveRegularFaninByPort(absl::string_view node_name, int port); + + // Removes control dependency `fanin_node_name` from the target node named + // `node_name`. If the node or fanin do not exist in the graph, nothing will + // be modified in the graph. To remove regular fanins, use RemoveRegularFanin. + // + // If the fanin being removed doesn't exist in the node's inputs, this will + // not result in an error and the node will not be modified. + absl::Status RemoveControllingFanin(absl::string_view node_name, + absl::string_view fanin_node_name); + + // Removes all fanins from node `node_name`. Control dependencies will be + // retained if keep_controlling_fanins is true. + // + // If no fanins are removed, this will not result in an error and the node + // will not be modified. + absl::Status RemoveAllFanins(absl::string_view node_name, + bool keep_controlling_fanins); + + // Replaces all fanins `from_fanin` with `to_fanin` in node `node_name`. If + // the fanins or node do not exist, nothing will be modified in the graph. + // Control dependencies will be deduped. + // + // If the fanin being updated doesn't exist in the node's inputs, this will + // not result in an error and the node will not be modified. + absl::Status UpdateFanin(absl::string_view node_name, + const TensorId& from_fanin, + const TensorId& to_fanin); + + // Replaces fanin at port `port` in node `node_name` with fanin `fanin`. If + // the fanins or node do not exist, nothing will be modified in the graph. + // Control dependencies will be deduped. + // + // If the port is not a valid port (less than 0 or greater than the last index + // of the regular fanins), this will result in an error and the node will not + // be modified. + absl::Status UpdateRegularFaninByPort(absl::string_view node_name, int port, + const TensorId& fanin); + + // Swaps fanins at ports `from_port` and `to_port` in node `node_name`. If the + // node does not exist, nothing will be modified in the graph. + // + // If the ports are not a valid port (less than 0 or greater than the last + // index of the regular fanins), this will result in an error and the node + // will not be modified. + absl::Status SwapRegularFaninsByPorts(absl::string_view node_name, + int from_port, int to_port); + + // Updates all regular fanins to equivalent controlling fanins. If it is not + // possible, an error will be returned and nothing will be modified in the + // graph. + absl::Status UpdateAllRegularFaninsToControlling(absl::string_view node_name); + + // Deletes nodes from the graph. If a node can't be safely removed, + // specifically if a node still has fanouts, an error will be returned. Nodes + // that can't be found are ignored. + absl::Status DeleteNodes(const absl::flat_hash_set& nodes_to_delete); + + private: + // Adds fanouts for fanins of node to graph, while deduping control + // dependencies from existing control dependencies and regular fanins. Note, + // node inputs will be mutated if control dependencies can be deduped. + void AddAndDedupFanouts(NodeDef* node); + + // Finds next output port smaller than fanin.port_id and update. The + // max_regular_output_port is only updated if fanin.port_id is the same as the + // current max_regular_output_port and if the fanouts set is empty. If there + // are no regular outputs, max_regular_output_port will be erased. + void UpdateMaxRegularOutputPortForRemovedFanin( + const OutputPort& fanin, + const absl::flat_hash_set& fanin_fanouts); + + // Updates max regular output port for newly added fanin by checking the + // current max and updating if the newly added fanin is of a larger port. + void UpdateMaxRegularOutputPortForAddedFanin(const OutputPort& fanin); + + // Updates all fanouts (input ports fetching output tensors) from `from_node` + // to the `to_node`, including control dependencies. + // + // Example: We have 3 nodes that use `bar` node output tensors as inputs: + // 1. foo1(bar:0, bar:1, other:0) + // 2. foo2(bar:1, other:1) + // 3. foo3(other:2, ^bar) + // + // After calling UpdateFanouts(bar, new_bar): + // 1. foo1(new_bar:0, new_bar:1, other:0) + // 2. foo2(new_bar:1, other:1) + // 3. foo3(other:2, ^new_bar) + // + // IMPORTANT: If `from_node` or `to_node` is not in the underlying graph, the + // behavior is undefined. + absl::Status UpdateFanoutsInternal(NodeDef* from_node, NodeDef* to_node); + + // Adds fanin to node. If fanin is a control dependency, existing control + // dependencies will be checked first before adding. Otherwise fanin will be + // added after existing non control dependency inputs. + bool AddFaninInternal(NodeDef* node, const OutputPort& fanin); + + // Finds control dependency node to be used based on fanin. If fanin is not a + // Switch node, fanin.node is simply returned. Otherwise this will try to find + // a candidate Identity node consuming fanin, as the control dependency. If it + // is not possible or will introduce a self loop, an error message will be + // set. If nullptr is returned with no error + // GetOrCreateIdentityConsumingSwitch should be called to generate the new + // Identity node. + NodeDef* GetControllingFaninToAdd(absl::string_view node_name, + const OutputPort& fanin, string* error_msg); + + // Finds a generated Identity node consuming Switch node `fanin.node` at port + // `fanin.port_id`. If such a node does not exist, a new Identity node will be + // created. + NodeDef* GetOrCreateIdentityConsumingSwitch(const OutputPort& fanin); + + // Removes all instances of regular fanin `fanin` from node `node`. + bool RemoveRegularFaninInternal(NodeDef* node, const OutputPort& fanin); + + // Removes controlling fanin `fanin_node` from node if such controlling fanin + // exists. + bool RemoveControllingFaninInternal(NodeDef* node, NodeDef* fanin_node); + + // Checks if nodes to be deleted are missing or have any fanouts that will + // remain in the graph. If node is removed in either case, the graph will + // enter an invalid state. + absl::Status CheckNodesCanBeDeleted( + const absl::flat_hash_set& nodes_to_delete); + + // Removes fanins of the deleted node from internal state. Control + // dependencies are retained iff keep_controlling_fanins is true. + void RemoveFaninsInternal(NodeDef* deleted_node, + bool keep_controlling_fanins); + + // Removes fanouts of the deleted node from internal state. + void RemoveFanoutsInternal(NodeDef* deleted_node); +}; + +} // end namespace grappler +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPPLER_MUTABLE_GRAPH_VIEW_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/grappler/op_types.h b/third_party/tflite-hdrs/tensorflow/core/grappler/op_types.h new file mode 100644 index 00000000..719f12fa --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/grappler/op_types.h @@ -0,0 +1,284 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPPLER_OP_TYPES_H_ +#define TENSORFLOW_CORE_GRAPPLER_OP_TYPES_H_ + +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { +namespace grappler { +bool IsAdd(const NodeDef& node); +bool IsAddN(const NodeDef& node); +bool IsAll(const NodeDef& node); +bool IsAngle(const NodeDef& node); +bool IsAny(const NodeDef& node); +bool IsAnyDiv(const NodeDef& node); +bool IsAnyBatchMatMul(const NodeDef& node); +bool IsAnyMatMul(const NodeDef& node); +bool IsAnyMax(const NodeDef& node); +bool IsAnyMaxPool(const NodeDef& node); +bool IsAnyMin(const NodeDef& node); +bool IsAnyMul(const NodeDef& node); +bool IsAnySparseSegmentReduction(const NodeDef& node); +bool IsApproximateEqual(const NodeDef& node); +bool IsArg(const NodeDef& node); +bool IsArgMax(const NodeDef& node); +bool IsArgMin(const NodeDef& node); +bool IsAssert(const NodeDef& node); +bool IsAssign(const NodeDef& node); +bool IsAsString(const NodeDef& node); +bool IsAtan2(const NodeDef& node); +bool IsAvgPoolGrad(const NodeDef& node); +bool IsBetainc(const NodeDef& node); +bool IsBiasAdd(const NodeDef& node); +bool IsBiasAddV2(const NodeDef& node); +bool IsBiasAddGrad(const NodeDef& node); +bool IsBitcast(const NodeDef& node); +bool IsBroadcastTo(const NodeDef& node); +bool IsCast(const NodeDef& node); +bool IsCheckNumerics(const NodeDef& node); +bool IsCollective(const NodeDef& node); +bool IsComplex(const NodeDef& node); +bool IsComplexAbs(const NodeDef& node); +bool IsConcat(const NodeDef& node); +bool IsConcatOffset(const NodeDef& node); +bool IsConj(const NodeDef& node); +bool IsConjugateTranspose(const NodeDef& node); +bool IsConstant(const NodeDef& node); +bool IsControlFlow(const NodeDef& node); +bool IsConv2D(const NodeDef& node); +bool IsConv2DBackpropFilter(const NodeDef& node); +bool IsConv2DBackpropInput(const NodeDef& node); +bool IsConv3D(const NodeDef& node); +bool IsConv3DBackpropFilterV2(const NodeDef& node); +bool IsConv3DBackpropInputV2(const NodeDef& node); +bool IsDepthwiseConv2dNative(const NodeDef& node); +bool IsDepthwiseConv2dNativeBackpropFilter(const NodeDef& node); +bool IsDepthwiseConv2dNativeBackpropInput(const NodeDef& node); +bool IsDequeueOp(const NodeDef& node); +bool IsDiv(const NodeDef& node); +bool IsDivNoNan(const NodeDef& node); +bool IsElementWiseMonotonic(const NodeDef& node, bool* is_non_decreasing); +bool IsElu(const NodeDef& node); +bool IsEluGrad(const NodeDef& node); +bool IsQuantizationEmulation(const NodeDef& node); +bool IsEnter(const NodeDef& node); +bool IsEqual(const NodeDef& node); +bool IsExit(const NodeDef& node); +bool IsExp(const NodeDef& node); +bool IsFakeParam(const NodeDef& node); +bool IsFill(const NodeDef& node); +bool IsFloorDiv(const NodeDef& node); +bool IsFloorMod(const NodeDef& node); +bool IsFusedBatchNorm(const NodeDef& node); +bool IsFusedBatchNormEx(const NodeDef& node); +bool IsFusedBatchNormGrad(const NodeDef& node); +bool IsGather(const NodeDef& node); +bool IsGreater(const NodeDef& node); +bool IsGreaterEqual(const NodeDef& node); +bool IsHistogramSummary(const NodeDef& node); +bool IsHostConstant(const NodeDef& node); +bool IsIdentity(const NodeDef& node); +bool IsIdentityN(const NodeDef& node); +bool IsIdentityNSingleInput(const NodeDef& node); +bool IsIf(const NodeDef& node); +bool IsIgamma(const NodeDef& node); +bool IsIgammac(const NodeDef& node); +bool IsImag(const NodeDef& node); +bool IsImmutableConst(const NodeDef& node); +bool IsInvGrad(const NodeDef& node); +bool IsLeakyRelu(const NodeDef& node); +bool IsLeakyReluGrad(const NodeDef& node); +bool IsLess(const NodeDef& node); +bool IsLessEqual(const NodeDef& node); +bool IsLog(const NodeDef& node); +bool IsLogicalAnd(const NodeDef& node); +bool IsLogicalNot(const NodeDef& node); +bool IsLogicalOr(const NodeDef& node); +bool IsLoopCond(const NodeDef& node); +bool IsMatMul(const NodeDef& node); +bool IsMax(const NodeDef& node); +bool IsMaxPoolGrad(const NodeDef& node); +bool IsMaximum(const NodeDef& node); +bool IsMean(const NodeDef& node); +bool IsMerge(const NodeDef& node); +bool IsMin(const NodeDef& node); +bool IsMinimum(const NodeDef& node); +bool IsMirrorPad(const NodeDef& node); +bool IsMirrorPadGrad(const NodeDef& node); +bool IsMklFusedMish(const NodeDef& node); +bool IsMod(const NodeDef& node); +bool IsMul(const NodeDef& node); +bool IsMulNoNan(const NodeDef& node); +bool IsNeg(const NodeDef& node); +bool IsNextIteration(const NodeDef& node); +bool IsNoOp(const NodeDef& node); +bool IsNotEqual(const NodeDef& node); +bool IsOnesLike(const NodeDef& node); +bool IsPack(const NodeDef& node); +bool IsPack(const NodeDef& node); +bool IsPad(const NodeDef& node); +bool IsPartitionedCall(const NodeDef& node); +bool IsPlaceholder(const NodeDef& node); +bool IsPolygamma(const NodeDef& node); +bool IsPow(const NodeDef& node); +bool IsPrint(const NodeDef& node); +bool IsProd(const NodeDef& node); +bool IsQuantizedMatMul(const NodeDef& node); +bool IsQueue(const NodeDef& node); +bool IsRandomShuffle(const NodeDef& node); +bool IsRank(const NodeDef& node); +bool IsReadVariableOp(const NodeDef& node); +bool IsReadVariablesOp(const NodeDef& node); +bool IsReal(const NodeDef& node); +bool IsRealDiv(const NodeDef& node); +bool IsReciprocalGrad(const NodeDef& node); +bool IsRecv(const NodeDef& node); +bool IsReduction(const NodeDef& node); +bool IsRelu(const NodeDef& node); +bool IsRelu6(const NodeDef& node); +bool IsRelu6Grad(const NodeDef& node); +bool IsReluGrad(const NodeDef& node); +bool IsReshape(const NodeDef& node); +bool IsRestore(const NodeDef& node); +bool IsRetval(const NodeDef& node); +bool IsReverse(const NodeDef& node); +bool IsReverseV2(const NodeDef& node); +bool IsRsqrt(const NodeDef& node); +bool IsRsqrtGrad(const NodeDef& node); +bool IsSelect(const NodeDef& node); +bool IsSeluGrad(const NodeDef& node); +bool IsSend(const NodeDef& node); +bool IsShape(const NodeDef& node); +bool IsShapeN(const NodeDef& node); +bool IsShuffle(const NodeDef& node); +bool IsSigmoid(const NodeDef& node); +bool IsSigmoidGrad(const NodeDef& node); +bool IsSize(const NodeDef& node); +bool IsSlice(const NodeDef& node); +bool IsSnapshot(const NodeDef& node); +bool IsSoftmax(const NodeDef& node); +bool IsSoftplusGrad(const NodeDef& node); +bool IsSoftsignGrad(const NodeDef& node); +bool IsSplit(const NodeDef& node); +bool IsSplitV(const NodeDef& node); +bool IsSqrt(const NodeDef& node); +bool IsSqrtGrad(const NodeDef& node); +bool IsSquare(const NodeDef& node); +bool IsSquaredDifference(const NodeDef& node); +bool IsSqueeze(const NodeDef& node); +bool IsStackCloseOp(const NodeDef& node); +bool IsStackOp(const NodeDef& node); +bool IsStackPopOp(const NodeDef& node); +bool IsStackPushOp(const NodeDef& node); +bool IsStatefulPartitionedCall(const NodeDef& node); +bool IsStopGradient(const NodeDef& node); +bool IsStridedSlice(const NodeDef& node); +bool IsStridedSliceGrad(const NodeDef& node); +bool IsStringToHashBucketFast(const NodeDef& node); +bool IsSub(const NodeDef& node); +bool IsSum(const NodeDef& node); +bool IsSwitch(const NodeDef& node); +bool IsSymbolicGradient(const NodeDef& node); +bool IsTanh(const NodeDef& node); +bool IsTanhGrad(const NodeDef& node); +bool IsTensorArray(const NodeDef& node); +bool IsTile(const NodeDef& node); +bool IsTranspose(const NodeDef& node); +bool IsTruncateDiv(const NodeDef& node); +bool IsTruncateMod(const NodeDef& node); +bool IsUnique(const NodeDef& node); +bool IsUnpack(const NodeDef& node); +bool IsVariable(const NodeDef& node); +bool IsWhile(const NodeDef& node); +bool IsXdivy(const NodeDef& node); +bool IsXlaLaunch(const NodeDef& node); +bool IsZerosLike(const NodeDef& node); +bool IsZeta(const NodeDef& node); + +// Return true if the op is an aggregation (e.g. Add, AddN). +// Returns false if it could not be determined to be so. +bool IsAggregate(const NodeDef& node); + +// Return true if the op is commutative (e.g. Mul, Add). +// Returns false if it could not be determined to be so. +bool IsCommutative(const NodeDef& node); + +// Returns true if the node is known to use persistent memory to store its +// value. +bool IsPersistent(const NodeDef& node); + +// Returns true if the node belongs to the NC_DATASET class (see graph/graph.h). +bool IsDataset(const NodeDef& node); + +// Returns true if the node op is marked as stateful, or if it was not found in +// op_registry. +bool IsStateful(const NodeDef& node, const OpRegistryInterface* op_registry); +bool IsStateful(const NodeDef& node); // use OpRegistry::Global() + +bool IsFreeOfSideEffect(const NodeDef& node, + const OpRegistryInterface* op_registry); +bool IsFreeOfSideEffect(const NodeDef& node); // use OpRegistry::Global() + +// Returns true if the takes a tensor reference as input. +// Returns false if the op type is unknown. +bool HasRefInput(const NodeDef& node); + +bool ModifiesFrameInfo(const NodeDef& node); + +// Returns true if the op is known to write to one or more of its inputs. +bool ModifiesInputsInPlace(const NodeDef& node); + +// Returns true if the op is an element-wise involution, i.e. if it is its +// own inverse such that f(f(x)) == x. +bool IsInvolution(const NodeDef& node); + +// Returns true if the op preserves the order and value of elements +// and shape of its first input tensor. +bool IsValueAndOrderAndShapePreserving(const NodeDef& node); + +// Returns true if the op preserves the order and value of elements in its +// first input tensor and possible changes its shape. +bool IsValueAndOrderPreserving(const NodeDef& node); + +// Returns true if the op in node only rearranges the order of elements in its +// first input tensor and possible changes its shape. More precisely, this +// function returns true if the op commutes with all element-wise operations. +bool IsValuePreserving(const NodeDef& node); + +// Returns true if node is idempotent w.r.t. its first input, i.e. if +// Op(Op(x, y, z), y, z) = Op(x, y, z). +bool IsIdempotent(const NodeDef& node); + +bool IsUnaryElementWise(const NodeDef& node); + +// Returns true if we can find an opdef corresponding to the op of the node. +bool HasOpDef(const NodeDef& node); + +// Returns true if the op changes the scalar type of its first input elements +// and preserves the number of elements. +bool IsCastLike(const NodeDef& node); + +// Returns true if this op never forwards any of its inputs, i.e. always +// allocates buffers for its inputs. +bool NeverForwardsInputs(const NodeDef& node); + +} // end namespace grappler +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPPLER_OP_TYPES_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h b/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h new file mode 100644 index 00000000..2d079a5c --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h @@ -0,0 +1,146 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_ARITHMETIC_OPTIMIZER_H_ +#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_ARITHMETIC_OPTIMIZER_H_ + +#include + +#include "tensorflow/core/grappler/costs/graph_properties.h" +#include "tensorflow/core/grappler/optimizers/graph_optimizer.h" +#include "tensorflow/core/grappler/utils.h" +#include "tensorflow/core/lib/gtl/flatset.h" +#include "tensorflow/core/protobuf/rewriter_config.pb.h" + +namespace tensorflow { +namespace grappler { + +constexpr char kArithmeticOptimizer[] = "ArithmeticOptimizer"; + +// Optimize TF computations by reducing the arithmetic complexity required to +// run a model. +class ArithmeticOptimizer : public GraphOptimizer { + public: + ArithmeticOptimizer() + : opt_level_(RewriterConfig::ON), + options_(ArithmeticOptimizerOptions::Default(RewriterConfig::ON)) {} + + explicit ArithmeticOptimizer(RewriterConfig::Toggle opt_level) + : opt_level_(opt_level), + options_(ArithmeticOptimizerOptions::Default(opt_level)) {} + + ~ArithmeticOptimizer() override {} + + string name() const override { return "arithmetic_optimizer"; }; + + bool UsesFunctionLibrary() const override { return false; } + + absl::Status Optimize(Cluster* cluster, const GrapplerItem& item, + GraphDef* optimized_graph) override; + + private: + friend class ArithmeticOptimizerTest; + + // Granular control for arithmetic optimizer stages + struct ArithmeticOptimizerOptions { + bool combine_add_to_addn = true; + bool convert_sqrt_div_to_rsqrt_mul = true; + bool dedup_computations = true; + bool fold_conjugate_into_transpose = true; + bool fold_multiply_into_conv = true; + bool fold_transpose_into_matmul = true; + bool fuse_squared_diff = true; + bool hoist_common_factor_out_of_aggregation = true; + bool hoist_cwise_unary_chains = true; + bool minimize_broadcasts = true; + bool optimize_max_or_min_of_monotonic = true; + bool remove_idempotent = true; + bool remove_identity_transpose = true; + bool remove_involution = true; + bool remove_logical_not = true; + bool remove_negation = true; + bool remove_redundant_bitcast = true; + bool remove_redundant_cast = true; + bool remove_redundant_reshape = true; + bool reduce_upsampling_dims = true; + bool reorder_cast_like_and_value_preserving = true; + bool replace_mul_with_tile = true; + bool replace_mul_with_square = true; + bool replace_pack_with_tile_reshape = true; + bool convert_pow = true; + bool convert_log1p = true; + bool convert_log_softmax = true; + bool convert_expm1 = true; + bool unary_ops_composition = true; + bool remove_stack_slice_same_axis = true; + bool simplify_aggregation = true; + bool simplify_embedding_lookup = true; + bool remove_cast_into_segment_reduction = true; + + // Choose which arithmetic optimizer stages will be enabled for a given + // optimization level by default. + static ArithmeticOptimizerOptions Default( + RewriterConfig::Toggle opt_level) { + ArithmeticOptimizerOptions options; + return options; + } + }; + + // Returns true if it is safe to dedup node from the graph. + bool CanDedup(const NodeDef& node) const; + + // Dedup redundant nodes in the graph. + void DedupComputations(); + + // Forward the control dependencies anchored on src_nodes to the target_nodes. + void ForwardControlDependencies(NodeDef* target_node, + const std::vector& src_nodes); + + // Runs peep-hole optimizations on `optimized_graph`, e.g., removing inverse + // transposes. + absl::Status SimplifyArithmeticOps(bool can_use_shapes); + // Tries to simplify the expression that roots at `node` and replaces the uses + // of `node` to the simplified expression. Returns the name of the simplified + // tensor (e.g. "split:1") or an empty string if no simplification is + // performed. + // + // `node_map` stores the mapping from node names to NodeDef*, and will be + // updated according to the rewrite. + // + // `new_nodes` will be populated with the new nodes this function creates and + // updates. The caller can push these nodes into the simplification queue to + // optimize them further. + // + // TODO(jingyue): This interface is not suitable for optimizing nodes with + // multiple output tensors. We should pass in a tensor name instead of a + // NodeDef. + string TrySimplifyAndReplaceUses(const NodeDef* node, + SetVector* nodes_to_simplify); + + RewriterConfig::Toggle opt_level_; + ArithmeticOptimizerOptions options_; + + bool fetch_nodes_known_ = false; + std::unordered_set nodes_to_preserve_; + std::unique_ptr node_map_; + std::unique_ptr graph_properties_; + GraphDef* optimized_graph_ = nullptr; // Not owned. + gtl::FlatSet feed_nodes_; +}; + +} // end namespace grappler +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_ARITHMETIC_OPTIMIZER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test_utils.h b/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test_utils.h new file mode 100644 index 00000000..7955db31 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test_utils.h @@ -0,0 +1,289 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_ARITHMETIC_OPTIMIZER_TEST_UTILS_H_ +#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_ARITHMETIC_OPTIMIZER_TEST_UTILS_H_ + +#include "tensorflow/core/grappler/optimizers/arithmetic_optimizer.h" +#include "tensorflow/core/grappler/optimizers/common_subgraph_elimination.h" +#include "tensorflow/core/grappler/optimizers/constant_folding.h" +#include "tensorflow/core/grappler/optimizers/model_pruner.h" +#include "tensorflow/core/grappler/utils/grappler_test.h" +#include "tensorflow/core/lib/core/status_test_util.h" + +namespace tensorflow { +namespace grappler { + +class ArithmeticOptimizerTest : public GrapplerTest { + protected: + // Optimize a graph using optimizer and prune all the nodes that no + // longer have any output consumers. + void OptimizeAndPrune(GraphOptimizer* optimizer, GrapplerItem* item, + GraphDef* output) { + TF_EXPECT_OK(optimizer->Optimize(nullptr, *item, output)); + item->graph.Swap(output); + output->Clear(); + TF_EXPECT_OK(ModelPruner().Optimize(nullptr, *item, output)); + } + + // Run optimizer twice to make sure the rewrite is idempotent. + void DedupAndOptimizeTwiceAndPrune(GraphOptimizer* optimizer, + GrapplerItem* item, GraphDef* output) { + TF_EXPECT_OK(CommonSubgraphElimination().Optimize(nullptr, *item, output)); + item->graph.Swap(output); + output->Clear(); + TF_EXPECT_OK(optimizer->Optimize(nullptr, *item, output)); + item->graph.Swap(output); + output->Clear(); + TF_EXPECT_OK(optimizer->Optimize(nullptr, *item, output)); + item->graph.Swap(output); + output->Clear(); + TF_EXPECT_OK(ModelPruner().Optimize(nullptr, *item, output)); + } + + // Run optimizer twice to make sure the rewrite is idempotent. + void OptimizeTwice(GraphOptimizer* optimizer, GrapplerItem* item, + GraphDef* output) { + TF_EXPECT_OK(optimizer->Optimize(nullptr, *item, output)); + item->graph.Swap(output); + output->Clear(); + TF_EXPECT_OK(optimizer->Optimize(nullptr, *item, output)); + } + + // Run optimizer twice to make sure the rewrite is idempotent. + // Optionally run a constant folding pass before pruning. + void OptimizeTwiceAndPrune(GraphOptimizer* optimizer, GrapplerItem* item, + GraphDef* output, bool const_folding = false) { + TF_EXPECT_OK(optimizer->Optimize(nullptr, *item, output)); + + item->graph.Swap(output); + output->Clear(); + TF_EXPECT_OK(optimizer->Optimize(nullptr, *item, output)); + + if (const_folding) { + item->graph.Swap(output); + output->Clear(); + TF_EXPECT_OK(ConstantFolding(/*cpu_device=*/nullptr) + .Optimize(nullptr, *item, output)); + } + + item->graph.Swap(output); + output->Clear(); + TF_EXPECT_OK(ModelPruner().Optimize(nullptr, *item, output)); + } + + void DisableAddToAddNCombining(ArithmeticOptimizer* optimizer) { + optimizer->options_.combine_add_to_addn = false; + } + + void EnableOnlyAddToAddNCombining(ArithmeticOptimizer* optimizer) { + DisableAllStages(optimizer); + optimizer->options_.combine_add_to_addn = true; + } + + void EnableOnlyFoldConjugateIntoTranspose(ArithmeticOptimizer* optimizer) { + DisableAllStages(optimizer); + optimizer->options_.fold_conjugate_into_transpose = true; + } + + void EnableOnlyFoldMultipleIntoConv(ArithmeticOptimizer* optimizer) { + DisableAllStages(optimizer); + optimizer->options_.fold_multiply_into_conv = true; + } + + void EnableOnlyFoldTransposeIntoMatMul(ArithmeticOptimizer* optimizer) { + DisableAllStages(optimizer); + optimizer->options_.fold_transpose_into_matmul = true; + } + + void EnableOnlyHoistCommonFactor(ArithmeticOptimizer* optimizer) { + DisableAllStages(optimizer); + optimizer->options_.hoist_common_factor_out_of_aggregation = true; + } + + void EnableOnlyMinimizeBroadcasts(ArithmeticOptimizer* optimizer) { + DisableAllStages(optimizer); + optimizer->options_.minimize_broadcasts = true; + } + + void EnableOnlyRemoveIdentityTranspose(ArithmeticOptimizer* optimizer) { + DisableAllStages(optimizer); + optimizer->options_.remove_identity_transpose = true; + } + + void EnableOnlyRemoveInvolution(ArithmeticOptimizer* optimizer) { + DisableAllStages(optimizer); + optimizer->options_.remove_involution = true; + } + + void EnableOnlyRemoveRedundantBitcast(ArithmeticOptimizer* optimizer) { + DisableAllStages(optimizer); + optimizer->options_.remove_redundant_bitcast = true; + } + + void EnableOnlyRemoveRedundantCast(ArithmeticOptimizer* optimizer) { + DisableAllStages(optimizer); + optimizer->options_.remove_redundant_cast = true; + } + + void EnableOnlyReduceUpsamplingDims(ArithmeticOptimizer* optimizer) { + DisableAllStages(optimizer); + optimizer->options_.reduce_upsampling_dims = true; + } + + void EnableOnlyRemoveRedundantReshape(ArithmeticOptimizer* optimizer) { + DisableAllStages(optimizer); + optimizer->options_.remove_redundant_reshape = true; + } + + void EnableOnlyRemoveNegation(ArithmeticOptimizer* optimizer) { + DisableAllStages(optimizer); + optimizer->options_.remove_negation = true; + } + + void EnableOnlyReorderCastAndTranspose(ArithmeticOptimizer* optimizer) { + DisableAllStages(optimizer); + optimizer->options_.reorder_cast_like_and_value_preserving = true; + } + + void EnableOnlyReplaceMulWithBroadcastByTile(ArithmeticOptimizer* optimizer) { + DisableAllStages(optimizer); + optimizer->options_.replace_mul_with_tile = true; + } + + void EnableOnlyReplaceMulWithSquare(ArithmeticOptimizer* optimizer) { + DisableAllStages(optimizer); + optimizer->options_.replace_mul_with_square = true; + } + + void EnableOnlyReplacePackWithTileReshape(ArithmeticOptimizer* optimizer) { + DisableAllStages(optimizer); + optimizer->options_.replace_pack_with_tile_reshape = true; + } + + void EnableOnlyHoistCWiseUnaryChains(ArithmeticOptimizer* optimizer) { + DisableAllStages(optimizer); + optimizer->options_.hoist_cwise_unary_chains = true; + } + + void EnableOnlySqrtDivToRsqrtMul(ArithmeticOptimizer* optimizer) { + DisableAllStages(optimizer); + optimizer->options_.convert_sqrt_div_to_rsqrt_mul = true; + } + + void EnableOnlyLogSoftmax(ArithmeticOptimizer* optimizer) { + DisableAllStages(optimizer); + optimizer->options_.convert_log_softmax = true; + } + + void EnableOnlyConvertPow(ArithmeticOptimizer* optimizer) { + DisableAllStages(optimizer); + optimizer->options_.convert_pow = true; + } + + void EnableOnlyFuseSquaredDiff(ArithmeticOptimizer* optimizer) { + DisableAllStages(optimizer); + optimizer->options_.fuse_squared_diff = true; + } + + void EnableOnlyRemoveIdempotent(ArithmeticOptimizer* optimizer) { + DisableAllStages(optimizer); + optimizer->options_.remove_idempotent = true; + } + + void EnableOnlyRemoveLogicalNot(ArithmeticOptimizer* optimizer) { + DisableAllStages(optimizer); + optimizer->options_.remove_logical_not = true; + } + + void EnableOnlySimplifyAggregation(ArithmeticOptimizer* optimizer) { + DisableAllStages(optimizer); + optimizer->options_.simplify_aggregation = true; + } + + void EnableOnlyLog1p(ArithmeticOptimizer* optimizer) { + DisableAllStages(optimizer); + optimizer->options_.convert_log1p = true; + } + + void EnableOnlyOptimizeMaxOrMinOfMonotonic(ArithmeticOptimizer* optimizer) { + DisableAllStages(optimizer); + optimizer->options_.optimize_max_or_min_of_monotonic = true; + } + + void EnableOnlyExpm1(ArithmeticOptimizer* optimizer) { + DisableAllStages(optimizer); + optimizer->options_.convert_expm1 = true; + } + + void EnableOnlyUnaryOpsComposition(ArithmeticOptimizer* optimizer) { + DisableAllStages(optimizer); + optimizer->options_.unary_ops_composition = true; + } + + void EnableOnlyRemoveStackSliceSameAxis(ArithmeticOptimizer* optimizer) { + DisableAllStages(optimizer); + optimizer->options_.remove_stack_slice_same_axis = true; + } + + void EnableOnlySimplifyEmbeddingLookup(ArithmeticOptimizer* optimizer) { + DisableAllStages(optimizer); + optimizer->options_.simplify_embedding_lookup = true; + } + + void EnableOnlyRemoveCastIntoSegmentReduction( + ArithmeticOptimizer* optimizer) { + DisableAllStages(optimizer); + optimizer->options_.remove_cast_into_segment_reduction = true; + } + + private: + void DisableAllStages(ArithmeticOptimizer* optimizer) { + ArithmeticOptimizer::ArithmeticOptimizerOptions options; + options.dedup_computations = false; + options.combine_add_to_addn = false; + options.convert_sqrt_div_to_rsqrt_mul = false; + options.convert_pow = false; + options.convert_log1p = false; + options.optimize_max_or_min_of_monotonic = false; + options.fold_conjugate_into_transpose = false; + options.fold_multiply_into_conv = false; + options.fold_transpose_into_matmul = false; + options.hoist_common_factor_out_of_aggregation = false; + options.hoist_cwise_unary_chains = false; + options.minimize_broadcasts = false; + options.remove_identity_transpose = false; + options.remove_involution = false; + options.remove_idempotent = false; + options.remove_redundant_bitcast = false; + options.remove_redundant_cast = false; + options.remove_redundant_reshape = false; + options.remove_negation = false; + options.remove_logical_not = false; + options.reorder_cast_like_and_value_preserving = false; + options.replace_mul_with_tile = false; + options.replace_mul_with_square = false; + options.simplify_aggregation = false; + options.unary_ops_composition = false; + options.simplify_embedding_lookup = false; + options.remove_cast_into_segment_reduction = false; + optimizer->options_ = options; + } +}; + +} // end namespace grappler +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_ARITHMETIC_OPTIMIZER_TEST_UTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/auto_mixed_precision.h b/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/auto_mixed_precision.h new file mode 100644 index 00000000..d4be8476 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/auto_mixed_precision.h @@ -0,0 +1,74 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_AUTO_MIXED_PRECISION_H_ +#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_AUTO_MIXED_PRECISION_H_ + +#include "tensorflow/core/grappler/optimizers/graph_optimizer.h" +#include "tensorflow/core/platform/cpu_info.h" +#include "tensorflow/core/protobuf/rewriter_config.pb.h" + +namespace tensorflow { +namespace grappler { + +// CUDA: convert to float16 on GPU +// BF16: convert to bfloat16 on CPU +// CPU: emulate float16 on CPU without changing operator kernel +// FP16_CPU : convert to float16 on CPU +enum class AutoMixedPrecisionMode { CUDA, BF16, CPU, FP16_CPU }; + +// Convert data types to float16 or bfloat16 where appropriate to improve +// performance on GPUs or CPUs. +class AutoMixedPrecision : public GraphOptimizer { + public: + // If 'mode' is CUDA, converts nodes to float16 on Nvidia GPUs. If BF16 or + // FP16_CPU, converts nodes to bfloat16/fp16 on CPUs in order to take + // advantage of oneDNN performance improvements with bfloat16/fp16. + explicit AutoMixedPrecision( + AutoMixedPrecisionMode mode = AutoMixedPrecisionMode::CUDA) + : mode_(mode) {} + + ~AutoMixedPrecision() override {} + + string name() const override { + switch (mode_) { + case AutoMixedPrecisionMode::CUDA: + return "auto_mixed_precision"; + case AutoMixedPrecisionMode::BF16: + return "auto_mixed_precision_onednn_bfloat16"; + case AutoMixedPrecisionMode::CPU: + return "auto_mixed_precision_cpu"; + case AutoMixedPrecisionMode::FP16_CPU: + // Note: using different name than GPU for ease of debugging. + return "auto_mixed_precision_onednn_float16"; + default: + LOG(FATAL) << "Invalid value for AutoMixedPrecisionMode: " // Crash Ok + << static_cast(mode_); + } + }; + + bool UsesFunctionLibrary() const override { return false; } + + absl::Status Optimize(Cluster* cluster, const GrapplerItem& item, + GraphDef* output) override; + + private: + const AutoMixedPrecisionMode mode_; +}; + +} // end namespace grappler +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_AUTO_MIXED_PRECISION_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/auto_mixed_precision_lists.h b/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/auto_mixed_precision_lists.h new file mode 100644 index 00000000..37f3714c --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/auto_mixed_precision_lists.h @@ -0,0 +1,600 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_AUTO_MIXED_PRECISION_LISTS_H_ +#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_AUTO_MIXED_PRECISION_LISTS_H_ + +#include + +#include "tensorflow/core/grappler/optimizers/auto_mixed_precision.h" +#include "tensorflow/core/lib/gtl/flatset.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/util/env_var.h" + +namespace tensorflow { +namespace grappler { + +// Represents the four lists of ops: the allow list, infer list, deny list, and +// clear list. These lists determine which ops are converted to fp16/bf16 +// (referred to as 'f16' for short) and which ops stay as fp32. +class AutoMixedPrecisionLists { + public: + virtual ~AutoMixedPrecisionLists() {} + + // Returns the set of ops that are considered numerically-safe (for execution + // in f16), performance-critical, and can run in f16. These ops are always + // converted to f16. + virtual gtl::FlatSet AllowList() = 0; + // Returns the set of ops that can run in f16 and are considered numerically- + // safe (for execution in f16), but which may be made unsafe by an upstream + // denylist op. + virtual gtl::FlatSet InferList() = 0; + // Returns the set of ops that are considered numerically-dangerous (i.e., + // unsafe for execution in f16) and whose effects may also be observed in + // downstream nodes (e.g. for f16, in Exp -> Add, the Add is unsafe due to + // the Exp). + virtual gtl::FlatSet DenyList() = 0; + // Returns the set of ops that do not have numerically-significant effects + // (i.e., they are always considered safe for execution in f16 precision), and + // can run in f16. + virtual gtl::FlatSet ClearList() = 0; + + protected: + // Adds or removes ops from list if certain environmental variables are set. + static void UpdateList(const string& list_name, gtl::FlatSet* list) { + CHECK(list_name == "ALLOWLIST" || list_name == "INFERLIST" || // Crash OK. + list_name == "DENYLIST" || list_name == "CLEARLIST" || + // TODO(reedwm): for bkwds compat; remove when no longer necessary: + list_name == "WHITELIST" || list_name == "GRAYLIST" || + list_name == "BLACKLIST"); + string add_env_var = + "TF_AUTO_MIXED_PRECISION_GRAPH_REWRITE_" + list_name + "_ADD"; + string remove_env_var = + "TF_AUTO_MIXED_PRECISION_GRAPH_REWRITE_" + list_name + "_REMOVE"; + string to_add, to_remove; + TF_CHECK_OK(ReadStringFromEnvVar(add_env_var, "", &to_add)); + TF_CHECK_OK(ReadStringFromEnvVar(remove_env_var, "", &to_remove)); + for (const auto& x : str_util::Split(to_add, ",")) { + list->insert(x); + } + for (const auto& x : str_util::Split(to_remove, ",")) { + list->erase(x); + } + } + + // Subclasses should include these on the ClearList. + static void AddTensorListOps(gtl::FlatSet* list) { + // Note: if a data structure op (such as TensorListPopBack) is added here, + // IsTensorListReaderOp or IsTensorListWriterOp may need to be modified + // LINT.IfChange + constexpr const char* tensor_list_ops[] = { + "TensorListConcat", "TensorListConcatLists", + "TensorListConcatV2", "TensorListGather", + "TensorListGetItem", "TensorListPopBack", + "TensorListPushBack", "TensorListPushBackBatch", + "TensorListFromTensor", "TensorListScatter", + "TensorListScatterV2", "TensorListScatterIntoExistingList", + "TensorListSetItem", "TensorListSplit", + "TensorListStack"}; + // LINT.ThenChange(//tensorflow/core/grappler/optimizers/auto_mixed_precision.cc) + for (auto op : tensor_list_ops) { + list->insert(op); + } + } +}; + +class AutoMixedPrecisionListsFp16 : public AutoMixedPrecisionLists { + private: + static bool IsPseudoFastMath() { + string optimization_level; + TF_CHECK_OK( + ReadStringFromEnvVar("TF_AUTO_MIXED_PRECISION_GRAPH_REWRITE_LEVEL", "", + &optimization_level)); + optimization_level = absl::AsciiStrToUpper(optimization_level); + return optimization_level == "TENSOR_CORES_ONLY"; + } + + public: + AutoMixedPrecisionListsFp16( + int cuda_version, int cudnn_version, + AutoMixedPrecisionMode mode = AutoMixedPrecisionMode::CUDA) + : cuda_version_(cuda_version), cudnn_version_(cudnn_version) { + if (mode == AutoMixedPrecisionMode::CUDA || + mode == AutoMixedPrecisionMode::CPU) { + // Note: this is not a typo here. use_cuda_ is set to true for the CPU + // intentionally to make CPU and GPU have the same fp16 ops. + use_cuda_ = true; + use_onednn_ = false; + } else if (mode == AutoMixedPrecisionMode::FP16_CPU) { + use_onednn_ = true; + use_cuda_ = false; + } + } + + gtl::FlatSet AllowList() override { + auto list = gtl::FlatSet{ + "Conv2D", "Conv2DBackpropFilter", "Conv2DBackpropInput", "Einsum", + "MatMul", + }; + if (use_cuda_) { + list.insert("BlockLSTM"); + list.insert("BlockLSTMV2"); + list.insert("BlockLSTMGrad"); + list.insert("BlockLSTMGradV2"); + list.insert("CudnnRNN"); + list.insert("CudnnRNNBackprop"); + list.insert("CudnnRNNBackpropV2"); + list.insert("CudnnRNNBackpropV3"); + list.insert("CudnnRNNV2"); + list.insert("CudnnRNNV3"); + list.insert("FusedConv2DBiasActivation"); + list.insert("FusedSparseConvGpuV2"); + list.insert("GRUBlockCell"); + list.insert("GRUBlockCellGrad"); + list.insert("LSTMBlockCell"); + list.insert("LSTMBlockCellGrad"); + list.insert("Mha"); + list.insert("MhaV2"); + list.insert("Tmlp"); + list.insert("TmlpV2"); + list.insert("TmlpV3"); + list.insert("Pmlp"); + list.insert("FastUnsortedSegmentMax"); + list.insert("VoxelMax"); + } +#if TENSORFLOW_USE_ROCM + if (true) { +#else + if ((use_cuda_ && cuda_version_ >= 9010) || use_onednn_) { + // Fp16 BatchMatMul is slow before CUDA 9.1. +#endif + list.insert("BatchMatMul"); + list.insert("BatchMatMulV2"); + } + if ((use_cuda_ && cudnn_version_ >= 7602) || use_onednn_) { + // Fp16 3D conv is slow before CUDNN 7.6.2. + list.insert("Conv3D"); + list.insert("Conv3DBackpropFilter"); + list.insert("Conv3DBackpropFilterV2"); + list.insert("Conv3DBackpropInput"); + list.insert("Conv3DBackpropInputV2"); + } + if ((use_cuda_ && cudnn_version_ >= 8000) || use_onednn_) { + list.insert("DepthwiseConv2dNative"); + list.insert("DepthwiseConv2dNativeBackpropFilter"); + list.insert("DepthwiseConv2dNativeBackpropInput"); + } + UpdateList("ALLOWLIST", &list); + // For backwards compatibility, keeping the original env variable here. + // TODO(reedwm): This should be removed if we don't have active users. + UpdateList("WHITELIST", &list); + + return list; + } + + gtl::FlatSet InferList() override { + if (IsPseudoFastMath() && use_cuda_) { + return gtl::FlatSet{}; + } + + auto list = gtl::FlatSet{ + "Add", + "AddN", + "AddV2", + "AvgPool", + "AvgPool3D", + "AvgPool3DGrad", + "AvgPoolGrad", + "BiasAdd", + "BiasAddGrad", + "BiasAddV1", + "Elu", + "EluGrad", + "Erf", + "Erfc", + "FloorDiv", + "FusedBatchNormV2", + "FusedBatchNormGradV2", + "FusedBatchNormV3", + "FusedBatchNormGradV3", + "_FusedBatchNormEx", + "Inv", + "LeakyRelu", + "LeakyReluGrad", + "Log", + "Log1p", + "LogSoftmax", + "Mul", + "Prod", + "RealDiv", + "Reciprocal", + "Selu", + "SeluGrad", + "Sigmoid", + "SigmoidGrad", + "Softmax", + "Softplus", + "SoftplusGrad", + "Softsign", + "SoftsignGrad", + "Sqrt", + "Sub", + "Tanh", + "TanhGrad", + }; + if (use_onednn_) { + list.insert("Rsqrt"); + list.insert("Square"); + list.insert("SquaredDifference"); + } + UpdateList("INFERLIST", &list); + // For backwards compatibility, keeping the original env variable here. + // TODO(reedwm): This should be removed if we don't have active users. + UpdateList("GRAYLIST", &list); + return list; + } + + gtl::FlatSet DenyList() override { + if (IsPseudoFastMath() && use_cuda_) { + return gtl::FlatSet{}; + } + + auto list = gtl::FlatSet{ + "Exp", + "Expm1", + "L2Loss", + "Mean", + "Pow", + "SaveV2", + "SoftmaxCrossEntropyWithLogits", + "SparseSoftmaxCrossEntropyWithLogits", + "Sum", + }; + UpdateList("DENYLIST", &list); + // For backwards compatibility, keeping the original env variable here. + // TODO(reedwm): This should be removed if we don't have active users. + UpdateList("BLACKLIST", &list); + return list; + } + + gtl::FlatSet ClearList() override { + if (IsPseudoFastMath() && use_cuda_) { + return gtl::FlatSet{}; + } + + auto list = gtl::FlatSet{ + "Abs", + "ArgMax", + "ArgMin", + "BatchToSpace", + "BatchToSpaceND", + "BroadcastTo", + "Ceil", + "CheckNumerics", + "ClipByValue", + "Concat", + "ConcatV2", + "DepthToSpace", + "DynamicPartition", + "DynamicStitch", + "Enter", + "EnsureShape", + "Equal", + "Exit", + "ExpandDims", + "Fill", + "Floor", + "Gather", + "GatherNd", + "GatherV2", + "Greater", + "GreaterEqual", + "Identity", + "IdentityN", + "IsFinite", + "IsInf", + "IsNan", + "Less", + "LessEqual", + "Max", + "MaxPool", + "MaxPool3D", + "MaxPool3DGrad", + "MaxPool3DGradGrad", + "MaxPoolGrad", + "MaxPoolGradGrad", + "MaxPoolGradGradV2", + "MaxPoolGradV2", + "MaxPoolV2", + "Maximum", + "Merge", + "Min", + "Minimum", + "MirrorPad", + "MirrorPadGrad", + "Neg", + "NextIteration", + "NotEqual", + "OneHot", + "OnesLike", + "Pack", + "Pad", + "PadV2", + "PreventGradient", + "Rank", + "Relu", + "Relu6", + "Relu6Grad", + "ReluGrad", + "Reshape", + "ResizeNearestNeighbor", + "ResizeNearestNeighborGrad", + "Reverse", + "ReverseSequence", + "ReverseV2", + "Round", + "Select", + "SelectV2", + "Shape", + "ShapeN", + "Sign", + "Size", + "Slice", + "Snapshot", + "SpaceToBatch", + "SpaceToBatchND", + "SpaceToDepth", + "Split", + "SplitV", + "Squeeze", + "StopGradient", + "StridedSlice", + "StridedSliceGrad", + "Switch", + "Tile", + "TopK", + "TopKV2", + "Transpose", + "Unpack", + "Where", + "ZerosLike", + }; + AddTensorListOps(&list); + UpdateList("CLEARLIST", &list); + return list; + } + + private: + int cuda_version_; + int cudnn_version_; + bool use_cuda_; + bool use_onednn_; +}; + +// TODO(reedwm): Remove this alias. Some Google-internal code still uses the +// AutoMixedPrecisionListsCuda name. +using AutoMixedPrecisionListsCuda = AutoMixedPrecisionListsFp16; + +class AutoMixedPrecisionListsMkl : public AutoMixedPrecisionLists { + public: + AutoMixedPrecisionListsMkl() {} + + // Only ops which are supported by MKL in bfloat16 should be added to the + // allow list, infer list, or clear list. + gtl::FlatSet AllowList() override { + auto list = gtl::FlatSet{"Conv2D", + "Conv2DBackpropFilter", + "Conv2DBackpropInput", + "Conv3D", + "Conv3DBackpropFilterV2", + "Conv3DBackpropInputV2", + "DepthwiseConv2dNative", + "DepthwiseConv2dNativeBackpropFilter", + "DepthwiseConv2dNativeBackpropInput", + "MatMul", + "FusedPadConv2D", + "BatchMatMul", + "BatchMatMulV2", + "Einsum"}; + + UpdateList("ALLOWLIST", &list); + // For backwards compatibility, keeping the original env variable here. + // TODO(reedwm): This should be removed if we don't have active users. + UpdateList("WHITELIST", &list); + return list; + } + + gtl::FlatSet InferList() override { + auto list = gtl::FlatSet{"Add", + "AddN", + "AddV2", + "AvgPool", + "AvgPool3D", + "AvgPool3DGrad", + "AvgPoolGrad", + "BiasAdd", + "BiasAddGrad", + "BiasAddV1", + "Erf", + "Erfc", + "FusedBatchNormV2", + "FusedBatchNormGradV2", + "FusedBatchNormV3", + "FusedBatchNormGradV3", + "LeakyRelu", + "LeakyReluGrad", + "Mul", + "Sub", + "Elu", + "EluGrad", + "FloorDiv", + "_FusedBatchNormEx", + "Inv", + "Log", + "Log1p", + "LogSoftmax", + "Mean", + "Prod", + "RealDiv", + "Reciprocal", + "Rsqrt", + "Selu", + "SeluGrad", + "Sigmoid", + "SigmoidGrad", + "Softmax", + "Softplus", + "SoftplusGrad", + "Softsign", + "SoftsignGrad", + "Sqrt", + "Square", + "SquaredDifference", + "Sum", + "Tanh", + "TanhGrad"}; + UpdateList("INFERLIST", &list); + // For backwards compatibility, keeping the original env variable here. + // TODO(reedwm): This should be removed if we don't have active users. + UpdateList("GRAYLIST", &list); + return list; + } + + gtl::FlatSet DenyList() override { + auto list = gtl::FlatSet{ + "Exp", + "Expm1", + "L2Loss", + "Pow", + "SaveV2", + "SoftmaxCrossEntropyWithLogits", + "SparseSoftmaxCrossEntropyWithLogits", + }; + UpdateList("DENYLIST", &list); + // For backwards compatibility, keeping the original env variable here. + // TODO(reedwm): This should be removed if we don't have active users. + UpdateList("BLACKLIST", &list); + return list; + } + + gtl::FlatSet ClearList() override { + auto list = gtl::FlatSet{ + "Abs", + "ArgMax", + "ArgMin", + "BatchToSpace", + "BatchToSpaceND", + "BroadcastTo", + "Ceil", + "CheckNumerics", + "ClipByValue", + "Concat", + "ConcatV2", + "DepthToSpace", + "DynamicPartition", + "DynamicStitch", + "EnsureShape", + "Enter", + "Equal", + "Exit", + "ExpandDims", + "Fill", + "Floor", + "Gather", + "GatherNd", + "GatherV2", + "Greater", + "GreaterEqual", + "Identity", + "IdentityN", + "IsFinite", + "IsInf", + "IsNan", + "Less", + "LessEqual", + "Max", + "Maximum", + "MaxPool", + "MaxPool3D", + "MaxPool3DGrad", + "MaxPoolGrad", + "MaxPoolGradGrad", + "MaxPoolGradGradV2", + "MaxPoolGradV2", + "MaxPoolV2", + "Merge", + "Min", + "Minimum", + "MirrorPad", + "MirrorPadGrad", + "Neg", + "NextIteration", + "NotEqual", + "OnesLike", + "Pack", + "Pad", + "PadV2", + "PreventGradient", + "Rank", + "Relu", + "Relu6", + "Relu6Grad", + "ReluGrad", + "Reshape", + "ResizeNearestNeighbor", + "ResizeNearestNeighborGrad", + "ResizeBilinear", + "Reverse", + "ReverseSequence", + "ReverseV2", + "Round", + "ScatterNd", + "Select", + "SelectV2", + "Shape", + "ShapeN", + "Sign", + "Slice", + "Snapshot", + "SpaceToBatch", + "SpaceToBatchND", + "SpaceToDepth", + "Split", + "SplitV", + "Squeeze", + "StatelessWhile", + "StopGradient", + "StridedSlice", + "StridedSliceGrad", + "Switch", + "Tile", + "TopK", + "TopKV2", + "Transpose", + "Where", + "While", + "Unpack", + "ZerosLike", + }; + AddTensorListOps(&list); + UpdateList("CLEARLIST", &list); + return list; + } +}; + +} // end namespace grappler +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_AUTO_MIXED_PRECISION_LISTS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/auto_parallel.h b/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/auto_parallel.h new file mode 100644 index 00000000..ae063864 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/auto_parallel.h @@ -0,0 +1,65 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_AUTO_PARALLEL_H_ +#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_AUTO_PARALLEL_H_ + +#include "tensorflow/core/framework/variable.pb.h" +#include "tensorflow/core/grappler/optimizers/graph_optimizer.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { +namespace grappler { + +// Automatically parallelize a graph by splitting in the batch dimension. +class AutoParallel : public GraphOptimizer { + public: + AutoParallel(int num_replicas) : num_replicas_(num_replicas) { + CHECK(num_replicas_ >= 2); + } + ~AutoParallel() override {} + + string name() const override { return "autoparallel"; }; + + bool UsesFunctionLibrary() const override { return false; } + + absl::Status Optimize(Cluster* cluster, const GrapplerItem& item, + GraphDef* output) override; + + private: + GraphDef graph_; + std::map all_nodes_; + std::set apply_gradients_nodes_; + std::set replica_nodes_; + std::set shared_nodes_; + const GrapplerItem* item_; + int num_replicas_; + int num_gpus_; + absl::Status Initialize(const GrapplerItem& item); + NodeDef* AddNodeDivConst(); + NodeDef* AddNodeDiv(const string& name, const string& input_a, + const string& input_b); + NodeDef* AddNodeControl(const string& name, const std::set& deps, + GraphDef* graph); + bool NotSharedNode(const string& name); + void AddSharedNodes(GraphDef* graph); + void AddOneReplica(GraphDef* graph, int number); + void BuildGraph(GraphDef* graph); +}; + +} // end namespace grappler +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_AUTO_PARALLEL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/common_subgraph_elimination.h b/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/common_subgraph_elimination.h new file mode 100644 index 00000000..2ec80e88 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/common_subgraph_elimination.h @@ -0,0 +1,70 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_COMMON_SUBGRAPH_ELIMINATION_H_ +#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_COMMON_SUBGRAPH_ELIMINATION_H_ + +#include + +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/grappler/optimizers/graph_optimizer.h" +#include "tensorflow/core/lib/gtl/flatset.h" +#include "tensorflow/core/platform/hash.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/protobuf/rewriter_config.pb.h" + +namespace tensorflow { +namespace grappler { + +// Optimize TF computations by deduping equivalent subgraphs. +class Cluster; +struct GrapplerItem; + +class CommonSubgraphElimination : public GraphOptimizer { + public: + CommonSubgraphElimination() {} + + explicit CommonSubgraphElimination(RewriterConfig::Toggle opt_level) + : opt_level_(opt_level) {} + + ~CommonSubgraphElimination() override {} + + string name() const override { return "common_subgraph_elimination"; }; + + bool UsesFunctionLibrary() const override { return false; } + + absl::Status Optimize(Cluster* cluster, const GrapplerItem& item, + GraphDef* optimized_graph) override; + + private: + friend class CommonSubgraphEliminationTest; + + // Returns true if it is safe to dedup node from the graph. + bool CanDedup(const NodeDef& node) const; + + // Dedup redundant nodes in the graph. + absl::Status DedupComputations(GraphDef* optimized_graph); + + RewriterConfig::Toggle opt_level_; + + bool fetch_nodes_known_ = false; + std::unordered_set nodes_to_preserve_; +}; + +} // end namespace grappler +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_COMMON_SUBGRAPH_ELIMINATION_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/constant_folding.h b/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/constant_folding.h new file mode 100644 index 00000000..9c58f81e --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/constant_folding.h @@ -0,0 +1,360 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_CONSTANT_FOLDING_H_ +#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_CONSTANT_FOLDING_H_ + +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/types/span.h" +#include "tensorflow/core/framework/device_base.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/resource_mgr.h" +#include "tensorflow/core/framework/tensor_shape.pb.h" +#include "tensorflow/core/grappler/costs/graph_properties.h" +#include "tensorflow/core/grappler/optimizers/graph_optimizer.h" +#include "tensorflow/core/grappler/utils.h" +#include "tensorflow/core/protobuf/rewriter_config.pb.h" + +namespace tensorflow { +namespace grappler { + +const char kConstantFoldingConst[] = "ConstantFolding"; +const char kConstantFoldingCtrl[] = "ConstantFoldingCtrl"; +extern const int64_t kMaxConstantSize; + +// Constant folding optimization for a graph. +class ConstantFolding : public GraphOptimizer { + public: + // The size limit will only be considered if the newly created node is greater + // than original_size (optional). + static absl::Status CreateNodeDef(const string& name, + const TensorValue& tensor, NodeDef* node, + size_t original_size = 0); + static string AddControlDependency(const string& input_name, GraphDef* graph, + NodeMap* node_map); + + explicit ConstantFolding(DeviceBase* cpu_device, + bool disable_compressed_tensor_optimization = false, + bool fold_quantization_emulation = true); + ConstantFolding(RewriterConfig::Toggle opt_level, DeviceBase* cpu_device, + bool disable_compressed_tensor_optimization = false, + bool fold_quantization_emulation = true); + + ~ConstantFolding() override {} + + string name() const override { return "constant_folding"; }; + + bool UsesFunctionLibrary() const override { return false; } + + absl::Status Optimize(Cluster* cluster, const GrapplerItem& item, + GraphDef* output) override; + + private: + bool ForwardInputs(NodeDef* node, absl::Span inputs_to_forward); + string OptimizedNodeName(const NodeDef& node, absl::string_view suffix) const; + bool OptimizedNodeExists(const NodeDef& node, absl::string_view suffix) const; + + bool IsReallyConstant(const NodeDef& node) const; + + bool GetTensorFromConstNode(const string& node_name_or_input, Tensor* tensor); + + absl::Status MaterializeShapes(const GraphProperties& properties); + + absl::Status MaterializeBroadcastGradientArgs( + const NodeDef& node, const GraphProperties& properties); + absl::Status MaterializeReductionIndices(NodeDef* node, + const GraphProperties& properties); + absl::Status MaterializeConstantValuedNode(NodeDef* node, + const GraphProperties& properties); + absl::Status MaterializeOutputValues(NodeDef* node, + const GraphProperties& properties); + absl::Status MaterializeConstants(const GraphProperties& properties); + + bool IsFoldable(const NodeDef& node, const GraphProperties* properties); + bool IsFoldableUncached(const NodeDef& node, + const GraphProperties* properties) const; + bool MaybeFoldable(const NodeDef& node, + const GraphProperties* properties) const; + + absl::Status EvaluateNode( + const NodeDef& node, const absl::InlinedVector& inputs, + absl::InlinedVector* output) const; + + absl::Status EvaluateOneFoldable(const NodeDef& node, + std::vector* outputs, + bool* result_too_large); + + absl::Status FoldMergeNode(NodeDef* node, GraphDef* output_graph); + absl::Status FoldNode(NodeDef* node, GraphDef* output_graph, + bool* result_too_large); + + bool IsOnes(const NodeDef& node) const; + bool IsZeros(const NodeDef& node) const; + bool ReplaceOperationWithBroadcastTo(int input_to_broadcast, + const GraphProperties& properties, + NodeDef* node, GraphDef* graph); + void ReplaceOperationWithIdentity(int input_to_forward, + const GraphProperties& properties, + NodeDef* node, GraphDef* graph); + void ReplaceOperationWithSnapshot(int input_to_forward, + const GraphProperties& properties, + NodeDef* node, GraphDef* graph); + void ReplaceOperationWithNoOp(NodeDef* node, GraphProperties* properties, + GraphDef* graph); + void ReplaceBinaryOperationWithBroadcastTo(int input_to_broadcast, + const GraphProperties& properties, + NodeDef* node, GraphDef* graph); + void ReplaceSubtractionFromZeroByNegation(NodeDef* node, GraphDef* graph); + absl::Status ReplaceOperationWithConstant(double value, + const GraphProperties& properties, + const TensorShapeProto& shape, + NodeDef* node, GraphDef* graph); + + // Notice: Destroys *value. + absl::Status ReplaceOperationWithConstantTensor(DataType dtype, + TensorProto* value, + NodeDef* node, + GraphDef* graph); + + void ReplaceDivisionOfOnesByReciprocal(NodeDef* node, GraphDef* graph); + absl::Status FoldGraph(const GraphProperties& properties, GraphDef* output, + absl::flat_hash_set* nodes_to_not_simplify); + + absl::Status IsSimplifiableReshape(const NodeDef& node, + const GraphProperties& properties) const; + absl::Status SimplifyGraph( + GraphDef* optimized_graph, GraphProperties* properties, + absl::flat_hash_set* nodes_to_not_simplify); + absl::Status SimplifyNode(NodeDef* node, GraphDef* optimized_graph, + GraphProperties* properties); + + absl::Status RunOptimizationPass(Cluster* cluster, GrapplerItem* item, + GraphProperties* properties, + GraphDef* optimized_graph); + + // Applies partial constant folding for Concat which is not commutative. + // Returns true if the transformation applied successfully. + bool PartialConcatConstFolding(GraphDef* optimized_graph, + GraphProperties* properties, NodeDef* node); + + // Applies partial constant folding for associative operators AddN and + // AccumulateNV2. Returns true if the transformation applied successfully. + bool PartialAssocOpConstFolding(GraphDef* optimized_graph, + GraphProperties* properties, NodeDef* node); + + // Applies partial constant propagation through IdentityN operator. + // Returns true if the transformation applied successfully. + bool PartialConstPropThroughIdentityN(NodeDef* node); + + struct ConstantPushDownContext { + NodeDef* op_child; + NodeDef* const_child; + bool left_child_is_const; + bool right_child_is_const; + NodeDef* left_leaf; + NodeDef* right_leaf; + bool left_leaf_is_const; + bool right_leaf_is_const; + + // Shape & type information. + const std::vector* parent_input_props; + const std::vector* op_child_input_props; + }; + + // Populates ctx with pointers to the nodes in expression tree for which + // constant pushdown optimization is being considered, corresponding to one of + // the following configurations: + // + // parent parent + // / \ / \ + // op_child const_child const_child op_child + // / \ / \ + // left_leaf right_leaf left_leaf right_leaf + // + // Returns true if the expression is possible amenable for optimization. + // Returns false if must_have_properties is true and input properties for + // parent and op_child are not known. + bool PrepareConstantPushDown(const NodeDef& parent, + const GraphProperties& properties, + bool must_have_properties, + ConstantPushDownContext* ctx) const; + + // Pushes down constants on '+', '-', '*', and '/' operators if applicable. + // Returns true if the transformation applied successfully. + bool ConstantPushDown(GraphProperties* properties, GraphDef* optimized_graph, + NodeDef* node); + + // Pushes down constants on '+' and 'BiasAdd' operators if applicable. + // Returns true if the graph was modified. + bool ConstantPushDownBiasAdd(GraphProperties* properties, + GraphDef* optimized_graph, NodeDef* node); + + // Aggregate constants present around a conv operator. Returns true if the + // transformation was applied successfully. + bool MulConvPushDown(GraphDef* optimized_graph, NodeDef* node, + const GraphProperties& properties); + + // Strength reduces floating point division by a constant Div(x, const) to + // multiplication by the reciprocal Mul(x, Reciprocal(const)). + bool ReduceDivToReciprocalMul(GraphDef* optimized_graph, NodeDef* node); + + // Simplifies arithmetic operations with ones or zeros. Returns the status, + // and updates the success input argument that denotes if any simplification + // was applied. + absl::Status SimplifyArithmeticOperations(const GraphProperties& properties, + bool use_shape_info, + GraphDef* optimized_graph, + NodeDef* node); + + // Simplifies a Reshape operation to an Identity operation if applicable. + bool SimplifyReshape(const GraphProperties& properties, bool use_shape_info, + NodeDef* node); + + // Returns true iff the node is a reduction and its reduction indices are + // constant. Sets *indices_is_empty to true if the set of dimensions to reduce + // along is empty (this happens often in the gradient graphs). + bool IsReductionWithConstantIndices(const NodeDef& node, + bool* indices_is_empty) const; + // Returns true if theres a possibility that a Reduce node could be simplified + // to an Identity/Reshape. + bool IsReductionCandidateForSimplification( + const NodeDef& node, const GraphProperties& properties, + TensorShapeProto* input_tensor_shape, + TensorShapeProto* output_tensor_shape, bool* is_single_element_op) const; + // Returns true iff this reduction can be reduced to an identity (i.e if the + // input dimensions to reduce along are all of size 1 and keep_dims is true). + bool IsReductionSimplifiableToIdentity( + const NodeDef& node, const TensorShapeProto& input_shape, bool keep_dims, + const absl::InlinedVector& reduction_indices_vector) + const; + // Changes a reduction into an Identity op, returning true on success. + bool ReplaceReductionWithIdentity(NodeDef* node) const; + + // Simplifies a Reduction operation to an Identity/Reshape operation if + // applicable. + bool SimplifyReduction(GraphDef* optimized_graph, + const GraphProperties& properties, NodeDef* node); + + // Switch(x, x) will always feed false to its false branch and true to + // its true branch. By rewriting the graph a bit, we can propagate these + // constants down the two output branches, and just use control dependencies + // to trigger the selected one at runtime. For example, + // + // +------+ + // x-->|Switch|-->a (in practice there may be multiple consumers of each + // x-->| |-->b output branch.) + // +------+ + // + // Is rewritten as + // + // +------+ + // x-->|Switch|-->Identity--^>Const(false)-->a + // x-->| |-->Identity--^>Const(true)-->b + // +------+ + bool SimplifySwitch(GraphDef* optimized_graph, NodeDef* node); + + // Moves constants past Enter node if applicable. + bool MoveConstantsPastEnter(GraphDef* optimized_graph, NodeDef* node); + + // Simplifies Pack operation if applicable. + bool SimplifyPack(GraphDef* optimized_graph, NodeDef* node); + + // Simplifies a Squeeze operation to an Identity operation if applicable. + void SimplifySqueeze(const GraphProperties& properties, bool use_shape_info, + GraphDef* optimized_graph, NodeDef* node); + + // Simplifies a Pad operation to an Identity operation if applicable. + absl::Status SimplifyPad(const GraphProperties& properties, + bool use_shape_info, GraphDef* optimized_graph, + NodeDef* node); + + // Simplifies a Tile operation to an Identity operation if applicable. + absl::Status SimplifyTile(const GraphProperties& properties, + bool use_shape_info, GraphDef* optimized_graph, + NodeDef* node); + + // Simplifies a StridedSlice operation to an Identity operation if applicable. + absl::Status SimplifyStridedSlice(const GraphProperties& properties, + bool use_shape_info, + GraphDef* optimized_graph, NodeDef* node); + + // Simplifies a Slice operation to an Identity operation if applicable. + absl::Status SimplifySlice(const GraphProperties& properties, + bool use_shape_info, GraphDef* optimized_graph, + NodeDef* node); + + // Simplify a Case operation where the output_idx is known. + bool SimplifyCase(GraphDef* optimized_graph, NodeDef* node); + + // Simplify a Select operation where the predicates are all true or all false. + bool SimplifySelect(const GraphProperties& properties, + GraphDef* optimized_graph, NodeDef* node); + + // Replaces variable updates that are effectively no-ops with NoOp nodes. + void RemoveRedundantVariableUpdates(GraphProperties* properties, + GraphDef* optimized_graph, NodeDef* node); + + // Removes Reverse op over dimensions with size 1. + absl::Status RemoveReverse(const GraphProperties& properties, + bool use_shape_info, GraphDef* optimized_graph, + NodeDef* node); + + // Removes RandomShuffle op if it is scalar or first dimension is of size 1. + void RemoveRandomShuffle(const GraphProperties& properties, + bool use_shape_info, GraphDef* optimized_graph, + NodeDef* node); + + // Removes Shuffle or Transpose op over dimensions of size 1. + absl::Status RemoveShuffleOrTranspose(const GraphProperties& properties, + bool use_shape_info, + GraphDef* optimized_graph, + NodeDef* node); + + // Removes Split or SplitV node if possible. + void RemoveSplitOrSplitV(const GraphProperties& properties, + GraphDef* optimized_graph, NodeDef* node); + + bool GetConcatAxis(const NodeDef& node, int* axis); + bool MergeConcat(bool use_shape_info, GraphProperties* properties, + GraphDef* optimized_graph, NodeDef* node); + + absl::Status AddQuantizedMatMulMinMaxOutConstNodes(NodeDef* node, + GraphDef* optimized_graph); + + // Points to an externally provided device or to owned_device_; + RewriterConfig::Toggle opt_level_; + DeviceBase* cpu_device_; + std::unique_ptr owned_device_; + + std::unique_ptr resource_mgr_; + GraphDef* graph_; + std::unique_ptr node_map_; + std::unordered_set nodes_to_preserve_; + // TODO(rmlarsen): Could these be keyed on absl::string_view? + absl::flat_hash_set nodes_allowlist_; + absl::flat_hash_set feed_nodes_; + absl::flat_hash_map maybe_foldable_nodes_; + bool has_fetch_; + bool graph_modified_; + bool graph_contains_assign_or_inplace_op_; + bool disable_compressed_tensor_optimization_; + bool fold_quantization_emulation_; +}; + +} // end namespace grappler +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_CONSTANT_FOLDING_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/custom_graph_optimizer.h b/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/custom_graph_optimizer.h new file mode 100644 index 00000000..beb6bd09 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/custom_graph_optimizer.h @@ -0,0 +1,48 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_CUSTOM_GRAPH_OPTIMIZER_H_ +#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_CUSTOM_GRAPH_OPTIMIZER_H_ + +#include "tensorflow/core/grappler/optimizers/graph_optimizer.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/protobuf/config.pb.h" +#include "tensorflow/core/protobuf/rewriter_config.pb.h" + +namespace tensorflow { +namespace grappler { + +// A custom optimizer that can be registered. +class CustomGraphOptimizer : public GraphOptimizer { + public: + virtual ~CustomGraphOptimizer() {} + virtual absl::Status Init( + const tensorflow::RewriterConfig_CustomGraphOptimizer* config = + nullptr) = 0; + // Populates ConfigProto on which the Session is run prior to running Init. + absl::Status InitWithConfig( + const ConfigProto& config_proto, + const tensorflow::RewriterConfig_CustomGraphOptimizer* config = nullptr) { + config_proto_ = config_proto; + return this->Init(config); + } + + ConfigProto config_proto_; +}; + +} // end namespace grappler +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_CUSTOM_GRAPH_OPTIMIZER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h b/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h new file mode 100644 index 00000000..67dff162 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h @@ -0,0 +1,116 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_CUSTOM_GRAPH_OPTIMIZER_REGISTRY_H_ +#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_CUSTOM_GRAPH_OPTIMIZER_REGISTRY_H_ + +#include +#include +#include +#include + +#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer.h" + +namespace tensorflow { +namespace grappler { + +// Contains plugin's configurations for each Grappler optimizer (on/off). +// See tensorflow/core/protobuf/rewriter_config.proto for optimizer description. +struct ConfigList { + ConfigList() {} + ConfigList(bool disable_model_pruning, + std::unordered_map config) + : disable_model_pruning(disable_model_pruning), + toggle_config(std::move(config)) {} + + bool operator==(const ConfigList& other) const { + return (disable_model_pruning == other.disable_model_pruning) && + (toggle_config == other.toggle_config); + } + bool disable_model_pruning; // Don't remove unnecessary ops from the graph. + std::unordered_map toggle_config; +}; + +class CustomGraphOptimizerRegistry { + public: + static std::unique_ptr CreateByNameOrNull( + const string& name); + + static std::vector GetRegisteredOptimizers(); + + typedef std::function Creator; + // Register graph optimizer which can be called during program initialization. + // This class is not thread-safe. + static void RegisterOptimizerOrDie(const Creator& optimizer_creator, + const string& name); +}; + +class CustomGraphOptimizerRegistrar { + public: + explicit CustomGraphOptimizerRegistrar( + const CustomGraphOptimizerRegistry::Creator& creator, + const string& name) { + CustomGraphOptimizerRegistry::RegisterOptimizerOrDie(creator, name); + } +}; + +#define REGISTER_GRAPH_OPTIMIZER_AS(MyCustomGraphOptimizerClass, name) \ + namespace { \ + static ::tensorflow::grappler::CustomGraphOptimizerRegistrar \ + MyCustomGraphOptimizerClass##_registrar( \ + []() { return new MyCustomGraphOptimizerClass; }, (name)); \ + } // namespace + +#define REGISTER_GRAPH_OPTIMIZER(MyCustomGraphOptimizerClass) \ + REGISTER_GRAPH_OPTIMIZER_AS(MyCustomGraphOptimizerClass, \ + #MyCustomGraphOptimizerClass) + +// A separate registry to register all plug-in CustomGraphOptimizers. +class PluginGraphOptimizerRegistry { + public: + // Constructs a list of plug-in CustomGraphOptimizers from the global map + // `registered_plugin_optimizers`. + static std::vector> CreateOptimizers( + const std::set& device_types); + + typedef std::function Creator; + + // Returns plugin's config. If any of the config is turned off, the returned + // config will be turned off. + static ConfigList GetPluginConfigs(bool use_plugin_optimizers, + const std::set& device_types); + + // Registers plugin graph optimizer which can be called during program + // initialization. Dies if multiple plugins with the same `device_type` are + // registered. This class is not thread-safe. + static void RegisterPluginOptimizerOrDie(const Creator& optimizer_creator, + const std::string& device_type, + ConfigList& configs); + + // Prints plugin's configs if there are some conflicts. + static void PrintPluginConfigsIfConflict( + const std::set& device_types); + + // Returns true when `plugin_config` conflicts with `user_config`: + // - Plugin's `disable_model_pruning` is not equal to `user_config`'s, or + // - At least one of plugin's `toggle_config`s is on when it is set to off in + // `user_config`'s. + static bool IsConfigsConflict(ConfigList& user_config, + ConfigList& plugin_config); +}; + +} // end namespace grappler +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_CUSTOM_GRAPH_OPTIMIZER_REGISTRY_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/data/auto_shard.h b/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/data/auto_shard.h new file mode 100644 index 00000000..400ace5f --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/data/auto_shard.h @@ -0,0 +1,64 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_AUTO_SHARD_H_ +#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_AUTO_SHARD_H_ + +#include +#include + +#include "tensorflow/core/framework/dataset_options.pb.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/grappler/mutable_graph_view.h" +#include "tensorflow/core/grappler/optimizers/data/optimizer_base.h" + +namespace tensorflow { +namespace grappler { + +class AutoShard : public TFDataOptimizerBase { + public: + AutoShard() = default; + ~AutoShard() override = default; + + string name() const override { return "tf_auto_shard"; } + + bool UsesFunctionLibrary() const override { return true; } + + absl::Status Init( + const tensorflow::RewriterConfig_CustomGraphOptimizer* config) override; + + absl::Status OptimizeAndCollectStats(Cluster* cluster, + const GrapplerItem& item, + GraphDef* output, + OptimizationStats* stats) override; + + private: + int64_t num_workers_; + int64_t num_replicas_; + int64_t index_; + tensorflow::data::AutoShardPolicy auto_shard_policy_; +}; + +// For testing only +namespace internal { +bool IsEligibleRewriteBatchSize(const NodeDef& sink_node, + const MutableGraphView& graph, + std::vector* ineligible_reason); +} + +} // namespace grappler +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_AUTO_SHARD_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/data/autotune_buffer_sizes.h b/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/data/autotune_buffer_sizes.h new file mode 100644 index 00000000..0860ba50 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/data/autotune_buffer_sizes.h @@ -0,0 +1,75 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_AUTOTUNE_BUFFER_SIZES_H_ +#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_AUTOTUNE_BUFFER_SIZES_H_ + +#include "absl/status/status.h" +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/grappler/optimizers/data/optimizer_base.h" + +namespace tensorflow { +namespace grappler { + +constexpr char kAutotune[] = "autotune"; + +// This optimization does the following: +// +// 1. Adds `prefetch(AUTOTUNE)` after all asynchronous tf.data transformations +// (e.g. parallel batch, parallel map, parallel interleave, and map + batch) if +// they are not followed by a `prefetch` yet. +// +// 2. If there exists any `prefetch(buffer_size=N)` for `N>=0`, it will replace +// the transformation with autotunable version of `prefetch` which uses N as +// the minimum size of the buffer. +class AutotuneBufferSizes : public TFDataOptimizerBase { + public: + AutotuneBufferSizes() = default; + ~AutotuneBufferSizes() override = default; + + string name() const override { return "autotune_buffer_sizes"; }; + + bool UsesFunctionLibrary() const override { return false; } + + absl::Status Init( + const tensorflow::RewriterConfig_CustomGraphOptimizer* config) override { + if (!config) return absl::OkStatus(); + + const string& autotune = config->parameter_map().at(kAutotune).s(); + if (autotune == "true") { + autotune_ = true; + } else if (autotune == "false") { + autotune_ = false; + } else { + return absl::InvalidArgumentError( + absl::StrCat("Received an invalid value for parameter ", kAutotune, + ": ", autotune)); + } + return absl::OkStatus(); + } + + absl::Status OptimizeAndCollectStats(Cluster* cluster, + const GrapplerItem& item, + GraphDef* output, + OptimizationStats* stats) override; + + private: + bool autotune_ = true; +}; + +} // namespace grappler +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_AUTOTUNE_BUFFER_SIZES_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/data/batch_parallelization.h b/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/data/batch_parallelization.h new file mode 100644 index 00000000..2e77dea0 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/data/batch_parallelization.h @@ -0,0 +1,65 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_BATCH_PARALLELIZATION_H_ +#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_BATCH_PARALLELIZATION_H_ + +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/grappler/optimizers/data/optimizer_base.h" + +namespace tensorflow { +namespace grappler { + +constexpr char kAutotune[] = "autotune"; + +// This optimization parallelizes BatchDataset. +class BatchParallelization : public TFDataOptimizerBase { + public: + BatchParallelization() = default; + ~BatchParallelization() override = default; + + string name() const override { return "batch_parallelization"; }; + + bool UsesFunctionLibrary() const override { return false; } + + absl::Status Init( + const tensorflow::RewriterConfig_CustomGraphOptimizer* config) override { + if (!config) return absl::OkStatus(); + + const string& autotune = config->parameter_map().at(kAutotune).s(); + if (autotune == "true") { + autotune_ = true; + } else if (autotune == "false") { + autotune_ = false; + } else { + return errors::InvalidArgument("Received an invalid value for parameter ", + kAutotune, ": ", autotune); + } + return absl::OkStatus(); + } + + absl::Status OptimizeAndCollectStats(Cluster* cluster, + const GrapplerItem& item, + GraphDef* output, + OptimizationStats* stats) override; + + private: + bool autotune_ = true; +}; + +} // namespace grappler +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_BATCH_PARALLELIZATION_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/data/disable_intra_op_parallelism.h b/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/data/disable_intra_op_parallelism.h new file mode 100644 index 00000000..977b0c5d --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/data/disable_intra_op_parallelism.h @@ -0,0 +1,48 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_DISABLE_INTRA_OP_PARALLELISM_H_ +#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_DISABLE_INTRA_OP_PARALLELISM_H_ + +#include "tensorflow/core/grappler/optimizers/data/optimizer_base.h" + +namespace tensorflow { +namespace grappler { + +// This optimization sets intra-op parallelism to be 1. +class DisableIntraOpParallelism : public TFDataOptimizerBase { + public: + DisableIntraOpParallelism() = default; + ~DisableIntraOpParallelism() override = default; + + string name() const override { return "disable_intra_op_parallelism"; }; + + bool UsesFunctionLibrary() const override { return false; } + + absl::Status Init( + const tensorflow::RewriterConfig_CustomGraphOptimizer* config) override { + return absl::OkStatus(); + } + + absl::Status OptimizeAndCollectStats(Cluster* cluster, + const GrapplerItem& item, + GraphDef* output, + OptimizationStats* stats) override; +}; + +} // namespace grappler +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_DISABLE_INTRA_OP_PARALLELISM_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/data/disable_prefetch_legacy_autotune.h b/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/data/disable_prefetch_legacy_autotune.h new file mode 100644 index 00000000..3aded258 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/data/disable_prefetch_legacy_autotune.h @@ -0,0 +1,65 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_DISABLE_PREFETCH_LEGACY_AUTOTUNE_H_ +#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_DISABLE_PREFETCH_LEGACY_AUTOTUNE_H_ + +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/grappler/optimizers/data/optimizer_base.h" + +namespace tensorflow { +namespace grappler { + +constexpr char kAutotune[] = "autotune"; + +// This optimization disables the lagacy autotune option for PrefetchDataset. +class DisablePrefetchLegacyAutotune : public TFDataOptimizerBase { + public: + DisablePrefetchLegacyAutotune() = default; + ~DisablePrefetchLegacyAutotune() override = default; + + string name() const override { return "disable_prefetch_legacy_autotune"; }; + + bool UsesFunctionLibrary() const override { return false; } + + absl::Status Init( + const tensorflow::RewriterConfig_CustomGraphOptimizer* config) override { + if (!config) return absl::OkStatus(); + + const string& autotune = config->parameter_map().at(kAutotune).s(); + if (autotune == "true") { + autotune_ = true; + } else if (autotune == "false") { + autotune_ = false; + } else { + return errors::InvalidArgument("Received an invalid value for parameter ", + kAutotune, ": ", autotune); + } + return absl::OkStatus(); + } + + absl::Status OptimizeAndCollectStats(Cluster* cluster, + const GrapplerItem& item, + GraphDef* output, + OptimizationStats* stats) override; + + private: + bool autotune_ = true; +}; + +} // namespace grappler +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_DISABLE_PREFETCH_LEGACY_AUTOTUNE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/data/enable_gradient_descent.h b/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/data/enable_gradient_descent.h new file mode 100644 index 00000000..35c333c0 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/data/enable_gradient_descent.h @@ -0,0 +1,65 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_ENABLE_GRADIENT_DESCENT_H_ +#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_ENABLE_GRADIENT_DESCENT_H_ + +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/grappler/optimizers/data/optimizer_base.h" + +namespace tensorflow { +namespace grappler { + +constexpr char kAutotune[] = "autotune"; + +// This optimization enables Gradient Descent Optimization in `ModelDataset`. +class EnableGradientDescent : public TFDataOptimizerBase { + public: + EnableGradientDescent() = default; + ~EnableGradientDescent() override = default; + + string name() const override { return "enable_gradient_descent"; }; + + bool UsesFunctionLibrary() const override { return false; } + + absl::Status Init( + const tensorflow::RewriterConfig_CustomGraphOptimizer* config) override { + if (!config) return absl::OkStatus(); + + const string& autotune = config->parameter_map().at(kAutotune).s(); + if (autotune == "true") { + autotune_ = true; + } else if (autotune == "false") { + autotune_ = false; + } else { + return errors::InvalidArgument("Received an invalid value for parameter ", + kAutotune, ": ", autotune); + } + return absl::OkStatus(); + } + + absl::Status OptimizeAndCollectStats(Cluster* cluster, + const GrapplerItem& item, + GraphDef* output, + OptimizationStats* stats) override; + + private: + bool autotune_ = true; +}; + +} // namespace grappler +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_ENABLE_GRADIENT_DESCENT_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/data/filter_fusion.h b/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/data/filter_fusion.h new file mode 100644 index 00000000..757f7557 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/data/filter_fusion.h @@ -0,0 +1,48 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_FILTER_FUSION_H_ +#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_FILTER_FUSION_H_ + +#include "tensorflow/core/grappler/optimizers/data/optimizer_base.h" + +namespace tensorflow { +namespace grappler { + +// This optimization fuses filter transformations. +class FilterFusion : public TFDataOptimizerBase { + public: + FilterFusion() = default; + ~FilterFusion() override = default; + + string name() const override { return "filter_fusion"; }; + + bool UsesFunctionLibrary() const override { return false; } + + absl::Status Init( + const tensorflow::RewriterConfig_CustomGraphOptimizer* config) override { + return absl::OkStatus(); + } + + absl::Status OptimizeAndCollectStats(Cluster* cluster, + const GrapplerItem& item, + GraphDef* output, + OptimizationStats* stats) override; +}; + +} // namespace grappler +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_FILTER_FUSION_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/data/filter_parallelization.h b/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/data/filter_parallelization.h new file mode 100644 index 00000000..63f75907 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/data/filter_parallelization.h @@ -0,0 +1,65 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_FILTER_PARALLELIZATION_H_ +#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_FILTER_PARALLELIZATION_H_ + +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/grappler/optimizers/data/optimizer_base.h" + +namespace tensorflow { +namespace grappler { + +constexpr char kAutotune[] = "autotune"; + +// This optimization parallelizes FilterDataset when function is stateless. +class FilterParallelization : public TFDataOptimizerBase { + public: + FilterParallelization() = default; + ~FilterParallelization() override = default; + + string name() const override { return "filter_parallelization"; }; + + bool UsesFunctionLibrary() const override { return false; } + + absl::Status Init( + const tensorflow::RewriterConfig_CustomGraphOptimizer* config) override { + if (!config) return absl::OkStatus(); + + const string& autotune = config->parameter_map().at(kAutotune).s(); + if (autotune == "true") { + autotune_ = true; + } else if (autotune == "false") { + autotune_ = false; + } else { + return errors::InvalidArgument("Received an invalid value for parameter ", + kAutotune, ": ", autotune); + } + return absl::OkStatus(); + } + + absl::Status OptimizeAndCollectStats(Cluster* cluster, + const GrapplerItem& item, + GraphDef* output, + OptimizationStats* stats) override; + + private: + bool autotune_ = true; +}; + +} // namespace grappler +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_FILTER_PARALLELIZATION_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/data/function_utils.h b/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/data/function_utils.h new file mode 100644 index 00000000..06034636 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/data/function_utils.h @@ -0,0 +1,132 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_FUNCTION_UTILS_H_ +#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_FUNCTION_UTILS_H_ + +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/function.pb.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/tensor.pb.h" +#include "tensorflow/core/framework/tensor_shape.pb.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/grappler/mutable_graph_view.h" +#include "tensorflow/core/grappler/utils.h" +#include "tensorflow/core/lib/core/errors.h" + +namespace tensorflow { +namespace grappler { +namespace function_utils { +// This namespace contains utility functions for querying and modifying +// FunctionDefs. + +// Describes a FunctionDef input tensor. In FunctionDefs, input tensor strings +// have the format node_name:node_output:position (if they derive from nodes), +// or input_name (if they derive from an argument). +struct FunctionDefTensorDesc { + FunctionDefTensorDesc() = default; + + FunctionDefTensorDesc(const string& node_name, const string& output, + int position); + + // Parses node_name:node_output:position string into its components. + explicit FunctionDefTensorDesc(const string& input); + + // TODO(rachelim): Add provisions to deal with special formats, like how + // GrapplerFunctionItem expands node output range if position is not defined + string full_str; + string node_name; + string node_output; + int position = -1; +}; + +// Replaces all references to `from` tensor in func's nodes' inputs and retvals +// to `to` tensor. This is similar to `MutableGraphView::ReplaceInputs`. +void ReplaceReferences(const string& from, const string& to, FunctionDef* func); + +// Adds a function output to the function def, ensuring that the output key +// is unique, and maps to output_tensor_name in the ret dict. +void AddFunctionOutputWithUniqueName(absl::string_view prefix, + absl::string_view output_tensor_name, + FunctionDef* fdef, DataType dtype); + +// Adds an input to a FunctionDef. +OpDef_ArgDef* AddFunctionInput(const string& name, FunctionDef* fdef, + DataType dtype); + +// Adds a node to a FunctionDef. +NodeDef* AddNode(absl::string_view name, absl::string_view op, + const std::vector& inputs, + const std::vector>& attributes, + FunctionDef* fd); + +// Checks whether the function contains a node with the given name. +bool ContainsFunctionNodeWithName(absl::string_view name, + const FunctionDef& function); + +// Checks whether the function contains a node with the given op. +bool ContainsFunctionNodeWithOp(absl::string_view op, + const FunctionDef& function); + +// Checks whether the function contains an output with the given name. +bool ContainsFunctionOutputWithName(absl::string_view name, + const FunctionDef& function); + +// Returns the index of the function input with the given name or -1 if the +// function node does not exist. +int FindFunctionInputWithName(absl::string_view name, + const FunctionDef& function); + +// Returns the index of the function output with the given name or -1 if the +// function node does not exist. +int FindFunctionOutputWithName(absl::string_view name, + const FunctionDef& function); + +// Returns the index of the function node with the given name or -1 if the +// function node does not exist. +int FindFunctionNodeWithName(absl::string_view name, + const FunctionDef& function); + +// Returns the index of the function node with the given op or -1 if the +// function node does not exist. +int FindFunctionNodeWithOp(absl::string_view op, const FunctionDef& function); + +// Sets the function node name using the `prefix` as a prefix while guaranteeing +// the name is unique across the functions nodes. +void SetUniqueFunctionNodeName(absl::string_view prefix, FunctionDef* function, + NodeDef* node); + +// Checks if the function is stateful by checking the function graph for +// stateful ops. Because the "If" and "While" ops are conservatively marked as +// stateful, the check recurses into their graph to determine whether they are +// actually stateful. The `skip_assert` argument determines whether the "Assert" +// op should be treated as stateful or not. +bool IsFunctionStateful(const FunctionLibraryDefinition& library, + const FunctionDef& function_def, + bool skip_assert = false); + +// Checks if the node is stateful. Because the "If" or "While" ops are +// conservatively marked as stateful, the check recurses into their graph to +// determine whether they are actually stateful. The `skip_assert` argument +// determines whether the "Assert" op should be treated as stateful or not. +bool IsNodeStateful(const FunctionLibraryDefinition& library, + const NodeDef& node, bool skip_assert = false); + +} // end namespace function_utils +} // end namespace grappler +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_FUNCTION_UTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/data/fusion_utils.h b/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/data/fusion_utils.h new file mode 100644 index 00000000..d0b7ed7c --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/data/fusion_utils.h @@ -0,0 +1,138 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_FUSION_UTILS_H_ +#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_FUSION_UTILS_H_ + +#include +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/grappler/op_types.h" +#include "tensorflow/core/grappler/optimizers/data/graph_utils.h" +#include "tensorflow/core/lib/gtl/inlined_vector.h" +#include "tensorflow/core/platform/protobuf.h" + +namespace tensorflow { +namespace grappler { +namespace fusion_utils { + +// These functions are invoked with first and second function signature, +// should set a signature of fused second_function. +using SetFunctionSignatureFn = std::function; + +using StringCollection = absl::InlinedVector; + +// These functions are invoked with nodes from second function that were +// previously taking arguments as input. The `arg_num` tells which +// function argument node was using as an input, e.g: +// node(arg_1, other_node, arg_4) +// would be called on the first and third input with arg_num equal 1 and 4. +// It should set up inputs based on first function inputs or outputs or +// second function inputs. +using SetInputFn = + std::function; + +// This function is invoked with first and second function ret. It is used to +// set up returns of fused function. +using SetOutputFn = + std::function& parent_ret, + const protobuf::Map& second_function_ret, + protobuf::Map* fused_ret)>; + +using SetNodesFn = std::function; + +void MergeNodes(const FunctionDef& first_function, + const FunctionDef& second_function, FunctionDef* fused_function, + FunctionDefLibrary* library); + +// Returns true if functions can be composed. +bool CanCompose(const OpDef& first_signature, const OpDef& second_signature); + +void ComposeSignature(const OpDef& first_signature, + const OpDef& second_signature, OpDef* fused_signature); + +string ComposeInput(const StringCollection& first_inputs, + const StringCollection& second_inputs, + const StringCollection& first_outputs, int arg_num); + +// Sets output to the composition of first and second function: +// second_function(first_function(args...)). +void ComposeOutput(const protobuf::Map& first_ret, + const protobuf::Map& second_ret, + protobuf::Map* fused_ret); + +// Set input signature to `first_function_signature` and output signature +// to `first_function_signature` + `second_function_signature` +void CombineSignature(const OpDef& first_signature, + const OpDef& second_signature, OpDef* fused_signature); + +// Apart from first function returns, return values from second function as +// extra returns like: +// return *first_function(...), *second_function(...) +void CombineOutput(const protobuf::Map& first_ret, + const protobuf::Map& second_ret, + protobuf::Map* fused_ret); + +// Returns true if both signatures have the same number of input and output +// args. +bool HasSameSignature(const OpDef& first_signature, + const OpDef& second_signature); + +// Check if both signatures are same and copy it from `first_signature`. +void SameSignature(const OpDef& first_signature, const OpDef& second_signature, + OpDef* fused_signature); + +// Take the same input as first function. +string SameInput(const StringCollection& first_inputs, + const StringCollection& second_inputs, + const StringCollection& first_outputs, int arg_num); + +// Create a fused function that computes the short-circuit logical AND of the +// result of the first function and the result of the second function. +void LazyConjunctionOutput(const protobuf::Map& first_ret, + const protobuf::Map& second_ret, + protobuf::Map* fused_ret); + +void LazyConjunctionNodes(const FunctionDef& first_function, + const FunctionDef& second_function, + FunctionDef* fused_function, + FunctionDefLibrary* library); + +// Fuse `first_function` with `second_function`, setting `fused_name_prefix` as +// a name prefix. The nodes from `first_function` are copied unmodified. All +// of the setup functions are called with a copy of second function having names +// that are not conflicting with first function. This means that copied nodes +// from second function can end up having different names. For explanation of +// set up functions see the documentation of the functions types. +FunctionDef* FuseFunctions(const FunctionDef& first_function, + const FunctionDef& second_function, + absl::string_view fused_name_prefix, + const SetFunctionSignatureFn& set_signature, + const SetInputFn& set_input, + const SetOutputFn& set_output, + const SetNodesFn& set_nodes, + FunctionDefLibrary* library); + +} // namespace fusion_utils +} // namespace grappler +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_FUSION_UTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/data/graph_test_utils.h b/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/data/graph_test_utils.h new file mode 100644 index 00000000..2b09eafc --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/data/graph_test_utils.h @@ -0,0 +1,141 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_GRAPH_TEST_UTILS_H_ +#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_GRAPH_TEST_UTILS_H_ + +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/lib/core/stringpiece.h" + +namespace tensorflow { +namespace grappler { +namespace graph_tests_utils { + +// Creates a test NodeDef for BatchDatasetV2. +NodeDef MakeBatchV2Node(absl::string_view name, + absl::string_view input_node_name, + absl::string_view batch_size_node_name, + absl::string_view drop_remainder_node_name, + bool parallel_copy); + +// Creates a test NodeDef for ParallelBatchDataset. +NodeDef MakeParallelBatchNode(absl::string_view name, + absl::string_view input_node_name, + absl::string_view batch_size_node_name, + absl::string_view num_parallel_calls_node_name, + absl::string_view drop_remainder_node_name, + absl::string_view deterministic); + +// Creates a test NodeDef for ShuffleDatasetV2. +NodeDef MakeCacheV2Node(absl::string_view name, + absl::string_view input_node_name, + absl::string_view filename_node_name, + absl::string_view cache_node_name); + +// Creates a test NodeDef for FilterDataset. +NodeDef MakeFilterNode(absl::string_view name, + absl::string_view input_node_name, + absl::string_view function_name = "IsZero"); + +// Creates a test NodeDef for MapDataset. +NodeDef MakeMapNode(absl::string_view name, absl::string_view input_node_name, + absl::string_view function_name = "XTimesTwo"); + +// Creates a test NodeDef for MapAndBatchDataset. +NodeDef MakeMapAndBatchNode(absl::string_view name, + absl::string_view input_node_name, + absl::string_view batch_size_node_name, + absl::string_view num_parallel_calls_node_name, + absl::string_view drop_remainder_node_name, + absl::string_view function_name = "XTimesTwo"); + +// Creates a test NodeDef for ParallelInterleaveDatasetV2. +NodeDef MakeParallelInterleaveV2Node( + absl::string_view name, absl::string_view input_node_name, + absl::string_view cycle_length_node_name, + absl::string_view block_length_node_name, + absl::string_view num_parallel_calls_node_name, + absl::string_view function_name, bool sloppy); + +// Creates a test NodeDef for ParallelInterleaveDatasetV4. +NodeDef MakeParallelInterleaveV4Node( + absl::string_view name, absl::string_view input_node_name, + absl::string_view cycle_length_node_name, + absl::string_view block_length_node_name, + absl::string_view num_parallel_calls_node_name, + absl::string_view function_name, absl::string_view deterministic); + +// Creates a test NodeDef for InterleaveDataset. +NodeDef MakeInterleaveNode(absl::string_view name, + absl::string_view input_node_name, + absl::string_view cycle_length_node_name, + absl::string_view block_length_node_name, + absl::string_view function_name, + absl::string_view deterministic); + +// Creates a test NodeDef for ParallelMapDataset. +NodeDef MakeParallelMapNode(absl::string_view name, + absl::string_view input_node_name, + absl::string_view num_parallel_calls_node_name, + absl::string_view function_name, bool sloppy); + +// Creates a test NodeDef for ParallelMapDatasetV2. +NodeDef MakeParallelMapV2Node(absl::string_view name, + absl::string_view input_node_name, + absl::string_view num_parallel_calls_node_name, + absl::string_view function_name, + absl::string_view deterministic, + bool use_unbounded_threadpool); + +// Creates a test NodeDef for ParseExampleDataset. +NodeDef MakeParseExampleNode(absl::string_view name, + absl::string_view input_node_name, + absl::string_view num_parallel_calls_node_name, + bool sloppy); + +// Creates a test NodeDef for ShuffleDatasetV2. +NodeDef MakeShuffleV2Node(absl::string_view name, + absl::string_view input_node_name, + absl::string_view buffer_size_node_name, + absl::string_view seed_generator_node_name); + +// Creates a test NodeDef for TakeDataset. +NodeDef MakeTakeNode(absl::string_view name, absl::string_view input_node_name, + absl::string_view count_node_name); + +// Creates a test NodeDef for TensorSliceDataset. +NodeDef MakeTensorSliceNode(absl::string_view name, + absl::string_view tensor_node_name, + bool replicate_on_split); + +// Creates a test NodeDef for SkipDataset. +NodeDef MakeSkipNode(absl::string_view name, absl::string_view input_node_name, + absl::string_view count_node_name); + +// Creates a test NodeDef for ShardDataset. +NodeDef MakeShardNode(absl::string_view name, absl::string_view input_node_name, + absl::string_view num_shards_node_name, + absl::string_view index_node_name); + +// Creates a test NodeDef for PrefetchDataset. +NodeDef MakePrefetchNode(absl::string_view name, + absl::string_view input_node_name, + absl::string_view buffer_size); + +} // namespace graph_tests_utils +} // namespace grappler +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_GRAPH_TEST_UTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/data/graph_utils.h b/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/data/graph_utils.h new file mode 100644 index 00000000..70d0c480 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/data/graph_utils.h @@ -0,0 +1,214 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_GRAPH_UTILS_H_ +#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_GRAPH_UTILS_H_ + +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/function.pb.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/tensor.pb.h" +#include "tensorflow/core/framework/tensor_shape.pb.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/grappler/grappler_item.h" +#include "tensorflow/core/grappler/mutable_graph_view.h" +#include "tensorflow/core/grappler/utils.h" +#include "tensorflow/core/lib/core/errors.h" + +namespace tensorflow { +namespace grappler { +namespace graph_utils { + +// Returns the index of the first element in collection that fulfills predicate. +// If no such element exists, returns -1. +template +int GetFirstElementIndexWithPredicate(const Predicate& predicate, + const Collection& collection) { + unsigned idx = 0; + for (auto&& element : collection) { + if (predicate(element)) { + return idx; + } + idx++; + } + return -1; +} + +// Adds a node to the graph. +NodeDef* AddNode(absl::string_view name, absl::string_view op, + const std::vector& inputs, + const std::vector>& attributes, + MutableGraphView* graph); + +// Adds Placeholder node for given type. +NodeDef* AddScalarPlaceholder(DataType dtype, MutableGraphView* graph); + +// Adds a Const node with the given value to the graph. +template +NodeDef* AddScalarConstNode(T v, MutableGraphView* graph) { + // is_same is an idiomatic hack for making it compile if not instantiated. + // Replacing with false will result in a compile-time error. + static_assert(!std::is_same::value, + "Invalid specialization of this method for type T."); + return {}; +} + +template <> +NodeDef* AddScalarConstNode(bool v, MutableGraphView* graph); +template <> +NodeDef* AddScalarConstNode(double v, MutableGraphView* graph); +template <> +NodeDef* AddScalarConstNode(float v, MutableGraphView* graph); +template <> +NodeDef* AddScalarConstNode(int v, MutableGraphView* graph); +template <> +NodeDef* AddScalarConstNode(int64_t v, MutableGraphView* graph); +template <> +NodeDef* AddScalarConstNode(absl::string_view v, MutableGraphView* graph); + +// Retrieves the value of a const node. Returns an error +// if the node is not const, or its value is of a different type. +template +absl::Status GetScalarConstNodeValue(const NodeDef& node, T* value) { + // is_same is an idiomatic hack for making it compile if not instantiated. + // Replacing with false will result in a compile-time error. + static_assert(!std::is_same::value, + "Invalid specialization of this method fo rtype T."); +} + +template <> +absl::Status GetScalarConstNodeValue(const NodeDef& node, int64_t* value); +template <> +absl::Status GetScalarConstNodeValue(const NodeDef& node, bool* value); + +// Checks whether the two graphs are the same. +bool Compare(const GraphDef& g1, const GraphDef& g2); + +// Checks whether the graph contains a node with the given name. +bool ContainsGraphNodeWithName(absl::string_view name, const GraphDef& graph); + +// Checks whether the library contains a function with the given name. +bool ContainsGraphFunctionWithName(absl::string_view name, + const FunctionDefLibrary& library); + +// Checks whether the graph contains a node with the given op. +bool ContainsNodeWithOp(absl::string_view op, const GraphDef& graph); + +// Returns the index of the node with the given name or -1 if the node does +// not exist. +int FindGraphNodeWithName(absl::string_view name, const GraphDef& graph); + +// Returns the index of the function with the given name or -1 if the function +// does not exist. +int FindGraphFunctionWithName(absl::string_view name, + const FunctionDefLibrary& library); + +// Returns the index of the first node with the given op or -1 if no such node +// exists. +int FindGraphNodeWithOp(absl::string_view op, const GraphDef& graph); + +// Gets the 0th input to a node in the graph. +NodeDef* GetInputNode(const NodeDef& node, const MutableGraphView& graph); + +// Gets the ith input to a node in the graph. +NodeDef* GetInputNode(const NodeDef& node, const MutableGraphView& graph, + int64_t i); + +// Gets the attr corresponding to a dataset node's output types, if it exists. +absl::Status GetDatasetOutputTypesAttr(const NodeDef& node, + DataTypeVector* output_types); + +// Returns the list of indices of all nodes with the given op or empty list if +// no such node exists. +std::vector FindAllGraphNodesWithOp(const string& op, + const GraphDef& graph); + +// Sets the node name using `prefix` as a prefix while guaranteeing the name +// is unique across the graph. +void SetUniqueGraphNodeName(absl::string_view prefix, GraphDef* graph, + NodeDef* node); + +// Sets the function name using the `prefix` name as a prefix while guaranteeing +// the name is unique across the function library. +void SetUniqueGraphFunctionName(absl::string_view prefix, + const FunctionDefLibrary* library, + FunctionDef* function); + +// Copies attribute having name `attribute_name` from node `from` to node +// `to_node`. +void CopyAttribute(const string& attribute_name, const NodeDef& from, + NodeDef* to_node); + +// Concatenates list attribute having name `attribute_name` from `first` and +// `second` node, setting it to `to_node`. +void ConcatAttributeList(const string& attribute_name, const NodeDef& first, + const NodeDef& second, NodeDef* to_node); + +// Checks that all nodes in the graphs have unique names, and sets their names +// to be unique if they are not already. This is necessary as Graph does not +// have the provisions to deduplicate names, and name deduplication elsewhere +// in tensorflow happens in other layers (for example, in the Scope class of the +// C++ API). Note that the nodes in the graph are identified by their id, +// and renaming nodes does not mutate any edges. +absl::Status EnsureNodeNamesUnique(Graph* g); + +// Returns the item's fetch node, if there is exactly one. Otherwise, returns an +// error. +absl::Status GetFetchNode(const MutableGraphView& graph, + const GrapplerItem& item, NodeDef** fetch_node); + +// Returns true if `item` is derived from a `FunctionDef`, false otherwise. +// Currently, we determine this heuristically: If we don't have any fetch nodes +// or all fetch nodes are `Retval` ops, then we consider this item as derived +// from a `FunctionDef`. +bool IsItemDerivedFromFunctionDef(const GrapplerItem& item, + const MutableGraphView& graph_view); + +// If both input nodes have the "metadata" attribute set, it populates the +// "metadata" attribute for the fused node. +void MaybeSetFusedMetadata(const NodeDef& node1, const NodeDef& node2, + NodeDef* fused_node); + +// Copies the attributes `output_shapes`, `output_types` from node `from` to +// node `to_node` if they exist. The method will return `true` if attributes +// copied successfully, otherwise it will return `false`. +// +// Some tf.data transformations set `Toutput_types` instead of `output_types` +// when the attribute describes type of tensor inputs (e.g. TensorDataset, +// TensorSliceDataset, and PaddedBatchDataset). In this case the method copies +// the attribute `Toutput_types` of node `from` to the attribute `output_types` +// of node `to_node`. +bool CopyShapesAndTypesAttrs(const NodeDef& from, NodeDef* to_node); + +// Checks whether the op has a "sloppy" attribute. +bool HasSloppyAttr(const string& op); + +// Checks whether the op has a "replicate_on_split" attribute. +bool HasReplicateOnSplitAttr(const string& op); + +// Checks whether the op has a "deterministic" attribute. +bool HasDeterministicAttr(const string& op); + +// Sets the `name` as the metadata name of the `node`. It returns an error if +// the `node` already has a metadata name. +absl::Status SetMetadataName(const std::string& name, NodeDef* node); + +} // namespace graph_utils +} // namespace grappler +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_GRAPH_UTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/data/inject_io_prefetch.h b/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/data/inject_io_prefetch.h new file mode 100644 index 00000000..444d49e7 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/data/inject_io_prefetch.h @@ -0,0 +1,64 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_INJECT_IO_PREFETCH_H_ +#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_INJECT_IO_PREFETCH_H_ + +#include + +#include "absl/status/status.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/grappler/optimizers/data/optimizer_base.h" +#include "tensorflow/core/grappler/optimizers/graph_optimizer.h" +#include "tensorflow/core/protobuf/rewriter_config.pb.h" + +namespace tensorflow { +namespace grappler { + +class InjectIoPrefetch : public TFDataOptimizerBase { + public: + InjectIoPrefetch() = default; + ~InjectIoPrefetch() override = default; + + std::string name() const override { return "inject_io_prefetch"; }; + + bool UsesFunctionLibrary() const override { return false; } + + absl::Status Init( + const tensorflow::RewriterConfig_CustomGraphOptimizer* config) override; + + absl::Status OptimizeAndCollectStats(Cluster* cluster, + const GrapplerItem& item, + GraphDef* output, + OptimizationStats* stats) override; + + protected: + bool autotune_ = true; +}; + +class InjectIoPrefetchEligible : public InjectIoPrefetch { + public: + std::string name() const override { return "inject_io_prefetch_eligible"; }; + + absl::Status OptimizeAndCollectStats(Cluster* cluster, + const GrapplerItem& item, + GraphDef* output, + OptimizationStats* stats) override; +}; + +} // namespace grappler +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_INJECT_IO_PREFETCH_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/data/inject_prefetch.h b/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/data/inject_prefetch.h new file mode 100644 index 00000000..f2ffda83 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/data/inject_prefetch.h @@ -0,0 +1,66 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_INJECT_PREFETCH_H_ +#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_INJECT_PREFETCH_H_ + +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/grappler/optimizers/data/optimizer_base.h" + +namespace tensorflow { +namespace grappler { + +constexpr char kAutotune[] = "autotune"; + +// If autotune is ON and the last transformation in the input pipeline is not +// `prefetch()`, this optimization adds `prefetch(AUTOTUNE)` after it. +class InjectPrefetch : public TFDataOptimizerBase { + public: + InjectPrefetch() = default; + ~InjectPrefetch() override = default; + + std::string name() const override { return "inject_prefetch"; }; + + bool UsesFunctionLibrary() const override { return false; } + + absl::Status Init( + const tensorflow::RewriterConfig_CustomGraphOptimizer* config) override { + if (!config) return absl::OkStatus(); + + const std::string& autotune = config->parameter_map().at(kAutotune).s(); + if (autotune == "true") { + autotune_ = true; + } else if (autotune == "false") { + autotune_ = false; + } else { + return errors::InvalidArgument("Received an invalid value for parameter ", + kAutotune, ": ", autotune); + } + return absl::OkStatus(); + } + + absl::Status OptimizeAndCollectStats(Cluster* cluster, + const GrapplerItem& item, + GraphDef* output, + OptimizationStats* stats) override; + + protected: + bool autotune_ = true; +}; + +} // namespace grappler +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_INJECT_PREFETCH_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/data/make_deterministic.h b/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/data/make_deterministic.h new file mode 100644 index 00000000..30659c43 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/data/make_deterministic.h @@ -0,0 +1,77 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_MAKE_DETERMINISTIC_H_ +#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_MAKE_DETERMINISTIC_H_ + +#include "tensorflow/core/grappler/optimizers/data/optimizer_base.h" + +namespace tensorflow { +namespace grappler { + +// Removes sources on nondeterminism from dataset ops. Nondeterminism can occur +// in the follow ways, each which this pass addresses: +// +// 1. The datasets ParallelInterleave, ParallelMap, and MapAndBatch can +// introduce nondeterminism by running a function multiple times in parallel. +// Specifically, if the function can mutate state, it is potentially +// nondeterministic. In such cases, this pass converts such dataset ops to a +// non-parallel version. As a performance optimization, in certain cases this +// pass will instead move nondeterministic ops to a separate non-parallel Map +// op, so that most of the ops can still run in parallel. +// +// 2. Certain datasets, such as Prefetch, can introduce asynchrony by running a +// dataset iterator in a background thread while ops outside the dataset are +// also running. This can introduce nondeterminism if the input pipeline has +// certain stateful ops. Other than Prefetch, datasets with a +// `num_parallel_calls` argument also introduce asynchrony, which includes +// the parallel datasets mentioned in (1) above. +// +// This pass modifies nodes to remove asynchrony when there are any datasets +// in the graph with problematic stateful ops. This is done by converting +// parallel ops into non-parallel versions, as in (1), and by removing +// Prefetch nodes. Unlike (1), legacy random ops such as RandomUniform are +// not problematic despite being stateful, as if the op is within a dataset's +// function, ops outside the dataset cannot access the state. Also unlike +// (1), nondeterministic ops are never moved to a separate Map op, since +// doing so would not remove asynchrony. +// +// 3. Nondeterminism occurs if an op has a "deterministic" attribute that is +// false or a "sloppy" attribute that is true. This pass changes such +// attributes to be deterministic. +class MakeDeterministic : public TFDataOptimizerBase { + public: + MakeDeterministic() = default; + ~MakeDeterministic() override = default; + + string name() const override { return "make_deterministic"; }; + + bool UsesFunctionLibrary() const override { return false; } + + absl::Status Init( + const tensorflow::RewriterConfig_CustomGraphOptimizer* config) override { + return absl::OkStatus(); + } + + absl::Status OptimizeAndCollectStats(Cluster* cluster, + const GrapplerItem& item, + GraphDef* output, + OptimizationStats* stats) override; +}; + +} // namespace grappler +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_MAKE_DETERMINISTIC_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/data/make_sloppy.h b/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/data/make_sloppy.h new file mode 100644 index 00000000..b1046809 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/data/make_sloppy.h @@ -0,0 +1,47 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_MAKE_SLOPPY_H_ +#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_MAKE_SLOPPY_H_ + +#include "tensorflow/core/grappler/optimizers/data/optimizer_base.h" + +namespace tensorflow { +namespace grappler { + +class MakeSloppy : public TFDataOptimizerBase { + public: + MakeSloppy() = default; + ~MakeSloppy() override = default; + + string name() const override { return "make_sloppy"; } + + bool UsesFunctionLibrary() const override { return false; } + + absl::Status Init( + const tensorflow::RewriterConfig_CustomGraphOptimizer* config) override { + return absl::OkStatus(); + } + + absl::Status OptimizeAndCollectStats(Cluster* cluster, + const GrapplerItem& item, + GraphDef* output, + OptimizationStats* stats) override; +}; + +} // namespace grappler +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_MAKE_SLOPPY_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion.h b/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion.h new file mode 100644 index 00000000..7e7e002b --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion.h @@ -0,0 +1,47 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_MAP_AND_BATCH_FUSION_H_ +#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_MAP_AND_BATCH_FUSION_H_ + +#include "tensorflow/core/grappler/optimizers/data/optimizer_base.h" + +namespace tensorflow { +namespace grappler { + +class MapAndBatchFusion : public TFDataOptimizerBase { + public: + MapAndBatchFusion() = default; + ~MapAndBatchFusion() override = default; + + string name() const override { return "map_and_batch_fusion"; }; + + bool UsesFunctionLibrary() const override { return false; } + + absl::Status Init( + const tensorflow::RewriterConfig_CustomGraphOptimizer* config) override { + return absl::OkStatus(); + } + + absl::Status OptimizeAndCollectStats(Cluster* cluster, + const GrapplerItem& item, + GraphDef* output, + OptimizationStats* stats) override; +}; + +} // namespace grappler +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_MAP_AND_BATCH_FUSION_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/data/map_and_filter_fusion.h b/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/data/map_and_filter_fusion.h new file mode 100644 index 00000000..018a8751 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/data/map_and_filter_fusion.h @@ -0,0 +1,56 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_MAP_AND_FILTER_FUSION_H_ +#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_MAP_AND_FILTER_FUSION_H_ + +#include "tensorflow/core/grappler/optimizers/data/optimizer_base.h" + +namespace tensorflow { +namespace grappler { + +// This transformation fuses map and filter operations by moving computation of +// filter predicate to MapDataset, which as a result produces an extra boolean +// component. We filter by the boolean component, then project it away. +// +// In symbols, we transform map(x -> f(x)).filter(f(x) -> p(f(x))) into +// map(x -> f(x), p(f(x))).filter(f(x), p(f(x)) -> p(f(x))).map(f(x), p(f(x)) +// -> f(x)). This is more efficient because the latter filter and map operations +// can be performed short-circuit, so only the first map requires an executor +// invocation. +class MapAndFilterFusion : public TFDataOptimizerBase { + public: + MapAndFilterFusion() = default; + ~MapAndFilterFusion() override = default; + + string name() const override { return "map_and_filter_fusion"; }; + + bool UsesFunctionLibrary() const override { return false; } + + absl::Status Init( + const tensorflow::RewriterConfig_CustomGraphOptimizer* config) override { + return absl::OkStatus(); + } + + absl::Status OptimizeAndCollectStats(Cluster* cluster, + const GrapplerItem& item, + GraphDef* output, + OptimizationStats* stats) override; +}; + +} // namespace grappler +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_MAP_AND_FILTER_FUSION_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/data/map_fusion.h b/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/data/map_fusion.h new file mode 100644 index 00000000..2512fc88 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/data/map_fusion.h @@ -0,0 +1,65 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_MAP_FUSION_H_ +#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_MAP_FUSION_H_ + +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/grappler/optimizers/data/optimizer_base.h" + +namespace tensorflow { +namespace grappler { + +constexpr char kAutotune[] = "autotune"; + +// This optimization fuses map transformations by merging their map functions. +class MapFusion : public TFDataOptimizerBase { + public: + MapFusion() = default; + ~MapFusion() override = default; + + string name() const override { return "map_fusion"; }; + + bool UsesFunctionLibrary() const override { return false; } + + absl::Status Init( + const tensorflow::RewriterConfig_CustomGraphOptimizer* config) override { + if (!config) return absl::OkStatus(); + + const string& autotune = config->parameter_map().at(kAutotune).s(); + if (autotune == "true") { + autotune_ = true; + } else if (autotune == "false") { + autotune_ = false; + } else { + return errors::InvalidArgument("Received an invalid value for parameter ", + kAutotune, ": ", autotune); + } + return absl::OkStatus(); + } + + absl::Status OptimizeAndCollectStats(Cluster* cluster, + const GrapplerItem& item, + GraphDef* output, + OptimizationStats* stats) override; + + private: + bool autotune_ = true; +}; + +} // namespace grappler +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_MAP_FUSION_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/data/map_parallelization.h b/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/data/map_parallelization.h new file mode 100644 index 00000000..6ed70034 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/data/map_parallelization.h @@ -0,0 +1,65 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_MAP_PARALLELIZATION_H_ +#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_MAP_PARALLELIZATION_H_ + +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/grappler/optimizers/data/optimizer_base.h" + +namespace tensorflow { +namespace grappler { + +constexpr char kAutotune[] = "autotune"; + +// This optimization parallelizes MapDataset when function is stateless. +class MapParallelization : public TFDataOptimizerBase { + public: + MapParallelization() = default; + ~MapParallelization() override = default; + + string name() const override { return "map_parallelization"; }; + + bool UsesFunctionLibrary() const override { return false; } + + absl::Status Init( + const tensorflow::RewriterConfig_CustomGraphOptimizer* config) override { + if (!config) return absl::OkStatus(); + + const string& autotune = config->parameter_map().at(kAutotune).s(); + if (autotune == "true") { + autotune_ = true; + } else if (autotune == "false") { + autotune_ = false; + } else { + return errors::InvalidArgument("Received an invalid value for parameter ", + kAutotune, ": ", autotune); + } + return absl::OkStatus(); + } + + absl::Status OptimizeAndCollectStats(Cluster* cluster, + const GrapplerItem& item, + GraphDef* output, + OptimizationStats* stats) override; + + private: + bool autotune_ = true; +}; + +} // namespace grappler +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_MAP_PARALLELIZATION_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/data/meta_optimizer.h b/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/data/meta_optimizer.h new file mode 100644 index 00000000..e839389d --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/data/meta_optimizer.h @@ -0,0 +1,55 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_META_OPTIMIZER_H_ +#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_META_OPTIMIZER_H_ + +#include "absl/container/flat_hash_map.h" +#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer.h" + +namespace tensorflow { +namespace grappler { + +// This optimizer performs tf.data-specific optimizations by invoking +// other optimizers. +class TFDataMetaOptimizer : public CustomGraphOptimizer { + public: + TFDataMetaOptimizer() = default; + ~TFDataMetaOptimizer() override = default; + + string name() const override { return "tf_data_meta_optimizer"; }; + + bool UsesFunctionLibrary() const override { return true; } + + absl::Status Init( + const tensorflow::RewriterConfig_CustomGraphOptimizer* config) override; + + absl::Status Optimize(Cluster* cluster, const GrapplerItem& item, + GraphDef* output) override; + + private: + absl::flat_hash_map> + enabled_optimizers_; + + // Applies an optimization with the specified name on `item`, and stores + // the result in `item.graph` + absl::Status ApplyOptimization(const string& name, Cluster* cluster, + GrapplerItem* item) const; +}; + +} // namespace grappler +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_META_OPTIMIZER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/data/noop_elimination.h b/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/data/noop_elimination.h new file mode 100644 index 00000000..389b112e --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/data/noop_elimination.h @@ -0,0 +1,49 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_NOOP_ELIMINATION_H_ +#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_NOOP_ELIMINATION_H_ + +#include "tensorflow/core/grappler/optimizers/data/optimizer_base.h" + +namespace tensorflow { +namespace grappler { + +// This class eliminates tf.data transformations such as `take(n)` (for n < 0), +// `skip(0)`, `repeat(1)`, or `prefetch(0)`. +class NoOpElimination : public TFDataOptimizerBase { + public: + NoOpElimination() = default; + ~NoOpElimination() override = default; + + string name() const override { return "noop_elimination"; }; + + bool UsesFunctionLibrary() const override { return false; } + + absl::Status Init( + const tensorflow::RewriterConfig_CustomGraphOptimizer* config) override { + return absl::OkStatus(); + } + + absl::Status OptimizeAndCollectStats(Cluster* cluster, + const GrapplerItem& item, + GraphDef* output, + OptimizationStats* stats) override; +}; + +} // namespace grappler +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_NOOP_ELIMINATION_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/data/optimizer_base.h b/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/data/optimizer_base.h new file mode 100644 index 00000000..7cd16fba --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/data/optimizer_base.h @@ -0,0 +1,47 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_OPTIMIZER_BASE_H_ +#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_OPTIMIZER_BASE_H_ + +#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer.h" + +namespace tensorflow { +namespace grappler { + +// A base class for tf.data optimizers. +class TFDataOptimizerBase : public CustomGraphOptimizer { + public: + struct OptimizationStats { + // Identifies the number of independent graph changes for an optimization. + int64_t num_changes = 0; + }; + + TFDataOptimizerBase() = default; + ~TFDataOptimizerBase() override = default; + + absl::Status Optimize(Cluster* cluster, const GrapplerItem& item, + GraphDef* output) final; + + virtual absl::Status OptimizeAndCollectStats(Cluster* cluster, + const GrapplerItem& item, + GraphDef* output, + OptimizationStats* stats) = 0; +}; + +} // namespace grappler +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_OPTIMIZER_BASE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/data/parallel_batch.h b/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/data/parallel_batch.h new file mode 100644 index 00000000..46b5ff9c --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/data/parallel_batch.h @@ -0,0 +1,47 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_PARALLEL_BATCH_H_ +#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_PARALLEL_BATCH_H_ + +#include "tensorflow/core/grappler/optimizers/data/optimizer_base.h" + +namespace tensorflow { +namespace grappler { + +class ParallelBatch : public TFDataOptimizerBase { + public: + ParallelBatch() = default; + ~ParallelBatch() override = default; + + string name() const override { return "parallel_batch"; } + + bool UsesFunctionLibrary() const override { return false; } + + absl::Status Init( + const tensorflow::RewriterConfig_CustomGraphOptimizer* config) override { + return absl::OkStatus(); + } + + absl::Status OptimizeAndCollectStats(Cluster* cluster, + const GrapplerItem& item, + GraphDef* output, + OptimizationStats* stats) override; +}; + +} // namespace grappler +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_PARALLEL_BATCH_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/data/remove_compression_map.h b/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/data/remove_compression_map.h new file mode 100644 index 00000000..550436f4 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/data/remove_compression_map.h @@ -0,0 +1,47 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_REMOVE_COMPRESSION_MAP_H_ +#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_REMOVE_COMPRESSION_MAP_H_ + +#include "tensorflow/core/grappler/optimizers/data/optimizer_base.h" + +namespace tensorflow { +namespace grappler { + +class RemoveCompressionMap : public TFDataOptimizerBase { + public: + RemoveCompressionMap() = default; + ~RemoveCompressionMap() override = default; + + string name() const override { return "remove_compression_map"; } + + bool UsesFunctionLibrary() const override { return false; } + + absl::Status Init( + const tensorflow::RewriterConfig_CustomGraphOptimizer* config) override { + return absl::OkStatus(); + } + + absl::Status OptimizeAndCollectStats(Cluster* cluster, + const GrapplerItem& item, + GraphDef* output, + OptimizationStats* stats) override; +}; + +} // namespace grappler +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_REMOVE_COMPRESSION_MAP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/data/replicate_on_split.h b/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/data/replicate_on_split.h new file mode 100644 index 00000000..cffcbd18 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/data/replicate_on_split.h @@ -0,0 +1,47 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_REPLICATE_ON_SPLIT_H_ +#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_REPLICATE_ON_SPLIT_H_ + +#include "tensorflow/core/grappler/optimizers/data/optimizer_base.h" + +namespace tensorflow { +namespace grappler { + +class ReplicateOnSplit : public TFDataOptimizerBase { + public: + ReplicateOnSplit() = default; + ~ReplicateOnSplit() override = default; + + string name() const override { return "replicate_on_split"; } + + bool UsesFunctionLibrary() const override { return false; } + + absl::Status Init( + const tensorflow::RewriterConfig_CustomGraphOptimizer* config) override { + return absl::OkStatus(); + } + + absl::Status OptimizeAndCollectStats(Cluster* cluster, + const GrapplerItem& item, + GraphDef* output, + OptimizationStats* stats) override; +}; + +} // namespace grappler +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_REPLICATE_ON_SPLIT_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/data/seq_interleave_prefetch.h b/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/data/seq_interleave_prefetch.h new file mode 100644 index 00000000..c881d9aa --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/data/seq_interleave_prefetch.h @@ -0,0 +1,55 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_SEQ_INTERLEAVE_PREFETCH_H_ +#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_SEQ_INTERLEAVE_PREFETCH_H_ + +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/grappler/optimizers/data/optimizer_base.h" + +namespace tensorflow { +namespace grappler { + +// This optimization replaces parallel interleave with sequential interleave and +// adds `prefetch(AUTOTUNE)` after the user defined map function in interleave. +class SeqInterleavePrefetch : public TFDataOptimizerBase { + public: + SeqInterleavePrefetch() = default; + ~SeqInterleavePrefetch() override = default; + + std::string name() const override { return "seq_interleave_prefetch"; }; + + // The SeqInterleavePrefetch optimizer requires access to the function + // library. + bool UsesFunctionLibrary() const override { return true; } + + absl::Status Init( + const tensorflow::RewriterConfig_CustomGraphOptimizer* config) override { + return absl::OkStatus(); + } + + absl::Status OptimizeAndCollectStats(Cluster* cluster, + const GrapplerItem& item, + GraphDef* output, + OptimizationStats* stats) override; + + protected: + bool autotune_ = true; +}; + +} // namespace grappler +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_SEQ_INTERLEAVE_PREFETCH_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/data/shuffle_and_repeat_fusion.h b/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/data/shuffle_and_repeat_fusion.h new file mode 100644 index 00000000..ba30ca63 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/data/shuffle_and_repeat_fusion.h @@ -0,0 +1,47 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_SHUFFLE_AND_REPEAT_FUSION_H_ +#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_SHUFFLE_AND_REPEAT_FUSION_H_ + +#include "tensorflow/core/grappler/optimizers/data/optimizer_base.h" + +namespace tensorflow { +namespace grappler { + +class ShuffleAndRepeatFusion : public TFDataOptimizerBase { + public: + ShuffleAndRepeatFusion() = default; + ~ShuffleAndRepeatFusion() override = default; + + string name() const override { return "shuffle_and_repeat_fusion"; }; + + bool UsesFunctionLibrary() const override { return false; } + + absl::Status Init( + const tensorflow::RewriterConfig_CustomGraphOptimizer* config) override { + return absl::OkStatus(); + } + + absl::Status OptimizeAndCollectStats(Cluster* cluster, + const GrapplerItem& item, + GraphDef* output, + OptimizationStats* stats) override; +}; + +} // namespace grappler +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_SHUFFLE_AND_REPEAT_FUSION_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/data/slack.h b/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/data/slack.h new file mode 100644 index 00000000..af70d314 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/data/slack.h @@ -0,0 +1,66 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_SLACK_H_ +#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_SLACK_H_ + +#include "absl/strings/numbers.h" +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/grappler/mutable_graph_view.h" +#include "tensorflow/core/grappler/optimizers/data/optimizer_base.h" + +namespace tensorflow { +namespace grappler { + +// This optimization sets the slack attr of the terminal PrefetchDataset node in +// an input pipeline. +class Slack : public TFDataOptimizerBase { + public: + Slack() = default; + ~Slack() override = default; + + string name() const override { return "slack"; }; + + bool UsesFunctionLibrary() const override { return false; } + + absl::Status Init( + const tensorflow::RewriterConfig_CustomGraphOptimizer* config) override { + if (!config) return errors::InvalidArgument("Config parameter required."); + + const string& slack_period_param = + config->parameter_map().at("slack_period").s(); + if (!absl::SimpleAtoi(slack_period_param, &slack_period_)) { + return errors::InvalidArgument("Invalid `slack_period` parameter: ", + slack_period_param); + } + return absl::OkStatus(); + } + + absl::Status OptimizeAndCollectStats(Cluster* cluster, + const GrapplerItem& item, + GraphDef* output, + OptimizationStats* stats) override; + + private: + int64_t slack_period_ = -1; + + absl::Status RecursivelyHandleOp(const MutableGraphView& graph, + NodeDef* dataset_node); +}; + +} // namespace grappler +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_SLACK_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/data/split_utils.h b/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/data/split_utils.h new file mode 100644 index 00000000..df4c52b2 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/data/split_utils.h @@ -0,0 +1,76 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_SPLIT_UTILS_H_ +#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_SPLIT_UTILS_H_ + +#include + +#include "absl/container/flat_hash_set.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/platform/statusor.h" + +namespace tensorflow { +namespace grappler { +namespace split_utils { + +// Return value of `SplitFunction`, which is described below. +struct SplitResults { + FunctionDef first_function; + FunctionDef second_function; + std::vector first_function_output_types; +}; + +// Splits a FunctionDef into two FunctionDefs, called `first` and `second`, such +// that calling `function(*args)` is equivalent to calling +// `second(first(*args))`. The set `nodes_in_first_function` specifies nodes +// that are copied to `first`, and the other nodes are copied to `second`. Any +// edges from `first` to `second` will be represented by an output of `first` +// and a corresponding input of `second`. The caller must pass +// `nodes_in_first_function` such that there will not be any edges from `second` +// to `first`. +// +// For example, if you have the following function (using Python syntax): +// +// def f(x): +// y = tf.math.add(x, 1., name='add') +// return tf.multiply(y, 2, name='mul') +// +// Calling SplitFunction(f, {'add'}) results in: +// +// def first_function(x): +// return tf.math.add(x, 1., name='add') +// def second_function(y): +// return tf.multiply(y, 2, name='mul') +// +// The `num_captured_inputs` argument controls which arguments of `function` +// will be arguments of `second`. If it is zero, the only arguments of `second` +// are the outputs of `first`. If it is above zero, the last +// `num_caputured_inputs` arguments of `function` will also be arguments of +// `second`. +// +// Splitting functions in certain cases is unimplemented, in which case an +// Unimplemented status will be returned. Grappler passes must gracefully handle +// Unimplemented statuses without returning the error to its caller. +absl::StatusOr SplitFunction( + const FunctionDef& function, + const absl::flat_hash_set& nodes_in_first_function, + int64_t num_captured_inputs, const FunctionLibraryDefinition& library); + +} // namespace split_utils +} // namespace grappler +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_SPLIT_UTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/data/use_private_thread_pool.h b/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/data/use_private_thread_pool.h new file mode 100644 index 00000000..b886d36a --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/data/use_private_thread_pool.h @@ -0,0 +1,48 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_USE_PRIVATE_THREAD_POOL_H_ +#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_USE_PRIVATE_THREAD_POOL_H_ + +#include "tensorflow/core/grappler/optimizers/data/optimizer_base.h" + +namespace tensorflow { +namespace grappler { + +// This optimization creates private thread pool for the input pipeline. +class UsePrivateThreadPool : public TFDataOptimizerBase { + public: + UsePrivateThreadPool() = default; + ~UsePrivateThreadPool() override = default; + + string name() const override { return "use_private_thread_pool"; }; + + bool UsesFunctionLibrary() const override { return false; } + + absl::Status Init( + const tensorflow::RewriterConfig_CustomGraphOptimizer* config) override { + return absl::OkStatus(); + } + + absl::Status OptimizeAndCollectStats(Cluster* cluster, + const GrapplerItem& item, + GraphDef* output, + OptimizationStats* stats) override; +}; + +} // namespace grappler +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_USE_PRIVATE_THREAD_POOL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/debug_stripper.h b/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/debug_stripper.h new file mode 100644 index 00000000..c94257f5 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/debug_stripper.h @@ -0,0 +1,42 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DEBUG_STRIPPER_H_ +#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DEBUG_STRIPPER_H_ + +#include "tensorflow/core/grappler/optimizers/graph_optimizer.h" + +namespace tensorflow { +namespace grappler { + +// DebugStripper strips off debug-related nodes (e.g. +// Assert, CheckNumerics, Print) from the graph. +class DebugStripper : public GraphOptimizer { + public: + DebugStripper() {} + ~DebugStripper() override {} + + string name() const override { return "debug_stripper"; }; + + bool UsesFunctionLibrary() const override { return false; } + + absl::Status Optimize(Cluster* cluster, const GrapplerItem& item, + GraphDef* output) override; +}; + +} // end namespace grappler +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DEBUG_STRIPPER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/dependency_optimizer.h b/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/dependency_optimizer.h new file mode 100644 index 00000000..cc8d7043 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/dependency_optimizer.h @@ -0,0 +1,85 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DEPENDENCY_OPTIMIZER_H_ +#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DEPENDENCY_OPTIMIZER_H_ + +#include +#include "tensorflow/core/grappler/optimizers/graph_optimizer.h" +#include "tensorflow/core/grappler/utils.h" +#include "tensorflow/core/protobuf/rewriter_config.pb.h" + +namespace tensorflow { +namespace grappler { + +// Optimize TF computations by removing control dependencies or re-arranging +// them to shorten the critical path for a model step or enable other +// optimizations, such as removing nodes that are effectively noops. +class DependencyOptimizer : public GraphOptimizer { + public: + DependencyOptimizer() {} + explicit DependencyOptimizer(RewriterConfig::Toggle opt_level) {} + ~DependencyOptimizer() override {} + + string name() const override { return "dependency_optimizer"; }; + + bool UsesFunctionLibrary() const override { return false; } + + absl::Status Optimize(Cluster* cluster, const GrapplerItem& item, + GraphDef* optimized_graph) override; + + private: + // Returns true if bypassing node does not increase the number of edges or + // number of edges crossing a device boundary. + bool BypassingNodeIsBeneficial( + const NodeDef& node, const std::vector& input_nodes, + const std::vector& output_nodes) const; + int NumEdgesIfBypassed(const NodeDef& node, + const std::vector& output_nodes) const; + // Returns true if node is not an Identity node or if it is an Identity + // that is safe to remove. + bool SafeToRemoveIdentity(const NodeDef& node) const; + // Returns true if it is safe to convert node to NoOp. + bool SafeToConvertToNoOp(const NodeDef& node) const; + // Removes all duplicate control dependencies. + void CleanControlInputs(); + // Builds a map from the &optimized_graph_->node(i) to i. + void BuildNodeToIdx(); + // Tries to optimize the node with the given index, possibly additional + // optimizations by inserting nodes in nodes_to_simplify, and pruning nodes by + // inserting them in nodes_to_delete. + void OptimizeNode(int node_idx, SetVector* nodes_to_simplify, + std::set* nodes_to_delete); + // Eliminates redundant control dependencies by computing the transitive + // reduction of the graph. + absl::Status TransitiveReduction(); + // Main driver of dependency optimizations. + absl::Status OptimizeDependencies(); + // Replaces multiple cross-device control edges from the same device with a + // single control edge. If `host_granularity` is true then group control + // edges from all devices on the same host. + void GroupCrossDeviceControlEdges(bool host_granularity); + + bool fetch_nodes_known_; + std::unordered_set nodes_to_preserve_; + std::unique_ptr node_map_; + std::unordered_map node_to_idx_; + GraphDef* optimized_graph_; // Not owned. +}; + +} // end namespace grappler +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DEPENDENCY_OPTIMIZER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/evaluation_utils.h b/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/evaluation_utils.h new file mode 100644 index 00000000..9ae5cb22 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/evaluation_utils.h @@ -0,0 +1,65 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_EVALUATION_UTILS_H_ +#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_EVALUATION_UTILS_H_ + +#define EIGEN_USE_THREADS + +#include "tensorflow/core/framework/device_base.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor.pb.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/gtl/inlined_vector.h" + +namespace Eigen { +class ThreadPoolInterface; +class ThreadPoolWrapper; +} // namespace Eigen + +namespace tensorflow { +namespace grappler { + +class DeviceSimple : public DeviceBase { + public: + DeviceSimple(); + ~DeviceSimple(); + + absl::Status MakeTensorFromProto(const TensorProto& tensor_proto, + const AllocatorAttributes alloc_attrs, + Tensor* tensor) override; + + Allocator* GetAllocator(AllocatorAttributes attr) override { + return cpu_allocator(); + } + + const std::string& device_type() const override { return device_type_; } + + private: + DeviceBase::CpuWorkerThreads eigen_worker_threads_; + std::unique_ptr eigen_device_; + const std::string device_type_ = DEVICE_CPU; +}; + +absl::Status EvaluateNode(const NodeDef& node, + const absl::InlinedVector& inputs, + DeviceBase* cpu_device, ResourceMgr* resource_mgr, + absl::InlinedVector* output); + +} // end namespace grappler +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_EVALUATION_UTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/function_api_info.h b/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/function_api_info.h new file mode 100644 index 00000000..e2ae234f --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/function_api_info.h @@ -0,0 +1,106 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_FUNCTION_API_INFO_H_ +#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_FUNCTION_API_INFO_H_ + +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "tensorflow/core/framework/function.pb.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { +namespace grappler { +class FunctionApiInfo { + public: + FunctionApiInfo(); + virtual ~FunctionApiInfo(); + + enum FunctionType { + INFERENCE, // Default type. + FORWARD, + BACKWARD, + }; + + absl::Status Init(const FunctionDef& function_def); + + const string& interface_name() const; + const string& preferred_device() const; + const FunctionType function_type() const; + const string& pairing_function_name() const; + const DataTypeVector& input_arg_dtypes() const; + const DataTypeVector& output_arg_dtypes() const; + + private: + string interface_name_; + string preferred_device_; + FunctionType function_type_; + // The pairing function is used to pair between forward and backward function, + // which will be useful during function swapping. Inference function won't + // have pairing function. + string pairing_function_name_; + // The following two attributes are useful for forward and backward functions. + DataTypeVector input_arg_dtypes_; + DataTypeVector output_arg_dtypes_; + + FunctionApiInfo(const FunctionApiInfo&) = delete; + void operator=(const FunctionApiInfo&) = delete; +}; + +// A collection of information for function and the interface it implements. +// A interface is a well defined math operation, eg I1 = 2 * x + y. Multiple +// functions could implement the same interface with different behavior based on +// different hardware condition and limits, +// eg F1 = math_ops.add(math_ops.add(x, x), y), or +// F2 = math_ops.add(math_ops.matmul(x, 2), y). +class FunctionLibraryApiInfo { + public: + FunctionLibraryApiInfo(); + virtual ~FunctionLibraryApiInfo(); + // Populate the internal field for the functions within the function_library. + absl::Status Init(const FunctionDefLibrary& function_library); + + absl::Status GetEquivalentImplementations( + const string& function_name, std::vector* other_functions) const; + + const FunctionApiInfo* GetApiInfo(const string& function_name) const; + bool empty() const { return func_info_.empty(); } + std::size_t size() const { return func_info_.size(); } + + private: + // Map between function name to function details. + std::unordered_map> func_info_; + + // Map between interface name to function names. + // Forward/backward function pair usually have different signatures between + // each other since forward function could produce extra internal state as + // output, and backward will take those extra state as inputs. + absl::flat_hash_map> intf_to_inference_funcs_; + absl::flat_hash_map> intf_to_forward_funcs_; + absl::flat_hash_map> intf_to_backward_funcs_; + + FunctionLibraryApiInfo(const FunctionLibraryApiInfo&) = delete; + void operator=(const FunctionLibraryApiInfo&) = delete; +}; + +} // end namespace grappler +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_FUNCTION_API_INFO_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/function_optimizer.h b/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/function_optimizer.h new file mode 100644 index 00000000..8f8eb732 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/function_optimizer.h @@ -0,0 +1,59 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_FUNCTION_OPTIMIZER_H_ +#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_FUNCTION_OPTIMIZER_H_ + +#include "tensorflow/core/grappler/optimizers/graph_optimizer.h" +#include "tensorflow/core/protobuf/rewriter_config.pb.h" + +namespace tensorflow { +namespace grappler { + +// Remap TensorFlow subgraphs onto alternative operations or collection of +// operations to make the overall graph more efficient. +class FunctionOptimizer : public GraphOptimizer { + public: + explicit FunctionOptimizer(RewriterConfig::Toggle opt_level, + bool lower_control_flow) + : opt_level_(opt_level), lower_control_flow_(lower_control_flow) {} + ~FunctionOptimizer() override = default; + + string name() const override { return "function_optimizer"; }; + + bool UsesFunctionLibrary() const override { return true; } + + absl::Status Optimize(Cluster* cluster, const GrapplerItem& item, + GraphDef* optimized_graph) override; + + private: + friend class FunctionOptimizerTest; + + // Runs a single function optimizer pass over the `graph`. All nodes that are + // not function calls will be copied from the `graph` to the + // `optimized_graph`. Function call nodes inlined or specialized, and + // instantiated function body or specialized function call nodes will be added + // to the `optimized_graph`. + absl::Status RunFunctionOptimizerPass(const GrapplerItem& item, + GraphDef* optimized_graph) const; + + RewriterConfig::Toggle opt_level_; + bool lower_control_flow_; +}; + +} // end namespace grappler +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_FUNCTION_OPTIMIZER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/generic_layout_optimizer.h b/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/generic_layout_optimizer.h new file mode 100644 index 00000000..61a578fa --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/generic_layout_optimizer.h @@ -0,0 +1,62 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_GENERIC_LAYOUT_OPTIMIZER_H_ +#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_GENERIC_LAYOUT_OPTIMIZER_H_ + +#include + +#include "tensorflow/core/grappler/optimizers/graph_optimizer.h" +#include "tensorflow/core/protobuf/rewriter_config.pb.h" + +namespace tensorflow { +namespace grappler { + +// Optimize the data layout for convolutional models. +class GenericLayoutOptimizer : public GraphOptimizer { + public: + explicit GenericLayoutOptimizer(string enforced_layout = "") + : GenericLayoutOptimizer(RewriterConfig::DEFAULT, + RewriterConfig::NO_CONVERSION_ON_CPU, + enforced_layout) {} + explicit GenericLayoutOptimizer(RewriterConfig::Toggle opt_level, + string enforced_layout = "") + : GenericLayoutOptimizer(opt_level, RewriterConfig::NO_CONVERSION_ON_CPU, + enforced_layout) {} + explicit GenericLayoutOptimizer(RewriterConfig::Toggle opt_level, + RewriterConfig::CpuLayout layout_conversion, + string enforced_layout = "") + : opt_level_(opt_level), + cpu_layout_conversion_(layout_conversion), + enforced_layout_(enforced_layout) {} + ~GenericLayoutOptimizer() override = default; + + string name() const override { return "layout"; }; + + bool UsesFunctionLibrary() const override { return false; } + + absl::Status Optimize(Cluster* cluster, const GrapplerItem& item, + GraphDef* output) override; + + private: + RewriterConfig::Toggle opt_level_; + RewriterConfig::CpuLayout cpu_layout_conversion_; + const string enforced_layout_; +}; + +} // namespace grappler +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_GENERIC_LAYOUT_OPTIMIZER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/generic_layout_optimizer_transposer.h b/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/generic_layout_optimizer_transposer.h new file mode 100644 index 00000000..1c0c0134 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/generic_layout_optimizer_transposer.h @@ -0,0 +1,676 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_GENERIC_LAYOUT_OPTIMIZER_TRANSPOSER_H_ +#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_GENERIC_LAYOUT_OPTIMIZER_TRANSPOSER_H_ + +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/grappler/costs/graph_properties.h" +#include "tensorflow/core/grappler/utils.h" +#include "tensorflow/core/grappler/utils/frame.h" +#include "tensorflow/core/grappler/utils/graph_view.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { +namespace grappler { + +constexpr char kAttrSrcFormat[] = "src_format"; +constexpr char kAttrDstFormat[] = "dst_format"; +constexpr char kAttrOutputShape[] = "_output_shapes"; +constexpr char kGPU[] = "GPU"; +constexpr char kCPU[] = "CPU"; + +// TransposeContext owns all data members. Must initialize GraphProperties, +// FrameView, GraphDef and MutableGraphView with the same graph. NodeDef +// pointers in FrameView, GraphDef and MutableGraphView must point to nodes in +// the same GraphDef instance. +struct TransposeContext { + // Initializes TransposeContext with given GrapplerItem. Because initializing + // FrameMap and GraphProperties may return error, we initialize + // TransposeContext outside constructor. + static absl::Status InitializeTransposeContext(bool assume_valid_feeds, + const GrapplerItem& item, + const Cluster* cluster, + TransposeContext* context); + + static absl::Status InitializeTransposeContext(const GrapplerItem& item, + const Cluster* cluster, + TransposeContext* context) { + return InitializeTransposeContext(false, item, cluster, context); + } + + // Sets data formats to convert from and to for specified device type. + void AssignDeviceAndDataFormats(absl::string_view target_device, + absl::string_view src_format, + absl::string_view dst_format); + + FrameView frames; + GraphDef graph; + // Number of nodes in the original graph. As new nodes are appended to the end + // of the graph, all new nodes should have a node index greater than or equal + // to this. + int num_nodes; + absl::flat_hash_set nodes_to_preserve; + std::unique_ptr graph_properties; + std::unique_ptr graph_view; + + string target_device; + string src_format; + string dst_format; + absl::flat_hash_map src_dim_indices; + absl::flat_hash_map dst_dim_indices; + std::vector src_to_dst; + std::vector dst_to_src; + + string enforced_layout; +}; + +class Transposer { + public: + explicit Transposer() {} + + Transposer(const Transposer&) = delete; + Transposer& operator=(const Transposer&) = delete; + + virtual ~Transposer() {} + + // Returns true iff the node should be processed by this transposer. + // NodeProcessors may perform additional oprand specific checks before + // processing if necessary. + // Following common conditions are checked: + // * node's device matches target device + // * node's source format matches config's source format + // * node has output + bool ShouldProcess(const TransposeContext& context, + const utils::MutableNodeView& node) const; + + // Transposes given node from src format to dst format. Also perform other + // necessary operations to guarantee the graph produce the same result. + // Eg. Add Transpose node sets before fanin ports and after fanout ports. + virtual absl::Status TransposeNode(TransposeContext* context, + utils::MutableNodeView* node) = 0; + + // Creates a Const node for permutation. If node with node_name already exits, + // return and reuse it. + absl::Status CreateConstPermNode(TransposeContext* context, + absl::string_view node_name, + absl::string_view device, + absl::Span permutation, + absl::string_view control_node_name, + utils::MutationNewNode* added_node); + + // Creates a TransposeNode with given properties. If node with node_name + // already exits, return and reuse it. + // A const perm node is also created and connected to the 2nd fanin. + // control_node_name is ignored if it is empty. + absl::Status CreateTransposeNode( + TransposeContext* context, absl::string_view name_format, + const DataType& data_type, absl::string_view device, + TensorShapeProto fanin_shape, absl::Span permutation, + absl::string_view control_node_name, utils::MutationNewNode* added_node, + string* transpose_node_name); + + // Update all edges between dst_node->fanin[dst_ports] and dst_node by + // inserting an op node. + absl::Status UpdateFaninEdgesWithOp(TransposeContext* context, + absl::Span dst_ports, + utils::MutableNodeView* dst_node, + absl::string_view op); + + // Update all edges between src_node:src_ports and nodes take + // src_node:src_ports as fanin. Also update attr _output_shape of src_node. + absl::Status UpdateFanoutEdgesWithOp(TransposeContext* context, + absl::Span src_ports, + utils::MutableNodeView* src_node, + absl::string_view op); + + // Creates a DataFromat node with given properties. + // DataFromat op is either DataFormatVecPermute or DataFormatDimMap. + absl::Status CreateDataFormatNode( + TransposeContext* context, absl::string_view node_name, + absl::string_view op, absl::string_view device, const DataType& data_type, + bool is_fanin_on_host, bool is_src_format_to_dst_format, + utils::MutationNewNode* added_node); + + protected: + int GetFanoutPortRank(const utils::MutableNodeView& node, int port) const; + bool IsFanoutPortRankN(const utils::MutableNodeView& node, int port, + int n) const; + bool IsFanoutPortsRankN(const utils::MutableNodeView& node, + absl::Span ports, int n) const; + int GetFaninPortRank(const utils::MutableNodeView& node, int port) const; + bool IsFaninPortRankN(const utils::MutableNodeView& node, int port, + int n) const; + + // Checks if fanin at specified port(s) has dimensions `dims` iff fanin is a + // Const. If fanin is not a Const, no dimensions will be checked and this will + // return true. + bool IsFaninPortDimsNIfConst(const utils::MutableNodeView& node, int port, + absl::Span dims) const; + bool IsFaninPortsDimsNIfConst(const utils::MutableNodeView& node, + absl::Span ports, + absl::Span dims) const; + bool CanProcessNode(const TransposeContext& context, + const utils::MutableNodeView& node) const; + // Update all edges between dst_node->fanin[dst_ports] and dst_node. + // A node with op is created and inserted between all edges. + // op is one of Transpose, DataFormatVecPermute or DataFormatDimMap. + absl::Status UpdateEdge(TransposeContext* context, + absl::string_view name_format, absl::string_view op, + const AttrValue* input_shape, bool is_in_frame, + bool is_src_format_to_dst_format, const int src_port, + const int dst_port, utils::MutableNodeView* src_node, + utils::MutableNodeView* dst_node); + string GetFaninNameFormat(absl::string_view node_name, int port, + absl::string_view src_format, + absl::string_view dst_format); + string GetFanoutNameFormat(absl::string_view node_name, int port, int index, + absl::string_view src_format, + absl::string_view dst_format); + string LayoutOptimizerNode(absl::string_view node_name); + string GetReshapeNodeNameFormat(absl::string_view node_name, int index, + absl::string_view src_format, + absl::string_view dst_format); + string GetShapeConstNodeNameFormat(absl::string_view node_name, int index); +}; + +class LayoutSensitiveOpTransposer : public Transposer { + public: + explicit LayoutSensitiveOpTransposer() : Transposer() {} + + // Updates attrs data_format, ksize, strides of the given node to dst_format. + // _output_shape is updated during UpdateOutputEdges. + absl::Status UpdateNode(TransposeContext* context, + utils::MutableNodeView* node); +}; + +// Layout sensitive op transposers. + +class DefaultLayoutSensitiveOpTransposer : public LayoutSensitiveOpTransposer { + public: + explicit DefaultLayoutSensitiveOpTransposer() + : LayoutSensitiveOpTransposer() {} + + absl::Status TransposeNode(TransposeContext* context, + utils::MutableNodeView* node) override; +}; + +class BiasAddTransposer : public LayoutSensitiveOpTransposer { + public: + explicit BiasAddTransposer() : LayoutSensitiveOpTransposer() {} + + absl::Status TransposeNode(TransposeContext* context, + utils::MutableNodeView* node) override; +}; + +class AvgPoolGradTransposer : public LayoutSensitiveOpTransposer { + public: + explicit AvgPoolGradTransposer() : LayoutSensitiveOpTransposer() {} + + absl::Status TransposeNode(TransposeContext* context, + utils::MutableNodeView* node) override; +}; + +class BiasAddGradTransposer : public LayoutSensitiveOpTransposer { + public: + explicit BiasAddGradTransposer() : LayoutSensitiveOpTransposer() {} + + absl::Status TransposeNode(TransposeContext* context, + utils::MutableNodeView* node) override; +}; + +class Conv2DBackpropFilterTransposer : public LayoutSensitiveOpTransposer { + public: + explicit Conv2DBackpropFilterTransposer() : LayoutSensitiveOpTransposer() {} + + absl::Status TransposeNode(TransposeContext* context, + utils::MutableNodeView* node) override; +}; + +class Conv2DBackpropInputTransposer : public LayoutSensitiveOpTransposer { + public: + explicit Conv2DBackpropInputTransposer() : LayoutSensitiveOpTransposer() {} + + absl::Status TransposeNode(TransposeContext* context, + utils::MutableNodeView* node) override; +}; + +class Conv3DTransposer : public LayoutSensitiveOpTransposer { + public: + explicit Conv3DTransposer() : LayoutSensitiveOpTransposer() {} + + absl::Status TransposeNode(TransposeContext* context, + utils::MutableNodeView* node) override; +}; + +class Conv3DBackpropFilterTransposer : public LayoutSensitiveOpTransposer { + public: + explicit Conv3DBackpropFilterTransposer() : LayoutSensitiveOpTransposer() {} + + absl::Status TransposeNode(TransposeContext* context, + utils::MutableNodeView* node) override; +}; + +class Conv3DBackpropInputTransposer : public LayoutSensitiveOpTransposer { + public: + explicit Conv3DBackpropInputTransposer() : LayoutSensitiveOpTransposer() {} + + absl::Status TransposeNode(TransposeContext* context, + utils::MutableNodeView* node) override; +}; + +class FusedBatchNormExTransposer : public LayoutSensitiveOpTransposer { + public: + explicit FusedBatchNormExTransposer() : LayoutSensitiveOpTransposer() {} + + absl::Status TransposeNode(TransposeContext* context, + utils::MutableNodeView* node) override; +}; + +class FusedBatchNormGradTransposer : public LayoutSensitiveOpTransposer { + public: + explicit FusedBatchNormGradTransposer() : LayoutSensitiveOpTransposer() {} + + absl::Status TransposeNode(TransposeContext* context, + utils::MutableNodeView* node) override; + + private: + bool IsTraining(const utils::MutableNodeView& node) const; +}; + +class MaxPoolV2Transposer : public LayoutSensitiveOpTransposer { + public: + explicit MaxPoolV2Transposer() : LayoutSensitiveOpTransposer() {} + + absl::Status TransposeNode(TransposeContext* context, + utils::MutableNodeView* node) override; +}; + +class MaxPool3DTransposer : public LayoutSensitiveOpTransposer { + public: + explicit MaxPool3DTransposer() : LayoutSensitiveOpTransposer() {} + + absl::Status TransposeNode(TransposeContext* context, + utils::MutableNodeView* node) override; +}; + +class MaxPoolGradTransposer : public LayoutSensitiveOpTransposer { + public: + explicit MaxPoolGradTransposer() : LayoutSensitiveOpTransposer() {} + + absl::Status TransposeNode(TransposeContext* context, + utils::MutableNodeView* node) override; +}; + +class MaxPoolGradV2Transposer : public LayoutSensitiveOpTransposer { + public: + explicit MaxPoolGradV2Transposer() : LayoutSensitiveOpTransposer() {} + + absl::Status TransposeNode(TransposeContext* context, + utils::MutableNodeView* node) override; +}; + +// Layout agnostic op transposers. + +class LayoutAgnosticOpTransposer : public Transposer { + public: + explicit LayoutAgnosticOpTransposer() : Transposer() {} + + protected: + bool IsAfterDstToSrcTransform(const TransposeContext& context, + const utils::MutableNodeView& node) const; + + std::vector GetVariadicNDFaninPorts(const TransposeContext& context, + const utils::MutableNodeView& node, + int rank) const; +}; + +class DefaultLayoutAgnosticOpTransposer : public LayoutAgnosticOpTransposer { + public: + explicit DefaultLayoutAgnosticOpTransposer() : LayoutAgnosticOpTransposer() {} + + absl::Status TransposeNode(TransposeContext* context, + utils::MutableNodeView* node) override; +}; + +class AddNTransposer : public LayoutAgnosticOpTransposer { + public: + explicit AddNTransposer() : LayoutAgnosticOpTransposer() {} + + absl::Status TransposeNode(TransposeContext* context, + utils::MutableNodeView* node) override; +}; + +class BinaryOpTransposer : public LayoutAgnosticOpTransposer { + public: + explicit BinaryOpTransposer() : LayoutAgnosticOpTransposer() {} + + absl::Status TransposeNode(TransposeContext* context, + utils::MutableNodeView* node) override; + + private: + bool IsNDOperateWithMD(const utils::MutableNodeView& node, int n, int m); + bool IsFaninShapeSupported(const utils::MutableNodeView& node, int rank); + std::vector GetNDDataFaninPorts(const utils::MutableNodeView& node, + int rank); + absl::Status AddNodeShapeConst(utils::Mutation* mutation, + absl::string_view node_name, + absl::string_view node_device, + bool node_in_frame, int num_channels, + absl::string_view depended_node, int rank); + absl::Status AddNodeReshape(utils::Mutation* mutation, + absl::string_view node_name, + absl::string_view node_device, + absl::string_view input_name, + absl::string_view shape_const_node_name, + const DataType& data_type); + absl::Status MaybeReshapeVectorFanin(TransposeContext* context, + utils::MutableNodeView* node, int rank); +}; + +class ConcatOpTransposer : public LayoutAgnosticOpTransposer { + public: + explicit ConcatOpTransposer() : LayoutAgnosticOpTransposer() {} + + absl::Status TransposeNode(TransposeContext* context, + utils::MutableNodeView* node) override; +}; + +class FillOpTransposer : public LayoutAgnosticOpTransposer { + public: + explicit FillOpTransposer() : LayoutAgnosticOpTransposer() {} + + absl::Status TransposeNode(TransposeContext* context, + utils::MutableNodeView* node) override; +}; + +class IdentityNTransposer : public LayoutAgnosticOpTransposer { + public: + explicit IdentityNTransposer() : LayoutAgnosticOpTransposer() {} + + absl::Status TransposeNode(TransposeContext* context, + utils::MutableNodeView* node) override; +}; + +class MergeTransposer : public LayoutAgnosticOpTransposer { + public: + explicit MergeTransposer() : LayoutAgnosticOpTransposer() {} + + absl::Status TransposeNode(TransposeContext* context, + utils::MutableNodeView* node) override; + + private: + bool IsEveryFaninAfterDstToSrcTransform( + const TransposeContext& context, + const utils::MutableNodeView& node) const; +}; + +class PadTransposer : public LayoutAgnosticOpTransposer { + public: + explicit PadTransposer() : LayoutAgnosticOpTransposer() {} + + absl::Status TransposeNode(TransposeContext* context, + utils::MutableNodeView* node) override; +}; + +class ReduceTransposer : public LayoutAgnosticOpTransposer { + public: + explicit ReduceTransposer() : LayoutAgnosticOpTransposer() {} + + absl::Status TransposeNode(TransposeContext* context, + utils::MutableNodeView* node) override; + + private: + bool KeepDims(const utils::MutableNodeView& node); + bool IsAlongAxis(const Tensor& tensor, absl::Span axis, int rank); + bool IsReduceAxisSupported(const TransposeContext& context, + const utils::MutableNodeView& node, int rank); +}; + +class ReverseV2Transposer : public LayoutAgnosticOpTransposer { + public: + explicit ReverseV2Transposer() : LayoutAgnosticOpTransposer() {} + + absl::Status TransposeNode(TransposeContext* context, + utils::MutableNodeView* node) override; +}; + +class SelectTransposer : public LayoutAgnosticOpTransposer { + public: + explicit SelectTransposer() : LayoutAgnosticOpTransposer() {} + + absl::Status TransposeNode(TransposeContext* context, + utils::MutableNodeView* node) override; + + protected: + bool IsFaninScalarVector4D(const utils::MutableNodeView& fanin, int port); + std::vector GetFaninPorts(const utils::MutableNodeView& fanin, int port); +}; + +class ShapeTransposer : public LayoutAgnosticOpTransposer { + public: + explicit ShapeTransposer() : LayoutAgnosticOpTransposer() {} + + absl::Status TransposeNode(TransposeContext* context, + utils::MutableNodeView* node) override; +}; + +class ShapeNTransposer : public LayoutAgnosticOpTransposer { + public: + explicit ShapeNTransposer() : LayoutAgnosticOpTransposer() {} + + absl::Status TransposeNode(TransposeContext* context, + utils::MutableNodeView* node) override; +}; + +class SliceTransposer : public LayoutAgnosticOpTransposer { + public: + explicit SliceTransposer() : LayoutAgnosticOpTransposer() {} + + absl::Status TransposeNode(TransposeContext* context, + utils::MutableNodeView* node) override; +}; + +class SplitTransposer : public LayoutAgnosticOpTransposer { + public: + explicit SplitTransposer() : LayoutAgnosticOpTransposer() {} + + absl::Status TransposeNode(TransposeContext* context, + utils::MutableNodeView* node) override; +}; + +class SplitVTransposer : public LayoutAgnosticOpTransposer { + public: + explicit SplitVTransposer() : LayoutAgnosticOpTransposer() {} + + absl::Status TransposeNode(TransposeContext* context, + utils::MutableNodeView* node) override; +}; + +class SqueezeTransposer : public LayoutAgnosticOpTransposer { + public: + explicit SqueezeTransposer() : LayoutAgnosticOpTransposer() {} + + absl::Status TransposeNode(TransposeContext* context, + utils::MutableNodeView* node) override; + + private: + bool IsInputConvertible(const TransposeContext& context, + const utils::MutableNodeView& node) const; + bool IsAlongAxis(const AttrValue& attr, absl::Span axis, + int rank) const; + bool IsDimsSupported(const TransposeContext& context, + const utils::MutableNodeView& node) const; + absl::Status UpdateSqueezeDims(TransposeContext* context, + utils::MutableNodeView* node); +}; + +class StridedSliceTransposer : public LayoutAgnosticOpTransposer { + public: + explicit StridedSliceTransposer() : LayoutAgnosticOpTransposer() {} + + absl::Status TransposeNode(TransposeContext* context, + utils::MutableNodeView* node) override; + + private: + bool IsMaskZero(const utils::MutableNodeView& node, absl::string_view mask); + bool HasOnlyBeginEndMask(const utils::MutableNodeView& node); + absl::Status PermuteMask(TransposeContext* context, + utils::MutableNodeView* node, + absl::string_view mask); +}; + +class SwitchTransposer : public LayoutAgnosticOpTransposer { + public: + explicit SwitchTransposer() : LayoutAgnosticOpTransposer() {} + + absl::Status TransposeNode(TransposeContext* context, + utils::MutableNodeView* node) override; +}; + +class TernaryOpTransposer : public LayoutAgnosticOpTransposer { + public: + explicit TernaryOpTransposer() : LayoutAgnosticOpTransposer() {} + + absl::Status TransposeNode(TransposeContext* context, + utils::MutableNodeView* node) override; +}; + +class TileTransposer : public LayoutAgnosticOpTransposer { + public: + explicit TileTransposer() : LayoutAgnosticOpTransposer() {} + + absl::Status TransposeNode(TransposeContext* context, + utils::MutableNodeView* node) override; +}; + +class UnaryGradTransposer : public LayoutAgnosticOpTransposer { + public: + explicit UnaryGradTransposer() : LayoutAgnosticOpTransposer() {} + + absl::Status TransposeNode(TransposeContext* context, + utils::MutableNodeView* node) override; +}; + +// Utils. + +// Permutes elements according to permutation and replaces the original values. +// Permutation and values must have same size. +template +absl::Status PermuteSingle(absl::string_view location, + absl::Span permutation, T* values) { + DCHECK(values != nullptr); + int permutation_size = permutation.size(); + if (values->size() != permutation_size) { + return absl::Status(absl::StatusCode::kInvalidArgument, + absl::StrCat("Size of values ", values->size(), + " does not match size of permutation ", + permutation_size, " @ ", location)); + } + typedef typename T::value_type V; + std::vector elements(values->begin(), values->end()); + int index = 0; + for (V& element : *values) { + element = elements[permutation[index++]]; + } + return absl::OkStatus(); +} + +// Permutes two elements at a time according to permutation and replaces the +// original values. Values must be twice the size of permutation. +template +absl::Status PermuteDouble(absl::string_view location, + absl::Span permutation, T* values) { + DCHECK(values != nullptr); + int permutation_size = permutation.size(); + if (values->size() != permutation_size * 2) { + return absl::Status( + absl::StatusCode::kInvalidArgument, + absl::StrCat("Size of values ", values->size(), + " does not match twice the size of permutation ", + permutation_size, " @ ", location)); + } + typedef typename T::value_type V; + std::vector elements(values->begin(), values->end()); + for (int i = 0; i < values->size(); i = i + 2) { + const int permutation_index = permutation[i / 2]; + (*values)[i] = elements[permutation_index * 2]; + (*values)[i + 1] = elements[permutation_index * 2 + 1]; + } + return absl::OkStatus(); +} + +string GetDeviceName(const NodeDef& node); + +bool IsDefaultLayoutSensitiveOp(const NodeDef& node); + +bool IsLayoutSensitiveOp(const NodeDef& node); + +bool IsDefaultLayoutAgnosticOp(const NodeDef& node); + +bool IsLayoutAgnosticOp(const NodeDef& node); + +bool IsTernaryOp(const NodeDef& node); + +bool IsUnaryGrad(const NodeDef& node); + +bool IsMaxPoolV2(const NodeDef& node); + +bool IsMaxPool3D(const NodeDef& node); + +bool IsMaxPoolGradV2(const NodeDef& node); + +bool IsMaxPoolGradGradV1(const NodeDef& node); + +bool IsMaxPoolGradGradV2(const NodeDef& node); + +bool IsBinaryOp(const NodeDef& node); + +bool IsReduceOp(const NodeDef& node); + +std::vector GetDataFaninPorts(const utils::MutableNodeView& node); + +std::vector GetDataFanoutPorts(const utils::MutableNodeView& node); + +// Returns a value of constant input to the `node` at `index`, iff `predicate` +// evaluated to true. Returns true if `tensor` was populated with data. +bool GetValueAttrFromConstInputNode( + const utils::MutableNodeView& node, + const std::function& predicate, int index, + Tensor* tensor); + +bool IsDataFormatOp(const utils::MutableNodeView& node); + +absl::flat_hash_map GetDimensionIndices( + absl::string_view data_format); + +std::vector GetPermutation( + const absl::flat_hash_map& src_dim_indices, + absl::string_view dst_format); + +} // namespace grappler +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_GENERIC_LAYOUT_OPTIMIZER_TRANSPOSER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/generic_layout_optimizer_transposer_factory.h b/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/generic_layout_optimizer_transposer_factory.h new file mode 100644 index 00000000..a31b1ca6 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/generic_layout_optimizer_transposer_factory.h @@ -0,0 +1,49 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_GENERIC_LAYOUT_OPTIMIZER_TRANSPOSER_FACTORY_H_ +#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_GENERIC_LAYOUT_OPTIMIZER_TRANSPOSER_FACTORY_H_ + +#include + +#include "absl/container/flat_hash_map.h" +#include "tensorflow/core/grappler/optimizers/generic_layout_optimizer_transposer.h" + +namespace tensorflow { +namespace grappler { + +class TransposerFactory { + public: + explicit TransposerFactory() {} + + std::shared_ptr GetTransposer(const NodeDef& node); + + protected: + template + std::shared_ptr GetOrCreateIfNotFound(const string& key) { + auto& transposer = transposer_map_[key]; + if (transposer == nullptr) { + transposer = std::make_shared(); + } + return transposer; + } + + absl::flat_hash_map> transposer_map_; +}; + +} // end namespace grappler +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_GENERIC_LAYOUT_OPTIMIZER_TRANSPOSER_FACTORY_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/graph_optimizer.h b/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/graph_optimizer.h new file mode 100644 index 00000000..6b7ba893 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/graph_optimizer.h @@ -0,0 +1,93 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_GRAPH_OPTIMIZER_H_ +#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_GRAPH_OPTIMIZER_H_ + +#include + +#include "absl/status/status.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +namespace grappler { + +class Cluster; +struct GrapplerItem; + +// An abstract interface for an algorithm for generating a candidate +// optimization of a GrapplerItem for running on a cluster. +class GraphOptimizer { + public: + GraphOptimizer() : deadline_usec_(0) {} + virtual ~GraphOptimizer() {} + + virtual string name() const = 0; + + // Returns true if the optimizer requires a valid function library to perform + // graph optimization. If false, optimized GrapplerItem will have a stub + // instead of real function library (all function signatures and attributes + // will be valid, but function body will be empty). Most of the optimizers + // that do not instantiate functions should return true. + virtual bool UsesFunctionLibrary() const = 0; + + // Routine called to allow an algorithm to propose a rewritten graph + // for the graph, feeds and fetches in "item" to run more efficiently + // on "cluster". If the returned status is OkStatus() then + // *optimized_graph contains the rewritten graph. + // Returns an error status if it failed to generate a solution. + // + // A return value of error::Aborted() can be used signal early termination of + // the optimizer, e.g. if the optimization turned out to be a no-op. In this + // case the content of *optimized_graph is undefined. + virtual absl::Status Optimize(Cluster* cluster, const GrapplerItem& item, + GraphDef* optimized_graph) = 0; + + // Subclasses may define a version of Optimize that consumes item. + virtual absl::Status Optimize(Cluster* cluster, GrapplerItem&& item, + GraphDef* optimized_graph) { + return Optimize(cluster, item, optimized_graph); + } + + // Set deadline in microseconds since epoch. A value of zero means no + // deadline. + void set_deadline_usec(uint64 deadline_usec) { + deadline_usec_ = deadline_usec; + } + uint64 deadline_usec() const { return deadline_usec_; } + bool DeadlineExceeded() const { + return deadline_usec_ > 0 && Env::Default()->NowMicros() > deadline_usec_; + } + + private: + uint64 deadline_usec_; +}; + +#define GRAPPLER_RETURN_IF_DEADLINE_EXCEEDED() \ + do { \ + if (this->DeadlineExceeded()) { \ + return absl::DeadlineExceededError( \ + absl::StrCat(this->name(), " exceeded deadline.")); \ + } \ + } while (0) + +} // end namespace grappler +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_GRAPH_OPTIMIZER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/graph_optimizer_stage.h b/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/graph_optimizer_stage.h new file mode 100644 index 00000000..ed5549ab --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/graph_optimizer_stage.h @@ -0,0 +1,315 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_GRAPH_OPTIMIZER_STAGE_H_ +#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_GRAPH_OPTIMIZER_STAGE_H_ + +#include +#include + +#include "absl/strings/str_cat.h" +#include "tensorflow/core/grappler/costs/graph_properties.h" +#include "tensorflow/core/grappler/grappler_item.h" +#include "tensorflow/core/grappler/utils.h" +#include "tensorflow/core/lib/gtl/flatset.h" +#include "tensorflow/core/protobuf/rewriter_config.pb.h" + +namespace tensorflow { +namespace grappler { + +struct NodeScopeAndName { + string scope; + string name; +}; + +// Parse scope and name: "a/b/c/Add_1" -> {"a/b/c", "Add_1"} +const NodeScopeAndName ParseNodeScopeAndName(const string& node_name); + +// Context owned by GraphOptimizer, and passed to every stage at construction +// time. Each optimizer stage is responsible for updating it according to the +// changes it made to the graph. +// +// If an optimizer needs access to some helper class that is not present in this +// context, consider creating an extension context, specific to that +// optimizer (see example of ArithmeticOptimizerContext). GraphOptimizerContext +// should only have members that are useful to almost all optimizers. +struct GraphOptimizerContext { + GraphOptimizerContext(const std::unordered_set* nodes_to_preserve, + GraphDef* optimized_graph, + GraphProperties* graph_properties, NodeMap* node_map, + gtl::FlatSet* feed_nodes, + RewriterConfig::Toggle opt_level) + : nodes_to_preserve(nodes_to_preserve), + optimized_graph(optimized_graph), + graph_properties(graph_properties), + node_map(node_map), + feed_nodes(feed_nodes), + opt_level(opt_level) {} + + const std::unordered_set* nodes_to_preserve; + GraphDef* optimized_graph; + GraphProperties* graph_properties; + NodeMap* node_map; + gtl::FlatSet* feed_nodes; + RewriterConfig::Toggle opt_level; +}; + +absl::Status GetInputNode(const GraphOptimizerContext& ctx, const string& input, + NodeDef** node); +absl::Status GetTensorProperties(const GraphOptimizerContext& ctx, + const string& tensor, + const OpInfo::TensorProperties** properties); + +NodeDef* AddCopyNode(const GraphOptimizerContext& ctx, const string& name, + const NodeDef* node_to_copy); +NodeDef* AddEmptyNode(const GraphOptimizerContext& ctx, const string& name); + +// WARNING: +// Optimizer stage must try to re-use original nodes of a graph and +// make all updates in place. This helps to make robust node placement +// decisions. Create new nodes only if there is a reason for that. + +// Make a name for a new node obtained by optimizing a single node of the +// original graph. The optimized node is placed under the original node scope. +// +// Node name uniqueness is guaranteed by unique name of an original node in +// a same scope. +// +// Empty sub_scope or prefix ignored. At least one of them must be non-empty. +// +// Example: a/b/c/Add -> a/b/c/${sub_scope}/${prefix}_Add. +const string MakeOptimizedNodeName(const NodeScopeAndName& node, + const string& sub_scope, + const string& prefix); +// Make a name for a new node obtained by optimizing multiple nodes of the +// original graph, starting from "root". The optimized node is placed under +// the original scope of a "root" node. +// +// Example: [a/b/c/Add, x/y/z/Mul] -> a/b/c/${sub_scope}/${prefix}_Add_Mul +const string MakeOptimizedNodeName(const NodeScopeAndName& root, + const std::vector node_names, + const string& sub_scope, + const string& prefix); + +// Base class for multi-stage GraphOptimizers (ArithmeticOptimizer, etc...). +// +// If a graph optimizer consists of large number of small independent +// rewrites, each of them should be implemented as a separate stage. +// +// * Result: +// Each graph optimizer choose what result is reported by each stage +// (e.g. each stage can fill in the name of optimized nodes, or have more +// complex result). +template +class GraphOptimizerStage { + public: + explicit GraphOptimizerStage(const string& optimizer_name, + const string& stage_name, + const GraphOptimizerContext& ctx) + : optimizer_name_(optimizer_name), stage_name_(stage_name), ctx_(ctx) {} + virtual ~GraphOptimizerStage() = default; + + const string& stage_name() const { return stage_name_; } + const string& optimizer_name() const { return optimizer_name_; } + + // Check if we should try to simplify node. Returning true doesn't + // guarantee that node will be simplified. + // + // Should implement just a basic sanity check, without any expensive graph + // traversals. + virtual bool IsSupported(const NodeDef* node) const = 0; + + // Try to simplify the given node. + // + // Return error status only if some precondition is failed, or got an + // incorrect graph. In every other case return Status:OK(), even if didn't + // simplify anything. + // + // Report result using output argument. Each GraphOptimizer can choose it's + // own Result type. + // TODO(ezhulenev): if it will appear that Result output parameter is not + // sufficiently useful (used with a reason by most optimizers), get rid of it, + // and remove template parameter. + virtual absl::Status TrySimplify(NodeDef* node, Result* result) = 0; + + // Return InvalidArgumentError if node is not supported by the optimizer + // stage. + // TODO(ezhulenev): make this check part of non-virtual public API + // (TrySimplify), and make virtual implementation protected. + absl::Status EnsureNodeIsSupported(const NodeDef* node) const { + return IsSupported(node) + ? absl::OkStatus() + : errors::InvalidArgument( + "Node ", node->name(), " is not supported by optimizer ", + optimizer_name_, " and stage ", stage_name_); + } + + // Get a name for a new node, created by this stage, based on one or multiple + // nodes of an original graph. + const string OptimizedNodeName(const NodeScopeAndName& node) const { + return MakeOptimizedNodeName(node, optimizer_name_, stage_name_); + } + const string OptimizedNodeName(const NodeScopeAndName& root, + const std::vector& nodes) const { + return MakeOptimizedNodeName(root, nodes, optimizer_name_, stage_name_); + } + const string OptimizedNodeName(const NodeScopeAndName& node, + const string& rewrite_rule) const { + const string prefix = strings::StrCat(stage_name_, "_", rewrite_rule); + return MakeOptimizedNodeName(node, optimizer_name_, prefix); + } + + const string UniqueOptimizedNodeName(const NodeScopeAndName& node) { + const string node_name = OptimizedNodeName(node); + return UniqueNodeName(node_name); + } + const string UniqueOptimizedNodeName(const NodeScopeAndName& node, + const string& rewrite_rule) { + const string node_name = OptimizedNodeName(node, rewrite_rule); + return UniqueNodeName(node_name); + } + + // Get a node by input name from a node map. Return an error if node was not + // found. + absl::Status GetInputNode(const string& input, NodeDef** node) const { + return ::tensorflow::grappler::GetInputNode(ctx_, input, node); + } + // Lookup tensor properties by name. Tensor name might have non-zero port + // number. Return an error if tensor node doesn't exists in a graph, or it + // doesn't have properties defined for requested port. + absl::Status GetTensorProperties( + const string& tensor, const OpInfo::TensorProperties** properties) const { + return ::tensorflow::grappler::GetTensorProperties(ctx_, tensor, + properties); + } + + NodeDef* AddCopyNode(const string& name, const NodeDef* node_to_copy) { + return ::tensorflow::grappler::AddCopyNode(ctx_, name, node_to_copy); + } + NodeDef* AddEmptyNode(const string& name) { + return ::tensorflow::grappler::AddEmptyNode(ctx_, name); + } + + protected: + const GraphOptimizerContext& ctx() const { return ctx_; } + + private: + const string UniqueNodeName(absl::string_view name) { + string node_name = string(name); + while (ctx_.node_map->NodeExists(node_name)) { + node_name = absl::StrCat(name, "_unique", + optimized_node_name_counter_.fetch_add(1)); + } + + return node_name; + } + + const string optimizer_name_; + const string stage_name_; + const GraphOptimizerContext ctx_; + std::atomic optimized_node_name_counter_ = {0}; +}; + +template +class GraphOptimizerStagePipeline { + public: + // Break predicate specifies if a pipeline should stop early, and not pass + // a node to the next registered optimizer stage, typically that should be the + // case when a stage successfully optimized a node, and it wants to yield + // control to the optimizer. + explicit GraphOptimizerStagePipeline( + const std::function break_predicate) + : break_predicate_(break_predicate) {} + + // Add a stage to the pipeline. It should be called with the arguments for the + // stage constructor: + // + // pipeline.AddStage(constructor_arg1, constructor_arg2); + // + // Returns a reference to the added stage. + template + T& AddStage(Args&&... args) { + auto stage = new T(std::forward(args)...); + stages_.push_back(std::unique_ptr(stage)); + return *stage; + } + + // Pass a node through all registered optimizer stages, until break predicate + // is true. + // + // Return true, if pipeline exited after a break predicate was evaluated as + // 'true', which typically means that a node was optimized by one of the + // registered stages. + // + // Return false, if node was not optimized by any of registered stages. + bool PassThroughAllStages(NodeDef* node, Result* result) { + for (auto& stage : stages_) { + if (stage->IsSupported(node)) { + const absl::Status stage_status = stage->TrySimplify(node, result); + // Each stage must be "error safe" (just like exception safe). In + // case of any error it must leave optimized graph unmodified. + if (!stage_status.ok()) { + VLOG(2) << "Failed to run optimizer " << stage->optimizer_name() + << ", stage " << stage->stage_name() << " node " + << node->name() << ". Error: " << stage_status.message(); + } + if (break_predicate_(*result)) return true; + } + } + return false; + } + + // Pass a node through all registered optimizer stages, until break predicate + // is true or a stage fails. + // + // Returns any stage failure status, or else OkStatus(). + absl::Status PassThroughAllStagesWithStatus(NodeDef* node, Result* result) { + for (auto& stage : stages_) { + if (!stage->IsSupported(node)) { + continue; + } + const absl::Status stage_status = stage->TrySimplify(node, result); + if (!stage_status.ok()) { + return stage_status; + } else if (break_predicate_(*result)) { + break; + } + } + return absl::OkStatus(); + } + + std::size_t NumStages() { return stages_.size(); } + + std::vector StageNames() { + std::vector names; + names.reserve(stages_.size()); + for (const auto& stage : stages_) { + names.push_back(stage->stage_name()); + } + return names; + } + + private: + std::vector>> stages_; + std::function break_predicate_; + + GraphOptimizerStagePipeline(const GraphOptimizerStagePipeline&) = delete; + void operator=(const GraphOptimizerStagePipeline&) = delete; +}; + +} // end namespace grappler +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_GRAPH_OPTIMIZER_STAGE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/implementation_selector.h b/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/implementation_selector.h new file mode 100644 index 00000000..dc804fdc --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/implementation_selector.h @@ -0,0 +1,203 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_IMPLEMENTATION_SELECTOR_H_ +#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_IMPLEMENTATION_SELECTOR_H_ + +#include + +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/grappler/costs/graph_properties.h" +#include "tensorflow/core/grappler/grappler_item.h" +#include "tensorflow/core/grappler/op_types.h" +#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer.h" +#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h" +#include "tensorflow/core/grappler/optimizers/function_api_info.h" +#include "tensorflow/core/grappler/utils/graph_view.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/util/device_name_utils.h" + +namespace tensorflow { +namespace grappler { + +static constexpr const char* const kNoImplSelectionAttr = "_noimpl_selection"; + +// Motivation: To achieve the same high level functionality, the underlying +// implementations sometimes are different for various devices where the +// function runs. In order to achieve the correct result and best performance, +// the proper implementation needs to be picked dynamically. +// +// Currently there are two approaches to do this. +// (1) Utilize case op and dynamacically change the branch index. +// (2) Swap function implementation, it will be deprecated. +// +// Idea for approach 1. +// This transformation rewrites the DeviceIndex op with a Const op with value +// of the index of the device the associcated Case op runs. +// Example: +// def plus_one_gpu(x): return x + 1.0 +// def plus_one_reference_implementation(x): return x + 1.0 +// input = tf.constant(2.0, dtype=tf.float32) +// cpu_fn = lambda:plus_one_reference_implementation(input) +// gpu_fn = lambda:plus_one_gpu(input) +// control_flow_switch_case.execute_fn_for_device( +// {"CPU": cpu_fn, "GPU":gpu_fn)}, default_fn=cpu_fn) +// +// Idea for approach 2. +// This transformation replaces function calls by the appropriate function +// definition based on properties of the runtime system. For instance, +// we may choose one implementation over another if we have a GPU with +// enough memory available. +// +// It is a way for the programmer to specify alternative implementations +// of the same functionality in the graph, and let TensorFlow pick the +// most appropriate one at runtime. +// +// For instance, the python code might specify: +// @Defun(tf.float32, +// api_implements='plus_one', +// api_preferred_device='GPU') +// def plus_one_gpu(x): return x + 1.0 +// +// @Defun(tf.float32, +// api_implements='plus_one') +// def plus_one_reference_implementation(x): return x + 1.0 +// input = tf.constant(2.0, dtype=tf.float32) +// +// z = plus_one_reference_implementation(input) +// z = plus_one_gpu(input) +// print(sess.run(z)) +// + +// At runtime, we will select either `plus_one_gpu` or +// `plus_one_reference_implementation` based on the availability of the GPU. +// +// Available annotations: +// - api_implements(string): all functions mapping to the same +// string can be interchanged. For now, all functions must have the same +// signature and overloads are not allowed. Defuns within defuns are +// allowed. +// - api_preferred_device(string): sets which device is preferred. +class ImplementationSelector : public CustomGraphOptimizer { + public: + ImplementationSelector() = default; + ~ImplementationSelector() override = default; + absl::Status Init( + const tensorflow::RewriterConfig_CustomGraphOptimizer* config) override { + return absl::OkStatus(); + } + string name() const override { + return "implementation_selector"; + } + + bool UsesFunctionLibrary() const override { return false; } + + // This call is not thread-safe. + absl::Status Optimize(Cluster* cluster, const GrapplerItem& item, + GraphDef* optimized_graph) override; + + private: + absl::Status LoadFunctions(const GraphDef& graph); + absl::Status MaybeOptimizeFunctionCall( + const Cluster* cluster, utils::MutableNodeView* node_view) const; + + // Finds all call sites for functions, then replace with the appropriate + // implementation. + // There are two ways of calling functions: + // 1. By specifying an op name as a function name, and + // 2. Via the functional interface, where the function name appears as an + // Attr. + // + // There may be multiple call sites for a given function. The function body + // may call into another function, so a function might have to be duplicated. + // For simplicity, we do not change function bodies. Also, we do not change + // gradients. + absl::Status SelectImplementation(const Cluster* cluster, + GraphDef* graph) const; + + // Rewrites the DeviceIndex op with a Const op with value of the index of the + // device the associcated Case op runs. + + // This function first looks up all the DeviceIndex ops. + // Then for each of these ops, it finds the device of the + // associated Case op that takes the DeviceIndex op as the input, and + // caculates the index of the device in the device list of DeviceIndex op. + // Lastly, it rewrites the DeviceIndex op with a Const op and sets the value + // to be the index. + // + // Example input nodes: + // node { + // name: "x" + // op: "DeviceIndex" + // device: "/device:CPU:0" + // attr { + // key: "device_names" + // value { + // list { + // s: "CPU" + // s: "TPU_REPLICATED_CORE" + // s: "GPU" + // } + // } + // } + // } + // node { + // name: "case" + // op: "Case" + // input: "x" + // device: "/device:GPU:0" + // ... + // } + // Example output nodes: + // + // name: "x" + // op: "Const" + // device: "/device:CPU:0" + // attr { + // key: "dtype" + // value { + // type: DT_INT32 + // } + // } + // attr { + // key: "value" + // value { + // tensor { + // dtype: DT_INT32 + // int_val: 2 + // } + // } + // } + // node { + // name: "case" + // op: "Case" + // input: "x" + // device: "/device:GPU:0" + // ... + // } + absl::Status SelectDeviceIndex(GraphDef* graph) const; + + std::unique_ptr lib_info_; + + ImplementationSelector(const ImplementationSelector&) = delete; + void operator=(const ImplementationSelector&) = delete; +}; + +} // namespace grappler +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_IMPLEMENTATION_SELECTOR_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/inference/batch_op_rewriter.h b/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/inference/batch_op_rewriter.h new file mode 100644 index 00000000..d15ff68b --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/inference/batch_op_rewriter.h @@ -0,0 +1,66 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_INFERENCE_BATCH_OP_REWRITER_H_ +#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_INFERENCE_BATCH_OP_REWRITER_H_ + +#include "tensorflow/core/grappler/grappler_item.h" +#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer.h" +#include "tensorflow/core/grappler/optimizers/inference/batch_op_rewriter.pb.h" + +namespace tensorflow { +namespace grappler { + +constexpr char kEnableAdaptiveSchedulerAttr[] = "_enable_adaptive_scheduler"; +constexpr char kMinInflightBatchesAttr[] = "_min_inflight_batches"; +constexpr char kInitialInflightBatchesAttr[] = "_initial_inflight_batches"; +constexpr char kMaxInflightBatchesAttr[] = "_max_inflight_batches"; +constexpr char kBatchesToAverageOverAttr[] = "_batches_to_average_over"; +constexpr char kFullBatchSchedulingBoostMicros[] = + "_full_batch_scheduling_boost_micros"; // NOLINT(whitespace/line_length) + +constexpr int64_t kMinInflightBatches = 16; +constexpr int64_t kInitialInflightBatches = 16; +constexpr int64_t kBatchesToAverageOver = 10; +constexpr int64_t kMaxInflightBatches = 64; + +using ::tensorflow::serving::BatchOpRewriteConfig; + +// This optimization does the following: +// +// Rewrite `num_batch_threads` to zero in batch-op. In this way, graphs with +// batch op will use a shared thread pool to schedule batches, as opposed to +// allocating batch threads per batch-op. +class BatchOpRewriter : public ::tensorflow::grappler::CustomGraphOptimizer { + public: + absl::Status Init( + const ::tensorflow::RewriterConfig_CustomGraphOptimizer* config) override; + + std::string name() const override { return "batch_op_rewriter"; } + + bool UsesFunctionLibrary() const override { return false; } + + absl::Status Optimize(::tensorflow::grappler::Cluster* cluster, + const ::tensorflow::grappler::GrapplerItem& item, + ::tensorflow::GraphDef* optimized_graph) override; + + private: + BatchOpRewriteConfig config_; +}; + +} // namespace grappler +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_INFERENCE_BATCH_OP_REWRITER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/loop_optimizer.h b/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/loop_optimizer.h new file mode 100644 index 00000000..0b561876 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/loop_optimizer.h @@ -0,0 +1,76 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_LOOP_OPTIMIZER_H_ +#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_LOOP_OPTIMIZER_H_ + +#include + +#include "tensorflow/core/grappler/costs/graph_properties.h" +#include "tensorflow/core/grappler/optimizers/graph_optimizer.h" +#include "tensorflow/core/grappler/utils.h" +#include "tensorflow/core/grappler/utils/frame.h" +#include "tensorflow/core/protobuf/rewriter_config.pb.h" + +namespace tensorflow { +namespace grappler { + +constexpr char kLoopOptimizer[] = "LoopOptimizer"; + +class LoopOptimizer : public GraphOptimizer { + public: + LoopOptimizer(); + + explicit LoopOptimizer(RewriterConfig::Toggle opt_level, + DeviceBase* cpu_device); + + ~LoopOptimizer() override {} + + string name() const override { return "loop_optimizer"; }; + + bool UsesFunctionLibrary() const override { return false; } + + absl::Status Optimize(Cluster* cluster, const GrapplerItem& item, + GraphDef* optimized_graph) override; + + private: + friend class LoopOptimizerTest; + + // Granular control for loop optimizer stages. + struct LoopOptimizerOptions { + bool enable_loop_invariant_node_motion = false; + bool enable_stack_push_removal = true; + bool enable_dead_branch_removal = true; + + static LoopOptimizerOptions Default(RewriterConfig::Toggle opt_level) { + LoopOptimizerOptions options; + return options; + } + }; + + absl::Status RemoveDeadBranches( + const std::unordered_set& nodes_to_preserve, NodeMap& node_map, + const absl::flat_hash_set& feed_nodes, GraphDef* optimized_graph); + + RewriterConfig::Toggle opt_level_; + DeviceBase* cpu_device_; + LoopOptimizerOptions options_; + std::unique_ptr resource_mgr_; +}; + +} // end namespace grappler +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_LOOP_OPTIMIZER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/memory_optimizer.h b/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/memory_optimizer.h new file mode 100644 index 00000000..e1274d93 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/memory_optimizer.h @@ -0,0 +1,58 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_MEMORY_OPTIMIZER_H_ +#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_MEMORY_OPTIMIZER_H_ + +#include +#include "tensorflow/core/grappler/clusters/cluster.h" +#include "tensorflow/core/grappler/grappler_item.h" +#include "tensorflow/core/grappler/optimizers/graph_optimizer.h" +#include "tensorflow/core/protobuf/rewriter_config.pb.h" + +namespace tensorflow { +namespace grappler { + +// Swap tensors in and out of device memory. +class MemoryOptimizer : public GraphOptimizer { + public: + // optimization_level: Controls the level of autonomy for the memory + // optimizer. See RewriterConfig::memory_optimization. + // recomputation_targets_name_scope: Name scope for potential outputs of + // recomputations. See + // RewriterConfig::memory_optimizer_target_node_name_scope. + explicit MemoryOptimizer( + RewriterConfig::MemOptType optimization_level, + const string& recomputation_targets_name_scope = "gradients/") + : optimization_level_(optimization_level), + recomputation_targets_name_scope_(recomputation_targets_name_scope) {} + ~MemoryOptimizer() override {} + + string name() const override { return "memory_optimizer"; }; + + bool UsesFunctionLibrary() const override { return false; } + + absl::Status Optimize(Cluster* cluster, const GrapplerItem& item, + GraphDef* pruned_graph) override; + + private: + RewriterConfig::MemOptType optimization_level_; + string recomputation_targets_name_scope_; +}; + +} // end namespace grappler +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_MEMORY_OPTIMIZER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/meta_optimizer.h b/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/meta_optimizer.h new file mode 100644 index 00000000..74756553 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/meta_optimizer.h @@ -0,0 +1,166 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_META_OPTIMIZER_H_ +#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_META_OPTIMIZER_H_ + +#include "tensorflow/core/common_runtime/device_set.h" +#include "tensorflow/core/framework/device_base.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/grappler/grappler_item.h" +#include "tensorflow/core/grappler/optimizers/graph_optimizer.h" +#include "tensorflow/core/grappler/verifiers/graph_verifier.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/protobuf/config.pb.h" +#include "tensorflow/core/protobuf/rewriter_config.pb.h" +#include "tensorflow/core/protobuf/verifier_config.pb.h" + +namespace tensorflow { +namespace grappler { + +// Run the other grappler optimizers based on the specified rewriter config. +class MetaOptimizer : public GraphOptimizer { + public: + MetaOptimizer(DeviceBase* cpu_device, const ConfigProto& cfg); + ~MetaOptimizer() override = default; + + string name() const override { return "meta_optimizer"; }; + + bool UsesFunctionLibrary() const override { return true; } + + absl::Status Optimize(Cluster* cluster, const GrapplerItem& item, + GraphDef* optimized_graph) override { + GrapplerItem copy(item); + return OptimizeConsumeItem(cluster, std::move(copy), optimized_graph); + } + + absl::Status OptimizeConsumeItem(Cluster* cluster, GrapplerItem&& item, + GraphDef* optimized_graph); + + string GetResultString() const; + + void PrintResult(); + + private: + std::unique_ptr MakeNewOptimizer( + const string& optimizer, const std::set& device_types) const; + + // When grappler should lower control flow to V1 switch/merge style nodes. + bool LowerControlFlow() const; + + // Initialize active optimizers from RewriterConfig toggles. + absl::Status InitializeOptimizers( + const std::set& device_types, + std::vector>* optimizers) const; + // Initialize active optimizers from RewriterConfig optimizer names. + absl::Status InitializeOptimizersByName( + const std::set& device_types, + std::vector>* optimizers) const; + // Initialize active optimizers from RewriterConfig.custom_optimizers. + absl::Status InitializeCustomGraphOptimizers( + const std::set& device_types, + const std::set& pre_initialized_optimizers, + std::vector>* optimizers) const; + absl::Status InitializePluginGraphOptimizers( + const std::set& device_types, + std::vector>* optimizers) const; + // Returns the config for a custom graph optimizer. Null if none was found. + const RewriterConfig::CustomGraphOptimizer* GetCustomGraphOptimizerConfig( + const string& name) const; + + // Initialize active verifiers from the RewriterConfig toggles. + void InitializeVerifiers( + std::vector>* inter_optimizer_verifiers, + std::vector>* post_optimization_verifiers) + const; + + void PrintUserAndPluginConfigs(const std::set& device_types) const; + + // Run optimization pass over a single GrapplerItem. Meta optimizer might run + // multiple such passes: 1) for the main graph 2) for the function library + absl::Status OptimizeGraph( + const std::vector>& optimizers, + Cluster* cluster, GrapplerItem&& item, GraphDef* optimized_graph); + absl::Status OptimizeGraph(Cluster* cluster, GrapplerItem&& item, + GraphDef* optimized_graph); + + DeviceBase* const cpu_device_; // may be NULL + ConfigProto config_proto_; + RewriterConfig& cfg_; + bool xla_auto_clustering_on_; + + struct OptimizerResult { + string optimizer_name; + string message; + absl::Status status; + }; + + struct GraphOptimizationResult { + explicit GraphOptimizationResult(const string& id) : id(id) {} + string id; + std::vector results; + }; + + absl::Status RunOptimizer(GraphOptimizer* optimizer, Cluster* cluster, + GrapplerItem* optimized_item, + GraphDef* optimized_graph, + GraphOptimizationResult* optimization_result); + + std::vector optimization_results_; +}; + +bool MetaOptimizerEnabled(const ConfigProto& cfg); + +// Run the meta optimizer. +// +// If is non-null, it is the device to be used for executing ops +// during constant folding; if NULL, a new device is created for doing constant +// folding. For performance, it is recommended to pass in an existing cpu_device +// when possible. +absl::Status RunMetaOptimizer(GrapplerItem&& item, const ConfigProto& cfg, + DeviceBase* cpu_device, Cluster* cluster, + GraphDef* optimized_graph); + +// Wrapper around RunMetaOptimizer convenient for optimizing +// function graphs. +// +// Runs grappler optimizations on `g` based on `config_proto`. +// `ret_node_names`: a vector of node names whose outputs are returned, +// aka fetches. when `g` represent a function, these are _Retval nodes. +// `lib`: function library to use with `g`. +// `device_set`: the set of devices that graph can refer to. +// `cpu_device`: the CPU device. +// `config_proto`: Grapper configuration. +// `grappler_item_id': Grappler item id (e.g. optimized function name). +// `optimization_options`: Grappler optimization constraints that are known only +// at runtime. +// +// **g is a graph constructed based on the runtime library 'lib'. +// OptimizeGraph mutates **g extensively and replaces '*g' with a +// complete copy. Therefore, the caller should not keep any references +// to nodes *g. +absl::Status OptimizeGraph( + std::vector ret_node_names, std::vector keep_node_names, + FunctionLibraryDefinition* lib, const DeviceSet& device_set, + Device* cpu_device, const ConfigProto& config_proto, + const string& grappler_item_id, + const GrapplerItem::OptimizationOptions& optimization_options, + std::unique_ptr* g); + +} // namespace grappler +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_META_OPTIMIZER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/model_pruner.h b/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/model_pruner.h new file mode 100644 index 00000000..668bb442 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/model_pruner.h @@ -0,0 +1,43 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_MODEL_PRUNER_H_ +#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_MODEL_PRUNER_H_ + +#include "tensorflow/core/grappler/optimizers/graph_optimizer.h" + +namespace tensorflow { +namespace grappler { + +// Prune a model to make it more efficient: +// * Remove unnecessary operations. +// * Optimize gradient computations. +class ModelPruner : public GraphOptimizer { + public: + ModelPruner() {} + ~ModelPruner() override {} + + string name() const override { return "model_pruner"; }; + + bool UsesFunctionLibrary() const override { return false; } + + absl::Status Optimize(Cluster* cluster, const GrapplerItem& item, + GraphDef* optimized_graph) override; +}; + +} // end namespace grappler +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_MODEL_PRUNER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/pin_to_host_optimizer.h b/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/pin_to_host_optimizer.h new file mode 100644 index 00000000..3cd1db08 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/pin_to_host_optimizer.h @@ -0,0 +1,57 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_PIN_TO_HOST_OPTIMIZER_H_ +#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_PIN_TO_HOST_OPTIMIZER_H_ + +#include +#include "tensorflow/core/grappler/costs/graph_properties.h" +#include "tensorflow/core/grappler/optimizers/graph_optimizer.h" +#include "tensorflow/core/lib/gtl/flatset.h" +#include "tensorflow/core/protobuf/rewriter_config.pb.h" + +namespace tensorflow { +namespace grappler { +namespace internal { +// Try and find an appropriate Host device in `devices` given `device`. +string TryFindHostDevice(const gtl::FlatSet& devices, + bool has_device_cpu, const string& device); +} // end namespace internal + +// Optimize TensorFlow ops that should be swapped into the CPU to avoid +// excessive cpu<->gpu memcpy/sync. +// +// TODO(williamchan): The current heuristic will swap any small integer Const to +// CPU. This may cause a problem cpu->cpu->gpu wherein the original behaviour of +// gpu->gpu->gpu may have been better/faster. We should probably fix this. +class PinToHostOptimizer : public GraphOptimizer { + public: + PinToHostOptimizer() {} + explicit PinToHostOptimizer(RewriterConfig::Toggle opt_level) {} + + ~PinToHostOptimizer() override {} + + string name() const override { return "pin_to_host_optimizer"; }; + + bool UsesFunctionLibrary() const override { return false; } + + absl::Status Optimize(Cluster* cluster, const GrapplerItem& item, + GraphDef* optimized_graph) override; +}; + +} // end namespace grappler +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_PIN_TO_HOST_OPTIMIZER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/remapper.h b/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/remapper.h new file mode 100644 index 00000000..51332eeb --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/remapper.h @@ -0,0 +1,55 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_REMAPPER_H_ +#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_REMAPPER_H_ + +#include "tensorflow/core/grappler/optimizers/graph_optimizer.h" +#include "tensorflow/core/protobuf/rewriter_config.pb.h" + +namespace tensorflow { +namespace grappler { + +// Optimize TF computations by remapping subgraphs/nodes onto other subgraphs or +// nodes to decrease the amount of operations needed to perform a computation. +class Remapper : public GraphOptimizer { + public: + explicit Remapper(RewriterConfig::Toggle opt_level, + RewriterConfig::CpuLayout cpu_layout_conversion = + RewriterConfig::NO_CONVERSION_ON_CPU, + bool xla_auto_clustering_on = false) + : opt_level_(opt_level), + cpu_layout_conversion_(cpu_layout_conversion), + xla_auto_clustering_on_(xla_auto_clustering_on) {} + + ~Remapper() override {} + + string name() const override { return "remapper"; }; + + bool UsesFunctionLibrary() const override { return false; } + + absl::Status Optimize(Cluster* cluster, const GrapplerItem& item, + GraphDef* optimized_graph) override; + + private: + RewriterConfig::Toggle opt_level_; + RewriterConfig::CpuLayout cpu_layout_conversion_; + bool xla_auto_clustering_on_; +}; + +} // end namespace grappler +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_REMAPPER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/scoped_allocator_optimizer.h b/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/scoped_allocator_optimizer.h new file mode 100644 index 00000000..1b50f148 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/scoped_allocator_optimizer.h @@ -0,0 +1,127 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_SCOPED_ALLOCATOR_OPTIMIZER_H_ +#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_SCOPED_ALLOCATOR_OPTIMIZER_H_ + +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "tensorflow/core/grappler/optimizers/graph_optimizer.h" +#include "tensorflow/core/grappler/utils.h" +#include "tensorflow/core/protobuf/rewriter_config.pb.h" + +namespace tensorflow { +class Graph; + +namespace grappler { +class GraphProperties; +class NodeMap; +class ScopedAllocatorOptimizer; + +// An Optimizer that introduces ScopedAllocators in order to reduce data +// movement and consolidate some kinds of Ops. +class ScopedAllocatorOptimizer : public GraphOptimizer { + public: + ScopedAllocatorOptimizer(RewriterConfig::Toggle opt_level, + const ScopedAllocatorOptions& opts); + ~ScopedAllocatorOptimizer() override; + + string name() const override { return "scoped_allocator_optimizer"; } + + bool UsesFunctionLibrary() const override { return true; } + + absl::Status Optimize(Cluster* cluster, const GrapplerItem& item, + GraphDef* optimized_graph) override; + + // Map from an Op name to a vector of Nodes with that Op. + typedef absl::flat_hash_map> DevOpOccurrences; + // Map from a device name to a DevOpOccurrences map. + typedef absl::flat_hash_map GraphOpOccurrences; + typedef absl::flat_hash_set OpNameSet; + + absl::Status ProcessGraphDef(GraphDef* graph, + const GraphProperties& graph_properties); + + // Populates *occs by grouping Nodes with common Ops, according to + // their assigned devices. + void FindOpOccurrences(GraphDef* graph, const OpNameSet& op_names, + GraphOpOccurrences* occs); + + // Returns a new, unused scope_id to be assigned to a ScopedAllocator that + // will allocate num_fields (> 0) separate tensors. + int NewScopedAllocatorId(int num_fields); + + // Returns a new, unused id to be assigned to an IdentityOp used in this graph + // rewrite. + absl::Status NewIdentityId(int* id); + + NodeMap* node_map() { return node_map_.get(); } + + const absl::flat_hash_set& repeated_outputs() { + return repeated_outputs_; + } + + // Appends values to the attr value under name in node_def, if present. + // If not present does an assignment. + static void ExtendNodeAttr(absl::string_view name, + const std::vector& values, + NodeDef* node_def); + + // Class that knows how to do graph rewriting for a particular kind of Op in + // order to take advantage of a ScopedAllocator. + class Rewriter { + public: + virtual ~Rewriter() {} + + virtual absl::Status Rewrite(ScopedAllocatorOptimizer* paopti, + int64_t invocation_count, GraphDef* graph, + const string& op_name, + const std::vector& nodes, + bool* applied) = 0; + + void SetGraphProperties(const GraphProperties& graph_properties) { + graph_properties_ = &graph_properties; + CHECK(graph_properties_); + } + + protected: + const GraphProperties* graph_properties_; + }; + + private: + Rewriter* GetRewriter(const string& op_name); + + absl::Status OrderNodeSet(std::vector* nodes) const; + + RewriterConfig::Toggle opt_level_; + std::unordered_set nodes_to_preserve_; + OpNameSet op_name_set_; + absl::flat_hash_map rewriters_; + std::vector to_delete_; + int next_sa_id_ = 1; + int next_identity_id_ = 1; + std::unique_ptr node_map_; + // Keeps track of outputs, i.e. a node and an output index, that are inputs to + // more than one op groups that are candidates for scoped allocator + // optimization. + absl::flat_hash_set repeated_outputs_; +}; + +} // namespace grappler +} // namespace tensorflow +#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_SCOPED_ALLOCATOR_OPTIMIZER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/shape_optimizer.h b/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/shape_optimizer.h new file mode 100644 index 00000000..00679ca8 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/shape_optimizer.h @@ -0,0 +1,49 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_SHAPE_OPTIMIZER_H_ +#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_SHAPE_OPTIMIZER_H_ + +#include +#include "tensorflow/core/grappler/costs/graph_properties.h" +#include "tensorflow/core/grappler/optimizers/graph_optimizer.h" +#include "tensorflow/core/grappler/utils.h" +#include "tensorflow/core/grappler/utils/frame.h" +#include "tensorflow/core/protobuf/rewriter_config.pb.h" + +namespace tensorflow { +namespace grappler { + +// Optimize TensorFlow subgraphs that operate on shape and shape related +// information. +class ShapeOptimizer : public GraphOptimizer { + public: + ShapeOptimizer() {} + explicit ShapeOptimizer(RewriterConfig::Toggle opt_level) {} + + ~ShapeOptimizer() override {} + + string name() const override { return "shape_optimizer"; }; + + bool UsesFunctionLibrary() const override { return false; } + + absl::Status Optimize(Cluster* cluster, const GrapplerItem& item, + GraphDef* optimized_graph) override; +}; + +} // end namespace grappler +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_SHAPE_OPTIMIZER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/static_schedule.h b/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/static_schedule.h new file mode 100644 index 00000000..b26ce381 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/static_schedule.h @@ -0,0 +1,50 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_STATIC_SCHEDULE_H_ +#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_STATIC_SCHEDULE_H_ + +#include + +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/grappler/clusters/cluster.h" +#include "tensorflow/core/grappler/costs/cost_estimator.h" +#include "tensorflow/core/grappler/grappler_item.h" + +namespace tensorflow { +namespace grappler { + +// Compute the earliest time at which the execution of each node in the graph +// can complete. +// In our estimation, we ensure that each node takes at least one nanosecond to +// execute: therefore the execution times can be used to derive a topological +// ordering of the graph (at least as long as there is no loop in the graph). +absl::Status EstimateEarliestExecutionTimes( + const GrapplerItem& item, const Cluster* cluster, + std::unordered_map* execution_times); + +// Compute the time by which the execution of each node must complete to ensure +// the subsequent nodes can still be executed by the times predicted by the +// EstimateEarliestExecutionTimes function. +absl::Status EstimateRequiredTimes( + const GrapplerItem& item, const Cluster* cluster, + const std::unordered_map& + execution_times, + std::unordered_map* required_times); + +} // namespace grappler +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_STATIC_SCHEDULE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/tfg_optimizer_hook.h b/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/tfg_optimizer_hook.h new file mode 100644 index 00000000..58872497 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/tfg_optimizer_hook.h @@ -0,0 +1,65 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_TFG_OPTIMIZER_HOOK_H_ +#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_TFG_OPTIMIZER_HOOK_H_ + +#include +#include + +#include "tensorflow/core/grappler/optimizers/graph_optimizer.h" + +namespace mlir { +class PassManager; + +namespace tfg { + +// A function that builds the TFG pass pipeline. +using TFGPassPipelineBuilder = std::function; + +// This class implements a Grappler optimizer wrapping a pipeline of passes +// implemented with TFG. +class TFGGrapplerOptimizer : public tensorflow::grappler::GraphOptimizer { + public: + // Constructs a TFG optimizer using the provided pipeline builder. By default, + // the optimizer will not use multi-threading. If `num_tfg_threads` is + // non-zero, then TFG will use threading with the specified number of threads. + explicit TFGGrapplerOptimizer(TFGPassPipelineBuilder builder, + unsigned num_tfg_threads = 0); + // Explicit destructor to defer instantiation of Impl. + ~TFGGrapplerOptimizer() override; + + // Constructs a name for the optimizer using the registered passes. + std::string name() const override; + // The TFG optimizer requires access to the function library. + bool UsesFunctionLibrary() const override { return true; } + + // Runs the optimizer on the GraphDef. The optimizer converts the GraphDef to + // TFG using the importer, runs the passes on the MLIR, and exports back to + // GraphDef. The result is stored in `optimized_graph`. + absl::Status Optimize(tensorflow::grappler::Cluster* cluster, + const tensorflow::grappler::GrapplerItem& item, + tensorflow::GraphDef* optimized_graph) override; + + private: + // Hide the implementation details. + class Impl; + std::unique_ptr impl_; +}; + +} // end namespace tfg +} // end namespace mlir + +#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_TFG_OPTIMIZER_HOOK_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/tfg_passes_builder.h b/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/tfg_passes_builder.h new file mode 100644 index 00000000..4aee20b7 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/grappler/optimizers/tfg_passes_builder.h @@ -0,0 +1,38 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_TFG_PASSES_BUILDER_H_ +#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_TFG_PASSES_BUILDER_H_ + +#include "mlir/Pass/PassManager.h" // from @llvm-project +#include "tensorflow/core/protobuf/rewriter_config.pb.h" + +namespace mlir { +namespace tfg { + +// Constructs the default graph/function-level TFG pass pipeline. +void DefaultGrapplerPipeline(PassManager& manager); + +// Constructs the default module-level TFG pass pipeline. +void DefaultModuleGrapplerPipeline(PassManager& manager, + const tensorflow::RewriterConfig& config); + +// Constructs the Remapper pass pipeline. +void RemapperPassBuilder(PassManager& manager); + +} // namespace tfg +} // namespace mlir + +#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_TFG_PASSES_BUILDER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/grappler/utils.h b/third_party/tflite-hdrs/tensorflow/core/grappler/utils.h new file mode 100644 index 00000000..e437ebe0 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/grappler/utils.h @@ -0,0 +1,440 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPPLER_UTILS_H_ +#define TENSORFLOW_CORE_GRAPPLER_UTILS_H_ + +#include +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/container/node_hash_map.h" +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/strings/match.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/graph/tensor_id.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/core/threadpool.h" +#include "tensorflow/core/lib/gtl/flatmap.h" +#include "tensorflow/core/lib/gtl/flatset.h" +#include "tensorflow/core/lib/gtl/inlined_vector.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +namespace grappler { + +// Utilities for manipulating node name and input strings. + +// Returns the trailing position number (or zero if no number is present) if +// NodeName(input_name) is equal to node_name. Returns -1 for control inputs. +// Returns -2 if input_name is empty or NodeName(input_name) is not equal to +// node_name. +inline int NodePositionIfSameNode(absl::string_view input_name, + absl::string_view node_name) { + bool is_control = absl::StartsWith(input_name, "^"); + if (is_control) input_name.remove_prefix(1); + if (input_name.empty() || node_name.empty() || + input_name.size() < node_name.size()) { + return -2; + } + TensorId id = ParseTensorName(input_name); + if (id.first != node_name) return -2; + if (is_control) return -1; + return id.second; +} + +// Returns the node name and position in a single call. +inline absl::string_view ParseNodeNameAsStringPiece(absl::string_view name, + int* position) { + const bool is_control = absl::StartsWith(name, "^"); + TensorId id = ParseTensorName(name); + if (position) { + *position = is_control ? -1 : id.second; + } + if (is_control && id.second >= 0) { + id.first.remove_prefix(1); + } + return id.first; +} + +// Returns the node name and position in a single call. +inline string ParseNodeName(const string& name, int* position) { + return string(ParseNodeNameAsStringPiece(name, position)); +} + +// Return the node name corresponding to 'name' if name is valid, or the empty +// string otherwise. +inline absl::string_view NodeNameAsStringPiece(const string& name) { + return ParseNodeNameAsStringPiece(name, nullptr); +} + +// Return the node name corresponding to 'name' if name is valid, or the empty +// string otherwise. +inline string NodeName(const string& name) { + return string(NodeNameAsStringPiece(name)); +} + +inline int NodePosition(const string& name) { + int position; + ParseNodeNameAsStringPiece(name, &position); + return position; +} + +namespace internal { +// Base template class for NodeMap and ImmutableNodeMap. +template +class NodeMapInternal { + public: + // Note: The NodeMap will store pointers to nodes in graph, which may become + // invalid if graph is changed. + explicit NodeMapInternal(GraphDefT* graph) { + if (graph == nullptr) { + LOG(WARNING) << "NodeMapInternal constructor is called with a nullptr!"; + return; + } + nodes_.reserve(graph->node_size()); + outputs_.reserve(graph->node_size()); + for (int i = 0; i < graph->node_size(); i++) { + NodeDefT* node = GetNodeDefFromGraph(graph, i); + const string& node_name = node->name(); + auto rslt = nodes_.emplace(node_name, node); + // Check that the graph doesn't contain multiple nodes with the same name. + if (!rslt.second) { + // The first node found with a given name becomes the canonical. + LOG(WARNING) << "Duplicated node in the graph: " << node_name; + } + NodeDefT* canonical = rslt.second ? node : rslt.first->second; + for (const auto& input : node->input()) { + outputs_[NodeName(input)].insert(canonical); + } + } + } + + // Get unordered list of fanouts from node. Notice, that the order is + // non-deterministic. + const absl::flat_hash_set& GetOutputs( + const string& node_name) const { + auto it = outputs_.find(node_name); + if (it == outputs_.end()) { + return empty_set_; + } + return it->second; + } + + // Get fanouts ordered by name. + std::vector GetOutputsOrderedByNodeName( + const string& node_name) const { + std::vector result; + auto it = outputs_.find(node_name); + if (it != outputs_.end()) { + const absl::flat_hash_set& outputs = it->second; + result.reserve(outputs.size()); + result.assign(outputs.begin(), outputs.end()); + std::sort(result.begin(), result.end(), + [](const NodeDef* n1, const NodeDef* n2) { + return n1->name() < n2->name(); + }); + } + return result; + } + + // This method doesn't record the outputs of the added node; the outputs need + // to be explicitly added by the AddOutput method. + void AddNode(const string& node_name, NodeDefT* node) { + DCHECK(node != nullptr); + auto ret = nodes_.emplace(node_name, node); + DCHECK(ret.second) + << "Pair (" << node_name << "," << node + << ") is not inserted because the same key already exists."; + } + + void RemoveNode(const string& name) { + nodes_.erase(NodeName(name)); + outputs_.erase(NodeName(name)); + } + + NodeDefT* GetNode(const string& name) const { + const string node_name = NodeName(name); + auto it = nodes_.find(node_name); + if (it == nodes_.end()) { + VLOG(1) << "Node could not be found: " << name; + return nullptr; + } + return it->second; + } + + bool NodeExists(const string& name) const { + const string node_name = NodeName(name); + return nodes_.find(node_name) != nodes_.end(); + } + + void AddOutput(const string& node_name, const string& output_name) { + auto output_node = nodes_[NodeName(output_name)]; + DCHECK(output_node) << "Output node " << output_name + << " is missing in NodeMap."; + outputs_[node_name].insert(output_node); + } + + void RemoveOutput(const string& node_name, const string& output_name) { + outputs_[node_name].erase(nodes_[NodeName(output_name)]); + } + + void UpdateInput(const string& node_name, const string& old_input_name, + const string& new_input_name) { + RemoveOutput(NodeName(old_input_name), node_name); + AddOutput(NodeName(new_input_name), node_name); + } + + void RemoveInputs(const string& node_name) { + auto node = nodes_[node_name]; + for (const auto& input : node->input()) { + RemoveOutput(NodeName(input), node->name()); + } + } + + void RemoveOutputs(const string& node_name) { outputs_.erase(node_name); } + + void UpdateOutput(const string& node_name, const string& old_output_name, + const string& new_output_name) { + absl::flat_hash_set& outputs = outputs_[node_name]; + outputs.erase(nodes_[NodeName(old_output_name)]); + outputs.insert(nodes_[NodeName(new_output_name)]); + } + + private: + // Helper method to get the NodeDef pointer of i-th node in a graph. + inline NodeDefT* GetNodeDefFromGraph(GraphDefT* graph, int64_t i) const; + + const absl::flat_hash_set empty_set_; + absl::node_hash_map nodes_; + absl::node_hash_map> outputs_; +}; + +// Specialized template class method GetNodeDefFromGraph. +template <> +inline NodeDef* NodeMapInternal::GetNodeDefFromGraph( + GraphDef* graph, int64_t i) const { + return graph->mutable_node(i); +} + +template <> +inline const NodeDef* +NodeMapInternal::GetNodeDefFromGraph( + const GraphDef* graph, int64_t i) const { + return &graph->node(i); +} +} // namespace internal + +// A utility class to lookup a node and its outputs by node name. +class NodeMap : public internal::NodeMapInternal { + public: + explicit NodeMap(GraphDef* graph) : NodeMapInternal(graph) {} +}; + +// Same to NodeMap, but uses const GraphDef. +class ImmutableNodeMap + : public internal::NodeMapInternal { + public: + explicit ImmutableNodeMap(const GraphDef* graph) : NodeMapInternal(graph) {} +}; + +// A vector with a set. The set stores the same elements as the vector, and +// quickly answers whether a value is in the vector. Duplicated elements are not +// allowed for now. +template > +class SetVector { + public: + // Returns false if value already existed in the set, true otherwise. + bool PushBack(const T& value) { + if (!set_.insert(value).second) { + return false; + } + vector_.push_back(value); + return true; + } + + T PopBack() { + T back = vector_.back(); + set_.erase(back); + vector_.pop_back(); + return back; + } + + bool Exists(const T& value) const { return set_.find(value) != set_.end(); } + + bool Empty() const { return vector_.empty(); } + + void Reserve(int64_t size) { vector_.reserve(size); } + + private: + gtl::FlatSet set_; + std::vector vector_; +}; + +// Returns formatted string from TensorId specific to grappler. Specifically, +// for the 0 port (first output), only the node name is returned. +string TensorIdToString(const TensorId& tensor_id); + +// Returns formatted string from SafeTensorId specific to grappler. +// Specifically, for the 0 port (first output), only the node name is returned. +string SafeTensorIdToString(const SafeTensorId& tensor_id); + +// True iff 'name' refers to a control inputs, i.e. a node name prefixed with +// the ^ character. +bool IsControlInput(absl::string_view name); + +// True iff tensor index refers to a control input. +bool IsControlInput(const TensorId& tensor_id); + +// True iff 'name1' and 'name2' refer to the same input. +bool IsSameInput(const string& name1, const string& name2); + + +// Add a prefix to a node name with a custom delimiter. +string AddPrefixToNodeName(const string& name, const string& prefix, + const string& delimiter); + +// Add a prefix to a node name. +string AddPrefixToNodeName(const string& name, const string& prefix); + +// Executes a 'fn' in the 'thread_pool'. The method waits for the configured +// timeout (in milliseconds) for 'fn' to complete, before returning false. +// +// If returning false, the 'fn' may still continue to execute in the +// thread-pool. It is the responsibility of the caller to reset the thread-pool +// as appropriate. +bool ExecuteWithTimeout(std::function fn, int64_t timeout_in_ms, + thread::ThreadPool* thread_pool); + +// Returns the node name prefixed with conventional symbol '^' +// for control dependency, given a NodeDef. +string AsControlDependency(const NodeDef& node); + +// Returns the node name prefixed with conventional symbol '^' +// for control dependency, given a node name +string AsControlDependency(const string& node); + +// Returns true if the node is assigned to run on CPU device. +bool NodeIsOnCpu(const NodeDef* node); + +// Returns true if the node is assigned to run on GPU device. +bool NodeIsOnGpu(const NodeDef* node); + +// Returns the number of outputs of a node according to its OpDef. Note that +// some of the outputs may be unconnected. +int NumOutputs(const NodeDef& node, GraphDef* graph); + +// Returns true iff the node has at least one control input. +bool HasControlInputs(const NodeDef& node); + +// Returns true iff the node has at least one regular input. +bool HasRegularInputs(const NodeDef& node); + +// Returns true iff the node has at least one regular output. +bool HasRegularOutputs(const NodeDef& node, const NodeMap& node_map); + +// Returns true iff the node has at least one control output. +bool HasControlOutputs(const NodeDef& node, const NodeMap& node_map); + +// Number of connected control inputs. +int NumControlInputs(const NodeDef& node); + +// Number of connected non-control inputs. +int NumNonControlInputs(const NodeDef& node); + +// Number of connected control outputs. +int NumControlOutputs(const NodeDef& node, const NodeMap& node_map); + +// Number of connected non-control outputs. +int NumNonControlOutputs(const NodeDef& node, const NodeMap& node_map); + +// Number of connected non-control data outputs (Ops that consume output tensor +// data, not just it's shape). +int NumNonControlDataOutputs(const NodeDef& node, const NodeMap& node_map); + +// Removes redundant control inputs from node. +void DedupControlInputs(NodeDef* node); + +// Returns an error if an attribute with the given key does not exist in node. +absl::Status CheckAttrExists(const NodeDef& node, const string& key); + +// Returns an error if attributes with the given keys do not exist in node. +absl::Status CheckAttrsExist(const NodeDef& node, + absl::Span keys); + +// Returns the data type in attribute `attr_name` of `node`. If that attribute +// doesn't exist, returns DT_INVALID. +DataType GetDataTypeFromAttr(const NodeDef& node, const string& type_attr); + +// Returns the last node in the simple chain starting at source and traversing +// through the input(0) edge from each node as long as the next node satisfies +// the predicate given in pred_fn. If no nodes satisfy the predicate, &source +// will be returned. Example: For the chain +// source <- a <- b <- ... <- y <- z +// where +// pred_fn(a) = pred_fn(b) = ... = pred_fn(y) = true, +// pred_fn(z) = false, +// the return value will be a pointer to y. +NodeDef* GetTailOfChain(const NodeDef& source, const NodeMap& node_map, + bool follow_control_input, + const std::function& pred_fn); + +// Permute the nodes of graph in place according to the permutation. +void PermuteNodesInPlace(GraphDef* graph, std::vector* permutation, + bool invert_permutation); + +// Returns OkStatus() if a kernel is registered for node.op() on the device +// type corresponding to node.device(). +absl::Status IsKernelRegisteredForNode( + absl::string_view node_name, bool has_experimental_debug_info, + const NodeDef_ExperimentalDebugInfo& experimental_debug_info, + absl::string_view node_op, absl::string_view node_device, + AttrSlice node_attrs); +absl::Status IsKernelRegisteredForNode(const NodeDef& node); + +absl::Status SetTensorValue(DataType dtype, int value, Tensor* tensor); + +void EraseNodesFromGraph(const std::set& nodes_to_delete, GraphDef* graph); + +void EraseNodesFromGraph(std::vector&& nodes_to_delete, GraphDef* graph); + +void EraseNodesFromGraph(const std::set& nodes_to_delete, + GraphDef* graph); + +// Erase all attributes without leading underscore. Returns the number of +// attributes erased. +int EraseRegularNodeAttributes(NodeDef* node); + +// Erase attribute "_xla_inferred_shapes" as well as all attributes starting in +// "_output_". +int EraseNodeOutputAttributes(NodeDef* node); + +} // end namespace grappler +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPPLER_UTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/grappler/utils/canonicalizer.h b/third_party/tflite-hdrs/tensorflow/core/grappler/utils/canonicalizer.h new file mode 100644 index 00000000..a913fc25 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/grappler/utils/canonicalizer.h @@ -0,0 +1,45 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPPLER_UTILS_CANONICALIZER_H_ +#define TENSORFLOW_CORE_GRAPPLER_UTILS_CANONICALIZER_H_ + +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { +namespace grappler { + +// Canonicalizes node by performing the following steps +// - sorting control inputs, +// - sorting data inputs if the node represents a commutative op. +void CanonicalizeNode(NodeDef* node); + +// Canonicalizes all nodes in graph. +void CanonicalizeGraph(GraphDef* graph); + +// Compresses Const and HostConstant nodes in the graph to the smallest +// representation possible, either +// a) truncated repeated field representation, or +// b) raw serialized byte format. +// Each node is only modified if it is larger than 64 bytes and compression +// reduces its size by more than 50%. +void CompressConstants(GraphDef* graph); + +} // namespace grappler +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPPLER_UTILS_CANONICALIZER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/grappler/utils/colocation.h b/third_party/tflite-hdrs/tensorflow/core/grappler/utils/colocation.h new file mode 100644 index 00000000..6062db61 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/grappler/utils/colocation.h @@ -0,0 +1,39 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPPLER_UTILS_COLOCATION_H_ +#define TENSORFLOW_CORE_GRAPPLER_UTILS_COLOCATION_H_ + +#include +#include "tensorflow/core/framework/graph.pb.h" + +namespace tensorflow { +namespace grappler { + +// Evaluates the colocation relation in the graph and rewrites the new +// colocation relation in the graph. We scan the graph nodes sequentially, and +// builds a disjoint-sets of nodes (within each disjoint-set the nodes are +// colocated with each other). We then select the root node of each set as a +// representative node, and then colocate each node within the set (should also +// exist in graph) with the representative node. +// Note that there is current one situation this function can't handle: +// Node A colocates with X, node B colocates with Y, X colocates with Y but +// X, Y are removed from graph. In this case we can't know A colocates with B. +void ReassignColocation(GraphDef* graph); + +} // namespace grappler +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPPLER_UTILS_COLOCATION_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/grappler/utils/frame.h b/third_party/tflite-hdrs/tensorflow/core/grappler/utils/frame.h new file mode 100644 index 00000000..d66cfb58 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/grappler/utils/frame.h @@ -0,0 +1,75 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPPLER_UTILS_FRAME_H_ +#define TENSORFLOW_CORE_GRAPPLER_UTILS_FRAME_H_ + +#include "absl/container/flat_hash_map.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/grappler/utils/graph_view.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { +namespace grappler { + +// FrameView is a helper class that allows to find in what execution frames (if +// any) the given node can be running in. It's constructed from an immutable +// GraphView, and any modification of the underlying graph might invalidate it. +// +// All execution frames assigned a unique integer id, but they do not have any +// meaning whatsoever, it's just a sequence number. +// +// See the paper "Dynamic Control Flow in Large-Scale Machine Learning" for +// detailed explanation of execution frames (https://arxiv.org/abs/1805.01772). +class FrameView { + public: + FrameView() : is_inferred_(false), num_frames_(0) {} + + // Infers nodes execution frames from the GraphView. Returns an error if + // called multiple times. + absl::Status InferFromGraphView(const utils::GraphView& graph_view); + // Infers nodes execution frames from the MutableGraphView. Returns an error + // if called multiple times. + absl::Status InferFromGraphView(const utils::MutableGraphView& graph_view); + // Infers nodes execution by constructing temporary GraphView and passing it + // to InferFromGraphView. + absl::Status InferFromGraph(const GraphDef& graph); + + // Returns all frames of the given node (denoted by their frame ids) in + // outermost-to-innermost order. + const std::vector& Frames(const NodeDef& node) const; + + // Returns true iff the node is at least in one execution frame. + bool IsInFrame(const NodeDef& node) const; + + int num_frames() const { return num_frames_; } + bool is_inferred() const { return is_inferred_; } + + private: + template + inline absl::Status InferFromGraphViewT(const GraphViewT& graph_view); + + bool is_inferred_; // true if it was inferred from the graph + int num_frames_; // number of frames present in a graph + absl::flat_hash_map> node_to_frames_; + + // We return a reference to this vector if node has no frames. + const std::vector node_has_no_frames_; +}; + +} // namespace grappler +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPPLER_UTILS_FRAME_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/grappler/utils/functions.h b/third_party/tflite-hdrs/tensorflow/core/grappler/utils/functions.h new file mode 100644 index 00000000..0006a260 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/grappler/utils/functions.h @@ -0,0 +1,190 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPPLER_UTILS_FUNCTIONS_H_ +#define TENSORFLOW_CORE_GRAPPLER_UTILS_FUNCTIONS_H_ + +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/container/inlined_vector.h" +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/function.pb.h" +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/op_def.pb.h" +#include "tensorflow/core/grappler/grappler_item.h" +#include "tensorflow/core/lib/gtl/flatset.h" + +namespace tensorflow { +namespace grappler { + +// Function input argument instantiated into an '_Arg' node in the function body +// graph, with an 'index' attribute corresponding to the input position. +struct InputArgInstantiation { + InputArgInstantiation(string node_name, DataType data_type) + : node_name(std::move(node_name)), data_type(data_type) {} + string node_name; + DataType data_type; +}; + +// Function output instantiated into a '_Retval' node in the function body +// graph, with an 'index' attribute corresponding to the output position. +struct OutputArgInstantiation { + OutputArgInstantiation(string node_name, DataType data_type) + : node_name(std::move(node_name)), data_type(data_type) {} + string node_name; + DataType data_type; +}; + +// A mapping from control output name to node name in function body graph. +struct ControlOutput { + string output_name; + string node_name; + bool operator<(const ControlOutput& a) const { + return output_name < a.output_name; + } +}; + +// A special case of GrapplerItem, constructed from a TensorFlow Function. +class GrapplerFunctionItem : public GrapplerItem { + public: + GrapplerFunctionItem() = default; + + const string& description() const; + + const std::vector& inputs() const; + const InputArgInstantiation& input(int i) const; + const std::size_t input_size() const; + + const std::vector& outputs() const; + const OutputArgInstantiation& output(int i) const; + const std::size_t output_size() const; + + const std::vector& control_outputs() const; + const std::size_t control_output_size() const; + + const AttrSlice& func_attr() const; + const std::vector& arg_attr() const; + const GraphDef& function_body() const; + GraphDef& mutable_function_body(); + + bool is_stateful() const; + + GrapplerFunctionItem& SwapFunctionBody(GraphDef&& other); + + private: + friend absl::Status MakeGrapplerFunctionItem(const FunctionDef&, + const AttrSlice&, + const FunctionLibraryDefinition&, + int, GrapplerFunctionItem*); + friend absl::Status ReplaceInputWithConst(const NodeDef&, int, + GrapplerFunctionItem*); + friend absl::Status RemoveFunctionOutputs(const absl::flat_hash_set&, + GrapplerFunctionItem*, + std::vector>*); + + GrapplerFunctionItem(string func_name, string description, + AttrSlice func_attr, + std::vector arg_attr, + std::vector input_args, + std::vector output_args, + std::vector control_outputs, + int graph_def_version, bool is_stateful, + GraphDef&& function_body); + + string description_; + AttrSlice func_attr_; // Attributes specific to function definition that + // produced this item (FuncDef.attr field). + + // Attributes of function arguments + std::vector arg_attr_; + + std::vector input_args_; + std::vector output_args_; + std::vector control_outputs_; + + bool is_stateful_ = false; +}; + +// Check if function input/output types are fully defined only at instantiation +// time (parametrized by its instantiation node). +bool HasParametrizedType(const FunctionDef& func); + +// Check if a function body is parametrized by its instantiation node. Function +// body is parametrized, if it has at least one node with a 'placeholder' +// attribute. +bool HasParametrizedBody(const FunctionDef& func); + +// Check if function has parametrized type or body. +bool IsParametrized(const FunctionDef& func); + +// Resolve function instantiation type parameters from the attributes of the +// caller node. Return error if type can't be resolved. +absl::Status InstantiationTypeParameters( + const FunctionDef& func, const AttrSlice& func_instantiation_attr, + absl::flat_hash_map* type_parameters); + +// Resolve function instantiation body parameters (values for the function body +// attr placeholders) from the attributes of the caller node. Return error if +// type can't be resolved. +absl::Status InstantiationBodyParameters( + const FunctionDef& func, const AttrSlice& func_instantiation_attr, + absl::flat_hash_map* body_parameters); + +// Replace one of the function inputs with a constant. +absl::Status ReplaceInputWithConst(const NodeDef& input_const, int input_index, + GrapplerFunctionItem* item); + +// Removes outputs from instantiated grappler function item. For all active +// function outputs that changed its output index, this function adds an output +// mapping (std::pair). +absl::Status RemoveFunctionOutputs( + const absl::flat_hash_set& remove_outputs, GrapplerFunctionItem* item, + std::vector>* output_mapping); + +// TODO(ezhulenev, b/120103818): Add RemoveFunctionInputs. + +// Make a GrapplerFunctionItem from the function definition and function +// instantiation attributes (caller node attributes). Returns error if the given +// function def cannot be converted (e.g. not all attributes are defined). +absl::Status MakeGrapplerFunctionItem(const FunctionDef& func, + const AttrSlice& func_instantiation_attr, + const FunctionLibraryDefinition& flib, + int graph_def_version, + GrapplerFunctionItem* item); + +// Make a GrapplerFunction item from the function definition. Function must be +// fully defined (no type or body parametrization). +// TODO(ezhulenev): Support parametrized functions without fully defined +// instantiation attributes? Do we ever want to optimize parametrized function +// without specializing it to its instantiation attributes (at least types)? +absl::Status MakeGrapplerFunctionItem(const FunctionDef& func, + const FunctionLibraryDefinition& flib, + int graph_def_version, + GrapplerFunctionItem* item); + +// Make a FunctionDef from the GrapplerFunctionItem. Use function library +// definition to lookup function body nodes output names and ranges. +absl::Status MakeFunctionDef(const GrapplerFunctionItem& item, + const FunctionLibraryDefinition& flib, + FunctionDef* func); + +} // end namespace grappler +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPPLER_UTILS_FUNCTIONS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/grappler/utils/graph_view.h b/third_party/tflite-hdrs/tensorflow/core/grappler/utils/graph_view.h new file mode 100644 index 00000000..3398e338 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/grappler/utils/graph_view.h @@ -0,0 +1,541 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPPLER_UTILS_GRAPH_VIEW_H_ +#define TENSORFLOW_CORE_GRAPPLER_UTILS_GRAPH_VIEW_H_ + +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/graph/tensor_id.h" +#include "tensorflow/core/grappler/utils/graph_view_internal.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { +namespace grappler { +namespace utils { + +class NodeView; + +class GraphView; + +// FaninView is a helper class to represent fanouts of a node. This holds a +// pointer to GraphView, the index of the node being represented from GraphView, +// and the input index (hence is labeled as Fanin). +class FaninView : public internal::NodeIndexAndPortIndex { + public: + FaninView() : NodeIndexAndPortIndex() {} + + FaninView(GraphView* graph_view, int node_index, int port_index) + : NodeIndexAndPortIndex(graph_view, node_index, port_index) {} + + FaninView(NodeView* node_view, int index); + + private: + friend class NodeView; + friend class GraphView; +}; + +// FanoutView is a helper class to represent fanins of a node. This holds a +// pointer to GraphView, the index of the node being represented from GraphView, +// and the output index (hence is labeled as Fanout). +class FanoutView : public internal::NodeIndexAndPortIndex { + public: + FanoutView() : NodeIndexAndPortIndex() {} + + FanoutView(GraphView* graph_view, int node_index, int port_index) + : NodeIndexAndPortIndex(graph_view, node_index, port_index) {} + + FanoutView(NodeView* node_view, int index); + + private: + friend class NodeView; + friend class GraphView; +}; + +// Immutable NodeView that keeps the constness of the NodeDef. This allows for +// lookups of fanins and fanouts, and traversals of the graph, but no mutations. +// No dedupping of fanins will be performed on the node to preserve it's +// constness. +class NodeView : public internal::NodeViewInternal { + public: + explicit NodeView(GraphView* graph_view, int node_index) + : NodeViewInternal(graph_view, node_index) {} + + NodeView() : NodeViewInternal() {} + + ~NodeView() override = default; + + NodeView(NodeView&&) = default; + NodeView& operator=(NodeView&&) = default; + + const NodeDef* node() const override; + + // Checks if a fanin exists for the node. + bool HasFanin(const FanoutView& fanin) const override; + + // Checks if a fanout exists for the node. + bool HasFanout(const FaninView& fanout) const override; + + private: + inline const FanoutView& GetMissingFanin() const override; + + inline const std::vector& GetMissingFanout() const override; + + absl::flat_hash_set fanins_set_; + + friend class FaninView; + friend class FanoutView; + friend class GraphView; +}; + +// Immutable GraphView that keeps the constness of the GraphDef. This allows +// for lookups and traversals of the graph, but no mutations. +class GraphView : public internal::GraphViewInternal { + public: + explicit GraphView(const GraphDef* graph, absl::Status* status); + ~GraphView() override = default; + + private: + bool AddUniqueNodeInternal(const NodeDef* node); + + absl::Status CheckAndAddFaninsInternal(NodeView* node_view); + + friend class NodeView; +}; + +class MutableNodeView; + +class MutableGraphView; + +class Mutation; + +// MutableFaninView is a helper class to represent fanouts of a node. This holds +// a pointer to MutableGraphView, the index of the node from MutableGraphView +// being mutated, and the input index (hence is labeled as Fanin). +class MutableFaninView + : public internal::NodeIndexAndPortIndex { + public: + MutableFaninView() : NodeIndexAndPortIndex() {} + + MutableFaninView(MutableGraphView* graph_view, int node_index, int port_index) + : NodeIndexAndPortIndex(graph_view, node_index, port_index) {} + + explicit MutableFaninView(MutableGraphView* graph_view, int node_index, + int port_index, int fanin_index) + : NodeIndexAndPortIndex(graph_view, node_index, port_index), + fanin_index_(fanin_index) { + // TODO(lyandy): Remove once constructor is not public. + DCHECK(port_index < 0 || port_index == fanin_index); + } + + MutableFaninView(MutableNodeView* node_view, int index); + + private: + // Index of associated fanin in fanout's underlying MutableNodeView. For + // regular fanouts, this will be the same as port_index (index of the + // associated fanin in MutableNodeView::regular_fanins_). For controlled + // fanouts, this will be the index of the associated fanin in + // MutableNodeView::controlling_fanins_. + int fanin_index_ = internal::kMissingIndex; + + friend class MutableNodeView; + friend class MutableGraphView; + friend class Mutation; +}; + +// MutableFanoutView is a helper class to represent fanins of a node. This holds +// a pointer to MutableGraphView, the index of the node from MutableGraphView +// being mutated, and the output index (hence is labeled as Fanout). +class MutableFanoutView + : public internal::NodeIndexAndPortIndex { + public: + MutableFanoutView() : NodeIndexAndPortIndex() {} + + MutableFanoutView(MutableGraphView* graph_view, int node_index, + int port_index) + : NodeIndexAndPortIndex(graph_view, node_index, port_index) {} + + explicit MutableFanoutView(MutableGraphView* graph_view, int node_index, + int port_index, int fanout_index) + : NodeIndexAndPortIndex(graph_view, node_index, port_index), + fanout_index_(fanout_index) {} + + MutableFanoutView(MutableNodeView* node_view, int index); + + private: + // Index of associated fanout in fanin's underlying MutableNodeView. For + // regular fanins, this will be the index of the associated fanout in + // MutableNodeView::regular_fanouts_by_port_[port_index]. For controlled + // fanins, this will be the index of the associated fanout in + // MutableNodeView::controlled_fanouts_. + int fanout_index_ = internal::kMissingIndex; + + friend class MutableNodeView; + friend class MutableGraphView; + friend class Mutation; +}; + +// Mutable NodeView that holds a mutable NodeDef. This allows for lookups of +// fanins and fanouts, and traversals of the graph. Control dependencies will be +// dedupped among other control dependencies on initialization via +// MutableGraphView. Mutations should be handled via MutableGraphView and not +// directly on the mutable NodeDef. +class MutableNodeView + : public internal::NodeViewInternal { + public: + explicit MutableNodeView(MutableGraphView* graph_view, int node_index) + : NodeViewInternal(graph_view, node_index) {} + + MutableNodeView() : NodeViewInternal() {} + + ~MutableNodeView() override = default; + + MutableNodeView(MutableNodeView&&) = default; + MutableNodeView& operator=(MutableNodeView&&) = default; + + NodeDef* node() const override; + + // Checks if a fanin exists for the node. + bool HasFanin(const MutableFanoutView& fanin) const override; + + // Checks if a fanout exists for the node. + bool HasFanout(const MutableFaninView& fanout) const override; + + private: + inline const MutableFanoutView& GetMissingFanin() const override; + + inline const std::vector& GetMissingFanout() const override; + + absl::flat_hash_map fanins_count_; + absl::flat_hash_map controlling_fanins_index_; + // Index of associated MutableNodeViewDiff in Mutation::updated_nodes_. + // If this is -1, there exists no MutableNodeViewDiff for this node. + int update_index_ = internal::kMissingIndex; + + friend class MutableFaninView; + friend class MutableFanoutView; + friend class MutableGraphView; + friend class Mutation; +}; + +class MutationNewNode { + public: + MutationNewNode() {} + + private: + explicit MutationNewNode(Mutation* mutation, int mutation_counter, int index) + : mutation_(mutation), + mutation_counter_(mutation_counter), + index_(index) {} + + Mutation* mutation_ = nullptr; + int mutation_counter_ = internal::kMissingSlot; + int index_ = internal::kMissingIndex; + + friend class Mutation; +}; + +// Mutation is a helper class that allows rewrites of MutableGraphView. This +// should not be initialized or be used directly. +// Note, if a node is renamed to another node, or a new node is created with the +// same name as an existing node, the node with the same name originally in the +// graph will be overwritten. +class Mutation { + public: + // Create a new node to be added to the graph. If the node's fanins are not + // well formed (self loops, control dependencies between regular fanins), the + // `status` will be set. + MutationNewNode AddNode(NodeDef&& node, absl::Status* status); + + // Remove an existing node in the graph. + void RemoveNode(MutableNodeView* node); + + // Update the name of an existing node. + void UpdateNodeName(MutableNodeView* node, absl::string_view name); + + // Update the name of a new node. + void UpdateNodeName(const MutationNewNode& node, absl::string_view name); + + // Update the op of an existing node. + void UpdateNodeOp(MutableNodeView* node, absl::string_view op); + + // Update the op of a new node. + void UpdateNodeOp(const MutationNewNode& node, absl::string_view op); + + // Update the device of an existing node. + void UpdateNodeDevice(MutableNodeView* node, absl::string_view device); + + // Update the device of a new node. + void UpdateNodeDevice(const MutationNewNode& node, absl::string_view device); + + // Add or replace regular fanin `fanin` at `index` for an existing node. + void AddOrUpdateRegularFanin(MutableNodeView* node, int index, + const TensorId& fanin); + + // Add or replace regular fanin `fanin` at `index` for a new node. + void AddOrUpdateRegularFanin(const MutationNewNode& node, int index, + const TensorId& fanin); + + // Remove regular fanin at `index` for an existing node. + void RemoveRegularFanin(MutableNodeView* node, int index); + + // Remove regular fanin at `index` for a new node. + void RemoveRegularFanin(const MutationNewNode& node, int index); + + // Add controlling fanin `fanin_node_name` for an existing node. + void AddControllingFanin(MutableNodeView* node, + absl::string_view fanin_node_name); + + // Add controlling fanin `fanin_node_name` for a new node. + void AddControllingFanin(const MutationNewNode& node, + absl::string_view fanin_node_name); + + // Remove controlling fanin `fanin_node_name` for an existing node. + void RemoveControllingFanin(MutableNodeView* node, + absl::string_view fanin_node_name); + + // Remove controlling fanin `fanin_node_name` for a new node. + void RemoveControllingFanin(const MutationNewNode& node, + absl::string_view fanin_node_name); + + // Add or replace attribute `attr_name` with `attr_value` for an existing + // node. + void AddOrUpdateNodeAttr(MutableNodeView* node, absl::string_view attr_name, + const AttrValue& attr_value); + + // Add or replace attribute `attr_name` with `attr_value` for a new node. + void AddOrUpdateNodeAttr(const MutationNewNode& node, + absl::string_view attr_name, + const AttrValue& attr_value); + + // Remove attribute `attr_name` for an existing node. + void RemoveNodeAttr(MutableNodeView* node, absl::string_view attr_name); + + // Remove attribute `attr_name` for a new node. + void RemoveNodeAttr(const MutationNewNode& node, absl::string_view attr_name); + + // Reset and clear mutation. + void Reset(); + + // Applies the Mutation to the graph. If the mutation is valid, the graph will + // be modified. Otherwise an error status will be returned and the graph will + // not be modified. + absl::Status Apply(); + + private: + explicit Mutation(MutableGraphView* graph_view); + + void ResetInternal(); + + using MutableNodeViewDiff = internal::NodeViewDiff; + + // Adds a mutation to the `node`. Mutation function `mutate_fn` must return + // `true` if it actually does any mutations. If it returns `false` mutation + // will be ignored. + void AddMutation(MutableNodeView* node, + std::function mutate_fn); + + MutableGraphView* graph_view_ = nullptr; + int mutation_counter_ = 0; + std::vector updated_nodes_; + absl::flat_hash_set removed_nodes_; + + using MutationNewNodeHolder = internal::NewNode; + std::vector new_nodes_; + + friend class MutableGraphView; +}; + +// Mutable GraphView that holds a mutable GraphDef. This allows for lookups and +// traversals of the graph. Control dependencies will be dedupped among other +// control dependencies on initialization. Mutations should be handled using +// this API instead of directly on the GraphDef/NodeDef. +// Note, after a mutation, pointers of MutableNodeView's from MutableGraphView +// may be invalidated. +class MutableGraphView + : public internal::GraphViewInternal { + public: + explicit MutableGraphView(GraphDef* graph, absl::Status* status); + ~MutableGraphView() override = default; + + // Returns a Mutation (builder) that can be used to modify MutableGraphView. + Mutation* GetMutationBuilder(); + + // Helper class representing an extra dependency for topological sorting. + class TopologicalDependency { + public: + TopologicalDependency(const MutableNodeView* from_node, + const MutableNodeView* to_node) { + if (from_node->graph_view_ == to_node->graph_view_) { + graph_view_ = from_node->graph_view_; + from_ = from_node->node_index_; + to_ = to_node->node_index_; + } + } + + private: + MutableGraphView* graph_view_ = nullptr; + int from_ = internal::kMissingIndex; + int to_ = internal::kMissingIndex; + + friend class MutableGraphView; + }; + + // Sorts graph topologically in-place. If `ignore_cycles` is set, a + // topological like sorting will be performed when there are cycles. Otherwise + // if a cycle is detected or if the graph cannot be sorted, an error will be + // returned. + absl::Status SortTopologically( + bool ignore_cycles, + absl::Span extra_dependencies); + + private: + bool AddUniqueNodeInternal(NodeDef* node); + + absl::Status CheckFaninsInternal(std::vector>* fanins); + + void AddFaninsInternal(std::vector>* fanins); + + // RenamedOrOverwrittenNode holds a index to Mutation::updated_nodes_ for a + // renamed node, alongside a potential overwritten node index in the actual + // graph. If the renamed node is not overwriting any existing nodes, + // `overwritten_node_index_` will be set to `internal::kMissingIndex`. + class RenamedOrOverwrittenNode { + public: + RenamedOrOverwrittenNode(int renamed_update_index, + int overwritten_node_index) + : renamed_update_index_(renamed_update_index), + overwritten_node_index_(overwritten_node_index) {} + + private: + int renamed_update_index_; + int overwritten_node_index_; + + friend class MutableGraphView; + }; + + absl::Status GetNodeNamesAndPartitionUpdatedNodes( + absl::flat_hash_map* node_names, + std::vector* renamed_nodes, + std::vector* inplace_nodes, + std::vector* empty_diff_node_indices); + + absl::Status RemovedOrMissingNodeFanoutsWellFormed( + const absl::flat_hash_map& node_names, + const std::vector& renamed_nodes); + + absl::Status CheckNodeNamesAndFanins( + const absl::flat_hash_map& node_names, + const std::vector& renamed_nodes, + const std::vector& inplace_nodes); + + absl::Status CheckKernelRegisteredForNodes(); + + // Helper class to move fanouts around. + class NodeViewFanouts { + public: + NodeViewFanouts( + std::vector>&& regular_fanouts_by_port, + int num_regular_fanouts, + std::vector controlled_fanouts) + : regular_fanouts_by_port_(std::move(regular_fanouts_by_port)), + num_regular_fanouts_(num_regular_fanouts), + controlled_fanouts_(std::move(controlled_fanouts)) {} + + private: + std::vector> regular_fanouts_by_port_; + int num_regular_fanouts_ = 0; + std::vector controlled_fanouts_; + + friend class MutableGraphView; + }; + + template + void ReplaceNodeFanouts(MutableNodeView* node, T* fanouts); + + void FixRenamedNodes( + std::vector* renamed_nodes, + absl::flat_hash_map* renamed_fanouts, + std::vector* overwritten_name_removed_nodes); + + void AddNewNodes( + absl::flat_hash_map* renamed_fanouts, + std::vector* new_node_indices); + + void FixRenamedFanouts( + const absl::flat_hash_map& renamed_fanouts); + + inline void RemoveRegularFaninFanoutInternal(MutableNodeView* node_view, + int i); + + inline void AddRegularFaninInternal(MutableNodeView* node_view, + const SafeTensorId& fanin_id); + + inline void UpdateRegularFaninInternal(MutableNodeView* node_view, + const int i, + const SafeTensorId& fanin_id); + + inline void RemoveControllingFaninFanoutInternal(MutableNodeView* node_view, + int i); + + inline void RemoveControllingFaninInternal( + MutableNodeView* node_view, const std::set& indices_to_remove); + + inline void AddControllingFaninInternal(MutableNodeView* node_view, + absl::string_view fanin_node_name); + + void ApplyNodeUpdates(); + + void SetNewNodesFanins(const std::vector& new_node_indices); + + inline void RemoveAllFaninFanoutInternal(MutableNodeView* node_view); + + void RemoveNodesInternal( + const std::vector& renamed_nodes, + const std::vector& overwritten_name_removed_nodes); + + inline absl::Status ValidateInternal( + absl::flat_hash_map* node_names, + std::vector* renamed_nodes, + std::vector* inplace_nodes, + std::vector* empty_diff_node_indices); + + absl::Status ApplyMutationInternal(); + + Mutation mutation_; + + friend class MutableNodeView; + friend class Mutation; +}; + +} // namespace utils +} // namespace grappler +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPPLER_UTILS_GRAPH_VIEW_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/grappler/utils/graph_view_internal.h b/third_party/tflite-hdrs/tensorflow/core/grappler/utils/graph_view_internal.h new file mode 100644 index 00000000..d66b1ca0 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/grappler/utils/graph_view_internal.h @@ -0,0 +1,920 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPPLER_UTILS_GRAPH_VIEW_INTERNAL_H_ +#define TENSORFLOW_CORE_GRAPPLER_UTILS_GRAPH_VIEW_INTERNAL_H_ + +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/hash/hash.h" +#include "absl/strings/string_view.h" +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/graph/tensor_id.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/gtl/map_util.h" + +namespace tensorflow { +namespace grappler { +namespace utils { +namespace internal { + +constexpr int kMissingSlot = -2; +constexpr int kMissingIndex = -1; +constexpr int kNodeNamePresent = -1; + +// NodeIndexAndPortIndex is a helper class that represents fanins and fanouts +// of a node. +template +class NodeIndexAndPortIndex { + public: + NodeIndexAndPortIndex() + : graph_view_(nullptr), + node_index_(kMissingIndex), + port_index_(kMissingSlot) {} + NodeIndexAndPortIndex(GraphViewT* graph_view, int node_index, int port_index) + : graph_view_(graph_view), + node_index_(node_index), + port_index_(port_index) {} + + bool operator==(const NodeIndexAndPortIndex& other) const { + return port_index_ == other.port_index_ && + node_index_ == other.node_index_ && graph_view_ == other.graph_view_; + } + + template + friend Hash AbslHashValue(Hash h, const NodeIndexAndPortIndex& n) { + return Hash::combine(std::move(h), n.node_index_, n.port_index_); + } + + // Returns NodeView from `graph_view_` at `node_index_`. + NodeViewT* node_view() const { + if (graph_view_ == nullptr) { + return nullptr; + } + return graph_view_->GetNode(node_index_); + } + + // Returns node index in graph. + int node_index() const { return node_index_; } + + // Returns input/output port index. + int index() const { return port_index_; } + + protected: + GraphViewT* graph_view_; + int node_index_; + int port_index_; +}; + +// NodeDefAndPortIndex is a helper class that represents fanins hashed with +// pointer stability using the fanin's NodeDef. +class NodeDefAndPortIndex { + public: + NodeDefAndPortIndex(const NodeDef* node_def, int port_index) + : node_def_(node_def), port_index_(port_index) {} + + bool operator==(const NodeDefAndPortIndex& other) const { + return node_def_ == other.node_def_ && port_index_ == other.port_index_; + } + + template + friend Hash AbslHashValue(Hash h, const NodeDefAndPortIndex& n) { + return Hash::combine(std::move(h), n.node_def_, n.port_index_); + } + + private: + const NodeDef* node_def_; + int port_index_; +}; + +// NodeViewInternal is a helper class to simplify graph traversal. It creates +// a view of a node and associated fanins and fanouts from the NodeDef +// protocol buffer. +// +// There are two public classes implementing NodeViewInternal: +// +// - NodeView: constructed from `const NodeDef` and doesn't allow mutating the +// underlying node. +// - MutableNodeView: constructed from `NodeDef` and allows mutating the +// underlying node. +// +// --------------------------- !!! WARNING !!! --------------------------------- +// Modifying the node outside of implementations of NodeViewInternal +// (i.e. modifying inputs of the NodeDef directly) may leave the NodeView +// in an inconsistent/invalid state. +// ----------------------------------------------------------------------------- +// +template +class NodeViewInternal { + private: + using NodeDefT = + typename std::conditional::type; + + public: + explicit NodeViewInternal(GraphViewT* graph_view, int node_index) + : graph_view_(graph_view), + node_index_(node_index), + attrs_(AttrSlice(graph_view->graph()->node(node_index))) {} + + NodeViewInternal() + : graph_view_(nullptr), node_index_(kMissingIndex), attrs_(AttrSlice()) {} + + virtual ~NodeViewInternal() {} + + NodeViewInternal(NodeViewInternal&&) = default; + NodeViewInternal& operator=(NodeViewInternal&&) = default; + + bool operator==(const NodeViewInternal& other) const { + return node_index_ == other.node_index_ && graph_view_ == other.graph_view_; + } + + template + friend Hash AbslHashValue(Hash h, const NodeViewInternal& n) { + return Hash::combine(std::move(h), n.node_index_); + } + + // Returns NodeDef of view. + virtual NodeDefT* node() const = 0; + + // Returns index of node in GraphDef/GraphView. + int node_index() const { return node_index_; } + + // Returns the name of the node. + const string& GetName() const { return node()->name(); } + + // Returns the op of the node. + const string& GetOp() const { return node()->op(); } + + // Returns the device set for the node. + const string& GetDevice() const { return node()->device(); } + + // Returns all regular fanins, based on ordering in the node. + const std::vector& GetRegularFanins() const { + return regular_fanins_; + } + + // Returns a regular fanin based on input index. If no such fanin exist, a + // missing fanin is returned, with no NodeView set and an index of -2. + const FanoutViewT& GetRegularFanin(int i) const { + int regular_fanins_size = regular_fanins_.size(); + if (i < 0 || i >= regular_fanins_size) { + return GetMissingFanin(); + } + return regular_fanins_[i]; + } + + // Returns all controlling fanins, based on ordering in the node. + const std::vector& GetControllingFanins() const { + return controlling_fanins_; + } + + // Returns all regular fanouts. + const std::vector>& GetRegularFanouts() const { + return regular_fanouts_by_port_; + } + + // Returns a regular fanout(s) based on output index. If no such output index + // exists, no fanouts will be returned. + const std::vector& GetRegularFanout(int i) const { + int regular_fanouts_by_port_size = regular_fanouts_by_port_.size(); + if (i < 0 || i >= regular_fanouts_by_port_size) { + return GetMissingFanout(); + } + return regular_fanouts_by_port_[i]; + } + + // Returns all controlled fanouts. + const std::vector& GetControlledFanouts() const { + return controlled_fanouts_; + } + + // Returns the number of regular fanins. + int NumRegularFanins() const { return regular_fanins_.size(); } + + // Returns the number of controlling fanins. + int NumControllingFanins() const { return controlling_fanins_.size(); } + + // Returns the number of regular fanouts. + int NumRegularFanouts() const { return num_regular_fanouts_; } + + // Returns the number of controlled fanouts. + int NumControlledFanouts() const { return controlled_fanouts_.size(); } + + // Checks if a fanin exists for the node. + virtual bool HasFanin(const FanoutViewT& fanin) const = 0; + + // Checks if a fanout exists for the node. + virtual bool HasFanout(const FaninViewT& fanout) const = 0; + + // Returns an attribute of the node by key. If no attribute for such key + // exists, a `nullptr` is returned. + const AttrValue* GetAttr(absl::string_view attr_name) const { + return attrs_.Find(attr_name); + } + + // Returns all attributes of the node. + const AttrSlice& GetAttrs() const { return attrs_; } + + // Returns the number of attributes in the node. + int NumAttrs() const { return attrs_.size(); } + + // Checks if an attribute exist in the node. + bool HasAttr(absl::string_view attr_name) const { + return attrs_.Find(attr_name) != nullptr; + } + + protected: + virtual inline const FanoutViewT& GetMissingFanin() const = 0; + virtual inline const std::vector& GetMissingFanout() const = 0; + + std::vector regular_fanins_; + std::vector controlling_fanins_; + std::vector> regular_fanouts_by_port_; + int num_regular_fanouts_ = 0; + std::vector controlled_fanouts_; + + GraphViewT* graph_view_; + int node_index_; + AttrSlice attrs_; +}; + +// GraphViewInternal is a helper class to simplify graph traversal. It creates +// a view of the nodes and associated fanins and fanouts from the GraphDef +// protocol buffer. +// +// There are two public classes implementing GraphViewInternal: +// +// - GraphView: constructed from `const GraphDef` and doesn't allow mutating +// the underlying graph and its nodes. +// - MutableGraphView: constructed from `GraphDef` and allows mutating the +// underlying graph and its nodes. +// +// --------------------------- !!! WARNING !!! --------------------------------- +// Modifying the graph outside of implementations of GraphViewInternal +// (i.e. removing nodes from the GraphDef directly) may lead to +// segfaults! Guaranteed by absl::string_view! +// ----------------------------------------------------------------------------- +// +template +class GraphViewInternal { + private: + using GraphDefT = + typename std::conditional::type; + + public: + explicit GraphViewInternal(GraphDefT* graph) : graph_(graph) {} + virtual ~GraphViewInternal() {} + + bool operator==(const GraphViewInternal& other) const { + return graph_ == other.graph_; + } + + GraphDefT* graph() const { return graph_; } + + // Finds node by index in the graph. If no such node exists in the graph, a + // `nullptr` is returned. + const NodeViewT* GetNode(int node_index) const { + int nodes_size = nodes_.size(); + if (node_index < 0 || node_index >= nodes_size) { + return nullptr; + } + return &nodes_[node_index]; + } + + NodeViewT* GetNode(int node_index) { + int nodes_size = nodes_.size(); + if (node_index < 0 || node_index >= nodes_size) { + return nullptr; + } + return &nodes_[node_index]; + } + + // Finds node by name. If no such node exists in the graph, a `nullptr` is + // returned. + const NodeViewT* GetNode(absl::string_view node_name) const { + auto it = node_index_by_name_.find(node_name); + if (it == node_index_by_name_.end()) { + return nullptr; + } + return &nodes_[it->second]; + } + + NodeViewT* GetNode(absl::string_view node_name) { + auto it = node_index_by_name_.find(node_name); + if (it == node_index_by_name_.end()) { + return nullptr; + } + return &nodes_[it->second]; + } + + // Returns all nodes (as NodeView) in the graph. + const std::vector& GetNodes() const { return nodes_; } + + // Checks if a node by name exists in the graph. + bool HasNode(absl::string_view node_name) const { + return node_index_by_name_.contains(node_name); + } + + // Returns the number of nodes in the graph. + int NumNodes() const { return nodes_.size(); } + + protected: + // Reset allocated node vector and node map in case of failure. + void Reset() { + std::vector().swap(nodes_); + absl::flat_hash_map().swap(node_index_by_name_); + } + + // nodes_[i] is a view of graph_.{mutable_}node(i). + std::vector nodes_; + absl::flat_hash_map node_index_by_name_; + GraphDefT* graph_; + const FanoutViewT missing_fanin_; + const std::vector missing_fanout_; +}; + +inline SafeTensorId EmptyTensorId() { + return SafeTensorId("", internal::kMissingSlot); +} + +inline bool IsEmptyTensorId(const TensorId tensor_id) { + return tensor_id.node().empty() && + tensor_id.index() == internal::kMissingSlot; +} + +// NodeViewDiff is a helper struct holding changes to be made to an existing +// node in GraphViewT. This should not be initialized or be used directly. +template +struct NodeViewDiff { + explicit NodeViewDiff(GraphViewT* graph_view, int node_index) + : graph_view(graph_view), node_index(node_index) {} + + GraphViewT* graph_view; + int node_index; + string name; + bool update_name = false; + string op; + bool update_op = false; + string device; + bool update_device = false; + // Fanins to append after existing regular fanins. + std::vector regular_inputs_to_add; + // Number of fanins to be appended. This is used for a quick comparison with + // `regular_inputs_to_add` for if there will be any missing inputs in the + // updated node. + int num_regular_inputs_to_add = 0; + // Fanins to update inplace. + std::map regular_inputs_to_update; + // Fanins from end of regular fanins to remove. This keeps track of existing + // regular fanins in the original node to remove. + std::vector regular_inputs_to_remove; + // Number of fanins marked for removal. This is used for a quick comparison + // with `regular_inputs_to_remove` for if there will be any missing inputs + // in the updated node. + int num_regular_inputs_to_remove = 0; + absl::flat_hash_set controlling_inputs_to_add; + std::set controlling_inputs_to_remove; + absl::flat_hash_map attrs_to_add; + absl::flat_hash_set attrs_to_remove; + // AttrValueMap constructor and destructor are very expensive, we will + // initialize it lazily only if needed. + absl::optional processed_attrs; +}; + +// Updates node name. If `name` is the same as the name in the original node, +// the field will be cleared in the diff. +template +inline bool UpdateName(NodeViewDiff* diff, absl::string_view name) { + if (diff->graph_view->GetNode(diff->node_index)->GetName() == name) { + diff->name.clear(); + diff->update_name = false; + } else { + diff->name = string(name); + diff->update_name = true; + } + return true; +} + +// Updates node op. If `op` is the same as the op in the original node, the +// field will be cleared in the diff. +template +inline bool UpdateOp(NodeViewDiff* diff, absl::string_view op) { + if (diff->graph_view->GetNode(diff->node_index)->GetOp() == op) { + diff->op.clear(); + diff->update_op = false; + } else { + diff->op = string(op); + diff->update_op = true; + } + return true; +} + +// Updates node device. If `device` is the same as the device in the original +// node, the field will be cleared in the diff. +template +inline bool UpdateDevice(NodeViewDiff* diff, + absl::string_view device) { + if (diff->graph_view->GetNode(diff->node_index)->GetDevice() == device) { + diff->device.clear(); + diff->update_device = false; + } else { + diff->device = string(device); + diff->update_device = true; + } + return true; +} + +// Adds or updates value in vector `v` at index `i`. This will also resize the +// vector if index `i` is out of bounds, padding the vector with +// `default_value`. Returns true if a new value was appended or if an update +// occurred where an existing value was changed from `default_value`. +template +inline bool AddOrUpdateAtIndex(std::vector* v, int i, const U& value, + const T& default_value) { + int v_size = v->size(); + if (i > v_size) { + // Resize to include `value`, filling the newly introduced gap with + // `default_value` for later checks of validity (gaps in vector). + v->reserve(i + 1); + v->resize(i, default_value); + v->push_back({value}); + } else if (i == v_size) { + // Vector is large enough, simply append `value` to the end. + v->push_back({value}); + } else { + // Update existing value. + bool updated = (*v)[i] == default_value; + (*v)[i] = {value}; + return updated; + } + return true; +} + +// Checks if a node with name `node_name` will exist in the final mutated graph. +template +inline bool CheckNodeNameExists( + absl::string_view node_name, + const absl::flat_hash_map& updated_node_names, + const GraphViewT* graph_view) { + auto it = updated_node_names.find(node_name); + if (it != updated_node_names.end()) { + return it->second == kNodeNamePresent; + } + return graph_view->HasNode(node_name); +} + +// Adds or updates regular fanin at `index` of regular fanins. If `index` is +// less than the number of regular fanins in the original node, the fanin at +// `index` in the original node will be updated with `fanin` if the fanin +// differs. If `index` is greater than or equal to the number of regular fanins, +// `fanin` will be added beyond the end of regular fanins at `index`. +template +inline bool AddOrUpdateRegularFanin(NodeViewDiff* diff, int index, + const TensorId& fanin) { + if (index < 0) { + // Not a valid index for regular fanins. + return false; + } + auto* node_view = diff->graph_view->GetNode(diff->node_index); + const int num_regular_fanins = node_view->NumRegularFanins(); + if (index < num_regular_fanins) { // Updating existing fanins. + // Calculate (relative) index from end of regular fanins, from absolute + // index from beginning of regular fanins. + const int relative_removal_index = num_regular_fanins - index - 1; + // Check if at relative index fanin was already marked for removal. + int diff_regular_inputs_to_remove_size = + diff->regular_inputs_to_remove.size(); + if (relative_removal_index < diff_regular_inputs_to_remove_size && + diff->regular_inputs_to_remove[relative_removal_index]) { + // Unmark fanin for removal. + diff->regular_inputs_to_remove[relative_removal_index] = false; + --diff->num_regular_inputs_to_remove; + } + const auto& existing_fanin = node_view->GetRegularFanin(index); + if (existing_fanin.index() != fanin.index() || + existing_fanin.node_view()->GetName() != fanin.node()) { + // Update fanin if it is different from original fanin in node. + gtl::InsertOrUpdate(&diff->regular_inputs_to_update, index, + SafeTensorId(fanin)); + } + } else { + // Add fanin beyond current fanin range. + const int relative_add_index = index - num_regular_fanins; + if (AddOrUpdateAtIndex(&diff->regular_inputs_to_add, relative_add_index, + fanin, EmptyTensorId())) { + // New fanin was added. + ++diff->num_regular_inputs_to_add; + } + } + return true; +} + +// Remove regular fanin at `index` of regular fanins. This can remove existing +// fanins and updated/added fanins via AddOrUpdateRegularFanins. +template +inline bool RemoveRegularFanin(NodeViewDiff* diff, int index) { + if (index < 0) { + // Not a valid index for regular fanins. + return false; + } + auto* node_view = diff->graph_view->GetNode(diff->node_index); + const int num_regular_fanins = node_view->NumRegularFanins(); + if (index < num_regular_fanins) { // Removing existing fanins. + // Remove updated fanin if it exists. + diff->regular_inputs_to_update.erase(index); + // Calculate (relative) index from end of regular fanins, from absolute + // index from beginning of regular fanins. + const int relative_removal_index = num_regular_fanins - index - 1; + if (AddOrUpdateAtIndex(&diff->regular_inputs_to_remove, + relative_removal_index, + /*value=*/true, /*default_value=*/false)) { + ++diff->num_regular_inputs_to_remove; + } + } else { + // Relative index from end of regular fanins. + const int relative_add_index = index - num_regular_fanins; + int diff_regular_inputs_to_add_size = diff->regular_inputs_to_add.size(); + if (relative_add_index >= diff_regular_inputs_to_add_size || + IsEmptyTensorId(diff->regular_inputs_to_add[relative_add_index])) { + // At relative index, appended regular fanin was already marked for + // removal. + return false; + } + // Remove added fanin. + diff->regular_inputs_to_add[relative_add_index] = EmptyTensorId(); + --diff->num_regular_inputs_to_add; + } + return true; +} + +// Adds controlling fanin. If the controlling fanin already exists in the +// original node, it will be dedupped. If the controlling fanin is marked for +// removal, this will reverse it. +template +inline bool AddControllingFanin(NodeViewDiff* diff, + int control_index, + absl::string_view fanin_node_name) { + if (control_index == kMissingIndex) { + diff->controlling_inputs_to_add.emplace(fanin_node_name); + } else { + diff->controlling_inputs_to_remove.erase(control_index); + } + return true; +} + +// Remove controlling fanin. If the controlling fanin does not exist in the +// original node and diff, nothing will happen. If the controlling fanin exists +// in the diff, it will be removed. Otherwise the controlling fanin will be +// marked for removal from the original node. +template +inline bool RemoveControllingFanin(NodeViewDiff* diff, + int control_index, + absl::string_view fanin_node_name) { + if (control_index == kMissingIndex) { + diff->controlling_inputs_to_add.erase(fanin_node_name); + } else { + diff->controlling_inputs_to_remove.emplace(control_index); + } + return true; +} + +// Adds or updates an attribute by name. If an attribute exist in the original +// node or diff (including those marked for removal), this will overwrite it. +template +inline bool AddOrUpdateAttribute(NodeViewDiff* diff, + absl::string_view attr_name, + const AttrValue& attr_value) { + diff->attrs_to_add.empty() ? 0 : diff->attrs_to_remove.erase(attr_name); + gtl::InsertOrUpdate(&diff->attrs_to_add, string(attr_name), attr_value); + return true; +} + +// Removes an attribute by name. If an attribute exist in the original node or +// diff, this will remove it. +template +inline bool RemoveAttribute(NodeViewDiff* diff, + absl::string_view attr_name) { + const size_t num_erased = + diff->attrs_to_add.empty() ? 0 : diff->attrs_to_add.erase(attr_name); + auto* node_view = diff->graph_view->GetNode(diff->node_index); + if (node_view->HasAttr(attr_name)) { + diff->attrs_to_remove.emplace(attr_name); + return true; + } + return num_erased > 0; +} + +// Removes trailing values in vector `v` for values equal to `value`. +template +inline void ResizeByTrimmingEndForValue(std::vector* v, const T& value) { + int curr_index = v->size(); + const int last_index = v->size() - 1; + for (int i = last_index; i >= 0; --i) { + if ((*v)[i] == value) { + curr_index = i; + } else { + break; + } + } + if (curr_index <= last_index) { + v->resize(curr_index); + } +} + +// Checks if any changes are set in the diff. +template +inline bool IsEmpty(NodeViewDiff* diff) { + ResizeByTrimmingEndForValue(&diff->regular_inputs_to_remove, false); + ResizeByTrimmingEndForValue(&diff->regular_inputs_to_add, EmptyTensorId()); + return !diff->update_name && !diff->update_op && !diff->update_device && + diff->regular_inputs_to_add.empty() && + diff->regular_inputs_to_update.empty() && + diff->regular_inputs_to_remove.empty() && + diff->controlling_inputs_to_add.empty() && + diff->controlling_inputs_to_remove.empty() && + diff->attrs_to_add.empty() && diff->attrs_to_remove.empty(); +} + +// Resets and clears existing diff. +template +inline void Reset(NodeViewDiff* diff) { + diff->name.clear(); + diff->update_name = false; + diff->op.clear(); + diff->update_op = false; + diff->device.clear(); + diff->update_device = false; + std::vector().swap(diff->regular_inputs_to_add); + diff->num_regular_inputs_to_add = false; + std::map().swap(diff->regular_inputs_to_update); + std::vector().swap(diff->regular_inputs_to_remove); + diff->num_regular_inputs_to_remove = 0; + absl::flat_hash_set().swap(diff->controlling_inputs_to_add); + std::set().swap(diff->controlling_inputs_to_remove); + absl::flat_hash_map().swap(diff->attrs_to_add); + absl::flat_hash_set().swap(diff->attrs_to_remove); +} + +// Checks if changes to node will result in a valid node. +template +inline bool IsWellFormed( + NodeViewDiff* diff, + const absl::flat_hash_map& updated_node_names) { + ResizeByTrimmingEndForValue(&diff->regular_inputs_to_remove, false); + ResizeByTrimmingEndForValue(&diff->regular_inputs_to_add, EmptyTensorId()); + int diff_regular_inputs_to_add_size = diff->regular_inputs_to_add.size(); + if (diff_regular_inputs_to_add_size != diff->num_regular_inputs_to_add) { + // Missing regular fanins in between appended fanins. + return false; + } else if (diff->num_regular_inputs_to_add > 0 && + !diff->regular_inputs_to_remove.empty()) { + // Appending new fanins while removing existing fanins, resulting in missing + // regular fanins in between. + return false; + } else if (static_cast(diff->regular_inputs_to_remove.size()) != + diff->num_regular_inputs_to_remove) { + // Regular fanins exist in between removed fanins. + return false; + } + auto* node_view = diff->graph_view->GetNode(diff->node_index); + const string& node_name = + diff->update_name ? diff->name : node_view->GetName(); + auto invalid_node_name = [&](absl::string_view fanin_node_name) -> bool { + return fanin_node_name == node_name || + !CheckNodeNameExists(fanin_node_name, updated_node_names, + diff->graph_view); + }; + + // Check if nodes of all updated and new fanins exist (from name) and if such + // fanins do not introduce self loops. Note, this will not check for if + // unmodified fanins exist. + if (diff->update_name) { + // If name of node was changed in node, check all fanins. Updated fanins are + // checked for existence and self loops. Unmodified fanins are checked for + // self loops. + // `regular_inputs_to_update`, `controlling_inputs_to_remove` are sorted, + // so iterators from these maps/sets can be incremented alongside iteration + // and be used for comparisons. + const int last_index = + node_view->NumRegularFanins() - diff->num_regular_inputs_to_remove - 1; + auto regular_to_update_it = diff->regular_inputs_to_update.begin(); + for (int i = 0; i <= last_index; ++i) { + if (regular_to_update_it != diff->regular_inputs_to_update.end() && + regular_to_update_it->first < i) { + ++regular_to_update_it; + } + if (regular_to_update_it != diff->regular_inputs_to_update.end() && + regular_to_update_it->first == i) { + if (invalid_node_name(regular_to_update_it->second.node())) { + return false; + } + } else { + const string& regular_name = + node_view->GetRegularFanin(i).node_view()->GetName(); + if (regular_name == node_name) { + return false; + } + } + } + + auto& controls = node_view->GetControllingFanins(); + const int num_controls = controls.size(); + auto control_to_remove_it = diff->controlling_inputs_to_remove.begin(); + for (int i = 0; i < num_controls; ++i) { + if (control_to_remove_it != diff->controlling_inputs_to_remove.end() && + *control_to_remove_it < i) { + ++control_to_remove_it; + } + if (control_to_remove_it != diff->controlling_inputs_to_remove.end() && + *control_to_remove_it == i) { + // Control dependency marked for removal, can be ignored. + continue; + } else if (controls[i].node_view()->GetName() == node_name) { + return false; + } + } + } else { + // Name of node was not changed, check only updated fanins under the + // assumption prior fanins were valid. + for (const auto& updated : diff->regular_inputs_to_update) { + const string& fanin_name = updated.second.node(); + if (invalid_node_name(fanin_name)) { + return false; + } + } + } + // Check appended regular fanins. + for (const auto& regular : diff->regular_inputs_to_add) { + if (invalid_node_name(regular.node())) { + return false; + } + } + // Check new controlling fanins. + for (const auto& control : diff->controlling_inputs_to_add) { + if (invalid_node_name(control)) { + return false; + } + } + + return true; +} + +// NewNode is a helper struct holding a new node to be added to a GraphViewT. +// This should not be initialized or be used directly. +template +struct NewNode { + explicit NewNode(GraphViewT* graph_view, NodeDef&& node) + : graph_view(graph_view), node(std::move(node)) {} + + GraphViewT* graph_view; + NodeDef node; + std::vector regular_fanins; + int num_regular_fanins = 0; + absl::flat_hash_set controlling_fanins; +}; + +// Updates new node name. +template +inline void UpdateName(NewNode* new_node, absl::string_view name) { + if (name.empty()) { + new_node->node.clear_name(); + } else { + new_node->node.set_name(string(name)); + } +} + +// Updates new node op. +template +inline void UpdateOp(NewNode* new_node, absl::string_view op) { + if (op.empty()) { + new_node->node.clear_op(); + } else { + new_node->node.set_op(string(op)); + } +} + +// Updates new node device. +template +inline void UpdateDevice(NewNode* new_node, + absl::string_view device) { + if (device.empty()) { + new_node->node.clear_device(); + } else { + new_node->node.set_device(string(device)); + } +} + +// Adds or updates regular fanin at `index` of regular fanins in the new node. +// If another fanin already exists at `index`, it will be replaced with `fanin`. +template +inline void AddOrUpdateRegularFanin(NewNode* new_node, int index, + const TensorId& fanin) { + if (index < 0) { + // Not a valid index for regular fanins. + return; + } else if (AddOrUpdateAtIndex(&new_node->regular_fanins, index, fanin, + EmptyTensorId())) { + ++new_node->num_regular_fanins; + } +} + +// Remove regular fanin at `index` of regular fanins in the new node. This can +// remove existing fanins and updated/added fanins via AddOrUpdateRegularFanins. +template +inline void RemoveRegularFanin(NewNode* new_node, int index) { + int new_node_regular_fanins_size = new_node->regular_fanins.size(); + if (index < 0 || index >= new_node_regular_fanins_size || + IsEmptyTensorId(new_node->regular_fanins[index])) { + return; + } + new_node->regular_fanins[index] = EmptyTensorId(); + --new_node->num_regular_fanins; +} + +// Adds controlling fanin to new node. +template +inline void AddControllingFanin(NewNode* new_node, + absl::string_view fanin_node_name) { + new_node->controlling_fanins.emplace(fanin_node_name); +} + +// Removes controlling fanin to new node. +template +inline void RemoveControllingFanin(NewNode* new_node, + absl::string_view fanin_node_name) { + new_node->controlling_fanins.erase(fanin_node_name); +} + +// Adds or updates an attribute by name to a new node. +template +inline void AddOrUpdateAttribute(NewNode* new_node, + absl::string_view attr_name, + const AttrValue& attr_value) { + gtl::InsertOrUpdate(new_node->node.mutable_attr(), string(attr_name), + attr_value); +} + +// Removes an attribute by name to a new node. +template +inline void RemoveAttribute(NewNode* new_node, + absl::string_view attr_name) { + new_node->node.mutable_attr()->erase(string(attr_name)); +} + +// Checks if current state of new node is a valid node. +template +inline bool IsWellFormed( + NewNode* new_node, + const absl::flat_hash_map& updated_node_names) { + ResizeByTrimmingEndForValue(&new_node->regular_fanins, EmptyTensorId()); + int new_node_regular_fanins_size = new_node->regular_fanins.size(); + if (new_node_regular_fanins_size != new_node->num_regular_fanins) { + return false; + } + + const string& node_name = new_node->node.name(); + auto invalid_node_name = [new_node, updated_node_names, + node_name](absl::string_view fanin_node_name) { + return fanin_node_name == node_name || + !CheckNodeNameExists(fanin_node_name, updated_node_names, + new_node->graph_view); + }; + // Check if nodes of all fanins exist (from name) and if fanins do not + // introduce self loops. + for (const auto& regular : new_node->regular_fanins) { + if (invalid_node_name(regular.node())) { + return false; + } + } + for (const auto& control : new_node->controlling_fanins) { + if (invalid_node_name(control)) { + return false; + } + } + + return true; +} + +} // namespace internal +} // namespace utils +} // namespace grappler +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPPLER_UTILS_GRAPH_VIEW_INTERNAL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/grappler/utils/grappler_test.h b/third_party/tflite-hdrs/tensorflow/core/grappler/utils/grappler_test.h new file mode 100644 index 00000000..967cff28 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/grappler/utils/grappler_test.h @@ -0,0 +1,128 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPPLER_UTILS_GRAPPLER_TEST_H_ +#define TENSORFLOW_CORE_GRAPPLER_UTILS_GRAPPLER_TEST_H_ + +#include + +#include "absl/strings/string_view.h" +#include "tensorflow/cc/framework/scope.h" +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/grappler/grappler_item.h" +#include "tensorflow/core/grappler/utils.h" +#include "tensorflow/core/lib/random/random.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/public/session_options.h" + +namespace tensorflow { +namespace grappler { + +class GrapplerTest : public ::testing::Test { + public: + GrapplerTest(); + + protected: + void DisableAllOptimizers(); + void EnableAllOptimizers(); + + std::vector EvaluateNodes( + const GraphDef& graph, const std::vector& node_names) const; + + std::vector EvaluateNodes( + const GraphDef& graph, const std::vector& node_names, + const std::vector>& inputs) const; + + std::vector EvaluateFetchNodes(const GrapplerItem& item) const; + + NodeDef* AddNode(const string& name, const string& op, + const std::vector& inputs, + const std::vector>& attributes, + GraphDef* graph) const; + + void DisableAllOptimizers(RewriterConfig* cfg); + + // Checks if two graphs are equal. Both graphs must have the same set of nodes + // with the same inputs and attributes. Nodes can be in different order. + // + // NOTE: This function uses EXPECT/ASSERT macros to check node properties + // equality, and adds all failures to the current test. + void CompareGraphs(GraphDef want, GraphDef got) const; + + // Checks if two nodes have the same name, op, inputs and attributes. + // + // NOTE: This function uses EXPECT/ASSERT macros to check node properties + // equality, and adds all failures to the current test. + void CompareNodes(const NodeDef& want, const NodeDef& got) const; + + // Checks if two functions are equal. Both functions must have the same set of + // nodes with the same inputs and attributes. Nodes can be in different order. + // + // NOTE: This function uses EXPECT/ASSERT macros to check node properties + // equality, and adds all failures to the current test. + void CompareFunctions(FunctionDef want, FunctionDef got) const; + + // Checks if node 'src' is directly connected to the input($position) of + // 'dst'. + bool IsNodesDirectlyConnected(const NodeMap& node_map, const string& src, + const string& dst, int position = 0); + + // Counts nodes of the given op-type in a graph. + int CountOpNodes(const GraphDef& graph, const string& op); + + // Get a random tensor with given shape. + template + Tensor GenerateRandomTensor(const TensorShape& shape) const { + typedef typename EnumToDataType::Type T; + Tensor tensor(DTYPE, shape); + for (auto i = 0; i < tensor.NumElements(); i++) + tensor.flat()(i) = i + random::New64() % 10; + return tensor; + } + + // Creates a random tensor with given shape using `setRandom`. + template + Tensor GenerateTensorWithSetRandom(const TensorShape& shape) const { + typedef typename EnumToDataType::Type T; + Tensor tensor(DTYPE, shape); + tensor.flat().setRandom(); + return tensor; + } + + // Get a constant tensor with given shape. + template + Tensor GenerateConstantTensor( + const TensorShape& shape, + typename EnumToDataType::Type value) const { + typedef typename EnumToDataType::Type T; + Tensor tensor(DTYPE, shape); + for (auto i = 0; i < tensor.NumElements(); i++) tensor.flat()(i) = value; + return tensor; + } + + inline tensorflow::Scope CreateScopeWithDevice(absl::string_view device) { + return tensorflow::Scope::NewRootScope().WithDevice(string(device)); + } + + private: + SessionOptions options_; +}; + +} // end namespace grappler +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPPLER_UTILS_GRAPPLER_TEST_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/grappler/utils/pattern_utils.h b/third_party/tflite-hdrs/tensorflow/core/grappler/utils/pattern_utils.h new file mode 100644 index 00000000..de4eecb8 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/grappler/utils/pattern_utils.h @@ -0,0 +1,245 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPPLER_UTILS_PATTERN_UTILS_H_ +#define TENSORFLOW_CORE_GRAPPLER_UTILS_PATTERN_UTILS_H_ + +#include "tensorflow/core/grappler/utils/graph_view.h" + +namespace tensorflow { +namespace grappler { +namespace utils { + +//------------------------------------------------------------------------------ +// A pattern can be defined by the following grammar. Here, op_type is any valid +// op name in the TensorFlow. +// +// leaf_pattern ::= `{` op_type `}` +// pattern ::= leaf_pattern | +// `{` op_type `,` `{` pattern `,` ... `,` pattern `}` `}` +// +// (1) For example, the following pattern syntax describes a pattern for +// _FusedConv2D (Conv2D + BiasAdd + Relu). Note that "*" means any type of op. +// +// {"Relu", +// { +// "BiasAdd", +// { +// {"Conv2D"}, +// {"*"} +// } +// } +// } +// +// The syntax above has a root ("Relu") and children (inputs), where each child +// is a sub-pattern. Graph pattern matcher finds a match for the given pattern +// syntax in a graph and returns a set of matched nodes. +// +// (2) In order to match a DAG with a given root, we extend pattern syntax with +// labels. For example, a frequently found pattern in Deep Learning models is a +// residual block like below. +// +// Placeholder Const +// | | +// +-----+-----+ | +// | | | +// | v v +// | Conv2D Const +// | | | +// | v v-----+ +// | BiasAdd +// | | +// v v----------+ +// AddV2 +// +// As shown above, it is the same input node (Placeholder) consumed by both +// AddV2 and and Conv2D. This constrained can be put as labels in the following +// augmented pattern syntax. +// +// {"AddV2", "my_add", +// { +// {"*", "my_residual_input"}, +// {"BiasAdd", "my_bias_add", +// { +// {"Conv2D", "my_conv", +// { +// {"*", "my_residual_input"}, +// {"*", "my_filter"} +// } +// }, +// {"*", my_bias"} +// } +// } +// } +// } +// +// Note that the same label "my_residual_input" is used to tell that it is a +// child of both "AddV2" and "Conv2D". Labels are arbitrary strings to associate +// with the nodes to be matched as well as to uniquely identify those nodes. +// +// (3) The motivatation for a grammar based pattern matching in grappler is to +// make easy for finding fusion pattern in the remapper. A subgraph that +// matches a given pattern, however, is not fusable if any of the matched node, +// that will be removed as a part of fusion, has a consumer outside the matched +// subgraph. In order to check for such type of external dependencies, we +// further extend pattern syntax by prospective action (NodeStatus) on the +// matched nodes as shown below. This helps cross checking the nodes to be +// removed with the nodes matched intially. +// +// {"AddV2", "my_add", NodeStatus::kReplace, +// { +// {"*", "my_residual_input", NodeStatus::kRemain}, +// {"BiasAdd", "my_bias_add", NodeStatus::kRemove, +// { +// {"Conv2D", "my_conv", NodeStatus::kRemove, +// { +// {"*", "my_residual_input", NodeStatus::kRemain}, +// {"*", "my_filter", NodeStatus::Remain} +// } +// }, +// {"*", my_bias", NodeStatus::kRemain} +// } +// } +// } +// } +//------------------------------------------------------------------------------ + +// Pattern matcher recursively matches child subpatterns. The direction +// for children could be toward node's input (fanins) or outputs (fanouts). +enum class MatchingDirection { kFollowInputs, kFollowOutputs }; + +// Action for each node in the set of matched nodes for a given pattern. +enum class NodeStatus { kRemain, kRemove, kReplace }; + +// TODO (intel-tf): Support multiple roots by making them children of a single +// virtual root. +struct OpTypePattern { + string op; + string label; + NodeStatus node_status; + std::vector children; + + string DebugString() const { + string result = "{(op: " + op + ", " + "label: " + label + "), {"; + for (const OpTypePattern& child : children) { + result += child.DebugString() + ","; + } + result += "}}"; + return result; + } +}; + +// This is a helpful recursive structure that keeps one-to-one mapping of +// pattern syntax to the matched nodes. User can call DebugString to see what +// has been matched so far and where is the failing point. +struct NodeViewMatch { + MutableNodeView* node_view = nullptr; + std::vector children; + + string DebugString() const { + string result = "{"; + if (node_view == nullptr) { + result += "Non-Matched-Node}"; + return result; + } else { + result += node_view->node()->DebugString(); + result += ", {"; + for (const NodeViewMatch& child : children) { + result += child.DebugString() + ","; + } + result += "}}"; + return result; + } + } + + void Clear() { + for (NodeViewMatch& child : children) { + child.Clear(); // child is an object. + } + children.clear(); // children is a vector. + if (node_view != nullptr) { + node_view = nullptr; + } + } +}; + +template +class SubGraphMatcher { + public: + SubGraphMatcher(MutableGraphView* graph_view) : graph_view_(graph_view){}; + + // If a given pattern is matched, this function returns true as well as the + // matched node and remove node info is populated. + bool GetMatchedNodes(const OpTypePattern& pattern, + const std::unordered_set& nodes_to_preserve, + MutableNodeView* node_view, + std::map* matched_nodes_map, + std::set* remove_node_indices); + + private: + MutableGraphView* graph_view_; + std::map node_label_to_index_; + std::set matched_node_indices_; + std::set remove_node_indices_; + std::unique_ptr match_ = nullptr; + + bool DoesOpTypePatternMatch(const OpTypePattern& pattern, + MutableNodeView* node_view, NodeViewMatch* match); + + // This function should be called after the pattern matcher has found + // potential matched nodes (i.e. when DoesOpTypePatternMatch returns "true"). + // It performs a sanity check if the candidate nodes for removal in subgraph + // fusion is indeed safe to remove. + bool IsSafeNodesToRemove( + const std::unordered_set& nodes_to_preserve) { + for (const auto& node_idx : remove_node_indices_) { + auto node_view = graph_view_->GetNode(node_idx); + // Check if the node to be removed is in the nodes to be preserved. + string node_name = node_view->GetName(); + if (nodes_to_preserve.count(node_name) > 0) return false; + // Traverse all the Regular Fanouts. Fanouts are stored as vector of + // vector, std::vector>. Note that + // a MutableNodeView's fanouts are stored in a nested vector of + // MutableFaninView type. + auto fanouts_by_ports = node_view->GetRegularFanouts(); + for (const auto& fanouts : fanouts_by_ports) { + for (const auto& fanout : fanouts) { + if (!matched_node_indices_.count(fanout.node_index())) { + return false; + } + } + } + } + return true; + } +}; + +template <> +bool SubGraphMatcher::DoesOpTypePatternMatch( + const OpTypePattern& pattern, MutableNodeView* node_view, + NodeViewMatch* match); + +template <> +bool SubGraphMatcher::GetMatchedNodes( + const OpTypePattern& pattern, + const std::unordered_set& nodes_to_preserve, + MutableNodeView* node_view, std::map* matched_nodes_map, + std::set* remove_node_indices); + +} // namespace utils +} // namespace grappler +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPPLER_UTILS_PATTERN_UTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/grappler/utils/scc.h b/third_party/tflite-hdrs/tensorflow/core/grappler/utils/scc.h new file mode 100644 index 00000000..ceb9f5db --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/grappler/utils/scc.h @@ -0,0 +1,47 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPPLER_UTILS_SCC_H_ +#define TENSORFLOW_CORE_GRAPPLER_UTILS_SCC_H_ + +#include +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/grappler/inputs/utils.h" +#include "tensorflow/core/lib/io/path.h" + +namespace tensorflow { +namespace grappler { + +// Computes modified strongly connected components: +// All nodes that are not part of a loop are assigned the special -1 id +// All nodes that are part of at least one loop are assigned a positive +// component id: if 2 nodes v and w are reachable from one another (i.e. if they +// belong to the same scc), they'll be assigned the same id, otherwise they'll +// be assigned distinct ids. *num_components is set to the number of distinct +// ids. +void StronglyConnectedComponents( + const GraphDef& graph, std::unordered_map* components, + int* num_components); + +// Returns the number of individual loops present in the graph, and populate the +// 'loops' argument with the collection of loops (denoted by their loop ids) a +// node is part of. Loops ids are arbitrary. +int IdentifyLoops(const GraphDef& graph, + std::unordered_map>* loops); + +} // namespace grappler +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPPLER_UTILS_SCC_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/grappler/utils/symbolic_shapes.h b/third_party/tflite-hdrs/tensorflow/core/grappler/utils/symbolic_shapes.h new file mode 100644 index 00000000..9da374ed --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/grappler/utils/symbolic_shapes.h @@ -0,0 +1,77 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPPLER_UTILS_SYMBOLIC_SHAPES_H_ +#define TENSORFLOW_CORE_GRAPPLER_UTILS_SYMBOLIC_SHAPES_H_ + +#include "tensorflow/core/framework/tensor_shape.pb.h" +#include "tensorflow/core/grappler/costs/op_performance_data.pb.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +namespace grappler { + +bool IsKnown(const TensorShapeProto::Dim& dim); +bool IsKnownSymbolically(const TensorShapeProto::Dim& dim); +bool IsUnknown(const TensorShapeProto::Dim& dim); + +// Shape is symbolically defined, if it has a known rank, and each dimension is +// known (dim_size >= 0), or is a symbolic dimension size (dim_size <= -2). +bool ShapeIsSymbolicallyDefined(const TensorShapeProto& shape); +bool ShapeIsSymbolicallyDefined(const OpInfo::TensorProperties& properties); + +// Returns the rank of the shape ir -1 if unknown +int Rank(const TensorShapeProto& shape); + +// Returns the number of coefficients in the shape or -1 if unknown. +// TODO(bsteiner) Add a function that computes the minimum size of the tensor, +// ie the size assuming all the symbolic dimensions take the value 1. +int64_t NumCoefficients(const TensorShapeProto& shape); + +// Shapes are symbolically equal, if they have the same rank, they are known or +// symbolically defined, and have matching dimensions. +bool ShapesSymbolicallyEqual(const TensorShapeProto& left, + const TensorShapeProto& right); +bool ShapesSymbolicallyEqual(const OpInfo::TensorProperties& left, + const OpInfo::TensorProperties& right); + +// Check if two shapes can be broadcasted to each other. Both shapes must be at +// least symbolically defined, and the have valid BCast instance. +bool ShapesBroadcastable(const TensorShapeProto& left, + const TensorShapeProto& right); +bool ShapesBroadcastable(const OpInfo::TensorProperties& left, + const OpInfo::TensorProperties& right); +bool ShapeAfterBroadcast(const TensorShapeProto& left, + const TensorShapeProto& right, + TensorShapeProto* output_shape); + +// Return true if can prove, that tensor of size 'left' is smaller than tensor +// of size 'right'. Return false if it's larger or equal, or it's impossible to +// compare because of unknown dimensions, or mismatch in symbolic dimensions. +bool CompareSymbolicallyShapedTensorSizes(const TensorShapeProto& left, + const TensorShapeProto& right); +bool CompareSymbolicallyShapedTensorSizes( + const OpInfo::TensorProperties& left, + const OpInfo::TensorProperties& right); + +// Returns the ratio of the sizes of the 2 shapes if known statically, or -1 +// otherwise. +int64_t ComputeSizeRatio(const TensorShapeProto& numerator, + const TensorShapeProto& denominator); + +} // namespace grappler +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPPLER_UTILS_SYMBOLIC_SHAPES_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/grappler/utils/topological_sort.h b/third_party/tflite-hdrs/tensorflow/core/grappler/utils/topological_sort.h new file mode 100644 index 00000000..59ea41af --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/grappler/utils/topological_sort.h @@ -0,0 +1,58 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPPLER_UTILS_TOPOLOGICAL_SORT_H_ +#define TENSORFLOW_CORE_GRAPPLER_UTILS_TOPOLOGICAL_SORT_H_ + +#include "absl/types/span.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { +namespace grappler { + +// TODO(ezhulenev, b/121379902): We should be consistent with GraphTopologyView +// and use `GraphView::Edge` to pass extra dependencies. +struct TopologicalDependency { + TopologicalDependency(const NodeDef* from, const NodeDef* to) + : from(from), to(to) {} + const NodeDef* from; + const NodeDef* to; +}; + +// Computes a topological ordering for the graph nodes and outputs nodes in the +// topological order to the `topo_order` output argument. +// +// It's possible to pass additional edges that do not exists in a graph, but +// must be respected when computing graph topological order. Example: Tensorflow +// runtime allows concurrent execution of dequeue/enqueue ops from the same +// queue resource, but we might want to enforce ordering between them. +absl::Status ComputeTopologicalOrder( + const GraphDef& graph, + absl::Span extra_dependencies, + std::vector* topo_order); +absl::Status ComputeTopologicalOrder(const GraphDef& graph, + std::vector* topo_order); + +// Sorts a graph in topological order. +absl::Status TopologicalSort(GraphDef* graph); + +// Sorts a graph in topological order and reverse it. +absl::Status ReversedTopologicalSort(GraphDef* graph); + +} // namespace grappler +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPPLER_UTILS_TOPOLOGICAL_SORT_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/grappler/utils/tpu.h b/third_party/tflite-hdrs/tensorflow/core/grappler/utils/tpu.h new file mode 100644 index 00000000..f81ab93f --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/grappler/utils/tpu.h @@ -0,0 +1,31 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPPLER_UTILS_TPU_H_ +#define TENSORFLOW_CORE_GRAPPLER_UTILS_TPU_H_ + +#include "tensorflow/core/framework/graph.pb.h" + +namespace tensorflow { +namespace grappler { + +// Check if the graphdef contains nodes that indicate a graph processed by the +// legacy TPU bridge. +bool IsLegacyTPUBridgeGraphDef(const GraphDef& def); + +} // end namespace grappler +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPPLER_UTILS_TPU_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/grappler/utils/transitive_fanin.h b/third_party/tflite-hdrs/tensorflow/core/grappler/utils/transitive_fanin.h new file mode 100644 index 00000000..dd9b0c46 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/grappler/utils/transitive_fanin.h @@ -0,0 +1,50 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_GRAPPLER_UTILS_TRANSITIVE_FANIN_H_ +#define TENSORFLOW_CORE_GRAPPLER_UTILS_TRANSITIVE_FANIN_H_ + +#include +#include + +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { +namespace grappler { + +// Computes the transitive fanin of the graph based on reachability from the +// specified terminal nodes. Returns the set of nodes comprising the +// transitive fanin into fanin_nodes. Optionally returns a map of name->node +// for that graph into name_to_fanin_node if that is not set to nullptr. +absl::Status ComputeTransitiveFanin( + const GraphDef& graph, const std::vector& terminal_nodes, + std::unordered_map* name_to_fanin_node, + std::vector* fanin_nodes); + +absl::Status ComputeTransitiveFanin(const GraphDef& graph, + const std::vector& terminal_nodes, + std::vector* fanin_nodes); + +// Creates output_graph from input_graph using the transitive fanin from the +// specified terminal nodes. Returns error if the input_graph is deemed +// structurally invalid. +absl::Status SetTransitiveFaninGraph(const GraphDef& input_graph, + GraphDef* output_graph, + const std::vector& terminal_nodes); + +} // namespace grappler +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPPLER_UTILS_TRANSITIVE_FANIN_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/grappler/utils/traversal.h b/third_party/tflite-hdrs/tensorflow/core/grappler/utils/traversal.h new file mode 100644 index 00000000..5c9dada4 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/grappler/utils/traversal.h @@ -0,0 +1,103 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPPLER_UTILS_TRAVERSAL_H_ +#define TENSORFLOW_CORE_GRAPPLER_UTILS_TRAVERSAL_H_ + +#include + +#include "tensorflow/core/grappler/graph_topology_view.h" + +namespace tensorflow { +namespace grappler { + +enum class TraversalDirection { kFollowInputs, kFollowOutputs }; + +// Encapsulate DFS callbacks that will be called during the graph traversal. +// +// If non-empty, the `pre_order` and `post_order` functors will be called on +// each reachable node (including the `from` nodes) in pre and post order. If +// loops are found, the `on_back_edge` functor will be called on the +// corresponding back edges. Moreover, the pre and post order will assume that +// these back edges will be cut. +struct DfsCallbacks { + DfsCallbacks() = default; + DfsCallbacks(std::function pre, + std::function post, + std::function back_edge) + : pre_order(std::move(pre)), + post_order(std::move(post)), + on_back_edge(std::move(back_edge)) {} + + static DfsCallbacks PreOrder(std::function pre) { + return DfsCallbacks(std::move(pre), nullptr, nullptr); + } + + static DfsCallbacks PostOrder(std::function post) { + return DfsCallbacks(nullptr, std::move(post), nullptr); + } + + std::function pre_order; + std::function post_order; + std::function on_back_edge; +}; + +// Encapsulate DFS predicates for traversing the graph. +// +// The `enter` predicate decides if traversal should enter the node, and the +// `advance` predicate decides if the traversal should follow inputs/outputs +// from the node. +// +// If predicates are empty (default initialized), it's assumed that we can enter +// into any node and advance from any node respectively. +struct DfsPredicates { + DfsPredicates() = default; + DfsPredicates(std::function enter, + std::function advance) + : enter(std::move(enter)), advance(std::move(advance)) {} + + static DfsPredicates Enter(std::function enter) { + return DfsPredicates(std::move(enter), nullptr); + } + + static DfsPredicates Advance(std::function advance) { + return DfsPredicates(nullptr, std::move(advance)); + } + + std::function enter; + std::function advance; +}; + +// Traverse the graph in DFS order in the given direction, starting from the +// list of nodes specified in the `from` argument. Use `predicates` to decide if +// traversal should enter/advance to/from the graph node. These predicates also +// applied to the `from` nodes. Call corresponding callbacks for each visited +// node. +void DfsTraversal(const GraphTopologyView& graph_view, + absl::Span from, + TraversalDirection direction, const DfsPredicates& predicates, + const DfsCallbacks& callbacks); + +// Traverse the graph in DFS order in the given direction, starting from the +// list of nodes specified in the `from` argument. Call corresponding callbacks +// for each visited node. +void DfsTraversal(const GraphTopologyView& graph_view, + absl::Span from, + TraversalDirection direction, const DfsCallbacks& callbacks); + +} // namespace grappler +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPPLER_UTILS_TRAVERSAL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/grappler/verifiers/graph_verifier.h b/third_party/tflite-hdrs/tensorflow/core/grappler/verifiers/graph_verifier.h new file mode 100644 index 00000000..53d62e4c --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/grappler/verifiers/graph_verifier.h @@ -0,0 +1,55 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPPLER_VERIFIERS_GRAPH_VERIFIER_H_ +#define TENSORFLOW_CORE_GRAPPLER_VERIFIERS_GRAPH_VERIFIER_H_ + +#include +#include +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { +namespace grappler { + +// An abstract interface for verifying a graph. +// This will be used to implement specific verifiers to verify that a grappler +// transformed graph is valid. +// Some examples of specific verifiers are: +// 1. A general structural verifier that verifies that the specified graph has +// a valid structure that meets the specification of what it means to be +// a valid TensorFlow graph. +// 2. A backend specific verifier that verifies that the specified graph, +// generated after a grappler transformation to convert the input TensorFlow +// graph to a corresponding backend graph, is a valid graph in the +// specification of the backend. +class GraphVerifier { + public: + GraphVerifier() {} + virtual ~GraphVerifier() {} + + // A name for the verifier. + virtual string name() const = 0; + + // Implement an algorithm to verify the specified graph. + // The return value is a Status that represents a concatenation of Status of + // each verification step. + virtual absl::Status Verify(const GraphDef& graph) = 0; +}; + +} // end namespace grappler +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPPLER_VERIFIERS_GRAPH_VERIFIER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/grappler/verifiers/structure_verifier.h b/third_party/tflite-hdrs/tensorflow/core/grappler/verifiers/structure_verifier.h new file mode 100644 index 00000000..de77933f --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/grappler/verifiers/structure_verifier.h @@ -0,0 +1,43 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPPLER_VERIFIERS_STRUCTURE_VERIFIER_H_ +#define TENSORFLOW_CORE_GRAPPLER_VERIFIERS_STRUCTURE_VERIFIER_H_ + +#include +#include + +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/grappler/verifiers/graph_verifier.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { +namespace grappler { + +// Verifies the structure of a graph to ensure it is valid. +class StructureVerifier : public GraphVerifier { + public: + StructureVerifier() {} + ~StructureVerifier() override {} + + string name() const override { return "structure_verifier"; }; + + absl::Status Verify(const GraphDef& graph) override; +}; + +} // end namespace grappler +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPPLER_VERIFIERS_STRUCTURE_VERIFIER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/ir/dialect.h b/third_party/tflite-hdrs/tensorflow/core/ir/dialect.h new file mode 100644 index 00000000..cba40b38 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/ir/dialect.h @@ -0,0 +1,82 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_IR_DIALECT_H_ +#define TENSORFLOW_CORE_IR_DIALECT_H_ + +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/Diagnostics.h" // from @llvm-project +#include "mlir/IR/Dialect.h" // from @llvm-project +#include "mlir/IR/OpImplementation.h" // from @llvm-project +#include "mlir/IR/TypeUtilities.h" // from @llvm-project +#include "tensorflow/core/ir/types/dialect.h" + +namespace mlir { +namespace tfg { +// Include the relevant TensorFlow attrs/types directly in the TFG namespace. +using mlir::tf_type::Bfloat16RefType; // NOLINT +using mlir::tf_type::BoolRefType; // NOLINT +using mlir::tf_type::Complex128RefType; // NOLINT +using mlir::tf_type::Complex64RefType; // NOLINT +using mlir::tf_type::ControlType; // NOLINT +using mlir::tf_type::DoubleRefType; // NOLINT +using mlir::tf_type::Float8E4M3FNRefType; // NOLINT +using mlir::tf_type::Float8E5M2RefType; // NOLINT +using mlir::tf_type::FloatRefType; // NOLINT +using mlir::tf_type::FuncAttr; // NOLINT +using mlir::tf_type::HalfRefType; // NOLINT +using mlir::tf_type::Int16RefType; // NOLINT +using mlir::tf_type::Int32RefType; // NOLINT +using mlir::tf_type::Int4RefType; // NOLINT +using mlir::tf_type::Int64RefType; // NOLINT +using mlir::tf_type::Int8RefType; // NOLINT +using mlir::tf_type::OpaqueTensorType; // NOLINT +using mlir::tf_type::PlaceholderAttr; // NOLINT +using mlir::tf_type::Qint16RefType; // NOLINT +using mlir::tf_type::Qint16Type; // NOLINT +using mlir::tf_type::Qint32RefType; // NOLINT +using mlir::tf_type::Qint32Type; // NOLINT +using mlir::tf_type::Qint8RefType; // NOLINT +using mlir::tf_type::Qint8Type; // NOLINT +using mlir::tf_type::Quint16RefType; // NOLINT +using mlir::tf_type::Quint16Type; // NOLINT +using mlir::tf_type::Quint8RefType; // NOLINT +using mlir::tf_type::Quint8Type; // NOLINT +using mlir::tf_type::ResourceRefType; // NOLINT +using mlir::tf_type::ResourceType; // NOLINT +using mlir::tf_type::ShapeAttr; // NOLINT +using mlir::tf_type::StringRefType; // NOLINT +using mlir::tf_type::StringType; // NOLINT +using mlir::tf_type::Uint16RefType; // NOLINT +using mlir::tf_type::Uint32RefType; // NOLINT +using mlir::tf_type::Uint4RefType; // NOLINT +using mlir::tf_type::Uint64RefType; // NOLINT +using mlir::tf_type::Uint8RefType; // NOLINT +using mlir::tf_type::VariantRefType; // NOLINT +using mlir::tf_type::VariantType; // NOLINT +using mlir::tf_type::VersionAttr; // NOLINT + +class TFGraphOpAsmInterface; +class TFOp; +} // namespace tfg +} // namespace mlir + +// Dialect main class is defined in ODS, we include it here. +#include "tensorflow/core/ir/dialect.h.inc" // IWYU pragma: export +// ODS-generated attribute classes. +#define GET_ATTRDEF_CLASSES +#include "tensorflow/core/ir/attributes.h.inc" + +#endif // TENSORFLOW_CORE_IR_DIALECT_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/ir/importexport/convert_attributes.h b/third_party/tflite-hdrs/tensorflow/core/ir/importexport/convert_attributes.h new file mode 100644 index 00000000..250a32e5 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/ir/importexport/convert_attributes.h @@ -0,0 +1,86 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_IR_IMPORTEXPORT_CONVERT_ATTRIBUTES_H_ +#define TENSORFLOW_CORE_IR_IMPORTEXPORT_CONVERT_ATTRIBUTES_H_ + +#include + +#include "absl/strings/string_view.h" +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/op_def.pb.h" +#include "tensorflow/core/framework/resource_handle.pb.h" +#include "tensorflow/core/ir/dialect.h" +#include "tensorflow/core/platform/statusor.h" + +namespace mlir { +namespace tfg { + +// Convert the list of MLIR Attributes `attrs` to the `tensorflow::AttrValueMap` +// `values`. +absl::Status ConvertAttributes(ArrayRef attrs, + ArrayRef attrs_to_ignore, + bool remove_ref_type, + tensorflow::AttrValueMap* values); + +// Convert the MLIR attribute `attr` and return a `tensorflow::AttrValue`. +absl::StatusOr ConvertAttribute(Attribute attr); + +absl::Status SetShapeAttribute(absl::string_view name, ShapedType shaped_type, + tensorflow::AttrValueMap* values); + +// Converts an MLIR shaped type to a TensorFlow shape attribute. +ShapeAttr ConvertTypeToTensorShapeAttr(const Type& type); + +/// Import from TensorFlow to MLIR + +// Converts non func AttrValue proto into an MLIR attribute. Func attribute is +// exclused in this function because the function might be renamed when the +// function definition is imported. +absl::StatusOr ConvertNonFuncAttributeValue( + const tensorflow::AttrValue& value, Builder& builder); + +// Converts all kinds of AttrValue proto into an MLIR attribute. +absl::StatusOr ConvertAttributeValue( + const tensorflow::AttrValue& value, Builder& builder); + +// Convert the MLIR FullTyoe attribute `attr` and return a +// `tensorflow::FullTypeDef`. +absl::StatusOr ConvertAttribute( + tf_type::FullTypeAttr full_type); + +// Converts fulltype proto to attribute. +absl::StatusOr< ::mlir::tf_type::FullTypeAttr> ConvertAttribute( + const tensorflow::FullTypeDef& full_type, Builder& builder); + +// Convert an array of handle data (pairs of data types and shapes) to an array +// attribute of tensor types. +absl::StatusOr ConvertHandleData( + Builder builder, + const tensorflow::protobuf::RepeatedPtrField< + tensorflow::ResourceHandleProto_DtypeAndShape>& handle_data); + +// Convert an array of handle data into the `handle_data` field of the provided +// ArgDef. Each entry of the array is expected to be a TensorType. +absl::Status ConvertHandleData(ArrayAttr handle_data_arr, + tensorflow::OpDef::ArgDef* arg); + +} // namespace tfg +} // namespace mlir + +#endif // TENSORFLOW_CORE_IR_IMPORTEXPORT_CONVERT_ATTRIBUTES_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/ir/importexport/convert_tensor.h b/third_party/tflite-hdrs/tensorflow/core/ir/importexport/convert_tensor.h new file mode 100644 index 00000000..15bbe282 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/ir/importexport/convert_tensor.h @@ -0,0 +1,93 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_IR_IMPORTEXPORT_CONVERT_TENSOR_H_ +#define TENSORFLOW_CORE_IR_IMPORTEXPORT_CONVERT_TENSOR_H_ + +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/SmallVector.h" +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor.pb.h" +#include "tensorflow/core/framework/tensor_shape.pb.h" +#include "tensorflow/core/ir/dialect.h" +#include "tensorflow/core/ir/types/dialect.h" +#include "tensorflow/core/platform/statusor.h" + +namespace mlir { +namespace tfg { + +// Converts an TensorFlow tensor proto into an MLIR elements attribute. +absl::StatusOr ConvertTensorProto( + const tensorflow::TensorProto& input_tensor, Builder builder); + +// Converts an TensorFlow tensor into an MLIR elements attribute. +absl::StatusOr ConvertTensor( + const tensorflow::Tensor& input_tensor, Builder builder); + +// Converts a shape from MLIR to a TensorFlow tensor shape proto. +void ConvertToTensorShapeProto(ArrayRef shape, + tensorflow::TensorShapeProto* output_shape); + +// Converts an MLIR type to a TensorFlow tensor shape. +tensorflow::PartialTensorShape ConvertTypeToTensorShape(const Type& type); + +// Converts a TensorFlow shape attribute to an MLIR shape attribute. +absl::StatusOr ConvertTensorShapeProto( + const tensorflow::TensorShapeProto& shape, MLIRContext* context); + +// Fill in the contents of TensorShapeProto for the given shape. +// ShapeContainerT is any type with the following methods: +// bool hasRank() +// ArrayRef getShape() +// This includes TF::ShapeAttr and ShapedType. +template +void SetTensorShapeProto(ShapeContainerT shape, + tensorflow::TensorShapeProto* proto) { + if (shape.hasRank()) { + for (int64_t dim : shape.getShape()) { + // TODO(hinsu): Use tensorflow::kTFDynamicSize instead of -1 without + // depending on tensorflow/compiler + proto->add_dim()->set_size(mlir::ShapedType::isDynamic(dim) ? -1 : dim); + } + } else { + proto->set_unknown_rank(true); + } +} + +// Converts an MLIR elements attribute to a TensorFlow tensor proto. +absl::Status ConvertToTensorProto(ElementsAttr attr, + tensorflow::TensorProto* output_tensor); + +// Converts an MLIR elements attribute to a TensorFlow tensor. +absl::Status ConvertToTensor(ElementsAttr attr, + tensorflow::Tensor* output_tensor); + +// Converts a TF shape to MLIR shape, i.e. -1 becomes kDynamicSize. +llvm::SmallVector ConvertTFShapeToMlir(llvm::ArrayRef shape); + +// Converts an MLIR shape to TF shape, i.e. kDynamicSize becomes -1. +llvm::SmallVector ConvertMlirShapeToTF(llvm::ArrayRef shape); + +// Creates a TF TensorShape using MLIR shape, element type and encoding. +mlir::RankedTensorType GetTypeFromTFTensorShape(llvm::ArrayRef shape, + mlir::Type elementType, + mlir::Attribute encoding = {}); + +} // namespace tfg +} // namespace mlir + +#endif // TENSORFLOW_CORE_IR_IMPORTEXPORT_CONVERT_TENSOR_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/ir/importexport/convert_types.h b/third_party/tflite-hdrs/tensorflow/core/ir/importexport/convert_types.h new file mode 100644 index 00000000..d3f1756c --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/ir/importexport/convert_types.h @@ -0,0 +1,56 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_IR_IMPORTEXPORT_CONVERT_TYPES_H_ +#define TENSORFLOW_CORE_IR_IMPORTEXPORT_CONVERT_TYPES_H_ + +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/tensor_shape.pb.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/platform/statusor.h" + +namespace mlir { +namespace tfg { +// Converts the TensorFlow DataType 'dtype' into an MLIR (scalar) type. +absl::Status ConvertDataType(tensorflow::DataType dtype, Builder& builder, + Type* type); + +// Converts a scalar MLIR type to a TensorFlow Datatype. +absl::Status ConvertScalarTypeToDataType(Type type, + tensorflow::DataType* dtype); + +// Converts an MLIR type to TensorFlow DataType. If 'type' is a scalar type, it +// is converted directly. If it is a shaped type, the element type is converted. +absl::Status ConvertToDataType(Type type, tensorflow::DataType* dtype); + +// Converts an TensorFlow shape to the one used in MLIR. +void ConvertToMlirShape(const tensorflow::TensorShape& input_shape, + SmallVectorImpl* shape); + +// Converts an TensorFlow shape proto to the one used in MLIR. +absl::Status ConvertToMlirShape(const tensorflow::TensorShapeProto& input_shape, + SmallVectorImpl* shape); + +// Given a tensor shape and dtype, get the corresponding MLIR tensor type. +absl::StatusOr ConvertToMlirTensorType( + const tensorflow::TensorShapeProto& shape, tensorflow::DataType dtype, + Builder* builder); + +} // namespace tfg +} // namespace mlir + +#endif // TENSORFLOW_CORE_IR_IMPORTEXPORT_CONVERT_TYPES_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/ir/importexport/functiondef_export.h b/third_party/tflite-hdrs/tensorflow/core/ir/importexport/functiondef_export.h new file mode 100644 index 00000000..1eec4282 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/ir/importexport/functiondef_export.h @@ -0,0 +1,35 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_IR_IMPORTEXPORT_FUNCTIONDEF_EXPORT_H_ +#define TENSORFLOW_CORE_IR_IMPORTEXPORT_FUNCTIONDEF_EXPORT_H_ + +#include "mlir/IR/Builders.h" // from @llvm-project +#include "tensorflow/core/framework/function.pb.h" +#include "tensorflow/core/ir/ops.h" +#include "tensorflow/core/platform/statusor.h" + +namespace mlir { +namespace tfg { + +// Export a generic GraphFuncOp into a FunctionDef. This is intended to be a +// straight serialization, an error is returned in case of failure. +absl::StatusOr ConvertGenericFunctionToFunctionDef( + GraphFuncOp func); + +} // namespace tfg +} // namespace mlir + +#endif // TENSORFLOW_CORE_IR_IMPORTEXPORT_FUNCTIONDEF_EXPORT_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/ir/importexport/functiondef_import.h b/third_party/tflite-hdrs/tensorflow/core/ir/importexport/functiondef_import.h new file mode 100644 index 00000000..7e9aba69 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/ir/importexport/functiondef_import.h @@ -0,0 +1,36 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_IR_IMPORTEXPORT_FUNCTIONDEF_IMPORT_H_ +#define TENSORFLOW_CORE_IR_IMPORTEXPORT_FUNCTIONDEF_IMPORT_H_ + +#include "mlir/IR/Builders.h" // from @llvm-project +#include "tensorflow/core/framework/function.pb.h" +#include "tensorflow/core/ir/ops.h" +#include "tensorflow/core/platform/status.h" + +namespace mlir { +namespace tfg { + +// Import the FunctionDef `func` as a TFG generic function (see GraphFuncOp +// documentation). The function will be inserted using the provided `builder`. +absl::Status ConvertGenericFunction(GraphFuncOp func_op, + const tensorflow::FunctionDef& func, + OpBuilder& builder); + +} // namespace tfg +} // namespace mlir + +#endif // TENSORFLOW_CORE_IR_IMPORTEXPORT_FUNCTIONDEF_IMPORT_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/ir/importexport/graphdef_export.h b/third_party/tflite-hdrs/tensorflow/core/ir/importexport/graphdef_export.h new file mode 100644 index 00000000..74af12fb --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/ir/importexport/graphdef_export.h @@ -0,0 +1,56 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_IR_IMPORTEXPORT_GRAPHDEF_EXPORT_H_ +#define TENSORFLOW_CORE_IR_IMPORTEXPORT_GRAPHDEF_EXPORT_H_ + +#include + +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/ir/dialect.h" +#include "tensorflow/core/ir/ops.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/statusor.h" + +namespace mlir { +namespace tfg { + +// Get the name of a value as if it were an edge in a graph. +absl::StatusOr GetValueName(Value value, TFGraphDialect *dialect); + +// Convert a TFG graph directly to GraphDef. Graph functions in the module are +// added to the GraphDef's function library. +absl::Status ConvertToGraphDef(ModuleOp module, tensorflow::GraphDef *graph); + +// Convert a single TFG op to NodeDef. This utliity function requires a callback +// `get_value_name` that returns the edge name of the given operand. +absl::Status ConvertToNodeDef( + Operation *op, tensorflow::NodeDef *node, TFGraphDialect *dialect, + function_ref(Value)> get_value_name); + +// Convert a single TFG function to a FunctionDef and add it to the function +// library. If a function with the same name already exists, replace it. +absl::Status ConvertToFunctionDef( + GraphFuncOp func, tensorflow::FunctionLibraryDefinition &library); + +} // namespace tfg +} // namespace mlir + +#endif // TENSORFLOW_CORE_IR_IMPORTEXPORT_GRAPHDEF_EXPORT_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/ir/importexport/graphdef_import.h b/third_party/tflite-hdrs/tensorflow/core/ir/importexport/graphdef_import.h new file mode 100644 index 00000000..cda3a989 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/ir/importexport/graphdef_import.h @@ -0,0 +1,45 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_IR_IMPORTEXPORT_GRAPHDEF_IMPORT_H_ +#define TENSORFLOW_CORE_IR_IMPORTEXPORT_GRAPHDEF_IMPORT_H_ + +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/OwningOpRef.h" // from @llvm-project +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/graph_debug_info.pb.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/platform/statusor.h" + +namespace mlir { +namespace tfg { + +// Convert a GraphDef directly to TFG. +absl::StatusOr> ImportGraphDef( + MLIRContext *context, const tensorflow::GraphDebugInfo &debug_info, + const tensorflow::GraphDef &graph_def); + +// Converts a graph and function library to a TFG module. +absl::StatusOr> ImportGraphAndFunctionsToMlir( + MLIRContext *context, const tensorflow::GraphDebugInfo &debug_info, + const tensorflow::Graph &graph, + const tensorflow::FunctionLibraryDefinition &flib_def); + +} // namespace tfg +} // namespace mlir + +#endif // TENSORFLOW_CORE_IR_IMPORTEXPORT_GRAPHDEF_IMPORT_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/ir/importexport/load_proto.h b/third_party/tflite-hdrs/tensorflow/core/ir/importexport/load_proto.h new file mode 100644 index 00000000..9644411c --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/ir/importexport/load_proto.h @@ -0,0 +1,45 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_IR_IMPORTEXPORT_LOAD_PROTO_H_ +#define TENSORFLOW_CORE_IR_IMPORTEXPORT_LOAD_PROTO_H_ + +#include "absl/strings/string_view.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/protobuf.h" + +namespace tensorflow { + +// Reads text (.pbtext) or binary (.pb) format of a proto message from the given +// buffer. Returns error status of the file is not found or malformed proto. +// Note that text protos can only be parsed when full protobuf::Message protos +// are used, and will fail for protobuf::MessageLite protos. +absl::Status LoadProtoFromBuffer(absl::string_view input, + protobuf::Message* proto); +absl::Status LoadProtoFromBuffer(absl::string_view input, + protobuf::MessageLite* proto); + +// Reads text (.pbtext) or binary (.pb) format of a proto message from the given +// file path. Returns error status of the file is not found or malformed proto. +// Note that text protos can only be parsed when full protobuf::Message protos +// are used, and will fail for protobuf::MessageLite protos. +absl::Status LoadProtoFromFile(absl::string_view input_filename, + protobuf::Message* proto); +absl::Status LoadProtoFromFile(absl::string_view input_filename, + protobuf::MessageLite* proto); + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_IR_IMPORTEXPORT_LOAD_PROTO_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/ir/importexport/mangling.h b/third_party/tflite-hdrs/tensorflow/core/ir/importexport/mangling.h new file mode 100644 index 00000000..a85be927 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/ir/importexport/mangling.h @@ -0,0 +1,76 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_IR_IMPORTEXPORT_MANGLING_H_ +#define TENSORFLOW_CORE_IR_IMPORTEXPORT_MANGLING_H_ + +#include + +#include "absl/strings/string_view.h" +#include "tensorflow/core/framework/tensor.pb.h" +#include "tensorflow/core/framework/tensor_shape.pb.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/core/status.h" + +namespace mlir { +namespace tfg { +namespace mangling_util { +// The type of a mangled string. +enum class MangledKind { kUnknown, kDataType, kTensorShape, kTensor }; + +// Print proto in TextFormat in the single-line mode. +std::string PrintShortTextProto(const ::tensorflow::protobuf::Message& message); +// The MessageLite interface does not support reflection so this API +// will only print a summary of the proto. This API is needed for code +// that may work with both Message and MessageLite. +std::string PrintShortTextProto( + const ::tensorflow::protobuf::MessageLite& message); + +// Mangles an attribute name, marking the attribute as a TensorFlow attribute. +std::string MangleAttributeName(absl::string_view str); + +// Returns true if 'str' was mangled with MangleAttributeName. +bool IsMangledAttributeName(absl::string_view str); + +// Demangles an attribute name that was manged with MangleAttributeName. +// REQUIRES: IsMangledAttributeName returns true. +absl::string_view DemangleAttributeName(absl::string_view str); + +// Returns the type of a mangled string, or kUnknown. +MangledKind GetMangledKind(absl::string_view str); + +// Return a TensorShapeProto mangled as a string. +std::string MangleShape(const tensorflow::TensorShapeProto& shape); +// Demangle a string mangled with MangleShape. +absl::Status DemangleShape(absl::string_view str, + tensorflow::TensorShapeProto* proto); + +// Return a TensorProto mangled as a string. +std::string MangleTensor(const tensorflow::TensorProto& tensor); +// Demangle a string mangled with MangleTensor. +absl::Status DemangleTensor(absl::string_view str, + tensorflow::TensorProto* proto); + +// Return a DataType mangled as a string. +std::string MangleDataType(const tensorflow::DataType& dtype); +// Demangle a string mangled with MangleDataType. +absl::Status DemangleDataType(absl::string_view str, + tensorflow::DataType* proto); + +} // namespace mangling_util +} // namespace tfg +} // namespace mlir + +#endif // TENSORFLOW_CORE_IR_IMPORTEXPORT_MANGLING_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/ir/importexport/parse_text_proto.h b/third_party/tflite-hdrs/tensorflow/core/ir/importexport/parse_text_proto.h new file mode 100644 index 00000000..00a7d83e --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/ir/importexport/parse_text_proto.h @@ -0,0 +1,46 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_IR_IMPORTEXPORT_PARSE_TEXT_PROTO_H_ +#define TENSORFLOW_CORE_IR_IMPORTEXPORT_PARSE_TEXT_PROTO_H_ + +#include "absl/strings/string_view.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/protobuf.h" + +namespace mlir { +namespace tfg { + +// Sets output to the given input with `prefix` stripped, or returns an error if +// the prefix doesn't exist. +absl::Status ConsumePrefix(absl::string_view str, absl::string_view prefix, + absl::string_view* output); + +// Strips `prefix_to_strip` from `text_proto`, parses, and returns the parsed +// proto. +absl::Status ParseTextProto(absl::string_view text_proto, + absl::string_view prefix_to_strip, + tensorflow::protobuf::Message* parsed_proto); +inline absl::Status ParseTextProto( + absl::string_view /* text_proto */, absl::string_view /* prefix_to_strip */, + tensorflow::protobuf::MessageLite* /* parsed_proto */) { + return tensorflow::errors::Unavailable("Cannot parse text protos on mobile."); +} + +} // namespace tfg +} // namespace mlir + +#endif // TENSORFLOW_CORE_IR_IMPORTEXPORT_PARSE_TEXT_PROTO_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/ir/importexport/savedmodel_export.h b/third_party/tflite-hdrs/tensorflow/core/ir/importexport/savedmodel_export.h new file mode 100644 index 00000000..b270ce9c --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/ir/importexport/savedmodel_export.h @@ -0,0 +1,39 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_IR_IMPORTEXPORT_SAVEDMODEL_EXPORT_H_ +#define TENSORFLOW_CORE_IR_IMPORTEXPORT_SAVEDMODEL_EXPORT_H_ + +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "tensorflow/core/framework/graph_debug_info.pb.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/protobuf/saved_model.pb.h" + +namespace mlir { +namespace tfg { + +// Given an MLIR module, returns a `output_saved_model` SavedModel. +// The module must contain at most a single Graph operation and zero or more +// TFFunc operations. `original_saved_model` is used as only a GraphDef portion +// of a saved model represented in the MLIR module. +absl::Status ExportMlirToSavedModel( + mlir::ModuleOp module, const tensorflow::SavedModel &original_saved_model, + tensorflow::SavedModel *output_saved_model); + +} // namespace tfg +} // namespace mlir + +#endif // TENSORFLOW_CORE_IR_IMPORTEXPORT_SAVEDMODEL_EXPORT_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/ir/importexport/savedmodel_import.h b/third_party/tflite-hdrs/tensorflow/core/ir/importexport/savedmodel_import.h new file mode 100644 index 00000000..787f2ae5 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/ir/importexport/savedmodel_import.h @@ -0,0 +1,40 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_IR_IMPORTEXPORT_SAVEDMODEL_IMPORT_H_ +#define TENSORFLOW_CORE_IR_IMPORTEXPORT_SAVEDMODEL_IMPORT_H_ + +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/OwningOpRef.h" // from @llvm-project +#include "tensorflow/core/framework/graph_debug_info.pb.h" +#include "tensorflow/core/platform/statusor.h" +#include "tensorflow/core/protobuf/saved_model.pb.h" + +namespace mlir { +namespace tfg { + +// Converts a saved model to a MLIR module expressed in TFG dialect. +// Only the root graph and function library of the saved model gets imported +// into MLIR TFG dialect. +// TODO(b/218882780): Consider importing SignatureDefs from the SavedModel. +absl::StatusOr> ImportSavedModelToMlir( + mlir::MLIRContext* context, const tensorflow::GraphDebugInfo& debug_info, + const tensorflow::SavedModel& saved_model); + +} // namespace tfg +} // namespace mlir + +#endif // TENSORFLOW_CORE_IR_IMPORTEXPORT_SAVEDMODEL_IMPORT_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/ir/importexport/tests/roundtrip/roundtrip.h b/third_party/tflite-hdrs/tensorflow/core/ir/importexport/tests/roundtrip/roundtrip.h new file mode 100644 index 00000000..516ede67 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/ir/importexport/tests/roundtrip/roundtrip.h @@ -0,0 +1,25 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_IR_IMPORTEXPORT_TESTS_ROUNDTRIP_ROUNDTRIP_H_ +#define TENSORFLOW_CORE_IR_IMPORTEXPORT_TESTS_ROUNDTRIP_ROUNDTRIP_H_ + +#include "tensorflow/core/framework/graph.pb.h" + +namespace tensorflow { +void NormalizeTensorData(GraphDef& graphdef, bool add_fulltype); +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_IR_IMPORTEXPORT_TESTS_ROUNDTRIP_ROUNDTRIP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/ir/interfaces.h b/third_party/tflite-hdrs/tensorflow/core/ir/interfaces.h new file mode 100644 index 00000000..c6b07034 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/ir/interfaces.h @@ -0,0 +1,75 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_IR_INTERFACES_H_ +#define TENSORFLOW_CORE_IR_INTERFACES_H_ + +#include "mlir/IR/Dialect.h" // from @llvm-project +#include "mlir/IR/DialectInterface.h" // from @llvm-project +#include "mlir/IR/OpDefinition.h" // from @llvm-project +#include "mlir/Interfaces/ControlFlowInterfaces.h" // from @llvm-project +#include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "tensorflow/core/ir/dialect.h" + +// Include generated declarations. +#include "tensorflow/core/ir/interfaces.h.inc" + +namespace mlir { +namespace tfg { +// The dialect fallback model for the TensorFlow registry interface. +class TensorFlowRegistryInterfaceBase + : public TensorFlowRegistryInterface::FallbackModel< + TensorFlowRegistryInterfaceBase>, + public DialectInterface::Base { + public: + explicit TensorFlowRegistryInterfaceBase(Dialect *dialect) + : DialectInterface::Base(dialect) {} + + // Returns whether the operation is stateful. + virtual bool isStateful(Operation *op) const = 0; +}; + +// This dialect fallback model implements memory effects for TensorFlow +// operations. +class StatefulMemoryEffectInterface + : public MemoryEffectOpInterface::FallbackModel< + StatefulMemoryEffectInterface>, + public DialectInterface::Base { + public: + explicit StatefulMemoryEffectInterface(Dialect *dialect) + : DialectInterface::Base(dialect) {} + + // Get the memory effects of a TensorFlow operation. If the operation is known + // to be stateless, then it has no memory effects. Otherwise, statefulness is + // modelled as `MemoryWrite`. + void getEffects( + Operation *op, + SmallVectorImpl> + &effects) const; +}; +} // namespace tfg + +namespace OpTrait { +// This trait marks intrinsic TFG operations, e.g. terminators, functions, +// and region control-flow operations. Any TFG operation that has this trait +// exists only in MLIR. +template +class IntrinsicOperation + : public mlir::OpTrait::TraitBase {}; +} // namespace OpTrait +} // namespace mlir + +#endif // TENSORFLOW_CORE_IR_INTERFACES_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/ir/ops.h b/third_party/tflite-hdrs/tensorflow/core/ir/ops.h new file mode 100644 index 00000000..08e20991 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/ir/ops.h @@ -0,0 +1,67 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_IR_OPS_H_ +#define TENSORFLOW_CORE_IR_OPS_H_ + +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/Dialect.h" // from @llvm-project +#include "mlir/IR/Matchers.h" // from @llvm-project +#include "mlir/IR/OpImplementation.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/RegionKindInterface.h" // from @llvm-project +#include "mlir/IR/TypeUtilities.h" // from @llvm-project +#include "mlir/Interfaces/CallInterfaces.h" // from @llvm-project +#include "mlir/Interfaces/ControlFlowInterfaces.h" // from @llvm-project +#include "mlir/Interfaces/FunctionInterfaces.h" // from @llvm-project +#include "mlir/Interfaces/InferTypeOpInterface.h" // from @llvm-project +#include "tensorflow/core/ir/dialect.h" +#include "tensorflow/core/ir/interfaces.h" +#include "tensorflow/core/ir/tf_op_wrapper.h" + +// Get the C++ declaration for all the ops defined in ODS for the dialect. + +#define GET_OP_CLASSES +#include "tensorflow/core/ir/ops.h.inc" + +namespace mlir { +namespace tfg { + +// Analysis that keeps track of all function names in a module. +struct FunctionTable { + explicit FunctionTable(ModuleOp module); + + // Returns whether there are no functions. + bool empty() const { return functions.empty(); } + + // Returns whether `op` may be a function call. + bool MayBeCall(Operation* op) const; + + // Returns whether `op` is a legacy function call. A "legacy" function call + // is when the operation name is the name of a function in the library. + bool IsLegacyCall(Operation* op) const; + + private: + // All the functions in the graph. + DenseSet functions; +}; + +} // namespace tfg +} // namespace mlir + +#endif // TENSORFLOW_CORE_IR_OPS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/ir/tf_op_registry.h b/third_party/tflite-hdrs/tensorflow/core/ir/tf_op_registry.h new file mode 100644 index 00000000..fe0d82e3 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/ir/tf_op_registry.h @@ -0,0 +1,52 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_IR_TF_OP_REGISTRY_H_ +#define TENSORFLOW_CORE_IR_TF_OP_REGISTRY_H_ + +#include "mlir/IR/Dialect.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "tensorflow/core/ir/interfaces.h" + +// Forward declaration of TensorFlow types. +namespace tensorflow { +class OpRegistry; +} // namespace tensorflow + +namespace mlir { +namespace tfg { +class TensorFlowOpRegistryInterface : public TensorFlowRegistryInterfaceBase { + public: + // Create the interface model with a provided registry. + TensorFlowOpRegistryInterface(Dialect *dialect, + const tensorflow::OpRegistry *registry) + : TensorFlowRegistryInterfaceBase(dialect), registry_(registry) {} + // Create the interface model with the global registry. + explicit TensorFlowOpRegistryInterface(Dialect *dialect); + + // Returns true if the operation is stateful. + bool isStateful(Operation *op) const override; + + // Returns the current TensorFlow op registry. + const tensorflow::OpRegistry *GetRegistry() const { return registry_; } + + private: + // The TensorFlow op registry instance. + const tensorflow::OpRegistry *registry_; +}; +} // namespace tfg +} // namespace mlir + +#endif // TENSORFLOW_CORE_IR_TF_OP_REGISTRY_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/ir/tf_op_wrapper.h b/third_party/tflite-hdrs/tensorflow/core/ir/tf_op_wrapper.h new file mode 100644 index 00000000..1c8183f5 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/ir/tf_op_wrapper.h @@ -0,0 +1,200 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_IR_TF_OP_WRAPPER_H_ +#define TENSORFLOW_CORE_IR_TF_OP_WRAPPER_H_ + +#include +#include +#include + +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/iterator_range.h" +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/OperationSupport.h" // from @llvm-project +#include "mlir/IR/TypeRange.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "tensorflow/core/ir/dialect.h" +#include "tensorflow/core/ir/types/dialect.h" +#include "tensorflow/core/ir/utility.h" + +namespace mlir { +namespace detail { +// This class iterates over the control dependencies of the values. +template +class ControlRetIterator final + : public llvm::mapped_iterator_base, + ValueIteratorT, Value> { + public: + using llvm::mapped_iterator_base, + ValueIteratorT, Value>::mapped_iterator_base; + + Value mapElement(Value value) const { + return mlir::isa(value.getType()) + ? value + : tfg::LookupControlDependency(value); + } +}; +} // namespace detail + +namespace tfg { + +// Wrapper class exposing convenience methods to manipulate TensorFlow graph +// nodes uniformly. +class TFOp { + public: + // Wrap an operation. The operation can be null. The constructor must be + // marked as implicit to support `llvm::dyn_cast`. + TFOp(Operation *op = nullptr); // NOLINT + + explicit TFOp(Operation &op) : TFOp(&op) {} + + // Support LLVM-style RTTI. + static bool classof(Operation *op) { + return isa(op->getDialect()); + } + + // Get the wrapped operation. + Operation *getOperation() { return op_; } + + // Returns a pointer to the TensorFlow Graph Dialect. It nevers returns + // nullptr. + TFGraphDialect *getDialect() { + return cast(op_->getDialect()); + } + + // Split the operands into data and control operands. + std::pair splitOperands() { + ControlType ctl_type = getDialect()->getControlType(); + return SplitDataAndControlValues(op_->getOperands(), ctl_type); + } + + // Returns the regular operands, the control operands will be excluded. + OperandRange getNonControlOperands() { return splitOperands().first; } + + // The control operands are always after the regular inputs. + OperandRange getControlOperands() { return splitOperands().second; } + + // Returns the control token produced by this operation. + Value controlRet() { return op_->getResult(op_->getNumResults() - 1); } + + // Returns the non-control results produced by this operation. + ResultRange getNonControlResults() { + return op_->getResults().slice(0, op_->getNumResults() - 1); + } + + // Returns the node name for this operation. + StringAttr nameAttr(); + StringRef name(); + // Set a new node name for this operation. + void setName(const Twine &name); + void setName(StringAttr name); + + // Returns the requested device, which is also the "device" field in a + // GraphDef. + StringAttr requestedDeviceAttr(); + StringRef requestedDevice(); + // Set a new requested device for this operation. + void setRequestedDevice(const Twine &requested_device); + void setRequestedDevice(StringAttr requested_device); + + // Returns the assigned device, this field is set by placer in general. + StringAttr assignedDeviceAttr(); + StringRef assignedDevice(); + // Set a new assigned device for this operation. + void setAssignedDevice(const Twine &assigned_device); + void setAssignedDevice(StringAttr assigned_device); + + // Returns the assigned TPU cluster name. + StringAttr tpuReplicate(); + // Set the assigned TPU cluster name. + void setTpuReplicate(StringAttr tpu_replicate); + + // Returns the device, preferring the assigned device if set, and the + // requested device otherwise. + StringAttr deviceAttr() { + StringAttr device = assignedDeviceAttr(); + if (device) { + assert(!device.getValue().empty()); + return device; + } + return requestedDeviceAttr(); + } + StringRef device() { + StringAttr device_attr = deviceAttr(); + if (device_attr) return device_attr.getValue(); + return ""; + } + + // Forward `->` to the underlying operation, exposing the `Operation` methods. + Operation *operator->() { return op_; } + Operation &operator*() { return *op_; } + + // Converts to true if there is a wrapped operation. + explicit operator bool() const { return op_; } + + private: + // The wrapped operation. + Operation *op_; +}; + +// A range iterator to get the control tokens associated with a value range. +// This range allows to wrap a ValueRange (or an OperandRange) and iterates on +// the control token associated to the producer of each value. For example, if +// you wrap the operands of an operation: +// OperandControlRetRange range = op->getOperands(); +// iterating this range will yield the control edges from each of the operations +// (or block arguments) producing these operands. +template +class ControlRetRange final + : public llvm::iterator_range< + ::mlir::detail::ControlRetIterator> { + public: + using Base = llvm::iterator_range< + ::mlir::detail::ControlRetIterator>; + explicit ControlRetRange(ValueRangeT c) : Base(c.begin(), c.end()) {} + + /// Return the value at the given index. + Value operator[](size_t index) const { + assert(index < size() && "invalid index into value range"); + return *(this->begin() + index); + } + + // Return the size of this range. + size_t size() const { return llvm::size(*this); } + + // Return first value in the range. + Value front() { return (*this)[0]; } + + // Compare this range with another. + template + bool operator==(const OtherT &other) const { + return llvm::size(*this) == llvm::size(other) && + std::equal(this->begin(), this->end(), other.begin()); + } + template + bool operator!=(const OtherT &other) const { + return !(*this == other); + } +}; + +using OperandControlRetRange = ControlRetRange; +using ValueControlRetRange = ControlRetRange; + +} // namespace tfg +} // namespace mlir + +#endif // TENSORFLOW_CORE_IR_TF_OP_WRAPPER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/ir/types/dialect.h b/third_party/tflite-hdrs/tensorflow/core/ir/types/dialect.h new file mode 100644 index 00000000..b0b601e3 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/ir/types/dialect.h @@ -0,0 +1,359 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_IR_TYPES_DIALECT_H_ +#define TENSORFLOW_CORE_IR_TYPES_DIALECT_H_ + +#include +#include + +#include "mlir/Dialect/Quant/IR/QuantTypes.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/Diagnostics.h" // from @llvm-project +#include "mlir/IR/Dialect.h" // from @llvm-project +#include "mlir/IR/TypeUtilities.h" // from @llvm-project + +// Include the dialect class generated from dialect.td. +// The constructor and the printing/parsing of dialect types are manually +// implemented (see ops.cpp). +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "tensorflow/core/ir/types/dialect.h.inc" + +// Include the Type classes declaration generated from types.td +#define GET_TYPEDEF_CLASSES +#include "tensorflow/core/ir/types/types.h.inc" + +namespace mlir { +namespace tf_type { + +//===----------------------------------------------------------------------===// +// TensorFlow types +//===----------------------------------------------------------------------===// + +// The base class in the TensorFlow type hierarchy. +class TensorFlowType : public Type { + public: + using Type::Type; + + // Support method to enable LLVM-style type casting. + static bool classof(Type type); +}; + +// Returns true if the specified type is a valid TensorFlow element type. +inline bool IsValidTFElementType(Type type) { + return mlir::isa(type); +} + +// Returns true if this is a valid TensorFlow tensor type. +inline bool IsValidTFTensorType(Type type) { + // TensorFlow types should be tensors of one of the valid TensorFlow element + // types. + if (auto tensor_ty = mlir::dyn_cast(type)) + return IsValidTFElementType(tensor_ty.getElementType()); + return false; +} + +namespace detail { +// Common implementation of TensorFlow types. The template argument indicates +// the concrete derived class per CRTP. +template +class TensorFlowTypeImpl + : public Type::TypeBase { + public: + using Base = typename Type::TypeBase; + using TFBase = TensorFlowTypeImpl; + using Base::Base; +}; +} // namespace detail + +// TensorFlowRefType class supports all the ref types in TensorFlow dialect. +class TensorFlowRefType : public TensorFlowType { + public: + using TensorFlowType::TensorFlowType; + + // Checks if a type is TensorFlow Ref type. + static bool classof(Type type); + + // Converts a type to the corresponding TensorFlowRef type. + static TensorFlowType get(Type type); + static TensorFlowType getChecked(Type type, MLIRContext* context, + Location loc) { + if (failed(verify(loc, type))) { + return TensorFlowRefType(); + } + return get(type); + } + + static LogicalResult verify(Location loc, Type type) { + // type should be a valid TensorFlow type. + if (!IsValidTFTensorType(type)) { + return emitError(loc) << "invalid TensorFlow type: " << type; + } + return success(); + } + + // Converts a TensorFlowRef type to the corresponding TensorFlow or standard + // type. + Type RemoveRef(); +}; + +// Define a class for each individual TensorFlow type (dtype), see types.def +// for the list. +#define HANDLE_TF_TYPE(tftype, enumerant, name_marg) \ + class tftype##Type : public detail::TensorFlowTypeImpl { \ + public: \ + using TFBase::TFBase; \ + static constexpr StringLiteral name = #name_marg; \ + }; +#define HANDLE_CUSTOM_TF_TYPE(tftype, enumerant, name_marg) +#include "tensorflow/core/ir/types/types.def" + +namespace detail { +// Storage type contains inferred subtypes for TypeWithSubtype. +class TypeWithSubtypeStorage : public TypeStorage { + public: + using KeyTy = ArrayRef; + + // NOLINTNEXTLINE + static TypeWithSubtypeStorage* construct(TypeStorageAllocator& allocator, + const KeyTy& key) { + ArrayRef subtypes = allocator.copyInto(key); + return new (allocator.allocate()) + TypeWithSubtypeStorage(subtypes); + } + + explicit TypeWithSubtypeStorage(const KeyTy& key) : subtypes_(key) {} + + bool operator==(const KeyTy& key) const { return key == subtypes_; } + + static llvm::hash_code hashKey(const KeyTy& key) { + return llvm::hash_combine_range(key.begin(), key.end()); + } + + KeyTy subtypes_; +}; + +// Common implementation of TensorFlow types with subtypes. These subtypes are +// opaque and their interpretation depends on the actual underlying type. +// The template argument indicates the concrete derived class per CRTP. Concrete +// classes must implement the following: +// - `static std::string getTypeName()` that returns the name of the type for +// verification logging. +template +class TypeWithSubtypeImpl + : public Type::TypeBase { + public: + using Base = Type::TypeBase; + using TFBase = TypeWithSubtypeImpl; + using Base::Base; + + static Derived get(ArrayRef subtypes, MLIRContext* context) { + return Base::get(context, subtypes); + } + + static Derived getChecked(ArrayRef subtypes, MLIRContext* context, + Location loc) { + return Base::getChecked(loc, subtypes); + } + static Derived getChecked(function_ref emitError, + MLIRContext* context, + ArrayRef subtypes) { + return Base::getChecked(emitError, context, subtypes); + } + + static Derived get(MLIRContext* context) { return get({}, context); } + + static LogicalResult verify(function_ref emitError, + ArrayRef subtypes) { + // Each of the subtypes should be a valid TensorFlow type. + for (TensorType subtype : subtypes) { + if (!IsValidTFTensorType(subtype)) { + return emitError() << "invalid " << Derived::getTypeName() + << " subtype: " << subtype; + } + } + return success(); + } + + ArrayRef getSubtypes() { return Base::getImpl()->subtypes_; } +}; +} // namespace detail + +// TensorFlowTypeWithSubtype class supports all the types with subtypes in +// TensorFlow dialect. +class TensorFlowTypeWithSubtype : public TensorFlowType { + public: + using TensorFlowType::TensorFlowType; + + // Checks if a type is TensorFlow type with subtypes. + static bool classof(Type type); + + // Converts a TypeWithSubtype type to the same type but without its subtypes. + Type RemoveSubtypes(); + + // Clone the current Type with new subtypes. + TensorFlowTypeWithSubtype clone(ArrayRef new_subtypes); + + // Returns the subtypes. + ArrayRef GetSubtypes(); +}; + +// Returns the corresponding TensorFlow type with subtypes but without its +// subtypes. +inline Type GetDefaultTypeOf(TensorFlowTypeWithSubtype type) { + return type.RemoveSubtypes(); +} + +// TensorFlow resource type is used to support TensorFlow resource variables, +// which represent shared, persistent state manipulated by a TensorFlow program. +// ResourceType stores shape and datatype for subtypes unlike most other data +// types that don't have any associated information. +class ResourceType : public detail::TypeWithSubtypeImpl { + public: + using TFBase::TFBase; + static constexpr ::mlir::StringLiteral name = "tf_type.resource"; + static std::string getTypeName() { return "ResourceType"; } +}; + +// TensorFlow variant type is used to support arbitrary custom C++ data types. +// VariantType stores inferred shape and datatype for subtypes unlike most other +// data types that don't have any associated information. For example, variants +// encoding TensorList type stores the common shape and dtype of the list +// elements as the only subtype. +class VariantType : public detail::TypeWithSubtypeImpl { + public: + using TFBase::TFBase; + static constexpr ::mlir::StringLiteral name = "tf_type.variant"; + static std::string getTypeName() { return "VariantType"; } +}; + +// Given two types `a` and `b`, returns a refined type which is cast compatible +// with both `a` and `b` and is equal to or more precise than both of them. It +// returns empty Type if the input types are not cast compatible. +// Provides option to ignore ref types on 'a'. This is useful for TF ops that +// might allow operands to either be same as result type or be a ref type +// corresponding to it. +Type GetCastCompatibleType(Type a, Type b, bool may_ignore_ref_type_a = false); + +// Returns whether two arrays of Type are broadcast compatible. +bool BroadcastCompatible(TypeRange lhs, TypeRange rhs); + +// Returns whether the two elemental types are compatible. Shapes are compatible +// if: +// - the types are statically equal +// - could be dynamically equal +// - considering dynamic shapes equal unless contradictory info known; +// - element types are equivalent, modulo subtypes possible be less exact +// (e.g., a resource type without subtype is considered compatible with +// resource type with known subtype). +// Provide option to ignore ref types on 'lhs'. +bool HasCompatibleElementTypes(Type lhs, Type rhs, + bool may_ignore_ref_type_lhs = false); + +// Returns true if all TensorFlow types can be cast to one +// another. In other words, a single run-time value is legal for all the types. +// For example, tensor<*xf32>, tensor and tensor<3xf32> are cast +// compatible. +bool AreCastCompatible(TypeRange types); + +// Returns true if corresponding elements of lhs and rhs AreCastCompatible and +// lhs and rhs are the same length. +bool ArraysAreCastCompatible(TypeRange lhs, TypeRange rhs); + +// If `ty` is a tensor type and its element type has subtypes, then returns a +// new type of same shape but dropped subtypes for the element type. +// Otherwise, if `ty` has subtypes, then returns corresponding type with dropped +// subtypes. +// Otherwise, returns the original type `ty`. +Type DropSubTypes(Type ty); + +// If `ty` is a tensor type and has elements of a ref type, then returns a new +// type of same shape but corresponding non-ref type as element type. +// Otherwise, if `ty` is a ref type, then returns corresponding non-ref type. +// Otherwise, returns the original type `ty`. +Type DropRefType(Type ty); + +// Convenience call for executing both `DropRefType` and `DropSubTypes`. +Type DropRefAndSubTypes(Type ty); + +//===----------------------------------------------------------------------===// +// Utility iterators +//===----------------------------------------------------------------------===// + +// An iterator for the tensor shapes of an op's operands of shaped types. +// Returns std::nullopt if a operand is unranked; returns ArrayRef as +// the shape otherwise. +class OperandShapeIterator final + : public llvm::mapped_iterator> (*)( + Value)> { + public: + using reference = std::optional>; + + /// Initializes the operand shape iterator to the specified operand iterator. + explicit OperandShapeIterator(Operation::operand_iterator it); +}; + +using OperandShapeRange = iterator_range; + +// An iterator for the tensor shapes of an op's results of shaped types. +// Returns std::nullopt if a result is unranked; returns ArrayRef as +// the shape otherwise. +class ResultShapeIterator final + : public llvm::mapped_iterator> (*)( + Value)> { + public: + using reference = std::optional>; + + /// Initializes the result shape iterator to the specified result iterator. + explicit ResultShapeIterator(Operation::result_iterator it); +}; + +using ResultShapeRange = iterator_range; + +// Returns a range with just resource type values from the input range +// preserved. +template +auto filter_resources(RangeT&& range) { + return llvm::make_filter_range(std::forward(range), [](Value val) { + return mlir::isa(getElementTypeOrSelf(val.getType())); + }); +} + +// Returns the element type if `type` is a `ShapedType` and the type itself +// otherwise, converting `TensorFlowRef` type to corresponding `TensorFlow` or +// standard type if necessary. +inline Type GetElementTypeOrSelfResolveRef(Type type) { + Type element_type = getElementTypeOrSelf(type); + if (auto ref_type = mlir::dyn_cast(element_type)) { + element_type = ref_type.RemoveRef(); + } + return element_type; +} + +} // namespace tf_type +} // namespace mlir + +//===----------------------------------------------------------------------===// +// Tablegen Attribute Declarations +//===----------------------------------------------------------------------===// + +#define GET_ATTRDEF_CLASSES +#include "tensorflow/core/ir/types/attributes.h.inc" +#include "tensorflow/core/ir/types/attributes_enum.h.inc" + +#endif // TENSORFLOW_CORE_IR_TYPES_DIALECT_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/ir/utility.h b/third_party/tflite-hdrs/tensorflow/core/ir/utility.h new file mode 100644 index 00000000..e234751e --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/ir/utility.h @@ -0,0 +1,87 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_IR_UTILITY_H_ +#define TENSORFLOW_CORE_IR_UTILITY_H_ + +#include + +#include "llvm/ADT/STLExtras.h" +#include "mlir/IR/Block.h" // from @llvm-project +#include "mlir/IR/OperationSupport.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "tensorflow/core/ir/dialect.h" + +namespace mlir { +namespace tfg { + +// Region-based loop ops store control tokens all after the data values, unlike +// functions which store them as pairs. This is required by +// RegionBranchOpInterface's API which requires MutableOperandRange, i.e. the +// data operands need to be stored contiguously. + +// TODO(jeffniu): These functions aren't just for "loop regions" any more, but +// any region-based ops (if/case have explicit capture forms). + +// Given a region belonging to a region-based loop operation (e.g. a while +// loop), return the subrange of block arguments that are data values. +Block::BlockArgListType GetLoopRegionDataArgs(Region ®ion); +// Given a region belonging to a region-based loop operation (e.g. a while +// loop), return the subrange of block arguments that are control tokens. +Block::BlockArgListType GetLoopRegionControlTokens(Region ®ion); +// Given a data value block argument of a region belonging to a region-based +// loop operation (e.g. a while loop), return the block argument that +// corresponds to the control token. +BlockArgument GetLoopRegionControlOf(BlockArgument data); +// Given a control token block argument of a region belonging to a region-based +// loop operation (e.g. a while loop), return the block argument that +// corresponds to the data value. +BlockArgument GetLoopRegionDataOf(BlockArgument ctl); + +// Given a TFG value, lookup the associated control token. For op results, the +// token will be the last result of the op. For block arguments, the token will +// be the subsequent argument. A data value always has an associated control +// token. +Value LookupControlDependency(Value data); + +// Given a TFG control token, lookup the associated data value. Block arguments +// will always have an associated data value: the previous argument. For ops, +// if the only result is a control token, return None. Otherwise, returns the +// first result. +std::optional LookupDataValue(Value ctl); + +// Given a range of values, operands, or results, that contains data and control +// values, where all control tokens come after the data values, split the range +// between the two. +template +std::pair SplitDataAndControlValues(RangeT values, + ControlType ctl_type) { + unsigned num_ctl = 0; + for (Value value : llvm::reverse(values)) { + if (value.getType() == ctl_type) + ++num_ctl; + else + break; + } + unsigned split_idx = llvm::size(values) - num_ctl; + return std::make_pair(values.slice(0, split_idx), + values.slice(split_idx, num_ctl)); +} + +} // namespace tfg +} // namespace mlir + +#endif // TENSORFLOW_CORE_IR_UTILITY_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/ir/utils/shape_inference_utils.h b/third_party/tflite-hdrs/tensorflow/core/ir/utils/shape_inference_utils.h new file mode 100644 index 00000000..273f4cee --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/ir/utils/shape_inference_utils.h @@ -0,0 +1,94 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_IR_UTILS_SHAPE_INFERENCE_UTILS_H_ +#define TENSORFLOW_CORE_IR_UTILS_SHAPE_INFERENCE_UTILS_H_ + +#include + +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/IR/ValueRange.h" // from @llvm-project +#include "mlir/Interfaces/InferTypeOpInterface.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/shape_inference.h" +#include "tensorflow/core/platform/status.h" + +namespace tensorflow { +struct OpRegistrationData; +} // namespace tensorflow + +namespace mlir { +namespace tfg { + +// Function that takes in a value and extracts a constant from it, if available. +// If the value cannot be resolved as a constant, a nullptr will be returned. +// Certain shape functions require constant values as arguments. +using OperandAsConstantFn = llvm::function_ref; + +// Function that takes in an operation result and computes a shape (can be +// partial) value. Certain shape functions require shape values as arguments. +using OpResultAsShapeFn = + llvm::function_ref; + +// Function that takes a result index and returns the element type. Element +// types are necessary for handle types (resource, variant). +using ResultElementTypeFn = llvm::function_ref; + +// Extracts the attributes of a MLIR operation and populates the converted +// attributes in a proto map. This is used by operation +// defined in TF dialect which has different attributes format than TFG dialect. +using GetAttrValuesFn = llvm::function_ref; + +// Runs TensorFlow shape inference associated to the op type registered in the +// TensorFlow op registry based on the Graph version, operands, and attributes. +// Invoking this shape function will create conversions of parameters to the +// TensorFlow Graph equivalent data structures and back to MLIR equivalent data +// structures. This does not use a natively implemented shape inference in MLIR, +// and instead is temporary until shape functions are reimplemented/migrated to +// being in MLIR instead of the TensorFlow op registry. +// Note that the default way to get the attrs in the operation is using the API +// in TFG importer. For operations that has different format of attributes, they +// should give the `get_attr_values_fn` to read the attributes correctly. +LogicalResult InferReturnTypeComponentsForTFOp( + std::optional location, Operation* op, ValueRange operands, + int64_t graph_version, OperandAsConstantFn operand_as_constant_fn, + OpResultAsShapeFn op_result_as_shape_fn, + ResultElementTypeFn result_element_type_fn, + GetAttrValuesFn get_attr_values_fn, + SmallVectorImpl& inferred_return_shapes); + +// This one is almost the same as the above one, the difference is that we use +// ConvertOperationToNode to convert the operation to NodeDef to get the attr +// values. +LogicalResult InferReturnTypeComponentsForTFOp( + std::optional location, Operation* op, ValueRange operands, + int64_t graph_version, OperandAsConstantFn operand_as_constant_fn, + OpResultAsShapeFn op_result_as_shape_fn, + ResultElementTypeFn result_element_type_fn, + SmallVectorImpl& inferred_return_shapes); + +} // namespace tfg +} // namespace mlir + +#endif // TENSORFLOW_CORE_IR_UTILS_SHAPE_INFERENCE_UTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/aggregate_ops.h b/third_party/tflite-hdrs/tensorflow/core/kernels/aggregate_ops.h new file mode 100644 index 00000000..7f56e994 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/aggregate_ops.h @@ -0,0 +1,226 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_AGGREGATE_OPS_H_ +#define TENSORFLOW_CORE_KERNELS_AGGREGATE_OPS_H_ + +#include + +#include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/tensor_types.h" + +namespace tensorflow { +namespace functor { + +// Functor definitions for Aggregate ops, must be compilable by nvcc. +template +struct Add2Functor { + void operator()(const Device& d, typename TTypes::Flat out, + typename TTypes::ConstFlat in1, + typename TTypes::ConstFlat in2); +}; + +template +struct Add2EigenImpl { + static void Compute(const Device& d, typename TTypes::Flat out, + typename TTypes::ConstFlat in1, + typename TTypes::ConstFlat in2) { + out.device(d) = in1 + in2; + } +}; + +template +struct Add3Functor { + void operator()(const Device& d, typename TTypes::Flat out, + typename TTypes::ConstFlat in1, + typename TTypes::ConstFlat in2, + typename TTypes::ConstFlat in3); +}; + +template +struct Add3EigenImpl { + static void Compute(const Device& d, typename TTypes::Flat out, + typename TTypes::ConstFlat in1, + typename TTypes::ConstFlat in2, + typename TTypes::ConstFlat in3) { + out.device(d) = in1 + in2 + in3; + } +}; + +template +struct Add4Functor { + void operator()(const Device& d, typename TTypes::Flat out, + typename TTypes::ConstFlat in1, + typename TTypes::ConstFlat in2, + typename TTypes::ConstFlat in3, + typename TTypes::ConstFlat in4); +}; + +template +struct Add4EigenImpl { + static void Compute(const Device& d, typename TTypes::Flat out, + typename TTypes::ConstFlat in1, + typename TTypes::ConstFlat in2, + typename TTypes::ConstFlat in3, + typename TTypes::ConstFlat in4) { + out.device(d) = in1 + in2 + in3 + in4; + } +}; + +template +struct Add5Functor { + void operator()(const Device& d, typename TTypes::Flat out, + typename TTypes::ConstFlat in1, + typename TTypes::ConstFlat in2, + typename TTypes::ConstFlat in3, + typename TTypes::ConstFlat in4, + typename TTypes::ConstFlat in5); +}; + +template +struct Add5EigenImpl { + static void Compute(const Device& d, typename TTypes::Flat out, + typename TTypes::ConstFlat in1, + typename TTypes::ConstFlat in2, + typename TTypes::ConstFlat in3, + typename TTypes::ConstFlat in4, + typename TTypes::ConstFlat in5) { + out.device(d) = in1 + in2 + in3 + in4 + in5; + } +}; + +template +struct Add6Functor { + void operator()(const Device& d, typename TTypes::Flat out, + typename TTypes::ConstFlat in1, + typename TTypes::ConstFlat in2, + typename TTypes::ConstFlat in3, + typename TTypes::ConstFlat in4, + typename TTypes::ConstFlat in5, + typename TTypes::ConstFlat in6); +}; + +template +struct Add6EigenImpl { + static void Compute(const Device& d, typename TTypes::Flat out, + typename TTypes::ConstFlat in1, + typename TTypes::ConstFlat in2, + typename TTypes::ConstFlat in3, + typename TTypes::ConstFlat in4, + typename TTypes::ConstFlat in5, + typename TTypes::ConstFlat in6) { + out.device(d) = in1 + in2 + in3 + in4 + in5 + in6; + } +}; + +template +struct Add7Functor { + void operator()(const Device& d, typename TTypes::Flat out, + typename TTypes::ConstFlat in1, + typename TTypes::ConstFlat in2, + typename TTypes::ConstFlat in3, + typename TTypes::ConstFlat in4, + typename TTypes::ConstFlat in5, + typename TTypes::ConstFlat in6, + typename TTypes::ConstFlat in7); +}; + +template +struct Add7EigenImpl { + static void Compute(const Device& d, typename TTypes::Flat out, + typename TTypes::ConstFlat in1, + typename TTypes::ConstFlat in2, + typename TTypes::ConstFlat in3, + typename TTypes::ConstFlat in4, + typename TTypes::ConstFlat in5, + typename TTypes::ConstFlat in6, + typename TTypes::ConstFlat in7) { + out.device(d) = in1 + in2 + in3 + in4 + in5 + in6 + in7; + } +}; + +template +struct Add8Functor { + void operator()( + const Device& d, typename TTypes::Flat out, + typename TTypes::ConstFlat in1, typename TTypes::ConstFlat in2, + typename TTypes::ConstFlat in3, typename TTypes::ConstFlat in4, + typename TTypes::ConstFlat in5, typename TTypes::ConstFlat in6, + typename TTypes::ConstFlat in7, typename TTypes::ConstFlat in8); +}; + +template +struct Add8EigenImpl { + static void Compute( + const Device& d, typename TTypes::Flat out, + typename TTypes::ConstFlat in1, typename TTypes::ConstFlat in2, + typename TTypes::ConstFlat in3, typename TTypes::ConstFlat in4, + typename TTypes::ConstFlat in5, typename TTypes::ConstFlat in6, + typename TTypes::ConstFlat in7, typename TTypes::ConstFlat in8) { + out.device(d) = in1 + in2 + in3 + in4 + in5 + in6 + in7 + in8; + } +}; + +// Add8p is like Add8 except the underlying implementation should += +// rather than assign to the output. +template +struct Add8pFunctor { + void operator()( + const Device& d, typename TTypes::Flat out, + typename TTypes::ConstFlat in1, typename TTypes::ConstFlat in2, + typename TTypes::ConstFlat in3, typename TTypes::ConstFlat in4, + typename TTypes::ConstFlat in5, typename TTypes::ConstFlat in6, + typename TTypes::ConstFlat in7, typename TTypes::ConstFlat in8); +}; + +template +struct Add8pEigenImpl { + static void Compute( + const Device& d, typename TTypes::Flat out, + typename TTypes::ConstFlat in1, typename TTypes::ConstFlat in2, + typename TTypes::ConstFlat in3, typename TTypes::ConstFlat in4, + typename TTypes::ConstFlat in5, typename TTypes::ConstFlat in6, + typename TTypes::ConstFlat in7, typename TTypes::ConstFlat in8) { + out.device(d) += in1 + in2 + in3 + in4 + in5 + in6 + in7 + in8; + } +}; + +template +struct Add9Functor { + void operator()( + const Device& d, typename TTypes::Flat out, + typename TTypes::ConstFlat in1, typename TTypes::ConstFlat in2, + typename TTypes::ConstFlat in3, typename TTypes::ConstFlat in4, + typename TTypes::ConstFlat in5, typename TTypes::ConstFlat in6, + typename TTypes::ConstFlat in7, typename TTypes::ConstFlat in8, + typename TTypes::ConstFlat in9); +}; + +template +struct Add9EigenImpl { + static void Compute( + const Device& d, typename TTypes::Flat out, + typename TTypes::ConstFlat in1, typename TTypes::ConstFlat in2, + typename TTypes::ConstFlat in3, typename TTypes::ConstFlat in4, + typename TTypes::ConstFlat in5, typename TTypes::ConstFlat in6, + typename TTypes::ConstFlat in7, typename TTypes::ConstFlat in8, + typename TTypes::ConstFlat in9) { + out.device(d) = in1 + in2 + in3 + in4 + in5 + in6 + in7 + in8 + in9; + } +}; +} // namespace functor +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_AGGREGATE_OPS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/aggregate_ops_cpu.h b/third_party/tflite-hdrs/tensorflow/core/kernels/aggregate_ops_cpu.h new file mode 100644 index 00000000..f205d8d1 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/aggregate_ops_cpu.h @@ -0,0 +1,142 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_AGGREGATE_OPS_CPU_H_ +#define TENSORFLOW_CORE_KERNELS_AGGREGATE_OPS_CPU_H_ + +#include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/tensor_types.h" + +#include "tensorflow/core/kernels/aggregate_ops.h" + +typedef Eigen::ThreadPoolDevice CPUDevice; + + +namespace tensorflow { + +// Partial specializations for a CPUDevice, that uses the Eigen implementation +// from AddNEigenImpl. +namespace functor { +template +struct Add2Functor { + void operator()(const CPUDevice& d, typename TTypes::Flat out, + typename TTypes::ConstFlat in1, + typename TTypes::ConstFlat in2) { + Add2EigenImpl::Compute(d, out, in1, in2); + } +}; +template +struct Add3Functor { + void operator()(const CPUDevice& d, typename TTypes::Flat out, + typename TTypes::ConstFlat in1, + typename TTypes::ConstFlat in2, + typename TTypes::ConstFlat in3) { + Add3EigenImpl::Compute(d, out, in1, in2, in3); + } +}; +template +struct Add4Functor { + void operator()(const CPUDevice& d, typename TTypes::Flat out, + typename TTypes::ConstFlat in1, + typename TTypes::ConstFlat in2, + typename TTypes::ConstFlat in3, + typename TTypes::ConstFlat in4) { + Add4EigenImpl::Compute(d, out, in1, in2, in3, in4); + } +}; +template +struct Add5Functor { + void operator()(const CPUDevice& d, typename TTypes::Flat out, + typename TTypes::ConstFlat in1, + typename TTypes::ConstFlat in2, + typename TTypes::ConstFlat in3, + typename TTypes::ConstFlat in4, + typename TTypes::ConstFlat in5) { + Add5EigenImpl::Compute(d, out, in1, in2, in3, in4, in5); + } +}; +template +struct Add6Functor { + void operator()(const CPUDevice& d, typename TTypes::Flat out, + typename TTypes::ConstFlat in1, + typename TTypes::ConstFlat in2, + typename TTypes::ConstFlat in3, + typename TTypes::ConstFlat in4, + typename TTypes::ConstFlat in5, + typename TTypes::ConstFlat in6) { + Add6EigenImpl::Compute(d, out, in1, in2, in3, in4, in5, in6); + } +}; +template +struct Add7Functor { + void operator()(const CPUDevice& d, typename TTypes::Flat out, + typename TTypes::ConstFlat in1, + typename TTypes::ConstFlat in2, + typename TTypes::ConstFlat in3, + typename TTypes::ConstFlat in4, + typename TTypes::ConstFlat in5, + typename TTypes::ConstFlat in6, + typename TTypes::ConstFlat in7) { + Add7EigenImpl::Compute(d, out, in1, in2, in3, in4, in5, in6, + in7); + } +}; + +template +struct Add8Functor { + void operator()( + const CPUDevice& d, typename TTypes::Flat out, + typename TTypes::ConstFlat in1, typename TTypes::ConstFlat in2, + typename TTypes::ConstFlat in3, typename TTypes::ConstFlat in4, + typename TTypes::ConstFlat in5, typename TTypes::ConstFlat in6, + typename TTypes::ConstFlat in7, typename TTypes::ConstFlat in8) { + Add8EigenImpl::Compute(d, out, in1, in2, in3, in4, in5, in6, + in7, in8); + } +}; + +template +struct Add8pFunctor { + void operator()( + const CPUDevice& d, typename TTypes::Flat out, + typename TTypes::ConstFlat in1, typename TTypes::ConstFlat in2, + typename TTypes::ConstFlat in3, typename TTypes::ConstFlat in4, + typename TTypes::ConstFlat in5, typename TTypes::ConstFlat in6, + typename TTypes::ConstFlat in7, typename TTypes::ConstFlat in8) { + Add8pEigenImpl::Compute(d, out, in1, in2, in3, in4, in5, in6, + in7, in8); + } +}; + +template +struct Add9Functor { + void operator()( + const CPUDevice& d, typename TTypes::Flat out, + typename TTypes::ConstFlat in1, typename TTypes::ConstFlat in2, + typename TTypes::ConstFlat in3, typename TTypes::ConstFlat in4, + typename TTypes::ConstFlat in5, typename TTypes::ConstFlat in6, + typename TTypes::ConstFlat in7, typename TTypes::ConstFlat in8, + typename TTypes::ConstFlat in9) { + Add9EigenImpl::Compute(d, out, in1, in2, in3, in4, in5, in6, + in7, in8, in9); + } +}; + + +} // namespace functor + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_AGGREGATE_OPS_CPU_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/argmax_op.h b/third_party/tflite-hdrs/tensorflow/core/kernels/argmax_op.h new file mode 100644 index 00000000..9b2325c3 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/argmax_op.h @@ -0,0 +1,72 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_ARGMAX_OP_H_ +#define TENSORFLOW_CORE_KERNELS_ARGMAX_OP_H_ +// Generator definition for ArgMaxOp, must be compilable by nvcc. + +#include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +namespace functor { + +template +struct ArgMax { +#define DECLARE_COMPUTE_SPEC(Dims) \ + EIGEN_ALWAYS_INLINE static void Reduce##Dims( \ + const Device& d, typename TTypes::ConstTensor input, \ + const int32 dimension, typename TTypes::Tensor output) { \ + output.device(d) = input.argmax(dimension).template cast(); \ + } + + DECLARE_COMPUTE_SPEC(1); + DECLARE_COMPUTE_SPEC(2); + DECLARE_COMPUTE_SPEC(3); + DECLARE_COMPUTE_SPEC(4); + DECLARE_COMPUTE_SPEC(5); + DECLARE_COMPUTE_SPEC(6); + DECLARE_COMPUTE_SPEC(7); + +#undef DECLARE_COMPUTE_SPEC +}; + +template +struct ArgMin { +#define DECLARE_COMPUTE_SPEC(Dims) \ + EIGEN_ALWAYS_INLINE static void Reduce##Dims( \ + const Device& d, typename TTypes::ConstTensor input, \ + const int32 dimension, typename TTypes::Tensor output) { \ + output.device(d) = input.argmin(dimension).template cast(); \ + } + + DECLARE_COMPUTE_SPEC(1); + DECLARE_COMPUTE_SPEC(2); + DECLARE_COMPUTE_SPEC(3); + DECLARE_COMPUTE_SPEC(4); + DECLARE_COMPUTE_SPEC(5); + DECLARE_COMPUTE_SPEC(6); + DECLARE_COMPUTE_SPEC(7); + +#undef DECLARE_COMPUTE_SPEC +}; + +} // namespace functor + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_ARGMAX_OP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/assign_op.h b/third_party/tflite-hdrs/tensorflow/core/kernels/assign_op.h new file mode 100644 index 00000000..063be3e4 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/assign_op.h @@ -0,0 +1,71 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_ASSIGN_OP_H_ +#define TENSORFLOW_CORE_KERNELS_ASSIGN_OP_H_ + +#define EIGEN_USE_THREADS + +#include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/ref_var.h" +#include "tensorflow/core/framework/tensor_types.h" + +namespace tensorflow { + +// TODO(jeff): Get rid of use_exclusive_lock_ option + +// Computes *input[0] = input[1] +class AssignOp : public OpKernel { + public: + explicit AssignOp(OpKernelConstruction* context) : OpKernel(context) { + OP_REQUIRES_OK(context, + context->GetAttr("use_locking", &use_exclusive_lock_)); + OP_REQUIRES_OK(context, + context->GetAttr("validate_shape", &validate_shape_)); + OP_REQUIRES(context, IsRefType(context->input_type(0)), + errors::InvalidArgument("lhs input needs to be a ref type")); + if (!context + ->GetAttr("_grappler_relax_allocator_constraints", + &relax_constraints_) + .ok()) { + relax_constraints_ = false; + } + } + + void Compute(OpKernelContext* context) override { + constexpr int input_ref_index = 0; + constexpr int output_ref_index = 0; + constexpr int value_index = 1; + + auto copy = [this](OpKernelContext* cc_ctx, Tensor* lhs, + const Tensor& rhs) { Copy(cc_ctx, lhs, rhs); }; + + AssignRefVariable(context, input_ref_index, output_ref_index, value_index, + use_exclusive_lock_, validate_shape_, relax_constraints_, + copy); + } + + virtual void Copy(OpKernelContext* context, Tensor* lhs, + const Tensor& rhs) = 0; + + bool use_exclusive_lock_; + bool validate_shape_; + bool relax_constraints_; +}; + +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_ASSIGN_OP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/autotune_conv_impl.h b/third_party/tflite-hdrs/tensorflow/core/kernels/autotune_conv_impl.h new file mode 100644 index 00000000..63c6a64d --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/autotune_conv_impl.h @@ -0,0 +1,97 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +------------------------------------------------------------------------------*/ + +#ifndef TENSORFLOW_CORE_KERNELS_AUTOTUNE_CONV_IMPL_H_ +#define TENSORFLOW_CORE_KERNELS_AUTOTUNE_CONV_IMPL_H_ + +#if GOOGLE_CUDA +#define EIGEN_USE_THREADS + +#include "xla/stream_executor/gpu/redzone_allocator.h" +#include "xla/stream_executor/integrations/tf_allocator_adapter.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/kernels/conv_ops_gpu.h" +#include "tensorflow/core/util/proto/proto_utils.h" + +namespace tensorflow::internal { + +template +StatusOr> AutotuneConvImpl( + OpKernelContext* ctx, + std::vector>>& runners, + bool actually_do_autotune, const LaunchFunc& launch_func, + size_t scratch_size_limit, const se::RedzoneAllocator& rz_allocator) { + auto* stream = ctx->op_device_context()->stream(); + + se::TfAllocatorAdapter tf_allocator_adapter(ctx->device()->GetAllocator({}), + stream); + + std::vector results; + // TODO(reedwm): Warn if determinism is enabled after autotune is run + for (auto& runner : runners) { + // TODO(zhengxq): profile each algorithm multiple times to better + // accuracy. + se::RedzoneAllocator rz_scratch_allocator( + stream, &tf_allocator_adapter, + /*memory_limit=*/scratch_size_limit); + DnnScratchAllocator scratch_allocator(scratch_size_limit, ctx); + se::ScratchAllocator* allocator_used = + !RedzoneCheckDisabled() + ? static_cast(&rz_scratch_allocator) + : static_cast(&scratch_allocator); + + TF_ASSIGN_OR_RETURN(auto desc, runner->ToAlgorithmDesc()); + se::dnn::ProfileResult profile_result; + Status cudnn_launch_status = + actually_do_autotune + ? launch_func(allocator_used, runner, &profile_result) + : OkStatus(); + if (!actually_do_autotune) { + // Make the result valid according to `is_valid`. + profile_result.set_algorithm(desc); + profile_result.set_elapsed_time_in_ms(0); + } + + // We need to make sure the profiling results are one-to-one with the + // "runners". So, we insert dummy results when the execution fails. + results.emplace_back(); + auto& result = results.back(); + *result.mutable_algorithm() = desc.ToProto(); + if (cudnn_launch_status.ok() && profile_result.is_valid()) { + result.set_scratch_bytes( + !RedzoneCheckDisabled() + ? rz_scratch_allocator.TotalAllocatedBytesExcludingRedzones() + : scratch_allocator.TotalByteSize()); + *result.mutable_run_time() = proto_utils::ToDurationProto( + absl::Milliseconds(profile_result.elapsed_time_in_ms())); + + CheckRedzones(rz_scratch_allocator, &result); + CheckRedzones(rz_allocator, &result); + } else { + result.mutable_failure()->set_kind(xla::AutotuneResult::UNKNOWN); + result.mutable_failure()->set_msg( + absl::StrCat("Profiling failure on CUDNN engine ", desc.ToString(), + ": ", cudnn_launch_status.ToString())); + } + } + + return results; +} + +} // namespace tensorflow::internal + +#endif // GOOGLE_CUDA + +#endif // TENSORFLOW_CORE_KERNELS_AUTOTUNE_CONV_IMPL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/avgpooling_op.h b/third_party/tflite-hdrs/tensorflow/core/kernels/avgpooling_op.h new file mode 100644 index 00000000..8008c3c4 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/avgpooling_op.h @@ -0,0 +1,76 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_AVGPOOLING_OP_H_ +#define TENSORFLOW_CORE_KERNELS_AVGPOOLING_OP_H_ +// Functor definition for AvgPoolingOp, must be compilable by nvcc. + +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/kernels/eigen_pooling.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +namespace functor { + +template +struct SpatialAvgPooling { + void operator()(const Device& d, typename TTypes::Tensor output, + typename TTypes::ConstTensor input, int window_rows, + int window_cols, int row_stride, int col_stride, + const Eigen::PaddingType& padding) { + MaybeWith32BitIndexing( + [&](auto output32, auto input32) { + // Because we swap the layout, we swap the row/cols as well. + output32.swap_layout().device(d) = Eigen::SpatialAvgPooling( + input32.swap_layout(), window_cols, window_rows, col_stride, + row_stride, padding); + }, + output, input); + } +}; + +} // namespace functor + +typedef Eigen::GpuDevice GPUDevice; + +// Launch a custom GPU kernels from Yanqing for the avgpooling backward +// operation that works NHWC data formats. Arguments: +// top_diff: backprop to the output of the pooling layer +// num: number of input batches +// height: input height +// width: input width +// channels: number of input channels +// pooled_height: the height of the output to the pooling layer +// pooled_width: the width of the output to the pooling layer +// kernel_h: the height of the pooling kernel +// kernel_w: the width of the pooling kernel +// stride_h: the height of the vertical stride +// stride_w: the width of the horizontal stride +// pad_t: padding size to the top side +// pad_l: padding size to the left side +// bottom_diff: backprop to the input of the pooling layer. +template +bool RunAvePoolBackwardNHWC(const T* const top_diff, const int num, + const int height, const int width, + const int channels, const int pooled_height, + const int pooled_width, const int kernel_h, + const int kernel_w, const int stride_h, + const int stride_w, const int pad_t, + const int pad_l, T* const bottom_diff, + const GPUDevice& d); + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_AVGPOOLING_OP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/batch_kernel_test_util.h b/third_party/tflite-hdrs/tensorflow/core/kernels/batch_kernel_test_util.h new file mode 100644 index 00000000..2495580a --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/batch_kernel_test_util.h @@ -0,0 +1,48 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_BATCH_KERNEL_TEST_UTIL_H_ +#define TENSORFLOW_CORE_KERNELS_BATCH_KERNEL_TEST_UTIL_H_ + +#include +#include "tensorflow/core/kernels/batch_kernels.h" +#include "tensorflow/core/kernels/ops_testutil.h" +#include "tensorflow/core/platform/status.h" + +namespace tensorflow { +namespace test_util { + +// A test util for accessing private members of `BatchFunctionKernel`. +class BatchFunctionKernelTestAccess { + public: + explicit BatchFunctionKernelTestAccess(const BatchFunctionKernel* kernel); + + bool enable_adaptive_batch_threads() const; + + private: + const BatchFunctionKernel* const kernel_; +}; + +class BatchFunctionKernelTestBase : public OpsTestBase, + public ::testing::WithParamInterface { + public: + // Init test fixture with a batch kernel instance. + absl::Status Init(bool enable_adaptive_scheduler); +}; + +} // namespace test_util +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_BATCH_KERNEL_TEST_UTIL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/batch_kernels.h b/third_party/tflite-hdrs/tensorflow/core/kernels/batch_kernels.h new file mode 100644 index 00000000..73baea3a --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/batch_kernels.h @@ -0,0 +1,139 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_BATCH_KERNELS_H_ +#define TENSORFLOW_CORE_KERNELS_BATCH_KERNELS_H_ + +#include +#include + +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/status.h" +#include "tsl/platform/types.h" + +namespace tensorflow { + +// Per-model inflight batches parameters. +ABSL_CONST_INIT extern const int64_t kMinInflightBatches; +ABSL_CONST_INIT extern const int64_t kInitialInflightBatches; +ABSL_CONST_INIT extern const int64_t kBatchesToAverageOver; +ABSL_CONST_INIT extern const int64_t kMaxInflightBatches; + +namespace test_util { +class BatchFunctionKernelTestAccess; +} // namespace test_util + +// Records the usage of attribute `enable_large_batch_splitting`. +void RecordBatchSplitUsage( + std::optional maybe_enable_large_batch_splitting, + absl::string_view model_name); + +// Records the number of batch threads of a model. +void RecordBatchParamNumBatchThreads(int64_t num_batch_threads, + absl::string_view model_name); + +// Returns the model name from the context. +absl::string_view GetModelName(OpKernelContext* ctx); + +// `BatchFunctionKernel` is the implementation of op `BatchFunction`. +// +// `BatchFunctionKernel` will batch (tensor) inputs by concatenating them +// along the 0-th dimension, schedule a user-defined computation, and then +// splits the returned tensors as batch output. +// +// In particular, an instance of `BatchFunctionKernel` creates or re-uses a +// a batch scheduler instance based on op attributes, pre-processes and enqueues +// concatenated inputs to the scheduler which invokes user-defined function, +// and then splits function output as op output. +// +// User defined function is named by attribute `f` and defined in the graph. +class BatchFunctionKernel : public AsyncOpKernel { + public: + explicit BatchFunctionKernel(OpKernelConstruction* c); + + bool IsExpensive() override; + + void ComputeAsync(OpKernelContext* c, DoneCallback done) final; + + private: + friend class test_util::BatchFunctionKernelTestAccess; + + // Validates 'allowed_batch_sizes_'. The entries must increase monotonically. + // If large batch split is not enabled, the last one must equal + // `max_batch_size_`. otherwise the last element must be smaller than or equal + // to `max_batch_size_`. + absl::Status ValidateAllowedBatchSizes() const; + + // Creates the function handle if it isn't initialized yet; and re-use it + // afterwards. + absl::Status GetOrCreateFunctionHandle( + OpKernelContext* c, FunctionLibraryRuntime::Handle* handle); + + // Instantiate the user-defined function and emits `handle`. + absl::Status InstantiateFunction( + OpKernelContext* c, FunctionLibraryRuntime::Handle* handle) const; + + // Initialize vars by reading from op-kernel-construction. + // Vars + // - enable_adaptive_batch_threads_ + // true if value of attribute `kEnableAdaptiveSchedulerAttr` is true, or + // if `num_batch_threads` is not positive. + // - adaptive_batch_scheduler_options_ + // Read from corresponding attributes as long as they are set. + void SetAdaptiveBatchSchedulerOptions(OpKernelConstruction* c, + int32_t num_batch_threads); + string container_; + string shared_name_; + string batcher_queue_; + int32 num_batch_threads_; + int32 max_batch_size_; + int32 batch_timeout_micros_; + int32 max_enqueued_batches_; + std::vector allowed_batch_sizes_; + int32 low_priority_max_batch_size_; + int32 low_priority_batch_timeout_micros_; + int32 low_priority_max_enqueued_batches_; + std::vector low_priority_allowed_batch_sizes_; + std::string mixed_priority_policy_; + std::string batch_padding_policy_; + NameAttrList func_; + absl::optional fhandle_ TF_GUARDED_BY(mu_); + bool enable_large_batch_splitting_ = false; + bool has_attribute_enable_large_batch_splitting_ = false; + bool enable_adaptive_batch_threads_ = false; + + mutex mu_; + + // Parameters for adaptive batch scheduler only. + // Note 'num_batch_threads_' above is shared by two implementations of batch + // scheduler. + struct AdaptiveBatchSchedulerOptions { + int32 min_in_flight_batches_limit = kMinInflightBatches; + int32 initial_in_flight_batches_limit = kInitialInflightBatches; + int32 max_in_flight_batches_limit = kMaxInflightBatches; + int32 batches_to_average_over = kBatchesToAverageOver; + int64 full_batch_scheduling_boost_micros = -1; + }; + absl::optional + adaptive_batch_scheduler_options_ = absl::nullopt; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_BATCH_KERNELS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/batch_norm_op.h b/third_party/tflite-hdrs/tensorflow/core/kernels/batch_norm_op.h new file mode 100644 index 00000000..7341833e --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/batch_norm_op.h @@ -0,0 +1,143 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_BATCH_NORM_OP_H_ +#define TENSORFLOW_CORE_KERNELS_BATCH_NORM_OP_H_ +// Functor definition for BatchNormOp, must be compilable by nvcc. +#include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/tensor_types.h" + +namespace tensorflow { +namespace functor { + +// Functor used by BatchNormOp to do the computations. +template +struct BatchNorm { + void operator()(const Device& d, typename TTypes::ConstTensor input, + typename TTypes::ConstVec mean, + typename TTypes::ConstVec var, + typename TTypes::ConstVec beta, + typename TTypes::ConstVec gamma, T variance_epsilon, + bool scale_after_normalization, + typename TTypes::Tensor output) { + const int depth = mean.dimension(0); + const int rest_size = input.size() / depth; + + Eigen::DSizes rest_by_depth(rest_size, depth); + Eigen::IndexList > rest_by_one; + rest_by_one.set(0, rest_size); + Eigen::IndexList, int> one_by_depth; + one_by_depth.set(1, depth); + Eigen::IndexList > depth_by_one; + depth_by_one.set(0, depth); + if (scale_after_normalization) { + output.reshape(rest_by_depth).device(d) = + (input.reshape(rest_by_depth) - + mean.reshape(one_by_depth).broadcast(rest_by_one)) * + ((var + var.constant(variance_epsilon)).rsqrt() * gamma) + .eval() + .reshape(one_by_depth) + .broadcast(rest_by_one) + + beta.reshape(one_by_depth).broadcast(rest_by_one); + } else { + output.reshape(rest_by_depth).device(d) = + (input.reshape(rest_by_depth) - + mean.reshape(one_by_depth).broadcast(rest_by_one)) * + ((var + var.constant(variance_epsilon)).rsqrt()) + .eval() + .reshape(one_by_depth) + .broadcast(rest_by_one) + + beta.reshape(one_by_depth).broadcast(rest_by_one); + } + } +}; + +template +struct BatchNormGrad { + void operator()(const Device& d, typename TTypes::ConstTensor input, + typename TTypes::ConstVec mean, + typename TTypes::ConstVec var, + typename TTypes::ConstVec gamma, + typename TTypes::ConstTensor out_backprop, + T variance_epsilon, bool scale_after_normalization, + typename TTypes::Tensor dx, typename TTypes::Vec dm, + typename TTypes::Vec dv, typename TTypes::Vec db, + typename TTypes::Vec dg, typename TTypes::Vec scratch1, + typename TTypes::Vec scratch2) { + const int depth = mean.dimension(0); + const int rest_size = input.size() / depth; + + typedef typename TTypes::ConstVec::Index Index; + + Eigen::DSizes rest_by_depth(rest_size, depth); + Eigen::IndexList > rest_by_one; + rest_by_one.set(0, rest_size); + Eigen::IndexList, Index> one_by_depth; + one_by_depth.set(1, depth); + Eigen::IndexList > reduction_axis; + + // db = out_backprop + // + // dg = out_backprop * ((x - m) * rsqrt(v + epsilon)) + // + // dv = sum_over_rest(out_backprop * gamma * (x - m)) * + // (-1/2) * (v + epsilon) ^ (-3/2) + // + // dm = sum_over_rest(out_backprop * gamma) * (-1 / rsqrt(v + epsilon)) + // + // dx = out_backprop * (gamma * rsqrt(v + epsilon)) + db.device(d) = out_backprop.reshape(rest_by_depth).sum(reduction_axis); + + // scratch1 = rsqrt(v + epsilon) + scratch1.device(d) = (var + var.constant(variance_epsilon)).rsqrt(); + + // scratch2 = sum_over_rest(out_backprop * (x - m)) + scratch2.device(d) = (out_backprop.reshape(rest_by_depth) * + (input.reshape(rest_by_depth) - + mean.reshape(one_by_depth).broadcast(rest_by_one))) + .sum(reduction_axis); + + if (scale_after_normalization) { + dx.reshape(rest_by_depth).device(d) = + out_backprop.reshape(rest_by_depth) * ((scratch1 * gamma) + .eval() + .reshape(one_by_depth) + .broadcast(rest_by_one)); + dm.device(d) = -db * (scratch1 * gamma).eval(); + dg.device(d) = scratch2 * scratch1; + } else { + dx.reshape(rest_by_depth).device(d) = + out_backprop.reshape(rest_by_depth) * + scratch1.reshape(one_by_depth).broadcast(rest_by_one); + dm.device(d) = -db * scratch1; + dg.device(d) = dg.constant(static_cast(0.0)); // Gamma is not learned. + } + + // scratch1 = - 1/2 * (var + epsilon) ^ (-3/2) + scratch1.device(d) = scratch1 * scratch1.constant(static_cast(-0.5f)) / + (var + var.constant(variance_epsilon)); + + if (scale_after_normalization) { + dv.device(d) = scratch2 * (scratch1 * gamma).eval(); + } else { + dv.device(d) = scratch2 * scratch1; + } + } +}; + +} // namespace functor +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_BATCH_NORM_OP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/batch_util.h b/third_party/tflite-hdrs/tensorflow/core/kernels/batch_util.h new file mode 100644 index 00000000..dad2ec4e --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/batch_util.h @@ -0,0 +1,23 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// NOTE(lespeholt): This file is deprecated. Use +// "tensorflow/core/util/batch_util.h" instead. + +#ifndef TENSORFLOW_CORE_KERNELS_BATCH_UTIL_H_ +#define TENSORFLOW_CORE_KERNELS_BATCH_UTIL_H_ + +#include "tensorflow/core/util/batch_util.h" + +#endif // TENSORFLOW_CORE_KERNELS_BATCH_UTIL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/batching_util/adaptive_shared_batch_scheduler.h b/third_party/tflite-hdrs/tensorflow/core/kernels/batching_util/adaptive_shared_batch_scheduler.h new file mode 100644 index 00000000..8be441b2 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/batching_util/adaptive_shared_batch_scheduler.h @@ -0,0 +1,871 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_ADAPTIVE_SHARED_BATCH_SCHEDULER_H_ +#define TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_ADAPTIVE_SHARED_BATCH_SCHEDULER_H_ + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/types/optional.h" +#include "tensorflow/core/kernels/batching_util/batch_scheduler.h" +#include "tensorflow/core/kernels/batching_util/periodic_function.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/core/threadpool.h" +#include "tensorflow/core/platform/byte_order.h" +#include "tensorflow/core/platform/cpu_info.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/thread_annotations.h" +#include "tensorflow/core/platform/threadpool_interface.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/profiler/lib/connected_traceme.h" + +namespace tensorflow { +namespace serving { +namespace internal { +template +class ASBSBatch; + +template +class ASBSQueue; +} // namespace internal + +// Shared batch scheduler designed to minimize latency. The scheduler keeps +// track of a number of queues (one per model or model version) which are +// continuously enqueuing requests. The scheduler groups the requests into +// batches which it periodically sends off for processing (see +// shared_batch_scheduler.h for more details). AdaptiveSharedBatchScheduler +// (ASBS) prioritizes batches primarily by age (i.e. the batch's oldest request) +// along with a configurable preference for scheduling larger batches first. +// +// +// ASBS tries to keep the system busy by maintaining an adjustable number of +// concurrently processed batches. If a new batch is created, and the number of +// in flight batches is below the target, the next (i.e. oldest) batch is +// immediately scheduled. Similarly, when a batch finishes processing, the +// target is rechecked, and another batch may be scheduled. To avoid the need +// to carefully tune the target for workload, model type, platform, etc, it is +// dynamically adjusted in order to provide the lowest average latency. +// +// Some potential use cases: +// Hardware Accelerators (GPUs & TPUs) - If some phase of batch processing +// involves serial processing by a device, from a latency perspective it is +// desirable to keep the device evenly loaded, avoiding the need to wait for +// the device to process prior batches. +// CPU utilization - If the batch processing is cpu dominated, you can reap +// latency gains when underutilized by increasing the processing rate, but +// back the rate off when the load increases to avoid overload. + +template +class AdaptiveSharedBatchScheduler + : public std::enable_shared_from_this< + AdaptiveSharedBatchScheduler> { + public: + ~AdaptiveSharedBatchScheduler() { + // Finish processing batches before destroying other class members. + if (owned_batch_thread_pool_) { + delete batch_thread_pool_; + } + } + + struct Options { + // The name to use for the pool of batch threads. + string thread_pool_name = {"batch_threads"}; + // Number of batch processing threads - the maximum value of + // in_flight_batches_limit_. It is recommended that this value be set by + // running the system under load, observing the learned value for + // in_flight_batches_limit_, and setting this maximum to ~ 2x the value. + // Under low load, in_flight_batches_limit_ has no substantial effect on + // latency and therefore undergoes a random walk. Unreasonably large values + // for num_batch_threads allows for large in_flight_batches_limit_, which + // will harm latency for some time once load increases again. + int64_t num_batch_threads = port::MaxParallelism(); + // You can pass a ThreadPool directly rather than the above two + // parameters. If given, the above two parameers are ignored. Ownership of + // the threadpool is not transferred. + thread::ThreadPool* thread_pool = nullptr; + + // Lower bound for in_flight_batches_limit_. As discussed above, can be used + // to minimize the damage caused by the random walk under low load. + int64_t min_in_flight_batches_limit = 1; + // Although batch selection is primarily based on age, this parameter + // specifies a preference for larger batches. A full batch will be + // scheduled before an older, nearly empty batch as long as the age gap is + // less than full_batch_scheduling_boost_micros. The optimal value for this + // parameter should be of order the batch processing latency, but must be + // chosen carefully, as too large a value will harm tail latency. + int64_t full_batch_scheduling_boost_micros = 0; + // The environment to use (typically only overridden by test code). + Env* env = Env::Default(); + // Initial limit for number of batches being concurrently processed. + // Non-integer values correspond to probabilistic limits - i.e. a value of + // 3.2 results in an actual cap of 3 80% of the time, and 4 20% of the time. + double initial_in_flight_batches_limit = 3; + // Number of batches between adjustments of in_flight_batches_limit. Larger + // numbers will give less noisy latency measurements, but will be less + // responsive to changes in workload. + int64_t batches_to_average_over = 1000; + + // If true, schedule batches using FIFO policy. + // Requires that `full_batch_scheduling_boost_micros` is zero. + // NOTE: + // A new parameter is introduced (not re-using + // full_batch_scheduling_boost_micros==zero) for backward compatibility of + // API. + bool fifo_scheduling = false; + }; + + // Ownership is shared between the caller of Create() and any queues created + // via AddQueue(). + static absl::Status Create( + const Options& options, + std::shared_ptr>* scheduler); + + struct QueueOptions { + // Maximum size of a batch that's formed within + // `ASBSQueue::Schedule`. + int max_batch_size = 1000; + // Maximum size of input task, which is submitted to the queue by + // calling `ASBSQueue::Schedule` and used to form batches. + // + // If specified, it should be larger than or equal to 'max_batch_size'. + absl::optional max_input_task_size = absl::nullopt; + // Maximum number of tasks to add to a specific batch. + absl::optional max_tasks_per_batch = absl::nullopt; + // Maximum number of enqueued (i.e. non-scheduled) batches. + int max_enqueued_batches = 10; + // Amount of time non-full batches must wait before becoming schedulable. + // A non-zero value can improve performance by limiting the scheduling of + // nearly empty batches. + int64_t batch_timeout_micros = 0; + // If non nullptr, split_input_task_func should split input_task into + // multiple tasks, the first of which has size first_size and the remaining + // not exceeding max_size. This function may acquire ownership of input_task + // and should return a status indicating if the split was successful. Upon + // success, the caller can assume that all output_tasks will be scheduled. + // Including this option allows the scheduler to pack batches better and + // should usually improve overall throughput. + std::function* input_task, int first_size, + int max_batch_size, + std::vector>* output_tasks)> + split_input_task_func; + + // If true, the padding will not be appended. + bool disable_padding = false; + }; + + using BatchProcessor = std::function>)>; + + // Adds queue (and its callback) to be managed by this scheduler. + absl::Status AddQueue(const QueueOptions& options, + BatchProcessor process_batch_callback, + std::unique_ptr>* queue); + + double in_flight_batches_limit() { + mutex_lock l(mu_); + return in_flight_batches_limit_; + } + + private: + // access to AddBatch, MaybeScheduleClosedBatches, RemoveQueue, GetEnv. + friend class internal::ASBSQueue; + + explicit AdaptiveSharedBatchScheduler(const Options& options); + + // Tracks processing latency and adjusts in_flight_batches_limit to minimize. + void CallbackWrapper(const internal::ASBSBatch* batch, + BatchProcessor callback, bool is_express); + + // Schedules batch if in_flight_batches_limit_ is not met. + void MaybeScheduleNextBatch() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); + + // Schedules batch using FIFO policy if in_flight_batches_limit_ is not met. + void MaybeScheduleNextBatchFIFO() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); + + // Schedules all closed batches in batches_ for which an idle thread is + // available in batch_thread_pool_. + // Batches scheduled this way are called express batches. + // Express batches are not limited by in_flight_batches_limit_, and + // their latencies will not affect in_flight_batches_limit_. + void MaybeScheduleClosedBatches(); + + void MaybeScheduleClosedBatchesLocked() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); + + void MaybeScheduleClosedBatchesLockedFIFO() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); + + void MaybeAdjustInflightLimit() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); + + // Notifies scheduler of non-empty batch which is eligible for processing. + void AddBatch(const internal::ASBSBatch* batch); + + // Removes queue from scheduler. + void RemoveQueue(const internal::ASBSQueue* queue); + + Env* GetEnv() const { return options_.env; } + + const Options options_; + + // Collection of batches added by AddBatch, ordered by age. Owned by scheduler + // until they are released for processing. + std::vector*> batches_ TF_GUARDED_BY(mu_); + + // Collection of batches added by AddBatch, ordered by age. Owned by + // scheduler until they are released for processing. + std::deque*> fifo_batches_ + TF_GUARDED_BY(mu_); + + // Unowned queues and callbacks added by AddQueue. + std::unordered_map*, BatchProcessor> + queues_and_callbacks_ TF_GUARDED_BY(mu_); + + mutex mu_; + + // Responsible for running the batch processing callbacks. + thread::ThreadPool* batch_thread_pool_; + + bool owned_batch_thread_pool_ = false; + + // Limit on number of batches which can be concurrently processed. + // Non-integer values correspond to probabilistic limits - i.e. a value of 3.2 + // results in an actual cap of 3 80% of the time, and 4 20% of the time. + double in_flight_batches_limit_ TF_GUARDED_BY(mu_); + + // Number of regular batches currently being processed. + int64_t in_flight_batches_ TF_GUARDED_BY(mu_) = 0; + // Number of express batches currently being processed. + int64_t in_flight_express_batches_ TF_GUARDED_BY(mu_) = 0; + + // RNG engine and distribution. + std::default_random_engine rand_engine_; + std::uniform_real_distribution rand_double_; + + // Fields controlling the dynamic adjustment of in_flight_batches_limit_. + // Number of batches since the last in_flight_batches_limit_ adjustment. + int64_t batch_count_ TF_GUARDED_BY(mu_) = 0; + + struct DelayStats { + // Sum of processing latency for batches counted by batch_count_. + int64_t batch_latency_sum = 0; + // Average batch latency for previous value of in_flight_batches_limit_. + double last_avg_latency_ms = 0; + // Did last_avg_latency_ms decrease from the previous last_avg_latency_ms? + bool last_latency_decreased = false; + // Current direction (+-) to adjust in_flight_batches_limit_ + int step_direction = 1; + }; + + // Delay stats between the creation of a batch and the completion of a + // batch. + DelayStats batch_delay_stats_ TF_GUARDED_BY(mu_); + + // Max adjustment size (as a fraction of in_flight_batches_limit_). + constexpr static double kMaxStepSizeMultiplier = 0.125; // 1/8; + // Min adjustment size (as a fraction of in_flight_batches_limit_). + constexpr static double kMinStepSizeMultiplier = 0.0078125; // 1/128 + // Current adjustment size (as a fraction of in_flight_batches_limit_). + double step_size_multiplier_ TF_GUARDED_BY(mu_) = kMaxStepSizeMultiplier; + + AdaptiveSharedBatchScheduler(const AdaptiveSharedBatchScheduler&) = delete; + void operator=(const AdaptiveSharedBatchScheduler&) = delete; +}; + +////////////////////////////////////////////////////////// +// Implementation details follow. API users need not read. + +namespace internal { +// Consolidates tasks into batches, passing them off to the +// AdaptiveSharedBatchScheduler for processing. +template +class ASBSQueue : public BatchScheduler { + public: + using QueueOptions = + typename AdaptiveSharedBatchScheduler::QueueOptions; + + ASBSQueue(std::shared_ptr> scheduler, + const QueueOptions& options); + + ~ASBSQueue() override; + + // Adds task to current batch. Fails if the task size is larger than the batch + // size or if the current batch is full and this queue's number of outstanding + // batches is at its maximum. + absl::Status Schedule(std::unique_ptr* task) override; + + // Number of tasks waiting to be scheduled. + size_t NumEnqueuedTasks() const override; + + // Number of size 1 tasks which could currently be scheduled without failing. + size_t SchedulingCapacity() const override; + + // Notifies queue that a batch is about to be scheduled; the queue should not + // place any more tasks in this batch. + void ReleaseBatch(const ASBSBatch* batch); + + size_t max_task_size() const override { return options_.max_batch_size; } + + private: + // Number of size 1 tasks which could currently be scheduled without failing. + size_t SchedulingCapacityLocked() const TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); + + // Returns uint64 one greater than was returned by the previous call. + // Context id is reused after std::numeric_limits::max is exhausted. + static uint64 NewTraceMeContextIdForBatch(); + + std::shared_ptr> scheduler_; + const QueueOptions options_; + // Owned by scheduler_. + ASBSBatch* current_batch_ TF_GUARDED_BY(mu_) = nullptr; + int64_t num_enqueued_batches_ TF_GUARDED_BY(mu_) = 0; + int64_t num_enqueued_tasks_ TF_GUARDED_BY(mu_) = 0; + mutable mutex mu_; + ASBSQueue(const ASBSQueue&) = delete; + void operator=(const ASBSQueue&) = delete; +}; + +// Batch which remembers when and by whom it was created. +template +class ASBSBatch : public Batch { + public: + ASBSBatch(ASBSQueue* queue, int64_t creation_time_micros, + int64_t batch_timeout_micros, uint64 traceme_context_id) + : queue_(queue), + creation_time_micros_(creation_time_micros), + schedulable_time_micros_(creation_time_micros + batch_timeout_micros), + traceme_context_id_(traceme_context_id) {} + + ~ASBSBatch() override {} + + ASBSQueue* queue() const { return queue_; } + + int64_t creation_time_micros() const { return creation_time_micros_; } + + int64_t schedulable_time_micros() const { return schedulable_time_micros_; } + + uint64 traceme_context_id() const { return traceme_context_id_; } + + private: + ASBSQueue* queue_; + const int64_t creation_time_micros_; + const int64_t schedulable_time_micros_; + const uint64 traceme_context_id_; + ASBSBatch(const ASBSBatch&) = delete; + void operator=(const ASBSBatch&) = delete; +}; +} // namespace internal + +// ---------------- AdaptiveSharedBatchScheduler ---------------- + +template +constexpr double AdaptiveSharedBatchScheduler::kMaxStepSizeMultiplier; + +template +constexpr double AdaptiveSharedBatchScheduler::kMinStepSizeMultiplier; + +template +absl::Status AdaptiveSharedBatchScheduler::Create( + const Options& options, + std::shared_ptr>* scheduler) { + if (options.num_batch_threads < 1) { + return errors::InvalidArgument("num_batch_threads must be positive; was ", + options.num_batch_threads); + } + if (options.min_in_flight_batches_limit < 1) { + return errors::InvalidArgument( + "min_in_flight_batches_limit must be >= 1; was ", + options.min_in_flight_batches_limit); + } + if (options.min_in_flight_batches_limit > options.num_batch_threads) { + return errors::InvalidArgument( + "min_in_flight_batches_limit (", options.min_in_flight_batches_limit, + ") must be <= num_batch_threads (", options.num_batch_threads, ")"); + } + if (options.full_batch_scheduling_boost_micros < 0) { + return errors::InvalidArgument( + "full_batch_scheduling_boost_micros can't be negative; was ", + options.full_batch_scheduling_boost_micros); + } + if (options.initial_in_flight_batches_limit > options.num_batch_threads) { + return errors::InvalidArgument( + "initial_in_flight_batches_limit (", + options.initial_in_flight_batches_limit, + ") should not be larger than num_batch_threads (", + options.num_batch_threads, ")"); + } + if (options.initial_in_flight_batches_limit < + options.min_in_flight_batches_limit) { + return errors::InvalidArgument("initial_in_flight_batches_limit (", + options.initial_in_flight_batches_limit, + "must be >= min_in_flight_batches_limit (", + options.min_in_flight_batches_limit, ")"); + } + if (options.batches_to_average_over < 1) { + return errors::InvalidArgument( + "batches_to_average_over should be " + "greater than or equal to 1; was ", + options.batches_to_average_over); + } + scheduler->reset(new AdaptiveSharedBatchScheduler(options)); + return absl::OkStatus(); +} + +template +AdaptiveSharedBatchScheduler::AdaptiveSharedBatchScheduler( + const Options& options) + : options_(options), + in_flight_batches_limit_(options.initial_in_flight_batches_limit), + rand_double_(0.0, 1.0) { + std::random_device device; + rand_engine_.seed(device()); + if (options.thread_pool == nullptr) { + owned_batch_thread_pool_ = true; + batch_thread_pool_ = new thread::ThreadPool( + GetEnv(), options.thread_pool_name, options.num_batch_threads); + } else { + owned_batch_thread_pool_ = false; + batch_thread_pool_ = options.thread_pool; + } +} + +template +absl::Status AdaptiveSharedBatchScheduler::AddQueue( + const QueueOptions& options, BatchProcessor process_batch_callback, + std::unique_ptr>* queue) { + if (options.max_batch_size <= 0) { + return errors::InvalidArgument("max_batch_size must be positive; was ", + options.max_batch_size); + } + if (options.max_enqueued_batches <= 0) { + return errors::InvalidArgument( + "max_enqueued_batches must be positive; was ", + options.max_enqueued_batches); + } + if (options.max_input_task_size.has_value()) { + if (options.max_input_task_size.value() < options.max_batch_size) { + return errors::InvalidArgument( + "max_input_task_size must be larger than or equal to max_batch_size;" + "got max_input_task_size as ", + options.max_input_task_size.value(), " and max_batch_size as ", + options.max_batch_size); + } + } + internal::ASBSQueue* asbs_queue_raw; + queue->reset(asbs_queue_raw = new internal::ASBSQueue( + this->shared_from_this(), options)); + mutex_lock l(mu_); + queues_and_callbacks_[asbs_queue_raw] = process_batch_callback; + return absl::OkStatus(); +} + +template +void AdaptiveSharedBatchScheduler::AddBatch( + const internal::ASBSBatch* batch) { + mutex_lock l(mu_); + if (options_.fifo_scheduling) { + fifo_batches_.push_back(batch); + } else { + batches_.push_back(batch); + } + int64_t delay_micros = + batch->schedulable_time_micros() - GetEnv()->NowMicros(); + if (delay_micros <= 0) { + MaybeScheduleNextBatch(); + return; + } + // Try to schedule batch once it becomes schedulable. Although scheduler waits + // for all batches to finish processing before allowing itself to be deleted, + // MaybeScheduleNextBatch() is called in other places, and therefore it's + // possible the scheduler could be deleted by the time this closure runs. + // Grab a shared_ptr reference to prevent this from happening. + GetEnv()->SchedClosureAfter( + delay_micros, [this, lifetime_preserver = this->shared_from_this()] { + mutex_lock l(mu_); + MaybeScheduleNextBatch(); + }); +} + +template +void AdaptiveSharedBatchScheduler::RemoveQueue( + const internal::ASBSQueue* queue) { + mutex_lock l(mu_); + queues_and_callbacks_.erase(queue); +} + +template +void AdaptiveSharedBatchScheduler::MaybeScheduleNextBatchFIFO() { + const internal::ASBSBatch* batch = *fifo_batches_.begin(); + if (batch->schedulable_time_micros() > GetEnv()->NowMicros()) { + return; + } + fifo_batches_.pop_front(); + // Queue may destroy itself after ReleaseBatch is called. + batch->queue()->ReleaseBatch(batch); + batch_thread_pool_->Schedule(std::bind( + &AdaptiveSharedBatchScheduler::CallbackWrapper, this, batch, + queues_and_callbacks_[batch->queue()], false /* is express */)); + in_flight_batches_++; +} + +template +void AdaptiveSharedBatchScheduler< + TaskType>::MaybeScheduleClosedBatchesLockedFIFO() { + // Only schedule closed batches if we have spare capacity. + int available_threads = + static_cast(options_.num_batch_threads - in_flight_batches_ - + in_flight_express_batches_); + for (auto it = fifo_batches_.begin(); + it != fifo_batches_.end() && available_threads > 0; + it = fifo_batches_.begin()) { + if ((*it)->IsClosed()) { + const internal::ASBSBatch* batch = *it; + fifo_batches_.pop_front(); + batch->queue()->ReleaseBatch(batch); + batch_thread_pool_->Schedule( + std::bind(&AdaptiveSharedBatchScheduler::CallbackWrapper, + this, batch, queues_and_callbacks_[batch->queue()], true)); + in_flight_express_batches_++; + available_threads--; + } else { + // Batches are FIFO, so stop iteration after finding the first non-closed + // batches. + break; + } + } +} + +template +void AdaptiveSharedBatchScheduler::MaybeScheduleNextBatch() { + bool batch_empty = + options_.fifo_scheduling ? fifo_batches_.empty() : batches_.empty(); + if (batch_empty || in_flight_batches_ >= in_flight_batches_limit_) return; + // Non-integer limit handled probabilistically. + if (in_flight_batches_limit_ - in_flight_batches_ < 1 && + rand_double_(rand_engine_) > + in_flight_batches_limit_ - in_flight_batches_) { + return; + } + + if (options_.fifo_scheduling) { + MaybeScheduleNextBatchFIFO(); + return; + } + + auto best_it = batches_.end(); + double best_score = (std::numeric_limits::max)(); + int64_t now_micros = GetEnv()->NowMicros(); + for (auto it = batches_.begin(); it != batches_.end(); it++) { + if ((*it)->schedulable_time_micros() > now_micros) continue; + const double score = + (*it)->creation_time_micros() - + options_.full_batch_scheduling_boost_micros * (*it)->size() / + static_cast((*it)->queue()->max_task_size()); + if (best_it == batches_.end() || score < best_score) { + best_score = score; + best_it = it; + } + } + // No schedulable batches. + if (best_it == batches_.end()) return; + const internal::ASBSBatch* batch = *best_it; + batches_.erase(best_it); + // Queue may destroy itself after ReleaseBatch is called. + batch->queue()->ReleaseBatch(batch); + batch_thread_pool_->Schedule( + std::bind(&AdaptiveSharedBatchScheduler::CallbackWrapper, this, + batch, queues_and_callbacks_[batch->queue()], false)); + in_flight_batches_++; +} + +template +void AdaptiveSharedBatchScheduler::MaybeScheduleClosedBatches() { + mutex_lock l(mu_); + MaybeScheduleClosedBatchesLocked(); +} + +template +void AdaptiveSharedBatchScheduler< + TaskType>::MaybeScheduleClosedBatchesLocked() { + if (options_.fifo_scheduling) { + MaybeScheduleClosedBatchesLockedFIFO(); + return; + } + // Only schedule closed batches if we have spare capacity. + int available_threads = + static_cast(options_.num_batch_threads - in_flight_batches_ - + in_flight_express_batches_); + for (auto it = batches_.begin(); + it != batches_.end() && available_threads > 0;) { + if ((*it)->IsClosed()) { + const internal::ASBSBatch* batch = *it; + it = batches_.erase(it); + batch->queue()->ReleaseBatch(batch); + batch_thread_pool_->Schedule( + std::bind(&AdaptiveSharedBatchScheduler::CallbackWrapper, + this, batch, queues_and_callbacks_[batch->queue()], true)); + in_flight_express_batches_++; + available_threads--; + } else { + ++it; + } + } +} + +template +void AdaptiveSharedBatchScheduler::CallbackWrapper( + const internal::ASBSBatch* batch, + AdaptiveSharedBatchScheduler::BatchProcessor callback, + bool is_express) { + tsl::profiler::TraceMeConsumer trace_me( + [&] { + return profiler::TraceMeEncode( + "ProcessBatch", {{"batch_size_before_padding", batch->size()}, + {"_r", 2} /*root_event*/}); + }, + tsl::profiler::ContextType::kAdaptiveSharedBatchScheduler, + batch->traceme_context_id()); + const int64_t start_time = batch->creation_time_micros(); + callback(std::unique_ptr>( + const_cast*>(batch))); + int64_t end_time = GetEnv()->NowMicros(); + mutex_lock l(mu_); + if (is_express) { + in_flight_express_batches_--; + MaybeScheduleClosedBatchesLocked(); + return; + } + in_flight_batches_--; + batch_count_++; + batch_delay_stats_.batch_latency_sum += end_time - start_time; + + MaybeAdjustInflightLimit(); + + MaybeScheduleNextBatch(); +} + +template +void AdaptiveSharedBatchScheduler::MaybeAdjustInflightLimit() { + // Occasionally adjust in_flight_batches_limit_ to minimize average latency. + // Although the optimal value may depend on the workload, the latency should + // be a simple convex function of in_flight_batches_limit_, allowing us to + // locate the global minimum relatively quickly. + if (batch_count_ == options_.batches_to_average_over) { + double current_avg_latency_ms = + (batch_delay_stats_.batch_latency_sum / 1000.) / batch_count_; + bool current_latency_decreased = + current_avg_latency_ms < batch_delay_stats_.last_avg_latency_ms; + if (current_latency_decreased) { + // If latency improvement was because we're moving in the correct + // direction, increase step_size so that we can get to the minimum faster. + // If latency improvement was due to backtracking from a previous failure, + // decrease step_size in order to refine our location. + step_size_multiplier_ *= + (batch_delay_stats_.last_latency_decreased ? 2 : 0.5); + step_size_multiplier_ = + std::min(step_size_multiplier_, kMaxStepSizeMultiplier); + step_size_multiplier_ = + std::max(step_size_multiplier_, kMinStepSizeMultiplier); + } else { + // Return (nearly) to previous position and confirm that latency is better + // there before decreasing step size. + batch_delay_stats_.step_direction = -batch_delay_stats_.step_direction; + } + in_flight_batches_limit_ += batch_delay_stats_.step_direction * + in_flight_batches_limit_ * + step_size_multiplier_; + in_flight_batches_limit_ = + std::min(in_flight_batches_limit_, + static_cast(options_.num_batch_threads)); + in_flight_batches_limit_ = + std::max(in_flight_batches_limit_, + static_cast(options_.min_in_flight_batches_limit)); + batch_delay_stats_.last_avg_latency_ms = current_avg_latency_ms; + batch_delay_stats_.last_latency_decreased = current_latency_decreased; + batch_count_ = 0; + batch_delay_stats_.batch_latency_sum = 0; + } +} + +// ---------------- ASBSQueue ---------------- + +namespace internal { +template +ASBSQueue::ASBSQueue( + std::shared_ptr> scheduler, + const QueueOptions& options) + : scheduler_(scheduler), options_(options) {} + +template +ASBSQueue::~ASBSQueue() { + // Wait until last batch has been scheduled. + const int kSleepMicros = 1000; + for (;;) { + { + mutex_lock l(mu_); + if (num_enqueued_batches_ == 0) { + break; + } + } + scheduler_->GetEnv()->SleepForMicroseconds(kSleepMicros); + } + scheduler_->RemoveQueue(this); +} + +template +absl::Status ASBSQueue::Schedule(std::unique_ptr* task) { + size_t size = (*task)->size(); + if (options_.split_input_task_func == nullptr && + size > options_.max_batch_size) { + return errors::InvalidArgument("Task size ", size, + " is larger than maximum batch size ", + options_.max_batch_size); + } + if (options_.max_input_task_size.has_value() && + (size > options_.max_input_task_size.value())) { + return errors::InvalidArgument("Task size ", size, + " is larger than max input task size ", + options_.max_input_task_size.value()); + } + + std::vector> tasks_to_schedule; + std::vector*> new_batches; + bool closed_batch = false; + { + mutex_lock l(mu_); + if (size > SchedulingCapacityLocked()) { + return errors::Unavailable("The batch scheduling queue is full"); + } + + int remaining_batch_size = + current_batch_ == nullptr + ? options_.max_batch_size + : options_.max_batch_size - current_batch_->size(); + if (options_.split_input_task_func == nullptr || + size <= remaining_batch_size) { + // Either we don't allow task splitting or task fits within the current + // batch. + tasks_to_schedule.push_back(std::move(*task)); + } else { + // Split task in order to completely fill the current batch. + // Beyond this point Schedule should not fail, as the caller has been + // promised that all of the split tasks will be scheduled. + TF_RETURN_IF_ERROR(options_.split_input_task_func( + task, remaining_batch_size, options_.max_batch_size, + &tasks_to_schedule)); + } + for (auto& task : tasks_to_schedule) { + // Can't fit within current batch, close it off and try to create another. + if (current_batch_ && + current_batch_->size() + task->size() > options_.max_batch_size) { + current_batch_->Close(); + closed_batch = true; + current_batch_ = nullptr; + } + if (!current_batch_) { + num_enqueued_batches_++; + // batch.traceme_context_id connects TraceMeProducer and + // TraceMeConsumer. + // When multiple calls to "ASBS::Schedule" accumulate to one batch, they + // are processed in the same batch and should share traceme_context_id. + current_batch_ = new ASBSBatch( + this, scheduler_->GetEnv()->NowMicros(), + options_.batch_timeout_micros, NewTraceMeContextIdForBatch()); + new_batches.push_back(current_batch_); + } + + // Annotate each task (corresponds to one call of schedule) with a + // TraceMeProducer. + tsl::profiler::TraceMeProducer trace_me( + [task_size = task->size()] { + return profiler::TraceMeEncode( + "ASBSQueue::Schedule", + {{"batching_input_task_size", task_size}}); + }, + tsl::profiler::ContextType::kAdaptiveSharedBatchScheduler, + this->current_batch_->traceme_context_id()); + current_batch_->AddTask(std::move(task)); + num_enqueued_tasks_++; + // If current_batch_ is now full, allow it to be processed immediately. + bool reached_max_tasks = + (options_.max_tasks_per_batch.has_value() && + current_batch_->num_tasks() >= options_.max_tasks_per_batch.value()); + if (current_batch_->size() == options_.max_batch_size || + reached_max_tasks) { + current_batch_->Close(); + closed_batch = true; + current_batch_ = nullptr; + } + } + } + // Scheduler functions must be called outside of lock, since they may call + // ReleaseBatch. + for (auto* batch : new_batches) { + scheduler_->AddBatch(batch); + } + if (closed_batch) { + scheduler_->MaybeScheduleClosedBatches(); + } + return absl::OkStatus(); +} + +template +void ASBSQueue::ReleaseBatch(const ASBSBatch* batch) { + mutex_lock l(mu_); + num_enqueued_batches_--; + num_enqueued_tasks_ -= batch->num_tasks(); + if (batch == current_batch_) { + current_batch_->Close(); + current_batch_ = nullptr; + } +} + +template +size_t ASBSQueue::NumEnqueuedTasks() const { + mutex_lock l(mu_); + return num_enqueued_tasks_; +} + +template +size_t ASBSQueue::SchedulingCapacity() const { + mutex_lock l(mu_); + return SchedulingCapacityLocked(); +} + +template +size_t ASBSQueue::SchedulingCapacityLocked() const { + const int current_batch_capacity = + current_batch_ ? options_.max_batch_size - current_batch_->size() : 0; + const int spare_batches = + options_.max_enqueued_batches - num_enqueued_batches_; + return spare_batches * options_.max_batch_size + current_batch_capacity; +} + +template +// static +uint64 ASBSQueue::NewTraceMeContextIdForBatch() { + static std::atomic traceme_context_id(0); + return traceme_context_id.fetch_add(1, std::memory_order_relaxed); +} +} // namespace internal +} // namespace serving +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_ADAPTIVE_SHARED_BATCH_SCHEDULER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/batching_util/basic_batch_scheduler.h b/third_party/tflite-hdrs/tensorflow/core/kernels/batching_util/basic_batch_scheduler.h new file mode 100644 index 00000000..a0665503 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/batching_util/basic_batch_scheduler.h @@ -0,0 +1,366 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_BASIC_BATCH_SCHEDULER_H_ +#define TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_BASIC_BATCH_SCHEDULER_H_ + +#include + +#include +#include +#include +#include + +#include "tensorflow/core/kernels/batching_util/shared_batch_scheduler.h" + +namespace tensorflow { +namespace serving { + +// A BatchScheduler implementation geared toward handling a single request type +// running on a specific set of hardware resources. A typical scenario is one in +// which all requests invoke the same machine-learned model on one GPU. +// +// If there are, say, two GPUs and two models each bound to one of the GPUs, one +// could use two BasicBatchScheduler instances to schedule the two model/GPU +// combinations independently. If multiple models must share a given GPU or +// other hardware resource, consider using SharedBatchScheduler instead. +// +// +// PARAMETERS AND BEHAVIOR: +// +// BasicBatchScheduler runs a fixed pool of threads, which it uses to process +// batches of tasks. It enforces a maximum batch size, and enqueues a bounded +// number of tasks. If the queue is nearly empty, such that a full batch cannot +// be formed, when a thread becomes free, it anyway schedules a batch +// immediately if a task has been in the queue for longer than a given timeout +// parameter. If the timeout parameter is set to 0, then the batch threads will +// always be kept busy (unless there are zero tasks waiting to be processed). +// +// For online serving, it is recommended to set the maximum number of enqueued +// batches worth of tasks equal to the number of batch threads, which allows +// enqueuing of enough tasks s.t. if every thread becomes available it can be +// kept busy, but no more. For bulk processing jobs and throughput-oriented +// benchmarks, you may want to set it much higher. +// +// When Schedule() is called, if the queue is full the call will fail with an +// UNAVAILABLE error (after which the client may retry again later). If the call +// succeeds, the maximum time the task will spend in the queue before being +// placed in a batch and assigned to a thread for processing, is the greater of: +// - the maximum time to process ceil(max_enqueued_batches/num_batch_threads) +// (1 in the recommended configuration) batches of previously-submitted tasks +// - the configured timeout parameter (which can be 0, as mentioned above) +// +// Unlike StreamingBatchScheduler, when BasicBatchScheduler assigns a batch to a +// thread, it closes the batch. The process-batch callback may assume that every +// batch it receives is closed at the outset. +// +// +// RECOMMENDED USE-CASES: +// +// BasicBatchScheduler is suitable for use-cases that feature a single kind of +// request (e.g. a server performing inference with a single machine-learned +// model, possibly evolving over time), with loose versioning semantics. +// Concretely, the following conditions should hold: +// +// A. All requests batched onto a given resource (e.g. a hardware accelerator, +// or a pool accelerators) are of the same type. For example, they all +// invoke the same machine-learned model. +// +// These variations are permitted: +// - The model may reside in a single servable, or it may be spread across +// multiple servables that are used in unison (e.g. a vocabulary lookup +// table servable and a tensorflow session servable). +// - The model's servable(s) may be static, or they may evolve over time +// (successive servable versions). +// - Zero or more of the servables are used in the request thread; the rest +// are used in the batch thread. In our running example, the vocabulary +// lookups and tensorflow runs may both be performed in the batch thread, +// or alternatively the vocabulary lookup may occur in the request thread +// with only the tensorflow run performed in the batch thread. +// +// In contrast, BasicBatchScheduler is not a good fit if the server +// hosts multiple distinct models running on a pool accelerators, with each +// request specifying which model it wants to use. BasicBatchScheduler +// has no facility to time-multiplex the batch threads across multiple +// models in a principled way. More basically, it cannot ensure that a given +// batch doesn't contain a mixture of requests for different models. +// +// B. Requests do not specify a particular version of the servable(s) that must +// be used. Instead, each request is content to use the "latest" version. +// +// BasicBatchScheduler does not constrain which requests get grouped +// together into a batch, so using this scheduler there is no way to achieve +// cohesion of versioned requests to version-specific batches. +// +// C. No servable version coordination needs to be performed between the +// request threads and the batch threads. Often, servables are only used in +// the batch threads, in which case this condition trivially holds. If +// servables are used in both threads, then the use-case must tolerate +// version skew across the servables used in the two kinds of threads. +// +// +// EXAMPLE USE-CASE FLOW: +// +// For such use-cases, request processing via BasicBatchScheduler generally +// follows this flow (given for illustration; variations are possible): +// 1. Optionally perform some pre-processing on each request in the request +// threads. +// 2. Route the requests to the batch scheduler, as batching::Task objects. +// (Since all requests are of the same type and are not versioned, the +// scheduler is free to group them into batches arbitrarily.) +// 3. Merge the requests into a single batched representation B. +// 4. Obtain handles to the servable(s) needed to process B. The simplest +// approach is to obtain the latest version of each servable. Alternatively, +// if cross-servable consistency is required (e.g. the vocabulary lookup +// table's version number must match that of the tensorflow session), +// identify an appropriate version number and obtain the servable handles +// accordingly. +// 5. Process B using the obtained servable handles, and split the result into +// individual per-request units. +// 6. Perform any post-processing in the batch thread and/or request thread. +// +// +// PERFORMANCE TUNING: See README.md. +// +template +class BasicBatchScheduler : public BatchScheduler { + public: + // TODO(b/25089730): Tune defaults based on best practices as they develop. + // (Keep them mirrored to the ones in SharedBatchScheduler::QueueOptions and + // SharedBatchScheduler::Options.) + struct Options { + // Options related with (underlying) shared batch scheduler. + // 'thread_pool_name' and 'num_batch_threads' are used to initialize + // a shared batch scheduler underlyingly iff 'shared_batch_scheduler' is + // nullptr. + // + // There are two ways to specify threading: + // 1) Have each session create its own pool. + // 2) Have multiple sessions share the same pool. + // + // In general, the number of threads should be tied to roughly the number of + // compute resources (CPU cores or accelerator cores) backing the threads. + // Sharing a thread pool helps alleviate potential over allocation of + // threads to limited compute resources. + + // To have each session create its own thread pool (1) set + // thread_pool_name/num_batch_threads. + + // To share a thread pool (2) create a scheduler and pass it in. + + // The name to use for the pool of batch threads. + string thread_pool_name = {"batch_threads"}; + + // The number of threads to use to process batches. + // Must be >= 1, and should be tuned carefully. + int num_batch_threads = port::MaxParallelism(); + + // If specified, this scheduler will be used underlyingly to schedule + // batches. Note setting this means `thread_pool_name` and + // `num_batch_threads` are ignored. + std::shared_ptr> shared_batch_scheduler = + nullptr; + + // Options for queue. + // The maximum size of each batch. + // + // The scheduler may form batches of any size between 1 and this number + // (inclusive). If there is a need to quantize the batch sizes, i.e. only + // submit batches whose size is in a small set of allowed sizes, that can be + // done by adding padding in the process-batch callback. + int max_batch_size = 1000; + + // If a task has been enqueued for this amount of time (in microseconds), + // and a thread is available, the scheduler will immediately form a batch + // from enqueued tasks and assign the batch to the thread for processing, + // even if the batch's size is below 'max_batch_size'. + // + // This parameter offers a way to bound queue latency, so that a task isn't + // stuck in the queue indefinitely waiting for enough tasks to arrive to + // make a full batch. (The latency bound is given in the class documentation + // above.) + // + // The goal is to smooth out batch sizes under low request rates, and thus + // avoid latency spikes. + int64_t batch_timeout_micros = 0; + + // The maximum allowable number of enqueued (accepted by Schedule() but + // not yet being processed on a batch thread) tasks in terms of batches. + // If this limit is reached, Schedule() will return an UNAVAILABLE error. + // See the class documentation above for guidelines on how to tune this + // parameter. + int max_enqueued_batches = 10; + + // If true, an input task (i.e., input of `BasicBatchScheduler::Schedule`) + // with a large size (i.e., larger than the largest value of + // `allowed_batch_sizes`) will be split into multiple smaller batch tasks + // and possibly put into different batches for processing. If false, each + // input task is put into one batch as a whole for processing. + // + // API note: + // The value of this option doesn't affect processing output given the same + // input; it affects implementation details as stated below: + // 1. Improve batching efficiency by eliminating unnecessary padding in the + // following scenario: when an open batch has M slots while an input of size + // N is scheduled (M < N), the input can be split to fill remaining slots + // of an open batch as opposed to padding. + // 2.`max_batch_size` specifies the limit of input and + // `max_execution_batch_size` specifies the limit of a task to be processed. + // API user can give an input of size 128 when 'max_execution_batch_size' + // is 32 -> implementation can split input of 128 into 4 x 32, schedule + // concurrent processing, and then return concatenated results corresponding + // to 128. + bool enable_large_batch_splitting = false; + + // `split_input_task_func` specifies how to split `input_task` into + // `output_tasks`. + // + // `input_task`: a unit of task to be split. + // `first_output_task_size`: task size of first output. + // `max_batch_size`: Maximum size of each batch. + // `output_tasks`: A list of output tasks after split. + // + // REQUIRED: + // 1) All `output_tasks` should be non-empty tasks. + // 2) Sizes of `output_tasks` add up to size of `input_task`. + // + // NOTE: + // Instantiations of `TaskType` may vary, so it's up to caller to define + // how (e.g., which members to access) to split input tasks. + std::function* input_task, int first_output_task_size, + int input_batch_size_limit, + std::vector>* output_tasks)> + split_input_task_func; + + // The maximum size of each enqueued batch (i.e., in `batches_`). + // + // The scheduler may form batches of any size between 1 and this number + // (inclusive). If there is a need to quantize the batch sizes, i.e. only + // submit batches whose size is in a small set of allowed sizes, that can be + // done by adding padding in the process-batch callback. + // + // REQUIRES: + // - If enable_large_batch_splitting is true, `max_execution_batch_size` is + // less than or equal to `max_batch_size`. + // - If enable_large_batch_splitting is false, `max_execution_batch_size` is + // equal to `max_batch_size`. + int max_execution_batch_size = 10; + + // The following options are typically only overridden by test code. + + // The environment to use. + Env* env = Env::Default(); + }; + static absl::Status Create( + const Options& options, + std::function>)> + process_batch_callback, + std::unique_ptr* scheduler); + + ~BasicBatchScheduler() override = default; + + absl::Status Schedule(std::unique_ptr* task) override; + size_t NumEnqueuedTasks() const override; + size_t SchedulingCapacity() const override; + + size_t max_task_size() const override { + return shared_scheduler_queue_->max_task_size(); + } + + private: + explicit BasicBatchScheduler( + std::unique_ptr> shared_scheduler_queue); + + // This class is merely a thin wrapper around a SharedBatchScheduler with a + // single queue. + std::unique_ptr> shared_scheduler_queue_; + + BasicBatchScheduler(const BasicBatchScheduler&) = delete; + void operator=(const BasicBatchScheduler&) = delete; +}; + +////////// +// Implementation details follow. API users need not read. + +template +absl::Status BasicBatchScheduler::Create( + const Options& options, + std::function>)> + process_batch_callback, + std::unique_ptr* scheduler) { + std::shared_ptr> shared_scheduler; + + if (options.shared_batch_scheduler == nullptr) { + typename SharedBatchScheduler::Options shared_scheduler_options; + shared_scheduler_options.thread_pool_name = options.thread_pool_name; + shared_scheduler_options.num_batch_threads = options.num_batch_threads; + shared_scheduler_options.env = options.env; + + TF_RETURN_IF_ERROR(SharedBatchScheduler::Create( + shared_scheduler_options, &shared_scheduler)); + } else { + shared_scheduler = options.shared_batch_scheduler; + } + + typename SharedBatchScheduler::QueueOptions + shared_scheduler_queue_options; + shared_scheduler_queue_options.input_batch_size_limit = + options.max_batch_size; + shared_scheduler_queue_options.batch_timeout_micros = + options.batch_timeout_micros; + shared_scheduler_queue_options.max_enqueued_batches = + options.max_enqueued_batches; + shared_scheduler_queue_options.enable_large_batch_splitting = + options.enable_large_batch_splitting; + shared_scheduler_queue_options.split_input_task_func = + options.split_input_task_func; + shared_scheduler_queue_options.max_execution_batch_size = + options.max_execution_batch_size; + std::unique_ptr> shared_scheduler_queue; + TF_RETURN_IF_ERROR(shared_scheduler->AddQueue(shared_scheduler_queue_options, + process_batch_callback, + &shared_scheduler_queue)); + + scheduler->reset( + new BasicBatchScheduler(std::move(shared_scheduler_queue))); + return absl::OkStatus(); +} + +template +absl::Status BasicBatchScheduler::Schedule( + std::unique_ptr* task) { + return shared_scheduler_queue_->Schedule(task); +} + +template +size_t BasicBatchScheduler::NumEnqueuedTasks() const { + return shared_scheduler_queue_->NumEnqueuedTasks(); +} + +template +size_t BasicBatchScheduler::SchedulingCapacity() const { + return shared_scheduler_queue_->SchedulingCapacity(); +} + +template +BasicBatchScheduler::BasicBatchScheduler( + std::unique_ptr> shared_scheduler_queue) + : shared_scheduler_queue_(std::move(shared_scheduler_queue)) {} + +} // namespace serving +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_BASIC_BATCH_SCHEDULER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/batching_util/batch_input_task.h b/third_party/tflite-hdrs/tensorflow/core/kernels/batching_util/batch_input_task.h new file mode 100644 index 00000000..4f50f1da --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/batching_util/batch_input_task.h @@ -0,0 +1,267 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_BATCH_INPUT_TASK_H_ +#define TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_BATCH_INPUT_TASK_H_ + +#include +#include +#include +#include +#include + +#include "absl/base/call_once.h" +#include "absl/container/fixed_array.h" +#include "absl/synchronization/mutex.h" +#include "tensorflow/core/kernels/batching_util/batch_scheduler.h" +#include "tensorflow/core/kernels/batching_util/concat_split_util.h" +#include "tensorflow/core/kernels/batching_util/input_split_metadata.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/thread_annotations.h" +#include "tensorflow/core/util/incremental_barrier.h" + +namespace tensorflow { +namespace serving { + +namespace internal { +template +class BatchInputTaskHandleTestAccess; + +template +class BatchInputTaskTestAccess; + +template +class BatchInputTask; + +// A RAII-style object that holds a ref-counted batch-input-task, and +// represents a slice of batch-input-task. + +// To be handed out to callers of `BatchInputTask::ToTaskHandles` quickly +// (i.e. not necessarily waiting for input split) +// +// `BatchInputTaskHandle::GetSplitTask` evaluates to the slice of task. +template +class BatchInputTaskHandle : public BatchTask { + public: + BatchInputTaskHandle( + std::shared_ptr> batch_input_task, int split_id, + size_t task_size); + + // Should be called once. Returns nullptr on subsequent calls. + std::unique_ptr GetSplitTask(); + + // Returns the size of this task. + size_t size() const override { return task_size_; } + + private: + template + friend class internal::BatchInputTaskHandleTestAccess; + + int split_id() const { return split_id_; } + + std::shared_ptr> batch_input_task_; + + // The handle evaluates to the N-th slice of original task, and + // N is `split_id_`. + const int split_id_; + + const size_t task_size_; + + std::atomic once_{false}; +}; + +// BatchInputTask encapsulates a input (`input_task`) to be batched and the +// information to get task splits after it's enqueued, so as to support lazy +// split of a task. +// +// Input split could reduce excessive padding for efficiency; lazy split +// moves task-split out of the critical path of enqueue and dequeue and reduces +// contention. +// +// BatchInputTask is thread safe. +// +// Usage +// +// ... a deque with frequent enqueue and dequeue operations ... +// ... Note, a deque of Batch of BatchInputTaskHandle is used to form batches +// at enqueue time (split is lazy at deque time); +// ... For use cases to form batches at dequeue time, we can use a deque of +// BatchInputTaskHandle directly, and "peek" metadata to form a batch by +// then. +// std::deque>>> deque_ +// TF_GUARDED_BY(mu_); +// +// std::unique_ptr input_task; +// +// ... Enqueue path ... +// +// { +// mutex_lock l(mu_); +// std::shared_ptr> batch_input_task = +// ConstructLazyBatchWithoutSplit(input_task); +// +// std::vector>> task_handles; +// input_batch->ToTaskHandles(&task_handles); +// for (int i = 0; i < task_handles.size(); ++i) { +// EnqueueTaskHandleIntoDeque(deque_); +// } +// +// ... Dequeue path ... +// std::unique_ptr>> handles_to_schedule; +// { +// mutex_lock l(mu_); +// ... HasBatchToSchedule could be customized or specialized +// ... (e.g., readiness depending on enqueue time) +// if (HasBatchToSchedule(deque_)) { +// handles_to_schedule = std::move(deque_.front()); +// deque_.pop_front(); +// } +// } +// ...... `mu_` is released ...... +// +// std::vector>> tasks_in_batch = +// RemoveAllTasksFromBatch(handles_to_schedule); +// +// std::unique_ptr> batch_to_schedule; +// for (int i = 0; i < tasks_in_batch.size(); i++) { +// batch_to_schedule->AddTask(std::move(tasks_in_batch[i]->GetSplitTask())); +// } +// batch_to_schedule->Close(); +// +// `batch_to_schedule` is ready for schedule. +template +class BatchInputTask + : public std::enable_shared_from_this> { + public: + using SplitInputFunc = std::function* input_task, int first_output_task_size, + int input_batch_size_limit, + std::vector>* output_tasks)>; + + BatchInputTask(std::unique_ptr input_task, + int open_batch_remaining_slot, int batch_size_limit, + SplitInputFunc split_input_func); + + // Outputs the task handles for the input task. + // Each task handle represents a slice of task after input task is split, and + // could evaluate to that slice. + // + // NOTE: + // Each task handle in `output_task_handles` takes ownership of a reference of + // this BatchInputTask. + void ToTaskHandles( + std::vector>>* + output_task_handles); + + private: + friend class BatchInputTaskHandle; + template + friend class internal::BatchInputTaskTestAccess; + + std::unique_ptr GetSplitTask(int split_id); + + absl::Status SplitBatches( + std::vector>* output_tasks); + + std::unique_ptr input_task_; + + const int input_task_size_ = 0; + const int open_batch_remaining_slot_; + + const int batch_size_limit_; + const SplitInputFunc split_func_; + + const InputSplitMetadata input_split_metadata_; + + mutable absl::once_flag once_; + + std::vector> task_splits_; + absl::Status split_status_; +}; + +// +// Implementation details. API readers may skip. +// + +template +BatchInputTaskHandle::BatchInputTaskHandle( + std::shared_ptr> batch_input_task, int split_id, + size_t task_size) + : batch_input_task_(batch_input_task), + split_id_(split_id), + task_size_(task_size) {} + +template +std::unique_ptr BatchInputTaskHandle::GetSplitTask() { + if (once_.load(std::memory_order_acquire)) { + return nullptr; + } + once_.store(true, std::memory_order_release); + return batch_input_task_->GetSplitTask(split_id_); +} + +template +BatchInputTask::BatchInputTask(std::unique_ptr input_task, + int open_batch_remaining_slot, + int batch_size_limit, + SplitInputFunc split_input_func) + : input_task_(std::move(input_task)), + input_task_size_(input_task_->size()), + open_batch_remaining_slot_(open_batch_remaining_slot), + batch_size_limit_(batch_size_limit), + split_func_(split_input_func), + input_split_metadata_(input_task_size_, open_batch_remaining_slot, + batch_size_limit) {} + +template +void BatchInputTask::ToTaskHandles( + std::vector>>* + task_handles) { + const absl::FixedArray& task_sizes = input_split_metadata_.task_sizes(); + task_handles->resize(task_sizes.size()); + for (int i = 0; i < task_handles->size(); i++) { + (*task_handles)[i] = std::make_unique>( + this->shared_from_this(), i, task_sizes[i]); + } +} + +template +std::unique_ptr BatchInputTask::GetSplitTask(int split_id) { + absl::call_once(once_, + [this]() { split_status_ = SplitBatches(&task_splits_); }); + if (!split_status_.ok()) { + LOG_EVERY_N_SEC(WARNING, 60 /* seconds */) + << "Split task with error: " << split_status_ << " split metadata is " + << input_split_metadata_.DebugString(); + return nullptr; + } + if (split_id >= 0 && split_id < task_splits_.size()) { + return std::move(task_splits_[split_id]); + } + return nullptr; +} + +template +absl::Status BatchInputTask::SplitBatches( + std::vector>* output_tasks) { + return split_func_(&input_task_, open_batch_remaining_slot_, + batch_size_limit_, output_tasks); +} + +} // namespace internal +} // namespace serving +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_BATCH_INPUT_TASK_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/batching_util/batch_resource_base.h b/third_party/tflite-hdrs/tensorflow/core/kernels/batching_util/batch_resource_base.h new file mode 100644 index 00000000..e853fc48 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/batching_util/batch_resource_base.h @@ -0,0 +1,380 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_BATCH_RESOURCE_BASE_H_ +#define TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_BATCH_RESOURCE_BASE_H_ + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" +#include "absl/synchronization/blocking_counter.h" +#include "tensorflow/core/common_runtime/cost_measurement_registry.h" +#include "tensorflow/core/common_runtime/request_cost.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/resource_mgr.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/kernels/batching_util/adaptive_shared_batch_scheduler.h" +#include "tensorflow/core/kernels/batching_util/batch_scheduler.h" +#include "tensorflow/core/kernels/batching_util/batch_scheduler_utils.h" +#include "tensorflow/core/kernels/batching_util/shared_batch_scheduler.h" +#include "tensorflow/core/kernels/batching_util/threadsafe_status.h" +#include "tensorflow/core/platform/context.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/thread_annotations.h" +#include "tensorflow/core/protobuf/config.pb.h" +#include "tsl/platform/criticality.h" + +namespace tensorflow { +namespace serving { + +// Options used to create a batch resource. +struct BatchResourceOptions { + int32_t num_batch_threads; + int32_t max_batch_size; + int32_t batch_timeout_micros; + int32_t max_enqueued_batches; + std::vector allowed_batch_sizes; + std::string batch_padding_policy{kPadUpPolicy}; + int32_t low_priority_max_batch_size; + int32_t low_priority_batch_timeout_micros; + int32_t low_priority_max_enqueued_batches; + std::vector low_priority_allowed_batch_sizes; + MixedPriorityBatchingPolicy mixed_priority_batching_policy; +}; + +// Base class for resource that encapsulating the state and logic for batching +// tensors. +class BatchResourceBase : public ResourceBase { + public: + // Given a BatchTask (from one op invocation) with 'num_outputs'== M and + // split into N sub tasks, TensorMatrix is a N X M matrix. + // Namely, TensorMatrix[i][j] indicates the i-th split tensor of j-th output; + // concatenating tensors along the 2nd dimension gives a output tensor. + typedef std::vector> TensorMatrix; + + // One task to be batched, corresponds to a `slice` of input from one batch-op + // invocation. + // + // Given input from one batch-op invocation, a `slice` of this input is: + // 1) Split each Tensor in `BatchTask::inputs` along the 0th dimension. + // 2) 'split_index' is calculated along the 0-th dimension. + // + // Note input from one batch-op invocation is valid and considered a + // specialized `slice`. + struct BatchTask : public tensorflow::serving::BatchTask { + BatchTask() : criticality_val(tsl::criticality::GetCriticality()){}; + + // A unique ID to identify this invocation of Batch. + int64_t guid; + + Context propagated_context; + + std::vector inputs; + std::vector captured_inputs; + OpKernelContext* context; + AsyncOpKernel::DoneCallback done_callback; + + // The index of this split, along the 0-th dimension of input from op + // invocation. + int split_index = 0; + + // Two-dimensional tensor matrix, ownership shared by: + // 1) each split of task (to fill one row in this matrix) + // and + // 2) callback that runs to merge output of individual splits for an op + // invocation, after all splits complete. + std::shared_ptr output; + + // 'status' records error (could be from any split) if at least one split + // returns error, OK otherwise. + // Ownership is shared by individual splits and callback. + std::shared_ptr status; + + bool is_partial = false; + + uint64 start_time; + + size_t size() const override { return inputs[0].shape().dim_size(0); } + + // Create a split task from this one. The caller needs to setup the inputs + // of the new task + std::unique_ptr CreateSplitTask( + int split_index, AsyncOpKernel::DoneCallback done_callback); + + // RequestCost is for collecting the cost and must outlive the batching + // processing. + // + // For example, to collect cost in rpc processing, `request_cost` is owned + // by rpc handler and points to the RequestCost of an rpc which provides + // the inputs to this BatchTask. + // + // After the batch processing, the request cost will be incremented with + // this task's processing costs. + RequestCost* request_cost = nullptr; + + // Returns the criticality associated with the task. + tsl::criticality::Criticality criticality() const override { + return criticality_val; + }; + + // If nonzero, make a batch of this size entirely out of padding. This + // batch is processed, but is not propagated to the kernel outputs. + int forced_warmup_batch_size = 0; + + protected: + virtual std::unique_ptr CreateDerivedTask() { + return std::make_unique(); + } + + private: + // Criticality associated with the task. + ::tsl::criticality::Criticality criticality_val; + }; + + // Appending a T suffix to make the type alias different to those in + // tensorflow::serving namespace, because some versions of compiler complain + // about changing meaning of the symbols. + using BatcherT = SharedBatchScheduler; + using AdaptiveBatcherT = + AdaptiveSharedBatchScheduler; + using BatcherQueueT = BatchScheduler; + using BatchT = Batch; + + BatchResourceBase(bool has_process_batch_function, + std::shared_ptr batcher, + const BatcherT::QueueOptions& batcher_queue_options, + std::vector allowed_batch_sizes) + : has_process_batch_function_(has_process_batch_function), + batcher_(std::move(batcher)), + batcher_queue_options_(batcher_queue_options), + allowed_batch_sizes_(std::move(allowed_batch_sizes)), + allowed_batch_sizes_str_(absl::StrJoin(allowed_batch_sizes_, ",")) {} + + BatchResourceBase(bool has_process_batch_function, + std::shared_ptr batcher, + const AdaptiveBatcherT::QueueOptions& batcher_queue_options, + std::vector allowed_batch_sizes) + : has_process_batch_function_(has_process_batch_function), + adaptive_batcher_(std::move(batcher)), + adaptive_batcher_queue_options_(batcher_queue_options), + allowed_batch_sizes_(std::move(allowed_batch_sizes)), + allowed_batch_sizes_str_(absl::StrJoin(allowed_batch_sizes_, ",")) {} + + void set_session_metadata(tensorflow::SessionMetadata session_metadata) { + session_metadata_ = std::move(session_metadata); + } + + const SessionMetadata& session_metadata() const { return session_metadata_; } + + using CreateBatchTaskFn = + std::function>()>; + + // Like `RegisterInput`, but extra "dummy" batches are processed for each + // batch size. Only the real request's outputs are propagated to the caller. + Status RegisterWarmupInputs(int64_t guid, OpKernelContext* context, + const string& batcher_queue_name, + const CreateBatchTaskFn& create_batch_task_fn, + AsyncOpKernel::DoneCallback done); + // Ingests data from one invocation of the batch op. The data is enqueued to + // be combined with others into a batch, asynchronously. + // `CreateBatchTaskFn` should be used to instantiate fields added to a + // child class of `BatchTask` by the caller. + Status RegisterInput(int64_t guid, OpKernelContext* context, + const string& batcher_queue_name, + const CreateBatchTaskFn& create_batch_task_fn, + AsyncOpKernel::DoneCallback done_callback, + int forced_warmup_batch_size = 0); + + static BatcherT::QueueOptions GetBatcherQueueOptions( + int32_t num_batch_threads, int32_t max_batch_size, + int32_t batch_timeout_micros, int32_t max_enqueued_batches, + const std::vector& allowed_batch_sizes, + bool enable_large_batch_splitting, bool disable_padding); + + static BatcherT::QueueOptions GetBatcherQueueOptions( + int32_t num_batch_threads, int32_t max_batch_size, + int32_t batch_timeout_micros, int32_t max_enqueued_batches, + const std::vector& allowed_batch_sizes, + bool enable_large_batch_splitting, bool disable_padding, + absl::string_view batch_padding_policy, + int32_t low_priority_max_batch_size, + int32_t low_priority_batch_timeout_micros, + int32_t low_priority_max_enqueued_batches, + const std::vector& low_priority_allowed_batch_sizes, + MixedPriorityBatchingPolicy mixed_priority_batching_policy); + + static AdaptiveBatcherT::QueueOptions GetAdaptiveBatcherQueueOptions( + int32_t max_batch_size, int32_t batch_timeout_micros, + int32_t max_enqueued_batches, bool enable_large_batch_splitting, + const std::vector& allowed_batch_sizes, bool disable_padding); + + // Split 'input' of 'input_task_ptr' along 0th dimension, into a list of + // 'output_tasks'. + // Task sizes are determined by + // 1) open_batch_remaining_slot + // 2) max_batch_size + // 3) size-of-input-task + // in a way that + // 1) Task sizes add up to `size-of-input-task`. + // 2) Task sizes from left to right are like + // [open_batch_remaining_slot, max_batch_size, max_batch_size, ..., + // `size-of-input-task` - `sum-of-previous-elements`]. + // + // REQUIRES: + // Caller should make sure size-of-input-task is greater than + // open_batch_remaining_slot. + static Status SplitInputTask( + std::unique_ptr* input_task_ptr, int open_batch_remaining_slot, + int max_batch_size, + std::vector>* output_tasks); + + // Splits the batch costs to each task. + // + // Inputs: + // 1) batch_cost_measurements, which provides the total cost of each type; + // 2) processed_size, it's the batch size plus the padding amount; + // 3) batch, provides the batch size and input sizes. + // + // Outputs: + // The request_cost in each batch task will be updated. + // - This function will use two approaches to split the batch cost (if it's + // non-zero), thus two costs will be output. + // 1) smeared cost: batch cost is split proportionally to each task's size, + // and paddings do not share any cost; + // 2) non-smeared cost: batch cost is split proportionally to each task or + // padding's size. Here padding's cost is not assigned to any tasks. + // - This function will also record the metrics of this batch in each task, + // including: + // 1) the batch size; + // 2) the input size from this task; + // 3) the padding amount. + static void SplitBatchCostsAndRecordMetrics( + const std::string& model_name, const std::string& op_name, + const std::vector>& + batch_cost_measurements, + int64_t processed_size, BatchT& batch); + + private: + // Implementation of calling the process batch function. + virtual void ProcessFuncBatchImpl( + const BatchResourceBase::BatchTask& last_task, + absl::Span inputs, std::vector* combined_outputs, + std::function done) const = 0; + + // Validates that it's legal to combine the tasks in 'batch' into a batch. + // Assumes the batch is non-empty. + static Status ValidateBatch(const BatchT& batch); + + // Returns a boolean indicating whether a batch is formed from low priority + // tasks only or not. + bool IsLowPriorityBatch(const BatchT& batch) const; + + // Returns the smallest entry in 'allowed_batch_sizes_' that is greater than + // or equal to 'batch_size'. If 'allowed_batch_sizes_' is empty, simply + // returns 'batch_size'. + int RoundToLowestAllowedBatchSize(int batch_size, + bool is_low_priority_batch = false) const; + + // Helper function to propagate the status to the task's context and call the + // done callback on the task. + void CleanUpFunctionHelper(BatchTask& task, const Status& status) const; + + // Concatenates the input tensors of the tasks from the batch and the + // unbatched task vector. When padding is enabled in the batcher queue, they + // are padded with garbage value up to the nearest allowed batch size. + Status ConcatInputTensors( + const BatchT& batch, + const std::vector>& unbatched_tasks, + OpKernelContext* context, + std::vector* concatenated_tensors) const; + + Status SplitOutputTensors( + const std::vector& combined_outputs, BatchT* batch, + std::vector>& unbatched_tasks) const; + + void ProcessFuncBatch( + std::unique_ptr batch, + std::vector> unbatched_tasks = {}) const; + + // Processes a batch of one or more BatchTask entries. + void ProcessBatch(std::unique_ptr batch) const; + + // Callback function that wraps the Process*Batch functions above. The caller + // of the callback must guarantee that the unique pointers passed as argument + // are not null. + void ProcessBatchCallBack( + std::unique_ptr> batch, + std::vector> unbatched_tasks); + + // Emits an index tensor, which the Unbatch op will use to un-concatenate + // the tensor and attribute the pieces to the right batch keys. The index + // tensor contains, for each input: [batch_key, start_offset, end_offset] + // where start_offset and end_offset represent the range of entries in the + // concatenated tensors that belong to that input. + // + // Emits the result to the output at 'output_index' using 'context'. + static Status EmitIndexTensor(OpKernelContext* context, const BatchT& batch, + int output_index); + + // Looks up the batcher queue for 'queue_name'. If it didn't previously exist, + // creates it. + // + // The model_name and op_name are the names of the current model and + // operation, respectively. + Status LookupOrCreateBatcherQueue(const string& queue_name, + const string& model_name, + const string& op_name, + BatcherQueueT** queue); + + SessionMetadata session_metadata_; + + absl::Mutex outstanding_batch_mu_; + int num_outstanding_batched_items_ TF_GUARDED_BY(outstanding_batch_mu_) = 0; + + // True if user specified a batch processing function for this resource. + const bool has_process_batch_function_; + // A batch scheduler, and options for creating queues. + std::shared_ptr batcher_; + BatcherT::QueueOptions batcher_queue_options_; + + // A batch scheduler, and options for creating queues. + std::shared_ptr adaptive_batcher_; + AdaptiveBatcherT::QueueOptions adaptive_batcher_queue_options_; + + // A collection of batcher queues, keyed on queue name. + // TODO(olston): Garbage-collect unused queues (perhaps simply remove empty + // ones (with a time delay?); it's okay if they get recreated later). + mutable mutex batcher_queues_mu_; + std::map> batcher_queues_ + TF_GUARDED_BY(batcher_queues_mu_); + + std::vector allowed_batch_sizes_; + // A concatenated string of , separated by ",". This is + // used to record batching parameter. + string allowed_batch_sizes_str_; +}; + +} // namespace serving +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_BATCH_RESOURCE_BASE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/batching_util/batch_scheduler.h b/third_party/tflite-hdrs/tensorflow/core/kernels/batching_util/batch_scheduler.h new file mode 100644 index 00000000..ccb34412 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/batching_util/batch_scheduler.h @@ -0,0 +1,601 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Abstractions for processing small tasks in a batched fashion, to reduce +// processing times and costs that can be amortized across multiple tasks. +// +// The core class is BatchScheduler, which groups tasks into batches. +// +// BatchScheduler encapsulates logic for aggregating multiple tasks into a +// batch, and kicking off processing of a batch on a thread pool it manages. +// +// This file defines an abstract BatchScheduler class. + +#ifndef TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_BATCH_SCHEDULER_H_ +#define TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_BATCH_SCHEDULER_H_ + +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "tensorflow/core/lib/core/notification.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/thread_annotations.h" +#include "tensorflow/core/platform/types.h" +#include "tsl/platform/criticality.h" +#include "tsl/profiler/lib/traceme.h" + +namespace tensorflow { +namespace serving { + +const absl::string_view kLowPriorityPaddingWithMaxBatchSizeAttrValue = + "low_priority_padding_with_max_batch_size"; +const absl::string_view kLowPriorityPaddingWithNextAllowedBatchSizeAttrValue = + "low_priority_padding_with_next_allowed_batch_size"; +const absl::string_view kPriorityIsolationAttrValue = "priority_isolation"; +const absl::string_view kPriorityMergeAttrValue = "priority_merge"; + +enum class MixedPriorityBatchingPolicy { + kLowPriorityPaddingWithMaxBatchSize, + kLowPriorityPaddingWithNextAllowedBatchSize, + kPriorityIsolation, + kPriorityMerge, +}; + +absl::StatusOr GetMixedPriorityBatchingPolicy( + absl::string_view attr_value); + +// The abstract superclass for a unit of work to be done as part of a batch. +// +// An implementing subclass typically contains (or points to): +// (a) input data; +// (b) a thread-safe completion signal (e.g. a Notification); +// (c) a place to store the outcome (success, or some error), upon completion; +// (d) a place to store the output data, upon success. +// +// Items (b), (c) and (d) are typically non-owned pointers to data homed +// elsewhere, because a task's ownership gets transferred to a BatchScheduler +// (see below) and it may be deleted as soon as it is done executing. +class BatchTask { + public: + virtual ~BatchTask() = default; + + // Returns the size of the task, in terms of how much it contributes to the + // size of a batch. (A batch's size is the sum of its task sizes.) + virtual size_t size() const = 0; + + // Returns the criticality of associated with the task. It defaults to + // kCritical. + virtual tsl::criticality::Criticality criticality() const { + return tsl::criticality::Criticality::kCritical; + } +}; + +// A thread-safe collection of BatchTasks. Tasks can be either added or removed +// from the TaskQueue. It is mainly used to hold the registered tasks without +// forming batches, so that the batches can be formed more flexibly right before +// they get scheduled for execution. +// +// Type parameter TaskType must be a subclass of BatchTask. +template +class TaskQueue { + public: + TaskQueue() = default; + + struct TaskWrapper { + std::unique_ptr task; + uint64 start_time_micros; + + TaskWrapper(std::unique_ptr task, uint64 start_time_micros) + : task(std::move(task)), start_time_micros(start_time_micros) {} + }; + + // Appends a task to the end of the queue with the given start time. + void AddTask(std::unique_ptr task, uint64 start_time_micros); + + // Adds a task to the front of the queue with the given start time. + void PrependTask(std::unique_ptr task, uint64 start_time_micros); + + // Removes a task from the front of the queue, i.e., the oldest task in the + // queue. + std::unique_ptr RemoveTask(); + + // Removes tasks from the front of the queue as many as possible as long as + // the sum of sizes of the removed tasks don't exceed the 'size' given as the + // argument. + std::vector> RemoveTask(int size); + + // Returns the start time of the earliest task in the queue. If the queue is + // empty, return the null value. + std::optional EarliestTaskStartTime() const; + + // Returns true iff the queue contains 0 tasks. + bool empty() const; + + // Returns the number of tasks in the queue. + int num_tasks() const; + + // Returns the sum of the task sizes. + int size() const; + + private: + mutable mutex mu_; + + // Tasks in the queue. + std::deque tasks_ TF_GUARDED_BY(mu_); + + // The sum of the sizes of the tasks in 'tasks_'. + int size_ TF_GUARDED_BY(mu_) = 0; + + // Whether the queue is empty. + std::atomic empty_ TF_GUARDED_BY(mu_){true}; + + // The copy constructor and the assign op are deleted. + TaskQueue(const TaskQueue&) = delete; + void operator=(const TaskQueue&) = delete; +}; + +template +void TaskQueue::AddTask(std::unique_ptr task, + uint64 start_time_micros) { + { + mutex_lock l(mu_); + size_ += task->size(); + tasks_.emplace_back(std::move(task), start_time_micros); + empty_.store(false); + } +} + +template +void TaskQueue::PrependTask(std::unique_ptr task, + uint64 start_time_micros) { + { + mutex_lock l(mu_); + size_ += task->size(); + tasks_.emplace_front(std::move(task), start_time_micros); + empty_.store(false); + } +} + +template +std::unique_ptr TaskQueue::RemoveTask() { + { + mutex_lock l(mu_); + if (tasks_.empty()) { + return nullptr; + } + std::unique_ptr task = std::move(tasks_.front().task); + size_ -= task->size(); + tasks_.pop_front(); + if (tasks_.empty()) { + empty_.store(true); + } + return task; + } +} + +template +std::vector> TaskQueue::RemoveTask( + int size) { + { + mutex_lock l(mu_); + if (tasks_.empty()) { + return {}; + } + + int size_lower_bound = size_ - size; + std::vector> remove_tasks; + while (!tasks_.empty() && + size_ - static_cast(tasks_.front().task->size()) >= + size_lower_bound) { + size_ -= static_cast(tasks_.front().task->size()); + remove_tasks.push_back(std::move(tasks_.front().task)); + tasks_.pop_front(); + if (tasks_.empty()) { + empty_.store(true); + } + } + return remove_tasks; + } +} + +template +bool TaskQueue::empty() const { + { + mutex_lock l(mu_); + return empty_.load(); + } +} + +template +std::optional TaskQueue::EarliestTaskStartTime() const { + { + mutex_lock l(mu_); + + if (tasks_.empty()) { + return std::nullopt; + } + + return tasks_.front().start_time_micros; + } +} + +template +int TaskQueue::num_tasks() const { + { + mutex_lock l(mu_); + return tasks_.size(); + } +} + +template +int TaskQueue::size() const { + { + mutex_lock l(mu_); + return size_; + } +} + +// A thread-safe collection of BatchTasks, to be executed together in some +// fashion. +// +// At a given time, a batch is either "open" or "closed": an open batch can +// accept new tasks; a closed one cannot. A batch is monotonic: initially it is +// open and tasks can be added to it; then it is closed and its set of tasks +// remains fixed for the remainder of its life. A closed batch cannot be re- +// opened. +// +// Type parameter TaskType must be a subclass of BatchTask. +template +class Batch { + public: + Batch(); + explicit Batch(uint64 traceme_context_id); + virtual ~Batch(); // Blocks until the batch is closed. + + // Appends 'task' to the batch. After calling AddTask(), the newly-added task + // can be accessed via task(num_tasks()-1) or mutable_task(num_tasks()-1). + // Dies if the batch is closed. + void AddTask(std::unique_ptr task, uint64 start_time_micros = 0); + + // Removes the most recently added task. Returns nullptr if the batch is + // empty. + std::unique_ptr RemoveTask(); + + // Caller takes ownership of returned tasks. + // Must be called after a batch is closed. + std::vector> RemoveAllTasks(); + + // Returns the number of tasks in the batch. + int num_tasks() const; + + // Returns true iff the batch contains 0 tasks. + bool empty() const; + + // Returns a reference to the ith task (in terms of insertion order). + const TaskType& task(int i) const; + + // Returns a pointer to the ith task (in terms of insertion order). + // + // Caller doesn't take ownership. + TaskType* mutable_task(int i); + + // Returns the sum of the task sizes. + size_t size() const; + + // Returns true iff the batch is currently closed. + bool IsClosed() const; + + // Blocks until the batch is closed. + void WaitUntilClosed() const; + + // Marks the batch as closed. Dies if called more than once. + void Close(); + + // Returns the TraceMe context id of this batch. + uint64 traceme_context_id() const; + + // Attempts to trim this batch to a new, smaller size (not to be confused with + // the number of tasks in the batch). On success, the trimmed tasks go into + // 'out_trimmed_tasks' in the same order the tasks were in this batch. + // + // The method might not succeed if it needs to split a large task to hit the + // correct size. + void TryTrimToNewSize( + int new_size, std::vector>& out_trimmed_tasks); + + // Returns the start time of the earliest task in the queue. If the queue is + // empty, return the null value. + std::optional EarliestTaskStartTime() const; + + private: + mutable mutex mu_; + + // The tasks in the batch. + std::vector> tasks_ TF_GUARDED_BY(mu_); + + // The sum of the sizes of the tasks in 'tasks_'. + size_t size_ TF_GUARDED_BY(mu_) = 0; + + std::atomic empty_ TF_GUARDED_BY(mu_){true}; + + // Whether the batch has been closed. + Notification closed_; + + // The TracMe context id. + const uint64 traceme_context_id_; + + // The minimum start time of all tasks in the batch. + // If the batch is empty, the value is undefined. + uint64 earliest_task_start_time_micros_ TF_GUARDED_BY(mu_); + + Batch(const Batch&) = delete; + void operator=(const Batch&) = delete; +}; + +// An abstract batch scheduler class. Collects individual tasks into batches, +// and processes each batch on a pool of "batch threads" that it manages. The +// actual logic for processing a batch is accomplished via a callback. +// +// Type parameter TaskType must be a subclass of BatchTask. +template +class BatchScheduler { + public: + virtual ~BatchScheduler() = default; + + // Submits a task to be processed as part of a batch. + // + // Ownership of '*task' is transferred to the callee iff the method returns + // Status::OK. In that case, '*task' is left as nullptr. Otherwise, '*task' is + // left as-is. + // + // If no batch processing capacity is available to process this task at the + // present time, and any task queue maintained by the implementing subclass is + // full, this method returns an UNAVAILABLE error code. The client may retry + // later. + // + // Other problems, such as the task size being larger than the maximum batch + // size, yield other, permanent error types. + // + // In all cases, this method returns "quickly" without blocking for any + // substantial amount of time. If the method returns Status::OK, the task is + // processed asynchronously, and any errors that occur during the processing + // of the batch that includes the task can be reported to 'task'. + virtual absl::Status Schedule(std::unique_ptr* task) = 0; + + // Returns the number of tasks that have been scheduled (i.e. accepted by + // Schedule()), but have yet to be handed to a thread for execution as part of + // a batch. Note that this returns the number of tasks, not the aggregate task + // size (so if there is one task of size 3 and one task of size 5, this method + // returns 2 rather than 8). + virtual size_t NumEnqueuedTasks() const = 0; + + // Returns a guaranteed number of size 1 tasks that can be Schedule()d without + // getting an UNAVAILABLE error. In a typical implementation, returns the + // available space on a queue. + // + // There are two important caveats: + // 1. The guarantee does not extend to varying-size tasks due to possible + // internal fragmentation of batches. + // 2. The guarantee only holds in a single-thread environment or critical + // section, i.e. if an intervening thread cannot call Schedule(). + // + // This method is useful for monitoring, or for guaranteeing a future slot in + // the schedule (but being mindful about the caveats listed above). + virtual size_t SchedulingCapacity() const = 0; + + // Returns the maximum allowed size of tasks submitted to the scheduler. (This + // is typically equal to a configured maximum batch size.) + virtual size_t max_task_size() const = 0; +}; + +////////// +// Implementation details follow. API users need not read. + +template +Batch::Batch() : Batch(0) {} + +template +Batch::Batch(uint64 traceme_context_id) + : traceme_context_id_(traceme_context_id) {} + +template +Batch::~Batch() { + WaitUntilClosed(); +} + +template +void Batch::AddTask(std::unique_ptr task, + uint64 start_time_micros) { + DCHECK(!IsClosed()); + { + mutex_lock l(mu_); + size_ += task->size(); + tasks_.push_back(std::move(task)); + empty_.store(false); + if (tasks_.size() == 1) { + earliest_task_start_time_micros_ = start_time_micros; + } else { + earliest_task_start_time_micros_ = + std::min(earliest_task_start_time_micros_, start_time_micros); + } + } +} + +template +std::optional Batch::EarliestTaskStartTime() const { + { + mutex_lock l(mu_); + if (tasks_.empty()) { + return std::nullopt; + } + return earliest_task_start_time_micros_; + } +} + +template +std::vector> Batch::RemoveAllTasks() { + DCHECK(IsClosed()); + { + mutex_lock l(mu_); + size_ = 0; + empty_.store(true); + std::vector> tasks_to_return; + + // Swapping vector takes constant time. + tasks_to_return.swap(tasks_); + return std::move(tasks_to_return); + } +} + +template +std::unique_ptr Batch::RemoveTask() { + { + mutex_lock l(mu_); + if (tasks_.empty()) { + return nullptr; + } + std::unique_ptr task = std::move(tasks_.back()); + size_ -= task->size(); + tasks_.pop_back(); + if (tasks_.empty()) { + empty_.store(true); + } + return task; + } +} + +template +int Batch::num_tasks() const { + { + mutex_lock l(mu_); + return tasks_.size(); + } +} + +template +bool Batch::empty() const TF_NO_THREAD_SAFETY_ANALYSIS { + // tracer is added to zoom in about this method. + // TODO(b/160249203): Remove tracer after evaluating a change to reduce + // lock contention and cpu usage (which is observed in profiler and + // very data-driven). + tsl::profiler::TraceMe tracer("BatchTask::empty"); + return empty_.load(); +} + +template +const TaskType& Batch::task(int i) const { + DCHECK_GE(i, 0); + { + mutex_lock l(mu_); + DCHECK_LT(i, tasks_.size()); + return *tasks_[i].get(); + } +} + +template +TaskType* Batch::mutable_task(int i) { + DCHECK_GE(i, 0); + { + mutex_lock l(mu_); + DCHECK_LT(i, tasks_.size()); + return tasks_[i].get(); + } +} + +template +size_t Batch::size() const { + { + mutex_lock l(mu_); + return size_; + } +} + +template +bool Batch::IsClosed() const { + return const_cast(&closed_)->HasBeenNotified(); +} + +template +void Batch::WaitUntilClosed() const { + const_cast(&closed_)->WaitForNotification(); +} + +template +void Batch::Close() { + closed_.Notify(); +} + +template +uint64 Batch::traceme_context_id() const { + return traceme_context_id_; +} + +template +void Batch::TryTrimToNewSize( + int new_size, std::vector>& out_trimmed_tasks) { + mutex_lock l(mu_); + DCHECK_GT(new_size, 0); + DCHECK_LT(new_size, size_); + DCHECK(out_trimmed_tasks.empty()); + + // Index of the first task to trim away. It is possible that it is the index + // of a task of size larger than 1 that will have to be split in order to get + // to the target new_size. + int32 first_task_to_move = 0; + // The sum of sizes of tasks i, where i < first_task_to_move. + int32 size_of_previous_tasks = 0; + while (size_of_previous_tasks + tasks_[first_task_to_move]->size() <= + new_size) { + size_of_previous_tasks += tasks_[first_task_to_move]->size(); + first_task_to_move++; + // The loop must always stop before this check is tripped because new_size + // must never be larger than the size of the batch. + DCHECK_LT(first_task_to_move, tasks_.size()); + } + + // Check whether task 'first_task_to_move' will have to be split. + if (size_of_previous_tasks < new_size) { + // TODO: b/325954758 - Consider supporting splitting large tasks and then + // drop 'Try' from the method name. + return; + } + DCHECK_EQ(size_of_previous_tasks, new_size); + + // Actually trim. + out_trimmed_tasks.reserve(tasks_.size() - first_task_to_move); + std::move(tasks_.begin() + first_task_to_move, tasks_.end(), + std::back_inserter(out_trimmed_tasks)); + tasks_.resize(first_task_to_move); + size_ = new_size; +} + +} // namespace serving +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_BATCH_SCHEDULER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/batching_util/batch_scheduler_utils.h b/third_party/tflite-hdrs/tensorflow/core/kernels/batching_util/batch_scheduler_utils.h new file mode 100644 index 00000000..9a6deb1a --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/batching_util/batch_scheduler_utils.h @@ -0,0 +1,157 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_BATCH_SCHEDULER_UTILS_H_ +#define TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_BATCH_SCHEDULER_UTILS_H_ + +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "absl/time/time.h" +#include "tensorflow/core/kernels/batching_util/batch_scheduler.h" +#include "tensorflow/core/kernels/batching_util/batch_stats.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +namespace serving { + +// Returns the next allowed batch size, which is the smallest allowed batch size +// greater than or equal to the given batch size. If allowed_batch_sizes, +// returns batch_size as is. +int GetNextAllowedBatchSize(int batch_size, + const std::vector& allowed_batch_sizes, + bool disable_padding); + +// Returns the largest allowed batch size that is smaller than or equal to +// batch_size. Returns batch_size if no such size exists. +int GetPrevAllowedBatchSize(int batch_size, + const std::vector& allowed_batch_sizes, + bool disable_padding); + +// Constants containing possible values for the batch_padding_policy argument +// of MaybeBatchDown. This argument specifies the policy that a batch scheduler +// is using when deciding what to do when, say, 18 requests need to be batched, +// but only 16 and 32 batch sizes are allowed. The following options are +// available. +// +// - PAD_UP: pad to size 32. +// - BATCH_DOWN: schedule a batch of size 16 and leave 2 requests in the +// batch buffer. +// - MINIMIZE_TPU_COST_PER_REQUEST: a smarter greedy policy that chooses +// to either PAD_UP or BATCH_DOWN so as to minimize the TPU costs per +// real request. In this case, it would compare (batch_16_cost / 16) and +// (batch_32_cost / 18). +// +inline constexpr absl::string_view kBatchDownPolicy = "BATCH_DOWN"; +inline constexpr absl::string_view kPadUpPolicy = "PAD_UP"; +inline constexpr absl::string_view kMinimizeTpuCostPerRequestPolicy = + "MINIMIZE_TPU_COST_PER_REQUEST"; + +// Trims the batch to the next allowed batch size when possible and when +// configured by batch_padding_policy. +// +// When trimming, this function puts the trimmed tasks go into the +// out_trimmed_tasks vector in the same order as they were in the batch. +template +void MaybeBatchDown(Batch& batch, + const std::vector& allowed_batch_sizes, + bool disable_padding, + absl::string_view batch_padding_policy, + ModelBatchStats* model_batch_stats, + std::vector>& out_trimmed_tasks) { + if (batch_padding_policy == kPadUpPolicy) { + // This is the default behavior of batch resource when it is given a batch + // size that doesn't match any of the allowed batch sizes. + return; + } + bool minimize_tpu_cost_per_request; + if (batch_padding_policy == kBatchDownPolicy) { + minimize_tpu_cost_per_request = false; + } else if (batch_padding_policy == kMinimizeTpuCostPerRequestPolicy) { + if (model_batch_stats == nullptr) { + LOG_FIRST_N(ERROR, 1) + << kMinimizeTpuCostPerRequestPolicy + << " batch padding policy has been chosen " + "but no ModelBatchStats passed to the batch scheduler; will " + "fall back on the " + << kPadUpPolicy << " policy."; + return; + } + minimize_tpu_cost_per_request = true; + } else { + LOG_FIRST_N(ERROR, 1) << "Unsupported batch_padding_policy: " + << batch_padding_policy << ", falling back on the " + << kPadUpPolicy << " policy."; + return; + } + + int32 batch_size = batch.size(); + + int32 pad_up_size = + GetNextAllowedBatchSize(batch_size, allowed_batch_sizes, disable_padding); + if (pad_up_size == batch_size) { + return; // Good, no padding is necessary. + } + + int32 batch_down_size = + GetPrevAllowedBatchSize(batch_size, allowed_batch_sizes, disable_padding); + if (batch_down_size == batch_size) { + return; // Can't batch down (e.g. no smaller batch size available). + } + + if (minimize_tpu_cost_per_request) { + // TODO: b/325954758 - Consider logging a warning here or elsewhere if + // a larger batch doesn't cost meaningfully cheaper than a smaller batch. + // TODO: b/325954758 - Consider logging a warning here or elsewhere if a + // smaller batch costs unreasonably cheaper than a larger one (assuming + // a batch cost model = constant_cost + batch_size * per_element_cost). + // TODO: b/325954758 - Consider occasionally picking either batch size so + // that we learn fresh costs of each batch size. For this code, it is not a + // large priority though because if we are in between two allowed batch + // sizes (say, 16 and 32), chances are that will occasionally organically + // get batches of exact sizes 16 and 32 (and then we pick those + // unconditionally). But if we explicitly occasionally explored other batch + // sizes, we wouldn't have to rely on this "chances are". For other + // applications of batch costs, we might also want to occasionally explore + // all allowed batch sizes and not just 16 and 32 from this example. + std::optional down_batch_cost = + model_batch_stats->batch_size(batch_down_size).tpu_cost().mean(); + std::optional up_batch_cost = + model_batch_stats->batch_size(pad_up_size).tpu_cost().mean(); + if (!down_batch_cost.has_value() || !up_batch_cost.has_value()) { + // We have no data about batch costs, let's just do nothing. + return; + } + + auto batch_down_cost_per_request = *down_batch_cost / batch_down_size; + auto pad_up_cost_per_request = *up_batch_cost / batch_size; + + if (pad_up_cost_per_request < batch_down_cost_per_request) { + // Abort batching down because it's cheaper to pad up. + return; + } + } + + // Batch down. + batch.TryTrimToNewSize(batch_down_size, out_trimmed_tasks); +} + +} // namespace serving +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_BATCH_SCHEDULER_UTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/batching_util/batch_stats.h b/third_party/tflite-hdrs/tensorflow/core/kernels/batching_util/batch_stats.h new file mode 100644 index 00000000..87c36fca --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/batching_util/batch_stats.h @@ -0,0 +1,274 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// The API for reporting and querying batch statistics such as the average batch +// costs for in-process use. +// +// All these statistics can also be retrieved from metrics reported by various +// modules (e.g., batch_resource_base), but it would be slow. This API, on the +// other hand, was designed to be queried on every request. +// +// The classes defined here are not supposed to be instantiated by the user. +// Instead, this file provides a single entry point: +// +// BatchStatsRegistry& GlobalBatchStatsRegistry(); +// +// For example, to register batch cost, do: +// +// GlobalBatchStatsRegistry() +// .model(/* model_name= */ "m", /* op_name= */ "o") +// .batch_size(4) +// .tpu_cost +// .Register(cost); +// +// To get the mean cost later, do: +// +// std::optional cost = +// .GlobalBatchStatsRegistry() +// .model(/* model_name= */ "m", /* op_name= */ "o") +// .batch_size(4) +// .tpu_cost +// .mean(); +// +// It is allowed and safe to store references to intermediate objects here +// because all intermediate objects are guaranteed to never be destroyed. +// +// All operations supported by this API are thread-safe. + +#ifndef TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_BATCH_STATS_H_ +#define TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_BATCH_STATS_H_ + +#include +#include +#include +#include +#include +#include + +#include "absl/container/node_hash_map.h" +#include "absl/time/time.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/types.h" +#include "tsl/platform/thread_annotations.h" + +namespace tensorflow::serving { + +// Default values for when there is no recorded statistic in ModelBatchStats. +constexpr int64_t kNumBatchThreadsUnknown = -1; +constexpr int64_t kBatchTimeoutMicrosUnknown = -1; + +// Tracks the average cost of registered samples. +// +// Thread-safe. +class CostTracker { + public: + // Registers a cost sample. + void Register(absl::Duration cost) { + DCHECK_GT(cost, absl::ZeroDuration()); + + mutex_lock l(mu_); + sample_count_++; + sample_sum_ += cost; + }; + + // Returns the average cost of all registered samples, giving each sample + // the same weight. + // + // Returns std::nullopt if no samples have been registered. + // + // TODO: b/325954758 - Switch this to an exponentially-decaying average. It's + // likely enough to set the half-life to the last 100-1000 samples. + std::optional mean() const { + int64_t count; + absl::Duration sum; + + { + // We only hold the lock to read the values and release it before later + // performing a relatively slow division operation. + mutex_lock l(mu_); + count = sample_count_; + sum = sample_sum_; + } + + if (count == 0) return std::nullopt; + + return sum / count; + }; + + private: + mutable mutex mu_; + + int64_t sample_count_ TF_GUARDED_BY(mu_) = 0; + absl::Duration sample_sum_ TF_GUARDED_BY(mu_); +}; + +// Tracks statistics for a particular model and batch size. +// +// Thread-safe. +class BatchSizeStats { + public: + CostTracker& tpu_cost() { return tpu_cost_; }; + + private: + CostTracker tpu_cost_; +}; + +// Tracks statistics for a particular model. +// +// Here, "model" means a specific version of a model (we assume that version is +// encoded in the op_name). In rare cases, when a model version has multiple +// BatchFunction operation, we also treat each such operation as a separate +// model in this context (they should also have different op_names). +// +// Thread-safe. +class ModelBatchStats { + public: + // Returns a reference to the BatchSizeStats instance for the given batch + // size. + // + // The returned reference persist for as long as 'this' is alive. + BatchSizeStats& batch_size(int32 batch_size) { + mutex_lock l(mu_); + return batch_size_stats_by_batch_size_[batch_size]; + } + + // Registers that the model server has processed a batch of size `size` + // non-padding tasks for this model, updating the current cumulative + // processed size. + void RegisterProcessedSize(int64_t size) { + cumulative_processed_size_.fetch_add(size, std::memory_order_relaxed); + } + + // Returns the cumulative size processed by this model (the total + // count of individual unit-sized queries processed by the model). + int64_t cumulative_processed_size() const { + return cumulative_processed_size_.load(std::memory_order_relaxed); + } + + // Returns the list of batch sizes for which this model has statistics. + // + // The returned list is not guaranteed to be sorted. + std::vector BatchSizes() const { + std::vector result; + mutex_lock l(mu_); + result.reserve(batch_size_stats_by_batch_size_.size()); + for (const auto& [key, value] : batch_size_stats_by_batch_size_) { + result.push_back(key); + } + return result; + } + + void SetNumBatchThreads(int64_t num_batch_threads) { + num_batch_threads_.store(num_batch_threads, std::memory_order_relaxed); + } + + int64_t num_batch_threads() const { + return num_batch_threads_.load(std::memory_order_relaxed); + } + + void SetBatchTimeoutMicros(int64_t batch_timeout_micros) { + batch_timeout_micros_.store(batch_timeout_micros, + std::memory_order_relaxed); + } + + int64_t batch_timeout_micros() const { + return batch_timeout_micros_.load(std::memory_order_relaxed); + } + + private: + mutable mutex mu_; + + // The storage of all BatchSizeStats instances. + // + // The mutex only protects adding/finding element in the map. Access to + // elements themselves (after they were created) is not protected here. No + // element deletion is possible because we return references to items in this + // map and don't track their lifetime. We are using the node hash map so that + // elements, once created, are fixed in memory. + absl::node_hash_map batch_size_stats_by_batch_size_ + TF_GUARDED_BY(mu_); + + // The total count of individual unit-sized queries processed by this model. + // Can be used to generate an internal load metric per model. See + // RegisterQuerySize for more details. + std::atomic cumulative_processed_size_ = 0; + + // The number of batch threads assigned to this model. + std::atomic num_batch_threads_ = kNumBatchThreadsUnknown; + + // The timeout in microseconds for this model (after which the current batch + // is sent to be processed by the TPU). + std::atomic batch_timeout_micros_ = kBatchTimeoutMicrosUnknown; +}; + +// Tracks batch statistics for all models. +// +// Thread-safe. +class BatchStatsRegistry { + public: + // Returns a reference to ModelBatchStats for the provided model_name and + // op_name. + // + // Upon invocation with a not-yet-seen arguments, creates an empty + // ModelBatchStats instance. + // + // The returned reference persist for as long as 'this' is alive. + ModelBatchStats& model(const std::string& model_name, + const std::string& op_name) { + std::tuple key(model_name, op_name); + mutex_lock l(mu_); + return model_batch_stats_by_model_and_op_names_[key]; + } + + // Returns a list of all model and op names. + // + // This is the set of model/op names tracked by this BatchStats instance. + // Note that the returned list is not guaranteed to be sorted. + std::vector> ModelAndOpNames() const { + std::vector> result; + mutex_lock l(mu_); + result.reserve(model_batch_stats_by_model_and_op_names_.size()); + for (const auto& [key, value] : model_batch_stats_by_model_and_op_names_) { + result.push_back(key); + } + return result; + } + + private: + mutable mutex mu_; + + // The storage of all ModelBatchStats instances. + // + // The mutex only protects adding/finding element in the map. Access to + // elements themselves (after they were created) is not protected here. No + // element deletion is possible because we return references to items in this + // map and don't track their lifetime. We are using the node hash map for + // element pointer stability. + absl::node_hash_map, ModelBatchStats> + model_batch_stats_by_model_and_op_names_ TF_GUARDED_BY(mu_); +}; + +// Returns the global instance of BatchStats, to use used for all production +// purposes (one should only instantiate individual classes from this file to +// test them). +inline BatchStatsRegistry& GlobalBatchStatsRegistry() { + static BatchStatsRegistry* instance = new BatchStatsRegistry(); + return *instance; +} + +} // namespace tensorflow::serving + +#endif // TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_BATCH_STATS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/batching_util/bounded_executor.h b/third_party/tflite-hdrs/tensorflow/core/kernels/batching_util/bounded_executor.h new file mode 100644 index 00000000..804a3790 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/batching_util/bounded_executor.h @@ -0,0 +1,80 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_BOUNDED_EXECUTOR_H_ +#define TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_BOUNDED_EXECUTOR_H_ + +#include + +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/statusor.h" +#include "tensorflow/core/platform/thread_annotations.h" +#include "tensorflow/core/platform/threadpool.h" +#include "tensorflow/core/platform/threadpool_interface.h" + +namespace tensorflow { +namespace serving { +// BoundedExecutor has a bounded number of threads and unlimited queue length, +// scheduled tasks are executed in a FIFO way. +class BoundedExecutor : public thread::ThreadPoolInterface { + public: + struct Options { + Env* env = Env::Default(); + ThreadOptions thread_options; + std::string thread_name; + int num_threads = -1; + }; + + static absl::StatusOr> Create( + const Options& options); + + // Destructor. All threads will be joined. + ~BoundedExecutor() override; + + // Enqueue a function to be executed. + // + // Callers are responsible to guarantee `func` is not nullptr. + void Schedule(std::function func) override; + + // Returns the number of threads. + int NumThreads() const override; + + int CurrentThreadId() const override; + + private: + explicit BoundedExecutor(const Options& options); + + // Starts N workers (N == num_threads), polling tasks from `work_queue_`. + void InitWorker(); + + // A loop to fetch task from `work_queue_` and execute task. + void Run(); + + const Options& options_; + + mutex work_queue_mu_; + std::deque> work_queue_ TF_GUARDED_BY(work_queue_mu_); + condition_variable work_queue_cv_ TF_GUARDED_BY(work_queue_mu_); + + // A fixed number of threads. + std::vector> threads_; + BoundedExecutor(const BoundedExecutor&) = delete; + void operator=(const BoundedExecutor&) = delete; +}; + +} // namespace serving +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_BOUNDED_EXECUTOR_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/batching_util/concat_split_util.h b/third_party/tflite-hdrs/tensorflow/core/kernels/batching_util/concat_split_util.h new file mode 100644 index 00000000..b5354be3 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/batching_util/concat_split_util.h @@ -0,0 +1,253 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_CONCAT_SPLIT_UTIL_H_ +#define TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_CONCAT_SPLIT_UTIL_H_ + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/ops_util.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/kernels/concat_lib.h" +#include "tensorflow/core/kernels/split_lib.h" +#include "tensorflow/core/platform/status.h" + +namespace tensorflow { +namespace concat_split_util { + +typedef Eigen::ThreadPoolDevice CPUDevice; +typedef Eigen::GpuDevice GPUDevice; + +// Concatenates 'inputs' into a single tensor along the zeroth dimension. +// Requires that all elements of 'inputs' have element type T. Writes to +// 'output' using 'context' for the allocation to ensure proper device +// placement. +template +absl::Status Concat(OpKernelContext* context, + const absl::Span inputs, Tensor* output) { + const int input_dims = inputs[0].dims(); + const TensorShape& input_shape = inputs[0].shape(); + + // Note that we reduce the concat of k-dimensional tensors into a two + // dimensional concat. Assuming the dimensions of any input tensor are + // {y0, y1,...,ym-1}, we flatten it to {1, y}, where y = Prod_i(yi). + std::vector::ConstMatrix>> inputs_flat; + inputs_flat.reserve(inputs.size()); + int64_t output_dim0 = 0; + for (size_t i = 0; i < inputs.size(); ++i) { + const Tensor& input = inputs[i]; + if (input.dims() != input_dims) { + return errors::InvalidArgument( + "Ranks of all input tensors should match: shape[0] = ", + input_shape.DebugString(), " vs. shape[", i, + "] = ", input.shape().DebugString()); + } + for (int j = 1; j < input_dims; ++j) { + if (input.dim_size(j) != input_shape.dim_size(j)) { + return errors::InvalidArgument( + "Dimensions of inputs should match: shape[0] = ", + input_shape.DebugString(), " vs. shape[", i, + "] = ", input.shape().DebugString()); + } + } + if (input.NumElements() > 0) { + inputs_flat.emplace_back(new typename TTypes::ConstMatrix( + input.shaped({1, input.NumElements()}))); + } + output_dim0 += input.dim_size(0); + } + + TensorShape output_shape(input_shape); + output_shape.set_dim(0, output_dim0); + AllocatorAttributes attr; + attr.set_on_host(true); + TF_RETURN_IF_ERROR(context->allocate_temp(DataTypeToEnum::value, + output_shape, output, attr)); + if (output->NumElements() > 0) { + auto output_flat = output->shaped({1, output->NumElements()}); +#if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \ + (defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM) + if (std::is_same::value) { + ConcatGPU(context, inputs_flat, output, &output_flat); + return OkStatus(); + } +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM + ConcatCPU(context->device(), inputs_flat, &output_flat); + } + + return absl::OkStatus(); +} + +// Same as 'Concat' above, but handles Tensor dtype deduction automatically. +inline absl::Status Concat(OpKernelContext* context, + const absl::Span inputs, + Tensor* output) { + const DataType type = inputs[0].dtype(); + absl::Status concat_status; + switch (type) { +#define CASE(type) \ + case DataTypeToEnum::value: \ + concat_status = Concat(context, inputs, output); \ + break; + TF_CALL_ALL_TYPES(CASE); +#undef CASE + default: + concat_status = errors::InvalidArgument("Unsupported data type: ", type); + break; + } + return concat_status; +} + +// The Split*() functions split 'input' with element type T into 'sizes.size()' +// tensors along the zeroth dimension, with the ith split having zeroth- +// dimension size 'sizes[i]'. They allocate the output tensors using 'context', +// for proper device placement. + +// Handles special cases that are cheap. Sets 'done==true' iff it found an +// applicable special case and wrote to the outputs. Otherwise acts as a no-op. +template +absl::Status SplitEasyCases(OpKernelContext* context, const Tensor& input, + const absl::Span sizes, + std::vector* outputs, bool* done) { + *done = false; + + int64_t total_size = 0; + for (const int64_t size : sizes) { + total_size += size; + } + if (total_size > input.shape().dim_size(0)) { + return errors::InvalidArgument( + "Sum of split sizes must not exceed dim0-size of input tensor"); + } + + // Special case 0: trivial 1-way split. + if (sizes.size() == 1 && sizes.at(0) == input.shape().dim_size(0)) { + outputs->push_back(input); + *done = true; + return absl::OkStatus(); + } + + // Special case 1: input is aligned. + if (IsInnerDimsSizeAligned(input.shape())) { + int64_t position = 0; + for (const int64_t size : sizes) { + outputs->emplace_back(input.Slice(position, position + size)); + position += size; + } + *done = true; + return absl::OkStatus(); + } + + return absl::OkStatus(); +} + +// Handles the general case, on CPU. +template +absl::Status SplitCPU(OpKernelContext* context, const Tensor& input, + const absl::Span sizes, + std::vector* outputs) { + int64_t suffix_dim_size = 1; + for (int i = 1; i < input.shape().dims(); ++i) { + suffix_dim_size *= input.shape().dim_size(i); + } + auto input_reshaped = + input.shaped({input.shape().dim_size(0), suffix_dim_size}); + + int64_t position = 0; + for (const int64_t size : sizes) { + TensorShape output_shape = input.shape(); + output_shape.set_dim(0, size); + Tensor output; + AllocatorAttributes attr; + attr.set_on_host(true); + TF_RETURN_IF_ERROR( + context->allocate_temp(input.dtype(), output_shape, &output, attr)); + auto output_shaped = output.shaped({size, suffix_dim_size}); + + Eigen::DSizes slice_indices{ + static_cast(position), 0}; + Eigen::DSizes slice_sizes{ + static_cast(size), + static_cast(suffix_dim_size)}; + functor::Split()(context->eigen_device(), + output_shaped, input_reshaped, + slice_indices, slice_sizes); + + outputs->emplace_back(output); + + position += size; + } + + return absl::OkStatus(); +} + +#if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \ + (defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM) + +// Handles the general case, on GPU. +template +Status SplitGPU(OpKernelContext* context, const Tensor& input, + const gtl::ArraySlice& sizes, + std::vector* outputs) { + // TODO(olston, apassos): Implement this. + LOG(FATAL) << "Not yet implemented"; // Crash ok +} + +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM + +// The outer function that dispatches to the various Split*() functions above. +template +absl::Status Split(OpKernelContext* context, const Tensor& input, + const absl::Span sizes, + std::vector* outputs) { + bool easy_cases_done; + TF_RETURN_IF_ERROR( + SplitEasyCases(context, input, sizes, outputs, &easy_cases_done)); + if (easy_cases_done) { + return absl::OkStatus(); + } + +#if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \ + (defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM) +// TODO(olston, apassos): Handle non-CPU cases. +// return SplitGPU(context, input, sizes, outputs); +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM + return SplitCPU(context, input, sizes, outputs); +} + +// Same as 'Split' above, but handles Tensor dtype automatically. +inline absl::Status Split(OpKernelContext* context, const Tensor& input, + const absl::Span sizes, + std::vector* outputs) { + const DataType type = input.dtype(); + absl::Status split_status; + switch (type) { +#define CASE(type) \ + case DataTypeToEnum::value: \ + split_status = Split(context, input, sizes, outputs); \ + break; + TF_CALL_ALL_TYPES(CASE); +#undef CASE + default: + split_status = errors::InvalidArgument("Unsupported data type: ", type); + break; + } + return split_status; +} + +} // namespace concat_split_util +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_CONCAT_SPLIT_UTIL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/batching_util/fake_clock_env.h b/third_party/tflite-hdrs/tensorflow/core/kernels/batching_util/fake_clock_env.h new file mode 100644 index 00000000..6fc8d9e5 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/batching_util/fake_clock_env.h @@ -0,0 +1,77 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_FAKE_CLOCK_ENV_H_ +#define TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_FAKE_CLOCK_ENV_H_ + +#include +#include +#include + +#include "tensorflow/core/lib/core/notification.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/thread_annotations.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +namespace serving { +namespace test_util { + +// An Env implementation with a fake clock for NowMicros() and +// SleepForMicroseconds(). The clock doesn't advance on its own; it advances via +// an explicit Advance() method. +// All other Env virtual methods pass through to a wrapped Env. +class FakeClockEnv : public EnvWrapper { + public: + explicit FakeClockEnv(Env* wrapped); + ~FakeClockEnv() override = default; + + // Advance the clock by a certain number of microseconds. + void AdvanceByMicroseconds(int micros); + + // Blocks until there is a sleeping thread that is scheduled to wake up at + // the given (absolute) time. + void BlockUntilSleepingThread(uint64 wake_time); + + // Blocks until there are at least num_threads sleeping. + void BlockUntilThreadsAsleep(int num_threads); + + // Methods that this class implements. + uint64 NowMicros() const override; + void SleepForMicroseconds(int64_t micros) override; + + private: + mutable mutex mu_; + + uint64 current_time_ TF_GUARDED_BY(mu_) = 0; + + struct SleepingThread { + uint64 wake_time; + Notification* wake_notification; + }; + std::vector sleeping_threads_ TF_GUARDED_BY(mu_); + + FakeClockEnv(const FakeClockEnv&) = delete; + void operator=(const FakeClockEnv&) = delete; +}; + +} // namespace test_util +} // namespace serving +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_FAKE_CLOCK_ENV_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/batching_util/input_split_metadata.h b/third_party/tflite-hdrs/tensorflow/core/kernels/batching_util/input_split_metadata.h new file mode 100644 index 00000000..429858d3 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/batching_util/input_split_metadata.h @@ -0,0 +1,55 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_INPUT_SPLIT_METADATA_H_ +#define TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_INPUT_SPLIT_METADATA_H_ + +#include + +#include "absl/container/fixed_array.h" + +namespace tensorflow { +namespace serving { +namespace internal { +// InputSplitMetadata represents the task sizes of an batch-task after it's +// tailored according to queue status (`open_batch_remaining_slot` and +// `batch_size_limit`). +// +// This is an internal helper class, and the implementation is shared +// shared across different instantiations of internal::Queue +// in input-split mode (QueueOptions.enable_large_batch_splitting is true). +class InputSplitMetadata { + public: + InputSplitMetadata(int input_task_size, int open_batch_remaining_slot, + int batch_size_limit); + + // Returns underlying task sizes. + const absl::FixedArray& task_sizes() const; + + // Serializes task split metadata into a string for debugging. + std::string DebugString() const; + + private: + absl::FixedArray generate_task_sizes(int input_task_size, + int open_batch_remaining_slot, + int batch_size_limit) const; + + const absl::FixedArray task_sizes_; +}; +} // namespace internal +} // namespace serving +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_INPUT_SPLIT_METADATA_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/batching_util/periodic_function.h b/third_party/tflite-hdrs/tensorflow/core/kernels/batching_util/periodic_function.h new file mode 100644 index 00000000..278cfac2 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/batching_util/periodic_function.h @@ -0,0 +1,130 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// PeriodicFunction will periodically call the given function with a specified +// period in a background thread. After Start() returns, the thread is +// guaranteed to have started. The destruction of the class causes the +// background thread to be destroyed as well. Start() should not be called more +// than once. +// +// PeriodicFunction runs the function as soon as any previous run both is +// complete and was started more than "interval_micros" earlier. Thus, runs are +// both serialized, and normally have a period of "interval_micros" if no run +// exceeds the time. +// +// Note that, if the function takes longer than two interval_micross to finish, +// then PeriodicFunction will "skip" at least one call to the function. For +// instance, if the period is 50ms and the function starts runs at time 0 for +// 150ms, then the function will immediately start executing again at time 150, +// but there will be no function runs corresponding to times 50 or 100. This is +// especially important to remember when using an environment with a simulated +// clock: advancing simulated time atomically over N interval_micross will not +// cause the function to be called N times. +// +// This object is thread-safe. +// +// Example: +// +// class Foo { +// public: +// Foo() : periodic_function_([this]() { Bar(); }, +// 1000 /* 1000us == 1ms*/) { +// } +// +// private: +// void Bar() { ... } +// +// PeriodicFunction periodic_function_; +// }; + +#ifndef TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_PERIODIC_FUNCTION_H_ +#define TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_PERIODIC_FUNCTION_H_ + +#include +#include +#include + +#include "absl/functional/any_invocable.h" +#include "tensorflow/core/lib/core/notification.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +namespace serving { + +namespace internal { +class PeriodicFunctionTestAccess; +} + +class PeriodicFunction { + public: + // Provides the ability to customize several aspects of the PeriodicFunction. + // Passed to constructor of PeriodicFunction. + struct Options { + Options() {} + + // Any standard thread options, such as stack size, should + // be passed via "thread_options". + ThreadOptions thread_options; + + // Specifies the thread name prefix (see the description in class + // Thread). + string thread_name_prefix = "periodic_function"; + + // The environment to use. Does not take ownership, but must remain alive + // for as long as the PeriodicFunction exists. + Env* env = Env::Default(); + + // Specifies the length of sleep before the first invocation of the + // function. + // This can be used for adding a random jitter to avoid synchronous behavior + // across multiple periodic functions. + int64_t startup_delay_micros = 0; + }; + + // Also starts the background thread which will be calling the function. + PeriodicFunction(absl::AnyInvocable function, int64_t interval_micros, + const Options& options = Options()); + + ~PeriodicFunction(); + + private: + friend class internal::PeriodicFunctionTestAccess; + + // Notifies the background thread to stop. + void NotifyStop(); + + // (Blocking.) Loops forever calling "function_" every "interval_micros_". + void RunLoop(int64_t start); + + absl::AnyInvocable function_; // Actual client function + const int64_t interval_micros_; // Interval between calls. + const Options options_; + + // Used to notify the thread to stop. + Notification stop_thread_; + + // Thread for running "function_" + std::unique_ptr thread_ = nullptr; + + PeriodicFunction(const PeriodicFunction&) = delete; + void operator=(const PeriodicFunction&) = delete; +}; + +} // namespace serving +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_PERIODIC_FUNCTION_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/batching_util/serial_device_batch_scheduler.h b/third_party/tflite-hdrs/tensorflow/core/kernels/batching_util/serial_device_batch_scheduler.h new file mode 100644 index 00000000..a7285077 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/batching_util/serial_device_batch_scheduler.h @@ -0,0 +1,552 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_SERIAL_DEVICE_BATCH_SCHEDULER_H_ +#define TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_SERIAL_DEVICE_BATCH_SCHEDULER_H_ + +#include +#include +#include +#include +#include +#include + +#include "tensorflow/core/kernels/batching_util/batch_scheduler.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/core/threadpool.h" +#include "tensorflow/core/platform/cpu_info.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/thread_annotations.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +namespace serving { +namespace internal { +template +class SDBSBatch; + +template +class SDBSQueue; +} // namespace internal + +// EXPERIMENTAL: API MAY BE SUBJECTED TO SUDDEN CHANGES. +// +// Shared batch scheduler designed for batches which are processed by a serial +// device (e.g. GPU, TPU). When batch processing involves a mix of +// parallelizable cpu work and non-parallelizable on-device work, overall +// latency can be minimized by producing batches at a (load dependent) rate +// which keeps the serial device uniformly busy. +// +// SerialDeviceBatchScheduler (SDBS) controls the batching rate by limiting the +// allowed number of concurrently processed batches. Too large a limit causes +// batches to pile up behind the serial device, adding to the overall batch +// latency. Too small a limit underutilizes the serial device and harms latency +// by forcing batches to wait longer to be processed. Feedback from the device +// (i.e. avg number of batches directly pending on the device) is used to set +// the correct limit. +// +// SDBS groups requests into per model batches which are processed when a batch +// processing thread becomes available. SDBS prioritizes batches primarily by +// age (i.e. the batch's oldest request) along with a configurable preference +// for scheduling larger batches first. + + +template +class SerialDeviceBatchScheduler : public std::enable_shared_from_this< + SerialDeviceBatchScheduler> { + public: + ~SerialDeviceBatchScheduler(); + + struct Options { + // The name to use for the pool of batch threads. + string thread_pool_name = {"batch_threads"}; + // Maximum number of batch processing threads. + int64_t num_batch_threads = port::NumSchedulableCPUs(); + // Although batch selection is primarily based on age, this parameter + // specifies a preference for larger batches. A full batch will be + // scheduled before an older, nearly empty batch as long as the age gap is + // less than full_batch_scheduling_boost_micros. The optimal value for this + // parameter should be of order the batch processing latency, but must be + // chosen carefully, as too large a value will harm tail latency. + int64_t full_batch_scheduling_boost_micros = 0; + // The environment to use (typically only overridden by test code). + Env* env = Env::Default(); + // Initial limit for number of batches being concurrently processed. + int64_t initial_in_flight_batches_limit = 3; + // Returns the current number of batches directly waiting to be processed + // by the serial device (i.e. GPU, TPU). + std::function get_pending_on_serial_device; + // Desired average number of batches directly waiting to be processed by the + // serial device. Small numbers of O(1) should deliver the best latency. + double target_pending = 2; + // Number of batches between potential adjustments of + // in_flight_batches_limit. Larger numbers will reduce noise, but will be + // less responsive to sudden changes in workload. + int64_t batches_to_average_over = 1000; + }; + + // Ownership is shared between the caller of Create() and any queues created + // via AddQueue(). + static absl::Status Create( + const Options& options, + std::shared_ptr>* scheduler); + + struct QueueOptions { + // Maximum size of each batch. + int max_batch_size = 1000; + // Maximum number of enqueued (i.e. non-scheduled) batches. + int max_enqueued_batches = 10; + }; + + using BatchProcessor = std::function>)>; + + // Adds queue (and its callback) to be managed by this scheduler. + absl::Status AddQueue(const QueueOptions& options, + BatchProcessor process_batch_callback, + std::unique_ptr>* queue); + + double in_flight_batches_limit() { + mutex_lock l(mu_); + return in_flight_batches_limit_; + } + + double recent_low_traffic_ratio() { + mutex_lock l(mu_); + return recent_low_traffic_ratio_; + } + + private: + // access to AddBatch(), RemoveQueue(), env(). + friend class internal::SDBSQueue; + + explicit SerialDeviceBatchScheduler(const Options& options); + + // Continuously retrieves and processes batches. + void ProcessBatches(); + + // Notifies scheduler of non-empty batch which is eligible for processing. + void AddBatch(const internal::SDBSBatch* batch); + + // Removes queue from scheduler. + void RemoveQueue(const internal::SDBSQueue* queue); + + Env* env() const { return options_.env; } + + const Options options_; + + // Collection of batches added by AddBatch. Owned by scheduler until they are + // released for processing. + std::vector*> batches_ TF_GUARDED_BY(mu_); + + // Unowned queues and callbacks added by AddQueue. + std::unordered_map*, BatchProcessor> + queues_and_callbacks_ TF_GUARDED_BY(mu_); + + // Responsible for running the batch processing callbacks. + std::unique_ptr batch_thread_pool_; + + // Limit on number of batches which can be concurrently processed. + int64_t in_flight_batches_limit_ TF_GUARDED_BY(mu_); + + // Number of batch processing threads. + int64_t processing_threads_ TF_GUARDED_BY(mu_) = 0; + + // Number of batches processed since the last in_flight_batches_limit_ + // adjustment. + int64_t batch_count_ TF_GUARDED_BY(mu_) = 0; + + // Number of times since the last in_flight_batches_limit_ adjustment when a + // processing thread was available but there were no batches to process. + int64_t no_batch_count_ TF_GUARDED_BY(mu_) = 0; + + // Sum of batches pending on the serial device since the last + // in_flight_batches_limit_ adjustment. + int64_t pending_sum_ = 0; + + // Sum of batch latencies since the last in_flight_batches_limit_ adjustment. + int64_t batch_latency_sum_ = 0; + + // Average period between which two consecutive batches begin processing. + int64_t batch_period_micros_ = 0; + + // Moving average tracking the fraction of recent in_flight_batches_limit_ + // adjustments where the external traffic was not high enough to provide + // useful feedback for an adjustment. + double recent_low_traffic_ratio_ = 0; + + mutex mu_; + + SerialDeviceBatchScheduler(const SerialDeviceBatchScheduler&) = delete; + void operator=(const SerialDeviceBatchScheduler&) = delete; +}; + +////////////////////////////////////////////////////////// +// Implementation details follow. API users need not read. + +namespace internal { +// Consolidates tasks into batches, passing them off to the +// SerialDeviceBatchScheduler for processing. +template +class SDBSQueue : public BatchScheduler { + public: + using QueueOptions = + typename SerialDeviceBatchScheduler::QueueOptions; + + SDBSQueue(std::shared_ptr> scheduler, + const QueueOptions& options); + + ~SDBSQueue() override; + + // Adds task to current batch. Fails if the task size is larger than the batch + // size or if the current batch is full and this queue's number of outstanding + // batches is at its maximum. + absl::Status Schedule(std::unique_ptr* task) override; + + // Number of tasks waiting to be scheduled. + size_t NumEnqueuedTasks() const override; + + // Number of size 1 tasks which could currently be scheduled without failing. + size_t SchedulingCapacity() const override; + + // Notifies queue that a batch is about to be scheduled; the queue should not + // place any more tasks in this batch. + void ReleaseBatch(const SDBSBatch* batch); + + size_t max_task_size() const override { return options_.max_batch_size; } + + private: + std::shared_ptr> scheduler_; + const QueueOptions options_; + // Owned by scheduler_. + SDBSBatch* current_batch_ TF_GUARDED_BY(mu_) = nullptr; + int64_t num_enqueued_batches_ TF_GUARDED_BY(mu_) = 0; + int64_t num_enqueued_tasks_ TF_GUARDED_BY(mu_) = 0; + mutable mutex mu_; + SDBSQueue(const SDBSQueue&) = delete; + void operator=(const SDBSQueue&) = delete; +}; + +// Batch which remembers when and by whom it was created. +template +class SDBSBatch : public Batch { + public: + SDBSBatch(SDBSQueue* queue, int64_t creation_time_micros) + : queue_(queue), creation_time_micros_(creation_time_micros) {} + + ~SDBSBatch() override {} + + SDBSQueue* queue() const { return queue_; } + + int64_t creation_time_micros() const { return creation_time_micros_; } + + private: + SDBSQueue* queue_; + const int64_t creation_time_micros_; + SDBSBatch(const SDBSBatch&) = delete; + void operator=(const SDBSBatch&) = delete; +}; +} // namespace internal + +// ---------------- SerialDeviceBatchScheduler ---------------- + +template +absl::Status SerialDeviceBatchScheduler::Create( + const Options& options, + std::shared_ptr>* scheduler) { + if (options.num_batch_threads < 1) { + return errors::InvalidArgument("num_batch_threads must be positive; was ", + options.num_batch_threads); + } + if (options.initial_in_flight_batches_limit < 1) { + return errors::InvalidArgument( + "initial_in_flight_batches_limit must be positive; was ", + options.initial_in_flight_batches_limit); + } + if (options.initial_in_flight_batches_limit > options.num_batch_threads) { + return errors::InvalidArgument( + "initial_in_flight_batches_limit (", + options.initial_in_flight_batches_limit, + ") should not be larger than num_batch_threads (", + options.num_batch_threads, ")"); + } + if (options.full_batch_scheduling_boost_micros < 0) { + return errors::InvalidArgument( + "full_batch_scheduling_boost_micros can't be negative; was ", + options.full_batch_scheduling_boost_micros); + } + if (options.batches_to_average_over < 1) { + return errors::InvalidArgument( + "batches_to_average_over should be " + "greater than or equal to 1; was ", + options.batches_to_average_over); + } + if (options.target_pending <= 0) { + return errors::InvalidArgument( + "target_pending should be larger than zero; was ", + options.target_pending); + } + if (!options.get_pending_on_serial_device) { + return errors::InvalidArgument( + "get_pending_on_serial_device must be " + "specified"); + } + scheduler->reset(new SerialDeviceBatchScheduler(options)); + return absl::OkStatus(); +} + +template +SerialDeviceBatchScheduler::SerialDeviceBatchScheduler( + const Options& options) + : options_(options), + in_flight_batches_limit_(options.initial_in_flight_batches_limit), + processing_threads_(options.initial_in_flight_batches_limit) { + batch_thread_pool_.reset(new thread::ThreadPool( + env(), options.thread_pool_name, options.num_batch_threads)); + for (int i = 0; i < processing_threads_; i++) { + batch_thread_pool_->Schedule( + std::bind(&SerialDeviceBatchScheduler::ProcessBatches, this)); + } +} + +template +SerialDeviceBatchScheduler::~SerialDeviceBatchScheduler() { + // Signal processing threads to exit. + { + mutex_lock l(mu_); + processing_threads_ = 0; + } + // Hangs until all threads finish. + batch_thread_pool_.reset(); +} + +template +absl::Status SerialDeviceBatchScheduler::AddQueue( + const QueueOptions& options, BatchProcessor process_batch_callback, + std::unique_ptr>* queue) { + if (options.max_batch_size <= 0) { + return errors::InvalidArgument("max_batch_size must be positive; was ", + options.max_batch_size); + } + if (options.max_enqueued_batches <= 0) { + return errors::InvalidArgument( + "max_enqueued_batches must be positive; was ", + options.max_enqueued_batches); + } + internal::SDBSQueue* SDBS_queue_raw; + queue->reset(SDBS_queue_raw = new internal::SDBSQueue( + this->shared_from_this(), options)); + mutex_lock l(mu_); + queues_and_callbacks_[SDBS_queue_raw] = process_batch_callback; + return absl::OkStatus(); +} + +template +void SerialDeviceBatchScheduler::AddBatch( + const internal::SDBSBatch* batch) { + mutex_lock l(mu_); + batches_.push_back(batch); +} + +template +void SerialDeviceBatchScheduler::RemoveQueue( + const internal::SDBSQueue* queue) { + mutex_lock l(mu_); + queues_and_callbacks_.erase(queue); +} + +template +void SerialDeviceBatchScheduler::ProcessBatches() { + const int64_t kIdleThreadSleepTimeMicros = 1000; + const double kMaxNoBatchRatio = .1; + const double kLowTrafficMovingAverageFactor = .1; + for (;;) { + mu_.lock(); + if (processing_threads_ < 1 || + processing_threads_ > in_flight_batches_limit_) { + processing_threads_--; + mu_.unlock(); + break; + } + if (batches_.empty()) { + no_batch_count_++; + int64_t sleep_time = batch_period_micros_ ? batch_period_micros_ + : kIdleThreadSleepTimeMicros; + mu_.unlock(); + env()->SleepForMicroseconds(sleep_time); + continue; + } + auto best_it = batches_.begin(); + double best_score = + (*best_it)->creation_time_micros() - + options_.full_batch_scheduling_boost_micros * (*best_it)->size() / + static_cast((*best_it)->queue()->max_task_size()); + for (auto it = batches_.begin() + 1; it != batches_.end(); it++) { + const double score = + (*it)->creation_time_micros() - + options_.full_batch_scheduling_boost_micros * (*it)->size() / + static_cast((*it)->queue()->max_task_size()); + if (score < best_score) { + best_score = score; + best_it = it; + } + } + const internal::SDBSBatch* batch = *best_it; + batches_.erase(best_it); + // Queue may destroy itself after ReleaseBatch is called. + batch->queue()->ReleaseBatch(batch); + auto callback = queues_and_callbacks_[batch->queue()]; + mu_.unlock(); + int64_t start_time = env()->NowMicros(); + callback(std::unique_ptr>( + const_cast*>(batch))); + int64_t end_time = env()->NowMicros(); + mu_.lock(); + batch_count_++; + batch_latency_sum_ += end_time - start_time; + pending_sum_ += options_.get_pending_on_serial_device(); + if (batch_count_ == options_.batches_to_average_over) { + recent_low_traffic_ratio_ *= (1 - kLowTrafficMovingAverageFactor); + // Only adjust in_flight_batches_limit_ if external load is large enough + // to consistently provide batches. Otherwise we would (mistakenly) assume + // that the device is underutilized because in_flight_batches_limit_ is + // too small. + if (no_batch_count_ < kMaxNoBatchRatio * batch_count_) { + double avg_pending = pending_sum_ / static_cast(batch_count_); + // Avg processing time / # of concurrent batches gives the avg period + // between which two consecutive batches begin processing. Used to set a + // reasonable sleep time for idle batch processing threads. + batch_period_micros_ = + batch_latency_sum_ / batch_count_ / in_flight_batches_limit_; + // When the processing pipeline is consistently busy, the average number + // of pending batches differs from in_flight_batches_limit_ by a + // load-dependent offset. Adjust in_flight_batches_limit_to maintain + // the desired target pending. + in_flight_batches_limit_ += + std::round(options_.target_pending - avg_pending); + in_flight_batches_limit_ = + std::max(in_flight_batches_limit_, int64_t{1}); + in_flight_batches_limit_ = + std::min(in_flight_batches_limit_, options_.num_batch_threads); + // Add extra processing threads if necessary. + if (processing_threads_ > 0 && + processing_threads_ < in_flight_batches_limit_) { + int extra_threads = in_flight_batches_limit_ - processing_threads_; + for (int i = 0; i < extra_threads; i++) { + batch_thread_pool_->Schedule(std::bind( + &SerialDeviceBatchScheduler::ProcessBatches, this)); + } + processing_threads_ = in_flight_batches_limit_; + } + } else { + recent_low_traffic_ratio_ += kLowTrafficMovingAverageFactor; + } + batch_count_ = 0; + no_batch_count_ = 0; + pending_sum_ = 0; + batch_latency_sum_ = 0; + } + mu_.unlock(); + } +} + +// ---------------- SDBSQueue ---------------- + +namespace internal { +template +SDBSQueue::SDBSQueue( + std::shared_ptr> scheduler, + const QueueOptions& options) + : scheduler_(scheduler), options_(options) {} + +template +SDBSQueue::~SDBSQueue() { + // Wait until last batch has been scheduled. + const int kSleepMicros = 1000; + for (;;) { + { + mutex_lock l(mu_); + if (num_enqueued_batches_ == 0) { + break; + } + } + scheduler_->env()->SleepForMicroseconds(kSleepMicros); + } + scheduler_->RemoveQueue(this); +} + +template +absl::Status SDBSQueue::Schedule(std::unique_ptr* task) { + SDBSBatch* new_batch = nullptr; + size_t size = (*task)->size(); + if (size > options_.max_batch_size) { + return errors::InvalidArgument("Task size ", size, + " is larger than maximum batch size ", + options_.max_batch_size); + } + { + mutex_lock l(mu_); + // Current batch is full, create another if allowed. + if (current_batch_ && + current_batch_->size() + size > options_.max_batch_size) { + if (num_enqueued_batches_ >= options_.max_enqueued_batches) { + return errors::Unavailable("The batch scheduling queue is full"); + } + current_batch_->Close(); + current_batch_ = nullptr; + } + if (!current_batch_) { + num_enqueued_batches_++; + current_batch_ = new_batch = + new SDBSBatch(this, scheduler_->env()->NowMicros()); + } + current_batch_->AddTask(std::move(*task)); + num_enqueued_tasks_++; + } + // AddBatch must be called outside of lock, since it may call ReleaseBatch. + if (new_batch != nullptr) scheduler_->AddBatch(new_batch); + return absl::OkStatus(); +} + +template +void SDBSQueue::ReleaseBatch(const SDBSBatch* batch) { + mutex_lock l(mu_); + num_enqueued_batches_--; + num_enqueued_tasks_ -= batch->num_tasks(); + if (batch == current_batch_) { + current_batch_->Close(); + current_batch_ = nullptr; + } +} + +template +size_t SDBSQueue::NumEnqueuedTasks() const { + mutex_lock l(mu_); + return num_enqueued_tasks_; +} + +template +size_t SDBSQueue::SchedulingCapacity() const { + mutex_lock l(mu_); + const int current_batch_capacity = + current_batch_ ? options_.max_batch_size - current_batch_->size() : 0; + const int spare_batches = + options_.max_enqueued_batches - num_enqueued_batches_; + return spare_batches * options_.max_batch_size + current_batch_capacity; +} +} // namespace internal +} // namespace serving +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_SERIAL_DEVICE_BATCH_SCHEDULER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/batching_util/shared_batch_scheduler.h b/third_party/tflite-hdrs/tensorflow/core/kernels/batching_util/shared_batch_scheduler.h new file mode 100644 index 00000000..347f3008 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/batching_util/shared_batch_scheduler.h @@ -0,0 +1,1548 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_SHARED_BATCH_SCHEDULER_H_ +#define TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_SHARED_BATCH_SCHEDULER_H_ + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/strings/str_format.h" +#include "absl/time/clock.h" +#include "tensorflow/core/kernels/batching_util/batch_input_task.h" +#include "tensorflow/core/kernels/batching_util/batch_scheduler.h" +#include "tensorflow/core/kernels/batching_util/batch_scheduler_utils.h" +#include "tensorflow/core/kernels/batching_util/batch_stats.h" +#include "tensorflow/core/kernels/batching_util/periodic_function.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/cpu_info.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/notification.h" +#include "tensorflow/core/platform/thread_annotations.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/profiler/lib/traceme.h" +#include "tensorflow/core/profiler/lib/traceme_encode.h" +#include "tsl/platform/criticality.h" +#include "tsl/platform/errors.h" +#include "tsl/profiler/lib/connected_traceme.h" +#include "tsl/profiler/lib/context_types.h" +#include "tsl/profiler/lib/traceme.h" + +namespace tensorflow { +namespace serving { +namespace internal { +template +class Queue; +} // namespace internal +} // namespace serving +} // namespace tensorflow + +namespace tensorflow { +namespace serving { + +// A batch scheduler for server instances that service multiple request types +// (e.g. multiple machine-learned models, or multiple versions of a model served +// concurrently), or even multiple distinct tasks for a given request. The +// scheduler multiplexes batches of different kinds of tasks onto a fixed-size +// thread pool (each batch contains tasks of a single type), in a carefully +// controlled manner. A common configuration is to set the number of threads +// equal to the number of hardware accelerator units, in which case the +// scheduler takes care of multiplexing the task types onto the shared hardware, +// in a manner that is both fair and efficient. +// +// Semantically, SharedBatchScheduler behaves like having N instances of +// BasicBatchScheduler (see basic_batch_scheduler.h), one per task type. The +// difference is that under the covers there is a single shared thread pool, +// instead of N independent ones, with their sharing deliberately coordinated. +// +// SharedBatchScheduler does not implement the BatchScheduler API; rather, it +// presents an abstraction of "queues", where each queue corresponds to one type +// of task. Tasks submitted to a given queue are placed in their own batches, +// and cannot be mixed with other tasks. Queues can be added and deleted +// dynamically, to accommodate e.g. versions of a model being brought up and +// down over the lifetime of a server. +// +// The batch thread pool round-robins through the queues, running one batch +// from a queue and then moving to the next queue. Each queue behaves like a +// BasicBatchScheduler instance, in the sense that it has maximum batch size and +// timeout parameters, which govern when a batch is eligible to be processed. +// +// Each queue is independently configured with a maximum size (in terms of the +// maximum number of batches worth of enqueued tasks). For online serving, it is +// recommended that the queue sizes be configured such that the sum of the sizes +// of the active queues roughly equal the number of batch threads. (The idea is +// that if all threads become available at roughly the same time, there will be +// enough enqueued work for them to take on, but no more.) +// +// If queue sizes are configured in the manner suggested above, the maximum time +// a task can spend in a queue before being placed in a batch and assigned to a +// thread for processing, is the greater of: +// - the maximum time to process one batch of tasks from any active queue +// - the configured timeout parameter for the task's queue (which can be 0) +// +// For bulk processing jobs and throughput-oriented benchmarks, you may want to +// set the maximum queue size to a large value. +// +// TODO(b/26539183): Support queue servicing policies other than round-robin. +// E.g. let each queue specify a "share" (an int >= 1), so e.g. with queues A +// and B having shares 1 and 2 respectively, the servicing pattern is ABBABB... +// +// +// PERFORMANCE TUNING: See README.md. +// +template +class SharedBatchScheduler + : public std::enable_shared_from_this> { + public: + using BatchTaskUniquePtr = std::unique_ptr>; + + using ProcessBatchCallback = + std::variant, + std::function>)>>; + // TODO(b/25089730): Tune defaults based on best practices as they develop. + struct Options { + // The name to use for the pool of batch threads. + string thread_pool_name = {"batch_threads"}; + + // The number of threads to use to process batches. + // Must be >= 1, and should be tuned carefully. + int num_batch_threads = port::MaxParallelism(); + + // The environment to use. + // (Typically only overridden by test code.) + Env* env = Env::Default(); + + // If true, when multiple queues have available batches to process, they + // will be prioritized based on a (priority, arrival_time) key. + bool rank_queues = false; + + // If true, Create() will return a global instance of the scheduler. Only + // the options provided in the first Create() call will be used to + // initialize the global scheduler. + bool use_global_scheduler = false; + }; + // Ownership is shared between the caller of Create() and any queues created + // via AddQueue(). + static absl::Status Create( + const Options& options, + std::shared_ptr>* scheduler); + + virtual ~SharedBatchScheduler(); + + // Adds a queue to which tasks may be submitted. The returned queue implements + // the BatchScheduler API. Each queue has its own set of scheduling options, + // and its own callback to process batches of tasks submitted to the queue. + // + // The returned queue's destructor blocks until all tasks submitted to it have + // been processed. + struct QueueOptions { + // The size limit of an input batch to the queue. + // + // If `enable_large_batch_splitting` is True, 'input_batch_size_limit' + // should be greater or equal than `max_execution_batch_size`; otherwise + // `input_batch_size_limit` should be equal to `max_execution_batch_size`. + size_t input_batch_size_limit = 1000; + + // If a task has been enqueued for this amount of time (in microseconds), + // and a thread is available, the scheduler will immediately form a batch + // from enqueued tasks and assign the batch to the thread for processing, + // even if the batch's size is below 'input_batch_size_limit'. + // + // This parameter offers a way to bound queue latency, so that a task isn't + // stuck in the queue indefinitely waiting for enough tasks to arrive to + // make a full batch. (The latency bound is given in the class documentation + // above.) + // + // The goal is to smooth out batch sizes under low request rates, and thus + // avoid latency spikes. + int64_t batch_timeout_micros = 0; + + // The maximum allowable number of enqueued (accepted by Schedule() but + // not yet being processed on a batch thread) tasks in terms of batches. + // If this limit is reached, Schedule() will return an UNAVAILABLE error. + // See the class documentation above for guidelines on how to tune this + // parameter. + // + // Must be positive, or else invalid argument error will be returned at + // queue creation time. + size_t max_enqueued_batches = 10; + + // If true, queue implementation would split one input batch task into + // subtasks (as specified by `split_input_task_func` below) and fit subtasks + // into different batches. + // + // For usage of `split_input_task_func`, please see its comment. + bool enable_large_batch_splitting = false; + + // `input_task`: a unit of task to be split. + // `first_output_task_size`: task size of first output. + // `max_execution_batch_size`: Maximum size of each batch. + // `output_tasks`: A list of output tasks after split. + // + // REQUIRED: + // 1) All `output_tasks` should be non-empty tasks. + // 2) Sizes of `output_tasks` add up to size of `input_task`. + // + // NOTE: + // Instantiations of `TaskType` may vary, so it's up to caller to define + // how (e.g., which members to access) to split input tasks. + std::function* input_task, int first_output_task_size, + int input_batch_size_limit, + std::vector>* output_tasks)> + split_input_task_func; + + // The maximum size of each enqueued batch (i.e., in + // `high_priority_batches_`). + // + // The scheduler may form batches of any size between 1 and this number + // (inclusive). If there is a need to quantize the batch sizes, i.e. only + // submit batches whose size is in a small set of allowed sizes, that can be + // done by adding padding in the process-batch callback. + size_t max_execution_batch_size = 1000; + + // If non-empty, contains configured batch sizes. + std::vector allowed_batch_sizes; + + // If true, the padding will not be appended. + bool disable_padding = false; + + // The padding policy to use. + // + // See the documentation for kPadUpPolicy for details. + string batch_padding_policy = string(kPadUpPolicy); + + // A pointer to a ModelBatchStats instance for this model. To be used for + // cost-based padding policy selection. + // + // If null, some other padding policy will be used if a cost-based one is + // requested. + ModelBatchStats* model_batch_stats = nullptr; + + // If true, queue implementation would split high priority and low priority + // inputs into two sub queues. + bool enable_priority_queue = false; + + // A separate set of queue options for different priority inputs. + // Use iff `enable_priority_queue` is true. + struct PriorityQueueOptions { + // See QueueOptions.max_execution_batch_size + size_t max_execution_batch_size = 0; + // See QueueOptions.batch_timeout_micros + int64_t batch_timeout_micros = 0; + // See QueueOptions.input_batch_size_limit + size_t input_batch_size_limit = 0; + // See QueueOptions.max_enqueued_batches + size_t max_enqueued_batches = 0; + // See QueueOptions.allowed_batch_sizes + std::vector allowed_batch_sizes; + }; + // A subset of queue options for high priority input. These options are + // currently not being used in favor of the equivalents options at the + // QueueOptions level. + PriorityQueueOptions high_priority_queue_options; + // A subset of queue options for low priority input. + PriorityQueueOptions low_priority_queue_options; + + // A policy that determines the mixed priority batching behavior. It is + // effective only when enable_priority_queue is true. + MixedPriorityBatchingPolicy mixed_priority_batching_policy = + MixedPriorityBatchingPolicy::kLowPriorityPaddingWithMaxBatchSize; + }; + // This method is marked virtual for testing purposes only. + virtual absl::Status AddQueue( + const QueueOptions& options, ProcessBatchCallback process_batch_callback, + std::unique_ptr>* queue); + + protected: + explicit SharedBatchScheduler(const Options& options); + + private: + void GetNextWorkItem_Locked(internal::Queue** queue_for_batch_out, + BatchTaskUniquePtr* batch_to_process_out) + TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); + + // The code executed in 'batch_threads_'. Obtains a batch to process from the + // queue pointed to by 'next_queue_to_schedule_', and processes it. If that + // queue declines to provide a batch to process, moves onto the next queue. If + // no queues provide a batch to process, just sleeps briefly and exits. + void ThreadLogic(); + + // Called by `AddQueue`. + absl::Status AddQueueAfterRewritingOptions( + const QueueOptions& options, ProcessBatchCallback process_batch_callback, + std::unique_ptr>* queue); + + static bool BatchExists(const BatchTaskUniquePtr& batch_to_process); + + const Options options_; + + mutex mu_; + + // A list of queues. (We use std::list instead of std::vector to ensure that + // iterators are not invalidated by adding/removing elements. It also offers + // efficient removal of elements from the middle.) + using QueueList = std::list>>; + + // All "active" queues, i.e. ones that either: + // - have not been removed, or + // - have been removed but are not yet empty. + QueueList queues_ TF_GUARDED_BY(mu_); + + // An iterator over 'queues_', pointing to the queue from which the next + // available batch thread should grab work. + typename QueueList::iterator next_queue_to_schedule_ TF_GUARDED_BY(mu_); + + // Used by idle batch threads to wait for work to enter the system. Notified + // whenever a batch becomes schedulable. + condition_variable schedulable_batch_cv_; + + // Threads that process batches obtained from the queues. + std::vector> batch_threads_; + + SharedBatchScheduler(const SharedBatchScheduler&) = delete; + void operator=(const SharedBatchScheduler&) = delete; +}; + +////////// +// Implementation details follow. API users need not read. + +namespace internal { + +// A task queue for SharedBatchScheduler. Accepts tasks and accumulates them +// into batches, and dispenses those batches to be processed via a "pull" +// interface. The queue's behavior is governed by maximum batch size, timeout +// and maximum queue length parameters; see their documentation in +// SharedBatchScheduler. +// +// The queue is implemented as a deque of batches, with these invariants: +// - The number of batches is between 1 and 'options_.max_enqueued_batches'. +// - The back-most batch is open; the rest are closed. +// +// Submitted tasks are added to the open batch. If that batch doesn't have room +// but the queue isn't full, then that batch is closed and a new open batch is +// started. +// +// Batch pull requests are handled by dequeuing the front-most batch if it is +// closed. If the front-most batch is open (i.e. the queue contains only one +// batch) and has reached the timeout, it is immediately closed and returned; +// otherwise no batch is returned for the request. +template +class Queue { + public: + using ProcessBatchCallbackWithoutPaddingTasks = + std::function>)>; + using ProcessBatchCallbackWithPaddingTasks = + std::function>, + std::vector>)>; + using ProcessBatchCallback = + std::variant; + + using SchedulableBatchCallback = std::function; + using SplitInputTaskIntoSubtasksCallback = std::function* input_task, int open_batch_remaining_slot, + int max_execution_batch_size, + std::vector>* output_tasks)>; + // Orderable key representing the priority of a batch. Higher priority + // batches will be prioritized for execution first (when using + // rank_queues=true). + // - A smaller key value is higher priority than a larger one. + // - This is a pair formed from . The exact values + // used are an implementation detail of PeekBatchPriority(). + using BatchPriorityKey = std::pair; + + Queue(const typename SharedBatchScheduler::QueueOptions& options, + Env* env, ProcessBatchCallback process_batch_callback, + SchedulableBatchCallback schedulable_batch_callback); + + // Illegal to destruct unless the queue is empty. + ~Queue(); + + // Submits a task to the queue, with the same semantics as + // BatchScheduler::Schedule(). + absl::Status Schedule(std::unique_ptr* task); + + // Returns the number of enqueued tasks, with the same semantics as + // BatchScheduler::NumEnqueuedTasks(). + size_t NumEnqueuedTasks() const; + + // Returns the queue capacity, with the same semantics as + // BatchScheduler::SchedulingCapacity(). + size_t SchedulingCapacity() const; + + // Returns the maximum allowed size of tasks submitted to the queue. + size_t max_task_size() const { return options_.input_batch_size_limit; } + + // Returns the maximum allowed size of tasks to be executed. + // Returned value would be less than or equal to the maximum allowed input + // size that's provided by caller of batch scheduler. + size_t max_execution_batch_size() const { return max_execution_batch_size_; } + + // Called by a thread that is ready to process a batch, to request one from + // this queue. Either returns a batch that is ready to be processed, or + // nullptr if the queue declines to schedule a batch at this time. If it + // returns a batch, the batch is guaranteed to be closed. + typename SharedBatchScheduler::BatchTaskUniquePtr ScheduleBatch(); + + // Without mutating the queue, checks if ScheduleBatch() will return a valid + // batch and if so will return the priority of that batch. + std::optional PeekBatchPriority() const; + + // Retrieves the low priority tasks that can be padded to a high priority + // batch of the specified size. + std::vector> GetLowPriorityTasksForPadding( + size_t batch_size); + + // Processes a batch that has been returned earlier by ScheduleBatch(). + void ProcessBatch(std::unique_ptr> batch, + std::vector> padding_task); + + // Determines whether the queue is empty, i.e. has no tasks waiting or being + // processed. + bool IsEmpty() const; + + // Marks the queue closed, and waits until it is empty. + void CloseAndWaitUntilEmpty(); + + bool closed() const TF_NO_THREAD_SAFETY_ANALYSIS { return closed_.load(); } + + private: + // Computes the max_execution_batch_size of the queue based on queue options. + static size_t GetMaxExecutionBatchSize( + const typename SharedBatchScheduler::QueueOptions& options) { + // If `enable_large_batch_splitting`, returns `max_execution_batch_size` + // configured by user options directly; returns `input_batch_size_limit` + // otherwise. + // + // Note `input_batch_size_limit` is used for backward compatibitliy -> + // users may not specify `max_execution_batch_size` explicitly. + if (options.enable_large_batch_splitting) { + return options.max_execution_batch_size; + } else { + return options.input_batch_size_limit; + } + } + + // Same as IsEmpty(), but assumes the caller already holds a lock on 'mu_'. + bool IsEmptyInternal() const TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); + + // Returns true iff the task is a low priority task based on the queue option. + bool IsLowPriorityTask(std::unique_ptr* task); + + // Implementation of Schedule above. Enqueues `task` as it + // is or split it inline (eagerly) to form batches to be processed by + // `Queue::ProcessBatch` + absl::Status ScheduleWithoutOrEagerSplitImpl(std::unique_ptr* task) + TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); + + // Pads the open batch until it is full with low priority tasks. + void PadOpenBatchWithLowPriorityTasks() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); + + // Closes the open batch residing at the back of std::deque, and inserts a + // fresh open batch behind it. + void StartNewBatch() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); + + // Split `input task` into `output_tasks` according to 'task_sizes'. + absl::Status SplitInputBatchIntoSubtasks( + std::unique_ptr* input_task, + std::vector>* output_tasks) + TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); + + // Determines whether the open batch residing at the back of + // 'high_priority_batches_' is currently schedulable. + bool IsOpenBatchSchedulable() const TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); + + std::optional PeekBatchPriorityImpl() const + TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); + + // Determines whether the low priority tasks in `low_priority_tasks_` can form + // a batch on their own. If yes, returns a batch that is ready to be + // processed. Otherwise, returns an empty unique_ptr. + std::unique_ptr> ScheduleLowPriorityBatch() + TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); + + // Same as SchedulingCapacity(), but assumes the caller already holds a + // lock on 'mu_'. + size_t SchedulingCapacityInternal() const TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); + + // Returns an error if queue doesn't have capacity for this task. + // + // `task` must outlive this method. + absl::Status ValidateBatchTaskQueueCapacity(TaskType* task) const + TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); + + // Returns an error if the low priority task queue doesn't have capacity for + // this task using the low priority batch options. Since the low priority + // tasks are not batched until they get scheduled, it only checks that a + // single task does not it exceed input batch size limit and the total size of + // the tasks in the queue does not exceed the max batch size * max enqueued + // batch sizes. + absl::Status ValidateLowPriorityTaskQueueCapacity(const TaskType& task) const + TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); + + // The task size of the last batch in the queue. + size_t tail_batch_task_size() const TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); + + // Returns the number of enqueued batches. + int64 num_enqueued_batches() const TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); + + // Gets the appropriate batches. + std::deque>>& GetBatches() + TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); + + // Gets the appropriate batches (const version). + const std::deque>>& GetBatches() const + TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); + + // Gets the low priority task queue. + TaskQueue& GetLowPriorityTaskQueue() + TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); + + // Retrieves the tasks up to the specified size from the low priority task + // queue. It will immediately return an empty vector when + // enable_priority_queue is false. + std::vector> GetLowPriorityTasks(size_t size); + + const typename SharedBatchScheduler::QueueOptions options_; + + // The environment to use. + Env* env_; + + // The maximum batch size to be executed by `Queue::ProcessBatch`. + // See the comment of QueueOptions and helper function + // `GetMaxExecutionBatchSize` for more details on what it means. + const size_t max_execution_batch_size_; + + // A callback invoked to processes a batch of work units. Always invoked + // from a batch thread. + ProcessBatchCallback process_batch_callback_; + + // A callback invoked to notify the scheduler that a new batch has become + // schedulable. + SchedulableBatchCallback schedulable_batch_callback_; + + mutable mutex mu_; + + // Whether this queue can accept new tasks. This variable is monotonic: it + // starts as false, and then at some point gets set to true and remains true + // for the duration of this object's life. + std::atomic closed_ TF_GUARDED_BY(mu_){false}; + + // The enqueued tasks for low priority inputs. + // Each element corresponds to a task to be dequeued. These tasks to be + // consumed by `Queue::ProcessBatch` to either pad the high priority + // batches below or form their own batch to be executed. + TaskQueue low_priority_tasks_ TF_GUARDED_BY(mu_); + + // The enqueued batches for high priority input. + // Each element corresponds to a task to be dequeued and processed by + // `Queue::ProcessBatch`. + std::deque>> high_priority_batches_ + TF_GUARDED_BY(mu_); + + // The counter of the TraceMe context ids. + uint64 traceme_context_id_counter_ TF_GUARDED_BY(mu_) = 0; + + // The time at which the first task was added to the open (back-most) batch + // in 'high_priority_batches_'. Valid iff that batch contains at least one + // task. + // + // Note that when using a batch padding policy other than PAD_UP, this field + // might contain an approximate value. + uint64 open_batch_start_time_micros_ TF_GUARDED_BY(mu_); + + // Whether this queue contains a batch that is eligible to be scheduled. + // Used to keep track of when to call 'schedulable_batch_callback_'. + bool schedulable_batch_ TF_GUARDED_BY(mu_) = false; + + // The number of batches currently being processed by batch threads. + // Incremented in ScheduleBatch() and decremented in ProcessBatch(). + int num_batches_being_processed_ TF_GUARDED_BY(mu_) = 0; + + // Used by CloseAndWaitUntilEmpty() to wait until the queue is empty, for + // the case in which the queue is not empty when CloseAndWaitUntilEmpty() + // starts. When ProcessBatch() dequeues the last batch and makes the queue + // empty, if 'empty_notification_' is non-null it calls + // 'empty_notification_->Notify()'. + Notification* empty_notification_ TF_GUARDED_BY(mu_) = nullptr; + + Queue(const Queue&) = delete; + void operator=(const Queue&) = delete; +}; + +// A RAII-style object that points to a Queue and implements +// the BatchScheduler API. To be handed out to clients who call AddQueue(). +template +class QueueHandle : public BatchScheduler { + public: + QueueHandle(std::shared_ptr> scheduler, + Queue* queue); + ~QueueHandle() override; + + absl::Status Schedule(std::unique_ptr* task) override; + size_t NumEnqueuedTasks() const override; + size_t SchedulingCapacity() const override; + + size_t max_task_size() const override { return queue_->max_task_size(); } + + private: + // The scheduler that owns 'queue_'. + std::shared_ptr> scheduler_; + + // The queue this handle wraps. Owned by 'scheduler_', which keeps it alive at + // least until this class's destructor closes it. + Queue* queue_; + + QueueHandle(const QueueHandle&) = delete; + void operator=(const QueueHandle&) = delete; +}; + +} // namespace internal + +template +absl::Status SharedBatchScheduler::Create( + const Options& options, + std::shared_ptr>* scheduler) { + if (options.num_batch_threads < 1) { + return errors::InvalidArgument("num_batch_threads must be positive; was ", + options.num_batch_threads); + } + + if (options.use_global_scheduler) { + static std::shared_ptr>* global_scheduler = + [&]() { + return new std::shared_ptr>( + new SharedBatchScheduler(options)); + }(); + *scheduler = *global_scheduler; + return absl::OkStatus(); + } + + scheduler->reset(new SharedBatchScheduler(options)); + return absl::OkStatus(); +} + +template +SharedBatchScheduler::~SharedBatchScheduler() { + // Wait until the batch threads finish clearing out and deleting the closed + // queues. + for (;;) { + { + mutex_lock l(mu_); + if (queues_.empty()) { + break; + } + } + const int64_t kSleepTimeMicros = 100; + options_.env->SleepForMicroseconds(kSleepTimeMicros); + } + // Delete the batch threads before allowing state the threads may access (e.g. + // 'mu_') to be deleted. + batch_threads_.clear(); +} + +template +absl::Status SharedBatchScheduler::AddQueue( + const QueueOptions& options, ProcessBatchCallback process_batch_callback, + std::unique_ptr>* queue) { + QueueOptions rewrite_options = options; + if ((!rewrite_options.enable_large_batch_splitting) && + rewrite_options.max_enqueued_batches == 0) { + // Many existing models (with very low QPS) rely on this option to be >0. + // Rewrite and set this to one and retain old behavior to allow such models + // to continue to work. + // + // Note, technically an invalid-argument error should be returned, but + // that may break such models. + rewrite_options.max_enqueued_batches = 1; + } + return AddQueueAfterRewritingOptions(rewrite_options, process_batch_callback, + queue); +} + +template +absl::Status SharedBatchScheduler::AddQueueAfterRewritingOptions( + const QueueOptions& options, ProcessBatchCallback process_batch_callback, + std::unique_ptr>* queue) { + if (options.input_batch_size_limit == 0) { + return errors::InvalidArgument( + "input_batch_size_limit must be positive; was ", + options.input_batch_size_limit); + } + if (options.batch_timeout_micros < 0) { + return errors::InvalidArgument( + "batch_timeout_micros must be non-negative; was ", + options.batch_timeout_micros); + } + if (options.max_enqueued_batches == 0) { + return errors::InvalidArgument( + "max_enqueued_batches must be positive; was ", + options.max_enqueued_batches); + } + + if (options.enable_large_batch_splitting && + options.split_input_task_func == nullptr) { + return errors::InvalidArgument( + "split_input_task_func must be specified when split_input_task is " + "true: ", + options.enable_large_batch_splitting); + } + + if (options.enable_large_batch_splitting && + (options.input_batch_size_limit < options.max_execution_batch_size)) { + return errors::InvalidArgument( + "When enable_large_batch_splitting is true, input_batch_size_limit " + "must be " + "greater than or equal to max_execution_batch_size.", + options.enable_large_batch_splitting, options.input_batch_size_limit, + options.max_execution_batch_size); + } + + auto schedulable_batch_callback = [this] { + mutex_lock l(mu_); + schedulable_batch_cv_.notify_one(); + }; + auto internal_queue = + std::unique_ptr>(new internal::Queue( + options, options_.env, process_batch_callback, + schedulable_batch_callback)); + auto handle = std::unique_ptr>( + new internal::QueueHandle(this->shared_from_this(), + internal_queue.get())); + { + mutex_lock l(mu_); + queues_.push_back(std::move(internal_queue)); + if (next_queue_to_schedule_ == queues_.end()) { + next_queue_to_schedule_ = queues_.begin(); + } + } + *queue = std::move(handle); + return absl::OkStatus(); +} + +template +SharedBatchScheduler::SharedBatchScheduler(const Options& options) + : options_(options), next_queue_to_schedule_(queues_.end()) { + // Kick off the batch threads. + PeriodicFunction::Options periodic_fn_options; + periodic_fn_options.thread_name_prefix = + strings::StrCat(options.thread_pool_name, "_"); + for (int i = 0; i < options.num_batch_threads; ++i) { + std::unique_ptr thread(new PeriodicFunction( + [this] { this->ThreadLogic(); }, + 0 /* function invocation interval time */, periodic_fn_options)); + batch_threads_.push_back(std::move(thread)); + } +} + +template +bool SharedBatchScheduler::BatchExists( + const BatchTaskUniquePtr& batch_to_process) { + return batch_to_process != nullptr; +} + +template +void SharedBatchScheduler::GetNextWorkItem_Locked( + internal::Queue** queue_for_batch_out, + BatchTaskUniquePtr* batch_to_process_out) { + BatchTaskUniquePtr batch_to_process; + internal::Queue* queue_for_batch = nullptr; + std::optional::BatchPriorityKey> + batch_priority_key; + const int num_queues = queues_.size(); + for (int num_queues_tried = 0; + !BatchExists(batch_to_process) && num_queues_tried < num_queues; + ++num_queues_tried) { + DCHECK(next_queue_to_schedule_ != queues_.end()); + + // If a closed queue responds to ScheduleBatch() with nullptr, the queue + // will never yield any further batches so we can drop it. To avoid a + // race, we take a snapshot of the queue's closedness state *before* + // calling ScheduleBatch(). + const bool queue_closed = (*next_queue_to_schedule_)->closed(); + + bool queue_has_work = false; + + if (options_.rank_queues) { + auto key = (*next_queue_to_schedule_)->PeekBatchPriority(); + queue_has_work = key.has_value(); + if (key.has_value() && (!batch_priority_key.has_value() || + key.value() < batch_priority_key.value())) { + batch_priority_key = key; + queue_for_batch = next_queue_to_schedule_->get(); + } + } else { + // Ask '*next_queue_to_schedule_' if it wants us to process a batch. + batch_to_process = (*next_queue_to_schedule_)->ScheduleBatch(); + queue_has_work = BatchExists(batch_to_process); + + if (queue_has_work) { + queue_for_batch = next_queue_to_schedule_->get(); + } + } + + // Advance 'next_queue_to_schedule_'. + if (queue_closed && (*next_queue_to_schedule_)->IsEmpty() && + !queue_has_work) { + // We've encountered a closed queue with no work to do. Drop it. + DCHECK_NE(queue_for_batch, next_queue_to_schedule_->get()); + next_queue_to_schedule_ = queues_.erase(next_queue_to_schedule_); + } else { + ++next_queue_to_schedule_; + } + if (next_queue_to_schedule_ == queues_.end() && !queues_.empty()) { + // We've hit the end. Wrap to the first queue. + next_queue_to_schedule_ = queues_.begin(); + } + } + + if (options_.rank_queues && batch_priority_key.has_value()) { + batch_to_process = queue_for_batch->ScheduleBatch(); + } + + *queue_for_batch_out = queue_for_batch; + *batch_to_process_out = std::move(batch_to_process); +} + +template +void SharedBatchScheduler::ThreadLogic() { + // A batch to process next (or nullptr if no work to do). + BatchTaskUniquePtr batch_to_process; + // The queue with which 'batch_to_process' is associated. + internal::Queue* queue_for_batch = nullptr; + { + mutex_lock l(mu_); + while (true) { + GetNextWorkItem_Locked(&queue_for_batch, &batch_to_process); + if (BatchExists(batch_to_process)) break; + // We couldn't find any work to do. Wait until a new batch becomes + // schedulable, or some time has elapsed, before checking again. + const int64_t kTimeoutMillis = + 1; // The smallest accepted granule of time. + WaitForMilliseconds(&l, &schedulable_batch_cv_, kTimeoutMillis); + if (queues_.empty()) return; + } + } + + size_t batch_size_to_schedule = batch_to_process->size(); + queue_for_batch->ProcessBatch( + std::move(batch_to_process), + queue_for_batch->GetLowPriorityTasksForPadding(batch_size_to_schedule)); +} + +namespace internal { + +template +Queue::Queue( + const typename SharedBatchScheduler::QueueOptions& options, + Env* env, ProcessBatchCallback process_batch_callback, + SchedulableBatchCallback schedulable_batch_callback) + : options_(options), + env_(env), + max_execution_batch_size_(GetMaxExecutionBatchSize(options_)), + process_batch_callback_(process_batch_callback), + schedulable_batch_callback_(schedulable_batch_callback) { + // Set the higher 32 bits of traceme_context_id_counter_ to be the creation + // time of the queue. This prevents the batches in different queues to have + // the same traceme_context_id_counter_. + traceme_context_id_counter_ = (absl::GetCurrentTimeNanos() & 0xFFFFFFFF) + << 32; + GetBatches().emplace_back(new Batch); +} + +template +Queue::~Queue() { + mutex_lock l(mu_); + DCHECK(IsEmptyInternal()); + GetBatches().back()->Close(); +} + +template +bool Queue::IsLowPriorityTask(std::unique_ptr* task) { + if (!options_.enable_priority_queue) { + return false; + } + + // The criticality is defined only when the task is a derived class of + // BatchTask. + if constexpr (std::is_base_of_v) { + // TODO(b/316379576): Make the criticality and priority configurable. + return ((*task)->criticality() == + tsl::criticality::Criticality::kSheddablePlus || + (*task)->criticality() == + tsl::criticality::Criticality::kSheddable); + } + + // Otherwise, consider it a high priority task and return false. + return false; +} + +template +absl::Status Queue::ScheduleWithoutOrEagerSplitImpl( + std::unique_ptr* task) { + // TODO(b/161857471): + // Add test coverage when when concurrent incoming batches arrives and + // use up all queue capacity. + TF_RETURN_IF_ERROR(ValidateBatchTaskQueueCapacity((*task).get())); + + std::deque>>& batches = GetBatches(); + + const int64_t open_batch_remaining_slot = + max_execution_batch_size() - batches.back()->size(); + + const int64_t input_task_size = (*task)->size(); + + std::vector> output_tasks; + + if (input_task_size <= open_batch_remaining_slot || + !options_.enable_large_batch_splitting) { + // This is the fast path when input doesn't need to be split. + output_tasks.push_back(std::move(*task)); + } else { + TF_RETURN_IF_ERROR(SplitInputBatchIntoSubtasks(task, &output_tasks)); + } + + for (int i = 0; i < output_tasks.size(); ++i) { + if (batches.back()->size() + output_tasks[i]->size() > + max_execution_batch_size()) { + StartNewBatch(); + } + if (batches.back()->empty()) { + open_batch_start_time_micros_ = env_->NowMicros(); + } + tsl::profiler::TraceMeProducer trace_me( + [&output_tasks, i] { + return profiler::TraceMeEncode("ScheduleOutputTask", + {{"size", output_tasks[i]->size()}}); + }, + tsl::profiler::ContextType::kSharedBatchScheduler, + batches.back()->traceme_context_id()); + batches.back()->AddTask(std::move(output_tasks[i]), env_->NowMicros()); + } + + return absl::OkStatus(); +} + +template +void Queue::PadOpenBatchWithLowPriorityTasks() { + std::deque>>& batches = GetBatches(); + + const bool should_pad = options_.enable_priority_queue && + options_.mixed_priority_batching_policy == + MixedPriorityBatchingPolicy::kPriorityMerge && + batches.size() == 1 && IsOpenBatchSchedulable(); + if (!should_pad) { + return; + } + + // If true, the next low priority task couldn't fit in the remaining space of + // the open batch. + bool out_of_space = false; + + while (!low_priority_tasks_.empty() && !out_of_space) { + const int64_t open_batch_remaining_slot = + max_execution_batch_size() - batches.back()->size(); + if (open_batch_remaining_slot <= 0) { + // Terminate early if the open batch is full. Remaining low priority tasks + // will be re-checked during the next batch formation opportunity. + return; + } + + uint64 task_time = low_priority_tasks_.EarliestTaskStartTime().value(); + std::unique_ptr task = low_priority_tasks_.RemoveTask(); + + const int64_t input_task_size = task->size(); + + std::vector> output_tasks; + + if (input_task_size <= open_batch_remaining_slot || + !options_.enable_large_batch_splitting) { + // This is the fast path when input doesn't need to be split. + output_tasks.push_back(std::move(task)); + } else { + absl::Status status = SplitInputBatchIntoSubtasks(&task, &output_tasks); + if (!status.ok()) { + LOG(ERROR) << "Failed to split low priority task: " << status; + continue; + } + } + + for (int i = 0; i < output_tasks.size(); ++i) { + if (batches.back()->size() + output_tasks[i]->size() > + max_execution_batch_size()) { + low_priority_tasks_.PrependTask(std::move(output_tasks[i]), task_time); + out_of_space = true; + // NOTE: Future iterations of this loop will also hit this case but are + // needed to re-add all the unused tasks to the low priority queue. + continue; + } + + if (batches.back()->empty()) { + open_batch_start_time_micros_ = task_time; + } else { + open_batch_start_time_micros_ = + std::min(open_batch_start_time_micros_, task_time); + } + + tsl::profiler::TraceMeProducer trace_me( + [&output_tasks, i] { + return profiler::TraceMeEncode("ScheduleOutputTask", + {{"size", output_tasks[i]->size()}}); + }, + tsl::profiler::ContextType::kSharedBatchScheduler, + batches.back()->traceme_context_id()); + + batches.back()->AddTask(std::move(output_tasks[i])); + } + } +} + +template +absl::Status Queue::Schedule(std::unique_ptr* task) { + const bool large_batch_splitting = options_.enable_large_batch_splitting; + tsl::profiler::TraceMe trace_me([task, large_batch_splitting] { + return profiler::TraceMeEncode( + large_batch_splitting ? "ScheduleWithEagerSplit" + : "ScheduleWithoutSplit", + {{"batching_input_task_size", (*task)->size()}}); + }); + + bool notify_of_schedulable_batch = false; + { + mutex_lock l(mu_); + + DCHECK(!closed_); + + if (IsLowPriorityTask(task)) { + // Insert the task to the low priority task queue instead of the high + // priority batch queue below. + TF_RETURN_IF_ERROR(ValidateLowPriorityTaskQueueCapacity(**task)); + low_priority_tasks_.AddTask(std::move(*task), env_->NowMicros()); + } else { + TF_RETURN_IF_ERROR(ScheduleWithoutOrEagerSplitImpl(task)); + } + + // Check if the batch queue has a schedulable batch and mark it schedulable + // if it not already marked. + if (!schedulable_batch_) { + if (GetBatches().size() > 1 || IsOpenBatchSchedulable()) { + schedulable_batch_ = true; + notify_of_schedulable_batch = true; + } + } + } + + if (notify_of_schedulable_batch) { + schedulable_batch_callback_(); + } + + return absl::OkStatus(); +} + +template +size_t Queue::NumEnqueuedTasks() const { + size_t num_enqueued_tasks = 0; + mutex_lock l(mu_); + for (const auto& batch : GetBatches()) { + num_enqueued_tasks += batch->num_tasks(); + } + return num_enqueued_tasks + low_priority_tasks_.num_tasks(); +} + +template +size_t Queue::SchedulingCapacity() const { + mutex_lock l(mu_); + return SchedulingCapacityInternal(); +} + +template +size_t Queue::SchedulingCapacityInternal() const { + const int64 num_new_batches_schedulable = + static_cast(options_.max_enqueued_batches) - + this->num_enqueued_batches(); + const int64 execution_batch_size_limit = max_execution_batch_size(); + const int64 open_batch_capacity = + execution_batch_size_limit - this->tail_batch_task_size(); + // Note the returned value is guaranteed to be not negative, since + // enqueue operation could only happen if queue has enough capacity. + return (num_new_batches_schedulable * execution_batch_size_limit) + + open_batch_capacity; +} + +template +absl::Status Queue::ValidateBatchTaskQueueCapacity( + TaskType* task) const { + // Check if the task size is larger than the batch size limit, regardless of + // the batch capacity. + if (task->size() > options_.input_batch_size_limit) { + return absl::InvalidArgumentError(absl::StrFormat( + "Task size %d is larger than maximum input batch size %d", task->size(), + options_.input_batch_size_limit)); + } + + if (options_.enable_large_batch_splitting) { + if (task->size() > SchedulingCapacityInternal()) { + return errors::Unavailable( + "The batch scheduling queue to which this task was submitted is " + "full; task size is ", + task->size(), " but scheduling capacity is only ", + SchedulingCapacityInternal(), + " (num_enqueued_batches=", num_enqueued_batches(), + ", max_enqueued_batches=", options_.max_enqueued_batches, + ", open_batch_size=", tail_batch_task_size(), + ", max_execution_batch_size=", max_execution_batch_size(), ")"); + } + return absl::OkStatus(); + } + + // NOTE, the capacity checking below is loose and is retained + // for backward compatibility that was broken due to the merge of no-split + // and eager split. + // There are existing clients/models that rely on the loose check + // and can get errors after the merge. Retaining the old behavior + // allows such models to continue to work. + // + // We need to revisit/remove this check after we fix model configs. + const std::deque>>& batches = GetBatches(); + if (batches.back()->size() + task->size() > options_.input_batch_size_limit) { + if (batches.size() >= options_.max_enqueued_batches) { + return errors::Unavailable( + "The batch scheduling queue to which this task was submitted is " + "full; currently ", + batches.size(), " batches enqueued and max_enqueued_batches is ", + options_.max_enqueued_batches); + } + } + return absl::OkStatus(); +} + +template +absl::Status Queue::ValidateLowPriorityTaskQueueCapacity( + const TaskType& task) const { + // Unlike the high priority batch capacity validation where having only + // input_batch_size_limit without max_execution_batch_size is allowed, it + // doesn't have the backward compatibility check and always assume that + // max_execution_batch_size is present. + if (task.size() > + options_.low_priority_queue_options.max_execution_batch_size) { + return absl::UnavailableError(absl::StrFormat( + "The low priority task queue to which this task was submitted has " + "max_execution_batch_size=%d and the task size is %d", + options_.low_priority_queue_options.max_execution_batch_size, + task.size())); + } + if (low_priority_tasks_.size() + task.size() > + options_.low_priority_queue_options.max_enqueued_batches * + options_.low_priority_queue_options.max_execution_batch_size) { + return absl::UnavailableError(absl::StrFormat( + "The low priority task queue to which this task was submitted does not " + "have the capacity to handle this task; currently the low priority " + "queue has %d tasks enqueued and the submitted task size is %d while " + "max_enqueued_batches=%d and max_execution_batch_size=%d", + low_priority_tasks_.size(), task.size(), + options_.low_priority_queue_options.max_enqueued_batches, + options_.low_priority_queue_options.max_execution_batch_size)); + } + return absl::OkStatus(); +} + +template +typename SharedBatchScheduler::BatchTaskUniquePtr +Queue::ScheduleBatch() { + // The batch to schedule, which we may populate below. (If left as nullptr, + // that means we are electing not to schedule a batch at this time.) + std::unique_ptr> batch_to_schedule; + + { + mutex_lock l(mu_); + + std::deque>>& batches = GetBatches(); + + // Just in time merging of low priority tasks into the open batch. + PadOpenBatchWithLowPriorityTasks(); + + // Consider closing the open batch at this time, to schedule it. + if (batches.size() == 1 && IsOpenBatchSchedulable()) { + // Support BatchPaddingPolicy::kBatchDown and + // BatchPaddingPolicy::kMinimizeTpuCostPerRequest. We do this before + // starting a new batch because starting a new batch will close the old + // batch, making it read-only. + Batch& old_batch = *batches[0]; + uint64 old_batch_time = old_batch.EarliestTaskStartTime().value(); + std::vector> trimmed_tasks; + MaybeBatchDown( + /* batch= */ old_batch, + /* allowed_batch_sizes= */ options_.allowed_batch_sizes, + /* disable_padding= */ options_.disable_padding, + /* batch_padding_policy= */ options_.batch_padding_policy, + /* model_batch_stats= */ options_.model_batch_stats, + /* out_trimmed_tasks= */ trimmed_tasks); + + StartNewBatch(); + + // Move the trimmed tasks, if any, into the new batch. + Batch& new_batch = *batches[1]; + for (std::unique_ptr& task : trimmed_tasks) { + new_batch.AddTask(std::move(task), old_batch_time); + } + if (!new_batch.empty()) { + // TODO - b/325954758: Reconsider the starting time of a trimmed batch. + // + // Ideally, we'd set open_batch_start_time_micros_ to time we received + // the first task in the open batch, but we don't have this information + // here. For now, we're trying as alternative solution that doesn't + // require adding time to each task: assume that requests arrived at a + // steady rate and therefore use a point between the old value of + // open_batch_start_time_micros_ and NOW. + // + // Let's say that originally, the batch had 10 requests, and we want to + // schedule a batch of size 8 and leave 2 requests in the open batch + // (new_batch). Then, variable `position` is 0.8, which means we have to + // set open_batch_start_time_micros_ to be at a position of 80% between + // open_batch_start_time_micros_ and now. + double position = static_cast(old_batch.size()) / + (old_batch.size() + new_batch.size()); + open_batch_start_time_micros_ += + (env_->NowMicros() - open_batch_start_time_micros_) * position; + } + } + + if (batches.size() >= 2) { + // There is at least one closed batch that is ready to be scheduled. + batch_to_schedule = std::move(batches.front()); + batches.pop_front(); + } + + if (batch_to_schedule == nullptr) { + // If there was no schedulable batch in the batch queue, try to schedule + // from the low priority task queue. + batch_to_schedule = ScheduleLowPriorityBatch(); + } + + if (batch_to_schedule == nullptr) { + // There is neither high nor low priority batch that can be scheduled, + // mark the condition false and return the nullptr. + schedulable_batch_ = false; + return batch_to_schedule; + } + + // Otherwise, increment the counter and return the batch. + ++num_batches_being_processed_; + } + return batch_to_schedule; +} + +template +std::vector> Queue::GetLowPriorityTasks( + size_t size) { + std::vector> low_priority_tasks_to_pad; + // If priority queue is not enabled, immediately return instead of attempting + // to acquire a lock. + if (!options_.enable_priority_queue || size == 0) + return low_priority_tasks_to_pad; + { + mutex_lock l(mu_); + low_priority_tasks_to_pad = GetLowPriorityTaskQueue().RemoveTask(size); + } + return low_priority_tasks_to_pad; +} + +template +std::vector> +Queue::GetLowPriorityTasksForPadding(size_t batch_size) { + size_t target_batch_size; + switch (options_.mixed_priority_batching_policy) { + case MixedPriorityBatchingPolicy::kLowPriorityPaddingWithMaxBatchSize: + target_batch_size = max_execution_batch_size(); + break; + case MixedPriorityBatchingPolicy:: + kLowPriorityPaddingWithNextAllowedBatchSize: + target_batch_size = GetNextAllowedBatchSize( + batch_size, options_.allowed_batch_sizes, options_.disable_padding); + break; + default: + target_batch_size = 0; + break; + } + + if (target_batch_size <= batch_size) { + return {}; + } + return GetLowPriorityTasks(target_batch_size - batch_size); +} + +template +void Queue::ProcessBatch( + std::unique_ptr> batch, + std::vector> padding_task) { + tsl::profiler::TraceMeConsumer trace_me( + [&] { + return profiler::TraceMeEncode( + "ProcessBatch", {{"batch_size_before_padding", batch->size()}, + {"_r", 2} /*root_event*/}); + }, + tsl::profiler::ContextType::kSharedBatchScheduler, + batch->traceme_context_id()); + + if (std::holds_alternative( + process_batch_callback_)) { + std::get(process_batch_callback_)( + std::move(batch)); + } else { + std::get(process_batch_callback_)( + std::move(batch), std::move(padding_task)); + } + + { + mutex_lock l(mu_); + --num_batches_being_processed_; + if (empty_notification_ != nullptr && IsEmptyInternal()) { + empty_notification_->Notify(); + } + } +} + +template +bool Queue::IsEmpty() const { + mutex_lock l(mu_); + return IsEmptyInternal(); +} + +template +void Queue::CloseAndWaitUntilEmpty() { + Notification empty; + { + mutex_lock l(mu_); + closed_ = true; + if (IsEmptyInternal()) { + empty.Notify(); + } else { + // Arrange for ProcessBatch() to notify when the queue becomes empty. + empty_notification_ = ∅ + } + } + empty.WaitForNotification(); +} + +template +bool Queue::IsEmptyInternal() const { + const std::deque>>& batches = GetBatches(); + return num_batches_being_processed_ == 0 && batches.size() == 1 && + batches.back()->empty() && low_priority_tasks_.empty(); +} + +template +void Queue::StartNewBatch() { + std::deque>>& batches = GetBatches(); + batches.back()->Close(); + batches.emplace_back(new Batch(++traceme_context_id_counter_)); +} + +template +absl::Status Queue::SplitInputBatchIntoSubtasks( + std::unique_ptr* input_task, + std::vector>* output_tasks) { + const int open_batch_remaining_slot = + max_execution_batch_size() - this->tail_batch_task_size(); + return options_.split_input_task_func( + std::move(input_task), open_batch_remaining_slot, + max_execution_batch_size(), std::move(output_tasks)); +} + +template +bool Queue::IsOpenBatchSchedulable() const { + return PeekBatchPriorityImpl().has_value(); +} + +template +std::optional::BatchPriorityKey> +Queue::PeekBatchPriority() const { + { + mutex_lock l(mu_); + return PeekBatchPriorityImpl(); + } +} + +template +std::optional::BatchPriorityKey> +Queue::PeekBatchPriorityImpl() const { + const int kHighPriority = 1; + const int kLowPriority = 2; + + const std::deque>>& batches = GetBatches(); + + if (batches.size() >= 2) { + Batch* batch = batches.front().get(); + return std::make_pair(kHighPriority, + batch->EarliestTaskStartTime().value()); + } + + Batch* open_batch = batches.back().get(); + + size_t effective_batch_size = open_batch->size(); + uint64 effective_start_time_micros = open_batch_start_time_micros_; + int64_t effective_batch_timeout_micros = options_.batch_timeout_micros; + if (effective_batch_size == 0) { + // open_batch_start_time_micros_ is not valid for an empty batch. + effective_start_time_micros = env_->NowMicros(); + } + + if (options_.enable_priority_queue && + options_.mixed_priority_batching_policy == + MixedPriorityBatchingPolicy::kPriorityMerge) { + if (effective_batch_size == 0) { + effective_batch_timeout_micros = + options_.low_priority_queue_options.batch_timeout_micros; + } + + effective_batch_size += low_priority_tasks_.size(); + + auto low_priority_earliest_start_time = + low_priority_tasks_.EarliestTaskStartTime(); + if (low_priority_earliest_start_time.has_value()) { + effective_start_time_micros = std::min(effective_start_time_micros, + *low_priority_earliest_start_time); + } + } + + if (effective_batch_size == 0) { + return std::nullopt; + } + + bool schedulable = closed_ || + effective_batch_size >= max_execution_batch_size() || + env_->NowMicros() >= effective_start_time_micros + + effective_batch_timeout_micros; + + if (!schedulable) { + return std::nullopt; + } + + int priority = open_batch->empty() ? kLowPriority : kHighPriority; + return std::make_pair(priority, effective_start_time_micros); +} + +template +std::unique_ptr> Queue::ScheduleLowPriorityBatch() { + std::unique_ptr> batch_to_schedule; + if (!options_.enable_priority_queue || low_priority_tasks_.empty() || + options_.mixed_priority_batching_policy == + MixedPriorityBatchingPolicy::kPriorityMerge) { + // Return early if priority queue is disabled or there is no low priority + // task. Note that the priority_merge policy does all scheduling in + // ScheduleBatch(). + return batch_to_schedule; + } + if (env_->NowMicros() < + *low_priority_tasks_.EarliestTaskStartTime() + + options_.low_priority_queue_options.batch_timeout_micros && + low_priority_tasks_.size() < + options_.low_priority_queue_options.max_execution_batch_size) { + // Return early if the low priority tasks can't fill up the max batch size + // and the earliest task didn't time out. + return batch_to_schedule; + } + if (!GetBatches().empty() && !GetBatches().front()->empty()) { + // Return early if there is a non-empty high priority batch in the queue. + return batch_to_schedule; + } + + batch_to_schedule = std::make_unique>(); + for (std::unique_ptr& task : low_priority_tasks_.RemoveTask( + options_.low_priority_queue_options.max_execution_batch_size)) { + batch_to_schedule->AddTask(std::move(task), env_->NowMicros()); + } + batch_to_schedule->Close(); + + return batch_to_schedule; +} + +template +size_t Queue::tail_batch_task_size() const { + return GetBatches().back()->size(); +} + +template +int64 Queue::num_enqueued_batches() const { + return GetBatches().size(); +} + +template +std::deque>>& Queue::GetBatches() { + return high_priority_batches_; +} + +template +const std::deque>>& +Queue::GetBatches() const { + return high_priority_batches_; +} + +template +TaskQueue& Queue::GetLowPriorityTaskQueue() { + return low_priority_tasks_; +} + +template +QueueHandle::QueueHandle( + std::shared_ptr> scheduler, + Queue* queue) + : scheduler_(scheduler), queue_(queue) {} + +template +QueueHandle::~QueueHandle() { + queue_->CloseAndWaitUntilEmpty(); +} + +template +absl::Status QueueHandle::Schedule(std::unique_ptr* task) { + return queue_->Schedule(task); +} + +template +size_t QueueHandle::NumEnqueuedTasks() const { + return queue_->NumEnqueuedTasks(); +} + +template +size_t QueueHandle::SchedulingCapacity() const { + return queue_->SchedulingCapacity(); +} + +} // namespace internal + +} // namespace serving +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_SHARED_BATCH_SCHEDULER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/batching_util/threadsafe_status.h b/third_party/tflite-hdrs/tensorflow/core/kernels/batching_util/threadsafe_status.h new file mode 100644 index 00000000..68e94f70 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/batching_util/threadsafe_status.h @@ -0,0 +1,57 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_THREADSAFE_STATUS_H_ +#define TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_THREADSAFE_STATUS_H_ + +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/thread_annotations.h" + +namespace tensorflow { +// Wrapper class to allow both lock-free construction and concurrent updates on +// a 'status'. +// +// Example Usage: +// std::thread threads[2]; +// ThreadSafeStatus thread_safe_status; +// threads[0] = std::thread([&]() { +// status.Update(errors::Internal("internal error")); +// }); +// threads[1] = std::thread([&]() { +// status.Update(errors::InvalidArgument("invalid argument")); +// }); +// threads[0].Join(); +// threads[1].Join(); +// +// NOTE: +// When updated in a multi-threading setup, only the first error is retained. +class ThreadSafeStatus { + public: + const absl::Status& status() const& TF_LOCKS_EXCLUDED(mutex_); + absl::Status status() && TF_LOCKS_EXCLUDED(mutex_); + + // Retains the first error status: replaces the current status with + // `new_status` if `new_status` is not OK and the previous status is OK. + void Update(const absl::Status& new_status) TF_LOCKS_EXCLUDED(mutex_); + void Update(absl::Status&& new_status) TF_LOCKS_EXCLUDED(mutex_); + + private: + mutable mutex mutex_; + absl::Status status_ TF_GUARDED_BY(mutex_); +}; +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_THREADSAFE_STATUS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/batching_util/warmup.h b/third_party/tflite-hdrs/tensorflow/core/kernels/batching_util/warmup.h new file mode 100644 index 00000000..30e64795 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/batching_util/warmup.h @@ -0,0 +1,132 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_WARMUP_H_ +#define TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_WARMUP_H_ + +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/hash/hash.h" +#include "absl/status/statusor.h" +#include "absl/synchronization/mutex.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/protobuf/config.pb.h" +#include "tsl/platform/logging.h" + +namespace tensorflow { +namespace serving { + +// Global registry for model's warm-up states. Before a model executes warm-up +// requests, it is registered here so that the runtime can distinguish demand +// requests vs. warm-up requests and apply warm-up specific optimizations. +class WarmupStateRegistry { + public: + struct Key { + std::string name; + int64_t version; + + Key(std::string name, int64_t version) + : name(std::move(name)), version(version) {} + + template + friend H AbslHashValue(H state, const Key& key) { + return H::combine(std::move(state), key.name, key.version); + } + + friend bool operator==(const Key& x, const Key& y) { + return x.name == y.name && x.version == y.version; + } + }; + // Data stored per key. + struct PerModelData { + // If true, supported batch ops will execute the model on dummy batches + // for all `allowed_batch_sizes` of that batch op. This removes the + // need to issue separate warmup requests for each batch size. + bool warmup_all_batch_sizes = false; + }; + + // RAII handle for registered models. + class Handle { + public: + Handle() = default; + + Handle(const Handle& other) = delete; + Handle& operator=(const Handle& other) = delete; + Handle(Handle&& other) + : key_(std::move(other.key_)), registry_(other.registry_) { + other.key_.reset(); + } + Handle& operator=(Handle&& other) { + if (key_.has_value()) { + Release(); + } + + key_ = std::move(other.key_); + other.key_.reset(); + registry_ = other.registry_; + return *this; + } + + ~Handle() { Release(); } + + void Release(); + + private: + friend class WarmupStateRegistry; + + // Can only be constructed by `WarmupStateRegistry::Register()`. + Handle(const Key& key, WarmupStateRegistry* registry) + : key_(key), registry_(registry) { + DCHECK(registry_); + } + + std::optional key_; + WarmupStateRegistry* registry_ = nullptr; + }; + + // Registers the given model to be in a warm-up state and associates the given + // metadata with the model. Returns an RAII handle that unregisters the model + // at its destruction. + absl::StatusOr Register(const Key& model_key, + std::unique_ptr per_model_data); + + // Return model data. A nullptr indicates the key was not present. + const PerModelData* Lookup(const Key& model_key); + + private: + friend class Handle; + + void Unregister(const Key& model_key); + + absl::Mutex mu_; + // Map of model names/versions to miscellaneous data. + absl::flat_hash_map> states_ + ABSL_GUARDED_BY(&mu_); +}; + +WarmupStateRegistry& GetGlobalWarmupStateRegistry(); + +// Utility function that returns whether or not to warmup all batch sizes, +// based on the state of WarmupStateRegistry. +bool ShouldWarmupAllBatchSizes(const OpKernelContext* c); + +} // namespace serving +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_WARMUP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/betainc_op.h b/third_party/tflite-hdrs/tensorflow/core/kernels/betainc_op.h new file mode 100644 index 00000000..c808e688 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/betainc_op.h @@ -0,0 +1,51 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_BETAINC_OP_H_ +#define TENSORFLOW_CORE_KERNELS_BETAINC_OP_H_ +// Functor definition for BetaincOp, must be compilable by nvcc. + +#include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/tensor_types.h" + +namespace tensorflow { +namespace functor { + +// Functor used by BetaincOp to do the computations. +template +struct Betainc { + void operator()(const Device& d, typename TTypes::ConstTensor a, + typename TTypes::ConstTensor b, + typename TTypes::ConstTensor x, + typename TTypes::Tensor output) { + output.device(d) = Eigen::betainc(a, b, x); + } + + void BCast(const Device& d, typename TTypes::ConstTensor a, + const typename Eigen::array& bcast_a, + typename TTypes::ConstTensor b, + const typename Eigen::array& bcast_b, + typename TTypes::ConstTensor x, + const typename Eigen::array& bcast_x, + typename TTypes::Tensor output) { + output.device(d) = Eigen::betainc( + a.broadcast(bcast_a), b.broadcast(bcast_b), x.broadcast(bcast_x)); + } +}; + +} // namespace functor +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_BETAINC_OP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/bias_op.h b/third_party/tflite-hdrs/tensorflow/core/kernels/bias_op.h new file mode 100644 index 00000000..d4a78804 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/bias_op.h @@ -0,0 +1,60 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_BIAS_OP_H_ +#define TENSORFLOW_CORE_KERNELS_BIAS_OP_H_ +// Functor definition for BiasOp, must be compilable by nvcc. + +#include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/tensor_types.h" + +namespace tensorflow { +namespace functor { + +// Functor used by BiasOp to do the computations. +template +struct Bias { + // Add "bias" to "input", repeating "bias". + void operator()(const Device& d, typename TTypes::ConstFlat input, + typename TTypes::ConstVec bias, + typename TTypes::Flat output) { + const Eigen::Index rest_size = input.size() / bias.dimension(0); + Eigen::DSizes bcast(rest_size); + MaybeWith32BitIndexing( + [&](auto input32, auto bias32, auto output32, const auto& bcast32) { + output32.device(d) = input32 + bias32.broadcast(bcast32); + }, + input, bias, output, bcast); + } + + // NCHW layout, repeating on the first dimension, broadcasting on the last + // dimension. + void operator()(const Device& d, typename TTypes::ConstMatrix input, + typename TTypes::ConstMatrix bias1, // shape [C, 1]. + typename TTypes::Matrix output) { + const Eigen::Index rest_size = input.dimension(0) / bias1.dimension(0); + Eigen::DSizes bcast(rest_size, input.dimension(1)); + MaybeWith32BitIndexing( + [&](auto input32, auto bias32, auto output32, const auto& bcast32) { + output32.device(d) = input32 + bias32.broadcast(bcast32); + }, + input, bias1, output, bcast); + } +}; + +} // namespace functor +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_BIAS_OP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/bias_op_gpu.h b/third_party/tflite-hdrs/tensorflow/core/kernels/bias_op_gpu.h new file mode 100644 index 00000000..0ece14a9 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/bias_op_gpu.h @@ -0,0 +1,81 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_BIAS_OP_GPU_H_ +#define TENSORFLOW_CORE_KERNELS_BIAS_OP_GPU_H_ + +#define EIGEN_USE_GPU + +#include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/kernels/gpu_utils.h" +#include "tensorflow/core/util/tensor_format.h" + +namespace tensorflow { + +typedef Eigen::GpuDevice GPUDevice; + +template +struct BiasGPU { + static void compute(const GPUDevice& d, const T* input, const T* bias, + T* output, int32_t batch, int32_t height, int32_t width, + int32_t depth, int32_t channel, TensorFormat data_format); +}; + +template +struct BiasGradGPU { + static void compute(const GPUDevice& device, const T* output_backprop, + T* bias_backprop, int32_t batch, int32_t height, + int32_t width, int32_t depth, int32_t channel, + TensorFormat data_format); + + static void DoRowReduction(OpKernelContext* context, T* output, + const T* input, int rows, int cols); + + static void DoColReduction(OpKernelContext* context, T* output, + const T* input, int rows, int cols); +}; + +enum class BiasAddGradGPUMode { + kInvalid = 0, + kNative = 1, + kReduction = 2, +}; + +// Describe the BiasGradGPU result from a perf experiment. +// +// Arguments: +// algorithm: returns the method to use for bias add grad. +// elapsed_time; returns the measured elapsed time in microseconds. +class BiasGradGPUProfileResult { + public: + bool is_valid() const { + return (algorithm_ != BiasAddGradGPUMode::kInvalid && + elapsed_time_ != std::numeric_limits::max()); + } + BiasAddGradGPUMode algorithm() const { return algorithm_; } + void set_algorithm(BiasAddGradGPUMode val) { algorithm_ = val; } + uint64 elapsed_time() const { return elapsed_time_; } + void set_elapsed_time(uint64 val) { elapsed_time_ = val; } + + private: + BiasAddGradGPUMode algorithm_ = BiasAddGradGPUMode::kInvalid; + uint64 elapsed_time_ = std::numeric_limits::max(); +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_BIAS_OP_GPU_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/bincount_op.h b/third_party/tflite-hdrs/tensorflow/core/kernels/bincount_op.h new file mode 100644 index 00000000..48847617 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/bincount_op.h @@ -0,0 +1,51 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_BINCOUNT_OP_H_ +#define TENSORFLOW_CORE_KERNELS_BINCOUNT_OP_H_ + +#include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/core/errors.h" + +namespace tensorflow { + +namespace functor { + +template +struct BincountFunctor { + static absl::Status Compute(OpKernelContext* context, + const typename TTypes::ConstTensor& arr, + const typename TTypes::ConstTensor& weights, + typename TTypes::Tensor& output, + const Tidx num_bins); +}; + +template +struct BincountReduceFunctor { + static absl::Status Compute(OpKernelContext* context, + const typename TTypes::ConstTensor& in, + const typename TTypes::ConstTensor& weights, + typename TTypes::Tensor& out, + const Tidx num_bins); +}; + +} // end namespace functor + +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_BINCOUNT_OP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/broadcast_to_op.h b/third_party/tflite-hdrs/tensorflow/core/kernels/broadcast_to_op.h new file mode 100644 index 00000000..083723e1 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/broadcast_to_op.h @@ -0,0 +1,91 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_BROADCAST_TO_OP_H_ +#define TENSORFLOW_CORE_KERNELS_BROADCAST_TO_OP_H_ + +#include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/kernels/fill_functor.h" +#include "tensorflow/core/util/bcast.h" + +namespace tensorflow { + +namespace functor { + +template +struct BroadcastTo { + template + void DoBCast( + const Device &device, typename TTypes::Tensor out, + typename TTypes::ConstTensor in, + const typename Eigen::array &bcast) const { + MaybeWith32BitIndexing( + [&](auto out32, auto in32, const auto &bcast32) { + out32.device(device) = in32.broadcast(bcast32); + }, + out, in, bcast); + } + + template + void ReshapeAndBCast(const Device &device, Tensor &output_tensor, + const Tensor &input_tensor, const BCast &bcast) const { + DoBCast( + device, output_tensor.template shaped(bcast.result_shape()), + input_tensor.template shaped(bcast.x_reshape()), + BCast::ToIndexArrayType(bcast.x_bcast())); + } + + // PRECONDITION: rank(input_shape) > 0 && + // rank(input_shape) <= rank(output_shape) && + // output_shape.num_elements() > 0. + void operator()(const Device &device, OpKernelContext *ctx, + Tensor &output_tensor, const TensorShape &output_shape, + const Tensor &input_tensor, const TensorShape &input_shape, + const BCast &bcast) const { + const int ndims = bcast.y_reshape().size(); + switch (ndims) { + case 1: + ReshapeAndBCast<1>(device, output_tensor, input_tensor, bcast); + break; + case 2: + ReshapeAndBCast<2>(device, output_tensor, input_tensor, bcast); + break; + case 3: + ReshapeAndBCast<3>(device, output_tensor, input_tensor, bcast); + break; + case 4: + ReshapeAndBCast<4>(device, output_tensor, input_tensor, bcast); + break; + case 5: + ReshapeAndBCast<5>(device, output_tensor, input_tensor, bcast); + break; + default: + ctx->SetStatus(errors::Unimplemented( + "Broadcast between ", input_shape.DebugString(), " and ", + output_shape.DebugString(), " is not supported yet.")); + break; + } + } +}; + +} // namespace functor +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_BROADCAST_TO_OP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/bucketize_op.h b/third_party/tflite-hdrs/tensorflow/core/kernels/bucketize_op.h new file mode 100644 index 00000000..9fb59c77 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/bucketize_op.h @@ -0,0 +1,41 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_BUCKETIZE_OP_H_ +#define TENSORFLOW_CORE_KERNELS_BUCKETIZE_OP_H_ + +#include +#include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/core/errors.h" + +namespace tensorflow { +namespace functor { + +template +struct BucketizeFunctor { + static absl::Status Compute(OpKernelContext* context, + const typename TTypes::ConstTensor& input, + const std::vector& boundaries_vector, + typename TTypes::Tensor& output); +}; + +} // namespace functor +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_BUCKETIZE_OP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/cast_op.h b/third_party/tflite-hdrs/tensorflow/core/kernels/cast_op.h new file mode 100644 index 00000000..0c955651 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/cast_op.h @@ -0,0 +1,351 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_CAST_OP_H_ +#define TENSORFLOW_CORE_KERNELS_CAST_OP_H_ + +#include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/bfloat16.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/platform/byte_order.h" +#include "tensorflow/core/platform/types.h" + +// Note that the GPU cast functor templates need to be instantiated unlike the +// CPU ones, and hence their specializations are different than that for CPUs. +#ifdef SPECIALIZE_FOR_GPUS +#define SPECIALIZE_CAST(DEVICE, OUT_TYPE, IN_TYPE) \ + template \ + struct CastFunctor { \ + void operator()(const Device& d, \ + typename TTypes::Flat out_tensor, \ + typename TTypes::ConstFlat in_tensor, \ + bool truncate = false) { \ + if (truncate) { \ + out_tensor.device(d) = \ + in_tensor.unaryExpr(LSBZeroSetter()) \ + .template cast(); \ + } else { \ + out_tensor.device(d) = in_tensor.template cast(); \ + } \ + } \ + }; \ + template struct CastFunctor; +#else +#define SPECIALIZE_CAST(DEVICE, OUT_TYPE, IN_TYPE) \ + template <> \ + struct CastFunctor { \ + void operator()(const DEVICE& d, \ + typename TTypes::Flat out_tensor, \ + typename TTypes::ConstFlat in_tensor, \ + bool truncate = false) { \ + if (truncate) { \ + out_tensor.device(d) = \ + in_tensor.unaryExpr(LSBZeroSetter()) \ + .template cast(); \ + } else { \ + out_tensor.device(d) = in_tensor.template cast(); \ + } \ + } \ + }; +#endif + +#define CAST_FUNCTORS(devname) \ + SPECIALIZE_CAST(devname, float, double) \ + SPECIALIZE_CAST(devname, float, std::complex) \ + SPECIALIZE_CAST(devname, std::complex, std::complex) \ + SPECIALIZE_CAST(devname, std::complex, double) \ + SPECIALIZE_CAST(devname, Eigen::half, double) \ + SPECIALIZE_CAST(devname, Eigen::half, float) \ + SPECIALIZE_CAST(devname, Eigen::half, std::complex) \ + SPECIALIZE_CAST(devname, Eigen::half, std::complex) \ + SPECIALIZE_CAST(devname, bfloat16, float) \ + SPECIALIZE_CAST(devname, float8_e5m2, double) \ + SPECIALIZE_CAST(devname, float8_e5m2, float) \ + SPECIALIZE_CAST(devname, float8_e5m2, bfloat16) \ + SPECIALIZE_CAST(devname, float8_e5m2, Eigen::half) \ + SPECIALIZE_CAST(devname, float8_e5m2, float8_e4m3fn) \ + SPECIALIZE_CAST(devname, float8_e4m3fn, double) \ + SPECIALIZE_CAST(devname, float8_e4m3fn, float) \ + SPECIALIZE_CAST(devname, float8_e4m3fn, bfloat16) \ + SPECIALIZE_CAST(devname, float8_e4m3fn, Eigen::half) \ + template \ + struct CastFunctor { \ + void operator()(const devname& d, \ + typename TTypes::Flat out_tensor, \ + typename TTypes::ConstFlat in_tensor, \ + bool truncate = false) { \ + out_tensor.device(d) = in_tensor.template cast(); \ + } \ + }; + +#if defined(MLIR_GENERATED_GPU_KERNELS_ENABLED) +// If MLIR kernels are enabled, we don't need the specialized cast from float to +// double or from Eigen::half to double. We still need the specialized cast from +// Eigen::half to float, because it is used in depthwise_conv_grad_op.cc. We +// still need the specialized cast from float to double because it is used in +// resize_bilinear_op.cc. +#define CAST_FUNCTORS_SUBSET(devname) \ + SPECIALIZE_CAST(devname, float, double) \ + SPECIALIZE_CAST(devname, Eigen::half, float) \ + SPECIALIZE_CAST(devname, bfloat16, float) \ + SPECIALIZE_CAST(devname, float8_e5m2, double) \ + SPECIALIZE_CAST(devname, float8_e5m2, float) \ + SPECIALIZE_CAST(devname, float8_e5m2, bfloat16) \ + SPECIALIZE_CAST(devname, float8_e5m2, Eigen::half) \ + SPECIALIZE_CAST(devname, float8_e5m2, float8_e4m3fn) \ + SPECIALIZE_CAST(devname, float8_e4m3fn, double) \ + SPECIALIZE_CAST(devname, float8_e4m3fn, float) \ + SPECIALIZE_CAST(devname, float8_e4m3fn, bfloat16) \ + SPECIALIZE_CAST(devname, float8_e4m3fn, Eigen::half) \ + template \ + struct CastFunctor { \ + void operator()(const devname& d, \ + typename TTypes::Flat out_tensor, \ + typename TTypes::ConstFlat in_tensor, \ + bool truncate = false) { \ + out_tensor.device(d) = in_tensor.template cast(); \ + } \ + }; +#endif + +namespace tensorflow { + +typedef std::function + CastFunctorType; + +// Common base class of Cast kernels +class CastOpBase : public OpKernel { + public: + explicit CastOpBase(OpKernelConstruction* ctx); + + void Compute(OpKernelContext* ctx) override; + + protected: + DataType src_dtype_; + DataType dst_dtype_; + DataType external_src_dtype_; + DataType external_dst_dtype_; + bool use_truncation_; + CastFunctorType work_ = nullptr; + absl::Status Unimplemented(); + + CastOpBase(const CastOpBase&) = delete; + void operator=(const CastOpBase&) = delete; +}; + +// CPU implementation of Cast +class CpuCastOp : public CastOpBase { + public: + explicit CpuCastOp(OpKernelConstruction* ctx); + + private: + absl::Status Prepare(); +}; + +namespace functor { + +template +constexpr int MantissaWidth() { + return std::numeric_limits::digits; +} + +template <> +constexpr int MantissaWidth() { + // Remember, there's 1 hidden bit + return 10 + 1; +} + +template <> +constexpr int MantissaWidth() { + // Remember, there's 1 hidden bit + return 7 + 1; +} + +template +void Cast(const Device& d, typename TTypes::Flat o, + typename TTypes::ConstFlat i) { + o.device(d) = i.template cast(); +} + +template +struct CastFunctor { + void operator()(const Device& d, typename TTypes::Flat o, + typename TTypes::ConstFlat i, bool truncate = false); +}; + +template +typename std::enable_if::type EIGEN_DEVICE_FUNC + EIGEN_STRONG_INLINE static LSBZeroSetterHelper(I& t, int n) { + // Only zero the bits for non-NaNs. + // For NaNs, let the non-truncation version handle it. + if (!Eigen::numext::isnan(t)) { + uint64_t* p = reinterpret_cast(&t); + *p &= (0xFFFFFFFFFFFFFFFF << n); + } +} + +template +typename std::enable_if::type EIGEN_DEVICE_FUNC + EIGEN_STRONG_INLINE static LSBZeroSetterHelper(I& t, int n) { + // Only zero the bits for non-NaNs. + // For NaNs, let the non-truncation version handle it. + if (!Eigen::numext::isnan(t)) { + uint32_t* p = reinterpret_cast(&t); + *p &= (0xFFFFFFFF << n); + } +} + +template +typename std::enable_if::type EIGEN_DEVICE_FUNC + EIGEN_STRONG_INLINE static LSBZeroSetterHelper(I& t, int n) { + // Only zero the bits for non-NaNs. + // For NaNs, let the non-truncation version handle it. + if (!Eigen::numext::isnan(t)) { + uint16_t* p = reinterpret_cast(&t); + *p &= (0xFFFF << n); + } +} + +template +typename std::enable_if::type EIGEN_DEVICE_FUNC + EIGEN_STRONG_INLINE static LSBZeroSetterHelper(I& t, int n) { + // Only zero the bits for non-NaNs. + // For NaNs, let the non-truncation version handle it. + if (!Eigen::numext::isnan(t)) { + uint8_t* p = reinterpret_cast(&t); + *p &= (0xFF << n); + } +} + +// Set n least significant bits to 0 +template +struct LSBZeroSetter { + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE I operator()(const I& a) const { + constexpr int bits = MantissaWidth() - MantissaWidth(); + static_assert( + bits > 0, + "The output type must have fewer mantissa bits than the input type\n"); + I t = a; + LSBZeroSetterHelper(t, bits); + return t; + } +}; + +template +struct LSBZeroSetter, std::complex> { + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::complex operator()( + const std::complex& a) const { + constexpr int bits = MantissaWidth() - MantissaWidth(); + static_assert( + bits > 0, + "The output type must have fewer mantissa bits than the input type\n"); + I re = Eigen::numext::real(a); + I img = Eigen::numext::imag(a); + LSBZeroSetterHelper(re, bits); + LSBZeroSetterHelper(img, bits); + std::complex toReturn(re, img); + return toReturn; + } +}; + +template +struct LSBZeroSetter, O> { + // Sets the 16 LSBits of the float to 0 + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::complex operator()( + const std::complex& a) const { + constexpr int bits = MantissaWidth() - MantissaWidth(); + static_assert( + bits > 0, + "The output type must have fewer mantissa bits than the input type\n"); + I re = Eigen::numext::real(a); + I img = Eigen::numext::imag(a); + LSBZeroSetterHelper(re, bits); + LSBZeroSetterHelper(img, bits); + std::complex toReturn(re, img); + return toReturn; + } +}; + +} // end namespace functor +} // end namespace tensorflow + +namespace Eigen { +namespace internal { + +// Eigen can't convert to/from complex numbers, because it is limited to cases +// that can be static_casted. But numpy is able to cast to/from complex, which +// we want to replicate. So we add specializations for complex here. +template +struct scalar_cast_op, To> { + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE To + operator()(const std::complex& a) const { + // Replicate numpy behavior of returning just the real part + return static_cast(a.real()); + } +}; + +template +struct scalar_cast_op, bool> { + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool operator()( + const std::complex& a) const { + return static_cast(a.real()); + } +}; + +template +struct scalar_cast_op> { + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::complex operator()( + const From& a) const { + // Replicate numpy behavior of setting the imaginary part to 0 + return std::complex(static_cast(a), To(0)); + } +}; + +template +struct scalar_cast_op, std::complex> { + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::complex operator()( + const std::complex& a) const { + return std::complex(static_cast(a.real()), + static_cast(a.imag())); + } +}; + +template +struct functor_traits_complex_impl { + enum { Cost = NumTraits::AddCost, PacketAccess = false }; +}; + +template +struct functor_traits, bool>> + : functor_traits_complex_impl, bool> {}; + +template +struct functor_traits, To>> + : functor_traits_complex_impl, To> {}; +template +struct functor_traits>> + : functor_traits_complex_impl> {}; +// Needed to avoid ambiguous partial specialization +template +struct functor_traits, std::complex>> + : functor_traits_complex_impl, std::complex> {}; + +} // namespace internal +} // namespace Eigen + +#endif // TENSORFLOW_CORE_KERNELS_CAST_OP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/cast_op_impl.h b/third_party/tflite-hdrs/tensorflow/core/kernels/cast_op_impl.h new file mode 100644 index 00000000..6f0fe7eb --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/cast_op_impl.h @@ -0,0 +1,189 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_CAST_OP_IMPL_H_ +#define TENSORFLOW_CORE_KERNELS_CAST_OP_IMPL_H_ + +#include +#include + +#define EIGEN_USE_THREADS + +#include "absl/status/status.h" +#include "tensorflow/core/platform/errors.h" +#include "tsl/platform/status.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/kernels/cast_op.h" + +namespace tensorflow { + +namespace functor { + +template +struct OutOfRange { + bool operator()(const F f) const { + return f < std::numeric_limits::min() || + f > std::numeric_limits::max(); + } +}; + +#define VALIDATE_CAST(I, F) \ + template <> \ + struct CastFunctor { \ + void operator()(const Eigen::ThreadPoolDevice& d, \ + typename TTypes::Flat out_tensor, \ + typename TTypes::ConstFlat in_tensor, \ + bool truncate = false) const { \ + Eigen::Tensor out_of_range = \ + in_tensor.unaryExpr(OutOfRange{}).any(); \ + if (out_of_range()) { \ + LOG(ERROR) \ + << "IMPORTANT! The input tensor to Cast contains values out of " \ + "range for the target type. This is undefined behavior and " \ + "likely a bug in your model. A crash immediately after this " \ + "under ubsan is expected."; \ + } \ + out_tensor.device(d) = in_tensor.template cast(); \ + } \ + }; + +// Add additional logging for out of range inputs when running in debug mode. +#ifndef NDEBUG +VALIDATE_CAST(int32, float); +VALIDATE_CAST(int64, float); +VALIDATE_CAST(int32, double); +VALIDATE_CAST(int64, double); +#endif + +CAST_FUNCTORS(Eigen::ThreadPoolDevice); + + +} // namespace functor + +#define CURRY_TYPES3(FN, arg0, arg1) \ + FN(arg0, arg1, bool); \ + FN(arg0, arg1, uint8); \ + FN(arg0, arg1, uint16); \ + FN(arg0, arg1, uint32); \ + FN(arg0, arg1, uint64); \ + FN(arg0, arg1, int8); \ + FN(arg0, arg1, int16); \ + FN(arg0, arg1, int32); \ + FN(arg0, arg1, int64_t); \ + FN(arg0, arg1, float); \ + FN(arg0, arg1, double); \ + FN(arg0, arg1, std::complex); \ + FN(arg0, arg1, std::complex) \ + FN(arg0, arg1, Eigen::half); \ + FN(arg0, arg1, bfloat16); + +#define CAST_CASE(DEVICE, IN, OUT) \ + if (DataTypeToEnum::value == dst_dtype) { \ + return [](OpKernelContext* ctx, const Tensor& inp, Tensor* out, \ + bool truncate) { \ + functor::CastFunctor func; \ + func(ctx->eigen_device(), out->flat(), inp.flat(), \ + truncate); \ + }; \ + } + +// The functions below are implemented in the cast_op_impl_*.cc files. +CastFunctorType GetCpuCastFromBool(DataType dst_dtype); + +CastFunctorType GetCpuCastFromUint8(DataType dst_dtype); + +CastFunctorType GetCpuCastFromUint16(DataType dst_dtype); + +CastFunctorType GetCpuCastFromInt8(DataType dst_dtype); + +CastFunctorType GetCpuCastFromUint32(DataType dst_dtype); + +CastFunctorType GetCpuCastFromUint64(DataType dst_dtype); + +CastFunctorType GetCpuCastFromInt8(DataType dst_dtype); + +CastFunctorType GetCpuCastFromInt16(DataType dst_dtype); + +CastFunctorType GetCpuCastFromInt32(DataType dst_dtype); + +CastFunctorType GetCpuCastFromInt64(DataType dst_dtype); + +CastFunctorType GetCpuCastFromHalf(DataType dst_dtype); + +CastFunctorType GetCpuCastFromFloat(DataType dst_dtype); + +CastFunctorType GetCpuCastFromDouble(DataType dst_dtype); + +CastFunctorType GetCpuCastFromComplex64(DataType dst_dtype); + +CastFunctorType GetCpuCastFromComplex128(DataType dst_dtype); + +CastFunctorType GetCpuCastFromBfloat(DataType dst_dtype); + +CastFunctorType GetCpuCastFromFloat8e5m2(DataType dst_dtype); + +CastFunctorType GetCpuCastFromFloat8e4m3fn(DataType dst_dtype); + +CastFunctorType GetCpuCastFromInt4(DataType dst_dtype); + +CastFunctorType GetCpuCastFromUint4(DataType dst_dtype); + +#if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \ + (defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM) +// Same, for GPU. +CastFunctorType GetGpuCastFromBool(DataType dst_dtype); + +CastFunctorType GetGpuCastFromUint8(DataType dst_dtype); + +CastFunctorType GetGpuCastFromUint16(DataType dst_dtype); + +CastFunctorType GetGpuCastFromInt8(DataType dst_dtype); + +CastFunctorType GetGpuCastFromUint32(DataType dst_dtype); + +CastFunctorType GetGpuCastFromUint64(DataType dst_dtype); + +CastFunctorType GetGpuCastFromInt16(DataType dst_dtype); + +CastFunctorType GetGpuCastFromInt32(DataType dst_dtype); + +CastFunctorType GetGpuCastFromInt64(DataType dst_dtype); + +CastFunctorType GetGpuCastFromHalf(DataType dst_dtype); + +CastFunctorType GetGpuCastFromFloat(DataType dst_dtype); + +CastFunctorType GetGpuCastFromDouble(DataType dst_dtype); + +CastFunctorType GetGpuCastFromComplex64(DataType dst_dtype); + +CastFunctorType GetGpuCastFromComplex128(DataType dst_dtype); + +CastFunctorType GetGpuCastFromBfloat(DataType dst_dtype); + +CastFunctorType GetGpuCastFromFloat8e5m2(DataType dst_dtype); + +CastFunctorType GetGpuCastFromFloat8e4m3fn(DataType dst_dtype); + +CastFunctorType GetGpuCastFromInt4(DataType dst_dtype); + +CastFunctorType GetGpuCastFromUint4(DataType dst_dtype); + +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM + + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_CAST_OP_IMPL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/checkpoint_callback_manager.h b/third_party/tflite-hdrs/tensorflow/core/kernels/checkpoint_callback_manager.h new file mode 100644 index 00000000..7e0d9d8f --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/checkpoint_callback_manager.h @@ -0,0 +1,113 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0(the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_KERNELS_CHECKPOINT_CALLBACK_MANAGER_H_ +#define TENSORFLOW_CORE_KERNELS_CHECKPOINT_CALLBACK_MANAGER_H_ + +#include +#include +#include + +#include "absl/base/attributes.h" +#include "absl/container/flat_hash_map.h" +#include "absl/strings/string_view.h" +#include "tensorflow/core/framework/resource_base.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/statusor.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +namespace checkpoint { + +ABSL_CONST_INIT extern const absl::string_view + kCheckpointCallbackManagerResourceName; + +// StatusOr save_callback(absl::string_view checkpoint_id); +using SaveCallback = + std::function(absl::string_view)>; + +// Status restore_callback(absl::string_view checkpoint_id, +// absl::string_view content_from_checkpoint); +using RestoreCallback = + std::function; + +// A class to save and restore additional information for checkpointing. +class CheckpointCallbackManager : public ResourceBase { + public: + CheckpointCallbackManager() = default; + + // Not copyable or movable + CheckpointCallbackManager(const CheckpointCallbackManager&) = delete; + CheckpointCallbackManager& operator=(const CheckpointCallbackManager&) = + delete; + + std::string DebugString() const override { + return "CheckpointCallbackManager"; + } + + // Infers a checkpoint id and directory from a prefix + // passed to SaveV2 / RestoreV2 Ops + static absl::StatusOr> + GetCheckpointIdAndPathFromPrefix(absl::string_view prefix); + + // Register a save callback. + // The passed callback will be triggered with an identified checkpoint id. + // The callback should return a string content needs to be stored + // as a part of a checkpoint, and then the content is stored as a file + // with the registered the file_extension. + absl::Status RegisterSaveCallback(absl::string_view file_extension, + SaveCallback callback); + + // Checks if a registered save callback exists for an extension. + bool DoesSaveCallbackExist(absl::string_view file_extension); + + // Register a restore callback. + // The passed file_extension is used to generate a file name together with + // an identified checkpoint_id. If the file exists, the registered callback + // is triggered with the content of the file. + absl::Status RegisterRestoreCallback(absl::string_view file_extension, + RestoreCallback callback); + + // Checks if a registered restore callback exists for an extension. + bool DoesRestoreCallbackExist(absl::string_view file_extension); + + // Should be triggered from SaveV2()::Compute(). + void Save(absl::string_view prefix); + + // Should be triggered from RestoreV2()::Compute(). + void Restore(absl::string_view prefix); + + private: + mutable mutex mu_; + + absl::flat_hash_map save_callbacks_ + TF_GUARDED_BY(mu_); + absl::flat_hash_map restore_callbacks_ + TF_GUARDED_BY(mu_); + + // Checkpoint save and restore could happen before save / restore callbacks + // are registered. The last checkpoint information is kept in these variables + // to trigger the registered callback lazily. + std::pair last_restored_checkpoint_id_and_dir_ + TF_GUARDED_BY(mu_); + + std::pair last_saved_checkpoint_id_and_dir_ + TF_GUARDED_BY(mu_); +}; + +} // namespace checkpoint +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_CHECKPOINT_CALLBACK_MANAGER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/collective_nccl.h b/third_party/tflite-hdrs/tensorflow/core/kernels/collective_nccl.h new file mode 100644 index 00000000..4fc4bebb --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/collective_nccl.h @@ -0,0 +1,45 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_KERNELS_COLLECTIVE_NCCL_H_ +#define TENSORFLOW_CORE_KERNELS_COLLECTIVE_NCCL_H_ + +#include "tensorflow/core/framework/collective.h" + +namespace tensorflow { +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM + +class NcclBase : public CollectiveImplementationInterface { + public: + explicit NcclBase(CollectiveType type, const string& name); + ~NcclBase() override = default; + + // No-op for this collective implementation. + Status InitializeCollectiveParams(CollectiveParams* col_params) override; + + // Initializes the device objects and device localities. + Status InitializeCollectiveContext( + std::shared_ptr col_ctx) override; + + protected: + const CollectiveType type_; + const string name_; + std::shared_ptr col_ctx_; + const CollectiveParams* col_params_; // Not owned +}; + +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_COLLECTIVE_NCCL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/collective_nccl_all_to_all.h b/third_party/tflite-hdrs/tensorflow/core/kernels/collective_nccl_all_to_all.h new file mode 100644 index 00000000..4ba624c9 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/collective_nccl_all_to_all.h @@ -0,0 +1,35 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_KERNELS_COLLECTIVE_NCCL_ALL_TO_ALL_H_ +#define TENSORFLOW_CORE_KERNELS_COLLECTIVE_NCCL_ALL_TO_ALL_H_ + +#include "tensorflow/core/kernels/collective_nccl.h" + +namespace tensorflow { +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM + +class NcclAllToAll : public NcclBase { + public: + NcclAllToAll() : NcclBase(ALL_TO_ALL_COLLECTIVE, "NcclAllToAll") {} + ~NcclAllToAll() override = default; + + // Hands off all-to-all to NcclManager. + void Run(StatusCallback done) override; +}; + +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_COLLECTIVE_NCCL_ALL_TO_ALL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/collective_nccl_broadcaster.h b/third_party/tflite-hdrs/tensorflow/core/kernels/collective_nccl_broadcaster.h new file mode 100644 index 00000000..9c1f6f4a --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/collective_nccl_broadcaster.h @@ -0,0 +1,35 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_KERNELS_COLLECTIVE_NCCL_BROADCASTER_H_ +#define TENSORFLOW_CORE_KERNELS_COLLECTIVE_NCCL_BROADCASTER_H_ + +#include "tensorflow/core/kernels/collective_nccl.h" + +namespace tensorflow { +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM + +class NcclBroadcaster : public NcclBase { + public: + NcclBroadcaster() : NcclBase(BROADCAST_COLLECTIVE, "NcclBroadcast") {} + ~NcclBroadcaster() override = default; + + // Hands off broadcast to NcclManager. + void Run(StatusCallback done) override; +}; + +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_COLLECTIVE_NCCL_BROADCASTER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/collective_nccl_gatherer.h b/third_party/tflite-hdrs/tensorflow/core/kernels/collective_nccl_gatherer.h new file mode 100644 index 00000000..97d41f77 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/collective_nccl_gatherer.h @@ -0,0 +1,35 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_KERNELS_COLLECTIVE_NCCL_GATHERER_H_ +#define TENSORFLOW_CORE_KERNELS_COLLECTIVE_NCCL_GATHERER_H_ + +#include "tensorflow/core/kernels/collective_nccl.h" + +namespace tensorflow { +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM + +class NcclGatherer : public NcclBase { + public: + NcclGatherer() : NcclBase(GATHER_COLLECTIVE, "NcclGather") {} + ~NcclGatherer() override = default; + + // Hands off all-gather to NcclManager. + void Run(StatusCallback done) override; +}; + +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_COLLECTIVE_NCCL_GATHERER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/collective_nccl_reducer.h b/third_party/tflite-hdrs/tensorflow/core/kernels/collective_nccl_reducer.h new file mode 100644 index 00000000..b95d5720 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/collective_nccl_reducer.h @@ -0,0 +1,44 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_KERNELS_COLLECTIVE_NCCL_REDUCER_H_ +#define TENSORFLOW_CORE_KERNELS_COLLECTIVE_NCCL_REDUCER_H_ + +#include "tensorflow/core/kernels/collective_nccl.h" + +namespace tensorflow { +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM + +class NcclReducer : public NcclBase { + public: + NcclReducer() : NcclBase(REDUCTION_COLLECTIVE, "NcclReduce") {} + NcclReducer(CollectiveType type, const string& name) : NcclBase(type, name) {} + ~NcclReducer() override = default; + + // Hands off all reduce to NcclManager. + void Run(StatusCallback done) override; +}; + +class NcclReduceScatterer : public NcclReducer { + public: + NcclReduceScatterer() + : NcclReducer(REDUCE_SCATTER_COLLECTIVE, "NcclReduceScatter") {} + ~NcclReduceScatterer() override = default; + // Uses same Run() as NcclReducer. +}; + +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_COLLECTIVE_NCCL_REDUCER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/composite_tensor_variant.h b/third_party/tflite-hdrs/tensorflow/core/kernels/composite_tensor_variant.h new file mode 100644 index 00000000..fa98f795 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/composite_tensor_variant.h @@ -0,0 +1,96 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_COMPOSITE_TENSOR_VARIANT_H_ +#define TENSORFLOW_CORE_KERNELS_COMPOSITE_TENSOR_VARIANT_H_ + +#include + +#include "absl/types/span.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/variant_tensor_data.h" + +namespace tensorflow { + +class CompositeTensorVariantMetadata; + +// Encoding for a `tf.ExtensionType` value, that can be saved as a Variant. +// +// `tf.ExtensionType` (also known as `CompositeTensor`) is a Python base class +// used to Python types that are supported by TensorFlow APIs. Example +// ExtensionTypes include `tf.RaggedTensor` and `tf.SparseTensor`. +// +// `CompositeTensorVariant` decomposes the `ExtensionType` value into two +// parts: +// +// * `components`: A list of Tensors, which encodes the value's dynamic +// data -- i.e., data that may change for different executions of a graph. +// * `metadata`: A serialized TypeSpec, which encodes the value's +// static data -- i.e., data that is the same for all executions of a graph. +// +// CompositeTensorVariant can be stored in a Tensor with dtype=DT_VARIANT. +// Typically, extension type values are encoded with a scalar tensor containing +// a single CompositeTensorVariant value. +class CompositeTensorVariant { + public: + CompositeTensorVariant(const CompositeTensorVariantMetadata& metadata, + absl::Span flat_components); + + CompositeTensorVariant(); + CompositeTensorVariant(const CompositeTensorVariant& other); + CompositeTensorVariant& operator=(CompositeTensorVariant&& other) = default; + CompositeTensorVariant& operator=(const CompositeTensorVariant& other) = + delete; + + // Returns the list of Tensor components that encode this value's dynamic + // data. + absl::Span flat_components() const { + return absl::MakeConstSpan(flat_components_); + } + + // Returns the serialized TypeSpec that encodes the value's static data. + const CompositeTensorVariantMetadata& metadata() const { return *metadata_; } + + // Variant methods. + string TypeName() const { return kTypeName; } + + // Updates `VariantTensorData` with an encoding for this value. + void Encode(VariantTensorData* data) const; + + // Updates this value to match the encoding in a given `VariantTensorData`. + bool Decode(const VariantTensorData& data); + + // Returns a string summary for this value. + string DebugString() const; + + // Name of this type (used for variant serialization). + static constexpr const char kTypeName[] = "CompositeTensorVariant"; + + private: + // Tensor components for this value. + std::vector flat_components_; + + // TypeSpec for this value. CompositeTensorVariantMetadata is a thin wrapper + // around a TypeSpecProto, which is used to retain flexibility to change the + // variant encoding. + // + // Note: we use a unique_ptr, because header files in the kernels/ directory + // are not allowed to import .pb.h files. + std::unique_ptr metadata_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_COMPOSITE_TENSOR_VARIANT_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/concat_lib.h b/third_party/tflite-hdrs/tensorflow/core/kernels/concat_lib.h new file mode 100644 index 00000000..ca30908c --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/concat_lib.h @@ -0,0 +1,75 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_CONCAT_LIB_H_ +#define TENSORFLOW_CORE_KERNELS_CONCAT_LIB_H_ + +#include + +#include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/device_base.h" +#include "tensorflow/core/framework/register_types.h" + +namespace tensorflow { + +// Functors to concatenate tensors. These always take a rank-2 tensor (i.e a +// matrix) and concatenate it along the axis 1 ("putting them next to each +// other" as opposed to "putting them on top of one another"). +// +// Any concatenation of n-dimensional tensors across any axis can be reduced to +// a concatenation of two-dimensional tensors across the axis 1 by first +// partitioning the axes of the original tensors into those less than the axis +// to be concatenated across and the rest. Then reshape the tensors into a +// two-dimensional tensor by collapsing these two sets of axes and concatenate +// the resulting matrices across the axis 1, finally reshaping the result to +// have the proper shape. +// +// So, for example, when stacking N tensors, reshape each to have shape +// {1, Numelements} and reshape the result matrix to have shape +// {1, N * NumElements} before passing it to this functor. + +// Assumes all elements of inputs are nonempty. +// Assumes output is nonempty. +template +void ConcatCPU( + DeviceBase* d, + const std::vector::ConstMatrix>>& + inputs, + typename TTypes::Matrix* output); +#if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \ + (defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM) +template +void ConcatGPU( + OpKernelContext* c, + const std::vector::ConstMatrix>>& + inputs_flat, + Tensor* output, typename TTypes::Tensor* output_flat); + +// Explicit instantiations in concat_lib_gpu.cc. +#define REGISTER(T) \ + extern template void ConcatGPU( \ + OpKernelContext * c, \ + const std::vector::ConstMatrix>>& \ + inputs_flat, \ + Tensor* output, typename TTypes::Tensor* output_flat); + +TF_CALL_INTEGRAL_TYPES(REGISTER); // int32 Needed for TensorLists. +TF_CALL_GPU_ALL_TYPES(REGISTER); +#undef REGISTER +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_CONCAT_LIB_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/concat_lib_cpu.h b/third_party/tflite-hdrs/tensorflow/core/kernels/concat_lib_cpu.h new file mode 100644 index 00000000..45960772 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/concat_lib_cpu.h @@ -0,0 +1,135 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_CONCAT_LIB_CPU_H_ +#define TENSORFLOW_CORE_KERNELS_CONCAT_LIB_CPU_H_ + +#define EIGEN_USE_THREADS + +#include +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/kernels/concat_lib.h" +#include "tensorflow/core/util/work_sharder.h" + +namespace tensorflow { + +// ElementCopier must be a struct with a single Copy function, which is passed +// the output pointer, input pointer, input index, and number of elements to +// copy from input to output. +template +void ConcatCPUImpl( + DeviceBase* d, + const std::vector::ConstMatrix>>& + inputs, + int64_t cost_per_unit, ElementCopier copier, + typename TTypes::Matrix* output) { + size_t num_inputs = inputs.size(); + + std::vector sizes; + sizes.reserve(num_inputs); + int64_t row_size = 0; + for (const auto& input : inputs) { + sizes.push_back(input->dimension(1)); + row_size += sizes.back(); + } + + // cost_per_unit is estimated bytes to copy per output array element (for + // strings this includes an estimate of the number of bytes of the actual + // string data, as well). + const int64_t estimated_total_cost = output->size() * cost_per_unit; + + auto worker_threads = d->tensorflow_cpu_worker_threads(); + int num_threads = std::min(4, worker_threads->num_threads); + num_threads = static_cast( + std::min(num_threads, estimated_total_cost / 16384)); + // Single threaded mode. + // TODO(dga): Deduplicate this code w.r.t. sharded code below. + if (num_threads == 0) { + T* out = &(*output)(0, 0); + std::vector inp; + inp.reserve(num_inputs); + for (const auto& input : inputs) { + inp.push_back(&(*input)(0, 0)); + } + const int64_t dim0 = output->dimension(0); + for (int64_t i = 0; i < dim0; ++i) { + for (int64_t j = 0; j < num_inputs; ++j) { + auto size = sizes[j]; + copier.Copy(out, inp[j], j, size); + out += size; + inp[j] += size; + } + } + return; + } + + // Sharded mode. + auto work = [&row_size, &sizes, &inputs, &output, &copier, &num_inputs]( + int64_t start, int64_t end) { + int64_t skipped_rows = start / row_size; + T* out = output->data() + skipped_rows * row_size; + T* out_start = output->data() + start; + T* out_end = output->data() + end; + + // Handle partial row at start + if (out < out_start) { + for (size_t j = 0; j < num_inputs; ++j) { + ptrdiff_t size = sizes[j]; + ptrdiff_t offset = out_start - out; + if (size <= offset) { + out += size; + continue; + } + const T* inp = &(*inputs[j])(skipped_rows, 0); + if (offset > 0) { + out += offset; + inp += offset; + size -= offset; + } + size = std::min(size, out_end - out); + if (size <= 0) break; + copier.Copy(out, inp, j, size); + out += size; + } + ++skipped_rows; + } + if (out == out_end) return; + CHECK(out >= out_start); + CHECK(out < out_end); + + // Copy remaining data. + std::vector inp; + inp.reserve(num_inputs); + for (const auto& input : inputs) { + inp.push_back(&(*input)(skipped_rows, 0)); + } + const int64_t dim0 = output->dimension(0); + for (int64_t i = skipped_rows; i < dim0; ++i) { + for (int64_t j = 0; j < num_inputs; ++j) { + ptrdiff_t size = std::min(sizes[j], out_end - out); + copier.Copy(out, inp[j], j, size); + out += size; + inp[j] += size; + if (out == out_end) return; + } + } + }; + Shard(worker_threads->num_threads, worker_threads->workers, output->size(), + cost_per_unit, work); +} + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_CONCAT_LIB_CPU_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/concat_lib_gpu.h b/third_party/tflite-hdrs/tensorflow/core/kernels/concat_lib_gpu.h new file mode 100644 index 00000000..8e42cc1c --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/concat_lib_gpu.h @@ -0,0 +1,75 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_CONCAT_LIB_GPU_H_ +#define TENSORFLOW_CORE_KERNELS_CONCAT_LIB_GPU_H_ + +#define EIGEN_USE_THREADS +#define EIGEN_USE_GPU + +#include +#include + +#include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/kernels/concat_lib.h" +#include "tensorflow/core/kernels/gpu_device_array_gpu.h" + +namespace tensorflow { + +template +void ConcatGPUSlice( + const Eigen::GpuDevice& gpu_device, + const std::vector::ConstMatrix>>& + inputs_flat, + typename TTypes::Matrix* output); + +template +void ConcatGPUImpl(const Eigen::GpuDevice& d, + const GpuDeviceArrayStruct& input_ptrs, + const GpuDeviceArrayStruct& ptr_offsets, + bool same_size, int slice_size, + typename TTypes::Matrix* output); + +// Explicit instantiations in concat_lib_gpu_impl.cu.cc. +#define REGISTER(T) \ + extern template void ConcatGPUSlice( \ + const Eigen::GpuDevice& gpu_device, \ + const std::vector::ConstMatrix>>& \ + inputs_flat, \ + typename TTypes::Matrix* output); \ + extern template void ConcatGPUSlice( \ + const Eigen::GpuDevice& gpu_device, \ + const std::vector::ConstMatrix>>& \ + inputs_flat, \ + typename TTypes::Matrix* output); \ + extern template void ConcatGPUImpl( \ + const Eigen::GpuDevice& d, \ + const GpuDeviceArrayStruct& input_ptrs, \ + const GpuDeviceArrayStruct& ptr_offsets, bool fixed_size, \ + int split_size, typename TTypes::Matrix* output); \ + extern template void ConcatGPUImpl( \ + const Eigen::GpuDevice& d, \ + const GpuDeviceArrayStruct& input_ptrs, \ + const GpuDeviceArrayStruct& ptr_offsets, bool fixed_size, \ + int split_size, typename TTypes::Matrix* output); + +TF_CALL_INTEGRAL_TYPES(REGISTER); // int32 Needed for TensorLists. +TF_CALL_GPU_ALL_TYPES(REGISTER); +#undef REGISTER + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_CONCAT_LIB_GPU_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/conditional_accumulator.h b/third_party/tflite-hdrs/tensorflow/core/kernels/conditional_accumulator.h new file mode 100644 index 00000000..d2578a55 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/conditional_accumulator.h @@ -0,0 +1,136 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_CONDITIONAL_ACCUMULATOR_H_ +#define TENSORFLOW_CORE_KERNELS_CONDITIONAL_ACCUMULATOR_H_ + +#include "tensorflow/core/kernels/fill_functor.h" +#include "tensorflow/core/kernels/typed_conditional_accumulator_base.h" + +namespace tensorflow { + +/** + * An aggregation object for adding dense gradients. + * + * The two main methods of this class are TryApplyGrad and TryTakeGrad. + * + * TryApplyGrad tries add a gradient to the accumulator. The attempt is + * successful if local_step >= global_step, i.e., if the gradient is not stale, + * having been computed using up-to-date information. Otherwise, the gradient is + * silently dropped. + * + * TryTakeGrad logs an attempt to read the average gradient. The attempt is + * blocked until the number of gradients accumulated (via TryApplyGrad) is equal + * or exceeds the number requested by TryTakeGrad. + * Once this condition is satisfied, the following actions are taken: + * (1) the value of the average gradient is returned + * (2) the count of accumulated gradients is reset to 0 + * (3) the internal global_step value (current_global_step_) is incremented by 1 + * + * ConditionalAccumulator is the datatype-dependent templated sub-class of + * ConditionalAccumulatorBase. It implements the virtual arithmetic methods that + * are used by for aggregating, averaging, allocating, returning dense Tensors. + */ +template +class ConditionalAccumulator + : public TypedConditionalAccumulatorBase { + public: + // Args: + // dtype: The datatype of the gradients to be accumulated. + // shape: The shape of the accumulated gradients. + // name: A name to use for the ConditionalAccumulator. + // reduction_type: The reduction type, i.e., MEAN or SUM + ConditionalAccumulator(const DataType& dtype, const PartialTensorShape& shape, + const string& name, const string& reduction_type) + : TypedConditionalAccumulatorBase(dtype, shape, name, + reduction_type) {} + ~ConditionalAccumulator() override{}; + + protected: + // accum_grad is the tensor that holds the aggregate gradient. + // It is initialized the first time ApplyGrad is called. + Tensor accum_grad_; + + functor::SetZeroFunctor set_zero_functor_; + + absl::Status ValidateShape(const Tensor* tensor) + TF_EXCLUSIVE_LOCKS_REQUIRED(this->mu_) { + // Must be compatible with accumulated gradient if available + if (counter_ > 0) { + if (!accum_grad_.shape().IsSameSize(tensor->shape())) { + return errors::InvalidArgument("Shape mismatch: expected ", + accum_grad_.shape().DebugString(), + ", got ", tensor->shape().DebugString()); + } + } + // Must also be compatible with given shape + if (!shape_.IsCompatibleWith(tensor->shape())) { + return errors::InvalidArgument("Shape mismatch: expected ", + shape_.DebugString(), ", got ", + tensor->shape().DebugString()); + } + return absl::OkStatus(); + } + + void AllocateAndAssignToAccumGradFunction(OpKernelContext* ctx, + const Tensor* grad) override { + // TODO(b/32704451): Don't just ignore the ::tensorflow::Status object! + ctx->allocate_temp(dtype_, grad->shape(), &accum_grad_).IgnoreError(); + accum_grad_.flat().device(ctx->template eigen_device()) = + grad->flat(); + } + + void AddToAccumGradFunction(OpKernelContext* ctx, + const Tensor* grad) override { + accum_grad_.flat().device(ctx->template eigen_device()) += + grad->flat(); + } + + void DivideAccumGradByCounter(OpKernelContext* ctx) override + TF_EXCLUSIVE_LOCKS_REQUIRED(this->mu_) { + Tensor c(DataTypeToEnum::value, {}); + c.scalar()() = TypeConverter::ConvertUToT(this->counter_); + this->accum_grad_.template flat().device( + ctx->template eigen_device()) = + this->accum_grad_.template flat() / c.scalar()(); + } + + bool SetOutput(OpKernelContext* ctx) override { + ctx->set_output(0, accum_grad_); + return true; + } + + bool GetAndValidateTensorInputForApplyGrad(OpKernelContext* ctx, + const Tensor** tensor) override + TF_EXCLUSIVE_LOCKS_REQUIRED(this->mu_) { + // Get input gradient tensor + const Tensor* grad_tensor; + OP_REQUIRES_OK_BOOLEAN(ctx, ctx->input("gradient", &grad_tensor)); + *tensor = grad_tensor; + OP_REQUIRES_OK_BOOLEAN(ctx, this->ValidateShape(*tensor)); + return true; + } + + void CleanUpGradTensor(const Tensor* tensor) override { + // do nothing + } + + ConditionalAccumulator(const ConditionalAccumulator&) = delete; + void operator=(const ConditionalAccumulator&) = delete; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_CONDITIONAL_ACCUMULATOR_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/conditional_accumulator_base.h b/third_party/tflite-hdrs/tensorflow/core/kernels/conditional_accumulator_base.h new file mode 100644 index 00000000..683e667e --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/conditional_accumulator_base.h @@ -0,0 +1,201 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_CONDITIONAL_ACCUMULATOR_BASE_H_ +#define TENSORFLOW_CORE_KERNELS_CONDITIONAL_ACCUMULATOR_BASE_H_ + +#include + +#include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/numeric_op.h" + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/resource_mgr.h" + +namespace tensorflow { + +/** + * ConditionalAccumulator/ConditionalAccumulatorBase implements an aggregation + * object for adding gradients. + * The two main methods of this class are TryApplyGrad and TryTakeGrad. + * + * TryApplyGrad tries add a gradient to the accumulator. The attempt is + * successful if local_step >= global_step, i.e., if the gradient is not stale, + * having been computed using up-to-date information. Otherwise, the gradient is + * silently dropped. + * + * TryTakeGrad logs an attempt to read the average gradient. The attempt is + * blocked until the number of gradients accumulated (via TryApplyGrad) is equal + * or exceeds the number requested by TryTakeGrad. + * Once this condition is satisfied, the following actions are taken: + * (1) the value of the average gradient is returned + * (2) the count of accumulated gradients is reset to 0 + * (3) the internal global_step value (current_global_step_) is incremented by 1 + */ +class ConditionalAccumulatorBase : public ResourceBase { + public: + // Args: + // dtype: The datatype of the gradients to be accumulated. + // shape: The shape of the accumulated gradients. + // name: A name to use for the ConditionalAccumulator. + ConditionalAccumulatorBase(const DataType& dtype, + const PartialTensorShape& shape, + const string& name, const string& reduction_type); + + typedef AsyncOpKernel::DoneCallback DoneCallback; + + virtual void TryApplyGrad(int64_t local_step, OpKernelContext* ctx) = 0; + void TryTakeGrad(int num_required, OpKernelContext* ctx, + DoneCallback callback); + + // Accessor methods + uint32 num_accumulated() { + mutex_lock lock(mu_); + return counter_; + } + + const DataType& dtype() const { return dtype_; } + + string DebugString() const override { return "A conditional accumulator"; } + + // SetGlobalStep is a modifier method for current_global_step. + // It returns an InvalidArgument error if the new_global_step is less than + // current_global_step. + absl::Status SetGlobalStep(int64_t new_global_step); + + absl::Status MatchesNodeDef(const NodeDef& node_def); + + protected: + // Virtual methods to be implemented by sub-classes for different datatypes. + // Implements arithmetic operations specific to datatype. + virtual void DivideAccumGradByCounter(OpKernelContext* ctx) + TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) = 0; + virtual bool SetOutput(OpKernelContext* ctx) = 0; + + enum RunResult { kNoProgress, kComplete }; + + // Helper struct holding information about a TakeGrad attempt + struct Attempt; + typedef std::function RunCallback; + struct Attempt { + int elements_requested; + DoneCallback done_callback; // must be run outside mu_ + OpKernelContext* context; + CancellationManager* cancellation_manager; // not owned + CancellationToken cancellation_token; + RunCallback run_callback; // must be run while holding mu_ + bool is_cancelled; + + Attempt(int elements_requested, DoneCallback done_callback, + OpKernelContext* context, CancellationManager* cancellation_manager, + CancellationToken cancellation_token, RunCallback run_callback) + : elements_requested(elements_requested), + done_callback(std::move(done_callback)), + context(context), + cancellation_manager(cancellation_manager), + cancellation_token(cancellation_token), + run_callback(std::move(run_callback)), + is_cancelled(false) {} + }; + + // Helper struct for deregistration of a cancellation token and executing a + // DoneCallback after a TakeGrad attempt is complete. + struct CleanUp { + CleanUp(DoneCallback&& f, CancellationToken ct, CancellationManager* cm) + : finished(f), to_deregister(ct), cm(cm) {} + DoneCallback finished; + CancellationToken to_deregister; + CancellationManager* cm; + }; + + // Fields + + const DataType dtype_; + const PartialTensorShape shape_; + const string name_; + const string reduction_type_; + mutex mu_; + int counter_ TF_GUARDED_BY(mu_); + int64_t current_global_step_ TF_GUARDED_BY(mu_); + + std::deque takegrad_attempts_ TF_GUARDED_BY(mu_); + + // Methods + + // Helper function for creating cancellation callback + void Cancel(CancellationManager* cancellation_manager, + CancellationToken token); + + // Helper functions to process TakeGrad attempts. + // FlushUnlocked is called at the end of each TryApplyGrad and TryTakeGrad + // calls to try to clear the TakeGrad attempts. This in turn calls + // TryAttemptLocked, which then executes the RunCallback of the logged + // attempts. + // Both functions are modeled after core/kernels/queue_base. + // Note: ApplyGrad attempts never block -- unlike in a queue with limited + // capacity, we can always add the newest gradient to our accumulator + // (if it is not stale) or drop it silently (if it is stale). + void FlushUnlocked(); + bool TryAttemptLocked(std::vector* clean_up) + TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); + + // Helper methods + // void DeepCopy(Tensor* dst); + bool TakeGradLockedHelper(OpKernelContext* ctx, DoneCallback callback) + TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); +}; + +/* + * Modifications to convenience macros defined in core/framework/op_kernel.h. + * The below macros return a boolean if the test fails, so that the calling + * function can get an indication that a failure has occurred. + */ +#define OP_REQUIRES_BOOLEAN(CTX, EXP, STATUS) \ + do { \ + if (!TF_PREDICT_TRUE(EXP)) { \ + (CTX)->CtxFailure(__FILE__, __LINE__, (STATUS)); \ + return false; \ + } \ + } while (0) + +#define OP_REQUIRES_OK_BOOLEAN(CTX, STATUS) \ + do { \ + ::tensorflow::Status _s(STATUS); \ + if (!TF_PREDICT_TRUE(_s.ok())) { \ + (CTX)->CtxFailureWithWarning(__FILE__, __LINE__, _s); \ + return false; \ + } \ + } while (0) + +/* + * Convenience classes for helping to convert between numeric types. + * The specialization for Eigen::half here simplifies specialization of + * ConditionalAccumulator classes later. + */ +template +class TypeConverter { + public: + static T ConvertUToT(U c) { return c; /* implicit conversion */ } +}; + +template +class TypeConverter { + public: + static Eigen::half ConvertUToT(U c) { return static_cast(c); } +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_CONDITIONAL_ACCUMULATOR_BASE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/conditional_accumulator_base_op.h b/third_party/tflite-hdrs/tensorflow/core/kernels/conditional_accumulator_base_op.h new file mode 100644 index 00000000..c0d1c9a6 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/conditional_accumulator_base_op.h @@ -0,0 +1,262 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_CONDITIONAL_ACCUMULATOR_BASE_OP_H_ +#define TENSORFLOW_CORE_KERNELS_CONDITIONAL_ACCUMULATOR_BASE_OP_H_ + +#define EIGEN_USE_THREADS + +#include "tensorflow/core/kernels/conditional_accumulator_base.h" + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/resource_mgr.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/thread_annotations.h" +#include "tensorflow/core/platform/types.h" + +typedef Eigen::ThreadPoolDevice CPUDevice; + +typedef std::function DoneCallback; + +namespace tensorflow { + +/** + * Defines a ConditionalAccumulatorBaseOp, which constructs a + * ConditionalAccumulatorBase (via sub-class's Creator) and returns its handle. + */ +class ConditionalAccumulatorBaseOp : public OpKernel { + public: + explicit ConditionalAccumulatorBaseOp(OpKernelConstruction* context) + : OpKernel(context), accumulator_set_(false) { + OP_REQUIRES_OK(context, context->allocate_temp(DT_STRING, TensorShape({2}), + &accumulator_)); + OP_REQUIRES_OK(context, context->GetAttr("shape", &shape_)); + OP_REQUIRES_OK(context, context->GetAttr("dtype", &dtype_)); + OP_REQUIRES_OK(context, + context->GetAttr("reduction_type", &reduction_type_)); + } + + void Compute(OpKernelContext* ctx) override { + mutex_lock l(mu_); + if (!accumulator_set_) { + OP_REQUIRES_OK(ctx, SetAccumulatorHandle(ctx)); + } + SetHandleToOutput(ctx); + } + + protected: + ~ConditionalAccumulatorBaseOp() override { + // If the accumulator object was not shared, delete it. + if (accumulator_set_ && cinfo_.resource_is_private_to_kernel()) { + TF_CHECK_OK((cinfo_.resource_manager() + ->template Delete( + cinfo_.container(), cinfo_.name()))); + } + } + + protected: + virtual void SetHandleToOutput(OpKernelContext* ctx) + TF_SHARED_LOCKS_REQUIRED(mu_) = 0; + + virtual absl::Status CheckSignature(OpKernelContext* ctx) = 0; + + protected: + typedef std::function Creator; + + // Subclasses must override this + virtual Creator GetCreator() const = 0; + + // Variables required to construct ConditionalAccumulator + DataType dtype_; + PartialTensorShape shape_; + ContainerInfo cinfo_; + string reduction_type_; + mutex mu_; + Tensor accumulator_ TF_GUARDED_BY(mu_); + bool accumulator_set_ TF_GUARDED_BY(mu_); + + private: + absl::Status SetAccumulatorHandle(OpKernelContext* ctx) + TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { + TF_RETURN_IF_ERROR(cinfo_.Init(ctx->resource_manager(), def())); + + // Check input signature + TF_RETURN_IF_ERROR(CheckSignature(ctx)); + + Creator creator = GetCreator(); + ConditionalAccumulatorBase* accumulator; + TF_RETURN_IF_ERROR( + (cinfo_.resource_manager() + ->template LookupOrCreate( + cinfo_.container(), cinfo_.name(), &accumulator, creator))); + core::ScopedUnref unref_me(accumulator); + + // Verify that the shared accumulator is compatible + // with the requested arguments. + TF_RETURN_IF_ERROR(accumulator->MatchesNodeDef(def())); + auto h = accumulator_.template flat(); + h(0) = cinfo_.container(); + h(1) = cinfo_.name(); + accumulator_set_ = true; + return absl::OkStatus(); + } +}; + +// ------------------Sync kernels ------------------------------------------ + +/** + * General OpKernel for ConditionalAccumulatorBase-related ops. + */ +class ConditionalAccumulatorBaseSyncOpKernel : public OpKernel { + public: + explicit ConditionalAccumulatorBaseSyncOpKernel(OpKernelConstruction* context) + : OpKernel(context) {} + + void Compute(OpKernelContext* ctx) final { + ConditionalAccumulatorBase* accumulator; + OP_REQUIRES_OK(ctx, GetResourceFromContext(ctx, "handle", &accumulator)); + Compute(ctx, accumulator); + accumulator->Unref(); + } + + protected: + virtual void Compute(OpKernelContext* ctx, + ConditionalAccumulatorBase* accumulator) = 0; + + virtual DataTypeVector GetExpectedInputs( + ConditionalAccumulatorBase* accumulator) = 0; + + virtual void CheckSignature(OpKernelContext* ctx, + ConditionalAccumulatorBase* accumulator) { + // Check input signature + DataTypeVector expected_inputs = GetExpectedInputs(accumulator); + OP_REQUIRES_OK(ctx, ctx->MatchSignature(expected_inputs, {})); + } +}; + +/** + * Defines a AccumulateGradientOp, the execution of which adds a gradient to the + * given ConditionalAccumulator. + */ +class ConditionalAccumulatorBaseApplyGradientOp + : public ConditionalAccumulatorBaseSyncOpKernel { + public: + explicit ConditionalAccumulatorBaseApplyGradientOp( + OpKernelConstruction* context) + : ConditionalAccumulatorBaseSyncOpKernel(context) {} + + protected: + void Compute(OpKernelContext* ctx, + ConditionalAccumulatorBase* accumulator) override { + // Check input signature + CheckSignature(ctx, accumulator); + + // Get input local_step + const Tensor* local_step_tensor; + OP_REQUIRES_OK(ctx, ctx->input("local_step", &local_step_tensor)); + if (!TensorShapeUtils::IsScalar(local_step_tensor->shape())) { + ctx->CtxFailureWithWarning(errors::InvalidArgument( + "Argument local_step must be scalar, but had bad shape ", + local_step_tensor->shape().DebugString())); + } + + // Actually try to apply gradient now + accumulator->TryApplyGrad(local_step_tensor->scalar()(), ctx); + } +}; + +// -------------------- Async kernels -------------------------------------- +/** + * General OpKernel for ConditionalAccumulatorBase-related ops. + */ +class ConditionalAccumulatorBaseAsyncOpKernel : public AsyncOpKernel { + public: + explicit ConditionalAccumulatorBaseAsyncOpKernel( + OpKernelConstruction* context) + : AsyncOpKernel(context) {} + + void ComputeAsync(OpKernelContext* ctx, DoneCallback callback) final { + ConditionalAccumulatorBase* accumulator; + OP_REQUIRES_OK_ASYNC( + ctx, GetResourceFromContext(ctx, "handle", &accumulator), callback); + ComputeAsync(ctx, accumulator, [callback, accumulator]() { + accumulator->Unref(); + callback(); + }); + } + + protected: + virtual void ComputeAsync(OpKernelContext* ctx, + ConditionalAccumulatorBase* accumulator, + DoneCallback callback) = 0; + + virtual DataTypeVector GetExpectedInputs( + ConditionalAccumulatorBase* accumulator) = 0; + + virtual void CheckSignature(OpKernelContext* ctx, + ConditionalAccumulatorBase* accumulator, + DoneCallback callback) { + // Check input signature + OP_REQUIRES_OK_ASYNC(ctx, + ctx->MatchSignature(GetExpectedInputs(accumulator), + {accumulator->dtype()}), + callback); + } +}; + +/** + * Defines a TakeAccumulatedGradientOp, the execution of which adds a gradient + * to the given ConditionalAccumulator. + */ +class ConditionalAccumulatorBaseTakeGradientOp + : public ConditionalAccumulatorBaseAsyncOpKernel { + public: + explicit ConditionalAccumulatorBaseTakeGradientOp( + OpKernelConstruction* context) + : ConditionalAccumulatorBaseAsyncOpKernel(context) {} + + protected: + void ComputeAsync(OpKernelContext* ctx, + ConditionalAccumulatorBase* accumulator, + DoneCallback callback) override { + // Check signature + CheckSignature(ctx, accumulator, callback); + + // Get input num_required + const Tensor* num_required_tensor; + OP_REQUIRES_OK_ASYNC(ctx, ctx->input("num_required", &num_required_tensor), + callback); + if (!TensorShapeUtils::IsScalar(num_required_tensor->shape())) { + ctx->CtxFailureWithWarning(errors::InvalidArgument( + "Argument num_required must be scalar, but had bad shape ", + num_required_tensor->shape().DebugString())); + callback(); + } + + // Actually try to take gradient now + accumulator->TryTakeGrad(num_required_tensor->scalar()(), ctx, + callback); + } +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_CONDITIONAL_ACCUMULATOR_BASE_OP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/constant_op.h b/third_party/tflite-hdrs/tensorflow/core/kernels/constant_op.h new file mode 100644 index 00000000..32f1ddb7 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/constant_op.h @@ -0,0 +1,52 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_CONSTANT_OP_H_ +#define TENSORFLOW_CORE_KERNELS_CONSTANT_OP_H_ + +#include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/platform/macros.h" + +namespace tensorflow { + +// ConstantOp returns a tensor specified by ConstantOpDef. +class ConstantOp : public OpKernel { + public: + explicit ConstantOp(OpKernelConstruction* ctx); + void Compute(OpKernelContext* ctx) override; + bool IsExpensive() override { return false; } + const Tensor* const_tensor() const override { return &tensor_; }; + ~ConstantOp() override; + + private: + Tensor tensor_; + ConstantOp(const ConstantOp&) = delete; + void operator=(const ConstantOp&) = delete; +}; + +class PlaceholderOp : public OpKernel { + public: + explicit PlaceholderOp(OpKernelConstruction* ctx); + void Compute(OpKernelContext* ctx) override; + + private: + PartialTensorShape expected_shape_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_CONSTANT_OP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/control_flow_ops.h b/third_party/tflite-hdrs/tensorflow/core/kernels/control_flow_ops.h new file mode 100644 index 00000000..13869317 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/control_flow_ops.h @@ -0,0 +1,140 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_CONTROL_FLOW_OPS_H_ +#define TENSORFLOW_CORE_KERNELS_CONTROL_FLOW_OPS_H_ + +#include "tensorflow/core/framework/op_kernel.h" + +namespace tensorflow { + +// A ControlTriggerOp is similar to a NoOp. However, it always treats the input +// control edges as Live edges. Its primary use so far is in the scheduling of +// recvs, where we add ControlTrigger nodes and use them to trigger recvs. We +// allow ControlTrigger nodes to be enabled by dead nodes. +class ControlTriggerOp : public OpKernel { + public: + explicit ControlTriggerOp(OpKernelConstruction* context) + : OpKernel(context) {} + void Compute(OpKernelContext* context) override {} + bool IsExpensive() override { return false; } +}; + +// A switch op has two inputs and two outputs. It forwards the value of +// Input:0 to the output specified by input:1. Input:1 is a boolean tensor. +// Input:0 is forwarded to output:0 if input:1 is false, otherwise to +// output:1. +class SwitchOp : public OpKernel { + public: + explicit SwitchOp(OpKernelConstruction* context) : OpKernel(context) {} + void Compute(OpKernelContext* context) override; + bool IsExpensive() override { return false; } + ~SwitchOp() override {} + + SwitchOp(const SwitchOp&) = delete; + void operator=(const SwitchOp&) = delete; +}; + +// An n-way switch op has two inputs and N outputs. It forwards the value of +// Input:0 to the output specified by Input:1. Input:1 is an integer tensor. +// Input:0 is forwarded to output:0 if Input:1 is 0, to output:1 if 1, and so +// forth. If Input:1 is <0 or >=num_outputs(), Input:0 is forwarded to +// output:num_outputs()-1. +class SwitchNOp : public OpKernel { + public: + explicit SwitchNOp(OpKernelConstruction* context) : OpKernel(context) {} + void Compute(OpKernelContext* context) override; + bool IsExpensive() override { return false; } + ~SwitchNOp() override {} + + SwitchNOp(const SwitchNOp&) = delete; + void operator=(const SwitchNOp&) = delete; +}; + +// A merge op has n inputs and two outputs. It forwards the value of the +// first input that becomes available to its first output, and the +// index of the first input to its second output. +class MergeOp : public OpKernel { + public: + explicit MergeOp(OpKernelConstruction* context); + void Compute(OpKernelContext* context) override; + bool IsExpensive() override { return false; } + ~MergeOp() override {} + + MergeOp(const MergeOp&) = delete; + void operator=(const MergeOp&) = delete; +}; + +// An enter op has one input and one output. It creates or finds +// the child frame that is uniquely identified by the frame_name, +// and makes its input available to the child frame. +class EnterOp : public OpKernel { + public: + explicit EnterOp(OpKernelConstruction* context) : OpKernel(context) {} + void Compute(OpKernelContext* context) override; + bool IsExpensive() override { return false; } + ~EnterOp() override {} + + EnterOp(const EnterOp&) = delete; + void operator=(const EnterOp&) = delete; +}; + +// An exit op has one input and one output. It exits the current +// frame to its parent frame, and makes its input available to the +// parent frame. +class ExitOp : public OpKernel { + public: + explicit ExitOp(OpKernelConstruction* context) : OpKernel(context) {} + void Compute(OpKernelContext* context) override; + bool IsExpensive() override { return false; } + ~ExitOp() override {} + + ExitOp(const ExitOp&) = delete; + void operator=(const ExitOp&) = delete; +}; + +// A next_iteration op has one input and one output. It makes its input +// available to the next iteration. +class NextIterationOp : public OpKernel { + public: + explicit NextIterationOp(OpKernelConstruction* context) : OpKernel(context) {} + void Compute(OpKernelContext* context) override; + bool IsExpensive() override { return false; } + ~NextIterationOp() override {} + + NextIterationOp(const NextIterationOp&) = delete; + void operator=(const NextIterationOp&) = delete; +}; + +// A LoopCond op has one input and one output. The input is a boolean +// scalar representing the taken branches of the "pivot" Switch that +// determines loop termination. As a contract, any high-level front-end +// should always use port '0' of the "pivot" switches for loop exit. +class LoopCondOp : public OpKernel { + public: + explicit LoopCondOp(OpKernelConstruction* context); + ~LoopCondOp() override; + + void Compute(OpKernelContext* context) override; + + bool IsExpensive() override; + + LoopCondOp(const LoopCondOp&) = delete; + void operator=(const LoopCondOp&) = delete; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_CONTROL_FLOW_OPS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/conv_2d.h b/third_party/tflite-hdrs/tensorflow/core/kernels/conv_2d.h new file mode 100644 index 00000000..1ddeec23 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/conv_2d.h @@ -0,0 +1,585 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_CONV_2D_H_ +#define TENSORFLOW_CORE_KERNELS_CONV_2D_H_ + +#include "absl/strings/string_view.h" +#include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "xla/tsl/framework/convolution/eigen_spatial_convolutions.h" +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/kernels/eigen_backward_spatial_convolutions.h" +#include "tensorflow/core/util/tensor_format.h" + +// Returns true if TF_CONV2D_USE_FP16_ACCUMULATE == 1, false otherwise. +static bool Conv2dUseFp16Accumulate() { + static bool use_fp16_accumulate = []() { + const char* env = std::getenv("TF_CONV2D_USE_FP16_ACCUMULATE"); + return (env != nullptr) && (absl::string_view(env) == "1"); + }(); + return use_fp16_accumulate; +} + +namespace tensorflow { +namespace functor { + +template +void SpatialConvolutionFunc(const Device& d, Output output, Input input, + Filter filter, int row_stride, int col_stride, + int row_dilation, int col_dilation, + const Eigen::PaddingType& padding, + const OutputKernel& output_kernel, + int padding_top = 0, int padding_bottom = 0, + int padding_left = 0, int padding_right = 0) { + // Need to swap row/col, padding_top/padding_left, and + // padding_bottom/padding_right when calling Eigen. Eigen expects the tensor + // in NWHC format, but the tensor given is in NHWC. + output.device(d) = Eigen::SpatialConvolution( + input, filter, col_stride, row_stride, padding, col_dilation, + row_dilation, output_kernel, padding_left, padding_right, padding_top, + padding_bottom); +} + +// TODO(ezhulenev): Non-templated `operator()` are required by explicit template +// instantiations for the GPU device. However they are almost certainly not used +// in any of the kernel implementation. Check if they can be removed. +template +struct SpatialConvolution { + void operator()(const Device& d, typename TTypes::Tensor output, + typename TTypes::ConstTensor input, + typename TTypes::ConstTensor filter, int row_stride, + int col_stride, int row_dilation, int col_dilation, + const Eigen::PaddingType& padding, + const OutputKernel& output_kernel = OutputKernel()) { + SpatialConvolutionFunc(d, output, input, filter, row_stride, col_stride, + row_dilation, col_dilation, padding, output_kernel); + } + + template + void operator()(const Device& d, Output output, Input input, Filter filter, + int row_stride, int col_stride, int row_dilation, + int col_dilation, const Eigen::PaddingType& padding, + const OutputKernel& output_kernel = OutputKernel()) { + SpatialConvolutionFunc(d, output, input, filter, row_stride, col_stride, + row_dilation, col_dilation, padding, output_kernel); + } + + void operator()(const Device& d, typename TTypes::Tensor output, + typename TTypes::ConstTensor input, + typename TTypes::ConstTensor filter, int row_stride, + int col_stride, int row_dilation, int col_dilation, + int padding_top, int padding_bottom, int padding_left, + int padding_right, + const OutputKernel& output_kernel = OutputKernel()) { + SpatialConvolutionFunc( + d, output, input, filter, row_stride, col_stride, row_dilation, + col_dilation, Eigen::PaddingType::PADDING_VALID, output_kernel, + padding_top, padding_bottom, padding_left, padding_right); + } + + template + void operator()(const Device& d, Output output, Input input, Filter filter, + int row_stride, int col_stride, int row_dilation, + int col_dilation, int padding_top, int padding_bottom, + int padding_left, int padding_right, + const OutputKernel& output_kernel = OutputKernel()) { + SpatialConvolutionFunc( + d, output, input, filter, row_stride, col_stride, row_dilation, + col_dilation, Eigen::PaddingType::PADDING_VALID, output_kernel, + padding_top, padding_bottom, padding_left, padding_right); + } +}; + +template +struct SpatialConvolution { + void operator()(const Device& d, + typename TTypes::Tensor output, + typename TTypes::ConstTensor input, + typename TTypes::ConstTensor filter, + int row_stride, int col_stride, int row_dilation, + int col_dilation, const Eigen::PaddingType& padding, + const OutputKernel& output_kernel = OutputKernel()) { + if (Conv2dUseFp16Accumulate()) { + output.device(d) = Eigen::SpatialConvolution( + input, filter, col_stride, row_stride, padding, col_dilation, + row_dilation, output_kernel); + } else { + output.device(d) = + Eigen::SpatialConvolution(input.cast(), filter.cast(), + col_stride, row_stride, padding, + col_dilation, row_dilation, output_kernel) + .template cast(); + } + } + + template + void operator()(const Device& d, Output output, Input input, Filter filter, + int row_stride, int col_stride, int row_dilation, + int col_dilation, const Eigen::PaddingType& padding, + const OutputKernel& output_kernel = OutputKernel()) { + if (Conv2dUseFp16Accumulate()) { + output.device(d) = Eigen::SpatialConvolution( + input, filter, col_stride, row_stride, padding, col_dilation, + row_dilation, output_kernel); + } else { + output.device(d) = + Eigen::SpatialConvolution(input.template cast(), + filter.template cast(), col_stride, + row_stride, padding, col_dilation, + row_dilation, output_kernel) + .template cast(); + } + } + + void operator()(const Device& d, + typename TTypes::Tensor output, + typename TTypes::ConstTensor input, + typename TTypes::ConstTensor filter, + int row_stride, int col_stride, int row_dilation, + int col_dilation, int padding_top, int padding_bottom, + int padding_left, int padding_right, + const OutputKernel& output_kernel = OutputKernel()) { + if (Conv2dUseFp16Accumulate()) { + output.device(d) = Eigen::SpatialConvolution( + input, filter, col_stride, row_stride, + Eigen::PaddingType::PADDING_VALID, col_dilation, row_dilation, + output_kernel, padding_left, padding_right, padding_top, + padding_bottom); + } else { + output.device(d) = + Eigen::SpatialConvolution( + input.cast(), filter.cast(), col_stride, row_stride, + Eigen::PaddingType::PADDING_VALID, col_dilation, row_dilation, + output_kernel, padding_left, padding_right, padding_top, + padding_bottom) + .template cast(); + } + } + + template + void operator()(const Device& d, Output output, Input input, Filter filter, + int row_stride, int col_stride, int row_dilation, + int col_dilation, int padding_top, int padding_bottom, + int padding_left, int padding_right, + const OutputKernel& output_kernel = OutputKernel()) { + if (Conv2dUseFp16Accumulate()) { + output.device(d) = Eigen::SpatialConvolution( + input, filter, col_stride, row_stride, + Eigen::PaddingType::PADDING_VALID, col_dilation, row_dilation, + output_kernel, padding_left, padding_right, padding_top, + padding_bottom); + } else { + output.device(d) = + Eigen::SpatialConvolution( + input.template cast(), filter.template cast(), + col_stride, row_stride, Eigen::PaddingType::PADDING_VALID, + col_dilation, row_dilation, output_kernel, padding_left, + padding_right, padding_top, padding_bottom) + .template cast(); + } + } +}; + +// Use float32 accumulation for bfloat16 to deal with precision accumulation +// issues. +template +struct SpatialConvolution { + void operator()(const Device& d, + typename TTypes::Tensor output, + typename TTypes::ConstTensor input, + typename TTypes::ConstTensor filter, + int row_stride, int col_stride, int row_dilation, + int col_dilation, const Eigen::PaddingType& padding, + const OutputKernel& output_kernel = OutputKernel()) { + output.device(d) = + Eigen::SpatialConvolution(input.cast(), filter.cast(), + col_stride, row_stride, padding, col_dilation, + row_dilation, output_kernel) + .template cast(); + } + + template + void operator()(const Device& d, Output output, Input input, Filter filter, + int row_stride, int col_stride, int row_dilation, + int col_dilation, const Eigen::PaddingType& padding, + const OutputKernel& output_kernel = OutputKernel()) { + output.device(d) = + Eigen::SpatialConvolution(input.template cast(), + filter.template cast(), col_stride, + row_stride, padding, col_dilation, + row_dilation, output_kernel) + .template cast(); + } + + void operator()(const Device& d, + typename TTypes::Tensor output, + typename TTypes::ConstTensor input, + typename TTypes::ConstTensor filter, + int row_stride, int col_stride, int row_dilation, + int col_dilation, int padding_top, int padding_bottom, + int padding_left, int padding_right, + const OutputKernel& output_kernel = OutputKernel()) { + output.device(d) = + Eigen::SpatialConvolution( + input.cast(), filter.cast(), col_stride, row_stride, + Eigen::PaddingType::PADDING_VALID, col_dilation, row_dilation, + output_kernel, padding_left, padding_right, padding_top, + padding_bottom) + .template cast(); + } + + template + void operator()(const Device& d, Output output, Input input, Filter filter, + int row_stride, int col_stride, int row_dilation, + int col_dilation, int padding_top, int padding_bottom, + int padding_left, int padding_right, + const OutputKernel& output_kernel = OutputKernel()) { + output.device(d) = + Eigen::SpatialConvolution( + input.template cast(), filter.template cast(), + col_stride, row_stride, Eigen::PaddingType::PADDING_VALID, + col_dilation, row_dilation, output_kernel, padding_left, + padding_right, padding_top, padding_bottom) + .template cast(); + } +}; + +template +struct SpatialConvolutionBackwardInputFunc { + void operator()(const Device& d, typename TTypes::Tensor input_backward, + typename TTypes::ConstTensor filter, + typename TTypes::ConstTensor output_backward, + Eigen::DenseIndex col_stride, Eigen::DenseIndex row_stride, + Eigen::DenseIndex col_dilation, + Eigen::DenseIndex row_dilation) { + input_backward.device(d) = Eigen::SpatialConvolutionBackwardInput( + filter, output_backward, input_backward.dimension(2), + input_backward.dimension(1), col_stride, row_stride, col_dilation, + row_dilation); + } +}; + +// GPU version requires all tensors to be indexable by int32. +template +struct SpatialConvolutionBackwardInputFunc { + void operator()(const Eigen::GpuDevice& d, + typename TTypes::Tensor input_backward, + typename TTypes::ConstTensor filter, + typename TTypes::ConstTensor output_backward, + Eigen::DenseIndex col_stride, Eigen::DenseIndex row_stride, + Eigen::DenseIndex col_dilation, + Eigen::DenseIndex row_dilation) { + To32Bit(input_backward).device(d) = Eigen::SpatialConvolutionBackwardInput( + To32Bit(filter), To32Bit(output_backward), input_backward.dimension(2), + input_backward.dimension(1), col_stride, row_stride, col_dilation, + row_dilation); + } +}; + +template +struct SpatialConvolutionBackwardInputWithExplicitPaddingFunc { + void operator()(const Device& d, typename TTypes::Tensor input_backward, + typename TTypes::ConstTensor filter, + typename TTypes::ConstTensor output_backward, + Eigen::DenseIndex padded_cols, Eigen::DenseIndex padded_rows, + Eigen::DenseIndex col_stride, Eigen::DenseIndex row_stride, + Eigen::DenseIndex col_dilation, + Eigen::DenseIndex row_dilation, Eigen::DenseIndex pad_left, + Eigen::DenseIndex pad_top) { + // We have to slice the result of a spatial convolution backward + // input, before assigning it to the `input_backward` to remove padding. + // + // TODO(ezhulenev): Pass explicit paddings to Eigen and do not materialize + // intermediate result in memory before slicing. + input_backward.device(d) = + Eigen::SpatialConvolutionBackwardInput( + filter, output_backward, padded_cols, padded_rows, col_stride, + row_stride, col_dilation, row_dilation) + .eval() + .slice(Eigen::DSizes{0, pad_left, pad_top, 0}, + input_backward.dimensions()); + } +}; + +// GPU version requires all tensors to be indexable by int32. +template +struct SpatialConvolutionBackwardInputWithExplicitPaddingFunc { + void operator()(const Eigen::GpuDevice& d, + typename TTypes::Tensor input_backward, + typename TTypes::ConstTensor filter, + typename TTypes::ConstTensor output_backward, + Eigen::DenseIndex padded_cols, Eigen::DenseIndex padded_rows, + Eigen::DenseIndex col_stride, Eigen::DenseIndex row_stride, + Eigen::DenseIndex col_dilation, + Eigen::DenseIndex row_dilation, Eigen::DenseIndex pad_left, + Eigen::DenseIndex pad_top) { + To32Bit(input_backward).device(d) = + Eigen::SpatialConvolutionBackwardInput( + To32Bit(filter), To32Bit(output_backward), padded_cols, padded_rows, + col_stride, row_stride, col_dilation, row_dilation) + .eval() + .slice(Eigen::DSizes{0, pad_left, pad_top, 0}, + input_backward.dimensions()); + } +}; + +// TODO(vrv): Figure out how to use the MatMulFunctor in matmul_op.h. +// My initial attempt to do this compiled but failed in the pytest +// due to a swigdeps error. +template +struct MatMulConvFunctor { + // Computes on device "d": out = in0 * in1, where * is matrix + // multiplication. + void operator()( + const Device& d, typename TTypes::Tensor out, + typename TTypes::ConstTensor in0, + typename TTypes::ConstTensor in1, + const Eigen::array, 1>& dim_pair, + const OutputKernel& output_kernel = OutputKernel()) { + out.device(d) = in0.contract(in1, dim_pair, output_kernel); + } +}; + +// Use float32 accumulation for float16 by default to deal with precision +// accumulation issues. To enable float16 accumulation, set the environment +// variable TF_CONV2D_USE_FP16_ACCUMULATE. +template +struct MatMulConvFunctor { + // Computes on device "d": out = in0 * in1, where * is matrix + // multiplication. + void operator()( + const Device& d, typename TTypes::Tensor out, + typename TTypes::ConstTensor in0, + typename TTypes::ConstTensor in1, + const Eigen::array, 1>& dim_pair, + const OutputKernel& output_kernel = OutputKernel()) { + if (Conv2dUseFp16Accumulate()) { + out.device(d) = in0.contract(in1, dim_pair, output_kernel); + } else { + out.device(d) = + in0.cast() + .contract(in1.template cast(), dim_pair, output_kernel) + .template cast(); + } + } +}; + +// Use float32 accumulation for bfloat16 to deal with precision accumulation +// issues. +template +struct MatMulConvFunctor { + void operator()( + const Device& d, typename TTypes::Tensor out, + typename TTypes::ConstTensor in0, + typename TTypes::ConstTensor in1, + const Eigen::array, 1>& dim_pair, + const OutputKernel& output_kernel = OutputKernel()) { + out.device(d) = in0.cast() + .contract(in1.cast(), dim_pair, output_kernel) + .template cast(); + } +}; + +// Shuffles a filter tensor from TensorFlow format HWIO to dst_filter_format. +// +// Note: Currently supports OIHW and OHWI destination formats. +template +struct TransformFilter { + void operator()(const Device& d, FilterTensorFormat dst_filter_format, + typename TTypes::ConstTensor in, + typename TTypes::Tensor out) { + // NOTE: Source filter format is always HWIO. + Eigen::DSizes spatial_dims; + for (int i = 0; i < spatial_dims.rank(); ++i) { + spatial_dims[i] = in.dimension(i); + } + + // Merge the spatial dimensions together to speed up the shuffle operation. + Eigen::DSizes merged_dims; + merged_dims[0] = spatial_dims.TotalSize(); // product of spatial dims [H*W] + merged_dims[1] = in.dimension(NDIMS - 2); // input filters [I] + merged_dims[2] = in.dimension(NDIMS - 1); // output filters [O] + + // Shuffle tensor with merged spatial dimensions. + Eigen::DSizes shuffling_perm; + // Expand shuffled tensor into final dimensions. + Eigen::DSizes expanded_dims; + + if (dst_filter_format == FORMAT_OIHW) { + shuffling_perm = Eigen::DSizes(2, 1, 0); + + expanded_dims[0] = merged_dims[2]; // [O] + expanded_dims[1] = merged_dims[1]; // [I] + for (int i = 0; i < spatial_dims.rank(); ++i) { + expanded_dims[2 + i] = spatial_dims[i]; + } + + } else if (dst_filter_format == FORMAT_OHWI) { + shuffling_perm = Eigen::DSizes(2, 0, 1); + + expanded_dims[0] = merged_dims[2]; // [O] + expanded_dims[NDIMS - 1] = merged_dims[1]; // [I] + for (int i = 0; i < spatial_dims.rank(); ++i) { + expanded_dims[1 + i] = spatial_dims[i]; + } + + } else { + DCHECK(false) << "Unsupported destination filter format: " + << ToString(dst_filter_format); + } + + out.device(d) = + in.reshape(merged_dims).shuffle(shuffling_perm).reshape(expanded_dims); + } +}; + +// TODO This functor is not used anywhere and should be removed, +// but it defines some eigen templates that are referenced in other kernels. +template +struct TransformDepth { + void operator()(const Device& d, + typename TTypes::ConstTensor in, + const Eigen::DSizes& shuffle, + typename TTypes::Tensor out) { + Eigen::DSizes merged_dims; + Eigen::DSizes expanded_dims; + Eigen::DSizes new_shuffle; + + // Merge dimensions that won't be shuffled together to speed things up. + if (shuffle[1] == 2 && shuffle[2] == 3) { + merged_dims[0] = in.dimension(0); + merged_dims[1] = in.dimension(1); + merged_dims[2] = in.dimension(2) * in.dimension(3); + new_shuffle[0] = shuffle[0]; + new_shuffle[1] = 2; + new_shuffle[2] = shuffle[3]; + expanded_dims[0] = in.dimension(shuffle[0]); + expanded_dims[1] = in.dimension(2); + expanded_dims[2] = in.dimension(3); + expanded_dims[3] = in.dimension(shuffle[3]); + } else if (shuffle[0] == 2 && shuffle[1] == 3) { + merged_dims[0] = in.dimension(0); + merged_dims[1] = in.dimension(1); + merged_dims[2] = in.dimension(2) * in.dimension(3); + new_shuffle[0] = 2; + new_shuffle[1] = shuffle[2]; + new_shuffle[2] = shuffle[3]; + expanded_dims[0] = in.dimension(2); + expanded_dims[1] = in.dimension(3); + expanded_dims[2] = in.dimension(shuffle[2]); + expanded_dims[3] = in.dimension(shuffle[3]); + } else if (shuffle[0] == 0 && shuffle[1] == 3 && shuffle[2] == 1 && + shuffle[3] == 2) { + merged_dims[0] = in.dimension(0); + merged_dims[1] = in.dimension(1) * in.dimension(2); + merged_dims[2] = in.dimension(3); + new_shuffle[0] = 0; + new_shuffle[1] = 2; + new_shuffle[2] = 1; + expanded_dims[0] = in.dimension(0); + expanded_dims[1] = in.dimension(3); + expanded_dims[2] = in.dimension(1); + expanded_dims[3] = in.dimension(2); + } else { + assert(false && "unexpected shuffle"); + } + + out.device(d) = + in.reshape(merged_dims).shuffle(new_shuffle).reshape(expanded_dims); + } +}; + +template +struct PadInput { + void operator()(const Device& d, + typename TTypes::ConstTensor in, + const std::array& padding_left, + const std::array& padding_right, + typename TTypes::Tensor out, + TensorFormat format, const T& padding_value) { + Eigen::array, NDIMS> padding; + padding[GetTensorDimIndex(format, 'N')] = {0, 0}; + for (int i = 0; i < NDIMS - 2; ++i) { + padding[GetTensorDimIndex(format, '0' + i)] = { + padding_left[i], padding_right[i]}; + } + padding[GetTensorDimIndex(format, 'C')] = {0, 0}; + out.device(d) = in.pad(padding, padding_value); + } +}; + +// Converts a tensor from: +// [batch, , filters] +// to: +// [batch, filters, ] +template +struct NHWCToNCHW { + void operator()(const Device& d, typename TTypes::ConstTensor in, + typename TTypes::Tensor out); +}; + +// Converts a tensor from: +// [batch, filters, ] +// to: +// [batch, , filters] +template +struct NCHWToNHWC { + void operator()(const Device& d, typename TTypes::ConstTensor in, + typename TTypes::Tensor out); +}; + +// Converts a tensor from: +// [dim0, dim1, dim2] +// to: +// [dim0, dim2, dim1] +template +struct SwapDimension1And2InTensor3 { + void operator()(const Device& d, const T* in, + const absl::Span& input_dims, T* out); +}; + +// Converts a tensor from: +// [dim0, dim1, dim2] +// to: +// [dim2, dim1, dim0] +template +struct SwapDimension0And2InTensor3 { + void operator()(const Device& d, const T* in, + const absl::Span& input_dims, T* out); +}; + +// Transforms back filter from OIHW or OHWI to HWOI format to reverse effect of +// TransformFilter above. +template +struct ReverseTransformFilter { + void operator()(const Device& d, FilterTensorFormat src_filter_format, + typename TTypes::ConstTensor in, + typename TTypes::Tensor out); +}; + +} // namespace functor + +template +class ConvAlgorithmMap; + +template <> +class ConvAlgorithmMap {}; +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_CONV_2D_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/conv_2d_gpu.h b/third_party/tflite-hdrs/tensorflow/core/kernels/conv_2d_gpu.h new file mode 100644 index 00000000..60d2e831 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/conv_2d_gpu.h @@ -0,0 +1,1147 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_CONV_2D_GPU_H_ +#define TENSORFLOW_CORE_KERNELS_CONV_2D_GPU_H_ + +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM + +#define EIGEN_USE_GPU + +#include +#include +#include +#include + +#if GOOGLE_CUDA +#include "third_party/gpus/cuda/include/cuda.h" +#endif +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/kernels/conv_2d.h" +#include "tensorflow/core/lib/math/math_util.h" +#include "tensorflow/core/util/gpu_kernel_helper.h" +#include "tensorflow/core/util/tensor_format.h" + +namespace tensorflow { + +typedef Eigen::GpuDevice GPUDevice; + +namespace functor { + +template +struct maybe_conj { + __device__ static __inline__ T run(T x) { + if (conjugate) { + return Eigen::numext::conj(x); + } else { + return x; + } + } +}; + +// Partial specializations for Gpu types used to store complex numbers. +template +struct maybe_conj { + __device__ static __inline__ float2 run(float2 c) { + if (conjugate) { + float2 c_conj; + c_conj.x = c.x; + c_conj.y = -c.y; + return c_conj; + } else { + return c; + } + } +}; + +template +struct maybe_conj { + __device__ static __inline__ double2 run(double2 c) { + if (conjugate) { + double2 c_conj; + c_conj.x = c.x; + c_conj.y = -c.y; + return c_conj; + } else { + return c; + } + } +}; + +// TODO(mjanusz): Move this to a shared util file. +// A simple array that contains data that can be passed between CPU and GPU. +template +struct Array { + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T& operator[](int index) const { + return data[index]; + } + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T& operator[](int index) { + return data[index]; + } + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Array() { + for (int i = 0; i < IndexCount; i++) { + data[i] = DefaultValue; + } + } + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Array(T a0) { + data[0] = a0; + for (int i = 1; i < IndexCount; i++) { + data[i] = DefaultValue; + } + } + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Array(T a0, T a1) { + data[0] = a0; + data[1] = a1; + for (int i = 2; i < IndexCount; i++) { + data[i] = DefaultValue; + } + } + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Array(T a0, T a1, T a2) { + data[0] = a0; + data[1] = a1; + data[2] = a2; + for (int i = 3; i < IndexCount; i++) { + data[i] = DefaultValue; + } + } + EIGEN_STRONG_INLINE Array(const std::array& array) { + for (int i = 0; i < IndexCount; i++) { + data[i] = array[i]; + } + } + T data[IndexCount]; +}; + +// A dimension type with compile-time known size. +template +struct Dimension : Array { + typedef Array Base; + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Dimension() : Base() {} + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Dimension(int a0) : Base(a0) {} + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Dimension(int a0, int a1) + : Base(a0, a1) {} + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Dimension(int a0, int a1, int a2) + : Base(a0, a1, a2) {} + EIGEN_STRONG_INLINE Dimension(const std::array& array) + : Base(array) {} +}; + +// An index type with compile-time known size. +template +struct Index : Array { + typedef Array Base; + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index() : Base() {} + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index(int a0) : Base(a0) {} + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index(int a0, int a1) : Base(a0, a1) {} + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index(int a0, int a1, int a2) + : Base(a0, a1, a2) {} +}; + +// A helper function that converts a tensor index into a flat array index. +template +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE int TensorIndexToFlat( + const Index& index, const Dimension& dims) { + int flat_index = index[0]; + for (int i = 1; i < IndexCount; i++) { + flat_index = flat_index * dims[i] + index[i]; + } + return flat_index; +} + +// A helper function that converts a flat array index into a tensor index. +template +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index FlatToTensorIndex( + int index, const Dimension& dims) { + Index tensor_index; + for (int i = IndexCount - 1; i >= 0; i--) { + int new_index = index / dims[i]; + tensor_index[i] = index - dims[i] * new_index; + index = new_index; + } + return tensor_index; +} + +// A simple CUDA custom kernel to shuffle dimensions of a 3D tensor according to +// the given shuffle permutation in template parameters. Shuffle permutation +// shuffles dimensions such that input dimension 0 goes to sp0, +// 1 goes to sp1 and 2 goes to sp2. For example, shuffle permutation <2, 0, 1> +// will populate output so that input[x][y][z] is equal to (*output)[y][z][x]. +// +// Requires that nthreads is equal to the total number of elements in the input +// tensor. +template +__global__ void ShuffleInTensor3Simple(int nthreads, + const T* __restrict__ input, + Dimension<3> input_dims, + T* __restrict__ output) { + Dimension<3> output_dims; + output_dims[sp0] = input_dims[0]; + output_dims[sp1] = input_dims[1]; + output_dims[sp2] = input_dims[2]; + + // Iterate over output as opposed to iterating over input for better + // performance. Iterating over output will generate sequential writes and + // random reads that performs better compared to sequential reads and random + // writes. + GPU_1D_KERNEL_LOOP(output_index, nthreads) { + Index<3> output_tensor_index = FlatToTensorIndex(output_index, output_dims); + + Index<3> input_tensor_index; + input_tensor_index[0] = output_tensor_index[sp0]; + input_tensor_index[1] = output_tensor_index[sp1]; + input_tensor_index[2] = output_tensor_index[sp2]; + + int input_index = TensorIndexToFlat(input_tensor_index, input_dims); + + output[output_index] = + maybe_conj::run(ldg(input + input_index)); + } +} + +static constexpr int kUnroll = 4; + +template +__global__ void ShuffleInTensor3SimpleVector(int nthreads, + const T* __restrict__ input, + Dimension<3> input_dims, + T* __restrict__ output) { + Dimension<3> output_dims; + output_dims[sp0] = input_dims[0]; + output_dims[sp1] = input_dims[1]; + output_dims[sp2] = input_dims[2]; + + const int stride = blockDim.x * gridDim.x * kUnroll; + const int tid = blockIdx.x * blockDim.x + threadIdx.x; + T buf[kUnroll]; + + int output_index; + for (output_index = tid * kUnroll; output_index + kUnroll - 1 < nthreads; + output_index += stride) { +#pragma unroll + for (int i = 0; i < kUnroll; i++) { + int output_index_i = output_index + i; + Index<3> output_tensor_index = + FlatToTensorIndex(output_index_i, output_dims); + Index<3> input_tensor_index; + input_tensor_index[0] = output_tensor_index[sp0]; + input_tensor_index[1] = output_tensor_index[sp1]; + input_tensor_index[2] = output_tensor_index[sp2]; + + int input_index_i = TensorIndexToFlat(input_tensor_index, input_dims); + buf[i] = maybe_conj::run(ldg(input + input_index_i)); + } + float2* out = reinterpret_cast(output + output_index); + *out = *reinterpret_cast(buf); + } + + for (; output_index < nthreads; ++output_index) { + Index<3> output_tensor_index = FlatToTensorIndex(output_index, output_dims); + + Index<3> input_tensor_index; + input_tensor_index[0] = output_tensor_index[sp0]; + input_tensor_index[1] = output_tensor_index[sp1]; + input_tensor_index[2] = output_tensor_index[sp2]; + + int input_index = TensorIndexToFlat(input_tensor_index, input_dims); + + output[output_index] = + maybe_conj::run(ldg(input + input_index)); + } +} + +// Use shared memory tiles to swap dimension-1 and dimension-2 of a 3D tensor, +// where dimensions are zero-based: output[i][j][k] = input[i][k][j]. +// +// Each thread block operates on a single tile, a rectangle of dimensions +// TileSizeI x TileSizeJ. +// +// In general, for best performance, you should probably set TileSizeI, +// TileSizeJ equal to the number of threads in a warp (32 in nvidia GPUs). +// With a TileSizeI, TileSizeJ of 32, NumThreads of 128 or 256 seems to get +// the best performance on K40 GPUs. +template +__global__ void SwapDimension1And2InTensor3UsingTiles( + const T* __restrict__ input, Dimension<3> input_dims, + T* __restrict__ output) { + eigen_assert(blockDim.x == NumThreads); + eigen_assert(blockDim.y == 1); + eigen_assert(blockDim.z == 1); + eigen_assert(gridDim.y == 1); + eigen_assert(gridDim.z == 1); + + constexpr int ReadRowPerPass = NumThreads / TileSizeJ; + constexpr int WriteRowPerPass = NumThreads / TileSizeI; + // One extra line in the inner dimension to avoid share memory bank conflict. + // This is to mimic the following, but no constructor of T can be invoked. + // __shared__ T shared_memory_tile[TileSizeI][TileSizeJ + 1]; +#if GOOGLE_CUDA + __shared__ __align__( + alignof(T)) char shared_mem_raw[TileSizeI * (TileSizeJ + 1) * sizeof(T)]; + typedef T(*SharedMemoryTile)[TileSizeJ + 1]; + SharedMemoryTile shared_memory_tile = + reinterpret_cast(shared_mem_raw); +#elif TENSORFLOW_USE_ROCM + __shared__ T shared_memory_tile[TileSizeI][TileSizeJ + 1]; +#endif + + int x = threadIdx.x; + + Dimension<3> output_dims = { + input_dims[0], + input_dims[2], + input_dims[1], + }; + + Dimension<3> input_dims_in_tiles = { + input_dims[0], + (input_dims[1] + TileSizeI - 1) / TileSizeI, + (input_dims[2] + TileSizeJ - 1) / TileSizeJ, + }; + + Index<3> input_tile_index = + FlatToTensorIndex(blockIdx.x, input_dims_in_tiles); + + Index<3> input_tile_origin = { + input_tile_index[0], + input_tile_index[1] * TileSizeI, + input_tile_index[2] * TileSizeJ, + }; + + int input_origin_flat_index = + TensorIndexToFlat(input_tile_origin, input_dims); + + bool full_tile = true; + int tile_width = TileSizeJ; + + // Only the last row or column may not have the full size. + if (input_tile_index[2] == input_dims_in_tiles[2] - 1) { + tile_width = input_dims[2] - (input_dims_in_tiles[2] - 1) * TileSizeJ; + full_tile &= false; + } + + int tile_height = TileSizeI; + + if (input_tile_index[1] == input_dims_in_tiles[1] - 1) { + tile_height = input_dims[1] - (input_dims_in_tiles[1] - 1) * TileSizeI; + full_tile &= false; + } + + // Calculate effective thread number. This ensures that we use the largest + // number of threads available to form a regular thread block with no + // trailing incomplete lines. + constexpr int in_effective_thread_num = NumThreads / TileSizeJ * TileSizeJ; + + if (x < in_effective_thread_num) { + // Orient the logical thread block with respect to the input array. + // ie. align the contiguous dimension of thread blocks with the contiguous + // dimension of the input array. + int ti = x / TileSizeJ; + int tj = x % TileSizeJ; + int input_index = input_origin_flat_index + ti * input_dims[2] + tj; + int input_increment = ReadRowPerPass * input_dims[2]; + + if (full_tile) { +#pragma unroll + for (int i_loc = ti; i_loc < (TileSizeI); i_loc += ReadRowPerPass) { + shared_memory_tile[i_loc][tj] = + maybe_conj::run(input[input_index]); + input_index += input_increment; + } + } else { + if (tj < tile_width) { + for (int i_loc = ti; i_loc < (tile_height); i_loc += ReadRowPerPass) { + shared_memory_tile[i_loc][tj] = + maybe_conj::run(input[input_index]); + input_index += input_increment; + } + } + } + } + + __syncthreads(); + + Index<3> output_tile_index = { + input_tile_index[0], + input_tile_index[2], + input_tile_index[1], + }; + + Index<3> output_tile_origin = { + output_tile_index[0], + output_tile_index[1] * TileSizeJ, + output_tile_index[2] * TileSizeI, + }; + + int output_origin_flat_index = + TensorIndexToFlat(output_tile_origin, output_dims); + + constexpr int out_effective_thread_num = NumThreads / TileSizeI * TileSizeI; + + if (x < out_effective_thread_num) { + // Re-orient the logical thread block with respect to the output array. + // ie. align the contiguous dimension of thread blocks with contiguous + // dimension of the output array. + int ti = x / TileSizeI; + int tj = x % TileSizeI; + int output_index = output_origin_flat_index + ti * output_dims[2] + tj; + int output_increment = WriteRowPerPass * output_dims[2]; + + if (full_tile) { +#pragma unroll + for (int i_loc = ti; i_loc < (TileSizeJ); i_loc += WriteRowPerPass) { + output[output_index] = shared_memory_tile[tj][i_loc]; + output_index += output_increment; + } + } else { + if (tj < tile_height) { + for (int i_loc = ti; i_loc < (tile_width); i_loc += WriteRowPerPass) { + output[output_index] = shared_memory_tile[tj][i_loc]; + output_index += output_increment; + } + } + } + } +} + +// A Gpu custom kernel that convert input to output, given proper padding on +// the left and the top. +template +__global__ void PadInputCustomKernelNHWC( + int nthreads, const T* __restrict__ input, Dimension input_dims, + T* __restrict__ output, Dimension output_dims, + Dimension padding_left, T padding_value) { + GPU_1D_KERNEL_LOOP(index, nthreads) { + int output_index = index; + Index output_tensor_index = + FlatToTensorIndex(output_index, output_dims); + + Index input_tensor_index; + input_tensor_index[0] = output_tensor_index[0]; // batch + bool ok = true; + for (int i = 1; i < NDIMS - 1; i++) { + input_tensor_index[i] = output_tensor_index[i] - padding_left[i - 1]; + ok &= + (input_tensor_index[i] >= 0 && input_tensor_index[i] < input_dims[i]); + } + input_tensor_index[NDIMS - 1] = output_tensor_index[NDIMS - 1]; // channels + + if (ok) { + const int input_index = TensorIndexToFlat(input_tensor_index, input_dims); + output[output_index] = input[input_index]; + } else { + output[output_index] = padding_value; + } + } +} + +template +__global__ void PadInputCustomKernelNCHW( + int nthreads, const T* __restrict__ input, Dimension input_dims, + T* __restrict__ output, Dimension output_dims, + Dimension padding_left, T padding_value) { + GPU_1D_KERNEL_LOOP(index, nthreads) { + int output_index = index; + Index output_tensor_index = + FlatToTensorIndex(output_index, output_dims); + + Index input_tensor_index; + input_tensor_index[0] = output_tensor_index[0]; // batch + input_tensor_index[1] = output_tensor_index[1]; // channels + bool ok = true; + for (int i = 2; i < NDIMS; i++) { + input_tensor_index[i] = output_tensor_index[i] - padding_left[i - 2]; + ok &= + (input_tensor_index[i] >= 0 && input_tensor_index[i] < input_dims[i]); + } + + if (ok) { + const int input_index = TensorIndexToFlat(input_tensor_index, input_dims); + output[output_index] = input[input_index]; + } else { + output[output_index] = padding_value; + } + } +} + +// A GPU helper function that converts TensorFlow filter format to Cudnn filter +// format. +template +struct TransformFilter { + typedef GPUDevice Device; + void operator()(const Device& d, FilterTensorFormat dst_filter_format, + typename TTypes::ConstTensor in, + typename TTypes::Tensor out) { + Dimension<3> combined_dims; + combined_dims[0] = in.dimension(0); // spatial dimensions + for (int i = 1; i < NDIMS - 2; i++) { + combined_dims[0] *= in.dimension(i); + } + combined_dims[1] = in.dimension(NDIMS - 2); // input filters + combined_dims[2] = in.dimension(NDIMS - 1); // output filters + GpuLaunchConfig config = GetGpuLaunchConfig(out.size(), d); + + if (dst_filter_format == FORMAT_OIHW) { + TF_CHECK_OK(GpuLaunchKernel(ShuffleInTensor3Simple, + config.block_count, config.thread_per_block, + 0, d.stream(), config.virtual_thread_count, + in.data(), combined_dims, out.data())); + + } else if (dst_filter_format == FORMAT_OHWI) { + TF_CHECK_OK(GpuLaunchKernel(ShuffleInTensor3Simple, + config.block_count, config.thread_per_block, + 0, d.stream(), config.virtual_thread_count, + in.data(), combined_dims, out.data())); + + } else { + LOG(ERROR) << "Unsupported filter format: " + << ToString(dst_filter_format); + } + } +}; + +// Converts Cudnn filter format OIHW or OHWI back to TensorFlow filter format +// HWIO. +template +struct ReverseTransformFilter { + typedef GPUDevice Device; + void operator()(const Device& d, FilterTensorFormat src_filter_format, + typename TTypes::ConstTensor in, + typename TTypes::Tensor out) { + Dimension<3> combined_dims; + + if (src_filter_format == FORMAT_OIHW) { + combined_dims[0] = in.dimension(0); // output filters + combined_dims[1] = in.dimension(1); // input filters + combined_dims[2] = in.dimension(2); // spatial dimensions + for (int i = 3; i < NDIMS; ++i) { + combined_dims[2] *= in.dimension(i); + } + + GpuLaunchConfig config = GetGpuLaunchConfig(out.size(), d); + TF_CHECK_OK(GpuLaunchKernel(ShuffleInTensor3Simple, + config.block_count, config.thread_per_block, + 0, d.stream(), config.virtual_thread_count, + in.data(), combined_dims, out.data())); + + } else if (src_filter_format == FORMAT_OHWI) { + combined_dims[0] = in.dimension(0); // output filters + combined_dims[1] = in.dimension(1); // spatial dimensions + for (int i = 2; i < NDIMS - 1; i++) { + combined_dims[1] *= in.dimension(i); + } + combined_dims[2] = in.dimension(NDIMS - 1); // input filters + + GpuLaunchConfig config = GetGpuLaunchConfig(out.size(), d); + TF_CHECK_OK(GpuLaunchKernel(ShuffleInTensor3Simple, + config.block_count, config.thread_per_block, + 0, d.stream(), config.virtual_thread_count, + in.data(), combined_dims, out.data())); + + } else { + // TODO(ezhulenev): Set error status in OpKernelContext instead. + LOG(FATAL) << "Unsupported filter format: " + << ToString(src_filter_format); + } + } +}; + +// A GPU helper function that converts input tensor to a larger output tensor, +// given proper padding values. The padded value is zero. +template +struct PadInput { + typedef GPUDevice Device; + void operator()(const Device& d, + typename TTypes::ConstTensor in, + const std::array& padding_left, + const std::array& padding_right, + typename TTypes::Tensor out, + TensorFormat format, const T& padding_value) { + GpuLaunchConfig config = GetGpuLaunchConfig(out.size(), d); + Dimension input_dims; + for (int i = 0; i < NDIMS; ++i) { + input_dims[i] = in.dimension(i); + } + Dimension output_dims; + for (int i = 0; i < NDIMS; ++i) { + output_dims[i] = out.dimension(i); + } + + const Dimension padding_left_dim(padding_left); + + if (format == FORMAT_NHWC) { + TF_CHECK_OK(GpuLaunchKernel( + PadInputCustomKernelNHWC, config.block_count, + config.thread_per_block, 0, d.stream(), config.virtual_thread_count, + in.data(), input_dims, out.data(), output_dims, padding_left_dim, + padding_value)); + } else if (format == FORMAT_NCHW) { + TF_CHECK_OK(GpuLaunchKernel( + PadInputCustomKernelNCHW, config.block_count, + config.thread_per_block, 0, d.stream(), config.virtual_thread_count, + in.data(), input_dims, out.data(), output_dims, padding_left_dim, + padding_value)); + } else { + LOG(FATAL) << "Invalid data format: " << format; + } + } +}; + +// We want std::equal_to and std::greater, but they're not constexpr until +// C++14. +struct EqualTo { + constexpr bool operator()(int a, int b) const { return a == b; } +}; + +struct GreaterThan { + constexpr bool operator()(int a, int b) const { return a > b; } +}; + +// For each data type, the tile size possibility frontier denotes the tile size +// combinations that consume the most computational resources constrained by +// - number of threads per SM limit, +// - limit on size of the short dimension (<=15) due to the definition of +// narrow matrix, +// - shared memory limit and +// - some experimentally determined, type-specific constraint on the product of +// two side lengths to increase grid-level parallelism. +// +// A tile size combination lies on the frontier if and only if one or more +// constraint mentioned above is hit. Tile size combinations lying outside this +// frontier are either not possible, or are slower than the alternatives. +// +// It is instrumental to consider, for each data type, two subsets of the +// corresponding frontier: +// - long side frontier: the union of the biggest tile size combination for +// each legal long side len. +// - non long side frontier: the frontier set minus the long side frontier. +// +// TileSizePossibilityFrontierCheck defines the frontier using only the long +// side frontier tile size combinations (since one can easily extrapolate +// the entire frontier from this subset). It serves as a utility function +// to help us determine where a tile size combination of interest lies with +// resepect to the frontier. +template +constexpr bool TileSizePossibilityFrontierCheck(int TileLongSide, + int TileShortSide, + int size_of_t, Op op) { + // clang-format off + + return (size_of_t == 16 && ((TileLongSide == 32 && op(TileShortSide, 4)) || + (TileLongSide == 64 && op(TileShortSide, 4)) || + (TileLongSide == 128 && op(TileShortSide, 4)) || + (TileLongSide == 256 && op(TileShortSide, 2)))) || + (size_of_t == 8 && ((TileLongSide == 32 && op(TileShortSide, 15)) || + (TileLongSide == 64 && op(TileShortSide, 15)) || + (TileLongSide == 128 && op(TileShortSide, 8)) || + (TileLongSide == 256 && op(TileShortSide, 4)) || + (TileLongSide == 512 && op(TileShortSide, 2)))) || + (size_of_t == 4 && ((TileLongSide == 32 && op(TileShortSide, 15)) || + (TileLongSide == 64 && op(TileShortSide, 15)) || + (TileLongSide == 128 && op(TileShortSide, 15)) || + (TileLongSide == 256 && op(TileShortSide, 8)) || + (TileLongSide == 512 && op(TileShortSide, 4)) || + (TileLongSide == 1024 && op(TileShortSide, 2)))) || + (size_of_t == 2 && ((TileLongSide == 32 && op(TileShortSide, 15)) || + (TileLongSide == 64 && op(TileShortSide, 15)) || + (TileLongSide == 128 && op(TileShortSide, 15)) || + (TileLongSide == 256 && op(TileShortSide, 8)) || + (TileLongSide == 512 && op(TileShortSide, 4)) || + (TileLongSide == 1024 && op(TileShortSide, 2)))) || + (size_of_t == 1 && ((TileLongSide == 32 && op(TileShortSide, 15)) || + (TileLongSide == 64 && op(TileShortSide, 15)) || + (TileLongSide == 128 && op(TileShortSide, 15)) || + (TileLongSide == 256 && op(TileShortSide, 8)) || + (TileLongSide == 512 && op(TileShortSide, 4)) || + (TileLongSide == 1024 && op(TileShortSide, 2)))); + + // clang-format on +} + +constexpr bool TileSizeOnLongSideFrontier(int TileLongSide, int TileShortSide, + int size_of_t) { + return TileSizePossibilityFrontierCheck(TileLongSide, TileShortSide, + size_of_t, EqualTo()); +} +constexpr bool TileSizeOutsideFrontier(int TileLongSide, int TileShortSide, + int size_of_t) { + return TileSizePossibilityFrontierCheck(TileLongSide, TileShortSide, + size_of_t, GreaterThan()); +} +constexpr bool TileSizeOnNonLongSideFrontier(int TileLongSide, + int TileShortSide, int size_of_t) { + // For a tile size combination (longside, shortside), lying on the frontier + // implies that (longside, shortside) is on or within the frontier but + // (longside*2, shortside) or (longside, shortside+1) is not. With the above + // criterion, we simply need to use !TileSizeOnLongSideFrontier to ensure that + // it is not on the long side frontier. + return !TileSizeOutsideFrontier(TileLongSide, TileShortSide, size_of_t) && + (TileSizeOutsideFrontier(TileLongSide * 2, TileShortSide, size_of_t) || + TileSizeOutsideFrontier(TileLongSide, TileShortSide + 1, + size_of_t)) && + !TileSizeOnLongSideFrontier(TileLongSide, TileShortSide, size_of_t); +} + +// Helper function to launch a batch narrow matirx transpose kernel. +template +void LaunchBatchNarrowMatrixTransposeKernel( + const GPUDevice& d, int tile_size_i, int tile_size_j, int total_tiles_count, + const T* input, const Dimension<3>& input_dims, T* output) { + constexpr int NumThreads = TileLongSide; + if (tile_size_i <= TileLongSide && tile_size_j <= TileShortSide) { + TF_CHECK_OK(GpuLaunchKernel( + SwapDimension1And2InTensor3UsingTiles, + total_tiles_count, NumThreads, 0, d.stream(), input, input_dims, + output)); + } else { + TF_CHECK_OK(GpuLaunchKernel( + SwapDimension1And2InTensor3UsingTiles, + total_tiles_count, NumThreads, 0, d.stream(), input, input_dims, + output)); + } +} + +// Recursive template function to search, in a trial-and-error manner, for the +// minimum tile size configuration satisfying the requested tile side lengths. +// An important invariant of this search procedure is that for an unsatisfied +// request, we always try doubling the long side len first, and only after +// the request is satisfied for the long side len do we begin incrementing +// the short side len. +// +// We have three specializations of this search function depending on where the +// current tile size combination lies with respect to the frontier. +// - It lies within the frontier. If request is not satisfied, for the next tile +// size combination, we first try doubling the long side len and if that does +// not work, we then increment the short side len. +// - It lies on the non long side frontier. If the request is not satisfied, we +// can only increment the short side len. +// - It lies on the long side frontier. We launch the kernel without checking if +// the request is satisfied or not. +template +struct BatchNarrowMatrixTransposeDispatcher { + static void DoIt(const GPUDevice& d, int tile_size_i, int tile_size_j, + int total_tiles_count, const T* input, + const Dimension<3>& input_dims, T* output) { + static_assert( + (TileLongSide & (TileLongSide - 1)) == 0, + "The length of the longer side of the tile is always a power of 2."); + bool request_satisfied = + std::max(tile_size_i, tile_size_j) <= TileLongSide && + std::min(tile_size_i, tile_size_j) <= TileShortSide; + + if (request_satisfied) { + LaunchBatchNarrowMatrixTransposeKernel( + d, tile_size_i, tile_size_j, total_tiles_count, input, input_dims, + output); + return; + } + + // If the execution reaches here, then the kernel was not launched; we then + // determine whether it is the long side or the short side that falls short + // of the request and increase that parameter accordingly. + const bool long_side_request_not_satisfied = + std::max(tile_size_i, tile_size_j) > TileLongSide; + + if (long_side_request_not_satisfied) { + BatchNarrowMatrixTransposeDispatcher::DoIt(d, tile_size_i, + tile_size_j, + total_tiles_count, + input, input_dims, + output); + } else { + BatchNarrowMatrixTransposeDispatcher::DoIt(d, tile_size_i, + tile_size_j, + total_tiles_count, + input, input_dims, + output); + } + } +}; + +template +struct BatchNarrowMatrixTransposeDispatcher< + T, TileLongSide, TileShortSide, conjugate, + typename std::enable_if::type> { + static void DoIt(const GPUDevice& d, int tile_size_i, int tile_size_j, + int total_tiles_count, const T* input, + const Dimension<3>& input_dims, T* output) { + static_assert( + (TileLongSide & (TileLongSide - 1)) == 0, + "The length of the longer side of the tile is always a power of 2."); + bool request_satisfied = + std::max(tile_size_i, tile_size_j) <= TileLongSide && + std::min(tile_size_i, tile_size_j) <= TileShortSide; + + if (request_satisfied) { + LaunchBatchNarrowMatrixTransposeKernel( + d, tile_size_i, tile_size_j, total_tiles_count, input, input_dims, + output); + return; + } + + // If the execution reaches here, then the kernel was not launched; since + // we are on the non long side frontier, we increment the short dimension + // and try again. + BatchNarrowMatrixTransposeDispatcher::DoIt(d, tile_size_i, + tile_size_j, + total_tiles_count, + input, input_dims, + output); + } +}; + +template +struct BatchNarrowMatrixTransposeDispatcher< + T, TileLongSide, TileShortSide, conjugate, + typename std::enable_if::type> { + static void DoIt(const GPUDevice& d, int tile_size_i, int tile_size_j, + int total_tiles_count, const T* input, + const Dimension<3>& input_dims, T* output) { + static_assert( + (TileLongSide & (TileLongSide - 1)) == 0, + "The length of the longer side of the tile is always a power of 2."); + + LaunchBatchNarrowMatrixTransposeKernel( + d, tile_size_i, tile_size_j, total_tiles_count, input, input_dims, + output); + } +}; + +// This function tries to recover, in a brute force way, the frontier defined in +// TileSizePossibilityFrontierCheck as a vector of tile size combinations lying +// on the long side frontier. This vector is sufficient to determine the entire +// frontier. +// +// Note that if one changes the frontier definition in +// TileSizePossibilityFrontierCheck and forgets to set the largest short +// side len of the largest legal long side len to 2, this function will fail +// and crash the program. +template +const std::vector>& GetTileSizesFrontier() { + static_assert( + SizeOfT <= 16, + "Currently, only data types of sizes 16 bytes or less are supported."); + static_assert((SizeOfT & (SizeOfT - 1)) == 0, + "Data types must have sizes that are powers of 2."); + + // Expensive work to populate sizes, lazily run in a thread-safe + // manner the first time GetTileSizesFrontier is called. + static auto* frontier = [] { + auto* frontier = new std::vector>(); + const int kMaxLongSideLen = 1024; + const int kMaxShortSideLen = 15; + for (int long_side = 32; long_side <= kMaxLongSideLen; long_side *= 2) { + for (int short_side = 2; short_side <= kMaxShortSideLen; + short_side += 1) { + if (TileSizeOnLongSideFrontier(long_side, short_side, SizeOfT)) { + // The current combination lies on the frontier, thus we + // add it to the frontier definition. + frontier->push_back(std::make_pair(long_side, short_side)); + + // The long side length is the largest one allowed iff its + // corresponding short side length is 2. + if (short_side == 2) return frontier; + + // We have exhausted all the possibilities in the frontier + // with the given long side length. + break; + } + } + } + LOG(FATAL) + << "The corresponding short side length of the largest long side " + "length has to be 2."; + }(); + return *frontier; +} + +// Helper structs to help determine which data type to use given the size of +// the matrix data type. A transpose of elements of size N will use a kernel +// which operates on an array of TransposeElemType::type. +template +struct TransposeElemType; +template <> +struct TransposeElemType<1> { + using type = uint8; +}; +template <> +struct TransposeElemType<2> { + using type = uint16; +}; +template <> +struct TransposeElemType<4> { + using type = uint32; +}; +template <> +struct TransposeElemType<8> { + using type = float2; +}; +template <> +struct TransposeElemType<16> { + using type = double2; +}; + +// A helper function to make RunSwapDimension1And2InTensor3 concise. This +// helper function looks at the data type and input matrix sizes and decides +// the thread numbers and tile sizes to use. +template +void SwapDimension1And2InTensor3WithNarrowMatrices( + const GPUDevice& d, const T* input, const Dimension<3>& input_dims, + T* output, const int kMinDimensionToUseTiles) { + // Get available tile sizes here for the data type requested: + const auto& tile_spec = GetTileSizesFrontier(); + + int tile_long_side_len = 0; + int tile_short_side_len = 0; + float lowest_cost = std::numeric_limits::max(); + int data_long_side = std::max(input_dims[1], input_dims[2]); + + for (auto tile_size_pair : tile_spec) { + int proposed_tile_long_side_len = tile_size_pair.first; + + // Number of threads that will not be doing anything useful when reading + // the matrix because the thread block size is bigger than the data block + // size. + int num_wasted_threads = + data_long_side - MathUtil::FloorOfRatio( + data_long_side, proposed_tile_long_side_len) * + proposed_tile_long_side_len; + + int num_full_tiles = MathUtil::FloorOfRatio( + data_long_side, proposed_tile_long_side_len); + + float cost = 0; + + // However, if we can execute two or more full tiles, then we gladly + // accept any number of wasted threads and ignore its cost. + if (num_full_tiles <= 1) cost = num_wasted_threads; + + // Using less than or equal to here because given the same cost, we + // would like to launch as many threads as possible. + if (cost <= lowest_cost) { + tile_long_side_len = proposed_tile_long_side_len; + tile_short_side_len = tile_size_pair.second; + lowest_cost = cost; + } + } + + // Request tile sizes such that the longer side of threadblock aligns with + // the longer side of input data block to maximize read throughput. + // The ideal tile shape is one where the length of the shorter side of the + // tile is equal to the length of the shorter side of the input matrix. + int requested_tile_size_i = input_dims[1] >= kMinDimensionToUseTiles + ? tile_long_side_len + : input_dims[1]; + int requested_tile_size_j = input_dims[1] >= kMinDimensionToUseTiles + ? input_dims[2] + : tile_long_side_len; + + // Truncate the shorter size requested according to the manual limit set in + // tile_spec to make sure that we do not launch configurations violating + // hardware limits. + requested_tile_size_i = + requested_tile_size_i == tile_long_side_len + ? tile_long_side_len + : std::min(requested_tile_size_i, tile_short_side_len); + requested_tile_size_j = + requested_tile_size_j == tile_long_side_len + ? tile_long_side_len + : std::min(requested_tile_size_j, tile_short_side_len); + + Dimension<3> input_dims_in_tiles = { + input_dims[0], + MathUtil::CeilOfRatio(input_dims[1], requested_tile_size_i), + MathUtil::CeilOfRatio(input_dims[2], requested_tile_size_j), + }; + + int total_tiles_count = + input_dims_in_tiles[0] * input_dims_in_tiles[1] * input_dims_in_tiles[2]; + + using ElemType = typename TransposeElemType::type; + static_assert(alignof(T) >= alignof(ElemType), "Unexpected data alignment."); + BatchNarrowMatrixTransposeDispatcher::DoIt( + d, requested_tile_size_i, requested_tile_size_j, total_tiles_count, + reinterpret_cast(input), input_dims, + reinterpret_cast(output)); +} + +// Launch the GPU kernel that would swap dimension-1 and dimension-2 in a +// 3D tensor. It looks at the shape of the incoming data, and decides the best +// strategy to launch. +template +void RunSwapDimension1And2InTensor3(const GPUDevice& d, const T* input, + const Dimension<3>& input_dims, T* output) { + // If both dimensions are not trivial, use tiles for the actual swapping. + // If one dimension is trivial, use SmallDim kernel for swapping. + // Otherwise, the trivial swapping relying on the ldg cache is more efficient. + static const int kMinDimensionToUseTiles = 16; + static const int kMinDimensionToUseRectTiles = 96; + + bool large_matrix = input_dims[1] >= kMinDimensionToUseTiles && + input_dims[2] >= kMinDimensionToUseTiles; + bool narrow_matrix = input_dims[1] >= kMinDimensionToUseRectTiles || + input_dims[2] >= kMinDimensionToUseRectTiles; + if (large_matrix) { + // We get best performance when kTileSize is the number of threads in a warp + // (32 on our GPUs) and NumSubTiles is 8, so our block size is 8 * 32 = 256 + // threads. + constexpr int kTileSize = 32; + constexpr int kNumThreads = 256; + + Dimension<3> input_dims_in_tiles = { + input_dims[0], + MathUtil::CeilOfRatio(input_dims[1], kTileSize), + MathUtil::CeilOfRatio(input_dims[2], kTileSize), + }; + + int total_tiles_count = input_dims_in_tiles[0] * input_dims_in_tiles[1] * + input_dims_in_tiles[2]; + TF_CHECK_OK(GpuLaunchKernel( + SwapDimension1And2InTensor3UsingTiles, + total_tiles_count, kNumThreads, 0, d.stream(), input, input_dims, + output)); + + } else if (narrow_matrix) { + SwapDimension1And2InTensor3WithNarrowMatrices( + d, input, input_dims, output, kMinDimensionToUseTiles); + } else { + int total_element_count = input_dims[0] * input_dims[1] * input_dims[2]; + GpuLaunchConfig config = GetGpuLaunchConfig(total_element_count, d); + TF_CHECK_OK(GpuLaunchKernel(ShuffleInTensor3Simple, + config.block_count, config.thread_per_block, 0, + d.stream(), config.virtual_thread_count, input, + input_dims, output)); + } +} + +// A GPU helper functor that does general dimension 1 and 2 switch for 3D +// tensor. +template +struct SwapDimension1And2InTensor3 { + typedef GPUDevice Device; + void operator()(const Device& d, const T* in, + const gtl::ArraySlice& combined_dims, T* out) { + Dimension<3> input_dims = {static_cast(combined_dims[0]), + static_cast(combined_dims[1]), + static_cast(combined_dims[2])}; + RunSwapDimension1And2InTensor3(d, in, input_dims, out); + } +}; + +// A GPU helper functor that does general dimension 0 and 2 switch for 3D +// tensor. +template +struct SwapDimension0And2InTensor3 { + typedef GPUDevice Device; + void operator()(const Device& d, const T* in, + const gtl::ArraySlice& combined_dims, T* out) { + Dimension<3> input_dims = {static_cast(combined_dims[0]), + static_cast(combined_dims[1]), + static_cast(combined_dims[2])}; + size_t total_size = combined_dims[0] * combined_dims[1] * combined_dims[2]; + GpuLaunchConfig config = GetGpuLaunchConfig(total_size, d); + + auto out_ptr = reinterpret_cast(out); + bool aligned = out_ptr % 16 == 0; + + bool use_vector = false; + bool use_custom_config = false; + if ((input_dims[0] <= 128 && input_dims[2] <= 128) || + input_dims[0] * input_dims[1] <= 128 || + input_dims[1] * input_dims[2] <= 8) { + use_vector = true; + use_custom_config = true; + } else if (input_dims[1] * input_dims[2] <= 16384) { + use_vector = true; + } + + if (sizeof(T) == 2 && aligned && use_vector) { + int block_count; + if (use_custom_config) { + block_count = (total_size + config.thread_per_block - 1) / + config.thread_per_block; + } else { + block_count = config.block_count; + } + + TF_CHECK_OK( + GpuLaunchKernel(ShuffleInTensor3SimpleVector, + block_count, config.thread_per_block / kUnroll, 0, + d.stream(), total_size, in, input_dims, out)); + } else { + TF_CHECK_OK(GpuLaunchKernel(ShuffleInTensor3Simple, + config.block_count, config.thread_per_block, + 0, d.stream(), config.virtual_thread_count, + in, input_dims, out)); + } + } +}; + +// A GPU helper functor that converts NHWC TensorFlow data format to +// NCHW format that is accepted by Cudnn. +template +struct NHWCToNCHW { + typedef GPUDevice Device; + void operator()(const Device& d, typename TTypes::ConstTensor in, + typename TTypes::Tensor out) { + Dimension<3> combined_dims; + combined_dims[0] = in.dimension(0); // N (batch) + combined_dims[1] = in.dimension(1); // spatial dimensions (HW) + for (int i = 2; i < NDIMS - 1; ++i) { + combined_dims[1] *= in.dimension(i); + } + combined_dims[2] = in.dimension(NDIMS - 1); // C (channels) + RunSwapDimension1And2InTensor3(d, in.data(), combined_dims, out.data()); + } +}; + +// A GPU helper functor that converts NCHW Cudnn data format to NHWC TensorFlow +// Format. +template +struct NCHWToNHWC { + typedef GPUDevice Device; + void operator()(const Device& d, typename TTypes::ConstTensor in, + typename TTypes::Tensor out) { + Dimension<3> combined_dims; + combined_dims[0] = in.dimension(0); // N (batch) + combined_dims[1] = in.dimension(1); // C (channel) + combined_dims[2] = in.dimension(2); // spatial dimensions (HW) + for (int i = 3; i < NDIMS; ++i) { + combined_dims[2] *= in.dimension(i); + } + RunSwapDimension1And2InTensor3(d, in.data(), combined_dims, out.data()); + } +}; + +} // namespace functor +} // namespace tensorflow + +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM + +#endif // TENSORFLOW_CORE_KERNELS_CONV_2D_GPU_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/conv_3d.h b/third_party/tflite-hdrs/tensorflow/core/kernels/conv_3d.h new file mode 100644 index 00000000..b4cdbd5b --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/conv_3d.h @@ -0,0 +1,128 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Functors for 3d convolution. + +#ifndef TENSORFLOW_CORE_KERNELS_CONV_3D_H_ +#define TENSORFLOW_CORE_KERNELS_CONV_3D_H_ + +#include + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/ops_util.h" +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/kernels/eigen_backward_cuboid_convolutions.h" +#include "tensorflow/core/kernels/eigen_cuboid_convolution.h" + +namespace tensorflow { +namespace functor { + +// Applies a 3D convolution to a batch of multi-channel volumes. +template +struct CuboidConvolution; + +// Backward input pass for the cuboid convolution. +template +struct CuboidConvolutionBackwardInput; + +// Backward filter pass for the cuboid convolution. +template +struct CuboidConvolutionBackwardFilter; + +typedef Eigen::ThreadPoolDevice CPUDevice; + +template +struct CuboidConvolution { + void operator()(const CPUDevice& d, typename TTypes::Tensor output, + typename TTypes::ConstTensor input, + typename TTypes::ConstTensor filter, int stride_planes, + int stride_rows, int stride_cols, + const Eigen::PaddingType& padding) { + output.device(d) = Eigen::CuboidConvolution( + input, filter, stride_planes, stride_rows, stride_cols, padding); + } +}; + +template +struct CuboidConvolutionBackwardInput { + void operator()(const CPUDevice& d, + typename TTypes::Tensor input_backward, + typename TTypes::ConstTensor filter, + typename TTypes::ConstTensor output_backward, + int stride_planes, int stride_rows, int stride_cols) { + // Need to swap the order of plane/row/col strides when calling Eigen. + input_backward.device(d) = Eigen::CuboidConvolutionBackwardInput( + filter, output_backward, + input_backward.dimension(3), // input_planes + input_backward.dimension(2), // input_rows + input_backward.dimension(1), // input_cols + stride_cols, stride_rows, stride_planes); + } +}; + +template +struct CuboidConvolutionBackwardFilter { + void operator()(const CPUDevice& d, + typename TTypes::Tensor filter_backward, + typename TTypes::ConstTensor input, + typename TTypes::ConstTensor output_backward, + int stride_planes, int stride_rows, int stride_cols) { + // Need to swap the order of plane/row/col strides when calling Eigen. + filter_backward.device(d) = Eigen::CuboidConvolutionBackwardKernel( + input, output_backward, + filter_backward.dimension(2), // kernel_planes + filter_backward.dimension(1), // kernel_rows + filter_backward.dimension(0), // kernel_cols + stride_cols, stride_rows, stride_planes); + } +}; + +} // namespace functor + +typedef Eigen::ThreadPoolDevice CPUDevice; + +template +struct LaunchConv3DOp; + +template +struct LaunchConv3DOp { + static void launch(OpKernelContext* context, bool cudnn_use_autotune, + const Tensor& input, const Tensor& filter, + const std::array& dilations, + const std::array& strides, const Padding padding, + TensorFormat data_format, Tensor* output) { + OP_REQUIRES(context, data_format == FORMAT_NHWC, + absl::InvalidArgumentError("CPU implementation of Conv3D " + "currently only supports the NHWC " + "tensor format.")); + OP_REQUIRES( + context, dilations[0] == 1 && dilations[1] == 1 && dilations[2] == 1, + absl::InvalidArgumentError("CPU implementation of Conv3D " + "currently only supports dilated rates " + "of 1.")); + OP_REQUIRES(context, filter.dim_size(3) == input.dim_size(input.dims() - 1), + absl::InvalidArgumentError(absl::StrCat( + "Number of channels in filter (", filter.dim_size(3), + ") must match last dimension of input (", + input.dim_size(input.dims() - 1), ")"))); + functor::CuboidConvolution()( + context->eigen_device(), output->tensor(), + input.tensor(), filter.tensor(), strides[2], strides[1], + strides[0], BrainPadding2EigenPadding(padding)); + } +}; +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_CONV_3D_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/conv_grad_input_ops.h b/third_party/tflite-hdrs/tensorflow/core/kernels/conv_grad_input_ops.h new file mode 100644 index 00000000..3dbecd51 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/conv_grad_input_ops.h @@ -0,0 +1,718 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// See docs in ../ops/nn_ops.cc. + +#ifndef TENSORFLOW_CORE_KERNELS_CONV_GRAD_INPUT_OPS_H_ +#define TENSORFLOW_CORE_KERNELS_CONV_GRAD_INPUT_OPS_H_ + +#define USE_EIGEN_TENSOR +#define EIGEN_USE_THREADS + +#include +#include +#include + +#include "absl/base/dynamic_annotations.h" +#include "tensorflow/core/framework/bounds_check.h" +#include "tensorflow/core/framework/kernel_shape_util.h" +#include "tensorflow/core/framework/numeric_op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/tensor_slice.h" +#include "tensorflow/core/kernels/conv_2d.h" +#include "tensorflow/core/kernels/conv_grad_ops.h" +#include "tensorflow/core/kernels/conv_grad_shape_utils.h" +#include "tensorflow/core/kernels/fill_functor.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/util/padding.h" +#include "tensorflow/core/util/tensor_format.h" +#include "tensorflow/core/util/use_cudnn.h" +#include "tensorflow/core/util/work_sharder.h" + +#if defined(TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL) +#include "xla/tsl/framework/contraction/eigen_contraction_kernel.h" +#endif + +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM +#include "tensorflow/core/kernels/conv_ops_gpu.h" +#include "tensorflow/core/platform/stream_executor.h" +#include "tensorflow/core/util/proto/proto_utils.h" +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM +#if GOOGLE_CUDA +#include "xla/stream_executor/gpu/gpu_asm_opts.h" +#include "xla/stream_executor/gpu/redzone_allocator.h" +#include "xla/stream_executor/integrations/tf_allocator_adapter.h" +#endif // GOOGLE_CUDA + +namespace tensorflow { + +typedef Eigen::ThreadPoolDevice CPUDevice; +typedef Eigen::GpuDevice GPUDevice; + +// Returns in 'im_data' (assumes to be zero-initialized) image patch in storage +// order (height, width, depth), constructed from patches in 'col_data', which +// is required to be in storage order (out_height * out_width, filter_height, +// filter_width, in_depth). Implementation by Yangqing Jia (jiayq). +template +void Col2im(const T* col_data, const int depth, const int height, + const int width, const int filter_h, const int filter_w, + const int pad_t, const int pad_l, const int pad_b, const int pad_r, + const int stride_h, const int stride_w, T* __restrict im_data) { + int height_col = (height + pad_t + pad_b - filter_h) / stride_h + 1; + int width_col = (width + pad_l + pad_r - filter_w) / stride_w + 1; + int h_pad = -pad_t; + for (int h = 0; h < height_col; ++h) { + int w_pad = -pad_l; + for (int w = 0; w < width_col; ++w) { + T* im_patch_data = im_data + (h_pad * width + w_pad) * depth; + for (int ih = h_pad; ih < h_pad + filter_h; ++ih) { + for (int iw = w_pad; iw < w_pad + filter_w; ++iw) { + if (ih >= 0 && ih < height && iw >= 0 && iw < width) { + for (int i = 0; i < depth; ++i) { + im_patch_data[i] += col_data[i]; + } + } + im_patch_data += depth; + col_data += depth; + } + // Jump over remaining number of depth. + im_patch_data += depth * (width - filter_w); + } + w_pad += stride_w; + } + h_pad += stride_h; + } +} + +// Computes backprop input using Eigen::SpatialConvolutionBackwardInput on CPU +// and GPU (for int32 only). +template +struct LaunchConv2DBackpropInputOpImpl { + void operator()(OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune, + const Tensor& out_backprop, const Tensor& filter, + int row_dilation, int col_dilation, int row_stride, + int col_stride, const Padding& padding, + const std::vector& explicit_paddings, + Tensor* in_backprop, TensorFormat data_format) { + std::vector strides(4, 1); + std::vector dilations(4, 1); + + auto input_h = GetTensorDimIndex(data_format, 'H'); + auto input_w = GetTensorDimIndex(data_format, 'W'); + strides[input_h] = row_stride; + strides[input_w] = col_stride; + dilations[input_h] = row_dilation; + dilations[input_w] = col_dilation; + + const TensorShape& input_shape = in_backprop->shape(); + const TensorShape& filter_shape = filter.shape(); + + ConvBackpropDimensions dims; + OP_REQUIRES_OK( + ctx, ConvBackpropComputeDimensionsV2( + "Conv2DBackpropInput", /*num_spatial_dims=*/2, input_shape, + filter_shape, out_backprop.shape(), dilations, strides, + padding, explicit_paddings, data_format, &dims)); + + int64_t padding_top = -1, padding_bottom = -1; + int64_t padding_left = -1, padding_right = -1; + if (padding == EXPLICIT) { + GetExplicitPaddingForDim(explicit_paddings, data_format, 'H', + &padding_top, &padding_bottom); + GetExplicitPaddingForDim(explicit_paddings, data_format, 'W', + &padding_left, &padding_right); + } + + int64_t expected_out_rows, expected_out_cols; + // The function is guaranteed to succeed because we checked the output and + // padding was valid earlier. + TF_CHECK_OK(GetWindowedOutputSizeVerbose( + dims.spatial_dims[0].input_size, dims.spatial_dims[0].filter_size, + row_dilation, row_stride, padding, &expected_out_rows, &padding_top, + &padding_bottom)); + DCHECK_EQ(dims.spatial_dims[0].output_size, expected_out_rows); + + TF_CHECK_OK(GetWindowedOutputSizeVerbose( + dims.spatial_dims[1].input_size, dims.spatial_dims[1].filter_size, + col_dilation, col_stride, padding, &expected_out_cols, &padding_left, + &padding_right)); + DCHECK_EQ(dims.spatial_dims[1].output_size, expected_out_cols); + + if (std::is_same::value) { + int64_t size = 1; +#define REQUIRES_32BIT(x) \ + size *= x; \ + OP_REQUIRES(ctx, \ + FastBoundsCheck(x, std::numeric_limits::max()) && \ + FastBoundsCheck(size, std::numeric_limits::max()), \ + errors::InvalidArgument("Tensor too large")) + + REQUIRES_32BIT(in_backprop->dim_size(0)); + REQUIRES_32BIT(in_backprop->dim_size(1) + padding_top + padding_bottom); + REQUIRES_32BIT(in_backprop->dim_size(2) + padding_left + padding_right); + REQUIRES_32BIT(in_backprop->dim_size(3)); +#undef REQUIRES_32BIT + } + + auto in_backprop_t = in_backprop->tensor(); + auto out_backprop_t = out_backprop.tensor(); + auto filter_t = filter.tensor(); + + // WARNING: Need to swap row/col, padding_top/padding_left, and + // padding_bottom/padding_right when calling Eigen. Eigen expects tensors + // in NWHC format, but Tensorflow uses NHWC. + + if (padding != EXPLICIT) { + // If padding was not explicitly defined, Eigen spatial convolution + // backward input will infer correct forward paddings from input tensors. + functor::SpatialConvolutionBackwardInputFunc()( + ctx->eigen_device(), in_backprop_t, filter_t, out_backprop_t, + col_stride, row_stride, col_dilation, row_dilation); + } else { + functor::SpatialConvolutionBackwardInputWithExplicitPaddingFunc()( + ctx->eigen_device(), in_backprop_t, filter_t, out_backprop_t, + in_backprop_t.dimension(2) + (padding_left + padding_right), + in_backprop_t.dimension(1) + (padding_top + padding_bottom), + col_stride, row_stride, col_dilation, row_dilation, padding_top, + padding_left); + } + } +}; + +// Computes backprop input using Eigen::SpatialConvolutionBackwardInput on CPU. +template +struct LaunchConv2DBackpropInputOp { + void operator()(OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune, + const Tensor& out_backprop, const Tensor& filter, + int row_dilation, int col_dilation, int row_stride, + int col_stride, const Padding& padding, + const std::vector& explicit_paddings, + Tensor* in_backprop, TensorFormat data_format) { + LaunchConv2DBackpropInputOpImpl launcher; + launcher(ctx, use_cudnn, cudnn_use_autotune, out_backprop, filter, + row_dilation, col_dilation, row_stride, col_stride, padding, + explicit_paddings, in_backprop, data_format); + } +}; + +template +struct Conv2DCustomBackpropInputMatMulFunctor { + using MatrixMap = Eigen::Map< + Eigen::Matrix>; + using ConstMatrixMap = Eigen::Map< + const Eigen::Matrix>; + + void operator()(OpKernelContext* ctx, const T* out_data, const T* filter_data, + const int filter_total_size, const int output_image_size, + const int dims_out_depth, T* im2col_buf) { + // Compute gradient into 'im2col_buf'. + MatrixMap C(im2col_buf, output_image_size, filter_total_size); + + ConstMatrixMap A(out_data, output_image_size, dims_out_depth); + ConstMatrixMap B(filter_data, filter_total_size, dims_out_depth); + + C.noalias() = A * B.transpose(); + } +}; + +#if defined(TENSORFLOW_USE_MKLDNN_CONTRACTION_KERNEL) +template <> +struct Conv2DCustomBackpropInputMatMulFunctor { + using T = float; + + void operator()(OpKernelContext* ctx, const T* out_data, const T* filter_data, + const int filter_total_size, const int output_image_size, + const int dims_out_depth, T* im2col_buf) { + // Inputs are in RowMajor order. + // im2col = out_data * filter_data^T + // [ois x fts] = [ois x dod] * [fts x dod]^T + // + // Dimension names: + // out_image_size -> ois + // filter_total_size -> fts + // dims_out_depth -> dod + + const int m = output_image_size; + const int n = filter_total_size; + const int k = dims_out_depth; // contraction dim + + const char transposeA = 'N'; // sgemm(A) == filter_data + const char transposeB = 'T'; // sgemm(B) == out_data + + const int ldA = dims_out_depth; + const int ldB = dims_out_depth; + const int ldC = filter_total_size; + + const float alpha = 1.0; + const float beta = 0.0; + + // dnnl_sgemm code can't be instrumented with msan. + ANNOTATE_MEMORY_IS_INITIALIZED( + im2col_buf, filter_total_size * output_image_size * sizeof(T)); + + dnnl_status_t st = + dnnl_sgemm(transposeA, transposeB, m, n, k, alpha, out_data, ldA, + filter_data, ldB, beta, im2col_buf, ldC); + + OP_REQUIRES( + ctx, st == 0, + errors::Internal("Failed to call dnnl_sgemm. Error code: ", st)); + } +}; +#endif + +template +class Conv2DBackpropInputOp : public OpKernel { + public: + explicit Conv2DBackpropInputOp(OpKernelConstruction* context) + : OpKernel(context) { + string data_format; + OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format)); + OP_REQUIRES(context, FormatFromString(data_format, &data_format_), + errors::InvalidArgument("Invalid data format")); + + OP_REQUIRES_OK(context, context->GetAttr("strides", &strides_)); + OP_REQUIRES(context, strides_.size() == 4, + errors::InvalidArgument("Sliding window strides field must " + "specify 4 dimensions")); + int stride_n = GetTensorDim(strides_, data_format_, 'N'); + int stride_c = GetTensorDim(strides_, data_format_, 'C'); + int stride_h = GetTensorDim(strides_, data_format_, 'H'); + int stride_w = GetTensorDim(strides_, data_format_, 'W'); + OP_REQUIRES( + context, (stride_n == 1 && stride_c == 1), + errors::Unimplemented("Current implementation does not yet support " + "strides in the batch and depth dimensions.")); + OP_REQUIRES(context, stride_h > 0 && stride_w > 0, + errors::InvalidArgument( + "Row and column strides should be larger than 0.")); + + OP_REQUIRES_OK(context, context->GetAttr("dilations", &dilations_)); + OP_REQUIRES(context, dilations_.size() == 4, + errors::InvalidArgument("Sliding window dilations field must " + "specify 4 dimensions")); + int dilation_n = GetTensorDim(dilations_, data_format_, 'N'); + int dilation_c = GetTensorDim(dilations_, data_format_, 'C'); + int dilation_h = GetTensorDim(dilations_, data_format_, 'H'); + int dilation_w = GetTensorDim(dilations_, data_format_, 'W'); + OP_REQUIRES( + context, (dilation_n == 1 && dilation_c == 1), + errors::Unimplemented("Current implementation does not yet support " + "dilations in the batch and depth dimensions.")); + OP_REQUIRES( + context, dilation_h > 0 && dilation_w > 0, + errors::InvalidArgument("Dilated rates should be larger than 0.")); + + OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_)); + OP_REQUIRES_OK(context, + context->GetAttr("explicit_paddings", &explicit_paddings_)); + OP_REQUIRES_OK(context, CheckValidPadding(padding_, explicit_paddings_, + /*num_dims=*/4, data_format_)); + + OP_REQUIRES_OK(context, context->GetAttr("use_cudnn_on_gpu", &use_cudnn_)); + cudnn_use_autotune_ = CudnnUseAutotune(); + + if (std::is_same::value || + std::is_same::value) { + OP_REQUIRES( + context, data_format_ == FORMAT_NHWC, + errors::InvalidArgument("Conv2DBackpropInputOp [CPU or GPU(int32)] " + "only supports NHWC data format.")); + + // TODO(yangzihao): Add a CPU implementation for dilated convolution. + OP_REQUIRES( + context, (dilation_h == 1 && dilation_w == 1), + errors::InvalidArgument( + "Conv2DBackpropInputOp [CPU or GPU(int32)] not yet support " + "dilation rates larger than 1.")); + } + } + + void Compute(OpKernelContext* context) override { + const Tensor& input_sizes = context->input(0); + const Tensor& filter = context->input(1); + const Tensor& out_backprop = context->input(2); + + OP_REQUIRES( + context, out_backprop.dims() == 4, + errors::InvalidArgument("input_sizes must be 4-dimensional, got: ", + out_backprop.dims())); + + TensorShape input_shape; + OP_REQUIRES_OK(context, + Conv2DBackpropComputeInputShape(input_sizes, filter.shape(), + out_backprop.shape(), + data_format_, &input_shape)); + + Tensor* in_backprop = nullptr; + OP_REQUIRES_OK(context, + context->allocate_output(0, input_shape, &in_backprop)); + + // If there is nothing to compute, return. + if (input_shape.num_elements() == 0) { + return; + } + + // If shapes are valid but `out_backprop` is empty, in_backprop should be + // set to all zeros. Otherwise, cudnn/dnnl fail with an empty input. + if (out_backprop.NumElements() == 0) { + functor::SetZeroFunctor set_zero; + set_zero(context->eigen_device(), + in_backprop->template flat()); + return; + } + + // For now we take the stride from the second and third dimensions only (we + // do not support striding on the batch or depth dimension). + const int stride_rows = GetTensorDim(strides_, data_format_, 'H'); + const int stride_cols = GetTensorDim(strides_, data_format_, 'W'); + const int dilation_rows = GetTensorDim(dilations_, data_format_, 'H'); + const int dilation_cols = GetTensorDim(dilations_, data_format_, 'W'); + + VLOG(2) << "Conv2DBackpropInput:" + << " input: " << input_shape.DebugString() + << " filter:" << filter.shape().DebugString() + << " out_backprop: " << out_backprop.shape().DebugString() + << " strides: [" << stride_rows << ", " << stride_cols << "]" + << " dilations: [" << dilation_rows << ", " << dilation_cols << "]"; + + LaunchConv2DBackpropInputOp launch; + launch(context, use_cudnn_, cudnn_use_autotune_, out_backprop, filter, + dilation_rows, dilation_cols, stride_rows, stride_cols, padding_, + explicit_paddings_, in_backprop, data_format_); + } + + private: + std::vector dilations_; + std::vector strides_; + TensorFormat data_format_; + Padding padding_; + std::vector explicit_paddings_; + + bool use_cudnn_ = false; + bool cudnn_use_autotune_ = false; + + Conv2DBackpropInputOp(const Conv2DBackpropInputOp&) = delete; + void operator=(const Conv2DBackpropInputOp&) = delete; +}; + +// Based on implementation written by Yangqing Jia (jiayq). +template +class Conv2DCustomBackpropInputOp : public OpKernel { + public: + explicit Conv2DCustomBackpropInputOp(OpKernelConstruction* context) + : OpKernel(context) { + string data_format; + OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format)); + OP_REQUIRES(context, FormatFromString(data_format, &data_format_), + errors::InvalidArgument("Invalid data format")); + OP_REQUIRES(context, data_format_ == FORMAT_NHWC, + errors::InvalidArgument( + "Conv2DCustomBackpropInputOp only supports NHWC.")); + OP_REQUIRES_OK(context, context->GetAttr("strides", &strides_)); + OP_REQUIRES(context, strides_.size() == 4, + errors::InvalidArgument("Sliding window strides field must " + "specify 4 dimensions")); + OP_REQUIRES( + context, (strides_[0] == 1 && strides_[3] == 1), + errors::Unimplemented("Current implementation does not yet support " + "strides in the batch and depth dimensions.")); + OP_REQUIRES(context, strides_[1] > 0 && strides_[2] > 0, + errors::InvalidArgument( + "Row and column strides should be larger than 0.")); + OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_)); + OP_REQUIRES_OK(context, context->GetAttr("dilations", &dilations_)); + OP_REQUIRES(context, dilations_.size() == 4, + errors::InvalidArgument("Sliding window dilations field must " + "specify 4 dimensions")); + OP_REQUIRES( + context, (dilations_[0] == 1 && dilations_[3] == 1), + errors::Unimplemented("Current implementation does not yet support " + "dilations in the batch and depth dimensions.")); + // TODO(yangzihao): Add a CPU implementation for dilated convolution. + OP_REQUIRES( + context, (dilations_[1] == 1 && dilations_[2] == 1), + errors::InvalidArgument("Current CPU implementations do not yet " + "support dilation rates larger than 1.")); + OP_REQUIRES_OK(context, + context->GetAttr("explicit_paddings", &explicit_paddings_)); + OP_REQUIRES_OK(context, CheckValidPadding(padding_, explicit_paddings_, + /*num_dims=*/4, data_format_)); + } + + void Compute(OpKernelContext* context) override { + const Tensor& input_sizes = context->input(0); + const Tensor& filter = context->input(1); + const Tensor& out_backprop = context->input(2); + OP_REQUIRES( + context, out_backprop.dims() == 4, + errors::InvalidArgument("input_sizes must be 4-dimensional, got: ", + out_backprop.dims())); + + TensorShape input_shape; + OP_REQUIRES_OK(context, + Conv2DBackpropComputeInputShape(input_sizes, filter.shape(), + out_backprop.shape(), + data_format_, &input_shape)); + + ConvBackpropDimensions dims; + OP_REQUIRES_OK(context, + ConvBackpropComputeDimensionsV2( + "Conv2DCustomBackpropInput", /*num_spatial_dims=*/2, + input_shape, filter.shape(), out_backprop.shape(), + /*dilations=*/{1, 1, 1, 1}, strides_, padding_, + explicit_paddings_, data_format_, &dims)); + + OP_REQUIRES(context, dims.in_depth == filter.shape().dim_size(2), + errors::InvalidArgument( + "Gradients for grouped convolutions are not " + "supported on CPU. Please file a feature request if you " + "run into this issue. Computed input depth ", + dims.in_depth, " doesn't match filter input depth ", + filter.shape().dim_size(2))); + OP_REQUIRES( + context, dims.out_depth == filter.shape().dim_size(3), + errors::InvalidArgument("Computed output depth ", dims.out_depth, + " doesn't match filter output depth ", + filter.shape().dim_size(3))); + + Tensor* in_backprop = nullptr; + OP_REQUIRES_OK(context, + context->allocate_output(0, input_shape, &in_backprop)); + + // If there is nothing to compute, return. + if (input_shape.num_elements() == 0) { + return; + } + + // If shapes are valid but `out_backprop` is empty, in_backprop should be + // set to all zeros. Otherwise, cudnn/dnnl fail with an empty input. + if (out_backprop.NumElements() == 0) { + functor::SetZeroFunctor set_zero; + set_zero(context->eigen_device(), + in_backprop->template flat()); + return; + } + + int64_t pad_top, pad_bottom; + int64_t pad_left, pad_right; + + if (padding_ == Padding::EXPLICIT) { + pad_top = explicit_paddings_[2]; + pad_bottom = explicit_paddings_[3]; + pad_left = explicit_paddings_[4]; + pad_right = explicit_paddings_[5]; + } + OP_REQUIRES_OK( + context, + GetWindowedOutputSizeVerbose( + dims.spatial_dims[0].input_size, dims.spatial_dims[0].filter_size, + /*dilation_rate=*/1, dims.spatial_dims[0].stride, padding_, + &dims.spatial_dims[0].output_size, &pad_top, &pad_bottom)); + OP_REQUIRES_OK( + context, + GetWindowedOutputSizeVerbose( + dims.spatial_dims[1].input_size, dims.spatial_dims[1].filter_size, + /*dilation_rate=*/1, dims.spatial_dims[1].stride, padding_, + &dims.spatial_dims[1].output_size, &pad_left, &pad_right)); + + // The total dimension size of each kernel. + const int filter_total_size = dims.spatial_dims[0].filter_size * + dims.spatial_dims[1].filter_size * + dims.in_depth; + // The output image size is the spatial size of the output. + const int output_image_size = + dims.spatial_dims[0].output_size * dims.spatial_dims[1].output_size; + + // TODO(andydavis) Get L2/L3 cache sizes from device. + const size_t l2_cache_size = 256LL << 10; + const size_t l3_cache_size = 30LL << 20; + + // Use L3 cache size as target working set size. + const size_t target_working_set_size = l3_cache_size / sizeof(T); + + // Calculate size of matrices involved in MatMul: C = A x B. + const size_t size_A = output_image_size * dims.out_depth; + + const size_t size_B = filter_total_size * dims.out_depth; + + const size_t size_C = output_image_size * filter_total_size; + + const size_t work_unit_size = size_A + size_B + size_C; + + auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads()); + + // Calculate per-thread work unit size. + const size_t thread_work_unit_size = + work_unit_size / worker_threads.num_threads; + + // Set minimum per-thread work unit size to size of L2 cache. + const size_t min_thread_work_unit_size = l2_cache_size / sizeof(T); + + // Use parallel tensor contractions if there is no batching, or if the + // minimum per-thread work unit size threshold has been exceeded. + // Otherwise, revert to multiple single-threaded matmul ops running in + // parallel to keep all threads busy. + // TODO(andydavis) Explore alternatives to branching the code in this way + // (i.e. run multiple, parallel tensor contractions in another thread pool). + const bool use_parallel_contraction = + dims.batch_size == 1 || + thread_work_unit_size >= min_thread_work_unit_size; + + OP_REQUIRES( + context, work_unit_size > 0, + errors::InvalidArgument("input, filter_sizes and out_backprop tensors " + "must all have at least 1 element")); + + const size_t shard_size = + use_parallel_contraction + ? 1 + : (target_working_set_size + work_unit_size - 1) / work_unit_size; + + Tensor col_buffer; + OP_REQUIRES_OK(context, + context->allocate_temp( + DataTypeToEnum::value, + TensorShape({static_cast(shard_size), + static_cast(output_image_size), + static_cast(filter_total_size)}), + &col_buffer)); + + // The input offset corresponding to a single input image. + const int input_offset = dims.spatial_dims[0].input_size * + dims.spatial_dims[1].input_size * dims.in_depth; + // The output offset corresponding to a single output image. + const int output_offset = dims.spatial_dims[0].output_size * + dims.spatial_dims[1].output_size * dims.out_depth; + + const T* filter_data = filter.template flat().data(); + T* col_buffer_data = col_buffer.template flat().data(); + const T* out_backprop_data = out_backprop.template flat().data(); + + auto in_backprop_flat = in_backprop->template flat(); + T* input_backprop_data = in_backprop_flat.data(); + in_backprop_flat.device(context->eigen_device()) = + in_backprop_flat.constant(T(0)); + + if (use_parallel_contraction) { + typedef Eigen::TensorMap, + Eigen::Unaligned> + TensorMap; + typedef Eigen::TensorMap, + Eigen::Unaligned> + ConstTensorMap; + + // Initialize contraction dims (we need to transpose 'B' below). + Eigen::array, 1> contract_dims; + contract_dims[0].first = 1; + contract_dims[0].second = 1; + + for (int image_id = 0; image_id < dims.batch_size; ++image_id) { + // Compute gradient into col_buffer. + TensorMap C(col_buffer_data, output_image_size, filter_total_size); + + ConstTensorMap A(out_backprop_data + output_offset * image_id, + output_image_size, dims.out_depth); + ConstTensorMap B(filter_data, filter_total_size, dims.out_depth); + + C.device(context->eigen_cpu_device()) = A.contract(B, contract_dims); + + Col2im( + col_buffer_data, dims.in_depth, dims.spatial_dims[0].input_size, + dims.spatial_dims[1].input_size, dims.spatial_dims[0].filter_size, + dims.spatial_dims[1].filter_size, pad_top, pad_left, pad_bottom, + pad_right, dims.spatial_dims[0].stride, dims.spatial_dims[1].stride, + input_backprop_data); + + input_backprop_data += input_offset; + } + } else { + for (int image_id = 0; image_id < dims.batch_size; + image_id += shard_size) { + const int shard_limit = + std::min(static_cast(shard_size), + static_cast(dims.batch_size) - image_id); + + auto shard = [&context, &dims, &pad_top, &pad_left, &pad_bottom, + &pad_right, &output_image_size, &filter_total_size, + &input_backprop_data, &col_buffer_data, + &out_backprop_data, &filter_data, &input_offset, + &output_offset, &size_C](int64_t start, int64_t limit) { + for (int shard_id = start; shard_id < limit; ++shard_id) { + T* im2col_buf = col_buffer_data + shard_id * size_C; + T* input_data = input_backprop_data + shard_id * input_offset; + const T* out_data = out_backprop_data + shard_id * output_offset; + + Conv2DCustomBackpropInputMatMulFunctor()( + context, out_data, filter_data, filter_total_size, + output_image_size, dims.out_depth, im2col_buf); + + Col2im(im2col_buf, dims.in_depth, + dims.spatial_dims[0].input_size, + dims.spatial_dims[1].input_size, + dims.spatial_dims[0].filter_size, + dims.spatial_dims[1].filter_size, pad_top, pad_left, + pad_bottom, pad_right, dims.spatial_dims[0].stride, + dims.spatial_dims[1].stride, input_data); + } + }; + Shard(worker_threads.num_threads, worker_threads.workers, shard_limit, + work_unit_size, shard); + + input_backprop_data += input_offset * shard_limit; + out_backprop_data += output_offset * shard_limit; + } + } + } + + private: + std::vector dilations_; + std::vector strides_; + Padding padding_; + std::vector explicit_paddings_; + TensorFormat data_format_; + + Conv2DCustomBackpropInputOp(const Conv2DCustomBackpropInputOp&) = delete; + void operator=(const Conv2DCustomBackpropInputOp&) = delete; +}; + +// TODO(ezhulenev): Add a cost model to switch between custom/Eigen ops. +#define DEFAULT_CONV_2D_BACKPROP_CPU_OP Conv2DCustomBackpropInputOp + +#define REGISTER_CONV_2D_BACKPROP_CPU_KERNELS(T) \ + REGISTER_KERNEL_BUILDER( \ + Name("Conv2DBackpropInput").Device(DEVICE_CPU).TypeConstraint("T"), \ + DEFAULT_CONV_2D_BACKPROP_CPU_OP); \ + REGISTER_KERNEL_BUILDER(Name("Conv2DBackpropInput") \ + .Device(DEVICE_CPU) \ + .Label("custom") \ + .TypeConstraint("T"), \ + Conv2DCustomBackpropInputOp); \ + REGISTER_KERNEL_BUILDER(Name("Conv2DBackpropInput") \ + .Device(DEVICE_CPU) \ + .Label("eigen_tensor") \ + .TypeConstraint("T"), \ + Conv2DBackpropInputOp); + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_CONV_GRAD_INPUT_OPS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/conv_grad_ops.h b/third_party/tflite-hdrs/tensorflow/core/kernels/conv_grad_ops.h new file mode 100644 index 00000000..40e03b2b --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/conv_grad_ops.h @@ -0,0 +1,215 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This is the common header for the input and filter backprop kernels. +// +// The operation to compute Conv2D gradients. +// +// To compute the gradients for Conv2D, we need three input tensors: +// input, filter, and backprop for output. +// And we need to compute two backprops: one for input and one for filter. We +// compute them in two different kernels. +// +// Both backprops can be computed as straightforward conv2d. +// +// Consider a case where the input is 3x3 and the filter is 2x1: +// +// INPUT = [ A B C ] +// [ D E F ] +// [ G H I ] +// +// where each "A", "B", etc is batch x in_depth +// +// FILTER = [ X Y ] +// +// where both "X" and "Y" are in_depth x out_depth +// +// With VALID padding, the output is 3x2: +// +// OUTPUT = [ a b ] +// [ c d ] +// [ e f ] +// +// where each "a", "b", etc is batch x out_depth +// +// So we have: +// +// a = A * X + B * Y +// b = B * X + C * Y +// c = D * X + E * Y +// d = E * X + F * Y +// e = G * X + H * Y +// f = H * X + I * Y +// +// So when we have backprops for the outputs (we denote them by +// a', b', ... ): +// +// The backprops for the input are: +// +// A' = a' * X^t +// B' = a' * Y^t + b' * X^t +// C' = b' * Y^t +// ... +// +// This is essentially computing a 2d conv of +// +// INPUT = [ 0 a' b' 0 ] +// [ 0 c' d' 0 ] +// [ 0 e' f' 0 ] +// and +// +// FILTER = [ Y^t X^t ] +// +// The backprops for the filter are: +// +// X' = A^t * a' + B^t * b' + D^t * c' + E^t * d' + G^t * e' + H^t * f' +// Y' = B^t * a' + C^t * b' + E^t + c' + F^t * d' + H^t * e' + I^t * f' +// +// This is essentially computing a 2d conv of +// +// INPUT = [ A^t B^t C^t ] +// [ D^t E^t F^t ] +// [ G^t H^t I^t ] +// +// and +// +// FILTER = [ a' b' ] +// [ c' d' ] +// [ e' f' ] +// +// +////////////////////////////////////////////////////////// +// +// With stride more than one, it's a bit more complicated (we will need to +// create holes to the backprop). +// +// Consider the case where +// +// INPUT = [ A B C D E ] +// [ F G H I J ] +// [ K L M N O ] +// and +// +// FILTER = [ X Y Z ] +// +// with stride 2. +// +// The output will be +// +// OUTPUT = [ a b ] +// [ c d ] +// +// where: +// +// a = A * X + B * Y + C * Z +// b = C * X + D * Y + E * Z +// c = K * X + L * Y + M * Z +// d = M * X + N * Y + O * Z +// +// +// To compute the backprop for INPUT, we need to convolve +// +// INPUT = [ 0 0 a' 0 b' 0 0 ] +// [ 0 0 0 0 0 0 0 ] +// [ 0 0 c' 0 d' 0 0 ] +// +// (notice the holes in INPUT) +// +// and +// +// FILTER = [ Z^t Y^t X^t ] +// +// with stride 1. +// +// To compute the backprop for FILTER, we need to convolve + +// +// INPUT = [ A^t B^t C^t D^t E^t ] +// [ F^t G^t H^t I^t J^t ] +// [ K^t L^t M^t N^t O^t ] +// and +// +// FILTER = [ a' 0 b' ] +// [ 0 0 0 ] +// [ c' 0 d' ] +// +// (notice the holes in FILTER) +// +// +// with stride 1 +// +////////////////////////////////////////////////////////// +// +// +// The case for SAME padding is in fact very similar to VALID -- we just +// need to pad the input tensor a bit when computing the filter_backprop. + +#ifndef TENSORFLOW_CORE_KERNELS_CONV_GRAD_OPS_H_ +#define TENSORFLOW_CORE_KERNELS_CONV_GRAD_OPS_H_ + +#include + +#include "tensorflow/core/util/padding.h" +#include "tensorflow/core/util/tensor_format.h" + +namespace tensorflow { + +// Forward declaration. +class OpKernelContext; + +template +struct LaunchConv2DBackpropInputOp { + void operator()(OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune, + const Tensor& out_backprop, const Tensor& filter, + int row_dilation, int col_dilation, int row_stride, + int col_stride, const Padding& padding, + const std::vector& explicit_paddings, + Tensor* in_backprop, TensorFormat data_format); +}; + +template +struct LaunchConv2DBackpropFilterOp { + void operator()(OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune, + const Tensor& out_backprop, const Tensor& input, + int row_dilation, int col_dilation, int row_stride, + int col_stride, const Padding& padding, + const std::vector& explicit_paddings, + Tensor* filter_backprop, TensorFormat data_format); +}; + +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM +template +struct LaunchConv2DBackpropInputOp { + void operator()(OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune, + const Tensor& input, const Tensor& filter, int row_dilation, + int col_dilation, int row_stride, int col_stride, + const Padding& padding, + const std::vector& explicit_paddings, Tensor* output, + TensorFormat data_format); +}; + +template +struct LaunchConv2DBackpropFilterOp { + void operator()(OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune, + const Tensor& out_backprop, const Tensor& input, + int row_dilation, int col_dilation, int row_stride, + int col_stride, const Padding& padding, + const std::vector& explicit_paddings, + Tensor* filter_backprop, TensorFormat data_format); +}; +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_CONV_GRAD_OPS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/conv_grad_shape_utils.h b/third_party/tflite-hdrs/tensorflow/core/kernels/conv_grad_shape_utils.h new file mode 100644 index 00000000..d83c1bb2 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/conv_grad_shape_utils.h @@ -0,0 +1,93 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_CONV_GRAD_SHAPE_UTILS_H_ +#define TENSORFLOW_CORE_KERNELS_CONV_GRAD_SHAPE_UTILS_H_ + +#include + +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/util/padding.h" +#include "tensorflow/core/util/tensor_format.h" + +namespace tensorflow { +// Information about a single spatial dimension for a convolution +// backpropagation. +struct ConvBackpropSpatialDimension { + int64_t input_size; + int64_t filter_size; + int64_t output_size; + int64_t stride; + int64_t dilation; + + // Output size after scaling by the stride. + int64_t expanded_output_size; + + // Number of padding elements to be added before/after this dimension of + // the input when computing Conv?DBackpropInput. + int64_t pad_before, pad_after; +}; + +// Computed dimensions for a backwards convolution. +struct ConvBackpropDimensions { + // Information about each spatial dimension. + absl::InlinedVector spatial_dims; + + // Batch size. + int64_t batch_size; + + // Input and output feature depth. + int64_t in_depth, out_depth; + + // Convenience access methods for spatial dimensions properties. + int64_t input_size(int dim) const { return spatial_dims[dim].input_size; } + int64_t filter_size(int dim) const { return spatial_dims[dim].filter_size; } + int64_t output_size(int dim) const { return spatial_dims[dim].output_size; } + int64_t stride(int dim) const { return spatial_dims[dim].stride; } + int64_t dilation(int dim) const { return spatial_dims[dim].dilation; } + + // Compute padding for the given spatial dimension. + int SpatialPadding(const Padding& padding, int dim) const; +}; + +// Common code between implementations of Conv?DBackpropInput and +// Conv?DBackpropFilter. Verifies that the dimensions all match, and computes +// sizes/padding for the spatial dimensions. Does not support explicit padding. +absl::Status ConvBackpropComputeDimensions( + absl::string_view label, int num_spatial_dims, + const TensorShape& input_shape, const TensorShape& filter_shape, + const TensorShape& out_backprop_shape, const std::vector& strides, + Padding padding, TensorFormat data_format, ConvBackpropDimensions* dims); + +// The V2 version computes the same outputs with arbitrary dilation rate and +// supports explicit padding. +// TODO(b/67112639): Merge V2 versions and the original versions eventually. +absl::Status ConvBackpropComputeDimensionsV2( + absl::string_view label, int num_spatial_dims, + const TensorShape& input_shape, const TensorShape& filter_shape, + const TensorShape& out_backprop_shape, absl::Span dilations, + const std::vector& strides, Padding padding, + absl::Span explicit_paddings, TensorFormat data_format, + ConvBackpropDimensions* dims); + +// Computes the shape of the in_backprop. +absl::Status Conv2DBackpropComputeInputShape( + const Tensor& input_sizes, const TensorShape& filter_shape, + const TensorShape& out_backprop_shape, const TensorFormat& data_format, + TensorShape* input_shape); +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_CONV_GRAD_SHAPE_UTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/conv_ops.h b/third_party/tflite-hdrs/tensorflow/core/kernels/conv_ops.h new file mode 100644 index 00000000..65c63fec --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/conv_ops.h @@ -0,0 +1,140 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_CONV_OPS_H_ +#define TENSORFLOW_CORE_KERNELS_CONV_OPS_H_ + +#include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/resource_mgr.h" +#include "tensorflow/core/platform/mem.h" +#include "tensorflow/core/util/tensor_format.h" + +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM +#include "tensorflow/core/kernels/conv_ops_gpu.h" +#include "tensorflow/core/platform/stream_executor.h" +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM + +namespace tensorflow { + +// Forward declaration. +class OpKernelContext; + +template +struct LaunchConv2DOp { + void operator()(OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune, + const Tensor& input, const Tensor& filter, int row_dilation, + int col_dilation, int row_stride, int col_stride, + const Padding& padding, + const std::vector& explicit_paddings, Tensor* output, + TensorFormat data_format); +}; + +template +struct LaunchConvOp { + void operator()(OpKernelContext* context, bool cudnn_use_autotune, + const Tensor& input, const Tensor& filter, + const std::vector& dilations, + const std::vector& strides, Padding padding, + const std::vector& explicit_paddings, + TensorFormat data_format, Tensor* output); +}; + +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM +template +struct LaunchConv2DOp { + void operator()(OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune, + const Tensor& input, const Tensor& filter, int row_dilation, + int col_dilation, int row_stride, int col_stride, + const Padding& padding, + const std::vector& explicit_paddings, Tensor* output, + TensorFormat data_format); +}; + +template +struct LaunchConvOp { + void operator()(OpKernelContext* context, bool cudnn_use_autotune, + const Tensor& input, const Tensor& filter, + const std::vector& dilations, + const std::vector& strides, const Padding padding, + const std::vector& explicit_paddings, + TensorFormat data_format, Tensor* output); +}; +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM + +// Used to keep track of persistent memory buffers used within the op. +// It uses malloc and free to avoid the time cost of initializing the memory. +template +struct Im2ColBufferResource : public ResourceBase { + Im2ColBufferResource() { + data = static_cast(port::Malloc(size * sizeof(T))); + } + ~Im2ColBufferResource() { port::Free(data); } + // This mutex ensures that only a single operation at a time is able to use + // the buffer memory held by this resource. + mutex mu; + T* data; + string DebugString() const { return "Im2ColBufferResource"; } +}; + +// Convolution parameters specified by Op attributes. +struct Conv2DParameters { + std::vector dilations; + std::vector strides; + Padding padding; + TensorFormat data_format; + std::vector explicit_paddings; +}; + +// Convolution dimensions inferred from parameters, input and filter tensors. +struct Conv2DDimensions { + int batch; + int input_rows; + int input_cols; + int in_depth; + + int filter_rows; + int filter_cols; + int patch_depth; + int out_depth; + + int stride_rows; + int stride_cols; + + int dilation_rows; + int dilation_cols; + + int64_t out_rows; + int64_t out_cols; + int64_t pad_rows_before; + int64_t pad_rows_after; + int64_t pad_cols_before; + int64_t pad_cols_after; +}; + +// Initializes and validates Conv2D parameters configured by OpKernel +// attributes. +absl::Status InitConv2DParameters(const OpKernelConstruction* context, + Conv2DParameters* params); + +// Computes and validates convolutions dimensions from Conv2D parameters. If +// parameters are valid, dimensions will be updated with derived convolution +// dimensions, otherwise an error will be returned. +absl::Status ComputeConv2DDimension(const Conv2DParameters& params, + const Tensor& input, const Tensor& filter, + Conv2DDimensions* dimensions); + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_CONV_OPS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/conv_ops_fused_impl.h b/third_party/tflite-hdrs/tensorflow/core/kernels/conv_ops_fused_impl.h new file mode 100644 index 00000000..5e35562b --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/conv_ops_fused_impl.h @@ -0,0 +1,848 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Implements convolution operations with other kernels baked into the +// processing, to optimize latency and memory usage: +// - Conv2D + BiasAdd + +// - Conv2D + FusedBatchNorm + +// +// Activation: Relu, Relu6, Elu, etc... +// +// Kernels for convolutions fused with image transformations (resize and mirror +// padding) defined in `conv_ops_fused_image_transform.cc`. +// +// For the CPU device we implement fusion with an Eigen tensor contraction +// output kernel. For the GPU device we rely on CuDNN primitives. +// +// NOTE: GPU only supports fusion of Conv2D + BiasAdd + . + +#ifndef TENSORFLOW_CORE_KERNELS_CONV_OPS_FUSED_IMPL_H_ +#define TENSORFLOW_CORE_KERNELS_CONV_OPS_FUSED_IMPL_H_ + +#define USE_EIGEN_TENSOR +#define EIGEN_USE_THREADS + +#if GOOGLE_CUDA +#define EIGEN_USE_GPU +#endif // GOOGLE_CUDA + +#include +#include +#include +#include + +#include "tensorflow/core/framework/bounds_check.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/kernels/conv_2d.h" +#include "tensorflow/core/kernels/conv_ops.h" +#include "tensorflow/core/kernels/fused_eigen_output_kernels.h" +#include "tensorflow/core/kernels/ops_util.h" +#include "tensorflow/core/profiler/lib/scoped_annotation.h" +#include "tensorflow/core/util/tensor_format.h" +#include "tensorflow/core/util/use_cudnn.h" + +#if GOOGLE_CUDA +#include "third_party/gpus/cudnn/cudnn.h" +#include "xla/stream_executor/gpu/gpu_asm_opts.h" +#include "xla/stream_executor/gpu/redzone_allocator.h" +#include "xla/stream_executor/integrations/tf_allocator_adapter.h" +#include "tensorflow/core/kernels/conv_ops_gpu.h" +#include "tensorflow/core/platform/stream_executor.h" +#include "tensorflow/core/util/autotune_maps/conv_autotune_maps.h" +#include "tensorflow/core/util/autotune_maps/conv_parameters.h" +#include "tensorflow/core/util/proto/proto_utils.h" +#endif // GOOGLE_CUDA + +namespace tensorflow { + +typedef Eigen::ThreadPoolDevice CPUDevice; +typedef Eigen::GpuDevice GPUDevice; + +template +struct LaunchFusedConv2DOp { + void operator()(OpKernelContext* context, bool use_cudnn, + bool cudnn_use_autotune, const Tensor& input, + const Tensor& filter, FusedComputationType fusion, + const FusedComputationArgs& fusion_args, + const Conv2DParameters& params, + const Conv2DDimensions& dimensions, Tensor* output); +}; + +// This is CPU-only implementation that uses Eigen contraction output kernels. +// +// Dispatch 2D convolution to the appropriate primitive operation: +// (1) MatMul for the case of 1x1 convolution. +// (2) MatMul for the case when filter size equals to the input size. +// (3) General spatial 2D convolution for all other cases. +template +class LaunchFusedConv2DWithOutputKernel { + public: + LaunchFusedConv2DWithOutputKernel( + int row_stride, int col_stride, // + int row_dilation, int col_dilation, // + Padding padding, const std::vector& explicit_paddings) + : row_stride_(row_stride), + col_stride_(col_stride), + row_dilation_(row_dilation), + col_dilation_(col_dilation), + padding_(padding), + explicit_paddings_(explicit_paddings) {} + + template + void operator()(const OutputKernel& output_kernel, OpKernelContext* ctx, + const Tensor& input, const Tensor& filter, Tensor* output) { + // Wrap output_kernel into type erased wrapper to reduce the number of + // unique template instantiations for Eigen Tensor contraction expressions. + OutputKernelWrapper output_kernel_wrapper( + [&output_kernel]( + const ContractionOutputMapper& output_mapper, + const Eigen::TensorContractionParams& params, Eigen::Index i, + Eigen::Index j, Eigen::Index num_rows, Eigen::Index num_cols) { + output_kernel(output_mapper, params, i, j, num_rows, num_cols); + }); + + if (filter.dim_size(0) == 1 && filter.dim_size(1) == 1 && + row_stride_ == 1 && col_stride_ == 1 && padding_ != EXPLICIT) { + int conv_width = 1; // Width for the convolution step. + for (int i = 0; i < 3; ++i) { + conv_width *= output->dim_size(i); + } + + Eigen::array, 1> dim_pair; + dim_pair[0] = Eigen::IndexPair(1, 0); + functor::MatMulConvFunctor()( + ctx->eigen_device(), + output->shaped({conv_width, filter.dim_size(3)}), + input.shaped({conv_width, filter.dim_size(2)}), + filter.shaped({filter.dim_size(2), filter.dim_size(3)}), + dim_pair, std::move(output_kernel_wrapper)); + + } else if (filter.dim_size(0) == input.dim_size(1) && + filter.dim_size(1) == input.dim_size(2) && row_dilation_ == 1 && + col_dilation_ == 1 && padding_ == VALID) { + // If the input data and filter have the same height/width, + // reduce the 2D convolution to matrix multiplication. + const auto k = // Length of reduction dimension. + filter.dim_size(0) * filter.dim_size(1) * filter.dim_size(2); + + Eigen::array, 1> dim_pair; + dim_pair[0] = Eigen::IndexPair(1, 0); + functor::MatMulConvFunctor()( + ctx->eigen_device(), + output->shaped({input.dim_size(0), filter.dim_size(3)}), + input.shaped({input.dim_size(0), k}), + filter.shaped({k, filter.dim_size(3)}), dim_pair, + std::move(output_kernel_wrapper)); + + } else { + if (padding_ == EXPLICIT) { + functor::SpatialConvolution()( + ctx->eigen_device(), output->tensor(), + input.tensor(), filter.tensor(), row_stride_, + col_stride_, row_dilation_, col_dilation_, + static_cast(explicit_paddings_[2]), + static_cast(explicit_paddings_[3]), + static_cast(explicit_paddings_[4]), + static_cast(explicit_paddings_[5]), + std::move(output_kernel_wrapper)); + } else { + functor::SpatialConvolution()( + ctx->eigen_device(), output->tensor(), + input.tensor(), filter.tensor(), row_stride_, + col_stride_, row_dilation_, col_dilation_, + BrainPadding2EigenPadding(padding_), + std::move(output_kernel_wrapper)); + } + } + } + + private: + // Wrap output_kernel into type erased struct to reduce the number of unique + // template instantiations for Eigen Tensor contraction expressions. + // + // We do not pass std::function directly as an output kernel because it blows + // up the binary size in debug mode with super long symbol names. + struct OutputKernelWrapper { + using OutputKernelFn = + std::function&, + const Eigen::TensorContractionParams&, Eigen::Index, + Eigen::Index, Eigen::Index, Eigen::Index)>; + + explicit OutputKernelWrapper(OutputKernelFn fn) + : output_kernel_fn(std::move(fn)) {} + + void operator()( + const ContractionOutputMapper& output_mapper, + const Eigen::TensorContractionParams& params, Eigen::Index i, + Eigen::Index j, Eigen::Index num_rows, Eigen::Index num_cols) const { + output_kernel_fn(output_mapper, params, i, j, num_rows, num_cols); + } + + OutputKernelFn output_kernel_fn; + }; + + int row_stride_; + int col_stride_; + int row_dilation_; + int col_dilation_; + const Padding padding_; + const std::vector& explicit_paddings_; +}; + +template +struct LaunchFusedConv2DOp { + void operator()(OpKernelContext* context, bool use_cudnn, + bool cudnn_use_autotune, const Tensor& input, + const Tensor& filter, const FusedComputationType fusion, + const FusedComputationArgs& fusion_args, + const Conv2DParameters& params, + const Conv2DDimensions& dimensions, Tensor* output) { + OP_REQUIRES(context, dimensions.in_depth == filter.dim_size(2), + errors::Unimplemented("Fused conv implementation does not " + "support grouped convolutions for now.")); + OP_REQUIRES(context, params.data_format == FORMAT_NHWC, + errors::Unimplemented("Fused conv implementation only supports " + "NHWC tensor format for now.")); + OP_REQUIRES(context, DataTypeToEnum::value != DT_HALF, + errors::Unimplemented("Fused conv implementation with half " + "precision is not supported on CPU.")); + + BiasAddArgs bias_add_args; + if (BiasAddArgs::IsSupported(fusion)) { + if (fusion == FusedComputationType::kBiasAddWithLeakyRelu) { + OP_REQUIRES_OK(context, InitBiasAddArgs(context, &bias_add_args, + &fusion_args.leakyrelu_alpha)); + } else { + OP_REQUIRES_OK(context, InitBiasAddArgs(context, &bias_add_args)); + } + } + + FusedBatchNormArgs fused_batch_norm_args; + if (FusedBatchNormArgs::IsSupported(fusion)) { + if (fusion == FusedComputationType::kFusedBatchNormWithLeakyRelu) { + OP_REQUIRES_OK(context, + InitFusedBatchNormArgs(context, fusion_args.epsilon, + &fused_batch_norm_args, + &fusion_args.leakyrelu_alpha)); + } else { + OP_REQUIRES_OK(context, + InitFusedBatchNormArgs(context, fusion_args.epsilon, + &fused_batch_norm_args)); + } + } + + LaunchFusedConv2DWithOutputKernel conv2d( + dimensions.stride_rows, dimensions.stride_cols, + dimensions.dilation_rows, dimensions.dilation_cols, params.padding, + params.explicit_paddings); + + switch (fusion) { + case FusedComputationType::kUndefined: + OP_REQUIRES_OK(context, errors::Internal("Fusion type is undefined")); + break; + case FusedComputationType::kBiasAdd: + conv2d(WithBiasAdd(bias_add_args), context, input, filter, output); + break; + case FusedComputationType::kBiasAddWithRelu: + conv2d(WithBiasAddAndRelu(bias_add_args), context, input, filter, + output); + break; + case FusedComputationType::kBiasAddWithRelu6: + conv2d(WithBiasAddAndRelu6(bias_add_args), context, input, filter, + output); + break; + case FusedComputationType::kBiasAddWithLeakyRelu: + conv2d(WithBiasAddAndLeakyRelu(bias_add_args), context, input, + filter, output); + break; + case FusedComputationType::kBiasAddWithElu: + conv2d(WithBiasAddAndElu(bias_add_args), context, input, filter, + output); + break; + case FusedComputationType::kFusedBatchNorm: + conv2d( + WithFusedBatchNorm(fusion_args.epsilon, fused_batch_norm_args), + context, input, filter, output); + break; + case FusedComputationType::kFusedBatchNormWithRelu: + conv2d(WithFusedBatchNormAndRelu(fusion_args.epsilon, + fused_batch_norm_args), + context, input, filter, output); + break; + case FusedComputationType::kFusedBatchNormWithRelu6: + conv2d(WithFusedBatchNormAndRelu6(fusion_args.epsilon, + fused_batch_norm_args), + context, input, filter, output); + break; + case FusedComputationType::kFusedBatchNormWithLeakyRelu: + conv2d(WithFusedBatchNormAndLeakyRelu(fusion_args.epsilon, + fused_batch_norm_args), + context, input, filter, output); + break; + case FusedComputationType::kFusedBatchNormWithElu: + conv2d(WithFusedBatchNormAndElu(fusion_args.epsilon, + fused_batch_norm_args), + context, input, filter, output); + break; + default: + OP_REQUIRES_OK(context, errors::Internal("Fusion type is unsupported")); + break; + } + } +}; + +template <> +struct LaunchFusedConv2DOp; + +template <> +struct LaunchFusedConv2DOp; + +#if GOOGLE_CUDA + +inline int64_t ConvolveScratchSize() { + static int64_t convolve_scratch_size = GetDnnWorkspaceLimit( + // default value is in bytes despite the name of the environment variable + "TF_CUDNN_WORKSPACE_LIMIT_IN_MB", 1LL << 32 // 4GB + ); + return convolve_scratch_size; +} + +template +struct LaunchFusedConv2DOp { + void operator()(OpKernelContext* context, bool use_cudnn, + bool cudnn_use_autotune, const Tensor& input_param, + const Tensor& filter, FusedComputationType fusion, + const FusedComputationArgs& fusion_args, + const Conv2DParameters& params, + const Conv2DDimensions& dimensions, Tensor* output) { + OP_REQUIRES( + context, + params.data_format == FORMAT_NHWC || params.data_format == FORMAT_NCHW, + errors::Unimplemented("Fused conv implementation only supports " + "NHWC and HCHW tensor formats for now.")); + + auto* stream = context->op_device_context()->stream(); + OP_REQUIRES(context, stream, errors::Internal("No GPU stream available.")); + OP_REQUIRES( + context, use_cudnn, + errors::Unimplemented("FusedConv2D for GPU is not currently supported " + "without cudnn")); + + bool is_supported_activation = + fusion == FusedComputationType::kBiasAddWithRelu || + fusion == FusedComputationType::kBiasAddWithRelu6 || + fusion == FusedComputationType::kBiasAddWithElu || + fusion == FusedComputationType::kBiasAddWithLeakyRelu; + OP_REQUIRES( + context, is_supported_activation, + errors::Unimplemented("FusedConv2D implementation only supports " + "fusing with `BiasAdd + Relu|Relu6|Elu|LeakyRlue`" + " for now.")); + + Tensor input = input_param; + + const int64_t in_batch = GetTensorDim(input, params.data_format, 'N'); + int64_t in_rows = GetTensorDim(input, params.data_format, 'H'); + int64_t in_cols = GetTensorDim(input, params.data_format, 'W'); + const int64_t in_depths = GetTensorDim(input, params.data_format, 'C'); + + const int64_t patch_rows = filter.dim_size(0); + const int64_t patch_cols = filter.dim_size(1); + const int64_t patch_depths = filter.dim_size(2); + + const int64_t out_batch = GetTensorDim(*output, params.data_format, 'N'); + const int64_t out_rows = GetTensorDim(*output, params.data_format, 'H'); + const int64_t out_cols = GetTensorDim(*output, params.data_format, 'W'); + const int64_t out_depths = GetTensorDim(*output, params.data_format, 'C'); + + // Bias of the following dimensions: [ output_depth ] + const Tensor& bias = context->input(2); + OP_REQUIRES(context, bias.dims() == 1, + errors::InvalidArgument("bias must be 1-dimensional", + bias.shape().DebugString())); + OP_REQUIRES(context, bias.dim_size(0) == out_depths, + errors::InvalidArgument("bias depth must be equal to out depth", + bias.shape().DebugString())); + + const int64_t common_padding_rows = + std::min(dimensions.pad_rows_before, dimensions.pad_rows_after); + const int64_t common_padding_cols = + std::min(dimensions.pad_cols_before, dimensions.pad_cols_after); + if (dimensions.pad_rows_before != dimensions.pad_rows_after || + dimensions.pad_cols_before != dimensions.pad_cols_after) { + // cuDNN only supports padding the same amount on the left and right + // sides, and on the top and bottom sides. So we manually create a new + // padded input tensor such that we can pass it to cuDNN. + + // TODO(reedwm): In some cases, we can avoid an allocation even if the two + // padding sides are different. For example, if the input is 2x2, the + // filter is 1x1, the stride is 2, and the padding is (1, 0, 1, 0), the + // result is equivalent to as if the padding is (1, 1, 1, 1). Changing the + // padding in such a way would allow us to avoid the allocation. + Tensor transformed_input; + const int64_t padding_rows_diff = + std::abs(dimensions.pad_rows_after - dimensions.pad_rows_before); + const int64_t padding_cols_diff = + std::abs(dimensions.pad_cols_after - dimensions.pad_cols_before); + const int64_t new_in_rows = in_rows + padding_rows_diff; + const int64_t new_in_cols = in_cols + padding_cols_diff; + TensorShape transformed_input_shape; + OP_REQUIRES_OK(context, + ShapeFromFormatWithStatus( + params.data_format, in_batch, new_in_rows, new_in_cols, + in_depths, &transformed_input_shape)); + OP_REQUIRES_OK(context, context->allocate_temp(DataTypeToEnum::value, + transformed_input_shape, + &transformed_input)); + const int64_t input_pad_top = + dimensions.pad_rows_before - common_padding_rows; + const int64_t input_pad_bottom = + dimensions.pad_rows_after - common_padding_rows; + const int64_t input_pad_left = + dimensions.pad_cols_before - common_padding_cols; + const int64_t input_pad_right = + dimensions.pad_cols_after - common_padding_cols; + bool in_bounds = + FastBoundsCheck(input_pad_top, std::numeric_limits::max()) && + FastBoundsCheck(input_pad_bottom, std::numeric_limits::max()) && + FastBoundsCheck(input_pad_left, std::numeric_limits::max()) && + FastBoundsCheck(input_pad_right, std::numeric_limits::max()); + if (!in_bounds) { + context->SetStatus(errors::InvalidArgument("Padding is too large.")); + return; + } + functor::PadInput()( + context->eigen_device(), + To32Bit(input_param.tensor()), + {{static_cast(input_pad_top), static_cast(input_pad_left)}}, + {{static_cast(input_pad_bottom), + static_cast(input_pad_right)}}, + To32Bit(transformed_input.tensor()), params.data_format, T{}); + input = transformed_input; + in_rows = new_in_rows; + in_cols = new_in_cols; + } + + const bool compute_in_nhwc = DataTypeToEnum::value == DT_HALF && + stream->GetCudaComputeCapability().IsAtLeast( + se::CudaComputeCapability::VOLTA); + if (!compute_in_nhwc && params.data_format == FORMAT_NHWC) { + // Convert the input tensor from NHWC to NCHW. + TensorShape nchw_shape; + OP_REQUIRES_OK( + context, ShapeFromFormatWithStatus(FORMAT_NCHW, in_batch, in_rows, + in_cols, in_depths, &nchw_shape)); + if (in_depths > 1) { + Tensor transformed_input; + OP_REQUIRES_OK(context, + context->allocate_temp(DataTypeToEnum::value, + nchw_shape, &transformed_input)); + functor::NHWCToNCHW()( + context->eigen_device(), + const_cast(input).tensor(), + transformed_input.tensor()); + input = transformed_input; + } else { + // If depth <= 1, then just reshape. + CHECK(input.CopyFrom(input, nchw_shape)); // Crash OK + } + } + + CHECK(common_padding_rows >= 0) << "Negative padding rows"; // Crash OK + CHECK(common_padding_rows >= 0) << "Negative padding cols"; // Crash OK + + se::dnn::ActivationMode dnn_activation_mode; + switch (fusion) { + case FusedComputationType::kBiasAddWithRelu: + dnn_activation_mode = se::dnn::ActivationMode::kRelu; + break; + case FusedComputationType::kBiasAddWithRelu6: + dnn_activation_mode = se::dnn::ActivationMode::kRelu6; + break; + case FusedComputationType::kBiasAddWithElu: + dnn_activation_mode = se::dnn::ActivationMode::kElu; + break; + case FusedComputationType::kBiasAddWithLeakyRelu: + dnn_activation_mode = se::dnn::ActivationMode::kLeakyRelu; + break; + default: + LOG(FATAL) << "Unsupported fusion type"; // Crash OK + } + + const TensorFormat compute_data_format = + compute_in_nhwc ? FORMAT_NHWC : FORMAT_NCHW; + constexpr auto kComputeInNHWC = + std::make_tuple(se::dnn::DataLayout::kBatchYXDepth, + se::dnn::FilterLayout::kOutputYXInput); + constexpr auto kComputeInNCHW = + std::make_tuple(se::dnn::DataLayout::kBatchDepthYX, + se::dnn::FilterLayout::kOutputInputYX); + se::dnn::DataLayout compute_data_layout; + se::dnn::FilterLayout filter_layout; + std::tie(compute_data_layout, filter_layout) = + compute_in_nhwc ? kComputeInNHWC : kComputeInNCHW; + + se::dnn::BatchDescriptor input_desc; + input_desc.set_count(in_batch) + .set_feature_map_count(in_depths) + .set_height(in_rows) + .set_width(in_cols) + .set_layout(compute_data_layout); + se::dnn::FilterDescriptor filter_desc; + filter_desc.set_input_filter_height(patch_rows) + .set_input_filter_width(patch_cols) + .set_input_feature_map_count(patch_depths) + .set_output_feature_map_count(filter.dim_size(3)) + .set_layout(filter_layout); + se::dnn::BatchDescriptor bias_desc; + bias_desc.set_count(1) + .set_height(1) + .set_width(1) + .set_feature_map_count(out_depths) + .set_layout(compute_data_layout); + se::dnn::ConvolutionDescriptor conv_desc; + conv_desc.set_vertical_dilation_rate(dimensions.dilation_rows) + .set_horizontal_dilation_rate(dimensions.dilation_cols) + .set_vertical_filter_stride(dimensions.stride_rows) + .set_horizontal_filter_stride(dimensions.stride_cols) + .set_zero_padding_height(common_padding_rows) + .set_zero_padding_width(common_padding_cols) + .set_group_count(in_depths / patch_depths); + se::dnn::BatchDescriptor output_desc; + output_desc.set_count(out_batch) + .set_height(out_rows) + .set_width(out_cols) + .set_feature_map_count(out_depths) + .set_layout(compute_data_layout); + + Tensor transformed_filter; + const auto transform_filter = [&](FilterTensorFormat dst_format) -> Status { + VLOG(4) << "Transform filter tensor from " << ToString(FORMAT_HWIO) + << " to " << ToString(dst_format); + + TensorShape dst_shape = + dst_format == FORMAT_OIHW + ? TensorShape({filter.dim_size(3), filter.dim_size(2), + filter.dim_size(0), filter.dim_size(1)}) + : TensorShape({filter.dim_size(3), filter.dim_size(0), + filter.dim_size(1), filter.dim_size(2)}); + + TF_RETURN_IF_ERROR(context->allocate_temp( + DataTypeToEnum::value, dst_shape, &transformed_filter)); + functor::TransformFilter()( + context->eigen_device(), dst_format, + To32Bit(filter.tensor()), + To32Bit(transformed_filter.tensor())); + + return OkStatus(); + }; + + if (compute_in_nhwc) { + OP_REQUIRES_OK(context, transform_filter(FORMAT_OHWI)); + } else { + OP_REQUIRES_OK(context, transform_filter(FORMAT_OIHW)); + } + + Tensor transformed_output; + if (!compute_in_nhwc && params.data_format == FORMAT_NHWC) { + // Only allocate temporary memory when a layout transformation is needed. + TensorShape transformed_output_shape; + OP_REQUIRES_OK(context, ShapeFromFormatWithStatus( + FORMAT_NCHW, out_batch, out_rows, out_cols, + out_depths, &transformed_output_shape)); + OP_REQUIRES_OK(context, context->allocate_temp(DataTypeToEnum::value, + transformed_output_shape, + &transformed_output)); + } else { + transformed_output = *output; + } + + const auto tensor_on_device = [](const Tensor& t) -> se::DeviceMemory { + return AsDeviceMemory(t.template flat().data(), + t.template flat().size()); + }; + + se::DeviceMemory input_ptr = tensor_on_device(input); + se::DeviceMemory filter_ptr = tensor_on_device(transformed_filter); + se::DeviceMemory bias_ptr = tensor_on_device(bias); + se::DeviceMemory output_ptr = tensor_on_device(transformed_output); + + // We do not use side inputs, so we can safely pass nullptr. + se::DeviceMemory side_input_ptr = + AsDeviceMemory(static_cast(nullptr), 0); + + constexpr double kConvScale = 1.0; + constexpr double kSideInputScale = 0.0; + double leakyrelu_alpha = fusion_args.leakyrelu_alpha; + + DataType dtype = input.dtype(); + ConvParameters conv_parameters = { + stream->parent(), + in_batch, // batch + in_depths, // in_depths + {{in_rows, // in_rows + in_cols}}, // in_cols + compute_data_format, // compute_data_format + out_depths, // out_depths + {{patch_rows, // filter_rows + patch_cols, // filter_cols + patch_depths}}, // filter_depths + {{dimensions.dilation_rows, // dilation_rows + dimensions.dilation_cols}}, // dilation_cols + {{dimensions.stride_rows, // stride_rows + dimensions.stride_cols}}, // stride_cols + {{common_padding_rows, // padding_rows + common_padding_cols}}, // padding_cols + dtype, // tensor datatype + conv_desc.group_count(), + ConvParameters::FusionInfo{kConvScale, kSideInputScale, leakyrelu_alpha, + dnn_activation_mode, // activation_mode + /*is_contrib=*/false}}; + + se::dnn::DataType element_type = se::dnn::ToDataType::value; + + auto entry_or = AutotuneFusedConv( + cudnn_use_autotune, FusedConvAutotuneMap::GetInstance(), + conv_parameters, context, input_desc, filter_desc, bias_desc, + output_desc, conv_desc, dnn_activation_mode, kConvScale, + kSideInputScale, leakyrelu_alpha, input_ptr, filter_ptr, output_ptr, + bias_ptr, side_input_ptr, ConvolveScratchSize()); + OP_REQUIRES_OK(context, entry_or.status()); + auto autotune_entry = std::move(entry_or).value(); + + DnnScratchAllocator scratch_allocator(ConvolveScratchSize(), context); + Status cudnn_launch_status; + if (!autotune_entry.is_algorithm_config()) { + auto& runners = autotune_entry.GetOpRunners(); + se::dnn::FusedConvOp::Config config{se::dnn::ConvolutionKind::FORWARD, + element_type, + element_type, + element_type, + kConvScale, + kSideInputScale, + leakyrelu_alpha, + input_desc, + filter_desc, + bias_desc, + output_desc, + conv_desc, + dnn_activation_mode}; + auto primary_or = runners.primary->GetOrCreateRunner(config, stream); + OP_REQUIRES_OK(context, primary_or.status()); + auto* primary = primary_or.value(); + + const se::dnn::FusedConvRunner* no_scratch_fallback = nullptr; + if (runners.no_scratch_fallback) { + auto no_scratch_fallback_or = + runners.no_scratch_fallback->GetOrCreateRunner(config, stream); + OP_REQUIRES_OK(context, no_scratch_fallback_or.status()); + no_scratch_fallback = no_scratch_fallback_or.value(); + } + + auto runner_and_scratch_or = + AllocateScratchOrFallback( + &scratch_allocator, primary, no_scratch_fallback); + OP_REQUIRES_OK(context, runner_and_scratch_or.status()); + auto runner_and_scratch = std::move(runner_and_scratch_or).value(); + auto& runner = + *std::get(runner_and_scratch); + cudnn_launch_status = runner( + stream, nullptr, std::get(runner_and_scratch), + input_ptr, filter_ptr, side_input_ptr, bias_ptr, output_ptr); + } else { + auto dnn = stream->parent()->AsDnn(); + OP_REQUIRES(context, dnn != nullptr, + absl::InternalError("No DNN for stream.")); + cudnn_launch_status = dnn->FusedConvolveWithAlgorithm( + stream, input_desc, input_ptr, // input + kConvScale, // input_scale + filter_desc, filter_ptr, // filter + conv_desc, // conv + side_input_ptr, kSideInputScale, // side_input + bias_desc, bias_ptr, // bias + dnn_activation_mode, // activation + output_desc, &output_ptr, // output + &scratch_allocator, autotune_entry.GetAlgorithmConfig(), nullptr); + } + + OP_REQUIRES_OK(context, cudnn_launch_status); + + // Convert the output tensor back from NCHW to NHWC. + if (!compute_in_nhwc && params.data_format == FORMAT_NHWC) { + functor::NCHWToNHWC()( + context->eigen_device(), + const_cast(transformed_output).tensor(), + output->tensor()); + } + } +}; + +template <> +struct LaunchFusedConv2DOp; + +template <> +struct LaunchFusedConv2DOp; + +#endif // GOOGLE_CUDA + +template +class FusedConv2DOp : public OpKernel { + public: + explicit FusedConv2DOp(OpKernelConstruction* context) : OpKernel(context) { + OP_REQUIRES_OK(context, InitConv2DParameters(context, ¶ms_)); + + OP_REQUIRES_OK(context, context->GetAttr("use_cudnn_on_gpu", &use_cudnn_)); + cudnn_use_autotune_ = CudnnUseAutotune(); + + using FCT = FusedComputationType; + + std::vector patterns; + if (std::is_same::value) { + patterns = { + {FCT::kBiasAdd, {"BiasAdd"}}, + {FCT::kBiasAddWithRelu, {"BiasAdd", "Relu"}}, + {FCT::kBiasAddWithRelu6, {"BiasAdd", "Relu6"}}, + {FCT::kBiasAddWithElu, {"BiasAdd", "Elu"}}, + {FCT::kBiasAddWithLeakyRelu, {"BiasAdd", "LeakyRelu"}}, + {FCT::kFusedBatchNorm, {"FusedBatchNorm"}}, + {FCT::kFusedBatchNormWithRelu, {"FusedBatchNorm", "Relu"}}, + {FCT::kFusedBatchNormWithRelu6, {"FusedBatchNorm", "Relu6"}}, + {FCT::kFusedBatchNormWithElu, {"FusedBatchNorm", "Elu"}}, + {FCT::kFusedBatchNormWithLeakyRelu, {"FusedBatchNorm", "LeakyRelu"}}, + }; + } + + // NOTE(ezhulenev): CuDNN `cudnnConvolutionBiasActivationForward` supports + // identity activation function, it in theory should allow to fuse + // convolution with BiasAdd, but in practice it doesn't work, cuDNN ignores + // this parameter and always does Relu activation. + if (std::is_same::value) { + if (std::is_same::value || std::is_same::value) { + patterns = {{FCT::kBiasAdd, {"BiasAdd"}}, + {FCT::kBiasAddWithRelu, {"BiasAdd", "Relu"}}}; + } else { + patterns = { + {FCT::kBiasAddWithRelu, {"BiasAdd", "Relu"}}, + {FCT::kBiasAddWithRelu6, {"BiasAdd", "Relu6"}}, + {FCT::kBiasAddWithElu, {"BiasAdd", "Elu"}}, + {FCT::kBiasAddWithLeakyRelu, {"BiasAdd", "LeakyRelu"}}, + }; + } + } + + OP_REQUIRES_OK(context, InitializeFusedComputation( + context, "Conv2D", patterns, + &fused_computation_, &fused_computation_args_)); + } + + void Compute(OpKernelContext* context) override { + // Input tensor is of the following dimensions: + // [ batch, in_rows, in_cols, in_depth ] + const Tensor& input = context->input(0); + + // Input filter is of the following dimensions: + // [ filter_rows, filter_cols, in_depth, out_depth] + const Tensor& filter = context->input(1); + + Conv2DDimensions dimensions; + OP_REQUIRES_OK(context, + ComputeConv2DDimension(params_, input, filter, &dimensions)); + + TensorShape out_shape; + OP_REQUIRES_OK( + context, ShapeFromFormatWithStatus( + params_.data_format, dimensions.batch, dimensions.out_rows, + dimensions.out_cols, dimensions.out_depth, &out_shape)); + + // Output tensor is of the following dimensions: + // [ in_batch, out_rows, out_cols, out_depth ] + Tensor* output = nullptr; + OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &output)); + + VLOG(2) << "FusedConv2D: in_depth = " << dimensions.in_depth + << ", patch_depth = " << dimensions.patch_depth + << ", input_cols = " << dimensions.input_cols + << ", filter_cols = " << dimensions.filter_cols + << ", input_rows = " << dimensions.input_rows + << ", filter_rows = " << dimensions.filter_rows + << ", stride_rows = " << dimensions.stride_rows + << ", stride_cols = " << dimensions.stride_cols + << ", dilation_rows = " << dimensions.dilation_rows + << ", dilation_cols = " << dimensions.dilation_cols + << ", out_depth = " << dimensions.out_depth; + + // If there is nothing to compute, return. + if (out_shape.num_elements() == 0) { + return; + } + + LaunchFusedConv2DOp()(context, use_cudnn_, cudnn_use_autotune_, + input, filter, fused_computation_, + fused_computation_args_, params_, + dimensions, output); + } + + private: + Conv2DParameters params_; + bool use_cudnn_; + bool cudnn_use_autotune_; + + FusedComputationType fused_computation_ = FusedComputationType::kUndefined; + FusedComputationArgs fused_computation_args_; + + FusedConv2DOp(const FusedConv2DOp&) = delete; + void operator=(const FusedConv2DOp&) = delete; +}; + +// Registration of the CPU implementations. +#define REGISTER_FUSED_CPU_CONV2D(T) \ + REGISTER_KERNEL_BUILDER( \ + Name("_FusedConv2D").Device(DEVICE_CPU).TypeConstraint("T"), \ + FusedConv2DOp); + +#if GOOGLE_CUDA + +#define DECLARE_FUNCTOR_GPU_SPEC(T) \ + template <> \ + void TransformFilter::operator()( \ + const GPUDevice& d, FilterTensorFormat dst_filter_format, \ + typename TTypes::ConstTensor in, \ + typename TTypes::Tensor out); \ + extern template struct TransformFilter; \ + template <> \ + void PadInput::operator()( \ + const GPUDevice& d, typename TTypes::ConstTensor in, \ + const std::array& padding_left, \ + const std::array& padding_right, \ + typename TTypes::Tensor out, TensorFormat data_format, \ + const T& padding_value); \ + extern template struct PadInput + +// Registration of the GPU implementations. +#define REGISTER_FUSED_GPU_CONV2D(T) \ + REGISTER_KERNEL_BUILDER(Name("_FusedConv2D") \ + .Device(DEVICE_GPU) \ + .TypeConstraint("T") \ + .HostMemory("host_args"), \ + FusedConv2DOp); + +#endif // GOOGLE_CUDA + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_CONV_OPS_FUSED_IMPL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/conv_ops_gpu.h b/third_party/tflite-hdrs/tensorflow/core/kernels/conv_ops_gpu.h new file mode 100644 index 00000000..627450ef --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/conv_ops_gpu.h @@ -0,0 +1,213 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_CONV_OPS_GPU_H_ +#define TENSORFLOW_CORE_KERNELS_CONV_OPS_GPU_H_ + +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM + +#include +#include + +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/kernels/gpu_utils.h" +#include "tensorflow/core/lib/gtl/inlined_vector.h" +#include "tensorflow/core/lib/hash/hash.h" +#include "tensorflow/core/util/autotune_maps/conv_parameters.h" +#include "tensorflow/core/util/tensor_format.h" + +namespace tensorflow { + +bool ComputeInNhwcEnabled(DataType data_type, se::Stream* stream, + bool use_4d_tensor = true); + +// Get the Dnn workspace limit from the environment variable, which is in MB. +// Return the workspace memory limit in bytes. If no value is set, return the +// default value. +int64 GetDnnWorkspaceLimit(const string& envvar_in_mb, + int64_t default_value_in_bytes); + +// Call the Dnn workspace limit from TF_CUDNN_WORKSPACE_LIMIT_IN_MB or default. +int64 GetDnnWorkspaceLimitOrDefault(); + +// A class to provide scratch-space allocator for Stream-Executor Cudnn +// callback. TensorFlow is responsible for releasing the temporary buffers after +// the kernel finishes. +class DnnScratchAllocator : public se::ScratchAllocator { + public: + virtual ~DnnScratchAllocator() {} + DnnScratchAllocator(int64_t memory_limit, OpKernelContext* context) + : memory_limit_(memory_limit), total_byte_size_(0), context_(context) {} + int64 GetMemoryLimitInBytes() override { return memory_limit_; } + tsl::StatusOr> AllocateBytes( + int64_t byte_size) override { + Tensor temporary_memory; + if (byte_size < 0) { + return tsl::Status{absl::StatusCode::kInvalidArgument, + "Requested negative byte size!"}; + } + if (byte_size > memory_limit_) { + return tsl::Status{absl::StatusCode::kUnavailable, + absl::StrCat("Requested memory size (", byte_size, + ") exceeds the max memory limit (", + memory_limit_, ").")}; + } + AllocationAttributes allocation_attr; + allocation_attr.retry_on_failure = false; + Status allocation_status(context_->allocate_temp( + DT_UINT8, TensorShape({byte_size}), &temporary_memory, + AllocatorAttributes(), allocation_attr)); + if (!allocation_status.ok()) { + return tsl::Status{ + absl::StatusCode::kUnavailable, + absl::StrCat("Failed to allocate the requested memory size (", + byte_size, ").")}; + } + // Hold the reference of the allocated tensors until the end of the + // allocator. + allocated_tensors_.push_back(temporary_memory); + total_byte_size_ += byte_size; + return tsl::StatusOr>( + AsDeviceMemory(temporary_memory.flat().data(), + temporary_memory.flat().size())); + } + int64 TotalByteSize() { return total_byte_size_; } + + private: + int64 memory_limit_; + int64 total_byte_size_; + OpKernelContext* context_; + std::vector allocated_tensors_; +}; + +typedef Eigen::GpuDevice GPUDevice; + +// Select an algorithm for the given convolution, either by running actual +// autotuning with a cache, or by falling back to a default if +// 'cudnn_use_autotune' is true and cuDNN is the statically-chosen DNN backend. +template +StatusOr> AutotuneFusedConv( + bool cudnn_use_autotune, + AutotuneMap>* + autotune_map, + const ConvParameters& params, OpKernelContext* ctx, + const se::dnn::BatchDescriptor& input_desc, + const se::dnn::FilterDescriptor& filter_desc, + const se::dnn::BatchDescriptor& bias_desc, + const se::dnn::BatchDescriptor& output_desc, + const se::dnn::ConvolutionDescriptor& conv_desc, + const se::dnn::ActivationMode activation_mode, double conv_input_scale, + double side_input_scale, double leakyrelu_alpha, + se::DeviceMemory input_ptr, se::DeviceMemory filter_ptr, + se::DeviceMemory output_ptr, se::DeviceMemory bias_ptr, + se::DeviceMemory side_input_ptr, int64_t scratch_size); + +template +StatusOr> AutotuneUnfusedConv( + bool cudnn_use_autotune, + AutotuneMap>* autotune_map, + const ConvParameters& conv_parameters, OpKernelContext* ctx, + se::dnn::ConvolutionKind kind, const se::dnn::BatchDescriptor& input_desc, + se::DeviceMemory input_ptr, const se::dnn::FilterDescriptor& filter_desc, + se::DeviceMemory filter_ptr, + const se::dnn::ConvolutionDescriptor& conv_desc, + const se::dnn::BatchDescriptor& output_desc, se::DeviceMemory output_ptr, + int64_t scratch_size_limit); + +// Returns a pointer to the primary 'OpRunner' of 'runners' and allocated +// scratch memory if allocatable; else a pointer to its fallback +// no-scratch-space runner, and a null 'DeviceMemoryBase'. +template +StatusOr*, se::DeviceMemoryBase>> +AllocateScratchOrFallback(se::ScratchAllocator* scratch_allocator, + const se::dnn::OpRunner* primary, + const se::dnn::OpRunner* no_scratch_fallback) { + const se::dnn::OpRunner* selected_runner = primary; + + auto workspace_size = selected_runner->GetWorkspaceSize(); + + se::DeviceMemoryBase scratch_memory; + if (workspace_size > 0) { + auto scratch_or = scratch_allocator->AllocateBytes(workspace_size); + if (scratch_or.ok()) { + scratch_memory = scratch_or.value(); + } else if ((selected_runner = no_scratch_fallback)) { + if (selected_runner->GetWorkspaceSize() > 0) { + return errors::Internal( + "No-scratch fallback runner requires nonzero scratch space"); + } + } else { + return errors::Unknown( + "CUDNN failed to allocate the scratch space for the runner or to " + "find a working no-scratch runner."); + } + } + + return std::make_tuple(selected_runner, scratch_memory); +} + +template +Status LaunchAutotunedConv(const AutotuneEntry& autotune_entry, + DnnScratchAllocator* scratch_allocator, + se::dnn::ConvolutionKind kind, se::Stream* stream, + const se::dnn::BatchDescriptor& input_desc, + se::DeviceMemory in_ptr, + const se::dnn::FilterDescriptor& filter_desc, + se::DeviceMemory filter_ptr, + const se::dnn::ConvolutionDescriptor& conv_desc, + const se::dnn::BatchDescriptor& output_desc, + se::DeviceMemory out_ptr) { + if (!autotune_entry.is_algorithm_config()) { + const auto& runners = autotune_entry.GetOpRunners(); + se::dnn::DataType element_type = se::dnn::ToDataType::value; + se::dnn::ConvOp::Config config{kind, element_type, element_type, + input_desc, filter_desc, output_desc, + conv_desc}; + TF_ASSIGN_OR_RETURN(auto* primary, + runners.primary->GetOrCreateRunner(config, stream)); + + const se::dnn::ConvRunner* no_scratch_fallback = nullptr; + if (runners.no_scratch_fallback) { + TF_ASSIGN_OR_RETURN( + no_scratch_fallback, + runners.no_scratch_fallback->GetOrCreateRunner(config, stream)); + } + + TF_ASSIGN_OR_RETURN(auto runner_and_scratch, + AllocateScratchOrFallback( + scratch_allocator, primary, no_scratch_fallback)); + auto& runner = *std::get(runner_and_scratch); + return runner(stream, nullptr, + std::get(runner_and_scratch), in_ptr, + filter_ptr, out_ptr); + } else { + auto dnn = stream->parent()->AsDnn(); + if (dnn == nullptr) { + return absl::InternalError("No DNN for stream."); + } + return dnn->ConvolveWithAlgorithm( + stream, kind, input_desc, in_ptr, filter_desc, filter_ptr, output_desc, + out_ptr, conv_desc, scratch_allocator, + autotune_entry.GetAlgorithmConfig(), nullptr); + } +} + +} // namespace tensorflow + +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM + +#endif // TENSORFLOW_CORE_KERNELS_CONV_OPS_GPU_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/conv_ops_impl.h b/third_party/tflite-hdrs/tensorflow/core/kernels/conv_ops_impl.h new file mode 100644 index 00000000..0d3fc798 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/conv_ops_impl.h @@ -0,0 +1,1284 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// See docs in ../ops/nn_ops.cc. + +#ifndef TENSORFLOW_CORE_KERNELS_CONV_OPS_IMPL_H_ +#define TENSORFLOW_CORE_KERNELS_CONV_OPS_IMPL_H_ + +#include + +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "tensorflow/core/framework/op_requires.h" + +#define USE_EIGEN_TENSOR +#define EIGEN_USE_THREADS + +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM +#define EIGEN_USE_GPU +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM + +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/synchronization/blocking_counter.h" +#include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/framework/bounds_check.h" +#include "tensorflow/core/framework/kernel_shape_util.h" +#include "tensorflow/core/framework/numeric_op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/tensor_slice.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/kernels/conv_2d.h" +#include "tensorflow/core/kernels/conv_3d.h" +#include "tensorflow/core/kernels/conv_ops.h" +#include "tensorflow/core/kernels/deep_conv2d.h" +#include "tensorflow/core/kernels/fill_functor.h" +#include "tensorflow/core/kernels/ops_util.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/lib/strings/numbers.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/profiler/lib/scoped_annotation.h" +#include "tensorflow/core/util/padding.h" +#include "tensorflow/core/util/tensor_format.h" +#include "tensorflow/core/util/use_cudnn.h" + +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM +#include "tensorflow/core/kernels/cast_op.h" +#include "tensorflow/core/kernels/conv_ops_gpu.h" +#include "tensorflow/core/kernels/numeric_options_utils.h" +#include "tensorflow/core/platform/stream_executor.h" +#include "tensorflow/core/util/autotune_maps/conv_autotune_maps.h" +#include "tensorflow/core/util/autotune_maps/conv_parameters.h" +#include "tensorflow/core/util/proto/proto_utils.h" +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM +#if GOOGLE_CUDA +#include "xla/stream_executor/gpu/gpu_asm_opts.h" +#include "xla/stream_executor/gpu/redzone_allocator.h" +#include "xla/stream_executor/integrations/tf_allocator_adapter.h" +#endif // GOOGLE_CUDA + +namespace tensorflow { + +typedef Eigen::ThreadPoolDevice CPUDevice; +typedef Eigen::GpuDevice GPUDevice; + +template +struct LaunchGeneric { + void operator()(OpKernelContext* ctx, const Tensor& input, + const Tensor& filter, int row_stride, int col_stride, + int row_dilation, int col_dilation, const Padding& padding, + const std::vector& explicit_paddings, Tensor* output, + TensorFormat data_format) { + DCHECK(data_format == FORMAT_NHWC) + << "Generic conv implementation only " + "supports NHWC tensor format for now."; + if (filter.dim_size(0) == 1 && filter.dim_size(1) == 1 && row_stride == 1 && + col_stride == 1 && (padding == SAME || padding == VALID)) { + // For 1x1 kernel, the 2D convolution is reduced to matrix + // multiplication. + // + // TODO(vrv): We should be able to call SpatialConvolution + // and it will produce the same result, but doing so + // led to NaNs during training. Using matmul instead for now. + int conv_width = 1; // Width for the convolution step. + for (int i = 0; i < 3; ++i) { + conv_width *= output->dim_size(i); + } + + Eigen::array, 1> dim_pair; + dim_pair[0] = Eigen::IndexPair(1, 0); + functor::MatMulConvFunctor()( + ctx->eigen_device(), + output->shaped({conv_width, filter.dim_size(3)}), + input.shaped({conv_width, filter.dim_size(2)}), + filter.shaped({filter.dim_size(2), filter.dim_size(3)}), + dim_pair); + } else if (filter.dim_size(0) == input.dim_size(1) && + filter.dim_size(1) == input.dim_size(2) && row_dilation == 1 && + col_dilation == 1 && padding == VALID) { + // If the input data and filter have the same height/width, + // the 2D convolution is reduced to matrix multiplication. + const int k = // Length of reduction dimension. + filter.dim_size(0) * filter.dim_size(1) * filter.dim_size(2); + + Eigen::array, 1> dim_pair; + dim_pair[0] = Eigen::IndexPair(1, 0); + functor::MatMulConvFunctor()( + ctx->eigen_device(), + output->shaped({input.dim_size(0), filter.dim_size(3)}), + input.shaped({input.dim_size(0), k}), + filter.shaped({k, filter.dim_size(3)}), dim_pair); + } else { + if (padding == EXPLICIT) { + functor::SpatialConvolution()( + ctx->eigen_device(), output->tensor(), + input.tensor(), filter.tensor(), row_stride, col_stride, + row_dilation, col_dilation, static_cast(explicit_paddings[2]), + static_cast(explicit_paddings[3]), + static_cast(explicit_paddings[4]), + static_cast(explicit_paddings[5])); + } else { + functor::SpatialConvolution()( + ctx->eigen_device(), output->tensor(), + input.tensor(), filter.tensor(), row_stride, col_stride, + row_dilation, col_dilation, BrainPadding2EigenPadding(padding)); + } + } + } +}; + +// Compute grouped 2D convolutions on CPU. Unlike grouped convolution +// implementation in cuDNN this is faaaaaar from optimal and needs more work +// to deliver competitive performance. Currently it exists to close the feature +// parity gap between convolution operations on different devices. +template +struct LaunchGrouped { + void operator()(OpKernelContext* ctx, const Tensor& input, + const Tensor& filter, int row_stride, int col_stride, + int row_dilation, int col_dilation, const Padding& padding, + const std::vector& explicit_paddings, Tensor* output, + TensorFormat data_format) { + DCHECK(data_format == FORMAT_NHWC) + << "Grouped conv implementation only " + "supports NHWC tensor format for now."; + + const int64_t in_depth = input.dim_size(3); + const int64_t patch_depth = filter.dim_size(2); + const int64_t num_groups = in_depth / patch_depth; + + // Shuffle input/filter tensors to have group as a leading dimension. + std::array shuffle({3, 0, 1, 2, 4}); + + // Compute pre shuffle dimemnsions. + auto pre_shuffle = [&](const Tensor& tensor) -> std::array { + return {tensor.dim_size(0), tensor.dim_size(1), tensor.dim_size(2), + num_groups, tensor.dim_size(3) / num_groups}; + }; + + // Compute post shuffle dimemnsions. + auto post_shuffle = [&](const Tensor& tensor) -> std::array { + return {num_groups, tensor.dim_size(0), tensor.dim_size(1), + tensor.dim_size(2), tensor.dim_size(3) / num_groups}; + }; + + auto& device = ctx->eigen_device(); + + absl::BlockingCounter shuffles_completed(2); + auto on_shuffled = [&]() { shuffles_completed.DecrementCount(); }; + + // Shuffle input into temporary tensor. + Tensor input_shuffled; + OP_REQUIRES_OK( + ctx, ctx->allocate_temp(input.dtype(), TensorShape(post_shuffle(input)), + &input_shuffled)); + input_shuffled.tensor().device(device, on_shuffled) = + input.shaped(pre_shuffle(input)).shuffle(shuffle); + + // Shuffle filter into temporary tensor. + Tensor filter_shuffled; + OP_REQUIRES_OK(ctx, ctx->allocate_temp(filter.dtype(), + TensorShape(post_shuffle(filter)), + &filter_shuffled)); + filter_shuffled.tensor().device(device, on_shuffled) = + filter.shaped(pre_shuffle(filter)).shuffle(shuffle); + + // Wait for the completion of input/filter shuffles. + shuffles_completed.Wait(); + + // Write group convolution results into temporary output tensor. + Tensor output_shuffled; + OP_REQUIRES_OK(ctx, ctx->allocate_temp(output->dtype(), + TensorShape(post_shuffle(*output)), + &output_shuffled)); + + for (int64_t i = 0; i < num_groups; ++i) { + // TODO(ezhulenev): Run this loop using `parallelFor` (regular parallelFor + // will lead to deadlock, SpatialConvolution has to use async Eigen + // assignment). This requires small changes to Eigen to support async + // exeuction for tensor chipping operation. + + // TODO(ezhulenev): Grouped convolution should also support 1x1 filter + // optimization. + + auto input_slice = input_shuffled.tensor().template chip<0>(i); + auto filter_slice = filter_shuffled.tensor().template chip<0>(i); + auto output_slice = output_shuffled.tensor().template chip<0>(i); + + if (padding == EXPLICIT) { + functor::SpatialConvolution()( + ctx->eigen_device(), output_slice, input_slice, + filter_slice, row_stride, col_stride, row_dilation, col_dilation, + static_cast(explicit_paddings[2]), + static_cast(explicit_paddings[3]), + static_cast(explicit_paddings[4]), + static_cast(explicit_paddings[5])); + } else { + functor::SpatialConvolution()( + ctx->eigen_device(), output_slice, input_slice, + filter_slice, row_stride, col_stride, row_dilation, col_dilation, + BrainPadding2EigenPadding(padding)); + } + } + + // Shuffle temporary output back into pre-shuffled shape. + std::array rev_shuffle({1, 2, 3, 0, 4}); + output->shaped(pre_shuffle(*output)).device(device) = + output_shuffled.tensor().shuffle(rev_shuffle); + } +}; + +template +struct LaunchConvOp; + +template +struct LaunchConvOp { + void operator()(OpKernelContext* context, bool cudnn_use_autotune, + const Tensor& input, const Tensor& filter, + const std::vector& dilations, + const std::vector& strides, const Padding padding, + const std::vector& explicit_paddings, + TensorFormat data_format, Tensor* output) { + // For now just calling existing launchers based on spatial dimensions. + int spatial_dims = input.dims() - 2; + + if (spatial_dims == 2) { + LaunchConv2DOp()(context, true, cudnn_use_autotune, input, + filter, dilations[1], dilations[2], + strides[1], strides[2], padding, + explicit_paddings, output, data_format); + } else { + LaunchConv3DOp().launch( + context, cudnn_use_autotune, input, filter, + {dilations[1], dilations[2], dilations[3]}, + {strides[1], strides[2], strides[3]}, padding, data_format, output); + } + } +}; + +template +class ConvOp : public BinaryOp { + public: + explicit ConvOp(OpKernelConstruction* context) : BinaryOp(context) { + // TODO(b/290223810) Add support for grouped and depthwise convolutions. + OP_REQUIRES_OK(context, context->GetAttr("groups", &groups_)); + OP_REQUIRES(context, groups_ == 1, + absl::UnimplementedError( + "Grouped/Depthwise Convolutions are not supported yet.")); + string data_format_str; + OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format_str)); + OP_REQUIRES(context, + data_format_str == "CHANNELS_LAST" || + data_format_str == "CHANNELS_FIRST", + absl::InvalidArgumentError( + absl::StrCat("Unknown data format: ", data_format_str))); + data_format_ = + data_format_str == "CHANNELS_LAST" ? FORMAT_NHWC : FORMAT_NCHW; + + // Always assume filter_format is HWIO / DHWIO. + filter_format_ = FilterTensorFormat::FORMAT_HWIO; + + // These parameters are checked against spatial dimensions on compute. + OP_REQUIRES_OK(context, context->GetAttr("batch_dims", &batch_dims_)); + OP_REQUIRES_OK(context, context->GetAttr("strides", &strides_)); + OP_REQUIRES_OK(context, context->GetAttr("dilations", &dilations_)); + if (context->HasAttr("explicit_paddings")) { + OP_REQUIRES_OK( + context, context->GetAttr("explicit_paddings", &explicit_paddings_)); + } + OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_)); + cudnn_use_autotune_ = CudnnUseAutotune(); + } + + void Compute(OpKernelContext* context) override { + // Input tensor is of the following dimensions: + // [ batch, [spatial_dims], in_depth ]. + const Tensor& input = context->input(0); + size_t original_input_dims = context->input(0).dims(); + const TensorShape original_input_shape = context->input(0).shape(); + int spatial_dims = original_input_dims - 1 - batch_dims_; + + // Input filter is of the following dimensions: + // [ batch, [spatial dims], in_depth ]. + const Tensor& filter = context->input(1); + + OP_REQUIRES(context, (spatial_dims == 2 || spatial_dims == 3), + absl::InvalidArgumentError(absl::StrCat( + "The input must have 2 or 3 spatial dimensions but got ", + spatial_dims))); + + OP_REQUIRES( + context, filter.NumElements() > 0, + absl::InvalidArgumentError("filter must not have zero elements " + "(i.e. all dimensions must be non-zero)")); + + // Flatten tensor for computation. + Tensor input_flat; + if (batch_dims_ == 1) { + input_flat = input; + } else { + std::vector in_flat_shape_vec(1, 1); + for (int i = 0; i < batch_dims_; ++i) { + in_flat_shape_vec[0] *= original_input_shape.dim_size(i); + } + for (int i = batch_dims_; i < original_input_shape.dims(); ++i) { + in_flat_shape_vec.push_back(original_input_shape.dim_size(i)); + } + TensorShape in_flat_shape(in_flat_shape_vec); + if (!input_flat.CopyFrom(input, in_flat_shape)) { + // This should never happen, since the output sizes should always be the + // same after expanding batches. + context->SetStatus(absl::InternalError(absl::StrCat( + "Could not flatten input shape ", + original_input_shape.DebugString(), " and flat input shape ", + in_flat_shape.DebugString()))); + } + } + + OP_REQUIRES(context, filter.dims() == 4 || filter.dims() == 5, + absl::InvalidArgumentError(absl::StrCat( + "The filter must be rank 4 or 5 but got ", filter.dims()))); + for (int i = 0; i < spatial_dims; i++) { + OP_REQUIRES( + context, + FastBoundsCheck(filter.dim_size(i), std::numeric_limits::max()), + absl::InvalidArgumentError("filter too large")); + } + + // Validate operation parameters based on inferred spatial dims. + OP_REQUIRES(context, strides_.size() == spatial_dims + 2, + absl::InvalidArgumentError( + absl::StrCat("Sliding window strides field must specify ", + spatial_dims + 2, " dimensions"))); + + OP_REQUIRES(context, + (GetTensorDim(strides_, data_format_, 'C') == 1 && + GetTensorDim(strides_, data_format_, 'N') == 1), + absl::InvalidArgumentError( + "Current implementation does not support " + "strides in the batch and depth dimensions.")); + bool stride_valid = true; + for (int i = 0; i < spatial_dims; ++i) { + stride_valid = + stride_valid && (GetTensorDim(strides_, data_format_, + static_cast(i + '0')) > 0); + } + OP_REQUIRES( + context, stride_valid, + absl::InvalidArgumentError("Spatial strides should be larger than 0.")); + if (dilations_.empty()) { + dilations_ = std::vector(spatial_dims + 2, 1); + } else { + OP_REQUIRES(context, dilations_.size() == spatial_dims + 2, + absl::InvalidArgumentError( + absl::StrCat("Dilation rates field must specify", + spatial_dims + 2, "dimensions"))); + OP_REQUIRES(context, + (GetTensorDim(dilations_, data_format_, 'N') == 1 && + GetTensorDim(dilations_, data_format_, 'C') == 1), + absl::InvalidArgumentError( + "Current implementation does not support " + "dilation rates in the batch and depth dimensions.")); + bool dilation_valid = true; + for (int i = 0; i < spatial_dims; ++i) { + dilation_valid = + dilation_valid && (GetTensorDim(dilations_, data_format_, + static_cast(i + '0')) > 0); + } + OP_REQUIRES( + context, dilation_valid, + absl::InvalidArgumentError("Dilated rates should be larger than 0.")); + } + OP_REQUIRES_OK(context, CheckValidPadding(padding_, explicit_paddings_, + spatial_dims + 2, data_format_)); + + const int64_t in_depth_raw = GetTensorDim(input_flat, data_format_, 'C'); + const int64_t patch_depth_raw = GetFilterDim(filter, filter_format_, 'I'); + OP_REQUIRES(context, + FastBoundsCheck(in_depth_raw, std::numeric_limits::max()), + absl::InvalidArgumentError("Input depth too large")); + OP_REQUIRES( + context, + FastBoundsCheck(patch_depth_raw, std::numeric_limits::max()), + absl::InvalidArgumentError("Patch depth too large")); + const int in_depth = static_cast(in_depth_raw); + const int patch_depth = static_cast(patch_depth_raw); + OP_REQUIRES( + context, patch_depth > 0, + absl::InvalidArgumentError(absl::StrCat( + "filter depth must be stricly positive, got ", patch_depth))); + OP_REQUIRES(context, in_depth == patch_depth, + absl::InvalidArgumentError(absl::StrCat( + "Input depth must be equal to filter depth: ", in_depth, + " vs ", patch_depth))); + + const int out_depth = + static_cast(GetFilterDim(filter, filter_format_, 'O')); + + std::vector input_dims_raw(spatial_dims); + std::vector input_dims(spatial_dims); + std::vector filter_dims(spatial_dims); + for (int i = 0; i < spatial_dims; ++i) { + input_dims_raw[i] = + GetTensorDim(input_flat, data_format_, static_cast(i + '0')); + OP_REQUIRES( + context, + FastBoundsCheck(input_dims_raw[i], std::numeric_limits::max()), + absl::InvalidArgumentError( + absl::StrCat("Input spatial dimension ", i, " too large"))); + input_dims[i] = static_cast(input_dims_raw[i]); + filter_dims[i] = static_cast( + GetFilterDim(filter, filter_format_, static_cast(i + '0'))); + } + // The first dimension for input is batch. + const int64_t batch_raw = GetTensorDim(input_flat, data_format_, 'N'); + OP_REQUIRES(context, + FastBoundsCheck(batch_raw, std::numeric_limits::max()), + absl::InvalidArgumentError("Batch is too large")); + const int batch = static_cast(batch_raw); + + // Take the stride and dilation from the spatial dimensions only (we + // do not support striding or dilation on the batch or depth dimension). + std::vector stride_dims(spatial_dims); + std::vector dilation_dims(spatial_dims); + for (int i = 0; i < spatial_dims; ++i) { + stride_dims[i] = + GetTensorDim(strides_, data_format_, static_cast(i + '0')); + dilation_dims[i] = + GetTensorDim(dilations_, data_format_, static_cast(i + '0')); + } + std::vector pad_before(spatial_dims, -1); + std::vector pad_after(spatial_dims, -1); + if (padding_ == Padding::EXPLICIT) { + GetExplicitPaddingForDim(explicit_paddings_, data_format_, 'H', + &pad_before[0], &pad_after[0]); + GetExplicitPaddingForDim(explicit_paddings_, data_format_, 'W', + &pad_before[1], &pad_after[1]); + } + + // Compute windowed output sizes for spatial dimensions. + std::vector out_dims(spatial_dims); + for (int i = 0; i < spatial_dims; ++i) { + OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose( + input_dims[i], filter_dims[i], + dilation_dims[i], stride_dims[i], padding_, + &out_dims[i], &pad_before[i], &pad_after[i])); + } + TensorShape out_shape; + OP_REQUIRES_OK(context, + ShapeFromFormatWithStatus(data_format_, batch, out_dims, + out_depth, &out_shape)); + + Tensor* output; + OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &output)); + + // If there is nothing to compute, return. + if (out_shape.num_elements() == 0) { + return; + } + + // If the input is empty, result can only be due to padding. + if (input_flat.NumElements() == 0) { + // Zero-out output and return. + functor::SetZeroFunctor()(context->eigen_device(), + output->template flat()); + + return; + } + + launcher_(context, cudnn_use_autotune_, input_flat, filter, dilations_, + strides_, padding_, explicit_paddings_, data_format_, output); + + // Reshape the output to preserve original batch dimensions. + if (batch_dims_ != 1) { + std::vector reshape_vect(batch_dims_); + for (int i = 0; i < batch_dims_; ++i) { + reshape_vect[i] = original_input_shape.dim_size(i); + } + for (int i = 1; i < out_shape.dims(); ++i) { + reshape_vect.push_back(out_shape.dim_size(i)); + } + TensorShape expanded_out_shape(reshape_vect); + if (!output->CopyFrom(*output, expanded_out_shape)) { + // This should never happen, since the output sizes should always be the + // same after expanding batches. + context->SetStatus(absl::InternalError( + absl::StrCat("Could not expand dimension with flat output shape ", + out_shape.DebugString(), " and expanded output shape ", + expanded_out_shape.DebugString()))); + } + } + } + + private: + std::vector strides_; + Padding padding_; + std::vector explicit_paddings_; + TensorFormat data_format_; + FilterTensorFormat filter_format_; + std::vector dilations_; + int batch_dims_; + int groups_; + bool cudnn_use_autotune_; + + LaunchConvOp launcher_; + + ConvOp(const ConvOp&) = delete; + void operator=(const ConvOp&) = delete; +}; + +template +struct LaunchConv2DOp { + void operator()(OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune, + const Tensor& input, const Tensor& filter, int row_dilation, + int col_dilation, int row_stride, int col_stride, + const Padding& padding, + const std::vector& explicit_paddings, Tensor* output, + TensorFormat data_format) { + if (data_format != FORMAT_NHWC) { + ctx->SetStatus(errors::Unimplemented( + "The Conv2D op currently only supports the NHWC tensor format on the " + "CPU. The op was given the format: ", + ToString(data_format))); + return; + } + + for (int64_t explicit_padding : explicit_paddings) { + if (!FastBoundsCheck(explicit_padding, std::numeric_limits::max())) { + ctx->SetStatus(errors::InvalidArgument("filter too large")); + return; + } + } + + const int64_t in_depth = input.dim_size(3); + const int64_t out_depth = output->dim_size(3); + const int64_t patch_depth = filter.dim_size(2); + + if (patch_depth <= 0) { + ctx->SetStatus(errors::InvalidArgument( + "filter depth must be stricly positive, got ", patch_depth)); + return; + } + if (in_depth % patch_depth != 0) { + ctx->SetStatus(errors::InvalidArgument( + "input depth must be evenly divisible by filter depth: ", in_depth, + " vs ", patch_depth)); + return; + } + if (filter.NumElements() <= 0) { + ctx->SetStatus( + errors::InvalidArgument("filter must not have zero elements " + "(i.e. all dimensions must be non-zero)")); + return; + } + + const int64_t num_groups = in_depth / patch_depth; + if (num_groups <= 0) { + ctx->SetStatus(errors::InvalidArgument( + "number of groups must be stricly positive, got ", num_groups)); + return; + } + if (out_depth % num_groups != 0 || out_depth < num_groups) { + ctx->SetStatus(errors::InvalidArgument( + "output depth must be evenly divisible by number of groups: ", + out_depth, " vs ", num_groups)); + return; + } + + if (in_depth != patch_depth) { + LaunchGrouped()(ctx, input, filter, row_stride, col_stride, + row_dilation, col_dilation, padding, explicit_paddings, + output, data_format); + } else { + LaunchGeneric()(ctx, input, filter, row_stride, col_stride, + row_dilation, col_dilation, padding, + explicit_paddings, output, data_format); + } + } +}; +extern template struct LaunchConv2DOp; +extern template struct LaunchConv2DOp; +extern template struct LaunchConv2DOp; +extern template struct LaunchConv2DOp; +extern template struct LaunchConv2DOp; + +template +class LaunchDeepConvOp { + public: + static bool Run(OpKernelContext* ctx, const Tensor& input, + const Tensor& filter, int batch, int input_rows, + int input_cols, int in_depth, int filter_rows, + int filter_cols, int pad_rows, int pad_cols, int out_rows, + int /*out_cols*/, int /*out_depth*/, int /*dilation_rows*/, + int /*dilation_cols*/, int /*stride_rows*/, + int /*stride_cols*/, Tensor* /*output*/, + TensorFormat /*data_format*/) { + return false; + } +}; + +template +class Conv2DOp : public BinaryOp { + public: + explicit Conv2DOp(OpKernelConstruction* context) : BinaryOp(context) { + OP_REQUIRES_OK(context, InitConv2DParameters(context, ¶ms_)); + + OP_REQUIRES_OK(context, context->GetAttr("use_cudnn_on_gpu", &use_cudnn_)); + cudnn_use_autotune_ = CudnnUseAutotune(); + } + + void Compute(OpKernelContext* context) override { + // Input tensor is of the following dimensions: + // [ batch, in_rows, in_cols, in_depth ] + const Tensor& input = context->input(0); + + // Input filter is of the following dimensions: + // [ filter_rows, filter_cols, in_depth, out_depth] + const Tensor& filter = context->input(1); + + Conv2DDimensions dimensions; + OP_REQUIRES_OK(context, + ComputeConv2DDimension(params_, input, filter, &dimensions)); + + TensorShape out_shape; + OP_REQUIRES_OK( + context, ShapeFromFormatWithStatus( + params_.data_format, dimensions.batch, dimensions.out_rows, + dimensions.out_cols, dimensions.out_depth, &out_shape)); + + // Output tensor is of the following dimensions: + // [ in_batch, out_rows, out_cols, out_depth ] + Tensor* output = nullptr; + OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &output)); + + VLOG(2) << "Conv2D: in_depth = " << dimensions.in_depth + << ", patch_depth = " << dimensions.patch_depth + << ", input_cols = " << dimensions.input_cols + << ", filter_cols = " << dimensions.filter_cols + << ", input_rows = " << dimensions.input_rows + << ", filter_rows = " << dimensions.filter_rows + << ", stride_rows = " << dimensions.stride_rows + << ", stride_cols = " << dimensions.stride_cols + << ", dilation_rows = " << dimensions.dilation_rows + << ", dilation_cols = " << dimensions.dilation_cols + << ", out_depth = " << dimensions.out_depth; + + // If there is nothing to compute, return. + if (out_shape.num_elements() == 0) { + return; + } + + // If the input is empty, result can only be due to padding. + if (input.NumElements() == 0) { + // Zero-out output and return. + functor::SetZeroFunctor()(context->eigen_device(), + output->template flat()); + + return; + } + + if (params_.padding != EXPLICIT && + LaunchDeepConvOp::Run( + context, input, filter, dimensions.batch, dimensions.input_rows, + dimensions.input_cols, dimensions.in_depth, dimensions.filter_rows, + dimensions.filter_cols, dimensions.pad_rows_before, + dimensions.pad_cols_before, dimensions.out_rows, + dimensions.out_cols, dimensions.out_depth, dimensions.dilation_rows, + dimensions.dilation_cols, dimensions.stride_rows, + dimensions.stride_cols, output, params_.data_format)) { + return; + } + + launcher_(context, use_cudnn_, cudnn_use_autotune_, input, filter, + dimensions.dilation_rows, dimensions.dilation_cols, + dimensions.stride_rows, dimensions.stride_cols, params_.padding, + params_.explicit_paddings, output, params_.data_format); + } + + private: + Conv2DParameters params_; + bool use_cudnn_; + bool cudnn_use_autotune_; + + LaunchConv2DOp launcher_; + + Conv2DOp(const Conv2DOp&) = delete; + void operator=(const Conv2DOp&) = delete; +}; +extern template struct Conv2DOp; +extern template struct Conv2DOp; +extern template struct Conv2DOp; +extern template struct Conv2DOp; +extern template struct Conv2DOp; + +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM +template +void LaunchConvOpImpl(OpKernelContext* context, bool cudnn_use_autotune, + const Tensor& input_param, const Tensor& filter, + const gtl::InlinedVector& dilations, + const gtl::InlinedVector& strides, + const Padding& padding, + const std::vector& explicit_paddings, + TensorFormat data_format, Tensor* output) { + auto* stream = context->op_device_context()->stream(); + OP_REQUIRES(context, stream, absl::InternalError("No GPU stream available.")); + + Tensor input = input_param; + + int spatial_dims = input.dims() - 2; + std::vector in_dims(spatial_dims); + + const int64_t in_batch = GetTensorDim(input, data_format, 'N'); + for (int i = 0; i < spatial_dims; ++i) { + in_dims[i] = GetTensorDim(input, data_format, static_cast('0' + i)); + } + const int64_t in_depth = GetTensorDim(input, data_format, 'C'); + + std::vector filter_dims(spatial_dims); + for (int i = 0; i < spatial_dims; ++i) { + filter_dims[i] = filter.dim_size(i); + } + const int64_t filter_depth = filter.dim_size(spatial_dims); + const int64_t out_depth = filter.dim_size(spatial_dims + 1); + + OP_REQUIRES( + context, filter.NumElements() > 0, + absl::InvalidArgumentError("filter must not have zero elements " + "(i.e. all dimensions must be non-zero)")); + + bool is_grouped_convolution = filter_depth != in_depth; + // check if filter is 1x1 and stride/dilation are all ones + bool one_filter = true; + bool one_dilations = true; + bool one_stride = true; + for (int i = 0; i < spatial_dims; ++i) { + one_filter = one_filter && (filter_dims[i] == 1); + one_dilations = one_dilations && (dilations[i] == 1); + one_stride = one_stride && (strides[i] == 1); + } + // check if filter is same spatial shape as input + bool filter_same_dims = true; + for (int i = 0; i < spatial_dims; ++i) { + if (filter_dims[i] != in_dims[i]) filter_same_dims = false; + } + + auto* blas = stream->parent()->AsBlas(); + OP_REQUIRES(context, blas != nullptr, + absl::InternalError("No BLAS for stream.")); + if (!is_grouped_convolution && one_filter && one_dilations && one_stride && + data_format == FORMAT_NHWC && (padding == VALID || padding == SAME)) { + // 1x1 filter, so call cublas directly. + const uint64 m = in_batch * std::accumulate(in_dims.begin(), in_dims.end(), + 1, std::multiplies<>{}); + const uint64 k = in_depth; + const uint64 n = out_depth; + + auto a_ptr = AsDeviceMemory(input.template flat().data(), + input.template flat().size()); + auto b_ptr = AsDeviceMemory(filter.template flat().data(), + filter.template flat().size()); + auto c_ptr = AsDeviceMemory(output->template flat().data(), + output->template flat().size()); + + auto no_transpose = se::blas::Transpose::kNoTranspose; + OP_REQUIRES_OK(context, blas->BlasGemm(stream, no_transpose, no_transpose, + n, m, k, b_ptr, n, a_ptr, k, &c_ptr, + n, GetNumericOptions(), + se::blas::CallContext::kNone)); + return; + } else if (!is_grouped_convolution && filter_same_dims && padding == VALID && + data_format == FORMAT_NHWC) { + // The input data and filter have the same spatial dimensions, so call + // cublas directly. + const uint64 m = in_batch; + const uint64 k = in_depth * std::accumulate(in_dims.begin(), in_dims.end(), + 1, std::multiplies<>{}); + const uint64 n = out_depth; + + auto a_ptr = AsDeviceMemory(input.template flat().data(), + input.template flat().size()); + auto b_ptr = AsDeviceMemory(filter.template flat().data(), + filter.template flat().size()); + auto c_ptr = AsDeviceMemory(output->template flat().data(), + output->template flat().size()); + + auto no_transpose = se::blas::Transpose::kNoTranspose; + OP_REQUIRES_OK(context, blas->BlasGemm(stream, no_transpose, no_transpose, + n, m, k, b_ptr, n, a_ptr, k, &c_ptr, + n, GetNumericOptions(), + se::blas::CallContext::kNone)); + return; + } + + const bool compute_in_nhwc = ComputeInNhwcEnabled( + DataTypeToEnum::value, stream, /*use_4d_tensor=*/(spatial_dims == 2)); + const TensorFormat compute_data_format = + (compute_in_nhwc && data_format == FORMAT_NHWC) ? FORMAT_NHWC + : FORMAT_NCHW; + + VLOG(3) << "Compute Conv with cuDNN:" + << " data_format=" << ToString(data_format) + << " compute_data_format=" << ToString(compute_data_format); + + std::vector out_dims(output->dims()); + for (int i = 0; i < output->dims(); ++i) { + out_dims[i] = output->dim_size(i); + } + std::vector> paddings(spatial_dims, {-1, -1}); + // Explicit only on 2D case. + if (padding == EXPLICIT) { + GetExplicitPaddingForDim(explicit_paddings, data_format, 'H', + &paddings[0].first, &paddings[0].second); + GetExplicitPaddingForDim(explicit_paddings, data_format, 'W', + &paddings[1].first, &paddings[1].second); + } + + // Get padding values, output should be valid, since it was checked before. + std::vector out_dims_check(spatial_dims); + for (int i = 0; i < spatial_dims; ++i) { + OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose( + in_dims[i], filter_dims[i], dilations[i], + strides[i], padding, &out_dims_check[i], + &paddings[i].first, &paddings[i].second)); + OP_REQUIRES(context, + (out_dims_check[i] == GetTensorDim(*output, data_format, + static_cast('0' + i))), + absl::InternalError("Output dimension doesn't match yo")); + } + + bool assymmetric_padding = false; + std::vector common_padding(spatial_dims); + for (int i = 0; i < spatial_dims; ++i) { + common_padding[i] = std::min(paddings[i].first, paddings[i].second); + assymmetric_padding = + assymmetric_padding || (paddings[i].first != paddings[i].second); + } + + if (assymmetric_padding) { + // cuDNN only supports padding the same amount on either side. So we + // manually create a new padded input tensor. + Tensor transformed_input; + std::vector new_in_dims(input.dims()); + new_in_dims[0] = in_batch; + for (int i = 0; i < spatial_dims; ++i) { + int index = GetTensorSpatialDimIndex(input.dims(), data_format, i); + new_in_dims[index] = + in_dims[i] + std::abs(paddings[i].first - paddings[i].second); + } + new_in_dims[GetTensorDimIndex(data_format, 'C', input.dims())] = in_depth; + TensorShape transformed_input_shape(new_in_dims); + OP_REQUIRES_OK(context, context->allocate_temp(DataTypeToEnum::value, + transformed_input_shape, + &transformed_input)); + + // Padding to add on transformed input. + std::vector> transformed_input_padding( + paddings); + for (int i = 0; i < spatial_dims; ++i) { + transformed_input_padding[i].first -= common_padding[i]; + transformed_input_padding[i].second -= common_padding[i]; + } + + // Check padding size. + bool padding_bounds_valid = true; + for (int i = 0; i < spatial_dims; ++i) { + padding_bounds_valid = + padding_bounds_valid && + FastBoundsCheck(transformed_input_padding[i].first, + std::numeric_limits::max()) && + FastBoundsCheck(transformed_input_padding[i].second, + std::numeric_limits::max()); + } + OP_REQUIRES(context, padding_bounds_valid, + absl::InvalidArgumentError("Padding is too large.")); + + // Pad new input. + if (input.dims() == 4) { + std::array pad_left{ + static_cast(transformed_input_padding[0].first), + static_cast(transformed_input_padding[1].first)}; + std::array pad_right{ + static_cast(transformed_input_padding[0].second), + static_cast(transformed_input_padding[1].second)}; + functor::PadInput()( + context->eigen_device(), + To32Bit(static_cast(input).tensor()), pad_left, + pad_right, To32Bit(transformed_input.tensor()), data_format, + T{}); + } else if (input.dims() == 5) { + std::array pad_left{ + static_cast(transformed_input_padding[0].first), + static_cast(transformed_input_padding[1].first), + static_cast(transformed_input_padding[2].first)}; + std::array pad_right{ + static_cast(transformed_input_padding[0].second), + static_cast(transformed_input_padding[1].second), + static_cast(transformed_input_padding[2].second)}; + functor::PadInput()( + context->eigen_device(), + To32Bit(static_cast(input).tensor()), pad_left, + pad_right, To32Bit(transformed_input.tensor()), data_format, + T{}); + } else { + context->SetStatus( + absl::InternalError("Failed to pad input, invalid dimensions.")); + } + + input = transformed_input; + for (int i = 0; i < spatial_dims; ++i) { + in_dims[i] = new_in_dims[GetTensorDimIndex( + data_format, static_cast('0' + i), input.dims())]; + } + } + + if (data_format == FORMAT_NHWC && compute_data_format == FORMAT_NCHW) { + VLOG(4) << "Convert the input tensor from NHWC to NCHW."; + + TensorShape channels_first_shape; + OP_REQUIRES_OK(context, + ShapeFromFormatWithStatus(FORMAT_NCHW, in_batch, in_dims, + in_depth, &channels_first_shape)); + + if (in_depth > 1) { + Tensor transformed_input; + OP_REQUIRES_OK(context, context->allocate_temp(DataTypeToEnum::value, + channels_first_shape, + &transformed_input)); + if (input.dims() == 4) { + functor::NHWCToNCHW()( + context->eigen_device(), + const_cast(input).tensor(), + transformed_input.tensor()); + } else if (input.dims() == 5) { + functor::NHWCToNCHW()( + context->eigen_device(), + const_cast(input).tensor(), + transformed_input.tensor()); + } else { + context->SetStatus( + absl::InternalError("Failed to reshape input to channels first " + "format, invalid dimensions.")); + } + input = transformed_input; + } else { + // Depth = 1, reshape. + if (!input.CopyFrom(input, channels_first_shape)) { + context->SetStatus(absl::InternalError( + "Failed to reshape input to channels first format.")); + } + } + } else { + DCHECK(data_format == compute_data_format) // Crash OK. + << "Illegal data and compute format pair:" + << " data_format=" << ToString(data_format) + << " compute_data_format=" << ToString(compute_data_format); + } + + // Check paddings are not negative. + bool non_negative_paddings = true; + for (int i = 0; i < spatial_dims; ++i) { + non_negative_paddings = non_negative_paddings && common_padding[i] >= 0; + } + OP_REQUIRES(context, non_negative_paddings, + absl::InvalidArgumentError("Padding is negative.")); + + constexpr auto kComputeInNHWC = + std::make_tuple(se::dnn::DataLayout::kBatchYXDepth, + se::dnn::FilterLayout::kOutputYXInput); + constexpr auto kComputeInNCHW = + std::make_tuple(se::dnn::DataLayout::kBatchDepthYX, + se::dnn::FilterLayout::kOutputInputYX); + + se::dnn::DataLayout compute_data_layout; + se::dnn::FilterLayout filter_layout; + + std::tie(compute_data_layout, filter_layout) = + compute_data_format == FORMAT_NHWC ? kComputeInNHWC : kComputeInNCHW; + + se::dnn::BatchDescriptor input_desc(spatial_dims); + input_desc.set_count(in_batch).set_feature_map_count(in_depth).set_layout( + compute_data_layout); + if (spatial_dims == 2) { + input_desc.set_spatial_dim(stream_executor::dnn::DimIndex::X, in_dims[1]) + .set_spatial_dim(stream_executor::dnn::DimIndex::Y, in_dims[0]); + } else if (spatial_dims == 3) { + input_desc.set_spatial_dim(stream_executor::dnn::DimIndex::X, in_dims[2]) + .set_spatial_dim(stream_executor::dnn::DimIndex::Y, in_dims[1]) + .set_spatial_dim(stream_executor::dnn::DimIndex::Z, in_dims[0]); + } else { + context->SetStatus( + absl::InternalError("Failed to set Input Descripitor:" + " invalid number of spatial dimensions")); + } + + se::dnn::BatchDescriptor output_desc(spatial_dims); + output_desc.set_count(GetTensorDim(*output, data_format, 'N')) + .set_feature_map_count(GetTensorDim(*output, data_format, 'C')) + .set_layout(compute_data_layout); + if (spatial_dims == 2) { + output_desc + .set_spatial_dim( + stream_executor::dnn::DimIndex::X, + GetTensorDim(*output, data_format, static_cast('1'))) + .set_spatial_dim( + stream_executor::dnn::DimIndex::Y, + GetTensorDim(*output, data_format, static_cast('0'))); + } else if (spatial_dims == 3) { + output_desc + .set_spatial_dim( + stream_executor::dnn::DimIndex::X, + GetTensorDim(*output, data_format, static_cast('2'))) + .set_spatial_dim( + stream_executor::dnn::DimIndex::Y, + GetTensorDim(*output, data_format, static_cast('1'))) + .set_spatial_dim( + stream_executor::dnn::DimIndex::Z, + GetTensorDim(*output, data_format, static_cast('0'))); + } else { + context->SetStatus( + absl::InternalError("Failed to set Output Descripitor: invalid " + "number of spatial dimensions")); + } + + se::dnn::FilterDescriptor filter_desc(spatial_dims); + filter_desc.set_input_feature_map_count(filter_depth) + .set_output_feature_map_count(out_depth) + .set_layout(filter_layout); + if (spatial_dims == 2) { + filter_desc + .set_spatial_dim(stream_executor::dnn::DimIndex::X, filter_dims[1]) + .set_spatial_dim(stream_executor::dnn::DimIndex::Y, filter_dims[0]); + } else if (spatial_dims == 3) { + filter_desc + .set_spatial_dim(stream_executor::dnn::DimIndex::X, filter_dims[2]) + .set_spatial_dim(stream_executor::dnn::DimIndex::Y, filter_dims[1]) + .set_spatial_dim(stream_executor::dnn::DimIndex::Z, filter_dims[0]); + } else { + context->SetStatus( + absl::InternalError("Failed to set Filter Descripitor: invalid " + "number of spatial dimensions")); + } + + se::dnn::ConvolutionDescriptor conv_desc(spatial_dims); + if (spatial_dims == 2) { + conv_desc.set_dilation_rate(stream_executor::dnn::DimIndex::X, dilations[1]) + .set_dilation_rate(stream_executor::dnn::DimIndex::Y, dilations[0]) + .set_filter_stride(stream_executor::dnn::DimIndex::X, strides[1]) + .set_filter_stride(stream_executor::dnn::DimIndex::Y, strides[0]) + .set_zero_padding(stream_executor::dnn::DimIndex::X, common_padding[1]) + .set_zero_padding(stream_executor::dnn::DimIndex::Y, common_padding[0]); + } else if (spatial_dims == 3) { + conv_desc.set_dilation_rate(stream_executor::dnn::DimIndex::X, dilations[2]) + .set_dilation_rate(stream_executor::dnn::DimIndex::Y, dilations[1]) + .set_dilation_rate(stream_executor::dnn::DimIndex::Z, dilations[0]) + .set_filter_stride(stream_executor::dnn::DimIndex::X, strides[2]) + .set_filter_stride(stream_executor::dnn::DimIndex::Y, strides[1]) + .set_filter_stride(stream_executor::dnn::DimIndex::Z, strides[0]) + .set_zero_padding(stream_executor::dnn::DimIndex::X, common_padding[2]) + .set_zero_padding(stream_executor::dnn::DimIndex::Y, common_padding[1]) + .set_zero_padding(stream_executor::dnn::DimIndex::Z, common_padding[0]); + } else { + context->SetStatus( + absl::InternalError("Failed to set Convolution Descripitor: invalid " + "number of spatial dimensions")); + } + conv_desc.set_group_count(1); + // TODO(b/290223810) Change group count when implementing group/depthwise. + Tensor transformed_filter; + auto dst_format = + compute_data_format == FORMAT_NCHW ? FORMAT_OIHW : FORMAT_OHWI; + VLOG(4) << "Transform filter tensor from " << ToString(FORMAT_HWIO) << " to " + << ToString(dst_format); + std::vector dst_shape_vec(spatial_dims + 2); + dst_shape_vec[0] = out_depth; + if (dst_format == FORMAT_OIHW) { + dst_shape_vec[1] = filter_depth; + for (int i = 2; i < filter.dims(); ++i) { + dst_shape_vec[i] = filter_dims[i - 2]; + } + } else { + // Format OHWI + dst_shape_vec[filter.dims() - 1] = filter_depth; + for (int i = 1; i < filter.dims() - 1; ++i) { + dst_shape_vec[i] = filter_dims[i - 1]; + } + } + TensorShape dst_shape(dst_shape_vec); + OP_REQUIRES_OK(context, + context->allocate_temp(DataTypeToEnum::value, dst_shape, + &transformed_filter)); + + // Filter: [(spatial_dims), in, out] (HWIO) + // T_filter: [out, in, (spatial_dims)] (OIHW) or + // T_filter: [out, (spatial_dims), in] (OHWI) + if (spatial_dims == 2) { + functor::TransformFilter()( + context->eigen_device(), dst_format, + To32Bit(filter.tensor()), + To32Bit(transformed_filter.tensor())); + } else if (spatial_dims == 3) { + functor::TransformFilter()( + context->eigen_device(), dst_format, + To32Bit(filter.tensor()), + To32Bit(transformed_filter.tensor())); + } else { + context->SetStatus(absl::InternalError( + "Failed to reshape filter, invalid spatial dimensions.")); + } + + Tensor transformed_output; + if (data_format != compute_data_format) { + VLOG(4) << "Allocate temporary memory for output in compute data format"; + TensorShape transformed_output_shape; + OP_REQUIRES_OK(context, ShapeFromFormatWithStatus( + FORMAT_NCHW, in_batch, out_dims_check, + out_depth, &transformed_output_shape)); + OP_REQUIRES_OK(context, context->allocate_temp(DataTypeToEnum::value, + transformed_output_shape, + &transformed_output)); + } else { + transformed_output = *output; + } + + auto input_ptr = AsDeviceMemory(input.template flat().data(), + input.template flat().size()); + auto filter_ptr = + AsDeviceMemory(transformed_filter.template flat().data(), + transformed_filter.template flat().size()); + auto output_ptr = + AsDeviceMemory(transformed_output.template flat().data(), + transformed_output.template flat().size()); + + static int64_t ConvolveScratchSize = GetDnnWorkspaceLimitOrDefault(); + + if (spatial_dims == 2) { + filter_dims.push_back(filter_depth); + } + ConvParameters conv_parameters = { + stream->parent(), + in_batch, // batch + in_depth, // in_depths + in_dims, // input spatial dims + compute_data_format, // compute_data_format + out_depth, // out_depths + filter_dims, // filter spatial dims + dilations, // dilations + strides, // strides + common_padding, // paddings (symmetrical) + input.dtype(), // tensor datatype + conv_desc.group_count(), + }; + + auto entry_or = AutotuneUnfusedConv( + cudnn_use_autotune, ConvAutotuneMap::GetInstance(), conv_parameters, + context, se::dnn::ConvolutionKind::FORWARD, input_desc, input_ptr, + filter_desc, filter_ptr, conv_desc, output_desc, output_ptr, + ConvolveScratchSize); + OP_REQUIRES_OK(context, entry_or.status()); + auto autotune_entry = std::move(entry_or).value(); + + DnnScratchAllocator scratch_allocator(ConvolveScratchSize, context); + Status cudnn_launch_status = LaunchAutotunedConv( + autotune_entry, &scratch_allocator, se::dnn::ConvolutionKind::FORWARD, + stream, input_desc, input_ptr, filter_desc, filter_ptr, conv_desc, + output_desc, output_ptr); + if (!cudnn_launch_status.ok()) { + context->SetStatus(cudnn_launch_status); + return; + } + + if (data_format == FORMAT_NHWC && compute_data_format == FORMAT_NCHW) { + VLOG(4) << "Convert the output tensor back from NCHW to NHWC."; + if (spatial_dims == 2) { + functor::NCHWToNHWC()( + context->eigen_device(), + const_cast(transformed_output).tensor(), + output->tensor()); + } else if (spatial_dims == 3) { + functor::NCHWToNHWC()( + context->eigen_device(), + const_cast(transformed_output).tensor(), + output->tensor()); + } else { + context->SetStatus(absl::InternalError( + "Failed to convert output data foramt, invalid spatial dimensions.")); + } + } +} + +template +void LaunchConvOp::operator()( + OpKernelContext* context, bool cudnn_use_autotune, const Tensor& input, + const Tensor& filter, const std::vector& dilations, + const std::vector& strides, const Padding padding, + const std::vector& explicit_paddings, TensorFormat data_format, + Tensor* output) { + // Get spatial dims for dilations and strides. + int spatial_dims = input.dims() - 2; + gtl::InlinedVector strides_spatial(spatial_dims); + gtl::InlinedVector dilations_spatial(spatial_dims); + for (int i = 0; i < spatial_dims; ++i) { + strides_spatial[i] = + GetTensorDim(strides, data_format, static_cast(i + '0')); + dilations_spatial[i] = + GetTensorDim(dilations, data_format, static_cast(i + '0')); + } + LaunchConvOpImpl(context, cudnn_use_autotune, input, filter, + dilations_spatial, strides_spatial, padding, + explicit_paddings, data_format, output); +} + +template +void LaunchConv2DOp::operator()( + OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune, + const Tensor& input_param, const Tensor& filter, int row_dilation, + int col_dilation, int row_stride, int col_stride, const Padding& padding, + const std::vector& explicit_paddings, Tensor* output, + TensorFormat data_format) { + // Cast strides and dilations. + gtl::InlinedVector casted_strides = {row_stride, col_stride}; + gtl::InlinedVector casted_dilations = {row_dilation, + col_dilation}; + LaunchConvOpImpl(ctx, cudnn_use_autotune, input_param, filter, + casted_dilations, casted_strides, padding, + explicit_paddings, data_format, output); +} + +// To be used inside depthwise_conv_op.cc. +extern template struct LaunchConv2DOp; +// extern template struct LaunchConv2DOp; +extern template struct LaunchConv2DOp; +extern template struct LaunchConv2DOp; + +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_CONV_OPS_IMPL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/cross_op.h b/third_party/tflite-hdrs/tensorflow/core/kernels/cross_op.h new file mode 100644 index 00000000..cf5956ac --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/cross_op.h @@ -0,0 +1,54 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_CROSS_OP_H_ +#define TENSORFLOW_CORE_KERNELS_CROSS_OP_H_ + +#include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/tensor_types.h" + +namespace tensorflow { + +namespace functor { + +template +struct Cross { + void operator()(const Device &d, + typename TTypes::ConstTensor in0_data, + typename TTypes::ConstTensor in1_data, + typename TTypes::Tensor output_data) { + auto s1 = output_data.template chip<1>(0); + auto s2 = output_data.template chip<1>(1); + auto s3 = output_data.template chip<1>(2); + + auto u1 = in0_data.template chip<1>(0); + auto u2 = in0_data.template chip<1>(1); + auto u3 = in0_data.template chip<1>(2); + + auto v1 = in1_data.template chip<1>(0); + auto v2 = in1_data.template chip<1>(1); + auto v3 = in1_data.template chip<1>(2); + + s1.device(d) = u2 * v3 - u3 * v2; + s2.device(d) = u3 * v1 - u1 * v3; + s3.device(d) = u1 * v2 - u2 * v1; + } +}; + +} // namespace functor +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_CROSS_OP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/cudnn_pooling_gpu.h b/third_party/tflite-hdrs/tensorflow/core/kernels/cudnn_pooling_gpu.h new file mode 100644 index 00000000..970eb533 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/cudnn_pooling_gpu.h @@ -0,0 +1,70 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Helper functions to run 3d pooling on GPU using CuDNN. + +#ifndef TENSORFLOW_CORE_KERNELS_CUDNN_POOLING_GPU_H_ +#define TENSORFLOW_CORE_KERNELS_CUDNN_POOLING_GPU_H_ + +#include + +#include "tensorflow/core/framework/op_kernel.h" + +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM +#include "tensorflow/core/platform/stream_executor.h" +#endif + +#include "tensorflow/core/util/padding.h" + +namespace tensorflow { + +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM + +// Runs (avg/max)pooling on GPU. +// Dimension order for all array arguments is: x, y, z. +template +class DnnPooling3dOp { + public: + static void Compute(OpKernelContext* context, + se::dnn::PoolingMode pooling_mode, + const std::array& size, + const std::array& stride, + const std::array& padding, + TensorFormat data_format, const Tensor& tensor_in, + Tensor* output); +}; + +// Computes the gradient of (avg/max)pooling on GPU. +// Dimension order for all array arguments is: x, y, z. +template +class DnnPooling3dGradOp { + public: + static void Compute(OpKernelContext* context, + se::dnn::PoolingMode pooling_mode, + const std::array& window, + const std::array& stride, + const std::array& padding, + const std::array& output_size, + TensorFormat data_format, const Tensor& out_backprop, + const TensorShape& tensor_in_shape, + const Tensor* tensor_in, const Tensor* tensor_out, + Tensor* input_backprop); +}; + +#endif + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_CUDNN_POOLING_GPU_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/cwise_op_clip.h b/third_party/tflite-hdrs/tensorflow/core/kernels/cwise_op_clip.h new file mode 100644 index 00000000..171b6932 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/cwise_op_clip.h @@ -0,0 +1,61 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_CWISE_OP_CLIP_H_ +#define TENSORFLOW_CORE_KERNELS_CWISE_OP_CLIP_H_ + +#include "tensorflow/core/kernels/cwise_ops_common.h" + +namespace tensorflow { +namespace functor { +// Unary functor for clip [Tensor, Scalar, Scalar] +template +struct UnaryClipOp { + void operator()(const Device &d, typename TTypes::ConstFlat &in0_flat, + typename TTypes::ConstFlat &in1_flat, + typename TTypes::ConstFlat &in2_flat, + typename TTypes::Flat &out_flat) const; +}; + +// Binary functor for clip [Tensor, Scalar, Tensor] +template +struct BinaryRightClipOp { + void operator()(const Device &d, typename TTypes::ConstFlat &in0_flat, + typename TTypes::ConstFlat &in1_flat, + typename TTypes::ConstFlat &in2_flat, + typename TTypes::Flat &out_flat) const; +}; + +// Binary functor for clip [Tensor, Tensor, Scalar] +template +struct BinaryLeftClipOp { + void operator()(const Device &d, typename TTypes::ConstFlat &in0_flat, + typename TTypes::ConstFlat &in1_flat, + typename TTypes::ConstFlat &in2_flat, + typename TTypes::Flat &out_flat) const; +}; + +// Ternary functor for clip [Tensor, Tensor, Tensor] +template +struct TernaryClipOp { + void operator()(const Device &d, typename TTypes::ConstFlat &in0_flat, + typename TTypes::ConstFlat &in1_flat, + typename TTypes::ConstFlat &in2_flat, + typename TTypes::Flat &out_flat) const; +}; +} // namespace functor +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_CWISE_OP_CLIP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/cwise_ops.h b/third_party/tflite-hdrs/tensorflow/core/kernels/cwise_ops.h new file mode 100644 index 00000000..06d75372 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/cwise_ops.h @@ -0,0 +1,1340 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_CWISE_OPS_H_ +#define TENSORFLOW_CORE_KERNELS_CWISE_OPS_H_ + +#define _USE_MATH_DEFINES +#include +#include +#include + +#include "Eigen/Core" // from @eigen_archive +#include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/bounds_check.h" +#include "tensorflow/core/framework/numeric_types.h" +#include "tensorflow/core/framework/tensor_types.h" + +namespace Eigen { +namespace internal { + +#if GOOGLE_CUDA +template <> +struct scalar_arg_op> { + typedef typename Eigen::NumTraits>::Real result_type; + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float operator()( + const std::complex& a) const { + return ::atan2f(a.imag(), a.real()); + } +}; + +template <> +struct scalar_arg_op> { + typedef typename Eigen::NumTraits>::Real result_type; + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE double operator()( + const std::complex& a) const { + return ::atan2(a.imag(), a.real()); + } +}; +#endif + +template +struct safe_scalar_binary_pow_op { + static_assert(std::is_integral::value, "Integer type expected"); + static_assert(std::is_integral::value && + std::is_signed::value, + "Signed integer type expected"); + + bool* const error; + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE safe_scalar_binary_pow_op(bool* error) + : error(error) {} + + EIGEN_DEVICE_FUNC inline Scalar operator()(const Scalar& a, + const Exponent& b) const { + const Exponent safe_b = tensorflow::internal::SubtleMustCopy(b); + if (TF_PREDICT_TRUE(safe_b >= 0)) { + return numext::pow(a, safe_b); + } else { + *error = true; + return 0; + } + } +}; + +template +struct functor_traits> { + enum { Cost = 5 * NumTraits::MulCost, PacketAccess = false }; +}; + +template +struct safe_div_or_mod_op { + static_assert(std::is_integral::value, "Integer type expected"); + + bool* const error; + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE safe_div_or_mod_op(bool* error) + : error(error) {} + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& a, + const T& b) const { + const T safe_b = tensorflow::internal::SubtleMustCopy(b); + if (TF_PREDICT_TRUE(safe_b != 0)) { + // Avoid FPE for INT_MIN/-1. + const T safe_a = tensorflow::internal::SubtleMustCopy(a); + if (TF_PREDICT_FALSE(std::is_signed::value && + safe_a == std::numeric_limits::min() && + safe_b == T(-1))) { + // Prefer to overflow 'a' instead of crashing. + return DivOrMod()(-safe_a, 1); + } + return DivOrMod()(safe_a, safe_b); + } else { + *error = true; + return 0; + } + } +}; + +template +struct functor_traits> { + enum { + Cost = functor_traits::Cost + NumTraits::AddCost, + PacketAccess = false, + }; +}; + +template +struct no_nan_op { + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& a, + const T& b) const { + if (b != T(0)) { + return Binary()(a, b); + } else { + return T(0); + } + } + template + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& a, + const Packet& b) const { + const Packet mask = pcmp_eq(b, pzero(b)); + const Packet quotient = Binary().packetOp(a, b); + return pandnot(quotient, mask); + } +}; + +template ::IsComplex> +struct div_no_nan_op; + +template +struct div_no_nan_op + : public no_nan_op> { +}; + +template +struct functor_traits> { + enum { + Cost = functor_traits>::Cost + NumTraits::AddCost, + PacketAccess = true, + }; +}; + +// Whether or not complex division produces a NaN depends on the underlying +// implementation. Some compilers (e.g. gcc) use a simple method that divides +// by |b|^2, which may underflow to 0 for b != 0. +template +struct div_no_nan_op { + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& a, + const T& b) const { + if (b == T(0)) { + return T(0); + } else { + // If the numerator is zero, then the result must be zero even if |b|^2 + // underflows to zero. + const T numerator = + scalar_product_op()(a, scalar_conjugate_op()(b)); + if (numerator == T(0)) { + return T(0); + } + } + return scalar_quotient_op()(a, b); + } + template + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& a, + const Packet& b) const { + const Packet numerator = pmul(a, pconj(b)); + const Packet mask = por(pcmp_eq(b, pzero(a)), pcmp_eq(numerator, pzero(a))); + const Packet quotient = pdiv(a, b); + return pandnot(quotient, mask); + } +}; + +template +struct functor_traits> { + enum { + Cost = functor_traits>::Cost + NumTraits::MulCost, + PacketAccess = packet_traits::HasMul && packet_traits::HasDiv && + packet_traits::HasConj, + }; +}; + +template +struct mul_no_nan_op : public no_nan_op> { +}; + +template +struct functor_traits> { + enum { + Cost = functor_traits>::Cost + NumTraits::AddCost, + PacketAccess = true, + }; +}; + +// scalar_left and scalar_right are template helpers to partially +// apply a binary function. +// +// Suppose Binary is a binary functor f(x, y), scalar_left<> is a +// unary functor g_x(y) = f(x, y), where x is provided via the +// constructor. Similarly, scalar_right<> is a unary functor g_y(x) = +// f(x, y). + +template +struct scalar_left : private Binary { + using result_type = Tout; + + const Tin* left; + + inline scalar_left(const scalar_left& other) = default; + + template + EIGEN_DEVICE_FUNC inline explicit scalar_left(const Tin* c, Args... args) + : Binary(args...), left(c) {} + + EIGEN_DEVICE_FUNC inline Tout operator()(const Tin& right) const { + return Binary::operator()(*left, right); + } + + template + EIGEN_DEVICE_FUNC inline Packet packetOp(const Packet& right_packet) const { + return Binary::packetOp(Eigen::internal::pset1(*left), + right_packet); + } +}; + +template +struct functor_traits> { + enum { + Cost = functor_traits::Cost, + PacketAccess = functor_traits::PacketAccess, + }; +}; + +template +struct scalar_right : private Binary { + using result_type = Tout; + + const Tin* right; + + inline scalar_right(const scalar_right& other) = default; + + template + EIGEN_DEVICE_FUNC inline explicit scalar_right(const Tin* c, Args... args) + : Binary(args...), right(c) {} + + EIGEN_DEVICE_FUNC inline Tout operator()(const Tin& left) const { + return Binary::operator()(left, *right); + } + + template + EIGEN_DEVICE_FUNC inline Packet packetOp(const Packet& left_packet) const { + return Binary::packetOp(left_packet, + Eigen::internal::pset1(*right)); + } +}; + +template +struct functor_traits> { + enum { + Cost = functor_traits::Cost, + PacketAccess = functor_traits::PacketAccess, + }; +}; + +// similar to std::equal_to, but with the DEVICE_FUNC qualifier +template +struct equal_to : std::function { + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool operator()(const T& x, + const T& y) const { + return x == y; + } +}; + +// similar to std::not_equal_to, but with the DEVICE_FUNC qualifier +template +struct not_equal_to : std::function { + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool operator()(const T& x, + const T& y) const { + return x != y; + } +}; + +// similar to std::greater, but with the DEVICE_FUNC qualifier +template +struct greater : std::function { + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool operator()(const T& x, + const T& y) const { + return x > y; + } +}; + +// similar to std::less, but with the DEVICE_FUNC qualifier +template +struct less : std::function { + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool operator()(const T& x, + const T& y) const { + return x < y; + } +}; + +// similar to std::greater_equal, but with the DEVICE_FUNC qualifier +template +struct greater_equal : std::function { + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool operator()(const T& x, + const T& y) const { + return x >= y; + } +}; + +// similar to std::less_equal, but with the DEVICE_FUNC qualifier +template +struct less_equal : std::function { + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool operator()(const T& x, + const T& y) const { + return x <= y; + } +}; + +// Functor that enables squared difference functor. +template +struct scalar_squared_difference_op { + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar + operator()(const Scalar& a, const Scalar& b) const { + const Scalar v = scalar_difference_op()(a, b); + return scalar_product_op()(v, scalar_conjugate_op()(v)); + } + template + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& a, + const Packet& b) const { + const Packet v = scalar_difference_op().packetOp(a, b); + return scalar_product_op().packetOp( + v, scalar_conjugate_op().packetOp(v)); + } +}; + +template +struct functor_traits> { + enum { + Cost = functor_traits>::Cost + + functor_traits>::Cost + + functor_traits>::Cost, + PacketAccess = functor_traits>::PacketAccess && + functor_traits>::PacketAccess && + functor_traits>::PacketAccess + }; +}; + +// TODO(b/32239616): This kernel should be moved into Eigen and vectorized. +template +struct google_floor_div { + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& x, + const T& y) const { + const T z = x / y; + // Subtract one if there is a remainder and if the inputs have opposite + // signs. This approach avoids unnecessary overflows. + return z * y != x && (x < T(0) != y < T(0)) ? z - T(1) : z; + } + template + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& x, + const Packet& y) const { + Packet zeros = pzero(x); + Packet x_mask = pcmp_lt(x, zeros); + Packet y_mask = pcmp_lt(y, zeros); + Packet x_div_y = pdiv(x, y); + Packet x_div_y_times_y = pmul(x_div_y, y); + return pselect(por(peq(x_div_y_times_y, x), peq(x_mask, y_mask)), x_div_y, + psub(x_div_y, pones(x))); + } +}; + +template +struct google_floor_div< + T, typename std::enable_if::value>::type> { + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& x, + const T& y) const { + return x / y; + } + template + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& x, + const Packet& y) const { + return pdiv(x, y); + } +}; + +template +struct functor_traits> { + enum { + Cost = 2 * Eigen::internal::scalar_div_cost< + Scalar, packet_traits::HasDiv>::value + + NumTraits::AddCost, + PacketAccess = packet_traits::HasDiv + }; +}; + +template +struct google_floor_div_real { + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& x, + const T& y) const { + return Eigen::numext::floor(x / y); + } + template + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& x, + const Packet& y) const { + return pfloor(pdiv(x, y)); + } +}; + +template +struct functor_traits> { + enum { + Cost = 2 * Eigen::internal::scalar_div_cost< + Scalar, packet_traits::HasDiv>::value + + 2 * NumTraits::AddCost, + PacketAccess = + packet_traits::HasDiv && packet_traits::HasRound + }; +}; + +// TODO(rmlarsen): Add vectorized mod & fmod in Eigen and use it here. +template +struct google_floor_fmod { + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& x, + const T& y) const { + // EIGEN_STATIC_ASSERT(NUMERIC_TYPE_MUST_BE_REAL); + T trunc_mod = scalar_fmod_op()(x, y); + return trunc_mod != T(0) && (y < T(0) != trunc_mod < T(0)) ? trunc_mod + y + : trunc_mod; + } +}; + +template +struct functor_traits> { + enum { + Cost = functor_traits>::Cost + + NumTraits::AddCost, + PacketAccess = false + }; +}; + +// TODO(rmlarsen): Add vectorized mod & fmod in Eigen and use it here. +template +struct google_floor_mod { + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& x, + const T& y) const { + // EIGEN_STATIC_ASSERT(!NUMERIC_TYPE_MUST_BE_REAL); + T trunc_mod = Eigen::internal::scalar_mod2_op()(x, y); + return trunc_mod != T(0) && (y < T(0) != trunc_mod < T(0)) ? trunc_mod + y + : trunc_mod; + } +}; + +template +struct functor_traits> { + enum { + Cost = functor_traits>::Cost + + NumTraits::AddCost, + PacketAccess = false + }; +}; + +template +struct google_truncate_div_real { + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& x, + const T& y) const { + EIGEN_USING_STD(trunc) + return static_cast(trunc(x / y)); + } + template + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& x, + const Packet& y) const { + const Packet z = pdiv(x, y); + return pselect(pcmp_lt(z, pzero(z)), pceil(z), pfloor(z)); + } +}; + +template +struct functor_traits> { + enum { + Cost = 2 * Eigen::internal::scalar_div_cost< + Scalar, packet_traits::HasDiv>::value + + 3 * NumTraits::AddCost, + PacketAccess = packet_traits::HasDiv && + packet_traits::HasRound && + packet_traits::HasCmp + }; +}; + +#if EIGEN_COMP_GNUC && __cplusplus > 199711L +#define DISABLE_FLOAT_EQUALITY_WARNING \ + _Pragma("GCC diagnostic push") \ + _Pragma("GCC diagnostic ignored \"-Wfloat-equal\"") +#define ENABLE_FLOAT_EQUALITY_WARNING _Pragma("GCC diagnostic pop") +#else +#define DISABLE_FLOAT_EQUALITY_WARNING +#define ENABLE_FLOAT_EQUALITY_WARNING +#endif + +template ::IsInteger, + bool HasRint = packet_traits::HasRound> +struct scalar_round_half_to_even_op { + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar + operator()(const Scalar& x) const { + EIGEN_STATIC_ASSERT((!NumTraits::IsComplex), + NUMERIC_TYPE_MUST_BE_REAL) + + const Scalar round_val = Eigen::numext::floor(x + Scalar(0.5)); + const Scalar fraction = round_val - x; + if (TF_PREDICT_FALSE(fraction == Scalar(.5))) { + return Scalar(2) * Eigen::numext::floor(Scalar(.5) * x + Scalar(0.5)); + } else { + return round_val; + } + } + + template + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& x) const { + Packet half = pset1(Scalar(0.5)); + Packet round_val = pfloor(padd(x, half)); + Packet fraction = psub(round_val, x); + Packet half_mask = pcmp_eq(fraction, half); + bool any_halves = predux_any(half_mask); + if (TF_PREDICT_FALSE(any_halves)) { + Packet two = pset1(Scalar(2)); + Packet nearest_even = pmul(two, pfloor(pmadd(half, x, half))); + return pselect(half_mask, nearest_even, round_val); + } else { + return round_val; + } + } +}; + +template +struct scalar_round_half_to_even_op { + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar + operator()(const Scalar& x) const { + return x; + } + template + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& x) const { + return x; + } +}; + +template +struct scalar_round_half_to_even_op { + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar + operator()(const Scalar& x) const { + return Eigen::numext::rint(x); + } + template + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& x) const { + return print(x); + } +}; + +template +struct functor_traits> { + enum { + Cost = Eigen::NumTraits::IsInteger ? 0 + : 4 * NumTraits::AddCost, + PacketAccess = packet_traits::HasRound && + packet_traits::HasAdd && + packet_traits::HasMul, + }; +}; + +template ::IsInteger> +struct scalar_round_up_op { + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar + operator()(const Scalar& x) const { + EIGEN_STATIC_ASSERT((!NumTraits::IsComplex), + NUMERIC_TYPE_MUST_BE_REAL) + return Eigen::numext::floor(x + Scalar(0.5)); + } + + template + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& x) const { + return pfloor(padd(x, pset1(0.5))); + } +}; + +template +struct scalar_round_up_op { + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar + operator()(const Scalar& x) const { + return x; + } + + template + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& x) const { + return x; + } +}; + +template +struct functor_traits> { + enum { + Cost = IsInteger ? 0 : 4 * NumTraits::AddCost, + PacketAccess = IsInteger || packet_traits::HasRound + }; +}; + +#undef ENABLE_FLOAT_EQUALITY_WARNING +#undef DISABLE_FLOAT_EQUALITY_WARNING + +template +struct bitwise_xor_op { + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar + operator()(const Scalar& x, const Scalar& y) const { + return x ^ y; + } + template + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& a, + const Packet& b) const { + return Eigen::internal::pxor(a, b); + } +}; + +template +struct functor_traits> { + enum { Cost = Eigen::NumTraits::AddCost, PacketAccess = true }; +}; + +template +struct xlogy_op { + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar + operator()(const Scalar& x, const Scalar& y) const { + if (x == Scalar(0.)) { + return Scalar(0.); + } + return x * numext::log(y); + } + template + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& x, + const Packet& y) const { + Packet zeros = pzero(x); + Packet mask = pcmp_eq(x, zeros); + scalar_log_op log_op; + Packet log_y = log_op.packetOp(y); + Packet x_log_y = pmul(x, log_y); + return pselect(mask, x, x_log_y); + } +}; + +template +struct functor_traits> { + enum { + Cost = functor_traits>::Cost + + Eigen::NumTraits::MulCost, + PacketAccess = functor_traits>::PacketAccess + }; +}; + +template +struct xlog1py_op { + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar + operator()(const Scalar& x, const Scalar& y) const { + if (x == Scalar(0.)) { + return Scalar(0.); + } + return x * numext::log1p(y); + } + template + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& x, + const Packet& y) const { + Packet zeros = pzero(x); + Packet mask = pcmp_eq(x, zeros); + scalar_log1p_op log1p_op; + Packet log1p_y = log1p_op.packetOp(y); + Packet x_log1p_y = pmul(x, log1p_y); + return pselect(mask, x, x_log1p_y); + } +}; + +template +struct functor_traits> { + enum { + Cost = functor_traits>::Cost + + Eigen::NumTraits::MulCost, +#if TENSORFLOW_USE_ROCM + PacketAccess = false, +#else + PacketAccess = functor_traits>::PacketAccess +#endif + }; +}; + +template +struct xdivy_op { + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar + operator()(const Scalar& x, const Scalar& y) const { + if (x == Scalar(0.)) { + return Scalar(0.); + } + return x / y; + } + template + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& x, + const Packet& y) const { + Packet zeros = pzero(x); + Packet mask = pcmp_eq(x, zeros); + Packet x_div_y = pdiv(x, y); + return pselect(mask, x, x_div_y); + } +}; + +template +struct functor_traits> { + enum { + Cost = + Eigen::NumTraits::AddCost + + Eigen::internal::scalar_div_cost::HasDiv>::value, + PacketAccess = packet_traits::HasDiv + }; +}; + +template +struct scalar_erfinv_op { + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& x) const { + constexpr T half = T(0.5); + T y = numext::ndtri(half * x + half); + constexpr T half_sqrt = T(M_SQRT1_2); + return y * half_sqrt; + } + template + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& x) const { + Packet half = pset1(T(0.5)); + Packet y = pndtri(pmadd(half, x, half)); + Packet half_sqrt = pset1(T(M_SQRT1_2)); + return pmul(y, half_sqrt); + } +}; + +template +struct functor_traits> { + enum { + Cost = functor_traits>::Cost + NumTraits::AddCost, + PacketAccess = packet_traits::HasNdtri, + }; +}; + +} // end namespace internal +} // end namespace Eigen + +namespace tensorflow { +namespace functor { + +//////////////////////////////////////////////////////////////////////////////// +// Helpers +//////////////////////////////////////////////////////////////////////////////// + +// Base template for functors whose input scalar type is T and +// output scalar type is R. +template +struct base { + // func defines operator() and its vectorized version packetOp(). + typedef F func; + + // If true, the functor's corresponding binary op will instantiate + // specialized kernels to perform an optimized broadcast + // operation. Each functor for which this is enabled increases the + // code size, so by default this is disabled for binary functors and + // is enabled on a per-op basis as needed. + static constexpr bool use_bcast_optimization = false; + + // operator() has the signature: + // out_type operator()(in_type in0, in_type in1 ...) + typedef R out_type; + typedef T in_type; + + // TensorFlow provides tensor-ized version of "func". Roughly + // speaking, the tensorflow operation has the signature: + // tout_type op(tin_type in0) + // tout_type op(tin_type in0, tin_type in1) + // tout_type op(tin_type in0, in_type scalar) + typedef typename TTypes::Flat tout_type; + typedef typename TTypes::ConstFlat tin_type; + typedef typename TTypes::ConstScalar tscalar_type; + + // Whether the functor can error out. Currently applies only to integer + // div and mod. + static constexpr bool has_errors = false; +}; + +// For now, we only apply certain speed optimization for +// float/double's broadcast binary op. +template +struct use_bcast_optimization { + static constexpr bool value = false; +}; + +template <> +struct use_bcast_optimization { + static constexpr bool value = true; +}; + +template <> +struct use_bcast_optimization { + static constexpr bool value = true; +}; + +//////////////////////////////////////////////////////////////////////////////// +// Unary functors +//////////////////////////////////////////////////////////////////////////////// + +// abs(x) = |x| +// neg(x) = - x +// inverse(x) = 1 / x +// square(x) = x^2 +// sqrt(x) = x^(1/2) +// rsqrt(x) = x^(-1/2) +// exp(x) = e^x +// expm1(x) = e^x - 1 +// log(x) = natural logarithm of x +// log1p(x) = natural logarithm of 1 + x +// tanh = (exp(x) - exp(-x)) / (exp(x) + exp(-x)) +// sigmoid = 1 / (1 + exp(-x)) // a.k.a, logistic +// +// NOTE: We may eventually implement common functions used in NN +// here. E.g., rectifier, softplus, derivatives of tanh, sigmod, etc. +// For reference, see speech/lstm/eigen_functors.h. + +template +struct abs : base, + typename Eigen::internal::scalar_abs_op::result_type> {}; + +template +struct neg : base> {}; + +template +struct inverse : base> {}; + +template +struct square : base> {}; + +template +struct sqrt : base> {}; + +template +struct rsqrt : base> {}; + +template +struct exp : base> {}; + +template +struct expm1 : base> {}; + +template +struct log : base> {}; + +template +struct log1p : base> {}; + +template +struct sign : base> {}; + +template +struct sinh : base> {}; + +template +struct cosh : base> {}; + +template +struct tanh : base> {}; + +template +struct asinh : base> {}; + +template +struct acosh : base> {}; + +template +struct atanh : base> {}; + +template +struct lgamma : base> {}; + +template +struct digamma : base> {}; + +template +struct erf : base> {}; + +template +struct erfc : base> {}; + +template +struct ndtri : base> {}; + +template +struct erfinv : base> {}; + +template +struct sigmoid : base> {}; + +template +struct sin : base> {}; + +template +struct cos : base> {}; + +template +struct tan : base> {}; + +template +struct asin : base> {}; + +template +struct acos : base> {}; + +template +struct atan : base> {}; + +struct logical_not : base> { +}; + +// Flip all bits. Named invert to be consistent with numpy. +template +struct invert_op { + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& a) const { + return ~a; + } +}; + +template +struct invert : base> {}; + +// NOTE: std::isinf, std::isnan, std::isfinite are plain function. +// Therefore we need to wrap them in functors to be used with Eigen's +// type system. +template +struct isinf : base, bool> {}; + +template +struct isnan : base, bool> {}; + +template +struct isfinite : base, bool> {}; + +template +struct floor : base> {}; + +template +struct round : base> {}; + +template +struct ceil : base> {}; + +// Note: rint rounds half values to even, just like round_half_to_even_op. +template +struct rint : base> {}; + +//////////////////////////////////////////////////////////////////////////////// +// Binary functors +//////////////////////////////////////////////////////////////////////////////// + +// Binary functors: +// +// add(x, y) = x + y +// sub(x, y) = x - y +// mul(x, y) = x * y +// div(x, y) = x / y +// mod(x, y) = x % y (int32 and int64 only) +// fmod(x, y) = fmod(x, y) (float and double only) +// pow(x, y) = x ^ y +// maximum(x, y) = x > y ? x : y +// minimum(x, y) = x < y ? x : y +// squared_difference(x, y) = conj(x - y) * (x - y) + +template +struct add : base> { + static constexpr bool use_bcast_optimization = true; +}; + +template +struct sub : base> { + static constexpr bool use_bcast_optimization = true; +}; + +template +struct mul : base> { + static constexpr bool use_bcast_optimization = true; +}; + +template +struct mul_no_nan : base> {}; + +template +struct div : base> {}; + +template +struct safe_div : base>> { + static constexpr bool has_errors = true; +}; + +template +struct div_no_nan : base> {}; + +template +struct fmod : base> {}; + +template +struct mod : base> {}; + +template +struct safe_mod : base>> { + static constexpr bool has_errors = true; +}; + +template +struct floor_fmod : base> {}; + +template +struct safe_floor_mod : base>> { + static constexpr bool has_errors = true; +}; + +template +struct floor_div : base> {}; + +template +struct safe_floor_div : base>> { + static constexpr bool has_errors = true; +}; + +template +struct floor_div_real : base> {}; + +template +struct truncate_div_real + : base> {}; + +template +struct pow : base> {}; + +template +struct safe_pow : base> { + static constexpr bool has_errors = true; +}; + +// Version of safe_pow for integers which returns 0 if RHS is negative and LHS +// is not 1 or -1. For use on GPUs, where we cannot raise an error. +template +struct safe_pow_ignore_error_op { + static_assert(std::is_integral::value, "Integer type expected"); + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& x, + const T& y) const { + if (TF_PREDICT_FALSE(y < 0)) { + if (x == T(-1)) { + T trunc_mod = Eigen::internal::scalar_mod2_op()(y, T(2)); + return trunc_mod == T(-1) ? T(-1) : T(1); + } + return x == T(1) ? T(1) : T(0); + } + return Eigen::internal::scalar_pow_op{}(x, y); + } +}; + +template +struct safe_pow_ignore_error : base> {}; + +template +struct maximum + : base> {}; + +template +struct minimum + : base> {}; + +template +struct igamma : base> {}; + +template +struct random_gamma_grad + : base> {}; + +template +struct igammac : base> {}; + +template +struct zeta : base> {}; + +template +struct polygamma : base> {}; + +template +struct atan2 : base> {}; + +template +struct squared_difference + : base> {}; + +template +struct xdivy : base> {}; + +template +struct xlogy : base> {}; + +template +struct xlog1py : base> {}; + +template +struct less : base, bool> {}; + +template +struct less_equal : base, bool> {}; + +template +struct greater : base, bool> {}; + +template +struct greater_equal : base, bool> {}; + +template +struct equal_to : base, bool> {}; + +template +struct not_equal_to : base, bool> {}; + +struct logical_and : base> { +}; + +struct logical_or : base> {}; + +template +struct bitwise_and_op { + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& x, + const T& y) const { + return x & y; + } +}; + +template +struct bitwise_or_op { + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& x, + const T& y) const { + return x | y; + } +}; + +template +struct bitwise_and : base> {}; + +template +struct bitwise_or : base> {}; + +template +struct bitwise_xor : base> {}; + +template +struct left_shift_op { + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& x, + const T& y) const { + // Avoids UB: don't shift by larger than the bitwidth of T, and + // performs left shifts as unsigned shifts. + T y_clamped = y; + if (y_clamped < 0) { + y_clamped = 0; + } else if (y_clamped > sizeof(T) * CHAR_BIT - 1) { + y_clamped = sizeof(T) * CHAR_BIT - 1; + } + using U = typename std::make_unsigned::type; + return static_cast(static_cast(x) << static_cast(y_clamped)); + } +}; + +template +struct right_shift_op { + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& x, + const T& y) const { + // Avoids UB: don't shift by larger than the bitwidth of T. + T y_clamped = y; + if (y_clamped < 0) { + y_clamped = 0; + } else if (y_clamped > sizeof(T) * CHAR_BIT - 1) { + y_clamped = sizeof(T) * CHAR_BIT - 1; + } + // Technically right shifts of signed integers are not necessarily + // arithmetic shifts according to the C++ standard. However in practice most + // implementations are arithmetic shifts. If this proves to be a problem in + // practice, we may need to use an alternative implementation. + return x >> y_clamped; + } +}; + +template +struct left_shift : base> {}; + +template +struct right_shift : base> {}; + +template +struct make_complex_func { + typedef std::complex result_type; + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE result_type operator()(T real, + T imag) const { + return std::complex(real, imag); + } +}; + +template +struct make_complex : base, std::complex> {}; + +template +struct get_real + : base, typename T::value_type> {}; + +template +struct get_imag + : base, typename T::value_type> {}; + +template +struct get_angle + : base, typename T::value_type> {}; + +template +struct conj : base> {}; + +//////////////////////////////////////////////////////////////////////////////// +// Functors takes 1 or 2 tensors, computes the base functor on +// coefficient of the input tensors and puts the results in the output +// tensor. +//////////////////////////////////////////////////////////////////////////////// +template +struct UnaryFunctor { + // Computes on device "d": out[i] = Functor(in[i]) + void operator()(const Device& d, typename Functor::tout_type out, + typename Functor::tin_type in); +}; + +template +struct UnaryFunctorWithArg { + // Computes on device "d": out[i] = Functor(in[i]) + void operator()(const Device& d, typename Functor::tout_type out, + typename Functor::tin_type in, Targ val); +}; + +template +struct BinaryFunctor { + // Computes on device "d": out[i] = Functor(in0[i], in1[i]) + void operator()(const Device& d, typename Functor::tout_type out, + typename Functor::tin_type in0, + typename Functor::tin_type in1, bool* error); + + // Computes on device "d": out[i] = Functor(scalar[0], in[i]) + void Left(const Device& d, typename Functor::tout_type out, + typename Functor::tscalar_type scalar, + typename Functor::tin_type in, bool* error); + + // Computes on device "d": out[i] = Functor(in[i], scalar[0]) + void Right(const Device& d, typename Functor::tout_type out, + typename Functor::tin_type in, + typename Functor::tscalar_type scalar, bool* error); + + // Computes on device "d": + // out = Functor(in0.broadcast(bcast0), in1.broadcast(bcast1)) + // + // TODO(zhifengc): makes BCast a template member function on NDIMS + // instead making BinaryFunctor templates on NDIMS. + void BCast(const Device& d, + typename TTypes::Tensor out, + typename TTypes::ConstTensor in0, + typename Eigen::array bcast0, + typename TTypes::ConstTensor in1, + typename Eigen::array bcast1, + bool* error); +}; + +template +struct ApproximateEqual { + void operator()(const Device& d, typename TTypes::ConstFlat x, + typename TTypes::ConstFlat y, T tolerance, + typename TTypes::Flat z); +}; + +template +bool AllOne(const typename Eigen::array& a) { + for (size_t i = 0; i < a.size(); ++i) { + if (a[i] != 1) return false; + } + return true; +} + +template +struct SelectFunctor { + void operator()(const Device& d, typename TTypes::Flat out, + typename TTypes::ConstFlat cond_flat, + typename TTypes::ConstFlat then_flat, + typename TTypes::ConstFlat else_flat); +}; + +template +struct SelectScalarFunctor { + void operator()(const Device& d, typename TTypes::Flat out, + typename TTypes::ConstScalar cond, + typename TTypes::ConstFlat then_flat, + typename TTypes::ConstFlat else_flat); +}; + +template +struct BatchSelectFunctor { + void operator()(const Device& d, + typename TTypes::Matrix output_flat_outer_dims, + TTypes::ConstVec cond_vec, + typename TTypes::ConstMatrix then_flat_outer_dims, + typename TTypes::ConstMatrix else_flat_outer_dims); +}; + +template +struct BCastSelectFunctor { + void operator()(const Device& d, + typename TTypes::Tensor output_tensor, + typename TTypes::ConstTensor cond_tensor, + typename TTypes::ConstTensor then_tensor, + typename TTypes::ConstTensor else_tensor, + typename Eigen::array cond_bcast, + typename Eigen::array then_bcast, + typename Eigen::array else_bcast); +}; + +} // end namespace functor +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_CWISE_OPS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/cwise_ops_common.h b/third_party/tflite-hdrs/tensorflow/core/kernels/cwise_ops_common.h new file mode 100644 index 00000000..fd7ee451 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/cwise_ops_common.h @@ -0,0 +1,683 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_CWISE_OPS_COMMON_H_ +#define TENSORFLOW_CORE_KERNELS_CWISE_OPS_COMMON_H_ + +// See docs in ../ops/math_ops.cc. +#define _USE_MATH_DEFINES +#include + +#define EIGEN_USE_THREADS + +#include "tensorflow/core/platform/bfloat16.h" + + +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/framework/variant_op_registry.h" +#include "tensorflow/core/kernels/cwise_ops.h" +#include "tensorflow/core/kernels/cwise_ops_gradients.h" +#include "tensorflow/core/kernels/fill_functor.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/util/bcast.h" + +namespace tensorflow { + +typedef Eigen::ThreadPoolDevice CPUDevice; +typedef Eigen::GpuDevice GPUDevice; + +class BinaryOpShared : public OpKernel { + public: + explicit BinaryOpShared(OpKernelConstruction* ctx, DataType out, DataType in); + + protected: + struct BinaryOpState { + // Sets up bcast with the shape of in0 and in1, ensures that the bcast + // is valid, and if so, set out, either by allocating a new buffer using + // ctx->output(...) or by creating an alias for an owned input buffer for + // in-place computation. + // Caller must check ctx->status() upon return for non-ok status. + // If ctx->status().ok() is true, then out is guaranteed to be allocated. + explicit BinaryOpState(OpKernelContext* ctx); + + const Tensor& in0; + const Tensor& in1; + + BCast bcast; + Tensor* out = nullptr; + int64_t out_num_elements; + + int64_t in0_num_elements; + int64_t in1_num_elements; + + int ndims; + bool result; + }; + + void SetUnimplementedError(OpKernelContext* ctx); + void SetComputeError(OpKernelContext* ctx); +}; + +// Coefficient-wise binary operations: +// Device: E.g., CPUDevice, GPUDevice. +// Functor: defined in cwise_ops.h. E.g., functor::add. +template +class BinaryOp : public BinaryOpShared { + public: + typedef typename Functor::in_type Tin; // Input scalar data type. + typedef typename Functor::out_type Tout; // Output scalar data type. + + explicit BinaryOp(OpKernelConstruction* ctx) + : BinaryOpShared(ctx, DataTypeToEnum::v(), + DataTypeToEnum::v()) {} + + void Compute(OpKernelContext* ctx) override { + const Tensor& input_0 = ctx->input(0); + OP_REQUIRES(ctx, input_0.dtype() == DataTypeToEnum::v(), + errors::InvalidArgument( + "Expected tensor of type ", + DataTypeString(DataTypeToEnum::v()), " but got type ", + DataTypeString(input_0.dtype()))); + const Tensor& input_1 = ctx->input(1); + OP_REQUIRES(ctx, input_1.dtype() == DataTypeToEnum::v(), + errors::InvalidArgument( + "Expected tensor of type ", + DataTypeString(DataTypeToEnum::v()), " but got type ", + DataTypeString(input_1.dtype()))); + const Device& eigen_device = ctx->eigen_device(); + bool error = false; + bool* const error_ptr = Functor::has_errors ? &error : nullptr; + + // NOTE: Handle three simple cases before building the BinaryOpState, which + // is relatively expensive for small operations. + if (input_0.shape() == input_1.shape()) { + // tensor op tensor with no broadcasting. + Tensor* out; + OP_REQUIRES_OK(ctx, ctx->forward_input_or_allocate_output( + {0, 1}, 0, input_0.shape(), &out)); + functor::BinaryFunctor()( + eigen_device, out->template flat(), + input_0.template flat(), input_1.template flat(), + error_ptr); + if (Functor::has_errors && error) { + SetComputeError(ctx); + } + return; + } else if (input_0.shape().dims() == 0) { + // scalar op tensor. + Tensor* out; + OP_REQUIRES_OK(ctx, ctx->forward_input_or_allocate_output( + {1}, 0, input_1.shape(), &out)); + + functor::BinaryFunctor().Left( + eigen_device, out->template flat(), + input_0.template scalar(), input_1.template flat(), + error_ptr); + if (Functor::has_errors && error) { + SetComputeError(ctx); + } + return; + } else if (input_1.shape().dims() == 0) { + // tensor op scalar. + Tensor* out; + OP_REQUIRES_OK(ctx, ctx->forward_input_or_allocate_output( + {0}, 0, input_0.shape(), &out)); + functor::BinaryFunctor().Right( + eigen_device, out->template flat(), + input_0.template flat(), input_1.template scalar(), + error_ptr); + if (Functor::has_errors && error) { + SetComputeError(ctx); + } + return; + } + + // 'state': Shared helper not dependent on T to reduce code size + BinaryOpState state(ctx); + if (ctx->status().code() == error::RESOURCE_EXHAUSTED) { + // Stop when BinaryOpState's constructor failed due to OOM. + return; + } + auto& bcast = state.bcast; + Tensor* out = state.out; + if (!bcast.IsValid()) { + if (ctx->status().ok()) { + if (state.result) { + functor::SetOneFunctor()(eigen_device, + out->flat()); + } else { + functor::SetZeroFunctor()(eigen_device, + out->flat()); + } + } + return; + } + + auto& in0 = state.in0; + auto& in1 = state.in1; + if (state.out_num_elements == 0) { + return; + } + + const int ndims = state.ndims; + if (ndims <= 1) { + auto out_flat = out->flat(); + if (state.in1_num_elements == 1) { + // tensor op scalar + functor::BinaryFunctor().Right( + eigen_device, out_flat, in0.template flat(), + in1.template scalar(), error_ptr); + } else if (state.in0_num_elements == 1) { + // scalar op tensor + functor::BinaryFunctor().Left( + eigen_device, out_flat, in0.template scalar(), + in1.template flat(), error_ptr); + } else { + functor::BinaryFunctor()( + eigen_device, out_flat, in0.template flat(), + in1.template flat(), error_ptr); + } + } else if (ndims == 2) { + functor::BinaryFunctor().BCast( + eigen_device, out->shaped(bcast.result_shape()), + in0.template shaped(bcast.x_reshape()), + BCast::ToIndexArray<2>(bcast.x_bcast()), + in1.template shaped(bcast.y_reshape()), + BCast::ToIndexArray<2>(bcast.y_bcast()), error_ptr); + } else if (ndims == 3) { + functor::BinaryFunctor().BCast( + eigen_device, out->shaped(bcast.result_shape()), + in0.template shaped(bcast.x_reshape()), + BCast::ToIndexArray<3>(bcast.x_bcast()), + in1.template shaped(bcast.y_reshape()), + BCast::ToIndexArray<3>(bcast.y_bcast()), error_ptr); + } else if (ndims == 4) { + functor::BinaryFunctor().BCast( + eigen_device, out->shaped(bcast.result_shape()), + in0.template shaped(bcast.x_reshape()), + BCast::ToIndexArray<4>(bcast.x_bcast()), + in1.template shaped(bcast.y_reshape()), + BCast::ToIndexArray<4>(bcast.y_bcast()), error_ptr); + } else if (ndims == 5) { + functor::BinaryFunctor().BCast( + eigen_device, out->shaped(bcast.result_shape()), + in0.template shaped(bcast.x_reshape()), + BCast::ToIndexArray<5>(bcast.x_bcast()), + in1.template shaped(bcast.y_reshape()), + BCast::ToIndexArray<5>(bcast.y_bcast()), error_ptr); + } else { + SetUnimplementedError(ctx); + } + if (Functor::has_errors && error) { + SetComputeError(ctx); + } + } +}; + +template +class ApproximateEqualOp : public OpKernel { + public: + explicit ApproximateEqualOp(OpKernelConstruction* context) + : OpKernel(context) { + float tolerance; + OP_REQUIRES_OK(context, context->GetAttr("tolerance", &tolerance)); + tolerance_ = T(tolerance); + } + void Compute(OpKernelContext* context) override { + const Tensor& x_input = context->input(0); + const Tensor& y_input = context->input(1); + OP_REQUIRES( + context, x_input.shape() == y_input.shape(), + errors::InvalidArgument("x and y must be of the same shape. ", + "x shape: ", x_input.shape().DebugString(), + ". y shape: ", y_input.shape().DebugString())); + Tensor* z_output = nullptr; + OP_REQUIRES_OK(context, + context->allocate_output(0, x_input.shape(), &z_output)); + const Device& d = context->eigen_device(); + typename TTypes::ConstFlat x(x_input.flat()); + typename TTypes::ConstFlat y(y_input.flat()); + typename TTypes::Flat z(z_output->flat()); + functor::ApproximateEqual()(d, x, y, tolerance_, z); + } + + private: + T tolerance_; +}; + +// Basic coefficient-wise binary operations that are known to not require +// any broadcasting. This is the case for example of the gradients of +// unary operations. +// Device: E.g., CPUDevice, GPUDevice. +// Functor: defined above. E.g., functor::tanh_grad. +template +class SimpleBinaryOp : public OpKernel { + public: + typedef typename Functor::in_type Tin; // Input scalar data type. + typedef typename Functor::out_type Tout; // Output scalar data type. + + explicit SimpleBinaryOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} + + void Compute(OpKernelContext* ctx) override { + const Tensor& in0 = ctx->input(0); + const Tensor& in1 = ctx->input(1); + OP_REQUIRES( + ctx, in0.NumElements() == in1.NumElements(), + errors::InvalidArgument("The two arguments to a cwise op must have " + "same number of elements, got ", + in0.NumElements(), " and ", in1.NumElements())); + auto in0_flat = in0.flat(); + auto in1_flat = in1.flat(); + const Device& eigen_device = ctx->eigen_device(); + + Tensor* out = nullptr; + if (std::is_same::value) { + OP_REQUIRES_OK(ctx, ctx->forward_input_or_allocate_output( + {0, 1}, 0, in0.shape(), &out)); + } else { + OP_REQUIRES_OK(ctx, ctx->allocate_output(0, in0.shape(), &out)); + } + auto out_flat = out->flat(); + functor::SimpleBinaryFunctor()(eigen_device, out_flat, + in0_flat, in1_flat); + } +}; + +// Coefficient-wise unary operations: +// Device: E.g., CPUDevice, GPUDevice. +// Functor: defined in cwise_ops.h. E.g., functor::sqrt. +template +class UnaryOp : public OpKernel { + public: + typedef typename Functor::in_type Tin; // Input scalar data type. + typedef typename Functor::out_type Tout; // Output scalar data type. + // Tin may be different from Tout. E.g., abs: complex64 -> float + + explicit UnaryOp(OpKernelConstruction* ctx) : OpKernel(ctx) { + auto in = DataTypeToEnum::v(); + auto out = DataTypeToEnum::v(); + OP_REQUIRES_OK(ctx, ctx->MatchSignature({in}, {out})); + } + + void Compute(OpKernelContext* ctx) override { + const Tensor& inp = ctx->input(0); + Tensor* out = nullptr; + if (std::is_same::value) { + OP_REQUIRES_OK(ctx, ctx->forward_input_or_allocate_output( + {0}, 0, inp.shape(), &out)); + } else { + OP_REQUIRES_OK(ctx, ctx->allocate_output(0, inp.shape(), &out)); + } + functor::UnaryFunctor()( + ctx->eigen_device(), out->flat(), inp.flat()); + } +}; + +template +class UnaryVariantOp : public OpKernel { + public: + explicit UnaryVariantOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} + + void Compute(OpKernelContext* ctx) override { + const Tensor& inp = ctx->input(0); + OP_REQUIRES( + ctx, TensorShapeUtils::IsScalar(inp.shape()), + errors::InvalidArgument("Non-scalar variants are not supported.")); + const Variant& v = inp.scalar()(); + Variant v_out; + OP_REQUIRES_OK(ctx, UnaryOpVariant(ctx, OpEnum, v, &v_out)); + int numa_node = ctx->device()->NumaNode(); + Tensor out(cpu_allocator(numa_node), DT_VARIANT, TensorShape()); + out.scalar()() = std::move(v_out); + ctx->set_output(0, std::move(out)); + } +}; + +namespace functor { + +template +void Assign(const D& d, Out out, Rhs rhs) { + out.device(d) = rhs; +} + +// Partial specialization of BinaryFunctor +// for functors with no error checking. +template +struct BinaryFunctor { + void operator()(const CPUDevice& d, typename Functor::tout_type out, + typename Functor::tin_type in0, + typename Functor::tin_type in1, bool* error) { + Assign(d, out, in0.binaryExpr(in1, typename Functor::func())); + } + + void Left(const CPUDevice& d, typename Functor::tout_type out, + typename Functor::tscalar_type scalar, + typename Functor::tin_type in, bool* error) { + typedef typename Functor::out_type Tout; + typedef typename Functor::in_type Tin; + typedef typename Functor::func Binary; + typedef typename Eigen::internal::scalar_left Unary; + Assign(d, out, in.unaryExpr(Unary(scalar.data()))); + } + + void Right(const CPUDevice& d, typename Functor::tout_type out, + typename Functor::tin_type in, + typename Functor::tscalar_type scalar, bool* error) { + typedef typename Functor::out_type Tout; + typedef typename Functor::in_type Tin; + typedef typename Functor::func Binary; + typedef typename Eigen::internal::scalar_right Unary; + Assign(d, out, in.unaryExpr(Unary(scalar.data()))); + } + + void BCast(const CPUDevice& dev, + typename TTypes::Tensor out, + typename TTypes::ConstTensor in0, + typename Eigen::array bcast0, + typename TTypes::ConstTensor in1, + typename Eigen::array bcast1, + bool* error) { + typename Functor::func func; + if (AllOne(bcast0) && AllOne(bcast1)) { + Assign(dev, out, in0.binaryExpr(in1, func)); + } else if (AllOne(bcast0)) { + auto rhs = in1.broadcast(bcast1); + Assign(dev, out, in0.binaryExpr(rhs, func)); + } else if (AllOne(bcast1)) { + auto lhs = in0.broadcast(bcast0); + Assign(dev, out, lhs.binaryExpr(in1, func)); + } else { + auto lhs = in0.broadcast(bcast0); + auto rhs = in1.broadcast(bcast1); + Assign(dev, out, lhs.binaryExpr(rhs, func)); + } + } +}; + +// Partial specialization of BinaryFunctor +// for functors with no error checking. +template +struct BinaryFunctor { + enum { NDIMS = 2 }; + + void operator()(const CPUDevice& d, typename Functor::tout_type out, + typename Functor::tin_type in0, + typename Functor::tin_type in1, bool* error) { + Assign(d, out, in0.binaryExpr(in1, typename Functor::func())); + } + + void Left(const CPUDevice& d, typename Functor::tout_type out, + typename Functor::tscalar_type scalar, + typename Functor::tin_type in, bool* error) { + typedef typename Functor::out_type Tout; + typedef typename Functor::in_type Tin; + typedef typename Functor::func Binary; + typedef typename Eigen::internal::scalar_left Unary; + Assign(d, out, in.unaryExpr(Unary(scalar.data()))); + } + + void Right(const CPUDevice& d, typename Functor::tout_type out, + typename Functor::tin_type in, + typename Functor::tscalar_type scalar, bool* error) { + typedef typename Functor::out_type Tout; + typedef typename Functor::in_type Tin; + typedef typename Functor::func Binary; + typedef typename Eigen::internal::scalar_right Unary; + Assign(d, out, in.unaryExpr(Unary(scalar.data()))); + } + + inline Eigen::IndexList> NByOne( + Eigen::DenseIndex n) { + Eigen::IndexList> ret; + ret.set(0, n); + return ret; + } + inline Eigen::IndexList, Eigen::DenseIndex> OneByM( + Eigen::DenseIndex m) { + Eigen::IndexList, Eigen::DenseIndex> ret; + ret.set(1, m); + return ret; + } + + void BCast(const CPUDevice& dev, + typename TTypes::Tensor out, + typename TTypes::ConstTensor in0, + typename Eigen::array bcast0, + typename TTypes::ConstTensor in1, + typename Eigen::array bcast1, + bool* error) { + typedef typename Functor::in_type T; + typename Functor::func func; + if (Functor::use_bcast_optimization && use_bcast_optimization::value) { + // Optimize for speed by using Eigen::type2index and avoid + // .broadcast() when we know it's a no-op. + // + // Here, we need to handle 6 cases depending on how many "1" + // exist in in0 and in1's shapes (4 numbers in total). It's not + // possible that two shapes have more than 2 1s because those + // are simplified to NDIMS==1 case. + // + // Because this optimization increases the binary size for each + // Functor (+, -, *, /, <, <=, etc.), type and ndim combination. + // we only apply such optimization for selected ops/types/ndims. + // + // Because NDIMS, Functor::use_broadcast_optimization and + // use_broadcast_optimization are compile-time constant, gcc + // does a decent job avoiding generating code when conditions + // are not met. + const Eigen::DenseIndex a = in0.dimension(0); // in0 is shape [a, b] + const Eigen::DenseIndex b = in0.dimension(1); + const Eigen::DenseIndex c = in1.dimension(0); // in1 is shape [c, d] + const Eigen::DenseIndex d = in1.dimension(1); + if ((a == 1) && (d == 1)) { + auto lhs = in0.reshape(OneByM(b)).broadcast(NByOne(c)); + auto rhs = in1.reshape(NByOne(c)).broadcast(OneByM(b)); + Assign(dev, out, lhs.binaryExpr(rhs, func)); + return; + } + if ((b == 1) && (c == 1)) { + auto lhs = in0.reshape(NByOne(a)).broadcast(OneByM(d)); + auto rhs = in1.reshape(OneByM(d)).broadcast(NByOne(a)); + Assign(dev, out, lhs.binaryExpr(rhs, func)); + return; + } + if (a == 1) { + auto lhs = in0.reshape(OneByM(b)).broadcast(NByOne(c)); + auto rhs = in1; + Assign(dev, out, lhs.binaryExpr(rhs, func)); + return; + } + if (b == 1) { + auto lhs = in0.reshape(NByOne(a)).broadcast(OneByM(d)); + auto rhs = in1; + Assign(dev, out, lhs.binaryExpr(rhs, func)); + return; + } + if (c == 1) { + auto lhs = in0; + auto rhs = in1.reshape(OneByM(d)).broadcast(NByOne(a)); + Assign(dev, out, lhs.binaryExpr(rhs, func)); + return; + } + if (d == 1) { + auto lhs = in0; + auto rhs = in1.reshape(NByOne(c)).broadcast(OneByM(b)); + Assign(dev, out, lhs.binaryExpr(rhs, func)); + return; + } + + const bool bcast0_all_one = AllOne(bcast0); + const bool bcast1_all_one = AllOne(bcast1); + if (bcast0_all_one && !bcast1_all_one) { + auto lhs = in0; // No need to do broadcast for in0 + auto rhs = in1.broadcast(bcast1); + Assign(dev, out, lhs.binaryExpr(rhs, func)); + return; + } + + if (!bcast0_all_one && bcast1_all_one) { + auto lhs = in0.broadcast(bcast0); + auto rhs = in1; // No need to do broadcast for in1 + Assign(dev, out, lhs.binaryExpr(rhs, func)); + return; + } + } + + // Fallback path. Always works and probably slower. + auto lhs = in0.broadcast(bcast0); + auto rhs = in1.broadcast(bcast1); + Assign(dev, out, lhs.binaryExpr(rhs, func)); + } +}; + +// Version of BinaryFunctor with error handling. +template +struct BinaryFunctor { + void operator()(const CPUDevice& d, typename Functor::tout_type out, + typename Functor::tin_type in0, + typename Functor::tin_type in1, bool* error) { + Assign(d, out, in0.binaryExpr(in1, typename Functor::func(error))); + } + + void Left(const CPUDevice& d, typename Functor::tout_type out, + typename Functor::tscalar_type scalar, + typename Functor::tin_type in, bool* error) { + typedef typename Functor::out_type Tout; + typedef typename Functor::in_type Tin; + typedef typename Functor::func Binary; + typedef typename Eigen::internal::scalar_left Unary; + Assign(d, out, in.unaryExpr(Unary(scalar.data(), error))); + } + + void Right(const CPUDevice& d, typename Functor::tout_type out, + typename Functor::tin_type in, + typename Functor::tscalar_type scalar, bool* error) { + typedef typename Functor::out_type Tout; + typedef typename Functor::in_type Tin; + typedef typename Functor::func Binary; + typedef typename Eigen::internal::scalar_right Unary; + Assign(d, out, in.unaryExpr(Unary(scalar.data(), error))); + } + + void BCast(const CPUDevice& dev, + typename TTypes::Tensor out, + typename TTypes::ConstTensor in0, + typename Eigen::array bcast0, + typename TTypes::ConstTensor in1, + typename Eigen::array bcast1, + bool* error) { + typename Functor::func func(error); + auto lhs = in0.broadcast(bcast0); + auto rhs = in1.broadcast(bcast1); + Assign(dev, out, lhs.binaryExpr(rhs, func)); + } +}; + +// Partial specialization of UnaryFunctor. +template +struct UnaryFunctor { + void operator()(const CPUDevice& d, typename Functor::tout_type out, + typename Functor::tin_type in) { + Assign(d, out, in.unaryExpr(typename Functor::func())); + } +}; + +template +struct UnaryFunctorWithArg { + void operator()(const CPUDevice& d, typename Functor::tout_type out, + typename Functor::tin_type in, Targ val) { + Assign(d, out, in.unaryExpr(typename Functor::func(val))); + } +}; + +// Partial specialization of ApproximateEqual. +template +struct ApproximateEqual { + void operator()(const CPUDevice& d, typename TTypes::ConstFlat x, + typename TTypes::ConstFlat y, T tolerance, + typename TTypes::Flat z) { + auto diff = x - y; + z.device(d) = diff.abs() <= tolerance; + } +}; + +} // end namespace functor + +#define REGISTER(OP, D, N, F, T) \ + REGISTER_KERNEL_BUILDER(Name(N).Device(DEVICE_##D).TypeConstraint("T"), \ + OP>); + +#define REGISTER_VARIANT(OP, D, N, ENUM) \ + REGISTER_KERNEL_BUILDER( \ + Name(N).Device(DEVICE_##D).TypeConstraint("T"), \ + OP); + +// Macros to register kernels for multiple types (T0, T1, etc.) on +// device type "D" (CPU or GPU) for operation "N" (e.g., sqrt) using +// the functor "F" (e.g., functor::sqrt). + +#if defined(__ANDROID_TYPES_SLIM__) +// Note that __ANDROID_TYPES_SLIM__ is also checked in the cwise_ops*.cc files. +// Normally Android TensorFlow is built with a reduced number of types (float). +// Override on the command-line using "--copt=-D__ANDROID_TYPES_FULL__" +// to generate a library with full type support with a consequent increase in +// code size. +#define REGISTER2(OP, D, N, F, T0, T1) REGISTER(OP, D, N, F, T0) +#define REGISTER3(OP, D, N, F, T0, T1, T2) REGISTER(OP, D, N, F, T0) +#define REGISTER4(OP, D, N, F, T0, T1, T2, T3) REGISTER(OP, D, N, F, T0) +#define REGISTER5(OP, D, N, F, T0, T1, T2, T3, T4) REGISTER(OP, D, N, F, T0) +#define REGISTER6(OP, D, N, F, T0, T1, T2, T3, T4, T5) REGISTER(OP, D, N, F, T0) +#define REGISTER7(OP, D, N, F, T0, T1, T2, T3, T4, T5, T6) \ + REGISTER(OP, D, N, F, T0) +#define REGISTER8(OP, D, N, F, T0, T1, T2, T3, T4, T5, T6, T7) \ + REGISTER(OP, D, N, F, T0) +#define REGISTER9(OP, D, N, F, T0, T1, T2, T3, T4, T5, T6, T7, T8) \ + REGISTER(OP, D, N, F, T0) +#else // !defined(__ANDROID_TYPES_SLIM__) +#define REGISTER2(OP, D, N, F, T0, T1) \ + REGISTER(OP, D, N, F, T0) \ + REGISTER(OP, D, N, F, T1) +#define REGISTER3(OP, D, N, F, T0, T1, T2) \ + REGISTER2(OP, D, N, F, T0, T1) \ + REGISTER(OP, D, N, F, T2) +#define REGISTER4(OP, D, N, F, T0, T1, T2, T3) \ + REGISTER2(OP, D, N, F, T0, T1) \ + REGISTER2(OP, D, N, F, T2, T3) +#define REGISTER5(OP, D, N, F, T0, T1, T2, T3, T4) \ + REGISTER3(OP, D, N, F, T0, T1, T2) \ + REGISTER2(OP, D, N, F, T3, T4) +#define REGISTER6(OP, D, N, F, T0, T1, T2, T3, T4, T5) \ + REGISTER3(OP, D, N, F, T0, T1, T2) \ + REGISTER3(OP, D, N, F, T3, T4, T5) +#define REGISTER7(OP, D, N, F, T0, T1, T2, T3, T4, T5, T6) \ + REGISTER4(OP, D, N, F, T0, T1, T2, T3) \ + REGISTER3(OP, D, N, F, T4, T5, T6) +#define REGISTER8(OP, D, N, F, T0, T1, T2, T3, T4, T5, T6, T7) \ + REGISTER4(OP, D, N, F, T0, T1, T2, T3) \ + REGISTER4(OP, D, N, F, T4, T5, T6, T7) +#define REGISTER9(OP, D, N, F, T0, T1, T2, T3, T4, T5, T6, T7, T8) \ + REGISTER5(OP, D, N, F, T0, T1, T2, T3, T4) \ + REGISTER4(OP, D, N, F, T5, T6, T7, T8) + +// Instead of adding REGISTER10, etc., shard the .cc files - see +// cwise_op_equal_to_*.cc for an example. + +#endif // defined(__ANDROID_TYPES_SLIM__) + +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_CWISE_OPS_COMMON_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/cwise_ops_gpu_common.cu.h b/third_party/tflite-hdrs/tensorflow/core/kernels/cwise_ops_gpu_common.cu.h new file mode 100644 index 00000000..fdd61d03 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/cwise_ops_gpu_common.cu.h @@ -0,0 +1,218 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#if !GOOGLE_CUDA && !TENSORFLOW_USE_ROCM +#error This file must only be included when building with Cuda or ROCm support +#endif + +#ifndef TENSORFLOW_CORE_KERNELS_CWISE_OPS_GPU_COMMON_CU_H_ +#define TENSORFLOW_CORE_KERNELS_CWISE_OPS_GPU_COMMON_CU_H_ + +#define _USE_MATH_DEFINES +#include +#include + +#define EIGEN_USE_GPU +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/kernels/cwise_ops.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/types.h" +namespace tensorflow { +namespace functor { + +typedef Eigen::GpuDevice GPUDevice; +typedef std::complex complex64; +typedef std::complex complex128; + +// Partial specialization of UnaryFunctor. +template +struct UnaryFunctor { + void operator()(const GPUDevice& d, typename Functor::tout_type out, + typename Functor::tin_type in) { + MaybeWith32BitIndexing( + [&](auto out32, auto in32) { + out32.device(d) = in32.unaryExpr(typename Functor::func()); + }, + out, in); + } +}; + +// Partial specialization of BinaryFunctor. +template +struct BinaryFunctor { + void operator()(const GPUDevice& d, typename Functor::tout_type out, + typename Functor::tin_type in0, + typename Functor::tin_type in1, bool* error) { + MaybeWith32BitIndexing( + [&](auto out32, auto in0_32, auto in1_32) { + out32.device(d) = in0_32.binaryExpr(in1_32, typename Functor::func()); + }, + out, in0, in1); + } + + void Left(const GPUDevice& d, typename Functor::tout_type out, + typename Functor::tscalar_type scalar, + typename Functor::tin_type in, bool* error) { + typedef typename Functor::out_type Tout; + typedef typename Functor::in_type Tin; + typedef typename Functor::func Binary; + typedef typename Eigen::internal::scalar_left Unary; + MaybeWith32BitIndexing( + [&](auto out32, auto in32) { + out32.device(d) = in32.unaryExpr(Unary(scalar.data())); + }, + out, in); + } + + void Right(const GPUDevice& d, typename Functor::tout_type out, + typename Functor::tin_type in, + typename Functor::tscalar_type scalar, bool* error) { + typedef typename Functor::out_type Tout; + typedef typename Functor::in_type Tin; + typedef typename Functor::func Binary; + typedef typename Eigen::internal::scalar_right Unary; + MaybeWith32BitIndexing( + [&](auto out32, auto in32) { + out32.device(d) = in32.unaryExpr(Unary(scalar.data())); + }, + out, in); + } + + void BCast(const GPUDevice& d, + typename TTypes::Tensor out, + typename TTypes::ConstTensor in0, + typename Eigen::array bcast0, + typename TTypes::ConstTensor in1, + typename Eigen::array bcast1, + bool* error) { + typedef typename Functor::in_type T; + typename Functor::func func; + if ((NDIMS == 2) && Functor::use_bcast_optimization && + use_bcast_optimization::value) { + const bool bcast0_all_one = AllOne(bcast0); + const bool bcast1_all_one = AllOne(bcast1); + if (bcast0_all_one && !bcast1_all_one) { + MaybeWith32BitIndexing( + [&](auto out32, auto in0_32, auto in1_32) { + out32.device(d) = + in0_32.binaryExpr(in1_32.broadcast(bcast1), func); + }, + out, in0, in1); + return; + } + if (!bcast0_all_one && bcast1_all_one) { + MaybeWith32BitIndexing( + [&](auto out32, auto in0_32, auto in1_32) { + out32.device(d) = + in0_32.broadcast(bcast0).binaryExpr(in1_32, func); + }, + out, in0, in1); + return; + } + } + MaybeWith32BitIndexing( + [&](auto out32, auto in0_32, auto in1_32) { + out32.device(d) = in0_32.broadcast(bcast0).binaryExpr( + in1_32.broadcast(bcast1), func); + }, + out, in0, in1); + } +}; + +// Partial specialization of ApproximateEqual. +template +struct ApproximateEqual { + void operator()(const GPUDevice& d, typename TTypes::ConstFlat x, + typename TTypes::ConstFlat y, T tolerance, + typename TTypes::Flat z) { + auto diff = x - y; + z.device(d) = diff.abs() <= tolerance; + } +}; + +// Macros to explicitly instantiate kernels on GPU for multiple types +// (T0, T1, etc.) for UnaryFunctor (e.g., functor::sqrt). +#define DEFINE_UNARY1(F, T) template struct UnaryFunctor > +#define DEFINE_UNARY2(F, T0, T1) \ + DEFINE_UNARY1(F, T0); \ + DEFINE_UNARY1(F, T1) +#define DEFINE_UNARY3(F, T0, T1, T2) \ + DEFINE_UNARY2(F, T0, T1); \ + DEFINE_UNARY1(F, T2) +#define DEFINE_UNARY4(F, T0, T1, T2, T3) \ + DEFINE_UNARY2(F, T0, T1); \ + DEFINE_UNARY2(F, T2, T3) +#define DEFINE_UNARY5(F, T0, T1, T2, T3, T4) \ + DEFINE_UNARY2(F, T0, T1); \ + DEFINE_UNARY3(F, T2, T3, T4) +#define DEFINE_UNARY6(F, T0, T1, T2, T3, T4, T5) \ + DEFINE_UNARY2(F, T0, T1); \ + DEFINE_UNARY4(F, T2, T3, T4, T5) +#define DEFINE_UNARY7(F, T0, T1, T2, T3, T4, T5, T6) \ + DEFINE_UNARY2(F, T0, T1); \ + DEFINE_UNARY5(F, T2, T3, T4, T5, T6) +#define DEFINE_UNARY8(F, T0, T1, T2, T3, T4, T5, T6, T7) \ + DEFINE_UNARY4(F, T0, T1, T2, T3); \ + DEFINE_UNARY4(F, T4, T5, T6, T7) + +// Macros to explicitly instantiate kernels on GPU for multiple types +// (T0, T1, etc.) for BinaryFunctor. +#define DEFINE_BINARY1(F, T) \ + template struct BinaryFunctor, 1>; \ + template struct BinaryFunctor, 2>; \ + template struct BinaryFunctor, 3>; \ + template struct BinaryFunctor, 4>; \ + template struct BinaryFunctor, 5> +#define DEFINE_BINARY2(F, T0, T1) \ + DEFINE_BINARY1(F, T0); \ + DEFINE_BINARY1(F, T1) +#define DEFINE_BINARY3(F, T0, T1, T2) \ + DEFINE_BINARY2(F, T0, T1); \ + DEFINE_BINARY1(F, T2) +#define DEFINE_BINARY4(F, T0, T1, T2, T3) \ + DEFINE_BINARY2(F, T0, T1); \ + DEFINE_BINARY2(F, T2, T3) +#define DEFINE_BINARY5(F, T0, T1, T2, T3, T4) \ + DEFINE_BINARY2(F, T0, T1); \ + DEFINE_BINARY3(F, T2, T3, T4) +#define DEFINE_BINARY6(F, T0, T1, T2, T3, T4, T5) \ + DEFINE_BINARY3(F, T0, T1, T2); \ + DEFINE_BINARY3(F, T3, T4, T5) +#define DEFINE_BINARY7(F, T0, T1, T2, T3, T4, T5, T6) \ + DEFINE_BINARY3(F, T0, T1, T2); \ + DEFINE_BINARY4(F, T3, T4, T5, T6) +#define DEFINE_BINARY8(F, T0, T1, T2, T3, T4, T5, T6, T7) \ + DEFINE_BINARY4(F, T0, T1, T2, T3); \ + DEFINE_BINARY4(F, T4, T5, T6, T7) +#define DEFINE_BINARY9(F, T0, T1, T2, T3, T4, T5, T6, T7, T8) \ + DEFINE_BINARY4(F, T0, T1, T2, T3); \ + DEFINE_BINARY5(F, T4, T5, T6, T7, T8) +#define DEFINE_BINARY10(F, T0, T1, T2, T3, T4, T5, T6, T7, T8, T9) \ + DEFINE_BINARY5(F, T0, T1, T2, T3, T4); \ + DEFINE_BINARY5(F, T5, T6, T7, T8, T9) +#define DEFINE_BINARY11(F, T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10) \ + DEFINE_BINARY5(F, T0, T1, T2, T3, T4); \ + DEFINE_BINARY6(F, T5, T6, T7, T8, T9, T10) + +#define DEFINE_APPROXIMATE_EQUAL1(T) \ + template struct ApproximateEqual; +#define DEFINE_APPROXIMATE_EQUAL2(T0, T1) \ + DEFINE_APPROXIMATE_EQUAL1(T0); \ + DEFINE_APPROXIMATE_EQUAL1(T1); + +} // end namespace functor +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_CWISE_OPS_GPU_COMMON_CU_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/cwise_ops_gpu_gradients.cu.h b/third_party/tflite-hdrs/tensorflow/core/kernels/cwise_ops_gpu_gradients.cu.h new file mode 100644 index 00000000..dddce612 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/cwise_ops_gpu_gradients.cu.h @@ -0,0 +1,74 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#if !GOOGLE_CUDA && !TENSORFLOW_USE_ROCM +#error This file must only be included when building with Cuda or ROCm support +#endif + +#ifndef TENSORFLOW_CORE_KERNELS_CWISE_OPS_GPU_GRADIENTS_CU_H_ +#define TENSORFLOW_CORE_KERNELS_CWISE_OPS_GPU_GRADIENTS_CU_H_ + +#define EIGEN_USE_GPU + +#include + +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/kernels/cwise_ops.h" +#include "tensorflow/core/kernels/cwise_ops_gradients.h" +#include "tensorflow/core/platform/types.h" + +#include "tensorflow/core/platform/logging.h" +namespace tensorflow { +namespace functor { + +typedef Eigen::GpuDevice GPUDevice; +typedef std::complex complex64; +typedef std::complex complex128; + +// Partial specialization of SimpleBinaryFunctor. +template +struct SimpleBinaryFunctor { + void operator()(const GPUDevice& d, typename Functor::tout_type out, + typename Functor::tin_type in1, + typename Functor::tin_type in2) { + MaybeWith32BitIndexing( + [&](auto out32, auto in1_32) { + out32.device(d) = in1_32.binaryExpr(in2, typename Functor::func()); + }, + out, in1); + } +}; + +// Macros to explicitly instantiate kernels on GPU for multiple types +// (T0, T1, etc.) for SimpleBinaryFunctor (e.g., functor::tanh_grad). +#define DEFINE_SIMPLE_BINARY1(F, T) \ + template struct SimpleBinaryFunctor > +#define DEFINE_SIMPLE_BINARY2(F, T0, T1) \ + DEFINE_SIMPLE_BINARY1(F, T0); \ + DEFINE_SIMPLE_BINARY1(F, T1) +#define DEFINE_SIMPLE_BINARY3(F, T0, T1, T2) \ + DEFINE_SIMPLE_BINARY2(F, T0, T1); \ + DEFINE_SIMPLE_BINARY1(F, T2) +#define DEFINE_SIMPLE_BINARY4(F, T0, T1, T2, T3) \ + DEFINE_SIMPLE_BINARY2(F, T0, T1); \ + DEFINE_SIMPLE_BINARY2(F, T2, T3) +#define DEFINE_SIMPLE_BINARY5(F, T0, T1, T2, T3, T4) \ + DEFINE_SIMPLE_BINARY2(F, T0, T1); \ + DEFINE_SIMPLE_BINARY3(F, T2, T3, T4) + +} // end namespace functor +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_CWISE_OPS_GPU_GRADIENTS_CU_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/cwise_ops_gradients.h b/third_party/tflite-hdrs/tensorflow/core/kernels/cwise_ops_gradients.h new file mode 100644 index 00000000..0be3f788 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/cwise_ops_gradients.h @@ -0,0 +1,210 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_CWISE_OPS_GRADIENTS_H_ +#define TENSORFLOW_CORE_KERNELS_CWISE_OPS_GRADIENTS_H_ + +#define EIGEN_USE_THREADS +#include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/kernels/cwise_ops.h" + +namespace Eigen { +namespace internal { + +// Gradient for the tanh function +template +struct scalar_tanh_gradient_op { + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T + operator()(const T& output, const T& output_gradient) const { + return output_gradient * (T(1) - output * output); + } + template + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Packet + packetOp(const Packet& output, const Packet& output_gradient) const { + return pmul(output_gradient, + psub(pset1(T(1)), pmul(output, output))); + } +}; +template +struct functor_traits> { + enum { + Cost = NumTraits::AddCost + 2 * NumTraits::MulCost, + PacketAccess = packet_traits::HasSub && packet_traits::HasMul, + }; +}; + +// Gradient for the sigmoid function +template +struct scalar_sigmoid_gradient_op { + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T + operator()(const T& output, const T& output_gradient) const { + return output_gradient * output * (T(1) - output); + } + template + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Packet + packetOp(const Packet& output, const Packet& output_gradient) const { + return pmul(output_gradient, + pmul(output, psub(pset1(T(1)), output))); + } +}; +template +struct functor_traits> { + enum { + Cost = NumTraits::AddCost + 2 * NumTraits::MulCost, + PacketAccess = packet_traits::HasSub && packet_traits::HasMul, + }; +}; + +// Gradient for the inverse function +template +struct scalar_inverse_gradient_op { + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T + operator()(const T& output, const T& output_gradient) const { + if (output_gradient == T(0)) { + return T(0); + } else { + const T out_conj = numext::conj(output); + return -out_conj * out_conj * output_gradient; + } + } + template + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Packet + packetOp(const Packet& output, const Packet& output_gradient) const { + const Packet out_conj = pconj(output); + return mul_no_nan_op().packetOp(pnegate(pmul(out_conj, out_conj)), + output_gradient); + } +}; +template +struct functor_traits> { + enum { + Cost = NumTraits::AddCost + 2 * NumTraits::MulCost, + PacketAccess = packet_traits::HasMul, + }; +}; + +// Gradient for the sqrt function +template +struct scalar_sqrt_gradient_op { + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T + operator()(const T& output, const T& output_gradient) const { + if (output_gradient == T(0)) { + return T(0); + } else { + const T out_conj = numext::conj(output); + return (static_cast(0.5) * output_gradient) / out_conj; + } + } + template + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Packet + packetOp(const Packet& output, const Packet& output_gradient) const { + const Packet const_half = pset1(static_cast(0.5)); + const Packet out_conj = pconj(output); + return mul_no_nan_op().packetOp(pdiv(const_half, out_conj), + output_gradient); + } +}; +template +struct functor_traits> { + enum { + PacketAccess = packet_traits::HasMul & packet_traits::HasDiv, + Cost = NumTraits::MulCost + scalar_div_cost::value, + }; +}; + +// Gradient for the rsqrt function +template +struct scalar_rsqrt_gradient_op { + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T + operator()(const T& output, const T& output_gradient) const { + if (output_gradient == T(0)) { + return T(0); + } else { + const T out_conj = numext::conj(output); + return static_cast(-0.5) * (output_gradient * out_conj) * + (out_conj * out_conj); + } + } + template + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Packet + packetOp(const Packet& output, const Packet& output_gradient) const { + const Packet const_half = pset1(static_cast(-0.5)); + const Packet out_conj = pconj(output); + auto safe_pmul = [](const Packet& a, const Packet& b) { + return mul_no_nan_op().packetOp(a, b); + }; + return safe_pmul(pmul(const_half, pmul(out_conj, out_conj)), + safe_pmul(out_conj, output_gradient)); + } +}; +template +struct functor_traits> { + enum { + Cost = 4 * NumTraits::MulCost, + PacketAccess = packet_traits::HasMul, + }; +}; + +} // end namespace internal +} // end namespace Eigen + +namespace tensorflow { + +namespace functor { + +template +struct SimpleBinaryFunctor { + void operator()(const Device& d, typename Functor::tout_type out, + typename Functor::tin_type in0, + typename Functor::tin_type in1); +}; + +// Partial specialization of BinaryFunctor for CPU devices +typedef Eigen::ThreadPoolDevice CPUDevice; + +template +struct SimpleBinaryFunctor { + void operator()(const CPUDevice& d, typename Functor::tout_type out, + typename Functor::tin_type in0, + typename Functor::tin_type in1) { + out.device(d) = in0.binaryExpr(in1, typename Functor::func()); + } +}; + + +template +struct tanh_grad : base> {}; + +template +struct sigmoid_grad : base> { +}; + +template +struct inverse_grad : base> { +}; + +template +struct sqrt_grad : base> {}; + +template +struct rsqrt_grad : base> {}; + +template +struct igamma_grad_a : base> {}; + +} // end namespace functor + +} // end namespace tensorflow +#endif // TENSORFLOW_CORE_KERNELS_CWISE_OPS_GRADIENTS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/data/batch_dataset_op.h b/third_party/tflite-hdrs/tensorflow/core/kernels/data/batch_dataset_op.h new file mode 100644 index 00000000..4be07eff --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/data/batch_dataset_op.h @@ -0,0 +1,48 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_KERNELS_DATA_BATCH_DATASET_OP_H_ +#define TENSORFLOW_CORE_KERNELS_DATA_BATCH_DATASET_OP_H_ + +#include "tensorflow/core/framework/dataset.h" + +namespace tensorflow { +namespace data { + +class BatchDatasetOp : public UnaryDatasetOpKernel { + public: + static constexpr const char* const kDatasetType = "Batch"; + static constexpr const char* const kInputDataset = "input_dataset"; + static constexpr const char* const kBatchSize = "batch_size"; + static constexpr const char* const kDropRemainder = "drop_remainder"; + static constexpr const char* const kParallelCopy = "parallel_copy"; + static constexpr const char* const kOutputTypes = "output_types"; + static constexpr const char* const kOutputShapes = "output_shapes"; + + explicit BatchDatasetOp(OpKernelConstruction* ctx); + + protected: + void MakeDataset(OpKernelContext* ctx, DatasetBase* input, + DatasetBase** output) override; + + private: + class Dataset; + const int op_version_; + bool parallel_copy_ = false; +}; + +} // namespace data +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_DATA_BATCH_DATASET_OP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/data/cache_dataset_ops.h b/third_party/tflite-hdrs/tensorflow/core/kernels/data/cache_dataset_ops.h new file mode 100644 index 00000000..e0ceee2a --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/data/cache_dataset_ops.h @@ -0,0 +1,52 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_KERNELS_DATA_CACHE_DATASET_OPS_H_ +#define TENSORFLOW_CORE_KERNELS_DATA_CACHE_DATASET_OPS_H_ + +#include "tensorflow/core/framework/dataset.h" + +namespace tensorflow { +namespace data { + +class CacheDatasetOp : public UnaryDatasetOpKernel { + public: + class FileDatasetBase; + class MemoryDatasetBase; + + static constexpr const char* const kDatasetType = "Cache"; + static constexpr const char* const kInputDataset = "input_dataset"; + static constexpr const char* const kFileName = "filename"; + static constexpr const char* const kOutputTypes = "output_types"; + static constexpr const char* const kOutputShapes = "output_shapes"; + + explicit CacheDatasetOp(OpKernelConstruction* ctx); + + protected: + void MakeDataset(OpKernelContext* ctx, DatasetBase* input, + DatasetBase** output) override; + + private: + class FileDataset; + class FileDatasetV2; + class MemoryDataset; + class MemoryDatasetV2; + + const int op_version_; +}; + +} // namespace data +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_DATA_CACHE_DATASET_OPS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/data/cache_ops.h b/third_party/tflite-hdrs/tensorflow/core/kernels/data/cache_ops.h new file mode 100644 index 00000000..e1e58ae9 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/data/cache_ops.h @@ -0,0 +1,98 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_DATA_CACHE_OPS_H_ +#define TENSORFLOW_CORE_KERNELS_DATA_CACHE_OPS_H_ + +#include "tensorflow/core/data/dataset_utils.h" +#include "tensorflow/core/framework/resource_mgr.h" + +namespace tensorflow { +namespace data { + +// A thread-safe data structure for caching dataset elements. +// +// The expected use is that a single `MemoryWriterIterator` populates the +// cache with dataset elements. Once all elements are cached, the cache can +// be used by one or more `MemoryReaderIterator`s. +class MemoryCache { + public: + MemoryCache() = default; + + // Marks the cache as completed. + void Complete(std::vector>&& cache); + + // Returns whether the cache is completed. + bool IsCompleted(); + + // Resets the cache. + void Reset(); + + // Returns the element at the given index. + const std::vector& at(int64_t index); + + // Returns the size of the cache. + size_t size(); + + // Returns a reference to the cache's data. The returned reference will be + // invalidated by any call to Reset(). + const std::vector>& data(); + + private: + mutex mu_; + // Determines whether all elements of the dataset have been cached. + bool completed_ TF_GUARDED_BY(mu_) = false; + std::vector> cache_ TF_GUARDED_BY(mu_); +}; + +// A resource wrapping a shared instance of a memory cache. +class MemoryCacheManager : public ResourceBase { + public: + MemoryCacheManager() : cache_(std::make_shared()) {} + + string DebugString() const override; + + std::shared_ptr get() { return cache_; } + + private: + std::shared_ptr cache_; +}; + +// Creates an instance of cache resource and transfers ownership to the caller. +class AnonymousMemoryCacheHandleOp + : public AnonymousResourceOp { + public: + explicit AnonymousMemoryCacheHandleOp(OpKernelConstruction* ctx); + + private: + string name() override; + absl::Status CreateResource( + OpKernelContext* ctx, std::unique_ptr flib_def, + std::unique_ptr pflr, + FunctionLibraryRuntime* lib, MemoryCacheManager** manager) override; +}; + +// Deletes an instance of cache resource. +class DeleteMemoryCacheOp : public OpKernel { + public: + explicit DeleteMemoryCacheOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} + + void Compute(OpKernelContext* ctx) override; +}; + +} // namespace data +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_DATA_CACHE_OPS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/data/concatenate_dataset_op.h b/third_party/tflite-hdrs/tensorflow/core/kernels/data/concatenate_dataset_op.h new file mode 100644 index 00000000..a40e71fc --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/data/concatenate_dataset_op.h @@ -0,0 +1,44 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_KERNELS_DATA_CONCATENATE_DATASET_OP_H_ +#define TENSORFLOW_CORE_KERNELS_DATA_CONCATENATE_DATASET_OP_H_ + +#include "tensorflow/core/framework/dataset.h" + +namespace tensorflow { +namespace data { + +class ConcatenateDatasetOp : public BinaryDatasetOpKernel { + public: + static constexpr const char* const kDatasetType = "Concatenate"; + static constexpr const char* const kInputDataset = "input_dataset"; + static constexpr const char* const kAnotherDataset = "another_dataset"; + static constexpr const char* const kOutputTypes = "output_types"; + static constexpr const char* const kOutputShapes = "output_shapes"; + + explicit ConcatenateDatasetOp(OpKernelConstruction* ctx); + + protected: + void MakeDataset(OpKernelContext* ctx, DatasetBase* input, + DatasetBase* to_concatenate, DatasetBase** output) override; + + private: + class Dataset; +}; + +} // namespace data +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_DATA_CONCATENATE_DATASET_OP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/data/dataset_ops.h b/third_party/tflite-hdrs/tensorflow/core/kernels/data/dataset_ops.h new file mode 100644 index 00000000..fbbfb514 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/data/dataset_ops.h @@ -0,0 +1,82 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_KERNELS_DATA_DATASET_OPS_H_ +#define TENSORFLOW_CORE_KERNELS_DATA_DATASET_OPS_H_ + +#include + +#include "tensorflow/core/platform/platform.h" + +// On mobile we do not provide this functionality because not all of its +// dependencies are available there. +#if !defined(IS_MOBILE_PLATFORM) +#include "tensorflow/core/framework/dataset.h" +#include "tensorflow/core/framework/op_kernel.h" + +namespace tensorflow { +namespace data { + +class DatasetToGraphOp : public OpKernel { + public: + static constexpr const char* const kAllowStateful = "allow_stateful"; + static constexpr const char* const kStripDeviceAssignment = + "strip_device_assignment"; + static constexpr const char* const kExternalStatePolicy = + "external_state_policy"; + static constexpr const char* const kDatasetToGraph = "DatasetToGraph"; + + explicit DatasetToGraphOp(OpKernelConstruction* ctx); + + void Compute(OpKernelContext* ctx) override; + + private: + const int op_version_; + ExternalStatePolicy external_state_policy_ = ExternalStatePolicy::POLICY_WARN; + bool strip_device_assignment_ = false; +}; + +class DatasetCardinalityOp : public OpKernel { + public: + explicit DatasetCardinalityOp(OpKernelConstruction* ctx); + + void Compute(OpKernelContext* ctx) override; + + private: + std::unique_ptr cardinality_options_; +}; + +// An OpKernel that computes the fingerprint of a dataset. +class DatasetFingerprintOp : public OpKernel { + public: + explicit DatasetFingerprintOp(OpKernelConstruction* ctx); + + void Compute(OpKernelContext* ctx) override; +}; + +class DatasetFromGraphOp : public OpKernel { + public: + static constexpr const char* const kGraphDef = "graph_def"; + static constexpr const char* const kHandle = "handle"; + + explicit DatasetFromGraphOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} + + void Compute(OpKernelContext* ctx) override; +}; + +} // namespace data +} // namespace tensorflow +#endif // !IS_MOBILE_PLATFORM + +#endif // TENSORFLOW_CORE_KERNELS_DATA_DATASET_OPS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/data/experimental/assert_cardinality_dataset_op.h b/third_party/tflite-hdrs/tensorflow/core/kernels/data/experimental/assert_cardinality_dataset_op.h new file mode 100644 index 00000000..098206a6 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/data/experimental/assert_cardinality_dataset_op.h @@ -0,0 +1,48 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_ASSERT_CARDINALITY_DATASET_OP_H_ +#define TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_ASSERT_CARDINALITY_DATASET_OP_H_ + +#include "tensorflow/core/framework/dataset.h" + +namespace tensorflow { +namespace data { +namespace experimental { + +class AssertCardinalityDatasetOp : public UnaryDatasetOpKernel { + public: + static constexpr const char* const kDatasetType = "AssertCardinality"; + static constexpr const char* const kInputDataset = "input_dataset"; + static constexpr const char* const kCardinality = "cardinality"; + static constexpr const char* const kOutputTypes = "output_types"; + static constexpr const char* const kOutputShapes = "output_shapes"; + + explicit AssertCardinalityDatasetOp(OpKernelConstruction* ctx); + + protected: + void MakeDataset(OpKernelContext* ctx, DatasetBase* input, + DatasetBase** output) override; + + private: + class Dataset; + DataTypeVector output_types_; + std::vector output_shapes_; +}; + +} // namespace experimental +} // namespace data +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_ASSERT_CARDINALITY_DATASET_OP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/data/experimental/assert_next_dataset_op.h b/third_party/tflite-hdrs/tensorflow/core/kernels/data/experimental/assert_next_dataset_op.h new file mode 100644 index 00000000..6e86b5d8 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/data/experimental/assert_next_dataset_op.h @@ -0,0 +1,48 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_ASSERT_NEXT_DATASET_OP_H_ +#define TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_ASSERT_NEXT_DATASET_OP_H_ + +#include "tensorflow/core/framework/dataset.h" + +namespace tensorflow { +namespace data { +namespace experimental { + +class AssertNextDatasetOp : public UnaryDatasetOpKernel { + public: + static constexpr const char* const kDatasetType = "AssertNext"; + static constexpr const char* const kInputDataset = "input_dataset"; + static constexpr const char* const kTransformations = "transformations"; + static constexpr const char* const kOutputTypes = "output_types"; + static constexpr const char* const kOutputShapes = "output_shapes"; + + explicit AssertNextDatasetOp(OpKernelConstruction* ctx); + + protected: + void MakeDataset(OpKernelContext* ctx, DatasetBase* input, + DatasetBase** output) override; + + private: + class Dataset; + DataTypeVector output_types_; + std::vector output_shapes_; +}; + +} // namespace experimental +} // namespace data +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_ASSERT_NEXT_DATASET_OP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/data/experimental/assert_prev_dataset_op.h b/third_party/tflite-hdrs/tensorflow/core/kernels/data/experimental/assert_prev_dataset_op.h new file mode 100644 index 00000000..ed42b0c8 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/data/experimental/assert_prev_dataset_op.h @@ -0,0 +1,48 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_ASSERT_PREV_DATASET_OP_H_ +#define TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_ASSERT_PREV_DATASET_OP_H_ + +#include "tensorflow/core/framework/dataset.h" + +namespace tensorflow { +namespace data { +namespace experimental { + +class AssertPrevDatasetOp : public UnaryDatasetOpKernel { + public: + static constexpr char kDatasetType[] = "AssertPrev"; + static constexpr char kInputDataset[] = "input_dataset"; + static constexpr char kTransformations[] = "transformations"; + static constexpr char kOutputTypes[] = "output_types"; + static constexpr char kOutputShapes[] = "output_shapes"; + + explicit AssertPrevDatasetOp(OpKernelConstruction* ctx); + + protected: + void MakeDataset(OpKernelContext* ctx, DatasetBase* input, + DatasetBase** output) override; + + private: + class Dataset; + DataTypeVector output_types_; + std::vector output_shapes_; +}; + +} // namespace experimental +} // namespace data +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_ASSERT_PREV_DATASET_OP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/data/experimental/auto_shard_dataset_op.h b/third_party/tflite-hdrs/tensorflow/core/kernels/data/experimental/auto_shard_dataset_op.h new file mode 100644 index 00000000..c1f71bd6 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/data/experimental/auto_shard_dataset_op.h @@ -0,0 +1,53 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_AUTO_SHARD_DATASET_OP_H_ +#define TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_AUTO_SHARD_DATASET_OP_H_ + +#include "tensorflow/core/framework/dataset.h" + +namespace tensorflow { +namespace data { +namespace experimental { + +class AutoShardDatasetOp : public UnaryDatasetOpKernel { + public: + static constexpr const char* const kDatasetType = "AutoShard"; + static constexpr const char* const kInputDataset = "input_dataset"; + static constexpr const char* const kNumWorkers = "num_workers"; + static constexpr const char* const kIndex = "index"; + static constexpr const char* const kAutoShardPolicy = "auto_shard_policy"; + static constexpr const char* const kNumReplicas = "num_replicas"; + static constexpr const char* const kOutputTypes = "output_types"; + static constexpr const char* const kOutputShapes = "output_shapes"; + + explicit AutoShardDatasetOp(OpKernelConstruction* ctx); + + protected: + void MakeDataset(OpKernelContext* ctx, DatasetBase* input, + DatasetBase** output) override; + + private: + static RewriterConfig CreateConfig(int64_t num_workers, int64_t index, + int64_t auto_shard_policy, + int64_t num_replicas); + int64_t auto_shard_policy_; + int64_t num_replicas_; +}; + +} // namespace experimental +} // namespace data +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_AUTO_SHARD_DATASET_OP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/data/experimental/compression_ops.h b/third_party/tflite-hdrs/tensorflow/core/kernels/data/experimental/compression_ops.h new file mode 100644 index 00000000..6dd89ea4 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/data/experimental/compression_ops.h @@ -0,0 +1,49 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_COMPRESSION_OPS_H_ +#define TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_COMPRESSION_OPS_H_ + +#include "tensorflow/core/framework/dataset.h" + +namespace tensorflow { +namespace data { +namespace experimental { + +class CompressElementOp : public OpKernel { + public: + explicit CompressElementOp(OpKernelConstruction* ctx); + + void Compute(OpKernelContext* ctx) override; +}; + +class UncompressElementOp : public OpKernel { + public: + static constexpr const char* const kOutputTypes = "output_types"; + static constexpr const char* const kOutputShapes = "output_shapes"; + + explicit UncompressElementOp(OpKernelConstruction* ctx); + + void Compute(OpKernelContext* ctx) override; + + private: + DataTypeVector output_types_; + std::vector output_shapes_; +}; + +} // namespace experimental +} // namespace data +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_COMPRESSION_OPS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/data/experimental/data_service_dataset_op.h b/third_party/tflite-hdrs/tensorflow/core/kernels/data/experimental/data_service_dataset_op.h new file mode 100644 index 00000000..5f23123a --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/data/experimental/data_service_dataset_op.h @@ -0,0 +1,107 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_DATA_SERVICE_DATASET_OP_H_ +#define TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_DATA_SERVICE_DATASET_OP_H_ + +#include +#include +#include + +#include "absl/strings/str_cat.h" +#include "absl/time/time.h" +#include "tensorflow/core/data/captured_function.h" +#include "tensorflow/core/data/service/common.h" +#include "tensorflow/core/framework/dataset.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/resource_mgr.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/platform/mutex.h" + +namespace tensorflow { +namespace data { + +// A resource which counts how many iterators have been created. This is used +// by the DataServiceDataset to coordinate jobs across multiple iterations. +class IterationCounter : public ResourceBase { + public: + IterationCounter() : counter_(0) {} + + std::string DebugString() const override { + mutex_lock l(mu_); + return absl::StrCat(counter_); + } + + int64_t GetAndIncrement() { + mutex_lock l(mu_); + return ++counter_; + } + + private: + mutable mutex mu_; + int64_t counter_ TF_GUARDED_BY(mu_) = 0; +}; + +// Creates a dataset for reading from the tf.data service. +class DataServiceDatasetOp : public DatasetOpKernel { + public: + static constexpr const char* const kDatasetType = "DataService"; + static constexpr const char* const kDatasetId = "dataset_id"; + static constexpr const char* const kProcessingMode = "processing_mode"; + static constexpr const char* const kAddress = "address"; + static constexpr const char* const kProtocol = "protocol"; + static constexpr const char* const kDataTransferProtocol = + "data_transfer_protocol"; + static constexpr const char* const kJobName = "job_name"; + static constexpr const char* const kConsumerIndex = "consumer_index"; + static constexpr const char* const kNumConsumers = "num_consumers"; + static constexpr const char* const kMaxOutstandingRequests = + "max_outstanding_requests"; + static constexpr const char* const kTaskRefreshIntervalHintMs = + "task_refresh_interval_hint_ms"; + static constexpr const char* const kTargetWorkers = "target_workers"; + static constexpr const char* const kIterationCounter = "iteration_counter"; + static constexpr const char* const kOutputTypes = "output_types"; + static constexpr const char* const kOutputShapes = "output_shapes"; + static constexpr const char* const kUncompress = "uncompress"; + static constexpr const char* const kUncompressFn = "uncompress_fn"; + static constexpr const char* const kCrossTrainerCacheOptions = + "cross_trainer_cache_options"; + + // Note: If a new constant is declared here, it *must* be defined in + // data_service_dataset_op.cc, otherwise it will not compile in debug mode. + + explicit DataServiceDatasetOp(OpKernelConstruction* ctx); + + protected: + void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override; + + private: + class Dataset; + int op_version_; + absl::Duration task_refresh_interval_hint_; + DataTypeVector output_types_; + std::vector output_shapes_; + std::string data_transfer_protocol_; + TargetWorkers target_workers_ = TARGET_WORKERS_AUTO; + bool uncompress_; + std::shared_ptr uncompress_fn_ = nullptr; + std::string seriazlied_cross_trainer_cache_options_; +}; + +} // namespace data +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_DATA_SERVICE_DATASET_OP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/data/experimental/data_service_ops.h b/third_party/tflite-hdrs/tensorflow/core/kernels/data/experimental/data_service_ops.h new file mode 100644 index 00000000..b21a353d --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/data/experimental/data_service_ops.h @@ -0,0 +1,59 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_DATA_SERVICE_OPS_H_ +#define TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_DATA_SERVICE_OPS_H_ + +#include + +#include "tensorflow/core/framework/dataset.h" +#include "tensorflow/core/framework/op_kernel.h" + +namespace tensorflow { +namespace data { + +// Registers a dataset with the tf.data service. +// +// The address and protocol inputs are used to connect to the dispatcher. +// The external state policy attribute determines whether to ignore, warn, or +// error out when the dataset contains external state. +// The op produces a dataset id for identifying the registered dataset. +class RegisterDatasetOp : public OpKernel { + public: + static constexpr const char* const kAddress = "address"; + static constexpr const char* const kProtocol = "protocol"; + static constexpr const char* const kExternalStatePolicy = + "external_state_policy"; + static constexpr const char* const kElementSpec = "element_spec"; + static constexpr const char* const kMetadata = "metadata"; + static constexpr const char* const kRequestedDatasetId = + "requested_dataset_id"; + static constexpr const char* const kTimeoutMs = "timeout_ms"; + + explicit RegisterDatasetOp(OpKernelConstruction* ctx); + + void Compute(OpKernelContext* ctx) override; + + private: + int op_version_; + ExternalStatePolicy external_state_policy_; + std::string element_spec_; + std::string serialized_metadata_; + std::string requested_dataset_id_; +}; + +} // namespace data +} // namespace tensorflow +#endif // TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_DATA_SERVICE_OPS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/data/experimental/directed_interleave_dataset_op.h b/third_party/tflite-hdrs/tensorflow/core/kernels/data/experimental/directed_interleave_dataset_op.h new file mode 100644 index 00000000..25c0ef7a --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/data/experimental/directed_interleave_dataset_op.h @@ -0,0 +1,50 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_DIRECTED_INTERLEAVE_DATASET_OP_H_ +#define TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_DIRECTED_INTERLEAVE_DATASET_OP_H_ + +#include "tensorflow/core/framework/dataset.h" + +namespace tensorflow { +namespace data { +namespace experimental { + +class DirectedInterleaveDatasetOp : public DatasetOpKernel { + public: + static constexpr const char* const kDatasetType = "DirectedInterleave"; + static constexpr const char* const kSelectorInputDataset = + "selector_input_dataset"; + static constexpr const char* const kDataInputDatasets = "data_input_datasets"; + static constexpr const char* const kOutputTypes = "output_types"; + static constexpr const char* const kOutputShapes = "output_shapes"; + static constexpr const char* const kNumInputDatasets = "N"; + static constexpr const char* const kStopOnEmptyDataset = + "stop_on_empty_dataset"; + + explicit DirectedInterleaveDatasetOp(OpKernelConstruction* ctx); + + protected: + void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override; + + private: + class Dataset; + bool stop_on_empty_dataset_ = false; +}; + +} // namespace experimental +} // namespace data +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_DIRECTED_INTERLEAVE_DATASET_OP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/data/experimental/distributed_save_op.h b/third_party/tflite-hdrs/tensorflow/core/kernels/data/experimental/distributed_save_op.h new file mode 100644 index 00000000..d88642f2 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/data/experimental/distributed_save_op.h @@ -0,0 +1,46 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_DISTRIBUTED_SAVE_OP_H_ +#define TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_DISTRIBUTED_SAVE_OP_H_ + +#include + +#include "tensorflow/core/framework/op_kernel.h" + +namespace tensorflow { +namespace data { +namespace experimental { + +// Initiates the process of distributedly saving a dataset to disk. +class DistributedSaveOp : public OpKernel { + public: + static constexpr const char* const kDirectory = "directory"; + static constexpr const char* const kAddress = "address"; + static constexpr const char* const kMetadata = "metadata"; + + explicit DistributedSaveOp(OpKernelConstruction* ctx); + + void Compute(OpKernelContext* ctx) override; + + private: + std::string serialized_metadata_; +}; + +} // namespace experimental +} // namespace data +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_DISTRIBUTED_SAVE_OP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/data/experimental/list_dataset_op.h b/third_party/tflite-hdrs/tensorflow/core/kernels/data/experimental/list_dataset_op.h new file mode 100644 index 00000000..ef921042 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/data/experimental/list_dataset_op.h @@ -0,0 +1,46 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_LIST_DATASET_OP_H_ +#define TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_LIST_DATASET_OP_H_ + +#include "tensorflow/core/framework/dataset.h" + +namespace tensorflow { +namespace data { + +class ListDatasetOp : public DatasetOpKernel { + public: + static constexpr const char* const kDatasetType = "List"; + static constexpr const char* const kTensors = "tensors"; + static constexpr const char* const kTinputTypes = "Tinput_types"; + static constexpr const char* const kOutputTypes = "output_types"; + static constexpr const char* const kOutputShapes = "output_shapes"; + + explicit ListDatasetOp(OpKernelConstruction* ctx); + + protected: + void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override; + + private: + class Dataset; + DataTypeVector input_types_; + DataTypeVector output_types_; + std::vector output_shapes_; +}; + +} // namespace data +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_LIST_DATASET_OP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/data/experimental/lmdb_dataset_op.h b/third_party/tflite-hdrs/tensorflow/core/kernels/data/experimental/lmdb_dataset_op.h new file mode 100644 index 00000000..f58473a7 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/data/experimental/lmdb_dataset_op.h @@ -0,0 +1,44 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_LMDB_DATASET_OP_H_ +#define TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_LMDB_DATASET_OP_H_ + +#include "tensorflow/core/framework/dataset.h" + +namespace tensorflow { +namespace data { +namespace experimental { + +class LMDBDatasetOp : public DatasetOpKernel { + public: + static constexpr const char* const kDatasetType = "LMDB"; + static constexpr const char* const kFileNames = "filenames"; + static constexpr const char* const kOutputTypes = "output_types"; + static constexpr const char* const kOutputShapes = "output_shapes"; + + using DatasetOpKernel::DatasetOpKernel; + + protected: + void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override; + + private: + class Dataset; +}; + +} // namespace experimental +} // namespace data +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_LMDB_DATASET_OP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/data/experimental/load_dataset_op.h b/third_party/tflite-hdrs/tensorflow/core/kernels/data/experimental/load_dataset_op.h new file mode 100644 index 00000000..4a27d6aa --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/data/experimental/load_dataset_op.h @@ -0,0 +1,62 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_LOAD_DATASET_OP_H_ +#define TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_LOAD_DATASET_OP_H_ + +#include +#include +#include + +#include "tensorflow/core/data/captured_function.h" +#include "tensorflow/core/framework/dataset.h" +#include "tensorflow/core/framework/op_kernel.h" + +namespace tensorflow { +namespace data { +namespace experimental { + +// An operation that can load a dataset from one or more files. +class LoadDatasetOp : public DatasetOpKernel { + public: + static constexpr const char* const kCompression = "compression"; + static constexpr const char* const kDatasetType = "Load"; + static constexpr const char* const kOutputTypes = "output_types"; + static constexpr const char* const kOutputShapes = "output_shapes"; + static constexpr const char* const kPath = "path"; + static constexpr const char* const kReaderFunc = "reader_func"; + static constexpr const char* const kReaderFuncOtherArgs = + "reader_func_other_args"; + static constexpr const char* const kReaderFuncTarguments = + "Treader_func_args"; + + explicit LoadDatasetOp(OpKernelConstruction* ctx); + + void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override; + + private: + class Dataset; + + std::string compression_; + DataTypeVector output_types_; + std::vector output_shapes_; + std::shared_ptr reader_func_metadata_; +}; + +} // namespace experimental +} // namespace data +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_LOAD_DATASET_OP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/data/experimental/map_and_batch_dataset_op.h b/third_party/tflite-hdrs/tensorflow/core/kernels/data/experimental/map_and_batch_dataset_op.h new file mode 100644 index 00000000..b3fec152 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/data/experimental/map_and_batch_dataset_op.h @@ -0,0 +1,61 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_MAP_AND_BATCH_DATASET_OP_H_ +#define TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_MAP_AND_BATCH_DATASET_OP_H_ + +#include "tensorflow/core/data/captured_function.h" +#include "tensorflow/core/framework/dataset.h" + +namespace tensorflow { +namespace data { +namespace experimental { + +// See documentation in ../../ops/experimental_dataset_ops.cc for a high-level +// description of the following op. + +class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { + public: + static constexpr const char* const kDatasetType = "MapAndBatch"; + static constexpr const char* const kInputDataset = "input_dataset"; + static constexpr const char* const kOtherArguments = "other_arguments"; + static constexpr const char* const kBatchSize = "batch_size"; + static constexpr const char* const kNumParallelCalls = "num_parallel_calls"; + static constexpr const char* const kDropRemainder = "drop_remainder"; + static constexpr const char* const kFunc = "f"; + static constexpr const char* const kTarguments = "Targuments"; + static constexpr const char* const kOutputTypes = "output_types"; + static constexpr const char* const kOutputShapes = "output_shapes"; + static constexpr const char* const kPreserveCardinality = + "preserve_cardinality"; + + explicit MapAndBatchDatasetOp(OpKernelConstruction* ctx); + + protected: + void MakeDataset(OpKernelContext* ctx, DatasetBase* input, + DatasetBase** output) override; + + private: + class Dataset; + std::shared_ptr func_metadata_ = nullptr; + DataTypeVector output_types_; + std::vector output_shapes_; + bool preserve_cardinality_; +}; + +} // namespace experimental +} // namespace data +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_MAP_AND_BATCH_DATASET_OP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/data/experimental/parallel_interleave_dataset_op.h b/third_party/tflite-hdrs/tensorflow/core/kernels/data/experimental/parallel_interleave_dataset_op.h new file mode 100644 index 00000000..fc59b599 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/data/experimental/parallel_interleave_dataset_op.h @@ -0,0 +1,67 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_PARALLEL_INTERLEAVE_DATASET_OP_H_ +#define TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_PARALLEL_INTERLEAVE_DATASET_OP_H_ + +#include "tensorflow/core/data/captured_function.h" +#include "tensorflow/core/data/dataset_utils.h" +#include "tensorflow/core/framework/dataset.h" + +namespace tensorflow { +namespace data { +namespace experimental { + +// See documentation in ../../ops/experimental_dataset_ops.cc for a high-level +// description of the following op. + +class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel { + public: + static constexpr const char* const kDatasetType = "LegacyParallelInterleave"; + static constexpr const char* const kInputDataset = "input_dataset"; + static constexpr const char* const kOtherArguments = "other_arguments"; + static constexpr const char* const kCycleLength = "cycle_length"; + static constexpr const char* const kBlockLength = "block_length"; + static constexpr const char* const kDeterministic = "deterministic"; + static constexpr const char* const kSloppy = "sloppy"; + static constexpr const char* const kBufferOutputElements = + "buffer_output_elements"; + static constexpr const char* const kPrefetchInputElements = + "prefetch_input_elements"; + static constexpr const char* const kFunc = "f"; + static constexpr const char* const kTarguments = "Targuments"; + static constexpr const char* const kOutputTypes = "output_types"; + static constexpr const char* const kOutputShapes = "output_shapes"; + + explicit ParallelInterleaveDatasetOp(OpKernelConstruction* ctx); + + protected: + void MakeDataset(OpKernelContext* ctx, DatasetBase* input, + DatasetBase** output) override; + + private: + class Dataset; + const int op_version_; + + std::shared_ptr func_metadata_ = nullptr; + DataTypeVector output_types_; + std::vector output_shapes_; + DeterminismPolicy deterministic_; +}; + +} // namespace experimental +} // namespace data +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_PARALLEL_INTERLEAVE_DATASET_OP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/data/experimental/random_access_ops.h b/third_party/tflite-hdrs/tensorflow/core/kernels/data/experimental/random_access_ops.h new file mode 100644 index 00000000..293cb99c --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/data/experimental/random_access_ops.h @@ -0,0 +1,64 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_RANDOM_ACCESS_OPS_H_ +#define TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_RANDOM_ACCESS_OPS_H_ + +#include "tensorflow/core/framework/dataset.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/op_requires.h" +#include "tensorflow/core/kernels/data/iterator_ops.h" +#include "tensorflow/core/platform/platform.h" + +namespace tensorflow { +namespace data { +namespace experimental { + +// An operation that can get an element at a specified index in a dataset. +class GetElementAtIndexOp : public AsyncOpKernel { + public: + explicit GetElementAtIndexOp(OpKernelConstruction* ctx) + : AsyncOpKernel(ctx), + unbounded_threadpool_(ctx->env(), "tf_data_get_element_at_index") { + OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_)); + } + + ~GetElementAtIndexOp() override {} + + void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override { + unbounded_threadpool_.Schedule([this, ctx, done = std::move(done)]() { + ctx->SetStatus(DoCompute(ctx)); + done(); + }); + } + + void Compute(OpKernelContext* ctx) override { + ctx->SetStatus(DoCompute(ctx)); + } + + protected: + absl::Status DoCompute(OpKernelContext* ctx); + + private: + UnboundedThreadPool unbounded_threadpool_; + DataTypeVector output_types_; + std::vector output_shapes_; +}; + +} // namespace experimental +} // namespace data +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_RANDOM_ACCESS_OPS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/data/experimental/random_dataset_op.h b/third_party/tflite-hdrs/tensorflow/core/kernels/data/experimental/random_dataset_op.h new file mode 100644 index 00000000..2b3624fe --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/data/experimental/random_dataset_op.h @@ -0,0 +1,54 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_RANDOM_DATASET_OP_H_ +#define TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_RANDOM_DATASET_OP_H_ + +#include "tensorflow/core/framework/dataset.h" + +namespace tensorflow { +namespace data { +namespace experimental { + +// See tensorflow/core/api_def/base_api/api_def_RandomDataset.pbtxt for the +// API definition that corresponds to this kernel. +class RandomDatasetOp : public DatasetOpKernel { + public: + // Names of op parameters, public so that they can be accessed by test cases. + // Make sure that these are kept in sync with the REGISTER_OP call in + // tensorflow/core/ops/experimental_dataset_ops.cc + static constexpr const char* const kDatasetType = "Random"; + static constexpr const char* const kSeed = "seed"; + static constexpr const char* const kSeed2 = "seed2"; + static constexpr const char* const kOutputTypes = "output_types"; + static constexpr const char* const kOutputShapes = "output_shapes"; + static constexpr const char* const kRerandomizeEachIteration = + "rerandomize_each_iteration"; + + explicit RandomDatasetOp(OpKernelConstruction* ctx); + + protected: + void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override; + + private: + class Dataset; + int32_t op_version_; + bool rerandomize_each_iteration_ = false; +}; + +} // namespace experimental +} // namespace data +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_RANDOM_DATASET_OP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/data/experimental/sampling_dataset_op.h b/third_party/tflite-hdrs/tensorflow/core/kernels/data/experimental/sampling_dataset_op.h new file mode 100644 index 00000000..9223c0e5 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/data/experimental/sampling_dataset_op.h @@ -0,0 +1,53 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_SAMPLING_DATASET_OP_H_ +#define TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_SAMPLING_DATASET_OP_H_ + +#include "tensorflow/core/framework/dataset.h" + +namespace tensorflow { +namespace data { +namespace experimental { + +// See tensorflow/core/api_def/base_api/api_def_SamplingDataset.pbtxt for the +// API definition that corresponds to this kernel. +class SamplingDatasetOp : public UnaryDatasetOpKernel { + public: + // Names of op parameters, public so that they can be accessed by test cases. + // Make sure that these are kept in sync with the REGISTER_OP call in + // tensorflow/core/ops/experimental_dataset_ops.cc + static constexpr const char* const kDatasetType = "Sampling"; + static constexpr const char* const kInputDataset = "input_dataset"; + static constexpr const char* const kRate = "rate"; + static constexpr const char* const kSeed = "seed"; + static constexpr const char* const kSeed2 = "seed2"; + static constexpr const char* const kOutputTypes = "output_types"; + static constexpr const char* const kOutputShapes = "output_shapes"; + + explicit SamplingDatasetOp(OpKernelConstruction* ctx); + + protected: + void MakeDataset(OpKernelContext* ctx, DatasetBase* input, + DatasetBase** output) override; + + private: + class Dataset; +}; + +} // namespace experimental +} // namespace data +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_SAMPLING_DATASET_OP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/data/experimental/save_dataset_op.h b/third_party/tflite-hdrs/tensorflow/core/kernels/data/experimental/save_dataset_op.h new file mode 100644 index 00000000..77478d4e --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/data/experimental/save_dataset_op.h @@ -0,0 +1,114 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_SAVE_DATASET_OP_H_ +#define TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_SAVE_DATASET_OP_H_ + +#include +#include +#include + +#include "tensorflow/core/data/captured_function.h" +#include "tensorflow/core/framework/dataset.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/kernels/data/iterator_ops.h" +#include "tensorflow/core/platform/thread_annotations.h" + +namespace tensorflow { +namespace data { +namespace experimental { + +// An operation that can save a dataset to one or more files. +class SaveDatasetOp : public HybridAsyncOpKernel { + public: + static constexpr const char* const kCompression = "compression"; + static constexpr const char* const kPath = "path"; + static constexpr const char* const kShardFunc = "shard_func"; + static constexpr const char* const kShardFuncOtherArgs = + "shard_func_other_args"; + static constexpr const char* const kUseShardFunc = "use_shard_func"; + + explicit SaveDatasetOp(OpKernelConstruction* ctx); + + absl::Status DoCompute(OpKernelContext* ctx) override; + + private: + static constexpr const int kFileFormatVersion = 2; + + absl::Status ConsumeElement(); + + absl::Status GetShardIndex(IteratorContext* ctx, + InstantiatedCapturedFunction* function, + const std::vector& element, + int64_t* shard_index); + + absl::Status WriteData(OpKernelContext* ctx, DatasetBase* dataset, + std::unique_ptr captured_func, + const std::string& run_dir, uint64* num_elements); + + absl::Status WriteMetadataFile(Env* env, const std::string& path, + uint64 run_id, + const DataTypeVector& output_dtypes, + uint64 num_elements, bool finalized); + + bool use_shard_func_; + std::string compression_; + std::shared_ptr func_metadata_; +}; + +// An operation that can save a dataset to one or more files. This +// version of the implementation subclasses from UnaryDatasetOpKernel to align +// the implementation of save with that of the other tf.data transformations. +class SaveDatasetV2Op : public UnaryDatasetOpKernel { + public: + static constexpr const char* const kInputDataset = "input_dataset"; + static constexpr const char* const kPath = "path"; + static constexpr const char* const kCompression = "compression"; + + static constexpr const char* const kDatasetType = "SaveV2"; + static constexpr const char* const kOutputTypes = "output_types"; + static constexpr const char* const kOutputShapes = "output_shapes"; + + static constexpr const char* const kShardFunc = "shard_func"; + static constexpr const char* const kShardFuncOtherArgs = + "shard_func_other_args"; + static constexpr const char* const kUseShardFunc = "use_shard_func"; + static constexpr const char* const kShardFuncTarguments = "Tshard_func_args"; + + explicit SaveDatasetV2Op(OpKernelConstruction* ctx); + + void MakeDataset(OpKernelContext* ctx, DatasetBase* input, + DatasetBase** output) override; + + private: + class Dataset; + + static constexpr const int kFileFormatVersion = 2; + + tstring path_; + std::string compression_; + std::unique_ptr shard_func_; + bool use_shard_func_; + DataTypeVector output_types_; + std::vector output_shapes_; + std::shared_ptr func_metadata_; + std::string writer_prefix_; +}; + +} // namespace experimental +} // namespace data +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_SAVE_DATASET_OP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/data/experimental/snapshot_dataset_op.h b/third_party/tflite-hdrs/tensorflow/core/kernels/data/experimental/snapshot_dataset_op.h new file mode 100644 index 00000000..fb1fa875 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/data/experimental/snapshot_dataset_op.h @@ -0,0 +1,95 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_SNAPSHOT_DATASET_OP_H_ +#define TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_SNAPSHOT_DATASET_OP_H_ + +#include + +#include "absl/container/flat_hash_map.h" +#include "tensorflow/core/data/captured_function.h" +#include "tensorflow/core/data/dataset_utils.h" +#include "tensorflow/core/data/name_utils.h" +#include "tensorflow/core/data/snapshot_utils.h" +#include "tensorflow/core/framework/dataset.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/op_requires.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/gtl/map_util.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/path.h" +#include "tensorflow/core/platform/random.h" +#include "tensorflow/core/platform/thread_annotations.h" + +namespace tensorflow { +namespace data { +namespace experimental { + +class SnapshotDatasetV2Op : public UnaryDatasetOpKernel { + public: + static constexpr const char* const kDatasetType = "Snapshot"; + static constexpr const char* const kOutputTypes = "output_types"; + static constexpr const char* const kOutputShapes = "output_shapes"; + static constexpr const char* const kCompression = "compression"; + static constexpr const char* const kReaderPrefix = "reader_prefix"; + static constexpr const char* const kWriterPrefix = "writer_prefix"; + static constexpr const char* const kHashValid = "hash_valid"; + static constexpr const char* const kHash = "hash"; + static constexpr const char* const kCompressionAuto = "AUTO"; + static constexpr const char* const kReaderFunc = "reader_func"; + static constexpr const char* const kShardFunc = "shard_func"; + static constexpr const char* const kReaderFuncOtherArgs = + "reader_func_other_args"; + static constexpr const char* const kShardFuncOtherArgs = + "shard_func_other_args"; + static constexpr const char* const kReaderFuncTarguments = + "Treader_func_args"; + static constexpr const char* const kShardFuncTarguments = "Tshard_func_args"; + // Note: If a new constant is declared here, it *must* be defined in + // snapshot_dataset_op.cc, otherwise it will not compile in debug mode. + + explicit SnapshotDatasetV2Op(OpKernelConstruction* ctx); + + protected: + void MakeDataset(OpKernelContext* ctx, DatasetBase* input, + DatasetBase** output) override; + + private: + static constexpr const int kFileFormatVersion = 2; + + class Dataset; + + const int graph_def_version_; + DataTypeVector output_types_; + std::vector output_shapes_; + + std::string compression_; + std::string reader_prefix_; + std::string writer_prefix_; + bool hash_valid_; + uint64 hash_; + + std::shared_ptr reader_func_metadata_; + std::shared_ptr shard_func_metadata_; +}; + +} // namespace experimental +} // namespace data +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_SNAPSHOT_DATASET_OP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/data/experimental/sql/driver_manager.h b/third_party/tflite-hdrs/tensorflow/core/kernels/data/experimental/sql/driver_manager.h new file mode 100644 index 00000000..7aa307e2 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/data/experimental/sql/driver_manager.h @@ -0,0 +1,43 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_SQL_DRIVER_MANAGER_H_ +#define TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_SQL_DRIVER_MANAGER_H_ + +#include "tensorflow/core/kernels/data/experimental/sql/query_connection.h" + +namespace tensorflow { +namespace data { +namespace experimental { +namespace sql { + +// A factory class for creating `QueryConnection` instances. +class DriverManager { + public: + // A factory method for creating `QueryConnection` instances. + // + // `driver_name` is the database type (e.g. 'sqlite'). `driver_name` + // corresponds to a `QueryConnection` subclass. For example, if `driver_name` + // == `sqlite`, then `CreateQueryConnection` will create a + // `SqliteQueryConnection` instance. + static std::unique_ptr CreateQueryConnection( + const string& driver_name); +}; + +} // namespace sql +} // namespace experimental +} // namespace data +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_SQL_DRIVER_MANAGER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/data/experimental/sql/query_connection.h b/third_party/tflite-hdrs/tensorflow/core/kernels/data/experimental/sql/query_connection.h new file mode 100644 index 00000000..031a8725 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/data/experimental/sql/query_connection.h @@ -0,0 +1,74 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_SQL_QUERY_CONNECTION_H_ +#define TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_SQL_QUERY_CONNECTION_H_ + +#include "tensorflow/core/framework/tensor.h" + +namespace tensorflow { +namespace data { + +class IteratorContext; + +namespace experimental { + +namespace sql { +// This interface allows a user to connect to a database, execute a query, and +// iterate over the result set, putting the results into an output tensor. +// A subclass implementation is required for each type of database +// (e.g. sqlite3, mysql, etc.) +// +// Presently, a `QueryConnection` instance can only handle one query at a time. +// In a future extension, this class may be refactored so that it creates +// instances of a new class (named, say, `Statement`) which could have a +// one-to-one correspondence with queries. This would make `QueryConnection` +// more consistent with `Connection` classes of other database APIs. +// `QueryConnection` would then be renamed simply `Connection`. +// +// This class is not thread safe. Access to it is guarded by a mutex in +// `SqlDatasetOp::Dataset::Iterator`. +class QueryConnection { + public: + virtual ~QueryConnection() {} + // Opens a connection to the database named by `data_source_name`. Prepares to + // execute `query` against the database. + // + // The client must call `Close()` to release the connection resources, even + // if `Open()` fails. `Close()` must be called before making another call + // to `Open()`. + virtual absl::Status Open(const string& data_source_name, const string& query, + const DataTypeVector& output_types) = 0; + // Closes an opened connection. + virtual absl::Status Close() = 0; + // Retrieves the next row of the result set of the query from the most recent + // call to `Open()`. + // + // If such a row exists, then the row will be stored in `*out_tensors`, and + // `false` will be stored in `*end_of_sequence`. + // + // If there are no more rows in the result set, then instead `true` will be + // stored in `*end_of_sequence`, and the content of `*out_tensors` will be + // undefined. + virtual absl::Status GetNext(IteratorContext* ctx, + std::vector* out_tensors, + bool* end_of_sequence) = 0; +}; + +} // namespace sql +} // namespace experimental +} // namespace data +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_SQL_QUERY_CONNECTION_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/data/experimental/sql/sqlite_query_connection.h b/third_party/tflite-hdrs/tensorflow/core/kernels/data/experimental/sql/sqlite_query_connection.h new file mode 100644 index 00000000..4cf2608c --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/data/experimental/sql/sqlite_query_connection.h @@ -0,0 +1,58 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_SQL_SQLITE_QUERY_CONNECTION_H_ +#define TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_SQL_SQLITE_QUERY_CONNECTION_H_ + +#include + +#include "tensorflow/core/kernels/data/experimental/sql/query_connection.h" +#include "tensorflow/core/lib/db/sqlite.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +namespace data { +namespace experimental { +namespace sql { + +class SqliteQueryConnection : public QueryConnection { + public: + SqliteQueryConnection(); + ~SqliteQueryConnection() override; + absl::Status Open(const string& data_source_name, const string& query, + const DataTypeVector& output_types) override; + absl::Status Close() override; + absl::Status GetNext(IteratorContext* ctx, std::vector* out_tensors, + bool* end_of_sequence) override; + + private: + // Prepares the query string `query_`. + absl::Status PrepareQuery(); + // Fills `tensor` with the column_index_th element of the current row of + // `stmt_`. + void FillTensorWithResultSetEntry(const DataType& data_type, int column_index, + Tensor* tensor); + Sqlite* db_ = nullptr; + SqliteStatement stmt_; + int column_count_ = 0; + string query_; + DataTypeVector output_types_; +}; + +} // namespace sql +} // namespace experimental +} // namespace data +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_SQL_SQLITE_QUERY_CONNECTION_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/data/experimental/threadpool_dataset_op.h b/third_party/tflite-hdrs/tensorflow/core/kernels/data/experimental/threadpool_dataset_op.h new file mode 100644 index 00000000..1255365d --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/data/experimental/threadpool_dataset_op.h @@ -0,0 +1,76 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_THREADPOOL_DATASET_OP_H_ +#define TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_THREADPOOL_DATASET_OP_H_ + +#include + +#include "tensorflow/core/framework/dataset.h" +#include "tensorflow/core/platform/platform.h" + +namespace tensorflow { +namespace data { +namespace experimental { + +class MaxIntraOpParallelismDatasetOp : public UnaryDatasetOpKernel { + public: + static constexpr const char* const kDatasetType = + "MaxIntraOpParallelismDataset"; + static constexpr const char* const kDatasetOp = + "MaxIntraOpParallelismDatasetOp"; + + // Executes the logic of the MaxIntraOpParallelismDatasetOp directly (as + // opposed to through executing the MaxIntraOpParallelismDatasetOp op kernel). + static void MakeDatasetFromOptions(OpKernelContext* ctx, DatasetBase* input, + int32_t max_intra_op_parallelism, + DatasetBase** output); + + explicit MaxIntraOpParallelismDatasetOp(OpKernelConstruction* ctx) + : UnaryDatasetOpKernel(ctx) {} + + protected: + void MakeDataset(OpKernelContext* ctx, DatasetBase* input, + DatasetBase** output) override; + + private: + class Dataset; +}; + +class PrivateThreadPoolDatasetOp : public UnaryDatasetOpKernel { + public: + static constexpr const char* const kDatasetType = "PrivateThreadPoolDataset"; + static constexpr const char* const kDatasetOp = "PrivateThreadPoolDatasetOp"; + + // Executes the logic of the PrivateThreadpoolDatasetOp directly (as + // opposed to through executing the PrivateThreadpoolDatasetOp op kernel). + static void MakeDatasetFromOptions(OpKernelContext* ctx, DatasetBase* input, + int32_t num_threads, DatasetBase** output); + + explicit PrivateThreadPoolDatasetOp(OpKernelConstruction* ctx) + : UnaryDatasetOpKernel(ctx) {} + + protected: + void MakeDataset(OpKernelContext* ctx, DatasetBase* input, + DatasetBase** output) override; + + private: + class Dataset; +}; + +} // namespace experimental +} // namespace data +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_THREADPOOL_DATASET_OP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/data/experimental/unique_dataset_op.h b/third_party/tflite-hdrs/tensorflow/core/kernels/data/experimental/unique_dataset_op.h new file mode 100644 index 00000000..2d415816 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/data/experimental/unique_dataset_op.h @@ -0,0 +1,46 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_UNIQUE_DATASET_OP_H_ +#define TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_UNIQUE_DATASET_OP_H_ + +#include "tensorflow/core/framework/dataset.h" + +namespace tensorflow { +namespace data { +namespace experimental { + +class UniqueDatasetOp : public UnaryDatasetOpKernel { + public: + static constexpr const char* const kDatasetType = "Unique"; + static constexpr const char* const kInputDataset = "input_dataset"; + static constexpr const char* const kOutputTypes = "output_types"; + static constexpr const char* const kOutputShapes = "output_shapes"; + + explicit UniqueDatasetOp(OpKernelConstruction* ctx) + : UnaryDatasetOpKernel(ctx) {} + + protected: + void MakeDataset(OpKernelContext* ctx, DatasetBase* input, + DatasetBase** output) override; + + private: + class Dataset; +}; + +} // namespace experimental +} // namespace data +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_UNIQUE_DATASET_OP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/data/filter_dataset_op.h b/third_party/tflite-hdrs/tensorflow/core/kernels/data/filter_dataset_op.h new file mode 100644 index 00000000..59c5bcc1 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/data/filter_dataset_op.h @@ -0,0 +1,48 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_KERNELS_DATA_FILTER_DATASET_OP_H_ +#define TENSORFLOW_CORE_KERNELS_DATA_FILTER_DATASET_OP_H_ + +#include "tensorflow/core/data/captured_function.h" +#include "tensorflow/core/framework/dataset.h" + +namespace tensorflow { +namespace data { + +class FilterDatasetOp : public UnaryDatasetOpKernel { + public: + static constexpr const char* const kDatasetType = "Filter"; + static constexpr const char* const kInputDataset = "input_dataset"; + static constexpr const char* const kOtherArguments = "other_arguments"; + static constexpr const char* const kPredicate = "predicate"; + static constexpr const char* const kTarguments = "Targuments"; + static constexpr const char* const kOutputTypes = "output_types"; + static constexpr const char* const kOutputShapes = "output_shapes"; + + explicit FilterDatasetOp(OpKernelConstruction* ctx); + + protected: + void MakeDataset(OpKernelContext* ctx, DatasetBase* input, + DatasetBase** output) override; + + private: + class Dataset; + std::shared_ptr func_metadata_ = nullptr; +}; + +} // namespace data +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_DATA_FILTER_DATASET_OP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/data/finalize_dataset_op.h b/third_party/tflite-hdrs/tensorflow/core/kernels/data/finalize_dataset_op.h new file mode 100644 index 00000000..4b2ef22b --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/data/finalize_dataset_op.h @@ -0,0 +1,60 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_KERNELS_DATA_FINALIZE_DATASET_OP_H_ +#define TENSORFLOW_CORE_KERNELS_DATA_FINALIZE_DATASET_OP_H_ + +#include "tensorflow/core/framework/dataset.h" +#include "tensorflow/core/framework/op_kernel.h" + +namespace tensorflow { +namespace data { + +class FinalizeDatasetOp : public UnaryDatasetOpKernel { + public: + static constexpr const char* const kDatasetType = "Finalize"; + static constexpr const char* const kInputDataset = "input_dataset"; + static constexpr const char* const kOutputTypes = "output_types"; + static constexpr const char* const kOutputShapes = "output_shapes"; + static constexpr const char* const kHasCapturedRef = "has_captured_ref"; + + explicit FinalizeDatasetOp(OpKernelConstruction* ctx); + + void MakeDataset(OpKernelContext* ctx, DatasetBase* input, + DatasetBase** output) override; + + private: + class Dataset; + bool has_captured_ref_; +}; + +class FinalizeDatasetNoopOp : public UnaryDatasetOpKernel { + public: + explicit FinalizeDatasetNoopOp(OpKernelConstruction* ctx) + : UnaryDatasetOpKernel(ctx) {} + + protected: + void MakeDataset(OpKernelContext* ctx, DatasetBase* input, + DatasetBase** output) override { + LOG(WARNING) << "FinalizeDataset is only supported on CPU. Using it on " + "devices other than CPU has no effect."; + input->Ref(); + *output = input; + } +}; + +} // namespace data +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_DATA_FINALIZE_DATASET_OP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/data/fixed_length_record_dataset_op.h b/third_party/tflite-hdrs/tensorflow/core/kernels/data/fixed_length_record_dataset_op.h new file mode 100644 index 00000000..30b62031 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/data/fixed_length_record_dataset_op.h @@ -0,0 +1,46 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_KERNELS_DATA_FIXED_LENGTH_RECORD_DATASET_OP_H_ +#define TENSORFLOW_CORE_KERNELS_DATA_FIXED_LENGTH_RECORD_DATASET_OP_H_ + +#include "tensorflow/core/framework/dataset.h" + +namespace tensorflow { +namespace data { + +class FixedLengthRecordDatasetOp : public DatasetOpKernel { + public: + static constexpr const char* const kDatasetType = "FixedLengthRecord"; + static constexpr const char* const kFileNames = "filenames"; + static constexpr const char* const kHeaderBytes = "header_bytes"; + static constexpr const char* const kRecordBytes = "record_bytes"; + static constexpr const char* const kFooterBytes = "footer_bytes"; + static constexpr const char* const kBufferSize = "buffer_size"; + static constexpr const char* const kCompressionType = "compression_type"; + + explicit FixedLengthRecordDatasetOp(OpKernelConstruction* ctx); + + protected: + void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override; + + private: + class Dataset; + const int op_version_; +}; + +} // namespace data +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_DATA_FIXED_LENGTH_RECORD_DATASET_OP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/data/flat_map_dataset_op.h b/third_party/tflite-hdrs/tensorflow/core/kernels/data/flat_map_dataset_op.h new file mode 100644 index 00000000..6b370757 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/data/flat_map_dataset_op.h @@ -0,0 +1,51 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_KERNELS_DATA_FLAT_MAP_DATASET_OP_H_ +#define TENSORFLOW_CORE_KERNELS_DATA_FLAT_MAP_DATASET_OP_H_ + +#include "tensorflow/core/data/captured_function.h" +#include "tensorflow/core/framework/dataset.h" + +namespace tensorflow { +namespace data { + +class FlatMapDatasetOp : public UnaryDatasetOpKernel { + public: + static constexpr const char* const kDatasetType = "FlatMap"; + static constexpr const char* const kInputDataset = "input_dataset"; + static constexpr const char* const kOtherArguments = "other_arguments"; + static constexpr const char* const kFunc = "f"; + static constexpr const char* const kTarguments = "Targuments"; + static constexpr const char* const kOutputTypes = "output_types"; + static constexpr const char* const kOutputShapes = "output_shapes"; + + explicit FlatMapDatasetOp(OpKernelConstruction* ctx); + + protected: + void MakeDataset(OpKernelContext* ctx, DatasetBase* input, + DatasetBase** output) override; + + private: + class Dataset; + const int graph_def_version_; + DataTypeVector output_types_; + std::vector output_shapes_; + std::shared_ptr func_metadata_ = nullptr; +}; + +} // namespace data +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_DATA_FLAT_MAP_DATASET_OP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/data/generator_dataset_op.h b/third_party/tflite-hdrs/tensorflow/core/kernels/data/generator_dataset_op.h new file mode 100644 index 00000000..b734e9a6 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/data/generator_dataset_op.h @@ -0,0 +1,59 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_DATA_GENERATOR_DATASET_OP_H_ +#define TENSORFLOW_CORE_KERNELS_DATA_GENERATOR_DATASET_OP_H_ + +#include "tensorflow/core/data/captured_function.h" +#include "tensorflow/core/framework/dataset.h" + +namespace tensorflow { +namespace data { + +class GeneratorDatasetOp : public DatasetOpKernel { + public: + static constexpr const char* const kDatasetType = "Generator"; + static constexpr const char* const kInitFuncOtherArgs = + "init_func_other_args"; + static constexpr const char* const kNextFuncOtherArgs = + "next_func_other_args"; + static constexpr const char* const kFinalizeFuncOtherArgs = + "finalize_func_other_args"; + static constexpr const char* const kInitFunc = "init_func"; + static constexpr const char* const kNextFunc = "next_func"; + static constexpr const char* const kFinalizeFunc = "finalize_func"; + static constexpr const char* const kTinitFuncArgs = "Tinit_func_args"; + static constexpr const char* const kTnextFuncArgs = "Tnext_func_args"; + static constexpr const char* const kTfinalizeFuncArgs = "Tfinalize_func_args"; + static constexpr const char* const kOutputTypes = "output_types"; + static constexpr const char* const kOutputShapes = "output_shapes"; + + explicit GeneratorDatasetOp(OpKernelConstruction* ctx); + + void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override; + + private: + class Dataset; + + DataTypeVector output_types_; + std::vector output_shapes_; + std::shared_ptr init_func_metadata_ = nullptr; + std::shared_ptr next_func_metadata_ = nullptr; + std::shared_ptr finalize_func_metadata_ = nullptr; +}; + +} // namespace data +} // namespace tensorflow +#endif // TENSORFLOW_CORE_KERNELS_DATA_GENERATOR_DATASET_OP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/data/get_options_op.h b/third_party/tflite-hdrs/tensorflow/core/kernels/data/get_options_op.h new file mode 100644 index 00000000..3e6611cb --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/data/get_options_op.h @@ -0,0 +1,36 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_KERNELS_DATA_GET_OPTIONS_OP_H_ +#define TENSORFLOW_CORE_KERNELS_DATA_GET_OPTIONS_OP_H_ + +#include "tensorflow/core/framework/op_kernel.h" + +namespace tensorflow { +namespace data { + +// TODO(jsimsa): Provide class-level documentation for this and the other ops. +class GetOptionsOp : public OpKernel { + public: + explicit GetOptionsOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} + + void Compute(OpKernelContext* ctx) final; + + string TraceString(const OpKernelContext& ctx, bool verbose) const override; +}; + +} // namespace data +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_DATA_GET_OPTIONS_OP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/data/interleave_dataset_op.h b/third_party/tflite-hdrs/tensorflow/core/kernels/data/interleave_dataset_op.h new file mode 100644 index 00000000..a1300ddd --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/data/interleave_dataset_op.h @@ -0,0 +1,53 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_KERNELS_DATA_INTERLEAVE_DATASET_OP_H_ +#define TENSORFLOW_CORE_KERNELS_DATA_INTERLEAVE_DATASET_OP_H_ + +#include "tensorflow/core/data/captured_function.h" +#include "tensorflow/core/framework/dataset.h" + +namespace tensorflow { +namespace data { + +class InterleaveDatasetOp : public UnaryDatasetOpKernel { + public: + static constexpr const char* const kDatasetType = "Interleave"; + static constexpr const char* const kInputDataset = "input_dataset"; + static constexpr const char* const kOtherArguments = "other_arguments"; + static constexpr const char* const kCycleLength = "cycle_length"; + static constexpr const char* const kBlockLength = "block_length"; + static constexpr const char* const kFunc = "f"; + static constexpr const char* const kTarguments = "Targuments"; + static constexpr const char* const kOutputTypes = "output_types"; + static constexpr const char* const kOutputShapes = "output_shapes"; + + explicit InterleaveDatasetOp(OpKernelConstruction* ctx); + + protected: + void MakeDataset(OpKernelContext* ctx, DatasetBase* input, + DatasetBase** output) override; + + private: + class Dataset; + const int graph_def_version_; + DataTypeVector output_types_; + std::vector output_shapes_; + std::shared_ptr func_metadata_ = nullptr; +}; + +} // namespace data +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_DATA_INTERLEAVE_DATASET_OP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/data/iterator_ops.h b/third_party/tflite-hdrs/tensorflow/core/kernels/data/iterator_ops.h new file mode 100644 index 00000000..a2b13411 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/data/iterator_ops.h @@ -0,0 +1,356 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_DATA_ITERATOR_OPS_H_ +#define TENSORFLOW_CORE_KERNELS_DATA_ITERATOR_OPS_H_ + +#include +#include +#include + +#include "tensorflow/core/data/dataset_utils.h" +#include "tensorflow/core/data/metric_utils.h" +#include "tensorflow/core/data/tfdataz_metrics.h" +#include "tensorflow/core/data/unbounded_thread_pool.h" +#include "tensorflow/core/framework/dataset.h" +#include "tensorflow/core/framework/function_handle_cache.h" +#include "tensorflow/core/framework/model.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/platform/refcount.h" + +namespace tensorflow { +namespace data { + +class IteratorResource : public ResourceBase { + public: + IteratorResource(Env* env, const DataTypeVector& output_dtypes, + const std::vector& output_shapes, + std::unique_ptr device_mgr, + std::unique_ptr flib_def, + std::unique_ptr pflr, + FunctionLibraryRuntime* flr); + + ~IteratorResource() override; + + // Gets the next output from the iterator managed by this iterator resource. + // + // If at least one output remains, that output will be stored in + // `*out_tensors` and `false` will be stored in `*end_of_sequence`. + // + // If no more outputs remain, `true` will be stored in `*end_of_sequence`, and + // the content of `*out_tensors` will be undefined. + absl::Status GetNext(OpKernelContext* ctx, std::vector* out_tensors, + bool* end_of_sequence); + + absl::Status GetModelProto(std::string& model_proto); + + // Saves a checkpoint of the state of the iterator through the given `writer`. + absl::Status Save(OpKernelContext* ctx, + ExternalStatePolicy external_state_policy, + IteratorStateWriter* writer); + + // Restores the state of the iterator from a checkpoint created by `Save`. + absl::Status Restore(OpKernelContext* ctx, IteratorStateReader* reader); + + // Creates an iterator for `dataset`, and associates the iterator with this + // iterator resource. + // + // `SetIteratorFromDataset` should be called before calling `GetNext`, `Save`, + // or `Restore`. + absl::Status SetIteratorFromDataset(OpKernelContext* ctx, + const DatasetBase* dataset); + + string DebugString() const override { return "Iterator resource"; } + + const DataTypeVector& output_dtypes() const { return output_dtypes_; } + + const std::vector& output_shapes() const { + return output_shapes_; + } + + private: + class State { + public: + State(std::shared_ptr flib_def, + std::shared_ptr pflr, + FunctionLibraryRuntime* flr, + std::unique_ptr iterator) + : flib_def_(std::move(flib_def)), + flr_(flr), + pflr_(std::move(pflr)), + function_handle_cache_(std::make_unique(flr)), + iterator_(std::move(iterator)), + + id_registry_(std::make_shared()), + checkpoint_(MemoryCheckpoint::CreateRootCheckpoint(id_registry_)) {} + + ~State() { cancellation_manager_.StartCancel(); } + + std::shared_ptr flib_def() { return flib_def_; } + + FunctionLibraryRuntime* flr() { return flr_; } + + std::shared_ptr pflr() { return pflr_; } + + FunctionHandleCache* function_handle_cache() { + return function_handle_cache_.get(); + } + + ResourceMgr* resource_mgr() { return &resource_mgr_; } + + CancellationManager* cancellation_manager() { + return &cancellation_manager_; + } + + DatasetBaseIterator* iterator() { return iterator_.get(); } + + std::shared_ptr model() { return model_; } + + const MemoryCheckpoint& checkpoint() const { return checkpoint_; } + + DatasetBase* dataset() { return dataset_.get(); } + + // Downcasts the given `IteratorBase` to a `DatasetBaseIterator`, and uses + // it to set the `iterator` and the `dataset` field. + void DowncastAndSetIteratorAndDataset(std::unique_ptr it, + const DatasetBase* dataset); + + // Merges the given checkpoint with the checkpoint of this state. + void MergeCheckpoint(MemoryCheckpoint* other); + + void SetModel(std::shared_ptr model); + + std::shared_ptr id_registry() { + return id_registry_; + } + + private: + std::shared_ptr flib_def_; + FunctionLibraryRuntime* flr_ = nullptr; // not owned + std::shared_ptr pflr_; + std::unique_ptr function_handle_cache_; + ResourceMgr resource_mgr_; + CancellationManager cancellation_manager_; + std::unique_ptr iterator_; + core::RefCountPtr dataset_; + std::shared_ptr id_registry_; + MemoryCheckpoint checkpoint_; + std::shared_ptr model_; + }; + + IteratorMetricsCollector metrics_collector_; + std::shared_ptr tf_dataz_metrics_collector_; + UnboundedThreadPool unbounded_thread_pool_; + + mutex mu_; + const Env& env_; + const std::unique_ptr device_mgr_ TF_GUARDED_BY(mu_); + std::shared_ptr iterator_state_ TF_GUARDED_BY(mu_); + const DataTypeVector output_dtypes_; + const std::vector output_shapes_; +}; + +class IteratorHandleOp : public OpKernel { + public: + explicit IteratorHandleOp(OpKernelConstruction* ctx); + + // The resource is deleted from the resource manager only when it is private + // to kernel. Ideally the resource should be deleted when it is no longer held + // by anyone, but it would break backward compatibility. + ~IteratorHandleOp() override; + + void Compute(OpKernelContext* context) override TF_LOCKS_EXCLUDED(mu_); + + private: + // During the first Compute(), resource is either created or looked up using + // shared_name. In the latter case, the resource found should be verified if + // it is compatible with this op's configuration. The verification may fail in + // cases such as two graphs asking queues of the same shared name to have + // inconsistent capacities. + absl::Status VerifyResource(IteratorResource* resource); + + FunctionLibraryRuntime* CreatePrivateFLR( + OpKernelContext* ctx, std::unique_ptr* device_mgr, + std::unique_ptr* flib_def, + std::unique_ptr* pflr); + + mutex mu_; + ContainerInfo cinfo_; // Written once under mu_ then constant afterwards. + IteratorResource* resource_ TF_GUARDED_BY(mu_) = nullptr; + DataTypeVector output_dtypes_; + std::vector output_shapes_; + const int graph_def_version_; + string name_; +}; + +// Like IteratorHandleOp, but creates handles which are never shared, and does +// not hold a reference to these handles. The latter is important for eager +// execution, since OpKernel instances generally live as long as the program +// running them. +class AnonymousIteratorHandleOp : public AnonymousResourceOp { + public: + explicit AnonymousIteratorHandleOp(OpKernelConstruction* context); + + private: + string name() override; + + absl::Status CreateResource( + OpKernelContext* ctx, std::unique_ptr flib_def, + std::unique_ptr pflr, + FunctionLibraryRuntime* lib, IteratorResource** resource) override; + + DataTypeVector output_dtypes_; + std::vector output_shapes_; + const int graph_def_version_; +}; + +// A hybrid asynchronous-and-synchronous OpKernel with efficient support for +// both modes. +// +// Inherit from this class when the application logic of the kernel (i) is +// implemented synchronously, (ii) must run on a background thread when the +// kernel executes in the inter-op threadpool (typically because it depends on +// inter-op threadpool threads, e.g. for function execution), and (iii) can run +// synchronously on the calling thread when the caller donates a thread +// (typically in eager execution). The implementation avoids a thread-hop in +// case (iii). +// +// NOTE: Unlike typical OpKernel subclasses, the application logic is +// implemented in a method (DoCompute()) that returns Status. Use +// TF_RETURN_IF_ERROR for error-related control flow rather than +// OP_REQUIRES_OK(). +class HybridAsyncOpKernel : public AsyncOpKernel { + public: + HybridAsyncOpKernel(OpKernelConstruction* ctx, + const char* background_worker_name); + + void Compute(OpKernelContext* ctx) final; + void ComputeAsync(OpKernelContext* ctx, DoneCallback done) final; + + protected: + virtual absl::Status DoCompute(OpKernelContext* ctx) = 0; + + private: + BackgroundWorker background_worker_; +}; + +class MakeIteratorOp : public HybridAsyncOpKernel { + public: + explicit MakeIteratorOp(OpKernelConstruction* ctx) + : HybridAsyncOpKernel(ctx, "tf_data_make_iterator") {} + + protected: + absl::Status DoCompute(OpKernelContext* ctx) override; +}; + +class IteratorGetNextOp : public HybridAsyncOpKernel { + public: + explicit IteratorGetNextOp(OpKernelConstruction* ctx) + : HybridAsyncOpKernel(ctx, "tf_data_iterator_get_next") { + OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_)); + } + + AsyncOpKernel* AsAsync() override; + + protected: + absl::Status DoCompute(OpKernelContext* ctx) override; + + private: + DataTypeVector output_types_; + std::vector output_shapes_; +}; + +class IteratorGetModelProtoOp : public HybridAsyncOpKernel { + public: + explicit IteratorGetModelProtoOp(OpKernelConstruction* ctx) + : HybridAsyncOpKernel( + ctx, + /*background_worker_name=*/"tf_data_iterator_get_model_proto") {} + + protected: + absl::Status DoCompute(OpKernelContext* ctx) override; +}; + +class DeleteIteratorOp : public HybridAsyncOpKernel { + public: + explicit DeleteIteratorOp(OpKernelConstruction* ctx) + : HybridAsyncOpKernel(ctx, "tf_data_delete_iterator") {} + + protected: + absl::Status DoCompute(OpKernelContext* ctx) override; +}; + +class IteratorGetNextAsOptionalOp : public HybridAsyncOpKernel { + public: + explicit IteratorGetNextAsOptionalOp(OpKernelConstruction* ctx) + : HybridAsyncOpKernel(ctx, "tf_data_iterator_get_next_as_optional") { + OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_)); + } + + protected: + absl::Status DoCompute(OpKernelContext* ctx) override; + + private: + DataTypeVector output_types_; + std::vector output_shapes_; +}; + +class IteratorToStringHandleOp : public OpKernel { + public: + explicit IteratorToStringHandleOp(OpKernelConstruction* ctx) + : OpKernel(ctx) {} + + void Compute(OpKernelContext* ctx) override; +}; + +class IteratorFromStringHandleOp : public OpKernel { + public: + explicit IteratorFromStringHandleOp(OpKernelConstruction* ctx); + + void Compute(OpKernelContext* ctx) override; + + private: + DataTypeVector output_dtypes_; + std::vector output_shapes_; +}; + +class SerializeIteratorOp : public OpKernel { + public: + static constexpr const char* const kExternalStatePolicy = + "external_state_policy"; + + explicit SerializeIteratorOp(OpKernelConstruction* ctx); + + void Compute(OpKernelContext* ctx) override; + + private: + ExternalStatePolicy external_state_policy_ = ExternalStatePolicy::POLICY_WARN; +}; + +class DeserializeIteratorOp : public OpKernel { + public: + explicit DeserializeIteratorOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} + + void Compute(OpKernelContext* ctx) override; +}; + +} // namespace data +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_DATA_ITERATOR_OPS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/data/map_dataset_op.h b/third_party/tflite-hdrs/tensorflow/core/kernels/data/map_dataset_op.h new file mode 100644 index 00000000..dff288d0 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/data/map_dataset_op.h @@ -0,0 +1,57 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_KERNELS_DATA_MAP_DATASET_OP_H_ +#define TENSORFLOW_CORE_KERNELS_DATA_MAP_DATASET_OP_H_ + +#include "tensorflow/core/data/captured_function.h" +#include "tensorflow/core/framework/dataset.h" + +namespace tensorflow { +namespace data { + +class MapDatasetOp : public UnaryDatasetOpKernel { + public: + static constexpr const char* const kDatasetType = "Map"; + static constexpr const char* const kInputDataset = "input_dataset"; + static constexpr const char* const kOtherArguments = "other_arguments"; + static constexpr const char* const kFunc = "f"; + static constexpr const char* const kTarguments = "Targuments"; + static constexpr const char* const kOutputTypes = "output_types"; + static constexpr const char* const kOutputShapes = "output_shapes"; + static constexpr const char* const kUseInterOpParallelism = + "use_inter_op_parallelism"; + static constexpr const char* const kPreserveCardinality = + "preserve_cardinality"; + static constexpr const char* const kForceSynchronous = "force_synchronous"; + + explicit MapDatasetOp(OpKernelConstruction* ctx); + + protected: + void MakeDataset(OpKernelContext* ctx, DatasetBase* input, + DatasetBase** output) override; + + private: + class Dataset; + std::shared_ptr func_metadata_ = nullptr; + DataTypeVector output_types_; + std::vector output_shapes_; + bool preserve_cardinality_; + bool force_synchronous_; +}; + +} // namespace data +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_DATA_MAP_DATASET_OP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/data/map_defun_op.h b/third_party/tflite-hdrs/tensorflow/core/kernels/data/map_defun_op.h new file mode 100644 index 00000000..fc4adde9 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/data/map_defun_op.h @@ -0,0 +1,76 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_KERNELS_DATA_MAP_DEFUN_OP_H_ +#define TENSORFLOW_CORE_KERNELS_DATA_MAP_DEFUN_OP_H_ + +#include "tensorflow/core/framework/dataset.h" + +namespace tensorflow { +namespace data { + +// This op runs a given defun on slices of the input arguments. The function +// given by "f" is assumed to be stateless, and is executed concurrently +// on all the slices; up to batch_size (i.e. the 0th dimension of each argument) +// functions will be scheduled at once. +// +// The "max_intra_op_parallelism" attr, which defaults to 1, can be used to +// limit the intra op parallelism. To limit inter-op parallelism, a user +// can set a private threadpool on the dataset using `tf.data.Options`'s +// `ThreadingOptions`. +// +// Note that this op is not exposed to users directly, but is invoked in +// tf.data rewrites. +class MapDefunOp : public AsyncOpKernel { + public: + static constexpr const char* const kArguments = "arguments"; + static constexpr const char* const kCapturedInputs = "captured_inputs"; + static constexpr const char* const kTarguments = "Targuments"; + static constexpr const char* const kTcaptured = "Tcaptured"; + static constexpr const char* const kOutputTypes = "output_types"; + static constexpr const char* const kOutputShapes = "output_shapes"; + static constexpr const char* const kFunc = "f"; + static constexpr const char* const kMaxIntraOpParallelism = + "max_intra_op_parallelism"; + + explicit MapDefunOp(OpKernelConstruction* ctx); + + ~MapDefunOp() override = default; + + void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override; + + private: + struct ComputeOptions; + class MapFunctionCallFrame; + + void SetRunOptions(OpKernelContext* ctx, + FunctionLibraryRuntime::Options* opts, + ComputeOptions* compute_opts, bool always_collect_stats); + + // Get inputs to Compute and check that they are valid. + absl::Status SetupArgs(OpKernelContext* ctx, ComputeOptions** compute_opts); + + absl::Status SetupOutputs(OpKernelContext* ctx, ComputeOptions* opts); + + FunctionLibraryRuntime::Handle func_handle_; + std::vector output_shapes_; + // If this value is positive, limit the max intra op parallelism when the + // function is run on slices of the input. + int max_intra_op_parallelism_; +}; + +} // namespace data +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_DATA_MAP_DEFUN_OP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/data/model_dataset_op.h b/third_party/tflite-hdrs/tensorflow/core/kernels/data/model_dataset_op.h new file mode 100644 index 00000000..a6198414 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/data/model_dataset_op.h @@ -0,0 +1,87 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_KERNELS_DATA_MODEL_DATASET_OP_H_ +#define TENSORFLOW_CORE_KERNELS_DATA_MODEL_DATASET_OP_H_ + +#include "tensorflow/core/platform/platform.h" + +// On mobile we do not provide model dataset op because not all of its +// dependencies are available there. The op is replaced with a no-op. +#if !defined(IS_MOBILE_PLATFORM) +#include "tensorflow/core/framework/dataset.h" +#include "tensorflow/core/framework/model.h" + +namespace tensorflow { +namespace data { + +class ModelDatasetOp : public UnaryDatasetOpKernel { + public: + static constexpr const char* const kDatasetType = "ModelDataset"; + static constexpr const char* const kDatasetOp = "ModelDatasetOp"; + static constexpr const char* const kAlgorithm = "algorithm"; + static constexpr const char* const kCpuBudget = "cpu_budget"; + static constexpr const char* const kRamBudget = "ram_budget"; + + // Executes the logic of the ModelDatasetOp directly (as opposed to through + // executing the ModelDatasetOp op kernel). + static void MakeDatasetFromOptions(OpKernelContext* ctx, DatasetBase* input, + model::AutotuneAlgorithm algorithm, + int64_t cpu_budget, int64_t ram_budget, + DatasetBase** output); + + explicit ModelDatasetOp(OpKernelConstruction* ctx); + + protected: + void MakeDataset(OpKernelContext* ctx, DatasetBase* input, + DatasetBase** output) override; + + private: + class Dataset; + + model::AutotuneAlgorithm algorithm_; + int64_t cpu_budget_; + int64_t ram_budget_; +}; + +} // namespace data +} // namespace tensorflow +#else // !IS_MOBILE_PLATFORM +#include "tensorflow/core/framework/dataset.h" + +namespace tensorflow { +namespace data { + +class ModelDatasetOp : public UnaryDatasetOpKernel { + public: + // Creates and returns a ModelDatasetOp::Dataset in output, given the + // input, algorithm, cpu_budget and ram_budget parameters. This method is used + // to create the dataset without explicitly using the ModelDatasetOp. + static void MakeDatasetFromOptions(OpKernelContext* ctx, DatasetBase* input, + model::AutotuneAlgorithm algorithm, + bool cpu_budget, bool ram_budget, + DatasetBase** output); + + explicit ModelDatasetOp(OpKernelConstruction* ctx); + + protected: + void MakeDataset(OpKernelContext* ctx, DatasetBase* input, + DatasetBase** output) override; +}; + +} // namespace data +} // namespace tensorflow +#endif // !IS_MOBILE_PLATFORM + +#endif // TENSORFLOW_CORE_KERNELS_DATA_MODEL_DATASET_OP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/data/optimize_dataset_op.h b/third_party/tflite-hdrs/tensorflow/core/kernels/data/optimize_dataset_op.h new file mode 100644 index 00000000..1824fc5a --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/data/optimize_dataset_op.h @@ -0,0 +1,97 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_KERNELS_DATA_OPTIMIZE_DATASET_OP_H_ +#define TENSORFLOW_CORE_KERNELS_DATA_OPTIMIZE_DATASET_OP_H_ + +#include "absl/container/flat_hash_set.h" +#include "tensorflow/core/framework/dataset.h" +#include "tensorflow/core/platform/platform.h" + +// On mobile we do not provide optimize dataset op because not all of its +// dependencies are available there. The op is replaced with a no-op. +#if !defined(IS_MOBILE_PLATFORM) +namespace tensorflow { +namespace data { + +class OptimizeDatasetOp : public UnaryDatasetOpKernel { + public: + static constexpr const char* const kDatasetType = "Optimize"; + static constexpr const char* const kInputDataset = "input_dataset"; + static constexpr const char* const kOptimizations = "optimizations"; + static constexpr const char* const kOptimizationsEnabled = + "optimizations_enabled"; + static constexpr const char* const kOptimizationsDisabled = + "optimizations_disabled"; + static constexpr const char* const kOptimizationsDefault = + "optimizations_default"; + static constexpr const char* const kOutputTypes = "output_types"; + static constexpr const char* const kOutputShapes = "output_shapes"; + static constexpr const char* const kOptimizationConfigs = + "optimization_configs"; + static constexpr const char* const kOptimizeDatasetV1 = "OptimizeDataset"; + static constexpr const char* const kOptimizeDatasetV2 = "OptimizeDatasetV2"; + + // Creates and returns a OptimizeDatasetOp::Dataset in output, given the + // default optimizations and those that are enabled, disabled. This method is + // used to create the dataset without explicitly using the OptimizeDatasetOp. + static void MakeDatasetFromOptions( + OpKernelContext* ctx, DatasetBase* input, + const absl::flat_hash_set& optimizations_enabled, + const absl::flat_hash_set& optimizations_disabled, + const absl::flat_hash_set& optimizations_default, + const absl::flat_hash_set& optimization_configs, + DatasetBase** output); + + explicit OptimizeDatasetOp(OpKernelConstruction* ctx); + + protected: + void MakeDataset(OpKernelContext* ctx, DatasetBase* input, + DatasetBase** output) override; + + private: + absl::flat_hash_set optimization_configs_; + int op_version_ = 0; +}; + +} // namespace data +} // namespace tensorflow +#else // !IS_MOBILE_PLATFORM +namespace tensorflow { +namespace data { + +class OptimizeDatasetOp : public UnaryDatasetOpKernel { + public: + // Executes the logic of the OptimizeDatasetOp directly (as opposed to through + // executing the OptimizeDatasetOp op kernel). + static void MakeDatasetFromOptions( + OpKernelContext* ctx, DatasetBase* input, + const absl::flat_hash_set& optimizations_enabled, + const absl::flat_hash_set& optimizations_disabled, + const absl::flat_hash_set& optimizations_default, + const absl::flat_hash_set& optimization_configs, + DatasetBase** output); + + explicit OptimizeDatasetOp(OpKernelConstruction* ctx); + + protected: + void MakeDataset(OpKernelContext* ctx, DatasetBase* input, + DatasetBase** output) override; +}; + +} // namespace data +} // namespace tensorflow +#endif // !IS_MOBILE_PLATFORM + +#endif // TENSORFLOW_CORE_KERNELS_DATA_OPTIMIZE_DATASET_OP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/data/optional_ops.h b/third_party/tflite-hdrs/tensorflow/core/kernels/data/optional_ops.h new file mode 100644 index 00000000..8006b00b --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/data/optional_ops.h @@ -0,0 +1,94 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_KERNELS_DATA_OPTIONAL_OPS_H_ +#define TENSORFLOW_CORE_KERNELS_DATA_OPTIONAL_OPS_H_ + +#include + +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/variant_tensor_data.h" +#include "tensorflow/core/kernels/data/optional_ops_util.h" +#include "tensorflow/core/util/tensor_ops_util.h" + +namespace tensorflow { +namespace data { + +// Stores a DT_VARIANT value representing an Optional with the given value +// in the `output_index`^th output of the given kernel execution context. +absl::Status WriteOptionalWithValueToOutput(OpKernelContext* ctx, + int output_index, + std::vector value); + +// Stores a DT_VARIANT value representing an Optional with no value +// in the `output_index`^th output of the given kernel execution context. +absl::Status WriteOptionalNoneToOutput(OpKernelContext* ctx, int output_index); + +template +absl::Status OptionalZerosLike(OpKernelContext* ctx, const OptionalVariant& x, + OptionalVariant* y) { + return OptionalZerosLike(ctx, x, y, ZerosLikeTensor); +} + +template +absl::Status OptionalBinaryAdd(OpKernelContext* ctx, const OptionalVariant& a, + const OptionalVariant& b, OptionalVariant* out) { + return OptionalBinaryAdd(ctx, a, b, out, BinaryAddTensors); +} + +class OptionalNoneOp : public OpKernel { + public: + explicit OptionalNoneOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} + + void Compute(OpKernelContext* ctx) override; +}; + +class OptionalFromValueOp : public OpKernel { + public: + explicit OptionalFromValueOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} + + void Compute(OpKernelContext* ctx) override; +}; + +class OptionalHasValueOp : public OpKernel { + public: + explicit OptionalHasValueOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} + + void Compute(OpKernelContext* ctx) override; +}; + +class OptionalGetValueOp : public OpKernel { + public: + explicit OptionalGetValueOp(OpKernelConstruction* ctx) : OpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_)); + OP_REQUIRES( + ctx, output_shapes_.size() == output_types_.size(), + errors::InvalidArgument( + "output_types and output_shapes must be same length, got:\n", + "output_types: ", output_types_.size(), "\n", + "output_shapes: ", output_shapes_.size())); + } + + void Compute(OpKernelContext* ctx) override; + + private: + DataTypeVector output_types_; + std::vector output_shapes_; +}; + +} // namespace data +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_DATA_OPTIONAL_OPS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/data/optional_ops_util.h b/third_party/tflite-hdrs/tensorflow/core/kernels/data/optional_ops_util.h new file mode 100644 index 00000000..3ee3742f --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/data/optional_ops_util.h @@ -0,0 +1,117 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_KERNELS_DATA_OPTIONAL_OPS_UTIL_H_ +#define TENSORFLOW_CORE_KERNELS_DATA_OPTIONAL_OPS_UTIL_H_ + +#include +#include +#include +#include + +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/variant_tensor_data.h" +#include "tensorflow/core/util/tensor_ops_util.h" + +namespace tensorflow { +namespace data { + +const char kOptionalVariantTypeName[] = "tensorflow::data::Optional"; + +// An `OptionalVariant` can represent either an "actual value" (a tuple of +// tensors) or "none", and may be stored in a DT_VARIANT tensor. +class OptionalVariant { + public: + // Create an `OptionalVariant` with no actual value. + OptionalVariant() : values_(nullptr) {} + + // Create an `OptionalVariant` with the actual value given by the tuple of + // tensors in `values`. + explicit OptionalVariant(std::vector values) { + values_ = std::make_shared>(std::move(values)); + } + + OptionalVariant(const OptionalVariant& other) : values_(other.values_) {} + + // Returns true if `this` represents an actual value. + bool has_value() const { return values_ != nullptr; } + + // REQUIRES: `this->has_value()` must be true. + const std::vector& get_values() const { + DCHECK(values_) << "Tried to get values from an empty OptionalVariant"; + return *values_; + } + + // Implementations of the necessary methods for using `OptionalVariant` + // objects in DT_VARIANT tensors. + string TypeName() const { return kOptionalVariantTypeName; } + void Encode(VariantTensorData* data) const { + data->set_metadata(values_ != nullptr); + if (values_ != nullptr) { + for (const auto& t : *values_) { + *(data->add_tensors()) = t; + } + } + } + + bool Decode(const VariantTensorData& data) { + if (data.type_name() != TypeName()) { + return false; + } + bool has_value = false; + if (!data.get_metadata(&has_value)) { + return false; + } + if (has_value) { + values_ = std::make_shared>(data.tensors()); + } else { + values_.reset(); + } + return true; + } + + string DebugString() const { + if (values_) { + return strings::StrCat("OptionalVariant<", "values: (", + absl::StrJoin(*values_, ", ", + [](string* s, const Tensor& elem) { + *s = elem.DebugString(); + }), + ")>"); + } else { + return strings::StrCat("OptionalVariant"); + } + } + + private: + std::shared_ptr> values_; +}; + +absl::Status OptionalZerosLike( + OpKernelContext* ctx, const OptionalVariant& x, OptionalVariant* y, + std::function + zeros_like_func); + +absl::Status OptionalBinaryAdd( + OpKernelContext* ctx, const OptionalVariant& a, const OptionalVariant& b, + OptionalVariant* out, + std::function + binary_add_func); + +} // namespace data +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_DATA_OPTIONAL_OPS_UTIL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/data/options_dataset_op.h b/third_party/tflite-hdrs/tensorflow/core/kernels/data/options_dataset_op.h new file mode 100644 index 00000000..024ae757 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/data/options_dataset_op.h @@ -0,0 +1,44 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_KERNELS_DATA_OPTIONS_DATASET_OP_H_ +#define TENSORFLOW_CORE_KERNELS_DATA_OPTIONS_DATASET_OP_H_ + +#include "tensorflow/core/framework/dataset.h" + +namespace tensorflow { +namespace data { + +// TODO(jsimsa): Provide class-level documentation for this and the other ops. +class OptionsDatasetOp : public DatasetOpKernel { + public: + static constexpr const char* const kDatasetType = "Options"; + static constexpr const char* const kInputDataset = "input_dataset"; + static constexpr const char* const kOutputTypes = "output_types"; + static constexpr const char* const kOutputShapes = "output_shapes"; + static constexpr const char* const kSerializedOptions = "serialized_options"; + + explicit OptionsDatasetOp(OpKernelConstruction* ctx); + + void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override; + + private: + class Dataset; + tstring serialized_options_; +}; + +} // namespace data +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_DATA_OPTIONS_DATASET_OP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/data/padded_batch_dataset_op.h b/third_party/tflite-hdrs/tensorflow/core/kernels/data/padded_batch_dataset_op.h new file mode 100644 index 00000000..474587db --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/data/padded_batch_dataset_op.h @@ -0,0 +1,51 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_KERNELS_DATA_PADDED_BATCH_DATASET_OP_H_ +#define TENSORFLOW_CORE_KERNELS_DATA_PADDED_BATCH_DATASET_OP_H_ + +#include "tensorflow/core/framework/dataset.h" + +namespace tensorflow { +namespace data { + +class PaddedBatchDatasetOp : public UnaryDatasetOpKernel { + public: + static constexpr const char* const kDatasetType = "PaddedBatch"; + static constexpr const char* const kInputDataset = "input_dataset"; + static constexpr const char* const kBatchSize = "batch_size"; + static constexpr const char* const kPaddedShapes = "padded_shapes"; + static constexpr const char* const kPaddingValues = "padding_values"; + static constexpr const char* const kDropRemainder = "drop_remainder"; + static constexpr const char* const kParallelCopy = "parallel_copy"; + static constexpr const char* const kToutputTypes = "Toutput_types"; + static constexpr const char* const kOutputShapes = "output_shapes"; + static constexpr const char* const kNumPaddedShapes = "N"; + + explicit PaddedBatchDatasetOp(OpKernelConstruction* ctx); + + protected: + void MakeDataset(OpKernelContext* ctx, DatasetBase* input, + DatasetBase** output) override; + + private: + class Dataset; + const int op_version_; + bool parallel_copy_ = false; +}; + +} // namespace data +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_DATA_PADDED_BATCH_DATASET_OP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/data/parallel_batch_dataset_op.h b/third_party/tflite-hdrs/tensorflow/core/kernels/data/parallel_batch_dataset_op.h new file mode 100644 index 00000000..219dc73c --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/data/parallel_batch_dataset_op.h @@ -0,0 +1,51 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_KERNELS_DATA_PARALLEL_BATCH_DATASET_OP_H_ +#define TENSORFLOW_CORE_KERNELS_DATA_PARALLEL_BATCH_DATASET_OP_H_ + +#include "tensorflow/core/data/dataset_utils.h" +#include "tensorflow/core/framework/dataset.h" + +namespace tensorflow { +namespace data { + +class ParallelBatchDatasetOp : public UnaryDatasetOpKernel { + public: + static constexpr const char* const kDatasetType = "ParallelBatch"; + static constexpr const char* const kInputDataset = "input_dataset"; + static constexpr const char* const kBatchSize = "batch_size"; + static constexpr const char* const kNumParallelCalls = "num_parallel_calls"; + static constexpr const char* const kDropRemainder = "drop_remainder"; + static constexpr const char* const kParallelCopy = "parallel_copy"; + static constexpr const char* const kOutputTypes = "output_types"; + static constexpr const char* const kOutputShapes = "output_shapes"; + static constexpr const char* const kDeterministic = "deterministic"; + + explicit ParallelBatchDatasetOp(OpKernelConstruction* ctx); + + protected: + void MakeDataset(OpKernelContext* ctx, DatasetBase* input, + DatasetBase** output) override; + + private: + class Dataset; + DeterminismPolicy deterministic_; + bool parallel_copy_ = false; +}; + +} // namespace data +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_DATA_PARALLEL_BATCH_DATASET_OP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/data/parallel_filter_dataset_op.h b/third_party/tflite-hdrs/tensorflow/core/kernels/data/parallel_filter_dataset_op.h new file mode 100644 index 00000000..48b1bda1 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/data/parallel_filter_dataset_op.h @@ -0,0 +1,52 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_KERNELS_DATA_PARALLEL_FILTER_DATASET_OP_H_ +#define TENSORFLOW_CORE_KERNELS_DATA_PARALLEL_FILTER_DATASET_OP_H_ + +#include "tensorflow/core/data/captured_function.h" +#include "tensorflow/core/data/dataset_utils.h" +#include "tensorflow/core/framework/dataset.h" + +namespace tensorflow { +namespace data { + +class ParallelFilterDatasetOp : public UnaryDatasetOpKernel { + public: + static constexpr const char* const kDatasetType = "ParallelFilter"; + static constexpr const char* const kInputDataset = "input_dataset"; + static constexpr const char* const kOtherArguments = "other_arguments"; + static constexpr const char* const kNumParallelCalls = "num_parallel_calls"; + static constexpr const char* const kPredicate = "predicate"; + static constexpr const char* const kDeterministic = "deterministic"; + static constexpr const char* const kTarguments = "Targuments"; + static constexpr const char* const kOutputTypes = "output_types"; + static constexpr const char* const kOutputShapes = "output_shapes"; + + explicit ParallelFilterDatasetOp(OpKernelConstruction* ctx); + + protected: + void MakeDataset(OpKernelContext* ctx, DatasetBase* input, + DatasetBase** output) override; + + private: + class Dataset; + DeterminismPolicy deterministic_; + std::shared_ptr func_metadata_ = nullptr; +}; + +} // namespace data +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_DATA_PARALLEL_FILTER_DATASET_OP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/data/parallel_interleave_dataset_op.h b/third_party/tflite-hdrs/tensorflow/core/kernels/data/parallel_interleave_dataset_op.h new file mode 100644 index 00000000..be46a360 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/data/parallel_interleave_dataset_op.h @@ -0,0 +1,62 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_KERNELS_DATA_PARALLEL_INTERLEAVE_DATASET_OP_H_ +#define TENSORFLOW_CORE_KERNELS_DATA_PARALLEL_INTERLEAVE_DATASET_OP_H_ + +#include "tensorflow/core/data/captured_function.h" +#include "tensorflow/core/data/dataset_utils.h" +#include "tensorflow/core/framework/dataset.h" + +namespace tensorflow { +namespace data { + +class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel { + public: + static constexpr const char* const kDatasetType = "ParallelInterleave"; + static constexpr const char* const kInputDataset = "input_dataset"; + static constexpr const char* const kOtherArguments = "other_arguments"; + static constexpr const char* const kCycleLength = "cycle_length"; + static constexpr const char* const kBlockLength = "block_length"; + static constexpr const char* const kBufferOutputElements = + "buffer_output_elements"; + static constexpr const char* const kPrefetchInputElements = + "prefetch_input_elements"; + static constexpr const char* const kNumParallelCalls = "num_parallel_calls"; + static constexpr const char* const kFunc = "f"; + static constexpr const char* const kTarguments = "Targuments"; + static constexpr const char* const kOutputTypes = "output_types"; + static constexpr const char* const kOutputShapes = "output_shapes"; + static constexpr const char* const kDeterministic = "deterministic"; + static constexpr const char* const kSloppy = "sloppy"; + + explicit ParallelInterleaveDatasetOp(OpKernelConstruction* ctx); + + protected: + void MakeDataset(OpKernelContext* ctx, DatasetBase* input, + DatasetBase** output) override; + + private: + class Dataset; + const int op_version_; + std::shared_ptr func_metadata_ = nullptr; + DataTypeVector output_types_; + std::vector output_shapes_; + DeterminismPolicy deterministic_; +}; + +} // namespace data +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_DATA_PARALLEL_INTERLEAVE_DATASET_OP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/data/parallel_map_dataset_op.h b/third_party/tflite-hdrs/tensorflow/core/kernels/data/parallel_map_dataset_op.h new file mode 100644 index 00000000..efdf6339 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/data/parallel_map_dataset_op.h @@ -0,0 +1,76 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_KERNELS_DATA_PARALLEL_MAP_DATASET_OP_H_ +#define TENSORFLOW_CORE_KERNELS_DATA_PARALLEL_MAP_DATASET_OP_H_ + +#include "tensorflow/core/data/captured_function.h" +#include "tensorflow/core/data/dataset_utils.h" +#include "tensorflow/core/framework/dataset.h" + +namespace tensorflow { +namespace data { + +class ParallelMapDatasetOp : public UnaryDatasetOpKernel { + public: + static constexpr const char* const kDatasetType = "ParallelMap"; + static constexpr const char* const kInputDataset = "input_dataset"; + static constexpr const char* const kOtherArguments = "other_arguments"; + static constexpr const char* const kNumParallelCalls = "num_parallel_calls"; + static constexpr const char* const kFunc = "f"; + static constexpr const char* const kTarguments = "Targuments"; + static constexpr const char* const kOutputTypes = "output_types"; + static constexpr const char* const kOutputShapes = "output_shapes"; + static constexpr const char* const kUseInterOpParallelism = + "use_inter_op_parallelism"; + static constexpr const char* const kDeterministic = "deterministic"; + static constexpr const char* const kSloppy = "sloppy"; + static constexpr const char* const kPreserveCardinality = + "preserve_cardinality"; + static constexpr const char* const kUseUnboundedThreadpool = + "use_unbounded_threadpool"; + + explicit ParallelMapDatasetOp(OpKernelConstruction* ctx); + + protected: + void MakeDataset(OpKernelContext* ctx, DatasetBase* input, + DatasetBase** output) override; + + private: + class Dataset; + const int op_version_; + std::shared_ptr func_metadata_ = nullptr; + DataTypeVector output_types_; + std::vector output_shapes_; + bool sloppy_; + bool preserve_cardinality_; + DeterminismPolicy deterministic_; + bool use_unbounded_threadpool_; + + friend std::unique_ptr MakeDataServiceUncompressDataset( + DatasetBase* input, std::unique_ptr captured_function, + const DataTypeVector& output_types, + const std::vector& output_shapes); +}; + +// Used by tf.data service to create a map dataset for uncompression. +std::unique_ptr MakeDataServiceUncompressDataset( + DatasetBase* input, std::unique_ptr captured_function, + const DataTypeVector& output_types, + const std::vector& output_shapes); + +} // namespace data +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_DATA_PARALLEL_MAP_DATASET_OP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/data/prefetch_autotuner.h b/third_party/tflite-hdrs/tensorflow/core/kernels/data/prefetch_autotuner.h new file mode 100644 index 00000000..a06eb60f --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/data/prefetch_autotuner.h @@ -0,0 +1,86 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_DATA_PREFETCH_AUTOTUNER_H_ +#define TENSORFLOW_CORE_KERNELS_DATA_PREFETCH_AUTOTUNER_H_ + +#include +#include +#include + +#include "tensorflow/core/framework/model.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +namespace data { + +// PrefetchAutotuner dynamically adjusts the buffer size of a prefetch iterator. +// +// PrefetchAutotuner attempts to find the minimum buffer size such that there is +// always at least 1 element in the prefetch queue every time the downstream +// iterator calls GetNext(). +// +// One common failure mode of input pipelines is being throughput bound. No +// amount of prefetching can address that performance mode. In order to guard +// against this condition, PrefetchAutotuner will only increase the buffer_limit +// if the prefetching thread is able to successfully fill the buffer at its +// current size. +// +// Note: in the current implementation, we never decrease the buffer_limit(). +// This should change in the future! +// +// PrefetchAutotuner is NOT thread safe. +class PrefetchAutotuner { + public: + explicit PrefetchAutotuner( + int64_t initial_buffer_size, int64_t buffer_size_min, + std::shared_ptr ram_budget_manager); + + int64_t buffer_limit() const { return buffer_limit_; } + + // Reports whether the element size has been set. + bool HasElementSize() const { return element_size_bytes_.has_value(); } + // Sets the element size to use for predicting memory usage. Element size must + // be set before the autotuner can increase the buffer size. + void SetElementSize(int64_t element_size_bytes); + void RecordConsumption(size_t current_buffer_size); + void RecordEmpty() { RecordConsumption(0); } + + private: + // PrefetchAutotuner operates as a state machine. + enum class Mode { + // Disables the autotuning. + kDisabled, + + // We have increased the size of the buffer, and will transition to + // kDownswing if we successfully fill the buffer. + kUpswing, + + // We have successfully filled a buffer of this size. If we ever block the + // downstream iterator, we should increase the buffer size. + kDownswing, + }; + + int64_t buffer_limit_; + // Estimated per-element size. + std::optional element_size_bytes_; + Mode mode_ = Mode::kDisabled; + std::shared_ptr ram_budget_manager_; +}; + +} // namespace data +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_DATA_PREFETCH_AUTOTUNER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/data/prefetch_dataset_op.h b/third_party/tflite-hdrs/tensorflow/core/kernels/data/prefetch_dataset_op.h new file mode 100644 index 00000000..e193e75e --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/data/prefetch_dataset_op.h @@ -0,0 +1,53 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_DATA_PREFETCH_DATASET_OP_H_ +#define TENSORFLOW_CORE_KERNELS_DATA_PREFETCH_DATASET_OP_H_ + +#include "tensorflow/core/framework/dataset.h" +#include "tensorflow/core/framework/model.h" +#include "tensorflow/core/kernels/data/prefetch_autotuner.h" + +namespace tensorflow { +namespace data { + +class PrefetchDatasetOp : public UnaryDatasetOpKernel { + public: + static constexpr const char* const kDatasetType = "Prefetch"; + static constexpr const char* const kInputDataset = "input_dataset"; + static constexpr const char* const kBufferSize = model::kBufferSize; + static constexpr const char* const kOutputTypes = "output_types"; + static constexpr const char* const kOutputShapes = "output_shapes"; + static constexpr const char* const kSlackPeriod = "slack_period"; + static constexpr const char* const kLegacyAutotune = "legacy_autotune"; + static constexpr const char* const kBufferSizeMin = "buffer_size_min"; + + explicit PrefetchDatasetOp(OpKernelConstruction* ctx); + + protected: + void MakeDataset(OpKernelContext* ctx, DatasetBase* input, + DatasetBase** output) override; + + private: + class Dataset; + int64_t slack_period_ = 0; + bool legacy_autotune_ = true; + int64_t buffer_size_min_ = 0; +}; + +} // namespace data +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_DATA_PREFETCH_DATASET_OP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/data/random_seed_ops.h b/third_party/tflite-hdrs/tensorflow/core/kernels/data/random_seed_ops.h new file mode 100644 index 00000000..f0afa739 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/data/random_seed_ops.h @@ -0,0 +1,160 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_DATA_RANDOM_SEED_OPS_H_ +#define TENSORFLOW_CORE_KERNELS_DATA_RANDOM_SEED_OPS_H_ + +#include "tensorflow/core/data/dataset_utils.h" +#include "tensorflow/core/framework/resource_mgr.h" +#include "tensorflow/core/lib/random/philox_random.h" +#include "tensorflow/core/lib/random/random.h" +#include "tensorflow/core/lib/random/random_distributions.h" + +namespace tensorflow { +namespace data { + +// Represents a pair of random seeds. By TensorFlow convention, if both seeds +// are 0, then pseudo-random values are used instead. +class RandomSeeds { + public: + RandomSeeds(int64_t seed, int64_t seed2) + : input_seed_(seed), + input_seed2_(seed2), + seed_((seed | seed2) == 0 ? random::New64() : seed), + seed2_((seed | seed2) == 0 ? random::New64() : seed2) {} + + int64_t input_seed() const { return input_seed_; } + int64_t input_seed2() const { return input_seed2_; } + int64_t seed() const { return seed_; } + int64_t seed2() const { return seed2_; } + + private: + const int64_t input_seed_; + const int64_t input_seed2_; + const int64_t seed_; + const int64_t seed2_; +}; + +// Base class for seed generator resources. Subclasses customize how seeds are +// generated. +class SeedGenerator { + public: + virtual ~SeedGenerator() {} + + virtual int64_t seed() const = 0; + virtual int64_t seed2() const = 0; + virtual bool reshuffle_each_iteration() const = 0; + + virtual void GenerateSeeds(int64_t* seed1, int64_t* seed2) = 0; + virtual void Reset() = 0; + + virtual int64_t num_random_samples() const { + tf_shared_lock l(mu_); + return num_random_samples_; + } + virtual void set_num_random_samples(int64_t num_random_samples) { + mutex_lock l(mu_); + num_random_samples_ = num_random_samples; + } + + protected: + mutable mutex mu_; + int64_t num_random_samples_ TF_GUARDED_BY(mu_) = 0; +}; + +// A resource wrapping a shared instance of a seed generator. +class SeedGeneratorManager : public ResourceBase { + public: + explicit SeedGeneratorManager(SeedGenerator* seed_generator) + : seed_generator_(seed_generator) {} + + std::string DebugString() const override; + + std::shared_ptr get() { return seed_generator_; } + + private: + std::shared_ptr seed_generator_; +}; + +// Always generates the specified seed values. +class FixedSeedGenerator : public SeedGenerator { + public: + explicit FixedSeedGenerator(RandomSeeds seeds) : seeds_(std::move(seeds)) {} + + int64_t seed() const override { return seeds_.seed(); } + int64_t seed2() const override { return seeds_.seed(); } + bool reshuffle_each_iteration() const override { return false; } + + void GenerateSeeds(int64_t* seed1, int64_t* seed2) override; + void Reset() override {} + + private: + const RandomSeeds seeds_; +}; + +// Generates different (but deterministically chosen) seed values. +class RandomSeedGenerator : public SeedGenerator { + public: + explicit RandomSeedGenerator(RandomSeeds seeds) + : seeds_(std::move(seeds)), + parent_generator_(seeds_.seed(), seeds_.seed2()), + generator_(&parent_generator_) {} + + int64_t seed() const override { return seeds_.seed(); } + int64_t seed2() const override { return seeds_.seed2(); } + bool reshuffle_each_iteration() const override { return true; } + + void GenerateSeeds(int64_t* seed1, int64_t* seed2) override; + void Reset() override; + + private: + const RandomSeeds seeds_; + random::PhiloxRandom parent_generator_ TF_GUARDED_BY(mu_); + random::SingleSampleAdapter generator_ + TF_GUARDED_BY(mu_); +}; + +// Creates an instance of seed generator resource and transfers ownership +// to the caller. +class AnonymousSeedGeneratorHandleOp + : public AnonymousResourceOp { + public: + explicit AnonymousSeedGeneratorHandleOp(OpKernelConstruction* ctx); + void Compute(OpKernelContext* ctx) override; + + private: + string name() override; + absl::Status CreateResource( + OpKernelContext* ctx, std::unique_ptr flib_def, + std::unique_ptr pflr, + FunctionLibraryRuntime* lib, SeedGeneratorManager** manager) override; + + mutex mu_; + std::unique_ptr seeds_ TF_GUARDED_BY(mu_); + bool reshuffle_; +}; + +// Deletes an instance of seed generator resource. +class DeleteSeedGeneratorOp : public OpKernel { + public: + explicit DeleteSeedGeneratorOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} + + void Compute(OpKernelContext* ctx) override; +}; + +} // namespace data +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_DATA_RANDOM_SEED_OPS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/data/range_dataset_op.h b/third_party/tflite-hdrs/tensorflow/core/kernels/data/range_dataset_op.h new file mode 100644 index 00000000..687f2eb6 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/data/range_dataset_op.h @@ -0,0 +1,48 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_KERNELS_DATA_RANGE_DATASET_OP_H_ +#define TENSORFLOW_CORE_KERNELS_DATA_RANGE_DATASET_OP_H_ + +#include "tensorflow/core/framework/dataset.h" + +namespace tensorflow { +namespace data { + +class RangeDatasetOp : public DatasetOpKernel { + public: + static constexpr const char* const kDatasetType = "Range"; + static constexpr const char* const kStart = "start"; + static constexpr const char* const kStop = "stop"; + static constexpr const char* const kStep = "step"; + static constexpr const char* const kOutputTypes = "output_types"; + static constexpr const char* const kOutputShapes = "output_shapes"; + static constexpr const char* const kReplicateOnSplit = "replicate_on_split"; + + explicit RangeDatasetOp(OpKernelConstruction* ctx); + + protected: + void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override; + + private: + class Dataset; + class RangeSplitProvider; + DataTypeVector output_types_; + bool replicate_on_split_ = false; +}; + +} // namespace data +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_DATA_RANGE_DATASET_OP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/data/reduce_dataset_op.h b/third_party/tflite-hdrs/tensorflow/core/kernels/data/reduce_dataset_op.h new file mode 100644 index 00000000..73e18144 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/data/reduce_dataset_op.h @@ -0,0 +1,43 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_DATA_REDUCE_DATASET_OP_H_ +#define TENSORFLOW_CORE_KERNELS_DATA_REDUCE_DATASET_OP_H_ + +#include "tensorflow/core/data/captured_function.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/kernels/data/iterator_ops.h" + +namespace tensorflow { +namespace data { + +class ReduceDatasetOp : public HybridAsyncOpKernel { + public: + explicit ReduceDatasetOp(OpKernelConstruction* ctx); + + protected: + absl::Status DoCompute(OpKernelContext* ctx) override; + + std::shared_ptr func_metadata_ = nullptr; + DataTypeVector output_types_; + std::vector output_shapes_; +}; + +} // namespace data +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_DATA_REDUCE_DATASET_OP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/data/repeat_dataset_op.h b/third_party/tflite-hdrs/tensorflow/core/kernels/data/repeat_dataset_op.h new file mode 100644 index 00000000..81d534f7 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/data/repeat_dataset_op.h @@ -0,0 +1,44 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_KERNELS_DATA_REPEAT_DATASET_OP_H_ +#define TENSORFLOW_CORE_KERNELS_DATA_REPEAT_DATASET_OP_H_ + +#include "tensorflow/core/framework/dataset.h" + +namespace tensorflow { +namespace data { + +class RepeatDatasetOp : public UnaryDatasetOpKernel { + public: + static constexpr const char* const kDatasetType = "Repeat"; + static constexpr const char* const kInputDataset = "input_dataset"; + static constexpr const char* const kCount = "count"; + static constexpr const char* const kOutputTypes = "output_types"; + static constexpr const char* const kOutputShapes = "output_shapes"; + + explicit RepeatDatasetOp(OpKernelConstruction* ctx); + + protected: + void MakeDataset(OpKernelContext* ctx, DatasetBase* input, + DatasetBase** output) override; + + private: + class Dataset; +}; + +} // namespace data +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_DATA_REPEAT_DATASET_OP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/data/rewrite_dataset_op.h b/third_party/tflite-hdrs/tensorflow/core/kernels/data/rewrite_dataset_op.h new file mode 100644 index 00000000..cd9b34b4 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/data/rewrite_dataset_op.h @@ -0,0 +1,41 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_KERNELS_DATA_REWRITE_DATASET_OP_H_ +#define TENSORFLOW_CORE_KERNELS_DATA_REWRITE_DATASET_OP_H_ + +#include "tensorflow/core/framework/dataset.h" + +namespace tensorflow { +namespace data { + +class RewriteDatasetOp : public UnaryDatasetOpKernel { + public: + static constexpr const char* const kDatasetType = "Rewrite"; + static constexpr const char* const kInputDataset = "input_dataset"; + static constexpr const char* const kRewriteName = "rewrite_name"; + static constexpr const char* const kOutputTypes = "output_types"; + static constexpr const char* const kOutputShapes = "output_shapes"; + + explicit RewriteDatasetOp(OpKernelConstruction* ctx); + + protected: + void MakeDataset(OpKernelContext* ctx, DatasetBase* input, + DatasetBase** output) override; +}; + +} // namespace data +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_DATA_REWRITE_DATASET_OP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/data/shard_dataset_op.h b/third_party/tflite-hdrs/tensorflow/core/kernels/data/shard_dataset_op.h new file mode 100644 index 00000000..acdf171a --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/data/shard_dataset_op.h @@ -0,0 +1,47 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_KERNELS_DATA_SHARD_DATASET_OP_H_ +#define TENSORFLOW_CORE_KERNELS_DATA_SHARD_DATASET_OP_H_ + +#include "tensorflow/core/framework/dataset.h" + +namespace tensorflow { +namespace data { + +class ShardDatasetOp : public UnaryDatasetOpKernel { + public: + static constexpr const char* const kDatasetType = "Shard"; + static constexpr const char* const kInputDataset = "input_dataset"; + static constexpr const char* const kNumShards = "num_shards"; + static constexpr const char* const kIndex = "index"; + static constexpr const char* const kRequireNonEmpty = "require_non_empty"; + static constexpr const char* const kOutputTypes = "output_types"; + static constexpr const char* const kOutputShapes = "output_shapes"; + + explicit ShardDatasetOp(OpKernelConstruction* ctx); + + protected: + void MakeDataset(OpKernelContext* ctx, DatasetBase* input, + DatasetBase** output) override; + + private: + class Dataset; + bool require_non_empty_; +}; + +} // namespace data +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_DATA_SHARD_DATASET_OP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/data/shuffle_dataset_op.h b/third_party/tflite-hdrs/tensorflow/core/kernels/data/shuffle_dataset_op.h new file mode 100644 index 00000000..f33f75c8 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/data/shuffle_dataset_op.h @@ -0,0 +1,79 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_KERNELS_DATA_SHUFFLE_DATASET_OP_H_ +#define TENSORFLOW_CORE_KERNELS_DATA_SHUFFLE_DATASET_OP_H_ + +#include "tensorflow/core/framework/dataset.h" + +namespace tensorflow { +namespace data { + +class ShuffleDatasetOpBase : public UnaryDatasetOpKernel { + public: + static constexpr const char* const kInputDataset = "input_dataset"; + static constexpr const char* const kBufferSize = "buffer_size"; + static constexpr const char* const kSeed = "seed"; + static constexpr const char* const kSeed2 = "seed2"; + static constexpr const char* const kOutputTypes = "output_types"; + static constexpr const char* const kOutputShapes = "output_shapes"; + static constexpr const char* const kReshuffleEachIteration = + "reshuffle_each_iteration"; + + explicit ShuffleDatasetOpBase(OpKernelConstruction* ctx); + + protected: + class ShuffleDatasetBase; +}; + +class ShuffleDatasetOp : public ShuffleDatasetOpBase { + public: + static constexpr const char* const kDatasetType = "Shuffle"; + + explicit ShuffleDatasetOp(OpKernelConstruction* ctx); + + protected: + void MakeDataset(OpKernelContext* ctx, DatasetBase* input, + DatasetBase** output) override; + + private: + class Dataset; + class DatasetV2; + class DatasetV3; + int op_version_ = 0; + bool reshuffle_each_iteration_ = true; +}; + +class ShuffleAndRepeatDatasetOp : public ShuffleDatasetOpBase { + public: + static constexpr const char* const kDatasetType = "ShuffleAndRepeat"; + static constexpr const char* const kCount = "count"; + + explicit ShuffleAndRepeatDatasetOp(OpKernelConstruction* ctx); + + protected: + void MakeDataset(OpKernelContext* ctx, DatasetBase* input, + DatasetBase** output) override; + + private: + class Dataset; + class DatasetV2; + int op_version_ = 0; + bool reshuffle_each_iteration_ = true; +}; + +} // namespace data +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_DATA_SHUFFLE_DATASET_OP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/data/skip_dataset_op.h b/third_party/tflite-hdrs/tensorflow/core/kernels/data/skip_dataset_op.h new file mode 100644 index 00000000..6e22d7af --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/data/skip_dataset_op.h @@ -0,0 +1,44 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_KERNELS_DATA_SKIP_DATASET_OP_H_ +#define TENSORFLOW_CORE_KERNELS_DATA_SKIP_DATASET_OP_H_ + +#include "tensorflow/core/framework/dataset.h" + +namespace tensorflow { +namespace data { + +class SkipDatasetOp : public UnaryDatasetOpKernel { + public: + static constexpr const char* const kDatasetType = "Skip"; + static constexpr const char* const kInputDataset = "input_dataset"; + static constexpr const char* const kCount = "count"; + static constexpr const char* const kOutputTypes = "output_types"; + static constexpr const char* const kOutputShapes = "output_shapes"; + + explicit SkipDatasetOp(OpKernelConstruction* ctx); + + protected: + void MakeDataset(OpKernelContext* ctx, DatasetBase* input, + DatasetBase** output) override; + + private: + class Dataset; +}; + +} // namespace data +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_DATA_SKIP_DATASET_OP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/data/take_dataset_op.h b/third_party/tflite-hdrs/tensorflow/core/kernels/data/take_dataset_op.h new file mode 100644 index 00000000..de51d6a4 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/data/take_dataset_op.h @@ -0,0 +1,90 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_KERNELS_DATA_TAKE_DATASET_OP_H_ +#define TENSORFLOW_CORE_KERNELS_DATA_TAKE_DATASET_OP_H_ + +#include +#include +#include + +#include "absl/status/status.h" +#include "tensorflow/core/framework/dataset.h" +#include "tensorflow/core/framework/tensor.h" + +namespace tensorflow { +namespace data { + +class TakeDataset : public DatasetBase { + public: + TakeDataset(OpKernelContext* ctx, int64_t count, const DatasetBase* input); + + TakeDataset(DatasetContext::Params params, int64_t count, + const DatasetBase* input); + + ~TakeDataset() override; + + std::unique_ptr MakeIteratorInternal( + const string& prefix) const override; + + const DataTypeVector& output_dtypes() const override; + + const std::vector& output_shapes() const override; + + string DebugString() const override; + + int64_t CardinalityInternal(CardinalityOptions options) const override; + + absl::Status InputDatasets( + std::vector* inputs) const override; + + absl::Status Get(OpKernelContext* ctx, int64 index, + std::vector* out_tensors) const override; + + absl::Status CheckExternalState() const override; + + absl::Status RandomIndexingCompatible() const override; + + protected: + absl::Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, + Node** output) const override; + + private: + class EmptyIterator; + class FiniteIterator; + const int64_t count_; + const DatasetBase* const input_; + absl::Status random_indexing_compatible_; +}; + +class TakeDatasetOp : public UnaryDatasetOpKernel { + public: + static constexpr const char* const kDatasetType = "Take"; + static constexpr const char* const kInputDataset = "input_dataset"; + static constexpr const char* const kCount = "count"; + static constexpr const char* const kOutputTypes = "output_types"; + static constexpr const char* const kOutputShapes = "output_shapes"; + + explicit TakeDatasetOp(OpKernelConstruction* ctx); + + protected: + void MakeDataset(OpKernelContext* ctx, DatasetBase* input, + DatasetBase** output) override; +}; + +} // namespace data +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_DATA_TAKE_DATASET_OP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/data/tensor_dataset_op.h b/third_party/tflite-hdrs/tensorflow/core/kernels/data/tensor_dataset_op.h new file mode 100644 index 00000000..dcd738e9 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/data/tensor_dataset_op.h @@ -0,0 +1,44 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_KERNELS_DATA_TENSOR_DATASET_OP_H_ +#define TENSORFLOW_CORE_KERNELS_DATA_TENSOR_DATASET_OP_H_ + +#include "tensorflow/core/framework/dataset.h" + +namespace tensorflow { +namespace data { + +class TensorDatasetOp : public DatasetOpKernel { + public: + static constexpr const char* const kDatasetType = "Tensor"; + static constexpr const char* const kComponents = "components"; + static constexpr const char* const kToutput_types = "Toutput_types"; + static constexpr const char* const kOutputShapes = "output_shapes"; + + explicit TensorDatasetOp(OpKernelConstruction* ctx); + + protected: + void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override; + + private: + class Dataset; + DataTypeVector output_types_; + std::vector output_shapes_; +}; + +} // namespace data +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_DATA_TENSOR_DATASET_OP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/data/tensor_slice_dataset_op.h b/third_party/tflite-hdrs/tensorflow/core/kernels/data/tensor_slice_dataset_op.h new file mode 100644 index 00000000..c2ddbaf1 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/data/tensor_slice_dataset_op.h @@ -0,0 +1,48 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_KERNELS_DATA_TENSOR_SLICE_DATASET_OP_H_ +#define TENSORFLOW_CORE_KERNELS_DATA_TENSOR_SLICE_DATASET_OP_H_ + +#include "tensorflow/core/framework/dataset.h" + +namespace tensorflow { +namespace data { + +class TensorSliceDatasetOp : public DatasetOpKernel { + public: + static constexpr const char* const kDatasetType = "TensorSlice"; + static constexpr const char* const kComponents = "components"; + static constexpr const char* const kToutputTypes = "Toutput_types"; + static constexpr const char* const kOutputShapes = "output_shapes"; + static constexpr const char* const kIsFiles = "is_files"; + static constexpr const char* const kReplicateOnSplit = "replicate_on_split"; + + explicit TensorSliceDatasetOp(OpKernelConstruction* ctx); + + protected: + void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override; + + private: + class Dataset; + DataTypeVector output_types_; + std::vector output_shapes_; + bool is_files_ = false; + bool replicate_on_split_ = false; +}; + +} // namespace data +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_DATA_TENSOR_SLICE_DATASET_OP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/data/text_line_dataset_op.h b/third_party/tflite-hdrs/tensorflow/core/kernels/data/text_line_dataset_op.h new file mode 100644 index 00000000..3621b57a --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/data/text_line_dataset_op.h @@ -0,0 +1,42 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_KERNELS_DATA_TEXT_LINE_DATASET_OP_H_ +#define TENSORFLOW_CORE_KERNELS_DATA_TEXT_LINE_DATASET_OP_H_ + +#include "tensorflow/core/framework/dataset.h" + +namespace tensorflow { +namespace data { + +class TextLineDatasetOp : public DatasetOpKernel { + public: + static constexpr const char* const kDatasetType = "TextLine"; + static constexpr const char* const kFileNames = "filenames"; + static constexpr const char* const kCompressionType = "compression_type"; + static constexpr const char* const kBufferSize = "buffer_size"; + + explicit TextLineDatasetOp(OpKernelConstruction* ctx); + + protected: + void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override; + + private: + class Dataset; +}; + +} // namespace data +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_DATA_TEXT_LINE_DATASET_OP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/data/tf_record_dataset_op.h b/third_party/tflite-hdrs/tensorflow/core/kernels/data/tf_record_dataset_op.h new file mode 100644 index 00000000..0cfbc667 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/data/tf_record_dataset_op.h @@ -0,0 +1,44 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_KERNELS_DATA_TF_RECORD_DATASET_OP_H_ +#define TENSORFLOW_CORE_KERNELS_DATA_TF_RECORD_DATASET_OP_H_ + +#include "tensorflow/core/framework/dataset.h" + +namespace tensorflow { +namespace data { + +class TFRecordDatasetOp : public DatasetOpKernel { + public: + static constexpr const char* const kDatasetType = "TFRecord"; + static constexpr const char* const kFileNames = "filenames"; + static constexpr const char* const kCompressionType = "compression_type"; + static constexpr const char* const kBufferSize = "buffer_size"; + static constexpr const char* const kByteOffsets = "byte_offsets"; + + explicit TFRecordDatasetOp(OpKernelConstruction* ctx); + + protected: + void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override; + + private: + class Dataset; + int op_version_; +}; + +} // namespace data +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_DATA_TF_RECORD_DATASET_OP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/data/window_dataset.h b/third_party/tflite-hdrs/tensorflow/core/kernels/data/window_dataset.h new file mode 100644 index 00000000..17e5b6b5 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/data/window_dataset.h @@ -0,0 +1,52 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_KERNELS_DATA_WINDOW_DATASET_H_ +#define TENSORFLOW_CORE_KERNELS_DATA_WINDOW_DATASET_H_ + +#include + +#include "tensorflow/core/framework/dataset.h" +#include "tensorflow/core/framework/partial_tensor_shape.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/platform/status.h" + +namespace tensorflow { +namespace data { + +// Creates a dataset representing an eagerly-collected window of elements. +// +// The `elements` argument defines the elements of the resulting +// dataset, which is stored in `out_dataset`. +// +// This dataset is constructed internally for use in datasets that +// build nested dataset expressions (e.g. the reducer function for +// GroupByWindowDataset). It efficiently supports multiple iterators on +// the same window without recomputation. +// +// REQUIRES: `output_types` must match the types of the respective +// element components in `elements`. +// REQUIRES: `output_shapes` must be compatible with the shapes of the +// respective element components in `elements`.a +absl::Status NewWindow(std::vector> elements, + DataTypeVector output_types, + std::vector output_shapes, + DatasetBase** out_dataset); + +} // namespace data +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_DATA_WINDOW_DATASET_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/data/window_dataset_op.h b/third_party/tflite-hdrs/tensorflow/core/kernels/data/window_dataset_op.h new file mode 100644 index 00000000..241e0f51 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/data/window_dataset_op.h @@ -0,0 +1,53 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_KERNELS_DATA_WINDOW_DATASET_OP_H_ +#define TENSORFLOW_CORE_KERNELS_DATA_WINDOW_DATASET_OP_H_ + +#include + +#include "tensorflow/core/framework/dataset.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/partial_tensor_shape.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/types.h" + +namespace tensorflow { +namespace data { + +class WindowDatasetOp : public UnaryDatasetOpKernel { + public: + static constexpr const char* const kDatasetType = "Window"; + static constexpr const char* const kInputDataset = "input_dataset"; + static constexpr const char* const kSize = "size"; + static constexpr const char* const kShift = "shift"; + static constexpr const char* const kStride = "stride"; + static constexpr const char* const kDropRemainder = "drop_remainder"; + static constexpr const char* const kOutputTypes = "output_types"; + static constexpr const char* const kOutputShapes = "output_shapes"; + + explicit WindowDatasetOp(OpKernelConstruction* ctx); + + protected: + void MakeDataset(OpKernelContext* ctx, DatasetBase* input, + DatasetBase** output) override; + + private: + class Dataset; +}; + +} // namespace data +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_DATA_WINDOW_DATASET_OP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/data/zip_dataset_op.h b/third_party/tflite-hdrs/tensorflow/core/kernels/data/zip_dataset_op.h new file mode 100644 index 00000000..1e6b294b --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/data/zip_dataset_op.h @@ -0,0 +1,44 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_KERNELS_DATA_ZIP_DATASET_OP_H_ +#define TENSORFLOW_CORE_KERNELS_DATA_ZIP_DATASET_OP_H_ + +#include "tensorflow/core/framework/dataset.h" +#include "tensorflow/core/framework/op_kernel.h" + +namespace tensorflow { +namespace data { + +class ZipDatasetOp : public DatasetOpKernel { + public: + static constexpr const char* const kDatasetType = "Zip"; + static constexpr const char* const kInputDatasets = "input_datasets"; + static constexpr const char* const kOutputTypes = "output_types"; + static constexpr const char* const kOutputShapes = "output_shapes"; + static constexpr const char* const kNumInputDatasets = "N"; + + explicit ZipDatasetOp(OpKernelConstruction* ctx); + + protected: + void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override; + + private: + class Dataset; +}; + +} // namespace data +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_DATA_ZIP_DATASET_OP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/data_format_ops.h b/third_party/tflite-hdrs/tensorflow/core/kernels/data_format_ops.h new file mode 100644 index 00000000..3d4568d5 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/data_format_ops.h @@ -0,0 +1,113 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_DATA_FORMAT_OPS_H_ +#define TENSORFLOW_CORE_KERNELS_DATA_FORMAT_OPS_H_ +// Functor definition for data format dim mapping ops, must be compilable +// by nvcc. +#include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/tensor_types.h" + +namespace tensorflow { +namespace functor { + +// Functor used by DataFormatDimMapOP to do the computations. +template +struct DataFormatDimMap { + void operator()(const Device& d, typename TTypes::ConstFlat x, + typename TTypes::Flat y, const TTypes::Vec dst) { + if (dst.size() == 4) { + auto zero = x.constant(0); + auto one = x.constant(1); + auto two = x.constant(2); + + auto f_zero = x.constant(dst(0)); + auto f_one = x.constant(dst(1)); + auto f_two = x.constant(dst(2)); + auto f_three = x.constant(dst(3)); + + auto four = x.constant(4); + auto x_mod = (x + four) % 4; + + auto is_zero = (x_mod == zero); + auto is_one = (x_mod == one); + auto is_two = (x_mod == two); + + y.device(d) = is_zero.select( + f_zero, is_one.select(f_one, is_two.select(f_two, f_three))); + } else { + auto zero = x.constant(0); + auto one = x.constant(1); + auto two = x.constant(2); + auto three = x.constant(3); + + auto f_zero = x.constant(dst(0)); + auto f_one = x.constant(dst(1)); + auto f_two = x.constant(dst(2)); + auto f_three = x.constant(dst(3)); + auto f_four = x.constant(dst(4)); + + auto five = x.constant(5); + auto x_mod = (x + five) % 5; + + auto is_zero = (x_mod == zero); + auto is_one = (x_mod == one); + auto is_two = (x_mod == two); + auto is_three = (x_mod == three); + + y.device(d) = is_zero.select( + f_zero, + is_one.select( + f_one, is_two.select(f_two, is_three.select(f_three, f_four)))); + } + } +}; + +template +struct VecPermute { + explicit VecPermute(const Eigen::DSizes& dst) + : dst(dst) {} + Eigen::DSizes dimensions( + typename TTypes::ConstFlat input) const { + Eigen::DSizes result; + result[0] = input.dimension(0); + return result; + } + template + void eval(typename TTypes::ConstFlat input, Output& output, + const Device& d) const { + for (int i = 0; i < input.size(); ++i) { + output.template chip<0>(dst[i]).device(d) = input.template chip<0>(i); + } + } + + private: + Eigen::DSizes dst; +}; + +// Functor used by DataFormatVecPermuteOp to do the computations. +template +struct DataFormatVecPermute { + void operator()(const Device& d, typename TTypes::ConstFlat x, + typename TTypes::Flat y, + const Eigen::DSizes& dst) { + y.device(d) = x.customOp(VecPermute(dst)); + } +}; + +} // namespace functor +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_DATA_FORMAT_OPS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/debug_ops.h b/third_party/tflite-hdrs/tensorflow/core/kernels/debug_ops.h new file mode 100644 index 00000000..f417caf2 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/debug_ops.h @@ -0,0 +1,959 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_DEBUG_OPS_H_ +#define TENSORFLOW_CORE_KERNELS_DEBUG_OPS_H_ + +#include +#include +#include + +#include "tensorflow/core/platform/bfloat16.h" + +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM +#include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h" +#include "tensorflow/core/common_runtime/gpu/gpu_util.h" +#include "tensorflow/core/util/determinism.h" +#endif + +#if GOOGLE_CUDA +#include "tensorflow/core/platform/cuda.h" +#elif TENSORFLOW_USE_ROCM +#include "tensorflow/core/platform/rocm.h" +#endif + +#include "tensorflow/core/debug/debug_io_utils.h" +#include "tensorflow/core/framework/device_base.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor_util.h" +#include "tensorflow/core/lib/core/notification.h" +#include "tensorflow/core/lib/strings/stringprintf.h" +#include "tensorflow/core/util/debug_events_writer.h" + +namespace tensorflow { + +// Copy op for debugging. +// Performs CPU-to-CPU or GPU-to-GPU deep-copying of tensor, depending on the +// device on which the tensor is allocated. +class CopyOp : public OpKernel { + public: + explicit CopyOp(OpKernelConstruction* context) : OpKernel(context) { + OP_REQUIRES_OK(context, context->GetAttr("tensor_name", &tensor_name_)); + + std::vector debug_ops_spec; + OP_REQUIRES_OK(context, + context->GetAttr("debug_ops_spec", &debug_ops_spec)); + for (const string& debug_op_spec : debug_ops_spec) { + // Assume debug_op_spec has the format + // ;;, e.g., + // DebugIdentity;grpc://localhost:3333;1 + const std::vector items = str_util::Split(debug_op_spec, ";"); + OP_REQUIRES( + context, items.size() == 3, + errors::Internal( + "Unexpected number of semicolons in debug_ops_spec element: ", + debug_op_spec)); + debug_op_and_url_specs_.push_back( + DebugWatchAndURLSpec(strings::StrCat(tensor_name_, ":", items[0]), + items[1], items[2] == "1")); + } + } + + void Compute(OpKernelContext* context) override { + const Tensor& src_tensor = context->input(0); + + if (src_tensor.IsInitialized() && + DataTypeCanUseMemcpy(src_tensor.dtype()) && + DebugIO::IsCopyNodeGateOpen(debug_op_and_url_specs_)) { + // Source tensor is initialized and is mem-copyable. Make a copy. + Tensor* copied_tensor; + OP_REQUIRES_OK(context, context->allocate_output(0, src_tensor.shape(), + &copied_tensor)); + +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM + Device* device = static_cast(context->device()); + // Determine if the input tensor is not on CPU (e.g., on GPU). + bool off_host_input = device->device_type() == DEVICE_GPU && + !context->input_alloc_attr(0).on_host(); + + if (off_host_input) { + DeviceContext* device_ctxt = context->op_device_context(); + // Input is not on host: deep-copy it from GPU to the same GPU. + Notification done_copy; + GPUUtil::CopyGPUTensorToSameGPU( + device, device_ctxt, &src_tensor, copied_tensor, + [&done_copy](const Status& s) { done_copy.Notify(); }); + done_copy.WaitForNotification(); + } else { + // The input tensor is on the host (CPU): deep-copy from CPU to CPU. + *copied_tensor = tensor::DeepCopy(src_tensor); + } +#else + *copied_tensor = tensor::DeepCopy(src_tensor); +#endif + } else { + // Source tensor is NOT initialized and/or is not mem-copyable: Forward + // the Tensor object. + context->set_output(0, src_tensor); + } + } + + bool IsExpensive() override { return false; } + + private: + string tensor_name_; + std::vector debug_op_and_url_specs_; +}; + +// Base class of all debug ops. +class BaseDebugOp : public OpKernel { + public: + explicit BaseDebugOp(const string& debug_op_name, + OpKernelConstruction* context) + : OpKernel(context), debug_op_name_(debug_op_name) { + OP_REQUIRES_OK(context, context->GetAttr("debug_urls", &debug_urls_)); + OP_REQUIRES_OK(context, context->GetAttr("gated_grpc", &gated_grpc_)); + + string device_name; + string tensor_name; + OP_REQUIRES_OK(context, context->GetAttr("device_name", &device_name)); + OP_REQUIRES_OK(context, context->GetAttr("tensor_name", &tensor_name)); + + std::vector name_items = str_util::Split(tensor_name, ':'); + string node_name; + int32_t output_slot = 0; + OP_REQUIRES(context, name_items.size() == 1 || name_items.size() == 2, + errors::InvalidArgument("Failed to parse tensor name: \"", + tensor_name, "\"")); + if (name_items.size() == 2) { + node_name = name_items[0]; + OP_REQUIRES( + context, absl::SimpleAtoi(name_items[1], &output_slot), + errors::InvalidArgument("Invalid string value for output_slot: \"", + name_items[1], "\"")); + } else if (name_items.size() == 1) { + node_name = name_items[0]; + } + + debug_watch_key_.reset( + new DebugNodeKey(device_name, node_name, output_slot, debug_op_name_)); + } + + bool IsExpensive() override { return false; } + + protected: + // Apply gRPC gating (if gated_grpc_ attribute is true). + // + // Returns false if and only if all grpc:// debug URLs of the debug op are + // disabled currently (i.e., gated off), in which case the debug op will emit + // an empty (size {0}) tensor of undefined data type. + bool ApplyGrpcGating(OpKernelContext* context) { + if (gated_grpc_ && !DebugIO::IsDebugNodeGateOpen( + debug_watch_key_->debug_node_name, debug_urls_)) { + // The entire node is gated off: Output an empty tensor and avoid + // expensive computation. + Tensor* output_tensor; + TensorShape shape({0}); + if (!context->allocate_output(0, shape, &output_tensor).ok()) { + LOG(ERROR) << "Debug node of watch key " + << debug_watch_key_->debug_node_name + << " failed to allocate empty tensor under gated-off state."; + } + return false; + } else { + return true; + } + } + + // Publish a tensor to all debug URLs of the debug op. + // Log an error if the publishing failed. + absl::Status PublishTensor(const Tensor& tensor, int64_t step_id = -1) { + if (debug_urls_.empty()) { + return absl::OkStatus(); + } else { + absl::Status status = DebugIO::PublishDebugTensor( + *debug_watch_key_, tensor, Env::Default()->NowMicros(), debug_urls_, + gated_grpc_, step_id); + if (!status.ok()) { + LOG(ERROR) << "Debug node of watch key " + << debug_watch_key_->debug_node_name + << " failed to publish debug tensor data to all URLs " + << absl::StrJoin(debug_urls_, ", ") + << ", due to: " << status.message(); + } + return status; + } + } + + void CompleteDebugNodeKey(const string& io_of_node, bool is_input, + int io_index) { + debug_watch_key_ = std::make_unique( + debug_watch_key_->device_name, debug_watch_key_->node_name, + debug_watch_key_->output_slot, debug_op_name_, io_of_node, is_input, + io_index); + } + + private: + const string debug_op_name_; + std::unique_ptr debug_watch_key_; + std::vector debug_urls_; + bool gated_grpc_; +}; + +// Identity op for debugging. +// Output slot 0 carries the debug signal and is always allocated on the +// host (CPU) as a non-Ref tensor. In the case of DebugIdentityOp, +// the debug signal is equal to the input tensor. +class DebugIdentityOp : public BaseDebugOp { + public: + explicit DebugIdentityOp(OpKernelConstruction* context) + : BaseDebugOp("DebugIdentity", context) {} + + void Compute(OpKernelContext* context) override { + if (!ApplyGrpcGating(context)) { + return; + } + + OP_REQUIRES_OK(context, PublishTensor(context->input(0))); + context->set_output(0, context->input(0)); + } +}; + +// Identity op for debugging. +// Output slot 0 carries the debug signal and is always allocated on the +// host (CPU) as a non-Ref tensor. In the case of DebugIdentityOp, +// the debug signal is equal to the input tensor. +class DebugIdentityV3Op : public BaseDebugOp { + public: + explicit DebugIdentityV3Op(OpKernelConstruction* context) + : BaseDebugOp("DebugIdentityV3", context) { + string io_of_node; + bool is_input; + int io_index; + OP_REQUIRES_OK(context, context->GetAttr("io_of_node", &io_of_node)); + OP_REQUIRES_OK(context, context->GetAttr("is_input", &is_input)); + OP_REQUIRES_OK(context, context->GetAttr("io_index", &io_index)); + if (!io_of_node.empty()) { + CompleteDebugNodeKey(io_of_node, is_input, io_index); + } + } + + void Compute(OpKernelContext* context) override { + if (!ApplyGrpcGating(context)) { + return; + } + + OP_REQUIRES_OK(context, + PublishTensor(context->input(0), context->step_id())); + context->set_output(0, context->input(0)); + } +}; + +// NaN-counter op for debugging. +template +class DebugNanCountOp : public BaseDebugOp { + public: + explicit DebugNanCountOp(OpKernelConstruction* context) + : BaseDebugOp("DebugNanCount", context) {} + + void Compute(OpKernelContext* context) override { + if (!ApplyGrpcGating(context)) { + return; + } + + Tensor* output_tensor; + const Tensor& input = context->input(0); + + // Use DT_INT64/int64 to be consistent with TensorShape::num_elements(). + int64_t nan_count = 0; + + // If the input is an uninitialized tensor, let nan_count be 0. + if (input.IsInitialized()) { + // Count NaNs. + const TensorShape& input_shape = input.shape(); + const T* input_flat = input.template flat().data(); + + for (int64_t i = 0; i < input_shape.num_elements(); ++i) { + if (Eigen::numext::isnan(static_cast(input_flat[i]))) { + nan_count++; + } + } + } + + TensorShape shape({1}); + OP_REQUIRES_OK(context, context->allocate_output(0, shape, &output_tensor)); + output_tensor->vec()(0) = nan_count; + OP_REQUIRES_OK(context, PublishTensor(*output_tensor)); + } +}; + +// Numeric summary op for debugging. +template +class DebugNumericSummaryOp : public BaseDebugOp { + public: + explicit DebugNumericSummaryOp(OpKernelConstruction* context) + : BaseDebugOp("DebugNumericSummary", context) { + OP_REQUIRES_OK(context, context->GetAttr("lower_bound", &lower_bound_)); + OP_REQUIRES_OK(context, context->GetAttr("upper_bound", &upper_bound_)); + OP_REQUIRES_OK(context, + context->GetAttr("mute_if_healthy", &mute_if_healthy_)); + } + + void Compute(OpKernelContext* context) override { + if (!ApplyGrpcGating(context)) { + return; + } + + Tensor* output_tensor; + const Tensor& input = context->input(0); + + int64_t is_initialized = 0; + int64_t element_count = 0; + int64_t negative_inf_count = 0; + int64_t negative_count = 0; + int64_t zero_count = 0; + int64_t positive_count = 0; + int64_t positive_inf_count = 0; + int64_t nan_count = 0; + double min = std::numeric_limits::infinity(); + double max = -std::numeric_limits::infinity(); + double sum = 0.0; + double mean = std::numeric_limits::quiet_NaN(); + double variance = std::numeric_limits::quiet_NaN(); + + // Equal to negative_count + zero_count + positive_count. + int64_t non_inf_nan_count = 0; + + const TensorShape& input_shape = input.shape(); + if (input.IsInitialized()) { + is_initialized = 1; + const T* input_flat = input.template flat().data(); + + element_count = input_shape.num_elements(); + const bool is_lower_bound_custom = !Eigen::numext::isinf(lower_bound_); + const bool is_upper_bound_custom = !Eigen::numext::isinf(upper_bound_); + + for (int64_t i = 0; i < element_count; ++i) { + const double x = static_cast(input_flat[i]); + if (Eigen::numext::isnan(x)) { + nan_count++; + } else if (Eigen::numext::isinf(x)) { + if (x < 0.0) { + negative_inf_count++; + } else { + positive_inf_count++; + } + } else { + if (is_lower_bound_custom && x <= lower_bound_) { + negative_inf_count++; + } else if (is_upper_bound_custom && x >= upper_bound_) { + positive_inf_count++; + } else if (x < 0.0) { + negative_count++; + } else if (x > 0.0) { + positive_count++; + } else { + zero_count++; + } + + if (x < min) { + min = x; + } + if (x > max) { + max = x; + } + + non_inf_nan_count++; + sum += x; + } + } + + if (non_inf_nan_count > 0) { + mean = sum / non_inf_nan_count; + + // Do a second pass to compute variance. + variance = 0.0; + for (int64_t i = 0; i < element_count; ++i) { + const double x = static_cast(input_flat[i]); + if (!Eigen::numext::isnan(x) && !Eigen::numext::isinf(x)) { + variance += (x - mean) * (x - mean); + } + } + variance /= non_inf_nan_count; + } + } + + TensorShape shape({14 + input_shape.dims()}); + OP_REQUIRES_OK(context, context->allocate_output(0, shape, &output_tensor)); + output_tensor->vec()(0) = static_cast(is_initialized); + output_tensor->vec()(1) = static_cast(element_count); + output_tensor->vec()(2) = static_cast(nan_count); + output_tensor->vec()(3) = static_cast(negative_inf_count); + output_tensor->vec()(4) = static_cast(negative_count); + output_tensor->vec()(5) = static_cast(zero_count); + output_tensor->vec()(6) = static_cast(positive_count); + output_tensor->vec()(7) = static_cast(positive_inf_count); + output_tensor->vec()(8) = min; + output_tensor->vec()(9) = max; + output_tensor->vec()(10) = mean; + output_tensor->vec()(11) = variance; + + output_tensor->vec()(12) = static_cast(input.dtype()); + output_tensor->vec()(13) = static_cast(input_shape.dims()); + for (size_t d = 0; d < input_shape.dims(); ++d) { + output_tensor->vec()(14 + d) = + static_cast(input_shape.dim_sizes()[d]); + } + + bool mute = mute_if_healthy_ && nan_count == 0 && negative_inf_count == 0 && + positive_inf_count == 0; + if (!mute) { + OP_REQUIRES_OK(context, PublishTensor(*output_tensor)); + } + } + + private: + float lower_bound_; + float upper_bound_; + bool mute_if_healthy_; +}; + +// Identity op for tfdbg v2: Writes debug data using DebugEventsWriter. +class DebugIdentityV2Op : public OpKernel { + public: + explicit DebugIdentityV2Op(OpKernelConstruction* context) + : OpKernel(context), + device_name_(context->device()->name()), + output_slot_(-1), + tensor_debug_mode_(0), + tfdbg_run_id_() { + std::vector debug_urls; + OP_REQUIRES_OK(context, context->GetAttr("debug_urls", &debug_urls)); + for (const string& debug_url : debug_urls) { + if (absl::StartsWith(debug_url, DebugIO::kFileURLScheme)) { + dump_roots_.emplace_back( + debug_url.substr(strlen(DebugIO::kFileURLScheme))); + } else { + context->SetStatus( + errors::Internal("Unsupported debug URL schema in: ", debug_url)); + } + } + OP_REQUIRES_OK(context, + context->GetAttr("tfdbg_context_id", &tfdbg_context_id_)); + OP_REQUIRES_OK(context, context->GetAttr("op_name", &op_name_)); + OP_REQUIRES_OK(context, context->GetAttr("output_slot", &output_slot_)); + OP_REQUIRES_OK(context, + context->GetAttr("tensor_debug_mode", &tensor_debug_mode_)); + if (context->HasAttr("circular_buffer_size")) { + OP_REQUIRES_OK(context, context->GetAttr("circular_buffer_size", + &circular_buffer_size_)); + } else { + circular_buffer_size_ = + tfdbg::DebugEventsWriter::kDefaultCyclicBufferSize; + } + if (context->HasAttr("tfdbg_run_id")) { + OP_REQUIRES_OK(context, context->GetAttr("tfdbg_run_id", &tfdbg_run_id_)); + } + } + + void Compute(OpKernelContext* context) override { + const Tensor& tensor = context->input(0); + for (const string& dump_root : dump_roots_) { + tfdbg::DebugEventsWriter* debug_events_writer = + tfdbg::DebugEventsWriter::GetDebugEventsWriter( + dump_root, tfdbg_run_id_, circular_buffer_size_); + OP_REQUIRES_OK(context, debug_events_writer->WriteGraphExecutionTrace( + tfdbg_context_id_, device_name_, op_name_, + output_slot_, tensor_debug_mode_, tensor)); + } + context->set_output(0, tensor); + } + + private: + std::vector dump_roots_; + string tfdbg_context_id_; + string device_name_; + string op_name_; + int32 output_slot_; + int32 tensor_debug_mode_; + int64_t circular_buffer_size_; + string tfdbg_run_id_; +}; + +typedef Eigen::ThreadPoolDevice CPUDevice; +typedef Eigen::GpuDevice GPUDevice; + +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM +template +struct CurtHealthLaunch { + void Run(const GPUDevice& d, const Tin* data, int size, Tout output[1]); +}; + +extern template struct CurtHealthLaunch; +extern template struct CurtHealthLaunch; +extern template struct CurtHealthLaunch; +extern template struct CurtHealthLaunch; +extern template struct CurtHealthLaunch; +extern template struct CurtHealthLaunch; + +template +struct ConciseHealthLaunch { + void Run(const GPUDevice& d, const Tin* data, int size, Tout output[3]); +}; + +extern template struct ConciseHealthLaunch; +extern template struct ConciseHealthLaunch; +extern template struct ConciseHealthLaunch; +extern template struct ConciseHealthLaunch; +extern template struct ConciseHealthLaunch; +extern template struct ConciseHealthLaunch; + +template +struct FullHealthLaunch { + void Run(const GPUDevice& d, const Tin* data, int size, Tout output[6]); +}; + +extern template struct FullHealthLaunch; +extern template struct FullHealthLaunch; +extern template struct FullHealthLaunch; +extern template struct FullHealthLaunch; +extern template struct FullHealthLaunch; +extern template struct FullHealthLaunch; + +template +struct ReduceInfNanThreeSlotsLaunch { + void Run(const GPUDevice& d, const Tin* data, int size, Tout output[3]); +}; + +extern template struct ReduceInfNanThreeSlotsLaunch; +extern template struct ReduceInfNanThreeSlotsLaunch; +extern template struct ReduceInfNanThreeSlotsLaunch; +extern template struct ReduceInfNanThreeSlotsLaunch; +extern template struct ReduceInfNanThreeSlotsLaunch; +extern template struct ReduceInfNanThreeSlotsLaunch; + +#endif + +template +class DebugNumericSummaryV2Op; + +// Numeric summary op for tfdbg v2: CPU Kernel. +template +class DebugNumericSummaryV2Op : public OpKernel { + public: + explicit DebugNumericSummaryV2Op(OpKernelConstruction* context) + : OpKernel(context) { + OP_REQUIRES_OK(context, + context->GetAttr("tensor_debug_mode", &tensor_debug_mode_)); + OP_REQUIRES_OK(context, context->GetAttr("tensor_id", &tensor_id_)); + } + + void Compute(OpKernelContext* context) override { + const Tensor& tensor = context->input(0); + auto in = tensor.flat(); + const Tin* data = in.data(); + const int64_t size = in.size(); + Tensor* output_tensor; + Tout tensor_id = static_cast(tensor_id_); + const Tout num_elem = static_cast(context->input(0).NumElements()); + // Disregard lossy cast if mode is REDUCE_INF_NAN_THREE_SLOTS because + // that mode does not make use of tensor_id. + if (tensor_debug_mode_ != 8) { + OP_REQUIRES( + context, tensor_id_ <= kMaxTensorId, + errors::InvalidArgument("DebugNumericSummaryV2Op requires " + "tensor_id to be less than or equal to " + "(2^", + std::numeric_limits::digits, + "). Given tensor_id:", tensor_id_)); + } + + if (tensor_debug_mode_ == 2) { // CURT_HEALTH + TensorShape shape({2}); + OP_REQUIRES_OK(context, + context->allocate_output(0, shape, &output_tensor)); + output_tensor->flat()(0) = tensor_id; // Slot tensor id + output_tensor->flat()(1) = 0.0; // Has inf or nan + int fp_props = + std::accumulate(data, data + size, 0, [](const int x, const Tin& y) { + return Eigen::numext::isfinite(y) ? x : 1; + }); + if (fp_props) { + output_tensor->flat()(1) = 1.0; + } + } else if (tensor_debug_mode_ == 3) { // CONCISE_HEALTH + TensorShape shape({5}); + OP_REQUIRES_OK(context, + context->allocate_output(0, shape, &output_tensor)); + output_tensor->flat()(0) = tensor_id; + output_tensor->flat()(1) = num_elem; + + // Accumulator value [neg_inf_count, pos_inf_count, nan_count] + Tout fp_props[3] = {0.0, 0.0, 0.0}; + std::for_each(data, data + size, [&fp_props](const Tin& y) { + if (TF_PREDICT_TRUE(Eigen::numext::isfinite(y))) { + // Do nothing: common case. + } else if (Eigen::numext::isinf(y)) { + if (y < static_cast(0.f)) { + ++fp_props[0]; + } else { + ++fp_props[1]; + } + } else if (Eigen::numext::isnan(y)) { + ++fp_props[2]; + } + }); + output_tensor->flat()(2) = fp_props[0]; // Slot for -inf count + output_tensor->flat()(3) = fp_props[1]; // Slot for inf count + output_tensor->flat()(4) = fp_props[2]; // Slot for nan count + } else if (tensor_debug_mode_ == 4) { // FULL HEALTH + TensorShape shape({11}); + OP_REQUIRES_OK(context, + context->allocate_output(0, shape, &output_tensor)); + int num_dims = tensor.dims(); + output_tensor->flat()(0) = tensor_id; + output_tensor->flat()(1) = -1.0; // TODO(144919262): Device ID + output_tensor->flat()(2) = static_cast(tensor.dtype()); + output_tensor->flat()(3) = static_cast(num_dims); + output_tensor->flat()(4) = num_elem; + + // Accumulator value [neg_inf_count, pos_inf_count, nan_count, neg_count, + // zero_count, pos_count] + Tout fp_props[6] = {0.0, 0.0, 0.0, 0.0, 0.0, 0.0}; + std::for_each(data, data + size, [&fp_props](const Tin& y) { + if (TF_PREDICT_TRUE(Eigen::numext::isfinite(y))) { + if (y < static_cast(0.f)) { + ++fp_props[3]; + } else if (y == static_cast(0.f)) { + ++fp_props[4]; + } else { + ++fp_props[5]; + } + } else if (Eigen::numext::isinf(y)) { + if (y < static_cast(0.f)) { + ++fp_props[0]; + } else { + ++fp_props[1]; + } + } else if (Eigen::numext::isnan(y)) { + ++fp_props[2]; + } + }); + output_tensor->flat()(5) = fp_props[0]; // Slot for -inf count + output_tensor->flat()(6) = fp_props[1]; // Slot for inf count + output_tensor->flat()(7) = fp_props[2]; // Slot for nan count. + output_tensor->flat()(8) = fp_props[3]; // Slot for neg count. + output_tensor->flat()(9) = fp_props[4]; // Slot for zero count. + output_tensor->flat()(10) = fp_props[5]; // Slot for pos count. + } else if (tensor_debug_mode_ == 5) { // SHAPE + TensorShape shape({10}); + OP_REQUIRES_OK(context, + context->allocate_output(0, shape, &output_tensor)); + + int num_dims = tensor.dims(); + output_tensor->flat()(0) = tensor_id; + output_tensor->flat()(1) = static_cast(tensor.dtype()); + output_tensor->flat()(2) = static_cast(num_dims); + output_tensor->flat()(3) = num_elem; + + // Tensor shape - stored as (6 columns) + // if num_dim is less than 6, we right pad the shape with zeros + // if num_dim is greater than 6, we truncate the head (left most) of the + // dimensions as they are more predictable than the last few (e.g. batch + // size as first dimension) + int dim_idx = 4; + for (int i = std::max(0, num_dims - kShapeDims); + i < std::max(6, num_dims); ++i) { + if (i < num_dims) { + output_tensor->flat()(dim_idx++) = + static_cast(tensor.dim_size(i)); + } else { + output_tensor->flat()(dim_idx++) = 0.0; + } + } + } else if (tensor_debug_mode_ == 8) { // REDUCE_INF_NAN_THREE_SLOTS. + TensorShape shape({3}); + OP_REQUIRES_OK(context, + context->allocate_output(0, shape, &output_tensor)); + output_tensor->flat()(0) = 0.0; // Slot for -inf. + output_tensor->flat()(1) = 0.0; // Slot for inf. + output_tensor->flat()(2) = 0.0; // Slot for nan. + + int fp_props = + std::accumulate(data, data + size, 0, [](const int x, const Tin& y) { + int result = x; + if (TF_PREDICT_TRUE(Eigen::numext::isfinite(y))) { + // Do nothing: common case. + } else if (Eigen::numext::isinf(y)) { + result |= y < static_cast(0.f) ? kNegInfBit : kPosInfBit; + } else if (Eigen::numext::isnan(y)) { + result |= kNaNBit; + } + return result; + }); + + if (fp_props & kNegInfBit) { + output_tensor->flat()(0) = -std::numeric_limits::infinity(); + } + if (fp_props & kPosInfBit) { + output_tensor->flat()(1) = std::numeric_limits::infinity(); + } + if (fp_props & kNaNBit) { + output_tensor->flat()(2) = std::numeric_limits::quiet_NaN(); + } + } else { + // TODO(cais): Implement other tensor debug modes in debug_event.proto. + context->SetStatus(errors::Unimplemented( + "Unimplemented tensor debug mode: ", tensor_debug_mode_)); + } + } + + private: + int tensor_debug_mode_; + int64_t tensor_id_; + static constexpr int kShapeDims = 6; + static constexpr int kNegInfBit = 0x01; + static constexpr int kPosInfBit = 0x02; + static constexpr int kNaNBit = 0x04; + static constexpr int64_t kMaxTensorId = 1LL + << std::numeric_limits::digits; +}; + +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM + +template +class DebugNumericSummaryV2Op : public AsyncOpKernel { + public: + typedef GPUDevice Device; + + explicit DebugNumericSummaryV2Op(OpKernelConstruction* context) + : AsyncOpKernel(context) { + OP_REQUIRES_OK(context, + context->GetAttr("tensor_debug_mode", &tensor_debug_mode_)); + OP_REQUIRES_OK(context, context->GetAttr("tensor_id", &tensor_id_)); + } + + void ComputeAsync(OpKernelContext* context, DoneCallback done) override { + Tensor* output_tensor; + Tout tensor_id = static_cast(tensor_id_); + const Tensor& tensor = context->input(0); + const Tout num_elem = static_cast(tensor.NumElements()); + const Device& d = context->eigen_device(); + auto input = tensor.flat(); + auto check_cb = [this, done]() { done(); }; + // Disregard lossy cast if mode is REDUCE_INF_NAN_THREE_SLOTS because + // that mode does not make use of tensor_id. + if (tensor_debug_mode_ != 8) { + OP_REQUIRES_ASYNC( + context, tensor_id_ <= kMaxTensorId, + errors::InvalidArgument("DebugNumericSummaryV2Op requires " + "tensor_id to be less than or equal to " + "(2^", + std::numeric_limits::digits, + "). Given tensor_id:", tensor_id_), + done); + } + + if (tensor_debug_mode_ == 2) { // CURT_HEALTH. + TensorShape shape({2}); + OP_REQUIRES_OK(context, + context->allocate_output(0, shape, &output_tensor)); + + auto* stream = context->op_device_context()->stream(); + OP_REQUIRES_ASYNC(context, stream != nullptr, + errors::Internal("No GPU stream available."), done); + + se::DeviceMemoryBase output_tensor_ptr( + output_tensor->flat().data(), + output_tensor->flat().size()); + OP_REQUIRES_OK(context, + stream->MemZero(&output_tensor_ptr, 2 * sizeof(Tout))); + // Copy tensor_id to slot zero + OP_REQUIRES_OK(context, stream->Memcpy(&output_tensor_ptr, &tensor_id, + sizeof(Tout))); + if (num_elem == 0) { + done(); + return; + } + + // Call the GPU kernels for the numerical (inf/nan) checks. + auto input = context->input(0).flat(); + CurtHealthLaunch().Run(d, input.data(), input.size(), + output_tensor->flat().data() + 1); + + context->device() + ->tensorflow_accelerator_device_info() + ->event_mgr->ThenExecute(stream, std::move(check_cb)); + } else if (tensor_debug_mode_ == 3) { // CONCISE_HEALTH. + TensorShape shape({5}); + OP_REQUIRES_OK(context, + context->allocate_output(0, shape, &output_tensor)); + OP_REQUIRES_ASYNC(context, !tensorflow::OpDeterminismRequired(), + errors::Unimplemented( + "Determinism is not yet supported for " + "DebugNumericSummaryV2 when tensor_debug_mode is " + "CONCISE_HEALTH."), + done); + + auto* stream = context->op_device_context()->stream(); + OP_REQUIRES_ASYNC(context, stream != nullptr, + errors::Internal("No GPU stream available."), done); + + se::DeviceMemoryBase output_tensor_ptr( + output_tensor->flat().data(), + output_tensor->flat().size()); + OP_REQUIRES_OK(context, + stream->Memset32(&output_tensor_ptr, 0, 5 * sizeof(Tout))); + const Tout static_output[] = {tensor_id, num_elem}; + OP_REQUIRES_OK(context, stream->Memcpy(&output_tensor_ptr, &static_output, + 2 * sizeof(Tout))); + if (num_elem == 0) { + done(); + return; + } + + // Call the GPU kernels for the numerical (inf/nan) checks. + ConciseHealthLaunch().Run( + d, input.data(), input.size(), + output_tensor->flat().data() + 2); + + context->device() + ->tensorflow_accelerator_device_info() + ->event_mgr->ThenExecute(stream, std::move(check_cb)); + } else if (tensor_debug_mode_ == 4) { // FULL HEALTH + TensorShape shape({11}); + OP_REQUIRES_OK(context, + context->allocate_output(0, shape, &output_tensor)); + + auto* stream = context->op_device_context()->stream(); + OP_REQUIRES_ASYNC(context, stream != nullptr, + errors::Internal("No GPU stream available."), done); + OP_REQUIRES_ASYNC(context, !tensorflow::OpDeterminismRequired(), + errors::Unimplemented( + "Determinism is not yet supported for " + "DebugNumericSummaryV2 when tensor_debug_mode is " + "FULL_HEALTH."), + done); + + se::DeviceMemoryBase output_tensor_ptr( + output_tensor->flat().data(), + output_tensor->flat().size()); + OP_REQUIRES_OK( + context, stream->Memset32(&output_tensor_ptr, 0, 11 * sizeof(Tout))); + + int num_dims = tensor.dims(); + const Tout static_output[] = {tensor_id, + -1.0, // TODO(144919262): Device ID + static_cast(tensor.dtype()), + static_cast(num_dims), num_elem}; + OP_REQUIRES_OK(context, stream->Memcpy(&output_tensor_ptr, &static_output, + 5 * sizeof(Tout))); + if (num_elem == 0) { + done(); + return; + } + + // Call the GPU kernels for the numerical (inf/nan) checks and + // pos/neg/zero counts. + FullHealthLaunch().Run(d, input.data(), input.size(), + output_tensor->flat().data() + 5); + + context->device() + ->tensorflow_accelerator_device_info() + ->event_mgr->ThenExecute(stream, std::move(check_cb)); + } else if (tensor_debug_mode_ == 5) { // SHAPE + TensorShape shape({10}); + OP_REQUIRES_OK(context, + context->allocate_output(0, shape, &output_tensor)); + + auto* stream = context->op_device_context()->stream(); + OP_REQUIRES_ASYNC(context, stream != nullptr, + errors::Internal("No GPU stream available."), done); + + se::DeviceMemoryBase output_tensor_ptr( + output_tensor->flat().data(), + output_tensor->flat().size()); + + int num_dims = tensor.dims(); + Tout static_output[10] = {tensor_id, + static_cast(tensor.dtype()), + static_cast(num_dims), + num_elem, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0}; + // Tensor shape: right pad zeros, truncate head + int dim_idx = 4; + for (int i = std::max(0, num_dims - 6); i < num_dims; ++i) { + static_output[dim_idx++] = static_cast(tensor.dim_size(i)); + } + // Write to device stream + OP_REQUIRES_OK(context, stream->Memcpy(&output_tensor_ptr, &static_output, + sizeof(Tout) * 10)); + context->device() + ->tensorflow_accelerator_device_info() + ->event_mgr->ThenExecute(stream, std::move(check_cb)); + } else if (tensor_debug_mode_ == 8) { // REDUCE_INF_NAN_THREE_SLOTS. + TensorShape shape({3}); + OP_REQUIRES_OK(context, + context->allocate_output(0, shape, &output_tensor)); + + auto* stream = context->op_device_context()->stream(); + OP_REQUIRES_ASYNC(context, stream != nullptr, + errors::Internal("No GPU stream available."), done); + + se::DeviceMemoryBase output_tensor_ptr( + output_tensor->flat().data(), + output_tensor->flat().size()); + OP_REQUIRES_OK( + context, + stream->Memset32(&output_tensor_ptr, 0, + output_tensor->flat().size() * sizeof(Tout))); + if (num_elem == 0) { + done(); + return; + } + + // Call the GPU kernels for the numerical (inf/nan) checks. + auto input = context->input(0).flat(); + ReduceInfNanThreeSlotsLaunch().Run( + d, input.data(), input.size(), output_tensor->flat().data()); + + context->device() + ->tensorflow_accelerator_device_info() + ->event_mgr->ThenExecute(stream, std::move(check_cb)); + } else { + // TODO(cais): Implement other tensor debug modes in debug_event.proto. + context->SetStatus(errors::Unimplemented( + "Unimplemented tensor debug mode: ", tensor_debug_mode_)); + done(); + } + } + + private: + int tensor_debug_mode_; + int64_t tensor_id_; + static constexpr int64_t kMaxTensorId = 1L + << std::numeric_limits::digits; +}; + +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_DEBUG_OPS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/deep_conv2d.h b/third_party/tflite-hdrs/tensorflow/core/kernels/deep_conv2d.h new file mode 100644 index 00000000..c484db38 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/deep_conv2d.h @@ -0,0 +1,117 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_DEEP_CONV2D_H_ +#define TENSORFLOW_CORE_KERNELS_DEEP_CONV2D_H_ + +#include "tensorflow/core/framework/types.h" + +namespace tensorflow { + +class OpKernelContext; + +// DeepConv2D is a Conv2D implementation specialized for deep (i.e. large +// in_depth * out_depth product) convolutions (see deep_conv2d.cc for details). + +// DeepConv2DTransform is an interface for implementing transforms for +// DeepConv2D. Implementations must specify transform matrices and +// input/output/filter shapes. DeepConv2d computes: +// +// y = C[Ad * Bg] +// +// C: output transform matrix +// A: input data transform matrix +// B: filter transform matrix +// d: vectorized 2D data tile +// g: vectorized 2D filter tile +// y: vectorized 2D output tile + +template +class DeepConv2DTransform { + public: + virtual ~DeepConv2DTransform() {} + + virtual void GetFilterTransformMatrix(const int64_t rows, const int64_t cols, + T* transform_matrix) const = 0; + + virtual void GetInputTransformMatrix(const int64_t rows, const int64_t cols, + T* transform_matrix) const = 0; + + virtual void GetOutputTransformMatrix(const int64_t rows, const int64_t cols, + T* transform_matrix) const = 0; + + struct Shape { + Shape(int64_t r, int64_t c) : rows(r), cols(c) {} + int64_t rows; + int64_t cols; + }; + + virtual const Shape& filter_shape() const = 0; + virtual const Shape& input_shape() const = 0; + virtual const Shape& output_shape() const = 0; +}; + +// Conv2D arguments used by DeepConv2D implementation. +struct Conv2DArgs { + // Input layer dimensions + int batch; + int in_rows; + int in_cols; + int in_depth; + int filter_rows; + int filter_cols; + int pad_rows; + int pad_cols; + + // Output layer dimensions + int out_rows; + int out_cols; + int out_depth; + + Conv2DArgs() + : batch(0), + in_rows(0), + in_cols(0), + in_depth(0), + filter_rows(0), + filter_cols(0), + pad_rows(0), + pad_cols(0), + out_rows(0), + out_cols(0), + out_depth(0) {} +}; + +// Returns true if convolution operation specified by function arguments +// can use DeepConv2D implementation, and false otherwise. +// May return false based on parameters, cost, or whether feature is disabled. +bool CanUseDeepConv2D(int stride_rows, int stride_cols, int filter_rows, + int filter_cols, int in_depth, int out_depth, + int out_rows, int out_cols); + +namespace functor { + +// Calls DeepConv2D implementation (see deep_conv2d.cc for details). +template +struct DeepConv2D { + void operator()(OpKernelContext* ctx, const Conv2DArgs& args, const T* input, + const T* filter, T* output); +}; + +} // namespace functor + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_DEEP_CONV2D_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/dense_update_functor.h b/third_party/tflite-hdrs/tensorflow/core/kernels/dense_update_functor.h new file mode 100644 index 00000000..c16db936 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/dense_update_functor.h @@ -0,0 +1,81 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_DENSE_UPDATE_FUNCTOR_H_ +#define TENSORFLOW_CORE_KERNELS_DENSE_UPDATE_FUNCTOR_H_ + +#define EIGEN_USE_THREADS + +#include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor_types.h" + +namespace tensorflow { + +typedef Eigen::ThreadPoolDevice CPUDevice; +typedef Eigen::GpuDevice GPUDevice; + + +enum DenseUpdateType { ADD, SUB, ASSIGN }; + +namespace functor { + +template +struct DenseUpdate { + void operator()(const Device& d, typename TTypes::Flat params, + typename TTypes::ConstFlat update); +}; + +template +struct DenseUpdate { + void operator()(const CPUDevice& d, typename TTypes::Flat params, + typename TTypes::ConstFlat update) { + params.device(d) += update; + } +}; + +template +struct DenseUpdate { + void operator()(const CPUDevice& d, typename TTypes::Flat params, + typename TTypes::ConstFlat update) { + params.device(d) -= update; + } +}; + +template +struct DenseUpdate { + void operator()(const CPUDevice& d, typename TTypes::Flat params, + typename TTypes::ConstFlat update) { + params.device(d) = update; + } +}; + + +} // end namespace functor + +template +absl::Status VariantCopyFn(OpKernelContext* context, const Tensor& from, + Tensor* to); + +template <> +absl::Status VariantCopyFn(OpKernelContext* context, + const Tensor& from, Tensor* to); +template <> +absl::Status VariantCopyFn(OpKernelContext* context, + const Tensor& from, Tensor* to); + +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_DENSE_UPDATE_FUNCTOR_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/depthtospace_op.h b/third_party/tflite-hdrs/tensorflow/core/kernels/depthtospace_op.h new file mode 100644 index 00000000..63dba5d0 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/depthtospace_op.h @@ -0,0 +1,56 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_DEPTHTOSPACE_OP_H_ +#define TENSORFLOW_CORE_KERNELS_DEPTHTOSPACE_OP_H_ + +#include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/util/tensor_format.h" + +namespace tensorflow { +namespace functor { + +// Functor used by DepthToSpaceOp to do the computations. +// Implements a family of Depth to Space transforms for a 4D 'input' tensor +// to a 4D 'output' tensor, both tensors use type 'T' and layout 'data_format'. +// These transforms multiply the vertical and horizontal image sizes by +// 'block_size', and divide the depth dimension by (block_size * block_size) +// which must divide evenly. +// Each pixel in the input image is converted to a square block of pixels in +// the output image. The Y, X coordinates within each block comes from the +// high component of the input depth (channel) index. +// e.g. for data_format = NHWC: +// Each element in the input tensor can be specified via 6 coordinates, +// ordered by decreasing memory layout significance as: +// n,iY,iX,bY,bX,oC (where n=batch index, iX, iY means X or Y coordinates +// within the input image, bX, bY means coordinates +// within the output block, oC means output channel). +// The output would be a transpose to the following layout: +// n,iY,bY,iX,bX,oC +template +struct DepthToSpaceOpFunctor { + void operator()(const Device& d, typename TTypes::ConstTensor input, + int block_size, typename TTypes::Tensor output); + + // This 5-D version is to support NCHW_VECT_C. + void operator()(const Device& d, typename TTypes::ConstTensor input, + int block_size, typename TTypes::Tensor output); +}; + +} // namespace functor +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_DEPTHTOSPACE_OP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/depthwise_conv_op.h b/third_party/tflite-hdrs/tensorflow/core/kernels/depthwise_conv_op.h new file mode 100644 index 00000000..1114caab --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/depthwise_conv_op.h @@ -0,0 +1,352 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_DEPTHWISE_CONV_OP_H_ +#define TENSORFLOW_CORE_KERNELS_DEPTHWISE_CONV_OP_H_ + +#include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/util/tensor_format.h" + +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM +#include "tensorflow/core/platform/stream_executor.h" +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM + +namespace tensorflow { + +struct DepthwiseArgs { + // Input layer dimensions + int batch; + int in_rows; + int in_cols; + int in_depth; + int filter_rows; + int filter_cols; + int depth_multiplier; + int stride; + int pad_rows; // Amount of padding to the top of the input + int pad_cols; // Amount of padding to the left of the input + + // Output layer dimensions + int out_rows; + int out_cols; + int out_depth; + + DepthwiseArgs() + : batch(0), + in_rows(0), + in_cols(0), + in_depth(0), + filter_rows(0), + filter_cols(0), + depth_multiplier(0), + stride(0), + pad_rows(0), + pad_cols(0), + out_rows(0), + out_cols(0), + out_depth(0) {} +}; + +// Forward declaration. +class OpKernelContext; + +template +struct LaunchDepthwiseConvOp { + void operator()(OpKernelContext* ctx, const DepthwiseArgs& args, + const T* input, const T* filter, T* output, + TensorFormat data_format); +}; + +template +struct LaunchDepthwiseConvBackpropInputOp { + void operator()(OpKernelContext* ctx, const DepthwiseArgs& args, + const T* out_backprop, const T* filter, T* in_backprop, + TensorFormat data_format); +}; + +template +struct LaunchDepthwiseConvBackpropFilterOp { + void operator()(OpKernelContext* ctx, const DepthwiseArgs& args, + const T* out_backprop, const T* input, T* filter_backprop, + TensorFormat data_format); +}; + +bool UseCudnnWith16BitFloat(OpKernelContext* ctx, DataType dtype); + +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM +template +struct LaunchDepthwiseConvOp { + void operator()(OpKernelContext* ctx, const DepthwiseArgs& args, + const T* input, const T* filter, T* output, + TensorFormat data_format); +}; + +template +struct LaunchDepthwiseConvBackpropInputOp { + void operator()(class OpKernelContext* ctx, const DepthwiseArgs& args, + const T* out_backprop, const T* filter, T* in_backprop, + TensorFormat data_format); +}; + +template +struct LaunchDepthwiseConvBackpropFilterOp { + void operator()(class OpKernelContext* ctx, const DepthwiseArgs& args, + const T* out_backprop, const T* input, T* filter_backprop, + TensorFormat data_format); +}; +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM + +} // namespace tensorflow + +namespace tensorflow { +namespace functor { + +// Pads 'filter' to vector-register boundary along its inner dimension: +// filter_inner_dim_size = in_depth * depth_multiplier +// Requires 'filter' to have the following storage order: +// [filter_rows, filter_cols, in_depth, depth_multiplier] +// Returns zero-padded filter in 'padded_filter'. +// +// EX: +// in_depth = 3, depth_multiplier = 2, filter [2, 2], register_width = 4 +// So we have a total of 3 * 2 = 6 filters, each of spatial size 2 x 2. +// +// filter [rows, cols, in_depth, depth_multiplier] +// [u0, v0, w0, x0] [y0, z0, u1, v1] [w1, x1, y1, z1] +// [u2, v2, w2, x2] [y2, z2, u3, v3] [w3, x3, y3, z3] +// +// padded_filter [rows, cols, in_depth, depth_multiplier] +// [u0, v0, w0, x0] [y0, z0, 0, 0] [u1, v1, w1, x1] [y1, z1, 0, 0] +// [u2, v2, w2, x2] [y2, z2, 0, 0] [u3, v3, w3, x3] [y3, z3, 0, 0] + +template +struct DepthwiseFilterPadOp { + void operator()(const DepthwiseArgs& args, const T* filter, + T* padded_filter) { + typedef typename Eigen::internal::packet_traits::type Packet; + static const int64_t kPacketSize = (sizeof(Packet) / sizeof(T)); + + // Calculate vectorized and scalar lengths of filter's inner dimension. + const int64_t filter_inner_dim_size = args.out_depth; + const int64_t vectorized_size = + (filter_inner_dim_size / kPacketSize) * kPacketSize; + const int64_t scalar_size = filter_inner_dim_size - vectorized_size; + // Calculate required padding and padded output buffer stride. + const int64_t pad_size = scalar_size > 0 ? kPacketSize - scalar_size : 0; + const int64_t padded_filter_stride = vectorized_size + kPacketSize; + + const int64_t filter_spatial_size = args.filter_rows * args.filter_cols; + for (int64_t i = 0; i < filter_spatial_size; ++i) { + const int64_t input_base = i * filter_inner_dim_size; + const int64_t output_base = i * padded_filter_stride; + // Write vectorized length of filter's inner dimension to output. + for (int64_t j = 0; j < vectorized_size; j += kPacketSize) { + const auto v = Eigen::internal::ploadu(filter + input_base + j); + Eigen::internal::pstoreu(padded_filter + output_base + j, v); + } + // Write scalar length of filter's inner dimension to output. + for (int64_t j = 0; j < scalar_size; ++j) { + padded_filter[output_base + vectorized_size + j] = + filter[input_base + vectorized_size + j]; + } + // Pad the remainder of output to vector-register boundary. + for (int64_t j = 0; j < pad_size; ++j) { + padded_filter[output_base + vectorized_size + scalar_size + j] = + static_cast(0); + } + } + } +}; + +// Copies data from local region in 'input' specified by 'out_r' and 'out_'c' +// to 'input_buffer'. The copied data is replicated by factor +// 'args.depth_multiplier', and padded to vector register-width boundaries so +// that it is aligned for efficient traversal and vector multiply-add by the +// depthwise kernel. +// +// EX: +// in_depth = 3, depth_multiplier = 2, filter [2, 2], register_width = 4 +// +// input: [batch, in_rows, in_cols, in_depth] +// +// [a0, a1, a2, b0, b1, b2, ..., e0, e1, e2, f0, f1, f2, ...] +// +// input_buffer (register boundaries shown): +// [a0, a0, a1, a1] [a2, a2, 0, 0] in_row = 0, in_col = 0 +// [b0, b0, b1, b1] [b2, b2, 0, 0] in_row = 0, in_col = 1 +// [e0, e0, e1, e1] [e2, e2, 0, 0] in_row = 1, in_col = 0 +// [f0, f0, f1, f1] [f2, f2, 0, 0] in_row = 1, in_col = 1 +// +// Returns replicated and padded data from specified input region in +// 'input_buffer'. + +template +struct DepthwiseInputCopyOp { + void operator()(const DepthwiseArgs& args, + const int64_t padded_filter_inner_dim_size, + const int64_t out_r, const int64_t out_c, const T* input, + T* input_buffer) { + typedef typename Eigen::internal::packet_traits::type Packet; + static const int64_t kPacketSize = Eigen::internal::packet_traits::size; + + const int64_t kDepth = args.depth_multiplier; + // Calculate vectorized and scalar (residual) lengths for 'in_depth'. + const int64_t input_vectorized_size = + (args.in_depth / kPacketSize) * kPacketSize; + const int64_t input_scalar_size = args.in_depth - input_vectorized_size; + + // Calculate output padding length. + const int64_t output_scalar_size = args.out_depth % kPacketSize; + const int64_t output_pad_size = + output_scalar_size > 0 ? kPacketSize - output_scalar_size : 0; + + // Iterate through all rows x cols reading 'in_depth' from 'input' and + // replicating by 'depth_multiplier' into 'input_buffer' (otherwise + // zero-padding input buffer as needed). + auto* in_buf = input_buffer; + const int64_t in_r_start = out_r * args.stride - args.pad_rows; + const int64_t in_c_start = out_c * args.stride - args.pad_cols; + + // TODO: add a ploaddup variant for depth == 2 if needed. + if (kDepth > 1 && kDepth <= kPacketSize) { + for (int64_t f_r = 0; f_r < args.filter_rows; ++f_r) { + const int64_t in_r = in_r_start + f_r; + + for (int64_t f_c = 0; f_c < args.filter_cols; ++f_c) { + const int64_t in_c = in_c_start + f_c; + + if (in_r >= 0 && in_r < args.in_rows && in_c >= 0 && + in_c < args.in_cols) { + const auto* in = + input + (in_r * args.in_cols + in_c) * args.in_depth; + int64_t limit = args.in_depth; + // This will overwrite up to kPacketSize next elements, + // this is ok on all iterations except the last one, since + // we will write correct values on a next iteration. + if (f_c == args.filter_cols - 1) { + limit -= (kPacketSize - kDepth) / kDepth + 1; + if (limit < 0) { + limit = 0; + } + } + // Copy vectorized portion of inner dimension. + for (int64_t d = 0; d < limit; d++) { + const auto p = Eigen::internal::pset1(in[d]); + Eigen::internal::pstoreu(in_buf, p); + in_buf += kDepth; + } + + // Copy the scalar portion. + for (int64_t d = limit; d < args.in_depth; d++) { + const auto value = in[d]; + for (int64_t dm = 0; dm < kDepth; dm++) { + in_buf[dm] = value; + } + in_buf += kDepth; + } + + // Pad the remainder of the output to vector register boundary. + for (int64_t d = 0; d < output_pad_size; ++d) { + in_buf[d] = static_cast(0); + } + in_buf += output_pad_size; + } else { + // Zero pad. + memset(in_buf, 0, sizeof(T) * padded_filter_inner_dim_size); + in_buf += padded_filter_inner_dim_size; + } + } + } + } else if (kDepth > kPacketSize) { + // Calculate vectorized and scalar (residual) lengths for + // 'depth_multiplier'. This is used to efficiently replicate data for + // when 'depth_multiplier' > kPacketSize. + const int64_t dm_vectorized_size = (kDepth / kPacketSize) * kPacketSize; + + for (int64_t f_r = 0; f_r < args.filter_rows; ++f_r) { + const int64_t in_r = in_r_start + f_r; + + for (int64_t f_c = 0; f_c < args.filter_cols; ++f_c) { + const int64_t in_c = in_c_start + f_c; + + if (in_r >= 0 && in_r < args.in_rows && in_c >= 0 && + in_c < args.in_cols) { + const auto* in = + input + (in_r * args.in_cols + in_c) * args.in_depth; + // Copy vectorized portion of inner dimension. + for (int64_t d = 0; d < args.in_depth; d++) { + const auto p = Eigen::internal::pset1(in[d]); + for (int64_t dm = 0; dm < dm_vectorized_size; dm += kPacketSize) { + Eigen::internal::pstoreu(in_buf + dm, p); + } + // Overlapping store for the remainder. + Eigen::internal::pstoreu(in_buf + kDepth - kPacketSize, p); + in_buf += kDepth; + } + // Pad the remainder of the output to vector register boundary. + for (int64_t d = 0; d < output_pad_size; ++d) { + in_buf[d] = static_cast(0); + } + in_buf += output_pad_size; + } else { + // Zero pad. + memset(in_buf, 0, sizeof(T) * padded_filter_inner_dim_size); + in_buf += padded_filter_inner_dim_size; + } + } + } + } else if (kDepth == 1) { + for (int64_t f_r = 0; f_r < args.filter_rows; ++f_r) { + const int64_t in_r = in_r_start + f_r; + + for (int64_t f_c = 0; f_c < args.filter_cols; ++f_c) { + const int64_t in_c = in_c_start + f_c; + + if (in_r >= 0 && in_r < args.in_rows && in_c >= 0 && + in_c < args.in_cols) { + const auto* in = + input + (in_r * args.in_cols + in_c) * args.in_depth; + for (int64_t d = 0; d < input_vectorized_size; d += kPacketSize) { + const auto p = Eigen::internal::ploadu(in + d); + Eigen::internal::pstoreu(in_buf, p); + in_buf += kPacketSize; + } + for (int64_t d = 0; d < input_scalar_size; ++d) { + T v = in[input_vectorized_size + d]; + in_buf[d] = v; + } + in_buf += input_scalar_size; + + // Pad the remainder of the output to vector register boundary. + for (int64_t d = 0; d < output_pad_size; ++d) { + in_buf[d] = static_cast(0); + } + in_buf += output_pad_size; + } else { + // Zero pad. + memset(in_buf, 0, sizeof(T) * padded_filter_inner_dim_size); + in_buf += padded_filter_inner_dim_size; + } + } + } + } + } +}; + +} // namespace functor +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_DEPTHWISE_CONV_OP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/depthwise_conv_op_gpu.h b/third_party/tflite-hdrs/tensorflow/core/kernels/depthwise_conv_op_gpu.h new file mode 100644 index 00000000..b058ef26 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/depthwise_conv_op_gpu.h @@ -0,0 +1,1759 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_DEPTHWISE_CONV_OP_GPU_H_ +#define TENSORFLOW_CORE_KERNELS_DEPTHWISE_CONV_OP_GPU_H_ + +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM +#define EIGEN_USE_GPU + +#include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/kernels/depthwise_conv_op.h" +#include "tensorflow/core/kernels/gpu_prim.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/determinism.h" +#include "tensorflow/core/util/gpu_kernel_helper.h" +#include "tensorflow/core/util/tensor_format.h" + +#if defined(_MSC_VER) && !defined(__clang__) +#define UNROLL +#define NOUNROLL +#else +#define UNROLL _Pragma("unroll") +#define NOUNROLL _Pragma("nounroll") +#endif + +namespace tensorflow { + +namespace detail { +template +struct PseudoHalfType { + using Type = T; +}; +template <> +struct PseudoHalfType { + using Type = float; +}; +template <> +struct PseudoHalfType { + using Type = float; +}; +} // namespace detail + +using Eigen::GpuDevice; + +// Returns whether depthwise convolution forward or backward input pass can be +// performed using the faster ('Small') variant of the kernel. +inline EIGEN_DEVICE_FUNC bool CanLaunchDepthwiseConv2dGPUSmall( + const DepthwiseArgs& args) { + return args.depth_multiplier == 1 && args.stride == 1 && args.in_rows <= 32 && + args.in_cols <= 32 && args.in_rows == args.out_rows && + args.in_cols == args.out_cols && args.pad_rows >= 0 && + args.pad_rows < args.filter_rows && args.pad_cols >= 0 && + args.pad_cols < args.filter_cols && + args.filter_rows * args.filter_cols <= + (args.in_rows + 1) / 2 * args.in_cols; +} + +// Returns whether depthwise convolution backward filter pass can be performed +// using the faster ('Small') variant of the kernel. +inline EIGEN_DEVICE_FUNC bool CanLaunchDepthwiseConv2dBackpropFilterGPUSmall( + const DepthwiseArgs& args, const int block_height) { + return args.depth_multiplier == 1 && args.stride == 1 && args.in_rows <= 32 && + args.in_cols <= 32 && args.in_rows == args.out_rows && + args.in_cols == args.out_cols && args.pad_rows >= 0 && + args.pad_rows < args.filter_rows && args.pad_cols >= 0 && + args.pad_cols < args.filter_cols && block_height <= args.in_rows && + args.filter_rows * args.filter_cols <= args.in_cols * block_height; +} + +// The DepthwiseConv2dGPUKernels perform either forward or backprop input +// convolution depending on a template argument of this enum. +enum DepthwiseConv2dDirection { DIRECTION_FORWARD, DIRECTION_BACKWARD }; + +// A GPU kernel to compute the depthwise convolution forward pass +// in NHWC format. +template +__global__ void __launch_bounds__(1024, 2) + DepthwiseConv2dGPUKernelNHWC(const DepthwiseArgs args, + const T* __restrict__ input, + const T* __restrict__ filter, + T* __restrict__ output, int num_outputs) { + typedef typename detail::PseudoHalfType::Type S; + const int in_height = args.in_rows; + const int in_width = args.in_cols; + const int in_depth = args.in_depth; + const int filter_height = + kKnownFilterHeight < 0 ? args.filter_rows : kKnownFilterHeight; + const int filter_width = + kKnownFilterWidth < 0 ? args.filter_cols : kKnownFilterWidth; + const int depth_multiplier = + kKnownDepthMultiplier < 0 ? args.depth_multiplier : kKnownDepthMultiplier; + const int stride = args.stride; + const int pad_height = args.pad_rows; + const int pad_width = args.pad_cols; + const int out_height = args.out_rows; + const int out_width = args.out_cols; + const int out_depth = args.out_depth; + + GPU_1D_KERNEL_LOOP(thread_id, num_outputs) { + // Compute the indexes of this thread in the output. + const int out_channel = thread_id % out_depth; + const int out_col = (thread_id / out_depth) % out_width; + const int out_row = (thread_id / out_depth / out_width) % out_height; + const int batch = thread_id / out_depth / out_width / out_height; + // Compute the input depth and the index of depth multiplier. + const int in_channel = out_channel / depth_multiplier; + const int multiplier = out_channel % depth_multiplier; + + // Decide if all input is valid, if yes, we can skip the boundary checks + // for each input. + const int input_row_start = out_row * stride - pad_height; + const int input_col_start = out_col * stride - pad_width; + const int input_row_end = input_row_start + filter_height; + const int input_col_end = input_col_start + filter_width; + + S sum = static_cast(0); + + const int input_offset_temp = in_height * batch; + if (input_row_start >= 0 && input_col_start >= 0 && + input_row_end < in_height && input_col_end < in_width) { + UNROLL for (int filter_row = 0; filter_row < filter_height; + ++filter_row) { + const int in_row = input_row_start + filter_row; + const int filter_offset_temp = filter_width * filter_row; + UNROLL for (int filter_col = 0; filter_col < filter_width; + ++filter_col) { + const int in_col = input_col_start + filter_col; + + const int input_offset = + in_channel + + in_depth * (in_col + in_width * (in_row + input_offset_temp)); + const int filter_offset = + multiplier + + depth_multiplier * + (in_channel + in_depth * (filter_col + filter_offset_temp)); + sum += static_cast(ldg(input + input_offset)) * + static_cast(ldg(filter + filter_offset)); + } + } + } else { + UNROLL for (int filter_row = 0; filter_row < filter_height; + ++filter_row) { + const int in_row = input_row_start + filter_row; + const int filter_offset_temp = filter_width * filter_row; + UNROLL for (int filter_col = 0; filter_col < filter_width; + ++filter_col) { + const int in_col = input_col_start + filter_col; + if (in_row >= 0 && in_row < in_height && in_col >= 0 && + in_col < in_width) { + const int in_col = input_col_start + filter_col; + + const int input_offset = + in_channel + + in_depth * (in_col + in_width * (in_row + input_offset_temp)); + const int filter_offset = + multiplier + + depth_multiplier * + (in_channel + in_depth * (filter_col + filter_offset_temp)); + sum += static_cast(ldg(input + input_offset)) * + static_cast(ldg(filter + filter_offset)); + } + } + } + } + output[thread_id] = static_cast(sum); + } +} + +// CUDA kernel to compute the depthwise convolution forward pass in NHWC format, +// tailored for small images up to 32x32. Stride and depth multiplier must be 1. +// Padding must be 'SAME', which allows to reuse the index computation. Only +// use this kernel if CanLaunchDepthwiseConv2dGPUSmall(args) returns true. +// Tiles of the input and filter tensors are loaded into shared memory before +// performing the convolution. Each thread handles two elements per iteration, +// one each in the lower and upper half of a tile. +// Backprop input direction is the same as forward direction with the filter +// rotated by 180°. +// T is the tensors' data type. S is the math type the kernel uses. This is the +// same as T for all cases but pseudo half (which has T=Eigen::half, S=float). +template +__global__ __launch_bounds__(1024, 2) void DepthwiseConv2dGPUKernelNHWCSmall( + const DepthwiseArgs args, const T* __restrict__ input, + const T* __restrict__ filter, T* __restrict__ output) { + typedef typename detail::PseudoHalfType::Type S; + assert(CanLaunchDepthwiseConv2dGPUSmall(args)); + // Holds block plus halo and filter data for blockDim.x depths. + GPU_DYNAMIC_SHARED_MEM_DECL(8, unsigned char, shared_memory); + static_assert(sizeof(S) <= 8, "Insufficient alignment detected"); + S* const shared_data = reinterpret_cast(shared_memory); + + const int num_batches = args.batch; + const int in_height = args.in_rows; + const int in_width = args.in_cols; + const int in_depth = args.in_depth; + const int filter_height = + kKnownFilterHeight < 0 ? args.filter_rows : kKnownFilterHeight; + const int filter_width = + kKnownFilterWidth < 0 ? args.filter_cols : kKnownFilterWidth; + const int pad_height = args.pad_rows; + const int pad_width = args.pad_cols; + + assert(blockDim.x == kBlockDepth); + assert(blockDim.y == args.in_cols); + const int block_height = blockDim.z; + + // These values are the same for all threads and could + // be precomputed on the CPU. + const int block_size = block_height * in_width * kBlockDepth; + const int in_row_size = in_width * in_depth; + const int in_size = in_height * in_row_size; + const int in_increment = (in_width - 1) * kBlockDepth; + const int filter_pixels = filter_height * filter_width; + const int tile_width = in_width + filter_width - 1; + const int even_height = kKnownEvenHeight || (1 & ~in_height); + const int tile_height = in_height + filter_height - even_height; + const int tile_row_size = tile_width * kBlockDepth; + const int tile_size = tile_height * tile_row_size; + const int tile_offset = block_height * tile_row_size; + const int pad_offset = pad_height * tile_width + pad_width; + const int batch_blocks = (in_depth + kBlockDepth - 1) / kBlockDepth; + const int in_blocks = batch_blocks * num_batches; + const int tensor_offset = + kKnownEvenHeight ? in_size / 2 : block_height * in_row_size; + + const int thread_depth = threadIdx.x; + const int thread_col = threadIdx.y; + const int thread_row = threadIdx.z; + + // Position in block. + const int thread_pix = thread_row * in_width + thread_col; + const int thread_idx = thread_pix * kBlockDepth + thread_depth; + + // Initialize tile, in particular the padding. + for (int i = thread_idx; i < tile_size; i += block_size) { + shared_data[i] = S(); + } + __syncthreads(); + + // Position in tensors. + const int tensor_idx = thread_pix * in_depth + thread_depth; + + // Position in (padded) shared memory. + const int data_pix = thread_row * tile_width + thread_col; + const int data_idx = data_pix * kBlockDepth + thread_depth; + + // Position in shared memory, offset by pad_height / pad_width. + const int tile_pix = data_pix + pad_offset; + const int tile_idx = tile_pix * kBlockDepth + thread_depth; + + const int max_channel = in_depth - thread_depth; + const int filter_write_offset = + thread_pix < filter_pixels ? tile_size + thread_idx : 0; + const int filter_read_offset = + tile_size + thread_depth + + (kDirection == DIRECTION_FORWARD ? 0 : filter_pixels * kBlockDepth); + const bool skip_second = + !kKnownEvenHeight && thread_row + (in_height & 1) == block_height; + + for (int b = blockIdx.x; b < in_blocks; b += gridDim.x) { + const int batch = b / batch_blocks; + const int block = b - batch * batch_blocks; + + const int start_channel = block * kBlockDepth; + const int filter_offset = tensor_idx + start_channel; + const int inout_offset = batch * in_size + filter_offset; + const bool channel_in_range = start_channel < max_channel; + + if (channel_in_range) { + const T* const in_ptr = inout_offset + input; + S* const tile_ptr = tile_idx + shared_data; + tile_ptr[0] = static_cast(ldg(in_ptr)); + if (!skip_second) { + tile_ptr[tile_offset] = static_cast(ldg(tensor_offset + in_ptr)); + } + + if (filter_write_offset != 0) { + shared_data[filter_write_offset] = + static_cast(ldg(filter_offset + filter)); + } + } + + // Note: the condition to reach this is uniform across the entire block. + __syncthreads(); + + if (channel_in_range) { + S sum1 = S(); + S sum2 = S(); + int shared_offset = data_idx; + const S* filter_ptr = filter_read_offset + shared_data; + UNROLL for (int r = 0; r < filter_height; ++r) { + UNROLL for (int c = 0; c < filter_width; ++c) { + if (kDirection == DIRECTION_BACKWARD) { + filter_ptr -= kBlockDepth; + } + const S filter_value = *filter_ptr; + const S* const tile_ptr = shared_offset + shared_data; + sum1 += filter_value * tile_ptr[0]; + sum2 += filter_value * tile_ptr[tile_offset]; + shared_offset += kBlockDepth; + if (kDirection == DIRECTION_FORWARD) { + filter_ptr += kBlockDepth; + } + } + shared_offset += in_increment; + } + T* const out_ptr = inout_offset + output; + out_ptr[0] = static_cast(sum1); + if (!skip_second) { + out_ptr[tensor_offset] = static_cast(sum2); + } + } + + // Note: the condition to reach this is uniform across the entire block. + __syncthreads(); + } +} + +// A GPU kernel to compute the depthwise convolution forward pass +// in NCHW format. +template +__global__ void __launch_bounds__(1024, 2) + DepthwiseConv2dGPUKernelNCHW(const DepthwiseArgs args, + const T* __restrict__ input, + const T* __restrict__ filter, + T* __restrict__ output, int num_outputs) { + typedef typename detail::PseudoHalfType::Type S; + const int in_height = args.in_rows; + const int in_width = args.in_cols; + const int in_depth = args.in_depth; + const int filter_height = + kKnownFilterHeight < 0 ? args.filter_rows : kKnownFilterHeight; + const int filter_width = + kKnownFilterWidth < 0 ? args.filter_cols : kKnownFilterWidth; + const FastDividerUint32 depth_multiplier = + kKnownDepthMultiplier < 0 ? args.depth_multiplier : kKnownDepthMultiplier; + const int stride = args.stride; + const int pad_height = args.pad_rows; + const int pad_width = args.pad_cols; + const int out_width = args.out_cols; + const FastDividerUint32 out_height = args.out_rows; + const FastDividerUint32 out_depth = args.out_depth; + + GPU_1D_KERNEL_LOOP(thread_id, num_outputs) { + // Compute the indexes of this thread in the output. + // + // We want coalesced reads so we make sure that each warp reads + // a contiguous chunk of memory. + // + // THIS IS PROBABLY WRONG, we are not doing coalesced reads + // into the input, because of the depth multiplier division... + const int out_col = thread_id % out_width; + const int out_row = (thread_id / out_width) % out_height; + const int out_channel = (thread_id / out_width / out_height) % out_depth; + const int batch = thread_id / out_width / out_height / out_depth; + + // Compute the input depth and the index of depth multiplier + // based off the output depth index that this thread is + // computing n. + const int in_channel = out_channel / depth_multiplier; + const int multiplier = out_channel % depth_multiplier; + + // Data is stored in the following format (let's assume we + // flatten the height and width into one contiguous dimension + // called "P". + // + // B1C1P1 B1C1P2 ..... B1C2P1 B1C2P2 .... + // B2C1P1 B2C1P2 ..... B2C2P1 B2C2P2 .... + // + // Each row contains in_depth * in_height * in_width values + // for each sample in the batch. + // + // We can further flatten it into: + // + // B1C1P1 B1C1P2 ..... + // B1C2P1 B1C2P2 .... + // B2C1P1 B2C1P2 ..... + // B2C2P1 B2C2P2 .... + // + // where each row is a contiguous array of all of the spatial + // pixels for a given batch and input depth. The following + // loop unrolls across the filter dimensions for a given thread, + // indexing into the filter value and the corresponding input + // patch. + // + // We can compute the index into the patch once right here. + const int input_offset_temp = + (batch * in_depth + in_channel) * (in_height * in_width); + + // Finally, we can iterate over the spatial dimensions and perform the + // convolution, writing into the output at the end. + // + // We perform an additional optimization, where we can determine + // whether the patch fits within the image indices statically, and + // avoid boundary checking within the loop. + const int input_row_start = out_row * stride - pad_height; + const int input_col_start = out_col * stride - pad_width; + const int input_row_end = input_row_start + filter_height; + const int input_col_end = input_col_start + filter_width; + + S sum = static_cast(0); + if (input_row_start >= 0 && input_col_start >= 0 && + input_row_end < in_height && input_col_end < in_width) { + // Loop that doesn't need to check for boundary conditions. + UNROLL for (int filter_row = 0; filter_row < filter_height; + ++filter_row) { + const int in_row = input_row_start + filter_row; + const int filter_offset_temp = filter_width * filter_row; + UNROLL for (int filter_col = 0; filter_col < filter_width; + ++filter_col) { + const int in_col = input_col_start + filter_col; + + const int input_offset = + (input_offset_temp) + (in_row * in_width) + in_col; + const int filter_offset = + multiplier + + depth_multiplier * + (in_channel + in_depth * (filter_col + filter_offset_temp)); + sum += static_cast(ldg(input + input_offset)) * + static_cast(ldg(filter + filter_offset)); + } + } + } else { + // Loop that needs to check for boundary conditions. + UNROLL for (int filter_row = 0; filter_row < filter_height; + ++filter_row) { + const int in_row = input_row_start + filter_row; + const int filter_offset_temp = filter_width * filter_row; + UNROLL for (int filter_col = 0; filter_col < filter_width; + ++filter_col) { + const int in_col = input_col_start + filter_col; + // TODO(vrv): the in_row check can be done outside of this loop; + // benchmark both methods to determine the better decision. + if (in_row >= 0 && in_row < in_height && in_col >= 0 && + in_col < in_width) { + const int in_col = input_col_start + filter_col; + + // input_offset_temp indexes into the start of memory + // where the spatial data starts. + const int input_offset = + (input_offset_temp) + (in_row * in_width) + in_col; + + const int filter_offset = + multiplier + + depth_multiplier * + (in_channel + in_depth * (filter_col + filter_offset_temp)); + sum += static_cast(ldg(input + input_offset)) * + static_cast(ldg(filter + filter_offset)); + } + } + } + } + + output[thread_id] = static_cast(sum); + } +} + +// CUDA kernel to compute the depthwise convolution forward pass in NCHW format, +// tailored for small images up to 32x32. Stride and depth multiplier must be 1. +// Padding must be 'SAME', which allows to reuse the index computation. Only +// use this kernel if CanLaunchDepthwiseConv2dGPUSmall(args) returns true. +// Tiles of the input and filter tensors are loaded into shared memory before +// performing the convolution. Each thread handles two elements per iteration, +// one each in the lower and upper half of a tile. +// Backprop input direction is the same as forward direction with the filter +// rotated by 180°. +// T is the tensors' data type. S is the math type the kernel uses. This is the +// same as T for all cases but pseudo half (which has T=Eigen::half, S=float). +template +__global__ __launch_bounds__(1024, 2) void DepthwiseConv2dGPUKernelNCHWSmall( + const DepthwiseArgs args, const T* __restrict__ input, + const T* __restrict__ filter, T* __restrict__ output) { + typedef typename detail::PseudoHalfType::Type S; + assert(CanLaunchDepthwiseConv2dGPUSmall(args)); + // Holds block plus halo and filter data for blockDim.z depths. + GPU_DYNAMIC_SHARED_MEM_DECL(8, unsigned char, shared_memory); + static_assert(sizeof(S) <= 8, "Insufficient alignment detected"); + S* const shared_data = reinterpret_cast(shared_memory); + + const int num_batches = args.batch; + const int in_height = args.in_rows; + const int in_width = args.in_cols; + const int in_depth = args.in_depth; + const int filter_height = + kKnownFilterHeight < 0 ? args.filter_rows : kKnownFilterHeight; + const int filter_width = + kKnownFilterWidth < 0 ? args.filter_cols : kKnownFilterWidth; + const int pad_height = args.pad_rows; + const int pad_width = args.pad_cols; + + // Fixed blockDim.z, tailored for maximum grid size for images of size 16x16. + assert(blockDim.x == args.in_cols); + assert(blockDim.z == kBlockDepth); + const int block_height = blockDim.y; + + // These values are the same for all threads and could + // be precomputed on the CPU. + const int block_pixels = in_width * block_height; + const int block_size = block_pixels * kBlockDepth; + const int in_pixels = in_width * in_height; + const int in_increment = in_width - 1; + const int filter_pixels = filter_height * filter_width; + const int tile_width = in_width + filter_width - 1; + const int even_height = kKnownEvenHeight || (1 & ~in_height); + const int tile_height = in_height + filter_height - even_height; + const int tile_pixels = tile_width * tile_height; + const int tile_size = tile_pixels * kBlockDepth; + const int tile_offset = block_height * tile_width; + const int pad_offset = pad_height * tile_width + pad_width; + const int in_total_depth = in_depth * num_batches; + const int in_blocks = (in_total_depth + kBlockDepth - 1) / kBlockDepth; + + const int thread_col = threadIdx.x; + const int thread_row = threadIdx.y; + const int thread_depth = threadIdx.z; + + // Position in block. + const int thread_pix = thread_row * in_width + thread_col; + const int thread_idx = thread_depth * block_pixels + thread_pix; + + // Initialize tile, in particular the padding. + for (int i = thread_idx; i < tile_size; i += block_size) { + shared_data[i] = S(); + } + __syncthreads(); + + // Position in tensors. + const int tensor_idx = thread_depth * in_pixels + thread_pix; + + // Position in (padded) shared memory. + const int data_pix = thread_row * tile_width + thread_col; + const int data_idx = thread_depth * tile_pixels + data_pix; + + // Position in shared memory, offset by pad_height / pad_width. + const int tile_idx = data_idx + pad_offset; + + // Filter is always in HWCK format, irrespective of the input/output format. + const int filter_pix = thread_idx / kBlockDepth; + const int filter_channel = thread_idx % kBlockDepth; + const int filter_idx = filter_pix * in_depth; + + const int max_channel = in_total_depth - thread_depth; + const int filter_write_offset = + filter_pix < filter_pixels ? tile_size + thread_idx : 0; + const int filter_read_offset = + tile_size + thread_depth + + (kDirection == DIRECTION_FORWARD ? 0 : filter_pixels * kBlockDepth); + const bool skip_second = + !kKnownEvenHeight && thread_row + (in_height & 1) == block_height; + + for (int b = blockIdx.x; b < in_blocks; b += gridDim.x) { + const int channel = b * kBlockDepth; + + const int inout_offset = channel * in_pixels + tensor_idx; + const bool channel_in_range = channel < max_channel; + + if (channel_in_range) { + const T* const in_ptr = inout_offset + input; + S* const tile_ptr = tile_idx + shared_data; + tile_ptr[0] = static_cast(ldg(in_ptr)); + if (!skip_second) { + tile_ptr[tile_offset] = static_cast(ldg(block_pixels + in_ptr)); + } + } + + if (filter_write_offset != 0) { + const int filter_offset = + filter_idx + (channel + filter_channel) % in_depth; + shared_data[filter_write_offset] = + static_cast(ldg(filter_offset + filter)); + } + + // Note: the condition to reach this is uniform across the entire block. + __syncthreads(); + + if (channel_in_range) { + S sum1 = S(); + S sum2 = S(); + int shared_offset = data_idx; + const S* filter_ptr = filter_read_offset + shared_data; + UNROLL for (int r = 0; r < filter_height; ++r) { + UNROLL for (int c = 0; c < filter_width; ++c) { + if (kDirection == DIRECTION_BACKWARD) { + filter_ptr -= kBlockDepth; + } + const S filter_value = *filter_ptr; + const S* const tile_ptr = shared_offset + shared_data; + sum1 += filter_value * tile_ptr[0]; + sum2 += filter_value * tile_ptr[tile_offset]; + ++shared_offset; + if (kDirection == DIRECTION_FORWARD) { + filter_ptr += kBlockDepth; + } + } + shared_offset += in_increment; + } + T* const out_ptr = inout_offset + output; + out_ptr[0] = static_cast(sum1); + if (!skip_second) { + out_ptr[block_pixels] = static_cast(sum2); + } + } + + // Note: the condition to reach this is uniform across the entire block. + __syncthreads(); + } +} + +template +Status LaunchDepthwiseConv2dGPUSmall(OpKernelContext* ctx, + const DepthwiseArgs& args, const T* input, + const T* filter, T* output, + TensorFormat data_format) { + typedef typename detail::PseudoHalfType::Type S; + const int block_height = (args.in_rows + 1) / 2; + dim3 block_dim; + int block_count; + void (*kernel)(const DepthwiseArgs, const T*, const T*, T*); + switch (data_format) { + case FORMAT_NHWC: + block_dim = dim3(kBlockDepth, args.in_cols, block_height); + block_count = + args.batch * DivUp(args.out_depth, kBlockDepth) * kBlockDepth; + kernel = + DepthwiseConv2dGPUKernelNHWCSmall; + break; + case FORMAT_NCHW: + block_dim = dim3(args.in_cols, block_height, kBlockDepth); + block_count = + DivUp(args.batch * args.out_depth, kBlockDepth) * kBlockDepth; + kernel = + DepthwiseConv2dGPUKernelNCHWSmall; + break; + default: + return errors::InvalidArgument("FORMAT_", ToString(data_format), + " is not supported"); + } + const int tile_width = args.in_cols + args.filter_cols - 1; + const int tile_height = block_height * 2 + args.filter_rows - 1; + const int tile_pixels = tile_height * tile_width; + const int filter_pixels = args.filter_rows * args.filter_cols; + const int shared_memory_size = + kBlockDepth * (tile_pixels + filter_pixels) * sizeof(S); + const int num_outputs = args.out_rows * args.out_cols * block_count; + auto device = ctx->eigen_gpu_device(); + GpuLaunchConfig config = GetGpuLaunchConfigFixedBlockSize( + num_outputs, device, kernel, shared_memory_size, + block_dim.x * block_dim.y * block_dim.z); + TF_CHECK_OK(GpuLaunchKernel(kernel, config.block_count, block_dim, + shared_memory_size, device.stream(), args, input, + filter, output)); + return OkStatus(); +} + +// Returns whether the context's GPU supports efficient fp16 math. +inline bool HasFastHalfMath(OpKernelContext* ctx) { + se::CudaComputeCapability compute_capability = + ctx->op_device_context()->stream()->GetCudaComputeCapability(); + // GPUs before sm_53 don't support fp16 math, and sm_61's fp16 math is slow. + return compute_capability.IsAtLeast(5, 3) && + compute_capability != se::CudaComputeCapability{6, 1}; +} + +template +Status LaunchDepthwiseConv2dGPUSmall(OpKernelContext* ctx, + const DepthwiseArgs& args, const T* input, + const T* filter, T* output, + TensorFormat data_format) { + if (args.in_rows & 1) { + return LaunchDepthwiseConv2dGPUSmall(ctx, args, input, filter, + output, data_format); + } else { + return LaunchDepthwiseConv2dGPUSmall( + ctx, args, input, filter, output, data_format); + } +} + +template +Status LaunchDepthwiseConv2dGPUSmall(OpKernelContext* ctx, + const DepthwiseArgs& args, const T* input, + const T* filter, T* output, + TensorFormat data_format) { + // Maximize (power of two) kBlockDepth while keeping a block within 1024 + // threads (2 pixels per thread). + const int block_pixels = (args.in_rows + 1) / 2 * args.in_cols; + if (block_pixels > 256) { + return LaunchDepthwiseConv2dGPUSmall( + ctx, args, input, filter, output, data_format); + } else if (block_pixels > 128) { + return LaunchDepthwiseConv2dGPUSmall( + ctx, args, input, filter, output, data_format); + } else { + return LaunchDepthwiseConv2dGPUSmall( + ctx, args, input, filter, output, data_format); + } +} + +template +Status LaunchDepthwiseConv2dGPU(OpKernelContext* ctx, const DepthwiseArgs& args, + const T* input, const T* filter, T* output, + TensorFormat data_format) { + void (*kernel)(const DepthwiseArgs, const T*, const T*, T*, int); + switch (data_format) { + case FORMAT_NHWC: + kernel = + DepthwiseConv2dGPUKernelNHWC; + break; + case FORMAT_NCHW: + kernel = + DepthwiseConv2dGPUKernelNCHW; + break; + default: + return errors::InvalidArgument("FORMAT_", ToString(data_format), + " is not supported"); + } + const int num_outputs = + args.batch * args.out_rows * args.out_cols * args.out_depth; + auto device = ctx->eigen_gpu_device(); + GpuLaunchConfig config = + GetGpuLaunchConfig(num_outputs, device, kernel, 0, 0); + // The compile-time constant version runs faster with a single block. + const int max_block_count = kKnownFilterWidth < 0 || kKnownFilterHeight < 0 || + kKnownDepthMultiplier < 0 + ? std::numeric_limits::max() + : device.getNumGpuMultiProcessors(); + TF_CHECK_OK(GpuLaunchKernel(kernel, + std::min(max_block_count, config.block_count), + config.thread_per_block, 0, device.stream(), args, + input, filter, output, num_outputs)); + return OkStatus(); +} + +template +Status LaunchDepthwiseConv2dGPU(OpKernelContext* ctx, const DepthwiseArgs& args, + const T* input, const T* filter, T* output, + TensorFormat data_format) { + if (args.depth_multiplier == 1) { + if (CanLaunchDepthwiseConv2dGPUSmall(args)) { + return LaunchDepthwiseConv2dGPUSmall< + T, DIRECTION_FORWARD, kKnownFilterWidth, kKnownFilterHeight>( + ctx, args, input, filter, output, data_format); + } + + return LaunchDepthwiseConv2dGPU(ctx, args, input, filter, output, + data_format); + } else { + return LaunchDepthwiseConv2dGPU(ctx, args, input, filter, output, + data_format); + } +} + +// A simple launch pad to launch the GPU kernel for depthwise convolution. +template +void LaunchDepthwiseConvOp::operator()(OpKernelContext* ctx, + const DepthwiseArgs& args, + const T* input, + const T* filter, T* output, + TensorFormat data_format) { + if (args.filter_rows == 3 && args.filter_cols == 3) { + OP_REQUIRES_OK(ctx, LaunchDepthwiseConv2dGPU( + ctx, args, input, filter, output, data_format)); + } else { + OP_REQUIRES_OK(ctx, LaunchDepthwiseConv2dGPU( + ctx, args, input, filter, output, data_format)); + } +} + +// A GPU kernel to compute the depthwise convolution backprop w.r.t. input. +template +__global__ void __launch_bounds__(640, 2) + DepthwiseConv2dBackpropInputGPUKernelNHWC( + const DepthwiseArgs args, const T* __restrict__ out_backprop, + const T* __restrict__ filter, T* __restrict__ in_backprop, + int num_in_backprop) { + const int in_height = args.in_rows; + const int in_width = args.in_cols; + const int in_depth = args.in_depth; + const int filter_height = + kKnownFilterHeight < 0 ? args.filter_rows : kKnownFilterHeight; + const int filter_width = + kKnownFilterWidth < 0 ? args.filter_cols : kKnownFilterWidth; + const int depth_multiplier = + kKnownDepthMultiplier < 0 ? args.depth_multiplier : kKnownDepthMultiplier; + const int stride = args.stride; + const int pad_height = args.pad_rows; + const int pad_width = args.pad_cols; + const int out_height = args.out_rows; + const int out_width = args.out_cols; + const int out_depth = args.out_depth; + + GPU_1D_KERNEL_LOOP(thread_id, num_in_backprop) { + // Compute the indexes of this thread in the output. + const int in_channel = thread_id % in_depth; + const int in_col = (thread_id / in_depth) % in_width; + const int in_row = (thread_id / in_depth / in_width) % in_height; + const int batch = thread_id / in_depth / in_width / in_height; + + T sum = static_cast(0); + + const int out_row_start = + tf_max(0, (in_row - filter_height + pad_height + stride) / stride); + const int out_row_end = + tf_min(out_height - 1, (in_row + pad_height) / stride); + const int out_col_start = + tf_max(0, (in_col - filter_width + pad_width + stride) / stride); + const int out_col_end = + tf_min(out_width - 1, (in_col + pad_width) / stride); + + NOUNROLL for (int out_row = out_row_start; out_row <= out_row_end; + ++out_row) { + const int filter_row = in_row + pad_height - out_row * stride; + const int temp_out_backprop_offset = + out_depth * out_width * (out_row + out_height * batch); + const int temp_filter_offset = filter_width * filter_row; + NOUNROLL for (int out_col = out_col_start; out_col <= out_col_end; + ++out_col) { + const int filter_col = in_col + pad_width - out_col * stride; + int filter_offset = + depth_multiplier * + (in_channel + in_depth * (filter_col + temp_filter_offset)); + const int out_backprop_offset = + out_depth * out_col + temp_out_backprop_offset; +#pragma unroll 6 + for (int i = 0; i < depth_multiplier; ++i) { + sum += ldg(out_backprop + out_backprop_offset + + in_channel * depth_multiplier + i) * + ldg(filter + filter_offset + i); + } + } + } + const int in_backprop_offset = + in_channel + + in_depth * (in_col + in_width * (in_row + in_height * batch)); + in_backprop[in_backprop_offset] = sum; + } +} + +template +__global__ void __launch_bounds__(640, 2) + DepthwiseConv2dBackpropInputGPUKernelNCHW( + const DepthwiseArgs args, const T* __restrict__ out_backprop, + const T* __restrict__ filter, T* __restrict__ in_backprop, + int num_in_backprop) { + const FastDividerUint32 in_height = args.in_rows; + const FastDividerUint32 in_width = args.in_cols; + const FastDividerUint32 in_depth = args.in_depth; + const int filter_height = + kKnownFilterHeight < 0 ? args.filter_rows : kKnownFilterHeight; + const int filter_width = + kKnownFilterWidth < 0 ? args.filter_cols : kKnownFilterWidth; + const int depth_multiplier = + kKnownDepthMultiplier < 0 ? args.depth_multiplier : kKnownDepthMultiplier; + const int stride = args.stride; + const int pad_height = args.pad_rows; + const int pad_width = args.pad_cols; + const int out_height = args.out_rows; + const int out_width = args.out_cols; + const int out_depth = args.out_depth; + + // TODO(vrv): Consider assigning threads to output and using + // atomics for accumulation, similar to the filter case. + GPU_1D_KERNEL_LOOP(thread_id, num_in_backprop) { + // Compute the indexes of this thread in the input. + const int in_col = thread_id % in_width; + const int in_row = (thread_id / in_width) % in_height; + const int in_channel = (thread_id / in_width / in_height) % in_depth; + const int batch = thread_id / in_depth / in_width / in_height; + + T sum = static_cast(0); + const int out_channel_start = in_channel * depth_multiplier; + const int out_channel_end = out_channel_start + depth_multiplier; + + const int out_row_start = + tf_max(0, (in_row - filter_height + pad_height + stride) / stride); + const int out_row_end = + tf_min(out_height - 1, (in_row + pad_height) / stride); + const int out_col_start = + tf_max(0, (in_col - filter_width + pad_width + stride) / stride); + const int out_col_end = + tf_min(out_width - 1, (in_col + pad_width) / stride); + + UNROLL for (int out_channel = out_channel_start; + out_channel < out_channel_end; ++out_channel) { + UNROLL for (int out_row = out_row_start; out_row <= out_row_end; + ++out_row) { + const int filter_row = in_row + pad_height - out_row * stride; + const int filter_dm = out_channel - out_channel_start; + + const int temp_filter_offset = filter_width * filter_row; + for (int out_col = out_col_start; out_col <= out_col_end; ++out_col) { + const int filter_col = in_col + pad_width - out_col * stride; + const int filter_offset = + filter_dm + + args.depth_multiplier * + (in_channel + in_depth * (filter_col + temp_filter_offset)); + + const int out_backprop_offset = + (batch * out_depth * out_height * out_width) + + (out_channel * out_height * out_width) + (out_row * out_width) + + (out_col); + + sum += ldg(out_backprop + out_backprop_offset) * + ldg(filter + filter_offset); + } + } + } + const int in_backprop_offset = (batch * in_height * in_width * in_depth) + + (in_channel * in_height * in_width) + + (in_row * in_width) + (in_col); + in_backprop[in_backprop_offset] = sum; + } +} + +template +Status LaunchDepthwiseConv2dBackpropInputGPU(OpKernelContext* ctx, + const DepthwiseArgs& args, + const T* out_backprop, + const T* filter, T* in_backprop, + TensorFormat data_format) { + void (*kernel)(const DepthwiseArgs, const T*, const T*, T*, int); + switch (data_format) { + case FORMAT_NHWC: + kernel = DepthwiseConv2dBackpropInputGPUKernelNHWC< + T, kKnownFilterWidth, kKnownFilterHeight, kKnownDepthMultiplier>; + break; + case FORMAT_NCHW: + kernel = DepthwiseConv2dBackpropInputGPUKernelNCHW< + T, kKnownFilterWidth, kKnownFilterHeight, kKnownDepthMultiplier>; + break; + default: + return errors::InvalidArgument("FORMAT_", ToString(data_format), + " is not supported"); + } + const int num_in_backprop = + args.batch * args.in_rows * args.in_cols * args.in_depth; + auto device = ctx->eigen_gpu_device(); + int launch_bounds_value = 640; + GpuLaunchConfig config = GetGpuLaunchConfig(num_in_backprop, device, kernel, + 0, launch_bounds_value); + TF_CHECK_OK(GpuLaunchKernel( + kernel, config.block_count, config.thread_per_block, 0, device.stream(), + args, out_backprop, filter, in_backprop, num_in_backprop)); + return OkStatus(); +} + +template +Status LaunchDepthwiseConv2dBackpropInputGPU(OpKernelContext* ctx, + const DepthwiseArgs& args, + const T* out_backprop, + const T* filter, T* in_backprop, + TensorFormat data_format) { + if (args.depth_multiplier == 1) { + // This kernel doesn't currently work in all cases so it is disabled. + // TODO(b/150988950): Fix and reenable this kernel. + if (/* CanLaunchDepthwiseConv2dGPUSmall(args) */ false) { + return LaunchDepthwiseConv2dGPUSmall< + T, DIRECTION_BACKWARD, kKnownFilterWidth, kKnownFilterHeight>( + ctx, args, out_backprop, filter, in_backprop, data_format); + } + + return LaunchDepthwiseConv2dBackpropInputGPU( + ctx, args, out_backprop, filter, in_backprop, data_format); + } else { + return LaunchDepthwiseConv2dBackpropInputGPU( + ctx, args, out_backprop, filter, in_backprop, data_format); + } +} + +// A simple launch pad to launch the GPU kernel for depthwise convolution. +template +void LaunchDepthwiseConvBackpropInputOp::operator()( + OpKernelContext* ctx, const DepthwiseArgs& args, const T* out_backprop, + const T* filter, T* in_backprop, TensorFormat data_format) { + if (args.filter_rows == 3 && args.filter_cols == 3) { + OP_REQUIRES_OK( + ctx, LaunchDepthwiseConv2dBackpropInputGPU( + ctx, args, out_backprop, filter, in_backprop, data_format)); + } else { + OP_REQUIRES_OK( + ctx, LaunchDepthwiseConv2dBackpropInputGPU( + ctx, args, out_backprop, filter, in_backprop, data_format)); + } +} + +// A GPU kernel to compute the depthwise convolution backprop w.r.t. filter. +// TODO: Add fp32 accumulation to half calls of this function. This addition +// is non-trivial as the partial sums are added directly to the output +template +__global__ void __launch_bounds__(640, 2) + DepthwiseConv2dBackpropFilterGPUKernelNHWC( + const DepthwiseArgs args, const T* __restrict__ out_backprop, + const T* __restrict__ input, T* __restrict__ filter_backprop, + int num_out_backprop) { + const int in_height = args.in_rows; + const int in_width = args.in_cols; + const int in_depth = args.in_depth; + const int filter_height = + kKnownFilterHeight < 0 ? args.filter_rows : kKnownFilterHeight; + const int filter_width = + kKnownFilterWidth < 0 ? args.filter_cols : kKnownFilterWidth; + const int depth_multiplier = + kKnownDepthMultiplier < 0 ? args.depth_multiplier : kKnownDepthMultiplier; + const int stride = args.stride; + const int pad_height = args.pad_rows; + const int pad_width = args.pad_cols; + const int out_height = args.out_rows; + const int out_width = args.out_cols; + const int out_depth = args.out_depth; + + GPU_1D_KERNEL_LOOP(thread_id, num_out_backprop) { + // Compute the indexes of this thread in the output. + const int out_channel = thread_id % out_depth; + const int out_col = (thread_id / out_depth) % out_width; + const int out_row = (thread_id / out_depth / out_width) % out_height; + const int batch = thread_id / out_depth / out_width / out_height; + // Compute the input depth and the index of depth multiplier. + const int in_channel = out_channel / depth_multiplier; + const int dm = out_channel % depth_multiplier; + + // Decide if all input is valid, if yes, we can skip the boundary checks + // for each input. + const int in_row_start = out_row * stride - pad_height; + const int in_col_start = out_col * stride - pad_width; + const int in_row_end = in_row_start + filter_height; + const int in_col_end = in_col_start + filter_width; + + const int out_backprop_offset = + out_channel + + out_depth * (out_col + out_width * (out_row + out_height * batch)); + const T out_bp = ldg(out_backprop + out_backprop_offset); + if (in_row_start >= 0 && in_col_start >= 0 && in_row_end < in_height && + in_col_end < in_width) { + UNROLL for (int filter_row = 0; filter_row < filter_height; + ++filter_row) { + const int in_row = in_row_start + filter_row; + // Avoid repeated computation. + const int input_offset_temp = in_width * (in_row + in_height * batch); + UNROLL for (int filter_col = 0; filter_col < filter_width; + ++filter_col) { + const int in_col = in_col_start + filter_col; + + const int input_offset = + in_channel + in_depth * (in_col + input_offset_temp); + T partial_sum = ldg(input + input_offset) * out_bp; + T* addr = + filter_backprop + + (dm + depth_multiplier * + (in_channel + + in_depth * (filter_col + filter_width * filter_row))); + GpuAtomicAdd(addr, partial_sum); + } + } + } else { + UNROLL for (int filter_row = 0; filter_row < filter_height; + ++filter_row) { + const int in_row = in_row_start + filter_row; + // Avoid repeated computation. + const int input_offset_temp = in_width * (in_row + in_height * batch); + UNROLL for (int filter_col = 0; filter_col < filter_width; + ++filter_col) { + const int in_col = in_col_start + filter_col; + const int addr_temp = filter_width * filter_row; + + if (in_row >= 0 && in_row < in_height && in_col >= 0 && + in_col < in_width) { + const int input_offset = + in_channel + in_depth * (in_col + input_offset_temp); + T partial_sum = ldg(input + input_offset) * out_bp; + T* addr = + filter_backprop + + (dm + depth_multiplier * + (in_channel + in_depth * (filter_col + addr_temp))); + // Potentially many threads can add to the same address so we have + // to use atomic add here. + // TODO(jmchen): If atomic add turns out to be slow, we can: + // 1. allocate multiple buffers for the gradients (one for each + // example in a batch, for example). This can reduce the + // contention on the destination; 2. Have each thread compute one + // gradient for an element in the filters. This should work well + // when the input depth is big and filter size is not too small. + GpuAtomicAdd(addr, partial_sum); + } + } + } + } + } +} + +// Device function to compute sub-warp sum reduction for a power-of-two group of +// neighboring threads. +template +__device__ __forceinline__ T WarpSumReduce(T val) { + // support only power-of-two widths. + assert(__popc(kWidth) == 1); + int sub_warp = GpuLaneId() / kWidth; + int zeros = sub_warp * kWidth; + unsigned mask = ((1UL << kWidth) - 1) << zeros; + for (int delta = kWidth / 2; delta > 0; delta /= 2) { + val += GpuShuffleXorSync(mask, val, delta); + } + return val; +} + +// CUDA kernel to compute the depthwise convolution backward w.r.t. filter in +// NHWC format, tailored for small images up to 32x32. Stride and depth +// multiplier must be 1. Padding must be 'SAME'. Only use this kernel if +// CanLaunchDepthwiseConv2dGPUSmall(args) returns true. +// Tiles of the input tensor are loaded into shared memory before performing the +// convolution. Per iteration and filter element, each thread first performs +// a partial convolution for two elements, one each in the lower and upper half +// of a tile. The intermediate result of all pixels of a warp are then +// accumulated and written to shared memory. Finally, the values in shared +// memory are warp-accumulated (in chunks of kAccumPixels elements) and summed +// up in global memory using atomics. +// Requirements: threads per block must be multiple of 32 and <= launch_bounds, +// kAccumPixels * 64 >= args.in_rows * args.in_cols * kBlockDepth. +// T is the tensors' data type. S is the math type the kernel uses. This is the +// same as T for all cases but pseudo half (which has T=Eigen::half, S=float). +template +__global__ +__launch_bounds__(1024, 2) void DepthwiseConv2dBackpropFilterGPUKernelNHWCSmall( + const DepthwiseArgs args, const T* __restrict__ output, + const T* __restrict__ input, T* __restrict__ filter) { + typedef typename detail::PseudoHalfType::Type S; + assert(CanLaunchDepthwiseConv2dBackpropFilterGPUSmall(args, blockDim.z)); + // Holds block plus halo and filter data for blockDim.x depths. + GPU_DYNAMIC_SHARED_MEM_DECL(8, unsigned char, shared_memory); + static_assert(sizeof(S) <= 8, "Insufficient alignment detected"); + S* const shared_data = reinterpret_cast(shared_memory); + + const int num_batches = args.batch; + const int in_height = args.in_rows; + const int in_width = blockDim.y; // slower (see b/62280718): args.in_cols; + const int in_depth = args.in_depth; + const int filter_height = + kKnownFilterHeight < 0 ? args.filter_rows : kKnownFilterHeight; + const int filter_width = + kKnownFilterWidth < 0 ? args.filter_cols : kKnownFilterWidth; + const int pad_height = args.pad_rows; + const int pad_width = args.pad_cols; + + assert(blockDim.x == kBlockDepth); + assert(blockDim.y == args.in_cols); + const int block_height = blockDim.z; + + // These values are the same for all threads and could + // be precomputed on the CPU. + const int block_size = block_height * in_width * kBlockDepth; + assert((block_size & 31) == 0); + const int in_row_size = in_width * in_depth; + const int in_size = in_height * in_row_size; + const int in_increment = (in_width - 1) * kBlockDepth; + const int filter_pixels = filter_height * filter_width; + const int tile_width = in_width + filter_width - 1; + const int tile_height = 2 * block_height + filter_height - 1; + const int tile_row_size = tile_width * kBlockDepth; + const int tile_size = tile_height * tile_row_size; + const int tile_offset = block_height * tile_row_size; + const int pad_offset = pad_height * tile_width + pad_width; + const int batch_blocks = (in_depth + kBlockDepth - 1) / kBlockDepth; + const int in_blocks = batch_blocks * num_batches; + const int tensor_offset = block_height * in_row_size; + // The accumulator has a fixed number of pixels that can be reduced by one + // warp. Pixels beyond ceil(in_pixels * kBlockDepth / 64) are never written. + assert(kAccumPixels * 64 >= in_height * in_width * kBlockDepth); + const int accum_increment = kAccumPixels * kBlockDepth; + const int accum_size = filter_pixels * accum_increment; + + const int thread_depth = threadIdx.x; + const int thread_col = threadIdx.y; + const int thread_row = threadIdx.z; + + // Position in block. + const int thread_pix = thread_row * in_width + thread_col; + const int thread_idx = thread_pix * kBlockDepth + thread_depth; + + // Initialize tile, in particular the padding and accumulator. + for (int i = thread_idx; i < tile_size + accum_size; i += block_size) { + shared_data[i] = S(); + } + __syncthreads(); + + // Position in tensors. + const int tensor_idx = thread_pix * in_depth + thread_depth; + + // Position in (padded) shared memory. + const int data_pix = thread_row * tile_width + thread_col; + const int data_idx = data_pix * kBlockDepth + thread_depth; + + // Position in shared memory, offset by pad_height / pad_width. + const int tile_pix = data_pix + pad_offset; + const int tile_idx = tile_pix * kBlockDepth + thread_depth; + + // Position in accumulator (kBlockDepth per warp, depth major). + const int accum_pix = thread_pix / (32 / kBlockDepth); + const int accum_idx = thread_depth * kAccumPixels + accum_pix; + + const int max_channel = in_depth - thread_depth; + const int accum_offset = tile_size + accum_idx; + const bool skip_second = block_height + thread_row >= in_height; + + for (int b = blockIdx.x; b < in_blocks; b += gridDim.x) { + const int batch = b / batch_blocks; + const int block = b - batch * batch_blocks; + + const int start_channel = block * kBlockDepth; + const int filter_offset = tensor_idx + start_channel; + const int inout_offset = batch * in_size + filter_offset; + const bool channel_in_range = start_channel < max_channel; + + if (channel_in_range) { + const T* const in_ptr = inout_offset + input; + S* const tile_ptr = tile_idx + shared_data; + tile_ptr[0] = static_cast(ldg(in_ptr)); + if (!skip_second) { + tile_ptr[tile_offset] = static_cast(ldg(tensor_offset + in_ptr)); + } + } + + // Note: the condition to reach this is uniform across the entire block. + __syncthreads(); + unsigned active_threads = GpuBallotSync(kCudaWarpAll, channel_in_range); + + if (channel_in_range) { + const T* const out_ptr = inout_offset + output; + const S out1 = static_cast(ldg(out_ptr)); + const S out2 = + skip_second ? S() : static_cast(ldg(tensor_offset + out_ptr)); + int shared_offset = data_idx; + S* accum_ptr = accum_offset + shared_data; + UNROLL for (int r = 0; r < filter_height; ++r) { + UNROLL for (int c = 0; c < filter_width; ++c) { + const S* const tile_ptr = shared_offset + shared_data; + S val = out1 * tile_ptr[0] + out2 * tile_ptr[tile_offset]; + // Warp-accumulate pixels of the same depth and write to accumulator. + for (int delta = 16; delta >= kBlockDepth; delta /= 2) { + val += GpuShuffleXorSync(active_threads, val, delta); + } + if (!(thread_idx & 32 - kBlockDepth) /* lane_idx < kBlockDepth */) { + *accum_ptr = val; + } + shared_offset += kBlockDepth; + accum_ptr += accum_increment; + } + shared_offset += in_increment; + } + } + + // Note: the condition to reach this is uniform across the entire block. + __syncthreads(); + + const S* const accum_data = tile_size + shared_data; + for (int i = thread_idx; i < accum_size; i += block_size) { + const int filter_idx = i / kAccumPixels; + const int filter_pix = filter_idx / kBlockDepth; + const int filter_channel = filter_idx % kBlockDepth + start_channel; + const int filter_offset = filter_pix * in_depth + filter_channel; + if (filter_channel < in_depth) { + S val = accum_data[i]; + // Warp-accumulate the pixels of the same depth from the accumulator. + val = WarpSumReduce(val); + if (!(thread_idx & kAccumPixels - 1)) { + GpuAtomicAdd(filter_offset + filter, static_cast(val)); + } + } + } + } +} + +// A GPU kernel to compute the depthwise convolution backprop w.r.t. filter. +template +__global__ void __launch_bounds__(512, 2) + DepthwiseConv2dBackpropFilterGPUKernelNCHW( + const DepthwiseArgs args, const T* __restrict__ out_backprop, + const T* __restrict__ input, T* __restrict__ filter_backprop) { + const int batch_num = args.batch; + const int in_depth = args.in_depth; + const int in_height = args.in_rows; + const int in_width = args.in_cols; + const int filter_width = args.filter_cols; + const int stride_height = args.stride; + const int stride_width = args.stride; + const int pad_height = args.pad_rows; + const int pad_width = args.pad_cols; + const int out_depth = args.out_depth; + const int out_height = args.out_rows; + const FastDividerUint32 out_width = args.out_cols; + const FastDividerUint32 depth_multiplier = args.depth_multiplier; + assert(gridDim.x == filter_width); + assert(gridDim.z == out_depth); + + typedef gpuprim::WarpReduce WarpReduce; + typename WarpReduce::TempStorage temp_storage; + + const int filter_w = blockIdx.x; + const int filter_h = blockIdx.y; + const int out_c = blockIdx.z; + + const int in_c = out_c / depth_multiplier; + const int dm = out_c % depth_multiplier; + const int filter_backprop_offset = + (((filter_h * filter_width) + filter_w) * in_depth + in_c) * + depth_multiplier + + dm; + const int out_spatial_size = out_height * out_width; + + T partial_sum = static_cast(0.f); + for (int batch = 0; batch < batch_num; batch++) { + const int input_offset_temp = (batch * in_depth + in_c) * in_height; + const int output_backprop_offset_temp = + (batch * out_depth + out_c) * out_height; + for (int i = threadIdx.x; i < out_spatial_size; i += blockDim.x) { + const int out_col = i % out_width; + const int out_row = i / out_width; + // We use the formula: `(in_row - filter_w + pad_left ) / stride = + // out_row` to compute corresponding in_row and out_row positions. Similar + // for in_col and out_col. + const int in_row = out_row * stride_height + filter_h - pad_height; + const int in_col = out_col * stride_width + filter_w - pad_width; + + if (in_row < 0 || in_col < 0 || in_row >= in_height || + in_col >= in_width) { + continue; + } + + int input_offset = (input_offset_temp + in_row) * in_width + in_col; + int output_backprop_offset = + (output_backprop_offset_temp + out_row) * out_width + out_col; + partial_sum += out_backprop[output_backprop_offset] * input[input_offset]; + } + } + + T val = WarpReduce(temp_storage).Sum(partial_sum); + if (gpuprim::LaneId() == 0) { + T* addr = filter_backprop + filter_backprop_offset; + GpuAtomicAdd(addr, val); + } +} + +// CUDA kernel to compute the depthwise convolution backward w.r.t. filter in +// NCHW format, tailored for small images up to 32x32. Stride and depth +// multiplier must be 1. Padding must be 'SAME'. Only use this kernel if +// CanLaunchDepthwiseConv2dGPUSmall(args) returns true. +// Tiles of the input tensor are loaded into shared memory before performing the +// convolution. Per iteration and filter element, each thread first performs +// a partial convolution for two elements, one each in the lower and upper half +// of a tile. The intermediate result of all pixels of a warp are then +// accumulated and written to shared memory. Finally, the values in shared +// memory are warp-accumulated (in chunks of kAccumPixels elements) and summed +// up in global memory using atomics. +// Requirements: threads per block must be multiple of 32 and <= launch_bounds, +// kAccumPixels * 64 >= args.in_rows * args.in_cols * kBlockDepth. +template +__global__ +__launch_bounds__(1024, 2) void DepthwiseConv2dBackpropFilterGPUKernelNCHWSmall( + const DepthwiseArgs args, const T* __restrict__ output, + const T* __restrict__ input, T* __restrict__ filter) { + typedef typename detail::PseudoHalfType::Type S; + assert(CanLaunchDepthwiseConv2dBackpropFilterGPUSmall(args, blockDim.x)); + // Holds block plus halo and filter data for blockDim.z depths. + GPU_DYNAMIC_SHARED_MEM_DECL(8, unsigned char, shared_memory); + static_assert(sizeof(S) <= 8, "Insufficient alignment detected"); + S* const shared_data = reinterpret_cast(shared_memory); + + const int num_batches = args.batch; + const int in_height = args.in_rows; + const int in_width = blockDim.x; // slower (see b/62280718): args.in_cols; + const int in_depth = args.in_depth; + const int filter_height = + kKnownFilterHeight < 0 ? args.filter_rows : kKnownFilterHeight; + const int filter_width = + kKnownFilterWidth < 0 ? args.filter_cols : kKnownFilterWidth; + const int pad_height = args.pad_rows; + const int pad_width = args.pad_cols; + + assert(blockDim.x == args.in_cols); + assert(blockDim.z == kBlockDepth); + const int block_height = blockDim.y; + + // These values are the same for all threads and could + // be precomputed on the CPU. + const int block_pixels = in_width * block_height; + const int block_size = block_pixels * kBlockDepth; + assert((block_size & 31) == 0); + const int in_pixels = in_width * in_height; + const int in_increment = in_width - 1; + const int filter_pixels = filter_height * filter_width; + const int tile_width = in_width + filter_width - 1; + const int tile_height = 2 * block_height + filter_height - 1; + const int tile_pixels = tile_width * tile_height; + const int tile_size = tile_pixels * kBlockDepth; + const int tile_offset = block_height * tile_width; + const int pad_offset = pad_height * tile_width + pad_width; + const int in_total_depth = in_depth * num_batches; + const int in_blocks = (in_total_depth + kBlockDepth - 1) / kBlockDepth; + // The accumulator has a fixed number of pixels that can be reduced by one + // warp. Pixels beyond ceil(in_pixels * kBlockDepth / 64) are never written. + assert(kAccumPixels * 64 >= in_height * in_width * kBlockDepth); + const int accum_increment = kAccumPixels * kBlockDepth; + const int accum_size = filter_pixels * accum_increment; + + const int thread_col = threadIdx.x; + const int thread_row = threadIdx.y; + const int thread_depth = threadIdx.z; + + // Position in block. + const int thread_pix = thread_row * in_width + thread_col; + const int thread_idx = thread_depth * block_pixels + thread_pix; + + // Initialize tile, in particular the padding and accumulator. + for (int i = thread_idx; i < tile_size + accum_size; i += block_size) { + shared_data[i] = S(); + } + __syncthreads(); + + // Position in tensors. + const int tensor_idx = thread_depth * in_pixels + thread_pix; + + // Position in (padded) shared memory. + const int data_pix = thread_row * tile_width + thread_col; + const int data_idx = thread_depth * tile_pixels + data_pix; + + // Position in shared memory, offset by pad_height / pad_width. + const int tile_idx = data_idx + pad_offset; + + // Position in accumulator (kBlockDepth per warp, depth major). + const int accum_pix = thread_pix / (32 / kBlockDepth); + const int accum_idx = thread_depth * kAccumPixels + accum_pix; + + const int max_channel = in_total_depth - thread_depth; + const int accum_offset = tile_size + accum_idx; + const bool skip_second = block_height + thread_row >= in_height; + + for (int b = blockIdx.x; b < in_blocks; b += gridDim.x) { + const int channel = b * kBlockDepth; + + const int inout_offset = channel * in_pixels + tensor_idx; + const bool channel_in_range = channel < max_channel; + + if (channel_in_range) { + const T* const in_ptr = inout_offset + input; + S* const tile_ptr = tile_idx + shared_data; + tile_ptr[0] = static_cast(ldg(in_ptr)); + if (!skip_second) { + tile_ptr[tile_offset] = static_cast(ldg(block_pixels + in_ptr)); + } + } + + // Note: the condition to reach this is uniform across the entire block. + __syncthreads(); + unsigned active_threads = GpuBallotSync(kCudaWarpAll, channel_in_range); + + if (channel_in_range) { + const T* const out_ptr = inout_offset + output; + const S out1 = static_cast(ldg(out_ptr)); + const S out2 = + skip_second ? S() : static_cast(ldg(block_pixels + out_ptr)); + int shared_offset = data_idx; + S* accum_ptr = accum_offset + shared_data; + UNROLL for (int r = 0; r < filter_height; ++r) { + UNROLL for (int c = 0; c < filter_width; ++c) { + const S* const tile_ptr = shared_offset + shared_data; + S val = out1 * tile_ptr[0] + out2 * tile_ptr[tile_offset]; + // Warp-accumulate pixels of the same depth and write to accumulator. + for (int delta = 16 / kBlockDepth; delta > 0; delta /= 2) { + val += GpuShuffleXorSync(active_threads, val, delta); + } + if (!(thread_idx & 32 / kBlockDepth - 1)) { + *accum_ptr = val; // kBlockDepth threads per warp. + } + ++shared_offset; + accum_ptr += accum_increment; + } + shared_offset += in_increment; + } + } + + // Note: the condition to reach this is uniform across the entire block. + __syncthreads(); + + const S* const accum_data = tile_size + shared_data; + for (int i = thread_idx; i < accum_size; i += block_size) { + const int filter_idx = i / kAccumPixels; + const int filter_pix = filter_idx / kBlockDepth; + const int filter_channel = + (channel + filter_idx % kBlockDepth) % in_depth; + const int filter_offset = filter_pix * in_depth + filter_channel; + if (filter_channel < in_depth) { + S val = accum_data[i]; + // Warp-accumulate pixels of the same depth from the accumulator. + val = WarpSumReduce(val); + if (!(thread_idx & kAccumPixels - 1)) { + GpuAtomicAdd(filter_offset + filter, static_cast(val)); + } + } + } + } +} + +template +Status TryLaunchDepthwiseConv2dBackpropFilterGPUSmall( + OpKernelContext* ctx, const DepthwiseArgs& args, const int block_height, + const T* out_backprop, const T* input, T* filter_backprop, + TensorFormat data_format) { + typedef typename detail::PseudoHalfType::Type S; + auto device = ctx->eigen_gpu_device(); + const int tile_width = args.in_cols + args.filter_cols - 1; + const int tile_height = block_height * 2 + args.filter_rows - 1; + const int tile_pixels = tile_height * tile_width; + const int filter_pixels = args.filter_rows * args.filter_cols; + const int shared_memory_size = + kBlockDepth * (tile_pixels + filter_pixels * kAccumPixels) * sizeof(S); + if (shared_memory_size > device.sharedMemPerBlock()) { + return errors::FailedPrecondition("Not enough shared memory"); + } + + dim3 block_dim; + int block_count; + void (*kernel)(const DepthwiseArgs, const T*, const T*, T*); + switch (data_format) { + case FORMAT_NHWC: + block_dim = dim3(kBlockDepth, args.in_cols, block_height); + block_count = + args.batch * DivUp(args.out_depth, kBlockDepth) * kBlockDepth; + kernel = DepthwiseConv2dBackpropFilterGPUKernelNHWCSmall< + T, kKnownFilterWidth, kKnownFilterHeight, kBlockDepth, kAccumPixels>; + break; + case FORMAT_NCHW: + block_dim = dim3(args.in_cols, block_height, kBlockDepth); + block_count = + DivUp(args.batch * args.out_depth, kBlockDepth) * kBlockDepth; + kernel = DepthwiseConv2dBackpropFilterGPUKernelNCHWSmall< + T, kKnownFilterWidth, kKnownFilterHeight, kBlockDepth, kAccumPixels>; + break; + default: + return errors::InvalidArgument("FORMAT_", ToString(data_format), + " is not supported"); + } + const int num_out_backprop = args.out_rows * args.out_cols * block_count; + GpuLaunchConfig config = GetGpuLaunchConfigFixedBlockSize( + num_out_backprop, device, kernel, shared_memory_size, + block_dim.x * block_dim.y * block_dim.z); + TF_CHECK_OK(GpuLaunchKernel(kernel, config.block_count, block_dim, + shared_memory_size, device.stream(), args, + out_backprop, input, filter_backprop)); + return OkStatus(); +} + +template +Status TryLaunchDepthwiseConv2dBackpropFilterGPUSmall( + OpKernelContext* ctx, const DepthwiseArgs& args, const int block_height, + const T* out_backprop, const T* input, T* filter_backprop, + TensorFormat data_format) { + // Minimize (power of two) kAccumPixels, while satisfying + // kAccumPixels * 32 >= block_height * in_width * kBlockDepth. + const int block_pixels = block_height * args.in_cols * kBlockDepth; + if (block_pixels > 512) { + return TryLaunchDepthwiseConv2dBackpropFilterGPUSmall< + T, kKnownFilterWidth, kKnownFilterHeight, kBlockDepth, 32>( + ctx, args, block_height, out_backprop, input, filter_backprop, + data_format); + } else if (block_pixels > 256) { + return TryLaunchDepthwiseConv2dBackpropFilterGPUSmall< + T, kKnownFilterWidth, kKnownFilterHeight, kBlockDepth, 16>( + ctx, args, block_height, out_backprop, input, filter_backprop, + data_format); + } else { + return TryLaunchDepthwiseConv2dBackpropFilterGPUSmall< + T, kKnownFilterWidth, kKnownFilterHeight, kBlockDepth, 8>( + ctx, args, block_height, out_backprop, input, filter_backprop, + data_format); + } +} + +template +Status TryLaunchDepthwiseConv2dBackpropFilterGPUSmall( + OpKernelContext* ctx, const DepthwiseArgs& args, const T* out_backprop, + const T* input, T* filter_backprop, TensorFormat data_format) { + // Maximize (power of two) kBlockDepth while keeping a block within 1024 + // threads (2 pixels per thread). + int block_depth = 8; + int block_height = (args.in_rows + 1) / 2; + int round_mask = 1; + for (; block_depth > 1; block_depth /= 2) { + // args.in_cols * block_height * kBlockDepth must be multiple of 32. + for (; block_height * args.in_cols * block_depth & 31; + round_mask = round_mask * 2 + 1) { + block_height = block_height + round_mask & ~round_mask; + } + int block_size = block_height * args.in_cols * block_depth; + if (block_size <= 1024) { + break; + } + } + + if (!CanLaunchDepthwiseConv2dBackpropFilterGPUSmall(args, block_height)) { + return errors::FailedPrecondition("Cannot launch this configuration"); + } + + switch (block_depth) { + case 8: + return TryLaunchDepthwiseConv2dBackpropFilterGPUSmall< + T, kKnownFilterWidth, kKnownFilterHeight, 8>( + ctx, args, block_height, out_backprop, input, filter_backprop, + data_format); + case 4: + return TryLaunchDepthwiseConv2dBackpropFilterGPUSmall< + T, kKnownFilterWidth, kKnownFilterHeight, 4>( + ctx, args, block_height, out_backprop, input, filter_backprop, + data_format); + case 2: + return TryLaunchDepthwiseConv2dBackpropFilterGPUSmall< + T, kKnownFilterWidth, kKnownFilterHeight, 2>( + ctx, args, block_height, out_backprop, input, filter_backprop, + data_format); + default: + return errors::InvalidArgument("Unexpected block depth"); + } +} + +template +Status LaunchDepthwiseConv2dBackpropFilterGPU( + OpKernelContext* ctx, const DepthwiseArgs& args, const T* out_backprop, + const T* input, T* filter_backprop, TensorFormat data_format) { + auto device = ctx->eigen_gpu_device(); + const int num_out_backprop = + args.batch * args.out_rows * args.out_cols * args.out_depth; + if (data_format == FORMAT_NHWC) { + auto kernel = DepthwiseConv2dBackpropFilterGPUKernelNHWC< + T, kKnownFilterWidth, kKnownFilterHeight, kKnownDepthMultiplier>; + + int launch_bounds_value = 640; + GpuLaunchConfig config = GetGpuLaunchConfig(num_out_backprop, device, + kernel, 0, launch_bounds_value); + TF_CHECK_OK(GpuLaunchKernel( + kernel, config.block_count, config.thread_per_block, 0, device.stream(), + args, out_backprop, input, filter_backprop, num_out_backprop)); + } else if (data_format == FORMAT_NCHW) { + auto kernel = DepthwiseConv2dBackpropFilterGPUKernelNCHW; + dim3 blocks = dim3(args.filter_cols, args.filter_rows, args.out_depth); + dim3 threads = dim3(512, 1, 1); + + TF_CHECK_OK(GpuLaunchKernel(kernel, blocks, threads, 0, device.stream(), + args, out_backprop, input, filter_backprop)); + } else { + return errors::InvalidArgument("FORMAT_", ToString(data_format), + " is not supported"); + } + + return OkStatus(); +} + +template +Status LaunchDepthwiseConv2dBackpropFilterGPU( + OpKernelContext* ctx, const DepthwiseArgs& args, const T* out_backprop, + const T* input, T* filter_backprop, TensorFormat data_format) { + if (args.depth_multiplier == 1) { + if (TryLaunchDepthwiseConv2dBackpropFilterGPUSmall( + ctx, args, out_backprop, input, filter_backprop, data_format) + .ok()) { + return OkStatus(); + } + + return LaunchDepthwiseConv2dBackpropFilterGPU( + ctx, args, out_backprop, input, filter_backprop, data_format); + } else { + return LaunchDepthwiseConv2dBackpropFilterGPU( + ctx, args, out_backprop, input, filter_backprop, data_format); + } +} + +// A simple launch pad to launch the GPU kernel for depthwise convolution. +template +void LaunchDepthwiseConvBackpropFilterOp::operator()( + OpKernelContext* ctx, const DepthwiseArgs& args, const T* out_backprop, + const T* input, T* filter_backprop, TensorFormat data_format) { + auto stream = ctx->op_device_context()->stream(); + + // It's simpler to catch this here than in + // DepthwiseConv2dNativeBackpropFilterOp + OP_REQUIRES( + ctx, !OpDeterminismRequired(), + errors::Unimplemented( + "A deterministic GPU implementation of DepthwiseConvBackpropFilter is" + " not available with this version of cuDNN. Please build with cuDNN" + " version 7.6.3 or later.")); + + // Initialize the results to 0. + int num_filter_backprop = + args.filter_rows * args.filter_cols * args.out_depth; + se::DeviceMemoryBase filter_bp_ptr(filter_backprop, num_filter_backprop); + OP_REQUIRES_OK( + ctx, stream->MemZero(&filter_bp_ptr, num_filter_backprop * sizeof(T))); + + if (args.filter_rows == 3 && args.filter_cols == 3) { + OP_REQUIRES_OK( + ctx, LaunchDepthwiseConv2dBackpropFilterGPU( + ctx, args, out_backprop, input, filter_backprop, data_format)); + } else { + OP_REQUIRES_OK( + ctx, LaunchDepthwiseConv2dBackpropFilterGPU( + ctx, args, out_backprop, input, filter_backprop, data_format)); + } +} +} // namespace tensorflow +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM + +#endif // TENSORFLOW_CORE_KERNELS_DEPTHWISE_CONV_OP_GPU_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/diag_op.h b/third_party/tflite-hdrs/tensorflow/core/kernels/diag_op.h new file mode 100644 index 00000000..c41da62d --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/diag_op.h @@ -0,0 +1,43 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_DIAG_OP_H_ +#define TENSORFLOW_CORE_KERNELS_DIAG_OP_H_ + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +namespace functor { + +template +struct DiagFunctor { + absl::Status operator()(OpKernelContext* context, const int64_t size, + const T* in, T* out); +}; + +template +struct DiagPartFunctor { + absl::Status operator()(OpKernelContext* context, const int64_t size, + const T* in, T* out); +}; + +} // namespace functor + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_DIAG_OP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/dilation_ops.h b/third_party/tflite-hdrs/tensorflow/core/kernels/dilation_ops.h new file mode 100644 index 00000000..4f0b944a --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/dilation_ops.h @@ -0,0 +1,66 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_DILATION_OPS_H_ +#define TENSORFLOW_CORE_KERNELS_DILATION_OPS_H_ + +#include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +namespace functor { + +template +struct Dilation { + // We assume that the tensor sizes are correct. + void operator()(const Device& d, typename TTypes::ConstTensor input, + typename TTypes::ConstTensor filter, int stride_rows, + int stride_cols, int rate_rows, int rate_cols, int pad_top, + int pad_left, typename TTypes::Tensor output); +}; + +template +struct DilationBackpropInput { + // We assume that the tensor sizes are correct. + // To avoid storing the argmax values during forward computation, we recompute + // the argmax during backward computation, which is the reason why we provide + // filter as argument to the backward computation routine. + void operator()(const Device& d, typename TTypes::ConstTensor input, + typename TTypes::ConstTensor filter, + typename TTypes::ConstTensor out_backprop, + int stride_rows, int stride_cols, int rate_rows, + int rate_cols, int pad_top, int pad_left, + typename TTypes::Tensor in_backprop); +}; + +template +struct DilationBackpropFilter { + // We assume that the tensor sizes are correct. + // To avoid storing the argmax values during forward computation, we recompute + // the argmax during backward computation, which is the reason why we provide + // filter as argument to the backward computation routine. + void operator()(const Device& d, typename TTypes::ConstTensor input, + typename TTypes::ConstTensor filter, + typename TTypes::ConstTensor out_backprop, + int stride_rows, int stride_cols, int rate_rows, + int rate_cols, int pad_top, int pad_left, + typename TTypes::Tensor filter_backprop); +}; + +} // namespace functor +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_DILATION_OPS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/eigen_activations.h b/third_party/tflite-hdrs/tensorflow/core/kernels/eigen_activations.h new file mode 100644 index 00000000..8224627f --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/eigen_activations.h @@ -0,0 +1,122 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_EIGEN_ACTIVATIONS_H_ +#define TENSORFLOW_CORE_KERNELS_EIGEN_ACTIVATIONS_H_ + +#include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive + +namespace Eigen { + +/** scalar_sigmoid_fast_derivative_op + * \ingroup CXX11_NeuralNetworks_Module + * \brief Template functor to compute the fast derivative of a sigmoid + * + * Input should be the backpropagated gradient. + * + * \sa class CwiseUnaryOp, Cwise::sigmoid_fast_derivative() + */ +template +struct scalar_sigmoid_fast_derivative_op { + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& y) const { + const T one = T(1); + return (one - y) * y; + } + + template + inline Packet packetOp(const Packet& y) const { + const Packet one = internal::pset1(1); + return internal::pmul(internal::psub(one, y), y); + } +}; + +namespace internal { +template +struct functor_traits > { + enum { + Cost = NumTraits::AddCost * 2 + NumTraits::MulCost, + PacketAccess = packet_traits::HasAdd && packet_traits::HasMul && + packet_traits::HasNegate + }; +}; +} // namespace internal + +/** scalar_tanh_fast_derivative_op + * \ingroup CXX11_NeuralNetworks_Module + * \brief Template functor to compute the fast derivative of a tanh + * + * Input should be the backpropagated gradient. + * + * \sa class CwiseUnaryOp, Cwise::tanh_fast_derivative() + */ +template +struct scalar_tanh_fast_derivative_op { + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& y) const { + const T one = T(1); + return one - (y * y); + } + + template + inline Packet packetOp(const Packet& y) const { + const Packet one = internal::pset1(1); + return internal::psub(one, internal::pmul(y, y)); + } +}; + +namespace internal { +template +struct functor_traits > { + enum { + Cost = NumTraits::AddCost * 2 + NumTraits::MulCost * 1, + PacketAccess = packet_traits::HasAdd && packet_traits::HasMul && + packet_traits::HasNegate + }; +}; +} // namespace internal + +/** + * \ingroup CXX11_NeuralNetworks_Module + * \brief Template functor to clip the magnitude of the first scalar. + * + * \sa class CwiseBinaryOp, MatrixBase::Clip + */ +template +struct scalar_clip_op { + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar + operator()(const Scalar& a, const Scalar& b) const { + return numext::mini(numext::maxi(a, -b), b); + } + template + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Packet + packetOp(const Packet& a, const Packet& b) const { + return internal::pmin(internal::pmax(a, internal::pnegate(b)), b); + } +}; + +namespace internal { +template +struct functor_traits > { + enum { + Cost = NumTraits::AddCost * 3, + PacketAccess = packet_traits::HasMax && + packet_traits::HasMin && + packet_traits::HasNegate + }; +}; +} // namespace internal + +} // end namespace Eigen + +#endif // TENSORFLOW_CORE_KERNELS_EIGEN_ACTIVATIONS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/eigen_attention.h b/third_party/tflite-hdrs/tensorflow/core/kernels/eigen_attention.h new file mode 100644 index 00000000..7eec12bf --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/eigen_attention.h @@ -0,0 +1,300 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_EIGEN_ATTENTION_H_ +#define TENSORFLOW_CORE_KERNELS_EIGEN_ATTENTION_H_ + +#include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive + +namespace Eigen { + +// Noise mode used when padding. +enum ExtractGlimpsesNoiseMode { + UNIFORM = 0, + GAUSSIAN = 1, + ZERO = 2, +}; + +/** ExtractGlimpses + * \ingroup CXX11_NeuralNetworks_Module + * + * \brief Extract glimpses from an input tensor. + * + * The input parameter is expected to be a col-major tensor with a rank of 4 + * (depth, x, y, and batch). The width and height parameters specify the + * extension of the returned glimpses. The offsets parameter specifies the x, y + * locations of the center of the glimpses relative to the center of the input + * image. The vector is expected to contain one IndexPair for each image in the + * batch dimension. The normalized boolean indicates if incoming coordinates are + * normalized so that 0.0 and 1.0 correspond to the minimum and maximum of each + * height and width dimension. The centered boolean indicates if incoming + * coordinates are centered relative to the image, in which case -1.0 and 1.0 + * correspond to minimum and maximum of each dimension while 0.0 corresponds to + * the center. + * + * The result can be assigned to a tensor of rank equal to that of the input. + * The result will be laid out in col-major order (depth, x, y, batch). The + * dimensions of the result will be equal to the dimensions of the input except + * for width and height which will be equal to the requested glimpse size. + */ +namespace { + +template +struct GlimpseExtractionOp { + GlimpseExtractionOp(const Index width, const Index height, + const std::vector >& offsets, + const bool normalized, const bool centered, + const ExtractGlimpsesNoiseMode noise, const int version) + : width_(width), + height_(height), + offsets_(offsets), + normalized_(normalized), + centered_(centered), + noise_(noise), + version_(version) {} + + template + DSizes dimensions(const Input& input) const { + typedef typename internal::traits::Index IndexType; + typedef TensorRef::Scalar, 4, + internal::traits::Layout, IndexType> > + Ref; + Ref in(input); + + DSizes dims = in.dimensions(); + + dims[0] = in.dimension(0); + dims[1] = width_; + dims[2] = height_; + dims[3] = in.dimension(3); + return dims; + } + + template + EIGEN_DEVICE_FUNC void eval(const Input& input, Output& output, + const Device& device) const { + typedef typename internal::traits::Index IndexType; + typedef TensorRef::Scalar, 4, + internal::traits::Layout, IndexType> > + Ref; + Ref in(input); + const Index num_channels = in.dimension(0); + const Index input_width = in.dimension(1); + const Index input_height = in.dimension(2); + const Index batch_size = in.dimension(3); + eigen_assert(input_width > 0); + eigen_assert(input_height > 0); + internal::NormalRandomGenerator gen; + internal::UniformRandomGenerator unigen; + + for (Index i = 0; i < batch_size; ++i) { + float x = offsets_[i].first, y = offsets_[i].second; + + if (version_ == 1) { + // Un-normalize coordinates back to pixel space if normalized. + if (normalized_) { + x *= input_width; + y *= input_height; + } + // Un-center if coordinates are centered on the image center. + if (centered_) { + x /= 2.0f; + y /= 2.0f; + x += input_width / 2.0f; + y += input_height / 2.0f; + } + // Remove half of the glimpse window. + x -= width_ / 2.0f; + y -= height_ / 2.0f; + } else { + if (normalized_) { + // Un-normalize coordinates back to pixel space if normalized. + x *= input_width; + y *= input_height; + if (centered_) { + // Un-center if coordinates are centered on the image center. + x /= 2.0f; + y /= 2.0f; + x += input_width / 2.0f; + y += input_height / 2.0f; + // Remove half of the glimpse window. + x -= width_ / 2.0f; + y -= height_ / 2.0f; + } + } else { + if (centered_) { + x += input_width / 2.0f; + y += input_height / 2.0f; + } + } + } + + const Index offset_x = (Index)x; + const Index offset_y = (Index)y; + Index glimpse_width = width_; + Index glimpse_height = height_; + bool partial_overlap = false; + DSizes slice_offset(0, offset_x, offset_y); + DSizes slice_extent(num_channels, width_, height_); + DSizes base_offset(0, 0, 0); + + if (offset_x < 0) { + slice_offset[1] = 0; + glimpse_width = (std::max)(0, width_ + offset_x); + slice_extent[1] = glimpse_width; + base_offset[1] = width_ - glimpse_width; + partial_overlap = true; + } else if (offset_x + width_ >= input_width) { + glimpse_width = (std::max)(0, input_width - offset_x); + slice_extent[1] = glimpse_width; + partial_overlap = true; + } + if (offset_y < 0) { + slice_offset[2] = 0; + glimpse_height = (std::max)(0, height_ + offset_y); + slice_extent[2] = glimpse_height; + base_offset[2] = height_ - glimpse_height; + partial_overlap = true; + } else if (offset_y + height_ >= input_height) { + glimpse_height = (std::max)(0, input_height - offset_y); + slice_extent[2] = glimpse_height; + partial_overlap = true; + } + slice_extent[1] = std::min(input_width, slice_extent[1]); + slice_extent[2] = std::min(input_height, slice_extent[2]); + + if (partial_overlap) { + switch (noise_) { + case ZERO: { + // Initialize the glimpse with zero noise. + output.template chip<3>(i).device(device) = + output.template chip<3>(i).constant(0); + } break; + case UNIFORM: { + // Initialize the glimpse with uniform noise. + typedef std::remove_const_t< + typename internal::traits::Scalar> + Scalar; + TensorFixedSize > mini; + mini.device(device) = input.template chip<3>(i).minimum(); + TensorFixedSize > range; + range.device(device) = (input.template chip<3>(i).maximum() - mini) + .template cast(); + + DSizes glimpse_size(num_channels, width_, height_); + TensorMap > tmp(nullptr, glimpse_size); + output.template chip<3>(i).device(device) = + mini.reshape(Sizes<1, 1, 1>()).broadcast(glimpse_size) + + (tmp.random(unigen) * + range.reshape(Sizes<1, 1, 1>()).broadcast(glimpse_size)) + .template cast(); + } break; + case GAUSSIAN: { + // Initialize the glimpse with white noise: compute the mean and + // sigma + // of each channel, and use them to shape the gaussian. + DSizes glimpse_size(width_, height_); + DSizes input_size(input_width, input_height); + typedef std::remove_const_t< + typename internal::traits::Scalar> + Scalar; + + for (int j = 0; j < num_channels; ++j) { + TensorFixedSize > mean; + mean.device(device) = input.template chip<3>(i) + .template chip<0>(j) + .template cast() + .mean(); + TensorFixedSize > sigma; + sigma.device(device) = + (input.template chip<3>(i) + .template chip<0>(j) + .template cast() - + mean.reshape(Sizes<1, 1>()).broadcast(input_size)) + .square() + .mean() + .sqrt(); + TensorFixedSize > mini; + mini.device(device) = + input.template chip<3>(i).template chip<0>(j).minimum(); + TensorFixedSize > maxi; + maxi.device(device) = + input.template chip<3>(i).template chip<0>(j).maximum(); + + TensorMap > tmp(nullptr, glimpse_size); + output.template chip<3>(i).template chip<0>(j).device(device) = + (mean.reshape(Sizes<1, 1>()).broadcast(glimpse_size) + + (tmp.random(gen) * + sigma.reshape(Sizes<1, 1>()).broadcast(glimpse_size)) + .template cast()) + .cwiseMin( + maxi.reshape(Sizes<1, 1>()).broadcast(glimpse_size)) + .cwiseMax( + mini.reshape(Sizes<1, 1>()).broadcast(glimpse_size)); + } + } break; + } + + // Copy the part of the glimpse that cover the input image if any. + if (glimpse_width == 0 || glimpse_height == 0) { + continue; + } + output.template chip<3>(i) + .slice(base_offset, slice_extent) + .device(device) = + input.template chip<3>(i).slice(slice_offset, slice_extent); + } else { + output.template chip<3>(i).device(device) = + input.template chip<3>(i).slice(slice_offset, slice_extent); + } + } + } + + private: + const Index width_; + const Index height_; + const std::vector > offsets_; + const bool normalized_; + const bool centered_; + const ExtractGlimpsesNoiseMode noise_; + const int version_; +}; +} // namespace + +template +EIGEN_ALWAYS_INLINE static const TensorCustomUnaryOp< + const GlimpseExtractionOp::Index>, + const Input> +ExtractGlimpses( + const Input& input, const typename internal::traits::Index width, + const typename internal::traits::Index height, + const std::vector >& offsets, const bool normalized = true, + const bool centered = true, + const ExtractGlimpsesNoiseMode noise = ExtractGlimpsesNoiseMode::UNIFORM, + const int version = 2) { + EIGEN_STATIC_ASSERT(internal::traits::Layout == ColMajor, + YOU_MADE_A_PROGRAMMING_MISTAKE); + EIGEN_STATIC_ASSERT(internal::traits::NumDimensions == 4, + YOU_MADE_A_PROGRAMMING_MISTAKE); + + typedef typename internal::traits::Index Index; + const GlimpseExtractionOp op(width, height, offsets, normalized, + centered, noise, version); + return input.customOp(op); +} + +} // end namespace Eigen + +#endif // TENSORFLOW_CORE_KERNELS_EIGEN_ATTENTION_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/eigen_backward_cuboid_convolutions.h b/third_party/tflite-hdrs/tensorflow/core/kernels/eigen_backward_cuboid_convolutions.h new file mode 100644 index 00000000..4ef1b924 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/eigen_backward_cuboid_convolutions.h @@ -0,0 +1,610 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_EIGEN_BACKWARD_CUBOID_CONVOLUTIONS_H_ +#define TENSORFLOW_CORE_KERNELS_EIGEN_BACKWARD_CUBOID_CONVOLUTIONS_H_ + +#include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/kernels/eigen_cuboid_convolution.h" + +namespace Eigen { + +/** CuboidConvolutionBackwardInput + * \ingroup CXX11_NeuralNetworks_Module + * + * \brief Computes the backprop for the input of a 3D convolution. + * + * The output_backward parameter is expected to be a tensor with a rank of 4 or + * more (channels, depth, height, width, and optionally others) + * The kernel parameter is expected to be a 5D tensor (filters, channels, + * kernel_depth, kernel_height, kernel_width) + * output_backward and kernel have to be in the same layout. + * + * The dimensions of the result will be filters, depth, height, width (and + * others if applicable). + * + * It is possible to swap the order of the depth, width and height dimensions + * provided that the same order is used in the input, the kernel, and the + * output. + * + * All dimension orders above are given for col-major, and should be reversed + * for row-major. + */ + +template +EIGEN_ALWAYS_INLINE static const std::conditional_t< + internal::traits::Layout == ColMajor, + TensorReshapingOp< + const DSizes::Index, + internal::traits::NumDimensions>, + const TensorContractionOp< + const array< + IndexPair::Index>, 1>, + const Eigen::TensorForcedEvalOp::Index, + 2>, + const TensorShufflingOp< + const array< + typename internal::traits::Index, 5>, + const TensorReverseOp, + const Kernel>>>>, + const TensorReshapingOp< + const DSizes::Index, + 2>, + const TensorVolumePatchOp>>>, + TensorReshapingOp< + const DSizes::Index, + internal::traits::NumDimensions>, + const TensorContractionOp< + const array< + IndexPair::Index>, 1>, + const TensorReshapingOp< + const DSizes::Index, + 2>, + const TensorVolumePatchOp>, + const Eigen::TensorForcedEvalOp::Index, + 2>, + const TensorShufflingOp< + const array< + typename internal::traits::Index, 5>, + const TensorReverseOp, + const Kernel>>>>>>> +CuboidConvolutionBackwardInput( + const Kernel& kernel, const OutputBackward& output_backward, + typename internal::traits::Index inputPlanes, + typename internal::traits::Index inputRows, + typename internal::traits::Index inputCols, + const DenseIndex plane_stride = 1, const DenseIndex row_stride = 1, + const DenseIndex col_stride = 1) { + typedef typename internal::traits::Index TensorIndex; + const TensorRef::Scalar, + internal::traits::NumDimensions, + internal::traits::Layout, TensorIndex>> + kern(kernel); + const TensorRef< + const Tensor::Scalar, + internal::traits::NumDimensions, + internal::traits::Layout, TensorIndex>> + out(output_backward); + + EIGEN_STATIC_ASSERT(internal::traits::Layout == + internal::traits::Layout, + YOU_MADE_A_PROGRAMMING_MISTAKE); + + static const bool isColMajor = + (internal::traits::Layout == ColMajor); + + static const int NumDims = internal::traits::NumDimensions; + + // Number of filters to apply. This is the same as the output depth of the + // result + const TensorIndex kernelFilters = + isColMajor ? kern.dimensions()[0] : kern.dimensions()[4]; + // Number of channels. This is the same as the input depth. + const TensorIndex kernelChannels = + isColMajor ? kern.dimensions()[1] : kern.dimensions()[3]; + const TensorIndex kernelPlanes = + isColMajor ? kern.dimensions()[2] : kern.dimensions()[2]; + const TensorIndex kernelRows = + isColMajor ? kern.dimensions()[3] : kern.dimensions()[1]; + const TensorIndex kernelCols = + isColMajor ? kern.dimensions()[4] : kern.dimensions()[0]; + + const TensorIndex outputPlanes = + isColMajor ? out.dimensions()[1] : out.dimensions()[NumDims - 2]; + const TensorIndex outputRows = + isColMajor ? out.dimensions()[2] : out.dimensions()[NumDims - 3]; + const TensorIndex outputCols = + isColMajor ? out.dimensions()[3] : out.dimensions()[NumDims - 4]; + + // TODO(ezhulenev): Add support for inflated strides. Without inflated strides + // effective kernel planes/rows/cols are always the same as the kernel itself + // (see eigen_spatial_convolutions for details). + const TensorIndex kernelPlanesEff = kernelPlanes; + const TensorIndex kernelRowsEff = kernelRows; + const TensorIndex kernelColsEff = kernelCols; + + // Computing the forward padding. + const TensorIndex forward_pad_top_z = numext::maxi( + 0, + ((outputPlanes - 1) * plane_stride + kernelPlanesEff - inputPlanes) / 2); + const TensorIndex forward_pad_top = numext::maxi( + 0, ((outputRows - 1) * row_stride + kernelRowsEff - inputRows) / 2); + const TensorIndex forward_pad_left = numext::maxi( + 0, ((outputCols - 1) * col_stride + kernelColsEff - inputCols) / 2); + + const TensorIndex padding_top_z = kernelPlanesEff - 1 - forward_pad_top_z; + const TensorIndex padding_top = kernelRowsEff - 1 - forward_pad_top; + const TensorIndex padding_left = kernelColsEff - 1 - forward_pad_left; + + const TensorIndex padding_bottom_z = inputPlanes - + (outputPlanes - 1) * plane_stride - 2 - + padding_top_z + kernelPlanesEff; + const TensorIndex padding_bottom = inputRows - (outputRows - 1) * row_stride - + 2 - padding_top + kernelRowsEff; + const TensorIndex padding_right = inputCols - (outputCols - 1) * col_stride - + 2 - padding_left + kernelColsEff; + + eigen_assert(padding_top_z >= 0); + eigen_assert(padding_top >= 0); + eigen_assert(padding_left >= 0); + eigen_assert(padding_bottom_z >= 0); + eigen_assert(padding_bottom >= 0); + eigen_assert(padding_right >= 0); + + // The kernel has dimensions : + // filters x channels x patch_planes x patch_rows x patch_cols. + // We need to reverse the kernel along the spatial dimensions. + Eigen::array kernel_reverse; + if (isColMajor) { + kernel_reverse[0] = false; + kernel_reverse[1] = false; + kernel_reverse[2] = true; + kernel_reverse[3] = true; + kernel_reverse[4] = true; + } else { + kernel_reverse[0] = true; + kernel_reverse[1] = true; + kernel_reverse[2] = true; + kernel_reverse[3] = false; + kernel_reverse[4] = false; + } + + // Reorder the dimensions to: + // filters x patch_planes x patch_rows x patch_cols x channels + array kernel_shuffle; + if (isColMajor) { + // From: filters x channels x planes x rows x cols + // To: filters x planes x rows x cols x channels + kernel_shuffle[0] = 0; + kernel_shuffle[1] = 2; + kernel_shuffle[2] = 3; + kernel_shuffle[3] = 4; + kernel_shuffle[4] = 1; + } else { + // From: cols x rows x planes x channels x filters + // To: channels x cols x rows x planes x filters + kernel_shuffle[0] = 3; + kernel_shuffle[1] = 0; + kernel_shuffle[2] = 1; + kernel_shuffle[3] = 2; + kernel_shuffle[4] = 4; + } + + // Collapse the dims + DSizes kernel_dims; + if (isColMajor) { + kernel_dims[0] = kernelFilters * kernelPlanes * kernelRows * kernelCols; + kernel_dims[1] = kernelChannels; + } else { + kernel_dims[1] = kernelFilters * kernelPlanes * kernelRows * kernelCols; + kernel_dims[0] = kernelChannels; + } + + // The output_backward has dimensions out_depth X out_planes X out_rows X + // out_cols X OTHERS + // When we extract the image patches from output_backward, it will have + // dimensions: + // out_depth X (patch_planes * patch_rows * patch_cols) X (input_planes * + // input_rows * input_cols * OTHERS) + DSizes pre_contract_dims; + if (isColMajor) { + pre_contract_dims[0] = + kernelFilters * kernelPlanes * kernelRows * kernelCols; + pre_contract_dims[1] = inputPlanes * inputRows * inputCols; + for (int i = 4; i < NumDims; ++i) { + pre_contract_dims[1] *= out.dimension(i); + } + } else { + pre_contract_dims[1] = + kernelFilters * kernelPlanes * kernelRows * kernelCols; + pre_contract_dims[0] = inputPlanes * inputRows * inputCols; + for (int i = 0; i < NumDims - 4; ++i) { + pre_contract_dims[0] *= out.dimension(i); + } + } + + // We will contract along the collapsed dimension that contains the + // kernelFilters, kernelPlanes, kernelRows and kernelCols. + array, 1> contract_dims; + if (isColMajor) { + // col-major: kernel.contract(output.patches) + contract_dims[0] = IndexPair(0, 0); + } else { + // row-major: output.patches.contract(kernel) + contract_dims[0] = IndexPair(1, 1); + } + + // Post contraction, the dimensions of the input_backprop is + // channels X input_planes X input_rows X input_cols X OTHERS + DSizes post_contract_dims; + if (isColMajor) { + post_contract_dims[0] = kernelChannels; + post_contract_dims[1] = inputPlanes; + post_contract_dims[2] = inputRows; + post_contract_dims[3] = inputCols; + for (int i = 4; i < NumDims; ++i) { + post_contract_dims[i] = out.dimension(i); + } + } else { + post_contract_dims[NumDims - 1] = kernelChannels; + post_contract_dims[NumDims - 2] = inputPlanes; + post_contract_dims[NumDims - 3] = inputRows; + post_contract_dims[NumDims - 4] = inputCols; + for (int i = 0; i < NumDims - 4; ++i) { + post_contract_dims[i] = out.dimension(i); + } + } + + return choose( + Cond::Layout == ColMajor>(), + kernel.reverse(kernel_reverse) + .shuffle(kernel_shuffle) + .reshape(kernel_dims) + .eval() + .contract(output_backward + .extract_volume_patches( + kernelPlanes, kernelRows, kernelCols, 1, 1, 1, + plane_stride, row_stride, col_stride, padding_top_z, + padding_bottom_z, padding_top, padding_bottom, + padding_left, padding_right) + .reshape(pre_contract_dims), + contract_dims) + .reshape(post_contract_dims), + output_backward + .extract_volume_patches(kernelPlanes, kernelRows, kernelCols, 1, 1, 1, + plane_stride, row_stride, col_stride, + padding_top_z, padding_bottom_z, padding_top, + padding_bottom, padding_left, padding_right) + .reshape(pre_contract_dims) + .contract(kernel.reverse(kernel_reverse) + .shuffle(kernel_shuffle) + .reshape(kernel_dims) + .eval(), + contract_dims) + .reshape(post_contract_dims)); +} + +/** CuboidConvolutionBackwardKernel + * \ingroup CXX11_NeuralNetworks_Module + * + * \brief Computes the backprop for the filter of a 3D convolution. + * + * The output_backward parameter is expected to be a tensor with a rank of 4 or + * more (channels, depth, height, width, and optionally others) + * The kernel parameter is expected to be a 4D tensor (filters, channels, + * kernel_depth, kernel_height, kernel_width) + * output_backward and kernel have to be in the same layout. + * + * The dimensions of the result will be filters, depth, height, width (and + * others if applicable). + * + * It is possible to swap the order of the depth, width and height dimensions + * provided that the same order is used in the input, the kernel, and the + * output. + * + * All dimension orders above are given for col-major, and should be reversed + * for row-major. + */ +template +EIGEN_ALWAYS_INLINE static const std::conditional_t< + internal::traits::Layout == ColMajor, + const TensorReverseOp< + const Eigen::array::Index, + internal::traits::NumDimensions>, + const Eigen::TensorShufflingOp< + const Eigen::array::Index, + internal::traits::NumDimensions>, + const Eigen::TensorReshapingOp< + const Eigen::DSizes::Index, + internal::traits::NumDimensions>, + const TensorContractionOp< + const array< + IndexPair::Index>, 1>, + const Eigen::TensorForcedEvalOp::Index, + 2>, + const Eigen::TensorShufflingOp< + const Eigen::array< + typename internal::traits::Index, + internal::traits::NumDimensions>, + const OutputBackward>>>, + const TensorReshapingOp< + const DSizes::Index, + 2>, + const TensorVolumePatchOp< + Dynamic, Dynamic, Dynamic, + const Eigen::TensorForcedEvalOp< + const Eigen::TensorShufflingOp< + const Eigen::array< + typename internal::traits::Index, + internal::traits::NumDimensions>, + const Input>>>>>>>>, + const TensorReverseOp< + const Eigen::array::Index, + internal::traits::NumDimensions>, + const Eigen::TensorShufflingOp< + const Eigen::array::Index, + internal::traits::NumDimensions>, + const Eigen::TensorReshapingOp< + const Eigen::DSizes::Index, + internal::traits::NumDimensions>, + const TensorContractionOp< + const array< + IndexPair::Index>, 1>, + const TensorReshapingOp< + const DSizes::Index, + 2>, + const TensorVolumePatchOp< + Dynamic, Dynamic, Dynamic, + const Eigen::TensorForcedEvalOp< + const Eigen::TensorShufflingOp< + const Eigen::array< + typename internal::traits::Index, + internal::traits::NumDimensions>, + const Input>>>>, + const Eigen::TensorForcedEvalOp::Index, + 2>, + const Eigen::TensorShufflingOp< + const Eigen::array< + typename internal::traits::Index, + internal::traits::NumDimensions>, + const OutputBackward>>>>>>>> +CuboidConvolutionBackwardKernel( + const Input& input, const OutputBackward& output_backward, + typename internal::traits::Index kernelPlanes, + typename internal::traits::Index kernelRows, + typename internal::traits::Index kernelCols, + const DenseIndex stridePlanes = 1, const DenseIndex strideRows = 1, + const DenseIndex strideCols = 1) { + typedef typename internal::traits::Index TensorIndex; + TensorRef::Scalar, + internal::traits::NumDimensions, + internal::traits::Layout, TensorIndex>> + in(input); + TensorRef::Scalar, + internal::traits::NumDimensions, + internal::traits::Layout, TensorIndex>> + out(output_backward); + + EIGEN_STATIC_ASSERT(internal::traits::Layout == + internal::traits::Layout, + YOU_MADE_A_PROGRAMMING_MISTAKE); + + static const bool isColMajor = (internal::traits::Layout == ColMajor); + + static const int NumDims = internal::traits::NumDimensions; + EIGEN_STATIC_ASSERT(internal::traits::NumDimensions == + internal::traits::NumDimensions, + YOU_MADE_A_PROGRAMMING_MISTAKE); + + // We do not support higher dimensional backward convolutions, or convolutions + // without batch dimension. + // TODO(ezhulenev): Relax this constraint, and turn on tests without batch + // dimension in eigen_backward_cuboid_convolutions_test.cc. + EIGEN_STATIC_ASSERT(internal::traits::NumDimensions == 5, + YOU_MADE_A_PROGRAMMING_MISTAKE); + + const TensorIndex inputPlanes = + isColMajor ? in.dimension(1) : in.dimension(NumDims - 2); + const TensorIndex inputRows = + isColMajor ? in.dimension(2) : in.dimension(NumDims - 3); + const TensorIndex inputCols = + isColMajor ? in.dimension(3) : in.dimension(NumDims - 4); + + const TensorIndex outputPlanes = + isColMajor ? out.dimension(1) : out.dimension(NumDims - 2); + const TensorIndex outputRows = + isColMajor ? out.dimension(2) : out.dimension(NumDims - 3); + const TensorIndex outputCols = + isColMajor ? out.dimension(3) : out.dimension(NumDims - 4); + + // Number of filters. This is the same as the output depth. + const TensorIndex kernelFilters = + isColMajor ? out.dimension(0) : out.dimension(NumDims - 1); + // Number of channels. This is the same as the input depth. + const TensorIndex kernelChannels = + isColMajor ? in.dimension(0) : in.dimension(NumDims - 1); + + // Number of batches in the input tensor. + const TensorIndex batch = + isColMajor ? in.dimension(4) : in.dimension(NumDims - 5); + + // TODO(ezhulenev): Add support for inflated strides. Without inflated strides + // effective kernel planes/rows/cols are always the same as the kernel itself + // (see eigen_spatial_convolutions for details). + const TensorIndex kernelPlanesEff = kernelPlanes; + const TensorIndex kernelRowsEff = kernelRows; + const TensorIndex kernelColsEff = kernelCols; + + // Compute forward padding from input and output_backward dimensions. + const TensorIndex padPlanes = numext::maxi( + 0, (outputPlanes - 1) * stridePlanes + kernelPlanesEff - inputPlanes); + const TensorIndex padRows = numext::maxi( + 0, (outputRows - 1) * strideRows + kernelRowsEff - inputRows); + const TensorIndex padCols = numext::maxi( + 0, (outputCols - 1) * strideCols + kernelColsEff - inputCols); + + const TensorIndex padding_top_z = padPlanes / 2; + const TensorIndex padding_top = padRows / 2; + const TensorIndex padding_left = padCols / 2; + + // Compute paddings for output_backward before extracting patches. + const auto expanded_out_planes = (outputPlanes - 1) * stridePlanes + 1; + const auto expanded_out_rows = (outputRows - 1) * strideRows + 1; + const auto expanded_out_cols = (outputCols - 1) * strideCols + 1; + const auto padded_out_planes = inputPlanes + kernelPlanes - 1; + const auto padded_out_rows = inputRows + kernelRows - 1; + const auto padded_out_cols = inputCols + kernelCols - 1; + const auto top_pad_planes = kernelPlanes - 1 - padding_top_z; + const auto top_pad_rows = kernelRows - 1 - padding_top; + const auto left_pad_cols = kernelCols - 1 - padding_left; + const auto bottom_pad_planes = + padded_out_planes - expanded_out_planes - top_pad_planes; + const auto bottom_pad_rows = + padded_out_rows - expanded_out_rows - top_pad_rows; + const auto right_pad_cols = + padded_out_cols - expanded_out_cols - left_pad_cols; + + // Reorder output_backward dimensions. + array output_backward_shuffle; + if (isColMajor) { + // From: [out_depth, out_planes, out_rows, out_cols, batch] + // To: [batch, out_planes, out_rows, out_cols, out_depth] + output_backward_shuffle = {4, 1, 2, 3, 0}; + } else { + // From: [batch, out_cols, out_rows, out_planes, out_depth] + // To: [out_depth, out_cols, out_rows, out_planes, batch] + output_backward_shuffle = {4, 1, 2, 3, 0}; + } + + // Reorder input dimensions. + array input_shuffle; + if (isColMajor) { + // From: [in_depth, in_planes, in_rows, in_cols, batch] + // To: [in_depth, batch, in_planes, in_rows, in_cols] + input_shuffle = {0, 4, 1, 2, 3}; + } else { + // From: [batch, in_cols, in_rows, in_planes, in_depth] + // To: [in_cols, in_rows, in_planes, batch, in_depth] + input_shuffle = {1, 2, 3, 0, 4}; + } + + // Input is playing the role of a "kernel" in this convolution. + DSizes input_dims; + if (isColMajor) { + input_dims[0] = kernelChannels; + input_dims[1] = batch * inputPlanes * inputRows * inputCols; + } else { + input_dims[1] = kernelChannels; + input_dims[0] = inputCols * inputRows * inputPlanes * batch; + } + + // Molds the output of the patch extraction result into a 2D tensor: + // - the first dimension (dims[0]): the patch values to be multiplied with the + // kernels + // - the second dimension (dims[1]): everything else + DSizes pre_contract_dims; + if (isColMajor) { + pre_contract_dims[0] = batch * inputPlanes * inputRows * inputCols; + pre_contract_dims[1] = + kernelPlanes * kernelRows * kernelCols * kernelFilters; + } else { + pre_contract_dims[1] = inputCols * inputRows * inputPlanes * batch; + pre_contract_dims[0] = + kernelFilters * kernelCols * kernelRows * kernelPlanes; + } + + // We will contract along the collapsed dimension that contains the + // batch, inputPlanes, inputRows and inputCols. + array, 1> contract_dims; + contract_dims[0] = IndexPair(1, 0); + + // Dimensions after contraction. + DSizes post_contract_dims; + if (isColMajor) { + post_contract_dims[0] = kernelChannels; + post_contract_dims[1] = kernelPlanes; + post_contract_dims[2] = kernelRows; + post_contract_dims[3] = kernelCols; + post_contract_dims[4] = kernelFilters; + } else { + post_contract_dims[0] = kernelFilters; + post_contract_dims[1] = kernelCols; + post_contract_dims[2] = kernelRows; + post_contract_dims[3] = kernelPlanes; + post_contract_dims[4] = kernelChannels; + } + + // Reorder output of contraction to valid filter shape. + array kernel_shuffle; + if (isColMajor) { + // From: [in_depth, kernel_planes, kernel_rows, kernel_cols, out_depth] + // To: [out_depth, in_depth, kernel_planes, kernel_rows, kernel_cols] + kernel_shuffle = {4, 0, 1, 2, 3}; + } else { + // From: [out_depth, kernel_cols, kernel_rows, kernel_planes, in_depth] + // To: [kernel_cols, kernel_rows, kernel_planes, in_depth, out_depth] + kernel_shuffle = {1, 2, 3, 4, 0}; + } + + // Reverse kernel backprop dimensions. + array kernel_reverse; + if (isColMajor) { + kernel_reverse = {false, false, true, true, true}; + } else { + kernel_reverse = {true, true, true, false, false}; + } + + // Create convolution input (aka source of patches) from output backward + // tensor by shuffling dimensions. + const auto the_input = + output_backward.shuffle(output_backward_shuffle).eval(); + + // Create convolution kernel (aka filter) from input by shuffling and + // reshaping. + const auto the_kernel = + input.shuffle(input_shuffle).reshape(input_dims).eval(); + + return choose(Cond::Layout == ColMajor>(), + the_kernel.contract( + the_input + .extract_volume_patches( + inputPlanes, inputRows, inputCols, 1, 1, 1, + stridePlanes, strideRows, strideCols, + top_pad_planes, bottom_pad_planes, top_pad_rows, + bottom_pad_rows, left_pad_cols, right_pad_cols) + .reshape(pre_contract_dims), + contract_dims), + the_input + .extract_volume_patches( + inputPlanes, inputRows, inputCols, 1, 1, 1, + stridePlanes, strideRows, strideCols, top_pad_planes, + bottom_pad_planes, top_pad_rows, bottom_pad_rows, + left_pad_cols, right_pad_cols) + .reshape(pre_contract_dims) + .contract(the_kernel, contract_dims)) + .reshape(post_contract_dims) + .shuffle(kernel_shuffle) + .reverse(kernel_reverse); +} + +} // end namespace Eigen + +#endif // TENSORFLOW_CORE_KERNELS_EIGEN_BACKWARD_CUBOID_CONVOLUTIONS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/eigen_backward_spatial_convolutions.h b/third_party/tflite-hdrs/tensorflow/core/kernels/eigen_backward_spatial_convolutions.h new file mode 100644 index 00000000..c21b6fe0 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/eigen_backward_spatial_convolutions.h @@ -0,0 +1,593 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_EIGEN_BACKWARD_SPATIAL_CONVOLUTIONS_H_ +#define TENSORFLOW_CORE_KERNELS_EIGEN_BACKWARD_SPATIAL_CONVOLUTIONS_H_ + +#include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "xla/tsl/framework/convolution/eigen_spatial_convolutions.h" + +namespace Eigen { + +/** SpatialConvolutionBackwardInput + * \ingroup CXX11_NeuralNetworks_Module + * + * \brief Computes the backprop for the input of a 2D convolution. + * + * The output_backward parameter is expected to be a tensor with a rank of 3 or + * more (channels, height, width, and optionally others) + * The kernel parameter is expected to be a 4D tensor (filters, channels, + * kernel_height, kernel_width) + * The output_backward and the kernel must both be in col-major layout. The + * result will also be in col-major layout. + * + * If row_in_stride, col_in_stride > 1, then applies convolution with holes + * (aka atrous convolution), sampling every row_in_stride, col_in_stride input + * pixels. + * + * The result can be assigned to a tensor of rank equal to the rank of the + * output_backward. The dimensions of the result will be filters, height, width + * (and others if applicable). + * + * It is possible to swap the order of the width and height dimensions provided + * that the same order is used in the input, the kernel, and the output. + * + */ +typedef IndexList, type2index<0>, type2index<1>, type2index<1>> + ReverseColMajor; +typedef IndexList, type2index<1>, type2index<0>, type2index<0>> + ReverseRowMajor; + +template +EIGEN_ALWAYS_INLINE static const std::conditional_t< + internal::traits::Layout == ColMajor, + TensorReshapingOp< + const DSizes::Index, + internal::traits::NumDimensions>, + const TensorContractionOp< + const array< + IndexPair::Index>, 1>, + const TensorReshapingOp< + const DSizes::Index, + 2>, + const Eigen::TensorForcedEvalOp::Index, 4>, + const Eigen::TensorForcedEvalOp>>>>, + const TensorReshapingOp< + const DSizes::Index, + 2>, + const TensorImagePatchOp>>>, + TensorReshapingOp< + + const DSizes::Index, + internal::traits::NumDimensions>, + const TensorContractionOp< + const array< + IndexPair::Index>, 1>, + const TensorReshapingOp< + const DSizes::Index, + 2>, + const TensorImagePatchOp>, + const TensorReshapingOp< + const DSizes::Index, + 2>, + const Eigen::TensorForcedEvalOp::Index, 4>, + const Eigen::TensorForcedEvalOp>>>>>>> +SpatialConvolutionBackwardInput( + const Kernel& kernel, const OutputBackward& output_backward, + typename internal::traits::Index inputRows, + typename internal::traits::Index inputCols, + const DenseIndex row_stride = 1, const DenseIndex col_stride = 1, + const DenseIndex row_in_stride = 1, const DenseIndex col_in_stride = 1) { + typedef typename internal::traits::Index TensorIndex; + typedef typename internal::traits::Scalar OutScalar; + TensorRef::Scalar, + internal::traits::NumDimensions, + internal::traits::Layout, TensorIndex>> + kern(kernel); + TensorRef::NumDimensions, + internal::traits::Layout, TensorIndex>> + out(output_backward); + + EIGEN_STATIC_ASSERT(internal::traits::Layout == + internal::traits::Layout, + YOU_MADE_A_PROGRAMMING_MISTAKE); + + static const bool isColMajor = + (internal::traits::Layout == ColMajor); + + static const int NumDims = internal::traits::NumDimensions; + + // Number of filters to apply. This is the same as the output depth of the + // result + const TensorIndex kernelFilters = + isColMajor ? kern.dimensions()[0] : kern.dimensions()[3]; + // Number of channels. This is the same as the input depth. + const TensorIndex kernelChannels = + isColMajor ? kern.dimensions()[1] : kern.dimensions()[2]; + const TensorIndex kernelRows = + isColMajor ? kern.dimensions()[2] : kern.dimensions()[1]; + const TensorIndex kernelCols = + isColMajor ? kern.dimensions()[3] : kern.dimensions()[0]; + + // This is the effective kernel size, taking into account the (*_in_stride - + // 1) zero-values + // inserted between consecutive kernel elements in atrous convolution + const TensorIndex kernelRowsEff = + kernelRows + (kernelRows - 1) * (row_in_stride - 1); + const TensorIndex kernelColsEff = + kernelCols + (kernelCols - 1) * (col_in_stride - 1); + + const TensorIndex outputRows = isColMajor + ? output_backward.dimension(1) + : output_backward.dimension(NumDims - 2); + const TensorIndex outputCols = isColMajor + ? output_backward.dimension(2) + : output_backward.dimension(NumDims - 3); + + // Computing the forward padding + const TensorIndex forward_pad_top = numext::maxi( + 0, ((outputRows - 1) * row_stride + kernelRowsEff - inputRows) / 2); + const TensorIndex forward_pad_left = numext::maxi( + 0, ((outputCols - 1) * col_stride + kernelColsEff - inputCols) / 2); + const TensorIndex padding_top = kernelRowsEff - 1 - forward_pad_top; + const TensorIndex padding_left = kernelColsEff - 1 - forward_pad_left; + + const TensorIndex padding_bottom = inputRows - (outputRows - 1) * row_stride - + 2 - padding_top + kernelRowsEff; + const TensorIndex padding_right = inputCols - (outputCols - 1) * col_stride - + 2 - padding_left + kernelColsEff; + + eigen_assert(padding_top >= 0); + eigen_assert(padding_left >= 0); + eigen_assert(padding_bottom >= 0); + eigen_assert(padding_right >= 0); + + // The kernel has dimensions filters X channels X patch_rows X patch_cols + // We need to reverse the kernel along dimensions corresponding to rows and + // cols. + // TODO(yangke): we can make things slightly faster by collapsing the + // dimensions + // where we don't reverse. Try that once we have a faster compiler. + typedef std::conditional_t + Reverse; + Reverse kernel_reverse; + // Reorder the dimensions to: + // filters x patch_rows x patch_cols x channels + array kernel_shuffle; + if (isColMajor) { + // From: filters x channels x rows x cols + // To: filters x rows x cols x channels + kernel_shuffle[0] = 0; + kernel_shuffle[1] = 2; + kernel_shuffle[2] = 3; + kernel_shuffle[3] = 1; + } else { + // From: cols x rows x channels x filters + // To: channels x cols x rows x filters + kernel_shuffle[0] = 2; + kernel_shuffle[1] = 0; + kernel_shuffle[2] = 1; + kernel_shuffle[3] = 3; + } + + // Collapse the dims + DSizes kernel_dims; + if (isColMajor) { + kernel_dims[0] = kernelFilters * kernelRows * kernelCols; + kernel_dims[1] = kernelChannels; + } else { + kernel_dims[1] = kernelFilters * kernelRows * kernelCols; + kernel_dims[0] = kernelChannels; + } + + // The output_backward has dimensions out_depth X out_rows X out_cols X OTHERS + // When we extract the image patches from output_backward, it will have + // dimensions + // out_depth X (patch_rows * patch_cols) X (input_rows * input_cols * + // OTHERS) + DSizes pre_contract_dims; + if (isColMajor) { + pre_contract_dims[0] = kernelFilters * kernelRows * kernelCols; + pre_contract_dims[1] = inputRows * inputCols; + for (int i = 3; i < NumDims; ++i) { + pre_contract_dims[1] *= out.dimension(i); + } + } else { + pre_contract_dims[1] = kernelFilters * kernelRows * kernelCols; + pre_contract_dims[0] = inputRows * inputCols; + for (int i = 0; i < NumDims - 3; ++i) { + pre_contract_dims[0] *= out.dimension(i); + } + } + + // We will contract along the collapsed dimension that contains the + // kernelFilters, the kernelRows and the kernelCols. + array, 1> contract_dims; + if (isColMajor) { + // col-major: kernel.contract(output.patches) + contract_dims[0] = IndexPair(0, 0); + } else { + // row-major: output.patches.contract(kernel) + contract_dims[0] = IndexPair(1, 1); + } + + // Post contraction, the dimensions of the input_backprop is + // channels X input_rows X input_cols X OTHERS + DSizes post_contract_dims; + if (isColMajor) { + post_contract_dims[0] = kernelChannels; + post_contract_dims[1] = inputRows; + post_contract_dims[2] = inputCols; + for (int i = 3; i < NumDims; ++i) { + post_contract_dims[i] = out.dimension(i); + } + } else { + post_contract_dims[NumDims - 1] = kernelChannels; + post_contract_dims[NumDims - 2] = inputRows; + post_contract_dims[NumDims - 3] = inputCols; + for (int i = 0; i < NumDims - 3; ++i) { + post_contract_dims[i] = out.dimension(i); + } + } + + // NOTE(ezhulenev): We do eval after reverse and shuffle, because tiled + // evaluation of these ops does not compose. Doing explicit eval is ~8x + // faster in micro benchmarks. + + return choose( + Cond::Layout == ColMajor>(), + kernel.reverse(kernel_reverse) + .eval() + .shuffle(kernel_shuffle) + .eval() + .reshape(kernel_dims) + .contract( + output_backward + .extract_image_patches( + kernelRows, kernelCols, 1, 1, row_in_stride, + col_in_stride, row_stride, col_stride, padding_top, + padding_bottom, padding_left, padding_right, OutScalar(0)) + .reshape(pre_contract_dims), + contract_dims) + .reshape(post_contract_dims), + output_backward + .extract_image_patches(kernelRows, kernelCols, 1, 1, row_in_stride, + col_in_stride, row_stride, col_stride, + padding_top, padding_bottom, padding_left, + padding_right, OutScalar(0)) + .reshape(pre_contract_dims) + .contract(kernel.reverse(kernel_reverse) + .eval() + .shuffle(kernel_shuffle) + .eval() + .reshape(kernel_dims), + contract_dims) + .reshape(post_contract_dims)); +} + +/** SpatialConvolutionBackwardKernel + * \ingroup CXX11_NeuralNetworks_Module + * + * \brief Computes the backprop for the filter of a 2D convolution. + * + * The output_backward parameter is expected to be a tensor with a rank of 3 or + * more (channels, height, width, and optionally others) + * The kernel parameter is expected to be a 4D tensor (filters, channels, + * kernel_height, kernel_width) + * The output_backward and the kernel must both be in col-major layout. The + * result will also be in col-major layout. + * + * If row_in_stride, col_stride > 1, then applies convolution with holes (aka + * atrous convolution), sampling every row_in_stride, col_in_stride input + * pixels. + * + * The result can be assigned to a tensor of rank equal to the rank of the + * output_backward. The dimensions of the result will be filters, height, width + * (and others if applicable). + * + * It is possible to swap the order of the width and height dimensions provided + * that the same order is used in the input, the kernel, and the output. + * + */ + +template +EIGEN_ALWAYS_INLINE static const std::conditional_t< + internal::traits::Layout == ColMajor, + const TensorReverseOp< + const Eigen::array::Index, + internal::traits::NumDimensions>, + const Eigen::TensorForcedEvalOp::Index, + internal::traits::NumDimensions>, + const Eigen::TensorReshapingOp< + const Eigen::DSizes::Index, + internal::traits::NumDimensions>, + const TensorContractionOp< + const array< + IndexPair::Index>, 1>, + const TensorReshapingOp< + const DSizes::Index, + 2>, + const Eigen::TensorForcedEvalOp< + const Eigen::TensorShufflingOp< + const Eigen::array< + typename internal::traits::Index, + internal::traits::NumDimensions>, + const Input>>>, + const TensorReshapingOp< + const DSizes::Index, + 2>, + const TensorImagePatchOp< + Dynamic, Dynamic, + const Eigen::TensorForcedEvalOp< + const Eigen::TensorShufflingOp< + const Eigen::array< + typename internal::traits::Index, + internal::traits::NumDimensions>, + const OutputBackward>>>>>>>>>, + const TensorReverseOp< + const Eigen::array::Index, + internal::traits::NumDimensions>, + const Eigen::TensorForcedEvalOp::Index, + internal::traits::NumDimensions>, + const Eigen::TensorReshapingOp< + const Eigen::DSizes::Index, + internal::traits::NumDimensions>, + const TensorContractionOp< + const array< + IndexPair::Index>, 1>, + const TensorReshapingOp< + const DSizes::Index, + 2>, + const TensorImagePatchOp< + Dynamic, Dynamic, + const Eigen::TensorForcedEvalOp< + const Eigen::TensorShufflingOp< + const Eigen::array< + typename internal::traits::Index, + internal::traits::NumDimensions>, + const OutputBackward>>>>, + const TensorReshapingOp< + const DSizes::Index, + 2>, + const Eigen::TensorForcedEvalOp< + const Eigen::TensorShufflingOp< + const Eigen::array< + typename internal::traits::Index, + internal::traits::NumDimensions>, + const Input>>>>>>>>> +SpatialConvolutionBackwardKernel( + const Input& input, const OutputBackward& output_backward, + typename internal::traits::Index kernelRows, + typename internal::traits::Index kernelCols, + const DenseIndex row_stride = 1, const DenseIndex col_stride = 1, + const DenseIndex row_in_stride = 1, const DenseIndex col_in_stride = 1) { + typedef typename internal::traits::Index TensorIndex; + typedef typename internal::traits::Scalar OutScalar; + TensorRef::Scalar, + internal::traits::NumDimensions, + internal::traits::Layout, TensorIndex>> + in(input); + TensorRef::NumDimensions, + internal::traits::Layout, TensorIndex>> + out(output_backward); + + EIGEN_STATIC_ASSERT(internal::traits::Layout == + internal::traits::Layout, + YOU_MADE_A_PROGRAMMING_MISTAKE); + + // stride and in_stride cannot both be larger than 1 + eigen_assert(!(row_stride > 1 && row_in_stride > 1)); + eigen_assert(!(col_stride > 1 && col_in_stride > 1)); + + static const bool isColMajor = (internal::traits::Layout == ColMajor); + + static const int NumDims = internal::traits::NumDimensions; + EIGEN_STATIC_ASSERT(internal::traits::NumDimensions == + internal::traits::NumDimensions, + YOU_MADE_A_PROGRAMMING_MISTAKE); + EIGEN_STATIC_ASSERT(NumDims == 4, YOU_MADE_A_PROGRAMMING_MISTAKE); + + const TensorIndex inputRows = + isColMajor ? in.dimension(1) : in.dimension(NumDims - 2); + const TensorIndex inputCols = + isColMajor ? in.dimension(2) : in.dimension(NumDims - 3); + + const TensorIndex outputRows = isColMajor + ? output_backward.dimension(1) + : output_backward.dimension(NumDims - 2); + const TensorIndex outputCols = isColMajor + ? output_backward.dimension(2) + : output_backward.dimension(NumDims - 3); + + // Number of filters to apply. This is the same as the output depth of the + // result + const TensorIndex kernelFilters = + isColMajor ? out.dimensions()[0] : out.dimensions()[NumDims - 1]; + + // Number of channels. This is the same as the input depth. + const TensorIndex kernelChannels = + isColMajor ? in.dimensions()[0] : in.dimensions()[NumDims - 1]; + + // This is the effective kernel size, taking into account the + // (*_in_stride - 1) zero-values inserted between consecutive kernel + // elements in atrous convolution + const TensorIndex kernelRowsEff = + kernelRows + (kernelRows - 1) * (row_in_stride - 1); + const TensorIndex kernelColsEff = + kernelCols + (kernelCols - 1) * (col_in_stride - 1); + + // Number of batches (and other dimensions) in the input tensor. + TensorIndex batch = 1; + for (int d = 3; d < NumDims; ++d) { + batch *= isColMajor ? in.dimension(d) : in.dimension(NumDims - d - 1); + } + + // Computing the forward padding + const TensorIndex padRows = numext::maxi( + 0, (outputRows - 1) * row_stride + kernelRowsEff - inputRows); + const TensorIndex padCols = numext::maxi( + 0, (outputCols - 1) * col_stride + kernelColsEff - inputCols); + + TensorIndex padding_top = padRows / 2; + TensorIndex padding_left = padCols / 2; + + // Compute paddings for output_backward before extracting patches. + const TensorIndex expanded_out_rows = (outputRows - 1) * row_stride + 1; + const TensorIndex expanded_out_cols = (outputCols - 1) * col_stride + 1; + + const TensorIndex padded_out_rows = inputRows + kernelRowsEff - 1; + const TensorIndex padded_out_cols = inputCols + kernelColsEff - 1; + + const TensorIndex top_pad_rows = kernelRowsEff - 1 - padding_top; + const TensorIndex left_pad_cols = kernelColsEff - 1 - padding_left; + + const TensorIndex bottom_pad_rows = + padded_out_rows - expanded_out_rows - top_pad_rows; + const TensorIndex right_pad_cols = + padded_out_cols - expanded_out_cols - left_pad_cols; + + // Reorder output_backward dimensions. + array output_backward_shuffle; + if (isColMajor) { + // From: [out_depth, out_rows, out_cols, batch] + // To: [batch, out_rows, out_cols, out_depth] + output_backward_shuffle = {3, 1, 2, 0}; + } else { + // From: [batch, out_cols, out_rows, out_depth] + // To: [out_depth, out_cols, out_rows, batch] + output_backward_shuffle = {3, 1, 2, 0}; + } + + // Reorder input dimensions. + array input_shuffle; + if (isColMajor) { + // From: [in_depth, in_rows, in_cols, batch] + // To: [in_depth, batch, in_rows, in_cols] + input_shuffle = {0, 3, 1, 2}; + } else { + // From: [batch, in_cols, in_rows, in_depth] + // To: [in_cols, in_rows, batch, in_depth] + input_shuffle = {1, 2, 0, 3}; + } + + // Input is playing the role of a "kernel" in this convolution. + DSizes input_dims; + if (isColMajor) { + input_dims[0] = kernelChannels; + input_dims[1] = batch * inputRows * inputCols; + } else { + input_dims[1] = kernelChannels; + input_dims[0] = inputCols * inputRows * batch; + } + + // Molds the output of the patch extraction result into a 2D tensor: + // - the first dimension (dims[0]): the patch values to be multiplied with the + // kernels + // - the second dimension (dims[1]): everything else + DSizes pre_contract_dims; + if (isColMajor) { + pre_contract_dims[0] = batch * inputRows * inputCols; + pre_contract_dims[1] = kernelRows * kernelCols * kernelFilters; + } else { + pre_contract_dims[1] = inputCols * inputRows * batch; + pre_contract_dims[0] = kernelFilters * kernelCols * kernelRows; + } + + // We will contract along the collapsed dimension that contains the + // batch, inputRows and inputCols. + array, 1> contract_dims; + contract_dims[0] = IndexPair(1, 0); + + // Dimensions after contraction. + DSizes post_contract_dims; + if (isColMajor) { + post_contract_dims[0] = kernelChannels; + post_contract_dims[1] = kernelRows; + post_contract_dims[2] = kernelCols; + post_contract_dims[3] = kernelFilters; + } else { + post_contract_dims[0] = kernelFilters; + post_contract_dims[1] = kernelCols; + post_contract_dims[2] = kernelRows; + post_contract_dims[3] = kernelChannels; + } + + // Reorder output of contraction to a valid filter shape. + array kernel_shuffle; + if (isColMajor) { + // From: [in_depth, kernel_rows, kernel_cols, out_depth] + // To: [out_depth, in_depth, kernel_rows, kernel_cols] + kernel_shuffle = {3, 0, 1, 2}; + } else { + // From: [out_depth, kernel_cols, kernel_rows, in_depth] + // To: [kernel_cols, kernel_rows, in_depth, out_depth] + kernel_shuffle = {1, 2, 3, 0}; + } + + // Reverse kernel backprop dimensions. + array kernel_reverse; + if (isColMajor) { + kernel_reverse = {false, false, true, true}; + } else { + kernel_reverse = {true, true, false, false}; + } + + // Create convolution input (aka source of patches) from output backward + // tensor by shuffling dimensions. + const auto output_backward_shuffled = + output_backward.shuffle(output_backward_shuffle).eval(); + + // Create convolution kernel (aka filter) from input by shuffling and + // reshaping. + const auto input_shuffled = + input.shuffle(input_shuffle).eval().reshape(input_dims); + + return choose( + Cond::Layout == ColMajor>(), + input_shuffled.contract( + output_backward_shuffled + .extract_image_patches(inputRows, inputCols, row_in_stride, + col_in_stride, 1, 1, row_stride, + col_stride, top_pad_rows, + bottom_pad_rows, left_pad_cols, + right_pad_cols, OutScalar(0)) + .reshape(pre_contract_dims), + contract_dims), + output_backward_shuffled + .extract_image_patches( + inputRows, inputCols, row_in_stride, col_in_stride, 1, 1, + row_stride, col_stride, top_pad_rows, bottom_pad_rows, + left_pad_cols, right_pad_cols, OutScalar(0)) + .reshape(pre_contract_dims) + .contract(input_shuffled, contract_dims)) + .reshape(post_contract_dims) + .shuffle(kernel_shuffle) + .eval() + .reverse(kernel_reverse); +} + +} // end namespace Eigen + +#endif // TENSORFLOW_CORE_KERNELS_EIGEN_BACKWARD_SPATIAL_CONVOLUTIONS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/eigen_benchmark.h b/third_party/tflite-hdrs/tensorflow/core/kernels/eigen_benchmark.h new file mode 100644 index 00000000..e69a5976 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/eigen_benchmark.h @@ -0,0 +1,295 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_EIGEN_BENCHMARK_H_ +#define TENSORFLOW_CORE_KERNELS_EIGEN_BENCHMARK_H_ + +#include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "xla/tsl/framework/convolution/eigen_spatial_convolutions.h" +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/kernels/eigen_backward_cuboid_convolutions.h" +#include "tensorflow/core/kernels/eigen_backward_spatial_convolutions.h" +#include "tensorflow/core/kernels/eigen_cuboid_convolution.h" +#include "tensorflow/core/platform/test_benchmark.h" + +using ::tensorflow::TTypes; + +template +class SpatialConvolutionBenchmarksSuite { + public: + using Input = TTypes::ConstTensor; + using Filter = TTypes::ConstTensor; + using Output = TTypes::Tensor; + + using Dimensions = Eigen::DSizes; + + SpatialConvolutionBenchmarksSuite(::testing::benchmark::State& state, + Device& device) + : state_(state), device_(device) {} + + Eigen::Index BufferSize(const Dimensions& dims) { + return dims.TotalSize() * sizeof(Scalar); + } + + void SpatialConvolution(Dimensions input_dims, Dimensions filter_dims) { + Dimensions output_dims(input_dims[0], // batch + input_dims[1], // input_height + input_dims[2], // input_width + filter_dims[3]); // filter_count + + Scalar* input_data = + static_cast(device_.allocate(BufferSize(input_dims))); + Scalar* filter_data = + static_cast(device_.allocate(BufferSize(filter_dims))); + Scalar* output_data = + static_cast(device_.allocate(BufferSize(output_dims))); + + device_.memset(input_data, 123, BufferSize(input_dims)); + device_.memset(filter_data, 123, BufferSize(filter_dims)); + + Input input(input_data, input_dims); + Filter filter(filter_data, filter_dims); + Output output(output_data, output_dims); + + for (auto s : state_) { + output.device(device_) = Eigen::SpatialConvolution(input, filter); + tensorflow::testing::DoNotOptimize(output); + } + + device_.deallocate(input_data); + device_.deallocate(filter_data); + device_.deallocate(output_data); + } + + void SpatialConvolutionBackwardInput(Dimensions input_dims, + Dimensions filter_dims) { + using OutputBackward = TTypes::ConstTensor; + using InputBackward = TTypes::Tensor; + + Dimensions output_dims(input_dims[0], // batch + input_dims[1], // input_height + input_dims[2], // input_width + filter_dims[3]); // filter_count + + // Assuming that the convolution had SAME padding. + Eigen::Index input_rows = input_dims[1]; + Eigen::Index input_cols = input_dims[2]; + + Scalar* filter_data = + static_cast(device_.allocate(BufferSize(filter_dims))); + Scalar* output_backward_data = + static_cast(device_.allocate(BufferSize(output_dims))); + Scalar* input_backward_data = + static_cast(device_.allocate(BufferSize(input_dims))); + + device_.memset(filter_data, 123, BufferSize(filter_dims)); + device_.memset(output_backward_data, 123, BufferSize(output_dims)); + + Filter filter(filter_data, filter_dims); + OutputBackward output_backward(output_backward_data, output_dims); + InputBackward input_backward(input_backward_data, input_dims); + + for (auto s : state_) { + input_backward.device(device_) = Eigen::SpatialConvolutionBackwardInput( + filter, output_backward, input_rows, input_cols); + tensorflow::testing::DoNotOptimize(input_backward); + } + + device_.deallocate(filter_data); + device_.deallocate(output_backward_data); + device_.deallocate(input_backward_data); + } + + void SpatialConvolutionBackwardKernel(Dimensions input_dims, + Dimensions filter_dims) { + using OutputBackward = TTypes::ConstTensor; + using FilterBackward = TTypes::Tensor; + + Dimensions output_dims(input_dims[0], // batch + input_dims[1], // input_height + input_dims[2], // input_width + filter_dims[3]); // filter_count + + // Assuming that the convolution had SAME padding. + Eigen::Index filter_rows = filter_dims[0]; + Eigen::Index filter_cols = filter_dims[1]; + + Scalar* input_data = + static_cast(device_.allocate(BufferSize(input_dims))); + Scalar* output_backward_data = + static_cast(device_.allocate(BufferSize(output_dims))); + Scalar* filter_backward_data = + static_cast(device_.allocate(BufferSize(filter_dims))); + + device_.memset(input_data, 123, BufferSize(input_dims)); + device_.memset(output_backward_data, 123, BufferSize(output_dims)); + + Input input(input_data, input_dims); + OutputBackward output_backward(output_backward_data, input_dims); + FilterBackward filter_backward(filter_backward_data, filter_dims); + + for (auto s : state_) { + filter_backward.device(device_) = Eigen::SpatialConvolutionBackwardKernel( + input, output_backward, filter_rows, filter_cols); + tensorflow::testing::DoNotOptimize(filter_backward); + } + + device_.deallocate(input_data); + device_.deallocate(output_backward_data); + device_.deallocate(filter_backward_data); + } + + private: + ::testing::benchmark::State& state_; + + Device& device_; +}; + +template +class CuboidConvolutionBenchmarksSuite { + public: + using Input = TTypes::ConstTensor; + using Filter = TTypes::ConstTensor; + using Output = TTypes::Tensor; + + using Dimensions = Eigen::DSizes; + + CuboidConvolutionBenchmarksSuite(::testing::benchmark::State& state, + Device& device) + : state_(state), device_(device) {} + + Eigen::Index BufferSize(const Dimensions& dims) { + return dims.TotalSize() * sizeof(Scalar); + } + + void CuboidConvolution(Dimensions input_dims, Dimensions filter_dims) { + Dimensions output_dims(input_dims[0], // batch + input_dims[1], // input_height + input_dims[2], // input_width + input_dims[3], // input_planes + filter_dims[4]); // filter_count + + Scalar* input_data = + static_cast(device_.allocate(BufferSize(input_dims))); + Scalar* filter_data = + static_cast(device_.allocate(BufferSize(filter_dims))); + Scalar* output_data = + static_cast(device_.allocate(BufferSize(output_dims))); + + device_.memset(input_data, 123, BufferSize(input_dims)); + device_.memset(filter_data, 123, BufferSize(filter_dims)); + + Input input(input_data, input_dims); + Filter filter(filter_data, filter_dims); + Output output(output_data, output_dims); + + for (auto s : state_) { + output.device(device_) = Eigen::CuboidConvolution(input, filter); + tensorflow::testing::DoNotOptimize(output); + } + + device_.deallocate(input_data); + device_.deallocate(filter_data); + device_.deallocate(output_data); + } + + void CuboidConvolutionBackwardInput(Dimensions input_dims, + Dimensions filter_dims) { + Dimensions output_dims(input_dims[0], // batch + input_dims[1], // input_height + input_dims[2], // input_width + input_dims[3], // input_planes + filter_dims[4]); // filter_count + + using OutputBackward = TTypes::ConstTensor; + using InputBackward = TTypes::Tensor; + + // Assuming that the convolution had SAME padding. + Eigen::Index input_rows = input_dims[1]; + Eigen::Index input_cols = input_dims[2]; + Eigen::Index input_planes = input_dims[3]; + + Scalar* filter_data = + static_cast(device_.allocate(BufferSize(filter_dims))); + Scalar* output_backward_data = + static_cast(device_.allocate(BufferSize(output_dims))); + Scalar* input_backward_data = + static_cast(device_.allocate(BufferSize(input_dims))); + + device_.memset(filter_data, 123, BufferSize(filter_dims)); + device_.memset(output_backward_data, 123, BufferSize(output_dims)); + + Filter filter(filter_data, filter_dims); + OutputBackward output_backward(output_backward_data, output_dims); + InputBackward input_backward(input_backward_data, input_dims); + + for (auto s : state_) { + input_backward.device(device_) = Eigen::CuboidConvolutionBackwardInput( + filter, output_backward, input_planes, input_rows, input_cols); + tensorflow::testing::DoNotOptimize(input_backward); + } + + device_.deallocate(filter_data); + device_.deallocate(output_backward_data); + device_.deallocate(input_backward_data); + } + + void CuboidConvolutionBackwardKernel(Dimensions input_dims, + Dimensions filter_dims) { + using OutputBackward = TTypes::ConstTensor; + using FilterBackward = TTypes::Tensor; + + Dimensions output_dims(input_dims[0], // batch + input_dims[1], // input_height + input_dims[2], // input_width + input_dims[3], // input_planes + filter_dims[4]); // filter_count + + // Assuming that the convolution had SAME padding. + Eigen::Index filter_rows = filter_dims[0]; + Eigen::Index filter_cols = filter_dims[1]; + Eigen::Index filter_planes = filter_dims[2]; + + Scalar* input_data = + static_cast(device_.allocate(BufferSize(input_dims))); + Scalar* output_backward_data = + static_cast(device_.allocate(BufferSize(output_dims))); + Scalar* filter_backward_data = + static_cast(device_.allocate(BufferSize(filter_dims))); + + device_.memset(input_data, 123, BufferSize(input_dims)); + device_.memset(output_backward_data, 123, BufferSize(output_dims)); + + Input input(input_data, input_dims); + OutputBackward output_backward(output_backward_data, output_dims); + FilterBackward filter_backward(filter_backward_data, filter_dims); + + for (auto s : state_) { + filter_backward.device(device_) = Eigen::CuboidConvolutionBackwardKernel( + input, output_backward, filter_planes, filter_rows, filter_cols); + tensorflow::testing::DoNotOptimize(filter_backward); + } + + device_.deallocate(input_data); + device_.deallocate(output_backward_data); + device_.deallocate(filter_backward_data); + } + + private: + ::testing::benchmark::State& state_; + Device& device_; +}; + +#endif // TENSORFLOW_CORE_KERNELS_EIGEN_BENCHMARK_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/eigen_cuboid_convolution.h b/third_party/tflite-hdrs/tensorflow/core/kernels/eigen_cuboid_convolution.h new file mode 100644 index 00000000..156c557e --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/eigen_cuboid_convolution.h @@ -0,0 +1,1995 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_EIGEN_CUBOID_CONVOLUTION_H_ +#define TENSORFLOW_CORE_KERNELS_EIGEN_CUBOID_CONVOLUTION_H_ + +#include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive + +#if defined(TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL) +#include "xla/tsl/framework/contraction/eigen_contraction_kernel.h" +#endif + +#include "xla/tsl/framework/convolution/eigen_convolution_helpers.h" + +namespace Eigen { + +namespace internal { + +#if !EIGEN_ALTIVEC_USE_CUSTOM_PACK +// WARNING: Most of the code here implicitly assumes that the matrix is in +// ColMajor layout. This is guaranteed by the tensor contraction (see +// TensorContraction.h). +// +// Inside Eigen a tensor contraction is represented by a matrix multiplication. +// We don't want to actually extract volume patches and reshape the result into +// a matrix (this involves allocating huge extra memory), so the patch +// extraction and reshape operations are implicit. +// +// TensorContractionInputMapper takes a matrix index and returns the coefficient +// (or the packet) of the "virtual tensor", that would be at that index if we +// were to actually reshape the result of patch extraction. +// +// TensorContractionSubMapper provides a similar view into the "virtual matrix" +// at the given vertical and horizontal offsets. +// +// "Virtual matrix" dimensions: +// *0: kernelChannels * kernelPlanes * kernelRows * kernelCols +// 1: out_planes * out_height * out_width * OTHERS (e.g batches, etc...) +// +// *) extracted patches are continuous in memory (innermost dimension assuming +// col major layout) +// +// With this dimensions: +// row - offset within a single patch (in code: patchId) +// col - index of the extracted patch (in code: patchIndex) +// patchIndex ∈ [0..num_patches * OTHERS] (batch and other dimensions) +// +template +class TensorContractionInputMapper< + Scalar_, Index, Side, + TensorEvaluator >, + Device>, + nocontract_t, contract_t, packet_size, inner_dim_contiguous, + inner_dim_reordered, Alignment> { + public: + typedef Scalar_ Scalar; + typedef TensorContractionInputMapper< + Scalar, Index, Side, + TensorEvaluator >, + Device>, + nocontract_t, contract_t, packet_size, inner_dim_contiguous, + inner_dim_reordered, Alignment> + Self; + typedef TensorContractionSubMapper< + Scalar, Index, Side, + TensorEvaluator >, + Device>, + nocontract_t, contract_t, packet_size, inner_dim_contiguous, + inner_dim_reordered, Alignment> + SubMapper; + typedef SubMapper VectorMapper; + typedef SubMapper LinearMapper; + typedef typename packet_traits::type Packet; + + EIGEN_DEVICE_FUNC + TensorContractionInputMapper( + const TensorEvaluator< + const TensorReshapingOp< + NewDimension, + const TensorVolumePatchOp >, + Device>& tensor, + const nocontract_t&, const nocontract_t&, const contract_t&, + const contract_t&) + : m_impl(tensor.impl().impl()) { + if (internal::traits::Layout == ColMajor) { + m_patch_depth = tensor.impl().dimensions()[0]; + m_patch_planes = tensor.impl().dimensions()[1]; + m_patch_rows = tensor.impl().dimensions()[2]; + m_patch_cols = tensor.impl().dimensions()[3]; + m_num_patches = tensor.impl().dimensions()[4]; + } else { + const int NumDims = tensor.impl().dimensions().size(); + m_patch_depth = tensor.impl().dimensions()[NumDims - 1]; + m_patch_planes = tensor.impl().dimensions()[NumDims - 2]; + m_patch_rows = tensor.impl().dimensions()[NumDims - 3]; + m_patch_cols = tensor.impl().dimensions()[NumDims - 4]; + m_num_patches = tensor.impl().dimensions()[NumDims - 5]; + } + + // Strides for navigating through the single patch. + m_patch_plane_stride = m_patch_depth; + m_patch_row_stride = m_patch_planes * m_patch_plane_stride; + m_patch_col_stride = m_patch_rows * m_patch_row_stride; + + // Strides for the output tensor. + // IMPORTANT: These strides are used to locate an element in a patch at a + // depth zero (channel), which is not quite the same as "traditional" + // stride. + m_rowStride = m_patch_planes; + m_colStride = m_patch_rows * m_rowStride; + m_patchStride = m_colStride * m_patch_cols * m_patch_depth; + m_otherStride = m_patchStride * m_num_patches; + + m_outputPlanes = tensor.impl().outputPlanes(); + m_outputRows = tensor.impl().outputRows(); + m_outputCols = tensor.impl().outputCols(); + + m_outputPlanesRows = m_outputPlanes * m_outputRows; + + m_plane_strides = tensor.impl().userPlaneStride(); + m_row_strides = tensor.impl().userRowStride(); + m_col_strides = tensor.impl().userColStride(); + + m_in_plane_strides = tensor.impl().userInPlaneStride(); + m_in_row_strides = tensor.impl().userInRowStride(); + m_in_col_strides = tensor.impl().userInColStride(); + + m_patch_plane_inflate_strides = tensor.impl().planeInflateStride(); + m_patch_row_inflate_strides = tensor.impl().rowInflateStride(); + m_patch_col_inflate_strides = tensor.impl().colInflateStride(); + + if (internal::traits::Layout == ColMajor) { + m_inputDepth = tensor.impl().impl().dimensions()[0]; + m_inputPlanes = tensor.impl().impl().dimensions()[1]; + m_inputRows = tensor.impl().impl().dimensions()[2]; + m_inputCols = tensor.impl().impl().dimensions()[3]; + } else { + const int NumDims = tensor.impl().impl().dimensions().size(); + m_inputDepth = tensor.impl().impl().dimensions()[NumDims - 1]; + m_inputPlanes = tensor.impl().impl().dimensions()[NumDims - 2]; + m_inputRows = tensor.impl().impl().dimensions()[NumDims - 3]; + m_inputCols = tensor.impl().impl().dimensions()[NumDims - 4]; + } + + // Strides for navigating through the input tensor. + m_planeInputStride = m_inputDepth; + m_rowInputStride = m_inputDepth * m_inputPlanes; + m_colInputStride = m_inputDepth * m_inputRows * m_inputPlanes; + m_patchInputStride = + m_inputDepth * m_inputRows * m_inputCols * m_inputPlanes; + + m_planePaddingTop = tensor.impl().planePaddingTop(); + m_rowPaddingTop = tensor.impl().rowPaddingTop(); + m_colPaddingLeft = tensor.impl().colPaddingLeft(); + + m_fastNumPatches = internal::TensorIntDivisor(m_num_patches); + + m_fastPatchPlaneStride = + internal::TensorIntDivisor(m_patch_plane_stride); + m_fastPatchRowStride = + internal::TensorIntDivisor(m_patch_row_stride); + m_fastPatchColStride = + internal::TensorIntDivisor(m_patch_col_stride); + + m_fastInputPlaneStride = + internal::TensorIntDivisor(m_patch_plane_inflate_strides); + m_fastInputRowStride = + internal::TensorIntDivisor(m_patch_row_inflate_strides); + m_fastInputColStride = + internal::TensorIntDivisor(m_patch_col_inflate_strides); + + m_fastRowStride = internal::TensorIntDivisor(m_rowStride); + m_fastColStride = internal::TensorIntDivisor(m_colStride); + + m_fastDimZero = internal::TensorIntDivisor(m_patch_depth); + m_fastOutputRows = internal::TensorIntDivisor(m_outputRows); + m_fastOutputPlanes = internal::TensorIntDivisor(m_outputPlanes); + m_fastOutputRows = internal::TensorIntDivisor(m_outputRows); + m_fastOutputCols = internal::TensorIntDivisor(m_outputCols); + + m_fastOutputPlanesRows = + internal::TensorIntDivisor(m_outputPlanesRows); + } + + EIGEN_DEVICE_FUNC + TensorContractionInputMapper(const TensorContractionInputMapper& base_mapper) + : m_impl(base_mapper.m_impl) { + m_patch_depth = base_mapper.m_patch_depth; + m_patch_planes = base_mapper.m_patch_planes; + m_patch_rows = base_mapper.m_patch_rows; + m_patch_cols = base_mapper.m_patch_cols; + m_num_patches = base_mapper.m_num_patches; + + m_patch_plane_stride = base_mapper.m_patch_plane_stride; + m_patch_row_stride = base_mapper.m_patch_row_stride; + m_patch_col_stride = base_mapper.m_patch_col_stride; + + m_rowStride = base_mapper.m_rowStride; + m_colStride = base_mapper.m_colStride; + m_patchStride = base_mapper.m_patchStride; + m_otherStride = base_mapper.m_otherStride; + + m_planeInputStride = base_mapper.m_planeInputStride; + m_rowInputStride = base_mapper.m_rowInputStride; + m_colInputStride = base_mapper.m_colInputStride; + m_patchInputStride = base_mapper.m_patchInputStride; + m_otherInputStride = base_mapper.m_otherInputStride; + + m_inputDepth = base_mapper.m_inputDepth; + m_inputPlanes = base_mapper.m_inputPlanes; + m_inputRows = base_mapper.m_inputRows; + m_inputCols = base_mapper.m_inputCols; + + m_outputPlanes = base_mapper.m_outputPlanes; + m_outputRows = base_mapper.m_outputRows; + m_outputCols = base_mapper.m_outputCols; + + m_plane_strides = base_mapper.m_plane_strides; + m_row_strides = base_mapper.m_row_strides; + m_col_strides = base_mapper.m_col_strides; + + m_in_plane_strides = base_mapper.m_in_plane_strides; + m_in_row_strides = base_mapper.m_in_row_strides; + m_in_col_strides = base_mapper.m_in_col_strides; + + m_patch_plane_inflate_strides = base_mapper.m_patch_plane_inflate_strides; + m_patch_row_inflate_strides = base_mapper.m_patch_row_inflate_strides; + m_patch_col_inflate_strides = base_mapper.m_patch_col_inflate_strides; + + m_planePaddingTop = base_mapper.m_planePaddingTop; + m_rowPaddingTop = base_mapper.m_rowPaddingTop; + m_colPaddingLeft = base_mapper.m_colPaddingLeft; + + m_outputPlanesRows = base_mapper.m_outputPlanesRows; + + m_fastNumPatches = base_mapper.m_fastNumPatches; + m_fastPatchPlaneStride = base_mapper.m_fastPatchPlaneStride; + m_fastPatchRowStride = base_mapper.m_fastPatchRowStride; + m_fastPatchColStride = base_mapper.m_fastPatchColStride; + m_fastInputPlaneStride = base_mapper.m_fastInputPlaneStride; + m_fastInputRowStride = base_mapper.m_fastInputRowStride; + m_fastInputColStride = base_mapper.m_fastInputColStride; + m_fastRowStride = base_mapper.m_fastRowStride; + m_fastColStride = base_mapper.m_fastColStride; + m_fastOutputPlanes = base_mapper.m_fastOutputPlanes; + m_fastOutputRows = base_mapper.m_fastOutputRows; + m_fastOutputCols = base_mapper.m_fastOutputCols; + m_fastDimZero = base_mapper.m_fastDimZero; + m_fastOutputPlanesRows = base_mapper.m_fastOutputPlanesRows; + } + + // If true, turns off some optimizations for loading packets since the image + // patches are "non-standard" such as there are non-trivial strides or + // inflations in the input. + EIGEN_DEVICE_FUNC + EIGEN_ALWAYS_INLINE bool nonStandardPatches() const { + return m_in_plane_strides != 1 || m_in_row_strides != 1 || + m_in_col_strides != 1 || m_patch_plane_inflate_strides != 1 || + m_patch_row_inflate_strides != 1 || m_patch_col_inflate_strides != 1; + } + + EIGEN_DEVICE_FUNC + EIGEN_STRONG_INLINE SubMapper getSubMapper(Index i, Index j) const { + return SubMapper(*this, i, j); + } + + EIGEN_DEVICE_FUNC + EIGEN_STRONG_INLINE LinearMapper getLinearMapper(Index i, Index j) const { + return LinearMapper(*this, i, j); + } + + EIGEN_DEVICE_FUNC + EIGEN_ALWAYS_INLINE Scalar operator()(Index row) const { + Index planeIndex, rowIndex, colIndex, otherIndex; + computeBaseIndices(0, planeIndex, rowIndex, colIndex, otherIndex); + return loadCoeff(row, planeIndex, rowIndex, colIndex, otherIndex); + } + + // Load the coefficient at the patchIndex location instead of the usual + // m_rowIndex, m_colIndex, m_otherIndex. This is currently only used by the + // gpu code. + EIGEN_DEVICE_FUNC + EIGEN_STRONG_INLINE Scalar operator()(Index row, Index patchIndex) const { + Index planeIndex, rowIndex, colIndex, otherIndex; + computeBaseIndices(patchIndex, planeIndex, rowIndex, colIndex, otherIndex); + return loadCoeff(row, planeIndex, rowIndex, colIndex, otherIndex); + } + + EIGEN_DEVICE_FUNC + EIGEN_ALWAYS_INLINE Packet loadPacket(Index row) const { + Index planeIndex, rowIndex, colIndex, otherIndex; + computeBaseIndices(0, planeIndex, rowIndex, colIndex, otherIndex); + return loadPacket(row, planeIndex, rowIndex, colIndex, otherIndex); + } + + // Load the packet at the patchIndex location instead of the usual m_rowIndex, + // m_colIndex, m_otherIndex. This is currently only used by the gpu code. + EIGEN_DEVICE_FUNC + EIGEN_ALWAYS_INLINE Packet loadPacket(Index row, Index patchIndex) const { + Index planeIndex, rowIndex, colIndex, otherIndex; + computeBaseIndices(patchIndex, planeIndex, rowIndex, colIndex, otherIndex); + return loadPacket(row, planeIndex, rowIndex, colIndex, otherIndex); + } + + EIGEN_DEVICE_FUNC + EIGEN_ALWAYS_INLINE const TensorEvaluator& impl() const { + return m_impl; + } + + EIGEN_DEVICE_FUNC + EIGEN_ALWAYS_INLINE Index patchDepth() const { return m_planeInputStride; } + EIGEN_DEVICE_FUNC + EIGEN_ALWAYS_INLINE Index patchPlanes() const { return m_rowStride; } + EIGEN_DEVICE_FUNC + EIGEN_ALWAYS_INLINE Index patchRows() const { return m_patch_rows; } + EIGEN_DEVICE_FUNC + EIGEN_ALWAYS_INLINE Index patchCols() const { return m_patch_cols; } + + private: + friend class TensorContractionSubMapper< + Scalar, Index, Side, + TensorEvaluator >, + Device>, + nocontract_t, contract_t, packet_size, inner_dim_contiguous, + inner_dim_reordered, Alignment>; + + // Load coefficient from a patch specified by the "within patch offset" + // (patchId) and the precomputed indices of the first element of the patch. + EIGEN_DEVICE_FUNC + EIGEN_STRONG_INLINE Scalar loadCoeff(Index patchId, Index planeIndex, + Index rowIndex, Index colIndex, + Index otherIndex) const { + // Find the offset of the element wrt the location of the first element. + const Index patchOffset = patchId / m_fastDimZero; + + const Index colOffset = patchOffset / m_fastColStride; + const Index inputCol = colIndex + colOffset * m_in_col_strides; + const Index origInputCol = + (m_patch_col_inflate_strides == 1) + ? inputCol + : ((inputCol >= 0) ? (inputCol / m_fastInputColStride) : 0); + + const Index rowOffset = + (patchOffset - colOffset * m_colStride) / m_fastRowStride; + const Index inputRow = rowIndex + rowOffset * m_in_row_strides; + const Index origInputRow = + (m_patch_row_inflate_strides == 1) + ? inputRow + : ((inputRow >= 0) ? (inputRow / m_fastInputRowStride) : 0); + + const Index planeOffset = + patchOffset - colOffset * m_colStride - rowOffset * m_rowStride; + const Index inputPlane = planeIndex + planeOffset * m_in_plane_strides; + const Index origInputPlane = + (m_patch_plane_inflate_strides == 1) + ? inputPlane + : ((inputPlane >= 0) ? (inputPlane / m_fastInputPlaneStride) : 0); + + if (origInputCol < 0 || origInputRow < 0 || origInputPlane < 0 || + origInputCol >= m_inputCols || origInputRow >= m_inputRows || + origInputPlane >= m_inputPlanes || + (inputCol != origInputCol * m_patch_col_inflate_strides) || + (inputRow != origInputRow * m_patch_row_inflate_strides) || + (inputPlane != origInputPlane * m_patch_plane_inflate_strides)) { + return Scalar(0); + } + + const Index depth = patchId - patchOffset * patchDepth(); + const Index inputIndex = depth + origInputPlane * m_planeInputStride + + origInputRow * m_rowInputStride + + origInputCol * m_colInputStride + otherIndex; + + return m_impl.coeff(inputIndex); + } + + // This is the same as loadCoeff(...), but optimized for all `inflate_strides` + // and `in_strides` equal to 1 (template specialization without templates). + EIGEN_DEVICE_FUNC + EIGEN_STRONG_INLINE Scalar loadCoeffStandard(Index patchId, Index planeIndex, + Index rowIndex, Index colIndex, + Index otherIndex) const { + eigen_assert(!nonStandardPatches()); + + // Find the offset of the element wrt the location of the first element. + const Index patchOffset = patchId / m_fastDimZero; + + const Index colOffset = patchOffset / m_fastColStride; + const Index rowOffset = + (patchOffset - colOffset * m_colStride) / m_fastRowStride; + const Index planeOffset = + patchOffset - colOffset * m_colStride - rowOffset * m_rowStride; + + const Index inputCol = colIndex + colOffset; + const Index inputRow = rowIndex + rowOffset; + const Index inputPlane = planeIndex + planeOffset; + + if (inputCol < 0 || inputCol >= m_inputCols || inputRow < 0 || + inputRow >= m_inputRows || inputPlane < 0 || + inputPlane >= m_inputPlanes) { + return Scalar(0); + } + + const Index depth = patchId - patchOffset * patchDepth(); + const Index inputIndex = depth + inputPlane * m_planeInputStride + + inputRow * m_rowInputStride + + inputCol * m_colInputStride + otherIndex; + + return m_impl.coeff(inputIndex); + } + + // Load packet from a patch specified by the "within patch offset" + // (patchId) and the precomputed indices of the first element of the patch. + EIGEN_DEVICE_FUNC + EIGEN_ALWAYS_INLINE Packet loadPacket(Index patchId, Index planeIndex, + Index rowIndex, Index colIndex, + Index otherIndex) const { + const Index packetSize = internal::unpacket_traits::size; + + EIGEN_STATIC_ASSERT(packetSize > 1, YOU_MADE_A_PROGRAMMING_MISTAKE) + eigen_assert(patchId < + patchDepth() * patchPlanes() * patchRows() * patchCols()); + + if (nonStandardPatches()) { + return packetWithPossibleZero(patchId, planeIndex, rowIndex, colIndex, + otherIndex); + } + typedef decltype(m_impl) TensorEvaluatorT; + return loadPacketStandard( + patchId, planeIndex, rowIndex, colIndex, otherIndex); + } + + // Helper function to load a 'partial' packet - this is the single row part of + // a packet that is split across two rows (but single column). In the + // 'partial' packet, the elements corresponding to the row (specified through + // rowOffset) are loaded and the rest of the elements are zero-filled into the + // 'partial' packet. This function is called from + // loadPacketStandardFromSingleColumnTwoRows(). This code path is exercised + // only when the packet type supports masked load and when the partial packet + // load is available in the TensorEvaluator. + EIGEN_DEVICE_FUNC + EIGEN_ALWAYS_INLINE Packet loadPartialPacketStandard( + Index planeIndex, Index rowIndex, Index colIndex, Index otherIndex, + Index patchId, const Index span[], const Index patchOffsets[], + Index colOffset, Index rowOffset) const { + const Index inputCol = colIndex + colOffset; + const Index inputRow = rowIndex + rowOffset; + const Index planeOffsets[2] = { + patchOffsets[0] - colOffset * m_colStride - rowOffset * m_rowStride, + patchOffsets[1] - colOffset * m_colStride - rowOffset * m_rowStride}; + const Index inputPlanes[2] = {planeIndex + planeOffsets[0], + planeIndex + planeOffsets[1]}; + + if (inputRow >= m_inputRows || inputRow < 0 || inputCol >= m_inputCols || + inputCol < 0 || inputPlanes[0] >= m_inputPlanes || inputPlanes[1] < 0) { + // Partial packet is all zeros + return internal::pset1(Scalar(0)); + } else if (inputPlanes[0] >= 0 && inputPlanes[1] < m_inputPlanes) { + // From inputIndex-span[0], we need to load elements starting from index + // span[0] all the way upto (and including) span[1]. + const Index depth = patchId - patchOffsets[0] * patchDepth(); + const Index inputIndex = depth + inputPlanes[0] * m_planeInputStride + + inputRow * m_rowInputStride + + inputCol * m_colInputStride + otherIndex; + return m_impl.template partialPacket( + inputIndex - span[0], mask(span[0], span[1] + 1)); + } else { + // Using slow path for this partial packet. + // We need to load elements starting from index span[0] all the way upto + // (and including) span[1]. We split this load into 3 parts: + // 0 : span[0]-1 - Zeros will be loaded for these indices + // span[0] : span[1] - Elements will be loaded here for these indices + // span[1]+1 : packetSize-1 - Zeross will be loaded for these indices + const Index packetSize = internal::unpacket_traits::size; + EIGEN_ALIGN_MAX + std::remove_const_t values[packetSize]; + for (int i = 0; i < span[0]; ++i) values[i] = Scalar(0); + for (int i = span[0]; i < span[1] + 1; ++i) + values[i] = loadCoeff(patchId - span[0] + i, planeIndex, rowIndex, + colIndex, otherIndex); + for (int i = span[1] + 1; i < packetSize; ++i) values[i] = Scalar(0); + return internal::pload(values); + } + } + + // Helper function to load a packet that is split across two rows (but single + // column). If required, this function is called from loadPacketStandard() + // when the packet type supports masked load and when the partial packet load + // is available in the TensorEvaluator. + EIGEN_DEVICE_FUNC + EIGEN_ALWAYS_INLINE Packet loadPacketStandardFromSingleColumnTwoRows( + Index patchId, Index planeIndex, Index rowIndex, Index colIndex, + Index otherIndex, const Index patchOffsets[], const Index colOffsets[], + const Index rowOffsets[]) const { + eigen_assert(colOffsets[1] == colOffsets[0] && + rowOffsets[1] == rowOffsets[0] + 1); + const Index packetSize = internal::unpacket_traits::size; + + // Packet to load will be split into 2 parts where each part spans a single + // row and both the parts span the same column. + // First determine where to split. + const Index patchIdSplit = + (((rowOffsets[1] * m_rowStride) + (colOffsets[0] * m_colStride)) * + m_patch_depth) - + 1; + const Index patchOffsetSplit = patchIdSplit / m_fastDimZero; + + // patchIds[i]: patchId corresponding to partial packet i + // spans[i]: Start and end indices corresponding to the elements + // to be loaded for partial packet i + // patchOffsets2Cols[i]: patchOffsets corresponding to partial packet i + const Index patchIds[2] = {patchId, patchIdSplit + 1}; + const Index spans[2][2] = {{0, patchIdSplit - patchId}, + {patchIdSplit - patchId + 1, packetSize - 1}}; + const Index patchOffsets2Cols[2][2] = { + {patchOffsets[0], patchOffsetSplit}, + {patchOffsetSplit + 1, patchOffsets[1]}}; + + // Load partial packets and do bit-wise OR to generate required packet + return internal::por( + loadPartialPacketStandard(planeIndex, rowIndex, colIndex, otherIndex, + patchIds[0], spans[0], patchOffsets2Cols[0], + colOffsets[0], rowOffsets[0]), + loadPartialPacketStandard(planeIndex, rowIndex, colIndex, otherIndex, + patchIds[1], spans[1], patchOffsets2Cols[1], + colOffsets[1], rowOffsets[1])); + } + + // Helper function to load a packet that is present in a single column and + // row. If required, this function is called from loadPacketStandard(). + EIGEN_DEVICE_FUNC + EIGEN_ALWAYS_INLINE Packet loadPacketStandardFromSingleColumnSingleRow( + Index patchId, Index planeIndex, Index rowIndex, Index colIndex, + Index otherIndex, const Index patchOffsets[], const Index colOffsets[], + const Index rowOffsets[], const Index inputCols[], + const Index inputRows[]) const { + eigen_assert(colOffsets[1] == colOffsets[0] && + rowOffsets[1] == rowOffsets[0]); + const Index planeOffsets[2] = { + patchOffsets[0] - colOffsets[0] * m_colStride - + rowOffsets[0] * m_rowStride, + patchOffsets[1] - colOffsets[1] * m_colStride - + rowOffsets[1] * m_rowStride}; + eigen_assert(planeOffsets[0] <= planeOffsets[1]); + const Index inputPlanes[2] = {planeIndex + planeOffsets[0], + planeIndex + planeOffsets[1]}; + + if (inputPlanes[0] >= m_inputPlanes || inputPlanes[1] < 0) { + return internal::pset1(Scalar(0)); + } + if (inputPlanes[0] >= 0 && inputPlanes[1] < m_inputPlanes) { + const Index depth = patchId - patchOffsets[0] * patchDepth(); + const Index inputIndex = depth + inputPlanes[0] * m_planeInputStride + + inputRows[0] * m_rowInputStride + + inputCols[0] * m_colInputStride + otherIndex; + return m_impl.template packet(inputIndex); + } + return packetWithPossibleZero(patchId, planeIndex, rowIndex, colIndex, + otherIndex); + } + + // Load standard packet from a patch specified by the "within patch offset" + // (patchId) and the precomputed indices of the first element of the patch. + // This function will be called if partial packet loading is not available + // for the TensorEvaluator or if the packet type does not support masked + // load. + template + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE typename std::enable_if< + !TensorEvaluatorHasPartialPacket::value, + PacketT>::type + loadPacketStandard(Index patchId, Index planeIndex, Index rowIndex, + Index colIndex, Index otherIndex) const { + const Index packetSize = internal::unpacket_traits::size; + EIGEN_STATIC_ASSERT(packetSize > 1, YOU_MADE_A_PROGRAMMING_MISTAKE) + eigen_assert(patchId < + patchDepth() * patchPlanes() * patchRows() * patchCols()); + eigen_assert(!nonStandardPatches()); + + if ((patchDepth() % packetSize) == 0) { + return loadPacketFast(patchId, planeIndex, rowIndex, colIndex, + otherIndex); + } else { + // Offsets and input calculation here are identical to + // loadCoeffStandard(...), but repeated twice. + + const Index patchOffsets[2] = { + patchId / m_fastDimZero, (patchId + packetSize - 1) / m_fastDimZero}; + + const Index colOffsets[2] = {patchOffsets[0] / m_fastColStride, + patchOffsets[1] / m_fastColStride}; + eigen_assert(colOffsets[0] <= colOffsets[1]); + + const Index inputCols[2] = {colIndex + colOffsets[0], + colIndex + colOffsets[1]}; + if (inputCols[0] >= m_inputCols || inputCols[1] < 0) { + return internal::pset1(Scalar(0)); + } + + if (inputCols[0] == inputCols[1]) { + const Index rowOffsets[2] = { + (patchOffsets[0] - colOffsets[0] * m_colStride) / m_fastRowStride, + (patchOffsets[1] - colOffsets[1] * m_colStride) / m_fastRowStride}; + eigen_assert(rowOffsets[0] <= rowOffsets[1]); + const Index inputRows[2] = {rowIndex + rowOffsets[0], + rowIndex + rowOffsets[1]}; + + if (inputRows[0] >= m_inputRows || inputRows[1] < 0) { + return internal::pset1(Scalar(0)); + } + + if (inputRows[0] == inputRows[1]) { + return loadPacketStandardFromSingleColumnSingleRow( + patchId, planeIndex, rowIndex, colIndex, otherIndex, patchOffsets, + colOffsets, rowOffsets, inputCols, inputRows); + } + } + } + + return packetWithPossibleZero(patchId, planeIndex, rowIndex, colIndex, + otherIndex); + } + + // Load standard packet from a patch specified by the "within patch offset" + // (patchId) and the precomputed indices of the first element of the patch. + // This function will be called if partial packet loading is available for + // the TensorEvaluator and if the packet type supports masked load. + // The only difference between this and the other case is that if the packet + // to load is split across two rows (but in same column), then in this case + // instead of going to the slow (element-by-element) load, we load two packets + // - each containing elements from one of the rows (rest of the elements of + // the packets are zeroes), and then combine these two packets to generate the + // required packet. The idea is to enable fast load (if possible) of these + // 'partial' packets. + template + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE typename std::enable_if< + TensorEvaluatorHasPartialPacket::value, + PacketT>::type + loadPacketStandard(Index patchId, Index planeIndex, Index rowIndex, + Index colIndex, Index otherIndex) const { + const Index packetSize = internal::unpacket_traits::size; + EIGEN_STATIC_ASSERT(packetSize > 1, YOU_MADE_A_PROGRAMMING_MISTAKE) + eigen_assert(patchId < + patchDepth() * patchPlanes() * patchRows() * patchCols()); + eigen_assert(!nonStandardPatches()); + + if ((patchDepth() % packetSize) == 0) { + return loadPacketFast(patchId, planeIndex, rowIndex, colIndex, + otherIndex); + } else { + // Offsets and input calculation here are identical to + // loadCoeffStandard(...), but repeated twice. + + const Index patchOffsets[2] = { + patchId / m_fastDimZero, (patchId + packetSize - 1) / m_fastDimZero}; + + const Index colOffsets[2] = {patchOffsets[0] / m_fastColStride, + patchOffsets[1] / m_fastColStride}; + eigen_assert(colOffsets[0] <= colOffsets[1]); + + const Index inputCols[2] = {colIndex + colOffsets[0], + colIndex + colOffsets[1]}; + if (inputCols[0] >= m_inputCols || inputCols[1] < 0) { + return internal::pset1(Scalar(0)); + } + + if (inputCols[0] == inputCols[1]) { + const Index rowOffsets[2] = { + (patchOffsets[0] - colOffsets[0] * m_colStride) / m_fastRowStride, + (patchOffsets[1] - colOffsets[1] * m_colStride) / m_fastRowStride}; + eigen_assert(rowOffsets[0] <= rowOffsets[1]); + const Index inputRows[2] = {rowIndex + rowOffsets[0], + rowIndex + rowOffsets[1]}; + + if (inputRows[0] >= m_inputRows || inputRows[1] < 0) { + return internal::pset1(Scalar(0)); + } + + if (inputRows[0] == inputRows[1]) { + return loadPacketStandardFromSingleColumnSingleRow( + patchId, planeIndex, rowIndex, colIndex, otherIndex, patchOffsets, + colOffsets, rowOffsets, inputCols, inputRows); + } + if (inputRows[0] + 1 == inputRows[1]) { + return loadPacketStandardFromSingleColumnTwoRows( + patchId, planeIndex, rowIndex, colIndex, otherIndex, patchOffsets, + colOffsets, rowOffsets); + } + } + } + + return packetWithPossibleZero(patchId, planeIndex, rowIndex, colIndex, + otherIndex); + } + + EIGEN_DEVICE_FUNC + EIGEN_ALWAYS_INLINE Packet loadPacketFast(Index patchId, Index planeIndex, + Index rowIndex, Index colIndex, + Index otherIndex) const { + const Index packetSize = internal::unpacket_traits::size; + EIGEN_STATIC_ASSERT(packetSize > 1, YOU_MADE_A_PROGRAMMING_MISTAKE) + eigen_assert(patchId < + patchDepth() * patchPlanes() * patchRows() * patchCols()); + + eigen_assert(!nonStandardPatches()); + eigen_assert((patchDepth() % packetSize) == 0); + + // Find the offset of the element wrt the location of the first element. + const Index patchOffset = patchId / m_fastDimZero; + eigen_assert((patchId + packetSize - 1) / m_fastDimZero == patchOffset); + + const Index colOffset = patchOffset / m_fastColStride; + const Index rowOffset = + (patchOffset - colOffset * m_colStride) / m_fastRowStride; + const Index planeOffset = + patchOffset - colOffset * m_colStride - rowOffset * m_rowStride; + + const Index inputCol = colIndex + colOffset; + const Index inputRow = rowIndex + rowOffset; + const Index inputPlane = planeIndex + planeOffset; + + if (inputCol < 0 || inputRow < 0 || inputPlane < 0 || + inputCol >= m_inputCols || inputRow >= m_inputRows || + inputPlane >= m_inputPlanes) { + return internal::pset1(Scalar(0)); + } + + const Index depth = patchId - patchOffset * patchDepth(); + const Index inputIndex = depth + inputPlane * m_planeInputStride + + inputRow * m_rowInputStride + + inputCol * m_colInputStride + otherIndex; + return m_impl.template packet(inputIndex); + } + + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet + packetWithPossibleZero(Index patchId, Index planeIndex, Index rowIndex, + Index colIndex, Index otherIndex) const { + const int packetSize = internal::unpacket_traits::size; + EIGEN_ALIGN_MAX + std::remove_const_t values[packetSize]; + for (int i = 0; i < packetSize; ++i) { + values[i] = + loadCoeff(patchId + i, planeIndex, rowIndex, colIndex, otherIndex); + } + Packet rslt = internal::pload(values); + return rslt; + } + + // Precompute the indices (plane, row, col, other) of the first element of + // the given patch index, within the output tensor of the TensorVolumePatchOp. + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void computeBaseIndices( + Index patchIndex, Index& planeIndex, Index& rowIndex, Index& colIndex, + Index& otherIndex) const { + const size_t NumInputDims = array_size< + typename TensorEvaluator::Dimensions>::value; + + // Check if patchIndex might contain batch and other dimensions. + otherIndex = (NumInputDims == 4) ? 0 : patchIndex / m_fastNumPatches; + + // Compute index of the patch within the batch (and other dimensions). + const Index patch3DIndex = (NumInputDims == 4) + ? patchIndex + : (patchIndex - otherIndex * m_num_patches); + + otherIndex *= m_patchInputStride; + + colIndex = patch3DIndex / m_fastOutputPlanesRows; + rowIndex = + (patch3DIndex - colIndex * m_outputPlanesRows) / m_fastOutputPlanes; + planeIndex = + patch3DIndex - (colIndex * m_outputRows + rowIndex) * m_outputPlanes; + + colIndex = colIndex * m_col_strides - m_colPaddingLeft; + rowIndex = rowIndex * m_row_strides - m_rowPaddingTop; + planeIndex = planeIndex * m_plane_strides - m_planePaddingTop; + } + + Index m_patch_depth; // number of channels in the patch + Index m_patch_planes; // number of planes in the patch + Index m_patch_rows; // number of rows in the patch + Index m_patch_cols; // number of columns in the patch + Index m_num_patches; // number of patches to extract + + // Strides for navigating through the single patch. + Index m_patch_plane_stride; + Index m_patch_row_stride; + Index m_patch_col_stride; + + // Strides for the output tensor (depth is not the part of the stride). + Index m_rowStride; + Index m_colStride; + Index m_patchStride; + Index m_otherStride; + + Index m_planeInputStride; // Plane stride in the input tensor + Index m_rowInputStride; // Row stride in the input tensor + Index m_colInputStride; // Col stride in the input tensor + Index m_patchInputStride; // Patch stride in the input tensor + Index m_otherInputStride; + + Index m_inputDepth; // Depth of the input tensor + Index m_inputPlanes; // Number of planes in the input tensor + Index m_inputRows; // Number of rows in the input tensor + Index m_inputCols; // Number of cols in the input tensor + + Index m_outputPlanes; // Number of output planes + Index m_outputRows; // Number of output rows + Index m_outputCols; // Number of output cols + Index m_outputPlanesRows; // Cached outputPlanes * outputRows. + + Index m_plane_strides; // User specified plane stride + Index m_row_strides; // User specified row stride + Index m_col_strides; // User specified col stride + + // User specified plane/row/col atrous convolution strides. + Index m_in_plane_strides; + Index m_in_row_strides; + Index m_in_col_strides; + + // User specified plane/row/col inflation strides in the image patch. + Index m_patch_plane_inflate_strides; + Index m_patch_row_inflate_strides; + Index m_patch_col_inflate_strides; + + Index m_planePaddingTop; // Plane padding + Index m_rowPaddingTop; // Row padding + Index m_colPaddingLeft; // Column padding + + // Fast representation of various divisors. + internal::TensorIntDivisor m_fastNumPatches; + + internal::TensorIntDivisor m_fastPatchPlaneStride; + internal::TensorIntDivisor m_fastPatchRowStride; + internal::TensorIntDivisor m_fastPatchColStride; + + internal::TensorIntDivisor m_fastInputPlaneStride; + internal::TensorIntDivisor m_fastInputRowStride; + internal::TensorIntDivisor m_fastInputColStride; + + internal::TensorIntDivisor m_fastRowStride; + internal::TensorIntDivisor m_fastColStride; + + internal::TensorIntDivisor m_fastDimZero; // aka output depth + internal::TensorIntDivisor m_fastOutputPlanes; + internal::TensorIntDivisor m_fastOutputRows; + internal::TensorIntDivisor m_fastOutputCols; + internal::TensorIntDivisor m_fastOutputPlanesRows; + + const TensorEvaluator m_impl; +}; + +template +class TensorContractionSubMapper< + Scalar, Index, Side, + TensorEvaluator >, + Device>, + nocontract_t, contract_t, packet_size, inner_dim_contiguous, + inner_dim_reordered, Alignment> { + public: + typedef typename packet_traits::type Packet; + typedef typename packet_traits::half HalfPacket; + + typedef TensorContractionInputMapper< + Scalar, Index, Side, + TensorEvaluator >, + Device>, + nocontract_t, contract_t, packet_size, inner_dim_contiguous, + inner_dim_reordered, Alignment> + ParentMapper; + typedef TensorContractionSubMapper< + Scalar, Index, Side, + TensorEvaluator >, + Device>, + nocontract_t, contract_t, packet_size, inner_dim_contiguous, + inner_dim_reordered, Alignment> + Self; + typedef Self LinearMapper; + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorContractionSubMapper( + const ParentMapper& base_mapper, Index vert_offset, Index horiz_offset) + : m_base_mapper(base_mapper), + m_depth_offset(vert_offset), + m_col_offset(horiz_offset) { + m_base_mapper.computeBaseIndices(m_col_offset, m_planeIndex, m_rowIndex, + m_colIndex, m_otherIndex); + } + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorContractionSubMapper( + const Self& base_mapper, Index vert_offset, Index horiz_offset) + : m_base_mapper(base_mapper.m_base_mapper), + m_depth_offset(vert_offset + base_mapper.m_depth_offset), + m_col_offset(horiz_offset + base_mapper.m_col_offset) { + m_base_mapper.computeBaseIndices(m_col_offset, m_planeIndex, m_rowIndex, + m_colIndex, m_otherIndex); + } + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar operator()(Index i) const { + return m_base_mapper.loadCoeff(i + m_depth_offset, m_planeIndex, m_rowIndex, + m_colIndex, m_otherIndex); + } + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar operator()(Index i, + Index j) const { + return m_base_mapper(i + m_depth_offset, j + m_col_offset); + } + + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacket(Index i) const { + return m_base_mapper.loadPacket(i + m_depth_offset, m_planeIndex, + m_rowIndex, m_colIndex, m_otherIndex); + } + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacket(Index i, + Index j) const { + return m_base_mapper.template loadPacket(i + m_depth_offset, + j + m_col_offset); + } + + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar + loadCoeffStandard(Index i) const { + return m_base_mapper.loadCoeffStandard( + i + m_depth_offset, m_planeIndex, m_rowIndex, m_colIndex, m_otherIndex); + } + + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacketFast(Index i) const { + return m_base_mapper.loadPacketFast(i + m_depth_offset, m_planeIndex, + m_rowIndex, m_colIndex, m_otherIndex); + } + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet + loadPacketStandard(Index i) const { + typedef decltype(m_base_mapper.m_impl) TensorEvaluatorT; + return m_base_mapper.template loadPacketStandard( + i + m_depth_offset, m_planeIndex, m_rowIndex, m_colIndex, m_otherIndex); + } + template + EIGEN_DEVICE_FUNC bool aligned(Index) const { + return false; + } + + EIGEN_DEVICE_FUNC + EIGEN_ALWAYS_INLINE bool nonStandardPatches() const { + return m_base_mapper.nonStandardPatches(); + } + + // Max(Col|Row|Plane|Depth): compute the upper limit for the column, row, + // plane and depth index respectively that fits into the peeled_k elements + // starting at m_depth_offset. + + EIGEN_DEVICE_FUNC + EIGEN_ALWAYS_INLINE Index maxCol(const Index peeled_k) const { + const Index max_col = + fastPatchColStride().divide(m_depth_offset + peeled_k); + return std::min(1 + max_col, patchCols()); + } + + EIGEN_DEVICE_FUNC + EIGEN_ALWAYS_INLINE Index maxRow(const Index peeled_k, + const Index col) const { + const Index max_row = fastPatchRowStride().divide( + m_depth_offset + peeled_k - col * patchColStride()); + return std::min(1 + max_row, patchRows()); + } + + EIGEN_DEVICE_FUNC + EIGEN_ALWAYS_INLINE Index maxPlane(const Index peeled_k, const Index col, + const Index row) const { + const Index max_plane = fastPatchPlaneStride().divide( + m_depth_offset + peeled_k - col * patchColStride() - + row * patchRowStride()); + return std::min(1 + max_plane, patchPlanes()); + } + + // MaxDepth uses only the remaining number of elements in the peeled_k. + EIGEN_DEVICE_FUNC + EIGEN_ALWAYS_INLINE Index maxDepth(const Index num_elements, + const Index start_depth) const { + return std::min(start_depth + num_elements, patchDepth()); + } + + // Every register matters in this code, so sometimes to prevent register + // spilling, instead of the variable that you would expect to see, we use + // another one, that is guaranteed to have the same value. E.g. patch depth is + // always the same as input depth, and it's also the same as input plane + // stride. Bunch of other parameters have similar relations. + + typedef internal::TensorIntDivisor IndexDivisor; + + EIGEN_DEVICE_FUNC + EIGEN_ALWAYS_INLINE Index patchDepth() const { + eigen_assert(m_base_mapper.m_patch_depth == + m_base_mapper.m_planeInputStride && + "Patch depth must be equal to plane input stride."); + return m_base_mapper.m_planeInputStride; + } + + EIGEN_DEVICE_FUNC + EIGEN_ALWAYS_INLINE Index patchPlanes() const { + eigen_assert(m_base_mapper.m_patch_planes == m_base_mapper.m_rowStride && + "Patch planes must be equal to row stride."); + return m_base_mapper.m_rowStride; + } + EIGEN_DEVICE_FUNC + EIGEN_ALWAYS_INLINE Index patchRows() const { + return m_base_mapper.m_patch_rows; + } + EIGEN_DEVICE_FUNC + EIGEN_ALWAYS_INLINE Index patchCols() const { + return m_base_mapper.m_patch_cols; + } + + EIGEN_DEVICE_FUNC + EIGEN_ALWAYS_INLINE Index patchPlaneStride() const { + eigen_assert(patchDepth() == m_base_mapper.m_patch_plane_stride && + "Patch depth must be equal to patch plane stride."); + return patchDepth(); + } + EIGEN_DEVICE_FUNC + EIGEN_ALWAYS_INLINE Index patchRowStride() const { + return m_base_mapper.m_patch_row_stride; + } + EIGEN_DEVICE_FUNC + EIGEN_ALWAYS_INLINE Index patchColStride() const { + return m_base_mapper.m_patch_col_stride; + } + + EIGEN_DEVICE_FUNC + EIGEN_ALWAYS_INLINE IndexDivisor fastPatchPlaneStride() const { + eigen_assert(patchDepth() == m_base_mapper.m_patch_plane_stride && + "Patch depth must be equal to patch plane stride."); + return m_base_mapper.m_fastDimZero; // patch_depth + } + EIGEN_DEVICE_FUNC + EIGEN_ALWAYS_INLINE IndexDivisor fastPatchRowStride() const { + return m_base_mapper.m_fastPatchRowStride; + } + EIGEN_DEVICE_FUNC + EIGEN_ALWAYS_INLINE IndexDivisor fastPatchColStride() const { + return m_base_mapper.m_fastPatchColStride; + } + + EIGEN_DEVICE_FUNC + EIGEN_ALWAYS_INLINE Packet packetNoPadding(const Index depth, + const Index baseIndex) const { + const Index inputIndex = depth + baseIndex; + return m_base_mapper.m_impl.template packet(inputIndex); + } + EIGEN_DEVICE_FUNC + EIGEN_ALWAYS_INLINE Scalar coeffNoPadding(const Index depth, + const Index baseIndex) const { + const Index inputIndex = depth + baseIndex; + return m_base_mapper.m_impl.coeff(inputIndex); + } + + EIGEN_DEVICE_FUNC + EIGEN_ALWAYS_INLINE bool padPlane(const Index plane) const { + const Index p = m_planeIndex + plane; + return p < 0 || p >= m_base_mapper.m_inputPlanes; + } + EIGEN_DEVICE_FUNC + EIGEN_ALWAYS_INLINE bool padRow(const Index row) const { + const Index r = m_rowIndex + row; + return r < 0 || r >= m_base_mapper.m_inputRows; + } + EIGEN_DEVICE_FUNC + EIGEN_ALWAYS_INLINE bool padCol(const Index col) const { + const Index c = m_colIndex + col; + return c < 0 || c >= m_base_mapper.m_inputCols; + } + EIGEN_DEVICE_FUNC + EIGEN_ALWAYS_INLINE Index baseIndex(const Index plane, const Index row, + const Index col) const { + const Index p = m_planeIndex + plane; + const Index r = m_rowIndex + row; + const Index c = m_colIndex + col; + return p * m_base_mapper.m_planeInputStride + + r * m_base_mapper.m_rowInputStride + + c * m_base_mapper.m_colInputStride + m_otherIndex; + } + + EIGEN_DEVICE_FUNC + EIGEN_ALWAYS_INLINE Index planeOffset() const { + const Index patchOffset = m_depth_offset / m_base_mapper.m_fastDimZero; + const Index colOffset = patchOffset / m_base_mapper.m_fastColStride; + const Index rowOffset = + (patchOffset - colOffset * m_base_mapper.m_colStride) / + m_base_mapper.m_fastRowStride; + const Index planeOffset = patchOffset - + colOffset * m_base_mapper.m_colStride - + rowOffset * m_base_mapper.m_rowStride; + return planeOffset; + } + + EIGEN_DEVICE_FUNC + EIGEN_ALWAYS_INLINE Index rowOffset() const { + const Index patchOffset = m_depth_offset / m_base_mapper.m_fastDimZero; + const Index colOffset = patchOffset / m_base_mapper.m_fastColStride; + const Index rowOffset = + (patchOffset - colOffset * m_base_mapper.m_colStride) / + m_base_mapper.m_fastRowStride; + return rowOffset; + } + + EIGEN_DEVICE_FUNC + EIGEN_ALWAYS_INLINE Index colOffset() const { + const Index patchOffset = m_depth_offset / m_base_mapper.m_fastDimZero; + const Index colOffset = patchOffset / m_base_mapper.m_fastColStride; + return colOffset; + } + + EIGEN_DEVICE_FUNC + EIGEN_ALWAYS_INLINE Index depthOffset() const { + return m_depth_offset % patchDepth(); + } + + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE LinearMapper + getLinearMapper(Index i, Index j) const { + return LinearMapper(m_base_mapper, i + m_depth_offset, j + m_col_offset); + } + + private: + const ParentMapper m_base_mapper; // Keeping a copy instead of a reference + // performs better in benchmarks. + + Index m_depth_offset; // First row in the input matrix + Index m_col_offset; // First col in the input matrix + + // Knowing that: col_offset == patchIndex * OTHERS, we keep precomputed base + // indices for the first element in a patch specified by col_offset + // (see computeBaseIndices(...) for details). + Index m_planeIndex; + Index m_rowIndex; + Index m_colIndex; + Index m_otherIndex; +}; + +// Arrange a block of the right input matrix (in our case it's always a "virtual +// matrix" constructed from extracted volume patches) in contiguous memory. +// +// Given column major input (A0 beside A1 in memory): +// A0 B0 C0 D0 E0 F0 G0 H0 ... Z0 +// A1 B1 C1 D1 E1 F1 G1 H1 ... Z1 +// A2 B2 C2 D2 E2 F2 G2 H2 ... Z2 +// A3 B3 C3 D3 E3 F3 G3 H3 ... Z3 +// A4 B4 C4 D4 E4 F4 G4 H4 ... Z4 +// A5 B5 C5 D5 E5 F5 G5 H5 ... Z5 +// A6 B6 C6 D6 E6 F6 G6 H6 ... Z6 +// A7 B7 C7 D7 E7 F7 G7 H7 ... Z7 +// A8 ... +// ... +// +// *) A, B, C, ... - patches extracted from the original input. +// *) A0, A1, A2 ... - values from the same patch at different offsets. +// +// The traversal (packed rhs memory) order (B0 besides A0 in memory): +// A0 B0 C0 D0 A1 B1 C1 D1 ... +// E0 F0 G0 H0 E1 F1 G1 H1 ... +// ... +// Z0 Z1 Z2 Z3 Z4 Z5 Z6 Z7 ... <- doesn't belong to any block (nr = 4) +// +// This traversal order must be the same as in default gemm_pack_rhs defined in +// GeneralBlockPanelKernel.h. +// +// *) nr - number of registers along the 'n' dimension. +// See GeneralBlockPanelKernel.h and "Anatomy of High-Performance Matrix +// Multiplication" paper. +// +// TODO(ezhulenev): Add support for squeezing reads along two innermost +// dimensions (see eigen_spatial_convolutions). +template +struct gemm_pack_rhs< + Scalar, Index, + TensorContractionSubMapper< + Scalar, Index, Rhs, + TensorEvaluator >, + Device>, + nocontract_t, contract_t, packet_size, inner_dim_contiguous, + inner_dim_reordered, Alignment>, + nr, ColMajor, false, false> { + typedef TensorContractionSubMapper< + Scalar, Index, Rhs, + TensorEvaluator >, + Device>, + nocontract_t, contract_t, packet_size, inner_dim_contiguous, + inner_dim_reordered, Alignment> + SubMapper; + + typedef SubMapper DataMapper; + typedef typename packet_traits::type Packet; + + EIGEN_STATIC_ASSERT((nr == 4), YOU_MADE_A_PROGRAMMING_MISTAKE); + + EIGEN_DEVICE_FUNC + EIGEN_DONT_INLINE void operator()(Scalar* block, const DataMapper& rhs, + Index depth, Index cols, Index stride = 0, + Index offset = 0) const { + eigen_assert(stride == 0); + eigen_assert(offset == 0); + + const Index packet_cols4 = (cols / 4) * 4; + const Index peeled_k = (depth / packet_size) * packet_size; + const bool non_standard_patches = rhs.nonStandardPatches(); + + for (Index j2 = 0; j2 < packet_cols4; j2 += 4) { + const SubMapper dm0 = rhs.getLinearMapper(0, j2 + 0); + const SubMapper dm1 = rhs.getLinearMapper(0, j2 + 1); + const SubMapper dm2 = rhs.getLinearMapper(0, j2 + 2); + const SubMapper dm3 = rhs.getLinearMapper(0, j2 + 3); + + Index k = 0; + if ((packet_size % 4) == 0 && !non_standard_patches) { + // FAST PATH: + // Iterate over patch columns, rows and planes if we know that a single + // packet do not span across multiple planes, rows or columns. + if ((rhs.patchDepth() % packet_size) == 0) { + const Index start_col = rhs.colOffset(); + const Index max_col = rhs.maxCol(peeled_k); + + for (Index c = start_col; c < max_col; ++c) { + eigen_assert(k <= peeled_k); + + const Index start_row = (c == start_col) ? rhs.rowOffset() : 0; + const Index max_row = rhs.maxRow(peeled_k, c); + + const bool pad_col0 = dm0.padCol(c); + const bool pad_col1 = dm1.padCol(c); + const bool pad_col2 = dm2.padCol(c); + const bool pad_col3 = dm3.padCol(c); + + for (Index r = start_row; r < max_row; ++r) { + eigen_assert(k <= peeled_k); + + const Index start_plane = ((c == start_col) && (r == start_row)) + ? rhs.planeOffset() + : 0; + const Index max_plane = rhs.maxPlane(peeled_k, c, r); + + const bool pad_row0 = pad_col0 || dm0.padRow(r); + const bool pad_row1 = pad_col1 || dm1.padRow(r); + const bool pad_row2 = pad_col2 || dm2.padRow(r); + const bool pad_row3 = pad_col3 || dm3.padRow(r); + + for (Index p = start_plane; p < max_plane; ++p) { + eigen_assert(k <= peeled_k); + + const bool pad0 = pad_row0 || dm0.padPlane(p); + const bool pad1 = pad_row1 || dm1.padPlane(p); + const bool pad2 = pad_row2 || dm2.padPlane(p); + const bool pad3 = pad_row3 || dm3.padPlane(p); + + const Index idx0 = dm0.baseIndex(p, r, c); + const Index idx1 = dm1.baseIndex(p, r, c); + const Index idx2 = dm2.baseIndex(p, r, c); + const Index idx3 = dm3.baseIndex(p, r, c); + + const Index start_depth = + ((c == start_col) && (r == start_row) && (p == start_plane)) + ? rhs.depthOffset() + : 0; + const Index max_depth = rhs.maxDepth(peeled_k - k, start_depth); + eigen_assert((max_depth - start_depth) % packet_size == 0); + + for (Index d = start_depth; d < max_depth; d += packet_size) { + eigen_assert(k < peeled_k); + PacketBlock kernel; + kernel.packet[0] = pad0 ? pset1(Scalar(0)) + : rhs.packetNoPadding(d, idx0); + kernel.packet[1] = pad1 ? pset1(Scalar(0)) + : rhs.packetNoPadding(d, idx1); + kernel.packet[2] = pad2 ? pset1(Scalar(0)) + : rhs.packetNoPadding(d, idx2); + kernel.packet[3] = pad3 ? pset1(Scalar(0)) + : rhs.packetNoPadding(d, idx3); + ptranspose(kernel); + pstoreu(block + 0 * packet_size, kernel.packet[0]); + pstoreu(block + 1 * packet_size, kernel.packet[1]); + pstoreu(block + 2 * packet_size, kernel.packet[2]); + pstoreu(block + 3 * packet_size, kernel.packet[3]); + block += 4 * packet_size; + k += packet_size; + } + } + } + } + + // The loop above should fill peeled_k elements. + eigen_assert(peeled_k == k); + + } else { + // Packet can span multiple planes, rows or columns, so we have to go + // though the slower "standard" path. + for (; k < peeled_k; k += packet_size) { + PacketBlock kernel; + kernel.packet[0] = dm0.loadPacketStandard(k); + kernel.packet[1] = dm1.loadPacketStandard(k); + kernel.packet[2] = dm2.loadPacketStandard(k); + kernel.packet[3] = dm3.loadPacketStandard(k); + ptranspose(kernel); + pstoreu(block + 0 * packet_size, kernel.packet[0]); + pstoreu(block + 1 * packet_size, kernel.packet[1]); + pstoreu(block + 2 * packet_size, kernel.packet[2]); + pstoreu(block + 3 * packet_size, kernel.packet[3]); + block += 4 * packet_size; + } + } + } + + // Copy the remaining coefficients of the column block after the peeled_k. + if (!non_standard_patches) { + for (; k < depth; k++) { + block[0] = dm0.loadCoeffStandard(k); + block[1] = dm1.loadCoeffStandard(k); + block[2] = dm2.loadCoeffStandard(k); + block[3] = dm3.loadCoeffStandard(k); + block += 4; + } + } else { + for (; k < depth; k++) { + block[0] = dm0(k); + block[1] = dm1(k); + block[2] = dm2(k); + block[3] = dm3(k); + block += 4; + } + } + } + + // Copy the remaining columns one at a time (nr==1). + for (Index j2 = packet_cols4; j2 < cols; ++j2) { + const SubMapper dm0 = rhs.getLinearMapper(0, j2); + for (Index k = 0; k < depth; k++) { + *block = dm0(k); + block += 1; + } + } + } +}; + +// Template specialization for packet_size = 2. We must special-case packet +// blocks with nr > packet_size, e.g. PacketBlock. +// +// TODO(ezhulenev): Add support for squeezing reads along two innermost +// dimensions (see eigen_spatial_convolutions). +template +struct gemm_pack_rhs< + Scalar, Index, + TensorContractionSubMapper< + Scalar, Index, Rhs, + TensorEvaluator >, + Device>, + nocontract_t, contract_t, /*packet_size*/ 2, inner_dim_contiguous, + inner_dim_reordered, Alignment>, + nr, ColMajor, false, false> { + typedef TensorContractionSubMapper< + Scalar, Index, Rhs, + TensorEvaluator >, + Device>, + nocontract_t, contract_t, /*packet_size*/ 2, inner_dim_contiguous, + inner_dim_reordered, Alignment> + SubMapper; + typedef SubMapper DataMapper; + typedef typename packet_traits::type Packet; + + EIGEN_STATIC_ASSERT((nr == 4), YOU_MADE_A_PROGRAMMING_MISTAKE); + + EIGEN_DEVICE_FUNC + EIGEN_DONT_INLINE void operator()(Scalar* block, const DataMapper& rhs, + Index depth, Index cols, Index stride = 0, + Index offset = 0) const { + eigen_assert(stride == 0); + eigen_assert(offset == 0); + + const int packet_size = 2; + + const Index packet_cols4 = (cols / 4) * 4; + const Index peeled_k = (depth / packet_size) * packet_size; + const bool non_standard_patches = rhs.nonStandardPatches(); + + for (Index j2 = 0; j2 < packet_cols4; j2 += 4) { + const SubMapper dm0 = rhs.getLinearMapper(0, j2 + 0); + const SubMapper dm1 = rhs.getLinearMapper(0, j2 + 1); + const SubMapper dm2 = rhs.getLinearMapper(0, j2 + 2); + const SubMapper dm3 = rhs.getLinearMapper(0, j2 + 3); + + Index k = 0; + if (!non_standard_patches) { + // FAST PATH: + // Iterate over patch columns, rows and planes if we know that a single + // packet do not span across multiple planes, rows or columns. + if ((rhs.patchDepth() % packet_size) == 0) { + const Index start_col = rhs.colOffset(); + const Index max_col = rhs.maxCol(peeled_k); + + for (Index c = start_col; c < max_col; ++c) { + eigen_assert(k <= peeled_k); + + const Index start_row = (c == start_col) ? rhs.rowOffset() : 0; + const Index max_row = rhs.maxRow(peeled_k, c); + + const bool pad_col0 = dm0.padCol(c); + const bool pad_col1 = dm1.padCol(c); + const bool pad_col2 = dm2.padCol(c); + const bool pad_col3 = dm3.padCol(c); + + for (Index r = start_row; r < max_row; ++r) { + eigen_assert(k <= peeled_k); + + const Index start_plane = ((c == start_col) && (r == start_row)) + ? rhs.planeOffset() + : 0; + const Index max_plane = rhs.maxPlane(peeled_k, c, r); + + const bool pad_row0 = dm0.padRow(r); + const bool pad_row1 = dm1.padRow(r); + const bool pad_row2 = dm2.padRow(r); + const bool pad_row3 = dm3.padRow(r); + + for (Index p = start_plane; p < max_plane; ++p) { + eigen_assert(k <= peeled_k); + + const bool pad0 = pad_col0 || pad_row0 || dm0.padPlane(p); + const bool pad1 = pad_col1 || pad_row1 || dm1.padPlane(p); + const bool pad2 = pad_col2 || pad_row2 || dm2.padPlane(p); + const bool pad3 = pad_col3 || pad_row3 || dm3.padPlane(p); + + const Index idx0 = dm0.baseIndex(p, r, c); + const Index idx1 = dm1.baseIndex(p, r, c); + const Index idx2 = dm2.baseIndex(p, r, c); + const Index idx3 = dm3.baseIndex(p, r, c); + + const Index start_depth = + ((c == start_col) && (r == start_row) && (p == start_plane)) + ? rhs.depthOffset() + : 0; + const Index max_depth = rhs.maxDepth(peeled_k - k, start_depth); + eigen_assert((max_depth - start_depth) % packet_size == 0); + + for (Index d = start_depth; d < max_depth; d += packet_size) { + eigen_assert(k < peeled_k); + PacketBlock kernel0; + PacketBlock kernel1; + kernel0.packet[0] = pad0 ? pset1(Scalar(0)) + : rhs.packetNoPadding(d, idx0); + kernel0.packet[1] = pad1 ? pset1(Scalar(0)) + : rhs.packetNoPadding(d, idx1); + kernel1.packet[0] = pad2 ? pset1(Scalar(0)) + : rhs.packetNoPadding(d, idx2); + kernel1.packet[1] = pad3 ? pset1(Scalar(0)) + : rhs.packetNoPadding(d, idx3); + ptranspose(kernel0); + ptranspose(kernel1); + pstoreu(block + 0 * packet_size, kernel0.packet[0]); + pstoreu(block + 1 * packet_size, kernel1.packet[0]); + pstoreu(block + 2 * packet_size, kernel0.packet[1]); + pstoreu(block + 3 * packet_size, kernel1.packet[1]); + block += 4 * packet_size; + k += packet_size; + } + } + } + } + + // The loop above should fill peeled_k elements. + eigen_assert(peeled_k == k); + + } else { + for (; k < peeled_k; k += packet_size) { + PacketBlock kernel0; + PacketBlock kernel1; + kernel0.packet[0] = dm0.loadPacketStandard(k); + kernel0.packet[1] = dm1.loadPacketStandard(k); + kernel1.packet[0] = dm2.loadPacketStandard(k); + kernel1.packet[1] = dm3.loadPacketStandard(k); + ptranspose(kernel0); + ptranspose(kernel1); + pstoreu(block + 0 * packet_size, kernel0.packet[0]); + pstoreu(block + 1 * packet_size, kernel1.packet[0]); + pstoreu(block + 2 * packet_size, kernel0.packet[1]); + pstoreu(block + 3 * packet_size, kernel1.packet[1]); + block += 4 * packet_size; + } + } + } + + // Copy the remaining coefficients of the column block after the peeled_k. + if (!rhs.nonStandardPatches()) { + for (; k < depth; k++) { + block[0] = dm0.loadCoeffStandard(k); + block[1] = dm1.loadCoeffStandard(k); + block[2] = dm2.loadCoeffStandard(k); + block[3] = dm3.loadCoeffStandard(k); + block += 4; + } + } else { + for (; k < depth; k++) { + block[0] = dm0(k); + block[1] = dm1(k); + block[2] = dm2(k); + block[3] = dm3(k); + block += 4; + } + } + } + + // Copy the remaining columns one at a time (nr==1). + for (Index j2 = packet_cols4; j2 < cols; ++j2) { + const SubMapper dm0 = rhs.getLinearMapper(0, j2); + for (Index k = 0; k < depth; k++) { + *block = dm0(k); + block += 1; + } + } + } +}; + +// Special case for non-vectorized types such as float16 (packet_size = 1). +template +struct gemm_pack_rhs< + Scalar, Index, + TensorContractionSubMapper< + Scalar, Index, Rhs, + TensorEvaluator >, + Device>, + nocontract_t, contract_t, /*packet_size*/ 1, inner_dim_contiguous, + inner_dim_reordered, Alignment>, + nr, ColMajor, false, false> { + typedef TensorContractionSubMapper< + Scalar, Index, Rhs, + TensorEvaluator >, + Device>, + nocontract_t, contract_t, 1, inner_dim_contiguous, inner_dim_reordered, + Alignment> + SubMapper; + typedef SubMapper DataMapper; + + EIGEN_STATIC_ASSERT((nr == 4), YOU_MADE_A_PROGRAMMING_MISTAKE); + + EIGEN_DEVICE_FUNC + EIGEN_DONT_INLINE void operator()(Scalar* block, const DataMapper& rhs, + Index depth, Index cols, Index stride = 0, + Index offset = 0) const { + eigen_assert(stride == 0); + eigen_assert(offset == 0); + + const Index packet_cols4 = (cols / 4) * 4; + + for (Index j2 = 0; j2 < packet_cols4; j2 += 4) { + const SubMapper dm0 = rhs.getLinearMapper(0, j2 + 0); + const SubMapper dm1 = rhs.getLinearMapper(0, j2 + 1); + const SubMapper dm2 = rhs.getLinearMapper(0, j2 + 2); + const SubMapper dm3 = rhs.getLinearMapper(0, j2 + 3); + + if (!rhs.nonStandardPatches()) { + for (Index k = 0; k < depth; k++) { + block[0] = dm0.loadCoeffStandard(k); + block[1] = dm1.loadCoeffStandard(k); + block[2] = dm2.loadCoeffStandard(k); + block[3] = dm3.loadCoeffStandard(k); + block += 4; + } + } else { + for (Index k = 0; k < depth; k++) { + block[0] = dm0(k); + block[1] = dm1(k); + block[2] = dm2(k); + block[3] = dm3(k); + block += 4; + } + } + } + + // Copy the remaining columns one at a time (nr==1). + for (Index j2 = packet_cols4; j2 < cols; ++j2) { + const SubMapper dm0 = rhs.getLinearMapper(0, j2); + for (Index k = 0; k < depth; k++) { + *block = dm0(k); + block += 1; + } + } + } +}; +#endif + +#if defined(TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL) +// Pack a block of the right input matrix (in our case it's always a "virtual +// matrix" constructed from extracted image patches) in contiguous block in +// column-major storage order. Knowing the properties of the original patch op +// we can do it more efficient than the default gemm_pack_colmajor_block. +// +// TODO(ezhulenev): gemm_pack_colmajor_block for spatial convolutions supports +// squeezing reads along the 2 innermost dimensions, add it here if needed. +template +struct gemm_pack_colmajor_block< + Scalar, StorageIndex, + TensorContractionSubMapper< + Scalar, StorageIndex, Rhs, + TensorEvaluator >, + Device>, + nocontract_t, contract_t, packet_size, inner_dim_contiguous, + inner_dim_reordered, Alignment>, + ColMajor> { + typedef TensorContractionSubMapper< + Scalar, StorageIndex, Rhs, + TensorEvaluator >, + Device>, + nocontract_t, contract_t, packet_size, inner_dim_contiguous, + inner_dim_reordered, Alignment> + SubMapper; + + typedef SubMapper DataMapper; + typedef typename packet_traits::type Packet; + + EIGEN_DONT_INLINE + void operator()(Scalar* block, const DataMapper& rhs, StorageIndex rows, + StorageIndex cols) { + const bool standard_patches = !rhs.nonStandardPatches(); + + if (standard_patches && rhs.patchDepth() % packet_size == 0) { + packStandardPatches(block, rhs, rows, cols); + + } else if (standard_patches) { + packStandardPatches(block, rhs, rows, cols); + + } else { + // With non-standard patches we don't do any vectorized loads. + // TODO(ezhulenev): It doesn't look like that we should completely give up + // on packets. Make this code path faster! + for (StorageIndex col = 0; col < cols; ++col) { + SubMapper lm = rhs.getLinearMapper(0, col); + for (StorageIndex i = 0; i < rows; ++i) { + *block = lm(i); + ++block; + } + } + } + } + + private: + // Pack standard volume patches: + // + // - patch_depth_is_multiple_of_packet_size=true: We are guaranteed to have + // depth dimension size to be a multiple of packet size, so we can skip all + // non vectorized loads and checks. + // + template + EIGEN_ALWAYS_INLINE void packStandardPatches(Scalar* block, + const DataMapper& rhs, + StorageIndex rows, + StorageIndex cols) { + eigen_assert(!rhs.nonStandardPatches()); + + // Give vectorized_rows the name used in all other gemm_pack_rhs above. + const Index peeled_k = (rows / packet_size) * packet_size; + + const Index start_col = rhs.colOffset(); + const Index max_col = rhs.maxCol(peeled_k); + + for (StorageIndex col = 0; col < cols; ++col) { + SubMapper lm = rhs.getLinearMapper(0, col); + + Index k = 0; + for (Index c = start_col; c < max_col; ++c) { + eigen_assert(k <= peeled_k); + + const Index start_row = (c == start_col) ? rhs.rowOffset() : 0; + const Index max_row = rhs.maxRow(peeled_k, c); + const bool pad_col = lm.padCol(c); + + for (Index r = start_row; r < max_row; ++r) { + eigen_assert(k <= peeled_k); + + const Index start_plane = + ((c == start_col) && (r == start_row)) ? rhs.planeOffset() : 0; + const Index max_plane = rhs.maxPlane(peeled_k, c, r); + const bool pad_row = pad_col || lm.padRow(r); + + for (Index p = start_plane; p < max_plane; ++p) { + eigen_assert(k <= peeled_k); + + const Index start_depth = + ((c == start_col) && (r == start_row) && (p == start_plane)) + ? rhs.depthOffset() + : 0; + const Index max_depth = rhs.maxDepth(peeled_k - k, start_depth); + + const bool pad = pad_col || pad_row || lm.padPlane(p); + const Index base_idx = lm.baseIndex(p, r, c); + + if (patch_depth_is_multiple_of_packet_size) + eigen_assert((max_depth - start_depth) % packet_size == 0); + + // If patch depth is a multiple of packet size, it's guaranteed that + // we can process all values in depth dimension with packets. + const Index max_vectorized_depth = + patch_depth_is_multiple_of_packet_size + ? max_depth + : max_depth - packet_size; + + Index d = start_depth; + + // 1. Process depth dimension with vectorized instructions. + for (; d < max_vectorized_depth; d += packet_size) { + eigen_assert(k < peeled_k); + const Packet packet = pad ? pset1(Scalar(0)) + : rhs.packetNoPadding(d, base_idx); + internal::pstoreu(block, packet); + block += packet_size; + k += packet_size; + } + + // 2. Finish with coefficients. + if (!patch_depth_is_multiple_of_packet_size) { + for (; d < max_depth; d++) { + eigen_assert(k < peeled_k); + *block = pad ? Scalar(0) : rhs.coeffNoPadding(d, base_idx); + ++block; + ++k; + } + } + } + } + } + + // The loop above should fill peeled_k elements. + eigen_assert(peeled_k == k); + + // Fill remaining elements using loadCoeffStandard. + for (; k < rows; ++k) { + *block = lm.loadCoeffStandard(k); + ++block; + } + } + } +}; +#endif // defined(TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL) + +} // namespace internal + +/** CuboidConvolution + * \ingroup CXX11_NeuralNetworks_Module + * + * \brief Applies a 3D convolution over a multichannel input voxel block. + * + * The input parameter is expected to be a tensor with a rank of 4 or more + * (channels, depth, height, width, and optionally others). + * The kernel parameter is expected to be a 5D tensor (filters, channels, + * kernel_depth, kernel_height, kernel_width). + * The result can be assigned to a tensor of rank equal to the rank of the + * input. The dimensions of the result will be filters, depth, height, width + * (and others if applicable). + * + * The input and kernel have to be in the same layout, and both row-major and + * col-major are supported. The shapes given above are for col-major layout. + * For row-major, all dimensions should be reversed. + * + * It is possible to swap the order of the depth, width, and height dimensions + * provided that the same order is used in the input, the kernel, and the + * output. + */ +template +EIGEN_ALWAYS_INLINE static const std::conditional_t< + internal::traits::Layout == ColMajor, + TensorReshapingOp< + const DSizes::Index, + internal::traits::NumDimensions>, + const TensorContractionOp< + const array::Index>, 1>, + const TensorReshapingOp< + const DSizes::Index, 2>, + const Kernel>, + const TensorReshapingOp< + const DSizes::Index, 2>, + const TensorVolumePatchOp > > >, + TensorReshapingOp< + const DSizes::Index, + internal::traits::NumDimensions>, + const TensorContractionOp< + const array::Index>, 1>, + const TensorReshapingOp< + const DSizes::Index, 2>, + const TensorVolumePatchOp >, + const TensorReshapingOp< + const DSizes::Index, 2>, + const Kernel> > > > +CuboidConvolution(const Input& input, const Kernel& kernel, + const Index stridePlanes = 1, const Index strideRows = 1, + const Index strideCols = 1, + const PaddingType padding_type = PADDING_SAME) { + typedef typename internal::traits::Index TensorIndex; + TensorRef::Scalar, + internal::traits::NumDimensions, + internal::traits::Layout, TensorIndex> > + in(input); + TensorRef::Scalar, + internal::traits::NumDimensions, + internal::traits::Layout, TensorIndex> > + kern(kernel); + + EIGEN_STATIC_ASSERT( + internal::traits::Layout == internal::traits::Layout, + YOU_MADE_A_PROGRAMMING_MISTAKE); + static const bool isColMajor = (internal::traits::Layout == ColMajor); + static const int NumDims = internal::traits::NumDimensions; + + // Number of filters to apply. This is the same as the output depth of the + // result. + const TensorIndex kernelFilters = + isColMajor ? kern.dimensions()[0] : kern.dimensions()[4]; + const TensorIndex kernelChannels = + isColMajor ? kern.dimensions()[1] : kern.dimensions()[3]; + + // Spatial size of the kernel. + const TensorIndex kernelPlanes = + isColMajor ? kern.dimensions()[2] : kern.dimensions()[2]; + const TensorIndex kernelRows = + isColMajor ? kern.dimensions()[3] : kern.dimensions()[1]; + const TensorIndex kernelCols = + isColMajor ? kern.dimensions()[4] : kern.dimensions()[0]; + + if (isColMajor) { + eigen_assert(kernelChannels == in.dimension(0)); + } else { + eigen_assert(kernelChannels == in.dimension(NumDims - 1)); + } + + const TensorIndex inputPlanes = + isColMajor ? in.dimension(1) : in.dimension(NumDims - 2); + const TensorIndex inputRows = + isColMajor ? in.dimension(2) : in.dimension(NumDims - 3); + const TensorIndex inputCols = + isColMajor ? in.dimension(3) : in.dimension(NumDims - 4); + + TensorIndex out_planes; + TensorIndex out_height; + TensorIndex out_width; + switch (padding_type) { + case PADDING_VALID: + out_planes = Eigen::divup(inputPlanes - kernelPlanes + 1, + static_cast(stridePlanes)); + out_height = Eigen::divup(inputRows - kernelRows + 1, + static_cast(strideRows)); + out_width = Eigen::divup(inputCols - kernelCols + 1, + static_cast(strideCols)); + break; + case PADDING_SAME: + out_planes = + Eigen::divup(inputPlanes, static_cast(stridePlanes)); + out_height = + Eigen::divup(inputRows, static_cast(strideRows)); + out_width = Eigen::divup(inputCols, static_cast(strideCols)); + break; + default: + out_planes = 0; + out_height = 0; + out_width = 0; + eigen_assert(false && "unexpected padding"); + } + + DSizes kernel_dims; + if (isColMajor) { + kernel_dims[0] = kernelFilters; + kernel_dims[1] = kernelChannels * kernelPlanes * kernelRows * kernelCols; + } else { + kernel_dims[0] = kernelChannels * kernelPlanes * kernelRows * kernelCols; + kernel_dims[1] = kernelFilters; + } + + // Molds the output of the patch extraction result into a 2D tensor: + // - the first dimension (dims[0]): the patch values to be multiplied with the + // kernels + // - the second dimension (dims[1]): everything else + DSizes pre_contract_dims; + if (isColMajor) { + pre_contract_dims[0] = + kernelChannels * kernelPlanes * kernelRows * kernelCols; + pre_contract_dims[1] = out_planes * out_height * out_width; + for (int i = 4; i < NumDims; ++i) { + pre_contract_dims[1] *= in.dimension(i); + } + } else { + pre_contract_dims[1] = + kernelChannels * kernelPlanes * kernelRows * kernelCols; + pre_contract_dims[0] = out_planes * out_height * out_width; + for (int i = 0; i < NumDims - 4; ++i) { + pre_contract_dims[0] *= in.dimension(i); + } + } + + array, 1> contract_dims; + contract_dims[0] = IndexPair(1, 0); + + // Molds the output of the contraction into the shape expected by the user + // (assuming ColMajor): + // - 1st dim: kernel filters + // - 2nd dim: output depth + // - 3nd dim: output height + // - 4rd dim: output width + // - 5th dim and beyond: everything else including batch size + DSizes post_contract_dims; + if (isColMajor) { + post_contract_dims[0] = kernelFilters; + post_contract_dims[1] = out_planes; + post_contract_dims[2] = out_height; + post_contract_dims[3] = out_width; + for (int i = 4; i < NumDims; ++i) { + post_contract_dims[i] = in.dimension(i); + } + } else { + post_contract_dims[NumDims - 1] = kernelFilters; + post_contract_dims[NumDims - 2] = out_planes; + post_contract_dims[NumDims - 3] = out_height; + post_contract_dims[NumDims - 4] = out_width; + for (int i = 0; i < NumDims - 4; ++i) { + post_contract_dims[i] = in.dimension(i); + } + } + + return choose( + Cond::Layout == ColMajor>(), + kernel.reshape(kernel_dims) + .contract(input + .extract_volume_patches( + kernelPlanes, kernelRows, kernelCols, stridePlanes, + strideRows, strideCols, padding_type) + .reshape(pre_contract_dims), + contract_dims) + .reshape(post_contract_dims), + input + .extract_volume_patches(kernelPlanes, kernelRows, kernelCols, + stridePlanes, strideRows, strideCols, + padding_type) + .reshape(pre_contract_dims) + .contract(kernel.reshape(kernel_dims), contract_dims) + .reshape(post_contract_dims)); +} + +} // end namespace Eigen + +#endif // TENSORFLOW_CORE_KERNELS_EIGEN_CUBOID_CONVOLUTION_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/eigen_pooling.h b/third_party/tflite-hdrs/tensorflow/core/kernels/eigen_pooling.h new file mode 100644 index 00000000..ac701df0 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/eigen_pooling.h @@ -0,0 +1,546 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_EIGEN_POOLING_H_ +#define TENSORFLOW_CORE_KERNELS_EIGEN_POOLING_H_ + +#include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive + +namespace Eigen { + +/** SpatialMaxPooling + * \ingroup CXX11_NeuralNetworks_Module + * + * \brief Applies a max-pooling over a multichannel input image. + * + * The input parameter is expected to be a with a rank of 4 (channels, height, + * width, others in col-major, and the reverse of that in row-major). + * + * The result can be assigned to a tensor of rank equal to the rank of the + * input. The dimensions of the result will be channels, height, width, and + * others (in col-major, and the reverse of that if the input was row-major). + * + * The order of the width and height dimensions can be swapped if needed. + * + */ +template +EIGEN_ALWAYS_INLINE static const TensorReshapingOp< + const Eigen::DSizes::Index, + internal::traits::NumDimensions>, + const TensorReductionOp< + internal::MaxReducer< + std::remove_const_t::Scalar>>, + std::conditional_t< + internal::traits::Layout == ColMajor, + const Eigen::IndexList, Eigen::type2index<2>>, + const Eigen::IndexList, Eigen::type2index<3>>>, + const TensorImagePatchOp>> +SpatialMaxPooling(const Input& input, DenseIndex patchRows, + DenseIndex patchCols, DenseIndex strideRows, + DenseIndex strideCols, const PaddingType padding_type, + DenseIndex in_strideRows = 1, DenseIndex in_strideCols = 1) { + EIGEN_STATIC_ASSERT(internal::traits::NumDimensions == 4, + YOU_MADE_A_PROGRAMMING_MISTAKE); + + typedef typename internal::traits::Index TensorIndex; + TensorRef::Scalar, + internal::traits::NumDimensions, + internal::traits::Layout, TensorIndex> > + in(input); + + const DenseIndex patchRowsEff = + patchRows + (patchRows - 1) * (in_strideRows - 1); + const DenseIndex patchColsEff = + patchCols + (patchCols - 1) * (in_strideCols - 1); + + static const bool isColMajor = (internal::traits::Layout == ColMajor); + static const int idxRows = isColMajor ? 1 : 2; + static const int idxCols = isColMajor ? 2 : 1; + + // Molds the output of the reduction into the shape expected by the user. + // (assuming col-major): + // - 1st dim: channels + // - 2nd dim: output height + // - 3rd dim: output width + // - 4th dim and beyond: everything else including batch size + Eigen::DSizes::NumDimensions> + post_reduce_dims; + post_reduce_dims[0] = in.dimension(0); + if (padding_type == PADDING_VALID) { + post_reduce_dims[idxRows] = Eigen::divup( + static_cast(in.dimension(idxRows)) - patchRowsEff + 1, + strideRows); + post_reduce_dims[idxCols] = Eigen::divup( + static_cast(in.dimension(idxCols)) - patchColsEff + 1, + strideCols); + } else { + post_reduce_dims[idxRows] = Eigen::divup( + static_cast(in.dimension(idxRows)), strideRows); + post_reduce_dims[idxCols] = Eigen::divup( + static_cast(in.dimension(idxCols)), strideCols); + } + post_reduce_dims[3] = in.dimension(3); + + // Take advantage of cxx11 to give the compiler information it can use to + // optimize the code. + std::conditional_t< + internal::traits::Layout == ColMajor, + const Eigen::IndexList, Eigen::type2index<2>>, + const Eigen::IndexList, Eigen::type2index<3>>> + reduction_dims; + + return input + .extract_image_patches( + patchRows, patchCols, strideRows, strideCols, in_strideRows, + in_strideCols, padding_type, + Eigen::NumTraits::Scalar>>::lowest()) + .maximum(reduction_dims) + .reshape(post_reduce_dims); +} + +/** CuboidMaxPooling + * \ingroup CXX11_NeuralNetworks_Module + * + * \brief Applies a max-pooling over a multichannel input volume. + * + * The input parameter is expected to be a tensor with a rank of 5 (channels, + * depth, height, width, others in col-major, and the reverse of that in + * row-major). + * + * The result can be assigned to a tensor of rank equal to the rank of the + * input. The dimensions of the result will be channels, depth, height, width, + * and others (in col-major, and the reverse of that if the input was + * row-major). + * + * The order of the depth, width and height dimensions can be swapped if + * needed. + * + */ +template +EIGEN_ALWAYS_INLINE static const TensorReshapingOp< + const Eigen::DSizes::NumDimensions>, + const TensorReductionOp< + internal::MaxReducer< + std::remove_const_t::Scalar>>, + const Eigen::IndexList>, + const TensorReshapingOp< + const Eigen::DSizes, + const TensorVolumePatchOp>>> +CuboidMaxPooling(const Input& input, DenseIndex patchPlanes, + DenseIndex patchRows, DenseIndex patchCols, + DenseIndex stridePlanes, DenseIndex strideRows, + DenseIndex strideCols, const PaddingType padding_type) { + EIGEN_STATIC_ASSERT(internal::traits::NumDimensions == 5, + YOU_MADE_A_PROGRAMMING_MISTAKE); + static const bool isColMajor = (internal::traits::Layout == ColMajor); + + typedef typename internal::traits::Index TensorIndex; + TensorRef::Scalar, + internal::traits::NumDimensions, + internal::traits::Layout, TensorIndex> > + in(input); + + static const int idxPlanes = isColMajor ? 1 : 3; + static const int idxRows = 2; + static const int idxCols = isColMajor ? 3 : 1; + + // Molds the output of the reduction into the shape expected by the used + // (assuming col-major): + // - 1st dim: channels + // - 2nd dim: output depth + // - 3rd dim: output height + // - 4th dim: output width + // - 5th dim and beyond: everything else including batch size + Eigen::DSizes::NumDimensions> + post_reduce_dims; + post_reduce_dims[0] = in.dimension(0); + if (padding_type == PADDING_VALID) { + post_reduce_dims[idxPlanes] = Eigen::divup( + static_cast(in.dimension(idxPlanes)) - patchPlanes + 1, + stridePlanes); + post_reduce_dims[idxRows] = Eigen::divup( + static_cast(in.dimension(idxRows)) - patchRows + 1, + strideRows); + post_reduce_dims[idxCols] = Eigen::divup( + static_cast(in.dimension(idxCols)) - patchCols + 1, + strideCols); + } else { + post_reduce_dims[idxPlanes] = Eigen::divup( + static_cast(in.dimension(idxPlanes)), stridePlanes); + post_reduce_dims[idxRows] = Eigen::divup( + static_cast(in.dimension(idxRows)), strideRows); + post_reduce_dims[idxCols] = Eigen::divup( + static_cast(in.dimension(idxCols)), strideCols); + } + post_reduce_dims[4] = in.dimension(4); + + Eigen::DSizes pre_reduce_dims; + pre_reduce_dims[1] = patchRows * patchCols * patchPlanes; + if (isColMajor) { + pre_reduce_dims[0] = post_reduce_dims[0]; + pre_reduce_dims[2] = post_reduce_dims[1] * post_reduce_dims[2] * + post_reduce_dims[3] * post_reduce_dims[4]; + } else { + pre_reduce_dims[0] = post_reduce_dims[0] * post_reduce_dims[1] * + post_reduce_dims[2] * post_reduce_dims[3]; + pre_reduce_dims[2] = post_reduce_dims[4]; + } + + typedef std::remove_const_t::Scalar> + CoeffReturnType; + + // Take advantage of cxx11 to give the compiler information it can use to + // optimize the code. + Eigen::IndexList > reduction_dims; + return input + .extract_volume_patches(patchPlanes, patchRows, patchCols, stridePlanes, + strideRows, strideCols, padding_type, + -Eigen::NumTraits::highest()) + .reshape(pre_reduce_dims) + .maximum(reduction_dims) + .reshape(post_reduce_dims); +} + +/** SpatialAvgPooling + * \ingroup CXX11_NeuralNetworks_Module + * + * \brief Applies an average pooling over a multichannel input image. + * + * The input parameter is expected to be a tensor with a rank of 4 (channels, + * height, width, others in col-major, and the reverse of that in row-major). + * + * The result can be assigned to a tensor of rank equal to the rank of the + * input. The dimensions of the result will be channels, height, width, and + * others (in col-major, and the reverse of that if the input was row-major). + * + * The order of the width and height dimensions can be swapped if needed. + * + */ +namespace internal { + +template +struct AvgPoolMeanReducer { +#if (EIGEN_ARCH_i386 || EIGEN_ARCH_x86_64) && !defined(__CUDACC__) && \ + !defined(__HIPCC__) + // We only support packet access for floats. + static constexpr bool PacketAccess = internal::is_same::value; +#else + static const bool PacketAccess = false; +#endif + static constexpr bool IsStateful = true; + + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE AvgPoolMeanReducer() : scalarCount_(0) { + typedef typename packet_traits::type Packet; +#if defined(__HIPCC__) + packetCount_ = 0; +#else + packetCount_ = pset1(T(0.0)); +#endif + } + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void reduce(const T t, T* accum) { + if (t != -Eigen::NumTraits::highest()) { + (*accum) = (*accum) + t; + scalarCount_++; + } + } + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T initialize() const { + return static_cast(0); + } + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T finalize(const T accum) const { + eigen_assert(scalarCount_ > 0); + return accum / T(scalarCount_); + } + +#if (EIGEN_ARCH_i386 || EIGEN_ARCH_x86_64) && !defined(__CUDACC__) && \ + !defined(__HIPCC__) +#ifdef EIGEN_VECTORIZE_AVX512 +#define pequal(a, b) \ + _mm512_castsi512_ps( \ + _mm512_maskz_set1_epi32(_mm512_cmp_ps_mask(a, b, _CMP_EQ_UQ), -1)) + + // The ternarylogic function immediate determines the values in the result + // In the case below, 0xd8 implies (false_mask) ? (b) : (a) + // For details, refer to the vpternlogd instruction table at + // http://www.intel.com/content/dam/www/public/us/en/documents/manuals/64-ia-32-architectures-software-developer-vol-2c-manual.pdf + +#define psel(a, b, false_mask) \ + _mm512_castsi512_ps(_mm512_ternarylogic_epi32( \ + _mm512_castps_si512(a), _mm512_castps_si512(b), \ + _mm512_castps_si512(false_mask), 0xd8)) +#elif defined EIGEN_VECTORIZE_AVX +#define pequal(a, b) _mm256_cmp_ps(a, b, _CMP_EQ_UQ) +#define psel(a, b, false_mask) _mm256_blendv_ps(a, b, false_mask) +#else +#define pequal(a, b) _mm_cmpeq_ps(a, b) +#define psel(a, b, false_mask) \ + _mm_or_ps(_mm_andnot_ps(false_mask, a), _mm_and_ps(false_mask, b)) +#endif + + template + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void reducePacket(const Packet& p, + Packet* accum) { + reducePacketWithType(static_cast(0), p, accum); + } + + template + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void reducePacketWithType( + T, const Packet& p, Packet* accum) { + Packet skip_mask = + pequal(p, pset1(-Eigen::NumTraits::highest())); + (*accum) = padd(*accum, psel(p, pset1(0), skip_mask)); + packetCount_ = padd( + packetCount_, psel(pset1(1), pset1(0), skip_mask)); + } + + template + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet initializePacket() const { + return pset1(0); + } + + template + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet + finalizePacket(const Packet& vaccum) const { + return pdiv(vaccum, packetCount_); + } + template + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T + finalizeBoth(const T saccum, const Packet& vaccum) const { + return (saccum + predux(vaccum)) / (scalarCount_ + predux(packetCount_)); + } +#endif + + protected: + typedef typename packet_traits::type Packet; + int scalarCount_; +#if defined(__HIPCC__) + int packetCount_; +#else + Packet packetCount_; +#endif +}; + +template +struct reducer_traits, Device> { + enum { + Cost = 1, +#if (EIGEN_ARCH_i386 || EIGEN_ARCH_x86_64) && !defined(__CUDACC__) && \ + !defined(__HIPCC__) + // We only support packet access for floats. + PacketAccess = true, +#else + PacketAccess = false, +#endif + IsStateful = true, + IsExactlyAssociative = false + }; +}; + +template <> +struct reducer_traits, GpuDevice> { + enum { + Cost = 1, + PacketAccess = false, + IsStateful = true, + IsExactlyAssociative = false + }; +}; + +} // namespace internal + +template +EIGEN_ALWAYS_INLINE static const TensorReshapingOp< + const Eigen::DSizes::Index, + internal::traits::NumDimensions>, + const TensorReductionOp< + internal::AvgPoolMeanReducer< + std::remove_const_t::Scalar>>, + std::conditional_t< + internal::traits::Layout == ColMajor, + const Eigen::IndexList, Eigen::type2index<2>>, + const Eigen::IndexList, Eigen::type2index<3>>>, + const TensorImagePatchOp>> +SpatialAvgPooling(const Input& input, DenseIndex patchRows, + DenseIndex patchCols, DenseIndex strideRows, + DenseIndex strideCols, const PaddingType padding_type, + DenseIndex in_strideRows = 1, DenseIndex in_strideCols = 1) { + EIGEN_STATIC_ASSERT(internal::traits::NumDimensions == 4, + YOU_MADE_A_PROGRAMMING_MISTAKE); + + typedef typename internal::traits::Index TensorIndex; + TensorRef::Scalar, + internal::traits::NumDimensions, + internal::traits::Layout, TensorIndex> > + in(input); + + const DenseIndex patchRowsEff = + patchRows + (patchRows - 1) * (in_strideRows - 1); + const DenseIndex patchColsEff = + patchCols + (patchCols - 1) * (in_strideCols - 1); + + static const bool isColMajor = (internal::traits::Layout == ColMajor); + static const int idxRows = isColMajor ? 1 : 2; + static const int idxCols = isColMajor ? 2 : 1; + + // Molds the output of the reduction into the shape expected by the user. + // (assuming col-major): + // - 1st dim: channels + // - 2nd dim: output height + // - 3rd dim: output width + // - 4th dim and beyond: everything else including batch size + Eigen::DSizes::NumDimensions> + post_reduce_dims; + post_reduce_dims[0] = in.dimension(0); + if (padding_type == PADDING_VALID) { + post_reduce_dims[idxRows] = Eigen::divup( + static_cast(in.dimension(idxRows)) - patchRowsEff + 1, + strideRows); + post_reduce_dims[idxCols] = Eigen::divup( + static_cast(in.dimension(idxCols)) - patchColsEff + 1, + strideCols); + } else { + post_reduce_dims[idxRows] = Eigen::divup( + static_cast(in.dimension(idxRows)), strideRows); + post_reduce_dims[idxCols] = Eigen::divup( + static_cast(in.dimension(idxCols)), strideCols); + } + post_reduce_dims[3] = in.dimension(3); + + typedef std::remove_const_t::Scalar> + CoeffReturnType; + internal::AvgPoolMeanReducer mean_with_nan; + + // Take advantage of cxx11 to give the compiler information it can use to + // optimize the code. + std::conditional_t< + internal::traits::Layout == ColMajor, + const Eigen::IndexList, Eigen::type2index<2>>, + const Eigen::IndexList, Eigen::type2index<3>>> + reduction_dims; + return input + .extract_image_patches(patchRows, patchCols, strideRows, strideCols, + in_strideRows, in_strideCols, padding_type, + -Eigen::NumTraits::highest()) + .reduce(reduction_dims, mean_with_nan) + .reshape(post_reduce_dims); +} + +/** CuboidAvgPooling + * \ingroup CXX11_NeuralNetworks_Module + * + * \brief Applies an average pooling over a multichannel input volume. + * + * The input parameter is expected to be a tensor with a rank of 5 (channels, + * depth, height, width, others, and the reverse of that in row-major). + * + * The result can be assigned to a tensor of rank equal to the rank of the + * input. The dimensions of the result will be channels, depth, width, and + * others (in col-major, and the reverse of that if the input was row-major). + * + * The order of the depth, width and height dimensions can be swapped if + * needed. + * + */ +template +EIGEN_ALWAYS_INLINE static const TensorReshapingOp< + const Eigen::DSizes::NumDimensions>, + const TensorReductionOp< + internal::AvgPoolMeanReducer< + std::remove_const_t::Scalar>>, + const Eigen::IndexList>, + const TensorReshapingOp< + const Eigen::DSizes, + const TensorVolumePatchOp>>> +CuboidAvgPooling(const Input& input, DenseIndex patchPlanes, + DenseIndex patchRows, DenseIndex patchCols, + DenseIndex stridePlanes, DenseIndex strideRows, + DenseIndex strideCols, const PaddingType padding_type) { + EIGEN_STATIC_ASSERT(internal::traits::NumDimensions == 5, + YOU_MADE_A_PROGRAMMING_MISTAKE); + static const bool isColMajor = (internal::traits::Layout == ColMajor); + + typedef typename internal::traits::Index TensorIndex; + TensorRef::Scalar, + internal::traits::NumDimensions, + internal::traits::Layout, TensorIndex> > + in(input); + + static const int idxPlanes = isColMajor ? 1 : 3; + static const int idxRows = 2; + static const int idxCols = isColMajor ? 3 : 1; + // Molds the output of the reduction into the shape expected by the used + // (assuming col-major): + // - 1st dim: channels + // - 2nd dim: outupt depth + // - 3rd dim: output height + // - 4th dim: output width + // - 5th dim and beyond: everything else including batch size + Eigen::DSizes::NumDimensions> + post_reduce_dims; + post_reduce_dims[0] = in.dimension(0); + if (padding_type == PADDING_VALID) { + post_reduce_dims[idxPlanes] = Eigen::divup( + static_cast(in.dimension(idxPlanes)) - patchPlanes + 1, + stridePlanes); + post_reduce_dims[idxRows] = Eigen::divup( + static_cast(in.dimension(idxRows)) - patchRows + 1, + strideRows); + post_reduce_dims[idxCols] = Eigen::divup( + static_cast(in.dimension(idxCols)) - patchCols + 1, + strideCols); + } else { + post_reduce_dims[idxPlanes] = Eigen::divup( + static_cast(in.dimension(idxPlanes)), stridePlanes); + post_reduce_dims[idxRows] = Eigen::divup( + static_cast(in.dimension(idxRows)), strideRows); + post_reduce_dims[idxCols] = Eigen::divup( + static_cast(in.dimension(idxCols)), strideCols); + } + post_reduce_dims[4] = in.dimension(4); + + Eigen::DSizes pre_reduce_dims; + pre_reduce_dims[1] = patchRows * patchCols * patchPlanes; + if (isColMajor) { + pre_reduce_dims[0] = post_reduce_dims[0]; + pre_reduce_dims[2] = post_reduce_dims[1] * post_reduce_dims[2] * + post_reduce_dims[3] * post_reduce_dims[4]; + } else { + pre_reduce_dims[0] = post_reduce_dims[0] * post_reduce_dims[1] * + post_reduce_dims[2] * post_reduce_dims[3]; + pre_reduce_dims[2] = post_reduce_dims[4]; + } + + typedef std::remove_const_t::Scalar> + CoeffReturnType; + internal::AvgPoolMeanReducer mean_with_nan; + + // Take advantage of cxx11 to give the compiler information it can use to + // optimize the code. + Eigen::IndexList > reduction_dims; + return input + .extract_volume_patches(patchPlanes, patchRows, patchCols, stridePlanes, + strideRows, strideCols, padding_type, + -Eigen::NumTraits::highest()) + .reshape(pre_reduce_dims) + .reduce(reduction_dims, mean_with_nan) + .reshape(post_reduce_dims); +} + +} // end namespace Eigen + +#endif // TENSORFLOW_CORE_KERNELS_EIGEN_POOLING_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/fake_quant_ops_functor.h b/third_party/tflite-hdrs/tensorflow/core/kernels/fake_quant_ops_functor.h new file mode 100644 index 00000000..5053b5f6 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/fake_quant_ops_functor.h @@ -0,0 +1,290 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_FAKE_QUANT_OPS_FUNCTOR_H_ +#define TENSORFLOW_CORE_KERNELS_FAKE_QUANT_OPS_FUNCTOR_H_ + +#include + +#define EIGEN_STACK_ALLOCATION_LIMIT 0 +#define EIGEN_USE_THREADS +#include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/platform/types.h" + +EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE float StdRound(float input) { +// On Android, std::round() isn't present, just round(). +#if defined(__ANDROID__) + return round(input); +#else + return std::round(input); +#endif +} + +namespace tensorflow { + +// Gymnastics with nudged zero point is to ensure that real zero maps to +// an integer, which is required for e.g. zero-padding in convolutional layers. +// Outputs nudged_min, nudged_max, nudged_scale. +EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void Nudge( + const float min, const float max, const int quant_min, const int quant_max, + float* nudged_min, float* nudged_max, float* scale, float* inv_scale) { + const float quant_min_float = static_cast(quant_min); + const float quant_max_float = static_cast(quant_max); + *scale = (max - min) / (quant_max_float - quant_min_float); + // Re-calculate the inverse to avoid loss of precision which would result + // from simply taking the reciprocal of *scale + *inv_scale = (quant_max_float - quant_min_float) / (max - min); + const float zero_point_from_min = quant_min_float - min / *scale; + const uint16 nudged_zero_point = [zero_point_from_min, quant_min, + quant_min_float, quant_max, + quant_max_float] { + if (zero_point_from_min < quant_min_float) { + return static_cast(quant_min); + } + if (zero_point_from_min > quant_max_float) { + return static_cast(quant_max); + } + return static_cast(StdRound(zero_point_from_min)); + }(); + *nudged_min = (quant_min_float - nudged_zero_point) * (*scale); + *nudged_max = (quant_max_float - nudged_zero_point) * (*scale); +} + +template +using ConstScalar = typename tensorflow::TTypes::ConstScalar; +template +using Scalar = typename tensorflow::TTypes::Scalar; +template +using ConstVec = typename tensorflow::TTypes::ConstVec; +template +using Vec = typename tensorflow::TTypes::Vec; +template +using ConstFlat = typename tensorflow::TTypes::ConstFlat; +template +using Flat = typename tensorflow::TTypes::Flat; + +// Functor called by FakeQuantWithMinMaxArgsOp to do the work. Compiles both +// for CPU and GPU. +template +struct FakeQuantWithMinMaxArgsFunctor { + void operator()(const Device& d, ConstFlat inputs, const float min, + const float max, const int quant_min, const int quant_max, + Flat outputs) { + eigen_assert(min <= 0.0f && "min should be <= 0.0"); + eigen_assert(max >= 0.0f && "max should be >= 0.0"); + eigen_assert(min < max && "min should be < max"); + + float nudged_min, nudged_max, nudged_scale, inv_nudged_scale; + Nudge(min, max, quant_min, quant_max, &nudged_min, &nudged_max, + &nudged_scale, &inv_nudged_scale); + + const float quant_zero = floor(-nudged_min * inv_nudged_scale + 0.5f); + + auto clamped = inputs.cwiseMin(nudged_max).cwiseMax(nudged_min); + auto clamped_shifted = clamped - nudged_min; + outputs.device(d) = + (clamped_shifted * inv_nudged_scale - quant_zero + 0.5f).floor() * + nudged_scale; + } +}; + +// Functor called by FakeQuantWithMinMaxArgsGradientOp to do the work. Compiles +// both for CPU and GPU. +template +struct FakeQuantWithMinMaxArgsGradientFunctor { + void operator()(const Device& d, ConstFlat gradients, + ConstFlat inputs, const float min, const float max, + const int quant_min, const int quant_max, + Flat backprops) { + eigen_assert(min <= 0.0f && "min should be <= 0.0"); + eigen_assert(max >= 0.0f && "max should be >= 0.0"); + eigen_assert(min < max && "min should be < max"); + + float nudged_min, nudged_max, nudged_scale, inv_nudged_scale; + Nudge(min, max, quant_min, quant_max, &nudged_min, &nudged_max, + &nudged_scale, &inv_nudged_scale); + + auto between_nudged_min_max = + (inputs >= nudged_min && inputs <= nudged_max) + .select(inputs.constant(1.0f), inputs.constant(0.0f)); + backprops.device(d) = gradients * between_nudged_min_max; + } +}; + +// Functor called by FakeQuantWithMinMaxVarsOp to do the work. Compiles both +// for CPU and GPU. +template +struct FakeQuantWithMinMaxVarsFunctor { + void operator()(const Device& d, ConstFlat inputs, + ConstScalar min, ConstScalar max, + const int quant_min, const int quant_max, + Flat outputs) { + const float min_val = min(); + const float max_val = max(); + // If min and max are both zero, we should just return zero. + if (min_val == 0.0f && max_val == 0.0f) { + outputs.device(d) = outputs.constant(0.0f); + return; + } + float nudged_min, nudged_max, nudged_scale, inv_nudged_scale; + Nudge(min_val, max_val, quant_min, quant_max, &nudged_min, &nudged_max, + &nudged_scale, &inv_nudged_scale); + + const float quant_zero = floor(-nudged_min * inv_nudged_scale + 0.5f); + const auto nudged_scale_repl = inputs.constant(nudged_scale); + // const auto inv_nudged_scale_repl = inputs.constant(inv_nudged_scale); + + const auto clamped = inputs.cwiseMin(nudged_max).cwiseMax(nudged_min); + const auto clamped_shifted = clamped - nudged_min; + outputs.device(d) = + (clamped_shifted / nudged_scale_repl - quant_zero + 0.5f).floor() * + nudged_scale_repl; + } +}; + +// Functor called by FakeQuantWithMinMaxVarsGradientOp to do the work. Compiles +// both for CPU and GPU. +template +struct FakeQuantWithMinMaxVarsGradientFunctor { + void operator()(const Device& d, ConstFlat gradients, + ConstFlat inputs, ConstScalar min, + ConstScalar max, const int quant_min, + const int quant_max, Flat backprops_wrt_input, + Scalar backprop_wrt_min, + Scalar backprop_wrt_max) { + const float min_val = min(); + const float max_val = max(); + // If min and max are both zero, we propagate everything to inputs. + if (min_val == 0.0f && max_val == 0.0f) { + backprops_wrt_input.device(d) = gradients; + backprop_wrt_min.device(d) = backprop_wrt_min.constant(0.0f); + backprop_wrt_max.device(d) = backprop_wrt_max.constant(0.0f); + return; + } + float nudged_min, nudged_max, nudged_scale, inv_nudged_scale; + Nudge(min_val, max_val, quant_min, quant_max, &nudged_min, &nudged_max, + &nudged_scale, &inv_nudged_scale); + + const auto between_min_max = + (inputs >= nudged_min && inputs <= nudged_max) + .select(inputs.constant(1.0f), inputs.constant(0.0f)); + backprops_wrt_input.device(d) = gradients * between_min_max; + + const auto below_min = + (inputs < nudged_min) + .select(inputs.constant(1.0f), inputs.constant(0.0f)); + backprop_wrt_min.device(d) = (gradients * below_min).sum(); + + const auto above_max = + (inputs > nudged_max) + .select(inputs.constant(1.0f), inputs.constant(0.0f)); + backprop_wrt_max.device(d) = (gradients * above_max).sum(); + } +}; + +using Index = typename tensorflow::TTypes::ConstTensor::Index; + +// Functor called by FakeQuantWithMinMaxVarsPerChannelOp to do the work. +// Compiles both for CPU and GPU. +// +// Already verified: inputs, outputs are of shape [b, d], min, max are of shape +// [d]. +template +struct FakeQuantWithMinMaxVarsPerChannelFunctor { + void operator()(const Device& d, TTypes::ConstMatrix inputs, + ConstVec min, ConstVec max, const int quant_min, + const int quant_max, TTypes::Matrix outputs) { + for (Index i = 0; i < min.size(); ++i) { + const float min_val = min(i); + const float max_val = max(i); + // If min and max are both zero, we should just return zero. + if (min_val == 0.0f && max_val == 0.0f) { + auto chip = outputs.chip<1>(i); + chip.device(d) = chip.constant(0.0f); + continue; + } + float nudged_min, nudged_max, nudged_scale, inv_nudged_scale; + Nudge(min_val, max_val, quant_min, quant_max, &nudged_min, &nudged_max, + &nudged_scale, &inv_nudged_scale); + + const float quant_zero = floor(-nudged_min * inv_nudged_scale + 0.5f); + + const auto clamped = + inputs.chip<1>(i).cwiseMin(nudged_max).cwiseMax(nudged_min); + const auto clamped_shifted = clamped - nudged_min; + + outputs.chip<1>(i).device(d) = + (clamped_shifted * inv_nudged_scale - quant_zero + 0.5f).floor() * + nudged_scale; + } + } +}; + +// Functor called by FakeQuantWithMinMaxVarsPerChannelGradientOp to do the work. +// Compiles both for CPU and GPU. +// +// Already verified: gradients, inputs, backprops_wrt_input are of shape [b, d], +// min, max, backprop_wrt_min, backprop_wrt_max are of shape [d]. +template +struct FakeQuantWithMinMaxVarsPerChannelGradientFunctor { + void operator()(const Device& d, TTypes::ConstMatrix gradients, + TTypes::ConstMatrix inputs, ConstVec min, + ConstVec max, const int quant_min, const int quant_max, + TTypes::Matrix backprops_wrt_input, + Vec backprop_wrt_min, Vec backprop_wrt_max) { + for (Index i = 0; i < min.size(); ++i) { + const float min_val = min(i); + const float max_val = max(i); + const auto gradients_chip = gradients.chip<1>(i); + const auto inputs_chip = inputs.chip<1>(i); + // If min and max are both zero, we propagate everything to inputs. + if (min_val == 0.0f && max_val == 0.0f) { + backprops_wrt_input.chip<1>(i).device(d) = gradients_chip; + auto min_chip = backprop_wrt_min.chip<0>(i); + auto max_chip = backprop_wrt_max.chip<0>(i); + min_chip.device(d) = min_chip.constant(0.0f); + max_chip.device(d) = max_chip.constant(0.0f); + continue; + } + float nudged_min, nudged_max, nudged_scale, inv_nudged_scale; + Nudge(min_val, max_val, quant_min, quant_max, &nudged_min, &nudged_max, + &nudged_scale, &inv_nudged_scale); + + const auto between_min_max = + (inputs_chip >= nudged_min && inputs_chip <= nudged_max) + .select(inputs_chip.constant(1.0f), inputs_chip.constant(0.0f)); + backprops_wrt_input.chip<1>(i).device(d) = + gradients_chip * between_min_max; + + const auto below_min = + (inputs_chip < nudged_min) + .select(inputs_chip.constant(1.0f), inputs_chip.constant(0.0f)); + Eigen::DSizes reduce(0); + backprop_wrt_min.chip<0>(i).device(d) = + (gradients_chip * below_min).sum(reduce); + + const auto above_max = + (inputs_chip > nudged_max) + .select(inputs_chip.constant(1.0f), inputs_chip.constant(0.0f)); + backprop_wrt_max.chip<0>(i).device(d) = + (gradients_chip * above_max).sum(reduce); + } + } +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_FAKE_QUANT_OPS_FUNCTOR_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/fifo_queue.h b/third_party/tflite-hdrs/tensorflow/core/kernels/fifo_queue.h new file mode 100644 index 00000000..6648fe27 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/fifo_queue.h @@ -0,0 +1,93 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_FIFO_QUEUE_H_ +#define TENSORFLOW_CORE_KERNELS_FIFO_QUEUE_H_ + +#include +#include + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/kernels/queue_op.h" +#include "tensorflow/core/kernels/typed_queue.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +class FIFOQueue : public TypedQueue > { + public: + FIFOQueue(int32_t capacity, const DataTypeVector& component_dtypes, + const std::vector& component_shapes, + const string& name); + + // Implementations of QueueInterface methods -------------------------------- + + void TryEnqueue(const Tuple& tuple, OpKernelContext* ctx, + DoneCallback callback) override; + void TryEnqueueMany(const Tuple& tuple, OpKernelContext* ctx, + DoneCallback callback) override; + void TryDequeue(OpKernelContext* ctx, CallbackWithTuple callback) override; + void TryDequeueMany(int num_elements, OpKernelContext* ctx, + bool allow_small_batch, + CallbackWithTuple callback) override; + absl::Status MatchesNodeDef(const NodeDef& node_def) override; + + int32 size() const override { + mutex_lock lock(mu_); + return queues_[0].size(); + } + + protected: + ~FIFOQueue() override {} + + // Helper for dequeuing a single element from queues_. + void DequeueLocked(OpKernelContext* ctx, Tuple* tuple) + TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); + + static absl::Status GetElementComponentFromBatch(const Tuple& tuple, + int64_t index, int component, + OpKernelContext* ctx, + Tensor* out_tensor); + + private: + FIFOQueue(const FIFOQueue&) = delete; + void operator=(const FIFOQueue&) = delete; +}; + +// Defines a FIFOQueueOp, which produces a Queue (specifically, one +// backed by FIFOQueue) that persists across different graph +// executions, and sessions. Running this op produces a single-element +// tensor of handles to Queues in the corresponding device. +class FIFOQueueOp : public TypedQueueOp { + public: + explicit FIFOQueueOp(OpKernelConstruction* context); + + private: + absl::Status CreateResource(QueueInterface** ret) override + TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); + + std::vector component_shapes_; + FIFOQueueOp(const FIFOQueueOp&) = delete; + void operator=(const FIFOQueueOp&) = delete; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_FIFO_QUEUE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/fill_empty_rows_functor.h b/third_party/tflite-hdrs/tensorflow/core/kernels/fill_empty_rows_functor.h new file mode 100644 index 00000000..2298ed92 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/fill_empty_rows_functor.h @@ -0,0 +1,271 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_FILL_EMPTY_ROWS_OP_H_ +#define TENSORFLOW_CORE_KERNELS_FILL_EMPTY_ROWS_OP_H_ + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/lib/core/status.h" + +using CPUDevice = Eigen::ThreadPoolDevice; +using GPUDevice = Eigen::GpuDevice; + +namespace tensorflow { + +namespace functor { + +template +struct FillEmptyRows { + // Note that the done callback is only used by the GPU implementation. + absl::Status operator()(OpKernelContext* context, + const Tensor& default_value_t, + const Tensor& indices_t, const Tensor& values_t, + const Tensor& dense_shape_t, + typename AsyncOpKernel::DoneCallback done = nullptr); +}; + +template +struct FillEmptyRows { + static constexpr int IndicesRank = RaggedOperands ? 1 : 2; + absl::Status operator()(OpKernelContext* context, + const Tensor& default_value_t, + const Tensor& indices_t, const Tensor& values_t, + const Tensor& dense_shape_t, + typename AsyncOpKernel::DoneCallback done) { + (void)done; // Unused (only used in GPU implementation) + const int kOutputIndicesOutput = 0; + const int kOutputValuesOutput = 1; + const int kEmptyRowIndicatorOutput = 2; + const int kReverseIndexMapOutput = 3; + + const T& default_value = default_value_t.scalar()(); + const auto indices = indices_t.tensor(); + const auto values = values_t.vec(); + const auto dense_shape = dense_shape_t.tensor(); + + const Tindex N = indices_t.shape().dim_size(0); + const Tindex dense_rows = dense_shape(0); + + bool* empty_row_indicator = nullptr; + if (context->output_required(kEmptyRowIndicatorOutput)) { + Tensor* empty_row_indicator_t = nullptr; + TensorShape output_shape; + TF_RETURN_IF_ERROR( + TensorShape::BuildTensorShape({dense_rows}, &output_shape)); + TF_RETURN_IF_ERROR(context->allocate_output( + kEmptyRowIndicatorOutput, output_shape, &empty_row_indicator_t)); + empty_row_indicator = empty_row_indicator_t->vec().data(); + } + Tindex* reverse_index_map = nullptr; + if (context->output_required(kReverseIndexMapOutput)) { + Tensor* reverse_index_map_t = nullptr; + TensorShape output_shape; + TF_RETURN_IF_ERROR(TensorShape::BuildTensorShape({N}, &output_shape)); + TF_RETURN_IF_ERROR(context->allocate_output( + kReverseIndexMapOutput, output_shape, &reverse_index_map_t)); + reverse_index_map = reverse_index_map_t->vec().data(); + } + + const int rank = IndicesRank == 1 ? 1 : indices_t.shape().dim_size(1); + + if (dense_rows == 0) { + if (N != 0) { + return errors::InvalidArgument( + "Received SparseTensor with dense_shape[0] = 0 but " + "indices.shape[0] = ", + N); + } + Tensor* output_indices_t; + TensorShape output_indices_shape; + TF_RETURN_IF_ERROR( + TensorShape::BuildTensorShape({0, rank}, &output_indices_shape)); + TF_RETURN_IF_ERROR(context->allocate_output( + kOutputIndicesOutput, output_indices_shape, &output_indices_t)); + Tensor* output_values_t; + TF_RETURN_IF_ERROR(context->allocate_output( + kOutputValuesOutput, TensorShape({0}), &output_values_t)); + + // Exit early, nothing more to do. + return absl::OkStatus(); + } + + auto vec_or_matrix = [](auto tensor, int index1, int index2) -> auto& { + std::array indices; + indices[0] = index1; + if (IndicesRank == 2) { + indices[1] = index2; + } + return std::apply(tensor, indices); + }; + + bool rows_are_ordered = true; + Tindex last_indices_row = 0; + std::vector csr_offset(dense_rows, 0); + for (int i = 0; i < N; ++i) { + const Tindex row = vec_or_matrix(indices, i, 0); + if (row < 0 || row >= dense_rows) { + return errors::InvalidArgument("indices(", i, ", 0) is invalid: ", row, + " >= ", dense_rows); + } + ++csr_offset[row]; + rows_are_ordered = rows_are_ordered & (row >= last_indices_row); + last_indices_row = row; + } + bool all_rows_full = true; + for (int row = 0; row < dense_rows; ++row) { + // csr_offset here describes the number of elements in this dense row + bool row_empty = (csr_offset[row] == 0); + if (empty_row_indicator) { + empty_row_indicator[row] = row_empty; + } + all_rows_full = all_rows_full & !row_empty; + // In filled version, each row has at least one element. + csr_offset[row] = std::max(csr_offset[row], Tindex{1}); + // Update csr_offset to represent the number of elements up to and + // including dense_row + 1: + // csr_offset(0) == #{elements of row 0} + // csr_offset(1) == #{elements of row 1} + #{elements of row 0} + // .. + // csr_offset(i) == starting index for elements in row i + 1. + if (row > 0) { + csr_offset[row] += csr_offset[row - 1]; + } + } + + if (all_rows_full && rows_are_ordered) { + context->set_output(kOutputIndicesOutput, indices_t); + context->set_output(kOutputValuesOutput, values_t); + if (reverse_index_map) { + for (Tindex i = 0; i < N; ++i) { + reverse_index_map[i] = i; + } + } + } else { + Tensor* output_indices_t; + const Tindex N_full = csr_offset[dense_rows - 1]; + TensorShape output_indices_shape; + if constexpr (RaggedOperands) { + TF_RETURN_IF_ERROR( + TensorShape::BuildTensorShape({N_full}, &output_indices_shape)); + } else { + TF_RETURN_IF_ERROR(TensorShape::BuildTensorShape( + {N_full, rank}, &output_indices_shape)); + } + TF_RETURN_IF_ERROR(context->allocate_output( + kOutputIndicesOutput, output_indices_shape, &output_indices_t)); + auto output_indices = output_indices_t->tensor(); + + Tensor* output_values_t; + TF_RETURN_IF_ERROR(context->allocate_output( + kOutputValuesOutput, TensorShape({N_full}), &output_values_t)); + auto output_values = output_values_t->vec(); + + std::vector filled_count(dense_rows, 0); + + // Fill in values for rows that are not missing + for (Tindex i = 0; i < N; ++i) { + const Tindex row = vec_or_matrix(indices, i, 0); + Tindex& offset = filled_count[row]; + const Tindex output_i = ((row == 0) ? 0 : csr_offset[row - 1]) + offset; + offset++; // Increment the filled count for this row. + std::copy_n(&vec_or_matrix(indices, i, 0), rank, + &vec_or_matrix(output_indices, output_i, 0)); + output_values(output_i) = values(i); + // We'll need this reverse index map to backprop correctly. + if (reverse_index_map) { + reverse_index_map[i] = output_i; + } + } + + // Fill in values for rows that are missing + for (Tindex row = 0; row < dense_rows; ++row) { + const Tindex row_count = filled_count[row]; + if (row_count == 0) { // We haven't filled this row + const Tindex starting_index = (row == 0) ? 0 : csr_offset[row - 1]; + // Remaining index values were set to zero already. + // Just need to set the row index in the right location. + vec_or_matrix(output_indices, starting_index, 0) = row; + for (Tindex col = 1; col < rank; ++col) { + vec_or_matrix(output_indices, starting_index, col) = 0; + } + output_values(starting_index) = default_value; + } + } + } + + return absl::OkStatus(); + } +}; + +template +struct FillEmptyRowsGrad { + absl::Status operator()(OpKernelContext* context, + typename TTypes::ConstVec reverse_index_map, + typename TTypes::ConstVec grad_values, + typename TTypes::Vec d_values, + typename TTypes::Scalar d_default_value); +}; + +template +struct FillEmptyRowsGrad { + absl::Status operator()(OpKernelContext* context, + typename TTypes::ConstVec reverse_index_map, + typename TTypes::ConstVec grad_values, + typename TTypes::Vec d_values, + typename TTypes::Scalar d_default_value) { + const CPUDevice& device = context->eigen_device(); + const Tindex N = reverse_index_map.dimension(0); + const Tindex N_full = grad_values.dimension(0); + + T& d_default_value_scalar = d_default_value(); + d_default_value_scalar = T(); + + Tensor visited_t; + TF_RETURN_IF_ERROR( + context->allocate_temp(DT_BOOL, TensorShape({N_full}), &visited_t)); + auto visited = visited_t.vec(); + visited.device(device) = visited.constant(false); + + for (int i = 0; i < N; ++i) { + // Locate the index of the output of the forward prop associated + // with this location in the input of the forward prop. Copy + // the gradient into it. Mark it as visited. + int64_t reverse_index = reverse_index_map(i); + if (reverse_index < 0 || reverse_index >= N_full) { + return errors::InvalidArgument( + "Elements in reverse index must be in [0, ", N_full, ") but got ", + reverse_index); + } + d_values(i) = grad_values(reverse_index); + visited(reverse_index) = true; + } + for (int j = 0; j < N_full; ++j) { + // The default value gradient gets the accumulated remainder of + // the backprop values (since the default value was used to fill + // in these slots in the forward calculation). + if (!visited(j)) { + d_default_value_scalar += grad_values(j); + } + } + return absl::OkStatus(); + } +}; + +} // namespace functor + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_FILL_EMPTY_ROWS_OP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/fill_functor.h b/third_party/tflite-hdrs/tensorflow/core/kernels/fill_functor.h new file mode 100644 index 00000000..abdc10ee --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/fill_functor.h @@ -0,0 +1,96 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_FILL_FUNCTOR_H_ +#define TENSORFLOW_CORE_KERNELS_FILL_FUNCTOR_H_ + +#define EIGEN_USE_THREADS + +#include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/framework/types.h" + +namespace tensorflow { +namespace functor { + +template +struct FillFunctor { + // Computes on device "d": out = out.constant(in(0)), + void operator()(const Device& d, typename TTypes::Flat out, + typename TTypes::ConstScalar in); +}; + +template +struct SetZeroFunctor { + // Computes on device "d": out = out.setZero(), + void operator()(const Device& d, typename TTypes::Flat out); +}; + +// Partial specialization of SetZeroFunctor. +template +struct SetZeroFunctor { + void operator()(const Eigen::ThreadPoolDevice& d, + typename TTypes::Flat out); +}; + + +template <> +struct SetZeroFunctor { + void operator()(const Eigen::ThreadPoolDevice& d, + typename TTypes::Flat out); +}; + +template +struct SetOneFunctor { + // Computes on device "d": out = out.setOne(), + void operator()(const Device& d, typename TTypes::Flat out); +}; + +// Partial specialization of SetOneFunctor. +template +struct SetOneFunctor { + void operator()(const Eigen::ThreadPoolDevice& d, + typename TTypes::Flat out); +}; + + +template <> +struct SetOneFunctor { + void operator()(const Eigen::ThreadPoolDevice& d, + typename TTypes::Flat out); +}; + +template +struct SetNanFunctor { + void operator()(const Device& d, typename TTypes::Flat out); +}; + +// Partial specialization of SetNanFunctor. +template +struct SetNanFunctor { + void operator()(const Eigen::ThreadPoolDevice& d, + typename TTypes::Flat out); +}; + +template <> +struct SetNanFunctor { + void operator()(const Eigen::ThreadPoolDevice& d, + typename TTypes::Flat out); +}; + +} // namespace functor +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_FILL_FUNCTOR_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/fractional_pool_common.h b/third_party/tflite-hdrs/tensorflow/core/kernels/fractional_pool_common.h new file mode 100644 index 00000000..0abb20d2 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/fractional_pool_common.h @@ -0,0 +1,79 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_KERNELS_FRACTIONAL_POOL_COMMON_H_ +#define TENSORFLOW_CORE_KERNELS_FRACTIONAL_POOL_COMMON_H_ + +#include +#include + +#include "tensorflow/core/util/guarded_philox_random.h" + +namespace tensorflow { + +// Shuffle a container randomly, copied from random_shuffle_op.cc +template +static inline void RandomShuffle(Iter first, Iter last, const Random& uniform) { + if (first == last) { + return; + } + const auto stop = last - 1; + for (auto i = first; i != stop; ++i) { + using std::iter_swap; + iter_swap(i, i + uniform(last - i)); + } +} + +// Generate pooling sequence for fractional pooling along one dimension. +// +// Regular max/avg pooling can be viewed as a special case, in which given the +// * input_length: e.g. 10 +// * output_length: e.g. 5 +// it will generate pooling sequence as +// diff sequence: [2, 2, 2, 2, 2] +// or as +// cumulative sequence: [0, 2, 4, 6, 8, 10] +// +// In the case of fractional pooling, input_length is not an integer multiple of +// output_length, randomness plays a role when generating pooling sequence. +// There are two type of randomness (random vs pseudo-random) defined in paper: +// http://arxiv.org/abs/1412.6071 +// You can check the paper for the difference between these two types. +// +// In summary, the generated diff sequence satisfy the following properties for +// both types of randomness: +// * length(generated_diff_pooling_sequence) = output_length +// * sum(generated_diff_pooling_sequence) = input_length +// * Let's define floor(input_length / output_length) = K, then +// K <= generated_diff_pooling_sequence[i] <= K+1 +// For example, when input_length = 10, output_length = 6, the following are +// valid pooling sequence: +// * [1, 2, 2, 1, 2, 2] +// * [1, 1, 2, 2, 2, 2] +// [1, 3, 2, 2, 2, 2] is not valid. +// +// Args: +// input_length: See above explanation +// output_length: See above explanation +// generator: Parallel version of random number generator +// pseudo_random: Whether or not use pseudo-random +// Returns: +// pooling_sequence: This is the cumulative pooling sequence. +std::vector GeneratePoolingSequence(int input_length, + int output_length, + GuardedPhiloxRandom* generator, + bool pseudo_random); +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_FRACTIONAL_POOL_COMMON_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/function_ops.h b/third_party/tflite-hdrs/tensorflow/core/kernels/function_ops.h new file mode 100644 index 00000000..552e1e6c --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/function_ops.h @@ -0,0 +1,91 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_FUNCTION_OPS_H_ +#define TENSORFLOW_CORE_KERNELS_FUNCTION_OPS_H_ + +#include "tensorflow/core/framework/full_type_util.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/op_kernel.h" + +namespace tensorflow { + +static const char* const kArgOp = FunctionLibraryDefinition::kArgOp; +static const char* const kDeviceArgOp = FunctionLibraryDefinition::kDeviceArgOp; +static const char* const kRetOp = FunctionLibraryDefinition::kRetOp; +static const char* const kDeviceRetOp = FunctionLibraryDefinition::kDeviceRetOp; + +class ArgOp : public OpKernel { + public: + explicit ArgOp(OpKernelConstruction* ctx); + + void Compute(OpKernelContext* ctx) override; + + bool IsExpensive() override { return false; } + + private: + int index_; + DataType dtype_; + + ArgOp(const ArgOp&) = delete; + void operator=(const ArgOp&) = delete; +}; + +class RetvalOp : public OpKernel { + public: + explicit RetvalOp(OpKernelConstruction* ctx); + + void Compute(OpKernelContext* ctx) override; + + bool IsExpensive() override { return false; } + + private: + int index_; + DataType dtype_; + + RetvalOp(const RetvalOp&) = delete; + void operator=(const RetvalOp&) = delete; +}; + +class RemoteCallOp : public AsyncOpKernel { + public: + explicit RemoteCallOp(OpKernelConstruction* ctx); + + ~RemoteCallOp() override {} + + void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override; + + string TraceString(const OpKernelContext& ctx, bool verbose) const override; + + private: + NameAttrList func_; + DataTypeVector input_dtypes_; + DataTypeVector output_dtypes_; + // Note that in the future if all RemoteCall ops have full type + // information, the kernel will not need access to the "Tout" Attr and + // return_type_ will replace output_dtypes_. + FullTypeDef return_type_; + + mutex mu_; + typedef std::pair FunctionTarget; + std::map handle_cache_ + TF_GUARDED_BY(mu_); + + RemoteCallOp(const RemoteCallOp&) = delete; + void operator=(const RemoteCallOp&) = delete; +}; + +} // namespace tensorflow +#endif // TENSORFLOW_CORE_KERNELS_FUNCTION_OPS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/fused_batch_norm_op.h b/third_party/tflite-hdrs/tensorflow/core/kernels/fused_batch_norm_op.h new file mode 100644 index 00000000..e50d80ae --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/fused_batch_norm_op.h @@ -0,0 +1,72 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_FUSED_BATCH_NORM_OP_H_ +#define TENSORFLOW_CORE_KERNELS_FUSED_BATCH_NORM_OP_H_ + +#include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/util/tensor_format.h" + +namespace tensorflow { +namespace functor { + +// FusedBatchNormEx op supports side inputs and activations: +// (1) batch_norm + activation +// (2) batch norm + side input + activation +enum class FusedBatchNormActivationMode { kIdentity, kRelu }; + +std::string ToString(FusedBatchNormActivationMode activation_mode); + +absl::Status ParseActivationMode(OpKernelConstruction* context, + FusedBatchNormActivationMode* activation_mode); + +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM + +// This is a functor to launch custom CUDA kernel for FusedBatchNorm with side +// input and activation when 'is_training=False'. In training we rely on cuDNN. +template +struct FusedBatchNormInferenceFunctor { + void operator()(OpKernelContext* context, TensorFormat tensor_format, + typename TTypes::ConstTensor in, + typename TTypes::ConstVec scale, + typename TTypes::ConstVec offset, + typename TTypes::ConstVec estimated_mean, + typename TTypes::ConstVec estimated_variance, + typename TTypes::ConstTensor side_input, U epsilon, + FusedBatchNormActivationMode activation_mode, + typename TTypes::Tensor out); +}; + +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM + +// Functor used by FusedBatchNormGradOp to do the computations when +// is_training=False. +template +struct FusedBatchNormFreezeGrad { + void operator()(OpKernelContext* context, const Tensor& y_backprop_input, + const Tensor& x_input, const Tensor& scale_input, + const Tensor& pop_mean_input, + const Tensor& pop_variance_input, U epsilon, + Tensor* x_backprop_output, Tensor* scale_backprop_output, + Tensor* offset_backprop_output) {} +}; + +} // namespace functor +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_FUSED_BATCH_NORM_OP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/fused_eigen_output_kernels.h b/third_party/tflite-hdrs/tensorflow/core/kernels/fused_eigen_output_kernels.h new file mode 100644 index 00000000..84a0d27b --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/fused_eigen_output_kernels.h @@ -0,0 +1,479 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Output kernels for fusing computation into Eigen Tensor contractions: +// (1) FusedConv2DOp +// (2) FusedMatMulOp +// +// Supported fused computations: +// (1) {Conv2D/MatMul} + BiasAdd + +// (2) {Conv2D/MatMul} + FusedBatchNorm + +// +// Activation: Relu, Relu6, Elu, etc... + +#ifndef TENSORFLOW_CORE_KERNELS_FUSED_EIGEN_OUTPUT_KERNELS_H_ +#define TENSORFLOW_CORE_KERNELS_FUSED_EIGEN_OUTPUT_KERNELS_H_ + +#include + +#include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_types.h" + +namespace tensorflow { + +enum class FusedComputationType { + kUndefined, + kBiasAdd, + kBiasAddWithRelu, + kBiasAddWithRelu6, + kBiasAddWithTanh, + kBiasAddWithSigmoid, + kBiasAddWithElu, + kBiasAddWithLeakyRelu, + kBiasAddWithGeluApproximate, + kBiasAddWithGeluExact, + kFusedBatchNorm, + kFusedBatchNormWithRelu, + kFusedBatchNormWithRelu6, + kFusedBatchNormWithElu, + kFusedBatchNormWithLeakyRelu +}; + +// We have to pass around additional arguments for all possible fusion types. +struct FusedComputationArgs { + float epsilon = 0.0; // Used by `FusedBatchNorm` fusion only + float leakyrelu_alpha = 0.0; // Used by `LeakyRelu` fusion only +}; + +struct FusedComputationPattern { + FusedComputationType fused_computation; + std::vector fused_ops; +}; + +// Parse attributes from the kernel construction context, and verifies that they +// specify valid fused computation pattern. +absl::Status InitializeFusedComputation( + OpKernelConstruction* context, const string& kernel_name, + const std::vector& patterns, + FusedComputationType* fused_computation, + FusedComputationArgs* fused_computation_args); + +// Type alias for the tensor contraction output mapper. +template +using ContractionOutputMapper = + Eigen::internal::blas_data_mapper; + +// Returns input expression without any transformations. +struct Identity { + template + static auto apply(XprType expr) -> XprType { + return expr; + }; +}; + +// Applies `Relu` to the passed input expression. +struct Relu { + template + static auto apply(XprType expr) + -> decltype(expr.cwiseMax(std::declval())) { + return expr.cwiseMax(static_cast(0)); + }; +}; + +// Applies `Relu6` to the passed input expression. +struct Relu6 { + template + static auto apply(XprType expr) + -> decltype(expr.cwiseMax(std::declval()) + .cwiseMin(std::declval())) { + return expr.cwiseMax(static_cast(0)) + .cwiseMin(static_cast(6)); + }; +}; + +// Applies `Tanh` to the passed input expression. +struct Tanh { + template + static auto apply(XprType expr) -> decltype(expr.tanh()) { + return expr.tanh(); + }; +}; + +// Applies `Sigmoid` to the passed input expression. +struct Sigmoid { + template + static auto apply(XprType expr) -> decltype(expr.sigmoid()) { + return expr.sigmoid(); + }; +}; + +// Applies `Elu` to the passed input expression. +struct Elu { + template + static auto apply(XprType expr) -> decltype( + (expr < std::declval()) + .select(expr.exp() - + expr.constant(std::declval()), + expr)) { + return (expr < static_cast(0)) + .select(expr.exp() - + expr.constant(static_cast(1)), + expr); + }; +}; + +// Applies `LeakyRelu` to the passed input expression. +struct LeakyRelu { + template + static auto apply(XprType expr, const float leakyrelu_alpha) -> decltype( + (expr < std::declval()) + .select(expr * + expr.constant(std::declval()), + expr)) { + return (expr < static_cast(0)) + .select(expr * expr.constant(static_cast( + leakyrelu_alpha)), + expr); + }; +}; + +template +struct BiasAddArgs { + const T* bias_add_data = nullptr; + float leakyrelu_alpha; + + static bool IsSupported(FusedComputationType fusion) { + return fusion == FusedComputationType::kBiasAdd || + fusion == FusedComputationType::kBiasAddWithRelu || + fusion == FusedComputationType::kBiasAddWithRelu6 || + fusion == FusedComputationType::kBiasAddWithTanh || + fusion == FusedComputationType::kBiasAddWithSigmoid || + fusion == FusedComputationType::kBiasAddWithElu || + fusion == FusedComputationType::kBiasAddWithLeakyRelu; + } +}; + +template +struct FusedBatchNormArgs { + const T* scale_data = nullptr; + const T* offset_data = nullptr; + const T* estimated_mean_data = nullptr; + const T* estimated_variance_data = nullptr; + + // Precomputed expression: + // scaling_factor = (estimated_variance + epsilon).rsqrt() * scale + Eigen::Tensor scaling_factor; + + float leakyrelu_alpha; + + static bool IsSupported(FusedComputationType fusion) { + return fusion == FusedComputationType::kFusedBatchNorm || + fusion == FusedComputationType::kFusedBatchNormWithRelu || + fusion == FusedComputationType::kFusedBatchNormWithRelu6 || + fusion == FusedComputationType::kFusedBatchNormWithElu || + fusion == FusedComputationType::kFusedBatchNormWithLeakyRelu; + } +}; + +// TensorContraction swaps lhs with rhs, and changes layout from RowMajor +// (default in Tensorflow) to ColMajor (preferred in Eigen), and computes matmul +// using these tensors. +// +// (1) Spatial Convolution (see eigen_spatial_convolutions.h): +// +// TensorContraction output matrix (before reshape) has a ColMajor layout, and +// has dimensions: +// - rows: output_channels +// - cols: all other dimensions +// +// First element in every column is: +// [batch ??, height ??, width ??, out_channel = i] +// +// We do not know what are the values of the 'batch', 'height', and 'width' +// here (if we know original dimensions, they can be computed from 'j'). +// +// Each column of an output block is a continuous slice along the output +// channel dimension, so we can use it to efficiently compute any +// transformation that depends only on a channel value (e.g. add channel +// bias). +// +// (2) Matrix Multiplication (see matmul_op.cc): +// +// For the `MxK * KxN` matrix multiplication, output matrix has a `MxN` +// dimensions. Each column in output block is a slice of the innermost +// dimension of the output matrix starting at offset 'i'. +// +// Example: In Tensorflow MatMul [8x32] * [32x64], each output block column +// will correspond to MatMul output row of size 64 (because Tensorflow uses +// row major storage order). + +// Output kernel that fuses BiasAdd operation into the output of tensor +// contraction + activation function defined by Activation. +template +struct BiasAddOutputKernel { + explicit BiasAddOutputKernel(const BiasAddArgs& args) + : bias_data(args.bias_add_data) {} + + template + EIGEN_ALWAYS_INLINE void operator()( + const ContractionOutputMapper& output_mapper, + const Eigen::TensorContractionParams& params, StorageIndex i, + StorageIndex j, StorageIndex num_rows, StorageIndex num_cols) const { + DCHECK(params.swapped_arguments); + + const T* bias_base = bias_data + i; + typename TTypes::UnalignedConstTensor bias(bias_base, num_rows); + + for (int col = 0; col < num_cols; ++col) { + Scalar* output_base = &output_mapper(0, col); + typename TTypes::UnalignedTensor output(output_base, num_rows); + if constexpr (std::is_same_v) { + const auto expr = output + bias; + output = Activation::template apply(expr); + } else { + const auto bias_expr = bias.template cast(); + const auto expr = output + bias_expr; + output = Activation::template apply(expr); + } + } + } + + private: + const T* bias_data; +}; + +template +struct BiasAddOutputKernel { + explicit BiasAddOutputKernel(const BiasAddArgs& args) + : bias_data(args.bias_add_data), leakyrelu_alpha(args.leakyrelu_alpha) {} + + template + EIGEN_ALWAYS_INLINE void operator()( + const ContractionOutputMapper& output_mapper, + const Eigen::TensorContractionParams& params, StorageIndex i, + StorageIndex j, StorageIndex num_rows, StorageIndex num_cols) const { + DCHECK(params.swapped_arguments); + + const T* bias_base = bias_data + i; + typename TTypes::UnalignedConstTensor bias(bias_base, num_rows); + + for (int col = 0; col < num_cols; ++col) { + Scalar* output_base = &output_mapper(0, col); + typename TTypes::UnalignedTensor output(output_base, num_rows); + if constexpr (std::is_same_v) { + const auto expr = output + bias; + output = + LeakyRelu::template apply(expr, leakyrelu_alpha); + } else { + const auto bias_expr = bias.template cast(); + const auto expr = output + bias_expr; + output = + LeakyRelu::template apply(expr, leakyrelu_alpha); + } + } + } + + private: + const T* bias_data; + float leakyrelu_alpha; +}; + +// Output kernel that fuses FusedBatchNorm operation into the output of tensor +// contraction + activation function defined by Activation. +template +struct FusedBatchNormOutputKernel { + FusedBatchNormOutputKernel(T epsilon, const FusedBatchNormArgs& args) + : epsilon(epsilon), + scaling_factor_data(args.scaling_factor.data()), + offset_data(args.offset_data), + estimated_mean_data(args.estimated_mean_data) {} + + template + EIGEN_ALWAYS_INLINE void operator()( + const ContractionOutputMapper& output_mapper, + const Eigen::TensorContractionParams& params, StorageIndex i, + StorageIndex j, StorageIndex num_rows, StorageIndex num_cols) const { + DCHECK(params.swapped_arguments); + + const T* scaling_factor_base = scaling_factor_data + i; + const T* offset_base = offset_data + i; + const T* mean_base = estimated_mean_data + i; + + typename TTypes::UnalignedConstTensor scaling_factor(scaling_factor_base, + num_rows); + typename TTypes::UnalignedConstTensor offset(offset_base, num_rows); + typename TTypes::UnalignedConstTensor mean(mean_base, num_rows); + + for (int col = 0; col < num_cols; ++col) { + T* output_base = &output_mapper(0, col); + typename TTypes::UnalignedTensor output(output_base, num_rows); + + auto scaled = (output - mean) * scaling_factor; + auto shifted = scaled + offset; + + output = Activation::template apply(shifted); + } + } + + private: + T epsilon; + const T* scaling_factor_data; + const T* offset_data; + const T* estimated_mean_data; +}; + +template +struct FusedBatchNormOutputKernel { + FusedBatchNormOutputKernel(T epsilon, const FusedBatchNormArgs& args) + : epsilon(epsilon), + scaling_factor_data(args.scaling_factor.data()), + offset_data(args.offset_data), + estimated_mean_data(args.estimated_mean_data), + leakyrelu_alpha(args.leakyrelu_alpha) {} + + template + EIGEN_ALWAYS_INLINE void operator()( + const ContractionOutputMapper& output_mapper, + const Eigen::TensorContractionParams& params, StorageIndex i, + StorageIndex j, StorageIndex num_rows, StorageIndex num_cols) const { + DCHECK(params.swapped_arguments); + + const T* scaling_factor_base = scaling_factor_data + i; + const T* offset_base = offset_data + i; + const T* mean_base = estimated_mean_data + i; + + typename TTypes::UnalignedConstTensor scaling_factor(scaling_factor_base, + num_rows); + typename TTypes::UnalignedConstTensor offset(offset_base, num_rows); + typename TTypes::UnalignedConstTensor mean(mean_base, num_rows); + + for (int col = 0; col < num_cols; ++col) { + T* output_base = &output_mapper(0, col); + typename TTypes::UnalignedTensor output(output_base, num_rows); + + auto scaled = (output - mean) * scaling_factor; + auto shifted = scaled + offset; + + output = LeakyRelu::template apply(shifted, + leakyrelu_alpha); + } + } + + private: + T epsilon; + const T* scaling_factor_data; + const T* offset_data; + const T* estimated_mean_data; + float leakyrelu_alpha; +}; + +// Type aliases for the output kernels, purely for the sake of better launch +// dispatching code readability. +template +using WithBiasAdd = BiasAddOutputKernel; +template +using WithBiasAddAndRelu = BiasAddOutputKernel; +template +using WithBiasAddAndRelu6 = BiasAddOutputKernel; +template +using WithBiasAddAndTanh = BiasAddOutputKernel; +template +using WithBiasAddAndSigmoid = BiasAddOutputKernel; +template +using WithBiasAddAndElu = BiasAddOutputKernel; +template +using WithBiasAddAndLeakyRelu = BiasAddOutputKernel; +template +using WithFusedBatchNorm = FusedBatchNormOutputKernel; +template +using WithFusedBatchNormAndRelu = FusedBatchNormOutputKernel; +template +using WithFusedBatchNormAndRelu6 = FusedBatchNormOutputKernel; +template +using WithFusedBatchNormAndElu = FusedBatchNormOutputKernel; +template +using WithFusedBatchNormAndLeakyRelu = FusedBatchNormOutputKernel; + +template +absl::Status InitBiasAddArgs(OpKernelContext* context, BiasAddArgs* args, + const float* leakyrelu_alpha = nullptr) { + // Bias of the following dimensions: [ output_depth ] + const Tensor& bias = context->input(2); + + if (bias.dims() != 1) + return errors::InvalidArgument("bias must be 1-dimensional", + bias.shape().DebugString()); + + const auto data_ptr = [](const Tensor& tensor) -> const T* { + return reinterpret_cast(tensor.tensor_data().data()); + }; + + args->bias_add_data = data_ptr(bias); + + if (leakyrelu_alpha) { + args->leakyrelu_alpha = *leakyrelu_alpha; + } + + return absl::OkStatus(); +} + +template +absl::Status InitFusedBatchNormArgs(OpKernelContext* context, float epsilon, + FusedBatchNormArgs* args, + const float* leakyrelu_alpha = nullptr) { + const Tensor& scale = context->input(2); + const Tensor& offset = context->input(3); + const Tensor& estimated_mean = context->input(4); + const Tensor& estimated_variance = context->input(5); + + if (scale.dims() != 1) + return errors::InvalidArgument("scale must be 1-dimensional", + scale.shape().DebugString()); + if (offset.dims() != 1) + return errors::InvalidArgument("offset must be 1-dimensional", + offset.shape().DebugString()); + if (estimated_mean.dims() != 1) + return errors::InvalidArgument("estimated_mean must be 1-dimensional", + estimated_mean.shape().DebugString()); + if (estimated_variance.dims() != 1) + return errors::InvalidArgument("estimated_variance must be 1-dimensional", + estimated_variance.shape().DebugString()); + + const auto data_ptr = [](const Tensor& tensor) -> const T* { + return reinterpret_cast(tensor.tensor_data().data()); + }; + + args->scale_data = data_ptr(scale); + args->offset_data = data_ptr(offset); + args->estimated_mean_data = data_ptr(estimated_mean); + args->estimated_variance_data = data_ptr(estimated_variance); + + // Precompute scaling factor once for all output blocks (kernels). + args->scaling_factor = + (estimated_variance.flat() + static_cast(epsilon)).rsqrt() * + scale.flat(); + + if (leakyrelu_alpha) { + args->leakyrelu_alpha = *leakyrelu_alpha; + } + + return absl::OkStatus(); +} + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_FUSED_EIGEN_OUTPUT_KERNELS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/fuzzing/fuzz_session.h b/third_party/tflite-hdrs/tensorflow/core/kernels/fuzzing/fuzz_session.h new file mode 100644 index 00000000..09c7563d --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/fuzzing/fuzz_session.h @@ -0,0 +1,157 @@ +/* Copyright 2016 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_FUZZING_FUZZ_SESSION_H_ +#define TENSORFLOW_CORE_KERNELS_FUZZING_FUZZ_SESSION_H_ + +#include "tensorflow/cc/framework/scope.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/public/session.h" + +// Standard invoking function macro to dispatch to a fuzzer class. +#ifndef PLATFORM_WINDOWS +#define STANDARD_TF_FUZZ_FUNCTION(FuzzerClass) \ + extern "C" int LLVMFuzzerTestOneInput(const uint8_t* data, size_t size) { \ + static FuzzerClass* fuzzer = new FuzzerClass(); \ + return fuzzer->Fuzz(data, size); \ + } +#else +// We don't compile this for Windows, MSVC doesn't like it as pywrap in Windows +// links all the code into one big object file and there are conflicting +// function names. +#define STANDARD_TF_FUZZ_FUNCTION(FuzzerClass) +#endif + +// Standard builder for hooking one placeholder to one op. +#define SINGLE_INPUT_OP_BUILDER(dtype, opName) \ + void BuildGraph(const Scope& scope) override { \ + auto op_node = \ + tensorflow::ops::Placeholder(scope.WithOpName("input"), dtype); \ + (void)tensorflow::ops::opName(scope.WithOpName("output"), op_node); \ + } + +namespace tensorflow { +namespace fuzzing { + +// Create a TensorFlow session using a specific GraphDef created +// by BuildGraph(), and make it available for fuzzing. +// Users must override BuildGraph and FuzzImpl to specify +// (1) which operations are being fuzzed; and +// (2) How to translate the uint8_t* buffer from the fuzzer +// to a Tensor or Tensors that are semantically appropriate +// for the op under test. +// For the simple cases of testing a single op that takes a single +// input Tensor, use the SINGLE_INPUT_OP_BUILDER(dtype, opName) macro in place +// of defining BuildGraphDef. +// +// Typical use: +// class FooFuzzer : public FuzzSession { +// SINGLE_INPUT_OP_BUILDER(DT_INT8, Identity); +// void FuzzImpl(const uint8_t* data, size_t size) { +// ... convert data and size to a Tensor, pass it to: +// RunInputs({{"input", input_tensor}}); +// +class FuzzSession { + public: + FuzzSession() : initialized_(false) {} + virtual ~FuzzSession() {} + + // Constructs a Graph using the supplied Scope. + // By convention, the graph should have inputs named "input1", ... + // "inputN", and one output node, named "output". + // Users of FuzzSession should override this method to create their graph. + virtual void BuildGraph(const Scope& scope) = 0; + + // Implements the logic that converts an opaque byte buffer + // from the fuzzer to Tensor inputs to the graph. Users must override. + virtual void FuzzImpl(const uint8_t* data, size_t size) = 0; + + // Initializes the FuzzSession. Not safe for multithreading. + // Separate init function because the call to virtual BuildGraphDef + // can't be put into the constructor. + Status InitIfNeeded() { + if (initialized_) { + return absl::OkStatus(); + } + initialized_ = true; + + Scope root = Scope::DisabledShapeInferenceScope().ExitOnError(); + SessionOptions options; + session_ = std::unique_ptr(NewSession(options)); + + BuildGraph(root); + + GraphDef graph_def; + TF_CHECK_OK(root.ToGraphDef(&graph_def)); + + Status status = session_->Create(graph_def); + if (!status.ok()) { + // This is FATAL, because this code is designed to fuzz an op + // within a session. Failure to create the session means we + // can't send any data to the op. + LOG(FATAL) << "Could not create session: " << status.message(); + } + return status; + } + + // Runs the TF session by pulling on the "output" node, attaching + // the supplied input_tensor to the input node(s), and discarding + // any returned output. + // Note: We are ignoring Status from Run here since fuzzers don't need to + // check it (as that will slow them down and printing/logging is useless). + void RunInputs(const std::vector >& inputs) { + RunInputsWithStatus(inputs).IgnoreError(); + } + + // Same as RunInputs but don't ignore status + Status RunInputsWithStatus( + const std::vector >& inputs) { + return session_->Run(inputs, {}, {"output"}, nullptr); + } + + // Dispatches to FuzzImpl; small amount of sugar to keep the code + // of the per-op fuzzers tiny. + int Fuzz(const uint8_t* data, size_t size) { + Status status = InitIfNeeded(); + TF_CHECK_OK(status) << "Fuzzer graph initialization failed: " + << status.message(); + // No return value from fuzzing: Success is defined as "did not + // crash". The actual application results are irrelevant. + FuzzImpl(data, size); + return 0; + } + + private: + bool initialized_; + std::unique_ptr session_; +}; + +// A specialized fuzz implementation for ops that take +// a single string. Caller must still define the op +// to plumb by overriding BuildGraph or using +// a plumbing macro. +class FuzzStringInputOp : public FuzzSession { + void FuzzImpl(const uint8_t* data, size_t size) final { + Tensor input_tensor(tensorflow::DT_STRING, TensorShape({})); + input_tensor.scalar()() = + string(reinterpret_cast(data), size); + RunInputs({{"input", input_tensor}}); + } +}; + +} // end namespace fuzzing +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_FUZZING_FUZZ_SESSION_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/gather_functor.h b/third_party/tflite-hdrs/tensorflow/core/kernels/gather_functor.h new file mode 100644 index 00000000..607f3c80 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/gather_functor.h @@ -0,0 +1,183 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_GATHER_FUNCTOR_H_ +#define TENSORFLOW_CORE_KERNELS_GATHER_FUNCTOR_H_ + +#include "absl/base/prefetch.h" +#include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/bounds_check.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/framework/type_traits.h" +#include "tensorflow/core/framework/variant.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/work_sharder.h" + +namespace tensorflow { +typedef Eigen::ThreadPoolDevice CPUDevice; +typedef Eigen::GpuDevice GPUDevice; + +namespace functor { + +// Helper method to copy using memcpy. +template +SliceIndex HandleCopies(OpKernelContext* ctx, + typename TTypes::ConstTensor params, + typename TTypes::ConstFlat indices, + SliceIndex slice_elems, + typename TTypes::Tensor out) { + const SliceIndex indices_size = static_cast(indices.dimension(0)); + const SliceIndex batch_size = static_cast(params.dimension(0)); + const Index limit = static_cast(params.dimension(1)); + T* out_base = out.data(); + const T* params_base = params.data(); + if (static_slice_elems >= 0) { + // Give compiler static knowledge of the number of elements/bytes + slice_elems = static_slice_elems; + } + // Compute slice_bytes here so that static knowledge is available + const size_t slice_bytes = slice_elems * sizeof(T); + auto* worker_threads = ctx->device()->tensorflow_cpu_worker_threads(); + mutex mu; + // Store the value of invalidate index for printing error information, it's a + // shared variable. + SliceIndex result = -1; + auto work = [&](int64_t start, int64_t end) { + SliceIndex batch_idx = static_cast(start / indices_size); + SliceIndex indices_idx = static_cast(start % indices_size); + SliceIndex batch_idx_end = static_cast(end / indices_size); + SliceIndex indices_idx_end = static_cast(end % indices_size); + + while ((batch_idx < batch_idx_end) || + (batch_idx == batch_idx_end && indices_idx < indices_idx_end)) { + SliceIndex i_next = indices_idx + 1; + SliceIndex b_next = batch_idx + 1; + const Index index = internal::SubtleMustCopy(indices(indices_idx)); + if (!FastBoundsCheck(index, limit)) { + mutex_lock l(mu); + result = indices_idx; + return; + } + if ((batch_idx == batch_idx_end && i_next < indices_idx_end) || + (i_next < indices_size)) { + absl::PrefetchToLocalCache(¶ms(batch_idx, indices(i_next), 0)); + absl::PrefetchToLocalCache(&out(batch_idx, i_next, 0)); + b_next = batch_idx; + } else if (b_next <= batch_idx_end) { + absl::PrefetchToLocalCache(¶ms(b_next, indices(0), 0)); + absl::PrefetchToLocalCache(&out(b_next, 0, 0)); + i_next = 0; + } + // Copy using memcpy if possible, otherwise an Eigen loop + // TODO(cwhipkey): avoid linking to framework to get Allocator (to improve + // ahead-of-time compilation binary size). + if (is_simple_type::value) { + // Avoid auto-promotion to Index from SliceIndex by casting. + memcpy( + out_base + (batch_idx * indices_size + indices_idx) * slice_elems, + params_base + (batch_idx * static_cast(limit) + + static_cast(index)) * + slice_elems, + slice_bytes); + } else { + // For non-"simple" types (e.g. strings). + out.template chip<0>(batch_idx).template chip<0>(indices_idx) = + params.template chip<0>(batch_idx).template chip<0>(index); + } + indices_idx = i_next; + batch_idx = b_next; + } + }; + + Shard(worker_threads->num_threads, worker_threads->workers, + batch_size * indices_size, slice_elems * sizeof(T), work); + return result; +} + +template +struct GatherFunctorCPU { + int64_t operator()(OpKernelContext* ctx, + typename TTypes::ConstTensor params, + typename TTypes::ConstFlat indices, + typename TTypes::Tensor out) { + const int64_t indices_size = indices.size(); + const int64_t slice_size = out.dimension(2); + int64_t bad_i; + + const int64_t batch_size = params.dimension(0); + + bool use_large = (slice_size > std::numeric_limits::max() || + params.size() > std::numeric_limits::max() || + indices_size > std::numeric_limits::max() || + batch_size * indices_size * slice_size > + std::numeric_limits::max()); +#define CALL(elems) \ + do { \ + if (use_large) { \ + bad_i = HandleCopies(ctx, params, indices, \ + slice_size, out); \ + } else { \ + const int32 small_slice = static_cast(slice_size); \ + bad_i = HandleCopies(ctx, params, indices, \ + small_slice, out); \ + } \ + } while (0) + + if (slice_size == 10) + CALL(10); + else if (slice_size == 20) + CALL(20); + else + CALL(-1); +#undef CALL + + return bad_i; + } +}; + +template +struct GatherFunctor { + int64_t operator()(OpKernelContext* ctx, + typename TTypes::ConstTensor params, + typename TTypes::ConstFlat indices, + typename TTypes::Tensor out); +}; + +template +struct GatherFunctor { + int64_t operator()(OpKernelContext* ctx, + typename TTypes::ConstTensor params, + typename TTypes::ConstFlat indices, + typename TTypes::Tensor out) { + return GatherFunctorCPU()(ctx, params, indices, out); + } +}; + +template +struct GatherFunctor { + int64_t operator()(OpKernelContext* ctx, + typename TTypes::ConstTensor params, + typename TTypes::ConstFlat indices, + typename TTypes::Tensor out) { + return GatherFunctorCPU()(ctx, params, indices, out); + } +}; + +} // namespace functor +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_GATHER_FUNCTOR_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/gather_functor_batched.h b/third_party/tflite-hdrs/tensorflow/core/kernels/gather_functor_batched.h new file mode 100644 index 00000000..41b809bd --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/gather_functor_batched.h @@ -0,0 +1,201 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_GATHER_FUNCTOR_BATCHED_H_ +#define TENSORFLOW_CORE_KERNELS_GATHER_FUNCTOR_BATCHED_H_ + +#include "absl/base/prefetch.h" +#include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/bounds_check.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/framework/type_traits.h" +#include "tensorflow/core/framework/variant.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/work_sharder.h" + +namespace tensorflow { +typedef Eigen::ThreadPoolDevice CPUDevice; +typedef Eigen::GpuDevice GPUDevice; + +namespace functor { + +// Helper method to copy using memcpy. +template +SliceIndex HandleCopiesBatched(OpKernelContext* ctx, + typename TTypes::ConstTensor params, + typename TTypes::ConstFlat indices, + SliceIndex slice_elems, + typename TTypes::Tensor out) { + const SliceIndex batch_size = static_cast(params.dimension(0)); + const SliceIndex outer_size = static_cast(params.dimension(1)); + const SliceIndex indices_size = + static_cast(indices.dimension(0)) / batch_size; + + const Index limit = static_cast(params.dimension(2)); + if (static_slice_elems >= 0) { + // Give compiler static knowledge of the number of elements/bytes + slice_elems = static_slice_elems; + } + // Compute slice_bytes here so that static knowledge is available + const size_t slice_bytes = slice_elems * sizeof(T); + auto* worker_threads = ctx->device()->tensorflow_cpu_worker_threads(); + mutex mu; + // Store the value of invalidate index for printing error information, it's a + // shared variable. + SliceIndex result = -1; + auto work = [&](int64_t start, int64_t end) { + const int64_t r_start = start % (outer_size * indices_size); + SliceIndex batch_idx = static_cast( + start / (outer_size * indices_size)); + SliceIndex outer_idx = static_cast(r_start / indices_size); + SliceIndex indices_idx = static_cast(r_start % indices_size); + + SliceIndex batch_offset = batch_idx * indices_size; + for (; start < end; ++start) { + SliceIndex i_next = indices_idx + 1; + SliceIndex o_next = outer_idx; + SliceIndex b_next = batch_idx; + SliceIndex b_offset_next = batch_offset; + + if (i_next >= indices_size) { + i_next = 0; + if (++o_next >= outer_size) { + o_next = 0; + ++b_next; + b_offset_next += indices_size; + } + } + if (start + 1 < end) { + absl::PrefetchToLocalCache( + ¶ms(b_next, o_next, indices(b_offset_next + i_next), 0)); + absl::PrefetchToLocalCache(&out(b_next, o_next, i_next, 0)); + } + const Index index = internal::SubtleMustCopy( + indices(batch_offset + indices_idx)); + if (!FastBoundsCheck(index, limit)) { + mutex_lock l(mu); + result = batch_offset + indices_idx; + return; + } + + // Copy using memcpy if possible, otherwise an Eigen loop + // TODO(cwhipkey): avoid linking to framework to get Allocator (to improve + // ahead-of-time compilation binary size). + if (is_simple_type::value) { + // Avoid auto-promotion to Index from SliceIndex by casting. + memcpy( + &out(batch_idx, outer_idx, indices_idx, 0), + ¶ms(batch_idx, outer_idx, static_cast(index), 0), + slice_bytes); + } else { + // For non-"simple" types (e.g. strings). + out.template chip<0>(batch_idx) + .template chip<0>(outer_idx) + .template chip<0>(indices_idx) = + params.template chip<0>(batch_idx) + .template chip<0>(outer_idx) + .template chip<0>(static_cast(index)); + } + + indices_idx = i_next; + outer_idx = o_next; + batch_idx = b_next; + batch_offset = b_offset_next; + } + }; + + Shard(worker_threads->num_threads, worker_threads->workers, + batch_size * outer_size * indices_size, slice_elems * sizeof(T), work); + return result; +} + +template +struct GatherFunctorBatchedCPU { + int64_t operator()(OpKernelContext* ctx, + typename TTypes::ConstTensor params, + typename TTypes::ConstFlat indices, + typename TTypes::Tensor out) { + const int64_t indices_size = indices.size(); // Includes the batch_size. + const int64_t slice_size = out.dimension(3); + int64_t bad_i; + + const int64_t batch_size = params.dimension(0); + const int64_t outer_size = params.dimension(1); + + bool use_large = (slice_size > std::numeric_limits::max() || + params.size() > std::numeric_limits::max() || + indices_size > std::numeric_limits::max() || + batch_size * outer_size * indices_size * slice_size > + std::numeric_limits::max()); +#define CALL(elems) \ + do { \ + if (use_large) { \ + bad_i = HandleCopiesBatched( \ + ctx, params, indices, slice_size, out); \ + } else { \ + const int32 small_slice = static_cast(slice_size); \ + bad_i = HandleCopiesBatched( \ + ctx, params, indices, small_slice, out); \ + } \ + } while (0) + + // TODO(rmlarsen): Investigate whether these specializations are still + // needed and, if yes, whether the slice sizes are appropriate. + if (slice_size == 10) + CALL(10); + else if (slice_size == 20) + CALL(20); + else + CALL(-1); +#undef CALL + + return bad_i; + } +}; + +template +struct GatherFunctorBatched { + int64_t operator()(OpKernelContext* ctx, + typename TTypes::ConstTensor params, + typename TTypes::ConstFlat indices, + typename TTypes::Tensor out); +}; + +template +struct GatherFunctorBatched { + int64_t operator()(OpKernelContext* ctx, + typename TTypes::ConstTensor params, + typename TTypes::ConstFlat indices, + typename TTypes::Tensor out) { + return GatherFunctorBatchedCPU()(ctx, params, indices, out); + } +}; + +template +struct GatherFunctorBatched { + int64_t operator()(OpKernelContext* ctx, + typename TTypes::ConstTensor params, + typename TTypes::ConstFlat indices, + typename TTypes::Tensor out) { + return GatherFunctorBatchedCPU()(ctx, params, indices, out); + } +}; + +} // namespace functor +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_GATHER_FUNCTOR_BATCHED_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/gather_functor_batched_gpu.cu.h b/third_party/tflite-hdrs/tensorflow/core/kernels/gather_functor_batched_gpu.cu.h new file mode 100644 index 00000000..e2cb7597 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/gather_functor_batched_gpu.cu.h @@ -0,0 +1,183 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_GATHER_FUNCTOR_BATCHED_GPU_CU_H_ +#define TENSORFLOW_CORE_KERNELS_GATHER_FUNCTOR_BATCHED_GPU_CU_H_ + +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM + +#define EIGEN_USE_GPU + +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/kernels/gather_functor_batched.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/gpu_kernel_helper.h" + +namespace tensorflow { + +typedef Eigen::GpuDevice GPUDevice; + +template +__global__ void GatherOpKernel(const ValueOrVec* __restrict__ params, + const Index* __restrict__ indices, + ValueOrVec* __restrict__ out, int64 outer_size, + int64 gather_dim_size, int64 indices_size, + int64 slice_size, int64 out_size) { + // params is a tensor of shape + // [batch_size, outer_size, gather_dim_size, slice_size]. + GPU_1D_KERNEL_LOOP(i, out_size) { + Index batch_i = 0; // The batch index into params to use for i. + Index outer_i = 0; // The outer index into params to use for i. + Index indices_i = 0; // The index into indices to use for i. + Index slice_i = 0; // Index into the current slice in params to use for i. + + const Index slices_count = i / slice_size; + if (is_batch_dims_zero) { + if (is_axis_zero) { + indices_i = slices_count; + } else { + outer_i = slices_count / indices_size; + indices_i = slices_count - outer_i * indices_size; + } + } else { + const Index entries_count = slices_count / indices_size; + if (is_axis_zero) { + batch_i = entries_count; + } else { + batch_i = entries_count / outer_size; + outer_i = entries_count - batch_i * outer_size; + } + indices_i = slices_count - entries_count * indices_size; + } + slice_i = i - slices_count * slice_size; + + // Index into the gather axis to use for i. + Index gather_i = ldg(indices + batch_i * indices_size + indices_i); + + // Check gather_i is in [0, gather_dim_size). + if (!FastBoundsCheck(gather_i, gather_dim_size)) { + // Set indices out of range to zero + // TODO(fpmc): Log an error for transfer back to host. + out[i] = ValueOrVec(0); + } else { + // Read params[batch_i, outer_i, gather_i, slice_i] and write it to the + // i'th position in out. + Index params_i = ( + (batch_i * outer_size + outer_i) * gather_dim_size + gather_i + ) * slice_size + slice_i; + out[i] = params[params_i]; + } + } +} + +namespace detail { + +template +struct LaunchGatherKernelVectorized { + template + struct Impl { + template + Status operator()(const GPUDevice& d, const T* params, const Index* indices, + T* out, int64 outer_size, int64 gather_dim_size, + int64 indices_size, int64 slice_size, int64 out_size) { + DCHECK_EQ(slice_size % vec_size, 0); + DCHECK_EQ(out_size % vec_size, 0); + DCHECK_EQ(reinterpret_cast(params) % vec_size, 0); + DCHECK_EQ(reinterpret_cast(out) % vec_size, 0); + int64 out_size_vec = out_size / vec_size; + int64 slice_size_vec = slice_size / vec_size; + using Tvec = AlignedVector; + const Tvec* params_vec = reinterpret_cast(params); + Tvec* out_vec = reinterpret_cast(out); + + GpuLaunchConfig config = GetGpuLaunchConfig( + out_size_vec, d, + &GatherOpKernel, + /*dynamic_shared_memory_size=*/0, /*block_size_limit=*/0); + return GpuLaunchKernel( + GatherOpKernel, + config.block_count, config.thread_per_block, 0, d.stream(), + params_vec, indices, out_vec, outer_size, gather_dim_size, + indices_size, slice_size_vec, out_size_vec); + } + }; +}; + +} // namespace detail + +template +Status LaunchGatherKernel(const GPUDevice& d, const T* params, + const Index* indices, T* out, int64 outer_size, + int64 gather_dim_size, int64 indices_size, + int64 slice_size, int64 out_size) { + // Note that the GPU memory allocator always returns aligned buffers, so the + // alignment of data pointers is expected to be deterministic. + // There will be performance cliffs when slice_size is not aligned, but there + // is no easy way to handle the misalignment because each row will be aligned + // differently. + return DispatchToVectorized< + T, detail::LaunchGatherKernelVectorized< + is_axis_zero, is_batch_dims_zero>::template Impl>( + MinAlignmentOf(params, out, slice_size), d, params, indices, out, + outer_size, gather_dim_size, indices_size, slice_size, out_size); +} + +namespace functor { +template +struct GatherFunctorBatched { + int64 operator()(OpKernelContext* ctx, + typename TTypes::ConstTensor params, + typename TTypes::ConstFlat indices, + typename TTypes::Tensor out) { + const GPUDevice& d = ctx->eigen_gpu_device(); + const int64 out_size = out.size(); + if (out_size == 0) { + // We need a check here since the CPU version does useful error checking + // work if there are nonempty indices but empty slices, so the kernel is + // executed in that case. In the GPU case we don't know how to do error + // checking, so we skip the loop entirely. + return -1; + } + const bool is_batch_dims_zero = params.dimension(0) == 1; + const bool is_axis_zero = params.dimension(1) == 1; + const int64 outer_size = params.dimension(1); + const int64 gather_dim_size = params.dimension(2); + const int64 indices_size = indices.size() / params.dimension(0); + const int64 slice_size = params.dimension(3); + + const auto function = + is_axis_zero + ? (is_batch_dims_zero ? LaunchGatherKernel + : LaunchGatherKernel) + : (is_batch_dims_zero ? LaunchGatherKernel + : LaunchGatherKernel); + TF_CHECK_OK(function(d, params.data(), indices.data(), out.data(), + outer_size, gather_dim_size, indices_size, slice_size, + out_size)); + // TODO(fpmc): enable indices validation on GPU. + // Right now checking for indices out of bound in the kernel would + // require copying code between GPU/CPU, and thus slow. + return -1; + } +}; + +} // namespace functor +} // namespace tensorflow + +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM + +#endif // TENSORFLOW_CORE_KERNELS_GATHER_FUNCTOR_BATCHED_GPU_CU_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/gather_functor_gpu.cu.h b/third_party/tflite-hdrs/tensorflow/core/kernels/gather_functor_gpu.cu.h new file mode 100644 index 00000000..3ac0d912 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/gather_functor_gpu.cu.h @@ -0,0 +1,165 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_GATHER_FUNCTOR_GPU_CU_H_ +#define TENSORFLOW_CORE_KERNELS_GATHER_FUNCTOR_GPU_CU_H_ + +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM + +#define EIGEN_USE_GPU + +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/kernels/gather_functor.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/gpu_kernel_helper.h" + +namespace tensorflow { + +typedef Eigen::GpuDevice GPUDevice; + +template +__global__ void GatherOpKernel(const ValueOrVec* __restrict__ params, + const Index* __restrict__ indices, + ValueOrVec* __restrict__ out, + int64 gather_dim_size, int64 indices_size, + int64 slice_size, int64 out_size) { + GPU_1D_KERNEL_LOOP(i, out_size) { + Index batch_i = 0; + Index indices_i = 0; + Index slice_i = 0; + if (is_axis_zero) { + indices_i = i / slice_size; + slice_i = i - indices_i * slice_size; + } else { + Index batch_indices_i = i / slice_size; + // The batch index into params to use for i. + batch_i = batch_indices_i / indices_size; + // The index into indices to use for i. + indices_i = batch_indices_i - batch_i * indices_size; + // Index into the current slice in params to use for i. + slice_i = i - batch_indices_i * slice_size; + } + + // Index into the gather axis to use for i. + Index gather_i = ldg(indices + indices_i); + + // Check gather_i is in [0, gather_dim_size). + if (!FastBoundsCheck(gather_i, gather_dim_size)) { + // Set indices out of range to zero + // TODO(fpmc): Log an error for transfer back to host. + out[i] = ValueOrVec(0); + } else { + // params is a [batch_size, gather_dim_size, slice_size] tensor. Read + // params[batch_i, gather_i, slice_i] and write it to the i'th position in + // out. + Index params_i = + (batch_i * gather_dim_size + gather_i) * slice_size + slice_i; + out[i] = params[params_i]; + } + } +} + +namespace detail { + +template +struct LaunchGatherKernelVectorized { + template + struct Impl { + template + Status operator()(const GPUDevice& d, const T* params, const Index* indices, + T* out, int64 gather_dim_size, int64 indices_size, + int64 slice_size, int64 out_size) { + DCHECK_EQ(slice_size % vec_size, 0); + DCHECK_EQ(out_size % vec_size, 0); + DCHECK_EQ(reinterpret_cast(params) % vec_size, 0); + DCHECK_EQ(reinterpret_cast(out) % vec_size, 0); + int64 out_size_vec = out_size / vec_size; + int64 slice_size_vec = slice_size / vec_size; + using Tvec = AlignedVector; + const Tvec* params_vec = reinterpret_cast(params); + Tvec* out_vec = reinterpret_cast(out); + + GpuLaunchConfig config = GetGpuLaunchConfig( + out_size_vec, d, &GatherOpKernel, + /*dynamic_shared_memory_size=*/0, /*block_size_limit=*/0); + return GpuLaunchKernel( + GatherOpKernel, config.block_count, + config.thread_per_block, 0, d.stream(), params_vec, indices, out_vec, + gather_dim_size, indices_size, slice_size_vec, out_size_vec); + } + }; +}; + +} // namespace detail + +template +Status LaunchGatherKernel(const GPUDevice& d, const T* params, + const Index* indices, T* out, int64 gather_dim_size, + int64 indices_size, int64 slice_size, + int64 out_size) { + // Note that the GPU memory allocator always returns aligned buffers, so the + // alignment of data pointers is expected to be deterministic. + // There will be performance cliffs when slice_size is not aligned, but there + // is no easy way to handle the misalignment because each row will be aligned + // differently. + return DispatchToVectorized< + T, detail::LaunchGatherKernelVectorized::template Impl>( + MinAlignmentOf(params, out, slice_size), d, params, indices, out, + gather_dim_size, indices_size, slice_size, out_size); +} + +namespace functor { +template +struct GatherFunctor { + int64 operator()(OpKernelContext* ctx, + typename TTypes::ConstTensor params, + typename TTypes::ConstFlat indices, + typename TTypes::Tensor out) { + const GPUDevice& d = ctx->eigen_gpu_device(); + const int64 out_size = out.size(); + if (out_size == 0) { + // We need a check here since the CPU version does useful error checking + // work if there are nonempty indices but empty slices, so the kernel is + // executed in that case. In the GPU case we don't know how to do error + // checking, so we skip the loop entirely. + return -1; + } + const bool is_axis_zero = params.dimension(0) == 1; + const int64 gather_dim_size = params.dimension(1); + const int64 indices_size = indices.size(); + const int64 slice_size = params.dimension(2); + + if (is_axis_zero) { + TF_CHECK_OK(LaunchGatherKernel(d, params.data(), indices.data(), + out.data(), gather_dim_size, + indices_size, slice_size, out_size)); + } else { + TF_CHECK_OK(LaunchGatherKernel( + d, params.data(), indices.data(), out.data(), gather_dim_size, + indices_size, slice_size, out_size)); + } + // TODO(fpmc): enable indices validation on GPU. + // Right now checking for indices out of bound in the kernel would + // require copying code between GPU/CPU, and thus slow. + return -1; + } +}; + +} // namespace functor +} // namespace tensorflow + +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM + +#endif // TENSORFLOW_CORE_KERNELS_GATHER_FUNCTOR_GPU_CU_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/gather_nd_op.h b/third_party/tflite-hdrs/tensorflow/core/kernels/gather_nd_op.h new file mode 100644 index 00000000..b53e1348 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/gather_nd_op.h @@ -0,0 +1,179 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_GATHER_ND_OP_H_ +#define TENSORFLOW_CORE_KERNELS_GATHER_ND_OP_H_ +// Functor definition for GatherOp, must be compilable by nvcc. + +#include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/bounds_check.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/bad_indices_policy.h" +#include "tensorflow/core/util/util.h" + +namespace tensorflow { +class OpKernelContext; +class Tensor; + +namespace functor { + +template +struct GatherNdSlice { + // Performs a slice gather op on (Tparams, Tindices), writing to Tout. + // Returns an index to Tindices if the value at that index is out of range. + // Returns -1 if all values of Tindices are in range. + Index operator()(const Device& d, const Index slice_size, + typename TTypes::Scalar Tscratch, + typename TTypes::ConstTensor Tparams, + typename TTypes::ConstMatrix Tindices, + typename TTypes::Matrix Tout); +}; + +template +absl::Status DoGatherNd( + OpKernelContext* c, const Tensor& params, const Tensor& indices, + Tensor* out, + BadIndicesPolicy bad_indices_policy = BadIndicesPolicy::kDefault) { + if (!TensorShapeUtils::IsVectorOrHigher(params.shape())) { + return errors::InvalidArgument("params must be at least a vector"); + } + if (!TensorShapeUtils::IsVectorOrHigher(indices.shape())) { + return errors::InvalidArgument("indices must be at least a vector"); + } + if (indices.dim_size(indices.dims() - 1) > params.dims()) { + return errors::InvalidArgument( + "index innermost dimension length must be <= params rank; saw: ", + indices.dim_size(indices.dims() - 1), " vs. ", params.dims()); + } + + const TensorShape& indices_shape(indices.shape()); + const int64_t indices_nd = indices_shape.dim_size(indices_shape.dims() - 1); + + // Check that we have enough index space + int64_t N_big = 1; + for (int i = 0; i < indices_shape.dims() - 1; ++i) { + N_big *= indices_shape.dim_size(i); + } + if (N_big > std::numeric_limits::max()) { + return errors::InvalidArgument( + "indices has too many elements for int indexing: ", N_big, " > ", + std::numeric_limits::max()); + } + if (params.NumElements() > std::numeric_limits::max()) { + return errors::InvalidArgument("params.NumElements() too large for ", + DataTypeString(DataTypeToEnum::v()), + " indexing: ", params.NumElements(), " > ", + std::numeric_limits::max()); + } + + // The result shape is + // indices.shape[:-1] + params.shape[indices.shape[-1]:] + Index N_result = 1; + for (int i = 0; i < indices_shape.dims() - 1; ++i) { + N_result *= indices_shape.dim_size(i); + } + + const TensorShape& params_shape(params.shape()); + Index total_nd = params_shape.dims(); + + TensorShape result_shape(indices_shape); + result_shape.RemoveLastDims(1); + + int64_t slice_size_big = 1; + for (Index i = indices_nd; i < total_nd; ++i) { + slice_size_big *= params_shape.dim_size(i); + TF_RETURN_IF_ERROR(result_shape.AddDimWithStatus(params_shape.dim_size(i))); + } + + if (slice_size_big > std::numeric_limits::max()) { + return errors::InvalidArgument( + "slice size is too large for indexing: ", slice_size_big, " > ", + std::numeric_limits::max()); + } + + const Index slice_size = static_cast(slice_size_big); + + TF_RETURN_IF_ERROR( + c->allocate_temp(DataTypeToEnum::value, result_shape, out)); + + if (N_result > 0) { + if (params_shape.num_elements() == 0) { + return errors::InvalidArgument( + "Requested more than 0 entries, but " + "params is empty. Params shape: ", + params_shape.DebugString()); + } + + auto indices_mat = indices.flat_inner_dims(); + + Index bad_i = -1; + + // Request to copy slices / subtensors + // Make out a matrix with the slices the col size. + auto out_mat = out->shaped({N_result, slice_size}); + Tensor scratch; + TF_RETURN_IF_ERROR(c->allocate_temp(DT_INT32, TensorShape(), &scratch)); + auto scratch_scalar = scratch.scalar(); + + switch (indices_nd) { +#define PARAMS_CASE(IXDIM) \ + case IXDIM: { \ + functor::GatherNdSlice func; \ + auto params_flat = params.flat_outer_dims(); \ + bad_i = func(c->eigen_device(), slice_size, scratch_scalar, \ + params_flat, indices_mat, out_mat); \ + } break + PARAMS_CASE(0); + PARAMS_CASE(1); + PARAMS_CASE(2); + PARAMS_CASE(3); + PARAMS_CASE(4); + PARAMS_CASE(5); + PARAMS_CASE(6); + PARAMS_CASE(7); +#undef PARAMS_CASE + default: + return errors::InvalidArgument( + "Only indices.shape[-1] values between 1 and 7 " + "are currently supported. Requested rank: ", + indices_nd); + } + using CPUDevice = Eigen::ThreadPoolDevice; + + const bool check_bad_indices = + ((std::is_same::value && + bad_indices_policy == BadIndicesPolicy::kDefault) || + bad_indices_policy == BadIndicesPolicy::kError); + if (check_bad_indices && bad_i >= 0) { + auto shape = indices.shape(); + shape.RemoveLastDims(1); + return errors::InvalidArgument( + "indices", SliceDebugString(shape, bad_i), " = [", + str_util::Join( + gtl::ArraySlice(&indices_mat(bad_i, 0), indices_nd), ", "), + "] does not index into param shape ", params.shape().DebugString(), + ", node name: ", c->op_kernel().name()); + } + } + return absl::OkStatus(); +} + +} // namespace functor +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_GATHER_ND_OP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/gather_nd_op_cpu_impl.h b/third_party/tflite-hdrs/tensorflow/core/kernels/gather_nd_op_cpu_impl.h new file mode 100644 index 00000000..524f303e --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/gather_nd_op_cpu_impl.h @@ -0,0 +1,149 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_GATHER_ND_OP_CPU_IMPL_H_ +#define TENSORFLOW_CORE_KERNELS_GATHER_ND_OP_CPU_IMPL_H_ + +// Specialization of GatherNdSlice to CPU + +#define EIGEN_USE_THREADS + +#include + +#include "tensorflow/core/framework/bounds_check.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/kernels/gather_nd_op.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/mem.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/util.h" + +namespace tensorflow { + +typedef Eigen::ThreadPoolDevice CPUDevice; + +namespace generator { + +template +class GatherNdSliceGenerator { + public: + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE GatherNdSliceGenerator( + const Index slice_size, typename TTypes::ConstMatrix Tindices, + typename TTypes::ConstTensor Tparams, + typename TTypes::Matrix Tout, std::atomic* error_loc) + : slice_size_(slice_size), + Tindices_(Tindices), + Tparams_(Tparams), + Tout_(Tout), + error_loc_(error_loc) {} + + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE bool GenerateIndices( + const Index loc, Eigen::array* ix) const { + (*ix)[IXDIM] = 0; + bool out_of_bounds = false; + for (int i = 0; i < IXDIM; ++i) { + const Index ix_i = internal::SubtleMustCopy(Tindices_(loc, i)); + (*ix)[i] = ix_i; + out_of_bounds |= !FastBoundsCheck(ix_i, Tparams_.dimension(i)); + } + return out_of_bounds; + } + + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE int32 + operator()(const Eigen::array& loc_array) const { + const Index loc = loc_array[0]; + Eigen::array ix; + Eigen::array ix_out; + ix_out[0] = loc; + ix_out[1] = 0; + const bool out_of_bounds = GenerateIndices(loc, &ix); + if (TF_PREDICT_FALSE(out_of_bounds)) { + error_loc_->store(loc); + std::fill_n(&Tout_(ix_out), slice_size_, T()); + } else { + std::copy_n(&Tparams_(ix), slice_size_, &Tout_(ix_out)); + } + + return static_cast(0); // Return something... + } + + private: + const Index slice_size_; + const typename TTypes::ConstMatrix Tindices_; + const typename TTypes::ConstTensor Tparams_; + mutable typename TTypes::Matrix Tout_; + std::atomic* error_loc_; +}; + +} // namespace generator + +namespace functor { + +template +struct GatherNdSlice { + Index operator()(const CPUDevice& d, const Index slice_size, + typename TTypes::Scalar Tscratch, + typename TTypes::ConstTensor Tparams, + typename TTypes::ConstMatrix Tindices, + typename TTypes::Matrix Tout) { + std::atomic error_loc(-1); + const Eigen::Index batch_size = Tindices.dimension(0); + generator::GatherNdSliceGenerator gather_nd_generator( + slice_size, Tindices, Tparams, Tout, &error_loc); + + auto compute_shard = [&](Eigen::Index begin, Eigen::Index end) { + for (Eigen::Index i = begin; i < end; ++i) { + const Eigen::array loc{i}; + gather_nd_generator(loc); + } + }; + Eigen::Index bytes_moved = sizeof(T) * (slice_size + IXDIM); + auto cost = Eigen::TensorOpCost(bytes_moved /* bytes loaded */, + bytes_moved /* bytes stored */, + slice_size + IXDIM /* compute cycles */); + d.parallelFor(batch_size, cost, compute_shard); + + // error_loc() returns -1 if there's no out-of-bounds index, + // otherwise it returns the location of an OOB index in Tindices. + return error_loc.load(); + } +}; + +#define REGISTER_GATHER_ND_FULL(T, Index) \ + template Index \ + GatherNdSlice::operator()( \ + const CPUDevice& d, const Index slice_size, \ + typename TTypes::Scalar Tscratch, \ + typename TTypes::ConstTensor Tparams, \ + typename TTypes::ConstMatrix Tindices, \ + typename TTypes::Matrix Tout); + +#define REGISTER_GATHER_ND_CPU(type) \ + REGISTER_GATHER_ND_FULL(type, int16); \ + REGISTER_GATHER_ND_FULL(type, int32); \ + REGISTER_GATHER_ND_FULL(type, int64) + +TF_CALL_ALL_TYPES(REGISTER_GATHER_ND_CPU); +TF_CALL_QUANTIZED_TYPES(REGISTER_GATHER_ND_CPU); +TF_CALL_float8_e5m2(REGISTER_GATHER_ND_CPU); +TF_CALL_float8_e4m3fn(REGISTER_GATHER_ND_CPU); + +} // namespace functor + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_GATHER_ND_OP_CPU_IMPL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/gemm_functors.h b/third_party/tflite-hdrs/tensorflow/core/kernels/gemm_functors.h new file mode 100644 index 00000000..8039353e --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/gemm_functors.h @@ -0,0 +1,153 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This is a set of different implementations for the basic matrix by matrix +// multiply function, commonly known as GEMM after the BLAS library's naming. +// Having a standard interface enables us to swap out implementations on +// different platforms, to make sure we're using the optimal version. They are +// implemented as C++ template functors, so they're easy to swap into all of the +// different kernels that use them. + +#if !defined(EIGEN_USE_THREADS) +#error "EIGEN_USE_THREADS must be enabled by all .cc files including this." +#endif // EIGEN_USE_THREADS + +#ifndef TENSORFLOW_CORE_KERNELS_GEMM_FUNCTORS_H_ +#define TENSORFLOW_CORE_KERNELS_GEMM_FUNCTORS_H_ + +#include + +#include +#include + +#include "tensorflow/core/common_runtime/threadpool_device.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_types.h" + +#if defined(TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL) +#include "xla/tsl/framework/contraction/eigen_contraction_kernel.h" +#endif + +// Apple provides an optimized BLAS library that is better than Eigen for their +// devices, so use that if possible. +#if defined(__APPLE__) && defined(USE_GEMM_FOR_CONV) +#include +#define USE_CBLAS_GEMM +#endif // __APPLE__ + +// Older Raspberry Pi systems don't have NEON SIMD acceleration, so Eigen falls +// back to scalar code, but OpenBLAS has much faster support so prefer that. +#if defined(RASPBERRY_PI) && defined(USE_GEMM_FOR_CONV) && defined(USE_OPENBLAS) +#include +#define USE_CBLAS_GEMM +#endif + +// A readable but slow implementation of matrix multiplication, useful for +// debugging and understanding the algorithm. Use instead of FastGemmFunctor in +// the Im2ColConvFunctor template definition inside the op registration to +// enable. Assumes row-major ordering of the values in memory. +template +class ReferenceGemmFunctor { + public: + void operator()(tensorflow::OpKernelContext* ctx, size_t m, size_t n, + size_t k, const T1* a, size_t lda, const T2* b, size_t ldb, + T3* c, size_t ldc) { + const size_t a_i_stride = lda; + const size_t a_l_stride = 1; + const size_t b_j_stride = 1; + const size_t b_l_stride = ldb; + const size_t c_i_stride = ldc; + const size_t c_j_stride = 1; + size_t i, j, l; + for (j = 0; j < n; j++) { + for (i = 0; i < m; i++) { + T3 total(0); + for (l = 0; l < k; l++) { + const size_t a_index = ((i * a_i_stride) + (l * a_l_stride)); + const T1 a_value = a[a_index]; + const size_t b_index = ((j * b_j_stride) + (l * b_l_stride)); + const T2 b_value = b[b_index]; + total += (a_value * b_value); + } + const size_t c_index = ((i * c_i_stride) + (j * c_j_stride)); + c[c_index] = total; + } + } + } +}; + +// Uses the optimized EigenTensor library to implement the matrix multiplication +// required by the Im2ColConvFunctor class. We supply the two input and one +// output types so that the accumulator can potentially be higher-precision than +// the inputs, even though we don't currently take advantage of this. +template +class FastGemmFunctor { + public: + void operator()(tensorflow::OpKernelContext* ctx, size_t m, size_t n, + size_t k, const T1* a, size_t lda, const T2* b, size_t ldb, + T3* c, size_t ldc) { + typename tensorflow::TTypes::Matrix a_matrix(a, m, k); + typename tensorflow::TTypes::Matrix b_matrix(b, k, n); + typename tensorflow::TTypes::Matrix c_matrix(c, m, n); + + Eigen::array, 1> dim_pair; + dim_pair[0].first = 1; + dim_pair[0].second = 0; + c_matrix.device(ctx->eigen_device()) = + a_matrix.contract(b_matrix, dim_pair); + } +}; + +// Use float32 accumulation for bfloat16 to deal with precision accumulation +// issues. +template <> +class FastGemmFunctor { + public: + void operator()(tensorflow::OpKernelContext* ctx, size_t m, size_t n, + size_t k, const Eigen::bfloat16* a, size_t lda, + const Eigen::bfloat16* b, size_t ldb, Eigen::bfloat16* c, + size_t ldc) { + using ConstMatrix = + typename tensorflow::TTypes::Matrix; + ConstMatrix a_matrix(a, m, k); + ConstMatrix b_matrix(b, k, n); + typename tensorflow::TTypes::Matrix c_matrix(c, m, n); + + Eigen::array, 1> dim_pair; + dim_pair[0].first = 1; + dim_pair[0].second = 0; + c_matrix.device(ctx->eigen_device()) = + a_matrix.cast() + .contract(b_matrix.cast(), dim_pair) + .template cast(); + } +}; + +// If we have a fast CBLAS library, use its implementation through a wrapper. +#if defined(USE_CBLAS_GEMM) +template <> +class FastGemmFunctor { + public: + void operator()(tensorflow::OpKernelContext* ctx, size_t m, size_t n, + size_t k, const float* a, size_t lda, const float* b, + size_t ldb, float* c, size_t ldc) { + cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, m, n, k, 1.0f, a, + lda, b, ldb, 0.0f, c, ldc); + } +}; +#endif // USE_CBLAS_GEMM + +#endif // TENSORFLOW_CORE_KERNELS_GEMM_FUNCTORS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/gpu_device_array.h b/third_party/tflite-hdrs/tensorflow/core/kernels/gpu_device_array.h new file mode 100644 index 00000000..be0bd0e8 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/gpu_device_array.h @@ -0,0 +1,125 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_KERNELS_GPU_DEVICE_ARRAY_H_ +#define TENSORFLOW_CORE_KERNELS_GPU_DEVICE_ARRAY_H_ + +#if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \ + (defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM) + +#include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor_reference.h" +#include "tensorflow/core/kernels/gpu_device_array_gpu.h" + +namespace tensorflow { + +// Create an array of value on the host, to be sent to kernel using +// GpuDeviceArrayStruct. +// +// Usage: +// int size = ...; +// GpuDeviceArrayOnHost ptrs(context, size); +// OP_REQUIRES_OK(ptrs.Init()); +// for (int i = 0; i < size; ++i) { +// ptrs.Set(i, ...); +// } +// OP_REQUIRES_OK(ptrs.Finalize()); +// launchKernel(..., ptrs.data, ...); +// +// ValueType must be memcopyable. +template +class GpuDeviceArrayOnHost { + public: + GpuDeviceArrayOnHost(OpKernelContext* context, int32_t size) + : context_(context), + total_bytes_(static_cast(size) * sizeof(ValueType)) { + data_.size = size; + } + + Status Init() { + if (inlined()) { + values_ = data_.inline_values; + return OkStatus(); + } + + // Out-of-line: allocate data that will be memcopied. + AllocatorAttributes attr; + attr.set_on_host(true); + attr.set_gpu_compatible(true); + TF_RETURN_IF_ERROR( + context_->allocate_temp(DT_INT8, TensorShape{total_bytes_}, + &out_of_line_values_on_host_, attr)); + values_ = reinterpret_cast( + out_of_line_values_on_host_.flat().data()); + return OkStatus(); + } + + void Set(int index, ValueType val) { + DCHECK(values_); // ensure Init was called. + DCHECK_LT(index, data_.size); + *(values_ + index) = val; + } + + Status Finalize() { + if (inlined()) { + return OkStatus(); + } + + // Out-of-line - copy pointers to device. + auto stream = context_->op_device_context()->stream(); + TensorReference tensor_ref(out_of_line_values_on_host_); + TF_RETURN_IF_ERROR(context_->allocate_temp( + DT_INT8, TensorShape{total_bytes_}, &out_of_line_values_on_gpu_)); + se::DeviceMemoryBase output_values_base{ + out_of_line_values_on_gpu_.flat().data(), + static_cast(total_bytes_)}; + TF_RETURN_IF_ERROR(stream->Memcpy( + &output_values_base, out_of_line_values_on_host_.flat().data(), + total_bytes_)); + context_->device() + ->tensorflow_accelerator_device_info() + ->event_mgr->ThenExecute(stream, + [tensor_ref]() { tensor_ref.Unref(); }); + data_.out_of_line_values = reinterpret_cast( + out_of_line_values_on_gpu_.flat().data()); + return OkStatus(); + } + + const GpuDeviceArrayStruct& data() const { + // Ensure Finalize is called. + DCHECK(inlined() || out_of_line_values_on_gpu_.IsInitialized()); + return data_; + } + + private: + bool inlined() const { return data_.size <= MaxInlineValues; } + + OpKernelContext* const context_; + const int64_t total_bytes_; // total size of all pointers. + ValueType* values_ = nullptr; + GpuDeviceArrayStruct data_; + + Tensor out_of_line_values_on_host_; + Tensor out_of_line_values_on_gpu_; + + GpuDeviceArrayOnHost(const GpuDeviceArrayOnHost&) = delete; + void operator=(const GpuDeviceArrayOnHost&) = delete; +}; + +} // namespace tensorflow + +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM + +#endif // TENSORFLOW_CORE_KERNELS_GPU_DEVICE_ARRAY_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/gpu_device_array_gpu.h b/third_party/tflite-hdrs/tensorflow/core/kernels/gpu_device_array_gpu.h new file mode 100644 index 00000000..15a09e3d --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/gpu_device_array_gpu.h @@ -0,0 +1,50 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Contains structs and functions to be included in device code. + +#ifndef TENSORFLOW_CORE_KERNELS_GPU_DEVICE_ARRAY_GPU_H_ +#define TENSORFLOW_CORE_KERNELS_GPU_DEVICE_ARRAY_GPU_H_ + +#if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \ + (defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM) + +namespace tensorflow { + +// To decode on the device side, use GetGpuDeviceArrayOnDevice. +// To encode on the host side, use GpuDeviceArrayOnHost. +template +struct GpuDeviceArrayStruct { + int32 size; + // used if size <= MaxInlineValues; + ValueType inline_values[MaxInlineValues]; + ValueType* out_of_line_values = nullptr; // used if size > MaxInlineValues; +}; + +template +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE ValueType* GetGpuDeviceArrayOnDevice( + GpuDeviceArrayStruct* data) { + if (data->size <= MaxInlineValues) { + return data->inline_values; + } else { + return data->out_of_line_values; + } +} + +} // namespace tensorflow + +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM + +#endif // TENSORFLOW_CORE_KERNELS_GPU_DEVICE_ARRAY_GPU_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/gpu_prim.h b/third_party/tflite-hdrs/tensorflow/core/kernels/gpu_prim.h new file mode 100644 index 00000000..bef22b50 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/gpu_prim.h @@ -0,0 +1,117 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +To in writing unless required by applicable law or agreed, +distributed on an, software distributed under the license is "AS IS" +BASIS, WITHOUT OF ANY KIND WARRANTIES OR CONDITIONS, either express +or implied. For the specific language governing permissions and +limitations under the license, the license you must see. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_KERNELS_GPU_PRIM_H_ +#define TENSORFLOW_CORE_KERNELS_GPU_PRIM_H_ + +#include "tensorflow/core/platform/bfloat16.h" + +#if GOOGLE_CUDA +#include "cub/block/block_load.cuh" +#include "cub/block/block_scan.cuh" +#include "cub/block/block_store.cuh" +#include "cub/device/device_histogram.cuh" +#include "cub/device/device_radix_sort.cuh" +#include "cub/device/device_reduce.cuh" +#include "cub/device/device_scan.cuh" +#include "cub/device/device_segmented_radix_sort.cuh" +#include "cub/device/device_segmented_reduce.cuh" +#include "cub/device/device_select.cuh" +#include "cub/iterator/counting_input_iterator.cuh" +#include "cub/iterator/transform_input_iterator.cuh" +#include "cub/thread/thread_operators.cuh" +#include "cub/warp/warp_reduce.cuh" +#include "third_party/gpus/cuda/include/cusparse.h" + +namespace gpuprim = ::cub; + +// Required for sorting Eigen::half and bfloat16. +namespace cub { +template <> +__device__ __forceinline__ void ThreadStoreVolatilePtr( + Eigen::half *ptr, Eigen::half val, Int2Type /*is_primitive*/) { + *reinterpret_cast(ptr) = + Eigen::numext::bit_cast(val); +} + +template <> +__device__ __forceinline__ Eigen::half ThreadLoadVolatilePointer( + Eigen::half *ptr, Int2Type /*is_primitive*/) { + uint16_t result = *reinterpret_cast(ptr); + return Eigen::numext::bit_cast(result); +} + +template <> +__device__ __forceinline__ void ThreadStoreVolatilePtr( + Eigen::bfloat16 *ptr, Eigen::bfloat16 val, + Int2Type /*is_primitive*/) { + *reinterpret_cast(ptr) = + Eigen::numext::bit_cast(val); +} + +template <> +__device__ __forceinline__ Eigen::bfloat16 +ThreadLoadVolatilePointer(Eigen::bfloat16 *ptr, + Int2Type /*is_primitive*/) { + uint16_t result = *reinterpret_cast(ptr); + return Eigen::numext::bit_cast(result); +} + +template <> +struct NumericTraits + : BaseTraits {}; +template <> +struct NumericTraits + : BaseTraits {}; +} // namespace cub +#elif TENSORFLOW_USE_ROCM +#include "rocm/include/hipcub/hipcub.hpp" +#include "rocm/rocm_config.h" +namespace gpuprim = ::hipcub; + +// Required for sorting Eigen::half and bfloat16. +namespace rocprim { +namespace detail { +#if (TF_ROCM_VERSION >= 50200) +template <> +struct float_bit_mask { + static constexpr uint16_t sign_bit = 0x8000; + static constexpr uint16_t exponent = 0x7C00; + static constexpr uint16_t mantissa = 0x03FF; + using bit_type = uint16_t; +}; + +template <> +struct float_bit_mask { + static constexpr uint16_t sign_bit = 0x8000; + static constexpr uint16_t exponent = 0x7F80; + static constexpr uint16_t mantissa = 0x007F; + using bit_type = uint16_t; +}; +#endif +template <> +struct radix_key_codec_base + : radix_key_codec_floating {}; +template <> +struct radix_key_codec_base + : radix_key_codec_floating {}; +}; // namespace detail +}; // namespace rocprim + +#endif // TENSORFLOW_USE_ROCM + +#endif // TENSORFLOW_CORE_KERNELS_GPU_PRIM_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/gpu_prim_helpers.h b/third_party/tflite-hdrs/tensorflow/core/kernels/gpu_prim_helpers.h new file mode 100644 index 00000000..52599890 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/gpu_prim_helpers.h @@ -0,0 +1,286 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_KERNELS_GPU_PRIM_HELPERS_H_ +#define TENSORFLOW_CORE_KERNELS_GPU_PRIM_HELPERS_H_ + +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM + +#define EIGEN_USE_GPU + +#include "xla/stream_executor/stream.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/kernels/gpu_prim.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/util/gpu_kernel_helper.h" + +namespace tensorflow { + +namespace detail { + +template +__global__ void RangeInitKernel(const T start, const T delta, const T size, + T* out) { + GPU_1D_KERNEL_LOOP(i, size) { out[i] = start + i * delta; } +} + +// Initialize out with range start, start + delta, start + 2 * delta, ... +template +Status RangeInit(const Eigen::GpuDevice& d, const T start, const T delta, + const T size, T* out) { + if (size == 0) return OkStatus(); + GpuLaunchConfig config = GetGpuLaunchConfig(size, d); + return GpuLaunchKernel(RangeInitKernel, config.block_count, + config.thread_per_block, 0, d.stream(), start, delta, + size, out); +} + +// Computes keys_out = sorted(keys_in), and indices_out = argsort(keys_in). +// If keys_out is not required, it can be set to nullptr. +// If indices_in is nullptr, the range of input indices [0, size) will be used. +template +Status GpuRadixSortImpl(OpKernelContext* context, int size, const Tkey* keys_in, + Tkey* keys_out, // Optional + const Tindex* indices_in, // Optional + Tindex* indices_out, int num_bits = sizeof(Tkey) * 8) { + if (size == 0) return OkStatus(); + if (num_bits == 0) { + // Workaround for CUB failing when begin_bit = end_bit = 0 (e.g., when all + // keys are 0, so no sorting is needed). + se::Stream* stream = context->op_device_context()->stream(); + if (keys_out) { + // Copy keys_in to keys_out. + size_t num_bytes = size * sizeof(Tkey); + se::DeviceMemoryBase src(const_cast(keys_in), num_bytes); + se::DeviceMemoryBase dst(keys_out, num_bytes); + TF_RETURN_IF_ERROR(stream->Memcpy(&dst, src, num_bytes)); + } + if (indices_in) { + // Copy indices_in to indices_out. + size_t num_bytes = size * sizeof(Tindex); + se::DeviceMemoryBase src(const_cast(indices_in), num_bytes); + se::DeviceMemoryBase dst(indices_out, num_bytes); + TF_RETURN_IF_ERROR(stream->Memcpy(&dst, src, num_bytes)); + } else { + // Set output indices to range. + const Eigen::GpuDevice& device = + context->eigen_device(); + TF_RETURN_IF_ERROR(detail::RangeInit(device, Tindex(0), Tindex(1), + Tindex(size), indices_out)); + } + return OkStatus(); + } + // Allocate temporary inputs/outputs if necessary. + Tensor tmp_indices_in; + if (!indices_in) { + TF_RETURN_IF_ERROR(context->allocate_temp( + DataTypeToEnum::value, TensorShape({size}), &tmp_indices_in)); + Tindex* mutable_indices_in = tmp_indices_in.flat().data(); + indices_in = mutable_indices_in; + const Eigen::GpuDevice& device = context->eigen_device(); + // Initialize indices_in to the input index range. + TF_RETURN_IF_ERROR(detail::RangeInit(device, Tindex(0), Tindex(1), + Tindex(size), mutable_indices_in)); + } + Tensor tmp_keys_out; + if (!keys_out) { + TF_RETURN_IF_ERROR(context->allocate_temp( + DataTypeToEnum::value, TensorShape({size}), &tmp_keys_out)); + keys_out = tmp_keys_out.flat().data(); + } + // Determine temporary device storage requirements. + Tensor temp_storage; + size_t temp_storage_bytes = 0; + const auto& cu_stream = GetGpuStream(context); + gpuError_t err; + if constexpr (Descending) { + err = gpuprim::DeviceRadixSort::SortPairsDescending( + nullptr, temp_storage_bytes, keys_in, keys_out, indices_in, indices_out, + size, /*begin_bit=*/0, /*end_bit=*/num_bits, cu_stream); + } else { + err = gpuprim::DeviceRadixSort::SortPairs( + nullptr, temp_storage_bytes, keys_in, keys_out, indices_in, indices_out, + size, /*begin_bit=*/0, /*end_bit=*/num_bits, cu_stream); + } + if (err != 0) { + return errors::Internal( + "Failed to launch gpuprim::DeviceRadixSort::SortPairs to calculate " + "temp_storage_bytes, status: ", + cudaGetErrorString(err)); + } + // Allocate temporary storage. + TF_RETURN_IF_ERROR(context->allocate_temp( + DT_INT8, TensorShape({static_cast(temp_storage_bytes)}), + &temp_storage)); + // Sort indices by keys. + if constexpr (Descending) { + err = gpuprim::DeviceRadixSort::SortPairsDescending( + temp_storage.flat().data(), temp_storage_bytes, keys_in, keys_out, + indices_in, indices_out, size, /*begin_bit=*/0, /*end_bit=*/num_bits, + cu_stream); + } else { + err = gpuprim::DeviceRadixSort::SortPairs( + temp_storage.flat().data(), temp_storage_bytes, keys_in, keys_out, + indices_in, indices_out, size, /*begin_bit=*/0, /*end_bit=*/num_bits, + cu_stream); + } + if (err != 0) { + return errors::Internal( + "Failed to launch gpuprim::DeviceRadixSort::SortPairs, " + "temp_storage_bytes: ", + temp_storage_bytes, "status: ", cudaGetErrorString(err)); + } + return OkStatus(); +} + +} // namespace detail + +template +Status GpuRadixSort(OpKernelContext* context, int size, const Tkey* keys_in, + Tkey* keys_out, // Optional + const Tindex* indices_in, // Optional + Tindex* indices_out, int num_bits = sizeof(Tkey) * 8) { + return detail::GpuRadixSortImpl( + context, size, keys_in, keys_out, indices_in, indices_out, num_bits); +} + +template +Status GpuRadixSortDescending(OpKernelContext* context, int size, + const Tkey* keys_in, + Tkey* keys_out, // Optional + const Tindex* indices_in, // Optional + Tindex* indices_out, + int num_bits = sizeof(Tkey) * 8) { + return detail::GpuRadixSortImpl( + context, size, keys_in, keys_out, indices_in, indices_out, num_bits); +} + +template +Status GpuInclusivePrefixSum(OpKernelContext* context, int size, + InputIteratorT input, OutputIteratorT output) { + static_assert( + !std::is_same::type, + bool>::value, + "GpuInclusivePrefixSum does not work correct with booleans, please use " + "TransformInputIterator to explicitly cast to an integer."); + if (size == 0) return OkStatus(); + const auto& cu_stream = GetGpuStream(context); + size_t temp_storage_bytes; + auto err = gpuprim::DeviceScan::InclusiveSum(nullptr, temp_storage_bytes, + input, output, size, cu_stream); + if (err != 0) { + return errors::Internal( + "Failed to launch gpuprim::DeviceScan::InclusiveSum to calculate " + "temp_storage_bytes, status: ", + cudaGetErrorString(err)); + } + Tensor temp_storage; + TF_RETURN_IF_ERROR(context->allocate_temp( + DT_INT8, TensorShape({static_cast(temp_storage_bytes)}), + &temp_storage)); + err = gpuprim::DeviceScan::InclusiveSum(temp_storage.flat().data(), + temp_storage_bytes, input, output, + size, cu_stream); + if (err != 0) { + return errors::Internal( + "Failed to launch gpuprim::DeviceScan::InclusiveSum, " + "temp_storage_bytes: ", + temp_storage_bytes, ", status: ", cudaGetErrorString(err)); + } + return OkStatus(); +} + +// Note that this behaves deterministically for repeat calls on the same device. +template +Status GpuSegmentedReduce( + OpKernelContext* context, int num_segments, ReduceOp reduce_op, + const T& initial_value, + InputIteratorT input, // [any] + OffsetIteratorT segment_offsets, // [num_segments + 1] + OutputIteratorT output) { // [num_segments] + if (num_segments == 0) return OkStatus(); + const auto& cu_stream = GetGpuStream(context); + size_t temp_storage_bytes; + auto err = gpuprim::DeviceSegmentedReduce::Reduce( + nullptr, temp_storage_bytes, input, output, num_segments, segment_offsets, + segment_offsets + 1, reduce_op, initial_value, cu_stream); + if (err != 0) { + return errors::Internal( + "Failed to launch gpuprim::DeviceSegmentedReduce::Reduce to calculate " + "temp_storage_bytes, status: ", + cudaGetErrorString(err)); + } + Tensor temp_storage; + TF_RETURN_IF_ERROR(context->allocate_temp( + DT_INT8, TensorShape({static_cast(temp_storage_bytes)}), + &temp_storage)); + err = gpuprim::DeviceSegmentedReduce::Reduce( + temp_storage.flat().data(), temp_storage_bytes, input, output, + num_segments, segment_offsets, segment_offsets + 1, reduce_op, + initial_value, cu_stream); + if (err != 0) { + return errors::Internal( + "Failed to launch gpuprim::DeviceSegmentedReduce::Reduce" + ", temp_storage_bytes: ", + temp_storage_bytes, ", status: ", cudaGetErrorString(err)); + } + return OkStatus(); +} + +template +Status GpuSelectFlagged(OpKernelContext* context, int size, + InputIteratorT input, FlagIteratorT flags, + OutputIteratorT output, + NumSelectedT* out_num_selected = nullptr) { + const auto& cu_stream = GetGpuStream(context); + Tensor out_num_selected_t; + if (!out_num_selected) { + TF_RETURN_IF_ERROR( + context->allocate_temp(DataTypeToEnum::value, + TensorShape({}), &out_num_selected_t)); + out_num_selected = out_num_selected_t.scalar().data(); + } + size_t temp_storage_bytes; + auto err = + gpuprim::DeviceSelect::Flagged(nullptr, temp_storage_bytes, input, flags, + output, out_num_selected, size, cu_stream); + if (err != 0) { + return errors::Internal( + "Failed to launch gpuprim::DeviceSelect::Flagged to calculate " + "temp_storage_bytes, status: ", + cudaGetErrorString(err)); + } + Tensor temp_storage; + TF_RETURN_IF_ERROR(context->allocate_temp( + DT_INT8, TensorShape({static_cast(temp_storage_bytes)}), + &temp_storage)); + err = gpuprim::DeviceSelect::Flagged(temp_storage.flat().data(), + temp_storage_bytes, input, flags, output, + out_num_selected, size, cu_stream); + if (err != 0) { + return errors::Internal( + "Failed to launch gpuprim::DeviceSelect::Flagged, temp_storage_bytes: ", + temp_storage_bytes, ", status: ", cudaGetErrorString(err)); + } + return OkStatus(); +} + +} // namespace tensorflow + +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM + +#endif // TENSORFLOW_CORE_KERNELS_GPU_PRIM_HELPERS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/gpu_utils.h b/third_party/tflite-hdrs/tensorflow/core/kernels/gpu_utils.h new file mode 100644 index 00000000..8d511859 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/gpu_utils.h @@ -0,0 +1,448 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_GPU_UTILS_H_ +#define TENSORFLOW_CORE_KERNELS_GPU_UTILS_H_ + +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM + +#include + +#include "absl/strings/str_cat.h" +#include "absl/types/span.h" +#include "xla/stream_executor/dnn.h" +#include "xla/stream_executor/lazy_op_runner.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/lib/strings/stringprintf.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/stream_executor.h" + +namespace stream_executor { +class RedzoneAllocator; +} // namespace stream_executor + +namespace xla { +class AutotuneResult; +} // namespace xla + +namespace tensorflow { + +// Returns true if bfloat16 is directly supported in Ops and inputs shall not be +// casted to floats to perform the computations and then back. +bool IsBF16SupportedInOps(se::Stream* stream); + +class NodeDef; +using xla::AutotuneResult; + +template +se::DeviceMemory AsDeviceMemory(const T* gpu_memory) { + se::DeviceMemoryBase wrapped(const_cast(gpu_memory)); + se::DeviceMemory typed(wrapped); + return typed; +} + +// Return whether the redzone check is disabled. +// +// Controlled by the TF_DISABLE_RZ_CHECK environment variable. +bool RedzoneCheckDisabled(); + +// Return an allocated buffer with redzones the size of `buffer`. Does +// *not* copy the contents of the `buffer` into the newly allocated buffer: +// assumes that buffer is a pure out-parameter. +// +// Returns `buffer` if RedzoneCheckDisabled() is true. +// +// On error, return `buffer`, and log an error message (once). +se::DeviceMemoryBase WrapRedzoneBestEffort(se::RedzoneAllocator* rz_allocator, + se::DeviceMemoryBase buffer); + +// Check the passed allocator for redzone violations. +// If violations have occurred, mark the corresponding autotune result +// as a failure. +void CheckRedzones(const se::RedzoneAllocator& rz_allocator, + AutotuneResult* autotune_result); + +template +inline se::DeviceMemory AsDeviceMemory(const T* cuda_memory, uint64 size) { + se::DeviceMemoryBase wrapped(const_cast(cuda_memory), size * sizeof(T)); + se::DeviceMemory typed(wrapped); + return typed; +} + +// Returns whether cuBLASLt is enabled. +// +// Controlled by the TF_USE_CUBLASLT environment variable. +bool EnableCublasLtGemm(); + +namespace internal { + +template +struct AutotuneMapHasher { + std::size_t operator()(const Parameters& parameter) const { + return parameter.hash(); + } +}; + +} // namespace internal + +// A helper class that looks up the best autotuned config from parameters. +// Due to the noisy nature of autotune, especially with multiple devices, it +// only accepts a config if its margin exceeds a threshold. +// For the same shape configs, if a new best config matches the previous best, +// they get promoted; otherwise, the winner gets demoted. This process stops +// when the winner's score exceeds the threshold. +// In a bad case when two configs are very close to each other and flips +// back and forth randomly, the expected number of experiments before autotune +// settles is O(threshold ^ 2). So we recommend that number of warmup runs +// for any benchmarks. +template > +class AutotuneMap { + public: + bool Find(const Parameters& params, Config* config) const { + mutex_lock lock(mu_); + auto iter = params_config_map_.find(params); + if (iter == params_config_map_.end() || + (iter->second.score < min_score_threshold_ && + iter->second.count <= max_autotune_count_)) { + return false; + } + *config = iter->second.config; + return true; + } + void Insert(const Parameters& params, const Config& config) { + mutex_lock lock(mu_); + auto iter = params_config_map_.find(params); + int new_score = 0; + if (iter == params_config_map_.end()) { + // Create a new entry if params is new. + VLOG(1) << GetActionSummary("creates", params, config); + params_config_map_.insert( + std::make_pair(params, ValueType{config, 1, 1})); + new_score = 1; + } else if (iter->second.score < min_score_threshold_ && + iter->second.count <= max_autotune_count_) { + DCHECK_GT(iter->second.score, 0); + if (iter->second.config != config) { + // If it is different from the current winner, demotes the winner. + VLOG(1) << GetActionSummary("demotes", params, config); + new_score = --iter->second.score; + ++iter->second.count; + if (new_score <= 0) { + VLOG(1) << GetActionSummary("erases", params, config); + params_config_map_.erase(iter); + } + } else { + // If it is the same as the current winner, promotes the winner. + VLOG(1) << GetActionSummary("promotes", params, config); + new_score = ++iter->second.score; + ++iter->second.count; + } + } + if (new_score >= min_score_threshold_) { + VLOG(1) << GetActionSummary("accepts", params, config); + } else if (autotune_global_count_ >= max_autotune_global_count_) { + // The autotuning exceeds the max iteration threshold and we accept the + // the winner if it exists in the map, otherwise we accept the current + // winner. + auto winner = params_config_map_.find(params); + if (winner == params_config_map_.end()) { + VLOG(1) << GetActionSummary("creates", params, config); + for (int i = 0; i < min_score_threshold_; ++i) { + VLOG(1) << GetActionSummary("promotes", params, config); + } + params_config_map_.insert( + std::make_pair(params, ValueType{config, min_score_threshold_, 1})); + } else { + int promotes_times = min_score_threshold_ - winner->second.score; + for (int i = 0; i < promotes_times; ++i) { + VLOG(1) << GetActionSummary("promotes", params, config); + } + winner->second.score = min_score_threshold_; + } + VLOG(1) << GetActionSummary("accepts", params, config); + } + autotune_global_count_++; + } + + std::unordered_map GetMap() const { + mutex_lock lock(mu_); + std::unordered_map map; + for (const auto& entry : params_config_map_) { + map.insert(std::make_pair(entry.first, entry.second.config)); + } + return map; + } + + // Only for testing + void ClearMap() { + mutex_lock lock(mu_); + params_config_map_.clear(); + } + + private: + // Underlying data structure of values in the map. + struct ValueType { + Config config; + int32 score; + int32 count; + }; + AutotuneMap(const std::string& name) : name_(name) { + min_score_threshold_ = 1; + int min_warmup_iterations = 10; + const char* threshold_str = getenv("TF_AUTOTUNE_THRESHOLD"); + if (threshold_str != nullptr) { + VLOG(1) << "TF_AUTOTUNE_THRESHOLD = " << threshold_str; + strings::safe_strto32(threshold_str, &min_score_threshold_); + } + const char* min_warmup_iteration_str = + getenv("TF_AUTOTUNE_MIN_WARMUP_ITERATIONS"); + if (min_warmup_iteration_str != nullptr) { + strings::safe_strto32(min_warmup_iteration_str, &min_warmup_iterations); + } + min_score_threshold_ = std::max(min_score_threshold_, 1); + max_autotune_count_ = std::max( + 5 * min_score_threshold_ * min_score_threshold_, min_warmup_iterations); + max_autotune_global_count_ = 2 * max_autotune_count_; + autotune_global_count_ = 0; + } + + template + friend class AutotuneSingleton; + + std::string GetActionSummary(StringPiece action, const Parameters& params, + const Config& config) { + return strings::Printf("autotune_map %s %s: %s -> (%s)", name_.c_str(), + string(action).c_str(), params.ToString().c_str(), + config.ToString().c_str()); + } + + mutable mutex mu_; + + std::unordered_map params_config_map_ + TF_GUARDED_BY(mu_); + std::string name_; + int32 min_score_threshold_; + int32 max_autotune_count_; + int32 max_autotune_global_count_; + int32 autotune_global_count_; + + AutotuneMap(const AutotuneMap&) = delete; + void operator=(const AutotuneMap&) = delete; +}; + +// A Singleton helper that manages the global autotune results by groups. +// The caller specified arbitrary Group type that can distinguish between +// different autotune results, even if their Parameters and Configs are the +// same. +template > +class AutotuneSingleton { + public: + typedef AutotuneMap AutotuneType; + static AutotuneType* GetInstance() { + static AutotuneType* instance = new AutotuneType(Group::name()); + return instance; + } +}; + +// Logs convolution results to customized back-storage. +void LogConvAutotuneResults(se::dnn::ConvolutionKind kind, + se::dnn::DataType element_type, + se::DeviceMemoryBase input_buffer, + se::DeviceMemoryBase filter_buffer, + se::DeviceMemoryBase output_buffer, + const se::dnn::BatchDescriptor& input_desc, + const se::dnn::FilterDescriptor& filter_desc, + const se::dnn::BatchDescriptor& output_desc, + const se::dnn::ConvolutionDescriptor& conv_desc, + se::StreamExecutor* stream_exec, + absl::Span results); + +// Logs fused convolution results to customized back-storage. +void LogFusedConvForwardAutotuneResults( + se::dnn::DataType element_type, se::DeviceMemoryBase input_buffer, + se::DeviceMemoryBase filter_buffer, se::DeviceMemoryBase output_buffer, + se::DeviceMemoryBase bias_buffer, se::DeviceMemoryBase side_input_buffer, + const se::dnn::BatchDescriptor& input_desc, + const se::dnn::FilterDescriptor& filter_desc, + const se::dnn::BatchDescriptor& output_desc, + const se::dnn::ConvolutionDescriptor& conv_desc, double conv_scale, + double side_value_scale, se::dnn::ActivationMode activation_mode, + se::StreamExecutor* stream_exec, absl::Span results); + +// Logs fused matmul results to customized back-storage. +void LogFusedMatmulAutotuneResults( + se::dnn::DataType ab_dtype, se::dnn::DataType c_dtype, + se::DeviceMemoryBase a_buffer, se::DeviceMemoryBase b_buffer, + se::DeviceMemoryBase c_buffer, se::DeviceMemoryBase bias_buffer, + bool trans_a, bool trans_b, uint32_t m, uint32_t n, uint32_t k, int32_t lda, + int32_t ldb, int32_t ldc, se::dnn::ActivationMode activation_mode, + se::StreamExecutor* stream_exec, absl::Span results); + +// Autotuning map entry for cuDNN-frontend-capable APIs. +// +// The longer-term intent is to remove the AlgorithmConfig variant and make this +// contain only the two LazyOpRunners, but for the time being ROCm is stuck on +// the legacy API and requires an AlgorithmConfig. +template +class AutotuneEntry { + public: + AutotuneEntry() : is_algorithm_config_(true) {} + + // Initialize with legacy-API AlgorithmConfig; used for the ROCm backend only. + explicit AutotuneEntry(se::dnn::AlgorithmConfig config) + : is_algorithm_config_(true), algorithm_config_(std::move(config)) {} + + AutotuneEntry(std::shared_ptr> primary, + std::shared_ptr> no_scratch_fallback) + : is_algorithm_config_(false), + op_runners_{std::move(primary), std::move(no_scratch_fallback)} {} + + // Initialize from config data, without pre-cached runners, such as when + // loading AoT autotuning maps. + AutotuneEntry(se::dnn::AlgorithmDesc primary, + absl::optional no_scratch_fallback) + : AutotuneEntry(std::make_shared>(primary), + no_scratch_fallback + ? std::make_shared>( + *no_scratch_fallback) + : nullptr) {} + + // Initialize with pre-cached OpRunners, such as during autotuning. + static StatusOr FromOpRunners( + std::unique_ptr> primary, + std::unique_ptr> + no_cache_fallback) { + TF_ASSIGN_OR_RETURN( + auto primary_cache, + se::dnn::LazyOpRunner::FromOpRunner(std::move(primary))); + + if (no_cache_fallback) { + TF_ASSIGN_OR_RETURN(auto fallback_cache, + se::dnn::LazyOpRunner::FromOpRunner( + std::move(no_cache_fallback))); + return AutotuneEntry(std::move(primary_cache), std::move(fallback_cache)); + + } else { + return AutotuneEntry(std::move(primary_cache), nullptr); + } + } + + struct OpRunners { + OpRunners() = default; + + OpRunners(std::shared_ptr> primary_, + std::shared_ptr> no_scratch_fallback_) + : primary(std::move(primary_)), + no_scratch_fallback(std::move(no_scratch_fallback_)) {} + + // Null iff this 'OpRunners' is default-constructed as part of the + // fake-variant in AutotuneEntry; users outside gpu_utils.h itself should + // never see primary = nullptr. + std::shared_ptr> primary; + std::shared_ptr> no_scratch_fallback; // Nullable + + bool operator==(const OpRunners& other) const { + return *primary == *other.primary && + ((!no_scratch_fallback && !other.no_scratch_fallback) || + (no_scratch_fallback && other.no_scratch_fallback && + *no_scratch_fallback == *other.no_scratch_fallback)); + } + }; + + bool is_algorithm_config() const { return is_algorithm_config_; } + + const se::dnn::AlgorithmConfig& GetAlgorithmConfig() const { + DCHECK(is_algorithm_config_); + return algorithm_config_; + } + + const OpRunners& GetOpRunners() const { + DCHECK(!is_algorithm_config_); + return op_runners_; + } + + // AutotuneMap needs to test equality to keep track of the number of times an + // algorithm has won autotuning; for this purpose, we can use ToString to + // determine whether runners are equal. + bool operator==(const AutotuneEntry& other) const { + if (is_algorithm_config_) { + return other.is_algorithm_config_ && + algorithm_config_ == other.algorithm_config_; + } + + return !other.is_algorithm_config_ && op_runners_ == other.op_runners_; + } + + bool operator!=(const AutotuneEntry& other) const { + return !(*this == other); + } + + std::string ToString() const { + if (is_algorithm_config_) { + return algorithm_config_.ToString(); + } + return absl::StrCat("{", op_runners_.primary->ToString(), ", ", + (op_runners_.no_scratch_fallback + ? op_runners_.no_scratch_fallback->ToString() + : "(op_runners have no fallback)"), + "}"); + } + + private: + // NVCC is broken, so we can't use absl::variant here. Just fake it with a + // bool and both fields. + bool is_algorithm_config_; + se::dnn::AlgorithmConfig algorithm_config_; + OpRunners op_runners_; +}; + +namespace internal { +StatusOr> BestCudnnConvAlgorithmIndices( + absl::Span results); +} // namespace internal + +// Returns the best algorithms for the config, one is the fastest, the other is +// other is fastest with 0 scratch space. Unsuccessful autotuning results are +// allowed and ignored. +StatusOr BestCudnnConvAlgorithm( + absl::Span results); + +// Explicitly-instantiated with ConvOp and FusedConvOp. +// +// The definition can't be in the header because including .pb.h files in +// headers is forbidden. +template +StatusOr> BestCudnnConvAlgorithm( + absl::Span results, + std::vector< + std::unique_ptr>> + runners); + +// Get the Dnn workspace limit from the environment variable, which is in MB. +// Return the workspace memory limit in bytes. If no value is set, return the +// default value. +int64_t GetDnnWorkspaceLimit(const string& envvar_in_mb, + int64_t default_value_in_bytes); + +} // namespace tensorflow + +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM + +#endif // TENSORFLOW_CORE_KERNELS_GPU_UTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/hinge-loss.h b/third_party/tflite-hdrs/tensorflow/core/kernels/hinge-loss.h new file mode 100644 index 00000000..51f11e04 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/hinge-loss.h @@ -0,0 +1,126 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_HINGE_LOSS_H_ +#define TENSORFLOW_CORE_KERNELS_HINGE_LOSS_H_ + +#include +#include + +#include "tensorflow/core/kernels/loss.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { + +class HingeLossUpdater : public DualLossUpdater { + public: + // Computes the updated dual variable (corresponding) to a single example. The + // updated dual value maximizes the objective function of the dual + // optimization problem associated with hinge loss (conditioned on keeping the + // rest of the dual variables intact). The method below finds an optimal delta + // (difference between updated and previous dual value) using the update rule + // within SDCA procedure (see http://arxiv.org/pdf/1209.1873v2.pdf, page 5) + // and the particular form of conjugate function for hinge loss. + // + // The CoCoA+ modification is detailed in readme.md. + // + // TODO(sibyl-vie3Poto): Write up a doc with concrete derivation and point to it from + // here. + double ComputeUpdatedDual(const int num_loss_partitions, const double label, + const double example_weight, + const double current_dual, const double wx, + const double weighted_example_norm) const final { + // Intuitively there are 3 cases: + // a. new optimal value of the dual variable falls within the admissible + // range [0, 1]. In this case we set new dual to this value. + // b. new optimal value is < 0. Then, because of convexity, the optimal + // valid value for new dual = 0 + // c. new optimal value > 1.0. Then new optimal value should be set to 1.0. + const double candidate_optimal_dual = + current_dual + (label - wx) / (num_loss_partitions * example_weight * + weighted_example_norm); + if (label * candidate_optimal_dual < 0) { + return 0.0; + } + if (label * candidate_optimal_dual > 1.0) { + return label; + } + return candidate_optimal_dual; + } + + // Conjugate of hinge loss. This is computed as: + // \phi*(z) = z if z \in [-1, 0] and +infinity everywhere else. See for + // instance http://www.eecs.berkeley.edu/~wainwrig/stat241b/lec10.pdf + // Here we want the weighted version of the conjugate loss. It turns out, that + // if w is the weight of an example, the conjugate of the weighted hinge loss + // is given by: + // \phi*(z) = z if z \in [-w, 0] and +infinity everywhere else. Here the + // conjugate function depends not only on the weight of the example but also + // on its label. In particular: + // \phi_y*(z) = y*z if y*z \in [-w, 0] and +infinity everywhere else where + // y \in {-1,1}. The following method implements \phi_y*(-\alpha/w). + double ComputeDualLoss(const double current_dual, const double example_label, + const double example_weight) const final { + // For binary classification, there are 2 conjugate functions, one per + // label value (-1 and 1). + const double y_alpha = current_dual * example_label; // y \alpha + if (y_alpha < 0 || y_alpha > 1.0) { + return std::numeric_limits::max(); + } + return -y_alpha * example_weight; + } + + // Hinge loss for binary classification for a single example. Hinge loss + // equals max(0, 1 - y * wx) (see https://en.wikipedia.org/wiki/Hinge_loss). + // For weighted instances loss should be multiplied by the instance weight. + double ComputePrimalLoss(const double wx, const double example_label, + const double example_weight) const final { + const double y_wx = example_label * wx; + return std::max(0.0, 1 - y_wx) * example_weight; + } + + double PrimalLossDerivative(const double wx, const double label, + const double example_weight) const final { + if (label * wx < 1) { + return -label * example_weight; + } + return 0; + } + + // The smoothness constant is 0 since the derivative of the loss is not + // Lipschitz + double SmoothnessConstant() const final { return 0; } + + // Converts binary example labels from 0.0 or 1.0 to -1.0 or 1.0 respectively + // as expected by hinge loss. + absl::Status ConvertLabel(float* const example_label) const final { + if (*example_label == 0.0) { + *example_label = -1; + return absl::OkStatus(); + } + if (*example_label == 1.0) { + return absl::OkStatus(); + } + return errors::InvalidArgument( + "Only labels of 0.0 or 1.0 are supported right now. " + "Found example with label: ", + *example_label); + } +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_HINGE_LOSS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/histogram_op.h b/third_party/tflite-hdrs/tensorflow/core/kernels/histogram_op.h new file mode 100644 index 00000000..cc6ea006 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/histogram_op.h @@ -0,0 +1,39 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_HISTOGRAM_OP_H_ +#define TENSORFLOW_CORE_KERNELS_HISTOGRAM_OP_H_ + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/core/errors.h" + +namespace tensorflow { +namespace functor { + +template +struct HistogramFixedWidthFunctor { + static absl::Status Compute( + OpKernelContext* context, + const typename TTypes::ConstTensor& values, + const typename TTypes::ConstTensor& value_range, int32_t nbins, + typename TTypes::Tensor& out); +}; + +} // end namespace functor +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_HISTOGRAM_OP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/host_constant_op.h b/third_party/tflite-hdrs/tensorflow/core/kernels/host_constant_op.h new file mode 100644 index 00000000..9ba151ba --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/host_constant_op.h @@ -0,0 +1,44 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_HOST_CONSTANT_OP_H_ +#define TENSORFLOW_CORE_KERNELS_HOST_CONSTANT_OP_H_ + +#include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/platform/macros.h" + +namespace tensorflow { + +// HostConstantOp differs from ConstantOp in that its output is always +// in host memory. +class _HostConstantOp : public OpKernel { + public: + explicit _HostConstantOp(OpKernelConstruction* ctx); + void Compute(OpKernelContext* ctx) override; + bool IsExpensive() override { return false; } + const Tensor* const_tensor() const override { return &tensor_; }; + ~_HostConstantOp() override {} + + private: + Tensor tensor_; + _HostConstantOp(const _HostConstantOp&) = delete; + void operator=(const _HostConstantOp&) = delete; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_HOST_CONSTANT_OP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/identity_n_op.h b/third_party/tflite-hdrs/tensorflow/core/kernels/identity_n_op.h new file mode 100644 index 00000000..7273731f --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/identity_n_op.h @@ -0,0 +1,51 @@ +/* Copyright 2015-2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_IDENTITY_N_OP_H_ +#define TENSORFLOW_CORE_KERNELS_IDENTITY_N_OP_H_ + +#include "tensorflow/core/framework/metrics.h" +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/op_kernel.h" + +namespace tensorflow { + +class IdentityNOp : public OpKernel { + public: + explicit IdentityNOp(OpKernelConstruction* context) : OpKernel(context) {} + + void Compute(OpKernelContext* context) override { + OpInputList input; + OpOutputList output; + OP_REQUIRES_OK(context, context->input_list("input", &input)); + OP_REQUIRES_OK(context, context->output_list("output", &output)); + OP_REQUIRES(context, input.size() == output.size(), + errors::InvalidArgument("Input and output counts must match")); + if (absl::StrContains(name(), kTpuExecuteStagingNodeName)) { + // TPU staging node execution is used for measuring launch latency. + metrics::UpdateTpuVariableDistributionTime(EnvTime::NowMicros() - + context->start_time_usecs()); + } + for (int i = 0; i < input.size(); ++i) { + output.set(i, input[i]); + } + } + + bool IsExpensive() override { return false; } +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_IDENTITY_N_OP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/identity_op.h b/third_party/tflite-hdrs/tensorflow/core/kernels/identity_op.h new file mode 100644 index 00000000..6b74868a --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/identity_op.h @@ -0,0 +1,40 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_IDENTITY_OP_H_ +#define TENSORFLOW_CORE_KERNELS_IDENTITY_OP_H_ + +#include "tensorflow/core/framework/op_kernel.h" + +namespace tensorflow { + +class IdentityOp : public OpKernel { + public: + explicit IdentityOp(OpKernelConstruction* context) : OpKernel(context) {} + + void Compute(OpKernelContext* context) override { + if (IsRefType(context->input_dtype(0))) { + context->forward_ref_input_to_ref_output(0, 0); + } else { + context->set_output(0, context->input(0)); + } + } + + bool IsExpensive() override { return false; } +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_IDENTITY_OP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/image/adjust_contrast_op.h b/third_party/tflite-hdrs/tensorflow/core/kernels/image/adjust_contrast_op.h new file mode 100644 index 00000000..9981275c --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/image/adjust_contrast_op.h @@ -0,0 +1,127 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_IMAGE_ADJUST_CONTRAST_OP_H_ +#define TENSORFLOW_CORE_KERNELS_IMAGE_ADJUST_CONTRAST_OP_H_ +#include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/tensor_types.h" + +namespace tensorflow { +namespace functor { + +// Functor used by AdjustContrastOp to do the computations. +template +struct AdjustContrast { + void operator()(const Device& d, typename TTypes::ConstTensor input, + typename TTypes::ConstScalar contrast_factor, + typename TTypes::ConstScalar min_value, + typename TTypes::ConstScalar max_value, + typename TTypes::Tensor mean_values, + typename TTypes::Tensor output) { + const int batch = input.dimension(0); + const int height = input.dimension(1); + const int width = input.dimension(2); + const int channels = input.dimension(3); + + Eigen::array scalar_broadcast; + scalar_broadcast[0] = batch; + scalar_broadcast[1] = height; + scalar_broadcast[2] = width; + scalar_broadcast[3] = channels; + + Eigen::IndexList, Eigen::type2index<2> > + reduction_axis; + Eigen::IndexList, int, int, Eigen::type2index<1> > + broadcast_dims; + broadcast_dims.set(1, height); + broadcast_dims.set(2, width); + Eigen::IndexList, Eigen::type2index<1>, int> + reshape_dims; + reshape_dims.set(0, batch); + reshape_dims.set(3, channels); + + Eigen::Sizes<1, 1, 1, 1> scalar; + float num_reduced_coeffs = height * width; + mean_values.device(d) = + (input.template cast().sum(reduction_axis).eval() / + num_reduced_coeffs) + .reshape(reshape_dims) + .broadcast(broadcast_dims); + + auto contrast_factor_tensor = + contrast_factor.reshape(scalar).broadcast(scalar_broadcast); + auto adjusted = + (input.template cast() - mean_values) * contrast_factor_tensor + + mean_values; + auto min_bcast = min_value.reshape(scalar).broadcast(scalar_broadcast); + auto max_bcast = max_value.reshape(scalar).broadcast(scalar_broadcast); + // TODO(wicke): This is rather slow and should be re-written as pure cuda. + output.device(d) = adjusted.cwiseMin(max_bcast).cwiseMax(min_bcast); + } +}; + +// Functor used by AdjustContrastOpv2 to do the computations. +template +struct AdjustContrastv2 { + void operator()(const Device& d, typename TTypes::ConstTensor input, + typename TTypes::ConstScalar contrast_factor, + typename TTypes::Tensor output) { + const int batch = input.dimension(0); + const int height = input.dimension(1); + const int width = input.dimension(2); + const int channels = input.dimension(3); + + Eigen::array scalar_broadcast; + scalar_broadcast[0] = batch; + scalar_broadcast[1] = height; + scalar_broadcast[2] = width; + scalar_broadcast[3] = channels; + + Eigen::IndexList, Eigen::type2index<1> > + reduction_axis; + Eigen::IndexList, int, int, Eigen::type2index<1> > + broadcast_dims; + broadcast_dims.set(1, height); + broadcast_dims.set(2, width); + Eigen::IndexList, Eigen::type2index<1>, int> + reshape_dims; + reshape_dims.set(0, batch); + reshape_dims.set(3, channels); + Eigen::IndexList, Eigen::type2index<2>, + Eigen::type2index<0>, Eigen::type2index<3> > + reduced_dims_first; + + Eigen::Sizes<1, 1, 1, 1> scalar; + float num_reduced_coeffs = height * width; + output.device(d) = (input.template cast() + .shuffle(reduced_dims_first) + .sum(reduction_axis) + .eval() / + num_reduced_coeffs) + .template cast() + .reshape(reshape_dims) + .broadcast(broadcast_dims); + auto contrast_factor_tensor = + contrast_factor.reshape(scalar).broadcast(scalar_broadcast); + auto adjusted = + (input - output).template cast() * contrast_factor_tensor; + output.device(d) += adjusted.template cast(); + } +}; + +} // namespace functor +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_IMAGE_ADJUST_CONTRAST_OP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/image/adjust_hsv_gpu.cu.h b/third_party/tflite-hdrs/tensorflow/core/kernels/image/adjust_hsv_gpu.cu.h new file mode 100644 index 00000000..417ea652 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/image/adjust_hsv_gpu.cu.h @@ -0,0 +1,145 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_KERNELS_IMAGE_ADJUST_HSV_GPU_CU_H_ +#define TENSORFLOW_CORE_KERNELS_IMAGE_ADJUST_HSV_GPU_CU_H_ + +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM + +#define EIGEN_USE_GPU + +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/types.h" + +namespace tensorflow { +namespace internal { + +typedef struct RgbTuple { + float r; + float g; + float b; +} RgbTuple; + +typedef struct HsvTuple { + float h; + float s; + float v; +} HsvTuple; + +inline __device__ HsvTuple rgb2hsv_cuda(const float r, const float g, + const float b) { + HsvTuple tuple; + const float M = fmaxf(r, fmaxf(g, b)); + const float m = fminf(r, fminf(g, b)); + const float chroma = M - m; + float h = 0.0f, s = 0.0f; + // hue + if (chroma > 0.0f) { + if (M == r) { + const float num = (g - b) / chroma; + const float sign = copysignf(1.0f, num); + h = ((sign < 0.0f) * 6.0f + sign * fmodf(sign * num, 6.0f)) / 6.0f; + } else if (M == g) { + h = ((b - r) / chroma + 2.0f) / 6.0f; + } else { + h = ((r - g) / chroma + 4.0f) / 6.0f; + } + } else { + h = 0.0f; + } + // saturation + if (M > 0.0) { + s = chroma / M; + } else { + s = 0.0f; + } + tuple.h = h; + tuple.s = s; + tuple.v = M; + return tuple; +} + +inline __device__ RgbTuple hsv2rgb_cuda(const float h, const float s, + const float v) { + RgbTuple tuple; + const float new_h = h * 6.0f; + const float chroma = v * s; + const float x = chroma * (1.0f - fabsf(fmodf(new_h, 2.0f) - 1.0f)); + const float new_m = v - chroma; + const bool between_0_and_1 = new_h >= 0.0f && new_h < 1.0f; + const bool between_1_and_2 = new_h >= 1.0f && new_h < 2.0f; + const bool between_2_and_3 = new_h >= 2.0f && new_h < 3.0f; + const bool between_3_and_4 = new_h >= 3.0f && new_h < 4.0f; + const bool between_4_and_5 = new_h >= 4.0f && new_h < 5.0f; + const bool between_5_and_6 = new_h >= 5.0f && new_h < 6.0f; + tuple.r = chroma * (between_0_and_1 || between_5_and_6) + + x * (between_1_and_2 || between_4_and_5) + new_m; + tuple.g = chroma * (between_1_and_2 || between_2_and_3) + + x * (between_0_and_1 || between_3_and_4) + new_m; + tuple.b = chroma * (between_3_and_4 || between_4_and_5) + + x * (between_2_and_3 || between_5_and_6) + new_m; + return tuple; +} + +template +__global__ void adjust_hsv_nhwc( + const int64 number_elements, const T* const __restrict__ input, + T* const __restrict__ output, const float* const __restrict__ hue_delta, + const float* const __restrict__ saturation_scale, + const float* const __restrict__ value_scale) { + // multiply by 3 since we're dealing with contiguous RGB bytes for each pixel + // (NHWC) + for (int64 idx = (blockDim.x * blockIdx.x + threadIdx.x) * 3; + idx < number_elements; idx += blockDim.x * gridDim.x * 3) { + if (!AdjustHue && !AdjustSaturation && !AdjustV) { + output[idx] = input[idx]; + output[idx + 1] = input[idx + 1]; + output[idx + 2] = input[idx + 2]; + continue; + } + const HsvTuple hsv = rgb2hsv_cuda(static_cast(input[idx]), + static_cast(input[idx + 1]), + static_cast(input[idx + 2])); + float new_h = hsv.h; + float new_s = hsv.s; + float new_v = hsv.v; + // hue adjustment + if (AdjustHue) { + const float delta = *hue_delta; + new_h = fmodf(hsv.h + delta, 1.0f); + if (new_h < 0.0f) { + new_h = fmodf(1.0f + new_h, 1.0f); + } + } + // saturation adjustment + if (AdjustSaturation && saturation_scale != nullptr) { + const float scale = *saturation_scale; + new_s = fminf(1.0f, fmaxf(0.0f, hsv.s * scale)); + } + // value adjustment + if (AdjustV && value_scale != nullptr) { + const float scale = *value_scale; + new_v = hsv.v * scale; + } + const RgbTuple rgb = hsv2rgb_cuda(new_h, new_s, new_v); + output[idx] = static_cast(rgb.r); + output[idx + 1] = static_cast(rgb.g); + output[idx + 2] = static_cast(rgb.b); + } +} + +} // namespace internal +} // namespace tensorflow + +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM +#endif // TENSORFLOW_CORE_KERNELS_IMAGE_ADJUST_HSV_GPU_CU_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/image/adjust_hue_op.h b/third_party/tflite-hdrs/tensorflow/core/kernels/image/adjust_hue_op.h new file mode 100644 index 00000000..788b61bc --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/image/adjust_hue_op.h @@ -0,0 +1,41 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_KERNELS_IMAGE_ADJUST_HUE_OP_H_ +#define TENSORFLOW_CORE_KERNELS_IMAGE_ADJUST_HUE_OP_H_ + +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM +#define EIGEN_USE_GPU + +#include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive + +#include "tensorflow/core/framework/types.h" + +namespace tensorflow { + +typedef Eigen::GpuDevice GPUDevice; + +namespace functor { + +template +struct AdjustHueGPU { + void operator()(GPUDevice* device, const int64_t number_of_elements, + const T* const input, const float* const delta, + T* const output); +}; + +} // namespace functor +} // namespace tensorflow + +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM +#endif // TENSORFLOW_CORE_KERNELS_IMAGE_ADJUST_HUE_OP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/image/adjust_saturation_op.h b/third_party/tflite-hdrs/tensorflow/core/kernels/image/adjust_saturation_op.h new file mode 100644 index 00000000..278161bd --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/image/adjust_saturation_op.h @@ -0,0 +1,41 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_KERNELS_IMAGE_ADJUST_SATURATION_OP_H_ +#define TENSORFLOW_CORE_KERNELS_IMAGE_ADJUST_SATURATION_OP_H_ + +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM +#define EIGEN_USE_GPU + +#include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive + +#include "tensorflow/core/framework/types.h" + +namespace tensorflow { + +typedef Eigen::GpuDevice GPUDevice; + +namespace functor { + +template +struct AdjustSaturationGPU { + void operator()(GPUDevice* device, const int64_t number_of_elements, + const T* const input, const float* const scale, + T* const output); +}; + +} // namespace functor +} // namespace tensorflow + +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM +#endif // TENSORFLOW_CORE_KERNELS_IMAGE_ADJUST_SATURATION_OP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/image/colorspace_op.h b/third_party/tflite-hdrs/tensorflow/core/kernels/image/colorspace_op.h new file mode 100644 index 00000000..b71f058f --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/image/colorspace_op.h @@ -0,0 +1,90 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_IMAGE_COLORSPACE_OP_H_ +#define TENSORFLOW_CORE_KERNELS_IMAGE_COLORSPACE_OP_H_ + +#include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/tensor_types.h" + +namespace tensorflow { + +namespace functor { + +template +struct RGBToHSV { + void operator()(const Device &d, + typename TTypes::ConstTensor input_data, + typename TTypes::Tensor range, + typename TTypes::Tensor output_data) { + auto H = output_data.template chip<1>(0); + auto S = output_data.template chip<1>(1); + auto V = output_data.template chip<1>(2); + + auto R = input_data.template chip<1>(0); + auto G = input_data.template chip<1>(1); + auto B = input_data.template chip<1>(2); + + Eigen::IndexList > channel_axis; + + V.device(d) = input_data.maximum(channel_axis); + + range.device(d) = V - input_data.minimum(channel_axis); + + S.device(d) = (V > T(0)).select(range / V, V.constant(T(0))); + + auto norm = range.inverse() * (T(1) / T(6)); + // TODO(wicke): all these assignments are only necessary because a combined + // expression is larger than kernel parameter space. A custom kernel is + // probably in order. + H.device(d) = (R == V).select( + norm * (G - B), (G == V).select(norm * (B - R) + T(2) / T(6), + norm * (R - G) + T(4) / T(6))); + H.device(d) = (range > T(0)).select(H, H.constant(T(0))); + H.device(d) = (H < T(0)).select(H + T(1), H); + } +}; + +template +struct HSVToRGB { + void operator()(const Device &d, + typename TTypes::ConstTensor input_data, + typename TTypes::Tensor output_data) { + auto H = input_data.template chip<1>(0); + auto S = input_data.template chip<1>(1); + auto V = input_data.template chip<1>(2); + + // TODO(wicke): compute only the fractional part of H for robustness + auto dh = H * T(6); + auto dr = ((dh - T(3)).abs() - T(1)).cwiseMax(T(0)).cwiseMin(T(1)); + auto dg = (-(dh - T(2)).abs() + T(2)).cwiseMax(T(0)).cwiseMin(T(1)); + auto db = (-(dh - T(4)).abs() + T(2)).cwiseMax(T(0)).cwiseMin(T(1)); + auto one_s = -S + T(1); + + auto R = output_data.template chip<1>(0); + auto G = output_data.template chip<1>(1); + auto B = output_data.template chip<1>(2); + + R.device(d) = (one_s + S * dr) * V; + G.device(d) = (one_s + S * dg) * V; + B.device(d) = (one_s + S * db) * V; + } +}; + +} // namespace functor +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_IMAGE_COLORSPACE_OP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/image/crop_and_resize_op.h b/third_party/tflite-hdrs/tensorflow/core/kernels/image/crop_and_resize_op.h new file mode 100644 index 00000000..dd838ea5 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/image/crop_and_resize_op.h @@ -0,0 +1,72 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_IMAGE_CROP_AND_RESIZE_OP_H_ +#define TENSORFLOW_CORE_KERNELS_IMAGE_CROP_AND_RESIZE_OP_H_ + +#include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/numeric_types.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor_types.h" + +namespace tensorflow { +namespace functor { + +template +struct CropAndResize { + // We assume that the tensor sizes are correct. + bool operator()(const OpKernelContext* context, + typename TTypes::ConstTensor image, + typename TTypes::ConstTensor boxes, + typename TTypes::ConstTensor box_ind, + const std::string& method_name, float extrapolation_value, + typename TTypes::Tensor crops); +}; + +template +struct CropAndResizeBackpropImage { + // We assume that the tensor sizes are correct. + bool operator()(const OpKernelContext* context, + typename TTypes::ConstTensor grads, + typename TTypes::ConstTensor boxes, + typename TTypes::ConstTensor box_ind, + typename TTypes::Tensor grads_image, + const std::string& method_name); +}; + +template +struct CropAndResizeBackpropBoxes { + // We assume that the tensor sizes are correct. + bool operator()(const Device& d, typename TTypes::ConstTensor grads, + typename TTypes::ConstTensor image, + typename TTypes::ConstTensor boxes, + typename TTypes::ConstTensor box_ind, + typename TTypes::Tensor grads_boxes); +}; + +template +struct CheckValidBoxIndexHelper { + // Checks if all values in box_index are in [0, batch). + void operator()(const Device& d, + typename TTypes::ConstTensor box_index, int batch, + typename TTypes::Tensor isvalid) { + isvalid.device(d) = ((box_index >= 0) && (box_index < batch)).all(); + } +}; + +} // namespace functor +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_IMAGE_CROP_AND_RESIZE_OP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/image/extract_image_patches_op.h b/third_party/tflite-hdrs/tensorflow/core/kernels/image/extract_image_patches_op.h new file mode 100644 index 00000000..3dc2f323 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/image/extract_image_patches_op.h @@ -0,0 +1,51 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_IMAGE_EXTRACT_IMAGE_PATCHES_OP_H_ +#define TENSORFLOW_CORE_KERNELS_IMAGE_EXTRACT_IMAGE_PATCHES_OP_H_ + +#include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/tensor_types.h" + +namespace tensorflow { +namespace functor { + +template +struct ExtractImagePatchesForward { + void operator()(const Device& d, typename TTypes::ConstTensor input, + int patch_rows, int patch_cols, int stride_rows, + int stride_cols, int rate_rows, int rate_cols, + const Eigen::PaddingType& padding, + typename TTypes::Tensor output) { + // Need to swap row/col when calling Eigen, because our data is in + // NHWC format while Eigen assumes NWHC format. + MaybeWith32BitIndexing( + [&](auto input32, auto output32) { + output32.device(d) = + input32 + .extract_image_patches(patch_cols, patch_rows, stride_cols, + stride_rows, rate_cols, rate_rows, + padding) + .reshape(output32.dimensions()); + }, + input, output); + } +}; + +} // namespace functor +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_IMAGE_EXTRACT_IMAGE_PATCHES_OP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/image/extract_volume_patches_op.h b/third_party/tflite-hdrs/tensorflow/core/kernels/image/extract_volume_patches_op.h new file mode 100644 index 00000000..9e134818 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/image/extract_volume_patches_op.h @@ -0,0 +1,50 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_IMAGE_EXTRACT_VOLUME_PATCHES_OP_H_ +#define TENSORFLOW_CORE_KERNELS_IMAGE_EXTRACT_VOLUME_PATCHES_OP_H_ + +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/tensor_types.h" +#include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive + +namespace tensorflow { +namespace functor { + +template +struct ExtractVolumePatchesForward { + void operator()(const Device& d, typename TTypes::ConstTensor input, + int patch_planes, int patch_rows, int patch_cols, + int stride_planes, int stride_rows, int stride_cols, + /* int rate_planes, int rate_rows, int rate_cols, */ + const Eigen::PaddingType& padding, + typename TTypes::Tensor output) { + MaybeWith32BitIndexing( + [&](auto input32, auto output32) { + output32.device(d) = + input32 + .extract_volume_patches(patch_cols, patch_rows, patch_planes, + stride_cols, stride_rows, + stride_planes, padding) + .reshape(output32.dimensions()); + }, + input, output); + } +}; + +} // namespace functor +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_IMAGE_EXTRACT_VOLUME_PATCHES_OP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/image/image_ops.h b/third_party/tflite-hdrs/tensorflow/core/kernels/image/image_ops.h new file mode 100644 index 00000000..914cb528 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/image/image_ops.h @@ -0,0 +1,278 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_IMAGE_IMAGE_OPS_H_ +#define TENSORFLOW_CORE_KERNELS_IMAGE_IMAGE_OPS_H_ + +// See docs in ../ops/image_ops.cc. + +#define EIGEN_USE_THREADS + +#include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive + +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +namespace generator { + +enum Interpolation { NEAREST, BILINEAR }; +enum Mode { FILL_REFLECT, FILL_WRAP, FILL_CONSTANT, FILL_NEAREST }; + +using Eigen::array; +using Eigen::DenseIndex; + +// Follow scipy's implementation +// https://github.com/scipy/scipy/blob/master/scipy/ndimage/src/ni_interpolation.c +template +struct MapCoordinate { + float operator()(const float out_coord, const DenseIndex len); +}; + +template +struct MapCoordinate { + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE float operator()(const float out_coord, + const DenseIndex len) { + // Reflect [abcd] to [dcba|abcd|dcba]. + float in_coord = out_coord; + if (in_coord < 0) { + if (len <= 1) { + in_coord = 0; + } else { + const DenseIndex sz2 = 2 * len; + if (in_coord < sz2) { + in_coord = sz2 * static_cast(-in_coord / sz2) + in_coord; + } + in_coord = (in_coord < -len) ? in_coord + sz2 : -in_coord - 1; + } + } else if (in_coord > len - 1) { + if (len <= 1) { + in_coord = 0; + } else { + const DenseIndex sz2 = 2 * len; + in_coord -= sz2 * static_cast(in_coord / sz2); + if (in_coord >= len) { + in_coord = sz2 - in_coord - 1; + } + } + } + // clamp is necessary because when out_coord = 3.5 and len = 4, + // in_coord = 3.5 and will be rounded to 4 in nearest interpolation. + return Eigen::internal::scalar_clamp_op(0.0f, len - 1)(in_coord); + } +}; + +template +struct MapCoordinate { + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE float operator()(const float out_coord, + const DenseIndex len) { + // Wrap [abcd] to [abcd|abcd|abcd]. + float in_coord = out_coord; + if (in_coord < 0) { + if (len <= 1) { + in_coord = 0; + } else { + const DenseIndex sz = len - 1; + in_coord += len * (static_cast(-in_coord / sz) + 1); + } + } else if (in_coord > len - 1) { + if (len <= 1) { + in_coord = 0; + } else { + const DenseIndex sz = len - 1; + in_coord -= len * static_cast(in_coord / sz); + } + } + // clamp is necessary because when out_coord = -0.5 and len = 4, + // in_coord = 3.5 and will be rounded to 4 in nearest interpolation. + return Eigen::internal::scalar_clamp_op(0.0f, len - 1)(in_coord); + } +}; + +template +struct MapCoordinate { + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE float operator()(const float out_coord, + const DenseIndex len) { + return out_coord; + } +}; + +template +struct MapCoordinate { + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE float operator()(const float out_coord, + const DenseIndex len) { + return Eigen::internal::scalar_clamp_op(0.0f, len - 1)(out_coord); + } +}; + +template +class ProjectiveGenerator { + private: + typename TTypes::ConstTensor input_; + typename TTypes::ConstMatrix transforms_; + const Interpolation interpolation_; + const T fill_value_; + + public: + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE + ProjectiveGenerator(typename TTypes::ConstTensor input, + typename TTypes::ConstMatrix transforms, + const Interpolation interpolation, const T fill_value) + : input_(input), + transforms_(transforms), + interpolation_(interpolation), + fill_value_(fill_value) {} + + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T + operator()(const array& coords) const { + const int64_t output_y = coords[1]; + const int64_t output_x = coords[2]; + const float* transform = + transforms_.dimension(0) == 1 + ? transforms_.data() + : &transforms_.data()[transforms_.dimension(1) * coords[0]]; + float projection = transform[6] * output_x + transform[7] * output_y + 1.f; + if (projection == 0) { + // Return the fill value for infinite coordinates, + // which are outside the input image + return fill_value_; + } + const float input_x = + (transform[0] * output_x + transform[1] * output_y + transform[2]) / + projection; + const float input_y = + (transform[3] * output_x + transform[4] * output_y + transform[5]) / + projection; + + // Map out-of-boundary input coordinates to in-boundary based on fill_mode. + auto map_functor = MapCoordinate(); + const float x = map_functor(input_x, input_.dimension(2)); + const float y = map_functor(input_y, input_.dimension(1)); + + const DenseIndex batch = coords[0]; + const DenseIndex channels = coords[3]; + switch (interpolation_) { + case NEAREST: + return nearest_interpolation(batch, y, x, channels, fill_value_); + case BILINEAR: + return bilinear_interpolation(batch, y, x, channels, fill_value_); + } + // Unreachable; ImageProjectiveTransform only uses INTERPOLATION_NEAREST + // or INTERPOLATION_BILINEAR. + return fill_value_; + } + + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T + nearest_interpolation(const DenseIndex batch, const float y, const float x, + const DenseIndex channel, const T fill_value) const { + return read_with_fill_value(batch, DenseIndex(std::round(y)), + DenseIndex(std::round(x)), channel, fill_value); + } + + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T + bilinear_interpolation(const DenseIndex batch, const float y, const float x, + const DenseIndex channel, const T fill_value) const { + const float y_floor = std::floor(y); + const float x_floor = std::floor(x); + const float y_ceil = y_floor + 1; + const float x_ceil = x_floor + 1; + // f(x, y_floor) = (x_ceil - x) / (x_ceil - x_floor) * f(x_floor, y_floor) + // + (x - x_floor) / (x_ceil - x_floor) * f(x_ceil, y_floor) + const float value_yfloor = + (x_ceil - x) * static_cast(read_with_fill_value( + batch, DenseIndex(y_floor), DenseIndex(x_floor), + channel, fill_value)) + + (x - x_floor) * static_cast(read_with_fill_value( + batch, DenseIndex(y_floor), DenseIndex(x_ceil), + channel, fill_value)); + // f(x, y_ceil) = (x_ceil - x) / (x_ceil - x_floor) * f(x_floor, y_ceil) + // + (x - x_floor) / (x_ceil - x_floor) * f(x_ceil, y_ceil) + const float value_yceil = + (x_ceil - x) * static_cast(read_with_fill_value( + batch, DenseIndex(y_ceil), DenseIndex(x_floor), + channel, fill_value)) + + (x - x_floor) * static_cast(read_with_fill_value( + batch, DenseIndex(y_ceil), DenseIndex(x_ceil), + channel, fill_value)); + // f(x, y) = (y_ceil - y) / (y_ceil - y_floor) * f(x, y_floor) + // + (y - y_floor) / (y_ceil - y_floor) * f(x, y_ceil) + return T((y_ceil - y) * value_yfloor + (y - y_floor) * value_yceil); + } + + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T read_with_fill_value( + const DenseIndex batch, const DenseIndex y, const DenseIndex x, + const DenseIndex channel, const T fill_value) const { + // batch and channel must be correct, because they are passed unchanged from + // the input. + return (0 <= y && y < input_.dimension(1) && 0 <= x && + x < input_.dimension(2)) + ? input_(array{batch, y, x, channel}) + : fill_value; + } +}; + +} // end namespace generator + +namespace functor { + +using generator::Interpolation; +using generator::Mode; +using generator::ProjectiveGenerator; + +template +struct FillProjectiveTransform { + typedef typename TTypes::Tensor OutputType; + typedef typename TTypes::ConstTensor InputType; + typedef typename TTypes::ConstTensor TransformsType; + const Interpolation interpolation; + + explicit FillProjectiveTransform(Interpolation interpolation) + : interpolation(interpolation) {} + + EIGEN_ALWAYS_INLINE + void operator()(const Device& device, OutputType* output, + const InputType& images, const TransformsType& transform, + const Mode fill_mode, const T fill_value) const { + switch (fill_mode) { + case Mode::FILL_REFLECT: + output->device(device) = + output->generate(ProjectiveGenerator( + images, transform, interpolation, fill_value)); + break; + case Mode::FILL_WRAP: + output->device(device) = + output->generate(ProjectiveGenerator( + images, transform, interpolation, fill_value)); + break; + case Mode::FILL_CONSTANT: + output->device(device) = output->generate( + ProjectiveGenerator( + images, transform, interpolation, fill_value)); + break; + case Mode::FILL_NEAREST: + output->device(device) = + output->generate(ProjectiveGenerator( + images, transform, interpolation, fill_value)); + break; + } + } +}; + +} // end namespace functor + +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_IMAGE_IMAGE_OPS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/image/mirror_pad_op.h b/third_party/tflite-hdrs/tensorflow/core/kernels/image/mirror_pad_op.h new file mode 100644 index 00000000..7c3df978 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/image/mirror_pad_op.h @@ -0,0 +1,445 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_IMAGE_MIRROR_PAD_OP_H_ +#define TENSORFLOW_CORE_KERNELS_IMAGE_MIRROR_PAD_OP_H_ + +#include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/platform/types.h" + +namespace Eigen { +template +class TensorMirrorPadOp; + +namespace internal { +template +struct traits> + : public traits { + typedef typename XprType::Scalar Scalar; + typedef traits XprTraits; + typedef typename XprTraits::StorageKind StorageKind; + typedef typename XprTraits::Index Index; + typedef typename XprType::Nested Nested; + typedef std::remove_reference_t _Nested; + static constexpr int NumDimensions = XprTraits::NumDimensions; + static constexpr int Layout = XprTraits::Layout; +}; + +template +struct eval, Eigen::Dense> { + typedef const TensorMirrorPadOp& type; +}; + +template +struct nested< + TensorMirrorPadOp, 1, + typename eval>::type> { + typedef TensorMirrorPadOp type; +}; +} // namespace internal + +template +class TensorMirrorPadOp + : public TensorBase, + ReadOnlyAccessors> { + public: + typedef typename Eigen::internal::traits::Scalar Scalar; + typedef typename Eigen::NumTraits::Real RealScalar; + typedef typename XprType::CoeffReturnType CoeffReturnType; + typedef typename Eigen::internal::nested::type Nested; + typedef typename Eigen::internal::traits::StorageKind + StorageKind; + typedef typename Eigen::internal::traits::Index Index; + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorMirrorPadOp( + const XprType& expr, const PaddingDimensions& padding_dims, Index offset) + : xpr_(expr), padding_dims_(padding_dims), offset_(offset) {} + + EIGEN_DEVICE_FUNC + const PaddingDimensions& padding() const { return padding_dims_; } + + EIGEN_DEVICE_FUNC + Index offset() const { return offset_; } + + EIGEN_DEVICE_FUNC + const typename internal::remove_all::type& + expression() const { + return xpr_; + } + + protected: + typename XprType::Nested xpr_; + const PaddingDimensions padding_dims_; + const Index offset_; +}; + +// Eval as rvalue +template +struct TensorEvaluator, + Device> { + typedef TensorMirrorPadOp XprType; + typedef typename XprType::Index Index; + static constexpr int Dims = internal::array_size::value; + typedef DSizes Dimensions; + typedef typename XprType::Scalar Scalar; + typedef typename XprType::CoeffReturnType CoeffReturnType; + // Copied from Eigen3 Github version 0e806c1. + typedef typename PacketType::type PacketReturnType; + + enum { + IsAligned = false, + PacketAccess = TensorEvaluator::PacketAccess, + BlockAccess = false, + BlockAccessV2 = false, + PreferBlockAccess = false, + Layout = TensorEvaluator::Layout, + CoordAccess = true, + RawAccess = false + }; + + //===- Tensor block evaluation strategy (see TensorBlock.h) -------------===// + typedef internal::TensorBlockNotImplemented TensorBlock; + //===--------------------------------------------------------------------===// + + EIGEN_STRONG_INLINE TensorEvaluator(const XprType& op, const Device& device) + : impl_(op.expression(), device), padding_(op.padding()) { + EIGEN_STATIC_ASSERT(Dims > 0, YOU_MADE_A_PROGRAMMING_MISTAKE) + + // op.offset() == 0 if padding mode is symmetric. + // op.offset() == 1 if padding mode is reflect. + eigen_assert(op.offset() == 0 || op.offset() == 1); + left_offset_ = -1 + op.offset(); + right_offset_ = -1 - op.offset(); + + // This should trigger compilation error if padding dimensions and + // expression dimensions do not match. + dimensions_ = impl_.dimensions(); + for (int dim = 0; dim < Dims; ++dim) { + eigen_assert(padding_[dim].first + op.offset() <= dimensions_[dim]); + eigen_assert(padding_[dim].second + op.offset() <= dimensions_[dim]); + dimensions_[dim] += padding_[dim].first + padding_[dim].second; + } + + const auto& input_dims = impl_.dimensions(); + if (static_cast(Layout) == static_cast(ColMajor)) { + input_strides_[0] = 1; + output_strides_[0] = 1; + for (int i = 0; i < Dims - 1; ++i) { + input_strides_[i + 1] = input_strides_[i] * input_dims[i]; + output_strides_[i + 1] = output_strides_[i] * dimensions_[i]; + } + } else { + input_strides_[numext::maxi(0, Dims - 1)] = 1; + output_strides_[numext::maxi(0, Dims - 1)] = 1; + for (int i = Dims - 1; i > 0; --i) { + input_strides_[i - 1] = input_strides_[i] * input_dims[i]; + output_strides_[i - 1] = output_strides_[i] * dimensions_[i]; + } + } + } + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { + return dimensions_; + } + + EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(Scalar*) { + impl_.evalSubExprsIfNeeded(nullptr); + return true; + } + + EIGEN_STRONG_INLINE void cleanup() { impl_.cleanup(); } + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType + coeff(Index index) const { + eigen_assert(index < dimensions().TotalSize()); + const Index input_index = ToInputIndex(index); + return impl_.coeff(input_index); + } + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType + coeff(array coords) const { + for (int dim = 0; dim < Dims; ++dim) { + coords[dim] = ToInputCoord(coords[dim], dim); + } + ReadInputHelper::CoordAccess> helper; + return helper(coords, input_strides_, impl_); + } + + template + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketReturnType + packet(Index index) const { + constexpr int kPacketSize = + internal::unpacket_traits::size; + + EIGEN_STATIC_ASSERT(kPacketSize > 1, YOU_MADE_A_PROGRAMMING_MISTAKE) + eigen_assert(index + kPacketSize <= dimensions().TotalSize()); + + // Find the effective inner-most dimension where padding actually happens. + // NOTE: This is independent of index argument, and can be done in the + // constructor to save computation. However, if packet access does not + // happen, then moving to constructor will incur needless overhead. + int dim = -1; + if (static_cast(Layout) == static_cast(ColMajor)) { + for (int k = 0; k < Dims; ++k) { + if (padding_[k].first != 0 || padding_[k].second != 0) { + dim = k; + break; + } + } + } else { + for (int k = Dims - 1; k >= 0; --k) { + if (padding_[k].first != 0 || padding_[k].second != 0) { + dim = k; + break; + } + } + } + + const Index input_index = ToInputIndex(index); + + // If dim < 0, this means there is no padding at all. + if (dim < 0) { + return impl_.template packet(input_index); + } + + // Check if the way from the begin of the packet to the end of the packet + // is paved with contiguous road. That is, the indices must be between the + // padded region in the effective inner-most dimension. + const Index left = padding_[dim].first * output_strides_[dim]; + const Index right = + (dimensions_[dim] - padding_[dim].second) * output_strides_[dim]; + + const Index index_mod = index % (dimensions_[dim] * output_strides_[dim]); + if (left <= index_mod && (index_mod + kPacketSize - 1) < right) { + return impl_.template packet(input_index); + } + + // If the road is not contiguous, then fall back to coeff(). + EIGEN_ALIGN_MAX std::remove_const_t values[kPacketSize]; + values[0] = impl_.coeff(input_index); + for (int i = 1; i < kPacketSize; ++i) { + values[i] = coeff(index + i); + } + PacketReturnType result = internal::pload(values); + return result; + } + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost + costPerCoeff(bool vectorized) const { + constexpr int kPacketSize = + internal::unpacket_traits::size; + + const double compute_cost = Dims * (7 * TensorOpCost::AddCost() + + 2 * TensorOpCost::MulCost() + + TensorOpCost::DivCost()); + return impl_.costPerCoeff(vectorized) + + TensorOpCost(1, 0, compute_cost, vectorized, kPacketSize); + } + + EIGEN_DEVICE_FUNC Scalar* data() const { return nullptr; } + + protected: + using Coords = array; + + // Full template specialization is not allowed within non-fully specialized + // template class. Adding a dummy parameter to make specializations partial. + template + struct ReadInputHelper; + + template + struct ReadInputHelper { + template + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Index + operator()(const Coords& coord, const Coords& strides, const Eval& eval) { + Index index = 0; + for (int k = 0; k < Dims; ++k) { + index += coord[k] * strides[k]; + } + return eval.coeff(index); + } + }; + + template + struct ReadInputHelper { + template + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Index + operator()(const Coords& coord, const Coords& strides, const Eval& eval) { + return eval.coeff(coord); + } + }; + + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Index ToInputCoord(Index k, + int dim) const { + const Index m = impl_.dimensions()[dim]; + k -= padding_[dim].first; + if (k < 0) { + return -k + left_offset_; + } + if (k < m) { + return k; + } + return m - (k - m) + right_offset_; + } + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index + ToInputIndex(const Coords& coords) const { + Index input_index = 0; + for (int dim = 0; dim < Dims; ++dim) { + input_index += ToInputCoord(coords[dim], dim) * input_strides_[dim]; + } + return input_index; + } + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index ToInputIndex(Index index) const { + Index input_index = 0; + if (static_cast(Layout) == static_cast(ColMajor)) { + for (int dim = Dims - 1; dim > 0; --dim) { + const Index k = index / output_strides_[dim]; + index -= k * output_strides_[dim]; + input_index += ToInputCoord(k, dim) * input_strides_[dim]; + } + input_index += ToInputCoord(index, 0); + } else { + for (int dim = 0; dim < Dims - 1; ++dim) { + const Index k = index / output_strides_[dim]; + index -= k * output_strides_[dim]; + input_index += ToInputCoord(k, dim) * input_strides_[dim]; + } + input_index += ToInputCoord(index, Dims - 1); + } + + return input_index; + } + + TensorEvaluator impl_; + PaddingDimensions padding_; + Dimensions dimensions_; + array input_strides_; + array output_strides_; + + Index left_offset_; + Index right_offset_; +}; +} // namespace Eigen + +namespace tensorflow { +namespace functor { + +// offset argument must be either 0 or 1. This controls whether the boundary +// values are replicated (offset == 0) or not replicated (offset == 1). +template +struct MirrorPad { + void operator()(const Device& device, + typename TTypes::Tensor output, + typename TTypes::ConstTensor input, + typename TTypes::ConstMatrix padding, int offset) { + Eigen::array, Dims> padding_dims; + + for (int i = 0; i < Dims; ++i) { + padding_dims[i] = Eigen::IndexPair(padding(i, 0), padding(i, 1)); + } + + output.device(device) = MirrorPadOp(input, padding_dims, offset); + } + + template + static const Eigen::TensorMirrorPadOp + MirrorPadOp( + const Eigen::TensorBase& tensor, + const PaddingDimensions& padding, int offset) { + return Eigen::TensorMirrorPadOp( + static_cast(tensor), padding, offset); + } +}; + +// offset argument must be either 0 or 1. This controls whether the boundary +// values are replicated (offset == 0) or not replicated (offset == 1). +template +struct MirrorPadGrad { + void operator()(const Device& device, + typename TTypes::Tensor output, + typename TTypes::ConstTensor input, + typename TTypes::ConstMatrix paddings, int offset, + typename TTypes::Tensor scratch) { + // Copy the gradient input into the scratch buffer. + scratch.device(device) = input; + + Eigen::array lhs_offsets; + Eigen::array rhs_offsets; + Eigen::array extents; + Eigen::array reverses; + + for (int i = 0; i < Dims; ++i) { + lhs_offsets[i] = 0; + rhs_offsets[i] = 0; + extents[i] = scratch.dimension(i); + reverses[i] = false; + } + + // At this point, the central part (non-padded area) does not include the + // gradients back-propagated through padded areas. Those gradient components + // need be added to the central part. + // + // Note that a gradient input element falls into a padded area iff in at + // least one dimension i, the coordinate x(i) is in the range (python-style) + // [:paddings(i,0)] or [-paddings(i,1):]. + + for (int i = 0; i < Dims; ++i) { + reverses[i] = true; + + // This handles the case when coordinate in dimension i is in the range + // [:paddings(i,0)]. This portion is added to the range + // [paddings(i,0) + offset:2 * paddings(i,0) + offset]. + if (paddings(i, 0) > 0) { + rhs_offsets[i] = 0; + lhs_offsets[i] = paddings(i, 0) + offset; + extents[i] = paddings(i, 0); + + scratch.slice(lhs_offsets, extents).device(device) += + scratch.slice(rhs_offsets, extents).reverse(reverses); + } + + // This handles the case when coordinate in dimension i is in the range + // [-paddings(i,1):]. This portion is added to the range + // [-2 * paddings(i,1) - offset:-paddings(i,1) - offset]. + if (paddings(i, 1) > 0) { + rhs_offsets[i] = scratch.dimension(i) - paddings(i, 1); + lhs_offsets[i] = rhs_offsets[i] - paddings(i, 1) - offset; + extents[i] = paddings(i, 1); + + scratch.slice(lhs_offsets, extents).device(device) += + scratch.slice(rhs_offsets, extents).reverse(reverses); + } + + reverses[i] = false; + lhs_offsets[i] = paddings(i, 0); + rhs_offsets[i] = paddings(i, 0); + extents[i] = output.dimension(i); + + // At this point, scratch buffer contains gradient input as if paddings + // for dimension k = 0,...,i are zeros. Therefore after the loop + // termination, the central part of the scratch buffer contains the folded + // gradients. + } + + // Copy the central part of the scratch buffer to the output. + output.device(device) = scratch.slice(rhs_offsets, extents); + } +}; +} // namespace functor +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_IMAGE_MIRROR_PAD_OP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/image/mirror_pad_op_cpu_impl.h b/third_party/tflite-hdrs/tensorflow/core/kernels/image/mirror_pad_op_cpu_impl.h new file mode 100644 index 00000000..b138ae0c --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/image/mirror_pad_op_cpu_impl.h @@ -0,0 +1,47 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_IMAGE_MIRROR_PAD_OP_CPU_IMPL_H_ +#define TENSORFLOW_CORE_KERNELS_IMAGE_MIRROR_PAD_OP_CPU_IMPL_H_ + +#if CPU_PROVIDED_IXDIM +#define EIGEN_USE_THREADS + +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/kernels/image/mirror_pad_op.h" + +namespace tensorflow { + +using CpuDevice = Eigen::ThreadPoolDevice; + +#define DEFINE_CPU_SPECS(T) \ + template struct functor::MirrorPad; \ + template struct functor::MirrorPad; +TF_CALL_POD_TYPES(DEFINE_CPU_SPECS); +TF_CALL_QUANTIZED_TYPES(DEFINE_CPU_SPECS); +TF_CALL_tstring(DEFINE_CPU_SPECS); +#undef DEFINE_CPU_SPECS + +#define DEFINE_CPU_SPECS(T) \ + template struct functor::MirrorPadGrad; \ + template struct functor::MirrorPadGrad; +TF_CALL_NUMBER_TYPES(DEFINE_CPU_SPECS); +#undef DEFINE_CPU_SPECS +} // namespace tensorflow + +#endif // CPU_PROVIDED_IXDIM +#endif // TENSORFLOW_CORE_KERNELS_IMAGE_MIRROR_PAD_OP_CPU_IMPL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/image/non_max_suppression_op.h b/third_party/tflite-hdrs/tensorflow/core/kernels/image/non_max_suppression_op.h new file mode 100644 index 00000000..04828b07 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/image/non_max_suppression_op.h @@ -0,0 +1,50 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_IMAGE_NON_MAX_SUPPRESSION_OP_H_ +#define TENSORFLOW_CORE_KERNELS_IMAGE_NON_MAX_SUPPRESSION_OP_H_ + +#include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/numeric_types.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor_types.h" + +namespace tensorflow { + +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM +extern const int kNmsBoxesPerTread; + +// Given descending sorted box list, apply non-maximal-suppression with given +// threshold and select boxes to keep. +// - d_sorted_boxes_float_ptr: a pointer to device memory float array +// containing the box corners for N boxes sorted in descending order of +// scores. +// - num_boxes: number of boxes. +// - iou_threshold: the intersection-over-union (iou) threshold for elimination. +// - d_selected_indices: is a device pointer to int array containing sorted +// indices of the boxes to keep. +// - h_num_boxes_to_keep: is a host pointer for returning number of items +// to keep. +// - flip_boxes: flag reorders the boxes use lower left and upper right +// corners if they are given in mixed format. +Status NmsGpu(const float* d_sorted_boxes_float_ptr, const int num_boxes, + const float iou_threshold, int* d_selected_indices, + int* h_num_boxes_to_keep, OpKernelContext* context, + const int max_boxes, bool flip_boxes = false); +#endif + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_IMAGE_NON_MAX_SUPPRESSION_OP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/image/resize_bilinear_op.h b/third_party/tflite-hdrs/tensorflow/core/kernels/image/resize_bilinear_op.h new file mode 100644 index 00000000..1a304c2c --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/image/resize_bilinear_op.h @@ -0,0 +1,46 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_IMAGE_RESIZE_BILINEAR_OP_H_ +#define TENSORFLOW_CORE_KERNELS_IMAGE_RESIZE_BILINEAR_OP_H_ + +#include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/numeric_types.h" +#include "tensorflow/core/framework/tensor_types.h" + +namespace tensorflow { +namespace functor { + +template +struct ResizeBilinear { + void operator()(const Device& d, typename TTypes::ConstTensor images, + const float height_scale, const float width_scale, + const bool half_pixel_centers, + typename TTypes::Tensor resized_images); +}; + +template +struct ResizeBilinearGrad { + void operator()(const Device& d, + typename TTypes::ConstTensor input_grad, + const float height_scale, const float width_scale, + const bool half_pixel_centers, + typename TTypes::Tensor output_grad); +}; + +} // namespace functor +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_IMAGE_RESIZE_BILINEAR_OP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/image/resize_nearest_neighbor_op.h b/third_party/tflite-hdrs/tensorflow/core/kernels/image/resize_nearest_neighbor_op.h new file mode 100644 index 00000000..e6797dfb --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/image/resize_nearest_neighbor_op.h @@ -0,0 +1,45 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_IMAGE_RESIZE_NEAREST_NEIGHBOR_OP_H_ +#define TENSORFLOW_CORE_KERNELS_IMAGE_RESIZE_NEAREST_NEIGHBOR_OP_H_ + +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +namespace functor { + +template +struct ResizeNearestNeighbor { + bool operator()(const Device& d, typename TTypes::ConstTensor input, + const float height_scale, const float width_scale, + typename TTypes::Tensor output); +}; + +template +struct ResizeNearestNeighborGrad { + bool operator()(const Device& d, + typename TTypes::ConstTensor input_grad, + const float height_scale, const float width_scale, + typename TTypes::Tensor output_grad); +}; + +} // namespace functor +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_IMAGE_RESIZE_NEAREST_NEIGHBOR_OP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/image/sampling_kernels.h b/third_party/tflite-hdrs/tensorflow/core/kernels/image/sampling_kernels.h new file mode 100644 index 00000000..6f889add --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/image/sampling_kernels.h @@ -0,0 +1,192 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_IMAGE_SAMPLING_KERNELS_H_ +#define TENSORFLOW_CORE_KERNELS_IMAGE_SAMPLING_KERNELS_H_ + +#include + +#include "tensorflow/core/lib/core/stringpiece.h" + +namespace tensorflow { +namespace functor { +// Defines functions for different types of sampling kernels. +enum SamplingKernelType { + // Lanczos kernel with radius 1. Aliases but does not ring. + Lanczos1Kernel, + + // Lanczos kernel with radius 3. High-quality practical filter but may have + // some ringing especially on synthetic images. + Lanczos3Kernel, + + // Lanczos kernel with radius 5. Very-high-quality filter but may have + // stronger ringing. + Lanczos5Kernel, + + // Gaussian kernel with radius 3, sigma = 1.5 / 3. Less commonly used. + GaussianKernel, + + // Rectangle function. Equivalent to "nearest" sampling when upscaling. + // Has value 1 in interval (-0.5, 0.5), value 0.5 on edge, and 0 elsewhere. + BoxKernel, + + // Hat/tent function with radius 1. Equivalent to "bilinear" reconstruction + // when upsampling. + // Has value zero at -1.0 and 1.0. + TriangleKernel, + + // Cubic interpolant of Keys. Equivalent to Catmull-Rom kernel. Reasonably + // good quality and faster than Lanczos3Kernel. + KeysCubicKernel, + + // Cubic non-interpolating scheme. For synthetic images (especially those + // lacking proper prefiltering), less ringing than Keys cubic kernel but less + // sharp. + MitchellCubicKernel, + + // Always insert new kernel types before this. + SamplingKernelTypeEnd +}; + +// Converts a string into the corresponding kernel type. +// Returns SamplingKernelTypeEnd if the string couldn't be converted. +SamplingKernelType SamplingKernelTypeFromString(const absl::string_view str); + +// A function object for a Lanczos kernel. +struct LanczosKernelFunc { + // Pass 1 for Lanczos1 kernel, 3 for Lanczos3 etc. + explicit LanczosKernelFunc(float _radius) : radius(_radius) {} + float operator()(float x) const { + constexpr float kPI = 3.14159265359; + x = std::abs(x); + if (x > radius) return 0.0; + // Need to special case the limit case of sin(x) / x when x is zero. + if (x <= 1e-3) { + return 1.0; + } + return radius * std::sin(kPI * x) * std::sin(kPI * x / radius) / + (kPI * kPI * x * x); + } + float Radius() const { return radius; } + const float radius; +}; + +struct GaussianKernelFunc { + static constexpr float kRadiusMultiplier = 3.0f; + // https://en.wikipedia.org/wiki/Gaussian_function + // We use sigma = 0.5, as suggested on p. 4 of Ken Turkowski's "Filters + // for Common Resampling Tasks" for kernels with a support of 3 pixels: + // www.realitypixels.com/turk/computergraphics/ResamplingFilters.pdf + // This implies a radius of 1.5, + explicit GaussianKernelFunc(float _radius = 1.5f) + : radius(_radius), sigma(_radius / kRadiusMultiplier) {} + float operator()(float x) const { + x = std::abs(x); + if (x >= radius) return 0.0; + return std::exp(-x * x / (2.0 * sigma * sigma)); + } + float Radius() const { return radius; } + const float radius; + const float sigma; // Gaussian standard deviation +}; + +struct BoxKernelFunc { + float operator()(float x) const { + x = std::abs(x); + return x < 0.5f ? 1. : x == 0.5f ? 0.5f : 0.0f; + } + float Radius() const { return 1.f; } +}; + +struct TriangleKernelFunc { + // https://en.wikipedia.org/wiki/Triangle_function + float operator()(float x) const { + x = std::abs(x); + return x < 1.0f ? 1.0f - x : 0.0f; + } + float Radius() const { return 1.f; } +}; + +struct KeysCubicKernelFunc { + // http://ieeexplore.ieee.org/document/1163711/ + // R. G. Keys. Cubic convolution interpolation for digital image + // processing. IEEE Transactions on Acoustics, Speech, and Signal + // Processing, 29(6):1153–1160, 1981. + float operator()(float x) const { + x = std::abs(x); + if (x >= 2.0f) { + return 0.0f; + } else if (x >= 1.0f) { + return ((-0.5f * x + 2.5f) * x - 4.0f) * x + 2.0f; + } else { + return ((1.5f * x - 2.5f) * x) * x + 1.0f; + } + } + float Radius() const { return 2.f; } +}; + +struct MitchellCubicKernelFunc { + // https://doi.org/10.1145/378456.378514 + // D. P. Mitchell and A. N. Netravali. Reconstruction filters in computer + // graphics. Computer Graphics (Proceedings of ACM SIGGRAPH 1988), + // 22(4):221–228, 1988. + float operator()(float x) const { + x = std::abs(x); + if (x >= 2.0f) { + return 0.0f; + } else if (x >= 1.0f) { + return (((-7.0f / 18.0f) * x + 2.0f) * x - 10.0f / 3.0f) * x + + 16.0f / 9.0f; + } else { + return (((7.0f / 6.0f) * x - 2.0f) * x) * x + 8.0f / 9.0f; + } + } + float Radius() const { return 2.f; } +}; + +inline LanczosKernelFunc CreateLanczos1Kernel() { + return LanczosKernelFunc(1.0); +} + +inline LanczosKernelFunc CreateLanczos3Kernel() { + return LanczosKernelFunc(3.0); +} + +inline LanczosKernelFunc CreateLanczos5Kernel() { + return LanczosKernelFunc(5.0); +} + +inline GaussianKernelFunc CreateGaussianKernel() { + return GaussianKernelFunc(1.5); +} + +inline BoxKernelFunc CreateBoxKernel() { return BoxKernelFunc(); } + +inline TriangleKernelFunc CreateTriangleKernel() { + return TriangleKernelFunc(); +} + +inline KeysCubicKernelFunc CreateKeysCubicKernel() { + return KeysCubicKernelFunc(); +} + +inline MitchellCubicKernelFunc CreateMitchellCubicKernel() { + return MitchellCubicKernelFunc(); +} + +} // namespace functor +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_IMAGE_SAMPLING_KERNELS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/image/scale_and_translate_op.h b/third_party/tflite-hdrs/tensorflow/core/kernels/image/scale_and_translate_op.h new file mode 100644 index 00000000..672cc2a8 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/image/scale_and_translate_op.h @@ -0,0 +1,76 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_IMAGE_SCALE_AND_TRANSLATE_OP_H_ +#define TENSORFLOW_CORE_KERNELS_IMAGE_SCALE_AND_TRANSLATE_OP_H_ + +#include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/numeric_types.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/kernels/image/sampling_kernels.h" +#include "tsl/platform/threadpool.h" + +namespace tensorflow { +namespace functor { + +// The scale and translate op works by scaling and translating the row and +// column dimensions separately. +// When scaling and translating the rows the set of input pixels and kernel +// weights used to compute a given output pixel within a row is constant across +// rows and can thus be precomputed and reused for every row. Similarly for the +// columns. This precomputed data structure is called a 'span'. + +// To compute the gradient we use the spans computed on the forward pass and +// essentially reverse them: we record for each input pixel which output +// pixels it contributes to. This means that the forward and backward passes +// use the same core algorithm, only the spans are computed differently. + +// A pre-computed span of pixels along a single dimension. +// The output pixel will be the weighted sum of pixels starting from start. +struct Spans { + // The maximum span size of any output pixel. + int span_size; + // int32 tensor of size [output_dim]. + Tensor starts; + // float tensor of size [output_dim, span_size]. + // The output pixel at x is computed as: + // dot_product(input[starts[x]:starts[x]+span_size], weights[x]). + Tensor weights; +}; + +// Gather spans in both dimensions. +// row_span_size, row_starts and row_weights correspond to the variables in +// the row Spans data structure, similarly for col_span_size etc. +// intermediate_buffer is a Tensor used to store the result of the +// resize in the column dimension and is of size: +// [batch_size, input_height, output_width, channels] +template +struct GatherSpans { + void operator()(OpKernelContext* context, const Device& d, int row_span_size, + typename TTypes::ConstTensor row_starts, + typename TTypes::ConstTensor row_weights, + int col_span_size, + typename TTypes::ConstTensor col_starts, + typename TTypes::ConstTensor col_weights, + typename TTypes::ConstTensor input_images, + typename TTypes::Tensor intermediate_buffer, + typename TTypes::Tensor output_images); +}; + +} // namespace functor +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_IMAGE_SCALE_AND_TRANSLATE_OP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/immutable_constant_op.h b/third_party/tflite-hdrs/tensorflow/core/kernels/immutable_constant_op.h new file mode 100644 index 00000000..264abc84 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/immutable_constant_op.h @@ -0,0 +1,50 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_KERNELS_IMMUTABLE_CONSTANT_OP_H_ +#define TENSORFLOW_CORE_KERNELS_IMMUTABLE_CONSTANT_OP_H_ + +#include + +#include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/macros.h" + +namespace tensorflow { + +class ImmutableConstantOp : public OpKernel { + public: + explicit ImmutableConstantOp(OpKernelConstruction* context); + void Compute(OpKernelContext* ctx) override; + bool IsExpensive() override { return false; } + ~ImmutableConstantOp() override; + + // Names of attributes that are used by this op + static constexpr char const* kDTypeAttr = "dtype"; + static constexpr char const* kShapeAttr = "shape"; + static constexpr char const* kMemoryRegionNameAttr = "memory_region_name"; + + private: + string region_name_; + DataType dtype_; + TensorShape shape_; + ImmutableConstantOp(const ImmutableConstantOp&) = delete; + void operator=(const ImmutableConstantOp&) = delete; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_IMMUTABLE_CONSTANT_OP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/in_topk_op.h b/third_party/tflite-hdrs/tensorflow/core/kernels/in_topk_op.h new file mode 100644 index 00000000..87777764 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/in_topk_op.h @@ -0,0 +1,100 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_IN_TOPK_OP_H_ +#define TENSORFLOW_CORE_KERNELS_IN_TOPK_OP_H_ + +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM +#define EIGEN_USE_GPU +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM + +#include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/bounds_check.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_types.h" + +namespace tensorflow { +namespace functor { + +typedef Eigen::ThreadPoolDevice CPUDevice; +typedef Eigen::GpuDevice GPUDevice; + +// InTopK argument can be passed either via mode attribute (InTopK op), or as an +// input tensor (InTopKV2 op). +struct TopKArg { + int64_t k_value = -1; + const Tensor* k_tensor = nullptr; +}; + +template +struct InTopKFunctor { + template + using Dims = Eigen::DSizes; + + void operator()(OpKernelContext* context, + typename TTypes::ConstTensor predictions, + typename TTypes::ConstVec targets, const TopKArg k, + typename TTypes::Vec output) {} +}; + +template +struct InTopKFunctor { + void operator()(OpKernelContext* context, + typename TTypes::ConstTensor predictions, + typename TTypes::ConstVec targets, const TopKArg k, + typename TTypes::Vec output) { + const Eigen::Index num_targets = predictions.dimension(0); + const Eigen::Index num_classes = predictions.dimension(1); + + int64_t k_val = k.k_value; + if (k.k_tensor != nullptr) { + if (k.k_tensor->dtype() == DT_INT32) { + k_val = k.k_tensor->scalar()(); + } else { + k_val = k.k_tensor->scalar()(); + } + } + + for (int batch_idx = 0; batch_idx < num_targets; batch_idx++) { + auto target = internal::SubtleMustCopy(targets(batch_idx)); + + bool cannot_say = !FastBoundsCheck(target, num_classes) || + !std::isfinite(predictions(batch_idx, target)); + + int more_probable_classes = 0; + if (!cannot_say) { + const T target_prediction = predictions(batch_idx, target); + + for (int class_idx = 0; class_idx < num_classes; ++class_idx) { + T pred = predictions(batch_idx, class_idx); + if (!std::isfinite(pred)) { + cannot_say = true; + break; + } else if (pred > target_prediction) { + ++more_probable_classes; + if (more_probable_classes > k_val) break; + } + } + } + output(batch_idx) = cannot_say ? false : (more_probable_classes < k_val); + } + } +}; + +} // namespace functor +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_IN_TOPK_OP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/initializable_lookup_table.h b/third_party/tflite-hdrs/tensorflow/core/kernels/initializable_lookup_table.h new file mode 100644 index 00000000..c190fbd3 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/initializable_lookup_table.h @@ -0,0 +1,271 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_INITIALIZABLE_LOOKUP_TABLE_H_ +#define TENSORFLOW_CORE_KERNELS_INITIALIZABLE_LOOKUP_TABLE_H_ + +#include + +#include "tensorflow/core/framework/lookup_interface.h" +#include "tensorflow/core/platform/macros.h" + +namespace tensorflow { +namespace lookup { + +// Base class for lookup tables that require initialization. +class InitializableLookupTable : public LookupInterface { + public: + class InitTableIterator; + class InitializerSerializer; + + // Performs batch lookups, for every element in the key tensor, Find returns + // the corresponding value into the values tensor. + // If an element is not present in the table, the given default value is used. + // + // For tables that require initialization, `Find` is available once the table + // is marked as initialized. + // + // Returns the following statuses: + // - OK: when the find finishes successfully. + // - FailedPrecondition: if the table is not initialized. + // - InvalidArgument: if any of the preconditions on the lookup key or value + // fails. + // - In addition, other implementations may provide another non-OK status + // specific to their failure modes. + absl::Status Find(OpKernelContext* ctx, const Tensor& keys, Tensor* values, + const Tensor& default_value) final; + + // Returns errors::Unimplemented. + absl::Status Insert(OpKernelContext* ctx, const Tensor& keys, + const Tensor& values) final { + return errors::Unimplemented( + "Insert not supported by InitializableLookupTable implementations"); + } + + // Returns errors::Unimplemented. + absl::Status Remove(OpKernelContext* ctx, const Tensor& keys) final { + return errors::Unimplemented( + "Remove not supported by InitializableLookupTable implementations"); + } + + absl::Status ExportValues(OpKernelContext* context) override { + return errors::Unimplemented( + "ExportValues not supported by InitializableLookupTable " + "implementations"); + } + + absl::Status ImportValues(OpKernelContext* ctx, const Tensor& keys, + const Tensor& values) final; + + TensorShape key_shape() const final { return TensorShape(); } + + TensorShape value_shape() const final { return TensorShape(); } + + // Returns whether the table was initialized and is ready to serve lookups. + bool is_initialized() const { + return is_initialized_.load(std::memory_order_acquire); + } + + // Initializes the table from the given init table iterator. + // + // Atomically, this operation prepares the table, populates it with the given + // iterator, and marks the table as initialized. + // + // Returns the following statuses: + // - OK: when the initialization was successful. + // - InvalidArgument: if any of the preconditions on the lookup key or value + // fails. + // - FailedPrecondition: if the table is already initialized and + // fail_if_initialized is set to true. + // - In addition, other implementations may provide another non-OK status + // specific to their failure modes. + absl::Status Initialize(InitTableIterator& iter); + + // Initializes the table from the given init table iterator. `serializer` may + // specify how to serialize the table initializer, so that the table can be + // serialized using its metadata (as opposed to serializing a handle to the + // table). + absl::Status Initialize(InitTableIterator& iter, + std::unique_ptr serializer); + + // Basic iterator to initialize lookup tables. + // It yields a sequence of pairs of `keys()` and `values()` Tensors, so that + // the consumer may insert key-value pairs in batches. + // + // Then the iterator is exhausted, valid returns false and status returns + // Status::OutOfRange. + // + // This class is Thread-unsafe. + class InitTableIterator { + public: + InitTableIterator() {} + + virtual ~InitTableIterator() {} + + // Prepares the next batch of key and value tensors. + virtual void Next() = 0; + + // Returns true if keys and values point to valid tensors. + virtual bool Valid() const = 0; + + // Returns a tensor that contains the current batch of 'key' values. + virtual const Tensor& keys() const = 0; + + // Returns a tensor that contains the current batch of 'value' values. + virtual const Tensor& values() const = 0; + + // Returns an error if one has occurred, otherwise returns Status::OK. + virtual absl::Status status() const = 0; + + // Returns the total number of elements that the iterator will produce. + // It might return -1 in case of error. + virtual int64_t total_size() const = 0; + + private: + InitTableIterator(const InitTableIterator&) = delete; + void operator=(const InitTableIterator&) = delete; + }; + + InitializableLookupTable* GetInitializableLookupTable() override { + return this; + } + + // Logic specifying how to represent an initializer as a GraphDef, so that a + // lookup table can be serialized using its metadata (as opposed to + // serializing the content of the table, or a handle to the table). + class InitializerSerializer { + public: + // A function which builds a graph so that executing `*out` will initialize + // `table`. + using SerializeFn = std::function; + // A function which performs any necessary cleanup for the serializer. + using CleanupFn = std::function; + + // Wraps serialization logic that requires no cleanup. + explicit InitializerSerializer(SerializeFn serialize) + : serialize_(std::move(serialize)), cleanup_([] {}) {} + + // Wraps serialization logic along with a cleanup function. `cleanup` will + // be run when the serializer is destroyed. + explicit InitializerSerializer(SerializeFn serialize, CleanupFn cleanup) + : serialize_(std::move(serialize)), cleanup_(std::move(cleanup)) {} + + ~InitializerSerializer() { cleanup_(); } + + // Builds a graph so that executing `*out` will initialize `table`. + absl::Status AsGraphDef(GraphDefBuilder* builder, Node* table, Node** out) { + return serialize_(builder, table, out); + } + + private: + SerializeFn serialize_; + CleanupFn cleanup_; + }; + + protected: + // Prepares and allocates the underlying data structure to store the given + // number of expected elements. + virtual absl::Status DoPrepare(size_t expected_num_elements) = 0; + + // Same as DoPrepare() but derived implementations might choose to skip + // calling get_expected_num_elements if size is not needed for DoPrepare. + virtual absl::Status DoLazyPrepare( + std::function get_expected_num_elements) { + int64_t expected_num_elements = get_expected_num_elements(); + if (expected_num_elements < 0) { + return errors::FailedPrecondition("Got negative expected_num_elements."); + } + return DoPrepare(expected_num_elements); + } + + // Populates the table in batches given keys and values as tensors into the + // underlying data structure. + virtual absl::Status DoInsert(const Tensor& keys, const Tensor& values) = 0; + + // Performs the batch find operation on the underlying data structure. + virtual absl::Status DoFind(const Tensor& keys, Tensor* values, + const Tensor& default_value) = 0; + + virtual absl::Status AreEntriesSame(const InitTableIterator& iter, + bool* result); + + mutex mu_; + + protected: + // When set, provides a mechanism for serializing the table initializer as + // GraphDef. + std::unique_ptr initializer_serializer_; + + private: + std::atomic is_initialized_{false}; +}; + +// Iterator to initialize tables given 'keys' and 'values' tensors. +// +// The two tensors are returned in the first iteration. It doesn't loop +// over each element of the tensor since insertions in the lookup table can +// process batches. +class KeyValueTensorIterator + : public InitializableLookupTable::InitTableIterator { + public: + // keys and values are not owned by the iterator. + explicit KeyValueTensorIterator(const Tensor* keys, const Tensor* values) + : keys_(keys), values_(values), valid_(true), status_(absl::OkStatus()) { + TensorShape key_shape = keys_->shape(); + if (!key_shape.IsSameSize(values_->shape())) { + valid_ = false; + status_ = errors::InvalidArgument( + "keys and values should have the same dimension.", + key_shape.DebugString(), " vs ", values_->shape().DebugString()); + } + if (key_shape.num_elements() == 0) { + valid_ = false; + status_ = + errors::InvalidArgument("keys and values cannot be empty tensors."); + } + } + + bool Valid() const override { return valid_; } + + void Next() override { + valid_ = false; + status_ = errors::OutOfRange("No more data."); + } + + const Tensor& keys() const override { return *keys_; } + + const Tensor& values() const override { return *values_; } + + absl::Status status() const override { return status_; } + + int64_t total_size() const override { + return keys_ == nullptr ? -1 : keys_->NumElements(); + } + + private: + KeyValueTensorIterator(const KeyValueTensorIterator&) = delete; + void operator=(const KeyValueTensorIterator&) = delete; + + const Tensor* keys_; // Doesn't own it. + const Tensor* values_; // Doesn't own it. + bool valid_; // true if the iterator points to an existing range. + absl::Status status_; +}; + +} // namespace lookup +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_INITIALIZABLE_LOOKUP_TABLE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/inplace_ops_functor.h b/third_party/tflite-hdrs/tensorflow/core/kernels/inplace_ops_functor.h new file mode 100644 index 00000000..e1707824 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/inplace_ops_functor.h @@ -0,0 +1,49 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_INPLACE_OPS_FUNCTOR_H_ +#define TENSORFLOW_CORE_KERNELS_INPLACE_OPS_FUNCTOR_H_ + +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { +namespace functor { + +template +absl::Status DoParallelConcat(const Device& device, const Tensor& value, + int32_t loc, Tensor* output); + +// Inplace update/add/sub values in 'y'. It computes +// y[i, :] = v if op is I_UPDATE +// y[i, :] += v if op is I_ADD +// y[i, :] -= v if op is I_SUB +// Returns an error if the operation fails. +enum InplaceOpType { + I_UPDATE, // x = y + I_ADD, // x += y + I_SUB, // x -= y +}; +template +absl::Status DoInplace(const Device& device, InplaceOpType op, const Tensor& i, + const Tensor& v, Tensor* y); +// Copies x into y. +template +absl::Status DoCopy(const Device& device, const Tensor& x, Tensor* y); + +} // end namespace functor +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_INPLACE_OPS_FUNCTOR_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/kernel_platform_strings.h b/third_party/tflite-hdrs/tensorflow/core/kernels/kernel_platform_strings.h new file mode 100644 index 00000000..9bf40c30 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/kernel_platform_strings.h @@ -0,0 +1,25 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Generate platform strings for libtfkernel-* + +#ifndef TENSORFLOW_CORE_KERNELS_KERNEL_PLATFORM_STRINGS_H_ +#define TENSORFLOW_CORE_KERNELS_KERNEL_PLATFORM_STRINGS_H_ + +#include "tensorflow/core/platform/platform_strings.h" + +TF_PLATFORM_STRINGS() + +#endif // TENSORFLOW_CORE_KERNELS_KERNEL_PLATFORM_STRINGS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/l2loss_op.h b/third_party/tflite-hdrs/tensorflow/core/kernels/l2loss_op.h new file mode 100644 index 00000000..2adaacbb --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/l2loss_op.h @@ -0,0 +1,33 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_L2LOSS_OP_H_ +#define TENSORFLOW_CORE_KERNELS_L2LOSS_OP_H_ +#include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor_types.h" + +namespace tensorflow { + +template +struct L2LossOp : public OpKernel { + explicit L2LossOp(OpKernelConstruction* context) : OpKernel(context) {} + + void Compute(OpKernelContext* context) {} +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_L2LOSS_OP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/linalg/determinant_op.h b/third_party/tflite-hdrs/tensorflow/core/kernels/linalg/determinant_op.h new file mode 100644 index 00000000..6ace1bef --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/linalg/determinant_op.h @@ -0,0 +1,47 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_LINALG_DETERMINANT_OP_H_ +#define TENSORFLOW_CORE_KERNELS_LINALG_DETERMINANT_OP_H_ + +#include "tensorflow/core/framework/tensor_types.h" + +namespace tensorflow { +namespace functor { + +// Helper functor to compute Determinant from a partially pivoted LU +// factorization. +template +struct DeterminantFromPivotedLUFunctor { + void operator()(const Device& device, + typename TTypes::ConstTensor lu_factor, + const int* pivots, typename TTypes::Tensor output, + int* info); +}; + +// Helper functor to compute sign and log of the absolute value of the +// determinant from a partially pivoted LU factorization. +template +struct LogDeterminantFromPivotedLUFunctor { + void operator()(const Device& device, + typename TTypes::ConstTensor lu_factor, + const int* pivots, typename TTypes::Tensor sign, + typename TTypes::Tensor log_abs_det); +}; + +} // namespace functor +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_LINALG_DETERMINANT_OP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/linalg/eig_op_impl.h b/third_party/tflite-hdrs/tensorflow/core/kernels/linalg/eig_op_impl.h new file mode 100644 index 00000000..220e6db5 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/linalg/eig_op_impl.h @@ -0,0 +1,100 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_LINALG_EIG_OP_IMPL_H_ +#define TENSORFLOW_CORE_KERNELS_LINALG_EIG_OP_IMPL_H_ + +// See docs in ../ops/linalg_ops.cc. + +#include "Eigen/Core" // from @eigen_archive +#include "Eigen/Eigenvalues" // from @eigen_archive +#include "tensorflow/core/framework/kernel_def_builder.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/kernels/linalg/linalg_ops_common.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/platform/denormal.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +template +class EigOp : public LinearAlgebraOp { + public: + typedef LinearAlgebraOp Base; + + explicit EigOp(OpKernelConstruction* context) : Base(context) { + OP_REQUIRES_OK(context, context->GetAttr("compute_v", &compute_v_)); + } + + using TensorShapes = typename Base::TensorShapes; + using InputMatrix = typename Base::InputMatrix; + using InputMatrixMaps = typename Base::InputMatrixMaps; + using InputConstMatrixMap = typename Base::InputConstMatrixMap; + using InputConstMatrixMaps = typename Base::InputConstMatrixMaps; + + using OutputMatrix = typename Base::OutputMatrix; + using OutputMatrixMaps = typename Base::OutputMatrixMaps; + using OutputConstMatrixMap = typename Base::OutputConstMatrixMap; + using OutputConstMatrixMaps = typename Base::OutputConstMatrixMaps; + + TensorShapes GetOutputMatrixShapes( + const TensorShapes& input_matrix_shapes) const final { + int64_t n = input_matrix_shapes[0].dim_size(0); + if (compute_v_) { + return TensorShapes({TensorShape({n}), TensorShape({n, n})}); + } else { + return TensorShapes({TensorShape({n})}); + } + } + + void ComputeMatrix(OpKernelContext* context, + const InputConstMatrixMaps& inputs, + OutputMatrixMaps* outputs) final { + const int64_t rows = inputs[0].rows(); + if (rows == 0) { + // If X is an empty matrix (0 rows, 0 col), X * X' == X. + // Therefore, we return X. + return; + } + + // This algorithm relies on denormals, so switch them back on locally. + port::ScopedDontFlushDenormal dont_flush_denormals; + + using EigenSolver = + std::conditional_t::IsComplex, + Eigen::ComplexEigenSolver, + Eigen::EigenSolver>; + EigenSolver eig(inputs[0], /*computeEigenvectors=*/compute_v_); + + OP_REQUIRES( + context, eig.info() == Eigen::Success, + errors::InvalidArgument("Eigen decomposition was not " + "successful. The input might not be valid.")); + + outputs->at(0) = eig.eigenvalues().template cast(); + if (compute_v_) { + outputs->at(1) = eig.eigenvectors(); + } + } + + private: + bool compute_v_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_LINALG_EIG_OP_IMPL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/linalg/einsum_op.h b/third_party/tflite-hdrs/tensorflow/core/kernels/linalg/einsum_op.h new file mode 100644 index 00000000..26daed1e --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/linalg/einsum_op.h @@ -0,0 +1,48 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_KERNELS_LINALG_EINSUM_OP_H_ +#define TENSORFLOW_CORE_KERNELS_LINALG_EINSUM_OP_H_ + +#include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/tensor_types.h" + +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM +#define EIGEN_USE_GPU +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM + +namespace tensorflow { +namespace functor { + +template +struct StrideFunctor { + void operator()(const Device& d, typename TTypes::ConstTensor input, + const Eigen::DSizes& strides, + typename TTypes::Tensor output) { + output.device(d) = input.stride(strides); + } +}; + +template +struct InflateFunctor { + void operator()(const Device& d, typename TTypes::ConstTensor input, + const Eigen::DSizes& strides, + typename TTypes::Tensor output) { + output.device(d) = input.inflate(strides); + } +}; +} // namespace functor +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_EINSUM_OP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/linalg/einsum_op_impl.h b/third_party/tflite-hdrs/tensorflow/core/kernels/linalg/einsum_op_impl.h new file mode 100644 index 00000000..1d345be6 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/linalg/einsum_op_impl.h @@ -0,0 +1,673 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_LINALG_EINSUM_OP_IMPL_H_ +#define TENSORFLOW_CORE_KERNELS_LINALG_EINSUM_OP_IMPL_H_ + +#define EIGEN_USE_THREADS +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM +#define EIGEN_USE_GPU +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM + +#include "absl/container/flat_hash_map.h" +#include "absl/strings/str_split.h" +#include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/kernel_def_builder.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/kernels/fill_functor.h" +#include "tensorflow/core/kernels/linalg/einsum_op.h" +#include "tensorflow/core/kernels/matmul_op_impl.h" +#include "tensorflow/core/kernels/reduction_ops_common.h" +#include "tensorflow/core/kernels/transpose_functor.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/gtl/inlined_vector.h" +#include "tensorflow/core/lib/math/math_util.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/profiler/lib/traceme.h" +#include "tensorflow/core/util/einsum_op_util.h" + +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM +#include "tensorflow/core/kernels/reduction_ops_common_gpu.h" +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM + +namespace tensorflow { + +using CPUDevice = Eigen::ThreadPoolDevice; +using GPUDevice = Eigen::GpuDevice; + +using ShapeVec = absl::InlinedVector; +using Labels = absl::InlinedVector; +using OperandLabels = absl::InlinedVector; +using LabelCounts = absl::InlinedVector; +using OperandLabelCounts = absl::InlinedVector; +using LabelToDimSizes = absl::InlinedVector; + +struct EinsumHelper { + // Insert new (unnamed) broadcasting labels at the location of ellipsis. + static void InsertBroadcastLabels(int num_bcast_dims, int num_named_labels, + int ellipsis_axis, Labels* labels, + LabelCounts* label_counts) { + labels->erase(labels->begin() + ellipsis_axis); + labels->insert(labels->begin() + ellipsis_axis, num_bcast_dims, 0); + std::iota(labels->begin() + ellipsis_axis, + labels->begin() + ellipsis_axis + num_bcast_dims, + num_named_labels); + // Increment label counts. Since these are new labels, the count is set + // to 1. + label_counts->resize(num_named_labels + num_bcast_dims, 1); + } + + // Record and validate the label to dimension mapping. Must be a named + // (non-broadcasting) label as broadcasting labels don't have a fixed + // dimension. + static absl::Status RecordLabelToDimension( + const int label, const int axis, const Tensor& input, + LabelToDimSizes* label_to_dim_sizes) { + const int64_t input_dim = input.dim_size(axis); + // We know that label_to_dim_sizes has the size to accommodate named labels. + if (label_to_dim_sizes->at(label) != 0 && + label_to_dim_sizes->at(label) != input_dim) { + return errors::InvalidArgument( + "Expected dimension ", label_to_dim_sizes->at(label), " at axis ", + axis, " of the input shaped ", input.shape().DebugString(), + " but got dimension ", input_dim); + } + (*label_to_dim_sizes)[label] = input_dim; + return absl::OkStatus(); + } + + // Validate input dimensions and populate unnamed labels and their label + // counts. + static absl::Status ProcessDimensions( + const OpInputList& inputs, + const absl::InlinedVector& input_has_ellipsis, + const bool output_has_ellipsis, OperandLabels* input_labels, + Labels* output_labels, std::vector* label_types, + OperandLabelCounts* input_label_counts, LabelCounts* output_label_counts, + LabelToDimSizes* label_to_dim_sizes) { + if (inputs.size() != input_labels->size()) { + return errors::InvalidArgument("Expected ", input_labels->size(), + " inputs but got: ", inputs.size()); + } + const int num_inputs = inputs.size(); + + // We infer the number of broadcasting dimensions by taking the maximum rank + // among the broadcasting subshapes of the input. + int max_bcast_dims = 0; + const int num_named_labels = label_types->size(); + label_to_dim_sizes->resize(num_named_labels); + for (int i = 0; i < num_inputs; ++i) { + Labels* labels = &(*input_labels)[i]; + + if (!input_has_ellipsis[i]) { + if (inputs[i].dims() != labels->size()) { + return errors::InvalidArgument("Expected input ", i, " to have rank ", + labels->size(), + " but got: ", inputs[i].dims()); + } + for (int label_idx = 0; label_idx < labels->size(); ++label_idx) { + const int label = (*labels)[label_idx]; + TF_RETURN_IF_ERROR(RecordLabelToDimension(label, label_idx, inputs[i], + label_to_dim_sizes)); + } + continue; + } + + // Input has an ellipsis. + if (inputs[i].dims() + 1 < labels->size()) { + return errors::InvalidArgument( + "Expected input ", i, " to have rank at least ", labels->size() - 1, + " but got: ", inputs[i].dims()); + } + int ellipsis_axis = -1; + const int num_bcast_dims = inputs[i].dims() - labels->size() + 1; + for (int label_idx = 0; label_idx < labels->size(); ++label_idx) { + const int label = (*labels)[label_idx]; + if (label == kEllipsisLabel) { + ellipsis_axis = label_idx; + continue; + } + // Current label is not an ellipsis. + const int axis = + label_idx + (ellipsis_axis == -1 ? 0 : num_bcast_dims - 1); + TF_RETURN_IF_ERROR( + RecordLabelToDimension(label, axis, inputs[i], label_to_dim_sizes)); + } + // Found an ellipsis. Replace 'kEllipsisLabel' with broadcasting + // dimensions. + if (ellipsis_axis != -1) { + InsertBroadcastLabels(num_bcast_dims, num_named_labels, ellipsis_axis, + labels, &input_label_counts->at(i)); + max_bcast_dims = std::max(max_bcast_dims, num_bcast_dims); + } + } + if (!absl::c_linear_search(input_has_ellipsis, true) && + !output_has_ellipsis) { + return absl::OkStatus(); + } + // Insert broadcasting dimensions in the output labels. + auto it = + std::find(output_labels->begin(), output_labels->end(), kEllipsisLabel); + if (it != output_labels->end()) { + const int ellipsis_axis = it - output_labels->begin(); + InsertBroadcastLabels(max_bcast_dims, num_named_labels, ellipsis_axis, + output_labels, output_label_counts); + } else if (max_bcast_dims > 0) { + return errors::InvalidArgument( + "Output contains ", max_bcast_dims, + " broadcasting dimension(s) but no ellipsis " + "(...) was found in the output subscripts."); + } + // Populate EinsumDimensionType for the new broadcasting labels. + label_types->resize(num_named_labels + max_bcast_dims, + EinsumDimensionType::kBroadcasting); + return absl::OkStatus(); + } + + // Permutes the labels according to the given permutation. + static void PermuteLabels(const std::vector& permutation, + Labels* labels) { + Labels permuted_labels(labels->size()); + for (int i = 0; i < labels->size(); ++i) { + permuted_labels[i] = (*labels)[permutation[i]]; + } + labels->swap(permuted_labels); + } + + // Returns a reshaped input Tensor. The underlying buffer is not copied. + static absl::Status CopyFrom(const Tensor& input, const TensorShape& shape, + Tensor* output) { + if (output->CopyFrom(input, shape)) return absl::OkStatus(); + return errors::Internal( + "Encountered error while reshaping a Tensor of shape ", + input.shape().DebugString(), " to shape ", shape.DebugString()); + } + + // Returns whether transposing would be a no-op; whether input has rank < 2 or + // the permutation is the identity permutation. + static bool ShouldTranspose(const TensorShape& input_shape, + const std::vector& permutation) { + if (input_shape.dims() < 2) return false; + for (int i = 0; i < permutation.size(); ++i) { + if (permutation[i] != i) return true; + } + return false; + } + + // Transpose the input given a permutation. Returns a reference to the input + // if transposing is not necessary. + template + static absl::Status TransposeOperand(OpKernelContext* ctx, + const Tensor& input, + const std::vector& permutation, + Tensor* output) { + if (!ShouldTranspose(input.shape(), permutation)) { + return CopyFrom(input, input.shape(), output); + } + TensorShape transposed_shape; + for (int i = 0; i < input.dims(); ++i) { + TF_RETURN_IF_ERROR( + transposed_shape.AddDimWithStatus(input.dim_size(permutation[i]))); + } + // For empty Tensors, just change the shape. E.g. we may need to transpose + // from shape [1, 0, 5] to [5, 1, 0]. + if (input.NumElements() == 0) { + return CopyFrom(input, transposed_shape, output); + } + TF_RETURN_IF_ERROR( + ctx->allocate_temp(DataTypeToEnum::value, transposed_shape, output)); + const Device& device = ctx->eigen_device(); + TF_RETURN_IF_ERROR(DoTranspose(device, input, permutation, output)); + return absl::OkStatus(); + } + + // If there are repeated labels in either the input or output, then this + // strides the input (e.g. iii->i) or inflates it (e.g. i->iii), respectively. + template + static absl::Status StrideOrInflate(OpKernelContext* ctx, const Tensor& input, + const Labels& labels, + const LabelCounts& label_counts, + const bool should_inflate, + Tensor* output) { + // Return early if there are no repeated indices. + if (absl::c_all_of(label_counts, [](int c) { return c <= 1; })) { + return CopyFrom(input, input.shape(), output); + } + // We reshape so that each repeated label is compressed to one dimension. + // E.g. For iiij -> ij, The shape [3, 3, 3, 5] would be compressed to [27, + // 5]. Striding appropriately (in this case with strides 14 (=1+3+9) and 1) + // recovers the generalized diagonal of shape [3, 5]. + ShapeVec reshape; + ShapeVec strides; + // Strided and inflated shapes correspond to input and output shapes, + // respectively, should_inflate is true (vice-versa if should_inflate is + // false). E.g. they are [3, 5] and [3, 3, 3, 5] in the above example. + ShapeVec strided_shape; + ShapeVec inflated_shape; + for (int label : labels) { + const int count = label_counts[label]; + const int current_axis = + should_inflate ? strided_shape.size() : inflated_shape.size(); + const int64_t dim = input.dim_size(current_axis); + strided_shape.push_back(dim); + inflated_shape.insert(inflated_shape.end(), count, dim); + const int64_t reshape_dim = MathUtil::IPow(dim, count); + reshape.push_back(reshape_dim); + // While taking the d-diagonal in a rank k Tensor, we take d + // equally-spaced elements including the first and last element. Then, (k + // - 1) * stride = d^k - 1, or, stride = (d^k - 1)/(d - 1). + const int64_t stride = + (dim > 1 && count > 1) ? (reshape_dim - 1) / (dim - 1) : 1; + strides.push_back(stride); + } + + TensorShape output_shape = + TensorShape(should_inflate ? inflated_shape : strided_shape); + TF_RETURN_IF_ERROR( + ctx->allocate_temp(DataTypeToEnum::value, output_shape, output)); + const Device& device = ctx->eigen_device(); + switch (reshape.size()) { +#define NDIMS_CASE(N) \ + case N: { \ + if (should_inflate) { \ + auto output_map = output->shaped(reshape); \ + auto input_map = input.shaped(strided_shape); \ + functor::InflateFunctor()( \ + device, input_map, TensorShape(strides).AsEigenDSizes(), \ + output_map); \ + } else { \ + auto input_map = input.shaped(reshape); \ + auto output_map = output->shaped(strided_shape); \ + functor::StrideFunctor()( \ + device, input_map, TensorShape(strides).AsEigenDSizes(), \ + output_map); \ + } \ + } break; + NDIMS_CASE(1); + NDIMS_CASE(2); + NDIMS_CASE(3); + NDIMS_CASE(4); + NDIMS_CASE(5); + NDIMS_CASE(6); + default: + return errors::Unimplemented( + "Unsupported rank: ", reshape.size(), + " while handling repeated indices. Up to rank 6 is supported."); +#undef NDIMS_CASE + } + return absl::OkStatus(); + } + + // Returns true if the input dimensions are already sorted in the order + // [batch, contract, free, reduce]. Used to implement an optimization to avoid + // an extra transpose and instead uses (adj_x and adj_y) in BatchMatMul. + static bool ShouldSwapFreeAndContract( + const Labels& labels, + const std::vector& label_types) { + // Check that ordering is according to dimension type, with the role of + // free and contract dimensions swapped. + absl::InlinedVector remap = {0, 1, 3, 2, 4}; + for (int i = 0; i + 1 < labels.size(); ++i) { + const int dimtype_a = remap[label_types[labels[i]]]; + const int dimtype_b = remap[label_types[labels[i + 1]]]; + if (dimtype_a > dimtype_b || + (dimtype_a == dimtype_b && labels[i] > labels[i + 1])) { + return false; + } + } + return true; + } + + template + static absl::Status ReduceOperand( + OpKernelContext* ctx, const Tensor& input, + const std::vector& label_types, + const LabelCounts& label_counts, Labels* labels, Labels* free_labels, + bool* swap_free_and_contract, Tensor* output) { + // Find the permutation to transpose the input dimensions in the order of + // EinsumDimensionType; i.e. batch, free, contract and reduce dimensions. + // This makes it more convenient to invoke Reduce/Contract operations. + std::vector permutation(input.dims()); + absl::c_iota(permutation, 0); + Tensor input_transposed; + // Check if we can avoid the transpose. We need to flip the adj_x (or adj_y) + // flag during BatchMatMul. This is an extra optimization not necessary for + // correctness. + if (ShouldSwapFreeAndContract(*labels, label_types)) { + *swap_free_and_contract = true; + } else { + absl::c_sort(permutation, [&](int i, int j) { + int label_i = (*labels)[i]; + int label_j = (*labels)[j]; + return std::tie(label_types[label_i], label_i) < + std::tie(label_types[label_j], label_j); + }); + } + // Transpose the input so that EinsumDimensionTypes are in order. + TF_RETURN_IF_ERROR(TransposeOperand(ctx, input, permutation, + &input_transposed)); + PermuteLabels(permutation, labels); + + // Take the generalized diagonal for dimensions with repeated axis labels. + Tensor input_deduped; + labels->erase(std::unique(labels->begin(), labels->end()), labels->end()); + TF_RETURN_IF_ERROR( + StrideOrInflate(ctx, input_transposed, *labels, label_counts, + false /* should_inflate */, &input_deduped)); + + // Reshape denotes the rank-5 shape [broadcast, batch, free, contract, + // reduce] where we've compacted the dimensions of each EinsumDimensionType. + absl::InlinedVector reshape(5, 1); + // The output shape is [batch shape] + [free size, contract size] + // That is, the batch shape is preserved (for broadcasting while + // contracting) while the free dims and contract dims are compressed to one + // dimension each. + TensorShape output_shape; + for (int label_idx = 0; label_idx < labels->size(); ++label_idx) { + const int label = labels->at(label_idx); + int64_t dim = input_deduped.dim_size(label_idx); + if (label_types[label] == EinsumDimensionType::kBroadcasting || + label_types[label] == EinsumDimensionType::kBatch) { + TF_RETURN_IF_ERROR(output_shape.AddDimWithStatus(dim)); + } else if (label_types[label] == EinsumDimensionType::kFree) { + free_labels->push_back(label); + } + reshape[label_types[label]] *= dim; + } + if (*swap_free_and_contract) + std::swap(reshape[EinsumDimensionType::kFree], + reshape[EinsumDimensionType::kContract]); + TF_RETURN_IF_ERROR( + output_shape.AddDimWithStatus(reshape[EinsumDimensionType::kFree])); + TF_RETURN_IF_ERROR( + output_shape.AddDimWithStatus(reshape[EinsumDimensionType::kContract])); + + if (reshape[EinsumDimensionType::kReduce] == + 1) { // No need to actually reduce. + return CopyFrom(input_deduped, output_shape, output); + } + TF_RETURN_IF_ERROR( + ctx->allocate_temp(DataTypeToEnum::value, output_shape, output)); + using Reducer = Eigen::internal::SumReducer; + using Index = typename TTypes::Tensor::Index; + // Reduce along the last axis (i.e axis 1) of the rank-2 Tensor. + const int64_t output_size = reshape[kBroadcasting] * reshape[kBatch] * + reshape[kFree] * reshape[kContract]; + functor::ReduceFunctor::Reduce( + ctx, output->shaped({output_size}), + const_cast(input_deduped) + .shaped({output_size, reshape[kReduce]}), + Eigen::array({1}), Reducer()); + return absl::OkStatus(); + } + + // Reshapes a Tensor of shape [b0,b1...bk,N,M] to [prod(b0,b1...bk),N,M]. + static absl::Status ReshapeToRank3(const Tensor& input, int batch_size, + Tensor* output) { + const int rank = input.dims(); + TensorShape output_shape = {batch_size, input.dim_size(rank - 2), + input.dim_size(rank - 1)}; + return CopyFrom(input, output_shape, output); + } + + // Contracts the inputs along the last axis (or the second last if the + // corresponding value of swap_free_and_contract is true). The batch + // dimensions are broadcast to the output shape. + // TODO(anudhyan): BatchMatMul might devolve into a component-wise + // multiplication when the matrix shape is [1,1]; in this case BatchMatMul + // functor would be very inefficient. The functor should detect if this is the + // case and perform componentwise multiplication functor instead. + template + static absl::Status ContractOperands( + OpKernelContext* ctx, absl::Span inputs, + absl::Span swap_free_and_contract, Tensor* output) { + if (inputs.size() == 1) + return CopyFrom(inputs[0], inputs[0].shape(), output); + MatMulBCast bcast(inputs[0].shape().dim_sizes(), + inputs[1].shape().dim_sizes()); + if (!bcast.IsValid()) { + return errors::InvalidArgument( + "Invalid broadcasting dimensions: ", inputs[0].shape().DebugString(), + " vs. ", inputs[1].shape().DebugString()); + } + Tensor lhs; + TF_RETURN_IF_ERROR(ReshapeToRank3(inputs[0], bcast.x_batch_size(), &lhs)); + Tensor rhs; + TF_RETURN_IF_ERROR(ReshapeToRank3(inputs[1], bcast.y_batch_size(), &rhs)); + TensorShape output_shape = bcast.output_batch_shape(); + for (int i = 0; i < inputs.size(); ++i) { + const int64_t free_axis = + inputs[i].dims() - (swap_free_and_contract[i] ? 1 : 2); + TF_RETURN_IF_ERROR( + output_shape.AddDimWithStatus(inputs[i].dim_size(free_axis))); + } + bool trans_x = swap_free_and_contract[0]; + bool trans_y = !swap_free_and_contract[1]; + TF_RETURN_IF_ERROR( + ctx->allocate_temp(DataTypeToEnum::value, output_shape, output)); + if (lhs.NumElements() == 0 || rhs.NumElements() == 0) { + functor::SetZeroFunctor set_zero; + set_zero(ctx->eigen_device(), output->flat()); + return absl::OkStatus(); + } + Tensor output_reshaped; + TF_RETURN_IF_ERROR( + ReshapeToRank3(*output, bcast.output_batch_size(), &output_reshaped)); + LaunchBatchMatMul::Launch(ctx, lhs, rhs, /*adj_x=*/false, + /*adj_y=*/false, trans_x, trans_y, + /*grad_x=*/false, /*grad_y=*/false, + bcast, &output_reshaped); + return absl::OkStatus(); + } +}; + +template +class EinsumOp : public OpKernel { + public: + explicit EinsumOp(OpKernelConstruction* c) : OpKernel(c) { + OP_REQUIRES_OK(c, c->GetAttr("equation", &equation_)); + OP_REQUIRES_OK( + c, ParseEinsumEquation(equation_, &input_labels_, &output_labels_, + &label_types_, &input_label_counts_, + &output_label_counts_, &input_has_ellipsis_, + &output_has_ellipsis_)); + } + + void Compute(OpKernelContext* ctx) override { + OpInputList inputs; + OP_REQUIRES_OK(ctx, ctx->input_list("inputs", &inputs)); + + OperandLabels input_labels(input_labels_); + Labels output_labels(output_labels_); + std::vector label_types(label_types_); + OperandLabelCounts input_label_counts(input_label_counts_); + LabelCounts output_label_counts(output_label_counts_); + LabelToDimSizes label_to_dim_sizes; + + OP_REQUIRES_OK(ctx, EinsumHelper::ProcessDimensions( + inputs, input_has_ellipsis_, output_has_ellipsis_, + &input_labels, &output_labels, &label_types, + &input_label_counts, &output_label_counts, + &label_to_dim_sizes)); + + // The reduction phase (a) sums across reduction dimensions, (b) takes + // generalized diagonals, and (c) reshapes it into shape + // [(broadcasting) batch shape] + [F,C] + // where F and C denote the total (compacted) size of free and contract + // dimensions, respectively. + const int num_inputs = inputs.size(); + OperandLabels free_labels(num_inputs); + absl::InlinedVector inputs_reduced(num_inputs); + absl::InlinedVector swap_free_and_contract(num_inputs); + for (int i = 0; i < num_inputs; ++i) { + OP_REQUIRES_OK(ctx, + EinsumHelper::ReduceOperand( + ctx, inputs[i], label_types, input_label_counts[i], + &input_labels[i], &free_labels[i], + &swap_free_and_contract[i], &inputs_reduced[i])); + } + + // After reduction, the inputs should be reshaped to Tensors suitable for + // contraction. If num_inputs is 1, the reduced input is simply forwarded to + // the output. + Tensor contraction_output_reshaped; + OP_REQUIRES_OK(ctx, EinsumHelper::ContractOperands( + ctx, inputs_reduced, swap_free_and_contract, + &contraction_output_reshaped)); + + // Copy the batch labels from the contraction output. Recover the batch + // shape, which may have been broadcasted. + TensorShape result_shape = contraction_output_reshaped.shape(); + result_shape.RemoveLastDims(2); + + int num_labels = label_types.size(); + Labels result_labels; + // All batch dimensions should be present in the contracted result. First + // the broadcasting dimensions, then the named batch dimensions. + for (int label = 0; label < num_labels; ++label) { + if (label_types[label] == EinsumDimensionType::kBroadcasting) + result_labels.push_back(label); + } + for (int label = 0; label < num_labels; ++label) { + if (label_types[label] == EinsumDimensionType::kBatch) + result_labels.push_back(label); + } + for (int i = 0; i < num_inputs; ++i) { + for (int label : free_labels[i]) { + result_labels.push_back(label); + OP_REQUIRES_OK( + ctx, result_shape.AddDimWithStatus(label_to_dim_sizes[label])); + } + } + + // Reshape the contraction (or reduction) result to its expanded shape: + // [(broadcasted) batch shape] + [free shape 0] + [free shape 1]. + Tensor contraction_output; + OP_REQUIRES_OK( + ctx, EinsumHelper::CopyFrom(contraction_output_reshaped, result_shape, + &contraction_output)); + + // Inflate the output if necessary. (E.g. for the equation 'i->iii' which + // may arise while computing gradient of a regular Einsum). + // TODO(anudhyan): It's possible that Eigen's contract and inflate can be + // chained here to avoid materializing an intermediate. + Tensor output_inflated; + OP_REQUIRES_OK( + ctx, EinsumHelper::StrideOrInflate( + ctx, contraction_output, result_labels, output_label_counts, + true /* should_inflate */, &output_inflated)); + if (output_inflated.dims() > contraction_output.dims()) { + // We inflated the output. Modify result labels accordingly. + Labels inflated_labels; + for (int label : result_labels) { + inflated_labels.insert(inflated_labels.end(), + output_label_counts[label], label); + } + result_labels.swap(inflated_labels); + } + // Find the permutation to map the result labels to the output labels. Note + // that both the result and the final output may have the repeated labels, + // in which case the permutation preserves the left-to-right ordering. + // E.g. if result labels are [0, 0, 1] and output is [0, l, 0] then the + // permutation should be [0, 2, 1]. We also use the fact that repeated + // labels in the result are adjacent to each other. + std::vector output_permutation(output_labels.size()); + std::vector label_to_position(num_labels, -1); + for (int i = 0; i < result_labels.size(); ++i) { + // Remember the position of only the leftmost result label. + if (label_to_position[result_labels[i]] == -1) { + label_to_position[result_labels[i]] = i; + } + } + for (int i = 0; i < output_labels.size(); ++i) { + output_permutation[i] = label_to_position[output_labels[i]]; + // We have found the leftmost occurrence. The next one would be adjacent. + label_to_position[output_labels[i]] += 1; + } + Tensor output; + OP_REQUIRES_OK(ctx, EinsumHelper::TransposeOperand( + ctx, output_inflated, output_permutation, &output)); + ctx->set_output(0, std::move(output)); + } + + string TraceString(const OpKernelContext& ctx, bool verbose) const override { + string op = profiler::TraceMeOp(name_view(), type_string_view()); + string equation = strings::StrCat("(", equation_, ")"); + if (verbose) { + string shape = ShapeTraceString(ctx); + if (!shape.empty()) { + return tsl::profiler::TraceMeEncode( + std::move(op), {{"equation", equation}, {"shape", shape}}); + } + } + return tsl::profiler::TraceMeEncode(std::move(op), + {{"equation", equation}}); + } + + private: + string equation_; + OperandLabels input_labels_; + Labels output_labels_; + std::vector label_types_; + OperandLabelCounts input_label_counts_; + LabelCounts output_label_counts_; + absl::InlinedVector input_has_ellipsis_; + bool output_has_ellipsis_ = false; +}; + +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM +// Forward declarations of the functor specializations for GPU. +namespace functor { +#define DECLARE_GPU_SPEC(T, N) \ + template <> \ + void StrideFunctor::operator()( \ + const GPUDevice& d, typename TTypes::ConstTensor input, \ + const Eigen::DSizes& strides, \ + typename TTypes::Tensor output); \ + extern template struct StrideFunctor; \ + template <> \ + void InflateFunctor::operator()( \ + const GPUDevice& d, typename TTypes::ConstTensor input, \ + const Eigen::DSizes& strides, \ + typename TTypes::Tensor output); \ + extern template struct InflateFunctor; + +#define DECLARE_GPU_SPECS(T) \ + DECLARE_GPU_SPEC(T, 1); \ + DECLARE_GPU_SPEC(T, 2); \ + DECLARE_GPU_SPEC(T, 3); \ + DECLARE_GPU_SPEC(T, 4); \ + DECLARE_GPU_SPEC(T, 5); \ + DECLARE_GPU_SPEC(T, 6); + +TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPECS); +// TODO(rocm): Enable once complex types are supported. +#if GOOGLE_CUDA +DECLARE_GPU_SPECS(complex64); +DECLARE_GPU_SPECS(complex128); +#endif +#undef DECLARE_GPU_SPEC +#undef DECLARE_GPU_SPECS +} // namespace functor +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_LINALG_EINSUM_OP_IMPL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/linalg/eye_functor.h b/third_party/tflite-hdrs/tensorflow/core/kernels/linalg/eye_functor.h new file mode 100644 index 00000000..c77372f0 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/linalg/eye_functor.h @@ -0,0 +1,32 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_KERNELS_LINALG_EYE_FUNCTOR_H_ +#define TENSORFLOW_CORE_KERNELS_LINALG_EYE_FUNCTOR_H_ + +#include "tensorflow/core/framework/tensor_types.h" + +namespace tensorflow { +namespace functor { + +template +struct EyeFunctor { + void operator()(const Device& device, + typename TTypes::Tensor matrix_batch); +}; + +} // namespace functor +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_EYE_FUNCTOR_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/linalg/linalg_ops_common.h b/third_party/tflite-hdrs/tensorflow/core/kernels/linalg/linalg_ops_common.h new file mode 100644 index 00000000..b4b98921 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/linalg/linalg_ops_common.h @@ -0,0 +1,224 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_KERNELS_LINALG_LINALG_OPS_COMMON_H_ +#define TENSORFLOW_CORE_KERNELS_LINALG_LINALG_OPS_COMMON_H_ + +// Classes to support linear algebra functionality, similar to the numpy.linalg +// module. Supports batch computation on several matrices at once, sharding the +// computations across different threads if necessary. +#include +#include + +#include "Eigen/Core" // from @eigen_archive +#include "tensorflow/core/framework/kernel_def_builder.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/gtl/inlined_vector.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/work_sharder.h" + +namespace tensorflow { + +// Base class for linear algebra operators. +template +class LinearAlgebraOp : public OpKernel { + public: + explicit LinearAlgebraOp(OpKernelConstruction* context) : OpKernel(context) {} + + void Compute(OpKernelContext* context) override; + + protected: + using TensorShapes = absl::InlinedVector; + // Returns the number of leading inputs that are to be treated as matrix + // inputs. By default this is all the inputs. Derived classes can override + // this to tell the base class to ignore one or more trailing inputs. + virtual int NumMatrixInputs(const OpKernelContext* context) const { + return context->num_inputs(); + } + + // Returns true if the number of inputs and their shapes are as expected. + // Many ops take a single square input matrix, so we provide that as a default + // implementation for convenience. + virtual void ValidateInputMatrixShapes( + OpKernelContext* context, const TensorShapes& input_matrix_shapes) const { + ValidateSingleSquareMatrix(context, input_matrix_shapes); + } + + // Convenience validators for common cases: + // + // Validate op taking a single matrix A. + static void ValidateSingleMatrix(OpKernelContext* context, + const TensorShapes& input_matrix_shapes); + // Validate op taking a single square matrix A. + static void ValidateSingleSquareMatrix( + OpKernelContext* context, const TensorShapes& input_matrix_shapes); + // Validate op taking two matrices A and B that have the same number of rows. + static void ValidateSolver(OpKernelContext* context, + const TensorShapes& input_matrix_shapes); + // Validate op taking two matrices A and B that have the same number of rows + // and A is square. + static void ValidateSquareSolver(OpKernelContext* context, + const TensorShapes& input_matrix_shapes); + + // Returns the output shapes of each individual matrix operation. Output + // matrices shapes must be rank 0, 1, or 2. Scalar outputs are rank 0. + // + // The derived class may return a number of shapes (N) less than + // context->num_outputs() (M) to indicate that a only leading subset of + // the outputs will be populated. In this case, a dummy scalar tensor with + // value zero will be return for the last M-N outputs. + // + // For many ops, the output dimensions are the same as the input dimensions, + // so we provide that as a default implementation for convenience. + virtual TensorShapes GetOutputMatrixShapes( + const TensorShapes& input_matrix_shapes) const { + return input_matrix_shapes; + } + + // Returns the cost per matrix operation. This is used to determine the + // number of threads to use for parallelizing calls to ComputeMatrix in + // batch mode. Cost per unit is assumed to be roughly 1ns, based on comments + // in core/util/work_sharder.cc. Many linear algebra ops take roughly max(m,n) + // * min(m,n)^2, where the first input matrix is m-by-n. We provide that as a + // default implementation for convenience. + virtual int64_t GetCostPerUnit( + const TensorShapes& input_matrix_shapes) const { + double m = static_cast(input_matrix_shapes[0].dim_size(0)); + double n = static_cast(input_matrix_shapes[0].dim_size(1)); + double cost = std::max(m, n) * std::min(m, n) * std::min(m, n); + return cost >= static_cast(kint64max) ? kint64max + : static_cast(cost); + } + + // Returns true if it is safe to forward (alias) input to output buffer + // and expect the kernel to perform the computation inplace. + virtual bool EnableInputForwarding() const { return true; } + + using InputMatrix = Eigen::Matrix; + using InputConstMatrixMap = Eigen::Map; + using InputMatrixMap = Eigen::Map; + using InputConstVectorMap = + Eigen::Map>; + using InputConstMatrixMaps = gtl::InlinedVector; + using InputMatrixMaps = gtl::InlinedVector; + using InputRealScalar = typename Eigen::NumTraits::Real; + + using OutputMatrix = Eigen::Matrix; + using OutputConstMatrixMap = Eigen::Map; + using OutputMatrixMap = Eigen::Map; + using OutputConstVectorMap = + Eigen::Map>; + using OutputConstMatrixMaps = gtl::InlinedVector; + using OutputMatrixMaps = gtl::InlinedVector; + using OutputRealScalar = typename Eigen::NumTraits::Real; + + // backward compatibility + using Scalar = OutputScalar; + using Matrix = + Eigen::Matrix; + using ConstMatrixMap = Eigen::Map; + using MatrixMap = Eigen::Map; + using ConstVectorMap = + Eigen::Map>; + using ConstMatrixMaps = gtl::InlinedVector; + using MatrixMaps = gtl::InlinedVector; + using RealScalar = typename Eigen::NumTraits::Real; + + // Performs a single matrix computation given input matrices, and + // stores the result in outputs. For batch operations, this will be called + // repeatedly for a single call to Compute() when multiple matrices exist in + // input Tensors with rank > 2. In this case the calls to ComputeMatrix are + // parallelized. The number of threads used is determined by a cost model from + // the value returned by GetCostPerUnit(). + virtual void ComputeMatrix(OpKernelContext* context, + const InputConstMatrixMaps& inputs, + OutputMatrixMaps* outputs) = 0; + + private: + using TensorInputs = absl::InlinedVector; + using TensorOutputs = absl::InlinedVector; + // This function maps 2-d slices (matrices) of the input and output tensors + // using Eigen::Map and calls ComputeMatrix implemented in terms of the + // Eigen::MatrixBase API by the derived class. + // + // The 'matrix_index' parameter specifies the index of the matrix to be used + // from each input tensor, and the index of the matrix to be written to each + // output tensor. The input matrices are in row major order, and located at + // the memory addresses + // inputs[i].flat().data() + + // matrix_index * input_matrix_shapes[i].num_elements() + // for i in 0...inputs.size()-1. + // The output matrices are in row major order, and located at the memory + // address + // outputs[i]->flat().data() + + // matrix_index * output_matrix_shapes[i].num_elements(). + // for i in 0...outputs.size()-1. + // + void ComputeTensorSlice(OpKernelContext* context, int64_t matrix_index, + const TensorInputs& inputs, + const TensorShapes& input_matrix_shapes, + const TensorOutputs& outputs, + const TensorShapes& output_matrix_shapes); + + void AnalyzeInputs(OpKernelContext* context, TensorInputs* inputs, + TensorShapes* input_matrix_shapes, + TensorShape* batch_shape); + + void PrepareOutputs(OpKernelContext* context, + const TensorShapes& input_matrix_shapes, + const TensorShape& batch_shape, TensorOutputs* outputs, + TensorShapes* output_matrix_shapes); +}; + +// Declare LinearAlgebraOp, which is explicitly instantiated in +// linalg_ops_common.cc for half,float, double, complex64, and complex128. +extern template class LinearAlgebraOp; +extern template class LinearAlgebraOp; +extern template class LinearAlgebraOp; +extern template class LinearAlgebraOp; +extern template class LinearAlgebraOp; + +} // namespace tensorflow + +#define INHERIT_LINALG_TYPEDEFS(Scalar) \ + typedef LinearAlgebraOp Base; \ + using RealScalar = typename Eigen::NumTraits::Real; \ + using Matrix = typename Base::Matrix; \ + using MatrixMap = typename Base::MatrixMap; \ + using MatrixMaps = typename Base::MatrixMaps; \ + using ConstMatrixMap = typename Base::ConstMatrixMap; \ + using ConstMatrixMaps = typename Base::ConstMatrixMaps; \ + using ConstVectorMap = typename Base::ConstVectorMap; \ + using TensorShapes = typename Base::TensorShapes; + +#define REGISTER_LINALG_OP_CPU(OpName, OpClass, Scalar) \ + REGISTER_KERNEL_BUILDER( \ + Name(OpName).Device(DEVICE_CPU).TypeConstraint("T"), OpClass) + +#define REGISTER_LINALG_OP_GPU(OpName, OpClass, Scalar) \ + REGISTER_KERNEL_BUILDER( \ + Name(OpName).Device(DEVICE_GPU).TypeConstraint("T"), OpClass) + +// Deprecated, use one of the device-specific macros above. +#define REGISTER_LINALG_OP(OpName, OpClass, Scalar) \ + REGISTER_LINALG_OP_CPU(OpName, OpClass, Scalar) + +#endif // TENSORFLOW_CORE_KERNELS_LINALG_LINALG_OPS_COMMON_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/linalg/matrix_band_part_op.h b/third_party/tflite-hdrs/tensorflow/core/kernels/linalg/matrix_band_part_op.h new file mode 100644 index 00000000..2f68eba6 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/linalg/matrix_band_part_op.h @@ -0,0 +1,37 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_LINALG_MATRIX_BAND_PART_OP_H_ +#define TENSORFLOW_CORE_KERNELS_LINALG_MATRIX_BAND_PART_OP_H_ + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +namespace functor { + +template +struct MatrixBandPartFunctor { + void operator()(OpKernelContext* context, const Device& device, + int num_upper_diags, int num_lower_diags, + typename TTypes::ConstTensor input, + typename TTypes::Tensor output); +}; + +} // namespace functor +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_LINALG_MATRIX_BAND_PART_OP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/linalg/matrix_diag_op.h b/third_party/tflite-hdrs/tensorflow/core/kernels/linalg/matrix_diag_op.h new file mode 100644 index 00000000..01c875ca --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/linalg/matrix_diag_op.h @@ -0,0 +1,74 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_LINALG_MATRIX_DIAG_OP_H_ +#define TENSORFLOW_CORE_KERNELS_LINALG_MATRIX_DIAG_OP_H_ + +// Generator definition for MatrixDiagOp, must be compilable by nvcc. + +#include + +#include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +namespace functor { + +// Reads the diagonal packing alignment. +void ReadAlignment(OpKernelConstruction* context, + bool* left_align_superdiagonal, + bool* left_align_subdiagonal); + +// Calculates diagonal length and content offset (from aligning) of a diagonal. +// Returns a pair of integers {diag_len, content_offset}: +// - diag_len: The length of the diag_index-th diagonal. +// - content_offset: Each diagonal is stored as a row in the compact format. +// If the diagonal is shorter than max_diag_len, its content is aligned +// either to the left or right. content_offset is the index in the row +// where the first element of the diag-index-th diagonal is stored. It is +// always zero when the diagonal is left-aligned. +std::pair ComputeDiagLenAndContentOffset( + int diag_index, int max_diag_len, int num_rows, int num_cols, + bool left_align_superdiagonal, bool left_align_subdiagonal); + +template +struct MatrixDiagPart { + EIGEN_ALWAYS_INLINE static void Compute( + OpKernelContext* context, const Device& device, + typename TTypes::ConstTensor& input, + typename TTypes::Tensor& output, const Eigen::Index lower_diag_index, + const Eigen::Index upper_diag_index, const Eigen::Index max_diag_len, + const T padding_value, const bool left_align_superdiagonal, + const bool left_align_subdiagonal); +}; + +template +struct MatrixDiag { + EIGEN_ALWAYS_INLINE static void Compute( + OpKernelContext* context, const Device& device, + typename TTypes::ConstTensor& diag, + typename TTypes::Tensor& output, + const Eigen::Index lower_diag_index, const Eigen::Index upper_diag_index, + const Eigen::Index max_diag_len, const T padding_value, + const bool left_align_superdiagonal, const bool left_align_subdiagonal); +}; + +} // namespace functor +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_LINALG_MATRIX_DIAG_OP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/linalg/matrix_set_diag_op.h b/third_party/tflite-hdrs/tensorflow/core/kernels/linalg/matrix_set_diag_op.h new file mode 100644 index 00000000..449a3607 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/linalg/matrix_set_diag_op.h @@ -0,0 +1,42 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_LINALG_MATRIX_SET_DIAG_OP_H_ +#define TENSORFLOW_CORE_KERNELS_LINALG_MATRIX_SET_DIAG_OP_H_ + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +namespace functor { + +template +struct MatrixSetDiag { + static void Compute(OpKernelContext* context, const Device& device, + typename TTypes::ConstTensor& input, + typename TTypes::ConstTensor& diag, + typename TTypes::Tensor& output, + const Eigen::Index lower_diag_index, + const Eigen::Index upper_diag_index, + const Eigen::Index max_diag_len, + const bool left_align_superdiagonal, + const bool left_align_subdiagonal); +}; + +} // namespace functor +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_LINALG_MATRIX_SET_DIAG_OP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/linalg/matrix_solve_ls_op_impl.h b/third_party/tflite-hdrs/tensorflow/core/kernels/linalg/matrix_solve_ls_op_impl.h new file mode 100644 index 00000000..c75c494e --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/linalg/matrix_solve_ls_op_impl.h @@ -0,0 +1,166 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_LINALG_MATRIX_SOLVE_LS_OP_IMPL_H_ +#define TENSORFLOW_CORE_KERNELS_LINALG_MATRIX_SOLVE_LS_OP_IMPL_H_ + +// See docs in ../ops/linalg_ops.cc. + +#include "Eigen/Cholesky" // from @eigen_archive +#include "Eigen/Core" // from @eigen_archive +#include "Eigen/QR" // from @eigen_archive +#include "tensorflow/core/framework/kernel_def_builder.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/kernels/linalg/linalg_ops_common.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +template +class MatrixSolveLsOp : public LinearAlgebraOp { + public: + typedef LinearAlgebraOp Base; + + explicit MatrixSolveLsOp(OpKernelConstruction* context) : Base(context) { + OP_REQUIRES_OK(context, context->GetAttr("fast", &fast_)); + } + + using TensorShapes = typename Base::TensorShapes; + using Matrix = typename Base::Matrix; + using MatrixMaps = typename Base::MatrixMaps; + using ConstMatrixMap = typename Base::ConstMatrixMap; + using ConstMatrixMaps = typename Base::ConstMatrixMaps; + + // Tell the base class to ignore the regularization parameter + // in context->input(2). + int NumMatrixInputs(const OpKernelContext* context) const final { return 2; } + + void ValidateInputMatrixShapes( + OpKernelContext* context, + const TensorShapes& input_matrix_shapes) const final { + Base::ValidateSolver(context, input_matrix_shapes); + } + + TensorShapes GetOutputMatrixShapes( + const TensorShapes& input_matrix_shapes) const final { + return TensorShapes({TensorShape({input_matrix_shapes[0].dim_size(1), + input_matrix_shapes[1].dim_size(1)})}); + } + + int64_t GetCostPerUnit(const TensorShapes& input_matrix_shapes) const final { + double m = static_cast(input_matrix_shapes[0].dim_size(0)); + double n = static_cast(input_matrix_shapes[0].dim_size(1)); + double num_rhss = static_cast(input_matrix_shapes[1].dim_size(1)); + double cost = std::max(m, n) * std::min(m, n) * (std::min(m, n) + num_rhss); + return cost >= static_cast(kint64max) ? kint64max + : static_cast(cost); + } + + bool EnableInputForwarding() const final { return false; } + + void ComputeMatrix(OpKernelContext* context, const ConstMatrixMaps& inputs, + MatrixMaps* outputs) final { + const ConstMatrixMap& matrix = inputs[0]; + const ConstMatrixMap& rhs = inputs[1]; + const auto& l2_regularizer_in = context->input(2); + OP_REQUIRES( + context, TensorShapeUtils::IsScalar(l2_regularizer_in.shape()), + errors::InvalidArgument("l2_regularizer must be scalar, got shape ", + l2_regularizer_in.shape().DebugString())); + const double l2_regularizer = l2_regularizer_in.scalar()(); + OP_REQUIRES(context, l2_regularizer >= 0, + errors::InvalidArgument("l2_regularizer must be >= 0.")); + + const int64_t rows = matrix.rows(); + const int64_t cols = matrix.cols(); + if (rows == 0 || cols == 0 || rhs.rows() == 0 || rhs.cols() == 0) { + // The result is the empty matrix. + return; + } + if (fast_) { + // The fast branch assumes that matrix is not rank deficient and + // not too ill-conditioned. Specifically, the reciprocal condition number + // should be greater than the square root of the machine precision, i.e. + // 1 / cond(matrix) > sqrt(std::numeric_limits::epsilon()). + // This branch solves over- or underdetermined least-squares problems + // via the normal equations and Cholesky decomposition. + if (rows >= cols) { + // Overdetermined case (rows >= cols): Solves the ordinary (possibly + // regularized) least-squares problem + // min || A * X - RHS ||_F^2 + l2_regularizer ||X||_F^2 + // by solving the normal equations + // (A^T * A + l2_regularizer * I) X = A^T RHS + // using Cholesky decomposition. + Matrix gramian(cols, cols); + gramian.template triangularView() = + matrix.adjoint() * matrix; + if (l2_regularizer > 0) { + gramian += + (Scalar(l2_regularizer) * Matrix::Ones(cols, 1)).asDiagonal(); + } + const Eigen::LLT, Eigen::Lower> llt(gramian); + OP_REQUIRES( + context, llt.info() == Eigen::Success, + errors::InvalidArgument("Input matrix was rank deficient or " + "ill-conditioned. Try setting fast=False " + "or provide a larger l2_regularizer > 0.")); + outputs->at(0).noalias() = matrix.adjoint() * rhs; + llt.solveInPlace(outputs->at(0)); + } else { + // Underdetermined case (rows < cols): Solves the minimum-norm problem + // min ||X||_F^2 s.t. A*X = RHS + // by solving the normal equations of the second kind + // (A * A^T + l2_regularizer * I) Z = RHS, X = A^T * Z + // using Cholesky decomposition. + Matrix gramian(rows, rows); + gramian.template triangularView() = + matrix * matrix.adjoint(); + if (l2_regularizer > 0) { + gramian += + (Scalar(l2_regularizer) * Matrix::Ones(rows, 1)).asDiagonal(); + } + const Eigen::LLT, Eigen::Lower> llt(gramian); + OP_REQUIRES( + context, llt.info() == Eigen::Success, + errors::InvalidArgument("Input matrix was rank deficient or " + "ill-conditioned. Try setting fast=False " + "or provide an l2_regularizer > 0.")); + outputs->at(0).noalias() = matrix.adjoint() * llt.solve(rhs); + } + } else { + // Use complete orthogonal decomposition which is backwards stable and + // will compute the minimum-norm solution for rank-deficient matrices. + // This is 6-7 times slower than the fast path. + // + // TODO(rmlarsen): The implementation of + // Eigen::CompleteOrthogonalDecomposition is not blocked, so for + // matrices that do not fit in cache, it is significantly slower than + // the equivalent blocked LAPACK routine xGELSY (e.g. Eigen is ~3x + // slower for 4k x 4k matrices). + // See http://www.netlib.org/lapack/lawnspdf/lawn114.pdf + outputs->at(0) = matrix.completeOrthogonalDecomposition().solve(rhs); + } + } + + private: + bool fast_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_LINALG_MATRIX_SOLVE_LS_OP_IMPL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/linalg/matrix_triangular_solve_op_impl.h b/third_party/tflite-hdrs/tensorflow/core/kernels/linalg/matrix_triangular_solve_op_impl.h new file mode 100644 index 00000000..8e524347 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/linalg/matrix_triangular_solve_op_impl.h @@ -0,0 +1,416 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// See docs in ../ops/linalg_ops.cc. +// +#ifndef TENSORFLOW_CORE_KERNELS_LINALG_MATRIX_TRIANGULAR_SOLVE_OP_IMPL_H_ +#define TENSORFLOW_CORE_KERNELS_LINALG_MATRIX_TRIANGULAR_SOLVE_OP_IMPL_H_ + +#include "Eigen/Core" // from @eigen_archive +#include "tensorflow/core/framework/kernel_def_builder.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/kernels/fill_functor.h" +#include "tensorflow/core/kernels/linalg/linalg_ops_common.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/matmul_bcast.h" + +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM +#include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/kernels/transpose_functor.h" +#include "tensorflow/core/platform/stream_executor.h" +#include "tensorflow/core/util/gpu_solvers.h" +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM + +namespace tensorflow { + +typedef Eigen::ThreadPoolDevice CPUDevice; +typedef Eigen::GpuDevice GPUDevice; + +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM +template +se::DeviceMemory AsDeviceMemory(const Scalar* gpu_memory) { + se::DeviceMemoryBase wrapped(const_cast(gpu_memory)); + se::DeviceMemory typed(wrapped); + return typed; +} + +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM + +// Sequential batch matrix triangular solve kernel that calls Eigen's +// matrix triangular solve. +template +struct SequentialMatrixTriangularSolveKernel { + using Matrix = + Eigen::Matrix; + using ConstMatrixMap = Eigen::Map; + using MatrixMap = Eigen::Map; + using RealScalar = typename Eigen::NumTraits::Real; + + static ConstMatrixMap ConstTensorSliceToEigenMatrix(const Tensor& t, + int slice) { + return ConstMatrixMap( + t.flat().data() + slice * t.dim_size(1) * t.dim_size(2), + t.dim_size(1), t.dim_size(2)); + } + + static MatrixMap TensorSliceToEigenMatrix(Tensor* t, int slice) { + return MatrixMap( + t->flat().data() + slice * t->dim_size(1) * t->dim_size(2), + t->dim_size(1), t->dim_size(2)); + } + + static void Run(const Tensor& in_x, const Tensor& in_y, bool lower, + bool adjoint, const MatMulBCast& bcast, Tensor* out, + int start, int limit) { + const bool should_bcast = bcast.IsBroadcastingRequired(); + const auto& x_batch_indices = bcast.x_batch_indices(); + const auto& y_batch_indices = bcast.y_batch_indices(); + for (int64_t i = start; i < limit; ++i) { + const int64_t x_batch_index = should_bcast ? x_batch_indices[i] : i; + const int64_t y_batch_index = should_bcast ? y_batch_indices[i] : i; + auto matrix = ConstTensorSliceToEigenMatrix(in_x, x_batch_index); + auto rhs = ConstTensorSliceToEigenMatrix(in_y, y_batch_index); + auto output = TensorSliceToEigenMatrix(out, i); + if (lower) { + auto triangle = matrix.template triangularView(); + if (adjoint) { + output.noalias() = triangle.adjoint().solve(rhs); + } else { + output.noalias() = triangle.solve(rhs); + } + } else { + auto triangle = matrix.template triangularView(); + if (adjoint) { + output.noalias() = triangle.adjoint().solve(rhs); + } else { + output.noalias() = triangle.solve(rhs); + } + } + } + } +}; + +template +struct LaunchBatchMatrixTriangularSolve; + +template +struct LaunchBatchMatrixTriangularSolve { + static void Launch(OpKernelContext* context, const Tensor& in_x, + const Tensor& in_y, bool adjoint, bool lower, + const MatMulBCast& bcast, Tensor* out) { + // Number of matrix triangular solves i.e. size of the batch. + const int64_t batch_size = bcast.output_batch_size(); + const int64_t cost_per_unit = + in_x.dim_size(1) * in_x.dim_size(1) * in_y.dim_size(2) / 2; + auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads()); + + using Matrix = + Eigen::Matrix; + using ConstMatrixMap = Eigen::Map; + using RealScalar = typename Eigen::NumTraits::Real; + + Shard(worker_threads.num_threads, worker_threads.workers, batch_size, + cost_per_unit, + [&in_x, &in_y, adjoint, lower, &bcast, out](int start, int limit) { + SequentialMatrixTriangularSolveKernel::Run( + in_x, in_y, lower, adjoint, bcast, out, start, limit); + }); + } +}; + +template +class BaseMatrixTriangularSolveOp : public OpKernel { + public: + explicit BaseMatrixTriangularSolveOp(OpKernelConstruction* context) + : OpKernel(context) { + OP_REQUIRES_OK(context, context->GetAttr("lower", &lower_)); + OP_REQUIRES_OK(context, context->GetAttr("adjoint", &adjoint_)); + } + + ~BaseMatrixTriangularSolveOp() override {} + + void Compute(OpKernelContext* ctx) override { + const Tensor& in0 = ctx->input(0); + const Tensor& in1 = ctx->input(1); + + ValidateInputTensors(ctx, in0, in1); + if (!ctx->status().ok()) { + return; + } + + MatMulBCast bcast(in0.shape().dim_sizes(), in1.shape().dim_sizes()); + OP_REQUIRES( + ctx, bcast.IsValid(), + errors::InvalidArgument( + "In[0] and In[1] must have compatible batch dimensions: ", + in0.shape().DebugString(), " vs. ", in1.shape().DebugString())); + + TensorShape out_shape = bcast.output_batch_shape(); + auto batch_size = bcast.output_batch_size(); + auto d0 = in0.dim_size(in0.dims() - 2); + auto d1 = in0.dim_size(in0.dims() - 1); + Tensor in0_reshaped; + OP_REQUIRES( + ctx, + in0_reshaped.CopyFrom(in0, TensorShape({bcast.x_batch_size(), d0, d1})), + errors::Internal("Failed to reshape In[0] from ", + in0.shape().DebugString())); + auto d2 = in1.dim_size(in1.dims() - 2); + auto d3 = in1.dim_size(in1.dims() - 1); + Tensor in1_reshaped; + OP_REQUIRES( + ctx, + in1_reshaped.CopyFrom(in1, TensorShape({bcast.y_batch_size(), d2, d3})), + errors::Internal("Failed to reshape In[1] from ", + in1.shape().DebugString())); + if (adjoint_) std::swap(d0, d1); + OP_REQUIRES(ctx, d1 == d2, + errors::InvalidArgument( + "In[0] mismatch In[1] shape: ", d1, " vs. ", d2, ": ", + in0.shape().DebugString(), " ", in1.shape().DebugString(), + " ", lower_, " ", adjoint_)); + OP_REQUIRES_OK(ctx, out_shape.AddDimWithStatus(d0)); + OP_REQUIRES_OK(ctx, out_shape.AddDimWithStatus(d3)); + Tensor* out = nullptr; + OP_REQUIRES_OK(ctx, ctx->allocate_output(0, out_shape, &out)); + if (out->NumElements() == 0) { + return; + } + Tensor out_reshaped; + OP_REQUIRES(ctx, + out_reshaped.CopyFrom(*out, TensorShape({batch_size, d0, d3})), + errors::Internal("Failed to reshape output from ", + out->shape().DebugString())); + LaunchBatchMatrixTriangularSolve::Launch( + ctx, in0_reshaped, in1_reshaped, adjoint_, lower_, bcast, + &out_reshaped); + } + + private: + virtual void ValidateInputTensors(OpKernelContext* ctx, const Tensor& in0, + const Tensor& in1) = 0; + bool lower_; + bool adjoint_; +}; + +template +class MatrixTriangularSolveOp + : public BaseMatrixTriangularSolveOp { + public: + explicit MatrixTriangularSolveOp(OpKernelConstruction* context) + : BaseMatrixTriangularSolveOp(context) {} + + ~MatrixTriangularSolveOp() override {} + + private: + void ValidateInputTensors(OpKernelContext* ctx, const Tensor& in0, + const Tensor& in1) override { + const auto in0_num_dims = in0.dims(); + OP_REQUIRES( + ctx, in0_num_dims >= 2, + errors::InvalidArgument("In[0] ndims must be >= 2: ", in0_num_dims)); + + const auto in1_num_dims = in1.dims(); + OP_REQUIRES( + ctx, in1_num_dims >= 2, + errors::InvalidArgument("In[1] ndims must be >= 2: ", in1_num_dims)); + + const auto in0_last_dim = in0.dim_size(in0_num_dims - 1); + const auto in0_prev_dim = in0.dim_size(in0_num_dims - 2); + OP_REQUIRES(ctx, in0_last_dim == in0_prev_dim, + errors::InvalidArgument( + "In[0] matrices in the last dimensions must be square (", + in0_last_dim, " =/= ", in0_prev_dim, ")")); + } +}; + +#define REGISTER_BATCH_MATRIX_TRIANGULAR_SOLVE_CPU(TYPE) \ + REGISTER_KERNEL_BUILDER(Name("MatrixTriangularSolve") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T"), \ + MatrixTriangularSolveOp); \ + REGISTER_KERNEL_BUILDER(Name("BatchMatrixTriangularSolve") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T"), \ + MatrixTriangularSolveOp); + +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM + +template +struct LaunchBatchMatrixTriangularSolve { + static void Launch(OpKernelContext* context, const Tensor& in_x, + const Tensor& in_y, bool adjoint, bool lower, + const MatMulBCast& bcast, Tensor* out) { + auto* stream = context->op_device_context()->stream(); + + const uint64 m = in_x.dim_size(1); + const uint64 n = out->dim_size(2); + + // Do a memcpy when we don't need to broadcast. + if (!bcast.IsBroadcastingRequired() || out->shape() == in_y.shape()) { + auto src_device_mem = AsDeviceMemory(in_y.template flat().data()); + auto dst_device_mem = AsDeviceMemory(out->template flat().data()); + OP_REQUIRES_OK(context, stream->MemcpyD2D(&dst_device_mem, src_device_mem, + bcast.y_batch_size() * m * n * + sizeof(Scalar))); + } else { + std::vector out_ptrs; + std::vector b_tmp_ptrs; + auto* b_base_ptr = in_y.template flat().data(); + const std::vector& b_batch_indices = bcast.y_batch_indices(); + for (int64_t i = 0; i < bcast.y_batch_size(); ++i) { + b_tmp_ptrs.push_back(b_base_ptr + i * m * n); + } + for (int64_t i = 0; i < bcast.output_batch_size(); ++i) { + auto src_device_mem = AsDeviceMemory(b_tmp_ptrs[b_batch_indices[i]]); + auto dst_device_mem = + AsDeviceMemory(out->template flat().data() + i * m * n); + OP_REQUIRES_OK(context, + stream->MemcpyD2D(&dst_device_mem, src_device_mem, + m * n * sizeof(Scalar))); + } + } + + if (out->NumElements() == 0) { + return; + } + +#if GOOGLE_CUDA + + cublasSideMode_t side = CUBLAS_SIDE_RIGHT; + cublasFillMode_t uplo; + cublasOperation_t trans; + cublasDiagType_t diag = CUBLAS_DIAG_NON_UNIT; + + // Cublas does + // output = matrix \ rhs + // where matrix, rhs and output are assumed to be in column major. + // We want the output to be in row-major, so we can compute + // output' = rhs' / matrix' (' stands for transpose) + // Upper/lower needs to be swapped for this. + + uplo = lower ? CUBLAS_FILL_MODE_UPPER : CUBLAS_FILL_MODE_LOWER; + trans = adjoint ? CUBLAS_OP_C : CUBLAS_OP_N; + +#elif TENSORFLOW_USE_ROCM + rocblas_side side = rocblas_side_right; + rocblas_fill uplo; + rocblas_operation trans; + rocblas_diagonal diag = rocblas_diagonal_non_unit; + + // rocblas does + // output = matrix \ rhs + // where matrix, rhs and output are assumed to be in column major. + // We want the output to be in row-major, so we can compute + // output' = rhs' / matrix' (' stands for transpose) + // Upper/lower needs to be swapped for this. + + uplo = lower ? rocblas_fill_upper : rocblas_fill_lower; + trans = adjoint ? rocblas_operation_conjugate_transpose + : rocblas_operation_none; + +#endif + + auto solver = absl::make_unique(context); + const uint64 leading_dim_matrix = m; + const uint64 leading_dim_output = n; + const uint64 colmajor_rows = n; + const uint64 colmajor_cols = m; + + const int64_t batch_size = bcast.output_batch_size(); + std::vector a_ptrs; + std::vector out_ptrs; + std::vector a_tmp_ptrs; + a_ptrs.reserve(batch_size); + out_ptrs.reserve(batch_size); + a_tmp_ptrs.reserve(bcast.x_batch_size()); + auto* a_base_ptr = in_x.template flat().data(); + auto* out_base_ptr = out->template flat().data(); + + if (!bcast.IsBroadcastingRequired()) { + for (int64_t i = 0; i < batch_size; ++i) { + a_ptrs.push_back(a_base_ptr + i * m * m); + out_ptrs.push_back(out_base_ptr + i * m * n); + } + } else { + const std::vector& a_batch_indices = bcast.x_batch_indices(); + for (int64_t i = 0; i < bcast.x_batch_size(); ++i) { + a_tmp_ptrs.push_back(a_base_ptr + i * m * m); + } + for (int64_t i = 0; i < batch_size; ++i) { + a_ptrs.push_back(a_tmp_ptrs[a_batch_indices[i]]); + out_ptrs.push_back(out_base_ptr + i * m * n); + } + } + + typedef Scalar Coefficient; + const Scalar alpha = Scalar(1.0); + + // TODO(b/146763573): Consider using Trsv here when the right hand side is + // a vector. This will require an explicit transpose since Trsv assumes + // CUBLAS_SIDE_LEFT. + if (batch_size == 1) { + OP_REQUIRES_OK( + context, + solver->Trsm(side, uplo, trans, diag, colmajor_rows, colmajor_cols, + &alpha, a_ptrs[0], leading_dim_matrix /*lda*/, + out_ptrs[0], leading_dim_output /*ldb*/)); + } else { + // Heuristic for choosing between batched interface vs. non-batched + // interface. This is inspired by matrix_solve_op and can probably be + // tuned. + // TODO(b/146763573): Tune this heuristic. + const int kMaxMatrixSizeToBatchSizeRatio = 128; + const bool use_batched_solver = + m <= kMaxMatrixSizeToBatchSizeRatio * batch_size; + if (use_batched_solver) { + OP_REQUIRES_OK( + context, solver->TrsmBatched( + side, uplo, trans, diag, colmajor_rows, colmajor_cols, + &alpha, &a_ptrs[0], leading_dim_matrix /*lda*/, + &out_ptrs[0], leading_dim_output /*ldb*/, batch_size)); + } else { + for (int batch = 0; batch < batch_size; ++batch) { + OP_REQUIRES_OK( + context, solver->Trsm(side, uplo, trans, diag, colmajor_rows, + colmajor_cols, &alpha, a_ptrs[batch], + leading_dim_matrix /*lda*/, out_ptrs[batch], + leading_dim_output /*ldb*/)); + } + } + } + } +}; + +#define REGISTER_BATCH_MATRIX_TRIANGULAR_SOLVE_GPU(TYPE) \ + REGISTER_KERNEL_BUILDER(Name("MatrixTriangularSolve") \ + .Device(DEVICE_GPU) \ + .TypeConstraint("T"), \ + MatrixTriangularSolveOp); \ + REGISTER_KERNEL_BUILDER(Name("BatchMatrixTriangularSolve") \ + .Device(DEVICE_GPU) \ + .TypeConstraint("T"), \ + MatrixTriangularSolveOp); + +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_LINALG_MATRIX_TRIANGULAR_SOLVE_OP_IMPL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/linalg/qr_op_impl.h b/third_party/tflite-hdrs/tensorflow/core/kernels/linalg/qr_op_impl.h new file mode 100644 index 00000000..c5a1823f --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/linalg/qr_op_impl.h @@ -0,0 +1,318 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_LINALG_QR_OP_IMPL_H_ +#define TENSORFLOW_CORE_KERNELS_LINALG_QR_OP_IMPL_H_ + +// See docs in ../ops/linalg_ops.cc. +// +// This header file is used by the individual qr_*op*.cc files for registering +// individual kernels. A separate file is used for each instantiated kernel to +// improve compilation times. +#include +#include + +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM +#define EIGEN_USE_GPU +#endif + +#include "Eigen/QR" // from @eigen_archive +#include "tensorflow/core/framework/kernel_def_builder.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/kernels/linalg/linalg_ops_common.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/types.h" + +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM +#include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/kernels/cwise_ops.h" +#include "tensorflow/core/kernels/linalg/eye_functor.h" +#include "tensorflow/core/kernels/linalg/matrix_band_part_op.h" +#include "tensorflow/core/kernels/transpose_functor.h" +#include "tensorflow/core/util/gpu_solvers.h" +#endif + +namespace tensorflow { + +template +class QrOp : public LinearAlgebraOp { + public: + typedef LinearAlgebraOp Base; + + explicit QrOp(OpKernelConstruction* context) : Base(context) { + OP_REQUIRES_OK(context, context->GetAttr("full_matrices", &full_matrices_)); + } + + using TensorShapes = typename Base::TensorShapes; + + void ValidateInputMatrixShapes( + OpKernelContext* context, + const TensorShapes& input_matrix_shapes) const final { + Base::ValidateSingleMatrix(context, input_matrix_shapes); + } + + TensorShapes GetOutputMatrixShapes( + const TensorShapes& input_matrix_shapes) const final { + int64_t m = input_matrix_shapes[0].dim_size(0); + int64_t n = input_matrix_shapes[0].dim_size(1); + int64_t min_size = std::min(m, n); + if (full_matrices_) { + return TensorShapes({TensorShape({m, m}), TensorShape({m, n})}); + } else { + return TensorShapes( + {TensorShape({m, min_size}), TensorShape({min_size, n})}); + } + } + + int64_t GetCostPerUnit(const TensorShapes& input_matrix_shapes) const final { + double m = static_cast(input_matrix_shapes[0].dim_size(0)); + double n = static_cast(input_matrix_shapes[0].dim_size(1)); + double max_size = std::max(m, n); + double min_size = std::min(m, n); + double cost = 2 * max_size * min_size * min_size - + 2 * min_size * min_size * min_size / 3.; + // TODO(jpoulson): Increase the cost if full_matrices is true in a manner + // that reflects the algorithm used for the expansion. + return cost >= static_cast(kint64max) ? kint64max + : static_cast(cost); + } + + using Matrix = typename Base::Matrix; + using MatrixMaps = typename Base::MatrixMaps; + using ConstMatrixMap = typename Base::ConstMatrixMap; + using ConstMatrixMaps = typename Base::ConstMatrixMaps; + + void ComputeMatrix(OpKernelContext* context, const ConstMatrixMaps& inputs, + MatrixMaps* outputs) final { + Eigen::HouseholderQR qr(inputs[0]); + const int m = inputs[0].rows(); + const int n = inputs[0].cols(); + const int min_size = std::min(m, n); + + if (full_matrices_) { + outputs->at(0) = qr.householderQ(); + outputs->at(1) = qr.matrixQR().template triangularView(); + } else { + // TODO(jpoulson): Exploit the fact that Householder transformations can + // be expanded faster than they can be applied to an arbitrary matrix + // (Cf. LAPACK's DORGQR). + Matrix tmp = Matrix::Identity(m, min_size); + outputs->at(0) = qr.householderQ() * tmp; + auto qr_top = qr.matrixQR().block(0, 0, min_size, n); + outputs->at(1) = qr_top.template triangularView(); + } + } + + private: + bool full_matrices_; + + QrOp(const QrOp&) = delete; + void operator=(const QrOp&) = delete; +}; + +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM + +typedef Eigen::GpuDevice GPUDevice; + +template +class QrOpGpu : public AsyncOpKernel { + public: + explicit QrOpGpu(OpKernelConstruction* context) : AsyncOpKernel(context) { + OP_REQUIRES_OK(context, context->GetAttr("full_matrices", &full_matrices_)); + } + + void ComputeAsync(OpKernelContext* context, DoneCallback done) final { + const Tensor& input = context->input(0); + const int ndims = input.dims(); + const int64_t m = input.dim_size(ndims - 2); + const int64_t n = input.dim_size(ndims - 1); + const int64_t min_size = std::min(m, n); + const int64_t batch_size = + input.template flat_inner_dims().dimension(0); + + // Validate inputs. + OP_REQUIRES_ASYNC( + context, ndims >= 2, + errors::InvalidArgument("Input must have rank >= 2, got ", ndims), + done); + + // Allocate output. + // If full_matrices_ is true then Q is m x m and R is m x n. + // Otherwise, Q is m x min(m, n), and R is min(m, n) x n. + Tensor* q; + TensorShape q_shape = input.shape(); + q_shape.set_dim(ndims - 1, full_matrices_ ? m : min_size); + OP_REQUIRES_OK_ASYNC(context, context->allocate_output(0, q_shape, &q), + done); + Tensor* r; + TensorShape r_shape = input.shape(); + r_shape.set_dim(ndims - 2, full_matrices_ ? m : min_size); + OP_REQUIRES_OK_ASYNC(context, context->allocate_output(1, r_shape, &r), + done); + + if (input.NumElements() == 0) { + done(); + return; + } + + // TODO(rmlarsen): Convert to std::make_unique when available. + std::unique_ptr solver(new GpuSolver(context)); + + // Allocate temporaries. + Tensor input_transposed; + TensorShape transposed_shape = input.shape(); + transposed_shape.set_dim(ndims - 2, input.dim_size(ndims - 1)); + transposed_shape.set_dim(ndims - 1, input.dim_size(ndims - 2)); + + OP_REQUIRES_OK_ASYNC( + context, + solver->allocate_scoped_tensor(DataTypeToEnum::value, + transposed_shape, &input_transposed), + done); + + Tensor tau; + OP_REQUIRES_OK_ASYNC(context, + solver->allocate_scoped_tensor( + DataTypeToEnum::value, + TensorShape({batch_size, min_size}), &tau), + done); + + // Transpose input, since cuSolver uses column-major, while TensorFlow uses + // row-major storage. + const GPUDevice& device = context->eigen_device(); + OP_REQUIRES_OK_ASYNC( + context, DoMatrixTranspose(device, input, &input_transposed), done); + + // Compute QR decomposition in-place in input_transposed. + std::vector dev_info; + dev_info.push_back(solver->GetDeviceLapackInfo(batch_size, "geqrf")); + auto input_transposed_reshaped = + input_transposed.flat_inner_dims(); + auto tau_matrix = tau.matrix(); + auto r_reshaped = r->flat_inner_dims(); + for (int batch = 0; batch < batch_size; ++batch) { + OP_REQUIRES_OK_ASYNC( + context, + solver->Geqrf(m, n, &input_transposed_reshaped(batch, 0, 0), m, + &tau_matrix(batch, 0), + dev_info.back().mutable_data() + batch), + done); + } + +#if GOOGLE_CUDA + cublasOperation_t transa = CUBLAS_OP_T; + cublasOperation_t transb = CUBLAS_OP_N; + cublasSideMode_t side = CUBLAS_SIDE_LEFT; +#elif TENSORFLOW_USE_ROCM + rocblas_operation transa = rocblas_operation_transpose; + rocblas_operation transb = rocblas_operation_none; + rocblas_side side = rocblas_side_left; +#endif + + // Generate R. R is equal to the upper triangle of the decomposition + // stored in input_transposed. Crop, transpose (to get back to row-major) + // and copy it to the output buffer. + if (full_matrices_ || m == n) { + OP_REQUIRES_OK_ASYNC( + context, DoMatrixTranspose(device, input_transposed, r), done); + } else { + const Scalar alpha(1); + const Scalar beta(0); + const Scalar* dummy = nullptr; + for (int batch = 0; batch < batch_size; ++batch) { + OP_REQUIRES_OK_ASYNC( + context, + solver->Geam(transa, transb, n, full_matrices_ ? m : min_size, + &alpha, &input_transposed_reshaped(batch, 0, 0), m, + &beta, dummy, n, &r_reshaped(batch, 0, 0), n), + done); + } + } + // Extract the upper triangle of r (i.e. zero out the strictly lower + // triangle). + functor::MatrixBandPartFunctor band_part; + auto r_reshaped_const = + const_cast(r)->flat_inner_dims(); + band_part(context, device, 0 /* num_lower_diags */, + -1 /* num_upper_diags */, r_reshaped_const, r_reshaped); + + // Generate Q from the decomposition in input_transposed. + if (m != n && (full_matrices_ || m < n)) { + // Generate full m x m matrix Q by computing the product Q^T * I, + // where the transpose is to get back to row-major form. + // In the complex case we actually form Q^H * I and conjugate it + // to get Q in row-major form. + functor::EyeFunctor eye; + auto q_reshaped = q->flat_inner_dims(); + eye(device, q_reshaped); +#if GOOGLE_CUDA + cublasOperation_t trans = CublasAdjointOp(); +#elif TENSORFLOW_USE_ROCM + rocblas_operation trans = RocblasAdjointOp(); +#endif + for (int batch = 0; batch < batch_size; ++batch) { + // Notice: It appears that Unmqr does not write a zero into *info upon + // success (probably a bug), so we simply re-use the info array already + // zeroed by Geqrf above. + OP_REQUIRES_OK_ASYNC( + context, + solver->Unmqr(side, trans, m, m, min_size, + &input_transposed_reshaped(batch, 0, 0), m, + &tau_matrix(batch, 0), &q_reshaped(batch, 0, 0), m, + dev_info.back().mutable_data() + batch), + done); + } + if (Eigen::NumTraits::IsComplex) { + functor::UnaryFunctor> conj; + conj(device, q->flat() /*out*/, + const_cast(q)->flat() /*in*/); + } + } else { + // Generate m x n matrix Q. In this case we can use the more efficient + // algorithm in Ungqr to generate Q in place. + dev_info.push_back(solver->GetDeviceLapackInfo(batch_size, "orgqr")); + for (int batch = 0; batch < batch_size; ++batch) { + OP_REQUIRES_OK_ASYNC( + context, + solver->Ungqr( + m, n, min_size, &input_transposed_reshaped(batch, 0, 0), m, + &tau_matrix(batch, 0), dev_info.back().mutable_data() + batch), + done); + } + OP_REQUIRES_OK_ASYNC( + context, DoMatrixTranspose(device, input_transposed, q), done); + } + + // Asynchronously check return status from cuSolver kernels. + GpuSolver::CheckLapackInfoAndDeleteSolverAsync(std::move(solver), dev_info, + std::move(done)); + } + + private: + bool full_matrices_; + + QrOpGpu(const QrOpGpu&) = delete; + void operator=(const QrOpGpu&) = delete; +}; + +#endif // GOOGLE_CUDA + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_LINALG_QR_OP_IMPL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/linalg/self_adjoint_eig_v2_op_impl.h b/third_party/tflite-hdrs/tensorflow/core/kernels/linalg/self_adjoint_eig_v2_op_impl.h new file mode 100644 index 00000000..4fba705f --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/linalg/self_adjoint_eig_v2_op_impl.h @@ -0,0 +1,92 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_LINALG_SELF_ADJOINT_EIG_V2_OP_IMPL_H_ +#define TENSORFLOW_CORE_KERNELS_LINALG_SELF_ADJOINT_EIG_V2_OP_IMPL_H_ + +// See docs in ../ops/linalg_ops.cc. + +#include "Eigen/Core" // from @eigen_archive +#include "Eigen/Eigenvalues" // from @eigen_archive +#include "tensorflow/core/framework/kernel_def_builder.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/kernels/linalg/linalg_ops_common.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/platform/denormal.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +template +class SelfAdjointEigV2Op : public LinearAlgebraOp { + public: + typedef LinearAlgebraOp Base; + + explicit SelfAdjointEigV2Op(OpKernelConstruction* context) : Base(context) { + OP_REQUIRES_OK(context, context->GetAttr("compute_v", &compute_v_)); + } + + using TensorShapes = typename Base::TensorShapes; + using Matrix = typename Base::Matrix; + using MatrixMaps = typename Base::MatrixMaps; + using ConstMatrixMap = typename Base::ConstMatrixMap; + using ConstMatrixMaps = typename Base::ConstMatrixMaps; + + TensorShapes GetOutputMatrixShapes( + const TensorShapes& input_matrix_shapes) const final { + int64_t n = input_matrix_shapes[0].dim_size(0); + if (compute_v_) { + return TensorShapes({TensorShape({n}), TensorShape({n, n})}); + } else { + return TensorShapes({TensorShape({n})}); + } + } + + void ComputeMatrix(OpKernelContext* context, const ConstMatrixMaps& inputs, + MatrixMaps* outputs) final { + const int64_t rows = inputs[0].rows(); + if (rows == 0) { + // If X is an empty matrix (0 rows, 0 col), X * X' == X. + // Therefore, we return X. + return; + } + + // This algorithm relies on denormals, so switch them back on locally. + port::ScopedDontFlushDenormal dont_flush_denormals; + + Eigen::SelfAdjointEigenSolver eig( + inputs[0], + compute_v_ ? Eigen::ComputeEigenvectors : Eigen::EigenvaluesOnly); + // TODO(rmlarsen): Output more detailed error info on failure. + OP_REQUIRES( + context, eig.info() == Eigen::Success, + errors::InvalidArgument("Self-adjoint eigen decomposition was not " + "successful. The input might not be valid.")); + + outputs->at(0) = eig.eigenvalues().template cast(); + if (compute_v_) { + outputs->at(1) = eig.eigenvectors(); + } + } + + private: + bool compute_v_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_LINALG_SELF_ADJOINT_EIG_V2_OP_IMPL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/linalg/svd_op_impl.h b/third_party/tflite-hdrs/tensorflow/core/kernels/linalg/svd_op_impl.h new file mode 100644 index 00000000..4e674585 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/linalg/svd_op_impl.h @@ -0,0 +1,135 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_LINALG_SVD_OP_IMPL_H_ +#define TENSORFLOW_CORE_KERNELS_LINALG_SVD_OP_IMPL_H_ + +// See docs in ../ops/linalg_ops.cc. +// +// This header file is used by the individual svd_*op*.cc files for registering +// individual kernels. A separate file is used for each instantiated kernel to +// improve compilation times. +#include + +#include "Eigen/SVD" // from @eigen_archive +#include "tensorflow/core/framework/kernel_def_builder.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/kernels/linalg/linalg_ops_common.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +template +class SvdOp : public LinearAlgebraOp { + public: + typedef LinearAlgebraOp Base; + + explicit SvdOp(OpKernelConstruction* context) : Base(context) { + OP_REQUIRES_OK(context, context->GetAttr("compute_uv", &compute_uv_)); + OP_REQUIRES_OK(context, context->GetAttr("full_matrices", &full_matrices_)); + } + + using TensorShapes = typename Base::TensorShapes; + + void ValidateInputMatrixShapes( + OpKernelContext* context, + const TensorShapes& input_matrix_shapes) const final { + Base::ValidateSingleMatrix(context, input_matrix_shapes); + } + + TensorShapes GetOutputMatrixShapes( + const TensorShapes& input_matrix_shapes) const final { + int64_t m = input_matrix_shapes[0].dim_size(0); + int64_t n = input_matrix_shapes[0].dim_size(1); + int64_t min_size = std::min(m, n); + if (compute_uv_) { + return TensorShapes({TensorShape({min_size}), + TensorShape({m, full_matrices_ ? m : min_size}), + TensorShape({n, full_matrices_ ? n : min_size})}); + } else { + return TensorShapes({TensorShape({min_size})}); + } + } + + // TODO(rmlarsen): This should depend on compute_uv. See b/30409375. + int64_t GetCostPerUnit(const TensorShapes& input_matrix_shapes) const final { + double m = static_cast(input_matrix_shapes[0].dim_size(0)); + double n = static_cast(input_matrix_shapes[0].dim_size(1)); + double cost = 12 * std::max(m, n) * std::min(m, n) * std::min(m, n); + return cost >= static_cast(kint64max) ? kint64max + : static_cast(cost); + } + + using Matrix = typename Base::Matrix; + using MatrixMaps = typename Base::MatrixMaps; + using ConstMatrixMap = typename Base::ConstMatrixMap; + using ConstMatrixMaps = typename Base::ConstMatrixMaps; + + void ComputeMatrix(OpKernelContext* context, const ConstMatrixMaps& inputs, + MatrixMaps* outputs) final { + int64_t n = inputs[0].cols(); + int64_t m = inputs[0].rows(); + const bool empty = (m == 0 || n == 0); + int options = 0; // Don't compute singular vectors; + if (compute_uv_) { + options = full_matrices_ ? Eigen::ComputeFullU | Eigen::ComputeFullV + : Eigen::ComputeThinU | Eigen::ComputeThinV; + } + + if (empty) { + // For an empty matrix where only one dimension is zero, we still set + // U or V to the unit matrix for the dimension that is non-zero. + if (compute_uv_ && full_matrices_) { + if (m > 0) { + outputs->at(1) = Matrix::Identity(m, m); + } else { + outputs->at(2) = Matrix::Identity(n, n); + } + } + return; + } + + Eigen::BDCSVD svd(inputs[0], options); + if (svd.info() != Eigen::Success) { + LOG(ERROR) << "Eigen::BDCSVD failed with error code " << svd.info(); + outputs->at(0).fill(std::numeric_limits::quiet_NaN()); + if (compute_uv_) { + outputs->at(1).fill(std::numeric_limits::quiet_NaN()); + outputs->at(2).fill(std::numeric_limits::quiet_NaN()); + } + } else { + outputs->at(0) = svd.singularValues().template cast(); + if (compute_uv_) { + outputs->at(1) = svd.matrixU(); + outputs->at(2) = svd.matrixV(); + } + } + } + + private: + bool compute_uv_; + bool full_matrices_; + + SvdOp(const SvdOp&) = delete; + void operator=(const SvdOp&) = delete; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_LINALG_SVD_OP_IMPL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/linalg_ops_common.h b/third_party/tflite-hdrs/tensorflow/core/kernels/linalg_ops_common.h new file mode 100644 index 00000000..0aa69801 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/linalg_ops_common.h @@ -0,0 +1,21 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_KERNELS_LINALG_OPS_COMMON_H_ +#define TENSORFLOW_CORE_KERNELS_LINALG_OPS_COMMON_H_ + +// Temporary forwarding header. +#include "tensorflow/core/kernels/linalg/linalg_ops_common.h" + +#endif // TENSORFLOW_CORE_KERNELS_LINALG_OPS_COMMON_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/list_kernels.h b/third_party/tflite-hdrs/tensorflow/core/kernels/list_kernels.h new file mode 100644 index 00000000..9837b087 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/list_kernels.h @@ -0,0 +1,1137 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_KERNELS_LIST_KERNELS_H_ +#define TENSORFLOW_CORE_KERNELS_LIST_KERNELS_H_ + +#define EIGEN_USE_THREADS +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM +#define EIGEN_USE_GPU +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM + +#include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/framework/variant.h" +#include "tensorflow/core/framework/variant_op_registry.h" +#include "tensorflow/core/kernels/concat_lib.h" +#include "tensorflow/core/kernels/fill_functor.h" +#include "tensorflow/core/kernels/tensor_list.h" +#include "tensorflow/core/kernels/tensor_list_util.h" +#include "tensorflow/core/lib/core/coding.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/refcount.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/platform/platform.h" +#include "tensorflow/core/util/tensor_ops_util.h" +#include "tensorflow/core/util/util.h" + +// stream.h isn't available in some platforms such as Android, iOS, ChromiumOS, +// and Fuchsia. Only include it for platforms that PluggableDevice is tested on. +#if !defined(PLUGGABLE_DEVICE_SUPPORTED) && \ + (__x86_64__ || __i386__ || defined(__APPLE__) || defined(_WIN32)) && \ + !defined(ANDROID) && !defined(__ANDROID__) && !TARGET_OS_IOS && \ + !defined(PLATFORM_CHROMIUMOS) && !defined(__Fuchsia__) +#define PLUGGABLE_DEVICE_SUPPORTED +#endif + +#ifdef PLUGGABLE_DEVICE_SUPPORTED +#include "xla/stream_executor/stream.h" +#endif + +namespace tensorflow { + +typedef Eigen::ThreadPoolDevice CPUDevice; + +absl::Status TensorShapeFromTensor(const Tensor& t, PartialTensorShape* out); + +absl::Status GetElementShapeFromInput(OpKernelContext* c, + const TensorList& tensor_list, int index, + PartialTensorShape* element_shape); + +absl::Status GetInputList(OpKernelContext* c, int index, + const TensorList** list); + +absl::Status ForwardInputOrCreateNewList(OpKernelContext* c, + int32_t input_index, + int32_t output_index, + const TensorList& input_list, + TensorList** output_list); + +// TODO(penporn): Move this to a proper place. +inline bool IsPluggableDevice(OpKernelContext* c) { + return c->op_device_context() && c->op_device_context()->IsPluggableDevice(); +} + +template +inline void SetZero(OpKernelContext* ctx, Tensor& tensor) { +#ifdef PLUGGABLE_DEVICE_SUPPORTED + if (IsPluggableDevice(ctx)) { + auto ptr = + se::DeviceMemoryBase(tensor.flat().data(), tensor.TotalBytes()); + auto stream = ctx->op_device_context()->stream(); + auto result = stream->MemZero(&ptr, tensor.TotalBytes()).ok(); + DCHECK_EQ(true, result); + } else { +#endif // PLUGGABLE_DEVICE_SUPPORTED + functor::SetZeroFunctor()(ctx->eigen_device(), + tensor.flat()); +#ifdef PLUGGABLE_DEVICE_SUPPORTED + } +#endif // PLUGGABLE_DEVICE_SUPPORTED +} + +template +inline void CopyTensorPluggableDevice(OpKernelContext* ctx, Tensor& src, + Tensor& dst) { +#ifdef PLUGGABLE_DEVICE_SUPPORTED + auto src_t = src.unaligned_flat(); + auto dst_t = dst.flat(); + DCHECK(DataTypeCanUseMemcpy(DataTypeToEnum::v())); + auto src_ptr = se::DeviceMemoryBase(src_t.data(), src.TotalBytes()); + auto dst_ptr = se::DeviceMemoryBase(dst_t.data(), dst.TotalBytes()); + auto stream = ctx->op_device_context()->stream(); + auto result = stream->Memcpy(&dst_ptr, src_ptr, src.TotalBytes()).ok(); + DCHECK_EQ(true, result); +#else + LOG(FATAL) // Crash OK. + << "PluggableDevice is not supported on this platform."; +#endif // PLUGGABLE_DEVICE_SUPPORTED +} + +template +inline void CopyTensor(OpKernelContext* ctx, Tensor& src, Tensor& dst) { + auto src_t = src.unaligned_flat(); + auto dst_t = dst.flat(); + dst_t.device(ctx->eigen_device()) = src_t; +} + +template +void ConcatPluggableDevice( + OpKernelContext* context, + const std::vector::ConstMatrix>>& + inputs, + typename TTypes::Matrix* output) { +#ifdef PLUGGABLE_DEVICE_SUPPORTED + DCHECK(DataTypeCanUseMemcpy(DataTypeToEnum::v())); + + se::Stream* stream = context->op_device_context()->stream(); + + size_t num_inputs = inputs.size(); + std::vector sizes; + sizes.reserve(num_inputs); + int64 row_size = 0; + for (const auto& input : inputs) { + sizes.push_back(input->dimension(1)); + row_size += sizes.back(); + } + + T* out = &(*output)(0, 0); + std::vector inp; + inp.reserve(num_inputs); + for (const auto& input : inputs) { + inp.push_back(&(*input)(0, 0)); + } + const int64 dim0 = output->dimension(0); + for (int64 i = 0; i < dim0; ++i) { + for (int64 j = 0; j < num_inputs; ++j) { + auto size = sizes[j]; + se::DeviceMemoryBase out_base{out, size * sizeof(T)}; + se::DeviceMemoryBase inp_base{const_cast(inp[j]), size * sizeof(T)}; + OP_REQUIRES_OK(context, + stream->Memcpy(&out_base, inp_base, size * sizeof(T))); + out += size; + inp[j] += size; + } + } +#else + LOG(FATAL) // Crash OK. + << "PluggableDevice is not supported on this platform."; +#endif // PLUGGABLE_DEVICE_SUPPORTED +} + +template +class TensorListStack : public OpKernel { + public: + typedef std::vector::ConstMatrix>> + ConstMatrixVector; + explicit TensorListStack(OpKernelConstruction* c) : OpKernel(c) { + OP_REQUIRES_OK(c, c->GetAttr("element_dtype", &element_dtype_)); + OP_REQUIRES_OK(c, c->GetAttr("num_elements", &num_elements_)); + } + + void Compute(OpKernelContext* c) override { + const TensorList* tensor_list = nullptr; + OP_REQUIRES_OK(c, GetInputList(c, 0, &tensor_list)); + OP_REQUIRES( + c, element_dtype_ == tensor_list->element_dtype, + errors::InvalidArgument( + "Invalid data types; op elements ", DataTypeString(element_dtype_), + " but list elements ", DataTypeString(tensor_list->element_dtype))); + if (num_elements_ != -1) { + OP_REQUIRES(c, tensor_list->tensors().size() == num_elements_, + errors::InvalidArgument( + "Operation expected a list with ", num_elements_, + " elements but got a list with ", + tensor_list->tensors().size(), " elements.")); + } + PartialTensorShape partial_element_shape; + OP_REQUIRES_OK(c, GetElementShapeFromInput(c, *tensor_list, 1, + &partial_element_shape)); + OP_REQUIRES( + c, + partial_element_shape.IsFullyDefined() || + !tensor_list->tensors().empty(), + errors::InvalidArgument("Tried to stack elements of an empty ", + "list with non-fully-defined element_shape: ", + partial_element_shape.DebugString())); + + // Check that `element_shape` input tensor is compatible with the shapes of + // element tensors. + if (!tensor_list->element_shape.IsFullyDefined()) { + for (int i = 0; i < tensor_list->tensors().size(); ++i) { + const Tensor& t = tensor_list->tensors()[i]; + if (t.dtype() != DT_INVALID) { + PartialTensorShape tmp = partial_element_shape; + OP_REQUIRES_OK(c, tmp.MergeWith(t.shape(), &partial_element_shape)); + } + } + } + + // Compute the shape of the output tensor by pre-pending the leading dim to + // the element_shape. + TensorShape element_shape; + OP_REQUIRES(c, partial_element_shape.AsTensorShape(&element_shape), + errors::InvalidArgument( + "Tried to stack list which only contains uninitialized ", + "tensors and has a non-fully-defined element_shape: ", + partial_element_shape.DebugString())); + TensorShape output_shape = element_shape; + output_shape.InsertDim(0, tensor_list->tensors().size()); + Tensor* output; + OP_REQUIRES_OK(c, c->allocate_output(0, output_shape, &output)); + if (output->NumElements() == 0) { + return; + } + + ConstMatrixVector inputs_flat; + inputs_flat.reserve(tensor_list->tensors().size()); + Tensor zeros; + for (const auto& t : tensor_list->tensors()) { + if (t.dtype() != DT_INVALID) { + inputs_flat.emplace_back(new typename TTypes::ConstMatrix( + t.shaped({1, t.NumElements()}))); + } else { + if (!zeros.NumElements()) { + AllocatorAttributes attr; + if (element_dtype_ == DT_VARIANT) { + attr.set_on_host(true); + } + OP_REQUIRES_OK( + c, c->allocate_temp(element_dtype_, element_shape, &zeros, attr)); + SetZero(c, zeros); + } + inputs_flat.emplace_back(new typename TTypes::ConstMatrix( + const_cast(zeros).shaped( + {1, zeros.NumElements()}))); + } + } + auto output_flat = output->shaped({1, output->NumElements()}); + +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM + if (std::is_same::value) { + ConcatGPU(c, inputs_flat, output, &output_flat); + return; + } +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM + if (IsPluggableDevice(c)) { + ConcatPluggableDevice(c, inputs_flat, &output_flat); + } else { + ConcatCPU(c->device(), inputs_flat, &output_flat); + } + } + + private: + int num_elements_; + DataType element_dtype_; +}; + +template +class TensorListGetItem : public OpKernel { + public: + explicit TensorListGetItem(OpKernelConstruction* c) : OpKernel(c) { + OP_REQUIRES_OK(c, c->GetAttr("element_dtype", &element_dtype_)); + } + + void Compute(OpKernelContext* c) override { + const TensorList* l = nullptr; + OP_REQUIRES_OK(c, GetInputList(c, 0, &l)); + OP_REQUIRES(c, element_dtype_ == l->element_dtype, + errors::InvalidArgument("Invalid data types; op elements ", + DataTypeString(element_dtype_), + " but list elements ", + DataTypeString(l->element_dtype))); + int32_t index = c->input(1).scalar()(); + OP_REQUIRES(c, index < l->tensors().size(), + errors::InvalidArgument("Trying to access element ", index, + " in a list with ", l->tensors().size(), + " elements.")); + if (l->tensors()[index].dtype() != DT_INVALID) { + c->set_output(0, l->tensors()[index]); + } else { + PartialTensorShape partial_element_shape; + OP_REQUIRES_OK( + c, GetElementShapeFromInput(c, *l, 2, &partial_element_shape)); + TensorShape element_shape; + // If l->element_shape and the element_shape input are both not fully + // defined, try to infer the shape from other list elements. This requires + // that all initialized list elements have the same shape. + // NOTE(srbs): This might be a performance bottleneck since we are + // iterating over the entire list here. This is necessary for feature + // parity with TensorArray.read. TensorArray has a mode in which all + // elements are required to be of the same shape, TensorList does not. + // In that mode TensorArray sets the array's element_shape on the first + // write call. We could do something similar here if needed. + if (!partial_element_shape.IsFullyDefined()) { + for (const Tensor& t : l->tensors()) { + if (t.dtype() != DT_INVALID) { + PartialTensorShape tmp = partial_element_shape; + OP_REQUIRES_OK(c, tmp.MergeWith(t.shape(), &partial_element_shape)); + } + } + } + OP_REQUIRES( + c, partial_element_shape.AsTensorShape(&element_shape), + errors::InvalidArgument("Trying to read an uninitialized tensor but ", + "element_shape is not fully defined: ", + partial_element_shape.DebugString(), + " and no list element is set.")); + Tensor* result; + AllocatorAttributes attr; + if (element_dtype_ == DT_VARIANT) { + attr.set_on_host(true); + } + OP_REQUIRES_OK(c, c->allocate_output(0, element_shape, &result, attr)); + SetZero(c, *result); + } + } + + private: + DataType element_dtype_; +}; + +template +class TensorListPopBack : public OpKernel { + public: + explicit TensorListPopBack(OpKernelConstruction* c) : OpKernel(c) { + OP_REQUIRES_OK(c, c->GetAttr("element_dtype", &element_dtype_)); + } + + void Compute(OpKernelContext* c) override { + const TensorList* l = nullptr; + OP_REQUIRES_OK(c, GetInputList(c, 0, &l)); + OP_REQUIRES(c, element_dtype_ == l->element_dtype, + errors::InvalidArgument("Invalid data types; op elements ", + DataTypeString(element_dtype_), + " but list elements ", + DataTypeString(l->element_dtype))); + + OP_REQUIRES(c, !l->tensors().empty(), + errors::InvalidArgument("Trying to pop from an empty list.")); + + const Tensor& t = l->tensors().back(); + if (t.dtype() != DT_INVALID) { + c->set_output(1, t); + } else { + PartialTensorShape partial_element_shape; + OP_REQUIRES_OK( + c, GetElementShapeFromInput(c, *l, 1, &partial_element_shape)); + TensorShape element_shape; + OP_REQUIRES( + c, partial_element_shape.AsTensorShape(&element_shape), + errors::InvalidArgument("Trying to read an uninitialized tensor but ", + "element_shape is not fully defined.", + partial_element_shape.DebugString())); + Tensor* result; + AllocatorAttributes attr; + if (element_dtype_ == DT_VARIANT) { + attr.set_on_host(true); + } + OP_REQUIRES_OK(c, c->allocate_output(1, element_shape, &result, attr)); + SetZero(c, *result); + } + + TensorList* output_list = nullptr; + OP_REQUIRES_OK(c, ForwardInputOrCreateNewList(c, 0, 0, *l, &output_list)); + output_list->tensors().pop_back(); + } + + private: + DataType element_dtype_; +}; + +template +class TensorListConcat : public OpKernel { + public: + using ConstMatrixVector = + std::vector::ConstMatrix>>; + explicit TensorListConcat(OpKernelConstruction* c) : OpKernel(c) { + OP_REQUIRES_OK(c, c->GetAttr("element_dtype", &element_dtype_)); + if (c->HasAttr("element_shape")) { + OP_REQUIRES_OK(c, c->GetAttr("element_shape", &element_shape_)); + } + } + + void Compute(OpKernelContext* c) override { + PartialTensorShape element_shape_except_first_dim; + if (!element_shape_.unknown_rank()) { + auto dim_sizes = element_shape_.dim_sizes(); + OP_REQUIRES(c, !dim_sizes.empty(), + errors::InvalidArgument("element_shape must not be empty")); + element_shape_except_first_dim = + PartialTensorShape(absl::Span(dim_sizes).subspan(1)); + } + // Check that the input Variant tensor is indeed a TensorList and has the + // correct element type. + const TensorList* tensor_list = nullptr; + OP_REQUIRES_OK(c, GetInputList(c, 0, &tensor_list)); + OP_REQUIRES( + c, element_dtype_ == tensor_list->element_dtype, + errors::InvalidArgument( + "Invalid data types; op elements ", DataTypeString(element_dtype_), + " but list elements ", DataTypeString(tensor_list->element_dtype))); + // The leading dimension of all list elements if they are all the same. + // This is used as the leading dim of uninitialized tensors in the list + // if leading_dims is not provided. + int64_t first_dim = -1; + if (c->num_inputs() > 1) { + // TensorListConcatV2 + PartialTensorShape element_shape; + OP_REQUIRES_OK( + c, GetElementShapeFromInput(c, *tensor_list, 1, &element_shape)); + OP_REQUIRES(c, element_shape.unknown_rank() || element_shape.dims() >= 1, + errors::InvalidArgument( + "Concat requires elements to be at least vectors, ", + "found scalars instead.")); + // Split `element_shape` into `first_dim` and + // `element_shape_except_first_dim`. + first_dim = element_shape.dim_size(0); + element_shape_except_first_dim = element_shape; + element_shape_except_first_dim.RemoveDim(0); + } + // If the TensorList is empty, element_shape_except_first_dim must be fully + // defined. + OP_REQUIRES(c, + !tensor_list->tensors().empty() || + element_shape_except_first_dim.IsFullyDefined(), + errors::InvalidArgument( + "All except the first dimension must be fully defined ", + "when concating an empty tensor list. element_shape: ", + element_shape_except_first_dim.DebugString())); + // 1. Check that `element_shape_except_first_dim` input tensor is + // compatible with the shapes of element tensors. + // 2. Check that the elements have the same shape except the first dim. + // 3. If `first_dim` is known, check that it is compatible with the leading + // dims of all elements. + // 4. If `first_dim` is unknown (-1), check whether all initialized + // elements have the same leading dim and if so set `first_dim` to that + // value. + if (!tensor_list->element_shape.IsFullyDefined()) { + bool check_dim = (first_dim == -1); + int64_t inferred_first_dim = first_dim; + for (int i = 0; i < tensor_list->tensors().size(); ++i) { + const Tensor& t = tensor_list->tensors()[i]; + if (t.dtype() != DT_INVALID) { + PartialTensorShape tmp = element_shape_except_first_dim; + OP_REQUIRES( + c, TensorShapeUtils::IsVectorOrHigher(t.shape()), + errors::InvalidArgument("Concat saw a scalar shape at index ", i, + " but requires at least vectors.")); + TensorShape shape_except_first_dim = TensorShape( + absl::Span(t.shape().dim_sizes()).subspan(1)); + OP_REQUIRES_OK(c, tmp.MergeWith(shape_except_first_dim, + &element_shape_except_first_dim)); + OP_REQUIRES(c, first_dim == -1 || first_dim == t.shape().dim_size(0), + errors::InvalidArgument( + "First entry of element_shape input does not match ", + "the first dim of list element at index: ", i, + " Expected: ", first_dim, + " Actual: ", t.shape().dim_size(0))); + if (check_dim) { + if (inferred_first_dim == -1) { + inferred_first_dim = t.shape().dim_size(0); + } else if (inferred_first_dim != t.shape().dim_size(0)) { + inferred_first_dim = -1; + check_dim = false; + } + } + } + } + first_dim = inferred_first_dim; + } + TensorShape output_shape; + OP_REQUIRES(c, element_shape_except_first_dim.AsTensorShape(&output_shape), + errors::InvalidArgument( + "Trying to concat list with only uninitialized tensors ", + "but element_shape_except_first_dim is not fully defined: ", + element_shape_except_first_dim.DebugString())); + // Build the lengths_tensor and leading dim of the output tensor by + // iterating over all element tensors. + Tensor* lengths_tensor = nullptr; + OP_REQUIRES_OK(c, c->allocate_output(1, + TensorShape({static_cast( + tensor_list->tensors().size())}), + &lengths_tensor)); + auto lengths_tensor_vec = lengths_tensor->vec(); + int64_t leading_dim = 0; + for (size_t i = 0; i < tensor_list->tensors().size(); i++) { + int64_t dim; + if (tensor_list->tensors()[i].dtype() != DT_INVALID) { + dim = tensor_list->tensors()[i].shape().dim_size(0); + } else { + // If leading_dims is not provided or does not contain an entry for + // index i use the inferred `first_dim` if set. + if ((c->num_inputs() <= 2 || i >= c->input(2).NumElements()) && + first_dim != -1) { + dim = first_dim; + } else { + OP_REQUIRES(c, c->num_inputs() > 2, + errors::InvalidArgument( + "Concating lists with uninitialized tensors is not ", + "supported in this version of TensorListConcat. ", + "Consider updating your GraphDef to run the newer ", + "version.")); + OP_REQUIRES(c, i < c->input(2).NumElements(), + errors::InvalidArgument( + "List contains uninitialized tensor at index ", i, + " but leading_dims has only ", + c->input(2).NumElements(), " elements.")); + dim = c->input(2).vec()(i); + } + } + leading_dim += dim; + lengths_tensor_vec(i) = dim; + } + output_shape.InsertDim(0, leading_dim); + Tensor* output; + // Allocate the output tensor and fill it up with the concated element + // tensors. + OP_REQUIRES_OK(c, c->allocate_output(0, output_shape, &output)); + if (output->NumElements() == 0) { + return; + } + + ConstMatrixVector inputs_flat; + inputs_flat.reserve(tensor_list->tensors().size()); + // Store the zeros tensors in a vector to prevent them from being GC'ed till + // concat is complete. + std::vector zeros_vec; + for (int i = 0; i < tensor_list->tensors().size(); i++) { + const Tensor& element_tensor = tensor_list->tensors()[i]; + if (element_tensor.dtype() != DT_INVALID) { + if (element_tensor.NumElements() > 0) { + inputs_flat.emplace_back(new typename TTypes::ConstMatrix( + element_tensor.shaped({1, element_tensor.NumElements()}))); + } + } else { + AllocatorAttributes attr; + if (element_dtype_ == DT_VARIANT) { + attr.set_on_host(true); + } + TensorShape element_shape = output_shape; + element_shape.set_dim(0, lengths_tensor_vec(i)); + zeros_vec.emplace_back(); + Tensor& zeros = zeros_vec.back(); + OP_REQUIRES_OK( + c, c->allocate_temp(element_dtype_, element_shape, &zeros, attr)); + SetZero(c, zeros); + inputs_flat.emplace_back(new typename TTypes::ConstMatrix( + const_cast(zeros).shaped( + {1, zeros.NumElements()}))); + } + } + auto output_flat = output->shaped({1, output->NumElements()}); + +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM + if (std::is_same::value) { + ConcatGPU(c, inputs_flat, output, &output_flat); + return; + } +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM + if (IsPluggableDevice(c)) { + ConcatPluggableDevice(c, inputs_flat, &output_flat); + } else { + ConcatCPU(c->device(), inputs_flat, &output_flat); + } + } + + private: + DataType element_dtype_; + PartialTensorShape element_shape_; +}; + +template +class TensorListSplit : public OpKernel { + public: + TensorListSplit(OpKernelConstruction* c) : OpKernel(c) {} + + void Compute(OpKernelContext* c) override { + Tensor* output_tensor; + AllocatorAttributes attr; + attr.set_on_host(true); + OP_REQUIRES_OK(c, c->allocate_output(0, {}, &output_tensor, attr)); + PartialTensorShape element_shape; + OP_REQUIRES_OK(c, TensorShapeFromTensor(c->input(1), &element_shape)); + OP_REQUIRES(c, element_shape.unknown_rank() || element_shape.dims() >= 1, + errors::InvalidArgument( + "TensorListSplit requires element_shape to be at least of ", + "rank 1, but saw: ", element_shape.DebugString())); + TensorList output_list; + const Tensor& input_tensor = c->input(0); + output_list.element_dtype = input_tensor.dtype(); + OP_REQUIRES(c, TensorShapeUtils::IsVectorOrHigher(input_tensor.shape()), + errors::InvalidArgument( + "Tensor must be at least a vector, but saw shape: ", + input_tensor.shape().DebugString())); + TensorShape tensor_shape_without_first_dim(input_tensor.shape()); + tensor_shape_without_first_dim.RemoveDim(0); + PartialTensorShape element_shape_without_first_dim; + if (!element_shape.unknown_rank()) { + element_shape_without_first_dim = + PartialTensorShape(element_shape.dim_sizes()); + element_shape_without_first_dim.RemoveDim(0); + } + OP_REQUIRES(c, + element_shape_without_first_dim.IsCompatibleWith( + tensor_shape_without_first_dim), + errors::InvalidArgument( + "tensor shape ", input_tensor.shape().DebugString(), + " is not compatible with element_shape ", + element_shape.DebugString())); + output_list.element_shape = element_shape; + const Tensor& lengths = c->input(2); + OP_REQUIRES(c, TensorShapeUtils::IsVector(lengths.shape()), + errors::InvalidArgument( + "Expected lengths to be a vector, received shape: ", + lengths.shape().DebugString())); + output_list.tensors().reserve(lengths.shape().dim_size(0)); + + const auto copy_tensor = IsPluggableDevice(c) + ? &CopyTensorPluggableDevice + : &CopyTensor; + + int64_t start = 0; + int64_t end = 0; + for (int i = 0; i < lengths.shape().dim_size(0); ++i) { + int64_t length = lengths.vec()(i); + OP_REQUIRES( + c, length >= 0, + errors::InvalidArgument("Invalid value in lengths: ", length)); + end = start + length; + OP_REQUIRES(c, end <= input_tensor.shape().dim_size(0), + errors::InvalidArgument("Attempting to slice [", start, ", ", + end, "] from tensor with length ", + input_tensor.shape().dim_size(0))); + Tensor tmp = input_tensor.Slice(start, end); + start = end; + // TODO(apassos) maybe not always align; but weird compiler bugs seem to + // prevent this. + Tensor aligned; + OP_REQUIRES_OK(c, c->allocate_temp(tmp.dtype(), tmp.shape(), &aligned)); + copy_tensor(c, tmp, aligned); + output_list.tensors().emplace_back(aligned); + } + OP_REQUIRES(c, end == input_tensor.shape().dim_size(0), + errors::InvalidArgument( + "Unused values in tensor. Length of tensor: ", + input_tensor.shape().dim_size(0), " Values used: ", end)); + output_tensor->scalar()() = std::move(output_list); + } +}; + +template +class TensorListGather : public OpKernel { + public: + typedef std::vector::ConstMatrix>> + ConstMatrixVector; + explicit TensorListGather(OpKernelConstruction* c) : OpKernel(c) { + OP_REQUIRES_OK(c, c->GetAttr("element_dtype", &element_dtype_)); + } + + void Compute(OpKernelContext* c) override { + const TensorList* tensor_list = nullptr; + OP_REQUIRES_OK(c, GetInputList(c, 0, &tensor_list)); + OP_REQUIRES( + c, element_dtype_ == tensor_list->element_dtype, + errors::InvalidArgument( + "Invalid data types; op elements ", DataTypeString(element_dtype_), + " but list elements ", DataTypeString(tensor_list->element_dtype))); + const Tensor& indices = c->input(1); + PartialTensorShape partial_element_shape; + OP_REQUIRES_OK(c, GetElementShapeFromInput(c, *tensor_list, 2, + &partial_element_shape)); + OP_REQUIRES( + c, partial_element_shape.IsFullyDefined() || indices.NumElements() > 0, + errors::InvalidArgument("Tried to gather 0-elements from " + "a list with non-fully-defined shape: ", + partial_element_shape.DebugString())); + + // Check that `element_shape` input tensor is compatible with the shapes of + // element tensors. + if (!tensor_list->element_shape.IsFullyDefined()) { + for (int index = 0; index < indices.NumElements(); ++index) { + const int i = indices.flat()(index); + + OP_REQUIRES(c, 0 <= i && i < tensor_list->tensors().size(), + absl::InvalidArgumentError(absl::StrCat( + "Trying to gather element ", i, " in a list with ", + tensor_list->tensors().size(), " elements."))); + + const Tensor& t = tensor_list->tensors()[i]; + if (t.dtype() != DT_INVALID) { + PartialTensorShape tmp = partial_element_shape; + OP_REQUIRES_OK(c, tmp.MergeWith(t.shape(), &partial_element_shape)); + } + } + } + + // Compute the shape of the output tensor by pre-pending the leading dim to + // the element_shape. + TensorShape element_shape; + OP_REQUIRES( + c, partial_element_shape.AsTensorShape(&element_shape), + errors::InvalidArgument("Tried to gather uninitialized tensors from a ", + "list with non-fully-defined element_shape: ", + partial_element_shape.DebugString())); + TensorShape output_shape = element_shape; + output_shape.InsertDim(0, indices.NumElements()); + Tensor* output; + OP_REQUIRES_OK(c, c->allocate_output(0, output_shape, &output)); + if (output->NumElements() == 0) { + return; + } + + ConstMatrixVector inputs_flat; + inputs_flat.reserve(indices.NumElements()); + Tensor zeros; + for (int index = 0; index < indices.NumElements(); ++index) { + const int i = indices.flat()(index); + OP_REQUIRES( + c, i < tensor_list->tensors().size(), + errors::InvalidArgument("Index ", i, " out o range; list only has ", + tensor_list->tensors().size(), " elements.")); + const Tensor& t = tensor_list->tensors()[i]; + if (t.dtype() != DT_INVALID) { + inputs_flat.emplace_back(new typename TTypes::ConstMatrix( + t.shaped({1, t.NumElements()}))); + } else { + if (!zeros.NumElements()) { + AllocatorAttributes attr; + if (element_dtype_ == DT_VARIANT) { + attr.set_on_host(true); + } + OP_REQUIRES_OK( + c, c->allocate_temp(element_dtype_, element_shape, &zeros, attr)); + SetZero(c, zeros); + } + inputs_flat.emplace_back(new typename TTypes::ConstMatrix( + const_cast(zeros).shaped( + {1, zeros.NumElements()}))); + } + } + auto output_flat = output->shaped({1, output->NumElements()}); + +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM + if (std::is_same::value) { + ConcatGPU(c, inputs_flat, output, &output_flat); + return; + } +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM + if (IsPluggableDevice(c)) { + ConcatPluggableDevice(c, inputs_flat, &output_flat); + } else { + ConcatCPU(c->device(), inputs_flat, &output_flat); + } + } + + private: + DataType element_dtype_; +}; + +template +class TensorListFromTensor : public OpKernel { + public: + TensorListFromTensor(OpKernelConstruction* c) : OpKernel(c) {} + + void Compute(OpKernelContext* c) override { + Tensor* output_tensor; + AllocatorAttributes attr; + attr.set_on_host(true); + OP_REQUIRES_OK(c, c->allocate_output(0, {}, &output_tensor, attr)); + PartialTensorShape element_shape; + OP_REQUIRES( + c, !TensorShapeUtils::IsMatrixOrHigher(c->input(1).shape()), + errors::InvalidArgument( + "TensorListFromTensor: element_shape must be at most rank 1 but ", + "has the shape of ", c->input(1).shape().DebugString())); + OP_REQUIRES_OK(c, TensorShapeFromTensor(c->input(1), &element_shape)); + TensorList output_list; + const Tensor& t = c->input(0); + output_list.element_dtype = t.dtype(); + OP_REQUIRES(c, TensorShapeUtils::IsVectorOrHigher(t.shape()), + errors::InvalidArgument( + "Tensor must be at least a vector, but saw shape: ", + t.shape().DebugString())); + TensorShape output_shape(t.shape()); + output_shape.RemoveDim(0); + OP_REQUIRES(c, element_shape.IsCompatibleWith(output_shape), + errors::InvalidArgument( + "Specified a list with shape ", element_shape.DebugString(), + " from a tensor with shape ", output_shape.DebugString())); + output_list.element_shape = element_shape; + output_list.tensors().reserve(t.shape().dim_size(0)); + + const auto copy_tensor = IsPluggableDevice(c) + ? &CopyTensorPluggableDevice + : &CopyTensor; + + for (int i = 0; i < t.shape().dim_size(0); ++i) { + Tensor tmp = t.Slice(i, i + 1); + TensorShape tmp_shape = tmp.shape(); + tmp_shape.RemoveDim(0); + OP_REQUIRES(c, tmp.CopyFrom(tmp, tmp_shape), + errors::Unknown("Unexpected shape error.")); + // TODO(apassos) maybe not always align; but weird compiler bugs seem to + // prevent this. + Tensor aligned; + OP_REQUIRES_OK(c, c->allocate_temp(tmp.dtype(), tmp.shape(), &aligned)); + copy_tensor(c, tmp, aligned); + output_list.tensors().push_back(aligned); + } + output_tensor->scalar()() = std::move(output_list); + } +}; + +// Scatters values in `value` into `list`. Assumes that `indices` are valid. +template +absl::Status Scatter(OpKernelContext* c, const Tensor& value, + const Tensor& indices, TensorList* list) { + const auto copy_tensor = IsPluggableDevice(c) ? &CopyTensorPluggableDevice + : &CopyTensor; + for (int index = 0; index < indices.NumElements(); ++index) { + const int i = indices.flat()(index); + Tensor tmp = value.Slice(index, index + 1); + TensorShape tmp_shape = tmp.shape(); + tmp_shape.RemoveDim(0); + if (!tmp.CopyFrom(tmp, tmp_shape)) { + return errors::Unknown("Unexpected shape error."); + } + // TODO(apassos) maybe not always align; but weird compiler bugs seem to + // prevent this. + Tensor aligned; + TF_RETURN_IF_ERROR(c->allocate_temp(tmp.dtype(), tmp.shape(), &aligned)); + // TODO(apassos) do all slices in a single kernel invocation instead of + // many small ones. + copy_tensor(c, tmp, aligned); + std::swap(list->tensors()[i], aligned); + } + return absl::OkStatus(); +} + +template +class TensorListScatterIntoExistingList : public OpKernel { + public: + TensorListScatterIntoExistingList(OpKernelConstruction* c) : OpKernel(c) {} + + void Compute(OpKernelContext* c) override { + const TensorList* l = nullptr; + OP_REQUIRES_OK(c, GetInputList(c, 0, &l)); + const Tensor& input_tensor = c->input(1); + const Tensor& indices = c->input(2); + + // Check that inputs are valid. + OP_REQUIRES(c, input_tensor.dtype() == l->element_dtype, + errors::InvalidArgument( + "Invalid data types; input tensor type: ", + DataTypeString(input_tensor.dtype()), + " list element_type: ", DataTypeString(l->element_dtype))); + OP_REQUIRES(c, TensorShapeUtils::IsVectorOrHigher(input_tensor.shape()), + errors::InvalidArgument( + "Tensor must be at least a vector, but saw shape: ", + input_tensor.shape().DebugString())); + OP_REQUIRES(c, TensorShapeUtils::IsVector(indices.shape()), + errors::InvalidArgument( + "Expected indices to be a vector, but received shape: ", + indices.shape().DebugString())); + OP_REQUIRES( + c, indices.NumElements() == input_tensor.shape().dim_size(0), + errors::InvalidArgument( + "Expected len(indices) == tensor.shape[0], but saw: ", + indices.NumElements(), " vs. ", input_tensor.shape().dim_size(0))); + + // Resize the list if needed to accommodate all indices. + TensorList* output_list = nullptr; + OP_REQUIRES_OK(c, ForwardInputOrCreateNewList(c, 0, 0, *l, &output_list)); + const auto indices_vec = indices.vec(); + int32_t max_index = + (indices.NumElements() == 0) + ? -1 + : *std::max_element(indices_vec.data(), + indices_vec.data() + indices.NumElements()); + if (max_index + 1 > output_list->tensors().size()) { + output_list->tensors().resize(max_index + 1); + } + + // Scatter the values. + OP_REQUIRES_OK(c, + Scatter(c, input_tensor, indices, output_list)); + } +}; + +template +class TensorListScatter : public OpKernel { + public: + TensorListScatter(OpKernelConstruction* c) : OpKernel(c) {} + + void Compute(OpKernelContext* c) override { + Tensor* output_tensor; + AllocatorAttributes attr; + attr.set_on_host(true); + OP_REQUIRES_OK(c, c->allocate_output(0, {}, &output_tensor, attr)); + Tensor indices = c->input(1); + PartialTensorShape element_shape; + OP_REQUIRES( + c, !TensorShapeUtils::IsMatrixOrHigher(c->input(2).shape()), + errors::InvalidArgument( + "TensorListScatter: element_shape must be at most rank 1 but has ", + "the shape of ", c->input(2).shape().DebugString())); + OP_REQUIRES_OK(c, TensorShapeFromTensor(c->input(2), &element_shape)); + // TensorListScatterV2 passes the num_elements input, TensorListScatter does + // not. + int num_elements = -1; + if (c->num_inputs() >= 4) { + OP_REQUIRES(c, TensorShapeUtils::IsScalar(c->input(3).shape()), + errors::InvalidArgument("num_elements must be a scalar")); + num_elements = c->input(3).scalar()(); + } + OP_REQUIRES(c, num_elements >= -1, + errors::InvalidArgument( + "TensorListScatter expects num_elements >= -1, found: ", + num_elements)); + TensorList output_list; + const Tensor& input_tensor = c->input(0); + output_list.element_dtype = input_tensor.dtype(); + OP_REQUIRES(c, TensorShapeUtils::IsVectorOrHigher(input_tensor.shape()), + errors::InvalidArgument( + "Tensor must be at least a vector, but saw shape: ", + input_tensor.shape().DebugString())); + TensorShape output_shape(input_tensor.shape()); + output_shape.RemoveDim(0); + OP_REQUIRES(c, element_shape.IsCompatibleWith(output_shape), + errors::InvalidArgument( + "Specified a list with shape ", element_shape.DebugString(), + " from a tensor with shape ", output_shape.DebugString())); + output_list.element_shape = element_shape; + + OP_REQUIRES(c, indices.NumElements() == input_tensor.shape().dim_size(0), + errors::InvalidArgument( + "Invalid number of rows in input tensor. Expected: ", + indices.NumElements(), + " Actual: ", input_tensor.shape().dim_size(0))); + + // Validate indices and resize output_list.tensors to fit the highest index. + { + int highest_index = -1; + for (int index = 0; index < indices.NumElements(); ++index) { + const int i = indices.flat()(index); + OP_REQUIRES( + c, i >= 0, + errors::InvalidArgument( + "Indices in TensorListScatter must all be non-negative.")); + OP_REQUIRES(c, num_elements == -1 || i < num_elements, + errors::InvalidArgument( + "TensorListScatter: Trying to scatter at index ", i, + " in list with size ", num_elements)); + if (i > highest_index) { + highest_index = i; + } + } + output_list.tensors().resize(std::max(highest_index + 1, num_elements), + Tensor(DT_INVALID)); + } + + OP_REQUIRES_OK(c, + Scatter(c, input_tensor, indices, &output_list)); + output_tensor->scalar()() = std::move(output_list); + } +}; + +template +absl::Status TensorListBinaryAdd(OpKernelContext* c, const TensorList& a, + const TensorList& b, TensorList* out) { + return TensorListBinaryAdd(c, a, b, out, BinaryAddTensors); +} + +template +absl::Status TensorListZerosLike(OpKernelContext* c, const TensorList& x, + TensorList* y) { + return TensorListZerosLike(c, x, y, ZerosLikeTensor); +} + +template +class TensorListPushBackBatch : public OpKernel { + public: + explicit TensorListPushBackBatch(OpKernelConstruction* c) : OpKernel(c) { + OP_REQUIRES_OK(c, c->GetAttr("element_dtype", &element_dtype_)); + } + + void Compute(OpKernelContext* c) override { + const Tensor& input = c->input(1); + OP_REQUIRES(c, element_dtype_ == input.dtype(), + errors::InvalidArgument("Invalid data types; list elements ", + DataTypeString(element_dtype_), + " but tried to append ", + DataTypeString(input.dtype()))); + OP_REQUIRES(c, TensorShapeUtils::IsVectorOrHigher(input.shape()), + errors::InvalidArgument( + "Expected tensor to be at least a vector, but saw shape: ", + input.shape().DebugString())); + + const TensorShape& tls_shape = c->input(0).shape(); + + // For purposes of input forwarding, we want the least restrictive + // AllocatorAttributes possible. If we need to allocate later, + // we'll request the DT_VARIANT be allocated on host. + AllocatorAttributes attr; + + std::unique_ptr tls_alias = c->forward_input( + 0 /*input_index*/, 0 /*output_index*/, DT_VARIANT, tls_shape, + DEVICE_MEMORY /* input is always on DEVICE_MEMORY */, attr); + + bool ok_to_alias = tls_alias != nullptr; + if (tls_alias && tls_alias->dtype() == DT_VARIANT && + tls_alias->NumElements() > 0) { + auto alias_t = tls_alias->flat(); + for (int i = 0; i < tls_alias->NumElements(); ++i) { + TensorList* tl_i = alias_t(i).get(); + if (tl_i == nullptr || !tl_i->RefCountIsOne()) { + ok_to_alias = false; + break; + } + } + } + const Tensor& tls = ok_to_alias ? *tls_alias : c->input(0); + + OP_REQUIRES(c, tls.dtype() == DT_VARIANT, + errors::InvalidArgument( + "Expected input_handles dtype to be Variant, but saw: ", + DataTypeString(tls.dtype()))); + OP_REQUIRES(c, TensorShapeUtils::IsVector(tls_shape), + errors::InvalidArgument( + "Expected input_handles to be a vector, but saw shape: ", + tls_shape.DebugString())); + const int64_t batch_size = tls.NumElements(); + OP_REQUIRES(c, input.dim_size(0) == batch_size, + errors::InvalidArgument( + "Expected tensor.shape[0] == input_handles.size, but saw ", + input.dim_size(0), " vs. ", batch_size)); + auto tls_t = tls.vec(); + + TensorShape input_element_shape = input.shape(); + input_element_shape.RemoveDim(0); + std::vector tl_batch; + for (int64_t b = 0; b < batch_size; ++b) { + const TensorList* l = tls_t(b).get(); + OP_REQUIRES(c, l != nullptr, + errors::InvalidArgument("Input handle at index ", b, + " is not a list. Saw: '", + tls_t(b).DebugString(), "'")); + OP_REQUIRES( + c, l->element_shape.IsCompatibleWith(input_element_shape), + errors::InvalidArgument( + "Tried to append a tensor with incompatible shape to a " + "list at index ", + b, ". Op element shape: ", input_element_shape.DebugString(), + " list shape: ", l->element_shape.DebugString())); + OP_REQUIRES(c, element_dtype_ == l->element_dtype, + errors::InvalidArgument( + "Invalid data type at index ", b, "; op elements ", + DataTypeString(element_dtype_), " but list elements ", + DataTypeString(l->element_dtype))); + tl_batch.push_back(l); + } + + Tensor* result; + + if (ok_to_alias) { + result = tls_alias.get(); + c->set_output(0, *result); + } else { + // DT_VARIANT tensors always allocated on host. + AllocatorAttributes attr; + attr.set_on_host(true); + OP_REQUIRES_OK( + c, c->allocate_output(0, TensorShape{batch_size}, &result, attr)); + } + + if (batch_size == 0) { + return; + } + + auto input_t = input.flat_outer_dims(); + auto result_t = result->vec(); + + for (int64_t b = 0; b < batch_size; ++b) { + if (!ok_to_alias) { + result_t(b) = tl_batch[b]->Copy(); + } + TensorList* output = result_t(b).get(); + DCHECK(output != nullptr); + Tensor frame; + OP_REQUIRES_OK( + c, c->allocate_temp(element_dtype_, input_element_shape, &frame)); + if (input_element_shape.num_elements() > 0) { + auto frame_t = frame.flat(); + // TODO(penporn): Get this if out of the batch loop. + if (IsPluggableDevice(c)) { + // The chip method need Eigen Device, so need to use Tensor.Slice + // instead of chip for pluggable device. The input should be reshaped + // to 2-D and so can be sliced by batch dim. + auto input_t_shape = + TensorShape({input_t.dimension(0), input_t.dimension(1)}); + auto input_reshaped = Tensor(); + OP_REQUIRES(c, input_reshaped.CopyFrom(input, input_t_shape), + errors::Unknown("Unexpected shape error.")); + + auto input_batch = input_reshaped.Slice(b, b + 1); + CopyTensorPluggableDevice(c, input_batch, frame); + } else { + frame_t.device(c->eigen_device()) = + input_t.template chip<0>(b); + } + } + output->tensors().push_back(std::move(frame)); + } + } + + private: + DataType element_dtype_; +}; + +} // namespace tensorflow + +#undef PLUGGABLE_DEVICE_SUPPORTED +#endif // TENSORFLOW_CORE_KERNELS_LIST_KERNELS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/logging_ops.h b/third_party/tflite-hdrs/tensorflow/core/kernels/logging_ops.h new file mode 100644 index 00000000..5cb12139 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/logging_ops.h @@ -0,0 +1,33 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_KERNELS_LOGGING_OPS_H_ +#define TENSORFLOW_CORE_KERNELS_LOGGING_OPS_H_ + +#include "tensorflow/core/framework/op_kernel.h" + +namespace tensorflow { + +class AssertOp : public OpKernel { + public: + explicit AssertOp(OpKernelConstruction* c); + void Compute(OpKernelContext* ctx) override; + + private: + int32 summarize_ = 0; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_LOGGING_OPS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/logistic-loss.h b/third_party/tflite-hdrs/tensorflow/core/kernels/logistic-loss.h new file mode 100644 index 00000000..d848a1f3 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/logistic-loss.h @@ -0,0 +1,134 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_LOGISTIC_LOSS_H_ +#define TENSORFLOW_CORE_KERNELS_LOGISTIC_LOSS_H_ + +#include + +#include "tensorflow/core/kernels/loss.h" +#include "tensorflow/core/lib/core/errors.h" + +namespace tensorflow { + +class LogisticLossUpdater : public DualLossUpdater { + public: + // Adding vs. Averaging in Distributed Primal-Dual Optimization. + // Chenxin Ma, Virginia Smith, Martin Jaggi, Michael I. Jordan, Peter + // Richtarik, Martin Takac http://arxiv.org/abs/1502.03508 + double ComputeUpdatedDual(const int num_loss_partitions, const double label, + const double example_weight, + const double current_dual, const double wx, + const double weighted_example_norm) const final { + // Newton algorithm converges quadratically so 10 steps will be largely + // enough to achieve a very good precision + static const int newton_total_steps = 10; + double x = 0; + for (int i = 0; i < newton_total_steps; ++i) { + x = NewtonStep(x, num_loss_partitions, label, wx, example_weight, + weighted_example_norm, current_dual); + } + return 0.5 * (1 + tanh(x)) / label; + } + + // Dual of logistic loss function. + // https://en.wikipedia.org/wiki/Convex_conjugate + double ComputeDualLoss(const double current_dual, const double example_label, + const double example_weight) const final { + // Dual of the logistic loss function is + // ay * log(ay) + (1-ay) * log (1-ay), where a is the dual variable. + const double ay = current_dual * example_label; + const double log_ay = (ay > 0) ? log(ay) : 0; + const double one_minus_ay = 1 - ay; + const double log_one_minus_ay = (one_minus_ay > 0) ? log(one_minus_ay) : 0; + return ((ay * log_ay) + (one_minus_ay * log_one_minus_ay)) * example_weight; + } + + // Logistic loss for binary classification. + // https://en.wikipedia.org/wiki/Loss_functions_for_classification + double ComputePrimalLoss(const double wx, const double example_label, + const double example_weight) const final { + // Logistic loss: + // log(1 + e^(-ywx)) + // log(e^0 + e^(-ywx)) + // a + log(e^(0-a) + e^(-ywx - a)), where a is max(0, -ywx) + // https://hips.seas.harvard.edu/blog/2013/01/09/computing-log-sum-exp/ + const double y_wx = example_label * wx; + if (y_wx > 0) { + // 0 + log(e^(0) + e^(-ywx - 0)) + // log(1 + e^(-ywx)) + return log1p(exp(-y_wx)) * example_weight; + } + // -ywx + log(e^(ywx) + e^(-ywx + ywx)) + // log(e^(ywx) + e^(0)) - ywx + // log(1 + e^(ywx)) - ywx + return (log1p(exp(y_wx)) - y_wx) * example_weight; + } + + // Derivative of logistic loss + double PrimalLossDerivative(const double wx, const double label, + const double example_weight) const final { + double inverse_exp_term = 0; + if (label * wx > 0) { + inverse_exp_term = exp(-label * wx) / (1 + exp(-label * wx)); + } else { + inverse_exp_term = 1 / (1 + exp(label * wx)); + } + return -inverse_exp_term * label * example_weight; + } + + // The smoothness constant is 4 since the derivative of logistic loss, which + // is exp(-x) / (1 + exp(-x)) can be shown to 0.25-Lipschitz (its derivative + // is bounded by 0.25) + double SmoothnessConstant() const final { return 4; } + + // Converts binary example labels from 0.0 or 1.0 to -1.0 or 1.0 respectively + // as expected by logistic regression. + absl::Status ConvertLabel(float* const example_label) const final { + if (*example_label == 0.0) { + *example_label = -1; + return absl::OkStatus(); + } + if (*example_label == 1.0) { + return absl::OkStatus(); + } + return errors::InvalidArgument( + "Only labels of 0.0 or 1.0 are supported right now. " + "Found example with label: ", + *example_label); + } + + private: + // We use Newton algorithm on a modified function (see readme.md). + double NewtonStep(const double x, const int num_loss_partitions, + const double label, const double wx, + const double example_weight, + const double weighted_example_norm, + const double current_dual) const { + const double tanhx = tanh(x); + const double numerator = -2 * label * x - wx - + num_loss_partitions * weighted_example_norm * + example_weight * + (0.5 * (1 + tanhx) / label - current_dual); + const double denominator = + -2 * label - num_loss_partitions * weighted_example_norm * + example_weight * (1 - tanhx * tanhx) * 0.5 / label; + return x - numerator / denominator; + } +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_LOGISTIC_LOSS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/lookup_table_init_op.h b/third_party/tflite-hdrs/tensorflow/core/kernels/lookup_table_init_op.h new file mode 100644 index 00000000..e94db921 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/lookup_table_init_op.h @@ -0,0 +1,34 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_LOOKUP_TABLE_INIT_OP_H_ +#define TENSORFLOW_CORE_KERNELS_LOOKUP_TABLE_INIT_OP_H_ + +#include "tensorflow/core/kernels/initializable_lookup_table.h" + +namespace tensorflow { +namespace lookup { + +// Helper function to initialize an InitializableLookupTable from a text file. +absl::Status InitializeTableFromTextFile(const string& filename, + int64_t vocab_size, char delimiter, + int32_t key_index, int32_t value_index, + Env* env, + InitializableLookupTable* table); + +} // namespace lookup +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_LOOKUP_TABLE_INIT_OP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/lookup_table_op.h b/third_party/tflite-hdrs/tensorflow/core/kernels/lookup_table_op.h new file mode 100644 index 00000000..daa7f6e3 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/lookup_table_op.h @@ -0,0 +1,352 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_LOOKUP_TABLE_OP_H_ +#define TENSORFLOW_CORE_KERNELS_LOOKUP_TABLE_OP_H_ + +#include "absl/container/flat_hash_map.h" +#include "tensorflow/core/framework/bounds_check.h" +#include "tensorflow/core/framework/lookup_interface.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/resource_mgr.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/graph/graph_def_builder.h" +#include "tensorflow/core/kernels/lookup_util.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/gtl/map_util.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/thread_annotations.h" + +namespace tensorflow { + +// Lookup table op that supports different table implementations specified by +// the 'Container' template. Container must be derived from LookupInterface. The +// key and value are of the templated type "key_dtype" and "value_dtype" +// respectively. +template +class LookupTableOp : public OpKernel { + public: + // ctx is not owned by this class. + explicit LookupTableOp(OpKernelConstruction* ctx) + : OpKernel(ctx), table_set_(false) { + if (ctx->output_type(0) == DT_RESOURCE) { + OP_REQUIRES_OK(ctx, + ctx->allocate_temp(tensorflow::DT_RESOURCE, + tensorflow::TensorShape({}), &table_)); + } else { + OP_REQUIRES_OK(ctx, + ctx->allocate_temp(tensorflow::DT_STRING, + tensorflow::TensorShape({2}), &table_)); + } + OP_REQUIRES_OK( + ctx, ctx->GetAttr("use_node_name_sharing", &use_node_name_sharing_)); + } + + // ctx is not owned by this function. + void Compute(OpKernelContext* ctx) override { + mutex_lock l(mu_); + + if (!table_set_) { + OP_REQUIRES_OK(ctx, cinfo_.Init(ctx->resource_manager(), def(), + use_node_name_sharing_)); + } + + auto creator = + [ctx, this](lookup::LookupInterface** ret) + TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { + lookup::LookupInterface* container = new Container(ctx, this); + if (!ctx->status().ok()) { + container->Unref(); + return ctx->status(); + } + if (ctx->track_allocations()) { + ctx->record_persistent_memory_allocation( + container->MemoryUsed() + table_.AllocatedBytes()); + } + *ret = container; + return absl::OkStatus(); + }; + + lookup::LookupInterface* table = nullptr; + OP_REQUIRES_OK(ctx, + cinfo_.resource_manager() + ->template LookupOrCreate( + cinfo_.container(), cinfo_.name(), &table, creator)); + core::ScopedUnref unref_me(table); + + OP_REQUIRES_OK(ctx, lookup::CheckTableDataTypes( + *table, DataTypeToEnum::v(), + DataTypeToEnum::v(), cinfo_.name())); + + if (ctx->expected_output_dtype(0) == DT_RESOURCE) { + if (!table_set_) { + auto h = table_.template scalar(); + h() = MakeResourceHandle( + ctx, cinfo_.container(), cinfo_.name()); + } + ctx->set_output(0, table_); + } else { + if (!table_set_) { + auto h = table_.template flat(); + h(0) = cinfo_.container(); + h(1) = cinfo_.name(); + } + ctx->set_output_ref(0, &mu_, &table_); + } + table_set_ = true; + } + + ~LookupTableOp() override { + // If the table object was not shared, delete it. + if (table_set_ && cinfo_.resource_is_private_to_kernel()) { + if (!cinfo_.resource_manager() + ->template Delete(cinfo_.container(), + cinfo_.name()) + .ok()) { + // Do nothing; the resource can have been deleted by session resets. + } + } + } + + private: + mutex mu_; + Tensor table_ TF_GUARDED_BY(mu_); + bool table_set_ TF_GUARDED_BY(mu_); + ContainerInfo cinfo_; + bool use_node_name_sharing_; + + LookupTableOp(const LookupTableOp&) = delete; + void operator=(const LookupTableOp&) = delete; +}; + +// An anonymous version of LookupTableOp, which creates a new table resource +// everytime `Compute` is called. The resource can only be accessed by the +// returned resource handle (e.g. it can't be looked up by a name in a resource +// manager). The resource will be automatically deleted when all resource +// handles pointing to it are gone. +template +class AnonymousLookupTableOp : public OpKernel { + public: + explicit AnonymousLookupTableOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} + + void Compute(OpKernelContext* ctx) override { + lookup::LookupInterface* table = new Container(ctx, this); + if (!ctx->status().ok()) { + table->Unref(); + return; + } + Tensor table_tensor; + OP_REQUIRES_OK( + ctx, ctx->allocate_temp(tensorflow::DT_RESOURCE, + tensorflow::TensorShape({}), &table_tensor)); + if (ctx->track_allocations()) { + ctx->record_persistent_memory_allocation(table->MemoryUsed() + + table_tensor.AllocatedBytes()); + } + table_tensor.scalar()() = + ResourceHandle::MakeRefCountingHandle( + table, ctx->device()->name()); + ctx->set_output(0, table_tensor); + } + + private: + AnonymousLookupTableOp(const AnonymousLookupTableOp&) = delete; + void operator=(const AnonymousLookupTableOp&) = delete; +}; + +namespace lookup { + +// Ensure that the compiler cannot elide a copy into a local, for +// bounds checking on source tensors that might be updated asynchronously for +// integral types. However non-integer variables are not allowed and therefore +// the local copy is unnecessary. +template +T SubtleMustCopyIfIntegral(const T& value) { + return internal::SubtleMustCopy(value); +} + +inline const tstring& SubtleMustCopyIfIntegral(const tstring& value) { + return value; +} + +inline const float SubtleMustCopyIfIntegral(const float value) { return value; } + +inline const double SubtleMustCopyIfIntegral(const double value) { + return value; +} + +inline const Variant& SubtleMustCopyIfIntegral(const Variant& value) { + return value; +} + +inline const ResourceHandle& SubtleMustCopyIfIntegral( + const ResourceHandle& value) { + return value; +} + +// Returns a unique node name starting with "base". +std::string UniqueNodeName(const std::string& base); + +// Lookup table that wraps an flat_hash_map, where the key and value data type +// is specified. +// +// This table is recommended for any variations to key values. +// +// For look up, the table is required to be initialized (allocated +// and populated). Once the table is marked as initialized it becomes read-only. +// +// Sample use case: +// +// HashTable table; // int64 -> int64. +// table.Initialize(...); +// table.Find(in_t, &out_t, default_t) +// +template +class HashTable : public InitializableLookupTable { + public: + HashTable(OpKernelContext* ctx, OpKernel* kernel) {} + + absl::Status AsGraphDef(GraphDefBuilder* builder, Node** out) const override { + // We set use_node_name_sharing with a unique node name so that the resource + // can outlive the HashTableV2 kernel. This means that the lifetime of the + // HashTable resource will be tied to the lifetime of the resource manager + // it is created in. + // TODO(b/181695913): Provide a mechanism for deleting this resource + // earlier when appropriate. + Node* hash_table_node = ops::SourceOp( + "HashTableV2", builder->opts() + .WithName(UniqueNodeName("HashTableFromGraphDef")) + .WithAttr("key_dtype", key_dtype()) + .WithAttr("value_dtype", value_dtype()) + .WithAttr("use_node_name_sharing", true)); + if (table_.empty()) { + *out = hash_table_node; + return absl::OkStatus(); + } + + if (initializer_serializer_ == nullptr) { + std::string message = + "Failed to serialize lookup table: no initialization function was " + "specified. Falling back to serializing a handle to the table."; + LOG(WARNING) << message; + return errors::Unimplemented(message); + } + Node* initializer; + TF_RETURN_IF_ERROR(initializer_serializer_->AsGraphDef( + builder, hash_table_node, &initializer)); + *out = ops::UnaryOp("Identity", hash_table_node, + builder->opts().WithControlInput(initializer)); + return absl::OkStatus(); + } + + size_t size() const override { + if (!is_initialized()) + return 0; + else + return table_.size(); + } + + absl::Status ExportValues(OpKernelContext* context) override { + if (!is_initialized()) { + return errors::Aborted("HashTable is not initialized."); + } + + const int64_t size = table_.size(); + + Tensor* keys; + Tensor* values; + TF_RETURN_IF_ERROR( + context->allocate_output("keys", TensorShape({size}), &keys)); + TF_RETURN_IF_ERROR( + context->allocate_output("values", TensorShape({size}), &values)); + + auto keys_data = keys->flat(); + auto values_data = values->flat(); + int64_t i = 0; + for (auto it = table_.begin(); it != table_.end(); ++it, ++i) { + keys_data(i) = it->first; + values_data(i) = it->second; + } + return absl::OkStatus(); + } + + DataType key_dtype() const override { return DataTypeToEnum::v(); } + + DataType value_dtype() const override { return DataTypeToEnum::v(); } + + protected: + absl::Status DoPrepare(size_t size) override { + if (is_initialized()) { + return errors::Aborted("HashTable already initialized."); + } + if (size > 0) { + table_.reserve(size); + } + return absl::OkStatus(); + }; + + absl::Status DoLazyPrepare(std::function size_fn) override { + return DoPrepare(size_fn()); + } + + absl::Status DoInsert(const Tensor& keys, const Tensor& values) override { + const auto key_values = keys.flat(); + const auto value_values = values.flat(); + for (int64_t i = 0; i < key_values.size(); ++i) { + auto&& key = SubtleMustCopyIfIntegral(key_values(i)); + auto&& value = SubtleMustCopyIfIntegral(value_values(i)); + auto result = table_.try_emplace(key, value); + if (!result.second && result.first->second != value) { + return errors::FailedPrecondition( + "HashTable has different value for same key. Key ", key, " has ", + result.first->second, " and trying to add value ", value); + } + } + return absl::OkStatus(); + } + + absl::Status DoFind(const Tensor& key, Tensor* value, + const Tensor& default_value) override { + const V default_val = default_value.flat()(0); + const auto key_values = key.flat(); + auto value_values = value->flat(); + + for (int64_t i = 0; i < key_values.size(); ++i) { + value_values(i) = gtl::FindWithDefault( + table_, SubtleMustCopyIfIntegral(key_values(i)), default_val); + } + return absl::OkStatus(); + } + + int64_t MemoryUsed() const override { + if (!is_initialized()) { + return 0; + } + const int64_t num_elements = table_.size(); + return num_elements * (sizeof(K) + sizeof(V)); + } + + private: + absl::flat_hash_map table_; +}; + +} // namespace lookup + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_LOOKUP_TABLE_OP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/lookup_util.h b/third_party/tflite-hdrs/tensorflow/core/kernels/lookup_util.h new file mode 100644 index 00000000..677c6a56 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/lookup_util.h @@ -0,0 +1,76 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_LOOKUP_UTIL_H_ +#define TENSORFLOW_CORE_KERNELS_LOOKUP_UTIL_H_ + +#include "tensorflow/core/framework/lookup_interface.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/kernels/initializable_lookup_table.h" + +namespace tensorflow { +namespace data { +class DatasetBase; +} // namespace data +} // namespace tensorflow + +namespace tensorflow { +namespace lookup { + +// Gets the LookupTable stored in the ctx->resource_manager() with key +// passed by attribute with name input_name, returns null if the table +// doesn't exist. Use GetResourceLookupTable() or GetReferenceLookupTable() if +// the input dtype is known. +absl::Status GetLookupTable(absl::string_view input_name, OpKernelContext* ctx, + LookupInterface** table); +absl::Status GetResourceLookupTable(absl::string_view input_name, + OpKernelContext* ctx, + LookupInterface** table); +absl::Status GetReferenceLookupTable(absl::string_view input_name, + OpKernelContext* ctx, + LookupInterface** table); + +// Gets the InitializableLookupTable stored in the +// ctx->resource_manager() with key passed by attribute with name +// input_name, returns null if the table doesn't exist. +absl::Status GetInitializableLookupTable(absl::string_view input_name, + OpKernelContext* ctx, + InitializableLookupTable** table); + +// Verify that the given key_dtype and value_dtype matches the corresponding +// table's data types. +absl::Status CheckTableDataTypes(const LookupInterface& table, + DataType key_dtype, DataType value_dtype, + const string& table_name); + +// Initializes `table` from `filename`. +absl::Status InitializeTableFromTextFile(const string& filename, + int64_t vocab_size, char delimiter, + int32_t key_index, int32_t value_index, + int64_t offset, Env* env, + InitializableLookupTable* table); + +// Initializes `table` from `filename`. `func` may specify how to represent the +// initializer as a graphdef, so that the table can be serialized as metadata. +absl::Status InitializeTableFromTextFile( + const string& filename, int64_t vocab_size, char delimiter, + int32_t key_index, int32_t value_index, int64_t offset, Env* env, + std::unique_ptr serializer, + InitializableLookupTable* table); + +} // namespace lookup +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_LOOKUP_UTIL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/loss.h b/third_party/tflite-hdrs/tensorflow/core/kernels/loss.h new file mode 100644 index 00000000..85893ba8 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/loss.h @@ -0,0 +1,59 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_KERNELS_LOSS_H_ +#define TENSORFLOW_CORE_KERNELS_LOSS_H_ + +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { + +class DualLossUpdater { + public: + virtual ~DualLossUpdater() {} + + // Compute update dual (alpha), based on a single example. Various strategies + // can be employed here, like newton step and/or line search or approximate + // step that decreases the dual sub-optimality. + virtual double ComputeUpdatedDual( + const int num_loss_partitions, const double label, + const double example_weight, const double current_dual, const double wx, + const double weighted_example_norm) const = 0; + + // Compute dual loss based on the current dual (alpha), example label (y) + // and example weight (cost). + virtual double ComputeDualLoss(const double current_dual, + const double example_label, + const double example_weight) const = 0; + + // Compute the primal loss based on current estimate of log-odds(wx), + // example label (y) and example weight (cost). + virtual double ComputePrimalLoss(const double wx, const double example_label, + const double example_weight) const = 0; + + // Primal loss derivative used to compute the dual residue in AdaSDCA + virtual double PrimalLossDerivative(const double wx, + const double example_label, + const double example_weight) const = 0; + + // This is gamma such that the loss derivative is 1/gamma Lipschitz + virtual double SmoothnessConstant() const = 0; + + // Converts binary example labels from 0.0 or 1.0 to appropriate range for + // each loss function. + virtual absl::Status ConvertLabel(float* const example_label) const = 0; +}; + +} // namespace tensorflow +#endif // TENSORFLOW_CORE_KERNELS_LOSS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/map_kernels.h b/third_party/tflite-hdrs/tensorflow/core/kernels/map_kernels.h new file mode 100644 index 00000000..6949ff55 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/map_kernels.h @@ -0,0 +1,255 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_KERNELS_MAP_KERNELS_H_ +#define TENSORFLOW_CORE_KERNELS_MAP_KERNELS_H_ + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/kernels/tensor_map.h" +#include "tensorflow/core/util/batch_util.h" +#include "tensorflow/core/util/tensor_ops_util.h" + +namespace tensorflow { + +inline absl::Status GetInputMap(OpKernelContext* ctx, int index, + const TensorMap** ret_map) { + if (!TensorShapeUtils::IsScalar(ctx->input(index).shape())) { + return errors::InvalidArgument("Input map must be a scalar. Saw: ", + ctx->input(index).shape().DebugString()); + } + const TensorMap* map = ctx->input(index).scalar()().get(); + if (map == nullptr) { + return errors::InvalidArgument( + "Input handle is not a map. Saw: '", + ctx->input(index).scalar()().DebugString(), "'"); + } + *ret_map = map; + return absl::OkStatus(); +} + +// TODO(kattian): change into templated function +inline absl::Status ForwardInputOrCreateNewMap(OpKernelContext* ctx, + int32_t input_index, + int32_t output_index, + const TensorMap& input_map, + TensorMap** output_map) { + // Attempt to forward the input tensor to the output if possible. + std::unique_ptr maybe_output = ctx->forward_input( + input_index, output_index, DT_VARIANT, TensorShape{}, + ctx->input_memory_type(input_index), AllocatorAttributes()); + Tensor* output_tensor; + if (maybe_output != nullptr && maybe_output->dtype() == DT_VARIANT && + maybe_output->NumElements() == 1) { + output_tensor = maybe_output.get(); + TensorMap* tmp_out = output_tensor->scalar()().get(); + if (tmp_out == nullptr) { + return errors::InvalidArgument( + "Expected input ", input_index, " to be a TensorMap but saw ", + output_tensor->scalar()().TypeName()); + } + if (tmp_out->RefCountIsOne()) { + // Woohoo, forwarding succeeded! + ctx->set_output(output_index, *output_tensor); + *output_map = tmp_out; + return absl::OkStatus(); + } + } + + // If forwarding is not possible allocate a new output tensor and copy + // the `input_map` to it. + AllocatorAttributes attr; + attr.set_on_host(true); + TF_RETURN_IF_ERROR( + ctx->allocate_output(output_index, {}, &output_tensor, attr)); + output_tensor->scalar()() = input_map.Copy(); + + *output_map = output_tensor->scalar()().get(); + return absl::OkStatus(); +} + +class EmptyTensorMap : public OpKernel { + public: + explicit EmptyTensorMap(OpKernelConstruction* ctx) : OpKernel(ctx) {} + + void Compute(OpKernelContext* ctx) override { + Tensor* result; + AllocatorAttributes attr; + attr.set_on_host(true); + OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape{}, &result, attr)); + TensorMap empty; + result->scalar()() = std::move(empty); + } +}; + +class TensorMapSize : public OpKernel { + public: + explicit TensorMapSize(OpKernelConstruction* ctx) : OpKernel(ctx) {} + ~TensorMapSize() override {} + + void Compute(OpKernelContext* ctx) override { + const TensorMap* map = nullptr; + OP_REQUIRES_OK(ctx, GetInputMap(ctx, 0, &map)); + Tensor* result; + OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape{}, &result)); + result->scalar()() = map->tensors().size(); + } +}; + +class TensorMapLookup : public OpKernel { + public: + explicit TensorMapLookup(OpKernelConstruction* ctx) : OpKernel(ctx) {} + ~TensorMapLookup() override {} + + void Compute(OpKernelContext* ctx) override { + const TensorKey& key = ctx->input(1); + const TensorMap* map = nullptr; + OP_REQUIRES_OK(ctx, GetInputMap(ctx, 0, &map)); + + OP_REQUIRES( + ctx, map->tensors().find(key) != map->tensors().end(), + errors::InvalidArgument("Trying to lookup non-existent key. Could not " + "find key \"" + + key.SummarizeValue(100) + "\".")); + + ctx->set_output(0, map->tensors().find(key)->second); + } +}; + +class TensorMapInsert : public OpKernel { + public: + explicit TensorMapInsert(OpKernelConstruction* ctx) : OpKernel(ctx) {} + ~TensorMapInsert() override {} + + void Compute(OpKernelContext* ctx) override { + const TensorKey& key = ctx->input(1); + const Tensor& value = ctx->input(2); + const TensorMap* map = nullptr; + OP_REQUIRES_OK(ctx, GetInputMap(ctx, 0, &map)); + + TensorMap* output_map = nullptr; + OP_REQUIRES_OK(ctx, + ForwardInputOrCreateNewMap(ctx, 0, 0, *map, &output_map)); + output_map->replace(key, value); + } +}; + +class TensorMapErase : public OpKernel { + public: + explicit TensorMapErase(OpKernelConstruction* ctx) : OpKernel(ctx) {} + + void Compute(OpKernelContext* ctx) override { + const TensorKey& key = ctx->input(1); + const TensorMap* map = nullptr; + OP_REQUIRES_OK(ctx, GetInputMap(ctx, 0, &map)); + + OP_REQUIRES( + ctx, map->tensors().find(key) != map->tensors().end(), + errors::InvalidArgument("Trying to erase non-existent item. Could not " + "find key \"" + + key.SummarizeValue(100) + "\".")); + + TensorMap* output_map = nullptr; + OP_REQUIRES_OK(ctx, + ForwardInputOrCreateNewMap(ctx, 0, 0, *map, &output_map)); + output_map->tensors().erase(key); + } +}; + +class TensorMapHasKey : public OpKernel { + public: + explicit TensorMapHasKey(OpKernelConstruction* ctx) : OpKernel(ctx) {} + ~TensorMapHasKey() override {} + + void Compute(OpKernelContext* ctx) override { + const TensorKey& key = ctx->input(1); + const TensorMap* map = nullptr; + OP_REQUIRES_OK(ctx, GetInputMap(ctx, 0, &map)); + Tensor* result; + OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape{}, &result)); + result->scalar()() = map->tensors().find(key) != map->tensors().end(); + } +}; + +class TensorMapStackKeys : public OpKernel { + public: + explicit TensorMapStackKeys(OpKernelConstruction* ctx) : OpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("key_dtype", &key_dtype_)); + } + ~TensorMapStackKeys() override {} + + void Compute(OpKernelContext* ctx) override { + const TensorMap* map = nullptr; + OP_REQUIRES_OK(ctx, GetInputMap(ctx, 0, &map)); + + OP_REQUIRES(ctx, map->size() != 0, + errors::InvalidArgument( + "TensorMapStackKeys cannot be called on empty map.")); + + auto it = map->tensors().begin(); + TensorShape output_shape = it->first.shape(); + output_shape.InsertDim(0, map->tensors().size()); + Tensor* result; + OP_REQUIRES_OK(ctx, ctx->allocate_output(0, output_shape, &result)); + + int i = 0; + size_t sz = map->tensors().size(); + TensorShape key_shape = it->first.shape(); + while (it != map->tensors().end() && i < sz) { + OP_REQUIRES( + ctx, it->first.dtype() == key_dtype_, + errors::InvalidArgument("Key does not match requested dtype.")); + OP_REQUIRES( + ctx, it->first.shape() == key_shape, + errors::InvalidArgument("Keys must all have the same shape.")); + OP_REQUIRES_OK(ctx, batch_util::CopyElementToSlice(it->first, result, i)); + i++; + it++; + } + } + + private: + DataType key_dtype_; +}; + +template +absl::Status TensorMapBinaryAdd(OpKernelContext* ctx, const TensorMap& a, + const TensorMap& b, TensorMap* out) { + // Binary add returns a map containing the union of keys. + // Values with keys in the intersection are added. + out->tensors() = a.tensors(); + for (const std::pair& p : b.tensors()) { + absl::flat_hash_map::iterator it = + out->tensors().find(p.first); + if (it != out->tensors().end()) { + Tensor out_tensor; + TF_RETURN_IF_ERROR( + BinaryAddTensors(ctx, p.second, it->second, &out_tensor)); + it->second = out_tensor; + } else { + out->tensors().emplace(p.first, p.second); + } + } + return absl::OkStatus(); +} + +template +absl::Status TensorMapZerosLike(OpKernelContext* ctx, const TensorMap& x, + TensorMap* y) { + // Zeros like returns an empty map. + return absl::OkStatus(); +} + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_MAP_KERNELS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/matmul_op.h b/third_party/tflite-hdrs/tensorflow/core/kernels/matmul_op.h new file mode 100644 index 00000000..94a39794 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/matmul_op.h @@ -0,0 +1,69 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_MATMUL_OP_H_ +#define TENSORFLOW_CORE_KERNELS_MATMUL_OP_H_ + +#include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/lib/hash/hash.h" + +#if defined(TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL) +#include "xla/tsl/framework/contraction/eigen_contraction_kernel.h" +#endif + +namespace tensorflow { +namespace functor { + +// Helpers to define tensor needed by MatMul op. +template +struct MatMulTypes { + typedef Eigen::TensorMap, Eigen::Aligned> + out_type; + typedef Eigen::TensorMap, + Eigen::Aligned> + in_type; +}; + +template +void MatMul(const Device& d, Out out, In0 in0, In1 in1, + const DimPair& dim_pair) { + out.device(d) = in0.contract(in1, dim_pair); +} + +template +struct MatMulFunctor { + // Computes on device "d": out = in0 * in1, where * is matrix + // multiplication. + void operator()( + const Device& d, typename MatMulTypes::out_type out, + typename MatMulTypes::in_type in0, + typename MatMulTypes::in_type in1, + const Eigen::array, 1>& dim_pair); +}; + +} // end namespace functor + +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM + +typedef Eigen::GpuDevice GPUDevice; + +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM + +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_MATMUL_OP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/matmul_op_impl.h b/third_party/tflite-hdrs/tensorflow/core/kernels/matmul_op_impl.h new file mode 100644 index 00000000..50517dc9 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/matmul_op_impl.h @@ -0,0 +1,1156 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// See docs in ../ops/math_ops.cc. + +#ifndef TENSORFLOW_CORE_KERNELS_MATMUL_OP_IMPL_H_ +#define TENSORFLOW_CORE_KERNELS_MATMUL_OP_IMPL_H_ + +#define EIGEN_USE_THREADS + +#include +#include +#include +#include +#include +#include + +#include "Eigen/Core" // from @eigen_archive +#include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/bfloat16.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/type_traits.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/kernels/fill_functor.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/gtl/inlined_vector.h" +#include "tensorflow/core/platform/bfloat16.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/matmul_autotune.h" +#include "tensorflow/core/util/matmul_bcast.h" +#include "tensorflow/core/util/work_sharder.h" + +#if defined(TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL) +#include "xla/tsl/framework/contraction/eigen_contraction_kernel.h" +#endif + +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM +#include "xla/stream_executor/host_or_device_scalar.h" +#include "tensorflow/core/kernels/gpu_utils.h" +#include "tensorflow/core/kernels/matmul_util.h" +#include "tensorflow/core/kernels/numeric_options_utils.h" +#include "tensorflow/core/platform/stream_executor.h" +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM +#if GOOGLE_CUDA +#include "third_party/gpus/cuda/include/cuda.h" +#include "xla/stream_executor/cuda/cuda_blas_lt.h" +#endif // GOOGLE_CUDA +#if TENSORFLOW_USE_ROCM +#include "rocm/rocm_config.h" +#if TF_HIPBLASLT +#include "xla/stream_executor/rocm/hip_blas_lt.h" +#endif +#endif + +namespace tensorflow { + +typedef Eigen::ThreadPoolDevice CPUDevice; +typedef Eigen::GpuDevice GPUDevice; + +namespace { + +// Returns the pair of dimensions along which to perform Tensor contraction to +// emulate matrix multiplication. +// For matrix multiplication of 2D Tensors X and Y, X is contracted along +// second dimension and Y is contracted along the first dimension (if neither X +// nor Y is adjointed). The dimension to contract along is switched when any +// operand is adjointed. +// See http://en.wikipedia.org/wiki/Tensor_contraction +inline Eigen::IndexPair ContractionDims(bool adj_x, + bool adj_y) { + return Eigen::IndexPair(adj_x ? 0 : 1, adj_y ? 1 : 0); +} + +// Parallel batch matmul kernel based on the multi-threaded tensor contraction +// in Eigen. +template +struct ParallelMatMulKernel { + static void Conjugate(const OpKernelContext* context, Tensor* out) { + const Eigen::ThreadPoolDevice d = context->eigen_cpu_device(); + auto z = out->tensor(); + z.device(d) = z.conjugate(); + } + + static void Run(const OpKernelContext* context, const Tensor& in_x, + const Tensor& in_y, bool adj_x, bool adj_y, bool trans_x, + bool trans_y, const MatMulBCast& bcast, Tensor* out, + int batch_size) { + static_assert(IsComplex, "Complex type expected."); + auto Tx = in_x.tensor(); + auto Ty = in_y.tensor(); + auto Tz = out->tensor(); + // We use the identities + // conj(a) * conj(b) = conj(a * b) + // conj(a) * b = conj(a * conj(b)) + // to halve the number of cases. The final conjugation of the result is + // done at the end of LaunchBatchMatMul::Launch(). + Eigen::array, 1> contract_pairs; + contract_pairs[0] = ContractionDims(adj_x || trans_x, adj_y || trans_y); + const Eigen::ThreadPoolDevice d = context->eigen_cpu_device(); + + const bool should_bcast = bcast.IsBroadcastingRequired(); + const auto& x_batch_indices = bcast.x_batch_indices(); + const auto& y_batch_indices = bcast.y_batch_indices(); + // TODO(rmlarsen): Consider launching these contractions asynchronously. + for (int64_t i = 0; i < batch_size; ++i) { + const int64_t x_batch_index = should_bcast ? x_batch_indices[i] : i; + const int64_t y_batch_index = should_bcast ? y_batch_indices[i] : i; + + auto x = Tx.template chip<0>(x_batch_index); + auto z = Tz.template chip<0>(i); + if (adj_x != adj_y) { + auto y = Ty.template chip<0>(y_batch_index).conjugate(); + z.device(d) = x.contract(y, contract_pairs); + } else { + auto y = Ty.template chip<0>(y_batch_index); + z.device(d) = x.contract(y, contract_pairs); + } + } + } +}; + +// The Eigen contraction kernel used here is very large and slow to compile, +// so we partially specialize ParallelMatMulKernel for real types to avoid all +// but one of the instantiations. +template +struct ParallelMatMulKernel { + static void Conjugate(const OpKernelContext* context, Tensor* out) {} + + static void Run(const OpKernelContext* context, const Tensor& in_x, + const Tensor& in_y, bool adj_x, bool adj_y, bool trans_x, + bool trans_y, const MatMulBCast& bcast, Tensor* out, + int batch_size) { + const bool should_bcast = bcast.IsBroadcastingRequired(); + const Eigen::ThreadPoolDevice d = context->eigen_cpu_device(); + Eigen::array, 1> contract_pairs; + contract_pairs[0] = ContractionDims(adj_x || trans_x, adj_y || trans_y); + if (batch_size == 1 && !should_bcast) { + auto Tx = in_x.flat_inner_dims(); + auto Ty = in_y.flat_inner_dims(); + auto Tz = out->flat_inner_dims(); + Tz.device(d) = Tx.contract(Ty, contract_pairs); + } else { + auto Tx = in_x.tensor(); + auto Ty = in_y.tensor(); + auto Tz = out->tensor(); + const auto& x_batch_indices = bcast.x_batch_indices(); + const auto& y_batch_indices = bcast.y_batch_indices(); + // TODO(rmlarsen): Consider launching these contractions asynchronously. + for (int64_t i = 0; i < batch_size; ++i) { + const int64_t x_batch_index = should_bcast ? x_batch_indices[i] : i; + const int64_t y_batch_index = should_bcast ? y_batch_indices[i] : i; + auto x = Tx.template chip<0>(x_batch_index); + auto y = Ty.template chip<0>(y_batch_index); + auto z = Tz.template chip<0>(i); + + z.device(d) = x.contract(y, contract_pairs); + } + } + } +}; + +// Basic y-combinator implementation. +template +struct YCombinatorImpl { + Func func; + template + decltype(auto) operator()(Args&&... args) const { + return func(*this, std::forward(args)...); + } +}; + +template +YCombinatorImpl> YCombinator(Func&& func) { + return YCombinatorImpl>{std::forward(func)}; +} + +// Sequential batch matmul kernel that calls the regular Eigen matmul. +// We prefer this over the tensor contraction because it performs +// better on vector-matrix and matrix-vector products. +template +struct SequentialMatMulKernel { + using Matrix = + Eigen::Matrix; + using ConstMatrixMap = Eigen::Map; + using MatrixMap = Eigen::Map; + + static ConstMatrixMap ConstTensorSliceToEigenMatrix(const Tensor& t, + int slice) { + return ConstMatrixMap( + t.flat().data() + slice * t.dim_size(1) * t.dim_size(2), + t.dim_size(1), t.dim_size(2)); + } + + static MatrixMap TensorSliceToEigenMatrix(Tensor* t, int slice) { + return MatrixMap( + t->flat().data() + slice * t->dim_size(1) * t->dim_size(2), + t->dim_size(1), t->dim_size(2)); + } + + static void Run(const Tensor& in_x, const Tensor& in_y, bool adj_x, + bool adj_y, bool trans_x, bool trans_y, + const MatMulBCast& bcast, Tensor* out, int start, int limit) { + const bool should_bcast = bcast.IsBroadcastingRequired(); + const auto& x_batch_indices = bcast.x_batch_indices(); + const auto& y_batch_indices = bcast.y_batch_indices(); + for (int64_t i = start; i < limit; ++i) { + const int64_t x_batch_index = should_bcast ? x_batch_indices[i] : i; + const int64_t y_batch_index = should_bcast ? y_batch_indices[i] : i; + auto x = ConstTensorSliceToEigenMatrix(in_x, x_batch_index); + auto y = ConstTensorSliceToEigenMatrix(in_y, y_batch_index); + auto z = TensorSliceToEigenMatrix(out, i); + // Assume at most one of adj_x or trans_x is true. Similarly, for adj_y + // and trans_y. + if (!adj_x && !trans_x) { + if (!adj_y && !trans_y) { + z.noalias() = x * y; + } else if (adj_y) { + z.noalias() = x * y.adjoint(); + } else { // trans_y == true + z.noalias() = x * y.transpose(); + } + } else if (adj_x) { + if (!adj_y && !trans_y) { + z.noalias() = x.adjoint() * y; + } else if (adj_y) { + z.noalias() = x.adjoint() * y.adjoint(); + } else { // trans_y == true + z.noalias() = x.adjoint() * y.transpose(); + } + } else { // trans_x == true + if (!adj_y && !trans_y) { + z.noalias() = x.transpose() * y; + } else if (adj_y) { + z.noalias() = x.transpose() * y.adjoint(); + } else { // trans_y == true + z.noalias() = x.transpose() * y.transpose(); + } + } + } + } +}; + +// For single-batch multiplications, manually parallize by splitting the output +// matrix. +template +struct SingleBatchParallelMatMulKernel { + using Matrix = + Eigen::Matrix; + using ConstMatrixMap = Eigen::Map; + using MatrixMap = Eigen::Map; + + static ConstMatrixMap ConstTensorToEigenMatrix(const Tensor& t) { + return ConstMatrixMap(t.flat().data(), t.dim_size(1), + t.dim_size(2)); + } + + static MatrixMap TensorToEigenMatrix(Tensor* t) { + return MatrixMap(t->flat().data(), t->dim_size(1), t->dim_size(2)); + } + + static void Run(const CPUDevice& device, const Tensor& in_x, + const Tensor& in_y, bool adj_x, bool adj_y, bool trans_x, + bool trans_y, Tensor* out) { + using Eigen::Index; + Eigen::ThreadPoolInterface* pool = device.getPool(); + + Index m = (trans_x || adj_x) ? in_x.dim_size(2) : in_x.dim_size(1); + Index k = (trans_x || adj_x) ? in_x.dim_size(1) : in_x.dim_size(2); + Index n = (trans_y || adj_y) ? in_y.dim_size(1) : in_y.dim_size(2); + + auto x_mat = ConstTensorToEigenMatrix(in_x); + auto y_mat = ConstTensorToEigenMatrix(in_y); + auto out_mat = TensorToEigenMatrix(out); + + // Computes a block of the output matrix. + auto compute_matmul_block = [&x_mat, &y_mat, &out_mat, adj_x, trans_x, + adj_y, trans_y](Index row, Index col, + Index nrows, Index ncols) { + auto z = out_mat.block(row, col, nrows, ncols); + + // Assume at most one of adj_x or trans_x is true. Similarly, for adj_y + // and trans_y. + if (!adj_x && !trans_x) { + auto x = x_mat.middleRows(row, nrows); + if (!adj_y && !trans_y) { + auto y = y_mat.middleCols(col, ncols); + z = x * y; + } else if (adj_y) { + auto y = y_mat.middleRows(col, ncols); + z.noalias() = x * y.adjoint(); + } else { // trans_y == true + auto y = y_mat.middleRows(col, ncols); + z.noalias() = x * y.transpose(); + } + } else if (adj_x) { + auto x = x_mat.middleCols(row, nrows); + if (!adj_y && !trans_y) { + auto y = y_mat.middleCols(col, ncols); + z.noalias() = x.adjoint() * y; + } else if (adj_y) { + auto y = y_mat.middleRows(col, ncols); + z.noalias() = x.adjoint() * y.adjoint(); + } else { // trans_y == true + auto y = y_mat.middleRows(col, ncols); + z.noalias() = x.adjoint() * y.transpose(); + } + } else { // trans_x == true + auto x = x_mat.middleCols(row, nrows); + if (!adj_y && !trans_y) { + auto y = y_mat.middleCols(col, ncols); + z.noalias() = x.transpose() * y; + } else if (adj_y) { + auto y = y_mat.middleRows(col, ncols); + z.noalias() = x.transpose() * y.adjoint(); + } else { // trans_y == true + auto y = y_mat.middleRows(col, ncols); + z.noalias() = x.transpose() * y.transpose(); + } + } + }; + + // Split the work across n threads, unless the total amount of work + // is small (e.g. 128 * 128) - in which case use fewer threads. This is + // the same heuristic value used in LaunchBatchMatMul below. + const int64_t kMaxCostOuterParallelism = 128 * 128; + Index work_limit = std::max((m * k * n) / pool->NumThreads(), + kMaxCostOuterParallelism); + // Blocks should have a size no smaller than 8 * kPacketSize, except perhaps + // for tail blocks. + constexpr int kPacketSize = Eigen::internal::packet_traits::size; + constexpr Index kBlockMin = 8 * kPacketSize; + + // Precompute how many blocks there will be. + auto compute_blocks = YCombinator([k, work_limit, kBlockMin]( + auto& compute_blocks, Index row, + Index col, Index nrows, + Index ncols) -> Index { + Index work = nrows * k * ncols; + Index blocks = 0; + while (work > work_limit && (nrows > kBlockMin || ncols > kBlockMin)) { + if (nrows > ncols) { + Index half = Eigen::divup(nrows / 2, kBlockMin) * kBlockMin; + blocks += 1 + compute_blocks(row + half, col, nrows - half, ncols); + nrows = half; + } else { + Index half = Eigen::divup(ncols / 2, kBlockMin) * kBlockMin; + blocks += 1 + compute_blocks(row, col + half, nrows, ncols - half); + ncols = half; + } + work = nrows * k * ncols; + } + return blocks; + }); + Index total_blocks = 1 + compute_blocks(0, 0, m, n); + + // Recursively split work according to the exact same heuristic as above. + Eigen::Barrier barrier(total_blocks); + auto handle_range = YCombinator( + [k, pool, &barrier, work_limit, kBlockMin, &compute_matmul_block]( + auto& handle_range, Index row, Index col, Index nrows, + Index ncols) -> void { + Index work = nrows * k * ncols; + while (work > work_limit && + (nrows > kBlockMin || ncols > kBlockMin)) { + if (nrows > ncols) { + Index half = Eigen::divup(nrows / 2, kBlockMin) * kBlockMin; + pool->Schedule([&handle_range, row, half, col, nrows, ncols]() { + handle_range(row + half, col, nrows - half, ncols); + }); + nrows = half; + } else { + Index half = Eigen::divup(ncols / 2, kBlockMin) * kBlockMin; + pool->Schedule([&handle_range, row, half, col, nrows, ncols]() { + handle_range(row, col + half, nrows, ncols - half); + }); + ncols = half; + } + work = nrows * k * ncols; + } + + if (nrows > 0 && ncols > 0) { + // Compute the output block. + compute_matmul_block(row, col, nrows, ncols); + } + barrier.Notify(); + }); + handle_range(0, 0, m, n); + barrier.Wait(); + } +}; + +} // namespace + +template +struct LaunchBatchMatMul; + +template +struct LaunchBatchMatMul { + static void Launch(OpKernelContext* context, const Tensor& in_x, + const Tensor& in_y, bool adj_x, bool adj_y, bool trans_x, + bool trans_y, bool grad_x, bool grad_y, + const MatMulBCast& bcast, Tensor* out) { + typedef ParallelMatMulKernel::IsComplex> + ParallelMatMulKernel; + bool conjugate_result = false; + + // Number of matrix multiplies i.e. size of the batch. + const int64_t batch_size = bcast.output_batch_size(); + const int64_t cost_per_unit = + in_x.dim_size(1) * in_x.dim_size(2) * out->dim_size(2); + const int64_t small_dim = std::min( + std::min(in_x.dim_size(1), in_x.dim_size(2)), out->dim_size(2)); + // NOTE(nikhilsarda): This heuristic is optimal in benchmarks as of + // Jan 21, 2020. + const int64_t kMaxCostOuterParallelism = 128 * 128; // heuristic. + auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads()); + // TODO(rmlarsen): Reconsider the heuristics now that we have asynchronous + // evaluation in Eigen Tensor. + if (small_dim > 1 && + (batch_size == 1 || cost_per_unit > kMaxCostOuterParallelism)) { + // Parallelize over inner dims. + // For large matrix products it is counter-productive to parallelize + // over the batch dimension. + ParallelMatMulKernel::Run(context, in_x, in_y, adj_x, adj_y, trans_x, + trans_y, bcast, out, batch_size); + conjugate_result = adj_x; + } else if (batch_size > 1) { + // Parallelize over outer dims. For small matrices and large batches, it + // is counter-productive to parallelize the inner matrix multiplies. + Shard(worker_threads.num_threads, worker_threads.workers, batch_size, + cost_per_unit, + [&in_x, &in_y, adj_x, adj_y, trans_x, trans_y, &bcast, out]( + int start, int limit) { + SequentialMatMulKernel::Run(in_x, in_y, adj_x, adj_y, + trans_x, trans_y, bcast, out, + start, limit); + }); + } else if (cost_per_unit > kMaxCostOuterParallelism) { + // Split along output blocks. + SingleBatchParallelMatMulKernel::Run(context->eigen_cpu_device(), + in_x, in_y, adj_x, adj_y, + trans_x, trans_y, out); + } else { + // Single small multiplication. + SequentialMatMulKernel::Run(in_x, in_y, adj_x, adj_y, trans_x, + trans_y, bcast, out, 0, batch_size); + } + + if (conjugate_result) { + // We used one of the identities + // conj(a) * conj(b) = conj(a * b) + // conj(a) * b = conj(a * conj(b)) + // above, we need to conjugate the final output. This is a + // no-op for non-complex types. + ParallelMatMulKernel::Conjugate(context, out); + } + } +}; + +#if GOOGLE_CUDA || TF_HIPBLASLT + +namespace { +// A dummy type to group matmul autotune results together. +struct BlasLtMatmulAutoTuneGroup { + static string name() { return "MatmulLt"; } +}; + +typedef AutotuneSingleton> + AutoTuneBatchMatmul; + +} // namespace + +#endif // GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM + +class BlasScratchAllocator : public se::ScratchAllocator { + public: + using Stream = se::Stream; + using DeviceMemoryBytes = se::DeviceMemory; + + BlasScratchAllocator(OpKernelContext* context) + : memory_limit_(0), total_byte_size_(0), context_(context) {} + + BlasScratchAllocator(OpKernelContext* context, int64_t memory_limit) + : memory_limit_(memory_limit), total_byte_size_(0), context_(context) {} + + int64_t GetMemoryLimitInBytes() override { return memory_limit_; } + + tsl::StatusOr AllocateBytes(int64_t byte_size) override { + Tensor temporary_memory; + + if (memory_limit_ > 0 && byte_size > memory_limit_) { + return tsl::Status{ + absl::StatusCode::kUnavailable, + absl::StrCat("Requested memory size (", byte_size, + ") exceeds the memory limit (", memory_limit_, ").")}; + } + AllocationAttributes allocation_attr; + allocation_attr.retry_on_failure = false; + Status allocation_status(context_->allocate_temp( + DT_UINT8, TensorShape({byte_size}), &temporary_memory)); + if (!allocation_status.ok()) { + return tsl::Status{ + absl::StatusCode::kUnavailable, + absl::StrCat("Failed to allocate requested memory of (", byte_size, + ").")}; + } + // Hold the reference of the allocated tensors until the end of the + // allocator. + allocated_tensors_.push_back(temporary_memory); + total_byte_size_ += byte_size; + return tsl::StatusOr(DeviceMemoryBytes::MakeFromByteSize( + temporary_memory.flat().data(), + temporary_memory.flat().size())); + } + int64 TotalByteSize() { return total_byte_size_; } + + private: + int64_t memory_limit_; + int64_t total_byte_size_; + OpKernelContext* context_; + std::vector allocated_tensors_; +}; + +template +struct LaunchBatchMatMul { + static void Launch(OpKernelContext* context, const Tensor& in_x, + const Tensor& in_y, bool adj_x, bool adj_y, bool trans_x, + bool trans_y, bool grad_x, bool grad_y, + const MatMulBCast& bcast, Tensor* out) { + se::blas::Transpose trans[] = {se::blas::Transpose::kNoTranspose, + se::blas::Transpose::kTranspose, + se::blas::Transpose::kConjugateTranspose}; + const uint64 m = in_x.dim_size(adj_x || trans_x ? 2 : 1); + const uint64 k = in_x.dim_size(adj_x || trans_x ? 1 : 2); + const uint64 n = in_y.dim_size(adj_y || trans_y ? 1 : 2); + const int64_t batch_size = bcast.output_batch_size(); + auto blas_transpose_a = trans[adj_x ? 2 : (trans_x ? 1 : 0)]; + auto blas_transpose_b = trans[adj_y ? 2 : (trans_y ? 1 : 0)]; + + auto* stream = context->op_device_context()->stream(); + OP_REQUIRES(context, stream, errors::Internal("No GPU stream available.")); + + typedef se::DeviceMemory DeviceMemoryType; + std::vector a_device_memory; + std::vector b_device_memory; + std::vector c_device_memory; + std::vector a_ptrs; + std::vector b_ptrs; + std::vector c_ptrs; + a_device_memory.reserve(bcast.x_batch_size()); + b_device_memory.reserve(bcast.y_batch_size()); + c_device_memory.reserve(batch_size); + a_ptrs.reserve(batch_size); + b_ptrs.reserve(batch_size); + c_ptrs.reserve(batch_size); + auto* a_base_ptr = in_x.template flat().data(); + auto* b_base_ptr = in_y.template flat().data(); + auto* c_base_ptr = out->template flat().data(); + uint64 a_stride; + uint64 b_stride; + uint64 c_stride; + + bool is_full_broadcast = + std::min(bcast.x_batch_size(), bcast.y_batch_size()) == 1; + + // Use float as coefficient type for half and bfloat16 precision inputs, + // otherwise use the input type. + constexpr bool is_16bit_input = std::is_same_v || + std::is_same_v; + using Coefficient = std::conditional_t; + + se::blas::CallContext call_context = se::blas::CallContext::kNone; + OP_REQUIRES(context, grad_x == false || grad_y == false, + errors::InvalidArgument( + "At least 1 of grad_x and grad_y shall be false")); + if (grad_x) { + call_context = se::blas::CallContext::kBackpropInput1; + } + if (grad_y) { + call_context = se::blas::CallContext::kBackpropInput2; + } +#if GOOGLE_CUDA || TF_HIPBLASLT + static const bool use_autotune = MatmulAutotuneEnable(); + bool bCublasLtSupport = true; + + const auto& cc = + stream->parent()->GetDeviceDescription().gpu_compute_capability(); + if (auto* procm = std::get_if(&cc)) { + bCublasLtSupport = procm->gfx9_mi200_or_later(); + } + + if (EnableCublasLtGemm() && bCublasLtSupport) { + static const int64_t max_scratch_size = + GetWorkspaceLimit(1LL << 32); // 4GB by default + + bool requires_mixed_broadcasting = + bcast.IsBroadcastingRequired() && !is_full_broadcast; + + if (!requires_mixed_broadcasting) { + a_device_memory.push_back(AsDeviceMemory(a_base_ptr)); + b_device_memory.push_back(AsDeviceMemory(b_base_ptr)); + c_device_memory.push_back(AsDeviceMemory(c_base_ptr)); + a_ptrs.push_back(&a_device_memory.back()); + b_ptrs.push_back(&b_device_memory.back()); + c_ptrs.push_back(&c_device_memory.back()); + + BlasLtMatmulPlanParams matmul_params{ + se::blas::ToDataType::value, + static_cast(m), + static_cast(n), + static_cast(k), + blas_transpose_a, + blas_transpose_b, + static_cast(batch_size), + /*broadcast_a=*/bcast.x_batch_size() == 1, + /*broadcast_b=*/bcast.y_batch_size() == 1}; + + std::optional max_algorithm_count; + if (!use_autotune) max_algorithm_count = 1; + absl::Mutex* pmu = nullptr; + auto plan_and_algorithms_or = PlanAndAlgorithms::GetOrCreate( + stream, matmul_params, &pmu, max_algorithm_count); + OP_REQUIRES_OK(context, plan_and_algorithms_or.status()); + absl::MutexLock lock(pmu); + const auto* plan_and_algorithms = + std::move(plan_and_algorithms_or).value(); + auto n_algorithms = plan_and_algorithms->algorithms.size(); + + se::blas::AlgorithmConfig algorithm_config(se::blas::kNoAlgorithm); + if (!use_autotune) { + algorithm_config.set_algorithm(0); + } else if (!AutoTuneBatchMatmul::GetInstance()->Find( + matmul_params, &algorithm_config)) { + VLOG(4) << "Autotuning BlasLtMatmul over " << n_algorithms + << " algorithms."; + se::blas::ProfileResult best_result; + se::blas::ProfileResult profile_result; + + for (size_t i = 0; i != n_algorithms; ++i) { + // Create a new scratch allocator with every autotuning run so that + // scratch space is deallocated between runs. + BlasScratchAllocator scratch_allocator(context, max_scratch_size); + Status cublas_launch_status = plan_and_algorithms->ExecuteOnStream( + stream, *a_ptrs[0], *b_ptrs[0], *c_ptrs[0], i, + scratch_allocator, se::DeviceMemoryBase{}, &profile_result); + + VLOG(4) << " Autotune algorithm " << i + << " result: " << profile_result.elapsed_time_in_ms() + << " ms, valid=" << profile_result.is_valid() + << ", workspace_size=" + << plan_and_algorithms->algorithms[i].workspace_size; + + if (cublas_launch_status.ok() && profile_result.is_valid() && + profile_result.elapsed_time_in_ms() < + best_result.elapsed_time_in_ms()) { + best_result = profile_result; + // Use index into algorithms array, instead of cublas internal ID. + best_result.set_algorithm(i); + } + } + + if (best_result.is_valid()) { + algorithm_config.set_algorithm(best_result.algorithm()); + } + // Each matmul parameter set gets one pass of + // autotune. If no algorithms works, kNoAlgorithm is added to the + // autotune map. + AutoTuneBatchMatmul::GetInstance()->Insert(matmul_params, + algorithm_config); + } + se::blas::AlgorithmType algorithm_idx = algorithm_config.algorithm(); + OP_REQUIRES(context, 0 <= algorithm_idx && algorithm_idx < n_algorithms, + errors::Internal("Missing/invalid BatchMatmul algorithm")); + BlasScratchAllocator scratch_allocator(context, max_scratch_size); + VLOG(4) << "Calling BlasLtMatMul: a.shape=(" << bcast.x_batch_size() + << ", " << in_x.dim_size(1) << ", " << in_x.dim_size(2) + << "), b.shape=(" << bcast.y_batch_size() << ", " + << in_y.dim_size(1) << ", " << in_y.dim_size(2) << "), m=" << m + << ", n=" << n << ", k=" << k << ", batch_size=" << batch_size + << "trans_x = " << trans_x << "trans_y = " << trans_y + << "adj_x = " << adj_x << "adj_y = " << adj_y; + + OP_REQUIRES_OK(context, plan_and_algorithms->ExecuteOnStream( + stream, *a_ptrs[0], *b_ptrs[0], *c_ptrs[0], + algorithm_idx, scratch_allocator)); + } else { // requires mixed broadcasting + const std::vector& a_batch_indices = bcast.x_batch_indices(); + const std::vector& b_batch_indices = bcast.y_batch_indices(); + for (int64_t i = 0; i < bcast.x_batch_size(); ++i) { + a_device_memory.push_back(AsDeviceMemory(a_base_ptr + i * m * k)); + } + for (int64_t i = 0; i < bcast.y_batch_size(); ++i) { + b_device_memory.push_back(AsDeviceMemory(b_base_ptr + i * k * n)); + } + for (int64_t i = 0; i < batch_size; ++i) { + c_device_memory.push_back(AsDeviceMemory(c_base_ptr + i * m * n)); + a_ptrs.push_back(&a_device_memory[a_batch_indices[i]]); + b_ptrs.push_back(&b_device_memory[b_batch_indices[i]]); + c_ptrs.push_back(&c_device_memory.back()); + } + + BlasScratchAllocator scratch_allocator(context, max_scratch_size); + auto blas = stream->parent()->AsBlas(); + OP_REQUIRES(context, blas != nullptr, + absl::InternalError("No blas support for stream")); + bool blas_launch_status = blas->DoBlasGemmBatched( + stream, blas_transpose_b, blas_transpose_a, n, m, k, + static_cast(1.0), b_ptrs, adj_y || trans_y ? k : n, + a_ptrs, adj_x || trans_x ? m : k, static_cast(0.0), + c_ptrs, n, batch_size, GetNumericOptions(), &scratch_allocator, + call_context); + if (!blas_launch_status) { + context->SetStatus(errors::Internal( + "Blas xGEMMBatched launch failed: a.shape=", + in_x.shape().DebugString(), + ", b.shape=", in_y.shape().DebugString(), ", m=", m, ", n=", n, + ", k=", k, ", batch_size=", batch_size)); + } + } + } else { +#endif // GOOGLE_CUDA + bool use_strided_batched = + (!bcast.IsBroadcastingRequired() || is_full_broadcast) && + batch_size > 1; + if (use_strided_batched) { + a_stride = bcast.x_batch_size() != 1 ? m * k : 0; + b_stride = bcast.y_batch_size() != 1 ? k * n : 0; + c_stride = m * n; + a_device_memory.push_back(AsDeviceMemory(a_base_ptr)); + b_device_memory.push_back(AsDeviceMemory(b_base_ptr)); + c_device_memory.push_back(AsDeviceMemory(c_base_ptr)); + a_ptrs.push_back(&a_device_memory.back()); + b_ptrs.push_back(&b_device_memory.back()); + c_ptrs.push_back(&c_device_memory.back()); + } else if (!bcast.IsBroadcastingRequired()) { + for (int64_t i = 0; i < batch_size; ++i) { + a_device_memory.push_back(AsDeviceMemory(a_base_ptr + i * m * k)); + b_device_memory.push_back(AsDeviceMemory(b_base_ptr + i * k * n)); + c_device_memory.push_back(AsDeviceMemory(c_base_ptr + i * m * n)); + a_ptrs.push_back(&a_device_memory.back()); + b_ptrs.push_back(&b_device_memory.back()); + c_ptrs.push_back(&c_device_memory.back()); + } + } else { + const std::vector& a_batch_indices = bcast.x_batch_indices(); + const std::vector& b_batch_indices = bcast.y_batch_indices(); + for (int64_t i = 0; i < bcast.x_batch_size(); ++i) { + a_device_memory.push_back(AsDeviceMemory(a_base_ptr + i * m * k)); + } + for (int64_t i = 0; i < bcast.y_batch_size(); ++i) { + b_device_memory.push_back(AsDeviceMemory(b_base_ptr + i * k * n)); + } + for (int64_t i = 0; i < batch_size; ++i) { + c_device_memory.push_back(AsDeviceMemory(c_base_ptr + i * m * n)); + a_ptrs.push_back(&a_device_memory[a_batch_indices[i]]); + b_ptrs.push_back(&b_device_memory[b_batch_indices[i]]); + c_ptrs.push_back(&c_device_memory.back()); + } + } + + // Blas does + // C = A x B + // where A, B and C are assumed to be in column major. + // We want the output to be in row-major, so we can compute + // C' = B' x A', where ' stands for transpose (not adjoint). + // TODO(yangzihao): Choose the best of the three strategies using + // autotune. + auto blas = stream->parent()->AsBlas(); + OP_REQUIRES(context, blas != nullptr, + absl::InternalError("No blas support for stream")); + if (batch_size == 1) { + // This is a regular matrix*matrix or matrix*vector multiply. Avoid the + // overhead of the scratch allocator and the batch interface. + // TODO(benbarsdell): Use fp16 Gemv if it becomes supported by CUBLAS + if constexpr (!std::is_same_v && + !std::is_same_v) { + if (n == 1 && + blas_transpose_b != se::blas::Transpose::kConjugateTranspose && + blas_transpose_a != se::blas::Transpose::kConjugateTranspose) { + // This is a matrix*vector multiply so use GEMV to compute A * b. + // Here we are multiplying in the natural order, so we have to flip + // the transposition flag to compensate for the tensor being stored + // row-major. Since GEMV doesn't provide a way to just conjugate an + // argument, we have to defer those cases to GEMM below. + auto gemv_trans_a = + blas_transpose_a == se::blas::Transpose::kTranspose + ? se::blas::Transpose::kNoTranspose + : se::blas::Transpose::kTranspose; + bool blas_launch_status = blas->DoBlasGemv( + stream, gemv_trans_a, adj_x || trans_x ? m : k, + adj_x || trans_x ? k : m, static_cast(1.0), + *(a_ptrs[0]), adj_x || trans_x ? m : k, *(b_ptrs[0]), 1, + static_cast(0.0), c_ptrs[0], 1); + if (!blas_launch_status) { + context->SetStatus(errors::Internal( + "Blas xGEMV launch failed : a.shape=", + in_x.shape().DebugString(), ", b.shape=", + in_y.shape().DebugString(), ", m=", m, ", n=", n, ", k=", k)); + } + return; + } + } + + OP_REQUIRES_OK( + context, + blas->BlasGemm(stream, blas_transpose_b, blas_transpose_a, n, m, k, + *(b_ptrs[0]), adj_y || trans_y ? k : n, *(a_ptrs[0]), + adj_x || trans_x ? m : k, c_ptrs[0], n, + GetNumericOptions(), call_context)); + } else if (use_strided_batched) { + OP_REQUIRES_OK( + context, blas->BlasGemmStridedBatched( + stream, blas_transpose_b, blas_transpose_a, n, m, k, + static_cast(1.0), *b_ptrs[0], + adj_y || trans_y ? k : n, b_stride, *a_ptrs[0], + adj_x || trans_x ? m : k, a_stride, + static_cast(0.0), c_ptrs[0], n, c_stride, + batch_size, GetNumericOptions(), call_context)); + } else { + BlasScratchAllocator scratch_allocator(context); + bool blas_launch_status = blas->DoBlasGemmBatched( + stream, blas_transpose_b, blas_transpose_a, n, m, k, + static_cast(1.0), b_ptrs, adj_y || trans_y ? k : n, + a_ptrs, adj_x || trans_x ? m : k, static_cast(0.0), + c_ptrs, n, batch_size, GetNumericOptions(), &scratch_allocator, + call_context); + if (!blas_launch_status) { + context->SetStatus(errors::Internal( + "Blas xGEMMBatched launch failed : a.shape=", + in_x.shape().DebugString(), + ", b.shape=", in_y.shape().DebugString(), ", m=", m, ", n=", n, + ", k=", k, ", batch_size=", batch_size)); + } + } +#if GOOGLE_CUDA || TF_HIPBLASLT + } +#endif // GOOGLE_CUDA + } +}; + +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM + +template +inline void FastConvertToFloat(const T* src, float* dst, int64_t size) { + Eigen::Map> src_eigen(src, size); + Eigen::Map dst_eigen(dst, size); + dst_eigen = src_eigen.template cast(); +} + +template +inline void FastConvertFromFloat(const float* src, T* dst, int64_t size) { + Eigen::Map src_eigen(src, size); + Eigen::Map> dst_eigen(dst, size); + dst_eigen = src_eigen.template cast(); +} + +template <> +inline void FastConvertToFloat(const bfloat16* src, float* dst, + int64_t size) { + BFloat16ToFloat(src, dst, size); +} + +template <> +inline void FastConvertFromFloat(const float* src, bfloat16* dst, + int64_t size) { + FloatToBFloat16(src, dst, size); +} + +template +class BaseBatchMatMulOp : public OpKernel { + public: + explicit BaseBatchMatMulOp(OpKernelConstruction* context, + bool is_legacy_matmul) + : OpKernel(context) { + if (is_legacy_matmul) { + // The old MatMul kernel has "transpose_a/transpose_b" attributes. + OP_REQUIRES_OK(context, context->GetAttr("transpose_a", &trans_x_)); + OP_REQUIRES_OK(context, context->GetAttr("transpose_b", &trans_y_)); + adj_x_ = false; + adj_y_ = false; + OP_REQUIRES_OK(context, context->GetAttr("grad_a", &grad_input_1_)); + OP_REQUIRES_OK(context, context->GetAttr("grad_b", &grad_input_2_)); + } else { + OP_REQUIRES_OK(context, context->GetAttr("adj_x", &adj_x_)); + OP_REQUIRES_OK(context, context->GetAttr("adj_y", &adj_y_)); + trans_x_ = false; + trans_y_ = false; + OP_REQUIRES_OK(context, context->GetAttr("grad_x", &grad_input_1_)); + OP_REQUIRES_OK(context, context->GetAttr("grad_y", &grad_input_2_)); + } + } + + ~BaseBatchMatMulOp() override {} + + void Compute(OpKernelContext* ctx) override { + const Tensor& in0 = ctx->input(0); + const Tensor& in1 = ctx->input(1); + + const absl::Status s = ValidateInputTensors(ctx, in0, in1); + if (!s.ok()) { + ctx->SetStatus(s); + return; + } + + MatMulBCast bcast(in0.shape().dim_sizes(), in1.shape().dim_sizes()); + OP_REQUIRES( + ctx, bcast.IsValid(), + errors::InvalidArgument( + "In[0] and In[1] must have compatible batch dimensions: ", + in0.shape().DebugString(), " vs. ", in1.shape().DebugString())); + + TensorShape out_shape = bcast.output_batch_shape(); + auto batch_size = bcast.output_batch_size(); + auto d0 = in0.dim_size(in0.dims() - 2); + auto d1 = in0.dim_size(in0.dims() - 1); + Tensor in0_reshaped; + OP_REQUIRES( + ctx, + in0_reshaped.CopyFrom(in0, TensorShape({bcast.x_batch_size(), d0, d1})), + errors::Internal("Failed to reshape In[0] from ", + in0.shape().DebugString())); + auto d2 = in1.dim_size(in1.dims() - 2); + auto d3 = in1.dim_size(in1.dims() - 1); + Tensor in1_reshaped; + OP_REQUIRES( + ctx, + in1_reshaped.CopyFrom(in1, TensorShape({bcast.y_batch_size(), d2, d3})), + errors::Internal("Failed to reshape In[1] from ", + in1.shape().DebugString())); + if (adj_x_ || trans_x_) std::swap(d0, d1); + if (adj_y_ || trans_y_) std::swap(d2, d3); + OP_REQUIRES( + ctx, d1 == d2, + errors::InvalidArgument( + "Matrix size-incompatible: In[0]: ", in0.shape().DebugString(), + ", In[1]: ", in1.shape().DebugString())); + OP_REQUIRES_OK(ctx, out_shape.AddDimWithStatus(d0)); + OP_REQUIRES_OK(ctx, out_shape.AddDimWithStatus(d3)); + Tensor* out = nullptr; + OP_REQUIRES_OK(ctx, ctx->allocate_output(0, out_shape, &out)); + if (out->NumElements() == 0) { + return; + } + if (in0.NumElements() == 0 || in1.NumElements() == 0) { + functor::SetZeroFunctor f; + f(ctx->eigen_device(), out->flat()); + return; + } + Tensor out_reshaped; + OP_REQUIRES(ctx, + out_reshaped.CopyFrom(*out, TensorShape({batch_size, d0, d3})), + errors::Internal("Failed to reshape output from ", + out->shape().DebugString())); + + // b/307285203: There seems to be an overly aggressive compiler optimization + // that optimizes away these data pointers unless we explicitly check them. + OP_REQUIRES(ctx, + in0_reshaped.data() != nullptr && + in1_reshaped.data() != nullptr && + out_reshaped.data() != nullptr, + absl::InternalError("Null data pointer encountered.")); + if constexpr (std::is_same_v && std::is_same_v && + (std::is_same_v || + std::is_same_v)) { + Tensor in0_reshaped_float, in1_reshaped_float, out_reshaped_float; + OP_REQUIRES_OK(ctx, ctx->allocate_temp(DT_FLOAT, in0_reshaped.shape(), + &in0_reshaped_float)); + OP_REQUIRES_OK(ctx, ctx->allocate_temp(DT_FLOAT, in1_reshaped.shape(), + &in1_reshaped_float)); + OP_REQUIRES_OK(ctx, ctx->allocate_temp(DT_FLOAT, out_reshaped.shape(), + &out_reshaped_float)); + + // TODO: Avoid extra copy to make (b)float16 matmul efficient on CPU. + FastConvertToFloat(in0_reshaped.flat().data(), + in0_reshaped_float.flat().data(), + in0_reshaped.NumElements()); + FastConvertToFloat(in1_reshaped.flat().data(), + in1_reshaped_float.flat().data(), + in1_reshaped.NumElements()); + + LaunchBatchMatMul::Launch( + ctx, in0_reshaped_float, in1_reshaped_float, adj_x_, adj_y_, trans_x_, + trans_y_, grad_input_1_, grad_input_2_, bcast, &out_reshaped_float); + FastConvertFromFloat(out_reshaped_float.flat().data(), + out_reshaped.flat().data(), + out->NumElements()); + } else { + // Cast tensor to desired type to reuse Eigen. + // TODO(b/178749687): remove this cast if Eigen supports this natively. + if constexpr (!std::is_same::value) { + in0_reshaped = CastTensor(in0_reshaped); + } + if constexpr (!std::is_same::value) { + in1_reshaped = CastTensor(in1_reshaped); + } + LaunchBatchMatMul::Launch( + ctx, in0_reshaped, in1_reshaped, adj_x_, adj_y_, trans_x_, trans_y_, + grad_input_1_, grad_input_2_, bcast, &out_reshaped); + } + } + + protected: + virtual absl::Status ValidateInputTensors(OpKernelContext* ctx, + const Tensor& in0, + const Tensor& in1) = 0; + + private: + // TODO(171979567) Make the ops take both adj and transpose attributes. + bool adj_x_ = false; + bool adj_y_ = false; + bool trans_x_ = false; + bool trans_y_ = false; + bool grad_input_1_ = false; + bool grad_input_2_ = false; + + // Cast `t` from `SrcT` to `DstT`. + template + Tensor CastTensor(const Tensor& t) { + Tensor res = Tensor(DataTypeToEnum::v(), t.shape()); + res.flat() = t.flat().template cast(); + return res; + } +}; + +// BatchMatMul Op implementation which disallows broadcasting. +template +class BatchMatMulOp : public BaseBatchMatMulOp { + public: + explicit BatchMatMulOp(OpKernelConstruction* context) + : BaseBatchMatMulOp(context, is_legacy_matmul) {} + + ~BatchMatMulOp() override {} + + private: + absl::Status ValidateInputTensors(OpKernelContext* ctx, const Tensor& in0, + const Tensor& in1) override { + // Disallow broadcasting support. Ensure that all batch dimensions of the + // input tensors match. + if (in0.dims() != in1.dims()) { + return errors::InvalidArgument( + "In[0] and In[1] has different ndims: ", in0.shape().DebugString(), + " vs. ", in1.shape().DebugString()); + } + const int ndims = in0.dims(); + if (is_legacy_matmul) { + if (ndims != 2) { + return errors::InvalidArgument("In[0] and In[1] ndims must be == 2: ", + ndims); + } + } else { + if (ndims < 2) { + return errors::InvalidArgument("In[0] and In[1] ndims must be >= 2: ", + ndims); + } + for (int i = 0; i < ndims - 2; ++i) { + if (in0.dim_size(i) != in1.dim_size(i)) { + return errors::InvalidArgument( + "In[0].dim(", i, ") and In[1].dim(", i, + ") must be the same: ", in0.shape().DebugString(), " vs ", + in1.shape().DebugString()); + } + } + } + return absl::OkStatus(); + } +}; + +// BatchMatMul Op implementation with broadcasting support. +template +class BatchMatMulV2Op : public BaseBatchMatMulOp { + public: + explicit BatchMatMulV2Op(OpKernelConstruction* context) + : BaseBatchMatMulOp(context, + /* is_legacy_matmul= */ false) { + } + + ~BatchMatMulV2Op() override {} + + private: + absl::Status ValidateInputTensors(OpKernelContext* ctx, const Tensor& in0, + const Tensor& in1) override { + // Enable broadcasting support. Validity of broadcasting is checked in + // BaseBatchMatMulOp. + if (in0.dims() < 2) { + return errors::InvalidArgument("In[0] ndims must be >= 2: ", in0.dims()); + } + if (in1.dims() < 2) { + return errors::InvalidArgument("In[1] ndims must be >= 2: ", in1.dims()); + } + return absl::OkStatus(); + } +}; + +// Register for MatMul, BatchMatMul, BatchMatMulv2 where Tin = Tout. +#define REGISTER_BATCH_MATMUL_CPU(TYPE) \ + REGISTER_KERNEL_BUILDER( \ + Name("BatchMatMul").Device(DEVICE_CPU).TypeConstraint("T"), \ + BatchMatMulOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("BatchMatMulV2").Device(DEVICE_CPU).TypeConstraint("T"), \ + BatchMatMulV2Op); \ + REGISTER_KERNEL_BUILDER( \ + Name("MatMul").Device(DEVICE_CPU).TypeConstraint("T"), \ + BatchMatMulOp) + +#define REGISTER_BATCH_MATMUL_GPU(TYPE) \ + REGISTER_KERNEL_BUILDER( \ + Name("BatchMatMul").Device(DEVICE_GPU).TypeConstraint("T"), \ + BatchMatMulOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("BatchMatMulV2").Device(DEVICE_GPU).TypeConstraint("T"), \ + BatchMatMulV2Op); \ + REGISTER_KERNEL_BUILDER( \ + Name("MatMul").Device(DEVICE_GPU).TypeConstraint("T"), \ + BatchMatMulOp) + +// Register for BatchMatMulv3 where Ta, Tb and Tout are not the same. +#define REGISTER_BATCH_MATMUL_TOUT_CPU(Ta, Tb, Tout) \ + REGISTER_KERNEL_BUILDER(Name("BatchMatMulV3") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("Ta") \ + .TypeConstraint("Tb") \ + .TypeConstraint("Tout"), \ + BatchMatMulV2Op) + +#define REGISTER_BATCH_MATMUL_TOUT_GPU(Ta, Tb, Tout) \ + REGISTER_KERNEL_BUILDER(Name("BatchMatMulV3") \ + .Device(DEVICE_GPU) \ + .TypeConstraint("Ta") \ + .TypeConstraint("Tb") \ + .TypeConstraint("Tout"), \ + BatchMatMulV2Op) + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_MATMUL_OP_IMPL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/matmul_util.h b/third_party/tflite-hdrs/tensorflow/core/kernels/matmul_util.h new file mode 100644 index 00000000..0b73b881 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/matmul_util.h @@ -0,0 +1,88 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_KERNELS_MATMUL_UTIL_H_ +#define TENSORFLOW_CORE_KERNELS_MATMUL_UTIL_H_ + +#include +#include + +#if TENSORFLOW_USE_ROCM +#include "rocm/rocm_config.h" +#endif + +#if GOOGLE_CUDA || TF_HIPBLASLT + +#include "absl/container/flat_hash_map.h" +#include "xla/stream_executor/device_memory.h" +#include "xla/stream_executor/gpu/gpu_blas_lt.h" +#include "tensorflow/core/framework/types.h" +#include "tsl/platform/types.h" + +namespace tensorflow { + +// Get a workspace limit from the environment variable, which is in MB. +// Return the workspace memory limit in bytes. If no value is set, return the +// default value. +int64_t GetWorkspaceLimit(int64_t default_value_in_bytes); + +struct BlasLtMatmulPlanParams { + std::string ToString() const; + bool operator==(const BlasLtMatmulPlanParams& other) const; + + se::blas::DataType dtype; + size_t m; + size_t n; + size_t k; + se::blas::Transpose trans_a; + se::blas::Transpose trans_b; + size_t batch_count = 1; + bool broadcast_a = false; + bool broadcast_b = false; + se::gpu::BlasLt::Epilogue epilogue = se::gpu::BlasLt::Epilogue::kDefault; +}; + +struct PlanAndAlgorithms { + static StatusOr GetOrCreate( + se::Stream* stream, const BlasLtMatmulPlanParams& params, + absl::Mutex** pmu, std::optional max_algorithm_count = std::nullopt); + + Status ExecuteOnStream( + se::Stream* stream, const se::DeviceMemoryBase& a, + const se::DeviceMemoryBase& b, se::DeviceMemoryBase& c, + size_t algorithm_idx, se::ScratchAllocator& scratch_allocator, + const se::DeviceMemoryBase& bias = se::DeviceMemoryBase{}, + se::blas::ProfileResult* profile_result = nullptr) const; + + se::gpu::BlasLt::MatmulPlanPtr plan; + std::vector algorithms; +}; + +namespace internal { + +inline auto AsTuple(const BlasLtMatmulPlanParams& p) { + return std::make_tuple(p.dtype, p.m, p.n, p.k, p.trans_a, p.trans_b, + p.batch_count, p.broadcast_a, p.broadcast_b, + p.epilogue); +} + +} // namespace internal + +template +H AbslHashValue(H h, const BlasLtMatmulPlanParams& params) { + return H::combine(std::move(h), internal::AsTuple(params)); +} + +} // namespace tensorflow + +#endif // GOOGLE_CUDA || TF_HIPBLASLT + +#endif // TENSORFLOW_CORE_KERNELS_MATMUL_UTIL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/maxpooling_op.h b/third_party/tflite-hdrs/tensorflow/core/kernels/maxpooling_op.h new file mode 100644 index 00000000..7c1d91d7 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/maxpooling_op.h @@ -0,0 +1,55 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_MAXPOOLING_OP_H_ +#define TENSORFLOW_CORE_KERNELS_MAXPOOLING_OP_H_ +// Functor definition for MaxPoolingOp, must be compilable by nvcc. + +#include "xla/tsl/framework/fixedpoint/FixedPoint.h" +#include "tensorflow/core/framework/numeric_types.h" +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/framework/type_traits.h" +#include "tensorflow/core/kernels/eigen_pooling.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +namespace functor { + +template +struct SpatialMaxPooling { + void operator()(const Device& d, typename TTypes::Tensor output, + typename TTypes::ConstTensor input, int window_rows, + int window_cols, int row_stride, int col_stride, + const Eigen::PaddingType& padding) { + // Because we swap the layout, we swap the row/cols as well + output.swap_layout().device(d) = + Eigen::SpatialMaxPooling(input.swap_layout(), window_cols, window_rows, + col_stride, row_stride, padding); + } +}; + +template +struct SpatialMaxPooling { + void operator()(const Device& d, typename TTypes::Tensor output, + typename TTypes::ConstTensor input, int window_rows, + int window_cols, int row_stride, int col_stride, + const Eigen::PaddingType& padding) {} +}; + +} // namespace functor + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_MAXPOOLING_OP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/maxpooling_op_gpu.h b/third_party/tflite-hdrs/tensorflow/core/kernels/maxpooling_op_gpu.h new file mode 100644 index 00000000..650a01e3 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/maxpooling_op_gpu.h @@ -0,0 +1,86 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#if !GOOGLE_CUDA && !TENSORFLOW_USE_ROCM +#error This file must only be included when building with Cuda or ROCm support +#endif + +#ifndef TENSORFLOW_CORE_KERNELS_MAXPOOLING_OP_GPU_H_ +#define TENSORFLOW_CORE_KERNELS_MAXPOOLING_OP_GPU_H_ + +#define EIGEN_USE_GPU + +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/tensor_format.h" + +namespace tensorflow { + +namespace functor { +// Run the forward pass of max pooling, optionally writing the argmax indices to +// the mask array, if it is not nullptr. If mask is passed in as nullptr, the +// argmax indices are not written. +template +struct MaxPoolForwardWithOptionalArgmax { + bool operator()(const T* bottom_data, const int batch, const int height, + const int width, const int channels, const int pooled_height, + const int pooled_width, const int kernel_h, + const int kernel_w, const int stride_h, const int stride_w, + const int pad_t, const int pad_l, T* top_data, int64_t* mask, + const Eigen::GpuDevice& d, bool propagate_nans, + const bool include_batch_in_index); +}; + +struct MaxPoolForwardNoMask_NCHW_VECT_C { + bool operator()(const int32* bottom_data, const int batch, const int height, + const int width, int channels, const int pooled_height, + const int pooled_width, const int kernel_h, + const int kernel_w, const int stride_h, const int stride_w, + const int pad_t, const int pad_l, int32* top_data, + const Eigen::GpuDevice& d); +}; + +template +struct MaxPoolBackwardWithArgmax { + bool operator()(const int output_size, const int input_size, + const T* top_diff, const int64_t* mask, const int top_offset, + const int bottom_offset, T* bottom_diff, + const Eigen::GpuDevice& d, const bool include_batch_in_index); +}; + +template +struct MaxPoolGradBackwardWithArgmax { + bool operator()(const int output_size, const int input_size, + const T* top_diff, const int64_t* mask, const int top_offset, + const int bottom_offset, T* bottom_diff, + const Eigen::GpuDevice& d, const bool include_batch_in_index); +}; + +template +struct MaxPoolGradBackwardNoMask { + bool operator()(TensorFormat data_format, const T* bottom_data, + const T* output_data, const int batch, + const int pooled_height, const int pooled_width, + const int channels, const int height, const int width, + const int kernel_h, const int kernel_w, const int stride_h, + const int stride_w, const int pad_t, const int pad_l, + const T* top_diff, T* bottom_diff, const Eigen::GpuDevice& d); +}; + +} // namespace functor + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_MAXPOOLING_OP_GPU_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/meta_support.h b/third_party/tflite-hdrs/tensorflow/core/kernels/meta_support.h new file mode 100644 index 00000000..b1e81b4f --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/meta_support.h @@ -0,0 +1,112 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_META_SUPPORT_H_ +#define TENSORFLOW_CORE_KERNELS_META_SUPPORT_H_ + +#include "meta/multi_thread_gemm.h" +#include "meta/multi_thread_transform.h" +#include "meta/quantized_mul_kernels.h" +#include "meta/streams.h" +#include "meta/transform_kernels.h" + +#include "tensorflow/core/framework/numeric_types.h" + +namespace tensorflow { + +class OpKernelContext; + +namespace meta { + +// Gemmlowp/meta is a small library of optimized Arm32/64 kernels for quantized +// matrix multiplication and other quantized computations. + +// Set the maximum number of threads of computation that the internal workers +// pool can use. If num_threads is 0, then use intra_op_parallelism_threads. +void SetNumThreads(int num_threads); + +int GetNumThreads(); + +// Toggle the internal workers pool. If set to false, the computations will +// use the worker pool passed each time in the OpKernelContext. If set to true +// then the OpKernelContext will be ignored, and the internal optimized workers +// pool will be used. +// +// The internal workers pool is disabled by default (false). +void SetUseLocalContext(bool use_local_context); + +bool GetUseLocalContext(); + +// Toggles the codepath. Enabled by default (true) on supported platforms. +void SetEnabled(bool enabled); + +// Returns true if the codepath is supported and is enabled. Use this call +// before calling the compute functions. If the codepath is not supported, and +// any of the compute function is called, the library will log a FATAL error. +bool IsSupportedAndEnabled(); + +// Calculate the quantized matrix multiplication: +// +// for (i, j) in [0, m) x [0, n) do +// c_data[i, j] := +// sum((a_data[i, l] + offset_a) * (b_data[l, j] + offset_b)) : l in [0, k) +// +// If transpose_a is false the lhs operand has row major layout, otherwise +// column major. Similarly transpose_b describes the layout of the rhs operand. +// lda, ldb, and ldc are the strides of the lhs operand, rhs operand and the +// result arrays. +void QuantizedGemm(OpKernelContext* context, bool transpose_a, bool transpose_b, + const quint8* a_data, const quint8* b_data, qint32* c_data, + int m, int n, int k, int offset_a, int offset_b, int lda, + int ldb, int ldc); + +// Take an array of numbers from the range [input_min, input_max] quantized +// uniformly to int32 values, recover their float values, and then quantize +// them back uniformly to the range [output_min, output_max] as uint8. +// Saturate the uint8 values. +void Requantize(OpKernelContext* context, const qint32* input, int count, + float input_min, float input_max, float output_min, + float output_max, quint8* output); + +// Take an array of numbers from the range [range_min, range_max] quantized +// uniformly to uint8 values and recover their float values. +void Dequantize(OpKernelContext* context, const quint8* input, int count, + float range_min, float range_max, float* output); + +// Take an array of float values and quantize them uniformly to the range +// [range_min, range_max] expressed as uint8. Saturate the uint8 values. +void Quantize(OpKernelContext*, const float* input, int count, float range_min, + float range_max, quint8* output); + +// Take two arrays: the inputs and the bias quantized uniformly in the ranges +// [input_min, input_max], and [bias_min, bias_max] accordingly, as uint8 +// values. Recover their float values. Add the values. Quantize them back +// uniformly to the range [output_min, output_max] as int32. Saturate the +// int32 values. +void QuantizedBiasAdd(OpKernelContext* context, const quint8* input, + int input_count, const quint8* bias, int bias_count, + float input_min, float input_max, float bias_min, + float bias_max, float output_min, float output_max, + qint32* output); + +// Take an array of uint8 values and clamp them to the range [clamp_min, +// clamp_max]. +void Clamp(OpKernelContext* context, const quint8* input, int input_count, + quint8 clamp_min, quint8 clamp_max, quint8* output); + +} // namespace meta +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_META_SUPPORT_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/mfcc.h b/third_party/tflite-hdrs/tensorflow/core/kernels/mfcc.h new file mode 100644 index 00000000..790b5f7b --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/mfcc.h @@ -0,0 +1,77 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Basic class for computing MFCCs from spectrogram slices. + +#ifndef TENSORFLOW_CORE_KERNELS_MFCC_H_ +#define TENSORFLOW_CORE_KERNELS_MFCC_H_ + +#include + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/kernels/mfcc_dct.h" +#include "tensorflow/core/kernels/mfcc_mel_filterbank.h" +#include "tensorflow/core/platform/logging.h" + +namespace tensorflow { + +class Mfcc { + public: + Mfcc(); + bool Initialize(int input_length, double input_sample_rate); + + // Input is a single squared-magnitude spectrogram frame. The input spectrum + // is converted to linear magnitude and weighted into bands using a + // triangular mel filterbank, and a discrete cosine transform (DCT) of the + // values is taken. Output is populated with the lowest dct_coefficient_count + // of these values. + void Compute(const std::vector& spectrogram_frame, + std::vector* output) const; + + void set_upper_frequency_limit(double upper_frequency_limit) { + CHECK(!initialized_) << "Set frequency limits before calling Initialize."; + upper_frequency_limit_ = upper_frequency_limit; + } + + void set_lower_frequency_limit(double lower_frequency_limit) { + CHECK(!initialized_) << "Set frequency limits before calling Initialize."; + lower_frequency_limit_ = lower_frequency_limit; + } + + void set_filterbank_channel_count(int filterbank_channel_count) { + CHECK(!initialized_) << "Set channel count before calling Initialize."; + filterbank_channel_count_ = filterbank_channel_count; + } + + void set_dct_coefficient_count(int dct_coefficient_count) { + CHECK(!initialized_) << "Set coefficient count before calling Initialize."; + dct_coefficient_count_ = dct_coefficient_count; + } + + private: + MfccMelFilterbank mel_filterbank_; + MfccDct dct_; + bool initialized_; + double lower_frequency_limit_; + double upper_frequency_limit_; + int filterbank_channel_count_; + int dct_coefficient_count_; + Mfcc(const Mfcc&) = delete; + void operator=(const Mfcc&) = delete; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_MFCC_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/mfcc_dct.h b/third_party/tflite-hdrs/tensorflow/core/kernels/mfcc_dct.h new file mode 100644 index 00000000..e7982d6a --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/mfcc_dct.h @@ -0,0 +1,45 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Basic minimal DCT class for MFCC speech processing. + +#ifndef TENSORFLOW_CORE_KERNELS_MFCC_DCT_H_ +#define TENSORFLOW_CORE_KERNELS_MFCC_DCT_H_ + +#include + +#include "tensorflow/core/framework/op_kernel.h" + +namespace tensorflow { + +class MfccDct { + public: + MfccDct(); + bool Initialize(int input_length, int coefficient_count); + void Compute(const std::vector& input, + std::vector* output) const; + + private: + bool initialized_; + int coefficient_count_; + int input_length_; + std::vector > cosines_; + MfccDct(const MfccDct&) = delete; + void operator=(const MfccDct&) = delete; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_MFCC_DCT_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/mfcc_mel_filterbank.h b/third_party/tflite-hdrs/tensorflow/core/kernels/mfcc_mel_filterbank.h new file mode 100644 index 00000000..293d7745 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/mfcc_mel_filterbank.h @@ -0,0 +1,65 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Basic class for applying a mel-scale mapping to a power spectrum. + +#ifndef TENSORFLOW_CORE_KERNELS_MFCC_MEL_FILTERBANK_H_ +#define TENSORFLOW_CORE_KERNELS_MFCC_MEL_FILTERBANK_H_ + +#include +#include "tensorflow/core/framework/op_kernel.h" + +namespace tensorflow { + +class MfccMelFilterbank { + public: + MfccMelFilterbank(); + bool Initialize(int input_length, // Number of unique FFT bins fftsize/2+1. + double input_sample_rate, int output_channel_count, + double lower_frequency_limit, double upper_frequency_limit); + + // Takes a squared-magnitude spectrogram slice as input, computes a + // triangular-mel-weighted linear-magnitude filterbank, and places the result + // in output. + void Compute(const std::vector& input, + std::vector* output) const; + + private: + double FreqToMel(double freq) const; + bool initialized_; + int num_channels_; + double sample_rate_; + int input_length_; + std::vector center_frequencies_; // In mel, for each mel channel. + + // Each FFT bin b contributes to two triangular mel channels, with + // proportion weights_[b] going into mel channel band_mapper_[b], and + // proportion (1 - weights_[b]) going into channel band_mapper_[b] + 1. + // Thus, weights_ contains the weighting applied to each FFT bin for the + // upper-half of the triangular band. + std::vector weights_; // Right-side weight for this fft bin. + + // FFT bin i contributes to the upper side of mel channel band_mapper_[i] + std::vector band_mapper_; + int start_index_; // Lowest FFT bin used to calculate mel spectrum. + int end_index_; // Highest FFT bin used to calculate mel spectrum. + + MfccMelFilterbank(const MfccMelFilterbank&) = delete; + void operator=(const MfccMelFilterbank&) = delete; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_MFCC_MEL_FILTERBANK_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/mkl/mkl_batch_matmul_helper.h b/third_party/tflite-hdrs/tensorflow/core/kernels/mkl/mkl_batch_matmul_helper.h new file mode 100644 index 00000000..d7c1da14 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/mkl/mkl_batch_matmul_helper.h @@ -0,0 +1,104 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_KERNELS_MKL_MKL_BATCH_MATMUL_HELPER_H_ +#define TENSORFLOW_CORE_KERNELS_MKL_MKL_BATCH_MATMUL_HELPER_H_ +#if defined(INTEL_MKL) + +#include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/type_traits.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/kernels/mkl/mkl_matmul_ops_common.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/matmul_bcast.h" + +namespace tensorflow { + +struct MklBatchMatMulHelper { + using dims = dnnl::memory::dims; + // This method makes the rank (ndims) of input same as the output by creating + // new axes to the input. For example, if input shape is [a, b, c, d] and + // output shape is [e, f, g, h, i, j], then the reshaped input would have a + // shape of [1, 1, a, b, c, d]. + void ExpandInputDimsToOutputShape(const TensorShape& input_shape, + const TensorShape& output_shape, + dims* reshaped_dims) { + auto ndims_input = input_shape.dims(); + auto ndims_output = output_shape.dims(); + auto dim_offset = ndims_output - ndims_input; + DCHECK(dim_offset > 0); + reshaped_dims->clear(); + reshaped_dims->resize(ndims_output, 1); + auto input_dims = input_shape.dim_sizes(); + for (int dim_idx = 0; dim_idx < ndims_input; ++dim_idx) + reshaped_dims->at(dim_idx + dim_offset) = input_dims[dim_idx]; + } + + std::unique_ptr CreateMatMulParams( + string& prefix, const TensorShape& lhs_shape, + const TensorShape& rhs_shape, const TensorShape& out_shape, bool& adj_x, + bool& adj_y) { + const auto ndims_lhs = lhs_shape.dims(); + const auto ndims_rhs = rhs_shape.dims(); + const auto ndims_out = out_shape.dims(); + auto lhs_dims = TFShapeToMklDnnDims(lhs_shape); + auto rhs_dims = TFShapeToMklDnnDims(rhs_shape); + auto out_dims = TFShapeToMklDnnDims(out_shape); + + // DNNL matmul_primitive requires ranks of inputs and output to be same. + // Create dnnl::memory::dims for inputs and output of same rank. + // It is assumed here that MatMulBCast object creates output_batch_shape as + // a conforming superset of input batch shapes, i.e., ndims_out >= + // ndims_lhs and ndims_out >= ndims_rhs. + if (ndims_lhs < ndims_out) { + ExpandInputDimsToOutputShape(lhs_shape, out_shape, &lhs_dims); + } + if (ndims_rhs < ndims_out) { + ExpandInputDimsToOutputShape(rhs_shape, out_shape, &rhs_dims); + } + auto lhs_strides = CalculateTFStrides(lhs_dims); + auto rhs_strides = CalculateTFStrides(rhs_dims); + auto out_strides = CalculateTFStrides(out_dims); + + if (adj_x) { + int m_idx = ndims_out - 1; + int k_idx = ndims_out - 2; + memory::dim m = lhs_dims[m_idx]; // number of rows in x + std::swap(lhs_dims[m_idx], lhs_dims[k_idx]); + lhs_strides[m_idx] = m; + lhs_strides[k_idx] = 1; + } + + if (adj_y) { + int k_idx = ndims_out - 1; + int n_idx = ndims_out - 2; + memory::dim k = rhs_dims[k_idx]; // number of columns in x + std::swap(rhs_dims[k_idx], rhs_dims[n_idx]); + rhs_strides[k_idx] = k; + rhs_strides[n_idx] = 1; + } + + return std::make_unique(prefix, lhs_dims, rhs_dims, + out_dims, lhs_strides, rhs_strides, + out_strides); + } +}; + +} // namespace tensorflow + +#endif // INTEL_MKL +#endif // TENSORFLOW_CORE_KERNELS_MKL_MKL_BATCH_MATMUL_HELPER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/mkl/mkl_conv_ops.h b/third_party/tflite-hdrs/tensorflow/core/kernels/mkl/mkl_conv_ops.h new file mode 100644 index 00000000..eac82bea --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/mkl/mkl_conv_ops.h @@ -0,0 +1,711 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_MKL_MKL_CONV_OPS_H_ +#define TENSORFLOW_CORE_KERNELS_MKL_MKL_CONV_OPS_H_ + +#ifdef INTEL_MKL +#include +#include +#include + +#include "dnnl.hpp" +#include "tensorflow/core/framework/bounds_check.h" +#include "tensorflow/core/framework/kernel_shape_util.h" +#include "tensorflow/core/framework/numeric_op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/tensor_slice.h" +#include "tensorflow/core/kernels/conv_grad_ops.h" +#include "tensorflow/core/kernels/ops_util.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/lib/strings/numbers.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/util/mkl_util.h" +#include "tensorflow/core/util/onednn_env_vars.h" +#include "tensorflow/core/util/padding.h" +#include "tensorflow/core/util/tensor_format.h" + +using dnnl::convolution_forward; +using dnnl::prop_kind; +using dnnl::stream; + +namespace tensorflow { + +#ifndef ENABLE_ONEDNN_V3 +// Op descriptor is no longer supported in oneDNN v3.x. Instead, primitive +// descriptor will directly accept primitive parameters during creation. +using ConvFwdDesc = dnnl::convolution_forward::desc; +#endif // !ENABLE_ONEDNN_V3 +using ConvFwdPd = dnnl::convolution_forward::primitive_desc; + +class MklDnnConvUtil { + protected: + OpKernelContext* context_; // We don't own this. + std::vector strides_; + std::vector dilations_; + Padding padding_; + TensorFormat data_format_; + + public: + MklDnnConvUtil(OpKernelContext* context, const std::vector& strides, + Padding pad, TensorFormat fm, + const std::vector& dilations, bool is_depthwise = false) + : context_(context), + strides_(strides), + dilations_(dilations), + padding_(pad), + data_format_(fm) {} + + virtual ~MklDnnConvUtil() { context_ = nullptr; } + + // Calculate Convolution strides + virtual inline void GetStridesInMklOrder(memory::dims* strides) { + // For now we take the stride from the second and third dimensions only + // (we do not support striding on the batch or depth dimension). + DCHECK(strides); + if (strides_.size() == 4) { + int stride_rows = GetTensorDim(strides_, data_format_, 'H'); + int stride_cols = GetTensorDim(strides_, data_format_, 'W'); + *strides = {stride_rows, stride_cols}; + } else if (strides_.size() == 5) { + int stride_planes = GetTensorDim(strides_, data_format_, '0'); + int stride_rows = GetTensorDim(strides_, data_format_, '1'); + int stride_cols = GetTensorDim(strides_, data_format_, '2'); + *strides = {stride_planes, stride_rows, stride_cols}; + } + } + + // Calculate Convolution dilations + virtual inline void GetDilationsInMklOrder(memory::dims* dilations) { + // For now we take the dilation from the second and third dimensions only + // (we do not support dilation on the batch or depth dimension). + DCHECK(dilations); + if (dilations_.size() == 4) { + int dilations_rows = GetTensorDim(dilations_, data_format_, 'H'); + int dilations_cols = GetTensorDim(dilations_, data_format_, 'W'); + *dilations = {dilations_rows, dilations_cols}; + } else if (dilations_.size() == 5) { + int dilations_planes = GetTensorDim(dilations_, data_format_, '0'); + int dilations_rows = GetTensorDim(dilations_, data_format_, '1'); + int dilations_cols = GetTensorDim(dilations_, data_format_, '2'); + *dilations = {dilations_planes, dilations_rows, dilations_cols}; + } + } + + // Calculate Convolution input size in oneDNN order. oneDNN + // requires input in NCHW/NCDHW format. Function does not return anything. + // But errors arising from sanity checks are returned in context's + // status. + virtual inline void GetInputSizeInMklOrder(const TensorShape& input_shape, + memory::dims* input_dims) { +#define CHECK_BOUNDS(val, err_msg) \ + do { \ + OP_REQUIRES(context_, \ + FastBoundsCheck(val, std::numeric_limits::max()), \ + errors::InvalidArgument(err_msg)); \ + } while (0) + + DCHECK(input_dims); + + // Input channel + int64 input_depth_raw = GetTensorDim(input_shape, data_format_, 'C'); + int input_depth = static_cast(input_depth_raw); + + // Input batch + int64 input_batch_raw = GetTensorDim(input_shape, data_format_, 'N'); + CHECK_BOUNDS(input_batch_raw, "Input batch too large"); + int input_batch = static_cast(input_batch_raw); + + if (strides_.size() == 4) { // NCHW format for Conv2D + // Input rows/height + int64 input_rows_raw = GetTensorDim(input_shape, data_format_, 'H'); + CHECK_BOUNDS(input_rows_raw, "Input rows too large"); + int input_rows = static_cast(input_rows_raw); + + // Input columns/width + int64 input_cols_raw = GetTensorDim(input_shape, data_format_, 'W'); + CHECK_BOUNDS(input_cols_raw, "Input cols too large"); + int input_cols = static_cast(input_cols_raw); + + // oneDNN always requires input in NCHW format Conv2D. + std::vector input_sizes(4, -1); + input_sizes[MklDnnDims::Dim_N] = input_batch; + input_sizes[MklDnnDims::Dim_C] = input_depth; + input_sizes[MklDnnDims::Dim_H] = input_rows; + input_sizes[MklDnnDims::Dim_W] = input_cols; + *input_dims = input_sizes; + } else if (strides_.size() == 5) { // NCDHW format for Conv3D + // Input planes/third-dimension + int64 input_planes_raw = GetTensorDim(input_shape, data_format_, '0'); + CHECK_BOUNDS(input_planes_raw, "Input depth too large"); + int input_planes = static_cast(input_planes_raw); + + // Input rows/height + int64 input_rows_raw = GetTensorDim(input_shape, data_format_, '1'); + CHECK_BOUNDS(input_rows_raw, "Input rows too large"); + int input_rows = static_cast(input_rows_raw); + + // Input columns/width + int64 input_cols_raw = GetTensorDim(input_shape, data_format_, '2'); + CHECK_BOUNDS(input_cols_raw, "Input cols too large"); + int input_cols = static_cast(input_cols_raw); + + // oneDNN always requires input in NCDHW format for Conv3D. + std::vector input_sizes(5, -1); + input_sizes[MklDnnDims3D::Dim3d_N] = input_batch; + input_sizes[MklDnnDims3D::Dim3d_C] = input_depth; + input_sizes[MklDnnDims3D::Dim3d_D] = input_planes; + input_sizes[MklDnnDims3D::Dim3d_H] = input_rows; + input_sizes[MklDnnDims3D::Dim3d_W] = input_cols; + *input_dims = input_sizes; + } +#undef CHECK_BOUNDS + } + + // Calculate Convolution filter size in oneDNN order. + // oneDNN requires filter in OIHW (Conv2D) or OIDHW (Conv3D) format. + // Function does not return anything. + // But errors arising from sanity checks are returned in context's + // status. This function differs from GetConvFilterSizeInMklOrder in + // parameter for input - it accepts src_shape since Convolution Backward + // Input gets shape of input tensor rather than actual tensor (Convolution + // forward gets actual tensor as input). + // + // TODO(intel-tf): Add similar function for input and filter in MklShape. + virtual inline void GetFilterSizeInMklOrder(const TensorShape& input_shape, + const TensorShape& filter_shape, + memory::dims* filter_dims, + bool* is_grouped_convolution, + bool is_depthwise) { + DCHECK(filter_dims); + + OP_REQUIRES(context_, filter_shape.dims() == strides_.size(), + errors::InvalidArgument((strides_.size() == 4) + ? "filter must be 4-dimensional: " + : "filter must be 5-dimensional: ", + filter_shape.DebugString())); + + for (int i = 0; i < ((strides_.size() == 4) ? 3 : 5); i++) { + OP_REQUIRES(context_, + FastBoundsCheck(filter_shape.dim_size(i), + std::numeric_limits::max()), + errors::InvalidArgument("filter too large")); + } + + int input_depth = GetTensorDim(input_shape, data_format_, 'C'); + + if (strides_.size() == 4) { // Conv2D + // TF filter is always in (rows, cols, in_depth, out_depth) order. + int filter_rows = + static_cast(filter_shape.dim_size(TF_2DFILTER_DIM_H)); + int filter_cols = + static_cast(filter_shape.dim_size(TF_2DFILTER_DIM_W)); + int filter_in_depth = + static_cast(filter_shape.dim_size(TF_2DFILTER_DIM_I)); + int filter_out_depth = + static_cast(filter_shape.dim_size(TF_2DFILTER_DIM_O)); + OP_REQUIRES(context_, input_depth % filter_in_depth == 0, + errors::InvalidArgument( + "input depth must be evenly divisible by filter depth: ", + input_depth, " vs ", filter_in_depth)); + *is_grouped_convolution = filter_in_depth != input_depth; + int group_count = input_depth / filter_in_depth; + OP_REQUIRES(context_, group_count > 0, + errors::InvalidArgument( + "grouped convolution must have at least one group: ", + group_count, " groups")); + + // oneDNN always needs filter in OIHW format for regular convolutions + // and GOIHW for grouped/depthwise convolutions, + // OIHW = (out_depth, in_depth, rows, cols) + // GOIHW = (group, out_depth, in_depth, rows, cols) + // Specifically for depthwise G=filter_indepth, O=filter_outdepth, I=1 + if (is_depthwise) { + std::vector filter_sizes(5, -1); + filter_sizes[MKL_GROUP_FILTER_DIM_G] = filter_in_depth; + filter_sizes[MKL_GROUP_FILTER_DIM_O] = filter_out_depth; + filter_sizes[MKL_GROUP_FILTER_DIM_I] = 1; + filter_sizes[MKL_GROUP_FILTER_DIM_H] = filter_rows; + filter_sizes[MKL_GROUP_FILTER_DIM_W] = filter_cols; + *filter_dims = filter_sizes; + } else if (*is_grouped_convolution) { + // TODO(intel-tf): Directly set filter_dims. Same for other places. + std::vector filter_sizes(5, -1); + filter_sizes[MKL_GROUP_FILTER_DIM_G] = group_count; + filter_sizes[MKL_GROUP_FILTER_DIM_O] = filter_out_depth / group_count; + filter_sizes[MKL_GROUP_FILTER_DIM_I] = filter_in_depth; + filter_sizes[MKL_GROUP_FILTER_DIM_H] = filter_rows; + filter_sizes[MKL_GROUP_FILTER_DIM_W] = filter_cols; + *filter_dims = filter_sizes; + } else { + std::vector filter_sizes(4, -1); + filter_sizes[MklDnnDims::Dim_O] = filter_out_depth; + filter_sizes[MklDnnDims::Dim_I] = filter_in_depth; + filter_sizes[MklDnnDims::Dim_H] = filter_rows; + filter_sizes[MklDnnDims::Dim_W] = filter_cols; + *filter_dims = filter_sizes; + } + } else { // Conv3D + OP_REQUIRES(context_, input_depth == filter_shape.dim_size(3), + errors::InvalidArgument( + "input and filter must have the same depth: ", + input_depth, " vs ", filter_shape.dim_size(3))); + + // TF filter is always in (planes, rows, cols, in_depth, out_depth) order. + int filter_planes = + static_cast(filter_shape.dim_size(TF_3DFILTER_DIM_P)); + int filter_rows = + static_cast(filter_shape.dim_size(TF_3DFILTER_DIM_H)); + int filter_cols = + static_cast(filter_shape.dim_size(TF_3DFILTER_DIM_W)); + int filter_in_depth = + static_cast(filter_shape.dim_size(TF_3DFILTER_DIM_I)); + int filter_out_depth = + static_cast(filter_shape.dim_size(TF_3DFILTER_DIM_O)); + + // oneDNN always needs filter in OIDHW format. + // OIDHW = (out_depth, in_depth, planes, rows, cols) + std::vector filter_sizes(5, -1); + filter_sizes[MklDnnDims3D::Dim3d_O] = filter_out_depth; + filter_sizes[MklDnnDims3D::Dim3d_I] = filter_in_depth; + filter_sizes[MklDnnDims3D::Dim3d_D] = filter_planes; + filter_sizes[MklDnnDims3D::Dim3d_H] = filter_rows; + filter_sizes[MklDnnDims3D::Dim3d_W] = filter_cols; + *filter_dims = filter_sizes; + } + } + + // Calculate Convolution filter size in oneDNN order. + // oneDNN requires filter in OIHW (Conv2D) or OIDHW(Conv3D format. + // Function does not return anything. But errors arising from sanity + // checks are returned in context's status. + virtual inline void GetFilterSizeInMklOrder(size_t src_index, + size_t filter_index, + memory::dims* filter_dims, + bool* is_grouped_convolution, + bool is_depthwise) { + DCHECK(filter_dims); + GetFilterSizeInMklOrder(GetTfShape(context_, src_index), + GetTfShape(context_, filter_index), filter_dims, + is_grouped_convolution, is_depthwise); + } + + // Calculate Bias size for 2D or 3D Convolution. Function does not + // return anything, but may set an error in context status. + virtual inline void GetBiasSizeInMklOrder(size_t bias_index, + memory::dims* bias_dims) { + const Tensor& bias = MklGetInput(context_, bias_index); + if (bias.dims() > 1) { + if (strides_.size() == 4) { + OP_REQUIRES( + context_, bias.dims() <= 4, + errors::InvalidArgument("For NHWC format, bias should have " + "4 or less dimensions", + bias.shape().DebugString())); + } else if (strides_.size() == 5) { + OP_REQUIRES( + context_, bias.dims() <= 5, + errors::InvalidArgument("For NDHWC format, bias should have " + "5 or less dimensions", + bias.shape().DebugString())); + } + // Make sure all the dims except channel(last) is 1 + for (int i = 0; i < bias.dims() - 1; i++) { + OP_REQUIRES( + context_, bias.dim_size(i) == 1, + errors::InvalidArgument("For bias_dims > 1, all except the last " + "dimension (channel) must be 1: ", + bias.shape().DebugString())); + } + *bias_dims = {static_cast(bias.dim_size(bias.dims() - 1))}; + } else { + *bias_dims = {static_cast(bias.dim_size(0))}; + } + } + + // Function to calculate output and padding size for 2D/3D convolution. + // + // Calculate output shape of Convolution in oneDNN and TensorFlow order. + // oneDNN uses NCHW(Conv2D) or NCDHW(Conv3D) for output order. + // But TensorFlow output will be in NHWC||NCHW(Conv2D) or + // NDHWC||NCDHW(Conv3D) format depending on data format. + // Function also calculates left, right, top and bottom pads. + // Function does not return any status which is set with context status. + // + // TODO(intel-tf): Add similar function for input and filter in MklShape. + virtual inline void GetOutputAndPadSizeInMklOrder( + const TensorShape& input_shape, const TensorShape& filter_shape, + const memory::dims& strides, const memory::dims& dilations, + memory::dims* output_dims_tf_order, memory::dims* output_dims_mkl_order, + memory::dims* pad_l, memory::dims* pad_r, bool is_grouped_convolution, + bool pad_enabled = false, bool is_depthwise = false) { + DCHECK(output_dims_tf_order); + DCHECK(output_dims_mkl_order); + DCHECK(pad_l); + DCHECK(pad_r); + + bool is_conv2d = (strides_.size() == 4); + int input_planes, input_rows, input_cols; + if (is_conv2d) { + input_rows = GetTensorDim(input_shape, data_format_, 'H'); + input_cols = GetTensorDim(input_shape, data_format_, 'W'); + } else { + input_planes = GetTensorDim(input_shape, data_format_, '0'); + input_rows = GetTensorDim(input_shape, data_format_, '1'); + input_cols = GetTensorDim(input_shape, data_format_, '2'); + } + + // Filter dimension + // Conv2D: + // First dimension: rows/height. + // Second dimension: cols/width. + // Conv3D: + // First dimension: planes/depth. + // Second dimension: rows/height. + // Third dimension: cols/width. + + int filter_planes, filter_rows, filter_cols; + if (is_conv2d) { + filter_rows = filter_shape.dim_size(TF_2DFILTER_DIM_H); + filter_cols = filter_shape.dim_size(TF_2DFILTER_DIM_W); + } else { + filter_planes = filter_shape.dim_size(TF_3DFILTER_DIM_P); + filter_rows = filter_shape.dim_size(TF_3DFILTER_DIM_H); + filter_cols = filter_shape.dim_size(TF_3DFILTER_DIM_W); + } + + int stride_planes, stride_rows, stride_cols; + int dilation_planes, dilation_rows, dilation_cols; + if (is_conv2d) { + // Conv2D stride is a vector of 2 elements: {s_r, s_c} + stride_rows = strides[0]; + stride_cols = strides[1]; + dilation_rows = dilations[0]; + dilation_cols = dilations[1]; + } else { + // Conv3D stride is a vector of 3 elements: {s_d, s_r, s_c} + stride_planes = strides[0]; + stride_rows = strides[1]; + stride_cols = strides[2]; + dilation_planes = dilations[0]; + dilation_rows = dilations[1]; + dilation_cols = dilations[2]; + } + + // Output batch is same as input batch. + int out_batch = GetTensorDim(input_shape, data_format_, 'N'); + int out_depth; + + // TODO(intel-tf) add support for 3-D Depthwise + + // Output depth is same as last dimension for filters for regular + // convolutions and group convolutions. For depthwise it is in_depth * + // channel_multiplier. The channel_multiplier is the last dimension of + // TF filter for depthwise convolutions. + if (is_depthwise) { + out_depth = (filter_shape.dim_size(TF_2DFILTER_DIM_I) * + filter_shape.dim_size(TF_2DFILTER_DIM_O)); + } else if (is_grouped_convolution) { + out_depth = filter_shape.dim_size(TF_2DFILTER_DIM_O); + } else { + out_depth = filter_shape.dim_size( + is_conv2d ? static_cast(TF_2DFILTER_DIM_O) + : static_cast(TF_3DFILTER_DIM_O)); + } + + int64 out_rows = 0, out_cols = 0, out_planes = 0; + int64 pad_top = 0, pad_bottom = 0, pad_left = 0, pad_right = 0; + int64 pad_front, pad_back; + + if (is_conv2d) { + Padding padding_type; + if (pad_enabled) { + padding_type = Padding::EXPLICIT; + pad_top = static_cast((*pad_l)[0]); + pad_left = static_cast((*pad_l)[1]); + pad_bottom = static_cast((*pad_r)[0]); + pad_right = static_cast((*pad_r)[1]); + } else { + padding_type = padding_; + } + OP_REQUIRES_OK(context_, + GetWindowedOutputSizeVerbose( + input_rows, filter_rows, dilation_rows, stride_rows, + padding_type, &out_rows, &pad_top, &pad_bottom)); + OP_REQUIRES_OK(context_, + GetWindowedOutputSizeVerbose( + input_cols, filter_cols, dilation_cols, stride_cols, + padding_type, &out_cols, &pad_left, &pad_right)); + } else { + Padding padding_type; + if (pad_enabled) { + padding_type = Padding::EXPLICIT; + pad_front = static_cast((*pad_l)[0]); + pad_top = static_cast((*pad_l)[1]); + pad_left = static_cast((*pad_l)[2]); + pad_back = static_cast((*pad_r)[0]); + pad_bottom = static_cast((*pad_r)[1]); + pad_right = static_cast((*pad_r)[2]); + } else { + padding_type = padding_; + } + OP_REQUIRES_OK(context_, GetWindowedOutputSizeVerbose( + input_planes, filter_planes, dilation_planes, + stride_planes, padding_type, &out_planes, + &pad_front, &pad_back)); + OP_REQUIRES_OK(context_, + GetWindowedOutputSizeVerbose( + input_rows, filter_rows, dilation_rows, stride_rows, + padding_type, &out_rows, &pad_top, &pad_bottom)); + OP_REQUIRES_OK(context_, + GetWindowedOutputSizeVerbose( + input_cols, filter_cols, dilation_cols, stride_cols, + padding_type, &out_cols, &pad_left, &pad_right)); + } + + if (is_conv2d) { + // If pad_enabled, i.e., pad and conv op are fused, then + // all pads are already passed from pad op through + // *pad_l and *pad_r and they don't need to be set here. + if (!pad_enabled) { + *pad_l = {static_cast(pad_top), static_cast(pad_left)}; + *pad_r = {static_cast(pad_bottom), static_cast(pad_right)}; + } + } else { + // If pad_enabled, i.e., pad and conv op are fused, then + // all pads are already passed from pad op through + // *pad_l and *pad_r and they don't need to be set here. + if (!pad_enabled) { + *pad_l = {static_cast(pad_front), static_cast(pad_top), + static_cast(pad_left)}; + *pad_r = {static_cast(pad_back), static_cast(pad_bottom), + static_cast(pad_right)}; + } + } + // Tensorflow output is in data_format order. + // Conv2D: NHWC or NCHW + // Conv3D: NDHWC or NCDHW + // oneDNN uses asymmetric padding. + TensorShape out_shape; + if (is_conv2d) { + OP_REQUIRES_OK( + context_, ShapeFromFormatWithStatus(data_format_, out_batch, out_rows, + out_cols, out_depth, &out_shape)); + } else { + OP_REQUIRES_OK(context_, ShapeFromFormatWithStatus( + data_format_, out_batch, + {{out_planes, out_rows, out_cols}}, + out_depth, &out_shape)); + } + *output_dims_tf_order = TFShapeToMklDnnDims(out_shape); + if (is_grouped_convolution) { + int out_depth = GetTensorDim(out_shape, data_format_, 'C'); + int input_depth = GetTensorDim(input_shape, data_format_, 'C'); + int filter_in_depth = + static_cast(filter_shape.dim_size(TF_2DFILTER_DIM_I)); + int num_groups = input_depth / filter_in_depth; + OP_REQUIRES( + context_, out_depth % num_groups == 0 && out_depth >= num_groups, + errors::InvalidArgument( + "output depth must be evenly divisible by number of groups: ", + out_depth, " vs ", num_groups)); + } + if (is_conv2d) { + // For Conv2D, oneDNN always needs output in NCHW format. + std::vector output_sizes(4, -1); + output_sizes[MklDnnDims::Dim_N] = out_batch; + output_sizes[MklDnnDims::Dim_C] = out_depth; + output_sizes[MklDnnDims::Dim_H] = static_cast(out_rows); + output_sizes[MklDnnDims::Dim_W] = static_cast(out_cols); + *output_dims_mkl_order = output_sizes; + } else { + std::vector output_sizes(5, -1); + output_sizes[MklDnnDims3D::Dim3d_N] = out_batch; + output_sizes[MklDnnDims3D::Dim3d_C] = out_depth; + output_sizes[MklDnnDims3D::Dim3d_D] = static_cast(out_planes); + output_sizes[MklDnnDims3D::Dim3d_H] = static_cast(out_rows); + output_sizes[MklDnnDims3D::Dim3d_W] = static_cast(out_cols); + *output_dims_mkl_order = output_sizes; + } + } + + // Calculate output and pad size of forward Convolution operator. + // See comment on GetConvOutputAndPadSizeInMklOrder for parameters. + // + // Function does not return anything, but sets error in context status. + inline void GetOutputAndPadSizeInMklOrder( + size_t src_index, size_t filter_index, const memory::dims& strides, + const memory::dims& dilations, memory::dims* output_dims_tf_order, + memory::dims* output_dims_mkl_order, memory::dims* pad_l, + memory::dims* pad_r, bool is_grouped_convolution, bool is_depthwise) { + DCHECK(output_dims_tf_order); + DCHECK(output_dims_mkl_order); + DCHECK(pad_l); + DCHECK(pad_r); + + auto input_tf_shape = GetTfShape(context_, src_index); + auto filter_tf_shape = GetTfShape(context_, filter_index); + + if (strides_.size() == 4) { + // Conv2D + OP_REQUIRES(context_, input_tf_shape.dims() == 4, + errors::InvalidArgument("input must be 4-dimensional", + input_tf_shape.DebugString())); + OP_REQUIRES(context_, filter_tf_shape.dims() == 4, + errors::InvalidArgument("filter must be 4-dimensional", + filter_tf_shape.DebugString())); + } else { + // Conv3D + OP_REQUIRES(context_, input_tf_shape.dims() == 5, + errors::InvalidArgument("input must be 5-dimensional", + input_tf_shape.DebugString())); + OP_REQUIRES(context_, filter_tf_shape.dims() == 5, + errors::InvalidArgument("filter must be 5-dimensional", + filter_tf_shape.DebugString())); + } + + GetOutputAndPadSizeInMklOrder(input_tf_shape, filter_tf_shape, strides, + dilations, output_dims_tf_order, + output_dims_mkl_order, pad_l, pad_r, + is_grouped_convolution, is_depthwise); + } + + // Wrapper function to calculate input, filter, and output sizes of + // Conv2D/Conv3D in MKL order: + // Conv2D: NCHW for input and output; OIHW for filter. + // Conv3D: NCDHW for input and output; OIDHW for filter. + // Function also calculates output shape in Tensorflow order. + // Additionally, it also calculates strides and paddings. + // + // Function does not return anything, but sets error in context status. + inline void GetConvFwdSizesInMklOrder( + const TensorShape& input_shape, const TensorShape& filter_shape, + memory::dims* input_dims, memory::dims* filter_dims, + memory::dims* strides, memory::dims* dilations, + memory::dims* output_dims_tf_order, memory::dims* output_dims_mkl_order, + memory::dims* pad_l, memory::dims* pad_r, bool* is_grouped_convolution, + bool pad_enabled = false, bool is_depthwise = false) { + DCHECK(input_dims); + DCHECK(filter_dims); + DCHECK(strides); + DCHECK(dilations); + DCHECK(output_dims_tf_order); + DCHECK(output_dims_mkl_order); + DCHECK(pad_l); + DCHECK(pad_r); + + GetInputSizeInMklOrder(input_shape, input_dims); + if (!context_->status().ok()) return; + GetFilterSizeInMklOrder(input_shape, filter_shape, filter_dims, + is_grouped_convolution, is_depthwise); + if (!context_->status().ok()) return; + GetStridesInMklOrder(strides); + GetDilationsInMklOrder(dilations); + GetOutputAndPadSizeInMklOrder( + input_shape, filter_shape, *strides, *dilations, output_dims_tf_order, + output_dims_mkl_order, pad_l, pad_r, *is_grouped_convolution, + pad_enabled, is_depthwise); + if (!context_->status().ok()) return; + } +}; + +///////////////////////////////////////////////////////////////////// +/// Common class that implements ConvBackpropFilter and Input +///////////////////////////////////////////////////////////////////// + +template +class MklConvBackpropCommonOp : public OpKernel { + public: + ~MklConvBackpropCommonOp() {} + explicit MklConvBackpropCommonOp(OpKernelConstruction* context) + : OpKernel(context) { + string data_format_str; + OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format_str)); + OP_REQUIRES(context, FormatFromString(data_format_str, &data_format_), + errors::InvalidArgument("Invalid data format")); + OP_REQUIRES_OK(context, context->GetAttr("strides", &strides_)); + int stride_n = GetTensorDim(strides_, data_format_, 'N'); + int stride_c = GetTensorDim(strides_, data_format_, 'C'); + OP_REQUIRES( + context, (stride_n == 1 && stride_c == 1), + errors::InvalidArgument("Current implementation does not yet support " + "strides in the batch and depth dimensions.")); + + // Depthwise Convolution doesn't have dilation parameter + if (!is_depthwise) { + OP_REQUIRES_OK(context, context->GetAttr("dilations", &dilations_)); + if (strides_.size() == 4) { + // Check Conv2D dilations + OP_REQUIRES( + context, dilations_.size() == 4, + errors::InvalidArgument("Sliding window dilations field must " + "specify 4 dimensions")); + int dilation_n = GetTensorDim(dilations_, data_format_, 'N'); + int dilation_c = GetTensorDim(dilations_, data_format_, 'C'); + int dilation_h = GetTensorDim(dilations_, data_format_, 'H'); + int dilation_w = GetTensorDim(dilations_, data_format_, 'W'); + OP_REQUIRES(context, (dilation_n == 1 && dilation_c == 1), + errors::InvalidArgument( + "Current implementation does not yet support " + "dilations in the batch and depth dimensions.")); + OP_REQUIRES( + context, dilation_h > 0 && dilation_w > 0, + errors::InvalidArgument("Dilated rates should be larger than 0.")); + } + } else { + // Set dilations as 1 for depthwise conv + // for future support to align with Tensorflow + dilations_ = {1, 1, 1, 1}; + } + + OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_)); + } + + protected: + // data members accessible to derived classes. + std::vector dilations_; + std::vector strides_; + Padding padding_; + TensorFormat data_format_; // NCHW or NHWC +}; + +///////////////////////////////////////////////////////////////////// +/// Dummy Mkl op that is just used for operators that are intermediate +/// output of node fusion in the graph +///////////////////////////////////////////////////////////////////// + +template +class MklDummyOp : public OpKernel { + public: + ~MklDummyOp() {} + + explicit MklDummyOp(OpKernelConstruction* context) : OpKernel(context) {} + + void Compute(OpKernelContext* context) override { + TF_CHECK_OK( + errors::Unimplemented("This is a dummy op." + "It should not have been invoked.")); + } +}; + +} // namespace tensorflow + +#endif // INTEL_MKL +#endif // TENSORFLOW_CORE_KERNELS_MKL_MKL_CONV_OPS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/mkl/mkl_eltwise_activation_base_op.h b/third_party/tflite-hdrs/tensorflow/core/kernels/mkl/mkl_eltwise_activation_base_op.h new file mode 100644 index 00000000..a1d1268d --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/mkl/mkl_eltwise_activation_base_op.h @@ -0,0 +1,351 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_MKL_MKL_ELTWISE_ACTIVATION_BASE_OP_H_ +#define TENSORFLOW_CORE_KERNELS_MKL_MKL_ELTWISE_ACTIVATION_BASE_OP_H_ + +// See docs in ../ops/mkl_nn_ops.cc. + +#ifdef INTEL_MKL + +#include + +#include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "dnnl.hpp" +#include "tensorflow/core/framework/numeric_op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/util/mkl_util.h" +#if defined(DNNL_AARCH64_USE_ACL) && defined(ENABLE_ONEDNN_OPENMP) +#include "tensorflow/core/platform/mutex.h" +#endif + +using dnnl::algorithm; +using dnnl::eltwise_forward; +using dnnl::memory; +using dnnl::prop_kind; +using dnnl::stream; + +using EltwiseFwdActivationPd = dnnl::eltwise_forward::primitive_desc; + +namespace tensorflow { +#ifndef ENABLE_ONEDNN_V3 +#define GET_MEMORY_DESC(md) md.data +#else +#define GET_MEMORY_DESC(md) md +#endif // !ENABLE_ONEDNN_V3 + +// TODO(tf-onednn): Consolidate this class with `MklEltWiseFwdParams` +// in `mkl_relu_op.cc`. +// +// The implementation of this class is very similar to it and it +// should be consolidated to one class +template +class MklEltwiseFwdActivationParams { + public: + memory::dims src_dims; + memory::desc src_md; +#ifdef ENABLE_ONEDNN_V3 + memory::desc dst_md; +#endif // ENABLE_ONEDNN_V3 + algorithm alg_kind; + float alpha; + float beta; + + MklEltwiseFwdActivationParams(memory::dims src_dims, memory::desc src_md, +#ifdef ENABLE_ONEDNN_V3 + memory::desc dst_md, +#endif // ENABLE_ONEDNN_V3 + algorithm alg_kind, float alpha, float beta) + : src_dims(src_dims), + src_md(src_md), +#ifdef ENABLE_ONEDNN_V3 + dst_md(dst_md), +#endif // ENABLE_ONEDNN_V3 + alg_kind(alg_kind), + alpha(alpha), + beta(beta) { + } +}; + +template +class MklEltwiseFwdActivationPrimitive : public MklPrimitive { + public: + explicit MklEltwiseFwdActivationPrimitive( + const MklEltwiseFwdActivationParams& fwdParams) + : MklPrimitive(engine(engine::kind::cpu, 0)) { + // create eltwise primitive + if (context_.eltwise_fwd == nullptr) { + Setup(fwdParams); + } + } + + ~MklEltwiseFwdActivationPrimitive() {} + + // Eltwise forward execute + // src_data: input data buffer of src + // dst_data: output data buffer of dst + void Execute(const T* src_data, T* dst_data, OpKernelContext* op_context) { +#if defined(DNNL_AARCH64_USE_ACL) && defined(ENABLE_ONEDNN_OPENMP) + mutex_lock lock(primitive_execution_mu_); +#endif + context_.src_mem->set_data_handle( + static_cast(const_cast(src_data))); + context_.dst_mem->set_data_handle(static_cast(dst_data)); + DCHECK_EQ(context_.fwd_primitives.size(), + context_.fwd_primitives_args.size()); + + std::vector net; + net.push_back(eltwise_forward(*context_.fwd_pd)); + std::vector net_args; + net_args.push_back( + {{DNNL_ARG_SRC, *context_.src_mem}, {DNNL_ARG_DST, *context_.dst_mem}}); + // execute eltwise_fwd primitve + ExecutePrimitive(net, &net_args, GetEngine(), op_context); + + // After execution, set data handle back. + context_.src_mem->set_data_handle(DummyData); + context_.dst_mem->set_data_handle(DummyData); + } + + std::shared_ptr GetEltwiseFwdActivationPd() { + return context_.fwd_pd; + } + + private: + // Primitive reuse context for eltwise Fwd ops: Relu, Elu, Tanh + struct EltwiseFwdActivationContext { + // oneDNN memory + std::shared_ptr src_mem; + std::shared_ptr dst_mem; + + // desc & primitive desc +#ifndef ENABLE_ONEDNN_V3 + std::shared_ptr fwd_desc; +#endif // !ENABLE_ONEDNN_V3 + std::shared_ptr fwd_pd; + + // memory desc + std::shared_ptr src_md; + std::shared_ptr dst_md; + + // memory primitive desc + std::shared_ptr src_mpd; + + // Eltwise primitive + std::shared_ptr eltwise_fwd; + + std::vector fwd_primitives; + + std::vector> fwd_primitives_args; + + EltwiseFwdActivationContext() + : src_mem(nullptr), + dst_mem(nullptr), +#ifndef ENABLE_ONEDNN_V3 + fwd_desc(nullptr), +#endif // !ENABLE_ONEDNN_V3 + fwd_pd(nullptr), + src_md(nullptr), + dst_md(nullptr), + src_mpd(nullptr), + eltwise_fwd(nullptr) { + } + }; + + // Eltwise forward primitive setup + void Setup(const MklEltwiseFwdActivationParams& fwdParams) { + // create memory descriptors for eltwise data with specified format + context_.src_md.reset(new memory::desc(GET_MEMORY_DESC(fwdParams.src_md))); + context_.src_mpd.reset(new memory::desc(*context_.src_md)); + + // Create an eltwise forward descriptor and primitive descriptor +#ifndef ENABLE_ONEDNN_V3 + context_.fwd_desc.reset(new eltwise_forward::desc( + prop_kind::forward, fwdParams.alg_kind, *context_.src_md, + fwdParams.alpha, fwdParams.beta)); + context_.fwd_pd.reset( + new EltwiseFwdActivationPd(*context_.fwd_desc, cpu_engine_)); +#else + context_.dst_md.reset(new memory::desc(fwdParams.dst_md)); + context_.fwd_pd.reset(new EltwiseFwdActivationPd( + cpu_engine_, prop_kind::forward, fwdParams.alg_kind, *context_.src_md, + *context_.dst_md, fwdParams.alpha, fwdParams.beta)); +#endif // !ENABLE_ONEDNN_V3 + auto fwd_pd = context_.fwd_pd.get(); + + // Create memory primitive based on dummy data + context_.src_mem.reset( + new memory(fwd_pd->src_desc(), cpu_engine_, DummyData)); + context_.dst_mem.reset( + new memory(fwd_pd->dst_desc(), cpu_engine_, DummyData)); + // Create eltwise primitive and add it to net + context_.eltwise_fwd.reset(new eltwise_forward(*context_.fwd_pd)); + context_.fwd_primitives_args.push_back( + {{DNNL_ARG_SRC, *context_.src_mem}, {DNNL_ARG_DST, *context_.dst_mem}}); + context_.fwd_primitives.push_back(*context_.eltwise_fwd); + } + + struct EltwiseFwdActivationContext context_; + +#if defined(DNNL_AARCH64_USE_ACL) && defined(ENABLE_ONEDNN_OPENMP) + mutex primitive_execution_mu_; +#endif +}; + +template +class MklEltwiseFwdActivationPrimitiveFactory : public MklPrimitiveFactory { + public: + static MklEltwiseFwdActivationPrimitive* Get( + const MklEltwiseFwdActivationParams& fwdParams) { + MklEltwiseFwdActivationPrimitive* eltwise_forward = nullptr; + + // Get a eltwise fwd primitive from the cached pool + eltwise_forward = static_cast*>( + MklEltwiseFwdActivationPrimitiveFactory::GetInstance() + .GetEltwiseFwdActivation(fwdParams)); + if (eltwise_forward == nullptr) { + eltwise_forward = new MklEltwiseFwdActivationPrimitive(fwdParams); + MklEltwiseFwdActivationPrimitiveFactory::GetInstance() + .SetEltwiseFwdActivation(fwdParams, eltwise_forward); + } + + return eltwise_forward; + } + + static MklEltwiseFwdActivationPrimitiveFactory& GetInstance() { + static MklEltwiseFwdActivationPrimitiveFactory instance_; + return instance_; + } + + private: + MklEltwiseFwdActivationPrimitiveFactory() {} + ~MklEltwiseFwdActivationPrimitiveFactory() {} + + static string CreateKey(const MklEltwiseFwdActivationParams& fwdParams) { + string prefix = "eltwise_fwd"; + FactoryKeyCreator key_creator; + key_creator.AddAsKey(prefix); + key_creator.AddAsKey(fwdParams.src_dims); + key_creator.AddAsKey(static_cast(fwdParams.alg_kind)); + key_creator.AddAsKey(static_cast(fwdParams.alpha)); + key_creator.AddAsKey(static_cast(fwdParams.beta)); + return key_creator.GetKey(); + } + + MklPrimitive* GetEltwiseFwdActivation( + const MklEltwiseFwdActivationParams& fwdParams) { + string key = CreateKey(fwdParams); + return this->GetOp(key); + } + + void SetEltwiseFwdActivation( + const MklEltwiseFwdActivationParams& fwdParams, MklPrimitive* op) { + string key = CreateKey(fwdParams); + this->SetOp(key, op); + } +}; + +template +class MklEltwiseFwdActivationOpBase : public OpKernel { + public: + ~MklEltwiseFwdActivationOpBase() {} + + explicit MklEltwiseFwdActivationOpBase(OpKernelConstruction* context, + float alpha, float beta) + : OpKernel(context), alpha_(alpha), beta_(beta) {} + virtual void Compute_Scalar(OpKernelContext* context) = 0; + + void Compute(OpKernelContext* context) override { + try { + const Tensor& src_tensor = context->input(0); + TensorShape src_shape = src_tensor.shape(); + if (src_tensor.dims() == 0) { + Compute_Scalar(context); + return; + } + // Allocate output (dst) tensor + TensorShape dst_shape = src_shape; + Tensor* dst_tensor = nullptr; + // Nothing to compute, return. + if (src_shape.num_elements() == 0) { + OP_REQUIRES_OK(context, + context->allocate_output( + GetTensorDataIndex(0, context->num_outputs()), + dst_shape, &dst_tensor)); + return; + } + // Set DNN primitive - src + MklDnnData src(&cpu_engine); + memory::dims src_dims; + memory::desc src_md({}, memory::data_type::undef, + memory::format_tag::undef); + + src_dims = TFShapeToMklDnnDims(src_tensor.shape()); + auto src_strides = CalculateTFStrides(src_dims); + + // Create blocked memory descriptor + src_md = MklDnnData::CreateBlockedMemDesc(src_dims, src_strides); + +#ifdef ENABLE_ONEDNN_V3 + memory::desc dst_md = src_md; +#endif // ENABLE_ONEDNN_V3 + + // Try to get an eltwise forward primitive from caching pool + MklEltwiseFwdActivationParams fwdParams(src_dims, src_md, +#ifdef ENABLE_ONEDNN_V3 + dst_md, +#endif // ENABLE_ONEDNN_V3 + alg_kind, alpha_, beta_); + MklEltwiseFwdActivationPrimitive* eltwise_fwd = + MklEltwiseFwdActivationPrimitiveFactory::Get(fwdParams); + + const T* src_data = src_tensor.flat().data(); + + OP_REQUIRES_OK(context, context->allocate_output( + GetTensorDataIndex(0, context->num_outputs()), + dst_shape, &dst_tensor)); + + T* dst_data = dst_tensor->flat().data(); + // execute eltwise + eltwise_fwd->Execute(src_data, dst_data, context); + } catch (dnnl::error& e) { + string error_msg = "Status: " + std::to_string(e.status) + + ", message: " + string(e.message) + ", in file " + + string(__FILE__) + ":" + std::to_string(__LINE__); + OP_REQUIRES_OK( + context, + errors::Aborted("Operation received an exception:", error_msg)); + } + } + + private: + engine cpu_engine = engine(engine::kind::cpu, 0); + + protected: + float alpha_; + float beta_; +}; + +// TODO : Implement Eltwise bwd / eltwiseGrad class + +#undef GET_MEMORY_DESC + +} // namespace tensorflow + +#endif // INTEL_MKL +#endif // TENSORFLOW_CORE_KERNELS_MKL_MKL_ELTWISE_ACTIVATION_BASE_OP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/mkl/mkl_kernel_util.h b/third_party/tflite-hdrs/tensorflow/core/kernels/mkl/mkl_kernel_util.h new file mode 100644 index 00000000..da600fb0 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/mkl/mkl_kernel_util.h @@ -0,0 +1,135 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_MKL_MKL_KERNEL_UTIL_H_ +#define TENSORFLOW_CORE_KERNELS_MKL_MKL_KERNEL_UTIL_H_ + +#ifdef INTEL_MKL + +#include "dnnl.hpp" +#include "tensorflow/core/graph/testlib.h" +#include "tensorflow/core/public/session.h" +#include "tsl/platform/status.h" + +using dnnl::memory; + +using dnnl::memory; + +namespace tensorflow { + +class MklTestingUtil { + public: + static void RunMklQuantizeOp(const Tensor& input, const float input_min, + const float input_max, DataType type, + string mode, Tensor* output); + static void RunDequantizeOp(const Tensor& input, const Tensor& input_min, + const Tensor& input_max, string mode, + Tensor* output); + + static void RunGraph(const tensorflow::GraphDef graph_def, + const string& fetch, Tensor* output); + template + static void ComputeMinMax(const Tensor& tf_tensor, T* tensor_min, + T* tensor_max) { + auto eigen_tensor = tf_tensor.flat(); + Eigen::Tensor min = eigen_tensor.minimum(); + Eigen::Tensor max = eigen_tensor.maximum(); + *tensor_min = min(); + *tensor_max = max(); + } + + // This utility function mimics Quantization of float/bfloat16 tensor with + // oneDNN backend QuantizeV2 operation. Since the op signature requires min + // and max values to be in float type, min_tensor and max_tensor should have + // their dtype set to DT_FLOAT. + template + static Status GetQuantizationTensors(const Tensor& input, Tensor* output, + DataType out_type, const string mode, + Tensor* min_tensor, Tensor* max_tensor) { + if (min_tensor->dtype() != DT_FLOAT || max_tensor->dtype() != DT_FLOAT) { + return absl::UnimplementedError("Tensor must be float32."); + } + T min; + T max; + ComputeMinMax(input, &min, &max); + + float adjusted_min = static_cast(min); + float adjusted_max = static_cast(max); + if (mode == "SCALED") { + if (output->dtype() != DT_QINT8) { + return absl::UnimplementedError("Tensor must be QInt8 in SCALED mode."); + } + float range = std::max(std::abs(adjusted_min), std::abs(adjusted_max)); + adjusted_min = -range; + adjusted_max = range; + } + RunMklQuantizeOp(input, adjusted_min, adjusted_max, out_type, mode, output); + min_tensor->flat()(0) = adjusted_min; + max_tensor->flat()(0) = adjusted_max; + + return OkStatus(); + } +}; + +#ifdef ENABLE_ONEDNN_V3 +// Since oneDNN v3.x exposes only an opaque memory descriptor, it is no longer +// possible to cache the entire filter memory descriptor as is. So we store +// all relevant information about it in the following class. +// +// TODO(intel-tf): When oneDNN major version changes to v4.x, weight +// caching may not work as expected if the underlying memory descriptor +// has changed (i.e. compared to v3.x). We have to return a status here +// to catch oneDNN major version change to avoid unexpected results. +class FilterMemoryDesc { + public: + FilterMemoryDesc() {} + + explicit FilterMemoryDesc(int ndims, int inner_nblks, + memory::data_type data_type, + const memory::dims& dims, + const memory::dims& inner_blks, + const memory::dims& inner_idxs, + const memory::dims& strides) + : ndims_(ndims), + inner_nblks_(inner_nblks), + data_type_(data_type), + dims_(dims), + inner_blks_(inner_blks), + inner_idxs_(inner_idxs), + strides_(strides) {} + + ~FilterMemoryDesc() {} + + bool operator==(const FilterMemoryDesc& other) const { + return (ndims_ == other.ndims_ && inner_nblks_ == other.inner_nblks_ && + data_type_ == other.data_type_ && dims_ == other.dims_ && + inner_blks_ == other.inner_blks_ && + inner_idxs_ == other.inner_idxs_ && strides_ == other.strides_); + } + + private: + int ndims_; + int inner_nblks_; + memory::data_type data_type_; + memory::dims dims_; + memory::dims inner_blks_; + memory::dims inner_idxs_; + memory::dims strides_; +}; +#endif // ENABLE_ONEDNN_V3 +} // namespace tensorflow + +#endif // INTEL_MKL +#endif // TENSORFLOW_CORE_KERNELS_MKL_MKL_KERNEL_UTIL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/mkl/mkl_matmul_ops_common.h b/third_party/tflite-hdrs/tensorflow/core/kernels/mkl/mkl_matmul_ops_common.h new file mode 100644 index 00000000..8af21582 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/mkl/mkl_matmul_ops_common.h @@ -0,0 +1,1219 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_MKL_MKL_MATMUL_OPS_COMMON_H_ +#define TENSORFLOW_CORE_KERNELS_MKL_MKL_MATMUL_OPS_COMMON_H_ + +#if defined(INTEL_MKL) +#include +#include +#include + +#include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "dnnl.hpp" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor_util.h" +#include "tensorflow/core/kernels/mkl/mkl_kernel_util.h" +#include "tensorflow/core/util/mkl_util.h" +#include "tensorflow/core/util/onednn_env_vars.h" +#if defined(DNNL_AARCH64_USE_ACL) && defined(ENABLE_ONEDNN_OPENMP) +#include "tensorflow/core/platform/mutex.h" +#endif + +using dnnl::inner_product_forward; +using dnnl::primitive_attr; +using dnnl::prop_kind; +using dnnl::stream; + +namespace tensorflow { + +#ifndef ENABLE_ONEDNN_V3 +#define APPEND_ELTWISE(scale, alg, alpha, beta) \ + append_eltwise(scale, alg, alpha, beta) +#define APPEND_ELTWISE_RELU6(scale, alpha, beta) \ + append_eltwise(scale, dnnl::algorithm::eltwise_bounded_relu, alpha, beta) +#define OUTPUT_SCALE_DCHECK (post_op_param.name == "output_scale") +#define SET_MKL_LAYOUT(md) SetMklLayout(&md) +#define TSCALED_BIAS Tbias +#else +#define APPEND_ELTWISE(scale, alg, alpha, beta) \ + append_eltwise(alg, alpha, beta); \ + (void)scale +#define APPEND_ELTWISE_RELU6(scale, alpha, beta) \ + append_eltwise(dnnl::algorithm::eltwise_clip, 0.0, alpha); \ + (void)scale; \ + (void)beta +#define OUTPUT_SCALE_DCHECK \ + (post_op_param.name == "src_scale") || \ + (post_op_param.name == "wei_scale") || \ + (post_op_param.name == "dst_scale") +#define SET_MKL_LAYOUT(md) SetMklLayout(md) +#define TSCALED_BIAS float +#endif // !ENABLE_ONEDNN_V3 + +#if !defined(ENABLE_ONEDNN_OPENMP) && !defined(ENABLE_ONEDNN_V3) +#define FWD_STREAM , *fwd_stream +#else +#define FWD_STREAM +#endif // !ENABLE_ONEDNN_OPENMP && !ENABLE_ONEDNN_V3 + +static Eigen::internal::CacheSizes cache_sizes = Eigen::internal::CacheSizes(); + +typedef Eigen::ThreadPoolDevice CPUDevice; +inline bool ExecuteSingleThreadedGemm(int64_t m, int64_t n, int64_t k, + int bytes) { + // Ideally we would like to determine blocking and then come up with + // a heuristic but what we are targeting are very small models whose + // total size is < x*L2. So we will do this simple calculation + // to determine if the matrix multiplication should be run on a single thread. + // TODO(Intel-tf): this needs to be vastly improved, perhaps at a lower level + // than the integration. + ptrdiff_t l2_size = cache_sizes.m_l2; + constexpr float kHeuristicMultiplier = 1.01; + const float mul_size = bytes * (m * n + k * (m + n)); + const float l2_heur = l2_size * kHeuristicMultiplier; + return (mul_size >= 0 && mul_size < l2_heur); +} + +// This structure aggregates multiple inputs to MklDnnMatMul* methods. +struct MklDnnMatMulFwdParams { + memory::dims src_dims; + memory::dims weight_dims; + memory::dims bias_dims; + memory::dims dst_dims; + memory::format_tag src_format; + memory::format_tag weight_format; + memory::format_tag dst_format; + string dtypes = string(""); + bool const_weight; + struct PostOpParam { + string name; + std::vector param; + string partial_key; + }; + std::vector post_op_params; + string input_quant_mode; + + MklDnnMatMulFwdParams( + memory::dims src_dims, memory::dims weight_dims, memory::dims bias_dims, + memory::dims dst_dims, + memory::format_tag src_format = memory::format_tag::any, + memory::format_tag weight_format = memory::format_tag::any, + memory::format_tag dst_format = memory::format_tag::any, + bool const_weight = false) + : src_dims(src_dims), + weight_dims(weight_dims), + bias_dims(bias_dims), + dst_dims(dst_dims), + src_format(src_format), + weight_format(weight_format), + dst_format(dst_format), + const_weight(const_weight) {} +}; + +// With quantization, input, weight, bias, and output can have different types. +// So we use different template parameters for each type. +// TODO(intel-tf): The template type "T" is currently used to match the +// templatized class MklPrimitiveFactory (tensorflow/core/util/mkl_util.h). +// In the future, with the removal of "T" from MklPrimitiveFactory, this class +// needs to drop "T". +template +class MklDnnMatMulFwdPrimitive : public MklPrimitive { + public: + explicit MklDnnMatMulFwdPrimitive( + const MklDnnMatMulFwdParams& matmulFwdParams) + : MklPrimitive(engine(engine::kind::cpu, 0)) { + // Create matmul primitive + if (context_.matmul_fwd == nullptr) { + Setup(matmulFwdParams); + } + } + + ~MklDnnMatMulFwdPrimitive() {} + + dnnl::memory::desc GetScratchPadDesc() { + return context_.fwd_pd->scratchpad_desc(); + } + + // Inner-product forward execute with bias: + // - src_data: input data buffer of src + // - weight_data: input data buffer of weight + // - bias_data: input data buffer of bias + // - dst_data: output data buffer of dst + // - sp_data: scratchpad data + void Execute(const Tinput* src_data, const Tweight* weight_data, + const void* bias_data, Toutput* dst_data, + const MklDnnMatMulFwdParams& matmul_fwd_params, void* sp_data, + std::shared_ptr fwd_stream) { +#if defined(DNNL_AARCH64_USE_ACL) && defined(ENABLE_ONEDNN_OPENMP) + mutex_lock lock(primitive_execution_mu_); +#endif + context_.src_mem->set_data_handle( + static_cast(const_cast(src_data)) FWD_STREAM); + context_.weight_mem->set_data_handle( + static_cast(const_cast(weight_data)) FWD_STREAM); + context_.bias_mem->set_data_handle(const_cast(bias_data) FWD_STREAM); + context_.dst_mem->set_data_handle(static_cast(dst_data) FWD_STREAM); + context_.sp_mem->set_data_handle(sp_data FWD_STREAM); + auto const& post_op_params = matmul_fwd_params.post_op_params; + if (!post_op_params.empty()) { + for (auto const& post_op_param : post_op_params) { + if (post_op_param.name == "src_scale") { + context_.src_scale_mem->set_data_handle(static_cast( + const_cast(post_op_param.param.data())) FWD_STREAM); + } else if (post_op_param.name == "wei_scale") { + context_.wei_scale_mem->set_data_handle(static_cast( + const_cast(post_op_param.param.data())) FWD_STREAM); + } else if (post_op_param.name == "dst_scale") { + context_.dst_scale_mem->set_data_handle(static_cast( + const_cast(post_op_param.param.data())) FWD_STREAM); + } + } + } + + execute_primitives(context_.fwd_primitives, fwd_stream, context_.net_args); + + // After execution, set data handle back + context_.src_mem->set_data_handle(DummyData); + context_.weight_mem->set_data_handle(DummyData); + context_.bias_mem->set_data_handle(DummyData); + context_.dst_mem->set_data_handle(DummyData); + } + + std::shared_ptr + GetPrimitiveDesc() const { + return context_.fwd_pd; + } + + private: + // Primitive reuse context for inner-product Fwd op + struct MklDnnMatMulFwdContext { + // oneDNN memory. + std::shared_ptr src_mem; + std::shared_ptr weight_mem; + std::shared_ptr bias_mem; + std::shared_ptr dst_mem; + std::shared_ptr sp_mem; + // Quantization scale related memory + std::shared_ptr src_scale_mem; + std::shared_ptr wei_scale_mem; + std::shared_ptr dst_scale_mem; + + // Descriptor and primitive-descriptor for forward inner-product. +#ifndef ENABLE_ONEDNN_V3 + std::shared_ptr fwd_desc; +#endif // !ENABLE_ONEDNN_V3 + std::shared_ptr fwd_pd; + + // Memory descriptors. + std::shared_ptr src_md; + std::shared_ptr weight_md; + std::shared_ptr bias_md; + std::shared_ptr dst_md; + // Quantization scale related memory descriptors + std::shared_ptr src_scale_md; + std::shared_ptr wei_scale_md; + std::shared_ptr dst_scale_md; + + // Inner-product primitive. + std::shared_ptr matmul_fwd; + std::vector fwd_primitives; + + std::vector> net_args; + + MklDnnMatMulFwdContext() + : src_mem(nullptr), + weight_mem(nullptr), + bias_mem(nullptr), + dst_mem(nullptr), + sp_mem(nullptr), + src_scale_mem(nullptr), + wei_scale_mem(nullptr), + dst_scale_mem(nullptr), +#ifndef ENABLE_ONEDNN_V3 + fwd_desc(nullptr), +#endif // ENABLE_ONEDNN_V3 + fwd_pd(nullptr), + src_md(nullptr), + weight_md(nullptr), + bias_md(nullptr), + dst_md(nullptr), + src_scale_md(nullptr), + wei_scale_md(nullptr), + dst_scale_md(nullptr), + matmul_fwd(nullptr) { + } + }; + + void Setup(const MklDnnMatMulFwdParams& matmul_fwd_params) { + // Create memory descriptors for inner-product data without specified + // format. + context_.src_md.reset(new memory::desc({matmul_fwd_params.src_dims}, + MklDnnType(), + matmul_fwd_params.src_format)); + + context_.weight_md.reset(new memory::desc({matmul_fwd_params.weight_dims}, + MklDnnType(), +#ifdef DNNL_AARCH64_USE_ACL + memory::format_tag::any)); +#else + matmul_fwd_params.weight_format)); +#endif + + context_.dst_md.reset(new memory::desc({matmul_fwd_params.dst_dims}, + MklDnnType(), + matmul_fwd_params.dst_format)); + + memory::data_type bias_dt; +#ifndef ENABLE_ONEDNN_V3 + bias_dt = MklDnnType(); +#else + if (std::is_same::value) { + // For QuantizedMatMul, bias needs to be passed to oneDNN as float of + // bfloat16 (even if Tbias is qint32). + if (std::is_same::value && + matmul_fwd_params.input_quant_mode == "SCALED") { + bias_dt = MklDnnType(); + } else { + bias_dt = MklDnnType(); + } + } else { + bias_dt = MklDnnType(); + } +#endif // !ENABLE_ONEDNN_V3 + context_.bias_md.reset(new memory::desc({matmul_fwd_params.bias_dims}, + bias_dt, memory::format_tag::any)); + + // Create an inner-product. +#ifndef ENABLE_ONEDNN_V3 + context_.fwd_desc.reset(new inner_product_forward::desc( + matmul_fwd_params.const_weight ? prop_kind::forward_inference + : prop_kind::forward_training, + *context_.src_md, *context_.weight_md, *context_.bias_md, + *context_.dst_md)); + context_.fwd_pd.reset(new inner_product_forward::primitive_desc( + *context_.fwd_desc, cpu_engine_)); +#endif // !ENABLE_ONEDNN_V3 + + // Check if there is any fusion as post-ops + auto const& post_op_params = matmul_fwd_params.post_op_params; + dnnl::primitive_attr post_ops_attr; + post_ops_attr.set_scratchpad_mode(dnnl::scratchpad_mode::user); + dnnl::post_ops post_ops; + std::unordered_map is_scale_set; + if (!post_op_params.empty()) { + for (auto const& post_op_param : post_op_params) { + if (post_op_param.name == "Relu" || post_op_param.name == "LeakyRelu") { + DCHECK_EQ(post_op_param.param.size(), 3); + float op_scale = post_op_param.param[0]; + float op_alpha = post_op_param.param[1]; + float op_beta = post_op_param.param[2]; + post_ops.APPEND_ELTWISE(op_scale, dnnl::algorithm::eltwise_relu, + op_alpha, op_beta); + } else if (post_op_param.name == "Relu6") { + DCHECK_EQ(post_op_param.param.size(), 3); + float op_scale = post_op_param.param[0]; + float op_alpha = post_op_param.param[1]; + float op_beta = post_op_param.param[2]; + post_ops.APPEND_ELTWISE_RELU6(op_scale, op_alpha, op_beta); + } else if (post_op_param.name == "Elu") { + DCHECK_EQ(post_op_param.param.size(), 3); + float op_scale = post_op_param.param[0]; + float op_alpha = post_op_param.param[1]; + float op_beta = post_op_param.param[2]; + post_ops.APPEND_ELTWISE(op_scale, dnnl::algorithm::eltwise_elu, + op_alpha, op_beta); + } else if (post_op_param.name == "GeluApproximate") { + DCHECK_EQ(post_op_param.param.size(), 3); + float op_scale = post_op_param.param[0]; + float op_alpha = post_op_param.param[1]; + float op_beta = post_op_param.param[2]; + post_ops.APPEND_ELTWISE(op_scale, dnnl::algorithm::eltwise_gelu_tanh, + op_alpha, op_beta); + } else if (post_op_param.name == "GeluExact") { + DCHECK_EQ(post_op_param.param.size(), 3); + float op_scale = post_op_param.param[0]; + float op_alpha = post_op_param.param[1]; + float op_beta = post_op_param.param[2]; + post_ops.APPEND_ELTWISE(op_scale, dnnl::algorithm::eltwise_gelu_erf, + op_alpha, op_beta); + } else if (post_op_param.name == "Tanh") { + DCHECK_EQ(post_op_param.param.size(), 3); + float op_scale = post_op_param.param[0]; + float op_alpha = post_op_param.param[1]; + float op_beta = post_op_param.param[2]; + post_ops.APPEND_ELTWISE(op_scale, dnnl::algorithm::eltwise_tanh, + op_alpha, op_beta); + } else if (post_op_param.name == "Sigmoid") { + DCHECK_EQ(post_op_param.param.size(), 3); + float op_scale = post_op_param.param[0]; + float op_alpha = post_op_param.param[1]; + float op_beta = post_op_param.param[2]; + post_ops.APPEND_ELTWISE(op_scale, dnnl::algorithm::eltwise_logistic, + op_alpha, op_beta); + } else if (post_op_param.name == "linear") { + DCHECK_EQ(post_op_param.param.size(), 3); + float op_scale = post_op_param.param[0]; + float op_alpha = post_op_param.param[1]; + float op_beta = post_op_param.param[2]; + post_ops.APPEND_ELTWISE(op_scale, dnnl::algorithm::eltwise_linear, + op_alpha, op_beta); +#ifndef ENABLE_ONEDNN_V3 + } else if (post_op_param.name == "output_scale") { + if (post_op_param.param.size() == 1) { + post_ops_attr.set_output_scales(0, post_op_param.param); + } else { + post_ops_attr.set_output_scales(2, post_op_param.param); + } +#else + } else if (post_op_param.name == "src_scale") { + is_scale_set.insert({"src", true}); + post_ops_attr.set_scales_mask(DNNL_ARG_SRC, 0); + context_.src_scale_md.reset(new memory::desc({1}, MklDnnType(), + memory::format_tag::x)); + context_.src_scale_mem.reset( + new memory(*context_.src_scale_md, cpu_engine_, DummyData)); + } else if (post_op_param.name == "wei_scale") { + is_scale_set.insert({"wei", true}); + const int scale_size = post_op_param.param.size(); + const int mask = scale_size == 1 ? 0 : 1; + post_ops_attr.set_scales_mask(DNNL_ARG_WEIGHTS, mask); + context_.wei_scale_md.reset(new memory::desc( + {scale_size}, MklDnnType(), memory::format_tag::x)); + context_.wei_scale_mem.reset( + new memory(*context_.wei_scale_md, cpu_engine_, DummyData)); + } else if (post_op_param.name == "dst_scale") { + is_scale_set.insert({"dst", true}); + const int scale_size = post_op_param.param.size(); + const int mask = scale_size == 1 ? 0 : 1; + post_ops_attr.set_scales_mask(DNNL_ARG_DST, mask); + context_.dst_scale_md.reset(new memory::desc({1}, MklDnnType(), + memory::format_tag::x)); + context_.dst_scale_mem.reset( + new memory(*context_.dst_scale_md, cpu_engine_, DummyData)); +#endif // !ENABLE_ONEDNN_V3 + } else if (post_op_param.name == "sum") { + DCHECK_EQ(post_op_param.param.size(), 1); + float op_scale = post_op_param.param[0]; + post_ops.append_sum(op_scale); + + } else { + DCHECK((post_op_param.name == "Relu") || + (post_op_param.name == "Relu6") || + (post_op_param.name == "Elu") || + (post_op_param.name == "GeluApproximate") || + (post_op_param.name == "GeluExact") || + (post_op_param.name == "Tanh") || + (post_op_param.name == "Sigmoid") || + (post_op_param.name == "sum") || + (post_op_param.name == "Leakyrelu") || OUTPUT_SCALE_DCHECK); + } + } + post_ops_attr.set_post_ops(post_ops); + } + +#ifndef ENABLE_ONEDNN_V3 + context_.fwd_pd.reset(new inner_product_forward::primitive_desc( + *context_.fwd_desc, post_ops_attr, cpu_engine_)); +#else + context_.fwd_pd.reset(new inner_product_forward::primitive_desc( + cpu_engine_, + matmul_fwd_params.const_weight ? prop_kind::forward_inference + : prop_kind::forward_training, + *context_.src_md, *context_.weight_md, *context_.bias_md, + *context_.dst_md, post_ops_attr)); +#endif // !ENABLE_ONEDNN_V3 + + // Create memory primitive based on dummy data + context_.src_mem.reset( + new memory(context_.fwd_pd.get()->src_desc(), cpu_engine_, DummyData)); + context_.weight_mem.reset(new memory(context_.fwd_pd.get()->weights_desc(), + cpu_engine_, DummyData)); + context_.dst_mem.reset( + new memory(context_.fwd_pd.get()->dst_desc(), cpu_engine_, DummyData)); + context_.bias_mem.reset( + new memory(context_.fwd_pd.get()->bias_desc(), cpu_engine_, DummyData)); + auto scratchpad_md = context_.fwd_pd->scratchpad_desc(); + context_.sp_mem.reset( + new dnnl::memory(scratchpad_md, cpu_engine_, DummyData)); + + // Create inner-product primitive. + context_.matmul_fwd.reset(new inner_product_forward(*context_.fwd_pd)); + std::unordered_map net_args = { + {DNNL_ARG_SRC, *context_.src_mem}, + {DNNL_ARG_WEIGHTS, *context_.weight_mem}, + {DNNL_ARG_BIAS, *context_.bias_mem}, + {DNNL_ARG_SCRATCHPAD, *context_.sp_mem}, + {DNNL_ARG_DST, *context_.dst_mem}}; +#ifdef ENABLE_ONEDNN_V3 + if (is_scale_set["src"]) { + net_args.insert( + {DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, *context_.src_scale_mem}); + } + if (is_scale_set["wei"]) { + net_args.insert( + {DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, *context_.wei_scale_mem}); + } + if (is_scale_set["dst"]) { + net_args.insert( + {DNNL_ARG_ATTR_SCALES | DNNL_ARG_DST, *context_.dst_scale_mem}); + } +#endif // ENABLE_ONEDNN_V3 + context_.net_args.push_back(net_args); + context_.fwd_primitives.push_back(*context_.matmul_fwd); + return; + } + + struct MklDnnMatMulFwdContext context_; + +#if defined(DNNL_AARCH64_USE_ACL) && defined(ENABLE_ONEDNN_OPENMP) + // Guards Execution() + mutex primitive_execution_mu_; +#endif +}; + +template +class MklDnnMatMulFwdPrimitiveFactory : public MklPrimitiveFactory { + public: + static MklDnnMatMulFwdPrimitive* Get( + const MklDnnMatMulFwdParams& mkldnn_matmul_fwd_dims, bool do_not_cache) { + MklDnnMatMulFwdPrimitive* matmul_fwd = + nullptr; + + if (do_not_cache) { + // Always create new primitive + matmul_fwd = + new MklDnnMatMulFwdPrimitive( + mkldnn_matmul_fwd_dims); + } else { + // Try to find a suitable one in pool + matmul_fwd = dynamic_cast< + MklDnnMatMulFwdPrimitive*>( + MklDnnMatMulFwdPrimitiveFactory::GetInstance() + .GetMklDnnMatMulFwd(mkldnn_matmul_fwd_dims)); + if (matmul_fwd == nullptr) { + matmul_fwd = + new MklDnnMatMulFwdPrimitive( + mkldnn_matmul_fwd_dims); + MklDnnMatMulFwdPrimitiveFactory::GetInstance() + .SetMklDnnMatMulFwd(mkldnn_matmul_fwd_dims, matmul_fwd); + } + } + return matmul_fwd; + } + + private: + MklDnnMatMulFwdPrimitiveFactory() {} + ~MklDnnMatMulFwdPrimitiveFactory() {} + + static MklDnnMatMulFwdPrimitiveFactory& GetInstance() { + static MklDnnMatMulFwdPrimitiveFactory instance_; + return instance_; + } + + static string CreateKey(const MklDnnMatMulFwdParams& mkldnn_matmul_fwd_dims) { + string prefix = "matmul_fwd_"; + FactoryKeyCreator key_creator; + key_creator.AddAsKey(prefix); + key_creator.AddAsKey(mkldnn_matmul_fwd_dims.src_dims); + key_creator.AddAsKey(mkldnn_matmul_fwd_dims.weight_dims); + key_creator.AddAsKey(mkldnn_matmul_fwd_dims.bias_dims); + key_creator.AddAsKey(mkldnn_matmul_fwd_dims.dst_dims); + key_creator.AddAsKey(mkldnn_matmul_fwd_dims.dtypes); + key_creator.AddAsKey(mkldnn_matmul_fwd_dims.weight_format); + + // Generate keys for post-ops + for (auto const& post_op_param : mkldnn_matmul_fwd_dims.post_op_params) { + if (post_op_param.name == "Relu" || post_op_param.name == "Relu6" || + post_op_param.name == "Elu" || post_op_param.name == "Tanh" || + post_op_param.name == "Sigmoid" || + post_op_param.name == "LeakyRelu" || + post_op_param.name == "GeluApproximate" || + post_op_param.name == "GeluExact" || post_op_param.name == "linear") { + DCHECK_EQ(post_op_param.param.size(), 3); + key_creator.AddAsKey(post_op_param.name); + key_creator.AddAsKey(post_op_param.param[0]); + key_creator.AddAsKey(post_op_param.param[1]); + key_creator.AddAsKey(post_op_param.param[2]); + } else if (post_op_param.name == "sum") { + DCHECK_EQ(post_op_param.param.size(), 1); + key_creator.AddAsKey(post_op_param.name); + key_creator.AddAsKey(post_op_param.param[0]); +#ifndef ENABLE_ONEDNN_V3 + } else if (post_op_param.name == "output_scale") { +#else + } else if (post_op_param.name == "src_scale" || + post_op_param.name == "wei_scale" || + post_op_param.name == "dst_scale") { +#endif // !ENABLE_ONEDNN_V3 + key_creator.AddAsKey(post_op_param.name); + if (post_op_param.partial_key.empty()) { + DCHECK_GE(post_op_param.param.size(), 1); + // Old Quantized MatMul kernels do not create part of key beforehand + // as primitive caching-key-creation optimization. + key_creator.AddAsKey(post_op_param.param[0]); + } else { + // New Quantized MatMul kernels pre-create partial key. + key_creator.AddAsKey(post_op_param.partial_key); + } + } else { + return string("not_a_key"); + } + } + return key_creator.GetKey(); + } + + MklPrimitive* GetMklDnnMatMulFwd( + const MklDnnMatMulFwdParams& mkldnn_matmul_fwd_dims) { + string key = CreateKey(mkldnn_matmul_fwd_dims); + return this->GetOp(key); + } + + void SetMklDnnMatMulFwd(const MklDnnMatMulFwdParams& mkldnn_matmul_fwd_dims, + MklPrimitive* op) { + string key = CreateKey(mkldnn_matmul_fwd_dims); + this->SetOp(key, op); + } +}; + +template +class MklDnnMatMulOpBase : public OpKernel { + public: + explicit MklDnnMatMulOpBase(OpKernelConstruction* context) + : OpKernel(context) {} + void Compute(OpKernelContext* context) override = 0; + + // Allocate output tensor. + virtual void AllocateOutputTensor( + OpKernelContext* context, + const inner_product_forward::primitive_desc& mkldnn_matmul_prim_desc, + const memory::dims& output_dims_mkl_order, + MklTensorFormat output_tf_format, Tensor** output_tensor, + bool native_format = false) { + DCHECK(output_tensor); + auto dst_pd = mkldnn_matmul_prim_desc.dst_desc(); + + MklDnnShape output_mkl_shape; + output_mkl_shape.SetMklTensor(true); + output_mkl_shape.SET_MKL_LAYOUT(dst_pd); + output_mkl_shape.SetElemType(MklDnnType()); + output_mkl_shape.SetTfLayout(output_dims_mkl_order.size(), + output_dims_mkl_order, output_tf_format); + + TensorShape output_tf_shape; + output_tf_shape.AddDim((dst_pd.get_size() / sizeof(Toutput))); + + if (native_format) { + output_tf_shape = output_mkl_shape.GetTfShape(); + } + // Allocate Output Tensor + AllocateOutputSetMklShape(context, kOutputIndexDst, output_tensor, + output_tf_shape, output_mkl_shape, native_format); + } + + // TF_LOCKS_EXCLUDED annotation ensures that the lock (mu_) cannot + // be acquired before entering the function, since it is acquired + // inside the function. + inline bool IsWeightCacheEmpty(OpKernelContext* context) + TF_LOCKS_EXCLUDED(mu_) { + tf_shared_lock lock(mu_); + return (weight_oi_.NumElements() == 0); + } + + // Cache the converted weight in a tensor. + // Only one thread can execute this method at any given time. + void CacheWeight( + OpKernelContext* context, + const std::shared_ptr& + matmul_fwd_pd, + Tweight* weight_data, const Tensor& weight_tensor, + MklDnnData& weight, const memory::desc& weight_md) + TF_LOCKS_EXCLUDED(mu_) { + mutex_lock lock(mu_); + const Tensor& weight_t = weight_oi_; + + // If the weights are already cached, there's nothing to do + if (weight_t.NumElements() > 0) { + return; + } + +#ifdef ENABLE_ONEDNN_V3 + // For now, cache weights only for blocked format + if (weight_md.get_format_kind() != memory::format_kind::blocked) { + return; + } +#endif // ENABLE_ONEDNN_V3 + + // reorder and cache the weight + weight.SetUsrMem(weight_md, &weight_tensor); + weight.CheckReorderToOpMem(matmul_fwd_pd.get()->weights_desc(), cpu_engine_, + context); + weight_data = static_cast(weight.GetOpMem().get_data_handle()); + + size_t weight_size = matmul_fwd_pd.get()->weights_desc().get_size(); + TensorShape weight_tf_shape; + weight_tf_shape.AddDim(weight_size / sizeof(Tweight)); + + OP_REQUIRES_OK(context, + context->allocate_temp(DataTypeToEnum::value, + weight_tf_shape, &weight_oi_)); + + void* weight_oi_t_data = weight.GetTensorBuffer(&weight_oi_); + memcpy(weight_oi_t_data, weight_data, weight_size); + + // cache the memory descriptor + auto expected_md = matmul_fwd_pd->weights_desc(); +#ifndef ENABLE_ONEDNN_V3 + TensorShape weight_mkl_format; + weight_mkl_format.AddDim(sizeof(expected_md) / sizeof(Tweight)); + + OP_REQUIRES_OK(context, + context->allocate_temp(DataTypeToEnum::value, + weight_mkl_format, &weight_oi_md_)); + *reinterpret_cast(weight_oi_md_.flat().data()) = + expected_md; +#else + weight_oi_md_ = FilterMemoryDesc( + expected_md.get_ndims(), expected_md.get_inner_nblks(), + expected_md.get_data_type(), expected_md.get_dims(), + expected_md.get_inner_blks(), expected_md.get_inner_idxs(), + expected_md.get_strides()); +#endif // !ENABLE_ONEDNN_V3 + } + + Tweight* GetCachedWeight(OpKernelContext* context, + const memory::desc& expected_md) + TF_LOCKS_EXCLUDED(mu_) { + tf_shared_lock lock(mu_); + const Tensor& weight_t = weight_oi_; +#ifndef ENABLE_ONEDNN_V3 + const Tensor& weight_md_t = weight_oi_md_; + + // Check if the memory descriptor of the cached weight is same as + // expected_md. if so use the cached memory, else return NULL + if (weight_md_t.flat().size()) { + const memory::desc& stored_md = + *(static_cast(weight_md_t.data())); + if (stored_md == expected_md) { + return static_cast( + const_cast(weight_t.flat().data())); + } + } + return nullptr; +#else + // Return the cached weights only if the dimensions of the cached weights + // and the current weights match. Otherwise, return nullptr. + // + // TODO(intel-tf): The following check assumes that all dimensions are + // known before checking for equality. We may have to modify it in the + // future once we support runtime dimensions (especially if the dimensions + // are still unknown at this point). + if (weight_oi_md_ == + FilterMemoryDesc(expected_md.get_ndims(), expected_md.get_inner_nblks(), + expected_md.get_data_type(), expected_md.get_dims(), + expected_md.get_inner_blks(), + expected_md.get_inner_idxs(), + expected_md.get_strides())) { + return static_cast( + const_cast(weight_t.flat().data())); + } + return nullptr; +#endif // !ENABLE_ONEDNN_V3 + } + + bool IsBiasCacheEmpty() TF_LOCKS_EXCLUDED(bias_cache_mutex_) { + tf_shared_lock lock(bias_cache_mutex_); + return (cached_bias_data_pt_.NumElements() == 0); + } + + virtual bool IsCachedBiasValid(float, float) + TF_SHARED_LOCKS_REQUIRED(bias_cache_mutex_) { + return false; + } + + void CacheBias(OpKernelContext* ctx, const Tensor& temp_scaled_bias_tensor, + float min_input, float max_input) + TF_LOCKS_EXCLUDED(bias_cache_mutex_) { + mutex_lock lock(bias_cache_mutex_); + if (cached_bias_data_pt_.NumElements() > 0) { + return; + } + OP_REQUIRES_OK(ctx, ctx->allocate_temp(temp_scaled_bias_tensor.dtype(), + temp_scaled_bias_tensor.shape(), + &cached_bias_data_pt_)); + tensor::DeepCopy(temp_scaled_bias_tensor, &cached_bias_data_pt_); + saved_min_input_ = min_input; + saved_max_input_ = max_input; + } + + void GetCachedBias(float min_input, float max_input, void** bias_data) + TF_LOCKS_EXCLUDED(bias_cache_mutex_) { + tf_shared_lock lock(bias_cache_mutex_); + const Tensor& cached_bias_data = cached_bias_data_pt_; + if (IsCachedBiasValid(min_input, max_input)) { + *bias_data = static_cast(const_cast( + cached_bias_data.flat().data())); + } else { + *bias_data = nullptr; + } + } + + engine cpu_engine_ = engine(engine::kind::cpu, 0); + + protected: + // Tensor to save reordered weight + mutex mu_; + Tensor weight_oi_ TF_GUARDED_BY(mu_); +#ifndef ENABLE_ONEDNN_V3 + Tensor weight_oi_md_ TF_GUARDED_BY(mu_); +#else + FilterMemoryDesc weight_oi_md_ TF_GUARDED_BY(mu_); +#endif // !ENABLE_ONEDNN_V3 + + bool is_weight_const_; + + bool is_bias_const_; + mutex bias_cache_mutex_; + // Persistent tensor for cached bias. + Tensor cached_bias_data_pt_ TF_GUARDED_BY(bias_cache_mutex_); + float saved_min_input_ = -std::numeric_limits::infinity(); + float saved_max_input_ = std::numeric_limits::infinity(); + + const int kInputIndexSrc = 0; + const int kInputIndexWeight = 1; + const int kInputIndexBias = 2; + const int kOutputIndexDst = 0; +}; + +using dnnl::matmul; + +namespace { + +struct MklMatMulParams { + string prefix; + memory::dims a_dims; + memory::dims b_dims; + memory::dims c_dims; + memory::dims a_strides; + memory::dims b_strides; + memory::dims c_strides; + memory::dim a_nnz; + struct PostOpParam { + string name; + std::vector param; + memory::dims dims; + memory::data_type data_type; + memory::format_tag format_tag; + }; + std::vector post_op_params; + + MklMatMulParams(string prefix, memory::dims a_dims, memory::dims b_dims, + memory::dims c_dims, memory::dims a_strides, + memory::dims b_strides, memory::dims c_strides, + memory::dim a_nnz = 0) + : prefix(prefix), + a_dims(a_dims), + b_dims(b_dims), + c_dims(c_dims), + a_strides(a_strides), + b_strides(b_strides), + c_strides(c_strides), + a_nnz(a_nnz) {} +}; + +template +class MklMatMulPrimitive : public MklPrimitive { + public: + explicit MklMatMulPrimitive(const MklMatMulParams& params) + : MklPrimitive(engine(engine::kind::cpu, 0)) { + // Create matmul primitive + Setup(params); + } + + ~MklMatMulPrimitive() {} + + dnnl::memory::desc GetScratchPadDesc() { + return context_.prim_desc->scratchpad_desc(); + } + + void Execute(const std::shared_ptr& stream, const Tlhs* a_data, + const Trhs* b_data, const Toutput* c_data, void* sp_data, + void* mul_data = nullptr, void* add_data = nullptr, + const int32_t* a_col_indices = nullptr, + const int32_t* a_row_pointers = nullptr) { +#if defined(DNNL_AARCH64_USE_ACL) && defined(ENABLE_ONEDNN_OPENMP) + mutex_lock lock(primitive_execution_mu_); +#endif +#if !defined(ENABLE_ONEDNN_OPENMP) && !defined(ENABLE_ONEDNN_V3) + context_.a_mem->set_data_handle( + static_cast(const_cast(a_data)), *stream); + context_.b_mem->set_data_handle( + static_cast(const_cast(b_data)), *stream); + context_.c_mem->set_data_handle( + static_cast(const_cast(c_data)), *stream); + + if (sp_data != nullptr) context_.sp_mem->set_data_handle(sp_data, *stream); + if (mul_data != nullptr) + context_.mul_mem->set_data_handle(mul_data, *stream); + if (add_data != nullptr) + context_.add_mem->set_data_handle(add_data, *stream); +#else + if constexpr (CSR) { + context_.a_mem->set_data_handle( + static_cast(const_cast(a_data)), 0); + context_.a_mem->set_data_handle( + static_cast(const_cast(a_col_indices)), 1); + context_.a_mem->set_data_handle( + static_cast(const_cast(a_row_pointers)), 2); + } else { + context_.a_mem->set_data_handle( + static_cast(const_cast(a_data))); + } + context_.b_mem->set_data_handle( + static_cast(const_cast(b_data))); + context_.c_mem->set_data_handle( + static_cast(const_cast(c_data))); + if (sp_data != nullptr) context_.sp_mem->set_data_handle(sp_data); + if (mul_data != nullptr) context_.mul_mem->set_data_handle(mul_data); + if (add_data != nullptr) context_.add_mem->set_data_handle(add_data); +#endif // !ENABLE_ONEDNN_OPENMP && !ENABLE_ONEDNN_V3 + execute_primitives(context_.matmul_primitives, stream, context_.net_args); + + // After execution, set data handle back + context_.a_mem->set_data_handle(DummyData); + context_.b_mem->set_data_handle(DummyData); + context_.c_mem->set_data_handle(DummyData); + if (sp_data != nullptr) context_.sp_mem->set_data_handle(DummyData); + if (mul_data != nullptr) context_.mul_mem->set_data_handle(DummyData); + if (add_data != nullptr) context_.add_mem->set_data_handle(DummyData); + } + + std::shared_ptr GetPrimitiveDesc() const { + return context_.prim_desc; + } + + private: + // Primitive reuse context for MatMul op + struct MklMatMulContext { + // oneDNN memory. + std::shared_ptr a_mem; + std::shared_ptr b_mem; + std::shared_ptr c_mem; + std::shared_ptr mul_mem; + std::shared_ptr add_mem; + std::shared_ptr sp_mem; + + // Descriptor and primitive-descriptor for MatMul. +#ifndef ENABLE_ONEDNN_V3 + std::shared_ptr desc; +#endif // !ENABLE_ONEDNN_V3 + std::shared_ptr prim_desc; + + // Memory descriptors. + std::shared_ptr a_md; + std::shared_ptr b_md; + std::shared_ptr c_md; + std::shared_ptr mul_md; + std::shared_ptr add_md; + + // MatMul primitive. + std::vector matmul_primitives; + std::vector> net_args; + + MklMatMulContext() + : a_mem(nullptr), + b_mem(nullptr), + c_mem(nullptr), + mul_mem(nullptr), + add_mem(nullptr), + sp_mem(nullptr), +#ifndef ENABLE_ONEDNN_V3 + desc(nullptr), +#endif // !ENABLE_ONEDNN_V3 + prim_desc(nullptr), + a_md(nullptr), + b_md(nullptr), + c_md(nullptr), + mul_md(nullptr), + add_md(nullptr) { + } + }; + + void Setup(const MklMatMulParams& params) { + std::shared_ptr matmul_primitive = nullptr; + + // Create MatMul descriptor and primitive descriptor. + if constexpr (CSR) { + // If it's a CSR matrix. +#ifdef ENABLE_ONEDNN_V3 + const auto tmp = memory::desc::csr( + params.a_dims, MklDnnType(), params.a_nnz, + dnnl::memory::data_type::s32, dnnl::memory::data_type::s32); + context_.a_md.reset(new memory::desc(tmp)); +#endif // ENABLE_ONEDNN_V3 + } else { + context_.a_md.reset(new memory::desc({params.a_dims}, MklDnnType(), + params.a_strides)); + } + + context_.b_md.reset(new memory::desc({params.b_dims}, MklDnnType(), +#ifdef DNNL_AARCH64_USE_ACL + memory::format_tag::any)); +#else + params.b_strides)); +#endif + context_.c_md.reset(new memory::desc({params.c_dims}, MklDnnType(), + params.c_strides)); + + // Create matmul. +#ifndef ENABLE_ONEDNN_V3 + context_.desc.reset( + new matmul::desc(*context_.a_md, *context_.b_md, *context_.c_md)); +#endif // !ENABLE_ONEDNN_V3 + + // Check if there is any fusion as post-ops + auto const& post_op_params = params.post_op_params; + dnnl::primitive_attr post_ops_attr; + dnnl::post_ops post_ops; + if (!post_op_params.empty()) { + for (auto const& post_op_param : post_op_params) { + if (post_op_param.name == "output_scale") { +#ifndef ENABLE_ONEDNN_V3 + // TODO(intel-tf): Verify if this code is needed. If not, it needs to + // be removed. + DCHECK_EQ(post_op_param.param.size(), 1); + std::vector scales; + scales.push_back(post_op_param.param[0]); + post_ops_attr.set_output_scales(0, scales); +#endif // !ENABLE_ONEDNN_V3 + } else if (post_op_param.name == "mul") { + context_.mul_md.reset(new memory::desc({post_op_param.dims}, + post_op_param.data_type, + post_op_param.format_tag)); + post_ops.append_binary(dnnl::algorithm::binary_mul, *context_.mul_md); + } else if (post_op_param.name == "add") { + context_.add_md.reset(new memory::desc({post_op_param.dims}, + post_op_param.data_type, + post_op_param.format_tag)); + post_ops.append_binary(dnnl::algorithm::binary_add, *context_.add_md); + } else { + DCHECK((post_op_param.name == "output_scale")); + } + } + post_ops_attr.set_post_ops(post_ops); + } + post_ops_attr.set_scratchpad_mode(dnnl::scratchpad_mode::user); +#ifndef ENABLE_ONEDNN_V3 + context_.prim_desc.reset( + new matmul::primitive_desc(*context_.desc, post_ops_attr, cpu_engine_)); +#else + context_.prim_desc.reset( + new matmul::primitive_desc(cpu_engine_, *context_.a_md, *context_.b_md, + *context_.c_md, post_ops_attr)); +#endif // !ENABLE_ONEDNN_V3 + + // Create memory primitive based on dummy data. + if constexpr (CSR) { + context_.a_mem.reset(new dnnl::memory(*context_.a_md, cpu_engine_, + std::vector(3, DummyData))); + } else { + context_.a_mem.reset( + new dnnl::memory(*context_.a_md, cpu_engine_, DummyData)); + } +#ifdef DNNL_AARCH64_USE_ACL + context_.b_mem.reset(new dnnl::memory( + context_.prim_desc.get()->weights_desc(), cpu_engine_, DummyData)); +#else + context_.b_mem.reset( + new dnnl::memory(*context_.b_md, cpu_engine_, DummyData)); +#endif + context_.c_mem.reset( + new dnnl::memory(*context_.c_md, cpu_engine_, DummyData)); + auto scratchpad_md = context_.prim_desc->scratchpad_desc(); + context_.sp_mem.reset( + new dnnl::memory(scratchpad_md, cpu_engine_, DummyData)); + + // Create matmul primitive. + matmul_primitive.reset(new dnnl::matmul(*context_.prim_desc)); + context_.net_args.push_back({{DNNL_ARG_SRC, *context_.a_mem}, + {DNNL_ARG_WEIGHTS, *context_.b_mem}, + {DNNL_ARG_SCRATCHPAD, *context_.sp_mem}, + {DNNL_ARG_DST, *context_.c_mem}}); + if (!post_op_params.empty()) { + int count = 0; + for (auto const& post_op_param : post_op_params) { + if (post_op_param.name == "mul") { + context_.mul_mem.reset( + new dnnl::memory(*context_.mul_md, cpu_engine_, DummyData)); + context_.net_args[0].insert( + {DNNL_ARG_ATTR_MULTIPLE_POST_OP(count) | DNNL_ARG_SRC_1, + *context_.mul_mem}); + count++; + } else if (post_op_param.name == "add") { + context_.add_mem.reset( + new dnnl::memory(*context_.add_md, cpu_engine_, DummyData)); + context_.net_args[0].insert( + {DNNL_ARG_ATTR_MULTIPLE_POST_OP(count) | DNNL_ARG_SRC_1, + *context_.add_mem}); + count++; + } + } + } + + context_.matmul_primitives.push_back(*matmul_primitive); + return; + } + + struct MklMatMulContext context_; +#if defined(DNNL_AARCH64_USE_ACL) && defined(ENABLE_ONEDNN_OPENMP) + mutex primitive_execution_mu_; +#endif +}; + +template +class MklMatMulPrimitiveFactory : public MklPrimitiveFactory { + public: + static MklMatMulPrimitive* Get( + const MklMatMulParams& params, bool do_not_cache) { + MklMatMulPrimitive* matmul_prim = nullptr; + + if (do_not_cache) { + // Always create new primitive + matmul_prim = new MklMatMulPrimitive(params); + } else { + // Try to find a suitable one in pool + matmul_prim = dynamic_cast*>( + MklMatMulPrimitiveFactory::GetInstance() + .GetMklMatMul(params)); + if (matmul_prim == nullptr) { + matmul_prim = new MklMatMulPrimitive(params); + MklMatMulPrimitiveFactory::GetInstance() + .SetMklMatMul(params, matmul_prim); + } + } + + return matmul_prim; + } + + private: + MklMatMulPrimitiveFactory() {} + ~MklMatMulPrimitiveFactory() {} + + static MklMatMulPrimitiveFactory& GetInstance() { + static MklMatMulPrimitiveFactory instance_; + return instance_; + } + + static string CreateKey(const MklMatMulParams& params) { + FactoryKeyCreator key_creator; + key_creator.AddAsKey(params.prefix); + key_creator.AddAsKey(params.a_dims); + key_creator.AddAsKey(params.b_dims); + key_creator.AddAsKey(params.c_dims); + key_creator.AddAsKey(params.a_strides); + key_creator.AddAsKey(params.b_strides); + key_creator.AddAsKey(params.c_strides); + key_creator.AddAsKey(typeid(T).name()); + key_creator.AddAsKey(typeid(Tlhs).name()); + key_creator.AddAsKey(typeid(Trhs).name()); + key_creator.AddAsKey(typeid(Toutput).name()); + + // Generate keys for post-ops + for (auto const& post_op_param : params.post_op_params) { + if (post_op_param.name == "output_scale") { + DCHECK_EQ(post_op_param.param.size(), 1); + key_creator.AddAsKey(post_op_param.name); + key_creator.AddAsKey(post_op_param.param[0]); + } else if (post_op_param.name == "mul" || post_op_param.name == "add") { + key_creator.AddAsKey(post_op_param.name); + key_creator.AddAsKey(post_op_param.dims); + } else { + return string("not_a_key"); + } + } + return key_creator.GetKey(); + } + + MklPrimitive* GetMklMatMul(const MklMatMulParams& params) { + string key = CreateKey(params); + return this->GetOp(key); + } + + void SetMklMatMul(const MklMatMulParams& params, MklPrimitive* op) { + string key = CreateKey(params); + this->SetOp(key, op); + } +}; + +template +void dnnl_gemm(char transa, char transb, int64_t m, int64_t n, int64_t k, + float alpha, const T* a, int64_t lda, const T* b, int64_t ldb, + float beta, T* c, int64_t ldc, OpKernelContext* ctx = nullptr) { + using dims = dnnl::memory::dims; + + // Prepare strides based on the transa and transb flags: transposed + // matrices have strides swapped + dims a_dims = dims{m, k}; + dims b_dims = dims{k, n}; + dims c_dims = dims{m, n}; + dims a_strides = tolower(transa) == 'n' ? dims{lda, 1} : dims{1, lda}; + dims b_strides = tolower(transb) == 'n' ? dims{ldb, 1} : dims{1, ldb}; + dims c_strides = dims{ldc, 1}; + + // MklMatMul uses const alpha and beta, make guarantee here to ensure + // they are never changed. + DCHECK_EQ(alpha, 1.0f); + DCHECK_EQ(beta, 0.f); + + MklMatMulParams params("dnnl_gemm", a_dims, b_dims, c_dims, a_strides, + b_strides, c_strides); + auto st = ExecuteSingleThreadedGemm(m, n, k, sizeof(T)); + // Create the oneDNN wrapper over Eigen threadpool and set max threads + // in oneDNN. + Eigen::ThreadPoolInterface* eigen_interface = + EigenThreadPoolFromTfContext(ctx); + tsl::OneDnnThreadPool eigen_tp(eigen_interface, ThreadPoolUseCallerThread(), + st ? 1 : -1); + MklMatMulPrimitive* matmul_prim = + MklMatMulPrimitiveFactory::Get(params, 0); + + UserScratchPad scratch_pad; + scratch_pad.AllocateSPTensor(matmul_prim, ctx); + // Execute matmul primitive. + + std::shared_ptr cpu_stream; + + cpu_stream.reset(CreateStream(&eigen_tp, matmul_prim->GetEngine())); + matmul_prim->Execute(cpu_stream, a, b, c, scratch_pad.Get()); +} + +} // anonymous namespace + +#undef APPEND_ELTWISE +#undef APPEND_ELTWISE_RELU6 +#undef OUTPUT_SCALE_DCHECK +#undef SET_MKL_LAYOUT +#undef TSCALED_BIAS + +} // namespace tensorflow + +#endif // INTEL_MKL +#endif // TENSORFLOW_CORE_KERNELS_MKL_MKL_MATMUL_OPS_COMMON_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/mkl/mkl_pooling_ops_common.h b/third_party/tflite-hdrs/tensorflow/core/kernels/mkl/mkl_pooling_ops_common.h new file mode 100644 index 00000000..da031d5c --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/mkl/mkl_pooling_ops_common.h @@ -0,0 +1,808 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_MKL_MKL_POOLING_OPS_COMMON_H_ +#define TENSORFLOW_CORE_KERNELS_MKL_MKL_POOLING_OPS_COMMON_H_ + +#ifdef INTEL_MKL + +#include +#include +#include + +#include "dnnl.hpp" +#include "tensorflow/core/framework/kernel_shape_util.h" +#include "tensorflow/core/framework/ops_util.h" +#include "tensorflow/core/util/mkl_util.h" +#include "tensorflow/core/util/padding.h" +#if defined(DNNL_AARCH64_USE_ACL) && defined(ENABLE_ONEDNN_OPENMP) +#include "tensorflow/core/platform/mutex.h" +#endif + +namespace tensorflow { + +#ifndef ENABLE_ONEDNN_V3 +#define GET_DIMS data.dims +#define SET_MKL_LAYOUT(md) SetMklLayout(&md) +#else +#define GET_DIMS get_dims() +#define SET_MKL_LAYOUT(md) SetMklLayout(md) +#endif // !ENABLE_ONEDNN_V3 + +using dnnl::pooling_backward; +using dnnl::pooling_forward; +using dnnl::prop_kind; +using dnnl::stream; + +using PoolingFwdPd = dnnl::pooling_forward::primitive_desc; +using PoolingBwdPd = dnnl::pooling_backward::primitive_desc; + +struct MklPoolingParams { + memory::dims src_dims; + memory::dims dst_dims; + memory::dims filter_dims; + memory::dims strides; +#ifdef ENABLE_ONEDNN_V3 + memory::dims dilations; +#endif // ENABLE_ONEDNN_V3 + memory::dims padding_left; + memory::dims padding_right; + dnnl::algorithm alg_kind; + dnnl::prop_kind prop_kind; + memory::format_tag src_format; + memory::desc src_md; + bool native_format; + + MklPoolingParams(memory::dims src_dims, memory::dims dst_dims, + memory::dims filter_dims, memory::dims strides, +#ifdef ENABLE_ONEDNN_V3 + memory::dims dilations, +#endif // ENABLE_ONEDNN_V3 + memory::dims padding_left, memory::dims padding_right, + dnnl::algorithm alg_kind, dnnl::prop_kind prop_kind, + memory::format_tag src_format, memory::desc src_md, + bool native_format) + : src_dims(src_dims), + dst_dims(dst_dims), + filter_dims(filter_dims), + strides(strides), +#ifdef ENABLE_ONEDNN_V3 + dilations(dilations), +#endif // ENABLE_ONEDNN_V3 + padding_left(padding_left), + padding_right(padding_right), + alg_kind(alg_kind), + prop_kind(prop_kind), + src_format(src_format), + src_md(src_md), + native_format(native_format) { + } +}; + +template +class MklPoolingFwdPrimitive : public MklPrimitive { + public: + explicit MklPoolingFwdPrimitive(const MklPoolingParams& fwdParams) + : MklPrimitive(engine(engine::kind::cpu, 0)) { + if (context_.fwd == nullptr) Setup(fwdParams); + } + + ~MklPoolingFwdPrimitive() {} + + // Pooling forward execute + // src_data: input data buffer of src + // ws_data: output data buffer of workspace + // dst_data: output data buffer of dst + void Execute(const T* src_data, T* dst_data, void* ws_data, + std::shared_ptr fwd_stream); + + std::shared_ptr GetPoolingFwdPd() const { + return context_.fwd_pd; + } + + memory::format_tag GetSrcMemoryFormat() const { return context_.src_fmt; } + memory::format_tag GetDstMemoryFormat() const { return context_.dst_fmt; } + + private: + void Setup(const MklPoolingParams& fwdParams); + + struct PoolingFwdContext { + // Algorithm. + dnnl::algorithm alg_kind; + + // Kind of propagation, forward or backward. + dnnl::prop_kind prop_kind; + + // Expected memory format. + memory::format_tag src_fmt; + memory::format_tag dst_fmt; + memory::format_tag ws_fmt; + + // Workspace shape. + memory::data_type ws_dt; + size_t ws_size; + + // oneDNN memory, just dummy data. + std::shared_ptr ws_mem; + std::shared_ptr src_mem; + std::shared_ptr dst_mem; + + // Pooling forward descriptor and primitive descriptor. +#ifndef ENABLE_ONEDNN_V3 + std::shared_ptr fwd_desc; +#endif // !ENABLE_ONEDNN_V3 + std::shared_ptr fwd_pd; + + // Memory descriptor. + std::shared_ptr src_md; + std::shared_ptr dst_md; + + // Pooling primitive + std::shared_ptr fwd; + std::shared_ptr fwd_stream; + std::vector fwd_primitives; + + std::vector> net_args; + + PoolingFwdContext() + : src_fmt(memory::format_tag::any), + dst_fmt(memory::format_tag::any), + ws_fmt(memory::format_tag::any), + ws_dt(memory::data_type::u8), + ws_size(0), + ws_mem(nullptr), + src_mem(nullptr), + dst_mem(nullptr), +#ifndef ENABLE_ONEDNN_V3 + fwd_desc(nullptr), +#endif // !ENABLE_ONEDNN_V3 + fwd_pd(nullptr), + src_md(nullptr), + dst_md(nullptr), + fwd(nullptr) { + } + }; + + struct PoolingFwdContext context_; + +#if defined(DNNL_AARCH64_USE_ACL) && defined(ENABLE_ONEDNN_OPENMP) + mutex primitive_execution_mu_; +#endif +}; + +template +class MklPoolingFwdPrimitiveFactory : public MklPrimitiveFactory { + public: + static MklPoolingFwdPrimitive* Get(const MklPoolingParams& fwdParams) { + MklPoolingFwdPrimitive* pooling_forward = nullptr; + // Get pooling primitive from the pool + pooling_forward = static_cast*>( + MklPoolingFwdPrimitiveFactory::GetInstance().GetPoolingFwd( + fwdParams)); + + if (pooling_forward == nullptr) { + pooling_forward = new MklPoolingFwdPrimitive(fwdParams); + MklPoolingFwdPrimitiveFactory::GetInstance().SetPoolingFwd( + fwdParams, pooling_forward); + } + return pooling_forward; + } + + static MklPoolingFwdPrimitiveFactory& GetInstance() { + static MklPoolingFwdPrimitiveFactory instance_; + return instance_; + } + + private: + MklPoolingFwdPrimitiveFactory() {} + ~MklPoolingFwdPrimitiveFactory() {} + + // The key to be created will be used to get/set pooling + // primitive op from reuse perspective. + // A pooling key is a string which concates key parameters + // as well as algorithm kind (max versus avg). + static string CreateKey(const MklPoolingParams& fwdParams) { + string prefix = "pooling_fwd"; + FactoryKeyCreator key_creator; + key_creator.AddAsKey(prefix); + key_creator.AddAsKey(fwdParams.src_dims); + key_creator.AddAsKey(fwdParams.dst_dims); + key_creator.AddAsKey(fwdParams.filter_dims); + key_creator.AddAsKey(fwdParams.strides); +#ifdef ENABLE_ONEDNN_V3 + key_creator.AddAsKey(fwdParams.dilations); +#endif // ENABLE_ONEDNN_V3 + key_creator.AddAsKey(fwdParams.padding_left); + key_creator.AddAsKey(fwdParams.padding_right); + key_creator.AddAsKey(fwdParams.src_format); + key_creator.AddAsKey(static_cast(fwdParams.alg_kind)); + key_creator.AddAsKey(static_cast(fwdParams.prop_kind)); + return key_creator.GetKey(); + } + + MklPrimitive* GetPoolingFwd(const MklPoolingParams& fwdParams) { + string key = CreateKey(fwdParams); + return this->GetOp(key); + } + + void SetPoolingFwd(const MklPoolingParams& fwdParams, MklPrimitive* op) { + string key = CreateKey(fwdParams); + this->SetOp(key, op); + } +}; + +template +class MklPoolingBwdPrimitive : public MklPrimitive { + public: + explicit MklPoolingBwdPrimitive(const MklPoolingParams& bwdParams) + : MklPrimitive(engine(engine::kind::cpu, 0)) { + if (context_.bwd == nullptr) Setup(bwdParams); + } + + ~MklPoolingBwdPrimitive() {} + + // Pooling backward execute + // diff_dst_data: input data buffer of diff_dst + // diff_src_data: output data buffer of diff_src + // ws_data: input data buffer of workspace + void Execute(const T* diff_dst_data, T* diff_src_data, const void* ws_data, + std::shared_ptr bwd_stream); + + public: + std::shared_ptr GetPoolingFwdPd() const { + return context_.fwd_pd; + } + std::shared_ptr GetPoolingBwdPd() const { + return context_.bwd_pd; + } + + dnnl::memory::data_type GetWorkspaceDataType() const { + return context_.ws_dt; + } + + private: + void Setup(const MklPoolingParams& bwdParams); + + // Primitive reuse context for pooling bwd ops + struct PoolingBwdContext { + // Algorithm. + dnnl::algorithm alg_kind; + + // Expected memory format. + memory::format_tag diff_src_fmt; + memory::format_tag diff_dst_fmt; + memory::format_tag ws_fmt; + + // Workspace attribute. + dnnl::memory::data_type ws_dt; + + // oneDNN memory. + std::shared_ptr ws_mem; + std::shared_ptr diff_src_mem; + std::shared_ptr diff_dst_mem; + + // Memory descriptors. + std::shared_ptr src_md; + std::shared_ptr dst_md; + + // Forward and backward pooling descriptors and primitive descriptors. +#ifndef ENABLE_ONEDNN_V3 + std::shared_ptr fwd_desc; + std::shared_ptr bwd_desc; +#endif // !ENABLE_ONEDNN_V3 + std::shared_ptr fwd_pd; + std::shared_ptr bwd_pd; + + // Backward pooling primitive. + std::shared_ptr bwd; + std::shared_ptr bwd_stream; + + std::vector bwd_primitives; + std::vector> net_args; + + PoolingBwdContext() + : diff_src_fmt(memory::format_tag::any), + diff_dst_fmt(memory::format_tag::any), + ws_fmt(memory::format_tag::any), + ws_dt(memory::data_type::u8), + ws_mem(nullptr), + diff_src_mem(nullptr), + diff_dst_mem(nullptr), + src_md(nullptr), + dst_md(nullptr), +#ifndef ENABLE_ONEDNN_V3 + fwd_desc(nullptr), + bwd_desc(nullptr), +#endif // !ENABLE_ONEDNN_V3 + fwd_pd(nullptr), + bwd_pd(nullptr), + bwd(nullptr) { + } + }; + + struct PoolingBwdContext context_; +#if defined(DNNL_AARCH64_USE_ACL) && defined(ENABLE_ONEDNN_OPENMP) + mutex primitive_execution_mu_; +#endif +}; + +template +class MklPoolingBwdPrimitiveFactory : public MklPrimitiveFactory { + public: + static MklPoolingBwdPrimitive* Get(const MklPoolingParams& bwdParams) { + MklPoolingBwdPrimitive* pooling_backward = nullptr; + + // Find a pooling backward primitive from the pool. + // If it does not exist, create a new one. + pooling_backward = static_cast*>( + MklPoolingBwdPrimitiveFactory::GetInstance().GetPoolingBwd( + bwdParams)); + if (pooling_backward == nullptr) { + pooling_backward = new MklPoolingBwdPrimitive(bwdParams); + MklPoolingBwdPrimitiveFactory::GetInstance().SetPoolingBwd( + bwdParams, pooling_backward); + } + return pooling_backward; + } + + static MklPoolingBwdPrimitiveFactory& GetInstance() { + static MklPoolingBwdPrimitiveFactory instance_; + return instance_; + } + + private: + MklPoolingBwdPrimitiveFactory() {} + ~MklPoolingBwdPrimitiveFactory() {} + + // The key to be created will be used to get/set pooling + // primitive op from reuse perspective. + // A pooling key is a string which concates key parameters + // as well as algorithm kind (max versus avg). + static string CreateKey(const MklPoolingParams& bwdParams) { + string prefix = "pooling_bwd"; + FactoryKeyCreator key_creator; + key_creator.AddAsKey(prefix); + key_creator.AddAsKey(bwdParams.src_dims); + key_creator.AddAsKey(bwdParams.dst_dims); + key_creator.AddAsKey(bwdParams.filter_dims); + key_creator.AddAsKey(bwdParams.strides); +#ifdef ENABLE_ONEDNN_V3 + key_creator.AddAsKey(bwdParams.dilations); +#endif // ENABLE_ONEDNN_V3 + key_creator.AddAsKey(bwdParams.padding_left); + key_creator.AddAsKey(bwdParams.padding_right); + key_creator.AddAsKey(bwdParams.src_format); + key_creator.AddAsKey(static_cast(bwdParams.alg_kind)); + return key_creator.GetKey(); + } + + MklPrimitive* GetPoolingBwd(const MklPoolingParams& bwdParams) { + string key = CreateKey(bwdParams); + return this->GetOp(key); + } + + void SetPoolingBwd(const MklPoolingParams& bwdParams, MklPrimitive* op) { + string key = CreateKey(bwdParams); + this->SetOp(key, op); + } +}; + +typedef Eigen::ThreadPoolDevice CPUDevice; + +struct MklPoolParameters { + int depth; + + int tensor_in_planes; // Pool3D + int tensor_in_cols; + int tensor_in_rows; + int tensor_in_batch; + + int window_planes; // Pool3D + int window_rows; + int window_cols; + int depth_window; + + int planes_stride; // Pool3D + int row_stride; + int col_stride; + int depth_stride; + +#ifdef ENABLE_ONEDNN_V3 + int planes_dilation; // Pool3D + int row_dilation; + int col_dilation; +#endif // ENABLE_ONEDNN_V3 + + int64 out_planes; // Pool3D + int64 out_height; + int64 out_width; + int out_depth; + + int64 pad_P1; // Pool3D + int64 pad_P2; // Pool3D + int64 pad_left; + int64 pad_right; + int64 pad_top; + int64 pad_bottom; + int pad_depth; + + TensorFormat data_format; + MklPoolParameters() + : depth(0), + tensor_in_planes(0), + tensor_in_cols(0), + tensor_in_rows(0), + tensor_in_batch(0), + window_planes(0), + window_rows(0), + window_cols(0), + depth_window(0), + planes_stride(0), + row_stride(0), + col_stride(0), + depth_stride(0), +#ifdef ENABLE_ONEDNN_V3 + planes_dilation(0), + row_dilation(0), + col_dilation(0), +#endif // ENABLE_ONEDNN_V3 + out_planes(0), + out_height(0), + out_width(0), + out_depth(0), + pad_P1(0), + pad_P2(0), + pad_left(0), + pad_right(0), + pad_top(0), + pad_bottom(0), + pad_depth(0), + data_format(TensorFormat::FORMAT_NCHW) { + } + + // Updates context->status if there is an invalid input. + void Init(OpKernelContext* context, const std::vector& ksize, + const std::vector& stride, Padding padding, + TensorFormat data_format, const TensorShape& tensor_in_shape); + void Init(OpKernelContext* context, const std::vector& ksize, + const std::vector& stride, Padding padding, + TensorFormat data_format, const MklDnnShape* mkl_in_shape); + + private: + // Common initialization for TensorFlow and MKL formats + void Init(OpKernelContext* context, const std::vector& ksize, + const std::vector& stride, Padding padding, + TensorFormat data_format); +}; + +template +class MklPoolingOpBase : public OpKernel { + public: + explicit MklPoolingOpBase(OpKernelConstruction* context) + : OpKernel(context), workspace_enabled_(false) { + string data_format; + if (std::is_same::value || std::is_same::value) { + // Current quantized convolution doesn't have data_format attribute. + data_format = "NHWC"; + } else { + OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format)); + } + OP_REQUIRES(context, FormatFromString(data_format, &this->data_format_tf_), + absl::InvalidArgumentError("Invalid data format")); + OP_REQUIRES_OK(context, context->GetAttr("ksize", &this->ksize_)); + OP_REQUIRES(context, this->ksize_.size() == 4 || this->ksize_.size() == 5, + absl::InvalidArgumentError("Sliding window ksize field must " + "specify 4 or 5 dimensions")); + for (int i = 0; i < this->ksize_.size(); ++i) { + OP_REQUIRES(context, this->ksize_[i] > 0, + errors::InvalidArgument( + absl::StrCat("Sliding window ksize must be positive. The " + "specified or inferred ksize is: ", + absl::StrJoin(ksize_, ",")))); + } + + OP_REQUIRES_OK(context, context->GetAttr("strides", &this->stride_)); + OP_REQUIRES(context, this->stride_.size() == 4 || this->stride_.size() == 5, + absl::InvalidArgumentError("Sliding window strides field must " + "specify 4 or 5 dimensions")); + OP_REQUIRES_OK(context, context->GetAttr("padding", &this->padding_)); + OP_REQUIRES(context, this->ksize_[0] == 1 && this->stride_[0] == 1, + absl::UnimplementedError("Pooling is not yet supported on the " + "batch dimension.")); + bool is_pool2d = (this->ksize_.size() == 4); + this->tensor_format_mkldnn_ = + is_pool2d ? TFDataFormatToMklDnnDataFormat(this->data_format_tf_) + : TFDataFormatToMklDnn3DDataFormat(this->data_format_tf_); + + this->data_format_mkldnn_ = + MklTensorFormatToMklDnnDataFormat(this->tensor_format_mkldnn_); + + // We may not get this attribute for this node if it does not go through + // graph rewrite pass. So we do not check for error while retrieving this + // attribute value. + auto status = + context->GetAttr("workspace_enabled", &this->workspace_enabled_); + (void)status; + } + void Compute(OpKernelContext* context) override = 0; + + protected: + // Calculate output shape of pooling op in oneDNN and TensorFlow order. + // oneDNN uses NCHW(Pool2D) or NCDHW(Pool3D) for output order. + // But TensorFlow output will be in NHWC/NCHW(Pool2D) or + // NDHWC/NCDHW(Pool3D) format depending on data format. Function expects + // output height and width to have already been int32 bounds-checked. + void GetOutputDims(const MklPoolParameters& mkl_pool_params, + memory::dims* output_dims_mkl_order) { + if (this->ksize_.size() == 4) { + // Pooling2D: oneDNN always needs output in NCHW format. + *output_dims_mkl_order = {mkl_pool_params.tensor_in_batch, + mkl_pool_params.out_depth, + static_cast(mkl_pool_params.out_height), + static_cast(mkl_pool_params.out_width)}; + } else { + // Pooling3D: oneDNN always needs output in NCDHW format. + *output_dims_mkl_order = {mkl_pool_params.tensor_in_batch, + mkl_pool_params.out_depth, + static_cast(mkl_pool_params.out_planes), + static_cast(mkl_pool_params.out_height), + static_cast(mkl_pool_params.out_width)}; + } + } + + void InitMklPoolParameters(OpKernelContext* context, + MklPoolParameters* pool_params, + const MklDnnShape& original_input_mkl_shape, + const TensorShape& input_tensor_shape) { + if (!original_input_mkl_shape.IsMklTensor()) { + pool_params->Init(context, this->ksize_, this->stride_, this->padding_, + this->data_format_tf_, input_tensor_shape); + } else { + pool_params->Init(context, this->ksize_, this->stride_, this->padding_, + this->data_format_tf_, &original_input_mkl_shape); + } + } + + void PoolParamsToDims(const MklPoolParameters* pool_params, + memory::dims* filter_dims, memory::dims* strides, +#ifdef ENABLE_ONEDNN_V3 + memory::dims* dilations, +#endif // ENABLE_ONEDNN_V3 + memory::dims* padding_left, memory::dims* padding_right, + bool is_pool2d) { + if (is_pool2d) { + // Pool2D + *filter_dims = + memory::dims({pool_params->window_rows, pool_params->window_cols}); + *strides = + memory::dims({pool_params->row_stride, pool_params->col_stride}); +#ifdef ENABLE_ONEDNN_V3 + *dilations = + memory::dims({pool_params->row_dilation, pool_params->col_dilation}); +#endif // ENABLE_ONEDNN_V3 + *padding_left = memory::dims({static_cast(pool_params->pad_top), + static_cast(pool_params->pad_left)}); + *padding_right = memory::dims({static_cast(pool_params->pad_bottom), + static_cast(pool_params->pad_right)}); + } else { + // Pool3D + *filter_dims = + memory::dims({pool_params->window_planes, pool_params->window_rows, + pool_params->window_cols}); + *strides = + memory::dims({pool_params->planes_stride, pool_params->row_stride, + pool_params->col_stride}); +#ifdef ENABLE_ONEDNN_V3 + *dilations = + memory::dims({pool_params->planes_dilation, pool_params->row_dilation, + pool_params->col_dilation}); +#endif // ENABLE_ONEDNN_V3 + + *padding_left = memory::dims({static_cast(pool_params->pad_P1), + static_cast(pool_params->pad_top), + static_cast(pool_params->pad_left)}); + *padding_right = memory::dims({static_cast(pool_params->pad_P2), + static_cast(pool_params->pad_bottom), + static_cast(pool_params->pad_right)}); + } + } + + void AllocateEmptyOutputTensor(OpKernelContext* context, + const int kOutputIndex, + MklPoolParameters* pool_params, + const memory::dims output_dims_mkl_order, + Tensor** output_tensor) { + MklDnnShape output_mkl_shape; + output_mkl_shape.SetMklTensor(false); + TensorShape output_tf_shape; + if (pool_params->data_format == TensorFormat::FORMAT_NCHW) { + output_tf_shape = MklDnnDimsToTFShape(output_dims_mkl_order); + } else { + memory::dims output_dims_order; + // determine Pooling2D (NHWC) or Pooling3D (NDHWC) + if (this->ksize_.size() == 4) { + output_dims_order = {pool_params->tensor_in_batch, + static_cast(pool_params->out_height), + static_cast(pool_params->out_width), + pool_params->out_depth}; + } else { + output_dims_order = {pool_params->tensor_in_batch, + static_cast(pool_params->out_planes), + static_cast(pool_params->out_height), + static_cast(pool_params->out_width), + pool_params->out_depth}; + } + output_tf_shape = MklDnnDimsToTFShape(output_dims_order); + } + AllocateOutputSetMklShape(context, kOutputIndex, output_tensor, + output_tf_shape, output_mkl_shape, + native_format_); + DCHECK(output_tensor); + } + + // Checks to make sure that the memory we need to allocate + // is a multiple of sizeof(T) + // returns the number of elements + size_t GetNumTElements(const memory::desc& pd) { + size_t num_bytes = pd.get_size(); + size_t ret_val = num_bytes / sizeof(T); + if (num_bytes % sizeof(T) != 0) { + ret_val++; + } + return ret_val; + } + + std::vector ksize_; + std::vector stride_; + Padding padding_; + TensorFormat data_format_tf_; + MklTensorFormat tensor_format_mkldnn_; + memory::format_tag data_format_mkldnn_; + bool workspace_enabled_; + bool native_format_ = false; +}; + +template +class MklPoolingForwardOpBase : public MklPoolingOpBase { + public: + explicit MklPoolingForwardOpBase(OpKernelConstruction* context) + : MklPoolingOpBase(context) {} + void Compute(OpKernelContext* context) override = 0; + + protected: + void ConfigureInput(OpKernelContext* context, + const MklDnnShape& input_mkl_shape, + const Tensor& input_tensor, + MklPoolParameters* pool_params, + MklDnnData* dnn_data_input) { + DCHECK(pool_params); + DCHECK(dnn_data_input); + TensorShape input_tensor_shape = input_tensor.shape(); + if (input_tensor.NumElements() != 0) { + memory::desc input_md = + input_mkl_shape.IsMklTensor() + ? input_mkl_shape.GetMklLayout() + : memory::desc( + (this->ksize_.size() == 4) + ? TFShapeToMklDnnDimsInNCHW(input_tensor_shape, + this->data_format_tf_) + : TFShapeToMklDnnDimsInNCDHW(input_tensor_shape, + this->data_format_tf_), + MklDnnType(), this->data_format_mkldnn_); + dnn_data_input->SetUsrMem(input_md, &input_tensor); + + if (this->ksize_.size() == 5) { + // Pool3D + std::vector input_sizes(5, -1); + input_sizes[MklDnnDims3D::Dim3d_N] = input_md.GET_DIMS[0]; + input_sizes[MklDnnDims3D::Dim3d_C] = input_md.GET_DIMS[1]; + input_sizes[MklDnnDims3D::Dim3d_D] = input_md.GET_DIMS[2]; + input_sizes[MklDnnDims3D::Dim3d_H] = input_md.GET_DIMS[3]; + input_sizes[MklDnnDims3D::Dim3d_W] = input_md.GET_DIMS[4]; + dnn_data_input->SetOpMemDesc(input_sizes, this->data_format_mkldnn_); + } + } + this->InitMklPoolParameters(context, pool_params, input_mkl_shape, + input_tensor_shape); + } + + void AllocateOutputTensor(OpKernelContext* context, + const PoolingFwdPd& pool_fwd_prim_desc, + const memory::dims output_dims_mkl_order, + const MklTensorFormat& output_tf_format, + Tensor** output_tensor) { + TensorShape output_tf_shape; + DCHECK(output_tensor); + memory::desc dst_pd = pool_fwd_prim_desc.dst_desc(); + + MklDnnShape output_mkl_shape; + output_mkl_shape.SetMklTensor(true); + output_mkl_shape.SET_MKL_LAYOUT(dst_pd); + output_mkl_shape.SetElemType(MklDnnType()); + output_mkl_shape.SetTfLayout(output_dims_mkl_order.size(), + output_dims_mkl_order, output_tf_format); + // Only allocate enough space for the elements we need. + output_tf_shape.AddDim(this->GetNumTElements(dst_pd)); + + if (this->native_format_) { + output_tf_shape = output_mkl_shape.GetTfShape(); + } + AllocateOutputSetMklShape(context, kOutputTensorIndexOutput, output_tensor, + output_tf_shape, output_mkl_shape, + this->native_format_); + DCHECK(*output_tensor); + } + + void SanityCheckInput(OpKernelContext* context, const Tensor& input_tensor, + const MklDnnShape& input_mkl_shape) { + if (!input_mkl_shape.IsMklTensor()) { + OP_REQUIRES( + context, input_tensor.dims() == 4 || input_tensor.dims() == 5, + absl::InvalidArgumentError("Input must be 4 or 5-dimensional")); + } else { + OP_REQUIRES( + context, + input_mkl_shape.GetDimension() == 4 || + input_mkl_shape.GetDimension() == 5, + absl::InvalidArgumentError("Input shape must be 4 or 5-dimensional")); + } + } + const int kInputTensorIndexInput = 0; + const int kOutputTensorIndexOutput = 0; +}; // MklPoolingForwardBaseOp + +template +class MklPoolingBackwardOpBase : public MklPoolingOpBase { + public: + explicit MklPoolingBackwardOpBase(OpKernelConstruction* context) + : MklPoolingOpBase(context) {} + void Compute(OpKernelContext* context) override = 0; + + protected: + const int kOutputTensorIndexOutput = 0; + + void AllocateOutputTensor(OpKernelContext* context, + const PoolingBwdPd& pool_bkwd_prim_desc, + const memory::dims output_dims_mkl_order, + const MklTensorFormat& output_tf_format, + Tensor** output_tensor) { + DCHECK(output_tensor); + memory::desc dst_pd = pool_bkwd_prim_desc.diff_src_desc(); + MklDnnShape output_mkl_shape; + output_mkl_shape.SetMklTensor(true); + output_mkl_shape.SET_MKL_LAYOUT(dst_pd); + output_mkl_shape.SetElemType(MklDnnType()); + output_mkl_shape.SetTfLayout(output_dims_mkl_order.size(), + output_dims_mkl_order, output_tf_format); + + TensorShape output_tf_shape; + output_tf_shape.AddDim(this->GetNumTElements(dst_pd)); + if (this->native_format_) { + output_tf_shape = output_mkl_shape.GetTfShape(); + } + AllocateOutputSetMklShape(context, kOutputTensorIndexOutput, output_tensor, + output_tf_shape, output_mkl_shape, + this->native_format_); + DCHECK(*output_tensor); + } +}; + +#undef GET_DIMS +#undef SET_MKL_LAYOUT + +} // namespace tensorflow + +#endif // INTEL_MKL +#endif // TENSORFLOW_CORE_KERNELS_MKL_MKL_POOLING_OPS_COMMON_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/mkl/mkl_quantized_conv_ops.h b/third_party/tflite-hdrs/tensorflow/core/kernels/mkl/mkl_quantized_conv_ops.h new file mode 100644 index 00000000..0b6319c9 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/mkl/mkl_quantized_conv_ops.h @@ -0,0 +1,93 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_MKL_MKL_QUANTIZED_CONV_OPS_H_ +#define TENSORFLOW_CORE_KERNELS_MKL_MKL_QUANTIZED_CONV_OPS_H_ + +#include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/tensor.h" + +#ifdef INTEL_MKL + +namespace tensorflow { +template +float MklFloatForOneQuantizedLevel(float range_min, float range_max) { + int64 highest = static_cast(Eigen::NumTraits::highest()); + int64 lowest = static_cast(Eigen::NumTraits::lowest()); + + // Adjusting for having a symmetric range. + // for example: for 8-bit [-127, 127] as opposed to [-128, 127]. + if (lowest < -highest) ++lowest; + + const float float_for_one_quantized_level = + (range_max - range_min) / (highest - lowest); + return float_for_one_quantized_level; +} + +template +void MklQuantizationRangeForMultiplication(float min_a, float max_a, + float min_b, float max_b, + float* min_c, float* max_c) { + const float a_float_for_one_quant_level = + MklFloatForOneQuantizedLevel(min_a, max_a); + const float b_float_for_one_quant_level = + MklFloatForOneQuantizedLevel(min_b, max_b); + + const int64 c_highest = static_cast(Eigen::NumTraits::highest()); + const int64 c_lowest = static_cast(Eigen::NumTraits::lowest()); + const float c_float_for_one_quant_level = + a_float_for_one_quant_level * b_float_for_one_quant_level; + + *min_c = c_float_for_one_quant_level * c_lowest; + *max_c = c_float_for_one_quant_level * c_highest; +} + +template +void MklQuantizationRangeForMultiplication(float min_a, float max_a, + const Tensor& min_b_vector, + const Tensor& max_b_vector, + Tensor** min_c_vector, + Tensor** max_c_vector) { + DCHECK(min_b_vector.NumElements() == (*min_c_vector)->NumElements()); + DCHECK(max_b_vector.NumElements() == (*max_c_vector)->NumElements()); + size_t n_channel = min_b_vector.NumElements(); + const int64 c_highest = static_cast(Eigen::NumTraits::highest()); + const int64 c_lowest = static_cast(Eigen::NumTraits::lowest()); + const float* min_b = min_b_vector.flat().data(); + const float* max_b = max_b_vector.flat().data(); + float* min_c = (*min_c_vector)->flat().data(); + float* max_c = (*max_c_vector)->flat().data(); + +#ifdef ENABLE_ONEDNN_OPENMP +#pragma omp parallel for +#endif // ENABLE_ONEDNN_OPENMP + // TODO(intel-tf): Add eigen parallel_for + for (int64_t n = 0; n < n_channel; ++n) { + float a_float_for_one_quant_level = + MklFloatForOneQuantizedLevel(min_a, max_a); + float b_float_for_one_quant_level = + MklFloatForOneQuantizedLevel(min_b[n], max_b[n]); + float c_float_for_one_quant_level = + a_float_for_one_quant_level * b_float_for_one_quant_level; + min_c[n] = c_float_for_one_quant_level * c_lowest; + max_c[n] = c_float_for_one_quant_level * c_highest; + } +} + +} // namespace tensorflow + +#endif // INTEL_MKL + +#endif // TENSORFLOW_CORE_KERNELS_MKL_MKL_QUANTIZED_CONV_OPS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/mlir_generated/base_binary_ops_test.h b/third_party/tflite-hdrs/tensorflow/core/kernels/mlir_generated/base_binary_ops_test.h new file mode 100644 index 00000000..59bf0e77 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/mlir_generated/base_binary_ops_test.h @@ -0,0 +1,474 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_MLIR_GENERATED_BASE_BINARY_OPS_TEST_H_ +#define TENSORFLOW_CORE_KERNELS_MLIR_GENERATED_BASE_BINARY_OPS_TEST_H_ + +#include +#include + +#include "absl/container/inlined_vector.h" +#include "absl/strings/string_view.h" +#include "llvm/ADT/STLExtras.h" +#include "tensorflow/core/common_runtime/device.h" +#include "tensorflow/core/common_runtime/device_factory.h" +#include "tensorflow/core/framework/fake_input.h" +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/kernels/mlir_generated/base_ops_test.h" +#include "tensorflow/core/kernels/ops_testutil.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { + +// Base class for `BinaryOpsTest` fixture that has to be defined with a custom +// TF device if you want to use the test macros in this file. +class BinaryOpsTestBase : public OpsTestBase { + protected: + // This method should set the TF device, e.g. DEVICE_CPU, DEVICE_GPU. + void SetUp() override = 0; + + template + void SetOpKernel(const std::string& op_name, const TensorShape& lhs_shape, + const absl::InlinedVector& lhs_input, + const TensorShape& rhs_shape, + const absl::InlinedVector& rhs_input, + const test::OpsTestConfig& config) { + auto builder = NodeDefBuilder("some_name", op_name) + .Input(FakeInput(DataTypeToEnum::v())) + .Input(FakeInput(DataTypeToEnum::v())); + if (config.add_t) { + builder.Attr(config.input_attribute, DataTypeToEnum::v()); + } + if (config.add_tout) { + builder.Attr(config.output_attribute, DataTypeToEnum::v()); + } + TF_ASSERT_OK(builder.Finalize(node_def())); + + TF_ASSERT_OK(InitOp()); + AddInputFromArray(lhs_shape, lhs_input); + AddInputFromArray(rhs_shape, rhs_input); + } + + // Run fully specified tests. + + template + void RunAndExpectResult(const std::string& op_name, + const TensorShape& lhs_shape, + const absl::InlinedVector& lhs_input, + const TensorShape& rhs_shape, + const absl::InlinedVector& rhs_input, + const TensorShape& expected_shape, + const absl::InlinedVector& expected_output, + const test::OpsTestConfig& config) { + SetOpKernel(op_name, lhs_shape, lhs_input, rhs_shape, rhs_input, + config); + TF_ASSERT_OK(RunOpKernel()); + + // Compare output to expectation. + Tensor expected_tensor(allocator(), DataTypeToEnum::value, + expected_shape); + test::FillValues(&expected_tensor, expected_output); + if (config.expect_strictly_equal) { + test::ExpectEqual(expected_tensor, *GetOutput(0), + config.supress_tolerance ? test::Tolerance::kNone + : test::Tolerance::kDefault); + } else { + test::ExpectClose(expected_tensor, *GetOutput(0), config.atol, + config.rtol); + } + } + + template + void RunAndExpectInvalidArgument(const std::string& op_name, + const TensorShape& lhs_shape, + const absl::InlinedVector& lhs_input, + const TensorShape& rhs_shape, + const absl::InlinedVector& rhs_input, + const test::OpsTestConfig& config) { + SetOpKernel(op_name, lhs_shape, lhs_input, rhs_shape, rhs_input, + config); + auto status = RunOpKernel(); + EXPECT_FALSE(status.ok()); + EXPECT_EQ(status.code(), error::INVALID_ARGUMENT); + } + + // Run common test cases. + + template + void TestIncompatibleShapes(const std::string& op_name, + const absl::InlinedVector& lhs_input, + const absl::InlinedVector& rhs_input, + const test::OpsTestConfig& config) { + // Prepare incompatibly shaped inputs. + TensorShape lhs_shape{3}; + TensorShape rhs_shape{2}; + auto repeated_lhs_input = + test::RepeatInputToMatchShape(lhs_input, lhs_shape.num_elements()); + auto repeated_rhs_input = + test::RepeatInputToMatchShape(rhs_input, rhs_shape.num_elements()); + + RunAndExpectInvalidArgument(op_name, lhs_shape, repeated_lhs_input, + rhs_shape, repeated_rhs_input, config); + } + + template + void TestEqualShapes(const std::string& op_name, const TensorShape& shape, + const absl::InlinedVector& lhs_input, + const absl::InlinedVector& rhs_input, + BaselineOutT (*baseline_callback)(BaselineT, BaselineT), + const test::OpsTestConfig& config) { + // Prepare inputs. + int64_t input_size = shape.num_elements(); + CHECK(lhs_input.size() <= input_size && rhs_input.size() <= input_size && + "expect input shape to hold all input values"); + auto repeated_lhs_input = + test::RepeatInputToMatchShape(lhs_input, input_size); + auto repeated_rhs_input = + test::RepeatInputToMatchShape(rhs_input, input_size); + + // Compute expected results. + absl::InlinedVector expected_output; + for (auto it_lhs = repeated_lhs_input.begin(), + it_rhs = repeated_rhs_input.begin(), + end = repeated_lhs_input.end(); + it_lhs != end; ++it_lhs, ++it_rhs) { + auto lhs = static_cast(*it_lhs); + auto rhs = static_cast(*it_rhs); + auto result = static_cast(baseline_callback(lhs, rhs)); + expected_output.push_back(result); + } + + RunAndExpectResult(op_name, shape, repeated_lhs_input, shape, + repeated_rhs_input, shape, expected_output, + config); + } + + template + void TestOneScalar(const std::string& op_name, T scalar_input, + const TensorShape& other_shape, + const absl::InlinedVector& other_input, + BaselineOutT (*baseline_callback)(BaselineT, BaselineT), + const test::OpsTestConfig& config) { + // Prepare inputs. + TensorShape scalar_shape{}; + CHECK(other_input.size() <= other_shape.num_elements() && + "expect other input shape to hold all input values"); + auto repeated_other_input = + test::RepeatInputToMatchShape(other_input, other_shape.num_elements()); + + // Compute expected results. + absl::InlinedVector expected_output; + for (auto it = repeated_other_input.begin(), + end = repeated_other_input.end(); + it != end; ++it) { + auto scalar = static_cast(scalar_input); + auto other_value = static_cast(*it); + auto result = static_cast(baseline_callback(scalar, other_value)); + expected_output.push_back(result); + } + + auto scalar_input_vector = test::InputAsVector({scalar_input}); + RunAndExpectResult(op_name, scalar_shape, scalar_input_vector, + other_shape, repeated_other_input, + /*expected_shape=*/other_shape, expected_output, + config); + } + + template + void TestOneEffectiveScalar(const std::string& op_name, T scalar_input, + const TensorShape& other_shape, + const absl::InlinedVector& other_input, + BaselineOutT (*baseline_callback)(BaselineT, + BaselineT), + const test::OpsTestConfig& config) { + // Prepare inputs. + TensorShape effective_scalar_shape{1, 1, 1, 1, 1, 1, 1}; + CHECK(other_input.size() <= other_shape.num_elements() && + "expect other input shape to hold all input values"); + auto repeated_other_input = + test::RepeatInputToMatchShape(other_input, other_shape.num_elements()); + + // Compute expected results. + absl::InlinedVector expected_output; + for (auto it = repeated_other_input.begin(), + end = repeated_other_input.end(); + it != end; ++it) { + auto scalar = static_cast(scalar_input); + auto other_value = static_cast(*it); + auto result = static_cast(baseline_callback(scalar, other_value)); + expected_output.push_back(result); + } + + auto scalar_input_vector = test::InputAsVector({scalar_input}); + TensorShape expected_shape = other_shape; + while (expected_shape.dims() < effective_scalar_shape.dims()) { + expected_shape.InsertDim(0, 1); + } + RunAndExpectResult( + op_name, effective_scalar_shape, scalar_input_vector, other_shape, + repeated_other_input, expected_shape, expected_output, config); + } + + template + void TestBroadcastingExpand(const std::string& op_name, + const absl::InlinedVector& lhs_input, + const absl::InlinedVector& rhs_input, + BaselineOutT (*baseline_callback)(BaselineT, + BaselineT), + const test::OpsTestConfig& config) { + // Prepare inputs. + TensorShape lhs_shape{1}; + TensorShape rhs_shape{6}; + auto repeated_lhs_input = + test::RepeatInputToMatchShape(lhs_input, lhs_shape.num_elements()); + auto repeated_rhs_input = + test::RepeatInputToMatchShape(rhs_input, rhs_shape.num_elements()); + + // Compute expected results. + std::vector lhs_indices = {0, 0, 0, 0, 0, 0}; + std::vector rhs_indices = {0, 1, 2, 3, 4, 5}; + auto expected_output = + ComputeExpectedOutput( + lhs_indices, repeated_lhs_input, rhs_indices, repeated_rhs_input, + baseline_callback); + + RunAndExpectResult( + op_name, lhs_shape, repeated_lhs_input, rhs_shape, repeated_rhs_input, + /*expected_shape=*/rhs_shape, expected_output, config); + } + + template + void TestBroadcastingInDim(const std::string& op_name, + const absl::InlinedVector& lhs_input, + const absl::InlinedVector& rhs_input, + BaselineOutT (*baseline_callback)(BaselineT, + BaselineT), + const test::OpsTestConfig& config) { + // Prepare inputs. + TensorShape lhs_shape{3}; + TensorShape rhs_shape{2, 3}; + auto repeated_lhs_input = + test::RepeatInputToMatchShape(lhs_input, lhs_shape.num_elements()); + auto repeated_rhs_input = + test::RepeatInputToMatchShape(rhs_input, rhs_shape.num_elements()); + + // Compute expected results. + std::vector lhs_indices = {0, 1, 2, 0, 1, 2}; + std::vector rhs_indices = {0, 1, 2, 3, 4, 5}; + auto expected_output = + ComputeExpectedOutput( + lhs_indices, repeated_lhs_input, rhs_indices, repeated_rhs_input, + baseline_callback); + + RunAndExpectResult( + op_name, lhs_shape, repeated_lhs_input, rhs_shape, repeated_rhs_input, + /*expected_shape=*/rhs_shape, expected_output, config); + } + + template + void TestBroadcasting(const std::string& op_name, + const absl::InlinedVector& lhs_input, + const absl::InlinedVector& rhs_input, + BaselineOutT (*baseline_callback)(BaselineT, BaselineT), + const test::OpsTestConfig& config) { + // Prepare inputs. + TensorShape lhs_shape{2, 1}; + TensorShape rhs_shape{3}; + auto repeated_lhs_input = + test::RepeatInputToMatchShape(lhs_input, lhs_shape.num_elements()); + auto repeated_rhs_input = + test::RepeatInputToMatchShape(rhs_input, rhs_shape.num_elements()); + + // Compute expected results. + TensorShape expected_shape{2, 3}; + std::vector lhs_indices = {0, 0, 0, 1, 1, 1}; + std::vector rhs_indices = {0, 1, 2, 0, 1, 2}; + auto expected_output = + ComputeExpectedOutput( + lhs_indices, repeated_lhs_input, rhs_indices, repeated_rhs_input, + baseline_callback); + + RunAndExpectResult(op_name, lhs_shape, repeated_lhs_input, + rhs_shape, repeated_rhs_input, expected_shape, + expected_output, config); + } + + template + void TestBroadcastingRank6(const std::string& op_name, + const absl::InlinedVector& lhs_input, + const absl::InlinedVector& rhs_input, + BaselineOutT (*baseline_callback)(BaselineT, + BaselineT), + const test::OpsTestConfig& config) { + // Prepare inputs. + TensorShape lhs_shape{1, 2, 3, 1, 2, 1}; + TensorShape rhs_shape{1, 1, 1, 2, 3}; + auto repeated_lhs_input = + test::RepeatInputToMatchShape(lhs_input, lhs_shape.num_elements()); + auto repeated_rhs_input = + test::RepeatInputToMatchShape(rhs_input, rhs_shape.num_elements()); + + // Compute expected results. + TensorShape expected_shape{1, 2, 3, 1, 2, 3}; + std::vector lhs_indices = {0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, + 4, 4, 4, 5, 5, 5, 6, 6, 6, 7, 7, 7, + 8, 8, 8, 9, 9, 9, 10, 10, 10, 11, 11, 11}; + std::vector rhs_indices = { + 0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 5, + 0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 5, + }; + auto expected_output = + ComputeExpectedOutput( + lhs_indices, repeated_lhs_input, rhs_indices, repeated_rhs_input, + baseline_callback); + + RunAndExpectResult(op_name, lhs_shape, repeated_lhs_input, + rhs_shape, repeated_rhs_input, expected_shape, + expected_output, config); + } + + template + void TestEmptyShapeBroadcasting(const std::string& op_name, + const absl::InlinedVector& lhs_input, + const absl::InlinedVector& rhs_input, + const test::OpsTestConfig& config) { + // Prepare inputs. + TensorShape lhs_shape{2, 0, 1}; + TensorShape rhs_shape{2, 0, 5}; + absl::InlinedVector empty_input = {}; + + // Define expected result. + TensorShape expected_shape{2, 0, 5}; + absl::InlinedVector expected_output = {}; + + RunAndExpectResult(op_name, lhs_shape, empty_input, rhs_shape, + empty_input, expected_shape, expected_output, + config); + } + + private: + template + absl::InlinedVector ComputeExpectedOutput( + std::vector lhs_indices, absl::InlinedVector lhs_input, + std::vector rhs_indices, absl::InlinedVector rhs_input, + BaselineOutT (*baseline_callback)(BaselineT, BaselineT)) { + absl::InlinedVector expected_output; + for (int64_t i = 0; i < lhs_indices.size(); i++) { + auto lhs = static_cast(lhs_input[lhs_indices[i]]); + auto rhs = static_cast(rhs_input[rhs_indices[i]]); + auto result = static_cast(baseline_callback(lhs, rhs)); + expected_output.push_back(result); + } + return expected_output; + } +}; + +// Macros to easily generate common test cases. The macros use `BinaryOpsTest` +// fixture in order to share implementation across GPU and CPU platform tests. +// For specific inputs, please define your own test fixtures. +#define GENERATE_DEFAULT_NO_BROADCASTING_TESTS_2( \ + op_name, test_name, T, BaselineT, OutT, BaselineOutT, lhs_input, \ + rhs_input, baseline_callback, config) \ + TEST_F(BinaryOpsTest, op_name##EqShapes##test_name) { \ + TestEqualShapes( \ + #op_name, /*shape=*/test::DefaultInputShape(), lhs_input, rhs_input, \ + baseline_callback, config); \ + } \ + TEST_F(BinaryOpsTest, op_name##IncompatibleShapes##test_name) { \ + TestIncompatibleShapes(#op_name, lhs_input, rhs_input, config); \ + } + +#define GENERATE_DEFAULT_TESTS_2(op_name, test_name, T, BaselineT, OutT, \ + BaselineOutT, lhs_input, rhs_input, \ + baseline_callback, config) \ + \ + GENERATE_DEFAULT_NO_BROADCASTING_TESTS_2( \ + op_name, test_name, T, BaselineT, OutT, BaselineOutT, lhs_input, \ + rhs_input, baseline_callback, config) \ + \ + TEST_F(BinaryOpsTest, op_name##OneScalar##test_name) { \ + TestOneScalar( \ + #op_name, /*scalar_input=*/lhs_input.front(), \ + /*other_shape=*/test::DefaultInputShape(), /*other_input=*/rhs_input, \ + baseline_callback, config); \ + } \ + \ + TEST_F(BinaryOpsTest, op_name##TestOneEffectiveScalar##test_name) { \ + TestOneEffectiveScalar( \ + #op_name, /*scalar_input=*/lhs_input.front(), \ + /*other_shape=*/test::DefaultInputShape(), /*other_input=*/rhs_input, \ + baseline_callback, config); \ + } \ + \ + TEST_F(BinaryOpsTest, op_name##BroadcastingExpand##test_name) { \ + TestBroadcastingExpand( \ + #op_name, lhs_input, rhs_input, baseline_callback, config); \ + } \ + \ + TEST_F(BinaryOpsTest, op_name##BroadcastingInDim##test_name) { \ + TestBroadcastingInDim( \ + #op_name, lhs_input, rhs_input, baseline_callback, config); \ + } \ + \ + TEST_F(BinaryOpsTest, op_name##Broadcasting##test_name) { \ + TestBroadcasting( \ + #op_name, lhs_input, rhs_input, baseline_callback, config); \ + } \ + \ + TEST_F(BinaryOpsTest, op_name##BroadcastingRank6##test_name) { \ + TestBroadcastingRank6( \ + #op_name, lhs_input, rhs_input, baseline_callback, config); \ + } \ + \ + TEST_F(BinaryOpsTest, op_name##EmptyShapeBroadcasting##test_name) { \ + TestEmptyShapeBroadcasting( \ + #op_name, lhs_input, rhs_input, config); \ + } + +#define GENERATE_DEFAULT_TESTS(op_name, test_name, T, OutT, baseline_callback, \ + config) \ + GENERATE_DEFAULT_TESTS_2(op_name, test_name, T, T, OutT, OutT, \ + test::DefaultInput(), test::DefaultInput(), \ + baseline_callback, config) + +#define GENERATE_DEFAULT_TESTS_WITH_SPECIFIC_INPUT_VALUES( \ + op_name, test_name, T, OutT, lhs_input, rhs_input, baseline_callback, \ + config) \ + GENERATE_DEFAULT_TESTS_2(op_name, test_name, T, T, OutT, OutT, lhs_input, \ + rhs_input, baseline_callback, config) + +#define GENERATE_DEFAULT_NO_BROADCASTING_TESTS(op_name, test_name, T, OutT, \ + baseline_callback) \ + GENERATE_DEFAULT_NO_BROADCASTING_TESTS_2( \ + op_name, test_name, T, T, OutT, OutT, test::DefaultInput(), \ + test::DefaultInput(), baseline_callback, \ + test::OpsTestConfig().ExpectStrictlyEqual()) + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_MLIR_GENERATED_BASE_BINARY_OPS_TEST_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/mlir_generated/base_gpu_op.h b/third_party/tflite-hdrs/tensorflow/core/kernels/mlir_generated/base_gpu_op.h new file mode 100644 index 00000000..c299e1c7 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/mlir_generated/base_gpu_op.h @@ -0,0 +1,117 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_MLIR_GENERATED_BASE_GPU_OP_H_ +#define TENSORFLOW_CORE_KERNELS_MLIR_GENERATED_BASE_GPU_OP_H_ + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/kernels/mlir_generated/base_op.h" + +namespace tensorflow { + +/// Register kernels. + +#define REGISTER_ALIASED_GPU_KERNEL(tf_op, mlir_op, input_type, output_type) \ + REGISTER_ALIASED_KERNEL(tf_op, mlir_op, GPU, input_type, output_type, \ + /*no additional_cstrs*/) + +// clang-format off +#define REGISTER_GPU_KERNEL(tf_op, input_type, output_type) \ + REGISTER_KERNEL(tf_op, GPU, input_type, output_type, /*no additional_cstrs*/) +// clang-format on + +#define REGISTER_COMPLEX_GPU_KERNEL(tf_op, input_type, output_type) \ + REGISTER_COMPLEX_KERNEL(tf_op, GPU, input_type, output_type) + +#define REGISTER_GPU_KERNEL_NO_TYPE_CONSTRAINT(tf_op, input_type) \ + REGISTER_KERNEL_NO_TYPE_CONSTRAINT(tf_op, GPU, input_type) + +/// Unary kernels. + +#define GENERATE_AND_REGISTER_UNARY_GPU_KERNEL(tf_op, input_type) \ + GENERATE_AND_REGISTER_UNARY_KERNEL(tf_op, GPU, input_type, \ + /*no additional_cstrs*/) + +#define GENERATE_AND_REGISTER_UNARY_GPU_KERNEL2(tf_op, input_type, \ + output_type) \ + GENERATE_AND_REGISTER_UNARY_KERNEL2(tf_op, GPU, input_type, output_type, \ + /*no additional_cstrs*/) + +#define GENERATE_AND_REGISTER_UNARY_GPU_KERNEL3( \ + tf_op, input_type, output_type, casted_input_type, casted_output_type) \ + GENERATE_AND_REGISTER_UNARY_KERNEL3(tf_op, GPU, input_type, output_type, \ + casted_input_type, casted_output_type, \ + /*no additional_cstrs*/) + +#define GENERATE_AND_REGISTER_UNARY_JIT_GPU_KERNEL(tf_op, input_type) \ + GENERATE_AND_REGISTER_UNARY_JIT_KERNEL(tf_op, GPU, input_type, \ + /*no additional_cstrs*/) + +#define GENERATE_UNARY_GPU_KERNEL(tf_op, input_type) \ + GENERATE_UNARY_KERNEL(tf_op, GPU, input_type) + +#define GENERATE_UNARY_GPU_KERNEL2(tf_op, input_type, output_type) \ + GENERATE_UNARY_KERNEL2(tf_op, GPU, input_type, output_type) + +#define GENERATE_UNARY_GPU_KERNEL3(tf_op, input_type, output_type, \ + casted_input_type, casted_output_type) \ + GENERATE_UNARY_KERNEL3(tf_op, GPU, input_type, output_type, \ + casted_input_type, casted_output_type) + +/// Binary kernels. + +#define GENERATE_AND_REGISTER_BINARY_GPU_KERNEL(tf_op, input_type) \ + GENERATE_AND_REGISTER_BINARY_KERNEL(tf_op, GPU, input_type, \ + /*no additional_cstrs*/) + +#define GENERATE_AND_REGISTER_BINARY_GPU_KERNEL2(tf_op, input_type, \ + output_type) \ + GENERATE_AND_REGISTER_BINARY_KERNEL2(tf_op, GPU, input_type, output_type, \ + /*no additional_cstrs*/) + +#define GENERATE_AND_REGISTER_BINARY_GPU_KERNEL3( \ + tf_op, input_type, output_type, casted_input_type, casted_output_type) \ + GENERATE_AND_REGISTER_BINARY_KERNEL3(tf_op, GPU, input_type, output_type, \ + casted_input_type, casted_output_type, \ + /*no additional_cstrs*/) + +#define GENERATE_AND_REGISTER_BINARY_JIT_GPU_KERNEL(tf_op, input_type) \ + GENERATE_AND_REGISTER_BINARY_JIT_KERNEL(tf_op, GPU, input_type, \ + /*no additional_cstrs*/) + +#define GENERATE_BINARY_GPU_KERNEL(tf_op, input_type) \ + GENERATE_BINARY_KERNEL(tf_op, GPU, input_type) + +#define GENERATE_BINARY_GPU_KERNEL2(tf_op, input_type, output_type) \ + GENERATE_BINARY_KERNEL2(tf_op, GPU, input_type, output_type) + +#define GENERATE_BINARY_GPU_KERNEL3(tf_op, input_type, output_type, \ + casted_input_type, casted_output_type) \ + GENERATE_BINARY_KERNEL3(tf_op, GPU, input_type, output_type, \ + casted_input_type, casted_output_type) + +/// Ternary kernels. + +#define GENERATE_AND_REGISTER_TERNARY_GPU_KERNEL(tf_op, input_type) \ + GENERATE_AND_REGISTER_TERNARY_KERNEL(tf_op, GPU, input_type, \ + /*no additional_cstrs*/) + +#define GENERATE_AND_REGISTER_TERNARY_JIT_GPU_KERNEL(tf_op, input_type) \ + GENERATE_AND_REGISTER_TERNARY_JIT_KERNEL(tf_op, GPU, input_type, \ + /*no additional_cstrs*/) + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_MLIR_GENERATED_BASE_GPU_OP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/mlir_generated/base_op.h b/third_party/tflite-hdrs/tensorflow/core/kernels/mlir_generated/base_op.h new file mode 100644 index 00000000..c7e92540 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/mlir_generated/base_op.h @@ -0,0 +1,346 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_MLIR_GENERATED_BASE_OP_H_ +#define TENSORFLOW_CORE_KERNELS_MLIR_GENERATED_BASE_OP_H_ + +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/SmallVector.h" +#include "mlir/ExecutionEngine/CRunnerUtils.h" // from @llvm-project +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/op_requires.h" +#include "tensorflow/core/platform/errors.h" + +namespace tensorflow { + +// Unranked memref descriptor as it is expected and returned by the external +// MLIR-generated "C" function. +struct UnrankedMemRef { + int64_t rank; + void* descriptor; +}; + +// Returns a pointer to an allocated MlirTensorBuffer that takes ownership of +// pre-allocated memory. +TensorBuffer* GetMlirTensorBuffer(const void* ptr, size_t size, + Allocator* allocator); + +/// Used to allocate descriptors on stack when they are small. + +constexpr int kMaxRankForOnStackDescriptors = 10; + +static constexpr size_t GetSizeOfDescriptor(int rank) { + return sizeof(void*) * (2 * rank + 3); +} + +using DescriptorBuffer = + llvm::SmallVector; + +/// Converts tensors to memory descriptors and back. + +UnrankedMemRef ConvertTensorToDescriptor(const Tensor& tensor, + DescriptorBuffer& buffer); + +TensorShape ExtractShapeFromDescriptor(UnrankedMemRef unranked_descriptor); + +template +Tensor ConvertDescriptorToTensor(UnrankedMemRef unranked_descriptor, + DataType TfDataType, Allocator* allocator) { + void* base_ptr = static_cast(unranked_descriptor.descriptor)[0]; + TensorShape result_shape = ExtractShapeFromDescriptor(unranked_descriptor); + TensorBuffer* buffer = GetMlirTensorBuffer( + base_ptr, sizeof(ElemType) * result_shape.num_elements(), allocator); + + // Tensor takes ownership of the buffer. + Tensor tensor{TfDataType, result_shape, buffer}; + // When Tensor is constructed, its ref-counter is incremented. We need to + // decrement it back. + buffer->Unref(); + return tensor; +} + +// OpKernel with Compute function that converts input tensors to unranked +// memref descriptors and calls the MLIR-generated unranked kernel. The outputs +// are converted back to tensors using MlirTensorBuffer to take ownership of +// pre-allocated memory. +template +class MLIROpKernel : public OpKernel { + public: + explicit MLIROpKernel(OpKernelConstruction* ctx) : OpKernel(ctx) {} + + void Compute(OpKernelContext* ctx) override { + VLOG(4) << ctx->op_kernel().TraceString(*ctx, true); + + // Convert tensor arguments to unranked memory descriptors. + llvm::SmallVector buffers(ctx->num_inputs()); + llvm::SmallVector args; + for (int i = 0; i < ctx->num_inputs(); ++i) { + args.push_back(ConvertTensorToDescriptor(ctx->input(i), buffers[i])); + } + + UnrankedMemRef result_desc = Invoke(ctx, args); + if (!ctx->status().ok()) { + free(result_desc.descriptor); + return; + } + void* result_data_ptr = static_cast(result_desc.descriptor)[0]; + + // Detect input buffer reuse. + for (int i = 0, end = ctx->num_inputs(); i < end; ++i) { + const Tensor& input = ctx->input(i); + if (input.data() == result_data_ptr) { + // Run a bitcast in case the output type is different. + Tensor output; + TensorShape result_shape = ExtractShapeFromDescriptor(result_desc); + OP_REQUIRES_OK( + ctx, output.BitcastFrom(input, CastedTfDataType, result_shape)); + + ctx->set_output(0, output); + free(result_desc.descriptor); + return; + } + } + + tensorflow::AllocatorAttributes attrs; + auto* allocator = ctx->get_allocator(attrs); + Tensor result_tensor = ConvertDescriptorToTensor( + result_desc, TfDataType, allocator); + if (TfDataType != CastedTfDataType) { + Tensor casted_result_tensor; + OP_REQUIRES_OK( + ctx, casted_result_tensor.BitcastFrom(result_tensor, CastedTfDataType, + result_tensor.shape())); + result_tensor = casted_result_tensor; + } + free(result_desc.descriptor); + ctx->set_output(0, result_tensor); + } + + protected: + virtual UnrankedMemRef Invoke( + OpKernelContext* ctx, llvm::SmallVectorImpl& args) = 0; +}; + +/// Generate C function and kernel names. + +#define MLIR_FUNCTION(tf_op, platform, input_type, output_type) \ + _mlir_ciface_##tf_op##_##platform##_##input_type##_##output_type + +#define MLIR_OP(tf_op, platform, input_type, output_type) \ + Mlir##tf_op##platform##input_type##output_type##Op + +/// Register kernels. + +#define REGISTER_ALIASED_KERNEL(tf_op, mlir_op, platform, input_type, \ + output_type, additional_cstrs) \ + REGISTER_KERNEL_BUILDER( \ + Name(#tf_op) \ + .Device(DEVICE_##platform) \ + .TypeConstraint::Type>("T") \ + additional_cstrs, \ + MLIR_OP(mlir_op, platform, input_type, output_type)); + +#define REGISTER_KERNEL(tf_op, platform, input_type, output_type, \ + additional_cstrs) \ + REGISTER_ALIASED_KERNEL(tf_op, tf_op, platform, input_type, output_type, \ + additional_cstrs) + +#define REGISTER_COMPLEX_KERNEL(tf_op, platform, input_type, output_type) \ + REGISTER_KERNEL_BUILDER( \ + Name(#tf_op) \ + .Device(DEVICE_##platform) \ + .TypeConstraint::Type>("T") \ + .TypeConstraint::Type>("Tout"), \ + MLIR_OP(tf_op, platform, input_type, output_type)); + +#define REGISTER_KERNEL_NO_TYPE_CONSTRAINT(tf_op, platform, input_type) \ + REGISTER_KERNEL_BUILDER(Name(#tf_op).Device(DEVICE_##platform), \ + MLIR_OP(tf_op, platform, input_type, input_type)); + +/// Unary kernels. + +#define GENERATE_AND_REGISTER_UNARY_KERNEL(tf_op, platform, input_type, \ + additional_cstrs) \ + GENERATE_UNARY_KERNEL(tf_op, platform, input_type) \ + REGISTER_KERNEL(tf_op, platform, input_type, input_type, additional_cstrs) + +#define GENERATE_AND_REGISTER_UNARY_KERNEL2(tf_op, platform, input_type, \ + output_type, additional_cstrs) \ + GENERATE_UNARY_KERNEL(tf_op, platform, input_type, output_type) \ + REGISTER_KERNEL(tf_op, platform, input_type, output_type, additional_cstrs) + +#define GENERATE_AND_REGISTER_UNARY_KERNEL3( \ + tf_op, platform, input_type, output_type, casted_input_type, \ + casted_output_type, additional_cstrs) \ + GENERATE_UNARY_KERNEL3(tf_op, platform, input_type, output_type, \ + casted_input_type, casted_output_type) \ + REGISTER_KERNEL(tf_op, platform, casted_input_type, casted_output_type, \ + additional_cstrs) + +#define GENERATE_AND_REGISTER_UNARY_JIT_KERNEL(tf_op, platform, input_type, \ + additional_cstrs) \ + GENERATE_AND_REGISTER_UNARY_KERNEL(tf_op, platform, input_type, \ + .Label(kJitKernelLabel) additional_cstrs) + +#define GENERATE_UNARY_KERNEL(tf_op, platform, input_type) \ + GENERATE_UNARY_KERNEL2(tf_op, platform, input_type, input_type) + +#define GENERATE_UNARY_KERNEL2(tf_op, platform, input_type, output_type) \ + GENERATE_UNARY_KERNEL3(tf_op, platform, input_type, output_type, input_type, \ + output_type) + +#define GENERATE_UNARY_KERNEL3(tf_op, platform, input_type, output_type, \ + casted_input_type, casted_output_type) \ + extern "C" void MLIR_FUNCTION(tf_op, platform, input_type, output_type)( \ + UnrankedMemRef * result, OpKernelContext * ctx, UnrankedMemRef * arg); \ + \ + namespace { \ + class MLIR_OP(tf_op, platform, casted_input_type, casted_output_type) \ + : public MLIROpKernel::Type, \ + casted_output_type> { \ + public: \ + using MLIROpKernel::MLIROpKernel; \ + \ + UnrankedMemRef Invoke( \ + OpKernelContext* ctx, \ + llvm::SmallVectorImpl& args) override { \ + UnrankedMemRef result; \ + MLIR_FUNCTION(tf_op, platform, input_type, output_type) \ + (&result, ctx, &args[0]); \ + return result; \ + } \ + }; \ + } + +/// Binary kernels. + +#define GENERATE_AND_REGISTER_BINARY_KERNEL(tf_op, platform, input_type, \ + additional_cstrs) \ + GENERATE_BINARY_KERNEL(tf_op, platform, input_type) \ + REGISTER_KERNEL(tf_op, platform, input_type, input_type, additional_cstrs) + +#define GENERATE_AND_REGISTER_BINARY_KERNEL2(tf_op, platform, input_type, \ + output_type, additional_cstrs) \ + GENERATE_BINARY_KERNEL2(tf_op, platform, input_type, output_type) \ + REGISTER_KERNEL(tf_op, platform, input_type, output_type, additional_cstrs) + +#define GENERATE_AND_REGISTER_BINARY_KERNEL3( \ + tf_op, platform, input_type, output_type, casted_input_type, \ + casted_output_type, additional_cstrs) \ + GENERATE_BINARY_KERNEL3(tf_op, platform, input_type, output_type, \ + casted_input_type, casted_output_type) \ + REGISTER_KERNEL(tf_op, platform, casted_input_type, casted_output_type, \ + additional_cstrs) + +#define GENERATE_AND_REGISTER_BINARY_JIT_KERNEL(tf_op, platform, input_type, \ + additional_cstrs) \ + GENERATE_AND_REGISTER_BINARY_KERNEL( \ + tf_op, platform, input_type, .Label(kJitKernelLabel) additional_cstrs) + +#define GENERATE_BINARY_KERNEL(tf_op, platform, input_type) \ + GENERATE_BINARY_KERNEL2(tf_op, platform, input_type, input_type) + +#define GENERATE_BINARY_KERNEL2(tf_op, platform, input_type, output_type) \ + GENERATE_BINARY_KERNEL3(tf_op, platform, input_type, output_type, \ + input_type, output_type) + +#define GENERATE_BINARY_KERNEL3(tf_op, platform, input_type, output_type, \ + casted_input_type, casted_output_type) \ + extern "C" void MLIR_FUNCTION(tf_op, platform, input_type, output_type)( \ + UnrankedMemRef * result, OpKernelContext * ctx, UnrankedMemRef * arg0, \ + UnrankedMemRef * arg1); \ + \ + namespace { \ + class MLIR_OP(tf_op, platform, casted_input_type, casted_output_type) \ + : public MLIROpKernel::Type, \ + casted_output_type> { \ + public: \ + using MLIROpKernel::MLIROpKernel; \ + \ + UnrankedMemRef Invoke( \ + OpKernelContext* ctx, \ + llvm::SmallVectorImpl& args) override { \ + UnrankedMemRef result; \ + MLIR_FUNCTION(tf_op, platform, input_type, output_type) \ + (&result, ctx, &args[0], &args[1]); \ + return result; \ + } \ + }; \ + } + +/// Ternary kernels. + +#define GENERATE_AND_REGISTER_TERNARY_KERNEL(tf_op, platform, input_type, \ + additional_cstrs) \ + GENERATE_TERNARY_KERNEL(tf_op, platform, input_type) \ + REGISTER_KERNEL(tf_op, platform, input_type, input_type, additional_cstrs) + +#define GENERATE_AND_REGISTER_TERNARY_KERNEL2(tf_op, platform, input_type, \ + output_type, additional_cstrs) \ + GENERATE_TERNARY_KERNEL2(tf_op, platform, input_type, output_type) \ + REGISTER_KERNEL(tf_op, platform, input_type, output_type, additional_cstrs) + +#define GENERATE_AND_REGISTER_TERNARY_KERNEL3( \ + tf_op, platform, input_type, output_type, casted_input_type, \ + casted_output_type, additional_cstrs) \ + GENERATE_TERNARY_KERNEL3(tf_op, platform, input_type, output_type, \ + casted_input_type, casted_output_type) \ + REGISTER_KERNEL(tf_op, platform, casted_input_type, casted_output_type, \ + additional_cstrs) + +#define GENERATE_AND_REGISTER_TERNARY_JIT_KERNEL(tf_op, platform, input_type, \ + additional_cstrs) \ + GENERATE_AND_REGISTER_TERNARY_KERNEL( \ + tf_op, platform, input_type, .Label(kJitKernelLabel) additional_cstrs) + +#define GENERATE_TERNARY_KERNEL(tf_op, platform, input_type) \ + GENERATE_TERNARY_KERNEL2(tf_op, platform, input_type, input_type) + +#define GENERATE_TERNARY_KERNEL2(tf_op, platform, input_type, output_type) \ + GENERATE_TERNARY_KERNEL3(tf_op, platform, input_type, output_type, \ + input_type, output_type) + +#define GENERATE_TERNARY_KERNEL3(tf_op, platform, input_type, output_type, \ + casted_input_type, casted_output_type) \ + extern "C" void MLIR_FUNCTION(tf_op, platform, input_type, output_type)( \ + UnrankedMemRef * result, OpKernelContext * ctx, UnrankedMemRef * arg0, \ + UnrankedMemRef * arg1, UnrankedMemRef * arg2); \ + \ + namespace { \ + class MLIR_OP(tf_op, platform, casted_input_type, casted_output_type) \ + : public MLIROpKernel::Type, \ + casted_output_type> { \ + public: \ + using MLIROpKernel::MLIROpKernel; \ + \ + UnrankedMemRef Invoke( \ + OpKernelContext* ctx, \ + llvm::SmallVectorImpl& args) override { \ + UnrankedMemRef result; \ + MLIR_FUNCTION(tf_op, platform, input_type, output_type) \ + (&result, ctx, &args[0], &args[1], &args[2]); \ + return result; \ + } \ + }; \ + } + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_MLIR_GENERATED_BASE_OP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/mlir_generated/base_ops_test.h b/third_party/tflite-hdrs/tensorflow/core/kernels/mlir_generated/base_ops_test.h new file mode 100644 index 00000000..d7a2a2d0 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/mlir_generated/base_ops_test.h @@ -0,0 +1,324 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_MLIR_GENERATED_BASE_OPS_TEST_H_ +#define TENSORFLOW_CORE_KERNELS_MLIR_GENERATED_BASE_OPS_TEST_H_ + +#include +#include +#include + +#include "absl/container/inlined_vector.h" +#include "absl/strings/string_view.h" +#include "llvm/ADT/STLExtras.h" +#include "tensorflow/core/framework/tensor_shape.h" + +namespace tensorflow { +namespace test { + +template +using is_integer = llvm::is_one_of; + +/// Helper functions to create or derive inputs of the right type and size. + +template +absl::InlinedVector InputAsVector( + std::initializer_list input) { + absl::InlinedVector result; + result.reserve(input.size()); + for (const LiteralT& value : input) { + result.push_back(static_cast(value)); + } + return result; +} + +template +absl::InlinedVector RepeatInputToMatchShape( + absl::InlinedVector input, int64_t size) { + absl::InlinedVector result; + result.reserve(size); + for (int64_t i = 0; i < size; i++) { + auto value = input[i % input.size()]; + result.push_back(value); + } + return result; +} + +template +absl::InlinedVector RepeatElements(absl::InlinedVector input, + int64_t num_repeats) { + absl::InlinedVector result; + result.reserve(input.size() * num_repeats); + for (T value : input) { + for (int64_t i = 0; i < num_repeats; ++i) { + result.push_back(value); + } + } + return result; +} + +/// Helper functions to get default input shapes. + +TensorShape DefaultInputShape(); +TensorShape DefaultInputShapeExceedingInt32(); + +/// Helper functions to configure tests. + +struct OpsTestConfig { + bool add_t = true; + bool add_tout = false; + // Only used for gpu_unary_ops_test. + bool expect_buffer_reuse = true; + bool expect_strictly_equal = false; + bool supress_tolerance = false; + // Negative atol/rtol will make ExpectClose use the default. + double atol = -1; + double rtol = -1; + std::string input_attribute = "T"; + std::string output_attribute = "Tout"; + bool jit_compilation = false; + OpsTestConfig ExpectStrictlyEqual() { + OpsTestConfig config = *this; + config.expect_strictly_equal = true; + return config; + } + OpsTestConfig SuppressTolerance() { + OpsTestConfig config = *this; + config.supress_tolerance = true; + return config; + } + OpsTestConfig NoBufferReuse() { + OpsTestConfig config = *this; + config.expect_buffer_reuse = false; + return config; + } + OpsTestConfig AddTout() { + OpsTestConfig config = *this; + config.add_tout = true; + return config; + } + OpsTestConfig NoT() { + OpsTestConfig config = *this; + config.add_t = false; + return config; + } + OpsTestConfig RTol(double new_rtol) { + OpsTestConfig config = *this; + config.rtol = new_rtol; + return config; + } + OpsTestConfig ATol(double new_atol) { + OpsTestConfig config = *this; + config.atol = new_atol; + return config; + } + OpsTestConfig InputAttribute(const std::string& attr) { + OpsTestConfig config = *this; + config.input_attribute = attr; + return config; + } + OpsTestConfig OutputAttribute(const std::string& attr) { + OpsTestConfig config = *this; + config.output_attribute = attr; + return config; + } + OpsTestConfig JITCompilation() { + OpsTestConfig config = *this; + config.jit_compilation = true; + return config; + } +}; + +/// Helper functions to get more specific input data. + +template ::value, + bool> = true> +absl::InlinedVector NearZeroAndExtremeInput() { + return InputAsVector({-std::numeric_limits::infinity(), + -0.1, -0.0, 0.0, 0.1, + std::numeric_limits::infinity()}); +} + +template ::value, bool> = true> +absl::InlinedVector NearZeroAndExtremeInput() { + return InputAsVector({std::numeric_limits::min(), + std::numeric_limits::min() + 1, -1, 0, 1, + std::numeric_limits::max()}); +} + +template ::value, + bool> = true> +absl::InlinedVector NearZeroInfAndNanInput() { + return InputAsVector({-std::numeric_limits::quiet_NaN(), + -std::numeric_limits::infinity(), + -0.1, -0.0, 0.0, 0.1, + std::numeric_limits::infinity(), + std::numeric_limits::quiet_NaN()}); +} + +template ::value, + bool> = true> +absl::InlinedVector DefaultInputGreaterEqualOne() { + return test::InputAsVector( + {18.0, 9.0, 1.0, std::numeric_limits::max(), 42.0, 2.0, 1.0, + std::sqrt(std::numeric_limits::max()), 9.0, 18.0}); +} + +template ::value, + bool> = true> +absl::InlinedVector DefaultInputGreaterThanZero() { + return test::InputAsVector({18.0, 9.0, 1e-6, 1.0, 0.1, 1e-6, 0.1, + 0.2, 0.3, 0.5, 0.7, 0.9, 9.0, 18.0}); +} + +template ::value, + bool> = true> +absl::InlinedVector DefaultInputGreaterOrEqualToZero() { + return test::InputAsVector({18.0, 9.0, 1e-6, 0.0, 0.1, 1e-6, 0.1, + 0.2, 0.3, 0.5, 0.7, 0.9, 9.0, 18.0}); +} + +template ::value, + bool> = true> +absl::InlinedVector DefaultInputNonZero() { + return test::InputAsVector({18.0, 9.0, 1e-6, -0.1, 0.1, 1e-6, 0.1, + 0.2, 0.3, 0.5, 0.7, 0.9, 9.0, 18.0}); +} + +template ::value, bool> = true> +absl::InlinedVector DefaultInputNonZero() { + return test::InputAsVector({-18, -9, -1, 1, 3, 4, 5, 7, 9, 10, 18}); +} + +template ::value, + bool> = true> +absl::InlinedVector DefaultInputBetweenZeroAndOne() { + return test::InputAsVector({-0.999, -0.9, -0.8, -0.5, -0.1, -0.001, + -0, 0, 0.001, 0.1, 0.5, 0.8, 0.9, + 0.999}); +} + +template ::value, bool> = true> +absl::InlinedVector DefaultInputLessThanBitwidth() { + auto max_shift = sizeof(T) * 8 - 1; + absl::InlinedVector v; + for (auto i = 0; i < max_shift; ++i) v.push_back(i); + return v; +} + +/// Helper functions to get default input data. + +template ::value, bool> = true> +absl::InlinedVector DefaultInput() { + return InputAsVector({-18, -9, -1, 0, 0, 1, 1, 2, 3, 5, 7, 9, 9, 18}); +} + +template ::value, + bool> = true> +absl::InlinedVector DefaultInput() { + return InputAsVector({-18.0, -9.0, -0.7, -0.5, -0.3, -0.2, -0.1, + -1e-6, -0.0, 0.0, 1e-6, 0.1, 0.2, 0.3, 0.5, + 0.7, 0.9, 18.0}); +} + +template , + std::complex>::value, + bool> = true> +absl::InlinedVector DefaultInput() { + using ElementType = typename T::value_type; + auto input = test::DefaultInput(); + absl::InlinedVector complex_input; + for (ElementType value : input) { + complex_input.emplace_back(value, -value); + } + return complex_input; +} + +template , + std::complex>::value, + bool> = true> +absl::InlinedVector ComplexInputFromValues( + const absl::InlinedVector& real, + const absl::InlinedVector& imag) { + using ElementType = typename T::value_type; + absl::InlinedVector complex_input; + CHECK_EQ(real.size(), imag.size()); + for (size_t i = 0; i < real.size() && i < imag.size(); ++i) { + complex_input.emplace_back(real[i], imag[i]); + } + return complex_input; +} + +template , + std::complex>::value, + bool> = true> +absl::InlinedVector DefaultInputNonZero() { + auto real = test::DefaultInputNonZero(); + auto imag = real; + std::reverse(imag.begin(), imag.end()); + return test::ComplexInputFromValues(real, imag); +} + +template , + std::complex>::value, + bool> = true> +absl::InlinedVector DefaultInputGreaterOrEqualToZero() { + auto real = test::DefaultInputGreaterOrEqualToZero(); + auto imag = real; + std::reverse(imag.begin(), imag.end()); + return test::ComplexInputFromValues(real, imag); +} + +template , + std::complex>::value, + bool> = true> +absl::InlinedVector NearZeroInfAndNanInput() { + using ElementType = typename T::value_type; + auto input = test::NearZeroInfAndNanInput(); + absl::InlinedVector real; + absl::InlinedVector imag; + for (ElementType r : input) { + for (ElementType i : input) { + real.push_back(r); + imag.push_back(i); + } + } + return test::ComplexInputFromValues(real, imag); +} + +template ::value, bool> = true> +absl::InlinedVector DefaultInput() { + return InputAsVector({true, false, true, true, false}); +} + +} // namespace test +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_MLIR_GENERATED_BASE_OPS_TEST_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/mlir_generated/base_unary_ops_test.h b/third_party/tflite-hdrs/tensorflow/core/kernels/mlir_generated/base_unary_ops_test.h new file mode 100644 index 00000000..5edb7e7d --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/mlir_generated/base_unary_ops_test.h @@ -0,0 +1,219 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_MLIR_GENERATED_BASE_UNARY_OPS_TEST_H_ +#define TENSORFLOW_CORE_KERNELS_MLIR_GENERATED_BASE_UNARY_OPS_TEST_H_ + +#include + +#include "absl/container/inlined_vector.h" +#include "tensorflow/compiler/mlir/tools/kernel_gen/tf_jit_cache.h" +#include "tensorflow/core/framework/fake_input.h" +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/kernels/mlir_generated/base_ops_test.h" +#include "tensorflow/core/kernels/ops_testutil.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { + +// Base class for `UnaryOpsTest` fixture that has to be defined with a custom TF +// device if you want to use the test macros in this file. +class UnaryOpsTestBase : public OpsTestBase { + protected: + // This method should set the TF device, e.g. DEVICE_CPU, DEVICE_GPU. + void SetUp() override = 0; + + template + void SetOpKernel(const std::string& op_name, const TensorShape& shape, + const absl::InlinedVector& input, + const test::OpsTestConfig& config) { + NodeDefBuilder builder("some_name", op_name); + builder.Input(FakeInput(DataTypeToEnum::v())); + if (config.add_t) { + builder.Attr(config.input_attribute, DataTypeToEnum::v()); + } + if (config.add_tout) { + builder.Attr(config.output_attribute, DataTypeToEnum::v()); + } + TF_ASSERT_OK(builder.Finalize(node_def())); + + TF_ASSERT_OK(InitOp()); + AddInputFromArray(shape, input); + } + + template + void RunAndExpectResult(const std::string& op_name, const TensorShape& shape, + const absl::InlinedVector& input, + const absl::InlinedVector& expected_output, + const test::OpsTestConfig& config) { + SetOpKernel(op_name, shape, input, config); + TF_ASSERT_OK(RunOpKernel()); + + // Assert buffer reuse if expected. + if (config.expect_buffer_reuse) { + void* arg_ptr_on_device = context_->input(0).data(); + void* result_ptr_on_device = context_->mutable_output(0)->data(); + ASSERT_EQ(arg_ptr_on_device, result_ptr_on_device); + } + + // Assert expected results. + Tensor expected_tensor(allocator(), DataTypeToEnum::value, shape); + test::FillValues(&expected_tensor, expected_output); + if (config.expect_strictly_equal) { + test::ExpectEqual(expected_tensor, *GetOutput(0), + config.supress_tolerance ? test::Tolerance::kNone + : test::Tolerance::kDefault); + } else { + test::ExpectClose(expected_tensor, *GetOutput(0), kAbsoluteTolerance, + kRelativeTolerance); + } + + // For JIT-compiled kernels, expect exactly one entry in the JIT cache for + // the current test. The cache is not affected by other tests as we always + // set up a new environment. + if (config.jit_compilation) { + ResourceMgr* mgr = context_->resource_manager(); + mlir::kernel_gen::tf_framework::JITCache* cache; + TF_ASSERT_OK(mgr->Lookup( + mgr->default_container(), + mlir::kernel_gen::tf_framework::JITCache::kDefaultResourceName, + &cache)); + core::ScopedUnref cache_ref(cache); + ASSERT_EQ(cache->Size(), 1); + } + } + + template + void TestImpl(const std::string& op_name, const TensorShape& shape, + const absl::InlinedVector& input, + const BaselineCallback& baseline_callback, + const test::OpsTestConfig& config) { + // Prepare inputs and compute expected results. + CHECK(input.size() <= shape.num_elements()); + auto repeated_input = + test::RepeatInputToMatchShape(input, shape.num_elements()); + absl::InlinedVector expected_output = + ComputeExpectedOutput(repeated_input, + baseline_callback); + + RunAndExpectResult(op_name, shape, repeated_input, expected_output, + config); + } + + template + void Test(const std::string& op_name, const TensorShape& shape, + const absl::InlinedVector& input, + const BaselineCallback& baseline_callback, + const test::OpsTestConfig& config) { + TestImpl(op_name, shape, input, baseline_callback, + config); + } + + // Allow deduction of overloaded function with const ref input. + template + void Test(const std::string& op_name, const TensorShape& shape, + const absl::InlinedVector& input, + BaselineOutT (*baseline_callback)(const BaselineT&), + const test::OpsTestConfig& config) { + TestImpl(op_name, shape, input, baseline_callback, + config); + } + + // Allow deduction of overloaded function with value input. + template + void Test(const std::string& op_name, const TensorShape& shape, + const absl::InlinedVector& input, + BaselineOutT (*baseline_callback)(BaselineT), + const test::OpsTestConfig& config) { + TestImpl(op_name, shape, input, baseline_callback, + config); + } + + template + void TestEmptyShape(const std::string& op_name, + const test::OpsTestConfig& config) { + TensorShape shape{0, 1, 2}; + absl::InlinedVector empty_input = {}; + absl::InlinedVector expected_output = {}; + RunAndExpectResult(op_name, shape, empty_input, expected_output, + config); + } + + private: + constexpr static double kAbsoluteTolerance = 0.001; + constexpr static double kRelativeTolerance = 0.001; + + template + absl::InlinedVector ComputeExpectedOutput( + absl::InlinedVector input, + const BaselineCallback& baseline_callback) { + absl::InlinedVector expected_output; + expected_output.reserve(input.size()); + for (int64_t i = 0; i < input.size(); i++) { + auto arg = static_cast(input[i]); + auto result = static_cast(baseline_callback(arg)); + expected_output.push_back(result); + } + return expected_output; + } +}; + +// Macros to easily generate common test cases. The macros use `UnaryOpsTest` +// fixture in order to share implementation across GPU and CPU platform tests. +// For specific inputs, please define your own test fixtures. +#define GENERATE_DEFAULT_TEST(op_name, InT, OutT, baseline_callback, config) \ + GENERATE_DEFAULT_TEST_2(op_name, InT, InT, OutT, OutT, baseline_callback, \ + config) + +#define GENERATE_DEFAULT_TEST_2(op_name, InT, BaselineT, OutT, BaselineOutT, \ + baseline_callback, config) \ + GENERATE_DEFAULT_TEST_WITH_SPECIFIC_INPUT_VALUES_2( \ + op_name, InT, BaselineT, OutT, BaselineOutT, \ + test::DefaultInput(), baseline_callback, config) + +#define GENERATE_DEFAULT_TEST_WITH_SPECIFIC_INPUT_VALUES( \ + op_name, InT, OutT, input_values, baseline_callback, config) \ + GENERATE_DEFAULT_TEST_WITH_SPECIFIC_INPUT_VALUES_2( \ + op_name, InT, InT, OutT, OutT, input_values, baseline_callback, config) + +#define GENERATE_DEFAULT_TEST_WITH_SPECIFIC_INPUT_VALUES_2( \ + op_name, InT, BaselineT, OutT, BaselineOutT, input_values, \ + baseline_callback, config) \ + TEST_F(UnaryOpsTest, op_name##InT##OutT) { \ + using NativeT = EnumToDataType::Type; \ + using NativeBaselineT = EnumToDataType::Type; \ + using NativeOutT = EnumToDataType::Type; \ + using NativeBaselineOutT = EnumToDataType::Type; \ + Test( \ + #op_name, test::DefaultInputShape(), input_values, baseline_callback, \ + config); \ + } \ + TEST_F(UnaryOpsTest, op_name##InT##OutT##EmptyShape) { \ + using NativeT = EnumToDataType::Type; \ + using NativeOutT = EnumToDataType::Type; \ + TestEmptyShape(#op_name, config); \ + } + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_MLIR_GENERATED_BASE_UNARY_OPS_TEST_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/multinomial_op.h b/third_party/tflite-hdrs/tensorflow/core/kernels/multinomial_op.h new file mode 100644 index 00000000..34e21236 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/multinomial_op.h @@ -0,0 +1,30 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_MULTINOMIAL_OP_H_ +#define TENSORFLOW_CORE_KERNELS_MULTINOMIAL_OP_H_ + +namespace tensorflow { + +namespace functor { + +// Generic helper functor for the Multinomial Op. +template +struct MultinomialFunctor; + +} // namespace functor +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_MULTINOMIAL_OP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/nextafter_op.h b/third_party/tflite-hdrs/tensorflow/core/kernels/nextafter_op.h new file mode 100644 index 00000000..89a39f49 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/nextafter_op.h @@ -0,0 +1,39 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_NEXTAFTER_OP_H_ +#define TENSORFLOW_CORE_KERNELS_NEXTAFTER_OP_H_ + +#include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/kernels/cwise_ops.h" + +namespace tensorflow { +namespace functor { + +template +struct nextafter_op { + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T operator()(const T& x1, + const T& x2) const { + return std::nextafter(x1, x2); + } +}; + +template +struct nextafter : base> {}; + +} // namespace functor +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_NEXTAFTER_OP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/no_op.h b/third_party/tflite-hdrs/tensorflow/core/kernels/no_op.h new file mode 100644 index 00000000..9e16d069 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/no_op.h @@ -0,0 +1,32 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_NO_OP_H_ +#define TENSORFLOW_CORE_KERNELS_NO_OP_H_ + +#include "tensorflow/core/framework/op_kernel.h" + +namespace tensorflow { + +class NoOp : public OpKernel { + public: + explicit NoOp(OpKernelConstruction* context) : OpKernel(context) {} + void Compute(OpKernelContext* context) override {} + bool IsExpensive() override { return false; } +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_NO_OP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/nth_element_op.h b/third_party/tflite-hdrs/tensorflow/core/kernels/nth_element_op.h new file mode 100644 index 00000000..7a5ec3d0 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/nth_element_op.h @@ -0,0 +1,37 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_NTH_ELEMENT_OP_H_ +#define TENSORFLOW_CORE_KERNELS_NTH_ELEMENT_OP_H_ + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +namespace functor { + +template +struct NthElementFunctor { + void operator()(OpKernelContext* context, const Tensor& input_tensor, + Tensor& output_tensor, int n); +}; + +} // namespace functor + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_NTH_ELEMENT_OP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/numeric_options_utils.h b/third_party/tflite-hdrs/tensorflow/core/kernels/numeric_options_utils.h new file mode 100644 index 00000000..ced38d37 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/numeric_options_utils.h @@ -0,0 +1,47 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_NUMERIC_OPTIONS_UTILS_H_ +#define TENSORFLOW_CORE_KERNELS_NUMERIC_OPTIONS_UTILS_H_ + +#include "xla/stream_executor/numeric_options.h" +#include "xla/tsl/util/determinism.h" +#include "tensorflow/core/util/env_var.h" +#include "tsl/platform/tensor_float_32_utils.h" + +namespace tensorflow { + +inline stream_executor::NumericOptions GetNumericOptions() { + return stream_executor::NumericOptions{ + /*require_determinism=*/tsl::OpDeterminismRequired(), + /*allow_tf32=*/tsl::tensor_float_32_execution_enabled()}; +} + +inline stream_executor::NumericOptions GetNumericOptionsForCuDnn() { + static bool cudnn_deterministic_env_var = [] { + bool cudnn_deterministic = false; + TF_CHECK_OK(ReadBoolFromEnvVar("TF_CUDNN_DETERMINISTIC", + /*default_val=*/false, + &cudnn_deterministic)); + return cudnn_deterministic; + }(); + stream_executor::NumericOptions result = GetNumericOptions(); + result.require_determinism |= cudnn_deterministic_env_var; + return result; +} + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_NUMERIC_OPTIONS_UTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/one_hot_op.h b/third_party/tflite-hdrs/tensorflow/core/kernels/one_hot_op.h new file mode 100644 index 00000000..afcf287a --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/one_hot_op.h @@ -0,0 +1,125 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// See docs in ../ops/array_ops.cc + +#ifndef TENSORFLOW_CORE_KERNELS_ONE_HOT_OP_H_ +#define TENSORFLOW_CORE_KERNELS_ONE_HOT_OP_H_ +// Generator definition for OneHotOp, must be compilable by nvcc. + +#define EIGEN_USE_THREADS + +#include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/bounds_check.h" +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +typedef Eigen::ThreadPoolDevice CPUDevice; + +namespace generator { + +template +class OneGenerator { + public: + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE + OneGenerator(const typename TTypes::ConstMatrix& indices, + const typename TTypes::ConstScalar& on_value, + const typename TTypes::ConstScalar& off_value) + : indices_(indices), on_value_(on_value), off_value_(off_value) {} + + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T + operator()(const Eigen::array& pre_depth_suff) const { + return (indices_(pre_depth_suff[0], pre_depth_suff[2]) == pre_depth_suff[1]) + ? on_value_() + : off_value_(); + } + + private: + const typename TTypes::ConstMatrix indices_; + const typename TTypes::ConstScalar on_value_; + const typename TTypes::ConstScalar off_value_; +}; + +} // namespace generator + +namespace functor { + +template +struct OneHot { + EIGEN_ALWAYS_INLINE static void Compute( + const Device& d, const typename TTypes::ConstMatrix& indices, + const typename TTypes::ConstScalar& on_value, + const typename TTypes::ConstScalar& off_value, + typename TTypes::Tensor* output) { + generator::OneGenerator generator(indices, on_value, off_value); + output->device(d) = output->generate(generator); + } +}; + +template +struct OneHot { + EIGEN_ALWAYS_INLINE static void Compute( + const CPUDevice& d, const typename TTypes::ConstMatrix& indices, + const typename TTypes::ConstScalar& on_value, + const typename TTypes::ConstScalar& off_value, + typename TTypes::Tensor* output) { + // Pre-fill output with `off_value`. + output->device(d) = output->constant(off_value()); + + // Iterate through indices and update on_value elements in the output. + Eigen::Index prefix_size = output->dimensions()[0]; + Eigen::Index depth_size = output->dimensions()[1]; + Eigen::Index suffix_size = output->dimensions()[2]; + + // Cost of setting one `on_value` coefficient. + double bytes_loaded = sizeof(T); + double bytes_stored = sizeof(T); + double cycles = 0.0; + const Eigen::TensorOpCost cost(bytes_loaded, bytes_stored, cycles); + + if (suffix_size == 1) { + const auto func = [&](Eigen::Index start, Eigen::Index end) -> void { + for (Eigen::Index i = start; i < end; ++i) { + const TI depth = internal::SubtleMustCopy(indices(i, 0)); + if (FastBoundsCheck(depth, depth_size)) { + (*output)(i, depth, 0) = on_value(); + } + } + }; + d.parallelFor(prefix_size, cost, func); + } else { + const auto func = [&](Eigen::Index start, Eigen::Index end) -> void { + for (Eigen::Index i = start; i < end; ++i) { + const Eigen::Index d0 = i / suffix_size; + const Eigen::Index d1 = i - (d0 * suffix_size); + const TI depth = internal::SubtleMustCopy(indices(d0, d1)); + if (FastBoundsCheck(depth, depth_size)) { + (*output)(d0, depth, d1) = on_value(); + } + } + }; + d.parallelFor(prefix_size * suffix_size, cost * suffix_size, func); + } + } +}; + +} // namespace functor + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_ONE_HOT_OP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/ops_testutil.h b/third_party/tflite-hdrs/tensorflow/core/kernels/ops_testutil.h new file mode 100644 index 00000000..ef4a7cd5 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/ops_testutil.h @@ -0,0 +1,212 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_OPS_TESTUTIL_H_ +#define TENSORFLOW_CORE_KERNELS_OPS_TESTUTIL_H_ + +#include +#include +#include +#include +#include +#include + +#include "tensorflow/core/common_runtime/device.h" +#include "tensorflow/core/common_runtime/device_factory.h" +#include "tensorflow/core/common_runtime/device_mgr.h" +#include "tensorflow/core/common_runtime/process_function_library_runtime.h" +#include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/framework/device_base.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/resource_mgr.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/framework/type_index.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/lib/gtl/inlined_vector.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/threadpool.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/public/session_options.h" +#include "tensorflow/core/public/version.h" +#include "tensorflow/core/util/tensor_slice_reader_cache.h" + +namespace tensorflow { +namespace test { + +void SetOutputAttrs(OpKernelContext::Params* params, + std::vector* attrs); + +} // namespace test + +// Helpful functions to test operators. +// +// This class will eventually be replaced / heavily modified +// to use the BrainClient interface. +class OpsTestBase : public ::testing::Test { + public: + OpsTestBase(); + + ~OpsTestBase() override; + + // Allow kernel unit tests to run on GPU + void SetDevice(const DeviceType& device_type, std::unique_ptr device); + + void set_node_def(const NodeDef& node_def); + + // Clients can manipulate the underlying NodeDef via this accessor. + NodeDef* node_def(); + + // Initializes an operator that takes in 'input_types' as input + // and output types as output. + // + // Returns the status of initialization. + absl::Status InitOp(); + + // Only use this directly if you have a deprecated op that you need to test. + absl::Status InitOpWithGraphVersion(int graph_def_version); + + // Adds an input for every element described by the shape. + // 'input_mapping' maps an index (0...NumElements(shape)) to a + // value. + // + // TODO(vrv): Replace with something like a BrainClient Feed. + template + void AddInput(const TensorShape& shape, std::function input_mapping) { + test::FillFn(AddInput(DataTypeToEnum::v(), shape), input_mapping); + } + + // Like AddInput but takes in an explicit arrayslice of data. + template + void AddInputFromArray(const TensorShape& shape, + const gtl::ArraySlice data) { + test::FillValues(AddInput(DataTypeToEnum::v(), shape), data); + } + + // Convenience function to add an input and populate it with the elements from + // an initializer list converting the types as needed. + template + void AddInputFromList(const TensorShape& shape, + std::initializer_list data) { + test::FillValues(AddInput(DataTypeToEnum::v(), shape), data); + } + + // Adds a Resource type as input. If is empty, uses the default + // container name. + template + void AddResourceInput(const string& container, const string& name, + T* resource) { + CHECK_GT(input_types_.size(), inputs_.size()) + << "Adding more inputs than types; perhaps you need to call MakeOp"; + ResourceMgr* rm = device_->resource_manager(); + std::string container_name = + container.empty() ? rm->default_container() : container; + EXPECT_TRUE(rm->Create(container_name, name, resource).ok()); + AddResourceInputInternal(container_name, name, TypeIndex::Make()); + } + + // Runs an operation producing 'num_outputs' outputs. + // + // Returns the context's status after running the operation. + absl::Status RunOpKernel(); + + // Returns the tensor input for 'input_index'. + // + // REQUIRES: 0 <= input_index < context_->num_inputs() + const Tensor& GetInput(int input_index) const; + + TensorValue mutable_input(int input_index); + + // Returns the tensor output for 'output_index'. + // + // REQUIRES: 0 <= output_index < context_->num_outputs() + Tensor* GetOutput(int output_index); + + Allocator* allocator(); + + OpKernel* op_kernel(); + + const DataTypeVector& output_types() const; + + void set_session_metadata(SessionMetadata session_metadata) { + session_metadata_ = std::move(session_metadata); + } + + const SessionMetadata& session_metadata() const { return session_metadata_; } + + protected: + void CreateContext(); + Tensor* AddInput(DataType dtype, const TensorShape& shape); + void AddResourceInputInternal(const std::string& container_name, + const std::string& name, + const TypeIndex& type_index); + + // device_mgr_ owns device_. + std::unique_ptr device_mgr_; + Device* device_; + + // The device allocator, or the managed_allocator_ below if running on GPU. + Allocator* allocator_; + + std::unique_ptr kernel_; + std::unique_ptr step_container_; + NodeDef node_def_; + DataTypeVector input_types_; + DeviceType device_type_; + + mutex lock_for_refs_; // Used as the Mutex for inputs added as refs + + absl::InlinedVector inputs_; + // Owns Tensors. + std::vector tensors_; + // Copies of the outputs in unified memory (host and device accessible). + std::vector managed_outputs_; + + // AllocatorAttributes for the allocators of the outputs. + std::vector out_alloc_attrs_; + checkpoint::TensorSliceReaderCacheWrapper slice_reader_cache_wrapper_; + CancellationManager default_cancellation_manager_; + std::unique_ptr params_; + std::unique_ptr context_; + // Unified memory allocator, only used when running on GPU. + std::unique_ptr managed_allocator_; + + std::unique_ptr flib_def_; + std::unique_ptr pflr_; + std::unique_ptr thread_pool_; + + SessionMetadata session_metadata_; + + private: + OpsTestBase(const OpsTestBase&) = delete; + void operator=(const OpsTestBase&) = delete; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_OPS_TESTUTIL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/ops_util.h b/third_party/tflite-hdrs/tensorflow/core/kernels/ops_util.h new file mode 100644 index 00000000..842dd798 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/ops_util.h @@ -0,0 +1,22 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_OPS_UTIL_H_ +#define TENSORFLOW_CORE_KERNELS_OPS_UTIL_H_ + +// Placeholder for the ops_util library that is moved under core/framework. +#include "tensorflow/core/framework/ops_util.h" + +#endif // TENSORFLOW_CORE_KERNELS_OPS_UTIL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/pad_op.h b/third_party/tflite-hdrs/tensorflow/core/kernels/pad_op.h new file mode 100644 index 00000000..34a19dfc --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/pad_op.h @@ -0,0 +1,56 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_PAD_OP_H_ +#define TENSORFLOW_CORE_KERNELS_PAD_OP_H_ +// Functor definition for PadOp, must be compilable by nvcc. + +#include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +namespace functor { + +// Functor used by PadOp to do the computations. +template +struct Pad { + // Pad "input" into "output", as specified by "paddings" and "pad_value". + // See pad_op.cc for details. + void operator()(const Device& d, typename TTypes::Tensor output, + typename TTypes::ConstTensor input, + Eigen::array, Dims> paddings, + T pad_value) { + MaybeWith32BitIndexing( + [&](auto output32, auto input32) { + output32.device(d) = input32.pad(paddings, pad_value); + }, + output, input); + } +}; + +template +struct Pad { + // In the scalar case we simply copy the input. + void operator()(const Device& d, typename TTypes::Tensor output, + typename TTypes::ConstTensor input, + Eigen::array, 0>, T) { + output.device(d) = input; + } +}; +} // namespace functor +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_PAD_OP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/padding_fifo_queue.h b/third_party/tflite-hdrs/tensorflow/core/kernels/padding_fifo_queue.h new file mode 100644 index 00000000..74107e80 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/padding_fifo_queue.h @@ -0,0 +1,90 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_PADDING_FIFO_QUEUE_H_ +#define TENSORFLOW_CORE_KERNELS_PADDING_FIFO_QUEUE_H_ + +#include +#include + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/partial_tensor_shape.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/kernels/fifo_queue.h" +#include "tensorflow/core/kernels/typed_queue.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +class PaddingFIFOQueue : public FIFOQueue { + public: + PaddingFIFOQueue(int32_t capacity, const DataTypeVector& component_dtypes, + const std::vector& component_shapes, + const string& name); + + absl::Status Initialize() override; + + // Implementations of QueueInterface methods -------------------------------- + + void TryDequeueMany(int num_elements, OpKernelContext* ctx, + bool allow_small_batch, + CallbackWithTuple callback) override; + absl::Status MatchesNodeDef(const NodeDef& node_def) override; + + protected: + absl::Status ValidateManyTuple(const Tuple& tuple) override; + absl::Status ValidateTuple(const Tuple& tuple) override; + absl::Status CompatibleNodeDefShapes(const NodeDef& node_def) const; + + // Convert a list of PartialTensorShape to a list of + // TensorShape. + // Any unknown dimension sizes are converted to 0. + // REQUIRED: All the input shapes have well defined rank. + static std::vector ConvertShapesPartialDimensionsToZero( + absl::Span partial_shapes); + + // Sets the values in the given element to zero. + static absl::Status SetElementZero(Tensor* element); + + // Copies element into the index^th slice (in the first dimension) + // of parent. Allows for the parent's slice to have a larger size + // than the element, and copies the element into the upper left hand + // corner of the slice. + static absl::Status CopyElementToLargerSlice(const Tensor& element, + Tensor* parent, int index); + + std::vector partial_shapes_; + + private: + ~PaddingFIFOQueue() override {} + + static absl::Status GetElementComponent(const PaddingFIFOQueue::Tuple& tuple, + int component, OpKernelContext* ctx, + Tensor* out_tensor); + + static absl::Status IsSameSizeExceptZerosInFirst(const TensorShape& first, + const TensorShape& second); + + PaddingFIFOQueue(const PaddingFIFOQueue&) = delete; + void operator=(const PaddingFIFOQueue&) = delete; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_PADDING_FIFO_QUEUE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/parameterized_truncated_normal_op.h b/third_party/tflite-hdrs/tensorflow/core/kernels/parameterized_truncated_normal_op.h new file mode 100644 index 00000000..4df75c78 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/parameterized_truncated_normal_op.h @@ -0,0 +1,66 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_PARAMETERIZED_TRUNCATED_NORMAL_OP_H_ +#define TENSORFLOW_CORE_KERNELS_PARAMETERIZED_TRUNCATED_NORMAL_OP_H_ + +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/lib/random/random_distributions.h" +#include "tensorflow/core/util/bcast.h" + +namespace tensorflow { + +class OpKernelContext; + +namespace functor { + +// Sample a truncated normal random variable, with mean, stddev, minval, and +// maxval parameters for each batch. Uses two rejection sampling algorithms +// described in http://rd.springer.com/article/10.1007/BF00143942 and a randn +// rejection sampler when most of the normal is inside the bounds. +// +// Either minval may be -infinity, or maxval may be +infinity. If the interval +// (minval, maxval) is empty, the result is NaN. +template +struct TruncatedNormalFunctor { + void operator()(OpKernelContext* ctx, const Device& d, int64_t num_batches, + int64_t samples_per_batch, int64_t num_elements, + typename TTypes::ConstFlat means, + typename TTypes::ConstFlat stddevs, + typename TTypes::ConstFlat minvals, + typename TTypes::ConstFlat maxvals, + const random::PhiloxRandom& gen, + typename TTypes::Flat output); +}; + +// This version supports broadcasting of the arguments, as well as puts +// the sample dimension on the left. +template +struct TruncatedNormalFunctorV2 { + void operator()(OpKernelContext* ctx, const Device& d, int64_t num_batches, + int64_t samples_per_batch, int64_t num_elements, + const BCastList<4>& bcast, + typename TTypes::ConstFlat means, + typename TTypes::ConstFlat stddevs, + typename TTypes::ConstFlat minvals, + typename TTypes::ConstFlat maxvals, + const random::PhiloxRandom& gen, + typename TTypes::Flat output); +}; + +} // namespace functor +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_PARAMETERIZED_TRUNCATED_NORMAL_OP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/partitioned_function_ops.h b/third_party/tflite-hdrs/tensorflow/core/kernels/partitioned_function_ops.h new file mode 100644 index 00000000..2b2ec8ea --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/partitioned_function_ops.h @@ -0,0 +1,73 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_PARTITIONED_FUNCTION_OPS_H_ +#define TENSORFLOW_CORE_KERNELS_PARTITIONED_FUNCTION_OPS_H_ + +#include "tensorflow/core/common_runtime/function.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor.h" + +namespace tensorflow { + +class NameAttrList; +class ConfigProto; + +// A `PartitionedCallOp` asynchronously executes a function, potentially across +// multiple devices but within a single process. The kernel places and +// partitions a given function's underlying graph, and executes each of the +// partitioned subgraphs as a function. +// +// TODO(akshayka): Support distributed execution. +class PartitionedCallOp : public AsyncOpKernel { + public: + explicit PartitionedCallOp(OpKernelConstruction* ctx); + + ~PartitionedCallOp() override; + + void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override; + + protected: + absl::Status FillOutputDevices( + const FunctionLibraryRuntime& lib, const Device& cpu_device, + AttrSlice attrs, FunctionLibraryRuntime::InstantiateOptions* opts); + + absl::Status Instantiate(FunctionLibraryRuntime* lib, OpKernelContext* ctx, + std::vector* inputs, + FunctionLibraryRuntime::Handle* handle); + + void RunFunction(FunctionLibraryRuntime::Handle handle, + const std::vector& inputs, + FunctionLibraryRuntime* lib, OpKernelContext* ctx, + DoneCallback done); + + // Using unique pointers to avoid including proto headers in kernel headers + std::unique_ptr func_; + std::unique_ptr config_proto_; + string executor_type_; + bool shared_rendezvous_; + mutex mu_; + // Cache the handle per FLR because this kernel may be instantiated for + // a stateful op, different invocations of it may use different FLRs. + // Different device placements of PartitionedCallOp also use + // different FLRs. + gtl::FlatMap handles_ + TF_GUARDED_BY(mu_); +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_PARTITIONED_FUNCTION_OPS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/poisson-loss.h b/third_party/tflite-hdrs/tensorflow/core/kernels/poisson-loss.h new file mode 100644 index 00000000..d946b066 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/poisson-loss.h @@ -0,0 +1,109 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_POISSON_LOSS_H_ +#define TENSORFLOW_CORE_KERNELS_POISSON_LOSS_H_ + +#include + +#include "tensorflow/core/kernels/loss.h" +#include "tensorflow/core/lib/core/errors.h" + +namespace tensorflow { + +class PoissonLossUpdater : public DualLossUpdater { + public: + // Update is found by a Newton algorithm (see readme.md). + double ComputeUpdatedDual(const int num_loss_partitions, const double label, + const double example_weight, + const double current_dual, const double wx, + const double weighted_example_norm) const final { + // Newton algorithm converges quadratically so 10 steps will be largely + // enough to achieve a very good precision + static const int newton_total_steps = 10; + // Initialize the Newton optimization at x such that + // exp(x) = label - current_dual + const double y_minus_a = label - current_dual; + double x = (y_minus_a > 0) ? log(y_minus_a) : 0; + for (int i = 0; i < newton_total_steps; ++i) { + x = NewtonStep(x, num_loss_partitions, label, wx, example_weight, + weighted_example_norm, current_dual); + } + return label - exp(x); + } + + // Dual of poisson loss function. + // https://en.wikipedia.org/wiki/Convex_conjugate + double ComputeDualLoss(const double current_dual, const double example_label, + const double example_weight) const final { + // Dual of the poisson loss function is + // (y-a)*(log(y-a)-1), where a is the dual variable. + // It is defined only for a::max(); + } + return y_minus_a * (log(y_minus_a) - 1) * example_weight; + } + + double ComputePrimalLoss(const double wx, const double example_label, + const double example_weight) const final { + return (exp(wx) - wx * example_label) * example_weight; + } + + double PrimalLossDerivative(const double wx, const double label, + const double example_weight) const final { + return (exp(wx) - label) * example_weight; + } + + // TODO(chapelle): We need to introduce a maximum_prediction parameter, + // expose that parameter to the user and have this method return + // 1.0/maximum_prediction. + // Setting this at 1 for now, it only impacts the adaptive sampling. + double SmoothnessConstant() const final { return 1; } + + absl::Status ConvertLabel(float* const example_label) const final { + if (*example_label < 0.0) { + return errors::InvalidArgument( + "Only non-negative labels can be used with the Poisson log loss. " + "Found example with label: ", *example_label); + } + return absl::OkStatus(); + } + + private: + // One Newton step (see readme.md). + double NewtonStep(const double x, const int num_loss_partitions, + const double label, const double wx, + const double example_weight, + const double weighted_example_norm, + const double current_dual) const { + const double expx = exp(x); + const double numerator = + x - wx - num_loss_partitions * weighted_example_norm * + example_weight * (label - current_dual - expx); + const double denominator = + 1 + num_loss_partitions * weighted_example_norm * example_weight * expx; + return x - numerator / denominator; + } +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_LOGISTIC_LOSS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/pooling_ops_3d.h b/third_party/tflite-hdrs/tensorflow/core/kernels/pooling_ops_3d.h new file mode 100644 index 00000000..c0a589ff --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/pooling_ops_3d.h @@ -0,0 +1,80 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_POOLING_OPS_3D_H_ +#define TENSORFLOW_CORE_KERNELS_POOLING_OPS_3D_H_ + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/util/padding.h" +#include "tensorflow/core/util/tensor_format.h" + +namespace tensorflow { + +enum PoolingType { MAX, AVG }; + +template +struct LaunchPoolingOp; + +template +struct LaunchAvgPooling3dGradOp; + +template +struct LaunchMaxPooling3dGradOp; + +template +struct LaunchMaxPooling3dGradGradOp; + +// A helper class to manage sizes and shapes for 3d pooling operations. +struct Pool3dParameters { + // Updates context->status if there is an invalid input. + Pool3dParameters(OpKernelContext* context, const std::vector& ksize, + const std::vector& stride, Padding padding, + TensorFormat data_format, + const TensorShape& tensor_in_shape); + + // Returns the shape of the output for "forward" pooling operations. + absl::Status forward_output_shape(TensorShape* shape); + + int depth; + + int tensor_in_planes; + int tensor_in_cols; + int tensor_in_rows; + int tensor_in_batch; + + int window_planes; + int window_cols; + int window_rows; + int depth_window; + + int plane_stride; + int col_stride; + int row_stride; + int depth_stride; + + int64_t out_plane; + int64_t out_height; + int64_t out_width; + + int64_t pad_planes; + int64_t pad_cols; + int64_t pad_rows; + + TensorFormat data_format; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_POOLING_OPS_3D_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/pooling_ops_3d_gpu.h b/third_party/tflite-hdrs/tensorflow/core/kernels/pooling_ops_3d_gpu.h new file mode 100644 index 00000000..002964a3 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/pooling_ops_3d_gpu.h @@ -0,0 +1,48 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#if !GOOGLE_CUDA && !TENSORFLOW_USE_ROCM +#error This file must only be included when building with Cuda or ROCm support +#endif + +#ifndef TENSORFLOW_CORE_KERNELS_POOLING_OPS_3D_GPU_H_ +#define TENSORFLOW_CORE_KERNELS_POOLING_OPS_3D_GPU_H_ + +#define EIGEN_USE_GPU + +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/tensor_format.h" + +namespace tensorflow { + +namespace functor { +template +struct MaxPool3dGradBackward { + bool operator()(TensorFormat data_format, const T* bottom_data, + const T* output_data, const int batch, const int pooled_plane, + const int pooled_height, const int pooled_width, + const int channels, const int plane, const int height, + const int width, const int kernel_p, const int kernel_h, + const int kernel_w, const int stride_p, const int stride_h, + const int stride_w, const int pad_p, const int pad_t, + const int pad_l, const T* top_diff, T* bottom_diff, + const Eigen::GpuDevice& d); +}; +} // namespace functor + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_POOLING_OPS_3D_GPU_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/pooling_ops_common.h b/third_party/tflite-hdrs/tensorflow/core/kernels/pooling_ops_common.h new file mode 100644 index 00000000..bb5dda56 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/pooling_ops_common.h @@ -0,0 +1,681 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_POOLING_OPS_COMMON_H_ +#define TENSORFLOW_CORE_KERNELS_POOLING_OPS_COMMON_H_ + +#include + +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM +#define EIGEN_USE_GPU +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM + +#include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/bounds_check.h" +#include "tensorflow/core/framework/numeric_op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/kernels/avgpooling_op.h" +#include "tensorflow/core/kernels/maxpooling_op.h" +#include "tensorflow/core/kernels/ops_util.h" +#include "tensorflow/core/util/padding.h" +#include "tensorflow/core/util/tensor_format.h" +#include "tensorflow/core/util/work_sharder.h" + +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM +#include "tensorflow/core/kernels/maxpooling_op_gpu.h" +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM + +namespace tensorflow { + +typedef Eigen::GpuDevice GPUDevice; + +// A helper class to manage sizes and shapes for pooling operations. +struct PoolParameters { + // Updates context->status if there is an invalid input. + // explicit_paddings has eight elements if padding==EXPLIICT, and zero + // elements otherwise. + PoolParameters(OpKernelContext* context, const std::vector& ksize, + const std::vector& stride, Padding padding, + std::vector explicit_paddings, + TensorFormat data_format, const TensorShape& tensor_in_shape); + + // Returns the shape of the output for "forward" pooling operations. + absl::Status forward_output_shape(TensorShape* shape); + + int depth; + + int tensor_in_cols; + int tensor_in_rows; + int tensor_in_batch; + + int window_rows; + int window_cols; + int depth_window; + + int row_stride; + int col_stride; + int depth_stride; + + int64_t out_height; + int64_t out_width; + int out_depth; + + int64_t pad_top; + int64_t pad_bottom; + int64_t pad_left; + int64_t pad_right; + + int pad_depth; + + TensorFormat data_format; +}; + +// An implementation of MaxPooling (forward). +// TODO (yongtang): Remove MaxPoolingOp and use MaxPoolingV2Op, +// QuantizedMaxPoolingOp depends on MaxPoolingOp so keep intact for now +template +class MaxPoolingOp : public OpKernel { + public: + explicit MaxPoolingOp(OpKernelConstruction* context) : OpKernel(context) { + string data_format; + auto status = context->GetAttr("data_format", &data_format); + if (status.ok()) { + OP_REQUIRES(context, FormatFromString(data_format, &data_format_), + errors::InvalidArgument("Invalid data format")); + OP_REQUIRES( + context, data_format_ == FORMAT_NHWC, + errors::InvalidArgument("Default MaxPoolingOp only supports NHWC ", + "on device type ", + DeviceTypeString(context->device_type()))); + } else { + data_format_ = FORMAT_NHWC; + } + OP_REQUIRES_OK(context, context->GetAttr("ksize", &ksize_)); + OP_REQUIRES(context, ksize_.size() == 4, + errors::InvalidArgument("Sliding window ksize field must " + "specify 4 dimensions")); + OP_REQUIRES( + context, + ksize_[0] > 0 && ksize_[1] > 0 && ksize_[2] > 0 && ksize_[3] > 0, + errors::InvalidArgument( + absl::StrCat("Sliding window ksize must be positive. The " + "specified or inferred ksize is: ", + absl::StrJoin(ksize_, ",")))); + OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_)); + OP_REQUIRES(context, stride_.size() == 4, + errors::InvalidArgument("Sliding window stride field must " + "specify 4 dimensions")); + OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_)); + if (padding_ == Padding::EXPLICIT) { + OP_REQUIRES_OK( + context, context->GetAttr("explicit_paddings", &explicit_paddings_)); + } + OP_REQUIRES(context, ksize_[0] == 1 && stride_[0] == 1, + errors::Unimplemented( + "Pooling is not yet supported on the batch dimension.")); + } + + void Compute(OpKernelContext* context) override { + const Tensor& tensor_in = context->input(0); + PoolParameters params{ + context, ksize_, stride_, padding_, explicit_paddings_, + FORMAT_NHWC, tensor_in.shape()}; + if (!context->status().ok()) { + return; + } + + Tensor* output = nullptr; + TensorShape params_forward_output_shape; + OP_REQUIRES_OK(context, + params.forward_output_shape(¶ms_forward_output_shape)); + OP_REQUIRES_OK(context, context->allocate_output( + 0, params_forward_output_shape, &output)); + + if (params.depth_window > 1) { + // Validate spec against the current implementation. A + // relaxation of these requirements would be ideal. + OP_REQUIRES(context, params.depth % params.depth_window == 0, + errors::Unimplemented( + "Depthwise max pooling requires " + "the depth window to evenly divide the input depth.")); + OP_REQUIRES( + context, params.depth_window == params.depth_stride, + errors::Unimplemented("Depthwise max pooling requires " + "the depth window to equal the depth stride.")); + OP_REQUIRES( + context, padding_ != EXPLICIT, + errors::Unimplemented("Depthwise max pooling does not support " + "explicit padding.")); + + DepthwiseMaxPool(context, output, tensor_in, params); + } else { + // MaxPoolingOp is only called on the GPU when the eigen_tensor label + // is used. In this case, explicit padding is not supported + if (std::is_same::value && + padding_ == Padding::EXPLICIT) { + context->SetStatus(errors::Unimplemented( + "MaxPoolingOp does not support explicit padding.")); + return; + } + SpatialMaxPool(context, output, tensor_in, params, padding_); + } + } + + private: + // Single-threaded implementation of DepthwiseMaxPool which + // does not handle all of the same options as SpatialMaxPool + // (strict assumptions on no padding, stride). + // + // TODO(vrv): implement a more general depthwise-max pool that works + // on GPU as well. + void DepthwiseMaxPool(OpKernelContext* context, Tensor* output, + const Tensor& tensor_in, const PoolParameters& params) { + Eigen::Map> + in_by_pool(tensor_in.flat().data(), params.depth_window, + tensor_in.NumElements() / params.depth_window); + Eigen::Map> out_by_pool( + output->flat().data(), 1, output->NumElements()); + out_by_pool = in_by_pool.colwise().maxCoeff(); + } + + void SpatialMaxPool(OpKernelContext* context, Tensor* output, + const Tensor& tensor_in, const PoolParameters& params, + const Padding& padding) { + if (output->NumElements() == 0) { + return; + } + // On GPU, use Eigen's Spatial Max Pooling. On CPU, use an + // EigenMatrix version that is currently faster than Eigen's + // Spatial MaxPooling implementation. + // + // TODO(vrv): Remove this once we no longer need it. + if (std::is_same::value) { + Eigen::PaddingType pt = BrainPadding2EigenPadding(padding); + functor::SpatialMaxPooling()( + context->eigen_device(), output->tensor(), + tensor_in.tensor(), params.window_rows, params.window_cols, + params.row_stride, params.col_stride, pt); + } else { + typedef Eigen::Map> + ConstEigenMatrixMap; + typedef Eigen::Map> + EigenMatrixMap; + + ConstEigenMatrixMap in_mat(tensor_in.flat().data(), params.depth, + params.tensor_in_cols * params.tensor_in_rows * + params.tensor_in_batch); + EigenMatrixMap out_mat( + output->flat().data(), params.depth, + params.out_width * params.out_height * params.tensor_in_batch); + + const DeviceBase::CpuWorkerThreads& worker_threads = + *(context->device()->tensorflow_cpu_worker_threads()); + + // The following code basically does the following: + // 1. Flattens the input and output tensors into two dimensional arrays. + // tensor_in_as_matrix: + // depth by (tensor_in_cols * tensor_in_rows * tensor_in_batch) + // output_as_matrix: + // depth by (out_width * out_height * tensor_in_batch) + // + // 2. Walks through the set of columns in the flattened + // tensor_in_as_matrix, + // and updates the corresponding column(s) in output_as_matrix with the + // max value. + auto shard = [¶ms, &in_mat, &out_mat](int64_t start, int64_t limit) { + const int32_t in_rows = params.tensor_in_rows; + const int32_t in_cols = params.tensor_in_cols; + const int32_t pad_top = params.pad_top; + const int32_t pad_left = params.pad_left; + const int32_t window_rows = params.window_rows; + const int32_t window_cols = params.window_cols; + const int32_t row_stride = params.row_stride; + const int32_t col_stride = params.col_stride; + const int32_t out_height = params.out_height; + const int32_t out_width = params.out_width; + + { + // Initializes the output tensor with MIN. + const int32_t output_image_size = + out_height * out_width * params.depth; + EigenMatrixMap out_shard(out_mat.data() + start * output_image_size, + 1, (limit - start) * output_image_size); + out_shard.setConstant(Eigen::NumTraits::lowest()); + } + + for (int32_t b = start; b < limit; ++b) { + const int32_t out_offset_batch = b * out_height; + for (int32_t h = 0; h < in_rows; ++h) { + for (int32_t w = 0; w < in_cols; ++w) { + // (h_start, h_end) * (w_start, w_end) is the range that the input + // vector projects to. + const int32_t hpad = h + pad_top; + const int32_t wpad = w + pad_left; + const int32_t h_start = + (hpad < window_rows) ? 0 + : (hpad - window_rows) / row_stride + 1; + const int32_t h_end = std::min(hpad / row_stride + 1, out_height); + const int32_t w_start = + (wpad < window_cols) ? 0 + : (wpad - window_cols) / col_stride + 1; + const int32_t w_end = std::min(wpad / col_stride + 1, out_width); + // compute elementwise max + const int32_t in_offset = (b * in_rows + h) * in_cols + w; + for (int32_t ph = h_start; ph < h_end; ++ph) { + const int32_t out_offset_base = + (out_offset_batch + ph) * out_width; + for (int32_t pw = w_start; pw < w_end; ++pw) { + const int32_t out_offset = out_offset_base + pw; + out_mat.col(out_offset) = + out_mat.col(out_offset).cwiseMax(in_mat.col(in_offset)); + } + } + } + } + } + }; + + // TODO(andydavis) Consider sharding across batch x rows x cols. + // TODO(andydavis) Consider a higher resolution shard cost model. + const int64_t shard_cost = + params.tensor_in_rows * params.tensor_in_cols * params.depth; + Shard(worker_threads.num_threads, worker_threads.workers, + params.tensor_in_batch, shard_cost, shard); + } + } + + std::vector ksize_; + std::vector stride_; + Padding padding_; + std::vector explicit_paddings_; + TensorFormat data_format_; +}; + +template +struct LaunchMaxPoolingNoMask_NCHW_VECT_C; + +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM +template <> +struct LaunchMaxPoolingNoMask_NCHW_VECT_C { + static void launch(OpKernelContext* context, const PoolParameters& params, + const Tensor& input, Tensor* output) { +#if GOOGLE_CUDA + bool status = functor::MaxPoolForwardNoMask_NCHW_VECT_C()( + reinterpret_cast(input.flat().data()), + params.tensor_in_batch, params.tensor_in_rows, params.tensor_in_cols, + params.depth, params.out_height, params.out_width, params.window_rows, + params.window_cols, params.row_stride, params.col_stride, + params.pad_top, params.pad_left, + reinterpret_cast(output->flat().data()), + context->eigen_gpu_device()); + if (!status) { + context->SetStatus(errors::Internal( + "Failed launching LaunchMaxPoolingNoMask_NCHW_VECT_C")); + } +#else + // ROCm TODO: add support __vmaxs4 on ROCm + context->SetStatus(errors::Internal( + "Failed launching LaunchMaxPoolingNoMask_NCHW_VECT_C")); +#endif // GOOGLE_CUDA + } +}; +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM + +template +class MaxPoolingV2Op : public OpKernel { + public: + explicit MaxPoolingV2Op(OpKernelConstruction* context) : OpKernel(context) { + string data_format; + auto status = context->GetAttr("data_format", &data_format); + if (status.ok()) { + OP_REQUIRES(context, FormatFromString(data_format, &data_format_), + errors::InvalidArgument("Invalid data format")); + OP_REQUIRES( + context, + data_format_ == FORMAT_NHWC || data_format_ == FORMAT_NCHW_VECT_C, + errors::InvalidArgument( + "MaxPoolingV2Op only supports NHWC or NCHW_VECT_C. Got: ", + data_format)); + } else { + data_format_ = FORMAT_NHWC; + } + if (context->num_inputs() == 1) { + OP_REQUIRES_OK(context, context->GetAttr("ksize", &ksize_)); + OP_REQUIRES(context, ksize_.size() == 4, + errors::InvalidArgument("Sliding window ksize field must " + "specify 4 dimensions")); + OP_REQUIRES( + context, + ksize_[0] > 0 && ksize_[1] > 0 && ksize_[2] > 0 && ksize_[3] > 0, + errors::InvalidArgument("Sliding window ksize must be positive.")); + OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_)); + OP_REQUIRES(context, stride_.size() == 4, + errors::InvalidArgument("Sliding window stride field must " + "specify 4 dimensions")); + OP_REQUIRES(context, ksize_[0] == 1 && stride_[0] == 1, + errors::Unimplemented( + "Pooling is not yet supported on the batch dimension.")); + } + OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_)); + } + + void Compute(OpKernelContext* context) override { + const Tensor& tensor_in = context->input(0); + + std::vector ksize = ksize_; + std::vector stride = stride_; + + if (context->num_inputs() != 1) { + const Tensor& tensor_ksize = context->input(1); + auto value_ksize = tensor_ksize.flat(); + ksize.resize(tensor_ksize.shape().num_elements()); + std::copy_n(&value_ksize(0), ksize.size(), ksize.begin()); + + const Tensor& tensor_stride = context->input(2); + auto value_stride = tensor_stride.flat(); + stride.resize(tensor_stride.shape().num_elements()); + std::copy_n(&value_stride(0), stride.size(), stride.begin()); + } + + OP_REQUIRES(context, ksize.size() == 4, + errors::InvalidArgument("Sliding window ksize field must " + "specify 4 dimensions")); + OP_REQUIRES( + context, ksize[0] > 0 && ksize[1] > 0 && ksize[2] > 0 && ksize[3] > 0, + errors::InvalidArgument("Sliding window ksize must be positive.")); + OP_REQUIRES(context, stride.size() == 4, + errors::InvalidArgument("Sliding window stride field must " + "specify 4 dimensions")); + OP_REQUIRES(context, ksize[0] == 1 && stride[0] == 1, + errors::Unimplemented( + "Pooling is not yet supported on the batch dimension.")); + + PoolParameters params{ + context, + ksize, + stride, + padding_, + /*explicit_paddings=*/{}, + data_format_, + tensor_in.shape(), + }; + if (!context->status().ok()) { + return; + } + + Tensor* output = nullptr; + TensorShape params_forward_output_shape; + OP_REQUIRES_OK(context, + params.forward_output_shape(¶ms_forward_output_shape)); + OP_REQUIRES_OK(context, context->allocate_output( + 0, params_forward_output_shape, &output)); + + if (params.depth_window > 1) { + // Validate spec against the current implementation. A + // relaxation of these requirements would be ideal. + OP_REQUIRES(context, params.depth % params.depth_window == 0, + errors::Unimplemented( + "Depthwise max pooling requires " + "the depth window to evenly divide the input depth.")); + OP_REQUIRES( + context, params.depth_window == params.depth_stride, + errors::Unimplemented("Depthwise max pooling requires " + "the depth window to equal the depth stride.")); + + DepthwiseMaxPool(context, output, tensor_in, params); + } else { + SpatialMaxPool(context, output, tensor_in, params, padding_); + } + } + + private: + // Single-threaded implementation of DepthwiseMaxPool which + // does not handle all of the same options as SpatialMaxPool + // (strict assumptions on no padding, stride). + // + // TODO(vrv): implement a more general depthwise-max pool that works + // on GPU as well. + void DepthwiseMaxPool(OpKernelContext* context, Tensor* output, + const Tensor& tensor_in, const PoolParameters& params) { + Eigen::Map> + in_by_pool(tensor_in.flat().data(), params.depth_window, + tensor_in.NumElements() / params.depth_window); + Eigen::Map> out_by_pool( + output->flat().data(), 1, output->NumElements()); + out_by_pool = in_by_pool.colwise().maxCoeff(); + } + + void SpatialMaxPool(OpKernelContext* context, Tensor* output, + const Tensor& tensor_in, const PoolParameters& params, + const Padding& padding) { + if (output->NumElements() == 0) { + return; + } + // On GPU, use Eigen's Spatial Max Pooling. On CPU, use an + // EigenMatrix version that is currently faster than Eigen's + // Spatial MaxPooling implementation. + // + // TODO(vrv): Remove this once we no longer need it. +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM + if (std::is_same::value) { + Eigen::PaddingType pt = BrainPadding2EigenPadding(padding); + if (std::is_same::value) { + LaunchMaxPoolingNoMask_NCHW_VECT_C::launch( + context, params, tensor_in, output); + } else { + functor::SpatialMaxPooling()( + context->eigen_device(), output->tensor(), + tensor_in.tensor(), params.window_rows, params.window_cols, + params.row_stride, params.col_stride, pt); + } + } else +#endif + { + typedef Eigen::Map> + ConstEigenMatrixMap; + typedef Eigen::Map> + EigenMatrixMap; + + ConstEigenMatrixMap in_mat(tensor_in.flat().data(), params.depth, + params.tensor_in_cols * params.tensor_in_rows * + params.tensor_in_batch); + EigenMatrixMap out_mat( + output->flat().data(), params.depth, + params.out_width * params.out_height * params.tensor_in_batch); + + const DeviceBase::CpuWorkerThreads& worker_threads = + *(context->device()->tensorflow_cpu_worker_threads()); + + // The following code basically does the following: + // 1. Flattens the input and output tensors into two dimensional arrays. + // tensor_in_as_matrix: + // depth by (tensor_in_cols * tensor_in_rows * tensor_in_batch) + // output_as_matrix: + // depth by (out_width * out_height * tensor_in_batch) + // + // 2. Walks through the set of columns in the flattened + // tensor_in_as_matrix, + // and updates the corresponding column(s) in output_as_matrix with the + // max value. + auto shard = [¶ms, &in_mat, &out_mat](int64_t start, int64_t limit) { + const int32_t in_rows = params.tensor_in_rows; + const int32_t in_cols = params.tensor_in_cols; + const int32_t pad_top = params.pad_top; + const int32_t pad_left = params.pad_left; + const int32_t window_rows = params.window_rows; + const int32_t window_cols = params.window_cols; + const int32_t row_stride = params.row_stride; + const int32_t col_stride = params.col_stride; + const int32_t out_height = params.out_height; + const int32_t out_width = params.out_width; + + { + // Initializes the output tensor with MIN. + const int32_t output_image_size = + out_height * out_width * params.depth; + EigenMatrixMap out_shard(out_mat.data() + start * output_image_size, + 1, (limit - start) * output_image_size); + out_shard.setConstant(Eigen::NumTraits::lowest()); + } + + for (int32_t b = start; b < limit; ++b) { + const int32_t out_offset_batch = b * out_height; + for (int32_t h = 0; h < in_rows; ++h) { + for (int32_t w = 0; w < in_cols; ++w) { + // (h_start, h_end) * (w_start, w_end) is the range that the input + // vector projects to. + const int32_t hpad = h + pad_top; + const int32_t wpad = w + pad_left; + const int32_t h_start = + (hpad < window_rows) ? 0 + : (hpad - window_rows) / row_stride + 1; + const int32_t h_end = std::min(hpad / row_stride + 1, out_height); + const int32_t w_start = + (wpad < window_cols) ? 0 + : (wpad - window_cols) / col_stride + 1; + const int32_t w_end = std::min(wpad / col_stride + 1, out_width); + // compute elementwise max + const int32_t in_offset = (b * in_rows + h) * in_cols + w; + for (int32_t ph = h_start; ph < h_end; ++ph) { + const int32_t out_offset_base = + (out_offset_batch + ph) * out_width; + for (int32_t pw = w_start; pw < w_end; ++pw) { + const int32_t out_offset = out_offset_base + pw; + out_mat.col(out_offset) = + out_mat.col(out_offset).cwiseMax(in_mat.col(in_offset)); + } + } + } + } + } + }; + + // TODO(andydavis) Consider sharding across batch x rows x cols. + // TODO(andydavis) Consider a higher resolution shard cost model. + const int64_t shard_cost = + params.tensor_in_rows * params.tensor_in_cols * params.depth; + Shard(worker_threads.num_threads, worker_threads.workers, + params.tensor_in_batch, shard_cost, shard); + } + } + + std::vector ksize_; + std::vector stride_; + Padding padding_; + TensorFormat data_format_; +}; + +template +void SpatialAvgPool(OpKernelContext* context, Tensor* output, + const Tensor& input, const PoolParameters& params, + const Padding& padding) { + if (output->NumElements() == 0) { + return; + } + typedef Eigen::Map> + ConstEigenMatrixMap; + typedef Eigen::Map> + EigenMatrixMap; + + auto in_flat = input.flat(); + auto out_flat = output->flat(); + + auto shard = [¶ms, &in_flat, &out_flat](int64_t start, int64_t limit) { + // Calculate indices for this shards chunk of work. + const int64_t input_image_size = + params.tensor_in_rows * params.tensor_in_cols * params.depth; + const int64_t output_image_size = + params.out_width * params.out_height * params.depth; + const int64_t shard_batch_size = limit - start; + + ConstEigenMatrixMap in_mat( + in_flat.data() + start * input_image_size, params.depth, + params.tensor_in_cols * params.tensor_in_rows * shard_batch_size); + EigenMatrixMap out_mat( + out_flat.data() + start * output_image_size, params.depth, + params.out_width * params.out_height * shard_batch_size); + Eigen::Matrix out_count(out_mat.cols()); + out_count.setZero(); + + // Initializes output to zero. + out_mat.setZero(); + + // The following code basically does the following: + // 1. Flattens the input and output tensors into two dimensional arrays. + // tensor_in_as_matrix: + // depth by (tensor_in_cols * tensor_in_rows * tensor_in_batch) + // output_as_matrix: + // depth by (out_width * out_height * tensor_in_batch) + // + // 2. Walks through the set of columns in the flattened + // tensor_in_as_matrix, + // and updates the corresponding column(s) in output_as_matrix with the + // average value. + for (int b = 0; b < shard_batch_size; ++b) { + for (int h = 0; h < params.tensor_in_rows; ++h) { + for (int w = 0; w < params.tensor_in_cols; ++w) { + // (h_start, h_end) * (w_start, w_end) is the range that the input + // vector projects to. + const int hpad = h + params.pad_top; + const int wpad = w + params.pad_left; + const int h_start = + (hpad < params.window_rows) + ? 0 + : (hpad - params.window_rows) / params.row_stride + 1; + const int h_end = + std::min(hpad / params.row_stride + 1, params.out_height); + const int w_start = + (wpad < params.window_cols) + ? 0 + : (wpad - params.window_cols) / params.col_stride + 1; + const int w_end = + std::min(wpad / params.col_stride + 1, params.out_width); + const int in_offset = + (b * params.tensor_in_rows + h) * params.tensor_in_cols + w; + Eigen::DSizes in_indices(0, in_offset); + for (int ph = h_start; ph < h_end; ++ph) { + for (int pw = w_start; pw < w_end; ++pw) { + const int out_offset = + (b * params.out_height + ph) * params.out_width + pw; + out_mat.col(out_offset) += in_mat.col(in_offset); + out_count(out_offset) += T(1); + } + } + } + } + } + + DCHECK_GT(out_count.minCoeff(), T(0)); + out_mat.array().rowwise() /= out_count.transpose().array(); + }; + + const int64_t work_unit_size = + params.tensor_in_rows * params.tensor_in_cols * params.depth; + // NOTE: Constants in calculation below were estimated based on benchmarking. + // Nanoseconds/work_unit for benchmarks ranged from 0.01 to 0.001, and + // so the factor 0.01 (i.e. 1/100) with a max of 10000, was chosen to limit + // the work unit cost to an operating range in which it empirically performed + // best. + const int64_t work_unit_cost = std::max(int64_t{10000}, work_unit_size / 100); + const DeviceBase::CpuWorkerThreads& worker_threads = + *(context->device()->tensorflow_cpu_worker_threads()); + Shard(worker_threads.num_threads, worker_threads.workers, + params.tensor_in_batch, work_unit_cost, shard); +} + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_POOLING_OPS_COMMON_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/pooling_ops_common_gpu.h b/third_party/tflite-hdrs/tensorflow/core/kernels/pooling_ops_common_gpu.h new file mode 100644 index 00000000..c5d51e59 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/pooling_ops_common_gpu.h @@ -0,0 +1,70 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#if !GOOGLE_CUDA && !TENSORFLOW_USE_ROCM +#error This file must only be included when building with Cuda or ROCm support +#endif + +#ifndef TENSORFLOW_CORE_KERNELS_POOLING_OPS_COMMON_GPU_H_ +#define TENSORFLOW_CORE_KERNELS_POOLING_OPS_COMMON_GPU_H_ + +#include +#include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/numeric_op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/kernels/avgpooling_op.h" +#include "tensorflow/core/kernels/maxpooling_op.h" +#include "tensorflow/core/kernels/ops_util.h" +#include "tensorflow/core/platform/stream_executor.h" +#include "tensorflow/core/util/padding.h" +#include "tensorflow/core/util/tensor_format.h" + +namespace tensorflow { + +// A helper class that launch the cudnn pooling forward operations. +template +class DnnPoolingOp { + public: + typedef GPUDevice Device; + static void Compute(OpKernelContext* context, + se::dnn::PoolingMode pooling_mode, + const std::vector& size, + const std::vector& stride, Padding padding, + std::vector explicit_paddings, + TensorFormat data_format, const Tensor& tensor_in, + const TensorShape& tensor_out_shape, bool propagate_nans); +}; + +// A helper class that launch the cudnn pooling backward operations. +// The original input and output tensors are optional for AvgPoolGrad, but +// mandatory for MaxPoolGrad. +template +class DnnPoolingGradOp { + public: + typedef GPUDevice Device; + static void Compute(OpKernelContext* context, + se::dnn::PoolingMode pooling_mode, + const std::vector& size, + const std::vector& stride, Padding padding, + std::vector explicit_paddings, + TensorFormat data_format, const Tensor* tensor_in, + const Tensor* tensor_out, const Tensor& out_backprop, + const TensorShape& tensor_in_shape, bool propagate_nans); +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_POOLING_OPS_COMMON_GPU_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/population_count_op.h b/third_party/tflite-hdrs/tensorflow/core/kernels/population_count_op.h new file mode 100644 index 00000000..2c981296 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/population_count_op.h @@ -0,0 +1,38 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_POPULATION_COUNT_OP_H_ +#define TENSORFLOW_CORE_KERNELS_POPULATION_COUNT_OP_H_ + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +namespace functor { + +template +struct PopulationCount { + void operator()(OpKernelContext* c, typename TTypes::ConstFlat input, + TTypes::Flat output); +}; + +} // namespace functor + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_POPULATION_COUNT_OP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/priority_queue.h b/third_party/tflite-hdrs/tensorflow/core/kernels/priority_queue.h new file mode 100644 index 00000000..f7ca800a --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/priority_queue.h @@ -0,0 +1,95 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_PRIORITY_QUEUE_H_ +#define TENSORFLOW_CORE_KERNELS_PRIORITY_QUEUE_H_ + +#include +#include +#include + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/kernels/typed_queue.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +using PriorityTensorPair = std::pair; + +struct ComparePriorityTensorPair { + // 0 is a higher priority than 1, -MAX_LONG is a higher priority + // than MAX_LONG, etc. Values coming in with a smaller + // priority number will bubble to the front of the queue. + bool operator()(const PriorityTensorPair& lhs, + const PriorityTensorPair& rhs) const { + return lhs.first > rhs.first; + } +}; + +class PriorityQueue + : public TypedQueue, + ComparePriorityTensorPair> > { + public: + PriorityQueue(int32_t capacity, const DataTypeVector& component_dtypes, + const std::vector& component_shapes, + const string& name); + + absl::Status Initialize() + override; // Must be called before any other method. + + // Implementations of QueueInterface methods -------------------------------- + + void TryEnqueue(const Tuple& tuple, OpKernelContext* ctx, + DoneCallback callback) override; + void TryEnqueueMany(const Tuple& tuple, OpKernelContext* ctx, + DoneCallback callback) override; + void TryDequeue(OpKernelContext* ctx, CallbackWithTuple callback) override; + void TryDequeueMany(int num_elements, OpKernelContext* ctx, + bool allow_small_batch, + CallbackWithTuple callback) override; + absl::Status MatchesNodeDef(const NodeDef& node_def) override; + absl::Status MatchesPriorityNodeDefTypes(const NodeDef& node_def) const; + absl::Status MatchesPriorityNodeDefShapes(const NodeDef& node_def) const; + + int32 size() const override { + mutex_lock lock(mu_); + return queues_[0].size(); + } + + private: + ~PriorityQueue() override {} + + // Helper for dequeuing a single element from queues_. + void DequeueLocked(OpKernelContext* ctx, Tuple* tuple) + TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); + + static absl::Status GetElementComponentFromBatch(const Tuple& tuple, + int index, int component, + OpKernelContext* ctx, + Tensor* out_element); + + PriorityQueue(const PriorityQueue&) = delete; + void operator=(const PriorityQueue&) = delete; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_PRIORITY_QUEUE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/quantization_utils.h b/third_party/tflite-hdrs/tensorflow/core/kernels/quantization_utils.h new file mode 100644 index 00000000..88bee911 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/quantization_utils.h @@ -0,0 +1,968 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_QUANTIZATION_UTILS_H_ +#define TENSORFLOW_CORE_KERNELS_QUANTIZATION_UTILS_H_ + +#include +#define EIGEN_USE_THREADS + +// This is a set of functions that standardizes how quantized values are +// interpreted as float numbers. +// All of the current implementations are for reference and have not been +// optimized. They should be implementable using fixed point representations +// to avoid a dependency on floating-point hardware. + +#if defined(__ARM_NEON__) || defined(__ARM_NEON) +#define QUANTIZATION_UTILS_USE_NEON +#include +#endif + +#include + +#include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#define GEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK +#include "public/gemmlowp.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/lib/core/threadpool.h" + +namespace tensorflow { + +// We have to be able to detect and handle overflows in int32, so this function +// uses doubles and int64's to make sure we have enough room. +template +inline int64_t FloatToQuantizedUnclamped(float input, float range_min, + float range_max) { + const int64_t lowest_quantized = + static_cast(Eigen::NumTraits::lowest()); + if (range_min == range_max) { + return lowest_quantized; + } + const int number_of_bits = sizeof(T) * 8; + const int64_t number_of_steps = static_cast(1) << number_of_bits; + const double range_adjust = (number_of_steps / (number_of_steps - 1.0)); + const double range = ((range_max - range_min) * range_adjust); + const double range_scale = (number_of_steps / range); + int64_t quantized = + (round(input * range_scale) - round(range_min * range_scale)); + quantized += lowest_quantized; + return quantized; +} + +template <> +inline int64_t FloatToQuantizedUnclamped(float input, float range_min, + float range_max) { + return -1; +} + +// This converts the float into the final quantized type, clamping/saturating +// any over or underflows. +template +T FloatToQuantized(float input, float range_min, float range_max) { + if (std::is_same::value) { + // Specialization for float. This is used in reference implementation + // for float which is useful to compare performance between float + // and quantized type. + return input; + } + int64_t quantized = FloatToQuantizedUnclamped(input, range_min, range_max); + const int64_t lowest_quantized = + static_cast(Eigen::NumTraits::lowest()); + const int64_t highest_quantized = + static_cast(Eigen::NumTraits::highest()); + quantized = std::max(quantized, lowest_quantized); + quantized = std::min(quantized, highest_quantized); + return static_cast(static_cast(quantized)); +} + +template +float QuantizedToFloat(T input, float range_min, float range_max) { + if (std::is_same::value) { + // Specialization for float. This is used in reference implementation + // for float which is useful to compare performance between float + // and quantized type. + return input; + } + if (range_min == range_max) { + return range_min; + } + const int number_of_bits = sizeof(T) * 8; + const int64_t number_of_steps = static_cast(1) << number_of_bits; + const double range_adjust = (number_of_steps / (number_of_steps - 1.0)); + const double range = ((range_max - range_min) * range_adjust); + const double range_scale = (range / number_of_steps); + const int64_t lowest_quantized = + static_cast(Eigen::NumTraits::lowest()); + const double offset_input = static_cast(input) - lowest_quantized; + // For compatibility with DEQUANTIZE_WITH_EIGEN, we should convert + // range_scale to a float, otherwise range_min_rounded might be slightly + // different. + const double range_min_rounded = + std::round(range_min / static_cast(range_scale)) * + static_cast(range_scale); + const double result = range_min_rounded + (offset_input * range_scale); + return static_cast(result); +} + +template +float FloatForOneQuantizedLevel(float range_min, float range_max) { + const int64_t highest = static_cast(Eigen::NumTraits::highest()); + const int64_t lowest = static_cast(Eigen::NumTraits::lowest()); + const float float_for_one_quantized_level = + (range_max - range_min) / (highest - lowest); + return float_for_one_quantized_level; +} + +template +void QuantizationRangeForMultiplication(float min_a, float max_a, float min_b, + float max_b, float* min_c, + float* max_c) { + const float a_float_for_one_quant_level = + FloatForOneQuantizedLevel(min_a, max_a); + const float b_float_for_one_quant_level = + FloatForOneQuantizedLevel(min_b, max_b); + + const int64_t c_highest = + static_cast(Eigen::NumTraits::highest()); + const int64_t c_lowest = static_cast(Eigen::NumTraits::lowest()); + const float c_float_for_one_quant_level = + a_float_for_one_quant_level * b_float_for_one_quant_level; + + *min_c = c_float_for_one_quant_level * c_lowest; + *max_c = c_float_for_one_quant_level * c_highest; +} + +// input_array is an eigen Tensor. q2f is a QuantizedToFloatStruct. +// This evaluates to an eigen tensor expression, to be used like: +// auto tensor = DEQUANTIZE_WITH_EIGEN(input_tensor, q2f); +#define DEQUANTIZE_WITH_EIGEN(input_array, q2f) \ + ((q2f.range_min_rounded - q2f.lowest_quantized() * q2f.range_scale) + \ + input_array.template cast() * q2f.range_scale) + +// input_array is an eigen Tensor. f2q is a FloatToQuantizedStruct. +// OutputType is the type of output (e.g. quint8). +// This evaluates to an eigen tensor expression, to be used like: +// auto tensor = QUANTIZE_WITH_EIGEN(input_tensor, f2q, T); +#define QUANTIZE_WITH_EIGEN(input_array, f2q, OutputType) \ + ((input_array * f2q.range_scale).round() - \ + (f2q.range_min_scaled - f2q.lowest_quantized())) \ + .cwiseMax(f2q.lower_bound_float()) \ + .cwiseMin(f2q.upper_bound_float()) \ + .template cast() \ + .template cast() + +// For use with DEQUANTIZE_WITH_EIGEN. +template +struct QuantizedToFloatStruct { + static constexpr int number_of_bits = sizeof(T) * 8; + static constexpr int64_t number_of_steps = static_cast(1) + << number_of_bits; + + static float lowest_quantized() { + return static_cast(Eigen::NumTraits::lowest()); + } + + QuantizedToFloatStruct(float range_min, float range_max) + : range_min(range_min), + range_scale((range_max - range_min) / (number_of_steps - 1.0)), + range_min_rounded(range_max == range_min + ? range_min + : std::round(range_min / range_scale) * + range_scale) {} + + const float range_min; + const float range_scale; + const float range_min_rounded; +}; + +// For use with QUANTIZE_WITH_EIGEN. +template +struct FloatToQuantizedStruct { + static constexpr int number_of_bits = sizeof(T) * 8; + static constexpr int64_t number_of_steps = static_cast(1) + << number_of_bits; + static constexpr double range_adjust = + (number_of_steps / (number_of_steps - 1.0)); + + // Casting QInt32's lowest or highest to a float gives a float that can't be + // cast back to int32 or QInt32. Instead, use bounds that can be converted + // back to int32 without going outside the range of an int32. + static float lower_bound_float() { + return Eigen::numext::maxi( + static_cast(Eigen::NumTraits::lowest()), -2.147483648e+09f); + } + static float upper_bound_float() { + return Eigen::numext::mini( + static_cast(Eigen::NumTraits::highest()), +2.147483520e+09f); + } + + static float lowest_quantized() { + return static_cast(Eigen::NumTraits::lowest()); + } + + FloatToQuantizedStruct(float range_min, float range_max) + : range_min(range_min), + range_scale(range_max == range_min + ? 0.0 + : (number_of_steps - 1.0) / (range_max - range_min)), + range_min_scaled(std::round(range_min * range_scale)) {} + + const float range_min; + const float range_scale; + const float range_min_scaled; +}; + +template +inline T2 RequantizeInNewRange(T1 input, float min_input, float max_input, + float min_new, float max_new) { + const float input_float = QuantizedToFloat(input, min_input, max_input); + return FloatToQuantized(input_float, min_new, max_new); +} + +template +inline void RequantizeManyInNewRange(const T1* input, int64_t count, + float min_input, float max_input, + float min_output, float max_output, + T2* output) { + for (size_t index = 0; index < count; ++index) { + const float input_float = + QuantizedToFloat(input[index], min_input, max_input); + output[index] = FloatToQuantized(input_float, min_output, max_output); + } +} + +// Because converting 32-bit accumulated results down to eight bit is a common +// case, we have a specialized code path to handle it as efficiently as +// possible using only fixed-point math for the inner loop. +inline void RequantizeManyInNewRangeReference(const qint32* input, + int64_t count, float min_input, + float max_input, float min_output, + float max_output, + quint8* output) { + // Initially we calculate all the constants we need once, before we go into + // the inner loop. If this is updated, also update the Eigen version. + const int fp_shift = 16; + const float input_range = max_input - min_input; + const float output_range = max_output - min_output; + const float recip_output_range = + output_range == 0.0 ? 0.0 : (255.0 / output_range); + const float input_rezero = (min_input + max_input) / 2.0; + const int64_t range_scale_fp = + output_range == 0.0 ? 0.0 + : static_cast(255.0 * (1 << fp_shift) * + input_range / output_range); + const int64_t input_offset_fp = + static_cast(input_rezero * recip_output_range * (1 << fp_shift)); + const int64_t output_offset_fp = + output_range == 0.0 + ? 0 + : std::lround((1 << fp_shift) * (min_output * 255.0) / output_range); + const int64_t rounding_delta = 1 << (fp_shift - 1); + + // Inside this loop we just do minimal adds, multiplies, and shifts, in a way + // that could be easily adapted for a SIMD implementation. It should also be + // possible to perform all the calculations in 32-bit rather than 64, but + // that's not been implemented yet. + for (int64_t index = 0; index < count; ++index) { + const int64_t input_value = static_cast(input[index]); + const int64_t fp_value = + ((input_value * range_scale_fp) >> 32) + input_offset_fp; + const int64_t offset_intermediate = fp_value - output_offset_fp; + const int64_t round_intermediate = offset_intermediate + rounding_delta; + int64_t quantized_int64 = round_intermediate >> fp_shift; + quantized_int64 = std::max(quantized_int64, int64_t{0}); + quantized_int64 = std::min(quantized_int64, int64_t{255}); + output[index] = static_cast(static_cast(quantized_int64)); + } +} + +// Another common case is converting eight bit inputs up to thirty two bits, so +// we have specialized fixed-point code to accelerate that. There is also a NEON +// version for ARM devices below. +inline void RequantizeManyInNewRange8To32BitReference( + const quint8* input, int64_t count, float min_input, float max_input, + float min_output, float max_output, qint32* output) { + const float code_0_float = QuantizedToFloat(0, min_input, max_input); + const float code_1_float = QuantizedToFloat(1, min_input, max_input); + const int64_t code_0_int64 = + FloatToQuantizedUnclamped(code_0_float, min_output, max_output); + const int64_t code_1_int64 = + FloatToQuantizedUnclamped(code_1_float, min_output, max_output); + const int32_t mult_int32 = code_1_int64 - code_0_int64; + const int64_t lowest_quantized = + static_cast(Eigen::NumTraits::lowest()); + const int64_t highest_quantized = + static_cast(Eigen::NumTraits::highest()); + for (int64_t i = 0; i < count; ++i) { + const int64_t input_value = static_cast(input[i]); + int64_t output_value = code_0_int64 + (input_value * mult_int32); + output_value = std::max(output_value, lowest_quantized); + output_value = std::min(output_value, highest_quantized); + output[i] = static_cast(output_value); + } +} + +#ifdef QUANTIZATION_UTILS_USE_NEON +// Speeds up the 32->8bit conversion using fixed-point arithmetic and NEON SIMD +// intrinsics for ARM platforms. +inline void RequantizeManyInNewRangeNeon(const qint32* input, int64 count, + float min_input, float max_input, + float min_output, float max_output, + quint8* output) { + // Initially we calculate all the constants we need once, before we go into + // the inner loop. If this is updated, also update the Eigen version. + const int fp_shift = 16; + + // Calculate range variables in advance. + // Input range. + const float input_range = max_input - min_input; + // Output range. + const float output_range = max_output - min_output; + // Ratio of output range. + const float recip_output_range = + output_range == 0.0 ? 0.0 : (255.0 / output_range); + // Average of input range as zero position of input. + const float input_rezero = (min_input + max_input) / 2.0; + // In-out range scale. + const int32 range_scale_fp = + output_range == 0.0 ? 0.0 + : static_cast(255.0 * (1 << (fp_shift - 16)) * + input_range / output_range); + // Input zero position offset to output. + const int32 input_offset_fp = + static_cast(input_rezero * recip_output_range * (1 << fp_shift)); + // Output min offset. + const int32 output_offset_fp = + output_range == 0.0 + ? 0 + : static_cast((1 << fp_shift) * (min_output * 255.0) / + output_range); + const int32 rounding_delta = 1 << (fp_shift - 1); + + // broadcast range to each lane + const int32x4_t range_scale_fp_32x4 = vmovq_n_s32(range_scale_fp); + const int32x4_t input_offset_fp_32x4 = vmovq_n_s32(input_offset_fp); + const int32x4_t output_offset_fp_32x4 = vmovq_n_s32(output_offset_fp); + const int32x4_t rounding_delta_32x4 = vmovq_n_s32(rounding_delta); + + int64 index = 0; + // Use SIMD to requantize. + for (; index < (count - 7); index += 8) { + const int32* input_ptr = &(input->value) + index; + const int32x4_t input_value_low_32x4 = vld1q_s32(input_ptr); + const int32x4_t input_value_high_32x4 = vld1q_s32(input_ptr + 4); + const int32x4_t fp_value_low_32x4 = vaddq_s32( + input_offset_fp_32x4, + vmulq_s32(vshrq_n_s32(input_value_low_32x4, 16), range_scale_fp_32x4)); + const int32x4_t fp_value_high_32x4 = vaddq_s32( + input_offset_fp_32x4, + vmulq_s32(vshrq_n_s32(input_value_high_32x4, 16), range_scale_fp_32x4)); + const int32x4_t offset_intermediate_low_32x4 = + vsubq_s32(fp_value_low_32x4, output_offset_fp_32x4); + const int32x4_t offset_intermediate_high_32x4 = + vsubq_s32(fp_value_high_32x4, output_offset_fp_32x4); + const int32x4_t round_intermediate_low_32x4 = + vaddq_s32(offset_intermediate_low_32x4, rounding_delta_32x4); + const int32x4_t round_intermediate_high_32x4 = + vaddq_s32(offset_intermediate_high_32x4, rounding_delta_32x4); + const int16x4_t quantized_low_16x4 = + vqmovn_s32(vshrq_n_s32(round_intermediate_low_32x4, fp_shift)); + const int16x4_t quantized_high_16x4 = + vqmovn_s32(vshrq_n_s32(round_intermediate_high_32x4, fp_shift)); + const uint8x8_t quantized_8x8 = + vqmovun_s16(vcombine_s16(quantized_low_16x4, quantized_high_16x4)); + uint8* output_ptr = &(output->value) + index; + vst1_u8(output_ptr, quantized_8x8); + } + + // Requantize remaining elements in array without SIMD. + for (; index < count; ++index) { + const int32 input_value = static_cast(input[index]); + const int32 fp_value = + static_cast( + (static_cast(input_value >> 16) * (range_scale_fp))) + + input_offset_fp; + const int32 offset_intermediate = fp_value - output_offset_fp; + const int32 round_intermediate = offset_intermediate + rounding_delta; + int32 quantized_int32 = round_intermediate >> fp_shift; + quantized_int32 = std::max(quantized_int32, 0); + quantized_int32 = std::min(quantized_int32, 255); + output[index] = static_cast(static_cast(quantized_int32)); + } +} + +template <> +inline void RequantizeManyInNewRange( + const qint32* input, int64 count, float min_input, float max_input, + float min_output, float max_output, quint8* output) { + const float input_range = max_input - min_input; + const float output_range = max_output - min_output; + if ((input_range / output_range) > 16384.0f) { + // Our NEON implementation uses 32-bit math and can't handle very + // large ranges, so fall back to the reference implementation. We don't + // expect these to be common in models, so this shouldn't be a performance + // problem in practice. + RequantizeManyInNewRangeReference(input, count, min_input, max_input, + min_output, max_output, output); + } else { + RequantizeManyInNewRangeNeon(input, count, min_input, max_input, min_output, + max_output, output); + } +} + +// NEON accelerated 16bit rounded division by 2^n. +template +inline int16x8_t Divide16x8PowRound(const int16x8_t val) { + const int16x8_t val_sign = vshrq_n_s16(val, 15); + const int16x8_t val_xor = veorq_s16(val, val_sign); + const int16x8_t val_pos = vsubq_s16(val_xor, val_sign); + const int16x8_t shifted_val_pos = vrshrq_n_s16(val_pos, POW); + const int16x8_t shifted_val_pos_xor = veorq_s16(shifted_val_pos, val_sign); + const int16x8_t shifted_val = vsubq_s16(shifted_val_pos_xor, val_sign); + return shifted_val; +} + +// NEON accelerated 64bit rounded division by 2^n. +template +inline int64x2_t Divide64x2PowRound(const int64x2_t val) { + const int64x2_t val_sign = vshrq_n_s64(val, 63); + const int64x2_t val_xor = veorq_s64(val, val_sign); + const int64x2_t val_pos = vsubq_s64(val_xor, val_sign); + const int64x2_t shifted_val_pos = vrshrq_n_s64(val_pos, POW); + const int64x2_t shifted_val_pos_xor = veorq_s64(shifted_val_pos, val_sign); + const int64x2_t shifted_val = vsubq_s64(shifted_val_pos_xor, val_sign); + return shifted_val; +} + +// NEON accelerated 16bit division by 2^n. +// CAVEAT: The input must be greater than min-int16 to avoid underflow. +template +inline int16x8_t Divide16x8Pow(const int16x8_t val) { + static constexpr int16 FIRST_BIT_VAL = 0x0000000000000001; + static const int16x8_t FIRST_BIT = vmovq_n_s16(FIRST_BIT_VAL); + const int16x8_t val_sign = vshrq_n_s16(val, 15); + const int16x8_t neg_offset = vandq_s16(val_sign, FIRST_BIT); + const int16x8_t val_with_offset = vsubq_s16(val, neg_offset); + const int16x8_t shifted_wo_offset = + vsraq_n_s16(neg_offset, val_with_offset, POW); + return shifted_wo_offset; +} + +// NEON accelerated 64bit division by 2^n. +// CAVEAT: The input must be greater than min-int64 to avoid underflow. +template +inline int64x2_t Divide64x2Pow(const int64x2_t val) { + static constexpr int64 FIRST_BIT_VAL = 0x0000000000000001; + static const int64x2_t FIRST_BIT = vmovq_n_s64(FIRST_BIT_VAL); + const int64x2_t val_sign = vshrq_n_s64(val, 63); + const int64x2_t neg_offset = vandq_s64(val_sign, FIRST_BIT); + const int64x2_t val_with_offset = vsubq_s64(val, neg_offset); + const int64x2_t shifted_wo_offset = + vsraq_n_s64(neg_offset, val_with_offset, POW); + return shifted_wo_offset; +} + +// 32bit x 2 NEON accelerated lerp computation. +template +inline int32x2_t ComputeLerp32x2(const int32x2_t top_left, + const int32x2_t top_right, + const int32x2_t bottom_left, + const int32x2_t bottom_right, + const int32x2_t x_lerp, + const int32x2_t y_lerp) { + static_assert(RESOLUTION < 31, "RESOLUTION must be less than 31"); + constexpr int32 RESOLUTION_MULT32 = (1 << RESOLUTION); + static const int32x2_t RESOLUTION_MULT32x2 = vmov_n_s32(RESOLUTION_MULT32); + + const int64x2_t top_left_x_res = vmull_s32(top_left, RESOLUTION_MULT32x2); + const int64x2_t bottom_left_x_res = + vmull_s32(bottom_left, RESOLUTION_MULT32x2); + + const int32x2_t top_right_sub_top_left = vsub_s32(top_right, top_left); + const int64x2_t top_x_res = + vmlal_s32(top_left_x_res, top_right_sub_top_left, x_lerp); + const int32x2_t bottom_right_sub_bottom_left = + vsub_s32(bottom_right, bottom_left); + const int64x2_t bottom_x_res = + vmlal_s32(bottom_left_x_res, bottom_right_sub_bottom_left, x_lerp); + + const int64x2_t bottom_sub_top_x_res = vsubq_s64(bottom_x_res, top_x_res); + const int64x2_t bottom_sub_top = + Divide64x2Pow(bottom_sub_top_x_res); + const int32x2_t bottom_sub_top_32 = vqmovn_s64(bottom_sub_top); + const int64x2_t top_add_bottom_sub_top_mul_ylerp_x_res = + vmlal_s32(top_x_res, bottom_sub_top_32, y_lerp); + const int64x2_t retval = + Divide64x2PowRound(top_add_bottom_sub_top_mul_ylerp_x_res); + const int32x2_t retval32 = vqmovn_s64(retval); + return retval32; +} + +// 8bit x 8 NEON accelerated lerp computation. +template +inline uint8x8_t ComputeLerp8x8(const uint8x8_t top_left8x8, + const uint8x8_t top_right8x8, + const uint8x8_t bottom_left8x8, + const uint8x8_t bottom_right8x8, + const int16x8_t x_lerp, + const int16x8_t y_lerp) { + static_assert(RESOLUTION < 8, "RESOLUTION must be less than 8"); + constexpr uint8 RESOLUTION_MULT_VAL = (1 << RESOLUTION); + static const uint8x8_t RESOLUTION_MULT = vdup_n_u8(RESOLUTION_MULT_VAL); + + const int16x8_t top_left_x_res = + vreinterpretq_s16_u16(vmull_u8(top_left8x8, RESOLUTION_MULT)); + const int16x8_t bottom_left_x_res = + vreinterpretq_s16_u16(vmull_u8(bottom_left8x8, RESOLUTION_MULT)); + + const int16x8_t top_right_sub_top_left = + vreinterpretq_s16_u16(vsubl_u8(top_right8x8, top_left8x8)); + const int16x8_t top_x_res = + vmlaq_s16(top_left_x_res, top_right_sub_top_left, x_lerp); + + const int16x8_t bottom_right_sub_bottom_left = + vreinterpretq_s16_u16(vsubl_u8(bottom_right8x8, bottom_left8x8)); + const int16x8_t bottom_x_res = + vmlaq_s16(bottom_left_x_res, bottom_right_sub_bottom_left, x_lerp); + + const int16x8_t bottom_sub_top_x_res = vsubq_s16(bottom_x_res, top_x_res); + const int16x8_t bottom_sub_top = + Divide16x8Pow(bottom_sub_top_x_res); + const int16x8_t top_add_bottom_sub_top_mul_ylerp_x_res = + vmlaq_s16(top_x_res, bottom_sub_top, y_lerp); + const int16x8_t retval16 = + Divide16x8PowRound(top_add_bottom_sub_top_mul_ylerp_x_res); + const uint8x8_t retval = vmovn_u16(vreinterpretq_u16_s16(retval16)); + return retval; +} + +// Requantize 8 x 8 quints to 8 x 32 qints in parallel by neon +// Return std::array instead of pointer to leverage return value optimization +inline std::array Requantize8x8To32Neon( + const uint8* input_ptr, const int64x2_t input_0_64x2, + const int32x2_t input_mult_32x2) { + const uint8x8_t input_value_8x8 = vld1_u8(input_ptr); + const int16x8_t input_value_16x8 = + vreinterpretq_s16_u16(vmovl_u8(input_value_8x8)); + const int16x4_t input_value_low_16x4 = vget_low_s16(input_value_16x8); + const int16x4_t input_value_high_16x4 = vget_high_s16(input_value_16x8); + const int32x4_t input_value_low_32x4 = vmovl_s16(input_value_low_16x4); + const int32x4_t input_value_high_32x4 = vmovl_s16(input_value_high_16x4); + const int32x2_t input_value_low_low_32x2 = vget_low_s32(input_value_low_32x4); + const int32x2_t input_value_low_high_32x2 = + vget_high_s32(input_value_low_32x4); + const int32x2_t input_value_high_low_32x2 = + vget_low_s32(input_value_high_32x4); + const int32x2_t input_value_high_high_32x2 = + vget_high_s32(input_value_high_32x4); + const int64x2_t mult_result_low_low_64x2 = + vmlal_s32(input_0_64x2, input_value_low_low_32x2, input_mult_32x2); + const int64x2_t mult_result_low_high_64x2 = + vmlal_s32(input_0_64x2, input_value_low_high_32x2, input_mult_32x2); + const int64x2_t mult_result_high_low_64x2 = + vmlal_s32(input_0_64x2, input_value_high_low_32x2, input_mult_32x2); + const int64x2_t mult_result_high_high_64x2 = + vmlal_s32(input_0_64x2, input_value_high_high_32x2, input_mult_32x2); + const int32x2_t output_value_low_low_32x2 = + vqmovn_s64(mult_result_low_low_64x2); + const int32x2_t output_value_low_high_32x2 = + vqmovn_s64(mult_result_low_high_64x2); + const int32x2_t output_value_high_low_32x2 = + vqmovn_s64(mult_result_high_low_64x2); + const int32x2_t output_value_high_high_32x2 = + vqmovn_s64(mult_result_high_high_64x2); + const int32x4_t output_value_low_32x4 = + vcombine_s32(output_value_low_low_32x2, output_value_low_high_32x2); + const int32x4_t output_value_high_32x4 = + vcombine_s32(output_value_high_low_32x2, output_value_high_high_32x2); + return std::array{ + {output_value_low_32x4, output_value_high_32x4}}; +} + +// Speeds up the 8->32bit conversion using fixed-point arithmetic and NEON SIMD +// intrinsics for ARM platforms. +template <> +inline void RequantizeManyInNewRange( + const quint8* input, int64 count, float min_input, float max_input, + float min_output, float max_output, qint32* output) { + // Pre-calculate zero position and multiplier. + // Calculate 0 and 1 value in float. + const float code_0_float = QuantizedToFloat(0, min_input, max_input); + const float code_1_float = QuantizedToFloat(1, min_input, max_input); + + // Cast 0 and 1 value in int64. + const int64 code_0_int64 = + FloatToQuantizedUnclamped(code_0_float, min_output, max_output); + const int64 code_1_int64 = + FloatToQuantizedUnclamped(code_1_float, min_output, max_output); + + // Calculate multiplier. + const int32 mult_int32 = static_cast(code_1_int64 - code_0_int64); + + // Broadcast 0 position and multiplier to lanes + const int64x2_t code_0_64x2 = vmovq_n_s64(code_0_int64); + const int32x2_t mult_32x2 = vmov_n_s32(mult_int32); + + int64 i = 0; + + // Use SIMD to requantize array. + for (; i < (count - 7); i += 8) { + const uint8* input_ptr = &(input->value) + i; + int32* output_ptr = &(output->value) + i; + const std::array output_value = + Requantize8x8To32Neon(input_ptr, code_0_64x2, mult_32x2); + vst1q_s32(output_ptr + 0, output_value[0]); + vst1q_s32(output_ptr + 4, output_value[1]); + } + + // Requantize remaining elements in array without SIMD. + const int64 lowest_quantized = + static_cast(Eigen::NumTraits::lowest()); + const int64 highest_quantized = + static_cast(Eigen::NumTraits::highest()); + + for (; i < count; ++i) { + const int64 input_value = static_cast(input[i]); + int64 output_value = code_0_int64 + (input_value * mult_int32); + output_value = std::max(output_value, lowest_quantized); + output_value = std::min(output_value, highest_quantized); + output[i] = static_cast(output_value); + } +} + +#else + +// If SIMD implementations aren't available, then use these default reference +// versions. +template <> +inline void RequantizeManyInNewRange( + const qint32* input, int64_t count, float min_input, float max_input, + float min_output, float max_output, quint8* output) { + RequantizeManyInNewRangeReference(input, count, min_input, max_input, + min_output, max_output, output); +} + +template <> +inline void RequantizeManyInNewRange( + const quint8* input, int64_t count, float min_input, float max_input, + float min_output, float max_output, qint32* output) { + RequantizeManyInNewRange8To32BitReference(input, count, min_input, max_input, + min_output, max_output, output); +} + +#endif + +template +struct int64_right_shift_op { + EIGEN_DEVICE_FUNC + EIGEN_STRONG_INLINE const int64_t operator()(const int64_t a) const { + return a >> shift; + } +}; + +// See RequantizeManyInNewRange() for a non-eigen reference implementation. +template +inline void RequantizeManyInNewRangeUsingEigen( + const Eigen::ThreadPoolDevice& device, const Tensor& input, float min_input, + float max_input, float min_output, float max_output, Tensor* output) { + auto input_array = input.flat(); + QuantizedToFloatStruct q2f(min_input, max_input); + auto input_float = DEQUANTIZE_WITH_EIGEN(input_array, q2f); + FloatToQuantizedStruct f2q(min_output, max_output); + auto input_requantized = QUANTIZE_WITH_EIGEN(input_float, f2q, T2); + + output->flat().device(device) = input_requantized; +} + +// See RequantizeManyInNewRange() for a non-eigen reference implementation. +// +// Because converting 32-bit accumulated results down to eight bit is a common +// case, we have a specialized code path to handle it as efficiently as +// possible using only fixed-point math for the inner loop. +template <> +inline void RequantizeManyInNewRangeUsingEigen( + const Eigen::ThreadPoolDevice& device, const Tensor& input, float min_input, + float max_input, float min_output, float max_output, Tensor* output) { + // Initially we calculate all the constants we need once, before we go into + // the inner loop. If this is updated, also update the non-Eigen version. + const int fp_shift = 16; + const float input_range = max_input - min_input; + const float output_range = max_output - min_output; + const float recip_output_range = + output_range == 0.0 ? 0.0 : (255.0 / output_range); + const float input_rezero = (min_input + max_input) / 2.0; + const int64_t range_scale_fp = + output_range == 0.0 ? 0.0 + : static_cast(255.0 * (1 << fp_shift) * + input_range / output_range); + const int64_t input_offset_fp = + static_cast(input_rezero * recip_output_range * (1 << fp_shift)); + const int64_t output_offset_fp = + output_range == 0.0 + ? 0 + : std::lround((1 << fp_shift) * (min_output * 255.0) / output_range); + const int64_t rounding_delta = 1 << (fp_shift - 1); + + // Inside this eigen expression we just do minimal adds, multiplies, and + // shifts. It should be possible to perform all the calculations in 32-bit + // rather than 64, but that's not been implemented yet. + auto input_array = input.flat(); + auto fp_value = ((input_array.template cast() * range_scale_fp) + .unaryExpr(int64_right_shift_op<32>())) + + (input_offset_fp - output_offset_fp + rounding_delta); + auto intermediate = fp_value.unaryExpr(int64_right_shift_op()); + auto input_requantized = intermediate.cwiseMax(int64_t{0}) + .cwiseMin(int64_t{255}) + .template cast() + .template cast(); + output->flat().device(device) = input_requantized; +} + +// REQUIRES: 'result->NumElements() == input.NumElements()' +template +void FloatTensorToQuantizedInPlaceUsingEigen( + const Eigen::ThreadPoolDevice& device, const Tensor& input, float min, + float max, Tensor* result) { + DCHECK_EQ(DataTypeToEnum::v(), result->dtype()); + auto flat_input = input.flat(); + auto flat_result = result->flat(); + DCHECK_EQ(flat_input.size(), flat_result.size()); + + FloatToQuantizedStruct f2q(min, max); + flat_result.device(device) = QUANTIZE_WITH_EIGEN(flat_input, f2q, T); +} + +template +void FloatTensorToQuantizedInPlace(const Tensor& input, float min, float max, + Tensor* result) { + DCHECK_EQ(DataTypeToEnum::v(), result->dtype()); + auto flat_input = input.flat(); + auto flat_result = result->flat(); + const int data_size = flat_input.size(); + DCHECK(data_size == flat_result.size()); + for (int i = 0; i < data_size; ++i) { + flat_result(i) = FloatToQuantized(flat_input(i), min, max); + } +} + +template +Tensor FloatTensorToQuantized(const Tensor& input, float min, float max) { + Tensor result(DataTypeToEnum::v(), input.shape()); + FloatTensorToQuantizedInPlace(input, min, max, &result); + return result; +} + +// REQUIRES: 'result->NumElements() == input.NumElements()' +template +void QuantizedTensorToFloatInPlaceUsingEigen( + const Eigen::ThreadPoolDevice& device, const Tensor& input, float min, + float max, Tensor* result) { + DCHECK_EQ(DataTypeToEnum::v(), input.dtype()); + auto flat_input = input.flat(); + auto flat_result = result->flat(); + const int data_size = flat_input.size(); + DCHECK(data_size == flat_result.size()); + + QuantizedToFloatStruct q2f(min, max); + flat_result.device(device) = DEQUANTIZE_WITH_EIGEN(flat_input, q2f); +} + +// REQUIRES: 'result->NumElements() == input.NumElements()' +template +void QuantizedTensorToFloatInPlace(const Tensor& input, float min, float max, + Tensor* result) { + DCHECK_EQ(DataTypeToEnum::v(), input.dtype()); + auto flat_input = input.flat(); + auto flat_result = result->flat(); + const int data_size = flat_input.size(); + DCHECK(data_size == flat_result.size()); + for (int i = 0; i < data_size; ++i) { + flat_result(i) = QuantizedToFloat(flat_input(i), min, max); + } +} + +template +Tensor QuantizedTensorToFloat(const Tensor& input, float min, float max) { + Tensor result(DT_FLOAT, input.shape()); + QuantizedTensorToFloatInPlace(input, min, max, &result); + return result; +} + +void GetOutputMinAndMaxForQuantizedAdd(float input_min, float input_max, + float smaller_input_min, + float smaller_input_max, + float* output_min, float* output_max); + +// Add and . If has fewer elements than +// , then it is broadcast onto . +template +void QuantizedAddUsingEigen(const Eigen::ThreadPoolDevice& device, + const Tensor& input, float input_min, + float input_max, const Tensor& smaller_input, + float smaller_input_min, float smaller_input_max, + Tensor* output, float* output_min, + float* output_max) { + const auto& input_flat = input.flat(); + const auto& smaller_input_flat = smaller_input.flat(); + auto output_flat = output->flat(); + + GetOutputMinAndMaxForQuantizedAdd(input_min, input_max, smaller_input_min, + smaller_input_max, output_min, output_max); + // To do addition properly, we need to compensate for a possibly unbalanced + // zero point in the total representation. The quantized value that + // represents the real number zero needs to be subtracted before addition to + // make sure that the identity of zero + zero = zero holds. + const T3 zero_in_total_space = + FloatToQuantized(0.0f, *output_min, *output_max); + + const int64_t input_element_count = input.NumElements(); + const int64_t smaller_input_element_count = smaller_input.NumElements(); + + QuantizedToFloatStruct input_q2f(input_min, input_max); + QuantizedToFloatStruct smaller_input_q2f(smaller_input_min, + smaller_input_max); + FloatToQuantizedStruct f2q(*output_min, *output_max); + + auto smaller_input_float = + DEQUANTIZE_WITH_EIGEN(smaller_input_flat, smaller_input_q2f); + auto smaller_input_in_total_space = + QUANTIZE_WITH_EIGEN(smaller_input_float, f2q, T3); + + auto input_float = DEQUANTIZE_WITH_EIGEN(input_flat, input_q2f); + auto input_in_total_space = QUANTIZE_WITH_EIGEN(input_float, f2q, T3); + + Eigen::array bcast; + bcast[0] = input_element_count / smaller_input_element_count; + output_flat.device(device) = + input_in_total_space + + (smaller_input_in_total_space.broadcast(bcast) + zero_in_total_space); +} + +// This is a reference implementation of the bias addition for quantized +// buffers, designed to provide a clear specification for the result we +// want. We'll want to specialize this for particular hardware, and +// probably even fuse it with matrix multiplications in a lot of cases. It's +// important to show the clamping behavior we want in particular. +template +void QuantizedAdd(const Eigen::ThreadPoolDevice& device, const Tensor& input, + float input_min, float input_max, const Tensor& smaller_input, + float smaller_input_min, float smaller_input_max, + Tensor* output, float* output_min, float* output_max) { + const auto& input_flat = input.flat(); + const auto& smaller_input_flat = smaller_input.flat(); + auto output_flat = output->flat(); + + GetOutputMinAndMaxForQuantizedAdd(input_min, input_max, smaller_input_min, + smaller_input_max, output_min, output_max); + // To do addition properly, we need to compensate for a possibly unbalanced + // zero point in the total representation. The quantized value that + // represents the real number zero needs to be subtracted before addition to + // make sure that the identity of zero + zero = zero holds. + const T3 zero_in_total_space = + FloatToQuantized(0.0f, *output_min, *output_max); + + const int64_t input_element_count = input.NumElements(); + const int64_t smaller_input_element_count = smaller_input.NumElements(); + + float total_min = *output_min; + float total_max = *output_max; + const size_t how_many_iterations = + (input_element_count / smaller_input_element_count); + for (size_t iteration = 0; iteration < how_many_iterations; ++iteration) { + const size_t offset = iteration * smaller_input_element_count; + for (int c = 0; c < smaller_input_element_count; ++c) { + const int index = (offset + c); + // The two numbers we're going to add can each be in very different + // ranges (e.g. the quantized value '127' may represent very different + // real numbers in both) so we need to convert them to a common range + // before we sum them. + const T1 input_value = input_flat(index); + const T3 input_in_total_space = RequantizeInNewRange( + input_value, input_min, input_max, total_min, total_max); + const T2 smaller_input_value = smaller_input_flat(c); + const T3 smaller_input_in_total_space = + RequantizeInNewRange(smaller_input_value, smaller_input_min, + smaller_input_max, total_min, total_max); + const T3 total_pre = input_in_total_space + smaller_input_in_total_space; + // As noted above, we need to compensate for the offset of the actual + // zero point in the space we're operating in. + const T3 total = total_pre + zero_in_total_space; + output_flat(index) = total; + } + } +} + +// See gemmlowp/internal/multi_thread_gemm.h for the semantics of Execute. +class TensorflowGemmlowpWorkersPool { + public: + TensorflowGemmlowpWorkersPool(thread::ThreadPool* workers) + : workers_(workers) {} + + ~TensorflowGemmlowpWorkersPool() { + // This workaround ensures that all worker tasks have exited methods in the + // BlockingCounter. Without this, there is a race where the context is torn + // down while the counter is in use. + counter_to_decrement_when_ready_.Reset(0); + } + + void Execute(const std::vector& tasks) { + assert(!tasks.empty()); + assert(workers_ != nullptr); + counter_to_decrement_when_ready_.Reset(tasks.size()); + for (gemmlowp::Task* task : tasks) { + workers_->Schedule([this, task]() { + // TODO(cwhipkey): get a local_allocator from a thread local storage. + gemmlowp::Allocator local_allocator; + CHECK(task != nullptr); + task->local_allocator = &local_allocator; + task->Run(); + counter_to_decrement_when_ready_.DecrementCount(); + }); + } + counter_to_decrement_when_ready_.Wait(); + for (gemmlowp::Task* task : tasks) { + delete task; + } + } + + private: + thread::ThreadPool* const workers_; + + // The BlockingCounter used to wait for the workers. + gemmlowp::BlockingCounter counter_to_decrement_when_ready_; + + TensorflowGemmlowpWorkersPool(const TensorflowGemmlowpWorkersPool&) = delete; + void operator=(const TensorflowGemmlowpWorkersPool&) = delete; +}; + +class TensorflowGemmContext : public gemmlowp::MultiThreadGemmContextBase { + public: + TensorflowGemmContext(int num_threads, thread::ThreadPool* workers) + : workers_pool_(workers) { + set_max_num_threads(num_threads); + } + + TensorflowGemmlowpWorkersPool* workers_pool() { return &workers_pool_; } + + private: + TensorflowGemmlowpWorkersPool workers_pool_; + + TensorflowGemmContext(const TensorflowGemmContext&) = delete; + void operator=(const TensorflowGemmContext&) = delete; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_QUANTIZATION_UTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/quantize_and_dequantize_op.h b/third_party/tflite-hdrs/tensorflow/core/kernels/quantize_and_dequantize_op.h new file mode 100644 index 00000000..253d667a --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/quantize_and_dequantize_op.h @@ -0,0 +1,322 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_QUANTIZE_AND_DEQUANTIZE_OP_H_ +#define TENSORFLOW_CORE_KERNELS_QUANTIZE_AND_DEQUANTIZE_OP_H_ + +#include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/kernels/cwise_ops.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +enum QuantizerRoundMode { + // Round half up: if the fraction of y is exactly 0.5, then + // round(y) = y + 0.5 + // E.g., -5.5 gets rounded to -5, -5.4 goes to -5, + // 5.4 goes to 5, and 5.5 goes to 6. + ROUND_HALF_UP, + // Round half to even: if the fraction of y is exactly 0.5, then round(y) is + // the nearest even integer to y. + // E.g., 23.5 gets rounded to 24, 24.5 gets rounded to 24, while -23.5 becomes + // -24, and -24.5 gets rounded to 24. + ROUND_HALF_TO_EVEN, +}; + +namespace functor { + +// TODO(pauldonnelly): 'signed_input' should really be called 'signed_output'. + +template +struct QuantizeAndDequantizeOneScaleFunctor { + void operator()(const Device& d, typename TTypes::ConstVec input, + bool signed_input, int num_bits, bool range_given, + Tensor* input_min_tensor, Tensor* input_max_tensor, + QuantizerRoundMode round_mode, bool narrow_range, + typename TTypes::Vec output); +}; + +template +struct QuantizeAndDequantizePerChannelFunctor { + void operator()(const Device& d, typename TTypes::ConstTensor input, + bool signed_input, int num_bits, bool range_given, + Tensor* input_min_tensor, Tensor* input_max_tensor, + QuantizerRoundMode round_mode, bool narrow_range, + typename TTypes::Tensor output); +}; + +template +struct QuantizeAndDequantizeOneScaleGradientFunctor { + void operator()(const Device& d, typename TTypes::ConstFlat gradient, + typename TTypes::ConstFlat input, + typename TTypes::ConstScalar input_min, + typename TTypes::ConstScalar input_max, + typename TTypes::Flat input_backprop, + typename TTypes::Scalar input_min_backprop, + typename TTypes::Scalar input_max_backprop); +}; + +template +struct QuantizeAndDequantizePerChannelGradientFunctor { + void operator()(const Device& d, typename TTypes::ConstTensor gradient, + typename TTypes::ConstTensor input, + const Tensor* input_min_tensor, + const Tensor* input_max_tensor, + typename TTypes::Tensor input_backprop, + typename TTypes::Flat input_min_backprop, + typename TTypes::Flat input_max_backprop); +}; + +// The implementation below runs on both CPU and GPU. +template ::Vec, + typename ConstVec = typename TTypes::ConstVec> +void ClampScaleAndRound(const Device& d, ConstVec input, T min_range, + T max_range, T scale, T inverse_scale, Func round_func, + Vec output) { + output.device(d) = (input.cwiseMin(max_range).cwiseMax(min_range) * scale) + .unaryExpr(round_func) * + inverse_scale; +} + +// The implementation below runs on both CPU and GPU. +template ::Vec, + typename ConstVec = typename TTypes::ConstVec> +void ClampScaleAndRound(const Device& d, ConstVec input, T min_range, + T max_range, T scale, T inverse_scale, + QuantizerRoundMode round_mode, Vec output) { + switch (round_mode) { + case ROUND_HALF_TO_EVEN: + ClampScaleAndRound(d, input, min_range, max_range, scale, inverse_scale, + Eigen::internal::scalar_round_half_to_even_op(), + output); + break; + case ROUND_HALF_UP: + ClampScaleAndRound(d, input, min_range, max_range, scale, inverse_scale, + Eigen::internal::scalar_round_up_op(), output); + break; + } +} + +// The implementation below runs on both CPU and GPU. +template ::Vec, + typename ConstVec = typename TTypes::ConstVec> +void ScaleAndRound(const Device& d, ConstVec input, T scale, T inverse_scale, + Func round_func, Vec output) { + output.device(d) = (input * scale).unaryExpr(round_func) * inverse_scale; +} + +// The implementation below runs on both CPU and GPU. +template ::Vec, + typename ConstVec = typename TTypes::ConstVec> +void ScaleAndRound(const Device& d, ConstVec input, T scale, T inverse_scale, + QuantizerRoundMode round_mode, Vec output) { + switch (round_mode) { + case ROUND_HALF_TO_EVEN: + ScaleAndRound(d, input, scale, inverse_scale, + Eigen::internal::scalar_round_half_to_even_op(), output); + break; + case ROUND_HALF_UP: + ScaleAndRound(d, input, scale, inverse_scale, + Eigen::internal::scalar_round_up_op(), output); + break; + } +} + +template +void ComputeQuantizationRange(bool signed_input, int num_bits, + QuantizerRoundMode round_mode, bool narrow_range, + T* min_range, T* max_range, T* scale, + T* inverse_scale) { + // Calculate the range for the simulated integer quantization: + // e.g. [-127,127] for signed = true, narrow_range = true, num_bits = 8, + // or [-128,127] for signed = true, narrow_range = false, num_bits = 8, + // or [0, 255] for signed = false, num_bits = 8. + const int64_t min_quantized = + signed_input ? narrow_range ? -(1ULL << (num_bits - 1)) + 1 + : -(1ULL << (num_bits - 1)) + : 0; + const int64_t max_quantized = + signed_input ? (1ULL << (num_bits - 1)) - 1 : (1ULL << num_bits) - 1; + // Determine the maximum scaling factor that would scale + // [min_range, max_range] to not exceed [min_quantized, max_quantized], + // while keeping 0 unchanged. + const T scale_from_min_side = (min_quantized * *min_range > 0) + ? min_quantized / *min_range + : std::numeric_limits::max(); + const T scale_from_max_side = (max_quantized * *max_range > 0) + ? max_quantized / *max_range + : std::numeric_limits::max(); + + // Note: Avoids changing the side of the range that determines scale. + if (scale_from_min_side < scale_from_max_side) { + *scale = scale_from_min_side; + *inverse_scale = *min_range / min_quantized; + *max_range = max_quantized * *inverse_scale; + } else { + *scale = scale_from_max_side; + *inverse_scale = *max_range / max_quantized; + *min_range = min_quantized * *inverse_scale; + } +} + +// The implementation below runs on both CPU and GPU. +template +struct QuantizeAndDequantizeOneScaleImpl { + static void Compute(const Device& d, typename TTypes::ConstVec input, + bool signed_input, int num_bits, bool range_given, + Tensor* input_min_tensor, Tensor* input_max_tensor, + QuantizerRoundMode round_mode, bool narrow_range, + typename TTypes::Vec output) { + T min_range; + T max_range; + auto input_min = input_min_tensor->scalar(); + auto input_max = input_max_tensor->scalar(); + if (!range_given) { + input_min.device(d) = input.minimum(); + input_max.device(d) = input.maximum(); + d.memcpyDeviceToHost(&min_range, input_min.data(), sizeof(T)); + d.memcpyDeviceToHost(&max_range, input_max.data(), sizeof(T)); + } else { + // Copy the range values from their respective tensors on the host. + min_range = input_min_tensor->scalar()(); + max_range = input_max_tensor->scalar()(); + } + + T scale, inverse_scale; + ComputeQuantizationRange(signed_input, num_bits, round_mode, narrow_range, + &min_range, &max_range, &scale, &inverse_scale); + + if (range_given) { + // Note: The clamping here is to avoid overflow in the quantized type. + // The semantics of the op does not guarantee to clamp to the specified + // min_range and max_range - because we may have changed either min_range + // or max_range. + ClampScaleAndRound(d, input, min_range, max_range, scale, inverse_scale, + round_mode, output); + } else { + ScaleAndRound(d, input, scale, inverse_scale, round_mode, output); + } + } +}; + +// The implementation below runs on both CPU and GPU. + +template +struct QuantizeAndDequantizePerChannelImpl { + static void Compute(const Device& d, typename TTypes::ConstTensor input, + bool signed_input, int num_bits, bool range_given, + Tensor* input_min_tensor, Tensor* input_max_tensor, + QuantizerRoundMode round_mode, bool narrow_range, + typename TTypes::Tensor output) { + using Index = typename tensorflow::TTypes::ConstTensor::Index; + int num_channels = input.dimension(1); + auto input_min = input_min_tensor->vec(); + auto input_max = input_max_tensor->vec(); + std::vector min_range(num_channels); + std::vector max_range(num_channels); + + if (!range_given) { + Eigen::IndexList, Eigen::type2index<2> > reduce_dims; + input_min.device(d) = input.minimum(reduce_dims); + input_max.device(d) = input.maximum(reduce_dims); + d.memcpyDeviceToHost(min_range.data(), input_min.data(), + num_channels * sizeof(T)); + d.memcpyDeviceToHost(max_range.data(), input_max.data(), + num_channels * sizeof(T)); + } else { + // Copy the range values from their respective tensors on the host. + std::memcpy(min_range.data(), input_min_tensor->vec().data(), + num_channels * sizeof(T)); + std::memcpy(max_range.data(), input_max_tensor->vec().data(), + num_channels * sizeof(T)); + } + + for (Index i = 0; i < num_channels; ++i) { + const auto input_chip = input.template chip<1>(i); + auto output_chip = output.template chip<1>(i); + + T scale, inverse_scale; + ComputeQuantizationRange(signed_input, num_bits, round_mode, narrow_range, + &min_range[i], &max_range[i], &scale, + &inverse_scale); + if (range_given) { + ClampScaleAndRound(d, input_chip, min_range[i], max_range[i], scale, + inverse_scale, round_mode, output_chip); + } else { + ScaleAndRound(d, input_chip, scale, inverse_scale, round_mode, + output_chip); + } + } + } +}; + +template +struct QuantizeAndDequantizeOneScaleGradientImpl { + static void Compute(const Device& d, typename TTypes::ConstFlat gradient, + typename TTypes::ConstFlat input, + typename TTypes::ConstScalar input_min, + typename TTypes::ConstScalar input_max, + typename TTypes::Flat input_backprop, + typename TTypes::Scalar input_min_backprop, + typename TTypes::Scalar input_max_backprop) { + const T min_val = input_min(); + const T max_val = input_max(); + const auto in_range = + (input >= min_val && input <= max_val) + .select(input.constant(1.0f), input.constant(0.0f)); + input_backprop.device(d) = gradient * in_range; + input_min_backprop.device(d) = input_min_backprop.constant(0.0f); + input_max_backprop.device(d) = input_max_backprop.constant(0.0f); + } +}; + +template +struct QuantizeAndDequantizePerChannelGradientImpl { + static void Compute(const Device& d, + typename TTypes::ConstTensor gradient, + typename TTypes::ConstTensor input, + const Tensor* input_min_tensor, + const Tensor* input_max_tensor, + typename TTypes::Tensor input_backprop, + typename TTypes::Flat input_min_backprop, + typename TTypes::Flat input_max_backprop) { + using Index = typename tensorflow::TTypes::ConstTensor::Index; + auto input_min = input_min_tensor->vec(); + auto input_max = input_max_tensor->vec(); + int num_channels = input.dimension(1); + for (Index i = 0; i < num_channels; ++i) { + const auto gradient_chip = gradient.template chip<1>(i); + const auto input_chip = input.template chip<1>(i); + const T min_val = input_min(i); + const T max_val = input_max(i); + const auto in_range = + (input_chip >= min_val && input_chip <= max_val) + .select(input_chip.constant(1.0f), input_chip.constant(0.0f)); + input_backprop.template chip<1>(i).device(d) = gradient_chip * in_range; + } + input_min_backprop.device(d) = input_min_backprop.constant(0.0f); + input_max_backprop.device(d) = input_max_backprop.constant(0.0f); + } +}; + +} // end of namespace functor +} // end of namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_QUANTIZE_AND_DEQUANTIZE_OP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/queue_base.h b/third_party/tflite-hdrs/tensorflow/core/kernels/queue_base.h new file mode 100644 index 00000000..d39ab454 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/queue_base.h @@ -0,0 +1,188 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_QUEUE_BASE_H_ +#define TENSORFLOW_CORE_KERNELS_QUEUE_BASE_H_ + +#include +#include + +#include "absl/base/macros.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/queue_interface.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +// Functionality common to asynchronous QueueInterface implementations. +class QueueBase : public QueueInterface { + public: + // As a possible value of 'capacity'. + static constexpr int32_t kUnbounded = INT_MAX; + + // Args: + // component_dtypes: The types of each component in a queue-element tuple. + // component_shapes: The shapes of each component in a queue-element tuple, + // which must either be empty (if the shapes are not specified) or + // or have the same size as component_dtypes. + // name: A name to use for the queue. + QueueBase(int32_t capacity, const DataTypeVector& component_dtypes, + const std::vector& component_shapes, + const string& name); + + // Implementations of QueueInterface methods -------------------------------- + const DataTypeVector& component_dtypes() const override { + return component_dtypes_; + } + + absl::Status ValidateTuple(const Tuple& tuple) override; + absl::Status ValidateManyTuple(const Tuple& tuple) override; + + void Close(OpKernelContext* ctx, bool cancel_pending_enqueues, + DoneCallback callback) override; + + // Other public methods ----------------------------------------------------- + const std::vector& component_shapes() const { + return component_shapes_; + } + + int32 capacity() const { return capacity_; } + + bool is_closed() const override { + mutex_lock lock(mu_); + return closed_; + } + + // Copies the index^th slice (in the first dimension) of parent into element. + static absl::Status CopySliceToElement(const Tensor& parent, Tensor* element, + int64_t index); + + // Copies element into the index^th slice (in the first dimension) of parent. + // NOTE(mrry): This method is deprecated. Use + // `tensorflow::batch_util::CopySliceToElement()` defined in + // "./batch_util.h" instead. + ABSL_DEPRECATED( + "Use `tensorflow::batch_util::CopySliceToElement()` defined in " + "\"./batch_util.h\" instead.") + static absl::Status CopyElementToSlice(const Tensor& element, Tensor* parent, + int64_t index); + + protected: + enum Action { kEnqueue, kDequeue }; + enum RunResult { kNoProgress, kProgress, kComplete }; + + // Tries to enqueue/dequeue (or close) based on whatever is at the + // front of enqueue_attempts_/dequeue_attempts_. Appends to + // *finished the callback for any finished attempt (so it may be + // called once mu_ is released). Returns true if any progress was + // made. + struct CleanUp { + CleanUp(DoneCallback&& f, CancellationToken ct, CancellationManager* cm) + : finished(f), to_deregister(ct), cm(cm) {} + DoneCallback finished; + CancellationToken to_deregister; + CancellationManager* cm; + }; + + // Returns the number of components in a queue-element tuple. + int32 num_components() const { return component_dtypes_.size(); } + + // True if shapes were specified. If so, inputs will be validated + // against them, etc. + bool specified_shapes() const { return component_shapes_.size() > 0; } + + // Code common to Validate*Tuple(). + absl::Status ValidateTupleCommon(const Tuple& tuple) const; + + TensorShape ManyOutShape(int i, int64_t batch_size) { + TensorShape shape({batch_size}); + shape.AppendShape(component_shapes_[i]); + return shape; + } + + void Cancel(Action action, CancellationManager* cancellation_manager, + CancellationToken token); + + // Helper for cancelling all pending Enqueue(Many) operations when + // Close is called with cancel_pending_enqueues. + void CloseAndCancel(); + + bool TryAttemptLocked(Action action, std::vector* clean_up) + TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); + + // Tries to make progress on the enqueues or dequeues at the front + // of the *_attempts_ queues. + void FlushUnlocked(); + + ~QueueBase() override; + + // Helpers for implementing MatchesNodeDef(). + static string ShapeListString(const absl::Span& shapes); + absl::Status MatchesNodeDefOp(const NodeDef& node_def, + const string& op) const; + absl::Status MatchesNodeDefCapacity(const NodeDef& node_def, + int32_t capacity) const; + absl::Status MatchesNodeDefTypes(const NodeDef& node_def) const; + absl::Status MatchesNodeDefShapes(const NodeDef& node_def) const; + + protected: + const int32 capacity_; + const DataTypeVector component_dtypes_; + const std::vector component_shapes_; + const string name_; + mutable mutex mu_; + bool closed_ TF_GUARDED_BY(mu_); + + struct Attempt; + typedef std::function RunCallback; + struct Attempt { + int32 elements_requested; + DoneCallback done_callback; // must be run outside mu_ + OpKernelContext* context; + CancellationManager* cancellation_manager; // not owned + CancellationToken cancellation_token; + RunCallback run_callback; // must be run while holding mu_ + bool is_cancelled; + Tuple tuple; + // tuples is used by some implementations allowing dynamic shapes. + std::vector tuples; + + Attempt(int32_t elements_requested, DoneCallback done_callback, + OpKernelContext* context, CancellationManager* cancellation_manager, + CancellationToken cancellation_token, RunCallback run_callback) + : elements_requested(elements_requested), + done_callback(done_callback), + context(context), + cancellation_manager(cancellation_manager), + cancellation_token(cancellation_token), + run_callback(run_callback), + is_cancelled(false) {} + }; + std::deque enqueue_attempts_ TF_GUARDED_BY(mu_); + std::deque dequeue_attempts_ TF_GUARDED_BY(mu_); + + QueueBase(const QueueBase&) = delete; + void operator=(const QueueBase&) = delete; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_QUEUE_BASE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/queue_op.h b/third_party/tflite-hdrs/tensorflow/core/kernels/queue_op.h new file mode 100644 index 00000000..57a771d9 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/queue_op.h @@ -0,0 +1,279 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_QUEUE_OP_H_ +#define TENSORFLOW_CORE_KERNELS_QUEUE_OP_H_ + +#include + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/queue_interface.h" +#include "tensorflow/core/framework/resource_op_kernel.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/kernels/queue_base.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +// Defines a QueueOp, an abstract class for Queue construction ops. +class QueueOp : public ResourceOpKernel { + public: + QueueOp(OpKernelConstruction* context); + + void Compute(OpKernelContext* context) override; + + protected: + // Variables accessible by subclasses + int32 capacity_; + DataTypeVector component_types_; + + private: + absl::Status VerifyResource(QueueInterface* queue) override; +}; + +class TypedQueueOp : public QueueOp { + public: + using QueueOp::QueueOp; + + protected: + template + absl::Status CreateTypedQueue(TypedQueue* queue, QueueInterface** ret) { + if (queue == nullptr) { + return errors::ResourceExhausted("Failed to allocate queue."); + } + *ret = queue; + return queue->Initialize(); + } +}; + +// Queue manipulator kernels + +class QueueOpKernel : public AsyncOpKernel { + public: + explicit QueueOpKernel(OpKernelConstruction* context); + + void ComputeAsync(OpKernelContext* ctx, DoneCallback callback) final; + + protected: + virtual void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue, + DoneCallback callback) = 0; +}; + +class QueueAccessOpKernel : public QueueOpKernel { + public: + explicit QueueAccessOpKernel(OpKernelConstruction* context); + + protected: + int64_t timeout_; +}; + +// Defines an EnqueueOp, the execution of which enqueues a tuple of +// tensors in the given Queue. +// +// The op has 1 + k inputs, where k is the number of components in the +// tuples stored in the given Queue: +// - Input 0: queue handle. +// - Input 1: 0th element of the tuple. +// - ... +// - Input (1+k): kth element of the tuple. +class EnqueueOp : public QueueAccessOpKernel { + public: + explicit EnqueueOp(OpKernelConstruction* context); + + protected: + void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue, + DoneCallback callback) override; + + private: + EnqueueOp(const EnqueueOp&) = delete; + void operator=(const EnqueueOp&) = delete; +}; + +// Defines an EnqueueManyOp, the execution of which slices each +// component of a tuple of tensors along the 0th dimension, and +// enqueues tuples of slices in the given Queue. +// +// The op has 1 + k inputs, where k is the number of components in the +// tuples stored in the given Queue: +// - Input 0: queue handle. +// - Input 1: 0th element of the tuple. +// - ... +// - Input (1+k): kth element of the tuple. +// +// N.B. All tuple components must have the same size in the 0th +// dimension. +class EnqueueManyOp : public QueueAccessOpKernel { + public: + explicit EnqueueManyOp(OpKernelConstruction* context); + + protected: + void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue, + DoneCallback callback) override; + + ~EnqueueManyOp() override; + + private: + EnqueueManyOp(const EnqueueManyOp&) = delete; + void operator=(const EnqueueManyOp&) = delete; +}; + +// Defines a DequeueOp, the execution of which dequeues a tuple of +// tensors from the given Queue. +// +// The op has one input, which is the handle of the appropriate +// Queue. The op has k outputs, where k is the number of components in +// the tuples stored in the given Queue, and output i is the ith +// component of the dequeued tuple. +class DequeueOp : public QueueAccessOpKernel { + public: + explicit DequeueOp(OpKernelConstruction* context); + + protected: + void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue, + DoneCallback callback) override; + + ~DequeueOp() override; + + private: + DequeueOp(const DequeueOp&) = delete; + void operator=(const DequeueOp&) = delete; +}; + +// Defines a DequeueManyOp, the execution of which concatenates the +// requested number of elements from the given Queue along the 0th +// dimension, and emits the result as a single tuple of tensors. +// +// The op has two inputs: +// - Input 0: the handle to a queue. +// - Input 1: the number of elements to dequeue. +// +// The op has k outputs, where k is the number of components in the +// tuples stored in the given Queue, and output i is the ith component +// of the dequeued tuple. +class DequeueManyOp : public QueueAccessOpKernel { + public: + explicit DequeueManyOp(OpKernelConstruction* context); + + protected: + void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue, + DoneCallback callback) override; + + ~DequeueManyOp() override; + + private: + DequeueManyOp(const DequeueManyOp&) = delete; + void operator=(const DequeueManyOp&) = delete; +}; + +// Defines a DequeueUpToOp, the execution of which concatenates the +// requested number of elements from the given Queue along the 0th +// dimension, and emits the result as a single tuple of tensors. +// +// The difference between this op and DequeueMany is the handling when +// the Queue is closed. While the DequeueMany op will return if there +// an error when there are less than num_elements elements left in the +// closed queue, this op will return between 1 and +// min(num_elements, elements_remaining_in_queue), and will not block. +// If there are no elements left, then the standard DequeueMany error +// is returned. +// +// This op only works if the underlying Queue implementation accepts +// the allow_small_batch = true parameter to TryDequeueMany. +// If it does not, an errors::Unimplemented exception is returned. +// +// The op has two inputs: +// - Input 0: the handle to a queue. +// - Input 1: the number of elements to dequeue. +// +// The op has k outputs, where k is the number of components in the +// tuples stored in the given Queue, and output i is the ith component +// of the dequeued tuple. +// +// The op has one attribute: allow_small_batch. If the Queue supports +// it, setting this to true causes the queue to return smaller +// (possibly zero length) batches when it is closed, up to however +// many elements are available when the op executes. In this case, +// the Queue does not block when closed. +class DequeueUpToOp : public QueueAccessOpKernel { + public: + explicit DequeueUpToOp(OpKernelConstruction* context); + + protected: + void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue, + DoneCallback callback) override; + + ~DequeueUpToOp() override; + + private: + DequeueUpToOp(const DequeueUpToOp&) = delete; + void operator=(const DequeueUpToOp&) = delete; +}; + +// Defines a QueueCloseOp, which closes the given Queue. Closing a +// Queue signals that no more elements will be enqueued in it. +// +// The op has one input, which is the handle of the appropriate Queue. +class QueueCloseOp : public QueueOpKernel { + public: + explicit QueueCloseOp(OpKernelConstruction* context); + + protected: + void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue, + DoneCallback callback) override; + + private: + bool cancel_pending_enqueues_; + QueueCloseOp(const QueueCloseOp&) = delete; + void operator=(const QueueCloseOp&) = delete; +}; + +// Defines a QueueSizeOp, which computes the number of elements in the +// given Queue, and emits it as an output tensor. +// +// The op has one input, which is the handle of the appropriate Queue; +// and one output, which is a single-element tensor containing the current +// size of that Queue. +class QueueSizeOp : public QueueOpKernel { + public: + explicit QueueSizeOp(OpKernelConstruction* context); + + protected: + void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue, + DoneCallback callback) override; + + private: + QueueSizeOp(const QueueSizeOp&) = delete; + void operator=(const QueueSizeOp&) = delete; +}; + +class QueueIsClosedOp : public QueueOpKernel { + public: + explicit QueueIsClosedOp(OpKernelConstruction* context); + + protected: + void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue, + DoneCallback callback) override; + + private: + QueueIsClosedOp(const QueueIsClosedOp&) = delete; + void operator=(const QueueIsClosedOp&) = delete; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_QUEUE_OP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/ragged_tensor_to_variant_op_test.h b/third_party/tflite-hdrs/tensorflow/core/kernels/ragged_tensor_to_variant_op_test.h new file mode 100644 index 00000000..7dc63ac8 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/ragged_tensor_to_variant_op_test.h @@ -0,0 +1,189 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include +#include "absl/strings/match.h" +#include "tensorflow/core/framework/fake_input.h" +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/shape_inference.h" +#include "tensorflow/core/framework/shape_inference_testutil.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/framework/variant.h" +#include "tensorflow/core/framework/variant_encode_decode.h" +#include "tensorflow/core/kernels/ops_testutil.h" +#include "tensorflow/core/kernels/ragged_tensor_variant.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" + +#ifndef TENSORFLOW_CORE_KERNELS_RAGGED_TENSOR_TO_VARIANT_OP_TEST_H_ +#define TENSORFLOW_CORE_KERNELS_RAGGED_TENSOR_TO_VARIANT_OP_TEST_H_ + +namespace tensorflow { + +class RaggedTensorToVariantKernelTest : public ::tensorflow::OpsTestBase { + protected: + // Builds the tensorflow test graph for the RaggedTensorToVariant op, and + // populates the `splits` input with the given values. + template + void BuildEncodeRaggedTensorGraph( + const std::vector>& ragged_splits, + const TensorShape& ragged_values_shape, + const std::vector& ragged_values, const bool batched) { + const auto values_dtype = DataTypeToEnum::v(); + const auto splits_dtype = DataTypeToEnum::v(); + int64_t num_splits = ragged_splits.size(); + TF_ASSERT_OK( + NodeDefBuilder("tested_op", "RaggedTensorToVariant") + .Input(FakeInput(num_splits, splits_dtype)) // ragged_splits + .Input(FakeInput(values_dtype)) // ragged_values + .Attr("RAGGED_RANK", num_splits) + .Attr("Tvalues", values_dtype) + .Attr("Tsplits", splits_dtype) + .Attr("batched_input", batched) + .Finalize(node_def())); + TF_ASSERT_OK(InitOp()); + for (const auto& splits : ragged_splits) { + int64_t splits_size = splits.size(); + AddInputFromArray(TensorShape({splits_size}), splits); + } + AddInputFromArray(ragged_values_shape, ragged_values); + } + + template + void BuildEncodeRaggedTensorGraph( + const std::vector>& ragged_splits, + const TensorShape& ragged_values_shape, const VALUE_TYPE& ragged_values, + const bool batched) { + const auto values_dtype = DataTypeToEnum::v(); + const auto splits_dtype = DataTypeToEnum::v(); + int64_t num_splits = ragged_splits.size(); + TF_ASSERT_OK( + NodeDefBuilder("tested_op", "RaggedTensorToVariant") + .Input(FakeInput(num_splits, splits_dtype)) // ragged_splits + .Input(FakeInput(values_dtype)) // ragged_values + .Attr("RAGGED_RANK", num_splits) + .Attr("Tvalues", values_dtype) + .Attr("Tsplits", splits_dtype) + .Attr("batched_input", batched) + .Finalize(node_def())); + TF_ASSERT_OK(InitOp()); + for (const auto& splits : ragged_splits) { + int64_t splits_size = splits.size(); + AddInputFromArray(TensorShape({splits_size}), splits); + } + AddInput(ragged_values_shape, + [&ragged_values](int i) { return ragged_values; }); + } + + template + RaggedTensorVariant CreateVariantFromRagged( + const std::vector>& ragged_splits, + const TensorShape& ragged_values_shape, + const std::vector& ragged_values) { + RaggedTensorVariant encoded; + for (auto ragged_split : ragged_splits) { + int splits_size = ragged_split.size(); + Tensor splits(DataTypeToEnum::v(), + TensorShape({splits_size})); + test::FillValues(&splits, ragged_split); + encoded.append_splits(splits); + } + Tensor values(DataTypeToEnum::v(), ragged_values_shape); + test::FillValues(&values, ragged_values); + encoded.set_values(values); + return encoded; + } + + template + RaggedTensorVariant CreateVariantFromRagged( + const std::vector>& ragged_splits, + const std::vector& ragged_values) { + int num_values = ragged_values.size(); + return CreateVariantFromRagged(ragged_splits, {num_values}, ragged_values); + } + + template + void ExpectRaggedTensorVariantEqual(const RaggedTensorVariant& expected, + const RaggedTensorVariant& actual) { + test::ExpectTensorEqual(actual.values(), expected.values()); + EXPECT_EQ(actual.ragged_rank(), expected.ragged_rank()); + for (int i = 0; i < actual.ragged_rank(); ++i) { + test::ExpectTensorEqual(actual.splits(i), expected.splits(i)); + } + } +}; + +class RaggedTensorToVariantGradientKernelTest + : public ::tensorflow::OpsTestBase { + protected: + // Builds the tensorflow test graph for the RaggedTensorToVariantGradient op, + // and populates the `encoded_ragged_grad`, `row_splits` and + // `dense_values_shape` input with the given values. + template + void BuildEncodeRaggedTensorGradientGraph( + const std::vector& encoded_ragged_grad, + const std::vector& row_splits, + const std::vector& dense_values_shape) { + const auto values_dtype = DataTypeToEnum::v(); + const auto splits_dtype = DataTypeToEnum::v(); + + TF_ASSERT_OK(NodeDefBuilder("tested_op", "RaggedTensorToVariantGradient") + .Input(FakeInput(DT_VARIANT)) // encoded_ragged_grad + .Input(FakeInput(splits_dtype)) // row_splits + .Input(FakeInput(DT_INT32)) // dense_values_shape + .Attr("Tvalues", values_dtype) + .Attr("Tsplits", splits_dtype) + .Finalize(node_def())); + TF_ASSERT_OK(InitOp()); + + int64_t encoded_ragged_grad_size = encoded_ragged_grad.size(); + AddInputFromArray(TensorShape({encoded_ragged_grad_size}), + encoded_ragged_grad); + + int64_t splits_size = row_splits.size(); + AddInputFromArray(TensorShape({splits_size}), row_splits); + + int64_t dense_values_shape_size = dense_values_shape.size(); + AddInputFromArray(TensorShape({dense_values_shape_size}), + dense_values_shape); + } + + template + RaggedTensorVariant CreateVariantFromRagged( + const std::vector>& ragged_splits, + const TensorShape& ragged_values_shape, + const std::vector& ragged_values) { + RaggedTensorVariant encoded; + for (auto ragged_split : ragged_splits) { + int splits_size = ragged_split.size(); + Tensor splits(DataTypeToEnum::v(), + TensorShape({splits_size})); + test::FillValues(&splits, ragged_split); + encoded.append_splits(splits); + } + Tensor values(DataTypeToEnum::v(), ragged_values_shape); + test::FillValues(&values, ragged_values); + encoded.set_values(values); + return encoded; + } +}; +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_RAGGED_TENSOR_TO_VARIANT_OP_TEST_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/ragged_tensor_variant.h b/third_party/tflite-hdrs/tensorflow/core/kernels/ragged_tensor_variant.h new file mode 100644 index 00000000..1d2066b0 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/ragged_tensor_variant.h @@ -0,0 +1,110 @@ +#include "tensorflow/core/framework/tensor_key.h" +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_RAGGED_TENSOR_VARIANT_H_ +#define TENSORFLOW_CORE_KERNELS_RAGGED_TENSOR_VARIANT_H_ + +#define EIGEN_USE_THREADS +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM +#define EIGEN_USE_GPU +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM + +#include + +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/variant_op_registry.h" +#include "tensorflow/core/framework/variant_tensor_data.h" +#include "tensorflow/core/kernels/cwise_ops_common.h" +#include "tensorflow/core/util/tensor_ops_util.h" + +namespace tensorflow { + +// Class used to store a RaggedTensor as a Variant scalar. +class RaggedTensorVariant { + public: + RaggedTensorVariant() {} + RaggedTensorVariant(Tensor values, const std::vector& nested_splits) + : values_(std::move(values)), nested_splits_(nested_splits) {} + + // Variant support methods. + string TypeName() const; + string DebugString() const; + void Encode(VariantTensorData* data) const; + bool Decode(const VariantTensorData& data); + + // The flat_values of the RaggedTensor. + const Tensor& values() const { return values_; } + Tensor* mutable_values() { return &values_; } + void set_values(const Tensor& new_values) { values_ = new_values; } + + // The nested row_splits of the RaggedTensor. + int ragged_rank() const { return nested_splits_.size(); } + const std::vector& nested_splits() const { return nested_splits_; } + std::vector* mutable_nested_splits() { return &nested_splits_; } + const Tensor& splits(int i) const { return nested_splits_[i]; } + Tensor* mutable_splits(int i) { return &nested_splits_[i]; } + void set_nested_splits(const std::vector& nested_splits) { + nested_splits_ = nested_splits; + } + void append_splits(const Tensor& splits) { nested_splits_.push_back(splits); } + + private: + Tensor values_; + std::vector nested_splits_; +}; + +template +absl::Status RaggedTensorVariantZerosLike(OpKernelContext* c, + const RaggedTensorVariant& x, + RaggedTensorVariant* y) { + y->set_nested_splits(x.nested_splits()); + TF_RETURN_IF_ERROR( + ZerosLikeTensor(c, x.values(), y->mutable_values())); + return absl::OkStatus(); +} + +template +absl::Status RaggedTensorVariantBinaryAdd(OpKernelContext* c, + const RaggedTensorVariant& x, + const RaggedTensorVariant& y, + RaggedTensorVariant* out) { + if (x.values().dtype() != y.values().dtype()) { + return errors::InvalidArgument( + "Can't add RaggedTensorVariants of different dtypes. One is ", + DataTypeString(x.values().dtype()), " and the other is ", + DataTypeString(y.values().dtype())); + } + if (x.ragged_rank() != y.ragged_rank()) { + return errors::InvalidArgument( + "Can't add RaggedTensorVariants of different ragged rank. ", "One is ", + x.ragged_rank(), " and the other is ", y.ragged_rank()); + } + for (int i = 0; i < x.ragged_rank(); ++i) { + if (TensorKey(x.splits(i)) != TensorKey(y.splits(i))) { + return errors::InvalidArgument( + "Can't add RaggedTensorVariants with different row_splits."); + } + } + out->set_nested_splits(x.nested_splits()); + TF_RETURN_IF_ERROR(BinaryAddTensors(c, x.values(), y.values(), + out->mutable_values())); + return absl::OkStatus(); +} + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_RAGGED_TENSOR_VARIANT_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/ragged_utils.h b/third_party/tflite-hdrs/tensorflow/core/kernels/ragged_utils.h new file mode 100644 index 00000000..3ccd34a5 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/ragged_utils.h @@ -0,0 +1,77 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_KERNELS_RAGGED_UTILS_H_ +#define TENSORFLOW_CORE_KERNELS_RAGGED_UTILS_H_ + +#include + +#include "absl/status/status.h" +#include "tensorflow/core/framework/tensor.h" + +namespace tensorflow { + +// Utility functions for RaggedTensor + +// Verifies that the splits are valid for ragged tensor +template +absl::Status RaggedTensorVerifySplits(const Tensor& ragged_splits, + bool check_last_element, + int64_t num_ragged_values) { + auto flat_ragged_splits = ragged_splits.flat(); + + if (ragged_splits.dims() != 1) { + return absl::InvalidArgumentError(absl::StrCat( + "Invalid ragged splits: ragged splits must be rank 1 but is rank ", + ragged_splits.dims())); + } + + if (ragged_splits.NumElements() < 1) { + return absl::InvalidArgumentError( + "Invalid ragged splits: ragged splits must have at least one splits, " + "but is empty"); + } + + if (flat_ragged_splits(0) != static_cast(0)) { + return absl::InvalidArgumentError( + absl::StrCat("Invalid ragged splits: first element of ragged splits " + " must be 0 but is ", + flat_ragged_splits(0))); + } + + SPLIT_TYPE last_split = 0; + for (int j = 1; j < ragged_splits.dim_size(0); j++) { + auto split = flat_ragged_splits(j); + if (split < last_split) { + return absl::InvalidArgumentError( + absl::StrCat("Invalid ragged splits: ragged splits must be " + "monotonically increasing, but ragged_splits[", + j, "]=", split, " is smaller than row_splits[", j - 1, + "]=", last_split)); + } + last_split = split; + } + + if (check_last_element & last_split != num_ragged_values) { + return absl::InvalidArgumentError(absl::StrCat( + "Invalid ragged splits: last element of ragged splits must be ", + "the number of ragged values(", num_ragged_values, ") but is ", + last_split)); + } + + return absl::OkStatus(); +} +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_RAGGED_UTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/random_binomial_op.h b/third_party/tflite-hdrs/tensorflow/core/kernels/random_binomial_op.h new file mode 100644 index 00000000..e701e5ff --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/random_binomial_op.h @@ -0,0 +1,61 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_RANDOM_BINOMIAL_OP_H_ +#define TENSORFLOW_CORE_KERNELS_RANDOM_BINOMIAL_OP_H_ + +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/lib/random/random_distributions.h" + +namespace tensorflow { + +class OpKernelContext; + +namespace functor { + +// Sample a binomial random variable, with probs and counts for each batch. +// Uses binomial inversion and a transformed rejection sampling method as +// described in +// https://pdfs.semanticscholar.org/471b/c2726e25bbf8801ef781630a2c13f654268e.pdf. +// Two different algorithms are employed, depending on the size of +// counts * probs (or counts * (1 - probs) if probs > 0.5. +// If counts * probs < 10, we simply sum up Geometric random variables until +// they exceed count, and the number we used is binomially distributed. +// In expectation, this will take O(counts * probs) time, and requiring in +// expectation the same number of random variates. +// This can be much cheaper than summing bernoulli random variates, as we +// will always need O(counts) bernoulli random variates (so this requires fewer +// uniform r.v.s as well as can be faster). +// +// If counts * probs > 10, we use a transformed-rejection algorithm based on +// pairs of uniform random variates due to Hormann. +// https://pdfs.semanticscholar.org/471b/c2726e25bbf8801ef781630a2c13f654268e.pdf +// This algorithm has higher acceptance rates for counts * probs large, as the +// proposal distribution becomes quite tight, requiring approximately two +// uniform random variates as counts * probs becomes large. +template +struct RandomBinomialFunctor { + void operator()(OpKernelContext* ctx, const Device& d, int64_t num_batches, + int64_t samples_per_batch, int64_t num_elements, + typename TTypes::ConstFlat counts, + typename TTypes::ConstFlat probs, + const random::PhiloxRandom& gen, + typename TTypes::Flat output); +}; + +} // namespace functor +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_RANDOM_BINOMIAL_OP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/random_index_shuffle.h b/third_party/tflite-hdrs/tensorflow/core/kernels/random_index_shuffle.h new file mode 100644 index 00000000..68b52ad6 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/random_index_shuffle.h @@ -0,0 +1,45 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_RANDOM_INDEX_SHUFFLE_H_ +#define TENSORFLOW_CORE_KERNELS_RANDOM_INDEX_SHUFFLE_H_ + +#include + +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +namespace random { + +// Returns the position of `index` in a permutation of [0, ..., max_index]. +// +// Index must be number in [0, ..., max_index]. +// Key is the random key for the permutation. +// The returned index will also be in [0, ..., max_index]. For a fixed `key` +// and `max_index` the all possible `index` values and the returned values +// form a bijection. +// Rounds must be a positive even integer >= 4. Larger values increase improve +// 'randomness' of permutations for small `max_index` values. The time to +// compute the result scales linear with the number of rounds. We recommend 8 +// rounds for a good treat off. +// +// For more details on the algorithm see the top of the cc file. +uint64_t index_shuffle(const uint64_t index, const std::array& key, + const uint64_t max_index, const int32_t rounds); + +} // namespace random +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_RANDOM_INDEX_SHUFFLE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/random_op.h b/third_party/tflite-hdrs/tensorflow/core/kernels/random_op.h new file mode 100644 index 00000000..ea16f54e --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/random_op.h @@ -0,0 +1,64 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_RANDOM_OP_H_ +#define TENSORFLOW_CORE_KERNELS_RANDOM_OP_H_ + +#include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/lib/random/random_distributions.h" + +namespace tensorflow { + +class OpKernelContext; + +namespace functor { + +template +struct FillPhiloxRandom; + +typedef Eigen::ThreadPoolDevice CPUDevice; +// Declares the partially CPU-specialized functor struct. +// +// NOTE: Due to inlining done by the compiler, you may need to add +// explicit instantiation of the functor in random_op.cc. See example +// functor::FillPhiloxRandom. +// +// This functor can take the PhiloxRandom input from either device memory `key` +// and `counter` or a stack value `gen`. If both `key` and `counter` are not +// nullptr, they provide the input; otherwise `gen` provides the input. +template +struct FillPhiloxRandom { + void operator()(OpKernelContext* ctx, const CPUDevice& d, const uint64* key, + const uint64* counter, random::PhiloxRandom gen, + typename Distribution::ResultElementType* data, int64_t size, + Distribution dist); +}; + +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM +typedef Eigen::GpuDevice GPUDevice; +// Declares the partially GPU-specialized functor struct. +template +struct FillPhiloxRandom { + void operator()(OpKernelContext* ctx, const GPUDevice& d, const uint64* key, + const uint64* counter, random::PhiloxRandom gen, + typename Distribution::ResultElementType* data, int64_t size, + Distribution dist); +}; +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM + +} // namespace functor +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_RANDOM_OP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/random_op_cpu.h b/third_party/tflite-hdrs/tensorflow/core/kernels/random_op_cpu.h new file mode 100644 index 00000000..cfa927c1 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/random_op_cpu.h @@ -0,0 +1,193 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_RANDOM_OP_CPU_H_ +#define TENSORFLOW_CORE_KERNELS_RANDOM_OP_CPU_H_ + +#define EIGEN_USE_THREADS + +#include +#include +#include + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/kernels/random_op.h" +#include "tensorflow/core/kernels/random_ops_util.h" +#include "tensorflow/core/lib/hash/crc32c.h" +#include "tensorflow/core/lib/random/random_distributions.h" +#include "tensorflow/core/lib/random/simple_philox.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/util/guarded_philox_random.h" +#include "tensorflow/core/util/work_sharder.h" + +#if EIGEN_COMP_GNUC && __cplusplus > 199711L +#define DISABLE_FLOAT_EQUALITY_WARNING \ + _Pragma("GCC diagnostic push") \ + _Pragma("GCC diagnostic ignored \"-Wfloat-equal\"") +#define ENABLE_FLOAT_EQUALITY_WARNING _Pragma("GCC diagnostic pop") +#else +#define DISABLE_FLOAT_EQUALITY_WARNING +#define ENABLE_FLOAT_EQUALITY_WARNING +#endif + +namespace tensorflow { + +typedef Eigen::ThreadPoolDevice CPUDevice; +typedef Eigen::GpuDevice GPUDevice; + +namespace functor { +using random::PhiloxRandom; +using random::SingleSampleAdapter; + +// The default implementation of the functor, which should never be invoked +// But we still need to provide implementation for now for the linker to work, +// since we do not support all the distributions yet. +template +struct FillPhiloxRandom { + typedef typename Distribution::ResultElementType T; + void operator()(OpKernelContext* ctx, const Device&, const uint64* key, + const uint64* counter, random::PhiloxRandom gen, T* data, + int64_t size, Distribution dist) { + OP_REQUIRES( + ctx, false, + errors::Internal( + "Default `FillPhiloxRandom` implementation should not be executed. " + "The cause of this error is probably that `FillPhiloxRandom` does " + "not support this device or random distribution yet.")); + } +}; + +// A class to fill a specified range of random groups +template +struct FillPhiloxRandomTask; + +// Specialization for distribution that takes a fixed number of samples for +// each output. +template +struct FillPhiloxRandomTask { + typedef typename Distribution::ResultElementType T; + static void Run(random::PhiloxRandom gen, T* data, int64_t size, + int64_t start_group, int64_t limit_group, Distribution dist) { + const int kGroupSize = Distribution::kResultElementCount; + + gen.Skip(start_group); + int64_t offset = start_group * kGroupSize; + + // First fill all the full-size groups + int64_t limit_group_full = std::min(limit_group, size / kGroupSize); + for (int64_t index = start_group; index < limit_group_full; ++index) { + auto samples = dist(&gen); + std::copy(&samples[0], &samples[0] + kGroupSize, data + offset); + offset += kGroupSize; + } + + // If there are any remaining elements that need to be filled, process them + if (limit_group_full < limit_group) { + int64_t remaining_size = size - limit_group_full * kGroupSize; + auto samples = dist(&gen); + std::copy(&samples[0], &samples[0] + remaining_size, data + offset); + } + } +}; + +// Specialization for distribution that takes a variable number of samples for +// each output. This will be slower due to the generality. +template +struct FillPhiloxRandomTask { + typedef typename Distribution::ResultElementType T; + static constexpr int64_t kReservedSamplesPerOutput = 256; + + static void Run(random::PhiloxRandom base_gen, T* data, int64_t size, + int64_t start_group, int64_t limit_group, Distribution dist) { + const int kGroupSize = Distribution::kResultElementCount; + + static const int kGeneratorSkipPerOutputGroup = + kGroupSize * kReservedSamplesPerOutput / + PhiloxRandom::kResultElementCount; + + int64_t offset = start_group * kGroupSize; + + // First fill all the full-size groups + int64_t limit_group_full = std::min(limit_group, size / kGroupSize); + int64_t group_index; + for (group_index = start_group; group_index < limit_group_full; + ++group_index) { + // Reset the generator to the beginning of the output group region + // This is necessary if we want the results to be independent of order + // of work + PhiloxRandom gen = base_gen; + gen.Skip(group_index * kGeneratorSkipPerOutputGroup); + SingleSampleAdapter single_samples(&gen); + + auto samples = dist(&single_samples); + std::copy(&samples[0], &samples[0] + kGroupSize, data + offset); + offset += kGroupSize; + } + + // If there are any remaining elements that need to be filled, process them + if (limit_group_full < limit_group) { + PhiloxRandom gen = base_gen; + gen.Skip(group_index * kGeneratorSkipPerOutputGroup); + SingleSampleAdapter single_samples(&gen); + + int64_t remaining_size = size - limit_group_full * kGroupSize; + auto samples = dist(&single_samples); + std::copy(&samples[0], &samples[0] + remaining_size, data + offset); + } + } +}; + +// Partial specialization for CPU to fill the entire region with randoms +// It splits the work into several tasks and run them in parallel +template +void FillPhiloxRandom::operator()( + OpKernelContext* ctx, const CPUDevice&, const uint64* key, + const uint64* counter, random::PhiloxRandom gen, + typename Distribution::ResultElementType* data, int64_t size, + Distribution dist) { + if (key != nullptr && counter != nullptr) { + gen = GetPhiloxRandomFromCounterKeyMem(counter, key); + } + + const int kGroupSize = Distribution::kResultElementCount; + + auto worker_threads = *(ctx->device()->tensorflow_cpu_worker_threads()); + + int64_t total_group_count = (size + kGroupSize - 1) / kGroupSize; + + const int kGroupCost = + random::PhiloxRandom::kResultElementCount * + (random::PhiloxRandom::kElementCost + Distribution::kElementCost); + + Shard(worker_threads.num_threads, worker_threads.workers, total_group_count, + kGroupCost, + [&gen, data, size, dist](int64_t start_group, int64_t limit_group) { + FillPhiloxRandomTask< + Distribution, + Distribution::kVariableSamplesPerOutput>::Run(gen, data, size, + start_group, + limit_group, dist); + }); +} + +} // namespace functor + + +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_RANDOM_OP_CPU_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/random_op_gpu.h b/third_party/tflite-hdrs/tensorflow/core/kernels/random_op_gpu.h new file mode 100644 index 00000000..f8efa21d --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/random_op_gpu.h @@ -0,0 +1,255 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_RANDOM_OP_GPU_H_ +#define TENSORFLOW_CORE_KERNELS_RANDOM_OP_GPU_H_ + +#if defined(__CUDACC__) || TENSORFLOW_USE_ROCM + +#include "tensorflow/core/kernels/random_op.h" +#include "tensorflow/core/kernels/random_ops_util.h" +#include "tensorflow/core/lib/random/philox_random.h" +#include "tensorflow/core/lib/random/random_distributions.h" +#include "tensorflow/core/util/gpu_kernel_helper.h" + +namespace tensorflow { + +namespace functor { + +template +struct FillPhiloxRandomKernel; + +template +struct FillPhiloxRandomKernel { + typedef typename Distribution::ResultElementType T; + PHILOX_DEVICE_INLINE void Run(const uint64* key, const uint64* counter, + random::PhiloxRandom gen, T* data, int64 size, + Distribution dist); +}; + +template +struct FillPhiloxRandomKernel { + typedef typename Distribution::ResultElementType T; + PHILOX_DEVICE_INLINE void Run(const uint64* key, const uint64* counter, + random::PhiloxRandom base_gen, T* data, + int64 size, Distribution dist); +}; + +template +class SampleCopier { + public: + inline __device__ void operator()( + T* __restrict__ buf, + const tensorflow::random::Array& array) const { +#pragma unroll + for (int i = 0; i < ElementCount; i++) { + buf[i] = array[i]; + } + } +}; + +template <> +class SampleCopier { + public: + // Copies the elements from the array to buf. buf must be 128-bit aligned, + // which is true for tensor data, and all offsets that are a multiple of the + // vector size (because the vectors are 128 bits long). + inline __device__ void operator()( + float* __restrict__ buf, + const tensorflow::random::Array& array) const { + // NOTE(ringwalt): It's not safe to cast &array[0] to a float4, because they + // have 32-bit alignment vs 128-bit alignment. There seems to be no + // performance loss when assigning each element to a vector. + float4 vec; + vec.x = array[0]; + vec.y = array[1]; + vec.z = array[2]; + vec.w = array[3]; + float4* buf_vector = reinterpret_cast(buf); + *buf_vector = vec; + } +}; + +template <> +class SampleCopier { + public: + // Copies the elements from the array to buf. buf must be 128-bit aligned, + // which is true for tensor data, and all offsets that are a multiple of the + // vector size (because the vectors are 128 bits long). + inline __device__ void operator()( + int32* __restrict__ buf, + const tensorflow::random::Array& array) const { + ::int4 vec; + vec.x = array[0]; + vec.y = array[1]; + vec.z = array[2]; + vec.w = array[3]; + ::int4* buf_vector = reinterpret_cast<::int4*>(buf); + *buf_vector = vec; + } +}; + +template <> +class SampleCopier { + public: + // Copies the elements from the array to buf. buf must be 128-bit aligned, + // which is true for tensor data, and all offsets that are a multiple of the + // vector size (because the vectors are 128 bits long). + inline __device__ void operator()( + double* __restrict__ buf, + const tensorflow::random::Array& array) const { + double2 vec; + vec.x = array[0]; + vec.y = array[1]; + double2* buf_vector = reinterpret_cast(buf); + *buf_vector = vec; + } +}; + +template <> +class SampleCopier { + public: + // Copies the elements from the array to buf. buf must be 128-bit aligned, + // which is true for tensor data, and all offsets that are a multiple of the + // vector size (because the vectors are 128 bits long). + inline __device__ void operator()( + int64* __restrict__ buf, + const tensorflow::random::Array& array) const { + longlong2 vec; + vec.x = array[0]; + vec.y = array[1]; + longlong2* buf_vector = reinterpret_cast(buf); + *buf_vector = vec; + } +}; + +// A cuda kernel to fill the data with random numbers from the specified +// distribution. Each output takes a fixed number of samples. +template +PHILOX_DEVICE_INLINE void FillPhiloxRandomKernel::Run( + const uint64* key, const uint64* counter, random::PhiloxRandom gen, T* data, + int64 size, Distribution dist) { + const int kGroupSize = Distribution::kResultElementCount; + + const int32 thread_id = blockIdx.x * blockDim.x + threadIdx.x; + const int32 total_thread_count = gridDim.x * blockDim.x; + int64 offset = thread_id * kGroupSize; + if (key != nullptr && counter != nullptr) { + gen = GetPhiloxRandomFromCounterKeyMem(counter, key); + } + gen.Skip(thread_id); + + const SampleCopier copier; + while (offset + kGroupSize <= size) { + const typename Distribution::ResultType samples = dist(&gen); + copier(&data[offset], samples); + + offset += total_thread_count * kGroupSize; + gen.Skip(total_thread_count - 1); + } + + typename Distribution::ResultType samples = dist(&gen); + for (int i = 0; i < kGroupSize; ++i) { + if (offset >= size) { + return; + } + data[offset] = samples[i]; + ++offset; + } +} + +// A cuda kernel to fill the data with random numbers from the specified +// distribution. Each output takes a variable number of samples. +template +PHILOX_DEVICE_INLINE void FillPhiloxRandomKernel::Run( + const uint64* key, const uint64* counter, random::PhiloxRandom base_gen, + T* data, int64 size, Distribution dist) { + if (key != nullptr && counter != nullptr) { + base_gen = GetPhiloxRandomFromCounterKeyMem(counter, key); + } + + using random::PhiloxRandom; + using random::SingleSampleAdapter; + + const int kReservedSamplesPerOutput = 256; + const int kGroupSize = Distribution::kResultElementCount; + const int kGeneratorSkipPerOutputGroup = kGroupSize * + kReservedSamplesPerOutput / + PhiloxRandom::kResultElementCount; + + const int32 thread_id = blockIdx.x * blockDim.x + threadIdx.x; + const int32 total_thread_count = gridDim.x * blockDim.x; + int64 group_index = thread_id; + int64 offset = group_index * kGroupSize; + + while (offset < size) { + // Since each output takes a variable number of samples, we need to + // realign the generator to the beginning for the current output group + PhiloxRandom gen = base_gen; + gen.Skip(group_index * kGeneratorSkipPerOutputGroup); + SingleSampleAdapter single_samples(&gen); + + typename Distribution::ResultType samples = dist(&single_samples); + + for (int i = 0; i < kGroupSize; ++i) { + if (offset >= size) { + return; + } + data[offset] = samples[i]; + ++offset; + } + + offset += (total_thread_count - 1) * kGroupSize; + group_index += total_thread_count; + } +} + +// A simple launch pad to call the correct function templates to fill the data +template +__global__ void __launch_bounds__(1024) + FillPhiloxRandomKernelLaunch(const uint64* key, const uint64* counter, + random::PhiloxRandom base_gen, + typename Distribution::ResultElementType* data, + int64 size, Distribution dist) { + FillPhiloxRandomKernel() + .Run(key, counter, base_gen, data, size, dist); +} + +// Partial specialization for GPU +template +void FillPhiloxRandom::operator()( + OpKernelContext*, const GPUDevice& d, const uint64* key, + const uint64* counter, random::PhiloxRandom gen, + typename Distribution::ResultElementType* data, int64 size, + Distribution dist) { + if (size == 0) return; + const int32 block_size = d.maxGpuThreadsPerBlock(); + const int32 num_blocks = + std::min( + d.getNumGpuMultiProcessors() * d.maxGpuThreadsPerMultiProcessor(), + size + block_size - 1) / + block_size; + TF_CHECK_OK(GpuLaunchKernel(FillPhiloxRandomKernelLaunch, + num_blocks, block_size, 0, d.stream(), key, + counter, gen, data, size, dist)); +} + +} // namespace functor +} // namespace tensorflow + +#endif // defined(__CUDACC__) || TENSORFLOW_USE_ROCM + +#endif // TENSORFLOW_CORE_KERNELS_RANDOM_OP_GPU_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/random_ops_util.h b/third_party/tflite-hdrs/tensorflow/core/kernels/random_ops_util.h new file mode 100644 index 00000000..b9904569 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/random_ops_util.h @@ -0,0 +1,73 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_RANDOM_OPS_UTIL_H_ +#define TENSORFLOW_CORE_KERNELS_RANDOM_OPS_UTIL_H_ + +#include "tensorflow/core/lib/random/philox_random.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +using random::PhiloxRandom; + +// The following 2 functions use the contract "lower 32 bits for the first +// uint32, higher 32 bits for the second". Note that this is endian-neutral, +// unlike a direct memory copy `memcpy(output, &input, 8)`. +PHILOX_DEVICE_INLINE void Uint64ToUint32s(uint64 input, uint32* output1, + uint32* output2) { + *output1 = static_cast(input); + *output2 = static_cast(input >> 32); +} + +PHILOX_DEVICE_INLINE uint64 Uint32sToUint64(uint32 input1, uint32 input2) { + auto u64_1 = static_cast(input1); + auto u64_2 = static_cast(input2); + return u64_1 | (u64_2 << 32); +} + +PHILOX_DEVICE_INLINE PhiloxRandom::ResultType GetCounterFromMem( + uint64 const* ptr) { + PhiloxRandom::ResultType counter; + Uint64ToUint32s(ptr[0], &counter[0], &counter[1]); + Uint64ToUint32s(ptr[1], &counter[2], &counter[3]); + return counter; +} + +PHILOX_DEVICE_INLINE void WriteCounterToMem( + PhiloxRandom::ResultType const& counter, uint64* ptr) { + ptr[0] = Uint32sToUint64(counter[0], counter[1]); + ptr[1] = Uint32sToUint64(counter[2], counter[3]); +} + +PHILOX_DEVICE_INLINE PhiloxRandom::Key GetKeyFromMem(uint64 const* ptr) { + PhiloxRandom::Key key; + Uint64ToUint32s(ptr[0], &key[0], &key[1]); + return key; +} + +PHILOX_DEVICE_INLINE void WriteKeyToMem(PhiloxRandom::Key const& key, + uint64* ptr) { + *ptr = Uint32sToUint64(key[0], key[1]); +} + +PHILOX_DEVICE_INLINE PhiloxRandom GetPhiloxRandomFromCounterKeyMem( + uint64 const* counter_ptr, uint64 const* key_ptr) { + return PhiloxRandom(GetCounterFromMem(counter_ptr), GetKeyFromMem(key_ptr)); +} + +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_RANDOM_OPS_UTIL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/random_poisson_op.h b/third_party/tflite-hdrs/tensorflow/core/kernels/random_poisson_op.h new file mode 100644 index 00000000..ca0dad4b --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/random_poisson_op.h @@ -0,0 +1,38 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_RANDOM_POISSON_OP_H_ +#define TENSORFLOW_CORE_KERNELS_RANDOM_POISSON_OP_H_ + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/lib/random/simple_philox.h" + +namespace tensorflow { + +namespace functor { + +// Generic helper functor for the Random Poisson Op. +template +struct PoissonFunctor { + void operator()(OpKernelContext* ctx, const Device& d, const T* rate_flat, + int64_t num_rate, int64_t num_samples, + const random::PhiloxRandom& rng, U* samples_flat); +}; + +} // namespace functor + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_RANDOM_POISSON_OP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/range_sampler.h b/third_party/tflite-hdrs/tensorflow/core/kernels/range_sampler.h new file mode 100644 index 00000000..c49bbcc5 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/range_sampler.h @@ -0,0 +1,243 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_RANGE_SAMPLER_H_ +#define TENSORFLOW_CORE_KERNELS_RANGE_SAMPLER_H_ + +#include + +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/lib/random/distribution_sampler.h" +#include "tensorflow/core/lib/random/random_distributions.h" +#include "tensorflow/core/lib/random/weighted_picker.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/thread_annotations.h" +#include "tensorflow/core/platform/types.h" + +namespace tsl { +class Env; +} // namespace tsl +namespace tensorflow { +using Env = tsl::Env; + +// Abstract subclass for sampling from the set of non-negative integers +// [0, range) +class RangeSampler { + public: + explicit RangeSampler(int64_t range) : range_(range) { CHECK_GT(range_, 0); } + virtual ~RangeSampler(); + + // Sample a single value + virtual int64_t Sample(random::SimplePhilox* rnd) const = 0; + + // The probability that a single call to Sample() returns the given value. + // Assumes that value is in [0, range). No range checking is done. + virtual float Probability(int64_t value) const = 0; + + // Fill "batch" with samples from the distribution. + // If unique=true, then we re-pick each element until we get a + // value distinct from all previously picked values in the batch. + void SampleBatch(random::SimplePhilox* rnd, bool unique, + absl::Span batch) const; + + // Fill "batch" with samples from the distribution, and report + // "expected counts". + // + // The "expected count" of a value is an estimate of the expected + // number of occurrences of the value in the batch returned by a + // call to this function with the given parameters. If unique=true, + // the expected count is an inclusion probability. For details on + // this estimation, see the comment to "ExpectedCountHelper" in the + // .cc file. + // + // Expected counts for the elements of the returned "batch" are reported + // in the aligned array "batch_expected_count". + // + // The user can optionally provide "extras", containing values in the range. + // The expected counts for the extras are reported in the aligned array + // "extras_expected_count". + // + // "batch_expected_count" must have size equal to 0 or to the size of "batch". + // "extras" and "extras_expected_count" must have equal size. + void SampleBatchGetExpectedCount( + random::SimplePhilox* rnd, bool unique, absl::Span batch, + absl::Span batch_expected_count, absl::Span extras, + absl::Span extras_expected_count) const; + + // Same as SampleBatchGetExpectedCount (see above), but with avoided values. + // We repick to avoid all of the values in "avoided_values". + // "avoided_values" is only supported with unique=true. If + // unique=false, then avoided_values must be empty. + virtual void SampleBatchGetExpectedCountAvoid( + random::SimplePhilox* rnd, bool unique, absl::Span batch, + absl::Span batch_expected_count, absl::Span extras, + absl::Span extras_expected_count, + absl::Span avoided_values) const; + + // Does this sampler need to be updated with values, e.g. UnigramSampler + virtual bool NeedsUpdates() const { return false; } + + // Updates the underlying distribution + virtual void Update(absl::Span values) { + LOG(FATAL) << "Update not supported for this sampler type."; + } + + int64_t range() { return range_; } + + protected: + const int64_t range_; +}; + +// An AllSampler only samples batches of size equal to range. +// It returns the entire range. +// It cannot sample single values. +class AllSampler : public RangeSampler { + public: + explicit AllSampler(int64_t range); + + ~AllSampler() override {} + + int64_t Sample(random::SimplePhilox* rnd) const override { + LOG(FATAL) << "Should not be called"; + return 0; + } + + float Probability(int64_t value) const override { + LOG(FATAL) << "Should not be called"; + return 0; + } + + void SampleBatchGetExpectedCountAvoid( + random::SimplePhilox* rnd, bool unique, absl::Span batch, + absl::Span batch_expected_count, absl::Span extras, + absl::Span extras_expected_count, + absl::Span avoided_values) const override; +}; + +class UniformSampler : public RangeSampler { + public: + explicit UniformSampler(int64_t range); + + ~UniformSampler() override {} + + int64_t Sample(random::SimplePhilox* rnd) const override; + + float Probability(int64_t value) const override; + + private: + const float inv_range_; +}; + +class LogUniformSampler : public RangeSampler { + public: + explicit LogUniformSampler(int64_t range); + + ~LogUniformSampler() override {} + + int64_t Sample(random::SimplePhilox* rnd) const override; + + float Probability(int64_t value) const override; + + private: + const double log_range_; +}; + +// Thread-unsafe unigram sampler +class ThreadUnsafeUnigramSampler : public RangeSampler { + public: + explicit ThreadUnsafeUnigramSampler(int64_t range); + ~ThreadUnsafeUnigramSampler() override {} + + int64_t Sample(random::SimplePhilox* rnd) const override; + + float Probability(int64_t value) const override; + + bool NeedsUpdates() const override { return true; } + void Update(absl::Span values) override; + + private: + random::WeightedPicker picker_; +}; + +// Thread-safe unigram sampler +class UnigramSampler : public RangeSampler { + public: + explicit UnigramSampler(int64_t range); + ~UnigramSampler() override {} + + int64_t Sample(random::SimplePhilox* rnd) const override; + + float Probability(int64_t value) const override; + + // Overriding at a high level results in far fewer lock acquisitions. + void SampleBatchGetExpectedCountAvoid( + random::SimplePhilox* rnd, bool unique, absl::Span batch, + absl::Span batch_expected_count, absl::Span extras, + absl::Span extras_expected_count, + absl::Span avoided_values) const override; + + bool NeedsUpdates() const override { return true; } + void Update(absl::Span values) override; + + private: + ThreadUnsafeUnigramSampler unsafe_sampler_ TF_GUARDED_BY(mu_); + mutable mutex mu_; +}; + +// A unigram sampler that uses a fixed unigram distribution read from a +// file or passed in as an in-memory array instead of building up the +// distribution from data on the fly. There is also an option to skew the +// distribution by applying a distortion power to the weights. +class FixedUnigramSampler : public RangeSampler { + public: + FixedUnigramSampler(int64_t range, float distortion, int32_t num_reserved_ids, + int32_t num_shards, int32_t shard); + // The vocab_file is assumed to be a CSV, with the last entry of each row a + // value representing the counts or probabilities for the corresponding ID. + absl::Status SetDistributionSampler(Env* env, const string& vocab_file); + absl::Status SetDistributionSampler(const std::vector& unigrams); + float Probability(int64_t value) const override; + + int64_t Sample(random::SimplePhilox* rnd) const override; + + private: + // Underlying distribution sampler. + std::unique_ptr dist_sampler_; + // Weights for individual samples. The probability of a sample i is defined + // as weights_.at(i) / total_weight_. + std::vector weights_; + // The total weights of all samples. + float total_weight_; + // Sharding information of the sampler. The whole vocabulary is sharded + // into num_shards_ smaller ranges and each sampler is responsible for one + // such smaller range, identified by the shard number. + int32 num_shards_; + int32 shard_; + float distortion_; + // Fill the sampler with the appropriate number of reserved IDs. + void FillReservedIds(int32_t num_reserved_ids); + // Load IDs to sample from a CSV file. It is assumed that the last item of + // each row contains a count or probability for the corresponding ID. + absl::Status LoadFromFile(Env* env, const string& vocab_file, + float distortion); + // Load from an in-memory array. + void LoadFromUnigrams(const std::vector& unigrams, float distortion); +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_RANGE_SAMPLER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/record_yielder.h b/third_party/tflite-hdrs/tensorflow/core/kernels/record_yielder.h new file mode 100644 index 00000000..7e4c0f5a --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/record_yielder.h @@ -0,0 +1,160 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_RECORD_YIELDER_H_ +#define TENSORFLOW_CORE_KERNELS_RECORD_YIELDER_H_ + +#include +#include +#include +#include + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/notification.h" +#include "tensorflow/core/lib/core/threadpool.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/thread_annotations.h" + +namespace tensorflow { + +// RecordYielder produces value records from a set of tfrecord files +// in a random order. +// +// It guarantees that: +// 1) all records in tfrecords are yielded within every epoch; +// 2) each record is yielded only once within every epoch; +// 3) the order in which records are yielded is highly randomized. +// 4) the peak memory usage is roughly avg record size * +// (opts.bufsize + opts.parallelism * 16). +// +// Usage example: +// RecordYielder::Options opts; +// opts.file_pattern = "input-*"; +// opts.seed = 301; +// opts.bufsize = 1000000; // A randomized buffer with 1M records. +// opts.parallelism = 8; // Uses 8 tfrecord iterators to iterate +// // through all files. +// RecordYielder yielder(opts); +// string val; +// while (true) { +// yielder.YieldOne(&val); +// // process val +// } +// +// RecordYielder can be accessed by multiple threads concurrently. +class RecordYielder { + public: + struct Options { + // Glob pattern for tfrecords. + string file_pattern; + + // Random seed. It determines how data files are shuffled and how + // records are shuffled. + int64_t seed = 0; + + // Each epoch, all files are first shuffled according to the + // random seed and the epoch number, and then all files are + // left-shifted by file_shuffle_shift_ratio * num_files slots. If + // file_shuffle_shift_ratio is not within [0, 1), the + // implementation clip it to [0, 1). + float file_shuffle_shift_ratio = 0; + + // Randomization buffer keeps these many records. + uint64 bufsize = 1; + + // Uses these many concurrent tfrecord iterators to iterate through + // tfrecords. + int32 parallelism = 1; + + string compression_type; + }; + + explicit RecordYielder(OpKernelConstruction* context, + const RecordYielder::Options& opts); + ~RecordYielder(); + + RecordYielder(const RecordYielder&) = delete; + RecordYielder& operator=(const RecordYielder&) = delete; + + // Yields one 'value'. + absl::Status YieldOne(tstring* value); + + // Returns the current epoch number. + int64_t current_epoch() const { return epoch_; } + + private: + typedef RecordYielder ME; + + Options opts_; + + // Backgrounds threads. Owned. + thread::ThreadPool* thread_; + + // Epoch number. + std::atomic epoch_; + + mutex mu_; + + // Turned to true when this is deleted. + bool stop_ TF_GUARDED_BY(mu_) = false; + absl::Status status_ TF_GUARDED_BY(mu_); + + // PRG used for randomization. + std::mt19937_64 rnd_ TF_GUARDED_BY(mu_); + + // Randomization buffer. + std::vector buf_ TF_GUARDED_BY(mu_); + + // True iff we are draining an epoch. + bool epoch_end_ = false; + + int64_t num_records_added_in_epoch_ = 0; + int64_t num_records_yielded_in_epoch_ = 0; + + // Trigger when the main loop has exited. + Notification main_loop_done_; + + // condition_variables. + condition_variable buf_empty_; + bool BufEmpty() const TF_SHARED_LOCKS_REQUIRED(mu_) { + return stop_ || buf_.empty(); + } + + condition_variable buf_not_full_; + bool BufNotFull() const TF_SHARED_LOCKS_REQUIRED(mu_) { + return stop_ || buf_.size() < opts_.bufsize; + } + + condition_variable buf_enough_; + bool BufEnough() const TF_SHARED_LOCKS_REQUIRED(mu_) { + // NOTE: Unless we are finishing an epoch, we want to make sure + // the buf_ contains enough randomized elements before yielding + // any. + return stop_ || !status_.ok() || (epoch_end_ && !buf_.empty()) || + (!epoch_end_ && + buf_.size() >= std::max(1, opts_.bufsize / 2)); + } + + void MainLoop(); + struct Shard; + void ShardLoop(Shard* shard); + bool ShouldFinish(const absl::Status& s); + bool Add(std::vector* values); +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_RECORD_YIELDER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/reduction_gpu_kernels.cu.h b/third_party/tflite-hdrs/tensorflow/core/kernels/reduction_gpu_kernels.cu.h new file mode 100644 index 00000000..a82e6c47 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/reduction_gpu_kernels.cu.h @@ -0,0 +1,1412 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_REDUCTION_GPU_KERNELS_CU_H_ +#define TENSORFLOW_CORE_KERNELS_REDUCTION_GPU_KERNELS_CU_H_ + +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM + +#define EIGEN_USE_GPU + +#include + +#include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/kernels/gpu_prim.h" +#include "tensorflow/core/kernels/reduction_ops.h" +#include "tensorflow/core/lib/core/bits.h" +#include "tensorflow/core/util/gpu_device_functions.h" +#include "tensorflow/core/util/gpu_kernel_helper.h" +#include "tensorflow/core/util/permutation_input_iterator.h" +#include "tensorflow/core/util/transform_output_iterator.h" + +namespace tensorflow { +namespace functor { + +typedef Eigen::GpuDevice GPUDevice; + +template +struct SqrtOfReal { + __host__ __device__ T operator()(const T& a) const { + return T(Eigen::numext::sqrt(Eigen::numext::real(a))); + } +}; + +template +struct Sum { + __host__ __device__ T operator()(const T& a, const T& b) const { + return a + b; + } +}; + +template +struct Prod { + __host__ __device__ T operator()(const T& a, const T& b) const { + return a * b; + } +}; + +template +struct Square { + __host__ __device__ T operator()(const T& a) const { + return Prod()(a, Eigen::numext::conj(a)); + } +}; + +template +struct DividesBy { + T divisor; + + __host__ __device__ explicit DividesBy(T divisor) : divisor(divisor) {} + + __host__ __device__ OUT_T operator()(const T& x) const { return x / divisor; } +}; + +struct MaxPropagateNaN { + template + __host__ __device__ inline T operator()(const T& a, const T& b) const { + return (a != a ? a : (a > b ? a : b)); + } +}; + +struct MinPropagateNaN { + template + __host__ __device__ inline T operator()(const T& a, const T& b) const { + return (a != a ? a : (a < b ? a : b)); + } +}; + +#if GOOGLE_CUDA +// TODO(rocm) : enable this once ROCm platform has support for complex datatypes +// +// needed to work around a compiler bug in nvcc - it doesn't seem to like +// the overloaded ops for std::complex +template <> +struct DividesBy> { + cuFloatComplex divisor; + + __host__ __device__ explicit DividesBy(std::complex divisor) + : divisor(make_cuComplex(divisor.real(), divisor.imag())) {} + + // implements + __host__ __device__ std::complex operator()( + const std::complex& x) const { + auto result = cuCdivf(make_cuComplex(x.real(), x.imag()), divisor); + return std::complex(result.x, result.y); + } +}; + +template <> +struct DividesBy> { + cuDoubleComplex divisor; + + __host__ __device__ explicit DividesBy(std::complex divisor) + : divisor(make_cuDoubleComplex(divisor.real(), divisor.imag())) {} + + // implements + __host__ __device__ std::complex operator()( + const std::complex& x) const { + auto result = cuCdiv(make_cuDoubleComplex(x.real(), x.imag()), divisor); + return std::complex(result.x, result.y); + } +}; +#endif // GOOGLE_CUDA + +template +struct DividesBy { + float divisor; + + __host__ __device__ explicit DividesBy(float divisor) : divisor(divisor) {} + + __host__ __device__ T operator()(const float& x) const { + return T(x / divisor); + } +}; + +template +struct HalfToFloat { + __host__ __device__ float operator()(const T& x) const { + return static_cast(x); + } +}; + +template +struct FloatToHalf { + __host__ __device__ T operator()(const float& x) const { + return static_cast(x); + } +}; + +struct And { + __host__ __device__ bool operator()(const bool& a, const bool& b) const { + return a && b; + } +}; + +struct Or { + __host__ __device__ bool operator()(const bool& a, const bool& b) const { + return a || b; + } +}; + +// each block does a grid strided loop and reduces its values locally +// the case of one block is used for low latency small reductions to scalars +template +__global__ __launch_bounds__(1024) void BlockReduceKernel( + T in, OUT_T out, int num_elems, Op op, + typename std::iterator_traits::value_type initVal) { + const int bid = blockIdx.x; + const int tid = threadIdx.x; + + const int gid = bid * blockDim.x + tid; + const int stride = blockDim.x * gridDim.x; + + typedef typename std::iterator_traits::value_type value_type; + + value_type sum = initVal; + if (gid < num_elems) { + sum = in[gid]; + for (int pos = gid + stride; pos < num_elems; pos += stride) { + sum = op(sum, in[pos]); + } + } + + typedef gpuprim::BlockReduce BlockReduce; + + __shared__ typename BlockReduce::TempStorage temp_storage; + + // only include input values in the reduction + // + // elements: ----------------- + // grid: |====|====|====|====|====| + const int num_elements_to_reduce = + max(min(static_cast(num_elems - bid * blockDim.x), num_threads), 0); + + sum = BlockReduce(temp_storage).Reduce(sum, op, num_elements_to_reduce); + + if (tid == 0) out[bid] = sum; +} + +// maps a warp to each row +template +__global__ __launch_bounds__(1024) void RowReduceKernel( + T in, OUT_T out, int num_rows, int num_cols, Op op, + typename std::iterator_traits::value_type initVal) { + typedef typename std::iterator_traits::value_type value_type; + // Defensive index computation to avoid integer overflow. + assert(blockDim.x % TF_RED_WARPSIZE == 0); + int warps_per_block = blockDim.x / TF_RED_WARPSIZE; + int warp_index = threadIdx.x / TF_RED_WARPSIZE; + const int row = blockIdx.x * warps_per_block + warp_index; + const int lane = threadIdx.x % TF_RED_WARPSIZE; + + if (num_cols == 1) { + int gid = threadIdx.x + blockIdx.x * blockDim.x; + if (gid < num_rows) out[gid] = in[gid]; + return; + } + + value_type sum = initVal; + int col = lane; + + if (row < num_rows && col < num_cols) { + sum = in[row * num_cols + col]; + col += TF_RED_WARPSIZE; + for (; col < num_cols; col += TF_RED_WARPSIZE) { + sum = op(sum, in[row * num_cols + col]); + } + } + + typedef gpuprim::WarpReduce WarpReduce; + + __shared__ typename WarpReduce::TempStorage temp_storage; + + sum = + WarpReduce(temp_storage).Reduce(sum, op, min(num_cols, TF_RED_WARPSIZE)); + + if (row < num_rows && lane == 0) out[row] = sum; +} + +template +struct storage_type { + T1 val; + __host__ __device__ storage_type() {} + __host__ __device__ operator T1() { return val; } + __host__ __device__ storage_type& operator=(const T1& in) { + val = in; + return *this; + } +}; + +template +struct storage_type> { + T2 real; + T2 imag; + __host__ __device__ storage_type() {} + __host__ __device__ operator std::complex() { + return std::complex(real, imag); + } + __host__ __device__ storage_type>& operator=( + const std::complex& in) { + real = in.real(); + imag = in.imag(); + return *this; + } +}; + +// Works only if there are <= 16 columns +// each warps sums over multiple rows at once +template +__global__ __launch_bounds__(1024) void ColumnReduceMax16ColumnsKernel( + T in, OUT_T out, int num_rows, int num_cols, Op op, + typename std::iterator_traits::value_type initVal) { + typedef typename std::iterator_traits::value_type value_type; + int rows_per_warp = TF_RED_WARPSIZE / num_cols; + + const int lane = threadIdx.x % TF_RED_WARPSIZE; + const int lane_row = lane / num_cols; + + const int start_row_warp = + rows_per_warp * (blockIdx.y * blockDim.y + threadIdx.y); + const int start_row_lane = start_row_warp + lane_row; + int row = start_row_lane; + int col = lane % num_cols; + + value_type sum = initVal; + if (row * num_cols + col < num_rows * num_cols) + sum = in[row * num_cols + col]; + + // 1D array necessary due to bug in CUDA 9 compiler. + // TODO(nluehr) revert to 2D array when compiler is ready. + // This is to mimic the following, but without any constructors: + // __shared__ storage_type partial_sums[TF_RED_WARPSIZE * + // (TF_RED_WARPSIZE+1)]; +#if GOOGLE_CUDA + __shared__ __align__(alignof(value_type)) char + partial_sums_raw[TF_RED_WARPSIZE * (TF_RED_WARPSIZE + 1) * + sizeof(value_type)]; + value_type* partial_sums = reinterpret_cast(partial_sums_raw); +#elif TENSORFLOW_USE_ROCM + __shared__ storage_type + partial_sums[TF_RED_WARPSIZE * (TF_RED_WARPSIZE + 1)]; +#endif + + row += rows_per_warp * gridDim.y * blockDim.y; + for (; row < num_rows; row += rows_per_warp * gridDim.y * blockDim.y) { + int global_pos = row * num_cols + col; + if (global_pos < (num_rows * num_cols)) + sum = op(sum, in[row * num_cols + col]); + } + + const int rows_in_this_warp = min(rows_per_warp, num_rows - start_row_warp); + // not the most efficient way to do this sum + for (int i = 1; i < rows_in_this_warp; ++i) { + value_type tmp = gpuprim::ShuffleIndex( + sum, static_cast(threadIdx.x + i * num_cols), 0xffffffff); + if (lane < num_cols) sum = op(sum, tmp); + } + + if (lane < num_cols) + partial_sums[lane * (TF_RED_WARPSIZE + 1) + threadIdx.y] = sum; + + __syncthreads(); + + if (threadIdx.y == 0 && threadIdx.x < num_cols) { + value_type s = partial_sums[threadIdx.x * (TF_RED_WARPSIZE + 1)]; + + if (blockDim.y > 1) { + for (int row = 1; row < blockDim.y; ++row) { + value_type t = partial_sums[threadIdx.x * (TF_RED_WARPSIZE + 1) + row]; + s = op(s, t); + } + } + + out[col * gridDim.y + blockIdx.y] = s; + } +} + +// Maps each block to a column range TF_RED_WARPSIZE wide +template +__global__ __launch_bounds__(1024) void ColumnReduceKernel( + T in, OUT_T out, int num_rows, int num_cols, Op op, + typename std::iterator_traits::value_type initVal) { + typedef typename std::iterator_traits::value_type value_type; + int row = blockIdx.y * blockDim.y + threadIdx.y; + int col = blockIdx.x * TF_RED_WARPSIZE + threadIdx.x; + + value_type sum = initVal; + if (row < num_rows && col < num_cols) sum = in[row * num_cols + col]; + + // 1D array necessary due to bug in CUDA 9 compiler. + // TODO(nluehr) revert to 2D array when compiler is ready. + // This is to mimic the following, but without constructors: + // __shared__ storage_type partial_sums[TF_RED_WARPSIZE * + // (TF_RED_WARPSIZE + 1)]; +#if GOOGLE_CUDA + __shared__ __align__(alignof(value_type)) char + partial_sums_raw[TF_RED_WARPSIZE * (TF_RED_WARPSIZE + 1) * + sizeof(value_type)]; + value_type* partial_sums = reinterpret_cast(partial_sums_raw); +#elif TENSORFLOW_USE_ROCM + __shared__ storage_type + partial_sums[TF_RED_WARPSIZE * (TF_RED_WARPSIZE + 1)]; +#endif + + row += gridDim.y * blockDim.y; + + if (col < num_cols) { + for (; row < num_rows; row += gridDim.y * blockDim.y) { + sum = op(sum, in[row * num_cols + col]); + } + } + + partial_sums[threadIdx.x * (TF_RED_WARPSIZE + 1) + threadIdx.y] = sum; + + __syncthreads(); + + if (threadIdx.y == 0 && col < num_cols) { + value_type s = partial_sums[threadIdx.x * (TF_RED_WARPSIZE + 1)]; + + // only include input values in the reduction + // elem block_rows + // - = + // - = + // # # block boundary + // - = + // - = + // # # block boundary + // - = + // = + const int numRowsThisBlock = + min(static_cast(blockDim.y), num_rows - blockIdx.y * blockDim.y); + + for (int row = 1; row < numRowsThisBlock; ++row) { + value_type t = partial_sums[threadIdx.x * (TF_RED_WARPSIZE + 1) + row]; + s = op(s, t); + } + + out[col * gridDim.y + blockIdx.y] = s; + } +} + +// does multiple warp size segmented reductions in parallel +// segments cannot cross warp boundaries (mainly used for reducing the segments +// that come from the Max16Columns column reduction kernel) +template +__global__ __launch_bounds__(1024) void CleanupSegments( + T partial_sums, OUT_T out, int num_rows, int num_cols, int segment_size, + Op op, typename std::iterator_traits::value_type initVal) { + typedef typename std::iterator_traits::value_type value_type; + const int tid = threadIdx.x + blockIdx.x * blockDim.x; + + value_type val = initVal; + if (tid < segment_size * num_cols) val = partial_sums[tid]; + + typedef gpuprim::WarpReduce WarpReduce; + + __shared__ typename WarpReduce::TempStorage temp_storage; + + const bool head_flag = (threadIdx.x % segment_size) == 0; + value_type sum = + WarpReduce(temp_storage).HeadSegmentedReduce(val, head_flag, op); + + if (head_flag && tid < segment_size * num_cols) { + out[tid / segment_size] = sum; + } +} + +// assigns one thread to a column +template +__global__ __launch_bounds__(1024) void ColumnReduceSimpleKernel( + T in, OUT_T out, int num_planes, int num_rows, int num_cols, Op op) { + typedef typename std::iterator_traits::value_type value_type; + const int gid = threadIdx.x + blockIdx.x * blockDim.x; + const int elems_per_plane = num_rows * num_cols; + + const int plane = gid / num_cols; + const int col = gid % num_cols; + + if (plane >= num_planes) return; + + if (num_rows == 1) { + out[plane * elems_per_plane + col] = in[plane * elems_per_plane + col]; + return; + } + + value_type sum = op(in[plane * elems_per_plane + col], + in[plane * elems_per_plane + num_cols + col]); + for (int row = 2; row < num_rows; ++row) { + sum = op(sum, in[plane * elems_per_plane + row * num_cols + col]); + } + + out[plane * num_cols + col] = sum; +} + +namespace { +constexpr int kUnroll = 8; +} + +template +__device__ __inline__ T ComputeSum(IN_T in_, const int plane, + const int num_out_rows, int num_rows, + int num_cols, const int col, Op op) { + const int out_rows = num_rows / (2 * kUnroll); + const int num_rem_rows = num_rows % (2 * kUnroll); + const int elems_per_plane = num_rows * num_cols; + T reg[2 * kUnroll]; + T sum; + int offset = 0; + if (out_rows != 0) { + for (int i = 0; i < 2 * kUnroll; i++) { + reg[i] = + in_[plane * elems_per_plane + i * (num_out_rows * num_cols) + col]; + } + sum = reg[0]; + for (int i = 1; i < 2 * kUnroll; i++) { + sum = op(sum, reg[i]); + } + offset = 2 * kUnroll * (num_out_rows * num_cols); + } + + if (col < num_cols && num_rem_rows > 0) { + reg[0] = in_[plane * elems_per_plane + offset + 0 * num_cols + col]; + if (out_rows != 0) { + sum = op(sum, reg[0]); + } else { + sum = reg[0]; + } + for (int i = 1; i < num_rem_rows; i++) { + reg[0] = in_[plane * elems_per_plane + offset + i * num_cols + col]; + sum = op(sum, reg[0]); + } + } + return sum; +} + +template +__global__ __launch_bounds__(1024) void ColumnReduceInToTempKernel( + void* __restrict__ temp, int temp_in_offset, int temp_out_offset, IN_T in, + int num_planes, int num_rows, int num_cols, Op op) { + typedef typename std::iterator_traits::value_type value_type; + + value_type* t = (value_type*)temp; + value_type* out_ = t + temp_out_offset; + + const int gid = threadIdx.x + blockIdx.x * blockDim.x; + const int num_out_rows = max(1, num_rows / (2 * kUnroll)); + const int plane = gid / (num_out_rows * num_cols); + const int col = gid % (num_out_rows * num_cols); + + if (plane >= num_planes) return; + + value_type sum; + if (temp_in_offset == -1) { + auto in_ = in; + sum = ComputeSum(in_, plane, num_out_rows, num_rows, + num_cols, col, op); + } else { + auto in_ = t + temp_in_offset; + sum = ComputeSum(in_, plane, num_out_rows, + num_rows, num_cols, col, op); + } + out_[plane * num_out_rows * num_cols + col] = sum; +} + +template +__global__ __launch_bounds__(1024) void ColumnReduceTempToOutKernel( + void* __restrict__ temp, int temp_in_offset, T in, OUT_T out, + int num_planes, int num_rows, int num_cols, Op op) { + typedef typename std::iterator_traits::value_type value_type; + value_type* t = (value_type*)temp; + const int tid = threadIdx.x; + const int gid = threadIdx.x + blockIdx.x * blockDim.x; + int elems_per_plane = num_rows * num_cols; + + if (num_rows == 1) { + if (gid >= num_planes * num_cols) return; + if (temp_in_offset == -1) { + auto in_ = in; + out[gid] = in_[gid]; + } else { + auto in_ = t + temp_in_offset; + out[gid] = in_[gid]; + } + return; + } + + const int planes_per_block = 1; + const int plane = blockIdx.x * planes_per_block + tid / elems_per_plane; + // A thread block contains one or multiple plane(s), + // i.e. num_rows * num_cols <= blockDim.x + const int col = tid % elems_per_plane; + const int local_plane = plane % planes_per_block; + + if (tid >= planes_per_block * elems_per_plane || plane >= num_planes) return; + + GPU_DYNAMIC_SHARED_MEM_DECL(8, char, ss); + value_type* const smem = reinterpret_cast(ss); + + if (temp_in_offset == -1) { + auto in_ = in; + smem[local_plane * elems_per_plane + col] = + in_[plane * elems_per_plane + col]; + } else { + auto in_ = t + temp_in_offset; + smem[local_plane * elems_per_plane + col] = + in_[plane * elems_per_plane + col]; + } + __syncthreads(); + + int num_in_rows = num_rows; + int num_out_rows; + int num_rem_rows; + + int in_offset = 0; + int out_offset = blockDim.x; + + int in_elems_per_plane = elems_per_plane; + int out_elems_per_plane; + + while (num_in_rows > 1) { + num_out_rows = num_in_rows / 2; + num_rem_rows = num_in_rows % 2; + out_elems_per_plane = num_out_rows * num_cols; + + if (col < out_elems_per_plane) { + value_type sum; + sum = op(smem[in_offset + local_plane * in_elems_per_plane + col], + smem[in_offset + local_plane * in_elems_per_plane + + out_elems_per_plane + col]); + if (num_rem_rows == 1 && col < num_cols) { + sum = op(sum, smem[in_offset + local_plane * in_elems_per_plane + + 2 * out_elems_per_plane + col]); + } + smem[out_offset + local_plane * out_elems_per_plane + col] = sum; + } + + num_in_rows = num_out_rows; + in_elems_per_plane = out_elems_per_plane; + int t_offset = in_offset; + in_offset = out_offset; + out_offset = t_offset; + __syncthreads(); + } + + if (col < num_cols) { + out[plane * num_cols + col] = + smem[in_offset + local_plane * out_elems_per_plane + col]; + } +} + +struct RowOffset { + __host__ __device__ explicit RowOffset(const int& cols) : cols_(cols) {} + + __host__ __device__ int operator()(const int& x) const { return cols_ * x; } + + int cols_; +}; + +struct GatherOp { + __host__ __device__ GatherOp(const int& extent_x, const int& extent_y, + const int& extent_z, bool kOne) + : extent_x_(extent_x), + extent_y_(extent_y), + extent_z_(extent_z), + kOne_(kOne) { + if (kOne_) + group_size_ = extent_y_; + else + group_size_ = extent_x_ * extent_z_; + } + + __host__ __device__ int operator()(const int& ind) const { + const int group = kOne_ ? ind / group_size_ : ind % group_size_; + const int offset = kOne_ ? ind % group_size_ : ind / group_size_; + + const int x = group / extent_z_; + const int z = group % extent_z_; + + return x * extent_y_ * extent_z_ + z + offset * extent_z_; + } + + int extent_x_; + int extent_y_; + int extent_z_; + bool kOne_; + int group_size_; +}; + +template +void LaunchScalarReduction(OpKernelContext* ctx, OUT_T out, IN_T in, + int in_size, Op op, T init, + const gpuStream_t& cu_stream) { + // handle situations where low latency is important better than CUB + if (in_size <= 4096) { + const int num_blocks = 1; + const int num_threads = 256; + TF_CHECK_OK(GpuLaunchKernel(BlockReduceKernel, + num_blocks, num_threads, 0, cu_stream, in, out, + in_size, op, init)); + return; + } else if (in_size <= 1 << 18) { + const int num_threads = 256; + const int num_blocks = + std::min(TF_RED_WARPSIZE, Eigen::divup(in_size, num_threads)); + // it seems like tailoring this to the GPU + // would be more effective, but all attempts + // at making this a multiple of the number of + // multiprocessors have lead to lower perf + // in general + // TODO(eriche) investigate this more + + Tensor temp_storage; + OP_REQUIRES_OK( + ctx, ctx->allocate_temp( + DT_INT8, + TensorShape({static_cast(num_blocks * sizeof(T))}), + &temp_storage)); + + TF_CHECK_OK(GpuLaunchKernel(BlockReduceKernel, + num_blocks, num_threads, 0, cu_stream, in, + (T*)temp_storage.flat().data(), in_size, + op, init)); + + // take care that we only reduce blocks that had some valid elements in them + // TODO(eriche): CUB currently has a bug in HeadSegmentedReduce that + // requires it to be used with a full warp. Can reduce TF_RED_WARPSIZE -> + // num_blocks when this is fixed. + TF_CHECK_OK(GpuLaunchKernel(CleanupSegments, 1, + TF_RED_WARPSIZE, 0, cu_stream, + (T*)temp_storage.flat().data(), out, 1, + 1, num_blocks, op, init)); + return; + } + + size_t temp_storage_bytes = 0; + auto reduce = [&](void* temp_storage_ptr) { + auto success = + gpuprim::DeviceReduce::Reduce(temp_storage_ptr, temp_storage_bytes, in, + out, in_size, op, init, cu_stream); + + OP_REQUIRES( + ctx, success == 0, + errors::Internal("CUB reduce error ", GpuGetErrorString(success))); + }; + + reduce(nullptr); // Get required amount of temp storage. + + Tensor temp_storage; + OP_REQUIRES_OK( + ctx, ctx->allocate_temp( + DT_INT8, TensorShape({static_cast(temp_storage_bytes)}), + &temp_storage)); + + reduce(temp_storage.flat().data()); // Do reduction. +} + +template +void LaunchRowReduction(OpKernelContext* ctx, OUT_T out, IN_T in, int num_rows, + int num_cols, Op op, T init, + const gpuStream_t& cu_stream) { + if (num_cols < 1024) { + const int threads_per_block = 128; + const int warps_per_block = threads_per_block / TF_RED_WARPSIZE; + int num_blocks = (num_rows + warps_per_block - 1) / warps_per_block; + + TF_CHECK_OK(GpuLaunchKernel(RowReduceKernel, num_blocks, + threads_per_block, 0, cu_stream, in, out, + num_rows, num_cols, op, init)); + return; + } + + // setup segment offsets with counting and transform iterator + RowOffset row_offset_op(num_cols); + gpuprim::CountingInputIterator counting_iter(0); + gpuprim::TransformInputIterator> + transform_iter(counting_iter, row_offset_op); + + size_t temp_storage_bytes = 0; + auto reduce = [&](void* temp_storage_ptr) { + auto success = gpuprim::DeviceSegmentedReduce::Reduce( + temp_storage_ptr, temp_storage_bytes, in, out, num_rows, transform_iter, + transform_iter + 1, op, init, cu_stream); + + OP_REQUIRES(ctx, success == 0, + errors::Internal("CUB segmented reduce error", + GpuGetErrorString(success))); + }; + + reduce(nullptr); // Get required amount of temp storage. + + Tensor temp_storage; + OP_REQUIRES_OK( + ctx, ctx->allocate_temp( + DT_INT8, TensorShape({static_cast(temp_storage_bytes)}), + &temp_storage)); + + reduce(temp_storage.flat().data()); // Do reduction. +} + +template +void LaunchColumnReduction_LTE16Cols(OpKernelContext* ctx, OUT_T out, IN_T in, + int extent_x, int extent_y, Op op, T init, + const gpuStream_t& cu_stream) { + int rows_per_warp = TF_RED_WARPSIZE / extent_y; + dim3 block_dim( + TF_RED_WARPSIZE, + std::min(Eigen::divup(extent_x, rows_per_warp), (1024 / TF_RED_WARPSIZE)), + 1); + dim3 grid_dim(1, + Eigen::divup(static_cast(extent_x), + rows_per_warp * block_dim.y), + 1); + + grid_dim.y = std::min((int)grid_dim.y, TF_RED_WARPSIZE); + + if (grid_dim.y > 2 && grid_dim.y < TF_RED_WARPSIZE) { + int log2 = Log2Floor(grid_dim.y); + grid_dim.y = 1 << log2; + } + + if (grid_dim.y == 1) { + TF_CHECK_OK(GpuLaunchKernel(ColumnReduceMax16ColumnsKernel, + grid_dim, block_dim, 0, cu_stream, in, out, + extent_x, extent_y, op, init)); + } else { + Tensor temp_storage; + OP_REQUIRES_OK(ctx, + ctx->allocate_temp(DT_INT8, + TensorShape({static_cast( + sizeof(T) * extent_y * grid_dim.y)}), + &temp_storage)); + TF_CHECK_OK(GpuLaunchKernel(ColumnReduceMax16ColumnsKernel, + grid_dim, block_dim, 0, cu_stream, in, + (T*)temp_storage.flat().data(), + extent_x, extent_y, op, init)); + + dim3 new_grid_dim( + (grid_dim.y * extent_y + (TF_RED_WARPSIZE - 1)) / TF_RED_WARPSIZE, 1, + 1); + dim3 num_threads(128, 1, 1); + TF_CHECK_OK(GpuLaunchKernel(CleanupSegments, new_grid_dim, + num_threads, 0, cu_stream, + (T*)temp_storage.flat().data(), out, + extent_x, extent_y, grid_dim.y, op, init)); + } +} + +template +void LaunchColumnReduction_LTE4096Cols(OpKernelContext* ctx, OUT_T out, IN_T in, + int extent_x, int extent_y, Op op, + T init, const gpuStream_t& cu_stream) { + dim3 block_dim(TF_RED_WARPSIZE, std::min(extent_x, (1024 / TF_RED_WARPSIZE)), + 1); + dim3 grid_dim((extent_y + (TF_RED_WARPSIZE - 1)) / TF_RED_WARPSIZE, 1, 1); + + if (grid_dim.x < 16) + grid_dim.y = std::min((extent_x + (TF_RED_WARPSIZE - 1)) / TF_RED_WARPSIZE, + TF_RED_WARPSIZE); + + if (grid_dim.y > 2 && grid_dim.y < TF_RED_WARPSIZE) { + int log2 = Log2Floor(grid_dim.y); + grid_dim.y = 1 << log2; + } + + if (grid_dim.y == 1) { + TF_CHECK_OK(GpuLaunchKernel(ColumnReduceKernel, grid_dim, + block_dim, 0, cu_stream, in, out, extent_x, + extent_y, op, init)); + } else { + Tensor temp_storage; + OP_REQUIRES_OK(ctx, + ctx->allocate_temp(DT_INT8, + TensorShape({static_cast( + sizeof(T) * extent_y * grid_dim.y)}), + &temp_storage)); + + TF_CHECK_OK(GpuLaunchKernel( + ColumnReduceKernel, grid_dim, block_dim, 0, cu_stream, in, + (T*)temp_storage.flat().data(), extent_x, extent_y, op, init)); + + dim3 new_grid_dim( + (grid_dim.y * extent_y + (TF_RED_WARPSIZE - 1)) / TF_RED_WARPSIZE, 1, + 1); + TF_CHECK_OK(GpuLaunchKernel(CleanupSegments, new_grid_dim, + block_dim, 0, cu_stream, + (T*)temp_storage.flat().data(), out, + extent_x, extent_y, grid_dim.y, op, init)); + } +} + +template +void LaunchColumnReduction(OpKernelContext* ctx, OUT_T out, IN_T in, + int extent_x, int extent_y, Op op, T init, + const gpuStream_t& cu_stream) { + if (extent_y <= 16) { + LaunchColumnReduction_LTE16Cols(ctx, out, in, extent_x, extent_y, op, init, + cu_stream); + } else if (extent_y <= 4096) { + LaunchColumnReduction_LTE4096Cols(ctx, out, in, extent_x, extent_y, op, + init, cu_stream); + } else { + int threads_per_block = 128; + int num_blocks = Eigen::divup(extent_y, threads_per_block); + + TF_CHECK_OK(GpuLaunchKernel(ColumnReduceSimpleKernel, + num_blocks, threads_per_block, 0, cu_stream, in, + out, 1, extent_x, extent_y, op)); + } +} + +template +void Launch3DYReductionSimple(OpKernelContext* ctx, OUT_T out, IN_T in, + int extent_x, int extent_y, int extent_z, Op op, + T init, const gpuStream_t& cu_stream) { + int threads_per_block = 128; + int num_blocks = + (extent_x * extent_z + threads_per_block - 1) / threads_per_block; + + // TODO(eriche): this won't be very good in the case of small x + // small z and large y. + TF_CHECK_OK(GpuLaunchKernel(ColumnReduceSimpleKernel, + num_blocks, threads_per_block, 0, cu_stream, in, + out, extent_x, extent_y, extent_z, op)); +} + +template +void Launch3DYReduction(OpKernelContext* ctx, OUT_T out, IN_T in, int extent_x, + int extent_y, int extent_z, Op op, T init, + const gpuStream_t& cu_stream) { + int threads_per_block = 128; + + int n_group_in = extent_y; + int n_size = extent_z; + + // Calculate and allocate temporary space + std::size_t temp_storage_bytes = 0; + // A plane's size is n_group_in * n_size. We make sure no single plane crosses + // more than one thread block, meaning a thread block will handle one whole + // plane or multiple planes in the second stage. Also, It may handle a partial + // plane when n_size is too large and the while-loop will stop at + // n_group_in = 1, where we directly copy the temp to output in the next + // stage. + while (n_group_in >= 2 && n_group_in * n_size > threads_per_block) { + int n_group_out = std::max(1, n_group_in / (2 * kUnroll)); + temp_storage_bytes += n_group_out * n_size; + n_group_in = n_group_out; + } + temp_storage_bytes *= extent_x * sizeof(T); + Tensor temp_storage; + OP_REQUIRES_OK( + ctx, ctx->allocate_temp( + DT_INT8, TensorShape({static_cast(temp_storage_bytes)}), + &temp_storage)); + + // Reduction + n_group_in = extent_y; + int temp_in_offset = -1; + int temp_out_offset = 0; + int num_blocks; + while (n_group_in >= 2 && n_group_in * n_size > threads_per_block) { + int n_group_out = std::max(1, n_group_in / (2 * kUnroll)); + num_blocks = + Eigen::divup(extent_x * n_group_out * n_size, threads_per_block); + TF_CHECK_OK(GpuLaunchKernel( + ColumnReduceInToTempKernel, num_blocks, threads_per_block, 0, + cu_stream, (void*)(temp_storage.flat().data()), temp_in_offset, + temp_out_offset, in, extent_x, n_group_in, extent_z, op)); + + n_group_in = n_group_out; + temp_in_offset = temp_out_offset; + temp_out_offset = temp_in_offset + extent_x * n_group_out * n_size; + } + + if (n_group_in * n_size <= threads_per_block) { + num_blocks = extent_x; + } else { + DCHECK_EQ(1, n_group_in); + num_blocks = Eigen::divup(extent_x * n_size, threads_per_block); + } + + TF_CHECK_OK(GpuLaunchKernel( + ColumnReduceTempToOutKernel, num_blocks, + threads_per_block, 2 * sizeof(T) * threads_per_block, cu_stream, + (void*)(temp_storage.flat().data()), temp_in_offset, in, out, + extent_x, n_group_in, extent_z, op)); +} + +template +void Launch3DXZReduction(OpKernelContext* ctx, OUT_T out, IN_T in, int extent_x, + int extent_y, int extent_z, Op op, T init, + const gpuStream_t& cu_stream) { + // setup segment offsets with counting and transform iterator + RowOffset row_offset_op(extent_x * extent_z); + gpuprim::CountingInputIterator counting_iter(0); + gpuprim::TransformInputIterator> + transform_iter(counting_iter, row_offset_op); + + GatherOp gather_op(extent_x, extent_y, extent_z, false); + typedef gpuprim::TransformInputIterator> + gatherIterType; + gatherIterType gather_iter(counting_iter, gather_op); + + PermutationInputIterator permute_iter(in, + gather_iter); + + std::size_t temp_storage_bytes = 0; + auto reduce = [&](void* temp_storage_ptr) { + auto success = gpuprim::DeviceSegmentedReduce::Reduce( + temp_storage_ptr, temp_storage_bytes, permute_iter, out, extent_y, + transform_iter, transform_iter + 1, op, init, cu_stream); + + OP_REQUIRES(ctx, success == 0, + errors::Internal("CUB segmented reduce error", + GpuGetErrorString(success))); + }; + + reduce(nullptr); // Get required amount of temp storage. + + Tensor temp_storage; + OP_REQUIRES_OK( + ctx, ctx->allocate_temp( + DT_INT8, TensorShape({static_cast(temp_storage_bytes)}), + &temp_storage)); + + reduce(temp_storage.flat().data()); // Do reduction. +} + +namespace reduction_op_helper { + +template +struct IsSum { + constexpr static bool value = + (std::is_same::value || + std::is_same>::value || + std::is_same>::value); +}; + +template +struct IsMax { + constexpr static bool value = + (std::is_same::value || + std::is_same::value || + std::is_same< + Op, Eigen::internal::MaxReducer>::value); +}; + +template +struct IsMin { + constexpr static bool value = + (std::is_same::value || + std::is_same::value || + std::is_same< + Op, Eigen::internal::MinReducer>::value); +}; + +template +struct IsProd { + constexpr static bool value = + (std::is_same>::value || + std::is_same>::value); +}; + +template +struct IdentityValue { + static_assert(IsSum::value || IsMax::value || + IsMin::value || IsProd::value || + std::is_same::value || std::is_same::value, + "IdentityValue not yet defined for this type"); + + template + U operator()( + typename std::enable_if::value, U>::type t = U(0)) { + return t; + } + + template + U operator()(typename std::enable_if::value, U>::type t = + Eigen::NumTraits::lowest()) { + return t; + } + + template + U operator()(typename std::enable_if::value, U>::type t = + Eigen::NumTraits::highest()) { + return t; + } + + template + U operator()( + typename std::enable_if::value, U>::type t = U(1)) { + return t; + } + + template + U operator()(typename std::enable_if::value, + bool>::type t = true) { + return t; + } + + template + U operator()(typename std::enable_if::value, + bool>::type t = false) { + return t; + } +}; + +} // namespace reduction_op_helper + +template +void ReduceImpl(OpKernelContext* ctx, OUT_T out, IN_T in, int in_rank, + int in_dim0, int in_dim1, int in_dim2, int out_rank, + const ReductionAxes& reduction_axes, Op op) { + T init = reduction_op_helper::IdentityValue()(); + const gpuStream_t& cu_stream = GetGpuStream(ctx); + if (out_rank == 0) { + const int in_size = in_dim0 * in_dim1 * in_dim2; + LaunchScalarReduction(ctx, out, in, in_size, op, init, cu_stream); + } else if (in_rank == 2 && out_rank == 1 && + reduction_axes[0] == 1) { // row reduction + LaunchRowReduction(ctx, out, in, in_dim0, in_dim1, op, init, cu_stream); + } else if (in_rank == 2 && out_rank == 1 && + reduction_axes[0] == 0) { // column reduction + LaunchColumnReduction(ctx, out, in, in_dim0, in_dim1, op, init, cu_stream); + } else if (in_rank == 3 && out_rank == 2 && reduction_axes[0] == 1) { + int elems_per_thread = in_dim1 / (in_dim0 * in_dim2); + if (elems_per_thread >= 16) { + Launch3DYReduction(ctx, out, in, in_dim0, in_dim1, in_dim2, op, init, + cu_stream); + } else { + Launch3DYReductionSimple(ctx, out, in, in_dim0, in_dim1, in_dim2, op, + init, cu_stream); + } + } else if (in_rank == 3 && out_rank == 1 && reduction_axes[0] == 0 && + reduction_axes[1] == 2) { + Launch3DXZReduction(ctx, out, in, in_dim0, in_dim1, in_dim2, op, init, + cu_stream); + } else { + std::stringstream ss; + ss << "Invalid reduction requested: in_rank, out_rank, axes " << in_rank + << " " << out_rank; + if (out_rank == 1) ss << " " << reduction_axes[0]; + if (out_rank == 2) ss << " " << reduction_axes[1]; + LOG(FATAL) << ss.str(); + } +} + +template +struct ReduceFunctor { + template + static void Reduce(OpKernelContext* ctx, OUT_T out, IN_T in, + const ReductionAxes& reduction_axes, + const Reducer& reducer); +}; + +template +struct ReduceFunctor> { + template + static void Reduce(OpKernelContext* ctx, OUT_T out, IN_T in, + const ReductionAxes& reduction_axes, + const Eigen::internal::SumReducer& reducer) { + ReduceImpl, T*, T*, ReductionAxes>( + ctx, (T*)out.data(), (T*)in.data(), in.rank(), in.dimension(0), + in.rank() >= 2 ? in.dimension(1) : 1, + in.rank() >= 3 ? in.dimension(2) : 1, out.rank(), reduction_axes, + Sum()); + } + + template + static void FillIdentity(const GPUDevice& d, OUT_T out, + const Eigen::internal::SumReducer& reducer) { + FillIdentityEigenImpl(d, out, reducer); + } +}; + +// Specialization for bfloat16 with fp32 accumulation. +template <> +struct ReduceFunctor> { + template + static void Reduce( + OpKernelContext* ctx, OUT_T out, IN_T in, + const ReductionAxes& reduction_axes, + const Eigen::internal::SumReducer& reducer) { + typedef gpuprim::TransformInputIterator, + Eigen::bfloat16*> + inputIterType; + inputIterType input_itr((Eigen::bfloat16*)in.data(), + HalfToFloat()); + + typedef TransformOutputIterator> + outputIterType; + outputIterType itr((Eigen::bfloat16*)out.data(), + FloatToHalf()); + + ReduceImpl(ctx, itr, input_itr, in.rank(), in.dimension(0), + in.rank() >= 2 ? in.dimension(1) : 1, + in.rank() >= 3 ? in.dimension(2) : 1, out.rank(), + reduction_axes, gpuprim::Sum()); + } + + template + static void FillIdentity( + const GPUDevice& d, OUT_T out, + const Eigen::internal::SumReducer& reducer) { + FillIdentityEigenImpl(d, out, reducer); + } +}; + +// TODO(rmlarsen): Specialize for float16. +template +struct ReduceFunctor> { + template + static void Reduce(OpKernelContext* ctx, OUT_T out, IN_T in, + const ReductionAxes& reduction_axes, + const functor::EuclideanNormReducer& reducer) { + typedef gpuprim::TransformInputIterator, T*> inputIterType; + inputIterType input_itr((T*)in.data(), Square()); + typedef TransformOutputIterator> outputIterType; + outputIterType output_itr((T*)out.data(), SqrtOfReal()); + ReduceImpl, outputIterType, inputIterType, ReductionAxes>( + ctx, output_itr, input_itr, in.rank(), in.dimension(0), + in.rank() >= 2 ? in.dimension(1) : 1, + in.rank() >= 3 ? in.dimension(2) : 1, out.rank(), reduction_axes, + Sum()); + } + + template + static void FillIdentity(const GPUDevice& d, OUT_T out, + const functor::EuclideanNormReducer& reducer) { + FillIdentityEigenImpl(d, out, reducer); + } +}; + +template +struct ReduceFunctor> { + template + static void Reduce(OpKernelContext* ctx, OUT_T out, IN_T in, + const ReductionAxes& reduction_axes, + const functor::MeanReducer& reducer) { + int divisor = 1; + if (out.rank() == 0) + divisor = in.size(); + else if (out.rank() == 1 && in.rank() == 2 && reduction_axes[0] == 0) + divisor = in.dimension(0); + else if (out.rank() == 1 && in.rank() == 2 && reduction_axes[0] == 1) + divisor = in.dimension(1); + else if (out.rank() == 1 && in.rank() == 3 && reduction_axes[0] == 0 && + reduction_axes[1] == 2) + divisor = in.dimension(0) * in.dimension(2); + else if (out.rank() == 2 && in.rank() == 3 && reduction_axes[0] == 1) + divisor = in.dimension(1); + + DividesBy div_op(static_cast(divisor)); + TransformOutputIterator> itr((T*)out.data(), div_op); + ReduceImpl, TransformOutputIterator>, T*, + ReductionAxes>(ctx, itr, (T*)in.data(), in.rank(), + in.dimension(0), + in.rank() >= 2 ? in.dimension(1) : 1, + in.rank() >= 3 ? in.dimension(2) : 1, out.rank(), + reduction_axes, Sum()); + } + + template + static void FillIdentity(const GPUDevice& d, OUT_T out, + const functor::MeanReducer& reducer) { + FillIdentityEigenImpl(d, out, reducer); + } +}; + +template +void ReduceMeanWithFloatAccumulationImpl( + OpKernelContext* ctx, OUT_T out, IN_T in, + const ReductionAxes& reduction_axes, + const functor::MeanReducer& reducer) { + float divisor = 1.f; + if (out.rank() == 0) + divisor = in.size(); + else if (out.rank() == 1 && in.rank() == 2 && reduction_axes[0] == 0) + divisor = in.dimension(0); + else if (out.rank() == 1 && in.rank() == 2 && reduction_axes[0] == 1) + divisor = in.dimension(1); + else if (out.rank() == 1 && in.rank() == 3 && reduction_axes[0] == 0 && + reduction_axes[1] == 2) + divisor = in.dimension(0) * in.dimension(2); + else if (out.rank() == 2 && in.rank() == 3 && reduction_axes[0] == 1) + divisor = in.dimension(1); + DividesBy div_op(divisor); + + typedef gpuprim::TransformInputIterator, T*> + inputIterType; + inputIterType input_itr((T*)in.data(), HalfToFloat()); + + typedef TransformOutputIterator> outputIterType; + outputIterType itr((T*)out.data(), div_op); + + ReduceImpl( + ctx, itr, input_itr, in.rank(), in.dimension(0), + in.rank() >= 2 ? in.dimension(1) : 1, + in.rank() >= 3 ? in.dimension(2) : 1, out.rank(), reduction_axes, + gpuprim::Sum()); +} + +template <> +struct ReduceFunctor> { + template + static void Reduce(OpKernelContext* ctx, OUT_T out, IN_T in, + const ReductionAxes& reduction_axes, + const functor::MeanReducer& reducer) { + ReduceMeanWithFloatAccumulationImpl(ctx, out, in, reduction_axes, reducer); + } + + template + static void FillIdentity(const GPUDevice& d, OUT_T out, + const functor::MeanReducer& reducer) { + FillIdentityEigenImpl(d, out, reducer); + } +}; + +template <> +struct ReduceFunctor> { + template + static void Reduce(OpKernelContext* ctx, OUT_T out, IN_T in, + const ReductionAxes& reduction_axes, + const functor::MeanReducer& reducer) { + ReduceMeanWithFloatAccumulationImpl(ctx, out, in, reduction_axes, reducer); + } + + template + static void FillIdentity( + const GPUDevice& d, OUT_T out, + const functor::MeanReducer& reducer) { + FillIdentityEigenImpl(d, out, reducer); + } +}; + +template +struct ReduceFunctor> { + template + static void Reduce( + OpKernelContext* ctx, OUT_T out, IN_T in, + const ReductionAxes& reduction_axes, + const Eigen::internal::MaxReducer& reducer) { + ReduceImpl( + ctx, (T*)out.data(), (T*)in.data(), in.rank(), in.dimension(0), + in.rank() >= 2 ? in.dimension(1) : 1, + in.rank() >= 3 ? in.dimension(2) : 1, out.rank(), reduction_axes, + MaxPropagateNaN()); + } + + template + static void FillIdentity( + const GPUDevice& d, OUT_T out, + const Eigen::internal::MaxReducer& reducer) { + FillIdentityEigenImpl(d, out, reducer); + } +}; + +template +struct ReduceFunctor> { + template + static void Reduce( + OpKernelContext* ctx, OUT_T out, IN_T in, + const ReductionAxes& reduction_axes, + const Eigen::internal::MinReducer& reducer) { + ReduceImpl( + ctx, (T*)out.data(), (T*)in.data(), in.rank(), in.dimension(0), + in.rank() >= 2 ? in.dimension(1) : 1, + in.rank() >= 3 ? in.dimension(2) : 1, out.rank(), reduction_axes, + MinPropagateNaN()); + } + + template + static void FillIdentity( + const GPUDevice& d, OUT_T out, + const Eigen::internal::MinReducer& reducer) { + FillIdentityEigenImpl(d, out, reducer); + } +}; + +template +struct ReduceFunctor> { + template + static void Reduce(OpKernelContext* ctx, OUT_T out, IN_T in, + const ReductionAxes& reduction_axes, + const Eigen::internal::ProdReducer& reducer) { + ReduceImpl, T*, T*, ReductionAxes>( + ctx, (T*)out.data(), (T*)in.data(), in.rank(), in.dimension(0), + in.rank() >= 2 ? in.dimension(1) : 1, + in.rank() >= 3 ? in.dimension(2) : 1, out.rank(), reduction_axes, + Prod()); + } + + template + static void FillIdentity(const GPUDevice& d, OUT_T out, + const Eigen::internal::ProdReducer& reducer) { + FillIdentityEigenImpl(d, out, reducer); + } +}; + +template <> +struct ReduceFunctor { + template + static void Reduce(OpKernelContext* ctx, OUT_T out, IN_T in, + const ReductionAxes& reduction_axes, + const Eigen::internal::AndReducer& reducer) { + ReduceImpl( + ctx, (bool*)out.data(), (bool*)in.data(), in.rank(), in.dimension(0), + in.rank() >= 2 ? in.dimension(1) : 1, + in.rank() >= 3 ? in.dimension(2) : 1, out.rank(), reduction_axes, + And()); + } + + template + static void FillIdentity(const GPUDevice& d, OUT_T out, + const Eigen::internal::AndReducer& reducer) { + FillIdentityEigenImpl(d, out, reducer); + } +}; + +template <> +struct ReduceFunctor { + template + static void Reduce(OpKernelContext* ctx, OUT_T out, IN_T in, + const ReductionAxes& reduction_axes, + const Eigen::internal::OrReducer& reducer) { + ReduceImpl( + ctx, (bool*)out.data(), (bool*)in.data(), in.rank(), in.dimension(0), + in.rank() >= 2 ? in.dimension(1) : 1, + in.rank() >= 3 ? in.dimension(2) : 1, out.rank(), reduction_axes, Or()); + } + + template + static void FillIdentity(const GPUDevice& d, OUT_T out, + const Eigen::internal::OrReducer& reducer) { + FillIdentityEigenImpl(d, out, reducer); + } +}; + +} // namespace functor +} // namespace tensorflow + +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM + +#endif // TENSORFLOW_CORE_KERNELS_REDUCTION_GPU_KERNELS_CU_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/reduction_ops.h b/third_party/tflite-hdrs/tensorflow/core/kernels/reduction_ops.h new file mode 100644 index 00000000..510fbc93 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/reduction_ops.h @@ -0,0 +1,207 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_REDUCTION_OPS_H_ +#define TENSORFLOW_CORE_KERNELS_REDUCTION_OPS_H_ + +// Functor definitions for Reduction ops, must be compilable by nvcc. + +#include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor_types.h" + +namespace tensorflow { +namespace functor { + +template +struct ReducerTraits { + enum { IsScalarIdentity = true }; +}; + +// Dummy class used for template specialization for mean reduction, which is +// accomplished by SumReducer and on-the-fly division by the reduction factor. +template +struct MeanReducer { + Scalar initialize() const { return Scalar(0); } +}; + +// Dummy class used for template specialization for l2-norm reduction. +template +struct EuclideanNormReducer { + Scalar initialize() const { return Scalar(0); } +}; + +template +struct ReducerTraits> { + enum { IsScalarIdentity = false }; +}; + +template +struct ReduceEigenImpl { + void operator()(const Device& d, OUT_T out, IN_T in, + const ReductionAxes& reduction_axes, const Reducer& reducer) { + out.device(d) = in.reduce(reduction_axes, reducer); + } +}; + +// Specialization for BF16 Reducer to fix accuracy. +// TODO: All BF16 reducers should have specializations to fix accuracy. +#define CASTING_SPECIALIZATION(Reducer, ScalarType, IntermediateType) \ + template \ + struct ReduceEigenImpl> { \ + void operator()(const Device& d, OUT_T out, IN_T in, \ + const ReductionAxes& reduction_axes, \ + const Reducer& reducer) { \ + static_assert(std::is_same::value, \ + ""); \ + Reducer intermediate_reducer; \ + auto in_as_intermediate = in.template cast(); \ + out.device(d) = \ + in_as_intermediate.reduce(reduction_axes, intermediate_reducer) \ + .template cast(); \ + } \ + }; + +CASTING_SPECIALIZATION(Eigen::internal::SumReducer, bfloat16, float); +#undef CASTING_SPECIALIZATION + +template +struct ReduceEigenImpl> { + void operator()(const Device& d, OUT_T out, IN_T in, + const ReductionAxes& reduction_axes, + const functor::MeanReducer& reducer) { + static_assert(std::is_same::value, ""); + Eigen::internal::SumReducer sum_reducer; + out.device(d) = in.reduce(reduction_axes, sum_reducer) / + static_cast(in.size() / out.size()); + } +}; + +// Specialization for which we do the reduction in IntermediateType to +// avoid integer overflow and fix bfloat16 accuracy in some models. +#define CASTING_SPECIALIZATION(ScalarType, IntermediateType) \ + template \ + struct ReduceEigenImpl> { \ + void operator()(const Device& d, OUT_T out, IN_T in, \ + const ReductionAxes& reduction_axes, \ + const functor::MeanReducer& reducer) { \ + static_assert(std::is_same::value, \ + ""); \ + Eigen::internal::SumReducer sum_reducer; \ + out.device(d) = (in.template cast().reduce( \ + reduction_axes, sum_reducer) / \ + static_cast(in.size() / out.size())) \ + .template cast(); \ + } \ + } + +CASTING_SPECIALIZATION(uint8, uint64); +CASTING_SPECIALIZATION(uint16, uint64); +CASTING_SPECIALIZATION(uint32, uint64); +CASTING_SPECIALIZATION(int8, int64_t); +CASTING_SPECIALIZATION(int16, int64_t); +CASTING_SPECIALIZATION(int32, int64_t); +CASTING_SPECIALIZATION(bfloat16, float); +#undef CASTING_SPECIALIZATION + +// TODO(rmlarsen): Refactor this such that taking the sqrt can be optional +// controlled by an attribute. +template +struct ReduceEigenImpl> { + void operator()(const Device& d, OUT_T out, IN_T in, + const ReductionAxes& reduction_axes, + const functor::EuclideanNormReducer& reducer) { + static_assert(std::is_same::value, ""); + Eigen::internal::SumReducer sum_reducer; + out.device(d) = + (in * in.conjugate()).reduce(reduction_axes, sum_reducer).sqrt(); + } +}; + +template +struct ReduceEigenImpl> { + void operator()(const Device& d, OUT_T out, IN_T in, + const ReductionAxes& reduction_axes, + const functor::EuclideanNormReducer& reducer) { + static_assert(std::is_same::value, ""); + Eigen::internal::SumReducer sum_reducer; + auto in_as_float = in.template cast(); + out.device(d) = (in_as_float * in_as_float.conjugate()) + .reduce(reduction_axes, sum_reducer) + .sqrt() + .template cast(); + } +}; + +// For most reducers, the identity is Reducer::initialize() +template +struct Identity { + static auto identity(const Reducer& reducer) + -> decltype(reducer.initialize()) { + return reducer.initialize(); + } +}; + +// MeanReducer is a special case, since it doesn't technically have an identity. +// Thus, ideally we'd return nan. However, mean is instantiated for integer +// types as well, so we do the nan override only for floating point types. +#define FIX_MEAN_IDENTITY(T) \ + template <> \ + struct Identity> { \ + static T identity(const functor::MeanReducer&) { \ + return Eigen::NumTraits::quiet_NaN(); \ + } \ + }; +FIX_MEAN_IDENTITY(Eigen::half) +FIX_MEAN_IDENTITY(Eigen::bfloat16) +FIX_MEAN_IDENTITY(float) +FIX_MEAN_IDENTITY(double) +#undef FIX_MEAN_IDENTITY + +template +void FillIdentityEigenImpl(const Device& d, OUT_T out, const Reducer& reducer) { + MaybeWith32BitIndexing( + [&](auto out32) { + out32.device(d) = out32.constant(Identity::identity(reducer)); + }, + out); +} + +template +struct ReduceFunctor { + template + static void Reduce(OpKernelContext* ctx, OUT_T out, IN_T in, + const ReductionAxes& reduction_axes, + const Reducer& reducer); + + template + static void FillIdentity(const Device& d, OUT_T out, const Reducer& reducer); +}; + +} // namespace functor +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_REDUCTION_OPS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/reduction_ops_common.h b/third_party/tflite-hdrs/tensorflow/core/kernels/reduction_ops_common.h new file mode 100644 index 00000000..6ce777f7 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/reduction_ops_common.h @@ -0,0 +1,279 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This is an internal header file intended to only be included as the +// front-matter in the implementation files of various reduction ops. It +// is a header file because we split the various reduction ops into their +// own compilation units to get more parallelism in compilation. + +#ifndef TENSORFLOW_CORE_KERNELS_REDUCTION_OPS_COMMON_H_ +#define TENSORFLOW_CORE_KERNELS_REDUCTION_OPS_COMMON_H_ + +#define EIGEN_USE_THREADS + +#include "Eigen/Core" // from @eigen_archive +#include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/numeric_op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/kernels/reduction_ops.h" +#include "tensorflow/core/kernels/transpose_functor.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/gtl/inlined_vector.h" +#include "tensorflow/core/platform/logging.h" + +namespace tensorflow { + +typedef Eigen::ThreadPoolDevice CPUDevice; +typedef Eigen::GpuDevice GPUDevice; + +template +struct Constants { + // Derive Index type. int (32-bit) or long (64-bit) depending on the + // compile-time configuration. "float" here is not relevant. + // TODO(zhifengc): Moves the definition to TTypes. + typedef TTypes::Tensor::Index Index; + Eigen::array kZero; + Eigen::array kOne; + Eigen::array kZeroTwo; + + Constants() { + kZero[0] = 0; + kOne[0] = 1; + kZeroTwo[0] = 0; + kZeroTwo[1] = 2; + } +}; + +struct ConstantsBase { + const Eigen::IndexList> kZero; + const Eigen::IndexList> kOne; + const Eigen::IndexList, Eigen::type2index<2>> kZeroTwo; +}; +template <> +struct Constants : ConstantsBase {}; + +class ReductionHelper { + public: + ReductionHelper() : reduce_first_axis_(false) {} + + absl::Status Simplify(const Tensor& data, const Tensor& axis, + const bool keep_dims); + + // We need to do roughly: + // tmp_out = allocate(out_reshape()) + // tmp_out.reshape(out_reshape) = data.reshape(data_reshape).reduce(axes) + // out = tmp_out.reshape(out_shape) + + // The reduction result must be allocated with this shape. + TensorShape out_reshape() const; + + // The final output shape must be allocated with this shape. + TensorShape out_shape() const; + + // The reduction is on a reshaped tensor of this rank. + int ndims() const { return data_reshape_.size(); } + + // True if need to reduce the 0-th dimension. + bool reduce_first_axis() const { return reduce_first_axis_; } + + // The output is reshaped. + template + typename TTypes::Tensor out(Tensor* out) { + return out->shaped(out_reshape_); + } + + // The input is reshaped. + template + typename TTypes::ConstTensor in(const Tensor& data) { + return data.shaped(data_reshape_); + } + + // Shape of shuffled input + TensorShape data_reshape() const { + TensorShape shape; + for (auto s : data_reshape_) shape.AddDim(s); + return shape; + } + + // Shape with all reduction dimensions at the end + TensorShape shuffled_shape(); + + // Permutation of reduced dims needed to put reduction dimensions at the end + absl::InlinedVector permutation(); + + private: + bool reduce_first_axis_; // True if need to reduce the 0-th dimension. + absl::InlinedVector + data_reshape_; // Reshape data before reduction. + absl::InlinedVector out_shape_; // The final output shape. + absl::InlinedVector + out_reshape_; // Reshape output for reduction. +}; + +// For operations where the output is a reduction function along some +// dimensions of the input. +template +class ReductionOp : public OpKernel { + public: + explicit ReductionOp(OpKernelConstruction* ctx) : OpKernel(ctx) { + const DataType dt = DataTypeToEnum::v(); + const DataType pt = DataTypeToEnum::v(); + OP_REQUIRES_OK(ctx, ctx->MatchSignature({dt, pt}, {dt})); + + OP_REQUIRES_OK(ctx, ctx->GetAttr("keep_dims", &keep_dims_)); + } + + void Compute(OpKernelContext* ctx) override { + const Tensor& data = ctx->input(0); + const Tensor& axes = ctx->input(1); + VLOG(1) << "data shape: " << data.shape().DebugString(); + VLOG(1) << "axes : " << axes.SummarizeValue(10); + + ReductionHelper helper; + OP_REQUIRES_OK(ctx, helper.Simplify(data, axes, keep_dims_)); + CHECK_GE(helper.ndims(), 0); + + bool is_scalar_identity = functor::ReducerTraits::IsScalarIdentity; + bool is_trivial = helper.ndims() == 0 || + (helper.ndims() == 1 && !helper.reduce_first_axis()); + if (is_scalar_identity && is_trivial) { + Tensor out; + // Special case. Reduces nothing and does not alter the input values. + if (!out.CopyFrom(data, helper.out_shape())) { + ctx->SetStatus(errors::Internal("Error during reduction copy.")); + } + ctx->set_output(0, out); + return; + } + + // We must allocate temp tensors using the same alloc attr as + // output(0) because it is returned as output(0) in the end. + const AllocatorAttributes alloc_attr = ctx->output_alloc_attr(0); + + Tensor tmp_out; + typedef functor::ReduceFunctor Functor; + Constants constants; + const Device& d = ctx->eigen_device(); + Reducer reducer; + + if (data.NumElements() > 0 && is_trivial && !is_scalar_identity) { + OP_REQUIRES_OK(ctx, ctx->allocate_temp(ctx->expected_output_dtype(0), + TensorShape({data.NumElements()}), + &tmp_out, alloc_attr)); + Functor::Reduce(ctx, tmp_out.flat(), + data.shaped({1, data.NumElements()}), + constants.kZero, reducer); + } else { + // A temporary tensor whose size matches the size of the reduced + // output. + OP_REQUIRES_OK( + ctx, ctx->allocate_temp(ctx->expected_output_dtype(0), + helper.out_reshape(), &tmp_out, alloc_attr)); + + if (tmp_out.NumElements() == 0) { + // Nothing to do, fall through to final reshaping. + } else if (data.NumElements() == 0) { + // Degenerate reduction where the input is empty but the output is + // nonempty (thus tmp_out.NumElements() > 0), and we must fill the + // output with identity elements. Example: tf.reduce_sum(tf.zeros((0, + // 3)), [0]). Eigen sometimes crashes in this case, so we do it + // manually. + Functor::FillIdentity(d, tmp_out.flat(), reducer); + } else if ((helper.ndims() == 1) && helper.reduce_first_axis()) { + // Reduce to a scalar. + Functor::Reduce(ctx, helper.out(&tmp_out), helper.in(data), + constants.kZero, reducer); + } else if ((helper.ndims() == 2) && helper.reduce_first_axis()) { + // Can be viewed as a reduction of a matrix along 1st dimension. + Functor::Reduce(ctx, helper.out(&tmp_out), helper.in(data), + constants.kZero, reducer); + } else if ((helper.ndims() == 2) && !helper.reduce_first_axis()) { + // Can be viewed as a reduction of a matrix along 2nd dimension. + Functor::Reduce(ctx, helper.out(&tmp_out), helper.in(data), + constants.kOne, reducer); + } else if ((helper.ndims() == 3) && helper.reduce_first_axis()) { + // Can be viewed as a reduction of a 3D tensor along 1st and 3rd + // dimensions. + Functor::Reduce(ctx, helper.out(&tmp_out), helper.in(data), + constants.kZeroTwo, reducer); + } else if ((helper.ndims() == 3) && !helper.reduce_first_axis()) { + // Can be viewed as a reduction of a 3D tensor along 2nd dimension. + Functor::Reduce(ctx, helper.out(&tmp_out), helper.in(data), + constants.kOne, reducer); + } else { + // If we don't hit one of the cases above, transpose the data so that + // all reduced dimensions are last and reuse the 2-D -> 1-D case. + Tensor data_reshaped; + OP_REQUIRES(ctx, data_reshaped.CopyFrom(data, helper.data_reshape()), + errors::Internal("Error during reduction copy.")); + Tensor shuffled; + OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum::value, + helper.shuffled_shape(), + &shuffled, alloc_attr)); + OP_REQUIRES_OK(ctx, DoTranspose(d, data_reshaped, helper.permutation(), + &shuffled)); + const int64_t unreduced = tmp_out.NumElements(); + const int64_t reduced = shuffled.NumElements() / unreduced; + const Tensor& const_shuffled = shuffled; + Functor::Reduce(ctx, tmp_out.flat(), + const_shuffled.shaped({unreduced, reduced}), + constants.kOne, reducer); + } + } + + // Set the real output using the contents of the reduction but the + // real expected output shape. The number of elements should + // match between the two shapes. + Tensor out; + OP_REQUIRES(ctx, out.CopyFrom(tmp_out, helper.out_shape()), + errors::Internal("Error during reduction copy.")); + ctx->set_output(0, out); + } + + private: + // True if the number of dimensions should be maintained. + bool keep_dims_; +}; + +namespace functor { + +template +struct ReduceFunctorBase { + template + static void Reduce(OpKernelContext* ctx, OUT_T out, IN_T in, + const ReductionAxes& reduction_axes, + const Reducer& reducer) { + const Device& d = ctx->eigen_device(); + ReduceEigenImpl reducer_impl; + reducer_impl(d, out, in, reduction_axes, reducer); + } + + template + static void FillIdentity(const Device& d, OUT_T out, const Reducer& reducer) { + FillIdentityEigenImpl(d, out, reducer); + } +}; + +template +struct ReduceFunctor + : ReduceFunctorBase {}; + +} // namespace functor +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_REDUCTION_OPS_COMMON_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/reduction_ops_common_gpu.h b/third_party/tflite-hdrs/tensorflow/core/kernels/reduction_ops_common_gpu.h new file mode 100644 index 00000000..b7bdb07c --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/reduction_ops_common_gpu.h @@ -0,0 +1,44 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_KERNELS_REDUCTION_OPS_COMMON_GPU_H_ +#define TENSORFLOW_CORE_KERNELS_REDUCTION_OPS_COMMON_GPU_H_ + +#if !GOOGLE_CUDA && !TENSORFLOW_USE_ROCM +#error This file must only be included when building with GPU support +#endif + +#include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor_types.h" + +namespace tensorflow { +namespace functor { + +template +struct ReduceFunctor { + template + static void Reduce(OpKernelContext* ctx, OUT_T out, IN_T in, + const ReductionAxes& reduction_axes, + const Reducer& reducer); + + template + static void FillIdentity(const Eigen::GpuDevice& d, OUT_T out, + const Reducer& reducer); +}; + +} // namespace functor +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_REDUCTION_OPS_COMMON_GPU_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/redux_functor.h b/third_party/tflite-hdrs/tensorflow/core/kernels/redux_functor.h new file mode 100644 index 00000000..41ab917a --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/redux_functor.h @@ -0,0 +1,337 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_REDUX_FUNCTOR_H_ +#define TENSORFLOW_CORE_KERNELS_REDUX_FUNCTOR_H_ + +#define EIGEN_USE_THREADS + +#include "Eigen/Core" // from @eigen_archive +#include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +using CPUDevice = Eigen::ThreadPoolDevice; + +namespace functor { + +// Compute reduction over outer dimensions. +// Example: +// input: [D1, D2, ... , DN] +// -> +// output: [Di, ... , DN] where i belongs to set [1,N] +template +struct ReduceOuterDimensions { + ReduceOuterDimensions() {} + + template + void operator()(const CPUDevice& device, + const Eigen::DSizes& input_dims, + const Tensor& input, Tensor* output) const { + // Compute inner and outer dim after reshaping into 2d tensor. + const int num_output_dims = output->dims(); + auto output_dims = output->template flat().dimensions(); + + Eigen::Index inner_dim = 1, outer_dim = 1; + for (int i = 0; i < num_dims - num_output_dims; ++i) + outer_dim *= input_dims[i]; + for (int i = num_dims - num_output_dims; i < num_dims; ++i) + inner_dim *= input_dims[i]; + + if (1 == outer_dim) { + // Nothing to do but passing input to output. + output->template flat() = + input.template flat().template cast().reshape( + output_dims); + return; + } + + // Get device thread num. + const Eigen::Index num_threads = device.numThreads(); + + // If the inner dim parallelism is large enough + // TODO(ezhulenev): There seems to be no benefits in going this route. Check + // if this can be improved, or use better heuristic? + if (inner_dim > num_threads * 32) { + // Do not create more blocks than there are threads in a pool. + const Eigen::Index num_blocks = num_threads; + + // Block size along the outer dimension. + const Eigen::Index inner_block_size = Eigen::divup(inner_dim, num_blocks); + const InputT* input_data = input.template flat().data(); + + // Allocate temporary buffer for partial reductions. + Eigen::Tensor buffer( + {inner_dim}); + buffer.setZero(); + AccumT* buffer_data = buffer.data(); + + using Buffer = Eigen::TensorMap< + Eigen::Tensor, + Eigen::Unaligned>; + + using Input = Eigen::TensorMap< + Eigen::Tensor, + Eigen::Unaligned>; + + const auto compute = [inner_dim, outer_dim, num_blocks, inner_block_size, + input_data, buffer_data]( + Eigen::Index start, Eigen::Index limit) -> void { + DCHECK(start >= 0 && limit <= num_blocks); + Eigen::Index inner_dim_start = start * inner_block_size; + Eigen::Index inner_dim_limit = limit * inner_block_size; + inner_dim_limit = std::min(inner_dim, inner_dim_limit); + Eigen::Index my_job_len = inner_dim_limit - inner_dim_start; + + const InputT* my_job_start = input_data + inner_dim_start; + Buffer buf(buffer_data + inner_dim_start, my_job_len); + + for (Eigen::Index i = 0; i < outer_dim; ++i) { + auto in = Input(my_job_start + i * inner_dim, my_job_len); + auto cast = in.template cast(); + buf = Eigen::TensorCwiseBinaryOp(buf, cast); + } + }; + + // Compute cost of reducing a single block. + const Eigen::Index compute_size = outer_dim * inner_block_size; + const Eigen::Index compute_input_bytes = compute_size * sizeof(InputT); + const Eigen::TensorOpCost cost( + compute_input_bytes, + 0, // We'll be mostly writing to L1, assume store cost is 0 + compute_size * Eigen::internal::functor_traits::Cost); + + device.parallelFor(num_blocks, cost, compute); + + // Write final result to the output. + output->template flat() = + buffer.template cast().reshape(output_dims); + } else { + // Compute block size along the outer dimension for efficiency. + const Eigen::Index parallel_cell_size = inner_dim; + const Eigen::Index total_workload = outer_dim * inner_dim; + const Eigen::Index max_parallelism = total_workload / parallel_cell_size; + + const Eigen::Index min_block_workload = 2000; + const Eigen::Index min_block_size = + Eigen::divup(min_block_workload, parallel_cell_size); + const Eigen::Index max_num_blocks = std::min( + max_parallelism, Eigen::divup(total_workload, min_block_size)); + + // Do not create more blocks than there are threads in a pool. + const Eigen::Index num_blocks = std::min(max_num_blocks, num_threads); + + // Block size along the outer dimension. + const Eigen::Index outer_block_size = Eigen::divup(outer_dim, num_blocks); + + const InputT* input_data = input.template flat().data(); + + // Allocate temporary buffer for partial reductions. + Tensor buffer(DataTypeToEnum::v(), {num_blocks, inner_dim}); + buffer.template flat().setZero(); + AccumT* buffer_data = buffer.template flat().data(); + + using Buffer = Eigen::TensorMap< + Eigen::Tensor, + Eigen::Unaligned>; + + using Input = Eigen::TensorMap< + Eigen::Tensor, + Eigen::Unaligned>; + + const auto compute = [inner_dim, num_blocks, outer_block_size, + buffer_data, input_data, outer_dim]( + Eigen::Index start, Eigen::Index limit) -> void { + DCHECK(start >= 0 && limit <= num_blocks); + Eigen::Index outer_dim_start = start * outer_block_size; + Eigen::Index outer_dim_limit = limit * outer_block_size; + outer_dim_limit = std::min(outer_dim, outer_dim_limit); + + Buffer buf(buffer_data + start * inner_dim, inner_dim); + for (Eigen::Index i = outer_dim_start; i < outer_dim_limit; ++i) { + auto in = Input(input_data + i * inner_dim, inner_dim); + auto cast = in.template cast(); + buf = Eigen::TensorCwiseBinaryOp(buf, cast); + } + }; + + // Compute cost of reducing a single block. + const Eigen::Index compute_size = outer_block_size * inner_dim; + const Eigen::Index compute_input_bytes = compute_size * sizeof(InputT); + const Eigen::TensorOpCost cost( + compute_input_bytes, + 0, // We'll be mostly writing to L1, assume store cost is 0 + compute_size * Eigen::internal::functor_traits::Cost); + + device.parallelFor(num_blocks, cost, compute); + + // Aggregate partial results from temporary buffer into first block. + auto buf0 = Buffer(buffer_data, inner_dim); + // Just sum the buffer up, as inner dimensions is not large in this case. + for (int i = 1; i < num_blocks; ++i) { + auto buf = Buffer(buffer_data + i * inner_dim, inner_dim); + buf0 = Eigen::TensorCwiseBinaryOp(buf0, buf); + } + // Write final result to the output. + output->template flat() = + buf0.template cast().reshape(output_dims); + } + } +}; + +// Compute reduction to some serial middle dimensions (like a axis). +// Example: +// input: [D1, D2, ... , DN] +// -> +// output: [Di, ... , Dj] where i & j belongs to set [1,N]. +template +struct ReduceMiddleDimensions { + ReduceMiddleDimensions() {} + + template + void operator()(const CPUDevice& device, + const Eigen::DSizes& input_dims, + const Tensor& input, Tensor* output, + const int axis_begin_dim) const { + // Compute dims after reshaping into 3d tensor. + const int num_output_dims = output->dims(); + auto output_dims = output->template flat().dimensions(); + + Eigen::Index inner_dim = 1, middle_dim = 1, outer_dim = 1; + for (int i = 0; i < axis_begin_dim; ++i) outer_dim *= input_dims[i]; + for (int i = axis_begin_dim; i < axis_begin_dim + num_output_dims; ++i) + middle_dim *= input_dims[i]; + for (int i = axis_begin_dim + num_output_dims; i < num_dims; ++i) + inner_dim *= input_dims[i]; + + if ((1 == inner_dim * outer_dim)) { + // Nothing to do. + output->template flat() = + input.template flat().template cast().reshape( + output_dims); + return; + } + + // Compute block size along the outer dimension for efficiency. + const Eigen::Index parallel_cell_size = inner_dim; + const Eigen::Index max_parallelism = outer_dim * middle_dim; + const Eigen::Index total_workload = max_parallelism * inner_dim; + + const Eigen::Index min_block_workload = 2000; + const Eigen::Index min_block_size = + Eigen::divup(min_block_workload, parallel_cell_size); + const Eigen::Index max_num_blocks = + std::min(max_parallelism, Eigen::divup(total_workload, min_block_size)); + + // Do not create more blocks than there are threads in a pool. + const Eigen::Index num_threads = device.numThreads(); + const Eigen::Index num_blocks = std::min(max_num_blocks, num_threads); + + // Block size along the outer dimension. + const Eigen::Index outer_block_size = + Eigen::divup(total_workload, num_blocks); + + const InputT* input_data = input.template flat().data(); + + // Allocate temporary buffer for partial reductions. + Eigen::Tensor buffer(num_blocks, middle_dim); + buffer.setZero(); + AccumT* buffer_data = buffer.data(); + + using Buffer = Eigen::TensorMap>; + using Input = Eigen::TensorMap>; + + Eigen::array reduction_axis = {0}; + Reducer reducer; + const BinaryFunctor binary_op; + + const auto compute = [inner_dim, middle_dim, input_data, buffer_data, + total_workload, num_blocks, outer_block_size, + reduction_axis, reducer, binary_op]( + Eigen::Index start, Eigen::Index limit) -> void { + DCHECK(start >= 0 && limit <= num_blocks); + Eigen::Index block_start = start * outer_block_size; + Eigen::Index block_limit = limit * outer_block_size; + block_limit = std::min(total_workload, block_limit); + Buffer buf(buffer_data + start * middle_dim, middle_dim); + + const int align_start = + ((block_start + inner_dim - 1) / inner_dim) * inner_dim; + const int align_end = (block_limit / inner_dim) * inner_dim; + + Eigen::Index coordinate = block_start / inner_dim % middle_dim; + Eigen::Tensor reduced = + Input(&input_data[block_start], align_start - block_start) + .reduce(reduction_axis, reducer) + .template cast(); + + buf(coordinate) = binary_op(buf(coordinate), reduced(0)); + + coordinate = align_start / inner_dim % middle_dim; + for (int i = align_start; i < align_end; i += inner_dim) { + reduced = Input(&input_data[i], inner_dim) + .reduce(reduction_axis, reducer) + .template cast(); + buf(coordinate) = binary_op(buf(coordinate), reduced(0)); + ++coordinate; + if (middle_dim == coordinate) coordinate = 0; + } + + reduced = Input(&input_data[align_end], block_limit - align_end) + .reduce(reduction_axis, reducer) + .template cast(); + buf(coordinate) = binary_op(buf(coordinate), reduced(0)); + }; + + // Compute cost of reducing a single block. + const Eigen::Index compute_size = outer_block_size * inner_dim; + const Eigen::Index compute_input_bytes = compute_size * sizeof(InputT); + const Eigen::TensorOpCost cost( + compute_input_bytes, + 0, // We'll be mostly writing to L1, assume store cost is 0 + compute_size * Eigen::internal::functor_traits::Cost); + + device.parallelFor(num_blocks, cost, compute); + + using Output = Eigen::TensorMap< + Eigen::Tensor, + Eigen::Unaligned>; + // Aggregate partial results from temporary buffer into first block. + auto buf0 = Output(buffer_data, middle_dim); + // TODO(ezhulenev): Parallelize this loop for large inner dimensions? + for (int i = 1; i < num_blocks; ++i) { + auto buf = Output(buffer_data + i * middle_dim, middle_dim); + buf0 = Eigen::TensorCwiseBinaryOp(buf0, buf); + } + + // Write final result to the output. + output->template flat() = + buf0.template cast().reshape(output_dims); + } +}; + +} // namespace functor +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_REDUX_FUNCTOR_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/reference_gemm.h b/third_party/tflite-hdrs/tensorflow/core/kernels/reference_gemm.h new file mode 100644 index 00000000..9d0bb60e --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/reference_gemm.h @@ -0,0 +1,96 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_REFERENCE_GEMM_H_ +#define TENSORFLOW_CORE_KERNELS_REFERENCE_GEMM_H_ + +#include + +#include "Eigen/Core" // from @eigen_archive +#include "tensorflow/core/platform/types.h" + +// This is an unoptimized but debuggable implementation of the GEMM matrix +// multiply function, used to compare to faster but more opaque versions, or +// for bit depths or argument combinations that aren't supported by optimized +// code. +// It assumes the row-major convention used by TensorFlow, and implements +// C = A * B, like the standard BLAS GEMM interface. If the transpose flags are +// true, then the relevant matrix is treated as stored in column-major order. + +namespace tensorflow { +template +void ReferenceGemm(bool transpose_a, bool transpose_b, bool transpose_c, + size_t m, size_t n, size_t k, const T1* a, int32_t offset_a, + size_t lda, const T2* b, int32_t offset_b, size_t ldb, T3* c, + int32_t shift_c, int32_t offset_c, int32_t mult_c, + size_t ldc) { + int a_i_stride; + int a_l_stride; + if (transpose_a) { + a_i_stride = 1; + a_l_stride = lda; + } else { + a_i_stride = lda; + a_l_stride = 1; + } + int b_j_stride; + int b_l_stride; + if (transpose_b) { + b_j_stride = ldb; + b_l_stride = 1; + } else { + b_j_stride = 1; + b_l_stride = ldb; + } + int c_i_stride; + int c_j_stride; + if (transpose_c) { + c_i_stride = 1; + c_j_stride = ldc; + } else { + c_i_stride = ldc; + c_j_stride = 1; + } + + const int32_t highest = static_cast(Eigen::NumTraits::highest()); + const int32_t lowest = static_cast(Eigen::NumTraits::lowest()); + const int32_t rounding = (shift_c < 1) ? 0 : (1 << (shift_c - 1)); + + int i, j, l; + for (j = 0; j < n; j++) { + for (i = 0; i < m; i++) { + int32_t total = 0; + for (l = 0; l < k; l++) { + const size_t a_index = ((i * a_i_stride) + (l * a_l_stride)); + const int32_t a_value = static_cast(a[a_index]) - offset_a; + const size_t b_index = ((j * b_j_stride) + (l * b_l_stride)); + const int32_t b_value = static_cast(b[b_index]) - offset_b; + total += (a_value * b_value); + } + const size_t c_index = ((i * c_i_stride) + (j * c_j_stride)); + int32_t output = ((((total + offset_c) * mult_c) + rounding) >> shift_c); + if (output > highest) { + output = highest; + } + if (output < lowest) { + output = lowest; + } + c[c_index] = static_cast(output); + } + } +} +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_REFERENCE_GEMM_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/relu_op.h b/third_party/tflite-hdrs/tensorflow/core/kernels/relu_op.h new file mode 100644 index 00000000..4b64a69f --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/relu_op.h @@ -0,0 +1,283 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// See docs in ../ops/nn_ops.cc. + +#ifndef TENSORFLOW_CORE_KERNELS_RELU_OP_H_ +#define TENSORFLOW_CORE_KERNELS_RELU_OP_H_ + +#define EIGEN_USE_THREADS + +#include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/numeric_op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/kernels/relu_op_functor.h" +#include "tensorflow/core/lib/core/errors.h" + +namespace tensorflow { + +template +class ReluOp : public UnaryElementWiseOp> { + public: + using UnaryElementWiseOp>::UnaryElementWiseOp; + + void Operate(OpKernelContext* context, const Tensor& input, Tensor* output) { + functor::Relu functor; + functor(context->eigen_device(), input.flat(), + output->flat()); + } +}; + +// Out of line check to save code space (we have this code once, rather +// than once for every NDIMS * NumTypes * Num_different_relu_variants +// functions. +struct ReluHelpers { + static void ValidateSameSizeHelper(OpKernelContext* context, const Tensor& g, + const Tensor& a) { + OP_REQUIRES(context, a.IsSameSize(g), + errors::InvalidArgument("g and a must be the same size")); + } + static bool ValidateSameSize(OpKernelContext* context, const Tensor& g, + const Tensor& a) { + ValidateSameSizeHelper(context, g, a); + return context->status().ok(); + } +}; + +template +class ReluGradOp : public BinaryElementWiseOp> { + public: + using BinaryElementWiseOp>::BinaryElementWiseOp; + + void OperateNoTemplate(OpKernelContext* context, const Tensor& g, + const Tensor& a, Tensor* output); + + // INPUTS: + // g (gradients): backpropagated gradients + // a (inputs): either the inputs that were passed to ReluOp(), or its + // outputs (using either one yields the same result here). + // OUTPUT: + // gradients to backprop + template + void Operate(OpKernelContext* context, const Tensor& g, const Tensor& a, + Tensor* output) { + OperateNoTemplate(context, g, a, output); + } +}; + +template +void ReluGradOp::OperateNoTemplate(OpKernelContext* context, + const Tensor& g, const Tensor& a, + Tensor* output) { + if (!ReluHelpers::ValidateSameSize(context, g, a)) return; + functor::ReluGrad functor; + functor(context->eigen_device(), g.flat(), a.flat(), + output->flat()); +} + +template +class Relu6Op : public UnaryElementWiseOp> { + public: + using UnaryElementWiseOp>::UnaryElementWiseOp; + + void Operate(OpKernelContext* context, const Tensor& input, Tensor* output) { + functor::Relu6 functor; + functor(context->eigen_device(), input.flat(), + output->flat()); + } +}; + +template +class Relu6GradOp : public BinaryElementWiseOp> { + public: + using BinaryElementWiseOp>::BinaryElementWiseOp; + + void OperateNoTemplate(OpKernelContext* context, const Tensor& g, + const Tensor& a, Tensor* output); + + // INPUTS: + // g (gradients): backpropagated gradients + // a (inputs): inputs that were passed to Relu6Op() + // OUTPUT: + // gradients to backprop + template + void Operate(OpKernelContext* context, const Tensor& g, const Tensor& a, + Tensor* output) { + OperateNoTemplate(context, g, a, output); + } +}; + +template +void Relu6GradOp::OperateNoTemplate(OpKernelContext* context, + const Tensor& g, const Tensor& a, + Tensor* output) { + if (!ReluHelpers::ValidateSameSize(context, g, a)) return; + functor::Relu6Grad functor; + functor(context->eigen_device(), g.flat(), a.flat(), + output->flat()); +} + +template +class LeakyReluOp : public UnaryElementWiseOp> { + public: + explicit LeakyReluOp(OpKernelConstruction* context) + : UnaryElementWiseOp>(context) { + float alpha_tmp; + OP_REQUIRES_OK(context, context->GetAttr("alpha", &alpha_tmp)); + alpha_ = T(alpha_tmp); + } + + void Operate(OpKernelContext* context, const Tensor& input, Tensor* output) { + functor::LeakyRelu functor; + functor({context->eigen_device(), input.flat(), alpha_, + output->flat()}); + } + + private: + T alpha_; +}; + +template +class LeakyReluGradOp + : public BinaryElementWiseOp> { + public: + explicit LeakyReluGradOp(OpKernelConstruction* context) + : BinaryElementWiseOp>(context) { + float alpha_tmp; + OP_REQUIRES_OK(context, context->GetAttr("alpha", &alpha_tmp)); + alpha_ = T(alpha_tmp); + } + + void OperateNoTemplate(OpKernelContext* context, const Tensor& g, + const Tensor& a, T alpha, Tensor* output); + + // INPUTS: + // g (gradients): backpropagated gradients + // a (inputs): either the inputs that were passed to LeakyReluOp(), or its + // outputs (using either one yields the same result here). + // OUTPUT: + // gradients to backprop + template + void Operate(OpKernelContext* context, const Tensor& g, const Tensor& a, + Tensor* output) { + OperateNoTemplate(context, g, a, alpha_, output); + } + + private: + T alpha_; +}; + +template +void LeakyReluGradOp::OperateNoTemplate(OpKernelContext* context, + const Tensor& g, + const Tensor& a, T alpha, + Tensor* output) { + if (!ReluHelpers::ValidateSameSize(context, g, a)) return; + functor::LeakyReluGrad functor; + functor(context->eigen_device(), g.flat(), a.flat(), alpha, + output->flat()); +}; + +template +class EluOp : public UnaryElementWiseOp> { + public: + using UnaryElementWiseOp>::UnaryElementWiseOp; + + void Operate(OpKernelContext* context, const Tensor& input, Tensor* output) { + functor::Elu functor; + functor(context->eigen_device(), input.flat(), + output->flat()); + } +}; + +template +class EluGradOp : public BinaryElementWiseOp> { + public: + using BinaryElementWiseOp>::BinaryElementWiseOp; + + void OperateNoTemplate(OpKernelContext* context, const Tensor& g, + const Tensor& a, Tensor* output); + + // INPUTS: + // g (gradients): backpropagated gradients + // a (outputs): outputs of the EluOp() + // OUTPUT: + // gradients to backprop + template + void Operate(OpKernelContext* context, const Tensor& g, const Tensor& a, + Tensor* output) { + OperateNoTemplate(context, g, a, output); + } +}; + +template +void EluGradOp::OperateNoTemplate(OpKernelContext* context, + const Tensor& g, const Tensor& a, + Tensor* output) { + if (!ReluHelpers::ValidateSameSize(context, g, a)) return; + functor::EluGrad functor; + functor(context->eigen_device(), g.flat(), a.flat(), + output->flat()); +} + +template +class SeluOp : public UnaryElementWiseOp> { + public: + using UnaryElementWiseOp>::UnaryElementWiseOp; + + void Operate(OpKernelContext* context, const Tensor& input, Tensor* output) { + functor::Selu functor; + functor(context->eigen_device(), input.flat(), + output->flat()); + } +}; + +template +class SeluGradOp : public BinaryElementWiseOp> { + public: + using BinaryElementWiseOp>::BinaryElementWiseOp; + + void OperateNoTemplate(OpKernelContext* context, const Tensor& g, + const Tensor& a, Tensor* output); + + // INPUTS: + // g (gradients): backpropagated gradients + // a (outputs): outputs of the SeluOp() + // OUTPUT: + // gradients to backprop + template + void Operate(OpKernelContext* context, const Tensor& g, const Tensor& a, + Tensor* output) { + OperateNoTemplate(context, g, a, output); + } +}; + +template +void SeluGradOp::OperateNoTemplate(OpKernelContext* context, + const Tensor& g, const Tensor& a, + Tensor* output) { + if (!ReluHelpers::ValidateSameSize(context, g, a)) return; + functor::SeluGrad functor; + functor(context->eigen_device(), g.flat(), a.flat(), + output->flat()); +} + +} // namespace tensorflow + +#undef EIGEN_USE_THREADS + +#endif // TENSORFLOW_CORE_KERNELS_RELU_OP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/relu_op_functor.h b/third_party/tflite-hdrs/tensorflow/core/kernels/relu_op_functor.h new file mode 100644 index 00000000..cacef949 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/relu_op_functor.h @@ -0,0 +1,215 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_RELU_OP_FUNCTOR_H_ +#define TENSORFLOW_CORE_KERNELS_RELU_OP_FUNCTOR_H_ +// Functor definition for ReluOp and ReluGradOp, must be compilable by nvcc. + +#include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/tensor_types.h" + +namespace tensorflow { +namespace functor { + +// Functor used by ReluOp to do the computations. +template +struct Relu { + // Computes Relu activation. + // + // features: any shape. + // activations: same shape as "features". + void operator()(const Device& d, typename TTypes::ConstTensor features, + typename TTypes::Tensor activations) { + activations.device(d) = + features.template cwiseMax(static_cast(0)); + } +}; + +// Functor used by ReluGradOp to do the computations. +template +struct ReluGrad { + // Computes ReluGrad backprops. + // + // gradients: gradients backpropagated to the Relu op. + // features: either the inputs that were passed to the Relu or, or its + // outputs (using either one yields the same result here). + // backprops: gradients to backpropagate to the Relu inputs. + void operator()(const Device& d, typename TTypes::ConstTensor gradients, + typename TTypes::ConstTensor features, + typename TTypes::Tensor backprops) { + // NOTE: When the activation is exactly zero, we do not propagate the + // associated gradient value. This allows the output of the Relu to be used, + // as well as its input. + backprops.device(d) = + gradients * (features > static_cast(0)).template cast(); + } +}; + +// Functor used by Relu6Op to do the computations. +template +struct Relu6 { + // Computes Relu6 activation. + // + // features: any shape. + // activations: same shape as "features". + void operator()(const Device& d, typename TTypes::ConstTensor features, + typename TTypes::Tensor activations) { + activations.device(d) = + features.template cwiseMax(static_cast(0)) + .template cwiseMin(static_cast(6)); + } +}; + +// Functor used by ReluGradOp to do the computations. +template +struct Relu6Grad { + // Computes Relu6Grad backprops. + // + // gradients: gradients backpropagated to the Relu6 op. + // features: inputs that where passed to the Relu6 op, or its outputs. + // backprops: gradients to backpropagate to the Relu6 inputs. + void operator()(const Device& d, typename TTypes::ConstTensor gradients, + typename TTypes::ConstTensor features, + typename TTypes::Tensor backprops) { + // NOTE: When the activation is exactly zero or six, we + // make sure not to propagate the associated gradient + // value. This allows "features" to be either the input or the output of + // the relu6. + backprops.device(d) = gradients * ((features > static_cast(0)) * + (features < static_cast(6))) + .template cast(); + } +}; + +// Functor used by LeakyReluOp to do the computations. +template +struct LeakyRelu { + // Computes LeakyRelu activation. + // + // features: any shape. + // activations: same shape as "features". + + // Need to bundle the args (to the LeakyRelu functor) within a struct + // Not doing so leads to Eigen kernel args not getting populated + // corretly for Eigen::half type (when building on the ROCM platform) + struct LeakyReluArgs { + const Device& d; + typename TTypes::ConstTensor features; + T alpha; + typename TTypes::Tensor activations; + }; + void operator()(LeakyReluArgs args) { + // Note that alpha might be > 1 or < 0, so we don't use cwiseMax here. + args.activations.device(args.d) = + (args.features > static_cast(0)) + .select(args.features, args.features * args.alpha); + } +}; + +// Functor used by LeakyReluGradOp to do the computations. +template +struct LeakyReluGrad { + // Computes LeakyReluGrad backprops. + // + // gradients: gradients backpropagated to the LeakyRelu op. + // features: either the inputs that were passed to the LeakyRelu or, or its + // outputs (using either one yields the same result here). + // backprops: gradients to backpropagate to the LeakyRelu inputs. + void operator()(const Device& d, typename TTypes::ConstTensor gradients, + typename TTypes::ConstTensor features, T alpha, + typename TTypes::Tensor backprops) { + backprops.device(d) = + (features > static_cast(0)).select(gradients, gradients * alpha); + } +}; + +// Functor used by EluOp to do the computations. +template +struct Elu { + // Computes Elu activation. + // + // features: any shape. + // activations: same shape as "features". + void operator()(const Device& d, typename TTypes::ConstTensor features, + typename TTypes::Tensor activations) { + // features.constant(?) + activations.device(d) = + (features < static_cast(0)) + .select(features.exp() - features.constant(static_cast(1)), + features); + } +}; + +// Functor used by EluGradOp to do the computations. +template +struct EluGrad { + // Computes EluGrad backprops. + // + // gradients: gradients backpropagated to the Elu op. + // activations: outputs of the Elu op. + // backprops: gradients to backpropagate to the Elu inputs. + void operator()(const Device& d, typename TTypes::ConstTensor gradients, + typename TTypes::ConstTensor activations, + typename TTypes::Tensor backprops) { + backprops.device(d) = + (activations < static_cast(0)) + .select((activations + static_cast(1)) * gradients, gradients); + } +}; + +// Functor used by SeluOp to do the computations. +template +struct Selu { + // Computes Selu activation. + // + // features: any shape. + // activations: same shape as "features". + void operator()(const Device& d, typename TTypes::ConstTensor features, + typename TTypes::Tensor activations) { + // features.constant(?) + const auto scale = static_cast(1.0507009873554804934193349852946); + const auto scale_alpha = static_cast(1.7580993408473768599402175208123); + const auto one = static_cast(1); + const auto zero = static_cast(0); + activations.device(d) = + (features < zero) + .select(scale_alpha * (features.exp() - features.constant(one)), + scale * features); + } +}; + +// Functor used by SeluGradOp to do the computations. +template +struct SeluGrad { + // Computes SeluGrad backprops. + // + // gradients: gradients backpropagated to the Selu op. + // activations: outputs of the Selu op. + // backprops: gradients to backpropagate to the Selu inputs. + void operator()(const Device& d, typename TTypes::ConstTensor gradients, + typename TTypes::ConstTensor activations, + typename TTypes::Tensor backprops) { + const auto scale = static_cast(1.0507009873554804934193349852946); + const auto scale_alpha = static_cast(1.7580993408473768599402175208123); + backprops.device(d) = + (activations < static_cast(0)) + .select(gradients * (activations + scale_alpha), gradients * scale); + } +}; + +} // namespace functor +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_RELU_OP_FUNCTOR_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/reshape_op.h b/third_party/tflite-hdrs/tensorflow/core/kernels/reshape_op.h new file mode 100644 index 00000000..dd603374 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/reshape_op.h @@ -0,0 +1,168 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_RESHAPE_OP_H_ +#define TENSORFLOW_CORE_KERNELS_RESHAPE_OP_H_ + +#include + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/util/overflow.h" + +namespace tensorflow { + +// Note that this op is subclassed for QuantizedReshapeOp. +class ReshapeOp : public OpKernel { + public: + explicit ReshapeOp(OpKernelConstruction* context) : OpKernel(context) {} + + void Compute(OpKernelContext* context) override { + const Tensor& input = context->input(0); + const Tensor& sizes = context->input(1); + // Preliminary validation of sizes. + OP_REQUIRES( + context, + (TensorShapeUtils::IsVector(sizes.shape()) || + // TODO(rmlarsen): Disallow legacy use of scalars to represent shape. + TensorShapeUtils::IsScalar(sizes.shape())), + errors::InvalidArgument("sizes input must be 1-D, not ", + sizes.shape().DebugString())); + OP_REQUIRES( + context, sizes.NumElements() < TensorShape::MaxDimensions(), + errors::InvalidArgument("too many dimensions: must be < ", + TensorShape::MaxDimensions(), ", but received ", + sizes.NumElements())); + + // Compute the output shape. Determine product of specified + // dimensions, and find the index of the unspecified one. + TensorShape shape; + int64_t product = 1; + int unknown_index = -1; + bool sizes_has_zero_dim; + switch (sizes.dtype()) { + case DT_INT32: + OP_REQUIRES_OK(context, + ValidateSizes(sizes, &product, &unknown_index, + &shape, &sizes_has_zero_dim)); + break; + case DT_INT64: + OP_REQUIRES_OK(context, + ValidateSizes(sizes, &product, &unknown_index, + &shape, &sizes_has_zero_dim)); + break; + default: + context->CtxFailure(errors::InvalidArgument( + "desired shape must be a DT_INT32 or DT_INT64 vector, not a ", + DataTypeString(sizes.dtype()))); + return; + } + if (unknown_index != -1) { + int64_t input_num_elements = 1; + bool input_has_zero_dim = false; + for (int dim = 0; dim < input.dims(); dim++) { + // For zero dimension, we don't count it into `input_num_elements` + // unless `sizes` has no zero dimension, so we are still able to + // infer shapes for other dimensions. + if (input.dim_size(dim) > 0 || !sizes_has_zero_dim) { + input_num_elements *= input.dim_size(dim); + } else { + input_has_zero_dim = true; + } + } + + const int64_t missing = input_num_elements / product; + if (!input_has_zero_dim) { + OP_REQUIRES( + context, product * missing == input_num_elements, + errors::InvalidArgument( + "Input to reshape is a tensor with ", input_num_elements, + " values, but the requested shape requires a multiple of ", + product)); + } + shape.set_dim(unknown_index, missing); + } + OP_REQUIRES(context, shape.num_elements() == input.NumElements(), + errors::InvalidArgument("Input to reshape is a tensor with ", + input.NumElements(), + " values, but the requested shape has ", + shape.num_elements())); + + // Actually produce the reshaped output. + Tensor output(input.dtype()); + CHECK(output.CopyFrom(input, shape)); + context->set_output(0, std::move(output)); + } + + bool IsExpensive() override { return false; } + + private: + template + absl::Status ValidateSizes(const Tensor& sizes, int64_t* product, + int* unknown_index, TensorShape* shape, + bool* has_zero_dim) { + *product = 1; + *unknown_index = -1; + *has_zero_dim = false; + const int64_t num_dims = sizes.NumElements(); + auto Svec = sizes.flat(); + for (int d = 0; d < num_dims; ++d) { + const Tshape size = Svec(d); + if (size == -1) { + if (*unknown_index != -1) { + return errors::InvalidArgument( + "Only one input size may be -1, not both ", *unknown_index, + " and ", d); + } + *unknown_index = d; + TF_RETURN_IF_ERROR(shape->AddDimWithStatus(1)); + } else if (size < 0) { + return errors::InvalidArgument("Size ", d, + " must be non-negative, not ", size); + } else if (size == 0) { + // We don't include zero-sized dimension in product, so that we can + // still calculate number of elements for non-zero-sized dimensions and + // therefore infer their shapes. + TF_RETURN_IF_ERROR(shape->AddDimWithStatus(size)); + *has_zero_dim = true; + } else { + if (MultiplyWithoutOverflow(shape->num_elements(), size) < 0) { + string msg; + for (int ii = 0; ii < num_dims; ++ii) { + if (ii != 0) { + strings::StrAppend(&msg, ", "); + } + strings::StrAppend(&msg, Svec(ii)); + } + return errors::InvalidArgument("Shape [", msg, + "] has too many elements"); + } + TF_RETURN_IF_ERROR(shape->AddDimWithStatus(size)); + (*product) *= size; + } + } + return absl::OkStatus(); + } +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_RESHAPE_OP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/reshape_util.h b/third_party/tflite-hdrs/tensorflow/core/kernels/reshape_util.h new file mode 100644 index 00000000..1945712c --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/reshape_util.h @@ -0,0 +1,52 @@ + +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_KERNELS_RESHAPE_UTIL_H_ +#define TENSORFLOW_CORE_KERNELS_RESHAPE_UTIL_H_ + +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { + +class OpKernelContext; +class Tensor; + +// Reshapes the input indices and input shape to the target shape. +// Note: This template is explicitly instantiated for CPU and GPU devices. +template +void ReshapeSparseTensor(OpKernelContext *context, + const Tensor &input_indices_in, + const Tensor &input_shape_in, + const Tensor &target_shape_in, int output_indices_idx, + int output_shape_idx); + +namespace functor { + +template +struct ReshapeSparseTensorFunctor { + absl::Status operator()( + OpKernelContext *context, const TensorShape &input_shape, + const TensorShape &output_shape, + typename TTypes::ConstMatrix input_indices, + typename TTypes::Matrix output_indices) const; +}; + +} // namespace functor + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_RESHAPE_UTIL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/resource_variable_ops.h b/third_party/tflite-hdrs/tensorflow/core/kernels/resource_variable_ops.h new file mode 100644 index 00000000..1c8d7998 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/resource_variable_ops.h @@ -0,0 +1,99 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_KERNELS_RESOURCE_VARIABLE_OPS_H_ +#define TENSORFLOW_CORE_KERNELS_RESOURCE_VARIABLE_OPS_H_ + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/resource_mgr.h" +#include "tensorflow/core/framework/resource_var.h" + +namespace tensorflow { + +class VarHandleOp : public OpKernel { + public: + explicit VarHandleOp(OpKernelConstruction* c); + void Compute(OpKernelContext* ctx) override; + const Tensor* const_tensor() const override { + return is_anonymous_ ? nullptr : &const_tensor_; + } + + private: + // Same fields as in ResourceHandleOp. + bool is_anonymous_; + string container_; + string name_; + string debug_name_; + Tensor const_tensor_; + + DtypeAndPartialTensorShape dtype_and_shape_; +}; + +class ReadVariableOp : public OpKernel { + public: + explicit ReadVariableOp(OpKernelConstruction* c); + void Compute(OpKernelContext* ctx) override; + + private: + DataType dtype_; +}; + +class ReadVariablesOp : public OpKernel { + public: + explicit ReadVariablesOp(OpKernelConstruction* c); + void Compute(OpKernelContext* ctx) override; + bool IsExpensive() override { return false; } + + private: + DataTypeVector dtypes_; +}; + +class DestroyResourceOp : public OpKernel { + public: + explicit DestroyResourceOp(OpKernelConstruction* ctx); + void Compute(OpKernelContext* ctx) override; + + private: + bool ignore_lookup_error_; +}; + +class DisableCopyOnReadOp : public OpKernel { + public: + explicit DisableCopyOnReadOp(OpKernelConstruction* c) : OpKernel(c) {} + void Compute(OpKernelContext* ctx) override; +}; + +template +class VariableShapeOp : public OpKernel { + public: + explicit VariableShapeOp(OpKernelConstruction* c) : OpKernel(c) {} + + void Compute(OpKernelContext* ctx) override { + core::RefCountPtr variable; + OP_REQUIRES_OK(ctx, + LookupResource(ctx, HandleFromInput(ctx, 0), &variable)); + variable->mu()->lock_shared(); + TensorShape shape = variable->tensor()->shape(); + variable->mu()->unlock_shared(); + Tensor* output; + OP_REQUIRES_OK(ctx, ctx->allocate_output(0, {shape.dims()}, &output)); + for (int i = 0; i < shape.dims(); ++i) { + output->flat()(i) = shape.dim_size(i); + } + } +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_RESOURCE_VARIABLE_OPS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/resource_variable_util.h b/third_party/tflite-hdrs/tensorflow/core/kernels/resource_variable_util.h new file mode 100644 index 00000000..1222b4eb --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/resource_variable_util.h @@ -0,0 +1,28 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_RESOURCE_VARIABLE_UTIL_H_ +#define TENSORFLOW_CORE_KERNELS_RESOURCE_VARIABLE_UTIL_H_ + +#include "tensorflow/core/framework/tensor_shape.h" + +namespace tensorflow { + +absl::Status ValidateAssignUpdateVariableOpShapes( + const TensorShape& variable_shape, const TensorShape& value_shape); + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_RESOURCE_VARIABLE_UTIL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/reverse_op.h b/third_party/tflite-hdrs/tensorflow/core/kernels/reverse_op.h new file mode 100644 index 00000000..a2de766a --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/reverse_op.h @@ -0,0 +1,48 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_REVERSE_OP_H_ +#define TENSORFLOW_CORE_KERNELS_REVERSE_OP_H_ + +#include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/tensor_types.h" + +namespace tensorflow { +namespace functor { + +// Functor used by ReverseOp to do the computations. +template +struct Reverse { + void operator()(const Device& d, typename TTypes::ConstTensor input, + const Eigen::array& reverse_dims, + typename TTypes::Tensor output) { + output.device(d) = input.reverse(reverse_dims); + } +}; + +template +struct Reverse { + void operator()(const Device& d, typename TTypes::ConstTensor input, + const Eigen::array& reverse_dims, + typename TTypes::Tensor output) { + // Reversing a scalar is copying it. + output.device(d) = input; + } +}; + +} // namespace functor +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_REVERSE_OP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/reverse_sequence_op.h b/third_party/tflite-hdrs/tensorflow/core/kernels/reverse_sequence_op.h new file mode 100644 index 00000000..f25794f3 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/reverse_sequence_op.h @@ -0,0 +1,78 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_REVERSE_SEQUENCE_OP_H_ +#define TENSORFLOW_CORE_KERNELS_REVERSE_SEQUENCE_OP_H_ +// Generator definition for ReverseSequenceOp, must be compilable by nvcc. + +#include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +namespace generator { + +template +class ReverseGenerator { + public: + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE ReverseGenerator( + typename TTypes::ConstTensor input, int32_t batch_dim, + int32_t seq_dim, typename TTypes::ConstVec seq_lengths) + : input_(input), + batch_dim_(batch_dim), + seq_dim_(seq_dim), + seq_lengths_(seq_lengths) {} + + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T + operator()(const Eigen::array& coords) const { + Eigen::array new_coords = coords; + if (coords[seq_dim_] < seq_lengths_(coords[batch_dim_])) { + new_coords[seq_dim_] = + seq_lengths_(coords[batch_dim_]) - coords[seq_dim_] - 1; + } + + return input_(new_coords); + } + + private: + typename TTypes::ConstTensor input_; + int32 batch_dim_; + int32 seq_dim_; + typename TTypes::ConstVec seq_lengths_; +}; + +} // namespace generator + +namespace functor { + +template +struct ReverseSequence { + EIGEN_ALWAYS_INLINE static void Compute( + const Device& d, typename TTypes::ConstTensor input, + int32_t batch_dim, int32_t seq_dim, + typename TTypes::ConstVec seq_lengths, + typename TTypes::Tensor output) { + generator::ReverseGenerator generator(input, batch_dim, + seq_dim, seq_lengths); + output.device(d) = input.generate(generator); + } +}; + +} // namespace functor + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_REVERSE_SEQUENCE_OP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/rnn/blas_gemm.h b/third_party/tflite-hdrs/tensorflow/core/kernels/rnn/blas_gemm.h new file mode 100644 index 00000000..dabacedd --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/rnn/blas_gemm.h @@ -0,0 +1,97 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_RNN_BLAS_GEMM_H_ +#define TENSORFLOW_CORE_KERNELS_RNN_BLAS_GEMM_H_ + +#include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/kernels/eigen_activations.h" +#include "tensorflow/core/platform/types.h" + +#if defined(TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL) +#include "xla/tsl/framework/contraction/eigen_contraction_kernel.h" +#endif + +namespace tensorflow { +class OpKernelContext; +namespace functor { + +template +struct TensorCuBlasGemm { + void operator()(OpKernelContext* ctx, bool transa, bool transb, uint64 m, + uint64 n, uint64 k, float alpha, const T* a, int lda, + const T* b, int ldb, float beta, T* c, int ldc); +}; + +template +struct gemm_compute_type { + typedef T type; +}; + +template <> +struct gemm_compute_type { + typedef float type; +}; + +template +struct TensorBlasGemm; + +template +struct TensorBlasGemm { + static void compute(OpKernelContext* ctx, const Device& d, bool transa, + bool transb, typename gemm_compute_type::type alpha, + typename TTypes::ConstMatrix a, + typename TTypes::ConstMatrix b, + typename gemm_compute_type::type beta, + typename TTypes::Matrix c) { + int64_t m = c.dimensions()[0]; + int64_t n = c.dimensions()[1]; + int64_t k = transa ? a.dimensions()[0] : a.dimensions()[1]; + + TensorCuBlasGemm()(ctx, transb, transa, n, m, k, alpha, b.data(), + transb ? k : n, a.data(), transa ? m : k, beta, + c.data(), n); + } +}; + +template +struct TensorBlasGemm { + static void compute(OpKernelContext* ctx, const Device& d, bool transa, + bool transb, typename gemm_compute_type::type alpha, + typename TTypes::ConstMatrix a, + typename TTypes::ConstMatrix b, + typename gemm_compute_type::type beta, + typename TTypes::Matrix c) { + Eigen::array, 1> contract_pairs; + contract_pairs[0] = + Eigen::IndexPair(transa == false, transb == true); + if (alpha == typename gemm_compute_type::type(1.f) && + beta == typename gemm_compute_type::type(0.f)) { + c.device(d) = a.contract(b, contract_pairs); + } else if (alpha == typename gemm_compute_type::type(1.f) && + beta == typename gemm_compute_type::type(1.f)) { + c.device(d) += a.contract(b, contract_pairs); + } else { + c.device(d) = c.constant(T(alpha)) * a.contract(b, contract_pairs) + + c.constant(T(beta)) * c; + } + } +}; + +} // namespace functor +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_RNN_BLAS_GEMM_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/rnn/gru_ops.h b/third_party/tflite-hdrs/tensorflow/core/kernels/rnn/gru_ops.h new file mode 100644 index 00000000..8799401d --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/rnn/gru_ops.h @@ -0,0 +1,189 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_RNN_GRU_OPS_H_ +#define TENSORFLOW_CORE_KERNELS_RNN_GRU_OPS_H_ + +#include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/kernels/rnn/blas_gemm.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +class OpKernelContext; + +namespace functor { + +struct GRUCell { + GRUCell(const int batch_size, const int input_size, const int cell_size) + : batch_size_(batch_size), + input_size_(input_size), + cell_size_(cell_size) {} + + inline Eigen::array x_offsets() const { return {0, 0}; } + + inline Eigen::array x_extends() const { + return {batch_size_, input_size_}; + } + + inline Eigen::array h_offsets() const { + return {0, input_size_}; + } + + inline Eigen::array h_extends() const { + return {batch_size_, cell_size_}; + } + + inline Eigen::array ru_r_offset() const { + return {0, 0}; + } + + inline Eigen::array ru_u_offset() const { + return {0, cell_size_}; + } + + inline Eigen::array cell_extents() const { + return {batch_size_, cell_size_}; + } + + protected: + const int batch_size_; + const int input_size_; + const int cell_size_; +}; + +template +struct GRUBlockCellFprop : public GRUCell { + GRUBlockCellFprop(const int batch_size, const int input_size, + const int cell_size) + : GRUCell(batch_size, input_size, cell_size) {} + + void operator()( + OpKernelContext* ctx, const Device& d, typename TTypes::ConstMatrix x, + typename TTypes::ConstMatrix h_prev, + typename TTypes::ConstMatrix w_ru, typename TTypes::ConstMatrix w_c, + typename TTypes::ConstVec b_ru, typename TTypes::ConstVec b_c, + typename TTypes::Matrix r_u_bar, typename TTypes::Matrix r, + typename TTypes::Matrix u, typename TTypes::Matrix c, + typename TTypes::Matrix h, typename TTypes::Matrix x_h_prev, + typename TTypes::Matrix x_h_prevr) { + // Concat x_h_prev = [x, h_prev]. + x_h_prev.slice(x_offsets(), x_extends()).device(d) = x; + x_h_prev.slice(h_offsets(), h_extends()).device(d) = h_prev; + + // r_u_bar = x_h_prev * w_ru + b_ru + typename TTypes::ConstMatrix const_x_h_prev(x_h_prev.data(), + x_h_prev.dimensions()); + TensorBlasGemm::compute( + ctx, d, false, false, typename gemm_compute_type::type(1.f), + const_x_h_prev, w_ru, typename gemm_compute_type::type(0.f), + r_u_bar); + + // Creating a bias matrix for adding by broadcasting 'b_ru' + Eigen::array broadcast_shape({batch_size_, 1}); + Eigen::array b_ru_shape({1, b_ru.dimensions()[0]}); + r_u_bar.device(d) += b_ru.reshape(b_ru_shape).broadcast(broadcast_shape); + + // Slice r_u_bar into r, u and apply the sigmoid. + r.device(d) = (r_u_bar.slice(ru_r_offset(), cell_extents())).sigmoid(); + u.device(d) = (r_u_bar.slice(ru_u_offset(), cell_extents())).sigmoid(); + + // Concat x_h_prevr = [x,h_prev*r] + x_h_prevr.slice(x_offsets(), x_extends()).device(d) = x; + x_h_prevr.slice(h_offsets(), h_extends()).device(d) = h_prev * r; + + // c = tanh(x_h_prevr*w_c+b_c), Note b_c is broadcasted before adding. + typename TTypes::ConstMatrix const_x_h_prevr(x_h_prevr.data(), + x_h_prevr.dimensions()); + TensorBlasGemm::compute( + ctx, d, false, false, typename gemm_compute_type::type(1.f), + const_x_h_prevr, w_c, typename gemm_compute_type::type(0.f), c); + + Eigen::array b_c_shape({1, b_c.dimensions()[0]}); + c.device(d) += (b_c.reshape(b_c_shape).broadcast(broadcast_shape)); + c.device(d) = c.tanh(); + + // h= u*h_prev + (1-u)*c + h.device(d) = u * (h_prev - c) + c; + } +}; + +template +struct GRUBlockCellBprop : public GRUCell { + GRUBlockCellBprop(const int batch_size, const int input_size, + const int cell_size) + : GRUCell(batch_size, input_size, cell_size) {} + + void operator()( + OpKernelContext* ctx, const Device& d, typename TTypes::ConstMatrix x, + typename TTypes::ConstMatrix h_prev, + typename TTypes::ConstMatrix w_ru, typename TTypes::ConstMatrix w_c, + typename TTypes::ConstVec b_ru, typename TTypes::ConstVec b_c, + typename TTypes::ConstMatrix r, typename TTypes::ConstMatrix u, + typename TTypes::ConstMatrix c, typename TTypes::ConstMatrix d_h, + typename TTypes::Matrix d_x, typename TTypes::Matrix d_h_prev, + typename TTypes::Matrix d_c_bar, + typename TTypes::Matrix d_r_bar_u_bar, + typename TTypes::Matrix d_r_bar, typename TTypes::Matrix d_u_bar, + typename TTypes::Matrix d_hr, + typename TTypes::Matrix d_x_comp1_and_h_prev_comp1, + typename TTypes::Matrix d_x_comp2_and_h_prevr) { + // d_c_bar = d_h*(1-u)*(1-(c*c)) + d_c_bar.device(d) = + ((d_h * (u.constant(T(1)) - u)) * (c.constant(T(1)) - c * c)); + + // d_u_bar = d_h*(h-c)*(u*(1-u)) + d_u_bar.device(d) = d_h * (h_prev - c) * u * (u.constant(T(1)) - u); + + // [2nd_component_of_d_x d_h_prevr] = d_c_bar X w_c^T + typename TTypes::ConstMatrix const_d_c_bar(d_c_bar.data(), + d_c_bar.dimensions()); + TensorBlasGemm::compute( + ctx, d, false, true, typename gemm_compute_type::type(1.f), + const_d_c_bar, w_c, typename gemm_compute_type::type(0.f), + d_x_comp2_and_h_prevr); + + d_hr.device(d) = d_x_comp2_and_h_prevr.slice(h_offsets(), h_extends()); + d_r_bar.device(d) = (d_hr * h_prev * r) * (r.constant(T(1)) - r); + + // d_r_bar_u_bar = concatenate(d_r_bar, d_u_bar) along axis = 1. + d_r_bar_u_bar.slice(ru_r_offset(), cell_extents()).device(d) = d_r_bar; + d_r_bar_u_bar.slice(ru_u_offset(), cell_extents()).device(d) = d_u_bar; + + // [1st_component_of_d_x 1st_component_of_d_h_prev] = [d_r_bar d_u_bar] X + // w_ru^T + typename TTypes::ConstMatrix const_d_r_bar_u_bar( + d_r_bar_u_bar.data(), d_r_bar_u_bar.dimensions()); + TensorBlasGemm::compute( + ctx, d, false, true, typename gemm_compute_type::type(1.f), + const_d_r_bar_u_bar, w_ru, typename gemm_compute_type::type(0.f), + d_x_comp1_and_h_prev_comp1); + + // d_x = d_x_comp1 + d_x_comp2 + d_x.device(d) = (d_x_comp1_and_h_prev_comp1 + d_x_comp2_and_h_prevr) + .slice(x_offsets(), x_extends()); + + // d_h_prev = d_h_comp1 + d_hr*r + d_h*u + d_h_prev.device(d) = + d_x_comp1_and_h_prev_comp1.slice(h_offsets(), h_extends()) + + (d_hr * r) + (d_h * u); + } +}; + +} // namespace functor +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_RNN_GRU_OPS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/rnn/lstm_ops.h b/third_party/tflite-hdrs/tensorflow/core/kernels/rnn/lstm_ops.h new file mode 100644 index 00000000..f2457531 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/rnn/lstm_ops.h @@ -0,0 +1,308 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_RNN_LSTM_OPS_H_ +#define TENSORFLOW_CORE_KERNELS_RNN_LSTM_OPS_H_ + +#include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/kernels/eigen_activations.h" +#include "tensorflow/core/kernels/rnn/blas_gemm.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +class OpKernelContext; + +enum GateLayout { ICFO, IFCO }; + +constexpr int gate_c_offset(GateLayout gate_layout, int cell_size) { + return (gate_layout == ICFO) ? cell_size : cell_size * 2; +} + +constexpr int gate_f_offset(GateLayout gate_layout, int cell_size) { + return (gate_layout == ICFO) ? cell_size * 2 : cell_size; +} + +namespace functor { + +template +struct TensorZero { + void operator()(const Device& d, typename TTypes::Flat t) { + t.device(d) = t.constant(T(0)); + } +}; + +template +struct TensorUnalignedZero { + void operator()(const Device& d, typename TTypes::UnalignedFlat t) { + t.device(d) = t.constant(T(0)); + } +}; + +template +struct TensorCopy { + void operator()(const Device& d, typename TTypes::ConstFlat src, + typename TTypes::Flat dst) { + dst.device(d) = src; + } +}; + +template +struct TensorCopyUnaligned { + void operator()(const Device& d, typename TTypes::UnalignedConstFlat src, + typename TTypes::Flat dst) { + dst.device(d) = src; + } +}; + +template +struct TensorCopyToUnaligned { + void operator()(const Device& d, typename TTypes::ConstFlat src, + typename TTypes::UnalignedFlat dst) { + dst.device(d) = src; + } +}; + +template +struct TensorAdd { + void operator()(const Device& d, typename TTypes::ConstFlat a, + typename TTypes::ConstFlat b, typename TTypes::Flat c) { + c.device(d) = a + b; + } +}; + +template +struct TensorZeroPadding { + void operator()(const Device& d, const int64_t time_idx, + typename TTypes::ConstVec seq_len, + typename TTypes::Vec mask, typename TTypes::Matrix m) { + // mask is shape [batch_size]. + mask.device(d) = seq_len.constant(time_idx) < seq_len; + + // m_shape is [batch_size, 1]. + Eigen::array m_shape({m.dimensions()[0], 1}); + // broadcast_shape is [1, units]. + Eigen::array broadcast_shape({1, m.dimensions()[1]}); + + // m is shape [batch_size, units]. + m.device(d) = m * mask.reshape(m_shape).broadcast(broadcast_shape); + } +}; + +struct LSTMBlockCell { + LSTMBlockCell(const int batch_size, const int input_size, const int cell_size) + : batch_size_(batch_size), + input_size_(input_size), + cell_size_(cell_size) {} + + int batch_size() const { return batch_size_; } + + int input_size() const { return input_size_; } + + int cell_size() const { return cell_size_; } + + inline Eigen::array gates_i_offsets() const { + return {0, 0}; + } + + inline Eigen::array gates_c_offsets( + const GateLayout gate_layout) const { + return {0, gate_c_offset(gate_layout, cell_size_)}; + } + + inline Eigen::array gates_f_offsets( + const GateLayout gate_layout) const { + return {0, gate_f_offset(gate_layout, cell_size_)}; + } + + inline Eigen::array gates_o_offsets() const { + return {0, cell_size_ * 3}; + } + + inline Eigen::array cell_extents() const { + return {batch_size_, cell_size_}; + } + + inline Eigen::array xh_x_offsets() const { + return {0, 0}; + } + + inline Eigen::array xh_x_extents() const { + return {batch_size_, input_size_}; + } + + inline Eigen::array xh_h_offsets() const { + return {0, input_size_}; + } + + inline Eigen::array xh_h_extents() const { + return {batch_size_, cell_size_}; + } + + protected: + const int batch_size_; + const int input_size_; + const int cell_size_; +}; + +// See lstm_ops.cc for CPUDevice implementation and lstm_ops_gpu.cu.cc for +// GPUDevice implementation. +template +struct LSTMBlockCellFprop : public LSTMBlockCell { + LSTMBlockCellFprop(const int batch_size, const int input_size, + const int cell_size) + : LSTMBlockCell(batch_size, input_size, cell_size) {} + + void operator()(OpKernelContext* ctx, const Device& d, + const float forget_bias, const float cell_clip, + bool use_peephole, typename TTypes::ConstMatrix x, + typename TTypes::ConstMatrix cs_prev, + typename TTypes::ConstMatrix h_prev, + typename TTypes::ConstMatrix w, + typename TTypes::ConstVec wci, + typename TTypes::ConstVec wcf, + typename TTypes::ConstVec wco, + typename TTypes::ConstVec b, typename TTypes::Matrix xh, + typename TTypes::Matrix i, typename TTypes::Matrix cs, + typename TTypes::Matrix f, typename TTypes::Matrix o, + typename TTypes::Matrix ci, typename TTypes::Matrix co, + typename TTypes::Matrix gates, + typename TTypes::Matrix h); +}; + +// See lstm_ops.cc for CPUDevice implementation and lstm_ops_gpu.cu.cc for +// GPUDevice implementation. +template +struct LSTMBlockCellBprop : public LSTMBlockCell { + LSTMBlockCellBprop(const int batch_size, const int input_size, + const int cell_size) + : LSTMBlockCell(batch_size, input_size, cell_size) {} + + void operator()( + OpKernelContext* ctx, const Device& d, bool use_peephole, + typename TTypes::ConstMatrix x, + typename TTypes::ConstMatrix cs_prev, + typename TTypes::ConstMatrix h_prev, typename TTypes::ConstMatrix w, + typename TTypes::ConstVec wci, typename TTypes::ConstVec wcf, + typename TTypes::ConstVec wco, typename TTypes::ConstVec b, + typename TTypes::ConstMatrix i, typename TTypes::ConstMatrix cs, + typename TTypes::ConstMatrix f, typename TTypes::ConstMatrix o, + typename TTypes::ConstMatrix ci, typename TTypes::ConstMatrix co, + typename TTypes::ConstMatrix cs_grad, + typename TTypes::ConstMatrix h_grad, typename TTypes::Matrix do_, + typename TTypes::Matrix dcs, typename TTypes::Matrix dci, + typename TTypes::Matrix df, typename TTypes::Matrix di, + typename TTypes::Matrix dgates, + typename TTypes::Matrix cs_prev_grad, typename TTypes::Vec wci_grad, + typename TTypes::Vec wcf_grad, typename TTypes::Vec wco_grad); +}; + +template +struct BlockLSTMBprop : public LSTMBlockCell { + BlockLSTMBprop(const int batch_size, const int input_size, + const int cell_size) + : LSTMBlockCell(batch_size, input_size, cell_size) {} + + void operator()( + OpKernelContext* ctx, const Device& d, bool use_peephole, + typename TTypes::ConstMatrix x, + typename TTypes::ConstMatrix cs_prev, + typename TTypes::ConstMatrix h_prev, typename TTypes::ConstMatrix w, + typename TTypes::ConstVec wci, typename TTypes::ConstVec wcf, + typename TTypes::ConstVec wco, typename TTypes::ConstVec b, + typename TTypes::Matrix xh, typename TTypes::ConstMatrix i, + typename TTypes::ConstMatrix cs, typename TTypes::ConstMatrix f, + typename TTypes::ConstMatrix o, typename TTypes::ConstMatrix ci, + typename TTypes::ConstMatrix co, + typename TTypes::ConstMatrix cs_grad, + typename TTypes::ConstMatrix h_grad, typename TTypes::Matrix do_, + typename TTypes::Matrix dcs, typename TTypes::Matrix dci, + typename TTypes::Matrix df, typename TTypes::Matrix di, + typename TTypes::Matrix dgates, + typename TTypes::Matrix cs_prev_grad, + typename TTypes::Matrix h_prev_grad, + typename TTypes::Matrix xh_grad, typename TTypes::Matrix x_grad, + typename TTypes::Matrix w_grad, typename TTypes::Vec wci_grad, + typename TTypes::Vec wcf_grad, typename TTypes::Vec wco_grad, + typename TTypes::Vec b_grad) { + // do[t] = sigm'(o[t]) .* dh[t] .* co[t] + do_.device(d) = o * (o.constant(T(1)) - o) * h_grad * co; + + // dcs[t] += tanh'(cs[t]) .* dh[t] .* o[t] + dcs[t + 1] .* f[t + 1] + dcs.device(d) = (co.constant(T(1)) - co * co) * h_grad * o + cs_grad; + + Eigen::array p_shape({1, cell_size_}); + Eigen::array p_broadcast_shape({batch_size_, 1}); + if (use_peephole) { + dcs.device(d) = + dcs + do_ * wco.reshape(p_shape).broadcast(p_broadcast_shape); + } + + // dci[t] = tanh'(ci[t]) dcs[t] i[t] + dci.device(d) = (ci.constant(T(1)) - ci * ci) * dcs * i; + + // df[t] = sigm'(f[t]) dcs[t] cs[t - 1] + df.device(d) = f * (f.constant(T(1)) - f) * dcs * cs_prev; + + // di[t] = sigm'(i[t]) dcs[t] ci[t] + di.device(d) = i * (i.constant(T(1)) - i) * dcs * ci; + + dgates.slice(gates_i_offsets(), cell_extents()).device(d) = di; + dgates.slice(gates_c_offsets(gate_layout), cell_extents()).device(d) = dci; + dgates.slice(gates_f_offsets(gate_layout), cell_extents()).device(d) = df; + dgates.slice(gates_o_offsets(), cell_extents()).device(d) = do_; + + cs_prev_grad.device(d) = dcs * f; + if (use_peephole) { + cs_prev_grad.device(d) = + cs_prev_grad + + di * wci.reshape(p_shape).broadcast(p_broadcast_shape) + + df * wcf.reshape(p_shape).broadcast(p_broadcast_shape); + } + + // xh_grad. + typename TTypes::ConstMatrix const_dgates(dgates.data(), + dgates.dimensions()); + TensorBlasGemm::compute( + ctx, d, false, true, 1.f, const_dgates, w, 0.f, xh_grad); + + // xh. + xh.slice(xh_x_offsets(), xh_x_extents()).device(d) = x; + xh.slice(xh_h_offsets(), xh_h_extents()).device(d) = h_prev; + typename TTypes::ConstMatrix const_xh(xh.data(), xh.dimensions()); + + // x_grad. + x_grad.device(d) = xh_grad.slice(xh_x_offsets(), xh_x_extents()); + h_prev_grad.device(d) = xh_grad.slice(xh_h_offsets(), xh_h_extents()); + + // w_grad. + TensorBlasGemm::compute( + ctx, d, true, false, 1.f, const_xh, const_dgates, 1.f, w_grad); + + // b_grad. + b_grad.device(d) += dgates.sum(Eigen::array({0})); + + if (use_peephole) { + wci_grad.device(d) += (di * cs_prev).sum(Eigen::array({0})); + wcf_grad.device(d) += (df * cs_prev).sum(Eigen::array({0})); + wco_grad.device(d) += (do_ * cs).sum(Eigen::array({0})); + } + } +}; + +} // namespace functor +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_RNN_LSTM_OPS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/roll_op.h b/third_party/tflite-hdrs/tensorflow/core/kernels/roll_op.h new file mode 100644 index 00000000..7ae1d8f5 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/roll_op.h @@ -0,0 +1,46 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_ROLL_OP_H_ +#define TENSORFLOW_CORE_KERNELS_ROLL_OP_H_ + +#include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/numeric_types.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor_types.h" + +namespace tensorflow { +namespace functor { + +template +struct Roll { + // dim_size - the size of each dimension + // dim_range - the number of indices over in the flattened tensor + // you need to skip in order to make it over from one side of a dimension + // to the other. Used to make the shifts wrap around after a threshold. + // threshold - the index for each dimension that the roll starts to wrap + // back to the front + // isd - inner shift dimension + void operator()(const OpKernelContext* context, const int64_t num_elements, + const int num_dims, const absl::Span dim_size, + const T* input, T* output, + const absl::Span threshold, + const absl::Span dim_range, const int64_t isd); +}; + +} // namespace functor +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_ROLL_OP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/save_restore_tensor.h b/third_party/tflite-hdrs/tensorflow/core/kernels/save_restore_tensor.h new file mode 100644 index 00000000..f5fac541 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/save_restore_tensor.h @@ -0,0 +1,73 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_SAVE_RESTORE_TENSOR_H_ +#define TENSORFLOW_CORE_KERNELS_SAVE_RESTORE_TENSOR_H_ + +#include "tensorflow/core/util/tensor_slice_reader.h" +#include "tensorflow/core/util/tensor_slice_writer.h" + +namespace tensorflow { + +class OpKernelContext; + +// Legacy / V1 checkpoint format. + +// Save input tensors in *context to a writer built from builder_func(). +// context must have the following inputs: +// 0: a single element string tensor that contains the file name. +// 1: names for the remaining tensors +// If save_slices is true: +// 2: shape and slice specifications. +// rest: tensors to save +void SaveTensors( + OpKernelContext* context, + checkpoint::TensorSliceWriter::CreateBuilderFunction builder_func, + bool save_slices); + +// Reads a single tensor from the reader built from open_func() and produces +// it as context->output(restore_index). "preferred_shard" is the same the +// TensorSliceReader preferred_shard parameter. +// +// context must have the following inputs: +// 0: a single element string tensor that contains the file name. +// 1: string tensor that names the outputs to be restored. +// If restore_slice is true: +// 2: shape and slice specification of the tensors to restore. +// +// restore_index indicates the variable name and slice to lookup +// in context(1) and (2). +void RestoreTensor(OpKernelContext* context, + checkpoint::TensorSliceReader::OpenTableFunction open_func, + int preferred_shard, bool restore_slice, int restore_index); + +// V2 checkpoint format. + +// Invokes the V2 checkpoint read path to read tensors. +// +// "context" is only used for allocating outputs. In particular, the inputs are +// explicitly provided and not accessed via the "input(i)" methods. +// REQUIRES: +// * "prefix" has 1 element, DT_STRING. +// * "tensor_names" and "shape_and_slices" shaped {N}, both DT_STRING. +// * "dtypes" has N elements, the datatypes of the to-restore tensors. +absl::Status RestoreTensorsV2(OpKernelContext* context, const Tensor& prefix, + const Tensor& tensor_names, + const Tensor& shape_and_slices, + absl::Span dtypes); + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_SAVE_RESTORE_TENSOR_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/scan_ops.h b/third_party/tflite-hdrs/tensorflow/core/kernels/scan_ops.h new file mode 100644 index 00000000..ad3f2e1e --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/scan_ops.h @@ -0,0 +1,146 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_SCAN_OPS_H_ +#define TENSORFLOW_CORE_KERNELS_SCAN_OPS_H_ + +#include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/tensor_types.h" + +namespace tensorflow { +namespace functor { + +typedef Eigen::Index Index; + +// TODO(b/154339590): Needs to be vectorized. +template +struct Scan { + void operator()(const Device& d, typename TTypes::ConstTensor in, + typename TTypes::Tensor out, const Reducer& reducer, + const bool reverse, const bool exclusive) { + // Perform the reverse ops directly with Eigen, which avoids copying the + // tensor twice compared to using individual ops. + Eigen::array dims; + dims[0] = false; + dims[1] = reverse; + dims[2] = false; + MaybeWith32BitIndexing( + [&](auto in32, auto out32) { + out32.device(d) = + in32.reverse(dims).scan(1, reducer, exclusive).reverse(dims); + }, + in, out); + } +}; + +template +struct LogSumExp { + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& a, + const T& b) const { + auto mi = Eigen::internal::scalar_min_op()(a, b); + auto ma = Eigen::internal::scalar_max_op()(a, b); + + auto sub = Eigen::internal::scalar_difference_op(); + auto add = Eigen::internal::scalar_sum_op(); + auto exp = Eigen::internal::scalar_exp_op(); + auto log1p = Eigen::internal::scalar_log1p_op(); + auto cmp_lt = + Eigen::internal::scalar_cmp_op(); + + auto logsumexp = add(log1p(exp(sub(mi, ma))), ma); + return cmp_lt(ma, Eigen::NumTraits::lowest()) ? ma : logsumexp; + } + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T packetOp(const T& a, + const T& b) const { + auto mi = Eigen::internal::pmin(a, b); + auto ma = Eigen::internal::pmax(a, b); + using Eigen::internal::padd; + using Eigen::internal::pcmp_lt; + using Eigen::internal::pexp; + using Eigen::internal::plog1p; + using Eigen::internal::pset1; + using Eigen::internal::psub; + + auto logsumexp = padd(plog1p(pexp(psub(mi, ma))), ma); + return pselect(pcmp_lt(ma, pset1(Eigen::NumTraits::lowest())), ma, + logsumexp); + } +}; + +template +struct LogSumExpReducer { + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void reduce(const T t, T* accum) const { + LogSumExp logsumexp; + *accum = logsumexp(*accum, t); + } + + template + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void reducePacket(const Packet& p, + Packet* accum) const { + LogSumExp logsumexp; + *accum = logsumexp.packetOp(*accum, p); + } + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T initialize() const { + return -Eigen::NumTraits::infinity(); + } + + template + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet initializePacket() const { + return Eigen::internal::pset1(initialize()); + } + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T finalize(const T accum) const { + return accum; + } + + template + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet + finalizePacket(const Packet& vaccum) const { + return vaccum; + } + + template + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T + finalizeBoth(const T saccum, const Packet& vaccum) const { + auto max_reducer = Eigen::internal::MaxReducer(); + auto sum_reducer = Eigen::internal::SumReducer(); + auto exp = Eigen::internal::scalar_exp_op(); + auto cmp_lt = + Eigen::internal::scalar_cmp_op(); + auto log = Eigen::internal::scalar_log_op(); + auto add = Eigen::internal::scalar_sum_op(); + + using Eigen::internal::pexp; + using Eigen::internal::psub; + + // `ma = max(x1, ..., xn)` + // If the max of all of the `xi` is `-infinity` then the result is + // -infinity. If the max is larger than `-infinity` then it's safe to use + // for normalization even if the other elements are `-infinity`. + // + // `logsumexp(x1, ..., xn) = ma + log (exp(x1 - ma) + ... + exp(xn - ma))` + auto ma = max_reducer.finalizeBoth(saccum, vaccum); + auto logsumexp = add(log(sum_reducer.finalizeBoth( + exp(saccum - ma), pexp(psub(vaccum, pset1(ma))))), + ma); + return cmp_lt(ma, Eigen::NumTraits::lowest()) ? initialize() : logsumexp; + } +}; + +} // namespace functor +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_SCAN_OPS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/scan_ops_gpu.h b/third_party/tflite-hdrs/tensorflow/core/kernels/scan_ops_gpu.h new file mode 100644 index 00000000..15b4e5e1 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/scan_ops_gpu.h @@ -0,0 +1,334 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_SCAN_OPS_GPU_H_ +#define TENSORFLOW_CORE_KERNELS_SCAN_OPS_GPU_H_ + +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM + +#define EIGEN_USE_GPU + +#define CUB_USE_COOPERATIVE_GROUPS + +#include "tensorflow/core/framework/numeric_types.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/kernels/gpu_prim.h" +#include "tensorflow/core/kernels/scan_ops.h" +#include "tensorflow/core/util/gpu_kernel_helper.h" +#include "tensorflow/core/util/gpu_launch_config.h" +#include "tensorflow/core/util/permutation_input_iterator.h" +#include "tensorflow/core/util/permutation_output_iterator.h" + +namespace tensorflow { + +typedef Eigen::GpuDevice GPUDevice; +typedef Eigen::Index Index; + +namespace functor { + +// Map a contiguous range to the actual memory locations depending on which +// axis the scan is taking place over and whether or not reversed. +struct MapIndexToLocation { + __host__ __device__ MapIndexToLocation(int dimx, int dimy, int dimz, + bool reverse = false) + : dimx_(dimx), dimy_(dimy), dimz_(dimz), reverse_(reverse) {} + + __host__ __device__ int operator()(int id) const { + if (dimx_ == 1) { + int row = id % dimy_; + int col = id / dimy_; + + if (reverse_) return (dimy_ - row - 1) * dimz_ + col; + + return row * dimz_ + col; + } else if (dimz_ == 1) { + if (reverse_) { + int row = id / dimy_; + int col = id % dimy_; + return row * dimy_ + (dimy_ - col - 1); + } + return id; + } else { + int col = id % dimy_; + int tmp = id / dimy_; + + int row1 = id / (dimy_ * dimz_); + int col1 = tmp % dimz_; + + if (reverse_) + return row1 * dimy_ * dimz_ + (dimy_ - col - 1) * dimz_ + col1; + + return row1 * dimy_ * dimz_ + col * dimz_ + col1; + } + } + + int dimx_; + int dimy_; + int dimz_; + bool reverse_; +}; + +template +struct BlockPrefixCallbackOp { + // Running prefix + T running_total_; + Op op_; + + __device__ BlockPrefixCallbackOp(T running_total, Op op) + : running_total_(running_total), op_(op) {} + + // Callback operator to be entered by the first warp of threads in the block. + // tid 0 is responsible for returning a value for seeding the block-wide scan. + __device__ T operator()(T block_aggregate) { + T old_prefix = running_total_; + running_total_ = op_(old_prefix, block_aggregate); + return old_prefix; + } +}; + +template +struct Sum { + __host__ __device__ T operator()(const T& a, const T& b) const { + return a + b; + } +}; + +template +struct Prod { + __host__ __device__ T operator()(const T& a, const T& b) const { + return a * b; + } +}; + +template +struct IsSum { + constexpr static bool value = + (std::is_same>::value || + std::is_same>::value); +}; + +template +struct IsProd { + constexpr static bool value = + (std::is_same>::value || + std::is_same>::value); +}; + +template +struct IsLogSumExp { + constexpr static bool value = (std::is_same>::value || + std::is_same>::value); +}; + +template +struct IdentityValue { + static_assert(IsSum::value || IsProd::value || + IsLogSumExp::value, + "IdentityValue not yet defined for this type."); + + template + __host__ __device__ U operator()( + typename std::enable_if::value, U>::type t = U(0)) { + return t; + } + + template + __host__ __device__ U operator()( + typename std::enable_if::value, U>::type t = U(1)) { + return t; + } + + template + __host__ __device__ U + operator()(typename std::enable_if::value, U>::type t = + U(Eigen::NumTraits::lowest())) { + return t; + } +}; + +// Each block is mapped to one sequence. A contiguous range is mapped to the +// appropriate locations in memory by the permutation iterators. This is +// ideal for 1-D and row based scans. Column scans would be better if they +// did a block load and then locally transposed. CUB's device wide scan is not +// used in the large 1D case, even though it would be more efficient, because +// it is not deterministic. +template +__launch_bounds__(BlockDim) __global__ + void scan_kernel(const T* in, T* out, int dimx, int dimy, int dimz, + bool exclusive, bool reverse, Op op) { + typedef gpuprim::BlockLoad + BlockLoad; + typedef gpuprim::BlockStore + BlockStore; + typedef gpuprim::BlockScan BlockScan; + + // Allocate aliased shared memory for BlockLoad, BlockStore, and BlockScan + __shared__ union { + typename BlockLoad::TempStorage load; + typename BlockScan::TempStorage scan; + typename BlockStore::TempStorage store; + } temp_storage; + + int problem_length = dimy; + + // Initialize running total + BlockPrefixCallbackOp prefix_op(IdentityValue()(), op); + + MapIndexToLocation map_op(dimx, dimy, dimz, reverse); + int block_start = problem_length * blockIdx.x; + // Have the block iterate over segments of items + for (int block_offset = block_start; + block_offset < block_start + problem_length; + block_offset += BlockDim * ItemsPerThread) { + int valid_items = min(BlockDim * ItemsPerThread, + problem_length - (block_offset % problem_length)); + + // first construct a counting iterator that has the desired start point + typedef gpuprim::TransformInputIterator> + MapIterType; + + gpuprim::CountingInputIterator counting_iter(block_offset); + + // Next map the iterator to the actual locations in memory + MapIterType map_iter(counting_iter, map_op); + + PermutationInputIterator permutein_iter(in, + map_iter); + PermutationOutputIterator permuteout_iter(out, + map_iter); + + // Load a segment of consecutive items that are blocked across threads + T thread_data[ItemsPerThread]; + BlockLoad(temp_storage.load).Load(permutein_iter, thread_data, valid_items); + __syncthreads(); + + // Collectively compute the block-wide scan + if (exclusive) { + BlockScan(temp_storage.scan) + .ExclusiveScan(thread_data, thread_data, op, prefix_op); + } else { + BlockScan(temp_storage.scan) + .InclusiveScan(thread_data, thread_data, op, prefix_op); + } + __syncthreads(); + + // Store scanned items to output segment + BlockStore(temp_storage.store) + .Store(permuteout_iter, thread_data, valid_items); + __syncthreads(); + } +} + +template +void LaunchScan(const GPUDevice& d, typename TTypes::ConstTensor in, + typename TTypes::Tensor out, Op op, const bool reverse, + const bool exclusive) { + const int items_per_thread = 4; + + int dimx = in.dimension(0); + int dimy = in.dimension(1); + int dimz = in.dimension(2); + int num_blocks = dimx * dimz; + + int ideal_block_size = dimy / items_per_thread; + const int rocm_threads_per_warp = 64; + ideal_block_size = std::max(ideal_block_size, rocm_threads_per_warp); + + // There seems to be a bug when the type is not float and block_size 1024. + // Launch on the smallest power of 2 block size that we can. + if (ideal_block_size >= 1024 && std::is_same::value) { + const int block_size = 1024; + TF_CHECK_OK( + GpuLaunchKernel(scan_kernel, + num_blocks, block_size, 0, d.stream(), in.data(), + out.data(), dimx, dimy, dimz, exclusive, reverse, op)); + } else if (ideal_block_size >= 512) { + const int block_size = 512; + TF_CHECK_OK( + GpuLaunchKernel(scan_kernel, + num_blocks, block_size, 0, d.stream(), in.data(), + out.data(), dimx, dimy, dimz, exclusive, reverse, op)); + } else if (ideal_block_size >= 256) { + const int block_size = 256; + TF_CHECK_OK( + GpuLaunchKernel(scan_kernel, + num_blocks, block_size, 0, d.stream(), in.data(), + out.data(), dimx, dimy, dimz, exclusive, reverse, op)); + } else if (ideal_block_size >= 128) { + const int block_size = 128; + TF_CHECK_OK( + GpuLaunchKernel(scan_kernel, + num_blocks, block_size, 0, d.stream(), in.data(), + out.data(), dimx, dimy, dimz, exclusive, reverse, op)); +#if TENSORFLOW_COMPILER_IS_HIP_CLANG + // HIP-CLANG has some kind of problem here with 32 threads (possibly because + // the warpsize is 64). Reenable when working properly + } else if (true) { +#else + } else if (ideal_block_size >= 64) { +#endif + const int block_size = 64; + TF_CHECK_OK( + GpuLaunchKernel(scan_kernel, + num_blocks, block_size, 0, d.stream(), in.data(), + out.data(), dimx, dimy, dimz, exclusive, reverse, op)); + } else { + const int block_size = 32; + TF_CHECK_OK( + GpuLaunchKernel(scan_kernel, + num_blocks, block_size, 0, d.stream(), in.data(), + out.data(), dimx, dimy, dimz, exclusive, reverse, op)); + } +} + +template +struct Scan, T> { + void operator()(const GPUDevice& d, typename TTypes::ConstTensor in, + typename TTypes::Tensor out, + const Eigen::internal::SumReducer& reducer, + const bool reverse, const bool exclusive) { + LaunchScan>(d, in, out, Sum(), reverse, exclusive); + } +}; + +template +struct Scan, T> { + void operator()(const GPUDevice& d, typename TTypes::ConstTensor in, + typename TTypes::Tensor out, + const Eigen::internal::ProdReducer& reducer, + const bool reverse, const bool exclusive) { + LaunchScan>(d, in, out, Prod(), reverse, exclusive); + } +}; + +template +struct Scan, T> { + void operator()(const GPUDevice& d, typename TTypes::ConstTensor in, + typename TTypes::Tensor out, + const LogSumExpReducer& reducer, const bool reverse, + const bool exclusive) { + LaunchScan>(d, in, out, LogSumExp(), reverse, exclusive); + } +}; + +} // namespace functor +} // end namespace tensorflow + +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM + +#endif // TENSORFLOW_CORE_KERNELS_SCAN_OPS_GPU_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/scatter_functor.h b/third_party/tflite-hdrs/tensorflow/core/kernels/scatter_functor.h new file mode 100644 index 00000000..dcfae9b7 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/scatter_functor.h @@ -0,0 +1,414 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_SCATTER_FUNCTOR_H_ +#define TENSORFLOW_CORE_KERNELS_SCATTER_FUNCTOR_H_ + +#include + +#include "Eigen/Core" // from @eigen_archive +#include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/bounds_check.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/variant_op_registry.h" +#include "tensorflow/core/kernels/dense_update_functor.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/determinism.h" +#include "tensorflow/core/util/work_sharder.h" + +namespace tensorflow { + +class OpKernelContext; +typedef Eigen::ThreadPoolDevice CPUDevice; +typedef Eigen::GpuDevice GPUDevice; + +namespace scatter_op { + +enum class UpdateOp { ASSIGN, ADD, SUB, MUL, DIV, MIN, MAX }; + +namespace internal { + +template +struct Assign {}; +template <> +struct Assign { + template + static void Run(Params p, Update u) { + p = u; + } + template + static void RunScalar(Params p, Update u) { + p.setConstant(u); + } +}; +template <> +struct Assign { + template + static void Run(Params p, Update u) { + p += u; + } + template + static void RunScalar(Params p, Update u) { + p = p + u; + } +}; +template <> +struct Assign { + template + static void Run(Params p, Update u) { + p -= u; + } + template + static void RunScalar(Params p, Update u) { + p = p + static_cast(-u); + } +}; +template <> +struct Assign { + template + static void Run(Params p, Update u) { + p *= u; + } + template + static void RunScalar(Params p, Update u) { + p = p * u; + } +}; +template <> +struct Assign { + template + static void Run(Params p, Update u) { + p /= u; + } + template + static void RunScalar(Params p, Update u) { + p = p / u; + } +}; +template <> +struct Assign { + // This method requires that Params and Update are tensor types. + template + static void Run(Params p, Update u) { + p = p.cwiseMin(u); + } + // Same thing, but for Update being a scalar type. + template + static void RunScalar(Params p, Update u) { + p = p.cwiseMin(u); + } +}; +template <> +struct Assign { + template + static void Run(Params p, Update u) { + p = p.cwiseMax(u); + } + template + static void RunScalar(Params p, Update u) { + p = p.cwiseMax(u); + } +}; + + +} // namespace internal +} // namespace scatter_op + +namespace functor { +template +struct ScatterFunctor { + Index operator()(OpKernelContext* c, const Device& d, + typename TTypes::Matrix params, + typename TTypes::ConstMatrix updates, + typename TTypes::ConstFlat indices); +}; + +template +struct ScatterFunctorBase { + Index ParallelExecute(OpKernelContext* c, const Device& d, + typename TTypes::Matrix params, + typename TTypes::ConstMatrix updates, + typename TTypes::ConstFlat indices) { + const Index N = static_cast(indices.size()); + const Index limit = static_cast(params.dimension(0)); + const Index kMaxLocks = 1024; + const Index entries_per_lock = (limit + kMaxLocks - 1) / kMaxLocks; + // To reduce the number of locks and the memory usage, we divide the whole + // index space into kMaxLocks regions with each lock serializing access to + // a region. + mutex accessed[kMaxLocks]; + std::atomic bad_index(-1); + auto ParallelScatter = [&](Index start, Index end) { + for (Index i = start; i < end; ++i) { + // Grab the index and check its validity. Do this carefully, + // to avoid checking the value and grabbing it again from + // memory a second time (a security risk since it may change in + // between). + const Index index = ::tensorflow::internal::SubtleMustCopy(indices(i)); + if (!FastBoundsCheck(index, limit)) { + bad_index = i; + return; + } + const Index lock_id = index / entries_per_lock; + // Copy last Ndim-1 dimensions of updates[i] to params[index] + { + mutex_lock l(accessed[lock_id]); + scatter_op::internal::Assign::Run(params.template chip<0>(index), + updates.template chip<0>(i)); + } + } + }; + const float kMovingCost = 2.5f; + float shard_cost = kMovingCost * params.dimension(1); + const DeviceBase::CpuWorkerThreads& worker_threads = + *(c->device()->tensorflow_cpu_worker_threads()); + Shard(worker_threads.num_threads, worker_threads.workers, N, shard_cost, + ParallelScatter); // TODO: Come up with a good cost estimate. + return bad_index; + } + Index SerialExecute(OpKernelContext* c, const Device& d, + typename TTypes::Matrix params, + typename TTypes::ConstMatrix updates, + typename TTypes::ConstFlat indices) { + const Index N = static_cast(indices.size()); + const Index limit = static_cast(params.dimension(0)); + for (Index i = 0; i < N; ++i) { + // Grab the index and check its validity. Do this carefully, + // to avoid checking the value and grabbing it again from + // memory a second time (a security risk since it may change in + // between). + const Index index = ::tensorflow::internal::SubtleMustCopy(indices(i)); + if (!FastBoundsCheck(index, limit)) return i; + // Copy last Ndim-1 dimensions of updates[i] to params[index] + scatter_op::internal::Assign::Run(params.template chip<0>(index), + updates.template chip<0>(i)); + } + return -1; + } + + Index operator()(OpKernelContext* c, const Device& d, + typename TTypes::Matrix params, + typename TTypes::ConstMatrix updates, + typename TTypes::ConstFlat indices) { +#ifdef PLATFORM_GOOGLE + // The parallel version is significantly slower internally. Only call the + // serial version for now. + // TODO(penporn): Avoid locking in parallelization (sort beforehand). + return SerialExecute(c, d, params, updates, indices); +#else + // indices and params sizes were validated in DoCompute(). + const Index N = static_cast(indices.size()); + const Index limit = static_cast(params.dimension(0)); + const Index min_n_threshold = 1024; + const Index ser_par_ratio = 10000; + // For parallelizing the updates, duplicate entries need to be handled + // correctly. Multiple updates to the same index has to be serialized. + // This can lead to lock contention which may nullify the benefits of + // parallelization. Assuming uniform random distribution of the indices, we + // come up with a rough heuristic and determine whether the updates execute + // serially or parallelly. Also if 'N' is small, overheads of parallel + // execution outweigh its benefits and hence we check the value of N. + const bool execute_serial = N < min_n_threshold || + (N / limit) > ser_par_ratio || + OpDeterminismRequired(); + if (execute_serial) + return SerialExecute(c, d, params, updates, indices); + else + return ParallelExecute(c, d, params, updates, indices); +#endif // PLATFORM_GOOGLE + } +}; + +template +struct ScatterFunctorVariantAssignBase { + Index operator()(OpKernelContext* c, const Device& d, + typename TTypes::Matrix params, + typename TTypes::ConstMatrix updates, + typename TTypes::ConstFlat indices) { + // indices and params sizes were validated in DoCompute(). + const Index N = static_cast(indices.size()); + const Index limit = static_cast(params.dimension(0)); + const Index cols = static_cast(params.dimension(1)); + DCHECK_EQ(N, updates.dimension(0)); + DCHECK_EQ(cols, updates.dimension(1)); + for (Index i = 0; i < N; i++) { + // Grab the index and check its validity. Do this carefully, + // to avoid checking the value and grabbing it again from + // memory a second time (a security risk since it may change in between). + const Index index = ::tensorflow::internal::SubtleMustCopy(indices(i)); + if (!FastBoundsCheck(index, limit)) return i; + // Copy last Ndim-1 dimensions of updates[i] to params[index] + for (int j = 0; j < cols; ++j) { + const Variant& to_scatter = updates(i, j); + params(index, j) = to_scatter; + } + } + return -1; + } +}; + +template +struct ScatterFunctor + : ScatterFunctorVariantAssignBase {}; + +template +struct ScatterFunctor + : ScatterFunctorVariantAssignBase {}; + + +template +struct ScatterFunctorBase { + Index operator()(OpKernelContext* c, const CPUDevice& d, + typename TTypes::Matrix params, + typename TTypes::ConstMatrix updates, + typename TTypes::ConstFlat indices) { + // indices and params sizes were validated in DoCompute(). + const Index N = static_cast(indices.size()); + const Index limit = static_cast(params.dimension(0)); + if (!std::is_same::value) { + for (Index i = 0; i < N; i++) { + // Grab the index and check its validity. Do this carefully, + // to avoid checking the value and grabbing it again from + // memory a second time (a security risk since it may change in + // between). + const Index index = ::tensorflow::internal::SubtleMustCopy(indices(i)); + if (!FastBoundsCheck(index, limit)) return i; + memmove(params.data() + index * params.dimension(1), + updates.data() + i * updates.dimension(1), + updates.dimension(1) * sizeof(T)); + } + } else { + for (Index i = 0; i < N; i++) { + // Grab the index and check its validity. Do this carefully, + // to avoid checking the value and grabbing it again from + // memory a second time (a security risk since it may change in + // between). + const Index index = ::tensorflow::internal::SubtleMustCopy(indices(i)); + if (!FastBoundsCheck(index, limit)) return i; + // Copy last Ndim-1 dimensions of updates[i] to params[index] + scatter_op::internal::Assign::Run( + params.template chip<0>(index), updates.template chip<0>(i)); + } + } + return -1; + } +}; + +template +struct ScatterFunctor + : ScatterFunctorBase {}; + + +template +struct ScatterScalarFunctor { + Index operator()(OpKernelContext* c, const Device& d, + typename TTypes::Matrix params, + const typename TTypes::ConstScalar update, + typename TTypes::ConstFlat indices); +}; + +template +struct ScatterScalarFunctorBase { + Index operator()(OpKernelContext* c, const Device& d, + typename TTypes::Matrix params, + const typename TTypes::ConstScalar update, + typename TTypes::ConstFlat indices) { + // indices and params sizes were validated in DoCompute(). + const Index N = static_cast(indices.size()); + const Index limit = static_cast(params.dimension(0)); + for (Index i = 0; i < N; i++) { + // Grab the index and check its validity. Do this carefully, + // to avoid checking the value and grabbing it again from + // memory a second time (a security risk since it may change in between). + const Index index = ::tensorflow::internal::SubtleMustCopy(indices(i)); + if (!FastBoundsCheck(index, limit)) return i; + // Broadcast update to params[index] + scatter_op::internal::Assign::RunScalar( + params.template chip<0>(index), update()); + } + return -1; + } +}; + +template +struct ScatterScalarFunctorVariantAssignBase { + Index operator()(OpKernelContext* c, const Device& d, + typename TTypes::Matrix params, + const typename TTypes::ConstScalar update, + typename TTypes::ConstFlat indices) { + // indices and params sizes were validated in DoCompute(). + const Index N = static_cast(indices.size()); + const Index limit = static_cast(params.dimension(0)); + const Index cols = static_cast(params.dimension(1)); + const Variant& to_scatter = update(); + for (Index i = 0; i < N; i++) { + // Grab the index and check its validity. Do this carefully, + // to avoid checking the value and grabbing it again from + // memory a second time (a security risk since it may change in between). + const Index index = ::tensorflow::internal::SubtleMustCopy(indices(i)); + if (!FastBoundsCheck(index, limit)) return i; + // Broadcast update to params[index] + for (Index j = 0; j < cols; ++j) { + params(index, j) = to_scatter; + } + } + return -1; + } +}; + +template +struct ScatterScalarFunctor + : ScatterScalarFunctorVariantAssignBase {}; +template +struct ScatterScalarFunctor + : ScatterScalarFunctorVariantAssignBase {}; + + +template +struct ScatterScalarFunctorBase { + Index operator()(OpKernelContext* c, const CPUDevice& d, + typename TTypes::Matrix params, + const typename TTypes::ConstScalar update, + typename TTypes::ConstFlat indices) { + // indices and params sizes were validated in DoCompute(). + const Index N = static_cast(indices.size()); + const Index limit = static_cast(params.dimension(0)); + for (Index i = 0; i < N; i++) { + // Grab the index and check its validity. Do this carefully, + // to avoid checking the value and grabbing it again from + // memory a second time (a security risk since it may change in between). + const Index index = ::tensorflow::internal::SubtleMustCopy(indices(i)); + if (!FastBoundsCheck(index, limit)) return i; + // Broadcast update to params[index] + scatter_op::internal::Assign::RunScalar( + params.template chip<0>(index), update()); + } + return -1; + } +}; + +template +struct ScatterScalarFunctor + : ScatterScalarFunctorBase {}; + + +} // namespace functor +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_SCATTER_FUNCTOR_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/scatter_functor_gpu.cu.h b/third_party/tflite-hdrs/tensorflow/core/kernels/scatter_functor_gpu.cu.h new file mode 100644 index 00000000..61868b78 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/scatter_functor_gpu.cu.h @@ -0,0 +1,179 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_SCATTER_FUNCTOR_GPU_CU_H_ +#define TENSORFLOW_CORE_KERNELS_SCATTER_FUNCTOR_GPU_CU_H_ + +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM + +#define EIGEN_USE_GPU + +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/kernels/scatter_functor.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/gpu_kernel_helper.h" + +namespace tensorflow { + +typedef Eigen::GpuDevice GPUDevice; + +namespace scatter_op_gpu { + +template +struct ScatterOpKernelBody; + +template +struct ScatterOpKernelBody { + __device__ void operator()(T* __restrict__ dest, T src) const { *dest = src; } +}; + +template +struct ScatterOpKernelBody { + __device__ void operator()(T* __restrict__ dest, T src) const { + GpuAtomicAdd(dest, src); + } +}; + +template +struct ScatterOpKernelBody { + __device__ void operator()(T* __restrict__ dest, T src) const { + GpuAtomicSub(dest, src); + } +}; + +template +struct ScatterOpKernelBody { + __device__ void operator()(T* __restrict__ dest, T src) const { + GpuAtomicMul(dest, src); + } +}; + +template +struct ScatterOpKernelBody { + __device__ void operator()(T* __restrict__ dest, T src) const { + GpuAtomicDiv(dest, src); + } +}; + +template +struct ScatterOpKernelBody { + __device__ void operator()(T* __restrict__ dest, T src) const { + GpuAtomicMin(dest, src); + } +}; + +template +struct ScatterOpKernelBody { + __device__ void operator()(T* __restrict__ dest, T src) const { + GpuAtomicMax(dest, src); + } +}; + +template +__global__ void ScatterOpCustomKernel(T* __restrict__ params, + const T* __restrict__ updates, + const Index* __restrict__ indices, + Index first_dim_size, Index updates_size, + Index indices_size) { + Index update_block = updates_size / indices_size; + ScatterOpKernelBody body; + GPU_1D_KERNEL_LOOP(i, updates_size) { + int indices_i = i / update_block; + int updates_i = i; + int param_first_index = indices[indices_i]; + if (!(param_first_index >= 0 && param_first_index < first_dim_size)) { + // Ignore indices that are out of range. + continue; + } + int64 params_i = param_first_index * update_block + (i % update_block); + body(¶ms[params_i], ldg(updates + updates_i)); + } +} + +template +__global__ void ScatterScalarOpCustomKernel(T* __restrict__ params, + const T* __restrict__ update, + const Index* __restrict__ indices, + Index first_dim_size, + Index indices_size, + Index synthesized_updates_size) { + Index update_block = synthesized_updates_size / indices_size; + ScatterOpKernelBody body; + GPU_1D_KERNEL_LOOP(i, synthesized_updates_size) { + int indices_i = i / update_block; + int param_first_index = indices[indices_i]; + const T update_val = *update; + if (!(param_first_index >= 0 && param_first_index < first_dim_size)) { + // Ignore indices that are out of range. + continue; + } + int params_i = param_first_index * update_block + (i % update_block); + body(¶ms[params_i], update_val); + } +} + +} // namespace scatter_op_gpu + +namespace functor { +// Specialization for a GPU device. +template +struct ScatterFunctor { + Index operator()(OpKernelContext* c, const GPUDevice& d, + typename TTypes::Matrix params, + typename TTypes::ConstMatrix updates, + typename TTypes::ConstFlat indices) { + // TODO(b/31801742): Implement indices range check. The hardest part is + // with returning a value after the range check, as we do not want to do + // device to host memcpy during a stream. + const Index first_dim_size = params.dimension(0); + const Index indices_size = indices.size(); + const Index updates_size = updates.size(); + GpuLaunchConfig config = GetGpuLaunchConfig(updates_size, d); + TF_CHECK_OK(GpuLaunchKernel( + scatter_op_gpu::ScatterOpCustomKernel, config.block_count, + config.thread_per_block, 0, d.stream(), params.data(), updates.data(), + indices.data(), first_dim_size, updates_size, indices_size)); + return -1; + } +}; + +template +struct ScatterScalarFunctor { + Index operator()(OpKernelContext* c, const GPUDevice& d, + typename TTypes::Matrix params, + const typename TTypes::ConstScalar update, + typename TTypes::ConstFlat indices) { + // TODO(b/31801742): Implement indices range check. The hardest part is + // with returning a value after the range check, as we do not want to do + // device to host memcpy during a stream. + const Index first_dim_size = params.dimension(0); + const Index indices_size = indices.size(); + const Index synthesized_updates_size = indices_size * params.dimension(1); + GpuLaunchConfig config = GetGpuLaunchConfig(synthesized_updates_size, d); + TF_CHECK_OK(GpuLaunchKernel( + scatter_op_gpu::ScatterScalarOpCustomKernel, + config.block_count, config.thread_per_block, 0, d.stream(), + params.data(), update.data(), indices.data(), first_dim_size, + indices_size, synthesized_updates_size)); + return -1; + } +}; + +} // namespace functor +} // namespace tensorflow + +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM + +#endif // TENSORFLOW_CORE_KERNELS_SCATTER_FUNCTOR_GPU_CU_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/scatter_nd_op.h b/third_party/tflite-hdrs/tensorflow/core/kernels/scatter_nd_op.h new file mode 100644 index 00000000..b736d4b0 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/scatter_nd_op.h @@ -0,0 +1,74 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_SCATTER_ND_OP_H_ +#define TENSORFLOW_CORE_KERNELS_SCATTER_ND_OP_H_ + +#include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive + +#include "tensorflow/core/framework/bounds_check.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/kernels/fill_functor.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/util.h" + +namespace tensorflow { + +typedef Eigen::ThreadPoolDevice CPUDevice; + +class OpKernelContext; + +namespace scatter_nd_op { + +enum class UpdateOp { ASSIGN, ADD, SUB, MIN, MAX }; + +} // namespace scatter_nd_op + +namespace functor { + +// Functor used by ScatterOp to do the computations. +template +struct ScatterNdFunctor { + // Returns -1 on success or a nonnegative i s.t. indices[i] is a bad index. + Index operator()( + const Device& d, const Index slice_size, + const Eigen::array output_shape_prefix, + typename TTypes::Tensor Tparams, + typename TTypes::ConstTensor Tindices, + typename TTypes::ConstTensor Tupdates, + typename TTypes::Tensor Toutput); +}; + +// Scatter updates into indices in Tensor out. The argument allocate +// controls whether 'out' should be created. If allocate is true, +// *out will be updated to the scattered tensor upon successful completion. +// If allocate is false, out must point to a Tensor allocated with the +// right type (T) and shape. This tensor will not be zeroed out +// before the scatter is executed. +template +absl::Status DoScatterNd(OpKernelContext* c, const Tensor& indices, + const Tensor& updates, const TensorShape& shape, + Tensor* out, bool allocate); + +} // namespace functor +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_SCATTER_ND_OP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/scatter_nd_op_cpu_impl.h b/third_party/tflite-hdrs/tensorflow/core/kernels/scatter_nd_op_cpu_impl.h new file mode 100644 index 00000000..c4cc570b --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/scatter_nd_op_cpu_impl.h @@ -0,0 +1,199 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_SCATTER_ND_OP_CPU_IMPL_H_ +#define TENSORFLOW_CORE_KERNELS_SCATTER_ND_OP_CPU_IMPL_H_ + +// Functor definitions for ScatterND ops, must be compilable by nvcc. + +#define EIGEN_USE_THREADS + +#include + +#include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive + +#include "tensorflow/core/framework/bounds_check.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/kernels/fill_functor.h" +#include "tensorflow/core/kernels/scatter_nd_op.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/util.h" + +namespace tensorflow { + +typedef Eigen::ThreadPoolDevice CPUDevice; + +class OpKernelContext; + +// Specialization of UpdateExecutor to CPU +namespace update_executor { + +template +class UpdateExecutor { + public: + EIGEN_STRONG_INLINE static void Execute(const T& device, Input value, + Update update, Output output); +}; + +template +class UpdateExecutor { + public: + EIGEN_STRONG_INLINE static void Execute(const T& device, Input /* input */, + Update update, Output output) { + output.device(device) = update; + } +}; + +template +class UpdateExecutor { + public: + EIGEN_STRONG_INLINE static void Execute(const T& device, Input /* input */, + Update update, Output output) { + output.device(device) += update; + } +}; + +template +class UpdateExecutor { + public: + EIGEN_STRONG_INLINE static void Execute(const T& device, Input /* input */, + Update update, Output output) { + output.device(device) -= update; + } +}; + +template +class UpdateExecutor { + public: + EIGEN_STRONG_INLINE static void Execute(const T& device, Input /* input */, + Update update, Output output) { + output.device(device) = output.cwiseMin(update); + } +}; + +template +class UpdateExecutor { + public: + EIGEN_STRONG_INLINE static void Execute(const T& device, Input /* input */, + Update update, Output output) { + output.device(device) = output.cwiseMax(update); + } +}; + +} // namespace update_executor + +namespace functor { + +// Implementation of update functor for CPU. +template +struct ScatterNdFunctor { + Index operator()( + const CPUDevice& d, const Index slice_size, + const Eigen::array output_shape_prefix, + typename TTypes::Tensor Tparams, + typename TTypes::ConstTensor Tindices, + typename TTypes::ConstTensor Tupdates, + typename TTypes::Tensor Toutput) { + // error_loc is -1 if there's no out-of-bounds index, + // otherwise it is the location of an OOB index in Tindices. + Index error_loc = -1; + + const Eigen::DenseIndex batch_size = Tindices.dimension(0); + + Index batch_strides[IXDIM]; + if (IXDIM > 0) { + batch_strides[IXDIM - 1] = 1; + } + for (int dim = IXDIM - 2; dim >= 0; --dim) { + batch_strides[dim] = + batch_strides[dim + 1] * output_shape_prefix[dim + 1]; + } + + for (Eigen::DenseIndex loc = 0; loc < batch_size; ++loc) { + Index i = 0; + bool out_of_bounds = false; + for (int dim = 0; dim < IXDIM; ++dim) { + const Index ix_d = internal::SubtleMustCopy(Tindices(loc, dim)); + out_of_bounds |= !FastBoundsCheck(ix_d, output_shape_prefix[dim]); + i += ix_d * batch_strides[dim]; + } + if (TF_PREDICT_FALSE(out_of_bounds)) { + error_loc = loc; + // Don't break the loop here, but continue to update the rest because + // the caller might ignore bad indices. + continue; + } else { + auto input_chip = Toutput.template chip<0>(i); + auto output_chip = input_chip; + auto update_chip = Tupdates.template chip<0>(loc); + update_executor::UpdateExecutor< + CPUDevice, decltype(input_chip), decltype(update_chip), + decltype(output_chip), OP>::Execute(d, input_chip, update_chip, + output_chip); + } + } + + return error_loc; + } +}; + +#define REGISTER_SCATTER_ND_FULL(T, Index, op) \ + template Index \ + ScatterNdFunctor::operator()( \ + const CPUDevice& d, const Index slice_size, \ + const Eigen::array \ + output_shape_prefix, \ + typename TTypes::Tensor Tparams, \ + typename TTypes::ConstTensor Tindices, \ + typename TTypes::ConstTensor Tupdates, \ + typename TTypes::Tensor Toutput) + +#define REGISTER_SCATTER_ND_INDEX(type, op) \ + REGISTER_SCATTER_ND_FULL(type, int32, op); \ + REGISTER_SCATTER_ND_FULL(type, int64, op) + +#define REGISTER_SCATTER_ND_UPDATE(type) \ + REGISTER_SCATTER_ND_INDEX(type, scatter_nd_op::UpdateOp::ASSIGN); + +#define REGISTER_SCATTER_ND_MATH(type) \ + REGISTER_SCATTER_ND_INDEX(type, scatter_nd_op::UpdateOp::ADD); \ + REGISTER_SCATTER_ND_INDEX(type, scatter_nd_op::UpdateOp::SUB); + +#define REGISTER_SCATTER_ND_MIN_MAX(type) \ + REGISTER_SCATTER_ND_INDEX(type, scatter_nd_op::UpdateOp::MAX); \ + REGISTER_SCATTER_ND_INDEX(type, scatter_nd_op::UpdateOp::MIN); + +TF_CALL_ALL_TYPES(REGISTER_SCATTER_ND_UPDATE); +REGISTER_SCATTER_ND_INDEX(tstring, scatter_nd_op::UpdateOp::ADD); +TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_ND_MATH); +TF_CALL_REAL_NUMBER_TYPES(REGISTER_SCATTER_ND_MIN_MAX); +TF_CALL_bool(REGISTER_SCATTER_ND_MATH); + +#undef REGISTER_SCATTER_ND_MATH +#undef REGISTER_SCATTER_ND_MIN_MAX +#undef REGISTER_SCATTER_ND_UPDATE +#undef REGISTER_SCATTER_ND_INDEX +#undef REGISTER_SCATTER_ND_FULL +} // namespace functor + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_SCATTER_ND_OP_CPU_IMPL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/scatter_nd_util.h b/third_party/tflite-hdrs/tensorflow/core/kernels/scatter_nd_util.h new file mode 100644 index 00000000..5095e925 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/scatter_nd_util.h @@ -0,0 +1,47 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_SCATTER_ND_UTIL_H_ +#define TENSORFLOW_CORE_KERNELS_SCATTER_ND_UTIL_H_ + +#include "xla/tsl/util/env_var.h" +#include "tensorflow/core/framework/tensor_shape.h" + +namespace tensorflow { + +// Validates the input shapes for the ScatterNdUpdateOp +absl::Status ValidateScatterNdUpdateShape(const TensorShape& params_shape, + const TensorShape& indices_shape, + const TensorShape& updates_shape); + +inline bool DisableScatterOpDeterminism() { + static bool cached_disable = [] { + bool disable = false; + // When determinism is enabled, the kernels for various scatter ops like + // ScatterNdAdd will still use the faster non-deterministic versions if this + // environmental variable is true. This is useful if the user is certain the + // scatter inputs don't have duplicate indices (in which cases scatter ops + // are always deterministic), since the deterministic implementations are + // currently slow. + TF_CHECK_OK(tsl::ReadBoolFromEnvVar("TF_DISABLE_SCATTER_OP_DETERMINISM", + /*default_val=*/false, &disable)); + return disable; + }(); + return cached_disable; +} + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_SCATTER_ND_UTIL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/sdca_internal.h b/third_party/tflite-hdrs/tensorflow/core/kernels/sdca_internal.h new file mode 100644 index 00000000..8f5ac038 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/sdca_internal.h @@ -0,0 +1,394 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_SDCA_INTERNAL_H_ +#define TENSORFLOW_CORE_KERNELS_SDCA_INTERNAL_H_ + +#define EIGEN_USE_THREADS + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/device_base.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/kernels/loss.h" +#include "tensorflow/core/lib/core/coding.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/gtl/inlined_vector.h" +#include "tensorflow/core/lib/random/distribution_sampler.h" +#include "tensorflow/core/lib/strings/stringprintf.h" +#include "tensorflow/core/util/guarded_philox_random.h" +#include "tensorflow/core/util/work_sharder.h" + +namespace tensorflow { + +namespace sdca { + +// Statistics computed with input (ModelWeights, Example). +struct ExampleStatistics { + // Logits for each class. + // For binary case, this should be a vector of length 1; while for multiclass + // case, this vector has the same length as the number of classes, where each + // value corresponds to one class. + // Use InlinedVector to avoid heap allocation for small number of classes. + absl::InlinedVector wx; + + // Logits for each class, using the previous weights. + absl::InlinedVector prev_wx; + + // Sum of squared feature values occurring in the example divided by + // L2 * sum(example_weights). + double normalized_squared_norm = 0; + + // Num_weight_vectors equals to the number of classification classes in the + // multiclass case; while for binary case, it is 1. + ExampleStatistics(const int num_weight_vectors) + : wx(num_weight_vectors, 0.0), prev_wx(num_weight_vectors, 0.0) {} +}; + +class Regularizations { + public: + Regularizations() {} + + // Initialize() must be called immediately after construction. + absl::Status Initialize(OpKernelConstruction* const context) { + TF_RETURN_IF_ERROR(context->GetAttr("l1", &symmetric_l1_)); + TF_RETURN_IF_ERROR(context->GetAttr("l2", &symmetric_l2_)); + shrinkage_ = symmetric_l1_ / symmetric_l2_; + return absl::OkStatus(); + } + + // Proximal SDCA shrinking for L1 regularization. + double Shrink(const double weight) const { + const double shrinked = std::max(std::abs(weight) - shrinkage_, 0.0); + if (shrinked > 0.0) { + return std::copysign(shrinked, weight); + } + return 0.0; + } + + // Vectorized float variant of the above. + Eigen::Tensor EigenShrinkVector( + const Eigen::Tensor weights) const { + // Proximal step on the weights which is sign(w)*|w - shrinkage|+. + return weights.sign() * ((weights.abs() - weights.constant(shrinkage_)) + .cwiseMax(weights.constant(0.0))); + } + + // Matrix float variant of the above. + Eigen::Tensor EigenShrinkMatrix( + const Eigen::Tensor weights) const { + // Proximal step on the weights which is sign(w)*|w - shrinkage|+. + return weights.sign() * ((weights.abs() - weights.constant(shrinkage_)) + .cwiseMax(weights.constant(0.0))); + } + + float symmetric_l2() const { return symmetric_l2_; } + + private: + float symmetric_l1_ = 0; + float symmetric_l2_ = 0; + + // L1 divided by L2, pre-computed for use during weight shrinking. + double shrinkage_ = 0; + + Regularizations(const Regularizations&) = delete; + void operator=(const Regularizations&) = delete; +}; + +class ModelWeights; + +// Struct describing a single example. +class Example { + public: + // Compute matrix vector product between weights (a matrix) and features + // (a vector). This method also computes the normalized example norm used + // in SDCA update. + // For multiclass case, num_weight_vectors equals to the number of classes; + // while for binary case, it is 1. + const ExampleStatistics ComputeWxAndWeightedExampleNorm( + const int num_loss_partitions, const ModelWeights& model_weights, + const Regularizations& regularization, + const int num_weight_vectors) const; + + float example_label() const { return example_label_; } + + float example_weight() const { return example_weight_; } + + double squared_norm() const { return squared_norm_; } + + // Sparse features associated with the example. + // Indices and Values are the associated feature index, and values. Values + // can be optionally absent, in which we case we implicitly assume a value of + // 1.0f. + struct SparseFeatures { + std::unique_ptr::UnalignedConstVec> indices; + std::unique_ptr::UnalignedConstVec> + values; // nullptr encodes optional. + }; + + // A dense vector which is a row-slice of the underlying matrix. + struct DenseVector { + // Returns a row slice from the matrix. + Eigen::TensorMap> Row() + const { + return Eigen::TensorMap>( + data_matrix.data() + row_index * data_matrix.dimension(1), + data_matrix.dimension(1)); + } + + // Returns a row slice as a 1 * F matrix, where F is the number of features. + Eigen::TensorMap> + RowAsMatrix() const { + return Eigen::TensorMap>( + data_matrix.data() + row_index * data_matrix.dimension(1), 1, + data_matrix.dimension(1)); + } + + const TTypes::ConstMatrix data_matrix; + const int64_t row_index; + }; + + private: + std::vector sparse_features_; + std::vector> dense_vectors_; + + float example_label_ = 0; + float example_weight_ = 0; + double squared_norm_ = 0; // sum squared norm of the features. + + // Examples fills Example in a multi-threaded way. + friend class Examples; + + // ModelWeights use each example for model update w += \alpha * x_{i}; + friend class ModelWeights; +}; + +// Weights related to features. For example, say you have two sets of sparse +// features i.e. age bracket and country, then FeatureWeightsDenseStorage hold +// the parameters for it. We keep track of the original weight passed in and the +// delta weight which the optimizer learns in each call to the optimizer. +class FeatureWeightsDenseStorage { + public: + FeatureWeightsDenseStorage(const TTypes::Matrix nominals, + TTypes::Matrix deltas) + : nominals_(nominals), deltas_(deltas) { + CHECK_GT(deltas.rank(), 1); + } + + // Check if a feature index is with-in the bounds. + bool IndexValid(const int64_t index) const { + return index >= 0 && index < deltas_.dimension(1); + } + + // Nominals here are the original weight matrix. + TTypes::Matrix nominals() const { return nominals_; } + + // Delta weights during mini-batch updates. + TTypes::Matrix deltas() const { return deltas_; } + + // Updates delta weights based on active dense features in the example and + // the corresponding dual residual. + void UpdateDenseDeltaWeights( + const Eigen::ThreadPoolDevice& device, + const Example::DenseVector& dense_vector, + const std::vector& normalized_bounded_dual_delta); + + private: + // The nominal value of the weight for a feature (indexed by its id). + const TTypes::Matrix nominals_; + // The accumulated delta weight for a feature (indexed by its id). + TTypes::Matrix deltas_; +}; + +// Similar to FeatureWeightsDenseStorage, but the underlying weights are stored +// in an unordered map. +class FeatureWeightsSparseStorage { + public: + FeatureWeightsSparseStorage(const TTypes::Vec indices, + const TTypes::Matrix nominals, + TTypes::Matrix deltas) + : nominals_(nominals), deltas_(deltas) { + // Create a map from sparse index to the dense index of the underlying + // storage. + for (int64_t j = 0; j < indices.size(); ++j) { + indices_to_id_[indices(j)] = j; + } + } + + // Check if a feature index exists. + bool IndexValid(const int64_t index) const { + return indices_to_id_.find(index) != indices_to_id_.end(); + } + + // Nominal value at a particular feature index and class label. + float nominals(const int class_id, const int64_t index) const { + auto it = indices_to_id_.find(index); + return nominals_(class_id, it->second); + } + + // Delta weights during mini-batch updates. + float deltas(const int class_id, const int64_t index) const { + auto it = indices_to_id_.find(index); + return deltas_(class_id, it->second); + } + + // Updates delta weights based on active sparse features in the example and + // the corresponding dual residual. + void UpdateSparseDeltaWeights( + const Eigen::ThreadPoolDevice& device, + const Example::SparseFeatures& sparse_features, + const std::vector& normalized_bounded_dual_delta); + + private: + // The nominal value of the weight for a feature (indexed by its id). + const TTypes::Matrix nominals_; + // The accumulated delta weight for a feature (indexed by its id). + TTypes::Matrix deltas_; + // Map from feature index to an index to the dense vector. + std::unordered_map indices_to_id_; +}; + +// Weights in the model, wraps both current weights, and the delta weights +// for both sparse and dense features. +class ModelWeights { + public: + ModelWeights() {} + + bool SparseIndexValid(const int col, const int64_t index) const { + return sparse_weights_[col].IndexValid(index); + } + + bool DenseIndexValid(const int col, const int64_t index) const { + return dense_weights_[col].IndexValid(index); + } + + // Go through all the features present in the example, and update the + // weights based on the dual delta. + void UpdateDeltaWeights( + const Eigen::ThreadPoolDevice& device, const Example& example, + const std::vector& normalized_bounded_dual_delta); + + absl::Status Initialize(OpKernelContext* const context); + + const std::vector& sparse_weights() const { + return sparse_weights_; + } + + const std::vector& dense_weights() const { + return dense_weights_; + } + + private: + std::vector sparse_weights_; + std::vector dense_weights_; + + ModelWeights(const ModelWeights&) = delete; + void operator=(const ModelWeights&) = delete; +}; + +// Examples contains all the training examples that SDCA uses for a mini-batch. +class Examples { + public: + Examples() {} + + // Returns the Example at |example_index|. + const Example& example(const int example_index) const { + return examples_.at(example_index); + } + + int sampled_index(const int id) const { return sampled_index_[id]; } + + // Adaptive SDCA in the current implementation only works for + // binary classification, where the input argument for num_weight_vectors + // is 1. + absl::Status SampleAdaptiveProbabilities( + const int num_loss_partitions, const Regularizations& regularization, + const ModelWeights& model_weights, + const TTypes::Matrix example_state_data, + const std::unique_ptr& loss_updater, + const int num_weight_vectors); + + void RandomShuffle(); + + int num_examples() const { return examples_.size(); } + + int num_features() const { return num_features_; } + + // Initialize() must be called immediately after construction. + absl::Status Initialize(OpKernelContext* const context, + const ModelWeights& weights, int num_sparse_features, + int num_sparse_features_with_values, + int num_dense_features); + + private: + // Reads the input tensors, and builds the internal representation for sparse + // features per example. This function modifies the |examples| passed in + // to build the sparse representations. + static absl::Status CreateSparseFeatureRepresentation( + const DeviceBase::CpuWorkerThreads& worker_threads, int num_examples, + int num_sparse_features, const ModelWeights& weights, + const OpInputList& sparse_example_indices_inputs, + const OpInputList& sparse_feature_indices_inputs, + const OpInputList& sparse_feature_values_inputs, + std::vector* const examples); + + // Reads the input tensors, and builds the internal representation for dense + // features per example. This function modifies the |examples| passed in + // to build the sparse representations. + static absl::Status CreateDenseFeatureRepresentation( + const DeviceBase::CpuWorkerThreads& worker_threads, int num_examples, + int num_dense_features, const ModelWeights& weights, + const OpInputList& dense_features_inputs, + std::vector* const examples); + + // Computes squared example norm per example i.e |x|^2. This function modifies + // the |examples| passed in and adds the squared norm per example. + static absl::Status ComputeSquaredNormPerExample( + const DeviceBase::CpuWorkerThreads& worker_threads, int num_examples, + int num_sparse_features, int num_dense_features, + std::vector* const examples); + + // All examples in the batch. + std::vector examples_; + + // Adaptive sampling variables. + std::vector probabilities_; + std::vector sampled_index_; + std::vector sampled_count_; + + int num_features_ = 0; + + Examples(const Examples&) = delete; + void operator=(const Examples&) = delete; +}; + +} // namespace sdca +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_SDCA_INTERNAL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/searchsorted_op.h b/third_party/tflite-hdrs/tensorflow/core/kernels/searchsorted_op.h new file mode 100644 index 00000000..fb4ade03 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/searchsorted_op.h @@ -0,0 +1,54 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_SEARCHSORTED_OP_H_ +#define TENSORFLOW_CORE_KERNELS_SEARCHSORTED_OP_H_ + +#include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/core/errors.h" + +namespace tensorflow { +namespace functor { + +template +struct UpperBoundFunctor { + // Searches for values in sorted_inputs and returns the greatest possible + // index where they maintain sorted order. + static absl::Status Compute( + OpKernelContext* context, + const typename TTypes::ConstTensor& sorted_inputs, + const typename TTypes::ConstTensor& values, int batch_size, + int num_inputs, int num_values, + typename TTypes::Tensor* output); +}; + +template +struct LowerBoundFunctor { + // Searches for values in sorted_inputs and returns the lowest possible + // index where they maintain sorted order. + static absl::Status Compute( + OpKernelContext* context, + const typename TTypes::ConstTensor& sorted_inputs, + const typename TTypes::ConstTensor& values, int batch_size, + int num_inputs, int num_values, + typename TTypes::Tensor* output); +}; +} // namespace functor + +} // end namespace tensorflow +#endif // TENSORFLOW_CORE_KERNELS_SEARCHSORTED_OP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/segment_reduction_ops.h b/third_party/tflite-hdrs/tensorflow/core/kernels/segment_reduction_ops.h new file mode 100644 index 00000000..93aa9636 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/segment_reduction_ops.h @@ -0,0 +1,173 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_SEGMENT_REDUCTION_OPS_H_ +#define TENSORFLOW_CORE_KERNELS_SEGMENT_REDUCTION_OPS_H_ + +#include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/tensor_types.h" + +namespace tensorflow { + +class OpKernelContext; + +bool UseDeterministicSegmentReductions(); +bool DisableSegmentReductionOpDeterminismExceptions(); + +// Type of SparseSegmentReduction operation to perform gradient of. +enum class SparseSegmentReductionOperation { kSum, kMean, kSqrtN }; + +namespace functor { + +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM + +// Note that we define this ourselves to avoid a dependency on gpuprim. +struct Sum { + template + __host__ __device__ T operator()(const T& a, const T& b) const { + return a + b; + } +}; + +struct Prod { + template + __host__ __device__ T operator()(const T& a, const T& b) const { + return a * b; + } +}; + +// Note that we don't use gpuprim::Min/Max because they use operator<, which is +// not implemented for AlignedVector types. +struct Min { + template + __host__ __device__ T operator()(const T& a, const T& b) const { + return min(a, b); + } +}; + +struct Max { + template + __host__ __device__ T operator()(const T& a, const T& b) const { + return max(a, b); + } +}; + +template +struct ReduceOpIsAssociative {}; +template +struct ReduceOpIsAssociative : std::is_integral {}; +template +struct ReduceOpIsAssociative : std::is_integral {}; +template +struct ReduceOpIsAssociative : std::true_type {}; +template +struct ReduceOpIsAssociative : std::true_type {}; + +typedef Eigen::GpuDevice GPUDevice; +// Functor for SegmentReductionGPUOp. +// output_rows: the number of output segments (unique segment ids in +// 'segment_ids'). +// segment_ids_shape: shape of 'segment_ids' tensor. +// segment_ids: unsorted map from input to output segment ids at which to +// perform segment sum operation. +// data_size: size of input data tensor. +// data: input data tensor. +// output: output reshaped to {output_rows, output.size/output_rows} +template +struct SegmentReductionFunctor { + void operator()(OpKernelContext* ctx, const GPUDevice& d, + const Index output_rows, const TensorShape& segment_ids_shape, + bool is_mean, typename TTypes::ConstFlat segment_ids, + const Index data_size, const T* data, + typename TTypes::Tensor output); + static constexpr bool atomic_reduction_is_associative = + ReduceOpIsAssociative::value; +}; + +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM + +template +struct UnsortedSegmentFunctor { + void operator()(OpKernelContext* ctx, const TensorShape& segment_ids_shape, + typename TTypes::ConstFlat segment_ids, + typename TTypes::ConstTensor data, + typename TTypes::Tensor output); +}; + +// Initial value functors. +template +struct Zero { + EIGEN_STRONG_INLINE T operator()() const { return T(0); } +}; + +template +struct One { + EIGEN_STRONG_INLINE T operator()() const { return T(1); } +}; + +template +struct Lowest { + EIGEN_STRONG_INLINE T operator()() const { + return Eigen::NumTraits::lowest(); + } +}; + +template +struct Highest { + EIGEN_STRONG_INLINE T operator()() const { + return Eigen::NumTraits::highest(); + } +}; + +template +struct SparseSegmentReductionFunctor { + absl::Status operator()(OpKernelContext* context, bool is_mean, bool is_sqrtn, + T default_value, + typename TTypes::ConstTensor input, + typename TTypes::ConstVec indices, + typename TTypes::ConstVec segment_ids, + typename TTypes::Tensor output); +}; + +template +struct SparseSegmentGradFunctor { + void operator()(OpKernelContext* context, + SparseSegmentReductionOperation operation, + typename TTypes::ConstMatrix input_flat, + typename TTypes::ConstVec indices_vec, + typename TTypes::ConstVec segment_vec, + Tensor* output); +}; + +template +struct SparseSegmentGradV2Functor { + void operator()(OpKernelContext* context, + SparseSegmentReductionOperation operation, + typename TTypes::ConstMatrix input_flat, + typename TTypes::ConstVec indices_vec, + typename TTypes::ConstVec segment_vec, + const TensorShape& dense_output_shape, + typename AsyncOpKernel::DoneCallback done); +}; + +} // namespace functor +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_SEGMENT_REDUCTION_OPS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/segment_reduction_ops_gpu.cu.h b/third_party/tflite-hdrs/tensorflow/core/kernels/segment_reduction_ops_gpu.cu.h new file mode 100644 index 00000000..f0ba0ce2 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/segment_reduction_ops_gpu.cu.h @@ -0,0 +1,1413 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_SEGMENT_REDUCTION_OPS_GPU_CU_H_ +#define TENSORFLOW_CORE_KERNELS_SEGMENT_REDUCTION_OPS_GPU_CU_H_ + +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM + +#define EIGEN_USE_GPU + +#include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/kernels/gpu_prim.h" +#include "tensorflow/core/kernels/gpu_prim_helpers.h" +#include "tensorflow/core/kernels/segment_reduction_ops.h" +#include "tensorflow/core/lib/core/bits.h" +#include "tensorflow/core/util/determinism.h" +#include "tensorflow/core/util/env_var.h" +#include "tensorflow/core/util/gpu_device_functions.h" +#include "tensorflow/core/util/gpu_kernel_helper.h" +#include "tensorflow/core/util/gpu_solvers.h" // For ScratchSpace +#include "tensorflow/core/util/permutation_input_iterator.h" + +#if (defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM) +#include "tensorflow/core/platform/rocm.h" +#endif + +namespace tensorflow { + +using GPUDevice = Eigen::GpuDevice; + +// Non/Atomic reduction functors for the gpu. +#define DEFINE_REDUCE_UPDATE_OP_GPU(name, func) \ + struct name##OpGpu { \ + template \ + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void operator()(T* dest, \ + const T& value) { \ + func; \ + } \ + }; +DEFINE_REDUCE_UPDATE_OP_GPU(AtomicSum, GpuAtomicAdd(dest, value)) +DEFINE_REDUCE_UPDATE_OP_GPU(AtomicProd, GpuAtomicMul(dest, value)) +DEFINE_REDUCE_UPDATE_OP_GPU(AtomicMax, GpuAtomicMax(dest, value)) +DEFINE_REDUCE_UPDATE_OP_GPU(AtomicMin, GpuAtomicMin(dest, value)) +DEFINE_REDUCE_UPDATE_OP_GPU(NonAtomicSum, *dest += value) +DEFINE_REDUCE_UPDATE_OP_GPU(NonAtomicProd, *dest *= value) +DEFINE_REDUCE_UPDATE_OP_GPU(NonAtomicMax, *dest = max(*dest, value)) +DEFINE_REDUCE_UPDATE_OP_GPU(NonAtomicMin, *dest = min(*dest, value)) +#undef DEFINE_REDUCE_UPDATE_OP_GPU + +template +struct ReduceUpdateOpFor {}; + +#define DEFINE_REDUCE_UPDATE_OP_FOR(reduce_op, atomic, nonatomic) \ + template <> \ + struct ReduceUpdateOpFor { \ + using atomic_op = atomic; \ + using nonatomic_op = nonatomic; \ + }; +DEFINE_REDUCE_UPDATE_OP_FOR(functor::Sum, AtomicSumOpGpu, NonAtomicSumOpGpu) +DEFINE_REDUCE_UPDATE_OP_FOR(functor::Prod, AtomicProdOpGpu, NonAtomicProdOpGpu) +DEFINE_REDUCE_UPDATE_OP_FOR(functor::Max, AtomicMaxOpGpu, NonAtomicMaxOpGpu) +DEFINE_REDUCE_UPDATE_OP_FOR(functor::Min, AtomicMinOpGpu, NonAtomicMinOpGpu) +#undef DEFINE_REDUCE_UPDATE_OP_FOR + +// PR#61339: MSVC does not support compound-assignment operators on device + +// SortedSegmentReductionFunctor kernel reduces input data just as +// UnsortedSegmentReductionCustomKernel does except that input data +// is partitioned along the outer reduction dimension. This is +// because consecutive rows (elements in a row share the same +// outer dimension index) in the flattened 2D input data likely +// belong to the same segment in sorted segment sum operation. +// Therefore such partitioning strategy has two advantages over +// the UnsortedSegmentReductionFunctor kernel: +// 1. Each thread reduces across multiple rows before writing +// answers to the global memory, we can therefore +// write reduction results to global memory less often. +// 2. We may know that the current thread is the only contributor +// to an output element because of the increasing nature of segment +// ids. In such cases, we do not need to use atomic operations +// to write results to global memory. +// In the flattened view of input data (with only outer and inner +// dimension), every thread processes a strip of input data of +// size OuterDimTileSize x 1. This strip runs across multiple +// rows of input data and all reduction elements share one inner +// dimension index. +template +__global__ void SortedSegmentReductionCustomKernel( + const Index input_outer_dim_size, const Index inner_dim_size, + const Index output_outer_dim_size, const Index* __restrict__ segment_ids, + const T* __restrict__ input, T* __restrict__ output, + const Index total_stripe_count, const T initial_value) { + for (int stripe_index : GpuGridRangeX(total_stripe_count)) { + const Index segment_offset = stripe_index % inner_dim_size; + const Index input_outer_dim_index_base = + stripe_index / inner_dim_size * Index(OuterDimTileSize); + + T reduce_res = initial_value; + Index first_segment_id = segment_ids[input_outer_dim_index_base]; + Index last_output_segment_id = output_outer_dim_size; + + const Index actual_stripe_height = + min(Index(OuterDimTileSize), + input_outer_dim_size - input_outer_dim_index_base); + for (Index j = 0; j < actual_stripe_height; j++) { + Index current_output_segment_id = + segment_ids[input_outer_dim_index_base + j]; + // Decide whether to write result to global memory. Result is only written + // to global memory if we move to another segment. Otherwise we can keep + // accumulating locally. + if (current_output_segment_id > last_output_segment_id) { + const Index output_index = + last_output_segment_id * inner_dim_size + segment_offset; + // Decide whether to write result to global memory using atomic + // operations. + if (last_output_segment_id == first_segment_id) { + AtomicReductionF()(output + output_index, reduce_res); + } else { + ReductionF()(output + output_index, reduce_res); + } + reduce_res = initial_value; + } + ReductionF()( + &reduce_res, + ldg(input + (input_outer_dim_index_base + j) * inner_dim_size + + segment_offset)); + last_output_segment_id = current_output_segment_id; + } + // For the last result in a strip, always write using atomic operations + // due to possible race conditions with threads computing + // the following strip. + const Index output_index = + last_output_segment_id * inner_dim_size + segment_offset; + AtomicReductionF()(output + output_index, reduce_res); + } +} + +template +__global__ void SegmentMeanNormalizeKernel( + SegmentId nsegments, Index ninner, + const Index* __restrict__ segment_offsets, // [nsegments + 1] + T* __restrict__ output) { // [nsegments, ninner] + for (SegmentId seg : GpuGridRangeY(nsegments)) { + SegmentId segment_size = segment_offsets[seg + 1] - segment_offsets[seg]; + segment_size = max(segment_size, Index(1)); // Avoid division by zero + T inv_norm = T(1) / static_cast(segment_size); + for (Index i : GpuGridRangeX(ninner)) { + output[seg * ninner + i] *= inv_norm; + } + } +} + +template +Status LaunchSegmentMeanNormalizeKernel( + const GPUDevice& d, SegmentId nsegments, Index ninner, + const Index* __restrict__ segment_offsets, // [nsegments + 1] + T* __restrict__ output) { // [nsegments, ninner] + Gpu2DLaunchConfig config = GetGpu2DLaunchConfig( + ninner, nsegments, d, SegmentMeanNormalizeKernel, + /*dynamic_shared_memory_size=*/0, /*block_size_limit=*/0); + return GpuLaunchKernel(SegmentMeanNormalizeKernel, + config.block_count, config.thread_per_block, 0, + d.stream(), nsegments, ninner, segment_offsets, + output); +} + +template +__global__ void SegmentSetEmptyKernel( + SegmentId nsegments, Index ninner, + const Index* __restrict__ segment_offsets, // [nsegments + 1] + const T empty_value, + T* __restrict__ output) { // [nsegments, ninner] + for (SegmentId seg : GpuGridRangeY(nsegments)) { + SegmentId segment_size = segment_offsets[seg + 1] - segment_offsets[seg]; + if (segment_size == 0) { + for (Index i : GpuGridRangeX(ninner)) { + output[seg * ninner + i] = empty_value; + } + } + } +} + +template +Status LaunchSegmentSetEmptyKernel( + const GPUDevice& d, SegmentId nsegments, Index ninner, + const Index* __restrict__ segment_offsets, // [nsegments + 1] + const T empty_value, + T* __restrict__ output) { // [nsegments, ninner] + Gpu2DLaunchConfig config = GetGpu2DLaunchConfig( + ninner, nsegments, d, SegmentSetEmptyKernel, + /*dynamic_shared_memory_size=*/0, /*block_size_limit=*/0); + return GpuLaunchKernel(SegmentSetEmptyKernel, + config.block_count, config.thread_per_block, 0, + d.stream(), nsegments, ninner, segment_offsets, + empty_value, output); +} + +// UnsortedSegmentSumKernel processes 'input_total_size' elements. +// Each element is mapped from input to output by a combination of its +// 'segment_ids' mapping and 'inner_dim_size'. +template +__global__ void UnsortedSegmentCustomKernel( + const int64_t input_outer_dim_size, const int64_t inner_dim_size, + const int64_t output_outer_dim_size, const Index* __restrict__ segment_ids, + const T* __restrict__ input, T* __restrict__ output) { + const int64_t input_total_size = input_outer_dim_size * inner_dim_size; + for (int64_t input_index : GpuGridRangeX(input_total_size)) { + const int64_t input_segment_index = input_index / inner_dim_size; + const int64_t segment_offset = input_index % inner_dim_size; + const Index output_segment_index = segment_ids[input_segment_index]; + if (output_segment_index < 0 || + output_segment_index >= output_outer_dim_size) { + continue; + } + const int64_t output_index = + output_segment_index * inner_dim_size + segment_offset; + KernelReductionFunctor()(output + output_index, ldg(input + input_index)); + } +} + +template +__global__ void SegmentOffsetsKernel( + Toffsets size, Tsegmentids nsegments, + const Tsegmentids* __restrict__ segment_ids, // [size] + Toffsets* __restrict__ segment_offsets) { // [nsegments + 1] + GPU_1D_KERNEL_LOOP(i, size + 1) { + // IDs are clipped to [-1, nsegments] so that out-of-bounds IDs are ignored. + // Note that we can't report invalid IDs from the GPU without incurring + // additional overhead. + auto clip = [&](Tsegmentids id) { + return min(max(Tsegmentids(-1), id), nsegments); + }; + const Tsegmentids cur_id = (i < size) ? clip(segment_ids[i]) : nsegments; + const Tsegmentids prev_id = + (i == 0) ? Tsegmentids(-1) : clip(segment_ids[i - 1]); + // At segment boundaries, write the offset for this ID and any missing IDs + // since the previous one. + for (Tsegmentids id = prev_id + 1; id <= cur_id; ++id) { + segment_offsets[id] = i; + } + } +} + +// Finds the start offset of each segment in the given sorted segment_ids +// vector. Missing IDs are given the same offset as the next ID so that they +// represent empty ranges. Invalid IDs (those that are outside the range +// [0, nsegments)) are ignored. The value at segment_offsets[0] is set to the +// start index of the first valid ID (e.g., 0 if all IDs are valid), and the +// value at segment_offsets[nsegments] is set to the end index of the last valid +// ID (e.g., nsegments if all IDs are valid). +template +Status LaunchSegmentOffsetsKernel( + const GPUDevice& d, Toffsets size, Tsegmentids nsegments, + const Tsegmentids* segment_ids, // [size] + Toffsets* segment_offsets) { // [nsegments + 1] + GpuLaunchConfig config = GetGpuLaunchConfig( + size + 1, d, &SegmentOffsetsKernel, + /*dynamic_shared_memory_size=*/0, /*block_size_limit=*/0); + return GpuLaunchKernel(SegmentOffsetsKernel, + config.block_count, config.thread_per_block, 0, + d.stream(), size, nsegments, segment_ids, + segment_offsets); +} + +template +struct RealTypeIfComplex { + using type = T; +}; + +template +struct RealTypeIfComplex> { + using type = Real; +}; + +// Reduces along columns of the thread block, returning the result in the first +// row of threads. +template +__device__ T ReduceBlockAlongCols(ReduceOp reduce_op, const T& value, + bool is_valid) { + GPU_DYNAMIC_SHARED_MEM_DECL(/*ALIGN=*/16, char, shared_memory_raw); + T* const shared_partial_reduction = + reinterpret_cast(shared_memory_raw); // [blockDim.y, blockDim.x] + const int x = threadIdx.x; + const int y = threadIdx.y; + T reduced = value; + // Reduce over the y dimension of the block. + for (unsigned k = blockDim.y / 2; k > 0; k /= 2) { + if (is_valid && y < 2 * k) { + shared_partial_reduction[y * blockDim.x + x] = reduced; + } + __syncthreads(); + if (is_valid && y < k) { + reduced = reduce_op(reduced, + shared_partial_reduction[(y + k) * blockDim.x + x]); + } + __syncthreads(); + } + return reduced; +} + +// This kernel uses a 2D thread decomposition. The x dimension maps to the inner +// dimension of the input/output. The y grid dimension maps to segments, and y +// threads within a block cooperate to reduce over the block's segment. +// Note that Tinit is needed because Tvec and Treducevec may be vector types, +// but Tinit is always a scalar type. +// Note that the first dimension of input_vec is nouter if indices is not +// provided; otherwise it is indexed indirectly via indices and can have any +// size (as long as it spans at least the maximum value in indices). This also +// applies to the weights vector. +template +__global__ void SegmentReduceVectorKernel( + Toffsets nouter, Toffsets ninner_vec, Tsegmentids nsegments, + ReduceOp reduce_op, Tinit initial_value, Tinit empty_segment_value, + bool is_mean, bool is_sqrtn, + const Tvec* __restrict__ input_vec, // [nouter or any, ninner_vec] + const Toffsets* __restrict__ segment_offsets, // [nsegments + 1] + const Tindices* __restrict__ indices, // [nouter] (optional) + const Tweights* __restrict__ weights, // [nouter or any] (optional) + Tvec* __restrict__ output_vec) { // [nsegments, ninner_vec] + const int num_blocks_x = (ninner_vec - 1) / blockDim.x + 1; + // Grid-stride loop over inner dimension blocks. + for (Toffsets blk_x = blockIdx.x; blk_x < num_blocks_x; blk_x += gridDim.x) { + const Toffsets x = threadIdx.x + blk_x * blockDim.x; + const Toffsets y = threadIdx.y; + const bool x_ok = x < ninner_vec; + // Grid-stride loop over segment blocks, each processing one segment. + for (Tsegmentids seg = blockIdx.y; seg < nsegments; seg += gridDim.y) { + // Load segment range. + const Toffsets begin = segment_offsets[seg]; + const Toffsets end = segment_offsets[seg + 1]; + // Reduce over the segment. + Treducevec result = Treducevec(initial_value); + // Loop over the segment, reducing blockDim.y elements at a time. + for (Toffsets y_offset = begin; y_offset < end; y_offset += blockDim.y) { + const bool y_ok = (y_offset + y) < end; + // Perform indirect lookup if required. + const Toffsets y_idx = + indices && y_ok ? indices[y_offset + y] : y_offset + y; + const int64_t input_idx = static_cast(y_idx) * ninner_vec + x; + // Load the input row from global mem. + Treducevec block_result = + x_ok && y_ok ? input_vec[input_idx] : Tvec(initial_value); + // Apply weights if provided. + if (weights && y_ok) block_result = block_result * Tvec(weights[y_idx]); + // Reduce along the columns of the block, returning result in first row. + block_result = ReduceBlockAlongCols(reduce_op, block_result, x_ok); + if (y == 0 && x_ok) { + result = reduce_op(result, block_result); + } + } + // First row of the block stores the result to global memory. + if (y == 0 && x_ok) { + if (begin == end) { + // Empty segment. + result = Treducevec(empty_segment_value); + } else { + Tweights total_weight(end - begin); + // Normalize the results if necessary. + if (is_mean) { + result = result / Treducevec(total_weight); + } else if (is_sqrtn) { + result = + result / Treducevec(sqrt(static_cast(total_weight))); + } + } + // Cast from Treducevec to Tvec. + const int64_t output_idx = static_cast(seg) * ninner_vec + x; + output_vec[output_idx] = static_cast(result); + } + } + } +} + +// Reduces input matrix within segments over the outer dimension. Empty segments +// always output empty_segment_value. +// If is_mean or is_sqrtn is true, the results are normalized using the +// corresponding function. +// If indices is not nullptr, input rows are accessed indirectly as +// input[indices[i]], instead of input[i]. +// Note: Treducevec is to allow reducing in higher precision than Tvec. +template +Status LaunchSegmentReduceVectorKernel( + const GPUDevice& d, Toffsets nouter, Toffsets ninner_vec, + Tsegmentids nsegments, ReduceOp reduce_op, Tinit initial_value, + Tinit empty_segment_value, bool is_mean, bool is_sqrtn, + const Tvec* input_vec, // [nouter or any, ninner_vec] + const Toffsets* segment_offsets, // [nsegments + 1] + const Tindices* indices, // [nouter] (optional) + const Tweights* weights, // [nouter or any] (optional) + Tvec* output_vec) { // [nsegments, ninner_vec] + static constexpr const int kMaxGridX = (1u << 31) - 1; + static constexpr const int kMaxGridY = (1u << 16) - 1; + const int max_block_size = 1024; // Can be tuned for perf (<= 1024) + const int min_block_size = 64; // Can be tuned for perf + const Toffsets ninner_pow2 = Toffsets(1) << Log2Ceiling64(ninner_vec); + // This is a heuristic that first allocates threads in the block to the inner + // (x) dimension (which is most efficient) and then allocates the rest to the + // reduction (y) dimension (which is less efficient but increases + // parallelism). + int block_x = std::min(ninner_pow2, static_cast(max_block_size)); + const Toffsets avg_reduce_size = + Eigen::divup(nouter, static_cast(nsegments)); + const Toffsets avg_reduce_size_pow2 = Toffsets(1) + << Log2Ceiling64(avg_reduce_size); + dim3 block( + block_x, + std::min(static_cast(Eigen::divup(min_block_size, block_x)), + avg_reduce_size_pow2)); + dim3 grid(std::min(Eigen::divup(ninner_vec, static_cast(block.x)), + static_cast(kMaxGridX)), + std::min(nsegments, static_cast(kMaxGridY))); + unsigned shared_memory_bytes = block.x * block.y * sizeof(Treducevec); + return GpuLaunchKernel( + SegmentReduceVectorKernel, + grid, block, shared_memory_bytes, d.stream(), nouter, ninner_vec, + nsegments, reduce_op, initial_value, empty_segment_value, is_mean, + is_sqrtn, input_vec, segment_offsets, indices, weights, output_vec); +} + +template +__global__ void SegmentReduceEpilogueKernel( + Tsegmentids nsegments, Tinit empty_segment_value, bool is_mean, + bool is_sqrtn, + const Treducevec* __restrict__ output_raw, // [nsegments] + const Toffsets* __restrict__ segment_offsets, // [nsegments + 1] + Tvec* __restrict__ output) { // [nsegments] + GPU_1D_KERNEL_LOOP(seg, nsegments) { + Toffsets segment_size = segment_offsets[seg + 1] - segment_offsets[seg]; + Treducevec val = output_raw[seg]; + if (segment_size == 0) { + // Empty segment. + val = Treducevec(empty_segment_value); + } else if (is_mean) { + val = val / Treducevec(segment_size); + } else if (is_sqrtn) { + val = val / Treducevec(sqrt(static_cast( + typename RealTypeIfComplex::type(segment_size)))); + } + // Cast from Treducevec to Tvec. + output[seg] = static_cast(val); + } +} + +// Normalizes output_raw based on segment size and casts from Treducevec to +// Tvec. If Tvec == Treducevec, this is safe to call with output_raw == output. +// Note that Treducevec is the type that was used for the reduction, which may +// be a higher-precision type than the output type Tvec (e.g., float vs. half). +template +Status LaunchSegmentReduceEpilogueKernel( + const GPUDevice& d, Tsegmentids nsegments, Tinit empty_segment_value, + bool is_mean, bool is_sqrtn, + const Treducevec* output_raw, // [nsegments] + const Toffsets* segment_offsets, // [nsegments + 1] + Tvec* output) { // [nsegments] + GpuLaunchConfig config = GetGpuLaunchConfig( + nsegments, d, + &SegmentReduceEpilogueKernel, + /*dynamic_shared_memory_size=*/0, /*block_size_limit=*/0); + return GpuLaunchKernel(SegmentReduceEpilogueKernel, + config.block_count, config.thread_per_block, 0, + d.stream(), nsegments, empty_segment_value, is_mean, + is_sqrtn, output_raw, segment_offsets, output); +} + +template +struct CastFunctor { + template + __device__ Tto operator()(const T& val) const { + return static_cast(val); + } +}; + +template +struct LookupAndScaleAndCastInputsFunctor { + LookupAndScaleAndCastInputsFunctor(const Tvec* input_vec, + const Tindices* indices, + const Tweights* weights) + : input_vec_(input_vec), indices_(indices), weights_(weights) {} + + template + __device__ Treducevec operator()(Toffsets idx) const { + if (indices_) idx = indices_[idx]; + Treducevec result = static_cast(input_vec_[idx]); + if (weights_) result = result * Tvec(weights_[idx]); + return result; + } + + private: + const Tvec* __restrict__ input_vec_; + const Tindices* __restrict__ indices_; + const Tweights* __restrict__ weights_; +}; + +template +struct CastIterator { + using FunctorTy = + LookupAndScaleAndCastInputsFunctor; + using InputIteratorTy = gpuprim::CountingInputIterator; + using IteratorTy = + gpuprim::TransformInputIterator; +}; + +template +typename CastIterator::IteratorTy +MakeLookupAndScaleAndCastInputsIterator(const Tvec* input_vec, + const Tindices* indices, + const Tweights* weights) { + using CastIteratorTy = + CastIterator; + typename CastIteratorTy::FunctorTy functor(input_vec, indices, weights); + return typename CastIteratorTy::IteratorTy( + typename CastIteratorTy::InputIteratorTy(Toffsets(0)), functor); +} + +template +Status SegmentReduceGPUImplNoInnerDim( + OpKernelContext* ctx, Toffsets nouter, Tsegmentids nsegments, + ReduceOp reduce_op, Tinit initial_value, Tinit empty_segment_value, + bool is_mean, bool is_sqrtn, + const Tvec* input_vec, // [nouter or any] + const Toffsets* segment_offsets, // [nsegments + 1] + const Tindices* indices, // [nouter] (optional) + const Tweights* weights, // [nouter or any] (optional) + Tvec* output_vec) { // [nsegments] + // Here we use gpuprim::DeviceSegmentedReduce (which is optimized for this + // shape) and add the additional required functionality using fancy input + // iterators and an epilogue kernel. + + // Note: This reinterpret cast is only needed to avoid compilation error + // when Tvec != Treducevec; the result is only used if Tvec == Treducevec. + Treducevec* output_raw_ptr = reinterpret_cast(output_vec); + Tensor output_raw; + bool need_temp_output = !std::is_same::value; + if (need_temp_output) { + // Note: We must allocate and reinterpret as bytes because Treducevec may + // be a vector type and they are not supported as Tensor dtypes. + TF_RETURN_IF_ERROR(ctx->allocate_temp( + DT_INT8, + TensorShape({static_cast(nsegments * sizeof(Treducevec))}), + &output_raw)); + output_raw_ptr = + reinterpret_cast(output_raw.flat().data()); + } + auto input_iter = + MakeLookupAndScaleAndCastInputsIterator( + input_vec, indices, weights); + TF_RETURN_IF_ERROR(GpuSegmentedReduce(ctx, nsegments, reduce_op, + Treducevec(initial_value), input_iter, + segment_offsets, output_raw_ptr)); + bool need_epilogue = !std::is_same::value || + initial_value != empty_segment_value || is_mean || + is_sqrtn; + if (need_epilogue) { + const GPUDevice& device = ctx->eigen_gpu_device(); + // Normalize based on the segment size and cast results back to T. + TF_RETURN_IF_ERROR(LaunchSegmentReduceEpilogueKernel( + device, nsegments, empty_segment_value, is_mean, is_sqrtn, + output_raw_ptr, segment_offsets, output_vec)); + } + return OkStatus(); +} + +template +Status SegmentReduceGPUImpl( + OpKernelContext* ctx, Toffsets nouter, Toffsets ninner_vec, + Tsegmentids nsegments, ReduceOp reduce_op, Tinit initial_value, + Tinit empty_segment_value, bool is_mean, bool is_sqrtn, + const Tvec* input_vec, // [nouter or any, ninner_vec] + const Tsegmentids* segment_ids, // [nouter] + const Tindices* indices, // [nouter] (optional) + const Tweights* weights, // [nouter or any] (optional) + Tvec* output_vec) { // [nsegments, ninner_vec] + const GPUDevice& device = ctx->eigen_gpu_device(); + + if (nouter == 0) { + // Just set output to empty_segment_value. + GPUDevice d = ctx->template eigen_device(); + int64_t output_size = static_cast(nsegments) * ninner_vec; + GpuLaunchConfig config = GetGpuLaunchConfig(output_size, d); + return GpuLaunchKernel(SetToValue, config.block_count, + config.thread_per_block, 0, d.stream(), output_size, + output_vec, empty_segment_value); + } + + // Allocate and compute segment_offsets. + Tensor segment_offsets; + TF_RETURN_IF_ERROR(ctx->allocate_temp(DataTypeToEnum::value, + TensorShape({nsegments + 1}), + &segment_offsets)); + Toffsets* segment_offsets_ptr = segment_offsets.flat().data(); + TF_RETURN_IF_ERROR(LaunchSegmentOffsetsKernel( + device, nouter, nsegments, segment_ids, segment_offsets_ptr)); + + const Toffsets avg_reduce_size = + Eigen::divup(nouter, static_cast(nsegments)); + // This avg_reduce_size threshold is a performance heuristic. + if (ninner_vec == 1 && avg_reduce_size >= 512) { + // Here we use a gpuprim-based implementation that doesn't support an + // inner dimension but can be significantly faster for large reductions. + return SegmentReduceGPUImplNoInnerDim( + ctx, nouter, nsegments, reduce_op, initial_value, empty_segment_value, + is_mean, is_sqrtn, input_vec, segment_offsets_ptr, indices, weights, + output_vec); + } + // Here we use a custom kernel that is optimized for ninner_vec >= ~64 and + // gives decent performance for smaller cases. It also handles indices, + // casting to/from Treducevec, and normalizing the output. + return LaunchSegmentReduceVectorKernel( + device, nouter, ninner_vec, nsegments, reduce_op, initial_value, + empty_segment_value, is_mean, is_sqrtn, input_vec, segment_offsets_ptr, + indices, weights, output_vec); +} + +template +struct SegmentReduceGPUVectorized { + template + struct Impl { + template + Status operator()(OpKernelContext* ctx, Toffsets nouter, Toffsets ninner, + Tsegmentids nsegments, ReduceOp reduce_op, + T initial_value, T empty_segment_value, bool is_mean, + bool is_sqrtn, const T* input, + const Tsegmentids* segment_ids, const Tindices* indices, + const Tweights* weights, T* output) { + DCHECK_EQ(ninner % vec_size, 0); + DCHECK_EQ(reinterpret_cast(input) % vec_size, 0); + DCHECK_EQ(reinterpret_cast(output) % vec_size, 0); + Toffsets ninner_vec = ninner / vec_size; + using Tvec = AlignedVector; + using Treducevec = AlignedVector; + const Tvec* input_vec = reinterpret_cast(input); + Tvec* output_vec = reinterpret_cast(output); + + return SegmentReduceGPUImpl( + ctx, nouter, ninner_vec, nsegments, reduce_op, initial_value, + empty_segment_value, is_mean, is_sqrtn, input_vec, segment_ids, + indices, weights, output_vec); + } + }; +}; + +// Reduces input matrix within segments over the outer dimension. Empty segments +// always output empty_segment_value. +// The segment_ids vector must be sorted. +// If is_mean or is_sqrtn is true, the results are normalized using the +// corresponding function. +// If indices is not nullptr, input rows are accessed indirectly as +// input[indices[i]], instead of input[i]. +// The implementation is deterministic. +// Note: Treduce is to allow reducing in higher precision than T. +template +Status SegmentReduceGPU(OpKernelContext* ctx, Toffsets nouter, Toffsets ninner, + Tsegmentids nsegments, ReduceOp reduce_op, + T initial_value, T empty_segment_value, bool is_mean, + bool is_sqrtn, + const T* input, // [nouter or any, ninner] + const Tsegmentids* segment_ids, // [nouter] + const Tindices* indices, // [nouter] (optional) + const Tweights* weights, // [nouter or any] (optional) + T* output) { // [nsegments, ninner] + if (ninner == 0 || nsegments == 0) return OkStatus(); + return DispatchToVectorized< + T, SegmentReduceGPUVectorized::template Impl>( + MinAlignmentOf(input, output, ninner), ctx, nouter, ninner, nsegments, + reduce_op, initial_value, empty_segment_value, is_mean, is_sqrtn, input, + segment_ids, indices, weights, output); +} + +template +__global__ void SegmentWeightsKernel( + SegmentId nsegments, SparseSegmentReductionOperation operation, + const Index* __restrict__ segment_offsets, // [nsegments + 1] + Tweights* __restrict__ weights) { // [nsegments] + GPU_1D_KERNEL_LOOP(i, nsegments) { + Index segment_size = segment_offsets[i + 1] - segment_offsets[i]; + segment_size = max(segment_size, Index(1)); // Avoid division by zero + if (operation == SparseSegmentReductionOperation::kMean) { + weights[i] = Tweights(1) / static_cast(segment_size); + } else if (operation == SparseSegmentReductionOperation::kSqrtN) { + weights[i] = Tweights(1) / sqrt(static_cast(segment_size)); + } + } +} + +template +Status LaunchSegmentWeightsKernel( + const GPUDevice& d, SegmentId nsegments, + SparseSegmentReductionOperation operation, + const Index* segment_offsets, // [nsegments + 1] + Tweights* weights) { // [nsegments] + GpuLaunchConfig config = GetGpuLaunchConfig( + nsegments, d, &SegmentWeightsKernel, + /*dynamic_shared_memory_size=*/0, /*block_size_limit=*/0); + return GpuLaunchKernel(SegmentWeightsKernel, + config.block_count, config.thread_per_block, 0, + d.stream(), nsegments, operation, segment_offsets, + weights); +} + +template +struct ReduceType { + using type = T; +}; + +// Sum fp16 values using an fp32 accumulator to avoid numerical issues. +template <> +struct ReduceType { + using type = float; +}; + +template <> +struct ReduceType { + using type = float; +}; + +namespace functor { + +template +void SegmentReductionFunctor< + T, Index, InitialValueF, EmptySegmentValueF, + ReductionF>::operator()(OpKernelContext* ctx, const GPUDevice& d, + const Index output_rows, + const TensorShape& segment_ids_shape, bool is_mean, + typename TTypes::ConstFlat segment_ids, + const Index data_size, const T* data, + typename TTypes::Tensor output) { + if (output.size() == 0) { + return; + } + + // Launch kernel(s) to compute sorted segment reduction. + // Notes: + // *) 'input_total_size' is the total number of elements to process. + // *) 'segment_ids.shape' is a prefix of data's shape. + // *) 'input_outer_dim_size' is the total number of segments to process. + const Index input_total_size = data_size; + const Index input_outer_dim_size = segment_ids.dimension(0); + const Index input_inner_dim_size = input_total_size / input_outer_dim_size; + const Index num_segments = output.size() / input_inner_dim_size; + + bool use_deterministic_kernels = + UseDeterministicSegmentReductions() || + (OpDeterminismRequired() && !ReduceOpIsAssociative::value); + + // TODO(benbarsdell): If there are no performance concerns with the new + // deterministic kernels, remove this runtime check and the old + // non-deterministic kernels. + if (!use_deterministic_kernels) { + // Set 'output' to initial value. + GpuLaunchConfig config = GetGpuLaunchConfig(output.size(), d); + const T initial_value = InitialValueF()(); + TF_CHECK_OK(GpuLaunchKernel(SetToValue, config.block_count, + config.thread_per_block, 0, d.stream(), + output.size(), output.data(), initial_value)); + if (data_size == 0 || segment_ids_shape.num_elements() == 0) { + return; + } + + const int OuterDimTileSize = 8; + + const Index input_outer_dim_num_stripe = + Eigen::divup(input_outer_dim_size, Index(OuterDimTileSize)); + + const Index total_stripe_count = + input_inner_dim_size * input_outer_dim_num_stripe; + + config = GetGpuLaunchConfig(total_stripe_count, d); + TF_CHECK_OK(GpuLaunchKernel( + SortedSegmentReductionCustomKernel< + T, Index, OuterDimTileSize, + typename ReduceUpdateOpFor::nonatomic_op, + typename ReduceUpdateOpFor::atomic_op>, + config.block_count, config.thread_per_block, 0, d.stream(), + input_outer_dim_size, input_inner_dim_size, output_rows, + segment_ids.data(), data, output.data(), total_stripe_count, + initial_value)); + + const T empty_value = EmptySegmentValueF()(); + if (is_mean || initial_value != empty_value) { + Tensor segment_offsets; + OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum::value, + TensorShape({num_segments + 1}), + &segment_offsets)); + Index* segment_offsets_ptr = segment_offsets.flat().data(); + OP_REQUIRES_OK(ctx, LaunchSegmentOffsetsKernel( + d, input_outer_dim_size, num_segments, + segment_ids.data(), segment_offsets_ptr)); + + if (is_mean) { + OP_REQUIRES_OK(ctx, LaunchSegmentMeanNormalizeKernel( + d, num_segments, input_inner_dim_size, + segment_offsets_ptr, output.data())); + } + if (initial_value != empty_value) { + OP_REQUIRES_OK( + ctx, LaunchSegmentSetEmptyKernel( + d, num_segments, input_inner_dim_size, segment_offsets_ptr, + empty_value, output.data())); + } + } + } else { + using Treduce = typename ReduceType::type; + using Tweights = typename RealTypeIfComplex::type; + OP_REQUIRES_OK( + ctx, + SegmentReduceGPU( + ctx, input_outer_dim_size, input_inner_dim_size, num_segments, + ReductionF(), InitialValueF()(), EmptySegmentValueF()(), + /*is_mean=*/is_mean, /*is_sqrtn=*/false, data, segment_ids.data(), + /*indices=*/static_cast(nullptr), + /*weights=*/static_cast(nullptr), output.data())); + } +} + +template +struct UnsortedSegmentFunctor { + void operator()(OpKernelContext* ctx, const TensorShape& segment_ids_shape, + typename TTypes::ConstFlat unsorted_segment_ids, + typename TTypes::ConstTensor data, + typename TTypes::Tensor output) { + if (output.size() == 0) { + return; + } + + bool use_deterministic_kernels = + UseDeterministicSegmentReductions() || + (!ReduceOpIsAssociative::value && + OpDeterminismRequired()); + + bool determinism_requirement_met = + use_deterministic_kernels || + ReduceOpIsAssociative::value || + !OpDeterminismRequired() || + DisableSegmentReductionOpDeterminismExceptions(); + OP_REQUIRES( + ctx, determinism_requirement_met, + errors::Unimplemented( + "Deterministic GPU implementation of unsorted segment reduction op" + " not available.")); + + // Launch kernel(s) to compute unsorted segment reduction. + // Notes: + // *) 'data_size' is the total number of elements to process. + // *) 'segment_ids.shape' is a prefix of data's shape. + // *) 'input_outer_dim_size' is the total number of segments to process. + const Index input_outer_dim_size = unsorted_segment_ids.dimension(0); + const Index input_inner_dim_size = data.dimension(1); + const Index output_outer_dim_size = output.dimension(0); + const Index num_segments = output.size() / input_inner_dim_size; + + // TODO(benbarsdell): If there are no performance concerns with the new + // deterministic kernels, remove this runtime check and the old + // non-deterministic kernels. + if (!use_deterministic_kernels) { + // Set 'output' to initial value. + GPUDevice d = ctx->template eigen_device(); + GpuLaunchConfig config = GetGpuLaunchConfig(output.size(), d); + TF_CHECK_OK(GpuLaunchKernel( + SetToValue, config.block_count, config.thread_per_block, 0, + d.stream(), output.size(), output.data(), InitialValueF()())); + const int64_t data_size = data.size(); + if (data_size == 0 || segment_ids_shape.num_elements() == 0) { + return; + } + config = GetGpuLaunchConfig(data_size, d); + TF_CHECK_OK(GpuLaunchKernel( + UnsortedSegmentCustomKernel< + T, Index, typename ReduceUpdateOpFor::atomic_op>, + config.block_count, config.thread_per_block, 0, d.stream(), + input_outer_dim_size, input_inner_dim_size, output_outer_dim_size, + unsorted_segment_ids.data(), data.data(), output.data())); + } else { + // Allocate temporary space and sort segment_ids, then call the sorted + // implem. + Tensor segment_ids; + OP_REQUIRES_OK( + ctx, ctx->allocate_temp( + DataTypeToEnum::value, + TensorShape({static_cast(input_outer_dim_size)}), + &segment_ids)); + Index* segment_ids_ptr = segment_ids.flat().data(); + Tensor sorted_indices; + OP_REQUIRES_OK( + ctx, ctx->allocate_temp( + DataTypeToEnum::value, + TensorShape({static_cast(input_outer_dim_size)}), + &sorted_indices)); + Index* sorted_indices_ptr = sorted_indices.flat().data(); + // Note: We must sort using all bits here because unsorted_segment_ids + // may contain negative values. + OP_REQUIRES_OK( + ctx, GpuRadixSort(ctx, input_outer_dim_size, + /*keys_in=*/unsorted_segment_ids.data(), + /*keys_out=*/segment_ids_ptr, + /*indices_in=*/static_cast(nullptr), + /*indices_out=*/sorted_indices_ptr)); + using Treduce = typename ReduceType::type; + using Tweights = typename RealTypeIfComplex::type; + OP_REQUIRES_OK( + ctx, + SegmentReduceGPU( + ctx, input_outer_dim_size, input_inner_dim_size, num_segments, + ReductionF(), /*initial_value=*/InitialValueF()(), + /*empty_segment_value=*/InitialValueF()(), /*is_mean=*/false, + /*is_sqrtn=*/false, /*input=*/data.data(), + /*segment_ids=*/segment_ids_ptr, /*indices=*/sorted_indices_ptr, + /*weights=*/static_cast(nullptr), output.data())); + } + } +}; + +template +Status SparseSegmentReductionFunctor::operator()( + OpKernelContext* context, bool is_mean, bool is_sqrtn, T default_value, + typename TTypes::ConstTensor input, + typename TTypes::ConstVec indices, + typename TTypes::ConstVec segment_ids, + typename TTypes::Tensor output) { + using ReduceOp = functor::Sum; + using Treduce = typename ReduceType::type; + using Tweights = typename RealTypeIfComplex::type; + Index nouter = segment_ids.size(); + Index ninner = input.dimension(1); + SegmentId nsegments = output.dimension(0); + return SegmentReduceGPU( + context, /*nouter=*/nouter, /*ninner=*/ninner, + /*nsegments=*/nsegments, /*reduce_op=*/ReduceOp(), + /*initial_value=*/T(0), + /*empty_segment_value=*/default_value, + /*is_mean=*/is_mean, /*is_sqrtn=*/is_sqrtn, + /*input=*/input.data(), /*segment_ids=*/segment_ids.data(), + /*indices=*/indices.data(), /*weights=*/static_cast(nullptr), + /*output=*/output.data()); +} + +template +struct SparseSegmentGradFunctor { + void operator()(OpKernelContext* context, + SparseSegmentReductionOperation operation, + typename TTypes::ConstMatrix input_flat, + typename TTypes::ConstVec indices_vec, + typename TTypes::ConstVec segment_vec, + Tensor* output) { + const GPUDevice& device = context->eigen_gpu_device(); + + auto output_flat = output->flat_outer_dims(); + const SegmentId nsegments = input_flat.dimension(0); + const Index ninner = input_flat.dimension(1); + const Index nouter = indices_vec.dimension(0); + const Index noutput = output_flat.dimension(0); + + // Allocate and compute segment weights (for Mean/SqrtN operations only). + Tensor weights; + using Tweights = typename RealTypeIfComplex::type; + Tweights* weights_ptr = nullptr; + if (operation != SparseSegmentReductionOperation::kSum) { + OP_REQUIRES_OK( + context, context->allocate_temp(DataTypeToEnum::value, + TensorShape({nsegments}), &weights)); + weights_ptr = weights.flat().data(); + // Allocate and compute segment_offsets. + Tensor segment_offsets; + OP_REQUIRES_OK(context, + context->allocate_temp(DataTypeToEnum::value, + TensorShape({nsegments + 1}), + &segment_offsets)); + Index* segment_offsets_ptr = segment_offsets.flat().data(); + OP_REQUIRES_OK(context, LaunchSegmentOffsetsKernel( + device, nouter, nsegments, segment_vec.data(), + segment_offsets_ptr)); + // Compute the weights based on the segment sizes using segment_offsets. + OP_REQUIRES_OK(context, LaunchSegmentWeightsKernel( + device, nsegments, operation, + segment_offsets_ptr, weights_ptr)); + } + + const Index* sorted_indices_ptr = indices_vec.data(); + const SegmentId* sorted_segment_ptr = segment_vec.data(); + Tensor tmp_sorted_indices; + Tensor tmp_sorted_segment; + if (noutput > 1) { + // Sort indices and permute segments. + OP_REQUIRES_OK(context, context->allocate_temp( + DataTypeToEnum::value, + TensorShape({nouter}), &tmp_sorted_indices)); + Index* tmp_sorted_indices_ptr = tmp_sorted_indices.flat().data(); + OP_REQUIRES_OK(context, context->allocate_temp( + DataTypeToEnum::value, + TensorShape({nouter}), &tmp_sorted_segment)); + SegmentId* tmp_sorted_segment_ptr = + tmp_sorted_segment.flat().data(); + OP_REQUIRES_OK(context, + GpuRadixSort(context, nouter, + /*keys_in=*/indices_vec.data(), + /*keys_out=*/tmp_sorted_indices_ptr, + /*indices_in=*/segment_vec.data(), + /*indices_out=*/tmp_sorted_segment_ptr, + /*num_bits=*/Log2Ceiling64(noutput))); + sorted_indices_ptr = tmp_sorted_indices_ptr; + sorted_segment_ptr = tmp_sorted_segment_ptr; + } + + // Compute the gradient using a weighted SegmentReduceGPU with the segment + // IDs and indices swapped. + using ReduceOp = functor::Sum; + using Treduce = typename ReduceType::type; + OP_REQUIRES_OK( + context, + SegmentReduceGPU( + context, /*nouter=*/static_cast(nouter), + /*ninner=*/static_cast(ninner), + /*nsegments=*/noutput, + /*reduce_op=*/ReduceOp(), + /*initial_value=*/T(0), + /*empty_segment_value=*/T(0), + /*is_mean=*/false, /*is_sqrtn=*/false, + /*input=*/input_flat.data(), /*segment_ids=*/sorted_indices_ptr, + /*indices=*/sorted_segment_ptr, /*weights=*/weights_ptr, + /*output=*/output_flat.data())); + } +}; + +template +struct EdgeIndicatorFunctor { + EdgeIndicatorFunctor(const TindicesCompact* sorted_indices) + : sorted_indices_(sorted_indices) {} + + template + __device__ bool operator()(Idx idx) const { + return idx == 0 ? false : sorted_indices_[idx] != sorted_indices_[idx - 1]; + } + + private: + const TindicesCompact* __restrict__ sorted_indices_; +}; + +template +__global__ void ScatterUniqueIndicesKernel( + Toffsets nouter, + EdgeIndicatorIter sorted_indices_edge_indicator, // [nouter] + const TindicesCompact* __restrict__ sorted_indices, // [nouter] + const Toffsets* __restrict__ sorted_indices_ids, // [nouter] + Tindices* __restrict__ sorted_unique_indices) { // [num_unique] + for (int i : GpuGridRangeX(nouter)) { + if (i == 0 || sorted_indices_edge_indicator[i]) { + sorted_unique_indices[sorted_indices_ids[i]] = + static_cast(sorted_indices[i]); + } + } +} + +template +Status LaunchScatterUniqueIndicesKernel( + const GPUDevice& d, Toffsets nouter, + EdgeIndicatorIter sorted_indices_edge_indicator, // [nouter] + const TindicesCompact* __restrict__ sorted_indices, // [nouter] + const Toffsets* __restrict__ sorted_indices_ids, // [nouter] + Tindices* __restrict__ sorted_unique_indices) { // [num_unique] + GpuLaunchConfig config = GetGpuLaunchConfig( + nouter, d, + &ScatterUniqueIndicesKernel, + /*dynamic_shared_memory_size=*/0, /*block_size_limit=*/0); + return GpuLaunchKernel(ScatterUniqueIndicesKernel, + config.block_count, config.thread_per_block, 0, + d.stream(), nouter, sorted_indices_edge_indicator, + sorted_indices, sorted_indices_ids, + sorted_unique_indices); +} + +template +struct SparseSegmentGradV2Functor { + void operator()(OpKernelContext* context, + SparseSegmentReductionOperation operation, + typename TTypes::ConstMatrix input_flat, + typename TTypes::ConstVec indices_vec, + typename TTypes::ConstVec segment_vec, + const TensorShape& dense_output_shape, + typename AsyncOpKernel::DoneCallback done) { + const GPUDevice& device = context->eigen_gpu_device(); + + const int64_t nsegments = input_flat.dimension(0); + const int64_t ninner64 = input_flat.dimension(1); + const int64_t nouter64 = indices_vec.dimension(0); + // Note: nouter and ninner are not expected to be huge, so we use int32 to + // save memory bandwidth. + using Toffsets = int32; + OP_REQUIRES_ASYNC(context, nouter64 <= std::numeric_limits::max(), + absl::InvalidArgumentError( + absl::StrCat("Indices vector of length ", nouter64, + " is too large to fit in int32.")), + done); + const Toffsets nouter = static_cast(nouter64); + OP_REQUIRES_ASYNC(context, ninner64 <= std::numeric_limits::max(), + absl::InvalidArgumentError(absl::StrCat( + "Inner data dimension of size ", ninner64, + " is too large to fit in int32.")), + done); + const Toffsets ninner = static_cast(ninner64); + + // Cast indices to 32-bit to save memory bandwidth (the cost of the cast is + // worth it because the vector is used multiple times). + // Note that we can currently assume int32 is safe because the op's dense + // output_dim0 input is always int32. + using TindicesCompact = int32; + Tensor tmp_indices_internal; + const TindicesCompact* indices_internal_ptr; + if constexpr (std::is_same::value) { + indices_internal_ptr = indices_vec.data(); + } else { + OP_REQUIRES_OK_ASYNC( + context, + context->allocate_temp(DataTypeToEnum::value, + TensorShape({nouter}), &tmp_indices_internal), + done); + auto indices_vec_internal = tmp_indices_internal.flat(); + indices_vec_internal.device(device) = + indices_vec.template cast(); + indices_internal_ptr = indices_vec_internal.data(); + } + + // Cast segment IDs to smallest possible type to save memory bandwidth. + if (nsegments <= std::numeric_limits::max()) { + CastSegmentIdsThenImpl( + context, operation, nouter, ninner, nsegments, input_flat.data(), + tmp_indices_internal, indices_internal_ptr, segment_vec, + dense_output_shape, done); + } else if (sizeof(Tsegmentids) > sizeof(int32) && + nsegments <= std::numeric_limits::max()) { + CastSegmentIdsThenImpl( + context, operation, nouter, ninner, nsegments, input_flat.data(), + tmp_indices_internal, indices_internal_ptr, segment_vec, + dense_output_shape, done); + } else { + Impl( + context, operation, nouter, ninner, nsegments, input_flat.data(), + tmp_indices_internal, indices_internal_ptr, Tensor(), + segment_vec.data(), dense_output_shape, done); + } + } + + private: + using Tweights = typename RealTypeIfComplex::type; + + template + void CastSegmentIdsThenImpl( + OpKernelContext* context, SparseSegmentReductionOperation operation, + Toffsets nouter, Toffsets ninner, Tsegmentids_internal nsegments, + const T* input, Tensor indices_tensor, const TindicesCompact* indices, + typename TTypes::ConstVec segment_vec, + const TensorShape& dense_output_shape, + typename AsyncOpKernel::DoneCallback done) { + const GPUDevice& device = context->eigen_gpu_device(); + Tensor tmp_segment_internal; + OP_REQUIRES_OK_ASYNC( + context, + context->allocate_temp(DataTypeToEnum::value, + TensorShape({nouter}), &tmp_segment_internal), + done); + auto segment_vec_internal = + tmp_segment_internal.flat(); + segment_vec_internal.device(device) = + segment_vec.template cast(); + + Impl( + context, operation, nouter, ninner, nsegments, input, indices_tensor, + indices, tmp_segment_internal, segment_vec_internal.data(), + dense_output_shape, done); + } + + template + void Impl(OpKernelContext* context, SparseSegmentReductionOperation operation, + Toffsets nouter, Toffsets ninner, Tsegmentids_internal nsegments, + const T* input, Tensor indices_tensor, + const TindicesCompact* indices, Tensor segment_ids_tensor, + const Tsegmentids_internal* segment_ids, + const TensorShape& dense_output_shape, + typename AsyncOpKernel::DoneCallback done) { + const int64_t dense_output_dim0 = dense_output_shape.dim_size(0); + + // Allocate and compute segment weights (for Mean/SqrtN operations only). + Tensor tmp_weights; + Tweights* weights_ptr = nullptr; + if (operation != SparseSegmentReductionOperation::kSum) { + ComputeSegmentWeights(context, operation, nsegments, nouter, segment_ids, + &tmp_weights, done); + weights_ptr = tmp_weights.flat().data(); + } + + const TindicesCompact* sorted_indices_ptr = indices; + const Tsegmentids_internal* permuted_segment_ptr = segment_ids; + Tensor tmp_sorted_indices; + Tensor tmp_permuted_segment; + if (dense_output_dim0 > 1) { + // Sort indices and permute segments. + OP_REQUIRES_OK_ASYNC( + context, + context->allocate_temp(DataTypeToEnum::value, + TensorShape({nouter}), &tmp_sorted_indices), + done); + TindicesCompact* tmp_sorted_indices_ptr = + tmp_sorted_indices.flat().data(); + OP_REQUIRES_OK_ASYNC( + context, + context->allocate_temp(DataTypeToEnum::value, + TensorShape({nouter}), &tmp_permuted_segment), + done); + Tsegmentids_internal* tmp_permuted_segment_ptr = + tmp_permuted_segment.flat().data(); + OP_REQUIRES_OK_ASYNC( + context, + GpuRadixSort(context, nouter, + /*keys_in=*/indices, + /*keys_out=*/tmp_sorted_indices_ptr, + /*indices_in=*/segment_ids, + /*indices_out=*/tmp_permuted_segment_ptr, + /*num_bits=*/Log2Ceiling64(dense_output_dim0)), + done); + sorted_indices_ptr = tmp_sorted_indices_ptr; + permuted_segment_ptr = tmp_permuted_segment_ptr; + // The original tensors are no longer needed. + indices_tensor = Tensor(); + indices = nullptr; + segment_ids_tensor = Tensor(); + segment_ids = nullptr; + } + + using CountIter = gpuprim::CountingInputIterator; + using EdgeIndicatorIter = gpuprim::TransformInputIterator< + Toffsets, EdgeIndicatorFunctor, CountIter>; + EdgeIndicatorIter sorted_indices_edge_indicator( + CountIter(0), + EdgeIndicatorFunctor(sorted_indices_ptr)); + + Tensor tmp_sorted_indices_unique_ids; + OP_REQUIRES_OK_ASYNC(context, + context->allocate_temp(DataTypeToEnum::value, + TensorShape({nouter}), + &tmp_sorted_indices_unique_ids), + done); + Toffsets* sorted_indices_unique_ids_ptr = + tmp_sorted_indices_unique_ids.flat().data(); + OP_REQUIRES_OK_ASYNC( + context, + GpuInclusivePrefixSum(context, nouter, sorted_indices_edge_indicator, + sorted_indices_unique_ids_ptr), + done); + + se::Stream* stream = context->op_device_context()->stream(); + OP_REQUIRES_ASYNC(context, stream, + absl::InternalError("No GPU stream available."), done); + + // Copy the last element of sorted_indices_unique_ids back to the host to + // obtain num_unique. + ScratchSpace last_idx_host(context, 1, /*on_host=*/true); + OP_REQUIRES_OK_ASYNC( + context, + stream->Memcpy(last_idx_host.mutable_data(), + se::DeviceMemoryBase(const_cast( + sorted_indices_unique_ids_ptr) + + (nouter - 1), + sizeof(*last_idx_host.data())), + sizeof(*last_idx_host.data())), + done); + + auto async_finish_computation = + [this, context, dense_output_shape, nouter, ninner, input, + indices_tensor, tmp_sorted_indices, sorted_indices_ptr, + tmp_sorted_indices_unique_ids, sorted_indices_unique_ids_ptr, + segment_ids_tensor, tmp_permuted_segment, permuted_segment_ptr, + sorted_indices_edge_indicator, tmp_weights, weights_ptr, last_idx_host, + done]() -> void { + const GPUDevice& device = context->eigen_gpu_device(); + Toffsets num_unique = (*last_idx_host.data()) + 1; + + std::unique_ptr scoped_activation = + context->op_device_context()->stream()->parent()->Activate(); + + TensorShape output_shape = dense_output_shape; + OP_REQUIRES_OK_ASYNC(context, + output_shape.SetDimWithStatus(0, num_unique), done); + Tensor* output = nullptr; + T* output_ptr; + OP_REQUIRES_OK_ASYNC( + context, context->allocate_output(0, output_shape, &output), done); + output_ptr = output->flat().data(); + + // Compute the gradient using a weighted SegmentReduceGPU with the segment + // IDs and indices swapped. + using ReduceOp = functor::Sum; + using Treduce = typename ReduceType::type; + OP_REQUIRES_OK_ASYNC(context, + SegmentReduceGPU( + context, /*nouter=*/nouter, + /*ninner=*/ninner, + /*nsegments=*/num_unique, + /*reduce_op=*/ReduceOp(), + /*initial_value=*/T(0), + /*empty_segment_value=*/T(0), + /*is_mean=*/false, /*is_sqrtn=*/false, + /*input=*/input, + /*segment_ids=*/sorted_indices_unique_ids_ptr, + /*indices=*/permuted_segment_ptr, + /*weights=*/weights_ptr, + /*output=*/output_ptr), + done); + + Tensor* sorted_unique_indices = nullptr; + Tindices* sorted_unique_indices_ptr; + OP_REQUIRES_OK_ASYNC( + context, + context->allocate_output(1, TensorShape({num_unique}), + &sorted_unique_indices), + done); + sorted_unique_indices_ptr = + sorted_unique_indices->flat().data(); + + OP_REQUIRES_OK_ASYNC( + context, + LaunchScatterUniqueIndicesKernel( + device, nouter, sorted_indices_edge_indicator, sorted_indices_ptr, + sorted_indices_unique_ids_ptr, sorted_unique_indices_ptr), + done); + + done(); + }; + + context->device() + ->tensorflow_accelerator_device_info() + ->event_mgr->ThenExecute(stream, async_finish_computation); + } + + template + void ComputeSegmentWeights(OpKernelContext* context, + SparseSegmentReductionOperation operation, + Tsegmentids_internal nsegments, Toffsets nouter, + const Tsegmentids_internal* segment_ids, + Tensor* tmp_weights, + typename AsyncOpKernel::DoneCallback done) { + const GPUDevice& device = context->eigen_gpu_device(); + OP_REQUIRES_OK_ASYNC( + context, + context->allocate_temp(DataTypeToEnum::value, + TensorShape({nsegments}), tmp_weights), + done); + Tweights* weights_ptr = tmp_weights->flat().data(); + // Allocate and compute segment_offsets. + Tensor tmp_segment_offsets; + OP_REQUIRES_OK_ASYNC(context, + context->allocate_temp(DataTypeToEnum::value, + TensorShape({nsegments + 1}), + &tmp_segment_offsets), + done); + Toffsets* segment_offsets_ptr = tmp_segment_offsets.flat().data(); + OP_REQUIRES_OK_ASYNC( + context, + LaunchSegmentOffsetsKernel(device, nouter, nsegments, segment_ids, + segment_offsets_ptr), + done); + // Compute the weights based on the segment sizes using segment_offsets. + OP_REQUIRES_OK_ASYNC( + context, + LaunchSegmentWeightsKernel(device, nsegments, operation, + segment_offsets_ptr, weights_ptr), + done); + } +}; + +} // namespace functor +} // namespace tensorflow + +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM +#endif // TENSORFLOW_CORE_KERNELS_SEGMENT_REDUCTION_OPS_GPU_CU_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/segment_reduction_ops_impl.h b/third_party/tflite-hdrs/tensorflow/core/kernels/segment_reduction_ops_impl.h new file mode 100644 index 00000000..d087bfae --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/segment_reduction_ops_impl.h @@ -0,0 +1,1488 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// See docs in ../ops/math_ops.cc. + +#ifndef TENSORFLOW_CORE_KERNELS_SEGMENT_REDUCTION_OPS_IMPL_H_ +#define TENSORFLOW_CORE_KERNELS_SEGMENT_REDUCTION_OPS_IMPL_H_ + +#include +#include +#include +#include +#include +#include + +#include "absl/log/check.h" +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "tensorflow/core/framework/op_requires.h" +#include "tensorflow/core/platform/types.h" +#define EIGEN_USE_THREADS +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM +#define EIGEN_USE_GPU +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM + +#include "absl/container/flat_hash_map.h" +#include "Eigen/Core" // from @eigen_archive +#include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/bounds_check.h" +#include "tensorflow/core/framework/numeric_op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/framework/tensor_util.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/kernels/segment_reduction_ops.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/bfloat16.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/util/determinism.h" +#include "tensorflow/core/util/util.h" + +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM +#include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h" +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM + +#if GOOGLE_CUDA +#include "tensorflow/core/util/gpu_solvers.h" + +#elif TENSORFLOW_USE_ROCM +#include "tensorflow/core/platform/rocm.h" +#include "tensorflow/core/util/gpu_solvers.h" +#endif // GOOGLE_CUDA + +namespace tensorflow { + +typedef Eigen::ThreadPoolDevice CPUDevice; +typedef Eigen::GpuDevice GPUDevice; + +namespace internal { + +absl::Status ValidateSegmentReduction(OpKernelContext* c, const Tensor& input, + const Tensor& segment_ids); +absl::Status ValidateUnsortedSegmentReduction(OpKernel* op_kernel, + OpKernelContext* context, + const Tensor& data, + const Tensor& segment_ids, + const Tensor& num_segments); +absl::Status ValidateSparseSegmentReduction(OpKernelContext* context, + const Tensor& input, + const Tensor& indices, + const Tensor& segment_ids, + bool has_num_segments); +} // namespace internal + +// This operator handles reducing segments along the first dimension. +// See core/ops/math_ops.cc for more details. +template +class SegmentReductionOp : public OpKernel { + public: + explicit SegmentReductionOp(OpKernelConstruction* context) + : OpKernel(context) {} + + void Compute(OpKernelContext* context) override { + const Tensor& input = context->input(0); + const Tensor& segment_ids = context->input(1); + + OP_REQUIRES_OK(context, internal::ValidateSegmentReduction(context, input, + segment_ids)); + + const int64_t num_indices = segment_ids.NumElements(); + auto input_flat = input.flat_outer_dims(); + const int64_t num_col = input_flat.dimension(1); + + const auto segment_vec = segment_ids.vec(); + // Note that the current implementation assumes that segment_vec values are + // sorted. + const Index output_rows = + num_indices > 0 + ? internal::SubtleMustCopy(segment_vec(num_indices - 1)) + 1 + : 0; + OP_REQUIRES(context, output_rows >= 0, + errors::InvalidArgument("segment ids must be >= 0")); + + OP_REQUIRES(context, input.dims() >= 1, + errors::InvalidArgument("Shape must be at least rank 1")); + + TensorShape output_shape = input.shape(); + // Since we're changing the first dimension of the shape, we need to make + // sure the new shape won't overflow. + OP_REQUIRES_OK(context, output_shape.SetDimWithStatus(0, output_rows)); + + // Note that we do not initialize the output buffer with a default value, so + // we need to explicitly set missing indices to the default value. + Tensor* output = nullptr; + OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output)); + if (num_indices == 0) return; + OP_REQUIRES(context, output_rows > 0, + errors::InvalidArgument("segment ids must be >= 0")); + auto output_flat = output->flat_outer_dims(); + + Eigen::IndexList > dims_to_reduce; + Index start = 0, end = 1; + + Index uninitialized_index = 0; // Index from which the output is not set. + Index out_index = internal::SubtleMustCopy(segment_vec(start)); + + // TODO(agarwal): if this loop becomes a bottleneck, consider sharding it + // across threads. + Eigen::DSizes out_slice_shape(num_col); + while (end <= num_indices) { + // We initialize next_index to 0 to avoid "warning: 'next_index' may be + // used uninitialized in this function" in the Mac build (since the + // compiler isn't smart enough to realize the code is safe). + Index next_index = 0; + if (end < num_indices) { + next_index = internal::SubtleMustCopy(segment_vec(end)); + if (out_index == next_index) { + ++end; + continue; + } + // We have a new segment here. Verify that the segment ids are growing. + OP_REQUIRES(context, out_index < next_index, + errors::InvalidArgument("segment ids are not increasing")); + } + + // Process segment [start, end) + const T* in_slice_ptr = &input_flat(start, 0); + typedef Eigen::TensorMap, + Eigen::Unaligned> + OutT; + + OP_REQUIRES( + context, FastBoundsCheck(out_index, output_rows), + errors::InvalidArgument( + "Segment id ", out_index, " out of range [0, ", output_rows, + "), possibly because 'segment_ids' input is not sorted.")); + + // If there is a gap between two indices, we need to set that gap to the + // default value. + if (out_index > uninitialized_index) { + Eigen::DSizes gap_slice_shape( + out_index - uninitialized_index, num_col); + Eigen::TensorMap, Eigen::Unaligned> + gap_slice(&output_flat(uninitialized_index, 0), gap_slice_shape); + gap_slice.setConstant(T(default_value)); + } + + T* out_slice_ptr = &output_flat(out_index, 0); + OutT out_slice(out_slice_ptr, out_slice_shape); + // We don't use out_slice.device(context->eigen_device) + // because these pieces of work are likely to be very small and + // the context switching overhead dwarfs any benefit we get from + // using another thread to do this work. + if (start == end - 1) { + typedef Eigen::TensorMap, + Eigen::Unaligned> + InT; + InT in_slice(in_slice_ptr, out_slice_shape); + out_slice = in_slice; + } else { + Eigen::DSizes in_slice_shape(end - start, + num_col); + typedef Eigen::TensorMap, + Eigen::Unaligned> + InT; + InT in_slice(in_slice_ptr, in_slice_shape); + + out_slice = in_slice.reduce(dims_to_reduce, Reducer()); + } + if (end >= num_indices) break; + start = end; + ++end; + uninitialized_index = out_index + 1; + out_index = next_index; + } + } +}; + +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM + +// SegmentReductionGPUOp is a segment reduction operator implemented for GPU +// only. +// TODO: This implementation of SegmentReductionGPUOp is sometimes slower than +// its unsorted counterpart (mostly when problem size is small). +// This is due to the following two main reasons and a cost-effective way +// to resolve these problems is desirable. +// 1. Sorted segment reduction requires a memory transfer from device to host +// in order to know the size of the output dimension whereas unsorted +// segment reduction receives the size of the output dimension as an input +// parameter. +// 2. Sorted segment reduction is essentially a tiled version of unsorted +// segment reduction and therefore such optimization comes at an inherent +// cost. However such cost may not be justified when the problem size is +// small. When to use the tiled version or the untiled version depends on +// many factors including data alignments, ratio of calculation to memory +// traffic and obviously, the problem sizes. +template +class SegmentReductionGPUOp : public AsyncOpKernel { + public: + explicit SegmentReductionGPUOp(OpKernelConstruction* context) + : AsyncOpKernel(context) {} + + void ComputeAsync(OpKernelContext* context, DoneCallback done) override { + const Tensor& input = context->input(0); + const Tensor& segment_ids = context->input(1); + + OP_REQUIRES_ASYNC( + context, TensorShapeUtils::IsVector(segment_ids.shape()), + errors::InvalidArgument("segment_ids should be a vector."), done); + + OP_REQUIRES_ASYNC(context, input.dims() >= 1, + errors::InvalidArgument("Shape must be at least rank 1"), + done); + + const int64_t num_indices = segment_ids.NumElements(); + OP_REQUIRES_ASYNC( + context, num_indices == input.dim_size(0), + errors::InvalidArgument( + "segment_ids should be the same size as dimension 0 of" + " input."), + done); + + if (num_indices == 0) { + TensorShape output_shape = input.shape(); + output_shape.set_dim(0, 0); + + Tensor* output = nullptr; + OP_REQUIRES_OK_ASYNC( + context, context->allocate_output(0, output_shape, &output), done); + done(); + return; + } + + se::DeviceMemoryBase output_rows_device( + const_cast(segment_ids).template flat().data() + + (num_indices - 1)); + ScratchSpace output_rows_host(context, 1, /* on_host */ true); + + auto stream = context->op_device_context()->stream(); + OP_REQUIRES_OK_ASYNC(context, + stream->Memcpy(output_rows_host.mutable_data(), + output_rows_device, sizeof(Index)), + done); + + SegmentReductionFunctor functor_; + auto create_and_check_output = [context, output_rows_host, &input, + &segment_ids, &functor_, done]() { + // Ensure that within the callback, the proper GPU settings are + // configured. + auto stream = context->op_device_context()->stream(); + std::unique_ptr scoped_activation = + stream->parent()->Activate(); + + Index output_rows = *output_rows_host.data(); + output_rows++; + OP_REQUIRES_ASYNC(context, output_rows > 0, + errors::InvalidArgument("segment ids must be >= 0"), + done); + + TensorShape output_shape = input.shape(); + // Since we're changing the first dimension of the shape, we need to make + // sure the new shape won't overflow. + OP_REQUIRES_OK_ASYNC(context, + output_shape.SetDimWithStatus(0, output_rows), done); + + Tensor* output = nullptr; + OP_REQUIRES_OK_ASYNC( + context, context->allocate_output(0, output_shape, &output), done); + + bool use_deterministic_kernels = + UseDeterministicSegmentReductions() || + (!SegmentReductionFunctor::atomic_reduction_is_associative && + OpDeterminismRequired()); + + // The determinism check is here, rather than inside the functor (as it is + // for the unsorted segment reduction ops) because the done callback + // (required for OP_REQUIRES_ASYNC) is not available inside the functor. + bool determinism_requirement_met = + use_deterministic_kernels || + SegmentReductionFunctor::atomic_reduction_is_associative || + !OpDeterminismRequired() || + DisableSegmentReductionOpDeterminismExceptions(); + OP_REQUIRES_ASYNC( + context, determinism_requirement_met, + errors::Unimplemented( + "Deterministic GPU implementation of sorted segment reduction op" + " not available."), + done); + + auto output_flat = output->flat_outer_dims(); + auto data_ptr = input.template flat().data(); + auto segment_flat = segment_ids.flat(); + functor_(context, context->eigen_device(), output_rows, + segment_ids.shape(), IsMean, segment_flat, input.NumElements(), + data_ptr, output_flat); + + done(); + }; + + context->device() + ->tensorflow_accelerator_device_info() + ->event_mgr->ThenExecute(stream, create_and_check_output); + } +}; +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM + +// ____________________________________________________________________________ +// Unsorted segment reduction ops. + +namespace functor { + +// The ReductionFunctor implementation for CPU. +template +struct UnsortedSegmentFunctor { + void operator()(OpKernelContext* ctx, const TensorShape& segment_ids_shape, + typename TTypes::ConstFlat segment_ids, + typename TTypes::ConstTensor data, + typename TTypes::Tensor output) { + auto cpu_device = ctx->eigen_cpu_device(); + output.device(cpu_device) = output.constant(InitialValueF()()); + if (data.size() == 0) { + return; + } + + // This functor will reduce `N` rows input to `num_segments` rows output. + const int64_t N = segment_ids.dimension(0); + const int64_t num_segments = output.dimension(0); + const int64_t inner_dim = data.dimension(1); + const T* data_ptr = data.data(); + T* out_ptr = output.data(); + ReductionF reduction; + + const bool is_inner_dim_1d = inner_dim == 1; + + // `num_real_segment` counts the rows actually reduced from input, + // the rows with negative segment index will be excluded. + // It will be used for cost model. + int64_t num_real_segment = N; + // `num_reductions` counts the rows actually reduced in output, + // the rows only filled with InitialValueF() will be excluded. + int64_t num_reductions = 0; + // `row_counter` records how many input rows will be reduced in each + // output row, the row only fills with InitialValueF() will keep 0. + // Length of non-zero elements is `num_reductions`. + std::vector row_counter(num_segments, 0); + + for (int64_t i = 0; i < N; ++i) { + Index j = internal::SubtleMustCopy(segment_ids(i)); + if (j < 0) { + --num_real_segment; + continue; + } + OP_REQUIRES(ctx, FastBoundsCheck(j, num_segments), + errors::InvalidArgument( + "segment_ids", SliceDebugString(segment_ids_shape, i), + " = ", j, " is out of range [0, ", num_segments, ")")); + if (row_counter[j] == 0) num_reductions++; + row_counter[j]++; + } + + // Nothing to reduce. All output values equal to `InitialValueF()`. + if (num_reductions == 0) return; + + // Parallelize by `num_segments`. It's simple, efficient and safe + // (no data dependency): + // + // input segment_ids num_segments operation + // | a0 | | 0 | worker 1: |0| f(a0, a1) + // | b0 | | 1 | worker 2: |1| f(b0, b1) + // N | c0 | | 2 | --> worker 3: |2| f(c0) + // | b1 | | 1 | + // | a1 | | 0 | + // + // TODO(intel-tf): Balance workload in `row_counter` to make parallelism + // more efficient. + auto reductionWorker = [&](int64_t begin, int64_t end) -> void { + for (int64_t i = 0; i < N; i++) { + Index j = internal::SubtleMustCopy(segment_ids(i)); + // If `j` is in work scope of this worker, do the reduction. + if (j >= begin && j < end) { + reduction(data.template chip<0>(i), output.template chip<0>(j)); + } + } + }; + auto reductionWorker1D = [&](int64_t begin, int64_t end) -> void { + for (int64_t i = 0; i < N; i++) { + Index j = internal::SubtleMustCopy(segment_ids(i)); + // If `j` is in work scope of this worker, do the reduction. + if (j >= begin && j < end) { + reduction(data_ptr[i], out_ptr[j]); + } + } + }; + // Reduction functors includes Sum, Max, Min, etc. Simply consider it + // will cost 5 cycles per operation. + const int64_t kAverTaskSize = num_real_segment / num_segments; + const int64_t compute_cycles = 5 * inner_dim * kAverTaskSize; + const int64_t input_bytes = sizeof(T) * inner_dim * kAverTaskSize; + const int64_t output_bytes = sizeof(T) * inner_dim * kAverTaskSize; + const Eigen::TensorOpCost cost(input_bytes, output_bytes, compute_cycles); + if (is_inner_dim_1d) { + cpu_device.parallelFor(num_segments, cost, reductionWorker1D); + } else { + cpu_device.parallelFor(num_segments, cost, reductionWorker); + } + } +}; + +template +using MatrixChip = Eigen::TensorChippingOp<0l, typename TTypes::Matrix>; + +template +using constMatrixChip = + Eigen::TensorChippingOp<0l, const typename TTypes::ConstMatrix>; + +// reduction functors +template +struct SumOp { + void operator()(const constMatrixChip data, MatrixChip output) { + output += data; + } + void operator()(const T& data, T& output) { output += data; } +}; + +template +struct MaxOp { + void operator()(const constMatrixChip data, MatrixChip output) { + output = data.cwiseMax(output); + } + void operator()(const T& data, T& output) { output = std::max(data, output); } +}; + +template +struct MinOp { + void operator()(const constMatrixChip data, MatrixChip output) { + output = data.cwiseMin(output); + } + void operator()(const T& data, T& output) { output = std::min(data, output); } +}; + +template +struct ProdOp { + void operator()(const constMatrixChip data, MatrixChip output) { + output *= data; + } + void operator()(const T& data, T& output) { output *= data; } +}; +} // namespace functor + +// The UnsortedSegmentReduction OpKernel. The DeviceReductionFunctor +// is the device specific implementation of the reduction. These device +// specific implementations are templated themselves with the corresponding +// initial value functors and reduction functors. +template +class UnsortedSegmentReductionOp : public OpKernel { + public: + explicit UnsortedSegmentReductionOp(OpKernelConstruction* context) + : OpKernel(context), reduction_functor_(DeviceReductionFunctor()) {} + + void Compute(OpKernelContext* context) override { + const Tensor& data = context->input(0); + const Tensor& segment_ids = context->input(1); + const Tensor& num_segments = context->input(2); + OP_REQUIRES_OK(context, + internal::ValidateUnsortedSegmentReduction( + this, context, data, segment_ids, num_segments)); + const auto segment_flat = segment_ids.flat(); + const Index output_rows = internal::SubtleMustCopy(static_cast( + num_segments.dtype() == DT_INT32 ? num_segments.scalar()() + : num_segments.scalar()())); + OP_REQUIRES(context, output_rows >= 0, + errors::InvalidArgument("Input num_segments == ", output_rows, + " must not be negative.")); + TensorShape output_shape; + OP_REQUIRES_OK(context, output_shape.AddDimWithStatus(output_rows)); + for (int i = segment_ids.dims(); i < data.dims(); i++) { + OP_REQUIRES_OK(context, output_shape.AddDimWithStatus(data.dim_size(i))); + } + Tensor* output = nullptr; + OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output)); + auto output_flat = output->flat_outer_dims(); + auto data_flat = data.flat_inner_outer_dims(segment_ids.dims() - 1); + reduction_functor_(context, segment_ids.shape(), segment_flat, data_flat, + output_flat); + } + + protected: + DeviceReductionFunctor reduction_functor_; +}; + +// ____________________________________________________________________________ +// Sparse segment reduction ops. + +// Same as SegmentReductionOp but takes as input a "sparse" tensor, represented +// by two dense tensors, one containing the data, and the other containing +// indices into the data. +// +// The template parameters are: +// * Device: An Eigen device object, on which the kernel will execute. +// * T: The value type. +// * Index: The element type of the indices tensor (int32 or int64). +// * SegmentId: The element type of the segment_ids tensor (int32 or int64). +template +class SparseSegmentReductionOpBase : public OpKernel { + public: + explicit SparseSegmentReductionOpBase(OpKernelConstruction* context, + bool is_mean, bool is_sqrtn, + bool has_num_segments, T default_value) + : OpKernel(context), + dtidx_(DataTypeToEnum::v()), + is_mean_(is_mean), + is_sqrtn_(is_sqrtn), + has_num_segments_(has_num_segments), + default_value_(default_value) {} + + void Compute(OpKernelContext* context) override { + const Tensor& input = context->input(0); + const Tensor& indices = context->input(1); + const Tensor& segment_ids = context->input(2); + + OP_REQUIRES_OK( + context, internal::ValidateSparseSegmentReduction( + context, input, indices, segment_ids, has_num_segments_)); + + Index output_rows = -1; + if (has_num_segments_) { + const Tensor& num_segments = context->input(3); + // Note that there is a Tnumsegments parameter on the op, but it is not + // plumbed through to here and so always takes its default value of int32. + output_rows = internal::SubtleMustCopy(num_segments.scalar()()); + } + const int64_t num_indices = indices.NumElements(); + + auto input_flat = input.flat_outer_dims(); + const int64_t num_col = input_flat.dimension(1); + const auto indices_vec = indices.vec(); + const auto segment_vec = segment_ids.vec(); + // Note that the current implementation assumes that segment_vec values are + // sorted. + const SegmentId last_segment_id = + num_indices > 0 ? segment_vec(num_indices - 1) : 0; + int64_t limit = dtidx_ == DataType::DT_INT32 ? kint32max : kint64max; + + OP_REQUIRES( + context, last_segment_id < limit, + errors::InvalidArgument("Last segment id must be < kintmax, got ", + last_segment_id, " limit ", limit)); + + const SegmentId last_segment_id_plus_one = + num_indices > 0 + ? internal::SubtleMustCopy(segment_vec(num_indices - 1)) + 1 + : 0; + + if (has_num_segments_) { + OP_REQUIRES( + context, output_rows >= last_segment_id_plus_one, + errors::InvalidArgument("segment ids must be < num_segments")); + } else { + output_rows = last_segment_id_plus_one; + } + OP_REQUIRES(context, output_rows >= 0, + errors::InvalidArgument("segment ids must be >= 0")); + + TensorShape output_shape = input.shape(); + OP_REQUIRES_OK( + context, output_shape.SetDimWithStatus(/*d=*/0, /*size=*/output_rows)); + + // Note that we do not initialize the output buffer with a default value, so + // we need to explicitly set missing indices to the default value. + Tensor* output = nullptr; + OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output)); + if (num_indices == 0) { + if (output_rows > 0) { + output->flat_outer_dims().setConstant(default_value_); + } + return; + } + OP_REQUIRES(context, output_rows > 0, + errors::InvalidArgument("segment ids must be >= 0")); + auto output_flat = output->flat_outer_dims(); + + // If we use DT_BFLOAT16 or DT_HALF, we need to use DT_FLOAT for + // accumulation. We create a temp tensor to perform this accumulation for + // every segment. + Tensor temp; + if (input.dtype() == DT_BFLOAT16 || input.dtype() == DT_HALF) { + TensorShape temp_shape = output_shape; + OP_REQUIRES_OK(context, temp_shape.SetDimWithStatus(/*d=*/0, /*size=*/1)); + temp = tensorflow::Tensor(DT_FLOAT, temp_shape); + } + auto temp_flat = temp.flat_outer_dims(); + + int64_t start = 0, end = 1; + // Index from which the output is not initialized. + SegmentId uninitialized_index = 0; + SegmentId out_index = internal::SubtleMustCopy(segment_vec(start)); + + while (true) { + // We initialize next_index to 0 to avoid "warning: 'next_index' may be + // used uninitialized in this function" in the Mac build (since the + // compiler isn't smart enough to realize the code is safe). + SegmentId next_index = 0; + if (end < num_indices) { + next_index = internal::SubtleMustCopy(segment_vec(end)); + if (out_index == next_index) { + ++end; + continue; + } + // We have a new segment here. Verify that the segment ids are growing. + OP_REQUIRES(context, out_index < next_index, + errors::InvalidArgument("segment ids are not increasing")); + } + + OP_REQUIRES( + context, FastBoundsCheck(out_index, output_rows), + errors::InvalidArgument( + "Segment id ", out_index, " out of range [0, ", output_rows, + "), possibly because 'segment_ids' input is not sorted.")); + + // If there is a gap between two indices, we need to set that gap to the + // default value. + if (out_index > uninitialized_index) { + Eigen::DSizes gap_slice_shape( + out_index - uninitialized_index, num_col); + Eigen::TensorMap, Eigen::Unaligned> + gap_slice(&output_flat(uninitialized_index, 0), gap_slice_shape); + gap_slice.setConstant(default_value_); + } + + auto out = output_flat.template chip<0>(out_index); + auto temp = temp_flat.template chip<0>(0); + const int bad_offset = Reduce(input_flat, indices_vec, start, + end - start, out, temp); + OP_REQUIRES(context, bad_offset < 0, + errors::InvalidArgument( + "Bad: indices[", start + bad_offset, + "] == ", indices_vec(start + bad_offset), + " out of range [0, ", input_flat.dimension(0), ")")); + + start = end; + ++end; + uninitialized_index = out_index + 1; + out_index = next_index; + if (end > num_indices) break; + } + + // Fill the gap at the end with the default value. + if (uninitialized_index < output_rows) { + Eigen::DSizes gap_slice_shape( + output_rows - uninitialized_index, num_col); + Eigen::TensorMap, Eigen::Unaligned> + gap_slice(&output_flat(uninitialized_index, 0), gap_slice_shape); + gap_slice.setConstant(default_value_); + } + } + + private: + const DataType dtidx_; + + template + using EnableIfBfloat16OrHalf = + typename std::enable_if::value || + std::is_same::value, + int>::type; + template + using EnableIfNotBfloat16OrHalf = + typename std::enable_if::value && + !std::is_same::value, + int>::type; + + template = 0> + EIGEN_ALWAYS_INLINE auto fetch_val( + const typename TTypes::ConstMatrix& input_flat, Tindex index) { + return input_flat.template chip<0>(index); + } + + template = 0> + EIGEN_ALWAYS_INLINE auto fetch_val( + const typename TTypes::ConstMatrix& input_flat, Tindex index) { + return input_flat.template chip<0>(index).template cast(); + } + + template + EIGEN_ALWAYS_INLINE Tout get_scaling_factor(int64_t num) { + Tout m(1); + if (is_mean_ && (num < 10)) { + m = Tout(num); + } + if (is_sqrtn_ && (num < 10)) { + m = Tout(sqrt(num)); + } + return Tout(1) / m; + } + + template = 0> + int64_t Reduce( + const typename TTypes::ConstMatrix& input_flat, + const typename TTypes::ConstVec& indices_vec, int64_t start, + int64_t num, Eigen::TensorChippingOp<0, typename TTypes::Matrix> out, + Eigen::TensorChippingOp<0, typename TTypes::Matrix> temp) { + return ReduceImpl(input_flat, indices_vec, start, num, + out, get_scaling_factor(num)); + } + + template = 0> + int64_t Reduce( + const typename TTypes::ConstMatrix& input_flat, + const typename TTypes::ConstVec& indices_vec, int64_t start, + int64_t num, Eigen::TensorChippingOp<0, typename TTypes::Matrix> out, + Eigen::TensorChippingOp<0, typename TTypes::Matrix> temp) { + int64_t res = + ReduceImpl(input_flat, indices_vec, start, num, + temp, get_scaling_factor(num)); + out = temp.template cast(); + return res; + } + + template + int64_t ReduceImpl( + const typename TTypes::ConstMatrix& input_flat, + const typename TTypes::ConstVec& indices_vec, int64_t start, + int64_t num, + Eigen::TensorChippingOp<0, typename TTypes::Matrix> out, + const Tout scaling_factor) { +#define INDEX(n, i) \ + const auto index##n = indices_vec(start + (i)); \ + if (!FastBoundsCheck(index##n, input_flat.dimension(0))) return (i); + +#define L(n) fetch_val(input_flat, index##n) + + if (num == 1) { + INDEX(0, 0); + out = L(0); + } else { + int64_t r = num & 7; + switch (r) { + case 2: { + INDEX(0, 0); + INDEX(1, 1); + out = (L(0) + L(1)) * scaling_factor; + break; + } + case 3: { + INDEX(0, 0); + INDEX(1, 1); + INDEX(2, 2); + out = (L(0) + L(1) + L(2)) * scaling_factor; + break; + } + case 4: { + INDEX(0, 0); + INDEX(1, 1); + INDEX(2, 2); + INDEX(3, 3); + out = (L(0) + L(1) + L(2) + L(3)) * scaling_factor; + break; + } + case 5: { + INDEX(0, 0); + INDEX(1, 1); + INDEX(2, 2); + INDEX(3, 3); + INDEX(4, 4); + out = (L(0) + L(1) + L(2) + L(3) + L(4)) * scaling_factor; + break; + } + case 6: { + INDEX(0, 0); + INDEX(1, 1); + INDEX(2, 2); + INDEX(3, 3); + INDEX(4, 4); + INDEX(5, 5); + out = (L(0) + L(1) + L(2) + L(3) + L(4) + L(5)) * scaling_factor; + break; + } + case 7: { + INDEX(0, 0); + INDEX(1, 1); + INDEX(2, 2); + INDEX(3, 3); + INDEX(4, 4); + INDEX(5, 5); + INDEX(6, 6); + out = + (L(0) + L(1) + L(2) + L(3) + L(4) + L(5) + L(6)) * scaling_factor; + break; + } + case 0: { + INDEX(0, 0); + INDEX(1, 1); + INDEX(2, 2); + INDEX(3, 3); + INDEX(4, 4); + INDEX(5, 5); + INDEX(6, 6); + INDEX(7, 7); + out = (L(0) + L(1) + L(2) + L(3) + L(4) + L(5) + L(6) + L(7)) * + scaling_factor; + r = 8; + break; + } + case 1: { + INDEX(0, 0); + INDEX(1, 1); + INDEX(2, 2); + INDEX(3, 3); + INDEX(4, 4); + INDEX(5, 5); + INDEX(6, 6); + INDEX(7, 7); + INDEX(8, 8); + out = (L(0) + L(1) + L(2) + L(3) + L(4) + L(5) + L(6) + L(7) + L(8)) * + scaling_factor; + r = 9; + break; + } + } + for (; r < num; r += 8) { + INDEX(0, r); + INDEX(1, r + 1); + INDEX(2, r + 2); + INDEX(3, r + 3); + INDEX(4, r + 4); + INDEX(5, r + 5); + INDEX(6, r + 6); + INDEX(7, r + 7); + out += L(0) + L(1) + L(2) + L(3) + L(4) + L(5) + L(6) + L(7); + } + if (is_mean_ && num >= 10) { + out = out / static_cast(num); + } + if (is_sqrtn_ && num >= 10) { + out = out / static_cast(sqrt(num)); + } + } + + return -1; +#undef L +#undef INDEX + } + + const bool is_mean_; + const bool is_sqrtn_; + const bool has_num_segments_; + const T default_value_; +}; + +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM + +// Specialization for GPU. Must be Async because may need to wait for a host to +// device memcpy before allocating output. +template +class SparseSegmentReductionOpBase + : public AsyncOpKernel { + public: + explicit SparseSegmentReductionOpBase(OpKernelConstruction* context, + bool is_mean, bool is_sqrtn, + bool has_num_segments, T default_value) + : AsyncOpKernel(context), + is_mean_(is_mean), + is_sqrtn_(is_sqrtn), + has_num_segments_(has_num_segments), + default_value_(default_value) {} + + void ComputeAsync(OpKernelContext* context, DoneCallback done) override { + const Tensor& input = context->input(0); + const Tensor& indices = context->input(1); + const Tensor& segment_ids = context->input(2); + + OP_REQUIRES_OK_ASYNC( + context, + internal::ValidateSparseSegmentReduction( + context, input, indices, segment_ids, has_num_segments_), + done); + + ScratchSpace last_segment_id_host(context, 1, /*on_host=*/true); + + auto create_and_check_output = [this, context, input, indices, segment_ids, + last_segment_id_host, done]() { + // Ensure that within the callback, the proper GPU settings are + // configured. + auto stream = context->op_device_context()->stream(); + std::unique_ptr scoped_activation = + stream->parent()->Activate(); + + SegmentId last_segment_id = *last_segment_id_host.data(); + SegmentId output_rows = last_segment_id + 1; + OP_REQUIRES_ASYNC(context, output_rows > 0, + errors::InvalidArgument("segment ids must be >= 0"), + done); + + TensorShape output_shape = input.shape(); + output_shape.set_dim(0, output_rows); + + Tensor* output = nullptr; + OP_REQUIRES_OK_ASYNC( + context, context->allocate_output(0, output_shape, &output), done); + + auto input_flat = input.flat_outer_dims(); + const auto indices_vec = indices.vec(); + const auto segment_ids_vec = segment_ids.vec(); + auto output_flat = output->flat_outer_dims(); + + functor::SparseSegmentReductionFunctor functor; + OP_REQUIRES_OK_ASYNC( + context, + functor(context, is_mean_, is_sqrtn_, default_value_, input_flat, + indices_vec, segment_ids_vec, output_flat), + done); + done(); + }; + + if (has_num_segments_) { + // No need to do any device to host memcpy, just compute synchronously. + const Tensor& num_segments_t = context->input(3); + SegmentId num_segments = + internal::SubtleMustCopy(num_segments_t.dtype() == DT_INT32 + ? num_segments_t.scalar()() + : num_segments_t.scalar()()); + *last_segment_id_host.mutable_data() = num_segments - 1; + create_and_check_output(); + } else { + const int64_t num_indices = indices.NumElements(); + if (num_indices == 0) { + TensorShape output_shape = input.shape(); + output_shape.set_dim(0, 0); + + Tensor* output = nullptr; + OP_REQUIRES_OK_ASYNC( + context, context->allocate_output(0, output_shape, &output), done); + done(); + return; + } + + // Need to copy last element of segment_ids from device to host, and then + // asynchronously allocate the output and finish the computation. + se::DeviceMemoryBase last_segment_id_device( + const_cast(segment_ids).template flat().data() + + (num_indices - 1)); + auto stream = context->op_device_context()->stream(); + OP_REQUIRES_OK_ASYNC( + context, + stream->Memcpy(last_segment_id_host.mutable_data(), + last_segment_id_device, sizeof(SegmentId)), + done); + context->device() + ->tensorflow_accelerator_device_info() + ->event_mgr->ThenExecute(stream, create_and_check_output); + } + } + + private: + const bool is_mean_; + const bool is_sqrtn_; + const bool has_num_segments_; + const T default_value_; +}; + +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM + +template +class SparseSegmentReductionMeanOp + : public SparseSegmentReductionOpBase { + public: + explicit SparseSegmentReductionMeanOp(OpKernelConstruction* context) + : SparseSegmentReductionOpBase( + context, true /*is_mean*/, false /*is_sqrtn*/, + false /* has_num_segments */, T(0) /* default_value */) {} +}; + +template +class SparseSegmentReductionMeanWithNumSegmentsOp + : public SparseSegmentReductionOpBase { + public: + explicit SparseSegmentReductionMeanWithNumSegmentsOp( + OpKernelConstruction* context) + : SparseSegmentReductionOpBase( + context, true /*is_mean*/, false /*is_sqrtn*/, + true /* has_num_segments */, T(0) /* default_value */) {} +}; + +template +class SparseSegmentReductionSqrtNOp + : public SparseSegmentReductionOpBase { + public: + explicit SparseSegmentReductionSqrtNOp(OpKernelConstruction* context) + : SparseSegmentReductionOpBase( + context, false /*is_mean*/, true /*is_sqrtn*/, + false /* has_num_segments */, T(0) /* default_value */) {} +}; + +template +class SparseSegmentReductionSqrtNWithNumSegmentsOp + : public SparseSegmentReductionOpBase { + public: + explicit SparseSegmentReductionSqrtNWithNumSegmentsOp( + OpKernelConstruction* context) + : SparseSegmentReductionOpBase( + context, false /*is_mean*/, true /*is_sqrtn*/, + true /* has_num_segments */, T(0) /* default_value */) {} +}; + +template +class SparseSegmentReductionSumOp + : public SparseSegmentReductionOpBase { + public: + explicit SparseSegmentReductionSumOp(OpKernelConstruction* context) + : SparseSegmentReductionOpBase( + context, false /*is_mean*/, false /*is_sqrtn*/, + false /* has_num_segments */, T(0) /* default_value */) {} +}; + +template +class SparseSegmentReductionSumWithNumSegmentsOp + : public SparseSegmentReductionOpBase { + public: + explicit SparseSegmentReductionSumWithNumSegmentsOp( + OpKernelConstruction* context) + : SparseSegmentReductionOpBase( + context, false /*is_mean*/, false /*is_sqrtn*/, + true /* has_num_segments */, T(0) /* default_value */) {} +}; + +namespace functor { + +template +struct SparseSegmentGradFunctor { + void operator()(OpKernelContext* context, + SparseSegmentReductionOperation operation, + typename TTypes::ConstMatrix input_flat, + typename TTypes::ConstVec indices_vec, + typename TTypes::ConstVec segment_vec, + Tensor* output) { + auto output_flat = output->flat_outer_dims(); + const int64_t N = indices_vec.size(); + const SegmentId M = output_flat.dimension(0); + + // Note that similar to SparseSegmentMean, we assume that segment_vec is + // already sorted and has non-negative values. + const SegmentId num_segments = input_flat.dimension(0); + const SegmentId last_segment_id_plus_one = + internal::SubtleMustCopy(segment_vec(N - 1)) + 1; + OP_REQUIRES(context, last_segment_id_plus_one <= num_segments, + absl::InvalidArgumentError("Invalid number of segments")); + + const auto scaling_or = + ComputeScalingFactors(operation, segment_vec, num_segments); + OP_REQUIRES_OK(context, scaling_or.status()); + const std::vector& scaling = scaling_or.value(); + + // If we use DT_BFLOAT16 or DT_HALF, we need to use DT_FLOAT for + // accumulation. We create a temp tensor to perform this accumulation for + // every segment. + Tensor temp; + if (output->dtype() == DT_BFLOAT16 || output->dtype() == DT_HALF) { + temp = tensorflow::Tensor(DT_FLOAT, output->shape()); + } + auto temp_flat = temp.flat_outer_dims(); + + if (output->dtype() == DT_BFLOAT16 || output->dtype() == DT_HALF) { + temp_flat.setZero(); + } else { + output_flat.setZero(); + } + + for (int64_t i = 0; i < N; ++i) { + const Index output_idx = internal::SubtleMustCopy(indices_vec(i)); + OP_REQUIRES(context, FastBoundsCheck(output_idx, M), + absl::InvalidArgumentError(absl::StrCat( + "Index ", output_idx, " out of range [0, ", M, ")."))); + + const SegmentId idx = internal::SubtleMustCopy(segment_vec(i)); + OP_REQUIRES( + context, FastBoundsCheck(idx, num_segments), + absl::InvalidArgumentError(absl::StrCat( + "Segment id ", idx, " out of range [0, ", num_segments, ")."))); + + const double scale = operation == SparseSegmentReductionOperation::kSum + ? 1.0 + : scaling[idx]; + Accumulate(input_flat.template chip<0>(idx), scale, + output_flat.template chip<0>(output_idx), + temp_flat.template chip<0>(output_idx)); + } + + // Copy the contents of the temp tensor to the output tensor. + if (output->dtype() == DT_BFLOAT16 || output->dtype() == DT_HALF) { + output_flat = temp_flat.template cast(); + } + } + + private: + template + using EnableIfBfloat16OrHalf = + typename std::enable_if::value || + std::is_same::value, + int>::type; + template + using EnableIfNotBfloat16OrHalf = + typename std::enable_if::value && + !std::is_same::value, + int>::type; + + template = 0> + void Accumulate( + Eigen::TensorChippingOp<0, const typename TTypes::ConstMatrix> in, + double scale, + Eigen::TensorChippingOp<0, typename TTypes::Matrix> out, + Eigen::TensorChippingOp<0, typename TTypes::Matrix> temp) { + out += in * static_cast(scale); + } + + template = 0> + void Accumulate( + Eigen::TensorChippingOp<0, const typename TTypes::ConstMatrix> in, + double scale, + Eigen::TensorChippingOp<0, typename TTypes::Matrix> out, + Eigen::TensorChippingOp<0, typename TTypes::Matrix> temp) { + temp += in.template cast() * static_cast(scale); + } + + // Compute scaling factors for input. + absl::StatusOr> ComputeScalingFactors( + SparseSegmentReductionOperation operation, + typename TTypes::ConstVec segment_vec, + const SegmentId num_segments) { + if (operation == SparseSegmentReductionOperation::kSum) { + return std::vector(0); + } + + std::vector scaling(num_segments, 0); + + for (int64_t i = 0; i < segment_vec.size(); ++i) { + const SegmentId idx = internal::SubtleMustCopy(segment_vec(i)); + if (!FastBoundsCheck(idx, num_segments)) { + return absl::InvalidArgumentError(absl::StrCat( + "Segment id ", idx, " out of range [0, ", num_segments, ").")); + } + scaling[idx] += 1; + } + + if (operation == SparseSegmentReductionOperation::kMean) { + for (size_t i = 0; i < scaling.size(); ++i) { + scaling[i] = 1.0 / std::max(scaling[i], 1.0); + } + } else { + for (size_t i = 0; i < scaling.size(); ++i) { + scaling[i] = 1.0 / sqrt(std::max(scaling[i], 1.0)); + } + } + + return scaling; + } +}; + +template +struct SparseSegmentGradV2Functor { + void operator()(OpKernelContext* context, + SparseSegmentReductionOperation operation, + typename TTypes::ConstMatrix input_flat, + typename TTypes::ConstVec indices_vec, + typename TTypes::ConstVec segment_vec, + const TensorShape& dense_output_shape, + typename AsyncOpKernel::DoneCallback /*done*/) { + const int64_t N = indices_vec.size(); + const int64_t M = dense_output_shape.dim_size(0); + const SegmentId num_segments = input_flat.dimension(0); + const SegmentId last_segment_id_plus_one = + internal::SubtleMustCopy(segment_vec(N - 1)) + 1; + // Note: We do bounds-checking up front here so that it operates in the same + // order as the V1 implementation. + OP_REQUIRES(context, last_segment_id_plus_one <= num_segments, + errors::InvalidArgument("Invalid number of segments")); + for (int64_t i = 0; i < N; ++i) { + const Index output_idx = internal::SubtleMustCopy(indices_vec(i)); + OP_REQUIRES(context, FastBoundsCheck(output_idx, M), + errors::InvalidArgument("Index ", output_idx, + " out of range [0, ", M, ").")); + const SegmentId segment_id = internal::SubtleMustCopy(segment_vec(i)); + OP_REQUIRES( + context, FastBoundsCheck(segment_id, num_segments), + errors::InvalidArgument("Segment id ", segment_id, + " out of range [0, ", num_segments, ").")); + } + + std::vector permutation; + permutation.reserve(N); + for (int64_t i = 0; i < N; ++i) { + permutation.push_back(i); + } + std::stable_sort( + permutation.begin(), permutation.end(), + [&](Index a, Index b) { return indices_vec(a) < indices_vec(b); }); + std::vector sorted_indices; + std::vector permuted_segments; + sorted_indices.reserve(N); + permuted_segments.reserve(N); + for (Index j : permutation) { + sorted_indices.push_back(indices_vec(j)); + permuted_segments.push_back(segment_vec(j)); + } + + // Maps indices to unique index IDs. + absl::flat_hash_map unique_indices_map; + // The unique ID for each original index. + std::vector unique_index_ids; + unique_index_ids.reserve(N); + for (Index output_idx : sorted_indices) { + auto iter = + unique_indices_map.emplace(output_idx, unique_indices_map.size()) + .first; + Index unique_id = iter->second; + unique_index_ids.push_back(unique_id); + } + const int64_t num_unique = unique_indices_map.size(); + + // The original index for each unique ID. + Tensor* unique_indices = nullptr; + OP_REQUIRES_OK(context, + context->allocate_output(1, {num_unique}, &unique_indices)); + typename TTypes::Vec unique_indices_vec = + unique_indices->vec(); + for (const auto& idx_and_id : unique_indices_map) { + unique_indices_vec(idx_and_id.second) = idx_and_id.first; + } + + TensorShape output_shape = dense_output_shape; + OP_REQUIRES_OK(context, output_shape.SetDimWithStatus(0, num_unique)); + Tensor* output = nullptr; + OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output)); + + // Call the V1 implementation with the unique/permuted indices/segments. + typename TTypes::ConstVec unique_index_ids_vec( + unique_index_ids.data(), unique_index_ids.size()); + typename TTypes::ConstVec permuted_segment_vec( + permuted_segments.data(), permuted_segments.size()); + SparseSegmentGradFunctor()( + context, operation, input_flat, unique_index_ids_vec, + permuted_segment_vec, output); + } +}; + +} // namespace functor + +// Implements the common logic for the gradients of SparseSegmentReduction +// kernels. +// +// The template parameters are: +// * Device: An Eigen device object, on which the kernel will execute. +// * T: The value type. +// * Index: The element type of the indices tensor (int32 or int64). +// * SegmentId: The element type of the segment_ids tensor (int32 or int64). +template +class SparseSegmentGradOpBase : public OpKernel { + public: + explicit SparseSegmentGradOpBase(OpKernelConstruction* context, + SparseSegmentReductionOperation operation) + : OpKernel(context), operation_(operation) {} + + void Compute(OpKernelContext* context) override { + const Tensor& input = context->input(0); + const Tensor& indices = context->input(1); + const Tensor& segment_ids = context->input(2); + const Tensor& output_dim0 = context->input(3); + + OP_REQUIRES(context, TensorShapeUtils::IsVector(indices.shape()), + errors::InvalidArgument("indices should be a vector.")); + OP_REQUIRES(context, TensorShapeUtils::IsVector(segment_ids.shape()), + errors::InvalidArgument("segment_ids should be a vector.")); + OP_REQUIRES(context, TensorShapeUtils::IsScalar(output_dim0.shape()), + errors::InvalidArgument("output_dim0 should be a scalar.")); + + const int64_t N = indices.NumElements(); + OP_REQUIRES(context, N == segment_ids.NumElements(), + errors::InvalidArgument( + "segment_ids and indices should have same size.")); + const SegmentId M = internal::SubtleMustCopy(output_dim0.scalar()()); + + auto input_flat = input.flat_outer_dims(); + const auto indices_vec = indices.vec(); + const auto segment_vec = segment_ids.vec(); + + TensorShape output_shape = input.shape(); + OP_REQUIRES_OK(context, output_shape.SetDimWithStatus(0, M)); + Tensor* output = nullptr; + OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output)); + if (M == 0 || N == 0) return; + + functor::SparseSegmentGradFunctor()( + context, operation_, input_flat, indices_vec, segment_vec, output); + } + + private: + const SparseSegmentReductionOperation operation_; +}; + +template +class SparseSegmentSumGradOp + : public SparseSegmentGradOpBase { + public: + explicit SparseSegmentSumGradOp(OpKernelConstruction* context) + : SparseSegmentGradOpBase( + context, SparseSegmentReductionOperation::kSum) {} +}; + +template +class SparseSegmentMeanGradOp + : public SparseSegmentGradOpBase { + public: + explicit SparseSegmentMeanGradOp(OpKernelConstruction* context) + : SparseSegmentGradOpBase( + context, SparseSegmentReductionOperation::kMean) {} +}; + +template +class SparseSegmentSqrtNGradOp + : public SparseSegmentGradOpBase { + public: + explicit SparseSegmentSqrtNGradOp(OpKernelConstruction* context) + : SparseSegmentGradOpBase( + context, SparseSegmentReductionOperation::kSqrtN) {} +}; + +template +class SparseSegmentGradV2OpCommon { + public: + absl::Status operator()(OpKernelContext* context, + SparseSegmentReductionOperation operation, + typename AsyncOpKernel::DoneCallback done = nullptr) { + const Tensor& input = context->input(0); + const Tensor& indices = context->input(1); + const Tensor& segment_ids = context->input(2); + const Tensor& dense_output_dim0 = context->input(3); + + if (!TensorShapeUtils::IsVector(indices.shape())) { + return errors::InvalidArgument("indices should be a vector."); + } + if (!TensorShapeUtils::IsVector(segment_ids.shape())) { + return errors::InvalidArgument("segment_ids should be a vector."); + } + if (!TensorShapeUtils::IsScalar(dense_output_dim0.shape())) { + return errors::InvalidArgument("dense_output_dim0 should be a scalar."); + } + + const int64_t N = indices.NumElements(); + if (N != segment_ids.NumElements()) { + return errors::InvalidArgument( + "segment_ids and indices should have same size."); + } + const int32_t M = + internal::SubtleMustCopy(dense_output_dim0.scalar()()); + TensorShape dense_output_shape = input.shape(); + TF_RETURN_IF_ERROR(dense_output_shape.SetDimWithStatus(0, M)); + + if (M == 0 || N == 0) { + TensorShape output_shape = input.shape(); + TF_RETURN_IF_ERROR(output_shape.SetDimWithStatus(0, 0)); + Tensor* output = nullptr; + TF_RETURN_IF_ERROR(context->allocate_output(0, output_shape, &output)); + Tensor* sorted_unique_indices = nullptr; + TF_RETURN_IF_ERROR(context->allocate_output(1, TensorShape({0}), + &sorted_unique_indices)); + return absl::OkStatus(); + } + + auto input_flat = input.flat_outer_dims(); + const auto indices_vec = indices.vec(); + const auto segment_vec = segment_ids.vec(); + + functor::SparseSegmentGradV2Functor()( + context, operation, input_flat, indices_vec, segment_vec, + dense_output_shape, done); + + return absl::OkStatus(); + } +}; + +template +class SparseSegmentGradV2OpBase {}; + +// The CPU implementation is synchronous. +template +class SparseSegmentGradV2OpBase + : public OpKernel { + public: + explicit SparseSegmentGradV2OpBase(OpKernelConstruction* context, + SparseSegmentReductionOperation operation) + : OpKernel(context), operation_(operation) {} + + void Compute(OpKernelContext* context) override { + OP_REQUIRES_OK( + context, (SparseSegmentGradV2OpCommon()( + context, operation_))); + } + + private: + const SparseSegmentReductionOperation operation_; +}; + +// The GPU implementation is asynchronous. +template +class SparseSegmentGradV2OpBase + : public AsyncOpKernel { + public: + explicit SparseSegmentGradV2OpBase(OpKernelConstruction* context, + SparseSegmentReductionOperation operation) + : AsyncOpKernel(context), operation_(operation) {} + + void ComputeAsync(OpKernelContext* context, DoneCallback done) override { + OP_REQUIRES_OK_ASYNC( + context, + (SparseSegmentGradV2OpCommon()( + context, operation_, done)), + done); + } + + private: + const SparseSegmentReductionOperation operation_; +}; + +template +class SparseSegmentSumGradV2Op + : public SparseSegmentGradV2OpBase { + public: + explicit SparseSegmentSumGradV2Op(OpKernelConstruction* context) + : SparseSegmentGradV2OpBase( + context, SparseSegmentReductionOperation::kSum) {} +}; + +template +class SparseSegmentMeanGradV2Op + : public SparseSegmentGradV2OpBase { + public: + explicit SparseSegmentMeanGradV2Op(OpKernelConstruction* context) + : SparseSegmentGradV2OpBase( + context, SparseSegmentReductionOperation::kMean) {} +}; + +template +class SparseSegmentSqrtNGradV2Op + : public SparseSegmentGradV2OpBase { + public: + explicit SparseSegmentSqrtNGradV2Op(OpKernelConstruction* context) + : SparseSegmentGradV2OpBase( + context, SparseSegmentReductionOperation::kSqrtN) {} +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_SEGMENT_REDUCTION_OPS_IMPL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/sendrecv_ops.h b/third_party/tflite-hdrs/tensorflow/core/kernels/sendrecv_ops.h new file mode 100644 index 00000000..34f27d10 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/sendrecv_ops.h @@ -0,0 +1,58 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_SENDRECV_OPS_H_ +#define TENSORFLOW_CORE_KERNELS_SENDRECV_OPS_H_ + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/platform/macros.h" + +namespace tensorflow { + +class SendOp : public OpKernel { + public: + explicit SendOp(OpKernelConstruction* ctx); + void Compute(OpKernelContext* ctx) override; + + string TraceString(const OpKernelContext& ctx, bool verbose) const override; + + private: + string key_prefix_; + Rendezvous::ParsedKey parsed_key_; + bool hostmem_sendrecv_; + + SendOp(const SendOp&) = delete; + void operator=(const SendOp&) = delete; +}; + +class RecvOp : public AsyncOpKernel { + public: + explicit RecvOp(OpKernelConstruction* ctx); + void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override; + + string TraceString(const OpKernelContext& ctx, bool verbose) const override; + + private: + string key_prefix_; + Rendezvous::ParsedKey parsed_key_; + bool hostmem_sendrecv_; + + RecvOp(const RecvOp&) = delete; + void operator=(const RecvOp&) = delete; +}; + +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_SENDRECV_OPS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/sequence_ops.h b/third_party/tflite-hdrs/tensorflow/core/kernels/sequence_ops.h new file mode 100644 index 00000000..fc81643c --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/sequence_ops.h @@ -0,0 +1,36 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_SEQUENCE_OPS_H_ +#define TENSORFLOW_CORE_KERNELS_SEQUENCE_OPS_H_ + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor_types.h" + +namespace tensorflow { + +namespace functor { + +template +struct RangeFunctor { + void operator()(OpKernelContext* context, int64_t size, T start, T delta, + typename TTypes::Flat output) const; +}; + +} // namespace functor + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_SEQUENCE_OPS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/shape_ops.h b/third_party/tflite-hdrs/tensorflow/core/kernels/shape_ops.h new file mode 100644 index 00000000..d9c64c76 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/shape_ops.h @@ -0,0 +1,269 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_SHAPE_OPS_H_ +#define TENSORFLOW_CORE_KERNELS_SHAPE_OPS_H_ + +#include +#include +#include + +#include "absl/container/inlined_vector.h" +#include "tensorflow/core/common_runtime/dma_helper.h" +#include "tensorflow/core/framework/bounds_check.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/variant_op_registry.h" + +namespace tensorflow { + +namespace shape_op_helpers { +inline absl::Status GetShape(OpKernelContext* ctx, int input_index, + TensorShape* shape) { + *shape = ctx->input(input_index).shape(); + return absl::OkStatus(); +} +} // namespace shape_op_helpers + +template +class ShapeOp : public OpKernel { + public: + explicit ShapeOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} + + void Compute(OpKernelContext* ctx) override { + TensorShape shape; + OP_REQUIRES_OK(ctx, shape_op_helpers::GetShape(ctx, 0, &shape)); + const int rank = shape.dims(); + Tensor* out = nullptr; + OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({rank}), &out)); + auto vec = out->vec(); + for (int i = 0; i < rank; ++i) { + int64_t dim_size = shape.dim_size(i); + if (out->dtype() == DT_INT32) { + OP_REQUIRES( + ctx, FastBoundsCheck(dim_size, std::numeric_limits::max()), + errors::InvalidArgument("Shape output type is 32-bit ", " but dim ", + i, " is ", dim_size)); + } + vec(i) = static_cast(dim_size); + } + } + + bool IsExpensive() override { return false; } +}; + +template +class ShapeNOp : public OpKernel { + public: + explicit ShapeNOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} + + void Compute(OpKernelContext* ctx) override { + for (int i = 0; i < ctx->num_inputs(); ++i) { + TensorShape shape; + OP_REQUIRES_OK(ctx, shape_op_helpers::GetShape(ctx, i, &shape)); + const int dims = shape.dims(); + Tensor* out = nullptr; + OP_REQUIRES_OK(ctx, ctx->allocate_output(i, {dims}, &out)); + auto vec = out->vec(); + + for (int j = 0; j < dims; ++j) { + int64_t dim_size = shape.dim_size(j); + if (out->dtype() == DT_INT32) { + OP_REQUIRES( + ctx, FastBoundsCheck(dim_size, std::numeric_limits::max()), + errors::InvalidArgument("ShapeN output type is 32-bit but shape ", + i, " dim ", j, " is ", dim_size)); + } + vec(j) = static_cast(dim_size); + } + } + } + + bool IsExpensive() override { return false; } +}; + +class RankOp : public OpKernel { + public: + explicit RankOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} + + void Compute(OpKernelContext* ctx) override { + TensorShape shape; + OP_REQUIRES_OK(ctx, shape_op_helpers::GetShape(ctx, 0, &shape)); + const int rank = shape.dims(); + Tensor* out = nullptr; + OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &out)); + out->scalar()() = rank; + } + + bool IsExpensive() override { return false; } +}; + +template +class SizeOp : public OpKernel { + public: + explicit SizeOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} + + void Compute(OpKernelContext* ctx) override { + TensorShape shape; + OP_REQUIRES_OK(ctx, shape_op_helpers::GetShape(ctx, 0, &shape)); + const int64_t size = shape.num_elements(); + Tensor* out = nullptr; + OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &out)); + if (out->dtype() == DT_INT32) { + OP_REQUIRES( + ctx, FastBoundsCheck(size, std::numeric_limits::max()), + errors::InvalidArgument("Number of elements was larger than " + "representable by 32-bit output type")); + } + out->scalar()() = static_cast(size); + } + + bool IsExpensive() override { return false; } +}; + +template +class ExpandDimsOp : public OpKernel { + public: + explicit ExpandDimsOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} + + void Compute(OpKernelContext* ctx) override { + const Tensor& input_t = ctx->input(0); + OP_REQUIRES(ctx, input_t.dtype() != DT_VARIANT, + errors::InvalidArgument("ExpandDims on Variant not supported")); + + const Tensor& dim_t = ctx->input(1); + OP_REQUIRES( + ctx, (dim_t.NumElements() == 1), + errors::InvalidArgument("'dim' must be a tensor with a single value")); + DCHECK_EQ(dim_t.dtype(), DataTypeToEnum::v()); + Tdim dim = *static_cast(DMAHelper::base(&dim_t)); + const TensorShape& input_shape = input_t.shape(); + int input_dims = input_shape.dims(); + OP_REQUIRES(ctx, dim >= -1 - input_dims && dim <= input_dims, + errors::InvalidArgument("Tried to expand dim index ", dim, + " for tensor with ", input_dims, + " dimensions.")); + + // We emulate numpy's interpretation of the dim axis when + // -input.dims() >= dim <= input.dims(). + if (dim < 0) { + // Clamp to the end if needed. + dim = std::min(dim + input_dims + 1, input_dims); + } + + // Compute new shape with an additional dimension. + absl::InlinedVector output_shape_vec(input_dims + 1); + for (int64_t i = 0; i < dim; ++i) { + output_shape_vec[i] = input_shape.dim_size(i); + } + output_shape_vec[dim] = 1; + for (int64_t i = dim + 1; i < input_dims + 1; ++i) { + output_shape_vec[i] = input_shape.dim_size(i - 1); + } + TensorShape output_shape(output_shape_vec); + + Tensor output_t; + if (!output_t.CopyFrom(input_t, output_shape)) { + // This should never happen, since the sizes of the input and output + // should always be the same (we only expand the dimension with 1). + ctx->SetStatus( + errors::Internal("Could not expand dimension with input shape ", + ctx->input(0).shape().DebugString(), + " and output shape ", output_shape.DebugString())); + } + ctx->set_output(0, std::move(output_t)); + } + + bool IsExpensive() override { return false; } +}; + +class SqueezeOp : public OpKernel { + public: + explicit SqueezeOp(OpKernelConstruction* ctx) : OpKernel(ctx) { + std::vector squeeze_dims; + OP_REQUIRES_OK(ctx, ctx->GetAttr("squeeze_dims", &squeeze_dims)); + squeeze_dims_.insert(squeeze_dims.begin(), squeeze_dims.end()); + } + + void Compute(OpKernelContext* ctx) override { + OP_REQUIRES(ctx, ctx->input(0).dtype() != DT_VARIANT, + errors::InvalidArgument("Squeeze on Variant not supported")); + + auto existing_dims = ctx->input(0).shape().dim_sizes(); + const int existing_dims_size = static_cast(existing_dims.size()); + std::vector new_shape; + + std::unordered_set wrapped_squeeze_dims; + wrapped_squeeze_dims.reserve(squeeze_dims_.size()); + // Validate squeeze dims against the input. + for (int32_t dim : squeeze_dims_) { + OP_REQUIRES( + ctx, (dim >= -ctx->input(0).dims() && dim < ctx->input(0).dims()), + errors::InvalidArgument("Tried to squeeze dim index ", dim, + " for tensor with ", ctx->input(0).dims(), + " dimensions.")); + // If dim is < 0, we wrap around (-1 means the last element). + if (dim < 0) { + dim = existing_dims_size + dim; + } + + wrapped_squeeze_dims.insert(dim); + } + + for (int i = 0; i < existing_dims_size; ++i) { + auto existing_dim = existing_dims[i]; + + // If squeeze_set is non-empty, only squeeze those dimensions. + if (!wrapped_squeeze_dims.empty()) { + if (wrapped_squeeze_dims.count(i) > 0) { + OP_REQUIRES(ctx, existing_dim == 1, + errors::InvalidArgument( + "Can not squeeze dim[", i, + "], expected a dimension of 1, got ", existing_dim)); + } else { + // This dimension is not being squeezed. + new_shape.push_back(existing_dim); + } + } else { + // Copy over all non-1-length dimensions. + if (existing_dim != 1) { + new_shape.push_back(existing_dim); + } + } + } + + const TensorShape output_shape(new_shape); + Tensor* output = nullptr; + OP_REQUIRES_OK(ctx, ctx->allocate_output(0, {0}, &output)); + if (!output->CopyFrom(ctx->input(0), output_shape)) { + // This should never happen, since the sizes of the input and + // output should always be the same. + ctx->SetStatus(errors::Internal("Could not squeeze input with shape ", + ctx->input(0).shape().DebugString(), + " and output shape ", + output_shape.DebugString())); + } + } + + bool IsExpensive() override { return false; } + + private: + std::unordered_set squeeze_dims_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_SHAPE_OPS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/shuffle_common.h b/third_party/tflite-hdrs/tensorflow/core/kernels/shuffle_common.h new file mode 100644 index 00000000..0eea7fd4 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/shuffle_common.h @@ -0,0 +1,102 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Common utilities for random shuffling. + +#ifndef TENSORFLOW_CORE_KERNELS_SHUFFLE_COMMON_H_ +#define TENSORFLOW_CORE_KERNELS_SHUFFLE_COMMON_H_ + +#include +#include + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor_util.h" +#include "tensorflow/core/lib/random/philox_random.h" +#include "tensorflow/core/lib/random/random_distributions.h" + +namespace tensorflow { + +// TODO(irving): If performance is critical, generate output directly instead +// of an in-place shuffle using a pseudorandom permutation like +// +// https://github.com/otherlab/geode/blob/master/geode/random/permute.cpp +// +// This is probably also the right thing if we want a GPU version of shuffling. + +// We use our own version of std::random_shuffle to guarantee that exactly +// size - 1 samples are used. +template +static inline void ShuffleRange(Iter first, Iter last, Random& uniform) { + if (first == last) return; + const auto stop = last - 1; + for (auto i = first; i != stop; ++i) { + using std::iter_swap; + iter_swap(i, i + uniform(last - i)); + } +} + +template +static void IndexedShuffle(const int64_t size, const InT& input_mat, + OutT output_mat, Random& uniform) { + std::vector permutation(size); + for (IntT i = 0; i < size; i++) { + permutation[i] = i; + } + ShuffleRange(permutation.begin(), permutation.end(), uniform); + for (IntT i = 0; i < size; i++) { + output_mat.template chip<0>(i) = input_mat.template chip<0>(permutation[i]); + } +} + +template +absl::Status RandomShuffle( + OpKernelContext* context, const Tensor& input, int output_idx, + std::function get_rng) { + if (input.NumElements() <= 1 || input.dim_size(0) <= 1) { + // No shuffling is required, so copy input directly to output + context->set_output(output_idx, input); + } else { + // Reserve enough random samples for shuffling + const int64_t size = input.dim_size(0); + const int64_t samples = size - 1; + auto rng = get_rng(samples); + random::SingleSampleAdapter single(&rng); + const auto uniform = [&single](uint32 n) { return single() % n; }; + + if (input.dims() == 1) { + // For 1D data, copy and then shuffle in place + context->set_output(output_idx, tensor::DeepCopy(input)); + auto vec = context->mutable_output(output_idx)->vec(); + ShuffleRange(vec.data(), vec.data() + size, uniform); + } else { + // For >= 2D, shuffle indices and then copy across + Tensor* output = nullptr; + TF_RETURN_IF_ERROR( + context->allocate_output(output_idx, input.shape(), &output)); + const auto input_mat = input.flat_outer_dims(); + auto output_mat = output->flat_outer_dims(); + if (size < kint32max) { + IndexedShuffle(size, input_mat, output_mat, uniform); + } else { + IndexedShuffle(size, input_mat, output_mat, uniform); + } + } + } + return absl::OkStatus(); +} + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_SHUFFLE_COMMON_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/slice_op.h b/third_party/tflite-hdrs/tensorflow/core/kernels/slice_op.h new file mode 100644 index 00000000..1992c604 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/slice_op.h @@ -0,0 +1,45 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_SLICE_OP_H_ +#define TENSORFLOW_CORE_KERNELS_SLICE_OP_H_ + +// Functor definition for SliceOp, must be compilable by nvcc. + +#include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/tensor_types.h" + +namespace tensorflow { +namespace functor { + +template +struct Slice { + void operator()(const Device& d, typename TTypes::Tensor output, + typename TTypes::ConstTensor input, + const Eigen::DSizes& slice_indices, + const Eigen::DSizes& slice_sizes) { + MaybeWith32BitIndexing( + [&](auto output32, auto input32, auto slice_indices32, + auto slice_sizes32) { + output32.device(d) = input32.slice(slice_indices32, slice_sizes32); + }, + output, input, slice_indices, slice_sizes); + } +}; + +} // namespace functor +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_SLICE_OP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/slice_op_cpu_impl.h b/third_party/tflite-hdrs/tensorflow/core/kernels/slice_op_cpu_impl.h new file mode 100644 index 00000000..9eda840a --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/slice_op_cpu_impl.h @@ -0,0 +1,39 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_SLICE_OP_CPU_IMPL_H_ +#define TENSORFLOW_CORE_KERNELS_SLICE_OP_CPU_IMPL_H_ + +#define EIGEN_USE_THREADS + +#include "tensorflow/core/framework/bfloat16.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/kernels/slice_op.h" + +namespace tensorflow { + +using CpuDevice = Eigen::ThreadPoolDevice; + +#define DEFINE_CPU_KERNELS(T) \ + template struct functor::Slice; + +TF_CALL_ALL_TYPES(DEFINE_CPU_KERNELS); + +#undef DEFINE_CPU_KERNELS + + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_SLICE_OP_CPU_IMPL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/smooth-hinge-loss.h b/third_party/tflite-hdrs/tensorflow/core/kernels/smooth-hinge-loss.h new file mode 100644 index 00000000..8dc2c806 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/smooth-hinge-loss.h @@ -0,0 +1,114 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_SMOOTH_HINGE_LOSS_H_ +#define TENSORFLOW_CORE_KERNELS_SMOOTH_HINGE_LOSS_H_ + +#include + +#include "tensorflow/core/kernels/loss.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { + +class SmoothHingeLossUpdater : public DualLossUpdater { + public: + // Computes the updated dual variable (corresponding) to a single example. The + // updated dual value maximizes the objective function of the dual + // optimization problem associated with smooth hinge loss. The computations + // are detailed in readme.md. + double ComputeUpdatedDual(const int num_partitions, const double label, + const double example_weight, + const double current_dual, const double wx, + const double weighted_example_norm) const final { + // Intuitively there are 3 cases: + // a. new optimal value of the dual variable falls within the admissible + // range [0, 1]. In this case we set new dual to this value. + // b. new optimal value is < 0. Then, because of convexity, the optimal + // valid value for new dual = 0 + // c. new optimal value > 1.0. Then new optimal value should be set to 1.0. + const double candidate_optimal_dual = + current_dual + + (label - wx - gamma * current_dual) / + (num_partitions * example_weight * weighted_example_norm + gamma); + if (label * candidate_optimal_dual < 0) { + return 0.0; + } + if (label * candidate_optimal_dual > 1.0) { + return label; + } + return candidate_optimal_dual; + } + + double ComputeDualLoss(const double current_dual, const double example_label, + const double example_weight) const final { + // For binary classification, there are 2 conjugate functions, one per + // label value (-1 and 1). + const double y_alpha = current_dual * example_label; // y \alpha + if (y_alpha < 0 || y_alpha > 1.0) { + return std::numeric_limits::max(); + } + return (-y_alpha + 0.5 * gamma * current_dual * current_dual) * + example_weight; + } + + double ComputePrimalLoss(const double wx, const double example_label, + const double example_weight) const final { + const double y_wx = example_label * wx; + if (y_wx >= 1) return 0; + if (y_wx <= 1 - gamma) return (1 - y_wx - gamma / 2) * example_weight; + return (1 - y_wx) * (1 - y_wx) * example_weight * 0.5 / gamma; + } + + // Converts binary example labels from 0.0 or 1.0 to -1.0 or 1.0 respectively + // as expected by smooth hinge loss. + absl::Status ConvertLabel(float* const example_label) const final { + if (*example_label == 0.0) { + *example_label = -1; + return absl::OkStatus(); + } + if (*example_label == 1.0) { + return absl::OkStatus(); + } + return errors::InvalidArgument( + "Only labels of 0.0 or 1.0 are supported right now. " + "Found example with label: ", + *example_label); + } + + double PrimalLossDerivative(const double wx, const double label, + const double example_weight) const final { + if (label * wx >= 1) { + return 0; + } + if (label * wx <= 1 - gamma) { + return -label; + } + return (wx - label) / gamma; + } + + double SmoothnessConstant() const final { return gamma; } + + private: + // Smoothness constant of smooth hinge loss + // TODO(sibyl-Aix6ihai): expose this parameter + const double gamma = 1; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_SMOOTH_HINGE_LOSS_H_ +// TENSORFLOW_KERNELS_SMOOTH_HINGE_LOSS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/snapshot_op.h b/third_party/tflite-hdrs/tensorflow/core/kernels/snapshot_op.h new file mode 100644 index 00000000..1047b470 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/snapshot_op.h @@ -0,0 +1,44 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_SNAPSHOT_OP_H_ +#define TENSORFLOW_CORE_KERNELS_SNAPSHOT_OP_H_ + +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM +#define EIGEN_USE_GPU +#endif + +#define EIGEN_USE_THREADS + +#include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/op_kernel.h" + +namespace tensorflow { +namespace functor { + +// Functor used by SnapshotOp. +template +struct Snapshot { + void operator()(const Device& device, + typename TTypes::ConstTensor input, + typename TTypes::Tensor output) { + device.memcpy(output.data(), input.data(), input.size() * sizeof(Scalar)); + } +}; + +} // namespace functor +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_SNAPSHOT_OP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/softmax_op_functor.h b/third_party/tflite-hdrs/tensorflow/core/kernels/softmax_op_functor.h new file mode 100644 index 00000000..2ce16ce8 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/softmax_op_functor.h @@ -0,0 +1,95 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_SOFTMAX_OP_FUNCTOR_H_ +#define TENSORFLOW_CORE_KERNELS_SOFTMAX_OP_FUNCTOR_H_ +// Functor definition for SoftmaxOp, must be compilable by nvcc. + +#include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/tensor_types.h" + +namespace tensorflow { +namespace functor { + +// Functor used by SoftmaxOp to do the computations. +template +struct SoftmaxFunctor { + // Computes Softmax or LogSoftmax activation. + // + // logits: dim: batch_size, num_classes. + // softmax: dims: batch_size, num_classes. + // log: boolean + void operator()(const Device& d, typename TTypes::ConstMatrix logits, + typename TTypes::Matrix softmax, const bool log); +}; + +// Eigen code implementing SoftmaxFunctor::operator() or +// LogSoftmaxFunctor::operator(). +// This code works for both CPU and GPU and is used by the functor +// specializations for both device types. +template +struct SoftmaxEigenImpl { + static void Compute(const Device& d, typename TTypes::ConstMatrix logits, + typename TTypes::Matrix softmax, const bool log) { + const int kBatchDim = 0; + const int kClassDim = 1; + + const int batch_size = logits.dimension(kBatchDim); + const int num_classes = logits.dimension(kClassDim); + +// These arrays are used to reduce along the class dimension, and broadcast +// the resulting value to all classes. + Eigen::IndexList > along_class; + Eigen::IndexList > batch_by_one; + batch_by_one.set(0, batch_size); + Eigen::IndexList, int> one_by_class; + one_by_class.set(1, num_classes); + + // shifted_logits = logits - max(logits along classes); + auto shifted_logits = (logits - logits.maximum(along_class) + .eval() + .reshape(batch_by_one) + .broadcast(one_by_class)); + if (log) { + // Calculate the log of the softmax + // softmax = logits - max(logits along classes); + softmax.device(d) = shifted_logits; + // softmax = softmax - log(sum(exp(softmax along classes))); + softmax.device(d) = (softmax - softmax.exp() + .sum(along_class) + .log() + .eval() + .reshape(batch_by_one) + .broadcast(one_by_class)); + } else { + // NOTE(touts): If you modify this implementation please run + // the BM_ImageNetSoftmaxFwd benchmark in nn_ops_test.cc. + // + // softmax = exp(logits - max(logits along classes)); + softmax.device(d) = shifted_logits.exp(); + // softmax = softmax * (1 / sum(softmax along classes)); + softmax.device(d) = (softmax * softmax.sum(along_class) + .inverse() + .eval() + .reshape(batch_by_one) + .broadcast(one_by_class)); + } + } +}; + +} // namespace functor +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_SOFTMAX_OP_FUNCTOR_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/softplus_op.h b/third_party/tflite-hdrs/tensorflow/core/kernels/softplus_op.h new file mode 100644 index 00000000..1fa271a6 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/softplus_op.h @@ -0,0 +1,79 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_SOFTPLUS_OP_H_ +#define TENSORFLOW_CORE_KERNELS_SOFTPLUS_OP_H_ +// Functor definition for SoftplusOp and SoftplusGradOp, must be compilable by +// nvcc. + +// clang-format off +#include "tensorflow/core/platform/bfloat16.h" +#include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +// clang-format on +#include "tensorflow/core/framework/tensor_types.h" + +namespace tensorflow { +namespace functor { + +// Functor used by SoftplusOp to do the computations. +template +struct Softplus { + // Computes Softplus activation. + // + // features: any shape. + // activations: same shape as "features". + void operator()(const Device& d, typename TTypes::ConstTensor features, + typename TTypes::Tensor activations) { + // Choose a threshold on x below which exp(x) may underflow + // when added to 1, but for which exp(x) is always within epsilon of the + // true softplus(x). Offset of 2 from machine epsilon checked + // experimentally for float16, float32, float64. Checked against + // softplus implemented with numpy's log1p and numpy's logaddexp. + static const T threshold = + Eigen::numext::log(Eigen::NumTraits::epsilon()) + T(2); + // Value above which exp(x) may overflow, but softplus(x) == x + // is within machine epsilon. + auto too_large = features > features.constant(-threshold); + // Value below which exp(x) may underflow, but softplus(x) == exp(x) + // is within machine epsilon. + auto too_small = features < features.constant(threshold); + auto features_exp = features.exp(); + activations.device(d) = too_large.select( + features, // softplus(x) ~= x for x large + too_small.select(features_exp, // softplus(x) ~= exp(x) for x small + features_exp.log1p())); + } +}; + +// Functor used by SoftplusGradOp to do the computations. +template +struct SoftplusGrad { + // Computes SoftplusGrad backprops. + // + // gradients: gradients backpropagated to the Softplus op. + // features: inputs that where passed to the Softplus op. + // backprops: gradients to backpropagate to the Softplus inputs. + void operator()(const Device& d, typename TTypes::ConstTensor gradients, + typename TTypes::ConstTensor features, + typename TTypes::Tensor backprops) { + backprops.device(d) = + gradients / ((-features).exp() + features.constant(T(1))); + } +}; + +} // namespace functor +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_SOFTPLUS_OP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/softsign_op.h b/third_party/tflite-hdrs/tensorflow/core/kernels/softsign_op.h new file mode 100644 index 00000000..15de7288 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/softsign_op.h @@ -0,0 +1,60 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_SOFTSIGN_OP_H_ +#define TENSORFLOW_CORE_KERNELS_SOFTSIGN_OP_H_ +// Functor definition for SoftsignOp and SoftsignGradOp, must be compilable by +// nvcc. + +#include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/tensor_types.h" + +namespace tensorflow { +namespace functor { + +// Functor used by SoftsignOp to do the computations. +template +struct Softsign { + // Computes Softsign activation. + // + // features: any shape. + // activations: same shape as "features". + void operator()(const Device& d, typename TTypes::ConstTensor features, + typename TTypes::Tensor activations) { + activations.device(d) = + features / (features.abs() + features.constant(T(1))); + } +}; + +// Functor used by SoftsignGradOp to do the computations. +template +struct SoftsignGrad { + // Computes SoftsignGrad backprops. + // + // gradients: gradients backpropagated to the Softsign op. + // features: inputs that were passed to the Softsign op. + // backprops: gradients to backpropagate to the Softsign inputs. + void operator()(const Device& d, typename TTypes::ConstTensor gradients, + typename TTypes::ConstTensor features, + typename TTypes::Tensor backprops) { + backprops.device(d) = + gradients / (features.abs() + features.constant(T(1))).square(); + } +}; + +} // namespace functor +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_SOFTSIGN_OP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/spacetobatch_functor.h b/third_party/tflite-hdrs/tensorflow/core/kernels/spacetobatch_functor.h new file mode 100644 index 00000000..7838b5e3 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/spacetobatch_functor.h @@ -0,0 +1,114 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_SPACETOBATCH_FUNCTOR_H_ +#define TENSORFLOW_CORE_KERNELS_SPACETOBATCH_FUNCTOR_H_ + +#include + +#include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/bounds_check.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +// Maximum number of non-collapsible blocked dimensions supported by the +// {SpaceToBatch,BatchToSpace}ND operation. To change the limit, modify this +// constant and the TF_SPACETOBATCH_FOR_EACH_NUM_BLOCK_DIMS macro definition +// below. +constexpr int kMaxSpaceToBatchBlockDims = 4; + +// Expands to: +// MACRO(1, ## __VA_ARGS__) +// ... +// MACRO(kMaxSpaceToBatchBlockDims, ## __VA_ARGS__) +// +// Note: The space between the number and the comma is necessary for proper GCC +// comma handling: https://gcc.gnu.org/onlinedocs/cpp/Variadic-Macros.html +#define TF_SPACETOBATCH_FOR_EACH_NUM_BLOCK_DIMS(MACRO, ...) \ + MACRO(1 /**/, ##__VA_ARGS__) \ + MACRO(2 /**/, ##__VA_ARGS__) \ + MACRO(3 /**/, ##__VA_ARGS__) \ + MACRO(4 /**/, ##__VA_ARGS__) \ + /**/ + +namespace internal { +namespace spacetobatch { + +template +void SubtleMustCopyFlatHelper(const Tensor& t, OutputType* output) { + const int64_t num_elements = t.shape().num_elements(); + output->resize(num_elements); + auto eigen_vec = t.flat(); + for (int64_t i = 0; i < num_elements; ++i) { + (*output)[i] = SubtleMustCopy(eigen_vec(i)); + } +} + +// Copies flat contents of `t` to std::vector-like `*output`, which is resized +// as needed. `OutputType` may be either `std::vector` or +// `gtl::InlinedVector`. +// +// Precondition: t.dtype() must be either DT_INT32 or DT_INT64. +template +void SubtleMustCopyFlat(const Tensor& t, OutputType* output) { + if (t.dtype() == DT_INT32) { + SubtleMustCopyFlatHelper(t, output); + } else { + SubtleMustCopyFlatHelper(t, output); + } +} + +} // namespace spacetobatch +} // namespace internal + +namespace functor { + +// Functor used by {SpaceToBatch,BatchToSpace}{ND,}Op to do the conversion. +// +// If B2S is false, then this performs the space-to-batch conversion. If B2S is +// true, then this performs the inverse batch-to-space conversion. +template +struct SpaceToBatchFunctor { + using InputT = typename std::conditional::type; + using OutputT = typename std::conditional::type; + // Implements the space to batch conversion. + // + // space_tensor: input tensor of space-to-batch operation. If B2S = false, + // then this is the input to the conversion. If B2S = true, then this + // is the output of the conversion. + // block_size: array of shape [NUM_BLOCK_DIMS] specifying the block sizes for + // dimensions 1 through NUM_BLOCK_DIMS. + // paddings: row-major array of shape [NUM_BLOCK_DIMS, 2] specifying the + // start and end padding for dimensions 1 through NUM_BLOCK_DIMS. + // batch_tensor: output tensor of the space-to-batch operation. If + // B2S = false, then this is the output of the conversion. If B2S = true, + // then this is the input to the conversion. + // + // The caller must ensure that the dimensions of the tensors are correct. + absl::Status operator()( + const Device& d, + typename TTypes::Tensor space_tensor, + const int64_t block_shape[NUM_BLOCK_DIMS], + const int64_t paddings[NUM_BLOCK_DIMS * 2], + typename TTypes::Tensor batch_tensor); +}; + +} // namespace functor +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_SPACETOBATCH_FUNCTOR_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/spacetodepth_op.h b/third_party/tflite-hdrs/tensorflow/core/kernels/spacetodepth_op.h new file mode 100644 index 00000000..3cb1df5b --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/spacetodepth_op.h @@ -0,0 +1,57 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_SPACETODEPTH_OP_H_ +#define TENSORFLOW_CORE_KERNELS_SPACETODEPTH_OP_H_ +// Functor definition for XentOp, must be compilable by nvcc. + +#include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/util/tensor_format.h" + +namespace tensorflow { +namespace functor { + +// Functor used by SpaceToDepthOp to do the computations. +// Implements a family of Space to Depth transforms for a 4D 'input' tensor +// to a 4D 'output' tensor, both tensors use type 'T' and layout 'data_format'. +// These transforms divide the vertical and horizontal image sizes by +// 'block_size', and multiply the depth dimension size by +// (block_size * block_size). The offset within each block_size * block_size +// patch within the image is combined with the input channel index to form +// the output channel index, with the Y, X coordinates within each block of +// the input image used as the high order component of the output channel. +// e.g. for data_format = NHWC: +// Each element in the input tensor can be specified via 6 coordinates, +// ordered by decreasing memory layout significance as: +// n,oY,bY,oX,bX,iC (where n=batch index, oX, oY means X or Y coordinates +// within the output image, bX, bY means coordinates +// within the input block, iC means input channels). +// The output would be a transpose to the following layout: +// n,oY,oX,bY,bX,iC +template +struct SpaceToDepthOpFunctor { + void operator()(const Device& d, typename TTypes::ConstTensor input, + int block_size, typename TTypes::Tensor output); + + // This 5-D version is to support NCHW_VECT_C. + void operator()(const Device& d, typename TTypes::ConstTensor input, + int block_size, typename TTypes::Tensor output); +}; + +} // namespace functor +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_SPACETODEPTH_OP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/sparse/kernels.h b/third_party/tflite-hdrs/tensorflow/core/kernels/sparse/kernels.h new file mode 100644 index 00000000..aff14ca0 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/sparse/kernels.h @@ -0,0 +1,257 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_SPARSE_KERNELS_H_ +#define TENSORFLOW_CORE_KERNELS_SPARSE_KERNELS_H_ + +#include +#include + +#include "absl/status/status.h" +#include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/kernels/sparse/sparse_matrix.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +namespace functor { + +// Calculates number of nonzero entries per batch of a sorted rank-3 +// SparseTensor's indices. indices is expected to have columns +// corresponding to [batch, row, column], where indices[:,0] < B. +// +// REQUIRES: +// indices.dimension(1) == 3 +// nnz_per_batch.dimension(0) == B +template +struct CalculateNNZPerBatchMatrixFromIndices { + absl::Status operator()(OpKernelContext* c, + TTypes::ConstMatrix indices, + TTypes::Vec nnz_per_batch); +}; + +// Split a subset of a SparseTensors' indices into two vectors: +// COO row inds and COO col inds. Outputs are: +// +// coo_row_ind = indices[:, row_dim] +// coo_col_ind = indices[:, row_dim + 1] +// +// where n = coo_row_ind.size() +// and row_dim = #cols(indices) - 1 +// +// REQUIRES: +// host_dense_shape.size() in [2, 3] +// indices.dim_size(1) == host_dense_shape.size() +// coo_row_ind.size() == coo_col_ind.size() +// coo_row_ind.size() == indices.dim_size(0) +template +struct SparseTensorToCOOSparseMatrix { + void operator()(const Device& d, TTypes::ConstVec host_dense_shape, + TTypes::ConstMatrix indices, + TTypes::Vec coo_row_ind, + TTypes::Vec coo_col_ind); +}; + +// Write coo batch, row, and column vectors to output matrix indices: +// +// indices[:, row_dim] = coo_row_ind +// indices[:, col_dim] = coo_col_ind +// +// where row_dim = #cols(indices) - 1 and n = coo_row_ind.size(). +// In addition, if #cols(indices) == 3, also store the batch: +// +// indices[i, 0] = batch_of(i) where +// host_batch_ptrs(batch_of(i)) <= i < host_batch_ptrs(batch_of(i) + 1) +// +// REQUIRES: +// +// host_dense_shape.size() in [2, 3] +// indices.dim_size(1) == host_dense_shape.size() +// host_batch_ptr.size() == +// coo_row_ind.size() == coo_col_ind.size() +// +template +struct COOSparseMatrixToSparseTensor { + absl::Status operator()(OpKernelContext* c, + TTypes::ConstVec host_dense_shape, + TTypes::ConstVec host_batch_ptrs, + TTypes::Vec coo_row_ind, + TTypes::ConstVec coo_col_ind, + TTypes::Matrix indices); +}; + +// Convert a vector of coo row indices to csr row pointers. +// +// REQUIRES: +// +// csr_row_ptr.size() == rows + 1. +// max(coo_row_ptr) < rows. +// +template +struct COOSparseMatrixToCSRSparseMatrix { + absl::Status operator()(OpKernelContext* c, const int rows, const int cols, + TTypes::UnalignedVec coo_row_ind, + TTypes::UnalignedVec csr_row_ptr); +}; + +// Convert a matrix of (batched) coo row and column indices to CSR SparseMatrix +// batch ptrs, csr row pointers and coo column indices. +// +// REQUIRES: +// batch_ptr.size() == batch_size + 1 +// csr_row_ptr.size() == batch_size * (num_rows + 1) +// csr_col_ind.size() == total_nnz +// batch_size == 1 if rank == 2 +// +// where +// total_nnz = indices.dim_size(0) +// rank = indices.dim_size(1) +// Also csr_row_ptr should be initially filled with zeros. +// +struct SparseTensorToCSRSparseMatrixCPUFunctor { + absl::Status operator()(int64_t batch_size, int num_rows, int num_cols, + TTypes::ConstMatrix indices, + TTypes::Vec batch_ptr, + TTypes::Vec csr_row_ptr, + TTypes::Vec csr_col_ind); +}; + +// Convert a vector of csr row pointers to coo row indices. +// +// REQUIRES: +// +// coo_row_ptr.size() == nnz. +// csr_row_ptr[-1] == nnz. +// +template +struct CSRSparseMatrixToCOOSparseMatrix { + absl::Status operator()(OpKernelContext* c, + TTypes::UnalignedConstVec csr_row_ptr, + TTypes::UnalignedVec coo_row_ind); +}; + +// Calculates C = matmul(A, B) or C = matmul(A, B)^T, where A is in CSR format +// and B and C are dense. +template +struct CSRSparseMatrixMatMul { + explicit CSRSparseMatrixMatMul(const bool transpose_output); + absl::Status Compute(OpKernelContext* ctx, const ConstCSRComponent& a, + typename TTypes::ConstMatrix b, + typename TTypes::Matrix c); +}; + +// Calculates y = A * x, y = A^T * x, or y = A^H * x, where A is in CSR format +// and x and y are dense vectors. +template +class CSRSparseMatrixMatVec { + CSRSparseMatrixMatVec(bool transpose_a, bool adjoint_a); + absl::Status Compute(OpKernelContext* ctx, const ConstCSRComponent& a, + const T* x, T* y); +}; + +// Calculates C = functor(A, B) where A and B are CSR and C is CSR +// with a different sparsity pattern. +template +struct CSRStructureModifyingFunctor { + virtual ~CSRStructureModifyingFunctor() {} + + virtual absl::Status Initialize() = 0; + + virtual absl::Status GetWorkspaceSize(const ConstCSRComponent& a, + const ConstCSRComponent& b, + size_t* bufferSize) = 0; + + virtual absl::Status GetOutputStructure(const ConstCSRComponent& a, + const ConstCSRComponent& b, + TTypes::UnalignedVec c_row_ptr, + int* output_nnz, void* workspace) = 0; + + virtual absl::Status Compute(const ConstCSRComponent& a, + const ConstCSRComponent& b, + CSRComponent* c, void* workspace) = 0; +}; + +// Calculates C = alpha * A + beta * B, where A and B are in CSR +// format, and alpha and beta are scalars on the host. +template +struct CSRSparseMatrixAdd : public CSRStructureModifyingFunctor { + explicit CSRSparseMatrixAdd(OpKernelContext* ctx, const T alpha, + const T beta); +}; + +// Calculates C = matmul(A, B), where A, B, and C are in CSR format. +template +struct CSRSparseSparseMatrixMatMul + : public CSRStructureModifyingFunctor { + explicit CSRSparseSparseMatrixMatMul(OpKernelContext* ctx, bool transpose_a, + bool transpose_b); +}; + +// Calculates Y = transpose(X) where X and Y are CSR format components. +template +struct CSRSparseMatrixTransposeComponent { + absl::Status operator()(OpKernelContext* ctx, const ConstCSRComponent& x, + CSRComponent* y); +}; + +// Calculates Y = transpose(X) where X and Y are in CSR format. +template +struct CSRSparseMatrixTranspose { + absl::Status operator()(OpKernelContext* ctx, bool conjugate, + const CSRSparseMatrix& input_matrix, + CSRSparseMatrix* output_matrix); +}; + +// Calculates Y = softmax(X) where X and Y are in CSR format; +// missing coefficients in X are treates as -inf (logits of 0 probability). +template +struct CSRSparseMatrixSoftmax { + absl::Status operator()(OpKernelContext* ctx, const CSRSparseMatrix& logits, + typename TTypes::Vec softmax_values); +}; + +template +struct CSRSparseMatrixSoftmaxGrad { + absl::Status operator()(OpKernelContext* ctx, const CSRSparseMatrix& softmax, + const CSRSparseMatrix& grad_softmax, + typename TTypes::Vec gradient_values); +}; + +template +class CSRSparseMatrixMulScalar { + public: + explicit CSRSparseMatrixMulScalar() {} + + absl::Status Compute(OpKernelContext* ctx, const CSRSparseMatrix& a, + typename TTypes::ConstScalar b, CSRSparseMatrix* c); +}; + +template +class CSRSparseMatrixBatchMulVec { + public: + explicit CSRSparseMatrixBatchMulVec() {} + + absl::Status Compute(OpKernelContext* ctx, const CSRSparseMatrix& a, + typename TTypes::ConstFlat b, CSRSparseMatrix* c); +}; + +} // namespace functor + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_SPARSE_KERNELS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/sparse/mat_mul_op.h b/third_party/tflite-hdrs/tensorflow/core/kernels/sparse/mat_mul_op.h new file mode 100644 index 00000000..3e55cfbc --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/sparse/mat_mul_op.h @@ -0,0 +1,1018 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_SPARSE_MAT_MUL_OP_H_ +#define TENSORFLOW_CORE_KERNELS_SPARSE_MAT_MUL_OP_H_ + +#define EIGEN_USE_THREADS + +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM +#define EIGEN_USE_GPU +#endif + +#include "Eigen/Core" // from @eigen_archive +#include "Eigen/SparseCore" // from @eigen_archive +#include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/framework/type_traits.h" +#include "tensorflow/core/framework/variant_op_registry.h" +#include "tensorflow/core/kernels/cwise_ops_common.h" +#include "tensorflow/core/kernels/dense_update_functor.h" +#include "tensorflow/core/kernels/fill_functor.h" +#include "tensorflow/core/kernels/sparse/kernels.h" +#include "tensorflow/core/kernels/sparse/sparse_matrix.h" +#include "tensorflow/core/kernels/sparse/transpose_op.h" +#include "tensorflow/core/kernels/transpose_functor.h" +#include "tensorflow/core/lib/gtl/inlined_vector.h" +#include "tensorflow/core/platform/threadpool.h" + +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM +#include "tensorflow/core/util/cuda_sparse.h" +#include "tensorflow/core/util/gpu_solvers.h" +#endif + +namespace tensorflow { + +// TODO(anudhyan): These constants may be tuned based on the performance of +// 'benchmark_sparse_matrix_mat_vec_mul'. We would like to find constants +// which work across hardware platforms for typical matrix sizes. It should be +// possible to observe at least 30-50% improvement as we increase the number +// of threads by 1. If not, then it may we worth increasing kMaxShards and +// kNumShardsPerThread. However, once we have too many shards, latency may be +// dominated by per-shard overhead. +// +// Maximum number of shards into which to divide the computation for each CSR +// Sparse Matrix instance. +static constexpr int32_t kMaxShards = 20; +// Number of shards allocated to each thread. +static constexpr int32_t kNumShardsPerThread = 3; + +typedef Eigen::ThreadPoolDevice CPUDevice; +typedef Eigen::GpuDevice GPUDevice; + +// Abstract OpKernel to compute sparse-dense matrix multiplication. +// +// Implements a kernel which, given a SparseMatrix `a` and dense Tensor `b`, +// computes a dense Tensor `c` satisfying `c = a * b` where * denotes matrix +// multiplication. +// +// The boolean attributes `transpose_a` and `adjoint_a` will transpose or +// adjoint `a` before multiplication, respectively. At most one of these +// attributes must be set to True. Corresponding attributes will transpose or +// adjoint `b` or the output (after multiplication). +// +// The rank of both `a` and `b` must be equal and their shapes must be +// compatible for matrix multiplication. Otherwise, InvalidArgument runtime +// errors will be thrown. Only rank 2 or rank 3 inputs are supported. +// +template +class CSRMatMulOp : public OpKernel { + public: + explicit CSRMatMulOp(OpKernelConstruction* c) : OpKernel(c) { + OP_REQUIRES_OK(c, c->GetAttr("transpose_a", &transpose_a_)); + OP_REQUIRES_OK(c, c->GetAttr("transpose_b", &transpose_b_)); + bool adjoint_a; + OP_REQUIRES_OK(c, c->GetAttr("adjoint_a", &adjoint_a)); + OP_REQUIRES(c, !(adjoint_a && transpose_a_), + absl::InvalidArgumentError( + "Only one of adjoint_a and transpose_a may be true.")); + bool adjoint_b; + OP_REQUIRES_OK(c, c->GetAttr("adjoint_b", &adjoint_b)); + OP_REQUIRES(c, !(adjoint_b && transpose_b_), + absl::InvalidArgumentError( + "Only one of adjoint_b and transpose_b may be true.")); + OP_REQUIRES_OK(c, c->GetAttr("transpose_output", &transpose_output_)); + OP_REQUIRES_OK(c, c->GetAttr("conjugate_output", &conjugate_output_)); + transpose_a_ |= adjoint_a; + transpose_b_ |= adjoint_b; + if (is_complex::value) { + conjugate_a_ = adjoint_a; + conjugate_b_ = adjoint_b; + } else { + conjugate_a_ = false; + conjugate_b_ = false; + } + } + + ~CSRMatMulOp() override {} + + absl::Status ValidateInputs(const CSRSparseMatrix& sparse_matrix_a, + const Tensor& dense_tensor_b, int* rank, + int64_t* batch_size) { + if (sparse_matrix_a.dtype() != dense_tensor_b.dtype()) { + return absl::InvalidArgumentError(absl::StrCat( + "Input types don't match. a.dtype == ", + DataTypeString(sparse_matrix_a.dtype()), + " vs. b.dtype == ", DataTypeString(dense_tensor_b.dtype()))); + } + *rank = sparse_matrix_a.dims(); + // TODO(ebrevdo): Add support for broadcasting matmul. + if (*rank != dense_tensor_b.dims()) { + return absl::InvalidArgumentError( + absl::StrCat("Ranks of a and b must match, saw: ", *rank, " vs. ", + dense_tensor_b.dims(), ".")); + } + // A valid CSR SparseMatrix has rank 2 or rank 3. + *batch_size = (*rank == 2) ? 1 : dense_tensor_b.dim_size(0); + if (sparse_matrix_a.batch_size() != *batch_size) { + return absl::InvalidArgumentError(absl::StrCat( + "Batch sizes of a and b must match, saw: ", + sparse_matrix_a.batch_size(), " vs. ", *batch_size, ".")); + } + const auto& a_dense_shape = sparse_matrix_a.dense_shape().vec(); + const int64_t a_inner_dim = + a_dense_shape(this->transpose_a_ ? *rank - 2 : *rank - 1); + const int64_t b_inner_dim = + dense_tensor_b.dim_size(this->transpose_b_ ? *rank - 1 : *rank - 2); + if (a_inner_dim != b_inner_dim) { + return absl::InvalidArgumentError( + absl::StrCat("Inner product dimensions of A and B do not agree. ", + "Shapes are: ", TensorShape(a_dense_shape).DebugString(), + " vs. ", dense_tensor_b.shape().DebugString())); + } + return absl::OkStatus(); + } + + public: + bool transpose_a_; + bool transpose_b_; + bool conjugate_a_; + bool conjugate_b_; + bool transpose_output_; + bool conjugate_output_; +}; + +// CPU Kernel to compute sparse-dense matrix multiplication. +// +// Uses Eigen SparseMatrix to compute the sparse-dense multiplication between +// a CSR SparseMatrix `a` and dense Tensor `b`. If intra-op parallelism is +// available, the implementation parallelizes the computation across each row +// of the sparse matrix. +template +class CSRMatMulCPUOp : public CSRMatMulOp { + using SparseMatrix = Eigen::SparseMatrix; + using Matrix = + Eigen::Matrix; + using ConstMatrixMap = Eigen::Map; + using MatrixMap = Eigen::Map; + + public: + explicit CSRMatMulCPUOp(OpKernelConstruction* c) + : CSRMatMulOp(c) {} + + ~CSRMatMulCPUOp() override {} + + void Compute(OpKernelContext* ctx) final { + const CSRSparseMatrix* sparse_matrix_a; + OP_REQUIRES_OK(ctx, ExtractVariantFromInput(ctx, 0, &sparse_matrix_a)); + const Tensor& matrix_b = ctx->input(1); + + int rank; + int64_t batch_size; + OP_REQUIRES_OK(ctx, this->ValidateInputs(*sparse_matrix_a, matrix_b, &rank, + &batch_size)); + + const auto dense_shape = sparse_matrix_a->dense_shape().vec(); + int64_t num_lhs_rows = dense_shape(rank - 2); + int64_t num_lhs_cols = dense_shape(rank - 1); + int64_t num_rhs_rows = matrix_b.dim_size(rank - 2); + int64_t num_rhs_cols = matrix_b.dim_size(rank - 1); + + if (this->transpose_a_) { + std::swap(num_lhs_rows, num_lhs_cols); + } + + // Possibly transpose the dense Tensor b. + const Tensor* rhs = &matrix_b; + Tensor b_transposed; + if (this->transpose_b_) { + OP_REQUIRES_OK( + ctx, TransposeAndConjugateTensor(ctx, matrix_b, this->conjugate_b_, + &b_transposed)); + rhs = &b_transposed; + std::swap(num_rhs_rows, num_rhs_cols); + } + + // If we're transposing the output, then allocate a temporary buffer to + // store the output. Otherwise allocate the output directly. + Tensor* output = nullptr; + Tensor* matmul_result = nullptr; + Tensor output_transposed; + OP_REQUIRES_OK( + ctx, AllocateOutput(ctx, rank, batch_size, num_lhs_rows, num_rhs_cols, + this->transpose_output_, &output, + &output_transposed, &matmul_result)); + + if (!this->transpose_a_) { + SparseDenseMatMulWithoutTransposedLHS( + ctx, batch_size, num_lhs_rows, *sparse_matrix_a, *rhs, matmul_result); + } else { // transpose_a_ == true + SparseDenseMatMulWithTransposedLHS(ctx, batch_size, num_lhs_rows, + num_lhs_cols, *sparse_matrix_a, *rhs, + matmul_result); + } + + // Transpose (and conjugate) the output if necessary. + // Note that conjugate is only true if transpose is also true. + if (this->transpose_output_) { + OP_REQUIRES_OK( + ctx, TransposeAndConjugateAllocatedTensor( + ctx, output_transposed, this->conjugate_output_, output)); + } else if (this->conjugate_output_) { + functor::maybe_conj_inplace::run( + ctx->eigen_device(), output); + } + } + + private: + // Allocates the output with the appropriate shape. Additionally, if + // transpose_output is True, allocates a temporary buffer with the transposed + // output. 'matmul_result' points to either output or output_transposed, based + // on whether transpose_output is True. + absl::Status AllocateOutput(OpKernelContext* ctx, const int32_t rank, + const int64_t batch_size, const int64_t num_rows, + const int64_t num_cols, + const bool transpose_output, Tensor** output, + Tensor* output_transposed, + Tensor** matmul_result) { + TensorShape output_shape; + if (rank == 3) { + TF_RETURN_IF_ERROR(output_shape.AddDimWithStatus(batch_size)); + } + + if (!transpose_output) { + output_shape.AppendShape({num_rows, num_cols}); + TF_RETURN_IF_ERROR(ctx->allocate_output(0, output_shape, output)); + *matmul_result = *output; + } else { + TensorShape output_transposed_shape = output_shape; + output_transposed_shape.AppendShape({num_rows, num_cols}); + output_shape.AppendShape({num_cols, num_rows}); + TF_RETURN_IF_ERROR(ctx->allocate_temp(DataTypeToEnum::value, + output_transposed_shape, + output_transposed)); + TF_RETURN_IF_ERROR(ctx->allocate_output(0, output_shape, output)); + *matmul_result = output_transposed; + } + return absl::OkStatus(); + } + + // Returns an Eigen::Ref expression of a sparse sub-matrix from the given + // contiguous segment of rows of the CSR Sparse Matrix. + Eigen::Ref GetSparseMatrixRef( + const CSRSparseMatrix& csr_matrix, const int batch_index, + const int64_t row_begin, const int64_t num_shard_rows, + std::vector* row_ptrs) { + // Compute the row pointers of the sparse sub-matrix. + row_ptrs->resize(num_shard_rows + 1); + const int64_t row_offset = + csr_matrix.row_pointers_vec(batch_index)(row_begin); + for (int64_t row_idx = 0; row_idx <= num_shard_rows; ++row_idx) { + row_ptrs->at(row_idx) = + csr_matrix.row_pointers_vec(batch_index)(row_begin + row_idx) - + row_offset; + } + const int64_t num_cols = + csr_matrix.dense_shape().vec()(csr_matrix.dims() - 1); + return Eigen::Map( + num_shard_rows /* num_rows */, num_cols /* num_cols */, + row_ptrs->at(num_shard_rows) /* total_nnz */, row_ptrs->data(), + csr_matrix.col_indices_vec(batch_index).data() + row_offset, + csr_matrix.values_vec(batch_index).data() + row_offset); + } + + // Sparse-Dense Matrix Multiplication between a CSRSparseMatrix (LHS) and a + // dense Tensor (RHS). + void SparseDenseMatMulWithoutTransposedLHS(OpKernelContext* ctx, + const int64_t batch_size, + const int64_t num_lhs_rows, + const CSRSparseMatrix& lhs, + const Tensor& rhs, + Tensor* output) { + // Parallelize matrix multiplication across batch dimensions and across + // rows in each batch. + auto worker_threads = *(ctx->device()->tensorflow_cpu_worker_threads()); + const int32_t num_threads = worker_threads.num_threads; + const int64_t block_size = + num_lhs_rows / std::max(kMaxShards, kNumShardsPerThread * num_threads); + const int64_t num_rhs_rows = rhs.dim_size(rhs.dims() - 2); + const int64_t num_rhs_cols = rhs.dim_size(rhs.dims() - 1); + worker_threads.workers->ParallelFor( + batch_size * num_lhs_rows /* total */, + thread::ThreadPool::SchedulingParams( + thread::ThreadPool::SchedulingStrategy:: + kFixedBlockSize /* strategy */, + absl::nullopt /* cost_per_unit */, block_size), + [&](int64_t batch_and_row_begin, int64_t batch_and_row_end) { + HandleBatchAndRowRange( + num_lhs_rows, batch_and_row_begin, batch_and_row_end, + [&](int64_t batch_idx, int64_t row_begin, int64_t row_end) { + const int64_t num_shard_rows = row_end - row_begin; + + // Define an Eigen::SparseMatrix over the row range: + // [row_begin, row_end) of the CSR SparseMatrix A. + std::vector row_ptrs; + auto sparse_matrix = GetSparseMatrixRef( + lhs, batch_idx, row_begin, num_shard_rows, &row_ptrs); + + // Map the corresponding rows of the rhs. + ConstMatrixMap rhs_map(rhs.flat().data() + batch_idx * + num_rhs_rows * + num_rhs_cols, + num_rhs_rows, num_rhs_cols); + + // Write to the corresponding rows of the output matrix. + MatrixMap output_map( + output->flat().data() + + batch_idx * num_lhs_rows * num_rhs_cols + + row_begin * num_rhs_cols, + num_shard_rows, num_rhs_cols); + output_map.noalias() = sparse_matrix * rhs_map; + }); + }); + } + + // Sparse-Dense Matrix Multiplication assuming the CSRSparseMatrix (LHS) is + // to be transposed before the operation. + void SparseDenseMatMulWithTransposedLHS(OpKernelContext* ctx, + const int64_t batch_size, + const int64_t num_lhs_rows, + const int64_t num_lhs_cols, + const CSRSparseMatrix& lhs, + const Tensor& rhs, Tensor* output) { + auto device = ctx->eigen_device(); + auto worker_threads = *(ctx->device()->tensorflow_cpu_worker_threads()); + const int32_t num_threads = worker_threads.num_threads; + const int64_t num_rhs_rows = rhs.dim_size(rhs.dims() - 2); + const int64_t num_rhs_cols = rhs.dim_size(rhs.dims() - 1); + // Usually, we want to avoid transposing the sparse matrix A since it may be + // an expensive operation. Instead, we use the identity (A^T B) = (B^T A)^T. + // We don't actually transpose B or the output because it is more convenient + // to have them in column major form. + // + // However, if A is hypersparse and B and C are huge, transposing A will be + // cheaper. In the future, we should have a cost model estimating the cost + // of transposing all matrices (A, B, C) to decide which variant to use. + + // Each thread writes to its own copy of the matrix product. These + // `num_threads` copies are summed together to obtain the final result. + Tensor matmul_result_buffer; + OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum::value, + TensorShape({num_threads + 1, + output->NumElements()}), + &matmul_result_buffer)); + functor::SetZeroFunctor set_zero; + set_zero(device, matmul_result_buffer.flat()); + + // Parallelize matrix multiplication across batch dimensions and across + // columns of A^T in each batch. These correspond to rows of A. + const int64_t block_size = + num_lhs_cols / std::max(kMaxShards, kNumShardsPerThread * num_threads); + worker_threads.workers->ParallelForWithWorkerId( + batch_size * num_lhs_cols /* total */, + thread::ThreadPool::SchedulingParams( + thread::ThreadPool::SchedulingStrategy:: + kFixedBlockSize /* strategy */, + absl::nullopt /* cost_per_unit */, block_size), + [&](int64_t batch_and_row_begin, int64_t batch_and_row_end, int tid) { + HandleBatchAndRowRange( + num_lhs_cols, batch_and_row_begin, batch_and_row_end, + [&](int64_t batch_idx, int64_t row_begin, int64_t row_end) { + const int64_t num_shard_rows = row_end - row_begin; + + // Define a new sparse sub-matrix from the row range + // [row_begin, row_end) of the sparse matrix A. + std::vector row_ptrs; + auto sparse_matrix = GetSparseMatrixRef( + lhs, batch_idx, row_begin, num_shard_rows, &row_ptrs); + + // Map the corresponding `num_shard_rows` columns of B^T. + // This is the same as taking the `num_shard_rows` rows of B. + ConstMatrixMap b_dense_map( + rhs.flat().data() + + batch_idx * num_rhs_rows * num_rhs_cols + + row_begin * num_rhs_cols, + num_shard_rows, num_rhs_cols); + + // Map to the corresponding rows of the output. + MatrixMap output_map( + matmul_result_buffer.flat().data() + + tid * batch_size * num_lhs_rows * num_rhs_cols + + batch_idx * num_lhs_rows * num_rhs_cols, + num_lhs_rows, num_rhs_cols); + + // Compute the product C^T = B^T * A; restricted to the row + // range in the current shard. + if (this->conjugate_a_) { + output_map.transpose().noalias() += + b_dense_map.transpose() * sparse_matrix.conjugate(); + } else { + output_map.transpose().noalias() += + b_dense_map.transpose() * sparse_matrix; + } + }); + }); + + // Sum across each thread's matmul result. + using Reducer = Eigen::internal::SumReducer; + using Index = typename TTypes::Tensor::Index; + output->flat().device(device) = matmul_result_buffer.matrix().reduce( + Eigen::array({0}), Reducer()); + } + + // Given a range [batch_and_row_begin, batch_and_row_end) which is a + // contiguous subset of [0, num_rows * batch_size), calls the function + // fn(batch_idx, row_begin, row_end) for each batch index + // and the row range [row_begin, row_end) contained in the batch. + void HandleBatchAndRowRange( + const int64_t num_rows, const int64_t batch_and_row_begin, + const int64_t batch_and_row_end, + const std::function& fn) { + // Obtain the batch indices overlapping with the current shard. + const int64_t batch_begin = batch_and_row_begin / num_rows; + const int64_t batch_end_inclusive = batch_and_row_end / num_rows; + + for (int64_t batch_idx = batch_begin; batch_idx <= batch_end_inclusive; + ++batch_idx) { + // Find the contiguous set of rows which are contained in this shard as + // well as the current batch. We intersect with interval [batch_idx * + // num_rows, (batch_idx + 1) * num_rows) which denotes the current batch. + const int64_t current_batch_row_begin = + std::max(batch_and_row_begin, batch_idx * num_rows); + const int64_t current_batch_row_end = + std::min(batch_and_row_end, (batch_idx + 1) * num_rows); + + const int64_t row_begin = current_batch_row_begin % num_rows; + const int64_t num_shard_rows = + current_batch_row_end - current_batch_row_begin; + // Edge case for when current_batch_row_end is the first index of a new + // row. + if (num_shard_rows == 0) continue; + + fn(batch_idx, row_begin, row_begin + num_shard_rows); + } + } + + // Transposes (and optionally, conjugates) a given Tensor. Also allocates the + // required memory for the output Tensor. + absl::Status TransposeAndConjugateTensor(OpKernelContext* ctx, + const Tensor& input, bool conjugate, + Tensor* output) { + TensorShape transposed_shape = input.shape(); + transposed_shape.set_dim(input.dims() - 1, + input.dim_size(input.dims() - 2)); + transposed_shape.set_dim(input.dims() - 2, + input.dim_size(input.dims() - 1)); + TF_RETURN_IF_ERROR( + ctx->allocate_temp(DataTypeToEnum::value, transposed_shape, output)); + return TransposeAndConjugateAllocatedTensor(ctx, input, conjugate, output); + } + + // Transposes (and optionally, conjugates) a given Tensor. The output should + // be already allocated. + absl::Status TransposeAndConjugateAllocatedTensor(OpKernelContext* ctx, + const Tensor& input, + bool conjugate, + Tensor* output) { + if (conjugate) { + TF_RETURN_IF_ERROR(DoConjugateMatrixTranspose( + ctx->eigen_device(), input, output)); + } else { + TF_RETURN_IF_ERROR( + DoMatrixTranspose(ctx->eigen_device(), input, output)); + } + return absl::OkStatus(); + } +}; + +// GPU Kernel to compute sparse-dense matrix multiplication. +template +class CSRMatMulGPUOp : public CSRMatMulOp { + using SparseMatrix = Eigen::SparseMatrix; + using Matrix = + Eigen::Matrix; + using ConstMatrixMap = Eigen::Map; + using MatrixMap = Eigen::Map; + + public: + explicit CSRMatMulGPUOp(OpKernelConstruction* c) + : CSRMatMulOp(c) {} + + ~CSRMatMulGPUOp() override {} + + void Compute(OpKernelContext* ctx) final { + const CSRSparseMatrix* a_matrix; + OP_REQUIRES_OK(ctx, ExtractVariantFromInput(ctx, 0, &a_matrix)); + const Tensor& b_t = ctx->input(1); + + int rank; + int64_t batch_size; + OP_REQUIRES_OK(ctx, + this->ValidateInputs(*a_matrix, b_t, &rank, &batch_size)); + + const Tensor& a_dense_shape_t = a_matrix->dense_shape(); + TensorShape a_dense_tensor_shape; + auto a_dense_shape = a_dense_shape_t.vec(); + OP_REQUIRES_OK( + ctx, TensorShapeUtils::MakeShape(a_dense_shape, &a_dense_tensor_shape)); + + const int row_dim = (rank == 2) ? 0 : 1; + const int64_t a_outer_dim = a_dense_tensor_shape.dim_size( + this->transpose_a_ ? row_dim + 1 : row_dim); + const int64_t b_inner_dim = + b_t.shape().dim_size(this->transpose_b_ ? row_dim + 1 : row_dim); + const int64_t b_outer_dim = + b_t.dim_size(this->transpose_b_ ? row_dim : row_dim + 1); + const int64_t b_slice_size = b_inner_dim * b_outer_dim; + + TensorShape c_shape; + if (rank == 3) { + OP_REQUIRES_OK(ctx, c_shape.AddDimWithStatus(batch_size)); + } + if (this->transpose_output_) { + OP_REQUIRES_OK(ctx, c_shape.AddDimWithStatus(b_outer_dim)); + OP_REQUIRES_OK(ctx, c_shape.AddDimWithStatus(a_outer_dim)); + } else { + OP_REQUIRES_OK(ctx, c_shape.AddDimWithStatus(a_outer_dim)); + OP_REQUIRES_OK(ctx, c_shape.AddDimWithStatus(b_outer_dim)); + } + + const int64_t c_matrix_lhs = c_shape.dim_size(row_dim); + const int64_t c_matrix_rhs = c_shape.dim_size(row_dim + 1); + const int64_t c_slice_size = c_matrix_lhs * c_matrix_rhs; + Tensor* c_t; + OP_REQUIRES_OK(ctx, ctx->allocate_output(0, c_shape, &c_t)); + + const GPUDevice& d = ctx->eigen_device(); + bool use_matrix_vector_multiply = (b_outer_dim == 1); +#if TENSORFLOW_USE_ROCM + // ROCm hipsparse does not implement csrmv with transposed input a + use_matrix_vector_multiply = + use_matrix_vector_multiply && !this->transpose_a_; +#endif + if (use_matrix_vector_multiply) { + // Call matrix-vector multiply if b is a vector. + TTypes::ConstVec a_dense_shape_comp( + a_dense_shape.data() + row_dim, 2); + Tensor b_conj_t; + const T* b_base_ptr = b_t.template flat().data(); + bool conjugate_a = this->conjugate_a_; + bool conjugate_output = this->conjugate_output_; + if (this->conjugate_b_) { + if (conjugate_a) { + // In this case we can use the identity + // conj(a) * conj(b) = conj(a * b) + // instead of creating a conjugated copy of b. + conjugate_a = false; + conjugate_output = !conjugate_output; + } else { + OP_REQUIRES_OK( + ctx, ctx->forward_input_or_allocate_temp( + {1}, DataTypeToEnum::value, b_t.shape(), &b_conj_t)); + functor::maybe_conj::run(d, b_t, &b_conj_t); + b_base_ptr = b_conj_t.template flat().data(); + } + } + + functor::CSRSparseMatrixMatVec csr_spmv(this->transpose_a_, + conjugate_a); + for (int i = 0; i < batch_size; ++i) { + auto a_row_ptr = a_matrix->row_pointers_vec(i); + auto a_col_ind = a_matrix->col_indices_vec(i); + auto a_values = a_matrix->values_vec(i); + ConstCSRComponent a_comp{a_row_ptr, a_col_ind, a_values, + a_dense_shape_comp}; + const T* b_i = b_base_ptr + i * b_slice_size; + T* c_i = &c_t->template flat()(i * c_slice_size); + absl::Status s = csr_spmv.Compute(ctx, a_comp, b_i, c_i); + OP_REQUIRES_OK(ctx, s); + } + if (conjugate_output) { + functor::maybe_conj_inplace::run(d, c_t); + } + return; + } + + functor::CSRSparseMatrixMatMul csr_spmmadd( + this->transpose_output_); + + Tensor c_mat_col_major_t; + if (!this->transpose_output_) { + // If transpose_output is false, we'll need to transpose the (col + // major) output of the csrgemm call to get proper (row-major) + // output. Which means we need to keep a temporary buffer to + // store the intermediate gemm output. + TensorShape c_mat_col_major_shape; + if (rank == 2) { + c_mat_col_major_shape = TensorShape({c_matrix_rhs, c_matrix_lhs}); + } else { + c_mat_col_major_shape = + TensorShape({batch_size, c_matrix_rhs, c_matrix_lhs}); + } + OP_REQUIRES_OK( + ctx, ctx->allocate_temp(DataTypeToEnum::value, + c_mat_col_major_shape, &c_mat_col_major_t)); + } + + // If transpose_output is true, return the direct (column-major i.e., + // transposed) output of the csrgemm call. Otherwise we'll need + // to transpose it to row major format. + auto c_mat_col_major = (this->transpose_output_) + ? c_t->flat() + : c_mat_col_major_t.flat(); + + // Possibly transpose a. + const CSRSparseMatrix* a_input_matrix; + // If we need to transpose a, we will store the result temporarily + // in the object below. + CSRSparseMatrix a_matrix_transposed; + if (!this->transpose_a_) { + a_input_matrix = a_matrix; + } else { + functor::CSRSparseMatrixTranspose transpose; + OP_REQUIRES_OK(ctx, transpose(ctx, this->conjugate_a_, *a_matrix, + &a_matrix_transposed)); + a_input_matrix = &a_matrix_transposed; + } + + auto a_input_dense_shape = a_input_matrix->dense_shape().vec(); + + // Possibly transpose b. + Tensor b_t_input; + if (!this->transpose_b_) { + b_t_input = b_t; + } else { + TensorShape b_t_transposed_shape; + if (rank == 3) { + OP_REQUIRES_OK(ctx, b_t_transposed_shape.AddDimWithStatus(batch_size)); + } + OP_REQUIRES_OK(ctx, b_t_transposed_shape.AddDimWithStatus( + b_t.dim_size(row_dim + 1))); + OP_REQUIRES_OK( + ctx, b_t_transposed_shape.AddDimWithStatus(b_t.dim_size(row_dim))); + OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum::value, + b_t_transposed_shape, &b_t_input)); + const GPUDevice& d = ctx->eigen_device(); + if (this->conjugate_b_) { + OP_REQUIRES_OK(ctx, DoConjugateMatrixTranspose(d, b_t /*input*/, + &b_t_input /*output*/)); + } else { + OP_REQUIRES_OK( + ctx, DoMatrixTranspose(d, b_t /*input*/, &b_t_input /*output*/)); + } + } + + // Dense shape of a batch component of A. + TTypes::ConstVec a_input_dense_shape_comp( + a_input_dense_shape.data() + row_dim, 2); + + auto b = b_t_input.flat(); + + for (int i = 0; i < batch_size; ++i) { + auto a_row_ptr = a_input_matrix->row_pointers_vec(i); + auto a_col_ind = a_input_matrix->col_indices_vec(i); + auto a_values = a_input_matrix->values_vec(i); + typename TTypes::UnalignedConstMatrix b_i(b.data() + i * b_slice_size, + {b_inner_dim, b_outer_dim}); + typename TTypes::UnalignedMatrix c_mat_col_major_i( + c_mat_col_major.data() + i * c_slice_size, + {c_matrix_lhs, c_matrix_rhs}); + ConstCSRComponent a_comp{a_row_ptr, a_col_ind, a_values, + a_input_dense_shape_comp}; + absl::Status s = csr_spmmadd.Compute(ctx, a_comp, b_i, c_mat_col_major_i); + OP_REQUIRES_OK(ctx, s); + } + + if (!this->transpose_output_) { + // We need to return values in row major format, so transpose + // the column-major values in c_mat_col_major_t to row-major output c_t. + OP_REQUIRES_OK(ctx, DoMatrixTranspose(d, /*input=*/c_mat_col_major_t, + /*output=*/c_t)); + } + if (this->conjugate_output_) { + functor::maybe_conj_inplace::run(d, c_t); + } + } +}; + +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM + +namespace functor { + +namespace gpu_data_type { + +// GPUDataType::type translates from a C++ type (e.g. float) to a +// GPUDataType_t (e.g. CUDA_R_32F). +template +struct GPUDataType; + +template <> +struct GPUDataType { +#if GOOGLE_CUDA + static constexpr cudaDataType_t type = CUDA_R_16F; +#else + static constexpr hipDataType type = HIP_R_16F; +#endif +}; + +template <> +struct GPUDataType { +#if GOOGLE_CUDA + static constexpr cudaDataType_t type = CUDA_R_32F; +#else + static constexpr hipDataType type = HIP_R_32F; +#endif +}; + +template <> +struct GPUDataType> { +#if GOOGLE_CUDA + static constexpr cudaDataType_t type = CUDA_C_32F; +#else + static constexpr hipDataType type = HIP_C_32F; +#endif +}; + +template <> +struct GPUDataType { +#if GOOGLE_CUDA + static constexpr cudaDataType_t type = CUDA_R_64F; +#else + static constexpr hipDataType type = HIP_R_64F; +#endif +}; + +template <> +struct GPUDataType> { +#if GOOGLE_CUDA + static constexpr cudaDataType_t type = CUDA_C_64F; +#else + static constexpr hipDataType type = HIP_C_64F; +#endif +}; + +} // namespace gpu_data_type + +template +class CSRSparseMatrixMatMul { + public: + explicit CSRSparseMatrixMatMul(const bool transpose_output) + : transpose_output_(transpose_output) {} + + Status Compute(OpKernelContext* ctx, const ConstCSRComponent& a, + typename TTypes::UnalignedConstMatrix b, + typename TTypes::UnalignedMatrix c) { + GpuSparse cuda_sparse(ctx); + TF_RETURN_IF_ERROR(cuda_sparse.Initialize()); + { + // Use Csrmm/SpMM to calculate: + // C = alpha * op(A) * op(B) + beta * C + // where alpha = 1.0, beta = 0.0, A is sparse and B and C are dense. + // Note that Csrmm/Spmm assumes B and C are in column-major form; so we + // use transB == true, and manually transpose the output in place + // using blasgeam. + // TODO(ebrevdo,rmlarsen): Add support for transposition and adjoint. + + // Create alpha and beta scalars; alpha = 1.0, beta = 0.0 + // TODO(ebrevdo,rmlarsen): Add support for non-trivial alpha and beta. + const T alpha = 1; + const T beta = 0; + + // A is (m, k), Bt is (ldb, k) and Ct is (ldc, n) + const int k = b.dimension(0); + DCHECK_EQ(k, a.dense_shape_host(1)); + + // If transpose_output_ is true, then the c matrix we receive + // here is the direct row major output (into which we will store + // csrgemm's col major output). Otherwise it's a + // temporary tensor that will store the column major output that + // will eventually be transposed. + const int m = c.dimension(transpose_output_ ? 1 : 0); + const int n = c.dimension(transpose_output_ ? 0 : 1); + DCHECK_EQ(m, a.dense_shape_host(0)); + DCHECK_EQ(n, b.dimension(1)); + const int nnz = a.values.size(); + DCHECK_EQ(nnz, a.col_ind.size()); + + // ldb: leading dimension of B. If op(B)=B, it must be at least max(1, k) + // if op(A) = A and at least max (1, m) otherwise. If op(B) != B, it must + // be at least max(1, n). + const int ldb = n; + // ldc: leading dimension of C. It must be at least max(1, m) if + // op(A) = A and at least max(1, k) otherwise. + const int ldc = m; + + // transA must be non-transpose if transB is transpose (cusparse + // limitation). +#if GOOGLE_CUDA + const gpusparseOperation_t transA = CUSPARSE_OPERATION_NON_TRANSPOSE; +#elif TENSORFLOW_USE_ROCM + const gpusparseOperation_t transA = HIPSPARSE_OPERATION_NON_TRANSPOSE; +#endif + + // transB: b is row-major, and cusparse requires col-major b (or + // equivalently transB == transpose). this version is actually more + // efficient. +#if GOOGLE_CUDA && CUDA_VERSION >= 10020 + + const gpusparseOperation_t transB = CUSPARSE_OPERATION_TRANSPOSE; + gpusparseSpMatDescr_t matA; + gpusparseDnMatDescr_t matB, matC; + + TF_RETURN_IF_GPUSPARSE_ERROR(cusparseCreateCsr( + &matA, m, k, nnz, const_cast(a.row_ptr.data()), + const_cast(a.col_ind.data()), const_cast(a.values.data()), + CUSPARSE_INDEX_32I, CUSPARSE_INDEX_32I, CUSPARSE_INDEX_BASE_ZERO, + gpu_data_type::GPUDataType::type)); + + TF_RETURN_IF_GPUSPARSE_ERROR(cusparseCreateDnMat( + &matB, n, k, ldb, const_cast(b.data()), + gpu_data_type::GPUDataType::type, CUSPARSE_ORDER_COL)); + + TF_RETURN_IF_GPUSPARSE_ERROR(cusparseCreateDnMat( + &matC, m, n, ldc, c.data(), gpu_data_type::GPUDataType::type, + CUSPARSE_ORDER_COL)); + +#if CUDA_VERSION >= 12000 + cusparseSpMMAlg_t algo = CUSPARSE_SPMM_ALG_DEFAULT; +#else + cusparseSpMMAlg_t algo = CUSPARSE_MM_ALG_DEFAULT; +#endif + size_t bufferSize = 0; + TF_RETURN_IF_ERROR(cuda_sparse.SpMMBufferSize( + transA, transB, &alpha, matA, matB, &beta, matC, algo, &bufferSize)); + + Tensor buffer; + TF_RETURN_IF_ERROR(ctx->allocate_temp( + DT_INT8, TensorShape({static_cast(bufferSize)}), &buffer)); + DCHECK(buffer.flat().data() != nullptr); + + TF_RETURN_IF_ERROR(cuda_sparse.SpMM(transA, transB, &alpha, matA, matB, + &beta, matC, algo, + buffer.flat().data())); + + TF_RETURN_IF_GPUSPARSE_ERROR(cusparseDestroyDnMat(matB)); + TF_RETURN_IF_GPUSPARSE_ERROR(cusparseDestroyDnMat(matC)); + TF_RETURN_IF_GPUSPARSE_ERROR(cusparseDestroySpMat(matA)); + +#elif TENSORFLOW_USE_ROCM && TF_ROCM_VERSION >= 40200 + // Use SPMM + const gpusparseOperation_t transB = HIPSPARSE_OPERATION_TRANSPOSE; + gpusparseSpMatDescr_t matA; + gpusparseDnMatDescr_t matB, matC; + + TF_RETURN_IF_GPUSPARSE_ERROR(se::wrap::hipsparseCreateCsr( + &matA, m, k, nnz, const_cast(a.row_ptr.data()), + const_cast(a.col_ind.data()), const_cast(a.values.data()), + HIPSPARSE_INDEX_32I, HIPSPARSE_INDEX_32I, HIPSPARSE_INDEX_BASE_ZERO, + gpu_data_type::GPUDataType::type)); + + TF_RETURN_IF_GPUSPARSE_ERROR(se::wrap::hipsparseCreateDnMat( + &matB, n, k, ldb, const_cast(b.data()), + gpu_data_type::GPUDataType::type, HIPSPARSE_ORDER_COLUMN)); + + TF_RETURN_IF_GPUSPARSE_ERROR(se::wrap::hipsparseCreateDnMat( + &matC, m, n, ldc, c.data(), gpu_data_type::GPUDataType::type, + HIPSPARSE_ORDER_COLUMN)); + + size_t bufferSize = 0; + TF_RETURN_IF_ERROR(cuda_sparse.SpMMBufferSize( + transA, transB, &alpha, matA, matB, &beta, matC, + HIPSPARSE_MM_ALG_DEFAULT, &bufferSize)); + + Tensor buffer; + TF_RETURN_IF_ERROR(ctx->allocate_temp( + DT_INT8, TensorShape({static_cast(bufferSize)}), &buffer)); + DCHECK(buffer.flat().data() != nullptr); + + TF_RETURN_IF_ERROR(cuda_sparse.SpMM(transA, transB, &alpha, matA, matB, + &beta, matC, HIPSPARSE_MM_ALG_DEFAULT, + buffer.flat().data())); + + TF_RETURN_IF_GPUSPARSE_ERROR(se::wrap::hipsparseDestroyDnMat(matB)); + TF_RETURN_IF_GPUSPARSE_ERROR(se::wrap::hipsparseDestroyDnMat(matC)); + TF_RETURN_IF_GPUSPARSE_ERROR(se::wrap::hipsparseDestroySpMat(matA)); + +#else + +#if GOOGLE_CUDA + + const gpusparseOperation_t transB = CUSPARSE_OPERATION_TRANSPOSE; + + gpusparseMatDescr_t descrA; + TF_RETURN_IF_GPUSPARSE_ERROR(cusparseCreateMatDescr(&descrA)); + TF_RETURN_IF_GPUSPARSE_ERROR( + cusparseSetMatType(descrA, CUSPARSE_MATRIX_TYPE_GENERAL)); + TF_RETURN_IF_GPUSPARSE_ERROR( + cusparseSetMatIndexBase(descrA, CUSPARSE_INDEX_BASE_ZERO)); + +#elif TENSORFLOW_USE_ROCM + + const gpusparseOperation_t transB = HIPSPARSE_OPERATION_TRANSPOSE; + + gpusparseMatDescr_t descrA; + TF_RETURN_IF_GPUSPARSE_ERROR(se::wrap::hipsparseCreateMatDescr(&descrA)); + TF_RETURN_IF_GPUSPARSE_ERROR( + se::wrap::hipsparseSetMatType(descrA, HIPSPARSE_MATRIX_TYPE_GENERAL)); + TF_RETURN_IF_GPUSPARSE_ERROR(se::wrap::hipsparseSetMatIndexBase( + descrA, HIPSPARSE_INDEX_BASE_ZERO)); +#endif // GOOGLE_CUDA + + TF_RETURN_IF_ERROR( + cuda_sparse.Csrmm(transA, transB, m, n, k, nnz, &alpha, descrA, + a.values.data(), a.row_ptr.data(), a.col_ind.data(), + b.data(), ldb, &beta, c.data(), ldc)); + +#endif // GOOGLE_CUDA && CUDA_VERSION >= 10020 + } + + return OkStatus(); + } + + private: + bool transpose_output_; +}; + +template +class CSRSparseMatrixMatVec { + public: + CSRSparseMatrixMatVec(bool transpose_a, bool conjugate_a) + : transA_(TransposeAndConjugateToGpuSparseOp(transpose_a, conjugate_a, + &status_)) {} + + Status Compute(OpKernelContext* ctx, const ConstCSRComponent& a, + const T* x, T* y) { + TF_RETURN_IF_ERROR(status_); + GpuSparse cuda_sparse(ctx); + TF_RETURN_IF_ERROR(cuda_sparse.Initialize()); + { + // Use Csrmv to calculate: + // y = alpha * op(A) * x + beta * y + // where alpha = 1.0, beta = 0.0, A is a sparse matrix and x and y are + // dense vectors. + + // Create alpha and beta scalars; alpha = 1.0, beta = 0.0 + // TODO(rmlarsen,ebrevdo): Add support for general alpha, beta. + const T alpha = 1; + const T beta = 0; + +#if GOOGLE_CUDA && CUDA_VERSION < 10020 + gpusparseMatDescr_t descrA; + TF_RETURN_IF_GPUSPARSE_ERROR(cusparseCreateMatDescr(&descrA)); + TF_RETURN_IF_GPUSPARSE_ERROR( + cusparseSetMatType(descrA, CUSPARSE_MATRIX_TYPE_GENERAL)); + TF_RETURN_IF_GPUSPARSE_ERROR( + cusparseSetMatIndexBase(descrA, CUSPARSE_INDEX_BASE_ZERO)); +#elif TENSORFLOW_USE_ROCM + gpusparseMatDescr_t descrA; + TF_RETURN_IF_GPUSPARSE_ERROR(se::wrap::hipsparseCreateMatDescr(&descrA)); + TF_RETURN_IF_GPUSPARSE_ERROR( + se::wrap::hipsparseSetMatType(descrA, HIPSPARSE_MATRIX_TYPE_GENERAL)); + TF_RETURN_IF_GPUSPARSE_ERROR(se::wrap::hipsparseSetMatIndexBase( + descrA, HIPSPARSE_INDEX_BASE_ZERO)); +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM + + const int m = a.dense_shape_host(0); + const int n = a.dense_shape_host(1); + const int nnz = a.values.size(); + DCHECK_EQ(nnz, a.col_ind.size()); +#if GOOGLE_CUDA && (CUDA_VERSION >= 10020) + TF_RETURN_IF_ERROR(cuda_sparse.Csrmv(transA_, m, n, nnz, &alpha, + a.values.data(), a.row_ptr.data(), + a.col_ind.data(), x, &beta, y)); +#else + TF_RETURN_IF_ERROR(cuda_sparse.Csrmv(transA_, m, n, nnz, &alpha, descrA, + a.values.data(), a.row_ptr.data(), + a.col_ind.data(), x, &beta, y)); +#endif + } + + return OkStatus(); + } + + private: + Status status_; + const gpusparseOperation_t transA_; +}; + +} // namespace functor + +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_SPARSE_MAT_MUL_OP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/sparse/sparse_matrix.h b/third_party/tflite-hdrs/tensorflow/core/kernels/sparse/sparse_matrix.h new file mode 100644 index 00000000..8e5ff45f --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/sparse/sparse_matrix.h @@ -0,0 +1,655 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_SPARSE_SPARSE_MATRIX_H_ +#define TENSORFLOW_CORE_KERNELS_SPARSE_SPARSE_MATRIX_H_ + +#define EIGEN_USE_THREADS + +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM +#define EIGEN_USE_GPU +#endif + +#include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/framework/variant.h" +#include "tensorflow/core/framework/variant_encode_decode.h" +#include "tensorflow/core/framework/variant_op_registry.h" +#include "tensorflow/core/platform/errors.h" + +namespace tensorflow { + +class CSRSparseMatrix { + // CreateCSRSparseMatrix is the main method used to construct a + // CSRSparseMatrix. The representations for both 2D and 3D + // (batched) CSR Sparse Matrices are the same: + // + // dtype: The datatype of the values. + // dense_shape: The dense shape of the matrix. + // * Host int64 vector, size 2 or 3. + // * Takes on values: (rows, cols) or (batch_size, rows, cols). + // batch_pointers: Batch offset pointers into col_indices and values. + // * Host int32 vector, size (batch_size + 1). + // * Takes on values: (0, nnz[0], nnz[0] + nnz[1], ..., total_nnz). + // row_pointers: Row offset pointers into col_indices and values. + // * Device int32 vector, size ((rows + 1) * batch_size). + // * Each block of size (rows + 1) takes on values: + // (0, num_rows{b}[0], num_rows{b}[0] + num_rows{b}[1], ..., nnz[b]). + // for b = 0 .. batch_size - 1. + // col_indices: Column values for the given row and column index. + // * Device int32 vector, size total_nnz. + // values: Actual values for the given row and column index. + // * Device dtype vector, size total_nnz. + // + // The storage agreement is such that for a given (batch, row, ix): + // offset = batch_pointers(batch) + row_pointers(batch * (rows + 1) + row) + // col = col_indices(offset + ix) + // val = values(offset + ix) + // where ix < #nnz columns in (batch, row). + // Then: + // matrix(batch, row, col) = val. + // + // All other elements in the dense representation are treated as 0 / empty. + // + // For example, for a 2D sparse matrix m shaped (3, 4) such that: + // + // m[0, 0] = 1.0 + // m[0, 1] = 2.0 + // m[0, 2] = 3.0 + // m[2, 2] = 4.0 + // m[2, 3] = 5.0 + // + // The corresponding representation is: + // + // dtype: DT_FLOAT + // dense_shape: (3, 4) + // batch_pointers: (0, 5) + // row_pointers: (0, 3, 3, 5) + // col_indices: concat((0, 1, 2), (), (2, 3)) + // values: concat((1.0, 2.0, 3.0), (), (4.0, 5.0)) + // + // For a 3D sparse matrix m shaped (2, 3, 4) such that: + // + // m[0, 0, 0] = 1.0 + // m[0, 0, 2] = 2.0 + // m[0, 2, 3] = 3.0 + // m[1, 0, 3] = 4.0 + // m[1, 1, 0] = 5.0 + // + // The corresponding representation is: + // dtype: DT_FLOAT + // dense_shape: (2, 3, 4) + // batch_pointers: (0, 3, 5) + // row_pointers: concat((0, 2, 2, 3), (0, 1, 2, 2)) + // col_indices: concat(concat((0, 2), (), (3,)), + // concat((3,), (), (0,))) + // values: concat(concat((1.0, 2.0), (3.0,), ()), + /// concat((4.0,), (5.0,), ())) + // + public: + static constexpr const char kTypeName[] = "tensorflow::CSRSparseMatrix"; + + CSRSparseMatrix() : metadata_{false, DT_INVALID} {} + + CSRSparseMatrix(const CSRSparseMatrix& rhs) + : metadata_(rhs.metadata_), + dense_shape_(rhs.dense_shape_), + batch_pointers_(rhs.batch_pointers_), + row_pointers_(rhs.row_pointers_), + col_indices_(rhs.col_indices_), + values_(rhs.values_) { + SetupVecs(); + } + + CSRSparseMatrix(CSRSparseMatrix&& rhs) + : metadata_(rhs.metadata_), + dense_shape_(std::move(rhs.dense_shape_)), + batch_pointers_(std::move(rhs.batch_pointers_)), + row_pointers_(std::move(rhs.row_pointers_)), + col_indices_(std::move(rhs.col_indices_)), + values_(std::move(rhs.values_)) { + SetupVecs(); + rhs.metadata_.validated = false; + rhs.metadata_.dtype = DT_INVALID; + rhs.ClearVecs(); + } + + CSRSparseMatrix& operator=(CSRSparseMatrix&& rhs) { + if (this == &rhs) return *this; + metadata_ = rhs.metadata_; + metadata_.validated = rhs.metadata_.validated; + dense_shape_ = std::move(rhs.dense_shape_); + batch_pointers_ = std::move(rhs.batch_pointers_); + row_pointers_ = std::move(rhs.row_pointers_); + col_indices_ = std::move(rhs.col_indices_); + values_ = std::move(rhs.values_); + SetupVecs(); + rhs.metadata_ = {false, DT_INVALID}; + rhs.ClearVecs(); + return *this; + } + + static absl::Status CreateCSRSparseMatrix( + DataType dtype, + const Tensor& dense_shape, // on host + const Tensor& batch_pointers, // on host + const Tensor& row_pointers, const Tensor& col_indices, + const Tensor& values, CSRSparseMatrix* matrix) { + *matrix = CSRSparseMatrix(dtype, dense_shape, batch_pointers, row_pointers, + col_indices, values); + absl::Status s = matrix->Validate(); + matrix->metadata_.validated = s.ok(); + matrix->SetupVecs(); + return s; + } + + absl::Status Validate() const { + return ValidateTypesAndShapes(metadata_.dtype, dense_shape_, + batch_pointers_, row_pointers_, col_indices_, + values_); + } + + void Clear() { + metadata_ = {false, DT_INVALID}; + dense_shape_ = Tensor(); + batch_pointers_ = Tensor(); + row_pointers_ = Tensor(); + col_indices_ = Tensor(); + values_ = Tensor(); + ClearVecs(); + } + + bool valid() const { + return metadata_.validated && dense_shape_.IsInitialized() && + batch_pointers_.IsInitialized() && row_pointers_.IsInitialized() && + col_indices_.IsInitialized() && values_.IsInitialized() && + dense_shape_.NumElements() > 1 && + batch_pointers_.NumElements() > 0 && row_pointers_.NumElements() > 0; + } + + DataType dtype() const { + DCHECK(valid()); + return metadata_.dtype; + } + + inline int dims() const { + DCHECK(valid()); + return dense_shape_.NumElements(); + } + + inline int nnz(int batch) const { + DCHECK_LT(batch, batch_size()); + return (*batch_pointers_vec_)(batch + 1) - (*batch_pointers_vec_)(batch); + } + + inline int batch_offset(int batch) const { + DCHECK_LT(batch, batch_size()); + return (*batch_pointers_vec_)(batch); + } + + inline int total_nnz() const { + DCHECK(valid()); + return (*batch_pointers_vec_)(batch_size()); + } + + inline Tensor& dense_shape() { + DCHECK(valid()); + return dense_shape_; + } + + inline const Tensor& dense_shape() const { + DCHECK(valid()); + return dense_shape_; + } + + inline TTypes::UnalignedVec row_pointers_vec(int batch) { + DCHECK(valid()); + DCHECK_LT(batch, batch_size()); + const int64_t rows = dense_shape().vec()((dims() == 2) ? 0 : 1); + const int offset = batch * (rows + 1); + return TTypes::UnalignedVec(row_pointers_vec_->data() + offset, + rows + 1); + } + + inline TTypes::UnalignedConstVec row_pointers_vec(int batch) const { + DCHECK(valid()); + DCHECK_LT(batch, batch_size()); + const int64_t rows = dense_shape().vec()((dims() == 2) ? 0 : 1); + const int offset = batch * (rows + 1); + return TTypes::UnalignedConstVec(row_pointers_vec_->data() + offset, + rows + 1); + } + + inline TTypes::UnalignedVec col_indices_vec(int batch) { + DCHECK(valid()); + DCHECK_LT(batch, batch_size()); + const int offset = (*batch_pointers_vec_)(batch); + const int nnz_in_batch = nnz(batch); + return TTypes::UnalignedVec(col_indices_vec_->data() + offset, + nnz_in_batch); + } + + inline TTypes::UnalignedConstVec col_indices_vec(int batch) const { + DCHECK(valid()); + DCHECK_LT(batch, batch_size()); + const int offset = (*batch_pointers_vec_)(batch); + const int nnz_in_batch = nnz(batch); + return TTypes::UnalignedConstVec(col_indices_vec_->data() + offset, + nnz_in_batch); + } + + template + inline typename TTypes::UnalignedVec values_vec(int batch) { + DCHECK(valid()); + DCHECK_LT(batch, batch_size()); + const int offset = (*batch_pointers_vec_)(batch); + const int nnz_in_batch = nnz(batch); + return typename TTypes::UnalignedVec(values().vec().data() + offset, + nnz_in_batch); + } + + template + inline typename TTypes::UnalignedConstVec values_vec(int batch) const { + DCHECK(valid()); + DCHECK_LT(batch, batch_size()); + const int offset = (*batch_pointers_vec_)(batch); + const int nnz_in_batch = nnz(batch); + return typename TTypes::UnalignedConstVec( + values().vec().data() + offset, nnz_in_batch); + } + + inline Tensor& row_pointers() { + DCHECK(valid()); + return row_pointers_; + } + + inline const Tensor& row_pointers() const { + DCHECK(valid()); + return row_pointers_; + } + + inline Tensor& col_indices() { + DCHECK(valid()); + return col_indices_; + } + + inline const Tensor& col_indices() const { + DCHECK(valid()); + return col_indices_; + } + + inline Tensor& values() { + DCHECK(valid()); + return values_; + } + + inline const Tensor& values() const { + DCHECK(valid()); + return values_; + } + + inline Tensor& batch_pointers() { + DCHECK(valid()); + return batch_pointers_; + } + + inline const Tensor& batch_pointers() const { + DCHECK(valid()); + return batch_pointers_; + } + + std::string TypeName() const { return kTypeName; } + + // TODO(ebrevdo): A better debug string. + std::string DebugString() const { return dense_shape_.DebugString(); } + + // Returns the number of elements. This is equal to 1 if the + // CSRSparseMatrix is a singleton matrix (dense_shape is length 2). + int batch_size() const { + DCHECK(valid()); + return batch_pointers_.NumElements() - 1; + } + + bool Decode(const VariantTensorData& p) { + if (p.tensors_.empty()) return false; + Metadata metadata; + if (!p.get_metadata(&metadata)) return false; + const bool validated = metadata.validated; + const DataType dtype = metadata.dtype; + + // p.tensors_ should contain tensors {dense_shape, batch_pointers, + // row_pointers, col_indices, values}. + if (p.tensors_.size() != 5) return false; + + Tensor dense_shape = p.tensors_[0]; + if (dense_shape.dtype() != DT_INT64) return false; + if (dense_shape.dims() != 1) return false; + int rank = dense_shape.dim_size(0); + if (rank < 2 || rank > 3) return false; + + Tensor batch_pointers(p.tensors_[1]); + Tensor row_pointers(p.tensors_[2]); + Tensor col_indices(p.tensors_[3]); + Tensor values(p.tensors_[4]); + + // Check that the validated bool is consistent with the data. + absl::Status s = ValidateTypesAndShapes(dtype, dense_shape, batch_pointers, + row_pointers, col_indices, values); + if (s.ok() != validated) return false; + + // Save to this object. + metadata_ = metadata; + dense_shape_ = std::move(dense_shape); + batch_pointers_ = std::move(batch_pointers); + row_pointers_ = std::move(row_pointers); + col_indices_ = std::move(col_indices); + values_ = std::move(values); + SetupVecs(); + return true; + } + + void Encode(VariantTensorData* p) const { + DCHECK(valid()); + + // Store metadata_ to p's metadata + p->set_metadata(metadata_); + + // Store dense_shape, row_pointers, col_indices, and values to p->tensors_. + p->tensors_.reserve(5); + p->tensors_.push_back(dense_shape_); + p->tensors_.push_back(batch_pointers_); + p->tensors_.push_back(row_pointers_); + p->tensors_.push_back(col_indices_); + p->tensors_.push_back(values_); + } + + // This static method copies CSRSparseMatrices in all directions: + // Host->Device, Device->Host, and Device->Device. + static absl::Status DeviceCopy( + const CSRSparseMatrix& from, CSRSparseMatrix* to, + const UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn& copy) { + VLOG(2) << "DeviceCopy from type: " << DataTypeString(from.dtype()) + << " and shape: " << from.dense_shape().DebugString(); + Tensor to_row_ptr(DT_INT32); + Tensor to_col_ind(DT_INT32); + Tensor to_values(from.dtype()); + TF_RETURN_IF_ERROR(copy(from.row_pointers(), &to_row_ptr)); + TF_RETURN_IF_ERROR(copy(from.col_indices(), &to_col_ind)); + TF_RETURN_IF_ERROR(copy(from.values(), &to_values)); + return CreateCSRSparseMatrix(from.dtype(), + from.dense_shape(), // Always on host. + from.batch_pointers(), // Always on host. + to_row_ptr, to_col_ind, to_values, to); + } + + private: + CSRSparseMatrix(DataType dtype, const Tensor& dense_shape, + const Tensor& batch_pointers, const Tensor& row_pointers, + const Tensor& col_indices, const Tensor& values) + : metadata_{false, dtype}, + dense_shape_(dense_shape), + batch_pointers_(batch_pointers), + row_pointers_(row_pointers), + col_indices_(col_indices), + values_(values) {} + + void SetupVecs() { + if (!metadata_.validated) return; + batch_pointers_vec_.reset( + new TTypes::Vec(batch_pointers_.vec())); + row_pointers_vec_.reset(new TTypes::Vec(row_pointers_.vec())); + col_indices_vec_.reset(new TTypes::Vec(col_indices_.vec())); + } + + void ClearVecs() { + batch_pointers_vec_.reset(); + row_pointers_vec_.reset(); + col_indices_vec_.reset(); + } + + static absl::Status ValidateTypesAndShapes(DataType dtype, + const Tensor& dense_shape, + const Tensor& batch_pointers, + const Tensor& row_pointers, + const Tensor& col_indices, + const Tensor& values) { + // TODO(ebrevdo): Consider adding support for other floating point types + // (namely, float16). + if (dtype != DT_FLOAT && dtype != DT_DOUBLE && dtype != DT_COMPLEX64 && + dtype != DT_COMPLEX128) { + return errors::InvalidArgument( + "CSRSparseMatrix::Validate: dtype = ", DataTypeString(dtype), + " not in {float32, float64, complex64, complex128}"); + } + // dense_shape checks + if (dense_shape.dtype() != DT_INT64) { + return errors::InvalidArgument( + "CSRSparseMatrix::Validate: dense_shape.dtype() = ", + DataTypeString(dense_shape.dtype()), " != int64"); + } + if (dense_shape.dims() != 1) { + return errors::InvalidArgument( + "CSRSparseMatrix::Validate: dense_shape should be a vector, but saw " + "tensor: ", + dense_shape.DebugString()); + } + int rank = dense_shape.dim_size(0); + if (rank < 2 || rank > 3) { + return errors::InvalidArgument( + "CSRSparseMatrix::Validate: dense_shape should be a 2- or 3- vector, " + "but saw: ", + dense_shape.SummarizeValue(5)); + } + auto dense_shape_t = dense_shape.vec(); + const int64_t batch_size = (rank == 2) ? 1 : dense_shape_t(0); + const int64_t num_rows = (rank == 2) ? dense_shape_t(0) : dense_shape_t(1); + + if (batch_pointers.dtype() != DT_INT32) { + return errors::InvalidArgument( + "CSRSparseMatrix::Validate: batch_pointers.dtype() = ", + DataTypeString(batch_pointers.dtype()), " != int32"); + } + if (batch_pointers.dims() != 1) { + return errors::InvalidArgument( + "CSRSparseMatrix::Validate: batch_indices is not a vector, saw " + "shape: ", + batch_pointers.shape().DebugString()); + } + + // batch size checks + if (batch_size != batch_pointers.NumElements() - 1) { + return errors::InvalidArgument( + "CSRSparseMatrix::Validate: dense_shape is ", + dense_shape.SummarizeValue(5), + " but batch pointers implies batch size is ", + batch_pointers.NumElements() - 1); + } + + if (row_pointers.dtype() != DT_INT32) { + return errors::InvalidArgument( + "CSRSparseMatrix::Validate: row_pointers.dtype() = ", + DataTypeString(row_pointers.dtype()), " != int32"); + } + if (row_pointers.dims() != 1) { + return errors::InvalidArgument( + "CSRSparseMatrix::Validate: row_pointers is not a vector, saw " + "shape: ", + row_pointers.shape().DebugString()); + } + if (row_pointers.dim_size(0) != batch_size * (num_rows + 1)) { + return errors::InvalidArgument( + "CSRSparseMatrix::Validate: row_pointers should have size batch_size " + "* (num_rows + 1), saw shapes: ", + dense_shape.DebugString(), " vs. ", + row_pointers.shape().DebugString()); + } + if (col_indices.dtype() != DT_INT32) { + return errors::InvalidArgument( + "CSRSparseMatrix::Validate: col_indices.dtype() = ", + DataTypeString(col_indices.dtype()), " != int32"); + } + if (col_indices.dims() != 1) { + return errors::InvalidArgument( + "CSRSparseMatrix::Validate: col_indices is not a vector, saw shape: ", + col_indices.shape().DebugString()); + } + if (values.dtype() != dtype) { + return errors::InvalidArgument( + "CSRSparseMatrix::Validate: values.dtype() = ", + DataTypeString(values.dtype()), + " != dtype = ", DataTypeString(dtype)); + } + if (values.dims() != 1) { + return errors::InvalidArgument( + "CSRSparseMatrix::Validate: values is not a vector, saw shape: ", + values.shape().DebugString()); + } + if (col_indices.dim_size(0) != values.dim_size(0)) { + return errors::InvalidArgument( + "CSRSparseMatrix::Validate: size(col_indices) = ", + col_indices.dim_size(0), " != size(values) = ", values.dim_size(0)); + } + return absl::OkStatus(); + } + + struct Metadata { + bool validated; + DataType dtype; + }; + Metadata metadata_; + Tensor dense_shape_; + Tensor batch_pointers_; + Tensor row_pointers_; + Tensor col_indices_; + Tensor values_; + std::unique_ptr::Vec> batch_pointers_vec_; + std::unique_ptr::Vec> row_pointers_vec_; + std::unique_ptr::Vec> col_indices_vec_; +}; + +// Call BinaryFunctor()(ctx, a, b, c) +// where T depends on a.dtype(). T will be one of: float, double, +// complex64, complex128. +template class BinaryFunctor> +absl::Status CSRSparseMatrixBinaryHelper(OpKernelContext* ctx, + const CSRSparseMatrix& a, + const CSRSparseMatrix& b, + CSRSparseMatrix* c) { + DataType dt = a.dtype(); + if (dt != b.dtype()) { + return errors::InvalidArgument( + "CSRSparseMatrixBinaryHelper: Inconsistent dtypes for input matrices, " + "a " + "dtype: ", + DataTypeString(dt), ", b dtype: ", DataTypeString(b.dtype())); + } + switch (dt) { + case DT_FLOAT: { + BinaryFunctor functor(ctx); + return functor(a, b, c); + } + case DT_DOUBLE: { + BinaryFunctor functor(ctx); + return functor(a, b, c); + } + case DT_COMPLEX64: { + BinaryFunctor functor(ctx); + return functor(a, b, c); + } + case DT_COMPLEX128: { + BinaryFunctor functor(ctx); + return functor(a, b, c); + } + default: + return errors::InvalidArgument( + "CSRSparseMatrixBinaryHelper: a.dtype (", DataTypeString(dt), + ") is not one of: float, double, complex64, complex128"); + } +} + +// Call UnaryFunctor()(ctx, a, b) +// where T depends on a.dtype(). T will be one of: float, double, +// complex64, complex128. +template class UnaryFunctor> +absl::Status CSRSparseMatrixUnaryHelper(OpKernelContext* ctx, + const CSRSparseMatrix& a, + CSRSparseMatrix* b) { + DataType dt = a.dtype(); + switch (dt) { + case DT_FLOAT: { + UnaryFunctor functor(ctx); + return functor(a, b); + } + case DT_DOUBLE: { + UnaryFunctor functor(ctx); + return functor(a, b); + } + case DT_COMPLEX64: { + UnaryFunctor functor(ctx); + return functor(a, b); + } + case DT_COMPLEX128: { + UnaryFunctor functor(ctx); + return functor(a, b); + } + default: + return errors::InvalidArgument( + "CSRSparseMatrixUnaryHelper: a.dtype (", DataTypeString(dt), + ") is not one of: float, double, complex64, complex128"); + } +} + +template +struct ConstCSRComponent { + TTypes::UnalignedConstVec row_ptr; + TTypes::UnalignedConstVec col_ind; + typename TTypes::UnalignedConstVec values; + TTypes::ConstVec dense_shape_host; +}; + +template +struct CSRComponent { + TTypes::UnalignedVec row_ptr; + TTypes::UnalignedVec col_ind; + typename TTypes::UnalignedVec values; + TTypes::Vec dense_shape_host; +}; + +template +absl::Status ExtractVariantFromInput(OpKernelContext* ctx, int index, + const T** value) { + const Tensor& input_t = ctx->input(index); + if (!TensorShapeUtils::IsScalar(input_t.shape())) { + return errors::InvalidArgument( + "Invalid input matrix: Shape must be rank 0 but is rank ", + input_t.dims()); + } + const Variant& input_variant = input_t.scalar()(); + *value = input_variant.get(); + if (*value == nullptr) { + return errors::InvalidArgument("Could not retrieve Variant input ", index); + } + if (!(*value)->valid()) { + return errors::InvalidArgument("Variant input ", index, " is not valid."); + } + return absl::OkStatus(); +} + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_SPARSE_SPARSE_MATRIX_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/sparse/transpose_op.h b/third_party/tflite-hdrs/tensorflow/core/kernels/sparse/transpose_op.h new file mode 100644 index 00000000..2a8f0671 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/sparse/transpose_op.h @@ -0,0 +1,73 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_SPARSE_TRANSPOSE_OP_H_ +#define TENSORFLOW_CORE_KERNELS_SPARSE_TRANSPOSE_OP_H_ + +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/kernels/cwise_ops.h" + +namespace tensorflow { +namespace functor { + +template +struct maybe_conj_inplace { + static void run(const Device& d, Tensor* t) {} +}; + +template +struct maybe_conj_inplace { + static void run(const Device& d, Tensor* t) { + functor::UnaryFunctor> conj; + conj(d, t->flat() /*out*/, + const_cast(t)->flat() /*in*/); + } +}; + +template +struct maybe_conj_inplace { + static void run(const Device& d, Tensor* t) { + functor::UnaryFunctor> conj; + conj(d, t->flat() /*out*/, + const_cast(t)->flat() /*in*/); + } +}; + +template +struct maybe_conj { + static void run(const Device& d, const Tensor& in, Tensor* out) { *out = in; } +}; + +template +struct maybe_conj { + static void run(const Device& d, const Tensor& in, Tensor* out) { + functor::UnaryFunctor> conj; + conj(d, out->flat() /*out*/, in.flat() /*in*/); + } +}; + +template +struct maybe_conj { + static void run(const Device& d, const Tensor& in, Tensor* out) { + functor::UnaryFunctor> conj; + conj(d, out->flat() /*out*/, in.flat() /*in*/); + } +}; + +} // namespace functor +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_SPARSE_TRANSPOSE_OP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/sparse/zeros_op.h b/third_party/tflite-hdrs/tensorflow/core/kernels/sparse/zeros_op.h new file mode 100644 index 00000000..2a86089e --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/sparse/zeros_op.h @@ -0,0 +1,85 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_SPARSE_ZEROS_OP_H_ +#define TENSORFLOW_CORE_KERNELS_SPARSE_ZEROS_OP_H_ + +#define EIGEN_USE_THREADS + +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM +#define EIGEN_USE_GPU +#endif + +#include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/framework/variant_op_registry.h" +#include "tensorflow/core/kernels/dense_update_functor.h" +#include "tensorflow/core/kernels/fill_functor.h" +#include "tensorflow/core/kernels/sparse/sparse_matrix.h" + +namespace tensorflow { + +typedef Eigen::ThreadPoolDevice CPUDevice; +typedef Eigen::GpuDevice GPUDevice; + +namespace functor { + +template +struct CSRSparseMatrixZeros { + absl::Status operator()(OpKernelContext* c, DataType dtype, + const Tensor& dense_shape_t, + CSRSparseMatrix* matrix) { + auto dense_shape = dense_shape_t.vec(); + const int rank = dense_shape.size(); + if (!(rank == 2 || rank == 3)) { + return errors::InvalidArgument("sparse tensor must have rank == 2 or 3; ", + "but dense shape has ", rank, " entries"); + } + const int64_t batch_size = (rank == 2) ? 1 : dense_shape(0); + const int64_t rows = dense_shape((rank == 2) ? 0 : 1); + + Tensor batch_ptr_t(cpu_allocator(), DT_INT32, + TensorShape({batch_size + 1})); + batch_ptr_t.vec().setZero(); // On host. + + Allocator* allocator = c->device()->GetAllocator(AllocatorAttributes()); + // An all-zeros CSR matrix is composed of an empty set of column + // indices, an empty set of values, and a vector of all zero row + // pointers. The length of the row pointers vector is #rows + 1. + // Each row pointer is just an offset into the cols and + // values vectors, and those are empty, all coefficients are zero. + Tensor csr_row_ptr_t; + Tensor coo_col_ind_t(allocator, DT_INT32, TensorShape({0})); + Tensor csr_values_t(allocator, dtype, TensorShape({0})); + const Device& d = c->eigen_device(); + functor::SetZeroFunctor set_zero; + TF_RETURN_IF_ERROR(c->allocate_temp( + DT_INT32, TensorShape({batch_size * (rows + 1)}), &csr_row_ptr_t)); + set_zero(d, csr_row_ptr_t.flat()); + + TF_RETURN_IF_ERROR(CSRSparseMatrix::CreateCSRSparseMatrix( + dtype, dense_shape_t, batch_ptr_t, csr_row_ptr_t, coo_col_ind_t, + csr_values_t, matrix)); + + return absl::OkStatus(); + } +}; + +} // namespace functor +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_SPARSE_ZEROS_OP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/sparse_concat_op.h b/third_party/tflite-hdrs/tensorflow/core/kernels/sparse_concat_op.h new file mode 100644 index 00000000..c13ae502 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/sparse_concat_op.h @@ -0,0 +1,36 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_SPARSE_CONCAT_OP_H_ +#define TENSORFLOW_CORE_KERNELS_SPARSE_CONCAT_OP_H_ + +#include "tensorflow/core/framework/op_kernel.h" + +namespace tensorflow { + +namespace functor { + +template +struct SparseConcatFunctor { + void operator()(OpKernelContext* context, const OpInputList& inds, + const OpInputList& vals, const OpInputList& shapes, + int concat_dim); +}; + +} // namespace functor + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_SPARSE_CONCAT_OP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/sparse_conditional_accumulator.h b/third_party/tflite-hdrs/tensorflow/core/kernels/sparse_conditional_accumulator.h new file mode 100644 index 00000000..9d45d52b --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/sparse_conditional_accumulator.h @@ -0,0 +1,438 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_SPARSE_CONDITIONAL_ACCUMULATOR_H_ +#define TENSORFLOW_CORE_KERNELS_SPARSE_CONDITIONAL_ACCUMULATOR_H_ + +#include "tensorflow/core/kernels/typed_conditional_accumulator_base.h" + +namespace tensorflow { + +/** + * An aggregation object for adding sparse gradients, represented as a tuple of + * indices, values, and a (possibly empty) shape. + * + * The two main methods of this class are TryApplyGrad and TryTakeGrad. + * + * TryApplyGrad tries add a gradient to the accumulator. The attempt is + * successful if local_step >= global_step, i.e., if the gradient is not stale, + * having been computed using up-to-date information. Otherwise, the gradient is + * silently dropped. + * + * TryTakeGrad logs an attempt to read the average gradient. The attempt is + * blocked until the number of gradients accumulated (via TryApplyGrad) is equal + * or exceeds the number requested by TryTakeGrad. + * Once this condition is satisfied, the following actions are taken: + * (1) the value of the average gradient is returned + * (2) the count of accumulated gradients is reset to 0 + * (3) the internal global_step value (current_global_step_) is incremented by 1 + * + * SparseConditionalAccumulator is the datatype-dependent templated sub-class of + * ConditionalAccumulatorBase. It implements the virtual arithmetic methods that + * are used by for aggregating, averaging, allocating, returning indexed slices. + */ +template +class SparseConditionalAccumulator + : public TypedConditionalAccumulatorBase< + std::tuple> { + public: + SparseConditionalAccumulator(const DataType& dtype, + const PartialTensorShape& shape, + const string& name, const string& reduction_type) + : TypedConditionalAccumulatorBase< + std::tuple>( + dtype, shape, name, reduction_type), + accum_val_(std::make_unique()) {} + + protected: + std::unique_ptr> accum_idx_vec_; + std::unique_ptr> count_element_; + + std::unique_ptr accum_val_; + + typedef Eigen::TensorMap, + Eigen::Unaligned> + SliceT; + typedef Eigen::TensorMap, + Eigen::Unaligned> + SliceConstT; + + absl::Status ValidateShape( + std::tuple* tensor, + bool has_known_shape) TF_EXCLUSIVE_LOCKS_REQUIRED(this->mu_) { + const Tensor* tensor_idx = std::get<0>(*tensor); + const Tensor* tensor_val = std::get<1>(*tensor); + const Tensor* tensor_shape = std::get<2>(*tensor); + int64_t grad_val_dims = tensor_val->dims(); + int64_t grad_dims = grad_val_dims; + + // Compare with provided shape + if (has_known_shape) { + if (shape_.dims() > tensor_shape->NumElements()) { + return errors::InvalidArgument( + "Shape mismatch: expected shape rank at least ", shape_.dims(), + ", got ", tensor_shape->NumElements()); + } + const auto tensor_shape_flat = tensor_shape->flat(); + for (int64_t i = 0; i < shape_.dims(); i++) { + if (shape_.dim_size(i) != -1 && + shape_.dim_size(i) != tensor_shape_flat(i)) { + return errors::InvalidArgument("Shape mismatch: expected shape dim ", + i, " to be ", shape_.dim_size(i), + ", got ", tensor_shape_flat(i)); + } + } + } + // Check that indices are within limits + if (shape_.dims() > 0 && shape_.dim_size(0) != -1 && + tensor_idx->dims() > 0) { + for (int64_t i = 0; i < tensor_idx->dim_size(0); i++) { + if (tensor_idx->vec()(i) >= shape_.dim_size(0)) { + return errors::InvalidArgument( + "Shape mismatch: index of slice ", i, " exceeded limits of shape", + "; index is ", tensor_idx->vec()(i), " exceeded ", + shape_.dim_size(0)); + } + } + } + + // Check values compatibility with accumulated gradient if available + if (counter_ > 0) { + int64_t accum_val_dims = accum_val_->dims(); + if (accum_val_dims != grad_val_dims) { + return errors::InvalidArgument("Shape mismatch: expected values rank ", + accum_val_dims, ", got ", grad_val_dims); + } + for (int64_t i = 1; i < accum_val_dims; i++) { + if (accum_val_->dim_size(i) != tensor_val->dim_size(i)) { + return errors::InvalidArgument("Shape mismatch: expected values dim ", + i, " to be ", accum_val_->dim_size(i), + ", got ", tensor_val->dim_size(i)); + } + } + } else { + // If there are no accumulated gradients, check against shape_ + if (shape_.dims() > grad_dims) { + return errors::InvalidArgument( + "Shape mismatch: expected values rank at least ", shape_.dims(), + ", got ", grad_dims); + } + // Check that values have correct dimensions + for (int64_t i = 1; i < shape_.dims(); i++) { + if (shape_.dim_size(i) != -1 && + shape_.dim_size(i) != tensor_val->dim_size(i)) { + return errors::InvalidArgument("Shape mismatch: expected values dim ", + i, " to be ", shape_.dim_size(i), + ", got ", tensor_val->dim_size(i)); + } + } + } + + return absl::OkStatus(); + } + + void AllocateAndAssignToAccumGradFunction( + OpKernelContext* ctx, + std::tuple* grad) override { + const Tensor* grad_idx = std::get<0>(*grad); + const Tensor* grad_val = std::get<1>(*grad); + + const int64_t nnz = grad_idx->dim_size(0); + + // Assign indices + accum_idx_vec_ = std::make_unique>(); + accum_idx_vec_->reserve(nnz); + for (int i = 0; i < nnz; i++) { + accum_idx_vec_->push_back(grad_idx->vec()(i)); + } + + // Assign values to accum_val_tensor + OP_REQUIRES_OK( + ctx, ctx->allocate_temp(dtype_, grad_val->shape(), accum_val_.get())); + accum_val_->flat().device(ctx->template eigen_device()) = + grad_val->flat(); + + // Assign count_element_ + count_element_ = std::make_unique>(nnz, 1); + + // Do not need shape; Assume that the op has checked that the shapes match, + // so grad's shape == shape_ + } + + void AddToAccumGradFunction( + OpKernelContext* ctx, + std::tuple* grad) override { + // Modeled after third_party/tensorflow/core/kernels/sparse_add_op + + const Tensor* grad_idx = std::get<0>(*grad); + const Tensor* grad_val = std::get<1>(*grad); + + const int64_t accum_nnz = accum_idx_vec_->size(); + const int64_t grad_nnz = grad_idx->dim_size(0); + + // Source enumerates the origin of a non-zero element: whether it is from + // the new gradient, the accumulated gradient, or the sum of both. + enum Source { from_accum, from_grad, from_accum_and_grad }; + + // (1) do a pass over inputs, and append values and indices to vectors + std::vector> entries_to_copy; + entries_to_copy.reserve(accum_nnz + grad_nnz); + + // Pass over all non-zero elements of both the gradient and the accumulated + // value, to identify where each non-zero element of the sum comes from. + // The input and output indexed slices are assumed to be ordered along + // increasing dimension number. + int64_t i = 0, j = 0; + int64_t sum_nnz = 0; + while (i < accum_nnz && j < grad_nnz) { + sum_nnz++; + switch (cmp(accum_idx_vec_.get(), grad_idx, i, j)) { + case -1: + entries_to_copy.emplace_back(from_accum, i, -1); + ++i; + break; + case 0: + entries_to_copy.emplace_back(from_accum_and_grad, i, j); + ++i; + ++j; + break; + case 1: + entries_to_copy.emplace_back(from_grad, -1, j); + ++j; + break; + } + } + + // Handle leftovers + while (i < accum_nnz) { + sum_nnz++; + entries_to_copy.emplace_back(from_accum, i, -1); + ++i; + } + while (j < grad_nnz) { + sum_nnz++; + entries_to_copy.emplace_back(from_grad, -1, j); + ++j; + } + + // (2) Copy or sum the non-zero elements into sum_indices and sum_tensor + std::vector* sum_indices_vec = new std::vector(); + sum_indices_vec->reserve(sum_nnz); + + std::vector* sum_counts = new std::vector(); + sum_counts->reserve(sum_nnz); + + Tensor* sum_tensor = new Tensor(); + + TensorShape sum_shape = grad_val->shape(); + sum_shape.set_dim(0, sum_nnz); + + OP_REQUIRES_OK(ctx, ctx->allocate_temp(dtype_, sum_shape, sum_tensor)); + auto sum_flat = sum_tensor->flat_outer_dims(); + auto accum_flat = accum_val_->flat_outer_dims(); + auto grad_flat = grad_val->flat_outer_dims(); + + const int64_t num_col = grad_flat.dimension(1); + + Eigen::DSizes slice_shape(num_col); + + for (i = 0; i < sum_nnz; ++i) { + const Source src = std::get<0>(entries_to_copy[i]); + const int64_t idx_a = std::get<1>(entries_to_copy[i]); + const int64_t idx_b = std::get<2>(entries_to_copy[i]); + T* sum_slice_ptr = &sum_flat(i, 0); + SliceT sum_slice(sum_slice_ptr, slice_shape); + if (src == from_accum) { + // Element comes from accumulator; directly copy data structures over + sum_indices_vec->push_back(accum_idx_vec_->at(idx_a)); + T* accum_slice_ptr = &accum_flat(idx_a, 0); + SliceT accum_slice(accum_slice_ptr, slice_shape); + sum_slice = accum_slice; + sum_counts->push_back(count_element_->at(idx_a)); + } else if (src == from_accum_and_grad) { + // Element is a sum of accumulated value and new gradient; + // compute sum here + sum_indices_vec->push_back(accum_idx_vec_->at(idx_a)); + const T* grad_slice_ptr = &grad_flat(idx_b, 0); + SliceConstT grad_slice(grad_slice_ptr, slice_shape); + T* accum_slice_ptr = &accum_flat(idx_a, 0); + SliceT accum_slice(accum_slice_ptr, slice_shape); + sum_slice = grad_slice + accum_slice; + sum_counts->push_back(count_element_->at(idx_a) + 1); + } else if (src == from_grad) { + // Element comes from new gradient; make a copy of indices and values + sum_indices_vec->push_back(grad_idx->vec()(idx_b)); + const T* grad_slice_ptr = &grad_flat(idx_b, 0); + SliceConstT grad_slice(grad_slice_ptr, slice_shape); + sum_slice = grad_slice; + sum_counts->push_back(1); + } + } + + // (3) Keep output, i.e., switch pointers to point to new data structures + // representing the sum + // Indices + accum_idx_vec_.reset(sum_indices_vec); + // Values + accum_val_.reset(sum_tensor); + // Counts + count_element_.reset(sum_counts); + + // No need to copy shape, since shape remains the same after sum. + } + + void DivideAccumGradByCounter(OpKernelContext* ctx) override + TF_EXCLUSIVE_LOCKS_REQUIRED(this->mu_) { + const int64_t nnz = count_element_->size(); + auto accum_flat = accum_val_->flat_outer_dims(); + std::vector count_typet; + std::transform(count_element_->begin(), count_element_->end(), + std::back_inserter(count_typet), + TypeConverter::ConvertUToT); + + // Option 1: divide all by counter + /* + std::transform( + &accum_flat(0,0), &accum_flat(nnz,0), &accum_flat(0,0), + std::bind2nd(std::divides(), + TypeConverter::ConvertUToT(this->counter_))); + */ + + // Option 2: average element-wise + Eigen::DSizes slice_shape(accum_flat.dimension(1)); + for (int64_t i = 0; i < nnz; i++) { + T* accum_slice_ptr = &accum_flat(i, 0); + SliceT accum_slice(accum_slice_ptr, slice_shape); + accum_slice.device(ctx->template eigen_device()) = + accum_slice / count_typet[i]; + } + } + + bool SetOutput(OpKernelContext* ctx) override { + bool is_successful = true; + if (is_successful) is_successful = ReturnIdxTensor(ctx); + if (is_successful) is_successful = ReturnValTensor(ctx); + if (is_successful) is_successful = ReturnShapeTensor(ctx); + return is_successful; + } + + bool GetAndValidateTensorInputForApplyGrad( + OpKernelContext* ctx, + std::tuple** tensor) override + TF_EXCLUSIVE_LOCKS_REQUIRED(this->mu_) { + // TODO(xinghao, jmchen): The roundabout way of getting attr from + // OpKernelContext (instead of OpKernelConstruction) is a hack, and should + // be fixed if it affects efficiency. + bool has_known_shape = false; + OP_REQUIRES_OK_BOOLEAN( + ctx, GetNodeAttr(ctx->op_kernel().def(), "has_known_shape", + &has_known_shape)); + + // Get input gradient tensors + const Tensor* grad_idx_tensor; + OP_REQUIRES_OK_BOOLEAN(ctx, + ctx->input("gradient_indices", &grad_idx_tensor)); + const Tensor* grad_val_tensor; + OP_REQUIRES_OK_BOOLEAN(ctx, + ctx->input("gradient_values", &grad_val_tensor)); + const Tensor* grad_shape_tensor = nullptr; + if (has_known_shape) { + OP_REQUIRES_OK_BOOLEAN(ctx, + ctx->input("gradient_shape", &grad_shape_tensor)); + } + + // Checks + OP_REQUIRES_BOOLEAN( + ctx, TensorShapeUtils::IsVector(grad_idx_tensor->shape()), + errors::InvalidArgument( + "Input indices should be vector but received shape: ", + grad_idx_tensor->shape().DebugString())); + const int64_t nnz = grad_idx_tensor->dim_size(0); + OP_REQUIRES_BOOLEAN( + ctx, grad_val_tensor->dims() > 0, + errors::InvalidArgument("Values cannot be 0-dimensional.")); + OP_REQUIRES_BOOLEAN(ctx, grad_val_tensor->dim_size(0) == nnz, + errors::InvalidArgument("Expected ", nnz, + " non-empty input values, got ", + grad_val_tensor->dim_size(0))); + + *tensor = new std::tuple( + grad_idx_tensor, grad_val_tensor, grad_shape_tensor); + + OP_REQUIRES_OK_BOOLEAN(ctx, this->ValidateShape(*tensor, has_known_shape)); + + return true; + } + + void CleanUpGradTensor(std::tuple* tensor) override { + if (tensor != nullptr) delete tensor; + } + + private: + inline int cmp(std::vector* a_idx, const Tensor* b_idx, + const int64_t a_row, const int64_t b_row) { + const int64_t a = a_idx->at(a_row); + const int64_t b = b_idx->vec()(b_row); + if (a < b) { + return -1; + } else if (a > b) { + return 1; + } + return 0; + } + + inline bool ReturnIdxTensor(OpKernelContext* ctx) { + Tensor* idx_tensor; + const int64_t nnz = accum_idx_vec_->size(); + OP_REQUIRES_OK_BOOLEAN(ctx, ctx->allocate_output(0, {nnz}, &idx_tensor)); + // If allocate_output fails, OP_REQUIRES_OK_BOOLEAN will short-circuit + // the remaining code and just return false + auto idx_tensor_vec = idx_tensor->vec(); + for (int i = 0; i < nnz; ++i) { + idx_tensor_vec(i) = accum_idx_vec_->at(i); + } + return true; + } + + inline bool ReturnValTensor(OpKernelContext* ctx) { + ctx->set_output(1, *accum_val_); + return true; + } + + inline bool ReturnShapeTensor(OpKernelContext* ctx) { + int64_t accum_val_dims = accum_val_->dims(); + Tensor* shape_tensor; + OP_REQUIRES_OK_BOOLEAN( + ctx, ctx->allocate_output(2, {accum_val_dims}, &shape_tensor)); + // If allocate_output fails, OP_REQUIRES_OK_BOOLEAN will short-circuit + // the remaining code and just return false + + // First dim of shape is defined by shape_, others by accum_val_->shape + shape_tensor->flat()(0) = + (shape_.dims() > 0) ? shape_.dim_size(0) : -1; + for (int64_t i = 1; i < accum_val_dims; i++) { + shape_tensor->flat()(i) = accum_val_->dim_size(i); + } + return true; + } + + SparseConditionalAccumulator(const SparseConditionalAccumulator&) = delete; + void operator=(const SparseConditionalAccumulator&) = delete; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_SPARSE_CONDITIONAL_ACCUMULATOR_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/sparse_matmul_op.h b/third_party/tflite-hdrs/tensorflow/core/kernels/sparse_matmul_op.h new file mode 100644 index 00000000..589a65af --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/sparse_matmul_op.h @@ -0,0 +1,501 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_SPARSE_MATMUL_OP_H_ +#define TENSORFLOW_CORE_KERNELS_SPARSE_MATMUL_OP_H_ + +#include "Eigen/Core" // from @eigen_archive +#include "tensorflow/core/platform/byte_order.h" +#include "tensorflow/core/platform/types.h" + +#if defined(PLATFORM_WINDOWS) +#include "xla/tsl/platform/windows/intrinsics_port.h" +#endif + +namespace Eigen { +namespace internal { + +// Return the float representation of the bfloat16 value +// in the lower 16-bits of input +template +EIGEN_DEVICE_FUNC inline Packet pexpand_bf16_l(const Packet& from) { + tensorflow::uint32 tmp; +#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ + tmp = (reinterpret_cast(from)) & 0xffff0000; +#else + tmp = (reinterpret_cast(from) << 16) & 0xffff0000; +#endif + return reinterpret_cast(tmp); +} + +// Return the float representation of the bfloat16 value +// in the upper 16-bits of input +template +EIGEN_DEVICE_FUNC inline Packet pexpand_bf16_u(const Packet& from) { + tensorflow::uint32 tmp; +#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ + tmp = (reinterpret_cast(from) << 16) & 0xffff0000; +#else + tmp = (reinterpret_cast(from)) & 0xffff0000; +#endif + return reinterpret_cast(tmp); +} + +// Specialization non-scalar version on non-sse. +// Enable vectorization on z13 and higher +#if defined(EIGEN_VECTORIZE_ALTIVEC) || defined(EIGEN_VECTORIZE_VSX) || \ + defined(EIGEN_VECTORIZE_NEON) || defined(EIGEN_VECTORIZE_ZVECTOR) +template +EIGEN_DEVICE_FUNC inline Packet4f pexpand_bf16_l(const Packet4f& from) { + float r[4]; + tensorflow::uint32 p[4]; + pstoreu(r, from); + tensorflow::uint32* ir = reinterpret_cast(r); + p[0] = (ir[0] << 16) & 0xffff0000; + p[1] = ir[0] & 0xffff0000; + p[2] = (ir[1] << 16) & 0xffff0000; + p[3] = ir[1] & 0xffff0000; + return ploadu(reinterpret_cast(p)); +} + +template +EIGEN_DEVICE_FUNC inline Packet4f pexpand_bf16_u(const Packet4f& from) { + float r[4]; + tensorflow::uint32 p[4]; + pstoreu(r, from); + tensorflow::uint32* ir = reinterpret_cast(r); + p[0] = (ir[2] << 16) & 0xffff0000; + p[1] = ir[2] & 0xffff0000; + p[2] = (ir[3] << 16) & 0xffff0000; + p[3] = ir[3] & 0xffff0000; + return ploadu(reinterpret_cast(p)); +} +#endif + +template +EIGEN_DEVICE_FUNC inline Packet pinterleave4x64(const Packet& from) { + return from; +} + +template +EIGEN_DEVICE_FUNC inline Packet pbroadcast_first(const Packet& a) { + return a; +} + +template +EIGEN_DEVICE_FUNC inline Packet pbroadcast_second(const Packet& a) { + assert(false && "Not applicable to Scalar Values"); + return a; +} + +template +EIGEN_DEVICE_FUNC inline Packet pbroadcast_third(const Packet& a) { + assert(false && "Not applicable to Scalar Values"); + return a; +} + +template +EIGEN_DEVICE_FUNC inline Packet pbroadcast_fourth(const Packet& a) { + assert(false && "Not applicable to Scalar Values"); + return a; +} + +template +EIGEN_DEVICE_FUNC inline Packet pload4bf16( + const typename unpacket_traits::type* from) { + assert(false && "Not applicable to Scalar Values"); + return Packet(); +} + +template +EIGEN_DEVICE_FUNC inline Packet pload2bf16( + const typename unpacket_traits::type* from) { + assert(false && "Not applicable to Scalar Values"); + return Packet(); +} + +// Specialization for pload4bf16 and pload2bf16 for non-sse. +// Enable vectorization on z13 and higher. +#if defined(EIGEN_VECTORIZE_ALTIVEC) || defined(EIGEN_VECTORIZE_VSX) || \ + defined(EIGEN_VECTORIZE_NEON) || defined(EIGEN_VECTORIZE_ZVECTOR) +template <> +EIGEN_STRONG_INLINE Packet4f pload4bf16(const float* from) { + tensorflow::uint32 p[4]; + const tensorflow::uint32* ir = + reinterpret_cast(from); + p[0] = (ir[0] << 16) & 0xffff0000; + p[1] = ir[0] & 0xffff0000; + p[2] = (ir[1] << 16) & 0xffff0000; + p[3] = ir[1] & 0xffff0000; + return ploadu(reinterpret_cast(p)); +} + +template <> +EIGEN_STRONG_INLINE Packet4f pload2bf16(const float* from) { + tensorflow::uint32 p[4]; + const tensorflow::uint32* ir = + reinterpret_cast(from); + p[0] = (ir[0] << 16) & 0xffff0000; + p[1] = ir[0] & 0xffff0000; + p[2] = (ir[0] << 16) & 0xffff0000; + p[3] = ir[0] & 0xffff0000; + return ploadu(reinterpret_cast(p)); +} +#endif + +#if defined(EIGEN_VECTORIZE_NEON) +// Return a packet with the first value of the input Packet replicated +template <> +EIGEN_STRONG_INLINE Packet4f pbroadcast_first(const Packet4f& a) { + return pset1(pfirst(a)); +} +template <> +EIGEN_STRONG_INLINE Packet2f pbroadcast_first(const Packet2f& a) { + return pset1(pfirst(a)); +} + +// Return a packet with the second value of the input Packet replicated +template <> +EIGEN_STRONG_INLINE Packet4f pbroadcast_second(const Packet4f& a) { + return pset1(vgetq_lane_f32(a, 1)); +} +template <> +EIGEN_STRONG_INLINE Packet2f pbroadcast_second(const Packet2f& a) { + return pset1(vget_lane_f32(a, 1)); +} + +// Return a packet with the third value of the input Packet replicated +template <> +EIGEN_STRONG_INLINE Packet4f pbroadcast_third(const Packet4f& a) { + return pset1(vgetq_lane_f32(a, 2)); +} + +// Return a packet with the fourth value of the input Packet replicated +template <> +EIGEN_STRONG_INLINE Packet4f pbroadcast_fourth(const Packet4f& a) { + return pset1(vgetq_lane_f32(a, 3)); +} +#endif + +#if defined(EIGEN_VECTORIZE_ALTIVEC) || defined(EIGEN_VECTORIZE_VSX) +// Return a packet with the first value of the input Packet replicated +template <> +EIGEN_STRONG_INLINE Packet4f pbroadcast_first(const Packet4f& a) { + return vec_splat(a, 0); +} + +// Return a packet with the second value of the input Packet replicated +template <> +EIGEN_STRONG_INLINE Packet4f pbroadcast_second(const Packet4f& a) { + return vec_splat(a, 1); +} + +// Return a packet with the third value of the input Packet replicated +template <> +EIGEN_STRONG_INLINE Packet4f pbroadcast_third(const Packet4f& a) { + return vec_splat(a, 2); +} + +// Return a packet with the fourth value of the input Packet replicated +template <> +EIGEN_STRONG_INLINE Packet4f pbroadcast_fourth(const Packet4f& a) { + return vec_splat(a, 3); +} +#endif + +#ifdef EIGEN_VECTORIZE_SSE2 +// For PacketSize of 4 floats the Packet is not modified +template <> +EIGEN_STRONG_INLINE Packet4f pinterleave4x64(const Packet4f& from) { + return from; +} + +// Return a Packet with 4 floats loaded from 4 bfloat16 values +template <> +EIGEN_STRONG_INLINE Packet4f pload4bf16(const float* from) { + __m128i zero = _mm_setzero_si128(); + __m128i tmp = _mm_castpd_si128(_mm_load_pd1((const double*)from)); + return _mm_castsi128_ps(_mm_unpacklo_epi16(zero, tmp)); +} + +// Return a Packet with 2 floats loaded from 2 bfloat16 values +template <> +EIGEN_STRONG_INLINE Packet4f pload2bf16(const float* from) { + __m128i zero = _mm_setzero_si128(); + __m128i tmp = _mm_castps_si128(_mm_load_ps1(from)); + return _mm_castsi128_ps(_mm_unpacklo_epi16(zero, tmp)); +} + +// Return a Packet with 4 floats expanded from 4 bfloat16 values +// in the lower half of the 128-bit lane +template +EIGEN_DEVICE_FUNC inline Packet4f pexpand_bf16_l(const Packet4f& from) { + __m128i zero = _mm_setzero_si128(); + __m128i tmp = _mm_castps_si128(from); + return _mm_castsi128_ps(_mm_unpacklo_epi16(zero, tmp)); +} + +// Return a Packet with 4 floats expanded from 4 bfloat16 values +// in the upper half of the 128-bit lane +template +EIGEN_DEVICE_FUNC inline Packet4f pexpand_bf16_u(const Packet4f& from) { + __m128i zero = _mm_setzero_si128(); + __m128i tmp = _mm_castps_si128(from); + return _mm_castsi128_ps(_mm_unpackhi_epi16(zero, tmp)); +} + +// Return a packet with the first value of the input Packet replicated +template <> +EIGEN_STRONG_INLINE Packet4f pbroadcast_first(const Packet4f& a) { + return _mm_set1_ps(pfirst(a)); +} + +// Return a packet with the second value of the input Packet replicated +template <> +EIGEN_STRONG_INLINE Packet4f pbroadcast_second(const Packet4f& a) { + return _mm_set1_ps(_mm_cvtss_f32(_mm_shuffle_ps(a, a, 1))); +} + +// Return a packet with the third value of the input Packet replicated +template <> +EIGEN_STRONG_INLINE Packet4f pbroadcast_third(const Packet4f& a) { + return _mm_set1_ps(_mm_cvtss_f32(_mm_shuffle_ps(a, a, 2))); +} + +// Return a packet with the fourth value of the input Packet replicated +template <> +EIGEN_STRONG_INLINE Packet4f pbroadcast_fourth(const Packet4f& a) { + return _mm_set1_ps(_mm_cvtss_f32(_mm_shuffle_ps(a, a, 3))); +} + +#endif + +#ifdef EIGEN_VECTORIZE_AVX512 +template <> +EIGEN_STRONG_INLINE Packet16f +pbroadcast_first(const Packet16f& a_in) { + Packet4f a = _mm512_castps512_ps128(a_in); + return _mm512_broadcastss_ps(a); +} +template <> +EIGEN_STRONG_INLINE Packet16f +pbroadcast_second(const Packet16f& a_in) { + Packet4f a = _mm512_castps512_ps128(a_in); + return _mm512_broadcastss_ps(_mm_shuffle_ps(a, a, _MM_SHUFFLE(1, 1, 1, 1))); +} +template <> +EIGEN_STRONG_INLINE Packet16f +pbroadcast_third(const Packet16f& a_in) { + Packet4f a = _mm512_castps512_ps128(a_in); + return _mm512_broadcastss_ps(_mm_shuffle_ps(a, a, _MM_SHUFFLE(2, 2, 2, 2))); +} +template <> +EIGEN_STRONG_INLINE Packet16f +pbroadcast_fourth(const Packet16f& a_in) { + Packet4f a = _mm512_castps512_ps128(a_in); + return _mm512_broadcastss_ps(_mm_shuffle_ps(a, a, _MM_SHUFFLE(3, 3, 3, 3))); +} +template <> +EIGEN_STRONG_INLINE Packet8d pbroadcast_first(const Packet8d& a_in) { + Packet2d a = _mm512_castpd512_pd128(a_in); + return _mm512_broadcastsd_pd(a); +} +template <> +EIGEN_STRONG_INLINE Packet8d pbroadcast_second(const Packet8d& a_in) { + Packet2d a = _mm_permute_pd(_mm512_castpd512_pd128(a_in), 3); + return _mm512_broadcastsd_pd(a); +} +template <> +EIGEN_STRONG_INLINE Packet8d pbroadcast_third(const Packet8d& a_in) { + Packet2d a = _mm256_extractf128_pd(_mm512_castpd512_pd256(a_in), 1); + return _mm512_broadcastsd_pd(a); +} +template <> +EIGEN_STRONG_INLINE Packet8d pbroadcast_fourth(const Packet8d& a_in) { + Packet2d a = + _mm_permute_pd(_mm256_extractf128_pd(_mm512_castpd512_pd256(a_in), 1), 3); + return _mm512_broadcastsd_pd(a); +} +template <> +EIGEN_STRONG_INLINE Packet16i +pbroadcast_first(const Packet16i& a_in) { + Packet4i a = _mm512_castsi512_si128(a_in); + return _mm512_broadcastd_epi32(a); +} +template <> +EIGEN_STRONG_INLINE Packet16i +pbroadcast_second(const Packet16i& a_in) { + Packet4i a = _mm512_castsi512_si128(a_in); + return _mm512_broadcastd_epi32(_mm_shuffle_epi32(a, _MM_SHUFFLE(1, 1, 1, 1))); +} +template <> +EIGEN_STRONG_INLINE Packet16i +pbroadcast_third(const Packet16i& a_in) { + Packet4i a = _mm512_castsi512_si128(a_in); + return _mm512_broadcastd_epi32(_mm_shuffle_epi32(a, _MM_SHUFFLE(2, 2, 2, 2))); +} +template <> +EIGEN_STRONG_INLINE Packet16i +pbroadcast_fourth(const Packet16i& a_in) { + Packet4i a = _mm512_castsi512_si128(a_in); + return _mm512_broadcastd_epi32(_mm_shuffle_epi32(a, _MM_SHUFFLE(3, 3, 3, 3))); +} +#endif + +#ifdef EIGEN_VECTORIZE_AVX +// For a Packet of Size 8 floats(256-bits), swap the 2nd and 3rd quadwords +template <> +EIGEN_STRONG_INLINE Packet8f pinterleave4x64(const Packet8f& from) { +#ifdef EIGEN_VECTORIZE_AVX2 + return _mm256_castsi256_ps(_mm256_permute4x64_epi64(_mm256_castps_si256(from), + _MM_SHUFFLE(3, 1, 2, 0))); +#else + auto tmp1 = _mm256_extract_epi32(_mm256_castps_si256(from), 2); + auto tmp2 = _mm256_extract_epi32(_mm256_castps_si256(from), 3); + auto tmp3 = _mm256_extract_epi32(_mm256_castps_si256(from), 4); + auto tmp4 = _mm256_extract_epi32(_mm256_castps_si256(from), 5); + auto tmp5 = _mm256_insert_epi32(_mm256_castps_si256(from), tmp1, 4); + tmp5 = _mm256_insert_epi32(tmp5, tmp2, 5); + tmp5 = _mm256_insert_epi32(tmp5, tmp3, 2); + tmp5 = _mm256_insert_epi32(tmp5, tmp4, 3); + return _mm256_castsi256_ps(tmp5); +#endif +} +// Return a Packet with 4 floats loaded from 4 bfloat16 values +template <> +EIGEN_STRONG_INLINE Packet8f pload4bf16(const float* from) { + __m128i zero = _mm_setzero_si128(); + __m128i tmp = _mm_castpd_si128(_mm_load_pd1((const double*)from)); + return _mm256_castps128_ps256( + _mm_castsi128_ps(_mm_unpacklo_epi16(zero, tmp))); +} +// Return a Packet with 2 floats loaded from 2 bfloat16 values +template <> +EIGEN_STRONG_INLINE Packet8f pload2bf16(const float* from) { + __m128i zero = _mm_setzero_si128(); + __m128i tmp = _mm_castps_si128(_mm_load_ps1(from)); + return _mm256_castps128_ps256( + _mm_castsi128_ps(_mm_unpacklo_epi16(zero, tmp))); +} + +#ifdef EIGEN_VECTORIZE_AVX512 +// Return a Packet with 4 floats loaded from 4 bfloat16 values +template <> +EIGEN_STRONG_INLINE Packet16f pload4bf16(const float* from) { + __m128i zero = _mm_setzero_si128(); + __m128i tmp = _mm_castpd_si128(_mm_load_pd1((const double*)from)); + return _mm512_castps128_ps512( + _mm_castsi128_ps(_mm_unpacklo_epi16(zero, tmp))); +} +// Return a Packet with 2 floats loaded from 2 bfloat16 values +template <> +EIGEN_STRONG_INLINE Packet16f pload2bf16(const float* from) { + __m128i zero = _mm_setzero_si128(); + __m128i tmp = _mm_castps_si128(_mm_load_ps1(from)); + return _mm512_castps128_ps512( + _mm_castsi128_ps(_mm_unpacklo_epi16(zero, tmp))); +} +#endif + +// For each 128-bit lane convert 4 bfloat to 4 float values from the lower half +// of the 128-bit lane +template +EIGEN_DEVICE_FUNC inline Packet8f pexpand_bf16_l(const Packet8f& from) { +#ifdef EIGEN_VECTORIZE_AVX2 + __m256i zero = _mm256_setzero_si256(); + __m256i tmp = _mm256_castps_si256(from); + return _mm256_castsi256_ps(_mm256_unpacklo_epi16(zero, tmp)); +#else + __m128i zero = _mm_setzero_si128(); + __m128i low = _mm_castps_si128(_mm256_extractf128_ps(from, 0)); + __m128i res_l = _mm_unpacklo_epi16(zero, low); + __m128i high = _mm_castps_si128(_mm256_extractf128_ps(from, 1)); + __m128i res_h = _mm_unpacklo_epi16(zero, high); + __m256 res = _mm256_castps128_ps256(_mm_castsi128_ps(res_l)); + res = _mm256_insertf128_ps(res, _mm_castsi128_ps(res_h), 1); + return res; +#endif +} + +// For each 128-bit lane convert 4 bfloat to 4 float values from the upper half +// of the 128-bit lane +template +EIGEN_DEVICE_FUNC inline Packet8f pexpand_bf16_u(const Packet8f& from) { +#ifdef EIGEN_VECTORIZE_AVX2 + __m256i zero = _mm256_setzero_si256(); + __m256i tmp = _mm256_castps_si256(from); + return _mm256_castsi256_ps(_mm256_unpackhi_epi16(zero, tmp)); +#else + __m128i zero = _mm_setzero_si128(); + __m128i low = _mm_castps_si128(_mm256_extractf128_ps(from, 0)); + __m128i res_l = _mm_unpackhi_epi16(zero, low); + __m128i high = _mm_castps_si128(_mm256_extractf128_ps(from, 1)); + __m128i res_h = _mm_unpackhi_epi16(zero, high); + __m256 res = _mm256_castps128_ps256(_mm_castsi128_ps(res_l)); + res = _mm256_insertf128_ps(res, _mm_castsi128_ps(res_h), 1); + return res; +#endif +} + +// Return a packet with the first value of the input Packet replicated +template <> +EIGEN_STRONG_INLINE Packet8f pbroadcast_first(const Packet8f& a) { + return _mm256_set1_ps(pfirst(a)); +} + +// Return a packet with the second value of the input Packet replicated +template <> +EIGEN_STRONG_INLINE Packet8f pbroadcast_second(const Packet8f& a) { + return _mm256_set1_ps( + _mm_cvtss_f32(_mm256_castps256_ps128(_mm256_permute_ps(a, 1)))); +} + +// Return a packet with the third value of the input Packet replicated +template <> +EIGEN_STRONG_INLINE Packet8f pbroadcast_third(const Packet8f& a) { + return _mm256_set1_ps( + _mm_cvtss_f32(_mm256_castps256_ps128(_mm256_permute_ps(a, 2)))); +} + +// Return a packet with the fourth value of the input Packet replicated +template <> +EIGEN_STRONG_INLINE Packet8f pbroadcast_fourth(const Packet8f& a) { + return _mm256_set1_ps( + _mm_cvtss_f32(_mm256_castps256_ps128(_mm256_permute_ps(a, 3)))); +} + +#endif + +#ifdef EIGEN_VECTORIZE_AVX512 + +template +EIGEN_DEVICE_FUNC inline Packet16f pexpand_bf16_l(const Packet16f& from) { + return _mm512_castsi512_ps(_mm512_slli_epi32( + _mm512_cvtepu16_epi32(_mm512_castsi512_si256(_mm512_castps_si512(from))), + 16)); +} + +template +EIGEN_DEVICE_FUNC inline Packet16f pexpand_bf16_u(const Packet16f& from) { + Packet16i tmp = _mm512_castps_si512(from); + Packet16i tmp2 = _mm512_alignr_epi32(tmp, tmp, 8); + return _mm512_castsi512_ps(_mm512_slli_epi32( + _mm512_cvtepu16_epi32(_mm512_castsi512_si256(tmp2)), 16)); +} + +#endif +} // namespace internal +} // namespace Eigen +#endif // TENSORFLOW_CORE_KERNELS_SPARSE_MATMUL_OP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/sparse_reorder_op.h b/third_party/tflite-hdrs/tensorflow/core/kernels/sparse_reorder_op.h new file mode 100644 index 00000000..0af44c55 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/sparse_reorder_op.h @@ -0,0 +1,35 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_SPARSE_REORDER_OP_H_ +#define TENSORFLOW_CORE_KERNELS_SPARSE_REORDER_OP_H_ + +#include "tensorflow/core/framework/op_kernel.h" + +namespace tensorflow { + +namespace functor { + +template +struct SparseReorderFunctor { + void operator()(OpKernelContext* context, const Tensor& input_ind, + const Tensor& input_val, const Tensor& input_shape_in); +}; + +} // namespace functor + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_SPARSE_REORDER_OP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/sparse_slice_grad_op.h b/third_party/tflite-hdrs/tensorflow/core/kernels/sparse_slice_grad_op.h new file mode 100644 index 00000000..6358ed02 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/sparse_slice_grad_op.h @@ -0,0 +1,40 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_SPARSE_SLICE_GRAD_OP_H_ +#define TENSORFLOW_CORE_KERNELS_SPARSE_SLICE_GRAD_OP_H_ + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor_types.h" + +namespace tensorflow { + +namespace functor { + +template +struct SparseSliceGradFunctor { + void operator()(OpKernelContext* ctx, + typename TTypes::ConstFlat backprop_val_grad, + typename TTypes::ConstMatrix input_indices_mat, + typename TTypes::ConstFlat input_start_flat, + typename TTypes::ConstMatrix output_indices_mat, + typename TTypes::Flat val_grad) const; +}; + +} // namespace functor + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_SPARSE_SLICE_GRAD_OP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/sparse_slice_op.h b/third_party/tflite-hdrs/tensorflow/core/kernels/sparse_slice_op.h new file mode 100644 index 00000000..62e0b0cc --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/sparse_slice_op.h @@ -0,0 +1,39 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_SPARSE_SLICE_OP_H_ +#define TENSORFLOW_CORE_KERNELS_SPARSE_SLICE_OP_H_ + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { + +namespace functor { + +template +struct SparseSliceFunctor { + void operator()(OpKernelContext* context, const Tensor& input_indices, + const Tensor& input_values, const Tensor& input_shape, + const Tensor& input_start, const Tensor& input_size, + typename AsyncOpKernel::DoneCallback done = nullptr) const; +}; + +} // namespace functor + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_SPARSE_SLICE_OP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/sparse_split_op.h b/third_party/tflite-hdrs/tensorflow/core/kernels/sparse_split_op.h new file mode 100644 index 00000000..7fba47a4 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/sparse_split_op.h @@ -0,0 +1,37 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_SPARSE_SPLIT_OP_H_ +#define TENSORFLOW_CORE_KERNELS_SPARSE_SPLIT_OP_H_ + +#include "tensorflow/core/framework/op_kernel.h" + +namespace tensorflow { + +namespace functor { + +template +struct SparseSplitFunctor { + void operator()(OpKernelContext* context, const Tensor& input_indices, + const Tensor& input_values, const TensorShape& dense_shape, + const int64_t axis, const int num_split, + typename AsyncOpKernel::DoneCallback done = nullptr); +}; + +} // namespace functor + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_SPARSE_SPLIT_OP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/sparse_tensor_dense_add_op.h b/third_party/tflite-hdrs/tensorflow/core/kernels/sparse_tensor_dense_add_op.h new file mode 100644 index 00000000..44a85785 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/sparse_tensor_dense_add_op.h @@ -0,0 +1,42 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_SPARSE_TENSOR_DENSE_ADD_OP_H_ +#define TENSORFLOW_CORE_KERNELS_SPARSE_TENSOR_DENSE_ADD_OP_H_ + +#include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/kernels/scatter_functor.h" + +namespace tensorflow { +namespace functor { + +// TODO(zongheng): this should be a general functor that powers SparseAdd and +// ScatterNd ops. It should be moved to its own head file, once the other ops +// are implemented. +template +struct ScatterNdFunctor { + // Returns -1 on success or a nonnegative i s.t. indices[i] is a bad index. + Index operator()(const Device& d, typename TTypes::ConstMatrix indices, + typename TTypes::ConstFlat updates, + typename TTypes::Tensor out); +}; + +} // namespace functor +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_SPARSE_TENSOR_DENSE_ADD_OP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/sparse_tensor_dense_matmul_op.h b/third_party/tflite-hdrs/tensorflow/core/kernels/sparse_tensor_dense_matmul_op.h new file mode 100644 index 00000000..fef151ea --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/sparse_tensor_dense_matmul_op.h @@ -0,0 +1,85 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_SPARSE_TENSOR_DENSE_MATMUL_OP_H_ +#define TENSORFLOW_CORE_KERNELS_SPARSE_TENSOR_DENSE_MATMUL_OP_H_ + +#include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/core/errors.h" + +namespace tensorflow { + +namespace functor { + +template +struct SparseTensorDenseMatMulFunctor { + static EIGEN_ALWAYS_INLINE absl::Status Compute( + OpKernelContext* ctx, typename TTypes::Matrix out, + typename TTypes::ConstMatrix a_indices, + typename TTypes::ConstVec a_values, typename TTypes::ConstMatrix b); +}; + +template +class MaybeAdjoint; + +template +class MaybeAdjoint { + public: + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE MaybeAdjoint(MATRIX m) : m_(m) {} + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE typename MATRIX::Scalar operator()( + const typename MATRIX::Index i, const typename MATRIX::Index j) const { + return m_(i, j); + } + + private: + const MATRIX m_; +}; + +template +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T MaybeConj(T v) { + return Eigen::numext::conj(v); +} + +template +class MaybeAdjoint { + public: + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE MaybeAdjoint(MATRIX m) : m_(m) {} + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE typename MATRIX::Scalar operator()( + const typename MATRIX::Index i, const typename MATRIX::Index j) const { + return Eigen::numext::conj(m_(j, i)); + } + + private: + const MATRIX m_; +}; + +template +struct SumType { + using type = T; +}; + +template <> +struct SumType { + using type = float; // Use fp32 accumulator for fp16 input values +}; + +} // end namespace functor +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_SPARSE_TENSOR_DENSE_MATMUL_OP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/sparse_to_dense_op_gpu.h b/third_party/tflite-hdrs/tensorflow/core/kernels/sparse_to_dense_op_gpu.h new file mode 100644 index 00000000..c19ffa72 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/sparse_to_dense_op_gpu.h @@ -0,0 +1,40 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#if !GOOGLE_CUDA && !TENSORFLOW_USE_ROCM +#error This file must only be included when building with Cuda +#endif + +#ifndef TENSORFLOW_CORE_KERNELS_SPARSE_TO_DENSE_OP_GPU_H_ +#define TENSORFLOW_CORE_KERNELS_SPARSE_TO_DENSE_OP_GPU_H_ + +#include "xla/stream_executor/device_memory.h" +#include "tensorflow/core/framework/op_kernel.h" + +namespace tensorflow { + +namespace functor { +template +struct LaunchSparseToDense { + void operator()(OpKernelContext* c, AsyncOpKernel::DoneCallback done, + AsyncOpKernel* op, bool validate_indices, + const Tensor& indices, const Tensor& values, + const Tensor& shape, const T default_value, Tensor* dense); +}; + +} // namespace functor + +} // namespace tensorflow +#endif // TENSORFLOW_CORE_KERNELS_SPARSE_TO_DENSE_OP_GPU_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/sparse_utils.h b/third_party/tflite-hdrs/tensorflow/core/kernels/sparse_utils.h new file mode 100644 index 00000000..8f86b518 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/sparse_utils.h @@ -0,0 +1,88 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Helpers for writing OpKernels for sparse tensors. +#ifndef TENSORFLOW_CORE_KERNELS_SPARSE_UTILS_H_ +#define TENSORFLOW_CORE_KERNELS_SPARSE_UTILS_H_ + +#include + +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +namespace sparse_utils { + +// Find the index i of the first element for which +// indices_mat(sparse_index_begin, 0) < indices_mat(i, 0). +// The search is conducted in the open interval +// [sparse_index_begin, indices_mat.dimension(0)) and when no such i is found, +// indices_mat.dimension(0) is returned. +// indices_mat(k, 0) should be non-decreasing over the interval +// [begin, indices_mat.dimension(0)). +// Requires 0 <= sparse_index_begin < indices_mat.dimension(0). +template +Tindices FindNextDenseRowStartIndex( + const Tindices sparse_index_begin, + const typename TTypes::ConstMatrix& indices_mat); + +// Returns the vector v of indices in indices_mat at which new dense matrix +// rows begin. +// v.front() = 0, v.back() = indices_mat.dimension(0), and for i > 0, +// v[i] - v[i-1] is the length of the ith dense row in indices_mat. +// *contains_empty_rows = true if and only if indices_mat contains empty rows +// (rows without values) between row 0 and the last row. +template +std::vector GetStartIndicesOfEachDenseRow( + const typename TTypes::ConstMatrix& indices_mat, + bool* contains_empty_rows); + +// Converts tensor.vec to an std::vector object, appends +// the value num_nonzero_entries_in_sparse_mat, and returns the result. +template +std::vector ParseRowStartIndices( + const tensorflow::Tensor& tensor, + const Tindices num_nonzero_entries_in_sparse_mat); + +// Returns true if and only if the sparse matrix indices_mat whose row start +// indices are represented by row_start_indices has empty dense rows +// (between its first and last dense rows). +// This function satisfies the identity row_start_indices == +// GetStartIndicesOfEachDenseRow(indices_mat, &return_value). +template +bool ContainsEmptyRows(const std::vector& row_start_indices); + +// Methods for validating sparse indices. +enum class IndexValidation { + kNone, // Indices are not used by the op, or are not directly accessible + // (e.g. on GPU). + kOrdered, // Indices must be unique, in lexicographical order, and within + // safe bounds. + kUnordered // Indices must be within safe bounds, but may repeat or appear + // out-of-order. +}; + +// Validates the three component tensors of a sparse tensor have the proper +// shapes. Also validates index values according to the method supplied. +template +absl::Status ValidateSparseTensor(const Tensor& indices, const Tensor& values, + const Tensor& shape, + IndexValidation index_validation); + +} // namespace sparse_utils +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_SPARSE_UTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/sparse_xent_op.h b/third_party/tflite-hdrs/tensorflow/core/kernels/sparse_xent_op.h new file mode 100644 index 00000000..d0ad3c4b --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/sparse_xent_op.h @@ -0,0 +1,232 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_SPARSE_XENT_OP_H_ +#define TENSORFLOW_CORE_KERNELS_SPARSE_XENT_OP_H_ +// Functor definition for SparseXentOp, must be compilable by nvcc. + +#include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/bounds_check.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +namespace sparse_xent_helpers { + +template +typename TTypes::Tensor32Bit To32BitConst( + typename TTypes::Vec in) { + return To32Bit(typename TTypes::ConstVec(in.data(), in.dimensions())); +} + +template +typename TTypes::Tensor32Bit To32BitConst( + typename TTypes::Matrix in) { + return To32Bit(typename TTypes::ConstMatrix(in.data(), in.dimensions())); +} + +} // namespace sparse_xent_helpers + +namespace generator { + +// Generator for calculation of the sparse Xent loss. +// This generator takes the logits, the sum of the exponentiated +// logits, and the label indices. For each minibatch entry, ignoring +// the batch index b, it calculates: +// +// loss[j] = (log(sum_exp_logits) - logits[j]) * 1{ j == label } +// +// for j = 0 .. num_classes. This value must be summed over all j for +// the final loss. +template +class SparseXentLossGenerator { + public: + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE SparseXentLossGenerator( + typename TTypes::Tensor32Bit logits, + typename TTypes::Tensor32Bit sum_exp_logits, + typename TTypes::Tensor32Bit labels, + const Index max_depth) + : logits_(logits), + sum_exp_logits_(sum_exp_logits), + labels_(labels), + max_depth_(max_depth) {} + + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T + operator()(const Eigen::array& coords) const { + const int batch = coords[0]; + const int depth = coords[1]; + const Index label = tensorflow::internal::SubtleMustCopy(labels_(batch)); + if (!FastBoundsCheck(label, max_depth_)) { + return Eigen::NumTraits::quiet_NaN(); + } + return TF_PREDICT_FALSE(label == depth) + ? (Eigen::numext::log(sum_exp_logits_(batch)) - logits_(coords)) + : T(0.0); + }; + + private: + typename TTypes::Tensor32Bit logits_; + typename TTypes::Tensor32Bit sum_exp_logits_; + typename TTypes::Tensor32Bit labels_; + const Index max_depth_; +}; + +// Generator for calculation of the sparse Xent gradient. +// This generator takes the exponentiated logits, their sums, and the label +// indices. For each minibatch entry, ignoring the batch index b, it calculates: +// +// exp_logits[j] / sum_exp_logits - 1{ j == label } +// +// for j = 0 .. num_classes. +template +class SparseXentGradGenerator { + public: + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE SparseXentGradGenerator( + typename TTypes::Tensor32Bit exp_logits, + typename TTypes::Tensor32Bit sum_exp_logits, + typename TTypes::Tensor32Bit labels, + const Index max_depth) + : exp_logits_(exp_logits), + sum_exp_logits_(sum_exp_logits), + labels_(labels), + max_depth_(max_depth) {} + + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T + operator()(const Eigen::array& coords) const { + const int batch = coords[0]; + const int depth = coords[1]; + const Index label = tensorflow::internal::SubtleMustCopy(labels_(batch)); + if (!FastBoundsCheck(label, max_depth_)) { + return Eigen::NumTraits::quiet_NaN(); + } + T subtract = TF_PREDICT_FALSE(depth == label) ? T(1.0) : T(0.0); + return exp_logits_(coords) / sum_exp_logits_(batch) - subtract; + }; + + private: + typename TTypes::Tensor32Bit exp_logits_; + typename TTypes::Tensor32Bit sum_exp_logits_; + typename TTypes::Tensor32Bit labels_; + const Index max_depth_; +}; + +} // namespace generator + +namespace functor { + +template +struct RowMaxReduction { + // Computes the maximum across the rows of logits + // + // logits: batch_size, num_classes. + // maximum: temporary tensor, dims: batch_size, 1 + static inline void Compute(OpKernelContext* ctx, + typename TTypes::ConstMatrix logits, + typename TTypes::Vec maximum) { + Eigen::IndexList > along_row; + Device d = ctx->eigen_device(); + To32Bit(maximum).device(d) = To32Bit(logits).maximum(along_row); + } +}; + +// Functor used by SparseXentOp to do the computations. +template +struct SparseXentFunctor { + // Computes Cross Entropy loss and backprop. + // + // logits: batch_size, num_classes. + // labels: num_classes. + // scratch: temporary tensor, dims: batch_size, 1 + // loss: output tensor for the loss, dims: batch_size. + // backprop: output tensor for the backprop, dims: batch_size, num_classes. + void operator()(OpKernelContext* ctx, typename TTypes::ConstMatrix logits, + typename TTypes::ConstVec labels, + typename TTypes::Vec scratch, typename TTypes::Vec loss, + typename TTypes::Matrix backprop); +}; + +// Eigen code implementing SparseXentFunctor::operator(). +// This code works for both CPU and GPU and is used by the functor +// specializations for both device types. +template +struct SparseXentEigenImpl { + static void Compute(OpKernelContext* ctx, + typename TTypes::ConstMatrix logits, + typename TTypes::ConstVec labels, + typename TTypes::Vec scratch, + typename TTypes::Vec loss, + typename TTypes::Matrix backprop) { + // NOTE(touts): This duplicates some of the computations in softmax_op + // because we need the intermediate (logits -max(logits)) values to + // avoid a log(exp()) in the computation of the loss. + + const int kBatchDim = 0; + const int kClassDim = 1; + + const int batch_size = logits.dimension(kBatchDim); + const int num_classes = logits.dimension(kClassDim); + +// These arrays are used to reduce along the class dimension, and broadcast +// the resulting value to all classes. + Eigen::IndexList > along_class; + Eigen::IndexList > batch_by_one; + batch_by_one.set(0, batch_size); + Eigen::IndexList batch_only; + batch_only.set(0, batch_size); + Eigen::IndexList, int> one_by_class; + one_by_class.set(1, num_classes); + + // scratch = max_logits along classes. + RowMaxReduction::Compute(ctx, logits, scratch); + + Device d = ctx->eigen_device(); + // backprop = logits - max_logits. + To32Bit(backprop).device(d) = + To32Bit(logits) - + To32Bit(scratch).reshape(batch_by_one).broadcast(one_by_class); + + // scratch = sum(exp(logits - max_logits)) along classes. + To32Bit(scratch).device(d) = To32Bit(backprop).exp().sum(along_class); + + // sum(-labels * + // ((logits - max_logits) - log(sum(exp(logits - max_logits))))) + // along classes + generator::SparseXentLossGenerator sparse_xent_loss_gen( + sparse_xent_helpers::To32BitConst(backprop), + sparse_xent_helpers::To32BitConst(scratch), To32Bit(labels), + backprop.dimension(1) /* max_depth */); + To32Bit(loss).device(d) = + To32Bit(backprop).generate(sparse_xent_loss_gen).sum(along_class); + + // backprop: prob - labels, where + // prob = exp(logits - max_logits) / sum(exp(logits - max_logits)) + To32Bit(backprop).device(d) = To32Bit(backprop).exp(); + generator::SparseXentGradGenerator sparse_xent_grad_gen( + sparse_xent_helpers::To32BitConst(backprop), + sparse_xent_helpers::To32BitConst(scratch), To32Bit(labels), + backprop.dimension(1) /* max_depth */); + To32Bit(backprop).device(d) = + To32Bit(backprop).generate(sparse_xent_grad_gen); + } +}; + +} // namespace functor + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_SPARSE_XENT_OP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/special_math/special_math_op_misc_impl.h b/third_party/tflite-hdrs/tensorflow/core/kernels/special_math/special_math_op_misc_impl.h new file mode 100644 index 00000000..6b8bb7cb --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/special_math/special_math_op_misc_impl.h @@ -0,0 +1,724 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_SPECIAL_MATH_SPECIAL_MATH_OP_MISC_IMPL_H_ +#define TENSORFLOW_CORE_KERNELS_SPECIAL_MATH_SPECIAL_MATH_OP_MISC_IMPL_H_ + +#define _USE_MATH_DEFINES +#include +#include +#include + +#include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/bounds_check.h" +#include "tensorflow/core/framework/numeric_types.h" +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/kernels/cwise_ops.h" + +namespace Eigen { +namespace internal { + +// Implementation of Dawson's integral based on Cephes. + +template +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar +generic_dawsn_interval_1(const Scalar& x) { + // Rational approximation on [0, 3.25) + const Scalar AN[] = { + Scalar(1.13681498971755972054E-11), Scalar(8.49262267667473811108E-10), + Scalar(1.94434204175553054283E-8), Scalar(9.53151741254484363489E-7), + Scalar(3.07828309874913200438E-6), Scalar(3.52513368520288738649E-4), + Scalar(-8.50149846724410912031E-4), Scalar(4.22618223005546594270E-2), + Scalar(-9.17480371773452345351E-2), Scalar(9.99999999999999994612E-1), + }; + const Scalar AD[] = { + Scalar(2.40372073066762605484E-11), Scalar(1.48864681368493396752E-9), + Scalar(5.21265281010541664570E-8), Scalar(1.27258478273186970203E-6), + Scalar(2.32490249820789513991E-5), Scalar(3.25524741826057911661E-4), + Scalar(3.48805814657162590916E-3), Scalar(2.79448531198828973716E-2), + Scalar(1.58874241960120565368E-1), Scalar(5.74918629489320327824E-1), + Scalar(1.00000000000000000539E0), + }; + const Scalar x2 = x * x; + Scalar y = (x * internal::ppolevl::run(x2, AN)) / + internal::ppolevl::run(x2, AD); + return y; +} + +template +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar +generic_dawsn_interval_2(const Scalar& x) { + // Rational approximation on [3.25, 6.25) + const Scalar BN[] = { + Scalar(5.08955156417900903354E-1), Scalar(-2.44754418142697847934E-1), + Scalar(9.41512335303534411857E-2), Scalar(-2.18711255142039025206E-2), + Scalar(3.66207612329569181322E-3), Scalar(-4.23209114460388756528E-4), + Scalar(3.59641304793896631888E-5), Scalar(-2.14640351719968974225E-6), + Scalar(9.10010780076391431042E-8), Scalar(-2.40274520828250956942E-9), + Scalar(3.59233385440928410398E-11), + }; + const Scalar BD[] = { + Scalar(1.0), + Scalar(-6.31839869873368190192E-1), + Scalar(2.36706788228248691528E-1), + Scalar(-5.31806367003223277662E-2), + Scalar(8.48041718586295374409E-3), + Scalar(-9.47996768486665330168E-4), + Scalar(7.81025592944552338085E-5), + Scalar(-4.55875153252442634831E-6), + Scalar(1.89100358111421846170E-7), + Scalar(-4.91324691331920606875E-9), + Scalar(7.18466403235734541950E-11), + }; + const Scalar one = Scalar(1); + const Scalar half = Scalar(0.5); + + const Scalar inverse_x = one / x; + const Scalar inverse_x2 = inverse_x * inverse_x; + Scalar z = (internal::ppolevl::run(inverse_x2, BN) / + (x * internal::ppolevl::run(inverse_x2, BD))); + Scalar y = inverse_x2 * z + inverse_x; + return half * y; +} + +template +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar +generic_dawsn_interval_3(const Scalar& x) { + // Rational approximation on [6.25, 1.0e9) + const Scalar CN[] = { + Scalar(-5.90592860534773254987E-1), Scalar(6.29235242724368800674E-1), + Scalar(-1.72858975380388136411E-1), Scalar(1.64837047825189632310E-2), + Scalar(-4.86827613020462700845E-4), + }; + const Scalar CD[] = { + Scalar(1.0), + Scalar(-2.69820057197544900361E0), + Scalar(1.73270799045947845857E0), + Scalar(-3.93708582281939493482E-1), + Scalar(3.44278924041233391079E-2), + Scalar(-9.73655226040941223894E-4), + }; + const Scalar one = Scalar(1); + const Scalar half = Scalar(0.5); + + const Scalar inverse_x = one / x; + Scalar inverse_x2 = inverse_x * inverse_x; + Scalar z = (internal::ppolevl::run(inverse_x2, CN) / + (x * internal::ppolevl::run(inverse_x2, CD))); + Scalar y = inverse_x2 * z + inverse_x; + return half * y; + return pmul(half, y); +} + +template +struct dawsn_op { + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar + operator()(const Scalar& x) const { + const Scalar half = Scalar(0.5); + const Scalar a = Scalar(3.25); + const Scalar b = Scalar(6.25); + const Scalar c = Scalar(1.0e9); + + Scalar abs_x = pabs(x); + + Scalar dawsn; + if (abs_x < a) { + dawsn = generic_dawsn_interval_1(abs_x); + } else if (abs_x < b) { + dawsn = generic_dawsn_interval_2(abs_x); + } else if (abs_x < c) { + dawsn = generic_dawsn_interval_3(abs_x); + } else { + dawsn = half / x; + } + + if (x < Scalar(0)) { + dawsn = -dawsn; + } + return dawsn; + } +}; + +// Implementation of exponential integral, based on Cephes. + +template +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar +generic_expint_interval_1(const Scalar& x) { + /* 0 < x <= 2 + Ei(x) - EUL - ln(x) = x A(x)/B(x) + Theoretical peak relative error 9.73e-18 */ + const Scalar A[] = { + Scalar(-5.350447357812542947283E0), Scalar(2.185049168816613393830E2), + Scalar(-4.176572384826693777058E3), Scalar(5.541176756393557601232E4), + Scalar(-3.313381331178144034309E5), Scalar(1.592627163384945414220E6), + }; + const Scalar B[] = { + Scalar(1.0), + Scalar(-5.250547959112862969197E1), + Scalar(1.259616186786790571525E3), + Scalar(-1.756549581973534652631E4), + Scalar(1.493062117002725991967E5), + Scalar(-7.294949239640527645655E5), + Scalar(1.592627163384945429726E6), + }; + + // Euler gamma. + const Scalar EUL = Scalar(0.5772156649015329); + + const Scalar f = (internal::ppolevl::run(x, A) / + internal::ppolevl::run(x, B)); + return x * f + EUL + numext::log(x); +} + +template +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar +generic_expint_interval_2(const Scalar& x) { + /* 2 <= x <= 4 + x exp(-x) Ei(x) - 1 = 1/x A6(1/x) / B6(1/x) + Theoretical absolute error = 4.89e-17 */ + const Scalar A6[] = { + Scalar(1.981808503259689673238E-2), Scalar(-1.271645625984917501326E0), + Scalar(-2.088160335681228318920E0), Scalar(2.755544509187936721172E0), + Scalar(-4.409507048701600257171E-1), Scalar(4.665623805935891391017E-2), + Scalar(-1.545042679673485262580E-3), Scalar(7.059980605299617478514E-5), + }; + const Scalar B6[] = { + Scalar(1.0), + Scalar(1.476498670914921440652E0), + Scalar(5.629177174822436244827E-1), + Scalar(1.699017897879307263248E-1), + Scalar(2.291647179034212017463E-2), + Scalar(4.450150439728752875043E-3), + Scalar(1.727439612206521482874E-4), + Scalar(3.953167195549672482304E-5), + }; + + const Scalar one = Scalar(1.0); + Scalar w = one / x; + Scalar f = (internal::ppolevl::run(w, A6) / + internal::ppolevl::run(w, B6)); + f = w * f + one; + return numext::exp(x) * w * f; +} + +template +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar +generic_expint_interval_3(const Scalar& x) { + /* 4 <= x <= 8 + x exp(-x) Ei(x) - 1 = 1/x A5(1/x) / B5(1/x) + Theoretical absolute error = 2.20e-17 */ + const Scalar A5[] = { + Scalar(-1.373215375871208729803E0), Scalar(-7.084559133740838761406E-1), + Scalar(1.580806855547941010501E0), Scalar(-2.601500427425622944234E-1), + Scalar(2.994674694113713763365E-2), Scalar(-1.038086040188744005513E-3), + Scalar(4.371064420753005429514E-5), Scalar(2.141783679522602903795E-6), + }; + const Scalar B5[] = { + Scalar(1.0), + Scalar(8.585231423622028380768E-1), + Scalar(4.483285822873995129957E-1), + Scalar(7.687932158124475434091E-2), + Scalar(2.449868241021887685904E-2), + Scalar(8.832165941927796567926E-4), + Scalar(4.590952299511353531215E-4), + Scalar(-4.729848351866523044863E-6), + Scalar(2.665195537390710170105E-6), + }; + + const Scalar one = Scalar(1.0); + Scalar w = one / x; + Scalar f = (internal::ppolevl::run(w, A5) / + internal::ppolevl::run(w, B5)); + f = w * f + one; + return numext::exp(x) * w * f; +} + +template +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar +generic_expint_interval_4(const Scalar& x) { + /* 8 <= x <= 16 + x exp(-x) Ei(x) - 1 = 1/x R(1/x) + Theoretical peak absolute error = 1.07e-17 */ + const Scalar A2[] = { + Scalar(-2.106934601691916512584E0), Scalar(1.732733869664688041885E0), + Scalar(-2.423619178935841904839E-1), Scalar(2.322724180937565842585E-2), + Scalar(2.372880440493179832059E-4), Scalar(-8.343219561192552752335E-5), + Scalar(1.363408795605250394881E-5), Scalar(-3.655412321999253963714E-7), + Scalar(1.464941733975961318456E-8), Scalar(6.176407863710360207074E-10), + }; + const Scalar B2[] = { + Scalar(1.0), + Scalar(-2.298062239901678075778E-1), + Scalar(1.105077041474037862347E-1), + Scalar(-1.566542966630792353556E-2), + Scalar(2.761106850817352773874E-3), + Scalar(-2.089148012284048449115E-4), + Scalar(1.708528938807675304186E-5), + Scalar(-4.459311796356686423199E-7), + Scalar(1.394634930353847498145E-8), + Scalar(6.150865933977338354138E-10), + }; + + const Scalar one = Scalar(1.0); + Scalar w = one / x; + Scalar f = (internal::ppolevl::run(w, A2) / + internal::ppolevl::run(w, B2)); + f = w * f + one; + return numext::exp(x) * w * f; +} + +template +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar +generic_expint_interval_5(const Scalar& x) { + /* 16 <= x <= 32 + x exp(-x) Ei(x) - 1 = 1/x A4(1/x) / B4(1/x) + Theoretical absolute error = 1.22e-17 */ + const Scalar A4[] = { + Scalar(-2.458119367674020323359E-1), Scalar(-1.483382253322077687183E-1), + Scalar(7.248291795735551591813E-2), Scalar(-1.348315687380940523823E-2), + Scalar(1.342775069788636972294E-3), Scalar(-7.942465637159712264564E-5), + Scalar(2.644179518984235952241E-6), Scalar(-4.239473659313765177195E-8), + }; + const Scalar B4[] = { + Scalar(1.0), + Scalar(-1.044225908443871106315E-1), + Scalar(-2.676453128101402655055E-1), + Scalar(9.695000254621984627876E-2), + Scalar(-1.601745692712991078208E-2), + Scalar(1.496414899205908021882E-3), + Scalar(-8.462452563778485013756E-5), + Scalar(2.728938403476726394024E-6), + Scalar(-4.239462431819542051337E-8), + }; + + const Scalar one = Scalar(1.0); + Scalar w = one / x; + Scalar f = (internal::ppolevl::run(w, A4) / + internal::ppolevl::run(w, B4)); + f = w * f + one; + return numext::exp(x) * w * f; +} + +template +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar +generic_expint_interval_6(const Scalar& x) { + /* 32 <= x <= 64 + x exp(-x) Ei(x) - 1 = 1/x A7(1/x) / B7(1/x) + Theoretical absolute error = 7.71e-18 */ + const Scalar A7[] = { + Scalar(1.212561118105456670844E-1), Scalar(-5.823133179043894485122E-1), + Scalar(2.348887314557016779211E-1), Scalar(-3.040034318113248237280E-2), + Scalar(1.510082146865190661777E-3), Scalar(-2.523137095499571377122E-5), + }; + const Scalar B7[] = { + Scalar(1.0), + Scalar(-1.002252150365854016662E0), + Scalar(2.928709694872224144953E-1), + Scalar(-3.337004338674007801307E-2), + Scalar(1.560544881127388842819E-3), + Scalar(-2.523137093603234562648E-5), + }; + + const Scalar one = Scalar(1.0); + Scalar w = one / x; + Scalar f = (internal::ppolevl::run(w, A7) / + internal::ppolevl::run(w, B7)); + f = w * f + one; + return numext::exp(x) * w * f; +} + +template +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar +generic_expint_interval_7(const Scalar& x) { + /* x > 64 + x exp(-x) Ei(x) - 1 = 1/x A3(1/x)/B3(1/x) + Theoretical absolute error = 6.15e-17 */ + const Scalar A3[] = { + Scalar(-7.657847078286127362028E-1), Scalar(6.886192415566705051750E-1), + Scalar(-2.132598113545206124553E-1), Scalar(3.346107552384193813594E-2), + Scalar(-3.076541477344756050249E-3), Scalar(1.747119316454907477380E-4), + Scalar(-6.103711682274170530369E-6), Scalar(1.218032765428652199087E-7), + Scalar(-1.086076102793290233007E-9), + }; + const Scalar B3[] = { + Scalar(1.0), + Scalar(-1.888802868662308731041E0), + Scalar(1.066691687211408896850E0), + Scalar(-2.751915982306380647738E-1), + Scalar(3.930852688233823569726E-2), + Scalar(-3.414684558602365085394E-3), + Scalar(1.866844370703555398195E-4), + Scalar(-6.345146083130515357861E-6), + Scalar(1.239754287483206878024E-7), + Scalar(-1.086076102793126632978E-9), + }; + + const Scalar one = Scalar(1.0); + Scalar w = one / x; + Scalar f = (internal::ppolevl::run(w, A3) / + internal::ppolevl::run(w, B3)); + f = w * f + one; + return numext::exp(x) * w * f; +} + +template +struct expint_op { + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar + operator()(const Scalar& x) const { + const Scalar zero = Scalar(0.0); + const Scalar two = Scalar(2.0); + const Scalar four = Scalar(4.0); + const Scalar eight = Scalar(8.0); + const Scalar sixteen = Scalar(16.0); + const Scalar thirty_two = Scalar(32.0); + const Scalar sixty_four = Scalar(64.0); + const Scalar nan = Scalar(NumTraits::quiet_NaN()); + + if (x < zero) { + return nan; + } + + if (x < two) { + return generic_expint_interval_1(x); + } else if (x < four) { + return generic_expint_interval_2(x); + } else if (x < eight) { + return generic_expint_interval_3(x); + } else if (x < sixteen) { + return generic_expint_interval_4(x); + } else if (x < thirty_two) { + return generic_expint_interval_5(x); + } else if (x < sixty_four) { + return generic_expint_interval_6(x); + } + return generic_expint_interval_7(x); + } +}; + +// Implementation of Fresnel cosine and sine integrals, based on Cephes. + +template +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar +generic_fresnel_cos_interval_1(const Scalar& x) { + const Scalar CN[] = { + Scalar(-4.98843114573573548651E-8), Scalar(9.50428062829859605134E-6), + Scalar(-6.45191435683965050962E-4), Scalar(1.88843319396703850064E-2), + Scalar(-2.05525900955013891793E-1), Scalar(9.99999999999999998822E-1), + }; + const Scalar CD[] = { + Scalar(3.99982968972495980367E-12), Scalar(9.15439215774657478799E-10), + Scalar(1.25001862479598821474E-7), Scalar(1.22262789024179030997E-5), + Scalar(8.68029542941784300606E-4), Scalar(4.12142090722199792936E-2), + Scalar(1.00000000000000000118E0), + }; + + const Scalar x2 = x * x; + Scalar x4 = x2 * x2; + return (x * internal::ppolevl::run(x4, CN) / + internal::ppolevl::run(x4, CD)); +} + +template +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar +generic_fresnel_sin_interval_1(const Scalar& x) { + const Scalar SN[] = { + Scalar(-2.99181919401019853726E3), Scalar(7.08840045257738576863E5), + Scalar(-6.29741486205862506537E7), Scalar(2.54890880573376359104E9), + Scalar(-4.42979518059697779103E10), Scalar(3.18016297876567817986E11), + }; + const Scalar SD[] = { + Scalar(1.0), + Scalar(2.81376268889994315696E2), + Scalar(4.55847810806532581675E4), + Scalar(5.17343888770096400730E6), + Scalar(4.19320245898111231129E8), + Scalar(2.24411795645340920940E10), + Scalar(6.07366389490084639049E11), + }; + + const Scalar x2 = x * x; + Scalar x4 = x2 * x2; + Scalar z = x * x2; + return (z * internal::ppolevl::run(x4, SN) / + internal::ppolevl::run(x4, SD)); +} + +template +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar +generic_fresnel_asymp(const Scalar& x, bool use_sin) { + const Scalar FN[] = { + Scalar(4.21543555043677546506E-1), Scalar(1.43407919780758885261E-1), + Scalar(1.15220955073585758835E-2), Scalar(3.45017939782574027900E-4), + Scalar(4.63613749287867322088E-6), Scalar(3.05568983790257605827E-8), + Scalar(1.02304514164907233465E-10), Scalar(1.72010743268161828879E-13), + Scalar(1.34283276233062758925E-16), Scalar(3.76329711269987889006E-20), + }; + const Scalar FD[] = { + Scalar(1.0), + Scalar(7.51586398353378947175E-1), + Scalar(1.16888925859191382142E-1), + Scalar(6.44051526508858611005E-3), + Scalar(1.55934409164153020873E-4), + Scalar(1.84627567348930545870E-6), + Scalar(1.12699224763999035261E-8), + Scalar(3.60140029589371370404E-11), + Scalar(5.88754533621578410010E-14), + Scalar(4.52001434074129701496E-17), + Scalar(1.25443237090011264384E-20), + }; + const Scalar GN[] = { + Scalar(5.04442073643383265887E-1), Scalar(1.97102833525523411709E-1), + Scalar(1.87648584092575249293E-2), Scalar(6.84079380915393090172E-4), + Scalar(1.15138826111884280931E-5), Scalar(9.82852443688422223854E-8), + Scalar(4.45344415861750144738E-10), Scalar(1.08268041139020870318E-12), + Scalar(1.37555460633261799868E-15), Scalar(8.36354435630677421531E-19), + Scalar(1.86958710162783235106E-22), + }; + const Scalar GD[] = { + Scalar(1.0), + Scalar(1.47495759925128324529E0), + Scalar(3.37748989120019970451E-1), + Scalar(2.53603741420338795122E-2), + Scalar(8.14679107184306179049E-4), + Scalar(1.27545075667729118702E-5), + Scalar(1.04314589657571990585E-7), + Scalar(4.60680728146520428211E-10), + Scalar(1.10273215066240270757E-12), + Scalar(1.38796531259578871258E-15), + Scalar(8.39158816283118707363E-19), + Scalar(1.86958710162783236342E-22), + }; + + const Scalar HALF_PI = Scalar(1.5707963267948966); + const Scalar PI = Scalar(EIGEN_PI); + const Scalar one = Scalar(1); + const Scalar half = Scalar(0.5); + + const Scalar x2 = x * x; + const Scalar t = one / pmul(PI, x2); + Scalar u = t * t; + + Scalar f = one - u * (internal::ppolevl::run(u, FN) / + internal::ppolevl::run(u, FD)); + Scalar g = (t * internal::ppolevl::run(u, GN) / + internal::ppolevl::run(u, GD)); + + const Scalar z = HALF_PI * x2; + const Scalar c = numext::cos(z); + const Scalar s = numext::sin(z); + const Scalar y = one / (PI * x); + if (use_sin) { + Scalar intermediate = f * c; + intermediate += g * s; + return half - intermediate * y; + } + Scalar intermediate = f * s; + intermediate -= g * c; + return half + intermediate * y; +} + +template +struct fresnel_cos_op { + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar + operator()(const Scalar& x) const { + const Scalar zero = Scalar(0.); + const Scalar half = Scalar(0.5); + const Scalar a = Scalar(2.5625); + const Scalar b = Scalar(36974.0); + + const Scalar abs_x = numext::abs(x); + + if (abs_x > b) { + if (x < zero) { + return -half; + } + return half; + } + + const Scalar x2 = x * x; + + Scalar fresnel_cos; + if (x2 < a) { + fresnel_cos = generic_fresnel_cos_interval_1(abs_x); + } else { + fresnel_cos = generic_fresnel_asymp(abs_x, false); + } + if (x < zero) { + return -fresnel_cos; + } + return fresnel_cos; + } +}; + +template +struct fresnel_sin_op { + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar + operator()(const Scalar& x) const { + const Scalar zero = Scalar(0.); + const Scalar half = Scalar(0.5); + const Scalar a = Scalar(2.5625); + const Scalar b = Scalar(36974.0); + const Scalar abs_x = numext::abs(x); + + if (abs_x > b) { + if (x < zero) { + return -half; + } + return half; + } + + const Scalar x2 = x * x; + + Scalar fresnel_sin; + if (x2 < a) { + fresnel_sin = generic_fresnel_sin_interval_1(abs_x); + } else { + fresnel_sin = generic_fresnel_asymp(abs_x, true); + } + + if (x < zero) { + return -fresnel_sin; + } + return fresnel_sin; + } +}; + +// Implementation of Spence's Integral based on Cephes. +template +struct spence_op { + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar + operator()(const Scalar& x) const { + const Scalar A[] = { + Scalar(4.65128586073990045278E-5), Scalar(7.31589045238094711071E-3), + Scalar(1.33847639578309018650E-1), Scalar(8.79691311754530315341E-1), + Scalar(2.71149851196553469920E0), Scalar(4.25697156008121755724E0), + Scalar(3.29771340985225106936E0), Scalar(1.00000000000000000126E0), + }; + const Scalar B[] = { + Scalar(6.90990488912553276999E-4), Scalar(2.54043763932544379113E-2), + Scalar(2.82974860602568089943E-1), Scalar(1.41172597751831069617E0), + Scalar(3.63800533345137075418E0), Scalar(5.03278880143316990390E0), + Scalar(3.54771340985225096217E0), Scalar(9.99999999999999998740E-1), + }; + const Scalar zero = Scalar(0.0); + const Scalar one = Scalar(1.0); + const Scalar three_halves = Scalar(1.5); + const Scalar two = Scalar(2.0); + const Scalar half = Scalar(0.5); + const Scalar nan = Scalar(NumTraits::quiet_NaN()); + // pi**2 / 6. + const Scalar PI2O6 = Scalar(EIGEN_PI * EIGEN_PI / 6.0); + + if (x < zero) { + return nan; + } else if (x == zero) { + return PI2O6; + } else if (x == one) { + return zero; + } + + Scalar y; + if (x < two) { + y = x; + } else { + y = one / x; + } + + Scalar w; + if (three_halves < y) { + w = one / y - one; + } else { + if (y < half) { + w = -y; + } else { + w = y - one; + } + } + Scalar spence = -w * (internal::ppolevl::run(w, A) / + internal::ppolevl::run(w, B)); + Scalar z = numext::log(y); + if (y < half) { + spence = -z * numext::log1p(-y) + PI2O6 - spence; + } + if (three_halves < x) { + spence = -half * z * z - spence; + } + return spence; + } +}; + +} // end namespace internal +} // end namespace Eigen + +namespace tensorflow { +namespace functor { + +template +struct dawsn : base> {}; + +template +struct expint : base> {}; + +template +struct fresnel_cos : base> {}; + +template +struct fresnel_sin : base> {}; + +template +struct spence : base> {}; + +// Bessel Functions + +template +struct bessel_i0 : base> {}; + +template +struct bessel_i0e : base> {}; + +template +struct bessel_i1 : base> {}; + +template +struct bessel_i1e : base> {}; + +template +struct bessel_k0 : base> {}; + +template +struct bessel_k0e : base> {}; + +template +struct bessel_k1 : base> {}; + +template +struct bessel_k1e : base> {}; + +template +struct bessel_j0 : base> {}; + +template +struct bessel_j1 : base> {}; + +template +struct bessel_y0 : base> {}; + +template +struct bessel_y1 : base> {}; + +} // end namespace functor +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_SPECIAL_MATH_SPECIAL_MATH_OP_MISC_IMPL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/spectrogram.h b/third_party/tflite-hdrs/tensorflow/core/kernels/spectrogram.h new file mode 100644 index 00000000..4b6b9c8b --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/spectrogram.h @@ -0,0 +1,126 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Class for generating spectrogram slices from a waveform. +// Initialize() should be called before calls to other functions. Once +// Initialize() has been called and returned true, The Compute*() functions can +// be called repeatedly with sequential input data (ie. the first element of the +// next input vector directly follows the last element of the previous input +// vector). Whenever enough audio samples are buffered to produce a +// new frame, it will be placed in output. Output is cleared on each +// call to Compute*(). This class is thread-unsafe, and should only be +// called from one thread at a time. +// With the default parameters, the output of this class should be very +// close to the results of the following MATLAB code: +// overlap_samples = window_length_samples - step_samples; +// window = hann(window_length_samples, 'periodic'); +// S = abs(spectrogram(audio, window, overlap_samples)).^2; + +#ifndef TENSORFLOW_CORE_KERNELS_SPECTROGRAM_H_ +#define TENSORFLOW_CORE_KERNELS_SPECTROGRAM_H_ + +#include +#include +#include + +#include "third_party/fft2d/fft.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor.h" + +namespace tensorflow { + +class Spectrogram { + public: + Spectrogram() : initialized_(false) {} + ~Spectrogram() {} + + // Initializes the class with a given window length and step length + // (both in samples). Internally a Hann window is used as the window + // function. Returns true on success, after which calls to Process() + // are possible. window_length must be greater than 1 and step + // length must be greater than 0. + bool Initialize(int window_length, int step_length); + + // Initialize with an explicit window instead of a length. + bool Initialize(const std::vector& window, int step_length); + + // Reset internal variables. + // Spectrogram keeps internal state: remaining input data from previous call. + // As a result it can produce different number of frames when you call + // ComputeComplexSpectrogram multiple times (even though input data + // has the same size). As it is shown in + // MultipleCallsToComputeComplexSpectrogramMayYieldDifferentNumbersOfFrames + // in tensorflow/core/kernels/spectrogram_test.cc. + // But if you need to compute Spectrogram on input data without keeping + // internal state (and clear remaining input data from the previous call) + // you have to call Reset() before computing Spectrogram. + // For example in tensorflow/core/kernels/spectrogram_op.cc + bool Reset(); + + // Processes an arbitrary amount of audio data (contained in input) + // to yield complex spectrogram frames. After a successful call to + // Initialize(), Process() may be called repeatedly with new input data + // each time. The audio input is buffered internally, and the output + // vector is populated with as many temporally-ordered spectral slices + // as it is possible to generate from the input. The output is cleared + // on each call before the new frames (if any) are added. + // + // The template parameters can be float or double. + template + bool ComputeComplexSpectrogram( + const std::vector& input, + std::vector>>* output); + + // This function works as the one above, but returns the power + // (the L2 norm, or the squared magnitude) of each complex value. + template + bool ComputeSquaredMagnitudeSpectrogram( + const std::vector& input, + std::vector>* output); + + // Return reference to the window function used internally. + const std::vector& GetWindow() const { return window_; } + + // Return the number of frequency channels in the spectrogram. + int output_frequency_channels() const { return output_frequency_channels_; } + + private: + template + bool GetNextWindowOfSamples(const std::vector& input, + int* input_start); + void ProcessCoreFFT(); + + int fft_length_; + int output_frequency_channels_; + int window_length_; + int step_length_; + bool initialized_; + int samples_to_next_step_; + + std::vector window_; + std::vector fft_input_output_; + std::deque input_queue_; + + // Working data areas for the FFT routines. + std::vector fft_integer_working_area_; + std::vector fft_double_working_area_; + + Spectrogram(const Spectrogram&) = delete; + void operator=(const Spectrogram&) = delete; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_SPECTROGRAM_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/spectrogram_test_utils.h b/third_party/tflite-hdrs/tensorflow/core/kernels/spectrogram_test_utils.h new file mode 100644 index 00000000..d4187076 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/spectrogram_test_utils.h @@ -0,0 +1,81 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_SPECTROGRAM_TEST_UTILS_H_ +#define TENSORFLOW_CORE_KERNELS_SPECTROGRAM_TEST_UTILS_H_ + +#include +#include +#include + +#include "tensorflow/core/framework/types.h" + +namespace tensorflow { + +// Reads a wav format file into a vector of floating-point values with range +// -1.0 to 1.0. +bool ReadWaveFileToVector(const string& file_name, std::vector* data); + +// Reads a binary file containing 32-bit floating point values in the +// form [real_1, imag_1, real_2, imag_2, ...] into a rectangular array +// of complex values where row_length is the length of each inner vector. +bool ReadRawFloatFileToComplexVector( + const string& file_name, int row_length, + std::vector > >* data); + +// Reads a CSV file of numbers in the format 1.1+2.2i,1.1,2.2i,3.3j into data. +void ReadCSVFileToComplexVectorOrDie( + const string& file_name, + std::vector > >* data); + +// Reads a 2D array of floats from an ASCII text file, where each line is a row +// of the array, and elements are separated by commas. +void ReadCSVFileToArrayOrDie(const string& filename, + std::vector >* array); + +// Write a binary file containing 64-bit floating-point values for +// reading by, for example, MATLAB. +bool WriteDoubleVectorToFile(const string& file_name, + const std::vector& data); + +// Write a binary file containing 32-bit floating-point values for +// reading by, for example, MATLAB. +bool WriteFloatVectorToFile(const string& file_name, + const std::vector& data); + +// Write a binary file containing 64-bit floating-point values for +// reading by, for example, MATLAB. +bool WriteDoubleArrayToFile(const string& file_name, int size, + const double* data); + +// Write a binary file containing 32-bit floating-point values for +// reading by, for example, MATLAB. +bool WriteFloatArrayToFile(const string& file_name, int size, + const float* data); + +// Write a binary file in the format read by +// ReadRawDoubleFileToComplexVector above. +bool WriteComplexVectorToRawFloatFile( + const string& file_name, + const std::vector > >& data); + +// Generate a sine wave with the provided parameters, and populate +// data with the samples. +void SineWave(int sample_rate, float frequency, float duration_seconds, + std::vector* data); + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_SPECTROGRAM_TEST_UTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/split_lib.h b/third_party/tflite-hdrs/tensorflow/core/kernels/split_lib.h new file mode 100644 index 00000000..28257ed4 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/split_lib.h @@ -0,0 +1,55 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_SPLIT_LIB_H_ +#define TENSORFLOW_CORE_KERNELS_SPLIT_LIB_H_ +// Functor definition for SplitOp, must be compilable by nvcc. + +#include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/tensor_types.h" + +namespace tensorflow { +namespace functor { + +template +struct SplitCustom { + void operator()(const Device& d, typename TTypes::Tensor output, + typename TTypes::ConstTensor input, + const Eigen::DSizes& slice_indices, + const Eigen::DSizes& slice_sizes); +}; + +template +struct Split { + void operator()(const Device& d, typename TTypes::Tensor output, + typename TTypes::ConstTensor input, + const Eigen::DSizes& slice_indices, + const Eigen::DSizes& slice_sizes); +}; + +template +struct Split { + void operator()(const Eigen::ThreadPoolDevice& d, + typename TTypes::Tensor output, + typename TTypes::ConstTensor input, + const Eigen::DSizes& slice_indices, + const Eigen::DSizes& slice_sizes); +}; + + +} // namespace functor +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_SPLIT_LIB_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/split_lib_gpu.h b/third_party/tflite-hdrs/tensorflow/core/kernels/split_lib_gpu.h new file mode 100644 index 00000000..ae767b07 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/split_lib_gpu.h @@ -0,0 +1,60 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_SPLIT_LIB_GPU_H_ +#define TENSORFLOW_CORE_KERNELS_SPLIT_LIB_GPU_H_ + +#define EIGEN_USE_THREADS +#define EIGEN_USE_GPU + +#include +#include + +#include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/kernels/gpu_device_array_gpu.h" +#include "tensorflow/core/kernels/split_lib.h" + +namespace tensorflow { + +template +struct SplitOpGPULaunch { + void Run(const Eigen::GpuDevice& d, const T* input, int32_t prefix_dim_size, + int32_t split_dim_size, int32_t suffix_dim_size, + const GpuDeviceArrayStruct& output_ptr_data); +}; + +template +struct SplitVOpGPULaunch { + void Run(const Eigen::GpuDevice& d, bool fixed, const T* input, + int total_cols, int total_rows, + const GpuDeviceArrayStruct& output_scan, + const GpuDeviceArrayStruct& output_ptr_data); +}; + +// Explicit instantiations in split_lib_gpu.cu.cc. +#define REGISTER_GPU_KERNEL(T) \ + extern template struct SplitOpGPULaunch; \ + extern template struct SplitVOpGPULaunch; \ + extern template struct SplitVOpGPULaunch; \ + extern template struct SplitVOpGPULaunch; + +TF_CALL_uint8(REGISTER_GPU_KERNEL); +TF_CALL_GPU_ALL_TYPES(REGISTER_GPU_KERNEL); +#undef REGISTER_GPU_KERNEL + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_SPLIT_LIB_GPU_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/squared-loss.h b/third_party/tflite-hdrs/tensorflow/core/kernels/squared-loss.h new file mode 100644 index 00000000..3b334d68 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/squared-loss.h @@ -0,0 +1,73 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_SQUARED_LOSS_H_ +#define TENSORFLOW_CORE_KERNELS_SQUARED_LOSS_H_ + +#include "tensorflow/core/kernels/loss.h" + +namespace tensorflow { + +class SquaredLossUpdater : public DualLossUpdater { + public: + // Closed form solution that decreases the dual squared loss. + // See page 23 of http://arxiv.org/pdf/1309.2375v2.pdf for the derivation of + // the update rule when the example weights are equal to 1.0. + // Note: There is a typo in the formula in the paper: the denominator should + // be 1 + ||x_i||^2/(\lambda n) (without the 2 multiplier). + // + // The CoCoA+ modification is detailed in readme.md. + double ComputeUpdatedDual(const int num_loss_partitions, const double label, + const double example_weight, + const double current_dual, const double wx, + const double weighted_example_norm) const final { + const double delta_numerator = label - current_dual - wx; + const double delta_denominator = + 1 + num_loss_partitions * weighted_example_norm * example_weight; + return current_dual + delta_numerator / delta_denominator; + } + + // Dual of squared loss function. + // https://en.wikipedia.org/wiki/Convex_conjugate + double ComputeDualLoss(const double current_dual, const double example_label, + const double example_weight) const final { + // Dual of the squared loss function = b * (y + b/2), where b is the + // dual variable and y is the label. This is Dual(-b). + return current_dual * (0.5 * current_dual - example_label) * example_weight; + } + + // Squared loss for linear regression. + double ComputePrimalLoss(const double wx, const double example_label, + const double example_weight) const final { + const double error = wx - example_label; + return error * error * example_weight * 0.5; + } + + inline double PrimalLossDerivative(const double wx, const double label, + const double example_weight) const final { + return (wx - label) * example_weight; + } + + inline double SmoothnessConstant() const final { return 1.0; } + + // Labels don't require conversion for linear regression. + absl::Status ConvertLabel(float* const example_label) const final { + return absl::OkStatus(); + } +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_SQUARED_LOSS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/stack.h b/third_party/tflite-hdrs/tensorflow/core/kernels/stack.h new file mode 100644 index 00000000..a9c6a607 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/stack.h @@ -0,0 +1,77 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_STACK_H_ +#define TENSORFLOW_CORE_KERNELS_STACK_H_ + +// See docs in ../ops/data_flow_ops.cc. + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +// A per-run local stack. The stack uses a "per-step" resource manager which +// ensures that correct garbage collection on error or successful completion. +class StackOp : public OpKernel { + public: + explicit StackOp(OpKernelConstruction* context); + void Compute(OpKernelContext* ctx) override; + + private: + DataType elem_type_; + string stack_name_; + + StackOp(const StackOp&) = delete; + void operator=(const StackOp&) = delete; +}; + +class StackPushOp : public AsyncOpKernel { + public: + StackPushOp(OpKernelConstruction* context, bool allow_swapping); + void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override; + bool IsExpensive() override; + + private: + bool swap_memory_ = false; +}; + +// Templated helper to make it easier to register kernels with or without +// swapping. +template +class TemplatedStackPushOp : public StackPushOp { + public: + TemplatedStackPushOp(OpKernelConstruction* context) + : StackPushOp(context, allow_swapping) {} +}; + +class StackPopOp : public AsyncOpKernel { + public: + explicit StackPopOp(OpKernelConstruction* context); + void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override; + bool IsExpensive() override; +}; + +class StackCloseOp : public OpKernel { + public: + explicit StackCloseOp(OpKernelConstruction* context); + void Compute(OpKernelContext* ctx) override; + bool IsExpensive() override; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_STACK_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/stateful_random_ops.h b/third_party/tflite-hdrs/tensorflow/core/kernels/stateful_random_ops.h new file mode 100644 index 00000000..21a08fa0 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/stateful_random_ops.h @@ -0,0 +1,42 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_STATEFUL_RANDOM_OPS_H_ +#define TENSORFLOW_CORE_KERNELS_STATEFUL_RANDOM_OPS_H_ + +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/random/philox_random.h" + +namespace tensorflow { + +// 'Variable' doesn't support uint32 or uint64 yet (due to reasons explained +// in b/111604096 and cl/171681867), so we use signed int here. We choose int64 +// instead of int32 because `VarHandleOp` doesn't support int32 on GPU, and +// because of the "int32 problem". +using StateElementType = int64_t; +static constexpr DataType STATE_ELEMENT_DTYPE = DT_INT64; +static constexpr DataType ALGORITHM_DTYPE = STATE_ELEMENT_DTYPE; + +using random::PhiloxRandom; + +static constexpr int64_t PHILOX_MIN_STATE_SIZE = + (PhiloxRandom::ResultType::kElementCount + + PhiloxRandom::Key::kElementCount) / + 2; +static constexpr int64_t THREEFRY_MIN_STATE_SIZE = 2; + +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_STATEFUL_RANDOM_OPS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/stateful_random_ops_cpu_gpu.h b/third_party/tflite-hdrs/tensorflow/core/kernels/stateful_random_ops_cpu_gpu.h new file mode 100644 index 00000000..74eb40f3 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/stateful_random_ops_cpu_gpu.h @@ -0,0 +1,114 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_STATEFUL_RANDOM_OPS_CPU_GPU_H_ +#define TENSORFLOW_CORE_KERNELS_STATEFUL_RANDOM_OPS_CPU_GPU_H_ + +#include "tensorflow/core/kernels/random_ops_util.h" +#include "tensorflow/core/kernels/stateful_random_ops.h" + +namespace tensorflow { + +PHILOX_DEVICE_INLINE PhiloxRandom +GetPhiloxRandomFromMem(StateElementType const* ptr) { + auto ptr_ = reinterpret_cast(ptr); + return GetPhiloxRandomFromCounterKeyMem(ptr_, ptr_ + 2); +} + +PHILOX_DEVICE_INLINE void WritePhiloxRandomToMem(PhiloxRandom const& philox, + StateElementType* ptr) { + auto ptr_ = reinterpret_cast(ptr); + WriteCounterToMem(philox.counter(), ptr_); + WriteKeyToMem(philox.key(), ptr_ + 2); +} + +PHILOX_DEVICE_INLINE PhiloxRandom SkipPhiloxRandom(PhiloxRandom const& philox, + uint64 output_size) { + auto new_philox = philox; + // Multiplier 256 is the same as in FillPhiloxRandomTask; do not change it + // just here. + auto delta = output_size * 256; + new_philox.Skip(delta); // do the actual increasing + return new_philox; +} + +PHILOX_DEVICE_INLINE void UpdateMemWithPhiloxRandom(PhiloxRandom const& philox, + uint64 output_size, + StateElementType* ptr) { + auto new_philox = SkipPhiloxRandom(philox, output_size); + WritePhiloxRandomToMem(new_philox, ptr); +} + +PHILOX_DEVICE_INLINE void UpdateCounterMemWithPhiloxRandom( + PhiloxRandom::ResultType const& counter, uint64 output_size, + StateElementType* ptr) { + auto philox = PhiloxRandom(counter, PhiloxRandom::Key() /*dummy*/); + auto new_philox = SkipPhiloxRandom(philox, output_size); + WriteCounterToMem(new_philox.counter(), reinterpret_cast(ptr)); +} + +namespace functor { + +// A per-device helper function that does the actual work for +// `UpdateVariableAndFill`. +// Reason to use functor: C++ doesn't allow function-template partial +// specialization. +template +struct UpdateVariableAndFill_Philox; + +template +struct RngSkip_Philox; + +} // end namespace functor + +using CPUDevice = Eigen::ThreadPoolDevice; + +class ScopedUnlockUnrefVar; + +struct UpdateVariableAndFill_Philox_Arg { + int64_t output_size; + int64_t alg_tag_skip; + ScopedUnlockUnrefVar* state_var_guard; + Tensor* state_tensor; +}; + +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM + +using GPUDevice = Eigen::GpuDevice; + +namespace functor { + +// Declares the partially GPU-specialized functor structs. +// must be kept at <=6 arguments because of a gcc/clang ABI incompatibility bug +template +struct UpdateVariableAndFill_Philox { + void operator()(OpKernelContext* ctx, const GPUDevice& device, + Distribution dist, UpdateVariableAndFill_Philox_Arg* arg, + typename Distribution::ResultElementType* output_data); +}; + +template <> +struct RngSkip_Philox { + void operator()(const GPUDevice& device, const StateElementType* in_data, + uint64 delta, StateElementType* out_data); +}; + +} // end namespace functor + +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM + +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_STATEFUL_RANDOM_OPS_CPU_GPU_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/stateless_random_gamma_op.h b/third_party/tflite-hdrs/tensorflow/core/kernels/stateless_random_gamma_op.h new file mode 100644 index 00000000..426dbd5e --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/stateless_random_gamma_op.h @@ -0,0 +1,90 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_STATELESS_RANDOM_GAMMA_OP_H_ +#define TENSORFLOW_CORE_KERNELS_STATELESS_RANDOM_GAMMA_OP_H_ + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/random/philox_random.h" + +namespace tensorflow { + +namespace functor { + +// This functor can take the PhiloxRandom input from either device memory `key` +// and `counter` or a stack value `random`. If both `key` and `counter` are not +// nullptr, they provide the input; otherwise `random` provides the input. +template +struct StatelessRandomGammaFunctor { + static absl::Status Fill(OpKernelContext* ctx, const T* alpha_flat, + int64_t num_samples, int64_t num_alphas, + int64_t samples_per_alpha, const uint64* key, + const uint64* counter, + const random::PhiloxRandom& random, T* samples_flat); +}; + +} // namespace functor + +// Buffer that holds multiple samples. Operator()(random::PhiloxRandom*) returns +// a single sample from this buffer. If the buffer is empty, it first generates +// new samples using the provided distribution. +// +// If the call to Distribution::operator() returns samples[0...N-1], then this +// class returns samples in the following order: +// +// samples[N-1], samples[N-2],..., samples[1], samples[0] +// +// For comparison, random::SingleSampleAdapter returns samples in +// the following order: +// +// samples[0], samples[1],...,samples[N-2], samples[N-1]. +// +template +class RandomSampleBuffer { + public: + typedef typename Distribution::ResultElementType ResultElementType; + + PHILOX_DEVICE_INLINE + explicit RandomSampleBuffer(Distribution* distribution) + : distribution_(distribution), remaining_numbers_(0) {} + + PHILOX_DEVICE_INLINE + ResultElementType operator()(random::PhiloxRandom* random) { + if (remaining_numbers_ == 0) { + results_ = (*distribution_)(random); + remaining_numbers_ = Distribution::kResultElementCount; + } + + remaining_numbers_--; + return results_[remaining_numbers_]; + } + + // Mark this buffer as empty. The next call to operator() will fill it + // with new random numbers. + PHILOX_DEVICE_INLINE + void Clear() { remaining_numbers_ = 0; } + + private: + typedef typename Distribution::ResultType ResultType; + + Distribution* distribution_; + ResultType results_; + int remaining_numbers_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_STATELESS_RANDOM_GAMMA_OP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/stateless_random_ops.h b/third_party/tflite-hdrs/tensorflow/core/kernels/stateless_random_ops.h new file mode 100644 index 00000000..42ce3bff --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/stateless_random_ops.h @@ -0,0 +1,50 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_STATELESS_RANDOM_OPS_H_ +#define TENSORFLOW_CORE_KERNELS_STATELESS_RANDOM_OPS_H_ + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/lib/random/random_distributions.h" + +namespace tensorflow { + +// Generates a key and counter that can be used to seed a PhiloxRandom, +// generator, based on the seed value in `seed_t`. +// +// REQUIRES: `seed_t` must be a length-2 vector of type DT_INT{32,64}. +// `out_key` and `out_counter` must be non-null. +absl::Status GenerateKey(Tensor seed_t, random::PhiloxRandom::Key* out_key, + random::PhiloxRandom::ResultType* out_counter); + +// A base class for kernels of stateless RNG ops that take shape and seed as the +// first 2 inputs. +class StatelessRandomOpBase : public OpKernel { + public: + explicit StatelessRandomOpBase(OpKernelConstruction* context); + + void Compute(OpKernelContext* context) override; + + protected: + // The part of Compute that depends on device, type, and distribution. + // Must be a tail call because it doesn't report error via return value. + virtual void Fill(OpKernelContext* context, random::PhiloxRandom random, + Tensor* output) = 0; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_STATELESS_RANDOM_OPS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/stateless_random_ops_v2.h b/third_party/tflite-hdrs/tensorflow/core/kernels/stateless_random_ops_v2.h new file mode 100644 index 00000000..0b5b8945 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/stateless_random_ops_v2.h @@ -0,0 +1,61 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_STATELESS_RANDOM_OPS_V2_H_ +#define TENSORFLOW_CORE_KERNELS_STATELESS_RANDOM_OPS_V2_H_ + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/rng_alg.h" +#include "tensorflow/core/framework/tensor_shape.h" + +namespace tensorflow { + +inline absl::Status CheckKeyCounterShape(int minimum_counter_size, + TensorShape const& key_shape, + TensorShape const& counter_shape) { + if (!(key_shape.dims() == 1 && key_shape.dim_size(0) == RNG_KEY_SIZE)) { + return errors::InvalidArgument( + "key must have shape [", RNG_KEY_SIZE, "], not ", + key_shape.DebugString(), + ". (Note that batched keys are not supported yet.)"); + } + if (!(counter_shape.dims() == 1 && + counter_shape.dim_size(0) >= minimum_counter_size)) { + return errors::InvalidArgument( + "counter must be a vector with length at least ", minimum_counter_size, + "; got shape: ", counter_shape.DebugString(), + ". (Note that batched counters are not supported yet.)"); + } + return absl::OkStatus(); +} + +// A base class for kernels of stateless RNG ops that take shape, key, counter +// and algorithm as the first 4 inputs. +class StatelessRandomOpBaseWithKeyCounter : public OpKernel { + public: + explicit StatelessRandomOpBaseWithKeyCounter(OpKernelConstruction* ctx); + + void Compute(OpKernelContext* ctx) override; + + protected: + // The part of Compute that depends on device, type, and distribution. + // Must be a tail call because it doesn't report error via return value. + virtual void Fill(OpKernelContext* ctx, Algorithm alg, const Tensor& key, + const Tensor& counter, Tensor* output) = 0; +}; + +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_STATELESS_RANDOM_OPS_V2_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/stateless_random_ops_v2_util.h b/third_party/tflite-hdrs/tensorflow/core/kernels/stateless_random_ops_v2_util.h new file mode 100644 index 00000000..a5798342 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/stateless_random_ops_v2_util.h @@ -0,0 +1,86 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_STATELESS_RANDOM_OPS_V2_UTIL_H_ +#define TENSORFLOW_CORE_KERNELS_STATELESS_RANDOM_OPS_V2_UTIL_H_ + +// Utilities for V2 stateless random ops' (non-XLA) kernels. + +#include + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/kernels/random_op.h" +#include "tensorflow/core/kernels/stateless_random_ops_v2.h" +#include "tensorflow/core/lib/random/random_distributions.h" + +namespace tensorflow { + +template +absl::Status GetScalar(const Tensor& tensor, int input_idx, T* result) { + auto dtype = DataTypeToEnum::v(); + if (tensor.dims() != 0) { + return errors::InvalidArgument("input ", std::to_string(input_idx), + " (0-based) must have shape [], not ", + tensor.shape().DebugString()); + } + if (tensor.dtype() != dtype) { + return errors::InvalidArgument("dtype of input ", std::to_string(input_idx), + " (0-based) must be ", DataTypeString(dtype), + ", not ", DataTypeString(tensor.dtype())); + } + *result = tensor.flat()(0); + return absl::OkStatus(); +} + +inline absl::StatusOr> +GetKeyCounterAlgFromInputs(OpKernelContext* ctx, int key_input_idx, + int counter_input_idx, int alg_input_idx) { + const Tensor& key_t = ctx->input(key_input_idx); + const Tensor& counter_t = ctx->input(counter_input_idx); + const Tensor& alg_t = ctx->input(alg_input_idx); + + int alg_id; + TF_RETURN_IF_ERROR(GetScalar(alg_t, alg_input_idx, &alg_id)); + Algorithm alg = Algorithm(alg_id); + if (alg == RNG_ALG_AUTO_SELECT) { + alg = RNG_ALG_PHILOX; + } + + TF_RETURN_IF_ERROR( + CheckKeyCounterShape(alg, key_t.shape(), counter_t.shape())); + return std::make_tuple(key_t, counter_t, alg); +} + +template +void FillRandomTensor(OpKernelContext* ctx, Algorithm alg, const Tensor& key, + const Tensor& counter, Distribution dist, + Tensor* tensor) { + typedef typename Distribution::ResultElementType T; + auto flat = tensor->flat(); + if (alg == RNG_ALG_PHILOX) { + // Reuse the compute kernels from the stateful random ops + auto key_data = key.flat().data(); + auto counter_data = counter.flat().data(); + functor::FillPhiloxRandom()( + ctx, ctx->eigen_device(), key_data, counter_data, + random::PhiloxRandom() /*dummy*/, flat.data(), flat.size(), dist); + } else { + OP_REQUIRES(ctx, false, + errors::InvalidArgument("Unsupported algorithm id: ", alg)); + } +} +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_STATELESS_RANDOM_OPS_V2_UTIL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/stochastic_cast_op.h b/third_party/tflite-hdrs/tensorflow/core/kernels/stochastic_cast_op.h new file mode 100644 index 00000000..a1039b7f --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/stochastic_cast_op.h @@ -0,0 +1,140 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_STOCHASTIC_CAST_OP_H_ +#define TENSORFLOW_CORE_KERNELS_STOCHASTIC_CAST_OP_H_ + +#include +#include + +#include "Eigen/Core" // from @eigen_archive +#include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/rng_alg.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/lib/random/random_distributions.h" +#include "tensorflow/core/platform/macros.h" + +namespace tensorflow { +namespace internal { + +// Base class that dispatches random algorithm, key and counter for +// StochasticCast ops. +class StochasticCastOpBase : public OpKernel { + public: + explicit StochasticCastOpBase(OpKernelConstruction* ctx) : OpKernel(ctx) {} + + void Compute(OpKernelContext* ctx) override; + + protected: + // Subclasses can implement this rounding kernel with assumption that random + // algorithm, key, counter have been given. + virtual void RoundOff(OpKernelContext* ctx, Algorithm alg, const Tensor& key, + const Tensor& counter, Tensor* output) = 0; +}; + +} // namespace internal +} // namespace tensorflow + +namespace Eigen { +namespace internal { + +template +struct StochasticRoundToIntOp { + static_assert(std::is_integral::value, + "Integer type expected"); + typedef tensorflow::random::UniformDistribution + Distribution; + const Scalar max = + static_cast(std::numeric_limits::max()); + const Scalar min = + static_cast(std::numeric_limits::min()); + + EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC explicit StochasticRoundToIntOp( + Generator* g) + : gen(g) {} + + EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Scalar + operator()(const Scalar& s) const { + if (TF_PREDICT_FALSE(Eigen::numext::isnan(s))) { + return Scalar{0}; + } + if (s >= max) { + return max; + } + if (s <= min) { + return min; + } + // Already integer, doesn't need to be rounded. + if (Eigen::numext::floor(s) == s) { + return s; + } + // In order to match comparison-based algorithm on some hardware + // implementations which rounds abs(operand) up when random < + // abs(fractional), we deal with positive and negative operands differently. + // TODO(b/232442915): Revisit RNG multi-threading issue when needed. + Distribution dist; + Scalar random = dist(gen)[0]; + if (s < 0) { + return Eigen::numext::floor(s + random); + } else { + return Eigen::numext::floor(s + Scalar{1} - random); + } + } + + template + EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Packet packetOp(const Packet& p) const { + constexpr size_t kPacketSize = + Eigen::internal::unpacket_traits::size; + Scalar unpacked_random[kPacketSize]; + Distribution dist; + auto const sample = dist(gen); + for (int i = 0; i < kPacketSize; i += Distribution::kResultElementCount) { + int granularity = std::min(Distribution::kResultElementCount, + static_cast(kPacketSize - i)); + std::copy(&sample[0], &sample[0] + granularity, &unpacked_random[i]); + } + Packet random = pload(unpacked_random); + Packet rounded = + pselect(pcmp_eq(pfloor(p), p), p, + pselect(pcmp_lt(p, pzero(p)), pfloor(padd(p, random)), + pfloor(padd(p, psub(pset1(1), random))))); + // Handles out of range inputs. + Packet result = + pselect(pcmp_le(pset1(max), p), pset1(max), rounded); + result = + pselect(pcmp_le(p, pset1(min)), pset1(min), result); + // Handles NaN input. + return pselect(pcmp_eq(p, p), result, pset1(0)); + } + Generator* gen; +}; + +template +struct functor_traits< + StochasticRoundToIntOp> { + enum { + Cost = 3 * NumTraits::AddCost, + PacketAccess = + packet_traits::HasCmp && packet_traits::HasRound, + }; +}; + +// TODO(b/232442915): Add support for rounding floats to lower precision floats. + +} // namespace internal +} // namespace Eigen + +#endif // TENSORFLOW_CORE_KERNELS_STOCHASTIC_CAST_OP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/strided_slice_op.h b/third_party/tflite-hdrs/tensorflow/core/kernels/strided_slice_op.h new file mode 100644 index 00000000..439f22e7 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/strided_slice_op.h @@ -0,0 +1,123 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_STRIDED_SLICE_OP_H_ +#define TENSORFLOW_CORE_KERNELS_STRIDED_SLICE_OP_H_ + +// Functor definition for StridedSliceOp, must be compilable by nvcc. + +#include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/resource_handle.h" +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/framework/variant_encode_decode.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/strided_slice_op.h" + +namespace tensorflow { +namespace functor { + +template +struct StridedSlice { + void operator()(const Device& d, typename TTypes::Tensor output, + typename TTypes::ConstTensor input, + const Eigen::DSizes& start_indices, + const Eigen::DSizes& stop_indices, + const Eigen::DSizes& strides) { + MaybeWith32BitIndexing( + [&](auto output32, auto input32, const auto& start_indices32, + const auto& stop_indices32, const auto& strides32) { + output32.device(d) = + input32.stridedSlice(start_indices32, stop_indices32, strides32); + }, + output, input, start_indices, stop_indices, strides); + } +}; + +template +struct InitOutput { + static void run(const Device& d, typename TTypes::Tensor output) { + output.device(d) = output.constant(T(0)); + } +}; + +template +struct InitOutput { + static void run(const Device& d, + typename TTypes::Tensor output) { + output.device(d) = output.constant(ResourceHandle()); + } +}; + +template +struct InitOutput { + static void run(const Device& d, + typename TTypes::Tensor output) { + output.device(d) = output.constant(tstring()); + } +}; + +template +struct StridedSliceGrad { + void operator()(const Device& d, typename TTypes::Tensor output, + typename TTypes::ConstTensor input, + const Eigen::DSizes& start_indices, + const Eigen::DSizes& stop_indices, + const Eigen::DSizes& strides) { + InitOutput::run(d, output); + MaybeWith32BitIndexing( + [&](auto output32, const auto& start_indices32, + const auto& stop_indices32, const auto& strides32) { + output32.stridedSlice(start_indices32, stop_indices32, strides32) + .device(d) = input; + }, + output, start_indices, stop_indices, strides); + } +}; + +template +struct StridedSliceAssign { + void operator()(const Device& d, typename TTypes::Tensor output, + typename TTypes::ConstTensor input, + const Eigen::DSizes& start_indices, + const Eigen::DSizes& stop_indices, + const Eigen::DSizes& strides, + const StridedSliceAssignBCast& bcast) { + MaybeWith32BitIndexing( + [&](auto output32, auto input32, const auto& start_indices32, + const auto& stop_indices32, const auto& strides32) { + if (bcast.IsBroadcastingRequired()) { + output32.stridedSlice(start_indices32, stop_indices32, strides32) + .device(d) = input32.broadcast(bcast.bcast()); + } else { + output32.stridedSlice(start_indices32, stop_indices32, strides32) + .device(d) = input32; + } + }, + output, input, start_indices, stop_indices, strides); + } +}; + +template +struct StridedSliceAssignScalar { + void operator()(const Device& d, typename TTypes::Tensor output, + typename TTypes::ConstTensor input) { + output.device(d) = input; + } +}; + +} // namespace functor +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_STRIDED_SLICE_OP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/strided_slice_op_gpu_impl.h b/third_party/tflite-hdrs/tensorflow/core/kernels/strided_slice_op_gpu_impl.h new file mode 100644 index 00000000..23a3ff86 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/strided_slice_op_gpu_impl.h @@ -0,0 +1,63 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_STRIDED_SLICE_OP_GPU_IMPL_H_ +#define TENSORFLOW_CORE_KERNELS_STRIDED_SLICE_OP_GPU_IMPL_H_ + +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM + +#define EIGEN_USE_GPU + +#include "tensorflow/core/kernels/strided_slice_op.h" + +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +typedef Eigen::GpuDevice GPUDevice; + +#define DEFINE_GPU_KERNELS(T) \ + template struct functor::StridedSlice; \ + template struct functor::StridedSlice; \ + template struct functor::StridedSlice; \ + template struct functor::StridedSlice; \ + template struct functor::StridedSlice; \ + template struct functor::StridedSlice; \ + template struct functor::StridedSlice; \ + template struct functor::StridedSlice; \ + template struct functor::StridedSliceGrad; \ + template struct functor::StridedSliceGrad; \ + template struct functor::StridedSliceGrad; \ + template struct functor::StridedSliceGrad; \ + template struct functor::StridedSliceGrad; \ + template struct functor::StridedSliceGrad; \ + template struct functor::StridedSliceGrad; \ + template struct functor::StridedSliceGrad; \ + template struct functor::StridedSliceAssign; \ + template struct functor::StridedSliceAssign; \ + template struct functor::StridedSliceAssign; \ + template struct functor::StridedSliceAssign; \ + template struct functor::StridedSliceAssign; \ + template struct functor::StridedSliceAssign; \ + template struct functor::StridedSliceAssign; \ + template struct functor::StridedSliceAssign; \ + template struct functor::StridedSliceAssignScalar; + +} // end namespace tensorflow + +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM +#endif // TENSORFLOW_CORE_KERNELS_STRIDED_SLICE_OP_GPU_IMPL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/strided_slice_op_impl.h b/third_party/tflite-hdrs/tensorflow/core/kernels/strided_slice_op_impl.h new file mode 100644 index 00000000..01e58c9b --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/strided_slice_op_impl.h @@ -0,0 +1,304 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_STRIDED_SLICE_OP_IMPL_H_ +#define TENSORFLOW_CORE_KERNELS_STRIDED_SLICE_OP_IMPL_H_ + +// Functor definition for StridedSliceOp, must be compilable by nvcc. + +#include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/bounds_check.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/register_types_traits.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/variant.h" +#include "tensorflow/core/framework/variant_encode_decode.h" +#include "tensorflow/core/kernels/dense_update_functor.h" +#include "tensorflow/core/kernels/ops_util.h" +#include "tensorflow/core/kernels/slice_op.h" +#include "tensorflow/core/kernels/strided_slice_op.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/platform/mem.h" + +namespace tensorflow { + +template +void HandleStridedSliceCase(OpKernelContext* context, + const absl::Span& begin, + const absl::Span& end, + const absl::Span& strides, + const TensorShape& processing_shape, + bool is_simple_slice, Tensor* result); + +template +void HandleStridedSliceGradCase(OpKernelContext* context, + const absl::Span& begin, + const absl::Span& end, + const absl::Span& strides, + const TensorShape& processing_shape, + bool is_simple_slice, Tensor* result); + +template +class HandleStridedSliceAssignCase { + public: + void operator()(OpKernelContext* context, + const absl::Span& begin, + const absl::Span& end, + const absl::Span& strides, + const StridedSliceAssignBCast& bcast, Tensor* result); +}; +} // namespace tensorflow + +// The actual implementation. This is designed so multiple +// translation units can include this file in the form +// +// #define STRIDED_SLICE_INSTANTIATE_DIM 1 +// #include +// #undef STRIDED_SLICE_INSTANTIATE_DIM +// +#ifdef STRIDED_SLICE_INSTANTIATE_DIM + +namespace tensorflow { + +template +void HandleStridedSliceCase(OpKernelContext* context, + const absl::Span& begin, + const absl::Span& end, + const absl::Span& strides, + const TensorShape& processing_shape, + bool is_simple_slice, Tensor* result) { + typedef typename proxy_type::type Proxy; + + absl::InlinedVector processing_dims = + processing_shape.dim_sizes(); + if (is_simple_slice) { + Eigen::DSizes begin_di; + Eigen::DSizes sizes_di; + for (int i = 0; i < NDIM; ++i) { + begin_di[i] = begin[i]; + sizes_di[i] = end[i] - begin[i]; + } + functor::Slice()( + context->eigen_device(), + result->bit_casted_shaped(processing_dims), + context->input(0).bit_casted_tensor(), begin_di, sizes_di); + } else { + Eigen::DSizes begin_di; + Eigen::DSizes end_di; + Eigen::DSizes strides_di; + for (int i = 0; i < NDIM; ++i) { + begin_di[i] = begin[i]; + end_di[i] = end[i]; + strides_di[i] = strides[i]; + } + functor::StridedSlice()( + context->eigen_device(), + result->bit_casted_shaped(processing_dims), + context->input(0).bit_casted_tensor(), begin_di, end_di, + strides_di); + } +} + +template +void HandleStridedSliceGradCase(OpKernelContext* context, + const absl::Span& begin, + const absl::Span& end, + const absl::Span& strides, + const TensorShape& processing_shape, + bool is_simple_slice, Tensor* result) { + absl::InlinedVector processing_dims = + processing_shape.dim_sizes(); + + Eigen::DSizes begin_di; + Eigen::DSizes end_di; + Eigen::DSizes strides_di; + for (int i = 0; i < NDIM; ++i) { + begin_di[i] = begin[i]; + end_di[i] = end[i]; + strides_di[i] = strides[i]; + } + + typedef typename proxy_type::type Proxy; + functor::StridedSliceGrad()( + context->eigen_device(), result->bit_casted_tensor(), + context->input(4).bit_casted_shaped(processing_dims), + begin_di, end_di, strides_di); +} + +template +void HandleStridedSliceAssignCase::operator()( + OpKernelContext* context, const absl::Span& begin, + const absl::Span& end, + const absl::Span& strides, + const StridedSliceAssignBCast& bcast, Tensor* result) { + typedef typename proxy_type::type Proxy; + Eigen::DSizes begin_di; + Eigen::DSizes end_di; + Eigen::DSizes strides_di; + for (int i = 0; i < NDIM; ++i) { + begin_di[i] = begin[i]; + end_di[i] = end[i]; + strides_di[i] = strides[i]; + } + + constexpr int kRhsInput = 4; + const Tensor& input = context->input(kRhsInput); + functor::StridedSliceAssign()( + context->eigen_device(), result->bit_casted_tensor(), + input.bit_casted_shaped(bcast.reshape()), begin_di, end_di, + strides_di, bcast); +} + +template +class HandleStridedSliceAssignCase { + public: + enum { NDIM_PROXY = 1 }; + void operator()(OpKernelContext* context, + const absl::Span& begin, + const absl::Span& end, + const absl::Span& strides, + const StridedSliceAssignBCast& bcast, Tensor* result) { + absl::InlinedVector processing_dims(1); + processing_dims[0] = 1; + + typedef typename proxy_type::type Proxy; + functor::StridedSliceAssignScalar()( + context->eigen_device(), + result->bit_casted_shaped(processing_dims), + context->input(4).bit_casted_shaped(processing_dims)); + } +}; + +// NOTE(aselle): according to bsteiner, we need this because otherwise +// nvcc instantiates templates that are invalid. strided_slice_op_gpu.cu +// handles instantiates externally. It is important that this is done +// before the HandleXXCase's are instantiated to avoid duplicate +// specialization errors. + +#define PREVENT_INSTANTIATE_DIM1_AND_UP(T, NDIM) \ + namespace functor { \ + template <> \ + void StridedSlice::operator()( \ + const GPUDevice& d, typename TTypes::Tensor output, \ + typename TTypes::ConstTensor input, \ + const Eigen::DSizes& start, \ + const Eigen::DSizes& stop, \ + const Eigen::DSizes& strides); \ + extern template struct StridedSlice; \ + template <> \ + void Slice::operator()( \ + const GPUDevice& d, typename TTypes::Tensor output, \ + typename TTypes::ConstTensor input, \ + const Eigen::DSizes& indices, \ + const Eigen::DSizes& sizes); \ + extern template struct Slice; \ + template <> \ + void StridedSliceGrad::operator()( \ + const GPUDevice& d, typename TTypes::Tensor output, \ + typename TTypes::ConstTensor input, \ + const Eigen::DSizes& start, \ + const Eigen::DSizes& stop, \ + const Eigen::DSizes& strides); \ + extern template struct StridedSliceGrad; \ + template <> \ + void StridedSliceAssign::operator()( \ + const GPUDevice& d, typename TTypes::Tensor output, \ + typename TTypes::ConstTensor input, \ + const Eigen::DSizes& start, \ + const Eigen::DSizes& stop, \ + const Eigen::DSizes& strides, \ + const StridedSliceAssignBCast& bcast); \ + extern template struct StridedSliceAssign; \ + } // namespace functor +#define PREVENT_INSTANTIATE_DIM0_ONLY(T, NDIM) \ + namespace functor { \ + template <> \ + void StridedSliceAssignScalar::operator()( \ + const GPUDevice& d, typename TTypes::Tensor output, \ + typename TTypes::ConstTensor input); \ + extern template struct StridedSliceAssignScalar; \ + } // namespace functor + +// Dimension 0 only instantiates some functors. So we only need +// to prevent ones defined by PREVENT_INSTANTIATE_DIM0_ONLY +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM +#if STRIDED_SLICE_INSTANTIATE_DIM == 0 +#define PREVENT_INSTANTIATE(T, NDIM) PREVENT_INSTANTIATE_DIM0_ONLY(T, NDIM) +#else +#define PREVENT_INSTANTIATE(T, NDIM) PREVENT_INSTANTIATE_DIM1_AND_UP(T, NDIM) +#endif +#else +#define PREVENT_INSTANTIATE(T, NDIM) +#endif + +#define INSTANTIATE_DIM1_AND_UP_HANDLERS(DEVICE, T, DIM) \ + template void HandleStridedSliceCase( \ + OpKernelContext * context, const gtl::ArraySlice& begin, \ + const gtl::ArraySlice& end, \ + const gtl::ArraySlice& strides, \ + const TensorShape& processing_shape, bool is_simple_slice, \ + Tensor* result); \ + template void HandleStridedSliceGradCase( \ + OpKernelContext * context, const gtl::ArraySlice& begin, \ + const gtl::ArraySlice& end, \ + const gtl::ArraySlice& strides, \ + const TensorShape& processing_shape, bool is_simple_slice, \ + Tensor* result); + +#define INSTANTIATE_DIM0_AND_UP_HANDLERS(DEVICE, T, DIM) \ + template class HandleStridedSliceAssignCase; + +// Only some kernels need to be instantiated on dim 0. +#if STRIDED_SLICE_INSTANTIATE_DIM == 0 +#define INSTANTIATE(DEVICE, T, DIM) \ + INSTANTIATE_DIM0_AND_UP_HANDLERS(DEVICE, T, DIM) +#else +#define INSTANTIATE(DEVICE, T, DIM) \ + INSTANTIATE_DIM0_AND_UP_HANDLERS(DEVICE, T, DIM) \ + INSTANTIATE_DIM1_AND_UP_HANDLERS(DEVICE, T, DIM) +#endif + +#define DECLARE_FOR_N_CPU(T) \ + INSTANTIATE(CPUDevice, T, STRIDED_SLICE_INSTANTIATE_DIM) + +#define PREVENT_FOR_N_GPU(T) \ + PREVENT_INSTANTIATE(T, STRIDED_SLICE_INSTANTIATE_DIM) + +#define DECLARE_FOR_N_GPU(T) \ + INSTANTIATE(GPUDevice, T, STRIDED_SLICE_INSTANTIATE_DIM) + +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM +TF_CALL_GPU_PROXY_TYPES(PREVENT_FOR_N_GPU); +TF_CALL_COMPLEX_TYPES(PREVENT_FOR_N_GPU); + +TF_CALL_INTEGRAL_TYPES(DECLARE_FOR_N_GPU); +TF_CALL_GPU_ALL_TYPES(DECLARE_FOR_N_GPU); +#endif // END GOOGLE_CUDA || TENSORFLOW_USE_ROCM + +TF_CALL_ALL_TYPES(DECLARE_FOR_N_CPU); +TF_CALL_QUANTIZED_TYPES(DECLARE_FOR_N_CPU); +TF_CALL_float8_e5m2(DECLARE_FOR_N_CPU); +TF_CALL_float8_e4m3fn(DECLARE_FOR_N_CPU); + +#undef INSTANTIATE +#undef DECLARE_FOR_N_CPU +#undef DECLARE_FOR_N_GPU + +} // end namespace tensorflow + +#endif // END STRIDED_SLICE_INSTANTIATE_DIM +#endif // TENSORFLOW_CORE_KERNELS_STRIDED_SLICE_OP_IMPL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/string_to_hash_bucket_fast_op.h b/third_party/tflite-hdrs/tensorflow/core/kernels/string_to_hash_bucket_fast_op.h new file mode 100644 index 00000000..f9119259 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/string_to_hash_bucket_fast_op.h @@ -0,0 +1,67 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_STRING_TO_HASH_BUCKET_FAST_OP_H_ +#define TENSORFLOW_CORE_KERNELS_STRING_TO_HASH_BUCKET_FAST_OP_H_ + +#include + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/macros.h" + +namespace tensorflow { + +template +class StringToHashBucketOp : public OpKernel { + public: + explicit StringToHashBucketOp(OpKernelConstruction* ctx) : OpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("num_buckets", &num_buckets_)); + } + + void Compute(OpKernelContext* context) override { + const Tensor* input_tensor; + OP_REQUIRES_OK(context, context->input("input", &input_tensor)); + const auto& input_flat = input_tensor->flat(); + + Tensor* output_tensor = nullptr; + OP_REQUIRES_OK(context, + context->allocate_output("output", input_tensor->shape(), + &output_tensor)); + auto output_flat = output_tensor->flat(); + + typedef decltype(input_flat.size()) Index; + for (Index i = 0; i < input_flat.size(); ++i) { + const uint64 input_hash = hash(input_flat(i)); + const uint64 bucket_id = input_hash % num_buckets_; + // The number of buckets is always in the positive range of int64 so is + // the resulting bucket_id. Casting the bucket_id from uint64 to int64 is + // safe. + output_flat(i) = static_cast(bucket_id); + } + } + + private: + int64_t num_buckets_; + + StringToHashBucketOp(const StringToHashBucketOp&) = delete; + void operator=(const StringToHashBucketOp&) = delete; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_STRING_TO_HASH_BUCKET_FAST_OP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/string_to_hash_bucket_op.h b/third_party/tflite-hdrs/tensorflow/core/kernels/string_to_hash_bucket_op.h new file mode 100644 index 00000000..71fba9b6 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/string_to_hash_bucket_op.h @@ -0,0 +1,75 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_STRING_TO_HASH_BUCKET_OP_H_ +#define TENSORFLOW_CORE_KERNELS_STRING_TO_HASH_BUCKET_OP_H_ + +#include + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/macros.h" + +namespace tensorflow { + +template +class StringToKeyedHashBucketOp : public OpKernel { + public: + explicit StringToKeyedHashBucketOp(OpKernelConstruction* ctx) + : OpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("num_buckets", &num_buckets_)); + + std::vector key; + OP_REQUIRES_OK(ctx, ctx->GetAttr("key", &key)); + OP_REQUIRES(ctx, key.size() == 2, + errors::InvalidArgument("Key must have 2 elements")); + std::memcpy(key_, key.data(), sizeof(key_)); + } + + void Compute(OpKernelContext* context) override { + const Tensor* input_tensor; + OP_REQUIRES_OK(context, context->input("input", &input_tensor)); + const auto& input_flat = input_tensor->flat(); + + Tensor* output_tensor = nullptr; + OP_REQUIRES_OK(context, + context->allocate_output("output", input_tensor->shape(), + &output_tensor)); + auto output_flat = output_tensor->flat(); + + typedef decltype(input_flat.size()) Index; + for (Index i = 0; i < input_flat.size(); ++i) { + const uint64 input_hash = hash(key_, input_flat(i)); + const uint64 bucket_id = input_hash % num_buckets_; + // The number of buckets is always in the positive range of int64 so is + // the resulting bucket_id. Casting the bucket_id from uint64 to int64 is + // safe. + output_flat(i) = static_cast(bucket_id); + } + } + + private: + int64_t num_buckets_; + uint64 key_[2]; + + StringToKeyedHashBucketOp(const StringToKeyedHashBucketOp&) = delete; + void operator=(const StringToKeyedHashBucketOp&) = delete; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_STRING_TO_HASH_BUCKET_OP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/string_util.h b/third_party/tflite-hdrs/tensorflow/core/kernels/string_util.h new file mode 100644 index 00000000..58230d3d --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/string_util.h @@ -0,0 +1,88 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_KERNELS_STRING_UTIL_H_ +#define TENSORFLOW_CORE_KERNELS_STRING_UTIL_H_ + +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/core/stringpiece.h" + +namespace tensorflow { + +// Enumeration for unicode encodings. Used by ops such as +// tf.strings.unicode_encode and tf.strings.unicode_decode. +enum class UnicodeEncoding { UTF8, UTF16BE, UTF32BE }; + +// Enumeration for character units. Used by string such as +// tf.strings.length and tf.substr. +// TODO(edloper): Add support for: UTF32_CHAR, etc. +enum class CharUnit { BYTE, UTF8_CHAR }; + +// Whether or not the given byte is the trailing byte of a UTF-8/16/32 char. +inline bool IsTrailByte(char x) { return static_cast(x) < -0x40; } + +// Sets `encoding` based on `str`. +absl::Status ParseUnicodeEncoding(const string& str, UnicodeEncoding* encoding); + +// Sets `unit` value based on `str`. +absl::Status ParseCharUnit(const string& str, CharUnit* unit); + +// Returns the number of Unicode characters in a UTF-8 string. +// Result may be incorrect if the input string is not valid UTF-8. +int32 UTF8StrLen(const string& str); + +// Get the next UTF8 character position starting at the given position and +// skipping the given number of characters. Position is a byte offset, and +// should never be `null`. The function return true if successful. However, if +// the end of the string is reached before the requested characters, then the +// position will point to the end of string and this function will return false. +template +bool ForwardNUTF8CharPositions(const absl::string_view in, + const T num_utf8_chars_to_shift, T* pos) { + const size_t size = in.size(); + T utf8_chars_counted = 0; + while (utf8_chars_counted < num_utf8_chars_to_shift && *pos < size) { + // move forward one utf-8 character + do { + ++*pos; + } while (*pos < size && IsTrailByte(in[*pos])); + ++utf8_chars_counted; + } + return utf8_chars_counted == num_utf8_chars_to_shift; +} + +// Get the previous UTF8 character position starting at the given position and +// skipping the given number of characters. Position is a byte offset with a +// positive value, relative to the beginning of the string, and should never be +// `null`. The function return true if successful. However, if the beginning of +// the string is reached before the requested character, then the position will +// point to the beginning of the string and this function will return false. +template +bool BackNUTF8CharPositions(const absl::string_view in, + const T num_utf8_chars_to_shift, T* pos) { + const size_t start = 0; + T utf8_chars_counted = 0; + while (utf8_chars_counted < num_utf8_chars_to_shift && (*pos > start)) { + // move back one utf-8 character + do { + --*pos; + } while (IsTrailByte(in[*pos]) && *pos > start); + ++utf8_chars_counted; + } + return utf8_chars_counted == num_utf8_chars_to_shift; +} + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_STRING_UTIL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/summary_interface.h b/third_party/tflite-hdrs/tensorflow/core/kernels/summary_interface.h new file mode 100644 index 00000000..f423d4ab --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/summary_interface.h @@ -0,0 +1,64 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_KERNELS_SUMMARY_INTERFACE_H_ +#define TENSORFLOW_CORE_KERNELS_SUMMARY_INTERFACE_H_ + +#include + +#include "tensorflow/core/framework/resource_mgr.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +class Event; +class GraphDef; + +// Main interface for the summary writer resource. +class SummaryWriterInterface : public ResourceBase { + public: + virtual ~SummaryWriterInterface() override {} + + // Flushes all unwritten messages in the queue. + virtual absl::Status Flush() = 0; + + // These are called in the OpKernel::Compute methods for the summary ops. + virtual absl::Status WriteTensor(int64_t global_step, Tensor t, + const string& tag, + const string& serialized_metadata) = 0; + + virtual absl::Status WriteScalar(int64_t global_step, Tensor t, + const string& tag) = 0; + + virtual absl::Status WriteHistogram(int64_t global_step, Tensor t, + const string& tag) = 0; + + virtual absl::Status WriteImage(int64_t global_step, Tensor t, + const string& tag, int max_images, + Tensor bad_color) = 0; + + virtual absl::Status WriteAudio(int64_t global_step, Tensor t, + const string& tag, int max_outputs_, + float sample_rate) = 0; + + virtual absl::Status WriteGraph(int64_t global_step, + std::unique_ptr graph) = 0; + + virtual absl::Status WriteEvent(std::unique_ptr e) = 0; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_SUMMARY_INTERFACE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/tensor_array.h b/third_party/tflite-hdrs/tensorflow/core/kernels/tensor_array.h new file mode 100644 index 00000000..aef4a97b --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/tensor_array.h @@ -0,0 +1,629 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_TENSOR_ARRAY_H_ +#define TENSORFLOW_CORE_KERNELS_TENSOR_ARRAY_H_ + +#include + +#include + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/partial_tensor_shape.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/resource_mgr.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/kernels/aggregate_ops.h" +#include "tensorflow/core/kernels/fill_functor.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +typedef Eigen::ThreadPoolDevice CPUDevice; +typedef Eigen::GpuDevice GPUDevice; + +namespace tensor_array { + +// Full implementations are in tensor_array.cc +template +absl::Status AddToTensor(OpKernelContext* ctx, Tensor* sum, + const Tensor* current, const Tensor* add) { + return errors::InvalidArgument( + "tensor_array::AddToTensor type not supported: ", + DataTypeString(DataTypeToEnum::value)); +} + +#define TENSOR_ARRAY_WRITE_OR_ADD(Device, T) \ + template <> \ + Status AddToTensor(OpKernelContext * ctx, Tensor * sum, \ + const Tensor* current, const Tensor* add); + +#define TENSOR_ARRAY_WRITE_OR_ADD_CPU(T) TENSOR_ARRAY_WRITE_OR_ADD(CPUDevice, T) +TF_CALL_NUMBER_TYPES(TENSOR_ARRAY_WRITE_OR_ADD_CPU) +#undef TENSOR_ARRAY_WRITE_OR_ADD_CPU + +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM + +#define TENSOR_ARRAY_WRITE_OR_ADD_GPU(T) TENSOR_ARRAY_WRITE_OR_ADD(GPUDevice, T) +TF_CALL_GPU_NUMBER_TYPES(TENSOR_ARRAY_WRITE_OR_ADD_GPU); +TF_CALL_COMPLEX_TYPES(TENSOR_ARRAY_WRITE_OR_ADD_GPU); +#undef TENSOR_ARRAY_WRITE_OR_ADD_GPU + +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM + +#undef TENSOR_ARRAY_WRITE_OR_ADD + +template +absl::Status TensorSetZero(OpKernelContext* ctx, Tensor* value) { + return errors::InvalidArgument( + "tensor_array::TensorSetZero type not supported: ", + DataTypeString(DataTypeToEnum::value)); +} + +#define TENSOR_ARRAY_SET_ZERO(Device, T) \ + template <> \ + Status TensorSetZero(OpKernelContext * ctx, Tensor * value); + +#define TENSOR_ARRAY_SET_ZERO_CPU(T) TENSOR_ARRAY_SET_ZERO(CPUDevice, T) +TF_CALL_NUMBER_TYPES(TENSOR_ARRAY_SET_ZERO_CPU); +TF_CALL_bool(TENSOR_ARRAY_SET_ZERO_CPU); +#undef TENSOR_ARRAY_SET_ZERO_CPU + +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM + +#define TENSOR_ARRAY_SET_ZERO_GPU(T) TENSOR_ARRAY_SET_ZERO(GPUDevice, T) +TF_CALL_GPU_NUMBER_TYPES(TENSOR_ARRAY_SET_ZERO_GPU); +TF_CALL_COMPLEX_TYPES(TENSOR_ARRAY_SET_ZERO_GPU); +#undef TENSOR_ARRAY_SET_ZERO_GPU + +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM + +#undef TENSOR_ARRAY_SET_ZERO + +} // namespace tensor_array + +// The TensorArray object keeps an array of Tensors. It allows reading from the +// array and writing to the array. +// +// Important properties: +// * Usually, writing to a particular index in the TensorArray is allowed at +// most once per index. In a special case, writes with the flag +// multiple_writes_aggregate allow multiple writes to the same +// index. In this case, the writes are summed. +// * Multiple reads are supported. +// * Deep copies of Tensors are rarely made. The only time they are made is +// when WriteOrAggregate is called at least twice on the same index with the +// flag multiple_writes_aggregate = True. +// * Reading and Writing to the array is protected by a mutex. +// All operations on a TensorArray are thread-safe. +// * A TensorArray may be preemptively closed, which releases all +// memory associated with it. +// +// These properties together allow the TensorArray to work as a +// functional object and makes gradient computation easy. For +// example: +// * Write-Once semantics mean the gradient of a TensorArray Read never has to +// worry which of multiple writes to that index the gradient value +// is meant for. +// * Read-Many semantics (when using clear_after_read=false) allow the +// TensorArray to be read, packed, or concatenated multiple times; +// and the gradient operations use the multiple_writes_aggregate +// flag to aggregate the backprop writes. Multiple backprop writes to +// the same index are partial gradients corresponding to the +// multiple reads of that index in the forward phase. +// +class TensorArray : public ResourceBase { + public: + static std::atomic tensor_array_counter; + + // Construct a TensorArray for holding Tensors of type 'dtype' with + // 'N' elements. While the underlying storage is a std::vector and + // can hold more than MAX_INT entries, in practice we do not expect + // users to construct this many Tensors for storage in a TensorArray. + TensorArray(const string& key, const DataType& dtype, const Tensor& handle, + int32_t N, const PartialTensorShape& element_shape, + bool identical_element_shapes, bool dynamic_size, + bool multiple_writes_aggregate, bool is_grad, int32_t marked_size, + bool clear_after_read) + : key_(key), + dtype_(dtype), + handle_(handle), + closed_(false), + dynamic_size_(dynamic_size), + multiple_writes_aggregate_(multiple_writes_aggregate), + gradients_disallowed_(false), + clear_after_read_(clear_after_read), + is_grad_(is_grad), + marked_size_(marked_size), + element_shape_(element_shape), + identical_element_shapes_(identical_element_shapes), + tensors_(N) {} + + // Write Tensor 'value' to index 'index'. + // + // Preconditions: + // * The TensorArray is not closed + // * If the array has dynamic size: + // The index is >= 0 + // Otherwise: + // The index is in [0, N) where N == Size() + // * The dtype of the Tensor in 'value' matches the TensorArray's dtype. + // * If multiple_writes_aggregate is false: + // The Tensor at 'index' has not yet been written to. + // * If multiple_writes_aggregate is true: + // The Tensor at 'index' has the same shape as value. + // + // Side effects: + // * On the first write to 'index': + // - The underlying Tensor in 'value' has a new reference to it. + // - The index 'index' is marked as written. + // * If multiple_writes_aggregate is false, subsequent writes to 'index' + // raise an InvalidArgument error. + // * If multiple_writes_aggregate is true, subsequent writes to 'index': + // - The underlying Tensors in 'value' and from the first write + // are released and a local Tensor is created. + // - Index 'index' is also marked as local_copy. + // - The gradients_disallowed flag is set true (GradientsAllowed() + // will now return false). + // + // Note, value is passed as a pointer because we its underlying + // Tensor's shape is accessed. Otherwise it is not modified. + template + absl::Status WriteOrAggregate(OpKernelContext* ctx, const int32_t index, + const Tensor* value) { + mutex_lock l(mu_); + return LockedWriteOrAggregate(ctx, index, value); + } + + template + absl::Status WriteOrAggregateMany(OpKernelContext* ctx, + const std::vector& indices, + std::vector* values) { + mutex_lock l(mu_); + int32_t i = 0; + for (const int32_t ix : indices) { + absl::Status s = + LockedWriteOrAggregate(ctx, ix, &(*values)[i]); + ++i; + TF_RETURN_IF_ERROR(s); + } + return absl::OkStatus(); + } + + // Read from index 'index' into Tensor 'value'. + // + // Preconditions: + // * The TensorArray is not closed + // * The index is in [0, N) + // * The Tensor at 'index' has been written to. + // * The Tensor at 'index' has not been read from with flag + // clear_after_read = true. + // + // Side effects: + // * If clear_after_read is true, the reference to the underlying + // Tensor is deleted. + // * The reference to the underlying Tensor at 'index' is copied to + // the returned '*value'. + // * The index is marked as read (it cannot be rewritten to). + template + absl::Status Read(OpKernelContext* ctx, const int32_t index, Tensor* value) { + mutex_lock l(mu_); + return LockedRead(ctx, index, value); + } + + template + absl::Status ReadMany(OpKernelContext* ctx, const std::vector& indices, + std::vector* values) { + mutex_lock l(mu_); + values->clear(); + values->resize(indices.size()); + int32_t i = 0; + for (const int32_t ix : indices) { + absl::Status s = LockedRead(ctx, ix, &(*values)[i]); + ++i; + if (!s.ok()) return s; + } + return absl::OkStatus(); + } + + DataType ElemType() const { return dtype_; } + + PartialTensorShape ElemShape() { + mutex_lock l(mu_); + return element_shape_; + } + + absl::Status SetElemShape(const PartialTensorShape& candidate) { + mutex_lock l(mu_); + PartialTensorShape new_element_shape_; + absl::Status s = element_shape_.MergeWith(candidate, &new_element_shape_); + if (!s.ok()) { + return s; + } + element_shape_ = new_element_shape_; + return absl::OkStatus(); + } + + string DebugString() const override { + mutex_lock l(mu_); + CHECK(!closed_); + return strings::StrCat("TensorArray[", tensors_.size(), "]"); + } + + bool IsClosed() { + mutex_lock l(mu_); + return closed_; + } + + // Return the size of the TensorArray. + absl::Status Size(int32* size) { + mutex_lock l(mu_); + TF_RETURN_IF_ERROR(LockedReturnIfClosed()); + *size = tensors_.size(); + return absl::OkStatus(); + } + + // Record the size of the TensorArray after an unpack or split. + absl::Status SetMarkedSize(int32_t size) { + mutex_lock l(mu_); + TF_RETURN_IF_ERROR(LockedReturnIfClosed()); + if (!is_grad_) { + marked_size_ = size; + } + return absl::OkStatus(); + } + + // Return the marked size of the TensorArray. + absl::Status MarkedSize(int32* size) { + mutex_lock l(mu_); + TF_RETURN_IF_ERROR(LockedReturnIfClosed()); + *size = marked_size_; + return absl::OkStatus(); + } + + // Return the size that should be used by pack or concat op. + absl::Status PackOrConcatSize(int32* size) { + mutex_lock l(mu_); + TF_RETURN_IF_ERROR(LockedReturnIfClosed()); + *size = is_grad_ ? marked_size_ : tensors_.size(); + return absl::OkStatus(); + } + + // Once a TensorArray is being used for gradient calculations, it + // should be marked as no longer resizeable. + void DisableDynamicSize() { + mutex_lock l(mu_); + dynamic_size_ = false; + } + + bool HasDynamicSize() { + mutex_lock l(mu_); + return dynamic_size_; + } + + bool GradientsAllowed() { + mutex_lock l(mu_); + return !gradients_disallowed_; + } + + bool HasIdenticalElementShapes() const { return identical_element_shapes_; } + + // Copy the TensorShapes from another TensorArray into this one. + // If `shapes_to_prepend` is set, expands the rank of the copied shape by + // prepending the passed in shape prefix to the shape values in `rhs`. + // The sizes of the two TensorArrays must match and this one + // may not have any entries filled in. This performs a "soft copy", + // essentially filling the current TensorArray with virtual + // zero-tensors, which will be replaced by future aggregate writes, + // or instantiated by future reads. Requires a non-const pointer + // to the rhs to access its mutex. + absl::Status CopyShapesFrom(TensorArray* rhs, + const TensorShape* shape_to_prepend); + + // Clear the TensorArray, including any Tensor references, and mark as closed. + void ClearAndMarkClosed() { + mutex_lock l(mu_); + tensors_.clear(); + closed_ = true; + } + + mutex* mu() { return &mu_; } + Tensor* handle() { return &handle_; } + + ResourceHandle resource_handle(OpKernelContext* ctx) { + return ctx->step_container()->MakeResourceHandle( + key_, *ctx->device()); + } + + private: + absl::Status LockedWrite(OpKernelContext* ctx, const int32_t index, + Tensor* value) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); + + template + absl::Status LockedWriteOrAggregate(OpKernelContext* ctx, const int32_t index, + const Tensor* value) + TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); + + template + absl::Status LockedRead(OpKernelContext* ctx, const int32_t index, + Tensor* value) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); + + absl::Status LockedReturnIfClosed() const TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { + if (closed_) { + return errors::InvalidArgument("TensorArray ", handle_.vec()(1), + " has already been closed."); + } + return absl::OkStatus(); + } + + const string key_; + + const DataType dtype_; + Tensor handle_; + + mutable mutex mu_; + + // Marks that the tensor_array_ has been cleared. + bool closed_ TF_GUARDED_BY(mu_); + + // Writes are allowed to grow the array. + bool dynamic_size_; + + // Multiple writes to the same index will result in summation of the + // values (used by backprop) + const bool multiple_writes_aggregate_; + + // If multiple Writes were attempted (e.g. via attribute + // multiple_writes_aggregate), then gradients are disallowed. + bool gradients_disallowed_ TF_GUARDED_BY(mu_); + + // After a read at an index, clear away its Tensor to release memory. + const bool clear_after_read_; + + // True iff this is a gradient tensor array. + const bool is_grad_; + + // The size of the TensorArray after a (legacy) unpack or split is performed. + // -1 if there has been no unpack or split performed on the TensorArray. + int32 marked_size_; + + // The shape of each element in the TensorArray, may be partially known or not + // known at all. + PartialTensorShape element_shape_ TF_GUARDED_BY(mu_); + + // Whether all elements in the TensorArray have identical shapes. + // This allows certain behaviors, like dynamically checking for + // consistent shapes on write, and being able to fill in properly + // shaped zero tensors on stack -- even if the initial element_shape + // was not fully defined. + const bool identical_element_shapes_; + + // TensorAndState is used to keep track of the Tensors stored in the + // TensorArray, along with their shapes, and a boolean that determines whether + // they have already been read or not. + struct TensorAndState { + TensorAndState() + : written(false), read(false), cleared(false), local_copy(false) {} + Tensor tensor; + TensorShape shape; + bool written; // True if a Tensor has been written to the index. + bool read; // True if a Tensor has been written to and read from the index. + bool cleared; // True if a tensor has been read with + // clear_after_read = true; + + // Used by writes when multiple_writes_aggregate is true. In this + // case, the first time a value is written, it is a shallow copy. + // The second time a value is written, it is aggregated. However, + // in this case a new Tensor must be constructed to hold the + // aggregated value. This flag marks that such a Tensor is being + // used. All future writes will aggregate to the existing local Tensor. + bool local_copy; + }; + // The list of underlying Tensors and states. + std::vector tensors_ TF_GUARDED_BY(mu_); +}; + +template +absl::Status TensorArray::LockedWriteOrAggregate(OpKernelContext* ctx, + const int32_t index, + const Tensor* value) { + TF_RETURN_IF_ERROR(LockedReturnIfClosed()); + size_t index_size = static_cast(index); + if (index < 0 || (!dynamic_size_ && index_size >= tensors_.size())) { + return errors::InvalidArgument( + "TensorArray ", handle_.vec()(1), ": Tried to write to index ", + index, " but array is not resizeable and size is: ", tensors_.size()); + } + if (dynamic_size_) { + // We must grow the internal TensorArray + if (index_size >= tensors_.capacity()) { + tensors_.reserve(2 * (index_size + 1)); + } + if (index_size >= tensors_.size()) { + tensors_.resize(index_size + 1); + } + } + TensorAndState& t = tensors_[index]; + + if (value->dtype() != dtype_) { + return errors::InvalidArgument( + "TensorArray ", handle_.vec()(1), + ": Could not write to TensorArray index ", index, + " because the value dtype is ", DataTypeString(value->dtype()), + " but TensorArray dtype is ", DataTypeString(dtype_), "."); + } + if (!element_shape_.IsCompatibleWith(value->shape())) { + return errors::InvalidArgument( + "TensorArray ", handle_.vec()(1), + ": Could not write to TensorArray index ", index, + " because the value shape is ", value->shape().DebugString(), + " which is incompatible with the TensorArray's inferred element " + "shape: ", + element_shape_.DebugString(), " (consider setting infer_shape=False)."); + } else if (identical_element_shapes_ && !element_shape_.IsFullyDefined()) { + element_shape_ = PartialTensorShape(value->shape().dim_sizes()); + } + + if (t.read) { + return errors::InvalidArgument("TensorArray ", handle_.vec()(1), + ": Could not write to TensorArray index ", + index, " because it has already been read."); + } + + if (!multiple_writes_aggregate_ && t.written) { + return errors::InvalidArgument("TensorArray ", handle_.vec()(1), + ": Could not write to TensorArray index ", + index, + " because it has already been written to."); + } + + if (t.written) { + DCHECK(multiple_writes_aggregate_); + + // Check that value shape matches t.shape + if (value->shape() != t.shape) { + return errors::InvalidArgument( + "TensorArray ", handle_.vec()(1), + ": Could not aggregate to TensorArray index ", index, + " because the existing shape is ", t.shape.DebugString(), + " but the new input shape is ", value->shape().DebugString(), "."); + } + + if (!t.tensor.IsInitialized() || t.tensor.NumElements() == 0) { + // If existing_t == nullptr but written == true, then what was stored + // was just a shape, which just means zeros. So all we must do in this + // case is copy the reference over and return early. + t.tensor = *value; + return absl::OkStatus(); + } + + Tensor* existing_t = &t.tensor; + + if (t.local_copy) { + absl::Status s = tensor_array::AddToTensor(ctx, existing_t, + existing_t, value); + TF_RETURN_IF_ERROR(s); + } else { + Tensor local_tensor; + TF_RETURN_IF_ERROR( + ctx->allocate_temp(dtype_, existing_t->shape(), &local_tensor)); + absl::Status s = tensor_array::AddToTensor(ctx, &local_tensor, + existing_t, value); + TF_RETURN_IF_ERROR(s); + t.tensor = local_tensor; + t.local_copy = true; + } + + // We've aggregated the values, so disallow backprop on this + // TensorArray. + gradients_disallowed_ = true; + } else { + t.tensor = *value; + t.shape = value->shape(); + t.written = true; + } + return absl::OkStatus(); +} + +template +absl::Status TensorArray::LockedRead(OpKernelContext* ctx, const int32_t index, + Tensor* value) { + TF_RETURN_IF_ERROR(LockedReturnIfClosed()); + if ((index < 0) || + (!is_grad_ && (static_cast(index) >= tensors_.size()))) { + return errors::InvalidArgument("Tried to read from index ", index, + " but array size is: ", tensors_.size()); + } + size_t index_t = static_cast(index); + if ((is_grad_ && (index_t >= tensors_.size() || !tensors_[index].written)) || + (!is_grad_ && (index_t < tensors_.size() && !tensors_[index].written))) { + // Special case returning zeros if this is a gradient read that happens + // after a stop_gradients call with dynamic forward TensorArrays. + // There is sometimes a race condition where the gradient is not + // written due to stop_gradients, but is later read. + TensorShape element_shape; + if (is_grad_ && index_t < tensors_.size() && + tensors_[index].shape.dims() > 0) { + // A gradient TensorArray has more specific gradient information + // available for each entry. A forward TensorArray must rely on + // the global element_shape_ to fill in zeros on read. + element_shape = tensors_[index].shape; + } else if (!element_shape_.IsFullyDefined()) { + return errors::InvalidArgument( + "TensorArray ", handle_.vec()(1), + ": Could not read from TensorArray index ", index, + ". Furthermore, the element shape is not fully defined: ", + element_shape_.DebugString(), + ". It is possible you are working with a resizeable TensorArray and " + "stop_gradients is not allowing the gradients to be written. If you " + "set the full " + "element_shape property on the forward TensorArray, the proper " + "all-zeros tensor " + "will be returned instead of incurring this error."); + } else { + element_shape_.AsTensorShape(&element_shape); // Always succeeds. + } + if (index_t >= tensors_.size()) { + // Fill in tensors_ up to index to have known shape. + size_t old_tensors_size = tensors_.size(); + tensors_.resize(index + 1); + for (size_t i = old_tensors_size; i < index + 1; ++i) { + tensors_[i].shape = element_shape; + tensors_[i].written = true; + } + } else { + tensors_[index].shape = element_shape; + tensors_[index].written = true; + } + } + + TensorAndState& t = tensors_[index]; + + if (t.cleared) { + return errors::InvalidArgument("TensorArray ", handle_.vec()(1), + ": Could not read index ", index, + " twice because it was cleared after a " + "previous read (perhaps try setting " + "clear_after_read = false?)."); + } + + if (!t.tensor.IsInitialized() || t.tensor.NumElements() == 0) { + // We stored just a shape, but no value. This means create and + // return zeros of the appropriate shape. + TF_RETURN_IF_ERROR(ctx->allocate_temp(dtype_, t.shape, &t.tensor)); + if (t.shape.num_elements() > 0) { + absl::Status s = tensor_array::TensorSetZero(ctx, &t.tensor); + if (!s.ok()) return s; + } + } + + // Data is available inside the tensor, copy the reference over. + *value = t.tensor; + + if (clear_after_read_) { + t.tensor = Tensor(); + t.cleared = true; + } + t.read = true; + return absl::OkStatus(); +} + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_TENSOR_ARRAY_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/tensor_cord.h b/third_party/tflite-hdrs/tensorflow/core/kernels/tensor_cord.h new file mode 100644 index 00000000..2d3d4e3f --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/tensor_cord.h @@ -0,0 +1,363 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_TENSOR_CORD_H_ +#define TENSORFLOW_CORE_KERNELS_TENSOR_CORD_H_ + +#include +#include + +#include "absl/container/inlined_vector.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "tensorflow/core/framework/variant_tensor_data.h" + +namespace tensorflow { + +typedef void (*CordRepReleaser)(void*); + +class TensorCord { + // A TensorCord keeps a view into some data, and a cleanup method to clean up + // that data when the TensorCord destructor is called. Copying a TensorCord + // increments a reference count to the cleanup method, and so the cleanup + // method is only called when all copies of the original TensorCord are + // cleared. + // + // Example: + // + // const string& s = t.scalar()(); + // TensorCord tc(s, &t); + // ASSERT_EQ(s, tc.view()); + // TensorCord copy(tc); + // tc = TensorCord(); // cleanup not called; the reference is held by `copy`. + // copy = TensorCord(); // cleanup happens now, the reference is destroyed. + // + // Another example: + // + // void TensorProtoDeleter(void* ptr) { + // delete static_cast(ptr); + // } + // + // auto p = std::make_unique(...); + // absl::string_view content(p->tensor_content()); + // TensorCord tc(content, TensorProtoDeleter, p.release()); + // + + public: + static constexpr const char kTypeName[] = "tensorflow::TensorCord"; + + TensorCord() : chunks_() {} + + ~TensorCord(); + + // Args: + // `view`: should point to a location in memory that is guaranteed to remain + // valid until `releaser` is called. + // `releaser`: A callback that will be executed when there are no references + // left on `view`. It will be called via `releaser(memory)`. + // `memory`: The argument passed to `releaser` when it is called. + // + // You are STRONGLY advised to provide a non-null `releaser`, and a pointer + // to the underlying data (while ensuring that the data will not be deleted + // until `releaser(memory)` is called). Otherwise the TensorCord may + // outlive the data backing `view`. + TensorCord(absl::string_view view, CordRepReleaser releaser, + void* memory = nullptr) + : chunks_({new CordRep(view, releaser, memory)}) {} + + // Args: + // `view`: should point to a location in memory backed by `tensor`, + // e.g., `view` is a string_view on a tstring which is an element + // of `tensor`. Furthermore, the associated tstring is not expected + // to be modified in such a way that the underlying memory will + // be changed after this TensorCord is created. + TensorCord(absl::string_view view, Tensor* tensor) + : chunks_({NewCordRepFromTensor(view, tensor)}) {} + + // Disallow construction with empty callback or empty tensor. + TensorCord(absl::string_view view, std::nullptr_t, void* memory) = delete; + TensorCord(absl::string_view view, std::nullptr_t) = delete; + + TensorCord(const TensorCord& other); + + TensorCord(TensorCord&& other) noexcept; + + TensorCord& operator=(const TensorCord& other); + + TensorCord& operator=(TensorCord&& other) noexcept; + + void Append(const TensorCord& other); + + void Append(absl::string_view view, CordRepReleaser releaser, + void* memory = nullptr); + + void Append(absl::string_view view, Tensor* tensor); + + // Disallow Appends with empty callbacks or empty tensors. + void Append(absl::string_view view, std::nullptr_t, void* memory) = delete; + void Append(absl::string_view view, std::nullptr_t) = delete; + + size_t size() const; + bool empty() const { return size() == 0; } + + // NOTE: This performs an expensive copy of the underlying data. + explicit operator string() const; + + class ChunkIterator { + public: + using iterator_category = std::input_iterator_tag; + using value_type = absl::string_view; + using difference_type = ptrdiff_t; + using pointer = const value_type*; + using reference = value_type; + + ChunkIterator& operator++(); + + ChunkIterator operator++(int) { + ChunkIterator tmp(*this); + operator++(); + return tmp; + } + + bool operator==(const ChunkIterator& other) const { + return (cord_ == other.cord_ && chunk_index_ == other.chunk_index_); + } + + bool operator!=(const ChunkIterator& other) const { + return !(*this == other); + } + reference operator*() const { + assert(cord_ != nullptr); + return view_; + } + pointer operator->() const { + assert(cord_ != nullptr); + return &view_; + } + + friend class TensorCord; + + private: + // Constructs a `begin()` iterator from `cord`. + explicit ChunkIterator(const TensorCord* cord, int chunk_index); + + const TensorCord* const cord_; + int chunk_index_; + absl::string_view view_; + }; + + class ChunkRange { + public: + explicit ChunkRange(const TensorCord* cord) : cord_(cord) {} + + ChunkIterator begin() const { return ChunkIterator(cord_, 0); } + + ChunkIterator end() const { + return ChunkIterator(cord_, cord_->chunks_.size()); + } + + private: + const TensorCord* cord_; + }; + + // Note that the ordinary caveats of temporary lifetime extension apply: + // + // void Process() { + // for (absl::string_view chunk : CordFactory().Chunks()) { + // // The temporary Cord returned by CordFactory has been destroyed! + // } + // } + ChunkRange Chunks() const { return ChunkRange(this); } + + ChunkIterator chunk_begin() const { return ChunkIterator(this, 0); } + + ChunkIterator chunk_end() const { + return ChunkIterator(this, chunks_.size()); + } + + static string TypeName() { return kTypeName; } + + string DebugString() const { + return absl::StrCat(""); + } + + void Encode(VariantTensorData* data) const; + + bool Decode(VariantTensorData data); + + private: + void Cleanup(); + + class CordRep : public core::RefCounted { + public: + CordRep(absl::string_view view, CordRepReleaser releaser, + void* arg = nullptr) + : is_inline_(false), rep_(view, releaser, arg) {} + + // **WARNING** Only use this constructor if + // view.size() < CordRep::kMaxInlineSize. + explicit CordRep(absl::string_view view) : is_inline_(true), rep_(view) {} + + ~CordRep() override; + + absl::string_view view() const { + if (is_inline_) { + return absl::string_view( + rep_.internal.data() + 1, + *reinterpret_cast(rep_.internal.data())); + } else { + return rep_.external.view; + } + } + + private: + friend class TensorCord; + + struct ExternalRep { + absl::string_view view; + CordRepReleaser releaser; + void* arg; + + ExternalRep(absl::string_view view_, CordRepReleaser releaser_, + void* arg_) + : view(view_), releaser(releaser_), arg(arg_) {} + }; + + // We save the size in the first byte, so subtract 1. + static constexpr int kMaxInlineSize = sizeof(ExternalRep) - 1; + static_assert(kMaxInlineSize < 255, + "Cannot store size of InlineRep in a single byte."); + + // The first byte stores the size as a uint8. The rest of the bytes are the + // string itself. + using InlineRep = std::array; + + // Member variables. + const bool is_inline_; + const union _rep_union { + InlineRep internal; + ExternalRep external; + + _rep_union(absl::string_view view, CordRepReleaser releaser, void* arg) + : external(view, releaser, arg) {} + + explicit _rep_union(absl::string_view view) { + DCHECK_LT(view.size(), kMaxInlineSize); + *reinterpret_cast(internal.data()) = view.size(); + std::memcpy(static_cast(internal.data() + 1), view.data(), + view.size()); + } + } rep_; + }; + + static TensorBuffer* TensorBufWithRef(Tensor* tensor); + static void TensorBufReleaser(void* tensor_buffer); + static void StringReleaser(void* str_ptr); + static CordRep* NewCordRepFromTensor(absl::string_view view, Tensor* tensor); + + absl::InlinedVector chunks_; +}; + +inline TensorCord::TensorCord(const TensorCord& other) + : chunks_(other.chunks_) { + for (auto* rep : chunks_) { + rep->Ref(); + } +} + +inline TensorCord::TensorCord(TensorCord&& other) noexcept + : chunks_(std::move(other.chunks_)) { + other.chunks_.clear(); +} + +inline TensorCord& TensorCord::operator=(const TensorCord& other) { + Cleanup(); + chunks_ = other.chunks_; + for (auto* rep : chunks_) { + rep->Ref(); + } + return *this; +} + +inline TensorCord& TensorCord::operator=(TensorCord&& other) noexcept { + Cleanup(); + std::swap(chunks_, other.chunks_); + return *this; +} + +inline void TensorCord::Append(const TensorCord& other) { + for (auto* rep : other.chunks_) { + chunks_.push_back(rep); + rep->Ref(); + } +} + +inline void TensorCord::Append(absl::string_view view, CordRepReleaser releaser, + void* memory) { + chunks_.push_back(new CordRep(view, releaser, memory)); +} + +inline void TensorCord::Append(absl::string_view view, Tensor* tensor) { + chunks_.push_back(NewCordRepFromTensor(view, tensor)); +} + +inline size_t TensorCord::size() const { + return (chunks_.empty()) + ? 0 + : std::accumulate(chunk_begin(), chunk_end(), 0, + [](size_t acc, absl::string_view b) { + return acc + b.size(); + }); +} + +inline TensorCord::ChunkIterator& TensorCord::ChunkIterator::operator++() { + assert(cord_ != nullptr); + assert(chunk_index_ < cord_->chunks_.size()); + chunk_index_ += 1; + if (chunk_index_ != cord_->chunks_.size()) { + view_ = cord_->chunks_[chunk_index_]->view(); + } + return *this; +} + +inline TensorCord::ChunkIterator::ChunkIterator(const TensorCord* cord, + int index) + : cord_(cord), chunk_index_(index) { + if (index < cord_->chunks_.size()) { + view_ = cord_->chunks_[index]->view(); + } +} + +inline TensorCord::CordRep* TensorCord::NewCordRepFromTensor( + absl::string_view view, Tensor* tensor) { + if (view.size() <= TensorCord::CordRep::kMaxInlineSize) { + return new CordRep(view); + } else { + return new CordRep(view, &TensorBufReleaser, TensorBufWithRef(tensor)); + } +} + +inline void TensorCord::Cleanup() { + if (chunks_.empty()) return; + for (auto* rep : chunks_) { + rep->Unref(); + } + chunks_.clear(); +} + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_TENSOR_CORD_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/tensor_flag_utils.h b/third_party/tflite-hdrs/tensorflow/core/kernels/tensor_flag_utils.h new file mode 100644 index 00000000..f20ecad7 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/tensor_flag_utils.h @@ -0,0 +1,78 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Helpers for parsing tensors as runtime flags. +#ifndef TENSORFLOW_CORE_KERNELS_TENSOR_FLAG_UTILS_H_ +#define TENSORFLOW_CORE_KERNELS_TENSOR_FLAG_UTILS_H_ + +#include +#include + +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +namespace tensor_flag_utils { + +// Converts tensor.vec to an std::vector object, appends +// the value num_nonzero_entries_in_sparse_mat, and returns the result. +template +std::vector ParseRowStartIndices( + const tensorflow::Tensor& tensor, + const Tindices num_nonzero_entries_in_sparse_mat); + +// Returns OkStatus() if and only if config is a float scalar or a matrix with +// dimensions M x 3. If config is a scalar then config must be in the range +// [0, 1.0). If config is a matrix then config must have shape M x 3, all of +// its entries must be positive, and entries in the last column may not +// exceed 1.0. If config is a matrix then it may not be empty. +absl::Status ValidateSparseMatrixShardingConfig(const Tensor& config); + +// Returns OkStatus() if and only if config is a float scalar or a non-empty +// matrix with dimensions M x 2. +absl::Status ValidateScalarQuantityShardingConfig(const Tensor& config); + +// Returns the last entry of the first row in config_mat for which the first +// two entries are no smaller than the respective entries in key. If no such +// row exists then returns the last entry in the last row in config_mat. +// config_mat may not be empty. +template +MatrixType FindConfigValueForKey( + const typename TTypes::ConstMatrix& config_mat, + const std::pair& key); + +// Returns the last entry of the first row in config_mat for which the first +// two entries are no smaller than the respective entries in key. If no such +// row exists then returns the last entry in the last row in config_mat. +// config_mat may not be empty. +template +MatrixType FindConfigValueForKey( + const typename TTypes::ConstMatrix& config_mat, const K key); + +// Returns largest multiple of bucket_size less than value. +// Expects 1 <= bucket_size <= value. +template +Tindices GetLinearBucket(const Tindices value, const Tindices bucket_size); + +// Returns the largest power of bucket_size less than value. +// Expects 1 <= bucket_size <= value. If bucket_size = 1, returns 1. +template +Tindices GetPowerBucket(const Tindices value, const Tindices bucket_size); + +} // namespace tensor_flag_utils +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_TENSOR_FLAG_UTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/tensor_list.h b/third_party/tflite-hdrs/tensorflow/core/kernels/tensor_list.h new file mode 100644 index 00000000..5d3921cf --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/tensor_list.h @@ -0,0 +1,160 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_KERNELS_TENSOR_LIST_H_ +#define TENSORFLOW_CORE_KERNELS_TENSOR_LIST_H_ + +#include + +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/variant.h" +#include "tensorflow/core/framework/variant_tensor_data.h" +#include "tensorflow/core/lib/core/refcount.h" + +namespace tensorflow { + +// Variant compatible type for a list of tensors. This is mutable but instances +// should never be mutated after stored in a variant tensor. +// +// **NOTE**: TensorList stores a refcounted container of tf::Tensor objects, +// which are accessible via TensorList::tensors(). Because it is refcounted, +// straight copies of the form: +// +// TensorList b = a; +// b.tensors().push_back(t); // WARNING: This modifies a.tensors(). +// +// Do not create a true copy of the underlying container - but instead increment +// a reference count. Modifying b.tensors() modifies a.tensors(). In this way, +// TensorList should be considered similar to the tf::Tensor object. +// +// In order to get a copy of the underlying list, use the Copy method: +// +// TensorList b = a.Copy(); +// b.tensors().push_back(t); // This does not modify a.tensors(). +// +// Note that this is not a deep copy: the memory locations of the underlying +// tensors will still point to the same locations of the corresponding tensors +// in the original. To truly perform a deep copy, Device and Type-specific +// code needs to be applied to the underlying tensors as usual. +// +// The most important implication of RefCounted TLs is that OpKernels +// wishing to reuse TensorList inputs as outputs via context->forward_input() +// need to perform an additional check on the refcount of the TensorList, +// to ensure aliasing can be performed safely. For example: +// +// bool can_alias = false; +// auto fw = c->forward_input(..., DT_VARIANT, {}, ...); +// if (fw && fw->dtype() == DT_VARIANT && fw->NumElements() == 1) { +// auto* tl = fw->scalar()().get(); +// if (tl && tl->RefCountIsOne()) { +// can_alias = true; +// } +// } +// +class TensorList { + public: + TensorList() : tensors_(new Tensors) {} + ~TensorList(); + + TensorList(const TensorList& other) + : element_shape(other.element_shape), + element_dtype(other.element_dtype), + max_num_elements(other.max_num_elements), + tensors_(other.tensors_) { + tensors_->Ref(); + } + + TensorList(TensorList&& rhs) + : element_shape(std::move(rhs.element_shape)), + element_dtype(rhs.element_dtype), + max_num_elements(rhs.max_num_elements), + tensors_(rhs.tensors_) { + rhs.tensors_ = nullptr; + } + + TensorList& operator=(const TensorList& rhs) { + if (this == &rhs) return *this; + element_shape = rhs.element_shape; + element_dtype = rhs.element_dtype; + max_num_elements = rhs.max_num_elements; + tensors_->Unref(); + tensors_ = rhs.tensors_; + tensors_->Ref(); + return *this; + } + + TensorList& operator=(TensorList&& rhs) { + if (this == &rhs) return *this; + element_shape = rhs.element_shape; + element_dtype = rhs.element_dtype; + max_num_elements = rhs.max_num_elements; + std::swap(tensors_, rhs.tensors_); + return *this; + } + + static const char kTypeName[]; + + string TypeName() const { return kTypeName; } + + void Encode(VariantTensorData* data) const; + + bool Decode(const VariantTensorData& data); + + // TODO(apassos) fill this out + string DebugString() const { return "TensorList"; } + + PartialTensorShape element_shape; + + DataType element_dtype; + + // The maximum allowed size of `tensors`. Defaults to -1 meaning that the size + // of `tensors` is unbounded. + int max_num_elements = -1; + + // Access to the underlying tensor container. + std::vector& tensors() { return tensors_->values_; } + const std::vector& tensors() const { return tensors_->values_; } + + // Get a new TensorList containing a copy of the underlying tensor container. + TensorList Copy() const { + TensorList out; + out.element_shape = element_shape; + out.element_dtype = element_dtype; + out.max_num_elements = max_num_elements; + // This performs a copy of the std::vector. + out.tensors_->values_ = tensors_->values_; + return out; + } + + // Is this TensorList the only one with a reference to the underlying + // container? + bool RefCountIsOne() const { return tensors_->RefCountIsOne(); } + + private: + class Tensors : public core::RefCounted { + public: + std::vector values_; + }; + Tensors* tensors_; +}; + +#if defined(PLATFORM_GOOGLE) +// TODO(ebrevdo): Identify why Variant inline size is smaller on mobile devices. +// For 32-bit devices, it's acceptable not to inline. +static_assert(Variant::CanInlineType() || sizeof(void*) < 8, + "Must be able to inline TensorList into a Variant"); +#endif +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_TENSOR_LIST_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/tensor_list_util.h b/third_party/tflite-hdrs/tensorflow/core/kernels/tensor_list_util.h new file mode 100644 index 00000000..7ffabce8 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/tensor_list_util.h @@ -0,0 +1,43 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_KERNELS_TENSOR_LIST_UTIL_H_ +#define TENSORFLOW_CORE_KERNELS_TENSOR_LIST_UTIL_H_ + +#include + +#include "tensorflow/core/platform/status.h" + +namespace tensorflow { + +class OpKernelContext; +class TensorList; +class Tensor; + +absl::Status TensorListBinaryAdd( + OpKernelContext* c, const TensorList& a, const TensorList& b, + TensorList* out, + std::function + binary_add_func); + +absl::Status TensorListZerosLike( + OpKernelContext* c, const TensorList& x, TensorList* y, + std::function + zeros_like_func); + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_TENSOR_LIST_UTIL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/tensor_map.h b/third_party/tflite-hdrs/tensorflow/core/kernels/tensor_map.h new file mode 100644 index 00000000..cb4c827c --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/tensor_map.h @@ -0,0 +1,181 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_KERNELS_TENSOR_MAP_H_ +#define TENSORFLOW_CORE_KERNELS_TENSOR_MAP_H_ + +#include + +#include "absl/container/flat_hash_map.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_key.h" +#include "tensorflow/core/framework/variant.h" +#include "tensorflow/core/framework/variant_tensor_data.h" +#include "tensorflow/core/lib/core/refcount.h" + +namespace tensorflow { + +// Variant compatible type for a map of tensors. This is mutable but instances +// should never be mutated after stored in a variant tensor. +// +// **NOTE**: TensorMap stores a refcounted container of tf::Tensor objects, +// which are accessible via TensorMap::tensors(). Because it is refcounted, +// straight copies of the form: +// +// TensorMap b = a; +// b.tensors().insert(k,v); // WARNING: This modifies a.tensors(). +// +// Do not create a true copy of the underlying container - but instead increment +// a reference count. Modifying b.tensors() modifies a.tensors(). In this way, +// TensorMap should be considered similar to the tf::Tensor object. +// +// In order to get a copy of the underlying map, use the Copy method: +// +// TensorMap b = a.Copy(); +// b.tensors().insert(k, v); // This does not modify a.tensors(). +// +// Note that this is not a deep copy: the memory locations of the underlying +// tensors will still point to the same locations of the corresponding tensors +// in the original. To truly perform a deep copy, Device and Type-specific +// code needs to be applied to the underlying tensors as usual. +// +// The most important implication of RefCounted TensorMaps is that OpKernels +// wishing to reuse TensorMap inputs as outputs via context->forward_input() +// need to perform an additional check on the refcount of the TensorList, +// to ensure aliasing can be performed safely. For example: +// +// bool can_alias = false; +// auto fw = c->forward_input(..., DT_VARIANT, {}, ...); +// if (fw && fw->dtype() == DT_VARIANT && fw->NumElements() == 1) { +// auto* tl = fw->scalar()().get(); +// if (tl && tl->RefCountIsOne()) { +// can_alias = true; +// } +// } +// +class TensorMap { + public: + TensorMap() : tensors_(new Tensors) {} + ~TensorMap(); + + TensorMap(const TensorMap& other) : tensors_(other.tensors_) { + tensors_->Ref(); + } + + TensorMap(TensorMap&& rhs) : tensors_(rhs.tensors_) { + rhs.tensors_ = nullptr; + } + + TensorMap& operator=(const TensorMap& rhs) { + if (this == &rhs) return *this; + tensors_->Unref(); + tensors_ = rhs.tensors_; + tensors_->Ref(); + return *this; + } + + TensorMap& operator=(TensorMap&& rhs) { + if (this == &rhs) return *this; + std::swap(tensors_, rhs.tensors_); + return *this; + } + + static const char kTypeName[]; + + string TypeName() const { return kTypeName; } + + void Encode(VariantTensorData* data) const; + + bool Decode(const VariantTensorData& data); + + // TODO(apassos) fill this out + string DebugString() const { return "TensorMap"; } + + // Access to the underlying tensor container. + absl::flat_hash_map& tensors() { + return tensors_->values_; + } + + const absl::flat_hash_map& tensors() const { + return tensors_->values_; + } + + // Get a new TensorMap containing a copy of the underlying tensor container. + TensorMap Copy() const { + TensorMap out; + // This performs a copy of the absl::hashmap. + out.tensors_->values_ = tensors_->values_; + return out; + } + + // Insert key and value if the key does not already exist. + // Returns true if the insertion happens. + bool insert(const TensorKey& key, const Tensor& value) { + auto r = tensors_->values_.try_emplace(key, value); + return r.second; + } + + // Lookup given key. Returns iterator to found key or end. + absl::flat_hash_map::iterator find(TensorKey key) { + return tensors_->values_.find(key); + } + + Tensor& lookup(TensorKey key) { return tensors_->values_.find(key)->second; } + + Tensor& operator[](TensorKey& k) { return tensors_->values_[k]; } + + bool replace(const TensorKey& k, const Tensor& v) { + tensors_->values_[k] = v; + return true; + } + + // Removes element with given key. Return size of removed element. + size_t erase(TensorKey key) { return tensors_->values_.erase(key); } + + // Size returns the number of elements in the map + size_t size() const { return tensors_->values_.size(); } + + std::vector keys() const { + std::vector keys; + keys.reserve(tensors_->values_.size()); + absl::flat_hash_map::iterator it = + tensors_->values_.begin(); + while (it != tensors_->values_.end()) { + keys.push_back(it->first); + it++; + } + return keys; + } + + // Is this TensorMap the only one with a reference to the underlying + // container? + bool RefCountIsOne() const { return tensors_->RefCountIsOne(); } + + private: + class Tensors : public core::RefCounted { + public: + absl::flat_hash_map values_; + }; + Tensors* tensors_; +}; + +#if defined(PLATFORM_GOOGLE) +// TODO(ebrevdo): Identify why Variant inline size is smaller on mobile devices. +// For 32-bit devices, it's acceptable not to inline. +static_assert(Variant::CanInlineType() || sizeof(void*) < 8, + "Must be able to inline TensorMap into a Variant"); +#endif +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_TENSOR_MAP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/tensor_to_hash_bucket_op.h b/third_party/tflite-hdrs/tensorflow/core/kernels/tensor_to_hash_bucket_op.h new file mode 100644 index 00000000..cdf7dab2 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/tensor_to_hash_bucket_op.h @@ -0,0 +1,80 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_TENSOR_TO_HASH_BUCKET_OP_H_ +#define TENSORFLOW_CORE_KERNELS_TENSOR_TO_HASH_BUCKET_OP_H_ + +#include + +#include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/strings/stringprintf.h" +#include "tensorflow/core/platform/fingerprint.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +namespace functor { + +template +struct LaunchTensorToHashBucket { + void operator()(OpKernelContext* c, const int64_t num_buckets, const T* input, + const int num_elems, int64_t* output) { + string format = "%"; + switch (DataTypeToEnum::value) { + case DT_INT8: + case DT_INT16: + case DT_INT32: + strings::Appendf(&format, "d"); + break; + case DT_INT64: + strings::Appendf(&format, "lld"); + break; + default: + bool type_not_supported = true; + OP_REQUIRES( + c, !type_not_supported, + errors::InvalidArgument("Type not supported: ", + DataTypeString(DataTypeToEnum::value))); + } + + for (int i = 0; i < num_elems; ++i) { + string input_str = strings::Printf(format.c_str(), input[i]); + const uint64 input_hash = Fingerprint64(input_str); + const uint64 bucket_id = input_hash % num_buckets; + // The number of buckets is always in the positive range of int64 so is + // the resulting bucket_id. Casting the bucket_id from uint64 to int64 is + // safe. + output[i] = static_cast(bucket_id); + } + } +}; + +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM +template +struct LaunchTensorToHashBucket { + void operator()(OpKernelContext* c, const int64_t num_buckets, const T* input, + const int num_elems, int64_t* output); +}; +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM +} // namespace functor + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_TENSOR_TO_HASH_BUCKET_OP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/tile_functor.h b/third_party/tflite-hdrs/tensorflow/core/kernels/tile_functor.h new file mode 100644 index 00000000..d5f27eca --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/tile_functor.h @@ -0,0 +1,110 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_TILE_FUNCTOR_H_ +#define TENSORFLOW_CORE_KERNELS_TILE_FUNCTOR_H_ + +#include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive + +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +namespace internal { + +// Device-specific naive implementation for Tile. + +template +void TileSimple(const Eigen::ThreadPoolDevice& d, Tensor* out, + const Tensor& in); + +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM +template +void TileSimple(const Eigen::GpuDevice& d, Tensor* out, const Tensor& in); +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM + +template +void TileUsingEigen(const Device& d, Tensor* out, const Tensor& in, + const gtl::ArraySlice broadcast_array) { + Eigen::array b; + for (int i = 0; i < NDIM; ++i) b[i] = broadcast_array[i]; + MaybeWith32BitIndexing( + [&](auto out32, auto in32) { out32.device(d) = in32.broadcast(b); }, + out->tensor(), in.tensor()); +} + +template +void TileUsingEigen(const Device& d, Tensor* out, const Tensor& in, + const gtl::ArraySlice) { + auto x = in.tensor(); + auto y = out->tensor(); + // In the scalar case we simply copy the input. + y.device(d) = x; +} + +} // end namespace internal + +namespace functor { + +template +struct Tile { + void operator()(const Device& d, Tensor* out, const Tensor& in, + const gtl::ArraySlice broadcast_array) const { + switch (in.dims()) { + case 0: + internal::TileUsingEigen(d, out, in, + broadcast_array); + break; + case 1: + internal::TileUsingEigen(d, out, in, + broadcast_array); + break; + case 2: + internal::TileUsingEigen(d, out, in, + broadcast_array); + break; + case 3: + internal::TileUsingEigen(d, out, in, + broadcast_array); + break; + case 4: + internal::TileUsingEigen(d, out, in, + broadcast_array); + break; + case 5: + internal::TileUsingEigen(d, out, in, + broadcast_array); + break; + case 6: + internal::TileUsingEigen(d, out, in, + broadcast_array); + break; + case 7: + internal::TileUsingEigen(d, out, in, + broadcast_array); + break; + default: + internal::TileSimple(d, out, in); + break; + } + } +}; + +} // end namespace functor +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_TILE_FUNCTOR_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/tile_functor_cpu.h b/third_party/tflite-hdrs/tensorflow/core/kernels/tile_functor_cpu.h new file mode 100644 index 00000000..dee100e1 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/tile_functor_cpu.h @@ -0,0 +1,57 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_KERNELS_TILE_FUNCTOR_CPU_H_ +#define TENSORFLOW_CORE_KERNELS_TILE_FUNCTOR_CPU_H_ + +#define EIGEN_USE_THREADS + +#include "tensorflow/core/kernels/ops_util.h" +#include "tensorflow/core/kernels/tile_functor.h" + +namespace tensorflow { +namespace internal { + +template +void TileSimpleImpl(const Device& d, Tensor* out, const Tensor& in) { + const int ndims = in.dims(); + const int64_t nelem = out->NumElements(); + absl::InlinedVector in_strides = + ComputeStride(in.shape()); + absl::InlinedVector out_strides = + ComputeStride(out->shape()); + const T* p = in.flat().data(); + T* q = out->flat().data(); + + for (int64_t o_idx = 0; o_idx < nelem; ++o_idx) { + int64_t i_idx = 0; + int64_t t = o_idx; + for (int i = 0; i < ndims; ++i) { + i_idx += t / out_strides[i] % in.dim_size(i) * in_strides[i]; + t %= out_strides[i]; + } + q[o_idx] = p[i_idx]; + } +} + +template +void TileSimple(const Eigen::ThreadPoolDevice& d, Tensor* out, + const Tensor& in) { + return TileSimpleImpl(d, out, in); +} + +} // namespace internal +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_TILE_FUNCTOR_CPU_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/tile_functor_gpu.h b/third_party/tflite-hdrs/tensorflow/core/kernels/tile_functor_gpu.h new file mode 100644 index 00000000..8d825a68 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/tile_functor_gpu.h @@ -0,0 +1,91 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_TILE_FUNCTOR_GPU_H_ +#define TENSORFLOW_CORE_KERNELS_TILE_FUNCTOR_GPU_H_ + +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM + +#define EIGEN_USE_GPU + +#include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/kernels/ops_util.h" +#include "tensorflow/core/kernels/tile_functor.h" +#include "tensorflow/core/util/gpu_kernel_helper.h" + +namespace tensorflow { +namespace internal { + +template +__global__ void TileKernel(int nthreads, const T* __restrict__ src, + const int32* __restrict__ buf, const int32 ndims, + T* __restrict__ dst) { + const int32* in_strides = buf; + const int32* out_strides = buf + ndims; + const int32* in_dim_sizes = buf + ndims * 2; + GPU_1D_KERNEL_LOOP(o_idx, nthreads) { + int32 i_idx = 0; + int32 t = o_idx; + for (int i = 0; i < ndims; ++i) { + i_idx += t / out_strides[i] % in_dim_sizes[i] * in_strides[i]; + t %= out_strides[i]; + } + dst[o_idx] = ldg(src + i_idx); + } +} + +template +void TileSimple(const Eigen::GpuDevice& d, Tensor* out, const Tensor& in) { + // Ensures we can use 32-bit index. + const int64 in_nelem = in.NumElements(); + CHECK_LT(in_nelem, kint32max) << "Tensor too large to transpose on GPU"; + const int64 out_nelem = out->NumElements(); + CHECK_LT(out_nelem, kint32max) << "Tensor too large to transpose on GPU"; + // Pack strides and input dimension sizes into one buffer. + const int32 ndims = in.dims(); + gtl::InlinedVector host_buf(ndims * 3); + gtl::InlinedVector in_strides = ComputeStride(in.shape()); + gtl::InlinedVector out_strides = ComputeStride(out->shape()); + for (int i = 0; i < ndims; ++i) { + host_buf[i] = in_strides[i]; + host_buf[ndims + i] = out_strides[i]; + host_buf[ndims * 2 + i] = in.dim_size(i); + } + // Copies the input strides, output strides and input dimension sizes to the + // device. + auto num_bytes = sizeof(int32) * host_buf.size(); + auto dev_buf = d.allocate(num_bytes); + // NOTE: host_buf is not allocated by GpuHostAllocator, and + // therefore we are doing a sync copy effectively. + d.memcpyHostToDevice(dev_buf, host_buf.data(), num_bytes); + // Launch kernel to q[...] = p[...]. + const T* p = in.flat().data(); + T* q = out->flat().data(); + GpuLaunchConfig cfg = GetGpuLaunchConfig(out_nelem, d); + TF_CHECK_OK( + GpuLaunchKernel(TileKernel, cfg.block_count, cfg.thread_per_block, 0, + d.stream(), cfg.virtual_thread_count, p, + reinterpret_cast(dev_buf), ndims, q)); + // Safe to deallocate immediately after the kernel launch. + d.deallocate(dev_buf); +} + +} // end namespace internal +} // namespace tensorflow + +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM + +#endif // TENSORFLOW_CORE_KERNELS_TILE_FUNCTOR_GPU_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/tile_ops_cpu_impl.h b/third_party/tflite-hdrs/tensorflow/core/kernels/tile_ops_cpu_impl.h new file mode 100644 index 00000000..066954a1 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/tile_ops_cpu_impl.h @@ -0,0 +1,52 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_KERNELS_TILE_OPS_CPU_IMPL_H_ +#define TENSORFLOW_CORE_KERNELS_TILE_OPS_CPU_IMPL_H_ + +#define EIGEN_USE_THREADS + +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/kernels/tile_ops_impl.h" + +namespace tensorflow { + +namespace functor { + +typedef Eigen::ThreadPoolDevice CPUDevice; + +// Register functors used for TileGradientOp. +#define DEFINE_DIM(T, NDIM) \ + template struct TileGrad; \ + template struct ReduceAndReshape; +#define DEFINE_TYPE(T) DEFINE_DIM(T, CPU_PROVIDED_IXDIM) + +TF_CALL_float(DEFINE_TYPE); +TF_CALL_bfloat16(DEFINE_TYPE); +TF_CALL_double(DEFINE_TYPE); +TF_CALL_int16(DEFINE_TYPE); +TF_CALL_int32(DEFINE_TYPE); +TF_CALL_int64(DEFINE_TYPE); +TF_CALL_half(DEFINE_TYPE); +TF_CALL_complex64(DEFINE_TYPE); +TF_CALL_complex128(DEFINE_TYPE); + +#undef DEFINE_DIM +#undef DEFINE_TYPE + + +} // end namespace functor +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_TILE_OPS_CPU_IMPL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/tile_ops_gpu_impl.h b/third_party/tflite-hdrs/tensorflow/core/kernels/tile_ops_gpu_impl.h new file mode 100644 index 00000000..f1bbbf1e --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/tile_ops_gpu_impl.h @@ -0,0 +1,60 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_TILE_OPS_GPU_IMPL_H_ +#define TENSORFLOW_CORE_KERNELS_TILE_OPS_GPU_IMPL_H_ + +// Header used to split up compilation of GPU tile ops. For each type you want +// to have tile ops, create a .cu.cc file containing +// +// #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM +// #include "tensorflow/core/kernels/tile_ops_gpu_impl.h" +// DEFINE_TILE_OPS(NDIM) +// #endif // GOOGLE_CUDA +// +// where NDIM is an integer. +// +// NOTE(keveman): Eigen's int8 and string versions don't compile yet with nvcc. + +#if !GOOGLE_CUDA && !TENSORFLOW_USE_ROCM +#error "This header must be included inside with CUDA or ROCm defined" +#endif + +#define EIGEN_USE_GPU + +#include +#include "tensorflow/core/framework/numeric_types.h" +#include "tensorflow/core/kernels/tile_ops_impl.h" + +#define DEFINE_DIM(T, NDIM) \ + template struct TileGrad; \ + template struct ReduceAndReshape; + +#define DEFINE_TILE_OPS(NDIM) \ + namespace tensorflow { \ + namespace functor { \ + DEFINE_DIM(int16, NDIM) \ + DEFINE_DIM(int32, NDIM) \ + DEFINE_DIM(int64, NDIM) \ + DEFINE_DIM(Eigen::half, NDIM) \ + DEFINE_DIM(Eigen::bfloat16, NDIM) \ + DEFINE_DIM(float, NDIM) \ + DEFINE_DIM(double, NDIM) \ + DEFINE_DIM(complex64, NDIM) \ + DEFINE_DIM(complex128, NDIM) \ + } \ + } + +#endif // TENSORFLOW_CORE_KERNELS_TILE_OPS_GPU_IMPL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/tile_ops_impl.h b/third_party/tflite-hdrs/tensorflow/core/kernels/tile_ops_impl.h new file mode 100644 index 00000000..9f9a11b4 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/tile_ops_impl.h @@ -0,0 +1,71 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_TILE_OPS_IMPL_H_ +#define TENSORFLOW_CORE_KERNELS_TILE_OPS_IMPL_H_ + +#include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +namespace functor { + +template +struct TileGrad { + void operator()(const Device& d, typename TTypes::Tensor out, + typename TTypes::ConstTensor in, + const Eigen::DSizes& indices, + const Eigen::DSizes& sizes, + bool first) const { + if (first) { + out.device(d) = in.slice(indices, sizes); + } else { + out.device(d) += in.slice(indices, sizes); + } + } +}; + +template +struct TileGrad { + void operator()(const Device& d, typename TTypes::Tensor out, + typename TTypes::ConstTensor in, + const Eigen::DSizes&, + const Eigen::DSizes&, + bool first) const { + if (first) { + out.device(d) = in; + } else { + out.device(d) += in; + } + } +}; + +template +struct ReduceAndReshape { + void operator()( + const Device& d, typename TTypes::Tensor out, + typename TTypes::ConstTensor in, + const Eigen::DSizes& reduce_dim, + const Eigen::DSizes& reshape_dim) const { + out.device(d) = in.sum(reduce_dim).reshape(reshape_dim); + } +}; + +} // end namespace functor +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_TILE_OPS_IMPL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/topk_op.h b/third_party/tflite-hdrs/tensorflow/core/kernels/topk_op.h new file mode 100644 index 00000000..cdebb07f --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/topk_op.h @@ -0,0 +1,42 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_TOPK_OP_H_ +#define TENSORFLOW_CORE_KERNELS_TOPK_OP_H_ + +#include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/core/errors.h" + +namespace tensorflow { + +namespace functor { + +template +struct TopKFunctor { + static absl::Status Compute(OpKernelContext* context, bool sorted, int k, + const typename TTypes::ConstTensor& input, + const int64_t num_rows, const int64_t num_cols, + typename TTypes::Tensor values, + typename TTypes::Tensor indices); +}; + +} // end namespace functor + +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_TOPK_OP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/topk_op_gpu.h b/third_party/tflite-hdrs/tensorflow/core/kernels/topk_op_gpu.h new file mode 100644 index 00000000..26162abc --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/topk_op_gpu.h @@ -0,0 +1,597 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_KERNELS_TOPK_OP_GPU_H_ +#define TENSORFLOW_CORE_KERNELS_TOPK_OP_GPU_H_ + +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM + +#define EIGEN_USE_GPU + +#include +#include +#include + +#include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/kernels/gpu_prim.h" +#include "tensorflow/core/kernels/gpu_prim_helpers.h" +#include "tensorflow/core/kernels/topk_op.h" +#include "tensorflow/core/lib/gtl/top_n.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/gpu_kernel_helper.h" + +namespace tensorflow { + +typedef Eigen::GpuDevice GPUDevice; + +namespace impl { + +enum class HeapType { kMinHeap, kMaxHeap }; +enum class PreferIndices { kLower, kHigher }; + +template +struct Entry { + int index; + T value; + + // Test-only. + static bool greater(const Entry& a, const Entry& b) { + if (a.value == b.value) { + return a.index < b.index; + } + return a.value > b.value; + } +}; + +template +struct LinearData { + typedef impl::Entry Entry; + + __device__ Entry& operator[](std::size_t index) const { return data[index]; } + + __device__ int get_index(int i) const { return data[i].index; } + __device__ T get_value(int i) const { return data[i].value; } + + Entry* const data; +}; + +template +struct IndirectLinearData { + typedef impl::Entry Entry; + + __device__ Entry& operator[](std::size_t index) const { return data[index]; } + + __device__ int get_index(int i) const { + return backing_data[data[i].index].index; + } + __device__ T get_value(int i) const { return data[i].value; } + + Entry* const data; + Entry* const backing_data; +}; + +template +struct StridedData { + typedef impl::Entry Entry; + + __device__ Entry& operator[](std::size_t index) const { + return data[index * blockDim.x + threadIdx.x]; + } + + __device__ int get_index(int i) const { return (*this)[i].index; } + __device__ T get_value(int i) const { return (*this)[i].value; } + + Entry* const data; +}; + +// A heap of Entry that can either work as a min-heap or as a max-heap. +template class Data, typename T> +struct IndexedHeap { + typedef typename Data::Entry Entry; + const Data data; + __device__ IndexedHeap(const Data& d) : data(d) {} + + __device__ bool is_above(int left, int right) { + T left_value = data.get_value(left); + T right_value = data.get_value(right); + if (left_value == right_value) { + if (preferIndices == PreferIndices::kLower) { + return data.get_index(left) < data.get_index(right); + } else { + return data.get_index(left) > data.get_index(right); + } + } + if (heapType == HeapType::kMinHeap) { + return left_value < right_value; + } else { + return left_value > right_value; + } + } + + __device__ void assign(int i, const Entry& entry) { data[i] = entry; } + + __device__ void push_up(int i) { + int child = i; + int parent; + for (; child > 0; child = parent) { + parent = (child - 1) / 2; + if (!is_above(child, parent)) { + // Heap property satisfied. + break; + } + swap(child, parent); + } + } + + __device__ void swap(int a, int b) { + auto tmp = data[b]; + data[b] = data[a]; + data[a] = tmp; + } + + __device__ void push_root_down(int k) { push_down(0, k); } + + // MAX-HEAPIFY in Cormen + __device__ void push_down(int node, int k) { + while (true) { + const int left = 2 * node + 1; + const int right = left + 1; + int smallest = node; + if (left < k && is_above(left, smallest)) { + smallest = left; + } + if (right < k && is_above(right, smallest)) { + smallest = right; + } + if (smallest == node) { + break; + } + swap(smallest, node); + node = smallest; + } + } + + // BUILD-MAX-HEAPIFY in Cormen + __device__ void build(int k) { + for (int node = (k - 1) / 2; node >= 0; node--) { + push_down(node, k); + } + } + + // HEAP-EXTRACT-MAX in Cormen + __device__ void remove_root(int k) { + data[0] = data[k - 1]; + push_root_down(k - 1); + } + + // in-place HEAPSORT in Cormen + // This method destroys the heap property. + __device__ void sort(int k) { + for (int slot = k - 1; slot > 0; slot--) { + // This is like remove_root but we insert the element at the end. + swap(slot, 0); + // Heap is now an element smaller. + push_root_down(/*k=*/slot); + } + } + + __device__ void replace_root(const Entry& entry, int k) { + data[0] = entry; + push_root_down(k); + } + + __device__ const Entry& root() { return data[0]; } +}; + +template class Data, typename T> +__device__ IndexedHeap make_indexed_heap( + typename Data::Entry* data) { + return IndexedHeap{Data{data}}; +} + +// heapTopK walks over [input, input+length) with `step_size` stride starting at +// `start_index`. +// It builds a top-`k` heap that is stored in `heap_entries` using `Accessor` to +// access elements in `heap_entries`. If sorted=true, the elements will be +// sorted at the end. +template class Data = LinearData> +__device__ void heapTopK(const T* __restrict__ input, int length, int k, + Entry* __restrict__ heap_entries, + bool sorted = false, int start_index = 0, + int step_size = 1) { + assert(k <= length); + + auto heap = + make_indexed_heap( + heap_entries); + + int heap_end_index = start_index + k * step_size; + if (heap_end_index > length) { + heap_end_index = length; + } + // Initialize the min-heap. + for (int index = start_index, slot = 0; index < heap_end_index; + index += step_size, slot++) { + heap.assign(slot, {index, input[index]}); + } + + heap.build(k); + + // Now iterate over the remaining items. + // If an item is smaller than the min element, it is not amongst the top k. + // Otherwise, replace the min element with it and push upwards. + for (int index = heap_end_index; index < length; index += step_size) { + // We prefer elements with lower indices. This is given here. + // Later elements automatically have higher indices, so can be discarded. + if (input[index] > heap.root().value) { + // This element should replace the min. + heap.replace_root({index, input[index]}, k); + } + } + + // Sort if wanted. + if (sorted) { + heap.sort(k); + } +} + +// mergeShards performs a top-k merge on `num_shards` many sorted streams that +// are sorted and stored in `entries` in a strided way: +// |s_1 1st|s_2 1st|...s_{num_shards} 1st|s_1 2nd|s_2 2nd|... +// The overall top k elements are written to `top_k_values` and their indices +// to top_k_indices. +// `top_k_heap` is used as temporary storage for the merge heap. +template +__device__ void mergeShards(int num_shards, int k, + Entry* __restrict__ entries, + Entry* __restrict__ top_k_heap, T* top_k_values, + int* top_k_indices) { + // If k < num_shards, we can use a min-heap with k elements to get the top k + // of the sorted blocks. + // If k > num_shards, we can initialize a min-heap with the top element from + // each sorted block. + const int heap_size = k < num_shards ? k : num_shards; + + // Min-heap part. + { + auto min_heap = IndexedHeap{ + IndirectLinearData{top_k_heap, entries}}; + // Initialize the heap as a min-heap. + for (int slot = 0; slot < heap_size; slot++) { + min_heap.assign(slot, {slot, entries[slot].value}); + } + min_heap.build(heap_size); + + // Now perform top k with the remaining shards (if num_shards > heap_size). + for (int shard = heap_size; shard < num_shards; shard++) { + const auto entry = entries[shard]; + const auto root = min_heap.root(); + if (entry.value < root.value) { + continue; + } + if (entry.value == root.value && + entry.index > entries[root.index].index) { + continue; + } + // This element should replace the min. + min_heap.replace_root({shard, entry.value}, heap_size); + } + } + + // Max-part. + { + // Turn the min-heap into a max-heap in-place. + auto max_heap = IndexedHeap{ + IndirectLinearData{top_k_heap, entries}}; + // Heapify into a max heap. + max_heap.build(heap_size); + + // Now extract the minimum k-1 times. + // k is treated specially. + const int last_k = k - 1; + for (int rank = 0; rank < last_k; rank++) { + const Entry& max_element = max_heap.root(); + top_k_values[rank] = max_element.value; + int shard_index = max_element.index; + top_k_indices[rank] = entries[shard_index].index; + int next_shard_index = shard_index + num_shards; + // For rank < k-1, each top k heap still contains at least 1 element, + // so we can draw a replacement. + max_heap.replace_root({next_shard_index, entries[next_shard_index].value}, + heap_size); + } + + // rank == last_k. + const Entry& max_element = max_heap.root(); + top_k_values[last_k] = max_element.value; + int shard_index = max_element.index; + top_k_indices[last_k] = entries[shard_index].index; + } +} + +#if GOOGLE_CUDA +extern __shared__ char shared_memory[]; +#endif // GOOGLE_CUDA + +template +#if TENSORFLOW_USE_ROCM +__attribute__((amdgpu_flat_work_group_size(1, 256))) +#endif // TENSORFLOW_USE_ROCM +__global__ void +TopKKernel(const T* __restrict__ input, int length, int k, bool sorted, + T* __restrict__ output, int* __restrict__ indices) { +#if TENSORFLOW_USE_ROCM + HIP_DYNAMIC_SHARED(char, shared_memory); +#endif // TENSORFLOW_USE_ROCM + + const int batch_index = blockIdx.x; + const T* batch_input = input + batch_index * length; + + const int thread_index = threadIdx.x; + const int thread_count = blockDim.x; + + Entry* shared_entries = (Entry*)shared_memory; + + heapTopK(batch_input, length, k, shared_entries, true, + thread_index, thread_count); + + __syncthreads(); + if (thread_index == 0) { + const int offset = batch_index * k; + auto batch_output = output + offset; + auto batch_indices = indices + offset; + Entry* top_k_heap = shared_entries + thread_count * k; + + // TODO(blackhc): Erich says: Performance can likely be improved + // significantly by having the merge be done by multiple threads rather than + // just one. ModernGPU has some nice primitives that could help with this. + mergeShards(thread_count, k, shared_entries, top_k_heap, batch_output, + batch_indices); + } +} + +template +cudaError LaunchTopKKernel(const gpuStream_t& stream, int num_shards, + const T* input, int batch_size, int length, int k, + bool sorted, T* output, int* indices) { + // This code assumes that k is small enough that the computation + // fits inside shared memory (hard coded to 48KB). In practice this + // means k <= 3072 for T=float/int32 and k <= 2048 for T=double/int64. + // The calculation is: + // shared_memory_size / (2 * (sizeof(int) + sizeof(T))) < k. + + // Use as many shards as possible. + if (num_shards <= 0) { + constexpr auto shared_memory_size = 48 << 10; // 48 KB + const auto heap_size = k * sizeof(Entry); + // shared_memory_size = (num_shards + 1) * heap_size <=> + num_shards = shared_memory_size / heap_size - 1; + if (num_shards <= 0) { + num_shards = 1; + } + auto shard_size = length / num_shards; + auto min_shard_size = 2 * k; + if (shard_size < min_shard_size) { + num_shards = length / min_shard_size; + } + if (num_shards <= 0) { + num_shards = 1; +#if GOOGLE_CUDA + } else if (num_shards > 1024) { + num_shards = 1024; + } +#elif TENSORFLOW_USE_ROCM + // ROCm can't execute with 1024 and requires an explicit + // amdgpu_flat_work_group_size attribute with >256 + } else if (num_shards > 256) { + num_shards = 256; + } +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM + } + // We are limited by the amount of shared memory we have per block. + auto shared_memory_size = (num_shards + 1) * k * sizeof(Entry); + + TF_CHECK_OK(GpuLaunchKernel(TopKKernel, batch_size, num_shards, + shared_memory_size, stream, input, length, k, + sorted, output, indices)); + return cudaGetLastError(); +} + +struct SegmentOffsetCreator { + EIGEN_DEVICE_FUNC + SegmentOffsetCreator(int num_cols) : num_cols_(num_cols) {} + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE int operator()(int idx) const { + return idx * num_cols_; + } + + int num_cols_; +}; + +struct ColumnIndexCreator { + ColumnIndexCreator(int num_cols) : num_cols_(num_cols) {} + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE int operator()( + const Eigen::array& ix) const { + return ix[0] % num_cols_; + } + + int num_cols_; +}; + +template +Status LaunchSortKernel(OpKernelContext* ctx, const T* input, int num_rows, + int num_cols, int k, + typename TTypes::Tensor values, + TTypes::Tensor indices) { + const GPUDevice& d = ctx->eigen_device(); + const auto& cu_stream = GetGpuStream(ctx); + size_t temp_storage_bytes = -1; + + // TODO(ebrevdo): Once gpuprim supports iterators for ValueT replace that + // tensor with an iterator that directly returns the correct value. + Tensor input_indices; + TF_RETURN_IF_ERROR(ctx->allocate_temp( + DT_INT32, TensorShape({num_rows, num_cols}), &input_indices)); + auto input_indices_t = To32Bit(input_indices.flat()); + input_indices_t.device(d) = + input_indices_t.generate(ColumnIndexCreator(num_cols)); + + gpuprim::CountingInputIterator counting_iter(0); + gpuprim::TransformInputIterator> + segment_offsets_t(counting_iter, SegmentOffsetCreator(num_cols)); + + Tensor temp_values; + Tensor temp_indices; + T* sorted_values_ptr; + int* sorted_indices_ptr; + if (k == num_cols) { + // Doing a full sort, no intermediate values needed. + sorted_values_ptr = values.data(); + sorted_indices_ptr = indices.data(); + } else { + // Need to create intermediate values for sorting. + TF_RETURN_IF_ERROR(ctx->allocate_temp( + DT_INT32, TensorShape({num_rows, num_cols}), &temp_indices)); + TF_RETURN_IF_ERROR(ctx->allocate_temp(DataTypeToEnum::value, + TensorShape({num_rows, num_cols}), + &temp_values)); + sorted_indices_ptr = temp_indices.flat().data(); + sorted_values_ptr = temp_values.flat().data(); + } + + bool ran_nonsegmented_version = false; + if (num_rows == 1) { + // Note: DeviceSegmentedRadixSort is very slow when num_segments=1 because + // it only uses 1 SM per segment. Calling the un-segmented version is much + // faster in this case. + TF_RETURN_IF_ERROR( + GpuRadixSortDescending(ctx, num_cols, /*keys_in=*/input, + /*keys_out=*/sorted_values_ptr, + /*indices_in=*/input_indices_t.data(), + /*indices_out=*/sorted_indices_ptr, + /*num_bits=*/sizeof(T) * 8)); + ran_nonsegmented_version = true; + } + if (!ran_nonsegmented_version) { + auto err = gpuprim::DeviceSegmentedRadixSort::SortPairsDescending( + /* d_temp_storage */ nullptr, + /* temp_storage_bytes */ temp_storage_bytes, + /* d_keys_in */ input, + /* d_keys_out */ sorted_values_ptr, + /* d_values_in */ input_indices_t.data(), + /* d_values_out */ sorted_indices_ptr, + /* num_items */ num_cols * num_rows, + /* num_segments */ num_rows, + /* d_begin_offsets */ segment_offsets_t, + /* d_end_offsets */ segment_offsets_t + 1, + /* begin_bit */ 0, + /* end_bit */ sizeof(T) * 8, + /* stream */ cu_stream); + if (err != cudaSuccess) { + return errors::Internal( + "TopKOp: Could not launch " + "gpuprim::DeviceSegmentedRadixSort::SortPairsDescending to calculate " + "temp_storage_bytes, status: ", + cudaGetErrorString(err)); + } + Tensor temp_storage; + TF_RETURN_IF_ERROR(ctx->allocate_temp( + DT_INT8, TensorShape({static_cast(temp_storage_bytes)}), + &temp_storage)); + err = gpuprim::DeviceSegmentedRadixSort::SortPairsDescending( + /* d_temp_storage */ temp_storage.flat().data(), + /* temp_storage_bytes */ temp_storage_bytes, + /* d_keys_in */ input, + /* d_keys_out */ sorted_values_ptr, + /* d_values_in */ input_indices_t.data(), + /* d_values_out */ sorted_indices_ptr, + /* num_items */ num_cols * num_rows, + /* num_segments */ num_rows, + /* d_begin_offsets */ segment_offsets_t, + /* d_end_offsets */ segment_offsets_t + 1, + /* begin_bit */ 0, + /* end_bit */ sizeof(T) * 8, + /* stream */ cu_stream); + if (err != cudaSuccess) { + return errors::Internal( + "TopKOp: Could not launch " + "gpuprim::DeviceSegmentedRadixSort::SortPairsDescending to sort " + "input, " + "temp_storage_bytes: ", + temp_storage_bytes, ", status: ", cudaGetErrorString(err)); + } + } + if (k < num_cols) { + // Need to copy subsets of sorted_indices and sorted_outputs to + // indices and outputs. + const Eigen::DSizes slice_indices{0, 0}; + const Eigen::DSizes slice_sizes{num_rows, k}; + To32Bit(indices).device(d) = + To32Bit(temp_indices.matrix()).slice(slice_indices, slice_sizes); + To32Bit(values).device(d) = + To32Bit(temp_values.matrix()).slice(slice_indices, slice_sizes); + } + return OkStatus(); +} + +} // namespace impl + +namespace functor { + +template +struct TopKFunctor { + static EIGEN_ALWAYS_INLINE Status + Compute(OpKernelContext* context, bool sorted, int k, + const typename TTypes::ConstTensor& input, const int64 num_rows, + const int64 num_cols, typename TTypes::Tensor values, + typename TTypes::Tensor indices) { + // For small k, use the heap implementation. For larger k, use + // the in-place gpuprim sort. For k == num_cols, always use the + // in-place gpuprim sort. The thresholds for n and k were determined + // empirically. + if (num_cols <= 1000 || k == num_cols || k >= 100) { + return impl::LaunchSortKernel(context, input.data(), num_rows, num_cols, + k, values, indices); + } else { + const auto& cu_stream = GetGpuStream(context); + auto err = impl::LaunchTopKKernel(cu_stream, /* num_shards */ 0, + input.data(), num_rows, num_cols, k, + sorted, values.data(), indices.data()); + if (err != cudaSuccess) { + return errors::Internal( + "Could not launch TopKKernel: ", cudaGetErrorString(err), "."); + } else { + return OkStatus(); + } + } + } +}; + +} // namespace functor +} // namespace tensorflow + +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM + +#endif // TENSORFLOW_CORE_KERNELS_TOPK_OP_GPU_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/training_op_helpers.h b/third_party/tflite-hdrs/tensorflow/core/kernels/training_op_helpers.h new file mode 100644 index 00000000..83ee04fc --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/training_op_helpers.h @@ -0,0 +1,301 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_TRAINING_OP_HELPERS_H_ +#define TENSORFLOW_CORE_KERNELS_TRAINING_OP_HELPERS_H_ + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "xla/tsl/framework/allocator.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/resource_mgr.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/variant.h" +#include "tensorflow/core/framework/variant_op_registry.h" +#include "tensorflow/core/kernels/dense_update_functor.h" +#include "tensorflow/core/kernels/variable_ops.h" +#include "tensorflow/core/lib/core/refcount.h" +#include "tsl/platform/mutex.h" + +namespace tensorflow { + +// Must be called before performing a sparse operation on a variable. Ensures +// that no concurrent dense operations can happen while holding the variable's +// lock. +// @param ctx OpKernelContext for variable tensor cloning +// @param var Variable to be shared +// @param lock_held Whether the variable mutex was already held or not +// NOTE: This function uses variable's `copy_on_read_mode` flag to decide if +// it should immediately return or continue to lock the variable mutex for more +// processing, and always sets the `copy_on_read_mode` flag to true when this +// function returns. However, there is no guarantee that another op won't set +// the `copy_on_read_mode` flag back to false after this function. +// Therefore, for the operation that requires `copy_on_read` to stay true during +// its execution, the caller needs to lock the variable mutex outside and call +// this function with `lock_held = true` to avoid double locking. +template +absl::Status EnsureSparseVariableAccess(OpKernelContext* ctx, Var* var) { + if (var->copy_on_read_mode.load()) { + return absl::OkStatus(); + } + + tsl::mutex_lock ml(*var->mu()); + + // It may be possible that there are multiple threads that invoke + // `EnsureSparseVariableAccess` at the same time. If so, the first thread that + // enters this critical section will set the `copy_on_read_mode` flag to true. + // All other threads can then exit this critical section immediately. + if (var->copy_on_read_mode.load()) { + return absl::OkStatus(); + } + + // Once copy-on-read mode is True the refcount is guaranteed to be 1. This can + // also happen if there are no concurrent reads of the variable and + // copy-on-read mode is false. + if (var->tensor()->RefCountIsOne()) { + var->copy_on_read_mode.store(true); + return absl::OkStatus(); + } + Tensor tmp; + if (std::is_same::value) { + tsl::AllocatorAttributes attr; + attr.set_on_host(true); + TF_RETURN_IF_ERROR(ctx->allocate_temp(var->tensor()->dtype(), + var->tensor()->shape(), &tmp, attr)); + + const auto elements_in = var->tensor()->flat(); + auto elements_out = tmp.flat(); + for (int64_t i = 0; i < elements_in.size(); ++i) { + elements_out(i) = elements_in(i); + } + } else { + tsl::AllocatorAttributes attr; + attr.set_gpu_compatible(true); + attr.set_nic_compatible(true); + TF_RETURN_IF_ERROR(ctx->allocate_temp(var->tensor()->dtype(), + var->tensor()->shape(), &tmp, attr)); + functor::DenseUpdate copy_functor; + copy_functor(ctx->eigen_device(), tmp.flat(), + const_cast(var->tensor())->flat()); + } + *var->tensor() = tmp; + var->copy_on_read_mode.store(true); + return absl::OkStatus(); +} + +// Utility structure that releases a sequence of borrowed mutexes when it is +// deleted. +class VariableInputLockHolder { + public: + VariableInputLockHolder( + std::vector vars, + std::unique_ptr> locks, + std::unique_ptr> shared_locks) + : vars_(std::move(vars)), + locks_(std::move(locks)), + shared_locks_(std::move(shared_locks)) {} + + VariableInputLockHolder(VariableInputLockHolder&& other) + : vars_(std::move(other.vars_)), + locks_(std::move(other.locks_)), + shared_locks_(std::move(other.shared_locks_)) {} + + ~VariableInputLockHolder() { + // Release the locks before unrefing the Vars, because each lock + // is potentially borrowed from a Var in vars_. + locks_.reset(); + for (Var* var : vars_) { + var->Unref(); + } + } + + private: + std::vector vars_; + // NOTE: Use a `std::unique_ptr` instead of moving in a vector directly, + // because a `std::vector` is not movable on all platforms. + std::unique_ptr> locks_; + std::unique_ptr> shared_locks_; +}; + +// Returns a borrowed pointer to the mutex for the variable `input` in `ctx`. +// +// If `input` corresponds to a `DT_RESOURCE`-type variable input, +// `*maybe_resource` will be updated to contain the underlying resource, and the +// caller will be responsible for calling `Unref()` on that resource. +template +tsl::mutex* GetTrainingVariableMutex(OpKernelContext* ctx, int input, + Var** maybe_resource) { + *maybe_resource = nullptr; + if (ctx->input_dtype(input) == DT_RESOURCE) { + if (LookupResource(ctx, HandleFromInput(ctx, input), maybe_resource).ok()) { + return (*maybe_resource)->mu(); + } else { + ctx->CtxFailureWithWarning( + absl::InternalError("Invalid variable reference.")); + return nullptr; + } + } + return ctx->input_ref_mutex(input); +} + +// MaybeLockVariableInputMutexesInOrder is a helper function to acquire mutexes +// in address order to mitigate deadlock. Returns a structure that, when +// deleted, will release the acquired mutexes. Safe to pass duplicates - will +// only lock each distinct mutex once. If sparse is true, will ensure the +// variable gets switched to copy-on-read mode before trying to acquire the +// locks. If do_lock is false, returns immediately for reference variables. For +// resource variables in copy-on-read-mode, it will grab a shared lock if +// do_lock is false, exclusive lock otherwise. Note that this silently doesn't +// lock mutexes for invalid variable references; in all usages this is followed +// by GetInputTensor which will signal a failure. +template +VariableInputLockHolder MaybeLockVariableInputMutexesInOrder( + OpKernelContext* ctx, bool do_lock, bool sparse, + const std::vector& input_ids) { + bool any_resource = false; + for (auto i : input_ids) { + if (ctx->input_dtype(i) == DT_RESOURCE) { + any_resource = true; + break; + } + } + if (!do_lock && !any_resource) { + return VariableInputLockHolder({}, {}, {}); + } + std::vector vars; + std::vector mutexes; + std::vector acquire_order; + for (auto input : input_ids) { + Var* var; + tsl::mutex* mutex = GetTrainingVariableMutex(ctx, input, &var); + if (var) vars.push_back(var); + // Only lock each mutex once if duplicates exist (n^2 but n is 2 or 3). + if (std::find(mutexes.begin(), mutexes.end(), mutex) == mutexes.end()) { + acquire_order.push_back(mutexes.size()); + mutexes.push_back(mutex); + } + } + + if (sparse) { + for (Var* var : vars) { + EnsureSparseVariableAccess(ctx, var).IgnoreError(); + } + } + + std::sort(acquire_order.begin(), acquire_order.end(), + [&mutexes](int a, int b) { return mutexes[a] < mutexes[b]; }); + + auto locks = std::make_unique>(); + auto shared_locks = std::make_unique>(); + locks->reserve(acquire_order.size()); + + for (auto acquire : acquire_order) { + tsl::mutex* mu = mutexes[acquire]; + if (mu != nullptr) { + if (!sparse || do_lock) { + locks->emplace_back(*mu); + } else { + shared_locks->emplace_back(*mu); + } + } + } + auto variableInputLock = + VariableInputLockHolder(vars, std::move(locks), std::move(shared_locks)); + return variableInputLock; +} + +void MaybeForwardRefInputToRefOutput(OpKernelContext* ctx, int input, + int output); + +// This is for use with ResourceVariables to ensure *tensor has a +// reference count of 1 before you update it. +// REQUIRES: If you pass in variable->tensor(), *variable->mu() must be held. +template +absl::Status PrepareToUpdateVariable(OpKernelContext* ctx, Tensor* tensor, + bool copy_on_read_mode) { + if (copy_on_read_mode || !tensor->RefCountIsOne()) { + // Tensor's buffer is in use by some read, so we need to copy before + // updating. + Tensor tmp; + if (std::is_same::value) { + tsl::AllocatorAttributes attr; + attr.set_on_host(true); + TF_RETURN_IF_ERROR( + ctx->allocate_temp(tensor->dtype(), tensor->shape(), &tmp, attr)); + + const auto elements_in = tensor->flat(); + auto elements_out = tmp.flat(); + for (int64_t i = 0; i < elements_in.size(); ++i) { + elements_out(i) = elements_in(i); + } + } else { + tsl::AllocatorAttributes attr; + attr.set_gpu_compatible(true); + attr.set_nic_compatible(true); + TF_RETURN_IF_ERROR( + ctx->allocate_temp(tensor->dtype(), tensor->shape(), &tmp, attr)); + functor::DenseUpdate copy_functor; + copy_functor(ctx->eigen_device(), tmp.flat(), + const_cast(tensor)->flat()); + } + *tensor = tmp; + } + return absl::OkStatus(); +} + +// This gives you `*out`, a tensor you can update, corresponding to a variable +// passed as input index `input`. This handles the differences between +// reference and resource variables. + +// For reference variables we can just grab the tensor, grabbing the lock if +// `lock_held` is False. +// +// For resource variables: +// * If sparse is true: return the underlying tensor. +// * If sparse is false: ensure its refcount is 1 (by potentially copying its +// contents), and then return the underlying tensor. +// `lock_held` is ignored for resource variables. +template +absl::Status GetInputTensorFromVariable(OpKernelContext* ctx, int input, + bool lock_held, bool sparse, + Tensor* out) { + if (ctx->input_dtype(input) == DT_RESOURCE) { + core::RefCountPtr var; + TF_RETURN_IF_ERROR(LookupResource(ctx, HandleFromInput(ctx, input), &var)); + if (sparse) { + var->mu()->assert_held_shared(); + *out = *var->tensor(); + return absl::OkStatus(); + } + var->mu()->assert_held(); + TF_RETURN_IF_ERROR(PrepareToUpdateVariable( + ctx, var->tensor(), var->copy_on_read_mode.load())); + *out = *var->tensor(); + return absl::OkStatus(); + } + *out = ctx->mutable_input(input, lock_held); + return absl::OkStatus(); +} + +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_TRAINING_OP_HELPERS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/training_ops.h b/third_party/tflite-hdrs/tensorflow/core/kernels/training_ops.h new file mode 100644 index 00000000..8f986d13 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/training_ops.h @@ -0,0 +1,322 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_TRAINING_OPS_H_ +#define TENSORFLOW_CORE_KERNELS_TRAINING_OPS_H_ + +#include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +namespace functor { + +// Each training algorithm has a ApplyXYZ functor struct declared in +// this header file. They are specialized for different devices +// (CPUDevice in training_ops.cc or GPUDevice in training_ops_gpu.cc). + +template +struct ApplyGradientDescent { + void operator()(const Device& d, typename TTypes::Flat var, + typename TTypes::ConstScalar alpha, + typename TTypes::ConstFlat delta); +}; + +template +struct ApplyAdadelta { + void operator()(const Device& d, typename TTypes::Flat var, + typename TTypes::Flat accum, + typename TTypes::Flat accum_update, + typename TTypes::ConstScalar lr, + typename TTypes::ConstScalar rho, + typename TTypes::ConstScalar epsilon, + typename TTypes::ConstFlat grad); +}; + +template +struct SparseApplyAdadelta { + void operator()(const Device& d, typename TTypes::Matrix var, + typename TTypes::Matrix accum, + typename TTypes::Matrix accum_update, + typename TTypes::ConstScalar lr, + typename TTypes::ConstScalar rho, + typename TTypes::ConstScalar epsilon, + typename TTypes::ConstMatrix grad, + typename TTypes::ConstFlat indices); +}; + +template +struct FobosElasticNet { + void operator()(const Device& d, typename TTypes::Flat var, + typename TTypes::ConstScalar lr, + typename TTypes::ConstScalar l1, + typename TTypes::ConstScalar l2, + typename TTypes::ConstFlat grad); +}; + +template +struct ApplyProximalGradientDescent { + void operator()(const Device& d, typename TTypes::Flat var, + typename TTypes::ConstScalar lr, + typename TTypes::ConstScalar l1, + typename TTypes::ConstScalar l2, + typename TTypes::ConstFlat grad); +}; + +template +struct ApplyAdagrad { + void operator()(const Device& d, typename TTypes::Flat var, + typename TTypes::Flat accum, + typename TTypes::ConstScalar lr, + typename TTypes::ConstFlat grad, bool update_slots); +}; + +template +struct ApplyAdagradV2 { + void operator()(const Device& d, typename TTypes::Flat var, + typename TTypes::Flat accum, + typename TTypes::ConstScalar lr, + typename TTypes::ConstScalar epsilon, + typename TTypes::ConstFlat grad, bool update_slots); +}; + +template +struct ApplyAdagradDA { + void operator()(const Device& d, typename TTypes::Flat var, + typename TTypes::Flat gradient_accum, + typename TTypes::Flat gradient_squared_accum, + typename TTypes::ConstScalar lr, int64_t global_step, + typename TTypes::ConstScalar l1, + typename TTypes::ConstScalar l2, + typename TTypes::ConstFlat grad); +}; + +template +struct SparseApplyAdagrad { + // Note that epsilon is ignored if has_epsilon is false. + absl::Status operator()(const Device& d, typename TTypes::Matrix var, + typename TTypes::Matrix accum, + typename TTypes::ConstScalar lr, + typename TTypes::ConstScalar epsilon, + typename TTypes::ConstMatrix grad, + typename TTypes::ConstVec indices, + int64_t inner_dim, bool update_slots); +}; + +template +struct ApplyProximalAdagrad { + void operator()(const Device& d, typename TTypes::Flat var, + typename TTypes::Flat accum, + typename TTypes::ConstScalar lr, + typename TTypes::ConstScalar l1, + typename TTypes::ConstScalar l2, + typename TTypes::ConstFlat grad); +}; + +template +struct SparseApplyProximalAdagrad { + absl::Status operator()(const Device& d, typename TTypes::Matrix var, + typename TTypes::Matrix accum, + typename TTypes::ConstScalar lr, + typename TTypes::ConstScalar l1, + typename TTypes::ConstScalar l2, + typename TTypes::ConstMatrix grad, + typename TTypes::ConstVec indices, + int64_t inner_dim); +}; + +template +struct ApplyFtrl { + void operator()(const Device& d, typename TTypes::Flat var, + typename TTypes::Flat accum, + typename TTypes::Flat linear, + typename TTypes::ConstFlat grad, + typename TTypes::ConstScalar lr, + typename TTypes::ConstScalar l1, + typename TTypes::ConstScalar l2, + typename TTypes::ConstScalar lr_power); +}; + +template +struct ApplyFtrlMultiplyLinearByLr { + void operator()(const Device& d, typename TTypes::Flat var, + typename TTypes::Flat accum, + typename TTypes::Flat linear, + typename TTypes::ConstFlat grad, + typename TTypes::ConstScalar lr, + typename TTypes::ConstScalar l1, + typename TTypes::ConstScalar l2, + typename TTypes::ConstScalar lr_power); +}; + +template +struct ApplyFtrlV2 { + void operator()(const Device& d, typename TTypes::Flat var, + typename TTypes::Flat accum, + typename TTypes::Flat linear, + typename TTypes::ConstFlat grad, + typename TTypes::ConstScalar lr, + typename TTypes::ConstScalar l1, + typename TTypes::ConstScalar l2, + typename TTypes::ConstScalar l2_shrinkage, + typename TTypes::ConstScalar lr_power); +}; + +template +struct ApplyFtrlV2MultiplyLinearByLr { + void operator()(const Device& d, typename TTypes::Flat var, + typename TTypes::Flat accum, + typename TTypes::Flat linear, + typename TTypes::ConstFlat grad, + typename TTypes::ConstScalar lr, + typename TTypes::ConstScalar l1, + typename TTypes::ConstScalar l2, + typename TTypes::ConstScalar l2_shrinkage, + typename TTypes::ConstScalar lr_power); +}; + +template +struct SparseApplyFtrl { + absl::Status operator()(const Device& d, typename TTypes::Matrix var_flat, + typename TTypes::Matrix accum_flat, + typename TTypes::Matrix linear_flat, + typename TTypes::ConstScalar lr, + typename TTypes::ConstScalar l1, + typename TTypes::ConstScalar l2, + typename TTypes::ConstScalar l2_shrinkage, + typename TTypes::ConstScalar lr_power, + typename TTypes::ConstMatrix grad_flat, + typename TTypes::ConstVec indices_vec, + int64_t inner_dim, bool multiply_linear_by_lr); +}; + +template +struct ApplyMomentum { + void operator()(const Device& d, typename TTypes::Flat var, + typename TTypes::Flat accum, + typename TTypes::ConstScalar lr, + typename TTypes::ConstFlat grad, + typename TTypes::ConstScalar momentum, bool use_nesterov); +}; + +template +struct ApplyKerasMomentum { + void operator()(const Device& d, typename TTypes::Flat var, + typename TTypes::Flat accum, + typename TTypes::ConstScalar lr, + typename TTypes::ConstFlat grad, + typename TTypes::ConstScalar momentum, bool use_nesterov); +}; + +template +struct SparseApplyKerasMomentum { + Tindex operator()(const Device& d, typename TTypes::Matrix var, + typename TTypes::Matrix accum, + typename TTypes::ConstScalar lr, + typename TTypes::ConstMatrix grad, + typename TTypes::ConstFlat indices, + typename TTypes::ConstScalar momentum, + bool use_nesterov); +}; + +template +struct ApplyAdam { + void operator()(const Device& d, typename TTypes::Flat var, + typename TTypes::Flat m, typename TTypes::Flat v, + typename TTypes::ConstScalar beta1_power, + typename TTypes::ConstScalar beta2_power, + typename TTypes::ConstScalar lr, + typename TTypes::ConstScalar beta1, + typename TTypes::ConstScalar beta2, + typename TTypes::ConstScalar epsilon, + typename TTypes::ConstFlat grad, bool use_nesterov); +}; + +template +struct ApplyAdamWithAmsgrad { + void operator()(const Device& d, typename TTypes::Flat var, + typename TTypes::Flat m, typename TTypes::Flat v, + typename TTypes::Flat vhat, + typename TTypes::ConstScalar beta1_power, + typename TTypes::ConstScalar beta2_power, + typename TTypes::ConstScalar lr, + typename TTypes::ConstScalar beta1, + typename TTypes::ConstScalar beta2, + typename TTypes::ConstScalar epsilon, + typename TTypes::ConstFlat grad); +}; + +template +struct ApplyAdaMax { + void operator()(const Device& d, typename TTypes::Flat var, + typename TTypes::Flat m, typename TTypes::Flat v, + typename TTypes::ConstScalar beta1_power, + typename TTypes::ConstScalar lr, + typename TTypes::ConstScalar beta1, + typename TTypes::ConstScalar beta2, + typename TTypes::ConstScalar epsilon, + typename TTypes::ConstFlat grad); +}; + +template +struct ApplyRMSProp { + void operator()(const Device& d, typename TTypes::Flat var, + typename TTypes::Flat ms, typename TTypes::Flat mom, + typename TTypes::ConstScalar lr, + typename TTypes::ConstScalar rho, + typename TTypes::ConstScalar momentum, + typename TTypes::ConstScalar epsilon, + typename TTypes::ConstFlat grad); +}; + +template +struct ApplyCenteredRMSProp { + void operator()(const Device& d, typename TTypes::Flat var, + typename TTypes::Flat mg, typename TTypes::Flat ms, + typename TTypes::Flat mom, + typename TTypes::ConstScalar lr, + typename TTypes::ConstScalar rho, + typename TTypes::ConstScalar momentum, + typename TTypes::ConstScalar epsilon, + typename TTypes::ConstFlat grad); +}; + +template +struct ApplyAddSign { + void operator()(const Device& d, typename TTypes::Flat var, + typename TTypes::Flat m, + typename TTypes::ConstScalar lr, + typename TTypes::ConstScalar alpha, + typename TTypes::ConstScalar sign_decay, + typename TTypes::ConstScalar beta, + typename TTypes::ConstFlat grad); +}; + +template +struct ApplyPowerSign { + void operator()(const Device& d, typename TTypes::Flat var, + typename TTypes::Flat m, + typename TTypes::ConstScalar lr, + typename TTypes::ConstScalar logbase, + typename TTypes::ConstScalar sign_decay, + typename TTypes::ConstScalar beta, + typename TTypes::ConstFlat grad); +}; + +} // end namespace functor +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_TRAINING_OPS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/transpose_functor.h b/third_party/tflite-hdrs/tensorflow/core/kernels/transpose_functor.h new file mode 100644 index 00000000..f4c905b1 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/transpose_functor.h @@ -0,0 +1,258 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_TRANSPOSE_FUNCTOR_H_ +#define TENSORFLOW_CORE_KERNELS_TRANSPOSE_FUNCTOR_H_ + +#include +#include +#include + +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/platform/logging.h" + +namespace tensorflow { +// Transpose tensor 'in' into tensor 'out' according to dimension +// permutation 'perm'. +// +// REQUIRES: in.dtype() == out->dtype() +// REQUIRES: in.dims() == out->dims() +// REQUIRES: in.dims() == perm.size() +// REQUIRES: in.dim_size(perm[i]) == out->dim_size(i) +template +absl::Status DoTranspose(const Device& device, const Tensor& in, + const absl::Span perm, Tensor* out); + +// Conjugate and transpose tensor 'in' into tensor 'out' according to dimension +// permutation 'perm'. +// +// REQUIRES: in.dtype() == out->dtype() +// REQUIRES: in.dims() == out->dims() +// REQUIRES: in.dims() == perm.size() +// REQUIRES: in.dim_size(perm[i]) == out->dim_size(i) +template +absl::Status DoConjugateTranspose(const Device& device, const Tensor& in, + const absl::Span perm, + Tensor* out); + +// Convenience versions of DoTranspose that only swap the last (inner) two +// dimensions. +template +absl::Status DoMatrixTranspose(const Device& device, const Tensor& in, + Tensor* out); + +// Convenience versions of DoConjugateTranspose that only swap the last (inner) +// two dimensions. +template +absl::Status DoConjugateMatrixTranspose(const Device& device, const Tensor& in, + Tensor* out); + +// Primary device specific functor to be specialized for each device and type. +template +struct Transpose { + static void run(const Device& d, const Tensor& in, + const absl::Span perm, Tensor* out); +}; + +// Implementation details. +namespace internal { + +typedef absl::InlinedVector TransposeDimsVec; +typedef absl::InlinedVector TransposePermsVec; + +// Helper function that takes a tensor shape, a permutation, combines the +// neighboring shapes if their indices in the permutation are consecutive. +// The function outputs the combined shape and new permutation. +// Example: Tensor shape {2, 3, 4, 5, 120} and permutation {0, 4, 1, 2, 3} will +// produce new shape {2, 60, 120} and new permutation {0, 2, 1}. +inline void ReduceTransposeDimensions(const TensorShape& shape, + absl::Span perm, + TransposePermsVec* new_perm, + TransposeDimsVec* new_dims) { + CHECK_EQ(shape.dims(), perm.size()); + if (shape.dims() == 1) { + // If input dimension is already 1, no need to reduce dimension. + new_perm->resize(1); + (*new_perm)[0] = perm[0]; + (*new_dims)[0] = shape.dim_size(0); + return; + } + TransposePermsVec new_dim_position(shape.dims(), -1); + TransposeDimsVec combined_dims(shape.dims(), 0); + int cur_head = perm[0]; + new_dim_position[cur_head] = 0; + combined_dims[0] = shape.dim_size(cur_head); + int dim_idx = 0; + for (int perm_idx = 1; perm_idx < shape.dims(); ++perm_idx) { + // If two indices in permutation are consecutive numbers, combine their + // dimensions. + if (cur_head + 1 == perm[perm_idx]) { + cur_head = perm[perm_idx]; + combined_dims[dim_idx] *= shape.dim_size(cur_head); + } else { + // Else start a new dimension. + cur_head = perm[perm_idx]; + dim_idx++; + new_dim_position[cur_head] = dim_idx; + combined_dims[dim_idx] = shape.dim_size(cur_head); + } + } + // Compact the new permutations and dimension sizes. + new_perm->resize(dim_idx + 1); + new_dims->resize(dim_idx + 1); + dim_idx = 0; + for (int i = 0; i < new_dim_position.size(); ++i) { + if (new_dim_position[i] >= 0) { + int new_perm_idx = new_dim_position[i]; + (*new_perm)[dim_idx] = new_perm_idx; + (*new_dims)[dim_idx] = combined_dims[new_perm_idx]; + dim_idx++; + } + } +} + +// If all non-singleton dimensions remain in ascending order, the shuffled +// singletons can be transposed by a reshape, saving a memory allocation & copy. +// |permutation| must be a permutation of {0, .., input_shape.dims() - 1}. +// That is, for all i, 0 <= perm[i] < input_shape.dims(). +// In practice, this is checked in TransposeOp::Compute prior to calling this +// function, and the function sits here to facilitate unit testing. +inline bool NonSingletonDimensionsAlign(const TensorShape& input_shape, + const std::vector& permutation) { + int last_nonsingleton_perm_dim = -1; + for (int perm_dim : permutation) { + if (input_shape.dim_size(perm_dim) == 1) { + continue; + } + if (perm_dim < last_nonsingleton_perm_dim) { + return false; + } + last_nonsingleton_perm_dim = perm_dim; + } + return true; +} + +// Uses Eigen to transpose. +template +void TransposeUsingEigen(const Device& d, const Tensor& in, + const absl::Span perm, bool conjugate, + Tensor* out) { + Eigen::array p; + for (int i = 0; i < NDIMS; ++i) p[i] = perm[i]; + auto x = typename TTypes::ConstTensor( + reinterpret_cast(in.tensor_data().data()), + in.shape().AsEigenDSizes()); + auto y = typename TTypes::Tensor( + reinterpret_cast(const_cast(out->tensor_data().data())), + out->shape().AsEigenDSizes()); + if (conjugate) { + y.device(d) = x.conjugate().shuffle(p); + } else { + y.device(d) = x.shuffle(p); + } +} + +template +absl::Status DoTransposeImpl(const Device& d, const Tensor& in, + const absl::Span perm, bool conjugate, + Tensor* out) { + CHECK_EQ(in.dims(), out->dims()); + CHECK_EQ(in.dims(), perm.size()); + CHECK_EQ(in.dtype(), out->dtype()); + switch (in.dtype()) { + case DT_BOOL: + case DT_INT8: + case DT_QINT8: + case DT_QUINT8: + case DT_UINT8: + case DT_FLOAT8_E5M2: + case DT_FLOAT8_E4M3FN: + Transpose::run(d, in, perm, out); + break; + + case DT_BFLOAT16: + case DT_HALF: + case DT_INT16: + case DT_QINT16: + case DT_QUINT16: + case DT_UINT16: + Transpose::run(d, in, perm, out); + break; + + case DT_FLOAT: + case DT_INT32: + case DT_QINT32: + case DT_UINT32: + Transpose::run(d, in, perm, out); + break; + + case DT_DOUBLE: + case DT_INT64: + case DT_UINT64: + Transpose::run(d, in, perm, out); + break; + + case DT_COMPLEX64: + if (conjugate) { +#if defined(__ANDROID__) and !defined(__clang__) + // Workaround for GCC compiler bug in Android toolchain. + return errors::Unimplemented( + "Conjugate transpose of complex64 not supported for GCC on " + "Android."); +#else + Transpose::run(d, in, perm, out); +#endif + } else { + Transpose::run(d, in, perm, out); + } + break; + + case DT_COMPLEX128: + if (conjugate) { + Transpose::run(d, in, perm, + out); + } else { + Transpose::run(d, in, perm, + out); + } + break; + + case DT_STRING: + Transpose::run(d, in, perm, out); + break; + + default: + return errors::Unimplemented("Unsupported dtype on CPU: ", in.dtype()); + } + return absl::OkStatus(); +} + +template +inline absl::Status DoMatrixTransposeImpl(const Device& device, + const Tensor& in, bool conjugate, + Tensor* out) { + const int ndims = in.dims(); + if (ndims == 0) return absl::OkStatus(); + TransposePermsVec perm(ndims); + std::iota(perm.begin(), perm.end(), 0); + std::swap(perm[ndims - 2], perm[ndims - 1]); + return DoTransposeImpl(device, in, perm, conjugate, out); +} + +} // namespace internal +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_TRANSPOSE_FUNCTOR_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/transpose_op.h b/third_party/tflite-hdrs/tensorflow/core/kernels/transpose_op.h new file mode 100644 index 00000000..8f0405b6 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/transpose_op.h @@ -0,0 +1,106 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_TRANSPOSE_OP_H_ +#define TENSORFLOW_CORE_KERNELS_TRANSPOSE_OP_H_ + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor.h" + +namespace tensorflow { + +class TransposeOp : public OpKernel { + public: + explicit TransposeOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} + + void Compute(OpKernelContext* ctx) override; + + protected: + virtual absl::Status DoTranspose(OpKernelContext* ctx, const Tensor& in, + absl::Span perm, + Tensor* out) = 0; + virtual bool IsConjugate() const { return false; } +}; + +class TransposeCpuOp : public TransposeOp { + public: + explicit TransposeCpuOp(OpKernelConstruction* ctx) : TransposeOp(ctx) {} + + protected: + absl::Status DoTranspose(OpKernelContext* ctx, const Tensor& in, + absl::Span perm, Tensor* out) override; +}; + +#if defined(INTEL_MKL) +class MklTransposeCpuOp : public TransposeOp { + public: + explicit MklTransposeCpuOp(OpKernelConstruction* ctx) : TransposeOp(ctx) {} + + protected: + Status DoTranspose(OpKernelContext* ctx, const Tensor& in, + gtl::ArraySlice perm, Tensor* out) override; +}; +#endif // INTEL_MKL + +class TransposeGpuOp : public TransposeOp { + public: + explicit TransposeGpuOp(OpKernelConstruction* ctx) : TransposeOp(ctx) {} + + protected: + absl::Status DoTranspose(OpKernelContext* ctx, const Tensor& in, + absl::Span perm, Tensor* out) override; +}; + + +// Conjugating transpose ops. +class ConjugateTransposeCpuOp : public TransposeOp { + public: + explicit ConjugateTransposeCpuOp(OpKernelConstruction* ctx) + : TransposeOp(ctx) {} + + protected: + absl::Status DoTranspose(OpKernelContext* ctx, const Tensor& in, + absl::Span perm, Tensor* out) override; + bool IsConjugate() const override { return true; } +}; + +#if defined(INTEL_MKL) +class MklConjugateTransposeCpuOp : public TransposeOp { + public: + explicit MklConjugateTransposeCpuOp(OpKernelConstruction* ctx) + : TransposeOp(ctx) {} + + protected: + Status DoTranspose(OpKernelContext* ctx, const Tensor& in, + gtl::ArraySlice perm, Tensor* out) override; + bool IsConjugate() const override { return true; } +}; +#endif // INTEL_MKL + +class ConjugateTransposeGpuOp : public TransposeOp { + public: + explicit ConjugateTransposeGpuOp(OpKernelConstruction* ctx) + : TransposeOp(ctx) {} + + protected: + absl::Status DoTranspose(OpKernelContext* ctx, const Tensor& in, + absl::Span perm, Tensor* out) override; + bool IsConjugate() const override { return true; } +}; + + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_TRANSPOSE_OP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/typed_conditional_accumulator_base.h b/third_party/tflite-hdrs/tensorflow/core/kernels/typed_conditional_accumulator_base.h new file mode 100644 index 00000000..f6574416 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/typed_conditional_accumulator_base.h @@ -0,0 +1,95 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_TYPED_CONDITIONAL_ACCUMULATOR_BASE_H_ +#define TENSORFLOW_CORE_KERNELS_TYPED_CONDITIONAL_ACCUMULATOR_BASE_H_ + +#include "tensorflow/core/kernels/conditional_accumulator_base.h" + +namespace tensorflow { + +/* + * TypedConditionalAccumulatorBase is a templated companion of + * ConditionalAccumulatorBase which allows for subclasses to use different + * types for the input gradients. (See ConditionalAccumulator and + * SparseConditionalAccumulator.) + * + * TypedConditionalAccumulatorBase defines virtual methods and implements + * methods which depend on the gradient type. These are mainly methods that are + * used for adding a new gradient to the accumulator. + */ +template +class TypedConditionalAccumulatorBase : public ConditionalAccumulatorBase { + public: + TypedConditionalAccumulatorBase(const DataType& dtype, + const PartialTensorShape& shape, + const string& name, + const string& reduction_type) + : ConditionalAccumulatorBase(dtype, shape, name, reduction_type) {} + + /** + * Attempts to add a gradient to the accumulator. An ApplyGrad attempt is + * successful (i.e., has its gradient applied) if its local_step >= + * current_global_step_ at the time the attempt is processed. Otherwise, if + * local_step < current_global_step_, the stale gradient is silently dropped. + * + * local_step: Time-step at which the gradient was computed. + * grad: Gradient tensor to be added to the accumulator. + * ctx: Context in which the op is executed. + */ + void TryApplyGrad(int64_t local_step, OpKernelContext* ctx) override { + { + mutex_lock l(mu_); + if (local_step >= current_global_step_) { + GradientTensorType* grad = nullptr; + bool is_valid = GetAndValidateTensorInputForApplyGrad(ctx, &grad); + if (is_valid) { + if (counter_ > 0) { + AddToAccumGradFunction(ctx, grad); + } else { + AllocateAndAssignToAccumGradFunction(ctx, grad); + } + counter_++; + } + CleanUpGradTensor(grad); + } + } + FlushUnlocked(); + } + + protected: + // Virtual methods to be implemented by sub-classes for different datatypes. + // Implements arithmetic operations specific to datatype. + virtual void AllocateAndAssignToAccumGradFunction( + OpKernelContext* ctx, GradientTensorType* grad) = 0; + + virtual void AddToAccumGradFunction(OpKernelContext* ctx, + GradientTensorType* grad) = 0; + + // Method for extracting and validating input provided in an OpKernelContext. + // Returns true if input was successfully retrieved and is valid. + // Gradient is returned via the GradientTensorType** tensor. + virtual bool GetAndValidateTensorInputForApplyGrad( + OpKernelContext* ctx, GradientTensorType** tensor) + TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) = 0; + + // Method for cleaning up any memory allocated in + // GetAndValidateTensorInputForApplyGrad + virtual void CleanUpGradTensor(GradientTensorType* tensor) = 0; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_TYPED_CONDITIONAL_ACCUMULATOR_BASE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/typed_queue.h b/third_party/tflite-hdrs/tensorflow/core/kernels/typed_queue.h new file mode 100644 index 00000000..e4c82f0e --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/typed_queue.h @@ -0,0 +1,118 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_TYPED_QUEUE_H_ +#define TENSORFLOW_CORE_KERNELS_TYPED_QUEUE_H_ + +#include +#include +#include + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/kernels/queue_base.h" +#include "tensorflow/core/platform/mutex.h" + +namespace tensorflow { + +// TypedQueue builds on QueueBase, with backing class (SubQueue) +// known and stored within. Shared methods that need to have access +// to the backed data sit in this class. +template +class TypedQueue : public QueueBase { + public: + TypedQueue(const int32_t capacity, const DataTypeVector& component_dtypes, + const std::vector& component_shapes, + const string& name); + + virtual absl::Status Initialize(); // Must be called before any other method. + + int64_t MemoryUsed() const override; + + protected: + std::vector queues_ TF_GUARDED_BY(mu_); +}; // class TypedQueue + +template +TypedQueue::TypedQueue( + int32_t capacity, const DataTypeVector& component_dtypes, + const std::vector& component_shapes, const string& name) + : QueueBase(capacity, component_dtypes, component_shapes, name) {} + +template +absl::Status TypedQueue::Initialize() { + if (component_dtypes_.empty()) { + return errors::InvalidArgument("Empty component types for queue ", name_); + } + if (!component_shapes_.empty() && + component_dtypes_.size() != component_shapes_.size()) { + return errors::InvalidArgument( + "Different number of component types. ", + "Types: ", DataTypeSliceString(component_dtypes_), + ", Shapes: ", ShapeListString(component_shapes_)); + } + + mutex_lock lock(mu_); + queues_.reserve(num_components()); + for (int i = 0; i < num_components(); ++i) { + queues_.push_back(SubQueue()); + } + return absl::OkStatus(); +} + +template +inline int64_t SizeOf(const SubQueue& sq) { + static_assert(sizeof(SubQueue) != sizeof(SubQueue), "SubQueue size unknown."); + return 0; +} + +template <> +inline int64_t SizeOf(const std::deque& sq) { + if (sq.empty()) { + return 0; + } + return sq.size() * sq.front().AllocatedBytes(); +} + +template <> +inline int64_t SizeOf(const std::vector& sq) { + if (sq.empty()) { + return 0; + } + return sq.size() * sq.front().AllocatedBytes(); +} + +using TensorPair = std::pair; + +template +int64_t SizeOf(const std::priority_queue& sq) { + if (sq.empty()) { + return 0; + } + return sq.size() * (sizeof(TensorPair) + sq.top().second.AllocatedBytes()); +} + +template +inline int64_t TypedQueue::MemoryUsed() const { + int memory_size = 0; + mutex_lock l(mu_); + for (const auto& sq : queues_) { + memory_size += SizeOf(sq); + } + return memory_size; +} + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_TYPED_QUEUE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/uniform_quant_ops/math_utils.h b/third_party/tflite-hdrs/tensorflow/core/kernels/uniform_quant_ops/math_utils.h new file mode 100644 index 00000000..5cd9c1b4 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/uniform_quant_ops/math_utils.h @@ -0,0 +1,334 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_KERNELS_UNIFORM_QUANT_OPS_MATH_UTILS_H_ +#define TENSORFLOW_CORE_KERNELS_UNIFORM_QUANT_OPS_MATH_UTILS_H_ + +#include +#include +#include + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/status.h" + +namespace tensorflow { + +namespace internal { + +// Multiply by the effective quantized multiplier and shift. +// Caller is responsible for guaranteeing: +// quantized_multiplier >= 0 +// shift >= -31 && shift <= 30 +// The usage of this function is restricted to "multiply by quantized_multiplier +// and shift which were calcluated from QuantizeMultiplier() function below", +// so the conditions are expected to be met. +// +// Reference (TFLite MultiplyByQuantizedMultiplier with TFLITE_SINGLE_ROUNDING): +// https://github.com/tensorflow/tensorflow/blob/47c640a961874f644cd071752835c7b792450bb8/tensorflow/lite/kernels/internal/common.h#L145 +// Above implementation refers from ruy MultiplyByQuantizedMultiplier +// (https://github.com/google/ruy/blob/97ebb72aa0655c0af98896b317476a5d0dacad9c/ruy/apply_multiplier.cc) +// +// After mutiplying fixed point quantized_multiplier, apply single rounding +// operation (addition of 'round' to result and then shift right by +// total_shift). where round=(1 << (30 - shift)) and total_shift=(31 - shift) +inline int32_t MultiplyByQuantizedMultiplier(int32_t x, + int32_t quantized_multiplier, + int shift) { + const int64_t total_shift = 31 - shift; + const int64_t round = static_cast(1) << (total_shift - 1); + int64_t result = x * static_cast(quantized_multiplier) + round; + result = result >> total_shift; + + result = std::clamp( + result, static_cast(std::numeric_limits::min()), + static_cast(std::numeric_limits::max())); + return static_cast(result); +} + +} // namespace internal + +// Quantize eigen Tensor input_tensor using given inv_scale and zero_point, +// using the formula: +// quantized_val = floor(input_val * inv_scale + 0.5f) + zero_point +// +// The caller is reponsible for the validity of the inv_scale (Avoid precision +// loss from taking inverse, and ensure that inv_scale is a finite number.) +template +void AffineQuantize(const ConstTensorTin& input_tensor, float inv_scale, + int32_t zero_point, int32_t quantization_min_val, + int32_t quantization_max_val, TensorTout quantized_tensor) { + quantized_tensor = ((input_tensor.template cast() * inv_scale + 0.5f) + .floor() + .template cast() + + zero_point) + .cwiseMin(quantization_max_val) + .cwiseMax(quantization_min_val) + .template cast(); +} + +// Dequantize eigen Tensor input_tensor using given scale and zero_point, using +// the formula: +// dequantized_val = (input_val - zero_point) * scale +template +void AffineDequantize(const ConstTensorTin& input_tensor, float scale, + int32_t zero_point, TensorTout dequantized_tensor) { + dequantized_tensor = (((input_tensor.template cast() - zero_point)) + .template cast() * + scale) + .template cast(); +} + +// Given a portion of input float tensor, quantizes the data and writes output +// to the corresponding portion in quantized_tensor. The quantization scale and +// zero_point is calculated using the input data min and max. +// This function is used for dynamic range quantization in hybrid (float x qint) +// kernels. +// +// This function behavior aligns with TFLite AsymmetricQuantize() +// (https://github.com/tensorflow/tensorflow/blob/779d3824c8b38a622773940011ced0388697b951/tensorflow/lite/kernels/internal/reference/portable_tensor_utils.cc#L72) +// to achieve feature parity with TFLite which is required since supporting +// mobile executions is the one of the major use cases. The behavior is same +// except for following difference: TFLite AsymmetricQuantize() uses round(input +// / scale + zero_point), while AffineQuantize() uses floor(input_val * +// (1./scale) + 0.5) + zero_point +template +absl::Status AsymmetricQuantize(const ConstTensorTin& input_tensor, + int32_t quantization_min_val, + int32_t quantization_max_val, float& scale, + int32& zero_point, + TensorTout quantized_tensor) { + if (quantization_min_val >= quantization_max_val) { + // NOLINTNEXTLINE + return errors::InvalidArgument( + "quantization_min_val must be smaller than quantization_max_val. " + "Given ", + quantization_min_val, ", ", quantization_max_val); + } + + Eigen::Tensor input_tensor_min = + input_tensor.minimum(); + Eigen::Tensor input_tensor_max = + input_tensor.maximum(); + const double rmin = static_cast(std::min(0.0f, input_tensor_min())); + const double rmax = static_cast(std::max(0.0f, input_tensor_max())); + const double qmin_double = quantization_min_val; + const double qmax_double = quantization_max_val; + + float inv_scale = 0; + scale = (rmax - rmin) / (qmax_double - qmin_double); + if (rmax - rmin != 0) { + // Re-calculate the inverse instead of using (1./scale), to avoid loss of + // precision. + inv_scale = (qmax_double - qmin_double) / (rmax - rmin); + } + if (scale == 0 || !std::isfinite(inv_scale)) { + quantized_tensor.setZero(); + scale = 1.0; + zero_point = 0; + return absl::OkStatus(); + } + + // Using the scale calculated from the quantization range and data range, + // calculate zero point from quantization min and quantization max. + // Among those two, choose the zero point that has smaller error. + const double zero_point_from_min = qmin_double - rmin / scale; + const double zero_point_from_max = qmax_double - rmax / scale; + const double zero_point_from_min_error = + std::abs(qmin_double) + std::abs(rmin / scale); + const double zero_point_from_max_error = + std::abs(qmax_double) + std::abs(rmax / scale); + const double zero_point_double = + zero_point_from_min_error < zero_point_from_max_error + ? zero_point_from_min + : zero_point_from_max; + + int8_t nudged_zero_point = 0; + if (zero_point_double <= qmin_double) { + nudged_zero_point = quantization_min_val; + } else if (zero_point_double >= qmax_double) { + nudged_zero_point = quantization_max_val; + } else { + nudged_zero_point = static_cast(round(zero_point_double)); + } + zero_point = nudged_zero_point; + + AffineQuantize(input_tensor, inv_scale, zero_point, quantization_min_val, + quantization_max_val, quantized_tensor); + return absl::OkStatus(); +} + +// Given double_multiplier, quantize it where it is represented by two int32_t, +// quantized_multiplier and shift. +// +// double_multiplier must be a positive finite number. Otherwise returns +// InvalidArgument. +// +// Output quantized_multiplier is clamped to range [0, INT32_MAX], +// and shift is clamped to range [-31, 30]. +absl::Status QuantizeMultiplier(double double_multiplier, + int32_t& quantized_multiplier, int32_t& shift); + +// Requantize input_val given quantized effective_muliplier|shift and +// input|output zero_point. +// Effective multiplier and shift should be calculated from effective scale +// which is: +// (product of input scales) / (product of output scales). +template +Tout AffineRequantizeWithQuantizedMultiplierAndShift( + Tin input_val, int32_t effective_quantized_multiplier, int effective_shift, + int32_t input_zero_point, int32_t output_zero_point, + int32_t quantization_min_val, int32_t quantization_max_val) { + const int32_t input = static_cast(input_val) - input_zero_point; + + const int32_t unclamped = + internal::MultiplyByQuantizedMultiplier( + input, effective_quantized_multiplier, effective_shift) + + output_zero_point; + + // Clamp with [quantization_min_val, quantization_max_val]. + return static_cast( + std::max(std::min(unclamped, quantization_max_val), + quantization_min_val)); +} + +namespace internal { + +// Requantize from per-tensor to per-tensor. +template +absl::Status PerTensorToPerTensorRequantize( + const Tensor& input, float input_scale, int32_t input_zero_point, + float output_scale, int32_t output_zero_point, int32_t quantization_min_val, + int32_t quantization_max_val, Tensor& output) { + const double effective_multiplier = + static_cast(input_scale) / output_scale; + int32_t effective_quantized_multiplier; + int32_t effective_shift; + TF_RETURN_IF_ERROR(QuantizeMultiplier( + effective_multiplier, effective_quantized_multiplier, effective_shift)); + + output.flat() = input.flat().unaryExpr( + [effective_quantized_multiplier, effective_shift, input_zero_point, + output_zero_point, quantization_min_val, + quantization_max_val](Tin input_val) { + return AffineRequantizeWithQuantizedMultiplierAndShift( + input_val, effective_quantized_multiplier, effective_shift, + input_zero_point, output_zero_point, quantization_min_val, + quantization_max_val); + }); + return absl::OkStatus(); +} + +// Requantize where the input or output contains any per-axis quantized cases. +// - From per-tensor to per-axis. +// - From per-axis to per-tensor. +// - From per-axis to per-axis. +template +absl::Status PerAxisRequantize(OpKernelContext* context, const Tensor& input, + const Tensor& input_scales, + const Tensor& input_zero_points, + const Tensor& output_scales, + const Tensor& output_zero_points, + int quantization_axis, + int32_t quantization_min_val, + int32_t quantization_max_val, Tensor& output) { + const bool input_per_axis_quantization = input_scales.dims() == 1; + const bool output_per_axis_quantization = output_scales.dims() == 1; + const auto& per_axis_scales_shape = input_per_axis_quantization + ? input_scales.shape() + : output_scales.shape(); + + Tensor effective_quantized_multipliers; + TF_RETURN_IF_ERROR(context->allocate_temp(DT_INT32, per_axis_scales_shape, + &effective_quantized_multipliers)); + Tensor effective_shifts; + TF_RETURN_IF_ERROR(context->allocate_temp(DT_INT32, per_axis_scales_shape, + &effective_shifts)); + + const float* input_scales_data = input_scales.flat().data(); + const float* output_scales_data = output_scales.flat().data(); + int32_t* effective_quantized_multipliers_data = + effective_quantized_multipliers.flat().data(); + int32_t* effective_shifts_data = effective_shifts.flat().data(); + + const int64_t quantization_dim_size = output.dim_size(quantization_axis); + + for (int64_t i = 0; i < quantization_dim_size; ++i) { + const double effective_multiplier = + static_cast( + input_scales_data[input_per_axis_quantization ? i : 0]) / + output_scales_data[output_per_axis_quantization ? i : 0]; + TF_RETURN_IF_ERROR(QuantizeMultiplier( + effective_multiplier, effective_quantized_multipliers_data[i], + effective_shifts_data[i])); + } + + const int32* input_zero_points_data = input_zero_points.flat().data(); + const int32* output_zero_points_data = + output_zero_points.flat().data(); + + auto input_tensor = + input.template flat_inner_outer_dims(quantization_axis - 1); + auto output_tensor = + output.template flat_inner_outer_dims(quantization_axis - 1); + + for (int i = 0; i < quantization_dim_size; ++i) { + output_tensor.template chip<1>(i) = + input_tensor.template chip<1>(i).unaryExpr( + [effective_quantized_multipliers_data, effective_shifts_data, + input_zero_points_data, output_zero_points_data, + quantization_min_val, quantization_max_val, + input_per_axis_quantization, output_per_axis_quantization, + i](Tin input_val) { + return AffineRequantizeWithQuantizedMultiplierAndShift( + input_val, effective_quantized_multipliers_data[i], + effective_shifts_data[i], + input_zero_points_data[input_per_axis_quantization ? i : 0], + output_zero_points_data[output_per_axis_quantization ? i : 0], + quantization_min_val, quantization_max_val); + }); + } + return absl::OkStatus(); +} + +} // namespace internal + +template +absl::Status EvalRequantize( + OpKernelContext* context, const Tensor& input, const Tensor& input_scales, + const Tensor& input_zero_points, const Tensor& output_scales, + const Tensor& output_zero_points, int input_quantization_axis, + int output_quantization_axis, int32_t quantization_min_val, + int32_t quantization_max_val, Tensor& output) { + if (input_quantization_axis == -1 && output_quantization_axis == -1) { + return internal::PerTensorToPerTensorRequantize( + input, input_scales.scalar()(), + input_zero_points.scalar()(), output_scales.scalar()(), + output_zero_points.scalar()(), quantization_min_val, + quantization_max_val, output); + } else { + const int quantization_axis = input_quantization_axis >= 0 + ? input_quantization_axis + : output_quantization_axis; + return internal::PerAxisRequantize( + context, input, input_scales, input_zero_points, output_scales, + output_zero_points, quantization_axis, quantization_min_val, + quantization_max_val, output); + } +} + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_UNIFORM_QUANT_OPS_MATH_UTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/uniform_quant_ops/tensor_utils.h b/third_party/tflite-hdrs/tensorflow/core/kernels/uniform_quant_ops/tensor_utils.h new file mode 100644 index 00000000..4a303a3f --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/uniform_quant_ops/tensor_utils.h @@ -0,0 +1,76 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_KERNELS_UNIFORM_QUANT_OPS_TENSOR_UTILS_H_ +#define TENSORFLOW_CORE_KERNELS_UNIFORM_QUANT_OPS_TENSOR_UTILS_H_ + +#include "tensorflow/core/framework/ops_util.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/lib/gtl/array_slice.h" + +namespace tensorflow { + +// Returns if all elements in given tensors are positive. +template +bool AllElementsPositive(const Tensor& tensor) { + Eigen::Tensor positive = + (tensor.flat() > 0).all(); + return positive(); +} + +// Given data tensor's shape and quantization params, returns if the shapes are +// valid. +absl::Status QuantizationAxisAndShapeValid(const TensorShape& data_shape, + const TensorShape& scales_shape, + const TensorShape& zero_points_shape, + int quantization_axis); + +// Given in_shape and perm to transpose, returns out shape after the transpose. +// perm must be a permutation of [0, 1, ..., in_shape.rank - 1]. The caller is +// responsible for guaranteeing it. +TensorShape TransposedShape(const TensorShape& in_shape, + const absl::Span perm); + +// Given in Tensor and perm to transpose, transpose in Tensor and write to out +// Tensor. +// perm must be a permutation of [0, 1, ..., in_shape.rank - 1]. The caller is +// responsible for guaranteeing it. +// Reference: +// https://github.com/tensorflow/tensorflow/blob/c09dc18b15a56f3e72a08c9f3a53e7ef347d159d/tensorflow/core/kernels/transpose_functor_cpu.cc#L35 +template +void Transpose(const Tensor& in, const absl::Span perm, + Tensor& out) { + absl::InlinedVector in_strides = + ComputeStride(in.shape()); + absl::InlinedVector out_strides = + ComputeStride(out.shape()); + const T* in_data = in.flat().data(); + T* out_data = out.flat().data(); + + for (int64_t out_idx = 0; out_idx < out.NumElements(); ++out_idx) { + int64_t in_idx = 0; + int64_t remain_out_idx = out_idx; + for (int dim = 0; dim < out.dims(); ++dim) { + const int64_t ratio = remain_out_idx / out_strides[dim]; + remain_out_idx -= ratio * out_strides[dim]; + in_idx += ratio * in_strides[perm[dim]]; + } + out_data[out_idx] = in_data[in_idx]; + } +} + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_UNIFORM_QUANT_OPS_TENSOR_UTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/unique_op_gpu.cu.h b/third_party/tflite-hdrs/tensorflow/core/kernels/unique_op_gpu.cu.h new file mode 100644 index 00000000..23d7f89d --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/unique_op_gpu.cu.h @@ -0,0 +1,449 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_UNIQUE_OP_GPU_CU_H_ +#define TENSORFLOW_CORE_KERNELS_UNIQUE_OP_GPU_CU_H_ + +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM + +#define EIGEN_USE_GPU + +#include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/kernels/gpu_prim.h" +#include "tensorflow/core/kernels/gpu_prim_helpers.h" +#include "tensorflow/core/lib/core/bits.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/util/gpu_kernel_helper.h" +#include "tensorflow/core/util/gpu_solvers.h" // For ScratchSpace + +#if TENSORFLOW_USE_ROCM +#include "tensorflow/core/platform/rocm.h" +#endif + +namespace tensorflow { + +typedef Eigen::GpuDevice GPUDevice; + +namespace unique_op_gpu { + +// Returns true iff index is at the end of a segment (which is equivalent to the +// beginning of the next segment). +template +struct SegmentIndicatorFunctor { + const T* __restrict__ sorted_input_ptr_; + SegmentIndicatorFunctor(const T* sorted_input_ptr) + : sorted_input_ptr_(sorted_input_ptr) {} + __device__ bool operator()(const TIndex& i) const { + return i > 0 && sorted_input_ptr_[i] != sorted_input_ptr_[i - 1]; + } +}; + +template +__global__ void ExtractFirstOccurrenceIndicesKernel( + int64_t input_size, int64_t uniq_size, + const TIndex* __restrict__ sorted_input_inds, + const TIndex* __restrict__ sorted_input_unique_ids, + TIndex* __restrict__ unique_input_inds, TIndex* __restrict__ segment_ends) { + GPU_1D_KERNEL_LOOP(i, input_size) { + TIndex sorted_input_unique_id = sorted_input_unique_ids[i]; + if (i == 0 || sorted_input_unique_id != sorted_input_unique_ids[i - 1]) { + unique_input_inds[sorted_input_unique_id] = sorted_input_inds[i]; + if (segment_ends) { + if (i == 0) { + // First thread writes the last element. + segment_ends[uniq_size - 1] = input_size; + } else { + segment_ends[sorted_input_unique_id - 1] = i; + } + } + } + } +} + +// Scatters the index of the first occurrence of each unique input value to +// unique_input_inds. +// If segment_ends is not nullptr, it is filled with the end index of each +// unique value's range in the sorted input (the last element is always set +// to input_size). +template +Status ExtractFirstOccurrenceIndices(const GPUDevice& d, int64_t input_size, + int64_t uniq_size, + const TIndex* sorted_input_inds, + const TIndex* sorted_input_unique_ids, + TIndex* unique_input_inds, + TIndex* segment_ends) { + CHECK_GT(input_size, 0); // Crash OK + GpuLaunchConfig config = GetGpuLaunchConfig( + input_size, d, &ExtractFirstOccurrenceIndicesKernel, + /*dynamic_shared_memory_size=*/0, /*block_size_limit=*/0); + return GpuLaunchKernel(ExtractFirstOccurrenceIndicesKernel, + config.block_count, config.thread_per_block, 0, + d.stream(), input_size, uniq_size, sorted_input_inds, + sorted_input_unique_ids, unique_input_inds, + segment_ends); +} + +template +__global__ void GatherOutputsAndInvertPermutationKernel( + int64_t uniq_size, const T* __restrict__ input, + const TIndex* __restrict__ sorted_unique_input_inds, + const TIndex* __restrict__ sorted_unique_perm, + const TIndex* __restrict__ segment_ends, T* __restrict__ output, + TIndex* __restrict__ inv_sorted_unique_perm, TIndex* __restrict__ count) { + GPU_1D_KERNEL_LOOP(i, uniq_size) { + output[i] = input[sorted_unique_input_inds[i]]; + auto j = sorted_unique_perm[i]; + inv_sorted_unique_perm[j] = i; + if (count) { + TIndex beg = j == 0 ? 0 : segment_ends[j - 1]; + TIndex end = segment_ends[j]; + count[i] = end - beg; + } + } +} + +// Gathers input values using sorted_unique_input_inds, and inverts the +// permutation specified by sorted_unique_perm. +template +Status GatherOutputsAndInvertPermutation(const GPUDevice& d, int64_t uniq_size, + const T* input, + const TIndex* sorted_unique_input_inds, + const TIndex* sorted_unique_perm, + const TIndex* segment_ends, T* output, + TIndex* inv_sorted_unique_perm, + TIndex* count) { + if (uniq_size == 0) return OkStatus(); + GpuLaunchConfig config = GetGpuLaunchConfig( + uniq_size, d, &GatherOutputsAndInvertPermutationKernel, + /*dynamic_shared_memory_size=*/0, /*block_size_limit=*/0); + return GpuLaunchKernel(GatherOutputsAndInvertPermutationKernel, + config.block_count, config.thread_per_block, 0, + d.stream(), uniq_size, input, sorted_unique_input_inds, + sorted_unique_perm, segment_ends, output, + inv_sorted_unique_perm, count); +} + +template +__global__ void LookupAndScatterUniqueIdsKernel( + int64_t input_size, const TIndex* sorted_input_inds, + const TIndex* __restrict__ sorted_input_unique_ids, + const TIndex* __restrict__ inv_sorted_unique_perm, + TIndex* __restrict__ idx) { + GPU_1D_KERNEL_LOOP(i, input_size) { + idx[sorted_input_inds[i]] = + inv_sorted_unique_perm[sorted_input_unique_ids[i]]; + } +} + +// Maps the values of sorted_input_unique_ids and scatters them to idx using +// sorted_input_inds. +template +Status LookupAndScatterUniqueIds(const GPUDevice& d, int64_t input_size, + const TIndex* sorted_input_inds, + const TIndex* sorted_input_unique_ids, + const TIndex* inv_sorted_unique_perm, + TIndex* idx) { + CHECK_GT(input_size, 0); // Crash OK + GpuLaunchConfig config = GetGpuLaunchConfig( + input_size, d, &LookupAndScatterUniqueIdsKernel, + /*dynamic_shared_memory_size=*/0, /*block_size_limit=*/0); + return GpuLaunchKernel(LookupAndScatterUniqueIdsKernel, + config.block_count, config.thread_per_block, 0, + d.stream(), input_size, sorted_input_inds, + sorted_input_unique_ids, inv_sorted_unique_perm, idx); +} + +} // namespace unique_op_gpu + +// This only supports Unique[WithCounts], not Unique[WithCounts]V2. +template +class UniqueOpGPU : public AsyncOpKernel { + public: + explicit UniqueOpGPU(OpKernelConstruction* context) + : AsyncOpKernel(context) {} + + template + void AllocateTemp(OpKernelContext* context, int64_t size, Tensor* tensor, + U** tensor_data, DoneCallback done) const { + OP_REQUIRES_OK_ASYNC(context, + context->allocate_temp(DataTypeToEnum::value, + TensorShape({size}), tensor), + done); + *tensor_data = tensor->flat().data(); + } + + void ComputeAsync(OpKernelContext* context, DoneCallback done) override { + const Tensor& input = context->input(0); + // TODO(dga): Make unique polymorphic for returning int32 and int64 + // vectors to support large tensors. + OP_REQUIRES_ASYNC(context, + input.NumElements() <= std::numeric_limits::max(), + errors::InvalidArgument( + "unique does not support input tensors larger than ", + std::numeric_limits::max(), " elements"), + done); + + OP_REQUIRES_ASYNC(context, TensorShapeUtils::IsVector(input.shape()), + errors::InvalidArgument("unique expects a 1D vector."), + done); + + se::Stream* stream = context->op_device_context()->stream(); + OP_REQUIRES_ASYNC(context, stream, + errors::Internal("No GPU stream available."), done); + + int64_t input_size = input.NumElements(); + bool has_count_output = num_outputs() > 2; + if (input_size == 0) { + // Early exit for trivial case. + Tensor* t = nullptr; + OP_REQUIRES_OK_ASYNC( + context, context->allocate_output(0, TensorShape({0}), &t), done); + OP_REQUIRES_OK_ASYNC( + context, context->allocate_output(1, TensorShape({0}), &t), done); + if (has_count_output) { + OP_REQUIRES_OK_ASYNC( + context, context->allocate_output(2, TensorShape({0}), &t), done); + } + done(); + return; + } + + // The algorithm implemented here is as follows: + // input = [3, 5, 3, 4, 1, 4, 9, 8, 6, 3, 5, 7, 8, 8, 4, 6, 4, 2, 5, 6] + // 1) Sort the input to group equal values together in segments. + // sorted_input, sorted_input_inds = sort(input) + // sorted_input: + // [1, 2, 3, 3, 3, 4, 4, 4, 4, 5, 5, 5, 6, 6, 6, 7, 8, 8, 8, 9] + // sorted_input_inds: + // [4, 17, 0, 2, 9, 3, 5, 14, 16, 1, 10, 18, 8, 15, 19, 11, 7, 12, 13, 6] + // 2) Identify the boundaries between segments and use prefix sum to + // compute the unique ID for each sorted value. + // sorted_input_unique_ids = prefix_sum(indicator(sorted_input)) + // indicator(sorted_input): + // [0, 1, 1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 1, 0, 0, 1] + // sorted_input_unique_ids: + // [0, 1, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 7, 7, 7, 8] + // 3) Extract the input index of the first occurrence of each unique value. + // If counts are required, also extract the end index of each segment. + // unique_input_inds[sorted_input_unique_ids] = + // sorted_input_inds (@ indicator) + // segment_ends[sorted_input_unique_ids[i] - 1] = i (@ indicator) + // unique_input_inds: [4, 17, 0, 3, 1, 8, 11, 7, 6] + // segment_ends: [1, 2, 5, 9, 12, 15, 16, 19, 20] + // 4) Sort the extracted unique input indices to put them in order of + // first appearance. + // sorted_unique_input_inds, sorted_unique_perm = + // sort(unique_input_inds) + // sorted_unique_input_inds: [0, 1, 3, 4, 6, 7, 8, 11, 17] + // sorted_unique_perm: [2, 4, 3, 0, 8, 7, 5, 6, 1] + // 5) Gather the sorted unique input values to produce output, and invert + // the second sort permutation to produce an inverse ID mapping. If + // counts are required, also take the adjacent difference between + // segment_ends indices to produce counts. + // output = input[sorted_unique_input_inds] + // inv_sorted_unique_perm[sorted_unique_perm[i]] = i + // counts = adjacent_difference(segment_ends) + // output: [3, 5, 4, 1, 9, 8, 6, 7, 2] + // inv_sorted_unique_perm: [3, 8, 0, 2, 1, 6, 7, 5, 4] + // counts: [3, 3, 4, 1, 1, 3, 3, 1, 1] + // 6) Look up unique IDs via the inverse ID mapping and scatter them using + // the original sort permutation to produce the indices output. + // idx[sorted_input_inds] = + // inv_sorted_unique_perm[sorted_input_unique_ids] + // idx: [0, 1, 0, 2, 3, 2, 4, 5, 6, 0, 1, 7, 5, 5, 2, 6, 2, 8, 1, 6] + + Tensor sorted_input_inds; + TIndex* sorted_input_inds_ptr = nullptr; + AllocateTemp(context, input_size, &sorted_input_inds, + &sorted_input_inds_ptr, done); + if (!context->status().ok()) return; + + Tensor sorted_input; + T* sorted_input_ptr = nullptr; + AllocateTemp(context, input_size, &sorted_input, &sorted_input_ptr, done); + if (!context->status().ok()) return; + + const T* input_ptr = input.flat().data(); + OP_REQUIRES_OK_ASYNC( + context, + GpuRadixSort(context, input_size, /*keys_in=*/input_ptr, + /*keys_out=*/sorted_input_ptr, + /*indices_in=*/static_cast(nullptr), + /*indices_out=*/sorted_input_inds_ptr), + done); + + using namespace unique_op_gpu; + + // Create a fancy input iterator to indicate segment boundaries. + gpuprim::CountingInputIterator counting_iter(0); + gpuprim::TransformInputIterator, + gpuprim::CountingInputIterator> + segment_indicator_iter(counting_iter, {sorted_input_ptr}); + + Tensor sorted_input_unique_ids; + TIndex* sorted_input_unique_ids_ptr = nullptr; + AllocateTemp(context, input_size, &sorted_input_unique_ids, + &sorted_input_unique_ids_ptr, done); + if (!context->status().ok()) return; + + OP_REQUIRES_OK_ASYNC( + context, + GpuInclusivePrefixSum(context, input_size, segment_indicator_iter, + sorted_input_unique_ids_ptr), + done); + + // Copy the last element of sorted_input_unique_ids back to the host to + // obtain uniq_size. + ScratchSpace last_idx_host(context, 1, /*on_host=*/true); + OP_REQUIRES_OK_ASYNC( + context, + stream->Memcpy(last_idx_host.mutable_data(), + se::DeviceMemoryBase( + const_cast(sorted_input_unique_ids_ptr) + + (input_size - 1), + sizeof(*last_idx_host.data())), + sizeof(*last_idx_host.data())), + done); + + auto async_finish_computation = [this, context, input_size, input_ptr, + sorted_input_inds, sorted_input_inds_ptr, + sorted_input_unique_ids, + sorted_input_unique_ids_ptr, last_idx_host, + has_count_output, done]() -> void { + const GPUDevice& device = context->eigen_gpu_device(); + int64 uniq_size = (*last_idx_host.data()) + 1; + + std::unique_ptr scoped_activation = + context->op_device_context()->stream()->parent()->Activate(); + + Tensor unique_input_inds; + TIndex* unique_input_inds_ptr = nullptr; + AllocateTemp(context, uniq_size, &unique_input_inds, + &unique_input_inds_ptr, done); + if (!context->status().ok()) return; + + Tensor segment_ends; + TIndex* segment_ends_ptr = nullptr; + if (has_count_output) { + AllocateTemp(context, uniq_size, &segment_ends, &segment_ends_ptr, + done); + if (!context->status().ok()) return; + } + + OP_REQUIRES_OK_ASYNC( + context, + ExtractFirstOccurrenceIndices( + device, input_size, uniq_size, sorted_input_inds_ptr, + sorted_input_unique_ids_ptr, unique_input_inds_ptr, + segment_ends_ptr), + done); + + Tensor sorted_unique_input_inds; + TIndex* sorted_unique_input_inds_ptr = nullptr; + AllocateTemp(context, uniq_size, &sorted_unique_input_inds, + &sorted_unique_input_inds_ptr, done); + if (!context->status().ok()) return; + + Tensor sorted_unique_perm; + TIndex* sorted_unique_perm_ptr = nullptr; + AllocateTemp(context, uniq_size, &sorted_unique_perm, + &sorted_unique_perm_ptr, done); + if (!context->status().ok()) return; + + // Sort by input index so that output is in order of appearance. + OP_REQUIRES_OK_ASYNC( + context, + GpuRadixSort(context, uniq_size, + /*keys_in=*/unique_input_inds_ptr, + /*keys_out=*/sorted_unique_input_inds_ptr, + /*indices_in=*/static_cast(nullptr), + /*indices_out=*/sorted_unique_perm_ptr, + /*num_bits=*/Log2Ceiling(input_size)), + done); + + // Free temporary tensor that is no longer needed. + unique_input_inds = Tensor(); + unique_input_inds_ptr = nullptr; + + Tensor* output = nullptr; + OP_REQUIRES_OK_ASYNC( + context, + context->allocate_output(0, TensorShape({uniq_size}), &output), done); + T* output_ptr = output->flat().data(); + + Tensor inv_sorted_unique_perm; + TIndex* inv_sorted_unique_perm_ptr = nullptr; + AllocateTemp(context, uniq_size, &inv_sorted_unique_perm, + &inv_sorted_unique_perm_ptr, done); + if (!context->status().ok()) return; + + TIndex* count_ptr = nullptr; + if (has_count_output) { + Tensor* count = nullptr; + OP_REQUIRES_OK_ASYNC( + context, + context->allocate_output(2, TensorShape({uniq_size}), &count), + done); + count_ptr = count->flat().data(); + } + + // Compute output and counts (if necessary). + OP_REQUIRES_OK_ASYNC( + context, + GatherOutputsAndInvertPermutation( + device, uniq_size, input_ptr, sorted_unique_input_inds_ptr, + sorted_unique_perm_ptr, segment_ends_ptr, output_ptr, + inv_sorted_unique_perm_ptr, count_ptr), + done); + + // Free temporary tensors that are no longer needed. + sorted_unique_perm = Tensor(); + sorted_unique_perm_ptr = nullptr; + sorted_unique_input_inds = Tensor(); + sorted_unique_input_inds_ptr = nullptr; + segment_ends = Tensor(); + segment_ends_ptr = nullptr; + + Tensor* idx = nullptr; + OP_REQUIRES_OK_ASYNC( + context, context->allocate_output(1, TensorShape({input_size}), &idx), + done); + TIndex* idx_ptr = idx->flat().data(); + + // Compute indices output. + OP_REQUIRES_OK_ASYNC( + context, + LookupAndScatterUniqueIds(device, input_size, sorted_input_inds_ptr, + sorted_input_unique_ids_ptr, + inv_sorted_unique_perm_ptr, idx_ptr), + done); + + done(); + }; + + context->device() + ->tensorflow_accelerator_device_info() + ->event_mgr->ThenExecute(stream, async_finish_computation); + } +}; + +} // end namespace tensorflow + +#endif // GOOGLE_CUDA + +#endif // TENSORFLOW_CORE_KERNELS_UNIQUE_OP_GPU_CU_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/variable_ops.h b/third_party/tflite-hdrs/tensorflow/core/kernels/variable_ops.h new file mode 100644 index 00000000..035b583a --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/variable_ops.h @@ -0,0 +1,47 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_VARIABLE_OPS_H_ +#define TENSORFLOW_CORE_KERNELS_VARIABLE_OPS_H_ + +#include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/resource_mgr.h" +#include "tensorflow/core/framework/resource_var.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +class VariableOp : public OpKernel { + public: + explicit VariableOp(OpKernelConstruction* context); + void Compute(OpKernelContext* ctx) override; + + private: + DataType dtype_; + TensorShape shape_; + ContainerInfo cinfo_; + + VariableOp(const VariableOp&) = delete; + void operator=(const VariableOp&) = delete; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_VARIABLE_OPS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/variant_ops_util.h b/third_party/tflite-hdrs/tensorflow/core/kernels/variant_ops_util.h new file mode 100644 index 00000000..d6d1e831 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/variant_ops_util.h @@ -0,0 +1,36 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_VARIANT_OPS_UTIL_H_ +#define TENSORFLOW_CORE_KERNELS_VARIANT_OPS_UTIL_H_ + +#include + +#include "tensorflow/core/platform/status.h" + +namespace tensorflow { + +class OpKernelContext; +class Tensor; +class Variant; + +void AddNVariant(OpKernelContext* ctx, + std::function + binary_add_variant); + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_VARIANT_OPS_UTIL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/where_op.h b/third_party/tflite-hdrs/tensorflow/core/kernels/where_op.h new file mode 100644 index 00000000..fceea011 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/where_op.h @@ -0,0 +1,65 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_WHERE_OP_H_ +#define TENSORFLOW_CORE_KERNELS_WHERE_OP_H_ + +#include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +#define TF_CALL_WHERE_GPU_TYPES(m) \ + TF_CALL_int8(m); \ + TF_CALL_uint8(m); \ + TF_CALL_int64(m); \ + TF_CALL_float(m); \ + TF_CALL_double(m); \ + TF_CALL_complex64(m); \ + TF_CALL_complex128(m); \ + TF_CALL_bool(m); + +namespace functor { + +template +struct NumTrue { + EIGEN_ALWAYS_INLINE static absl::Status Compute( + OpKernelContext* ctx, const Device& d, + typename TTypes::ConstFlat input, + typename TTypes::UnalignedScalar num_true); +}; + +template +struct Where { + // Copies indices of true values in input into output. The pointer + // found_true should sit on the host. Compute should copy the + // number of true elements found into it. At the end, if + // *found_true != output.dimension(0), + // then the input may have changed between the initial counting of + // the true values and the call to Where. + EIGEN_ALWAYS_INLINE static absl::Status Compute( + OpKernelContext* ctx, const Device& d, + typename TTypes::ConstTensor input, + typename TTypes::Matrix output, TIndex* found_true); +}; + +} // namespace functor + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_WHERE_OP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/where_op_gpu.cu.h b/third_party/tflite-hdrs/tensorflow/core/kernels/where_op_gpu.cu.h new file mode 100644 index 00000000..5eb03ec6 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/where_op_gpu.cu.h @@ -0,0 +1,353 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_WHERE_OP_GPU_CU_H_ +#define TENSORFLOW_CORE_KERNELS_WHERE_OP_GPU_CU_H_ + +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM + +#define EIGEN_USE_GPU + +#include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/bounds_check.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/kernels/gpu_prim.h" +#include "tensorflow/core/kernels/where_op.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/gpu_kernel_helper.h" + +namespace tensorflow { + +typedef Eigen::GpuDevice GPUDevice; + +namespace functor { + +template +__global__ void PropagateWhereIndicesKernel( + const TIndex output_rows, const typename Eigen::array strides, + int64* __restrict__ output) { + // TODO(ebrevdo): Use a multi-dimensional loop, increasing the + // dimensions of individual indices manually, instead of relying on + // a scalar loop variable and using integer division. + GPU_1D_KERNEL_LOOP(i, output_rows) { + TIndex index_value = ldg(output + NDIM * i); +#pragma unroll + for (int c = 0; c < NDIM; ++c) { + *(output + NDIM * i + c) = index_value / strides[c]; + index_value %= strides[c]; + } + } +} + +namespace { + +template +struct IsNonzero { + EIGEN_DEVICE_FUNC IsNonzero() : zero(T(0)) {} + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool operator()(const T& x) const { + return (x != zero); + } + const T zero; +}; + +template +struct CubDeviceReduceCount { + gpuError_t operator()(void* d_temp_storage, size_t& temp_storage_bytes, + const T* d_in, TIndex* d_out, int num_items, + gpuStream_t stream = 0, + bool debug_synchronous = false) { + IsNonzero is_nonzero; + gpuprim::TransformInputIterator, const T*> + is_nonzero_iter(d_in, is_nonzero); + return gpuprim::DeviceReduce::Sum(d_temp_storage, temp_storage_bytes, + is_nonzero_iter, d_out, num_items, stream, + debug_synchronous); + } +}; + +template +struct CubDeviceReduceCount { + gpuError_t operator()(void* d_temp_storage, size_t& temp_storage_bytes, + const bool* d_in, TIndex* d_out, int num_items, + gpuStream_t stream = 0, + bool debug_synchronous = false) { + return gpuprim::DeviceReduce::Sum(d_temp_storage, temp_storage_bytes, d_in, + d_out, num_items, stream, + debug_synchronous); + } +}; + +template +struct CubDeviceSelectFlaggedCounter; + +template +struct CubDeviceSelectFlaggedCounter { + gpuError_t operator()(void* d_temp_storage, size_t& temp_storage_bytes, + const T* d_flags, OutputIterator d_out, + TIndex* d_num_selected_out, int num_items, + gpuStream_t stream = 0, + bool debug_synchronous = false) { + gpuprim::CountingInputIterator select_counter(0); + IsNonzero is_nonzero; + gpuprim::TransformInputIterator, const T*> + is_nonzero_iter(d_flags, is_nonzero); + return gpuprim::DeviceSelect::Flagged( + d_temp_storage, temp_storage_bytes, select_counter /*d_in*/, + is_nonzero_iter /*d_flags*/, d_out, d_num_selected_out, num_items, + stream, debug_synchronous); + } +}; + +template +struct CubDeviceSelectFlaggedCounter { + gpuError_t operator()(void* d_temp_storage, size_t& temp_storage_bytes, + const T* d_flags, OutputIterator d_out, + TIndex* d_num_selected_out, int num_items, + gpuStream_t stream = 0, + bool debug_synchronous = false) { + gpuprim::CountingInputIterator select_counter(0); + return gpuprim::DeviceSelect::Flagged( + d_temp_storage, temp_storage_bytes, select_counter /*d_in*/, d_flags, + d_out, d_num_selected_out, num_items, stream, debug_synchronous); + } +}; + +} // namespace + +template +struct NumTrue { + EIGEN_ALWAYS_INLINE static Status Compute( + OpKernelContext* ctx, const GPUDevice& d, + typename TTypes::ConstFlat input, + typename TTypes::UnalignedScalar num_true) { + const auto& cu_stream = GetGpuStream(ctx); + + std::size_t temp_storage_bytes = 0; + const T* input_data = input.data(); + TIndex* num_true_data = num_true.data(); + + // TODO(ebrevdo): sum doesn't work; perhaps need a different + // iterator? + auto reducer = CubDeviceReduceCount(); + auto first_success = reducer(/*temp_storage*/ nullptr, temp_storage_bytes, + /*d_in*/ input_data, + /*d_out*/ num_true_data, + /*num_items*/ input.size(), + /*stream*/ cu_stream); + + if (first_success != gpuSuccess) { + return errors::Internal( + "WhereOp: Could not launch gpuprim::DeviceReduce::Sum to calculate " + "temp_storage_bytes, status: ", + GpuGetErrorString(first_success)); + } + + Tensor temp_storage; + TF_RETURN_IF_ERROR(ctx->allocate_temp( + DT_INT8, TensorShape({static_cast(temp_storage_bytes)}), + &temp_storage)); + + auto second_success = reducer( + /*temp_storage*/ temp_storage.flat().data(), temp_storage_bytes, + /*d_in*/ input_data, + /*d_out*/ num_true_data, + /*num_items*/ input.size(), + /*stream*/ cu_stream); + + if (second_success != gpuSuccess) { + return errors::Internal( + "WhereOp: Could not launch gpuprim::DeviceReduce::Sum to count " + "number of true / nonzero indices. temp_storage_bytes: ", + temp_storage_bytes, ", status: ", GpuGetErrorString(second_success)); + } + + return OkStatus(); + } +}; + +#define NUMTRUE_GPU_FUNCTOR(T) \ + template struct NumTrue; \ + template struct NumTrue; + +// We only need to declare the NumTrue functor once, but this file is +// included from where_op_gpu_impl_X.cu.cc for X=1,2,... +// Only declare for X = 1. +#if GPU_PROVIDED_DIM == 1 + +TF_CALL_WHERE_GPU_TYPES(NUMTRUE_GPU_FUNCTOR); + +#endif // GPU_PROVIDED_DIM == 1 + +#undef NUMTRUE_GPU_FUNCTOR + +template +class WhereOutputIterator { + public: + // Required iterator traits + typedef WhereOutputIterator self_type; + typedef std::ptrdiff_t difference_type; + typedef void value_type; + typedef void pointer; + typedef int64& reference; + +#if (THRUST_VERSION >= 100700) + // Use Thrust's iterator categories so we can use these iterators in Thrust + // 1.7 (or newer) methods + typedef typename thrust::detail::iterator_facade_category< + thrust::device_system_tag, thrust::random_access_traversal_tag, + value_type, + reference>::type iterator_category; ///< The iterator category +#else + typedef std::random_access_iterator_tag + iterator_category; ///< The iterator category +#endif // THRUST_VERSION + + WhereOutputIterator(int64* ptr, const Eigen::DenseIndex max_row) + : ptr_(ptr), max_row_(max_row) {} + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE int64& operator[](int n) const { + // If the selection mechanism finds too many true values (because + // the input tensor changed between allocation of output and now), + // we may accidentally try to write past the allowable memory. If + // valid is false, then we don't do this. Instead, we'll read off + // the number of items found in Flagged()'s d_num_selected_out at + // the end and confirm that it matches the number of rows of output. + const bool valid = FastBoundsCheck(n, max_row_); + return *(ptr_ + (valid ? (NDIM * n) : 0)); + } + + private: + int64* ptr_; + const Eigen::DenseIndex max_row_; +}; + +template +Eigen::array CalculateStrides( + typename TTypes::ConstTensor input) { + const Eigen::DSizes dims = input.dimensions(); + Eigen::array strides; + EIGEN_STATIC_ASSERT((static_cast(decltype(input)::Layout) == + static_cast(Eigen::RowMajor)), + INTERNAL_ERROR_INPUT_SHOULD_BE_ROWMAJOR); + strides[NDIM - 1] = 1; + for (int i = NDIM - 2; i >= 0; --i) { + strides[i] = strides[i + 1] * dims[i + 1]; + } + return strides; +} + +template +struct Where { + EIGEN_ALWAYS_INLINE static Status Compute( + OpKernelContext* ctx, const GPUDevice& d, + typename TTypes::ConstTensor input, + typename TTypes::Matrix output, TIndex* found_true_host) { + if (output.dimension(0) == 0) { + // Nothing to do. + return OkStatus(); + } + + const auto& cu_stream = GetGpuStream(ctx); + + std::size_t temp_storage_bytes = 0; + + Tensor found_true_t; + TF_RETURN_IF_ERROR(ctx->allocate_temp(DataTypeToEnum::v(), + TensorShape({}), &found_true_t)); + TIndex* found_true_device = found_true_t.scalar().data(); + + WhereOutputIterator output_iterator( + output.data(), + /* max_row */ output.dimension(0)); + + typedef std::decay DT; + CubDeviceSelectFlaggedCounter< + T, TIndex, decltype(output_iterator) /*OutputIterator*/, + std::is_convertible::value /*IsConvertibleToBool*/> + counter; + auto first_success = counter(/*temp_storage*/ nullptr, temp_storage_bytes, + /*d_flags*/ input.data(), + /*d_out*/ output_iterator, + /*d_num_selected_out*/ found_true_device, + /*num_items*/ input.size(), + /*stream*/ cu_stream); + if (first_success != gpuSuccess) { + return errors::Internal( + "WhereOp: Could not launch gpuprim::DeviceSelect::Flagged to " + "calculate " + "temp_storage_bytes, status: ", + GpuGetErrorString(first_success)); + } + + Tensor temp_storage; + TF_RETURN_IF_ERROR(ctx->allocate_temp( + DT_INT8, TensorShape({static_cast(temp_storage_bytes)}), + &temp_storage)); + + auto second_success = counter( + /*temp_storage*/ temp_storage.flat().data(), temp_storage_bytes, + /*d_flags*/ input.data(), + /*d_out*/ output_iterator, + /*d_num_selected_out*/ found_true_device, + /*num_items*/ input.size(), + /*stream*/ cu_stream); + + if (second_success != gpuSuccess) { + return errors::Internal( + "WhereOp: Could not launch gpuprim::DeviceSelect::Flagged to copy " + "indices out, status: ", + GpuGetErrorString(second_success)); + } + + // TODO(ebrevdo): Find a way to synchronously copy back data from + // found_true_device to *found_true_host. + + const Eigen::array strides = + CalculateStrides(input); + const TIndex output_rows = output.dimension(0); + GpuLaunchConfig config = GetGpuLaunchConfig(output_rows, d); + TF_CHECK_OK(GpuLaunchKernel(PropagateWhereIndicesKernel, + config.block_count, config.thread_per_block, 0, + d.stream(), output_rows, strides, + output.data())); + + return OkStatus(); + } +}; + +#define DECLARE_GPU_SPEC_INDEX(Dims, T, TIndex) \ + template struct Where + +#define DECLARE_GPU_SPEC(T) \ + DECLARE_GPU_SPEC_INDEX(GPU_PROVIDED_DIM, T, int32); \ + DECLARE_GPU_SPEC_INDEX(GPU_PROVIDED_DIM, T, int64) + +TF_CALL_WHERE_GPU_TYPES(DECLARE_GPU_SPEC); + +#undef DECLARE_GPU_SPEC +#undef DECLARE_GPU_SPEC_INDEX + +} // namespace functor + +} // namespace tensorflow + +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM + +#endif // TENSORFLOW_CORE_KERNELS_WHERE_OP_GPU_CU_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/winograd_transform.h b/third_party/tflite-hdrs/tensorflow/core/kernels/winograd_transform.h new file mode 100644 index 00000000..4f4067e3 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/winograd_transform.h @@ -0,0 +1,377 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_WINOGRAD_TRANSFORM_H_ +#define TENSORFLOW_CORE_KERNELS_WINOGRAD_TRANSFORM_H_ + +#include "tensorflow/core/kernels/deep_conv2d.h" + +namespace tensorflow { + +// Winograd DeepConv2DTransform implementation for 3x3 filters. +// Details: +// *) Arithmetic complexity of computations: Shmuel Winograd +// *) Fast Algorithms for Convolutional Neural Networks: Lavin, Gray + +template +class WinogradTransform : public DeepConv2DTransform { + public: + typedef typename DeepConv2DTransform::Shape Shape; + + WinogradTransform() + : filter_shape_(3, 3), input_shape_(4, 4), output_shape_(2, 2) {} + + virtual void GetFilterTransformMatrix(const int64_t rows, const int64_t cols, + T* transform_matrix) const; + + virtual void GetInputTransformMatrix(const int64_t rows, const int64_t cols, + T* transform_matrix) const; + + virtual void GetOutputTransformMatrix(const int64_t rows, const int64_t cols, + T* transform_matrix) const; + + virtual const Shape& filter_shape() const { return filter_shape_; } + virtual const Shape& input_shape() const { return input_shape_; } + virtual const Shape& output_shape() const { return output_shape_; } + + private: + const Shape filter_shape_; + const Shape input_shape_; + const Shape output_shape_; +}; + +// The filter transform matrix is the kronecker product 'M * M' of the +// following matrix 'M': +// +// [ 1 0 0 ] +// [ 1/2 1/2 1/2 ] +// [ 1/2 -1/2 1/2 ] +// [ 0 0 1 ] +// +// The data layout of 'transform_matrix': +// [input_tile_spatial_size, filter_spatial_size] +// +template +void WinogradTransform::GetFilterTransformMatrix(const int64_t rows, + const int64_t cols, + T* transform_matrix) const { + CHECK_GT(rows, 0); + CHECK_GT(cols, 0); + memset(transform_matrix, 0, sizeof(T) * rows * cols); + + // Sub matrix [0,0] + transform_matrix[0 * cols + 0] = T(1.0); + + transform_matrix[1 * cols + 0] = T(0.5); + transform_matrix[1 * cols + 1] = T(0.5); + transform_matrix[1 * cols + 2] = T(0.5); + + transform_matrix[2 * cols + 0] = T(0.5); + transform_matrix[2 * cols + 1] = T(-0.5); + transform_matrix[2 * cols + 2] = T(0.5); + + transform_matrix[3 * cols + 2] = T(1.0); + + // Sub matrix [1,0] + transform_matrix[4 * cols + 0] = T(0.5); + + transform_matrix[5 * cols + 0] = T(0.25); + transform_matrix[5 * cols + 1] = T(0.25); + transform_matrix[5 * cols + 2] = T(0.25); + + transform_matrix[6 * cols + 0] = T(0.25); + transform_matrix[6 * cols + 1] = T(-0.25); + transform_matrix[6 * cols + 2] = T(0.25); + + transform_matrix[7 * cols + 2] = T(0.5); + + // Sub matrix [1,1] + transform_matrix[4 * cols + 3] = T(0.5); + + transform_matrix[5 * cols + 3] = T(0.25); + transform_matrix[5 * cols + 4] = T(0.25); + transform_matrix[5 * cols + 5] = T(0.25); + + transform_matrix[6 * cols + 3] = T(0.25); + transform_matrix[6 * cols + 4] = T(-0.25); + transform_matrix[6 * cols + 5] = T(0.25); + + transform_matrix[7 * cols + 5] = T(0.5); + + // Sub matrix [1,2] + transform_matrix[4 * cols + 6] = T(0.5); + + transform_matrix[5 * cols + 6] = T(0.25); + transform_matrix[5 * cols + 7] = T(0.25); + transform_matrix[5 * cols + 8] = T(0.25); + + transform_matrix[6 * cols + 6] = T(0.25); + transform_matrix[6 * cols + 7] = T(-0.25); + transform_matrix[6 * cols + 8] = T(0.25); + + transform_matrix[7 * cols + 8] = T(0.5); + + // Sub matrix [2,0] + transform_matrix[8 * cols + 0] = T(0.5); + + transform_matrix[9 * cols + 0] = T(0.25); + transform_matrix[9 * cols + 1] = T(0.25); + transform_matrix[9 * cols + 2] = T(0.25); + + transform_matrix[10 * cols + 0] = T(0.25); + transform_matrix[10 * cols + 1] = T(-0.25); + transform_matrix[10 * cols + 2] = T(0.25); + + transform_matrix[11 * cols + 2] = T(0.5); + + // Sub matrix [2,1] + transform_matrix[8 * cols + 3] = T(-0.5); + + transform_matrix[9 * cols + 3] = T(-0.25); + transform_matrix[9 * cols + 4] = T(-0.25); + transform_matrix[9 * cols + 5] = T(-0.25); + + transform_matrix[10 * cols + 3] = T(-0.25); + transform_matrix[10 * cols + 4] = T(0.25); + transform_matrix[10 * cols + 5] = T(-0.25); + + transform_matrix[11 * cols + 5] = T(-0.5); + + // Sub matrix [2,2] + transform_matrix[8 * cols + 6] = T(0.5); + + transform_matrix[9 * cols + 6] = T(0.25); + transform_matrix[9 * cols + 7] = T(0.25); + transform_matrix[9 * cols + 8] = T(0.25); + + transform_matrix[10 * cols + 6] = T(0.25); + transform_matrix[10 * cols + 7] = T(-0.25); + transform_matrix[10 * cols + 8] = T(0.25); + + transform_matrix[11 * cols + 8] = T(0.5); + + // Sub matrix [3,2] + transform_matrix[12 * cols + 6] = T(1.0); + + transform_matrix[13 * cols + 6] = T(0.5); + transform_matrix[13 * cols + 7] = T(0.5); + transform_matrix[13 * cols + 8] = T(0.5); + + transform_matrix[14 * cols + 6] = T(0.5); + transform_matrix[14 * cols + 7] = T(-0.5); + transform_matrix[14 * cols + 8] = T(0.5); + + transform_matrix[15 * cols + 8] = T(1.0); +} + +// The input transform matrix is the kronecker product 'M * M' of the +// following matrix 'M': +// +// [1 0 -1 0] +// [0 1 1 0] +// [0 -1 1 0] +// [0 1 0 -1] +// +// Data layout of 'transform_matrix': +// [tile_spatial_size, tile_spatial_size] +// +template +void WinogradTransform::GetInputTransformMatrix(const int64_t rows, + const int64_t cols, + T* transform_matrix) const { + CHECK_GT(rows, 0); + CHECK_GT(cols, 0); + memset(transform_matrix, 0, sizeof(T) * rows * cols); + + // Sub matrix [0,0] + transform_matrix[0 * cols + 0] = T(1.0); + transform_matrix[0 * cols + 2] = T(-1.0); + + transform_matrix[1 * cols + 1] = T(1.0); + transform_matrix[1 * cols + 2] = T(1.0); + + transform_matrix[2 * cols + 1] = T(-1.0); + transform_matrix[2 * cols + 2] = T(1.0); + + transform_matrix[3 * cols + 1] = T(1.0); + transform_matrix[3 * cols + 3] = T(-1.0); + + // Sub matrix [0,2] + transform_matrix[0 * cols + 8] = T(-1.0); + transform_matrix[0 * cols + 10] = T(1.0); + + transform_matrix[1 * cols + 9] = T(-1.0); + transform_matrix[1 * cols + 10] = T(-1.0); + + transform_matrix[2 * cols + 9] = T(1.0); + transform_matrix[2 * cols + 10] = T(-1.0); + + transform_matrix[3 * cols + 9] = T(-1.0); + transform_matrix[3 * cols + 11] = T(1.0); + + // Sub matrix [1,1] + transform_matrix[4 * cols + 4] = T(1.0); + transform_matrix[4 * cols + 6] = T(-1.0); + + transform_matrix[5 * cols + 5] = T(1.0); + transform_matrix[5 * cols + 6] = T(1.0); + + transform_matrix[6 * cols + 5] = T(-1.0); + transform_matrix[6 * cols + 6] = T(1.0); + + transform_matrix[7 * cols + 5] = T(1.0); + transform_matrix[7 * cols + 7] = T(-1.0); + + // Sub matrix [1,2] + transform_matrix[4 * cols + 8] = T(1.0); + transform_matrix[4 * cols + 10] = T(-1.0); + + transform_matrix[5 * cols + 9] = T(1.0); + transform_matrix[5 * cols + 10] = T(1.0); + + transform_matrix[6 * cols + 9] = T(-1.0); + transform_matrix[6 * cols + 10] = T(1.0); + + transform_matrix[7 * cols + 9] = T(1.0); + transform_matrix[7 * cols + 11] = T(-1.0); + + // Sub matrix [2,1] + transform_matrix[8 * cols + 4] = T(-1.0); + transform_matrix[8 * cols + 6] = T(1.0); + + transform_matrix[9 * cols + 5] = T(-1.0); + transform_matrix[9 * cols + 6] = T(-1.0); + + transform_matrix[10 * cols + 5] = T(1.0); + transform_matrix[10 * cols + 6] = T(-1.0); + + transform_matrix[11 * cols + 5] = T(-1.0); + transform_matrix[11 * cols + 7] = T(1.0); + + // Sub matrix [2,2] + transform_matrix[8 * cols + 8] = T(1.0); + transform_matrix[8 * cols + 10] = T(-1.0); + + transform_matrix[9 * cols + 9] = T(1.0); + transform_matrix[9 * cols + 10] = T(1.0); + + transform_matrix[10 * cols + 9] = T(-1.0); + transform_matrix[10 * cols + 10] = T(1.0); + + transform_matrix[11 * cols + 9] = T(1.0); + transform_matrix[11 * cols + 11] = T(-1.0); + + // Sub matrix [3,1] + transform_matrix[12 * cols + 4] = T(1.0); + transform_matrix[12 * cols + 6] = T(-1.0); + + transform_matrix[13 * cols + 5] = T(1.0); + transform_matrix[13 * cols + 6] = T(1.0); + + transform_matrix[14 * cols + 5] = T(-1.0); + transform_matrix[14 * cols + 6] = T(1.0); + + transform_matrix[15 * cols + 5] = T(1.0); + transform_matrix[15 * cols + 7] = T(-1.0); + + // Sub matrix [3,3] + transform_matrix[12 * cols + 12] = T(-1.0); + transform_matrix[12 * cols + 14] = T(1.0); + + transform_matrix[13 * cols + 13] = T(-1.0); + transform_matrix[13 * cols + 14] = T(-1.0); + + transform_matrix[14 * cols + 13] = T(1.0); + transform_matrix[14 * cols + 14] = T(-1.0); + + transform_matrix[15 * cols + 13] = T(-1.0); + transform_matrix[15 * cols + 15] = T(1.0); +}; + +// The output transform matrix is the kronecker product 'M * M' of the +// following matrix 'M': +// +// [1 1 1 0] +// [0 1 -1 -1] +// +// Data layout of 'transform_matrix': +// [out_tile_spatial_size, tile_spatial_size] +// +template +void WinogradTransform::GetOutputTransformMatrix(const int64_t rows, + const int64_t cols, + T* transform_matrix) const { + CHECK_GT(rows, 0); + CHECK_GT(cols, 0); + memset(transform_matrix, 0, sizeof(T) * rows * cols); + + // Sub matrix [0,0] + transform_matrix[0 * cols + 0] = T(1.0); + transform_matrix[0 * cols + 1] = T(1.0); + transform_matrix[0 * cols + 2] = T(1.0); + + transform_matrix[1 * cols + 1] = T(1.0); + transform_matrix[1 * cols + 2] = T(-1.0); + transform_matrix[1 * cols + 3] = T(-1.0); + + // Sub matrix [0,1] + transform_matrix[0 * cols + 4] = T(1.0); + transform_matrix[0 * cols + 5] = T(1.0); + transform_matrix[0 * cols + 6] = T(1.0); + + transform_matrix[1 * cols + 5] = T(1.0); + transform_matrix[1 * cols + 6] = T(-1.0); + transform_matrix[1 * cols + 7] = T(-1.0); + + // Sub matrix [0,2] + transform_matrix[0 * cols + 8] = T(1.0); + transform_matrix[0 * cols + 9] = T(1.0); + transform_matrix[0 * cols + 10] = T(1.0); + + transform_matrix[1 * cols + 9] = T(1.0); + transform_matrix[1 * cols + 10] = T(-1.0); + transform_matrix[1 * cols + 11] = T(-1.0); + + // Sub matrix [1,1] + transform_matrix[2 * cols + 4] = T(1.0); + transform_matrix[2 * cols + 5] = T(1.0); + transform_matrix[2 * cols + 6] = T(1.0); + + transform_matrix[3 * cols + 5] = T(1.0); + transform_matrix[3 * cols + 6] = T(-1.0); + transform_matrix[3 * cols + 7] = T(-1.0); + + // Sub matrix [1,2] + transform_matrix[2 * cols + 8] = T(-1.0); + transform_matrix[2 * cols + 9] = T(-1.0); + transform_matrix[2 * cols + 10] = T(-1.0); + + transform_matrix[3 * cols + 9] = T(-1.0); + transform_matrix[3 * cols + 10] = T(1.0); + transform_matrix[3 * cols + 11] = T(1.0); + + // Sub matrix [1,3] + transform_matrix[2 * cols + 12] = T(-1.0); + transform_matrix[2 * cols + 13] = T(-1.0); + transform_matrix[2 * cols + 14] = T(-1.0); + + transform_matrix[3 * cols + 13] = T(-1.0); + transform_matrix[3 * cols + 14] = T(1.0); + transform_matrix[3 * cols + 15] = T(1.0); +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_WINOGRAD_TRANSFORM_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/kernels/xent_op.h b/third_party/tflite-hdrs/tensorflow/core/kernels/xent_op.h new file mode 100644 index 00000000..07870f50 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/kernels/xent_op.h @@ -0,0 +1,115 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_XENT_OP_H_ +#define TENSORFLOW_CORE_KERNELS_XENT_OP_H_ +// Functor definition for XentOp, must be compilable by nvcc. + +#include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive + +#include "tensorflow/core/framework/tensor_types.h" + +namespace tensorflow { +namespace functor { + +// Functor used by XentOp to do the computations. +template +struct XentFunctor { + // Computes Cross Entropy loss and backprop. + // + // logits: batch_size, num_classes. + // labels: batch_size, num_classes. + // scratch: temporary tensor, dims: batch_size, 1 + // loss: output tensor for the loss, dims: batch_size. + // backprop: output tensor for the backprop, dims: batch_size, num_classes. + void operator()(const Device &d, + const Eigen::DSizes &shape, + const Eigen::array &logits_bcast, + const Eigen::array &labels_bcast, + typename TTypes::ConstMatrix logits, + typename TTypes::ConstMatrix labels, + typename TTypes::Matrix scratch, + typename TTypes::Vec loss, + typename TTypes::Matrix backprop); +}; + +// Eigen code implementing XentFunctor::operator(). +// This code works for both CPU and GPU and is used by the functor +// specializations for both device types. +template +struct XentEigenImpl { + static void Compute(const Device &d, + const Eigen::DSizes &shape, + const Eigen::array &logits_bcast, + const Eigen::array &labels_bcast, + typename TTypes::ConstMatrix logits, + typename TTypes::ConstMatrix labels, + typename TTypes::Matrix scratch, + typename TTypes::Vec loss, + typename TTypes::Matrix backprop) { + // NOTE(touts): This duplicates some of the computations in softmax_op + // because we need the intermediate (logits -max(logits)) values to + // avoid a log(exp()) in the computation of the loss. + + const int kBatchDim = 0; + const int kClassDim = 1; + + const int batch_size = shape[kBatchDim]; + const int num_classes = shape[kClassDim]; + +// These arrays are used to reduce along the class dimension, and broadcast +// the resulting value to all classes. + Eigen::IndexList > along_class; + Eigen::IndexList > batch_by_one; + batch_by_one.set(0, batch_size); + Eigen::IndexList batch_only; + batch_only.set(0, batch_size); + Eigen::IndexList, int> one_by_class; + one_by_class.set(1, num_classes); + + // max_logits along classes. + scratch.reshape(batch_only).device(d) = + logits.broadcast(logits_bcast).maximum(along_class); + + // logits - max_logits. + backprop.device(d) = + logits.broadcast(logits_bcast) - scratch.broadcast(one_by_class); + + // sum(exp(logits - max_logits)) along classes. + scratch.reshape(batch_only).device(d) = backprop.exp().sum(along_class); + + // NOTE(keveman): Eigen on GPU dispatches to an optimized implementation + // for an expression of the form lhs = rhs.sum(). + // lhs = -rhs.sum() doesn't match the above pattern, so folding in the + // negation before calling sum(). + // sum(-labels * + // ((logits - max_logits) - log(sum(exp(logits - max_logits))))) + // along classes + loss.device(d) = (labels.broadcast(labels_bcast) * + (scratch.log().eval().broadcast(one_by_class) - backprop)) + .eval() + .sum(along_class); + + // backprop: prob - labels, where + // prob = exp(logits - max_logits) / sum(exp(logits - max_logits)) + backprop.device(d) = (backprop.exp() / scratch.broadcast(one_by_class)) - + labels.broadcast(labels_bcast); + } +}; + +} // namespace functor +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_XENT_OP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/lib/bfloat16/bfloat16.h b/third_party/tflite-hdrs/tensorflow/core/lib/bfloat16/bfloat16.h new file mode 100644 index 00000000..d6ac77b6 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/lib/bfloat16/bfloat16.h @@ -0,0 +1,21 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_LIB_BFLOAT16_BFLOAT16_H_ +#define TENSORFLOW_CORE_LIB_BFLOAT16_BFLOAT16_H_ + +#include "tensorflow/core/platform/bfloat16.h" + +#endif // TENSORFLOW_CORE_LIB_BFLOAT16_BFLOAT16_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/lib/core/arena.h b/third_party/tflite-hdrs/tensorflow/core/lib/core/arena.h new file mode 100644 index 00000000..14d80422 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/lib/core/arena.h @@ -0,0 +1,111 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// TODO(vrv): Switch this to an open-sourced version of Arena. + +#ifndef TENSORFLOW_CORE_LIB_CORE_ARENA_H_ +#define TENSORFLOW_CORE_LIB_CORE_ARENA_H_ + +#include + +#include + +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +namespace core { + +// This class is "thread-compatible": different threads can access the +// arena at the same time without locking, as long as they use only +// const methods. +class Arena { + public: + // Allocates a thread-compatible arena with the specified block size. + explicit Arena(const size_t block_size); + ~Arena(); + + char* Alloc(const size_t size) { + return reinterpret_cast(GetMemory(size, 1)); + } + + char* AllocAligned(const size_t size, const size_t alignment) { + return reinterpret_cast(GetMemory(size, alignment)); + } + + void Reset(); + +// This should be the worst-case alignment for any type. This is +// good for IA-32, SPARC version 7 (the last one I know), and +// supposedly Alpha. i386 would be more time-efficient with a +// default alignment of 8, but ::operator new() uses alignment of 4, +// and an assertion will fail below after the call to MakeNewBlock() +// if you try to use a larger alignment. +#ifdef __i386__ + static const int kDefaultAlignment = 4; +#else + static constexpr int kDefaultAlignment = 8; +#endif + + protected: + bool SatisfyAlignment(const size_t alignment); + void MakeNewBlock(const uint32 alignment); + void* GetMemoryFallback(const size_t size, const int align); + void* GetMemory(const size_t size, const int align) { + assert(remaining_ <= block_size_); // an invariant + if (size > 0 && size < remaining_ && align == 1) { // common case + void* result = freestart_; + freestart_ += size; + remaining_ -= size; + return result; + } + return GetMemoryFallback(size, align); + } + + size_t remaining_; + + private: + struct AllocatedBlock { + char* mem; + size_t size; + }; + + // Allocate new block of at least block_size, with the specified + // alignment. + // The returned AllocatedBlock* is valid until the next call to AllocNewBlock + // or Reset (i.e. anything that might affect overflow_blocks_). + AllocatedBlock* AllocNewBlock(const size_t block_size, + const uint32 alignment); + + const size_t block_size_; + char* freestart_; // beginning of the free space in most recent block + char* freestart_when_empty_; // beginning of the free space when we're empty + // STL vector isn't as efficient as it could be, so we use an array at first + size_t blocks_alloced_; // how many of the first_blocks_ have been alloced + AllocatedBlock first_blocks_[16]; // the length of this array is arbitrary + // if the first_blocks_ aren't enough, expand into overflow_blocks_. + std::vector* overflow_blocks_; + + void FreeBlocks(); // Frees all except first block + + Arena(const Arena&) = delete; + void operator=(const Arena&) = delete; +}; + +} // namespace core +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_LIB_CORE_ARENA_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/lib/core/bitmap.h b/third_party/tflite-hdrs/tensorflow/core/lib/core/bitmap.h new file mode 100644 index 00000000..86e825db --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/lib/core/bitmap.h @@ -0,0 +1,29 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_LIB_CORE_BITMAP_H_ +#define TENSORFLOW_CORE_LIB_CORE_BITMAP_H_ + +#include "xla/tsl/lib/core/bitmap.h" + +namespace tensorflow { +namespace core { + +using Bitmap = tsl::core::Bitmap; + +} // namespace core +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_LIB_CORE_BITMAP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/lib/core/bits.h b/third_party/tflite-hdrs/tensorflow/core/lib/core/bits.h new file mode 100644 index 00000000..8bcc448b --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/lib/core/bits.h @@ -0,0 +1,42 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_LIB_CORE_BITS_H_ +#define TENSORFLOW_CORE_LIB_CORE_BITS_H_ + +#include "xla/tsl/lib/core/bits.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +// NOLINTBEGIN(misc-unused-using-decls) + +// Return floor(log2(n)) for positive integer n. Returns -1 iff n == 0. +using ::tsl::Log2Floor; +using ::tsl::Log2Floor64; + +// Return ceiling(log2(n)) for positive integer n. Returns -1 iff n == 0. +using ::tsl::Log2Ceiling; +using ::tsl::Log2Ceiling64; + +using ::tsl::NextPowerOfTwo; +using ::tsl::NextPowerOfTwo64; + +// NOLINTEND(misc-unused-using-decls) + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_LIB_CORE_BITS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/lib/core/coding.h b/third_party/tflite-hdrs/tensorflow/core/lib/core/coding.h new file mode 100644 index 00000000..47b645eb --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/lib/core/coding.h @@ -0,0 +1,26 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Endian-neutral encoding: +// * Fixed-length numbers are encoded with least-significant byte first +// * In addition we support variable length "varint" encoding +// * Strings are encoded prefixed by their length in varint format + +#ifndef TENSORFLOW_CORE_LIB_CORE_CODING_H_ +#define TENSORFLOW_CORE_LIB_CORE_CODING_H_ + +#include "tensorflow/core/platform/coding.h" // IWYU pragma: export + +#endif // TENSORFLOW_CORE_LIB_CORE_CODING_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/lib/core/errors.h b/third_party/tflite-hdrs/tensorflow/core/lib/core/errors.h new file mode 100644 index 00000000..94154429 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/lib/core/errors.h @@ -0,0 +1,21 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_LIB_CORE_ERRORS_H_ +#define TENSORFLOW_CORE_LIB_CORE_ERRORS_H_ + +#include "tensorflow/core/platform/errors.h" // IWYU pragma: export + +#endif // TENSORFLOW_CORE_LIB_CORE_ERRORS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/lib/core/notification.h b/third_party/tflite-hdrs/tensorflow/core/lib/core/notification.h new file mode 100644 index 00000000..c22f695f --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/lib/core/notification.h @@ -0,0 +1,23 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_LIB_CORE_NOTIFICATION_H_ +#define TENSORFLOW_CORE_LIB_CORE_NOTIFICATION_H_ + +// Notification implementation is platform-dependent, to support +// alternative synchronization primitives. +#include "tensorflow/core/platform/notification.h" // IWYU pragma: export + +#endif // TENSORFLOW_CORE_LIB_CORE_NOTIFICATION_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/lib/core/raw_coding.h b/third_party/tflite-hdrs/tensorflow/core/lib/core/raw_coding.h new file mode 100644 index 00000000..b4adbb7f --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/lib/core/raw_coding.h @@ -0,0 +1,21 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_LIB_CORE_RAW_CODING_H_ +#define TENSORFLOW_CORE_LIB_CORE_RAW_CODING_H_ + +#include "tensorflow/core/platform/raw_coding.h" // IWYU pragma: export + +#endif // TENSORFLOW_CORE_LIB_CORE_RAW_CODING_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/lib/core/refcount.h b/third_party/tflite-hdrs/tensorflow/core/lib/core/refcount.h new file mode 100644 index 00000000..3bc634af --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/lib/core/refcount.h @@ -0,0 +1,21 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_LIB_CORE_REFCOUNT_H_ +#define TENSORFLOW_CORE_LIB_CORE_REFCOUNT_H_ + +#include "tensorflow/core/platform/refcount.h" // IWYU pragma: export + +#endif // TENSORFLOW_CORE_LIB_CORE_REFCOUNT_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/lib/core/status.h b/third_party/tflite-hdrs/tensorflow/core/lib/core/status.h new file mode 100644 index 00000000..2146cbd5 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/lib/core/status.h @@ -0,0 +1,21 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_LIB_CORE_STATUS_H_ +#define TENSORFLOW_CORE_LIB_CORE_STATUS_H_ + +#include "tensorflow/core/platform/status.h" // IWYU pragma: export + +#endif // TENSORFLOW_CORE_LIB_CORE_STATUS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/lib/core/status_test_util.h b/third_party/tflite-hdrs/tensorflow/core/lib/core/status_test_util.h new file mode 100644 index 00000000..3c604ee8 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/lib/core/status_test_util.h @@ -0,0 +1,22 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_LIB_CORE_STATUS_TEST_UTIL_H_ +#define TENSORFLOW_CORE_LIB_CORE_STATUS_TEST_UTIL_H_ + +#include "xla/tsl/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" + +#endif // TENSORFLOW_CORE_LIB_CORE_STATUS_TEST_UTIL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/lib/core/stringpiece.h b/third_party/tflite-hdrs/tensorflow/core/lib/core/stringpiece.h new file mode 100644 index 00000000..d00ce8c1 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/lib/core/stringpiece.h @@ -0,0 +1,31 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// StringPiece is a simple structure containing a pointer into some external +// storage and a size. The user of a StringPiece must ensure that the slice +// is not used after the corresponding external storage has been +// deallocated. +// +// Multiple threads can invoke const methods on a StringPiece without +// external synchronization, but if any of the threads may call a +// non-const method, all threads accessing the same StringPiece must use +// external synchronization. + +#ifndef TENSORFLOW_CORE_LIB_CORE_STRINGPIECE_H_ +#define TENSORFLOW_CORE_LIB_CORE_STRINGPIECE_H_ + +#include "tensorflow/core/platform/stringpiece.h" // IWYU pragma: export + +#endif // TENSORFLOW_CORE_LIB_CORE_STRINGPIECE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/lib/core/threadpool.h b/third_party/tflite-hdrs/tensorflow/core/lib/core/threadpool.h new file mode 100644 index 00000000..4aa4b69b --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/lib/core/threadpool.h @@ -0,0 +1,21 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_LIB_CORE_THREADPOOL_H_ +#define TENSORFLOW_CORE_LIB_CORE_THREADPOOL_H_ + +#include "tensorflow/core/platform/threadpool.h" // IWYU pragma: export + +#endif // TENSORFLOW_CORE_LIB_CORE_THREADPOOL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/lib/core/threadpool_interface.h b/third_party/tflite-hdrs/tensorflow/core/lib/core/threadpool_interface.h new file mode 100644 index 00000000..1a51e38e --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/lib/core/threadpool_interface.h @@ -0,0 +1,21 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_LIB_CORE_THREADPOOL_INTERFACE_H_ +#define TENSORFLOW_CORE_LIB_CORE_THREADPOOL_INTERFACE_H_ + +#include "tensorflow/core/platform/threadpool_interface.h" // IWYU pragma: export + +#endif // TENSORFLOW_CORE_LIB_CORE_THREADPOOL_INTERFACE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/lib/core/threadpool_options.h b/third_party/tflite-hdrs/tensorflow/core/lib/core/threadpool_options.h new file mode 100644 index 00000000..64f7e647 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/lib/core/threadpool_options.h @@ -0,0 +1,21 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_LIB_CORE_THREADPOOL_OPTIONS_H_ +#define TENSORFLOW_CORE_LIB_CORE_THREADPOOL_OPTIONS_H_ + +#include "tensorflow/core/platform/threadpool_options.h" // IWYU pragma: export + +#endif // TENSORFLOW_CORE_LIB_CORE_THREADPOOL_OPTIONS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/lib/db/sqlite.h b/third_party/tflite-hdrs/tensorflow/core/lib/db/sqlite.h new file mode 100644 index 00000000..992001e4 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/lib/db/sqlite.h @@ -0,0 +1,457 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_LIB_DB_SQLITE_H_ +#define TENSORFLOW_CORE_LIB_DB_SQLITE_H_ + +#include +#include +#include + +#include "absl/log/check.h" +#include "sqlite3.h" +#include "tensorflow/core/lib/core/refcount.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/thread_annotations.h" +#include "tensorflow/core/platform/types.h" +#include "tsl/platform/status.h" + +/// TensorFlow SQLite Veneer +/// +/// - Memory safety +/// - Less boilerplate +/// - Removes deprecated stuff +/// - Pretends UTF16 doesn't exist +/// - Transaction compile-time safety +/// - Statically loads our native extensions +/// - Error reporting via tensorflow::Status et al. +/// +/// SQLite>=3.8.2 needs to be supported until April 2019, which is when +/// Ubuntu 14.04 LTS becomes EOL. + +namespace tensorflow { + +class SqliteLock; +class SqliteStatement; +class SqliteTransaction; + +/// \brief SQLite connection object. +/// +/// The SQLite connection is closed automatically by the destructor. +/// Reference counting ensures that happens after its statements are +/// destructed. +/// +/// Instances are reference counted and can be shared between threads. +/// This class offers the same thread safety behaviors as the SQLite +/// API itself. +/// +/// This veneer uses auto-commit mode by default, which means a 4ms +/// fsync() happens after every write unless a SqliteTransaction is +/// used or WAL mode is enabled beforehand. +class TF_LOCKABLE Sqlite : public core::RefCounted { + public: + /// \brief Closes SQLite connection, which can take milliseconds. + ~Sqlite() override; + + /// \brief Opens SQLite database file. + /// + /// Most users will want to set flags to SQLITE_OPEN_READWRITE | + /// SQLITE_OPEN_CREATE. There are many other open flags; here are + /// notes on a few of them: + /// + /// - SQLITE_OPEN_READONLY: Allowed if no WAL journal is active. + /// - SQLITE_OPEN_SHAREDCACHE: Will be ignored because this veneer + /// doesn't support the unlock notify API. + /// - SQLITE_OPEN_NOMUTEX: Means access to this connection MUST be + /// serialized by the caller in accordance with the same contracts + /// implemented by this API. + /// + /// This function sets PRAGMA values from TF_SQLITE_* environment + /// variables. See sqlite.cc to learn more. + static absl::Status Open(const string& path, int flags, Sqlite** db); + + /// \brief Creates SQLite statement. + /// + /// This routine should never fail if sql is valid and does not + /// reference tables. When tables are referenced, system calls are + /// needed which can take microseconds. When the schema changes, this + /// routine will retry automatically and then possibly fail. + /// + /// The returned statement holds a reference to this object. + absl::Status Prepare(const absl::string_view& sql, SqliteStatement* stmt); + SqliteStatement PrepareOrDie(const absl::string_view& sql); + + /// \brief Returns extended result code of last error. + /// + /// If the most recent API call was successful, the result is + /// undefined. The legacy result code can be obtained by saying + /// errcode() & 0xff. + int errcode() const TF_EXCLUSIVE_LOCKS_REQUIRED(this) { + return sqlite3_extended_errcode(db_); + } + + /// \brief Returns pointer to current error message state. + const char* errmsg() const TF_EXCLUSIVE_LOCKS_REQUIRED(this) { + return sqlite3_errmsg(db_); + } + + /// \brief Returns rowid assigned to last successful insert. + int64_t last_insert_rowid() const TF_EXCLUSIVE_LOCKS_REQUIRED(this) { + return sqlite3_last_insert_rowid(db_); + } + + /// \brief Returns number of rows directly changed by last write. + int64_t changes() const TF_EXCLUSIVE_LOCKS_REQUIRED(this) { + return sqlite3_changes(db_); + } + + private: + friend class SqliteLock; + friend class SqliteStatement; + friend class SqliteTransaction; + + Sqlite(sqlite3* db, sqlite3_stmt* begin, sqlite3_stmt* commit, + sqlite3_stmt* rollback) noexcept + : db_(db), begin_(begin), commit_(commit), rollback_(rollback) {} + + sqlite3* const db_; + sqlite3_stmt* const begin_; + sqlite3_stmt* const commit_; + sqlite3_stmt* const rollback_; + bool is_in_transaction_ = false; + + Sqlite(const Sqlite&) = delete; + void operator=(const Sqlite&) = delete; +}; + +/// \brief SQLite prepared statement. +/// +/// Instances can only be shared between threads if caller serializes +/// access from first Bind*() to *Reset(). +/// +/// When reusing a statement in a loop, be certain to not have jumps +/// betwixt Bind*() and *Reset(). +class SqliteStatement { + public: + /// \brief Initializes an empty statement to be assigned later. + SqliteStatement() noexcept = default; + + /// \brief Finalizes statement. + /// + /// This can take milliseconds if it was blocking the Sqlite + /// connection object from being freed. + ~SqliteStatement() { + sqlite3_finalize(stmt_); + if (db_ != nullptr) db_->Unref(); + } + + /// \brief Returns true if statement is initialized. + explicit operator bool() const { return stmt_ != nullptr; } + + /// \brief Returns SQL text from when this query was prepared. + const char* sql() const { return sqlite3_sql(stmt_); } + + /// \brief Number of bytes bound since last *Reset(). + uint64 size() { return size_; } + + /// \brief Executes query for fetching arbitrary rows. + /// + /// `is_done` will always be set to true unless SQLITE_ROW is + /// returned by the underlying API. If status() is already in an + /// error state, then this method is a no-op and the existing status + /// is returned. + /// + /// The OrDie version returns `!is_done` which, if true, indicates a + /// row is available. + /// + /// This statement should be Reset() or destructed when finished with + /// the result. + absl::Status Step(bool* is_done); + bool StepOrDie() TF_MUST_USE_RESULT; + + /// \brief Executes query when only one row is desired. + /// + /// If a row isn't returned, an internal error Status is returned + /// that won't be reflected in the connection error state. + /// + /// This statement should be Reset() or destructed when finished with + /// the result. + absl::Status StepOnce(); + const SqliteStatement& StepOnceOrDie(); + + /// \brief Executes query, ensures zero rows returned, then Reset(). + /// + /// If a row is returned, an internal error Status is returned that + /// won't be reflected in the connection error state. + absl::Status StepAndReset(); + void StepAndResetOrDie(); + + /// \brief Resets statement so it can be executed again. + /// + /// Implementation note: This method diverges from canonical API + /// behavior by calling sqlite3_clear_bindings() in addition to + /// sqlite3_reset(). That makes the veneer safer; we haven't found a + /// super compelling reason yet to call them independently. + void Reset(); + + /// \brief Binds signed 64-bit integer to 1-indexed query parameter. + void BindInt(int parameter, int64_t value) { + Update(sqlite3_bind_int64(stmt_, parameter, value), parameter); + size_ += sizeof(int64_t); + } + void BindInt(const char* parameter, int64_t value) { + BindInt(GetParameterIndex(parameter), value); + } + + /// \brief Binds double to 1-indexed query parameter. + void BindDouble(int parameter, double value) { + Update(sqlite3_bind_double(stmt_, parameter, value), parameter); + size_ += sizeof(double); + } + void BindDouble(const char* parameter, double value) { + BindDouble(GetParameterIndex(parameter), value); + } + + /// \brief Copies UTF-8 text to 1-indexed query parameter. + /// + /// If NUL characters are present, they will still go in the DB and + /// be successfully retrieved by ColumnString(); however, the + /// behavior of these values with SQLite functions is undefined. + /// + /// When using the unsafe methods, the data must not be changed or + /// freed until this statement is Reset() or finalized. + void BindText(int parameter, const absl::string_view& text) { + Update(sqlite3_bind_text64(stmt_, parameter, text.data(), text.size(), + SQLITE_TRANSIENT, SQLITE_UTF8), + parameter); + size_ += text.size(); + } + void BindText(const char* parameter, const absl::string_view& text) { + BindText(GetParameterIndex(parameter), text); + } + void BindTextUnsafe(int parameter, const absl::string_view& text) { + Update(sqlite3_bind_text64(stmt_, parameter, text.data(), text.size(), + SQLITE_STATIC, SQLITE_UTF8), + parameter); + size_ += text.size(); + } + void BindTextUnsafe(const char* parameter, const absl::string_view& text) { + BindTextUnsafe(GetParameterIndex(parameter), text); + } + + /// \brief Copies binary data to 1-indexed query parameter. + /// + /// When using the unsafe methods, the data must not be changed or + /// freed until this statement is Reset() or finalized. + void BindBlob(int parameter, const absl::string_view& blob) { + Update(sqlite3_bind_blob64(stmt_, parameter, blob.data(), blob.size(), + SQLITE_TRANSIENT), + parameter); + size_ += blob.size(); + } + void BindBlob(const char* parameter, const absl::string_view& blob) { + BindBlob(GetParameterIndex(parameter), blob); + } + void BindBlobUnsafe(int parameter, const absl::string_view& blob) { + Update(sqlite3_bind_blob64(stmt_, parameter, blob.data(), blob.size(), + SQLITE_STATIC), + parameter); + size_ += blob.size(); + } + void BindBlobUnsafe(const char* parameter, const absl::string_view& text) { + BindBlobUnsafe(GetParameterIndex(parameter), text); + } + + /// \brief Returns number of columns in result set. + int ColumnCount() const TF_MUST_USE_RESULT { + return sqlite3_column_count(stmt_); + } + + /// \brief Returns type of 0-indexed column value in row data. + /// + /// Please note that SQLite is dynamically typed and the type of a + /// particular column can vary from row to row. + int ColumnType(int column) const TF_MUST_USE_RESULT { + return sqlite3_column_type(stmt_, column); + } + + /// \brief Returns 0-indexed column from row result coerced as an integer. + int64_t ColumnInt(int column) const TF_MUST_USE_RESULT { + return sqlite3_column_int64(stmt_, column); + } + + /// \brief Returns 0-indexed column from row result coerced as a double. + double ColumnDouble(int column) const TF_MUST_USE_RESULT { + return sqlite3_column_double(stmt_, column); + } + + /// \brief Copies 0-indexed column from row result coerced as a string. + /// + /// NULL values are returned as empty string. This method should be + /// used for both BLOB and TEXT columns. See also: ColumnType(). + string ColumnString(int column) const TF_MUST_USE_RESULT { + auto data = sqlite3_column_blob(stmt_, column); + if (data == nullptr) return ""; + return {static_cast(data), + static_cast(ColumnSize(column))}; + } + + /// \brief Returns pointer to binary data at 0-indexed column. + /// + /// Empty values are returned as NULL. The returned memory will no + /// longer be valid the next time Step() or Reset() is called. No NUL + /// terminator is added. + absl::string_view ColumnStringUnsafe(int column) const TF_MUST_USE_RESULT { + return {static_cast(sqlite3_column_blob(stmt_, column)), + static_cast(ColumnSize(column))}; + } + + /// \brief Returns number of bytes stored at 0-indexed column. + int ColumnSize(int column) const TF_MUST_USE_RESULT { + return sqlite3_column_bytes(stmt_, column); + } + + /// \brief Move constructor, after which is reset to empty. + SqliteStatement(SqliteStatement&& other) noexcept + : db_(other.db_), stmt_(other.stmt_), bind_error_(other.bind_error_) { + other.db_ = nullptr; + other.stmt_ = nullptr; + other.bind_error_ = SQLITE_OK; + } + + /// \brief Move assignment, after which is reset to empty. + SqliteStatement& operator=(SqliteStatement&& other) noexcept { + if (&other != this) { + if (db_ != nullptr) db_->Unref(); + if (stmt_ != nullptr) sqlite3_finalize(stmt_); + db_ = other.db_; + stmt_ = other.stmt_; + bind_error_ = other.bind_error_; + size_ = other.size_; + other.db_ = nullptr; + other.stmt_ = nullptr; + other.bind_error_ = SQLITE_OK; + other.size_ = 0; + } + return *this; + } + + private: + friend class Sqlite; + + SqliteStatement(Sqlite* db, sqlite3_stmt* stmt) noexcept + : db_(db), stmt_(stmt) { + db_->Ref(); + } + + void Update(int rc, int parameter) { + // Binding strings can fail if they exceed length limit. + if (TF_PREDICT_FALSE(rc != SQLITE_OK)) { + if (bind_error_ == SQLITE_OK) { + bind_error_ = rc; + bind_error_parameter_ = parameter; + } + } + } + + int GetParameterIndex(const char* parameter) { + int index = sqlite3_bind_parameter_index(stmt_, parameter); + DCHECK(index > 0); // OK to compile away since it'll fail again + return index; + } + + Sqlite* db_ = nullptr; + sqlite3_stmt* stmt_ = nullptr; + int bind_error_ = SQLITE_OK; + int bind_error_parameter_ = 0; + uint64 size_ = 0; + + SqliteStatement(const SqliteStatement&) = delete; + void operator=(const SqliteStatement&) = delete; +}; + +/// \brief Reentrant SQLite connection object lock +/// +/// This is a no-op if SQLITE_OPEN_NOMUTEX was used. +class TF_SCOPED_LOCKABLE SqliteLock { + public: + explicit SqliteLock(Sqlite& db) TF_EXCLUSIVE_LOCK_FUNCTION(db) + : mutex_(sqlite3_db_mutex(db.db_)) { + sqlite3_mutex_enter(mutex_); + } + SqliteLock(Sqlite& db, std::try_to_lock_t) TF_EXCLUSIVE_LOCK_FUNCTION(db) + : mutex_(sqlite3_db_mutex(db.db_)) { + if (TF_PREDICT_FALSE(sqlite3_mutex_try(mutex_) != SQLITE_OK)) { + is_locked_ = false; + } + } + ~SqliteLock() TF_UNLOCK_FUNCTION() { + if (is_locked_) sqlite3_mutex_leave(mutex_); + } + explicit operator bool() const { return is_locked_; } + + private: + sqlite3_mutex* const mutex_; + bool is_locked_ = true; + SqliteLock(const SqliteLock&) = delete; + void operator=(const SqliteLock&) = delete; +}; +#define SqliteLock(x) static_assert(0, "sqlite_lock_decl_missing_name"); + +/// \brief SQLite transaction scope. +/// +/// This class acquires an exclusive lock on the connection object (if +/// mutexes weren't disabled) and runs BEGIN / ROLLBACK automatically. +/// Unlike SqliteLock this scope is non-reentrant. To avoid program +/// crashes, business logic should use the TF_EXCLUSIVE_LOCK_FUNCTION and +/// TF_LOCKS_EXCLUDED annotations as much as possible. +class TF_SCOPED_LOCKABLE SqliteTransaction { + public: + /// \brief Locks db and begins deferred transaction. + /// + /// This will crash if a transaction is already active. + explicit SqliteTransaction(Sqlite& db) TF_EXCLUSIVE_LOCK_FUNCTION(db); + + /// \brief Runs ROLLBACK and unlocks. + ~SqliteTransaction() TF_UNLOCK_FUNCTION(); + + /// \brief Commits transaction. + /// + /// If this is successful, a new transaction will be started, which + /// is rolled back when exiting the scope. + absl::Status Commit(); + + private: + void Begin(); + Sqlite* const db_; + + SqliteTransaction(const SqliteTransaction&) = delete; + void operator=(const SqliteTransaction&) = delete; +}; + +#define SQLITE_EXCLUSIVE_TRANSACTIONS_REQUIRED(...) \ + TF_EXCLUSIVE_LOCKS_REQUIRED(__VA_ARGS__) +#define SQLITE_TRANSACTIONS_EXCLUDED(...) TF_LOCKS_EXCLUDED(__VA_ARGS__) + +inline SqliteStatement Sqlite::PrepareOrDie(const absl::string_view& sql) { + SqliteStatement stmt; + TF_CHECK_OK(Prepare(sql, &stmt)); + return stmt; +} + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_LIB_DB_SQLITE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/lib/gif/gif_io.h b/third_party/tflite-hdrs/tensorflow/core/lib/gif/gif_io.h new file mode 100644 index 00000000..ae7d5125 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/lib/gif/gif_io.h @@ -0,0 +1,52 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Functions to read and write images in GIF format. +// +// The advantage over image/codec/png{enc,dec}oder.h is that this library +// supports both 8 and 16 bit images. +// +// The decoding routine accepts binary image data as a StringPiece. These are +// implicitly constructed from strings or char* so they're completely +// transparent to the caller. They're also very cheap to construct so this +// doesn't introduce any additional overhead. +// +// The primary benefit of StringPieces being, in this case, that APIs already +// returning StringPieces (e.g., Bigtable Scanner) or Cords (e.g., IOBuffer; +// only when they're flat, though) or protocol buffer fields typed to either of +// these can be decoded without copying the data into a C++ string. + +#ifndef TENSORFLOW_CORE_LIB_GIF_GIF_IO_H_ +#define TENSORFLOW_CORE_LIB_GIF_GIF_IO_H_ + +#include +#include +#include +#include + +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +namespace gif { + +uint8* Decode(const void* srcdata, int datasize, + const std::function& allocate_output, + string* error_string, bool expand_animations = true); + +} // namespace gif +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_LIB_GIF_GIF_IO_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/lib/gtl/array_slice.h b/third_party/tflite-hdrs/tensorflow/core/lib/gtl/array_slice.h new file mode 100644 index 00000000..ddacf4d2 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/lib/gtl/array_slice.h @@ -0,0 +1,42 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_LIB_GTL_ARRAY_SLICE_H_ +#define TENSORFLOW_CORE_LIB_GTL_ARRAY_SLICE_H_ + +#include "absl/base/macros.h" +#include "absl/types/span.h" +// TODO(timshen): This is kept only because lots of targets transitively depend +// on it. Remove all targets' dependencies. +#include "tensorflow/core/lib/gtl/inlined_vector.h" + +// TODO: b/323943471 - This macro should eventually be provided by Abseil. +#ifndef ABSL_DEPRECATE_AND_INLINE +#define ABSL_DEPRECATE_AND_INLINE() +#endif + +namespace tensorflow { +namespace gtl { + +template +using ArraySlice ABSL_DEPRECATE_AND_INLINE() = absl::Span; + +template +using MutableArraySlice ABSL_DEPRECATE_AND_INLINE() = absl::Span; + +} // namespace gtl +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_LIB_GTL_ARRAY_SLICE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/lib/gtl/cleanup.h b/third_party/tflite-hdrs/tensorflow/core/lib/gtl/cleanup.h new file mode 100644 index 00000000..3e54f828 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/lib/gtl/cleanup.h @@ -0,0 +1,113 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// MakeCleanup(f) returns an RAII cleanup object that calls 'f' in its +// destructor. The easiest way to use MakeCleanup is with a lambda argument, +// capturing the return value in an 'auto' local variable. Most users will not +// need more sophisticated syntax than that. +// +// Example: +// void func() { +// FILE* fp = fopen("data.txt", "r"); +// if (fp == nullptr) return; +// auto fp_cleaner = gtl::MakeCleanup([fp] { fclose(fp); }); +// // No matter what, fclose(fp) will happen. +// DataObject d; +// while (ReadDataObject(fp, &d)) { +// if (d.IsBad()) { +// LOG(ERROR) << "Bad Data"; +// return; +// } +// PushGoodData(d); +// } +// } +// +// You can use Cleanup directly, instead of using MakeCleanup and auto, +// but there's rarely a reason to do that. +// +// You can call 'release()' on a Cleanup object to cancel the cleanup. + +#ifndef TENSORFLOW_CORE_LIB_GTL_CLEANUP_H_ +#define TENSORFLOW_CORE_LIB_GTL_CLEANUP_H_ + +#include +#include + +#include "tensorflow/core/platform/macros.h" + +namespace tensorflow { +namespace gtl { + +// A move-only RAII object that calls a stored cleanup functor when +// destroyed. Cleanup is the return type of gtl::MakeCleanup(F). +template +class Cleanup { + public: + Cleanup() : released_(true), f_() {} + + template + explicit Cleanup(G&& f) // NOLINT + : f_(std::forward(f)) {} // NOLINT(build/c++11) + + Cleanup(Cleanup&& src) // NOLINT + : released_(src.is_released()), f_(src.release()) {} + + // Implicitly move-constructible from any compatible Cleanup. + // The source will be released as if src.release() were called. + // A moved-from Cleanup can be safely destroyed or reassigned. + template + Cleanup(Cleanup&& src) // NOLINT + : released_(src.is_released()), f_(src.release()) {} + + // Assignment to a Cleanup object behaves like destroying it + // and making a new one in its place, analogous to unique_ptr + // semantics. + Cleanup& operator=(Cleanup&& src) { // NOLINT + if (!released_) f_(); + released_ = src.released_; + f_ = src.release(); + return *this; + } + + ~Cleanup() { + if (!released_) f_(); + } + + // Releases the cleanup function instead of running it. + // Hint: use c.release()() to run early. + F release() { + released_ = true; + return std::move(f_); + } + + bool is_released() const { return released_; } + + private: + static_assert(!std::is_reference::value, "F must not be a reference"); + + bool released_ = false; + F f_; +}; + +template ::type> +TF_MUST_USE_RESULT Cleanup MakeCleanup(F&& f) { + return Cleanup(std::forward(f)); +} + +} // namespace gtl +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_LIB_GTL_CLEANUP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/lib/gtl/compactptrset.h b/third_party/tflite-hdrs/tensorflow/core/lib/gtl/compactptrset.h new file mode 100644 index 00000000..6655ac92 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/lib/gtl/compactptrset.h @@ -0,0 +1,29 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_LIB_GTL_COMPACTPTRSET_H_ +#define TENSORFLOW_CORE_LIB_GTL_COMPACTPTRSET_H_ + +#include "xla/tsl/lib/gtl/compactptrset.h" + +namespace tensorflow { +namespace gtl { + +using ::tsl::gtl::CompactPointerSet; // NOLINT(misc-unused-using-decls) + +} // namespace gtl +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_LIB_GTL_COMPACTPTRSET_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/lib/gtl/edit_distance.h b/third_party/tflite-hdrs/tensorflow/core/lib/gtl/edit_distance.h new file mode 100644 index 00000000..94a5ad68 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/lib/gtl/edit_distance.h @@ -0,0 +1,108 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_LIB_GTL_EDIT_DISTANCE_H_ +#define TENSORFLOW_CORE_LIB_GTL_EDIT_DISTANCE_H_ + +#include + +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/lib/gtl/inlined_vector.h" + +namespace tensorflow { +namespace gtl { + +// Calculate the Levenshtein Edit Distance between two contiguous +// sequences, s and t, of type T. +// +// The Levenshtein distance is a symmetric distance defined as the +// smallest number of insertions, deletions, and substitutions +// required to convert sequence s to t (and vice versa). +// Note, this distance does not consider transpositions. +// +// For more details and a reference implementation, see: +// https://en.wikipedia.org/wiki/Levenshtein_distance +// +// This implementation has time complexity O(|s|*|t|) +// and space complexity O(min(|s|, |t|)), where +// |x| := x.size() +// +// A simple call to LevenshteinDistance looks like: +// +// int64 dist = LevenshteinDistance("hi", "bye", std::equal_to()); +// +template +inline int64_t LevenshteinDistance(const gtl::ArraySlice s, + const gtl::ArraySlice t, const Cmp& cmp) { + const int64_t s_size = s.size(); + const int64_t t_size = t.size(); + + if (t_size > s_size) return LevenshteinDistance(t, s, cmp); + + const T* s_data = s.data(); + const T* t_data = t.data(); + + if (t_size == 0) return s_size; + if (s == t) return 0; + + // Create work vector + absl::InlinedVector scratch_holder(t_size); + + int64_t* scratch = scratch_holder.data(); + + // Special case for i = 0: Distance between empty string and string + // of length j is just j. + for (size_t j = 1; j < t_size; ++j) scratch[j - 1] = j; + + for (size_t i = 1; i <= s_size; ++i) { + // Invariant: scratch[j - 1] equals cost(i - 1, j). + int substitution_base_cost = i - 1; + int insertion_cost = i + 1; + for (size_t j = 1; j <= t_size; ++j) { + // Invariants: + // scratch[k - 1] = cost(i, k) for 0 < k < j. + // scratch[k - 1] = cost(i - 1, k) for j <= k <= t_size. + // substitution_base_cost = cost(i - 1, j - 1) + // insertion_cost = cost(i, j - 1) + const int replacement_cost = cmp(s_data[i - 1], t_data[j - 1]) ? 0 : 1; + const int substitution_cost = substitution_base_cost + replacement_cost; + const int deletion_cost = scratch[j - 1] + 1; + + // Select the cheapest edit. + const int cheapest = // = cost(i, j) + std::min(deletion_cost, std::min(insertion_cost, substitution_cost)); + + // Restore invariant for the next iteration of the loop. + substitution_base_cost = scratch[j - 1]; // = cost(i - 1, j) + scratch[j - 1] = cheapest; // = cost(i, j) + insertion_cost = cheapest + 1; // = cost(i, j) + 1 + } + } + return scratch[t_size - 1]; +} + +template +inline int64_t LevenshteinDistance(const Container1& s, const Container2& t, + const Cmp& cmp) { + return LevenshteinDistance( + gtl::ArraySlice(s.data(), s.size()), + gtl::ArraySlice(t.data(), t.size()), + cmp); +} + +} // namespace gtl +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_LIB_GTL_EDIT_DISTANCE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/lib/gtl/flatmap.h b/third_party/tflite-hdrs/tensorflow/core/lib/gtl/flatmap.h new file mode 100644 index 00000000..3b112a71 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/lib/gtl/flatmap.h @@ -0,0 +1,33 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_LIB_GTL_FLATMAP_H_ +#define TENSORFLOW_CORE_LIB_GTL_FLATMAP_H_ + +#include "xla/tsl/lib/gtl/flatmap.h" +#include "tensorflow/core/lib/gtl/flatrep.h" +#include "tensorflow/core/lib/hash/hash.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +namespace gtl { + +using tsl::gtl::FlatMap; // NOLINT(misc-unused-using-decls) + +} // namespace gtl +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_LIB_GTL_FLATMAP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/lib/gtl/flatrep.h b/third_party/tflite-hdrs/tensorflow/core/lib/gtl/flatrep.h new file mode 100644 index 00000000..59caa4b0 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/lib/gtl/flatrep.h @@ -0,0 +1,31 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_LIB_GTL_FLATREP_H_ +#define TENSORFLOW_CORE_LIB_GTL_FLATREP_H_ + +#include "xla/tsl/lib/gtl/flatrep.h" + +namespace tensorflow { +namespace gtl { +namespace internal { + +using tsl::gtl::internal::FlatRep; // NOLINT(misc-unused-using-decls) + +} // namespace internal +} // namespace gtl +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_LIB_GTL_FLATREP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/lib/gtl/flatset.h b/third_party/tflite-hdrs/tensorflow/core/lib/gtl/flatset.h new file mode 100644 index 00000000..fcb7ed96 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/lib/gtl/flatset.h @@ -0,0 +1,29 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_LIB_GTL_FLATSET_H_ +#define TENSORFLOW_CORE_LIB_GTL_FLATSET_H_ + +#include "xla/tsl/lib/gtl/flatset.h" + +namespace tensorflow { +namespace gtl { + +using tsl::gtl::FlatSet; // NOLINT(misc-unused-using-decls) + +} // namespace gtl +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_LIB_GTL_FLATSET_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/lib/gtl/inlined_vector.h b/third_party/tflite-hdrs/tensorflow/core/lib/gtl/inlined_vector.h new file mode 100644 index 00000000..df9d1a24 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/lib/gtl/inlined_vector.h @@ -0,0 +1,33 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_LIB_GTL_INLINED_VECTOR_H_ +#define TENSORFLOW_CORE_LIB_GTL_INLINED_VECTOR_H_ + +#include "xla/tsl/lib/gtl/inlined_vector.h" // IWYU pragma: export +// TODO(kramerb): This is kept only because lots of targets transitively depend +// on it. Remove all targets' dependencies. +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +namespace gtl { + +using ::tsl::gtl::InlinedVector; // NOLINT(misc-unused-using-decls) + +} // namespace gtl +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_LIB_GTL_INLINED_VECTOR_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/lib/gtl/int_type.h b/third_party/tflite-hdrs/tensorflow/core/lib/gtl/int_type.h new file mode 100644 index 00000000..c161ee91 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/lib/gtl/int_type.h @@ -0,0 +1,30 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + + +#ifndef TENSORFLOW_CORE_LIB_GTL_INT_TYPE_H_ +#define TENSORFLOW_CORE_LIB_GTL_INT_TYPE_H_ + +#include "xla/tsl/lib/gtl/int_type.h" + +namespace tensorflow { +namespace gtl { + +using ::tsl::gtl::IntType; // NOLINT(misc-unused-using-decls) + +} // namespace gtl +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_LIB_GTL_INT_TYPE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/lib/gtl/iterator_range.h b/third_party/tflite-hdrs/tensorflow/core/lib/gtl/iterator_range.h new file mode 100644 index 00000000..ca980fd5 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/lib/gtl/iterator_range.h @@ -0,0 +1,39 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This provides a very simple, boring adaptor for a begin and end iterator +// into a range type. This should be used to build range views that work well +// with range based for loops and range based constructors. +// +// Note that code here follows more standards-based coding conventions as it +// is mirroring proposed interfaces for standardization. +// +// Converted from chandlerc@'s code to Google style by joshl@. + +#ifndef TENSORFLOW_CORE_LIB_GTL_ITERATOR_RANGE_H_ +#define TENSORFLOW_CORE_LIB_GTL_ITERATOR_RANGE_H_ + +#include "xla/tsl/lib/gtl/iterator_range.h" + +namespace tensorflow { +namespace gtl { +// NOLINTBEGIN(misc-unused-using-decls) +using ::tsl::gtl::iterator_range; +using ::tsl::gtl::make_range; +// NOLINTEND(misc-unused-using-decls) +} // namespace gtl +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_LIB_GTL_ITERATOR_RANGE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/lib/gtl/manual_constructor.h b/third_party/tflite-hdrs/tensorflow/core/lib/gtl/manual_constructor.h new file mode 100644 index 00000000..4431f5e1 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/lib/gtl/manual_constructor.h @@ -0,0 +1,245 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// ManualConstructor statically-allocates space in which to store some +// object, but does not initialize it. You can then call the constructor +// and destructor for the object yourself as you see fit. This is useful +// for memory management optimizations, where you want to initialize and +// destroy an object multiple times but only allocate it once. +// +// (When I say ManualConstructor statically allocates space, I mean that +// the ManualConstructor object itself is forced to be the right size.) + +#ifndef TENSORFLOW_CORE_LIB_GTL_MANUAL_CONSTRUCTOR_H_ +#define TENSORFLOW_CORE_LIB_GTL_MANUAL_CONSTRUCTOR_H_ + +#include +#include +#include + +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/mem.h" + +namespace tensorflow { +namespace gtl { +namespace internal { + +// +// Provides a char array with the exact same alignment as another type. The +// first parameter must be a complete type, the second parameter is how many +// of that type to provide space for. +// +// TF_LIB_GTL_ALIGNED_CHAR_ARRAY(struct stat, 16) storage_; +// +// Because MSVC and older GCCs require that the argument to their alignment +// construct to be a literal constant integer, we use a template instantiated +// at all the possible powers of two. +#ifndef SWIG +template +struct AlignType {}; +template +struct AlignType<0, size> { + typedef char result[size]; +}; +#if defined(_MSC_VER) +#define TF_LIB_GTL_ALIGN_ATTRIBUTE(X) __declspec(align(X)) +#define TF_LIB_GTL_ALIGN_OF(T) __alignof(T) +#else +#define TF_LIB_GTL_ALIGN_ATTRIBUTE(X) __attribute__((aligned(X))) +#define TF_LIB_GTL_ALIGN_OF(T) __alignof__(T) +#endif + +#if defined(TF_LIB_GTL_ALIGN_ATTRIBUTE) + +#define TF_LIB_GTL_ALIGNTYPE_TEMPLATE(X) \ + template \ + struct AlignType { \ + typedef TF_LIB_GTL_ALIGN_ATTRIBUTE(X) char result[size]; \ + } + +TF_LIB_GTL_ALIGNTYPE_TEMPLATE(1); +TF_LIB_GTL_ALIGNTYPE_TEMPLATE(2); +TF_LIB_GTL_ALIGNTYPE_TEMPLATE(4); +TF_LIB_GTL_ALIGNTYPE_TEMPLATE(8); +TF_LIB_GTL_ALIGNTYPE_TEMPLATE(16); +TF_LIB_GTL_ALIGNTYPE_TEMPLATE(32); +TF_LIB_GTL_ALIGNTYPE_TEMPLATE(64); +TF_LIB_GTL_ALIGNTYPE_TEMPLATE(128); +TF_LIB_GTL_ALIGNTYPE_TEMPLATE(256); +TF_LIB_GTL_ALIGNTYPE_TEMPLATE(512); +TF_LIB_GTL_ALIGNTYPE_TEMPLATE(1024); +TF_LIB_GTL_ALIGNTYPE_TEMPLATE(2048); +TF_LIB_GTL_ALIGNTYPE_TEMPLATE(4096); +TF_LIB_GTL_ALIGNTYPE_TEMPLATE(8192); +// Any larger and MSVC++ will complain. + +#define TF_LIB_GTL_ALIGNED_CHAR_ARRAY(T, Size) \ + typename tensorflow::gtl::internal::AlignType::result + +#undef TF_LIB_GTL_ALIGNTYPE_TEMPLATE +#undef TF_LIB_GTL_ALIGN_ATTRIBUTE + +#else // defined(TF_LIB_GTL_ALIGN_ATTRIBUTE) +#error "You must define TF_LIB_GTL_ALIGNED_CHAR_ARRAY for your compiler." +#endif // defined(TF_LIB_GTL_ALIGN_ATTRIBUTE) + +#else // !SWIG + +// SWIG can't represent alignment and doesn't care about alignment on data +// members (it works fine without it). +template +struct AlignType { + typedef char result[Size]; +}; +#define TF_LIB_GTL_ALIGNED_CHAR_ARRAY(T, Size) \ + tensorflow::gtl::internal::AlignType::result + +// Enough to parse with SWIG, will never be used by running code. +#define TF_LIB_GTL_ALIGN_OF(Type) 16 + +#endif // !SWIG + +} // namespace internal +} // namespace gtl + +template +class ManualConstructor { + public: + // No constructor or destructor because one of the most useful uses of + // this class is as part of a union, and members of a union cannot have + // constructors or destructors. And, anyway, the whole point of this + // class is to bypass these. + + // Support users creating arrays of ManualConstructor<>s. This ensures that + // the array itself has the correct alignment. + static void* operator new[](size_t size) { + return port::AlignedMalloc(size, TF_LIB_GTL_ALIGN_OF(Type)); + } + static void operator delete[](void* mem) { port::AlignedFree(mem); } + + inline Type* get() { return reinterpret_cast(space_); } + inline const Type* get() const { + return reinterpret_cast(space_); + } + + inline Type* operator->() { return get(); } + inline const Type* operator->() const { return get(); } + + inline Type& operator*() { return *get(); } + inline const Type& operator*() const { return *get(); } + + inline void Init() { new (space_) Type; } + +// Init() constructs the Type instance using the given arguments +// (which are forwarded to Type's constructor). In C++11, Init() can +// take any number of arguments of any type, and forwards them perfectly. +// On pre-C++11 platforms, it can take up to 11 arguments, and may not be +// able to forward certain kinds of arguments. +// +// Note that Init() with no arguments performs default-initialization, +// not zero-initialization (i.e it behaves the same as "new Type;", not +// "new Type();"), so it will leave non-class types uninitialized. +#ifdef LANG_CXX11 + template + inline void Init(Ts&&... args) { // NOLINT + new (space_) Type(std::forward(args)...); // NOLINT + } +#else // !defined(LANG_CXX11) + template + inline void Init(const T1& p1) { + new (space_) Type(p1); + } + + template + inline void Init(const T1& p1, const T2& p2) { + new (space_) Type(p1, p2); + } + + template + inline void Init(const T1& p1, const T2& p2, const T3& p3) { + new (space_) Type(p1, p2, p3); + } + + template + inline void Init(const T1& p1, const T2& p2, const T3& p3, const T4& p4) { + new (space_) Type(p1, p2, p3, p4); + } + + template + inline void Init(const T1& p1, const T2& p2, const T3& p3, const T4& p4, + const T5& p5) { + new (space_) Type(p1, p2, p3, p4, p5); + } + + template + inline void Init(const T1& p1, const T2& p2, const T3& p3, const T4& p4, + const T5& p5, const T6& p6) { + new (space_) Type(p1, p2, p3, p4, p5, p6); + } + + template + inline void Init(const T1& p1, const T2& p2, const T3& p3, const T4& p4, + const T5& p5, const T6& p6, const T7& p7) { + new (space_) Type(p1, p2, p3, p4, p5, p6, p7); + } + + template + inline void Init(const T1& p1, const T2& p2, const T3& p3, const T4& p4, + const T5& p5, const T6& p6, const T7& p7, const T8& p8) { + new (space_) Type(p1, p2, p3, p4, p5, p6, p7, p8); + } + + template + inline void Init(const T1& p1, const T2& p2, const T3& p3, const T4& p4, + const T5& p5, const T6& p6, const T7& p7, const T8& p8, + const T9& p9) { + new (space_) Type(p1, p2, p3, p4, p5, p6, p7, p8, p9); + } + + template + inline void Init(const T1& p1, const T2& p2, const T3& p3, const T4& p4, + const T5& p5, const T6& p6, const T7& p7, const T8& p8, + const T9& p9, const T10& p10) { + new (space_) Type(p1, p2, p3, p4, p5, p6, p7, p8, p9, p10); + } + + template + inline void Init(const T1& p1, const T2& p2, const T3& p3, const T4& p4, + const T5& p5, const T6& p6, const T7& p7, const T8& p8, + const T9& p9, const T10& p10, const T11& p11) { + new (space_) Type(p1, p2, p3, p4, p5, p6, p7, p8, p9, p10, p11); + } +#endif // LANG_CXX11 + + inline void Destroy() { get()->~Type(); } + + private: + TF_LIB_GTL_ALIGNED_CHAR_ARRAY(Type, 1) space_; +}; + +#undef TF_LIB_GTL_ALIGNED_CHAR_ARRAY +#undef TF_LIB_GTL_ALIGN_OF + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_LIB_GTL_MANUAL_CONSTRUCTOR_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/lib/gtl/map_util.h b/third_party/tflite-hdrs/tensorflow/core/lib/gtl/map_util.h new file mode 100644 index 00000000..47d28e7d --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/lib/gtl/map_util.h @@ -0,0 +1,40 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This file provides utility functions for use with STL map-like data +// structures, such as std::map and hash_map. Some functions will also work with +// sets, such as ContainsKey(). + +#ifndef TENSORFLOW_CORE_LIB_GTL_MAP_UTIL_H_ +#define TENSORFLOW_CORE_LIB_GTL_MAP_UTIL_H_ + +#include "xla/tsl/lib/gtl/map_util.h" + +namespace tensorflow { +namespace gtl { +// NOLINTBEGIN(misc-unused-using-decls) +using ::tsl::gtl::EraseKeyReturnValuePtr; +using ::tsl::gtl::FindOrNull; +using ::tsl::gtl::FindPtrOrNull; +using ::tsl::gtl::FindWithDefault; +using ::tsl::gtl::InsertIfNotPresent; +using ::tsl::gtl::InsertOrUpdate; +using ::tsl::gtl::LookupOrInsert; +using ::tsl::gtl::ReverseMap; +// NOLINTEND(misc-unused-using-decls) +} // namespace gtl +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_LIB_GTL_MAP_UTIL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/lib/gtl/priority_queue_util.h b/third_party/tflite-hdrs/tensorflow/core/lib/gtl/priority_queue_util.h new file mode 100644 index 00000000..93bf3d30 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/lib/gtl/priority_queue_util.h @@ -0,0 +1,55 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_LIB_GTL_PRIORITY_QUEUE_UTIL_H_ +#define TENSORFLOW_CORE_LIB_GTL_PRIORITY_QUEUE_UTIL_H_ + +#include +#include +#include + +namespace tensorflow { +namespace gtl { + +// Removes the top element from a std::priority_queue and returns it. +// Supports movable types. +template +T ConsumeTop(std::priority_queue* q) { + // std::priority_queue is required to implement pop() as if it + // called: + // std::pop_heap() + // c.pop_back() + // unfortunately, it does not provide access to the removed element. + // If the element is move only (such as a unique_ptr), there is no way to + // reclaim it in the standard API. std::priority_queue does, however, expose + // the underlying container as a protected member, so we use that access + // to extract the desired element between those two calls. + using Q = std::priority_queue; + struct Expose : Q { + using Q::c; + using Q::comp; + }; + auto& c = q->*&Expose::c; + auto& comp = q->*&Expose::comp; + std::pop_heap(c.begin(), c.end(), comp); + auto r = std::move(c.back()); + c.pop_back(); + return r; +} + +} // namespace gtl +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_LIB_GTL_PRIORITY_QUEUE_UTIL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/lib/gtl/subtle/map_traits.h b/third_party/tflite-hdrs/tensorflow/core/lib/gtl/subtle/map_traits.h new file mode 100644 index 00000000..c4cca1fb --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/lib/gtl/subtle/map_traits.h @@ -0,0 +1,44 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Traits classes for performing uniform lookup on different map value types. +// +// The access is computed as follows: +// +// 1. If T has a `first` or `second` field, use them. +// 2. Otherwise if it has `key()` or `value()` methods, use them. +// 3. Otherwise the program is ill-formed. +#ifndef TENSORFLOW_CORE_LIB_GTL_SUBTLE_MAP_TRAITS_H_ +#define TENSORFLOW_CORE_LIB_GTL_SUBTLE_MAP_TRAITS_H_ + +#include "xla/tsl/lib/gtl/subtle/map_traits.h" + +namespace tensorflow { +namespace gtl { +namespace subtle { +namespace internal_map_traits { +// NOLINTBEGIN(misc-unused-using-decls) +using ::tsl::gtl::subtle::internal_map_traits::GetKey; +using ::tsl::gtl::subtle::internal_map_traits::GetMapped; +using ::tsl::gtl::subtle::internal_map_traits::Rank0; +using ::tsl::gtl::subtle::internal_map_traits::Rank1; +// NOLINTEND(misc-unused-using-decls) + +} // namespace internal_map_traits +} // namespace subtle +} // namespace gtl +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_LIB_GTL_SUBTLE_MAP_TRAITS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/lib/gtl/top_n.h b/third_party/tflite-hdrs/tensorflow/core/lib/gtl/top_n.h new file mode 100644 index 00000000..1f871e61 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/lib/gtl/top_n.h @@ -0,0 +1,336 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This simple class finds the top n elements of an incrementally provided set +// of elements which you push one at a time. If the number of elements exceeds +// n, the lowest elements are incrementally dropped. At the end you get +// a vector of the top elements sorted in descending order (through Extract() or +// ExtractNondestructive()), or a vector of the top elements but not sorted +// (through ExtractUnsorted() or ExtractUnsortedNondestructive()). +// +// The value n is specified in the constructor. If there are p elements pushed +// altogether: +// The total storage requirements are O(min(n, p)) elements +// The running time is O(p * log(min(n, p))) comparisons +// If n is a constant, the total storage required is a constant and the running +// time is linear in p. +// +// NOTE(zhifengc): There is a way to do this in O(min(n, p)) storage and O(p) +// runtime. The basic idea is to repeatedly fill up a buffer of 2 * n elements, +// discarding the lowest n elements whenever the buffer is full using a linear- +// time median algorithm. This may have better performance when the input +// sequence is partially sorted. +// +// NOTE(zhifengc): This class should be redesigned to avoid reallocating a +// vector for each Extract. + +#ifndef TENSORFLOW_CORE_LIB_GTL_TOP_N_H_ +#define TENSORFLOW_CORE_LIB_GTL_TOP_N_H_ + +#include +#include +#include +#include +#include + +#include "tensorflow/core/platform/logging.h" + +namespace tensorflow { +namespace gtl { + +// Cmp is an stl binary predicate. Note that Cmp is the "greater" predicate, +// not the more commonly used "less" predicate. +// +// If you use a "less" predicate here, the TopN will pick out the bottom N +// elements out of the ones passed to it, and it will return them sorted in +// ascending order. +// +// TopN is rule-of-zero copyable and movable if its members are. +template > +class TopN { + public: + // The TopN is in one of the three states: + // + // o UNORDERED: this is the state an instance is originally in, + // where the elements are completely orderless. + // + // o BOTTOM_KNOWN: in this state, we keep the invariant that there + // is at least one element in it, and the lowest element is at + // position 0. The elements in other positions remain + // unsorted. This state is reached if the state was originally + // UNORDERED and a peek_bottom() function call is invoked. + // + // o HEAP_SORTED: in this state, the array is kept as a heap and + // there are exactly limit_ elements in the array. This + // state is reached when at least (limit_+1) elements are + // pushed in. + // + // The state transition graph is at follows: + // + // peek_bottom() (limit_+1) elements pushed + // UNORDERED --------------> BOTTOM_KNOWN --------------------> HEAP_SORTED + // | ^ + // | (limit_+1) elements pushed | + // +-----------------------------------------------------------+ + + enum State { UNORDERED, BOTTOM_KNOWN, HEAP_SORTED }; + using UnsortedIterator = typename std::vector::const_iterator; + + // 'limit' is the maximum number of top results to return. + explicit TopN(size_t limit) : TopN(limit, Cmp()) {} + TopN(size_t limit, const Cmp &cmp) : limit_(limit), cmp_(cmp) {} + + size_t limit() const { return limit_; } + + // Number of elements currently held by this TopN object. This + // will be no greater than 'limit' passed to the constructor. + size_t size() const { return elements_.size(); } + + bool empty() const { return size() == 0; } + + // If you know how many elements you will push at the time you create the + // TopN object, you can call reserve to preallocate the memory that TopN + // will need to process all 'n' pushes. Calling this method is optional. + void reserve(size_t n) { + // We may need limit_+1 for the case where we transition from an unsorted + // set of limit_ elements to a heap. + elements_.reserve(std::min(n, limit_ + 1)); + } + + // Push 'v'. If the maximum number of elements was exceeded, drop the + // lowest element and return it in 'dropped' (if given). If the maximum is not + // exceeded, 'dropped' will remain unchanged. 'dropped' may be omitted or + // nullptr, in which case it is not filled in. + // Requires: T is CopyAssignable, Swappable + void push(const T &v) { push(v, nullptr); } + void push(const T &v, T *dropped) { PushInternal(v, dropped); } + + // Move overloads of push. + // Requires: T is MoveAssignable, Swappable + void push(T &&v) { // NOLINT(build/c++11) + push(std::move(v), nullptr); + } + void push(T &&v, T *dropped) { // NOLINT(build/c++11) + PushInternal(std::move(v), dropped); + } + + // Peeks the bottom result without calling Extract() + const T &peek_bottom(); + + // Extract the elements as a vector sorted in descending order. The caller + // assumes ownership of the vector and must delete it when done. This is a + // destructive operation. The only method that can be called immediately + // after Extract() is Reset(). + std::vector *Extract(); + + // Similar to Extract(), but makes no guarantees the elements are in sorted + // order. As with Extract(), the caller assumes ownership of the vector and + // must delete it when done. This is a destructive operation. The only + // method that can be called immediately after ExtractUnsorted() is Reset(). + std::vector *ExtractUnsorted(); + + // A non-destructive version of Extract(). Copy the elements in a new vector + // sorted in descending order and return it. The caller assumes ownership of + // the new vector and must delete it when done. After calling + // ExtractNondestructive(), the caller can continue to push() new elements. + std::vector *ExtractNondestructive() const; + + // A non-destructive version of Extract(). Copy the elements to a given + // vector sorted in descending order. After calling + // ExtractNondestructive(), the caller can continue to push() new elements. + // Note: + // 1. The given argument must to be allocated. + // 2. Any data contained in the vector prior to the call will be deleted + // from it. After the call the vector will contain only the elements + // from the data structure. + void ExtractNondestructive(std::vector *output) const; + + // A non-destructive version of ExtractUnsorted(). Copy the elements in a new + // vector and return it, with no guarantees the elements are in sorted order. + // The caller assumes ownership of the new vector and must delete it when + // done. After calling ExtractUnsortedNondestructive(), the caller can + // continue to push() new elements. + std::vector *ExtractUnsortedNondestructive() const; + + // A non-destructive version of ExtractUnsorted(). Copy the elements into + // a given vector, with no guarantees the elements are in sorted order. + // After calling ExtractUnsortedNondestructive(), the caller can continue + // to push() new elements. + // Note: + // 1. The given argument must to be allocated. + // 2. Any data contained in the vector prior to the call will be deleted + // from it. After the call the vector will contain only the elements + // from the data structure. + void ExtractUnsortedNondestructive(std::vector *output) const; + + // Return an iterator to the beginning (end) of the container, + // with no guarantees about the order of iteration. These iterators are + // invalidated by mutation of the data structure. + UnsortedIterator unsorted_begin() const { return elements_.begin(); } + UnsortedIterator unsorted_end() const { return elements_.end(); } + + // Accessor for comparator template argument. + Cmp *comparator() { return &cmp_; } + + // This removes all elements. If Extract() or ExtractUnsorted() have been + // called, this will put it back in an empty but useable state. + void Reset(); + + private: + template + void PushInternal(U &&v, T *dropped); // NOLINT(build/c++11) + + // elements_ can be in one of two states: + // elements_.size() <= limit_ && state_ != HEAP_SORTED: + // elements_ is an unsorted vector of elements pushed so far. + // elements_.size() == limit_ && state_ == HEAP_SORTED: + // elements_ is an stl heap. + std::vector elements_; + size_t limit_; // Maximum number of elements to find + Cmp cmp_; // Greater-than comparison function + State state_ = UNORDERED; +}; + +// ---------------------------------------------------------------------- +// Implementations of non-inline functions + +template +template +void TopN::PushInternal(U &&v, T *dropped) { // NOLINT(build/c++11) + if (limit_ == 0) { + if (dropped) *dropped = std::forward(v); // NOLINT(build/c++11) + return; + } + if (state_ != HEAP_SORTED) { + // We may temporarily extend one beyond limit_ elements here. This is + // necessary for finding and removing the smallest element. + elements_.push_back(std::forward(v)); // NOLINT(build/c++11) + if (elements_.size() == limit_ + 1) { + // Transition from unsorted vector to a heap. + std::make_heap(elements_.begin(), elements_.end(), cmp_); + std::pop_heap(elements_.begin(), elements_.end(), cmp_); + if (dropped) *dropped = std::move(elements_.back()); + elements_.pop_back(); // Restore to size limit_. + state_ = HEAP_SORTED; + } else if (state_ == UNORDERED || + cmp_(elements_.back(), elements_.front())) { + // Easy case: we just push the new element back + } else { + // To maintain the BOTTOM_KNOWN state, we need to make sure that + // the element at position 0 is always the smallest. So we put + // the new element at position 0 and push the original bottom + // element in the back. + // Warning: this code is subtle. + using std::swap; + swap(elements_.front(), elements_.back()); + } + + } else { + // Only insert the new element if it is greater than the least element. + if (cmp_(v, elements_.front())) { + // Remove the top (smallest) element of the min heap, then push the new + // value in. + std::pop_heap(elements_.begin(), elements_.end(), cmp_); + if (dropped) *dropped = std::move(elements_.back()); + elements_.back() = std::forward(v); + std::push_heap(elements_.begin(), elements_.end(), cmp_); + } else { + if (dropped) *dropped = std::forward(v); // NOLINT(build/c++11) + } + } +} + +template +const T &TopN::peek_bottom() { + CHECK(!empty()); + if (state_ == UNORDERED) { + // We need to do a linear scan to find out the bottom element + int min_candidate = 0; + for (size_t i = 1; i < elements_.size(); ++i) { + if (cmp_(elements_[min_candidate], elements_[i])) { + min_candidate = i; + } + } + // By swapping the element at position 0 and the minimal + // element, we transition to the BOTTOM_KNOWN state + if (min_candidate != 0) { + using std::swap; + swap(elements_[0], elements_[min_candidate]); + } + state_ = BOTTOM_KNOWN; + } + return elements_.front(); +} + +template +std::vector *TopN::Extract() { + auto out = new std::vector; + out->swap(elements_); + if (state_ != HEAP_SORTED) { + std::sort(out->begin(), out->end(), cmp_); + } else { + std::sort_heap(out->begin(), out->end(), cmp_); + } + return out; +} + +template +std::vector *TopN::ExtractUnsorted() { + auto out = new std::vector; + out->swap(elements_); + return out; +} + +template +std::vector *TopN::ExtractNondestructive() const { + auto out = new std::vector; + ExtractNondestructive(out); + return out; +} + +template +void TopN::ExtractNondestructive(std::vector *output) const { + CHECK(output); + *output = elements_; + if (state_ != HEAP_SORTED) { + std::sort(output->begin(), output->end(), cmp_); + } else { + std::sort_heap(output->begin(), output->end(), cmp_); + } +} + +template +std::vector *TopN::ExtractUnsortedNondestructive() const { + auto elements = new std::vector; + ExtractUnsortedNondestructive(elements); + return elements; +} + +template +void TopN::ExtractUnsortedNondestructive(std::vector *output) const { + CHECK(output); + *output = elements_; +} + +template +void TopN::Reset() { + elements_.clear(); + state_ = UNORDERED; +} + +} // namespace gtl +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_LIB_GTL_TOP_N_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/lib/hash/crc32c.h b/third_party/tflite-hdrs/tensorflow/core/lib/hash/crc32c.h new file mode 100644 index 00000000..7e8c8307 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/lib/hash/crc32c.h @@ -0,0 +1,38 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_LIB_HASH_CRC32C_H_ +#define TENSORFLOW_CORE_LIB_HASH_CRC32C_H_ + +#include + +#include "xla/tsl/lib/hash/crc32c.h" +#include "tensorflow/core/platform/cord.h" +#include "tensorflow/core/platform/platform.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +namespace crc32c { +// NOLINTBEGIN(misc-unused-using-decls) +using tsl::crc32c::Extend; +using tsl::crc32c::kMaskDelta; +using tsl::crc32c::Mask; +using tsl::crc32c::Unmask; +using tsl::crc32c::Value; +// NOLINTEND(misc-unused-using-decls) +} // namespace crc32c +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_LIB_HASH_CRC32C_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/lib/hash/hash.h b/third_party/tflite-hdrs/tensorflow/core/lib/hash/hash.h new file mode 100644 index 00000000..fa2cc295 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/lib/hash/hash.h @@ -0,0 +1,23 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Simple hash functions used for internal data structures + +#ifndef TENSORFLOW_CORE_LIB_HASH_HASH_H_ +#define TENSORFLOW_CORE_LIB_HASH_HASH_H_ + +#include "tensorflow/core/platform/hash.h" + +#endif // TENSORFLOW_CORE_LIB_HASH_HASH_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/lib/histogram/histogram.h b/third_party/tflite-hdrs/tensorflow/core/lib/histogram/histogram.h new file mode 100644 index 00000000..281e190f --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/lib/histogram/histogram.h @@ -0,0 +1,41 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_LIB_HISTOGRAM_HISTOGRAM_H_ +#define TENSORFLOW_CORE_LIB_HISTOGRAM_HISTOGRAM_H_ + +#include +#include + +#include "xla/tsl/lib/histogram/histogram.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/thread_annotations.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +using tsl::HistogramProto; // NOLINT + +namespace histogram { + +using tsl::histogram::Histogram; // NOLINT +using tsl::histogram::ThreadSafeHistogram; // NOLINT + +} // namespace histogram +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_LIB_HISTOGRAM_HISTOGRAM_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/lib/io/block.h b/third_party/tflite-hdrs/tensorflow/core/lib/io/block.h new file mode 100644 index 00000000..d3cfb88f --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/lib/io/block.h @@ -0,0 +1,28 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_LIB_IO_BLOCK_H_ +#define TENSORFLOW_CORE_LIB_IO_BLOCK_H_ + +#include "xla/tsl/lib/io/block.h" +#include "tensorflow/core/lib/io/iterator.h" + +namespace tensorflow { +namespace table { +using tsl::table::Block; // NOLINT(misc-unused-using-decls) +} +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_LIB_IO_BLOCK_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/lib/io/block_builder.h b/third_party/tflite-hdrs/tensorflow/core/lib/io/block_builder.h new file mode 100644 index 00000000..b47278cb --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/lib/io/block_builder.h @@ -0,0 +1,29 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_LIB_IO_BLOCK_BUILDER_H_ +#define TENSORFLOW_CORE_LIB_IO_BLOCK_BUILDER_H_ + +#include "xla/tsl/lib/io/block_builder.h" +#include "tensorflow/core/platform/stringpiece.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +namespace table { +using tsl::table::BlockBuilder; // NOLINT(misc-unused-using-decls) +} +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_LIB_IO_BLOCK_BUILDER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/lib/io/buffered_inputstream.h b/third_party/tflite-hdrs/tensorflow/core/lib/io/buffered_inputstream.h new file mode 100644 index 00000000..15023e6a --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/lib/io/buffered_inputstream.h @@ -0,0 +1,29 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_LIB_IO_BUFFERED_INPUTSTREAM_H_ +#define TENSORFLOW_CORE_LIB_IO_BUFFERED_INPUTSTREAM_H_ + +#include "xla/tsl/lib/io/buffered_inputstream.h" +#include "tensorflow/core/lib/io/inputstream_interface.h" +#include "tensorflow/core/platform/file_system.h" + +namespace tensorflow { +namespace io { +using tsl::io::BufferedInputStream; // NOLINT(misc-unused-using-decls) +} +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_LIB_IO_BUFFERED_INPUTSTREAM_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/lib/io/cache.h b/third_party/tflite-hdrs/tensorflow/core/lib/io/cache.h new file mode 100644 index 00000000..3afd011f --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/lib/io/cache.h @@ -0,0 +1,32 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_LIB_IO_CACHE_H_ +#define TENSORFLOW_CORE_LIB_IO_CACHE_H_ + +#include "xla/tsl/lib/io/cache.h" +#include "tensorflow/core/platform/stringpiece.h" + +namespace tensorflow { +using tsl::Slice; // NOLINT(misc-unused-using-decls) +namespace table { +// NOLINTBEGIN(misc-unused-using-decls) +using tsl::table::Cache; +using tsl::table::NewLRUCache; +// NOLINTEND(misc-unused-using-decls) +} // namespace table +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_LIB_IO_CACHE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/lib/io/compression.h b/third_party/tflite-hdrs/tensorflow/core/lib/io/compression.h new file mode 100644 index 00000000..628de375 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/lib/io/compression.h @@ -0,0 +1,34 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_LIB_IO_COMPRESSION_H_ +#define TENSORFLOW_CORE_LIB_IO_COMPRESSION_H_ + +#include "xla/tsl/lib/io/compression.h" + +namespace tensorflow { +namespace io { +namespace compression { +// NOLINTBEGIN(misc-unused-using-decls) +using tsl::io::compression::kGzip; +using tsl::io::compression::kNone; +using tsl::io::compression::kSnappy; +using tsl::io::compression::kZlib; +// NOLINTEND(misc-unused-using-decls) +} // namespace compression +} // namespace io +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_LIB_IO_COMPRESSION_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/lib/io/format.h b/third_party/tflite-hdrs/tensorflow/core/lib/io/format.h new file mode 100644 index 00000000..49f96d19 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/lib/io/format.h @@ -0,0 +1,36 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_LIB_IO_FORMAT_H_ +#define TENSORFLOW_CORE_LIB_IO_FORMAT_H_ + +#include "xla/tsl/lib/io/format.h" +#include "tensorflow/core/lib/io/table_builder.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/stringpiece.h" + +namespace tensorflow { +namespace table { +// NOLINTBEGIN(misc-unused-using-decls) +using tsl::table::BlockContents; +using tsl::table::BlockHandle; +using tsl::table::kBlockTrailerSize; +using tsl::table::kTableMagicNumber; +using tsl::table::ReadBlock; +// NOLINTEND(misc-unused-using-decls) +} // namespace table +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_LIB_IO_FORMAT_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/lib/io/inputbuffer.h b/third_party/tflite-hdrs/tensorflow/core/lib/io/inputbuffer.h new file mode 100644 index 00000000..2573a816 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/lib/io/inputbuffer.h @@ -0,0 +1,32 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_LIB_IO_INPUTBUFFER_H_ +#define TENSORFLOW_CORE_LIB_IO_INPUTBUFFER_H_ + +#include "xla/tsl/lib/io/inputbuffer.h" +#include "tensorflow/core/platform/coding.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +namespace io { +using tsl::io::InputBuffer; // NOLINT(misc-unused-using-decls) +} +} + +#endif // TENSORFLOW_CORE_LIB_IO_INPUTBUFFER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/lib/io/inputstream_interface.h b/third_party/tflite-hdrs/tensorflow/core/lib/io/inputstream_interface.h new file mode 100644 index 00000000..f38489d5 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/lib/io/inputstream_interface.h @@ -0,0 +1,31 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_LIB_IO_INPUTSTREAM_INTERFACE_H_ +#define TENSORFLOW_CORE_LIB_IO_INPUTSTREAM_INTERFACE_H_ + +#include "xla/tsl/lib/io/inputstream_interface.h" +#include "tensorflow/core/platform/cord.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +namespace io { +using tsl::io::InputStreamInterface; // NOLINT(misc-unused-using-decls) +} +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_LIB_IO_INPUTSTREAM_INTERFACE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/lib/io/iterator.h b/third_party/tflite-hdrs/tensorflow/core/lib/io/iterator.h new file mode 100644 index 00000000..4f3c0960 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/lib/io/iterator.h @@ -0,0 +1,43 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// An iterator yields a sequence of key/value pairs from a source. +// The following class defines the interface. Multiple implementations +// are provided by this library. In particular, iterators are provided +// to access the contents of a Table or a DB. +// +// Multiple threads can invoke const methods on an Iterator without +// external synchronization, but if any of the threads may call a +// non-const method, all threads accessing the same Iterator must use +// external synchronization. + +#ifndef TENSORFLOW_CORE_LIB_IO_ITERATOR_H_ +#define TENSORFLOW_CORE_LIB_IO_ITERATOR_H_ + +#include "xla/tsl/lib/io/iterator.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/stringpiece.h" + +namespace tensorflow { +namespace table { +// NOLINTBEGIN(misc-unused-using-decls) +using tsl::table::Iterator; +using tsl::table::NewEmptyIterator; +using tsl::table::NewErrorIterator; +// NOLINTEND(misc-unused-using-decls) +} // namespace table +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_LIB_IO_ITERATOR_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/lib/io/path.h b/third_party/tflite-hdrs/tensorflow/core/lib/io/path.h new file mode 100644 index 00000000..f5deacd1 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/lib/io/path.h @@ -0,0 +1,21 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_LIB_IO_PATH_H_ +#define TENSORFLOW_CORE_LIB_IO_PATH_H_ + +#include "tensorflow/core/platform/path.h" + +#endif // TENSORFLOW_CORE_LIB_IO_PATH_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/lib/io/proto_encode_helper.h b/third_party/tflite-hdrs/tensorflow/core/lib/io/proto_encode_helper.h new file mode 100644 index 00000000..8ca1d5be --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/lib/io/proto_encode_helper.h @@ -0,0 +1,31 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_LIB_IO_PROTO_ENCODE_HELPER_H_ +#define TENSORFLOW_CORE_LIB_IO_PROTO_ENCODE_HELPER_H_ + +#include "xla/tsl/lib/io/proto_encode_helper.h" +#include "tensorflow/core/platform/coding.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/platform/stringpiece.h" + +namespace tensorflow { +namespace io { +using tsl::io::ProtoEncodeHelper; // NOLINT(misc-unused-using-decls) +} +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_LIB_IO_PROTO_ENCODE_HELPER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/lib/io/random_inputstream.h b/third_party/tflite-hdrs/tensorflow/core/lib/io/random_inputstream.h new file mode 100644 index 00000000..70651bc6 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/lib/io/random_inputstream.h @@ -0,0 +1,30 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_LIB_IO_RANDOM_INPUTSTREAM_H_ +#define TENSORFLOW_CORE_LIB_IO_RANDOM_INPUTSTREAM_H_ + +#include "xla/tsl/lib/io/random_inputstream.h" +#include "tensorflow/core/lib/io/inputstream_interface.h" +#include "tensorflow/core/platform/cord.h" +#include "tensorflow/core/platform/file_system.h" + +namespace tensorflow { +namespace io { +using tsl::io::RandomAccessInputStream; // NOLINT(misc-unused-using-decls) +} +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_LIB_IO_RANDOM_INPUTSTREAM_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/lib/io/record_reader.h b/third_party/tflite-hdrs/tensorflow/core/lib/io/record_reader.h new file mode 100644 index 00000000..c2a06c6b --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/lib/io/record_reader.h @@ -0,0 +1,40 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_LIB_IO_RECORD_READER_H_ +#define TENSORFLOW_CORE_LIB_IO_RECORD_READER_H_ + +#include "tensorflow/core/lib/io/inputstream_interface.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/stringpiece.h" +#if !defined(IS_SLIM_BUILD) +#include "tensorflow/core/lib/io/zlib_compression_options.h" +#include "tensorflow/core/lib/io/zlib_inputstream.h" +#endif // IS_SLIM_BUILD +#include "xla/tsl/lib/io/record_reader.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +namespace io { +// NOLINTBEGIN(misc-unused-using-decls) +using tsl::io::RecordReader; +using tsl::io::RecordReaderOptions; +using tsl::io::SequentialRecordReader; +// NOLINTEND(misc-unused-using-decls) +} // namespace io +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_LIB_IO_RECORD_READER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/lib/io/record_writer.h b/third_party/tflite-hdrs/tensorflow/core/lib/io/record_writer.h new file mode 100644 index 00000000..602de00e --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/lib/io/record_writer.h @@ -0,0 +1,41 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_LIB_IO_RECORD_WRITER_H_ +#define TENSORFLOW_CORE_LIB_IO_RECORD_WRITER_H_ + +#include "tensorflow/core/lib/hash/crc32c.h" +#include "tensorflow/core/platform/coding.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/stringpiece.h" +#if !defined(IS_SLIM_BUILD) +#include "tensorflow/core/lib/io/zlib_compression_options.h" +#include "tensorflow/core/lib/io/zlib_outputbuffer.h" +#endif // IS_SLIM_BUILD +#include "xla/tsl/lib/io/record_writer.h" +#include "tensorflow/core/platform/cord.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +namespace io { +// NOLINTBEGIN(misc-unused-using-decls) +using tsl::io::RecordWriter; +using tsl::io::RecordWriterOptions; +// NOLINTEND(misc-unused-using-decls) +} // namespace io +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_LIB_IO_RECORD_WRITER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/lib/io/table.h b/third_party/tflite-hdrs/tensorflow/core/lib/io/table.h new file mode 100644 index 00000000..0045829a --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/lib/io/table.h @@ -0,0 +1,28 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_LIB_IO_TABLE_H_ +#define TENSORFLOW_CORE_LIB_IO_TABLE_H_ + +#include "xla/tsl/lib/io/table.h" +#include "tensorflow/core/lib/io/iterator.h" + +namespace tensorflow { +namespace table { +using tsl::table::Table; // NOLINT(misc-unused-using-decls) +} +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_LIB_IO_TABLE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/lib/io/table_builder.h b/third_party/tflite-hdrs/tensorflow/core/lib/io/table_builder.h new file mode 100644 index 00000000..52e27e9a --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/lib/io/table_builder.h @@ -0,0 +1,38 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// TableBuilder provides the interface used to build a Table +// (an immutable and sorted map from keys to values). +// +// Multiple threads can invoke const methods on a TableBuilder without +// external synchronization, but if any of the threads may call a +// non-const method, all threads accessing the same TableBuilder must use +// external synchronization. + +#ifndef TENSORFLOW_CORE_LIB_IO_TABLE_BUILDER_H_ +#define TENSORFLOW_CORE_LIB_IO_TABLE_BUILDER_H_ + +#include "xla/tsl/lib/io/table_builder.h" +#include "tensorflow/core/lib/io/table_options.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/stringpiece.h" + +namespace tensorflow { +namespace table { +using tsl::table::TableBuilder; // NOLINT(misc-unused-using-decls) +} +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_LIB_IO_TABLE_BUILDER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/lib/io/table_options.h b/third_party/tflite-hdrs/tensorflow/core/lib/io/table_options.h new file mode 100644 index 00000000..c16d4aca --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/lib/io/table_options.h @@ -0,0 +1,32 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_LIB_IO_TABLE_OPTIONS_H_ +#define TENSORFLOW_CORE_LIB_IO_TABLE_OPTIONS_H_ + +#include "xla/tsl/lib/io/table_options.h" + +namespace tensorflow { +namespace table { +// NOLINTBEGIN(misc-unused-using-decls) +using tsl::table::CompressionType; +using tsl::table::kNoCompression; +using tsl::table::kSnappyCompression; +using tsl::table::Options; +// NOLINTEND(misc-unused-using-decls) +} // namespace table +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_LIB_IO_TABLE_OPTIONS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/lib/io/two_level_iterator.h b/third_party/tflite-hdrs/tensorflow/core/lib/io/two_level_iterator.h new file mode 100644 index 00000000..c2b94de7 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/lib/io/two_level_iterator.h @@ -0,0 +1,28 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_LIB_IO_TWO_LEVEL_ITERATOR_H_ +#define TENSORFLOW_CORE_LIB_IO_TWO_LEVEL_ITERATOR_H_ + +#include "xla/tsl/lib/io/two_level_iterator.h" +#include "tensorflow/core/lib/io/iterator.h" + +namespace tensorflow { +namespace table { +using tsl::table::NewTwoLevelIterator; // NOLINT(misc-unused-using-decls) +} +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_LIB_IO_TWO_LEVEL_ITERATOR_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/lib/io/zlib_compression_options.h b/third_party/tflite-hdrs/tensorflow/core/lib/io/zlib_compression_options.h new file mode 100644 index 00000000..a0d43378 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/lib/io/zlib_compression_options.h @@ -0,0 +1,28 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_LIB_IO_ZLIB_COMPRESSION_OPTIONS_H_ +#define TENSORFLOW_CORE_LIB_IO_ZLIB_COMPRESSION_OPTIONS_H_ + +#include "xla/tsl/lib/io/zlib_compression_options.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +namespace io { +using tsl::io::ZlibCompressionOptions; // NOLINT(misc-unused-using-decls) +} +} + +#endif // TENSORFLOW_CORE_LIB_IO_ZLIB_COMPRESSION_OPTIONS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/lib/io/zlib_inputstream.h b/third_party/tflite-hdrs/tensorflow/core/lib/io/zlib_inputstream.h new file mode 100644 index 00000000..086493e3 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/lib/io/zlib_inputstream.h @@ -0,0 +1,33 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_LIB_IO_ZLIB_INPUTSTREAM_H_ +#define TENSORFLOW_CORE_LIB_IO_ZLIB_INPUTSTREAM_H_ + +#include "xla/tsl/lib/io/zlib_inputstream.h" +#include "tensorflow/core/lib/io/inputstream_interface.h" +#include "tensorflow/core/lib/io/zlib_compression_options.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +namespace io { +using tsl::io::ZlibInputStream; // NOLINT(misc-unused-using-decls); +} +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_LIB_IO_ZLIB_INPUTSTREAM_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/lib/io/zlib_outputbuffer.h b/third_party/tflite-hdrs/tensorflow/core/lib/io/zlib_outputbuffer.h new file mode 100644 index 00000000..7d3950f6 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/lib/io/zlib_outputbuffer.h @@ -0,0 +1,34 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_LIB_IO_ZLIB_OUTPUTBUFFER_H_ +#define TENSORFLOW_CORE_LIB_IO_ZLIB_OUTPUTBUFFER_H_ + +#include "xla/tsl/lib/io/zlib_outputbuffer.h" +#include "tensorflow/core/lib/io/zlib_compression_options.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/file_system.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/stringpiece.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +namespace io { +using tsl::io::ZlibOutputBuffer; // NOLINT(misc-unused-using-decls) +} +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_LIB_IO_ZLIB_OUTPUTBUFFER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/lib/jpeg/jpeg_handle.h b/third_party/tflite-hdrs/tensorflow/core/lib/jpeg/jpeg_handle.h new file mode 100644 index 00000000..8b2dd418 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/lib/jpeg/jpeg_handle.h @@ -0,0 +1,61 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This file declares the functions and structures for memory I/O with libjpeg +// These functions are not meant to be used directly, see jpeg_mem.h instead. + +#ifndef TENSORFLOW_CORE_LIB_JPEG_JPEG_HANDLE_H_ +#define TENSORFLOW_CORE_LIB_JPEG_JPEG_HANDLE_H_ + +#include "tensorflow/core/platform/jpeg.h" +#include "tensorflow/core/platform/tstring.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +namespace jpeg { + +// Handler for fatal JPEG library errors: clean up & return +void CatchError(j_common_ptr cinfo); + +typedef struct { + struct jpeg_destination_mgr pub; + JOCTET *buffer; + int bufsize; + int datacount; + tstring *dest; +} MemDestMgr; + +typedef struct { + struct jpeg_source_mgr pub; + const unsigned char *data; + unsigned long int datasize; + bool try_recover_truncated_jpeg; +} MemSourceMgr; + +void SetSrc(j_decompress_ptr cinfo, const void *data, + unsigned long int datasize, bool try_recover_truncated_jpeg); + +// JPEG destination: we will store all the data in a buffer "buffer" of total +// size "bufsize", if the buffer overflows, we will be in trouble. +void SetDest(j_compress_ptr cinfo, void *buffer, int bufsize); +// Same as above, except that buffer is only used as a temporary structure and +// is emptied into "destination" as soon as it fills up. +void SetDest(j_compress_ptr cinfo, void *buffer, int bufsize, + tstring *destination); + +} // namespace jpeg +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_LIB_JPEG_JPEG_HANDLE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/lib/jpeg/jpeg_mem.h b/third_party/tflite-hdrs/tensorflow/core/lib/jpeg/jpeg_mem.h new file mode 100644 index 00000000..200e129b --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/lib/jpeg/jpeg_mem.h @@ -0,0 +1,163 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This file defines functions to compress and uncompress JPEG files +// to and from memory. It provides interfaces for raw images +// (data array and size fields). +// Direct manipulation of JPEG strings are supplied: Flip, Rotate, Crop.. + +#ifndef TENSORFLOW_CORE_LIB_JPEG_JPEG_MEM_H_ +#define TENSORFLOW_CORE_LIB_JPEG_JPEG_MEM_H_ + +#include +#include + +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/platform/jpeg.h" +#include "tensorflow/core/platform/tstring.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +namespace jpeg { + +// Flags for Uncompress +struct UncompressFlags { + // ratio can be 1, 2, 4, or 8 and represent the denominator for the scaling + // factor (eg ratio = 4 means that the resulting image will be at 1/4 original + // size in both directions). + int ratio = 1; + + // The number of bytes per pixel (1, 3 or 4), or 0 for autodetect. + int components = 0; + + // If true, decoder will use a slower but nicer upscaling of the chroma + // planes (yuv420/422 only). + bool fancy_upscaling = true; + + // If true, will attempt to fill in missing lines of truncated files + bool try_recover_truncated_jpeg = false; + + // The minimum required fraction of lines read before the image is accepted. + float min_acceptable_fraction = 1.0; + + // The distance in bytes from one scanline to the other. Should be at least + // equal to width*components*sizeof(JSAMPLE). If 0 is passed, the stride + // used will be this minimal value. + int stride = 0; + + // Setting of J_DCT_METHOD enum in jpeglib.h, for choosing which + // algorithm to use for DCT/IDCT. + // + // Setting this has a quality/speed trade-off implication. + J_DCT_METHOD dct_method = JDCT_DEFAULT; + + // Settings of crop window before decompression. + bool crop = false; + // Vertical coordinate of the top-left corner of the result in the input. + int crop_x = 0; + // Horizontal coordinate of the top-left corner of the result in the input. + int crop_y = 0; + // Width of the output image. + int crop_width = 0; + // Height of the output image. + int crop_height = 0; +}; + +// Uncompress some raw JPEG data given by the pointer srcdata and the length +// datasize. +// - width and height are the address where to store the size of the +// uncompressed image in pixels. May be nullptr. +// - components is the address where the number of read components are +// stored. This is *output only*: to request a specific number of +// components use flags.components. May be nullptr. +// - nwarn is the address in which to store the number of warnings. +// May be nullptr. +// The function returns a pointer to the raw uncompressed data or NULL if +// there was an error. The caller of the function is responsible for +// freeing the memory (using delete []). +uint8* Uncompress(const void* srcdata, int datasize, + const UncompressFlags& flags, int* width, int* height, + int* components, // Output only: useful with autodetect + int64_t* nwarn); + +// Version of Uncompress that allocates memory via a callback. The callback +// arguments are (width, height, components). If the size is known ahead of +// time this function can return an existing buffer; passing a callback allows +// the buffer to be shaped based on the JPEG header. The caller is responsible +// for freeing the memory *even along error paths*. +uint8* Uncompress(const void* srcdata, int datasize, + const UncompressFlags& flags, int64_t* nwarn, + std::function allocate_output); + +// Read jpeg header and get image information. Returns true on success. +// The width, height, and components points may be null. +bool GetImageInfo(const void* srcdata, int datasize, int* width, int* height, + int* components); + +// Note: (format & 0xff) = number of components (<=> bytes per pixels) +enum Format { + FORMAT_GRAYSCALE = 0x001, // 1 byte/pixel + FORMAT_RGB = 0x003, // 3 bytes/pixel RGBRGBRGBRGB... + FORMAT_RGBA = 0x004, // 4 bytes/pixel RGBARGBARGBARGBA... + FORMAT_ABGR = 0x104 // 4 bytes/pixel ABGRABGRABGR... +}; + +// Flags for compression +struct CompressFlags { + // Encoding of the input data for compression + Format format; + + // Quality of the compression from 0-100 + int quality = 95; + + // If true, create a jpeg image that loads progressively + bool progressive = false; + + // If true, reduce jpeg size without changing quality (at the cost of CPU/RAM) + bool optimize_jpeg_size = false; + + // See http://en.wikipedia.org/wiki/Chroma_subsampling + bool chroma_downsampling = true; + + // Resolution + int density_unit = 1; // 1 = in, 2 = cm + int x_density = 300; + int y_density = 300; + + // If not empty, embed this XMP metadata in the image header + absl::string_view xmp_metadata; + + // The distance in bytes from one scanline to the other. Should be at least + // equal to width*components*sizeof(JSAMPLE). If 0 is passed, the stride + // used will be this minimal value. + int stride = 0; +}; + +// Compress some raw image given in srcdata, the data is a 2D array of size +// stride*height with one of the formats enumerated above. +// The encoded data is returned as a string. +// If not empty, XMP metadata can be embedded in the image header +// On error, returns the empty string (which is never a valid jpeg). +tstring Compress(const void* srcdata, int width, int height, + const CompressFlags& flags); + +// On error, returns false and sets output to empty. +bool Compress(const void* srcdata, int width, int height, + const CompressFlags& flags, tstring* output); + +} // namespace jpeg +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_LIB_JPEG_JPEG_MEM_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/lib/llvm_rtti/llvm_rtti.h b/third_party/tflite-hdrs/tensorflow/core/lib/llvm_rtti/llvm_rtti.h new file mode 100644 index 00000000..a159e76c --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/lib/llvm_rtti/llvm_rtti.h @@ -0,0 +1,25 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_LIB_LLVM_RTTI_LLVM_RTTI_H_ +#define TENSORFLOW_CORE_LIB_LLVM_RTTI_LLVM_RTTI_H_ + +#include "llvm/Support/Casting.h" + +namespace tensorflow { +using llvm::dyn_cast; +using llvm::isa; +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_LIB_LLVM_RTTI_LLVM_RTTI_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/lib/math/math_util.h b/third_party/tflite-hdrs/tensorflow/core/lib/math/math_util.h new file mode 100644 index 00000000..39bae7f4 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/lib/math/math_util.h @@ -0,0 +1,29 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_LIB_MATH_MATH_UTIL_H_ +#define TENSORFLOW_CORE_LIB_MATH_MATH_UTIL_H_ + +#include "xla/tsl/lib/math/math_util.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +// NOLINTBEGIN(misc-unused-using-decls) +using tsl::MathUtil; +// NOLINTEND(misc-unused-using-decls) +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_LIB_MATH_MATH_UTIL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/lib/monitoring/cell_reader-inl.h b/third_party/tflite-hdrs/tensorflow/core/lib/monitoring/cell_reader-inl.h new file mode 100644 index 00000000..f7be2b62 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/lib/monitoring/cell_reader-inl.h @@ -0,0 +1,48 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_LIB_MONITORING_CELL_READER_INL_H_ +#define TENSORFLOW_CORE_LIB_MONITORING_CELL_READER_INL_H_ + +#include +#include +#include +#include +#include + +#include "xla/tsl/lib/monitoring/cell_reader-inl.h" +#include "tensorflow/core/lib/monitoring/collected_metrics.h" +#include "tensorflow/core/lib/monitoring/metric_def.h" +#include "tensorflow/core/lib/monitoring/test_utils.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/statusor.h" +// NOLINTBEGIN(misc-unused-using-decls) +namespace tensorflow { +namespace monitoring { +namespace testing { +namespace internal { +using tsl::monitoring::testing::internal::CollectMetrics; +using tsl::monitoring::testing::internal::GetDelta; +using tsl::monitoring::testing::internal::GetLatestPoint; +using tsl::monitoring::testing::internal::GetLatestValueOrDefault; +using tsl::monitoring::testing::internal::GetMetricKind; +using tsl::monitoring::testing::internal::GetPoints; +using tsl::monitoring::testing::internal::GetValue; +} // namespace internal +} // namespace testing +} // namespace monitoring +} // namespace tensorflow +// NOLINTEND(misc-unused-using-decls) +#endif // TENSORFLOW_CORE_LIB_MONITORING_CELL_READER_INL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/lib/monitoring/cell_reader.h b/third_party/tflite-hdrs/tensorflow/core/lib/monitoring/cell_reader.h new file mode 100644 index 00000000..fead3ceb --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/lib/monitoring/cell_reader.h @@ -0,0 +1,37 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_LIB_MONITORING_CELL_READER_H_ +#define TENSORFLOW_CORE_LIB_MONITORING_CELL_READER_H_ + +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "xla/tsl/lib/monitoring/cell_reader.h" +#include "tensorflow/core/lib/monitoring/cell_reader-inl.h" +#include "tensorflow/core/lib/monitoring/collected_metrics.h" +#include "tensorflow/core/lib/monitoring/metric_def.h" +// NOLINTBEGIN(misc-unused-using-decls) +namespace tensorflow { +namespace monitoring { +namespace testing { +using tsl::monitoring::testing::CellReader; +} // namespace testing +} // namespace monitoring +} // namespace tensorflow +// NOLINTEND(misc-unused-using-decls) +#endif // TENSORFLOW_CORE_LIB_MONITORING_CELL_READER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/lib/monitoring/collected_metrics.h b/third_party/tflite-hdrs/tensorflow/core/lib/monitoring/collected_metrics.h new file mode 100644 index 00000000..fe707016 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/lib/monitoring/collected_metrics.h @@ -0,0 +1,42 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Standard format in which the metrics are collected, before being exported. +// These are to be used only by the CollectionRegistry and exporters which +// collect metrics using the CollectionRegistry. + +#ifndef TENSORFLOW_CORE_LIB_MONITORING_COLLECTED_METRICS_H_ +#define TENSORFLOW_CORE_LIB_MONITORING_COLLECTED_METRICS_H_ + +#include +#include +#include +#include + +#include "xla/tsl/lib/monitoring/collected_metrics.h" +#include "tensorflow/core/framework/summary.pb.h" +#include "tensorflow/core/lib/monitoring/metric_def.h" +#include "tensorflow/core/lib/monitoring/types.h" +// NOLINTBEGIN(misc-unused-using-decls) +namespace tensorflow { +namespace monitoring { +using tsl::monitoring::CollectedMetrics; +using tsl::monitoring::MetricDescriptor; +using tsl::monitoring::Point; +using tsl::monitoring::PointSet; +} // namespace monitoring +} // namespace tensorflow +// NOLINTEND(misc-unused-using-decls) +#endif // TENSORFLOW_CORE_LIB_MONITORING_COLLECTED_METRICS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/lib/monitoring/collection_registry.h b/third_party/tflite-hdrs/tensorflow/core/lib/monitoring/collection_registry.h new file mode 100644 index 00000000..fa379115 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/lib/monitoring/collection_registry.h @@ -0,0 +1,79 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_LIB_MONITORING_COLLECTION_REGISTRY_H_ +#define TENSORFLOW_CORE_LIB_MONITORING_COLLECTION_REGISTRY_H_ + +#include "xla/tsl/lib/monitoring/collection_registry.h" +// clang-format off +// Required for IS_MOBILE_PLATFORM +#include "tensorflow/core/platform/platform.h" +// clang-format on +// We use a null implementation for mobile platforms. +#ifdef IS_MOBILE_PLATFORM + +#include +#include +#include + +#include "tensorflow/core/lib/monitoring/metric_def.h" +#include "tensorflow/core/platform/macros.h" +// NOLINTBEGIN(misc-unused-using-decls) +namespace tensorflow { +namespace monitoring { +using tsl::monitoring::CollectionRegistry; +using tsl::monitoring::MetricCollector; +using tsl::monitoring::MetricCollectorGetter; +} // namespace monitoring +} // namespace tensorflow +// NOLINTEND(misc-unused-using-decls) +#else // !defined(IS_MOBILE_PLATFORM) + +#include +#include +#include +#include + +#include "tensorflow/core/framework/summary.pb.h" +#include "tensorflow/core/lib/monitoring/collected_metrics.h" +#include "tensorflow/core/lib/monitoring/metric_def.h" +#include "tensorflow/core/lib/monitoring/types.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/stringpiece.h" +#include "tensorflow/core/platform/thread_annotations.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +namespace monitoring { +// NOLINTBEGIN(misc-unused-using-decls) +using tsl::monitoring::CollectionRegistry; +using tsl::monitoring::Exporter; +using tsl::monitoring::MetricCollector; +using tsl::monitoring::MetricCollectorGetter; +using tsl::monitoring::exporter_registration::ExporterRegistration; +using tsl::monitoring::internal::Collector; +namespace test_util { +class CollectionRegistryTestAccess; +} // namespace test_util +// NOLINTEND(misc-unused-using-decls) +} // namespace monitoring +} // namespace tensorflow + +#endif // IS_MOBILE_PLATFORM + +#endif // TENSORFLOW_CORE_LIB_MONITORING_COLLECTION_REGISTRY_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/lib/monitoring/counter.h b/third_party/tflite-hdrs/tensorflow/core/lib/monitoring/counter.h new file mode 100644 index 00000000..35f68891 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/lib/monitoring/counter.h @@ -0,0 +1,43 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_LIB_MONITORING_COUNTER_H_ +#define TENSORFLOW_CORE_LIB_MONITORING_COUNTER_H_ + +#include "xla/tsl/lib/monitoring/counter.h" +#ifdef IS_MOBILE_PLATFORM +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/types.h" +#else +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/monitoring/collection_registry.h" +#include "tensorflow/core/lib/monitoring/metric_def.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/thread_annotations.h" +#endif +// NOLINTBEGIN(misc-unused-using-decls) +namespace tensorflow { +namespace monitoring { + +using tsl::monitoring::Counter; +using tsl::monitoring::CounterCell; + +} // namespace monitoring +} // namespace tensorflow +// NOLINTEND(misc-unused-using-decls) +#endif // TENSORFLOW_CORE_LIB_MONITORING_COUNTER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/lib/monitoring/gauge.h b/third_party/tflite-hdrs/tensorflow/core/lib/monitoring/gauge.h new file mode 100644 index 00000000..301f3683 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/lib/monitoring/gauge.h @@ -0,0 +1,31 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_LIB_MONITORING_GAUGE_H_ +#define TENSORFLOW_CORE_LIB_MONITORING_GAUGE_H_ + +#include "xla/tsl/lib/monitoring/gauge.h" +#include "tensorflow/core/lib/monitoring/collection_registry.h" +#include "tensorflow/core/lib/monitoring/metric_def.h" +// NOLINTBEGIN(misc-unused-using-decls) +namespace tensorflow { +namespace monitoring { +using tsl::monitoring::Gauge; +using tsl::monitoring::GaugeCell; + +} // namespace monitoring +} // namespace tensorflow +// NOLINTEND(misc-unused-using-decls) +#endif // TENSORFLOW_CORE_LIB_MONITORING_GAUGE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/lib/monitoring/metric_def.h b/third_party/tflite-hdrs/tensorflow/core/lib/monitoring/metric_def.h new file mode 100644 index 00000000..bd256d50 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/lib/monitoring/metric_def.h @@ -0,0 +1,38 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_LIB_MONITORING_METRIC_DEF_H_ +#define TENSORFLOW_CORE_LIB_MONITORING_METRIC_DEF_H_ + +#include +#include +#include +#include + +#include "xla/tsl/lib/monitoring/metric_def.h" +#include "tensorflow/core/framework/summary.pb.h" +#include "tensorflow/core/lib/monitoring/types.h" +#include "tensorflow/core/platform/stringpiece.h" +#include "tensorflow/core/platform/types.h" +// NOLINTBEGIN(misc-unused-using-decls) +namespace tensorflow { +namespace monitoring { +using tsl::monitoring::MetricDef; +using tsl::monitoring::MetricKind; +using tsl::monitoring::ValueType; +} // namespace monitoring +} // namespace tensorflow +// NOLINTEND(misc-unused-using-decls) +#endif // TENSORFLOW_CORE_LIB_MONITORING_METRIC_DEF_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/lib/monitoring/percentile_sampler.h b/third_party/tflite-hdrs/tensorflow/core/lib/monitoring/percentile_sampler.h new file mode 100644 index 00000000..8ac77500 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/lib/monitoring/percentile_sampler.h @@ -0,0 +1,67 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_LIB_MONITORING_PERCENTILE_SAMPLER_H_ +#define TENSORFLOW_CORE_LIB_MONITORING_PERCENTILE_SAMPLER_H_ + +// clang-format off +// Required for IS_MOBILE_PLATFORM +#include "tensorflow/core/platform/platform.h" +#include "xla/tsl/lib/monitoring/percentile_sampler.h" +// clang-format on + +// We replace this implementation with a null implementation for mobile +// platforms. +#ifdef IS_MOBILE_PLATFORM + +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/monitoring/collection_registry.h" +#include "tensorflow/core/lib/monitoring/metric_def.h" +#include "tensorflow/core/lib/monitoring/types.h" +#include "tensorflow/core/platform/macros.h" +// NOLINTBEGIN(misc-unused-using-decls) +namespace tensorflow { +namespace monitoring { + +using tsl::monitoring::PercentileSampler; +using tsl::monitoring::PercentileSamplerCell; + +} // namespace monitoring +} // namespace tensorflow +// NOLINTEND(misc-unused-using-decls) +#else // IS_MOBILE_PLATFORM + +#include +#include + +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/monitoring/collection_registry.h" +#include "tensorflow/core/lib/monitoring/metric_def.h" +#include "tensorflow/core/lib/monitoring/types.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/thread_annotations.h" + +namespace tensorflow { +namespace monitoring { +// NOLINTBEGIN(misc-unused-using-decls) +using tsl::monitoring::PercentileSampler; +using tsl::monitoring::PercentileSamplerCell; +// NOLINTEND(misc-unused-using-decls) +} // namespace monitoring +} // namespace tensorflow + +#endif // IS_MOBILE_PLATFORM +#endif // TENSORFLOW_CORE_LIB_MONITORING_PERCENTILE_SAMPLER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/lib/monitoring/sampler.h b/third_party/tflite-hdrs/tensorflow/core/lib/monitoring/sampler.h new file mode 100644 index 00000000..e794890a --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/lib/monitoring/sampler.h @@ -0,0 +1,54 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_LIB_MONITORING_SAMPLER_H_ +#define TENSORFLOW_CORE_LIB_MONITORING_SAMPLER_H_ + +#include "xla/tsl/lib/monitoring/sampler.h" +#ifdef IS_MOBILE_PLATFORM + +#include + +#include "tensorflow/core/framework/summary.pb.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/monitoring/metric_def.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/types.h" +#else // IS_MOBILE_PLATFORM + +#include + +#include + +#include "tensorflow/core/framework/summary.pb.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/histogram/histogram.h" +#include "tensorflow/core/lib/monitoring/collection_registry.h" +#include "tensorflow/core/lib/monitoring/metric_def.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/thread_annotations.h" +#endif +// NOLINTBEGIN(misc-unused-using-decls) +namespace tensorflow { +namespace monitoring { + +using tsl::monitoring::Buckets; +using tsl::monitoring::Sampler; +using tsl::monitoring::SamplerCell; +} // namespace monitoring +} // namespace tensorflow +// NOLINTEND(misc-unused-using-decls) +#endif // TENSORFLOW_CORE_LIB_MONITORING_SAMPLER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/lib/monitoring/test_utils.h b/third_party/tflite-hdrs/tensorflow/core/lib/monitoring/test_utils.h new file mode 100644 index 00000000..a479c878 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/lib/monitoring/test_utils.h @@ -0,0 +1,35 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_LIB_MONITORING_TEST_UTILS_H_ +#define TENSORFLOW_CORE_LIB_MONITORING_TEST_UTILS_H_ + +#include + +#include "xla/tsl/lib/monitoring/test_utils.h" +#include "tensorflow/core/framework/summary.pb.h" +#include "tensorflow/core/lib/monitoring/types.h" +#include "tensorflow/core/platform/statusor.h" +// NOLINTBEGIN(misc-unused-using-decls) +namespace tensorflow { +namespace monitoring { +namespace testing { +using tsl::monitoring::testing::Histogram; +using tsl::monitoring::testing::Percentiles; + +} // namespace testing +} // namespace monitoring +} // namespace tensorflow +// NOLINTEND(misc-unused-using-decls) +#endif // TENSORFLOW_CORE_LIB_MONITORING_TEST_UTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/lib/monitoring/timed.h b/third_party/tflite-hdrs/tensorflow/core/lib/monitoring/timed.h new file mode 100644 index 00000000..c8ec0b8c --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/lib/monitoring/timed.h @@ -0,0 +1,29 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_LIB_MONITORING_TIMED_H_ +#define TENSORFLOW_CORE_LIB_MONITORING_TIMED_H_ + +#include "xla/tsl/lib/monitoring/timed.h" +#include "tensorflow/core/platform/env_time.h" +// NOLINTBEGIN(misc-unused-using-decls) +namespace tensorflow { +namespace monitoring { +using tsl::monitoring::MakeTimed; +using tsl::monitoring::Timed; +} // namespace monitoring +} // namespace tensorflow +// NOLINTEND(misc-unused-using-decls) +#endif // TENSORFLOW_CORE_LIB_MONITORING_TIMED_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/lib/monitoring/types.h b/third_party/tflite-hdrs/tensorflow/core/lib/monitoring/types.h new file mode 100644 index 00000000..d84a7402 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/lib/monitoring/types.h @@ -0,0 +1,35 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_LIB_MONITORING_TYPES_H_ +#define TENSORFLOW_CORE_LIB_MONITORING_TYPES_H_ + +#include +#include + +#include "xla/tsl/lib/monitoring/types.h" +#include "tensorflow/core/platform/types.h" + +// NOLINTBEGIN(misc-unused-using-decls) +namespace tensorflow { +namespace monitoring { +using tsl::monitoring::PercentilePoint; +using tsl::monitoring::Percentiles; +using tsl::monitoring::UnitOfMeasure; + +} // namespace monitoring +} // namespace tensorflow +// NOLINTEND(misc-unused-using-decls) +#endif // TENSORFLOW_CORE_LIB_MONITORING_TYPES_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/lib/png/png_io.h b/third_party/tflite-hdrs/tensorflow/core/lib/png/png_io.h new file mode 100644 index 00000000..a7fff84c --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/lib/png/png_io.h @@ -0,0 +1,116 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Functions to read and write images in PNG format. +// +// The advantage over image/codec/png{enc,dec}ocder.h is that this library +// supports both 8 and 16 bit images. +// +// The decoding routine accepts binary image data as a StringPiece. These are +// implicitly constructed from strings or char* so they're completely +// transparent to the caller. They're also very cheap to construct so this +// doesn't introduce any additional overhead. +// +// The primary benefit of StringPieces being, in this case, that APIs already +// returning StringPieces (e.g., Bigtable Scanner) or Cords (e.g., IOBuffer; +// only when they're flat, though) or protocol buffer fields typed to either of +// these can be decoded without copying the data into a C++ string. + +#ifndef TENSORFLOW_CORE_LIB_PNG_PNG_IO_H_ +#define TENSORFLOW_CORE_LIB_PNG_PNG_IO_H_ + +#include +#include +#include + +#include "absl/base/casts.h" +#include "tensorflow/core/platform/png.h" +#include "tensorflow/core/platform/stringpiece.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +namespace png { + +// Handy container for decoding information and struct pointers +struct DecodeContext { + const uint8* data; + int data_left; + png_structp png_ptr; + png_infop info_ptr; + png_uint_32 width, height; + int num_passes; + int color_type; + int bit_depth; + int channels; + bool need_to_synthesize_16; + bool error_condition; + DecodeContext() : png_ptr(nullptr), info_ptr(nullptr) {} +}; + +bool DecodeHeader(absl::string_view png_string, int* width, int* height, + int* components, int* channel_bit_depth, + std::vector >* metadata); + +// Sample usage for reading PNG: +// +// string png_string; /* fill with input PNG format data */ +// DecodeContext context; +// CHECK(CommonInitDecode(png_string, 3 /*RGB*/, 8 /*uint8*/, &context)); +// char* image_buffer = new char[3*context.width*context.height]; +// CHECK(CommonFinishDecode(absl::bit_cast(image_buffer), +// 3*context.width /*stride*/, &context)); +// +// desired_channels may be 0 to detected it from the input. + +bool CommonInitDecode(absl::string_view png_string, int desired_channels, + int desired_channel_bits, DecodeContext* context); + +bool CommonFinishDecode(png_bytep data, int row_bytes, DecodeContext* context); + +// Normally called automatically from CommonFinishDecode. If CommonInitDecode +// is called but not CommonFinishDecode, call this to clean up. Safe to call +// extra times. +void CommonFreeDecode(DecodeContext* context); + +// Sample usage for writing PNG: +// +// uint16* image_buffer = new uint16[width*height]; /* fill with pixels */ +// string png_string; +// CHECK(WriteImageToBuffer(image_buffer, width, height, 2*width /*stride*/, +// 1 /*gray*/, 16 /*uint16*/, &png_string, NULL)); +// +// compression is in [-1,9], where 0 is fast and weak compression, 9 is slow +// and strong, and -1 is the zlib default. + +template +bool WriteImageToBuffer( + const void* image, int width, int height, int row_bytes, int num_channels, + int channel_bits, int compression, T* png_string, + const std::vector >* metadata); + +// Explicit instantiations defined in png_io.cc. +extern template bool WriteImageToBuffer( + const void* image, int width, int height, int row_bytes, int num_channels, + int channel_bits, int compression, std::string* png_string, + const std::vector >* metadata); +extern template bool WriteImageToBuffer( + const void* image, int width, int height, int row_bytes, int num_channels, + int channel_bits, int compression, tstring* png_string, + const std::vector >* metadata); + +} // namespace png +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_LIB_PNG_PNG_IO_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/lib/random/distribution_sampler.h b/third_party/tflite-hdrs/tensorflow/core/lib/random/distribution_sampler.h new file mode 100644 index 00000000..6218d899 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/lib/random/distribution_sampler.h @@ -0,0 +1,47 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// DistributionSampler allows generating a discrete random variable with a given +// distribution. +// The values taken by the variable are [0, N) and relative weights for each +// value are specified using a vector of size N. +// +// The Algorithm takes O(N) time to precompute data at construction time and +// takes O(1) time (2 random number generation, 2 lookups) for each sample. +// The data structure takes O(N) memory. +// +// In contrast, util/random/weighted-picker.h provides O(lg N) sampling. +// The advantage of that implementation is that weights can be adjusted +// dynamically, while DistributionSampler doesn't allow weight adjustment. +// +// The algorithm used is Walker's Aliasing algorithm, described in Knuth, Vol 2. + +#ifndef TENSORFLOW_CORE_LIB_RANDOM_DISTRIBUTION_SAMPLER_H_ +#define TENSORFLOW_CORE_LIB_RANDOM_DISTRIBUTION_SAMPLER_H_ + +#include "xla/tsl/lib/random/distribution_sampler.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/lib/random/simple_philox.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +namespace random { +using tsl::random::DistributionSampler; // NOLINT(misc-unused-using-decls) +} // namespace random +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_LIB_RANDOM_DISTRIBUTION_SAMPLER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/lib/random/exact_uniform_int.h b/third_party/tflite-hdrs/tensorflow/core/lib/random/exact_uniform_int.h new file mode 100644 index 00000000..cd511d43 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/lib/random/exact_uniform_int.h @@ -0,0 +1,29 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Exact uniform integers using rejection sampling + +#ifndef TENSORFLOW_CORE_LIB_RANDOM_EXACT_UNIFORM_INT_H_ +#define TENSORFLOW_CORE_LIB_RANDOM_EXACT_UNIFORM_INT_H_ + +#include "xla/tsl/lib/random/exact_uniform_int.h" + +namespace tensorflow { +namespace random { +using tsl::random::ExactUniformInt; // NOLINT(misc-unused-using-decls) +} // namespace random +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_LIB_RANDOM_EXACT_UNIFORM_INT_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/lib/random/philox_random.h b/third_party/tflite-hdrs/tensorflow/core/lib/random/philox_random.h new file mode 100644 index 00000000..2fe4120f --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/lib/random/philox_random.h @@ -0,0 +1,35 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Implement the Philox algorithm to generate random numbers in parallel. +// Salmon et al. SC 2011. Parallel random numbers: as easy as 1, 2, 3. +// http://www.thesalmons.org/john/random123/papers/random123sc11.pdf + +#ifndef TENSORFLOW_CORE_LIB_RANDOM_PHILOX_RANDOM_H_ +#define TENSORFLOW_CORE_LIB_RANDOM_PHILOX_RANDOM_H_ + +#include "xla/tsl/lib/random/philox_random.h" + +namespace tensorflow { +namespace random { +// NOLINTBEGIN(misc-unused-using-decls) +using tsl::random::Array; +using tsl::random::PhiloxRandom; +// NOLINTEND(misc-unused-using-decls) + +} // namespace random +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_LIB_RANDOM_PHILOX_RANDOM_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/lib/random/random.h b/third_party/tflite-hdrs/tensorflow/core/lib/random/random.h new file mode 100644 index 00000000..78dedde0 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/lib/random/random.h @@ -0,0 +1,21 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_LIB_RANDOM_RANDOM_H_ +#define TENSORFLOW_CORE_LIB_RANDOM_RANDOM_H_ + +#include "tensorflow/core/platform/random.h" + +#endif // TENSORFLOW_CORE_LIB_RANDOM_RANDOM_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/lib/random/random_distributions.h b/third_party/tflite-hdrs/tensorflow/core/lib/random/random_distributions.h new file mode 100644 index 00000000..57ce99a0 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/lib/random/random_distributions.h @@ -0,0 +1,41 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_LIB_RANDOM_RANDOM_DISTRIBUTIONS_H_ +#define TENSORFLOW_CORE_LIB_RANDOM_RANDOM_DISTRIBUTIONS_H_ + +#include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "xla/tsl/lib/random/random_distributions.h" +#include "tensorflow/core/lib/random/philox_random.h" +#include "tensorflow/core/lib/random/random_distributions_utils.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +namespace random { +// NOLINTBEGIN(misc-unused-using-decls) +using tsl::random::BoxMullerDouble; +using tsl::random::NormalDistribution; +using tsl::random::SignedAdd; +using tsl::random::SingleSampleAdapter; +using tsl::random::TruncatedNormalDistribution; +using tsl::random::Uint16ToGfloat16; +using tsl::random::Uint16ToHalf; +using tsl::random::UniformDistribution; +using tsl::random::UniformFullIntDistribution; +// NOLINTEND(misc-unused-using-decls) +} // namespace random +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_LIB_RANDOM_RANDOM_DISTRIBUTIONS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/lib/random/random_distributions_utils.h b/third_party/tflite-hdrs/tensorflow/core/lib/random/random_distributions_utils.h new file mode 100644 index 00000000..4c268049 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/lib/random/random_distributions_utils.h @@ -0,0 +1,36 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_LIB_RANDOM_RANDOM_DISTRIBUTIONS_UTILS_H_ +#define TENSORFLOW_CORE_LIB_RANDOM_RANDOM_DISTRIBUTIONS_UTILS_H_ + +#include + +#include + +#include "xla/tsl/lib/random/random_distributions_utils.h" +#include "tensorflow/core/lib/random/philox_random.h" + +namespace tensorflow { +namespace random { +// NOLINTBEGIN(misc-unused-using-decls) +using tsl::random::BoxMullerFloat; +using tsl::random::Uint32ToFloat; +using tsl::random::Uint64ToDouble; +// NOLINTEND(misc-unused-using-decls) +} // namespace random +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_LIB_RANDOM_RANDOM_DISTRIBUTIONS_UTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/lib/random/simple_philox.h b/third_party/tflite-hdrs/tensorflow/core/lib/random/simple_philox.h new file mode 100644 index 00000000..7c94ca21 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/lib/random/simple_philox.h @@ -0,0 +1,29 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_LIB_RANDOM_SIMPLE_PHILOX_H_ +#define TENSORFLOW_CORE_LIB_RANDOM_SIMPLE_PHILOX_H_ + +#include "xla/tsl/lib/random/simple_philox.h" +#include "tensorflow/core/lib/random/philox_random.h" +#include "tensorflow/core/lib/random/random_distributions.h" + +namespace tensorflow { +namespace random { +using tsl::random::SimplePhilox; // NOLINT(misc-unused-using-decls) +} // namespace random +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_LIB_RANDOM_SIMPLE_PHILOX_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/lib/random/weighted_picker.h b/third_party/tflite-hdrs/tensorflow/core/lib/random/weighted_picker.h new file mode 100644 index 00000000..ae404814 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/lib/random/weighted_picker.h @@ -0,0 +1,43 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// An abstraction to pick from one of N elements with a specified +// weight per element. +// +// The weight for a given element can be changed in O(lg N) time +// An element can be picked in O(lg N) time. +// +// Uses O(N) bytes of memory. +// +// Alternative: distribution-sampler.h allows O(1) time picking, but no weight +// adjustment after construction. + +#ifndef TENSORFLOW_CORE_LIB_RANDOM_WEIGHTED_PICKER_H_ +#define TENSORFLOW_CORE_LIB_RANDOM_WEIGHTED_PICKER_H_ + +#include + +#include "xla/tsl/lib/random/weighted_picker.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +namespace random { +using tsl::random::WeightedPicker; // NOLINT(misc-unused-using-decls) +} // namespace random +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_LIB_RANDOM_WEIGHTED_PICKER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/lib/strings/base64.h b/third_party/tflite-hdrs/tensorflow/core/lib/strings/base64.h new file mode 100644 index 00000000..bb7cbfb3 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/lib/strings/base64.h @@ -0,0 +1,21 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_LIB_STRINGS_BASE64_H_ +#define TENSORFLOW_CORE_LIB_STRINGS_BASE64_H_ + +#include "tensorflow/core/platform/base64.h" + +#endif // TENSORFLOW_CORE_LIB_STRINGS_BASE64_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/lib/strings/numbers.h b/third_party/tflite-hdrs/tensorflow/core/lib/strings/numbers.h new file mode 100644 index 00000000..cbc53d47 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/lib/strings/numbers.h @@ -0,0 +1,21 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_LIB_STRINGS_NUMBERS_H_ +#define TENSORFLOW_CORE_LIB_STRINGS_NUMBERS_H_ + +#include "tensorflow/core/platform/numbers.h" + +#endif // TENSORFLOW_CORE_LIB_STRINGS_NUMBERS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/lib/strings/ordered_code.h b/third_party/tflite-hdrs/tensorflow/core/lib/strings/ordered_code.h new file mode 100644 index 00000000..e7485bd5 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/lib/strings/ordered_code.h @@ -0,0 +1,95 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This module provides routines for encoding a sequence of typed +// entities into a string. The resulting strings can be +// lexicographically compared to yield the same comparison value that +// would have been generated if the encoded items had been compared +// one by one according to their type. +// +// More precisely, suppose: +// 1. string A is generated by encoding the sequence of items [A_1..A_n] +// 2. string B is generated by encoding the sequence of items [B_1..B_n] +// 3. The types match; i.e., for all i: A_i was encoded using +// the same routine as B_i +// Then: +// Comparing A vs. B lexicographically is the same as comparing +// the vectors [A_1..A_n] and [B_1..B_n] lexicographically. +// +// Furthermore, if n < m, the encoding of [A_1..A_n] is a strict prefix of +// [A_1..A_m] (unless m = n+1 and A_m is the empty string encoded with +// WriteTrailingString, in which case the encodings are equal). +// +// This module is often useful when generating multi-part sstable +// keys that have to be ordered in a particular fashion. + +#ifndef TENSORFLOW_CORE_LIB_STRINGS_ORDERED_CODE_H_ +#define TENSORFLOW_CORE_LIB_STRINGS_ORDERED_CODE_H_ + +#include + +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/stringpiece.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +namespace strings { + +class OrderedCode { + public: + // ------------------------------------------------------------------- + // Encoding routines: each one of the following routines append + // one item to "*dest" in an encoding where larger values are + // ordered lexicographically after smaller values. + static void WriteString(string* dest, absl::string_view str); + static void WriteNumIncreasing(string* dest, uint64 num); + static void WriteSignedNumIncreasing(string* dest, int64_t num); + + // ------------------------------------------------------------------- + // Decoding routines: these extract an item earlier encoded using + // the corresponding WriteXXX() routines above. The item is read + // from "*src"; "*src" is modified to point past the decoded item; + // and if "result" is non-NULL, "*result" is modified to contain the + // result. In case of string result, the decoded string is appended to + // "*result". Returns true if the next item was read successfully, false + // otherwise. + static bool ReadString(absl::string_view* src, string* result); + static bool ReadNumIncreasing(absl::string_view* src, uint64* result); + static bool ReadSignedNumIncreasing(absl::string_view* src, int64_t* result); + + // Helper for testing: corrupt "*str" by changing the kth item separator + // in the string. + static void TEST_Corrupt(string* str, int k); + + // Helper for testing. + // SkipToNextSpecialByte is an internal routine defined in the .cc file + // with the following semantics. Return a pointer to the first byte + // in the range "[start..limit)" whose value is 0 or 255. If no such + // byte exists in the range, returns "limit". + static const char* TEST_SkipToNextSpecialByte(const char* start, + const char* limit); + + private: + // This has only static methods, so disallow construction entirely + OrderedCode(); + OrderedCode(const OrderedCode&) = delete; + void operator=(const OrderedCode&) = delete; +}; + +} // namespace strings +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_LIB_STRINGS_ORDERED_CODE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/lib/strings/proto_serialization.h b/third_party/tflite-hdrs/tensorflow/core/lib/strings/proto_serialization.h new file mode 100644 index 00000000..e0c253f5 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/lib/strings/proto_serialization.h @@ -0,0 +1,29 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_LIB_STRINGS_PROTO_SERIALIZATION_H_ +#define TENSORFLOW_CORE_LIB_STRINGS_PROTO_SERIALIZATION_H_ + +#include "xla/tsl/lib/strings/proto_serialization.h" + +namespace tensorflow { +// NOLINTBEGIN(misc-unused-using-decls) +using ::tsl::AreSerializedProtosEqual; +using ::tsl::DeterministicProtoHash64; +using ::tsl::SerializeToBufferDeterministic; +using ::tsl::SerializeToStringDeterministic; +// NOLINTEND(misc-unused-using-decls) +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_LIB_STRINGS_PROTO_SERIALIZATION_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/lib/strings/proto_text_util.h b/third_party/tflite-hdrs/tensorflow/core/lib/strings/proto_text_util.h new file mode 100644 index 00000000..ef73108b --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/lib/strings/proto_text_util.h @@ -0,0 +1,169 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_LIB_STRINGS_PROTO_TEXT_UTIL_H_ +#define TENSORFLOW_CORE_LIB_STRINGS_PROTO_TEXT_UTIL_H_ + +#include "absl/strings/str_cat.h" +#include "tensorflow/core/lib/strings/scanner.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/numbers.h" +#include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/platform/str_util.h" +#include "tensorflow/core/platform/strcat.h" + +namespace tensorflow { +namespace strings { + +static constexpr char kColonSeparator[] = ": "; + +// Helper functions for writing proto-text output. +// Used by the code generated from tools/proto_text/gen_proto_text_lib.cc. +class ProtoTextOutput { + public: + // Construct a ProtoTextOutput that writes to If short_debug is true, + // outputs text to match proto.ShortDebugString(); else matches + // proto.DebugString(). + ProtoTextOutput(string* output, bool short_debug) + : output_(output), + short_debug_(short_debug), + field_separator_(short_debug ? " " : "\n") {} + + // Writes opening of nested message and increases indent level. + void OpenNestedMessage(const char field_name[]) { + StrAppend(output_, level_empty_ ? "" : field_separator_, indent_, + field_name, " {", field_separator_); + if (!short_debug_) StrAppend(&indent_, " "); + level_empty_ = true; + } + + // Writes close of nested message and decreases indent level. + void CloseNestedMessage() { + if (!short_debug_) indent_.resize(indent_.size() - 2); + StrAppend(output_, level_empty_ ? "" : field_separator_, indent_, "}"); + level_empty_ = false; + } + + // Print the close of the top-level message that was printed. + void CloseTopMessage() { + if (!short_debug_ && !level_empty_) StrAppend(output_, "\n"); + } + + // Appends a numeric value, like my_field: 123 + template + void AppendNumeric(const char field_name[], T value) { + AppendFieldAndValue(field_name, StrCat(value)); + } + + // Appends a numeric value, like my_field: 123, but only if value != 0. + template + void AppendNumericIfNotZero(const char field_name[], T value) { + if (value != 0) AppendNumeric(field_name, value); + } + + // Appends a bool value, either my_field: true or my_field: false. + void AppendBool(const char field_name[], bool value) { + AppendFieldAndValue(field_name, value ? "true" : "false"); + } + + // Appends a bool value, as my_field: true, only if value is true. + void AppendBoolIfTrue(const char field_name[], bool value) { + if (value) AppendBool(field_name, value); + } + + // Appends a string value, like my_field: "abc123". + void AppendString(const char field_name[], const string& value) { + AppendFieldAndValue(field_name, StrCat("\"", absl::CEscape(value), "\"")); + } + + // Appends a string value, like my_field: "abc123", but only if value is not + // empty. + void AppendStringIfNotEmpty(const char field_name[], const string& value) { + if (!value.empty()) AppendString(field_name, value); + } + + // Appends the string name of an enum, like my_field: FIRST_ENUM. + void AppendEnumName(const char field_name[], const string& name) { + AppendFieldAndValue(field_name, name); + } + + private: + void AppendFieldAndValue(const char field_name[], + absl::string_view value_text) { + absl::StrAppend(output_, level_empty_ ? "" : field_separator_, indent_, + field_name, kColonSeparator, value_text); + level_empty_ = false; + } + + string* const output_; + const bool short_debug_; + const string field_separator_; + string indent_; + + // False when at least one field has been output for the message at the + // current deepest level of nesting. + bool level_empty_ = true; + + ProtoTextOutput(const ProtoTextOutput&) = delete; + void operator=(const ProtoTextOutput&) = delete; +}; + +inline void ProtoSpaceAndComments(Scanner* scanner) { + for (;;) { + scanner->AnySpace(); + if (scanner->Peek() != '#') return; + // Skip until newline. + while (scanner->Peek('\n') != '\n') scanner->One(Scanner::ALL); + } +} + +// Parse the next numeric value from , returning false if parsing +// failed. +template +bool ProtoParseNumericFromScanner(Scanner* scanner, T* value) { + absl::string_view numeric_str; + scanner->RestartCapture(); + if (!scanner->Many(Scanner::LETTER_DIGIT_DOT_PLUS_MINUS) + .GetResult(nullptr, &numeric_str)) { + return false; + } + + // Special case to disallow multiple leading zeroes, to match proto parsing. + int leading_zero = 0; + for (size_t i = 0; i < numeric_str.size(); ++i) { + const char ch = numeric_str[i]; + if (ch == '0') { + if (++leading_zero > 1) return false; + } else if (ch != '-') { + break; + } + } + + ProtoSpaceAndComments(scanner); + return SafeStringToNumeric(numeric_str, value); +} + +// Parse the next boolean value from , returning false if parsing +// failed. +bool ProtoParseBoolFromScanner(Scanner* scanner, bool* value); + +// Parse the next string literal from , returning false if parsing +// failed. +bool ProtoParseStringLiteralFromScanner(Scanner* scanner, string* value); + +} // namespace strings +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_LIB_STRINGS_PROTO_TEXT_UTIL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/lib/strings/scanner.h b/third_party/tflite-hdrs/tensorflow/core/lib/strings/scanner.h new file mode 100644 index 00000000..c41e4475 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/lib/strings/scanner.h @@ -0,0 +1,21 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_LIB_STRINGS_SCANNER_H_ +#define TENSORFLOW_CORE_LIB_STRINGS_SCANNER_H_ + +#include "tensorflow/core/platform/scanner.h" // IWYU pragma: export + +#endif // TENSORFLOW_CORE_LIB_STRINGS_SCANNER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/lib/strings/str_util.h b/third_party/tflite-hdrs/tensorflow/core/lib/strings/str_util.h new file mode 100644 index 00000000..a20cbdb5 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/lib/strings/str_util.h @@ -0,0 +1,21 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_LIB_STRINGS_STR_UTIL_H_ +#define TENSORFLOW_CORE_LIB_STRINGS_STR_UTIL_H_ + +#include "tensorflow/core/platform/str_util.h" // IWYU pragma: export + +#endif // TENSORFLOW_CORE_LIB_STRINGS_STR_UTIL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/lib/strings/strcat.h b/third_party/tflite-hdrs/tensorflow/core/lib/strings/strcat.h new file mode 100644 index 00000000..d728231f --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/lib/strings/strcat.h @@ -0,0 +1,25 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// #status: RECOMMENDED +// #category: operations on strings +// #summary: Merges strings or numbers with no delimiter. +// +#ifndef TENSORFLOW_CORE_LIB_STRINGS_STRCAT_H_ +#define TENSORFLOW_CORE_LIB_STRINGS_STRCAT_H_ + +#include "tensorflow/core/platform/strcat.h" // IWYU pragma: export + +#endif // TENSORFLOW_CORE_LIB_STRINGS_STRCAT_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/lib/strings/stringprintf.h b/third_party/tflite-hdrs/tensorflow/core/lib/strings/stringprintf.h new file mode 100644 index 00000000..836632d7 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/lib/strings/stringprintf.h @@ -0,0 +1,28 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Printf variants that place their output in a C++ string. +// +// Usage: +// string result = strings::Printf("%d %s\n", 10, "hello"); +// strings::SPrintf(&result, "%d %s\n", 10, "hello"); +// strings::Appendf(&result, "%d %s\n", 20, "there"); + +#ifndef TENSORFLOW_CORE_LIB_STRINGS_STRINGPRINTF_H_ +#define TENSORFLOW_CORE_LIB_STRINGS_STRINGPRINTF_H_ + +#include "tensorflow/core/platform/stringprintf.h" + +#endif // TENSORFLOW_CORE_LIB_STRINGS_STRINGPRINTF_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/lib/wav/wav_io.h b/third_party/tflite-hdrs/tensorflow/core/lib/wav/wav_io.h new file mode 100644 index 00000000..99a3df50 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/lib/wav/wav_io.h @@ -0,0 +1,105 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Functions to write audio in WAV format. + +#ifndef TENSORFLOW_CORE_LIB_WAV_WAV_IO_H_ +#define TENSORFLOW_CORE_LIB_WAV_WAV_IO_H_ + +#include +#include + +#include "tensorflow/core/lib/core/coding.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +namespace wav { + +// Encode the provided interleaved buffer of audio as a signed 16-bit PCM +// little-endian WAV file. +// +// Example usage for 4 frames of an 8kHz stereo signal: +// First channel is -1, 1, -1, 1. +// Second channel is 0, 0, 0, 0. +// +// float audio_buffer[] = { -1.0f, 0.0f, 1.0f, 0.0f, -1.0f, 0.0f, 1.0f, 0.0f}; +// string wav_string; +// if (EncodeAudioAsS16LEWav(audio_buffer, 8000, 2, 4, &wav_string).ok()) { +// // Use wav_string. +// } +template +absl::Status EncodeAudioAsS16LEWav(const float* audio, size_t sample_rate, + size_t num_channels, size_t num_frames, + T* wav_string); + +// Explicit instantiations defined in wav_io.cc. +extern template Status EncodeAudioAsS16LEWav( + const float* audio, size_t sample_rate, size_t num_channels, + size_t num_frames, std::string* wav_string); +extern template Status EncodeAudioAsS16LEWav(const float* audio, + size_t sample_rate, + size_t num_channels, + size_t num_frames, + tstring* wav_string); + +// Decodes the little-endian signed 16-bit PCM WAV file data (aka LIN16 +// encoding) into a float Tensor. The channels are encoded as the lowest +// dimension of the tensor, with the number of frames as the second. This means +// that a four frame stereo signal will have the shape [4, 2]. The sample rate +// is read from the file header, and an error is returned if the format is not +// supported. +// The results are output as floats within the range -1 to 1, +absl::Status DecodeLin16WaveAsFloatVector(const std::string& wav_string, + std::vector* float_values, + uint32* sample_count, + uint16* channel_count, + uint32* sample_rate); + +// Everything below here is only exposed publicly for testing purposes. + +// Handles moving the data index forward, validating the arguments, and avoiding +// overflow or underflow. +absl::Status IncrementOffset(int old_offset, int64_t increment, size_t max_size, + int* new_offset); + +// This function is only exposed in the header for testing purposes, as a +// template that needs to be instantiated. Reads a typed numeric value from a +// stream of data. +template +absl::Status ReadValue(const std::string& data, T* value, int* offset) { + int new_offset; + TF_RETURN_IF_ERROR( + IncrementOffset(*offset, sizeof(T), data.size(), &new_offset)); + if (port::kLittleEndian) { + memcpy(value, data.data() + *offset, sizeof(T)); + } else { + *value = 0; + const uint8* data_buf = + reinterpret_cast(data.data() + *offset); + int shift = 0; + for (int i = 0; i < sizeof(T); ++i, shift += 8) { + *value = *value | (data_buf[i] << shift); + } + } + *offset = new_offset; + return absl::OkStatus(); +} + +} // namespace wav +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_LIB_WAV_WAV_IO_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/nccl/collective_communicator.h b/third_party/tflite-hdrs/tensorflow/core/nccl/collective_communicator.h new file mode 100644 index 00000000..484f1100 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/nccl/collective_communicator.h @@ -0,0 +1,30 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_NCCL_COLLECTIVE_COMMUNICATOR_H_ +#define TENSORFLOW_CORE_NCCL_COLLECTIVE_COMMUNICATOR_H_ + +#include "tensorflow/core/framework/collective.h" +#include "tensorflow/core/protobuf/config.pb.h" + +namespace tensorflow { + +// Creates a NcclCommunicator if built with NCCL support (unless configured to +// use no GPU devices), otherwise it returns nullptr. +std::unique_ptr MaybeCreateNcclCommunicator( + const ConfigProto& config); + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_NCCL_COLLECTIVE_COMMUNICATOR_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/nccl/nccl_manager.h b/third_party/tflite-hdrs/tensorflow/core/nccl/nccl_manager.h new file mode 100644 index 00000000..0e620139 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/nccl/nccl_manager.h @@ -0,0 +1,283 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_NCCL_NCCL_MANAGER_H_ +#define TENSORFLOW_CORE_NCCL_NCCL_MANAGER_H_ + +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM + +#include + +// TODO(rmlarsen): Get rid of this workaround. "gpu_assert" is defined when +// setting EIGEN_USE_THREADS. But when defining EIGEN_USE_THREADS here, +// incAtomic and other CUDA specific symbols are no longer recognized. +#ifndef gpu_assert +#define gpu_assert(x) +#endif + +#include "absl/container/flat_hash_map.h" +#if GOOGLE_CUDA +#include "third_party/nccl/nccl.h" +#elif TENSORFLOW_USE_ROCM +#include "rocm/rocm_config.h" +#if (TF_ROCM_VERSION >= 50200) +#include "rocm/include/rccl/rccl.h" +#else +#include "rocm/include/rccl.h" +#endif +#include "tensorflow/core/common_runtime/gpu_device_context.h" +#endif +#include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h" +#include "tensorflow/core/framework/device_base.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/stream_executor.h" + +namespace tensorflow { + +// NCCL manager is used to make the asynchronous communicator calls and to +// manage the per-device streams used for communication. +// +// See nccl_ops.cc for example usage, including description of memory +// management and stream synchronization. +class NcclManager { + public: + typedef std::function DoneCallback; + NcclManager(); + ~NcclManager(); + + static NcclManager* instance(); + +#if TENSORFLOW_USE_ROCM + static int instance_count; +#endif + + // Calls `ncclGetUniqueId` and returns the id as a string. The returned value + // may be shared with other participants on different nodes and passed in to + // multi-node collective invocations. + string GenerateCommunicatorKey(); + + // A participant in a Collective. + struct Participant { + Participant(se::StreamExecutor* executor, se::Stream* tensor_stream, + const DeviceBase::AcceleratorDeviceInfo* info, + const Tensor* input, Tensor* output, int global_rank, + DoneCallback done_callback) + : executor(executor), + tensor_stream(tensor_stream), + event_mgr(info->event_mgr), + gpu_device_id(info->gpu_id), +#if TENSORFLOW_USE_ROCM + context(static_cast(info->default_context)), +#endif + input(input), + output(output), + global_rank(global_rank), + done_callback(std::move(done_callback)), + root(false) { + DCHECK(executor != nullptr); + DCHECK(event_mgr != nullptr); + DCHECK(tensor_stream != nullptr); + } + + // StreamExecutor for the device. Expected to be live for process lifetime. + se::StreamExecutor* const executor = nullptr; + + // `tensor_stream` is the stream that should be waited on to ensure + // `input`'s data is available on the GPU for the communication stream to + // access. It is also the stream that will use the produced data; + // `done_callback` is not called until the next kernel launched on `stream` + // would see the data. Owned by the caller, who must keep it live until + // `done_callback` is called. + se::Stream* const tensor_stream; + + // EventMgr which polls on executor. + // Owned by the caller, who must keep it live until `done_callback` is + // called. + EventMgr* const event_mgr; + + const int gpu_device_id; + +#if TENSORFLOW_USE_ROCM + GPUDeviceContext* const context; +#endif + + // Owned by the caller, who must keep it live until `done_callback` is + // called. Is NULL for participants that only receive data. + const Tensor* input; + + // Owned by the caller, who must keep it live until `done_callback` is + // called. Is NULL for participants that only send data. + Tensor* output; + + // Rank across all devices and all nodes. + // `global_rank` is not required for single-node collectives. + const int global_rank; + + // The callback which is called at the completion of the NCCL operation. + // When called, `output` has been set to the result of the operation. (note: + // the stream may not yet have been synced) + DoneCallback done_callback; + + // True if this is the root of the collective, e.g. source of broadcast. + bool root; + }; + + // Data that provides context for the collective operation, including the + // operation key, number of participants, and communicator key. + struct Context { + Context(const string& collective_key, int num_local_devices, + int num_global_devices, const string& communicator_key, + int source_rank) + : collective_key(collective_key), + num_local_devices(num_local_devices), + num_global_devices(num_global_devices), + communicator_key(communicator_key), + source_rank(source_rank) {} + + // Unique key for this collective instance + const string& collective_key; + + // Devices local to this node + int num_local_devices; + + // Devices across all nodes + int num_global_devices; + + // In order to use NCCL across nodes, the callee first has to generate a + // `communicator_key` via `GenerateCommunicatorKey()` function and share + // this with all the other nodes. Each node should pass in this + // `communicator_key` to the `NcclManager` functions. + // `communicator_key` is not required for single-node collectives and can be + // empty. + const string& communicator_key; + + // Rank of broadcast source. + int source_rank; + }; + + // Adds one participant to an all-reduce. + void AddToAllReduce(std::unique_ptr participant, + const Context& context, ncclRedOp_t reduction_op); + + // Adds one participant to an all-gather. + void AddToAllGather(std::unique_ptr participant, + const Context& context); + + // Adds one participant to a reduce-scatter. + void AddToReduceScatter(std::unique_ptr participant, + const Context& context, ncclRedOp_t reduction_op); + + // AddBroadcastSend and AddBroadcastRecv combine to send data from one sender + // to all receivers. + void AddBroadcastSend(std::unique_ptr participant, + const Context& context); + void AddBroadcastRecv(std::unique_ptr participant, + const Context& context); + + // AddReduceSend and AddReduceRecv combine to send data from all senders + // to one receiver. + void AddReduceSend(std::unique_ptr participant, + const Context& context, ncclRedOp_t reduction_op); + void AddReduceRecv(std::unique_ptr participant, + const Context& context, ncclRedOp_t reduction_op); + + // Adds one participant to an all-to-all. + void AddToAllToAll(std::unique_ptr participant, + const Context& context); + + // Signals that the `Collective` corresponding to `key` is ready to launch + // across all nodes participating in this multi-node collective operation. + // + // This should only be called for multi-node collectives; single-node + // collectives are implicitly ready when all participants have called Add* + // function. + void SignalMultiNodeReady(const string& collective_key); + + // Aborts all collectives. After abortion, no further collectives can be + // launched with this NcclManager. + void StartAbort(const Status& s); + + // Resets a previously aborted NcclManager, making it available for future + // collectives. + void Reset(); + + private: + enum CollectiveType { + kAllReduce = 1, + kBroadcast = 2, + kReduce = 3, + kAllGather = 4, + kReduceScatter = 5, + kAllToAll = 6, + }; + struct Collective; + struct Communicator; + struct CommunicatorMember; + struct NcclStream; + + // Gets the `Communicator` object that will be used to enqueue NCCL kernels + // for `collective`, and returns it via `communicator`. + // + // This may involve creating CUDA streams and NCCL initialization. If a NCCL + // or CUDA error occurs in the process, this returns an INTERNAL error with + // the corresponding NCCL/CUDA error string. + Status GetCommunicator(Collective* collective, Communicator** communicator); + + // Adds a participant device to the local `Collective` instance corresponding + // to `collective_key`. Launches the `Collective` if it is ready, which it + // checks by calling `CheckReady()`. Also performs consistency and sanity + // checks before launching. + void AddParticipant(std::unique_ptr participant, + const Context& context, CollectiveType collective_type, + ncclRedOp_t reduction_op); + + // If `collective` is ready to run, removes it from the `collectives_` map and + // returns true. Otherwise returns false. + // Assumes `collective_key` corresponds to `collective`. + // + // A collective is ready to run when all local participants have called Add* + // function, and the collective is signalled globally ready via + // `SetMultiNodeReady`. + bool CheckReady(const string& collective_key, Collective* collective) + TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); + + // Run . This calls takes ownership of . + void RunCollective(Collective* collective); + void LoopKernelLaunches(NcclStream* stream); + + mutex mu_; + + // Maps key to collectives currently being assembled or run. + absl::flat_hash_map collectives_ TF_GUARDED_BY(mu_); + + // Maps a device to the communication streams that make up its collective. + // This is used to share the stream across different communicators that + // include the same device. + absl::flat_hash_map> + device_to_comm_streams_ TF_GUARDED_BY(mu_); + + std::vector> communicators_ TF_GUARDED_BY(mu_); + + Status status_ TF_GUARDED_BY(mu_); + + NcclManager(const NcclManager&) = delete; + void operator=(const NcclManager&) = delete; +}; + +} // namespace tensorflow + +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM + +#endif // TENSORFLOW_CORE_NCCL_NCCL_MANAGER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/ops/compat/op_compatibility_lib.h b/third_party/tflite-hdrs/tensorflow/core/ops/compat/op_compatibility_lib.h new file mode 100644 index 00000000..776a6039 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/ops/compat/op_compatibility_lib.h @@ -0,0 +1,86 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_OPS_COMPAT_OP_COMPATIBILITY_LIB_H_ +#define TENSORFLOW_CORE_OPS_COMPAT_OP_COMPATIBILITY_LIB_H_ + +#include + +#include "tensorflow/core/framework/op_def.pb.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +class OpCompatibilityLib { + public: + // `ops_prefix` is a filename prefix indicating where to find the + // ops files. + // `history_version` is used to construct the ops history file name. + // `*stable_ops` has an optional list of ops that we care about. + // If stable_ops == nullptr, we use all registered ops. + // Otherwise ValidateCompatible() ignores ops not in *stable_ops + // and require all ops in *stable_ops to exist. + OpCompatibilityLib(const string& ops_prefix, const string& history_version, + const std::set* stable_ops); + + // Name of the file that contains the checked-in versions of *all* + // ops, with docs. + const string& ops_file() const { return ops_file_; } + + // Name of the file that contains all versions of *stable* ops, + // without docs. Op history is in (alphabetical, oldest-first) + // order. + const string& op_history_file() const { return op_history_file_; } + + // Name of the directory that contains all versions of *stable* ops, + // without docs. Op history is one file per op, in oldest-first + // order within the file. + const string& op_history_directory() const { return op_history_directory_; } + + // Should match the contents of ops_file(). Run before calling + // ValidateCompatible(). + string OpsString() const { + string result; + google::protobuf::TextFormat::PrintToString(op_list_, &result); + return result; + } + + // Returns the number of ops in OpsString(), includes all ops, not + // just stable ops. + int num_all_ops() const { return op_list_.op_size(); } + + // pairs representing op history. + typedef std::vector> OpHistory; + + // Make sure the current version of the *stable* ops are compatible + // with the historical versions, and if out_op_history != nullptr, + // generate a new history adding all changed ops. Sets + // *changed_ops/*added_ops to the number of changed/added ops + // (ignoring doc changes). + absl::Status ValidateCompatible(Env* env, int* changed_ops, int* added_ops, + OpHistory* out_op_history); + + private: + const string ops_file_; + const string op_history_file_; + const string op_history_directory_; + const std::set* stable_ops_; + OpList op_list_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_OPS_COMPAT_OP_COMPATIBILITY_LIB_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/platform/abi.h b/third_party/tflite-hdrs/tensorflow/core/platform/abi.h new file mode 100644 index 00000000..8191011a --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/platform/abi.h @@ -0,0 +1,29 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PLATFORM_ABI_H_ +#define TENSORFLOW_CORE_PLATFORM_ABI_H_ + +#include "tsl/platform/abi.h" + +namespace tensorflow { +namespace port { + +using ::tsl::port::MaybeAbiDemangle; // NOLINT(misc-unused-using-decls) + +} // namespace port +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PLATFORM_ABI_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/platform/base64.h b/third_party/tflite-hdrs/tensorflow/core/platform/base64.h new file mode 100644 index 00000000..126455fc --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/platform/base64.h @@ -0,0 +1,32 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PLATFORM_BASE64_H_ +#define TENSORFLOW_CORE_PLATFORM_BASE64_H_ + +#include + +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/stringpiece.h" +#include "tsl/platform/base64.h" + +namespace tensorflow { + +using tsl::Base64Decode; // NOLINT +using tsl::Base64Encode; // NOLINT + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PLATFORM_BASE64_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/platform/bfloat16.h b/third_party/tflite-hdrs/tensorflow/core/platform/bfloat16.h new file mode 100644 index 00000000..d6091aa2 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/platform/bfloat16.h @@ -0,0 +1,28 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PLATFORM_BFLOAT16_H_ +#define TENSORFLOW_CORE_PLATFORM_BFLOAT16_H_ + +// clang-format off +#include "tensorflow/core/platform/byte_order.h" +#include "tsl/platform/bfloat16.h" +// clang-format on + +namespace tensorflow { +typedef tsl::bfloat16 bfloat16; +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_PLATFORM_BFLOAT16_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/platform/blocking_counter.h b/third_party/tflite-hdrs/tensorflow/core/platform/blocking_counter.h new file mode 100644 index 00000000..4e629804 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/platform/blocking_counter.h @@ -0,0 +1,29 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PLATFORM_BLOCKING_COUNTER_H_ +#define TENSORFLOW_CORE_PLATFORM_BLOCKING_COUNTER_H_ + +#include + +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/mutex.h" +#include "tsl/platform/blocking_counter.h" + +namespace tensorflow { +using tsl::BlockingCounter; // NOLINT +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PLATFORM_BLOCKING_COUNTER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/platform/byte_order.h b/third_party/tflite-hdrs/tensorflow/core/platform/byte_order.h new file mode 100644 index 00000000..f6e1d172 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/platform/byte_order.h @@ -0,0 +1,29 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PLATFORM_BYTE_ORDER_H_ +#define TENSORFLOW_CORE_PLATFORM_BYTE_ORDER_H_ + +#include "tsl/platform/byte_order.h" + +namespace tensorflow { +namespace port { + +constexpr bool kLittleEndian = tsl::port::kLittleEndian; + +} // namespace port +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PLATFORM_BYTE_ORDER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/platform/casts.h b/third_party/tflite-hdrs/tensorflow/core/platform/casts.h new file mode 100644 index 00000000..791ac095 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/platform/casts.h @@ -0,0 +1,21 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PLATFORM_CASTS_H_ +#define TENSORFLOW_CORE_PLATFORM_CASTS_H_ + +#include "tsl/platform/casts.h" + +#endif // TENSORFLOW_CORE_PLATFORM_CASTS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/platform/cloud/auth_provider.h b/third_party/tflite-hdrs/tensorflow/core/platform/cloud/auth_provider.h new file mode 100644 index 00000000..987cc39f --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/platform/cloud/auth_provider.h @@ -0,0 +1,32 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PLATFORM_CLOUD_AUTH_PROVIDER_H_ +#define TENSORFLOW_CORE_PLATFORM_CLOUD_AUTH_PROVIDER_H_ + +#include + +#include "xla/tsl/platform/cloud/auth_provider.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/status.h" + +namespace tensorflow { +// NOLINTBEGIN(misc-unused-using-decls) +using tsl::AuthProvider; +using tsl::EmptyAuthProvider; +// NOLINTEND(misc-unused-using-decls) +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PLATFORM_CLOUD_AUTH_PROVIDER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/platform/cloud/compute_engine_metadata_client.h b/third_party/tflite-hdrs/tensorflow/core/platform/cloud/compute_engine_metadata_client.h new file mode 100644 index 00000000..4c83d28a --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/platform/cloud/compute_engine_metadata_client.h @@ -0,0 +1,28 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PLATFORM_CLOUD_COMPUTE_ENGINE_METADATA_CLIENT_H_ +#define TENSORFLOW_CORE_PLATFORM_CLOUD_COMPUTE_ENGINE_METADATA_CLIENT_H_ + +#include "xla/tsl/platform/cloud/compute_engine_metadata_client.h" +#include "tensorflow/core/platform/cloud/http_request.h" +#include "tensorflow/core/platform/retrying_utils.h" +#include "tensorflow/core/platform/status.h" + +namespace tensorflow { +using tsl::ComputeEngineMetadataClient; // NOLINT(misc-unused-using-decls) +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PLATFORM_CLOUD_COMPUTE_ENGINE_METADATA_CLIENT_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/platform/cloud/compute_engine_zone_provider.h b/third_party/tflite-hdrs/tensorflow/core/platform/cloud/compute_engine_zone_provider.h new file mode 100644 index 00000000..6b416481 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/platform/cloud/compute_engine_zone_provider.h @@ -0,0 +1,27 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PLATFORM_CLOUD_COMPUTE_ENGINE_ZONE_PROVIDER_H_ +#define TENSORFLOW_CORE_PLATFORM_CLOUD_COMPUTE_ENGINE_ZONE_PROVIDER_H_ + +#include "xla/tsl/platform/cloud/compute_engine_zone_provider.h" +#include "tensorflow/core/platform/cloud/compute_engine_metadata_client.h" +#include "tensorflow/core/platform/cloud/zone_provider.h" + +namespace tensorflow { +using tsl::ComputeEngineZoneProvider; // NOLINT(misc-unused-using-decls) +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PLATFORM_CLOUD_COMPUTE_ENGINE_ZONE_PROVIDER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/platform/cloud/curl_http_request.h b/third_party/tflite-hdrs/tensorflow/core/platform/cloud/curl_http_request.h new file mode 100644 index 00000000..385091ff --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/platform/cloud/curl_http_request.h @@ -0,0 +1,41 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PLATFORM_CLOUD_CURL_HTTP_REQUEST_H_ +#define TENSORFLOW_CORE_PLATFORM_CLOUD_CURL_HTTP_REQUEST_H_ + +#include +#include +#include + +#include +#include "xla/tsl/platform/cloud/curl_http_request.h" +#include "tensorflow/core/platform/cloud/http_request.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/stringpiece.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +// NOLINTBEGIN(misc-unused-using-decls) +using tsl::CurlHttpRequest; +using tsl::LibCurl; +// NOLINTEND(misc-unused-using-decls) +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PLATFORM_CLOUD_CURL_HTTP_REQUEST_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/platform/cloud/expiring_lru_cache.h b/third_party/tflite-hdrs/tensorflow/core/platform/cloud/expiring_lru_cache.h new file mode 100644 index 00000000..03af7ee7 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/platform/cloud/expiring_lru_cache.h @@ -0,0 +1,34 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PLATFORM_CLOUD_EXPIRING_LRU_CACHE_H_ +#define TENSORFLOW_CORE_PLATFORM_CLOUD_EXPIRING_LRU_CACHE_H_ + +#include +#include +#include +#include + +#include "xla/tsl/platform/cloud/expiring_lru_cache.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/thread_annotations.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +using tsl::ExpiringLRUCache; // NOLINT(misc-unused-using-decls) +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PLATFORM_CLOUD_EXPIRING_LRU_CACHE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/platform/cloud/file_block_cache.h b/third_party/tflite-hdrs/tensorflow/core/platform/cloud/file_block_cache.h new file mode 100644 index 00000000..4c907437 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/platform/cloud/file_block_cache.h @@ -0,0 +1,42 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PLATFORM_CLOUD_FILE_BLOCK_CACHE_H_ +#define TENSORFLOW_CORE_PLATFORM_CLOUD_FILE_BLOCK_CACHE_H_ + +#include +#include +#include +#include +#include +#include + +#include "xla/tsl/platform/cloud/file_block_cache.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/notification.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/stringpiece.h" +#include "tensorflow/core/platform/thread_annotations.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +// NOLINTBEGIN(misc-unused-using-decls) +using tsl::FileBlockCache; +using tsl::FileBlockCacheStatsInterface; +// NOLINTEND(misc-unused-using-decls) +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PLATFORM_CLOUD_FILE_BLOCK_CACHE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/platform/cloud/gcs_dns_cache.h b/third_party/tflite-hdrs/tensorflow/core/platform/cloud/gcs_dns_cache.h new file mode 100644 index 00000000..813bcd0e --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/platform/cloud/gcs_dns_cache.h @@ -0,0 +1,79 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PLATFORM_CLOUD_GCS_DNS_CACHE_H_ +#define TENSORFLOW_CORE_PLATFORM_CLOUD_GCS_DNS_CACHE_H_ + +#include +#include +#include + +#include "tensorflow/core/platform/cloud/http_request.h" +#include "tensorflow/core/platform/env.h" + +namespace tensorflow { +const int64_t kDefaultRefreshRateSecs = 60; + +// DnsCache is a userspace DNS cache specialized for the GCS filesystem. +// +// Some environments have unreliable DNS resolvers. DnsCache ameliorates the +// situation by radically reducing the number of DNS requests by performing +// 2 DNS queries per minute (by default) on a background thread. Updated cache +// entries are used to override curl's DNS resolution processes. +class GcsDnsCache { + public: + // Default no-argument constructor. + GcsDnsCache() : GcsDnsCache(kDefaultRefreshRateSecs) {} + + // Constructs a GcsDnsCache with the specified refresh rate. + GcsDnsCache(int64_t refresh_rate_secs) + : GcsDnsCache(Env::Default(), refresh_rate_secs) {} + + GcsDnsCache(Env* env, int64_t refresh_rate_secs); + + ~GcsDnsCache() { + mutex_lock l(mu_); + cancelled_ = true; + cond_var_.notify_one(); + } + + // Annotate the given HttpRequest with resolve overrides from the cache. + void AnnotateRequest(HttpRequest* request); + + private: + static std::vector ResolveName(const string& name); + static std::vector> ResolveNames( + const std::vector& names); + void WorkerThread(); + + // Define a friend class for testing. + friend class GcsDnsCacheTest; + + mutex mu_; + Env* env_; + condition_variable cond_var_; + std::default_random_engine random_ TF_GUARDED_BY(mu_); + bool started_ TF_GUARDED_BY(mu_) = false; + bool cancelled_ TF_GUARDED_BY(mu_) = false; + std::unique_ptr worker_ TF_GUARDED_BY(mu_); // After mutable vars. + const int64_t refresh_rate_secs_; + + // Entries in this vector correspond to entries in kCachedDomainNames. + std::vector> addresses_ TF_GUARDED_BY(mu_); +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PLATFORM_CLOUD_GCS_DNS_CACHE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/platform/cloud/gcs_file_system.h b/third_party/tflite-hdrs/tensorflow/core/platform/cloud/gcs_file_system.h new file mode 100644 index 00000000..5545d2b2 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/platform/cloud/gcs_file_system.h @@ -0,0 +1,53 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PLATFORM_CLOUD_GCS_FILE_SYSTEM_H_ +#define TENSORFLOW_CORE_PLATFORM_CLOUD_GCS_FILE_SYSTEM_H_ + +#include +#include +#include +#include + +#include "xla/tsl/platform/cloud/gcs_file_system.h" +#include "tensorflow/core/platform/cloud/auth_provider.h" +#include "tensorflow/core/platform/cloud/compute_engine_metadata_client.h" +#include "tensorflow/core/platform/cloud/compute_engine_zone_provider.h" +#include "tensorflow/core/platform/cloud/expiring_lru_cache.h" +#include "tensorflow/core/platform/cloud/file_block_cache.h" +#include "tensorflow/core/platform/cloud/gcs_dns_cache.h" +#include "tensorflow/core/platform/cloud/gcs_throttle.h" +#include "tensorflow/core/platform/cloud/http_request.h" +#include "tensorflow/core/platform/file_system.h" +#include "tensorflow/core/platform/retrying_file_system.h" +#include "tensorflow/core/platform/status.h" + +namespace tensorflow { +// NOLINTBEGIN(misc-unused-using-decls) +using tsl::GcsFileSystem; +using tsl::GcsStatsInterface; +using tsl::GetEnvVar; +using tsl::kBlockSize; +using tsl::kDefaultBlockSize; +using tsl::kDefaultMaxCacheSize; +using tsl::kDefaultMaxStaleness; +using tsl::kMaxCacheSize; +using tsl::kMaxStaleness; +using tsl::RetryingGcsFileSystem; +using tsl::UploadSessionHandle; +// NOLINTEND(misc-unused-using-decls) +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PLATFORM_CLOUD_GCS_FILE_SYSTEM_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/platform/cloud/gcs_throttle.h b/third_party/tflite-hdrs/tensorflow/core/platform/cloud/gcs_throttle.h new file mode 100644 index 00000000..e4a33a38 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/platform/cloud/gcs_throttle.h @@ -0,0 +1,29 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PLATFORM_CLOUD_GCS_THROTTLE_H_ +#define TENSORFLOW_CORE_PLATFORM_CLOUD_GCS_THROTTLE_H_ + +#include "xla/tsl/platform/cloud/gcs_throttle.h" +#include "tensorflow/core/platform/env.h" + +namespace tensorflow { +// NOLINTBEGIN(misc-unused-using-decls) +using tsl::GcsThrottle; +using tsl::GcsThrottleConfig; +// NOLINTEND(misc-unused-using-decls) +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PLATFORM_CLOUD_GCS_THROTTLE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/platform/cloud/google_auth_provider.h b/third_party/tflite-hdrs/tensorflow/core/platform/cloud/google_auth_provider.h new file mode 100644 index 00000000..afefb308 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/platform/cloud/google_auth_provider.h @@ -0,0 +1,32 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PLATFORM_CLOUD_GOOGLE_AUTH_PROVIDER_H_ +#define TENSORFLOW_CORE_PLATFORM_CLOUD_GOOGLE_AUTH_PROVIDER_H_ + +#include + +#include "xla/tsl/platform/cloud/google_auth_provider.h" +#include "tensorflow/core/platform/cloud/auth_provider.h" +#include "tensorflow/core/platform/cloud/compute_engine_metadata_client.h" +#include "tensorflow/core/platform/cloud/oauth_client.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/thread_annotations.h" + +namespace tensorflow { +using tsl::GoogleAuthProvider; // NOLINT(misc-unused-using-decls) +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PLATFORM_CLOUD_GOOGLE_AUTH_PROVIDER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/platform/cloud/http_request.h b/third_party/tflite-hdrs/tensorflow/core/platform/cloud/http_request.h new file mode 100644 index 00000000..aae023b5 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/platform/cloud/http_request.h @@ -0,0 +1,36 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PLATFORM_CLOUD_HTTP_REQUEST_H_ +#define TENSORFLOW_CORE_PLATFORM_CLOUD_HTTP_REQUEST_H_ + +#include +#include +#include + +#include "xla/tsl/platform/cloud/http_request.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/stringpiece.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +using tsl::HttpRequest; // NOLINT(misc-unused-using-decls) +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PLATFORM_CLOUD_HTTP_REQUEST_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/platform/cloud/http_request_fake.h b/third_party/tflite-hdrs/tensorflow/core/platform/cloud/http_request_fake.h new file mode 100644 index 00000000..de1177ec --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/platform/cloud/http_request_fake.h @@ -0,0 +1,43 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PLATFORM_CLOUD_HTTP_REQUEST_FAKE_H_ +#define TENSORFLOW_CORE_PLATFORM_CLOUD_HTTP_REQUEST_FAKE_H_ + +#include +#include +#include +#include + +#include +#include "xla/tsl/platform/cloud/http_request_fake.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/cloud/curl_http_request.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/stringpiece.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +// NOLINTBEGIN(misc-unused-using-decls) +using tsl::FakeHttpRequest; +using tsl::FakeHttpRequestFactory; +// NOLINTEND(misc-unused-using-decls) +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PLATFORM_CLOUD_HTTP_REQUEST_FAKE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/platform/cloud/now_seconds_env.h b/third_party/tflite-hdrs/tensorflow/core/platform/cloud/now_seconds_env.h new file mode 100644 index 00000000..395e563c --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/platform/cloud/now_seconds_env.h @@ -0,0 +1,28 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PLATFORM_CLOUD_NOW_SECONDS_ENV_H_ +#define TENSORFLOW_CORE_PLATFORM_CLOUD_NOW_SECONDS_ENV_H_ + +#include "xla/tsl/platform/cloud/now_seconds_env.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +using tsl::NowSecondsEnv; // NOLINT(misc-unused-using-decls) +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PLATFORM_CLOUD_NOW_SECONDS_ENV_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/platform/cloud/oauth_client.h b/third_party/tflite-hdrs/tensorflow/core/platform/cloud/oauth_client.h new file mode 100644 index 00000000..ca390c9f --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/platform/cloud/oauth_client.h @@ -0,0 +1,31 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PLATFORM_CLOUD_OAUTH_CLIENT_H_ +#define TENSORFLOW_CORE_PLATFORM_CLOUD_OAUTH_CLIENT_H_ + +#include + +#include "json/json.h" +#include "xla/tsl/platform/cloud/oauth_client.h" +#include "tensorflow/core/platform/cloud/http_request.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/status.h" + +namespace tensorflow { +using tsl::OAuthClient; // NOLINT(misc-unused-using-decls) +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PLATFORM_CLOUD_OAUTH_CLIENT_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/platform/cloud/ram_file_block_cache.h b/third_party/tflite-hdrs/tensorflow/core/platform/cloud/ram_file_block_cache.h new file mode 100644 index 00000000..d4de2b42 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/platform/cloud/ram_file_block_cache.h @@ -0,0 +1,40 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PLATFORM_CLOUD_RAM_FILE_BLOCK_CACHE_H_ +#define TENSORFLOW_CORE_PLATFORM_CLOUD_RAM_FILE_BLOCK_CACHE_H_ + +#include +#include +#include +#include +#include +#include + +#include "xla/tsl/platform/cloud/ram_file_block_cache.h" +#include "tensorflow/core/platform/cloud/file_block_cache.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/notification.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/stringpiece.h" +#include "tensorflow/core/platform/thread_annotations.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +using tsl::RamFileBlockCache; // NOLINT(misc-unused-using-decls) +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PLATFORM_CLOUD_RAM_FILE_BLOCK_CACHE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/platform/cloud/time_util.h b/third_party/tflite-hdrs/tensorflow/core/platform/cloud/time_util.h new file mode 100644 index 00000000..7110d13c --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/platform/cloud/time_util.h @@ -0,0 +1,26 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PLATFORM_CLOUD_TIME_UTIL_H_ +#define TENSORFLOW_CORE_PLATFORM_CLOUD_TIME_UTIL_H_ + +#include "xla/tsl/platform/cloud/time_util.h" +#include "tensorflow/core/platform/status.h" + +namespace tensorflow { +using tsl::ParseRfc3339Time; // NOLINT(misc-unused-using-decls) +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PLATFORM_CLOUD_TIME_UTIL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/platform/cloud/zone_provider.h b/third_party/tflite-hdrs/tensorflow/core/platform/cloud/zone_provider.h new file mode 100644 index 00000000..07ef0609 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/platform/cloud/zone_provider.h @@ -0,0 +1,29 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PLATFORM_CLOUD_ZONE_PROVIDER_H_ +#define TENSORFLOW_CORE_PLATFORM_CLOUD_ZONE_PROVIDER_H_ + +#include + +#include "xla/tsl/platform/cloud/zone_provider.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/status.h" + +namespace tensorflow { +using tsl::ZoneProvider; // NOLINT(misc-unused-using-decls) +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PLATFORM_CLOUD_ZONE_PROVIDER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/platform/coding.h b/third_party/tflite-hdrs/tensorflow/core/platform/coding.h new file mode 100644 index 00000000..091d7544 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/platform/coding.h @@ -0,0 +1,54 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Endian-neutral encoding: +// * Fixed-length numbers are encoded with least-significant byte first +// * In addition we support variable length "varint" encoding +// * Strings are encoded prefixed by their length in varint format + +#ifndef TENSORFLOW_CORE_PLATFORM_CODING_H_ +#define TENSORFLOW_CORE_PLATFORM_CODING_H_ + +#include "tensorflow/core/platform/raw_coding.h" +#include "tensorflow/core/platform/stringpiece.h" +#include "tensorflow/core/platform/types.h" +#include "tsl/platform/coding.h" + +namespace tensorflow { +namespace core { +// NOLINTBEGIN(misc-unused-using-decls) +using tsl::core::EncodeFixed16; +using tsl::core::EncodeFixed32; +using tsl::core::EncodeFixed64; +using tsl::core::EncodeVarint32; +using tsl::core::EncodeVarint64; +using tsl::core::GetVarint32; +using tsl::core::GetVarint32Ptr; +using tsl::core::GetVarint32PtrFallback; +using tsl::core::GetVarint64; +using tsl::core::GetVarint64Ptr; +using tsl::core::kMaxVarint32Bytes; +using tsl::core::kMaxVarint64Bytes; +using tsl::core::PutFixed16; +using tsl::core::PutFixed32; +using tsl::core::PutFixed64; +using tsl::core::PutVarint32; +using tsl::core::PutVarint64; +using tsl::core::VarintLength; +// NOLINTEND(misc-unused-using-decls) +} // namespace core +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PLATFORM_CODING_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/platform/context.h b/third_party/tflite-hdrs/tensorflow/core/platform/context.h new file mode 100644 index 00000000..f93b5695 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/platform/context.h @@ -0,0 +1,32 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PLATFORM_CONTEXT_H_ +#define TENSORFLOW_CORE_PLATFORM_CONTEXT_H_ + +#include "tensorflow/core/platform/platform.h" +#include "tsl/platform/context.h" + +namespace tensorflow { + +// NOLINTBEGIN(misc-unused-using-decls) +using tsl::Context; +using tsl::ContextKind; +using tsl::WithContext; +// NOLINTEND(misc-unused-using-decls) + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PLATFORM_CONTEXT_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/platform/cord.h b/third_party/tflite-hdrs/tensorflow/core/platform/cord.h new file mode 100644 index 00000000..fa7d2a5d --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/platform/cord.h @@ -0,0 +1,21 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PLATFORM_CORD_H_ +#define TENSORFLOW_CORE_PLATFORM_CORD_H_ + +#include "tsl/platform/cord.h" // IWYU pragma: export + +#endif // TENSORFLOW_CORE_PLATFORM_CORD_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/platform/cpu_feature_guard.h b/third_party/tflite-hdrs/tensorflow/core/platform/cpu_feature_guard.h new file mode 100644 index 00000000..3d7bfe95 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/platform/cpu_feature_guard.h @@ -0,0 +1,32 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PLATFORM_CPU_FEATURE_GUARD_H_ +#define TENSORFLOW_CORE_PLATFORM_CPU_FEATURE_GUARD_H_ + +namespace tensorflow { +namespace port { + +// Called by the framework when we expect heavy CPU computation and we want to +// be sure that the code has been compiled to run optimally on the current +// hardware. The first time it's called it will run lightweight checks of +// available SIMD acceleration features and log warnings about any that aren't +// used. +void InfoAboutUnusedCPUFeatures(); + +} // namespace port +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PLATFORM_CPU_FEATURE_GUARD_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/platform/cpu_info.h b/third_party/tflite-hdrs/tensorflow/core/platform/cpu_info.h new file mode 100644 index 00000000..8e0b101b --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/platform/cpu_info.h @@ -0,0 +1,94 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PLATFORM_CPU_INFO_H_ +#define TENSORFLOW_CORE_PLATFORM_CPU_INFO_H_ + +#include + +// TODO(ahentz): This is not strictly required here but, for historical +// reasons, many people depend on cpu_info.h in order to use kLittleEndian. +#include "tensorflow/core/platform/byte_order.h" +#include "tsl/platform/cpu_info.h" + +namespace tensorflow { +namespace port { +using tsl::port::Aarch64CPU; +using tsl::port::ADX; +using tsl::port::AES; +using tsl::port::AMX_BF16; +using tsl::port::AMX_FP16; +using tsl::port::AMX_INT8; +using tsl::port::AMX_TILE; +using tsl::port::AVX; +using tsl::port::AVX2; +using tsl::port::AVX512_4FMAPS; +using tsl::port::AVX512_4VNNIW; +using tsl::port::AVX512_BF16; +using tsl::port::AVX512_FP16; +using tsl::port::AVX512_VNNI; +using tsl::port::AVX512BW; +using tsl::port::AVX512CD; +using tsl::port::AVX512DQ; +using tsl::port::AVX512ER; +using tsl::port::AVX512F; +using tsl::port::AVX512IFMA; +using tsl::port::AVX512PF; +using tsl::port::AVX512VBMI; +using tsl::port::AVX512VL; +using tsl::port::AVX_NE_CONVERT; +using tsl::port::AVX_VNNI; +using tsl::port::AVX_VNNI_INT8; +using tsl::port::BMI1; +using tsl::port::BMI2; +using tsl::port::CMOV; +using tsl::port::CMPXCHG16B; +using tsl::port::CMPXCHG8B; +using tsl::port::CPUFamily; +using tsl::port::CPUFeature; +using tsl::port::CPUIDNumSMT; +using tsl::port::CPUModelNum; +using tsl::port::CPUVendorIDString; +using tsl::port::F16C; +using tsl::port::FMA; +using tsl::port::GetCurrentCPU; +using tsl::port::HYPERVISOR; +using tsl::port::kUnknownCPU; +using tsl::port::MaxParallelism; +using tsl::port::MMX; +using tsl::port::NominalCPUFrequency; +using tsl::port::NumHyperthreadsPerCore; +using tsl::port::NumSchedulableCPUs; +using tsl::port::NumTotalCPUs; +using tsl::port::PCLMULQDQ; +using tsl::port::POPCNT; +using tsl::port::PREFETCHW; +using tsl::port::PREFETCHWT1; +using tsl::port::RDRAND; +using tsl::port::RDSEED; +using tsl::port::SMAP; +using tsl::port::SSE; +using tsl::port::SSE2; +using tsl::port::SSE3; +using tsl::port::SSE4_1; +using tsl::port::SSE4_2; +using tsl::port::SSSE3; +using tsl::port::TestAarch64CPU; +using tsl::port::TestCPUFeature; + +} // namespace port +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PLATFORM_CPU_INFO_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/platform/crash_analysis.h b/third_party/tflite-hdrs/tensorflow/core/platform/crash_analysis.h new file mode 100644 index 00000000..c4555ee9 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/platform/crash_analysis.h @@ -0,0 +1,22 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PLATFORM_CRASH_ANALYSIS_H_ +#define TENSORFLOW_CORE_PLATFORM_CRASH_ANALYSIS_H_ + +#include "tensorflow/core/platform/platform.h" +#include "tsl/platform/crash_analysis.h" + +#endif // TENSORFLOW_CORE_PLATFORM_CRASH_ANALYSIS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/platform/ctstring.h b/third_party/tflite-hdrs/tensorflow/core/platform/ctstring.h new file mode 100644 index 00000000..3b9359d4 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/platform/ctstring.h @@ -0,0 +1,21 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PLATFORM_CTSTRING_H_ +#define TENSORFLOW_CORE_PLATFORM_CTSTRING_H_ + +#include "tsl/platform/ctstring.h" + +#endif // TENSORFLOW_CORE_PLATFORM_CTSTRING_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/platform/ctstring_internal.h b/third_party/tflite-hdrs/tensorflow/core/platform/ctstring_internal.h new file mode 100644 index 00000000..c087dbca --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/platform/ctstring_internal.h @@ -0,0 +1,21 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PLATFORM_CTSTRING_INTERNAL_H_ +#define TENSORFLOW_CORE_PLATFORM_CTSTRING_INTERNAL_H_ + +#include "tsl/platform/ctstring_internal.h" + +#endif // TENSORFLOW_CORE_PLATFORM_CTSTRING_INTERNAL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/platform/cuda.h b/third_party/tflite-hdrs/tensorflow/core/platform/cuda.h new file mode 100644 index 00000000..d032f23a --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/platform/cuda.h @@ -0,0 +1,21 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PLATFORM_CUDA_H_ +#define TENSORFLOW_CORE_PLATFORM_CUDA_H_ + +#include "tensorflow/core/platform/platform.h" // IWYU pragma: keep + +#endif // TENSORFLOW_CORE_PLATFORM_CUDA_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/platform/demangle.h b/third_party/tflite-hdrs/tensorflow/core/platform/demangle.h new file mode 100644 index 00000000..fd569122 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/platform/demangle.h @@ -0,0 +1,28 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PLATFORM_DEMANGLE_H_ +#define TENSORFLOW_CORE_PLATFORM_DEMANGLE_H_ + +#include "tensorflow/core/platform/types.h" +#include "tsl/platform/demangle.h" + +namespace tensorflow { +namespace port { +using tsl::port::Demangle; +} // namespace port +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PLATFORM_DEMANGLE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/platform/denormal.h b/third_party/tflite-hdrs/tensorflow/core/platform/denormal.h new file mode 100644 index 00000000..47dcf75c --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/platform/denormal.h @@ -0,0 +1,35 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PLATFORM_DENORMAL_H_ +#define TENSORFLOW_CORE_PLATFORM_DENORMAL_H_ + +#include "tensorflow/core/platform/macros.h" +#include "tsl/platform/denormal.h" + +namespace tensorflow { +namespace port { +// NOLINTBEGIN(misc-unused-using-decls) +using tsl::port::DenormalState; +using tsl::port::GetDenormalState; +using tsl::port::ScopedDontFlushDenormal; +using tsl::port::ScopedFlushDenormal; +using tsl::port::ScopedRestoreFlushDenormalState; +using tsl::port::SetDenormalState; +// NOLINTEND(misc-unused-using-decls) +} // namespace port +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PLATFORM_DENORMAL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/platform/dynamic_annotations.h b/third_party/tflite-hdrs/tensorflow/core/platform/dynamic_annotations.h new file mode 100644 index 00000000..795c978f --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/platform/dynamic_annotations.h @@ -0,0 +1,22 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PLATFORM_DYNAMIC_ANNOTATIONS_H_ +#define TENSORFLOW_CORE_PLATFORM_DYNAMIC_ANNOTATIONS_H_ + +#include "tensorflow/core/platform/platform.h" +#include "tsl/platform/dynamic_annotations.h" // IWYU pragma: export + +#endif // TENSORFLOW_CORE_PLATFORM_DYNAMIC_ANNOTATIONS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/platform/enable_tf2_utils.h b/third_party/tflite-hdrs/tensorflow/core/platform/enable_tf2_utils.h new file mode 100644 index 00000000..856ee1f6 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/platform/enable_tf2_utils.h @@ -0,0 +1,31 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PLATFORM_ENABLE_TF2_UTILS_H_ +#define TENSORFLOW_CORE_PLATFORM_ENABLE_TF2_UTILS_H_ + +namespace tensorflow { + +// Sets the tf2 execution state. This can be used to indicate whether the user +// has explicitly asked for tf2 execution. +void set_tf2_execution(bool enabled); + +// Returns true or false depending on whether the user flag for tf2 execution +// has been set. The default is false. +bool tf2_execution_enabled(); + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PLATFORM_ENABLE_TF2_UTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/platform/env.h b/third_party/tflite-hdrs/tensorflow/core/platform/env.h new file mode 100644 index 00000000..c88c758a --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/platform/env.h @@ -0,0 +1,61 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PLATFORM_ENV_H_ +#define TENSORFLOW_CORE_PLATFORM_ENV_H_ + +#include + +#include +#include +#include +#include + +#include "tensorflow/core/platform/env_time.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/file_system.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/numa.h" +#include "tensorflow/core/platform/platform.h" +#include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/stringpiece.h" +#include "tensorflow/core/platform/types.h" +#include "tsl/platform/env.h" + +namespace tensorflow { +// NOLINTBEGIN(misc-unused-using-decls) +using tsl::Env; +using tsl::EnvWrapper; +using tsl::FileSystemCopyFile; +using tsl::ReadBinaryProto; +using tsl::ReadFileToString; +using tsl::ReadTextOrBinaryProto; +using tsl::ReadTextProto; +using tsl::setenv; +using tsl::Thread; +using tsl::ThreadOptions; +using tsl::unsetenv; +using tsl::WriteBinaryProto; +using tsl::WriteStringToFile; +using tsl::WriteTextProto; +namespace register_file_system { +using tsl::register_file_system::Register; +} // namespace register_file_system +// NOLINTEND(misc-unused-using-decls) +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PLATFORM_ENV_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/platform/env_time.h b/third_party/tflite-hdrs/tensorflow/core/platform/env_time.h new file mode 100644 index 00000000..b2831965 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/platform/env_time.h @@ -0,0 +1,27 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_PLATFORM_ENV_TIME_H_ +#define TENSORFLOW_CORE_PLATFORM_ENV_TIME_H_ + +#include + +#include "tensorflow/core/platform/types.h" +#include "tsl/platform/env_time.h" + +namespace tensorflow { +using tsl::EnvTime; // NOLINT +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PLATFORM_ENV_TIME_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/platform/error_logging.h b/third_party/tflite-hdrs/tensorflow/core/platform/error_logging.h new file mode 100644 index 00000000..378a0cb6 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/platform/error_logging.h @@ -0,0 +1,25 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PLATFORM_ERROR_LOGGING_H_ +#define TENSORFLOW_CORE_PLATFORM_ERROR_LOGGING_H_ + +#include "tsl/platform/error_logging.h" + +namespace tensorflow { +using tsl::error_logging::Log; +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PLATFORM_ERROR_LOGGING_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/platform/error_payloads.h b/third_party/tflite-hdrs/tensorflow/core/platform/error_payloads.h new file mode 100644 index 00000000..7f1d8b61 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/platform/error_payloads.h @@ -0,0 +1,50 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PLATFORM_ERROR_PAYLOADS_H_ +#define TENSORFLOW_CORE_PLATFORM_ERROR_PAYLOADS_H_ + +#include "absl/status/status.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/protobuf/core_platform_payloads.pb.h" +// This file contains macros and payload keys for the error counter in +// EagerClient. + +namespace tsl { + +// Proto: tensorflow::core::platform::ErrorSourceProto +// Location: tensorflow/core/protobuf/core_platform_payloads.proto +// Usage: Payload key for recording the error raised source. Payload value is +// retrieved to update counter in +// tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_client.cc. +constexpr char kErrorSource[] = + "type.googleapis.com/tensorflow.core.platform.ErrorSourceProto"; + +// Set payload when status is not ok and ErrorSource payload hasn't been set. +// The code below will be used at every place where we would like to catch +// the error for the error counter in EagerClient. + +void OkOrSetErrorCounterPayload( + const tensorflow::core::platform::ErrorSourceProto::ErrorSource& + error_source, + absl::Status& status); +} // namespace tsl + +namespace tensorflow { +using tsl::kErrorSource; // NOLINT +using tsl::OkOrSetErrorCounterPayload; // NOLINT +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PLATFORM_ERROR_PAYLOADS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/platform/errors.h b/third_party/tflite-hdrs/tensorflow/core/platform/errors.h new file mode 100644 index 00000000..343edd91 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/platform/errors.h @@ -0,0 +1,107 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PLATFORM_ERRORS_H_ +#define TENSORFLOW_CORE_PLATFORM_ERRORS_H_ + +#include +#include +#include + +#include "absl/base/attributes.h" +#include "absl/strings/str_join.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/str_util.h" +#include "tensorflow/core/platform/strcat.h" +#include "tsl/platform/errors.h" + +namespace tensorflow { +namespace errors { + +// NOLINTBEGIN(misc-unused-using-decls) +// Maps UNIX errors into a Status. +using error::OK; +using tsl::errors::Aborted; +using tsl::errors::AbortedWithPayloads; +using tsl::errors::AlreadyExists; +using tsl::errors::AlreadyExistsWithPayloads; +using tsl::errors::AppendToMessage; +using tsl::errors::Cancelled; +using tsl::errors::CancelledWithPayloads; +using tsl::errors::CopyPayloads; +using tsl::errors::Create; +using tsl::errors::CreateWithUpdatedMessage; +using tsl::errors::DataLoss; +using tsl::errors::DataLossWithPayloads; +using tsl::errors::DeadlineExceeded; +using tsl::errors::DeadlineExceededWithPayloads; +using tsl::errors::FailedPrecondition; +using tsl::errors::FailedPreconditionWithPayloads; +using tsl::errors::FormatColocationNodeForError; +using tsl::errors::FormatFunctionForError; +using tsl::errors::FormatNodeNameForError; +using tsl::errors::FormatNodeNamesForError; +using tsl::errors::FormatOriginalNodeLocationForError; +using tsl::errors::GetPayloads; +using tsl::errors::InsertPayloads; +using tsl::errors::Internal; +using tsl::errors::InternalWithPayloads; +using tsl::errors::InvalidArgument; +using tsl::errors::InvalidArgumentWithPayloads; +using tsl::errors::IOError; +using tsl::errors::IsAborted; +using tsl::errors::IsAlreadyExists; +using tsl::errors::IsCancelled; +using tsl::errors::IsDataLoss; +using tsl::errors::IsDeadlineExceeded; +using tsl::errors::IsFailedPrecondition; +using tsl::errors::IsInternal; +using tsl::errors::IsInvalidArgument; +using tsl::errors::IsNotFound; +using tsl::errors::IsOutOfRange; +using tsl::errors::IsPermissionDenied; +using tsl::errors::IsResourceExhausted; +using tsl::errors::IsUnauthenticated; +using tsl::errors::IsUnavailable; +using tsl::errors::IsUnimplemented; +using tsl::errors::IsUnknown; +using tsl::errors::NotFound; +using tsl::errors::NotFoundWithPayloads; +using tsl::errors::OutOfRange; +using tsl::errors::OutOfRangeWithPayloads; +using tsl::errors::PermissionDenied; +using tsl::errors::PermissionDeniedWithPayloads; +using tsl::errors::ReplaceErrorFromNonCommunicationOps; +using tsl::errors::ResourceExhausted; +using tsl::errors::ResourceExhaustedWithPayloads; +using tsl::errors::Unauthenticated; +using tsl::errors::UnauthenticatedWithPayloads; +using tsl::errors::Unavailable; +using tsl::errors::UnavailableWithPayloads; +using tsl::errors::Unimplemented; +using tsl::errors::UnimplementedWithPayloads; +using tsl::errors::Unknown; +using tsl::errors::UnknownPayloads; +namespace internal { +using tsl::errors::internal::PrepareForStrCat; +} +// NOLINTEND(misc-unused-using-decls) + +} // namespace errors +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PLATFORM_ERRORS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/platform/file_statistics.h b/third_party/tflite-hdrs/tensorflow/core/platform/file_statistics.h new file mode 100644 index 00000000..b9059288 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/platform/file_statistics.h @@ -0,0 +1,26 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PLATFORM_FILE_STATISTICS_H_ +#define TENSORFLOW_CORE_PLATFORM_FILE_STATISTICS_H_ + +#include "tensorflow/core/platform/types.h" +#include "tsl/platform/file_statistics.h" + +namespace tensorflow { +using tsl::FileStatistics; // NOLINT +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PLATFORM_FILE_STATISTICS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/platform/file_system.h b/third_party/tflite-hdrs/tensorflow/core/platform/file_system.h new file mode 100644 index 00000000..14826a90 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/platform/file_system.h @@ -0,0 +1,48 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PLATFORM_FILE_SYSTEM_H_ +#define TENSORFLOW_CORE_PLATFORM_FILE_SYSTEM_H_ + +#include + +#include +#include +#include +#include +#include + +#include "tensorflow/core/platform/cord.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/file_statistics.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/platform.h" +#include "tensorflow/core/platform/stringpiece.h" +#include "tensorflow/core/platform/types.h" +#include "tsl/platform/file_system.h" + +namespace tensorflow { +// NOLINTBEGIN(misc-unused-using-decls) +using tsl::FileSystem; +using tsl::FileSystemRegistry; +using tsl::RandomAccessFile; +using tsl::ReadOnlyMemoryRegion; +using tsl::TransactionToken; +using tsl::WrappedFileSystem; +using tsl::WritableFile; +// NOLINTEND(misc-unused-using-decls) +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PLATFORM_FILE_SYSTEM_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/platform/file_system_helper.h b/third_party/tflite-hdrs/tensorflow/core/platform/file_system_helper.h new file mode 100644 index 00000000..01b3a92d --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/platform/file_system_helper.h @@ -0,0 +1,39 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PLATFORM_FILE_SYSTEM_HELPER_H_ +#define TENSORFLOW_CORE_PLATFORM_FILE_SYSTEM_HELPER_H_ + +#include +#include + +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/statusor.h" +#include "tsl/platform/file_system_helper.h" + +namespace tensorflow { +// NOLINTBEGIN(misc-unused-using-decls) +using tsl::Env; +using tsl::FileSystem; + +namespace internal { +using tsl::internal::FileExists; +using tsl::internal::GetMatchingPaths; +} // namespace internal +// NOLINTEND(misc-unused-using-decls) +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PLATFORM_FILE_SYSTEM_HELPER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/platform/fingerprint.h b/third_party/tflite-hdrs/tensorflow/core/platform/fingerprint.h new file mode 100644 index 00000000..d209799c --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/platform/fingerprint.h @@ -0,0 +1,37 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PLATFORM_FINGERPRINT_H_ +#define TENSORFLOW_CORE_PLATFORM_FINGERPRINT_H_ + +#include "tensorflow/core/platform/stringpiece.h" +#include "tensorflow/core/platform/types.h" +#include "tsl/platform/fingerprint.h" + +namespace tensorflow { + +using Fprint128 = tsl::Fprint128; +using Fprint128Hasher = tsl::Fprint128Hasher; + +// NOLINTBEGIN(misc-unused-using-decls) +using tsl::Fingerprint128; +using tsl::Fingerprint32; +using tsl::Fingerprint64; +using tsl::FingerprintCat64; +// NOLINTEND(misc-unused-using-decls) + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PLATFORM_FINGERPRINT_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/platform/float8.h b/third_party/tflite-hdrs/tensorflow/core/platform/float8.h new file mode 100644 index 00000000..e2cad449 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/platform/float8.h @@ -0,0 +1,26 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PLATFORM_FLOAT8_H_ +#define TENSORFLOW_CORE_PLATFORM_FLOAT8_H_ + +#include "tsl/platform/ml_dtypes.h" + +namespace tensorflow { +typedef tsl::float8_e4m3fn float8_e4m3fn; +typedef tsl::float8_e5m2 float8_e5m2; +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PLATFORM_FLOAT8_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/platform/gif.h b/third_party/tflite-hdrs/tensorflow/core/platform/gif.h new file mode 100644 index 00000000..79af3822 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/platform/gif.h @@ -0,0 +1,21 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PLATFORM_GIF_H_ +#define TENSORFLOW_CORE_PLATFORM_GIF_H_ + +#include "gif_lib.h" // from @gif + +#endif // TENSORFLOW_CORE_PLATFORM_GIF_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/platform/hash.h b/third_party/tflite-hdrs/tensorflow/core/platform/hash.h new file mode 100644 index 00000000..85364243 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/platform/hash.h @@ -0,0 +1,35 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Simple hash functions used for internal data structures + +#ifndef TENSORFLOW_CORE_PLATFORM_HASH_H_ +#define TENSORFLOW_CORE_PLATFORM_HASH_H_ + +#include "tsl/platform/hash.h" + +namespace tensorflow { +// NOLINTBEGIN(misc-unused-using-decls) +using ::tsl::hash; +using ::tsl::Hash32; +using ::tsl::Hash64; +using ::tsl::Hash64Combine; +using ::tsl::Hash64CombineUnordered; +using ::tsl::StringPieceHasher; +// NOLINTEND(misc-unused-using-decls) +} // namespace tensorflow + + +#endif // TENSORFLOW_CORE_PLATFORM_HASH_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/platform/host_info.h b/third_party/tflite-hdrs/tensorflow/core/platform/host_info.h new file mode 100644 index 00000000..caab7ae3 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/platform/host_info.h @@ -0,0 +1,31 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PLATFORM_HOST_INFO_H_ +#define TENSORFLOW_CORE_PLATFORM_HOST_INFO_H_ + +#include "tensorflow/core/platform/types.h" +#include "tsl/platform/host_info.h" + +namespace tensorflow { +namespace port { +using tsl::port::Hostname; +using tsl::port::IOStatistics; +using tsl::port::JobName; +using tsl::port::JobUid; +} // namespace port +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PLATFORM_HOST_INFO_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/platform/human_readable_json.h b/third_party/tflite-hdrs/tensorflow/core/platform/human_readable_json.h new file mode 100644 index 00000000..73cc5165 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/platform/human_readable_json.h @@ -0,0 +1,28 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PLATFORM_HUMAN_READABLE_JSON_H_ +#define TENSORFLOW_CORE_PLATFORM_HUMAN_READABLE_JSON_H_ + +#include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/platform/status.h" +#include "tsl/platform/human_readable_json.h" + +namespace tensorflow { +using tsl::HumanReadableJsonToProto; +using tsl::ProtoToHumanReadableJson; +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PLATFORM_HUMAN_READABLE_JSON_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/platform/init_main.h b/third_party/tflite-hdrs/tensorflow/core/platform/init_main.h new file mode 100644 index 00000000..07b0620e --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/platform/init_main.h @@ -0,0 +1,27 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PLATFORM_INIT_MAIN_H_ +#define TENSORFLOW_CORE_PLATFORM_INIT_MAIN_H_ + +#include "tsl/platform/init_main.h" + +namespace tensorflow { +namespace port { +using tsl::port::InitMain; +} // namespace port +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PLATFORM_INIT_MAIN_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/platform/intrusive_ptr.h b/third_party/tflite-hdrs/tensorflow/core/platform/intrusive_ptr.h new file mode 100644 index 00000000..b46bf5d8 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/platform/intrusive_ptr.h @@ -0,0 +1,31 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_PLATFORM_INTRUSIVE_PTR_H_ +#define TENSORFLOW_CORE_PLATFORM_INTRUSIVE_PTR_H_ + +#include + +#include "tsl/platform/intrusive_ptr.h" + +namespace tensorflow { +namespace core { + +template +using IntrusivePtr = tsl::core::IntrusivePtr; + +} // namespace core +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PLATFORM_INTRUSIVE_PTR_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/platform/jpeg.h b/third_party/tflite-hdrs/tensorflow/core/platform/jpeg.h new file mode 100644 index 00000000..68dadd18 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/platform/jpeg.h @@ -0,0 +1,29 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PLATFORM_JPEG_H_ +#define TENSORFLOW_CORE_PLATFORM_JPEG_H_ + +#include +#include +#include +#include + +extern "C" { +#include "jerror.h" // from @libjpeg_turbo // IWYU pragma: export +#include "jpeglib.h" // from @libjpeg_turbo // IWYU pragma: export +} + +#endif // TENSORFLOW_CORE_PLATFORM_JPEG_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/platform/load_library.h b/third_party/tflite-hdrs/tensorflow/core/platform/load_library.h new file mode 100644 index 00000000..6bb4a416 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/platform/load_library.h @@ -0,0 +1,33 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PLATFORM_LOAD_LIBRARY_H_ +#define TENSORFLOW_CORE_PLATFORM_LOAD_LIBRARY_H_ + +#include "tsl/platform/load_library.h" + +namespace tensorflow { + +namespace internal { + +using ::tsl::internal::FormatLibraryFileName; +using ::tsl::internal::GetSymbolFromLibrary; +using ::tsl::internal::LoadDynamicLibrary; + +} // namespace internal + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PLATFORM_LOAD_LIBRARY_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/platform/logging.h b/third_party/tflite-hdrs/tensorflow/core/platform/logging.h new file mode 100644 index 00000000..0a5b0205 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/platform/logging.h @@ -0,0 +1,36 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PLATFORM_LOGGING_H_ +#define TENSORFLOW_CORE_PLATFORM_LOGGING_H_ + +#include "tensorflow/core/platform/types.h" // IWYU pragma: export +#include "tsl/platform/logging.h" // IWYU pragma: export + +// NOLINTBEGIN(misc-unused-using-decls) +namespace tensorflow { +namespace internal { +using tsl::internal::LogString; +} // namespace internal +using tsl::TFAddLogSink; +using tsl::TFGetLogSinks; +using tsl::TFLogEntry; +using tsl::TFLogSink; +using tsl::TFRemoveLogSink; +using tsl::UpdateLogVerbosityIfDefined; +} // namespace tensorflow +// NOLINTEND(misc-unused-using-decls) + +#endif // TENSORFLOW_CORE_PLATFORM_LOGGING_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/platform/macros.h b/third_party/tflite-hdrs/tensorflow/core/platform/macros.h new file mode 100644 index 00000000..975f1c59 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/platform/macros.h @@ -0,0 +1,29 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PLATFORM_MACROS_H_ +#define TENSORFLOW_CORE_PLATFORM_MACROS_H_ + +#include "tsl/platform/macros.h" // IWYU pragma: export + +namespace tensorflow { +namespace internal { +template +constexpr auto remove_unused_variable_compiler_warning = + tsl::internal::remove_unused_variable_compiler_warning; +} // namespace internal +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PLATFORM_MACROS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/platform/mem.h b/third_party/tflite-hdrs/tensorflow/core/platform/mem.h new file mode 100644 index 00000000..20acf859 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/platform/mem.h @@ -0,0 +1,42 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PLATFORM_MEM_H_ +#define TENSORFLOW_CORE_PLATFORM_MEM_H_ + +#include "tsl/platform/mem.h" +// TODO(cwhipkey): remove this when callers use annotations directly. +#include "tensorflow/core/platform/dynamic_annotations.h" + +namespace tensorflow { +namespace port { +// NOLINTBEGIN(misc-unused-using-decls) +using ::tsl::port::AlignedFree; +using ::tsl::port::AlignedMalloc; +using ::tsl::port::AvailableRam; +using ::tsl::port::Free; +using ::tsl::port::GetMemoryBandwidthInfo; +using ::tsl::port::GetMemoryInfo; +using ::tsl::port::Malloc; +using ::tsl::port::MallocExtension_GetAllocatedSize; +using ::tsl::port::MallocExtension_ReleaseToSystem; +using ::tsl::port::MemoryBandwidthInfo; +using ::tsl::port::MemoryInfo; +using ::tsl::port::Realloc; +// NOLINTEND(misc-unused-using-decls) +} // namespace port +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PLATFORM_MEM_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/platform/mutex.h b/third_party/tflite-hdrs/tensorflow/core/platform/mutex.h new file mode 100644 index 00000000..4a8d76c4 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/platform/mutex.h @@ -0,0 +1,39 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PLATFORM_MUTEX_H_ +#define TENSORFLOW_CORE_PLATFORM_MUTEX_H_ + +#include "tensorflow/core/platform/platform.h" +#include "tensorflow/core/platform/thread_annotations.h" +#include "tensorflow/core/platform/types.h" +#include "tsl/platform/mutex.h" + +namespace tensorflow { + +using tsl::Condition; +using tsl::condition_variable; +using tsl::ConditionResult; +using tsl::kCond_MaybeNotified; +using tsl::kCond_Timeout; +using tsl::LINKER_INITIALIZED; +using tsl::LinkerInitialized; +using tsl::mutex; +using tsl::mutex_lock; +using tsl::tf_shared_lock; +using tsl::WaitForMilliseconds; +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PLATFORM_MUTEX_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/platform/net.h b/third_party/tflite-hdrs/tensorflow/core/platform/net.h new file mode 100644 index 00000000..4b9d51fc --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/platform/net.h @@ -0,0 +1,27 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PLATFORM_NET_H_ +#define TENSORFLOW_CORE_PLATFORM_NET_H_ + +#include "tsl/platform/net.h" + +namespace tensorflow { +namespace internal { +using ::tsl::internal::PickUnusedPortOrDie; // NOLINT(misc-unused-using-decls) +} // namespace internal +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PLATFORM_NET_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/platform/notification.h b/third_party/tflite-hdrs/tensorflow/core/platform/notification.h new file mode 100644 index 00000000..a2d48a63 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/platform/notification.h @@ -0,0 +1,29 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PLATFORM_NOTIFICATION_H_ +#define TENSORFLOW_CORE_PLATFORM_NOTIFICATION_H_ + +#include "tensorflow/core/platform/platform.h" +#include "tsl/platform/notification.h" + +namespace tensorflow { +// NOLINTBEGIN(misc-unused-using-decls) +using tsl::Notification; +using tsl::WaitForNotificationWithTimeout; +// NOLINTEND(misc-unused-using-decls) +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PLATFORM_NOTIFICATION_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/platform/null_file_system.h b/third_party/tflite-hdrs/tensorflow/core/platform/null_file_system.h new file mode 100644 index 00000000..3fc7d179 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/platform/null_file_system.h @@ -0,0 +1,29 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PLATFORM_NULL_FILE_SYSTEM_H_ +#define TENSORFLOW_CORE_PLATFORM_NULL_FILE_SYSTEM_H_ + +#include "tsl/platform/null_file_system.h" + +namespace tensorflow { +#ifndef SWIG +using ::tsl::NullFileSystem; // NOLINT(misc-unused-using-decls) +#endif + +// END_SKIP_DOXYGEN +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PLATFORM_NULL_FILE_SYSTEM_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/platform/numa.h b/third_party/tflite-hdrs/tensorflow/core/platform/numa.h new file mode 100644 index 00000000..6333c01f --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/platform/numa.h @@ -0,0 +1,35 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PLATFORM_NUMA_H_ +#define TENSORFLOW_CORE_PLATFORM_NUMA_H_ + +#include "tensorflow/core/platform/platform.h" +#include "tensorflow/core/platform/types.h" +#include "tsl/platform/numa.h" + +namespace tensorflow { +namespace port { +using tsl::port::kNUMANoAffinity; +using tsl::port::NUMAEnabled; +using tsl::port::NUMAFree; +using tsl::port::NUMAGetMemAffinity; +using tsl::port::NUMAGetThreadNodeAffinity; +using tsl::port::NUMAMalloc; +using tsl::port::NUMANumNodes; +using tsl::port::NUMASetThreadNodeAffinity; +} // namespace port +} // namespace tensorflow +#endif // TENSORFLOW_CORE_PLATFORM_NUMA_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/platform/numbers.h b/third_party/tflite-hdrs/tensorflow/core/platform/numbers.h new file mode 100644 index 00000000..3164aab4 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/platform/numbers.h @@ -0,0 +1,52 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PLATFORM_NUMBERS_H_ +#define TENSORFLOW_CORE_PLATFORM_NUMBERS_H_ + +#include + +#include "tensorflow/core/platform/stringpiece.h" +#include "tensorflow/core/platform/types.h" +#include "tsl/platform/numbers.h" + +namespace tensorflow { +namespace strings { +// NOLINTBEGIN(misc-unused-using-decls) +using tsl::strings::DoubleToBuffer; +using tsl::strings::FastInt32ToBufferLeft; +using tsl::strings::FastInt64ToBufferLeft; +using tsl::strings::FastUInt32ToBufferLeft; +using tsl::strings::FastUInt64ToBufferLeft; +using tsl::strings::FloatToBuffer; +using tsl::strings::FpToString; +using tsl::strings::HexStringToUint64; +using tsl::strings::HumanReadableElapsedTime; +using tsl::strings::HumanReadableNum; +using tsl::strings::HumanReadableNumBytes; +using tsl::strings::kFastToBufferSize; +using tsl::strings::ProtoParseNumeric; +using tsl::strings::safe_strto32; +using tsl::strings::safe_strto64; +using tsl::strings::safe_strtod; +using tsl::strings::safe_strtof; +using tsl::strings::safe_strtou32; +using tsl::strings::safe_strtou64; +using tsl::strings::SafeStringToNumeric; +// NOLINTEND(misc-unused-using-decls) +} // namespace strings +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PLATFORM_NUMBERS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/platform/path.h b/third_party/tflite-hdrs/tensorflow/core/platform/path.h new file mode 100644 index 00000000..ca13a99f --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/platform/path.h @@ -0,0 +1,47 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PLATFORM_PATH_H_ +#define TENSORFLOW_CORE_PLATFORM_PATH_H_ + +#include "tensorflow/core/platform/stringpiece.h" +#include "tensorflow/core/platform/types.h" +#include "tsl/platform/path.h" + +// NOLINTBEGIN(misc-unused-using-decls) +namespace tensorflow { +namespace io { +namespace internal { +using tsl::io::internal::JoinPathImpl; +} +#ifndef SWIG // variadic templates +using tsl::io::JoinPath; +#endif /* SWIG */ +using tsl::io::Basename; +using tsl::io::BasenamePrefix; +using tsl::io::CleanPath; +using tsl::io::CommonPathPrefix; +using tsl::io::CreateURI; +using tsl::io::Dirname; +using tsl::io::Extension; +using tsl::io::GetTempFilename; +using tsl::io::GetTestUndeclaredOutputsDir; +using tsl::io::IsAbsolutePath; +using tsl::io::ParseURI; +} // namespace io +} // namespace tensorflow +// NOLINTEND(misc-unused-using-decls) + +#endif // TENSORFLOW_CORE_PLATFORM_PATH_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/platform/platform.h b/third_party/tflite-hdrs/tensorflow/core/platform/platform.h new file mode 100644 index 00000000..6d5d9879 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/platform/platform.h @@ -0,0 +1,21 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PLATFORM_PLATFORM_H_ +#define TENSORFLOW_CORE_PLATFORM_PLATFORM_H_ + +#include "tsl/platform/platform.h" + +#endif // TENSORFLOW_CORE_PLATFORM_PLATFORM_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/platform/platform_strings.h b/third_party/tflite-hdrs/tensorflow/core/platform/platform_strings.h new file mode 100644 index 00000000..a42f7c76 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/platform/platform_strings.h @@ -0,0 +1,362 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PLATFORM_PLATFORM_STRINGS_H_ +#define TENSORFLOW_CORE_PLATFORM_PLATFORM_STRINGS_H_ + +// This header defines the macro TF_PLATFORM_STRINGS() which should be used +// once in each dynamically loadable TensorFlow module. It embeds static +// strings into the compilation unit that allow TensorFlow to determine what +// compilation options were in effect when the compilation unit was built. All +// compilation units within the same dynamically loadable library should be +// built with the same options (or at least, the strings should be embedded in +// the compilation unit built with the most restrictive options). + +// The platform strings embedded into a binary may be retrieved with the +// GetPlatformStrings function. + +// Rationale: +// We wish to load only those libraries that this CPU can execute. For +// example, we should not load a library compiled with avx256 instructions on a +// CPU that cannot execute them. +// +// One might think that one could dlopen() the library, and call a routine that +// would return which cpu type it was compiled for. Alas, this does not work, +// because at dlopen() time, a library containing C++ will execute constructors +// of class variables with static storage class. Even code that looks +// innocuous may use optional platform-specific instructions. For example, +// the fastest way to zero a region of memory might use optional instructions. +// +// One might think one could run a tool such as "objdump" to read flags from +// the libraries' headers, or perhaps disassemble each library to look for +// particular instructions. Unfortunately, the desired flags are not present +// in the headers, and disassembly can be prohibitively slow ("objdump -d" is +// very slow, for example). Moreover, a tool to examine the library may not +// be present on the system unless the user has installed special packages (for +// example, on Windows). +// +// Instead, we adopt a crude but straightforward solution: We require +// developers to use the macro TF_PLATFORM_STRINGS() in their library, to +// embed the compilation options as constant strings. The compiler's +// predefined macros pick which strings are included. We then search for the +// strings in the files, and then dlopen() only those libraries that have or +// lack strings as needed. +// +// We adopt the approach of placing in the binary a fairly raw copy of the +// predefined macros, rather than trying to interpret them in complex ways at +// compile time. This allows the loading binary to alter its interpretation of +// the strings without library developers having to recompile. + +#include + +#include +#include + +// Aside from the header guard, the internal macros defined here have the form: +// TF_PLAT_STR_* + +// If a macro is removed from the list of tested macros, the major version in +// the following version number should be incremented, and the minor version +// set to zero. Otherwise, if a macro is added to the list of tested macros, +// the minor number should be incremented. +#define TF_PLAT_STR_VERSION_ "1.0" + +// Prefix of each option string indicator in the binary. +// After the prefix, such strings have the form: +// [A-Za-z_0-9]= +// followed by a terminating nul. To simplify searching, this prefix is all +// ASCII, starts with a nul, and contains no character twice. +#define TF_PLAT_STR_MAGIC_PREFIX_ "\0S\\s\":^p*L}" + +// A helper macro for TF_PLAT_STR_AS_STR_(). +#define TF_PLAT_STR_STR_1_(x) #x + +// Yield a constant string corresponding to x, after macro expansion. +#define TF_PLAT_STR_AS_STR_(x) TF_PLAT_STR_STR_1_(x) + +// An empty definition to make lists more uniform. +#define TF_PLAT_STR_TERMINATOR_ + +// TF_PLAT_STR_(x) introduces a constant string indicating whether a +// particular compilation option has been turned on. +// +// In gcc and clang, we might imagine using something like +// #define TF_PLAT_STR_(x) \ +// (sizeof (#x) != sizeof (TF_PLAT_STR_AS_STR_ (x))? \ +// TF_PLAT_STR_MAGIC_PREFIX_ #x "=" TF_PLAT_STR_AS_STR_ (x) : \ +// TF_PLAT_STR_MAGIC_PREFIX_ #x "=0"), +// but some compilers (notably MSVC) place both "foo" and "bar" in the binary +// when presented with +// (true? "foo" : "bar") +// so we must use #if to select the strings we need, which is rather verbose. +#define TF_PLAT_STR_(x) TF_PLAT_STR_MAGIC_PREFIX_ #x "=" TF_PLAT_STR_AS_STR_(x) + +// Include the #if machinery that sets the macros used below. +// platform_strings_computed.h can be generated by filtering this header file +// through: +// awk ' +// header == "" { print; } +// /\*\// && header == "" { +// print "// Generated from platform_strings.h."; +// print ""; +// print "#ifndef TENSORFLOW_CORE_PLATFORM_PLATFORM_STRINGS_COMPUTED_H_"; +// print "#define TENSORFLOW_CORE_PLATFORM_PLATFORM_STRINGS_COMPUTED_H_"; +// print ""; +// header = 1; +// } +// /^#define TF_PLAT_STR_LIST_[a-zA-Z0-9_]*\(\) *\\$/ { active = 1; } +// /TF_PLAT_STR_TERMINATOR_/ { active = 0; } +// /^ *TF_PLAT_STR_[A-Za-z0-9_]* *\\$/ && active { +// x = $0; +// sub(/^ *TF_PLAT_STR_/, "", x); +// sub(/ *\\$/, "", x); +// printf ("#if defined(%s)\n", x); +// printf ("#define TF_PLAT_STR_%s TF_PLAT_STR_(%s)\n", x, x); +// printf ("#else\n"); +// printf ("#define TF_PLAT_STR_%s\n", x); +// printf ("#endif\n"); +// } +// END { +// print ""; +// print "#endif // TENSORFLOW_CORE_PLATFORM_PLATFORM_STRINGS_COMPUTED_H_"; +// }' +#include "tensorflow/core/platform/platform_strings_computed.h" + +// clang-format butchers the following lines. +// clang-format off + +// x86_64 and x86_32 optional features. +#define TF_PLAT_STR_LIST___x86_64__() \ + TF_PLAT_STR__M_IX86_FP \ + TF_PLAT_STR__NO_PREFETCHW \ + TF_PLAT_STR___3dNOW_A__ \ + TF_PLAT_STR___3dNOW__ \ + TF_PLAT_STR___ABM__ \ + TF_PLAT_STR___ADX__ \ + TF_PLAT_STR___AES__ \ + TF_PLAT_STR___AVX2__ \ + TF_PLAT_STR___AVX512BW__ \ + TF_PLAT_STR___AVX512CD__ \ + TF_PLAT_STR___AVX512DQ__ \ + TF_PLAT_STR___AVX512ER__ \ + TF_PLAT_STR___AVX512F__ \ + TF_PLAT_STR___AVX512IFMA__ \ + TF_PLAT_STR___AVX512PF__ \ + TF_PLAT_STR___AVX512VBMI__ \ + TF_PLAT_STR___AVX512VL__ \ + TF_PLAT_STR___AVX__ \ + TF_PLAT_STR___BMI2__ \ + TF_PLAT_STR___BMI__ \ + TF_PLAT_STR___CLFLUSHOPT__ \ + TF_PLAT_STR___CLZERO__ \ + TF_PLAT_STR___F16C__ \ + TF_PLAT_STR___FMA4__ \ + TF_PLAT_STR___FMA__ \ + TF_PLAT_STR___FP_FAST_FMA \ + TF_PLAT_STR___FP_FAST_FMAF \ + TF_PLAT_STR___FSGSBASE__ \ + TF_PLAT_STR___FXSR__ \ + TF_PLAT_STR___LWP__ \ + TF_PLAT_STR___LZCNT__ \ + TF_PLAT_STR___MMX__ \ + TF_PLAT_STR___MWAITX__ \ + TF_PLAT_STR___PCLMUL__ \ + TF_PLAT_STR___PKU__ \ + TF_PLAT_STR___POPCNT__ \ + TF_PLAT_STR___PRFCHW__ \ + TF_PLAT_STR___RDRND__ \ + TF_PLAT_STR___RDSEED__ \ + TF_PLAT_STR___RTM__ \ + TF_PLAT_STR___SHA__ \ + TF_PLAT_STR___SSE2_MATH__ \ + TF_PLAT_STR___SSE2__ \ + TF_PLAT_STR___SSE_MATH__ \ + TF_PLAT_STR___SSE__ \ + TF_PLAT_STR___SSE3__ \ + TF_PLAT_STR___SSE4A__ \ + TF_PLAT_STR___SSE4_1__ \ + TF_PLAT_STR___SSE4_2__ \ + TF_PLAT_STR___SSSE3__ \ + TF_PLAT_STR___TBM__ \ + TF_PLAT_STR___XOP__ \ + TF_PLAT_STR___XSAVEC__ \ + TF_PLAT_STR___XSAVEOPT__ \ + TF_PLAT_STR___XSAVES__ \ + TF_PLAT_STR___XSAVE__ \ + TF_PLAT_STR_TERMINATOR_ + +// PowerPC (64- and 32-bit) optional features. +#define TF_PLAT_STR_LIST___powerpc64__() \ + TF_PLAT_STR__SOFT_DOUBLE \ + TF_PLAT_STR__SOFT_FLOAT \ + TF_PLAT_STR___ALTIVEC__ \ + TF_PLAT_STR___APPLE_ALTIVEC__ \ + TF_PLAT_STR___CRYPTO__ \ + TF_PLAT_STR___FLOAT128_HARDWARE__ \ + TF_PLAT_STR___FLOAT128_TYPE__ \ + TF_PLAT_STR___FP_FAST_FMA \ + TF_PLAT_STR___FP_FAST_FMAF \ + TF_PLAT_STR___HTM__ \ + TF_PLAT_STR___NO_FPRS__ \ + TF_PLAT_STR___NO_LWSYNC__ \ + TF_PLAT_STR___POWER8_VECTOR__ \ + TF_PLAT_STR___POWER9_VECTOR__ \ + TF_PLAT_STR___PPC405__ \ + TF_PLAT_STR___QUAD_MEMORY_ATOMIC__ \ + TF_PLAT_STR___RECIPF__ \ + TF_PLAT_STR___RECIP_PRECISION__ \ + TF_PLAT_STR___RECIP__ \ + TF_PLAT_STR___RSQRTEF__ \ + TF_PLAT_STR___RSQRTE__ \ + TF_PLAT_STR___TM_FENCE__ \ + TF_PLAT_STR___UPPER_REGS_DF__ \ + TF_PLAT_STR___UPPER_REGS_SF__ \ + TF_PLAT_STR___VEC__ \ + TF_PLAT_STR___VSX__ \ + TF_PLAT_STR_TERMINATOR_ + +// aarch64 and 32-bit arm optional features +#define TF_PLAT_STR_LIST___aarch64__() \ + TF_PLAT_STR___ARM_ARCH \ + TF_PLAT_STR___ARM_FEATURE_CLZ \ + TF_PLAT_STR___ARM_FEATURE_CRC32 \ + TF_PLAT_STR___ARM_FEATURE_CRC32 \ + TF_PLAT_STR___ARM_FEATURE_CRYPTO \ + TF_PLAT_STR___ARM_FEATURE_DIRECTED_ROUNDING \ + TF_PLAT_STR___ARM_FEATURE_DSP \ + TF_PLAT_STR___ARM_FEATURE_FMA \ + TF_PLAT_STR___ARM_FEATURE_IDIV \ + TF_PLAT_STR___ARM_FEATURE_LDREX \ + TF_PLAT_STR___ARM_FEATURE_NUMERIC_MAXMIN \ + TF_PLAT_STR___ARM_FEATURE_QBIT \ + TF_PLAT_STR___ARM_FEATURE_QRDMX \ + TF_PLAT_STR___ARM_FEATURE_SAT \ + TF_PLAT_STR___ARM_FEATURE_SIMD32 \ + TF_PLAT_STR___ARM_FEATURE_UNALIGNED \ + TF_PLAT_STR___ARM_FP \ + TF_PLAT_STR___ARM_NEON_FP \ + TF_PLAT_STR___ARM_NEON__ \ + TF_PLAT_STR___ARM_WMMX \ + TF_PLAT_STR___IWMMXT2__ \ + TF_PLAT_STR___IWMMXT__ \ + TF_PLAT_STR___VFP_FP__ \ + TF_PLAT_STR_TERMINATOR_ + +// Generic features, including indication of architecture and OS. +// The _M_* macros are defined by Visual Studio. +// It doesn't define __LITTLE_ENDIAN__ or __BYTE_ORDER__; +// Windows is assumed to be little endian. +#define TF_PLAT_STR_LIST___generic__() \ + TF_PLAT_STR_TARGET_IPHONE_SIMULATOR \ + TF_PLAT_STR_TARGET_OS_IOS \ + TF_PLAT_STR_TARGET_OS_IPHONE \ + TF_PLAT_STR__MSC_VER \ + TF_PLAT_STR__M_ARM \ + TF_PLAT_STR__M_ARM64 \ + TF_PLAT_STR__M_ARM_ARMV7VE \ + TF_PLAT_STR__M_ARM_FP \ + TF_PLAT_STR__M_IX86 \ + TF_PLAT_STR__M_X64 \ + TF_PLAT_STR__WIN32 \ + TF_PLAT_STR__WIN64 \ + TF_PLAT_STR___ANDROID__ \ + TF_PLAT_STR___APPLE__ \ + TF_PLAT_STR___BYTE_ORDER__ \ + TF_PLAT_STR___CYGWIN__ \ + TF_PLAT_STR___FreeBSD__ \ + TF_PLAT_STR___LITTLE_ENDIAN__ \ + TF_PLAT_STR___NetBSD__ \ + TF_PLAT_STR___OpenBSD__ \ + TF_PLAT_STR_____MSYS__ \ + TF_PLAT_STR___aarch64__ \ + TF_PLAT_STR___alpha__ \ + TF_PLAT_STR___arm__ \ + TF_PLAT_STR___i386__ \ + TF_PLAT_STR___i686__ \ + TF_PLAT_STR___ia64__ \ + TF_PLAT_STR___linux__ \ + TF_PLAT_STR___mips32__ \ + TF_PLAT_STR___mips64__ \ + TF_PLAT_STR___powerpc64__ \ + TF_PLAT_STR___powerpc__ \ + TF_PLAT_STR___riscv___ \ + TF_PLAT_STR___s390x__ \ + TF_PLAT_STR___sparc64__ \ + TF_PLAT_STR___sparc__ \ + TF_PLAT_STR___x86_64__ \ + TF_PLAT_STR_TERMINATOR_ + +#if !defined(__x86_64__) && !defined(_M_X64) && \ + !defined(__i386__) && !defined(_M_IX86) +#undef TF_PLAT_STR_LIST___x86_64__ +#define TF_PLAT_STR_LIST___x86_64__() +#endif +#if !defined(__powerpc64__) && !defined(__powerpc__) +#undef TF_PLAT_STR_LIST___powerpc64__ +#define TF_PLAT_STR_LIST___powerpc64__() +#endif +#if !defined(__aarch64__) && !defined(_M_ARM64) && \ + !defined(__arm__) && !defined(_M_ARM) +#undef TF_PLAT_STR_LIST___aarch64__ +#define TF_PLAT_STR_LIST___aarch64__() +#endif + +// Macro to be used in each dynamically loadable library. +// +// The BSS global variable tf_cpu_option_global and the class +// instance tf_cpu_option_avoid_omit_class are needed to prevent +// compilers/linkers such as clang from omitting the static variable +// tf_cpu_option[], which would otherwise appear to be unused. We cannot make +// tf_cpu_option[] global, because we then might get multiply-defined symbols +// if TF_PLAT_STR() is used twice in the same library. +// (tf_cpu_option_global doesn't see such errors because it is +// defined in BSS, so multiple definitions are combined by the linker.) gcc's +// __attribute__((used)) is insufficient because it seems to be ignored by +// linkers. +#define TF_PLATFORM_STRINGS() \ + static const char tf_cpu_option[] = \ + TF_PLAT_STR_MAGIC_PREFIX_ "TF_PLAT_STR_VERSION=" TF_PLAT_STR_VERSION_ \ + TF_PLAT_STR_LIST___x86_64__() \ + TF_PLAT_STR_LIST___powerpc64__() \ + TF_PLAT_STR_LIST___aarch64__() \ + TF_PLAT_STR_LIST___generic__() \ + ; \ + const char *tf_cpu_option_global; \ + namespace { \ + class TFCPUOptionHelper { \ + public: \ + TFCPUOptionHelper() { \ + /* Compilers/linkers remove unused variables aggressively. The */ \ + /* following gyrations subvert most such optimizations. */ \ + tf_cpu_option_global = tf_cpu_option; \ + /* Nothing is printed because the string starts with a nul. */ \ + printf("%s%s", tf_cpu_option, ""); \ + } \ + } tf_cpu_option_avoid_omit_class; \ + } /* anonymous namespace */ +// clang-format on + +namespace tensorflow { + +// Retrieves the platform strings from the file at the given path and appends +// them to the given vector. If the returned int is non-zero, an error occurred +// reading the file and vector may or may not be modified. The returned error +// code is suitable for use with strerror(). +int GetPlatformStrings(const std::string& path, + std::vector* found); + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PLATFORM_PLATFORM_STRINGS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/platform/platform_strings_computed.h b/third_party/tflite-hdrs/tensorflow/core/platform/platform_strings_computed.h new file mode 100644 index 00000000..6a17f3bf --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/platform/platform_strings_computed.h @@ -0,0 +1,735 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// Generated from platform_strings.h. + +#ifndef TENSORFLOW_CORE_PLATFORM_PLATFORM_STRINGS_COMPUTED_H_ +#define TENSORFLOW_CORE_PLATFORM_PLATFORM_STRINGS_COMPUTED_H_ + +#if defined(_M_IX86_FP) +#define TF_PLAT_STR__M_IX86_FP TF_PLAT_STR_(_M_IX86_FP) +#else +#define TF_PLAT_STR__M_IX86_FP +#endif +#if defined(_NO_PREFETCHW) +#define TF_PLAT_STR__NO_PREFETCHW TF_PLAT_STR_(_NO_PREFETCHW) +#else +#define TF_PLAT_STR__NO_PREFETCHW +#endif +#if defined(__3dNOW_A__) +#define TF_PLAT_STR___3dNOW_A__ TF_PLAT_STR_(__3dNOW_A__) +#else +#define TF_PLAT_STR___3dNOW_A__ +#endif +#if defined(__3dNOW__) +#define TF_PLAT_STR___3dNOW__ TF_PLAT_STR_(__3dNOW__) +#else +#define TF_PLAT_STR___3dNOW__ +#endif +#if defined(__ABM__) +#define TF_PLAT_STR___ABM__ TF_PLAT_STR_(__ABM__) +#else +#define TF_PLAT_STR___ABM__ +#endif +#if defined(__ADX__) +#define TF_PLAT_STR___ADX__ TF_PLAT_STR_(__ADX__) +#else +#define TF_PLAT_STR___ADX__ +#endif +#if defined(__AES__) +#define TF_PLAT_STR___AES__ TF_PLAT_STR_(__AES__) +#else +#define TF_PLAT_STR___AES__ +#endif +#if defined(__AVX2__) +#define TF_PLAT_STR___AVX2__ TF_PLAT_STR_(__AVX2__) +#else +#define TF_PLAT_STR___AVX2__ +#endif +#if defined(__AVX512BW__) +#define TF_PLAT_STR___AVX512BW__ TF_PLAT_STR_(__AVX512BW__) +#else +#define TF_PLAT_STR___AVX512BW__ +#endif +#if defined(__AVX512CD__) +#define TF_PLAT_STR___AVX512CD__ TF_PLAT_STR_(__AVX512CD__) +#else +#define TF_PLAT_STR___AVX512CD__ +#endif +#if defined(__AVX512DQ__) +#define TF_PLAT_STR___AVX512DQ__ TF_PLAT_STR_(__AVX512DQ__) +#else +#define TF_PLAT_STR___AVX512DQ__ +#endif +#if defined(__AVX512ER__) +#define TF_PLAT_STR___AVX512ER__ TF_PLAT_STR_(__AVX512ER__) +#else +#define TF_PLAT_STR___AVX512ER__ +#endif +#if defined(__AVX512F__) +#define TF_PLAT_STR___AVX512F__ TF_PLAT_STR_(__AVX512F__) +#else +#define TF_PLAT_STR___AVX512F__ +#endif +#if defined(__AVX512IFMA__) +#define TF_PLAT_STR___AVX512IFMA__ TF_PLAT_STR_(__AVX512IFMA__) +#else +#define TF_PLAT_STR___AVX512IFMA__ +#endif +#if defined(__AVX512PF__) +#define TF_PLAT_STR___AVX512PF__ TF_PLAT_STR_(__AVX512PF__) +#else +#define TF_PLAT_STR___AVX512PF__ +#endif +#if defined(__AVX512VBMI__) +#define TF_PLAT_STR___AVX512VBMI__ TF_PLAT_STR_(__AVX512VBMI__) +#else +#define TF_PLAT_STR___AVX512VBMI__ +#endif +#if defined(__AVX512VL__) +#define TF_PLAT_STR___AVX512VL__ TF_PLAT_STR_(__AVX512VL__) +#else +#define TF_PLAT_STR___AVX512VL__ +#endif +#if defined(__AVX__) +#define TF_PLAT_STR___AVX__ TF_PLAT_STR_(__AVX__) +#else +#define TF_PLAT_STR___AVX__ +#endif +#if defined(__BMI2__) +#define TF_PLAT_STR___BMI2__ TF_PLAT_STR_(__BMI2__) +#else +#define TF_PLAT_STR___BMI2__ +#endif +#if defined(__BMI__) +#define TF_PLAT_STR___BMI__ TF_PLAT_STR_(__BMI__) +#else +#define TF_PLAT_STR___BMI__ +#endif +#if defined(__CLFLUSHOPT__) +#define TF_PLAT_STR___CLFLUSHOPT__ TF_PLAT_STR_(__CLFLUSHOPT__) +#else +#define TF_PLAT_STR___CLFLUSHOPT__ +#endif +#if defined(__CLZERO__) +#define TF_PLAT_STR___CLZERO__ TF_PLAT_STR_(__CLZERO__) +#else +#define TF_PLAT_STR___CLZERO__ +#endif +#if defined(__F16C__) +#define TF_PLAT_STR___F16C__ TF_PLAT_STR_(__F16C__) +#else +#define TF_PLAT_STR___F16C__ +#endif +#if defined(__FMA4__) +#define TF_PLAT_STR___FMA4__ TF_PLAT_STR_(__FMA4__) +#else +#define TF_PLAT_STR___FMA4__ +#endif +#if defined(__FMA__) +#define TF_PLAT_STR___FMA__ TF_PLAT_STR_(__FMA__) +#else +#define TF_PLAT_STR___FMA__ +#endif +#if defined(__FP_FAST_FMA) +#define TF_PLAT_STR___FP_FAST_FMA TF_PLAT_STR_(__FP_FAST_FMA) +#else +#define TF_PLAT_STR___FP_FAST_FMA +#endif +#if defined(__FP_FAST_FMAF) +#define TF_PLAT_STR___FP_FAST_FMAF TF_PLAT_STR_(__FP_FAST_FMAF) +#else +#define TF_PLAT_STR___FP_FAST_FMAF +#endif +#if defined(__FSGSBASE__) +#define TF_PLAT_STR___FSGSBASE__ TF_PLAT_STR_(__FSGSBASE__) +#else +#define TF_PLAT_STR___FSGSBASE__ +#endif +#if defined(__FXSR__) +#define TF_PLAT_STR___FXSR__ TF_PLAT_STR_(__FXSR__) +#else +#define TF_PLAT_STR___FXSR__ +#endif +#if defined(__LWP__) +#define TF_PLAT_STR___LWP__ TF_PLAT_STR_(__LWP__) +#else +#define TF_PLAT_STR___LWP__ +#endif +#if defined(__LZCNT__) +#define TF_PLAT_STR___LZCNT__ TF_PLAT_STR_(__LZCNT__) +#else +#define TF_PLAT_STR___LZCNT__ +#endif +#if defined(__MMX__) +#define TF_PLAT_STR___MMX__ TF_PLAT_STR_(__MMX__) +#else +#define TF_PLAT_STR___MMX__ +#endif +#if defined(__MWAITX__) +#define TF_PLAT_STR___MWAITX__ TF_PLAT_STR_(__MWAITX__) +#else +#define TF_PLAT_STR___MWAITX__ +#endif +#if defined(__PCLMUL__) +#define TF_PLAT_STR___PCLMUL__ TF_PLAT_STR_(__PCLMUL__) +#else +#define TF_PLAT_STR___PCLMUL__ +#endif +#if defined(__PKU__) +#define TF_PLAT_STR___PKU__ TF_PLAT_STR_(__PKU__) +#else +#define TF_PLAT_STR___PKU__ +#endif +#if defined(__POPCNT__) +#define TF_PLAT_STR___POPCNT__ TF_PLAT_STR_(__POPCNT__) +#else +#define TF_PLAT_STR___POPCNT__ +#endif +#if defined(__PRFCHW__) +#define TF_PLAT_STR___PRFCHW__ TF_PLAT_STR_(__PRFCHW__) +#else +#define TF_PLAT_STR___PRFCHW__ +#endif +#if defined(__RDRND__) +#define TF_PLAT_STR___RDRND__ TF_PLAT_STR_(__RDRND__) +#else +#define TF_PLAT_STR___RDRND__ +#endif +#if defined(__RDSEED__) +#define TF_PLAT_STR___RDSEED__ TF_PLAT_STR_(__RDSEED__) +#else +#define TF_PLAT_STR___RDSEED__ +#endif +#if defined(__RTM__) +#define TF_PLAT_STR___RTM__ TF_PLAT_STR_(__RTM__) +#else +#define TF_PLAT_STR___RTM__ +#endif +#if defined(__SHA__) +#define TF_PLAT_STR___SHA__ TF_PLAT_STR_(__SHA__) +#else +#define TF_PLAT_STR___SHA__ +#endif +#if defined(__SSE2_MATH__) +#define TF_PLAT_STR___SSE2_MATH__ TF_PLAT_STR_(__SSE2_MATH__) +#else +#define TF_PLAT_STR___SSE2_MATH__ +#endif +#if defined(__SSE2__) +#define TF_PLAT_STR___SSE2__ TF_PLAT_STR_(__SSE2__) +#else +#define TF_PLAT_STR___SSE2__ +#endif +#if defined(__SSE_MATH__) +#define TF_PLAT_STR___SSE_MATH__ TF_PLAT_STR_(__SSE_MATH__) +#else +#define TF_PLAT_STR___SSE_MATH__ +#endif +#if defined(__SSE__) +#define TF_PLAT_STR___SSE__ TF_PLAT_STR_(__SSE__) +#else +#define TF_PLAT_STR___SSE__ +#endif +#if defined(__SSE3__) +#define TF_PLAT_STR___SSE3__ TF_PLAT_STR_(__SSE3__) +#else +#define TF_PLAT_STR___SSE3__ +#endif +#if defined(__SSE4A__) +#define TF_PLAT_STR___SSE4A__ TF_PLAT_STR_(__SSE4A__) +#else +#define TF_PLAT_STR___SSE4A__ +#endif +#if defined(__SSE4_1__) +#define TF_PLAT_STR___SSE4_1__ TF_PLAT_STR_(__SSE4_1__) +#else +#define TF_PLAT_STR___SSE4_1__ +#endif +#if defined(__SSE4_2__) +#define TF_PLAT_STR___SSE4_2__ TF_PLAT_STR_(__SSE4_2__) +#else +#define TF_PLAT_STR___SSE4_2__ +#endif +#if defined(__SSSE3__) +#define TF_PLAT_STR___SSSE3__ TF_PLAT_STR_(__SSSE3__) +#else +#define TF_PLAT_STR___SSSE3__ +#endif +#if defined(__TBM__) +#define TF_PLAT_STR___TBM__ TF_PLAT_STR_(__TBM__) +#else +#define TF_PLAT_STR___TBM__ +#endif +#if defined(__XOP__) +#define TF_PLAT_STR___XOP__ TF_PLAT_STR_(__XOP__) +#else +#define TF_PLAT_STR___XOP__ +#endif +#if defined(__XSAVEC__) +#define TF_PLAT_STR___XSAVEC__ TF_PLAT_STR_(__XSAVEC__) +#else +#define TF_PLAT_STR___XSAVEC__ +#endif +#if defined(__XSAVEOPT__) +#define TF_PLAT_STR___XSAVEOPT__ TF_PLAT_STR_(__XSAVEOPT__) +#else +#define TF_PLAT_STR___XSAVEOPT__ +#endif +#if defined(__XSAVES__) +#define TF_PLAT_STR___XSAVES__ TF_PLAT_STR_(__XSAVES__) +#else +#define TF_PLAT_STR___XSAVES__ +#endif +#if defined(__XSAVE__) +#define TF_PLAT_STR___XSAVE__ TF_PLAT_STR_(__XSAVE__) +#else +#define TF_PLAT_STR___XSAVE__ +#endif +#if defined(_SOFT_DOUBLE) +#define TF_PLAT_STR__SOFT_DOUBLE TF_PLAT_STR_(_SOFT_DOUBLE) +#else +#define TF_PLAT_STR__SOFT_DOUBLE +#endif +#if defined(_SOFT_FLOAT) +#define TF_PLAT_STR__SOFT_FLOAT TF_PLAT_STR_(_SOFT_FLOAT) +#else +#define TF_PLAT_STR__SOFT_FLOAT +#endif +#if defined(__ALTIVEC__) +#define TF_PLAT_STR___ALTIVEC__ TF_PLAT_STR_(__ALTIVEC__) +#else +#define TF_PLAT_STR___ALTIVEC__ +#endif +#if defined(__APPLE_ALTIVEC__) +#define TF_PLAT_STR___APPLE_ALTIVEC__ TF_PLAT_STR_(__APPLE_ALTIVEC__) +#else +#define TF_PLAT_STR___APPLE_ALTIVEC__ +#endif +#if defined(__CRYPTO__) +#define TF_PLAT_STR___CRYPTO__ TF_PLAT_STR_(__CRYPTO__) +#else +#define TF_PLAT_STR___CRYPTO__ +#endif +#if defined(__FLOAT128_HARDWARE__) +#define TF_PLAT_STR___FLOAT128_HARDWARE__ TF_PLAT_STR_(__FLOAT128_HARDWARE__) +#else +#define TF_PLAT_STR___FLOAT128_HARDWARE__ +#endif +#if defined(__FLOAT128_TYPE__) +#define TF_PLAT_STR___FLOAT128_TYPE__ TF_PLAT_STR_(__FLOAT128_TYPE__) +#else +#define TF_PLAT_STR___FLOAT128_TYPE__ +#endif +#if defined(__FP_FAST_FMA) +#define TF_PLAT_STR___FP_FAST_FMA TF_PLAT_STR_(__FP_FAST_FMA) +#else +#define TF_PLAT_STR___FP_FAST_FMA +#endif +#if defined(__FP_FAST_FMAF) +#define TF_PLAT_STR___FP_FAST_FMAF TF_PLAT_STR_(__FP_FAST_FMAF) +#else +#define TF_PLAT_STR___FP_FAST_FMAF +#endif +#if defined(__HTM__) +#define TF_PLAT_STR___HTM__ TF_PLAT_STR_(__HTM__) +#else +#define TF_PLAT_STR___HTM__ +#endif +#if defined(__NO_FPRS__) +#define TF_PLAT_STR___NO_FPRS__ TF_PLAT_STR_(__NO_FPRS__) +#else +#define TF_PLAT_STR___NO_FPRS__ +#endif +#if defined(__NO_LWSYNC__) +#define TF_PLAT_STR___NO_LWSYNC__ TF_PLAT_STR_(__NO_LWSYNC__) +#else +#define TF_PLAT_STR___NO_LWSYNC__ +#endif +#if defined(__POWER8_VECTOR__) +#define TF_PLAT_STR___POWER8_VECTOR__ TF_PLAT_STR_(__POWER8_VECTOR__) +#else +#define TF_PLAT_STR___POWER8_VECTOR__ +#endif +#if defined(__POWER9_VECTOR__) +#define TF_PLAT_STR___POWER9_VECTOR__ TF_PLAT_STR_(__POWER9_VECTOR__) +#else +#define TF_PLAT_STR___POWER9_VECTOR__ +#endif +#if defined(__PPC405__) +#define TF_PLAT_STR___PPC405__ TF_PLAT_STR_(__PPC405__) +#else +#define TF_PLAT_STR___PPC405__ +#endif +#if defined(__QUAD_MEMORY_ATOMIC__) +#define TF_PLAT_STR___QUAD_MEMORY_ATOMIC__ TF_PLAT_STR_(__QUAD_MEMORY_ATOMIC__) +#else +#define TF_PLAT_STR___QUAD_MEMORY_ATOMIC__ +#endif +#if defined(__RECIPF__) +#define TF_PLAT_STR___RECIPF__ TF_PLAT_STR_(__RECIPF__) +#else +#define TF_PLAT_STR___RECIPF__ +#endif +#if defined(__RECIP_PRECISION__) +#define TF_PLAT_STR___RECIP_PRECISION__ TF_PLAT_STR_(__RECIP_PRECISION__) +#else +#define TF_PLAT_STR___RECIP_PRECISION__ +#endif +#if defined(__RECIP__) +#define TF_PLAT_STR___RECIP__ TF_PLAT_STR_(__RECIP__) +#else +#define TF_PLAT_STR___RECIP__ +#endif +#if defined(__RSQRTEF__) +#define TF_PLAT_STR___RSQRTEF__ TF_PLAT_STR_(__RSQRTEF__) +#else +#define TF_PLAT_STR___RSQRTEF__ +#endif +#if defined(__RSQRTE__) +#define TF_PLAT_STR___RSQRTE__ TF_PLAT_STR_(__RSQRTE__) +#else +#define TF_PLAT_STR___RSQRTE__ +#endif +#if defined(__TM_FENCE__) +#define TF_PLAT_STR___TM_FENCE__ TF_PLAT_STR_(__TM_FENCE__) +#else +#define TF_PLAT_STR___TM_FENCE__ +#endif +#if defined(__UPPER_REGS_DF__) +#define TF_PLAT_STR___UPPER_REGS_DF__ TF_PLAT_STR_(__UPPER_REGS_DF__) +#else +#define TF_PLAT_STR___UPPER_REGS_DF__ +#endif +#if defined(__UPPER_REGS_SF__) +#define TF_PLAT_STR___UPPER_REGS_SF__ TF_PLAT_STR_(__UPPER_REGS_SF__) +#else +#define TF_PLAT_STR___UPPER_REGS_SF__ +#endif +#if defined(__VEC__) +#define TF_PLAT_STR___VEC__ TF_PLAT_STR_(__VEC__) +#else +#define TF_PLAT_STR___VEC__ +#endif +#if defined(__VSX__) +#define TF_PLAT_STR___VSX__ TF_PLAT_STR_(__VSX__) +#else +#define TF_PLAT_STR___VSX__ +#endif +#if defined(__ARM_ARCH) +#define TF_PLAT_STR___ARM_ARCH TF_PLAT_STR_(__ARM_ARCH) +#else +#define TF_PLAT_STR___ARM_ARCH +#endif +#if defined(__ARM_FEATURE_CLZ) +#define TF_PLAT_STR___ARM_FEATURE_CLZ TF_PLAT_STR_(__ARM_FEATURE_CLZ) +#else +#define TF_PLAT_STR___ARM_FEATURE_CLZ +#endif +#if defined(__ARM_FEATURE_CRC32) +#define TF_PLAT_STR___ARM_FEATURE_CRC32 TF_PLAT_STR_(__ARM_FEATURE_CRC32) +#else +#define TF_PLAT_STR___ARM_FEATURE_CRC32 +#endif +#if defined(__ARM_FEATURE_CRC32) +#define TF_PLAT_STR___ARM_FEATURE_CRC32 TF_PLAT_STR_(__ARM_FEATURE_CRC32) +#else +#define TF_PLAT_STR___ARM_FEATURE_CRC32 +#endif +#if defined(__ARM_FEATURE_CRYPTO) +#define TF_PLAT_STR___ARM_FEATURE_CRYPTO TF_PLAT_STR_(__ARM_FEATURE_CRYPTO) +#else +#define TF_PLAT_STR___ARM_FEATURE_CRYPTO +#endif +#if defined(__ARM_FEATURE_DIRECTED_ROUNDING) +#define TF_PLAT_STR___ARM_FEATURE_DIRECTED_ROUNDING \ + TF_PLAT_STR_(__ARM_FEATURE_DIRECTED_ROUNDING) +#else +#define TF_PLAT_STR___ARM_FEATURE_DIRECTED_ROUNDING +#endif +#if defined(__ARM_FEATURE_DSP) +#define TF_PLAT_STR___ARM_FEATURE_DSP TF_PLAT_STR_(__ARM_FEATURE_DSP) +#else +#define TF_PLAT_STR___ARM_FEATURE_DSP +#endif +#if defined(__ARM_FEATURE_FMA) +#define TF_PLAT_STR___ARM_FEATURE_FMA TF_PLAT_STR_(__ARM_FEATURE_FMA) +#else +#define TF_PLAT_STR___ARM_FEATURE_FMA +#endif +#if defined(__ARM_FEATURE_IDIV) +#define TF_PLAT_STR___ARM_FEATURE_IDIV TF_PLAT_STR_(__ARM_FEATURE_IDIV) +#else +#define TF_PLAT_STR___ARM_FEATURE_IDIV +#endif +#if defined(__ARM_FEATURE_LDREX) +#define TF_PLAT_STR___ARM_FEATURE_LDREX TF_PLAT_STR_(__ARM_FEATURE_LDREX) +#else +#define TF_PLAT_STR___ARM_FEATURE_LDREX +#endif +#if defined(__ARM_FEATURE_NUMERIC_MAXMIN) +#define TF_PLAT_STR___ARM_FEATURE_NUMERIC_MAXMIN \ + TF_PLAT_STR_(__ARM_FEATURE_NUMERIC_MAXMIN) +#else +#define TF_PLAT_STR___ARM_FEATURE_NUMERIC_MAXMIN +#endif +#if defined(__ARM_FEATURE_QBIT) +#define TF_PLAT_STR___ARM_FEATURE_QBIT TF_PLAT_STR_(__ARM_FEATURE_QBIT) +#else +#define TF_PLAT_STR___ARM_FEATURE_QBIT +#endif +#if defined(__ARM_FEATURE_QRDMX) +#define TF_PLAT_STR___ARM_FEATURE_QRDMX TF_PLAT_STR_(__ARM_FEATURE_QRDMX) +#else +#define TF_PLAT_STR___ARM_FEATURE_QRDMX +#endif +#if defined(__ARM_FEATURE_SAT) +#define TF_PLAT_STR___ARM_FEATURE_SAT TF_PLAT_STR_(__ARM_FEATURE_SAT) +#else +#define TF_PLAT_STR___ARM_FEATURE_SAT +#endif +#if defined(__ARM_FEATURE_SIMD32) +#define TF_PLAT_STR___ARM_FEATURE_SIMD32 TF_PLAT_STR_(__ARM_FEATURE_SIMD32) +#else +#define TF_PLAT_STR___ARM_FEATURE_SIMD32 +#endif +#if defined(__ARM_FEATURE_UNALIGNED) +#define TF_PLAT_STR___ARM_FEATURE_UNALIGNED \ + TF_PLAT_STR_(__ARM_FEATURE_UNALIGNED) +#else +#define TF_PLAT_STR___ARM_FEATURE_UNALIGNED +#endif +#if defined(__ARM_FP) +#define TF_PLAT_STR___ARM_FP TF_PLAT_STR_(__ARM_FP) +#else +#define TF_PLAT_STR___ARM_FP +#endif +#if defined(__ARM_NEON_FP) +#define TF_PLAT_STR___ARM_NEON_FP TF_PLAT_STR_(__ARM_NEON_FP) +#else +#define TF_PLAT_STR___ARM_NEON_FP +#endif +#if defined(__ARM_NEON__) +#define TF_PLAT_STR___ARM_NEON__ TF_PLAT_STR_(__ARM_NEON__) +#else +#define TF_PLAT_STR___ARM_NEON__ +#endif +#if defined(__ARM_WMMX) +#define TF_PLAT_STR___ARM_WMMX TF_PLAT_STR_(__ARM_WMMX) +#else +#define TF_PLAT_STR___ARM_WMMX +#endif +#if defined(__IWMMXT2__) +#define TF_PLAT_STR___IWMMXT2__ TF_PLAT_STR_(__IWMMXT2__) +#else +#define TF_PLAT_STR___IWMMXT2__ +#endif +#if defined(__IWMMXT__) +#define TF_PLAT_STR___IWMMXT__ TF_PLAT_STR_(__IWMMXT__) +#else +#define TF_PLAT_STR___IWMMXT__ +#endif +#if defined(__VFP_FP__) +#define TF_PLAT_STR___VFP_FP__ TF_PLAT_STR_(__VFP_FP__) +#else +#define TF_PLAT_STR___VFP_FP__ +#endif +#if defined(TARGET_IPHONE_SIMULATOR) +#define TF_PLAT_STR_TARGET_IPHONE_SIMULATOR \ + TF_PLAT_STR_(TARGET_IPHONE_SIMULATOR) +#else +#define TF_PLAT_STR_TARGET_IPHONE_SIMULATOR +#endif +#if defined(TARGET_OS_IOS) +#define TF_PLAT_STR_TARGET_OS_IOS TF_PLAT_STR_(TARGET_OS_IOS) +#else +#define TF_PLAT_STR_TARGET_OS_IOS +#endif +#if defined(TARGET_OS_IPHONE) +#define TF_PLAT_STR_TARGET_OS_IPHONE TF_PLAT_STR_(TARGET_OS_IPHONE) +#else +#define TF_PLAT_STR_TARGET_OS_IPHONE +#endif +#if defined(_MSC_VER) +#define TF_PLAT_STR__MSC_VER TF_PLAT_STR_(_MSC_VER) +#else +#define TF_PLAT_STR__MSC_VER +#endif +#if defined(_M_ARM) +#define TF_PLAT_STR__M_ARM TF_PLAT_STR_(_M_ARM) +#else +#define TF_PLAT_STR__M_ARM +#endif +#if defined(_M_ARM64) +#define TF_PLAT_STR__M_ARM64 TF_PLAT_STR_(_M_ARM64) +#else +#define TF_PLAT_STR__M_ARM64 +#endif +#if defined(_M_ARM_ARMV7VE) +#define TF_PLAT_STR__M_ARM_ARMV7VE TF_PLAT_STR_(_M_ARM_ARMV7VE) +#else +#define TF_PLAT_STR__M_ARM_ARMV7VE +#endif +#if defined(_M_ARM_FP) +#define TF_PLAT_STR__M_ARM_FP TF_PLAT_STR_(_M_ARM_FP) +#else +#define TF_PLAT_STR__M_ARM_FP +#endif +#if defined(_M_IX86) +#define TF_PLAT_STR__M_IX86 TF_PLAT_STR_(_M_IX86) +#else +#define TF_PLAT_STR__M_IX86 +#endif +#if defined(_M_X64) +#define TF_PLAT_STR__M_X64 TF_PLAT_STR_(_M_X64) +#else +#define TF_PLAT_STR__M_X64 +#endif +#if defined(_WIN32) +#define TF_PLAT_STR__WIN32 TF_PLAT_STR_(_WIN32) +#else +#define TF_PLAT_STR__WIN32 +#endif +#if defined(_WIN64) +#define TF_PLAT_STR__WIN64 TF_PLAT_STR_(_WIN64) +#else +#define TF_PLAT_STR__WIN64 +#endif +#if defined(__ANDROID__) +#define TF_PLAT_STR___ANDROID__ TF_PLAT_STR_(__ANDROID__) +#else +#define TF_PLAT_STR___ANDROID__ +#endif +#if defined(__APPLE__) +#define TF_PLAT_STR___APPLE__ TF_PLAT_STR_(__APPLE__) +#else +#define TF_PLAT_STR___APPLE__ +#endif +#if defined(__BYTE_ORDER__) +#define TF_PLAT_STR___BYTE_ORDER__ TF_PLAT_STR_(__BYTE_ORDER__) +#else +#define TF_PLAT_STR___BYTE_ORDER__ +#endif +#if defined(__CYGWIN__) +#define TF_PLAT_STR___CYGWIN__ TF_PLAT_STR_(__CYGWIN__) +#else +#define TF_PLAT_STR___CYGWIN__ +#endif +#if defined(__FreeBSD__) +#define TF_PLAT_STR___FreeBSD__ TF_PLAT_STR_(__FreeBSD__) +#else +#define TF_PLAT_STR___FreeBSD__ +#endif +#if defined(__LITTLE_ENDIAN__) +#define TF_PLAT_STR___LITTLE_ENDIAN__ TF_PLAT_STR_(__LITTLE_ENDIAN__) +#else +#define TF_PLAT_STR___LITTLE_ENDIAN__ +#endif +#if defined(__NetBSD__) +#define TF_PLAT_STR___NetBSD__ TF_PLAT_STR_(__NetBSD__) +#else +#define TF_PLAT_STR___NetBSD__ +#endif +#if defined(__OpenBSD__) +#define TF_PLAT_STR___OpenBSD__ TF_PLAT_STR_(__OpenBSD__) +#else +#define TF_PLAT_STR___OpenBSD__ +#endif +#if defined(____MSYS__) +#define TF_PLAT_STR_____MSYS__ TF_PLAT_STR_(____MSYS__) +#else +#define TF_PLAT_STR_____MSYS__ +#endif +#if defined(__aarch64__) +#define TF_PLAT_STR___aarch64__ TF_PLAT_STR_(__aarch64__) +#else +#define TF_PLAT_STR___aarch64__ +#endif +#if defined(__alpha__) +#define TF_PLAT_STR___alpha__ TF_PLAT_STR_(__alpha__) +#else +#define TF_PLAT_STR___alpha__ +#endif +#if defined(__arm__) +#define TF_PLAT_STR___arm__ TF_PLAT_STR_(__arm__) +#else +#define TF_PLAT_STR___arm__ +#endif +#if defined(__i386__) +#define TF_PLAT_STR___i386__ TF_PLAT_STR_(__i386__) +#else +#define TF_PLAT_STR___i386__ +#endif +#if defined(__i686__) +#define TF_PLAT_STR___i686__ TF_PLAT_STR_(__i686__) +#else +#define TF_PLAT_STR___i686__ +#endif +#if defined(__ia64__) +#define TF_PLAT_STR___ia64__ TF_PLAT_STR_(__ia64__) +#else +#define TF_PLAT_STR___ia64__ +#endif +#if defined(__linux__) +#define TF_PLAT_STR___linux__ TF_PLAT_STR_(__linux__) +#else +#define TF_PLAT_STR___linux__ +#endif +#if defined(__mips32__) +#define TF_PLAT_STR___mips32__ TF_PLAT_STR_(__mips32__) +#else +#define TF_PLAT_STR___mips32__ +#endif +#if defined(__mips64__) +#define TF_PLAT_STR___mips64__ TF_PLAT_STR_(__mips64__) +#else +#define TF_PLAT_STR___mips64__ +#endif +#if defined(__powerpc64__) +#define TF_PLAT_STR___powerpc64__ TF_PLAT_STR_(__powerpc64__) +#else +#define TF_PLAT_STR___powerpc64__ +#endif +#if defined(__powerpc__) +#define TF_PLAT_STR___powerpc__ TF_PLAT_STR_(__powerpc__) +#else +#define TF_PLAT_STR___powerpc__ +#endif +#if defined(__riscv___) +#define TF_PLAT_STR___riscv___ TF_PLAT_STR_(__riscv___) +#else +#define TF_PLAT_STR___riscv___ +#endif +#if defined(__s390x__) +#define TF_PLAT_STR___s390x__ TF_PLAT_STR_(__s390x__) +#else +#define TF_PLAT_STR___s390x__ +#endif +#if defined(__sparc64__) +#define TF_PLAT_STR___sparc64__ TF_PLAT_STR_(__sparc64__) +#else +#define TF_PLAT_STR___sparc64__ +#endif +#if defined(__sparc__) +#define TF_PLAT_STR___sparc__ TF_PLAT_STR_(__sparc__) +#else +#define TF_PLAT_STR___sparc__ +#endif +#if defined(__x86_64__) +#define TF_PLAT_STR___x86_64__ TF_PLAT_STR_(__x86_64__) +#else +#define TF_PLAT_STR___x86_64__ +#endif + +#endif // TENSORFLOW_CORE_PLATFORM_PLATFORM_STRINGS_COMPUTED_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/platform/png.h b/third_party/tflite-hdrs/tensorflow/core/platform/png.h new file mode 100644 index 00000000..fc1a3421 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/platform/png.h @@ -0,0 +1,30 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PLATFORM_PNG_H_ +#define TENSORFLOW_CORE_PLATFORM_PNG_H_ + +#include "tensorflow/core/platform/platform.h" + +#if defined(PLATFORM_GOOGLE) && !defined(IS_MOBILE_PLATFORM) +#include "png.h" // from @png // IWYU pragma: export +#elif defined(PLATFORM_POSIX) || defined(PLATFORM_WINDOWS) || \ + defined(PLATFORM_POSIX_ANDROID) || defined(IS_MOBILE_PLATFORM) +#include // IWYU pragma: export +#else +#error Define the appropriate PLATFORM_ macro for this platform +#endif + +#endif // TENSORFLOW_CORE_PLATFORM_PNG_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/platform/prefetch.h b/third_party/tflite-hdrs/tensorflow/core/platform/prefetch.h new file mode 100644 index 00000000..019493f6 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/platform/prefetch.h @@ -0,0 +1,32 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PLATFORM_PREFETCH_H_ +#define TENSORFLOW_CORE_PLATFORM_PREFETCH_H_ + +#include "tsl/platform/prefetch.h" + +namespace tensorflow { +namespace port { +// NOLINTBEGIN(misc-unused-using-decls) +using ::tsl::port::prefetch; +using ::tsl::port::PREFETCH_HINT_NTA; +using ::tsl::port::PREFETCH_HINT_T0; +using ::tsl::port::PrefetchHint; +// NOLINTEND(misc-unused-using-decls) +} // namespace port +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PLATFORM_PREFETCH_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/platform/profile_utils/android_armv7a_cpu_utils_helper.h b/third_party/tflite-hdrs/tensorflow/core/platform/profile_utils/android_armv7a_cpu_utils_helper.h new file mode 100644 index 00000000..610f507c --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/platform/profile_utils/android_armv7a_cpu_utils_helper.h @@ -0,0 +1,40 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PLATFORM_PROFILE_UTILS_ANDROID_ARMV7A_CPU_UTILS_HELPER_H_ +#define TENSORFLOW_CORE_PLATFORM_PROFILE_UTILS_ANDROID_ARMV7A_CPU_UTILS_HELPER_H_ + +#include + +#include "xla/tsl/platform/profile_utils/android_armv7a_cpu_utils_helper.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/profile_utils/i_cpu_utils_helper.h" +#include "tensorflow/core/platform/types.h" + +#if defined(__ANDROID__) && (__ANDROID_API__ >= 21) && \ + (defined(__ARM_ARCH_7A__) || defined(__aarch64__)) + +struct perf_event_attr; + +namespace tensorflow { +namespace profile_utils { +using tsl::profile_utils::AndroidArmV7ACpuUtilsHelper; // NOLINT +} // namespace profile_utils +} // namespace tensorflow + +#endif // defined(__ANDROID__) && (__ANDROID_API__ >= 21) && + // (defined(__ARM_ARCH_7A__) || defined(__aarch64__)) + +#endif // TENSORFLOW_CORE_PLATFORM_PROFILE_UTILS_ANDROID_ARMV7A_CPU_UTILS_HELPER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/platform/profile_utils/clock_cycle_profiler.h b/third_party/tflite-hdrs/tensorflow/core/platform/profile_utils/clock_cycle_profiler.h new file mode 100644 index 00000000..da58a612 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/platform/profile_utils/clock_cycle_profiler.h @@ -0,0 +1,30 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PLATFORM_PROFILE_UTILS_CLOCK_CYCLE_PROFILER_H_ +#define TENSORFLOW_CORE_PLATFORM_PROFILE_UTILS_CLOCK_CYCLE_PROFILER_H_ + +#include + +#include "xla/tsl/platform/profile_utils/clock_cycle_profiler.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/profile_utils/cpu_utils.h" + +namespace tensorflow { +using tsl::ClockCycleProfiler; // NOLINT +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PLATFORM_PROFILE_UTILS_CLOCK_CYCLE_PROFILER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/platform/profile_utils/cpu_utils.h b/third_party/tflite-hdrs/tensorflow/core/platform/profile_utils/cpu_utils.h new file mode 100644 index 00000000..fde59166 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/platform/profile_utils/cpu_utils.h @@ -0,0 +1,36 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// This class is designed to get accurate profile for programs. + +#ifndef TENSORFLOW_CORE_PLATFORM_PROFILE_UTILS_CPU_UTILS_H_ +#define TENSORFLOW_CORE_PLATFORM_PROFILE_UTILS_CPU_UTILS_H_ + +#include +#include + +#include "xla/tsl/platform/profile_utils/cpu_utils.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/profile_utils/i_cpu_utils_helper.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +namespace profile_utils { +using tsl::profile_utils::CpuUtils; // NOLINT +} // namespace profile_utils + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PLATFORM_PROFILE_UTILS_CPU_UTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/platform/profile_utils/i_cpu_utils_helper.h b/third_party/tflite-hdrs/tensorflow/core/platform/profile_utils/i_cpu_utils_helper.h new file mode 100644 index 00000000..f9357c6c --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/platform/profile_utils/i_cpu_utils_helper.h @@ -0,0 +1,29 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PLATFORM_PROFILE_UTILS_I_CPU_UTILS_HELPER_H_ +#define TENSORFLOW_CORE_PLATFORM_PROFILE_UTILS_I_CPU_UTILS_HELPER_H_ + +#include "xla/tsl/platform/profile_utils/i_cpu_utils_helper.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +namespace profile_utils { +using tsl::profile_utils::ICpuUtilsHelper; // NOLINT +} // namespace profile_utils +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PLATFORM_PROFILE_UTILS_I_CPU_UTILS_HELPER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/platform/protobuf.h b/third_party/tflite-hdrs/tensorflow/core/platform/protobuf.h new file mode 100644 index 00000000..d7dda8b3 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/platform/protobuf.h @@ -0,0 +1,39 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PLATFORM_PROTOBUF_H_ +#define TENSORFLOW_CORE_PLATFORM_PROTOBUF_H_ + +#include "tensorflow/core/platform/platform.h" +#include "tensorflow/core/platform/types.h" +#include "tsl/platform/protobuf.h" + +namespace tensorflow { +namespace protobuf = tsl::protobuf; // NOLINT(misc-unused-alias-decls) +// NOLINTBEGIN(misc-unused-using-decls) +using tsl::kProtobufInt64Typename; +using tsl::kProtobufUint64Typename; +using tsl::ParseFromTString; +using tsl::ParseProtoUnlimited; +using tsl::protobuf_int64; +using tsl::protobuf_uint64; +using tsl::ProtobufStringToString; +using tsl::SerializeToTString; +using tsl::SetProtobufStringSwapAllowed; +using tsl::TStringOutputStream; +// NOLINTEND(misc-unused-using-decls) +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PLATFORM_PROTOBUF_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/platform/protobuf_internal.h b/third_party/tflite-hdrs/tensorflow/core/platform/protobuf_internal.h new file mode 100644 index 00000000..b766b42b --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/platform/protobuf_internal.h @@ -0,0 +1,45 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PLATFORM_PROTOBUF_INTERNAL_H_ +#define TENSORFLOW_CORE_PLATFORM_PROTOBUF_INTERNAL_H_ + +#include "google/protobuf/any.pb.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/platform.h" +#include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +// Utility for parsing an Any value with full or lite protos. +template +absl::Status ParseAny(const google::protobuf::Any& any, T* message, + const string& type_name) { + CHECK_EQ(type_name, message->GetTypeName()); + if (!any.Is()) { + return errors::FailedPrecondition( + "Expected Any type_url for: ", message->GetTypeName(), + ". Got: ", string(any.type_url().data(), any.type_url().size()), "."); + } + if (!any.UnpackTo(message)) { + return errors::FailedPrecondition("Failed to unpack: ", any.DebugString()); + } + return absl::OkStatus(); +} + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PLATFORM_PROTOBUF_INTERNAL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/platform/ram_file_system.h b/third_party/tflite-hdrs/tensorflow/core/platform/ram_file_system.h new file mode 100644 index 00000000..2043737b --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/platform/ram_file_system.h @@ -0,0 +1,33 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PLATFORM_RAM_FILE_SYSTEM_H_ +#define TENSORFLOW_CORE_PLATFORM_RAM_FILE_SYSTEM_H_ + +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/file_system.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/stringpiece.h" +#include "tensorflow/core/platform/types.h" +#include "tsl/platform/ram_file_system.h" + +namespace tensorflow { +// NOLINTBEGIN(misc-unused-using-decls) +using tsl::RamFileSystem; +using tsl::RamRandomAccessFile; +// NOLINTEND(misc-unused-using-decls) +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PLATFORM_RAM_FILE_SYSTEM_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/platform/random.h b/third_party/tflite-hdrs/tensorflow/core/platform/random.h new file mode 100644 index 00000000..ceb54e4a --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/platform/random.h @@ -0,0 +1,29 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PLATFORM_RANDOM_H_ +#define TENSORFLOW_CORE_PLATFORM_RANDOM_H_ + +#include "tensorflow/core/platform/types.h" +#include "tsl/platform/random.h" + +namespace tensorflow { +namespace random { +using tsl::random::New64; // NOLINT +using tsl::random::New64DefaultSeed; // NOLINT +} // namespace random +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PLATFORM_RANDOM_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/platform/raw_coding.h b/third_party/tflite-hdrs/tensorflow/core/platform/raw_coding.h new file mode 100644 index 00000000..9b3c31d6 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/platform/raw_coding.h @@ -0,0 +1,33 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PLATFORM_RAW_CODING_H_ +#define TENSORFLOW_CORE_PLATFORM_RAW_CODING_H_ + +#include + +#include "tsl/platform/raw_coding.h" + +namespace tensorflow { +namespace core { +// NOLINTBEGIN(misc-unused-using-decls) +using ::tsl::core::DecodeFixed16; +using ::tsl::core::DecodeFixed32; +using ::tsl::core::DecodeFixed64; +// NOLINTEND(misc-unused-using-decls) +} // namespace core +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PLATFORM_RAW_CODING_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/platform/refcount.h b/third_party/tflite-hdrs/tensorflow/core/platform/refcount.h new file mode 100644 index 00000000..9d8b21b7 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/platform/refcount.h @@ -0,0 +1,36 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PLATFORM_REFCOUNT_H_ +#define TENSORFLOW_CORE_PLATFORM_REFCOUNT_H_ + +#include "tensorflow/core/platform/mutex.h" +#include "tsl/platform/refcount.h" + +namespace tensorflow { +namespace core { +// NOLINTBEGIN(misc-unused-using-decls) +using ::tsl::core::RefCountDeleter; +using ::tsl::core::RefCounted; +using ::tsl::core::RefCountPtr; +using ::tsl::core::ScopedUnref; +using ::tsl::core::WeakNotifyFn; +using ::tsl::core::WeakPtr; +using ::tsl::core::WeakRefCounted; +// NOLINTEND(misc-unused-using-decls) +} // namespace core +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PLATFORM_REFCOUNT_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/platform/regexp.h b/third_party/tflite-hdrs/tensorflow/core/platform/regexp.h new file mode 100644 index 00000000..0c2025ad --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/platform/regexp.h @@ -0,0 +1,20 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PLATFORM_REGEXP_H_ +#define TENSORFLOW_CORE_PLATFORM_REGEXP_H_ +#include "tsl/platform/regexp.h" + +#endif // TENSORFLOW_CORE_PLATFORM_REGEXP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/platform/resource.h b/third_party/tflite-hdrs/tensorflow/core/platform/resource.h new file mode 100644 index 00000000..1088b388 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/platform/resource.h @@ -0,0 +1,29 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PLATFORM_RESOURCE_H_ +#define TENSORFLOW_CORE_PLATFORM_RESOURCE_H_ + +#include + +#include "tsl/platform/resource.h" + +namespace tensorflow { + +using ::tsl::ResourceTagger; // NOLINT(misc-unused-using-decls) + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PLATFORM_RESOURCE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/platform/resource_loader.h b/third_party/tflite-hdrs/tensorflow/core/platform/resource_loader.h new file mode 100644 index 00000000..e4d6d56e --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/platform/resource_loader.h @@ -0,0 +1,32 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Small helper library to access "data" dependencies defined in BUILD files. +// Requires the relative paths starting from tensorflow/... +// For example, to get this file, a user would call: +// GetDataDependencyFilepath("tensorflow/core/platform/resource_loadder.h") + +#ifndef TENSORFLOW_CORE_PLATFORM_RESOURCE_LOADER_H_ +#define TENSORFLOW_CORE_PLATFORM_RESOURCE_LOADER_H_ + +#include "tsl/platform/resource_loader.h" + +namespace tensorflow { + +using tsl::GetDataDependencyFilepath; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PLATFORM_RESOURCE_LOADER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/platform/retrying_file_system.h b/third_party/tflite-hdrs/tensorflow/core/platform/retrying_file_system.h new file mode 100644 index 00000000..c8eb328c --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/platform/retrying_file_system.h @@ -0,0 +1,37 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PLATFORM_RETRYING_FILE_SYSTEM_H_ +#define TENSORFLOW_CORE_PLATFORM_RETRYING_FILE_SYSTEM_H_ + +#include +#include +#include + +#include "tensorflow/core/lib/random/random.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/file_system.h" +#include "tensorflow/core/platform/retrying_utils.h" +#include "tensorflow/core/platform/status.h" +#include "tsl/platform/retrying_file_system.h" + +namespace tensorflow { + +using tsl::RetryingFileSystem; // NOLINT(misc-unused-using-decls) + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PLATFORM_RETRYING_FILE_SYSTEM_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/platform/retrying_utils.h b/third_party/tflite-hdrs/tensorflow/core/platform/retrying_utils.h new file mode 100644 index 00000000..a42d02ad --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/platform/retrying_utils.h @@ -0,0 +1,31 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PLATFORM_RETRYING_UTILS_H_ +#define TENSORFLOW_CORE_PLATFORM_RETRYING_UTILS_H_ + +#include + +#include "tensorflow/core/platform/status.h" +#include "tsl/platform/retrying_utils.h" + +namespace tensorflow { +// NOLINTBEGIN(misc-unused-using-decls) +using tsl::RetryConfig; +using tsl::RetryingUtils; +// NOLINTEND(misc-unused-using-decls) +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PLATFORM_RETRYING_UTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/platform/rocm.h b/third_party/tflite-hdrs/tensorflow/core/platform/rocm.h new file mode 100644 index 00000000..8fc0fa9d --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/platform/rocm.h @@ -0,0 +1,21 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PLATFORM_ROCM_H_ +#define TENSORFLOW_CORE_PLATFORM_ROCM_H_ + +#include "tensorflow/core/platform/platform.h" // IWYU pragma: keep + +#endif // TENSORFLOW_CORE_PLATFORM_ROCM_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/platform/rocm_rocdl_path.h b/third_party/tflite-hdrs/tensorflow/core/platform/rocm_rocdl_path.h new file mode 100644 index 00000000..dc656131 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/platform/rocm_rocdl_path.h @@ -0,0 +1,27 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PLATFORM_ROCM_ROCDL_PATH_H_ +#define TENSORFLOW_CORE_PLATFORM_ROCM_ROCDL_PATH_H_ + +#include "tensorflow/core/platform/types.h" +#include "tsl/platform/rocm_rocdl_path.h" + +namespace tensorflow { +using tsl::RocdlRoot; // NOLINT +using tsl::RocmRoot; // NOLINT +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PLATFORM_ROCM_ROCDL_PATH_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/platform/scanner.h b/third_party/tflite-hdrs/tensorflow/core/platform/scanner.h new file mode 100644 index 00000000..edea0a65 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/platform/scanner.h @@ -0,0 +1,29 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PLATFORM_SCANNER_H_ +#define TENSORFLOW_CORE_PLATFORM_SCANNER_H_ + +#include "tsl/platform/scanner.h" + +namespace tensorflow { +namespace strings { + +using ::tsl::strings::Scanner; // NOLINT(misc-unused-using-decls) + +} // namespace strings +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PLATFORM_SCANNER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/platform/setround.h b/third_party/tflite-hdrs/tensorflow/core/platform/setround.h new file mode 100644 index 00000000..efd1f03e --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/platform/setround.h @@ -0,0 +1,29 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PLATFORM_SETROUND_H_ +#define TENSORFLOW_CORE_PLATFORM_SETROUND_H_ + +#include "tensorflow/core/platform/macros.h" +#include "tsl/platform/setround.h" + +namespace tensorflow { +namespace port { +using tsl::port::ScopedSetRound; // NOLINT + +} // namespace port +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PLATFORM_SETROUND_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/platform/snappy.h b/third_party/tflite-hdrs/tensorflow/core/platform/snappy.h new file mode 100644 index 00000000..53fa5de6 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/platform/snappy.h @@ -0,0 +1,40 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PLATFORM_SNAPPY_H_ +#define TENSORFLOW_CORE_PLATFORM_SNAPPY_H_ + +#include "tensorflow/core/platform/types.h" +#include "tsl/platform/snappy.h" + +#if !defined(PLATFORM_WINDOWS) +#include +#else +namespace tensorflow { +using tsl::iovec; +} // namespace tensorflow +#endif + +namespace tensorflow { +namespace port { +using tsl::port::Snappy_Compress; +using tsl::port::Snappy_CompressFromIOVec; +using tsl::port::Snappy_GetUncompressedLength; +using tsl::port::Snappy_Uncompress; +using tsl::port::Snappy_UncompressToIOVec; +} // namespace port +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PLATFORM_SNAPPY_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/platform/stack_frame.h b/third_party/tflite-hdrs/tensorflow/core/platform/stack_frame.h new file mode 100644 index 00000000..cd5c3ff1 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/platform/stack_frame.h @@ -0,0 +1,25 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PLATFORM_STACK_FRAME_H_ +#define TENSORFLOW_CORE_PLATFORM_STACK_FRAME_H_ + +#include "tsl/platform/stack_frame.h" + +namespace tensorflow { +typedef tsl::StackFrame StackFrame; +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PLATFORM_STACK_FRAME_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/platform/stacktrace.h b/third_party/tflite-hdrs/tensorflow/core/platform/stacktrace.h new file mode 100644 index 00000000..b8aaf464 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/platform/stacktrace.h @@ -0,0 +1,30 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PLATFORM_STACKTRACE_H_ +#define TENSORFLOW_CORE_PLATFORM_STACKTRACE_H_ + +#include "tensorflow/core/platform/platform.h" // IWYU pragma: export +#include "tsl/platform/stacktrace.h" + +namespace tensorflow { +// NOLINTBEGIN(misc-unused-using-decls) +using tsl::CurrentStackTrace; +using tsl::DebugWriteToString; +using tsl::SavedStackTrace; +// NOLINTEND(misc-unused-using-decls) +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PLATFORM_STACKTRACE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/platform/stacktrace_handler.h b/third_party/tflite-hdrs/tensorflow/core/platform/stacktrace_handler.h new file mode 100644 index 00000000..8a81a6a7 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/platform/stacktrace_handler.h @@ -0,0 +1,33 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PLATFORM_STACKTRACE_HANDLER_H_ +#define TENSORFLOW_CORE_PLATFORM_STACKTRACE_HANDLER_H_ + +#include "tsl/platform/stacktrace_handler.h" + +namespace tensorflow { +namespace testing { + +// Installs signal handlers to print out stack trace. +// Although GoogleTest has support for generating stacktraces with abseil via +// https://github.com/google/googletest/pull/1653, this doesn't cover our use +// case of getting C++ stacktraces in our python tests. +using tsl::testing::InstallStacktraceHandler; + +} // namespace testing +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PLATFORM_STACKTRACE_HANDLER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/platform/status.h b/third_party/tflite-hdrs/tensorflow/core/platform/status.h new file mode 100644 index 00000000..99f66009 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/platform/status.h @@ -0,0 +1,65 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PLATFORM_STATUS_H_ +#define TENSORFLOW_CORE_PLATFORM_STATUS_H_ + +#include "absl/base/macros.h" +#include "absl/status/status.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/stack_frame.h" +#include "tensorflow/core/platform/types.h" +#include "tsl/platform/status.h" + +#if !defined(ABSL_DEPRECATE_AND_INLINE) +#define ABSL_DEPRECATE_AND_INLINE() +#endif + +namespace tensorflow { +// NOLINTBEGIN(misc-unused-using-decls) +#ifdef SWIG +using tsl::FromAbslStatus; +using tsl::OkStatus; +using tsl::Status; +using tsl::ToAbslStatus; +#else +ABSL_DEPRECATE_AND_INLINE() +inline ::absl::Status FromAbslStatus(const ::absl::Status& s) { return s; } +ABSL_DEPRECATE_AND_INLINE() +inline ::absl::Status ToAbslStatus(const ::absl::Status& s) { return s; } +ABSL_DEPRECATE_AND_INLINE() +inline ::absl::Status OkStatus() { return ::absl::OkStatus(); }; +using Status ABSL_DEPRECATE_AND_INLINE() = ::absl::Status; +#endif +using tsl::StatusCallback; +using tsl::StatusGroup; +using tsl::TfCheckOpHelper; +using tsl::TfCheckOpHelperOutOfLine; + +namespace errors { +#ifdef SWIG +using tsl::errors::Code; +#else +using Code ABSL_DEPRECATE_AND_INLINE() = ::absl::StatusCode; +#endif +using tsl::errors::GetStackTrace; +using tsl::errors::SetStackTrace; +} // namespace errors +// NOLINTEND(misc-unused-using-decls) + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PLATFORM_STATUS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/platform/status_matchers.h b/third_party/tflite-hdrs/tensorflow/core/platform/status_matchers.h new file mode 100644 index 00000000..6fd5791f --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/platform/status_matchers.h @@ -0,0 +1,46 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_PLATFORM_STATUS_MATCHERS_H_ +#define TENSORFLOW_CORE_PLATFORM_STATUS_MATCHERS_H_ + + +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/statusor.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/protobuf/error_codes.pb.h" +#include "tsl/platform/status_matchers.h" + +namespace tensorflow { +// NOLINTBEGIN(misc-unused-using-decls) + +namespace testing { +namespace internal_status { +using tsl::testing::internal_status::GetStatus; +using tsl::testing::internal_status::IsOkAndHoldsMatcher; +using tsl::testing::internal_status::IsOkAndHoldsMatcherImpl; +using tsl::testing::internal_status::IsOkMatcher; +using tsl::testing::internal_status::MonoIsOkMatcherImpl; +using tsl::testing::internal_status::MonoStatusIsMatcherImpl; +using tsl::testing::internal_status::StatusIsMatcher; +using tsl::testing::internal_status::StatusIsMatcherCommonImpl; +} // namespace internal_status +using tsl::testing::IsOk; +using tsl::testing::IsOkAndHolds; +using tsl::testing::StatusIs; +// NOLINTEND(misc-unused-using-decls) +} // namespace testing +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PLATFORM_STATUS_MATCHERS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/platform/statusor.h b/third_party/tflite-hdrs/tensorflow/core/platform/statusor.h new file mode 100644 index 00000000..1a5f77e8 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/platform/statusor.h @@ -0,0 +1,26 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PLATFORM_STATUSOR_H_ +#define TENSORFLOW_CORE_PLATFORM_STATUSOR_H_ + +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/status.h" +#include "tsl/platform/statusor.h" +namespace tensorflow { +using tsl::StatusOr; // NOLINT +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PLATFORM_STATUSOR_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/platform/str_util.h b/third_party/tflite-hdrs/tensorflow/core/platform/str_util.h new file mode 100644 index 00000000..fbea09af --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/platform/str_util.h @@ -0,0 +1,61 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PLATFORM_STR_UTIL_H_ +#define TENSORFLOW_CORE_PLATFORM_STR_UTIL_H_ + +#include +#include + +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/stringpiece.h" +#include "tensorflow/core/platform/types.h" +#include "tsl/platform/str_util.h" + +// Basic string utility routines +namespace tensorflow { +namespace str_util { +// NOLINTBEGIN(misc-unused-using-decls) +using tsl::str_util::AllowEmpty; +using tsl::str_util::ArgDefCase; +using tsl::str_util::CEscape; +using tsl::str_util::ConsumeLeadingDigits; +using tsl::str_util::ConsumeNonWhitespace; +using tsl::str_util::ConsumePrefix; +using tsl::str_util::ConsumeSuffix; +using tsl::str_util::CUnescape; +using tsl::str_util::EndsWith; +using tsl::str_util::Join; +using tsl::str_util::Lowercase; +using tsl::str_util::RemoveLeadingWhitespace; +using tsl::str_util::RemoveTrailingWhitespace; +using tsl::str_util::RemoveWhitespaceContext; +using tsl::str_util::SkipEmpty; +using tsl::str_util::SkipWhitespace; +using tsl::str_util::Split; +using tsl::str_util::StartsWith; +using tsl::str_util::StrContains; +using tsl::str_util::StringReplace; +using tsl::str_util::StripPrefix; +using tsl::str_util::StripSuffix; +using tsl::str_util::StripTrailingWhitespace; +using tsl::str_util::Strnlen; +using tsl::str_util::TitlecaseString; +using tsl::str_util::Uppercase; +// NOLINTEND(misc-unused-using-decls) +} // namespace str_util +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PLATFORM_STR_UTIL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/platform/strcat.h b/third_party/tflite-hdrs/tensorflow/core/platform/strcat.h new file mode 100644 index 00000000..9a11dd2d --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/platform/strcat.h @@ -0,0 +1,54 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PLATFORM_STRCAT_H_ +#define TENSORFLOW_CORE_PLATFORM_STRCAT_H_ + +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/numbers.h" +#include "tensorflow/core/platform/stringpiece.h" +#include "tensorflow/core/platform/types.h" +#include "tsl/platform/strcat.h" + +namespace tensorflow { +namespace strings { + +// NOLINTBEGIN(misc-unused-using-decls) +using tsl::strings::AlphaNum; +using tsl::strings::Hex; +using tsl::strings::kZeroPad10; +using tsl::strings::kZeroPad11; +using tsl::strings::kZeroPad12; +using tsl::strings::kZeroPad13; +using tsl::strings::kZeroPad14; +using tsl::strings::kZeroPad15; +using tsl::strings::kZeroPad16; +using tsl::strings::kZeroPad2; +using tsl::strings::kZeroPad3; +using tsl::strings::kZeroPad4; +using tsl::strings::kZeroPad5; +using tsl::strings::kZeroPad6; +using tsl::strings::kZeroPad7; +using tsl::strings::kZeroPad8; +using tsl::strings::kZeroPad9; +using tsl::strings::PadSpec; +using tsl::strings::StrAppend; +using tsl::strings::StrCat; +// NOLINTEND(misc-unused-using-decls) + +} // namespace strings +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PLATFORM_STRCAT_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/platform/stream_executor.h b/third_party/tflite-hdrs/tensorflow/core/platform/stream_executor.h new file mode 100644 index 00000000..58acf8eb --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/platform/stream_executor.h @@ -0,0 +1,34 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PLATFORM_STREAM_EXECUTOR_H_ +#define TENSORFLOW_CORE_PLATFORM_STREAM_EXECUTOR_H_ + +#include "xla/stream_executor/cuda/cuda_platform_id.h" +#include "xla/stream_executor/device_memory.h" +#include "xla/stream_executor/dnn.h" +#include "xla/stream_executor/event.h" +#include "xla/stream_executor/host/host_platform_id.h" +#include "xla/stream_executor/platform.h" +#include "xla/stream_executor/platform_manager.h" +#include "xla/stream_executor/rocm/rocm_platform_id.h" +#include "xla/stream_executor/scratch_allocator.h" +#include "xla/stream_executor/stream.h" +#include "xla/stream_executor/stream_executor.h" +#include "tensorflow/core/platform/platform.h" +#include "tensorflow/core/platform/types.h" +#include "tsl/platform/dso_loader.h" + +#endif // TENSORFLOW_CORE_PLATFORM_STREAM_EXECUTOR_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/platform/stream_executor_no_cuda.h b/third_party/tflite-hdrs/tensorflow/core/platform/stream_executor_no_cuda.h new file mode 100644 index 00000000..e6013d76 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/platform/stream_executor_no_cuda.h @@ -0,0 +1,33 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PLATFORM_STREAM_EXECUTOR_NO_CUDA_H_ +#define TENSORFLOW_CORE_PLATFORM_STREAM_EXECUTOR_NO_CUDA_H_ + +#include "xla/stream_executor/cuda/cuda_platform_id.h" +#include "xla/stream_executor/device_memory.h" +#include "xla/stream_executor/dnn.h" +#include "xla/stream_executor/event.h" +#include "xla/stream_executor/host/host_platform_id.h" +#include "xla/stream_executor/platform.h" +#include "xla/stream_executor/platform_manager.h" +#include "xla/stream_executor/rocm/rocm_platform_id.h" +#include "xla/stream_executor/scratch_allocator.h" +#include "xla/stream_executor/stream.h" +#include "xla/stream_executor/stream_executor.h" +#include "tensorflow/core/platform/platform.h" +#include "tsl/platform/dso_loader.h" + +#endif // TENSORFLOW_CORE_PLATFORM_STREAM_EXECUTOR_NO_CUDA_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/platform/stringpiece.h b/third_party/tflite-hdrs/tensorflow/core/platform/stringpiece.h new file mode 100644 index 00000000..43f3d4a9 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/platform/stringpiece.h @@ -0,0 +1,43 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// StringPiece is a simple structure containing a pointer into some external +// storage and a size. The user of a StringPiece must ensure that the slice +// is not used after the corresponding external storage has been +// deallocated. +// +// Multiple threads can invoke const methods on a StringPiece without +// external synchronization, but if any of the threads may call a +// non-const method, all threads accessing the same StringPiece must use +// external synchronization. + +#ifndef TENSORFLOW_CORE_PLATFORM_STRINGPIECE_H_ +#define TENSORFLOW_CORE_PLATFORM_STRINGPIECE_H_ + +#include "absl/base/macros.h" +#include "tsl/platform/stringpiece.h" // IWYU pragma: export + +// TODO: b/323943471 - This macro should eventually be provided by Abseil. +#ifndef ABSL_DEPRECATE_AND_INLINE +#define ABSL_DEPRECATE_AND_INLINE() +#endif + +namespace tensorflow { + +using StringPiece ABSL_DEPRECATE_AND_INLINE() = absl::string_view; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PLATFORM_STRINGPIECE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/platform/stringprintf.h b/third_party/tflite-hdrs/tensorflow/core/platform/stringprintf.h new file mode 100644 index 00000000..27d30089 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/platform/stringprintf.h @@ -0,0 +1,43 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Printf variants that place their output in a C++ string. +// +// Usage: +// string result = strings::Printf("%d %s\n", 10, "hello"); +// strings::Appendf(&result, "%d %s\n", 20, "there"); + +#ifndef TENSORFLOW_CORE_PLATFORM_STRINGPRINTF_H_ +#define TENSORFLOW_CORE_PLATFORM_STRINGPRINTF_H_ + +#include + +#include + +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/types.h" +#include "tsl/platform/stringprintf.h" + +namespace tensorflow { +namespace strings { +// NOLINTBEGIN(misc-unused-using-decls) +using tsl::strings::Appendf; +using tsl::strings::Appendv; +using tsl::strings::Printf; +// NOLINTEND(misc-unused-using-decls) +} // namespace strings +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PLATFORM_STRINGPRINTF_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/platform/strong_hash.h b/third_party/tflite-hdrs/tensorflow/core/platform/strong_hash.h new file mode 100644 index 00000000..c442103c --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/platform/strong_hash.h @@ -0,0 +1,45 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PLATFORM_STRONG_HASH_H_ +#define TENSORFLOW_CORE_PLATFORM_STRONG_HASH_H_ + +#include "highwayhash/sip_hash.h" // from @highwayhash +#include "highwayhash/state_helpers.h" // from @highwayhash +#include "tensorflow/core/platform/platform.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +// This is a strong keyed hash function interface for strings. +// The hash function is deterministic on the content of the string within the +// process. The key of the hash is an array of 2 uint64 elements. +// A strong hash makes it difficult, if not infeasible, to compute inputs that +// hash to the same bucket. +// +// Usage: +// uint64 key[2] = {123, 456}; +// string input = "input string"; +// uint64 hash_value = StrongKeyedHash(key, input); +// +inline uint64 StrongKeyedHash(const tensorflow::uint64 (&key)[2], + const string& s) { + return highwayhash::StringHasher()( + {key[0], key[1]}, s); +} + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PLATFORM_STRONG_HASH_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/platform/subprocess.h b/third_party/tflite-hdrs/tensorflow/core/platform/subprocess.h new file mode 100644 index 00000000..0406f529 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/platform/subprocess.h @@ -0,0 +1,38 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PLATFORM_SUBPROCESS_H_ +#define TENSORFLOW_CORE_PLATFORM_SUBPROCESS_H_ + +#include "xla/tsl/platform/subprocess.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +using tsl::ACTION_CLOSE; +using tsl::ACTION_DUPPARENT; +using tsl::ACTION_PIPE; +using tsl::CHAN_STDERR; +using tsl::CHAN_STDIN; +using tsl::CHAN_STDOUT; +using tsl::Channel; +using tsl::ChannelAction; +using tsl::CreateSubProcess; +using tsl::SubProcess; +} // namespace tensorflow + +#include "tensorflow/core/platform/platform.h" + + +#endif // TENSORFLOW_CORE_PLATFORM_SUBPROCESS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/platform/tensor_coding.h b/third_party/tflite-hdrs/tensorflow/core/platform/tensor_coding.h new file mode 100644 index 00000000..b024e143 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/platform/tensor_coding.h @@ -0,0 +1,137 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Helper routines for encoding/decoding tensor contents. +#ifndef TENSORFLOW_CORE_PLATFORM_TENSOR_CODING_H_ +#define TENSORFLOW_CORE_PLATFORM_TENSOR_CODING_H_ + +#include + +#include "tensorflow/core/platform/platform.h" +#include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/platform/refcount.h" +#include "tensorflow/core/platform/stringpiece.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +namespace port { + +// Store src contents in *out. If backing memory for src is shared with *out, +// will ref obj during the call and will arrange to unref obj when no +// longer needed. +void AssignRefCounted(absl::string_view src, core::RefCounted* obj, + std::string* out); + +// Copy contents of src to dst[0,src.size()-1]. +inline void CopyToArray(const std::string& src, char* dst) { + memcpy(dst, src.data(), src.size()); +} + +// Copy subrange [pos:(pos + n)) from src to dst. If pos >= src.size() the +// result is empty. If pos + n > src.size() the subrange [pos, size()) is +// copied. +inline void CopySubrangeToArray(const std::string& src, size_t pos, size_t n, + char* dst) { + if (pos >= src.size()) return; + memcpy(dst, src.data() + pos, std::min(n, src.size() - pos)); +} + +// Store encoding of strings[0..n-1] in *out. +void EncodeStringList(const tstring* strings, int64_t n, std::string* out); + +// Decode n strings from src and store in strings[0..n-1]. +// Returns true if successful, false on parse error. +bool DecodeStringList(const std::string& src, tstring* strings, int64_t n); + +// Assigns base[0..bytes-1] to *s +void CopyFromArray(std::string* s, const char* base, size_t bytes); + +// Encodes sequences of strings and serialized protocol buffers into a string. +// Normal usage consists of zero or more calls to Append() and a single call to +// Finalize(). +class StringListEncoder { + public: + virtual ~StringListEncoder() = default; + + // Encodes the given protocol buffer. This may not be called after Finalize(). + virtual void Append(const protobuf::MessageLite& m) = 0; + + // Encodes the given string. This may not be called after Finalize(). + virtual void Append(const std::string& s) = 0; + + // Signals end of the encoding process. No other calls are allowed after this. + virtual void Finalize() = 0; +}; + +// Decodes a string into sequences of strings (which may represent serialized +// protocol buffers). Normal usage involves a single call to ReadSizes() in +// order to retrieve the length of all the strings in the sequence. For each +// size returned a call to Data() is expected and will return the actual +// string. +class StringListDecoder { + public: + virtual ~StringListDecoder() = default; + + // Populates the given vector with the lengths of each string in the sequence + // being decoded. Upon returning the vector is guaranteed to contain as many + // elements as there are strings in the sequence. + virtual bool ReadSizes(std::vector* sizes) = 0; + + // Returns a pointer to the next string in the sequence, then prepares for the + // next call by advancing 'size' characters in the sequence. + virtual const char* Data(uint32 size) = 0; +}; + +std::unique_ptr NewStringListEncoder(string* out); +std::unique_ptr NewStringListDecoder(const string& in); + +#if defined(TENSORFLOW_PROTOBUF_USES_CORD) +// Store src contents in *out. If backing memory for src is shared with *out, +// will ref obj during the call and will arrange to unref obj when no +// longer needed. +void AssignRefCounted(absl::string_view src, core::RefCounted* obj, + absl::Cord* out); + +// TODO(kmensah): Macro guard this with a check for Cord support. +inline void CopyToArray(const absl::Cord& src, char* dst) { + src.CopyToArray(dst); +} + +// Copy n bytes of src to dst. If pos >= src.size() the result is empty. +// If pos + n > src.size() the subrange [pos, size()) is copied. +inline void CopySubrangeToArray(const absl::Cord& src, int64_t pos, int64_t n, + char* dst) { + src.Subcord(pos, n).CopyToArray(dst); +} + +// Store encoding of strings[0..n-1] in *out. +void EncodeStringList(const tstring* strings, int64_t n, absl::Cord* out); + +// Decode n strings from src and store in strings[0..n-1]. +// Returns true if successful, false on parse error. +bool DecodeStringList(const absl::Cord& src, std::string* strings, int64_t n); +bool DecodeStringList(const absl::Cord& src, tstring* strings, int64_t n); + +// Assigns base[0..bytes-1] to *c +void CopyFromArray(absl::Cord* c, const char* base, size_t bytes); + +std::unique_ptr NewStringListEncoder(absl::Cord* out); +std::unique_ptr NewStringListDecoder(const absl::Cord& in); +#endif // defined(TENSORFLOW_PROTOBUF_USES_CORD) + +} // namespace port +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PLATFORM_TENSOR_CODING_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/platform/tensor_float_32_utils.h b/third_party/tflite-hdrs/tensorflow/core/platform/tensor_float_32_utils.h new file mode 100644 index 00000000..efcb9941 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/platform/tensor_float_32_utils.h @@ -0,0 +1,28 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PLATFORM_TENSOR_FLOAT_32_UTILS_H_ +#define TENSORFLOW_CORE_PLATFORM_TENSOR_FLOAT_32_UTILS_H_ + +#include "tsl/platform/tensor_float_32_utils.h" + +namespace tensorflow { +// NOLINTBEGIN(misc-unused-using-decls) +using tsl::enable_tensor_float_32_execution; +using tsl::tensor_float_32_execution_enabled; +// NOLINTEND(misc-unused-using-decls) +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PLATFORM_TENSOR_FLOAT_32_UTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/platform/test.h b/third_party/tflite-hdrs/tensorflow/core/platform/test.h new file mode 100644 index 00000000..d57a08f3 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/platform/test.h @@ -0,0 +1,36 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PLATFORM_TEST_H_ +#define TENSORFLOW_CORE_PLATFORM_TEST_H_ + +#include // IWYU pragma: export +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/platform.h" +#include "tensorflow/core/platform/types.h" +#include "tsl/platform/test.h" + +namespace tensorflow { + +namespace testing { +using tsl::testing::PickUnusedPortOrDie; +using tsl::testing::RandomSeed; +using tsl::testing::TensorFlowSrcRoot; +using tsl::testing::TmpDir; + +} // namespace testing +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PLATFORM_TEST_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/platform/test_benchmark.h b/third_party/tflite-hdrs/tensorflow/core/platform/test_benchmark.h new file mode 100644 index 00000000..ed964a89 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/platform/test_benchmark.h @@ -0,0 +1,30 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Simple benchmarking facility. +#ifndef TENSORFLOW_CORE_PLATFORM_TEST_BENCHMARK_H_ +#define TENSORFLOW_CORE_PLATFORM_TEST_BENCHMARK_H_ + +#include "tsl/platform/test_benchmark.h" + +namespace tensorflow { +namespace testing { +using tsl::testing::DoNotOptimize; // NOLINT +using tsl::testing::InitializeBenchmarks; // NOLINT +using tsl::testing::RunBenchmarks; // NOLINT +} // namespace testing +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PLATFORM_TEST_BENCHMARK_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/platform/thread_annotations.h b/third_party/tflite-hdrs/tensorflow/core/platform/thread_annotations.h new file mode 100644 index 00000000..4178265a --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/platform/thread_annotations.h @@ -0,0 +1,43 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This header file contains the macro definitions for thread safety +// annotations that allow the developers to document the locking policies +// of their multi-threaded code. The annotations can also help program +// analysis tools to identify potential thread safety issues. +// +// The primary documentation on these annotations is external: +// http://clang.llvm.org/docs/ThreadSafetyAnalysis.html +// +// The annotations are implemented using compiler attributes. +// Using the macros defined here instead of the raw attributes allows +// for portability and future compatibility. +// +// When referring to mutexes in the arguments of the attributes, you should +// use variable names or more complex expressions (e.g. my_object->mutex_) +// that evaluate to a concrete mutex object whenever possible. If the mutex +// you want to refer to is not in scope, you may use a member pointer +// (e.g. &MyClass::mutex_) to refer to a mutex in some (unknown) object. +// + +#ifndef TENSORFLOW_CORE_PLATFORM_THREAD_ANNOTATIONS_H_ +#define TENSORFLOW_CORE_PLATFORM_THREAD_ANNOTATIONS_H_ + +// IWYU pragma: private, include "third_party/tensorflow/core/platform/thread_annotations.h" +// IWYU pragma: friend third_party/tensorflow/core/platform/thread_annotations.h + +#include "tsl/platform/thread_annotations.h" // IWYU pragma: export + +#endif // TENSORFLOW_CORE_PLATFORM_THREAD_ANNOTATIONS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/platform/threadpool.h b/third_party/tflite-hdrs/tensorflow/core/platform/threadpool.h new file mode 100644 index 00000000..02129fd4 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/platform/threadpool.h @@ -0,0 +1,37 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PLATFORM_THREADPOOL_H_ +#define TENSORFLOW_CORE_PLATFORM_THREADPOOL_H_ + +#include +#include + +#include "absl/types/optional.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/threadpool_interface.h" +#include "tensorflow/core/platform/types.h" +#include "tsl/platform/threadpool.h" + +namespace tensorflow { +namespace thread { +using tsl::thread::EigenEnvironment; // NOLINT +using tsl::thread::ThreadPool; // NOLINT + +} // namespace thread +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PLATFORM_THREADPOOL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/platform/threadpool_interface.h b/third_party/tflite-hdrs/tensorflow/core/platform/threadpool_interface.h new file mode 100644 index 00000000..7e07e560 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/platform/threadpool_interface.h @@ -0,0 +1,29 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PLATFORM_THREADPOOL_INTERFACE_H_ +#define TENSORFLOW_CORE_PLATFORM_THREADPOOL_INTERFACE_H_ + +#include "tsl/platform/threadpool_interface.h" + +namespace tensorflow { +namespace thread { + +using ThreadPoolInterface = tsl::thread::ThreadPoolInterface; + +} // namespace thread +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PLATFORM_THREADPOOL_INTERFACE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/platform/threadpool_options.h b/third_party/tflite-hdrs/tensorflow/core/platform/threadpool_options.h new file mode 100644 index 00000000..c6237fa8 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/platform/threadpool_options.h @@ -0,0 +1,30 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PLATFORM_THREADPOOL_OPTIONS_H_ +#define TENSORFLOW_CORE_PLATFORM_THREADPOOL_OPTIONS_H_ + +#include "tensorflow/core/platform/threadpool_interface.h" +#include "tsl/platform/threadpool_options.h" + +namespace tensorflow { +namespace thread { + +using tsl::thread::ThreadPoolOptions; // NOLINT + +} // namespace thread +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PLATFORM_THREADPOOL_OPTIONS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/platform/tracing.h b/third_party/tflite-hdrs/tensorflow/core/platform/tracing.h new file mode 100644 index 00000000..24917a6d --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/platform/tracing.h @@ -0,0 +1,53 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PLATFORM_TRACING_H_ +#define TENSORFLOW_CORE_PLATFORM_TRACING_H_ + +// Tracing interface + +#include + +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/platform.h" +#include "tensorflow/core/platform/stringpiece.h" +#include "tensorflow/core/platform/types.h" +#include "tsl/platform/tracing.h" + +namespace tensorflow { +namespace tracing { +// NOLINTBEGIN(misc-unused-using-decls) +using tsl::tracing::EventCategory; +using tsl::tracing::EventCollector; +using tsl::tracing::GetArgForName; +using tsl::tracing::GetEventCategoryName; +using tsl::tracing::GetEventCollector; +using tsl::tracing::GetLogDir; +using tsl::tracing::GetNumEventCategories; +using tsl::tracing::GetUniqueArg; +using tsl::tracing::RecordEvent; +using tsl::tracing::ScopedRegion; +using tsl::tracing::SetEventCollector; +// NOLINTEND(misc-unused-using-decls) +} // namespace tracing +} // namespace tensorflow + +#if defined(PLATFORM_GOOGLE) +#include "xla/tsl/platform/google/tracing_impl.h" +#else +#include "xla/tsl/platform/default/tracing_impl.h" +#endif + +#endif // TENSORFLOW_CORE_PLATFORM_TRACING_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/platform/tstring.h b/third_party/tflite-hdrs/tensorflow/core/platform/tstring.h new file mode 100644 index 00000000..7795811d --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/platform/tstring.h @@ -0,0 +1,29 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PLATFORM_TSTRING_H_ +#define TENSORFLOW_CORE_PLATFORM_TSTRING_H_ + +#include "tensorflow/core/platform/cord.h" +#include "tensorflow/core/platform/ctstring.h" +#include "tensorflow/core/platform/stringpiece.h" +#include "tsl/platform/tstring.h" + +namespace tensorflow { + +using tstring = tsl::tstring; +} + +#endif // TENSORFLOW_CORE_PLATFORM_TSTRING_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/platform/types.h b/third_party/tflite-hdrs/tensorflow/core/platform/types.h new file mode 100644 index 00000000..a3159bfe --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/platform/types.h @@ -0,0 +1,63 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PLATFORM_TYPES_H_ +#define TENSORFLOW_CORE_PLATFORM_TYPES_H_ + +#include "tensorflow/core/platform/bfloat16.h" +#include "tensorflow/core/platform/platform.h" +#include "tensorflow/core/platform/tstring.h" +#include "tsl/platform/types.h" + +namespace tensorflow { + +// Alias tensorflow::string to std::string. +using tsl::string; + +using tsl::uint16; +using tsl::uint32; +using tsl::uint4; +using tsl::uint64; +using tsl::uint8; + +using tsl::int16; +using tsl::int32; +using tsl::int4; +using tsl::int64; +using tsl::int8; + +using tsl::float8_e4m3fn; +using tsl::float8_e5m2; + +static const uint8 kuint8max = tsl::kuint8max; +static const uint16 kuint16max = tsl::kuint16max; +static const uint32 kuint32max = tsl::kuint32max; +static const uint64 kuint64max = tsl::kuint64max; +static const int8_t kint8min = tsl::kint8min; +static const int8_t kint8max = tsl::kint8max; +static const int16_t kint16min = tsl::kint16min; +static const int16_t kint16max = tsl::kint16max; +static const int32_t kint32min = tsl::kint32min; +static const int32_t kint32max = tsl::kint32max; +static const int64_t kint64min = tsl::kint64min; +static const int64_t kint64max = tsl::kint64max; + +// A typedef for a uint64 used as a short fingerprint. +using tsl::bfloat16; +using tsl::Fprint; +using tsl::tstring; // NOLINT: suppress 'using decl 'tstring' is unused' +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PLATFORM_TYPES_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/platform/unbounded_work_queue.h b/third_party/tflite-hdrs/tensorflow/core/platform/unbounded_work_queue.h new file mode 100644 index 00000000..cd6cdf97 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/platform/unbounded_work_queue.h @@ -0,0 +1,29 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PLATFORM_UNBOUNDED_WORK_QUEUE_H_ +#define TENSORFLOW_CORE_PLATFORM_UNBOUNDED_WORK_QUEUE_H_ + +#include "tensorflow/core/platform/platform.h" +#include "tsl/platform/unbounded_work_queue.h" + +// An `UnboundedWorkQueue` feeds potentially-blocking work into a thread-pool +// whose size automatically increases with demand. + +namespace tensorflow { +using tsl::UnboundedWorkQueue; // NOLINT(misc-unused-using-decls) +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PLATFORM_UNBOUNDED_WORK_QUEUE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/profiler/convert/compute_inference_latency.h b/third_party/tflite-hdrs/tensorflow/core/profiler/convert/compute_inference_latency.h new file mode 100644 index 00000000..91632c90 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/profiler/convert/compute_inference_latency.h @@ -0,0 +1,33 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PROFILER_CONVERT_COMPUTE_INFERENCE_LATENCY_H_ +#define TENSORFLOW_CORE_PROFILER_CONVERT_COMPUTE_INFERENCE_LATENCY_H_ + +#include +#include + +#include "tensorflow/core/profiler/protobuf/inference_stats.pb.h" +#include "tensorflow/core/profiler/protobuf/overview_page.pb.h" + +namespace tensorflow::profiler { + +// Compute the inference latency from inference stats proto. +OverviewInferenceLatency ComputeInferenceLatencyResult( + const InferenceStats& inference_stats); + +} // namespace tensorflow::profiler + +#endif // TENSORFLOW_CORE_PROFILER_CONVERT_COMPUTE_INFERENCE_LATENCY_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/profiler/convert/dcn_analysis.h b/third_party/tflite-hdrs/tensorflow/core/profiler/convert/dcn_analysis.h new file mode 100644 index 00000000..cdff8177 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/profiler/convert/dcn_analysis.h @@ -0,0 +1,225 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_PROFILER_CONVERT_DCN_ANALYSIS_H_ +#define TENSORFLOW_CORE_PROFILER_CONVERT_DCN_ANALYSIS_H_ + +#include +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/strings/string_view.h" +#include "tensorflow/core/profiler/convert/dcn_utils.h" +#include "tensorflow/core/profiler/utils/xplane_builder.h" +#include "tensorflow/core/profiler/utils/xplane_visitor.h" + +namespace tensorflow { +namespace profiler { + +// Structure representing a DcnMessage using two entries: +// One for the start of the message and one for the end. +struct TimestampEvent { + uint64_t timestamp_ns; // TraceMe logging timestamp + uint64_t duration_ns; // 0 for start of message, duration for end of message + int32_t message_diff; // +1/-1 for start/end of message. + // Makes handling 0-sized messages easier and is + // convenient for the burst generation algorithm. + size_t size_diff; // +size/-size for start/end of message. + int32_t src_slice_id; // Source slice for message, used for stragglers +}; + +// We use an multi map since TimestampEvents will be ordered and we +// need separate entries for possible events happening at exactly the +// same time. +typedef std::multimap> TimestampMap; +typedef absl::flat_hash_map CollectiveTimestampMap; + +// Straggler messages. These are shown at the end of the bursts they belong to. +struct Straggler { + uint64_t duration_ns; // Message duration in ns + uint64_t end_timestamp_ns; // End of the message. For the last straggler + // this will be the end of the burst + size_t size_bytes; // Size of the message in bytes + int32_t src_slice_id; // Source slice of the message + // TODO(emizan) Add host info. +}; + +static constexpr uint32_t kMaxStragglersPerBurst = 4; + +// DCN Burst description. +// A burst is defined as a period of time during which there is at least one +// message in the network. Since DCN traffic is bursty this structure is +// convenient to summarize 100K+ messages in a few 10s of bursts. +// Burst scope is flexible. In this analysis we have per-host bursts, which +// include messages arriving on a single host independent of sender/target TPU/ +// and collective. We also have per collective/TPU bursts which include messages +// for a single collective+TPU combination. +struct DcnBurst { + uint64_t start_timestamp_ns; // Beginning of burst in ns + uint64_t end_timestamp_ns; // End of burst in ns + uint64_t burst_size_bytes; // Total number of bytes in burst + uint64_t num_messages; // Messages in burst + uint64_t max_overlapping_messages; // Max overlapping messages in burst + // Buffer of stragglers in a bursts. Contains the last few messages in a burst + std::array stragglers; +}; + +// Class with functionality to generate DcnBursts out of TimestampEvents. +// Burst creation is a non-trivial state machine +class DcnBurstManager { + public: + DcnBurstManager() = default; + uint64_t TotalLatency() const { return total_latency_; } + void SetToDisplay(bool to_display) { to_display_ = to_display; } + bool ToDisplay() const { return to_display_; } + const std::vector &GetBursts() const { return bursts_; } + + // Run burst state machine creation out of timestamp map. + void CreateBursts(const TimestampMap &tm_events); + // For debugging purposes. + void PrintBursts() { + for (const auto &burst : bursts_) { + LOG(INFO) << burst.start_timestamp_ns << " " << burst.end_timestamp_ns + << " " << burst.num_messages << " " << burst.burst_size_bytes + << " " << burst.max_overlapping_messages; + } + } + + private: + std::vector bursts_; // Bursts created by this manager + uint64_t total_latency_ = 0; // Total latency of all bursts created + // Used to see if bursts will be displayed + bool to_display_ = false; // Set to true to enable burst display + + int32_t active_burst_messages_; // Used by burst creation state machine. + DcnBurst active_burst_; // Active burst in creation + uint32_t straggler_idx_; + + // Initializes state machine when new burst is detected. + void ResetBurstState(); +}; + +typedef absl::flat_hash_map + CollectiveBurstManager; + +class DcnEventsProcessor { + public: + DcnEventsProcessor() = delete; + DcnEventsProcessor(uint32_t num_tpu_tensor_cores, bool is_megacore); + + uint32_t NumTpuTensorCores() const { return num_tpu_tensor_cores_; } + bool IsMegacore() const { return is_megacore_; } + + // Populates available megascale messages from event metadata. + void SetupMessageInfo(const tensorflow::profiler::XPlaneVisitor &plane); + + std::optional MegaScaleMessageId(absl::string_view msg_name) const { + auto iter = megascale_msg_.find(msg_name); + if (iter != megascale_msg_.end()) { + return iter->second; + } + return std::nullopt; + } + + uint32_t NumReceivedMessages() const { return received_messages_.size(); } + const tensorflow::profiler::DcnMessage &GetMessage(uint32_t i) const { + return received_messages_[i]; + } + + // Checks if messages with msg event name have been found in event metadata. + bool HasDcnMessages(absl::string_view msg_name) const { + return (megascale_msg_.find(msg_name) != megascale_msg_.end()); + } + + const TimestampMap &HostTsMap() const { return host_ts_map_; } + const std::vector &GetHostBursts() const { + return host_dcn_bursts_.GetBursts(); + } + + // Main function to process receive messages, and call other functions + // to generate timestamp events and bursts. + void ProcessReceiveMessages(const tensorflow::profiler::XPlaneVisitor &plane); + + // Update XPlanes using DCN traffic info + void AddHostDcnTrafficToXPlane(tensorflow::profiler::XPlane *host_xplane); + void AddTpuCollectiveDcnTrafficToXPlane( + tensorflow::profiler::XPlane *device_xplane); + + private: + // Tensor cores and megacore flag for this host. DCN messages are sent to a + // TPU chip, so we need to know the number of tensor cores and whether + // megacore is used to map DCN traffic to the proper tensor core. + const uint32_t num_tpu_tensor_cores_; + const bool is_megacore_; + + // Used for visualization of BW and computation of BW utilization. + static constexpr float kLimitLowHostDcnBw = 4.17; + static constexpr float kLimitMedHostDcnBw = 8.34; + static constexpr float kMaxHostDcnBw = 12.5; + + std::vector registered_dcn_messages_; + + // Available megascale messages for this trace. + absl::flat_hash_map megascale_msg_; + + std::vector received_messages_; + + // TimestampMaps for messages that arrive to this host + // and for messages of distinct collectives going to different TPUs. + TimestampMap host_ts_map_; + std::vector tpu_collective_ts_map_; + + // DcnBurstManagers for bursts that arrive to this host + // and for burst from distinct collectives going to different TPUs. + DcnBurstManager host_dcn_bursts_; + std::vector tpu_collective_bursts_; + + // Find the TPU index a DCN message goes to. + uint32_t FindTpuIdx(int tpu); + + // Generates BW info to display in the trace viewer. + // This included trace event BW level string, mean BW per burst and + // utilization. + absl::string_view GetBwInfo(bool is_per_tpu, const DcnBurst &burst, + float &burst_mean_bw, + float &burst_bw_utilization); + + // Qualify collectives to display on trace viewer. + // Qualified collectives are given a dedicated line, while for the rest + // we share a single line for their stragglers. + uint32_t NumCollectivesQualified(const std::vector &latencies); + void QualifyCollectives(); + // Export collective DCN activity to trace viewer. + void AddQualifiedCollectivesToXPlane( + tensorflow::profiler::XPlaneBuilder &plane_builder, uint32_t tpu_idx); + void AddUnqualifiedCollectivesToXPlane( + tensorflow::profiler::XPlaneBuilder &plane_builder, uint32_t tpu_idx); + + // Create timestamp events for every message + void GenerateTimestampEvents( + const tensorflow::profiler::DcnMessage &dcn_message); + // For debugging purposes + void PrintTimestampEvents(); + // Generate bursts (host and TPU/collective) from timestamp events. + void GenerateBursts(); +}; + +} // namespace profiler +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PROFILER_CONVERT_DCN_ANALYSIS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/profiler/convert/dcn_slack_analysis_combiner.h b/third_party/tflite-hdrs/tensorflow/core/profiler/convert/dcn_slack_analysis_combiner.h new file mode 100644 index 00000000..f0fc727a --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/profiler/convert/dcn_slack_analysis_combiner.h @@ -0,0 +1,47 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_PROFILER_CONVERT_DCN_SLACK_ANALYSIS_COMBINER_H_ +#define TENSORFLOW_CORE_PROFILER_CONVERT_DCN_SLACK_ANALYSIS_COMBINER_H_ + +#include + +#include "absl/container/flat_hash_map.h" +#include "tensorflow/core/profiler/protobuf/dcn_slack_analysis.pb.h" + +namespace tensorflow { +namespace profiler { + +using tensorflow::profiler::DcnSlackAnalysis; +using tensorflow::profiler::DcnSlackSummary; + +class DcnSlackAnalysisCombiner { + private: + absl::flat_hash_map slack_summary_; + + public: + // Combine the DCN Slack Summary in the DcnSlackAnalysis. + // The DcnSlackAnalysis consists of average durations, The combine phase, the + // summary consists of the total duration for all the occurrences. Finazile + // must be called to get the accurate value. + void Combine(const DcnSlackAnalysis& slack_analysis); + + // Finalize the DcnSlackSummary by converting total durations to averages. + DcnSlackAnalysis Finalize(); +}; + +} // namespace profiler +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PROFILER_CONVERT_DCN_SLACK_ANALYSIS_COMBINER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/profiler/convert/dcn_utils.h b/third_party/tflite-hdrs/tensorflow/core/profiler/convert/dcn_utils.h new file mode 100644 index 00000000..e0dd3a17 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/profiler/convert/dcn_utils.h @@ -0,0 +1,76 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PROFILER_CONVERT_DCN_UTILS_H_ +#define TENSORFLOW_CORE_PROFILER_CONVERT_DCN_UTILS_H_ + +#include + +#include "xla/tsl/profiler/utils/xplane_visitor.h" + +namespace tensorflow { +namespace profiler { + +// DCN Message Validity +enum DcnMessageValidity { + // Valid message + DCN_MESSAGE_VALID = 1, + // Valid message, but should not go through DCN, so it should not use BW. + DCN_MESSAGE_VALID_LOOPBACK = 2, + // Invalid message with 0 duration due to clock skew. Should be ignored. + DCN_MESSAGE_INVALID_CLOCK_SKEW = 3, + // Message that cannot be decoded. Should be ignored. + DCN_MESSAGE_INVALID_BAD_KEY = 4 +}; + +// Structure representing a DCN event +struct DcnMessage { + // Unique collective that generated this message, format should be + // _, e.g. all_gather_34 + std::string collective_name = ""; + // Src info + // TODO(emizan) Add host info when you figure out how to get it from + // slice+tpu. + int32_t slice_src = -1; + int32_t tpu_src = -1; + // Dst info + int32_t slice_dst = -1; + int32_t tpu_dst = -1; + // Timing info in ns. Since MSXLA TraceMe's have us timestamps, we need to + // multiply by 1000 to get these timestamps. + uint64_t start_timestamp_ns = 0; + uint64_t end_timestamp_ns = 0; + uint64_t duration_us = 0; + // Size info + size_t size_bytes = 0; + // Chunk and Loop index + int32_t chunk_id = -1; + int32_t loop_index_id = -1; + // Is message valid/invalid and why + DcnMessageValidity validity_info = DCN_MESSAGE_INVALID_BAD_KEY; + // TBD: Add flow events in case you need to connect to other events pointed to + // by MSXLA TraceMe's +}; + +DcnMessage GetDcnMessageFromXEvent( + const tsl::profiler::XEventVisitor& event_visitor); + +// Check if the XEventVisitor is a DCN Message +bool IsDcnEvent(const tsl::profiler::XEventVisitor& event); + +} // namespace profiler +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PROFILER_CONVERT_DCN_UTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/profiler/convert/hlo_proto_to_graph_view.h b/third_party/tflite-hdrs/tensorflow/core/profiler/convert/hlo_proto_to_graph_view.h new file mode 100644 index 00000000..b3a3a7c4 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/profiler/convert/hlo_proto_to_graph_view.h @@ -0,0 +1,101 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PROFILER_CONVERT_HLO_PROTO_TO_GRAPH_VIEW_H_ +#define TENSORFLOW_CORE_PROFILER_CONVERT_HLO_PROTO_TO_GRAPH_VIEW_H_ + +#include +#include +#include + +#include "xla/service/hlo.pb.h" +#include "xla/service/hlo_graph_dumper.h" +#include "tensorflow/core/platform/statusor.h" +#include "tensorflow/core/profiler/convert/tool_options.h" + +namespace tensorflow { +namespace profiler { + +// All the parameters for graph viewer. +struct GraphViewerParams { + // Whether to use GraphView or TxtView. + std::string type; + // Parameters for GraphView. + std::string node_name; + int graph_width; + xla::HloRenderOptions render_options; + xla::RenderedGraphFormat format; + // Parameters for TxtView. + bool verbose; + bool show_metadata; +}; + +// Return mapping from style key word to op names separated by comma. +// following hlo_graph_dumper styling +absl::StatusOr GetNodeStyles(); + +// Parse tool options to get the parameters for graph viewer. +absl::StatusOr ParseGraphViewerParams( + const ToolOptions& options); + +// Get graph render format. +xla::RenderedGraphFormat GetRenderFormat(const std::string& format_string); + +// Convert `hlo_proto` to GraphView with the provided render options. +absl::StatusOr ConvertHloProtoToGraph( + const xla::HloProto& hlo_proto, const std::string& node_name, + int graph_width, const xla::HloRenderOptions& render_options, + const xla::RenderedGraphFormat& format); + +// Convert `hlo_proto` to ModelExplorer Graph JSON data. +absl::StatusOr ConvertHloProtoToMeGraph( + const xla::HloProto& hlo_proto, const std::string& node_name, + int graph_width); + +// Render graph with the provided render options. +absl::StatusOr RenderGraphView( + const xla::HloComputation& computation, absl::string_view label, + const xla::DebugOptions& debug_options, xla::RenderedGraphFormat format, + xla::HloRenderOptions hlo_render_options = {}); + +// Render graph with centered node and depth +absl::StatusOr RenderGraphNeighborhoodAround( + const xla::HloInstruction& node, int radius, + xla::RenderedGraphFormat format, + xla::HloRenderOptions hlo_render_options = {}, + const absl::flat_hash_set& boundary = {}); + +// Convert `hlo_proto` to StringView. +absl::StatusOr ConvertHloProtoToStringView( + const xla::HloProto& hlo_proto, bool verbose, bool metadata); + +// Convert dot into certain format +absl::StatusOr WrapDotInFormat(std::string dot, + xla::RenderedGraphFormat format); + +// Convert dot into visual graph in html +std::string WrapDotInHtml(std::string dot); + +// Registers a function which implements RenderedGraphFormat::kUrl. +// The input to the function is dot, and the output should be a URL or an error. +// There can only be one active renderer, and the last call to this function +// wins. +void RegisterGraphvizURLRenderer( + std::function(absl::string_view dot)> renderer); + +} // namespace profiler +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PROFILER_CONVERT_HLO_PROTO_TO_GRAPH_VIEW_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/profiler/convert/hlo_proto_to_memory_visualization_utils.h b/third_party/tflite-hdrs/tensorflow/core/profiler/convert/hlo_proto_to_memory_visualization_utils.h new file mode 100644 index 00000000..e7a681de --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/profiler/convert/hlo_proto_to_memory_visualization_utils.h @@ -0,0 +1,44 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PROFILER_CONVERT_HLO_PROTO_TO_MEMORY_VISUALIZATION_UTILS_H_ +#define TENSORFLOW_CORE_PROFILER_CONVERT_HLO_PROTO_TO_MEMORY_VISUALIZATION_UTILS_H_ + +#include + +#include "absl/status/statusor.h" +#include "xla/service/hlo.pb.h" +#include "tensorflow/core/profiler/protobuf/memory_viewer_preprocess.pb.h" + +namespace tensorflow { +namespace profiler { + +constexpr int kSmallBufferSize = 16 * 1024; + +// Convert HloProto to PreprocessResult proto for memory visualization. +// small_buffer_size sets the byte size within which we collapse buffer entries +// for the max-heap display. +// is the index of heap simulator trace to be +// displayed. By default it is -1, which means the profiler will infer the heap +// simulator trace id from . +// By default the memory color is 0, which is HBM. +absl::StatusOr ConvertHloProtoToPreprocessResult( + const xla::HloProto& hlo_proto, + int64_t small_buffer_size = kSmallBufferSize, int64_t memory_color = 0); + +} // namespace profiler +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PROFILER_CONVERT_HLO_PROTO_TO_MEMORY_VISUALIZATION_UTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/profiler/convert/hlo_to_tools_data.h b/third_party/tflite-hdrs/tensorflow/core/profiler/convert/hlo_to_tools_data.h new file mode 100644 index 00000000..b567c973 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/profiler/convert/hlo_to_tools_data.h @@ -0,0 +1,41 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PROFILER_CONVERT_HLO_TO_TOOLS_DATA_H_ +#define TENSORFLOW_CORE_PROFILER_CONVERT_HLO_TO_TOOLS_DATA_H_ + +#include + +#include "absl/strings/string_view.h" +#include "tensorflow/core/platform/statusor.h" +#include "tensorflow/core/profiler/convert/repository.h" +#include "tensorflow/core/profiler/convert/tool_options.h" + +namespace tensorflow { +namespace profiler { + +// Convert HLO proto to tool specific data. +// must provide a "module_name" field to identify which HLO proto +// is used for the conversion. +// Return the serialized string of tool specific data when the conversion is +// successful, else return an error status. +absl::StatusOr ConvertHloProtoToToolData( + const SessionSnapshot& session_snapshot, absl::string_view tool_name, + const ToolOptions& options); + +} // namespace profiler +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PROFILER_CONVERT_HLO_TO_TOOLS_DATA_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/profiler/convert/inference_stats.h b/third_party/tflite-hdrs/tensorflow/core/profiler/convert/inference_stats.h new file mode 100644 index 00000000..2789694a --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/profiler/convert/inference_stats.h @@ -0,0 +1,53 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PROFILER_CONVERT_INFERENCE_STATS_H_ +#define TENSORFLOW_CORE_PROFILER_CONVERT_INFERENCE_STATS_H_ + +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "xla/tsl/profiler/utils/device_utils.h" +#include "xla/tsl/profiler/utils/group_events.h" +#include "tensorflow/core/profiler/protobuf/inference_stats.pb.h" +#include "tensorflow/core/profiler/protobuf/xplane.pb.h" +#include "tensorflow/core/profiler/utils/event_span.h" + +namespace tensorflow { +namespace profiler { + +// Generates PerHostInferenceStats from the given trace events. +// For TPU, get time breakdown from device_traces. For GPU, get time breakdown +// from nonoverlapped_step_events. +// Get batching parameters from TFstreamz xplane in . +void GenerateInferenceStats( + const std::vector& device_traces, + const tensorflow::profiler::StepEvents& nonoverlapped_step_events, + const tsl::profiler::GroupMetadataMap& group_metadata_map, + const tensorflow::profiler::XSpace& xspace, + tsl::profiler::DeviceType device_type, int32_t host_id, + tensorflow::profiler::InferenceStats* inference_stats); + +// Parses model name from TFstreamz. +// Returns whether the parsing is successful and the actual model name. If +// parsing failed, returns false and an empty string. +std::pair ParseModelName(absl::string_view param); + +} // namespace profiler +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PROFILER_CONVERT_INFERENCE_STATS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/profiler/convert/inference_stats_combiner.h b/third_party/tflite-hdrs/tensorflow/core/profiler/convert/inference_stats_combiner.h new file mode 100644 index 00000000..ceccc9cc --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/profiler/convert/inference_stats_combiner.h @@ -0,0 +1,25 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PROFILER_CONVERT_INFERENCE_STATS_COMBINER_H_ +#define TENSORFLOW_CORE_PROFILER_CONVERT_INFERENCE_STATS_COMBINER_H_ +#include "tensorflow/core/profiler/protobuf/inference_stats.pb.h" + +namespace tensorflow::profiler { +void CombineInferenceStatsResult(int src_host_id, const InferenceStats& src, + InferenceStats* dst); +} // namespace tensorflow::profiler + +#endif // TENSORFLOW_CORE_PROFILER_CONVERT_INFERENCE_STATS_COMBINER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/profiler/convert/inference_stats_grouping.h b/third_party/tflite-hdrs/tensorflow/core/profiler/convert/inference_stats_grouping.h new file mode 100644 index 00000000..7d60da0f --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/profiler/convert/inference_stats_grouping.h @@ -0,0 +1,29 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_PROFILER_CONVERT_INFERENCE_STATS_GROUPING_H_ +#define TENSORFLOW_CORE_PROFILER_CONVERT_INFERENCE_STATS_GROUPING_H_ + +#include "tensorflow/core/profiler/protobuf/inference_stats.pb.h" + +namespace tensorflow::profiler { + +// Change inference stats from per host to per model_id by doing a regroup. +// Future analysis of inference_stats will be on a per model_id basis. +void RegroupInferenceStatsByModel( + tensorflow::profiler::InferenceStats* inference_stats); + +} // namespace tensorflow::profiler + +#endif // TENSORFLOW_CORE_PROFILER_CONVERT_INFERENCE_STATS_GROUPING_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/profiler/convert/inference_stats_sampler.h b/third_party/tflite-hdrs/tensorflow/core/profiler/convert/inference_stats_sampler.h new file mode 100644 index 00000000..2706c16a --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/profiler/convert/inference_stats_sampler.h @@ -0,0 +1,53 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_PROFILER_CONVERT_INFERENCE_STATS_SAMPLER_H_ +#define TENSORFLOW_CORE_PROFILER_CONVERT_INFERENCE_STATS_SAMPLER_H_ + +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/strings/string_view.h" +#include "tensorflow/core/profiler/protobuf/inference_stats.pb.h" + +namespace tensorflow::profiler { + +// Sampled inference stats of a model. +// The pointers of RequestDetail and BatchDetail point to the actual data stored +// in TfOpStats.InferenceStats. +struct SampledPerModelInferenceStats { + // Sampled requests and their percentile. + std::vector> + sampled_requests; + // Sampled batches and their percentile. + std::vector> + sampled_batches; +}; + +// All the sampled inference stats of a profile. +// TODO: Move to use SampledInferenceStatsProto if feasible. +using SampledInferenceStats = + absl::flat_hash_map; + +// Samples a subset of InferenceStats from based on sampling +// column and . +SampledInferenceStats SampleInferenceStats( + absl::string_view request_percentile_column, + absl::string_view batch_percentile_column, + const tensorflow::profiler::InferenceStats& inference_stats); + +} // namespace tensorflow::profiler + +#endif // TENSORFLOW_CORE_PROFILER_CONVERT_INFERENCE_STATS_SAMPLER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/profiler/convert/multi_xplanes_to_op_stats.h b/third_party/tflite-hdrs/tensorflow/core/profiler/convert/multi_xplanes_to_op_stats.h new file mode 100644 index 00000000..51348097 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/profiler/convert/multi_xplanes_to_op_stats.h @@ -0,0 +1,38 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PROFILER_CONVERT_MULTI_XPLANES_TO_OP_STATS_H_ +#define TENSORFLOW_CORE_PROFILER_CONVERT_MULTI_XPLANES_TO_OP_STATS_H_ + +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/profiler/convert/repository.h" +#include "tensorflow/core/profiler/convert/xplane_to_op_stats.h" +#include "tensorflow/core/profiler/protobuf/op_stats.pb.h" + +namespace tensorflow { +namespace profiler { + +// Converts and combines multiple XSpace protos into a single OpStats +// . +// Return the first error status during conversion, or return OkStatus() if +// there is no error. +absl::Status ConvertMultiXSpacesToCombinedOpStats( + const SessionSnapshot& session_snapshot, const OpStatsOptions& options, + OpStats* combined_op_stats); + +} // namespace profiler +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PROFILER_CONVERT_MULTI_XPLANES_TO_OP_STATS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/profiler/convert/multi_xspace_to_inference_stats.h b/third_party/tflite-hdrs/tensorflow/core/profiler/convert/multi_xspace_to_inference_stats.h new file mode 100644 index 00000000..3ea9af85 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/profiler/convert/multi_xspace_to_inference_stats.h @@ -0,0 +1,27 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_PROFILER_CONVERT_MULTI_XSPACE_TO_INFERENCE_STATS_H_ +#define TENSORFLOW_CORE_PROFILER_CONVERT_MULTI_XSPACE_TO_INFERENCE_STATS_H_ + +#include "absl/strings/string_view.h" +#include "tensorflow/core/profiler/convert/repository.h" +#include "tensorflow/core/profiler/protobuf/inference_stats.pb.h" +namespace tensorflow::profiler { +absl::Status ConvertMultiXSpaceToInferenceStats( + const SessionSnapshot& session_snapshot, absl::string_view request_column, + absl::string_view batch_column, InferenceStats* inference_stats); +} + +#endif // TENSORFLOW_CORE_PROFILER_CONVERT_MULTI_XSPACE_TO_INFERENCE_STATS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/profiler/convert/op_metrics_db_combiner.h b/third_party/tflite-hdrs/tensorflow/core/profiler/convert/op_metrics_db_combiner.h new file mode 100644 index 00000000..76019da8 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/profiler/convert/op_metrics_db_combiner.h @@ -0,0 +1,54 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PROFILER_CONVERT_OP_METRICS_DB_COMBINER_H_ +#define TENSORFLOW_CORE_PROFILER_CONVERT_OP_METRICS_DB_COMBINER_H_ + +#include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/profiler/protobuf/op_metrics.pb.h" +#include "tensorflow/core/profiler/utils/op_metrics_db_utils.h" + +namespace tensorflow { +namespace profiler { + +// Copies OpMetrics metadata (e.g., category, provenance) from src to dst. +void CopyOpMetricsMetadata(const OpMetrics& src, OpMetrics* dst); + +// Combines OpMetrics data (e.g., occurrences, time) from src into dst. +// If is set to true, update the dst->num_cores to +// calculate the number of cores a certain op occurs. +void CombineOpMetrics(const OpMetrics& src, OpMetrics* dst, + bool update_num_cores); + +// Combines the memory access breakdown. +void CombineMemoryAccessedBreakdown( + const protobuf::RepeatedPtrField& src, + protobuf::RepeatedPtrField* dst); + +// Helper to combine op metrics databases. +class OpMetricsDbCombiner : public OpMetricsDbBuilder { + public: + explicit OpMetricsDbCombiner(OpMetricsDb* dst) : OpMetricsDbBuilder(dst) {} + + // Combine the OpMetrics in OpMetricsDb to current OpMetricsDbCombiner. + // If is set to true, update the OpMetrics.num_cores to + // calculate the number of cores a certain op occurs. + void Combine(const OpMetricsDb& src, bool update_num_cores = true); +}; + +} // namespace profiler +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PROFILER_CONVERT_OP_METRICS_DB_COMBINER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/profiler/convert/op_metrics_to_record.h b/third_party/tflite-hdrs/tensorflow/core/profiler/convert/op_metrics_to_record.h new file mode 100644 index 00000000..37dfa14c --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/profiler/convert/op_metrics_to_record.h @@ -0,0 +1,343 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PROFILER_CONVERT_OP_METRICS_TO_RECORD_H_ +#define TENSORFLOW_CORE_PROFILER_CONVERT_OP_METRICS_TO_RECORD_H_ + +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "xla/tsl/profiler/utils/device_utils.h" +#include "tensorflow/core/profiler/protobuf/hardware_types.pb.h" +#include "tensorflow/core/profiler/protobuf/op_metrics.pb.h" +#include "tensorflow/core/profiler/protobuf/op_stats.pb.h" +#include "tensorflow/core/profiler/utils/math_utils.h" + +namespace tensorflow { +namespace profiler { + +std::vector SortedOpMetricsDb(const OpMetricsDb& metrics_db, + int max_records = -1); + +inline double GigaFlopsPerSecondPerCore(const OpMetrics& metrics) { + // flops and time_ps are accumulated across all occurrences on all cores. + // time_ps is used instead of self_time_ps because flops for an op includes + // the flops executed by children (nested) ops. + return tsl::profiler::SafeDivide( + metrics.flops(), tsl::profiler::PicoToNano(metrics.time_ps())); +} + +inline double GigaModelFlopsPerSecondPerCore(const OpMetrics& metrics) { + // flops and time_ps are accumulated across all occurrences on all cores. + // time_ps is used instead of self_time_ps because flops for an op includes + // the flops executed by children (nested) ops. + return tsl::profiler::SafeDivide( + metrics.model_flops(), tsl::profiler::PicoToNano(metrics.time_ps())); +} + +// Return ByteAccessed for memory_space and operation_type. +inline double BytesAccessedPerCore( + const OpMetrics& metrics, uint64_t memory_space, + OpMetrics::MemoryAccessed::OperationType operation_type) { + uint64_t bytes = 0; + if (memory_space == MemorySpace::MEMORY_SPACE_ALL) { + bytes = metrics.bytes_accessed(); + } else { + for (const auto& breakdown : metrics.memory_accessed_breakdown()) { + // Count either on-chip or off-chip bytes. + if ((breakdown.operation_type() != operation_type) && + (operation_type != OpMetrics::MemoryAccessed::UNKNOWN)) { + continue; + } + if (((memory_space == MemorySpace::MEMORY_SPACE_HBM) && + (breakdown.memory_space() == MemorySpace::MEMORY_SPACE_HBM)) || + ((memory_space == MemorySpace::MEMORY_SPACE_ON_CHIP) && + (breakdown.memory_space() != MemorySpace::MEMORY_SPACE_HBM))) { + bytes += breakdown.bytes_accessed(); + } + } + } + return bytes; +} + +inline double GigaBytesPerSecondPerCore( + const OpMetrics& metrics, uint64_t memory_space, + OpMetrics::MemoryAccessed::OperationType operation_type) { + // bytes_accessed and time_ps are accumulated across all occurrences on all + // cores. + // time_ps is used instead of self_time_ps because bytes_accessed for an op + // includes the bytes accessed by children (nested) ops. + return tsl::profiler::SafeDivide( + BytesAccessedPerCore(metrics, memory_space, operation_type), + tsl::profiler::PicoToNano(metrics.time_ps())); +} + +inline double GibiBytesPerSecondPerCore( + const OpMetrics& metrics, uint64_t memory_space, + OpMetrics::MemoryAccessed::OperationType op_type) { + return tsl::profiler::GigaToGibi( + GigaBytesPerSecondPerCore(metrics, memory_space, op_type)); +} + +template +inline void SetExecutionTimes(const OpMetrics& metrics, Record* record) { + record->set_occurrences(metrics.occurrences()); + record->set_total_time_in_us(tsl::profiler::PicoToMicro(metrics.time_ps())); + record->set_avg_time_in_us( + SafeDivide(record->total_time_in_us(), metrics.occurrences())); + record->set_total_self_time_in_us( + tsl::profiler::PicoToMicro(metrics.self_time_ps())); + record->set_avg_self_time_in_us( + SafeDivide(record->total_self_time_in_us(), metrics.occurrences())); +} + +template +inline void SetTpuUnitFractions(const OpMetrics& metrics, Record* record) { + record->set_dma_stall_fraction( + tsl::profiler::SafeDivide(metrics.dma_stall_ps(), metrics.time_ps())); +} + +template +inline void SetRankAndTimeFractions(double total_time_us, + const Record& prev_record, Record* record) { + record->set_rank(prev_record.rank() + 1); + record->set_total_self_time_as_fraction( + SafeDivide(record->total_self_time_in_us(), total_time_us)); + record->set_cumulative_total_self_time_as_fraction( + prev_record.cumulative_total_self_time_as_fraction() + + record->total_self_time_as_fraction()); +} + +template +inline void SetRankAndDeviceTimeFractions(double total_time_us, + const Record& prev_record, + Record* record) { + record->set_rank(prev_record.rank() + 1); + record->set_device_total_self_time_as_fraction( + SafeDivide(record->total_self_time_in_us(), total_time_us)); + record->set_device_cumulative_total_self_time_as_fraction( + prev_record.device_cumulative_total_self_time_as_fraction() + + record->device_total_self_time_as_fraction()); +} + +template +inline void SetRankAndHostTimeFractions(double total_time_us, + const Record& prev_record, + Record* record) { + record->set_rank(prev_record.rank() + 1); + record->set_host_total_self_time_as_fraction( + SafeDivide(record->total_self_time_in_us(), total_time_us)); + record->set_host_cumulative_total_self_time_as_fraction( + prev_record.host_cumulative_total_self_time_as_fraction() + + record->host_total_self_time_as_fraction()); +} + +// Returns the memory bandwidth in GigaBytes/s in the PerfEnv. +// memory space is chosen by index following order in xplane_to_op_stats.cc +static inline double GetMemoryPeakBandwidth(const PerfEnv& perf_env, + const int index) { + if (perf_env.peak_bws_giga_bytes_per_second_size() > index) { + return perf_env.peak_bws_giga_bytes_per_second(index); + } + return perf_env.peak_hbm_bw_giga_bytes_per_second(); +} + +template +inline void SetRooflineMetrics(const OpMetrics& metrics, const PerfEnv perf_env, + const RunEnvironment& run_env, Record* record) { + using ::tensorflow::profiler::MemorySpace; + using ::tensorflow::profiler::PerformanceInfo; + using ::tensorflow::profiler::PicoToNano; + + // Set overall performance metrics. + record->set_measured_flop_rate(GigaFlopsPerSecondPerCore(metrics)); + record->set_model_flop_rate(GigaModelFlopsPerSecondPerCore(metrics)); + record->set_measured_memory_bw(GibiBytesPerSecondPerCore( + metrics, tensorflow::profiler::MemorySpace::MEMORY_SPACE_ALL, + OpMetrics::MemoryAccessed::UNKNOWN)); + record->set_flops(metrics.flops()); + record->set_bytes_accessed(metrics.bytes_accessed()); + record->set_operational_intensity( + tsl::profiler::SafeDivide(metrics.flops(), metrics.bytes_accessed())); + // Set performance metrics per memory access type. + uint64_t hbm_bytes = 0; + uint64_t cmem_read_bytes = 0; + uint64_t cmem_write_bytes = 0; + uint64_t vmem_read_bytes = 0; + uint64_t vmem_write_bytes = 0; + for (const auto& memory_access : metrics.memory_accessed_breakdown()) { + if (memory_access.memory_space() == PerformanceInfo::MemoryAccessed::HBM) { + hbm_bytes += memory_access.bytes_accessed(); + } else if (memory_access.memory_space() == + PerformanceInfo::MemoryAccessed::CMEM) { + if (memory_access.operation_type() == OpMetrics::MemoryAccessed::READ) { + cmem_read_bytes += memory_access.bytes_accessed(); + } else if (memory_access.operation_type() == + OpMetrics::MemoryAccessed::WRITE) { + cmem_write_bytes += memory_access.bytes_accessed(); + } + } else if (memory_access.memory_space() == + PerformanceInfo::MemoryAccessed::VMEM) { + if (memory_access.operation_type() == OpMetrics::MemoryAccessed::READ) { + vmem_read_bytes += memory_access.bytes_accessed(); + } else if (memory_access.operation_type() == + OpMetrics::MemoryAccessed::WRITE) { + vmem_write_bytes += memory_access.bytes_accessed(); + } + } + } + if (metrics.memory_accessed_breakdown_size() == 0) { + // For legacy profiles without memory access breakdown, consider all memory + // access as HBM access. + hbm_bytes = metrics.bytes_accessed(); + } + record->set_hbm_bw(tsl::profiler::GibibytesPerSecond( + hbm_bytes, tsl::profiler::PicoToNano(metrics.time_ps()))); + record->set_cmem_read_bw(tsl::profiler::GibibytesPerSecond( + cmem_read_bytes, tsl::profiler::PicoToNano(metrics.time_ps()))); + record->set_cmem_write_bw(tsl::profiler::GibibytesPerSecond( + cmem_write_bytes, tsl::profiler::PicoToNano(metrics.time_ps()))); + record->set_vmem_read_bw(tsl::profiler::GibibytesPerSecond( + vmem_read_bytes, tsl::profiler::PicoToNano(metrics.time_ps()))); + record->set_vmem_write_bw(tsl::profiler::GibibytesPerSecond( + vmem_write_bytes, tsl::profiler::PicoToNano(metrics.time_ps()))); + record->set_hbm_operational_intensity( + tsl::profiler::SafeDivide(metrics.flops(), hbm_bytes)); + record->set_cmem_read_operational_intensity( + tsl::profiler::SafeDivide(metrics.flops(), cmem_read_bytes)); + record->set_cmem_write_operational_intensity( + tsl::profiler::SafeDivide(metrics.flops(), cmem_write_bytes)); + record->set_vmem_read_operational_intensity( + tsl::profiler::SafeDivide(metrics.flops(), vmem_read_bytes)); + record->set_vmem_write_operational_intensity( + tsl::profiler::SafeDivide(metrics.flops(), vmem_write_bytes)); + // Resources considered for roofline analysis. + constexpr absl::string_view kUnknown = "Unknown"; + constexpr absl::string_view kCompute = "Compute"; + constexpr absl::string_view kHbm = "HBM"; + constexpr absl::string_view kCmemRead = "CMEM Read"; + constexpr absl::string_view kCmemWrite = "CMEM Write"; + constexpr absl::string_view kVmemRead = "VMEM Read"; + constexpr absl::string_view kVmemWrite = "VMEM Write"; + constexpr absl::string_view kShmL1 = "Shm/L1"; + // Compute the bound time assuming the peak capacity of each resource and + // choose the highest one as the bottleneck. See go/xprof-roofline-pxc for + // more details. + // NOTE: The roofline analysis result is the same for Megacore because every + // resource's capacity is doubled for Megacore so the comparison result is the + // same. + absl::string_view bottleneck_resource = kUnknown; + double bottleneck_utilization = 0; + double bottleneck_operational_intensity = 0; + double peak_flops = + tsl::profiler::TeraToGiga(perf_env.peak_tera_flops_per_second()); + double flops_utilization = + SafeDivide(record->measured_flop_rate(), peak_flops); + if (bottleneck_utilization < flops_utilization) { + bottleneck_resource = kCompute; + bottleneck_utilization = flops_utilization; + bottleneck_operational_intensity = record->operational_intensity(); + } + double peak_hbm_bw = GetMemoryPeakBandwidth(perf_env, 0); + double hbm_bw_utilization = + SafeDivide(record->hbm_bw(), tsl::profiler::GigaToGibi(peak_hbm_bw)); + if (bottleneck_utilization < hbm_bw_utilization) { + bottleneck_resource = kHbm; + bottleneck_utilization = hbm_bw_utilization; + bottleneck_operational_intensity = record->hbm_operational_intensity(); + } + tensorflow::profiler::HardwareType hardware_type = run_env.hardware_type(); + if (hardware_type == tensorflow::profiler::HardwareType::TPU) { + if (cmem_read_bytes) { + double peak_cmem_read_bw = GetMemoryPeakBandwidth(perf_env, 3); + if (peak_cmem_read_bw) { + double cmem_read_bw_utilization = + SafeDivide(record->cmem_read_bw(), + tsl::profiler::GigaToGibi(peak_cmem_read_bw)); + if (bottleneck_utilization < cmem_read_bw_utilization) { + bottleneck_resource = kCmemRead; + bottleneck_utilization = cmem_read_bw_utilization; + bottleneck_operational_intensity = + record->cmem_read_operational_intensity(); + } + } + } + if (cmem_write_bytes) { + double peak_cmem_write_bw = GetMemoryPeakBandwidth(perf_env, 4); + if (peak_cmem_write_bw) { + double cmem_write_bw_utilization = + SafeDivide(record->cmem_write_bw(), + tsl::profiler::GigaToGibi(peak_cmem_write_bw)); + if (bottleneck_utilization < cmem_write_bw_utilization) { + bottleneck_resource = kCmemWrite; + bottleneck_utilization = cmem_write_bw_utilization; + bottleneck_operational_intensity = + record->cmem_write_operational_intensity(); + } + } + } + if (vmem_read_bytes) { + double peak_vmem_read_bw = GetMemoryPeakBandwidth(perf_env, 5); + if (peak_vmem_read_bw) { + double vmem_read_bw_utilization = + SafeDivide(record->vmem_read_bw(), + tsl::profiler::GigaToGibi(peak_vmem_read_bw)); + if (bottleneck_utilization < vmem_read_bw_utilization) { + bottleneck_resource = kVmemRead; + bottleneck_utilization = vmem_read_bw_utilization; + bottleneck_operational_intensity = + record->vmem_read_operational_intensity(); + } + } + } + if (vmem_write_bytes) { + double peak_vmem_write_bw = GetMemoryPeakBandwidth(perf_env, 6); + if (peak_vmem_write_bw) { + double vmem_write_bw_utilization = + SafeDivide(record->vmem_write_bw(), + tsl::profiler::GigaToGibi(peak_vmem_write_bw)); + if (bottleneck_utilization < vmem_write_bw_utilization) { + bottleneck_resource = kVmemWrite; + bottleneck_utilization = vmem_write_bw_utilization; + bottleneck_operational_intensity = + record->vmem_write_operational_intensity(); + } + } + } + } + if (hardware_type == tensorflow::profiler::HardwareType::GPU) { + double peak_shm_l1_bw = GetMemoryPeakBandwidth(perf_env, 2); + if (peak_shm_l1_bw) { + // Currently, we only have general read/write bandwidth in record. + double shm_l1_bw_utilization = SafeDivide( + record->hbm_bw(), tsl::profiler::GigaToGibi(peak_shm_l1_bw)); + if (bottleneck_utilization < shm_l1_bw_utilization) { + bottleneck_resource = kShmL1; + bottleneck_utilization = shm_l1_bw_utilization; + bottleneck_operational_intensity = record->hbm_operational_intensity(); + } + } + } + record->set_bound_by(std::string(bottleneck_resource)); + record->set_bottleneck_operational_intensity( + bottleneck_operational_intensity); +} + +} // namespace profiler +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PROFILER_CONVERT_OP_METRICS_TO_RECORD_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/profiler/convert/op_profile_builder.h b/third_party/tflite-hdrs/tensorflow/core/profiler/convert/op_profile_builder.h new file mode 100644 index 00000000..3d4e7abd --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/profiler/convert/op_profile_builder.h @@ -0,0 +1,157 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PROFILER_CONVERT_OP_PROFILE_BUILDER_H_ +#define TENSORFLOW_CORE_PROFILER_CONVERT_OP_PROFILE_BUILDER_H_ + +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/container/node_hash_map.h" +#include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/profiler/protobuf/op_metrics.pb.h" +#include "tensorflow/core/profiler/protobuf/op_profile.pb.h" + +namespace tensorflow { +namespace profiler { + +struct OpProfileOptions { + bool group_by_program = true; + bool group_by_deduplicated_name = true; + int children_per_node = 100; +}; + +// The structure of an op profile tree may looks like below: +// 1. group "by_program" +// - It starts from the root node, named as "by_program", and this node does +// not show up in op profile. +// - The children of root node is a list of hlo program node, named as the +// program/module name (eg. cluster.xx). +// - The children of a program node is hlo op category node, named as the +// category name (eg. data formatting). +// - The children of a category node is a list of op node or deduplicated +// group node: +// - For op that has duplicates, the child will be a deduplicated node, +// named like "copy.1111 and its deduplicate(s)". Its children will be all op +// nodes that are deduplicated. +// - For op that does not have duplicates, the child will be an op node +// under the op category (eg. copy.2222). +// +// Example path: "by_program" -> "main(...)" +// -> "data_formatting" -> "copy.12345 and its duplicate(s) -> "copy.12345" +// +// 2. group "by_category" +// Similarly to how the `by_program` op profile tree is constructed, +// `by_category` just removed the "program_node" layer: +// - It starts from the root node, named as "by_category", this node also does +// not show up in op profile. +// - The children of root node is a list of op category node, everything below +// is similar to above. +// - ... +// +// Example path: "by_category" -> "data_formatting" -> "copy.12345 and its +// duplicate(s) -> "copy.12345" +// +// How the op profile metrics are calculated: +// 1. For parent node in the nested structure like root node and program node: +// - time_ps will be accumulated from the self_time of all op nodes under it +// (might still be off a bit if the parent node has self_time, more details in +// b/333608397#comment5) +// - flops and memory access will only be accumulated from leaf op node under +// it to avoid double counting +// - unable to get occurrences of program executions now +// 2. For conceptual horizontal grouping node (eg.category, deduplicated) +// - all op_metris fields will be accumulated from leaf op node only in the +// group, to avoid double counting +class OpProfileBuilder { + public: + OpProfileBuilder(const OpProfileOptions& options, op_profile::Node* root, + const tensorflow::protobuf::Map* + program_name_map = nullptr); + + // Accumulate the op_metrics to the op_profile node tree + void AddOp(const OpMetrics& op_metrics); + + // Finalize the op_profile proto in a few steps (inter-dependent): + // 1. Reset time_ps for root node for more precise total time + // 2. Loop over the node to op_metrics map, populate corresponding op_metrics + // to the node.metrics + // 3. `SortAndPruneChildren` given query param `op_profile_limit` + // 4. `FinalizeDeduplicatedNodes` by coping the first op node data to the + // deduplicated node + void Finalize(double peak_gigaflops_per_second_per_core, + std::vector peak_mem_gibibytes_per_second_per_core, + uint64_t total_time_ps); + + private: + struct Category { + op_profile::Node* node; + absl::flat_hash_map deduplicated_nodes; + }; + + struct Program { + op_profile::Node* node; + absl::flat_hash_map categories; + }; + + std::string GenerateProgramName(uint64_t program_id) const; + + // Adds and returns a node for op_metrics. + // If op_metrics corresponds to a fusion, adds children to the node for the + // fused instructions. + // If deduplicated_node is not null, adds the node under it. + // Otherwise, if category is not null, adds the node under category. + // Otherwise, adds the node under root. + op_profile::Node* AddOpNode(const OpMetrics& op_metrics, + Category* category = nullptr, + op_profile::Node* deduplicated_node = nullptr); + + // Returns a node for op_metrics.deduplicated_name(). + // Adds a node to the tree if necessary. + op_profile::Node* LookupOrAddDeduplicatedNode(const OpMetrics& op_metrics, + Category* category); + + // Returns a node for op_metrics.category(). + // Adds a node to the tree if necessary. + // If program is not null, the category node is added under program. + // Otherwise, the category node is added under root. + Category* LookupOrAddCategoryNode(const OpMetrics& op_metrics, + Program* program); + + // Returns a node for op_metrics.hlo_module_id(). + // Adds a node to the Node tree if necessary. + Program* LookupOrAddProgramNode(const OpMetrics& op_metrics); + + OpProfileOptions options_; + op_profile::Node* root_; + + // Map to look up and aggregate OpMetrics. + absl::node_hash_map metrics_; + + // Maps to look up if a category / program / deduplicated node has + // already been added to the tree. + absl::flat_hash_map programs_map_; + absl::flat_hash_map category_map_; + + // Map to look up program names by id. + const tensorflow::protobuf::Map* program_name_map_ = + nullptr; +}; +} // namespace profiler +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PROFILER_CONVERT_OP_PROFILE_BUILDER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/profiler/convert/op_stack.h b/third_party/tflite-hdrs/tensorflow/core/profiler/convert/op_stack.h new file mode 100644 index 00000000..6bfa4d77 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/profiler/convert/op_stack.h @@ -0,0 +1,69 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PROFILER_CONVERT_OP_STACK_H_ +#define TENSORFLOW_CORE_PROFILER_CONVERT_OP_STACK_H_ + +#include +#include +#include + +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +namespace profiler { + +template +class OpStack { + public: + // Pushes an Op onto the stack. + void Push(uint32 op_id, std::unique_ptr op_info) { + stack_.emplace_back(op_id, std::move(op_info)); + } + + // Pops the Op with the given op_id from the stack. + std::unique_ptr Pop(uint32 op_id) { + // Pop until match or stack_ is empty. + std::unique_ptr result; + while (!stack_.empty()) { + auto back = std::move(stack_.back()); + stack_.pop_back(); + if (op_id == back.first) { + result = std::move(back.second); + break; + } + } + return result; + } + + // Returns the Op at the top of the stack. + OpInfo* Top() const { + return stack_.empty() ? nullptr : stack_.back().second.get(); + } + + // Returns true if the stack is empty. + bool Empty() const { return stack_.empty(); } + + // Clears the stack. + void Clear() { stack_.clear(); } + + private: + std::vector>> stack_; +}; + +} // namespace profiler +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PROFILER_CONVERT_OP_STACK_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/profiler/convert/op_stats_combiner.h b/third_party/tflite-hdrs/tensorflow/core/profiler/convert/op_stats_combiner.h new file mode 100644 index 00000000..a8cb3c62 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/profiler/convert/op_stats_combiner.h @@ -0,0 +1,86 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PROFILER_CONVERT_OP_STATS_COMBINER_H_ +#define TENSORFLOW_CORE_PROFILER_CONVERT_OP_STATS_COMBINER_H_ + +#include + +#include "absl/container/flat_hash_map.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/profiler/convert/op_metrics_db_combiner.h" +#include "tensorflow/core/profiler/protobuf/hardware_types.pb.h" +#include "tensorflow/core/profiler/protobuf/op_stats.pb.h" +#include "tensorflow/core/profiler/utils/step_intersection.h" + +namespace tensorflow { +namespace profiler { + +// Whether a host is a coordinator. +bool IsCoordinator(bool no_accelerator_in_system, HardwareType hardware_type); + +// Translates the core id from single host to the one for multiple-host. +// We need this translation because the device_ordinal was assigned when a +// single host response was given. Now, we need a global core_id to distinguish +// it with multiple hosts. +uint32 GlobalCoreId(int host_id, uint32 device_ordinal); + +// Combines the src map into the dst map. +// The src map keys are local core_ids. The src_host_id is used to convert them +// into global core_ids used as keys in the dst map. +// REQUIRED: cores from src_host_id are not already in dst. +template +void CombineCoreIdMap(int src_host_id, const CoreIdMap& src, CoreIdMap* dst) { + for (const auto& core_id_and_value : src) { + uint32 global_core_id = GlobalCoreId(src_host_id, core_id_and_value.first); + auto iter_and_inserted = + dst->insert({global_core_id, core_id_and_value.second}); + DCHECK(iter_and_inserted.second) + << "Duplicated core_id: " << iter_and_inserted.first->first; + } +} + +// A struct that contains all the information that is needed to combine OpStats. +struct OpStatsInfo { + OpStatsInfo(const OpStats* op_stats, HardwareType hardware_type, + int src_host_id) + : op_stats(op_stats), + hardware_type(hardware_type), + src_host_id(src_host_id) {} + const OpStats* op_stats; + HardwareType hardware_type; + int src_host_id; +}; + +// Returns true if there is no device (accelerator) in any of the hosts. +bool NoAcceleratorInSystem(const std::vector& all_op_stats_info); + +// Compute the StepIntersection to merge OpStats. +// Profiler will limit the number of steps to be at most . +StepIntersection ComputeStepIntersectionToMergeOpStats( + const std::vector& all_op_stats_info, + uint32 max_step_per_host); + +// Combine all the OpStats in using the steps in range +// . The result is stored in . +void CombineAllOpStats(const std::vector& all_op_stats_info, + const StepIntersection& step_intersection, + OpStats* combined_op_stats); + +} // namespace profiler +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PROFILER_CONVERT_OP_STATS_COMBINER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/profiler/convert/op_stats_to_hlo_stats.h b/third_party/tflite-hdrs/tensorflow/core/profiler/convert/op_stats_to_hlo_stats.h new file mode 100644 index 00000000..1037ef19 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/profiler/convert/op_stats_to_hlo_stats.h @@ -0,0 +1,30 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PROFILER_CONVERT_OP_STATS_TO_HLO_STATS_H_ +#define TENSORFLOW_CORE_PROFILER_CONVERT_OP_STATS_TO_HLO_STATS_H_ + +#include "tensorflow/core/profiler/protobuf/hlo_stats.pb.h" +#include "tensorflow/core/profiler/protobuf/op_stats.pb.h" + +namespace tensorflow { +namespace profiler { +tensorflow::profiler::hlo_stats::HloStatsDatabase ConvertOpStatsToHloStats( + const tensorflow::profiler::OpStats& op_stats); + +} // namespace profiler +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PROFILER_CONVERT_OP_STATS_TO_HLO_STATS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/profiler/convert/op_stats_to_input_pipeline_analysis.h b/third_party/tflite-hdrs/tensorflow/core/profiler/convert/op_stats_to_input_pipeline_analysis.h new file mode 100644 index 00000000..c9de162e --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/profiler/convert/op_stats_to_input_pipeline_analysis.h @@ -0,0 +1,90 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PROFILER_CONVERT_OP_STATS_TO_INPUT_PIPELINE_ANALYSIS_H_ +#define TENSORFLOW_CORE_PROFILER_CONVERT_OP_STATS_TO_INPUT_PIPELINE_ANALYSIS_H_ + +#include + +#include "google/protobuf/any.pb.h" +#include "absl/strings/string_view.h" +#include "xla/tsl/util/stats_calculator.h" +#include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/profiler/protobuf/hardware_types.pb.h" +#include "tensorflow/core/profiler/protobuf/input_pipeline.pb.h" +#include "tensorflow/core/profiler/protobuf/op_metrics.pb.h" +#include "tensorflow/core/profiler/protobuf/op_stats.pb.h" +#include "tensorflow/core/profiler/protobuf/steps_db.pb.h" + +namespace tensorflow { +namespace profiler { + +StepSummary GetStepSummaryForSampleStats(const tsl::Stat& sample_stats); + +// If the percent of input-time spent on host-to-device transfer is greater than +// kHostToDeviceTimePercentAsSignificant, we should advise the +// user to optimize this transfer. +constexpr double kHostToDeviceTimePercentAsSignificant = 10.0; + +// If the percent of input-time spent on host-to-device transfer is greater than +// kHostToDeviceTimePercentAsDominant, we should ONLY advise the +// user to optimize this transfer; we won't bother to suggest optimization for +// tf.data. +constexpr double kHostToDeviceTimePercentAsDominant = 90.0; + +// Computes the summary of step time in milliseconds. +StepSummary ComputeStepTimeSummaryInMs( + const ::tensorflow::protobuf::RepeatedPtrField& + grouped_by_step); + +void GenerateHostResult(const OpMetricsDb& host_tf_metrics_db, + InputPipelineAnalysisResult* result); + +InputPipelineAnalysisRecommendation GenerateRecommendation(); + +// Returns the performance bottleneck of the program executed. +BottleneckAnalysis ComputeBottleneckAnalysis( + const InputTimeBreakdown& input_time_breakdown, + const ::tensorflow::protobuf::RepeatedPtrField<::google::protobuf::Any>& + any_step_details); + +InputPipelineAnalysisResult ConvertOpStatsToInputPipelineAnalysis( + const OpStats& op_stats); + +// Returns true if explanation for "All Others" time is also included in +// input_statement. +bool InputAnalysis(double input_percent, double all_other_percent, + std::string* input_classification, + std::string* input_statement); + +void OutputAnalysis(double output_percent, std::string* output_classification, + std::string* output_statement); + +string GetSummaryNextStep(absl::string_view input_classification, + const InputTimeBreakdown& breakdown); + +// Returns the percentage of the input time that is spent on transferring the +// data from host to device. +double HostToDeviceTransferAsPercentOfInputTime( + const InputTimeBreakdown& breakdown); + +void AddErrorMessages(const OpStats& op_stats, + InputPipelineAnalysisResult* result); + +} // namespace profiler +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PROFILER_CONVERT_OP_STATS_TO_INPUT_PIPELINE_ANALYSIS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/profiler/convert/op_stats_to_op_profile.h b/third_party/tflite-hdrs/tensorflow/core/profiler/convert/op_stats_to_op_profile.h new file mode 100644 index 00000000..1fcfefb5 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/profiler/convert/op_stats_to_op_profile.h @@ -0,0 +1,56 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PROFILER_CONVERT_OP_STATS_TO_OP_PROFILE_H_ +#define TENSORFLOW_CORE_PROFILER_CONVERT_OP_STATS_TO_OP_PROFILE_H_ + +#include "tensorflow/core/profiler/protobuf/hardware_types.pb.h" +#include "tensorflow/core/profiler/protobuf/op_profile.pb.h" +#include "tensorflow/core/profiler/protobuf/op_stats.pb.h" + +namespace tensorflow { +namespace profiler { + +// Assembles a hierarchical performance profile based on HLOs in the op metrics +// db. +// The node hierarchy is as following: +// by_category +// - combined_root +// - category 1 +// - category 2 +// - ... +// - idle +// by_program +// - program_1_root +// - category 1 +// - category 2 +// - ... +// - program_2_root +// - category 1 +// - ... +// - idle +// The nodes in the profile are sorted by time in decreasing order and pruned +// to reduce the profile size. Only 100 nodes are kept for level >= 3. +// See op_profile.proto for the detailed semantics of the returned profile. +void ConvertOpStatsToOpProfile( + const tensorflow::profiler::OpStats& op_stats, + tensorflow::profiler::HardwareType hardware_type, + tensorflow::profiler::op_profile::Profile& profile, + int op_profile_limit = 100); + +} // namespace profiler +} // namespace tensorflow + +#endif // THIRD_PARTY_TENSORFLOW_CORE_PROFILER_CONVERT_OP_STATS_TO_OP_PROFILE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/profiler/convert/op_stats_to_overview_page.h b/third_party/tflite-hdrs/tensorflow/core/profiler/convert/op_stats_to_overview_page.h new file mode 100644 index 00000000..2911e956 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/profiler/convert/op_stats_to_overview_page.h @@ -0,0 +1,81 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PROFILER_CONVERT_OP_STATS_TO_OVERVIEW_PAGE_H_ +#define TENSORFLOW_CORE_PROFILER_CONVERT_OP_STATS_TO_OVERVIEW_PAGE_H_ + +#include + +#include "absl/strings/string_view.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/profiler/protobuf/hardware_types.pb.h" +#include "tensorflow/core/profiler/protobuf/input_pipeline.pb.h" +#include "tensorflow/core/profiler/protobuf/op_metrics.pb.h" +#include "tensorflow/core/profiler/protobuf/op_stats.pb.h" +#include "tensorflow/core/profiler/protobuf/overview_page.pb.h" +#include "tensorflow/core/profiler/protobuf/xplane.pb.h" + +namespace tensorflow { +namespace profiler { + +// Reports tf-function optimization opportunity in the Overview Page if the +// expensive-call-time percentage is over this threshold for at least one of +// the tf-functions profiled. +const double kTfFunctionReportThresholdInPercent = 20; + +// Reports eager-mode optimization opportunity in the Overview Page if the +// percent of Op time on host (or device) that is spent on eager mode is over +// this threshold. +const double kEagerReportThresholdInPercent = 10; + +// Reports outside-compilation opportunity in the Overview Page if the +// percent of Op time on device that is for outside compilation is over +// this threshold. +const double kOutsideCompilationThresholdInPercent = 5; + +void SetCommonRecommendation( + absl::string_view input_classification, absl::string_view input_statement, + absl::string_view output_statement, HardwareType hardware_type, + absl::string_view tf_function_statement_html, + absl::string_view eager_statement_html, + absl::string_view outside_compilation_statement_html, + OverviewPageRecommendation* re); + +OverviewPageRecommendation ComputeGenericRecommendation( + const BottleneckAnalysis& bottleneck, + const PrecisionStats& precision_stats); + +OverviewPageAnalysis ComputeAnalysisResult(const OpStats& op_stats); + +OverviewPageRunEnvironment ComputeRunEnvironment( + const RunEnvironment& run_environment); + +OverviewPage ConvertOpStatsToOverviewPage(const OpStats& op_stats); + +// Returns a html which provides tf-function related recommendation. +std::string TfFunctionRecommendationHtml(const TfFunctionDb& tf_function_db); + +// Returns a html which provides eager-mode related recommendation. +std::string EagerRecommendationHtml(double host_op_time_eager_percent, + double device_op_time_eager_percent); + +// Returns a html which provides outside-compilation related recommendation. +std::string OutsideCompilationRecommendationHtml( + double device_op_time_outside_compilation_percent); + +} // namespace profiler +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PROFILER_CONVERT_OP_STATS_TO_OVERVIEW_PAGE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/profiler/convert/op_stats_to_pod_stats.h b/third_party/tflite-hdrs/tensorflow/core/profiler/convert/op_stats_to_pod_stats.h new file mode 100644 index 00000000..bd3d7406 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/profiler/convert/op_stats_to_pod_stats.h @@ -0,0 +1,30 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PROFILER_CONVERT_OP_STATS_TO_POD_STATS_H_ +#define TENSORFLOW_CORE_PROFILER_CONVERT_OP_STATS_TO_POD_STATS_H_ + +#include "tensorflow/core/profiler/protobuf/op_stats.pb.h" +#include "tensorflow/core/profiler/protobuf/pod_stats.pb.h" + +namespace tensorflow { +namespace profiler { + +PodStatsDatabase ConvertOpStatsToPodStats(const OpStats& op_stats); + +} // namespace profiler +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PROFILER_CONVERT_OP_STATS_TO_POD_STATS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/profiler/convert/op_stats_to_pod_viewer.h b/third_party/tflite-hdrs/tensorflow/core/profiler/convert/op_stats_to_pod_viewer.h new file mode 100644 index 00000000..c45c9939 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/profiler/convert/op_stats_to_pod_viewer.h @@ -0,0 +1,30 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PROFILER_CONVERT_OP_STATS_TO_POD_VIEWER_H_ +#define TENSORFLOW_CORE_PROFILER_CONVERT_OP_STATS_TO_POD_VIEWER_H_ + +#include "tensorflow/core/profiler/protobuf/op_stats.pb.h" +#include "tensorflow/core/profiler/protobuf/pod_viewer.pb.h" + +namespace tensorflow { +namespace profiler { + +PodViewerDatabase ConvertOpStatsToPodViewer(const OpStats& op_stats); + +} // namespace profiler +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PROFILER_CONVERT_OP_STATS_TO_POD_VIEWER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/profiler/convert/op_stats_to_roofline_model.h b/third_party/tflite-hdrs/tensorflow/core/profiler/convert/op_stats_to_roofline_model.h new file mode 100644 index 00000000..d745b96f --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/profiler/convert/op_stats_to_roofline_model.h @@ -0,0 +1,98 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PROFILER_CONVERT_OP_STATS_TO_ROOFLINE_MODEL_H_ +#define TENSORFLOW_CORE_PROFILER_CONVERT_OP_STATS_TO_ROOFLINE_MODEL_H_ + +#include + +#include "tsl/platform/protobuf.h" +#include "tensorflow/core/profiler/protobuf/op_metrics.pb.h" +#include "tensorflow/core/profiler/protobuf/op_stats.pb.h" +#include "tensorflow/core/profiler/protobuf/roofline_model.pb.h" +#include "tensorflow/core/profiler/protobuf/steps_db.pb.h" + +namespace tensorflow { +namespace profiler { + +using tensorflow::profiler::OpMetrics; +using tensorflow::profiler::roofline_model::RecordType; +using tensorflow::profiler::roofline_model::RooflineModelDatabase; +using tensorflow::profiler::roofline_model::RooflineModelRecord; + +RooflineModelRecord ConvertOpMetricsToRooflineModelRecord( + const OpStats& op_stats, const OpMetrics& metrics, RecordType record_type, + uint32_t step_num, uint64_t total_time_ps, + const RooflineModelDatabase& roofline_model_db, + bool include_infeed_outfeed); + +RooflineModelRecord GenerateRooflineModelProgramRecord( + const OpStats& op_stats, const OpMetricsDb& db, RecordType record_type, + uint32_t step_num, const RooflineModelDatabase& roofline_model_db, + bool include_infeed_outfeed); + +tsl::protobuf::RepeatedPtrField +ConvertOpMetricsDbToRooflineModelRecords( + const OpStats& op_stats, const OpMetricsDb& db, RecordType record_type, + uint32_t step_num, const RooflineModelDatabase& roofline_model_db, + bool include_infeed_outfeed); + +tensorflow::profiler::roofline_model::RooflineModelDatabase +ConvertOpStatsToRooflineModel(const tensorflow::profiler::OpStats& tf_op_stats, + bool include_infeed_outfeed); + +tensorflow::profiler::roofline_model::RooflineModelDatabase +InitializeRooflineModelDatabaseFromOpStats(const OpStats& op_stats, + bool include_infeed_outfeed); +// Generate RooflineModelRecord for the HLO DB over the entire profiling +// duration including incomplete steps. +inline void AddRooflineModelRecordForProfileDuration( + const OpStats& op_stats, RooflineModelDatabase& roofline_model_db, + bool include_infeed_outfeed) { + *roofline_model_db.mutable_roofline_model_record() = + ConvertOpMetricsDbToRooflineModelRecords( + op_stats, op_stats.device_op_metrics_db(), RecordType::ALL, + /*step_num=*/0, roofline_model_db, include_infeed_outfeed); +} + +// Generate RooflineModelRecord for the HLO DB over complete steps only. +inline void AddRooflineModelRecordsForCompleteSteps( + const OpStats& op_stats, RooflineModelDatabase& roofline_model_db, + bool include_infeed_outfeed) { + if (op_stats.has_hlo_metrics_db_complete_steps_only()) { + *roofline_model_db.add_roofline_model_record() = + GenerateRooflineModelProgramRecord( + op_stats, op_stats.hlo_metrics_db_complete_steps_only(), + RecordType::AVERAGE_STEP, /*step_num=*/0, roofline_model_db, + include_infeed_outfeed); + } +} + +// Generate RooflineModelRecords for the per-step DBs. +inline void AddRooflineModelRecordsPerStep( + const OpStats& op_stats, RooflineModelDatabase& roofline_model_db, + bool include_infeed_outfeed) { + for (const auto& step_info : op_stats.step_db().step_sequence()) { + *roofline_model_db.add_roofline_model_record() = + GenerateRooflineModelProgramRecord( + op_stats, step_info.hlo_metrics_db(), RecordType::PER_STEP, + step_info.step_num(), roofline_model_db, include_infeed_outfeed); + } +} + +} // namespace profiler +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PROFILER_CONVERT_OP_STATS_TO_ROOFLINE_MODEL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/profiler/convert/op_stats_to_tf_stats.h b/third_party/tflite-hdrs/tensorflow/core/profiler/convert/op_stats_to_tf_stats.h new file mode 100644 index 00000000..3b8a06ef --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/profiler/convert/op_stats_to_tf_stats.h @@ -0,0 +1,30 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PROFILER_CONVERT_OP_STATS_TO_TF_STATS_H_ +#define TENSORFLOW_CORE_PROFILER_CONVERT_OP_STATS_TO_TF_STATS_H_ + +#include "tensorflow/core/profiler/protobuf/op_stats.pb.h" +#include "tensorflow/core/profiler/protobuf/tf_stats.pb.h" + +namespace tensorflow { +namespace profiler { + +TfStatsDatabase ConvertOpStatsToTfStats(const OpStats& op_stats); + +} // namespace profiler +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PROFILER_CONVERT_OP_STATS_TO_TF_STATS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/profiler/convert/preprocess_single_host_xplane.h b/third_party/tflite-hdrs/tensorflow/core/profiler/convert/preprocess_single_host_xplane.h new file mode 100644 index 00000000..4c86ed87 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/profiler/convert/preprocess_single_host_xplane.h @@ -0,0 +1,35 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_PROFILER_CONVERT_PREPROCESS_SINGLE_HOST_XPLANE_H_ +#define TENSORFLOW_CORE_PROFILER_CONVERT_PREPROCESS_SINGLE_HOST_XPLANE_H_ + +#include "xla/tsl/profiler/utils/group_events.h" +#include "tsl/profiler/protobuf/xplane.pb.h" + +namespace tensorflow { +namespace profiler { + +// Preprocess XSpaces before tools conversion. +// If step_grouping = true, perform events grouping for step tracking. +// If derived_timeline, generate derived timeline (XLines). +// If group_metadata_map is not nullptr, populate the group metadata map. +void PreprocessSingleHostXSpace( + XSpace* space, bool step_grouping, bool derived_timeline, + tsl::profiler::GroupMetadataMap* group_metadata_map = nullptr); + +} // namespace profiler +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PROFILER_CONVERT_PREPROCESS_SINGLE_HOST_XPLANE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/profiler/convert/process_megascale_dcn.h b/third_party/tflite-hdrs/tensorflow/core/profiler/convert/process_megascale_dcn.h new file mode 100644 index 00000000..794c2bea --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/profiler/convert/process_megascale_dcn.h @@ -0,0 +1,29 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_PROFILER_CONVERT_PROCESS_MEGASCALE_DCN_H_ +#define TENSORFLOW_CORE_PROFILER_CONVERT_PROCESS_MEGASCALE_DCN_H_ + +#include "tsl/profiler/protobuf/xplane.pb.h" + +namespace tensorflow { +namespace profiler { + +// Process Dcn Megascale TraceMe info. +void ProcessMegascaleDcn(XSpace* space); + +} // namespace profiler +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PROFILER_CONVERT_PROCESS_MEGASCALE_DCN_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/profiler/convert/profile_time_breakdown.h b/third_party/tflite-hdrs/tensorflow/core/profiler/convert/profile_time_breakdown.h new file mode 100644 index 00000000..1e3379be --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/profiler/convert/profile_time_breakdown.h @@ -0,0 +1,244 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_PROFILER_CONVERT_PROFILE_TIME_BREAKDOWN_H_ +#define TENSORFLOW_CORE_PROFILER_CONVERT_PROFILE_TIME_BREAKDOWN_H_ + +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/log/check.h" +#include "absl/strings/string_view.h" +#include "xla/tsl/profiler/convert/xla_op_utils.h" +#include "xla/tsl/profiler/utils/math_utils.h" + +namespace tensorflow { +namespace profiler { + +// Allows accumulating time spent in different HLO instruction categories to +// breakdown the total profile time and compute metrics of interest. +class ProfileTimeBreakdown { + public: + // Category should be the operator category disambiguated by xprof instead of + // the original category from XLA. + // For a correct time breakdown, we need to use the self time of operators, + // instead of total time to avoid double counting. Note that for leaf ops, + // self time and total time are the same. + void IncrementCategoryTimePs(absl::string_view category, + uint64_t self_time_ps) { + time_ps_by_category_[category] += self_time_ps; + total_time_ps_ += self_time_ps; + } + + // Profile time cannot be smaller than the total time in all categories. + // If combining profiles across multiple cores, profile time should be the + // profiling duration multiplied by the number of cores that were profiled. + // go/autograppler_profile_time + void SetProfileTimePs(uint64_t profile_time_ps) { + DCHECK_LE(total_time_ps_, profile_time_ps); + profile_time_ps_ = profile_time_ps; + } + + // Breaks down "sparsecorev0 infeed" into two components: + // 1) "sparsecorev0 infeed wait": Time spent waiting on the SparseCoreV0. + // 2) "sparsecorev0 infeed transform": Time spent transforming activations in + // SparseCoreV0 layout into XLA layout. + // Even though 2) is part of the overall embedding computation, it is time + // spent doing work on the TensorCore. + void BreakdownSparseCoreV0Infeed(); + + // Duty cycle is the fraction of time an accelerator is being actively used. + // go/accelerator-metrics-definitions#common-accelerator-metrics + // go/ag-tpu-duty-cycle + double DutyCycle() const { return TimeFraction(OnDutyTimePs()); } + + double IdleFraction() const { return TimeFraction(IdleTimePs()); } + + double InfeedFraction() const { + return CategoryFraction(tsl::profiler::kHloInfeed); + } + + double OutfeedFraction() const { + return CategoryFraction(tsl::profiler::kHloOutfeed); + } + + double SparseCoreV0InfeedFraction() const { + return CategoriesFraction({tsl::profiler::kHloSparseCoreV0Infeed, + tsl::profiler::kHloSparseCoreV0InfeedWait, + tsl::profiler::kHloSparseCoreV0InfeedTransform}); + } + + double SparseCoreV0OutfeedFraction() const { + return CategoryFraction(tsl::profiler::kHloSparseCoreV0Outfeed); + } + + double AllReduceFraction() const { + return CategoryFraction(tsl::profiler::kHloAllReduce); + } + + double AllReduceFusionFraction() const { + return CategoryFraction(tsl::profiler::kHloAllReduceFusion); + } + + double SendRecvFraction() const { + return CategoriesFraction( + {tsl::profiler::kHloSend, tsl::profiler::kHloSendDone, + tsl::profiler::kHloRecv, tsl::profiler::kHloRecvDone}); + } + + double HostSendRecvFraction() const { + return CategoriesFraction( + {tsl::profiler::kHloHostSend, tsl::profiler::kHloHostSendDone, + tsl::profiler::kHloHostRecv, tsl::profiler::kHloHostRecvDone}); + } + + double CategoriesFraction( + const std::initializer_list& categories) const { + return TimeFraction(CategoriesTimePs(categories)); + } + + double CategoryFraction(absl::string_view category) const { + return TimeFraction(CategoryTimePs(category)); + } + + uint64_t ProfileTimePs() const { return profile_time_ps_; } + + uint64_t TotalTimePs() const { return total_time_ps_; } + + uint64_t IdleTimePs() const { return profile_time_ps_ - total_time_ps_; } + + uint64_t OnDutyTimePs() const { return profile_time_ps_ - OffDutyTimePs(); } + + uint64_t OffDutyTimePs() const { + return IdleTimePs() + + CategoriesTimePs( + {tsl::profiler::kHloInfeed, tsl::profiler::kHloOutfeed, + tsl::profiler::kHloHostSend, tsl::profiler::kHloHostSendDone, + tsl::profiler::kHloHostRecv, tsl::profiler::kHloHostRecvDone, + tsl::profiler::kHloMegacoreFusion}); + } + + uint64_t InfeedTimePs() const { + return CategoryTimePs(tsl::profiler::kHloInfeed); + } + + uint64_t OutfeedTimePs() const { + return CategoryTimePs(tsl::profiler::kHloOutfeed); + } + + uint64_t SparseCoreV0InfeedWaitTimePs() const { + return CategoryTimePs(tsl::profiler::kHloSparseCoreV0InfeedWait); + } + + uint64_t SparseCoreV0InfeedTransformTimePs() const { + return CategoryTimePs(tsl::profiler::kHloSparseCoreV0InfeedTransform); + } + + uint64_t SparseCoreV0OutfeedTimePs() const { + return CategoryTimePs(tsl::profiler::kHloSparseCoreV0Outfeed); + } + + uint64_t AllReduceOrAllToAllTimePs() const { + return CategoriesTimePs({tsl::profiler::kHloAllReduce, + tsl::profiler::kHloAllReduceFusion, + tsl::profiler::kHloAllToAll}); + } + + uint64_t SendTimePs() const { + return CategoriesTimePs( + {tsl::profiler::kHloSend, tsl::profiler::kHloSendDone}); + } + + uint64_t RecvTimePs() const { + return CategoriesTimePs( + {tsl::profiler::kHloRecv, tsl::profiler::kHloRecvDone}); + } + + uint64_t HostSendTimePs() const { + return CategoriesTimePs( + {tsl::profiler::kHloHostSend, tsl::profiler::kHloHostSendDone}); + } + + uint64_t HostRecvTimePs() const { + return CategoriesTimePs( + {tsl::profiler::kHloHostRecv, tsl::profiler::kHloHostRecvDone}); + } + + // Megacore fusion runs different operations on each core, e.g., a convolution + // on one core and an all-reduce on the other core. In a trace, megacore + // fusion is the parent operation, and its self time is the time that the core + // executing the faster operation waits for the core executing the slower + // operation to reach the synchronization point. + uint64_t MegacoreFusionTimePs() const { + return CategoryTimePs(tsl::profiler::kHloMegacoreFusion); + } + + uint64_t HighFlopsComputeTimePs() const { + return CategoriesTimePs({tsl::profiler::kHloConvolution, + tsl::profiler::kHloConvolutionBaseDilated, + tsl::profiler::kHloConvolutionWindowDilated, + tsl::profiler::kHloConvolutionFusion, + tsl::profiler::kHloOutputFusion}); + } + + // Calculated according to the "TC busy time" defined in go/tpu_kpis + uint64_t TensorCoreBusyTimePs() const { + return profile_time_ps_ - OffDutyTimePs() - SparseCoreV0InfeedWaitTimePs(); + } + + uint64_t CategoriesTimePs( + const std::initializer_list& categories) const { + uint64_t time_ps = 0; + for (auto category : categories) { + time_ps += CategoryTimePs(category); + } + return time_ps; + } + + uint64_t CategoryTimePs(absl::string_view category) const { + auto iter = time_ps_by_category_.find(category); + return (iter == time_ps_by_category_.end()) ? 0 : iter->second; + } + + template + void ComputeCategoryFractions(Map& category_fractions) { + for (const auto& [category, time_ps] : time_ps_by_category_) { + category_fractions[category] = TimeFraction(time_ps); + } + } + + std::string DebugString() const; + + private: + // Overwrites the time attributed to the given category. + void SetCategoryTimePs(absl::string_view category, uint64_t time_ps); + + // Removes and returns the time attributed to the given category. + uint64_t PopCategoryTimePs(absl::string_view category); + + double TimeFraction(uint64_t time_ps) const { + return tsl::profiler::SafeDivide(time_ps, profile_time_ps_); + } + + absl::flat_hash_map time_ps_by_category_; + uint64_t total_time_ps_ = 0; // Sum of values in time_ps_by_category_. + uint64_t profile_time_ps_ = 0; +}; + +} // namespace profiler +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PROFILER_CONVERT_PROFILE_TIME_BREAKDOWN_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/profiler/convert/repository.h b/third_party/tflite-hdrs/tensorflow/core/profiler/convert/repository.h new file mode 100644 index 00000000..af990aa5 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/profiler/convert/repository.h @@ -0,0 +1,200 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PROFILER_CONVERT_REPOSITORY_H_ +#define TENSORFLOW_CORE_PROFILER_CONVERT_REPOSITORY_H_ + +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "xla/tsl/profiler/utils/file_system_utils.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/path.h" +#include "tensorflow/core/platform/statusor.h" +#include "tensorflow/core/profiler/utils/hlo_module_map.h" +#include "tsl/platform/env.h" +#include "tsl/platform/statusor.h" +#include "tsl/profiler/protobuf/xplane.pb.h" + +namespace tensorflow { +namespace profiler { + +constexpr char kAllHostsIdentifier[] = "ALL_HOSTS"; +constexpr char kNoHostIdentifier[] = "NO_HOST"; + +enum StoredDataType { + DCN_COLLECTIVE_STATS, +}; + +static auto* kHostDataSuffixes = + new std::vector>( + {{StoredDataType::DCN_COLLECTIVE_STATS, ".dcn_collective_stats.pb"}}); + +// File system directory snapshot of a profile session. +class SessionSnapshot { + public: + // Performs validation and creates SessionSnapshot. + // are the file paths to XSpace protos. + // Optionally, can contain the XSpace protos pre-loaded by the + // profiler plugin. + static absl::StatusOr Create( + std::vector xspace_paths, + std::optional>> xspaces); + + // Returns the number of XSpaces in the profile session. + size_t XSpaceSize() const { return xspace_paths_.size(); } + + // Gets XSpace proto. + // The caller of this function will take ownership of the XSpace. + absl::StatusOr> GetXSpace(size_t index) const; + + // Gets XSpace proto. + // The caller of this function will take ownership of the XSpace. + absl::StatusOr> GetXSpaceByName( + absl::string_view name) const; + + // Gets host name. + std::string GetHostname(size_t index) const; + + // Gets the run directory of the profile session. + absl::string_view GetSessionRunDir() const { return session_run_dir_; } + + // Gets whether the session has an accessible run dir. If false, any + // path-based file read will be disabled in this mode. + bool HasAccessibleRunDir() const { return has_accessible_run_dir_; } + + // Gets the path of the fast file for a given tool. + std::optional GetFilePath(absl::string_view toolname, + absl::string_view host) const; + + // Gets the name of the host data file. + absl::StatusOr GetHostDataFileName(StoredDataType data_type, + std::string host) const; + + // Gets the path of the host data file. + absl::StatusOr> GetHostDataFilePath( + StoredDataType data_type, std::string host) const; + + /* Gets whether the cache file is present in run dir. First value indicates + whether cache file is present or not. Second value indicates the path of cache + file. Possible cases are: + 1. : If no cache file is present + 2. : If cache file is present but file contains no data_type + events + 3. : If cache file is present and file contains data_type + events + */ + absl::StatusOr> HasCacheFile( + StoredDataType data_type) const; + + template + absl::Status WriteBinaryProto(const StoredDataType data_type, + const std::string host, T& proto) const { + // Gets name for host data file. + TF_ASSIGN_OR_RETURN(std::string filename, + GetHostDataFileName(data_type, host)); + + std::string filepath = + tsl::profiler::ProfilerJoinPath(GetSessionRunDir(), filename); + + return tensorflow::WriteBinaryProto(tsl::Env::Default(), filepath, proto); + } + + template + absl::Status ReadBinaryProto(const StoredDataType data_type, + const std::string host, T* proto) const { + // Gets file path for host data. + TF_ASSIGN_OR_RETURN(std::optional filepath, + GetHostDataFilePath(data_type, host)); + if (filepath) { + return tensorflow::ReadBinaryProto(tsl::Env::Default(), filepath.value(), + proto); + } + + return absl::NotFoundError( + absl::StrCat("No binary proto found for ", host, " and ", data_type)); + } + + private: + SessionSnapshot(std::vector xspace_paths, + std::optional>> xspaces) + : xspace_paths_(std::move(xspace_paths)), + // If the snapshot was initialized by xspaces, the file path and run dir + // is a path tensorflow can't read from or write to so any file IO + // encapsulated in this class will be disabled in this mode. + has_accessible_run_dir_(!xspaces.has_value()), + xspaces_(std::move(xspaces)) { + session_run_dir_ = tensorflow::io::Dirname(xspace_paths_.at(0)); + for (size_t i = 0; i < xspace_paths_.size(); ++i) { + std::string host_name = GetHostname(i); + hostname_map_[host_name] = i; + } + } + + // File paths to XSpace protos. + std::vector xspace_paths_; + // The run directory of the profile session. + absl::string_view session_run_dir_; + + absl::flat_hash_map + hostname_map_; + + const bool has_accessible_run_dir_; + + // XSpace protos pre-loaded by the profiler plugin. + // TODO(profiler): Use blobstore paths to initialize SessionSnapshot instead + // of using pre-loaded XSpaces. + mutable std::optional>> xspaces_; +}; + +// Writes binary proto format T for a host and data_type to a session. +template +absl::Status WriteBinaryProto(const SessionSnapshot& session_snapshot, + const StoredDataType data_type, + const std::string& host, T& proto) { + return session_snapshot.WriteBinaryProto(data_type, host, proto); +} + +// Reads binary proto format T for a host and data_type to a session. +template +absl::Status ReadBinaryProto(const SessionSnapshot& session_snapshot, + const StoredDataType data_type, + const std::string& host, T* proto) { + return session_snapshot.ReadBinaryProto(data_type, host, proto); +} + +// Process HloModuleMap from all XSpaces in a session. +inline absl::StatusOr ProcessHloModuleMap( + const SessionSnapshot& session_snapshot) { + HloModuleMap hlo_module_map; + for (int i = 0; i < session_snapshot.XSpaceSize(); i++) { + TF_ASSIGN_OR_RETURN(std::unique_ptr xspace, + session_snapshot.GetXSpace(i)); + ProcessHloModuleMapFromXSpace(hlo_module_map, xspace.get()); + } + return hlo_module_map; +} + +} // namespace profiler +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PROFILER_CONVERT_REPOSITORY_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/profiler/convert/step_events_to_steps_db.h b/third_party/tflite-hdrs/tensorflow/core/profiler/convert/step_events_to_steps_db.h new file mode 100644 index 00000000..9764c46c --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/profiler/convert/step_events_to_steps_db.h @@ -0,0 +1,37 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PROFILER_CONVERT_STEP_EVENTS_TO_STEPS_DB_H_ +#define TENSORFLOW_CORE_PROFILER_CONVERT_STEP_EVENTS_TO_STEPS_DB_H_ + +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/profiler/protobuf/steps_db.pb.h" +#include "tensorflow/core/profiler/utils/event_span.h" + +namespace tensorflow { +namespace profiler { + +TF_CONST_INIT extern const uint32 kDefaultGpuLocalCoreId; + +// Converts from overlapped Step-Events to StepDatabaseResult. +StepDatabaseResult ConvertStepEventsToStepDb( + bool has_device, bool maybe_drop_incomplete_steps, + StepEvents& overlapped_step_events); + +} // namespace profiler +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PROFILER_CONVERT_STEP_EVENTS_TO_STEPS_DB_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/profiler/convert/tool_options.h b/third_party/tflite-hdrs/tensorflow/core/profiler/convert/tool_options.h new file mode 100644 index 00000000..85f285e7 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/profiler/convert/tool_options.h @@ -0,0 +1,71 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PROFILER_CONVERT_TOOL_OPTIONS_H_ +#define TENSORFLOW_CORE_PROFILER_CONVERT_TOOL_OPTIONS_H_ + +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/strings/str_format.h" + +namespace tensorflow { +namespace profiler { + +using ToolOptions = + absl::flat_hash_map>; + +// Helper function to get parameter from tool options. +template +std::optional GetParam(const ToolOptions& options, const std::string& key) { + const auto iter = options.find(key); + if (iter == options.end()) { + return std::nullopt; + } + + const T* result = std::get_if(&iter->second); + if (!result) { + return std::nullopt; + } + return *result; +} + +// Helper function to get parameter from tool options with default value. +template +T GetParamWithDefault(const ToolOptions& options, const std::string& key, + const T& default_param) { + if (auto param = GetParam(options, key)) { + return *param; + } + return default_param; +} + +inline std::string DebugString(const ToolOptions& options) { + std::string output; + for (const auto& [k, v] : options) { + absl::StrAppend( + &output, k, ":", + std::visit([](const auto& value) { return absl::StrCat(value); }, v), + ":", v.index(), ";"); + } + return absl::StrCat("{", output, "}"); +} + +} // namespace profiler +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PROFILER_CONVERT_TOOL_OPTIONS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/profiler/convert/tpu_input_pipeline_analysis_constants.h b/third_party/tflite-hdrs/tensorflow/core/profiler/convert/tpu_input_pipeline_analysis_constants.h new file mode 100644 index 00000000..352a2b77 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/profiler/convert/tpu_input_pipeline_analysis_constants.h @@ -0,0 +1,30 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_PROFILER_CONVERT_TPU_INPUT_PIPELINE_ANALYSIS_CONSTANTS_H_ +#define TENSORFLOW_CORE_PROFILER_CONVERT_TPU_INPUT_PIPELINE_ANALYSIS_CONSTANTS_H_ + +#include "absl/strings/string_view.h" +#include "tsl/platform/macros.h" + +namespace tensorflow { +namespace profiler { + +TF_CONST_INIT extern const absl::string_view kProfileAllHostsDoc; +TF_CONST_INIT extern const absl::string_view kSparseCoreV0Name; + +} // namespace profiler +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PROFILER_CONVERT_TPU_INPUT_PIPELINE_ANALYSIS_CONSTANTS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/profiler/convert/trace_viewer/trace_event_arguments_builder.h b/third_party/tflite-hdrs/tensorflow/core/profiler/convert/trace_viewer/trace_event_arguments_builder.h new file mode 100644 index 00000000..73a0f81e --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/profiler/convert/trace_viewer/trace_event_arguments_builder.h @@ -0,0 +1,64 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_PROFILER_CONVERT_TRACE_VIEWER_TRACE_EVENT_ARGUMENTS_BUILDER_H_ +#define TENSORFLOW_CORE_PROFILER_CONVERT_TRACE_VIEWER_TRACE_EVENT_ARGUMENTS_BUILDER_H_ + +#include +#include + +#include "absl/strings/string_view.h" +#include "tensorflow/core/profiler/protobuf/trace_events_raw.pb.h" + +namespace tensorflow { +namespace profiler { + +// Helper class for adding arguments to TraceEventsArguments. +class TraceEventArgumentsBuilder { + public: + explicit TraceEventArgumentsBuilder(TraceEventArguments* args) + : args_(args) {} + + void Append(absl::string_view key, absl::string_view value) { + auto* arg = args_->add_arg(); + arg->set_name(key.data(), key.size()); + arg->set_str_value(value.data(), value.size()); + } + + void Append(absl::string_view key, int64_t value) { + auto* arg = args_->add_arg(); + arg->set_name(key.data(), key.size()); + arg->set_int_value(value); + } + + void Append(absl::string_view key, uint64_t value) { + auto* arg = args_->add_arg(); + arg->set_name(key.data(), key.size()); + arg->set_uint_value(value); + } + + void Append(absl::string_view key, double value) { + auto* arg = args_->add_arg(); + arg->set_name(key.data(), key.size()); + arg->set_double_value(value); + } + + private: + TraceEventArguments* args_; +}; + +} // namespace profiler +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PROFILER_CONVERT_TRACE_VIEWER_TRACE_EVENT_ARGUMENTS_BUILDER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/profiler/convert/trace_viewer/trace_events.h b/third_party/tflite-hdrs/tensorflow/core/profiler/convert/trace_viewer/trace_events.h new file mode 100644 index 00000000..0581aab2 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/profiler/convert/trace_viewer/trace_events.h @@ -0,0 +1,513 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PROFILER_CONVERT_TRACE_VIEWER_TRACE_EVENTS_H_ +#define TENSORFLOW_CORE_PROFILER_CONVERT_TRACE_VIEWER_TRACE_EVENTS_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/base/optimization.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/functional/bind_front.h" +#include "absl/log/check.h" +#include "absl/status/status.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "xla/tsl/lib/io/table.h" +#include "xla/tsl/profiler/utils/timespan.h" +#include "tensorflow/core/profiler/convert/trace_viewer/trace_events_filter_interface.h" +#include "tensorflow/core/profiler/convert/trace_viewer/trace_events_util.h" +#include "tensorflow/core/profiler/convert/trace_viewer/trace_viewer_visibility.h" +#include "tensorflow/core/profiler/lib/context_types.h" +#include "tensorflow/core/profiler/protobuf/task.pb.h" +#include "tensorflow/core/profiler/protobuf/trace_events.pb.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/file_system.h" +#include "tsl/platform/status.h" +#include "tsl/profiler/lib/context_types.h" + +namespace tensorflow { +namespace profiler { + +// A track of events in the trace-viewer. +using TraceEventTrack = std::vector; + +// Merge-sorts the given event tracks. Each track must be sorted. +std::vector MergeEventTracks( + const std::vector& event_tracks); + +absl::Status DoStoreAsLevelDbTable( + std::unique_ptr& file, const Trace& trace, + const std::vector>& events_by_level); + +absl::Status DoLoadFromLevelDbTable( + const std::string& filename, + std::unique_ptr filter, + std::unique_ptr visibility_filter, + int64_t filter_by_visibility_threshold, Trace& trace, + bool& filter_by_visibility, + const std::function& copy_event_to_arena, + const std::function& add_arena_event); + +// Reads the trace metadata from a file with given path +absl::Status ReadFileTraceMetadata(std::string& filepath, Trace* trace); + +std::vector> GetEventsByLevel( + const Trace& trace, std::vector& events); + +// Return the minimum duration an event can have in `level`. +uint64_t LayerResolutionPs(unsigned level); + +// Returns bounds (in picoseconds) for the level that an event +// with `duration_ps` would go into. (upper >= duration_ps > lower) +std::pair GetLevelBoundsForDuration(uint64_t duration_ps); + +struct EventFactory { + TraceEvent* Create() { + events.push_back(std::make_unique()); + return events.back().get(); + } + std::vector> events; +}; + +struct DefaultStdHash { + size_t operator()(absl::string_view input) { + return std::hash()(input); + } +}; + +template +class TraceEventsContainerBase { + public: + TraceEventsContainerBase() { + arenas_.insert(std::make_shared()); + } + + // Movable but non-copyable. + TraceEventsContainerBase(TraceEventsContainerBase&&) = default; + TraceEventsContainerBase& operator=(TraceEventsContainerBase&&) = default; + TraceEventsContainerBase(const TraceEventsContainerBase&) = delete; + TraceEventsContainerBase& operator=(const TraceEventsContainerBase&) = delete; + + // Creates a TraceEvent prefilled with the given values. + void AddCompleteEvent(absl::string_view name, uint32_t resource_id, + uint32_t device_id, tsl::profiler::Timespan timespan, + RawData* raw_data = nullptr, + std::optional group_id = std::nullopt, + std::optional serial = std::nullopt) { + TraceEvent* event = CreateArenaEvent(); + MaybeInternEventName(event, name); + event->set_resource_id(resource_id); + event->set_device_id(device_id); + event->set_timestamp_ps(timespan.begin_ps()); + if (timespan.duration_ps() != 0) { + event->set_duration_ps(timespan.duration_ps()); + } + if (raw_data) { + MaybeInternTraceArgument(raw_data); + raw_data->SerializePartialToString(event->mutable_raw_data()); + if (event->raw_data().empty()) event->clear_raw_data(); + } + if (group_id) { + event->set_group_id(*group_id); + } + if (serial && *serial > 0) { + event->set_serial(static_cast(*serial)); + } + AddArenaEvent(event); + } + + // Similar to above, but the TraceEvent also has an associated flow_id and + // flow_entry_type, to make it part of a flow. + void AddFlowEvent(absl::string_view name, uint32_t resource_id, + uint32_t device_id, tsl::profiler::Timespan timespan, + uint64_t flow_id, TraceEvent::FlowEntryType flow_entry_type, + tsl::profiler::ContextType flow_category = + tsl::profiler::ContextType::kGeneric, + RawData* raw_data = nullptr, + std::optional group_id = std::nullopt, + std::optional serial = std::nullopt) { + TraceEvent* event = CreateArenaEvent(); + MaybeInternEventName(event, name); + event->set_resource_id(resource_id); + event->set_device_id(device_id); + event->set_timestamp_ps(timespan.begin_ps()); + if (timespan.duration_ps() != 0) { + event->set_duration_ps(timespan.duration_ps()); + } + event->set_flow_id(flow_id); + event->set_flow_entry_type(flow_entry_type); + event->set_flow_category(static_cast(flow_category)); + if (raw_data) { + MaybeInternTraceArgument(raw_data); + raw_data->SerializePartialToString(event->mutable_raw_data()); + if (event->raw_data().empty()) event->clear_raw_data(); + } + if (group_id) { + event->set_group_id(*group_id); + } + if (serial && *serial > 0) { + event->set_serial(static_cast(*serial)); + } + AddArenaEvent(event); + } + + // Similar to above, but the "async" TraceEvent don't have a resource id, its + // name is used as "async channel" which are used as "thread" name. It has an + // associated unique flow_id and flow_entry_type to signal asynchronous + // start and end events and match up between them. + void AddAsyncEvent(absl::string_view name, uint32_t device_id, + tsl::profiler::Timespan timespan, uint64_t flow_id, + TraceEvent::FlowEntryType flow_entry_type, + tsl::profiler::ContextType flow_category = + tsl::profiler::ContextType::kGeneric, + RawData* raw_data = nullptr, + std::optional group_id = std::nullopt, + std::optional serial = std::nullopt) { + TraceEvent* event = CreateArenaEvent(); + MaybeInternEventName(event, name); + event->set_device_id(device_id); + event->set_timestamp_ps(timespan.begin_ps()); + if (timespan.duration_ps() != 0) { + event->set_duration_ps(timespan.duration_ps()); + } + event->set_flow_id(flow_id); + event->set_flow_entry_type(flow_entry_type); + event->set_flow_category(static_cast(flow_category)); + if (raw_data) { + MaybeInternTraceArgument(raw_data); + raw_data->SerializePartialToString(event->mutable_raw_data()); + if (event->raw_data().empty()) event->clear_raw_data(); + } + if (group_id) { + event->set_group_id(*group_id); + } + if (serial && *serial > 0) { + event->set_serial(static_cast(*serial)); + } + AddArenaEvent(event); + } + + // Similar to above, but the TraceEvent also has an associated counter name + // and value in RawData.args. Counter events are per device, so no resource_id + // is passed. + void AddCounterEvent(absl::string_view name, uint32_t device_id, + uint64_t timestamp_ps, const RawData& raw_data, + std::optional serial = std::nullopt) { + TraceEvent* event = CreateArenaEvent(); + event->set_name(name.data(), name.size()); + event->set_device_id(device_id); + // Do not set resource_id for counter events, they are per device. + event->set_timestamp_ps(timestamp_ps); + DCHECK(raw_data.has_args()); + DCHECK_EQ(raw_data.args().arg_size(), 1); + DCHECK(raw_data.args().arg(0).has_uint_value()); + raw_data.SerializePartialToString(event->mutable_raw_data()); + if (serial && *serial > 0) { + event->set_serial(static_cast(*serial)); + } + AddArenaEvent(event); + } + + // Returns a device descriptor. + Device* MutableDevice(uint32_t device_id) { + return &(*trace_.mutable_devices())[device_id]; + } + + // Returns a resource descriptor, + Resource* MutableResource(uint32_t resource_id, uint32_t device_id) { + Device* device = MutableDevice(device_id); + return &(*device->mutable_resources())[resource_id]; + } + + // Adds metadata events to set the name of each device and resource. + // The arguments are callbacks that return the names given ids. + // This must be called after all AddEvent calls, and no more AddEvent + // calls should be made after calling AddMetadataEvents. + void AddMetadataEvents( + const std::function& device_name, + const std::function& resource_name) { + for (const auto& id_and_device : events_by_device_) { + uint32_t device_id = id_and_device.first; + auto& device = (*trace_.mutable_devices())[device_id]; + device.set_device_id(device_id); + device.set_name(device_name(device_id)); + const DeviceEvents& device_events = id_and_device.second; + for (const auto& id_and_resource : device_events.events_by_resource) { + uint32_t resource_id = id_and_resource.first; + auto& resource = (*device.mutable_resources())[resource_id]; + resource.set_resource_id(resource_id); + resource.set_name(resource_name(device_id, resource_id)); + resource.set_num_events(id_and_resource.second.size()); + } + } + } + + // Adds task metadata for the given host. + void AddTask(int host_id, const Task& task) { + (*trace_.mutable_tasks())[host_id] = task; + } + + // Stores the contents of this container in a level-db sstable file. + absl::Status StoreAsLevelDbTable( + std::unique_ptr file) const { + Trace trace = trace_; + trace.set_num_events(NumEvents()); + auto events_by_level = EventsByLevel(); + return DoStoreAsLevelDbTable(file, trace, events_by_level); + } + + std::vector> GetTraceEventsByLevel() const { + return EventsByLevel(); + } + + // Loads the contents of this container from a level-db sstable file. + // In order to be efficient, requires resolution__ to be set. + // If span_ is not set, it is initialized from the loaded trace_. + absl::Status LoadFromLevelDbTable( + const std::string& filename, + std::unique_ptr filter = nullptr, + std::unique_ptr visibility = nullptr, + int64_t filter_by_visibility_threshold = -1LL) { + return DoLoadFromLevelDbTable( + filename, std::move(filter), std::move(visibility), + filter_by_visibility_threshold, trace_, filter_by_visibility_, + absl::bind_front(&TraceEventsContainerBase::CopyEventToArena, this), + absl::bind_front(&TraceEventsContainerBase::AddArenaEvent, this)); + } + + // Calls 'callback' with all events stored in this container. + template + void ForAllEvents(Callback callback) const { + for (const auto& [device_id, device] : events_by_device_) { + for (const auto& [counter_name, events] : device.counter_events_by_name) { + for (auto* event : events) { + callback(*event); + } + } + for (const auto& [resource_id, events] : device.events_by_resource) { + for (auto* event : events) { + callback(*event); + } + } + } + } + + // Calls 'callback' with all event tracks stored in this container. + template + void ForAllTracks(Callback callback) const { + for (const auto& [device_id, device] : events_by_device_) { + for (const auto& [counter_name, events] : device.counter_events_by_name) { + if (!events.empty()) { + if (ABSL_PREDICT_FALSE(!callback(device_id, counter_name, events))) + return; + } + } + for (const auto& [resource_id, events] : device.events_by_resource) { + if (!events.empty()) { + if (ABSL_PREDICT_FALSE(!callback(device_id, resource_id, events))) + return; + } + } + } + } + + // Calls 'callback' with all event tracks stored in this container. + template + void ForAllMutableTracks(Callback callback) const { + for (auto& [device_id, device] : events_by_device_) { + for (auto& [counter_name, events] : device.counter_events_by_name) { + if (!events.empty()) { + callback(device_id, counter_name, &events); + } + } + for (auto& [resource_id, events] : device.events_by_resource) { + if (!events.empty()) { + callback(device_id, resource_id, &events); + } + } + } + } + + // Calls 'callback' with all event flows stored in this container. + template + void ForAllFlows(Callback callback) const { + absl::flat_hash_map flows; + for (const auto& [device_id, device] : events_by_device_) { + // Counter events are not flow events. + for (const auto& [resource_id, events] : device.events_by_resource) { + for (auto* event : events) { + if (event->has_flow_id()) flows[event->flow_id()].push_back(event); + } + } + } + for (auto& [flow_id, combined_flow] : flows) { + // If the flow_id is reused, split into individual flows. + for (auto& flow : SplitEventFlow(std::move(combined_flow))) { + callback(flow_id, flow); + } + } + } + + // Returns the metadata for this trace container. + const Trace& trace() const { return trace_; } + + // Returns the number of events. + size_t NumEvents() const { + size_t count = 0; + for (const auto& [device_id, device] : events_by_device_) { + for (const auto& [counter_name, events] : device.counter_events_by_name) { + count += events.size(); + } + for (const auto& [resource_id, events] : device.events_by_resource) { + count += events.size(); + } + } + return count; + } + + // Returns the number of tracks. + size_t NumTracks() const { + return std::accumulate( + events_by_device_.begin(), events_by_device_.end(), 0, + [](const size_t tracks, const std::pair item) { + return tracks + item.second.counter_events_by_name.size() + + item.second.events_by_resource.size(); + }); + } + + bool FilterByVisibility() const { return filter_by_visibility_; } + + protected: + // Allocates an event in the first of the arenas_. + TraceEvent* CreateArenaEvent() { return (*arenas_.begin())->Create(); } + + // Copies event into arenas_. + TraceEvent* CopyEventToArena(const TraceEvent& event) { + TraceEvent* copy = CreateArenaEvent(); + *copy = event; + return copy; + } + + // Adds an event from arenas_ to events_by_device_. + void AddArenaEvent(TraceEvent* event) { + ExpandTraceSpan(EventSpan(*event), &trace_); + DeviceEvents& device_events = events_by_device_[event->device_id()]; + if (!event->has_resource_id()) { + device_events.counter_events_by_name[event->name()].push_back(event); + } else { + device_events.events_by_resource[event->resource_id()].push_back(event); + } + } + + // Returns all events grouped by visibility level. + std::vector> EventsByLevel() const { + std::vector events = SortedEvents(); + return GetEventsByLevel(trace_, events); + } + + // Returns all events sorted using TraceEventsComparator. + // Helper for EventsByLevel(). + // REQUIRED: All events have been added and SortTracks() has been called. + std::vector SortedEvents() const { + std::vector event_tracks; + event_tracks.reserve(NumTracks()); + ForAllMutableTracks( + [&event_tracks](uint32_t device_id, + std::variant resource_id, + TraceEventTrack* events) { + event_tracks.push_back(events); + }); + return MergeEventTracks(event_tracks); + } + + uint64_t MaybeInternString(absl::string_view name) { + uint64_t fp = hash_(name); + auto& it = (*trace_.mutable_name_table())[fp]; + if (it.empty()) { + it = name; + } + return fp; + } + + void MaybeInternEventName(TraceEvent* event, absl::string_view name) { + static constexpr size_t kNameInternThreshold = 32; + if (name.size() > kNameInternThreshold) { + event->set_name_ref(MaybeInternString(name)); + } else { + event->set_name(name.data(), name.size()); + } + } + + void MaybeInternTraceArgument(RawData* raw_data) { + if (raw_data->has_args()) { + for (auto& arg : *raw_data->mutable_args()->mutable_arg()) { + constexpr size_t kTraceArgInternThreshold = 16; + if (arg.has_str_value() && + arg.str_value().size() > kTraceArgInternThreshold) { + // Use name table to string intern the trace argument. + if (arg.name() == "long_name" || arg.name() == "hlo_text") { + // Also mark it as potential stack frame. + arg.set_ref_value(MaybeInternString("@@" + arg.str_value())); + } else { + arg.set_ref_value(MaybeInternString(arg.str_value())); + } + } + } + } + } + + // Events shown within a single device. + struct DeviceEvents { + // Counter events, which are per-device (don't have resource_id), and are + // plotted in different tracks for each counter name. + absl::flat_hash_map counter_events_by_name; + + // Complete events and flow events, mapped by resource_id. + std::map events_by_resource; + }; + + // Events, mapped by device_id. + mutable std::map events_by_device_; + + // Indicator on if visibility filtering is applied or not + // Currently skip visibility filtering only applies to ssTable + bool filter_by_visibility_ = true; + + // The arenas containing events constructed in this container or in containers + // that have been merged into this container. + using Arenas = absl::flat_hash_set>; + Arenas arenas_; + + Trace trace_; + Hash hash_; +}; + +} // namespace profiler +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PROFILER_CONVERT_TRACE_VIEWER_TRACE_EVENTS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/profiler/convert/trace_viewer/trace_events_filter_interface.h b/third_party/tflite-hdrs/tensorflow/core/profiler/convert/trace_viewer/trace_events_filter_interface.h new file mode 100644 index 00000000..24f63203 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/profiler/convert/trace_viewer/trace_events_filter_interface.h @@ -0,0 +1,40 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PROFILER_CONVERT_TRACE_VIEWER_TRACE_EVENTS_FILTER_INTERFACE_H_ +#define TENSORFLOW_CORE_PROFILER_CONVERT_TRACE_VIEWER_TRACE_EVENTS_FILTER_INTERFACE_H_ + +#include "tensorflow/core/profiler/protobuf/trace_events.pb.h" + +namespace tensorflow { +namespace profiler { + +// Trace event filter interface. +class TraceEventsFilterInterface { + public: + virtual ~TraceEventsFilterInterface() = default; + + // Allow sub-classes to set up filtering by processing the trace, e.g., by + // capturing the names of devices and resources that need to be filtered. + virtual void SetUp(const Trace& trace) = 0; + + // Returns true if event should not be added to a TraceEventsContainer. + virtual bool Filter(const TraceEvent& event) = 0; +}; + +} // namespace profiler +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PROFILER_CONVERT_TRACE_VIEWER_TRACE_EVENTS_FILTER_INTERFACE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/profiler/convert/trace_viewer/trace_events_to_json.h b/third_party/tflite-hdrs/tensorflow/core/profiler/convert/trace_viewer/trace_events_to_json.h new file mode 100644 index 00000000..873a791d --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/profiler/convert/trace_viewer/trace_events_to_json.h @@ -0,0 +1,610 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_PROFILER_CONVERT_TRACE_VIEWER_TRACE_EVENTS_TO_JSON_H_ +#define TENSORFLOW_CORE_PROFILER_CONVERT_TRACE_VIEWER_TRACE_EVENTS_TO_JSON_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/base/macros.h" +#include "absl/container/fixed_array.h" +#include "absl/container/flat_hash_set.h" +#include "absl/log/log.h" +#include "absl/strings/match.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" +#include "absl/strings/strip.h" +#include "absl/time/time.h" +#include "absl/types/optional.h" +#include "xla/tsl/profiler/utils/timespan.h" +#include "tensorflow/core/profiler/convert/trace_viewer/trace_events_util.h" +#include "tensorflow/core/profiler/convert/trace_viewer/trace_viewer_color.h" +#include "tensorflow/core/profiler/lib/context_types.h" +#include "tensorflow/core/profiler/protobuf/task.pb.h" +#include "tensorflow/core/profiler/protobuf/trace_events.pb.h" +#include "tensorflow/core/profiler/protobuf/trace_events_raw.pb.h" +#include "tsl/platform/protobuf.h" +#include "tsl/profiler/lib/context_types.h" + +namespace tensorflow { +namespace profiler { + +// JSON generation options. +struct JsonTraceOptions { + using Details = std::vector>; + + // Options and values for filtering based on the "details" menu. + Details details; + + // If selected_device_ids is set, we add a field "selected_device_ids" + // in the Trace JSON. + std::optional> selected_device_ids; + + // Device IDs of devices whose resources should be sorted by name instead of + // by resource ID. + absl::flat_hash_set sort_resources_by_name; + + // Returns the color for an event. + TraceEventsColorerInterface* colorer = nullptr; + + bool generate_stack_frames = true; + bool use_new_backend = false; + std::string code_link; +}; + +// Counts generated JSON events by type. +class JsonEventCounter { + public: + JsonEventCounter() : event_count_(kNumEventTypes, 0) {} + ~JsonEventCounter() { LOG(INFO) << ToString(); } + + // Types of JSON events (bit.ly/trace-event-format) + enum EventType { + kCompleteEvent, + kCompleteEventWithFlow, + kCounterEvent, + kAsyncEvent, + }; + + void Inc(EventType e) { ++event_count_[e]; } + + std::string ToString() const { + std::string output = "Generated JSON events:"; + for (size_t i = 0; i < event_count_.size(); ++i) { + absl::StrAppend(&output, " ", kEventTypeName[i], ": ", event_count_[i]); + } + return output; + } + + private: + static constexpr absl::string_view kEventTypeName[] = { + "complete", + "complete+flow", + "counter", + "async", + }; + + static constexpr size_t kNumEventTypes = ABSL_ARRAYSIZE(kEventTypeName); + + absl::FixedArray event_count_; +}; + +// Adds a separator between elements of a JSON array or object. +template +class JsonSeparator { + public: + explicit JsonSeparator(IOBuffer* output) : output_(output) {} + + // Does nothing on the first call; adds a comma to the output on subsequent + // calls. + void Add() { + output_->Append(sep_); + sep_ = ","; + } + + private: + IOBuffer* output_; + absl::string_view sep_; +}; + +// Converts picoseconds to microseconds. +inline double PicosToMicros(uint64_t ps) { return ps / 1E6; } + +// Escapes the contents of "raw" in JSON style. +// Also adds double quotes to the beginning and end of the string. +std::string JsonEscape(absl::string_view raw); + +std::string ProtoString(const tsl::protobuf::Message& pb); + +template +void WriteTpuData(const RawDataType& data, JsonSeparator* separator, + IOBuffer* output) {} + +// Writes JSON events from a TraceEvent. +template +class JsonEventWriter { + public: + JsonEventWriter(const TraceEventsColorerInterface* colorer, + const Trace& trace, + const std::map& references, + IOBuffer* output) + : colorer_(colorer), + trace_(trace), + references_(references), + output_(output) {} + + void WriteEvent(const TraceEvent& event) const { + std::optional async_event; + output_->Append(R"({"pid":)", event.device_id()); + if (event.has_resource_id()) { + output_->Append(R"(,"tid":)", event.resource_id()); + } + const std::string& event_name = + event.has_name_ref() ? trace_.name_table().at(event.name_ref()) + : event.name(); + output_->Append(R"(,"name":)", JsonEscape(event_name)); + tsl::profiler::Timespan span = EventSpan(event); + // "%.17g" is the default double format in proto2::util::JsonFormat. + absl::Format(output_, R"(,"ts":%.17g)", PicosToMicros(span.begin_ps())); + JsonEventCounter::EventType event_type = JsonEventCounter::kCounterEvent; + if (event.has_resource_id()) { + event_type = event.has_flow_id() + ? JsonEventCounter::kCompleteEventWithFlow + : JsonEventCounter::kCompleteEvent; + // A complete event must have a duration, otherwise trace-viewer will + // extend the event to the end of the trace and append "(Did Not Finish)" + // to its name. Make the minimum duration 1 picosecond. + uint64_t duration_ps = std::max(span.duration_ps(), uint64_t{1}); + absl::Format(output_, R"(,"dur":%.17g)", PicosToMicros(duration_ps)); + + if (std::optional color_id = colorer_->GetColor(event)) { + output_->Append(R"(,"cname":)", TraceViewerColorName(*color_id)); + } + + // FlowV2 + if (event_type == JsonEventCounter::kCompleteEventWithFlow) { + output_->Append(R"(,"bind_id":)", event.flow_id()); + if (event.has_flow_category()) { + tsl::profiler::ContextType type = + tsl::profiler::GetSafeContextType(event.flow_category()); + if (type != tsl::profiler::ContextType::kGeneric && + type != tsl::profiler::ContextType::kLegacy) { + const char* category = tsl::profiler::GetContextTypeString(type); + output_->Append(R"(,"cat":")", category, R"(")"); + } + } + switch (event.flow_entry_type()) { + case TraceEvent::FLOW_NONE: + // The caller prevents this case from happening. + break; + case TraceEvent::FLOW_START: + output_->Append(R"(,"flow_out":true)"); + break; + case TraceEvent::FLOW_MID: + output_->Append(R"(,"flow_in":true,"flow_out":true)"); + break; + case TraceEvent::FLOW_END: + output_->Append(R"(,"flow_in":true)"); + break; + } + } + output_->Append(R"(,"ph":"X")"); + } else { + event_type = event.has_flow_id() ? JsonEventCounter::kAsyncEvent + : JsonEventCounter::kCounterEvent; + if (event_type == JsonEventCounter::kCounterEvent) { + output_->Append(R"(,"ph":"C")"); + } else { // async events + output_->Append(R"(,"id":)", event.flow_id()); + if (event.has_flow_category()) { + tsl::profiler::ContextType type = + tsl::profiler::GetSafeContextType(event.flow_category()); + const char* category = tsl::profiler::GetContextTypeString(type); + output_->Append(R"(,"cat":")", category, R"(")"); + } + switch (event.flow_entry_type()) { + case TraceEvent::FLOW_NONE: + // The caller prevents this case from happening. + break; + case TraceEvent::FLOW_START: + output_->Append(R"(,"ph":"b")"); + break; + case TraceEvent::FLOW_END: + output_->Append(R"(,"ph":"e")"); + break; + case TraceEvent::FLOW_MID: + output_->Append(R"(,"ph":"b")"); + async_event.emplace(event); + async_event->set_flow_entry_type(TraceEvent::FLOW_END); + async_event->set_timestamp_ps(event.timestamp_ps() + + event.duration_ps()); + async_event->clear_raw_data(); + break; + } + } + } + WriteArgs(event); + if (event.has_serial()) { + output_->Append(R"(,"z":)", event.serial()); + } + + output_->Append("}"); + counter_.Inc(event_type); + if (async_event) { + output_->Append(","); + WriteEvent(*async_event); + } + } + + private: + void WriteArgs(const TraceEvent& event) const { + if (!event.has_group_id() && !event.has_raw_data()) { + return; + } + output_->Append(R"(,"args":{)"); + std::optional stack_frames; + JsonSeparator separator(output_); + if (event.has_group_id()) { + separator.Add(); + output_->Append(R"("group_id":)", event.group_id()); + } + if (event.has_raw_data()) { + RawDataType data; + data.ParseFromString(event.raw_data()); + switch (data.raw_data_case()) { + case RawDataType::RAW_DATA_NOT_SET: + break; + case RawDataType::kTpuData: + WriteTpuData(data, &separator, output_); + break; + case RawDataType::kDmaActivity: + separator.Add(); + output_->Append(R"("DMA activity":)", + ProtoString(data.dma_activity())); + break; + case RawDataType::kArgs: + for (const auto& arg : data.args().arg()) { + switch (arg.value_case()) { + case TraceEventArguments::Argument::kStrValue: + separator.Add(); + WriteArg(arg.name(), arg.str_value()); + break; + case TraceEventArguments::Argument::kIntValue: + separator.Add(); + WriteArg(arg.name(), arg.int_value()); + break; + case TraceEventArguments::Argument::kUintValue: + separator.Add(); + WriteArg(arg.name(), arg.uint_value()); + break; + case TraceEventArguments::Argument::kDoubleValue: + separator.Add(); + WriteArg(arg.name(), arg.double_value()); + break; + case TraceEventArguments::Argument::kRefValue: { + const auto& it = trace_.name_table().find(arg.ref_value()); + if (it != trace_.name_table().end()) { + // Each event could only have one stack frame. + if (absl::StartsWith(it->second, "@@") && !stack_frames) { + stack_frames = arg.ref_value(); + } else { + separator.Add(); + WriteArg(arg.name(), it->second); + } + } + break; + } + case TraceEventArguments::Argument::VALUE_NOT_SET: + break; + } + } + break; + } + } + output_->Append("}"); + + // Write the optional stack frame. + if (stack_frames.has_value()) { + output_->Append(R"(,"sf":)", references_.at(*stack_frames), R"()"); + } + } + void WriteArg(absl::string_view name, absl::string_view value) const { + output_->Append(JsonEscape(name), ":", JsonEscape(value)); + } + void WriteArg(absl::string_view name, uint64_t value) const { + // Limit beyond which integers converted to 64-bit IEEE floating point may + // lose accuracy. JavaScript stores all numbers as doubles, quote the value + // to preserve accuracy. + // https://en.wikipedia.org/wiki/Double-precision_floating-point_format + constexpr uint64_t kIeeeLimit = 1ULL << 53; + if (value > kIeeeLimit) { + output_->Append(JsonEscape(name), ":\"", value, "\""); + } else { + output_->Append(JsonEscape(name), ":", value); + } + } + void WriteArg(absl::string_view name, int64_t value) const { + // Limit beyond which integers converted to 64-bit IEEE floating point may + // lose accuracy. JavaScript stores all numbers as doubles, quote the value + // to preserve accuracy. + // https://en.wikipedia.org/wiki/Double-precision_floating-point_format + constexpr uint64_t kIeeeLimit = 1ULL << 53; + if (abs(value) > kIeeeLimit) { + output_->Append(JsonEscape(name), ":\"", value, "\""); + } else { + output_->Append(JsonEscape(name), ":", value); + } + } + void WriteArg(absl::string_view name, double value) const { + if (std::isfinite(value)) { + output_->Append(JsonEscape(name)); + // "%.17g" is the default double format in proto2::util::JsonFormat. + absl::Format(output_, ":%.17g", value); + } else if (std::isinf(value)) { + output_->Append(JsonEscape(name), R"(:"Infinity")"); + } else if (std::isinf(-value)) { + output_->Append(JsonEscape(name), R"(:"-Infinity")"); + } else { + output_->Append(JsonEscape(name), R"(:"NaN")"); + } + } + + const TraceEventsColorerInterface* colorer_; + const Trace& trace_; + const std::map& references_; + IOBuffer* output_; + mutable JsonEventCounter counter_; +}; + +template +void WriteTasks(const Trace& trace, IOBuffer* output) { + const auto& tasks = trace.tasks(); + if (tasks.empty()) return; + output->Append(R"("tasks":[)"); + JsonSeparator task_separator(output); + std::map ordered_tasks(tasks.begin(), tasks.end()); + for (const auto& entry : ordered_tasks) { + const uint32_t host_id = entry.first; + const auto& task = entry.second; + + task_separator.Add(); + output->Append("{"); + JsonSeparator field_separator(output); + field_separator.Add(); + output->Append(R"("host_id":)", host_id); + if (task.has_changelist()) { + field_separator.Add(); + output->Append(R"("changelist":)", task.changelist()); + } + if (task.has_clean_build()) { + field_separator.Add(); + output->Append(R"("clean_build":)", task.clean_build()); + } + if (task.has_build_time()) { + field_separator.Add(); + output->Append( + R"("build_time":)", + JsonEscape(absl::FormatTime(absl::FromUnixNanos(task.build_time()), + absl::UTCTimeZone()))); + } + if (task.has_build_target()) { + field_separator.Add(); + output->Append(R"("build_target":)", JsonEscape(task.build_target())); + } + if (task.has_command_line()) { + field_separator.Add(); + output->Append(R"("command_line":)", JsonEscape(task.command_line())); + } + if (task.has_start_time()) { + field_separator.Add(); + output->Append( + R"("start_time":)", + JsonEscape(absl::FormatTime(absl::FromUnixNanos(task.start_time()), + absl::UTCTimeZone()))); + } + if (task.has_gtc_freq_hz()) { + field_separator.Add(); + output->Append(R"("gtc_freq_hz":)", task.gtc_freq_hz()); + } + if (task.has_tensor_core_freq_hz()) { + field_separator.Add(); + output->Append(R"("tensor_core_freq_hz":)", task.tensor_core_freq_hz()); + } + if (task.has_sparse_core_freq_hz()) { + field_separator.Add(); + output->Append(R"("sparse_core_freq_hz":)", task.sparse_core_freq_hz()); + } + output->Append("}"); + } + output->Append("],"); +} + +template +void WriteStackFrames(const Trace& trace, + const std::map& references, + IOBuffer* output) { + const auto& name_table = trace.name_table(); + output->Append(R"("stackFrames":{)"); + JsonSeparator separator(output); + for (const auto& [fp, name] : name_table) { + if (!absl::StartsWith(name, "@@")) continue; + separator.Add(); + std::string_view name_view = name; + absl::ConsumePrefix(&name_view, "@@"); + output->Append(R"(")", references.at(fp), R"(":{"name":)", + JsonEscape(name_view), R"(})"); + } + output->Append("},"); +} + +template +void WriteDetails(const JsonTraceOptions::Details& details, IOBuffer* output) { + if (details.empty()) return; + output->Append(R"("details":[)"); + JsonSeparator separator(output); + for (const auto& detail : details) { + separator.Add(); + output->Append(R"({"name":)", JsonEscape(detail.first), R"(,"value":)", + detail.second ? "true" : "false", "}"); + } + output->Append("],"); +} + +template +void WriteSelectedDeviceIds( + const absl::optional>& selected_device_ids, + IOBuffer* output) { + if (!selected_device_ids.has_value()) return; + + output->Append(R"("selected_device_ids":[)"); + JsonSeparator separator(output); + for (const auto& device_id : selected_device_ids.value()) { + separator.Add(); + output->Append(device_id); + } + output->Append("],"); +} + +std::map BuildStackFrameReferences(const Trace& trace); + +template +void WriteReturnedEventsSize(const int events_size, IOBuffer* output) { + output->Append(R"("returnedEventsSize":)", events_size, R"(,)"); +} + +template +void WriteFilteredByVisibility(bool filtered_by_visibility, IOBuffer* output) { + absl::string_view filtered_by_visibility_str = + filtered_by_visibility ? "true" : "false"; + output->Append(R"("filteredByVisibility":)", filtered_by_visibility_str, + R"(,)"); +} + +template +void WriteTraceFullTimespan(const Trace* trace, IOBuffer* output) { + auto start_time_ms = trace->min_timestamp_ps() / 1000000000.0; + auto end_time_ms = trace->max_timestamp_ps() / 1000000000.0; + output->Append(R"("fullTimespan":[)", start_time_ms, R"(,)", end_time_ms, + R"(],)"); +} + +template +void TraceEventsToJson(const JsonTraceOptions& options, + const TraceEventsContainer& events, IOBuffer* output) { + // Set the displayTimeUnit to nanoseconds (default is milliseconds), so the UI + // uses higher-precision when manipulating event times. Note that the + // timestamps of trace events are always given in microseconds. + output->Append( + R"({"displayTimeUnit":"ns","metadata":{"highres-ticks":true}, "codeLink":")", + options.code_link, R"(",)"); + + output->Append(absl::StrFormat(R"("useNewBackend": %s,)", + options.use_new_backend ? "true" : "false")); + WriteDetails(options.details, output); + WriteSelectedDeviceIds(options.selected_device_ids, output); + WriteReturnedEventsSize(events.NumEvents(), output); + WriteFilteredByVisibility(events.FilterByVisibility(), output); + WriteTraceFullTimespan(&events.trace(), output); + + const Trace& trace = events.trace(); + + WriteTasks(trace, output); + + auto references = BuildStackFrameReferences(trace); + if (options.generate_stack_frames) { + WriteStackFrames(trace, references, output); + } + + output->Append(R"("traceEvents":[)"); + JsonSeparator separator(output); + // Write metadata events. + std::map ordered_devices(trace.devices().begin(), + trace.devices().end()); + for (const auto& [device_id, device] : ordered_devices) { + if (device.has_name()) { + separator.Add(); + output->Append(R"({"args":{"name":)", JsonEscape(device.name()), + R"(},"name":"process_name","ph":"M","pid":)", device_id, + R"(,"thread_count":)", device.resources_size(), "}"); + } + separator.Add(); + output->Append(R"({"args":{"sort_index":)", device_id, + R"(},"name":"process_sort_index","ph":"M","pid":)", + device_id, "}"); + std::map ordered_resources(device.resources().begin(), + device.resources().end()); + for (const auto& [resource_id, resource] : ordered_resources) { + if (resource.has_name()) { + separator.Add(); + output->Append(R"({"args":{"name":)", JsonEscape(resource.name()), + R"(},"name":"thread_name","ph":"M","pid":)", device_id, + R"(,"tid":)", resource_id, "}"); + } + if (!options.sort_resources_by_name.count(device_id)) { + separator.Add(); + output->Append(R"({"args":{"sort_index":)", resource_id, + R"(},"name":"thread_sort_index","ph":"M","pid":)", + device_id, R"(,"tid":)", resource_id, "}"); + } + } + } + + TraceEventsColorerInterface* colorer = options.colorer; + DefaultTraceEventsColorer default_colorer; + if (colorer == nullptr) colorer = &default_colorer; + colorer->SetUp(trace); + + // Write events. + JsonEventWriter writer(colorer, trace, references, + output); + events.ForAllEvents([&](const TraceEvent& event) { + separator.Add(); + writer.WriteEvent(event); + }); + output->Append("]}"); +} + +class IOBufferAdapter { + public: + explicit IOBufferAdapter(std::string* output) : output_(output) {} + + template + inline void Append(AV&&... args) { + absl::StrAppend(output_, std::forward(args)...); + } + + // Support IOBufferAdapter as a sink object for absl::Format. + friend void AbslFormatFlush(IOBufferAdapter* buffer, absl::string_view s) { + absl::StrAppend(buffer->output_, s); + } + + private: + std::string* output_; +}; + +} // namespace profiler +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PROFILER_CONVERT_TRACE_VIEWER_TRACE_EVENTS_TO_JSON_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/profiler/convert/trace_viewer/trace_events_util.h b/third_party/tflite-hdrs/tensorflow/core/profiler/convert/trace_viewer/trace_events_util.h new file mode 100644 index 00000000..832da3f3 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/profiler/convert/trace_viewer/trace_events_util.h @@ -0,0 +1,168 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_PROFILER_CONVERT_TRACE_VIEWER_TRACE_EVENTS_UTIL_H_ +#define TENSORFLOW_CORE_PROFILER_CONVERT_TRACE_VIEWER_TRACE_EVENTS_UTIL_H_ + +#include +#include +#include +#include +#include +#include + +#include "absl/log/check.h" +#include "absl/strings/string_view.h" +#include "xla/tsl/profiler/utils/timespan.h" +#include "tensorflow/core/profiler/protobuf/trace_events.pb.h" + +namespace tensorflow { +namespace profiler { + +// Returns the resource name for the given (device_id, resource_id) in trace. +inline absl::string_view ResourceName(const Trace& trace, uint32_t device_id, + uint32_t resource_id) { + return trace.devices().at(device_id).resources().at(resource_id).name(); +} + +// Returns the resource name for the given event in trace. +inline absl::string_view ResourceName(const Trace& trace, + const TraceEvent& event) { + return ResourceName(trace, event.device_id(), event.resource_id()); +} + +// Functor that compares trace events for sorting. +// Trace events are sorted by timestamp_ps (ascending) and duration_ps +// (descending) so nested events are sorted from outer to innermost. +struct TraceEventsComparator { + bool operator()(const TraceEvent* a, const TraceEvent* b) const { + if (a->timestamp_ps() < b->timestamp_ps()) return true; + if (a->timestamp_ps() > b->timestamp_ps()) return false; + return (a->duration_ps() > b->duration_ps()); + } +}; + +// Creates a tsl::profiler::Timespan from a TraceEvent. +inline tsl::profiler::Timespan EventSpan(const TraceEvent& event) { + return tsl::profiler::Timespan(event.timestamp_ps(), event.duration_ps()); +} + +// Creates a tsl::profiler::Timespan from a Trace. +inline tsl::profiler::Timespan TraceSpan(const Trace& trace) { + return tsl::profiler::Timespan::FromEndPoints(trace.min_timestamp_ps(), + trace.max_timestamp_ps()); +} + +// A flow of events in the trace-viewer. +// All events in the flow have the same flow_id. +using TraceEventFlow = std::vector; + +// In case the flow_id was re-used, split into individual flows based on the +// flow_entry_type. +std::vector SplitEventFlow(TraceEventFlow&& flow); + +// Returns whether the flow is complete. +inline bool IsCompleteFlow(const TraceEventFlow& flow) { + DCHECK(!flow.empty()); + return flow.front()->flow_entry_type() == TraceEvent::FLOW_START && + flow.back()->flow_entry_type() == TraceEvent::FLOW_END; +} + +// Updates the timestamps of a Trace to ensure it includes the given +// tsl::profiler::Timespan. +void ExpandTraceSpan(const tsl::profiler::Timespan& span, Trace* trace); + +// Nway-merge implementation. + +// Reorders the elements of the range [first, last) to restore the heap +// condition (i.e. `std::is_heap(first, last, comp)`) following a change +// in the value of `*first`. +// +// REQUIRES: `first < last`, and [first, last) would be a valid heap if `*first` +// had a suitable value. +template +void push_down_root(RandIt first, RandIt last, Compare comp) { + size_t size = last - first; + size_t hole = 0; // root. + auto value = std::move(*first); + while (true) { + size_t l_child = 2 * hole + 1; + size_t r_child = l_child + 1; + size_t max_child = l_child; + if (r_child < size && comp(first[l_child], first[r_child])) { + max_child = r_child; + } + if (max_child >= size) break; + if (!comp(value, first[max_child])) break; + first[hole] = std::move(first[max_child]); + hole = max_child; + } + first[hole] = std::move(value); +} + +// ContainerContainer could be a container of pointers to container. +template +Out nway_merge(const ContainerContainer& containers, Out out, Cmp cmp) { + using std::begin; + using std::end; + using In = decltype(begin(**begin(containers))); // The input iterator type. + using Range = std::pair; + std::vector sources; + for (const auto& container : containers) { + Range r(begin(*container), end(*container)); + if (r.first != r.second) { + sources.push_back(r); + } + } + if (sources.empty()) return out; + // Take a comparator for T and produce an inverse comparator + // for std::pair, In>, inverted so as to produce a min-heap. + auto heap_cmp = [&](const Range& a, const Range& b) { + // Compares b < a instead of a < b. + return cmp(*b.first, *a.first); + }; + std::make_heap(sources.begin(), sources.end(), heap_cmp); + while (true) { + Range& r = sources.front(); + *out = *r.first; + ++r.first; + ++out; + if (r.first == r.second) { + if (sources.size() == 1) return out; + r = std::move(sources.back()); + sources.pop_back(); + } + push_down_root(sources.begin(), sources.end(), heap_cmp); + } +} + +// Interface that allows defining classes that map XLines within a single XPlane +// to multiple virtual devices in trace viewer. +class ResourceGrouperInterface { + public: + virtual ~ResourceGrouperInterface() = default; + + virtual std::vector> + Devices() const = 0; + + virtual uint32_t GetDeviceId(uint32_t resource_id) const = 0; +}; + +std::unique_ptr CreateDefaultResourceGrouper( + uint32_t device_id, absl::string_view name); + +} // namespace profiler +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PROFILER_CONVERT_TRACE_VIEWER_TRACE_EVENTS_UTIL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/profiler/convert/trace_viewer/trace_viewer_color.h b/third_party/tflite-hdrs/tensorflow/core/profiler/convert/trace_viewer/trace_viewer_color.h new file mode 100644 index 00000000..be2bb9f0 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/profiler/convert/trace_viewer/trace_viewer_color.h @@ -0,0 +1,98 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_PROFILER_CONVERT_TRACE_VIEWER_TRACE_VIEWER_COLOR_H_ +#define TENSORFLOW_CORE_PROFILER_CONVERT_TRACE_VIEWER_TRACE_VIEWER_COLOR_H_ + +#include +#include + +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "tensorflow/core/profiler/protobuf/trace_events.pb.h" + +namespace tensorflow { +namespace profiler { + +// Pre-defined color names (excluding "black" and "white") from: +// https://github.com/catapult-project/catapult/blob/master/tracing/tracing/base/color_scheme.html. +// Possible value of TraceEvent.color_id +enum TraceViewerColor { + kThreadStateUninterruptible, + kThreadStateIowait, + kThreadStateRunning, + kThreadStateRunnable, + kThreadStateUnknown, + kBackgroundMemoryDump, + kLightMemoryDump, + kDetailedMemoryDump, + kVsyncHighlightColor, + kGenericWork, + kGood, + kBad, + kTerrible, + kGrey, + kYellow, + kOlive, + kRailResponse, + kRailAnimation, + kRailIdle, + kRailLoad, + kStartup, + kHeapDumpStackFrame, + kHeapDumpObjectType, + kHeapDumpChildNodeArrow, + kCqBuildRunning, + kCqBuildPassed, + kCqBuildFailed, + kCqBuildAbandoned, + kCqBuildAttemptRunnig, + kCqBuildAttemptPassed, + kCqBuildAttemptFailed, +}; + +// Number of named colors in TraceViewer. +constexpr uint32_t kNumTraceViewerColors = + TraceViewerColor::kCqBuildAttemptFailed + 1; + +// Returns the color name for a given color id. +// Used to decode the value in TraceEvent.color_id. +absl::string_view TraceViewerColorName(uint32_t color_id); + +// Trace event colorer interface. +class TraceEventsColorerInterface { + public: + virtual ~TraceEventsColorerInterface() = default; + + // Allow sub-classes to set up coloring by processing the trace, e.g., by + // capturing the names of devices and resources that need to be colored. + virtual void SetUp(const Trace& trace) = 0; + + // Returns the color for a trace event. + virtual std::optional GetColor(const TraceEvent& event) const = 0; +}; + +class DefaultTraceEventsColorer : public TraceEventsColorerInterface { + public: + void SetUp(const Trace& trace) override {} + + std::optional GetColor(const TraceEvent& event) const override { + return std::nullopt; + } +}; + +} // namespace profiler +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PROFILER_CONVERT_TRACE_VIEWER_TRACE_VIEWER_COLOR_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/profiler/convert/trace_viewer/trace_viewer_visibility.h b/third_party/tflite-hdrs/tensorflow/core/profiler/convert/trace_viewer/trace_viewer_visibility.h new file mode 100644 index 00000000..13dfabe5 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/profiler/convert/trace_viewer/trace_viewer_visibility.h @@ -0,0 +1,179 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_PROFILER_CONVERT_TRACE_VIEWER_TRACE_VIEWER_VISIBILITY_H_ +#define TENSORFLOW_CORE_PROFILER_CONVERT_TRACE_VIEWER_TRACE_VIEWER_VISIBILITY_H_ + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/types/optional.h" +#include "xla/tsl/profiler/utils/timespan.h" +#include "tensorflow/core/profiler/convert/trace_viewer/trace_events_filter_interface.h" +#include "tensorflow/core/profiler/protobuf/trace_events.pb.h" + +namespace tensorflow { +namespace profiler { + +// Determines whether an event will be visible in trace viewer within a visible +// tsl::profiler::Timespan at a certain resolution. +// Events must be evaluated in order by timestamp, because when an event is +// determined to be visible, the internal state of this class is updated. +class TraceViewerVisibility { + public: + // Create with visible timespan and resolution (in picoseconds). + // The visible timespan must have non-zero duration. + // If resolution is zero, no events are downsampled. + explicit TraceViewerVisibility(tsl::profiler::Timespan visible_span, + uint64_t resolution_ps = 0); + + // Returns true if the event overlaps the visible span and is distinguishable + // at resolution_ps. + bool Visible(const TraceEvent& event); + + // Returns true if the event is distinguishable at resolution_ps. + bool VisibleAtResolution(const TraceEvent& event); + + // Records that event is distinguishable at resolution_ps. + void SetVisibleAtResolution(const TraceEvent& event); + + tsl::profiler::Timespan VisibleSpan() const { return visible_span_; } + // TODO(tf-profiler) Rename ResolutionPs and resolution_ps to be more + // self-explanatory (eg. MinDurationPs) + uint64_t ResolutionPs() const { return resolution_ps_; } + + private: + // Identifier for one Trace Viewer row. + using RowId = std::pair; + using CounterRowId = std::pair; + + // Visibility for one Trace Viewer row. + class RowVisibility { + public: + // Returns the nesting depth for an event at begin_timestamp_ps. + size_t Depth(uint64_t begin_timestamp_ps) const; + + // Returns the end_timestamp_ps of the last visibile event at the given + // nesting depth. + std::optional LastEndTimestampPs(size_t depth) const { + std::optional result; + if (depth < last_end_timestamp_ps_.size()) { + result = last_end_timestamp_ps_[depth]; + } + return result; + } + + // Returns the arrow timestamp of the last visible flow event. + std::optional LastFlowTimestampPs() const { + return last_flow_timestamp_ps_; + } + + // Sets the last visible timestamp at the given nesting depth. + void SetLastEndTimestampPs(size_t depth, uint64_t timestamp_ps) { + last_end_timestamp_ps_.resize(depth); + last_end_timestamp_ps_.push_back(timestamp_ps); + } + + // Sets the last visible arrow timestamp. + void SetLastFlowTimestampPs(uint64_t timestamp_ps) { + last_flow_timestamp_ps_ = timestamp_ps; + } + + private: + // Stack of most recently visible event end times. A stack is used to handle + // nested events. + std::vector last_end_timestamp_ps_; + + // Timestamp of the arrow binding point of the last visible flow event. + std::optional last_flow_timestamp_ps_; + }; + + // Constructor arguments. + tsl::profiler::Timespan visible_span_; + uint64_t resolution_ps_; + + // Visibility data for all rows. + absl::flat_hash_map rows_; + + // Visibility of flows. + absl::flat_hash_map flows_; + + // Visibility data for counter events. + absl::flat_hash_map last_counter_timestamp_ps_; +}; + +class TraceVisibilityFilter : public TraceEventsFilterInterface { + public: + // If visible_span.Instant(), all events are visible. + // If resolution is 0.0, events aren't downsampled. + TraceVisibilityFilter(tsl::profiler::Timespan visible_span, double resolution) + : resolution_(resolution), + visibility_(visible_span, ResolutionPs(visible_span.duration_ps())) {} + + tsl::profiler::Timespan VisibleSpan() const { + return visibility_.VisibleSpan(); + } + uint64_t ResolutionPs() const { return visibility_.ResolutionPs(); } + + void SetUp(const Trace& trace) override { + // Update visible_span with trace bounds and recompute the resolution in + // picoseconds. + tsl::profiler::Timespan visible_span = VisibleSpan(); + uint64_t start_time_ps = visible_span.begin_ps(); + uint64_t end_time_ps = visible_span.end_ps(); + if (end_time_ps == 0 && trace.has_max_timestamp_ps()) { + end_time_ps = trace.max_timestamp_ps(); + } + if (start_time_ps == 0 && trace.has_min_timestamp_ps()) { + start_time_ps = trace.min_timestamp_ps(); + } + visible_span = + tsl::profiler::Timespan::FromEndPoints(start_time_ps, end_time_ps); + visibility_ = TraceViewerVisibility( + visible_span, ResolutionPs(visible_span.duration_ps())); + } + + // Updates the visibility based on `resolution`. + void UpdateVisibility(double resolution) { + resolution_ = resolution; + visibility_ = TraceViewerVisibility( + visibility_.VisibleSpan(), + ResolutionPs(visibility_.VisibleSpan().duration_ps())); + } + + bool Filter(const TraceEvent& event) override { + return !visibility_.Visible(event); + } + + private: + // Returns the minimum duration in picoseconds that an event must have in + // order to be visible. + uint64_t ResolutionPs(uint64_t duration_ps) { + return (resolution_ == 0.0) ? 0 : std::llround(duration_ps / resolution_); + } + + double resolution_; // number of visible events per row + TraceViewerVisibility visibility_; +}; + +} // namespace profiler +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PROFILER_CONVERT_TRACE_VIEWER_TRACE_VIEWER_VISIBILITY_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/profiler/convert/xplane_to_dcn_collective_stats.h b/third_party/tflite-hdrs/tensorflow/core/profiler/convert/xplane_to_dcn_collective_stats.h new file mode 100644 index 00000000..68e0b491 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/profiler/convert/xplane_to_dcn_collective_stats.h @@ -0,0 +1,43 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PROFILER_CONVERT_XPLANE_TO_DCN_COLLECTIVE_STATS_H_ +#define TENSORFLOW_CORE_PROFILER_CONVERT_XPLANE_TO_DCN_COLLECTIVE_STATS_H_ + +#include "tensorflow/core/platform/statusor.h" +#include "tensorflow/core/profiler/convert/repository.h" +#include "tensorflow/core/profiler/protobuf/dcn_slack_analysis.pb.h" + +namespace tensorflow { +namespace profiler { + +// Converts multiple XSpaces to dcn collective stats. +// Stores the dcn collective stats as files in the same directory +// as the xspace files. +absl::StatusOr ConvertMultiXSpaceToDcnCollectiveStats( + const SessionSnapshot& session_snapshot); + +// Returns whether there are dcn collective stats in the profile. +absl::StatusOr HasDcnCollectiveStatsInMultiXSpace( + const SessionSnapshot& session_snapshot); + +// Gets DcnSlackAnalysis proto for a host. +absl::StatusOr GetDcnSlackAnalysisByHostName( + const SessionSnapshot& session_snapshot, std::string hostname); + +} // namespace profiler +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PROFILER_CONVERT_XPLANE_TO_DCN_COLLECTIVE_STATS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/profiler/convert/xplane_to_hlo.h b/third_party/tflite-hdrs/tensorflow/core/profiler/convert/xplane_to_hlo.h new file mode 100644 index 00000000..2361ba6e --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/profiler/convert/xplane_to_hlo.h @@ -0,0 +1,42 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PROFILER_CONVERT_XPLANE_TO_HLO_H_ +#define TENSORFLOW_CORE_PROFILER_CONVERT_XPLANE_TO_HLO_H_ + +#include + +#include "absl/strings/string_view.h" +#include "xla/service/hlo.pb.h" +#include "tensorflow/core/platform/statusor.h" +#include "tensorflow/core/profiler/convert/repository.h" + +namespace tensorflow { +namespace profiler { + +// Get HLO proto by module name. +absl::StatusOr GetHloProtoByModuleName( + const SessionSnapshot& session_snapshot, absl::string_view module_name); + +// Converts multiple XSpaces to HLO protos. +// Stores the HLO protos as files in the same directory as the xspace files. +// Returns whether there are HLO protos in this profile. +absl::StatusOr ConvertMultiXSpaceToHloProto( + const SessionSnapshot& session_snapshot); + +} // namespace profiler +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PROFILER_CONVERT_XPLANE_TO_HLO_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/profiler/convert/xplane_to_kernel_stats_db.h b/third_party/tflite-hdrs/tensorflow/core/profiler/convert/xplane_to_kernel_stats_db.h new file mode 100644 index 00000000..7cf9430c --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/profiler/convert/xplane_to_kernel_stats_db.h @@ -0,0 +1,39 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PROFILER_CONVERT_XPLANE_TO_KERNEL_STATS_DB_H_ +#define TENSORFLOW_CORE_PROFILER_CONVERT_XPLANE_TO_KERNEL_STATS_DB_H_ + +#include + +#include "tensorflow/core/profiler/protobuf/kernel_stats.pb.h" +#include "tensorflow/core/profiler/protobuf/xplane.pb.h" +#include "tensorflow/core/profiler/utils/gpu_event_stats.h" +#include "tensorflow/core/profiler/utils/kernel_stats_utils.h" +#include "tensorflow/core/profiler/utils/xplane_visitor.h" + +namespace tensorflow { +namespace profiler { + +void ConvertDeviceTraceXPlaneToKernelReports( + const XPlane& device_trace, + const std::function& + on_kernel_fn, + KernelReportMap* reports); + +} // namespace profiler +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PROFILER_CONVERT_XPLANE_TO_KERNEL_STATS_DB_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/profiler/convert/xplane_to_memory_profile.h b/third_party/tflite-hdrs/tensorflow/core/profiler/convert/xplane_to_memory_profile.h new file mode 100644 index 00000000..00f919d4 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/profiler/convert/xplane_to_memory_profile.h @@ -0,0 +1,40 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PROFILER_CONVERT_XPLANE_TO_MEMORY_PROFILE_H_ +#define TENSORFLOW_CORE_PROFILER_CONVERT_XPLANE_TO_MEMORY_PROFILE_H_ + +#include + +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/profiler/protobuf/memory_profile.pb.h" +#include "tsl/profiler/protobuf/xplane.pb.h" + +namespace tensorflow { +namespace profiler { + +// Process the host threads XPlane and generate MemoryProfile result; at most +// max_num_snapshots will be displayed on the UI. +// REQUIRED: host_plane should have been grouped by calling GroupTfEvents(). +MemoryProfile ConvertXPlaneToMemoryProfile(const XPlane& host_plane, + int64_t max_num_snapshots = 1000); + +absl::Status ConvertXSpaceToMemoryProfileJson(const XSpace& xspace, + std::string* json_output); +} // namespace profiler +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PROFILER_CONVERT_XPLANE_TO_MEMORY_PROFILE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/profiler/convert/xplane_to_op_metrics_db.h b/third_party/tflite-hdrs/tensorflow/core/profiler/convert/xplane_to_op_metrics_db.h new file mode 100644 index 00000000..c5d2a229 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/profiler/convert/xplane_to_op_metrics_db.h @@ -0,0 +1,59 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PROFILER_CONVERT_XPLANE_TO_OP_METRICS_DB_H_ +#define TENSORFLOW_CORE_PROFILER_CONVERT_XPLANE_TO_OP_METRICS_DB_H_ + +#include "absl/container/flat_hash_map.h" +#include "absl/types/optional.h" +#include "xla/tsl/profiler/utils/tf_op_utils.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/profiler/convert/op_metrics_db_combiner.h" +#include "tensorflow/core/profiler/protobuf/op_metrics.pb.h" +#include "tensorflow/core/profiler/protobuf/xplane.pb.h" +#include "tensorflow/core/profiler/utils/op_utils.h" +#include "tensorflow/core/profiler/utils/xplane_visitor.h" + +namespace tensorflow { +namespace profiler { + +// Data per host thread for TensorFlow Op Metrics Database. +struct TfMetricsDbData { + // A database of TF-Op metrics for this core. + OpMetricsDb tf_metrics_db; + HostOpMetricsDbBuilder tf_metrics_db_builder{&tf_metrics_db}; +}; + +absl::flat_hash_map +CollectTfOpsFromHostThreadsXPlane(const XPlane& host_trace); + +TfMetricsDbData ConvertHostThreadsXLineToTfMetricsDbData( + const XLineVisitor& line, + const absl::flat_hash_map& tf_ops); + +void ConsumeTfMetricsDbData(TfMetricsDbData src, OpMetricsDbCombiner* dst); + +OpMetricsDb ConvertHostThreadsXPlaneToOpMetricsDb(const XPlane& host_trace); + +OpMetricsDb ConvertDeviceTraceXPlaneToOpMetricsDb(const XPlane& device_trace); + +// Convert TPU DeviceTrace XPlane to OpMetricDb +OpMetricsDb ConvertTpuDeviceTraceXPlaneToOpMetricsDb( + const XPlane& device_trace); + +} // namespace profiler +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PROFILER_CONVERT_XPLANE_TO_OP_METRICS_DB_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/profiler/convert/xplane_to_op_stats.h b/third_party/tflite-hdrs/tensorflow/core/profiler/convert/xplane_to_op_stats.h new file mode 100644 index 00000000..994efb03 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/profiler/convert/xplane_to_op_stats.h @@ -0,0 +1,61 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PROFILER_CONVERT_XPLANE_TO_OP_STATS_H_ +#define TENSORFLOW_CORE_PROFILER_CONVERT_XPLANE_TO_OP_STATS_H_ + +#include + +#include "tensorflow/core/profiler/convert/repository.h" +#include "tensorflow/core/profiler/protobuf/op_stats.pb.h" +#include "tensorflow/core/profiler/utils/hlo_proto_map.h" +#include "tsl/profiler/protobuf/xplane.pb.h" + +namespace tensorflow { +namespace profiler { + +struct OpStatsOptions { + bool maybe_drop_incomplete_steps = false; + bool generate_op_metrics_db = false; + bool generate_step_db = false; + bool generate_kernel_stats_db = false; +}; + +// NOTE: call GroupTfEvents before if OpStats.step_db needs to be generated. +OpStats ConvertXSpaceToOpStats(const XSpace& space, + const OpStatsOptions& options); + +// Populates the program_id_to_name map in OpStats. +void SetProgramIdToNameMap(const HloProtoMap& hlo_proto_map, + tensorflow::profiler::OpStats& op_stats); + +// Populates the given RunEnvironment with data from XSpace. +void SetRunEnvironment(const XSpace& space, RunEnvironment* env); + +// Propagate and dedup the diagnostics in XSpace and add to OpStats. +void PropagateXSpaceDiagnosticsToOpStats(const XSpace& space, + OpStats* op_stats); + +// Populates PerfEnv. +PerfEnv MakePerfEnv(double peak_tera_flops_per_second, + std::vector peak_bws); + +// Extracts PerfEnv from XPlane stats. +PerfEnv GetPerfEnvFromXPlane(const XPlane& device_plane); + +} // namespace profiler +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PROFILER_CONVERT_XPLANE_TO_OP_STATS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/profiler/convert/xplane_to_step_events.h b/third_party/tflite-hdrs/tensorflow/core/profiler/convert/xplane_to_step_events.h new file mode 100644 index 00000000..acd84574 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/profiler/convert/xplane_to_step_events.h @@ -0,0 +1,47 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PROFILER_CONVERT_XPLANE_TO_STEP_EVENTS_H_ +#define TENSORFLOW_CORE_PROFILER_CONVERT_XPLANE_TO_STEP_EVENTS_H_ + +#include "tensorflow/core/profiler/protobuf/xplane.pb.h" +#include "tensorflow/core/profiler/utils/event_span.h" +#include "tensorflow/core/profiler/utils/xplane_visitor.h" + +namespace tensorflow { +namespace profiler { + +// Convert the host threads in XLine format to StepEvents format. If +// device_step_events is non-null, we will filter out events that only happens +// on CPU. +StepEvents ConvertHostThreadsXLineToStepEvents( + const XLineVisitor& line, const StepEvents* device_step_events); + +// Convert the host threads in XPlane format to StepEvents format. If +// device_step_events is non-null, we will filter out events that only happens +// on CPU. +StepEvents ConvertHostThreadsXPlaneToStepEvents( + const XPlane& host_trace, const StepEvents* device_step_events); + +// Convert the device trace in XLine format to StepEvents. +StepEvents ConvertDeviceTraceXLineToStepEvents(const XLineVisitor& line); + +// Convert the device trace in XPlane format to StepEvents. +StepEvents ConvertDeviceTraceXPlaneToStepEvents(const XPlane& device_trace); + +} // namespace profiler +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PROFILER_CONVERT_XPLANE_TO_STEP_EVENTS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/profiler/convert/xplane_to_step_stats.h b/third_party/tflite-hdrs/tensorflow/core/profiler/convert/xplane_to_step_stats.h new file mode 100644 index 00000000..5d5ff20c --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/profiler/convert/xplane_to_step_stats.h @@ -0,0 +1,31 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PROFILER_CONVERT_XPLANE_TO_STEP_STATS_H_ +#define TENSORFLOW_CORE_PROFILER_CONVERT_XPLANE_TO_STEP_STATS_H_ + +#include "tensorflow/core/framework/step_stats.pb.h" +#include "tsl/profiler/protobuf/xplane.pb.h" + +namespace tensorflow { +namespace profiler { + +// Converts XSpace collected by profiling a GPU device to StepStats. +void ConvertGpuXSpaceToStepStats(const XSpace& xspace, StepStats* step_stats); + +} // namespace profiler +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PROFILER_CONVERT_XPLANE_TO_STEP_STATS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/profiler/convert/xplane_to_tf_data_stats.h b/third_party/tflite-hdrs/tensorflow/core/profiler/convert/xplane_to_tf_data_stats.h new file mode 100644 index 00000000..f5f53488 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/profiler/convert/xplane_to_tf_data_stats.h @@ -0,0 +1,62 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PROFILER_CONVERT_XPLANE_TO_TF_DATA_STATS_H_ +#define TENSORFLOW_CORE_PROFILER_CONVERT_XPLANE_TO_TF_DATA_STATS_H_ + +#include "absl/strings/string_view.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/profiler/protobuf/tf_data_stats.pb.h" +#include "tsl/profiler/protobuf/xplane.pb.h" + +namespace tensorflow { +namespace profiler { + +TF_CONST_INIT extern const int64_t kSlowCallThresholdPs; + +enum class BottleneckType { + kSlowSource, + kSlowDataService, + kSlowRemoteSource, + kSlowTransformationWithParallelVersion, + kSlowTransformationWithoutParallelVersion, + kOther, +}; + +BottleneckType GetBottleneckType(absl::string_view bottleneck_iterator_name); + +class CombinedTfDataStatsBuilder { + public: + explicit CombinedTfDataStatsBuilder( + CombinedTfDataStats* combined_tf_data_stats, + bool generate_suggestion = true) + : combined_tf_data_stats_(combined_tf_data_stats), + generate_suggestion_(generate_suggestion) {} + + void Add(absl::string_view host_name, XPlane* host_plane); + + // Finalizes by populating TfDataBottleneckAnalysis. + void Finalize(); + + private: + CombinedTfDataStats* combined_tf_data_stats_; + bool generate_suggestion_; +}; + +} // namespace profiler +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PROFILER_CONVERT_XPLANE_TO_TF_DATA_STATS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/profiler/convert/xplane_to_tf_functions.h b/third_party/tflite-hdrs/tensorflow/core/profiler/convert/xplane_to_tf_functions.h new file mode 100644 index 00000000..fbff7cce --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/profiler/convert/xplane_to_tf_functions.h @@ -0,0 +1,39 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PROFILER_CONVERT_XPLANE_TO_TF_FUNCTIONS_H_ +#define TENSORFLOW_CORE_PROFILER_CONVERT_XPLANE_TO_TF_FUNCTIONS_H_ + +#include + +#include "tensorflow/core/profiler/protobuf/tf_function.pb.h" +#include "tensorflow/core/profiler/utils/xplane_visitor.h" + +namespace tensorflow { +namespace profiler { + +// Converts from the given XLine to a TfFunctionDb. +TfFunctionDb ConvertHostThreadsXLineToTfFunctionDb(const XLineVisitor& line); + +// Returns a debugging string for the given TfFunctionDb. +std::string DebugString(TfFunctionDb tf_function_db); + +// Combines the tf-function statistics from src and dst into dst. +void CombineTfFunctionDb(const TfFunctionDb& src, TfFunctionDb* dst); + +} // namespace profiler +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PROFILER_CONVERT_XPLANE_TO_TF_FUNCTIONS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/profiler/convert/xplane_to_tool_names.h b/third_party/tflite-hdrs/tensorflow/core/profiler/convert/xplane_to_tool_names.h new file mode 100644 index 00000000..a1e93694 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/profiler/convert/xplane_to_tool_names.h @@ -0,0 +1,35 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PROFILER_CONVERT_XPLANE_TO_TOOL_NAMES_H_ +#define TENSORFLOW_CORE_PROFILER_CONVERT_XPLANE_TO_TOOL_NAMES_H_ + +#include + +#include "tensorflow/core/platform/statusor.h" +#include "tensorflow/core/profiler/convert/repository.h" + +namespace tensorflow { +namespace profiler { + +// Gets the names of the available tools given a session snapshot. +// Returns a comma separated list of tool names. +absl::StatusOr GetAvailableToolNames( + const SessionSnapshot& session_snapshot); + +} // namespace profiler +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PROFILER_CONVERT_XPLANE_TO_TOOL_NAMES_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/profiler/convert/xplane_to_tools_data.h b/third_party/tflite-hdrs/tensorflow/core/profiler/convert/xplane_to_tools_data.h new file mode 100644 index 00000000..8a40e03a --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/profiler/convert/xplane_to_tools_data.h @@ -0,0 +1,39 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PROFILER_CONVERT_XPLANE_TO_TOOLS_DATA_H_ +#define TENSORFLOW_CORE_PROFILER_CONVERT_XPLANE_TO_TOOLS_DATA_H_ + +#include + +#include "absl/strings/string_view.h" +#include "tensorflow/core/platform/statusor.h" +#include "tensorflow/core/profiler/convert/repository.h" +#include "tensorflow/core/profiler/convert/tool_options.h" + +namespace tensorflow { +namespace profiler { + +// Convert XSpace protos to a tool specific data. +// Return the serialized string of tool specific data when the conversion is +// successful, else return error status. +absl::StatusOr ConvertMultiXSpacesToToolData( + const SessionSnapshot& session_snapshot, absl::string_view tool_name, + const ToolOptions& options); + +} // namespace profiler +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PROFILER_CONVERT_XPLANE_TO_TOOLS_DATA_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/profiler/convert/xplane_to_trace_container.h b/third_party/tflite-hdrs/tensorflow/core/profiler/convert/xplane_to_trace_container.h new file mode 100644 index 00000000..cdf3a72f --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/profiler/convert/xplane_to_trace_container.h @@ -0,0 +1,36 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PROFILER_CONVERT_XPLANE_TO_TRACE_CONTAINER_H_ +#define TENSORFLOW_CORE_PROFILER_CONVERT_XPLANE_TO_TRACE_CONTAINER_H_ + +#include "tensorflow/core/profiler/convert/trace_viewer/trace_events.h" +#include "tensorflow/core/profiler/protobuf/trace_events_raw.pb.h" +#include "tsl/profiler/protobuf/xplane.pb.h" + +namespace tensorflow { +namespace profiler { + +using TraceEventsContainer = TraceEventsContainerBase; + +// Converts XEvents within the XSpace into trace_viewer events container. +void ConvertXSpaceToTraceEventsContainer(absl::string_view hostname, + const XSpace& xspace, + TraceEventsContainer* container); + +} // namespace profiler +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PROFILER_CONVERT_XPLANE_TO_TRACE_CONTAINER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/profiler/convert/xspace_to_dcn_slack_analysis.h b/third_party/tflite-hdrs/tensorflow/core/profiler/convert/xspace_to_dcn_slack_analysis.h new file mode 100644 index 00000000..2f9e5551 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/profiler/convert/xspace_to_dcn_slack_analysis.h @@ -0,0 +1,167 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_PROFILER_CONVERT_XSPACE_TO_DCN_SLACK_ANALYSIS_H_ +#define TENSORFLOW_CORE_PROFILER_CONVERT_XSPACE_TO_DCN_SLACK_ANALYSIS_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/tsl/profiler/utils/timespan.h" +#include "xla/tsl/profiler/utils/xplane_visitor.h" +#include "tensorflow/core/profiler/protobuf/dcn_collective_info.pb.h" +#include "tensorflow/core/profiler/protobuf/dcn_slack_analysis.pb.h" +#include "tensorflow/core/profiler/protobuf/topology.pb.h" +#include "tensorflow/core/profiler/utils/hlo_proto_map.h" +#include "tsl/profiler/protobuf/xplane.pb.h" + +namespace tensorflow { +namespace profiler { + +using tensorflow::profiler::DcnSlackAnalysis; + +namespace dcn_analysis_internal { + +struct DcnOpState { + uint64_t start_time = 0; + uint64_t end_time = 0; + + // Duration of containing send/send-done/recv/recv-done ops that needs to be + // subtracted from the total duration + uint64_t overlapping_duration = 0; + std::string rendezvous_name; + std::string transfer_type; + uint64_t stall_duration_ns = 0; + std::string send_op_name; + int replica_group_size = 0; + + OpInstance send; + OpInstance send_done; + OpInstance recv; + OpInstance recv_done; +}; + +// Structure to extract and store the DcnHostEvents. +struct DcnHostEvent { + std::string rendezvous_name; + tsl::profiler::Timespan timespan; + int multi_slice_device_id; +}; + +// When visiting DcnHostEvents from the megascale planes, The events are stored +// in separate lines in an ascending (by time) order. The List allows insertion +// of multiple arrays of sorted events. +class DcnHostEventList { + public: + // Insert the event into the sorted list. + void insert(DcnHostEvent event); + + // Pop the events from the front that is included within the timestamp when + // available. + std::optional pop(const tsl::profiler::Timespan& timespan); + + // Number of events. + int size() const { return events_.size(); } + + private: + std::list events_; + std::list::iterator iter_ = events_.begin(); +}; + +struct InstrMetadata { + xla::HloOpcode opcode; + uint64_t channel_id; + std::optional rendezvous_name; + int64_t size = 0; + std::optional transfer_type; +}; + +class DcnTracker { + public: + explicit DcnTracker(const tensorflow::profiler::HloProtoMap& hlo_proto_map, + bool is_megacore) + : hlo_proto_map_(hlo_proto_map), is_megacore_(is_megacore) {} + + absl::StatusOr GetInstructionMetadata(std::string_view module, + std::string_view instr); + + DcnSlackAnalysis Finalize(); + + void DebugString(); + + void VisitOp(const InstrMetadata& instr, + const tsl::profiler::XEventVisitor& visitor); + + void VisitHostEvent(const DcnHostEvent& event); + + void ProcessTopology(const tensorflow::profiler::Topology& topology); + + private: + DcnSlackAnalysis slack_analysis_; + absl::flat_hash_map rendezvous_to_op_map_; + absl::flat_hash_map channel_id_to_rendezvous_map_; + absl::flat_hash_map instruction_metadata_map_; + absl::flat_hash_map core_id_to_host_event_map_; + const tensorflow::profiler::HloProtoMap& hlo_proto_map_; + absl::flat_hash_map global_chip_id_to_local_index_map_; + absl::flat_hash_map> + hlo_module_cache_; + absl::flat_hash_map rendezvous_to_replica_group_size_map_; + bool is_megacore_ = true; + + absl::StatusOr GetInstrMetadataFromHloModule( + std::string_view module, std::string_view instr); + + void UpdateActiveOps(uint64_t duration); + + void SummarizeDcnSlackAnalysis(); + + std::optional GetCollectiveHostEvent( + int core_id, std::string_view rendezvous_name, + tsl::profiler::Timespan timespan); + + // GetLocalIndex when available, else return the global_device_id itself. + int GetLocalIndex(int dcn_device_id); + + // Get number of replica group + int GetReplicaGroupSize(const std::string& rendezvous_name, + const tsl::profiler::XEventVisitor& visitor); + + // Compute data transmitted size based on number of replica groups + uint64_t ComputeTransmittedDataSize(int64_t buffer_size, int group_size, + const std::string& transfer_type); +}; + +} // namespace dcn_analysis_internal + +// Convert Hlo Events in XSpace to Dcn Slack analysis. +DcnSlackAnalysis ConvertXSpaceToDcnSlackAnalysis( + const tensorflow::profiler::XSpace& xspace, + const tensorflow::profiler::XPlane* dcn_host_plane, + const tensorflow::profiler::Topology* topology, bool is_megacore = true); + +} // namespace profiler +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PROFILER_CONVERT_XSPACE_TO_DCN_SLACK_ANALYSIS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/profiler/internal/advisor/accelerator_utilization_checker.h b/third_party/tflite-hdrs/tensorflow/core/profiler/internal/advisor/accelerator_utilization_checker.h new file mode 100644 index 00000000..9ea9bdac --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/profiler/internal/advisor/accelerator_utilization_checker.h @@ -0,0 +1,113 @@ +/* Copyright 2016 The TensorFlow Authors All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// This checker checks the accelerator's utilization. +#ifndef TENSORFLOW_CORE_PROFILER_INTERNAL_ADVISOR_ACCELERATOR_UTILIZATION_CHECKER_H_ +#define TENSORFLOW_CORE_PROFILER_INTERNAL_ADVISOR_ACCELERATOR_UTILIZATION_CHECKER_H_ + +#include +#include + +#include "absl/strings/str_format.h" +#include "tensorflow/core/profiler/internal/advisor/checker.h" + +namespace tensorflow { +namespace tfprof { + +struct ExecStats { + public: + // Earliest start time of a step. + int64_t start_micros; + // Latest finish time of a step. + int64_t end_micros; + // The duration spent on running a kernel during a step. + int64_t exec_micros; +}; + +class AcceleratorUtilizationChecker : public Checker { + public: + string name() const override { return kCheckers[0]; } + + private: + AdviceProto::Checker Check(const AdvisorOptionsProto::CheckerOption& options, + const TFStats* stats) override { + if (!stats) { + absl::FPrintF( + stderr, "Missing profiles (e.g. graph, run_meta). Skip %s\n", name()); + return reports_; + } + for (const auto& n : stats->nodes()) { + BuildExecStats(n.second.get()); + } + return CheckInternal(); + } + + AdviceProto::Checker CheckInternal() { + for (const auto& s : accelerator_exec_stats_) { + const ExecStats& stat = s.second; + int64_t total_micros = stat.end_micros - stat.start_micros; + if (total_micros <= 0) continue; + double utilization = 1.0 * stat.exec_micros / total_micros; + if (utilization >= 0.5) { + reports_.add_reports(absl::StrFormat("device: %s utilization: %.2f", + s.first, utilization)); + } else if (utilization < 0.5 && utilization > 0.2) { + reports_.add_reports(absl::StrFormat("device: %s low utilization: %.2f", + s.first, utilization)); + } else if (utilization <= 0.2) { + reports_.add_reports(absl::StrFormat("device: %s low utilization: %.2f", + s.first, utilization)); + } + } + return reports_; + } + + void BuildExecStats(const TFGraphNode* node) { + const auto& execs = node->all_op_execs(); + if (execs.empty()) { + return; + } + if (!IsPlacedOnAccelerator(node->canonical_device())) { + return; + } + + if (accelerator_exec_stats_.find(node->canonical_device()) == + accelerator_exec_stats_.end()) { + accelerator_exec_stats_.insert( + std::pair(node->canonical_device(), ExecStats())); + } + ExecStats& stats = accelerator_exec_stats_.at(node->canonical_device()); + + // TODO(xpan): Use multiple steps? + const ExecStep& exec = execs.rbegin()->second; + + if (stats.start_micros == 0) { + stats.start_micros = exec.all_start_micros(); + } else if (exec.all_start_micros() != 0) { + stats.start_micros = + std::min(stats.start_micros, exec.all_start_micros()); + } + stats.end_micros = std::max(stats.end_micros, exec.latest_end_micros()); + stats.exec_micros += exec.accelerator_exec_micros(); + } + + std::map accelerator_exec_stats_; + std::map ps_placement_; + AdviceProto::Checker reports_; +}; + +} // namespace tfprof +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PROFILER_INTERNAL_ADVISOR_ACCELERATOR_UTILIZATION_CHECKER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/profiler/internal/advisor/checker.h b/third_party/tflite-hdrs/tensorflow/core/profiler/internal/advisor/checker.h new file mode 100644 index 00000000..3fc345cc --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/profiler/internal/advisor/checker.h @@ -0,0 +1,51 @@ +/* Copyright 2016 The TensorFlow Authors All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PROFILER_INTERNAL_ADVISOR_CHECKER_H_ +#define TENSORFLOW_CORE_PROFILER_INTERNAL_ADVISOR_CHECKER_H_ + +#include "tensorflow/core/profiler/internal/tfprof_stats.h" +#include "tensorflow/core/profiler/tfprof_options.pb.h" + +namespace tensorflow { +namespace tfprof { + +// Append only. +static const char* const kCheckers[] = { + "AcceleratorUtilizationChecker", "OperationChecker", + "ExpensiveOperationChecker", + "JobChecker", // Internal checker. +}; + +class Checker { + public: + virtual ~Checker() = default; + + virtual string name() const = 0; + + AdviceProto::Checker Run(const AdvisorOptionsProto::CheckerOption& options, + const TFStats* stats) { + return Check(options, stats); + } + + protected: + virtual AdviceProto::Checker Check( + const AdvisorOptionsProto::CheckerOption& options, + const TFStats* stats) = 0; +}; +} // namespace tfprof +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PROFILER_INTERNAL_ADVISOR_CHECKER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/profiler/internal/advisor/expensive_operation_checker.h b/third_party/tflite-hdrs/tensorflow/core/profiler/internal/advisor/expensive_operation_checker.h new file mode 100644 index 00000000..4ec0cb57 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/profiler/internal/advisor/expensive_operation_checker.h @@ -0,0 +1,143 @@ +/* Copyright 2016 The TensorFlow Authors All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// This checker checks the most expensive operations. +#ifndef TENSORFLOW_CORE_PROFILER_INTERNAL_ADVISOR_EXPENSIVE_OPERATION_CHECKER_H_ +#define TENSORFLOW_CORE_PROFILER_INTERNAL_ADVISOR_EXPENSIVE_OPERATION_CHECKER_H_ + +#include + +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" +#include "tensorflow/core/profiler/internal/advisor/checker.h" + +namespace tensorflow { +namespace tfprof { + +class ExpensiveOperationChecker : public Checker { + public: + string name() const override { return kCheckers[2]; } + + private: + AdviceProto::Checker Check(const AdvisorOptionsProto::CheckerOption& options, + const TFStats* stats) override { + if (!stats) { + absl::FPrintF( + stderr, "Missing profiles (e.g. graph, run_meta). Skip %s\n", name()); + return reports_; + } + if (stats->steps().empty()) { + absl::FPrintF(stderr, "Missing RunMetadata info. Skip %s\n", name()); + } + CheckOpView(stats); + CheckScopeView(stats); + CheckCodeView(stats); + return reports_; + } + + void CheckOpView(const TFStats* stats) { + if (stats->steps().empty()) { + absl::FPrintF(stderr, "Missing run_meta for %s\n", name()); + return; + } + Options opts(3, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, -1, "micros", {".*"}, {".*"}, + {}, {".*"}, {}, false, {"micros", "occurrence"}, "none", {}); + const MultiGraphNodeProto root = stats->ShowMultiGraphNode("op", opts); + if (root.children_size() == 0) { + return; + } + const MultiGraphNodeProto* node = &root; + std::vector outputs; + for (int i = 0; i < 3 && node->children_size() > 0; ++i) { + node = &node->children(0); + outputs.push_back(absl::StrFormat( + "top %d operation type: %s, " + "cpu: %s, accelerator: %s, total: %s (%.2f%%)", + i + 1, node->name(), FormatTime(node->cpu_exec_micros()), + FormatTime(node->accelerator_exec_micros()), + FormatTime(node->exec_micros()), + 100.0 * node->exec_micros() / (root.total_exec_micros() + 1e-10))); + } + reports_.add_reports(absl::StrJoin(outputs, "\n")); + } + + void CheckCodeView(const TFStats* stats) { + if (!stats->has_code_traces()) { + absl::FPrintF(stderr, "Missing op_log (code traces) for %s\n", name()); + return; + } + Options opts(100, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, -1, "micros", {".*"}, + {".*"}, {}, {".*"}, {}, false, {"micros"}, "none", {}); + const MultiGraphNodeProto root = stats->ShowMultiGraphNode("code", opts); + const MultiGraphNodeProto* node = &root; + // A trick here is: Usually, codes in library file are usually referenced + // only once, while user's own code are referenced multiple times. + while (node->children_size() == 1) { + node = &node->children(0); + } + if (node->children_size() == 0) { + return; + } + + std::vector outputs; + CodeViewHelper(node, 0, &outputs); + reports_.add_reports(absl::StrJoin(outputs, "\n")); + } + + void CheckScopeView(const TFStats* stats) { + Options opts(100, 0, 0, 0, 0, 100, 0, 0, 0, 0, 0, -1, "micros", {".*"}, + {".*"}, {}, {".*"}, {}, false, {"micros"}, "none", {}); + const GraphNodeProto root = stats->ShowGraphNode("scope", opts); + if (root.children_size() == 0) { + return; + } + std::vector outputs; + for (int i = 0; i < 3 && i < root.children_size(); ++i) { + const GraphNodeProto& node = root.children(i); + outputs.push_back(absl::StrFormat( + "top %d graph node: %s, cpu: %s, accelerator: %s, total: %s", i + 1, + node.name(), FormatTime(node.cpu_exec_micros()), + FormatTime(node.accelerator_exec_micros()), + FormatTime(node.exec_micros()))); + } + reports_.add_reports(absl::StrJoin(outputs, "\n")); + } + + void CodeViewHelper(const MultiGraphNodeProto* node, int depth, + std::vector* outputs) { + if (node->children_size() <= 1 || depth > 3) { + return; + } + for (int j = 0; j < 3 && j < node->children_size(); ++j) { + const MultiGraphNodeProto* c = &node->children(j); + if (c->total_exec_micros() < 1000) { + continue; + } + outputs->push_back( + absl::StrFormat("%s%s, cpu: %s, accelerator: %s, total: %s", + std::string(depth * 2, ' '), c->name(), + FormatTime(c->total_cpu_exec_micros()), + FormatTime(c->total_accelerator_exec_micros()), + FormatTime(c->total_exec_micros()))); + CodeViewHelper(c, depth + 1, outputs); + } + } + + AdviceProto::Checker reports_; +}; + +} // namespace tfprof +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PROFILER_INTERNAL_ADVISOR_EXPENSIVE_OPERATION_CHECKER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/profiler/internal/advisor/internal_checker_runner.h b/third_party/tflite-hdrs/tensorflow/core/profiler/internal/advisor/internal_checker_runner.h new file mode 100644 index 00000000..6fc16cf9 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/profiler/internal/advisor/internal_checker_runner.h @@ -0,0 +1,34 @@ +/* Copyright 2016 The TensorFlow Authors All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PROFILER_INTERNAL_ADVISOR_INTERNAL_CHECKER_RUNNER_H_ +#define TENSORFLOW_CORE_PROFILER_INTERNAL_ADVISOR_INTERNAL_CHECKER_RUNNER_H_ + +#include "tensorflow/core/profiler/internal/tfprof_utils.h" +#include "tensorflow/core/profiler/tfprof_options.pb.h" +#include "tensorflow/core/profiler/tfprof_output.pb.h" + +namespace tensorflow { +namespace tfprof { + +class TFStats; + +AdviceProto RunInternalCheckers(const AdvisorOptionsProto& options, + const TFStats* stats); + +} // namespace tfprof +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PROFILER_INTERNAL_ADVISOR_INTERNAL_CHECKER_RUNNER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/profiler/internal/advisor/operation_checker.h b/third_party/tflite-hdrs/tensorflow/core/profiler/internal/advisor/operation_checker.h new file mode 100644 index 00000000..5142639f --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/profiler/internal/advisor/operation_checker.h @@ -0,0 +1,78 @@ +/* Copyright 2016 The TensorFlow Authors All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// This checker checks common wrong configurations of operations. +// +#ifndef TENSORFLOW_CORE_PROFILER_INTERNAL_ADVISOR_OPERATION_CHECKER_H_ +#define TENSORFLOW_CORE_PROFILER_INTERNAL_ADVISOR_OPERATION_CHECKER_H_ + +#include "absl/strings/str_format.h" +#include "tensorflow/core/profiler/internal/advisor/checker.h" + +namespace tensorflow { +namespace tfprof { + +class OperationChecker : public Checker { + public: + string name() const override { return kCheckers[1]; } + + private: + AdviceProto::Checker Check(const AdvisorOptionsProto::CheckerOption& options, + const TFStats* stats) override { + if (!stats) { + absl::FPrintF( + stderr, "Missing profiles (e.g. graph, run_meta). Skip %s\n", name()); + return reports_; + } + bool use_batch_norm = false; + bool use_fused_batch_norm = false; + bool recommend_nchw = false; + for (const auto& n : stats->nodes()) { + const TFGraphNode* node = n.second.get(); + if (node->name().find("BatchNorm") != node->name().npos) { + use_batch_norm = true; + } + if (node->op_types().find("FusedBatchNorm") != node->op_types().end()) { + use_fused_batch_norm = true; + } + + const AttrValue* attr = node->op_attrs("data_format"); + if (attr) { + if (attr->s() == "NHWC" && + IsPlacedOnAccelerator(node->canonical_device())) { + recommend_nchw = true; + } + } + } + if (use_batch_norm && !use_fused_batch_norm) { + reports_.add_reports( + "Maybe use faster FusedBatchNorm instead of BatchNorm"); + } + if (recommend_nchw) { + // TODO(xpan): Maybe print which Op supports NCHW. + reports_.add_reports( + "Found operation using NHWC data_format on GPU. Maybe " + "NCHW is faster."); + } + return reports_; + } + + private: + AdviceProto::Checker reports_; +}; + +} // namespace tfprof +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PROFILER_INTERNAL_ADVISOR_OPERATION_CHECKER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/profiler/internal/advisor/tfprof_advisor.h b/third_party/tflite-hdrs/tensorflow/core/profiler/internal/advisor/tfprof_advisor.h new file mode 100644 index 00000000..e1db57cc --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/profiler/internal/advisor/tfprof_advisor.h @@ -0,0 +1,84 @@ +/* Copyright 2016 The TensorFlow Authors All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PROFILER_INTERNAL_ADVISOR_TFPROF_ADVISOR_H_ +#define TENSORFLOW_CORE_PROFILER_INTERNAL_ADVISOR_TFPROF_ADVISOR_H_ + +#include + +#include "absl/strings/str_format.h" +#include "tensorflow/core/profiler/internal/advisor/accelerator_utilization_checker.h" +#include "tensorflow/core/profiler/internal/advisor/checker.h" +#include "tensorflow/core/profiler/internal/advisor/expensive_operation_checker.h" +#include "tensorflow/core/profiler/internal/advisor/internal_checker_runner.h" +#include "tensorflow/core/profiler/internal/advisor/operation_checker.h" +#include "tensorflow/core/profiler/tfprof_options.pb.h" + +namespace tensorflow { +namespace tfprof { + +// The Advisor runs a list of Checkers, each checks a specific area. +class Advisor { + public: + Advisor(const TFStats* stats) : stats_(stats) {} + + static AdvisorOptionsProto DefaultOptions() { + AdvisorOptionsProto options; + std::vector checkers( + kCheckers, kCheckers + sizeof(kCheckers) / sizeof(*kCheckers)); + for (const string& checker : checkers) { + (*options.mutable_checkers())[checker]; + } + return options; + } + + AdviceProto Advise(const AdvisorOptionsProto& options) { + // Note: Release a checker's memory ASAP. + AdviceProto ret = RunInternalCheckers(options, stats_); + + if (options.checkers().find(kCheckers[0]) != options.checkers().end()) { + AcceleratorUtilizationChecker au_checker; + (*ret.mutable_checkers())[kCheckers[0]].MergeFrom( + au_checker.Run(options.checkers().at(kCheckers[0]), stats_)); + } + if (options.checkers().find(kCheckers[1]) != options.checkers().end()) { + OperationChecker op_checker; + (*ret.mutable_checkers())[kCheckers[1]].MergeFrom( + op_checker.Run(options.checkers().at(kCheckers[1]), stats_)); + } + if (options.checkers().find(kCheckers[2]) != options.checkers().end()) { + ExpensiveOperationChecker expensive_op_checker; + (*ret.mutable_checkers())[kCheckers[2]].MergeFrom( + expensive_op_checker.Run(options.checkers().at(kCheckers[2]), + stats_)); + } + for (const auto& checker : ret.checkers()) { + absl::FPrintF(stdout, "\n%s:\n", checker.first); + for (const string& r : checker.second.reports()) { + absl::FPrintF(stdout, "%s\n", r); + } + } + fflush(stdout); + return ret; + } + + private: + const TFStats* stats_; +}; + +} // namespace tfprof +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PROFILER_INTERNAL_ADVISOR_TFPROF_ADVISOR_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/profiler/internal/print_model_analysis.h b/third_party/tflite-hdrs/tensorflow/core/profiler/internal/print_model_analysis.h new file mode 100644 index 00000000..ab1887a8 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/profiler/internal/print_model_analysis.h @@ -0,0 +1,66 @@ +/* Copyright 2016 The TensorFlow Authors All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PROFILER_INTERNAL_PRINT_MODEL_ANALYSIS_H_ +#define TENSORFLOW_CORE_PROFILER_INTERNAL_PRINT_MODEL_ANALYSIS_H_ + +#include + +namespace tensorflow { +namespace tfprof { +struct Options; + +// ********************** +// APIs in this file are only for swig. +// Talk to xpan@ if you want to call it directly! +// ********************* + +// Multi-step Profiler. +// +bool NewProfiler(const std::string* graph, const std::string* op_log); + +void DeleteProfiler(); + +double AddStep(int64_t step, const std::string* graph, + const std::string* run_meta, const std::string* op_log); + +// Write the profiler's profile to a proto buffer. +void WriteProfile(const std::string* filename); + +// Load the profile to profiler from a proto buffer file. +void ProfilerFromFile(const std::string* filename); + +// Returns a binary string that represents the serialized ProfileProto. +std::string SerializeToString(); + +std::string Profile(const std::string* command, const std::string* options); + +// Single-step Profiler. +// +// Interface defined for Python API swig. Calls the tfprof core API. +// 'graph', 'run_meta', 'op_log' are serialized GraphDef, RunMetadata, +// OpLogProto strings, respectively. +// 'graph', 'command' and 'options' are required. Others can be nullptr +// if not available. +std::string PrintModelAnalysis(const std::string* graph, + const std::string* run_meta, + const std::string* op_log, + const std::string* command, + const std::string* options); + +} // namespace tfprof +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PROFILER_INTERNAL_PRINT_MODEL_ANALYSIS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/profiler/internal/tfprof_code.h b/third_party/tflite-hdrs/tensorflow/core/profiler/internal/tfprof_code.h new file mode 100644 index 00000000..5664fb0c --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/profiler/internal/tfprof_code.h @@ -0,0 +1,96 @@ +/* Copyright 2016 The TensorFlow Authors All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Build a tree structure based on the TensorFlow model's python code stacks. +// Stats are aggregated from descendants to ancestors. + +#ifndef TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_CODE_H_ +#define TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_CODE_H_ + +#include +#include +#include +#include + +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/profiler/internal/tfprof_node.h" +#include "tensorflow/core/profiler/internal/tfprof_show_multi.h" +#include "tensorflow/core/profiler/internal/tfprof_timeline.h" +#include "tensorflow/core/profiler/internal/tfprof_utils.h" +#include "tensorflow/core/profiler/profile.pb.h" +#include "tensorflow/core/profiler/tfprof_log.pb.h" +#include "tensorflow/core/profiler/tfprof_options.h" +#include "tensorflow/core/profiler/tfprof_output.pb.h" + +namespace tensorflow { +namespace tfprof { + +class PprofProfile { + public: + virtual ~PprofProfile() = default; + + virtual uint64 AddLocation(const CodeNode* callee, + const CodeNode* caller) = 0; + + virtual void AddSample(const CodeNode* leaf, + std::vector* call_ids) = 0; + + virtual absl::Status WritePprofProfile(const string& filename) = 0; +}; + +class TFCode : public TFMultiShow { + public: + TFCode() = default; + ~TFCode() override = default; + + // Add nodes to the code view. Called before Build() + void AddNode(TFGraphNode* node) override; + + // Build the code view structure. Called after all nodes + // are added via AddNode(). + void Build() override; + + private: + const ShowMultiNode* ShowInternal(const Options& opts, + Timeline* timeline) override; + + std::vector SearchRoot(std::vector roots, + const std::vector& regexes); + + std::vector PrintScope(std::vector roots, + const Options& opts, int depth, + int last_ident); + + std::vector Account(const std::vector& roots, + const Options& opts); + + void Format(const CodeNode* root, const std::vector& nodes, + const Options& opts, string* display_str, + MultiGraphNodeProto* proto, std::vector* call_ids); + + string FormatNode(CodeNode* node, const Options& opts, int64_t indent) const; + string FormatNodeMemory(CodeNode* node, int64_t bytes, + int64_t total_bytes) const; + + std::unique_ptr root_; + std::unique_ptr graph_root_; + std::unique_ptr pprof_profile_; + std::map> grad_nodes_; + std::map forward_nodes_; +}; +} // namespace tfprof +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_CODE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/profiler/internal/tfprof_constants.h b/third_party/tflite-hdrs/tensorflow/core/profiler/internal/tfprof_constants.h new file mode 100644 index 00000000..d4a47931 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/profiler/internal/tfprof_constants.h @@ -0,0 +1,37 @@ +/* Copyright 2016 The TensorFlow Authors All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_CONSTANTS_H_ +#define TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_CONSTANTS_H_ + +namespace tensorflow { +namespace tfprof { + +// Op name of root of everything. Aggregates all stats. +static const char* const kTFProfRoot = "_TFProfRoot"; +// Op type for nodes that doesn't represent a physical node in the +// TensorFlow model. Only exist as a placehold to aggregate children. +// For example, kTFProfRoot belongs to this type. +static const char* const kTFGraphParent = "_TFGraphParent"; +static const char* const kTFScopeParent = "_kTFScopeParent"; +// Op type for tf.trainable_variables(). +static const char* const kTrainableVarType = "_trainable_variables"; +// Op type for tensors in the checkpoint file. +static const char* const kCkptVarType = "_checkpoint_variables"; + +} // namespace tfprof +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_CONSTANTS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/profiler/internal/tfprof_graph.h b/third_party/tflite-hdrs/tensorflow/core/profiler/internal/tfprof_graph.h new file mode 100644 index 00000000..89ae0b37 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/profiler/internal/tfprof_graph.h @@ -0,0 +1,87 @@ +/* Copyright 2016 The TensorFlow Authors All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Build a graph structure based on op inputs/outputs. The graph is a directed +// acyclic graph pointing *from outputs to inputs*. + +#ifndef TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_GRAPH_H_ +#define TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_GRAPH_H_ + +#include +#include +#include +#include +#include +#include + +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/profiler/internal/tfprof_node.h" +#include "tensorflow/core/profiler/internal/tfprof_show.h" +#include "tensorflow/core/profiler/internal/tfprof_utils.h" +#include "tensorflow/core/profiler/tfprof_options.h" +#include "tensorflow/core/profiler/tfprof_output.pb.h" + +namespace tensorflow { +namespace tfprof { + +// Organize tensorflow ops in a graph structure, pointing from output ops +// to input ops. +class TFGraph : public TFShow { + public: + explicit TFGraph(checkpoint::CheckpointReader* ckpt_reader) + : TFShow(ckpt_reader), root_(nullptr) {} + ~TFGraph() override = default; + + void AddNode(TFGraphNode* node) override; + + void Build() override; + + private: + const ShowNode* ShowInternal(const Options& opts, + Timeline* timeline) override; + + bool ShouldShowIfExtra(const ShowNode* node, const Options& opts, + int depth) const override { + return true; + } + + GraphNode* CreateParentNode(const string& name); + + std::vector SearchRoot(const std::vector& roots, + const std::vector& regexes, + std::set* visited); + + std::vector PrintGraph(std::vector roots, + const Options& opts, int depth, + int last_ident, std::set* visits); + + std::vector Account(const std::vector& roots, + const Options& opts, + std::set* visits); + + void Format(std::vector roots, string* display_str, + GraphNodeProto* proto); + + MemoryTracker memory_tracker_; + GraphNode* root_; + std::vector> node_defs_; + std::map> parent_nodes_; + std::map> nodes_map_; +}; + +} // namespace tfprof +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_GRAPH_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/profiler/internal/tfprof_node.h b/third_party/tflite-hdrs/tensorflow/core/profiler/internal/tfprof_node.h new file mode 100644 index 00000000..e0645654 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/profiler/internal/tfprof_node.h @@ -0,0 +1,920 @@ +/* Copyright 2016 The TensorFlow Authors All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_NODE_H_ +#define TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_NODE_H_ + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/strings/str_format.h" +#include "tensorflow/core/framework/allocation_description.pb.h" +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/step_stats.pb.h" +#include "tensorflow/core/framework/tensor_description.pb.h" +#include "tensorflow/core/framework/tensor_shape.pb.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/platform/regexp.h" +#include "tensorflow/core/profiler/tfprof_log.pb.h" +#include "tensorflow/core/profiler/tfprof_options.h" + +namespace tensorflow { +namespace tfprof { +std::vector ShapeProtoToVec(const TensorShapeProto& shape_pb); + +TensorShapeProto VecToShapeProto(const std::vector& shape_vec); + +class TFGraphNode; + +class CallStack { + public: + class Trace { + public: + Trace(const CodeDef::Trace* trace, + const std::map* id_to_string) + : trace_(trace), id_to_string_(id_to_string) {} + + int32 lineno() const { return trace_->lineno(); } + string file() const { + // Backward compatible with old proto files. + if (!trace_->file().empty()) return trace_->file(); + return id_to_string_->at(trace_->file_id()); + } + string function() const { + // Backward compatible with old proto files. + if (!trace_->function().empty()) return trace_->function(); + return id_to_string_->at(trace_->function_id()); + } + int32 func_start_line() const { return trace_->func_start_line(); } + + private: + const CodeDef::Trace* trace_; + const std::map* id_to_string_; + }; + + CallStack(const CodeDef& def, const std::map* id_to_string) + : def_(def) { + traces_.reserve(def.traces_size()); + for (const auto& t : def_.traces()) { + traces_.emplace_back(&t, id_to_string); + } + } + + const CodeDef& code_def() const { return def_; } + const std::vector& traces() const { return traces_; } + + private: + std::vector traces_; + CodeDef def_; +}; + +class ExecStep { + public: + ExecStep() = default; + + void AddTimeStats(const string& dev, const NodeExecStats& step_stat); + + void AddMemoryStats(const string& dev, const NodeExecStats& step_stat); + + int64_t run_count() const { return exec_.run_count(); } + // The execution time of an op. If it runs on accelerator, then it's + // accelerator_exec_micros(). Otherwise, it's CPU time. + int64_t exec_micros() const; + // The accelerator execution time of an op. 0 if not run on accelerator. + int64_t accelerator_exec_micros() const; + // The cpu execution time of an op. + int64_t cpu_exec_micros() const; + + const std::map>>& op_execs() + const { + return op_execs_; + } + const std::map>>& cpu_execs() + const { + return cpu_execs_; + } + int64_t all_start_micros() const { return exec_.all_start_micros(); } + int64_t latest_end_micros() const { return exec_.latest_end_micros(); } + int64_t lastest_schedule_end_micros() const { + int64_t ret = 0; + for (const auto& exec : cpu_execs_) { + for (const auto& pair : exec.second) { + ret = std::max(ret, pair.first + pair.second); + } + } + return ret; + } + int64_t requested_bytes() const { + int64_t requested_bytes = 0; + for (const ExecMemory& exec : memory_execs_) { + requested_bytes += exec.requested_bytes(); + } + return requested_bytes; + } + int64_t peak_bytes() const { + int64_t peak_bytes = 0; + for (const ExecMemory& exec : memory_execs_) { + peak_bytes += exec.peak_bytes(); + } + return peak_bytes; + } + int64_t residual_bytes() const { + int64_t residual_bytes = 0; + for (const ExecMemory& exec : memory_execs_) { + residual_bytes += exec.residual_bytes(); + } + return residual_bytes; + } + int64_t output_bytes() const { + int64_t output_bytes = 0; + for (const ExecMemory& exec : memory_execs_) { + output_bytes += exec.output_bytes(); + } + return output_bytes; + } + int64_t accelerator_temp_bytes() const { + int64_t accelerator_temp_bytes = 0; + for (const ExecMemory& exec : memory_execs_) { + accelerator_temp_bytes += exec.accelerator_temp_bytes(); + } + return accelerator_temp_bytes; + } + int64_t host_temp_bytes() const { + int64_t host_temp_bytes = 0; + for (const ExecMemory& exec : memory_execs_) { + host_temp_bytes += exec.host_temp_bytes(); + } + return host_temp_bytes; + } + int64_t accelerator_persistent_bytes() const { + int64_t accelerator_persistent_bytes = 0; + for (const ExecMemory& exec : memory_execs_) { + accelerator_persistent_bytes += exec.accelerator_persistent_bytes(); + } + return accelerator_persistent_bytes; + } + int64_t host_persistent_bytes() const { + int64_t host_persistent_bytes = 0; + for (const ExecMemory& exec : memory_execs_) { + host_persistent_bytes += exec.host_persistent_bytes(); + } + return host_persistent_bytes; + } + std::map allocator_bytes_in_use() const { + std::map bytes_in_use; + for (const ExecMemory& exec : memory_execs_) { + bytes_in_use[exec.memory_micros()] = exec.allocator_bytes_in_use(); + } + return bytes_in_use; + } + + const std::vector& allocations() const { + return allocations_; + } + + const ExecProfile& ToProto() { + exec_.mutable_accelerator_execs()->clear(); + for (const auto& e : accelerator_execs_) { + auto& exec_time = (*exec_.mutable_accelerator_execs())[e.first]; + for (const auto& p : e.second) { + auto* t = exec_time.mutable_times()->Add(); + t->add_int64_values(p.first); + t->add_int64_values(p.second); + } + } + + exec_.mutable_cpu_execs()->clear(); + for (const auto& e : cpu_execs_) { + auto& exec_time = (*exec_.mutable_cpu_execs())[e.first]; + for (const auto& p : e.second) { + auto* t = exec_time.mutable_times()->Add(); + t->add_int64_values(p.first); + t->add_int64_values(p.second); + } + } + + exec_.mutable_devices()->Clear(); + exec_.mutable_devices()->Reserve(devices_.size()); + for (const string& d : devices_) { + exec_.add_devices(d); + } + exec_.mutable_allocations()->Clear(); + for (const auto& r : allocations_) { + exec_.add_allocations()->MergeFrom(r); + } + + exec_.mutable_memory_execs()->Clear(); + for (const auto& m : memory_execs_) { + exec_.add_memory_execs()->MergeFrom(m); + } + return exec_; + } + + void FromProto(const ExecProfile& exec) { + exec_.Clear(); + exec_.MergeFrom(exec); + + devices_.clear(); + devices_.insert(exec.devices().begin(), exec.devices().end()); + + accelerator_execs_.clear(); + cpu_execs_.clear(); + op_execs_.clear(); + + allocations_.clear(); + memory_execs_.clear(); + + for (const auto& exec_time : exec_.accelerator_execs()) { + auto& exec = accelerator_execs_[exec_time.first]; + auto& op_exec = op_execs_[exec_time.first]; + for (const auto& p : exec_time.second.times()) { + exec.push_back(std::make_pair(p.int64_values(0), p.int64_values(1))); + op_exec.push_back(std::make_pair(p.int64_values(0), p.int64_values(1))); + } + } + for (const auto& exec_time : exec_.cpu_execs()) { + auto& exec = cpu_execs_[exec_time.first]; + auto& op_exec = op_execs_[exec_time.first]; + for (const auto& p : exec_time.second.times()) { + exec.push_back(std::make_pair(p.int64_values(0), p.int64_values(1))); + op_exec.push_back(std::make_pair(p.int64_values(0), p.int64_values(1))); + } + } + for (const auto& r : exec_.allocations()) { + allocations_.push_back(r); + } + for (const auto& m : exec_.memory_execs()) { + memory_execs_.push_back(m); + } + } + + private: + ExecProfile exec_; + // device -> vector of {op_start_micros, op_exec_micros} pairs. + // accelerator_execs: gpu:id/stream:all -> {op_start_micros, op_exec_micros} + // For accelerator, vector size can be larger than 1, multiple kernel fires + // or in tf.while_loop. + std::map>> accelerator_execs_; + // cpu_execs: cpu/gpu:id -> {op_start_micros, op_exec_micros} + // For cpu, vector size can be larger than 1 if in tf.while_loop. + std::map>> cpu_execs_; + // combines accelerator_execs_ and cpu_execs_. + std::map>> op_execs_; + // Each ExecMemory corresponds to one scheduling of the op. Normally, + // there are multiple schedulings in while_loop. + std::vector memory_execs_; + // All devices the op is associated with (e.g. gpu:0 (scheduling), + // gpu:0:stream:xx (kernel exec), cpu:0 host) + std::set devices_; + + // The history of accelerator allocations and deallocations of this step. + std::vector allocations_; +}; + +#define GRAPH_NODE_BYTES(type) \ + do { \ + if (execs_.empty()) { \ + return 0; \ + } \ + if (step >= 0) { \ + auto exec = execs_.find(step); \ + if (exec == execs_.end()) return 0; \ + return exec->second.type##_bytes(); \ + } \ + \ + int64_t bytes = 0; \ + for (const auto& exec : execs_) { \ + bytes += exec.second.type##_bytes(); \ + } \ + return bytes / execs_.size(); \ + } while (0) + +class TFGraphNode { + public: + TFGraphNode(const ProfileNode& node, const ProfileProto& profile, + const std::map* id_to_string, + const std::map>* nodes_map) { + nodes_map_ = nodes_map; + FromProto(node, profile, id_to_string); + } + + TFGraphNode(const NodeDef* node, int64_t id, + const std::map>* nodes_map) { + nodes_map_ = nodes_map; + node_.set_id(id); + node_.set_name(node->name()); + node_.set_op(node->op()); + node_.set_float_ops(0); + + for (const auto& attr : node->attr()) { + (*node_.mutable_attrs())[attr.first].MergeFrom(attr.second); + if (attr.first == "shape" && attr.second.has_shape()) { + if (!shape_.empty()) { + absl::FPrintF(stderr, "Found duplicated shapes!\n"); + continue; + } + shape_ = ShapeProtoToVec(attr.second.shape()); + } else if (attr.first == "_output_shapes" && attr.second.has_list()) { + if (!output_shapes_.empty()) { + absl::FPrintF(stderr, "Found duplicated output shapes!\n"); + continue; + } + for (int i = 0; i < attr.second.list().shape_size(); ++i) { + output_shapes_[i] = ShapeProtoToVec(attr.second.list().shape(i)); + } + } + } + op_types_.insert(node->op()); + } + + void AddInput(const string& input, int64_t output_index, int input_idx) { + inputs_[input_idx] = input; + src_output_idx_[input] = output_index; + } + + void AddOpType(const string& op_type) { op_types_.insert(op_type); } + + void AddStepStat(int64_t step, const string& device, + const NodeExecStats& step_stat); + + void AddFloatOps(int64_t float_ops) { node_.set_float_ops(float_ops); } + + // TODO(xpan): This could take a lot of memory. + void AddCode(const CodeDef& code, + const std::map* id_to_string) { + if (!call_stack_) { + call_stack_ = std::make_unique(code, id_to_string); + } + } + + const string& name() const { return node_.name(); } + int64_t id() const { return node_.id(); } + const string& op() const { return node_.op(); } + const ProfileNode& node() { return node_; } + + bool trackable(int64_t step) const { + auto exec = execs_.find(step); + if (exec == execs_.end()) return false; + + if (exec->second.all_start_micros() == 0) return false; + if (node_.canonical_device().empty() || node_.host_device().empty()) { + return false; + } + return true; + } + + const ProfileNode& ToProto( + const std::map>& nodes_map) { + node_.clear_shape(); + node_.mutable_shape()->Reserve(shape().size()); + for (int64_t s : shape()) { + node_.add_shape(s); + } + + node_.clear_op_types(); + node_.mutable_op_types()->Reserve(op_types().size()); + for (const string& t : op_types()) { + node_.add_op_types(t); + } + + node_.clear_execs(); + for (auto& exec : execs_) { + auto& exec_pb = (*node_.mutable_execs())[exec.first]; + exec_pb.MergeFrom(exec.second.ToProto()); + } + + node_.clear_inputs(); + for (const auto& inp : inputs_) { + (*node_.mutable_inputs())[inp.first] = nodes_map.at(inp.second)->id(); + } + + node_.clear_input_shapes(); + for (const auto& s : input_shapes_) { + auto& shape = (*node_.mutable_input_shapes())[s.first]; + for (int64_t d : s.second) { + shape.add_int64_values(d); + } + } + + node_.clear_output_shapes(); + for (const auto& s : output_shapes_) { + auto& shape = (*node_.mutable_output_shapes())[s.first]; + for (int64_t d : s.second) { + shape.add_int64_values(d); + } + } + + node_.clear_src_output_index(); + for (const auto& s : src_output_idx_) { + int64_t id = nodes_map.at(s.first)->id(); + (*node_.mutable_src_output_index())[id] = s.second; + } + + if (call_stack_) { + node_.clear_trace(); + node_.mutable_trace()->MergeFrom(call_stack_->code_def()); + } + return node_; + } + + void FromProto(const ProfileNode& node, const ProfileProto& profile, + const std::map* id_to_string) { + node_.Clear(); + node_.MergeFrom(node); + + call_stack_ = std::make_unique(node.trace(), id_to_string); + + op_types_.clear(); + op_types_.insert(node_.op_types().begin(), node_.op_types().end()); + + shape_.clear(); + for (int64_t s : node_.shape()) { + shape_.push_back(s); + } + + execs_.clear(); + for (const auto& exec_pb : node.execs()) { + auto& exec = execs_[exec_pb.first]; + exec.FromProto(exec_pb.second); + } + + inputs_.clear(); + for (const auto& inp : node.inputs()) { + inputs_[inp.first] = profile.nodes().at(inp.second).name(); + } + + input_shapes_.clear(); + for (const auto& s : node.input_shapes()) { + auto& shape = input_shapes_[s.first]; + for (const int64_t d : s.second.int64_values()) { + shape.push_back(d); + } + } + + output_shapes_.clear(); + for (const auto& s : node.output_shapes()) { + auto& shape = output_shapes_[s.first]; + for (const int64_t d : s.second.int64_values()) { + shape.push_back(d); + } + } + + src_output_idx_.clear(); + for (const auto& s : node.src_output_index()) { + src_output_idx_[profile.nodes().at(s.first).name()] = s.second; + } + } + + const std::map& inputs() const { return inputs_; } + + // Number of times the graph node is executed. When step < 0, the + // average number of times executed across all steps. + int64_t run_count(int64_t step) const { + if (execs_.empty()) { + return 0; + } + if (step >= 0) { + auto exec = execs_.find(step); + if (exec == execs_.end()) { + return 0; + } + return exec->second.run_count(); + } + int64_t total_run_count = 0; + for (const auto& exec : execs_) { + total_run_count += exec.second.run_count(); + } + return total_run_count / execs_.size(); + } + // This is overall computation time, including both cpu and accelerator. + // Note, cpu and accelerator might or might not run in parallel. + int64_t exec_micros(int64_t step) const { + // Empty when no RunMetadata is provided. + if (execs_.empty()) { + return 0; + } + if (step >= 0) { + auto exec = execs_.find(step); + if (exec == execs_.end()) { + return 0; + } + return exec->second.exec_micros(); + } + + int64_t total_micros = 0; + for (const auto& exec : execs_) { + total_micros += exec.second.exec_micros(); + } + return total_micros / execs_.size(); + } + + // This is accelerator computation time of a step, or average of + // multiple step, when step < 0. + int64_t accelerator_exec_micros(int64_t step) const { + // Empty when no RunMetadata is provided. + if (execs_.empty()) { + return 0; + } + if (step >= 0) { + auto exec = execs_.find(step); + if (exec == execs_.end()) { + return 0; + } + return exec->second.accelerator_exec_micros(); + } + + int64_t total_micros = 0; + for (const auto& exec : execs_) { + total_micros += exec.second.accelerator_exec_micros(); + } + return total_micros / execs_.size(); + } + + // This is cpu computation time of a step, or average of + // multiple step, when step < 0. + int64_t cpu_exec_micros(int64_t step) const { + // Empty when no RunMetadata is provided. + if (execs_.empty()) { + return 0; + } + if (step >= 0) { + auto exec = execs_.find(step); + if (exec == execs_.end()) { + return 0; + } + return exec->second.cpu_exec_micros(); + } + + int64_t total_micros = 0; + for (const auto& exec : execs_) { + total_micros += exec.second.cpu_exec_micros(); + } + return total_micros / execs_.size(); + } + + int64_t requested_bytes(int64_t step) const { GRAPH_NODE_BYTES(requested); } + int64_t peak_bytes(int64_t step) const { GRAPH_NODE_BYTES(peak); } + int64_t residual_bytes(int64_t step) const { GRAPH_NODE_BYTES(residual); } + int64_t output_bytes(int64_t step) const { GRAPH_NODE_BYTES(output); } + + int64_t all_start_micros(int64_t step) const { + auto exec = execs_.find(step); + if (exec == execs_.end()) { + return 0; + } + return exec->second.all_start_micros(); + } + + int64_t latest_end_micros(int64_t step) const { + auto exec = execs_.find(step); + if (exec == execs_.end()) { + return 0; + } + return exec->second.latest_end_micros(); + } + + int64_t lastest_schedule_end_micros(int64_t step) const { + auto exec = execs_.find(step); + if (exec == execs_.end()) { + return 0; + } + return exec->second.lastest_schedule_end_micros(); + } + + const std::map>>& op_execs( + int64_t step) const { + auto exec = execs_.find(step); + if (exec == execs_.end()) { + return empty_execs_; + } + return exec->second.op_execs(); + } + const std::map>>& cpu_execs( + int64_t step) const { + auto exec = execs_.find(step); + if (exec == execs_.end()) { + return empty_execs_; + } + return exec->second.cpu_execs(); + } + + const std::map& all_op_execs() const { return execs_; } + + int64_t accelerator_temp_bytes(int64_t step) const { + auto exec = execs_.find(step); + if (exec == execs_.end()) { + return 0; + } + return exec->second.accelerator_temp_bytes(); + } + int64_t host_temp_bytes(int64_t step) const { + auto exec = execs_.find(step); + if (exec == execs_.end()) { + return 0; + } + return exec->second.host_temp_bytes(); + } + int64_t accelerator_persistent_bytes() const { + int64_t persistent_bytes = 0; + for (const auto& exec : execs_) { + persistent_bytes = std::max(persistent_bytes, + exec.second.accelerator_persistent_bytes()); + } + return persistent_bytes; + } + std::map allocator_bytes_in_use(int64_t step) const { + auto exec = execs_.find(step); + if (exec == execs_.end()) { + return empty_bytes_in_use_; + } + return exec->second.allocator_bytes_in_use(); + } + + const std::vector& allocations(int64_t step) const { + auto exec = execs_.find(step); + if (exec == execs_.end()) { + return empty_allocations_; + } + return exec->second.allocations(); + } + + int64_t parameters() const { + if (!shape().empty()) { + int64_t params = 1; + bool complete_shape = true; + for (int64_t d : shape()) { + // Sometimes parameters could be <0 when a dim is unknown. + if (d < 0) { + complete_shape = false; + break; + } + params *= d; + } + if (complete_shape) { + return params; + } else { + LOG(INFO) << "Incomplete shape.\n"; + } + } + return 0; + } + + int64_t float_ops(int64_t step) const { + // If not run, return static analysis. + if (execs_.empty()) { + return node_.float_ops(); + } + // Otherwise, return dynamic float_ops. + return node_.float_ops() * run_count(step); + } + const CallStack* call_stack() { return call_stack_.get(); } + string canonical_device() const { return node_.canonical_device(); } + string host_device() const { return node_.host_device(); } + const std::set& op_types() const { return op_types_; } + + const AttrValue* op_attrs(const string& name) const { + const auto it = node_.attrs().find(name); + if (it == node_.attrs().end()) { + return nullptr; + } + return &it->second; + } + + const std::vector& shape() const { return shape_; } + + const std::map>& output_shapes() const { + return output_shapes_; + } + + std::map> input_shapes() const { + std::map> input_shapes; + for (const auto& inp : inputs_) { + // Always create an empty vec even if the shape info might be missing. + std::vector& shape_vec = input_shapes[inp.first]; + if (!nodes_map_) continue; + auto input_it = nodes_map_->find(inp.second); + if (input_it == nodes_map_->end()) continue; + auto output_it = src_output_idx_.find(inp.second); + if (output_it == src_output_idx_.end()) continue; + + const TFGraphNode* input_node = input_it->second.get(); + if (!input_node) continue; + const auto& output_shapes = input_node->output_shapes(); + const auto& output_shape = output_shapes.find(output_it->second); + if (output_shape == output_shapes.end()) continue; + + if (output_shape != input_node->output_shapes().end()) { + shape_vec.assign(output_shape->second.begin(), + output_shape->second.end()); + } + } + return input_shapes; + } + + private: + // maps graph node name to TFGraphNode. Not owned. + const std::map>* nodes_map_; + // inputs to the node. input index -> input node name. + std::map inputs_; + // The output index of the source node. + std::map src_output_idx_; + // proto for serialize/deserialized representation of the node. + ProfileNode node_; + // Python call stack that creates the name. + std::unique_ptr call_stack_; + // Shape of the node (e.g. Variable) if available. + std::vector shape_; + // Won't missing input_idx. But some shapes might be empty (unknown). + std::map> input_shapes_; + // Could miss output_idx if no _output_shapes attr. some shapes can also + // be empty. + std::map> output_shapes_; + + std::set op_types_; + + std::map execs_; + + // Placeholder for empty cases. + std::map empty_bytes_in_use_; + std::map>> empty_execs_; + std::vector empty_allocations_; +}; + +class TFMultiGraphNode { + public: + TFMultiGraphNode(const string& name) + : name_(name), + step_(-1), + run_count_(0), + exec_micros_(0), + accelerator_exec_micros_(0), + cpu_exec_micros_(0), + requested_bytes_(0), + peak_bytes_(0), + residual_bytes_(0), + output_bytes_(0), + float_ops_(0), + parameters_(0) {} + + bool SnapshotNodes(int64_t step, const std::vector& type_regexes) { + run_count_ = 0; + exec_micros_ = 0; + accelerator_exec_micros_ = 0; + cpu_exec_micros_ = 0; + + requested_bytes_ = 0; + peak_bytes_ = 0; + residual_bytes_ = 0; + output_bytes_ = 0; + + float_ops_ = 0; + parameters_ = 0; + op_types_.clear(); + shapes_.clear(); + devices_.clear(); + snapshot_nodes_.clear(); + + step_ = step; + std::vector nodes = pick_nodes(type_regexes); + + if (nodes.empty()) { + return (type_regexes.size() == 1 && type_regexes[0] == ".*"); + } + + for (const TFGraphNode* node : nodes) { + op_types_.insert(node->op_types().begin(), node->op_types().end()); + + run_count_ += node->run_count(step); + exec_micros_ += node->exec_micros(step); + accelerator_exec_micros_ += node->accelerator_exec_micros(step); + cpu_exec_micros_ += node->cpu_exec_micros(step); + + requested_bytes_ += node->requested_bytes(step); + peak_bytes_ += node->peak_bytes(step); + residual_bytes_ += node->residual_bytes(step); + output_bytes_ += node->output_bytes(step); + + float_ops_ += node->float_ops(step); + parameters_ += node->parameters(); + if (!node->shape().empty()) { + shapes_.push_back(node->shape()); + } + devices_.insert(node->canonical_device()); + snapshot_nodes_[node->name()] = node; + } + return true; + } + + int64_t step() const { return step_; } + + void AddGraphNode(const TFGraphNode* node) { + if (nodes_.find(node->name()) != nodes_.end()) { + return; + } + nodes_[node->name()] = node; + } + + const std::map& graph_nodes() const { + return snapshot_nodes_; + } + + const string& name() const { return name_; } + + int64_t run_count() const { return run_count_; } + int64_t exec_micros() const { return exec_micros_; } + int64_t accelerator_exec_micros() const { return accelerator_exec_micros_; } + int64_t cpu_exec_micros() const { return cpu_exec_micros_; } + + int64_t requested_bytes() const { return requested_bytes_; } + int64_t peak_bytes() const { return peak_bytes_; } + int64_t residual_bytes() const { return residual_bytes_; } + int64_t output_bytes() const { return output_bytes_; } + + int64_t float_ops() const { return float_ops_; } + + int64_t parameters() const { return parameters_; } + + const std::set& devices() const { return devices_; } + + const std::set& op_types() const { return op_types_; } + + const std::vector>& shapes() const { return shapes_; } + + private: + std::vector pick_nodes( + const std::vector& type_regexes) { + if (type_regexes.empty()) { + return {}; + } + std::vector ret; + if (type_regexes.size() == 1 && type_regexes[0] == ".*") { + for (const auto& n : nodes_) { + ret.push_back(n.second); + } + return ret; + } + + for (const string& regex : type_regexes) { + for (const auto& n : nodes_) { + for (const string& type : n.second->op_types()) { + if (RE2::FullMatch(type, regex)) { + ret.push_back(n.second); + break; + } + } + } + } + return ret; + } + + const string name_; + int64_t step_; + // Snapshot based on type_regexes + std::set op_types_; + int64_t run_count_; + int64_t exec_micros_; + int64_t accelerator_exec_micros_; + int64_t cpu_exec_micros_; + + int64_t requested_bytes_; + int64_t peak_bytes_; + int64_t residual_bytes_; + int64_t output_bytes_; + int64_t float_ops_; + int64_t parameters_; + std::set devices_; + std::vector> shapes_; + std::map snapshot_nodes_; + + // Overall data held by the TFMultiGraphNode. + std::map nodes_; +}; + +bool IsPlacedOnCPU(const string& device); +bool IsPlacedOnAccelerator(const string& device); +bool CountAsAcceleratorTime(const string& device); +bool CountAsCPUTime(const string& device); +bool IsCanonicalDevice(const string& device); + +} // namespace tfprof +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_NODE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/profiler/internal/tfprof_node_show.h b/third_party/tflite-hdrs/tensorflow/core/profiler/internal/tfprof_node_show.h new file mode 100644 index 00000000..e3d4b86a --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/profiler/internal/tfprof_node_show.h @@ -0,0 +1,160 @@ +/* Copyright 2016 The TensorFlow Authors All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Node classes used for different views. They are wrappers with "show" +// methods. +// +// ScopeNode is for scope view. GraphNode is for graph view, CodeNode +// is for code view and OpNode for op view. +// ScopeNode and GraphNode each maps to one TFGraphNode. +// CodeNode and OpNode each maps to one TFMultiGraphNode. + +#ifndef TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_NODE_SHOW_H_ +#define TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_NODE_SHOW_H_ + +#include +#include +#include +#include +#include + +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/profiler/internal/tfprof_constants.h" +#include "tensorflow/core/profiler/internal/tfprof_node.h" +#include "tensorflow/core/profiler/internal/tfprof_utils.h" +#include "tensorflow/core/profiler/tfprof_options.h" +#include "tensorflow/core/profiler/tfprof_output.pb.h" + +namespace tensorflow { +namespace tfprof { + +class ShowNode { + public: + explicit ShowNode(const TFGraphNode* node); + virtual ~ShowNode() = default; + + const string& name() const { return node->name(); } + GraphNodeProto* mutable_proto(); + const GraphNodeProto& proto() const; + + void ReInit(int64_t step); + + void AggregateTotalStats(ShowNode* node); + + void AddSelfToTotalStats(); + + void ResetTotalStats(); + + const TFGraphNode* node; + bool account; + string formatted_str; + + protected: + GraphNodeProto proto_; +}; + +class GraphNode : public ShowNode { + public: + explicit GraphNode(TFGraphNode* node) : ShowNode(node) {} + + bool Trackable(int64_t step) const { return node->trackable(step); } + + std::vector children; + std::vector show_children; +}; + +class ScopeNode : public ShowNode { + public: + explicit ScopeNode(const TFGraphNode* node) : ShowNode(node) {} + ~ScopeNode() override = default; + + std::vector children; + std::vector show_children; +}; + +class ShowMultiNode { + public: + explicit ShowMultiNode(TFMultiGraphNode* node); + virtual ~ShowMultiNode() = default; + + bool ReInit(int64_t step, const std::vector& type_regexes); + + const string& name() const { return node->name(); } + MultiGraphNodeProto* mutable_proto(); + const MultiGraphNodeProto& proto() const; + + void AggregateTotalStats(ShowMultiNode* node); + + void AddSelfToTotalStats(); + + void ResetTotalStats(); + + TFMultiGraphNode* node; + bool account; + bool show; + string formatted_str; + + protected: + MultiGraphNodeProto proto_; +}; + +class CodeNode : public ShowMultiNode { + public: + CodeNode(TFMultiGraphNode* node, const CallStack::Trace* trace, + const string& suffix) + : ShowMultiNode(node), trace_(trace), suffix_(suffix) {} + ~CodeNode() override = default; + + CodeNode* AddChildren(const string& name, const CallStack::Trace* trace, + const string suffix) { + auto it = children_.find(name); + if (it != children_.end()) { + return it->second.get(); + } + + graph_children_.push_back(std::make_unique(name)); + auto child = &children_[name]; + *child = + std::make_unique(graph_children_.back().get(), trace, suffix); + children.push_back(child->get()); + return child->get(); + } + + bool has_trace() const { return trace_ != nullptr; } + int32 lineno() const { return trace_->lineno(); } + string file() const { return trace_->file(); } + string function() const { return trace_->function() + suffix_; } + int32 func_start_line() const { return trace_->func_start_line(); } + + std::vector children; + std::vector show_children; + + private: + const CallStack::Trace* trace_; + string suffix_; + std::vector> graph_children_; + std::map> children_; +}; + +class OpNode : public ShowMultiNode { + public: + explicit OpNode(TFMultiGraphNode* node) : ShowMultiNode(node) {} + ~OpNode() override = default; +}; + +} // namespace tfprof +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_NODE_SHOW_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/profiler/internal/tfprof_op.h b/third_party/tflite-hdrs/tensorflow/core/profiler/internal/tfprof_op.h new file mode 100644 index 00000000..0aa4887e --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/profiler/internal/tfprof_op.h @@ -0,0 +1,77 @@ +/* Copyright 2016 The TensorFlow Authors All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Build a flat structure of ops. + +#ifndef TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_OP_H_ +#define TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_OP_H_ + +#include +#include +#include +#include +#include +#include + +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/profiler/internal/tfprof_node.h" +#include "tensorflow/core/profiler/internal/tfprof_show_multi.h" +#include "tensorflow/core/profiler/internal/tfprof_utils.h" +#include "tensorflow/core/profiler/tfprof_options.h" +#include "tensorflow/core/profiler/tfprof_output.pb.h" + +namespace tensorflow { +namespace tfprof { + +// Organize tensorflow ops in a graph structure, pointing from output ops +// to input ops. +class TFOp : public TFMultiShow { + public: + explicit TFOp() : TFMultiShow() {} + ~TFOp() override = default; + + void AddNode(TFGraphNode* node) override; + + void Build() override; + + private: + const ShowMultiNode* ShowInternal(const Options& opts, + Timeline* timeline) override; + + int64_t SearchRoot(std::vector nodes, + const std::vector& regexes); + + bool ShouldShowIfExtra(const ShowMultiNode* node, const Options& opts, + int depth) const override { + const int max_num_graph_nodes = node->node->graph_nodes().size(); + if (opts.min_occurrence > max_num_graph_nodes) { + return false; + } + return true; + } + + string FormatNode(OpNode* node, OpNode* root, const Options& opts) const; + string FormatMemoryNode(int64_t node_total_bytes, int64_t root_total_bytes, + int64_t node_bytes) const; + + std::unique_ptr root_; + std::map> cnodes_map_; + std::map> tfcnodes_map_; +}; + +} // namespace tfprof +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_OP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/profiler/internal/tfprof_scope.h b/third_party/tflite-hdrs/tensorflow/core/profiler/internal/tfprof_scope.h new file mode 100644 index 00000000..ede6d633 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/profiler/internal/tfprof_scope.h @@ -0,0 +1,76 @@ +/* Copyright 2016 The TensorFlow Authors All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Build a tree structure based on the TensorFlow op names. +// For example, 'name1/name2' is a child of 'name1'. +// Stats are aggregated from descendants to ancestors. + +#ifndef TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_SCOPE_H_ +#define TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_SCOPE_H_ + +#include +#include +#include +#include + +#include "tensorflow/c/checkpoint_reader.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/profiler/internal/tfprof_node.h" +#include "tensorflow/core/profiler/internal/tfprof_show.h" +#include "tensorflow/core/profiler/internal/tfprof_utils.h" +#include "tensorflow/core/profiler/tfprof_options.h" +#include "tensorflow/core/profiler/tfprof_output.pb.h" + +namespace tensorflow { +namespace tfprof { + +class TFScope : public TFShow { + public: + explicit TFScope(checkpoint::CheckpointReader* ckpt_reader) + : TFShow(ckpt_reader), root_(nullptr) {} + ~TFScope() override = default; + + void AddNode(TFGraphNode* node) override; + + void Build() override; + + private: + const ShowNode* ShowInternal(const Options& opts, + Timeline* timeline) override; + + ScopeNode* CreateParentNode(const string& name); + + std::vector SearchRoot(std::vector roots, + const std::vector& regexes); + + std::vector PrintScope(std::vector roots, + const Options& opts, int depth, + int last_ident); + + std::vector Account(const std::vector& roots, + const Options& opts); + + void Format(std::vector roots, string* display_str, + GraphNodeProto* proto); + + ScopeNode* root_; + std::vector> node_defs_; + std::map> parent_nodes_; + std::map> nodes_map_; +}; +} // namespace tfprof +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_SCOPE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/profiler/internal/tfprof_show.h b/third_party/tflite-hdrs/tensorflow/core/profiler/internal/tfprof_show.h new file mode 100644 index 00000000..ef713cbe --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/profiler/internal/tfprof_show.h @@ -0,0 +1,157 @@ +/* Copyright 2016 The TensorFlow Authors All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Parent class and utilities for tfprof_graph and tfprof_scope. + +#ifndef TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_SHOW_H_ +#define TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_SHOW_H_ + +#include +#include +#include +#include + +#include "tensorflow/c/checkpoint_reader.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/profiler/internal/tfprof_constants.h" +#include "tensorflow/core/profiler/internal/tfprof_node.h" +#include "tensorflow/core/profiler/internal/tfprof_node_show.h" +#include "tensorflow/core/profiler/internal/tfprof_tensor.h" +#include "tensorflow/core/profiler/internal/tfprof_timeline.h" +#include "tensorflow/core/profiler/internal/tfprof_utils.h" +#include "tensorflow/core/profiler/tfprof_options.h" +#include "tensorflow/core/profiler/tfprof_output.pb.h" + +namespace tensorflow { +namespace tfprof { +class TFShow { + public: + explicit TFShow(checkpoint::CheckpointReader* ckpt_reader) + : ckpt_reader_(ckpt_reader) {} + virtual ~TFShow() = default; + virtual void AddNode(TFGraphNode* node) = 0; + virtual void Build() = 0; + virtual const GraphNodeProto& Show(const string& prefix, + const Options& opts) final; + + protected: + virtual const ShowNode* ShowInternal(const Options& opts, + Timeline* timeline) = 0; + + bool LookUpCheckPoint(const string& name, + std::unique_ptr* tensor); + + // Overridden by subclass if extra requirements need to be met. + virtual bool ShouldShowIfExtra(const ShowNode* node, const Options& opts, + int depth) const { + return true; + } + + bool ShouldShow(const ShowNode* node, const Options& opts, int depth) const; + + bool ShouldTrim(const ShowNode* node, + const std::vector& regexes) const; + + bool ReAccount(ShowNode* node, const Options& opts); + + string FormatNode(ShowNode* node, const Options& opts) const; + string FormatNodeMemory(ShowNode* node, int64_t bytes, + int64_t total_bytes) const; + + string FormatLegend(const Options& opts) const; + + template + std::vector SortNodes(const std::vector& nodes, const Options& opts) { + if (opts.order_by.empty() || nodes.empty()) { + return nodes; + } + std::vector sorted_nodes = nodes; + std::stable_sort(sorted_nodes.begin(), sorted_nodes.end(), + [&opts](const T* n1, const T* n2) { + if (n1->name() == kTFProfRoot) return true; + if (n2->name() == kTFProfRoot) return false; + bool name_cmp = n1->name() < n2->name(); + if (opts.order_by == kOrderBy[0]) { + return name_cmp; + } else if (opts.order_by == kOrderBy[1]) { + return n1->proto().total_requested_bytes() > + n2->proto().total_requested_bytes(); + } else if (opts.order_by == kOrderBy[2]) { + return n1->proto().total_peak_bytes() > + n2->proto().total_peak_bytes(); + } else if (opts.order_by == kOrderBy[3]) { + return n1->proto().total_residual_bytes() > + n2->proto().total_residual_bytes(); + } else if (opts.order_by == kOrderBy[4]) { + return n1->proto().total_output_bytes() > + n2->proto().total_output_bytes(); + } else if (opts.order_by == kOrderBy[5]) { + return n1->proto().total_exec_micros() > + n2->proto().total_exec_micros(); + } else if (opts.order_by == kOrderBy[6]) { + return n1->proto().total_accelerator_exec_micros() > + n2->proto().total_accelerator_exec_micros(); + } else if (opts.order_by == kOrderBy[7]) { + return n1->proto().total_cpu_exec_micros() > + n2->proto().total_cpu_exec_micros(); + } else if (opts.order_by == kOrderBy[8]) { + return n1->proto().total_parameters() > + n2->proto().total_parameters(); + } else if (opts.order_by == kOrderBy[9]) { + return n1->proto().total_float_ops() > + n2->proto().total_float_ops(); + } + return name_cmp; + }); + return sorted_nodes; + } + + checkpoint::CheckpointReader* ckpt_reader_; +}; + +template +string FormatTotalExecTime(const T* node, const Options& opts) { + string time = FormatTime(node->proto().total_exec_micros()); + if (node->account) { + time = FormatTime(node->proto().exec_micros()) + "/" + time; + } else { + time = "--/" + time; + } + return time; +} +template +string FormatCPUExecTime(const T* node, const Options& opts) { + string time = FormatTime(node->proto().total_cpu_exec_micros()); + if (node->account) { + time = FormatTime(node->proto().cpu_exec_micros()) + "/" + time; + } else { + time = "--/" + time; + } + return time; +} +template +string FormatAcceleratorExecTime(const T* node, const Options& opts) { + string time = FormatTime(node->proto().total_accelerator_exec_micros()); + if (node->account) { + time = FormatTime(node->proto().accelerator_exec_micros()) + "/" + time; + } else { + time = "--/" + time; + } + return time; +} +} // namespace tfprof +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_SHOW_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/profiler/internal/tfprof_show_multi.h b/third_party/tflite-hdrs/tensorflow/core/profiler/internal/tfprof_show_multi.h new file mode 100644 index 00000000..1f424dd0 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/profiler/internal/tfprof_show_multi.h @@ -0,0 +1,127 @@ +/* Copyright 2016 The TensorFlow Authors All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Parent class and utilities for tfprof_code. + +#ifndef TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_SHOW_MULTI_H_ +#define TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_SHOW_MULTI_H_ + +#include +#include +#include +#include + +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/profiler/internal/tfprof_constants.h" +#include "tensorflow/core/profiler/internal/tfprof_node.h" +#include "tensorflow/core/profiler/internal/tfprof_node_show.h" +#include "tensorflow/core/profiler/internal/tfprof_show.h" +#include "tensorflow/core/profiler/internal/tfprof_tensor.h" +#include "tensorflow/core/profiler/internal/tfprof_timeline.h" +#include "tensorflow/core/profiler/internal/tfprof_utils.h" +#include "tensorflow/core/profiler/tfprof_options.h" +#include "tensorflow/core/profiler/tfprof_output.pb.h" + +namespace tensorflow { +namespace tfprof { + +class TFMultiShow { + public: + explicit TFMultiShow() = default; + virtual ~TFMultiShow() = default; + virtual void AddNode(TFGraphNode* node) = 0; + virtual void Build() = 0; + const MultiGraphNodeProto& Show(const string& prefix, const Options& opts); + + protected: + virtual const ShowMultiNode* ShowInternal(const Options& opts, + Timeline* timeline) = 0; + + bool LookUpCheckPoint(const string& name, + std::unique_ptr* tensor); + + // Overridden by subclass if extra requirements need to be met. + virtual bool ShouldShowIfExtra(const ShowMultiNode* node, const Options& opts, + int depth) const { + return true; + } + + bool ShouldShow(const ShowMultiNode* node, const Options& opts, + int depth) const; + + bool ShouldTrim(const ShowMultiNode* node, + const std::vector& regexes) const; + + bool ReAccount(ShowMultiNode* node, const Options& opts); + + string FormatLegend(const Options& opts) const; + string FormatInputShapes(const MultiGraphNodeProto& proto) const; + std::vector FormatTimes(const ShowMultiNode* node, + const Options& opts) const; + + template + std::vector SortNodes(const std::vector& nodes, const Options& opts) { + if (opts.order_by.empty() || nodes.empty()) { + return nodes; + } + std::vector sorted_nodes = nodes; + std::stable_sort(sorted_nodes.begin(), sorted_nodes.end(), + [&opts](const T* n1, const T* n2) { + if (n1->name() == kTFProfRoot) return true; + if (n2->name() == kTFProfRoot) return false; + bool name_cmp = n1->name() < n2->name(); + if (opts.order_by == kOrderBy[0]) { + return name_cmp; + } else if (opts.order_by == kOrderBy[1]) { + return n1->proto().total_requested_bytes() > + n2->proto().total_requested_bytes(); + } else if (opts.order_by == kOrderBy[2]) { + return n1->proto().total_peak_bytes() > + n2->proto().total_peak_bytes(); + } else if (opts.order_by == kOrderBy[3]) { + return n1->proto().total_residual_bytes() > + n2->proto().total_residual_bytes(); + } else if (opts.order_by == kOrderBy[4]) { + return n1->proto().total_output_bytes() > + n2->proto().total_output_bytes(); + } else if (opts.order_by == kOrderBy[5]) { + return n1->proto().total_exec_micros() > + n2->proto().total_exec_micros(); + } else if (opts.order_by == kOrderBy[6]) { + return n1->proto().total_accelerator_exec_micros() > + n2->proto().total_accelerator_exec_micros(); + } else if (opts.order_by == kOrderBy[7]) { + return n1->proto().total_cpu_exec_micros() > + n2->proto().total_cpu_exec_micros(); + } else if (opts.order_by == kOrderBy[8]) { + return n1->proto().total_parameters() > + n2->proto().total_parameters(); + } else if (opts.order_by == kOrderBy[9]) { + return n1->proto().total_float_ops() > + n2->proto().total_float_ops(); + } else if (opts.order_by == kOrderBy[10]) { + return n1->node->graph_nodes().size() > + n2->node->graph_nodes().size(); + } + return name_cmp; + }); + return sorted_nodes; + } +}; + +} // namespace tfprof +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_SHOW_MULTI_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/profiler/internal/tfprof_stats.h b/third_party/tflite-hdrs/tensorflow/core/profiler/internal/tfprof_stats.h new file mode 100644 index 00000000..67cbdf56 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/profiler/internal/tfprof_stats.h @@ -0,0 +1,127 @@ +/* Copyright 2016 The TensorFlow Authors All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Core API of tfprof. +// 1. Load protos generated from a tensorflow model. +// 2. Build in-memory representations of the tensorflow model, annotate the +// representation with various stats, such as params,times,memory,etc. +// 3. Accept command and options to selectively aggregate stats for analysis +// and print out the results. + +#ifndef TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_STATS_H_ +#define TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_STATS_H_ + +#include +#include +#include +#include + +#include "tensorflow/c/checkpoint_reader.h" +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/step_stats.pb.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/profiler/internal/tfprof_code.h" +#include "tensorflow/core/profiler/internal/tfprof_graph.h" +#include "tensorflow/core/profiler/internal/tfprof_node.h" +#include "tensorflow/core/profiler/internal/tfprof_op.h" +#include "tensorflow/core/profiler/internal/tfprof_scope.h" +#include "tensorflow/core/profiler/internal/tfprof_show.h" +#include "tensorflow/core/profiler/internal/tfprof_utils.h" +#include "tensorflow/core/profiler/tfprof_log.pb.h" +#include "tensorflow/core/profiler/tfprof_options.h" +#include "tensorflow/core/profiler/tfprof_output.pb.h" +#include "tensorflow/core/protobuf/config.pb.h" + +namespace tensorflow { +namespace tfprof { + +class TFStats { + public: + TFStats(std::unique_ptr graph, + std::unique_ptr run_meta, + std::unique_ptr op_log, + std::unique_ptr ckpt_reader); + + TFStats(const string& filename, + std::unique_ptr ckpt_reader); + + ~TFStats() = default; + + const std::map>& nodes() const { + return nodes_map_; + } + const std::set& steps() const { return steps_; } + bool has_code_traces() const { return has_code_traces_; } + double run_coverage() const { + return covered_nodes_.size() / (nodes_map_.size() + 1e-10); + } + + void BuildView(const string& cmd); + void BuildAllViews(); + + // Note: Must first BuildView(view_foo) before ShowXXX(view_foo) methods. + // + // Organize the TensorFlow model as different types of views, and generate + // outputs for profiling. + // TODO(xpan): Should it return reference here? + const GraphNodeProto& ShowGraphNode(const string& cmd, + const Options& opts) const; + const MultiGraphNodeProto& ShowMultiGraphNode(const string& cmd, + const Options& opts) const; + + // Add a (partial) graph to existing graph. + void AddGraph(std::unique_ptr graph); + + // Add a step of run time meta data. + void AddRunMeta(int64_t step, std::unique_ptr run_meta); + // Add tfprof operation meta data, such as customized op type, float_ops, + // and code traces. + void AddOpLogProto(std::unique_ptr op_log); + + void SerializeToString(string* content); + void WriteProfile(const string& filename); + + // For test purpose only. + void AddNodeForTest(int64_t step, std::unique_ptr node); + + private: + bool Validate(const Options& opts) const; + string MaybeReportMissingTrace() const; + + std::set steps_; + bool has_code_traces_; + bool miss_accelerator_stream_; + std::unique_ptr scope_view_; + std::unique_ptr graph_view_; + std::unique_ptr code_view_; + std::unique_ptr op_view_; + std::unique_ptr ckpt_reader_; + // TODO(xpan): Store TFGraphNode instead of TFGraphNode* to avoid large + // number of dynamic alloc. + // Maps from graph node name to TFGraphNode. + std::map> nodes_map_; + GraphNodeProto empty_graph_node_; + MultiGraphNodeProto empty_multi_graph_node_; + + std::map id_to_string_; + // Graph nodes covered by RunMetadata, that is traced with run time stats. + std::set covered_nodes_; +}; + +} // namespace tfprof +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_STATS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/profiler/internal/tfprof_tensor.h b/third_party/tflite-hdrs/tensorflow/core/profiler/internal/tfprof_tensor.h new file mode 100644 index 00000000..4a04b005 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/profiler/internal/tfprof_tensor.h @@ -0,0 +1,175 @@ +/* Copyright 2016 The TensorFlow Authors All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// TFProf representation of a Tensor's value. +// 1. Multi-dimension tensor is flattened in row major, and stored in proto. +// 2. integer are up-casted to int64. floats are up-casted to double. string +// is not supported by TensorFlow CheckPointReader library, though it is +// supported in current code. + +#ifndef TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_TENSOR_H_ +#define TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_TENSOR_H_ + +#include +#include +#include +#include + +#include "absl/strings/numbers.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/profiler/tfprof_output.pb.h" + +namespace tensorflow { +namespace tfprof { + +class TFProfTensor { + public: + explicit TFProfTensor(std::unique_ptr tensor) + : tensor_(std::move(tensor)) { + Build(); + } + + // If pointers are provided, they are filled by the method. + void Display(string* formatted_str, TFProfTensorProto* tfprof_tensor_pb); + + private: + // Max length of tensor value displayed to CLI. + const int64_t kTFProfTenosrMaxDisplayLen = 10000; + // Max length after which a latency warning will be printed. + const int64_t kTFProfTensorMaxWarnLen = 100000; + + void Build(); + + template + bool AddValue(const T& value, TFProfTensorProto* dim) { + std::ostringstream sstream; + sstream << value; + if (typeid(value) == typeid(double)) { + double double_val = 0.0; + CHECK(absl::SimpleAtod(sstream.str(), &double_val)); // Crash OK + dim->add_value_double(double_val); + absl::StrAppendFormat(&formatted_str_, "%.2f ", + dim->value_double(dim->value_double_size() - 1)); + } else if (typeid(value) == typeid(int64_t)) { + int64_t int64_val = 0; + CHECK(absl::SimpleAtoi(sstream.str(), &int64_val)); // Crash OK + dim->add_value_int64(int64_val); + absl::StrAppendFormat(&formatted_str_, "%d ", + dim->value_int64(dim->value_int64_size() - 1)); + } else if (typeid(value) == typeid(string)) { + dim->add_value_str(sstream.str()); + absl::StrAppend(&formatted_str_, "'", + dim->value_str(dim->value_str_size() - 1), "' "); + } else { + CHECK(false) << "Unsupported type: " << typeid(value).name(); + } + } + + // It assumes the flatten values are stored in row-major, which is mentioned + // indirectly at various places: + // TODO(xpan): Further verifying it. + template + int64_t BuildOutput(int64_t start, int depth, const std::vector& values, + TFProfTensorProto* dim) { + formatted_str_ += "["; + int64_t nstart = start; + if (tensor_->dims() == 0 && values.size() == 1) { + std::ostringstream sstream; + sstream << values[nstart]; + + if (typeid(values[nstart]) == typeid(double)) { + double double_val = 0.0; + CHECK(absl::SimpleAtod(sstream.str(), &double_val)); // Crash OK + dim->add_value_double(double_val); + absl::StrAppendFormat(&formatted_str_, "%.2f ", + dim->value_double(dim->value_double_size() - 1)); + } else if (typeid(values[nstart]) == typeid(int64_t)) { + int64_t int64_val = 0; + CHECK(absl::SimpleAtoi(sstream.str(), &int64_val)); // Crash OK + dim->add_value_int64(int64_val); + absl::StrAppendFormat(&formatted_str_, "%d ", + dim->value_int64(dim->value_int64_size() - 1)); + } else if (typeid(values[nstart]) == typeid(string)) { + dim->add_value_str(sstream.str()); + absl::StrAppend(&formatted_str_, "'", + dim->value_str(dim->value_str_size() - 1), "' "); + } else { + CHECK(false) << "Unsupported type: " << typeid(values[nstart]).name(); + } + } else { + for (int i = 0; i < tensor_->dim_size(depth); i++) { + // Last dimension, pull the values. + if (depth == tensor_->dims() - 1) { + std::ostringstream sstream; + sstream << values[nstart]; + + if (typeid(values[nstart]) == typeid(double)) { + double double_val = 0.0; + CHECK(absl::SimpleAtod(sstream.str(), &double_val)); // Crash OK + dim->add_value_double(double_val); + absl::StrAppendFormat( + &formatted_str_, "%.2f ", + dim->value_double(dim->value_double_size() - 1)); + } else if (typeid(values[nstart]) == typeid(int64_t)) { + int64_t int64_val = 0; + CHECK(absl::SimpleAtoi(sstream.str(), &int64_val)); // Crash OK + dim->add_value_int64(int64_val); + absl::StrAppendFormat( + &formatted_str_, "%d ", + dim->value_int64(dim->value_int64_size() - 1)); + } else if (typeid(values[nstart]) == typeid(string)) { + dim->add_value_str(sstream.str()); + absl::StrAppend(&formatted_str_, "'", + dim->value_str(dim->value_str_size() - 1), "' "); + } else { + CHECK(false) << "Unsupported type: " + << typeid(values[nstart]).name(); + } + ++nstart; + } else { + // Not-last dimension. Drill deeper. + nstart = BuildOutput(nstart, depth + 1, values, dim); + } + } + } + if (formatted_str_.length() > kTFProfTenosrMaxDisplayLen) { + formatted_str_ = formatted_str_.substr(0, kTFProfTenosrMaxDisplayLen); + } + formatted_str_ += "],\n"; + return nstart; + } + + template + void GetValueVec(std::vector* value_vec) { + // TODO(xpan): Address the huge tensor problem. + if (tensor_->NumElements() > kTFProfTensorMaxWarnLen) { + absl::FPrintF(stderr, "Showing huge tensor, the tool might halt...\n"); + } + auto values = tensor_->flat(); + for (int64_t i = 0; i < tensor_->NumElements(); i++) { + value_vec->push_back(static_cast(values(i))); + } + } + + TFProfTensorProto tfprof_tensor_pb_; + std::unique_ptr tensor_; + string formatted_str_; +}; +} // namespace tfprof +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_TENSOR_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/profiler/internal/tfprof_timeline.h b/third_party/tflite-hdrs/tensorflow/core/profiler/internal/tfprof_timeline.h new file mode 100644 index 00000000..b50c5633 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/profiler/internal/tfprof_timeline.h @@ -0,0 +1,197 @@ +/* Copyright 2016 The TensorFlow Authors All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_TIMELINE_H_ +#define TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_TIMELINE_H_ + +#include +#include +#include +#include + +#include "absl/strings/str_cat.h" +#include "json/json.h" +#include "tensorflow/core/profiler/internal/tfprof_node_show.h" + +namespace tensorflow { +namespace tfprof { + +typedef std::map Event; + +// Class for generating timeline json output. +class ChromeTraceFormatter { + public: + ChromeTraceFormatter() = default; + // The following methods creates timeline nodes. See chrome tracing format + // document for details. + Json::Value CreateEvent(const string& ph, const string& category, + const string& name, int64_t pid, int64_t tid, + int64_t ts); + + void EmitPID(const string& name, int64_t pid); + + void EmitRegion(int64_t ts, int64_t duration, int64_t pid, int64_t tid, + const string& category, const string& name, Json::Value args); + + void EmitFlowStart(const string& name, int64_t ts, int64_t pid, int64_t tid, + int64_t flow_id); + + void EmitFlowEnd(const string& name, int64_t ts, int64_t pid, int64_t tid, + int64_t flow_id); + + void EmitCounter(const string& category, const string& name, int64_t pid, + int64_t ts, const string& device, int64_t bytes, + const std::map>& tensor_mem); + + string Format(); + + private: + // A event is a visualization unit in timeline. + std::vector events_; + std::vector metadata_; +}; + +// A process (time series of events) in the timeline. +class Process { + public: + Process(const string& device, int64_t pid) : device(device), pid(pid) {} + + // Each lane is a map from start_time to end_time. + std::vector> lanes; + // device for the time series. + string device; + // unique id for the time series. + int64_t pid; +}; + +class TimeNode { + public: + TimeNode(Process* process, GraphNode* node, int64_t start_micros, + int64_t exec_micros) + : process(process), + node(node), + start_micros(start_micros), + exec_micros(exec_micros), + tid(-1) {} + virtual ~TimeNode() = default; + + const string& name() { return node->name(); } + + Process* process; + GraphNode* node; + int64_t start_micros; + int64_t exec_micros; + int64_t tid; + std::vector next_tnodes; +}; + +// Tracking the memory based on the op input/output, temporary bytes and +// persistent bytes. +// Currently, we calculate a "predicted" memory, but do not use it for display. +// The displayed memory timeline is directly from the TensorFlow allocator, +// which is the groundtruth. +class MemoryTracker { + public: + class Device { + public: + // map from tensor name to a pair of . + std::map> tensor_allocs; + // ground truth memory stats. time->bytes. + std::map allocations; + // tracked allocations, might miss some bytes. + std::map tracked_allocations; + }; + + void TrackNode(int64_t step, const GraphNode* node); + + const std::map& devices() const { return devices_; } + + private: + std::map devices_; +}; + +class Timeline { + public: + Timeline(int64_t step, const string& outfile) + : step_(step), outfile_(outfile) {} + ~Timeline() = default; + + int64_t step() const { return step_; } + void SetStep(int64_t step) { step_ = step; } + + void GenerateGraphTimeline(const std::vector& gnodes); + + void GenerateScopeTimeline(const ScopeNode* node); + + void GenerateCodeTimeline(const CodeNode* node); + + private: + void TrackNode(const GraphNode* node) { mem_tracker_.TrackNode(step_, node); } + + void OutputTimeline(); + + template + void EmitTreeNode(const Node* node, int64_t start_time, int64_t duration, + int64_t depth, std::set* visited_depth) { + if (visited_depth->find(depth) == visited_depth->end()) { + chrome_formatter_.EmitPID(absl::StrCat("Scope:", depth), depth); + visited_depth->insert(depth); + } + + Json::Value args(Json::objectValue); + args["name"] = Json::Value(node->name()); + args["op"] = Json::Value(node->name()); + chrome_formatter_.EmitRegion(start_time, duration, depth, 0, "Op", + node->name(), args); + + int64_t total_micros = 0; + int64_t c_start_time = start_time; + for (const Node* child : node->show_children) { + int64_t total_exec_micros = child->proto().total_exec_micros(); + if (total_exec_micros <= 0) { + continue; + } + EmitTreeNode(child, c_start_time, total_exec_micros, depth + 1, + visited_depth); + c_start_time += total_exec_micros; + total_micros += total_exec_micros; + } + CHECK(total_micros <= duration) << node->name() << " parent:" << duration + << " children:" << total_micros; + } + + void AllocateTimeNodes(GraphNode* gnode); + + void AllocateLanes(); + + int64_t AllocatePID(); + + int64_t step_; + const string outfile_; + int64_t next_pid_ = 0; + MemoryTracker mem_tracker_; + ChromeTraceFormatter chrome_formatter_; + std::map device_pids_; + + std::map> process_; + std::map>> + alloc_nodes_; + std::map>> tnodes_; +}; + +} // namespace tfprof +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_TIMELINE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/profiler/internal/tfprof_utils.h b/third_party/tflite-hdrs/tensorflow/core/profiler/internal/tfprof_utils.h new file mode 100644 index 00000000..7f4e49ba --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/profiler/internal/tfprof_utils.h @@ -0,0 +1,73 @@ +/* Copyright 2016 The TensorFlow Authors All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_UTILS_H_ +#define TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_UTILS_H_ + +#include +#include + +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/profiler/tfprof_options.h" + +namespace tensorflow { +namespace tfprof { +string FormatNumber(int64_t n); + +string FormatTime(int64_t micros); + +string FormatMemory(int64_t bytes); + +string FormatShapes(const std::vector& shapes); + +absl::Status ParseCmdLine(const string& line, string* cmd, + tensorflow::tfprof::Options* opts); + +string StringReplace(const string& str, const string& oldsub, + const string& newsub); + +template +absl::Status ReadProtoFile(Env* env, const string& fname, T* proto, + bool binary_first) { + string out; + absl::Status s = ReadFileToString(env, fname, &out); + if (!s.ok()) return s; + + if (binary_first) { + if (ReadBinaryProto(tensorflow::Env::Default(), fname, proto).ok()) { + return absl::Status(); + } else if (protobuf::TextFormat::ParseFromString(out, proto)) { + return absl::Status(); + } + } else { + if (protobuf::TextFormat::ParseFromString(out, proto)) { + return absl::Status(); + } else if (ReadBinaryProto(tensorflow::Env::Default(), fname, proto).ok()) { + return absl::Status(); + } + } + return errors::InvalidArgument("Cannot parse proto file."); +} + +void PrintHelp(); + +// Generate helper message based on the command and options. +string QueryDoc(const string& cmd, const Options& opts); + +} // namespace tfprof +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_UTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/profiler/lib/annotated_traceme.h b/third_party/tflite-hdrs/tensorflow/core/profiler/lib/annotated_traceme.h new file mode 100644 index 00000000..150b8097 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/profiler/lib/annotated_traceme.h @@ -0,0 +1,59 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_PROFILER_LIB_ANNOTATED_TRACEME_H_ +#define TENSORFLOW_CORE_PROFILER_LIB_ANNOTATED_TRACEME_H_ + +#include +#include +#include + +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/profiler/lib/scoped_annotation.h" +#include "tensorflow/core/profiler/lib/traceme.h" + +namespace tensorflow { +namespace profiler { + +// Combination of TraceMe and ScopedAnnotation which share the same label. +// Optimization are done to ensure the label generation are done once. +class AnnotatedTraceMe { + public: + template + explicit AnnotatedTraceMe(NameGeneratorT&& name_generator, int level = 1) { + DCHECK_GE(level, 1); + bool annotation_enabled = tsl::profiler::ScopedAnnotation::IsEnabled(); + bool traceme_enabled = tsl::profiler::TraceMe::Active(level); + if (TF_PREDICT_TRUE(!annotation_enabled && !traceme_enabled)) { + return; + } + std::string name = name_generator(); + if (annotation_enabled) { + scoped_annotation_.emplace(name); + } + if (TF_PREDICT_TRUE(traceme_enabled)) { + trace_me_.emplace([&name] { return std::move(name); }, level); + } + } + + private: + std::optional trace_me_; + std::optional scoped_annotation_; +}; + +} // namespace profiler +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PROFILER_LIB_ANNOTATED_TRACEME_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/profiler/lib/connected_traceme.h b/third_party/tflite-hdrs/tensorflow/core/profiler/lib/connected_traceme.h new file mode 100644 index 00000000..e696cdaf --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/profiler/lib/connected_traceme.h @@ -0,0 +1,45 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_PROFILER_LIB_CONNECTED_TRACEME_H_ +#define TENSORFLOW_CORE_PROFILER_LIB_CONNECTED_TRACEME_H_ + +#include +#include + +#include "absl/base/macros.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "tensorflow/core/profiler/lib/context_types.h" +#include "tensorflow/core/profiler/lib/traceme.h" +#include "tensorflow/core/profiler/lib/traceme_encode.h" +#include "tsl/profiler/lib/connected_traceme.h" + +// TODO: b/323943471 - This macro should eventually be provided by Abseil. +#ifndef ABSL_DEPRECATE_AND_INLINE +#define ABSL_DEPRECATE_AND_INLINE() +#endif + +namespace tensorflow { +namespace profiler { + +using TraceMeConsumer ABSL_DEPRECATE_AND_INLINE() = + tsl::profiler::TraceMeConsumer; // NOLINT +using TraceMeProducer ABSL_DEPRECATE_AND_INLINE() = + tsl::profiler::TraceMeProducer; // NOLINT + +} // namespace profiler +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PROFILER_LIB_CONNECTED_TRACEME_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/profiler/lib/context_types.h b/third_party/tflite-hdrs/tensorflow/core/profiler/lib/context_types.h new file mode 100644 index 00000000..dbb7fc2e --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/profiler/lib/context_types.h @@ -0,0 +1,48 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_PROFILER_LIB_CONTEXT_TYPES_H_ +#define TENSORFLOW_CORE_PROFILER_LIB_CONTEXT_TYPES_H_ + +#include + +#include "absl/base/macros.h" +#include "tsl/profiler/lib/context_types.h" + +// TODO: b/323943471 - This macro should eventually be provided by Abseil. +#ifndef ABSL_DEPRECATE_AND_INLINE +#define ABSL_DEPRECATE_AND_INLINE() +#endif + +namespace tensorflow { +namespace profiler { + +using ContextType ABSL_DEPRECATE_AND_INLINE() = + tsl::profiler::ContextType; // NOLINT + +ABSL_DEPRECATE_AND_INLINE() +inline const char* GetContextTypeString( + tsl::profiler::ContextType context_type) { + return tsl::profiler::GetContextTypeString(context_type); +} + +ABSL_DEPRECATE_AND_INLINE() +inline tsl::profiler::ContextType GetSafeContextType(uint32_t context_type) { + return tsl::profiler::GetSafeContextType(context_type); +} + +} // namespace profiler +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PROFILER_LIB_CONTEXT_TYPES_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/profiler/lib/device_profiler_session.h b/third_party/tflite-hdrs/tensorflow/core/profiler/lib/device_profiler_session.h new file mode 100644 index 00000000..179a3795 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/profiler/lib/device_profiler_session.h @@ -0,0 +1,83 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_PROFILER_LIB_DEVICE_PROFILER_SESSION_H_ +#define TENSORFLOW_CORE_PROFILER_LIB_DEVICE_PROFILER_SESSION_H_ + +#include "tensorflow/core/framework/step_stats.pb.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/platform.h" +#include "tensorflow/core/platform/status.h" + +#if !defined(IS_MOBILE_PLATFORM) +#include "tensorflow/core/profiler/convert/xplane_to_step_stats.h" +#include "tensorflow/core/profiler/lib/profiler_session.h" +#include "tensorflow/core/profiler/protobuf/xplane.pb.h" +#endif +#include "tsl/profiler/protobuf/profiler_options.pb.h" + +namespace tensorflow { + +// Wraps a ProfilerSession configured to collect only device traces. +// Returns data in StepStats format. +class DeviceProfilerSession { + public: + // Creates a DeviceProfilerSession and starts tracing. + // Traces GPU devices if present. + // Does not trace TPU devices (not supported). + static std::unique_ptr Create() { +#if !defined(IS_MOBILE_PLATFORM) + ProfileOptions options = tsl::ProfilerSession::DefaultOptions(); + options.set_host_tracer_level(0); + options.set_device_type(ProfileOptions::GPU); + return absl::WrapUnique(new DeviceProfilerSession(options)); +#else + return nullptr; +#endif + } + + // Stops tracing and converts the data to StepStats format. + // Should be called at most once. + absl::Status CollectData(StepStats* step_stats) { +#if defined(IS_MOBILE_PLATFORM) + return errors::Unimplemented("Profiling not supported on mobile platform."); +#else + profiler::XSpace space; + TF_RETURN_IF_ERROR(profiler_session_->CollectData(&space)); + profiler::ConvertGpuXSpaceToStepStats(space, step_stats); + return absl::OkStatus(); +#endif + } + + private: + // Constructs an instance of the class and starts profiling + explicit DeviceProfilerSession(const ProfileOptions& options) +#if !defined(IS_MOBILE_PLATFORM) + : profiler_session_(tsl::ProfilerSession::Create(options)) +#endif + { + } + + // DeviceProfilerSession is neither copyable nor movable. + DeviceProfilerSession(const DeviceProfilerSession&) = delete; + DeviceProfilerSession& operator=(const DeviceProfilerSession&) = delete; + +#if !defined(IS_MOBILE_PLATFORM) + // TODO(b/256013238) + std::unique_ptr profiler_session_; +#endif +}; + +} // namespace tensorflow +#endif // TENSORFLOW_CORE_PROFILER_LIB_DEVICE_PROFILER_SESSION_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/profiler/lib/profiler_controller.h b/third_party/tflite-hdrs/tensorflow/core/profiler/lib/profiler_controller.h new file mode 100644 index 00000000..21936dcd --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/profiler/lib/profiler_controller.h @@ -0,0 +1,40 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_PROFILER_LIB_PROFILER_CONTROLLER_H_ +#define TENSORFLOW_CORE_PROFILER_LIB_PROFILER_CONTROLLER_H_ + +#include + +#include "absl/base/macros.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/profiler/lib/profiler_interface.h" +#include "tensorflow/core/profiler/protobuf/xplane.pb.h" +#include "tsl/profiler/lib/profiler_controller.h" + +// TODO: b/323943471 - This macro should eventually be provided by Abseil. +#ifndef ABSL_DEPRECATE_AND_INLINE +#define ABSL_DEPRECATE_AND_INLINE() +#endif + +namespace tensorflow { +namespace profiler { + +using ProfilerController ABSL_DEPRECATE_AND_INLINE() = + tsl::profiler::ProfilerController; // NOLINT + +} // namespace profiler +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PROFILER_LIB_PROFILER_CONTROLLER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/profiler/lib/profiler_factory.h b/third_party/tflite-hdrs/tensorflow/core/profiler/lib/profiler_factory.h new file mode 100644 index 00000000..ebba761b --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/profiler/lib/profiler_factory.h @@ -0,0 +1,64 @@ +/* Copyright 2019 The TensorFlow Authors All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_PROFILER_LIB_PROFILER_FACTORY_H_ +#define TENSORFLOW_CORE_PROFILER_LIB_PROFILER_FACTORY_H_ + +#include +#include +#include +#include + +#include "absl/base/macros.h" +#include "tensorflow/core/profiler/lib/profiler_interface.h" +#include "tsl/profiler/lib/profiler_factory.h" +#include "tsl/profiler/protobuf/profiler_options.pb.h" + +// TODO: b/323943471 - This macro should eventually be provided by Abseil. +#ifndef ABSL_DEPRECATE_AND_INLINE +#define ABSL_DEPRECATE_AND_INLINE() +#endif + +namespace tensorflow { +namespace profiler { + +// A ProfilerFactory returns an instance of ProfilerInterface if ProfileOptions +// require it. Otherwise, it might return nullptr. +using ProfilerFactor ABSL_DEPRECATE_AND_INLINE() = + tsl::profiler::ProfilerFactory; // NOLINT + +// Registers a profiler factory. Should be invoked at most once per factory. +ABSL_DEPRECATE_AND_INLINE() +inline void RegisterProfilerFactory(tsl::profiler::ProfilerFactory factory) { + tsl::profiler::RegisterProfilerFactory(std::move(factory)); +} + +// Invokes all registered profiler factories with the given options, and +// returns the instantiated (non-null) profiler interfaces. +ABSL_DEPRECATE_AND_INLINE() +inline std::vector> +CreateProfilers(const tensorflow::ProfileOptions& options) { + return tsl::profiler::CreateProfilers(options); +} + +// For testing only. +ABSL_DEPRECATE_AND_INLINE() +inline void ClearRegisteredProfilersForTest() { + tsl::profiler::ClearRegisteredProfilersForTest(); +} + +} // namespace profiler +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PROFILER_LIB_PROFILER_FACTORY_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/profiler/lib/profiler_interface.h b/third_party/tflite-hdrs/tensorflow/core/profiler/lib/profiler_interface.h new file mode 100644 index 00000000..11423c1a --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/profiler/lib/profiler_interface.h @@ -0,0 +1,37 @@ +/* Copyright 2016 The TensorFlow Authors All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_PROFILER_LIB_PROFILER_INTERFACE_H_ +#define TENSORFLOW_CORE_PROFILER_LIB_PROFILER_INTERFACE_H_ + +#include "absl/base/macros.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/profiler/protobuf/xplane.pb.h" +#include "tsl/profiler/lib/profiler_interface.h" + +// TODO: b/323943471 - This macro should eventually be provided by Abseil. +#ifndef ABSL_DEPRECATE_AND_INLINE +#define ABSL_DEPRECATE_AND_INLINE() +#endif + +namespace tensorflow { +namespace profiler { + +using ProfilerInterface ABSL_DEPRECATE_AND_INLINE() = + tsl::profiler::ProfilerInterface; // NOLINT + +} // namespace profiler +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PROFILER_LIB_PROFILER_INTERFACE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/profiler/lib/profiler_lock.h b/third_party/tflite-hdrs/tensorflow/core/profiler/lib/profiler_lock.h new file mode 100644 index 00000000..7480df58 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/profiler/lib/profiler_lock.h @@ -0,0 +1,36 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_PROFILER_LIB_PROFILER_LOCK_H_ +#define TENSORFLOW_CORE_PROFILER_LIB_PROFILER_LOCK_H_ + +#include "absl/base/macros.h" +#include "tensorflow/core/platform/statusor.h" +#include "tsl/profiler/lib/profiler_lock.h" + +// TODO: b/323943471 - This macro should eventually be provided by Abseil. +#ifndef ABSL_DEPRECATE_AND_INLINE +#define ABSL_DEPRECATE_AND_INLINE() +#endif + +namespace tensorflow { +namespace profiler { + +using ProfilerLock ABSL_DEPRECATE_AND_INLINE() = + tsl::profiler::ProfilerLock; // NOLINT + +} // namespace profiler +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PROFILER_LIB_PROFILER_LOCK_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/profiler/lib/profiler_session.h b/third_party/tflite-hdrs/tensorflow/core/profiler/lib/profiler_session.h new file mode 100644 index 00000000..76099cc1 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/profiler/lib/profiler_session.h @@ -0,0 +1,32 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_PROFILER_LIB_PROFILER_SESSION_H_ +#define TENSORFLOW_CORE_PROFILER_LIB_PROFILER_SESSION_H_ + +#include "absl/base/macros.h" +#include "tsl/profiler/lib/profiler_session.h" + +// TODO: b/323943471 - This macro should eventually be provided by Abseil. +#ifndef ABSL_DEPRECATE_AND_INLINE +#define ABSL_DEPRECATE_AND_INLINE() +#endif + +namespace tensorflow { + +using ProfilerSession ABSL_DEPRECATE_AND_INLINE() = + tsl::ProfilerSession; // NOLINT + +} // namespace tensorflow +#endif // TENSORFLOW_CORE_PROFILER_LIB_PROFILER_SESSION_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/profiler/lib/scoped_annotation.h b/third_party/tflite-hdrs/tensorflow/core/profiler/lib/scoped_annotation.h new file mode 100644 index 00000000..8fa9fd67 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/profiler/lib/scoped_annotation.h @@ -0,0 +1,48 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_PROFILER_LIB_SCOPED_ANNOTATION_H_ +#define TENSORFLOW_CORE_PROFILER_LIB_SCOPED_ANNOTATION_H_ + +#include + +#include +#include +#include +#include + +#include "absl/base/macros.h" +#include "absl/strings/string_view.h" +#include "tensorflow/core/platform/types.h" +#include "tsl/profiler/lib/scoped_annotation.h" + +#if !defined(IS_MOBILE_PLATFORM) +#include "xla/tsl/profiler/backends/cpu/annotation_stack.h" +#endif + +// TODO: b/323943471 - This macro should eventually be provided by Abseil. +#ifndef ABSL_DEPRECATE_AND_INLINE +#define ABSL_DEPRECATE_AND_INLINE() +#endif + +namespace tensorflow { +namespace profiler { + +using ScopedAnnotation ABSL_DEPRECATE_AND_INLINE() = + tsl::profiler::ScopedAnnotation; // NOLINT + +} // namespace profiler +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PROFILER_LIB_SCOPED_ANNOTATION_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/profiler/lib/scoped_memory_debug_annotation.h b/third_party/tflite-hdrs/tensorflow/core/profiler/lib/scoped_memory_debug_annotation.h new file mode 100644 index 00000000..e44cdb3c --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/profiler/lib/scoped_memory_debug_annotation.h @@ -0,0 +1,42 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_PROFILER_LIB_SCOPED_MEMORY_DEBUG_ANNOTATION_H_ +#define TENSORFLOW_CORE_PROFILER_LIB_SCOPED_MEMORY_DEBUG_ANNOTATION_H_ + +#include +#include +#include +#include + +#include "absl/base/macros.h" +#include "tsl/profiler/lib/scoped_memory_debug_annotation.h" + +// TODO: b/323943471 - This macro should eventually be provided by Abseil. +#ifndef ABSL_DEPRECATE_AND_INLINE +#define ABSL_DEPRECATE_AND_INLINE() +#endif + +namespace tensorflow { +namespace profiler { + +using MemoryDebugAnnotation ABSL_DEPRECATE_AND_INLINE() = + tsl::profiler::MemoryDebugAnnotation; // NOLINT +using ScopedMemoryDebugAnnotation ABSL_DEPRECATE_AND_INLINE() = + tsl::profiler::ScopedMemoryDebugAnnotation; // NOLINT + +} // namespace profiler +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PROFILER_LIB_SCOPED_MEMORY_DEBUG_ANNOTATION_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/profiler/lib/traceme.h b/third_party/tflite-hdrs/tensorflow/core/profiler/lib/traceme.h new file mode 100644 index 00000000..23e48948 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/profiler/lib/traceme.h @@ -0,0 +1,50 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_PROFILER_LIB_TRACEME_H_ +#define TENSORFLOW_CORE_PROFILER_LIB_TRACEME_H_ + +#include "absl/base/macros.h" +#include "tensorflow/core/profiler/lib/traceme_encode.h" // IWYU pragma: export +#include "tsl/profiler/lib/traceme.h" + +#if !defined(IS_MOBILE_PLATFORM) +#include "xla/tsl/profiler/utils/time_utils.h" +#endif + +// TODO: b/323943471 - This macro should eventually be provided by Abseil. +#ifndef ABSL_DEPRECATE_AND_INLINE +#define ABSL_DEPRECATE_AND_INLINE() +#endif + +namespace tensorflow { +namespace profiler { + +using tsl::profiler::kInfo; // NOLINT +using TraceMe ABSL_DEPRECATE_AND_INLINE() = tsl::profiler::TraceMe; // NOLINT +using TraceMeLevel ABSL_DEPRECATE_AND_INLINE() = + tsl::profiler::TraceMeLevel; // NOLINT + +ABSL_DEPRECATE_AND_INLINE() +inline int GetTFTraceMeLevel(bool is_expensive) { + return tsl::profiler::GetTFTraceMeLevel(is_expensive); +} + +ABSL_DEPRECATE_AND_INLINE() +inline bool TfOpDetailsEnabled() { return tsl::profiler::TfOpDetailsEnabled(); } + +} // namespace profiler +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PROFILER_LIB_TRACEME_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/profiler/lib/traceme_encode.h b/third_party/tflite-hdrs/tensorflow/core/profiler/lib/traceme_encode.h new file mode 100644 index 00000000..0ebd2051 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/profiler/lib/traceme_encode.h @@ -0,0 +1,100 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_PROFILER_LIB_TRACEME_ENCODE_H_ +#define TENSORFLOW_CORE_PROFILER_LIB_TRACEME_ENCODE_H_ + +#include + +#include +#include +#include + +#include "absl/base/macros.h" +#include "absl/strings/match.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/macros.h" +#include "tsl/profiler/lib/traceme_encode.h" + +// TODO: b/323943471 - This macro should eventually be provided by Abseil. +#ifndef ABSL_DEPRECATE_AND_INLINE +#define ABSL_DEPRECATE_AND_INLINE() +#endif + +namespace tensorflow { +namespace profiler { + +using TraceMeArg ABSL_DEPRECATE_AND_INLINE() = + tsl::profiler::TraceMeArg; // NOLINT + +ABSL_DEPRECATE_AND_INLINE() +inline std::string TraceMeEncode( + std::string name, std::initializer_list args) { + return tsl::profiler::TraceMeEncode(std::move(name), args); +} + +ABSL_DEPRECATE_AND_INLINE() +inline std::string TraceMeEncode( + absl::string_view name, + std::initializer_list args) { + return tsl::profiler::TraceMeEncode(name, args); +} + +ABSL_DEPRECATE_AND_INLINE() +inline std::string TraceMeEncode( + const char* name, std::initializer_list args) { + return tsl::profiler::TraceMeEncode(name, args); +} + +ABSL_DEPRECATE_AND_INLINE() +inline std::string TraceMeEncode( + std::initializer_list args) { + return tsl::profiler::TraceMeEncode(args); +} + +ABSL_DEPRECATE_AND_INLINE() +// Concatenates op_name and op_type. +inline std::string TraceMeOp(absl::string_view op_name, + absl::string_view op_type) { + return tsl::profiler::TraceMeOp(op_name, op_type); +} + +ABSL_DEPRECATE_AND_INLINE() +inline std::string TraceMeOp(const char* op_name, const char* op_type) { + return tsl::profiler::TraceMeOp(op_name, op_type); +} + +ABSL_DEPRECATE_AND_INLINE() +inline std::string TraceMeOp(std::string&& op_name, absl::string_view op_type) { + return tsl::profiler::TraceMeOp(op_name, op_type); +} + +ABSL_DEPRECATE_AND_INLINE() +// Concatenates op_name and op_type. +inline std::string TraceMeOpOverride(absl::string_view op_name, + absl::string_view op_type) { + return tsl::profiler::TraceMeOpOverride(op_name, op_type); +} + +ABSL_DEPRECATE_AND_INLINE() +inline std::string TraceMeOpOverride(const char* op_name, const char* op_type) { + return tsl::profiler::TraceMeOpOverride(op_name, op_type); +} + +} // namespace profiler +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PROFILER_LIB_TRACEME_ENCODE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/profiler/rpc/client/profiler_client.h b/third_party/tflite-hdrs/tensorflow/core/profiler/rpc/client/profiler_client.h new file mode 100644 index 00000000..73563d1f --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/profiler/rpc/client/profiler_client.h @@ -0,0 +1,41 @@ +/* Copyright 2020 The TensorFlow Authors All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// GRPC client to perform on-demand profiling + +#ifndef TENSORFLOW_CORE_PROFILER_RPC_CLIENT_PROFILER_CLIENT_H_ +#define TENSORFLOW_CORE_PROFILER_RPC_CLIENT_PROFILER_CLIENT_H_ + +#include +#include + +#include "absl/strings/string_view.h" +#include "absl/time/time.h" +#include "xla/tsl/profiler/rpc/client/profiler_client.h" +#include "tensorflow/core/platform/status.h" +#include "tsl/profiler/protobuf/profiler_analysis.grpc.pb.h" +#include "tsl/profiler/protobuf/profiler_service.grpc.pb.h" + +namespace tensorflow { +namespace profiler { + +using tsl::profiler::MonitorGrpc; // NOLINT +using tsl::profiler::NewSessionGrpc; // NOLINT +using tsl::profiler::ProfileGrpc; // NOLINT +using tsl::profiler::RemoteProfilerSession; // NOLINT + +} // namespace profiler +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PROFILER_RPC_CLIENT_PROFILER_CLIENT_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/profiler/rpc/client/remote_profiler_session_manager.h b/third_party/tflite-hdrs/tensorflow/core/profiler/rpc/client/remote_profiler_session_manager.h new file mode 100644 index 00000000..3d0b9f58 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/profiler/rpc/client/remote_profiler_session_manager.h @@ -0,0 +1,41 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PROFILER_RPC_CLIENT_REMOTE_PROFILER_SESSION_MANAGER_H_ +#define TENSORFLOW_CORE_PROFILER_RPC_CLIENT_REMOTE_PROFILER_SESSION_MANAGER_H_ + +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "xla/tsl/profiler/rpc/client/remote_profiler_session_manager.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/thread_annotations.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/profiler/rpc/client/profiler_client.h" + +namespace tensorflow { +namespace profiler { + +using tsl::profiler::AddressResolver; // NOLINT +using tsl::profiler::RemoteProfilerSessionManager; // NOLINT + +} // namespace profiler +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PROFILER_RPC_CLIENT_REMOTE_PROFILER_SESSION_MANAGER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/profiler/rpc/client/save_profile.h b/third_party/tflite-hdrs/tensorflow/core/profiler/rpc/client/save_profile.h new file mode 100644 index 00000000..1de60aeb --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/profiler/rpc/client/save_profile.h @@ -0,0 +1,40 @@ +/* Copyright 2017 The TensorFlow Authors All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PROFILER_RPC_CLIENT_SAVE_PROFILE_H_ +#define TENSORFLOW_CORE_PROFILER_RPC_CLIENT_SAVE_PROFILE_H_ + +#include +#include + +#include "xla/tsl/profiler/rpc/client/save_profile.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/profiler/protobuf/xplane.pb.h" +#include "tsl/profiler/protobuf/profiler_service.pb.h" + +namespace tensorflow { +namespace profiler { + +using tsl::profiler::GetCurrentTimeStampAsString; // NOLINT +using tsl::profiler::GetTensorBoardProfilePluginDir; // NOLINT +using tsl::profiler::SaveGzippedToolData; // NOLINT +using tsl::profiler::SaveProfile; // NOLINT +using tsl::profiler::SaveXSpace; // NOLINT + +} // namespace profiler +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PROFILER_RPC_CLIENT_SAVE_PROFILE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/profiler/rpc/grpc.h b/third_party/tflite-hdrs/tensorflow/core/profiler/rpc/grpc.h new file mode 100644 index 00000000..d37c535d --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/profiler/rpc/grpc.h @@ -0,0 +1,37 @@ +/* Copyright 2020 The TensorFlow Authors All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// GRPC utilities + +#ifndef TENSORFLOW_CORE_PROFILER_RPC_GRPC_H_ +#define TENSORFLOW_CORE_PROFILER_RPC_GRPC_H_ + +#include + +#include "grpcpp/security/credentials.h" +#include "grpcpp/security/server_credentials.h" + +namespace tensorflow { +namespace profiler { + +// Returns default credentials for use when creating a gRPC server. +std::shared_ptr<::grpc::ServerCredentials> GetDefaultServerCredentials(); + +// Returns default credentials for use when creating a gRPC channel. +std::shared_ptr<::grpc::ChannelCredentials> GetDefaultChannelCredentials(); + +} // namespace profiler +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PROFILER_RPC_GRPC_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/profiler/rpc/profiler_server.h b/third_party/tflite-hdrs/tensorflow/core/profiler/rpc/profiler_server.h new file mode 100644 index 00000000..dec0a235 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/profiler/rpc/profiler_server.h @@ -0,0 +1,33 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_PROFILER_RPC_PROFILER_SERVER_H_ +#define TENSORFLOW_CORE_PROFILER_RPC_PROFILER_SERVER_H_ + +#include + +#include "grpcpp/grpcpp.h" +#include "xla/tsl/profiler/rpc/profiler_server.h" +#include "tensorflow/core/platform/types.h" +#include "tsl/profiler/protobuf/profiler_service.grpc.pb.h" + +namespace tensorflow { +namespace profiler { + +using tsl::profiler::ProfilerServer; // NOLINT + +} // namespace profiler +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PROFILER_RPC_PROFILER_SERVER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/profiler/rpc/profiler_service_impl.h b/third_party/tflite-hdrs/tensorflow/core/profiler/rpc/profiler_service_impl.h new file mode 100644 index 00000000..f3b6a293 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/profiler/rpc/profiler_service_impl.h @@ -0,0 +1,31 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_PROFILER_RPC_PROFILER_SERVICE_IMPL_H_ +#define TENSORFLOW_CORE_PROFILER_RPC_PROFILER_SERVICE_IMPL_H_ + +#include + +#include "xla/tsl/profiler/rpc/profiler_service_impl.h" +#include "tsl/profiler/protobuf/profiler_service.grpc.pb.h" + +namespace tensorflow { +namespace profiler { + +using tsl::profiler::CreateProfilerService; // NOLINT + +} // namespace profiler +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PROFILER_RPC_PROFILER_SERVICE_IMPL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/profiler/tfprof_options.h b/third_party/tflite-hdrs/tensorflow/core/profiler/tfprof_options.h new file mode 100644 index 00000000..61143b49 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/profiler/tfprof_options.h @@ -0,0 +1,186 @@ +/* Copyright 2016 The TensorFlow Authors All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PROFILER_TFPROF_OPTIONS_H_ +#define TENSORFLOW_CORE_PROFILER_TFPROF_OPTIONS_H_ + +#include +#include +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +namespace tfprof { +static const char* const kOptions[] = { + "-max_depth", + "-min_bytes", + "-min_peak_bytes", + "-min_residual_bytes", + "-min_output_bytes", + "-min_micros", + "-min_accelerator_micros", + "-min_cpu_micros", + "-min_params", + "-min_float_ops", + "-min_occurrence", + "-step", + "-order_by", + "-account_type_regexes", + "-start_name_regexes", + "-trim_name_regexes", + "-show_name_regexes", + "-hide_name_regexes", + "-account_displayed_op_only", + "-select", + "-output", +}; + +static const char* const kOrderBy[] = { + "name", "bytes", "peak_bytes", "residual_bytes", + "output_bytes", "micros", "accelerator_micros", "cpu_micros", + "params", "float_ops", "occurrence", +}; + +// Append Only. +// TODO(xpan): As we are adding more fields to be selected, we +// need to have a way to tell users what fields are available in which view. +static const char* const kShown[] = {"bytes", "micros", + "params", "float_ops", + "tensor_value", "device", + "op_types", "occurrence", + "input_shapes", "accelerator_micros", + "cpu_micros", "peak_bytes", + "residual_bytes", "output_bytes"}; + +static const char* const kCmds[] = { + "scope", "graph", "code", "op", "advise", "set", "help", +}; + +static const char* const kOutput[] = {"timeline", "stdout", "file", "pprof", + "none"}; + +static const char* const kTimelineOpts[] = { + "outfile", +}; + +static const char* const kTimelineRequiredOpts[] = {"outfile"}; + +static const char* const kFileOpts[] = { + "outfile", +}; + +static const char* const kFileRequiredOpts[] = { + "outfile", +}; + +static const char* const kPprofOpts[] = { + "outfile", +}; + +static const char* const kPprofRequiredOpts[] = { + "outfile", +}; + +struct Options { + public: + static absl::Status FromProtoStr(const string& opts_proto_str, Options* opts); + + virtual ~Options() {} + Options() + : Options(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, "", {}, {}, {}, {}, {}, + false, {}, "", {}) {} + + Options(int max_depth, int64_t min_bytes, int64_t min_peak_bytes, + int64_t min_residual_bytes, int64_t min_output_bytes, + int64_t min_micros, int64_t min_accelerator_micros, + int64_t min_cpu_micros, int64_t min_params, int64_t min_float_ops, + int64_t min_occurrence, int64_t step, const string& order_by, + const std::vector& account_type_regexes, + const std::vector& start_name_regexes, + const std::vector& trim_name_regexes, + const std::vector& show_name_regexes, + const std::vector& hide_name_regexes, + bool account_displayed_op_only, const std::vector& select, + const string& output_type, + const std::map& output_options) + : max_depth(max_depth), + min_bytes(min_bytes), + min_peak_bytes(min_peak_bytes), + min_residual_bytes(min_residual_bytes), + min_output_bytes(min_output_bytes), + min_micros(min_micros), + min_accelerator_micros(min_accelerator_micros), + min_cpu_micros(min_cpu_micros), + min_params(min_params), + min_float_ops(min_float_ops), + min_occurrence(min_occurrence), + step(step), + order_by(order_by), + account_type_regexes(account_type_regexes), + start_name_regexes(start_name_regexes), + trim_name_regexes(trim_name_regexes), + show_name_regexes(show_name_regexes), + hide_name_regexes(hide_name_regexes), + account_displayed_op_only(account_displayed_op_only), + select(select.begin(), select.end()), + output_type(output_type), + output_options(output_options) {} + + string ToString() const; + + int max_depth; + int64_t min_bytes; + int64_t min_peak_bytes; + int64_t min_residual_bytes; + int64_t min_output_bytes; + int64_t min_micros; + int64_t min_accelerator_micros; + int64_t min_cpu_micros; + int64_t min_params; + int64_t min_float_ops; + int64_t min_occurrence; + int64_t step; + string order_by; + + std::vector account_type_regexes; + std::vector start_name_regexes; + std::vector trim_name_regexes; + std::vector show_name_regexes; + std::vector hide_name_regexes; + bool account_displayed_op_only; + + std::set select; + + string output_type; + std::map output_options; +}; + +// Parse the -output option. +// 'output_opt': User input string with format: output_type:key=value,key=value. +// 'output_type' and 'output_options' are extracted from 'output_opt'. +absl::Status ParseOutput(const string& output_opt, string* output_type, + std::map* output_options); + +} // namespace tfprof +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PROFILER_TFPROF_OPTIONS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/profiler/utils/cost_utils.h b/third_party/tflite-hdrs/tensorflow/core/profiler/utils/cost_utils.h new file mode 100644 index 00000000..7ea14fe9 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/profiler/utils/cost_utils.h @@ -0,0 +1,60 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_PROFILER_UTILS_COST_UTILS_H_ +#define TENSORFLOW_CORE_PROFILER_UTILS_COST_UTILS_H_ + +#include + +#include "absl/container/flat_hash_set.h" +#include "tensorflow/core/grappler/costs/cost_estimator.h" +#include "tensorflow/core/grappler/costs/op_level_cost_estimator.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/profiler/utils/xplane_visitor.h" + +namespace tensorflow { +namespace profiler { + +// This is a wrapper of tensorflow::grappler::OpLevelCostEstimator and use +// tracing time information to estimate the roof line stats for each traced +// tensorflow op. +class TfOpRoofLineCostEstimator + : public tensorflow::grappler::OpLevelCostEstimator { + public: + TfOpRoofLineCostEstimator() = default; + ~TfOpRoofLineCostEstimator() override; + + grappler::DeviceInfo GetDeviceInfo( + const DeviceProperties& device) const override; + + struct OpRoofLineStats { + uint64 flops = 0LL; + uint64 bytes_accessed = 0LL; + bool inaccurate = false; + }; + OpRoofLineStats Predict(const XEventVisitor& event); + + private: + absl::flat_hash_set + unsupported_ops_; // summary for unsupported ops. + + TfOpRoofLineCostEstimator(const TfOpRoofLineCostEstimator&) = delete; + void operator=(const TfOpRoofLineCostEstimator&) = delete; +}; + +} // namespace profiler +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PROFILER_UTILS_COST_UTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/profiler/utils/derived_timeline.h b/third_party/tflite-hdrs/tensorflow/core/profiler/utils/derived_timeline.h new file mode 100644 index 00000000..6d2b5e5b --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/profiler/utils/derived_timeline.h @@ -0,0 +1,202 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_PROFILER_UTILS_DERIVED_TIMELINE_H_ +#define TENSORFLOW_CORE_PROFILER_UTILS_DERIVED_TIMELINE_H_ + +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "xla/tsl/profiler/utils/group_events.h" +#include "xla/tsl/profiler/utils/timespan.h" +#include "tensorflow/core/profiler/protobuf/xplane.pb.h" +#include "tensorflow/core/profiler/utils/xplane_builder.h" +#include "tsl/profiler/protobuf/xplane.pb.h" + +namespace tensorflow { +namespace profiler { + +// Store the mapping from child scope range id to parent scope range id, which +// logically form a scope range call stack tree/forest. +typedef absl::flat_hash_map + ScopeRangeIdTree; + +// Helper for deriving XEvents. +class DerivedXEventBuilder { + public: + DerivedXEventBuilder(XEventBuilder event, std::optional group_id, + std::optional scope_range_id = std::nullopt); + + bool ShouldExpand(const XEventMetadata& event_metadata, + std::optional group_id, + std::optional scope_range_id = std::nullopt) const; + + void Expand(tsl::profiler::Timespan event_span); + tsl::profiler::Timespan GetTimespan() const { return event_.GetTimespan(); } + void SetTimespan(tsl::profiler::Timespan event_span) { + event_.SetTimespan(event_span); + } + + template + void SetOrAddStatValue(const XStatMetadata& metadata, ValueT&& value) { + event_.SetOrAddStatValue(metadata, std::forward(value)); + } + + private: + XEventBuilder event_; + std::optional group_id_; + std::optional scope_range_id_; +}; + +// Helper for deriving an XLine from events in another XLine. +class DerivedXLineBuilder { + public: + DerivedXLineBuilder(XPlaneBuilder* plane, int64_t line_id, + absl::string_view name, int64_t timestamp_ns, + std::vector dependent_lines); + + XLineBuilder& Line() { return line_; } + + // Either merges event with the last event or creates a new event on this + // XLine. group_id and low_level_event_name may be passed to separate + // consecutive invocations of the same event, depending on the XEvent type: + // TF-op, TF name scope: both group_id and low_level_event_name are used. + // HLO-op, step: only group_id is used. + // HLO module, source: both group_id and low_level_event_name are NOT used. + // If scope_range_id is provided, it will be compared with the one in the + // event which is to be merged with. If they are different, merging is not + // allowed. + void ExpandOrAddEvent(const XEventMetadata& event_metadata, + tsl::profiler::Timespan event_span, + std::optional group_id, + std::optional scope_range_id = std::nullopt); + + // The multi-level version of ExpandOrAddEvent. Here, the XEvents at different + // levels all share the same group_id and low_level_event_name. + // Conceptually, the scope_range_ids should be of same length as the + // events_metadata_per_level. However, if it is shorter, this function will + // assume the missing elements at the end of scope_range_ids vector with the + // value of std::nullopt; and if it is longer, the extra elements in + // scope_range_ids will be ignored. + void ExpandOrAddEvents( + const std::vector& events_metadata_per_level, + tsl::profiler::Timespan event_span, std::optional group_id, + absl::Span> scope_range_ids = {}); + + // Reset the last events lower than or equal to the given level. + void ResetLastEvents(int level = 0); + + // To avoid using templates while need hide its implementation in .cc file, + // use two functions to set stat value for int64_t and uint64_t here. + void AddStatToLevelEvent(int level, const XStatMetadata& metadata, + int64_t value); + + void AddStatToLevelEvent(int level, const XStatMetadata& metadata, + uint64_t value); + + const XStatMetadata* GetCorrelationIdMetadata() const { + return correlation_id_metadata_; + } + + const XStatMetadata* GetCudaGraphIdMetadata() const { + return cuda_graph_id_metadata_; + } + + private: + // If the last event of the given level has the same metadata, expands it to + // include the time until the given event's end time. + // Otherwise, adds a new event and clears last_event_by_level_ for the levels + // below the given level and all levels of the dependent lines. Clearing + // last_event_by_level_ prevents a nested event from growing larger than the + // parent event(s). + void ExpandOrAddLevelEvent(const XEventMetadata& event_metadata, + tsl::profiler::Timespan event_span, + std::optional group_id, + std::optional scope_range_id, int level); + void AdjustDurationForTraceViewer(int level); + + const XStatMetadata* group_id_stat_metadata_ = nullptr; + const XStatMetadata* correlation_id_metadata_ = nullptr; + const XStatMetadata* cuda_graph_id_metadata_ = nullptr; + + XLineBuilder line_; + absl::flat_hash_map> + last_event_by_level_; + std::vector dependent_lines_; + bool is_gpu_plane_ = false; +}; + +struct Symbol { + absl::string_view tf_op_name; + std::string source_info; + std::string hlo_text; +}; + +using SymbolResolver = std::function program_id, + absl::string_view hlo_module_name, + absl::string_view hlo_op)>; + +// Derives TF name scope and op events from the TF op's fully qualified name +// with the name of the originating low-level event. +void ProcessTfOpEvent(absl::string_view tf_op_full_name, + tsl::profiler::Timespan event_span, + std::optional group_id, + XPlaneBuilder& plane_builder, + DerivedXLineBuilder& tf_name_scope_line_builder, + DerivedXLineBuilder& tf_op_line_builder); + +// Derives "Steps" line from group_id XStat in XEvents. +void DeriveStepEventsFromGroups( + const tsl::profiler::GroupMetadataMap& group_metadata_map, + XPlane* device_trace); + +// Derives "TensorFlow Ops", "TensorFlow Name Scope", "XLA Ops" and "XLA Module" +// lines in an NVIDIA_GPU device trace from data passed as ScopedAnnotations and +// stored as XStats in XEvents corresponding to GPU Kernels. Consecutive +// annotations with the same value are merged into a single event except for XLA +// modules. The device_trace is both input and output. +void DeriveEventsFromAnnotations( + const SymbolResolver& symbol_resolver, XPlane* device_trace, + const ScopeRangeIdTree* scope_range_id_tree = nullptr); + +// Derives "Launch Activities Summary" line from host trace. +void DeriveEventsFromHostTrace( + const XPlane* host_trace, + const tsl::profiler::GroupMetadataMap& group_metadata_map, + std::vector device_traces); + +// Loops through XPlanes of input XSpace, if it is "device" XPlane, generating +// derived timelines for the plane by calling DeriveEventsFromAnnotations. +void GenerateDerivedTimeLines( + const tsl::profiler::GroupMetadataMap& group_metadata_map, XSpace* space); + +// Derives `Tensorflow Ops`, `Tensorflow Name Scope` and `Source Code` lines +// from device_trace. +void DeriveLinesFromStats(tensorflow::profiler::XPlane* device_trace); + +// Devices Framework Op and Module lines for XLA:CPU ops. +void DeriveLinesForXlaCpuOps(tensorflow::profiler::XPlane* host_trace); + +} // namespace profiler +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PROFILER_UTILS_DERIVED_TIMELINE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/profiler/utils/device_caps_utils.h b/third_party/tflite-hdrs/tensorflow/core/profiler/utils/device_caps_utils.h new file mode 100644 index 00000000..db6bf44e --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/profiler/utils/device_caps_utils.h @@ -0,0 +1,31 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PROFILER_UTILS_DEVICE_CAPS_UTILS_H_ +#define TENSORFLOW_CORE_PROFILER_UTILS_DEVICE_CAPS_UTILS_H_ + +#include "tensorflow/core/profiler/protobuf/hardware_types.pb.h" +#include "tensorflow/core/profiler/protobuf/xplane.pb.h" + +namespace tensorflow { +namespace profiler { + +void SetDeviceCaps(const DeviceCapabilities& caps, XPlane* plane); +DeviceCapabilities GetDeviceCaps(const XPlane& plane); + +} // namespace profiler +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PROFILER_UTILS_DEVICE_CAPS_UTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/profiler/utils/diagnostics.h b/third_party/tflite-hdrs/tensorflow/core/profiler/utils/diagnostics.h new file mode 100644 index 00000000..e5c41751 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/profiler/utils/diagnostics.h @@ -0,0 +1,45 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PROFILER_UTILS_DIAGNOSTICS_H_ +#define TENSORFLOW_CORE_PROFILER_UTILS_DIAGNOSTICS_H_ + +#include "absl/strings/string_view.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/profiler/protobuf/diagnostics.pb.h" +#include "tensorflow/core/profiler/protobuf/op_stats.pb.h" + +namespace tensorflow { +namespace profiler { + +// Error message that the visualization is based on incomplete step. +TF_CONST_INIT extern const absl::string_view kErrorIncompleteStep; + +// Error message that no step marker is seen and visualization contains no +// step info. +TF_CONST_INIT extern const absl::string_view kErrorNoStepMarker; + +TF_CONST_INIT extern const absl::string_view kNoDeviceTraceCollected; + +TF_CONST_INIT extern const absl::string_view kStepsDropped; + +void PopulateStepDiagnostics(const OpStats& op_stats, Diagnostics* diag); + +void PopulateOverviewDiagnostics(const OpStats& op_stats, Diagnostics* diag); + +} // namespace profiler +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PROFILER_UTILS_DIAGNOSTICS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/profiler/utils/event_span.h b/third_party/tflite-hdrs/tensorflow/core/profiler/utils/event_span.h new file mode 100644 index 00000000..f1e3a5b7 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/profiler/utils/event_span.h @@ -0,0 +1,268 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PROFILER_UTILS_EVENT_SPAN_H_ +#define TENSORFLOW_CORE_PROFILER_UTILS_EVENT_SPAN_H_ + +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/strings/string_view.h" +#include "xla/tsl/profiler/utils/timespan.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/profiler/protobuf/op_metrics.pb.h" +#include "tensorflow/core/profiler/protobuf/steps_db.pb.h" + +namespace tensorflow { +namespace profiler { + +// The various event types. Enumerations are numbered such that a bigger number +// has a higher priority than a smaller number when used in execution-time +// breakdown. +enum EventType { + // No event associated with the time. It could be that the machine was idle or + // executing some events which were not traced. + UNKNOWN_TIME = 0, + // Host is computing. + HOST_COMPUTE = 10, + // Host is preprocessing the data before the execution on device. + HOST_PREPROCESS = 20, + // Host is postprocessing the data after the execution on device. + HOST_POSTPROCESS = 30, + // Host is batching data (for inference). + HOST_BATCH_FORMATION = 40, + // Host runtime, like memory allocation and etc. + HOST_RUNTIME = 50, + // Host is compiling. + HOST_COMPILE = 60, + // Host-to-host communication. + HOST_TO_HOST = 70, + // Host-to-device communication. + HOST_TO_DEVICE = 80, + // Host is preparing to launch a computation on device. + HOST_PREPARE = 90, + // Assigns a smaller priority to DEVICE_COLLECTIVES than HOST_WAIT_INPUT, + // because if an all-reduce event is overlapped with an host-wait-input event, + // we want to count it as waiting for input. + // Collective Ops such as All-Reduce. + DEVICE_COLLECTIVES = 100, + // Host is waiting for input. + HOST_WAIT_INPUT = 110, + // Device-to-device communication. + DEVICE_TO_DEVICE = 120, + // Device-to-host communication. + DEVICE_TO_HOST = 130, + // Device is computing with 32-bit precision. + DEVICE_COMPUTE_32 = 140, + // Device is computing with 16-bit precision. + DEVICE_COMPUTE_16 = 150, + // Device is waiting for another device. + DEVICE_WAIT_DEVICE = 160, + // Device is waiting for host. + DEVICE_WAIT_HOST = 170, + LAST_EVENT_TYPE = DEVICE_WAIT_HOST +}; + +// Generic event types that shown to the user. +enum GenericEventType { + kFirstGenericEventType = 1, + // Device is computing. + kDeviceCompute = kFirstGenericEventType, + // Device-to-device communication. + kDeviceToDevice, + // Collective Ops such as All-Reduce and NCCL. + kDeviceCollectives, + // Host is computing. + kHostCompute, + // Host is preparing to launch a computation on device. + kHostPrepare, + // Device waiting for input from the host. + kInput, + // Device sending output to the host. + kOutput, + // Host is compling. + kCompile, + // No recognized event associated with the time. + kAllOthers, + kLastGenericEventType = kAllOthers, +}; + +// Contains the type and timespan of an event. +struct EventTypeSpan { + EventType type; // type of this event. + tsl::profiler::Timespan span; // timespan of this event. + EventTypeSpan(EventType t, tsl::profiler::Timespan s) : type(t), span(s) {} + // Equality test. + bool operator==(const EventTypeSpan& other) const { + return type == other.type && span == other.span; + } + // Inequality test. + bool operator!=(const EventTypeSpan& other) const { + return !(*this == other); + } +}; + +enum class StepMarkerType { + // "TraceContext" TraceMe events. + kExplicitHostStepMarker, + // Identified by group_events (e.g., FunctionRun, SessionRun). + kImplicitHostStepMarker, + // Derived from the result of group_events. A device step marker starts with + // the first device event of the group and ends with the last event of the + // group. + kDeviceStepMarker, +}; + +// Record of an event that is used as a step marker. +struct StepMarker { + StepMarkerType type; + std::string event_name; // name of this event. + std::string step_name; + tsl::profiler::Timespan span; // timespan of this event. + StepMarker(StepMarkerType step_marker_type, absl::string_view name, + tsl::profiler::Timespan s) + : type(step_marker_type), event_name(name), span(s) {} + // Equality test. + bool operator==(const StepMarker& other) const { + return type == other.type && event_name == other.event_name && + span == other.span; + } + // Inequality test. + bool operator!=(const StepMarker& other) const { return !(*this == other); } +}; + +// Details of a step. Note that this could be the result of combining the +// StepDetails of the same step executed on different cores. +class StepDetails { + public: + StepDetails() : device_memory_transfers_(3) {} + + const std::vector& Markers() const { return markers_; } + const std::vector& Events() const { return events_; } + + const absl::flat_hash_map& Collectives() const { + return collectives_; + } + const std::vector& DeviceMemoryTransfers() const { + return device_memory_transfers_; + } + + absl::flat_hash_map& PerCoreOpMetricsDb() { + return per_core_op_metrics_db_; + } + // Returns the step time. + tsl::profiler::Timespan StepTime() const; + // Adds a step-marker to this step. + void AddMarker(const StepMarker& m); + // Adds an EventTypeSpan to this step. + void AddEvent(const EventTypeSpan& e); + // Adds a collective op to this step. + void AddCollectiveOpEvent(uint64 core_id, const AllReduceInfo& e); + // Appends device memory transfer events to this step. + // Only event type of HOST_TO_DEVICE/DEVICE_TO_DEVICE/DEVICE_TO_HOST are + // allowed. + void AddDeviceMemoryTransferEvent(EventType event_type, + const tsl::profiler::Timespan& time_span, + uint64 bytes); + // Returns the step name. + std::string StepName() const { return step_name_; } + // Sets the name of this step. + void SetStepName(std::string step_name) { step_name_ = step_name; } + + // Converts from overlapped events to non-overlapped events. + StepDetails ToNonOverlapped() const; + + // Combines other. + void Combine(const StepDetails& other); + + // Equality test. + bool operator==(const StepDetails& other) const; + // Inequality test. + bool operator!=(const StepDetails& other) const { return !(*this == other); } + + // Returns a string that prints the content of this object. + std::string DebugString() const; + + void SetPerCoreOpMetricsDb(OpMetricsDb db, uint32 core_id) { + per_core_op_metrics_db_[core_id] = db; + } + + private: + // Accumulates the device memory transfers from another step to this step. + void AggregateDeviceMemoryTransfers( + const std::vector& device_memory_transfers); + + // All step-markers found for marking this step in the traces. There could be + // multiple step-markers for a single step for different reasons. One such + // reason is that there may be one step-marker for the same step on each core; + // so after combining the StepDetails from multiple cores, there would be + // multiple step-markers for the same step. + std::vector markers_; + // All events belonging to this step. + std::vector events_; + // Collective operation related events such as all-reduce etc. + absl::flat_hash_map collectives_; + // Device memory transfers (including time and bytes involved). + // TODO(jiesun): Consider to use IntervalSet instead of just sum up the event + // durations. + std::vector device_memory_transfers_; + std::string step_name_; + + absl::flat_hash_map per_core_op_metrics_db_; +}; + +// Map from step_id to the events happened in that step. +using StepEvents = absl::flat_hash_map; + +// Equality test for StepEvents. +bool operator==(const StepEvents& a, const StepEvents& b); + +// Returns the name of the given EventType. +std::string PrintEventType(EventType event_type); + +// Returns the string of the given GenericEventType. +absl::string_view GetGenericEventTypeStr(GenericEventType event_type); + +// Returns a string that prints the given EventTypeSpan. +std::string PrintEventTypeSpan(const EventTypeSpan& event_type_span); + +// Returns a string that prints the given StepMarker. +std::string PrintStepMarker(const StepMarker& step_marker); + +// Returns a string that prints the given StepEvents. +std::string PrintStepEvents(const StepEvents& step_events); + +// Unions the map of StepEvents and combines the src StepEvents into dst. +void UnionCombineStepEvents(const StepEvents& src, StepEvents* dst); + +// Intersects the map of StepEvents and combines the src StepEvents into dst. +void IntersectCombineStepEvents(const StepEvents& src, StepEvents* dst); + +// Converts from overlapped events to non-overlapped events. +std::vector ToNonOverlappedEvents( + const std::vector& overlapped_events); + +// Converts from overlapped step-events to non-overlapped step events. +StepEvents ToNonOverlappedStepEvents(const StepEvents& overlapped_step_events); + +// Returns the precision stats of the given non-overlapped step events. +PrecisionStats ComputePrecisionStats( + const StepEvents& nonoverlapped_step_events); + +} // namespace profiler +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PROFILER_UTILS_EVENT_SPAN_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/profiler/utils/gpu_event_stats.h b/third_party/tflite-hdrs/tensorflow/core/profiler/utils/gpu_event_stats.h new file mode 100644 index 00000000..1c711249 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/profiler/utils/gpu_event_stats.h @@ -0,0 +1,83 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PROFILER_UTILS_GPU_EVENT_STATS_H_ +#define TENSORFLOW_CORE_PROFILER_UTILS_GPU_EVENT_STATS_H_ + +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "tensorflow/core/profiler/utils/xplane_visitor.h" + +namespace tensorflow { +namespace profiler { + +// Stats from a GPU stream XEvent. +struct GpuEventStats { + explicit GpuEventStats(const XEventVisitor* event); + + bool IsKernel() const { return !kernel_details.empty(); } + bool IsMemCpy() const { return !memcpy_details.empty(); } + bool IsCudaGraphExecution() const { return cuda_graph_exec_id.has_value(); } + + bool IsXlaOp() const { return !hlo_op_names.empty(); } + bool IsTfOp() const { return !tf_op_fullname.empty(); } + + // Stats from TensorFlow. + absl::string_view tf_op_fullname; + absl::string_view equation; + absl::string_view tensor_shapes; + + // Stats from XLA. + std::vector hlo_op_names; + absl::string_view hlo_module_name; + std::optional program_id; + + // Stats from CUPTI. + absl::string_view kernel_details; + absl::string_view memcpy_details; + std::optional correlation_id; + std::optional scope_range_id; + + // Stats derived by grouping. + std::optional group_id; + bool is_eager = false; + std::optional cuda_graph_exec_id; + std::optional cuda_graph_id_for_inner_node; +}; + +// Stats for a host-side GPU launch XEvent. +struct LaunchEventStats { + explicit LaunchEventStats(const XEventVisitor* event); + + bool IsLaunch() const { + return device_id.has_value() && correlation_id.has_value(); + } + + // Stats from CUPTI. + std::optional device_id; + std::optional correlation_id; + + // Stat derived by grouping. + std::optional group_id; +}; + +} // namespace profiler +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PROFILER_UTILS_GPU_EVENT_STATS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/profiler/utils/hardware_type_utils.h b/third_party/tflite-hdrs/tensorflow/core/profiler/utils/hardware_type_utils.h new file mode 100644 index 00000000..41b1bd4b --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/profiler/utils/hardware_type_utils.h @@ -0,0 +1,82 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PROFILER_UTILS_HARDWARE_TYPE_UTILS_H_ +#define TENSORFLOW_CORE_PROFILER_UTILS_HARDWARE_TYPE_UTILS_H_ + +#include "absl/strings/string_view.h" +#include "tensorflow/core/profiler/protobuf/hardware_types.pb.h" + +namespace tensorflow { +namespace profiler { + +struct GpuFlopCapabilities { + struct FlopCapabilityOnPrecisions { + double fp64_tflops = 0; + double fp32_tflops = 0; // also for tf32 for nvidia tensor core + double bf16_tflops = 0; + double fp16_tflops = 0; + double fp8_tflops = 0; + double int8_tops = 0; + double fp4_tflops = 0; + double int4_tops = 0; + + void ScaleWith(double scale) { + fp64_tflops *= scale; + fp32_tflops *= scale; + bf16_tflops *= scale; + fp16_tflops *= scale; + fp8_tflops *= scale; + int8_tops *= scale; + fp4_tflops *= scale; + int4_tops *= scale; + } + }; + + FlopCapabilityOnPrecisions cuda_core; + FlopCapabilityOnPrecisions tensor_core; + bool has_tensor_core_sparsity_support = false; + + void ScaleWith(double scale) { + cuda_core.ScaleWith(scale); + tensor_core.ScaleWith(scale); + } +}; + +// Get peak single precision throughput of the GPU in GFLOPS per +// streaming multiprocessor. +// TODO: Need design on how to use the sparsity capability of FLOPs. +double GetFlopMaxThroughputPerSM(const DeviceCapabilities& device_cap); + +// for Nvidia GPU, return shared memory bandwidth in Bytes Per Second on +// one single SM given the GPU core freq in device_cap. +double GetSharedMemoryBandwidthPerSM(const DeviceCapabilities& device_cap); + +// Returns the GPU model name from the given DeviceCapabilities. +// For nvidia GPUs, the name is like "Nvidia GPU (Kepler)" or "Nvidia GPU +// (Turing)". For AMD GPUs, the name is like "AMD GPU - gfx-10XX series". +// The model name here for Nvidia GPU in fact refers to its microarchitecture +// name. +absl::string_view GpuModelName(const DeviceCapabilities& device_cap); + +HardwareType ParseHardwareType(absl::string_view device_type); + +// Returns true if the given hardware type has a device. +bool HasDevice(HardwareType x); + +} // namespace profiler +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PROFILER_UTILS_HARDWARE_TYPE_UTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/profiler/utils/hlo_module_map.h b/third_party/tflite-hdrs/tensorflow/core/profiler/utils/hlo_module_map.h new file mode 100644 index 00000000..1ea242f6 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/profiler/utils/hlo_module_map.h @@ -0,0 +1,212 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PROFILER_UTILS_HLO_MODULE_MAP_H_ +#define TENSORFLOW_CORE_PROFILER_UTILS_HLO_MODULE_MAP_H_ + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/service/hlo.pb.h" +#include "xla/service/hlo_cost_analysis.h" +#include "tsl/profiler/protobuf/xplane.pb.h" + +namespace tensorflow { +namespace profiler { + +class HloInstructionInterface { + public: + virtual ~HloInstructionInterface() = default; + virtual absl::string_view Name() const = 0; + virtual xla::HloOpcode HloOpcode() const = 0; + virtual absl::string_view Category() const = 0; + virtual std::string HloOpcodeString() const = 0; + virtual const xla::OpMetadata& Metadata() const = 0; + virtual size_t flops() const = 0; + virtual size_t bytes_accessed() const = 0; + virtual std::string_view op_full_name() const = 0; + virtual std::string source_info() const = 0; + virtual bool isRoot() const = 0; + virtual bool IsFusion() const = 0; + virtual const std::string& Expression() const = 0; + + virtual void ProcessXlaCostAnalysis( + const xla::HloCostAnalysis* cost_analysis) = 0; +}; + +// This wrapper allows caching the results of HloInstruction methods. +// This wrapper is not thread safe. +class HloInstructionWrapper : public HloInstructionInterface { + public: + explicit HloInstructionWrapper( + const xla::HloInstruction* instr, + const xla::HloCostAnalysis* cost_analysis = nullptr); + + // Non copyable + HloInstructionWrapper(const HloInstructionWrapper&) = delete; + HloInstructionWrapper& operator=(const HloInstructionWrapper&) = delete; + // Movable. + HloInstructionWrapper(HloInstructionWrapper&&) = default; + HloInstructionWrapper& operator=(HloInstructionWrapper&&) = default; + + absl::string_view Name() const override { return instr_->name(); } + + xla::HloOpcode HloOpcode() const override { return instr_->opcode(); } + + absl::string_view Category() const override { return category_; } + + std::string HloOpcodeString() const override { + return std::string(xla::HloOpcodeString(instr_->opcode())); + } + + const xla::OpMetadata& Metadata() const override { + return instr_->metadata(); + } + + size_t flops() const override { return flops_; } + size_t bytes_accessed() const override { return bytes_accessed_; } + + std::string_view op_full_name() const override { return op_full_name_; } + std::string source_info() const override; + + bool isRoot() const override { return instr_->IsRoot(); } + bool IsFusion() const override { return !fused_children_.empty(); }; + + void ProcessXlaCostAnalysis( + const xla::HloCostAnalysis* cost_analysis) override { + if (cost_analysis == nullptr) return; + flops_ = cost_analysis->flop_count(*instr_); + bytes_accessed_ = cost_analysis->bytes_accessed(*instr_); + } + + const std::string& Expression() const override { return expression_; } + + void AddFusedChild(const HloInstructionWrapper* child) { + fused_children_.push_back(child); + }; + + const std::vector& FusedChildren() const { + return fused_children_; + } + + private: + const xla::HloInstruction* instr_; + std::vector fused_children_; + std::string op_full_name_; + size_t flops_ = 0; + size_t bytes_accessed_ = 0; + std::string category_; + std::string expression_; +}; + +// Helper class for accessing HloModule. +class HloModuleInterface { + public: + virtual ~HloModuleInterface() = default; + + // If the module contains no instructions. + virtual bool Empty() const = 0; + virtual absl::string_view Name() const = 0; + // Function to populated nested childs= instructions in a fusion. + virtual void GatherFusionInstructions(xla::HloInstruction* inst) = 0; +}; + +// Wraps HLO module and provides an interface that maps HLO names to +// HloInstructionWrappers. +class HloModuleWrapper : public HloModuleInterface { + public: + explicit HloModuleWrapper( + const xla::HloProto& hlo_proto, + std::function shape_func = nullptr); + + explicit HloModuleWrapper( + std::unique_ptr module, + std::function shape_func); + + const HloInstructionWrapper* GetHloInstruction( + absl::string_view hlo_name) const; + HloInstructionWrapper* GetMutableHloInstruction(absl::string_view hlo_name); + + bool Empty() const override { return instructions_by_name_.empty(); } + + absl::string_view Name() const override { return module_->name(); } + void GatherFusionInstructions(xla::HloInstruction* inst) override; + + private: + std::unique_ptr module_; + + // Map of HloInstructionWrappers by name. + using HloInstructionMap = + absl::flat_hash_map; + HloInstructionMap instructions_by_name_; +}; + +// Map of HloModuleWrappers by program_id. +using HloModuleMap = + absl::flat_hash_map; + +void AddHloProto(HloModuleMap& hlo_module_map, uint64_t program_id, + const xla::HloProto& hlo_proto); + +// Process HloModuleMap from single XSpace. +void ProcessHloModuleMapFromXSpace(HloModuleMap& hlo_module_map, + const XSpace* space); + +// WARNING: The returned pointer will be invalidated if HloModuleMap is mutated. +inline const HloModuleWrapper* GetHloModule(const HloModuleMap* hlo_module_map, + uint64_t program_id) { + if (hlo_module_map == nullptr) return nullptr; + auto iter = hlo_module_map->find(program_id); + if (iter == hlo_module_map->end()) return nullptr; + return &iter->second; +} + +inline const HloInstructionWrapper* GetHloInstruction( + const HloModuleMap& hlo_module_map, std::optional program_id, + absl::string_view hlo_name) { + if (!program_id.has_value()) return nullptr; + const auto* hlo_module = GetHloModule(&hlo_module_map, *program_id); + if (hlo_module == nullptr) return nullptr; + return hlo_module->GetHloInstruction(hlo_name); +} + +} // namespace profiler +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PROFILER_UTILS_HLO_MODULE_MAP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/profiler/utils/hlo_module_utils.h b/third_party/tflite-hdrs/tensorflow/core/profiler/utils/hlo_module_utils.h new file mode 100644 index 00000000..100671de --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/profiler/utils/hlo_module_utils.h @@ -0,0 +1,81 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PROFILER_UTILS_HLO_MODULE_UTILS_H_ +#define TENSORFLOW_CORE_PROFILER_UTILS_HLO_MODULE_UTILS_H_ + +#include +#include + +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_module.h" + +namespace tensorflow { +namespace profiler { + +// Sometimes HLO produce a huge string (>100MB). Limit the name size to 1MB. +static constexpr size_t kMaxHlolNameSize = 1000000; + +inline const xla::HloInstruction* FindInstruction(const xla::HloModule& module, + std::string node_name) { + if (absl::StartsWith(node_name, "%")) { + node_name.erase(node_name.begin()); + } + for (const xla::HloComputation* computation : module.computations()) { + auto instrs = computation->instructions(); + auto it = absl::c_find_if(instrs, [&](const xla::HloInstruction* instr) { + // Try with and without "%" at the beginning of the node name. + return absl::EqualsIgnoreCase(instr->name(), node_name) || + absl::EqualsIgnoreCase(instr->name(), + absl::StrCat("%", node_name)); + }); + if (it != instrs.end()) { + return *it; + } + } + return nullptr; +} + +inline const xla::HloComputation* FindComputation( + const xla::HloModule& module, const std::string& comp_name) { + for (const xla::HloComputation* computation : module.computations()) { + if (absl::EqualsIgnoreCase(computation->name(), comp_name)) { + return computation; + } + } + return nullptr; +} + +inline std::string UncachedExpression(const xla::HloInstruction* instr, + bool skip_expression, size_t max_size) { + if (skip_expression) { + return ""; + } + static const auto* hlo_print_options = + new xla::HloPrintOptions(xla::HloPrintOptions() + .set_print_metadata(false) + .set_print_backend_config(false) + .set_print_infeed_outfeed_config(false)); + std::string expression = instr->ToString(*hlo_print_options); + if (expression.size() > max_size) { + expression.resize(max_size); + } + return expression; +} +} // namespace profiler +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PROFILER_UTILS_HLO_MODULE_UTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/profiler/utils/hlo_proto_map.h b/third_party/tflite-hdrs/tensorflow/core/profiler/utils/hlo_proto_map.h new file mode 100644 index 00000000..cb376966 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/profiler/utils/hlo_proto_map.h @@ -0,0 +1,84 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PROFILER_UTILS_HLO_PROTO_MAP_H_ +#define TENSORFLOW_CORE_PROFILER_UTILS_HLO_PROTO_MAP_H_ + +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/service/hlo.pb.h" +#include "tensorflow/core/profiler/protobuf/xplane.pb.h" + +namespace tensorflow { +namespace profiler { + +absl::flat_hash_map> +ParseHloProtosFromXSpace(const XSpace& space); + +class HloProtoMap { + public: + void AddHloProtosFromXSpace(const XSpace& space); + + void AddHloProto(uint64_t program_id, + std::unique_ptr hlo_proto); + // Returns whether is new to HloProtoMap. + bool AddHloProto(uint64_t program_id, const xla::HloProto* hlo_proto); + + size_t size() const { return hlo_protos_by_program_id_.size(); } + + auto begin() const { return hlo_protos_by_program_id_.begin(); } + auto end() const { return hlo_protos_by_program_id_.end(); } + + bool contains(absl::string_view name) const { + return hlo_protos_by_name_.contains(name); + } + + bool contains(uint64_t program_id) const { + return hlo_protos_by_program_id_.contains(program_id); + } + + // Returns a list of module names (not sorted). + std::vector GetModuleList() const; + + // Returns a list of module names sorted alphabetically. + std::vector GetSortedModuleList() const; + + // Returns a list of hlo module names sorted first by heap trace size and then + // by hlo module name alphabetically. + std::vector GetSortedModuleListByHeapTraceSize() const; + + absl::StatusOr GetHloProtoByModuleName( + absl::string_view module_name) const; + + absl::StatusOr GetHloProtoByProgramId( + uint64_t program_id) const; + + private: + absl::flat_hash_map hlo_protos_by_program_id_; + absl::flat_hash_map hlo_protos_by_name_; + std::vector> owned_hlo_protos_; +}; + +} // namespace profiler +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PROFILER_UTILS_HLO_PROTO_MAP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/profiler/utils/hlo_proto_to_module.h b/third_party/tflite-hdrs/tensorflow/core/profiler/utils/hlo_proto_to_module.h new file mode 100644 index 00000000..d89b919d --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/profiler/utils/hlo_proto_to_module.h @@ -0,0 +1,36 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PROFILER_UTILS_HLO_PROTO_TO_MODULE_H_ +#define TENSORFLOW_CORE_PROFILER_UTILS_HLO_PROTO_TO_MODULE_H_ + +#include + +#include "xla/hlo/ir/hlo_module.h" +#include "xla/service/hlo.pb.h" + +namespace tensorflow { +namespace profiler { + +absl::StatusOr> ConvertHloProtoToModule( + const xla::HloProto& hlo_proto); + +std::unique_ptr ConvertHloProtoToModuleIgnoringErrors( + const xla::HloProto& hlo_proto); + +} // namespace profiler +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PROFILER_UTILS_HLO_PROTO_TO_MODULE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/profiler/utils/host_offload_utils.h b/third_party/tflite-hdrs/tensorflow/core/profiler/utils/host_offload_utils.h new file mode 100644 index 00000000..4bb96f2e --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/profiler/utils/host_offload_utils.h @@ -0,0 +1,73 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_PROFILER_UTILS_HOST_OFFLOAD_UTILS_H_ +#define TENSORFLOW_CORE_PROFILER_UTILS_HOST_OFFLOAD_UTILS_H_ + +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "xla/layout.h" +#include "tensorflow/core/profiler/protobuf/xplane.pb.h" +#include "tensorflow/core/profiler/utils/xplane_builder.h" +#include "tensorflow/core/profiler/utils/xplane_visitor.h" +#include "tsl/profiler/protobuf/xplane.pb.h" + +namespace tensorflow { +namespace profiler { + +struct LineBuilderAndEventEndTimeFrontier { + XLineBuilder line_builder; + uint64_t event_end_time_frontier_ns; +}; + +class HostOffloadEventProcessor { + public: + HostOffloadEventProcessor(XPlaneBuilder* plane_builder, + uint64_t start_timestamp_ns) + : plane_builder_(plane_builder), + start_timestamp_ns_(start_timestamp_ns) {} + ~HostOffloadEventProcessor() = default; + + void ProcessHostOffloadOpEvent(const XEventVisitor& event, + std::optional group_id); + + bool IsHostOffloadOpName(const XEventVisitor& event) const; + + private: + std::string GetOffloadInstructionID(absl::string_view op_name) const; + std::string GetOffloadInstructionName(absl::string_view op_name) const; + + absl::flat_hash_map> + seen_events_; + std::string host_memory_label_ = + absl::StrCat("S(", xla::Layout::kHostMemorySpace, ")"); + + XPlaneBuilder* plane_builder_; + uint64_t start_timestamp_ns_; + + std::vector + host_offload_op_line_builders_; +}; + +} // namespace profiler +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PROFILER_UTILS_HOST_OFFLOAD_UTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/profiler/utils/html_utils.h b/third_party/tflite-hdrs/tensorflow/core/profiler/utils/html_utils.h new file mode 100644 index 00000000..215d9f51 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/profiler/utils/html_utils.h @@ -0,0 +1,36 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PROFILER_UTILS_HTML_UTILS_H_ +#define TENSORFLOW_CORE_PROFILER_UTILS_HTML_UTILS_H_ + +#include + +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" + +namespace tensorflow { +namespace profiler { + +// Creates a html that links to the given url with the given text. +inline std::string AnchorElement(absl::string_view url, + absl::string_view text) { + return absl::StrCat("", text, ""); +} + +} // namespace profiler +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PROFILER_UTILS_HTML_UTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/profiler/utils/kernel_stats_utils.h b/third_party/tflite-hdrs/tensorflow/core/profiler/utils/kernel_stats_utils.h new file mode 100644 index 00000000..ee6f56d8 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/profiler/utils/kernel_stats_utils.h @@ -0,0 +1,135 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PROFILER_UTILS_KERNEL_STATS_UTILS_H_ +#define TENSORFLOW_CORE_PROFILER_UTILS_KERNEL_STATS_UTILS_H_ + +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/strings/string_view.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/profiler/protobuf/kernel_stats.pb.h" + +namespace tensorflow { +namespace profiler { + +// Populates kernel launch information from a kKernelDetails XStat. +void ParseKernelLaunchParams(absl::string_view xstat_kernel_details, + KernelReport* kernel); + +// Returns true if kernel uses TensorCores. +bool IsKernelUsingTensorCore(absl::string_view kernel_name); + +// Returns true if operation is eligible to use TensorCores. +bool IsOpTensorCoreEligible(absl::string_view tf_op_name); + +// Returns true if Einsum equation is eligible to use TensorCores. +bool IsEinsumTensorCoreEligible(absl::string_view equation); + +// Less than comparator for Kernel Reports. +struct KernelReportLessThanComparator { + bool operator()(const KernelReport& lhs, const KernelReport& rhs) const; +}; + +// Equal to comparator for Kernel Reports. +struct KernelReportEqualToComparator { + bool operator()(const KernelReport& lhs, const KernelReport& rhs) const; +}; + +// Sorts kernel reorts by total duration descendingly. +// Keeps only the top kernel reports with long kernel duration in the given +// KernelStatsDb. Kernel reports with shorter kernel duration are dropped. +void SortAndKeepTopKDurationKernelReportsInDb(KernelStatsDb* kernel_stats_db); + +struct KernelReportValue { + uint64 total_duration_ns = 0; + uint64 min_duration_ns = 0; + uint64 max_duration_ns = 0; + uint64 occurrences = 0; +}; + +struct KernelKeyWrap { + const KernelReport* key; + template + friend H AbslHashValue(H h, KernelKeyWrap wrap) { + // Kernel reports are grouped by these fields, hence they are used as + // hashing criteria. + // clang-format off + return H::combine( + std::move(h), + wrap.key->is_kernel_using_tensor_core(), + wrap.key->is_op_tensor_core_eligible(), + wrap.key->block_dim(0), + wrap.key->block_dim(1), + wrap.key->block_dim(2), + wrap.key->grid_dim(0), + wrap.key->grid_dim(1), + wrap.key->grid_dim(2), + wrap.key->registers_per_thread(), + wrap.key->static_shmem_bytes(), + wrap.key->dynamic_shmem_bytes(), + wrap.key->name(), + wrap.key->op_name()); + // clang-format on + } +}; + +struct KernelHash { + size_t operator()(const KernelReport& key) const { + return absl::Hash()(KernelKeyWrap{&key}); + } +}; + +using KernelReportMap = + absl::flat_hash_map; + +// Copies the top kernel reports with long kernel duration into the given +// KernelStatsDb. +void CopyTopKDurationKernelReportsToDb(const KernelReportMap& reports, + KernelStatsDb* dst); + +// Inserts or aggregates KernelReports into the given KernelReportMap. +void InsertOrUpdateKernelReport(const KernelReport& kernel, + const KernelReportValue& value, + KernelReportMap* dst); + +// Aggregates values from one KernelReportMap into another. +void MergeKernelReports(const KernelReportMap& reports, KernelReportMap* dst); + +// Kernel stats aggregated at TF operation level. +struct OpLevelKernelStats { + // Whether op is eligible to use TensorCore. + bool is_op_tensor_core_eligible = false; + // The accumulated duration of all the kernels launched in this op. + uint64 total_duration_ns = 0; + // The accumulated duration of all the kernels using TensorCore in this op. + // If this value is not 0, at least one of the kernels launched by this op + // is using TensorCore. + uint64 tensor_core_duration_ns = 0; +}; + +using KernelStatsByOpName = + absl::flat_hash_map; + +// Groups KernelReport in by tensorflow operation name. +KernelStatsByOpName GroupKernelReportsByOpName( + const KernelStatsDb& kernel_stats_db); + +} // namespace profiler +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PROFILER_UTILS_KERNEL_STATS_UTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/profiler/utils/math_utils.h b/third_party/tflite-hdrs/tensorflow/core/profiler/utils/math_utils.h new file mode 100644 index 00000000..380884ee --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/profiler/utils/math_utils.h @@ -0,0 +1,120 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PROFILER_UTILS_MATH_UTILS_H_ +#define TENSORFLOW_CORE_PROFILER_UTILS_MATH_UTILS_H_ + +#include + +#include "absl/base/macros.h" +#include "xla/tsl/profiler/utils/math_utils.h" + +// TODO: b/323943471 - This macro should eventually be provided by Abseil. +#ifndef ABSL_DEPRECATE_AND_INLINE +#define ABSL_DEPRECATE_AND_INLINE() +#endif + +namespace tensorflow { +namespace profiler { + +ABSL_DEPRECATE_AND_INLINE() +inline double CyclesToSeconds(double cycles, double frequency_hz) { + return tsl::profiler::CyclesToSeconds(cycles, frequency_hz); +} + +ABSL_DEPRECATE_AND_INLINE() +inline double GibibytesPerSecond(double gigabytes, double ns) { + return tsl::profiler::GibibytesPerSecond(gigabytes, ns); +} + +ABSL_DEPRECATE_AND_INLINE() +inline double GibiToGiga(double gibi) { + return tsl::profiler::GibiToGiga(gibi); +} + +ABSL_DEPRECATE_AND_INLINE() +inline double GigaToGibi(double giga) { + return tsl::profiler::GigaToGibi(giga); +} + +ABSL_DEPRECATE_AND_INLINE() +inline double GigaToTera(double giga) { + return tsl::profiler::GigaToTera(giga); +} + +ABSL_DEPRECATE_AND_INLINE() +inline double GigaToUni(double giga) { return tsl::profiler::GigaToUni(giga); } + +ABSL_DEPRECATE_AND_INLINE() +inline double MicroToMilli(double u) { return tsl::profiler::MicroToMilli(u); } + +ABSL_DEPRECATE_AND_INLINE() +inline double MicroToNano(double u) { return tsl::profiler::MicroToNano(u); } + +ABSL_DEPRECATE_AND_INLINE() +inline uint64_t MilliToNano(double m) { return tsl::profiler::MilliToNano(m); } + +ABSL_DEPRECATE_AND_INLINE() +inline uint64_t MilliToPico(double m) { return tsl::profiler::MilliToPico(m); } + +ABSL_DEPRECATE_AND_INLINE() +inline double MilliToUni(double m) { return tsl::profiler::MilliToUni(m); } + +ABSL_DEPRECATE_AND_INLINE() +inline double NanoToMicro(uint64_t n) { return tsl::profiler::NanoToMicro(n); } + +ABSL_DEPRECATE_AND_INLINE() +inline double NanoToMilli(uint64_t n) { return tsl::profiler::NanoToMilli(n); } + +ABSL_DEPRECATE_AND_INLINE() +inline uint64_t NanoToPico(uint64_t n) { return tsl::profiler::NanoToPico(n); } + +ABSL_DEPRECATE_AND_INLINE() +inline double PicoToMicro(uint64_t p) { return tsl::profiler::PicoToMicro(p); } + +ABSL_DEPRECATE_AND_INLINE() +inline double PicoToMilli(uint64_t p) { return tsl::profiler::PicoToMilli(p); } + +ABSL_DEPRECATE_AND_INLINE() +inline double PicoToNano(uint64_t p) { return tsl::profiler::PicoToNano(p); } + +ABSL_DEPRECATE_AND_INLINE() +inline double PicoToUni(uint64_t p) { return tsl::profiler::PicoToUni(p); } + +ABSL_DEPRECATE_AND_INLINE() +inline double SafeDivide(double dividend, double divisor) { + return tsl::profiler::SafeDivide(dividend, divisor); +} +ABSL_DEPRECATE_AND_INLINE() +inline double TeraToGiga(double tera) { + return tsl::profiler::TeraToGiga(tera); +} + +ABSL_DEPRECATE_AND_INLINE() +inline double UniToGiga(double uni) { return tsl::profiler::UniToGiga(uni); } + +ABSL_DEPRECATE_AND_INLINE() +inline double UniToMicro(double uni) { return tsl::profiler::UniToMicro(uni); } + +ABSL_DEPRECATE_AND_INLINE() +inline uint64_t UniToNano(double uni) { return tsl::profiler::UniToNano(uni); } + +ABSL_DEPRECATE_AND_INLINE() +inline uint64_t UniToPico(double uni) { return tsl::profiler::UniToPico(uni); } + +} // namespace profiler +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PROFILER_UTILS_MATH_UTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/profiler/utils/op_metrics_db_utils.h b/third_party/tflite-hdrs/tensorflow/core/profiler/utils/op_metrics_db_utils.h new file mode 100644 index 00000000..e3ff3fcc --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/profiler/utils/op_metrics_db_utils.h @@ -0,0 +1,138 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PROFILER_UTILS_OP_METRICS_DB_UTILS_H_ +#define TENSORFLOW_CORE_PROFILER_UTILS_OP_METRICS_DB_UTILS_H_ + +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "xla/tsl/profiler/utils/xplane_visitor.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/profiler/protobuf/op_metrics.pb.h" + +namespace tensorflow { +namespace profiler { + +// The name of OpMetrics to represent the idle time. +TF_CONST_INIT extern const absl::string_view kIdle; +// The core index to add to sparse core index in op metrics. +TF_CONST_INIT extern const uint32_t kSparseCoreIndexStart; + +// Helps build an op metrics database (borrowed). +// Enables fast lookup of existing ops and prevents the creation of duplicate +// ops. It is the user's responsibility to ensure an op metrics database +// outlives its builder, and that no ops are added to the database outside of +// the builder. +class OpMetricsDbBuilder { + public: + // Create with a borrowed op database. + // REQUIRED: The op database must be empty. + explicit OpMetricsDbBuilder(OpMetricsDb* db); + + protected: + // Looks up the given OP name. If it is already in the database, + // return its OpMetrics; otherwise, insert a new one. + OpMetrics* LookupOrInsertNewOpMetrics(uint64 hlo_module_id, + absl::string_view name, + uint64_t fingerprint); + + OpMetricsDb* db() { return db_; } + + private: + // Map op (hlo_module_id, name) to the corresponding metrics in the op + // database. + absl::flat_hash_map> + op_metrics_map_; + + // The op database. + OpMetricsDb* db_; +}; + +// Helps build an op metrics database (borrowed) from XEvents, +class XEventsOpMetricsDbBuilder { + public: + // Add OpMetric from XEventVisitor. + void AddOpMetric(const tsl::profiler::XEventVisitor& xevent); + + // Finalize OpMetricDb and add total time and Idle op. + OpMetricsDb Finalize(uint64_t total_time); + + // Finalize OpMetricDb, but the total time is unknown at the moment, So ignore + // the total time and Idle Op and will be handled by the caller. + OpMetricsDb Finalize(); + + private: + using OpMetricBySymbol = + absl::flat_hash_map; + absl::flat_hash_map + flat_op_metric_; +}; + +// Sets the total time for OpMetricsDb, ensuring idle time is not negative. +inline void SetTotalTimePs(OpMetricsDb& db, uint64_t total_time_ps) { + db.set_total_time_ps(std::max(db.total_op_time_ps(), total_time_ps)); +} + +// Returns the total time in OpMetricsDb, optionally excluding the idle time. +inline uint64_t TotalTimePs(const OpMetricsDb& db, bool exclude_idle = false) { + return exclude_idle ? db.total_op_time_ps() : db.total_time_ps(); +} + +// Returns the ratio of time that is idle (no op execution) over total time. +double IdleTimeRatio(const OpMetricsDb& db); + +// Returns the idle time in picoseconds. +uint64 IdleTimePs(const OpMetricsDb& db); + +// Populates an OpMetrics record representing idle time, i.e., the amount of +// time spent without any op execution. +void SetIdleOp(uint64_t idle_time_ps, OpMetrics& metrics); + +// Adds an OpMetrics record representing idle time, i.e., the amount of time +// spent without any op execution. +// REQUIRED: All ops must have been added to the database and the total time +// must have been set. +void AddIdleOp(OpMetricsDb& db); + +// Returns true if the given metrics represents idle time. +inline bool IsIdleOp(const OpMetrics& metrics) { + return metrics.category() == kIdle; +} + +// Returns the time spent in children (nested) ops. +inline uint64_t ChildrenTimePs(const OpMetrics& metrics) { + return metrics.time_ps() - metrics.self_time_ps(); +} + +// Returns the ratio of time spent sending data from the host to the device +// relative to the total time the host was active. +std::optional HostInfeedEnqueueRatio(const OpMetricsDb& db); + +// Converts from the device op metrics to Tf-op metrics. +OpMetricsDb CreateTfMetricsDbFromDeviceOpMetricsDb( + const OpMetricsDb& device_op_metrics_db, bool with_idle = true); + +} // namespace profiler +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PROFILER_UTILS_OP_METRICS_DB_UTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/profiler/utils/op_utils.h b/third_party/tflite-hdrs/tensorflow/core/profiler/utils/op_utils.h new file mode 100644 index 00000000..b3329b08 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/profiler/utils/op_utils.h @@ -0,0 +1,106 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PROFILER_UTILS_OP_UTILS_H_ +#define TENSORFLOW_CORE_PROFILER_UTILS_OP_UTILS_H_ + +#include "absl/strings/string_view.h" +#include "xla/tsl/profiler/utils/timespan.h" +#include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/profiler/protobuf/op_metrics.pb.h" +#include "tensorflow/core/profiler/utils/hlo_module_map.h" +#include "tensorflow/core/profiler/utils/op_metrics_db_utils.h" + +namespace tensorflow { +namespace profiler { + +// Annotate the op_metrics with the metadata from the instr_wrapper. +void EnterOpMetadata(OpMetrics* op_metrics, + const HloInstructionWrapper* instr_wrapper); +void EnterOpMetadataFromHloModuleMap(OpMetrics* op_metrics, + const HloModuleMap& hlo_module_map); + +void AddFusionChildrenToOpMetricsFromHloInstruction( + OpMetrics* op_metrics, const HloInstructionWrapper* instr_wrapper); + +class HostOpMetricsDbBuilder : public OpMetricsDbBuilder { + public: + explicit HostOpMetricsDbBuilder(OpMetricsDb* db) : OpMetricsDbBuilder(db) {} + + // A function that will be called when the end of an OP is + // observed on a trace, where: + // name = the OP name. + // category = the OP category. + // is_eager = whether this OP is eagerly executed. + // time_ps = the total execution time of the OP in picoseconds, including + // the execution time of its children. + // children_time_ps = the execution time of the children of this OP in + // picoseconds + void EnterOp(absl::string_view name, absl::string_view category, + bool is_eager, uint64 time_ps, uint64 children_time_ps); + + // Updates total_host_infeed_enq_duration_ps_ and + // total_host_infeed_enq_duration_ps_. + void EnterHostInfeedEnqueue(tsl::profiler::Timespan host_infeed_enqueue); + + private: + // The tsl::profiler::Timespan of the last InfeedEnqueue op on this thread. + tsl::profiler::Timespan last_host_infeed_enqueue_; +}; + +class DeviceOpMetricsDbBuilder : public OpMetricsDbBuilder { + public: + explicit DeviceOpMetricsDbBuilder(OpMetricsDb* db) : OpMetricsDbBuilder(db) {} + + // A function that will be called when the end of an OP is + // observed on a trace, where: + // program_id = the ID of the program that contains this OP. + // name = the OP name. + // category = the OP category. + // provenance = the provenance of this OP (e.g. original TF OP). + // is_eager = whether this OP is eagerly executed. + // occurrences = the number of occurrences of this OP. + // time_ps = the total execution time of the OP in picoseconds, including + // the execution time of its children. + // children_time_ps = the execution time of the children of this OP in + // picoseconds. + // flops = the number of floating-point operations computed. + // bytes_accessed = the sum of bytes read and bytes written by this OP. + // memory_accessed_breakdown = the breakdown of memory accessed by operation + // type and memory space. + void EnterOp(uint64 program_id, absl::string_view name, + absl::string_view category, absl::string_view provenance, + absl::string_view deduplicated_name, bool is_eager, + uint64 occurrences, uint64 time_ps, uint64 children_time_ps, + int64_t flops, int64_t bytes_accessed, + const protobuf::RepeatedPtrField& + memory_accessed_breakdown = {}, + int64_t model_flops = 0); + + void EnterOpMetadata(uint64 program_id, absl::string_view program_name, + absl::string_view category, absl::string_view provenance, + absl::string_view deduplicated_name, bool is_eager, + absl::string_view long_name = ""); + + void EnterOpMetadataFromHloModuleMap(uint64 program_id, + absl::string_view op_name, + const HloModuleMap& hlo_module_map); +}; + +} // namespace profiler +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PROFILER_UTILS_OP_UTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/profiler/utils/step_intersection.h b/third_party/tflite-hdrs/tensorflow/core/profiler/utils/step_intersection.h new file mode 100644 index 00000000..cf2961ca --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/profiler/utils/step_intersection.h @@ -0,0 +1,86 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PROFILER_UTILS_STEP_INTERSECTION_H_ +#define TENSORFLOW_CORE_PROFILER_UTILS_STEP_INTERSECTION_H_ + +#include + +#include "absl/container/flat_hash_map.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/profiler/protobuf/steps_db.pb.h" + +namespace tensorflow { +namespace profiler { + +// Description of how two step sequences are aligned. +struct StepsAlignment { + uint32 begin_subordinate_idx; // where the alignment begins on the + // subordinate steps. + uint32 begin_chief_idx; // where the alignment begins on the chief steps. + uint32 num_steps; // aligned for how many steps. +}; + +class StepIntersection { + public: + StepIntersection( + uint32 max_steps, + const absl::flat_hash_map& + perhost_stepdb); + + // Returns the number of steps in the intersection. + uint32 NumSteps() const { return end_chief_idx_ - begin_chief_idx_; } + + // Returns the value of empty_intersect_ (see the explanation of + // empty_intersect_ below). + bool EmptyIntersect() const { return empty_intersect_; } + + // Returns the step numbers for the destination (i.e. the intersection + // result). + std::vector DstStepNumbers() const; + + // Returns the index to the step in the given host that corresponds to the + // first step in the intersection. + uint32 FirstStepIndex(uint32 host_id) const; + + // Returns the number of steps dropped due to the max_steps constraint + // specified in the constructor. + uint32 StepsDropped() const { return steps_dropped_; } + + std::string DebugString() const; + + private: + absl::flat_hash_map perhost_alignment_; + uint32 + chief_host_id_; // the host whose step sequence is selected as the chief. + uint32 steps_dropped_; // number of steps dropped. + // If NumSteps() is 0, empty_intersect indicates one of two possible reasons: + // (i) At least one host has some steps, but the intersection over all hosts + // is empty. In this case, empty_intersect is true, + // (ii) None of the hosts has any steps. In this case, empty_intersect is + // false. + // If NumSteps() > 0, empty_intersect is don't care. + bool empty_intersect_; + // The begin and end indices to the chief step sequence for this step + // intersection. Note that the begin index is inclusive but the end index is + // exclusive. + uint32 begin_chief_idx_; + uint32 end_chief_idx_; +}; + +} // namespace profiler +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PROFILER_UTILS_STEP_INTERSECTION_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/profiler/utils/tfstreamz_utils.h b/third_party/tflite-hdrs/tensorflow/core/profiler/utils/tfstreamz_utils.h new file mode 100644 index 00000000..25b7436c --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/profiler/utils/tfstreamz_utils.h @@ -0,0 +1,41 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_PROFILER_UTILS_TFSTREAMZ_UTILS_H_ +#define TENSORFLOW_CORE_PROFILER_UTILS_TFSTREAMZ_UTILS_H_ + +#include +#include + +#include "tensorflow/core/lib/monitoring/collected_metrics.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/profiler/protobuf/xplane.pb.h" + +namespace tensorflow { +namespace profiler { + +struct TfStreamzSnapshot { + std::unique_ptr metrics; + uint64 start_time_ns; // time before collection. + uint64 end_time_ns; // time after collection. +}; + +absl::Status SerializeToXPlane(const std::vector& snapshots, + XPlane* plane, uint64 line_start_time_ns); + +} // namespace profiler +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PROFILER_UTILS_TFSTREAMZ_UTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/profiler/utils/tpu_step_breakdown_utils.h b/third_party/tflite-hdrs/tensorflow/core/profiler/utils/tpu_step_breakdown_utils.h new file mode 100644 index 00000000..731481a4 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/profiler/utils/tpu_step_breakdown_utils.h @@ -0,0 +1,75 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_PROFILER_UTILS_TPU_STEP_BREAKDOWN_UTILS_H_ +#define TENSORFLOW_CORE_PROFILER_UTILS_TPU_STEP_BREAKDOWN_UTILS_H_ + +#include + +#include "tensorflow/core/profiler/protobuf/steps_db.pb.h" + +namespace tensorflow { +namespace profiler { + +// Total duration of infeed from host or SparseCoreV0 to TensorCore. +inline uint64_t InfeedDurationPs(const TpuStepBreakdown& tpu) { + return tpu.infeed_duration_ps() + tpu.wait_for_scv0_duration_ps() + + tpu.scv0_infeed_transform_ps(); +} + +// Total duration of outfeed from TensorCore to host or SparseCoreV0. +inline uint64_t OutfeedDurationPs(const TpuStepBreakdown& tpu) { + return tpu.host_outfeed_ps() + tpu.scv0_outfeed_ps(); +} + +// Total duration of infeed from host to SparseCoreV0. +inline uint64_t ScV0InfeedDurationPs(const TpuStepBreakdown& tpu) { + return tpu.wait_for_scv0_duration_ps() * tpu.scv0_infeed_percent() / 100.0; +} + +// Total duration of SparseCoreV0 compute. +inline uint64_t ScV0ComputeDurationPs(const TpuStepBreakdown& tpu) { + return tpu.wait_for_scv0_duration_ps() - ScV0InfeedDurationPs(tpu); +} + +// Total duration of infeed from host to TensorCore or SparseCoreV0. +inline uint64_t TcPlusScV0InfeedDurationPs(const TpuStepBreakdown& tpu) { + return tpu.infeed_duration_ps() + ScV0InfeedDurationPs(tpu); +} + +// Total duration of send and recv ops. +inline uint64_t SendRecvDurationPs(const TpuStepBreakdown& tpu) { + return tpu.send_duration_ps() + tpu.recv_duration_ps(); +} + +// Total duration of host send and host recv ops. +inline uint64_t HostSendRecvDurationPs(const TpuStepBreakdown& tpu) { + return tpu.host_send_duration_ps() + tpu.host_recv_duration_ps(); +} + +// Total duration TensorCore spends waiting for host. +inline uint64_t WaitForHostDurationPs(const TpuStepBreakdown& tpu) { + return tpu.infeed_duration_ps() + tpu.host_outfeed_ps() + + HostSendRecvDurationPs(tpu) + tpu.tc_idle_ps(); +} + +// Total duration TensorCore spends waiting for host or SparseCoreV0. +inline uint64_t WaitForHostOrScV0DurationPs(const TpuStepBreakdown& tpu) { + return WaitForHostDurationPs(tpu) + tpu.wait_for_scv0_duration_ps(); +} + +} // namespace profiler +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PROFILER_UTILS_TPU_STEP_BREAKDOWN_UTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/profiler/utils/tpu_step_details_utils.h b/third_party/tflite-hdrs/tensorflow/core/profiler/utils/tpu_step_details_utils.h new file mode 100644 index 00000000..d26e4973 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/profiler/utils/tpu_step_details_utils.h @@ -0,0 +1,51 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_PROFILER_UTILS_TPU_STEP_DETAILS_UTILS_H_ +#define TENSORFLOW_CORE_PROFILER_UTILS_TPU_STEP_DETAILS_UTILS_H_ + +#include + +#include "tensorflow/core/profiler/protobuf/tpu_input_pipeline.pb.h" + +namespace tensorflow { +namespace profiler { + +inline double ComputeTimeMs(const PerTpuStepDetails& details) { + return details.tc_compute_time_ms() + details.scv0_compute_time_ms(); +} + +inline double InfeedTimeMs(const PerTpuStepDetails& details) { + return details.tc_infeed_time_ms() + details.scv0_infeed_time_ms(); +} + +inline double AllReduceTimeMs(const PerTpuStepDetails& details) { + return details.all_reduce_compute_time_ms() + + details.all_reduce_sync_time_ms(); +} + +inline double NonIdleTimeMs(const PerTpuStepDetails& details) { + return ComputeTimeMs(details) + InfeedTimeMs(details) + + AllReduceTimeMs(details) + details.tc_outfeed_time_ms(); +} + +// Time spent by a training step on TPU. +inline double StepTimeMs(const PerTpuStepDetails& details) { + return NonIdleTimeMs(details) + details.tc_idle_time_ms(); +} + +} // namespace profiler +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PROFILER_UTILS_TPU_STEP_DETAILS_UTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/profiler/utils/trace_utils.h b/third_party/tflite-hdrs/tensorflow/core/profiler/utils/trace_utils.h new file mode 100644 index 00000000..89e2b4cd --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/profiler/utils/trace_utils.h @@ -0,0 +1,44 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PROFILER_UTILS_TRACE_UTILS_H_ +#define TENSORFLOW_CORE_PROFILER_UTILS_TRACE_UTILS_H_ + +#include "xla/tsl/profiler/utils/trace_utils.h" + +namespace tensorflow { +namespace profiler { + +using tsl::profiler::IsDerivedThreadId; // NOLINT +using tsl::profiler::kFirstDeviceId; // NOLINT +using tsl::profiler::kHostThreadsDeviceId; // NOLINT +using tsl::profiler::kLastDeviceId; // NOLINT +using tsl::profiler::kThreadIdDerivedMax; // NOLINT +using tsl::profiler::kThreadIdDerivedMin; // NOLINT +using tsl::profiler::kThreadIdHloModule; // NOLINT +using tsl::profiler::kThreadIdHloOp; // NOLINT +using tsl::profiler::kThreadIdHostOffloadOpEnd; // NOLINT +using tsl::profiler::kThreadIdHostOffloadOpStart; // NOLINT +using tsl::profiler::kThreadIdKernelLaunch; // NOLINT +using tsl::profiler::kThreadIdOverhead; // NOLINT +using tsl::profiler::kThreadIdSource; // NOLINT +using tsl::profiler::kThreadIdStepInfo; // NOLINT +using tsl::profiler::kThreadIdTfNameScope; // NOLINT +using tsl::profiler::kThreadIdTfOp; // NOLINT + +} // namespace profiler +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PROFILER_UTILS_TRACE_UTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/profiler/utils/xplane_builder.h b/third_party/tflite-hdrs/tensorflow/core/profiler/utils/xplane_builder.h new file mode 100644 index 00000000..c0e2c39b --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/profiler/utils/xplane_builder.h @@ -0,0 +1,38 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_PROFILER_UTILS_XPLANE_BUILDER_H_ +#define TENSORFLOW_CORE_PROFILER_UTILS_XPLANE_BUILDER_H_ + +#include + +#include +#include +#include +#include + +#include "xla/tsl/profiler/utils/xplane_builder.h" + +namespace tensorflow { +namespace profiler { + +using tsl::profiler::XEventBuilder; // NOLINT +using tsl::profiler::XLineBuilder; // NOLINT +using tsl::profiler::XPlaneBuilder; // NOLINT +using tsl::profiler::XStatsBuilder; // NOLINT + +} // namespace profiler +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PROFILER_UTILS_XPLANE_BUILDER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/profiler/utils/xplane_schema.h b/third_party/tflite-hdrs/tensorflow/core/profiler/utils/xplane_schema.h new file mode 100644 index 00000000..cfa748bf --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/profiler/utils/xplane_schema.h @@ -0,0 +1,79 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PROFILER_UTILS_XPLANE_SCHEMA_H_ +#define TENSORFLOW_CORE_PROFILER_UTILS_XPLANE_SCHEMA_H_ + +#include "xla/tsl/profiler/utils/xplane_schema.h" + +namespace tensorflow { +namespace profiler { + +using tsl::profiler::FindHostEventType; // NOLINT +using tsl::profiler::FindStatType; // NOLINT +using tsl::profiler::FindTfOpEventType; // NOLINT +using tsl::profiler::GetHostEventTypeStr; // NOLINT +using tsl::profiler::GetStatTypeStr; // NOLINT +using tsl::profiler::GpuPlaneName; // NOLINT +using tsl::profiler::HostEventType; // NOLINT +using tsl::profiler::IsHostEventType; // NOLINT +using tsl::profiler::IsInternalEvent; // NOLINT +using tsl::profiler::IsInternalStat; // NOLINT +using tsl::profiler::IsStatType; // NOLINT +using tsl::profiler::kCuptiDriverApiPlaneName; // NOLINT +using tsl::profiler::kCustomPlanePrefix; // NOLINT +using tsl::profiler::kDeviceVendorAMD; // NOLINT +using tsl::profiler::kDeviceVendorNvidia; // NOLINT +using tsl::profiler::kGpuPlanePrefix; // NOLINT +using tsl::profiler::kHostOffloadOpLineName; // NOLINT +using tsl::profiler::kHostThreadsPlaneName; // NOLINT +using tsl::profiler::kKernelLaunchLineName; // NOLINT +using tsl::profiler::kMegaScaleBarrier; // NOLINT +using tsl::profiler::kMegaScaleD2HTransferFinished; // NOLINT +using tsl::profiler::kMegaScaleD2HTransferStart; // NOLINT +using tsl::profiler::kMegaScaleDcnReceive; // NOLINT +using tsl::profiler::kMegaScaleDcnSend; // NOLINT +using tsl::profiler::kMegaScaleDcnSendFinished; // NOLINT +using tsl::profiler::kMegaScaleH2DTransferFinished; // NOLINT +using tsl::profiler::kMegaScaleH2DTransferStart; // NOLINT +using tsl::profiler::kMegaScaleHostCommand; // NOLINT +using tsl::profiler::kMegaScaleTopologyDiscovery; // NOLINT +using tsl::profiler::kMetadataPlaneName; // NOLINT +using tsl::profiler::kPythonTracerPlaneName; // NOLINT +using tsl::profiler::kRoctracerApiPlaneName; // NOLINT +using tsl::profiler::kSourceLineName; // NOLINT +using tsl::profiler::kSparseCorePlaneRegex; // NOLINT +using tsl::profiler::kStepLineName; // NOLINT +using tsl::profiler::kTensorFlowNameScopeLineName; // NOLINT +using tsl::profiler::kTensorFlowOpLineName; // NOLINT +using tsl::profiler::kTFStreamzPlaneName; // NOLINT +using tsl::profiler::kTpuPlanePrefix; // NOLINT +using tsl::profiler::kTpuPlaneRegex; // NOLINT +using tsl::profiler::kTpuRuntimePlaneName; // NOLINT +using tsl::profiler::kXlaAsyncOpLineName; // NOLINT +using tsl::profiler::kXlaModuleLineName; // NOLINT +using tsl::profiler::kXlaOpLineName; // NOLINT +using tsl::profiler::kXProfMetadataBufferSize; // NOLINT +using tsl::profiler::kXProfMetadataFlow; // NOLINT +using tsl::profiler::kXProfMetadataKey; // NOLINT +using tsl::profiler::kXProfMetadataTransfers; // NOLINT +using tsl::profiler::StatType; // NOLINT +using tsl::profiler::TpuPlaneName; // NOLINT +using tsl::profiler::XFlow; // NOLINT + +} // namespace profiler +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PROFILER_UTILS_XPLANE_SCHEMA_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/profiler/utils/xplane_test_utils.h b/third_party/tflite-hdrs/tensorflow/core/profiler/utils/xplane_test_utils.h new file mode 100644 index 00000000..c2619394 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/profiler/utils/xplane_test_utils.h @@ -0,0 +1,40 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_PROFILER_UTILS_XPLANE_TEST_UTILS_H_ +#define TENSORFLOW_CORE_PROFILER_UTILS_XPLANE_TEST_UTILS_H_ + +#include + +#include "absl/strings/string_view.h" +#include "absl/types/variant.h" +#include "xla/tsl/profiler/utils/xplane_test_utils.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/profiler/utils/xplane_builder.h" +#include "tensorflow/core/profiler/utils/xplane_schema.h" + +namespace tensorflow { +namespace profiler { + +using tsl::profiler::CreateTfFunctionCallEvent; // NOLINT +using tsl::profiler::CreateXEvent; // NOLINT +using tsl::profiler::GetOrCreateGpuXPlane; // NOLINT +using tsl::profiler::GetOrCreateHostXPlane; // NOLINT +using tsl::profiler::GetOrCreateTpuXPlane; // NOLINT +using tsl::profiler::XStatValue; // NOLINT + +} // namespace profiler +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PROFILER_UTILS_XPLANE_TEST_UTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/profiler/utils/xplane_utils.h b/third_party/tflite-hdrs/tensorflow/core/profiler/utils/xplane_utils.h new file mode 100644 index 00000000..9292ed6a --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/profiler/utils/xplane_utils.h @@ -0,0 +1,66 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_PROFILER_UTILS_XPLANE_UTILS_H_ +#define TENSORFLOW_CORE_PROFILER_UTILS_XPLANE_UTILS_H_ + +#include +#include +#include +#include + +#include "xla/tsl/profiler/utils/xplane_utils.h" + +namespace tensorflow { +namespace profiler { + +using tsl::profiler::AddFlowsToXplane; // NOLINT +using tsl::profiler::AggregateXPlane; // NOLINT +using tsl::profiler::FindLinesWithId; // NOLINT +using tsl::profiler::FindLineWithId; // NOLINT +using tsl::profiler::FindLineWithName; // NOLINT +using tsl::profiler::FindMutablePlanes; // NOLINT +using tsl::profiler::FindMutablePlanesWithPrefix; // NOLINT +using tsl::profiler::FindMutablePlaneWithName; // NOLINT +using tsl::profiler::FindOrAddMutablePlaneWithName; // NOLINT +using tsl::profiler::FindOrAddMutableStat; // NOLINT +using tsl::profiler::FindPlanes; // NOLINT +using tsl::profiler::FindPlanesWithNames; // NOLINT +using tsl::profiler::FindPlanesWithPrefix; // NOLINT +using tsl::profiler::FindPlaneWithName; // NOLINT +using tsl::profiler::GetDevicePlaneFingerprint; // NOLINT +using tsl::profiler::GetSortedEvents; // NOLINT +using tsl::profiler::GetStartTimestampNs; // NOLINT +using tsl::profiler::IsEmpty; // NOLINT +using tsl::profiler::MergePlanes; // NOLINT +using tsl::profiler::NormalizeTimestamps; // NOLINT +using tsl::profiler::RemoveEmptyLines; // NOLINT +using tsl::profiler::RemoveEmptyPlanes; // NOLINT +using tsl::profiler::RemoveEvents; // NOLINT +using tsl::profiler::RemoveLine; // NOLINT +using tsl::profiler::RemovePlane; // NOLINT +using tsl::profiler::RemovePlanes; // NOLINT +using tsl::profiler::SortPlanesById; // NOLINT +using tsl::profiler::SortXLinesBy; // NOLINT +using tsl::profiler::SortXPlane; // NOLINT +using tsl::profiler::SortXSpace; // NOLINT +using tsl::profiler::XEventContextTracker; // NOLINT +using tsl::profiler::XEventsComparator; // NOLINT +using tsl::profiler::XEventTimespan; // NOLINT +using tsl::profiler::XLinesComparatorByName; // NOLINT + +} // namespace profiler +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PROFILER_UTILS_XPLANE_UTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/profiler/utils/xplane_visitor.h b/third_party/tflite-hdrs/tensorflow/core/profiler/utils/xplane_visitor.h new file mode 100644 index 00000000..81db4a4f --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/profiler/utils/xplane_visitor.h @@ -0,0 +1,35 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_PROFILER_UTILS_XPLANE_VISITOR_H_ +#define TENSORFLOW_CORE_PROFILER_UTILS_XPLANE_VISITOR_H_ + +#include "xla/tsl/profiler/utils/xplane_visitor.h" + +namespace tensorflow { +namespace profiler { + +using tsl::profiler::TypeGetter; // NOLINT +using tsl::profiler::TypeGetterList; // NOLINT +using tsl::profiler::XEventMetadataVisitor; // NOLINT +using tsl::profiler::XEventVisitor; // NOLINT +using tsl::profiler::XLineVisitor; // NOLINT +using tsl::profiler::XPlaneVisitor; // NOLINT +using tsl::profiler::XStatsOwner; // NOLINT +using tsl::profiler::XStatVisitor; // NOLINT + +} // namespace profiler +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PROFILER_UTILS_XPLANE_VISITOR_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/profiler/utils/xprof_gpu_cost_analysis.h b/third_party/tflite-hdrs/tensorflow/core/profiler/utils/xprof_gpu_cost_analysis.h new file mode 100644 index 00000000..6977295c --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/profiler/utils/xprof_gpu_cost_analysis.h @@ -0,0 +1,53 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PROFILER_UTILS_XPROF_GPU_COST_ANALYSIS_H_ +#define TENSORFLOW_CORE_PROFILER_UTILS_XPROF_GPU_COST_ANALYSIS_H_ + +#include +#include + +#include "absl/status/status.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/service/gpu/model/gpu_hlo_cost_analysis.h" +#include "xla/service/hlo_cost_analysis.h" + +namespace tensorflow { +namespace profiler { + +// XProfGpuCostAnalysis provides additional cost analysis for XProf, which +// normalizes the flops to the device flops based on input bit widths. +class XProfGpuCostAnalysis : public xla::gpu::GpuHloCostAnalysis { + public: + explicit XProfGpuCostAnalysis(const xla::HloCostAnalysis::Options& options) + : xla::gpu::GpuHloCostAnalysis(options) {} + + absl::Status Postprocess(const xla::HloInstruction* hlo) override; + + int64_t GetDeviceFlopsAdjustment(const xla::HloInstruction& hlo); + + protected: + std::unique_ptr CreateNestedCostAnalysis() override; + + private: + static inline constexpr absl::string_view kDeviceFlopsAdjustment = + "device_flops_adjustment"; +}; + +} // namespace profiler +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PROFILER_UTILS_XPROF_GPU_COST_ANALYSIS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/public/session.h b/third_party/tflite-hdrs/tensorflow/core/public/session.h new file mode 100644 index 00000000..b16a5955 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/public/session.h @@ -0,0 +1,362 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PUBLIC_SESSION_H_ +#define TENSORFLOW_CORE_PUBLIC_SESSION_H_ + +#include +#include +#include + +#include "absl/status/status.h" +#include "tensorflow/core/framework/device_attributes.pb.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/threadpool_options.h" +#include "tensorflow/core/protobuf/config.pb.h" +#include "tensorflow/core/public/session_options.h" + +namespace tensorflow { + +class DeviceMgr; + +/// \brief A Session instance lets a caller drive a TensorFlow graph +/// computation. +/// +/// When a Session is created with a given target, a new Session object +/// is bound to the universe of resources specified by that target. +/// Those resources are available to this session to perform +/// computation described in the GraphDef. After extending the session +/// with a graph, the caller uses the Run() API to perform the +/// computation and potentially fetch outputs as Tensors. +/// +/// Example: +/// +/// ```c++ +/// +/// tensorflow::GraphDef graph; +/// // ... Create or load graph into "graph". +/// +/// // This example uses the default options which connects +/// // to a local runtime. +/// tensorflow::SessionOptions options; +/// std::unique_ptr +/// session(tensorflow::NewSession(options)); +/// +/// // Create the session with this graph. +/// tensorflow::Status s = session->Create(graph); +/// if (!s.ok()) { ... } +/// +/// // Run the graph and fetch the first output of the "output" +/// // operation, and also run to but do not return anything +/// // for the "update_state" operation. +/// std::vector outputs; +/// s = session->Run({}, {"output:0"}, {"update_state"}, &outputs); +/// if (!s.ok()) { ... } +/// +/// // Map the output as a flattened float tensor, and do something +/// // with it. +/// auto output_tensor = outputs[0].flat(); +/// if (output_tensor(0) > 0.5) { ... } +/// +/// // Close the session to release the resources associated with +/// // this session. +/// session->Close(); +/// +/// ``` +/// +/// A Session allows concurrent calls to Run(), though a Session must +/// be created / extended by a single thread. +/// +/// Only one thread must call Close(), and Close() must only be called +/// after all other calls to Run() have returned. +class Session { + public: + Session(); + virtual ~Session(); + + /// \brief Create the graph to be used for the session. + /// + /// Returns an error if this session has already been created with a + /// graph. To re-use the session with a different graph, the caller + /// must Close() the session first. + virtual absl::Status Create(const GraphDef& graph) = 0; +#ifndef SWIG + virtual absl::Status Create(GraphDef&& graph) { return Create(graph); } +#endif + + /// \brief Adds operations to the graph that is already registered with the + /// Session. + /// + /// The names of new operations in "graph" must not exist in the + /// graph that is already registered. + virtual absl::Status Extend(const GraphDef& graph) = 0; +#ifndef SWIG + virtual absl::Status Extend(GraphDef&& graph) { return Extend(graph); } +#endif + + /// \brief Runs the graph with the provided input tensors and fills + /// `outputs` for the endpoints specified in `output_tensor_names`. + /// Runs to but does not return Tensors for the nodes in + /// `target_tensor_names`. + /// + /// The order of tensors in `outputs` will match the order provided + /// by `output_tensor_names`. + /// + /// If `Run` returns `OK()`, then `outputs->size()` will be equal to + /// `output_tensor_names.size()`. If `Run` does not return `OK()`, the + /// state of `outputs` is undefined. + /// + /// REQUIRES: The name of each Tensor of the input or output must + /// match a "Tensor endpoint" in the `GraphDef` passed to `Create()`. + /// + /// REQUIRES: At least one of `output_tensor_names` and + /// `target_tensor_names` must be non-empty. + /// + /// REQUIRES: outputs is not nullptr if `output_tensor_names` is non-empty. + virtual absl::Status Run( + const std::vector >& inputs, + const std::vector& output_tensor_names, + const std::vector& target_tensor_names, + std::vector* outputs) = 0; + + /// \brief Implementations which support `RunOptions`. + // + /// NOTE: This API is still experimental and may change. + virtual absl::Status Create(const RunOptions& run_options, + const GraphDef& graph) { + return absl::UnimplementedError( + "Create(const RunOptions& run_options, const GraphDef& graph) is not " + "supported for this session."); + } + virtual absl::Status Extend(const RunOptions& run_options, + const GraphDef& graph) { + return absl::UnimplementedError( + "Extend(const RunOptions& run_options, const GraphDef& graph) is not " + "supported for this session."); + } +#ifndef SWIG + virtual absl::Status Create(const RunOptions& run_options, GraphDef&& graph) { + return Create(run_options, graph); + } + virtual absl::Status Extend(const RunOptions& run_options, GraphDef&& graph) { + return Extend(run_options, graph); + } +#endif + virtual absl::Status Close(const RunOptions& run_options) { + return absl::UnimplementedError( + "Close(const RunOptions& run_options) is not supported for this " + "session."); + } + + /// \brief Like `Run`, but allows users to pass in a `RunOptions` proto and + /// to retrieve non-Tensor metadata output via a `RunMetadata` proto for this + /// step. `run_metadata` may be nullptr, in which case any metadata output is + /// discarded. + /// NOTE: This API is still experimental and may change. + virtual absl::Status Run( + const RunOptions& run_options, + const std::vector >& inputs, + const std::vector& output_tensor_names, + const std::vector& target_tensor_names, + std::vector* outputs, RunMetadata* run_metadata); + + /// \brief Like `Run` with `RunOptions` proto, but allows user to provide + /// custom threadpool implementation via ThreadPoolOptions. + /// NOTE: This API is still experimental and may change. + virtual absl::Status Run( + const RunOptions& run_options, + const std::vector >& inputs, + const std::vector& output_tensor_names, + const std::vector& target_tensor_names, + std::vector* outputs, RunMetadata* run_metadata, + const thread::ThreadPoolOptions& threadpool_options) { + return absl::UnimplementedError( + "Run with threadpool is not supported for this session."); + } + + /// \brief Sets up a graph for partial execution. All future feeds and + /// fetches are specified by `input_names` and `output_names`. Returns + /// `handle` that can be used to perform a sequence of partial feeds and + /// fetches. + /// NOTE: This API is still experimental and may change. + virtual absl::Status PRunSetup(const std::vector& input_names, + const std::vector& output_names, + const std::vector& target_nodes, + std::string* handle); + + /// \brief Continues the pending execution specified by `handle` with the + /// provided input tensors and fills `outputs` for the endpoints specified + /// in `output_names`. + /// NOTE: This API is still experimental and may change. + virtual absl::Status PRun( + const std::string& handle, + const std::vector >& inputs, + const std::vector& output_names, + std::vector* outputs); + + /// \brief List devices in the session. + /// + /// Retrieves the list of available devices within the session, and populates + /// *response. This API is optional. If it is unimplemented, Status will + /// return a corresponding error message, and *response will be unmodified. + virtual absl::Status ListDevices(std::vector* response) = 0; + + /// \brief Closes this session. + /// + /// Closing a session releases the resources used by this session + /// on the TensorFlow runtime (specified during session creation by + /// the `SessionOptions::target` field). + virtual absl::Status Close() = 0; + + // NOTE(ashankar): As of July 2017, this method was added to facilitate some + // experimentation. Reconsider/re-evaluate after September 2017. + // + // Sets `*output` to the `DeviceMgr` that owns accessible devices in the + // address-space of the caller. + virtual absl::Status LocalDeviceManager(const DeviceMgr** output) { + return absl::UnimplementedError( + "LocalDeviceManager is not supported for this session."); + } + + /// \brief A handle to a subgraph, created with `Session::MakeCallable()`. + typedef int64_t CallableHandle; + + /// \brief Creates a `handle` for invoking the subgraph defined by + /// `callable_options`. + /// NOTE: This API is still experimental and may change. + virtual absl::Status MakeCallable(const CallableOptions& callable_options, + CallableHandle* out_handle) { + return absl::UnimplementedError( + "MakeCallable is not supported for this session."); + } + + /// \brief Invokes the subgraph named by `handle` with the given options and + /// input tensors. + /// + /// The order of tensors in `feed_tensors` must and `fetch_tensors` will + /// match the order of names in `CallableOptions::feed()` and + /// `CallableOptions::fetch()` when this subgraph was created. + /// NOTE: This API is still experimental and may change. + virtual absl::Status RunCallable(CallableHandle handle, + const std::vector& feed_tensors, + std::vector* fetch_tensors, + RunMetadata* run_metadata) { + return absl::UnimplementedError( + "RunCallable is not supported for this session."); + } + + /// \brief Invokes the subgraph named by `handle` with the given options and + /// input tensors. User can provide custom threadpool implementation via + /// threadpool_options. + /// + /// The order of tensors in `feed_tensors` must and `fetch_tensors` will + /// match the order of names in `CallableOptions::feed()` and + /// `CallableOptions::fetch()` when this subgraph was created. + /// NOTE: This API is still experimental and may change. + virtual absl::Status RunCallable( + CallableHandle handle, const std::vector& feed_tensors, + std::vector* fetch_tensors, RunMetadata* run_metadata, + const thread::ThreadPoolOptions& threadpool_options) { + return absl::UnimplementedError( + "RunCallable with threadpool is not supported for this session."); + } + + /// \brief Releases resources associated with the given `handle` in this + /// session. + /// NOTE: This API is still experimental and may change. + virtual absl::Status ReleaseCallable(CallableHandle handle) { + return absl::UnimplementedError( + "ReleaseCallable is not supported for this session."); + } + + /// \brief Release global graph-related state in this session. + /// + /// After calling `this->Finalize()`, calls to `this->Run()` with previously + /// unseen feeds and fetches, and calls to `this->MakeCallable()` will fail. + /// Using `MakeCallable()` and `RunCallable()` is recommended, because + /// explicit callable creation makes it clearer where the `Finalize()` call + /// should be placed. + /// + /// This API can be used in conjunction with a "warmup" phase to reduce the + /// memory consumed by the session: + /// + /// 1. Call `Session::Create()`. + /// 2. Call `Session::MakeCallable()` for all subgraphs that you will execute + /// in the session. + /// 3. Call `Session::Finalize()` to release global graph-related state. + /// 4. Call `Session::RunCallable()` with the handle(s) created in step 2. + /// + /// NOTE: This API is still experimental and may change. + virtual absl::Status Finalize() { + return absl::UnimplementedError( + "Finalize is not supported for this session."); + } +}; + +/// \brief Create a new session with the given options. +/// +/// If session creation succeeds, the new `Session` will be stored in +/// `*out_session`, the caller will take ownership of the returned +/// `*out_session`, and this function will return `OK()`. Otherwise, this +/// function will return an error status and set *out_session to nullptr. +absl::Status NewSession(const SessionOptions& options, Session** out_session); + +/// \brief Resets resource containers associated with a target. +/// +/// Reset() allows misbehaving or slow sessions to be aborted and closed, and +/// causes their resources eventually to be released. Reset() does not wait +/// for the computations in old sessions to cease; it merely starts the +/// process of tearing them down. However, if a new session is started after +/// a Reset(), the new session is isolated from changes that old sessions +/// (started prior to the Reset()) may continue to make to resources, provided +/// all those resources are in containers listed in "containers". +/// +/// Old sessions may continue to have side-effects on resources not in +/// containers listed in "containers", and thus may affect future +/// sessions' results in ways that are hard to predict. Thus, if well-defined +/// behavior is desired, it is recommended that all containers be listed in +/// "containers". +/// +/// `containers` is a vector of string representation of resource container +/// names. When a resource container is reset, the resources held by the +/// container will be released. In particular, all Variables in the container +/// will become undefined. If the "containers" vector is empty, the default +/// container is assumed. If the "containers" vector is non-empty, the +/// default container should be listed explicitly. +/// +/// If Reset succeeds, this function will return `OK()`. Otherwise, this +/// function will return an error status. +absl::Status Reset(const SessionOptions& options, + const std::vector& containers); + +/// \brief Create a new session with the given options. +/// +/// If a new `Session` object could not be created, this function will +/// return nullptr. +/// +/// *Strongly prefer* the version of NewSession that returns Status, +/// which contains more helpful error information. +Session* NewSession(const SessionOptions& options); + +/// \brief Export the metric that indicates the session is created. +void SetSessionCreatedMetric(); + +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_PUBLIC_SESSION_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/public/session_options.h b/third_party/tflite-hdrs/tensorflow/core/public/session_options.h new file mode 100644 index 00000000..92134528 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/public/session_options.h @@ -0,0 +1,67 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PUBLIC_SESSION_OPTIONS_H_ +#define TENSORFLOW_CORE_PUBLIC_SESSION_OPTIONS_H_ + +#include + +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/protobuf/config.pb.h" + +namespace tsl { +class Env; +} // namespace tsl +namespace tensorflow { + +/// Configuration information for a Session. +struct SessionOptions { + /// The environment to use. + tsl::Env* env; + + /// \brief The TensorFlow runtime to connect to. + /// + /// If 'target' is empty or unspecified, the local TensorFlow runtime + /// implementation will be used. Otherwise, the TensorFlow engine + /// defined by 'target' will be used to perform all computations. + /// + /// "target" can be either a single entry or a comma separated list + /// of entries. Each entry is a resolvable address of the + /// following format: + /// local + /// ip:port + /// host:port + /// ... other system-specific formats to identify tasks and jobs ... + /// + /// NOTE: at the moment 'local' maps to an in-process service-based + /// runtime. + /// + /// Upon creation, a single session affines itself to one of the + /// remote processes, with possible load balancing choices when the + /// "target" resolves to a list of possible processes. + /// + /// If the session disconnects from the remote process during its + /// lifetime, session calls may fail immediately. + std::string target; + + /// Configuration options. + ConfigProto config; + + SessionOptions(); +}; + +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_PUBLIC_SESSION_OPTIONS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/public/version.h b/third_party/tflite-hdrs/tensorflow/core/public/version.h new file mode 100644 index 00000000..72ec42a5 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/public/version.h @@ -0,0 +1,127 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PUBLIC_VERSION_H_ +#define TENSORFLOW_CORE_PUBLIC_VERSION_H_ + +// TensorFlow uses semantic versioning, see http://semver.org/. + +// Also update tensorflow/tensorflow.bzl and +// tensorflow/tools/pip_package/setup.py +#define TF_MAJOR_VERSION 2 +#define TF_MINOR_VERSION 19 +#define TF_PATCH_VERSION 0 + +// TF_VERSION_SUFFIX is non-empty for pre-releases (e.g. "-alpha", "-alpha.1", +// "-beta", "-rc", "-rc.1") +#define TF_VERSION_SUFFIX "" + +#define TF_STR_HELPER(x) #x +#define TF_STR(x) TF_STR_HELPER(x) + +// e.g. "0.5.0" or "0.6.0-alpha". +#define TF_VERSION_STRING \ + (TF_STR(TF_MAJOR_VERSION) "." TF_STR(TF_MINOR_VERSION) "." TF_STR( \ + TF_PATCH_VERSION) TF_VERSION_SUFFIX) + +// GraphDef compatibility versions (the versions field in graph.proto). +// +// Each graph has producer and min_consumer versions, and each +// consumer has its own version and a min_producer. In addition, graphs can +// mark specific consumer versions as bad (to prevent bugs from executing). +// A consumer will execute a graph if the consumer's version is at least the +// graph's min_consumer, the graph's producer version is at least the consumer's +// min_producer, and the consumer version isn't specifically disallowed by the +// graph. +// +// By default, newly created graphs have producer version TF_GRAPH_DEF_VERSION +// min_consumer TF_GRAPH_DEF_MIN_CONSUMER, and no other bad consumer versions. +// +// Version history: +// +// 0. Graphs created before GraphDef versioning +// 1. First real version (2dec2015) +// 2. adjust_contrast only takes float, doesn't perform clamping (11dec2015) +// 3. Remove TileGrad, since it was equivalent to reduce_sum (30dec2015) +// 4. When support for this version is removed, we can safely make AttrValue +// parsing more strict with respect to empty list values (see +// 111635679, 7jan2016). +// 5. Graphs are wholly-validated during Session::Create() (7jan2016). +// 6. TensorFlow is scalar strict within Google (27jan2016). +// 7. Remove TopK in favor of TopKV2 (5feb2016). +// 8. Replace RandomCrop from C++ with pure Python (5feb2016). +// 9. Deprecate batch_norm_with_global_normalization (16feb2016). +// 10. Deprecate conv3d_backprop_{filter,input} (10jun2016). +// 11. Deprecate {batch}_self_adjoint_eig (3aug2016). +// 12. Graph consumers understand the node_def field of FunctionDef (22aug2016). +// 13. Deprecate multiple batch linear algebra ops (9sep2016). +// 14. Deprecate batch_matrix_* ops. (10sep2016). +// 15. Deprecate batch_fft_* ops. (14sep2016). +// 16. Deprecate tensor_array (v1) ops in favor of v2 (10nov2016). +// 17. Deprecate inv (11nov2016). +// 17. Expose reverse_v2 (10nov2016) +// 18. Add VariableV2 (30nov2016) +// 19. Deprecated ops created by models moved out of core SkipGram, NegTrain. +// (08dec2016) +// 20. Catch all version 1.0 changes to Python API generation. SplitV is now +// used for tf.split, ReverseV2 is now used by tf.reverse, ConcatV2 is +// now used by tf.concat. Graphs use flooring +// division and mod semantics. TensorArrayV3. (12dec2016) +// Also considered the version for when it is required for reduction +// ops' indices to be scalar or vector, and not higher rank. +// Some earlier graph def versions allowed this. +// 21. Dropped FunctionDef.Node support, switched to node_def introduced +// in version 12. (11jan2017) +// 22. Placeholder now can specify and enforce scalar and partial +// shapes, particularly when restoring a graph from GraphDef +// produced at version 22 or later. (04/10/2016) +// 23. Remove NonMaxSuppression in favor of NonMaxSuppressionV2. +// 24. Deprecate lookup ops (v1) ops in favor of v2 (30may2017) +// 25. Deprecate stack (v1) ops in favor of v2 (2017/6/15). +// 25. Deprecate RandomPoisson (v1) ops in favor of v2 (2017/10/25). +// 26. Add a bool 'stripped_default_attrs' to MetaInfoDef indicating +// whether default-valued attrs have been stripped from the nodes in the +// GraphDef. (7dec2017) +// 27. Deprecate TensorArray ops v2 in favor of v3 and deprecated io_ops +// deprecated in favor of V2 ops. (2018/01/23) +// 28. Deprecate MatrixExponential op in favor of Python implementation. +// (2018/08/21). +// (2019/02/15). Added `control_ret` field to FunctionDef proto, and +// `control_output` field to OpDef proto. +// 29. Deprecate StatefulStandardNormal op in favor of StatefulStandardNormalV2. +// (2019/03/25). +// (2019/04/17). Added `arg_attr` field to FunctionDefProto. +// 30. (2019/05/09) First date based GraphDef version. GraphDef +// versions advance by 1 each day after this point. + +#define TF_GRAPH_DEF_VERSION_MIN_PRODUCER 0 +#define TF_GRAPH_DEF_VERSION_MIN_CONSUMER 0 +#define TF_GRAPH_DEF_VERSION 2102 // Updated: 2025/1/9 + +// Checkpoint compatibility versions (the versions field in SavedSliceMeta). +// +// The checkpoint versions have the same semantics as GraphDef versions, but the +// numbering scheme is separate. We have no plans to ever deprecate checkpoint +// versions, but it's good to have this in place in case we ever need to. +// +// Version history: +// +// 0. Checkpoints saved before checkpoint versioning. +// 1. First real version (10feb2015). +#define TF_CHECKPOINT_VERSION_MIN_PRODUCER 0 +#define TF_CHECKPOINT_VERSION_MIN_CONSUMER 0 +#define TF_CHECKPOINT_VERSION 1 + +#endif // TENSORFLOW_CORE_PUBLIC_VERSION_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/runtime_fallback/bef_executor_flags.h b/third_party/tflite-hdrs/tensorflow/core/runtime_fallback/bef_executor_flags.h new file mode 100644 index 00000000..eccc43de --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/runtime_fallback/bef_executor_flags.h @@ -0,0 +1,51 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_RUNTIME_FALLBACK_BEF_EXECUTOR_FLAGS_H_ +#define TENSORFLOW_CORE_RUNTIME_FALLBACK_BEF_EXECUTOR_FLAGS_H_ + +#include "absl/flags/declare.h" +#include "absl/flags/flag.h" +#include "tfrt/bef_executor_driver/bef_executor_driver.h" // from @tf_runtime + +namespace tfrt { +ABSL_CONST_INIT extern const char kDefaultInputFilename[]; + +struct HostAllocatorTypeWrapper { + HostAllocatorTypeWrapper(HostAllocatorType type) : type(type) {} + operator HostAllocatorType() { return type; } + HostAllocatorType type; +}; + +} // namespace tfrt + +ABSL_DECLARE_FLAG(std::string, input_filename); +ABSL_DECLARE_FLAG(std::string, shared_libs); +ABSL_DECLARE_FLAG(std::string, functions); +ABSL_DECLARE_FLAG(std::string, test_init_function); +ABSL_DECLARE_FLAG(std::string, work_queue_type); +ABSL_DECLARE_FLAG(tfrt::HostAllocatorTypeWrapper, host_allocator_type); + +namespace tfrt { + +bool AbslParseFlag(absl::string_view text, + tfrt::HostAllocatorTypeWrapper* host_allocator_type, + std::string* error); + +std::string AbslUnparseFlag(tfrt::HostAllocatorTypeWrapper host_allocator_type); + +} // namespace tfrt + +#endif // TENSORFLOW_CORE_RUNTIME_FALLBACK_BEF_EXECUTOR_FLAGS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/runtime_fallback/conversion/conversion.h b/third_party/tflite-hdrs/tensorflow/core/runtime_fallback/conversion/conversion.h new file mode 100644 index 00000000..c31855e2 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/runtime_fallback/conversion/conversion.h @@ -0,0 +1,36 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This file implements conversion function between RuntimeFallback and +// KernelFallback. + +#ifndef TENSORFLOW_CORE_RUNTIME_FALLBACK_CONVERSION_CONVERSION_H_ +#define TENSORFLOW_CORE_RUNTIME_FALLBACK_CONVERSION_CONVERSION_H_ + +namespace tfrt { + +class TensorConversionFnRegistry; + +} + +namespace tensorflow { +namespace tfd { +void RegisterRuntimeFallbackTensorToKernelFallbackConversionFn( + tfrt::TensorConversionFnRegistry* registry); + +} // namespace tfd +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_RUNTIME_FALLBACK_CONVERSION_CONVERSION_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/runtime_fallback/kernel/attr_util.h b/third_party/tflite-hdrs/tensorflow/core/runtime_fallback/kernel/attr_util.h new file mode 100644 index 00000000..4abbb4f8 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/runtime_fallback/kernel/attr_util.h @@ -0,0 +1,54 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_RUNTIME_FALLBACK_KERNEL_ATTR_UTIL_H_ +#define TENSORFLOW_CORE_RUNTIME_FALLBACK_KERNEL_ATTR_UTIL_H_ + +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "llvm/ADT/StringMap.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/stringpiece.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/runtime_fallback/util/attr_util.h" +#include "tensorflow/core/util/padding.h" +#include "tfrt/core_runtime/op_attrs.h" // from @tf_runtime +#include "tfrt/host_context/kernel_utils.h" // from @tf_runtime + +namespace tensorflow { + +// Map from attribute name to a string value representation. +typedef llvm::StringMap AttrMap; + +// Parse value from the given string input. +absl::Status ParseValue(absl::string_view input, bool* value); +absl::Status ParseValue(absl::string_view input, int32* value); +absl::Status ParseValue(absl::string_view input, DataType* value); +absl::Status ParseValue(absl::string_view input, std::string* value); +absl::Status ParseValue(absl::string_view input, std::vector* value); +absl::Status ParseValue(absl::string_view input, Padding* value); + +absl::Status AddOpAttr(const std::string& name, const std::string& attr_value, + tfrt::OpAttrs* opattrs); + +absl::Status FillOpAttrs(tfrt::RemainingAttributes attrs, + tfrt::OpAttrs* opattrs); +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_RUNTIME_FALLBACK_KERNEL_ATTR_UTIL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/runtime_fallback/kernel/conversion/conversion.h b/third_party/tflite-hdrs/tensorflow/core/runtime_fallback/kernel/conversion/conversion.h new file mode 100644 index 00000000..782e31f7 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/runtime_fallback/kernel/conversion/conversion.h @@ -0,0 +1,42 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This file implements conversion function between KernelFallback and Host +// Tensor. + +#ifndef TENSORFLOW_CORE_RUNTIME_FALLBACK_KERNEL_CONVERSION_CONVERSION_H_ +#define TENSORFLOW_CORE_RUNTIME_FALLBACK_KERNEL_CONVERSION_CONVERSION_H_ + +#include "tfrt/support/forward_decls.h" // from @tf_runtime +namespace tfrt { + +class TensorConversionFnRegistry; +class DenseHostTensor; +class CpuDevice; +class Device; +class ExecutionContext; +} + +namespace tensorflow { +class KernelFallbackTensor; +namespace tfd { + +void RegisterKernelFallbackTensorConversionFn( + tfrt::TensorConversionFnRegistry* registry); + +} // namespace tfd +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_RUNTIME_FALLBACK_KERNEL_CONVERSION_CONVERSION_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/runtime_fallback/kernel/kernel_fallback_compat_request_state.h b/third_party/tflite-hdrs/tensorflow/core/runtime_fallback/kernel/kernel_fallback_compat_request_state.h new file mode 100644 index 00000000..6cfbf88c --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/runtime_fallback/kernel/kernel_fallback_compat_request_state.h @@ -0,0 +1,249 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_RUNTIME_FALLBACK_KERNEL_KERNEL_FALLBACK_COMPAT_REQUEST_STATE_H__ +#define TENSORFLOW_CORE_RUNTIME_FALLBACK_KERNEL_KERNEL_FALLBACK_COMPAT_REQUEST_STATE_H__ + +#include +#include +#include +#include +#include + +#include "tensorflow/core/framework/collective.h" +#include "tensorflow/core/framework/device.h" +#include "tensorflow/core/framework/resource_mgr.h" +#include "tensorflow/core/platform/refcount.h" +#include "tensorflow/core/platform/threadpool_interface.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/tfrt/fallback/cost_recorder.h" +#include "tensorflow/core/tfrt/fallback/op_kernel_runner.h" +#include "tensorflow/core/tfrt/graph_executor/config.h" +#include "tensorflow/core/tfrt/utils/fallback_tensor.h" +#include "tfrt/host_context/async_value_ref.h" // from @tf_runtime +#include "tfrt/host_context/execution_context.h" // from @tf_runtime +#include "tfrt/host_context/resource_context.h" // from @tf_runtime +#include "tfrt/support/pointer_util.h" // from @tf_runtime + +namespace tensorflow { +namespace tfd { + +// FallbackResourceArray holds the tensors that are computed only once during +// initialization and read-only afterwards. +class FallbackResourceArray { + public: + // Sets `tensor` in the array at `index`. `index` should be dense and + // duplicate indices are not allowed. + void SetResource(int index, tfrt_stub::ImmutableTensor tensor); + + // Returns the resource tensor wrapped in AsyncValue value at `index`. + tfrt::AsyncValuePtr GetResource(int index) const { + return resource_async_values_.at(index).AsPtr(); + } + + // Returns the resource tensor at `index`. + const tfrt_stub::FallbackTensor& GetResourceAsFallbackTensor( + int index) const { + return GetResource(index).get(); + } + + private: + // `resources_` holds the ownership of all the resource tensors. Note that it + // may not be a one-to-one mapping between `resources_` and + // `resource_async_values_`. + std::vector> resources_; + + // Storage for async values with manually managed lifetime. + std::vector>> + resource_storage_; + + // `resource_async_values_` holds the UnRefCountedAsyncValue of the fallback + // tensors that can be directly used by fallback kernels in the graph. + std::vector> + resource_async_values_; +}; + +// Per-request state in kernel falllback compat mode. +class KernelFallbackCompatRequestState { + public: + // NOTE: This is the constructor for training. + KernelFallbackCompatRequestState( + std::function)>* runner, + const tensorflow::DeviceMgr* device_manager, int64_t step_id, + tfrt::OwnedOrUnownedPtr step_container, + std::unique_ptr collective_executor, + core::RefCountPtr rendezvous, + tfrt_stub::OpKernelRunnerTable* runner_table, + FallbackResourceArray* resource_array, + tensorflow::thread::ThreadPoolInterface* user_intra_op_threadpool, + const absl::optional& model_metadata, + const tensorflow::ProcessFunctionLibraryRuntime* pflr); + + // NOTE: This is the constructor for inference. + KernelFallbackCompatRequestState( + std::function)>* runner, + const tensorflow::DeviceMgr* device_manager, int64_t step_id, + tfrt_stub::OpKernelRunnerTable* runner_table, + FallbackResourceArray* resource_array, + tensorflow::thread::ThreadPoolInterface* user_intra_op_threadpool, + const absl::optional& model_metadata, + const tensorflow::ProcessFunctionLibraryRuntime* pflr); + + int64_t step_id() const { return step_id_; } + + // Returns the user-specified custom device corresponding to the given device. + // It is currently only used for configure per-request intra op threadpool. + tensorflow::Device* custom_device(const tensorflow::Device* device) const { + auto it = custom_device_.find(device); + if (it == custom_device_.end()) return nullptr; + return it->second.get(); + } + + tensorflow::Device* cpu_device() const { return cpu_device_; } + tensorflow::FunctionLibraryRuntime* cpu_function_library_runtime() const { + return cpu_function_library_runtime_; + } + + ScopedStepContainer* step_container() const { return step_container_.get(); } + + const tensorflow::DeviceMgr& device_manager() const { + return *device_manager_; + } + + const tensorflow::ProcessFunctionLibraryRuntime& + process_function_library_runtime() const { + return *pflr_; + } + + CollectiveExecutor* collective_executor() const { + return collective_executor_; + } + + tfrt_stub::OpKernelRunnerTable* runner_table() const { return runner_table_; } + + FallbackResourceArray* resource_array() const { return resource_array_; } + + std::function)>* runner() const { return runner_; } + + CancellationManager* cancellation_manager() const { + return cancellation_manager_; + } + void set_cancellation_manager(CancellationManager* cancellation_manager) { + cancellation_manager_ = cancellation_manager; + } + + RendezvousInterface* rendezvous() const { return rendezvous_.get(); } + + void set_log_device_placement(bool log) { log_device_placement_ = log; } + bool log_device_placement() const { return log_device_placement_; } + + tensorflow::thread::ThreadPoolInterface* intra_op_threadpool() const { + return intra_op_threadpool_; + } + + const SessionMetadata& session_metadata() const { return session_metadata_; } + + // Nullable. + tensorflow::tfrt_stub::CostRecorder* cost_recorder() const { + return cost_recorder_; + } + void set_cost_recorder(tensorflow::tfrt_stub::CostRecorder* cost_recorder) { + cost_recorder_ = cost_recorder; + } + + // Nullable. + tfrt::ResourceContext* client_graph_resource_context() const { + return client_graph_resource_context_; + } + void set_client_graph_resource_context( + tfrt::ResourceContext* client_graph_resource_context) { + client_graph_resource_context_ = client_graph_resource_context; + } + + void set_runtime_config( + const tensorflow::tfrt_stub::RuntimeConfig* runtime_config) { + runtime_config_ = runtime_config; + } + + const tensorflow::tfrt_stub::RuntimeConfig* runtime_config() const { + return runtime_config_; + } + + private: + int64_t step_id_ = 0; + // Below are resources needed by current tensorflow. + std::function)>* runner_ = nullptr; + ::tfrt::OwnedOrUnownedPtr step_container_; + absl::flat_hash_map> + custom_device_; + std::unique_ptr custom_cpu_device_; + tensorflow::Device* cpu_device_ = nullptr; + tensorflow::FunctionLibraryRuntime* cpu_function_library_runtime_ = nullptr; + std::unique_ptr collective_executor_handle_; + CollectiveExecutor* collective_executor_ = nullptr; + core::RefCountPtr rendezvous_; + CancellationManager* cancellation_manager_ = nullptr; + + const tensorflow::DeviceMgr* device_manager_ = nullptr; + + // `runner_table` holds the prepopulated tensorflow::OpKernel instances for + // kernel fallback compat mode. + tfrt_stub::OpKernelRunnerTable* runner_table_ = nullptr; + + // Resource array is used for keeping static values in the runtime. It is + // accessed through tfrt_fallback_async.set_resource and + // tfrt_fallback_async.get_resource kernels. + FallbackResourceArray* resource_array_ = nullptr; + + tensorflow::thread::ThreadPoolInterface* intra_op_threadpool_ = nullptr; + + // Model metadata used for monitoring and tracing purpose. + SessionMetadata session_metadata_; + + const tensorflow::ProcessFunctionLibraryRuntime* pflr_ = nullptr; + + bool log_device_placement_ = false; + + // Records the cost per op. + tensorflow::tfrt_stub::CostRecorder* cost_recorder_ = nullptr; + + tfrt::ResourceContext* client_graph_resource_context_ = nullptr; + + const tensorflow::tfrt_stub::RuntimeConfig* runtime_config_ = nullptr; +}; + +// Set up fallback context with common tensorflow states such as devices, +// function library runtime. They will be forwarded to tensorflow::OpKernel as +// in tensorflow::Executor. If `runner` is nullptr, internally it will use a +// default runner that executes tasks in the caller thread. +absl::Status SetUpKernelFallbackCompatRequestContext( + tfrt::RequestContextBuilder* builder, + const tensorflow::DeviceMgr* device_manager, + const tensorflow::ProcessFunctionLibraryRuntime* pflr, + tfrt_stub::OpKernelRunnerTable* runner_table, + FallbackResourceArray* resource_array, + tensorflow::thread::ThreadPoolInterface* user_intra_op_threadpool, + const std::optional& model_metadata, + std::function)>* runner, + tfrt_stub::CostRecorder* cost_recorder, + tfrt::ResourceContext* client_graph_resource_context, + tensorflow::CancellationManager* cancellation_manager, + const tensorflow::tfrt_stub::RuntimeConfig* runtime_config); + +} // namespace tfd +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_RUNTIME_FALLBACK_KERNEL_KERNEL_FALLBACK_COMPAT_REQUEST_STATE_H__ diff --git a/third_party/tflite-hdrs/tensorflow/core/runtime_fallback/kernel/kernel_fallback_execute.h b/third_party/tflite-hdrs/tensorflow/core/runtime_fallback/kernel/kernel_fallback_execute.h new file mode 100644 index 00000000..f0c6359b --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/runtime_fallback/kernel/kernel_fallback_execute.h @@ -0,0 +1,51 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Provides a way to execute a TensorFlow kernel using TFRT kernel fallback. + +#ifndef TENSORFLOW_CORE_RUNTIME_FALLBACK_KERNEL_KERNEL_FALLBACK_EXECUTE_H_ +#define TENSORFLOW_CORE_RUNTIME_FALLBACK_KERNEL_KERNEL_FALLBACK_EXECUTE_H_ + +#include "llvm/ADT/ArrayRef.h" +#include "tfrt/core_runtime/op_attrs.h" // from @tf_runtime +#include "tfrt/host_context/async_value.h" // from @tf_runtime +#include "tfrt/host_context/execution_context.h" // from @tf_runtime +#include "tfrt/support/forward_decls.h" // from @tf_runtime +#include "tfrt/support/ref_count.h" // from @tf_runtime + +namespace tfrt { +class AsyncKernelFrame; +} // namespace tfrt + +namespace tensorflow { +namespace tfd { + +enum KernelFallbackOutputType { + TENSOR = 0, // Output type is tensorflow::Tensor + KERNEL_FALLBACK_TENSOR = 1 // Output type is KernelFallbackTensor +}; + +// Runs kernel asynchronously. +// `frame` must contain tensorflow::Tensor inputs and pre-allocated +// tensorflow::Tensor or tfrt::KernelFallbackTensor outputs. +bool KernelFallbackExecute( + const tfrt::ExecutionContext& exec_ctx, tfrt::string_view op_name, + llvm::ArrayRef arguments, + llvm::MutableArrayRef> results, + const tfrt::OpAttrsRef& attrs, KernelFallbackOutputType output_type); +} // namespace tfd +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_RUNTIME_FALLBACK_KERNEL_KERNEL_FALLBACK_EXECUTE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/runtime_fallback/kernel/kernel_fallback_execute_compat.h b/third_party/tflite-hdrs/tensorflow/core/runtime_fallback/kernel/kernel_fallback_execute_compat.h new file mode 100644 index 00000000..a3888486 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/runtime_fallback/kernel/kernel_fallback_execute_compat.h @@ -0,0 +1,56 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_RUNTIME_FALLBACK_KERNEL_KERNEL_FALLBACK_EXECUTE_COMPAT_H_ +#define TENSORFLOW_CORE_RUNTIME_FALLBACK_KERNEL_KERNEL_FALLBACK_EXECUTE_COMPAT_H_ + +#include +#include +#include + +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/threadpool_interface.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/runtime_fallback/kernel/kernel_fallback_compat_request_state.h" +#include "tensorflow/core/tfrt/fallback/op_kernel_runner.h" +#include "tfrt/core_runtime/op_attrs.h" // from @tf_runtime +#include "tfrt/host_context/async_value_ref.h" // from @tf_runtime +#include "tfrt/host_context/chain.h" // from @tf_runtime +#include "tfrt/host_context/execution_context.h" // from @tf_runtime +#include "tfrt/host_context/kernel_utils.h" // from @tf_runtime +#include "tfrt/support/forward_decls.h" // from @tf_runtime +#include "tfrt/tensor/tensor.h" // from @tf_runtime + +namespace tfrt { +class SyncKernelFrame; +} // namespace tfrt + +namespace tensorflow { +namespace tfd { + +ABSL_CONST_INIT extern const char kOpKernelRunnerCacheResourceName[]; + +// The CoreRuntime dispatch function to run a TF kernel in kernel fallback +// compat mode. +tfrt::AsyncValueRef KernelFallbackExecuteCompatCoreRuntimeDispatch( + const tfrt::ExecutionContext& exec_ctx, tfrt::string_view op_name, + tfrt::string_view device_name, llvm::ArrayRef arguments, + llvm::MutableArrayRef> results, + const KernelFallbackCompatRequestState& fallback_request_state, + const tfrt_stub::OpKernelRunner& op_kernel_runner); + +} // namespace tfd +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_RUNTIME_FALLBACK_KERNEL_KERNEL_FALLBACK_EXECUTE_COMPAT_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/runtime_fallback/kernel/kernel_fallback_execute_compat_eager.h b/third_party/tflite-hdrs/tensorflow/core/runtime_fallback/kernel/kernel_fallback_execute_compat_eager.h new file mode 100644 index 00000000..cf3e0014 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/runtime_fallback/kernel/kernel_fallback_execute_compat_eager.h @@ -0,0 +1,39 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_RUNTIME_FALLBACK_KERNEL_KERNEL_FALLBACK_EXECUTE_COMPAT_EAGER_H_ +#define TENSORFLOW_CORE_RUNTIME_FALLBACK_KERNEL_KERNEL_FALLBACK_EXECUTE_COMPAT_EAGER_H_ + +#include + +#include "tensorflow/core/common_runtime/eager/context.h" +#include "tensorflow/core/tfrt/fallback/op_kernel_runner.h" +#include "tfrt/host_context/execution_context.h" // from @tf_runtime + +namespace tensorflow { +namespace tfd { + +// Runner_table can be nullptr. In that case, kernel_fallback will use +// the default runner_table. +absl::Status SetUpKernelFallbackCompatRequestContext( + tfrt::RequestContextBuilder* builder, + tfrt_stub::OpKernelRunnerTable* runner_table, + tensorflow::EagerContext* eager_context, + tensorflow::thread::ThreadPoolInterface* user_intra_op_threadpool = nullptr, + const absl::optional& model_metadata = std::nullopt); + +} // namespace tfd +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_RUNTIME_FALLBACK_KERNEL_KERNEL_FALLBACK_EXECUTE_COMPAT_EAGER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/runtime_fallback/kernel/kernel_fallback_op_handler.h b/third_party/tflite-hdrs/tensorflow/core/runtime_fallback/kernel/kernel_fallback_op_handler.h new file mode 100644 index 00000000..b003c4e9 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/runtime_fallback/kernel/kernel_fallback_op_handler.h @@ -0,0 +1,36 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This file declares KernelFallbackOpHandler, responsible for running TFRT ops +// on Tensorflow. + +#ifndef TENSORFLOW_CORE_RUNTIME_FALLBACK_KERNEL_KERNEL_FALLBACK_OP_HANDLER_H_ +#define TENSORFLOW_CORE_RUNTIME_FALLBACK_KERNEL_KERNEL_FALLBACK_OP_HANDLER_H_ + +#include "llvm/Support/Error.h" +#include "tfrt/core_runtime/core_runtime.h" // from @tf_runtime +#include "tfrt/core_runtime/op_handler.h" // from @tf_runtime +#include "tfrt/host_context/device.h" // from @tf_runtime +#include "tfrt/support/ref_count.h" // from @tf_runtime + +namespace tensorflow { +namespace tfd { + +llvm::Expected CreateKernelFallbackOpHandler( + tfrt::CoreRuntime* runtime, tfrt::RCReference device); + +} // namespace tfd +} // namespace tensorflow +#endif // TENSORFLOW_CORE_RUNTIME_FALLBACK_KERNEL_KERNEL_FALLBACK_OP_HANDLER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/runtime_fallback/kernel/kernel_fallback_tensor.h b/third_party/tflite-hdrs/tensorflow/core/runtime_fallback/kernel/kernel_fallback_tensor.h new file mode 100644 index 00000000..8ade7d00 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/runtime_fallback/kernel/kernel_fallback_tensor.h @@ -0,0 +1,66 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This file declares TF kernel fallback tensor. + +#ifndef TENSORFLOW_CORE_RUNTIME_FALLBACK_KERNEL_KERNEL_FALLBACK_TENSOR_H_ +#define TENSORFLOW_CORE_RUNTIME_FALLBACK_KERNEL_KERNEL_FALLBACK_TENSOR_H_ + +#include + +#include "tensorflow/core/framework/tensor.h" +#include "tfrt/dtype/dtype.h" // from @tf_runtime +#include "tfrt/support/forward_decls.h" // from @tf_runtime +#include "tfrt/tensor/tensor.h" // from @tf_runtime +#include "tfrt/tensor/tensor_shape.h" // from @tf_runtime + +namespace tensorflow { + +class BaseKernelFallbackTensor : public tfrt::Tensor { + public: + explicit BaseKernelFallbackTensor(::tensorflow::Tensor tensor); + BaseKernelFallbackTensor(const tfrt::TensorShape& shape, tfrt::DType dtype, + ::tensorflow::Tensor tensor); + + void Print(tfrt::raw_ostream& os) const override; + + const ::tensorflow::Tensor* GetTensor() const { return &tensor_; } + + private: + ::tensorflow::Tensor tensor_; + bool is_valid_type_; +}; + +class KernelFallbackTensor final + : public BaseKernelFallbackTensor, + public tfrt::TensorTraits { + public: + explicit KernelFallbackTensor(::tensorflow::Tensor tensor) + : BaseKernelFallbackTensor(std::move(tensor)) {} + KernelFallbackTensor(const tfrt::TensorShape& shape, tfrt::DType dtype, + ::tensorflow::Tensor tensor) + : BaseKernelFallbackTensor(shape, dtype, std::move(tensor)) {} + + static KernelFallbackTensor Create(const tensorflow::Tensor& tensor) { + return KernelFallbackTensor(tensor); + } + + // Tensor type name for KernelFallbackTensor. + static const char* name() { return "KernelFallback"; } +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_RUNTIME_FALLBACK_KERNEL_KERNEL_FALLBACK_TENSOR_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/runtime_fallback/kernel/kernel_fallback_utils.h b/third_party/tflite-hdrs/tensorflow/core/runtime_fallback/kernel/kernel_fallback_utils.h new file mode 100644 index 00000000..b3879716 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/runtime_fallback/kernel/kernel_fallback_utils.h @@ -0,0 +1,55 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_RUNTIME_FALLBACK_KERNEL_KERNEL_FALLBACK_UTILS_H_ +#define TENSORFLOW_CORE_RUNTIME_FALLBACK_KERNEL_KERNEL_FALLBACK_UTILS_H_ + +#include + +#include "tensorflow/core/framework/device.h" +#include "tensorflow/core/runtime_fallback/kernel/kernel_fallback_compat_request_state.h" +#include "tensorflow/core/tfrt/fallback/op_kernel_runner.h" +#include "tensorflow/core/tfrt/utils/fallback_tensor.h" +#include "tfrt/host_context/async_value.h" // from @tf_runtime +#include "tfrt/host_context/sync_kernel_utils.h" // from @tf_runtime +#include "tfrt/host_context/value.h" // from @tf_runtime +#include "tfrt/support/forward_decls.h" // from @tf_runtime +#include "tfrt/support/variant.h" // from @tf_runtime + +namespace tensorflow { +namespace tfd { + +std::function)>* GetDefaultRunner(); + +using TfInputs = + tfrt::Variant, + tfrt::RepeatedSyncArguments&>; + +// Sets up the OpKernelcontext::Params in `run_state` with the objects and data +// in `runner`, `fallback_request_state` and `device`. +void SetUpParams(const tensorflow::tfrt_stub::OpKernelRunner& runner, + const KernelFallbackCompatRequestState& fallback_request_state, + tensorflow::Device* device, + tensorflow::tfrt_stub::OpKernelRunState& run_state); + +// Return the device to be used for the fallback kernel execution. The device is +// guaranteed to be alive during the graph execution. +tensorflow::Device* GetDeviceFromFallbackState( + const KernelFallbackCompatRequestState& fallback_request_state, + const tfrt_stub::OpKernelRunner& kernel_runner); + +} // namespace tfd +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_RUNTIME_FALLBACK_KERNEL_KERNEL_FALLBACK_UTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/runtime_fallback/kernel/tensor_util.h b/third_party/tflite-hdrs/tensorflow/core/runtime_fallback/kernel/tensor_util.h new file mode 100644 index 00000000..6126f104 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/runtime_fallback/kernel/tensor_util.h @@ -0,0 +1,132 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_RUNTIME_FALLBACK_KERNEL_TENSOR_UTIL_H_ +#define TENSORFLOW_CORE_RUNTIME_FALLBACK_KERNEL_TENSOR_UTIL_H_ + +#include + +#include "tensorflow/core/common_runtime/copy_tensor.h" +#include "tensorflow/core/framework/device.h" +#include "tfrt/host_context/async_dispatch.h" // from @tf_runtime +#include "tfrt/host_context/async_value_ref.h" // from @tf_runtime + +namespace tfrt { +class Device; +} // namespace tfrt + +namespace tensorflow { +class KernelFallbackTensor; +namespace tfd { + +// Transfers tensor `src` from `src_device` to `dst_device`. +// Returns the transferred tensor on `dst_device` wrapped as +// `TensorWrapperType`. +template +tfrt::AsyncValueRef TransferTensorToDevice( + const tfrt::ExecutionContext& exec_ctx, const Tensor& src, + Device* src_device, Device* dst_device) { + const bool is_same_device = + (src_device == dst_device) || (src_device->name() == dst_device->name()); + + // Note: source and destination CPU devices are expected to be on the same + // host. Currently TFRT doesn't support checking if a CPU is remote CPU, + // we may consider adding a remote CPU device type in the future. + const bool src_cpu = + src_device->tensorflow_accelerator_device_info() == nullptr; + const bool dst_cpu = + dst_device->tensorflow_accelerator_device_info() == nullptr; + const bool is_between_cpu_devices = dst_cpu && src_cpu; + + if (is_same_device || is_between_cpu_devices) { + return tfrt::MakeAvailableAsyncValueRef(src); + } + + if (!dst_cpu && (src.dtype() != tensorflow::DT_VARIANT && + !tensorflow::DataTypeCanUseMemcpy(src.dtype()))) { + return tfrt::MakeErrorAsyncValueRef(absl::InternalError(tfrt::StrCat( + "Can't copy Tensor with type ", tensorflow::DataTypeString(src.dtype()), + " to device ", dst_device->name(), "."))); + } + tensorflow::AllocatorAttributes attr; + if (src.dtype() == tensorflow::DT_VARIANT) { + attr.set_on_host(true); + } + tensorflow::Tensor dst(dst_device->GetAllocator(attr), src.dtype(), + src.shape()); + if (src.shape().num_elements() == 0) { + return tfrt::MakeAvailableAsyncValueRef(dst); + } + + auto result = tfrt::MakeUnconstructedAsyncValueRef(); + bool enqueued = tfrt::EnqueueBlockingWork( + exec_ctx.host(), [result = result.CopyRef(), src_cpu, dst_cpu, src_device, + dst_device, src, dst = std::move(dst)]() mutable { + tensorflow::DeviceContext* src_device_context = nullptr; + if (!src_cpu) { + src_device_context = + src_device->tensorflow_accelerator_device_info()->default_context; + } + tensorflow::DeviceContext* dst_device_context = nullptr; + if (!dst_cpu) { + dst_device_context = + dst_device->tensorflow_accelerator_device_info()->default_context; + } + // TODO(tfrt-devs): The Sync() call below may be more aggressive than + // necessary. It is based on knowledge of implementation details - that + // GPU devices are implemented using 3 streams - one for host->device + // copies, one for device->host copies and one for sending operations to + // the GPU. With that setup, Sync()ing across all 3 streams should be + // sufficient but more than necessary (since it waits for operations + // that might have nothing to do with this tensor to complete). + absl::Status s = src_device->Sync(); + if (!s.ok()) { + result.SetError(absl::InternalError(s.message())); + return; + } + tensorflow::Notification n; + absl::Status status; + tensorflow::CopyTensor::ViaDMA( + "copy", src_device_context, dst_device_context, src_device, + dst_device, tensorflow::AllocatorAttributes(), + tensorflow::AllocatorAttributes(), &src, &dst, + 0 /*dev_to_dev_stream_index*/, + [&status, &n](const absl::Status& s) { + status = s; + n.Notify(); + }); + n.WaitForNotification(); + if (status.ok()) { + result.emplace(std::move(dst)); + } + }); + + if (!enqueued) { + return tfrt::MakeErrorAsyncValueRef(absl::InternalError( + "Failed to enqueue blocking task to transfer tensor.")); + } + return result; +} + +tfrt::AsyncValueRef TransferTensorToDevice( + const tfrt::ExecutionContext& exec_ctx, const KernelFallbackTensor& tensor, + const tfrt::Device& src_device, const tfrt::Device& dst_device); + +llvm::Expected GetTfDevice(const tfrt::ExecutionContext& exec_ctx, + const tfrt::Device& device); + +} // namespace tfd +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_RUNTIME_FALLBACK_KERNEL_TENSOR_UTIL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/runtime_fallback/kernel/tfrt_op_kernel.h b/third_party/tflite-hdrs/tensorflow/core/runtime_fallback/kernel/tfrt_op_kernel.h new file mode 100644 index 00000000..e370fde5 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/runtime_fallback/kernel/tfrt_op_kernel.h @@ -0,0 +1,317 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Compatibility layer for calling directly into a TensorFlow kernel via TFRT, +// bypassing the existing TensorFlow runtime. This file defines: +// +// TFRTOpKernel +// TFRTOpKernelConstruction +// TFRTOpKernelContext +// +// Note that these are standalone objects that do not share a base class with +// TF's corresponding OpKernel, OpKernelConstruction, and OpKernelContext types. +// There is no common base class to avoid virtual call overhead. Kernels that +// support these fallback types must be templated: see +// core/kernels/aggregate_ops.h for an example. + +#ifndef TENSORFLOW_CORE_RUNTIME_FALLBACK_KERNEL_TFRT_OP_KERNEL_H_ +#define TENSORFLOW_CORE_RUNTIME_FALLBACK_KERNEL_TFRT_OP_KERNEL_H_ + +#include +#include +#include +#include + +#include "llvm/ADT/StringMap.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/ManagedStatic.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/stringpiece.h" +#include "tensorflow/core/runtime_fallback/kernel/attr_util.h" +#include "tensorflow/core/runtime_fallback/util/attr_util.h" +#include "tfrt/common/compat/eigen/thread_pool_device.h" // from @tf_runtime +#include "tfrt/core_runtime/op_attrs.h" // from @tf_runtime + +namespace tfrt { +class AsyncKernelFrame; +} // namespace tfrt + +namespace tensorflow { + +class TFRTOpKernel; +class TFRTOpMeta; +class Tensor; +class TensorShape; + +////////////////////////////////////////////////////////////////////// +// OpKernel interface. +////////////////////////////////////////////////////////////////////// +class TFRTOpKernelConstruction { + public: + explicit TFRTOpKernelConstruction(const tfrt::OpAttrsRef& attributes); + + template + absl::Status GetAttr(absl::string_view attr_name, T* value) const; + + void CtxFailure(const absl::Status& s); + void CtxFailureWithWarning(const absl::Status& s); + void CtxFailure(const char* file, int line, const absl::Status& s); + void CtxFailureWithWarning(const char* file, int line, const absl::Status& s); + + absl::Status MatchSignature(const DataTypeSlice expected_inputs, + const DataTypeSlice expected_outputs) { + // TODO(annarev): Move MatchSignatureHelper out of op_kernel.h + // and call it here. + return absl::OkStatus(); + } + + const std::optional& error(); + + private: + const tfrt::OpAttrsRef& attributes_; + // If an error occurs, the error message is stored here. + std::optional error_; +}; + +template <> +absl::Status TFRTOpKernelConstruction::GetAttr(absl::string_view attr_name, + std::string* value) const; + +template <> +absl::Status TFRTOpKernelConstruction::GetAttr(absl::string_view attr_name, + DataType* value) const; + +template <> +absl::Status TFRTOpKernelConstruction::GetAttr(absl::string_view attr_name, + Padding* value) const; + +template <> +absl::Status TFRTOpKernelConstruction::GetAttr(absl::string_view attr_name, + std::vector* value) const; + +absl::Status MissingAttributeError(absl::string_view attr_name); + +template +absl::Status TFRTOpKernelConstruction::GetAttr(absl::string_view attr_name, + T* value) const { + bool success = attributes_.Get( + llvm::StringRef(attr_name.data(), attr_name.size()), value); + if (!success) { + return MissingAttributeError(attr_name); + } + return absl::OkStatus(); +} + +// An implementation of OpKernelContext that fetches inputs from a +// tfrt::AsyncKernelFrame. Outputs and errors are stored internally. +class TFRTOpKernelContext { + public: + explicit TFRTOpKernelContext( + llvm::ArrayRef> inputs, + int num_outputs, const TFRTOpMeta* op_meta, tfrt::HostContext* host); + const Tensor& output(int index); + const std::optional& error(); + + // OpKernelContext interface implementation. + bool ValidateInputsAreSameShape(TFRTOpKernel* op); + const Tensor& input(int index); + int num_inputs() const; + void set_output(int index, const Tensor& tensor); + int num_outputs() const; + bool forward_input_to_output_with_shape(int input_index, int output_index, + const TensorShape& output_shape, + Tensor** output) { + return false; + } + absl::Status allocate_temp(DataType type, const TensorShape& shape, + Tensor* out_temp); + absl::Status allocate_output(int index, const TensorShape& shape, + Tensor** tensor); + DataType expected_output_dtype(int i) const; + + template + const EigenDeviceType& eigen_device() const; + + void CtxFailure(const absl::Status& s); + void CtxFailureWithWarning(const absl::Status& s); + void CtxFailure(const char* file, int line, const absl::Status& s); + void CtxFailureWithWarning(const char* file, int line, const absl::Status& s); + + private: + llvm::ArrayRef> inputs_; + const TFRTOpMeta* op_meta_; + + // The kernel's outputs are kept here. We can't directly store outputs in the + // AsyncKernelFrame because we must temporarily store allocate_output's Tensor + // somewhere until the Tensor is initialized. If we stored the allocated + // Tensor directly in the AsyncKernelFrame, the frame's output becomes + // available and downstream kernels may use the allocated (but uninitialized) + // Tensor. + std::vector outputs_; + + // If an error occurs, the error message is stored here. + std::optional error_; + + tfrt::compat::EigenHostContext eigen_host_context_; +}; + +class TFRTOpKernel { + public: + explicit TFRTOpKernel(TFRTOpKernelConstruction* context) {} + virtual ~TFRTOpKernel() = default; + virtual void Compute(TFRTOpKernelContext* context) = 0; +}; + +inline void CheckNotInComputeAsync(TFRTOpKernelConstruction*, const char*) {} +inline void CheckNotInComputeAsync(TFRTOpKernelContext*, const char*) {} + +////////////////////////////////////////////////////////////////////// +// Forwarding op metadata. +////////////////////////////////////////////////////////////////////// + +// Op metadata. For now TFRTOpMeta only stores the op's output types. +class TFRTOpMeta { + public: + explicit TFRTOpMeta(std::vector output_types); + DataType output_type(int index) const; + + private: + std::vector output_types_; +}; + +// Construct a TFRTOpMeta from .Input(), .Output(), and .Attr() +// specifications. This supports the same syntax as TF's REGISTER_OP macro, but +// this implementation only supports a subset of the full language. +// +// Currently, this only supports single-tensor outputs with fixed type. +// TODO(lauj) Support attribute outputs and compound attribute types as used by +// AddN. +class TFRTOpMetaBuilder { + public: + explicit TFRTOpMetaBuilder(absl::string_view op_name); + TFRTOpMetaBuilder& Output(absl::string_view output_spec); + TFRTOpMetaBuilder& Input(absl::string_view input_spec); + TFRTOpMetaBuilder& Attr(absl::string_view attr_spec); + + const string& op_name() const; + TFRTOpMeta BuildMeta() const; + + private: + string op_name_; + std::vector output_types_; +}; + +// Map from op name to TFRTOpMeta. +class TFRTOpMetaMap { + public: + TFRTOpMetaMap(); + void RegisterOpMeta(const TFRTOpMetaBuilder& op_builder); + + // Returns nullptr if there is no metadata for op_name. + const TFRTOpMeta* GetOpMeta(absl::string_view op_name) const; + + private: + llvm::StringMap op_metas_; +}; + +extern llvm::ManagedStatic tfrt_forwarding_op_meta_map; + +// Implementation detail for REGISTER_KERNEL_FALLBACK_OP. This helps with +// evaluating the .Input()/.Output()/.Attr() clauses in the REGISTER_OP syntax +// before calling BuildMeta(). +class TFRTOpRegisterer { + public: + TFRTOpRegisterer( // NOLINT(google-explicit-constructor) + const TFRTOpMetaBuilder& op_builder); +}; + +#define REGISTER_KERNEL_FALLBACK_OP(name) \ + REGISTER_KERNEL_FALLBACK_OP_UNIQ_HELPER(__COUNTER__, name) + +#define REGISTER_KERNEL_FALLBACK_OP_UNIQ_HELPER(ctr, name) \ + REGISTER_KERNEL_FALLBACK_OP_UNIQ(ctr, name) + +#define REGISTER_KERNEL_FALLBACK_OP_UNIQ(ctr, name) \ + static TFRTOpRegisterer global_tfrt_forwarding_op_meta_builder_##ctr##_ = \ + TFRTOpMetaBuilder(name) + +////////////////////////////////////////////////////////////////////// +// Forwarding kernel registration. +////////////////////////////////////////////////////////////////////// + +// Represents Kernel Fallback kernel registration information. +struct TFRTOpKernelReg { + using CallbackT = + std::unique_ptr (*)(TFRTOpKernelConstruction*); + + explicit TFRTOpKernelReg(CallbackT callback) : callback(callback) {} + + // Callback that creates a kernel. + CallbackT callback; + // Map from attribute names to type it must match. + // For e.g. foo: DT_FLOAT indicates that foo attribute + // must be a tfdtype attribute with type float. + llvm::StringMap type_constraints; +}; + +class TFRTOpKernelFactories { + public: + TFRTOpKernelFactories(); + void RegisterFactory(absl::string_view kernel_class_name, + TFRTOpKernelReg kernel_info); + + // Creates a kernel with the given name and passes op_kernel_construction + // to kernel constructor. + // Returns the constructed kernel on success. + // In case of failure, returns a nullptr. Kernel creation can fail in one + // of the following cases: + // 1. Kernel with the given name is not found. + // 2. Attributes in op_kernel_construction don't match type constraints + // for any of the kernels with this name. + // Note that we consider a constraint to be "not matched" if attribute + // it applies to is not in op_kernel_construction. + std::unique_ptr CreateKernel( + absl::string_view kernel_class_name, + TFRTOpKernelConstruction* op_kernel_construction) const; + + private: + llvm::StringMap> factories_; +}; + +// TODO(lauj) Should we move these kernel registrations to tfrt::KernelRegistry? +extern llvm::ManagedStatic + tfrt_forwarding_kernel_factories; + +#define REGISTER_KERNEL_FALLBACK_KERNEL(name, ...) \ + REGISTER_KERNEL_FALLBACK_KERNEL_UNIQ_HELPER(__COUNTER__, name, __VA_ARGS__) + +#define REGISTER_KERNEL_FALLBACK_KERNEL_UNIQ_HELPER(ctr, name, ...) \ + REGISTER_KERNEL_FALLBACK_KERNEL_UNIQ(ctr, name, __VA_ARGS__) + +#define REGISTER_KERNEL_FALLBACK_KERNEL_UNIQ(ctr, name, ...) \ + static bool global_tfrt_forwarding_kernel_##ctr##_registered_ = []() { \ + ::tensorflow::tfrt_forwarding_kernel_factories->RegisterFactory( \ + name, TFRTOpKernelReg([](TFRTOpKernelConstruction* construction) \ + -> std::unique_ptr { \ + return std::make_unique<__VA_ARGS__>(construction); \ + })); \ + return true; \ + }(); + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_RUNTIME_FALLBACK_KERNEL_TFRT_OP_KERNEL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/runtime_fallback/runtime/conversion_function.h b/third_party/tflite-hdrs/tensorflow/core/runtime_fallback/runtime/conversion_function.h new file mode 100644 index 00000000..d8537e6c --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/runtime_fallback/runtime/conversion_function.h @@ -0,0 +1,48 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This file implements TFRuntimeFallback tensor conversion function for +// converting to host tensor. + +#ifndef TENSORFLOW_CORE_RUNTIME_FALLBACK_RUNTIME_CONVERSION_FUNCTION_H_ +#define TENSORFLOW_CORE_RUNTIME_FALLBACK_RUNTIME_CONVERSION_FUNCTION_H_ + +#include "tfrt/support/forward_decls.h" // from @tf_runtime + +namespace tfrt { + +class TensorConversionFnRegistry; +class CpuDevice; +class ExecutionContext; +class DenseHostTensor; +} + +namespace tensorflow { +namespace tfd { +class RuntimeFallbackTensor; + +tfrt::Expected +ConvertRuntimeFallbackTensorToDenseHostTensor( + const RuntimeFallbackTensor &tensor, const tfrt::CpuDevice &src, + const tfrt::CpuDevice &dst, const tfrt::ExecutionContext &exec_ctx); + +// Register conversion functions for TFRuntimeFallbackTensors. +void RegisterTFRuntimeFallbackTensorToHostConversionFn( + tfrt::TensorConversionFnRegistry* registry); + +} // namespace tfd +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_RUNTIME_FALLBACK_RUNTIME_CONVERSION_FUNCTION_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/runtime_fallback/runtime/fallback_batch_kernel.h b/third_party/tflite-hdrs/tensorflow/core/runtime_fallback/runtime/fallback_batch_kernel.h new file mode 100644 index 00000000..ef45282a --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/runtime_fallback/runtime/fallback_batch_kernel.h @@ -0,0 +1,285 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_RUNTIME_FALLBACK_RUNTIME_FALLBACK_BATCH_KERNEL_H_ +#define TENSORFLOW_CORE_RUNTIME_FALLBACK_RUNTIME_FALLBACK_BATCH_KERNEL_H_ + +#include +#include +#include +#include +#include +#include + +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/op_requires.h" +#include "tensorflow/core/framework/resource_mgr.h" +#include "tensorflow/core/kernels/batch_kernels.h" +#include "tensorflow/core/kernels/batching_util/adaptive_shared_batch_scheduler.h" +#include "tensorflow/core/kernels/batching_util/batch_resource_base.h" +#include "tensorflow/core/kernels/batching_util/batch_scheduler.h" +#include "tensorflow/core/kernels/batching_util/batch_stats.h" +#include "tensorflow/core/kernels/batching_util/warmup.h" +#include "tensorflow/core/lib/core/refcount.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/random.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/status.h" +#include "tsl/platform/statusor.h" +#include "tfrt/host_context/resource_context.h" // from @tf_runtime + +namespace tensorflow { +namespace tfrt_stub { + +class BatchFunctionFallbackKernelBase : public AsyncOpKernel { + public: + explicit BatchFunctionFallbackKernelBase(OpKernelConstruction* c); + + protected: + // Validates 'allowed_batch_sizes_'. The entries must increase monotonically, + // and the last one must equal 'max_batch_size_'. + absl::Status ValidateAllowedBatchSizes() const; + + // Initialize vars by reading from op-kernel-construction. + // Vars + // - enable_adaptive_batch_threads_ + // true if value of attribute `kEnableAdaptiveSchedulerAttr` is true, or + // if `num_batch_threads` is not positive. + // - adaptive_batch_scheduler_options_ + // Read from corresponding attributes as long as they are set. + void SetAdaptiveBatchSchedulerOptions(OpKernelConstruction* c, + int32_t num_batch_threads); + + static int32 NumBatchThreadsFromEnvironmentWithDefault( + int default_num_batch_threads); + static thread::ThreadPool* GetOrCreateBatchThreadsPool(); + static constexpr int64_t kBatchThreadPoolSize = 128; + + std::string container_; + std::string shared_name_; + std::string batcher_queue_; + int32_t num_batch_threads_; + int32_t max_batch_size_; + int32_t batch_timeout_micros_; + int32_t max_enqueued_batches_; + std::vector allowed_batch_sizes_; + int32 low_priority_max_batch_size_; + int32 low_priority_batch_timeout_micros_; + int32 low_priority_max_enqueued_batches_; + std::vector low_priority_allowed_batch_sizes_; + std::string mixed_priority_policy_; + bool enable_large_batch_splitting_; + bool has_attribute_enable_large_batch_splitting_; + bool disable_padding_; + std::string batch_padding_policy_; + + // Parameters for adaptive batch scheduler only. + // Note 'num_batch_threads_' above is shared by two implementations of batch + // scheduler. + // Per-model inflight batches parameters. + static constexpr int64_t kMinInflightBatches = 16; + static constexpr int64_t kInitialInflightBatches = 16; + static constexpr int64_t kBatchesToAverageOver = 10; + static constexpr int64_t kMaxInflightBatches = 64; + bool enable_adaptive_batch_threads_ = false; + struct AdaptiveBatchSchedulerOptions { + int32 min_in_flight_batches_limit = kMinInflightBatches; + int32 initial_in_flight_batches_limit = kInitialInflightBatches; + int32 max_in_flight_batches_limit = kMaxInflightBatches; + int32 batches_to_average_over = kBatchesToAverageOver; + }; + std::optional + adaptive_batch_scheduler_options_ = std::nullopt; +}; + +// Legacy TF kernel which is a variant of tf.BatchFunction. +template +class BatchFunctionFallbackKernel : public BatchFunctionFallbackKernelBase { + public: + using BatchFunctionType = typename BatchResourceType::BatchFunctionType; + + explicit BatchFunctionFallbackKernel(OpKernelConstruction* c) + : BatchFunctionFallbackKernelBase(c) { + int64_t handle; + OP_REQUIRES_OK(c, c->GetAttr("opaque_function_handle", &handle)); + batch_function_ = BatchResourceType::CastHandleToFunction(handle); + } + + void ComputeAsync(OpKernelContext* c, DoneCallback done) final; + + private: + BatchFunctionType batch_function_; +}; + +template +void BatchFunctionFallbackKernel::ComputeAsync( + OpKernelContext* c, DoneCallback done) { + RecordBatchSplitUsage(has_attribute_enable_large_batch_splitting_ + ? std::make_optional(enable_large_batch_splitting_) + : std::nullopt, + GetModelName(c)); + RecordBatchParamNumBatchThreads(num_batch_threads_, GetModelName(c)); + OP_REQUIRES_VALUE(tfrt::ResourceContext * client_graph_resource_context, c, + BatchResourceType::GetClientGraphResourceContext(c)); + OP_REQUIRES_ASYNC( + c, client_graph_resource_context != nullptr, + errors::FailedPrecondition("client graph resource context not found"), + done); + std::function< + absl::StatusOr>()> + creator; + if (adaptive_batch_scheduler_options_ != std::nullopt) { + creator = [this, c]() + -> absl::StatusOr> { + serving::AdaptiveSharedBatchScheduler< + serving::BatchResourceBase::BatchTask>::Options + adaptive_shared_batch_scheduler_options; + adaptive_shared_batch_scheduler_options.thread_pool_name = + "adaptive_batch_threads"; + adaptive_shared_batch_scheduler_options.num_batch_threads = + adaptive_batch_scheduler_options_->max_in_flight_batches_limit; + adaptive_shared_batch_scheduler_options.thread_pool = + GetOrCreateBatchThreadsPool(); + + // When we explicitly specify 'thread_pool', you'd think ASBS would ignore + // 'num_batch_threads', but in fact ASBS still uses num_batch_threads as + // the max number of in-flight batches. It makes no sense to have more + // in-flight batches than threads (it would result in strictly bad + // batching decisions), so we cap this parameter (which otherwise comes + // from the saved model) to the actual number of batch threads (which + // comes from a process-wide environment variable). + // + // We have to apply the same capping to min_ and initial_ + // in_flight_batches_limit below to produce valid configurations. + adaptive_shared_batch_scheduler_options.num_batch_threads = std::min( + NumBatchThreadsFromEnvironmentWithDefault(kBatchThreadPoolSize), + adaptive_batch_scheduler_options_->max_in_flight_batches_limit); + + // adaptive_shared_batch_scheduler_options.full_batch_scheduling_boost_micros + // is 0 (default value) intentionally, so tasks are scheduled in a FIFO + // way. + // Two rationales to use default value (zero) for + // `full_batch_scheduling_boost_micros` + // 1) In this way, tasks scheduling policy is FIFO. Compared with round + // robin (what shared batch scheduler does), FIFO ensures that model + // with low QPS (i.e., models enqueue fewer tasks in the shared queue) + // will be processed timely. + // 2) If set, `full_batch_scheduling_boost_micros` should be of order + // the batch processing latency (which varies on a model basis). + // If a non-zero value is not set properly, it harms tail latency. + adaptive_shared_batch_scheduler_options.min_in_flight_batches_limit = + std::min( + NumBatchThreadsFromEnvironmentWithDefault(kBatchThreadPoolSize), + adaptive_batch_scheduler_options_->min_in_flight_batches_limit); + adaptive_shared_batch_scheduler_options + .initial_in_flight_batches_limit = std::min( + NumBatchThreadsFromEnvironmentWithDefault(kBatchThreadPoolSize), + adaptive_batch_scheduler_options_->initial_in_flight_batches_limit); + adaptive_shared_batch_scheduler_options.batches_to_average_over = + adaptive_batch_scheduler_options_->batches_to_average_over; + adaptive_shared_batch_scheduler_options.fifo_scheduling = true; + + std::unique_ptr new_resource; + auto status = BatchResourceType::Create( + c, adaptive_shared_batch_scheduler_options, max_batch_size_, + batch_timeout_micros_, max_enqueued_batches_, allowed_batch_sizes_, + batch_function_, disable_padding_, &new_resource); + if (!status.ok()) return status; + if (c->session_metadata() != nullptr) { + new_resource->set_session_metadata(*c->session_metadata()); + } + return tensorflow::core::RefCountPtr( + new_resource.release()); + }; + } else { + creator = [this, c]() + -> absl::StatusOr> { + serving::BatchResourceOptions batch_resource_options; + TF_ASSIGN_OR_RETURN( + batch_resource_options.mixed_priority_batching_policy, + serving::GetMixedPriorityBatchingPolicy(mixed_priority_policy_)); + batch_resource_options.num_batch_threads = num_batch_threads_; + batch_resource_options.max_batch_size = max_batch_size_; + batch_resource_options.batch_timeout_micros = batch_timeout_micros_; + batch_resource_options.max_enqueued_batches = max_enqueued_batches_; + batch_resource_options.allowed_batch_sizes = allowed_batch_sizes_; + batch_resource_options.batch_padding_policy = batch_padding_policy_; + batch_resource_options.low_priority_max_batch_size = + low_priority_max_batch_size_; + batch_resource_options.low_priority_batch_timeout_micros = + low_priority_batch_timeout_micros_; + batch_resource_options.low_priority_max_enqueued_batches = + low_priority_max_enqueued_batches_; + batch_resource_options.low_priority_allowed_batch_sizes = + low_priority_allowed_batch_sizes_; + + serving::ModelBatchStats& model_batch_stats = + serving::GlobalBatchStatsRegistry().model( + /* model_name= */ std::string(GetModelName(c)), + /* op_name= */ c->op_kernel().name()); + model_batch_stats.SetBatchTimeoutMicros(batch_timeout_micros_); + model_batch_stats.SetNumBatchThreads(num_batch_threads_); + + std::unique_ptr new_resource; + auto status = BatchResourceType::Create( + c, batch_resource_options, batch_function_, + enable_large_batch_splitting_, disable_padding_, &new_resource); + if (!status.ok()) return status; + if (c->session_metadata() != nullptr) { + new_resource->set_session_metadata(*c->session_metadata()); + } + return tensorflow::core::RefCountPtr( + new_resource.release()); + }; + } + + auto br = client_graph_resource_context->GetOrCreateResource< + tensorflow::core::RefCountPtr>(shared_name_, creator); + if (!br.ok()) OP_REQUIRES_OK_ASYNC(c, br.status(), done); + auto expected_name = BatchResourceType::GetBatchFunctionName(batch_function_); + auto received_name = + BatchResourceType::GetBatchFunctionName((*br)->get()->batch_function()); + + // TODO(b/187173237): When we can guarantee only 1 copy of BEF function is + // generated for the batched function, we can assert the pointers are equal + OP_REQUIRES_ASYNC( + c, expected_name == received_name, + errors::InvalidArgument(absl::StrCat( + "Provided BEF function doesn't match with BatchResource. Expected:", + expected_name, " Received:", received_name)), + done); + const uint64_t guid = random::New64(); + auto create_batch_task_fn = [c]() { + return BatchResourceType::CreateBatchTask(c); + }; + absl::Status status; + if (serving::ShouldWarmupAllBatchSizes(c)) { + status = (*br)->get()->RegisterWarmupInputs(guid, c, batcher_queue_, + create_batch_task_fn, done); + } else { + status = (*br)->get()->RegisterInput(guid, c, batcher_queue_, + create_batch_task_fn, done); + } + OP_REQUIRES_OK_ASYNC(c, status, done); + // Assume br calls done, so nothing to do here. +} + +} // namespace tfrt_stub +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_RUNTIME_FALLBACK_RUNTIME_FALLBACK_BATCH_KERNEL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/runtime_fallback/runtime/kernel_utils.h b/third_party/tflite-hdrs/tensorflow/core/runtime_fallback/runtime/kernel_utils.h new file mode 100644 index 00000000..e4978b80 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/runtime_fallback/runtime/kernel_utils.h @@ -0,0 +1,161 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This file declares kernel utils. + +#ifndef TENSORFLOW_CORE_RUNTIME_FALLBACK_RUNTIME_KERNEL_UTILS_H_ +#define TENSORFLOW_CORE_RUNTIME_FALLBACK_RUNTIME_KERNEL_UTILS_H_ + +#include +#include +#include +#include +#include + +#include "absl/strings/match.h" +#include "absl/strings/string_view.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/ErrorHandling.h" +#include "tensorflow/c/tf_tensor.h" +#include "tensorflow/core/common_runtime/device_mgr.h" +#include "tensorflow/core/common_runtime/eager/context.h" +#include "tensorflow/core/common_runtime/eager/eager_operation.h" +#include "tensorflow/core/common_runtime/eager/tensor_handle.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/platform/status.h" +#include "tfrt/core_runtime/core_runtime_op.h" // from @tf_runtime +#include "tfrt/dtype/dtype.h" // from @tf_runtime +#include "tfrt/host_context/execution_context.h" // from @tf_runtime +#include "tfrt/host_context/host_context.h" // from @tf_runtime +#include "tfrt/support/error_util.h" // from @tf_runtime +#include "tfrt/support/forward_decls.h" // from @tf_runtime +#include "tfrt/tensor/tensor_shape.h" // from @tf_runtime + +namespace tensorflow { +namespace tfd { + +template +struct AutoReleaser { + void operator()(T* p) const { p->Release(); } +}; +template +using AutoReleasePtr = std::unique_ptr>; + +using OwnedEagerContext = AutoReleasePtr; +using OwnedEagerOperation = AutoReleasePtr; +using OwnedTensorHandle = AutoReleasePtr; +using OwnedAbstractTensorInterface = AutoReleasePtr; + +// Check if a TensorHandle physically resides on GPU. +inline bool IsGpuTensorHandle(const tensorflow::TensorHandle& handle) { + absl::Status dummy_status; + // BackingDeviceName is where the tensor is physically located, not where the + // op that produces the tensor is. + // Note that dummy_status is never set in TensorHandle::BackingDeviceName. + absl::string_view device_name = handle.BackingDeviceName(&dummy_status); + return absl::StrContains(device_name, "GPU"); +} + +// TODO(zhangqiaorjc): Allowlist more dtypes as tfrt GPU supports more. +// RuntimeFallbackTensor of supported dtypes below will be eagerly converted to +// tfrt::DenseGpuTensor after each RuntimeFallbackOpHandler::Execute. +inline bool IsSupportedByTFRTGpu(DataType dtype) { + switch (dtype) { + default: + return false; + case DataType::DT_FLOAT: + case DataType::DT_DOUBLE: + case DataType::DT_INT32: + return true; + } +} + +// TODO(b/165872892): Remove this method. +// This method is needed because we use different device name in TF-TFRT +// integration and mlir test. In TF-TFRT integration, we reuse the device full +// name (e.g. /job:localhost/replica:0/task:0/device:GPU:0) from TF. But in mlir +// test, we use simplified device name "GPU:0". And lot of things in fallback +// need to be used in both cases. As a result, we need to look up the device +// with both device names. +inline const char* ConvertTfDeviceNameToTfrtDefault(const char* device_name) { + assert(strlen(device_name) >= 5); + return &device_name[strlen(device_name) - 5]; +} + +// Create and initialize EagerContext. +tfrt::Expected InitEagerContext(); + +tfrt::Expected InitEagerContext( + DynamicDeviceMgr* device_mgr, const SessionOptions& session_opts, + ContextDevicePlacementPolicy default_device_placement_policy, + bool is_async); + +// Obtain EagerContext from ExecutionContext. +tfrt::Expected GetEagerContext(tfrt::ExecutionContext exec_ctx); + +// Return the CoreRuntimeOp for `op_name` using fallback op_handler. +llvm::Expected GetFallbackOp(tfrt::string_view op_name, + tfrt::HostContext* host); + +constexpr char kEagerContextResourceName[] = "EagerContextResourceName"; + +class EagerContextResource { + public: + explicit EagerContextResource() + : device_mgr_(std::make_unique()), + ctx_{InitEagerContext( + device_mgr_.get(), tensorflow::SessionOptions(), + tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT, + /*is_async=*/false)} {} + explicit EagerContextResource( + const SessionOptions& session_opts, + ContextDevicePlacementPolicy default_device_placement_policy, + bool is_async) + : device_mgr_(std::make_unique()), + ctx_{InitEagerContext(device_mgr_.get(), session_opts, + default_device_placement_policy, is_async)} {} + + tfrt::Expected GetTFEagerContext() { + if (!ctx_) return ctx_.takeError(); + return ctx_.get().get(); + } + + DynamicDeviceMgr* GetDeviceMgr() { return device_mgr_.get(); } + + llvm::Error AddDevices(std::vector> devices) { + if (!ctx_) return ctx_.takeError(); + absl::Status s = dynamic_cast( + ctx_.get()->local_device_mgr()) + ->AddDevices(std::move(devices)); + if (!s.ok()) return tfrt::MakeStringError(s.message()); + ctx_.get()->InitPrioritizedDeviceTypeList(); + ctx_.get()->pflr()->InitializeDeviceAndFlr(); + return llvm::Error::success(); + } + + private: + // EagerContext uses this device_mgs as local_device_mgr. We manage the + // device_mgr_ here to allow TFRT adding new devices after EagerContext + // initialization. + // Today, TFRT only adds TPU devices after EagerContext initialization. + std::unique_ptr device_mgr_; + + tfrt::Expected ctx_; +}; + +} // namespace tfd +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_RUNTIME_FALLBACK_RUNTIME_KERNEL_UTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/runtime_fallback/runtime/op_logger.h b/third_party/tflite-hdrs/tensorflow/core/runtime_fallback/runtime/op_logger.h new file mode 100644 index 00000000..c920715d --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/runtime_fallback/runtime/op_logger.h @@ -0,0 +1,64 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This file defines a logger for op names. + +#ifndef TENSORFLOW_CORE_RUNTIME_FALLBACK_RUNTIME_OP_LOGGER_H_ +#define TENSORFLOW_CORE_RUNTIME_FALLBACK_RUNTIME_OP_LOGGER_H_ + +#include +#include + +#include "absl/memory/memory.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/StringRef.h" +#include "tfrt/host_context/shared_context.h" // from @tf_runtime +#include "tfrt/support/concurrent_vector.h" // from @tf_runtime +#include "tfrt/support/forward_decls.h" // from @tf_runtime + +namespace tfrt { +class HostContext; +} + +namespace tensorflow { +namespace tfd { + +class OpLogger : public tfrt::SharedContext { + public: + explicit OpLogger(tfrt::HostContext* host) + : op_names_(std::make_unique>(8)) {} + + void LogOp(tfrt::string_view op_name) { + op_names_->emplace_back(op_name.str()); + } + + tfrt::ArrayRef GetLoggedOps() const { + absl::Span span = op_names_->ToConstSpan(); + return tfrt::ArrayRef(span.data(), span.size()); + } + + // Cannot be called concurrently with any API in this class. + void Clear() { + op_names_ = std::make_unique>(8); + } + + private: + std::unique_ptr> op_names_; +}; + +} // namespace tfd +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_RUNTIME_FALLBACK_RUNTIME_OP_LOGGER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/runtime_fallback/runtime/runtime_fallback_kernels.h b/third_party/tflite-hdrs/tensorflow/core/runtime_fallback/runtime/runtime_fallback_kernels.h new file mode 100644 index 00000000..833b92f7 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/runtime_fallback/runtime/runtime_fallback_kernels.h @@ -0,0 +1,57 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This file declares kernels for running TFRT ops/kernels via TF runtime +// fallback. + +#ifndef TENSORFLOW_CORE_RUNTIME_FALLBACK_RUNTIME_RUNTIME_FALLBACK_KERNELS_H_ +#define TENSORFLOW_CORE_RUNTIME_FALLBACK_RUNTIME_RUNTIME_FALLBACK_KERNELS_H_ + +#include + +#include "llvm/Support/Error.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/runtime_fallback/runtime/kernel_utils.h" +#include "tfrt/core_runtime/op_attrs.h" // from @tf_runtime +#include "tfrt/host_context/async_value.h" // from @tf_runtime +#include "tfrt/host_context/chain.h" // from @tf_runtime +#include "tfrt/host_context/execution_context.h" // from @tf_runtime +#include "tfrt/host_context/shared_context.h" // from @tf_runtime +#include "tfrt/tensor/tensor.h" // from @tf_runtime + +namespace tensorflow { +namespace tfd { + +// Create an EagerOperation to run the op, taking tensorflow::TensorHandle and +// returning tensorflow::AbstractTensorHandle*. +absl::Status CallEagerExecute( + const tfrt::ExecutionContext& exec_ctx, EagerContext* eager_ctx, + const char* op_name, const char* device_name, + llvm::ArrayRef input_tensor_handles, + const tfrt::OpAttrsRef& attrs, + llvm::MutableArrayRef + result_tensor_handles); + +// Take and return RuntimeFallbackTensors. +tfrt::AsyncValueRef RuntimeFallbackExecute( + const tfrt::ExecutionContext& exec_ctx, const char* op_name, + const char* device_name, tfrt::ArrayRef arguments, + const tfrt::OpAttrsRef& attrs, + tfrt::MutableArrayRef> results); + +} // namespace tfd +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_RUNTIME_FALLBACK_RUNTIME_RUNTIME_FALLBACK_KERNELS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/runtime_fallback/runtime/runtime_fallback_op_handler.h b/third_party/tflite-hdrs/tensorflow/core/runtime_fallback/runtime/runtime_fallback_op_handler.h new file mode 100644 index 00000000..54d404fe --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/runtime_fallback/runtime/runtime_fallback_op_handler.h @@ -0,0 +1,36 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This file declares RuntimeFallbackOpHandler, responsible for running TFRT ops +// on Tensorflow. + +#ifndef TENSORFLOW_CORE_RUNTIME_FALLBACK_RUNTIME_RUNTIME_FALLBACK_OP_HANDLER_H_ +#define TENSORFLOW_CORE_RUNTIME_FALLBACK_RUNTIME_RUNTIME_FALLBACK_OP_HANDLER_H_ + +#include + +#include "tfrt/core_runtime/core_runtime.h" // from @tf_runtime +#include "tfrt/core_runtime/op_handler.h" // from @tf_runtime +#include "tfrt/host_context/execution_context.h" // from @tf_runtime + +namespace tensorflow { +namespace tfd { + +llvm::Expected CreateRuntimeFallbackOpHandler( + tfrt::CoreRuntime* runtime, tfrt::string_view tf_device_name); +} // namespace tfd +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_RUNTIME_FALLBACK_RUNTIME_RUNTIME_FALLBACK_OP_HANDLER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/runtime_fallback/runtime/runtime_fallback_tensor.h b/third_party/tflite-hdrs/tensorflow/core/runtime_fallback/runtime/runtime_fallback_tensor.h new file mode 100644 index 00000000..53c6ab69 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/runtime_fallback/runtime/runtime_fallback_tensor.h @@ -0,0 +1,80 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This file declares TF runtime fallback tensor. + +#ifndef TENSORFLOW_CORE_RUNTIME_FALLBACK_RUNTIME_RUNTIME_FALLBACK_TENSOR_H_ +#define TENSORFLOW_CORE_RUNTIME_FALLBACK_RUNTIME_RUNTIME_FALLBACK_TENSOR_H_ + +#include "llvm/ADT/STLExtras.h" +#include "tensorflow/core/runtime_fallback/runtime/kernel_utils.h" +#include "tfrt/support/forward_decls.h" // from @tf_runtime +#include "tfrt/tensor/dense_host_tensor.h" // from @tf_runtime +#include "tfrt/tensor/host_tensor.h" // from @tf_runtime +#include "tfrt/tensor/string_host_tensor.h" // from @tf_runtime +#include "tfrt/tensor/tensor.h" // from @tf_runtime + +namespace tensorflow { +namespace tfd { + +class RuntimeFallbackTensor final + : public tfrt::Tensor, + public tfrt::TensorTraits { + public: + explicit RuntimeFallbackTensor(const tfrt::TensorShape& shape, + tfrt::DType dtype, OwnedTensorHandle th); + + void Print(tfrt::raw_ostream& os) const override; + + // Note that this method does not add ref to the return tensor_handle. + TensorHandle* GetTensorHandle() const { return tensor_handle_.get(); } + + // Tensor type name for RuntimeFallbackTensor. + static const char* name() { return "RuntimeFallback"; } + + private: + template + static void PrintTensorValues(void* data, ssize_t size, + llvm::raw_ostream& os) { + llvm::ArrayRef elements = llvm::ArrayRef(static_cast(data), size); + llvm::interleaveComma(elements, os); + } + + OwnedTensorHandle tensor_handle_; +}; + +llvm::SmallVector GetShape( + AbstractTensorInterface* tensor_interface); + +tfrt::Expected CopyTfStringTensorToStringHostTensor( + AbstractTensorInterface* tensor_interface, tfrt::HostContext* host); + +tfrt::Expected +CreateRuntimeFallbackTensorFromTfTensorHandle(OwnedTensorHandle owned_th, + tfrt::HostContext* host); + +RuntimeFallbackTensor MoveDHTToRuntimeFallbackTensor( + tfrt::DenseHostTensor&& dht, tfrt::HostContext* host); + +RuntimeFallbackTensor CopyRefDHTToRuntimeFallbackTensor( + const tfrt::DenseHostTensor& dht, tfrt::HostContext* host); + +RuntimeFallbackTensor CopySHTToRuntimeFallbackTensor( + const tfrt::StringHostTensor& sht, tfrt::HostContext* host); + +} // namespace tfd +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_RUNTIME_FALLBACK_RUNTIME_RUNTIME_FALLBACK_TENSOR_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/runtime_fallback/test/coreruntime_driver.h b/third_party/tflite-hdrs/tensorflow/core/runtime_fallback/test/coreruntime_driver.h new file mode 100644 index 00000000..00ea4b0e --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/runtime_fallback/test/coreruntime_driver.h @@ -0,0 +1,79 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_RUNTIME_FALLBACK_TEST_CORERUNTIME_DRIVER_H_ +#define TENSORFLOW_CORE_RUNTIME_FALLBACK_TEST_CORERUNTIME_DRIVER_H_ + +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "tfrt/core_runtime/core_runtime.h" // from @tf_runtime +#include "tfrt/host_context/async_value_ref.h" // from @tf_runtime +#include "tfrt/host_context/chain.h" // from @tf_runtime +#include "tfrt/host_context/location.h" // from @tf_runtime +#include "tfrt/host_context/resource_context.h" // from @tf_runtime + +namespace tfrt { + +class OpHandle; +class OpHandler; +class OpAttrsRef; +class TensorHandle; + +class CoreRuntimeDriver final : public tfrt::LocationHandler { + public: + explicit CoreRuntimeDriver(); + + void Execute(string_view op_name, + tfrt::MutableArrayRef args, + const tfrt::OpAttrsRef& attrs, + tfrt::MutableArrayRef results, + tfrt::string_view filename, int line); + + ExecutionContext CreateExecutionContext(tfrt::string_view filename, int line); + + void InitializeCpuRuntimeFallbackOpHandler(); + + void InitializeGpuRuntimeFallbackOpHandler(int gpu_ordinal); + + void InitializeCpuKernelFallbackOpHandler(); + + HostContext* GetHost() const; + + CoreRuntimeOp MakeOp(string_view op_name); + + void WaitForHostContextQuiesce(); + + DecodedLocation DecodeLocation(Location loc) const override; + + private: + explicit CoreRuntimeDriver(std::unique_ptr corert); + + std::unique_ptr corert_; + tfrt::OpHandler* op_handler_; + tfrt::AsyncValueRef chain_; + tfrt::ResourceContext resource_context_; + + // `location_map_` is a map from (filename, line) to the opaque location data, + // which is the index in `locations_`. + absl::flat_hash_map, int> location_map_; + std::vector> locations_; +}; + +} // namespace tfrt + +#endif // TENSORFLOW_CORE_RUNTIME_FALLBACK_TEST_CORERUNTIME_DRIVER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/runtime_fallback/util/attr_util.h b/third_party/tflite-hdrs/tensorflow/core/runtime_fallback/util/attr_util.h new file mode 100644 index 00000000..2bb7f137 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/runtime_fallback/util/attr_util.h @@ -0,0 +1,99 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_RUNTIME_FALLBACK_UTIL_ATTR_UTIL_H_ +#define TENSORFLOW_CORE_RUNTIME_FALLBACK_UTIL_ATTR_UTIL_H_ + +#include + +#include "absl/strings/string_view.h" +#include "llvm/ADT/StringRef.h" +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/op_def.pb.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/platform/status.h" +#include "tfrt/bef/bef_encoding.h" // from @tf_runtime +#include "tfrt/core_runtime/op_attr_type.h" // from @tf_runtime +#include "tfrt/core_runtime/op_attrs.h" // from @tf_runtime +#include "tfrt/host_context/host_context.h" // from @tf_runtime +#include "tfrt/host_context/kernel_utils.h" // from @tf_runtime +#include "tfrt/support/forward_decls.h" // from @tf_runtime + +namespace tensorflow { +namespace tfd { + +// Converts a TFRT string_view to the Abseil version. +inline absl::string_view ToAbslStringView(tfrt::string_view sv) { + return absl::string_view(sv.data(), sv.size()); +} + +// Parses the string representation of the DataType in `dtype` into `data_type`. +// Aborts the program for unsupported dtypes. +absl::Status ParseTfDataType(absl::string_view dtype, DataType* data_type); + +// The following 2 functions convert between Tensorflow DataTypes and +// OpAttrTypes. The mapping between OpAttrType and DataType is defined in +// attr_type.def. Aborts on unsupported types. +DataType ConvertToTfDataType(tfrt::OpAttrType op_attr_type); +tfrt::OpAttrType ConvertFromTfDataType(DataType data_type); + +// The following 2 functions convert between BEF attribute types and Tensorflow +// DataTypes. Aborts on unsupported datatypes. +DataType ConvertBefAttrTypeToTfDataType(tfrt::DType attr_type); +tfrt::DType ConvertTfDataTypeToBefAttrType(DataType data_type); + +// Parses the tensor valued `attr_value` and constructs the tensor with its +// contents in `tensor`. Returns OK status on success, INVALID_ARGUMENT on +// failure. +absl::Status ParseTensorAttrValue(absl::string_view attr_value, + tensorflow::Tensor* tensor); + +// Parses a string of the form "[1,2,3,...]" in `attr_value` and returns the +// constituent dimension sizes (shape) in `int_list_val`. Returns +// INVALID_ARGUMENT on invalid input. +absl::Status ParseTensorShapeAttrValue(absl::string_view attr_value, + std::vector* shape_val); + +// Parses a boolean from `attr_value` into `bool_val` and returns OK status on +// success. Returns INVALID_ARGUMENT on invalid input. +absl::Status ParseBoolAttrValue(absl::string_view attr_value, bool* bool_val); + +// Parses an int64_t from `attr_value` into `int_val` and returns OK status on +// success. Returns INVLAID_ARGUMENT on invalid input. +absl::Status ParseIntAttrValue(absl::string_view attr_value, int64_t* int_val); + +inline std::vector AttrValueSplit(absl::string_view str) { + return absl::StrSplit(str, absl::MaxSplits('$', 1)); +} + +// Returns true if `attr_name` is an attribute that is not required by TFRT +// (usually added by stages higher in the lowering process) +bool IsUnusedAttribute(absl::string_view attr_name); + +// Fills in the passed in AttrValueMap `attr_value_map` with attributes from +// `attrs`. +llvm::Error FillAttrValueMap(const tfrt::OpAttrsRef& attrs, + tfrt::HostContext* host, + AttrValueMap* attr_value_map); + +// Fills in the passed in AttrValueMap `attr_value_map`. +absl::Status SetUpAttrValueMap(tfrt::AggregateAttr op_attr_array, + tfrt::AggregateAttr op_func_attr_array, + tensorflow::AttrValueMap* attr_value_map); + +} // namespace tfd +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_RUNTIME_FALLBACK_UTIL_ATTR_UTIL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/runtime_fallback/util/fallback_test_util.h b/third_party/tflite-hdrs/tensorflow/core/runtime_fallback/util/fallback_test_util.h new file mode 100644 index 00000000..cdfa6331 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/runtime_fallback/util/fallback_test_util.h @@ -0,0 +1,33 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_RUNTIME_FALLBACK_UTIL_FALLBACK_TEST_UTIL_H_ +#define TENSORFLOW_CORE_RUNTIME_FALLBACK_UTIL_FALLBACK_TEST_UTIL_H_ + +#include "tensorflow/core/platform/threadpool_interface.h" +#include "tfrt/host_context/execution_context.h" // from @tf_runtime +#include "tfrt/host_context/resource_context.h" // from @tf_runtime + +namespace tensorflow { +namespace tfd { + +tfrt::ExecutionContext CreateFallbackTestExecutionContext( + tfrt::HostContext* host, tfrt::ResourceContext* resource_context, + tensorflow::thread::ThreadPoolInterface* user_intra_op_threadpool = + nullptr); + +} // namespace tfd +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_RUNTIME_FALLBACK_UTIL_FALLBACK_TEST_UTIL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/runtime_fallback/util/tensor_metadata.h b/third_party/tflite-hdrs/tensorflow/core/runtime_fallback/util/tensor_metadata.h new file mode 100644 index 00000000..f192ace8 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/runtime_fallback/util/tensor_metadata.h @@ -0,0 +1,41 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_RUNTIME_FALLBACK_UTIL_TENSOR_METADATA_H_ +#define TENSORFLOW_CORE_RUNTIME_FALLBACK_UTIL_TENSOR_METADATA_H_ + +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/runtime_fallback/util/type_util.h" +#include "tfrt/support/forward_decls.h" // from @tf_runtime +#include "tfrt/tensor/tensor_metadata.h" // from @tf_runtime + +namespace tensorflow::tfd { + +// Retrieves TFRT TensorMetadata from a tensorflow::Tensor. +inline tfrt::TensorMetadata GetTensorMetadata( + const tensorflow::Tensor& tf_tensor) { + auto dtype = tfd::GetTfrtDtype(tf_tensor.dtype()); + auto dim_sizes = tf_tensor.shape().dim_sizes(); + static_assert(sizeof(tfrt::Index) == sizeof(dim_sizes.front()), + "Invalid dimension type size"); + auto shape = llvm::ArrayRef(reinterpret_cast(dim_sizes.data()), + dim_sizes.size()); + return tfrt::TensorMetadata(dtype, shape); +} + +} // namespace tensorflow::tfd + +#endif // TENSORFLOW_CORE_RUNTIME_FALLBACK_UTIL_TENSOR_METADATA_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/runtime_fallback/util/tensor_util.h b/third_party/tflite-hdrs/tensorflow/core/runtime_fallback/util/tensor_util.h new file mode 100644 index 00000000..f974edf2 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/runtime_fallback/util/tensor_util.h @@ -0,0 +1,68 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_RUNTIME_FALLBACK_UTIL_TENSOR_UTIL_H_ +#define TENSORFLOW_CORE_RUNTIME_FALLBACK_UTIL_TENSOR_UTIL_H_ + +#include +#include + +#include "tensorflow/c/tf_tensor.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/runtime_fallback/util/tensor_metadata.h" // IWYU pragma: export +#include "tfrt/dtype/dtype.h" // from @tf_runtime +#include "tfrt/host_context/host_buffer.h" // from @tf_runtime +#include "tfrt/support/forward_decls.h" // from @tf_runtime +#include "tfrt/tensor/string_host_tensor.h" // from @tf_runtime +#include "tfrt/tensor/tensor_shape.h" // from @tf_runtime + +namespace tensorflow { +namespace tfd { + +struct TFTensorDeleter { + void operator()(TF_Tensor* p) const { TF_DeleteTensor(p); } +}; +using OwnedTFTensor = std::unique_ptr; + +// Moves one ref on HostBuffer to tensorflow::Tensor. +tensorflow::Tensor MoveHostBufferToTfTensor( + tfrt::RCReference host_buffer, tfrt::DType dtype, + const tfrt::TensorShape& shape); + +// Creates a tensorflow::Tensor based on StringHostTensor. +tensorflow::Tensor CopyShtToTfTensor(const tfrt::StringHostTensor& sht); + +// Converts tfrt shape to tensorflow shape. +inline tensorflow::TensorShape GetTfShape(const tfrt::TensorShape& shape) { + llvm::SmallVector dimensions; + shape.GetDimensions(&dimensions); + llvm::SmallVector dims(dimensions.begin(), dimensions.end()); + return tensorflow::TensorShape(dims); +} + +inline void CheckBoolCompatibility() { + // sizeof(bool) is implementation defined. The following may only work when + // sizeof(bool) is 1. + // + // TODO(tfrt-devs): It is still undefined behavior to directly cast char* + // between bool* and access the data. Consider allocating target objects and + // using memcpy instead. + static_assert(sizeof(bool) == 1, "Only support when bool is 1 byte."); +} + +} // namespace tfd +} // namespace tensorflow +#endif // TENSORFLOW_CORE_RUNTIME_FALLBACK_UTIL_TENSOR_UTIL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/runtime_fallback/util/type_util.h b/third_party/tflite-hdrs/tensorflow/core/runtime_fallback/util/type_util.h new file mode 100644 index 00000000..32a859c7 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/runtime_fallback/util/type_util.h @@ -0,0 +1,59 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_RUNTIME_FALLBACK_UTIL_TYPE_UTIL_H_ +#define TENSORFLOW_CORE_RUNTIME_FALLBACK_UTIL_TYPE_UTIL_H_ + +#include "llvm/Support/ErrorHandling.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/platform/logging.h" +#include "tfrt/dtype/dtype.h" // from @tf_runtime + +namespace tensorflow { +namespace tfd { + +// Map tfrt::Dtype to TF_DataType. +inline DataType GetTfDataType(tfrt::DType dtype) { + switch (dtype) { + case tfrt::DType::Invalid: + case tfrt::DType::Unsupported: + case tfrt::DType::Resource: + DCHECK(false) << "invalid dtype"; + return DataType::DT_INVALID; +#define DTYPE(TFRT_ENUM, DT_ENUM) \ + case tfrt::DType::TFRT_ENUM: \ + return DataType::DT_ENUM; +#include "tensorflow/core/runtime_fallback/util/dtype.def" // NOLINT + } +} + +inline tfrt::DType GetTfrtDtype(DataType dtype) { + switch (dtype) { + default: + return tfrt::DType(tfrt::DType::Unsupported); + case DataType::DT_INVALID: + return tfrt::DType(); + case DataType::DT_RESOURCE: + return tfrt::DType(tfrt::DType::Resource); +#define DTYPE(TFRT_ENUM, DT_ENUM) \ + case DataType::DT_ENUM: \ + return tfrt::DType(tfrt::DType::TFRT_ENUM); +#include "tensorflow/core/runtime_fallback/util/dtype.def" // NOLINT + } +} + +} // namespace tfd +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_RUNTIME_FALLBACK_UTIL_TYPE_UTIL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/summary/schema.h b/third_party/tflite-hdrs/tensorflow/core/summary/schema.h new file mode 100644 index 00000000..dc13bbfb --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/summary/schema.h @@ -0,0 +1,34 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_SUMMARY_SCHEMA_H_ +#define TENSORFLOW_CORE_SUMMARY_SCHEMA_H_ + +#include "absl/status/status.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/db/sqlite.h" + +namespace tensorflow { + +constexpr uint32 kTensorboardSqliteApplicationId = 0xfeedabee; + +/// \brief Creates TensorBoard SQLite tables and indexes. +/// +/// If they are already created, this has no effect. If schema +/// migrations are necessary, they will be performed with logging. +absl::Status SetupTensorboardSqliteDb(Sqlite* db); + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_SUMMARY_SCHEMA_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/summary/summary_converter.h b/third_party/tflite-hdrs/tensorflow/core/summary/summary_converter.h new file mode 100644 index 00000000..ab196692 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/summary/summary_converter.h @@ -0,0 +1,39 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_SUMMARY_SUMMARY_CONVERTER_H_ +#define TENSORFLOW_CORE_SUMMARY_SUMMARY_CONVERTER_H_ + +#include "absl/status/status.h" +#include "tensorflow/core/framework/summary.pb.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { + +// TODO(jart): Delete these methods in favor of new Python implementation. +absl::Status AddTensorAsScalarToSummary(const Tensor& t, const string& tag, + Summary* s); +absl::Status AddTensorAsHistogramToSummary(const Tensor& t, const string& tag, + Summary* s); +absl::Status AddTensorAsImageToSummary(const Tensor& tensor, const string& tag, + int max_images, const Tensor& bad_color, + Summary* s); +absl::Status AddTensorAsAudioToSummary(const Tensor& tensor, const string& tag, + int max_outputs, float sample_rate, + Summary* s); + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_SUMMARY_SUMMARY_CONVERTER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/summary/summary_db_writer.h b/third_party/tflite-hdrs/tensorflow/core/summary/summary_db_writer.h new file mode 100644 index 00000000..545f849e --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/summary/summary_db_writer.h @@ -0,0 +1,44 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_SUMMARY_SUMMARY_DB_WRITER_H_ +#define TENSORFLOW_CORE_SUMMARY_SUMMARY_DB_WRITER_H_ + +#include "absl/status/status.h" +#include "tensorflow/core/kernels/summary_interface.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/db/sqlite.h" +#include "tensorflow/core/platform/env.h" + +namespace tensorflow { + +/// \brief Creates SQLite SummaryWriterInterface. +/// +/// This can be used to write tensors from the execution graph directly +/// to a database. The schema must be created beforehand. Entries in +/// Users, Experiments, and Runs tables will be created automatically +/// if they don't already exist. +/// +/// Please note that the type signature of this function may change in +/// the future if support for other DBs is added to core. +/// +/// The result holds a new reference to db. +absl::Status CreateSummaryDbWriter(Sqlite* db, const string& experiment_name, + const string& run_name, + const string& user_name, Env* env, + SummaryWriterInterface** result); + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_SUMMARY_SUMMARY_DB_WRITER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/summary/summary_file_writer.h b/third_party/tflite-hdrs/tensorflow/core/summary/summary_file_writer.h new file mode 100644 index 00000000..847e7cb8 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/summary/summary_file_writer.h @@ -0,0 +1,44 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_SUMMARY_SUMMARY_FILE_WRITER_H_ +#define TENSORFLOW_CORE_SUMMARY_SUMMARY_FILE_WRITER_H_ + +#include "absl/status/status.h" +#include "tensorflow/core/kernels/summary_interface.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +/// \brief Creates SummaryWriterInterface which writes to a file. +/// +/// The file is an append-only records file of tf.Event protos. That +/// makes this summary writer suitable for file systems like GCS. +/// +/// It will enqueue up to max_queue summaries, and flush at least every +/// flush_millis milliseconds. The summaries will be written to the +/// directory specified by logdir and with the filename suffixed by +/// filename_suffix. The caller owns a reference to result if the +/// returned status is ok. The Env object must not be destroyed until +/// after the returned writer. +absl::Status CreateSummaryFileWriter(int max_queue, int flush_millis, + const string& logdir, + const string& filename_suffix, Env* env, + SummaryWriterInterface** result); + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_SUMMARY_SUMMARY_FILE_WRITER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/tfrt/common/async_value_tensor.h b/third_party/tflite-hdrs/tensorflow/core/tfrt/common/async_value_tensor.h new file mode 100644 index 00000000..06e99f8f --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/tfrt/common/async_value_tensor.h @@ -0,0 +1,72 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_TFRT_COMMON_ASYNC_VALUE_TENSOR_H_ +#define TENSORFLOW_CORE_TFRT_COMMON_ASYNC_VALUE_TENSOR_H_ + +#include +#include + +#include "xla/pjrt/pjrt_client.h" +#include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/platform/types.h" +#include "tfrt/support/forward_decls.h" // from @tf_runtime +#include "tfrt/support/ref_count.h" // from @tf_runtime + +namespace tensorflow { + +// The implementation of a Tensor for an AsyncValue and PjRtBuffer. We used it +// to integrate TF with TFRT. +// TODO(b/243983834) After the migration of using PjRt for data transfer is +// completed, GetAsyncRef and SetAsyncRef will be removed and this class will be +// renamed to PjRtBufferTensor. +class AsyncValueTensor { + public: + // Downcast from a Tensor to an AsyncValueTensor. Return nullptr if the + // downcast fails. + static AsyncValueTensor* FromTensor(const Tensor* tensor); + + const tfrt::RCReference& GetAsyncRef(); + + void SetAsyncRef(tfrt::RCReference av_ref); + + std::shared_ptr GetBuffer(); + + void SetBuffer(std::shared_ptr buffer); + + // Convert from a raw pointer to an AsyncValueTensor, removing the pointer + // tag. + static AsyncValueTensor* FromOpaquePointer(void* ptr); + + // Convert to a raw pointer from an AsyncValueTensor, adding the pointer tag. + static void* ToOpaquePointer(AsyncValueTensor* tensor); + + private: + tfrt::RCReference av_ref_; + std::shared_ptr buffer_; +}; + +class AsyncValueAllocator : public Allocator { + public: + void* AllocateRaw(size_t alignment, size_t num_bytes) override; + void DeallocateRaw(void* ptr) override; + + bool AllocatesOpaqueHandle() const override { return true; } + string Name() override { return "async-value"; } +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_TFRT_COMMON_ASYNC_VALUE_TENSOR_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/tfrt/common/create_pjrt_client_util.h b/third_party/tflite-hdrs/tensorflow/core/tfrt/common/create_pjrt_client_util.h new file mode 100644 index 00000000..fe8dfbb8 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/tfrt/common/create_pjrt_client_util.h @@ -0,0 +1,45 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_TFRT_COMMON_CREATE_PJRT_CLIENT_UTIL_H_ +#define TENSORFLOW_CORE_TFRT_COMMON_CREATE_PJRT_CLIENT_UTIL_H_ + +#include +#include + +#include "absl/status/statusor.h" +#include "xla/pjrt/pjrt_client.h" +#include "tensorflow/core/framework/types.h" + +namespace tensorflow { + +// Gets PJRT client from TFGlobalResourceManager. If it is not found, creates a +// PJRT client and adds it to TFGlobalResourceManager. Different `DeviceType` +// can choose to create the PJRT client explicitly (e.g. in ops) and add it to +// TFGlobalResourceManager, or create a PJRT client on the first use implicitly +// in this method. +// The inputs are the device_type of the caller, and an optional +// set of device IDs `allowed_devices` for which the stream executor will be +// created. `allowed_devices` is only used for GPU. +// TODO(b/260802979): consider passing `XlaPlatformInfo` for the options to +// create a client, or creating a class similar to `LocalClientOptions`. +// TODO(b/280111106): make PjrtClientFactoryOptions an input of +// GetOrCreatePjRtClient. +absl::StatusOr GetOrCreatePjRtClient( + const DeviceType& device_type, + std::optional> allowed_devices = std::nullopt); + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_TFRT_COMMON_CREATE_PJRT_CLIENT_UTIL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/tfrt/common/global_state.h b/third_party/tflite-hdrs/tensorflow/core/tfrt/common/global_state.h new file mode 100644 index 00000000..117a4365 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/tfrt/common/global_state.h @@ -0,0 +1,40 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_TFRT_COMMON_GLOBAL_STATE_H_ +#define TENSORFLOW_CORE_TFRT_COMMON_GLOBAL_STATE_H_ + +#include "tensorflow/core/framework/resource_mgr.h" +#include "tfrt/host_context/host_context.h" // from @tf_runtime + +namespace tensorflow { +namespace tfrt_global { + +class GlobalHostContext { + public: + static void Set(::tfrt::HostContext* host_ctx); + static ::tfrt::HostContext* Get(); + + private: + static ::tfrt::HostContext* host_ctx_; +}; + +// A global resource manager in TF core framework. It can be used to store +// resources that are per host. +ResourceMgr* GetTFGlobalResourceMgr(); + +} // namespace tfrt_global +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_TFRT_COMMON_GLOBAL_STATE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/tfrt/common/metrics.h b/third_party/tflite-hdrs/tensorflow/core/tfrt/common/metrics.h new file mode 100644 index 00000000..e8486b68 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/tfrt/common/metrics.h @@ -0,0 +1,37 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_TFRT_COMMON_METRICS_H_ +#define TENSORFLOW_CORE_TFRT_COMMON_METRICS_H_ + +#include +#include + +#include "xla/tsl/lib/monitoring/sampler.h" + +namespace tensorflow { +namespace tfrt_metrics { + +tsl::monitoring::SamplerCell* GetTfrtGraphExecutorLatencySampler( + const std::string& model_name, int64_t model_version, + const std::string& graph_name); + +tsl::monitoring::SamplerCell* GetTfrtDeviceExecutionLatency( + const std::string& model_name, int64_t model_version); + +} // namespace tfrt_metrics +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_TFRT_COMMON_METRICS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/tfrt/common/pjrt_client_factory_options.h b/third_party/tflite-hdrs/tensorflow/core/tfrt/common/pjrt_client_factory_options.h new file mode 100644 index 00000000..70e3092c --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/tfrt/common/pjrt_client_factory_options.h @@ -0,0 +1,42 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_TFRT_COMMON_PJRT_CLIENT_FACTORY_OPTIONS_H_ +#define TENSORFLOW_CORE_TFRT_COMMON_PJRT_CLIENT_FACTORY_OPTIONS_H_ + +#include +#include +#include + +namespace xla { +// PjrtClientFactoryOptions store arguments to create PJRT client. +// Caller is responsible to set option value for corresponding PJRT client +// factory. +struct PjrtClientFactoryOptions { + struct GpuClientCreateOptions { + bool asynchronous = false; + int node_id = 0; + std::optional> allowed_devices = std::nullopt; + std::optional platform_name = std::nullopt; + }; + + struct CpuClientCreateOptions { + bool asynchronous = false; + }; + GpuClientCreateOptions gpu_options; + CpuClientCreateOptions cpu_options; +}; +} // namespace xla + +#endif // TENSORFLOW_CORE_TFRT_COMMON_PJRT_CLIENT_FACTORY_OPTIONS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/tfrt/common/pjrt_client_factory_registry.h b/third_party/tflite-hdrs/tensorflow/core/tfrt/common/pjrt_client_factory_registry.h new file mode 100644 index 00000000..2a04e9af --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/tfrt/common/pjrt_client_factory_registry.h @@ -0,0 +1,73 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_TFRT_COMMON_PJRT_CLIENT_FACTORY_REGISTRY_H_ +#define TENSORFLOW_CORE_TFRT_COMMON_PJRT_CLIENT_FACTORY_REGISTRY_H_ + +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/status/statusor.h" +#include "xla/pjrt/pjrt_client.h" +#include "xla/tsl/framework/device_type.h" +#include "tensorflow/core/framework/registration/registration.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/tfrt/common/pjrt_client_factory_options.h" +#include "tsl/platform/thread_annotations.h" + +namespace xla { + +using PjrtClientFactory = + std::function>( + const PjrtClientFactoryOptions&)>; + +// The Pjrt client factory registry holds all the registered client factories. +class PjrtClientFactoryRegistry { + public: + explicit PjrtClientFactoryRegistry() = default; + + // Registers PjrtClientFactory with DeviceType as key. + tensorflow::InitOnStartupMarker RegisterPjrtClientFactory( + const tsl::DeviceType& device_type, + const PjrtClientFactory& client_factory); + + // Given the device type, finds related PjrtClientFactory function which takes + // factory option and returns PjrtClient if succeeds. + absl::StatusOr> GetPjrtClient( + const tsl::DeviceType& device_type, + const PjrtClientFactoryOptions& options); + + // Returns singleton instance of PjrtClientFactoryRegistry class. + static PjrtClientFactoryRegistry& Get(); + + private: + absl::flat_hash_map registry_ + TF_GUARDED_BY(mu_); + + mutable ::tensorflow::mutex mu_; +}; + +// The `REGISTER_PJRT_CLIENT_FACTORY()` calls RegisterPjrtClientFactory on +// program startup. +#define REGISTER_PJRT_CLIENT_FACTORY(pjrt_client, device_type, client_factory) \ + static ::tensorflow::InitOnStartupMarker const register_##pjrt_client = \ + ::tensorflow::InitOnStartupMarker{} \ + << ::xla::PjrtClientFactoryRegistry::Get().RegisterPjrtClientFactory( \ + tsl::DeviceType(device_type), client_factory) + +} // namespace xla + +#endif // TENSORFLOW_CORE_TFRT_COMMON_PJRT_CLIENT_FACTORY_REGISTRY_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/tfrt/common/pjrt_state.h b/third_party/tflite-hdrs/tensorflow/core/tfrt/common/pjrt_state.h new file mode 100644 index 00000000..c3df6806 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/tfrt/common/pjrt_state.h @@ -0,0 +1,90 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_TFRT_COMMON_PJRT_STATE_H_ +#define TENSORFLOW_CORE_TFRT_COMMON_PJRT_STATE_H_ + +#include +#include +#include +#include + +#include "absl/base/thread_annotations.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/synchronization/mutex.h" +#include "xla/client/local_client.h" +#include "xla/pjrt/local_device_state.h" +#include "xla/pjrt/pjrt_client.h" +#include "xla/stream_executor/integrations/tf_allocator_adapter.h" +#include "xla/tsl/framework/allocator.h" +#include "tensorflow/core/framework/resource_base.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +const char kPjRtStateResourceName[] = "pjrt_state"; +using PjRtClientsMap = std::map>; + +// Information needed to create a PjRt GPU Client which is used when creating +// a client after after information about remote devices is available. +struct PjRtGpuClientCreationInfo { + std::set allowed_devices; + std::unique_ptr allocator; + std::unique_ptr host_memory_allocator; + std::map> local_device_states; + xla::LocalClient* local_client; +}; + +// The class for the state related to PjRt. It contains a map from `DeviceType` +// to `PjRtClient`. It will be stored in the global `ResourceManager`. +class PjRtState : public ResourceBase { + public: + static PjRtState* Create(); + absl::StatusOr GetPjRtClient(const DeviceType& device_type); + absl::StatusOr GetOrCreatePjRtClient( + const DeviceType& device_type); + absl::Status SetPjRtClient(const DeviceType& device_type, + std::unique_ptr client); + // Moves PJRT client to `unused_`. The PJRT client moved to `unused_` will not + // be returned by `GetPjRtClient`. + absl::Status MovePjRtClientToUnused(const DeviceType& device_type); + string DebugString() const override; + + // Saves information needed to create a PJRT client (to enable creating a + // client with remote devices). + absl::Status SetPjRtGpuClientCreationInfo( + std::unique_ptr info); + + // Retrieves information needed to create a PJRT client (for creating a + // client with remote devices). + PjRtGpuClientCreationInfo* GetPjRtGpuClientCreationInfo(); + + private: + explicit PjRtState() {} + absl::Mutex mu_; + PjRtClientsMap clients_ ABSL_GUARDED_BY(mu_); + // Store the PJRT clients that are no longer used to guarantee that PJRT + // clients outlive PJRT buffers. + std::vector> unused_ ABSL_GUARDED_BY(mu_); + + std::unique_ptr pjrt_gpu_client_creation_info_ + ABSL_GUARDED_BY(mu_); +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_TFRT_COMMON_PJRT_STATE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/tfrt/common/pjrt_util.h b/third_party/tflite-hdrs/tensorflow/core/tfrt/common/pjrt_util.h new file mode 100644 index 00000000..aaba7ad9 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/tfrt/common/pjrt_util.h @@ -0,0 +1,45 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_TFRT_COMMON_PJRT_UTIL_H_ +#define TENSORFLOW_CORE_TFRT_COMMON_PJRT_UTIL_H_ + +#include + +#include "absl/status/statusor.h" +#include "xla/pjrt/pjrt_client.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/tfrt/common/pjrt_state.h" + +namespace tensorflow { + +// Sets PJRT client for device_type in TFGlobalResourceManager. If a PJRT client +// for this device_type already exists, the existing PJRT client will not be +// destroyed, and will be kept alive in an "unused client" vector. PJRT API +// semantics require the PJRT client to outlive PJRT buffers. +absl::Status SetPjRtClientInTFGlobalResourceManager( + const DeviceType& device_type, std::unique_ptr client); + +// Gets (the most recent) PJRT client for device_type from +// TFGlobalResourceManager. +absl::StatusOr GetPjRtClient(const DeviceType& device_type); + +absl::Status SetPjRtGpuClientCreationInfoInTFGlobalResourceManager( + std::unique_ptr info); +absl::StatusOr GetPjRtGpuClientCreationInfo(); + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_TFRT_COMMON_PJRT_UTIL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/tfrt/fallback/cost_recorder.h b/third_party/tflite-hdrs/tensorflow/core/tfrt/fallback/cost_recorder.h new file mode 100644 index 00000000..e1d1b7f4 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/tfrt/fallback/cost_recorder.h @@ -0,0 +1,69 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This file defines a recorder for op cost measurement. + +#ifndef TENSORFLOW_CORE_TFRT_FALLBACK_COST_RECORDER_H_ +#define TENSORFLOW_CORE_TFRT_FALLBACK_COST_RECORDER_H_ + +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/thread_annotations.h" + +namespace tensorflow { +namespace tfrt_stub { + +// Thread-safe. +// Maintains the execution durations by `op_key`. Note that `op_key` is only +// unique within a model. +class CostRecorder { + public: + // Records an execution duration for the op keyed by `op_key`. + void RecordCost(int64_t op_key, uint64_t execution_time); + + // Returns the normalized average execution duration of the op keyed by + // `op_key`. If there is no record for `op_key`, returns the uint32_t::max to + // avoid stream merging. Note that we don't use uint64_t::max because + // otherwise adding op costs would cause overflow. + uint64_t GetCost(int64_t op_key) const; + + // Writes the op cost map (in format of `OpCostMapProto`) to a file specified + // by the env var name `MesuredCostPathEnvVarName()`. + // TODO(b/263837451): Fix the op_key unstableness during serialization. + absl::Status WriteToFile() const; + + size_t size() const; + + static const char* MesuredCostPathEnvVarName() { + return "TF_TFRT_MEASURED_COST_PATH"; + } + + private: + mutable tensorflow::mutex op_cost_map_mutex_; + // Map op key to {sum of op execution duration, #occurences of the op}. + absl::flat_hash_map> op_cost_map_ + TF_GUARDED_BY(op_cost_map_mutex_); +}; + +} // namespace tfrt_stub +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_TFRT_FALLBACK_COST_RECORDER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/tfrt/fallback/device_with_custom_allocator.h b/third_party/tflite-hdrs/tensorflow/core/tfrt/fallback/device_with_custom_allocator.h new file mode 100644 index 00000000..f04e95f9 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/tfrt/fallback/device_with_custom_allocator.h @@ -0,0 +1,101 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_TFRT_FALLBACK_DEVICE_WITH_CUSTOM_ALLOCATOR_H_ +#define TENSORFLOW_CORE_TFRT_FALLBACK_DEVICE_WITH_CUSTOM_ALLOCATOR_H_ + +#include + +#include "xla/tsl/framework/allocator.h" +#include "tensorflow/core/framework/device.h" + +namespace tensorflow { +namespace tfrt_stub { + +class DeviceWithCustomAllocator : public tensorflow::Device { + public: + DeviceWithCustomAllocator(tensorflow::Device* device, + tensorflow::Allocator* allocator) + : Device(device->env(), device->attributes()), + device_(device), + allocator_(allocator) { + DCHECK(device_); + DCHECK(allocator_); + } + + Allocator* GetAllocator(AllocatorAttributes attr) override { + return allocator_; + } + + const DeviceBase* UnderlyingDevice() const override { + return device_->UnderlyingDevice(); + } + DeviceBase* UnderlyingDevice() override { + return device_->UnderlyingDevice(); + } + + const CpuWorkerThreads* tensorflow_cpu_worker_threads() const override { + return device_->tensorflow_cpu_worker_threads(); + } + + Allocator* GetScopedAllocator(AllocatorAttributes attr, + int64_t step_id) override { + return device_->GetScopedAllocator(attr, step_id); + } + + ScopedAllocatorMgr* GetScopedAllocatorMgr() const override { + return device_->GetScopedAllocatorMgr(); + } + + const Eigen::ThreadPoolDevice* eigen_cpu_device() override { + return device_->eigen_cpu_device(); + } + + thread::ThreadPool* tensorflow_device_thread_pool() override { + return device_->tensorflow_device_thread_pool(); + } + + bool has_eigen_cpu_device() const override { + return device_->has_eigen_cpu_device(); + } + + absl::Status MakeTensorFromProto(const TensorProto& tensor_proto, + const AllocatorAttributes alloc_attrs, + Tensor* tensor) override { + return device_->MakeTensorFromProto(tensor_proto, alloc_attrs, tensor); + } + + void CopyTensorInSameDevice(const Tensor* input_tensor, Tensor* output_tensor, + const DeviceContext* device_context, + StatusCallback done) override { + device_->CopyTensorInSameDevice(input_tensor, output_tensor, device_context, + std::move(done)); + } + + absl::Status Sync() override { return device_->Sync(); } + + // Returns the resource manager associated w/ this device. + ResourceMgr* resource_manager() override { + return device_->resource_manager(); + } + + private: + tensorflow::Device* device_ = nullptr; + tensorflow::Allocator* allocator_ = nullptr; +}; + +} // namespace tfrt_stub +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_TFRT_FALLBACK_DEVICE_WITH_CUSTOM_ALLOCATOR_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/tfrt/fallback/fallback_state.h b/third_party/tflite-hdrs/tensorflow/core/tfrt/fallback/fallback_state.h new file mode 100644 index 00000000..ffbf0695 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/tfrt/fallback/fallback_state.h @@ -0,0 +1,105 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_TFRT_FALLBACK_FALLBACK_STATE_H_ +#define TENSORFLOW_CORE_TFRT_FALLBACK_FALLBACK_STATE_H_ + +#include +#include +#include + +#include "absl/base/nullability.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "tensorflow/core/common_runtime/device_mgr.h" +#include "tensorflow/core/common_runtime/device_set.h" +#include "tensorflow/core/common_runtime/graph_execution_state.h" +#include "tensorflow/core/common_runtime/process_function_library_runtime.h" +#include "tensorflow/core/framework/device.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/function.pb.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/public/session_options.h" + +namespace tensorflow { +namespace tfrt_stub { + +// FallbackState contains the necessary runtime states (eg. Devices) used in +// current tensorflow. It also provides methods used in current tensorflow. +class FallbackState { + public: + // The FunctionDefLibrary is passed in to initialize the + // ProcessFunctionLibraryRuntime member of this class + static absl::StatusOr> Create( + const SessionOptions &session_options, + const tensorflow::FunctionDefLibrary &fdef_lib); + + static absl::StatusOr> CreateWithCpuDevice( + const SessionOptions &session_options, + const tensorflow::FunctionDefLibrary &fdef_lib); + + static absl::StatusOr> CreateWithMockGpuDevice( + const SessionOptions &session_options, + const tensorflow::FunctionDefLibrary &fdef_lib); + + static absl::StatusOr> CreateWithDeviceMgr( + const SessionOptions &session_options, + const tensorflow::FunctionDefLibrary &fdef_lib, + absl::Nonnull device_mgr); + + FallbackState(const SessionOptions &session_options, + std::variant>, + absl::Nonnull> + device_mgr, + const tensorflow::FunctionDefLibrary &fdef_lib); + + // Create GraphExecutionState from the `graph_def`. The result will contain a + // preprocessed graph with runtime information such as devices. + absl::StatusOr> + CreateGraphExecutionState(GraphDef graph_def, bool run_placer = true, + bool enable_tf2xla_mlir_bridge = true) const; + + // Adds `func_def` to the function library. + absl::Status AddFunctionDef(const FunctionDef &func_def); + + const SessionOptions &session_options() const { return session_options_; } + + const DeviceMgr &device_manager() const { return *device_manager_ptr_; } + + DeviceMgr &device_manager() { return *device_manager_ptr_; } + + const DeviceSet &device_set() const { return device_set_; } + + const ProcessFunctionLibraryRuntime &process_function_library_runtime() + const { + return pflr_; + } + + const FunctionLibraryDefinition &func_lib_def() const { + return func_lib_def_; + } + + private: + SessionOptions session_options_; + DynamicDeviceMgr device_manager_; + absl::Nonnull device_manager_ptr_; + DeviceSet device_set_; + FunctionLibraryDefinition func_lib_def_; + ProcessFunctionLibraryRuntime pflr_; +}; + +} // namespace tfrt_stub +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_TFRT_FALLBACK_FALLBACK_STATE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/tfrt/fallback/op_kernel_runner.h b/third_party/tflite-hdrs/tensorflow/core/tfrt/fallback/op_kernel_runner.h new file mode 100644 index 00000000..317d0956 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/tfrt/fallback/op_kernel_runner.h @@ -0,0 +1,236 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_TFRT_FALLBACK_OP_KERNEL_RUNNER_H_ +#define TENSORFLOW_CORE_TFRT_FALLBACK_OP_KERNEL_RUNNER_H_ + +#include +#include + +#include +#include +#include +#include +#include +#include + +#include "absl/base/attributes.h" +#include "absl/container/inlined_vector.h" +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "tensorflow/core/common_runtime/device_mgr.h" +#include "tensorflow/core/common_runtime/process_function_library_runtime.h" +#include "tensorflow/core/framework/device.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/op_def.pb.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/status.h" + +namespace tensorflow { +namespace tfrt_stub { + +class OpKernelRunner { + public: + static absl::StatusOr Create( + absl::string_view op_name, absl::string_view node_name, + absl::string_view device_name, int num_args, + const std::function& + attr_builder, + const tensorflow::DeviceMgr& device_manager, + const tensorflow::ProcessFunctionLibraryRuntime& + process_function_library_runtime); + + ABSL_DEPRECATED("Please use the Create() method that takes node_name.") + static absl::StatusOr Create( + absl::string_view op_name, absl::string_view device_name, int num_args, + const std::function& + attr_builder, + const tensorflow::DeviceMgr& device_manager, + const tensorflow::ProcessFunctionLibraryRuntime& + process_function_library_runtime) { + return Create(op_name, /*node_name=*/op_name, device_name, num_args, + attr_builder, device_manager, + process_function_library_runtime); + } + + static absl::StatusOr Create( + absl::string_view op_name, absl::string_view node_name, int num_args, + const std::function& + attr_builder, + const tensorflow::ProcessFunctionLibraryRuntime& + process_function_library_runtime, + tensorflow::Device* device); + + ABSL_DEPRECATED("Please use the Create() method that takes node_name.") + static absl::StatusOr Create( + absl::string_view op_name, int num_args, + const std::function& + attr_builder, + const tensorflow::ProcessFunctionLibraryRuntime& + process_function_library_runtime, + tensorflow::Device* device) { + return Create(op_name, /*node_name=*/op_name, num_args, attr_builder, + process_function_library_runtime, device); + } + + OpKernelRunner() = default; + + explicit operator bool() const { return op_kernel_ != nullptr; } + + void Run(OpKernelContext* context) const { + DVLOG(1) << "KernelFallbackExecuteCompat Running Op: " + << op_kernel_->def().DebugString() + << ", on Device: " << context->device()->name(); + + // For TFRT GPU or TPU, we currently only run xla clusters on GPU or TPU, + // and all other ops are run on CPU. + + op_kernel_->Compute(context); + } + + void RunAsync(OpKernelContext* context, + AsyncOpKernel::DoneCallback done_callback) const; + + bool IsAsync() const { return info_->is_async; } + + tensorflow::OpKernel* op_kernel() const { return op_kernel_.get(); } + tensorflow::Device* device() const { return info_->device; } + tensorflow::FunctionLibraryRuntime* function_library_runtime() const { + return info_->function_library_runtime; + } + tensorflow::ResourceMgr* resource_manager() const { + return info_->resource_manager; + } + + absl::Span input_alloc_attrs() const { + return input_alloc_attrs_; + } + absl::Span output_alloc_attrs() const { + return output_alloc_attrs_; + } + + private: + explicit OpKernelRunner( + tensorflow::Device* device, + tensorflow::FunctionLibraryRuntime* function_library_runtime, + std::unique_ptr op_kernel); + + std::unique_ptr op_kernel_; + absl::Span input_alloc_attrs_; + absl::Span output_alloc_attrs_; + + struct Info { + tensorflow::Device* device = nullptr; + tensorflow::FunctionLibraryRuntime* function_library_runtime = nullptr; + tensorflow::ResourceMgr* resource_manager = nullptr; + bool is_async = false; + absl::InlinedVector input_alloc_attrs; + absl::InlinedVector output_alloc_attrs; + }; + std::unique_ptr info_; +}; + +// OpKernelRunState keeps the states needed for per-kernel execution. +struct OpKernelRunState { + std::vector tensor_buffers; + std::vector input_tf_tensor_values; + OpKernelContext::Params params; + absl::InlinedVector input_tf_tensors; + + OpKernelRunState() = default; + OpKernelRunState(absl::Span tensor_values, + const OpKernelContext::Params& p, + tensorflow::DeviceBase* device = nullptr) { + // `input_tf_tensor_values` contains the reference to all tensor used, + // while `input_tf_tensors` only contains those needs ownership so their + // sizes may not match. For this copy assignment, we conservatively copy all + // tensors. + input_tf_tensors.reserve(tensor_values.size()); + for (const auto& tensor_value : tensor_values) { + input_tf_tensors.push_back(*tensor_value.tensor); + } + for (auto& tensor : input_tf_tensors) { + input_tf_tensor_values.emplace_back(&tensor); + } + + // Since `input_tf_tensor_values` and `params` contains pointers to + // `input_tf_tensors`, we need to change those pointers to the correct ones + // after copying. + params = p; + params.inputs = input_tf_tensor_values; + // Clear eigen_gpu_device to ensure OpKernelContext constructor will make a + // new eigen GPU device. + params.eigen_gpu_device = nullptr; + if (device != nullptr) params.device = device; + } + + OpKernelRunState(const OpKernelRunState& other) = delete; + OpKernelRunState& operator=(const OpKernelRunState& other) = delete; + + ~OpKernelRunState() = default; +}; + +// OpKernelRunnerTable for keeping OpKernelRunner instances to avoid expensive +// reinstantiation of OpKernel and other repeated setup per kernel execution. +// OpKernelRunnerTable is thread-compatible. +class OpKernelRunnerTable { + public: + OpKernelRunnerTable() = default; + + // Return true if it successfully inserts `runner`. `index` is supposed to be + // dense. + bool Insert(int64_t index, OpKernelRunner runner) { + if (runners_.size() <= index) runners_.resize(index + 1); + if (runners_[index]) return false; + runners_[index] = std::move(runner); + return true; + } + + // Return the OpKernelRunner at the corresponding `index` in the table. The + // result can never be nullptr. It is a fatal error to use an index that is + // not in the table. Note that the returned pointer will be invalidated if + // Insert() is called. + const OpKernelRunner* Get(int64_t index) const { + // Out of bounds vector access will throw an exception and anyway will crash + // the binary, prefer a more readable error message. + CHECK_GT(runners_.size(), index) // Crash OK + << "runner index is out of bounds: index=" << index + << " size=" << runners_.size(); + CHECK(runners_[index]) // Crash OK + << "runner is not available: index=" << index; + return GetUnsafe(index); + } + + const OpKernelRunner* GetUnsafe(int64_t index) const { + DCHECK_GT(runners_.size(), index); + auto& result = runners_[index]; + DCHECK(result); + return &result; + } + + private: + std::vector runners_; +}; + +} // namespace tfrt_stub +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_TFRT_FALLBACK_OP_KERNEL_RUNNER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/tfrt/fallback/op_kernel_runner_cache.h b/third_party/tflite-hdrs/tensorflow/core/tfrt/fallback/op_kernel_runner_cache.h new file mode 100644 index 00000000..64f1060e --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/tfrt/fallback/op_kernel_runner_cache.h @@ -0,0 +1,74 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_TFRT_FALLBACK_OP_KERNEL_RUNNER_CACHE_H_ +#define TENSORFLOW_CORE_TFRT_FALLBACK_OP_KERNEL_RUNNER_CACHE_H_ + +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "tensorflow/core/tfrt/fallback/op_kernel_runner.h" +#include "tfrt/host_context/location.h" // from @tf_runtime + +namespace tensorflow { +namespace tfrt_stub { + +class OpLocationKey { + public: + explicit OpLocationKey(tfrt::Location loc) : loc_(loc) {} + + template + friend H AbslHashValue(H h, const OpLocationKey& key) { + // NOTE: Each BEF file has its own LocationHandler. Using LocationHandler + // as part of cache key here can avoid cache collision between different + // BEF file. + return H::combine(std::move(h), key.loc_.data, key.loc_.GetHandler()); + } + + friend bool operator==(const OpLocationKey& x, const OpLocationKey& y) { + return x.loc_.data == y.loc_.data && + x.loc_.GetHandler() == y.loc_.GetHandler(); + } + + private: + tfrt::Location loc_; +}; + +// OpKernelRunnerCache is similar to OpKernelRunnerTable but thread-safe. +class OpKernelRunnerCache { + public: + OpKernelRunnerCache() = default; + + absl::StatusOr GetOrCreate( + tfrt::Location loc, absl::string_view op_name, + absl::string_view device_name, int num_args, + const std::function& + attr_builder, + const tensorflow::DeviceMgr& device_manager, + const tensorflow::ProcessFunctionLibraryRuntime& + process_function_library_runtime); + + private: + mutable mutex mu_; + absl::flat_hash_map> map_ + TF_GUARDED_BY(mu_); +}; + +} // namespace tfrt_stub +} // namespace tensorflow +#endif // TENSORFLOW_CORE_TFRT_FALLBACK_OP_KERNEL_RUNNER_CACHE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/tfrt/gpu/kernel/gpu_runner.h b/third_party/tflite-hdrs/tensorflow/core/tfrt/gpu/kernel/gpu_runner.h new file mode 100644 index 00000000..5c51b8d5 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/tfrt/gpu/kernel/gpu_runner.h @@ -0,0 +1,73 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_TFRT_GPU_KERNEL_GPU_RUNNER_H_ +#define TENSORFLOW_CORE_TFRT_GPU_KERNEL_GPU_RUNNER_H_ + +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/status/statusor.h" +#include "llvm/ADT/SmallVector.h" +#include "xla/tsl/framework/serving_device_selector.h" +#include "tensorflow/core/framework/device.h" +#include "tensorflow/core/framework/resource_mgr.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/runtime_fallback/kernel/kernel_fallback_compat_request_state.h" +#include "tensorflow/core/tfrt/utils/fallback_tensor.h" +#include "tensorflow/core/tfrt/utils/gpu_variables_table.h" +#include "tfrt/host_context/async_value_ref.h" // from @tf_runtime +#include "tfrt/host_context/execution_context.h" // from @tf_runtime +#include "tfrt/support/forward_decls.h" // from @tf_runtime + +namespace tensorflow { +namespace gpu { + +constexpr char kGpuRunnerResourceName[] = "GpuRunnerResource"; + +struct GpuRunInputs { + std::vector args; + int num_outputs; + std::vector resource_indices; + std::vector used_output_indices; + std::string func_name; + Device* cpu_device; + absl::flat_hash_map gpu_devices; + const tfd::KernelFallbackCompatRequestState* fallback_request_state; + tfrt::HostContext* host_ctx; +}; + +class GpuRunner { + public: + explicit GpuRunner(tsl::ServingDeviceSelector* serving_device_selector) + : serving_device_selector_(serving_device_selector) {} + + // This compiles the given program and runs the given input tensors in + // `run_inputs`, and returns the output tensor AsyncValues. + absl::StatusOr< + llvm::SmallVector>> + Run(GpuRunInputs run_inputs); + + private: + tsl::ServingDeviceSelector* serving_device_selector_; + tfrt::gpu::GpuVariablesTable vars_table_; +}; + +} // namespace gpu +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_TFRT_GPU_KERNEL_GPU_RUNNER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/tfrt/gpu/kernel/tfrt_gpu_init.h b/third_party/tflite-hdrs/tensorflow/core/tfrt/gpu/kernel/tfrt_gpu_init.h new file mode 100644 index 00000000..f36356b8 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/tfrt/gpu/kernel/tfrt_gpu_init.h @@ -0,0 +1,37 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_TFRT_GPU_KERNEL_TFRT_GPU_INIT_H_ +#define TENSORFLOW_CORE_TFRT_GPU_KERNEL_TFRT_GPU_INIT_H_ +#include "absl/status/status.h" +#include "xla/tsl/framework/serving_device_selector_policies.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/tfrt/runtime/runtime.h" + +namespace tensorflow { +namespace gpu { + +struct GpuRunnerOptions { + int num_gpu_streams = 1; + tsl::ServingDeviceSelectorPolicy serving_selector_policy = + tsl::ServingDeviceSelectorPolicy::kRoundRobin; +}; + +absl::Status InitTfrtGpu(const GpuRunnerOptions& options, + tensorflow::tfrt_stub::Runtime& runtime); + +} // namespace gpu +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_TFRT_GPU_KERNEL_TFRT_GPU_INIT_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/tfrt/graph_executor/config.h b/third_party/tflite-hdrs/tensorflow/core/tfrt/graph_executor/config.h new file mode 100644 index 00000000..b0e3fbf1 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/tfrt/graph_executor/config.h @@ -0,0 +1,84 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_TFRT_GRAPH_EXECUTOR_CONFIG_H_ +#define TENSORFLOW_CORE_TFRT_GRAPH_EXECUTOR_CONFIG_H_ + +#include + +#include "google/protobuf/any.pb.h" +#include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "tensorflow/core/tfrt/graph_executor/config.pb.h" + +namespace tensorflow { +namespace tfrt_stub { + +// The helper class for building RuntimeConfigProto and retrieving configs of +// certain types from the RuntimeConfigProto. +class RuntimeConfig { + public: + RuntimeConfig() = default; + + static absl::StatusOr CreateFromProto( + RuntimeConfigProto proto); + + template + absl::Status Add(const ConcreteProto& config) { + const auto& full_name = config.GetDescriptor()->full_name(); + if (map_.contains(full_name)) { + return absl::AlreadyExistsError( + absl::StrCat(full_name, " already exists in ModelConfig.")); + } + + size_t id = proto_.config_size(); + if (!proto_.add_config()->PackFrom(config)) { + return absl::InvalidArgumentError( + absl::StrCat("Failed to pack proto to Any: ", full_name)); + } + map_[full_name] = id; + return absl::OkStatus(); + } + + template + absl::StatusOr Get() const { + const auto& full_name = ConcreteProto::GetDescriptor()->full_name(); + auto iter = map_.find(full_name); + + if (iter == map_.end()) { + return absl::NotFoundError( + absl::StrCat(full_name, " not found in ModelConfig.")); + } + + ConcreteProto config; + if (!proto_.config().at(iter->second).UnpackTo(&config)) { + return absl::DataLossError( + absl::StrCat("Failed to unpack proto: ", full_name)); + } + return config; + } + + const RuntimeConfigProto& ToProto() const { return proto_; } + + private: + RuntimeConfigProto proto_; + absl::flat_hash_map map_; +}; + +} // namespace tfrt_stub +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_TFRT_GRAPH_EXECUTOR_CONFIG_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/tfrt/graph_executor/executable_context.h b/third_party/tflite-hdrs/tensorflow/core/tfrt/graph_executor/executable_context.h new file mode 100644 index 00000000..fb02ab34 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/tfrt/graph_executor/executable_context.h @@ -0,0 +1,65 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_TFRT_GRAPH_EXECUTOR_EXECUTABLE_CONTEXT_H_ +#define TENSORFLOW_CORE_TFRT_GRAPH_EXECUTOR_EXECUTABLE_CONTEXT_H_ + +#include +#include + +#include "tensorflow/core/tfrt/mlrt/bytecode/bytecode.h" +#include "tensorflow/core/tfrt/mlrt/interpreter/context.h" +#include "tfrt/bef/bef_buffer.h" // from @tf_runtime +#include "tfrt/bef_executor/bef_file.h" // from @tf_runtime +#include "tfrt/host_context/resource_context.h" // from @tf_runtime +#include "tfrt/support/ref_count.h" // from @tf_runtime + +namespace tensorflow { +namespace tfrt_stub { + +// Stores executable-related data. +struct ExecutableContext { + ExecutableContext(mlrt::bc::Buffer bytecode_buffer, + std::unique_ptr bytecode_executable) + : bytecode_buffer(std::move(bytecode_buffer)), + bytecode_executable(std::move(bytecode_executable)) {} + + ExecutableContext(tfrt::BefBuffer bef, + tfrt::RCReference bef_file) + : bef(std::move(bef)), bef_file(std::move(bef_file)) {} + + bool IsForMlrt() const { return bytecode_executable != nullptr; } + + // Only one set of values will be filled. + + // For the MLRT path. + mlrt::bc::Buffer bytecode_buffer; + std::unique_ptr bytecode_executable; + + // For the TFRT path. + tfrt::BefBuffer bef; + tfrt::RCReference bef_file; + + // There are some resources that need re-creating when the executable is + // re-created, so a resource context is stored along with the executable. + // This resource context is meant to be passed to the op kernels for their + // references. See the comment above `GraphExecutor::resource_context_` + // about the todo to merge that resource context with this one. + tfrt::ResourceContext resource_context; +}; + +} // namespace tfrt_stub +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_TFRT_GRAPH_EXECUTOR_EXECUTABLE_CONTEXT_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/tfrt/graph_executor/export_mlir.h b/third_party/tflite-hdrs/tensorflow/core/tfrt/graph_executor/export_mlir.h new file mode 100644 index 00000000..ac687711 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/tfrt/graph_executor/export_mlir.h @@ -0,0 +1,64 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_TFRT_GRAPH_EXECUTOR_EXPORT_MLIR_H_ +#define TENSORFLOW_CORE_TFRT_GRAPH_EXECUTOR_EXPORT_MLIR_H_ + +#include +#include +#include + +#include "mlir/IR/BuiltinOps.h" // from @llvm-project + +namespace tensorflow { +namespace tfrt_stub { + +class XsymbolUploader { + public: + virtual ~XsymbolUploader() = default; + + virtual std::string MaybeUploadMlirToXsymbol(mlir::ModuleOp module) { + return ""; + } +}; + +class XsymbolUploaderRegistry { + public: + XsymbolUploaderRegistry() + : xsymbol_uploader_(std::make_unique()) {} + + void Register(std::unique_ptr xsymbol_uploader) { + xsymbol_uploader_ = std::move(xsymbol_uploader); + } + + XsymbolUploader &Get() const { return *xsymbol_uploader_; } + + private: + std::unique_ptr xsymbol_uploader_; +}; + +inline XsymbolUploaderRegistry &GetGlobalXsymbolUploaderRegistry() { + static auto *const registry = new XsymbolUploaderRegistry; + return *registry; +} + +inline std::string MaybeUploadMlirToXsymbol(mlir::ModuleOp module) { + return GetGlobalXsymbolUploaderRegistry().Get().MaybeUploadMlirToXsymbol( + module); +} + +} // namespace tfrt_stub +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_TFRT_GRAPH_EXECUTOR_EXPORT_MLIR_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/tfrt/graph_executor/graph_execution_options.h b/third_party/tflite-hdrs/tensorflow/core/tfrt/graph_executor/graph_execution_options.h new file mode 100644 index 00000000..32d0a007 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/tfrt/graph_executor/graph_execution_options.h @@ -0,0 +1,162 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_TFRT_GRAPH_EXECUTOR_GRAPH_EXECUTION_OPTIONS_H_ +#define TENSORFLOW_CORE_TFRT_GRAPH_EXECUTOR_GRAPH_EXECUTION_OPTIONS_H_ + +#include +#include +#include +#include + +#include "absl/time/time.h" +#include "absl/types/optional.h" +#include "tensorflow/compiler/mlir/tfrt/translate/tfrt_compile_options.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/protobuf/config.pb.h" +#include "tensorflow/core/public/session_options.h" +#include "tensorflow/core/tfrt/graph_executor/config.h" +#include "tensorflow/core/tfrt/runtime/work_queue_interface.h" + +namespace tensorflow { +namespace tfrt_stub { + +class Runtime; + +// General options for graph execution. +struct GraphExecutionOptions { + explicit GraphExecutionOptions(const tensorflow::tfrt_stub::Runtime* rt) + : runtime(rt) { + DCHECK(runtime); + } + + // If true, when creating an optimized subgraph, Placer and Grappler will + // also run on the functions. + bool run_placer_grappler_on_functions = false; + + // If true, the function optimizer in the grappler will be enabled, and + // optimizations like function inlining will be applied. + bool enable_grappler_function_optimizer = false; + + // Whether to enable TFRT GPU. + bool enable_tfrt_gpu = false; + + // The number of virtual GPUs to create on a physical GPU. + int tfrt_gpu_parallelism = 1; + + // if not zero, override the reserved memory space for gpu system. + int gpu_system_memory_size_in_mb = 0; + + // Whether to use gpurt.compile_and_execute for GPU. + // TODO(b/294895431): Remove the flag and default to the fused op. + bool tfrt_use_fused_gpu_op = false; + + // Runtime configuration. Refer to tensorflow::tfrt_stub::Runtime class for + // more details. It must not be nullptr; + const tensorflow::tfrt_stub::Runtime* runtime = nullptr; + + // Model metadata used for monitoring and tracing. + tensorflow::SessionMetadata model_metadata; + + // The model-specific runtime configurations. + tensorflow::tfrt_stub::RuntimeConfig runtime_config; + + // TODO(b/266251216): Maybe flip the default value. + [[deprecated( + "Use CostAnalysisOptions's `CostAnalysisOptions::ONCE` instead")]] bool + enable_online_cost_analysis = false; + + // Determines how often op costs are recorded, and how often these costs + // are used to re-compile the executable. Note to users: CostAnalysisOptions + // is overwritten when `enable_online_cost_analysis = true`. + struct CostAnalysisOptions { + enum CostAnalysisVersion { + kDisabled, + kOnce, // Cost recording and recompilation occurs on the first run only. + kPeriodic, // This is experimental. + }; + CostAnalysisVersion version = kDisabled; + + // Time between resets in Op cost estimates. Upon reset, the executable + // will be recompiled. + // However, a reset always occurs after the first execution. + absl::Duration reset_interval = absl::ZeroDuration(); + + // Number of times to record costs before resetting Op cost estimates. + // However, a reset always occurs after the first execution. + int updates_per_interval = 1; + }; + + CostAnalysisOptions cost_analysis_options; + + // If true, the MLRT interpreter will be used instead of the BEF executor. + // This option is experimental. + bool enable_mlrt = false; + + // If true, the IFRT will be used instead of the TPU Runner. + // This option is experimental. + bool use_ifrt = false; + + tensorflow::TfrtCompileOptions compile_options; +}; + +std::ostream& operator<<(std::ostream& os, + const GraphExecutionOptions& options); + +// Per-request options for graph execution. +struct GraphExecutionRunOptions { + std::optional deadline; + + // Priority of the request. Larger number means higher priority. + int priority = 0; + + // If true, the input specs will be checked before running, and an error + // will be raised upon mismatch. + bool validate_input_specs = false; + + // TODO(b/279197040) Remove after b/279197040 is fixed. + // If true, the input specs will be checked before running, and an error + // will be logged upon mismatch. + bool validate_input_specs_dry_run = false; + + // The thread pool used for this run. If it is nullptr, a default one set + // in the tensorflow::tfrt_stub::Runtime will be used. + tensorflow::tfrt_stub::WorkQueueInterface* work_queue = nullptr; + + // If true, just-in-time host compilation is disabled, and then if the + // specified graph is not compiled, the execution will return an error. + bool disable_compilation = false; + + std::function)> + streamed_output_callback; + + // The optional name for debugging purposes. If empty, the runtime will pick a + // name e.g. the joined string of input names and output names. + std::string name; +}; + +// Creates the default `SessionOptions` from a `GraphExecutionOptions`. +// The created `SessionOptions` contains the Grappler configs. +tensorflow::SessionOptions CreateDefaultSessionOptions( + const GraphExecutionOptions& options); + +// Updates TPU target to fallback if bridge uncompatible, otherwise TPU runtime. +void UpdateTpuTargetByBridgeCompatibility( + tensorflow::tfrt_stub::GraphExecutionOptions& options, + const tensorflow::GraphDef& graph_def); + +} // namespace tfrt_stub +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_TFRT_GRAPH_EXECUTOR_GRAPH_EXECUTION_OPTIONS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/tfrt/graph_executor/graph_executor.h b/third_party/tflite-hdrs/tensorflow/core/tfrt/graph_executor/graph_executor.h new file mode 100644 index 00000000..18375802 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/tfrt/graph_executor/graph_executor.h @@ -0,0 +1,394 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_TFRT_GRAPH_EXECUTOR_GRAPH_EXECUTOR_H_ +#define TENSORFLOW_CORE_TFRT_GRAPH_EXECUTOR_GRAPH_EXECUTOR_H_ + +#include +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/log/check.h" +#include "absl/strings/string_view.h" +#include "absl/time/clock.h" +#include "absl/time/time.h" +#include "absl/types/span.h" +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/DialectRegistry.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/OwningOpRef.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h" +#include "tensorflow/compiler/mlir/tfrt/backend_compiler.h" +#include "xla/tsl/concurrency/ref_count.h" +#include "xla/tsl/lib/monitoring/sampler.h" +#include "tensorflow/core/common_runtime/process_function_library_runtime.h" +#include "tensorflow/core/framework/cancellation.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/statusor.h" +#include "tensorflow/core/protobuf/config.pb.h" +#include "tensorflow/core/runtime_fallback/kernel/kernel_fallback_compat_request_state.h" +#include "tensorflow/core/tfrt/fallback/cost_recorder.h" +#include "tensorflow/core/tfrt/fallback/fallback_state.h" +#include "tensorflow/core/tfrt/fallback/op_kernel_runner.h" +#include "tensorflow/core/tfrt/graph_executor/executable_context.h" +#include "tensorflow/core/tfrt/graph_executor/graph_execution_options.h" +#include "tensorflow/core/tfrt/graph_executor/sync_resource_state.h" +#include "tensorflow/core/tfrt/mlrt/bytecode/function.h" +#include "tensorflow/core/tfrt/mlrt/interpreter/context.h" +#include "tensorflow/core/tfrt/mlrt/interpreter/value.h" +#include "tensorflow/core/tfrt/runtime/runtime.h" +#include "tensorflow/core/tfrt/runtime/stream.h" +#include "tensorflow/core/tfrt/runtime/work_queue_interface.h" +#include "tensorflow/core/tfrt/utils/tfrt_graph_execution_state.h" +#include "tsl/platform/thread_annotations.h" +#include "tfrt/bef/bef_buffer.h" // from @tf_runtime +#include "tfrt/bef_executor/bef_file.h" // from @tf_runtime +#include "tfrt/core_runtime/core_runtime.h" // from @tf_runtime +#include "tfrt/host_context/execution_context.h" // from @tf_runtime +#include "tfrt/host_context/function.h" // from @tf_runtime +#include "tfrt/host_context/request_deadline_tracker.h" // from @tf_runtime +#include "tfrt/host_context/resource_context.h" // from @tf_runtime +#include "tfrt/support/ref_count.h" // from @tf_runtime + +namespace tensorflow { +namespace tfrt_stub { + +// Contains request related info. +struct RequestInfo { + tfrt::RCReference tfrt_request_context; + // If this request needs to create a new queue, it is stored here. Otherwise, + // it can be nullptr. + std::unique_ptr request_queue_owner; + // The inter-op thread pool to be used for this request, and it must not be + // nullptr. If `request_queue_owner` is not nullptr, then `request_queue` is + // the raw pointer inside `request_queue_owner`. + WorkQueueInterface* request_queue = nullptr; + // The task runner used by tensorflow::OpKernel. + std::function)> runner; + + tensorflow::CancellationManager cancellation_manager; +}; + +struct SymbolUids { + std::string tf_symbol_uid; + std::string tfrt_symbol_uid; +}; + +// Creates a `RequestInfo` given relative data. +// Note: `resource_context` is per-graph-executor and +// `client_graph_resource_context` is per-loaded-client-graph. See the comment +// above `GraphExecutor::resource_context_` about the todo to merge these two. +absl::StatusOr> CreateRequestInfo( + const GraphExecutionOptions& options, + const GraphExecutionRunOptions& run_options, + tensorflow::tfrt_stub::WorkQueueInterface* work_queue, + tfrt::ResourceContext* resource_context, + tfrt::ResourceContext* client_graph_resource_context, + OpKernelRunnerTable* runner_table, + tfd::FallbackResourceArray* resource_array, FallbackState& fallback_state, + const ProcessFunctionLibraryRuntime& process_function_library_runtime, + CostRecorder* cost_recorder = nullptr); + +// Runs on a function given input/output and other info. +// Note: `resource_context` is per-graph-executor and +// `client_graph_resource_context` is per-loaded-client-graph. See the comment +// above `GraphExecutor::resource_context_` about the todo to merge these two. +// +// TODO(chky): Refactor this function to take `LoadedClientGraph` instead of +// having a long list of parameters. +absl::Status GraphExecutionRunOnFunction( + const GraphExecutionOptions& options, + const GraphExecutionRunOptions& run_options, + absl::string_view signature_name, const SymbolUids& symbol_uids, + const tfrt::Function* func, const mlrt::LoadedExecutable* loaded_executable, + absl::Span inputs, + std::vector* outputs, + tfrt::ResourceContext* resource_context, + tfrt::ResourceContext* client_graph_resource_context, + OpKernelRunnerTable* runner_table, + tfd::FallbackResourceArray* resource_array, const Runtime& runtime, + FallbackState& fallback_state, + const tensorflow::ProcessFunctionLibraryRuntime& + process_function_library_runtime, + tfrt::RequestDeadlineTracker* req_deadline_tracker, + std::optional stream_callback_id, + CostRecorder* cost_recorder = nullptr); + +// Runs a MLRT function for executing tensorflow graphs. +absl::Status RunMlrtFunction( + mlrt::bc::Function function, + const mlrt::LoadedExecutable& loaded_executable, + const tsl::RCReference& request_context, + tfrt::ConcurrentWorkQueue& work_queue, + absl::Span inputs, + std::vector* outputs, + SyncResourceState* sync_resource_state); + +// Loads (if not yet) and runs a subgraph in a graph as per each request. +class GraphExecutor { + public: + using Options = GraphExecutionOptions; + using RunOptions = GraphExecutionRunOptions; + + // The loading result of a `ClientGraph`. + class LoadedClientGraph { + public: + LoadedClientGraph(std::string name, SymbolUids symbol_uids, + GraphExecutor* graph_executor, + std::unique_ptr mlir_context, + mlir::OwningOpRef tf_mlir_with_op_keys, + mlir::OwningOpRef tfrt_mlir, + std::shared_ptr executable_context, + std::optional stream_callback_id, + bool is_restore, FunctionLibraryDefinition flib_def, + tsl::monitoring::SamplerCell* latency_sampler); + + // Returns this instance's CostRecorder if it is time to update costs, + // else returns nullptr. Only allows one non-null return value at a time + // in order to provide thread-safety. If do_recompilation becomes `true`, + // then recompiles using updated costs occurs. + CostRecorder* MaybeGetCostRecorder(absl::Time now, bool* do_recompilation); + // Updates the op cost values in this `LoadedClientGraph` with records from + // `cost_recorder`. + absl::Status UpdateCost(const CostRecorder& cost_recorder, + const Runtime& runtime); + // Updates `cost_analysis_data_` to make it accurate for the next execution. + // Assumes a cost update occurred this cycle. + void UpdateCostAnalysisData(absl::Time now, bool do_recompilation); + // Getters. + std::shared_ptr executable_context() const { + tensorflow::mutex_lock lock(executable_context_mu_); + return executable_context_; + } + absl::string_view name() const { return name_; } + const SymbolUids& symbol_uids() const { return symbol_uids_; } + + OpKernelRunnerTable& runner_table() { return runner_table_; } + tfd::FallbackResourceArray& resource_array() { return resource_array_; } + SyncResourceState& sync_resource_state() { return sync_resource_state_; } + + std::optional stream_callback_id() const { + return stream_callback_id_; + } + + bool is_restore() const { return is_restore_; } + + const ProcessFunctionLibraryRuntime& process_function_library_runtime() + const { + return pflr_; + } + tsl::monitoring::SamplerCell* latency_sampler() { return latency_sampler_; } + + private: + std::string name_; + SymbolUids symbol_uids_; + GraphExecutor* graph_executor_ = nullptr; + + // `mlir_context_` is declared here because the resources declared later may + // hold references to the MLIR objects. + std::unique_ptr mlir_context_; + + struct CostAnalysisData { + mutable tensorflow::mutex mu; + // Ensures only one GraphExecutor thread updates costs at a time. + bool is_available TF_GUARDED_BY(mu) = false; + // Maintains the book-keeping of op costs. + std::unique_ptr cost_recorder; + // For recompilation in MLRT, TFRT respectively. + mlir::OwningOpRef tf_mlir_with_op_keys; + mlir::OwningOpRef tfrt_mlir; + // Start of current cost measurement cycle. + absl::Time start_time TF_GUARDED_BY(mu) = absl::Now(); + // Cost recordings within the current measurement cycle. + int num_cost_updates TF_GUARDED_BY(mu) = 0; + }; + CostAnalysisData cost_analysis_data_; + + OpKernelRunnerTable runner_table_; + tfd::FallbackResourceArray resource_array_; + mutable tensorflow::mutex executable_context_mu_; + // Can be updated if online cost analysis is enabled. + std::shared_ptr executable_context_ + TF_GUARDED_BY(executable_context_mu_); + SyncResourceState sync_resource_state_; + + std::optional stream_callback_id_; + bool is_restore_; + FunctionLibraryDefinition flib_def_; + ProcessFunctionLibraryRuntime pflr_; + tsl::monitoring::SamplerCell* latency_sampler_; + }; + + // A subgraph constructed by specifying input/output tensors. + struct ClientGraph { + // The human-readable name for the graph, e.g. the signature_name in the + // saved model. + std::string name; + // The feed nodes for the corresponding inputs, but they might not be in the + // original order and if there are more than one original inputs mapped to + // the same feed node, only one is picked here. + tensorflow::GraphImportConfig::InputArrays input_nodes; + // The fetch nodes for the outputs, which should be in the original order. + std::vector output_nodes; + // The target nodes that should be run but not returned as outputs. + std::vector target_nodes; + }; + + // Creates a `GraphExecutor` given the args. + static absl::StatusOr> Create( + Options options, std::unique_ptr fallback_state, + std::unique_ptr resource_context, + tensorflow::GraphDef graph_def, + std::unique_ptr kernel_registry, + tensorflow::tfrt_stub::RuntimeConfig* runtime_config = nullptr); + + // Ctor. Public for `Create()`. Do not use directly. + GraphExecutor(Options options, std::unique_ptr fallback_state, + std::unique_ptr resource_context, + std::unique_ptr + graph_execution_state, + std::unique_ptr kernel_registry, + tensorflow::tfrt_stub::RuntimeConfig* runtime_config = nullptr); + + // Runs on the graph according to given input/output. + absl::Status Run( + const RunOptions& run_options, + absl::Span> inputs, + absl::Span output_tensor_names, + absl::Span target_tensor_names, + std::vector* outputs); + + // Runs the graph identified by `graph_name` using the input `inputs` and + // stores the output of the execution in `outputs`. It is the client's + // responsibility to ensure `graph_name` corresponds to logically different + // graphs, since this name is used to lookup compiled graphs in the cache. The + // graph is run synchronously with the TFRT interpreter. + absl::Status RunWithSyncInterpreter( + const std::string& graph_name, absl::Span input_values, + absl::Span input_names, + absl::Span input_dtypes, + absl::Span output_tensor_names, + absl::Span target_tensor_names, + absl::Span outputs); + + // Extends the current graph by `graph`. + absl::Status Extend(const GraphDef& graph); + + tensorflow::tfrt_stub::TfrtGraphExecutionState& graph_execution_state() + const { + return *graph_execution_state_; + } + + // Returns the underlying runtime. + const tensorflow::tfrt_stub::Runtime& runtime() const { + DCHECK(options_.runtime); + return *options_.runtime; + } + + tfrt::ResourceContext& resource_context() { return *resource_context_; } + + const Options& options() const { return options_; } + const FallbackState& fallback_state() const { return *fallback_state_; } + FallbackState& fallback_state() { return *fallback_state_; } + + // Compiles graph for `graph_name` and runs any initializers. + absl::Status CompileGraph( + const std::string& graph_name, + absl::Span input_tensor_names, + absl::Span input_tensor_dtypes, + absl::Span output_tensor_names, + absl::Span target_tensor_names); + + const mlrt::KernelRegistry& kernel_registry() const { + return *kernel_registry_; + } + + private: + // A set of methods to load a client graph. + absl::StatusOr> + LoadClientGraph( + const GraphExecutor::ClientGraph& client_graph, + tensorflow::tfrt_stub::WorkQueueInterface* work_queue, + absl::Span> inputs); + absl::StatusOr> + ImportAndCompileClientGraph( + const GraphExecutor::ClientGraph& client_graph, + absl::Span> inputs); + absl::StatusOr< + std::pair>> + ImportClientGraphToMlirModule(const GraphExecutor::ClientGraph& client_graph, + mlir::MLIRContext* context) const; + absl::StatusOr CompileMlirModuleToBef( + mlir::ModuleOp module) const; + + absl::Status InitBef(LoadedClientGraph* loaded_client_graph, + tensorflow::tfrt_stub::WorkQueueInterface* work_queue); + + absl::Status InitBytecode(LoadedClientGraph* loaded_graph); + + // Returns a `LoadedClientGraph` given input/output tensor info. If there is + // no existing one yet, creates one first. + absl::StatusOr> + GetOrCreateLoadedClientGraph( + const RunOptions& run_options, + absl::Span input_tensor_names, + absl::Span input_tensor_dtypes, + absl::Span output_tensor_names, + absl::Span target_tensor_names, + tensorflow::tfrt_stub::WorkQueueInterface* work_queue, + absl::string_view graph_name = "", + absl::Span> inputs = {}) + TF_LOCKS_EXCLUDED(loaded_client_graphs_mu_); + + Options options_; + std::unique_ptr fallback_state_; + + std::unique_ptr + graph_execution_state_; + + tfrt::RequestDeadlineTracker req_deadline_tracker_; + + tensorflow::mutex loaded_client_graphs_mu_; + // Caches `LoadedClientGraph` by the joined name. + // For pointer stability of values in `absl::flat_hash_map<>`, additional + // `std::unique_ptr<>` is necessary. (See https://abseil.io/tips/136.) + absl::flat_hash_map> + loaded_client_graphs_ TF_GUARDED_BY(loaded_client_graphs_mu_); + + std::unique_ptr kernel_registry_; + + std::unique_ptr resource_context_; + + protected: + // For testing basic Cost Analysis functionality. + absl::Duration simulated_duration_ = absl::ZeroDuration(); + tensorflow::mutex num_recompilations_mu_; + int num_recompilations_ TF_GUARDED_BY(num_recompilations_mu_) = 0; +}; + +void RegisterMlirDialect(mlir::DialectRegistry& registry, + tensorflow::BackendCompiler* backend_compiler); + +} // namespace tfrt_stub +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_TFRT_GRAPH_EXECUTOR_GRAPH_EXECUTOR_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/tfrt/graph_executor/sync_resource_state.h b/third_party/tflite-hdrs/tensorflow/core/tfrt/graph_executor/sync_resource_state.h new file mode 100644 index 00000000..9bb9bb58 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/tfrt/graph_executor/sync_resource_state.h @@ -0,0 +1,66 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_TFRT_GRAPH_EXECUTOR_SYNC_RESOURCE_STATE_H_ +#define TENSORFLOW_CORE_TFRT_GRAPH_EXECUTOR_SYNC_RESOURCE_STATE_H_ + +#include +#include + +#include "tensorflow/core/tfrt/utils/any_ptr.h" +#include "tfrt/tensor/dense_host_tensor.h" // from @tf_runtime + +namespace tensorflow { +namespace tfrt_stub { + +class SyncResourceState { + public: + // Sets `dht` in the array at `index`. `index` should be dense and + // duplicate indices are not allowed. + void SetResourceDht(int index, tfrt::DenseHostTensor dht) { + if (resource_dht_.size() <= index) { + resource_dht_.resize(index + 1); + } + + resource_dht_[index] = std::move(dht); + } + + tfrt::DenseHostTensor GetResourceDht(int index) const { + return resource_dht_.at(index).CopyRef(); + } + + template + void Set(int index, T* resource) { + if (resources_.size() <= index) { + resources_.resize(index + 1); + } + + resources_[index] = tfrt::AnyPtr(resource); + } + + template + T* Get(int index) const { + return resources_.at(index).get(); + } + + private: + std::vector resource_dht_; + // TODO(b/288899457): Consider provide a simpler solution than forking AnyPtr. + std::vector resources_; +}; + +} // namespace tfrt_stub +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_TFRT_GRAPH_EXECUTOR_SYNC_RESOURCE_STATE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/tfrt/ifrt/checkpoint_loader.h b/third_party/tflite-hdrs/tensorflow/core/tfrt/ifrt/checkpoint_loader.h new file mode 100644 index 00000000..e47c78bb --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/tfrt/ifrt/checkpoint_loader.h @@ -0,0 +1,83 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_TFRT_IFRT_CHECKPOINT_LOADER_H_ +#define TENSORFLOW_CORE_TFRT_IFRT_CHECKPOINT_LOADER_H_ + +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/types/span.h" +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/protobuf/meta_graph.pb.h" +#include "tensorflow/core/tfrt/fallback/fallback_state.h" +#include "tensorflow/core/tfrt/ifrt/ifrt_restore_tensor_registry.h" +#include "tensorflow/core/tfrt/mlrt/bytecode/bytecode.h" +#include "tensorflow/core/tfrt/mlrt/kernel/context.h" +#include "tensorflow/core/tfrt/utils/fallback_tensor.h" +#include "tfrt/host_context/concurrent_work_queue.h" // from @tf_runtime + +namespace tensorflow { +namespace ifrt_serving { + +// TODO(b/352551302) Move the unit test in ifrt_ops_kernel for restore to test +// this class's APIs. +// Implement the `CheckpointLoaderInterface` by using RestoreV2. +class CheckpointLoader { + public: + struct PrepareRestoreArgs { + mlir::MLIRContext* context; + tensorflow::MetaGraphDef meta_graph_def; + tfrt_stub::FallbackState* fallback_state; + std::string saved_model_dir; + bool run_placer_grappler_on_functions; + }; + + explicit CheckpointLoader( + IfrtRestoreTensorRegistry* ifrt_restore_tensor_registry, + tfrt::ConcurrentWorkQueue* checkpoint_loader_work_queue, + bool use_async_restore = true) + : ifrt_restore_tensor_registry_(ifrt_restore_tensor_registry), + checkpoint_loader_work_queue_(checkpoint_loader_work_queue), + use_async_restore_(use_async_restore) {} + virtual ~CheckpointLoader() = default; + + // Called before `Load` to do some preparation work. + virtual absl::Status PrepareRestore(const PrepareRestoreArgs& args); + + // Load the checkpoint. This API is designed to be compatible with the + // `tf_mlrt.ifrt_restore_variable` kernel. + virtual absl::Status Load( + const tensorflow::tfrt_stub::FallbackTensor& prefix, + const std::vector& var_handles, + const tensorflow::tfrt_stub::FallbackTensor& tensor_names, + const tensorflow::tfrt_stub::FallbackTensor& shape_and_slices, + absl::Span restored_dtypes, + const std::vector& truncate_in_cast, tf_mlrt::Context& context); + + protected: + IfrtRestoreTensorRegistry* ifrt_restore_tensor_registry_; + tfrt::ConcurrentWorkQueue* checkpoint_loader_work_queue_; + bool use_async_restore_ = true; +}; + +} // namespace ifrt_serving +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_TFRT_IFRT_CHECKPOINT_LOADER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/tfrt/ifrt/grid.h b/third_party/tflite-hdrs/tensorflow/core/tfrt/ifrt/grid.h new file mode 100644 index 00000000..28e52809 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/tfrt/ifrt/grid.h @@ -0,0 +1,77 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_TFRT_IFRT_GRID_H_ +#define TENSORFLOW_CORE_TFRT_IFRT_GRID_H_ + +#include +#include + +#include "absl/log/check.h" +#include "absl/strings/str_format.h" + +namespace tensorflow { +namespace ifrt_serving { + +// Coordinates that identify a particular point in a 4-d grid (usually a TPU +// topology). +struct GridCoords { + int dim[4]; + + constexpr GridCoords(int d0, int d1, int d2, int d3) : dim{d0, d1, d2, d3} {} + GridCoords() : GridCoords(0, 0, 0, 0) {} + + static GridCoords Zeroes() { return GridCoords(0, 0, 0, 0); } + static GridCoords Ones() { return GridCoords(1, 1, 1, 1); } + + int operator[](int i) const { + DCHECK_GE(i, 0); + DCHECK_LT(i, 4); + return dim[i]; + } + + int& operator[](int i) { + DCHECK_GE(i, 0); + DCHECK_LT(i, 4); + return dim[i]; + } + + int Product() const { return dim[0] * dim[1] * dim[2] * dim[3]; } + + std::string ToString() const; + + template + friend void AbslStringify(Sink& sink, const GridCoords& value) { + absl::Format(&sink, "%s", value.ToString()); + } + + friend bool operator==(const GridCoords& a, const GridCoords& b) { + return a[0] == b[0] && a[1] == b[1] && a[2] == b[2] && a[3] == b[3]; + } + + friend std::ostream& operator<<(std::ostream& os, const GridCoords& c) { + return os << c.ToString(); + } + + template + friend H AbslHashValue(H h, const GridCoords& c) { + return H::combine(std::move(h), c[0], c[1], c[2], c[3]); + } +}; + +} // namespace ifrt_serving +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_TFRT_IFRT_GRID_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/tfrt/ifrt/ifrt_device_utils.h b/third_party/tflite-hdrs/tensorflow/core/tfrt/ifrt/ifrt_device_utils.h new file mode 100644 index 00000000..f779aa62 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/tfrt/ifrt/ifrt_device_utils.h @@ -0,0 +1,63 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_TFRT_IFRT_IFRT_DEVICE_UTILS_H_ +#define TENSORFLOW_CORE_TFRT_IFRT_IFRT_DEVICE_UTILS_H_ + +#include +#include + +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/status/statusor.h" +#include "tensorflow/compiler/tf2xla/host_compute_metadata.pb.h" +#include "xla/python/ifrt/array.h" +#include "xla/python/ifrt/client.h" +#include "xla/python/ifrt/device.h" +#include "xla/python/ifrt/executable.h" +#include "xla/python/ifrt/host_callback.h" +#include "xla/xla_data.pb.h" +#include "tensorflow/core/example/feature.pb.h" +#include "tensorflow/core/framework/function.pb.h" +#include "tensorflow/core/framework/tensor_shape.pb.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h" +#include "tensorflow/core/tfrt/ifrt/ifrt_config.pb.h" + +namespace tensorflow { +namespace ifrt_serving { + +// Returns the assigned IFRT devices based on the device assignment attribute. +// +// params: +// ifrt_client: The ifrt client. +// num_replicas: The number of replicas. +// num_cores_per_replica: The number of cores per replica. +// +// device_assignment: The device assignment array encoded as +// [x0,y0,z0,core0,x1,y1,z1,core1, ...]. Optional. If not provided, the +// devices will be assigned based on the default order returned by the IFRT +// client. +// +// returns: +// The assigned devices. +absl::StatusOr> GetAssignedIfrtDevices( + const xla::ifrt::Client& ifrt_client, int num_replicas, + int num_cores_per_replica, + std::optional> device_assignment); + +} // namespace ifrt_serving +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_TFRT_IFRT_IFRT_DEVICE_UTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/tfrt/ifrt/ifrt_executable_registry.h b/third_party/tflite-hdrs/tensorflow/core/tfrt/ifrt/ifrt_executable_registry.h new file mode 100644 index 00000000..25b9e6c3 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/tfrt/ifrt/ifrt_executable_registry.h @@ -0,0 +1,104 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_TFRT_IFRT_IFRT_EXECUTABLE_REGISTRY_H_ +#define TENSORFLOW_CORE_TFRT_IFRT_IFRT_EXECUTABLE_REGISTRY_H_ + +#include +#include +#include + +#include "absl/base/thread_annotations.h" +#include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/synchronization/mutex.h" +#include "tensorflow/core/tfrt/ifrt/ifrt_serving_executable.h" + +namespace tensorflow { +namespace ifrt_serving { + +// Maintains a process-wide map from program ids to executables. Used by the +// `IfrtCall` TensorFlow op to look up executables and invoke them. +// +// Invoking a TPU program inside a `IfrtCall` TF op requires being +// able to retrieve an executable for the given program. Since there's no easy +// way to pass non-serializable attributes to TF ops, we encode a program id +// (that is unique within a process) as an attribute of a `IfrtCall` op and +// use this registry class to let the `IfrtCall` op look up an executable +// during TF op execution. +class ServingExecutableRegistry { + public: + // RAII handle for registered executables. + class Handle { + public: + Handle(); // Constructs an empty handle. + + // Move only. + Handle(Handle&& other); + Handle& operator=(Handle&& other); + Handle(const Handle&) = delete; + Handle& operator=(const Handle&) = delete; + + ~Handle(); + + // Returns the program id that the handle represents, or `std::nullopt` if + // the handle is empty. + std::optional program_id() const { return program_id_; } + + // Unregisters the owned executable, if any, early (before the destructor). + // Calling this method multiple times is a no-op. + void Release(); + + // Freezes the program's compilation. After Freeze() is called, no new model + // signature will be compiled. Using a signature or an input shape that + // wasn't compiled before the freeze will lead to an error. + absl::Status Freeze(); + + private: + friend class ServingExecutableRegistry; + + // Can only be constructed by `ServingExecutableRegistry::Register()`. + explicit Handle(int64_t program_id); + + // Program id. `std::nullopt` if the handle is already released. + std::optional program_id_; + }; + + // Registers an executable under the given program id. Returns an RAII handle + // that unregisters the program at its destruction. + static absl::StatusOr Register( + int64_t program_id, std::unique_ptr executable); + + // Looks up an executable registered under the given program id, or returns + // nullptr if there's no such program. + static IfrtServingExecutable* Lookup(int64_t program_id); + + private: + friend class Handle; + friend class IfrtBackendCompilerTest; + + static absl::Mutex mu_; + + // Mapping from program ids to executables. + static absl::flat_hash_map>* const + executables_ ABSL_GUARDED_BY(&mu_); +}; + +} // namespace ifrt_serving +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_TFRT_IFRT_IFRT_EXECUTABLE_REGISTRY_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/tfrt/ifrt/ifrt_loaded_variable_registry.h b/third_party/tflite-hdrs/tensorflow/core/tfrt/ifrt/ifrt_loaded_variable_registry.h new file mode 100644 index 00000000..d488d936 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/tfrt/ifrt/ifrt_loaded_variable_registry.h @@ -0,0 +1,98 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_TFRT_IFRT_IFRT_LOADED_VARIABLE_REGISTRY_H_ +#define TENSORFLOW_CORE_TFRT_IFRT_IFRT_LOADED_VARIABLE_REGISTRY_H_ + +#include +#include + +#include "absl/base/thread_annotations.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/functional/any_invocable.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "absl/synchronization/mutex.h" +#include "xla/hlo/ir/hlo_sharding.h" +#include "xla/python/ifrt/array.h" +#include "xla/python/ifrt/future.h" +#include "xla/tsl/concurrency/ref_count.h" + +namespace tensorflow { +namespace ifrt_serving { + +// This class is thread safe. +class IfrtLoadedVariableRegistry { + public: + // The key is per variable tensor per device assignment. For single -device + // program, variables can be loaded on multiple devices with core selection. + // For SPMD program, we currently assume all devices will be used, so we use + // vector to make it compatible with SPMD. + struct Key { + // We use a vector to make it compatible with SPMD because the order of the + // devices used for sharding must match the order of the devices used for + // xla compilation. + std::vector device_ids; + std::string input_name; + xla::HloSharding hlo_sharding; + template + friend H AbslHashValue(H h, const Key& key) { + h = H::combine(std::move(h), key.input_name, key.device_ids, + key.hlo_sharding); + return h; + } + + friend bool operator==(const Key& x, const Key& y) { + return x.input_name == y.input_name && x.device_ids == y.device_ids && + x.hlo_sharding == y.hlo_sharding; + } + + std::string ToString() const { + return absl::StrCat(input_name, ":", absl::StrJoin(device_ids, ","), ":", + hlo_sharding.ToString()); + } + }; + + struct LoadedVariable { + xla::ifrt::Future> array; + }; + using LoadedVariableConstructor = + absl::AnyInvocable() const>; + + // Tries to register a loaded variable with the given name. + // Returns an error if the named array does not already exists and + // loaded_variable_constructor failed to create an array. Note that it returns + // OK if the named array already exists. + // loaded_variable_constructor is invoked in the caller thread. + absl::Status TryRegisterLoadedVariable( + const Key& key, LoadedVariableConstructor&& loaded_variable_constructor) + ABSL_LOCKS_EXCLUDED(mutex_); + + absl::StatusOr GetLoadedVariable(const Key& key) const + ABSL_LOCKS_EXCLUDED(mutex_); + + private: + mutable absl::Mutex mutex_; + absl::flat_hash_map loaded_variable_map_ + ABSL_GUARDED_BY(mutex_); +}; + +} // namespace ifrt_serving +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_TFRT_IFRT_IFRT_LOADED_VARIABLE_REGISTRY_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/tfrt/ifrt/ifrt_loaded_variable_utils.h b/third_party/tflite-hdrs/tensorflow/core/tfrt/ifrt/ifrt_loaded_variable_utils.h new file mode 100644 index 00000000..6fea3a57 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/tfrt/ifrt/ifrt_loaded_variable_utils.h @@ -0,0 +1,73 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_TFRT_IFRT_IFRT_LOADED_VARIABLE_UTILS_H_ +#define TENSORFLOW_CORE_TFRT_IFRT_IFRT_LOADED_VARIABLE_UTILS_H_ + +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_types.h" +#include "xla/hlo/ir/hlo_sharding.h" +#include "xla/python/ifrt/client.h" +#include "tensorflow/core/framework/resource_handle.h" +#include "tensorflow/core/tfrt/ifrt/ifrt_config.pb.h" +#include "tensorflow/core/tfrt/ifrt/ifrt_loaded_variable_registry.h" +#include "tensorflow/core/tfrt/ifrt/ifrt_restore_tensor_registry.h" +#include "tsl/platform/threadpool.h" +#include "tfrt/host_context/concurrent_work_queue.h" // from @tf_runtime + +namespace tensorflow { +namespace ifrt_serving { + +// An index to indicate a non per-core executable bundle cache. +inline constexpr int kNoCoreSelectedIndex = -1; + +// TODO(b/352551302) Delete VariableDeviceShardingConfigProto. +struct VariableDeviceShardingConfig { + std::vector device_ids; + xla::HloSharding hlo_sharding; +}; + +absl::StatusOr GetDtypeAndShape( + const ResourceHandle& resource_handle); + +// Returns the runtime name from the resource handle. The name will be concat of +// handle's container name and handle's name. +std::string GetRuntimeNameFromVarHandle(const ResourceHandle& handle); + +// Loads a restored tensor as an IFRT loaded variable and set the restored +// tensor in the `restored_tensor_promise` as output. It is an async loading. We +// look for the restored tensor in `ifrt_restore_tensor_registry` and save a +// future of IFRT loaded variable in `ifrt_loaded_variable_registry`. The caller +// can look for the actual loaded variable value in +// `ifrt_loaded_variable_registry`. +absl::Status AsyncLoadRestoredTensorAsIfrtLoadedVariable( + absl::string_view runtime_name, + std::shared_ptr ifrt_client, + const tsl::thread::ThreadPool& thread_pool, + const ifrt_serving::IfrtRestoreTensorRegistry& ifrt_restore_tensor_registry, + ifrt_serving::IfrtLoadedVariableRegistry& ifrt_loaded_variable_registry, + tfrt::ConcurrentWorkQueue* checkpoint_loader_queue, + const VariableDeviceShardingConfig& sharding_config); + +} // namespace ifrt_serving +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_TFRT_IFRT_IFRT_LOADED_VARIABLE_UTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/tfrt/ifrt/ifrt_model_context.h b/third_party/tflite-hdrs/tensorflow/core/tfrt/ifrt/ifrt_model_context.h new file mode 100644 index 00000000..7c41a947 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/tfrt/ifrt/ifrt_model_context.h @@ -0,0 +1,191 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_TFRT_IFRT_IFRT_MODEL_CONTEXT_H_ +#define TENSORFLOW_CORE_TFRT_IFRT_IFRT_MODEL_CONTEXT_H_ + +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" +#include "absl/strings/string_view.h" +#include "tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf2hlo.h" +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "xla/python/ifrt/array.h" +#include "xla/python/ifrt/client.h" +#include "xla/python/ifrt/executable.h" +#include "xla/python/ifrt/topology.h" +#include "tensorflow/core/common_runtime/device_mgr.h" +#include "tensorflow/core/tfrt/ifrt/ifrt_config.pb.h" +#include "tensorflow/core/tfrt/ifrt/ifrt_executable_registry.h" +#include "tensorflow/core/tfrt/ifrt/ifrt_loaded_variable_registry.h" +#include "tensorflow/core/tfrt/ifrt/ifrt_persistent_compilation_cache.h" +#include "tensorflow/core/tfrt/ifrt/ifrt_restore_tensor_registry.h" +#include "tensorflow/core/tfrt/ifrt/ifrt_serving_core_selector.h" +#include "tsl/platform/protobuf.h" +#include "tsl/platform/threadpool.h" +#include "tfrt/host_context/concurrent_work_queue.h" // from @tf_runtime + +namespace tensorflow { +namespace ifrt_serving { + +inline constexpr absl::string_view kIfrtModelContextName = "IfrtModelContext"; + +// Device specific configuration not available through ifrt. This should be +// rare. +struct DeviceConfig { + tensorflow::XlaHelpers::ShapeRepresentationFn shape_representation_fn = + tensorflow::IdentityShapeRepresentationFn(); +}; + +// The runtime context for ifrt to be used in TFRT serving. +// +// This class is thread compatible. +class IfrtModelContext { + public: + explicit IfrtModelContext( + std::shared_ptr client, + IfrtServingCoreSelector* ifrt_serving_core_selector, + tsl::thread::ThreadPool* thread_pool, + std::unique_ptr compilation_environment_proto) + : client_(std::move(client)), + ifrt_serving_core_selector_(ifrt_serving_core_selector), + thread_pool_(*thread_pool), + compilation_environment_proto_( + std::move(compilation_environment_proto)) {} + IfrtModelContext( + std::shared_ptr client, + IfrtServingCoreSelector* ifrt_serving_core_selector, + tsl::thread::ThreadPool* thread_pool, tensorflow::DeviceMgr* device_mgr, + tensorflow::XlaHelpers::ShapeRepresentationFn shape_representation_fn, + std::unique_ptr compilation_environment_proto, + std::shared_ptr topology, TfToHloCompiler* tf_to_hlo_compiler, + IfrtPersistentCompilationCache* persistent_compilation_cache = nullptr) + : client_(std::move(client)), + topology_(topology), + ifrt_serving_core_selector_(ifrt_serving_core_selector), + thread_pool_(*thread_pool), + device_mgr_(device_mgr), + shape_representation_fn_(shape_representation_fn), + compilation_environment_proto_( + std::move(compilation_environment_proto)), + tf_to_hlo_compiler_(tf_to_hlo_compiler), + persistent_compilation_cache_(persistent_compilation_cache) {} + + void RegisterHandle(ServingExecutableRegistry::Handle handle) { + handles_.push_back(std::move(handle)); + } + + std::shared_ptr GetClient() const { return client_; } + + const tensorflow::XlaHelpers::ShapeRepresentationFn& + GetShapeRepresentationFn() const { + return shape_representation_fn_; + } + + tsl::thread::ThreadPool& GetThreadPool() const; + + const IfrtLoadedVariableRegistry& GetLoadedVariableRegistry() const { + return loaded_variable_registry_; + } + IfrtLoadedVariableRegistry& GetLoadedVariableRegistry() { + return loaded_variable_registry_; + } + + const IfrtRestoreTensorRegistry& GetRestoreTensorRegistry() const { + return restore_tensor_registry_; + } + IfrtRestoreTensorRegistry& GetRestoreTensorRegistry() { + return restore_tensor_registry_; + } + + IfrtPersistentCompilationCache* GetPersistentCompilationCache() const { + return persistent_compilation_cache_; + } + + tensorflow::DeviceMgr* GetDeviceMgr() const { return device_mgr_; } + IfrtServingCoreSelector* GetIfrtServingCoreSelector() const { + return ifrt_serving_core_selector_; + } + + tfrt::ConcurrentWorkQueue* checkpoint_loader_queue() const { + return checkpoint_loader_queue_; + } + void set_checkpoint_loader_queue(tfrt::ConcurrentWorkQueue* work_queue) { + checkpoint_loader_queue_ = work_queue; + } + + void set_default_signature_inputs( + const DefaultSignatureInputConfig& default_signature_inputs) { + default_signature_inputs_ = default_signature_inputs; + } + + const DefaultSignatureInputConfig& default_signature_inputs() const { + return default_signature_inputs_; + } + + tsl::protobuf::Message* GetCompilationEnvironmentProto() const { + return compilation_environment_proto_.get(); + } + + TfToHloCompiler* GetTfToHloCompiler() const { return tf_to_hlo_compiler_; } + + // Freeze the model: release the resources such as host tensors that are used + // by the device only. The caller guarantees all resources released in this + // function is no longer in use in regular execution path. + // After Freeze() is called, no new model signature will be compiled. Using a + // signature or an input shape that wasn't compiled before the freeze will + // leads to an error. + absl::Status Freeze(); + + bool IsFrozen() const { return frozen_; } + + private: + std::shared_ptr client_; + // Keep hardware specific topology info alive. This is currently used for + // shape determination. + std::shared_ptr topology_; + + IfrtServingCoreSelector* ifrt_serving_core_selector_; // May be nullptr + tsl::thread::ThreadPool& thread_pool_; + + tensorflow::DeviceMgr* device_mgr_ = nullptr; // Not owned. + tensorflow::XlaHelpers::ShapeRepresentationFn shape_representation_fn_ = + tensorflow::IdentityShapeRepresentationFn(); + std::unique_ptr compilation_environment_proto_ = + nullptr; + + // Dedicated work queue for heavy task such as variable tensor restoration. + tfrt::ConcurrentWorkQueue* checkpoint_loader_queue_ = nullptr; + + std::vector handles_; + + DefaultSignatureInputConfig default_signature_inputs_; + + IfrtLoadedVariableRegistry loaded_variable_registry_; + IfrtRestoreTensorRegistry restore_tensor_registry_; + TfToHloCompiler* tf_to_hlo_compiler_ = nullptr; + IfrtPersistentCompilationCache* persistent_compilation_cache_ = nullptr; + bool frozen_ = false; +}; + +} // namespace ifrt_serving +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_TFRT_IFRT_IFRT_MODEL_CONTEXT_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/tfrt/ifrt/ifrt_model_restore_context.h b/third_party/tflite-hdrs/tensorflow/core/tfrt/ifrt/ifrt_model_restore_context.h new file mode 100644 index 00000000..da9528ea --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/tfrt/ifrt/ifrt_model_restore_context.h @@ -0,0 +1,50 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_TFRT_IFRT_IFRT_MODEL_RESTORE_CONTEXT_H_ +#define TENSORFLOW_CORE_TFRT_IFRT_IFRT_MODEL_RESTORE_CONTEXT_H_ + +#include +#include + +#include "absl/strings/string_view.h" +#include "tensorflow/core/tfrt/ifrt/checkpoint_loader.h" + +namespace tensorflow { +namespace ifrt_serving { + +inline constexpr absl::string_view kIfrtModelRestoreContextName = + "IfrtModelRestoreContext"; + +// A resource context that holds the `CheckpointLoader` for a model. We need a +// different context than `IfrtModelContext` because `IfrtModelContext` is too +// large to be a dependency of other libraries. +class IfrtModelRestoreContext { + public: + explicit IfrtModelRestoreContext( + std::unique_ptr checkpoint_loader) + : checkpoint_loader_(std::move(checkpoint_loader)) {} + + CheckpointLoader* checkpoint_loader() const { + return checkpoint_loader_.get(); + } + + private: + std::unique_ptr checkpoint_loader_; +}; + +} // namespace ifrt_serving +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_TFRT_IFRT_IFRT_MODEL_RESTORE_CONTEXT_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/tfrt/ifrt/ifrt_persistent_compilation_cache.h b/third_party/tflite-hdrs/tensorflow/core/tfrt/ifrt/ifrt_persistent_compilation_cache.h new file mode 100644 index 00000000..56c76329 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/tfrt/ifrt/ifrt_persistent_compilation_cache.h @@ -0,0 +1,75 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_TFRT_IFRT_IFRT_PERSISTENT_COMPILATION_CACHE_H_ +#define TENSORFLOW_CORE_TFRT_IFRT_IFRT_PERSISTENT_COMPILATION_CACHE_H_ + +#include +#include + +#include "absl/functional/any_invocable.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_types.h" +#include "tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf2hlo.h" +#include "xla/pjrt/pjrt_executable.h" +#include "xla/python/ifrt/device_list.h" +#include "xla/python/ifrt/executable.h" +#include "xla/python/ifrt/hlo/hlo_program.h" +#include "xla/python/ifrt/host_callback.h" +#include "xla/python/ifrt/program.h" +#include "xla/tsl/concurrency/ref_count.h" +namespace tensorflow { +namespace ifrt_serving { + +class IfrtPersistentCompilationCache { + public: + IfrtPersistentCompilationCache() = default; + virtual ~IfrtPersistentCompilationCache() = default; + + // The implementation of this API should be thread-safe. It generates a key + // for looking up the executable in the persistent cache and it will return + // the LoadedExecutable if hits cache. Otherwise, it will call the `value_fn` + // to generate and return the LoadedExecutable. + virtual absl::StatusOr> + LookupLoadedExecutableOrCreate( + std::unique_ptr hlo_program, + tsl::RCReference device_list, + const xla::CompileOptions& xla_compile_options, + const std::vector>& + loaded_host_callbacks, + xla::ifrt::Client* client, + absl::AnyInvocable< + absl::StatusOr>( + std::unique_ptr program, + std::unique_ptr options)> + value_fn); + + // The implementation of this API should be thread-safe. It generates a key + // for looking up the Tf2HloResult in the persistent cache and it will return + // the Tf2HloResult if hits cache. Otherwise, it will call the `value_fn` to + // generate and return the Tf2HloResult. + virtual absl::StatusOr LookupTf2HloResultOrCreate( + Tf2HloArg tf2hlo_arg, TfToHloCompiler* tf_to_hlo_compiler); + + virtual bool IsXlaCompilationCacheEnabled() const { return false; } + virtual bool IsTf2HloCompilationCacheEnabled() const { return false; } +}; + +} // namespace ifrt_serving +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_TFRT_IFRT_IFRT_PERSISTENT_COMPILATION_CACHE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/tfrt/ifrt/ifrt_restore_tensor_registry.h b/third_party/tflite-hdrs/tensorflow/core/tfrt/ifrt/ifrt_restore_tensor_registry.h new file mode 100644 index 00000000..73b7fec3 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/tfrt/ifrt/ifrt_restore_tensor_registry.h @@ -0,0 +1,74 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_TFRT_IFRT_IFRT_RESTORE_TENSOR_REGISTRY_H_ +#define TENSORFLOW_CORE_TFRT_IFRT_IFRT_RESTORE_TENSOR_REGISTRY_H_ + +#include + +#include "absl/base/thread_annotations.h" +#include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/synchronization/mutex.h" +#include "tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_types.h" +#include "xla/python/ifrt/future.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.pb.h" + +namespace tensorflow { +namespace ifrt_serving { + +// This class is thread safe. +class IfrtRestoreTensorRegistry { + public: + struct RestoredTensorInfo { + bool used_by_host = false; + DtypeAndShape dtype_and_shape; + xla::ifrt::Future tensor_future; + }; + // Tries to register a loaded variable with the given name. + // Returns an error if the named tensor already exists. + absl::Status TryRegister(absl::string_view name, + RestoredTensorInfo restored_tensor_info) + ABSL_LOCKS_EXCLUDED(mutex_); + + xla::ifrt::Future GetRestoredTensor( + absl::string_view name) const ABSL_LOCKS_EXCLUDED(mutex_); + + // Sets the tensor as used by the host. To ensure a tensor's host memory + // is released, this function must be called at least once before the Freeze. + absl::Status SetUsedByHost(absl::string_view name) + ABSL_LOCKS_EXCLUDED(mutex_); + + absl::StatusOr GetDtypeAndShape(absl::string_view name) const + ABSL_LOCKS_EXCLUDED(mutex_); + + // Part of freezing the model is to release the host tensors that are used by + // the device only. The caller guarantees those tensors are already loaded to + // the device. + void Freeze() ABSL_LOCKS_EXCLUDED(mutex_); + + private: + mutable absl::Mutex mutex_; + absl::flat_hash_map restored_tensors_ + ABSL_GUARDED_BY(mutex_); +}; + +} // namespace ifrt_serving +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_TFRT_IFRT_IFRT_RESTORE_TENSOR_REGISTRY_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/tfrt/ifrt/ifrt_serving_core_selector.h b/third_party/tflite-hdrs/tensorflow/core/tfrt/ifrt/ifrt_serving_core_selector.h new file mode 100644 index 00000000..a4505cba --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/tfrt/ifrt/ifrt_serving_core_selector.h @@ -0,0 +1,53 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_TFRT_IFRT_IFRT_SERVING_CORE_SELECTOR_H_ +#define TENSORFLOW_CORE_TFRT_IFRT_IFRT_SERVING_CORE_SELECTOR_H_ + +#include + +#include "absl/base/thread_annotations.h" +#include "absl/container/flat_hash_map.h" +#include "absl/synchronization/mutex.h" +#include "xla/tsl/framework/serving_device_selector.h" +namespace tensorflow { +namespace ifrt_serving { + +// A wrapper of a `tsl::ServingDeviceSelector` that will be responsible for the +// core selection during Ifrt TPU execution. +class IfrtServingCoreSelector { + public: + explicit IfrtServingCoreSelector(tsl::ServingDeviceSelector* device_selector, + int num_cores); + // Reserves a device for the given `program_id`. The `program_id` is used to + // identify an IFRT executable and should be the key of + // `tensorflow::ifrt_serving::ServingExecutableRegistry `. + tsl::DeviceReservation ReserveDevice(int64_t program_id); + + private: + tsl::ServingDeviceSelector* device_selector_; + + absl::Mutex mu_; + // A counter of the number of runs for each program. For a given program, it + // is used to determine if the core selector should treat the incoming request + // as a warmup request to warm up a core. + absl::flat_hash_map run_counter_ ABSL_GUARDED_BY(mu_); + int num_cores_; +}; + +} // namespace ifrt_serving +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_TFRT_IFRT_IFRT_SERVING_CORE_SELECTOR_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/tfrt/ifrt/ifrt_serving_executable.h b/third_party/tflite-hdrs/tensorflow/core/tfrt/ifrt/ifrt_serving_executable.h new file mode 100644 index 00000000..b9402d25 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/tfrt/ifrt/ifrt_serving_executable.h @@ -0,0 +1,254 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_TFRT_IFRT_IFRT_SERVING_EXECUTABLE_H_ +#define TENSORFLOW_CORE_TFRT_IFRT_IFRT_SERVING_EXECUTABLE_H_ + +#include + +#include +#include +#include +#include +#include +#include + +#include "absl/base/thread_annotations.h" +#include "absl/container/flat_hash_map.h" +#include "absl/log/log.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/synchronization/mutex.h" +#include "absl/types/span.h" +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/OwningOpRef.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_types.h" +#include "tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf2hlo.h" +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "xla/python/ifrt/array.h" +#include "xla/python/ifrt/client.h" +#include "xla/python/ifrt/device.h" +#include "xla/python/ifrt/device_list.h" +#include "xla/python/ifrt/executable.h" +#include "xla/python/ifrt/future.h" +#include "xla/python/ifrt/shape.h" +#include "xla/python/ifrt/sharding.h" +#include "xla/tsl/concurrency/ref_count.h" +#include "xla/xla_data.pb.h" +#include "tensorflow/core/common_runtime/device_mgr.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h" +#include "tensorflow/core/tfrt/ifrt/ifrt_loaded_variable_registry.h" +#include "tensorflow/core/tfrt/ifrt/ifrt_persistent_compilation_cache.h" +#include "tensorflow/core/tfrt/ifrt/ifrt_restore_tensor_registry.h" +#include "tensorflow/core/tfrt/ifrt/ifrt_serving_core_selector.h" +#include "tensorflow/core/tfrt/ifrt/tf_host_callback.h" +#include "tsl/platform/threadpool.h" +#include "tfrt/host_context/concurrent_work_queue.h" // from @tf_runtime + +namespace tensorflow { +namespace ifrt_serving { + +class IfrtServingExecutable { + public: + static absl::StatusOr> Create( + int64_t program_id, absl::string_view model_name, + absl::string_view signature_name, + mlir::OwningOpRef module, + std::shared_ptr client, + tsl::thread::ThreadPool* thread_pool, + IfrtLoadedVariableRegistry* ifrt_loaded_variable_registry, + const IfrtRestoreTensorRegistry* ifrt_restore, + tfrt::ConcurrentWorkQueue* checkpoint_loader_queue, + tensorflow::DeviceMgr* device_mgr, + tensorflow::XlaHelpers::ShapeRepresentationFn shape_representation_fn, + IfrtServingCoreSelector* ifrt_serving_core_selector, + tsl::protobuf::Message* compilation_environment_proto, + TfToHloCompiler* tf_to_hlo_compiler, + IfrtPersistentCompilationCache* persistent_compilation_cache); + + // Movable but not copyable. + IfrtServingExecutable(IfrtServingExecutable&& other) = default; + IfrtServingExecutable& operator=(IfrtServingExecutable&& other) = default; + IfrtServingExecutable(const IfrtServingExecutable& other) = delete; + IfrtServingExecutable& operator=(const IfrtServingExecutable& other) = delete; + + absl::string_view model_name() const { return model_name_; } + absl::string_view signature_name() const { return signature_name_; } + + // Executes the computation. + // variable_arg_indices are in sorted order. + absl::StatusOr> Execute( + absl::Span inputs, + absl::Span variable_arg_indices); + + // Freezes the model. After the Freeze(), JIT compile is not supported and + // Execute() will return error if inputs contain uncompiled shapes. + void Freeze(); + + int num_executables() const { + absl::MutexLock lock(&mutex_); + return executable_bundles_.size(); + } + + private: + friend class IfrtBackendCompilerTest; + // In memory cache key. + struct Key { + std::vector input_shapes; + template + friend H AbslHashValue(H h, const Key& key) { + for (const auto& shape : key.input_shapes) { + for (auto size : shape.dim_sizes()) { + h = H::combine(std::move(h), size); + } + } + return h; + } + + friend bool operator==(const Key& x, const Key& y) { + return x.input_shapes == y.input_shapes; + } + }; + + struct CachedExecutableBundle { + std::unique_ptr ifrt_executable; + tensorflow::tpu::TPUCompileMetadataProto compile_metadata; + std::vector> host_callbacks; + + CachedExecutableBundle() = default; + // Move only + CachedExecutableBundle(CachedExecutableBundle&& other) = default; + CachedExecutableBundle& operator=(CachedExecutableBundle&& other) = default; + CachedExecutableBundle(const CachedExecutableBundle& other) = delete; + CachedExecutableBundle& operator=(const CachedExecutableBundle& other) = + delete; + }; + + IfrtServingExecutable( + int64_t program_id, absl::string_view model_name, + absl::string_view signature_name, + mlir::OwningOpRef module, + std::shared_ptr client, + tsl::thread::ThreadPool* thread_pool, + IfrtLoadedVariableRegistry* ifrt_loaded_variable_registry, + const IfrtRestoreTensorRegistry* ifrt_restore_tensor_registry, + tfrt::ConcurrentWorkQueue* checkpoint_loader_queue, + tensorflow::DeviceMgr* device_mgr, + tensorflow::XlaHelpers::ShapeRepresentationFn shape_representation_fn, + IfrtServingCoreSelector* ifrt_serving_core_selector, + tensorflow::tpu::TPUCompileMetadataProto original_compile_metadata, + tsl::RCReference assigned_device_list, + tsl::protobuf::Message* compilation_environment_proto, + TfToHloCompiler* tf_to_hlo_compiler, + IfrtPersistentCompilationCache* persistent_compilation_cache) + : program_id_(program_id), + model_name_(std::string(model_name)), + signature_name_(std::string(signature_name)), + module_(std::move(module)), + original_compile_metadata_(std::move(original_compile_metadata)), + assigned_device_list_(std::move(assigned_device_list)), + ifrt_client_(std::move(client)), + thread_pool_(*thread_pool), + ifrt_loaded_variable_registry_(*ifrt_loaded_variable_registry), + ifrt_restore_tensor_registry_(*ifrt_restore_tensor_registry), + checkpoint_loader_queue_(checkpoint_loader_queue), + device_mgr_(device_mgr), + shape_representation_fn_(std::move(shape_representation_fn)), + ifrt_serving_core_selector_(std::move(ifrt_serving_core_selector)), + compilation_environment_proto_(compilation_environment_proto), + tf_to_hlo_compiler_(tf_to_hlo_compiler), + persistent_compilation_cache_(persistent_compilation_cache) {} + + int64_t program_id_; + using SharedCachedExecutableBundle = std::shared_ptr; + + std::string model_name_; + std::string signature_name_; + + mlir::OwningOpRef module_ ABSL_GUARDED_BY(mutex_); + // The original compile metadata. We need to keep it around to be able to + // test portable execution condition even if the Module itself is already + // released. + tensorflow::tpu::TPUCompileMetadataProto original_compile_metadata_; + const tsl::RCReference assigned_device_list_; + + std::shared_ptr ifrt_client_; + tsl::thread::ThreadPool& thread_pool_; + + IfrtLoadedVariableRegistry& ifrt_loaded_variable_registry_; + const IfrtRestoreTensorRegistry& ifrt_restore_tensor_registry_; + tfrt::ConcurrentWorkQueue* checkpoint_loader_queue_; + tensorflow::DeviceMgr* device_mgr_; // Not owned. For host callback. + tensorflow::XlaHelpers::ShapeRepresentationFn shape_representation_fn_; + IfrtServingCoreSelector* ifrt_serving_core_selector_; + + tsl::protobuf::Message* + compilation_environment_proto_; // NOT OWNED. can be nullptr. + + mutable absl::Mutex mutex_; + absl::flat_hash_map> + executable_bundles_ ABSL_GUARDED_BY(mutex_); + + bool is_frozen_ ABSL_GUARDED_BY(mutex_) = false; + + // The tf_to_hlo_compiler_ is not owned by this executable. It is expected to + // be alive during the lifetime of the executable. + TfToHloCompiler* tf_to_hlo_compiler_; + + // The persistent compilation cache is a global cache and is not owned by + // this executable. When it is nullptr, the persistent compilation cache is + // disabled at ifrt serving level. + IfrtPersistentCompilationCache* persistent_compilation_cache_; + + // Asynchronously load the restored variable tensors to Ifrt array. + absl::Status AsyncLoadIfrtArray( + absl::Span inputs, + absl::Span variable_arg_indices, + const CachedExecutableBundle& executable_bundle, + const tsl::RCReference& devices); + + absl::StatusOr> ConvertTensorToArray( + const tensorflow::Tensor& tensor, + const tsl::RCReference& device_list, + const xla::OpSharding& sharding); + + xla::ifrt::Future LookUpOrCreateExecutable( + const tensorflow::tpu::TPUCompileMetadataProto& compile_metadata, + absl::Span dtypes_and_shapes, + absl::Span variable_arg_indices); + absl::StatusOr + CreateExecutableSynchronously( + mlir::OwningOpRef module_copy, + const tensorflow::tpu::TPUCompileMetadataProto& compile_metadata, + absl::Span dtypes_and_shapes, + absl::Span variable_arg_indices); + + absl::StatusOr> CreateSharding( + int num_devices, const xla::ifrt::Shape& arg_xla_shape, + const xla::ifrt::Shape& sharded_shapes); + + std::vector GetArgShape( + int arg_index, const CachedExecutableBundle& entry); + + bool UsePortableExecution( + const tensorflow::tpu::TPUCompileMetadataProto& compile_metadata); +}; + +} // namespace ifrt_serving +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_TFRT_IFRT_IFRT_SERVING_EXECUTABLE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/tfrt/ifrt/ifrt_serving_executable_test_util.h b/third_party/tflite-hdrs/tensorflow/core/tfrt/ifrt/ifrt_serving_executable_test_util.h new file mode 100644 index 00000000..9f527765 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/tfrt/ifrt/ifrt_serving_executable_test_util.h @@ -0,0 +1,89 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_TFRT_IFRT_IFRT_SERVING_EXECUTABLE_TEST_UTIL_H_ +#define TENSORFLOW_CORE_TFRT_IFRT_IFRT_SERVING_EXECUTABLE_TEST_UTIL_H_ + +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "mlir/IR/DialectRegistry.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf2hlo.h" +#include "xla/python/ifrt/array.h" +#include "xla/python/ifrt/client.h" +#include "xla/python/ifrt/test_util.h" +#include "xla/tsl/framework/test_util/mock_serving_device_selector.h" +#include "tensorflow/core/common_runtime/device_mgr.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/tfrt/ifrt/ifrt_loaded_variable_registry.h" +#include "tensorflow/core/tfrt/ifrt/ifrt_persistent_compilation_cache.h" +#include "tensorflow/core/tfrt/ifrt/ifrt_restore_tensor_registry.h" +#include "tensorflow/core/tfrt/ifrt/ifrt_serving_core_selector.h" +#include "tensorflow/core/tfrt/ifrt/ifrt_serving_executable.h" +#include "tsl/platform/threadpool.h" +#include "tfrt/host_context/concurrent_work_queue.h" // from @tf_runtime + +namespace tensorflow { +namespace ifrt_serving { +namespace test_utils { + +// A test helper class to create and IfrtServingExecutable. +class IfrtServingExecutableTestHelper { + public: + explicit IfrtServingExecutableTestHelper( + tsl::test_util::MockServingDeviceSelector* device_selector); + + // Creates an IfrtServingExecutable with the given program id. + // Note the instance of this class must outlive the returned + // IfrtServingExecutable. + std::unique_ptr MakeExecutable( + int64_t program_id, std::string mlir_module_path); + + IfrtRestoreTensorRegistry* ifrt_restore_tensor_registry() { + return &ifrt_restore_tensor_registry_; + } + + int num_cores() const { return client_->addressable_device_count(); } + + private: + static constexpr int kThreadPoolNumThreads = 16; + + tsl::test_util::MockServingDeviceSelector* device_selector_; // Not owned. + std::unique_ptr core_selector_; + std::shared_ptr client_; + std::unique_ptr thread_pool_; + IfrtLoadedVariableRegistry ifrt_loaded_variable_registry_; + IfrtRestoreTensorRegistry ifrt_restore_tensor_registry_; + std::unique_ptr work_queue_; + std::unique_ptr device_mgr_; + + mlir::DialectRegistry registry_; + std::unique_ptr context_; + std::unique_ptr + ifrt_persistent_compilation_cache_; + TfToHloCompiler tf_to_hlo_compiler_; +}; + +// Returns the path to the MLIR module for the given module name. +std::string GetMlirModulePath(absl::string_view module_name); + +} // namespace test_utils +} // namespace ifrt_serving +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_TFRT_IFRT_IFRT_SERVING_EXECUTABLE_TEST_UTIL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/tfrt/ifrt/ifrt_tensor_utils.h b/third_party/tflite-hdrs/tensorflow/core/tfrt/ifrt/ifrt_tensor_utils.h new file mode 100644 index 00000000..6235e414 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/tfrt/ifrt/ifrt_tensor_utils.h @@ -0,0 +1,45 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_TFRT_IFRT_IFRT_TENSOR_UTILS_H_ +#define TENSORFLOW_CORE_TFRT_IFRT_IFRT_TENSOR_UTILS_H_ + +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/status/statusor.h" +#include "xla/python/ifrt/dtype.h" +#include "xla/python/ifrt/shape.h" +#include "xla/xla_data.pb.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/tensor_shape.pb.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h" + +namespace tensorflow { +namespace ifrt_serving { + +absl::StatusOr ToTensorDataType( + xla::ifrt::DType ifrt_dtype); + +absl::StatusOr ToIfrtDType(tensorflow::DataType tensor_dtype); + +xla::ifrt::Shape ToIfrtShape(const tensorflow::TensorShape& shape); + +tensorflow::TensorShape ToTensorShape(const xla::ifrt::Shape& shape); + +} // namespace ifrt_serving +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_TFRT_IFRT_IFRT_TENSOR_UTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/tfrt/ifrt/sharding_utils.h b/third_party/tflite-hdrs/tensorflow/core/tfrt/ifrt/sharding_utils.h new file mode 100644 index 00000000..a777a822 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/tfrt/ifrt/sharding_utils.h @@ -0,0 +1,80 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_TFRT_IFRT_SHARDING_UTILS_H_ +#define TENSORFLOW_CORE_TFRT_IFRT_SHARDING_UTILS_H_ + +#include +#include + +#include "absl/container/inlined_vector.h" +#include "absl/status/statusor.h" +#include "absl/types/span.h" +#include "xla/hlo/ir/hlo_sharding.h" +#include "xla/python/ifrt/array.h" +#include "xla/python/ifrt/client.h" +#include "xla/python/ifrt/device.h" +#include "xla/python/ifrt/device_list.h" +#include "xla/python/ifrt/future.h" +#include "xla/tsl/concurrency/ref_count.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tsl/platform/threadpool.h" + +namespace tensorflow { +namespace ifrt_serving { + +// Create a tensor from the given host tensor based on given device ids and +// sharding information. +absl::StatusOr> MakeArrayFromTensor( + xla::ifrt::Client& ifrt_client, const tensorflow::Tensor& input_tensor, + absl::Span device_ids, const xla::HloSharding& hlo_sharding, + const tsl::thread::ThreadPool& thread_pool); + +// A variant of the above api. The difference is that the user passes in +// device_list directly instead of a list of device_ids. +absl::StatusOr> MakeArrayFromTensor( + xla::ifrt::Client& ifrt_client, const tensorflow::Tensor& input_tensor, + const tsl::RCReference& device_list, + const xla::HloSharding& hlo_sharding, + const tsl::thread::ThreadPool& thread_pool); + +// Reshard an disassembled array list back to one single tensor +// based on given sharding spec. +// +// input_array: the input device buffers. +// +// hlo_sharding: sharding spec that describes how the input device buffers are +// sharded. +// +// device_list: list of devices that is aligned with the order of device buffers +// in the `input_array`. +// +xla::ifrt::Future MakeTensorFromArray( + xla::ifrt::Client& ifrt_client, xla::ifrt::Array& input_array, + const xla::HloSharding& hlo_sharding, + const tsl::RCReference& device_list, + tsl::thread::ThreadPool& thread_pool); + +// A wrapper around xla::ShapeUtil::ByteStrides to get the byte strides of a +// TensorFlow tensor. +std::optional> GetByteStrides( + tensorflow::DataType dtype, const tensorflow::TensorShape& shape); + +} // namespace ifrt_serving +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_TFRT_IFRT_SHARDING_UTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/tfrt/ifrt/tf_host_callback.h b/third_party/tflite-hdrs/tensorflow/core/tfrt/ifrt/tf_host_callback.h new file mode 100644 index 00000000..2b19f239 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/tfrt/ifrt/tf_host_callback.h @@ -0,0 +1,85 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_TFRT_IFRT_TF_HOST_CALLBACK_H_ +#define TENSORFLOW_CORE_TFRT_IFRT_TF_HOST_CALLBACK_H_ + +#include +#include +#include +#include + +#include "absl/log/check.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_types.h" +#include "tensorflow/core/common_runtime/device_mgr.h" +#include "tensorflow/core/common_runtime/eager/context.h" +#include "tensorflow/core/framework/function.pb.h" +#include "tensorflow/core/protobuf/config.pb.h" + +namespace tensorflow { +namespace ifrt_serving { + +// A host callback implementation to run a TF graph. +// TODO(b/332774825): Use TFRT executor for host callback. +class TfHostCallback { + public: + // Creates a TfHostCallback instance. `device_mgr` ptr is guaranteed to be + // alive throughout the lifetime of model. + static absl::StatusOr> Create( + absl::Span functions, + absl::string_view entry_function_name, + absl::Span operand_type_and_shapes, + absl::Span result_type_and_shapes, + tensorflow::DeviceMgr* device_mgr); + + // The host callback function takes two pointer arrays, each element of which + // points to allocated host buffer in host layout according to corresponding + // operand or result's shape. The buffers are only guaranteed to be alive + // during the call. + absl::Status Call(void** inputs, void** outputs); + + private: + TfHostCallback(absl::string_view entry_function_name, + absl::Span operand_type_and_shapes, + absl::Span result_type_and_shape, + tensorflow::EagerContextPtr ctx) + : ctx_(std::move(ctx)), + entry_function_name_(entry_function_name), + operand_type_and_shapes_(operand_type_and_shapes.begin(), + operand_type_and_shapes.end()), + result_type_and_shapes_(result_type_and_shape.begin(), + result_type_and_shape.end()) {} + + // Per-callback TF Eager context. + tensorflow::EagerContextPtr ctx_; + + // Entry function name to be called on invocation. + std::string entry_function_name_; + + std::vector operand_type_and_shapes_; + std::vector result_type_and_shapes_; +}; + +absl::StatusOr> +CreateTfDynamicDeviceMgr(); + +} // namespace ifrt_serving +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_TFRT_IFRT_TF_HOST_CALLBACK_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/tfrt/kernels/ifrt_program_ops.h b/third_party/tflite-hdrs/tensorflow/core/tfrt/kernels/ifrt_program_ops.h new file mode 100644 index 00000000..463647ec --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/tfrt/kernels/ifrt_program_ops.h @@ -0,0 +1,55 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_TFRT_KERNELS_IFRT_PROGRAM_OPS_H_ +#define TENSORFLOW_CORE_TFRT_KERNELS_IFRT_PROGRAM_OPS_H_ + +#include + +#include +#include + +#include "absl/base/call_once.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/tfrt/ifrt/ifrt_serving_executable.h" + +namespace tensorflow { +namespace tfrt_stub { + +// TensorFlow op that calls a Ifrt program registered in `ProgramRegistry`. +class IfrtCallOp : public tensorflow::OpKernel { + public: + explicit IfrtCallOp(tensorflow::OpKernelConstruction* ctx); + + IfrtCallOp(const IfrtCallOp& other) = delete; + IfrtCallOp& operator=(const IfrtCallOp& other) = delete; + + void Compute(tensorflow::OpKernelContext* ctx) override; + + private: + // Op attributes. + int64_t program_id_; + + std::vector variable_names_; + std::vector variable_arg_indices_; + + // Ifrt program to be called. Cached after the first call. + absl::once_flag init_once_; + tensorflow::ifrt_serving::IfrtServingExecutable* executable_; // Not owned. +}; + +} // namespace tfrt_stub +} // namespace tensorflow +#endif // TENSORFLOW_CORE_TFRT_KERNELS_IFRT_PROGRAM_OPS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/tfrt/kernels/stream_ops.h b/third_party/tflite-hdrs/tensorflow/core/tfrt/kernels/stream_ops.h new file mode 100644 index 00000000..bef61d92 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/tfrt/kernels/stream_ops.h @@ -0,0 +1,51 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_TFRT_KERNELS_STREAM_OPS_H_ +#define TENSORFLOW_CORE_TFRT_KERNELS_STREAM_OPS_H_ + +#include +#include +#include + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/tfrt/runtime/stream.h" + +namespace tensorflow { +namespace tfrt_stub { + +// TensorFlow op that immediately sends results back to the serving controller. +class PwStreamResultsOp : public tensorflow::OpKernel { + public: + explicit PwStreamResultsOp(tensorflow::OpKernelConstruction* ctx); + + PwStreamResultsOp(const PwStreamResultsOp& other) = delete; + PwStreamResultsOp& operator=(const PwStreamResultsOp& other) = delete; + + void Compute(tensorflow::OpKernelContext* ctx) override; + + private: + // Op attributes. + std::string controller_address_; + std::string model_name_; + StreamCallbackId callback_id_; + std::vector names_; + + std::unique_ptr stream_; +}; + +} // namespace tfrt_stub +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_TFRT_KERNELS_STREAM_OPS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/tfrt/kernels/stream_ops_util.h b/third_party/tflite-hdrs/tensorflow/core/tfrt/kernels/stream_ops_util.h new file mode 100644 index 00000000..b6fa0223 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/tfrt/kernels/stream_ops_util.h @@ -0,0 +1,46 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_TFRT_KERNELS_STREAM_OPS_UTIL_H_ +#define TENSORFLOW_CORE_TFRT_KERNELS_STREAM_OPS_UTIL_H_ + +#include +#include +#include + +#include "absl/status/statusor.h" +#include "absl/types/span.h" +#include "tensorflow/core/framework/tensor.h" + +namespace tensorflow { +namespace tfrt_stub { + +// Unbatches `tensors` according to the step ids and returns a list of (step_id, +// unbatched_tensors) pairs. +// +// If `step_ids` is a scalar, each tensor in `tensors` is treated as if they are +// not batched and the entire tensor is associated with the single step id. +// +// If `step_ids` is a 1-D tensor, this tensor represents the step id of each +// example in the batch. Tensors in `tensors` are "unbatched" along the leading +// dimension according to the step id tensor and the unbatched tensors are +// associated with the corresponding step ids. +absl::StatusOr>>> +UnbatchStreamResults(const tensorflow::Tensor& step_ids, + absl::Span tensors); + +} // namespace tfrt_stub +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_TFRT_KERNELS_STREAM_OPS_UTIL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/tfrt/kernels/stream_ops_util_constants.h b/third_party/tflite-hdrs/tensorflow/core/tfrt/kernels/stream_ops_util_constants.h new file mode 100644 index 00000000..ef8bad04 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/tfrt/kernels/stream_ops_util_constants.h @@ -0,0 +1,30 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_TFRT_KERNELS_STREAM_OPS_UTIL_CONSTANTS_H_ +#define TENSORFLOW_CORE_TFRT_KERNELS_STREAM_OPS_UTIL_CONSTANTS_H_ + +#include + +namespace tensorflow { +namespace tfrt_stub { + +// Step id and batch id are packed together to a 64 bit integer in the stream +// callback. Step id takes the MSB 32 bit. +inline constexpr size_t kStepIdBitSize = 32; + +} // namespace tfrt_stub +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_TFRT_KERNELS_STREAM_OPS_UTIL_CONSTANTS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/tfrt/mla/mla_test_utils.h b/third_party/tflite-hdrs/tensorflow/core/tfrt/mla/mla_test_utils.h new file mode 100644 index 00000000..c445e537 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/tfrt/mla/mla_test_utils.h @@ -0,0 +1,47 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_TFRT_MLA_MLA_UTILS_H_ +#define TENSORFLOW_CORE_TFRT_MLA_MLA_UTILS_H_ + +// This file contains stub implementations for Google internal MLA APIs. + +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/errors.h" + +namespace tensorflow { +namespace tfrt_stub { + +inline std::string CopySavedModelFromTestDataToTempDir( + absl::string_view tf_dir, absl::string_view saved_model_name) { + return ""; +} + +inline Status ConvertSavedModelAndAddToMla( + absl::string_view saved_model_path, const int saved_model_version, + const std::unordered_set& tags, + const std::vector& entry_points, + absl::string_view mla_module_name) { + return tensorflow::errors::Unimplemented("Not supported in OSS"); +} + +} // namespace tfrt_stub +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_TFRT_MLA_MLA_UTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/tfrt/mla/mla_utils.h b/third_party/tflite-hdrs/tensorflow/core/tfrt/mla/mla_utils.h new file mode 100644 index 00000000..79965f4c --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/tfrt/mla/mla_utils.h @@ -0,0 +1,37 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_TFRT_MLA_MLA_UTILS_H_ +#define TENSORFLOW_CORE_TFRT_MLA_MLA_UTILS_H_ + +// This file contains stub implementations for Google internal MLA APIs. + +#include "absl/strings/string_view.h" +#include "tensorflow/core/platform/statusor.h" +#include "tensorflow/core/platform/errors.h" + +namespace tensorflow { +namespace tfrt_stub { + +inline StatusOr GetSavedModelDirFromMlaDir( + absl::string_view mla_dir) { + return tensorflow::errors::Unimplemented("Not supported in OSS"); +} + +inline bool IsMlarchive(absl::string_view saved_model_dir) { return false; } + +} // namespace tfrt_stub +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_TFRT_MLA_MLA_UTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/tfrt/mlrt/attribute/attribute.h b/third_party/tflite-hdrs/tensorflow/core/tfrt/mlrt/attribute/attribute.h new file mode 100644 index 00000000..f27118ea --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/tfrt/mlrt/attribute/attribute.h @@ -0,0 +1,128 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_TFRT_MLRT_ATTRIBUTE_ATTRIBUTE_H_ +#define TENSORFLOW_CORE_TFRT_MLRT_ATTRIBUTE_ATTRIBUTE_H_ + +#include + +#include "absl/status/statusor.h" +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tfrt/translate/mlrt/mlir_to_bytecode.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/tfrt/mlrt/bytecode/bytecode.h" + +namespace tensorflow { +namespace tf_mlrt { + +class ShapeAttr { + public: + struct StorageType { + using Self = StorageType; + DEFINE_BYTECODE_FIELD(uint8_t, unranked); + DEFINE_BYTECODE_FIELD(mlrt::bc::Vector, dims); + }; + + class Constructor { + public: + Constructor(mlrt::bc::Allocator* allocator, mlrt::bc::BcAddr_t address) + : allocator_(allocator), address_(address) {} + + void set_unranked(bool unranked) { + StorageType::construct_unranked(allocator_, address_, unranked); + } + + template + auto construct_shape(Args&&... args) { + return StorageType::construct_dims(allocator_, address_, + std::forward(args)...); + } + + mlrt::bc::BcAddr_t address() const { return address_; } + + private: + mlrt::bc::Allocator* allocator_; + mlrt::bc::BcAddr_t address_; + }; + using NonTrivialConstructorType = Constructor; + + explicit ShapeAttr(const char* p) : p_(p) {} + + bool unranked() const { return StorageType::read_unranked(p_); } + mlrt::bc::Vector dims() const { return StorageType::read_dims(p_); } + + private: + const char* p_ = nullptr; +}; + +class TensorAttr { + public: + struct StorageType { + using Self = StorageType; + DEFINE_BYTECODE_FIELD(tensorflow::DataType, dtype); + DEFINE_BYTECODE_FIELD(uint64_t, num_elements); + DEFINE_BYTECODE_FIELD(mlrt::bc::Vector, shape); + DEFINE_BYTECODE_FIELD(mlrt::bc::Vector, data); + }; + + class Constructor { + public: + Constructor(mlrt::bc::Allocator* allocator, mlrt::bc::BcAddr_t address, + tensorflow::DataType dtype) + : allocator_(allocator), address_(address) { + StorageType::construct_dtype(allocator_, address_, dtype); + } + + void set_num_elements(size_t num) { + StorageType::construct_num_elements(allocator_, address_, num); + } + + template + auto construct_shape(Args&&... args) { + return StorageType::construct_shape(allocator_, address_, + std::forward(args)...); + } + template + auto construct_data(Args&&... args) { + return StorageType::construct_data(allocator_, address_, + std::forward(args)...); + } + + mlrt::bc::BcAddr_t address() const { return address_; } + + private: + mlrt::bc::Allocator* allocator_; + mlrt::bc::BcAddr_t address_; + }; + using NonTrivialConstructorType = Constructor; + + explicit TensorAttr(const char* p) : p_(p) {} + + tensorflow::DataType dtype() const { return StorageType::read_dtype(p_); } + mlrt::bc::Vector shape() const { + return StorageType::read_shape(p_); + } + mlrt::bc::Vector data() const { return StorageType::read_data(p_); } + + private: + const char* p_ = nullptr; +}; + +absl::StatusOr EncodeTensorflowAttribute( + const mlrt::ModuleEmitterContext& module_context, mlir::Attribute attr); + +} // namespace tf_mlrt +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_TFRT_MLRT_ATTRIBUTE_ATTRIBUTE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/tfrt/mlrt/bytecode/bytecode.h b/third_party/tflite-hdrs/tensorflow/core/tfrt/mlrt/bytecode/bytecode.h new file mode 100644 index 00000000..f82666f1 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/tfrt/mlrt/bytecode/bytecode.h @@ -0,0 +1,526 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_TFRT_MLRT_BYTECODE_BYTECODE_H_ +#define TENSORFLOW_CORE_TFRT_MLRT_BYTECODE_BYTECODE_H_ + +// This file defines bytecode primitives that can be used to build bytecode +// structures. This library is C++17 compliant and portable for different +// platforms. It should be also as effcient as plain C++ structs on common +// platforms. +// +// Usage: +// +// class CustomStruct { +// public: +// // The actual storage of this CustomStruct should be defined as a member +// // struct of this class. Defining storage struct is almost as simple as +// // defining a plain C++ struct; +// struct Storage { +// using Self = Storage; +// // DEFINE_BYTECODE_FIELD will generate helpers for reading and +// constructing +// // the field in bytecode. +// DEFINE_BYTECODE_FIELD(uint32_t, x); +// DEFINE_BYTECODE_FIELD(bc::Vector, y); +// }; +// +// // If the storage involves indirection like std::vector, a member class +// // Constructor should be also provided. +// class Constructor { +// public: +// // The Constructor will use `allocator` to allocate indirect storage, +// // though the direct storage is assumed to be already allocated using +// // the same allocator starting at `address`. +// explicit Constructor(Allocator* allocator, BcAddr_t address) +// : allocator_(allocator), address_(address) {} +// +// // Setting trivial fields only need to call construct_ +// // provided by DEFINE_BYTECODE_FIELD. +// void set_x(uint32_t x) { +// Storage::construct_x(allocator_, address_, x); +// } +// +// // Setting non-trivial fields only need to call construct_ +// // provided by DEFINE_BYTECODE_FIELD and also return the field's +// constructor. bc::Vector::Constructor construct_y(size_t +// y_size) { +// return Storage::construct_y(allocator_, address_, y_size); +// } +// +// BcAddr_t address() const { return address_; } +// +// private: +// bc::Allocator* allocator_; +// BcAddr_t address_; +// }; +// using NonTrivialConstructorType = Constructor; +// +// explicit CustomStruct(const char* p) : p_(p) {} +// +// // Reading fields needs only calling read_ methods provided by +// // DEFINE_BYTECODE_FIELD. +// uint32_t x() const { return Storage::read_x(p_); } +// bc::Vector y() const { return Storage::read_y(p_); } +// +// private: +// // The CustomStruct can contain only the pointer to the actual memory +// // blob. So fields need not be touched if not necessary, which would +// // otherwise incurs overhead. +// const char* p_; +// }; + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/log/check.h" +#include "absl/strings/string_view.h" + +namespace mlrt { +namespace bc { + +using BcAddr_t = uint64_t; + +class Buffer { + public: + char* Get(BcAddr_t address) { + DCHECK_LT(address, buffer_.size()); + return &buffer_.at(address); + } + + char* data() { return buffer_.data(); } + const char* data() const { return buffer_.data(); } + size_t size() const { return buffer_.size(); } + bool empty() const { return buffer_.empty(); } + + void shrink_to_fit() { buffer_.shrink_to_fit(); } + + private: + static_assert(alignof(std::max_align_t) >= 8, + "The bytecode buffer needs to be at least 8-byte aligned."); + std::vector buffer_; + + friend class Allocator; +}; + +class Allocator { + public: + explicit Allocator(Buffer* buffer) : buffer_(buffer) { + DCHECK(buffer != nullptr); + } + + BcAddr_t Allocate(size_t size, size_t alignment) { + DCHECK_LE(alignment, 8); + + // Calculate the next buffer size that is greater or equal to the previous + // buffer size, and is also aligned to `alignment`. + size_t next_align = + (buffer_->buffer_.size() + alignment - 1) / alignment * alignment; + + buffer_->buffer_.resize(next_align + size); + + return next_align; + } + + template + BcAddr_t Allocate() { + static_assert(std::is_trivial::value, "T must be trivial."); + return Allocate(sizeof(T), alignof(T)); + } + + size_t size() const { return buffer_->size(); } + + char* raw(BcAddr_t address) { return buffer_->Get(address); } + + private: + Buffer* buffer_; +}; + +// AccessTraits encapsulates the fundamental Read() and Construct() methods for +// reading and constructing bytecode data structures. + +// AccessTraits specialized for trivial types. +template +struct AccessTraits { + using StorageType = T; + static_assert(std::is_trivial::value, + "StorageType must be trivial."); + + using ConstructorType = void; + + static T Read(const char* p) { + // To be compliant with C++ standard on object lifetime and strict aliasing + // rules, we have to copy the data from memory to construct a new object. + // This is fine on most platforms as the copy can be optimized away, + // assuming `p` is sufficiently aligned. + T value; + std::memcpy(&value, p, sizeof(T)); + return value; + } + + template + static BcAddr_t Construct(Allocator* allocator, BcAddr_t address, + Args&&... args) { + // Similar to Read(), memcpy is used to serialize data to bytecode. + T value(std::forward(args)...); + std::memcpy(allocator->raw(address), &value, sizeof(T)); + return address; + } + + // Place the bytes directly for this trivial type T. It also supports placing + // bytes for a contiguous array of T. The number of bytes, `size` must not be + // greater than `num` * sizeof(T). + static void Place(Allocator* allocator, BcAddr_t address, const char* data, + size_t size, size_t num = 1) { + CHECK_LE(size, num * sizeof(T)); // Crash Ok + std::memcpy(allocator->raw(address), data, size); + } +}; + +// AccessTraits specialized for non-trivial types. +template +struct AccessTraits> { + // Non-trivial types should provide a member struct `StorageType` to + // specify the storage layout. + using StorageType = typename T::StorageType; + static_assert(std::is_trivial::value, + "StorageType must be trivial."); + + // Non-trivial types should provide a member type `NonTrivialConstructorType` + // for constructing storages. + using ConstructorType = typename T::NonTrivialConstructorType; + + static T Read(const char* p) { + // Reading non-trivial types is simply constructing the bytecode type with + // the pointer to the memory blob. All reading methods are encapsulated in + // `T`. + return T(p); + } + + template + static ConstructorType Construct(Allocator* allocator, BcAddr_t address, + Args&&... args) { + // Constructing non-trivial types is simply creating the corresponding + // constructor. + return ConstructorType(allocator, address, std::forward(args)...); + } +}; + +// The bytecode counterparts of malloc() and operator new() are also provided. +template +BcAddr_t Allocate(Allocator* allocator) { + return allocator->Allocate::StorageType>(); +} +template +auto New(Allocator* allocator, Args&&... args) { + auto address = Allocate(allocator); + return AccessTraits::Construct(allocator, address, + std::forward(args)...); +} + +// The iterator for reading bytecode data. It uses AccessTraits::Read() for +// reading the data. It is an input iterator as we cannot return the type-safe +// reference to the data in bytecode in a C++ compliant way due to object +// lifetime and strict aliasing rule. +template +class ReadIterator { + using StorageType = typename AccessTraits::StorageType; + + public: + using difference_type = std::ptrdiff_t; + using value_type = std::remove_cv_t; + using pointer = void; + using reference = value_type; + using iterator_category = std::input_iterator_tag; + + explicit ReadIterator(const char* data) : data_(data) {} + + const char* data() const { return data_; } + + value_type operator*() const { return AccessTraits::Read(data_); } + + ReadIterator& operator++() { + data_ += sizeof(StorageType); + return *this; + } + + ReadIterator operator++(int) { + ReadIterator r = *this; + data_ += sizeof(StorageType); + return r; + } + + ReadIterator& operator+=(difference_type offset) { + data_ += offset * sizeof(StorageType); + return *this; + } + + ReadIterator operator+(difference_type offset) const { + ReadIterator r = *this; + r += offset; + return r; + } + + ReadIterator& operator--() { + data_ -= sizeof(StorageType); + return *this; + } + + ReadIterator operator--(int) { + ReadIterator r = *this; + data_ -= sizeof(StorageType); + return r; + } + + ReadIterator& operator-=(difference_type offset) { + data_ -= offset * sizeof(StorageType); + return *this; + } + + ReadIterator operator-(difference_type offset) const { + ReadIterator r = *this; + r -= offset; + return r; + } + + difference_type operator-(const ReadIterator& other) const { + DCHECK_EQ((data_ - other.data_) % sizeof(StorageType), 0); + return (data_ - other.data_) / sizeof(StorageType); + } + + friend bool operator==(const ReadIterator& a, const ReadIterator& b) { + return a.data_ == b.data_; + } + + friend bool operator!=(const ReadIterator& a, const ReadIterator& b) { + return !(a == b); + } + + friend bool operator<(const ReadIterator& a, const ReadIterator& b) { + return a.data_ < b.data_; + } + + friend bool operator<=(const ReadIterator& a, const ReadIterator& b) { + return a.data_ <= b.data_; + } + + friend bool operator>(const ReadIterator& a, const ReadIterator& b) { + return a.data_ > b.data_; + } + + friend bool operator>=(const ReadIterator& a, const ReadIterator& b) { + return a.data_ >= b.data_; + } + + private: + const char* data_ = nullptr; +}; + +// DEFINE_BYTECODE_FIELD provides helper functions for reading and constructing +// member fields in bytecode. +#define DEFINE_BYTECODE_FIELD(Type, name) \ + typename ::mlrt::bc::AccessTraits::StorageType name; \ + static const char* name##_pointer(const char* base) { \ + return base + offsetof(Self, name); \ + } \ + static ::mlrt::bc::BcAddr_t name##_address(::mlrt::bc::BcAddr_t base) { \ + return base + offsetof(Self, name); \ + } \ + static Type read_##name(const char* base) { \ + return ::mlrt::bc::AccessTraits::Read(name##_pointer(base)); \ + } \ + template \ + static auto construct_##name(::mlrt::bc::Allocator* allocator, \ + ::mlrt::bc::BcAddr_t base, Args&&... args) { \ + return ::mlrt::bc::AccessTraits::Construct( \ + allocator, name##_address(base), std::forward(args)...); \ + } \ + static_assert( \ + std::is_trivial< \ + typename ::mlrt::bc::AccessTraits::StorageType>::value, \ + "Bytecode storage types must be trivial.") + +// Defines a bytecode vector. +template +class Vector { + public: + struct Storage { + using Self = Storage; + DEFINE_BYTECODE_FIELD(SizeType, size); + DEFINE_BYTECODE_FIELD(SizeType, offset); + }; + static_assert(std::is_trivial::value, "StorageType is trivial"); + static_assert(std::is_standard_layout::value, + "StorageType has standard layout"); + static_assert(sizeof(Storage) == 2 * sizeof(SizeType)); + static_assert(alignof(Storage) == alignof(SizeType)); + + using StorageType = Storage; + using ElementStorageType = typename AccessTraits::StorageType; + + using value_type = T; + using iterator = ReadIterator; + using const_iterator = iterator; + + class Constructor { + public: + Constructor(Allocator* allocator, BcAddr_t address, size_t size) + : allocator_(allocator), address_(address) { + DCHECK_GE(allocator->size(), address + sizeof(StorageType)); + size_t data_start = allocator->Allocate(size * sizeof(ElementStorageType), + alignof(ElementStorageType)); + + CHECK_LT(size, std::numeric_limits::max()); // Crash Ok + CHECK_LT(data_start - address, // Crash Ok + std::numeric_limits::max()); + storage_.size = size; + storage_.offset = data_start - address; + AccessTraits::Construct(allocator, address, storage_); + } + + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) + Constructor(Allocator* allocator, BcAddr_t address, + const std::vector& vec) + : Constructor(allocator, address, vec.size()) { + Assign(vec.begin(), vec.end()); + } + + template + auto ConstructAt(size_t index, Args&&... args) { + DCHECK_LT(index, size()); + return AccessTraits::Construct(allocator_, GetElementAddress(index), + std::forward(args)...); + } + + template + void Assign(std::initializer_list ilist) { + DCHECK_EQ(ilist.size(), size()); + Assign(ilist.begin(), ilist.end()); + } + + template + void Assign(const Range& range) { + DCHECK_EQ(std::distance(std::begin(range), std::end(range)), size()); + Assign(std::begin(range), std::end(range)); + } + + template + void Assign(Iter begin, Iter end) { + size_t i = 0; + for (; begin != end; ++begin) { + ConstructAt(i++, *begin); + } + DCHECK_EQ(i, size()); + } + + // If T is a trivial inplace type like int32_t, we can place the bytes for + // this vector directly instead of constructing the elements one by one. + template < + typename U = T, + typename std::enable_if< + std::is_same_v::ConstructorType, void>, + int>::type = 0> + void Place(const char* data, size_t size) { + AccessTraits::Place(allocator_, address_ + storage_.offset, data, size, + storage_.size); + } + + // TODO(chky): Implement iterators for construction. + + size_t size() const { return storage_.size; } + BcAddr_t address() const { return address_; } + + private: + BcAddr_t GetElementAddress(size_t index) const { + return address_ + storage_.offset + index * sizeof(ElementStorageType); + } + + Allocator* allocator_; + BcAddr_t address_; + Vector::Storage storage_; + }; + using NonTrivialConstructorType = Constructor; + + explicit Vector(const char* p) : p_(p) { + static_assert(!std::is_trivial_v); + DCHECK(p_ != nullptr); + } + Vector() { + static_assert(!std::is_trivial_v); + static Storage kEmptyStorage{0, 0}; + p_ = reinterpret_cast(&kEmptyStorage); + } + + const char* data() const { return p_ + offset(); } + + size_t size() const { return StorageType::read_size(p_); } + bool empty() const { return size() == 0; } + + iterator begin() const { return iterator(data()); } + iterator end() const { + return iterator(data() + size() * sizeof(ElementStorageType)); + } + + T operator[](size_t index) const { + DCHECK_LT(index, size()); + auto iter = begin(); + iter += index; + return *iter; + } + + private: + SizeType offset() const { return StorageType::read_offset(p_); } + + const char* p_; +}; + +class String : public Vector { + public: + using Base = Vector; + using Base::Base; + + class Constructor : public Base::Constructor { + public: + using Base::Constructor::Assign; + + Constructor(Allocator* allocator, BcAddr_t address, absl::string_view str) + : Base::Constructor(allocator, address, str.size()) { + Assign(str.begin(), str.end()); + } + }; + using NonTrivialConstructorType = Constructor; + + using Base::data; + using Base::size; + + std::string str() const { return std::string(data(), size()); } + absl::string_view Get() const { return absl::string_view(data(), size()); } + + operator absl::string_view() const { // NOLINT + return absl::string_view(data(), size()); + } + + friend bool operator==(String x, absl::string_view y) { return x.Get() == y; } + friend bool operator==(absl::string_view x, String y) { return x == y.Get(); } +}; + +} // namespace bc +} // namespace mlrt + +#endif // TENSORFLOW_CORE_TFRT_MLRT_BYTECODE_BYTECODE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/tfrt/mlrt/bytecode/executable.h b/third_party/tflite-hdrs/tensorflow/core/tfrt/mlrt/bytecode/executable.h new file mode 100644 index 00000000..2f6f9c0e --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/tfrt/mlrt/bytecode/executable.h @@ -0,0 +1,90 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_TFRT_MLRT_BYTECODE_EXECUTABLE_H_ +#define TENSORFLOW_CORE_TFRT_MLRT_BYTECODE_EXECUTABLE_H_ + +#include "tensorflow/core/tfrt/mlrt/bytecode/function.h" + +namespace mlrt { +namespace bc { + +// Defines the bytecode format for the executable, which contains the following +// section: +// 1) kernel_names: an ordered list of strings for kernel names that appear in +// this file. The `code` fields of kernels in `functions` will be indices to +// this list. +// +// 2) attributes: an ordered list of strings that are raw bytes. It is kernel +// implementations' resposiblity to decode the bytes properly. The `attributes` +// field of kernels in `functions` will be indices to this list. +// +// 3) functions: an order list of functions, which contains kernels and other +// metadata. Please refer to function.h for its detailed format. +class Executable { + public: + struct StorageType { + using Self = StorageType; + DEFINE_BYTECODE_FIELD(Vector, kernel_names); + DEFINE_BYTECODE_FIELD(Vector, functions); + DEFINE_BYTECODE_FIELD(Vector, attributes); + }; + + class Constructor { + public: + Constructor(Allocator* allocator, BcAddr_t address) + : allocator_(allocator), address_(address) {} + + template + auto construct_kernel_names(Args&&... args) { + return StorageType::construct_kernel_names(allocator_, address_, + std::forward(args)...); + } + + template + auto construct_attributes(Args&&... args) { + return StorageType::construct_attributes(allocator_, address_, + std::forward(args)...); + } + + template + auto construct_functions(Args&&... args) { + return StorageType::construct_functions(allocator_, address_, + std::forward(args)...); + } + + BcAddr_t address() const { return address_; } + + private: + Allocator* allocator_; + BcAddr_t address_; + }; + using NonTrivialConstructorType = Constructor; + + explicit Executable(const char* p) : p_(p) {} + + Vector kernel_names() const { + return StorageType::read_kernel_names(p_); + } + Vector functions() const { return StorageType::read_functions(p_); } + Vector attributes() const { return StorageType::read_attributes(p_); } + + private: + const char* p_; +}; + +} // namespace bc +} // namespace mlrt + +#endif // TENSORFLOW_CORE_TFRT_MLRT_BYTECODE_EXECUTABLE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/tfrt/mlrt/bytecode/function.h b/third_party/tflite-hdrs/tensorflow/core/tfrt/mlrt/bytecode/function.h new file mode 100644 index 00000000..c85fc40d --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/tfrt/mlrt/bytecode/function.h @@ -0,0 +1,110 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_TFRT_MLRT_BYTECODE_FUNCTION_H_ +#define TENSORFLOW_CORE_TFRT_MLRT_BYTECODE_FUNCTION_H_ + +#include "tensorflow/core/tfrt/mlrt/bytecode/bytecode.h" +#include "tensorflow/core/tfrt/mlrt/bytecode/kernel.h" + +namespace mlrt { +namespace bc { + +class Function { + public: + struct StorageType { + using Self = StorageType; + DEFINE_BYTECODE_FIELD(String, name); + DEFINE_BYTECODE_FIELD(uint32_t, num_regs); + DEFINE_BYTECODE_FIELD(Vector, input_regs); + DEFINE_BYTECODE_FIELD(Vector, output_regs); + DEFINE_BYTECODE_FIELD(Vector, output_last_uses); + DEFINE_BYTECODE_FIELD(Vector, kernels); + }; + + class Constructor { + public: + Constructor(Allocator* allocator, BcAddr_t address) + : allocator_(allocator), address_(address) {} + + template + auto construct_name(Args&&... args) { + return StorageType::construct_name(allocator_, address_, + std::forward(args)...); + } + + void set_num_regs(uint32_t num_regs) { + StorageType::construct_num_regs(allocator_, address_, num_regs); + } + + template + auto construct_input_regs(Args&&... args) { + return StorageType::construct_input_regs(allocator_, address_, + std::forward(args)...); + } + + template + auto construct_output_regs(Args&&... args) { + return StorageType::construct_output_regs(allocator_, address_, + std::forward(args)...); + } + + template + auto construct_output_last_uses(Args&&... args) { + return StorageType::construct_output_last_uses( + allocator_, address_, std::forward(args)...); + } + + template + auto construct_kernels(Args&&... args) { + return StorageType::construct_kernels(allocator_, address_, + std::forward(args)...); + } + + BcAddr_t address() const { return address_; } + + private: + Allocator* allocator_; + BcAddr_t address_; + }; + using NonTrivialConstructorType = Constructor; + + Function() = default; + // NOLINTNEXTLINE(google-explicit-constructor) + Function(std::nullptr_t) : p_(nullptr) {} + explicit Function(const char* p) : p_(p) {} + + String name() const { return StorageType::read_name(p_); } + uint32_t num_regs() const { return StorageType::read_num_regs(p_); } + Vector input_regs() const { + return StorageType::read_input_regs(p_); + } + Vector output_regs() const { + return StorageType::read_output_regs(p_); + } + Vector output_last_uses() const { + return StorageType::read_output_last_uses(p_); + } + Vector kernels() const { return StorageType::read_kernels(p_); } + + explicit operator bool() const { return p_ != nullptr; } + + private: + const char* p_ = nullptr; +}; + +} // namespace bc +} // namespace mlrt + +#endif // TENSORFLOW_CORE_TFRT_MLRT_BYTECODE_FUNCTION_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/tfrt/mlrt/bytecode/kernel.h b/third_party/tflite-hdrs/tensorflow/core/tfrt/mlrt/bytecode/kernel.h new file mode 100644 index 00000000..b4e6f53b --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/tfrt/mlrt/bytecode/kernel.h @@ -0,0 +1,90 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_TFRT_MLRT_BYTECODE_KERNEL_H_ +#define TENSORFLOW_CORE_TFRT_MLRT_BYTECODE_KERNEL_H_ + +#include "tensorflow/core/tfrt/mlrt/bytecode/bytecode.h" + +namespace mlrt { +namespace bc { + +class Kernel { + public: + struct StorageType { + using Self = StorageType; + DEFINE_BYTECODE_FIELD(uint32_t, code); + DEFINE_BYTECODE_FIELD(bc::Vector, arguments); + DEFINE_BYTECODE_FIELD(bc::Vector, results); + DEFINE_BYTECODE_FIELD(bc::Vector, attributes); + DEFINE_BYTECODE_FIELD(bc::Vector, last_uses); + }; + + class Constructor { + public: + Constructor(Allocator* allocator, BcAddr_t address) + : allocator_(allocator), address_(address) {} + + void set_code(uint32_t code) { + StorageType::construct_code(allocator_, address_, code); + } + + template + auto construct_arguments(Args&&... args) { + return StorageType::construct_arguments(allocator_, address_, + std::forward(args)...); + } + template + auto construct_results(Args&&... args) { + return StorageType::construct_results(allocator_, address_, + std::forward(args)...); + } + template + auto construct_attributes(Args&&... args) { + return StorageType::construct_attributes(allocator_, address_, + std::forward(args)...); + } + template + auto construct_last_uses(Args&&... args) { + return StorageType::construct_last_uses(allocator_, address_, + std::forward(args)...); + } + + BcAddr_t address() const { return address_; } + + private: + Allocator* allocator_; + BcAddr_t address_; + }; + using NonTrivialConstructorType = Constructor; + + explicit Kernel(const char* p) : p_(p) {} + Kernel() : p_(nullptr) {} + + uint32_t code() const { return StorageType::read_code(p_); } + Vector arguments() const { return StorageType::read_arguments(p_); } + Vector results() const { return StorageType::read_results(p_); } + Vector attributes() const { + return StorageType::read_attributes(p_); + } + Vector last_uses() const { return StorageType::read_last_uses(p_); } + + private: + const char* p_; +}; + +} // namespace bc +} // namespace mlrt + +#endif // TENSORFLOW_CORE_TFRT_MLRT_BYTECODE_KERNEL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/tfrt/mlrt/bytecode/span.h b/third_party/tflite-hdrs/tensorflow/core/tfrt/mlrt/bytecode/span.h new file mode 100644 index 00000000..bf8ce7eb --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/tfrt/mlrt/bytecode/span.h @@ -0,0 +1,86 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_TFRT_MLRT_BYTECODE_SPAN_H_ +#define TENSORFLOW_CORE_TFRT_MLRT_BYTECODE_SPAN_H_ + +#include +#include + +#include "tensorflow/core/tfrt/mlrt/bytecode/bytecode.h" + +namespace mlrt { +namespace bc { + +// Span is a range view of contiguous byte region like bc::Vector. It reads the +// array size and start pointer eagerly, so that the range can be adapted. +template +class Span { + public: + using value_type = T; + using iterator = ReadIterator; + using const_iterator = iterator; + + Span() = default; + Span(const char* data, size_t size) : data_(data), size_(size) {} + + template + Span(const Vector& vec) // NOLINT(google-explicit-constructor) + : Span(vec.data(), vec.size()) {} + Span(const String& vec) // NOLINT(google-explicit-constructor) + : Span(vec.data(), vec.size()) {} + Span(const std::vector& vec) // NOLINT(google-explicit-constructor) + : Span(reinterpret_cast(vec.data()), vec.size()) {} + + const char* data() const { return data_; } + const char* data(size_t index) const { return data_ + index * sizeof(T); } + + iterator begin() const { return iterator(data_); } + iterator end() const { return iterator(data_ + size_ * sizeof(T)); } + T back() const { + DCHECK_GT(size_, 0); + return *iterator(data_ + (size_ - 1) * sizeof(T)); + } + + T operator[](size_t index) const { + DCHECK_LT(index, size()); + auto iter = begin(); + iter += index; + return *iter; + } + + size_t size() const { return size_; } + bool empty() const { return size_ == 0; } + + Span drop_front(size_t num = 1) const { + auto beg = begin(); + beg += num; + DCHECK_GE(size(), num); + return Span(beg.data(), size() - num); + } + + Span drop_back(size_t num = 1) const { + DCHECK_GE(size(), num); + return Span(data(), size() - num); + } + + private: + const char* data_ = nullptr; + size_t size_ = 0; +}; + +} // namespace bc +} // namespace mlrt + +#endif // TENSORFLOW_CORE_TFRT_MLRT_BYTECODE_SPAN_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/tfrt/mlrt/interpreter/async_handle.h b/third_party/tflite-hdrs/tensorflow/core/tfrt/mlrt/interpreter/async_handle.h new file mode 100644 index 00000000..3f033492 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/tfrt/mlrt/interpreter/async_handle.h @@ -0,0 +1,177 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_TFRT_MLRT_INTERPRETER_ASYNC_HANDLE_H_ +#define TENSORFLOW_CORE_TFRT_MLRT_INTERPRETER_ASYNC_HANDLE_H_ + +#include +#include + +#include "absl/log/check.h" +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "tensorflow/core/tfrt/mlrt/interpreter/context.h" +#include "tensorflow/core/tfrt/mlrt/interpreter/future.h" +#include "tensorflow/core/tfrt/mlrt/interpreter/value.h" +#include "tfrt/concurrency/async_value_ref.h" // from @tf_runtime +#include "tfrt/concurrency/chain.h" // from @tf_runtime + +namespace mlrt { + +// mlrt::AsyncHandle is a specialized future for mananging context of an async +// execution. +// +// Example usage: +// +// // Create the context the async execution by copying the current context. +// auto [promise, handle] = AsyncHandle::Allocate(current_context); +// +// // Set up completion signal through the `promise` created. +// handle.execution_context().set_exit_handler( +// [promise = std::move(promise)]() { promise.Finish(); }); +// +// // Launch execution. +// thread_pool.Schedule([&execution_context = handle.execution_context()](){ +// execution_context.Call(...); +// Execute(execution_context); +// }); +// +// // Pass `handle` to places that need to wait for the execution. +// other_execution_context.Await(std::move(handle)); +// +class AsyncHandle { + public: + class Promise { + public: + Promise(const Promise&) = delete; + Promise& operator=(const Promise&) = delete; + Promise(Promise&&) = default; + Promise& operator=(Promise&&) = default; + + ~Promise() { + DCHECK(!shared_state_ || shared_state_.IsAvailable()) + << "A non-empty promise must be fulfilled."; + } + + void Finish(absl::Status status) && { + if (status.ok()) { + shared_state_.SetStateConcrete(); + } else { + shared_state_.SetError(std::move(status)); + } + } + + // We don't need HandleError() method for AsyncHandle::Promise because it is + // managed by the framework internally and should never be placed in the + // register. + + private: + explicit Promise(tsl::AsyncValueRef shared_state) + : shared_state_(std::move(shared_state)) {} + tsl::AsyncValueRef shared_state_; + + friend class AsyncHandle; + }; + + // Allocate an AsyncHandle and the corresponding promise. + static std::pair Allocate( + const ExecutionContext& current); + + AsyncHandle(const AsyncHandle&) = delete; + AsyncHandle& operator=(const AsyncHandle&) = delete; + AsyncHandle(AsyncHandle&&) = default; + AsyncHandle& operator=(AsyncHandle&&) = default; + + ~AsyncHandle() { + CHECK(!shared_state_ || shared_state_.IsAvailable()) // Crash OK + << "A non-empty AsyncHandle must be awaited."; + } + + // Then() enqueues a callback which will be called when the future is + // fulfilled with either an error or a value. + // + // The following Then() overloads accept a callback with the following + // signatures: + // + // 1) void(absl::Status) + // The argument is the status of this future in ready state. + // + // 2) void() + // There is no argument. The callback will be called whenever it is ready. + + template >> + typename std::enable_if, void>::type Then( + F then) && { + CHECK(shared_state_); // Crash OK + auto* shared_state_ptr = shared_state_.GetAsyncValue(); + shared_state_ptr->AndThen([shared_state = std::move(shared_state_), + execution_context = + std::move(execution_context_), + then = std::move(then)]() mutable { + future_internal::InvokeThen(std::move(then), shared_state.GetAsyncValue(), + future_internal::ArgTag()); + }); + } + + template >> + typename std::enable_if, void>::type Then(F then) && { + CHECK(shared_state_); // Crash OK + auto* shared_state_ptr = shared_state_.GetAsyncValue(); + shared_state_ptr->AndThen( + [shared_state = std::move(shared_state_), + execution_context = std::move(execution_context_), + then = std::move(then)]() mutable { std::move(then)(); }); + } + + void HandleError(Value* arg) { + if (!shared_state_ || shared_state_.IsAvailable()) { + // This is an empty handle or it is already finished. + return; + } + + auto& execution_context = *arg->Get(); + execution_context.LogError(absl::InternalError(absl::StrCat( + "UnwindOnError: unwind AsyncHandle of context ", + absl::Hex(reinterpret_cast(execution_context_.get())), + " from context ", + absl::Hex(reinterpret_cast(&execution_context)), + " of state ", execution_context.state_))); + execution_context.Await(std::move(*this)); + } + + bool IsReady() const { return shared_state_.IsAvailable(); } + bool IsError() const { return shared_state_.IsError(); } + + const absl::Status& GetError() const { return shared_state_.GetError(); } + + ExecutionContext& execution_context() { return *execution_context_; } + + private: + AsyncHandle(std::unique_ptr execution_context, + tsl::AsyncValueRef shared_state) + : execution_context_(std::move(execution_context)), + shared_state_(std::move(shared_state)) { + DCHECK(execution_context_); + DCHECK(shared_state_); + } + + std::unique_ptr execution_context_; + tsl::AsyncValueRef shared_state_; +}; + +} // namespace mlrt + +#endif // TENSORFLOW_CORE_TFRT_MLRT_INTERPRETER_ASYNC_HANDLE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/tfrt/mlrt/interpreter/attribute_span.h b/third_party/tflite-hdrs/tensorflow/core/tfrt/mlrt/interpreter/attribute_span.h new file mode 100644 index 00000000..485aeceb --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/tfrt/mlrt/interpreter/attribute_span.h @@ -0,0 +1,87 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_TFRT_MLRT_INTERPRETER_ATTRIBUTE_SPAN_H_ +#define TENSORFLOW_CORE_TFRT_MLRT_INTERPRETER_ATTRIBUTE_SPAN_H_ + +#include +#include + +#include "absl/log/check.h" +#include "tensorflow/core/tfrt/mlrt/bytecode/bytecode.h" +#include "tensorflow/core/tfrt/mlrt/bytecode/span.h" +#include "tensorflow/core/tfrt/mlrt/interpreter/iterator.h" + +namespace mlrt { +namespace attribute_internal { + +// LINT.IfChange(mlrt_attributes) +template +inline constexpr bool kCanAttributeBeInlined = + (std::is_integral_v || + std::is_floating_point_v)&&(sizeof(T) <= sizeof(uint32_t)); +// LINT.ThenChange(../../../../compiler/mlir/tfrt/translate/mlrt/mlir_to_bytecode.cc:mlrt_attributes) + +} // namespace attribute_internal + +class AttributeSpan { + class Iterator + : public iterator_internal::IteratorBase> { + public: + using IteratorBase>::IteratorBase; + }; + + public: + using value_type = bc::String; + using iterator = Iterator; + using const_iterator = iterator; + + AttributeSpan(bc::Span attr_indices, + bc::Span attributes) + : attr_indices_(attr_indices), attributes_(attributes) {} + + bc::String operator[](size_t id) const { + return attributes_[attr_indices_[id]]; + } + + template + T GetAs(size_t id) const { + if constexpr (std::is_same_v) { + return attributes_[attr_indices_[id]]; + } + + if constexpr (attribute_internal::kCanAttributeBeInlined) { + return bc::AccessTraits::Read(attr_indices_.data(id)); + } + + return bc::AccessTraits::Read(attributes_[attr_indices_[id]].data()); + } + + size_t size() const { return attr_indices_.size(); } + + iterator begin() const { + return iterator(attr_indices_.begin(), attributes_); + } + iterator end() const { return iterator(attr_indices_.end(), attributes_); } + + private: + bc::Span attr_indices_; + bc::Span attributes_; +}; + +} // namespace mlrt + +#endif // TENSORFLOW_CORE_TFRT_MLRT_INTERPRETER_ATTRIBUTE_SPAN_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/tfrt/mlrt/interpreter/builtin_kernels.h b/third_party/tflite-hdrs/tensorflow/core/tfrt/mlrt/interpreter/builtin_kernels.h new file mode 100644 index 00000000..11273211 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/tfrt/mlrt/interpreter/builtin_kernels.h @@ -0,0 +1,73 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_TFRT_MLRT_INTERPRETER_BUILTIN_KERNELS_H_ +#define TENSORFLOW_CORE_TFRT_MLRT_INTERPRETER_BUILTIN_KERNELS_H_ + +#include + +#include "tensorflow/core/tfrt/mlrt/interpreter/context.h" +#include "tensorflow/core/tfrt/mlrt/interpreter/future.h" +#include "tsl/profiler/lib/traceme.h" + +namespace mlrt { + +void CallOp(KernelFrame& frame); +void ReturnOp(KernelFrame& frame); + +void AsyncOp(KernelFrame& frame); +void AwaitHandleOp(KernelFrame& frame); + +// The base class for the PromiseReturnOp. +template +class PromiseReturnOpBase : public KernelFrame { + public: + using KernelFrame::KernelFrame; + + Promise& promise() const { + return static_cast(this)->promise(); + } + + decltype(auto) value() const { + return static_cast(this)->value(); + } + + bool value_last_use() const { + return static_cast(this)->value_last_use(); + } + + void Invoke() { + tsl::profiler::TraceMe trace_me(Derived::kName); + + // Set the execution context to kReturn state so that the callbacks in the + // futures, which may invoke Resume(), knows we are exiting. + execution_context().Return({}); + auto& p = promise(); + + using ValueType = std::decay_t; + + decltype(auto) value = this->value(); + if (value_last_use()) { + std::move(p).template Set(std::move(value)); + } else { + std::move(p).template Set(value); + } + } +}; + +void RegisterBuiltinKernels(KernelRegistry& registry); + +} // namespace mlrt + +#endif // TENSORFLOW_CORE_TFRT_MLRT_INTERPRETER_BUILTIN_KERNELS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/tfrt/mlrt/interpreter/context.h b/third_party/tflite-hdrs/tensorflow/core/tfrt/mlrt/interpreter/context.h new file mode 100644 index 00000000..35329ced --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/tfrt/mlrt/interpreter/context.h @@ -0,0 +1,595 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_TFRT_MLRT_INTERPRETER_CONTEXT_H_ +#define TENSORFLOW_CORE_TFRT_MLRT_INTERPRETER_CONTEXT_H_ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/base/attributes.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/inlined_vector.h" +#include "absl/functional/any_invocable.h" +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "tensorflow/core/tfrt/mlrt/bytecode/bytecode.h" +#include "tensorflow/core/tfrt/mlrt/bytecode/executable.h" +#include "tensorflow/core/tfrt/mlrt/bytecode/function.h" +#include "tensorflow/core/tfrt/mlrt/bytecode/kernel.h" +#include "tensorflow/core/tfrt/mlrt/bytecode/span.h" +#include "tensorflow/core/tfrt/mlrt/interpreter/attribute_span.h" +#include "tensorflow/core/tfrt/mlrt/interpreter/register_span.h" +#include "tensorflow/core/tfrt/mlrt/interpreter/value.h" +#include "tfrt/host_context/concurrent_work_queue.h" // from @tf_runtime + +namespace mlrt { + +class KernelFrame; +class ExecutionContext; + +class Future; +template +Future AwaitAll(FutureLikeContainer futures, ResultRefContainer results); +template +Future AwaitAll(FutureLikeContainer futures); + +using KernelImplementation = void (*)(KernelFrame); + +class KernelRegistry { + public: + void Register(absl::string_view name, KernelImplementation kernel); + + KernelImplementation Get(absl::string_view name) const; + + template + void Register(absl::string_view name); + + template + void Register() { + Register(KernelClass::kName); + } + + void Merge(const KernelRegistry& other); + + private: + absl::flat_hash_map map_; +}; + +class LoadedExecutable { + public: + LoadedExecutable(bc::Executable executable, + const KernelRegistry& kernel_registry); + + absl::Span kernels() const { return kernels_; } + + bc::Function GetFunction(absl::string_view name) const { + if (auto iter = functions_.find(name); iter != functions_.end()) { + return iter->second; + } + + return nullptr; + } + + bc::Executable executable() const { return executable_; } + + private: + bc::Executable executable_; + + absl::flat_hash_map functions_; + std::vector kernels_; +}; + +// A helper structure that holds states for a kernel. Typical usuage is that a +// control kernel wants to call a function and then come back to the same +// kernel, e.g. WhileOp. +struct KernelContext { + // Any non-zero value indicates the kernel just reentered. + int reenter = 0; + // Registers for callee. + std::vector registers; +}; + +namespace execute_internal { + +void UnwindOnError(ExecutionContext& context, int64_t pc); + +} + +class FunctionContext { + public: + FunctionContext(bc::Function function, ExecutionContext* execution_context) + : pc_(0), + registers_(function.num_regs()), + function_object_(function), + execution_context_(execution_context) { + DCHECK(execution_context); + } + + FunctionContext(const FunctionContext&) = delete; + FunctionContext& operator=(const FunctionContext&) = delete; + FunctionContext(FunctionContext&&) = default; + FunctionContext& operator=(FunctionContext&&) = default; + + ExecutionContext& execution_context() { return *execution_context_; } + + const bc::Function& function_object() const { return function_object_; } + + absl::Span regs() { return absl::MakeSpan(registers_); } + + // Argument passing is via either copy or move. + template + void Call(bc::Span last_uses, Args args, Results results) { + auto idx_iter = function_object_.input_regs().begin(); + + DCHECK_EQ(function_object_.input_regs().size(), args.size()); + + DCHECK_EQ(args.size(), last_uses.size()); + auto last_use_iter = last_uses.begin(); + for (auto& arg : args) { + if (*last_use_iter) { + registers_[*idx_iter] = std::move(arg); + } else { + registers_[*idx_iter] = arg; + } + ++idx_iter; + ++last_use_iter; + } + + results_.reserve(results.size()); + for (auto& result : results) { + results_.push_back(&result); + } + } + + // Argument passing is via move. + template + void CallByMove(Args args, Results results) { + auto idx_iter = function_object_.input_regs().begin(); + + DCHECK_EQ(function_object_.input_regs().size(), args.size()); + + for (auto& arg : args) { + registers_[*idx_iter] = std::move(arg); + ++idx_iter; + } + + results_.reserve(results.size()); + for (auto& result : results) { + results_.push_back(&result); + } + } + + // The return operation copies or moves (if not a ref) the results. + void Return(RegisterSpan results) { + DCHECK_EQ(results.size(), function_object_.output_regs().size()); + auto result_iter = results.begin(); + auto output_last_uses = function_object_.output_last_uses(); + + for (int i = 0; i < results_.size(); ++i) { + auto* result = results_[i]; + + if (!output_last_uses.empty() && output_last_uses[i]) { + // We only move the result only if it is the last use. + *result = std::move(*result_iter); + } else { + *result = *result_iter; + } + ++result_iter; + } + } + + const KernelContext& kernel_context() const { return kernel_context_; } + KernelContext& kernel_context() { return kernel_context_; } + + private: + int64_t pc_; + std::vector registers_; + std::vector results_; + bc::Function function_object_; + KernelContext kernel_context_; + + ExecutionContext* execution_context_ = nullptr; + + friend class ExecutionContext; + friend void Execute(ExecutionContext& context); + friend void execute_internal::UnwindOnError(ExecutionContext& context, + int64_t pc); +}; + +namespace context_internal { + +inline std::atomic& GetNextId() { + static std::atomic next_id = 0; + return next_id; +} + +class UserContextBase { + public: + virtual ~UserContextBase(); + + virtual std::unique_ptr Copy() const = 0; +}; + +} // namespace context_internal + +// Every user context should inherit from this class. Internally it generates a +// unique id for each user context type for internal management. +template +class UserContext : public context_internal::UserContextBase { + public: + using Base = context_internal::UserContextBase; + + static int id() { return id_; } + + std::unique_ptr Copy() const final { + return std::make_unique(*static_cast(this)); + } + + private: + inline static int id_ = context_internal::GetNextId()++; +}; + +class ExecutionContext { + public: + explicit ExecutionContext(const LoadedExecutable* loaded_executable) + : user_contexts_(context_internal::GetNextId().load()), + loaded_executable_(loaded_executable) {} + + ExecutionContext( + const LoadedExecutable* loaded_executable, + std::vector> + user_contexts, + const std::vector>& user_error_loggers) + : user_contexts_(std::move(user_contexts)), + user_error_loggers_(user_error_loggers), + loaded_executable_(loaded_executable) {} + + void set_exit_handler(absl::AnyInvocable exit_handler) { + exit_handler_ = std::move(exit_handler); + } + + tfrt::ConcurrentWorkQueue* work_queue() const { return work_queue_; } + + void set_work_queue(tfrt::ConcurrentWorkQueue* work_queue) { + work_queue_ = work_queue; + } + + template + void Call(bc::Function function_object, bc::Span last_uses, + Args args, Results results) { + auto& function_context = + function_stack_.emplace_back(function_object, this); + function_context.Call(last_uses, args, results); + state_ = State::kReady; + } + + template + void CallByMove(bc::Function function_object, Args args, Results results) { + auto& function_context = + function_stack_.emplace_back(function_object, this); + function_context.CallByMove(args, results); + state_ = State::kReady; + } + + void Return(RegisterSpan results) { + auto& function_context = function_stack_.back(); + function_context.Return(results); + state_ = State::kReturn; + } + + size_t function_stack_size() const { return function_stack_.size(); } + FunctionContext& function_context() { return function_stack_.back(); } + + // Enqueues the current execution to the wait list of the `future`. Once the + // `future` is ready, the execution will be resumed. And the value will be + // populated in `result` if it is not an error. + template + void Await(FutureLike future, Value* result) { + if (future.IsReady()) { + if (future.IsError()) { + Fail(future.GetError()); + } else { + std::move(future).Then( + [result](T value) { result->Set(std::move(value)); }); + } + return; + } + + state_ = State::kSuspended; + suspend_handler_ = [this, result, future = std::move(future)]( + absl::AnyInvocable resume) mutable { + std::move(future).Then([this, result, resume = std::move(resume)]( + absl::StatusOr value) mutable { + if (!value.ok()) { + Fail(std::move(value).status()); + } else { + result->Set(*std::move(value)); + state_ = State::kRunning; + } + + std::move(resume)(); + }); + }; + } + + template + void Await(FutureLike future) { + if (future.IsReady()) { + if (future.IsError()) { + Fail(future.GetError()); + } + return; + } + + state_ = State::kSuspended; + suspend_handler_ = [this, future = std::move(future)]( + absl::AnyInvocable resume) mutable { + std::move(future).Then( + [this, resume = std::move(resume)](absl::Status status) mutable { + if (!status.ok()) { + Fail(std::move(status)); + } else { + state_ = State::kRunning; + } + + std::move(resume)(); + }); + }; + } + + template + ABSL_ATTRIBUTE_ALWAYS_INLINE void AwaitAll(FutureLikeContainer futures, + ResultRefContainer results) { + auto future = mlrt::AwaitAll(futures, results); + + if (future.IsReady()) { + if (future.IsError()) { + Fail(future.GetError()); + } + return; + } + + state_ = State::kSuspended; + suspend_handler_ = [this, future = std::move(future)]( + absl::AnyInvocable resume) mutable { + std::move(future).Then( + [this, resume = std::move(resume)](absl::Status status) mutable { + state_ = State::kRunning; + + if (!status.ok()) { + Fail(status); + } + + std::move(resume)(); + }); + }; + } + + template + ABSL_ATTRIBUTE_ALWAYS_INLINE void AwaitAll(FutureLikeContainer futures) { + auto future = mlrt::AwaitAll(futures); + + if (future.IsReady()) { + if (future.IsError()) { + Fail(future.GetError()); + } + return; + } + + state_ = State::kSuspended; + suspend_handler_ = [this, future = std::move(future)]( + absl::AnyInvocable resume) mutable { + std::move(future).Then( + [this, resume = std::move(resume)](absl::Status status) mutable { + state_ = State::kRunning; + + if (!status.ok()) { + Fail(status); + } + + std::move(resume)(); + }); + }; + } + + const LoadedExecutable& loaded_executable() const { + return *loaded_executable_; + } + + void Fail(absl::Status status) { + state_ = State::kError; + status_ = std::move(status); + } + + void FailOnCancellation() { Fail(absl::CancelledError()); } + + const absl::Status& status() const { return status_; } + + // Add an instance of user context to the execution context. + template + void AddUserContext(std::unique_ptr user_context) { + static_assert(std::is_base_of_v, T>); + DCHECK_LT(T::id(), user_contexts_.size()); + user_contexts_[T::id()] = std::move(user_context); + } + + // Return an reference to the user context. + template + T& GetUserContext() const { + static_assert(std::is_base_of_v, T>); + DCHECK_LT(T::id(), user_contexts_.size()); + return *static_cast(user_contexts_[T::id()].get()); + } + + std::vector> + CopyUserContexts() const { + std::vector> + user_contexts; + user_contexts.reserve(user_contexts_.size()); + for (const auto& user_context : user_contexts_) { + if (user_context) { + user_contexts.push_back(user_context->Copy()); + } else { + user_contexts.push_back(nullptr); + } + } + return user_contexts; + } + + void AddUserErrorLogger(std::function error_logger) { + user_error_loggers_.push_back(error_logger); + } + + const std::vector>& user_error_loggers() + const { + return user_error_loggers_; + } + + void LogError(absl::Status status) { + for (auto& error_logger : user_error_loggers_) { + error_logger(status); + } + } + + enum class State { + // The function is pushed to the stack, and ready for execution. + kReady = 0, + + // The function is being executed and has not reached the return op yet. + kRunning, + + // The function finished executing the return op, and ready for being popped + // from the stack. + kReturn, + + // The function is suspended from execution due to context switches. + kSuspended, + + // The execution reports an error in the current thread, and the execution + // will be aborted by cleaning the states. + kError + }; + State state() const { return state_; } + + private: + absl::InlinedVector function_stack_; + + State state_ = State::kReady; + + absl::Status status_; + + // The `suspend_handler_` is a callable whose argument is another callable + // that resumes the execution (or error handling). + absl::AnyInvocable resume) &&> + suspend_handler_; + absl::AnyInvocable exit_handler_; + + tfrt::ConcurrentWorkQueue* work_queue_ = nullptr; + + std::vector> + user_contexts_; + + std::vector> user_error_loggers_; + + const LoadedExecutable* loaded_executable_ = nullptr; + + friend class AsyncHandle; + friend void Execute(ExecutionContext& context); + friend void execute_internal::UnwindOnError(ExecutionContext& context, + int64_t pc); +}; + +class KernelFrame { + public: + struct State { + State(absl::Span regs, bc::Span attrs, + ExecutionContext* execution_context) + : regs(regs), attrs(attrs), execution_context(execution_context) { + DCHECK(execution_context); + } + + explicit State(FunctionContext* function_context) + : State(function_context->regs(), + function_context->execution_context() + .loaded_executable() + .executable() + .attributes(), + &function_context->execution_context()) {} + + bc::Kernel kernel; + absl::Span regs; + bc::Span attrs; + ExecutionContext* execution_context = nullptr; + }; + + explicit KernelFrame(State* state) : state_(state) { DCHECK(state_); } + + template + operator T() const { // NOLINT + return T(state_); + } + + RegisterSpan arguments() const { + return RegisterSpan(kernel().arguments(), regs()); + } + + RegisterSpan results() const { + return RegisterSpan(kernel().results(), regs()); + } + + AttributeSpan attributes() const { + return AttributeSpan(kernel().attributes(), attrs()); + } + + bc::Span last_uses() const { return kernel().last_uses(); } + + ExecutionContext& execution_context() { return *state_->execution_context; } + const ExecutionContext& execution_context() const { + return *state_->execution_context; + } + + void set_kernel(bc::Kernel kernel) { this->kernel() = kernel; } + + private: + bc::Kernel& kernel() { return state_->kernel; } + const bc::Kernel& kernel() const { return state_->kernel; } + + absl::Span regs() const { return state_->regs; } + bc::Span attrs() const { return state_->attrs; } + + State* state_ = nullptr; + + friend void Execute(ExecutionContext& context); +}; + +template +inline void KernelRegistry::Register(absl::string_view name) { + Register( + name, +[](KernelFrame frame) { KernelClass(frame).Invoke(); }); +} + +} // namespace mlrt + +#endif // TENSORFLOW_CORE_TFRT_MLRT_INTERPRETER_CONTEXT_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/tfrt/mlrt/interpreter/execute.h b/third_party/tflite-hdrs/tensorflow/core/tfrt/mlrt/interpreter/execute.h new file mode 100644 index 00000000..7492d44a --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/tfrt/mlrt/interpreter/execute.h @@ -0,0 +1,26 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_TFRT_MLRT_INTERPRETER_EXECUTE_H_ +#define TENSORFLOW_CORE_TFRT_MLRT_INTERPRETER_EXECUTE_H_ + +#include "tensorflow/core/tfrt/mlrt/interpreter/context.h" + +namespace mlrt { + +void Execute(ExecutionContext& context); + +} // namespace mlrt + +#endif // TENSORFLOW_CORE_TFRT_MLRT_INTERPRETER_EXECUTE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/tfrt/mlrt/interpreter/future.h b/third_party/tflite-hdrs/tensorflow/core/tfrt/mlrt/interpreter/future.h new file mode 100644 index 00000000..fd32214c --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/tfrt/mlrt/interpreter/future.h @@ -0,0 +1,348 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_TFRT_MLRT_INTERPRETER_FUTURE_H_ +#define TENSORFLOW_CORE_TFRT_MLRT_INTERPRETER_FUTURE_H_ + +#include +#include +#include + +#include "absl/base/attributes.h" +#include "absl/log/check.h" +#include "tensorflow/core/tfrt/mlrt/interpreter/context.h" +#include "tfrt/concurrency/async_value.h" // from @tf_runtime +#include "tfrt/concurrency/async_value_ref.h" // from @tf_runtime + +namespace mlrt { +namespace future_internal { + +// The overloads of GetArgumentType() are used to get the argument type of a +// callable. +void GetArgumentType(void (*)()); +template +void GetArgumentType(void (F::*)()); +template +void GetArgumentType(void (F::*)() const); +template +Arg GetArgumentType(void (*)(Arg)); +template +Arg GetArgumentType(void (F::*)(Arg)); +template +Arg GetArgumentType(void (F::*)(Arg) const); +template +decltype(GetArgumentType(&F::operator())) GetArgumentType(F); + +template +using ArgumentType = decltype(GetArgumentType(std::declval())); + +template +struct ArgTag {}; + +// The overloads of InvokeThen() are used to invoke different implementation +// according to `then`'s argument type. +template +ABSL_ATTRIBUTE_ALWAYS_INLINE void InvokeThen(F&& then, + tsl::AsyncValue* shared_state, + ArgTag) { + auto& arg = shared_state->get(); + if (shared_state->IsUnique()) { + std::forward(then)(std::move(arg)); + } else { + std::forward(then)(arg); + } +} + +template +ABSL_ATTRIBUTE_ALWAYS_INLINE void InvokeThen(F&& then, + tsl::AsyncValue* shared_state, + ArgTag) { + if (shared_state->IsError()) { + std::forward(then)(shared_state->GetError()); + } else { + std::forward(then)(absl::OkStatus()); + } +} + +template +ABSL_ATTRIBUTE_ALWAYS_INLINE void InvokeThen(F&& then, + tsl::AsyncValue* shared_state, + ArgTag>) { + if (shared_state->IsError()) { + std::forward(then)(shared_state->GetError()); + } else { + InvokeThen(std::forward(then), shared_state, ArgTag()); + } +} + +} // namespace future_internal + +struct Control {}; + +// mlrt::Future is similar to std::shared_future but type-erased. +class Future { + public: + // Constructs a mlrt::Future directly from tsl::AsyncValue. This is used to + // integrate MLRT with existing systems that uses AsyncValue directly. For new + // use cases, creating mlrt::Future through mlrt::Promise is preferred. + template + explicit Future(tsl::AsyncValueRef async_value) + : shared_state_(std::move(async_value)) {} + + Future(const Future& other) = default; + Future& operator=(const Future& other) = default; + Future(Future&& other) = default; + Future& operator=(Future&& other) = default; + + explicit operator bool() const { return shared_state_ != nullptr; } + + bool IsReady() const { + DCHECK(shared_state_); + return shared_state_->IsAvailable(); + } + + bool IsError() const { + DCHECK(shared_state_); + return shared_state_->IsError(); + } + + template + const T& Get() const { + DCHECK(shared_state_); + return shared_state_->get(); + } + + const absl::Status& GetError() const { + DCHECK(shared_state_); + return shared_state_->GetError(); + } + + // Then() enqueues a callback which will be called when the future is + // fulfilled with either an error or a value. + // + // The following Then() overloads accept a callback with the following + // signatures: + // + // 1) void(absl::StatusOr) + // The argument can be either the error or the value. + // + // 2) void(absl::Status) + // The argument is the status of this future in ready state. + // + // 3) void(T) + // The argument is the fulfilled value. It is undefined behavior if there + // is an error. + // + // 4) void() + // There is no argument. The callback will be called whenever it is ready. + + template >> + typename std::enable_if_t, void> Then(F then) && { + DCHECK(shared_state_); + auto* shared_state_ptr = shared_state_.get(); + shared_state_ptr->AndThen([shared_state = std::move(shared_state_), + then = std::move(then)]() mutable { + future_internal::InvokeThen(std::move(then), shared_state.get(), + future_internal::ArgTag()); + }); + } + + template >> + typename std::enable_if_t, void> Then(F then) && { + DCHECK(shared_state_); + auto* shared_state_ptr = shared_state_.get(); + shared_state_ptr->AndThen( + [shared_state = std::move(shared_state_), + then = std::move(then)]() mutable { std::move(then)(); }); + } + + size_t UseCount() const { + DCHECK(shared_state_); + return shared_state_->NumRef(); + } + + // We don't need HandleError() method for Future because + // AsyncHandle::HandleError() is enough for error handling for async + // execution. + + private: + friend class Promise; + + explicit Future(tsl::RCReference shared_state) + : shared_state_(std::move(shared_state)) {} + + tsl::RCReference shared_state_; +}; + +// mlrt::Promise is similar to std::promise but type-erased. +class Promise { + public: + template + static Promise Allocate() { + return Promise(tsl::MakeUnconstructedAsyncValueRef().ReleaseRCRef()); + } + + ~Promise() { + DCHECK(!shared_state_ || shared_state_->IsAvailable()) + << "A non-empty promise must be fulfilled."; + } + + Promise(const Promise&) = delete; + Promise& operator=(const Promise&) = delete; + Promise(Promise&&) = default; + Promise& operator=(Promise&&) = default; + + Future GetFuture() const { return Future(shared_state_); } + + template + void Set(Args&&... args) && { + DCHECK(shared_state_); + + auto shared_state = std::move(shared_state_); + auto* shared_state_ptr = shared_state.get(); + + // Since each waiter will hold a reference to the shared state, we can drop + // the reference in mlrt::Promise::Set() in order to trigger passing by move + // for the last waiter. + if (!shared_state->IsUnique()) { + shared_state.reset(); + } + + shared_state_ptr->emplace(std::forward(args)...); + } + + void SetError(absl::Status status) && { + DCHECK(shared_state_); + + DCHECK(!status.ok()); + shared_state_->SetError(std::move(status)); + shared_state_.reset(); + } + + void HandleError(Value* arg) && { + if (!shared_state_ || shared_state_->IsAvailable()) { + // This is an empty promise or it is already fulfilled. + return; + } + + auto& execution_context = *arg->Get(); + DCHECK(!execution_context.status().ok()); + + std::move(*this).SetError(execution_context.status()); + } + + explicit operator bool() const { return shared_state_ != nullptr; } + + private: + explicit Promise(tsl::RCReference shared_state) + : shared_state_(std::move(shared_state)) {} + + tsl::RCReference shared_state_; +}; + +namespace future_internal { + +struct State { + State(int size, mlrt::Promise promise) + : count(size), promise(std::move(promise)) {} + + std::atomic count; + mlrt::Promise promise; + + absl::Mutex mu; + absl::Status status; + + void SetError(absl::Status status) { + absl::MutexLock lock(&mu); + this->status = std::move(status); + } + + // Returns true if it is the last consumer of the state. If this method + // returns false, *this object might be destroyed anytime so the data can no + // longer be accessed after it returns false. + bool DecrementCount() { + if (count.fetch_sub(1, std::memory_order_acq_rel) == 1) { + if (status.ok()) { + std::move(promise).Set(Control()); + } else { + std::move(promise).SetError(std::move(status)); + } + return true; + } + return false; + } +}; + +} // namespace future_internal + +template +ABSL_ATTRIBUTE_ALWAYS_INLINE Future AwaitAll(FutureLikeContainer futures, + ResultRefContainer results) { + DCHECK(!futures.empty()); + + auto promise = Promise::Allocate(); + auto await_all = promise.GetFuture(); + auto* state = new future_internal::State(futures.size(), std::move(promise)); + + DCHECK_EQ(futures.size(), results.size()); + for (int i = 0; i < futures.size(); ++i) { + auto& future = futures[i]; + std::move(future).Then( + [state, result = &results[i]](absl::StatusOr value) { + if (value.ok()) { + result->Set(std::move(*value)); + } else { + state->SetError(std::move(value).status()); + } + + if (state->DecrementCount()) { + delete state; + } + }); + } + + return await_all; +} + +template +ABSL_ATTRIBUTE_ALWAYS_INLINE Future AwaitAll(FutureLikeContainer futures) { + DCHECK(!futures.empty()); + + auto promise = Promise::Allocate(); + auto await_all = promise.GetFuture(); + auto* state = new future_internal::State(futures.size(), std::move(promise)); + + for (int i = 0; i < futures.size(); ++i) { + auto& future = futures[i]; + std::move(future).Then([state](absl::Status status) { + if (!status.ok()) { + state->SetError(std::move(status)); + } + + if (state->DecrementCount()) { + delete state; + } + }); + } + + return await_all; +} + +// TODO(chky): Implement type-safe version of Future and Promise. + +} // namespace mlrt + +#endif // TENSORFLOW_CORE_TFRT_MLRT_INTERPRETER_FUTURE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/tfrt/mlrt/interpreter/interpreter_testutil.h b/third_party/tflite-hdrs/tensorflow/core/tfrt/mlrt/interpreter/interpreter_testutil.h new file mode 100644 index 00000000..2b1d967a --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/tfrt/mlrt/interpreter/interpreter_testutil.h @@ -0,0 +1,126 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_TFRT_MLRT_INTERPRETER_INTERPRETER_TESTUTIL_H_ +#define TENSORFLOW_CORE_TFRT_MLRT_INTERPRETER_INTERPRETER_TESTUTIL_H_ + +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/functional/function_ref.h" +#include "absl/log/check.h" +#include "absl/strings/string_view.h" +#include "tensorflow/core/tfrt/mlrt/bytecode/bytecode.h" +#include "tensorflow/core/tfrt/mlrt/interpreter/attribute_span.h" + +namespace mlrt { +namespace testing { + +class SymbolTable { + public: + int Def(absl::string_view name) { + auto iter = reg_names_.find(name); + if (iter != reg_names_.end()) { + return iter->second; + } + + int& id = reg_names_[name]; + id = next_reg_id_++; + + return id; + } + + std::vector Def(absl::Span names) { + return DefOrUse(names, + [this](absl::string_view name) { return Def(name); }); + } + + int Use(absl::string_view name) const { + DCHECK(reg_names_.contains(name)); + return reg_names_.at(name); + } + + std::vector Use(absl::Span names) { + return DefOrUse(names, + [this](absl::string_view name) { return Use(name); }); + } + + size_t size() const { return reg_names_.size(); } + + private: + std::vector DefOrUse( + absl::Span names, + absl::FunctionRef def_or_use) { + std::vector ids; + ids.reserve(names.size()); + for (const auto& name : names) { + ids.push_back(def_or_use(name)); + } + return ids; + } + + absl::flat_hash_map reg_names_; + int next_reg_id_ = 0; +}; + +class AttributeTable { + public: + explicit AttributeTable(bc::Vector::Constructor attributes_ctor) + : ctor_(attributes_ctor) {} + + void Add(absl::string_view name, absl::string_view value) { + handles_[name] = next_id_; + ctor_.ConstructAt(next_id_++, value); + } + + void Add(absl::string_view name, const char* value) { + Add(name, absl::string_view(value)); + } + + void AddInline(absl::string_view name, absl::string_view value) { + DCHECK_LE(value.size(), sizeof(uint32_t)); + std::memcpy(&handles_[name], value.data(), value.size()); + } + + template , int> = 0> + void Add(absl::string_view name, T value) { + AddInline(name, absl::string_view(reinterpret_cast(&value), + sizeof(value))); + } + + template && + !attribute_internal::kCanAttributeBeInlined, + int> = 0> + void Add(absl::string_view name, T value) { + Add(name, absl::string_view(reinterpret_cast(&value), + sizeof(value))); + } + + uint32_t GetHandle(absl::string_view name) { return handles_.at(name); } + + private: + bc::Vector::Constructor ctor_; + int next_id_ = 0; + absl::flat_hash_map handles_; +}; + +} // namespace testing +} // namespace mlrt + +#endif // TENSORFLOW_CORE_TFRT_MLRT_INTERPRETER_INTERPRETER_TESTUTIL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/tfrt/mlrt/interpreter/iterator.h b/third_party/tflite-hdrs/tensorflow/core/tfrt/mlrt/interpreter/iterator.h new file mode 100644 index 00000000..582e7def --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/tfrt/mlrt/interpreter/iterator.h @@ -0,0 +1,131 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_TFRT_MLRT_INTERPRETER_ITERATOR_H_ +#define TENSORFLOW_CORE_TFRT_MLRT_INTERPRETER_ITERATOR_H_ + +#include + +#include "tensorflow/core/tfrt/mlrt/bytecode/bytecode.h" + +namespace mlrt { +namespace iterator_internal { + +template +class IteratorBase { + const Iter& self() const { return static_cast(*this); } + Iter& self() { return static_cast(*this); } + + public: + using difference_type = std::ptrdiff_t; + using value_type = ValueType; + using pointer = ValueType*; + using reference = ValueType&; + using iterator_category = std::random_access_iterator_tag; + + explicit IteratorBase(bc::ReadIterator index_iter, + ValueRangeType values) + : index_iter_(index_iter), values_(values) {} + + reference operator*() const { return values_[*index_iter_]; } + + pointer operator->() const { return &values_[*index_iter_]; } + + reference operator[](difference_type i) const { + return values_[*(index_iter_ + i)]; + } + + Iter& operator+=(difference_type d) { + index_iter_ += d; + return self(); + } + + Iter& operator-=(difference_type d) { + index_iter_ -= d; + return self(); + } + + Iter& operator++() { + ++index_iter_; + return self(); + } + + Iter operator++(int) { + Iter r = self(); + ++index_iter_; + return r; + } + + Iter& operator--() { + --index_iter_; + return self(); + } + + Iter operator--(int) { + Iter r = self(); + --index_iter_; + return r; + } + + Iter operator+(difference_type d) const { + Iter r = self(); + r += d; + return r; + } + + friend Iter operator+(difference_type d, const Iter& i) { return i + d; } + + Iter operator-(difference_type d) const { + Iter r = self(); + r -= d; + return r; + } + + difference_type operator-(const Iter& other) const { + return index_iter_ - other.index_iter_; + } + + friend bool operator==(const Iter& a, const Iter& b) { + return a.index_iter_ == b.index_iter_; + } + + friend bool operator!=(const Iter& a, const Iter& b) { + return a.index_iter_ != b.index_iter_; + } + + friend bool operator<(const Iter& a, const Iter& b) { + return a.index_iter_ < b.index_iter_; + } + + friend bool operator<=(const Iter& a, const Iter& b) { + return a.index_iter_ <= b.index_iter_; + } + + friend bool operator>(const Iter& a, const Iter& b) { + return a.index_iter_ > b.index_iter_; + } + + friend bool operator>=(const Iter& a, const Iter& b) { + return a.index_iter_ >= b.index_iter_; + } + + private: + bc::ReadIterator index_iter_; + ValueRangeType values_; +}; + +} // namespace iterator_internal +} // namespace mlrt + +#endif // TENSORFLOW_CORE_TFRT_MLRT_INTERPRETER_ITERATOR_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/tfrt/mlrt/interpreter/register_span.h b/third_party/tflite-hdrs/tensorflow/core/tfrt/mlrt/interpreter/register_span.h new file mode 100644 index 00000000..1fd575ba --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/tfrt/mlrt/interpreter/register_span.h @@ -0,0 +1,225 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_TFRT_MLRT_INTERPRETER_REGISTER_SPAN_H_ +#define TENSORFLOW_CORE_TFRT_MLRT_INTERPRETER_REGISTER_SPAN_H_ + +#include + +#include "absl/types/span.h" +#include "tensorflow/core/tfrt/mlrt/bytecode/bytecode.h" +#include "tensorflow/core/tfrt/mlrt/bytecode/span.h" +#include "tensorflow/core/tfrt/mlrt/interpreter/iterator.h" +#include "tensorflow/core/tfrt/mlrt/interpreter/value.h" + +namespace mlrt { + +class RegisterIterator + : public iterator_internal::IteratorBase> { + public: + using IteratorBase>::IteratorBase; +}; + +class ConstRegisterIterator + : public iterator_internal::IteratorBase> { + using IteratorBase>::IteratorBase; +}; + +class RegisterSpan { + public: + using value_type = Value; + using size_type = size_t; + using difference_type = std::ptrdiff_t; + using reference = Value&; + using const_reference = const Value&; + using pointer = Value*; + using const_pointer = const Value*; + using iterator = RegisterIterator; + using const_iterator = ConstRegisterIterator; + + RegisterSpan() = default; + RegisterSpan(bc::Span reg_indices, absl::Span regs) + : reg_indices_(reg_indices), regs_(regs) {} + + Value& operator[](size_t idx) { return regs_[reg_indices_[idx]]; } + const Value& operator[](size_t idx) const { return regs_[reg_indices_[idx]]; } + Value& back() const { return regs_[reg_indices_.back()]; } + + size_t size() const { return reg_indices_.size(); } + + iterator begin() const { return iterator(reg_indices_.begin(), regs_); } + iterator end() const { return iterator(reg_indices_.end(), regs_); } + + RegisterSpan drop_front(int num = 1) { + return RegisterSpan(reg_indices_.drop_front(num), regs_); + } + + RegisterSpan drop_back(int num = 1) { + return RegisterSpan(reg_indices_.drop_back(num), regs_); + } + + private: + bc::Span reg_indices_; + absl::Span regs_; +}; + +template +class RegisterValueIterator { + using Iter = RegisterValueIterator; + + public: + using difference_type = std::ptrdiff_t; + using value_type = T; + using pointer = T*; + using reference = T&; + using iterator_category = std::random_access_iterator_tag; + + explicit RegisterValueIterator(RegisterIterator reg_iter) + : reg_iter_(reg_iter) {} + + reference operator*() const { return (*reg_iter_).Get(); } + + pointer operator->() const { return &(*reg_iter_).Get(); } + + reference operator[](difference_type i) const { + return (*(reg_iter_ + i)).Get(); + } + + Iter& operator+=(difference_type d) { + reg_iter_ += d; + return *this; + } + + Iter& operator-=(difference_type d) { + reg_iter_ -= d; + return *this; + } + + Iter& operator++() { + ++reg_iter_; + return *this; + } + + Iter operator++(int) { + Iter r = *this; + ++reg_iter_; + return r; + } + + Iter& operator--() { + --reg_iter_; + return *this; + } + + Iter operator--(int) { + Iter r = *this; + --reg_iter_; + return r; + } + + Iter operator+(difference_type d) const { + Iter r = *this; + r += d; + return r; + } + + friend Iter operator+(difference_type d, const Iter& i) { return i + d; } + + Iter operator-(difference_type d) const { + Iter r = *this; + r -= d; + return r; + } + + difference_type operator-(const Iter& other) const { + return reg_iter_ - other.reg_iter_; + } + + friend bool operator==(const Iter& a, const Iter& b) { + return a.reg_iter_ == b.reg_iter_; + } + + friend bool operator!=(const Iter& a, const Iter& b) { + return a.reg_iter_ != b.reg_iter_; + } + + friend bool operator<(const Iter& a, const Iter& b) { + return a.reg_iter_ < b.reg_iter_; + } + + friend bool operator<=(const Iter& a, const Iter& b) { + return a.reg_iter_ <= b.reg_iter_; + } + + friend bool operator>(const Iter& a, const Iter& b) { + return a.reg_iter_ > b.reg_iter_; + } + + friend bool operator>=(const Iter& a, const Iter& b) { + return a.reg_iter_ >= b.reg_iter_; + } + + private: + RegisterIterator reg_iter_; +}; + +template +class RegisterValueSpan { + public: + using value_type = T; + using size_type = size_t; + using difference_type = std::ptrdiff_t; + using reference = T&; + using const_reference = const T&; + using pointer = T*; + using const_pointer = const T*; + using iterator = RegisterValueIterator; + using const_iterator = RegisterValueIterator; + + RegisterValueSpan(bc::Span reg_indices, absl::Span regs) + : reg_span_(reg_indices, regs) {} + + // NOLINTNEXTLINE(google-explicit-constructor) + RegisterValueSpan(RegisterSpan reg_span) : reg_span_(reg_span) {} + + T& operator[](size_t idx) { return reg_span_[idx].Get(); } + const T& operator[](size_t idx) const { return reg_span_[idx].Get(); } + + void Destroy(size_t idx) { reg_span_[idx].Destroy(); } + + size_t size() const { return reg_span_.size(); } + + iterator begin() const { return iterator(reg_span_.begin()); } + iterator end() const { return iterator(reg_span_.end()); } + + bool empty() const { return size() == 0; } + + RegisterValueSpan drop_front(int num = 1) { + return reg_span_.drop_front(num); + } + + RegisterValueSpan drop_back(int num = 1) { return reg_span_.drop_back(num); } + + RegisterSpan reg_span() const { return reg_span_; } + + private: + RegisterSpan reg_span_; +}; + +} // namespace mlrt + +#endif // TENSORFLOW_CORE_TFRT_MLRT_INTERPRETER_REGISTER_SPAN_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/tfrt/mlrt/interpreter/value.h b/third_party/tflite-hdrs/tensorflow/core/tfrt/mlrt/interpreter/value.h new file mode 100644 index 00000000..a7113a7c --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/tfrt/mlrt/interpreter/value.h @@ -0,0 +1,419 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_TFRT_MLRT_INTERPRETER_VALUE_H_ +#define TENSORFLOW_CORE_TFRT_MLRT_INTERPRETER_VALUE_H_ + +#include +#include +#include + +#include "absl/log/check.h" +#include "absl/log/log.h" + +namespace mlrt { + +class Value; + +namespace value_internal { + +struct InPlaceStorageT { + // Many tensor implementations like tensorflow::Tensor requires multiple + // words, and we'd like to keep these values inplace. + // + // TODO(chky): Consider a better size for inplace storage. + alignas(8) char data[56]; +}; + +template +using IsInPlaceStorage = + std::integral_constant::value>; + +// Since we type-erase the value to be put in class Value, we need to an enum +// value to select the operation that should be applied on the type-erased +// value. +enum class Action { + kDestroy = 0, // Destructor + kCopy, // Copy constructor/assignment + kMove, // Move constructor/assignment + kError, // Error handler + kTypeInfo // Get type info +}; + +struct TypeInfo {}; + +using HandlerFuncPtr = TypeInfo* (*)(Action, Value*, Value*); + +template +class InPlaceHandler; +template +class OutOfPlaceHandler; + +template +using Handler = std::conditional_t::value, + InPlaceHandler, OutOfPlaceHandler>; + +template +struct HasHandleError : std::false_type {}; + +template +struct HasHandleError< + T, std::void_t().HandleError(nullptr))>> + : std::true_type {}; + +} // namespace value_internal + +// A container for type-erased value. The value should be at least copy +// constructable to be put into this container. This container has both move and +// copy semantics, but if the concrete value does not support copy, calling the +// copy operations on this class will result in undefined behavior. +class alignas(64) Value { + public: + // Value is default constructible. The payload is unset in the default + // constructed Value. + Value() = default; + + Value(const Value&); + Value& operator=(const Value&); + Value(Value&&) noexcept; + Value& operator=(Value&&) noexcept; + + // Construct Value and store `t` as the payload. + template , Value>, + int>::type = 0> + explicit Value(T&& t); + + template , Value>, + int>::type = 0> + Value& operator=(T&& value) { + Set(std::forward(value)); + return *this; + } + + ~Value(); + + // Get() function returns the payload of the Value object in the requested + // type. + // + // Dynamic type checking is performed in the debug mode. + template + T& Get(); + + template + const T& Get() const; + + // Emplace() constructs the payload object of type T in place with the given + // args. If the value is already initialized, the original value will be + // destroyed. + template + void Emplace(Args&&... args); + + // Construct() constructs the payload object of type T in place with the given + // args. The value should be uninitialized before calling this method. + // Otherwise the behavior is undefined. + template + void Construct(Args&&... args); + + // Destroy() destroys the payload object of type T. The value must be already + // initialized with a value of type T. Otherwise the behavior is undefined. + template + void Destroy(); + + // Set() stores the argument `t` as the payload of Value. + template + void Set(T&& t); + + // Reset the Value object to empty. + void Reset(); + + // Call T::HandleError() method on the underlying value of type T. If T does + // not have a HandleError() method, this method does nothing. + void HandleError(Value& arg); + + // Check if Value contains a payload. + bool HasValue() const { return handler_ != nullptr; } + + // Check if Value contains object of type T. + template + bool IsType() const; + + // Check if object of type T is stored in place. + template + static constexpr bool IsInPlace() { + return value_internal::IsInPlaceStorage::value; + } + + private: + union { + value_internal::InPlaceStorageT storage_{}; + void* value_; + }; + value_internal::HandlerFuncPtr handler_ = nullptr; + + template + friend class value_internal::InPlaceHandler; + template + friend class value_internal::OutOfPlaceHandler; +}; + +// We only optimize the code for 64-bit architectures for now. +static_assert(sizeof(Value) == 64 || sizeof(void*) != 8); + +// ----------------------------------------------------------- +// Implementation details. + +namespace value_internal { + +template +TypeInfo* GetTypeInfo(); + +template ::value, int> = 0> +void HandleErrorInternal(Value* self, Value* arg) { + std::move(self->Get()).HandleError(arg); +} + +template ::value, int> = 0> +static void HandleErrorInternal(Value* self, Value* arg) {} + +template +struct InPlaceHandler { + template + static void Construct(Value* self, Args&&... args) { + new (&self->storage_) T(std::forward(args)...); + self->handler_ = &Handle; + } + + static TypeInfo* Handle(Action action, Value* self, Value* other) { + switch (action) { + case Action::kDestroy: + Destroy(self); + return nullptr; + case Action::kCopy: + Copy(self, other); + return nullptr; + case Action::kMove: + Move(self, other); + return nullptr; + case Action::kError: + HandleError(self, other); + return nullptr; + case Action::kTypeInfo: + return GetTypeInfo(); + } + } + + static void Destroy(Value* self) { + DCHECK(self->HasValue()); + auto* p = std::launder(reinterpret_cast(&self->storage_)); + p->~T(); + self->handler_ = nullptr; + } + + template ::value, int> = 0> + static void CopyInternal(Value* self, Value* dest) { + DCHECK(self->HasValue() && !dest->HasValue()); + Construct(dest, *std::launder(reinterpret_cast(&self->storage_))); + } + + template ::value, int> = 0> + static void CopyInternal(Value* self, Value* dest) { + LOG(FATAL) << "Copying a mlrt::Value whose underlying type is " // Crash Ok + "not copyable is a runtime error."; + } + + static void Copy(Value* self, Value* dest) { CopyInternal(self, dest); } + + static void Move(Value* self, Value* dest) { + DCHECK(self->HasValue() && !dest->HasValue()); + Construct(dest, + std::move(*std::launder(reinterpret_cast(&self->storage_)))); + Destroy(self); + } + + static void HandleError(Value* self, Value* arg) { + HandleErrorInternal(self, arg); + } +}; + +template +struct OutOfPlaceHandler { + template + static void Construct(Value* self, Args&&... args) { + self->value_ = new T(std::forward(args)...); + self->handler_ = &Handle; + } + + static TypeInfo* Handle(Action action, Value* self, Value* other) { + switch (action) { + case Action::kDestroy: + Destroy(self); + return nullptr; + case Action::kCopy: + Copy(self, other); + return nullptr; + case Action::kMove: + Move(self, other); + return nullptr; + case Action::kError: + HandleError(self, other); + return nullptr; + case Action::kTypeInfo: + return GetTypeInfo(); + } + } + + static void Destroy(Value* self) { + DCHECK(self->HasValue()); + delete static_cast(self->value_); + self->handler_ = nullptr; + } + + template ::value, int> = 0> + static void CopyInternal(Value* self, Value* dest) { + DCHECK(self->HasValue() && !dest->HasValue()); + Construct(dest, *static_cast(self->value_)); + } + + template ::value, int> = 0> + static void CopyInternal(Value* self, Value* dest) { + LOG(FATAL) << "Copying a mlrt::Value whose underlying type is " // Crash Ok + "not copyable is a runtime error."; + } + + static void Copy(Value* self, Value* dest) { CopyInternal(self, dest); } + + static void Move(Value* self, Value* dest) { + DCHECK(self->HasValue() && !dest->HasValue()); + dest->value_ = self->value_; + dest->handler_ = &Handle; + self->handler_ = nullptr; + } + + static void HandleError(Value* self, Value* arg) { + HandleErrorInternal(self, arg); + } +}; + +template +__attribute__((noinline)) TypeInfo* GetTypeInfo() { + static TypeInfo kTypeInfo; + return &kTypeInfo; +} + +} // namespace value_internal + +template , Value>, int>::type> +Value::Value(T&& t) { + Construct>(std::forward(t)); +} + +inline Value::Value(const Value& v) { + if (v.HasValue()) + v.handler_(value_internal::Action::kCopy, const_cast(&v), this); +} + +inline Value& Value::operator=(const Value& v) { + Reset(); + if (v.HasValue()) + v.handler_(value_internal::Action::kCopy, const_cast(&v), this); + return *this; +} + +inline Value::Value(Value&& v) noexcept { + if (v.HasValue()) v.handler_(value_internal::Action::kMove, &v, this); +} + +inline Value& Value::operator=(Value&& v) noexcept { + Reset(); + if (v.HasValue()) v.handler_(value_internal::Action::kMove, &v, this); + return *this; +} + +inline void Value::HandleError(Value& arg) { + if (HasValue()) handler_(value_internal::Action::kError, this, &arg); +} + +inline Value::~Value() { Reset(); } + +template +T& Value::Get() { + return const_cast(static_cast(this)->Get()); +} + +template +const T& Value::Get() const { + DCHECK(IsType()); + + if constexpr (IsInPlace()) { + return *std::launder(reinterpret_cast(&storage_)); + } + + return *static_cast(value_); +} + +// Emplace() constructs the payload object of type T in place with the given +// args. +template +void Value::Emplace(Args&&... args) { + Reset(); + Construct>(std::forward(args)...); +} + +// Set() stores the argument `t` as the payload of Value. +template +void Value::Set(T&& t) { + Emplace(std::forward(t)); +} + +template +void Value::Construct(Args&&... args) { + DCHECK(!HasValue()); + static_assert(!std::is_same_v); + value_internal::Handler::Construct(this, std::forward(args)...); +} + +template +void Value::Destroy() { + DCHECK(HasValue()); + DCHECK(IsType()); + static_assert(!std::is_same_v); + value_internal::Handler::Destroy(this); +} + +// Reset the Value object to empty. +inline void Value::Reset() { + if (handler_ == nullptr) return; + handler_(value_internal::Action::kDestroy, this, nullptr); +} + +template +bool Value::IsType() const { + return handler_(value_internal::Action::kTypeInfo, const_cast(this), + nullptr) == value_internal::GetTypeInfo(); +} + +} // namespace mlrt + +#endif // TENSORFLOW_CORE_TFRT_MLRT_INTERPRETER_VALUE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/tfrt/mlrt/kernel/batch_kernel.h b/third_party/tflite-hdrs/tensorflow/core/tfrt/mlrt/kernel/batch_kernel.h new file mode 100644 index 00000000..a7b8e5f1 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/tfrt/mlrt/kernel/batch_kernel.h @@ -0,0 +1,28 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_TFRT_MLRT_KERNEL_BATCH_KERNEL_H_ +#define TENSORFLOW_CORE_TFRT_MLRT_KERNEL_BATCH_KERNEL_H_ + +#include "tensorflow/core/tfrt/mlrt/interpreter/context.h" + +namespace tensorflow { +namespace tf_mlrt { + +void RegisterTfMlrtBatchKernels(mlrt::KernelRegistry& registry); + +} // namespace tf_mlrt +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_TFRT_MLRT_KERNEL_BATCH_KERNEL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/tfrt/mlrt/kernel/context.h b/third_party/tflite-hdrs/tensorflow/core/tfrt/mlrt/kernel/context.h new file mode 100644 index 00000000..fa682f22 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/tfrt/mlrt/kernel/context.h @@ -0,0 +1,132 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_TFRT_MLRT_KERNEL_CONTEXT_H_ +#define TENSORFLOW_CORE_TFRT_MLRT_KERNEL_CONTEXT_H_ + +#include "tensorflow/core/runtime_fallback/kernel/kernel_fallback_compat_request_state.h" +#include "tensorflow/core/tfrt/fallback/op_kernel_runner.h" +#include "tensorflow/core/tfrt/mlrt/interpreter/context.h" +#include "tfrt/host_context/execution_context.h" // from @tf_runtime +#include "tfrt/host_context/resource_context.h" // from @tf_runtime + +namespace tensorflow { +namespace tf_mlrt { + +// The context for tensorflow::OpKernel. +class Context : public mlrt::UserContext { + public: + explicit Context( + const tfd::KernelFallbackCompatRequestState* fallback_request_state, + tfrt::ResourceContext* resource_context, + const tfrt::CancellationContext* cancellation_context = nullptr) + : fallback_request_state_(fallback_request_state), + op_kernel_context_(fallback_request_state_), + resource_context_(resource_context), + cancellation_context_(cancellation_context) { + DCHECK(resource_context_); + } + + Context(const Context&) = default; + Context& operator=(const Context&) = default; + + const tfd::KernelFallbackCompatRequestState& fallback_request_state() const { + return *fallback_request_state_; + } + void set_fallback_request_state( + const tfd::KernelFallbackCompatRequestState* fallback_request_state) { + DCHECK(fallback_request_state); + fallback_request_state_ = fallback_request_state; + } + + OpKernelContext::Params& params() { return op_kernel_context_.params; } + OpKernelContext& op_kernel_context() { + return op_kernel_context_.op_kernel_context; + } + + tfrt::ResourceContext& resource_context() const { return *resource_context_; } + + const tfrt::CancellationContext* cancellation_context() const { + return cancellation_context_; + } + + tfrt_stub::OpKernelRunState& run_state() { + // Keep states needed by kernel execution in a thread local storage to avoid + // repeated reallocation and destruction of them. + thread_local tfrt_stub::OpKernelRunState run_state; + return run_state; + } + + // Return true if there is a cancellation request. + bool IsCancelled() { + return cancellation_context_ != nullptr && + cancellation_context_->IsCancelled(); + } + + private: + const tfd::KernelFallbackCompatRequestState* fallback_request_state_ = + nullptr; + + struct CopyableOpKernelContext { + OpKernelContext::Params params; + OpKernelContext op_kernel_context; + + explicit CopyableOpKernelContext( + const tfd::KernelFallbackCompatRequestState* fallback_request_state) + : params(), + op_kernel_context( + [this, fallback_request_state]() { + DCHECK(fallback_request_state); + params.step_id = fallback_request_state->step_id(); + auto* device = fallback_request_state->cpu_device(); + params.device = device; + // Still use original device's resource_manager. + params.resource_manager = device->resource_manager(); + params.step_container = + fallback_request_state->step_container(); + // Following two parameters are used to support executing + // tf.data via fallback. + params.function_library = + fallback_request_state->cpu_function_library_runtime(); + params.runner = fallback_request_state->runner(); + params.collective_executor = + fallback_request_state->collective_executor(); + params.rendezvous = fallback_request_state->rendezvous(); + params.session_metadata = + &fallback_request_state->session_metadata(); + params.cancellation_manager = + fallback_request_state->cancellation_manager(); + return ¶ms; + }(), + 0) {} + CopyableOpKernelContext(const CopyableOpKernelContext& other) + : params(other.params), + op_kernel_context(¶ms, other.op_kernel_context.num_outputs()) {} + CopyableOpKernelContext& operator=(const CopyableOpKernelContext& other) { + params = other.params; + op_kernel_context.ResetOutputs(other.op_kernel_context.num_outputs()); + return *this; + } + ~CopyableOpKernelContext() { op_kernel_context.ResetOutputs(); } + }; + CopyableOpKernelContext op_kernel_context_; + + tfrt::ResourceContext* resource_context_ = nullptr; + const tfrt::CancellationContext* cancellation_context_; +}; + +} // namespace tf_mlrt +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_TFRT_MLRT_KERNEL_CONTEXT_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/tfrt/mlrt/kernel/kernel.h b/third_party/tflite-hdrs/tensorflow/core/tfrt/mlrt/kernel/kernel.h new file mode 100644 index 00000000..36ee01d1 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/tfrt/mlrt/kernel/kernel.h @@ -0,0 +1,30 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_TFRT_MLRT_KERNEL_KERNEL_H_ +#define TENSORFLOW_CORE_TFRT_MLRT_KERNEL_KERNEL_H_ + +#include "tensorflow/core/tfrt/mlrt/interpreter/context.h" + +namespace tensorflow { +namespace tf_mlrt { + +mlrt::KernelRegistry& GetTfMlrtOptionalKernelRegistry(); + +void RegisterTfMlrtKernels(mlrt::KernelRegistry& registry); + +} // namespace tf_mlrt +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_TFRT_MLRT_KERNEL_KERNEL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/tfrt/mlrt/kernel/kernel_runner_utils.h b/third_party/tflite-hdrs/tensorflow/core/tfrt/mlrt/kernel/kernel_runner_utils.h new file mode 100644 index 00000000..daecf14a --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/tfrt/mlrt/kernel/kernel_runner_utils.h @@ -0,0 +1,150 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_TFRT_MLRT_KERNEL_KERNEL_RUNNER_UTILS_H_ +#define TENSORFLOW_CORE_TFRT_MLRT_KERNEL_KERNEL_RUNNER_UTILS_H_ + +#include +#include +#include +#include + +#include "absl/base/optimization.h" +#include "absl/cleanup/cleanup.h" +#include "tensorflow/core/profiler/lib/traceme.h" +#include "tensorflow/core/runtime_fallback/kernel/kernel_fallback_compat_request_state.h" +#include "tensorflow/core/runtime_fallback/kernel/kernel_fallback_utils.h" +#include "tensorflow/core/tfrt/fallback/op_kernel_runner.h" +#include "tensorflow/core/tfrt/mlrt/interpreter/context.h" +#include "tensorflow/core/tfrt/mlrt/interpreter/register_span.h" +#include "tensorflow/core/tfrt/mlrt/kernel/context.h" + +namespace tensorflow { +namespace tf_mlrt { + +void LaunchAsyncOpKernel(const tfrt_stub::OpKernelRunner& kernel_runner, + const tfrt_stub::OpKernelRunState& run_state, + const OpKernelContext::Params& params, + mlrt::RegisterSpan results, + std::shared_ptr custom_device); + +inline void SetUpParams(const tfrt_stub::OpKernelRunner& kernel_runner, + absl::Span input_tf_tensor_values, + OpKernelContext::Params& params) { + params.inputs = input_tf_tensor_values; + params.op_kernel = kernel_runner.op_kernel(); + params.input_alloc_attrs = kernel_runner.input_alloc_attrs(); + params.output_attr_array = kernel_runner.output_alloc_attrs().data(); +} + +template +void ExecuteKernelRunner( + Frame& frame, Context& context, + const tfd::KernelFallbackCompatRequestState& fallback_request_state, + const tfrt_stub::OpKernelRunner& kernel_runner) { + tsl::profiler::TraceMe trace_me([&]() -> std::string { + return tsl::profiler::TraceMeOp( + kernel_runner.op_kernel()->name_view(), + kernel_runner.op_kernel()->type_string_view()); + }); + + auto args = frame.args(); + auto last_uses = frame.last_uses(); + + auto& run_state = context.run_state(); + auto& tensor_buffers = run_state.tensor_buffers; + + auto clean_up_inputs = absl::MakeCleanup([&]() { + for (const auto* buffer : tensor_buffers) { + DCHECK(buffer); + buffer->Unref(); + } + tensor_buffers.clear(); + }); + + // Prepare the input tensors. + auto& input_tf_tensor_values = run_state.input_tf_tensor_values; + input_tf_tensor_values.resize(args.size()); + for (int i = 0; i < args.size(); ++i) { + auto& fallback_tensor = args[i]; + // If the argument is immutable or it is the last use in the current scope, + // we can just keep the reference without copying that invovles expensive + // atomic reference counting. And if it is the last use, it can enable + // buffer forwarding optimization in many tensorflow OpKernels. + if (!fallback_tensor.is_immutable() && !last_uses[i]) { + if (const auto* buffer = fallback_tensor.buffer()) { + buffer->Ref(); + tensor_buffers.push_back(buffer); + } + } + input_tf_tensor_values[i].tensor = &fallback_tensor.tensor(); + } + + auto& params = context.params(); + SetUpParams(kernel_runner, input_tf_tensor_values, params); + + auto results = frame.results(); + + if constexpr (!IsAsync) { + tensorflow::DeviceBase* device = nullptr; + if constexpr (Frame::kUseCustomDevice) { + // If the kernel is using custom device, save the current device and + // change to the custom device. + device = params.device; + params.device = frame.device().get(); + } + + auto& op_kernel_context = context.op_kernel_context(); + op_kernel_context.ResetOutputs(results.size()); + + kernel_runner.Run(&op_kernel_context); + + if constexpr (Frame::kUseCustomDevice) { + // We need to restore the device as params will be reused by kernels + // invoked later. + params.device = device; + } + + if (ABSL_PREDICT_FALSE(!op_kernel_context.status().ok())) { + frame.execution_context().Fail(op_kernel_context.status()); + return; + } + + for (int i = 0; i < op_kernel_context.num_outputs(); ++i) { + DCHECK(op_kernel_context.mutable_output(i)); + results[i].template Emplace( + std::move(*op_kernel_context.mutable_output(i))); + } + } else { + std::shared_ptr custom_device = nullptr; + if constexpr (Frame::kUseCustomDevice) { + custom_device = frame.device(); + } + + LaunchAsyncOpKernel(kernel_runner, run_state, params, results, + std::move(custom_device)); + } + + auto reg_span = args.reg_span(); + for (int i = 0; i < last_uses.size(); ++i) { + if (last_uses[i]) { + reg_span[i].template Destroy(); + } + } +} + +} // namespace tf_mlrt +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_TFRT_MLRT_KERNEL_KERNEL_RUNNER_UTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/tfrt/mlrt/kernel/shard_restore_util.h b/third_party/tflite-hdrs/tensorflow/core/tfrt/mlrt/kernel/shard_restore_util.h new file mode 100644 index 00000000..d194b687 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/tfrt/mlrt/kernel/shard_restore_util.h @@ -0,0 +1,41 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_TFRT_MLRT_KERNEL_SHARD_RESTORE_UTIL_H_ +#define TENSORFLOW_CORE_TFRT_MLRT_KERNEL_SHARD_RESTORE_UTIL_H_ + +#include +#include + +#include "absl/status/statusor.h" +#include "absl/types/span.h" + +namespace tensorflow { +namespace tf_mlrt { + +// Shard variables into cluster of roughly the same size. +// +// `num_shards` is the number of shards to create. +// `variable_sizes` is the sizes of the variables. +// +// Returns a list of clusters, each of which is represented +// as a vector of variable indices. +std::vector> ShardVariables( + int num_shards, absl::Span variable_sizes); + +} // namespace tf_mlrt +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_TFRT_MLRT_KERNEL_SHARD_RESTORE_UTIL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/tfrt/run_handler_thread_pool/run_handler.h b/third_party/tflite-hdrs/tensorflow/core/tfrt/run_handler_thread_pool/run_handler.h new file mode 100644 index 00000000..87baccab --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/tfrt/run_handler_thread_pool/run_handler.h @@ -0,0 +1,458 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_TFRT_RUN_HANDLER_THREAD_POOL_RUN_HANDLER_H_ +#define TENSORFLOW_CORE_TFRT_RUN_HANDLER_THREAD_POOL_RUN_HANDLER_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "tensorflow/core/lib/core/threadpool.h" +#include "tensorflow/core/lib/histogram/histogram.h" +#include "tensorflow/core/platform/context.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/thread_annotations.h" +#include "tensorflow/core/protobuf/config.pb.h" +#include "tensorflow/core/tfrt/runtime/work_queue_interface.h" +#include "tfrt/host_context/task_function.h" // from @tf_runtime +namespace Eigen { +struct ThreadPoolDevice; +} + +namespace tfrt { +namespace tf { + +class RunHandler; + +// Options for RunHanler. +struct RunHandlerOptions { + RunHandlerOptions() : priority(0) {} + + // Request priority. + int priority; +}; + +// RunHandlerPool is a fixed size pool of pre-allocated RunHandlers +// that can be used for tracking op work for a given inference request. +// RunHandler(s) in the pool are initially 'inactive'. A RunHandler becomes +// 'active' when its unique_ptr is returned by Get() and is being used by a +// client. It becomes 'inactive' once more when its unique_ptr gets destroyed. +// +// Expected usage: +// +// * Create a single RunHandlerPool (say run_handler_pool_). +// +// * When an inference request is invoked, obtain a handler by: +// auto handler = run_handler_pool_->Get(); +// +// * Use handler for scheduling all inter-op work by: +// handler->ScheduleInterOpClosure(closure); +// +// This class is thread safe. +class RunHandlerPool { + public: + struct Options { + // The number of main threads. + int num_inter_op_threads = 1; + + // The number of complimentary threads. + int num_intra_op_threads = 1; + + // The number of max concurrent handlers. + int max_concurrent_handler = 128; + + // The number of sub thread pool configed. + int num_sub_thread_pool = 1; + + // The number of threads in each sub thread pool. The length of the vector + // should equal to num_sub_thread_pool. + std::vector num_threads_in_sub_thread_pool = {1}; + + // The percentage of requests the first N sub thread pool handles. The + // length of the vector should equal to num_sub_thread_pool. For example, + // {0.5, 1} means the first sub thread pool will handle the first 50% + // requests based on priority and the second thread pool will handle the + // second 50% requests based on priority. + std::vector sub_thread_request_percentage = {1.0}; + + // Sleep time for non blocking threads if there is no pending task. + int non_blocking_threads_sleep_time_micro_sec = 1000; + + // Max sleep time for blocking threads if there is no pending task and no + // new task wakes up the thread. + int blocking_threads_max_sleep_time_micro_sec = 1000; + + // If true, use adaptive waiting time. + bool use_adaptive_waiting_time = true; + + // If true, threads won't wake itself up if there is no active requests. + bool wait_if_no_active_request = true; + + // If true, threads will be waken up by new tasks. + bool enable_wake_up = true; + }; + explicit RunHandlerPool(Options options); + ~RunHandlerPool(); + + // Returns an inactive RunHandler from the pool. + // + // RunHandlers in RunHandlerPool are initially 'inactive'. + // A RunHandler becomes 'active' when its unique_ptr its returned by Get() + // and is being used by a client. It becomes 'inactive' once more when the + // unique_ptr is destroyed. + // + // Will block unless there is an inactive handler. + std::unique_ptr Get( + int64_t step_id = 0, int64_t timeout_in_ms = 0, + const RunHandlerOptions& options = RunHandlerOptions()); + + // Get the priorities for active handlers. The return result is with the same + // order of the active handler list. + std::vector GetActiveHandlerPrioritiesForTesting() const; + + // Block until the system is quiescent (no pending work and no inflight work). + void Quiesce() const; + + private: + class Impl; + friend class RunHandler; + + std::unique_ptr impl_; +}; + +// RunHandler can be used to schedule inter/intra-op closures to run on a global +// pool shared across all Session::Run(s). The closures are enqueued to a +// handler specific queue, from which the work is stolen in a priority order +// (time of the Get() call). +// +// It can only be created via RunHandlerPool::Get(). +// +// This class can be used instead of directly scheduling closures on a global +// pool since it maintains a global view across all sessions and optimizes pool +// scheduling to improve (median and tail) latency. +// +// This class is thread safe. +class RunHandler { + public: + void ScheduleInterOpClosure(TaskFunction fn); + void ScheduleIntraOpClosure(TaskFunction fn); + + tensorflow::thread::ThreadPoolInterface* AsIntraThreadPoolInterface() const; + + int NumThreads() const; + + int64_t step_id() const; + + ~RunHandler(); + + private: + class Impl; + friend class RunHandlerPool::Impl; + + explicit RunHandler(Impl* impl); + + Impl* impl_; // NOT OWNED. +}; + +namespace internal { + +// TODO(azaks): Refactor with thread:ThreadPool +class RunHandlerEnvironment { + public: + typedef tensorflow::Thread EnvThread; + struct TaskImpl { + TaskFunction f; + tensorflow::Context context; + uint64_t trace_id; + }; + tensorflow::Env* const env_; + const tensorflow::ThreadOptions thread_options_; + const std::string name_; + + public: + struct Task { + std::unique_ptr f; + }; + + RunHandlerEnvironment(tensorflow::Env* env, + const tensorflow::ThreadOptions& thread_options, + const std::string& name); + + EnvThread* CreateThread(std::function f); + + Task CreateTask(TaskFunction f); + + void ExecuteTask(const Task& t); +}; + +typedef typename RunHandlerEnvironment::Task Task; +typedef Eigen::RunQueue Queue; + +// To reduce cache misses, we use a doubly-linked list of Waiter structs and +// queue them in LIFO order rather than the FIFO order used by a single +// condition variable. +struct Waiter { + Waiter() { + next = this; + prev = this; + } + tensorflow::condition_variable cv; + int num_waiting_threads = 0; + tensorflow::mutex mu; + Waiter* next; + Waiter* prev; +}; + +class ThreadWorkSource { + public: + ThreadWorkSource(); + + ~ThreadWorkSource(); + + Task EnqueueTask(Task t, bool is_blocking, bool enable_wake_up); + + Task PopBlockingTask(); + + Task PopNonBlockingTask(int start_index, bool search_from_all_queue); + + int TaskQueueSize(bool is_blocking); + + int64_t GetTracemeId(); + + void SetTracemeId(int64_t value); + + void SetWaiter(uint64_t version, Waiter* waiter, tensorflow::mutex* mutex); + + int64_t GetInflightTaskCount(bool is_blocking); + + void IncrementInflightTaskCount(bool is_blocking); + + void DecrementInflightTaskCount(bool is_blocking); + + int64_t GetPendingTaskCount(); + + void IncrementPendingTaskCount(); + + void DecrementPendingTaskCount(); + + unsigned NonBlockingWorkShardingFactor(); + + std::string ToString(); + + private: + struct NonBlockingQueue { + tensorflow::mutex queue_op_mu; + char pad[128]; + Queue queue; + }; + + int32_t non_blocking_work_sharding_factor_; + Eigen::MaxSizeVector non_blocking_work_queues_; + + // The number of tasks that are executing now. + std::atomic blocking_inflight_; + std::atomic non_blocking_inflight_; + + // The number of tasks that are enqueued and not finished. + std::atomic pending_tasks_; + + Queue blocking_work_queue_; + tensorflow::mutex blocking_queue_op_mu_; + char pad_[128]; + tensorflow::mutex waiters_mu_; + Waiter queue_waiters_ TF_GUARDED_BY(waiters_mu_); + std::atomic traceme_id_; + + tensorflow::mutex run_handler_waiter_mu_; + uint64_t version_ TF_GUARDED_BY(run_handler_waiter_mu_); + tensorflow::mutex* sub_thread_pool_waiter_mu_ + TF_GUARDED_BY(run_handler_waiter_mu_); + Waiter* sub_thread_pool_waiter_ TF_GUARDED_BY(run_handler_waiter_mu_); +}; + +class RunHandlerThreadPool { + public: + struct Options { + int num_blocking_threads; + int num_non_blocking_threads; + bool wait_if_no_active_request; + int non_blocking_threads_sleep_time_micro_sec; + int blocking_threads_max_sleep_time_micro_sec; + bool use_adaptive_waiting_time; + bool enable_wake_up; + int max_concurrent_handler; + std::vector num_threads_in_sub_thread_pool; + std::vector sub_thread_request_percentage; + Options(int num_blocking_threads, int num_non_blocking_threads, + bool wait_if_no_active_request, + int non_blocking_threads_sleep_time_micro_sec, + int blocking_threads_max_sleep_time_micro_sec, + bool use_adaptive_waiting_time, bool enable_wake_up, + int max_concurrent_handler, + const std::vector& num_threads_in_sub_thread_pool, + const std::vector& sub_thread_request_percentage) + : num_blocking_threads(num_blocking_threads), + num_non_blocking_threads(num_non_blocking_threads), + wait_if_no_active_request(wait_if_no_active_request), + non_blocking_threads_sleep_time_micro_sec( + non_blocking_threads_sleep_time_micro_sec), + blocking_threads_max_sleep_time_micro_sec( + blocking_threads_max_sleep_time_micro_sec), + use_adaptive_waiting_time(use_adaptive_waiting_time), + enable_wake_up(enable_wake_up), + max_concurrent_handler(max_concurrent_handler), + num_threads_in_sub_thread_pool(num_threads_in_sub_thread_pool), + sub_thread_request_percentage(sub_thread_request_percentage) {} + }; + struct PerThread { + constexpr PerThread() : pool(nullptr), thread_id(-1) {} + RunHandlerThreadPool* pool; // Parent pool, or null for normal threads. + int thread_id; // Worker thread index in pool. + }; + + RunHandlerThreadPool(Options options, tensorflow::Env* env, + const tensorflow::ThreadOptions& thread_options, + const std::string& name, + Eigen::MaxSizeVector* waiters_mu, + Eigen::MaxSizeVector* queue_waiters); + + ~RunHandlerThreadPool(); + + void Start(); + + void StartOneThreadForTesting(); + + void AddWorkToQueue(ThreadWorkSource* tws, bool is_blocking, TaskFunction fn); + + // Set work queues from which the thread 'tid' can steal its work. + void SetThreadWorkSources( + int tid, uint64_t version, + const Eigen::MaxSizeVector& thread_work_sources); + + PerThread* GetPerThread(); + + int CurrentThreadId() const; + + int NumThreads() const; + + int NumBlockingThreads() const; + + int NumNonBlockingThreads() const; + + void WorkerLoop(int thread_id, bool may_steal_blocking_work); + + // Search tasks from Requets range searching_range_start to + // searching_range_end. If there is no tasks in the search range and + // may_steal_blocking_work is true, then search from all requests. + Task FindTask( + int searching_range_start, int searching_range_end, int thread_id, + int sub_thread_pool_id, int max_blocking_inflight, + bool may_steal_blocking_work, + const Eigen::MaxSizeVector& thread_work_sources, + bool* task_from_blocking_queue, ThreadWorkSource** tws); + + void WaitForWorkInSubThreadPool(int thread_id, bool is_blocking, + int sub_thread_pool_id); + + private: + struct ThreadData { + ThreadData(); + tensorflow::mutex mu; + uint64_t new_version; + tensorflow::condition_variable sources_not_empty; + std::unique_ptr thread; + int current_index; + std::unique_ptr> + new_thread_work_sources TF_GUARDED_BY(mu); + + uint64_t current_version; + // Should only be accessed by one thread. + std::unique_ptr> + current_thread_work_sources; + + int sub_thread_pool_id; + }; + + const int num_threads_; + const int num_blocking_threads_; + const int num_non_blocking_threads_; + const bool adaptive_sleep_time_; + const bool wait_if_no_active_request_; + const int non_blocking_thread_sleep_time_; + const int blocking_thread_max_waiting_time_; + const bool enable_wake_up_; + Eigen::MaxSizeVector thread_data_; + internal::RunHandlerEnvironment env_; + std::atomic cancelled_; + std::string name_; + Eigen::MaxSizeVector* waiters_mu_; + Eigen::MaxSizeVector* queue_waiters_; + + std::vector num_threads_in_sub_thread_pool_; + + // Threads in each sub thread pool will search tasks from + // the end_request_percentage of previous sub thread pool to its own + // end_request_percentage in a round robin fashion. + std::vector sub_thread_pool_end_request_percentage_; +}; + +} // namespace internal + +class RunHandlerWorkQueue : public tensorflow::tfrt_stub::WorkQueueInterface { + public: + explicit RunHandlerWorkQueue(std::unique_ptr run_handler) + : WorkQueueInterface(run_handler->step_id(), + run_handler->AsIntraThreadPoolInterface()), + run_handler_(std::move(run_handler)) { + DCHECK(run_handler_); + } + ~RunHandlerWorkQueue() override = default; + + std::string name() const override { return "run_handler"; } + + int GetParallelismLevel() const override; + + void AddTask(TaskFunction work) override; + + std::optional AddBlockingTask(TaskFunction work, + bool allow_queuing) override; + + void Await( + llvm::ArrayRef> values) override; + + bool IsInWorkerThread() const override; + + void Quiesce() override { + LOG(FATAL) << "RunHandlerWorkQueue::Quiesce() is not " // Crash OK + "implemented, and supposed to be removed."; + } + + private: + std::unique_ptr run_handler_; +}; + +} // end namespace tf +} // end namespace tfrt + +#endif // TENSORFLOW_CORE_TFRT_RUN_HANDLER_THREAD_POOL_RUN_HANDLER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/tfrt/run_handler_thread_pool/run_handler_concurrent_work_queue.h b/third_party/tflite-hdrs/tensorflow/core/tfrt/run_handler_thread_pool/run_handler_concurrent_work_queue.h new file mode 100644 index 00000000..23dd6c86 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/tfrt/run_handler_thread_pool/run_handler_concurrent_work_queue.h @@ -0,0 +1,142 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_TFRT_RUN_HANDLER_THREAD_POOL_RUN_HANDLER_CONCURRENT_WORK_QUEUE_H_ +#define TENSORFLOW_CORE_TFRT_RUN_HANDLER_THREAD_POOL_RUN_HANDLER_CONCURRENT_WORK_QUEUE_H_ + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/status/statusor.h" +#include "tensorflow/core/platform/strcat.h" +#include "tensorflow/core/tfrt/run_handler_thread_pool/run_handler.h" +#include "tensorflow/core/tfrt/runtime/work_queue_interface.h" +#include "tfrt/host_context/execution_context.h" // from @tf_runtime +#include "tfrt/support/thread_environment.h" // from @tf_runtime +#include "third_party/concurrent_work_queue/lib/blocking_work_queue.h" +#include "third_party/concurrent_work_queue/lib/non_blocking_work_queue.h" + +namespace tfrt { +namespace tf { + +// Concurrent Work Queue based on Run Handler thread Pool. All tasks are queued +// based on requests. +class RunHandlerThreadWorkQueue + : public tensorflow::tfrt_stub::WorkQueueInterface { + public: + struct Options { + // The number of threads used for the main thread pool. + int num_main_threads; + + // The number of threads used for complementary thread pool. + int num_complementary_threads; + + // Timeout for InitRequest(). + // The timeout may trigger as the work queue limits the number of concurrent + // in-flight requests for better latency. + int64_t init_timeout_ms; + + // The number of max concurrent handlers. + int max_concurrent_handler = 128; + + // The number of sub thread pool configed. + int num_sub_thread_pool = 1; + + // The number of threads in each sub thread pool. The length of the vector + // should equal to num_sub_thread_pool. + std::vector num_threads_in_sub_thread_pool = {1}; + + // The percentage of requests the first N sub thread pool handles. The + // length of the vector should equal to num_sub_thread_pool. + std::vector sub_thread_request_percentage = {1.0}; + + // Sleep time for non blocking threads if there is no pending task. + int non_blocking_threads_sleep_time_micro_sec = 1000; + + // Max sleep time for blocking threads if there is no pending task and no + // new task wakes up the thread. + int blocking_threads_max_sleep_time_micro_sec = 1000; + + // If true, use adaptive waiting time. + bool use_adaptive_waiting_time = true; + + // If true, threads won't wake itself up if there is no active requests. + bool wait_if_no_active_request = true; + + // If true, threads will be waken up by new tasks. + bool enable_wake_up = true; + }; + + explicit RunHandlerThreadWorkQueue(const Options& options); + ~RunHandlerThreadWorkQueue() override = default; + + std::string name() const override { + return tensorflow::strings::StrCat( + "RunHandlerThreadWorkQueue C++ work queue (", options_.num_main_threads, + " main threads, ", options_.num_complementary_threads, + " complementary threads)"); + } + + absl::StatusOr> + InitializeRequest(int64_t request_id) const override; + + int GetParallelismLevel() const override { + return options_.num_main_threads + options_.num_complementary_threads; + } + + void AddTask(TaskFunction work) override; + + std::optional AddBlockingTask(TaskFunction work, + bool allow_queuing) override; + + void Quiesce() override; + + void Await(ArrayRef> values) override; + + bool IsInWorkerThread() const override; + + private: + Options options_; + + // Handler Pool. + // Each request will require a handler from the pool, and release the handler + // back to the pool once it is done. + std::unique_ptr handler_pool_; + + // An id assigned to each request for tracing purpose. + static std::atomic_int_fast64_t step_id_counter_; + + // QuiescingState for non_blocking_work_queue_ and blocking_work_queue_. + std::unique_ptr<::tfrt::internal::QuiescingState> quiescing_state_; + + // Nonblocking queue used for cases without execution context. + ::tfrt::internal::NonBlockingWorkQueue + non_blocking_work_queue_; + + // Blocking queue used for cases without execution context. + ::tfrt::internal::BlockingWorkQueue + blocking_work_queue_; +}; + +std::ostream& operator<<(std::ostream& strm, + const RunHandlerThreadWorkQueue::Options& options); +} // namespace tf +} // namespace tfrt + +#endif // TENSORFLOW_CORE_TFRT_RUN_HANDLER_THREAD_POOL_RUN_HANDLER_CONCURRENT_WORK_QUEUE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/tfrt/run_handler_thread_pool/run_handler_util.h b/third_party/tflite-hdrs/tensorflow/core/tfrt/run_handler_thread_pool/run_handler_util.h new file mode 100644 index 00000000..acf15a70 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/tfrt/run_handler_thread_pool/run_handler_util.h @@ -0,0 +1,49 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_TFRT_RUN_HANDLER_THREAD_POOL_RUN_HANDLER_UTIL_H_ +#define TENSORFLOW_CORE_TFRT_RUN_HANDLER_THREAD_POOL_RUN_HANDLER_UTIL_H_ + +#include +#include +#include + +namespace tfrt { +namespace tf { + +// Look up environment variable named 'var_name' and return the value if it +// exist and can be parsed. Return 'default_value' otherwise. +double ParamFromEnvWithDefault(const char* var_name, double default_value); + +// Look up environment variable named 'var_name' and return the value if it +// exist and can be parsed. The value must be in format val1,val2... Return +// 'default_value' otherwise. +std::vector ParamFromEnvWithDefault(const char* var_name, + std::vector default_value); + +// Look up environment variable named 'var_name' and return the value if it +// exist and can be parsed. The value must be in format val1,val2... Return +// 'default_value' otherwise. +std::vector ParamFromEnvWithDefault(const char* var_name, + std::vector default_value); + +// Look up environment variable named 'var_name' and return the value if it +// exist and can be parsed. Return 'default_value' otherwise. +bool ParamFromEnvBoolWithDefault(const char* var_name, bool default_value); + +} // namespace tf +} // namespace tfrt + +#endif // TENSORFLOW_CORE_TFRT_RUN_HANDLER_THREAD_POOL_RUN_HANDLER_UTIL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/tfrt/runtime/runtime.h b/third_party/tflite-hdrs/tensorflow/core/tfrt/runtime/runtime.h new file mode 100644 index 00000000..1a6925c1 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/tfrt/runtime/runtime.h @@ -0,0 +1,248 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_TFRT_RUNTIME_RUNTIME_H_ +#define TENSORFLOW_CORE_TFRT_RUNTIME_RUNTIME_H_ + +#include +#include +#include +#include +#include +#include + +#include "absl/base/attributes.h" +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/strings/string_view.h" +#include "tensorflow/core/common_runtime/device_mgr.h" +#include "tensorflow/core/framework/device.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/platform/statusor.h" +#include "tensorflow/core/protobuf/config.pb.h" +#include "tensorflow/core/protobuf/meta_graph.pb.h" +#include "tensorflow/core/tfrt/graph_executor/graph_execution_options.h" +#include "tensorflow/core/tfrt/runtime/work_queue_interface.h" +#include "tsl/platform/errors.h" +#include "tfrt/core_runtime/core_runtime.h" // from @tf_runtime +#include "tfrt/host_context/resource_context.h" // from @tf_runtime + +namespace tensorflow { +namespace tfrt_stub { + +// ModelRuntimeContext provides model contexts for injected backends to +// initialize their per-model states. +class ModelRuntimeContext { + public: + ModelRuntimeContext(const GraphExecutionOptions* graph_execution_options, + std::string export_dir, + tfrt::ResourceContext* resource_context) + : graph_execution_options_(graph_execution_options), + export_dir_(std::move(export_dir)), + resource_context_(resource_context) { + DCHECK(graph_execution_options_); + DCHECK(resource_context_); + } + + absl::string_view name() const { + return graph_execution_options_->model_metadata.name(); + } + int64_t version() const { + return graph_execution_options_->model_metadata.version(); + } + + absl::string_view export_dir() const { return export_dir_; } + + const GraphDef* graph_def() const { return graph_def_; } + void set_graph_def(const GraphDef* graph_def) { graph_def_ = graph_def; } + + const CallableOptions* callable_options() const { return callable_options_; } + void set_callable_options(const CallableOptions* callable_options) { + callable_options_ = callable_options; + } + + FunctionLibraryDefinition* function_library_definition() const { + return flib_def_; + } + void set_function_library_definition(FunctionLibraryDefinition* flib_def) { + flib_def_ = flib_def; + } + + tensorflow::DeviceMgr* device_mgr() const { return device_mgr_; } + void set_device_mgr(tensorflow::DeviceMgr* device_mgr) { + device_mgr_ = device_mgr; + } + + bool is_local_session() const { return is_local_session_; } + + void set_is_local_session(bool is_local_session) { + is_local_session_ = is_local_session; + } + + tfrt::ResourceContext& resource_context() { return *resource_context_; } + + const GraphExecutionOptions& graph_execution_options() const { + return *graph_execution_options_; + } + + absl::string_view checkpoint_path() const { return checkpoint_path_; } + + void set_checkpoint_path(absl::string_view checkpoint_path) { + checkpoint_path_ = checkpoint_path; + } + + private: + const GraphExecutionOptions* graph_execution_options_ = nullptr; + + std::string export_dir_; + const GraphDef* graph_def_ = nullptr; + const CallableOptions* callable_options_ = nullptr; + tfrt::ResourceContext* resource_context_ = nullptr; + tensorflow::DeviceMgr* device_mgr_ = nullptr; + + FunctionLibraryDefinition* flib_def_ = nullptr; + + bool is_local_session_ = false; + std::string checkpoint_path_; +}; + +// This defines the runtime abstraction in tensorflow for TFRT. It is supposed +// to provide tensorflow specific functionalities that are implemented using +// TFRT. Currently, the only intended uses for this class are: +// 1) Creating the runtime instance with user specified dependencies (eg. +// thread pool). +// 2) Creating tensors that can be used by the runtime. +// +// It is temporary and will be replaced by the official +// tensorflow::experimental::cc::Runtime when it lands. +class Runtime { + public: + // Creates a runtime instance with specified threading configuration. Returns + // null upon creation error. + static std::unique_ptr Create(int num_inter_op_threads, + int num_intra_op_threads = 0); + + // Creates a runtime instance with the specified work_queue. Returns null upon + // creation error. + static std::unique_ptr Create( + std::unique_ptr work_queue); + + ~Runtime(); + Runtime(Runtime&&) = default; + Runtime& operator=(Runtime&&) = default; + + // TODO(tfrt-devs): Add methods for creating TFRT tensors. + + // TODO(chky): Make this method private as it should be only used by + // tfrt::SavedModel. Simply making tfrt::SavedModel a friend class does not + // work because the it resides in a different namespace. But we should + // consider moving it to the same namespace. + tfrt::CoreRuntime* core_runtime() const { return core_runtime_.get(); } + WorkQueueInterface* work_queue() const { return work_queue_; } + + // `AddCreateRuntimeResourceFn` allows the client to inject per model + // resources that are related to system-wide concepts, such as devices, when + // loading a SavedModel. + // + // A longer term plan is to use a Device concept for this purpose, so that + // Runtime contains a vector of Devices. Since it will take some time to + // iterate on the Device concept and integrate with the existing + // `tfrt::Device` class, we use the callback function as a temporary solution. + // + // The argument `fn` should be thread-safe. + void AddCreateRuntimeResourceFn( + std::function fn) { + runtime_resource_fns_.emplace_back( + [fn = std::move(fn)](ModelRuntimeContext& model_context) { + fn(&model_context.resource_context()); + return absl::OkStatus(); + }); + } + + void AddCreateRuntimeResourceFn( + std::function fn) { + runtime_resource_fns_.emplace_back(std::move(fn)); + } + + // `CreateRuntimeResources` populates `resource_ctx` with runtime-related + // resources. + // + // This function is thread-safe. + absl::Status CreateRuntimeResources( + ModelRuntimeContext& model_context) const { + for (auto& fn : runtime_resource_fns_) { + TF_RETURN_IF_ERROR(fn(model_context)); + } + return absl::OkStatus(); + } + + ABSL_DEPRECATED("Use the overload that take ModelRuntimeContext instead.") + void CreateRuntimeResources(const GraphExecutionOptions& options, + tfrt::ResourceContext* resource_ctx) const { + ModelRuntimeContext model_context( + &options, options.compile_options.saved_model_dir, resource_ctx); + for (auto& fn : runtime_resource_fns_) { + auto status = fn(model_context); + if (!status.ok()) { + LOG(ERROR) << "Failed to create runtime resource: " << status; + return; + } + } + } + + void SetCreateRequestQueueFn( + std::function< + absl::StatusOr>(int64_t)> + create_request_queue_fn) { + create_request_queue_fn_ = std::move(create_request_queue_fn); + } + + // Creates a work queue for a request. + absl::StatusOr> CreateRequestQueue( + int64_t request_id) const { + if (create_request_queue_fn_) { + return create_request_queue_fn_(request_id); + } + + return work_queue_->InitializeRequest(request_id); + } + + private: + explicit Runtime(std::unique_ptr core_runtime, + WorkQueueInterface* work_queue); + + std::unique_ptr core_runtime_; + std::function>(int64_t)> + create_request_queue_fn_; + WorkQueueInterface* work_queue_ = nullptr; + std::vector> + runtime_resource_fns_; +}; + +// Get a singleton instance of tfrt_stub::Runtime. Returns nullptr until +// SetGlobalRuntime has been called. +// Not thread safe. +Runtime* GetGlobalRuntime(); + +// Instantiates the singleton instance of tfrt_stub::Runtime by transferring +// an instance of tfrt_stub::Runtime. +// Not thread safe. +void SetGlobalRuntime(std::unique_ptr runtime); + +} // namespace tfrt_stub +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_TFRT_RUNTIME_RUNTIME_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/tfrt/runtime/step_id.h b/third_party/tflite-hdrs/tensorflow/core/tfrt/runtime/step_id.h new file mode 100644 index 00000000..f9de1a7d --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/tfrt/runtime/step_id.h @@ -0,0 +1,110 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_TFRT_RUNTIME_STEP_ID_H_ +#define TENSORFLOW_CORE_TFRT_RUNTIME_STEP_ID_H_ + +#include +#include + +#include "absl/strings/str_format.h" +#include "tensorflow/core/tfrt/kernels/stream_ops_util_constants.h" + +namespace tensorflow { +namespace tfrt_stub { + +// A base template for common utilities for a type safe id. +template +struct SafeId { + SafeId() : id(0) {} + explicit constexpr SafeId(int64_t id) : id(id) {} + + using Base = SafeId; + + int64_t id; + + friend bool operator==(const Derived& x, const Derived& y) { + return x.id == y.id; + } + + template + friend void AbslStringify(Sink& sink, const Derived& x) { + absl::Format(&sink, "%d", x.id); + } + + template + friend H AbslHashValue(H h, const Derived& x) { + return H::combine(std::move(h), x.id); + } +}; + +// A type-safe step id. +struct StepId : SafeId { + using Base::Base; + + bool valid() const { return id != 0; } + static constexpr StepId GetInvalidStepId() { return StepId(0); } +}; + +// The initial value of the step id. +std::atomic& GetGlobalInitialStepId(); + +// StepIdGenerator provides the utility to generate a monotonically increasing +// step id. And the number of bits can be configured at compile time. The step +// id is positive and the maximum value is 2^(kStepIdBitSize)-1. +class StepIdGenerator { + public: + StepIdGenerator() : next_id_(GetGlobalInitialStepId().load()) {} + + StepIdGenerator(const StepIdGenerator&) = delete; + StepIdGenerator& operator=(const StepIdGenerator&) = delete; + + // Generates a positive step id that is within the bit-range specified by + // `kStepIdBitSize`. + StepId GetNextStepId() { + uint64_t next_id = next_id_.fetch_add(1, std::memory_order_relaxed); + // Use kStepIdBitSize bits because we need to pack it with batch id if batch + // function is used. + static_assert(kStepIdBitSize <= 32); + next_id = (next_id & ((1ull << kStepIdBitSize) - 1)); + + if (next_id == 0) { + return GetNextStepId(); + } + + return StepId(static_cast(next_id)); + } + + private: + std::atomic next_id_{0}; +}; + +// Set up the initial step_id used by StepIdGenerator. This class is +// test-only. +class TEST_ScopedInitialStepId { + public: + explicit TEST_ScopedInitialStepId(uint64_t step_id); + ~TEST_ScopedInitialStepId(); + + TEST_ScopedInitialStepId(const TEST_ScopedInitialStepId&) = delete; + TEST_ScopedInitialStepId& operator=(const TEST_ScopedInitialStepId&) = delete; + + private: + uint64_t step_id_ = 0; +}; + +} // namespace tfrt_stub +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_TFRT_RUNTIME_STEP_ID_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/tfrt/runtime/stream.h b/third_party/tflite-hdrs/tensorflow/core/tfrt/runtime/stream.h new file mode 100644 index 00000000..03b0784b --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/tfrt/runtime/stream.h @@ -0,0 +1,285 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See +the License for the specific language governing permissions and limitations +under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_TFRT_RUNTIME_STREAM_H_ +#define TENSORFLOW_CORE_TFRT_RUNTIME_STREAM_H_ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/base/thread_annotations.h" +#include "absl/container/flat_hash_map.h" +#include "absl/functional/any_invocable.h" +#include "absl/log/check.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" +#include "absl/synchronization/mutex.h" +#include "absl/time/time.h" +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/tfrt/runtime/step_id.h" +#include "tsl/platform/env.h" +#include "tsl/platform/threadpool_interface.h" + +namespace tensorflow { +namespace tfrt_stub { + +struct StreamedResult { + absl::flat_hash_map tensors; + absl::Time enqueued_time; +}; + +struct StreamCallbackId : SafeId { + using Base::Base; +}; + +// An interface that abstracts communication between the +// `StreamCallbackRegistry` and the stream controller backend. +class StreamControllerInterface { + public: + explicit StreamControllerInterface(std::string controller_address) + : controller_address_(std::move(controller_address)) {} + virtual ~StreamControllerInterface() = default; + + absl::string_view controller_address() const { return controller_address_; } + + virtual void RecordDequeueLatency(absl::string_view model_name, + absl::Duration latency) {} + + virtual void RecordCallbackLatency(absl::string_view model_name, + absl::Duration latency) {} + + private: + std::string controller_address_; +}; + +// An interface that abstracts the communication from the `PwStreamResultsOp` +// worker to the controller. +class StreamWorkerInterface { + public: + explicit StreamWorkerInterface(std::string controller_address) + : controller_address_(std::move(controller_address)) {} + virtual ~StreamWorkerInterface() = default; + + absl::string_view controller_address() const { return controller_address_; } + + virtual void RecordSendLatency(absl::string_view model_name, + absl::Duration latency) {} + virtual absl::Status InvokeStreamCallback( + const StreamCallbackId& callback_id, + const std::vector& names, + const std::vector>>& + responses) = 0; + + private: + std::string controller_address_; +}; + +class ScopedStreamCallback; + +class StreamInterfaceFactory { + public: + using CreateWorkerStreamInterfaceFn = + std::function>( + absl::string_view)>; + + void RegisterController( + absl::AnyInvocable< + absl::StatusOr>() const> + interface_factory) { + absl::MutexLock lock(&mu_); + controller_interface_factory_ = std::move(interface_factory); + } + + absl::StatusOr> + CreateControllerStreamInterface() const { + absl::MutexLock lock(&mu_); + return controller_interface_factory_(); + } + + void RegisterWorker(CreateWorkerStreamInterfaceFn interface_factory) { + absl::MutexLock lock(&mu_); + worker_interface_factory_ = std::move(interface_factory); + } + + CreateWorkerStreamInterfaceFn CreateWorkerStreamInterface() const { + absl::MutexLock lock(&mu_); + return worker_interface_factory_; + } + + private: + mutable absl::Mutex mu_; + absl::AnyInvocable< + absl::StatusOr>() const> + controller_interface_factory_ ABSL_GUARDED_BY(mu_) = []() { + return absl::InternalError( + "The factory for StreamControllerInterface is not registered."); + }; + + CreateWorkerStreamInterfaceFn worker_interface_factory_ ABSL_GUARDED_BY(mu_) = + [](absl::string_view) { + return absl::InternalError( + "The factory for StreamWorkerInterface is not registered."); + }; +}; + +// Returns the global factory for the stream interface. The factory for the +// stream interface must be registered first before calling +// GetGlobalStreamCallbackRegistry(). +StreamInterfaceFactory& GetGlobalStreamInterfaceFactory(); + +// Mapping from tuples of (callback_id, step_id) to callback states. The mapping +// is stored in a global variable so that it can be shared between +// `ScopedStreamCallback` and `InvokeStreamCallbackOp`. +// +// This class is thread-safe. +class StreamCallbackRegistry { + public: + explicit StreamCallbackRegistry( + std::unique_ptr interface) + : interface_(std::move(interface)) { + DCHECK(interface_); + } + + // Registers a callback under the given id. A stream callback is uniquely + // identified by a tuple of a callback id (unique to each executable) and a + // step id (unique to each invocation of a given executable). Returns an RAII + // object that removes the callback from the registry on its deallocation, or + // an error if the id already exists in the registry. + // + // If a program runs `tf.PwStreamResults` with a matching callback/step id, + // `callback` will be called with the arguments of `tf.PwStreamResults`. + // + // All invocations to `callback` are handled serially by a single thread, so + // `callback` doesn't need to be thread-safe even if multiple + // `tf.PwStreamResults` ops may run concurrently. + absl::StatusOr Register( + absl::string_view model_name, StreamCallbackId callback_id, + StepId step_id, + absl::AnyInvocable< + void(absl::flat_hash_map)> + callback); + + absl::Status Invoke(tsl::thread::ThreadPoolInterface* thread_pool, + StreamCallbackId callback_id, StepId step_id, + StreamedResult result); + + StreamControllerInterface& stream_interface() const { return *interface_; } + + private: + friend class ScopedStreamCallback; + + class CallbackState { + public: + CallbackState(StreamCallbackRegistry* registry, + absl::string_view model_name, StreamCallbackId callback_id, + StepId step_id, + absl::AnyInvocable)> + callback) + : registry_(registry), + model_name_(model_name), + callback_id_(callback_id), + step_id_(step_id), + callback_(std::move(callback)) { + DCHECK(registry_); + } + + // Invokes the callback in `thread_pool` with `result`. + absl::Status Invoke(tsl::thread::ThreadPoolInterface* thread_pool, + StreamedResult result); + + // Closes the callback so that it can no longer be invoked. This method also + // waits for outstanding results to finish. + void Close(); + + private: + StreamControllerInterface& interface() { + return registry_->stream_interface(); + } + void InvokeCallback(StreamedResult result); + + StreamCallbackRegistry* registry_ = nullptr; + std::string model_name_; + StreamCallbackId callback_id_; + StepId step_id_; + absl::AnyInvocable)> + callback_; + + absl::Mutex mu_; + bool closed_ ABSL_GUARDED_BY(mu_) = false; + int num_outstanding_ ABSL_GUARDED_BY(mu_) = 0; + }; + + std::unique_ptr Unregister(StreamCallbackId callback_id, + StepId step_id); + + std::unique_ptr interface_; + + mutable absl::Mutex mu_; + absl::flat_hash_map, + std::unique_ptr> + stream_callbacks_ ABSL_GUARDED_BY(mu_); +}; + +// Returns the global registry for the stream callbacks. The stream interface +// must have been registered through GetGlobalStreamInterfaceFactory() before +// calling this function. +StreamCallbackRegistry& GetGlobalStreamCallbackRegistry(); + +// Creates a new stream callback id and rewrites the given module with +// information required to trigger this callback remotely. Returns the callback +// id, or `std::nullopt` if the module has no stream outputs. +absl::StatusOr> CreateStreamCallbackId( + absl::string_view model_name, mlir::ModuleOp module); + +// Implements an RAII object that registers a callback to be called on receiving +// streamed tensors. +class ScopedStreamCallback { + public: + ScopedStreamCallback() = default; + + // Moveable but not copyable. + ScopedStreamCallback(ScopedStreamCallback&& other); + ScopedStreamCallback& operator=(ScopedStreamCallback&& other); + + ~ScopedStreamCallback() { Unregister(); } + + private: + friend class StreamCallbackRegistry; + + explicit ScopedStreamCallback(StreamCallbackRegistry* registry, + StreamCallbackId callback_id, StepId step_id) + : registry_(registry), callback_id_(callback_id), step_id_(step_id) {} + + void Unregister(); + + StreamCallbackRegistry* registry_ = nullptr; + std::optional callback_id_; + StepId step_id_ = StepId::GetInvalidStepId(); +}; + +} // namespace tfrt_stub +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_TFRT_RUNTIME_STREAM_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/tfrt/runtime/tf_threadpool_concurrent_work_queue.h b/third_party/tflite-hdrs/tensorflow/core/tfrt/runtime/tf_threadpool_concurrent_work_queue.h new file mode 100644 index 00000000..be7acaee --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/tfrt/runtime/tf_threadpool_concurrent_work_queue.h @@ -0,0 +1,90 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_TFRT_RUNTIME_TF_THREADPOOL_CONCURRENT_WORK_QUEUE_H_ +#define TENSORFLOW_CORE_TFRT_RUNTIME_TF_THREADPOOL_CONCURRENT_WORK_QUEUE_H_ + +#include +#include +#include +#include + +#include "absl/base/attributes.h" +#include "absl/status/statusor.h" +#include "tensorflow/core/platform/cpu_info.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/threadpool_interface.h" +#include "tensorflow/core/tfrt/runtime/work_queue_interface.h" +#include "tfrt/host_context/async_value.h" // from @tf_runtime +#include "tfrt/host_context/concurrent_work_queue.h" // from @tf_runtime +#include "tfrt/host_context/execution_context.h" // from @tf_runtime +#include "tfrt/host_context/task_function.h" // from @tf_runtime +#include "tfrt/support/forward_decls.h" // from @tf_runtime + +namespace tensorflow { +namespace tfrt_stub { + +// This class defines a work queue based on the WorkQueueInterface that uses the +// Tensorflow threadpools to execute inter-op and intra-op closures. +class TfThreadPoolWorkQueue : public WorkQueueInterface { + public: + TfThreadPoolWorkQueue( + tensorflow::thread::ThreadPoolInterface* intra_op_threadpool, + tensorflow::thread::ThreadPoolInterface* inter_op_threadpool) + : TfThreadPoolWorkQueue(/*id=*/0, intra_op_threadpool, + inter_op_threadpool) {} + + TfThreadPoolWorkQueue( + int64_t id, tensorflow::thread::ThreadPoolInterface* intra_op_threadpool, + tensorflow::thread::ThreadPoolInterface* inter_op_threadpool) + : WorkQueueInterface(id, intra_op_threadpool), + intra_op_threadpool_(intra_op_threadpool), + inter_op_threadpool_(inter_op_threadpool) {} + + absl::StatusOr> InitializeRequest( + int64_t request_id) const override; + + int GetParallelismLevel() const override { + return inter_op_threadpool_->NumThreads(); + } + std::string name() const override { return "TfThreadPoolWorkQueue"; } + + void AddTask(tfrt::TaskFunction work) override; + + std::optional AddBlockingTask( + tfrt::TaskFunction work, bool allow_queuing) override; + + ABSL_DEPRECATED("Use the destructor instead.") + void Quiesce() override; + + void Await( + tfrt::ArrayRef<::tfrt::RCReference<::tfrt::AsyncValue>> values) override; + + bool IsInWorkerThread() const override; + + private: + tensorflow::thread::ThreadPoolInterface* intra_op_threadpool_ = nullptr; + tensorflow::thread::ThreadPoolInterface* inter_op_threadpool_ = nullptr; +}; + +// Create a default TfThreadPoolWorkQueue that is implemented by +// tensorflow::thread::ThreadPool. `num_inter_op_threads` and +// `num_intra_op_threads` must be larger than zero. +std::unique_ptr CreateDefaultTfThreadPoolWorkQueue( + int num_inter_op_threads, int num_intra_op_threads); + +} // namespace tfrt_stub +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_TFRT_RUNTIME_TF_THREADPOOL_CONCURRENT_WORK_QUEUE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/tfrt/runtime/work_queue_interface.h b/third_party/tflite-hdrs/tensorflow/core/tfrt/runtime/work_queue_interface.h new file mode 100644 index 00000000..4eca1d72 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/tfrt/runtime/work_queue_interface.h @@ -0,0 +1,113 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_TFRT_RUNTIME_WORK_QUEUE_INTERFACE_H_ +#define TENSORFLOW_CORE_TFRT_RUNTIME_WORK_QUEUE_INTERFACE_H_ + +#include +#include +#include +#include + +#include "absl/base/attributes.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "tensorflow/core/platform/context.h" +#include "tensorflow/core/platform/statusor.h" +#include "tensorflow/core/platform/threadpool_interface.h" +#include "tensorflow/core/profiler/lib/connected_traceme.h" +#include "tensorflow/core/profiler/lib/traceme_encode.h" +#include "tfrt/host_context/concurrent_work_queue.h" // from @tf_runtime +#include "tfrt/support/error_util.h" // from @tf_runtime + +namespace tensorflow { +namespace tfrt_stub { + +// This is an intermediate interface in tensorflow for injecting thread pool +// implementation into TFRT. We can add savedmodel/tensorflow specific +// methods (eg. create an intra op thread pool) without changing TFRT core. +class WorkQueueInterface : public tfrt::ConcurrentWorkQueue { + public: + WorkQueueInterface() = default; + explicit WorkQueueInterface(int64_t id) : id_(id) {} + explicit WorkQueueInterface(int64_t id, + thread::ThreadPoolInterface* intra_op_threadpool) + : id_(id), intra_op_threadpool_(intra_op_threadpool) {} + ~WorkQueueInterface() override = 0; + + int64_t id() const { return id_; } + + thread::ThreadPoolInterface* GetIntraOpThreadPool() const { + return intra_op_threadpool_; + } + + // Returns per-request work queue if possible. A nullptr should be returned if + // the implementation does not implement the per-request work queue. + // + // TODO(b/198671794): Remove per-request concepts from the work queue + // interface so that the interface is more composable. Per-request logic + // should be handled separately. + ABSL_DEPRECATED("Create the instance directly instead.") + virtual absl::StatusOr> InitializeRequest( + int64_t request_id) const { + return {nullptr}; + } + + private: + int64_t id_ = 0; + thread::ThreadPoolInterface* intra_op_threadpool_ = nullptr; +}; + +inline WorkQueueInterface::~WorkQueueInterface() = default; + +// Creates a WorkQueueInterface from a ConcurrentWorkQueue. The returned +// WorkQueueInterface simply delegates all its public methods to the specified +// ConcurrentWorkQueue. +std::unique_ptr WrapDefaultWorkQueue( + std::unique_ptr work_queue); + +// Creates a WorkQueueInterface from a ConcurrentWorkQueue. The returned +// WorkQueueInterface simply delegates all its public methods to the specified +// ConcurrentWorkQueue. The `intra_thread_pool` is stored and will be passed out +// when `InitializeRequest()` is called. +std::unique_ptr WrapDefaultWorkQueue( + std::unique_ptr work_queue, + thread::ThreadPoolInterface* intra_thread_pool); + +// A helper function that wraps tasks with traceme events. +template +tfrt::TaskFunction WrapWork(int64_t id, absl::string_view name, + Callable&& work) { + tensorflow::Context context(tensorflow::ContextKind::kThread); + tsl::profiler::TraceMeProducer producer( + [&]() { return absl::StrCat("producer_", name); }, + tsl::profiler::ContextType::kTfrtExecutor); + return tfrt::TaskFunction([traceme_id = producer.GetContextId(), + name = std::string(name), + context = std::move(context), + work = std::forward(work)]() mutable { + tsl::profiler::TraceMeConsumer consumer( + [&]() { return absl::StrCat("consumer_", name); }, + tsl::profiler::ContextType::kTfrtExecutor, traceme_id, + tsl::profiler::TraceMeLevel::kInfo); + tensorflow::WithContext wc(context); + std::forward(work)(); + }); +} + +} // namespace tfrt_stub +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_TFRT_RUNTIME_WORK_QUEUE_INTERFACE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/tfrt/saved_model/python/saved_model_load_and_run.h b/third_party/tflite-hdrs/tensorflow/core/tfrt/saved_model/python/saved_model_load_and_run.h new file mode 100644 index 00000000..cddf15cc --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/tfrt/saved_model/python/saved_model_load_and_run.h @@ -0,0 +1,49 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_TFRT_SAVED_MODEL_PYTHON_SAVED_MODEL_LOAD_AND_RUN_H_ +#define TENSORFLOW_CORE_TFRT_SAVED_MODEL_PYTHON_SAVED_MODEL_LOAD_AND_RUN_H_ + +#include + +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/statusor.h" +#include "tensorflow/core/tfrt/graph_executor/graph_execution_options.h" +#include "tensorflow/core/tfrt/saved_model/saved_model.h" + +namespace tensorflow::tfrt_stub { + +absl::StatusOr> LoadSavedModel( + absl::string_view saved_model_dir, + const std::unordered_set& tags); + +std::vector RunConvertor(PyObject* args); + +absl::Status Run( + SavedModel* saved_model, + const tensorflow::tfrt_stub::GraphExecutionRunOptions& run_options, + absl::string_view name, const std::vector& inputs, + std::vector* outputs); +} // namespace tensorflow::tfrt_stub + +#endif // TENSORFLOW_CORE_TFRT_SAVED_MODEL_PYTHON_SAVED_MODEL_LOAD_AND_RUN_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/tfrt/saved_model/saved_model.h b/third_party/tflite-hdrs/tensorflow/core/tfrt/saved_model/saved_model.h new file mode 100644 index 00000000..b4c050aa --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/tfrt/saved_model/saved_model.h @@ -0,0 +1,353 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_TFRT_SAVED_MODEL_SAVED_MODEL_H_ +#define TENSORFLOW_CORE_TFRT_SAVED_MODEL_SAVED_MODEL_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/tensor.pb.h" +#include "tensorflow/core/platform/thread_annotations.h" +#include "tensorflow/core/protobuf/config.pb.h" +#include "tensorflow/core/protobuf/meta_graph.pb.h" +#include "tensorflow/core/tfrt/fallback/fallback_state.h" +#include "tensorflow/core/tfrt/graph_executor/graph_execution_options.h" +#include "tensorflow/core/tfrt/graph_executor/graph_executor.h" +#include "tensorflow/core/tfrt/runtime/runtime.h" +#include "tensorflow/core/tfrt/saved_model/saved_model_util.h" +#include "tsl/platform/protobuf.h" +#include "tfrt/host_context/function.h" // from @tf_runtime +#include "tfrt/host_context/request_deadline_tracker.h" // from @tf_runtime +#include "tfrt/host_context/resource_context.h" // from @tf_runtime + +namespace tfrt { +class BEFFile; +class HostContext; +} // namespace tfrt + +namespace tensorflow { +namespace tfrt_stub { + +class FunctionMetadata { + public: + explicit FunctionMetadata(const internal::Signature* signature) + : signature_(signature) { + assert(signature); + } + + const std::vector& GetInputNames() const { + return signature_->input_names; + } + + const std::vector& GetInputSpecs() const { + return signature_->input_specs; + } + + const std::vector& GetOutputNames() const { + return signature_->output_names; + } + + const std::vector& GetOutputSpecs() const { + return signature_->output_specs; + } + + const protobuf::Map& GetDefaultInputs() const { + return signature_->default_inputs; + } + + private: + friend class SavedModelImpl; + + const internal::Signature* signature_ = nullptr; +}; + +// SavedModel represents the in-memory states (graphs and variables) loaded from +// a tensorflow saved model directory. +class SavedModel { + public: + struct Options { + explicit Options(const Runtime* rt) : graph_execution_options(rt) {} + + // If true, the loading of any signature (or signature combination) will be + // deferred until the first corresponding invocationof running. Otherwise, + // the individual signatures will be loaded along with the saved model. + bool enable_lazy_loading = false; + + // If true, we'll attempt to find MLArchive within the given loading path. + // If not found, will use the path as a normal SavedModel directory. + // + // This field is deprecated. + bool maybe_load_from_mla = false; + + // If true, the lazy loading path will use tfrt_stub::GraphExecutor. + // + // TODO(b/216379787): Remove this option once b/279197040 is unblocked. + bool lazy_loading_use_graph_executor = false; + + // True if and only if SavedModel is being loaded to generate AOT results. + bool aot_generation = false; + + // Make a best-effort guess at the model type and emit a metric. E.g. + // detecting JAX models by looking for the `XlaCallModule` op in the + // MetaGraphDef. + bool emit_model_type_metric = false; + + GraphExecutionOptions graph_execution_options; + }; + + // Per-request options. + using RunOptions = GraphExecutionRunOptions; + + explicit SavedModel(const Runtime* runtime) : options_(runtime) { + DCHECK(runtime); + } + explicit SavedModel(Options options, + std::unique_ptr graph_executor) + : options_(std::move(options)), + graph_executor_(std::move(graph_executor)) {} + virtual ~SavedModel(); + + const SessionMetadata& model_metadata() const { + return options_.graph_execution_options.model_metadata; + } + + const Runtime& runtime() const { + DCHECK(options_.graph_execution_options.runtime); + return *options_.graph_execution_options.runtime; + } + tfrt::HostContext* GetHostContext() const; + + GraphExecutor& graph_executor() const { return *graph_executor_; } + + // Returns meta graph def. Note that the graph_def field in the MetaGraphDef + // has already been removed. + // + // TODO(b/191931702): Change the method to return SignatureDefs instead. + virtual const tensorflow::MetaGraphDef& GetMetaGraphDef() const = 0; + + // Returns all the function names. + virtual std::vector GetFunctionNames() const = 0; + + // Returns the `FunctionMetadata` for a function. If the function is not + // found, returns nullopt instead. + virtual std::optional GetFunctionMetadata( + absl::string_view func_name) const = 0; + + // Runs the signature specified by `name`. Both `inputs` and `outputs` + // are all host tensors. The `outputs` must be non-null. If the returned + // status is non-OK, the `outputs` are invalid. + virtual absl::Status Run(const RunOptions& run_options, + absl::string_view name, + absl::Span inputs, + std::vector* outputs) = 0; + + // Runs the signatures specified by `names`. Both `inputs` and `outputs` are + // all host tensors. The `outputs` must be non-null. If the returned status is + // non-OK, the `outputs` are invalid. + // + // NOTE: If the given signatures have overlapping input nodes, the input + // tensors for these overlapping nodes must be the same. Having different + // input tensors for overlapping nodes results UNDEFINED BEHAVIOR. + // + // NOTE: The input/output tensors can only be dense tensors (as opposed to + // sparse tensors or composite tensors). + virtual absl::Status RunMultipleSignatures( + const RunOptions& run_options, absl::Span names, + absl::Span> multi_inputs, + std::vector>* multi_outputs) = 0; + + // Runs the graphs specified by the tensor names terminal tensors (eg. feed + // tensors, fetch tesnors) in the graph. + virtual absl::Status RunByTensorNames( + const RunOptions& run_options, + absl::Span> inputs, + absl::Span output_tensor_names, + absl::Span target_node_names, + std::vector* outputs) = 0; + + protected: + const FallbackState& fallback_state() const { + return graph_executor_->fallback_state(); + } + FallbackState& fallback_state() { return graph_executor_->fallback_state(); } + + const Options options_; + std::unique_ptr graph_executor_; +}; + +using SignatureMap = absl::flat_hash_map; +using ::tensorflow::StatusOr; + +class SavedModelImpl final : public SavedModel { + public: + struct JoinedSignature; + + // Loads all SignatureDefs in a MetaGraphDef that matches the `tags` in the + // tensorflow saved model from `saved_model_dir`. Refer to + // http://g3doc/learning/serving/g3doc/saved_model/overview.md + // for explanations on SavedModel. + // + // If `options.maybe_load_from_mla` is true, tries opening `saved_model_dir` + // as an MLA. If it's not an MLA, uses it as a normal SavedModel directory. + static absl::StatusOr> LoadSavedModel( + Options options, absl::string_view saved_model_dir, + const std::unordered_set& tags); + + // Loads all SignatureDefs in `meta_graph_def`. Refer to + // http://g3doc/learning/serving/g3doc/saved_model/overview.md + // for explanations on SavedModel. + static absl::StatusOr> LoadSavedModel( + Options options, tensorflow::MetaGraphDef meta_graph_def, + absl::string_view saved_model_dir); + + SavedModelImpl( + Options options, SymbolUids symbol_uids, + tensorflow::MetaGraphDef meta_graph_def, tfrt::BefBuffer bef, + tfrt::RCReference bef_file, mlrt::bc::Buffer bytecode, + std::optional loaded_executable, + absl::flat_hash_map signatures, + std::unique_ptr runner_table, + std::unique_ptr resource_array, + std::unique_ptr graph_executor); + + ~SavedModelImpl() override = default; + + SavedModelImpl(const SavedModelImpl&) = delete; + SavedModelImpl& operator=(const SavedModelImpl&) = delete; + + const tensorflow::MetaGraphDef& GetMetaGraphDef() const override; + + std::vector GetFunctionNames() const override; + + std::optional GetFunctionMetadata( + absl::string_view func_name) const override; + + absl::Status Run(const RunOptions& run_options, absl::string_view name, + absl::Span inputs, + std::vector* outputs) override; + + absl::Status RunMultipleSignatures( + const RunOptions& run_options, absl::Span names, + absl::Span> multi_inputs, + std::vector>* multi_outputs) override; + + absl::Status RunByTensorNames( + const RunOptions& run_options, + absl::Span> inputs, + absl::Span output_tensor_names, + absl::Span target_node_names, + std::vector* outputs) override; + + private: + // The result of loading signature(s). + struct LoadingResult { + std::string name; + SymbolUids symbol_uids; + + // For the MLRT path. + mlrt::bc::Buffer bytecode_buffer; + std::unique_ptr bytecode_executable; + + // For the TFRT path. + tfrt::BefBuffer bef; + tfrt::RCReference bef_file; + + std::unique_ptr runner_table; + std::unique_ptr resource_array; + + // There are some resources that need re-creating when the executable is + // re-created, so a resource context is stored along with the executable. + // This resource context is meant to be passed to the op kernels for their + // references. See the comment above `GraphExecutor::resource_context_` + // about the todo to merge that resource context with this one. + std::unique_ptr resource_context; + }; + + // Imports a subgraph as an MLIR module with the specified `input_nodes`, + // `output_nodes`. + absl::StatusOr> ImportSubgraph( + mlir::MLIRContext* context, absl::string_view name, + const tensorflow::GraphImportConfig::InputArrays& input_nodes, + const std::vector& output_nodes, + const std::vector& target_nodes); + + // Given the joined signature, loads the subgraph and returns loading result. + absl::StatusOr> + LoadJoinedSignature(const JoinedSignature& joined_signature) + TF_EXCLUSIVE_LOCKS_REQUIRED(loading_result_cache_mu_); + + // Returns the loading result given the signature names. + absl::StatusOr> + GetOrCreateLoadingResult(const RunOptions& run_options, + absl::Span names) + TF_LOCKS_EXCLUDED(loading_result_cache_mu_); + + SymbolUids symbol_uids_; + // `meta_graph_def_` only contains metadata of the model. The graph_def field + // is removed. + // + // TODO(b/191931702): We should only keep content that are actually used + // (eg. SignatureDefs), instead of keeping the whole saved model, to avoid + // unnecessary memory usage. + tensorflow::MetaGraphDef meta_graph_def_; + tfrt::BefBuffer bef_; + tfrt::RCReference bef_file_; + + mlrt::bc::Buffer bytecode_; + std::optional loaded_executable_; + + tfrt::RequestDeadlineTracker req_deadline_tracker_; + absl::flat_hash_map signatures_; + std::unique_ptr runner_table_; + std::unique_ptr resource_array_; + tensorflow::mutex loading_result_cache_mu_; + // For pointer stability of values in `absl::flat_hash_map<>`, additional + // `std::unique_ptr<>` is necessary. (See https://abseil.io/tips/136.) + absl::flat_hash_map> + loading_result_cache_ TF_GUARDED_BY(loading_result_cache_mu_); +}; + +class SavedModelMiraImpl; + +} // namespace tfrt_stub +} // namespace tensorflow + +namespace tfrt { + +using SavedModel = ::tensorflow::tfrt_stub::SavedModel; +using SavedModelImpl = ::tensorflow::tfrt_stub::SavedModelImpl; +using SavedModelMiraImpl = ::tensorflow::tfrt_stub::SavedModelMiraImpl; +using TensorSpec = ::tensorflow::tfrt_stub::TensorSpec; +using FunctionMetadata = ::tensorflow::tfrt_stub::FunctionMetadata; + +namespace internal { +using Signature = ::tensorflow::tfrt_stub::internal::Signature; +} + +} // namespace tfrt + +#endif // TENSORFLOW_CORE_TFRT_SAVED_MODEL_SAVED_MODEL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/tfrt/saved_model/saved_model_aot_compile.h b/third_party/tflite-hdrs/tensorflow/core/tfrt/saved_model/saved_model_aot_compile.h new file mode 100644 index 00000000..27db2c92 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/tfrt/saved_model/saved_model_aot_compile.h @@ -0,0 +1,105 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_TFRT_SAVED_MODEL_SAVED_MODEL_AOT_COMPILE_H_ +#define TENSORFLOW_CORE_TFRT_SAVED_MODEL_SAVED_MODEL_AOT_COMPILE_H_ + +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "mlir/IR/DialectRegistry.h" // from @llvm-project +#include "tensorflow/compiler/jit/device_compilation_cluster_signature.h" +#include "tensorflow/compiler/tf2xla/xla_compiler.h" +#include "xla/pjrt/pjrt_executable.h" +#include "xla/service/compiler.h" +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/function.pb.h" +#include "tensorflow/core/protobuf/meta_graph.pb.h" +#include "tensorflow/core/tfrt/graph_executor/graph_execution_options.h" +#include "tensorflow/core/tfrt/mlrt/bytecode/bytecode.h" +#include "tensorflow/core/tfrt/runtime/runtime.h" +#include "tfrt/bef/bef_buffer.h" // from @tf_runtime + +namespace tensorflow::tfrt_stub { +struct AotOptions { + AotOptions(); + std::unordered_set tags = {}; + std::shared_ptr graph_execution_options; + // TODO(b/296466237): support compiling for multiple signature functions. + // The signature name to be AOT compiled. + std::string signature_name; +}; + +struct AotResult { + using ExecutableMap = + absl::flat_hash_map; + std::variant buffer; + // TODO(b/296466237): Investigate whether the whole FunctionDefLibrary should + // be put here. + // XLA cluster functions corresponding to `XlaLaunch` op, generated during + // bridge. + std::vector xla_functions; +}; + +// AOT compiles saved_model in input_model_dir and returns AotResult, otherwise +// returns error. +absl::StatusOr AotCompileSavedModel( + absl::string_view input_model_dir, AotOptions aot_options = {}); + +// TODO(b/296466237): Add unit test. +// Runs bridge and compiles the generated XLA functions corresponding to the +// signature function with name `siganture_name` in MetaGraphDef. +// `input_shapes` maps input signature node name to its tensor shape, and is +// used to make up for the missing input shape information in the graph if any +// so that shape inference pass in bridge can proceed correctly. Returns +// AotResult::ExecutableMap as compilation result, which maps function +// signatures to serialized executables. +absl::StatusOr AotCompileXlaFunctionsInMetaGraphDef( + const MetaGraphDef& meta_graph_def, const std::string& signature_name, + const absl::flat_hash_map& + input_shapes, + const tensorflow::FunctionDefLibrary& fdef_lib, + const tensorflow::SessionOptions& session_options, + const mlir::DialectRegistry& registry, const AotOptions& aot_options, + absl::string_view input_model_dir, ModelRuntimeContext& model_context); + +// TODO(b/296466237): make this function general for all devices. +// AOT compiles `function` into PjRtExecutable. It is the counterpart of the JIT +// version `CompileToPjRtLoadedExecutable`. `compilation_result` contains the +// generated XLA computation. +absl::StatusOr> +AotCompileToGpuPjRtExecutable( + const FunctionLibraryDefinition* flib_def, const NameAttrList& function, + int graph_def_version, const std::vector& args, + bool has_ref_vars, bool may_alias_resource_update, + const stream_executor::GpuTargetConfigProto& gpu_target_config, + XlaCompiler::CompilationResult** compilation_result); + +// Returns serialized PJRT loaded GPU executable. This function requires GPU +// device to be present during compilation. +absl::StatusOr AotCompileToGpuPjRtLoadedExecutableWithDevice( + const FunctionLibraryDefinition* flib_def, const NameAttrList& function, + int graph_def_version, const std::vector& args, + bool has_ref_vars, bool may_alias_resource_update, + XlaCompiler::CompilationResult** compilation_result); +} // namespace tensorflow::tfrt_stub + +#endif // TENSORFLOW_CORE_TFRT_SAVED_MODEL_SAVED_MODEL_AOT_COMPILE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/tfrt/saved_model/saved_model_import_input.h b/third_party/tflite-hdrs/tensorflow/core/tfrt/saved_model/saved_model_import_input.h new file mode 100644 index 00000000..5ff375a0 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/tfrt/saved_model/saved_model_import_input.h @@ -0,0 +1,67 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_TFRT_SAVED_MODEL_SAVED_MODEL_IMPORT_INPUT_H_ +#define TENSORFLOW_CORE_TFRT_SAVED_MODEL_SAVED_MODEL_IMPORT_INPUT_H_ + +#include +#include + +#include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h" +#include "tensorflow/core/tfrt/fallback/fallback_state.h" +#include "tensorflow/core/tfrt/graph_executor/config.h" +#include "tensorflow/core/tfrt/utils/tfrt_graph_execution_state.h" + +namespace tensorflow { +namespace tfrt_stub { + +// TfrtSavedModelMLIRImportInput implements SavedModelMLIRImportInput, so that +// it can perform customization (eg. Placer and Grappler) on the input graph to +// the MLIR importer. +class TfrtSavedModelMLIRImportInput : public SavedModelMLIRImportInput { + public: + static absl::StatusOr Create( + const FallbackState& fallback_state, const MetaGraphDef* meta_graph_def, + const GraphDebugInfo& debug_info, + bool run_placer_grappler_on_nested_functions = false, + tensorflow::tfrt_stub::RuntimeConfig* runtime_config = nullptr); + + TfrtSavedModelMLIRImportInput( + const MetaGraphDef* meta_graph_def, const GraphDebugInfo& debug_info, + std::unique_ptr graph_execution_state); + + absl::StatusOr GetSubGraph( + absl::string_view name, GraphImportConfig& graph_import_config) override; + + // Return the time used by grappler. + absl::Duration GetGrapplerDuration() const { return grappler_duration_; } + + // Return the time used by functionalization. + absl::Duration GetFunctionalizationDuration() const { + return functionalization_duration_; + } + + private: + std::unique_ptr graph_execution_state_; + absl::flat_hash_map> + optimized_graphs_; + + absl::Duration functionalization_duration_; + absl::Duration grappler_duration_; +}; + +} // namespace tfrt_stub +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_TFRT_SAVED_MODEL_SAVED_MODEL_IMPORT_INPUT_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/tfrt/saved_model/saved_model_testutil.h b/third_party/tflite-hdrs/tensorflow/core/tfrt/saved_model/saved_model_testutil.h new file mode 100644 index 00000000..c0a69cd9 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/tfrt/saved_model/saved_model_testutil.h @@ -0,0 +1,127 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_TFRT_SAVED_MODEL_SAVED_MODEL_TESTUTIL_H_ +#define TENSORFLOW_CORE_TFRT_SAVED_MODEL_SAVED_MODEL_TESTUTIL_H_ + +#include + +#include +#include +#include +#include +#include + +#include "tensorflow/cc/saved_model/loader.h" +#include "tensorflow/compiler/mlir/tfrt/translate/tfrt_compile_options.h" +#include "tensorflow/core/tfrt/runtime/runtime.h" +#include "tensorflow/core/tfrt/saved_model/saved_model.h" +#include "tfrt/host_context/host_context.h" // from @tf_runtime + +#if defined(PLATFORM_GOOGLE) +ABSL_DECLARE_FLAG(bool, enable_optimizer); +ABSL_DECLARE_FLAG(std::string, force_data_format); +#endif + +namespace tensorflow { +namespace tfrt_stub { + +std::unique_ptr DefaultTfrtRuntime( + int num_threads); + +struct UserSavedModelOptions { + bool enable_mlrt = false; + bool enable_optimizer = false; + bool enable_grappler = false; + std::string force_data_format = ""; +}; + +SavedModel::Options DefaultSavedModelOptions( + tensorflow::tfrt_stub::Runtime* runtime, + std::optional user_options = std::nullopt); + +class TFRTSavedModelTest { + public: + explicit TFRTSavedModelTest(const std::string& saved_model_dir); + TFRTSavedModelTest(const std::string& saved_model_dir, + std::unique_ptr runtime); + + SavedModel* GetSavedModel() { return saved_model_.get(); } + + tfrt::HostContext* GetHostContext() const { + return saved_model_->GetHostContext(); + } + + private: + std::unique_ptr runtime_; + std::unique_ptr saved_model_; +}; + +template +tensorflow::Tensor CreateTfTensor(absl::Span shape, + absl::Span data) { + tensorflow::Tensor tensor(tensorflow::DataTypeToEnum::value, + tensorflow::TensorShape(shape)); + auto flat = tensor.flat(); + for (int i = 0; i < data.size(); ++i) { + flat(i) = data[i]; + } + return tensor; +} + +template +std::vector GetTfTensorData(const tensorflow::Tensor& tensor) { + return std::vector(tensor.flat().data(), + tensor.flat().data() + tensor.NumElements()); +} + +inline tensorflow::Tensor CreateTfStringTensor( + absl::Span shape, absl::Span data) { + return CreateTfTensor(shape, data); +} + +void ComputeCurrentTFResult(const std::string& saved_model_dir, + const std::string& signature_name, + const std::vector& input_names, + const std::vector& inputs, + const std::vector& output_names, + std::vector* outputs, + bool enable_mlir_bridge = false, + bool disable_grappler = false); + +// Compute the results using TF1 session loaded from the saved model. In +// addition to returning the result tensors, it also fills `bundle` with the +// loaded savedmodel. This is useful as sometimes the result tensors may only be +// valid when the bundle is alive. +void ComputeCurrentTFResult(const std::string& saved_model_dir, + const std::string& signature_name, + const std::vector& input_names, + const std::vector& inputs, + const std::vector& output_names, + std::vector* outputs, + tensorflow::SavedModelBundle* bundle, + bool enable_mlir_bridge = false, + bool disable_grappler = false); + +void ExpectTensorEqual(const tensorflow::Tensor& x, const tensorflow::Tensor& y, + std::optional error = std::nullopt); + +SavedModel::Options DefaultTpuModelOptions( + tensorflow::tfrt_stub::Runtime* runtime, + tensorflow::TfrtDeviceInfraTarget device_target); + +} // namespace tfrt_stub +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_TFRT_SAVED_MODEL_SAVED_MODEL_TESTUTIL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/tfrt/saved_model/saved_model_util.h b/third_party/tflite-hdrs/tensorflow/core/tfrt/saved_model/saved_model_util.h new file mode 100644 index 00000000..409a31a0 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/tfrt/saved_model/saved_model_util.h @@ -0,0 +1,154 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_TFRT_SAVED_MODEL_SAVED_MODEL_UTIL_H_ +#define TENSORFLOW_CORE_TFRT_SAVED_MODEL_SAVED_MODEL_UTIL_H_ + +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/DialectRegistry.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/OwningOpRef.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tfrt/translate/tfrt_compile_options.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor.pb.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/protobuf/config.pb.h" +#include "tensorflow/core/protobuf/meta_graph.pb.h" +#include "tensorflow/core/tfrt/fallback/fallback_state.h" +#include "tensorflow/core/tfrt/graph_executor/config.h" +#include "tensorflow/core/tfrt/mlrt/bytecode/bytecode.h" +#include "tsl/platform/protobuf.h" +#include "tfrt/bef/bef_buffer.h" // from @tf_runtime + +namespace tensorflow { +namespace tfrt_stub { + +// Filename for serialized BEF Buffer. +inline constexpr char kBefBufferFileName[] = "serialized_bef.mlir.bef"; + +// Filename for serialized MLRT bytecode Buffer. +inline constexpr char kMlrtBufferFileName[] = "serialized_mlrt.mlir.mlrt"; + +// Filename for serialized MLIR_MODULE. +inline constexpr char kMlirModuleFilename[] = "serialized_mlir.mlir"; + +// Subdirectory where AoT Packages are saved +inline constexpr char kAotPackagesDirectory[] = "aot_packages"; + +// TODO(tfrt-dev): Replace tfrt::TensorSpec with tensorflow::TensorSpec once the +// latter is checked in. +struct TensorSpec { + tensorflow::DataType dtype; + tensorflow::PartialTensorShape shape; + + explicit TensorSpec(tensorflow::DataType dtype) : dtype(dtype) {} + TensorSpec(tensorflow::DataType dtype, tensorflow::PartialTensorShape shape) + : dtype(dtype), shape(std::move(shape)) {} +}; + +inline bool operator==(const TensorSpec& a, const TensorSpec& b) { + return a.dtype == b.dtype && a.shape.IsIdenticalTo(b.shape); +} + +namespace internal { + +struct Signature { + // The following three fields should have the same size. + std::vector input_names; + std::vector input_specs; + std::vector input_devices; + + // The following two fields should have the same size. + std::vector output_names; + std::vector output_specs; + protobuf::Map default_inputs; +}; + +} // namespace internal + +// If `import_signature_names` is non-empty, this function only imports the +// graph that corresponds to this list. +absl::StatusOr> ImportSavedModel( + mlir::MLIRContext* context, const tensorflow::MetaGraphDef& meta_graph_def, + const FallbackState& fallback_state, std::string saved_model_dir, + bool import_user_signatures, bool run_placer_grappler_on_functions, + const std::vector& import_signature_names = {}, + tensorflow::tfrt_stub::RuntimeConfig* runtime_config = nullptr); + +absl::StatusOr ReadSavedModel( + absl::string_view saved_model_dir, + const std::unordered_set& tags); + +using SignatureMap = absl::flat_hash_map; +using ::tensorflow::StatusOr; + +struct Initializer { + std::string name; + std::vector inputs; +}; + +struct InitializersAndSignatures { + // Initializers are kept in a certain order as they need to be executed in + // that order. + std::vector initializers; + SignatureMap signature_map; +}; + +// If `saved_model_dir` is non-empty, this function fills in the Initializer's +// inputs in the returned result. +absl::StatusOr GetInitializersAndSignatures( + mlir::ModuleOp module, absl::string_view saved_model_dir = ""); + +std::string GetAotPackagePath(absl::string_view saved_model_dir); + +std::string GetBefFilePath(std::string aot_package_directory); + +std::string GetMlirFilePath(const std::string& aot_package_directory); + +// TODO(b/295241000): Implement MLIR deserialization to skip it AoT and remove +// redundant steps +absl::StatusOr LoadBefAndMlir( + const TfrtCompileOptions& options, mlir::ModuleOp mlir_module, + const std::string& saved_model_dir, + tfrt_stub::FallbackState* fallback_state); + +absl::StatusOr LoadMlrtAndMlir( + const TfrtCompileOptions& options, mlir::ModuleOp mlir_module, + const std::string& saved_model_dir, + tfrt_stub::FallbackState* fallback_state); + +absl::Status DeserializeAoTMlirModule( + absl::string_view saved_model_dir, mlir::MLIRContext* context, + mlir::OwningOpRef* mlir_module); + +CallableOptions CombineSignatureDefs( + const google::protobuf::Map& signature_defs); + +void RegisterTfrtDialectsForAot(mlir::DialectRegistry& registry); + +} // namespace tfrt_stub +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_TFRT_SAVED_MODEL_SAVED_MODEL_UTIL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/tfrt/saved_model/utils/serialize_utils.h b/third_party/tflite-hdrs/tensorflow/core/tfrt/saved_model/utils/serialize_utils.h new file mode 100644 index 00000000..6708b44a --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/tfrt/saved_model/utils/serialize_utils.h @@ -0,0 +1,54 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_TFRT_SAVED_MODEL_UTILS_SERIALIZE_UTILS_H_ +#define TENSORFLOW_CORE_TFRT_SAVED_MODEL_UTILS_SERIALIZE_UTILS_H_ + +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "llvm/Support/ToolOutputFile.h" +#include "mlir/Support/FileUtilities.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/tfrt/mlrt/bytecode/executable.h" +#include "tsl/platform/env.h" +#include "tfrt/bef/bef_buffer.h" // from @tf_runtime + +namespace tensorflow { +namespace tfrt_stub { + +// Serializes the BefBuffer into a file. +absl::Status SerializeBEF(const tfrt::BefBuffer &bef, + const std::string &filepath); + +// Deserializes BEF file from filepath into a BEFBuffer. +absl::StatusOr DeserializeBEFBuffer( + const std::string &filepath); + +// Serializes the MLRTBytecodeBuffer into a file. +absl::Status SerializeMLRTBytecode(const mlrt::bc::Buffer &byteCode, + const std::string &filepath); + +// Deserializes byte code from the given filepath into a MLRTBytecodeBuffer. +absl::StatusOr DeserializeMlrtBytecodeBuffer( + const std::string &filepath); + +} // namespace tfrt_stub +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_TFRT_SAVED_MODEL_UTILS_SERIALIZE_UTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/tfrt/stubs/model_config_stub.h b/third_party/tflite-hdrs/tensorflow/core/tfrt/stubs/model_config_stub.h new file mode 100644 index 00000000..6518fd21 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/tfrt/stubs/model_config_stub.h @@ -0,0 +1,49 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_TFRT_STUBS_MODEL_CONFIG_STUB_H_ +#define TENSORFLOW_CORE_TFRT_STUBS_MODEL_CONFIG_STUB_H_ + +#include + +#include "absl/log/log.h" +#include "tensorflow/core/tfrt/runtime/runtime.h" +#include "tensorflow/core/tfrt/saved_model/saved_model_util.h" + +namespace tensorflow { +namespace tfrt_stub { + +// TODO(b/299140515): Deprecate this stub and OSS the implementation. +// The tfrt model config stub that provides interface for internal and OSS +// with different impls. +class ModelConfigStub { + public: + virtual ~ModelConfigStub() = default; + + virtual void GetDefaultInputsFromModelConfig(ModelRuntimeContext& context, + SignatureMap& signatures) { + LOG(INFO) << "Unimplemented in non internal env"; + } +}; + +// The return value is to facilitate the global registration. +bool RegisterModelConfigStub(std::unique_ptr stub); + +void GetDefaultInputsFromModelConfig(ModelRuntimeContext& context, + SignatureMap& signatures); + +} // namespace tfrt_stub +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_TFRT_STUBS_MODEL_CONFIG_STUB_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/tfrt/stubs/tfrt_native_lowering_stub.h b/third_party/tflite-hdrs/tensorflow/core/tfrt/stubs/tfrt_native_lowering_stub.h new file mode 100644 index 00000000..d27fe02d --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/tfrt/stubs/tfrt_native_lowering_stub.h @@ -0,0 +1,59 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_TFRT_STUBS_TFRT_NATIVE_LOWERING_STUB_H_ +#define TENSORFLOW_CORE_TFRT_STUBS_TFRT_NATIVE_LOWERING_STUB_H_ + +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "tensorflow/core/tfrt/graph_executor/executable_context.h" +#include "tensorflow/core/tfrt/graph_executor/sync_resource_state.h" +#include "tensorflow/core/tfrt/mlrt/interpreter/context.h" +#include "tfrt/host_context/execution_context.h" // from @tf_runtime +#include "tfrt/host_context/host_context.h" // from @tf_runtime + +namespace tfrt { + +// The tfrt native lowering stub that provides interface for internal and OSS +// with different impls. +class TfrtNativeLoweringStub { + public: + virtual ~TfrtNativeLoweringStub() = default; + virtual void AddSyncContext( + mlrt::ExecutionContext& execution_context, HostContext& host_context, + tensorflow::tfrt_stub::SyncResourceState* sync_state) {} + virtual absl::StatusOr< + std::shared_ptr> + BuildExecutableContext(mlir::ModuleOp module, + const mlrt::KernelRegistry& kernel_registry) { + return absl::UnimplementedError(""); + } +}; + +void RegisterTfrtNativeLoweringStub( + std::unique_ptr stub); + +void AddSyncContext(mlrt::ExecutionContext& execution_context, + tfrt::HostContext& host_context, + tensorflow::tfrt_stub::SyncResourceState* sync_state); + +absl::StatusOr> +BuildExecutableContext(mlir::ModuleOp module, + const mlrt::KernelRegistry& kernel_registry); +} // namespace tfrt + +#endif // TENSORFLOW_CORE_TFRT_STUBS_TFRT_NATIVE_LOWERING_STUB_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/tfrt/tfrt_session/tfrt_session.h b/third_party/tflite-hdrs/tensorflow/core/tfrt/tfrt_session/tfrt_session.h new file mode 100644 index 00000000..84de49eb --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/tfrt/tfrt_session/tfrt_session.h @@ -0,0 +1,121 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_TFRT_TFRT_SESSION_TFRT_SESSION_H_ +#define TENSORFLOW_CORE_TFRT_TFRT_SESSION_TFRT_SESSION_H_ + +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/synchronization/mutex.h" +#include "absl/time/time.h" +#include "tensorflow/compiler/mlir/tfrt/backend_compiler.h" +#include "tensorflow/compiler/mlir/tfrt/translate/tfrt_compile_options.h" +#include "tensorflow/core/common_runtime/device_mgr.h" +#include "tensorflow/core/common_runtime/session_factory.h" +#include "tensorflow/core/platform/cpu_info.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/tfrt/runtime/runtime.h" +#include "tsl/platform/thread_annotations.h" + +namespace tensorflow { + +// Struct exposing a few threadpool configuration options. These +// correspond to the options in RunHandlerThreadWorkQueue::Options. +struct TfrtThreadpoolOptions { + // Number of threads used for running graphs. + int32_t num_main_threads = port::MaxParallelism(); + + // Time to wait for the init function to complete. + absl::Duration init_timeout = absl::Milliseconds(100); + + // Maximum number of concurrent RunHandlers. + int32_t max_concurrent_handler = 128; + + // Number of sub thread pools. + int32_t num_sub_thread_pool = 1; +}; + +struct TfrtSessionOptions { + TfrtThreadpoolOptions threadpool_options; + tensorflow::tfrt_stub::Runtime* runtime = nullptr; + bool enable_mlrt = false; + // Should only set one of `use_tpu` and `use_gpu` and `backend_compiler`. + bool use_tpu = false; + bool use_gpu = false; + tensorflow::BackendCompiler* backend_compiler = nullptr; +}; + +// Factory class to create `TfrtSession` instances. +class TfrtSessionFactory : public tensorflow::SessionFactory { + public: + TfrtSessionFactory(); + + bool AcceptsOptions(const SessionOptions& options) override; + + absl::Status NewSession(const SessionOptions& options, + Session** out_session) override + TF_LOCKS_EXCLUDED(mutex_); + + // This should only be used for the sake initializing resources for + // Python executables. It should only be called before main. + // + // Due to lack of applications and a concern for the ordering of initializers, + // this may only be called once. + using RuntimeInitializer = absl::Status (*)(tfrt_stub::Runtime*); + static void RegisterInitializer(RuntimeInitializer initializer); + + // May not be called within code holding mutex_. + static tfrt_stub::Runtime* GetRuntime(); + + private: + class ThreadPoolManager; + friend absl::Status InitializeTfrtSession(const TfrtSessionOptions& options); + friend absl::Status UpdateTfrtSessionOptionsLocked( + const TfrtSessionOptions& options); + absl::Status InitializeLocked(const TfrtSessionOptions& options) + TF_EXCLUSIVE_LOCKS_REQUIRED(mutex_); + bool IsInitialized() const TF_EXCLUSIVE_LOCKS_REQUIRED(mutex_) { + return runtime_ != nullptr; + } + + mutable absl::Mutex mutex_; + mutable absl::Mutex runtime_mutex_; + tensorflow::tfrt_stub::Runtime* runtime_ TF_GUARDED_BY(mutex_) = nullptr; + std::unique_ptr owned_runtime_ + TF_GUARDED_BY(mutex_); + + TfrtDeviceInfraTarget device_target_ TF_GUARDED_BY(mutex_) = + TfrtDeviceInfraTarget::kCpu; + bool tpu_use_tpu_runner_ TF_GUARDED_BY(mutex_) = false; + bool use_gpu_ TF_GUARDED_BY(mutex_) = false; + std::unique_ptr thread_pool_manager_ TF_GUARDED_BY(mutex_); + bool enable_mlrt_ TF_GUARDED_BY(mutex_) = false; + tensorflow::BackendCompiler* backend_compiler_ TF_GUARDED_BY(mutex_); + std::unique_ptr device_manager_; +}; + +// Configures the TfrtSessionFactory according to `options`. Should not be +// called within functions that are passed into +// `TfrtSessionFactory::RegisterInitializer`, because it acquires `mutex_`. +absl::Status InitializeTfrtSession(const TfrtSessionOptions& options); + +// Version of `InitializeTfrtSession` that can be used within functions passed +// into `TfrtSessionFactory::RegisterInitializer`. +absl::Status UpdateTfrtSessionOptionsLocked(const TfrtSessionOptions& options); +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_TFRT_TFRT_SESSION_TFRT_SESSION_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/tfrt/tfrt_session/tfrt_session_init.h b/third_party/tflite-hdrs/tensorflow/core/tfrt/tfrt_session/tfrt_session_init.h new file mode 100644 index 00000000..7891a0a8 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/tfrt/tfrt_session/tfrt_session_init.h @@ -0,0 +1,34 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_TFRT_TFRT_SESSION_TFRT_SESSION_INIT_H_ +#define TENSORFLOW_CORE_TFRT_TFRT_SESSION_TFRT_SESSION_INIT_H_ + +#include "tensorflow/core/common_runtime/local_session_selection.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { + +// Use TfrtSession as the Session implementation for local session. +// +// TODO(jingdong): Merge this function with the InitializeTfrtSession() in +// tfrt_session.h after we decouple TPU logic from TfrtSession. +inline absl::Status InitializeTfrtSession() { + SetDefaultLocalSessionImpl(LocalSessionImpl::kTfrtSession); + return absl::OkStatus(); +} + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_TFRT_TFRT_SESSION_TFRT_SESSION_INIT_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/tfrt/utils/any_ptr.h b/third_party/tflite-hdrs/tensorflow/core/tfrt/utils/any_ptr.h new file mode 100644 index 00000000..8b1a496c --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/tfrt/utils/any_ptr.h @@ -0,0 +1,170 @@ +/* Copyright 2016 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_TFRT_UTILS_ANY_PTR_H_ +#define TENSORFLOW_CORE_TFRT_UTILS_ANY_PTR_H_ + +#include +#include + +namespace tfrt { + +/// A (sort of) type-safe void*. Appears as null if a caller attempts to use it +/// as the wrong type. +/// +/// Example use: +/// +/// // A function that returns an AnyPtr: +/// AnyPtr StringOrInt() { +/// if (use_string) { +/// return AnyPtr(&some_string); +/// } else { +/// return AnyPtr(&some_int); +/// } +/// } +/// +/// // Use an AnyPtr at the correct type: +/// AnyPtr ptr = StringOrInt(); +/// if (ptr.get() != nullptr) { +/// DoSomethingWithInt(*ptr.get()); +/// } else if (ptr.get() != nullptr) { +/// DoSomethingWithString(*ptr.get()); +/// } else { +/// // Handle error. +/// } +/// +/// Typical best practice for this class is to use it when two disjoint pieces +/// of code must agree on type, but intermediate code is type agnostic. Large +/// chains of conditionals that handle a multitude of types is discouraged as an +/// anti-pattern. +/// +/// Note that this will appear null even if T is somewhere on the underlying +/// type's inheritance hierarchy, if you must use the object at some other type +/// you must do so explicitly when constructing an AnyPtr, like so: +/// +/// SomeObject object; +/// AnyPtr any_ptr(static_cast(&object)); +/// SomeInterface* interface = any_ptr.get(); +/// +/// This class is a value type; It can be copied or assigned. It performs no +/// internal allocations and should be relatively cheap to copy or return by +/// value. +class AnyPtr { + public: + /// AnyPtr is void and null by default. + AnyPtr() : type_id_(FastTypeId()), ptr_(nullptr) {} + + /// Implicit construction from nullptr. + AnyPtr(std::nullptr_t) : AnyPtr() {} // NOLINT + + /// Construct from a pointer to any type. + template + AnyPtr(T* ptr) // NOLINT + : type_id_(FastTypeId()), + // We need a double cast here, first to drop the type, and second to + // drop constness. We always cast back to the appropriate type and + // constness in get<>(), since FastTypeId is different for a const and + // non-const T. + ptr_(const_cast(reinterpret_cast(ptr))) {} + + /// Accessor for the underlying pointer if it is of type T, otherwise null. + template + T* get() const { + if (type_id_ != FastTypeId()) { + return nullptr; + } + return reinterpret_cast(ptr_); + } + + private: + template + static size_t FastTypeId() { + // Use a static variable to get a unique per-type address. + static int dummy; + return reinterpret_cast(&dummy); + } + + // The code for the type of 'ptr_'. + std::size_t type_id_; + + // The underlying pointer. + void* ptr_; +}; + +/// Like AnyPtr, but owns the pointed-to object (calls delete upon destruction). +/// This class is move-only, like std::unique_ptr. +class UniqueAnyPtr { + public: + /// UniqueAnyPtr is void and null by default. + UniqueAnyPtr() = default; + UniqueAnyPtr(std::nullptr_t) : UniqueAnyPtr() {} // NOLINT + + /// Construct from a unique pointer to any type. + template + explicit UniqueAnyPtr(std::unique_ptr ptr) + : ptr_(ptr.release()), deleter_(DeleterForType()) {} + + ~UniqueAnyPtr() { deleter_(ptr_); } + + // Disable copy. + UniqueAnyPtr(const UniqueAnyPtr& other) = delete; + UniqueAnyPtr& operator=(const UniqueAnyPtr& other) = delete; + + // Allow move. + UniqueAnyPtr(UniqueAnyPtr&& other) noexcept { swap(other); } + + UniqueAnyPtr& operator=(UniqueAnyPtr&& other) noexcept { + swap(other); + return *this; + } + + /// Accessor for the underlying pointer if it is of type T, otherwise null. + template + T* get() const { + return ptr_.get(); + } + + /// Accessor for the underlying pointer as an AnyPtr. + const AnyPtr& as_any_ptr() const { return ptr_; } + + void swap(UniqueAnyPtr& other) noexcept { + using ::std::swap; + swap(ptr_, other.ptr_); + swap(deleter_, other.deleter_); + } + + private: + // We use a raw function pointer. This eliminates the copy and calling + // overhead of std::function. + using Deleter = void (*)(AnyPtr ptr); + + // Returns a 'Deleter' that will delete it's argument as an instance of 'T'. + // Always returns the same value for the same 'T'. + template + static Deleter DeleterForType() { + return [](AnyPtr ptr) { delete ptr.get(); }; + } + + static Deleter NoOpDeleter() { + return [](AnyPtr ptr) {}; + } + + AnyPtr ptr_ = nullptr; + Deleter deleter_ = NoOpDeleter(); +}; + +} // namespace tfrt + +#endif // TENSORFLOW_CORE_TFRT_UTILS_ANY_PTR_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/tfrt/utils/bridge_graph_analysis.h b/third_party/tflite-hdrs/tensorflow/core/tfrt/utils/bridge_graph_analysis.h new file mode 100644 index 00000000..3fe4d67d --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/tfrt/utils/bridge_graph_analysis.h @@ -0,0 +1,35 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_TFRT_UTILS_BRIDGE_GRAPH_ANALYSIS_H_ +#define TENSORFLOW_CORE_TFRT_UTILS_BRIDGE_GRAPH_ANALYSIS_H_ + +#include "tensorflow/core/platform/status.h" + +namespace tfrt { + +inline tensorflow::Status CheckTpuMlirBridgeCompatibility( + const tensorflow::GraphDef& graph_def) { + return tensorflow::OkStatus(); +} + +inline tensorflow::Status CheckSpmdGraph( + const tensorflow::GraphDef& graph_def) { + return tensorflow::OkStatus(); +} + +} // namespace tfrt + + +#endif // TENSORFLOW_CORE_TFRT_UTILS_BRIDGE_GRAPH_ANALYSIS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/tfrt/utils/debug/node_io_dump_rewriter.h b/third_party/tflite-hdrs/tensorflow/core/tfrt/utils/debug/node_io_dump_rewriter.h new file mode 100644 index 00000000..068c19ba --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/tfrt/utils/debug/node_io_dump_rewriter.h @@ -0,0 +1,44 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_TFRT_UTILS_DEBUG_NODE_IO_DUMP_REWRITER_H_ +#define TENSORFLOW_CORE_TFRT_UTILS_DEBUG_NODE_IO_DUMP_REWRITER_H_ + +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/strings/string_view.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/protobuf/meta_graph.pb.h" + +namespace tensorflow { +namespace tfrt_stub { + +// Rewrites `graph` by inserting dump nodes for `nodes_to_dump`. During graph +// execution, the inputs and outputs of `nodes_to_dump` will be dumped to the +// folder specified by env var `TF_DUMP_GRAPH_PREFIX`. +absl::Status InsertDumpOps( + Graph& graph, const absl::flat_hash_set& nodes_to_dump, + absl::string_view dump_dir = ""); +// Similar to the above, but rewrites a `meta_graph_def`. +absl::Status InsertDumpOps( + MetaGraphDef& meta_graph_def, + const absl::flat_hash_set& nodes_to_dump, + absl::string_view dump_dir = ""); + +} // namespace tfrt_stub +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_TFRT_UTILS_DEBUG_NODE_IO_DUMP_REWRITER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/tfrt/utils/device_variables_table.h b/third_party/tflite-hdrs/tensorflow/core/tfrt/utils/device_variables_table.h new file mode 100644 index 00000000..1b1a742e --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/tfrt/utils/device_variables_table.h @@ -0,0 +1,98 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_TFRT_UTILS_DEVICE_VARIABLES_TABLE_H_ +#define TENSORFLOW_CORE_TFRT_UTILS_DEVICE_VARIABLES_TABLE_H_ + +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/synchronization/mutex.h" +#include "llvm/ADT/FunctionExtras.h" +#include "tfrt/host_context/async_value_ref.h" // from @tf_runtime + +namespace tfrt { + +// A variable table that keeps track of the device copies of host tensors. +// The same variable can have multiple copies on devices (e.g., on different TPU +// cores), and hence they are differenticated via `copy_index`. +// The table maps from to device tensor. +template +class DeviceVariablesTable { + public: + virtual ~DeviceVariablesTable() { ClearDeviceVariablesTable(); } + + void AddOrUpdateDeviceVariable( + const HostTensorType& host_tensor, int copy_index, + AsyncValueRef device_tensor) { + absl::MutexLock lock(&device_variables_mu_); + device_variables_table_.insert_or_assign( + std::make_pair(GetHostTensorDataPtr(host_tensor), copy_index), + std::move(device_tensor)); + } + + AsyncValueRef GetDeviceVariable( + const HostTensorType& host_tensor, int copy_index) { + absl::ReaderMutexLock lock(&device_variables_mu_); + auto it = device_variables_table_.find( + std::make_pair(GetHostTensorDataPtr(host_tensor), copy_index)); + return it != device_variables_table_.end() + ? it->second.CopyRef() + : AsyncValueRef(); + } + + AsyncValueRef GetOrAddDeviceVariable( + const HostTensorType& host_tensor, int copy_index, + llvm::unique_function)> creator) { + absl::ReleasableMutexLock lock(&device_variables_mu_); + auto it = device_variables_table_.find( + std::make_pair(GetHostTensorDataPtr(host_tensor), copy_index)); + if (it != device_variables_table_.end()) return it->second.CopyRef(); + + auto device_tensor = MakeUnconstructedAsyncValueRef(); + device_variables_table_.insert( + {std::make_pair(GetHostTensorDataPtr(host_tensor), copy_index), + device_tensor.CopyRef()}); + lock.Release(); + creator(device_tensor.CopyRef()); + return device_tensor; + } + + void ClearDeviceVariablesTable() { + absl::MutexLock lock(&device_variables_mu_); + device_variables_table_.clear(); + } + + int size() { + absl::ReaderMutexLock lock(&device_variables_mu_); + return device_variables_table_.size(); + } + + protected: + // Get the host tensor data pointer, which is used as a part of the table key. + virtual const void* GetHostTensorDataPtr( + const HostTensorType& host_tensor) = 0; + + private: + absl::Mutex device_variables_mu_; + + // A map from to device tensor. + absl::flat_hash_map, + AsyncValueRef> + device_variables_table_ ABSL_GUARDED_BY(device_variables_mu_); +}; + +} // namespace tfrt + +#endif // TENSORFLOW_CORE_TFRT_UTILS_DEVICE_VARIABLES_TABLE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/tfrt/utils/error_util.h b/third_party/tflite-hdrs/tensorflow/core/tfrt/utils/error_util.h new file mode 100644 index 00000000..229b854a --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/tfrt/utils/error_util.h @@ -0,0 +1,80 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_TFRT_UTILS_ERROR_UTIL_H_ +#define TENSORFLOW_CORE_TFRT_UTILS_ERROR_UTIL_H_ + +#include + +#include "tensorflow/core/platform/status.h" +#include "tfrt/support/error_util.h" // from @tf_runtime +#include "tfrt/support/forward_decls.h" // from @tf_runtime + +namespace tfrt { +class DecodedDiagnostic; + +tfrt::ErrorCode ConvertTfErrorCodeToTfrtErrorCode(const absl::Status& status); + +absl::Status CreateTfErrorStatus(const DecodedDiagnostic& error); + +absl::Status ToTfStatus(const AsyncValue* av); + +inline std::string MakeStatusString(absl::Status status) { + switch (static_cast(status.code())) { + case absl::StatusCode::kOk: + return "OK"; + case absl::StatusCode::kCancelled: + return absl::StrCat("Cancelled: ", status.message()); + case absl::StatusCode::kUnknown: + return absl::StrCat("Unknown: ", status.message()); + case absl::StatusCode::kInvalidArgument: + return absl::StrCat("Invalid argument: ", status.message()); + case absl::StatusCode::kDeadlineExceeded: + return absl::StrCat("Deadline exceeded: ", status.message()); + case absl::StatusCode::kNotFound: + return absl::StrCat("Not found: ", status.message()); + case absl::StatusCode::kAlreadyExists: + return absl::StrCat("Already exists: ", status.message()); + case absl::StatusCode::kPermissionDenied: + return absl::StrCat("Permission denied: ", status.message()); + case absl::StatusCode::kUnauthenticated: + return absl::StrCat("Unauthenticated: ", status.message()); + case absl::StatusCode::kResourceExhausted: + return absl::StrCat("Resource exhausted: ", status.message()); + case absl::StatusCode::kFailedPrecondition: + return absl::StrCat("Failed precondition: ", status.message()); + case absl::StatusCode::kAborted: + return absl::StrCat("Aborted: ", status.message()); + case absl::StatusCode::kOutOfRange: + return absl::StrCat("Out of range: ", status.message()); + case absl::StatusCode::kUnimplemented: + return absl::StrCat("Unimplemented: ", status.message()); + case absl::StatusCode::kInternal: + return absl::StrCat("Internal: ", status.message()); + case absl::StatusCode::kUnavailable: + return absl::StrCat("Unavailable: ", status.message()); + case absl::StatusCode::kDataLoss: + return absl::StrCat("Data loss: ", status.message()); + default: + return absl::StrCat("Unknown code: ", status.message()); + } +} + +inline llvm::Error MakeStatusError(absl::Status status) { + return MakeStringError(MakeStatusString(status)); +} + +} // namespace tfrt + +#endif // TENSORFLOW_CORE_TFRT_UTILS_ERROR_UTIL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/tfrt/utils/fallback_tensor.h b/third_party/tflite-hdrs/tensorflow/core/tfrt/utils/fallback_tensor.h new file mode 100644 index 00000000..c5b81f36 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/tfrt/utils/fallback_tensor.h @@ -0,0 +1,104 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_TFRT_UTILS_FALLBACK_TENSOR_H_ +#define TENSORFLOW_CORE_TFRT_UTILS_FALLBACK_TENSOR_H_ + +#include + +#include "tensorflow/core/common_runtime/dma_helper.h" +#include "tensorflow/core/framework/tensor.h" +#include "tsl/profiler/lib/traceme.h" + +namespace tensorflow { +namespace tfrt_stub { + +// A special tensor wrapper for immutable tensors that live a long time and are +// reused across steps in a program, eg. weights. +class ImmutableTensor { + public: + ImmutableTensor() = default; + // Create an ImmutableTensor by copying the content in `tensor`. + static ImmutableTensor Create(tensorflow::Tensor tensor); + + // Accessors for this underlying tensor. Users must not modify its content. It + // is guaranteed that RefCountIsOne() always return false for the tensor. + tensorflow::Tensor& tensor() { return tensor_; } + const tensorflow::Tensor& tensor() const { return tensor_; } + + private: + explicit ImmutableTensor(tensorflow::Tensor tensor) + : tensor_(std::move(tensor)) { + DCHECK(!tensor_.RefCountIsOne()) + << "Immutable tensors' buffers cannot be forwarded."; + } + + tensorflow::Tensor tensor_; +}; + +// A wrapper class over normal tensors and immutable tensors. This class is used +// as the currency type in TFRT fallback execution. Note that this class does +// not own the underlying tensor if it is an immutable tensor. +class FallbackTensor { + public: + FallbackTensor() = default; + + explicit FallbackTensor(const tensorflow::Tensor& tensor) : tensor_(tensor) {} + explicit FallbackTensor(tensorflow::Tensor&& tensor) + : tensor_(std::move(tensor)) {} + + explicit FallbackTensor(ImmutableTensor* immutable_tensor) + : tensor_(immutable_tensor->tensor()), is_immutable_(true) {} + + FallbackTensor(const FallbackTensor& other) { *this = other; } + FallbackTensor& operator=(const FallbackTensor& other) { + tsl::profiler::TraceMe trace_me("FallbackTensor::Copy"); + if (!other.is_immutable() && other.buffer() != nullptr) { + // Create a new TensorBuffer which contains a new atomic counter for each + // result, to avoid downstream threads contending the original atomic + // counter. + tensor_ = std::move( + tensorflow::tfrt_stub::ImmutableTensor::Create(other.tensor()) + .tensor()); + } else { + // For immutable tensors or empty tensors, we just need to copy the + // pointer as they don't incur atomic operations when they are referenced. + tensor_ = other.tensor(); + } + is_immutable_ = true; + return *this; + } + + FallbackTensor(FallbackTensor&&) noexcept = default; + FallbackTensor& operator=(FallbackTensor&&) noexcept = default; + + const TensorBuffer* buffer() const { + return tensorflow::DMAHelper::buffer(&tensor()); + } + TensorBuffer* buffer() { return tensorflow::DMAHelper::buffer(&tensor()); } + + bool is_immutable() const { return is_immutable_; } + + tensorflow::Tensor& tensor() { return tensor_; } + const tensorflow::Tensor& tensor() const { return tensor_; } + + private: + tensorflow::Tensor tensor_; + bool is_immutable_ = false; +}; + +} // namespace tfrt_stub +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_TFRT_UTILS_FALLBACK_TENSOR_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/tfrt/utils/gpu_variables_table.h b/third_party/tflite-hdrs/tensorflow/core/tfrt/utils/gpu_variables_table.h new file mode 100644 index 00000000..7e413f56 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/tfrt/utils/gpu_variables_table.h @@ -0,0 +1,42 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_TFRT_UTILS_GPU_VARIABLES_TABLE_H_ +#define TENSORFLOW_CORE_TFRT_UTILS_GPU_VARIABLES_TABLE_H_ + +#include "tensorflow/core/tfrt/utils/device_variables_table.h" +#include "tensorflow/core/tfrt/utils/fallback_tensor.h" + +namespace tfrt { +namespace gpu { + +// This is for creating/getting GpuVariablesTable object in the execution +// context at runtime. +constexpr char kGpuVariablesTableResourceName[] = "GpuVariablesTableResource"; + +// A variable table that keeps track of the device copies of GPU host tensors. +class GpuVariablesTable + : public DeviceVariablesTable { + private: + const void* GetHostTensorDataPtr( + const tensorflow::tfrt_stub::FallbackTensor& host_tensor) override { + return host_tensor.tensor().data(); + } +}; + +} // namespace gpu +} // namespace tfrt + +#endif // TENSORFLOW_CORE_TFRT_UTILS_GPU_VARIABLES_TABLE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/tfrt/utils/graph_partition.h b/third_party/tflite-hdrs/tensorflow/core/tfrt/utils/graph_partition.h new file mode 100644 index 00000000..4f5cedd2 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/tfrt/utils/graph_partition.h @@ -0,0 +1,67 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_TFRT_UTILS_GRAPH_PARTITION_H_ +#define TENSORFLOW_CORE_TFRT_UTILS_GRAPH_PARTITION_H_ + +#include +#include +#include + +#include "tensorflow/core/common_runtime/device_set.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { +namespace tfrt_stub { + +// Inserts send/recv ops to `graph` if nodes are assigned to multiple devices. +// Specifically, nodes on the same device will be wrapped in a function and +// invoked by a PartitionedCall op. All PartitionedCall ops are connected to a +// StatefulPartitionedCall op (which behaves as a 'stateful IdentityN') to +// protect them from being pruned in the subsequent MLIR lowering passes +// (b/232026253). +// +// The following shows a simple example of using this method. +// +// The original graph has four nodes that are placed on different devices. +// +// -----> op1(host) ------ +// / \ +// input(host) output(host) +// \ / +// -----> op2(device) ------ +// +// Calling this method will return the following graph, where `op1` is wrapped +// in the function invoked by `PartitionedCall_1`, and `op2` is wrapped in the +// function invoked by `PartitionedCall_2`. Both of them have a data dependency +// with the `StatefulPartitionedCall` op. +// +// input ---> PartitionedCall_1 ---- +// \ +// StatefulPartitionedCall ---> output +// / +// PartitionedCall_2 ---- +// +absl::StatusOr> InsertTransferOps( + const std::string& graph_func_name, const DeviceSet& device_set, + const Device* host_device, const std::vector& inputs, + const std::vector& outputs, + const std::vector& control_outputs, + std::unique_ptr graph); + +} // namespace tfrt_stub +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_TFRT_UTILS_GRAPH_PARTITION_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/tfrt/utils/tensor_util.h b/third_party/tflite-hdrs/tensorflow/core/tfrt/utils/tensor_util.h new file mode 100644 index 00000000..358d7604 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/tfrt/utils/tensor_util.h @@ -0,0 +1,48 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_TFRT_UTILS_TENSOR_UTIL_H_ +#define TENSORFLOW_CORE_TFRT_UTILS_TENSOR_UTIL_H_ + +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/platform/statusor.h" +#include "tfrt/core_runtime/tensor_handle.h" // from @tf_runtime +#include "tfrt/host_context/host_context.h" // from @tf_runtime +#include "tfrt/tensor/dense_host_tensor.h" // from @tf_runtime +#include "tfrt/tensor/tensor.h" // from @tf_runtime + +namespace tfrt { + +// Converts a tfrt::Tensor to tensorflow::Tensor. +llvm::Expected TFRTTensorToTFTensor(const Tensor& tensor); + +// Converts a tensorflow::Tensor to tfrt::TensorHandle. +AsyncValueRef TFTensorToTFRTTensorHandle( + const tensorflow::Tensor& tf_tensor, HostContext* host_ctx); + +// Creates a TFRT TensorHandle using the shape and data in a tensorflow tensor. +absl::StatusOr CreateTensorHandleFromTFTensor( + const tensorflow::Tensor& tensor, HostContext* host); + +// Creates a tensorflow tensor using the shape and data in a TFRT tensorhandle. +absl::StatusOr CreateTFTensorFromTensorHandle( + const TensorHandle& tensor_handle); + +// Converts a tensorflow::Tensor to tfrt::DenseHostTensor. +// TODO(tfrt-devs): consider generalize to TFTensorToTFRTTensor +Expected ConvertTfTensorToDHT(tensorflow::Tensor tf_tensor); + +} // namespace tfrt + +#endif // TENSORFLOW_CORE_TFRT_UTILS_TENSOR_UTIL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/tfrt/utils/tfrt_graph_execution_state.h b/third_party/tflite-hdrs/tensorflow/core/tfrt/utils/tfrt_graph_execution_state.h new file mode 100644 index 00000000..2912c2ca --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/tfrt/utils/tfrt_graph_execution_state.h @@ -0,0 +1,145 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_TFRT_UTILS_TFRT_GRAPH_EXECUTION_STATE_H_ +#define TENSORFLOW_CORE_TFRT_UTILS_TFRT_GRAPH_EXECUTION_STATE_H_ + +#include +#include +#include +#include +#include +#include + +#include "absl/synchronization/mutex.h" +#include "absl/time/time.h" +#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h" +#include "tensorflow/compiler/mlir/tf2xla/api/v1/mlir_bridge_config_v1.pb.h" +#include "tensorflow/core/common_runtime/graph_execution_state.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/platform/statusor.h" +#include "tensorflow/core/protobuf/config.pb.h" +#include "tensorflow/core/tfrt/fallback/fallback_state.h" +#include "tensorflow/core/tfrt/graph_executor/config.h" + +namespace tensorflow { +namespace tfrt_stub { + +// This is a TFRT variant of `tensorflow::GraphExecutionState`. It wraps +// `tensorflow::GraphExecutionState` and adds TFRT-specific adjustments. +// +// Responsible for generating an executable `Graph` from the original `GraphDef` +// that specifies the complete graph and from `GraphImportConfig` that specifies +// input/output nodes. +// +// Thread-safe. +class TfrtGraphExecutionState { + public: + struct OptimizationResult { + std::unique_ptr graph; + absl::Duration functionalization_duration; + absl::Duration grappler_duration; + }; + + struct Options { + bool run_placer_grappler_on_functions = false; + bool run_placer_on_graph = true; + }; + + // Creates a `GraphExecutionState` given `graph_def` and `fallback_state`. + static absl::StatusOr> Create( + const Options& options, tensorflow::GraphDef graph_def, + const FallbackState& fallback_state, + tensorflow::tfrt_stub::RuntimeConfig* runtime_config = nullptr); + + // Ctor. Do not use directly. Public only for `std::make_unique<>()`. + TfrtGraphExecutionState( + const Options& options, + std::unique_ptr graph_execution_state, + const FallbackState& fallback_state, + absl::flat_hash_set functions_to_optimize) + : options_(options), + graph_execution_state_(std::move(graph_execution_state)), + fallback_state_(fallback_state), + functions_to_optimize_(std::move(functions_to_optimize)) {} + + // Creates an optimized graph by pruning with `graph_import_config` and + // best-effort Grappler run. + absl::StatusOr CreateOptimizedGraph( + tensorflow::GraphImportConfig& graph_import_config); + + // Extends the current graph by `graph`. + absl::Status Extend(const GraphDef& graph); + + // Return the preprocessed full graph. Note that it does not contain the + // function library in the original graph. + const tensorflow::Graph& graph() const { + absl::MutexLock lock(&graph_execution_state_mu_); + DCHECK(graph_execution_state_->full_graph()); + return *graph_execution_state_->full_graph(); + } + + // The original graph. + const GraphDef* original_graph_def() const { + absl::MutexLock lock(&graph_execution_state_mu_); + return graph_execution_state_->original_graph_def(); + } + + // Return the function library in the original graph. + const FunctionLibraryDefinition& flib_def() const { + absl::MutexLock lock(&graph_execution_state_mu_); + return graph_execution_state_->flib_def(); + } + + private: + absl::StatusOr> OptimizeGraph( + const tensorflow::Graph& graph, + const tensorflow::BuildGraphOptions& build_graph_options); + + Options options_; + + std::unique_ptr graph_execution_state_ + ABSL_GUARDED_BY(graph_execution_state_mu_); + // We need this mutex even thought `GraphExecutionState` is thread-safe, + // because `swap()` is not thread-safe. + mutable absl::Mutex graph_execution_state_mu_; + + const FallbackState& fallback_state_; + // Only valid if `options_.run_placer_grappler_on_functions` is true. + absl::flat_hash_set functions_to_optimize_ + ABSL_GUARDED_BY(graph_execution_state_mu_); +}; + +// Prunes the `graph_def` using the feed/fetch nodes specified in +// `callable_options`. It is a TFRT-specific version that it performs more +// pruning (e.g., prunes the input edges to the feed nodes) than +// `ComputeTransitiveFanin()` so that the graph can be functionalized properly +// later. +absl::Status PruneGraphDef(GraphDef& graph_def, + const CallableOptions& callable_options); + +// Eliminates ref variables in V1 control flow, which is required for +// functionalization. Current strategy is to insert an identity node between +// each ref node and its ref input and in-place update the ref node to its +// non-ref counterpart. +absl::Status EliminateRefVariablesFromV1ControlFlow(GraphDef& graph_def); + +// Removes the "_input_shapes" attribute of functions in the graph. +void RemoveInputShapesInFunctions(tensorflow::GraphDef& graph_def); + +} // namespace tfrt_stub +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_TFRT_UTILS_TFRT_GRAPH_EXECUTION_STATE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/tfrt/utils/thread_pool.h b/third_party/tflite-hdrs/tensorflow/core/tfrt/utils/thread_pool.h new file mode 100644 index 00000000..0efe9133 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/tfrt/utils/thread_pool.h @@ -0,0 +1,61 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_TFRT_UTILS_THREAD_POOL_H_ +#define TENSORFLOW_CORE_TFRT_UTILS_THREAD_POOL_H_ + +#include +#include +#include + +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/threadpool.h" +#include "tensorflow/core/platform/threadpool_interface.h" + +namespace tensorflow { +namespace tfrt_stub { + +class TfThreadPool : public thread::ThreadPoolInterface { + public: + explicit TfThreadPool(const std::string& name, int num_threads) + : underlying_threadpool_(tensorflow::Env::Default(), name, num_threads) {} + + void Schedule(std::function fn) override { + underlying_threadpool_.Schedule(std::move(fn)); + } + + void ScheduleWithHint(std::function fn, int start, int end) override { + underlying_threadpool_.ScheduleWithHint(std::move(fn), start, end); + } + + void Cancel() override { + underlying_threadpool_.AsEigenThreadPool()->Cancel(); + } + + int NumThreads() const override { + return underlying_threadpool_.NumThreads(); + } + + int CurrentThreadId() const override { + return underlying_threadpool_.CurrentThreadId(); + } + + private: + tensorflow::thread::ThreadPool underlying_threadpool_; +}; + +} // namespace tfrt_stub +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_TFRT_UTILS_THREAD_POOL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/tfrt/utils/utils.h b/third_party/tflite-hdrs/tensorflow/core/tfrt/utils/utils.h new file mode 100644 index 00000000..970de920 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/tfrt/utils/utils.h @@ -0,0 +1,136 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_TFRT_UTILS_UTILS_H_ +#define TENSORFLOW_CORE_TFRT_UTILS_UTILS_H_ + +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/statusor.h" +#include "tensorflow/core/platform/strcat.h" +#include "tensorflow/core/tfrt/runtime/runtime.h" +#include "tfrt/bef/bef_buffer.h" // from @tf_runtime +#include "tfrt/dtype/dtype.h" // from @tf_runtime +#include "tfrt/support/forward_decls.h" // from @tf_runtime + +namespace tensorflow { +class Device; +} // namespace tensorflow + +namespace tfrt { + +class BEFFile; +class ExecutionContext; +class HostContext; + +typedef absl::InlinedVector TfrtDataTypeVector; +typedef absl::Span TfrtDataTypeSlice; + +DType ConvertTfDTypeToTfrtDType(tensorflow::DataType dtype); + +// Runs the runtime initialization function. A runtime initialization function +// is added by runtime/compiler workflow and is not present in the original +// savedmodel. +// +// TODO(b/178714905): We should avoid special handling on initialization by +// letting compiler to handle it. +absl::Status RunRuntimeInitializer(const tfrt::ExecutionContext& exec_ctx, + tfrt::BEFFile* bef_file, + absl::string_view fallback_init_func); + +// Creates dummy TF devices from the input device names. Currently this method +// is used to create the TPU_SYSTEM device for worker server. +void CreateDummyTfDevices( + const std::vector& device_names, + std::vector>* dummy_tf_devices); + +// Creates and add dummy TFRT devices from the input device names. Currently +// this method is used to create the TPU_SYSTEM device for worker server. +void AddDummyTfrtDevices(const std::vector& device_names, + tfrt::HostContext* host_ctx); + +// Creates a BEF file from a BEF buffer. `runtime` is used to provide host +// context for opening `bef`. +absl::StatusOr> CreateBefFileFromBefBuffer( + const tensorflow::tfrt_stub::Runtime& runtime, const tfrt::BefBuffer& bef); + +// Returns a unique integer within this process. +int64_t GetUniqueInt(); + +// Returns current CPU time. +uint64_t GetCpuClockCycle(); + +// A list of macros similar to `TF_RETURN_IF_ERROR`, with additional model +// loading stage info. +#define RETURN_IF_ERROR_IN_IMPORT(...) \ + RETURN_IF_ERROR_WITH_STAGE_INFO("GraphDef proto -> MLIR", __VA_ARGS__) + +#define RETURN_IF_ERROR_IN_COMPILE(...) \ + RETURN_IF_ERROR_WITH_STAGE_INFO( \ + "TF dialect -> TFRT dialect, compiler issue, please contact the TFRT " \ + "team", \ + __VA_ARGS__) + +#define RETURN_IF_ERROR_IN_INIT(...) \ + RETURN_IF_ERROR_WITH_STAGE_INFO("Initialize TFRT", __VA_ARGS__) + +#define RETURN_IF_ERROR_WITH_STAGE_INFO(stage, ...) \ + do { \ + ::tensorflow::Status _status = (__VA_ARGS__); \ + if (TF_PREDICT_FALSE(!_status.ok())) { \ + return ::tensorflow::errors::CreateWithUpdatedMessage( \ + _status, \ + ::tensorflow::strings::StrCat(stage, ": ", _status.message())); \ + } \ + } while (0) + +// A list of macros similar to `TF_ASSIGN_OR_RETURN`, with additional model +// loading stage info. +#define ASSIGN_OR_RETURN_IN_IMPORT(lhs, rexpr) \ + ASSIGN_OR_RETURN_WITH_STAGE_INFO("GraphDef proto -> MLIR", lhs, rexpr) + +#define ASSIGN_OR_RETURN_IN_COMPILE(lhs, rexpr) \ + ASSIGN_OR_RETURN_WITH_STAGE_INFO( \ + "TF dialect -> TFRT dialect, compiler issue, please contact the TFRT " \ + "team", \ + lhs, rexpr) + +#define ASSIGN_OR_RETURN_IN_INIT(lhs, rexpr) \ + ASSIGN_OR_RETURN_WITH_STAGE_INFO("Initialize TFRT", lhs, rexpr) + +#define ASSIGN_OR_RETURN_WITH_STAGE_INFO(stage, lhs, rexpr) \ + ASSIGN_OR_RETURN_WITH_STAGE_INFO_IMPL( \ + TF_STATUS_MACROS_CONCAT_NAME(_status_or_value, __COUNTER__), stage, lhs, \ + rexpr) + +#define ASSIGN_OR_RETURN_WITH_STAGE_INFO_IMPL(statusor, stage, lhs, rexpr) \ + auto statusor = (rexpr); \ + if (TF_PREDICT_FALSE(!statusor.ok())) { \ + const auto& _status = statusor.status(); \ + return ::tensorflow::errors::CreateWithUpdatedMessage( \ + _status, \ + ::tensorflow::strings::StrCat(stage, ": ", _status.message())); \ + } \ + lhs = std::move(statusor.value()) + +} // namespace tfrt + +#endif // TENSORFLOW_CORE_TFRT_UTILS_UTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/tpu/graph_rewrite/combine_tpu_embedding_load_retrieve_pass.h b/third_party/tflite-hdrs/tensorflow/core/tpu/graph_rewrite/combine_tpu_embedding_load_retrieve_pass.h new file mode 100644 index 00000000..746ec93d --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/tpu/graph_rewrite/combine_tpu_embedding_load_retrieve_pass.h @@ -0,0 +1,36 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_TPU_GRAPH_REWRITE_COMBINE_TPU_EMBEDDING_LOAD_RETRIEVE_PASS_H_ +#define TENSORFLOW_CORE_TPU_GRAPH_REWRITE_COMBINE_TPU_EMBEDDING_LOAD_RETRIEVE_PASS_H_ + +#include "absl/status/status.h" +#include "tensorflow/core/common_runtime/optimization_registry.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/status.h" + +namespace tensorflow { + +// Merges per-table TPUEmbedding load and retrieve operators into global +// operators. +class CombineTPUEmbeddingLoadRetrievePass : public GraphOptimizationPass { + public: + absl::Status Run(const GraphOptimizationPassOptions& options) override; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_TPU_GRAPH_REWRITE_COMBINE_TPU_EMBEDDING_LOAD_RETRIEVE_PASS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/tpu/graph_rewrite/cond_builder.h b/third_party/tflite-hdrs/tensorflow/core/tpu/graph_rewrite/cond_builder.h new file mode 100644 index 00000000..dd827a3b --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/tpu/graph_rewrite/cond_builder.h @@ -0,0 +1,76 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_TPU_GRAPH_REWRITE_COND_BUILDER_H_ +#define TENSORFLOW_CORE_TPU_GRAPH_REWRITE_COND_BUILDER_H_ + +#include + +#include "absl/status/status.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/platform/status.h" + +namespace tensorflow { + +// Conditional builder. +// Convenience builder to make it easy to construct a conditional. E.g., +// Node* pred = ...; +// CondBuilder cb("cond", g); +// auto switch_var = cb.AddInput("var", DT_RESOURCE); +// g->AddEdge(pred, 0, cb.pred(), 0); +// Will create the nodes of a conditional that takes as input a resource +// variable ("var") as input and that switches on pred. +// +// This currently only handles the case needed by distributed_tpu_rewrite_pass +// and is not completely general. +class CondBuilder { + public: + enum Branch { kElseBranch = 0, kThenBranch = 1 }; + + CondBuilder(std::string name, std::string device, const NodeDebugInfo& debug, + Graph* graph); + + // Returns node corresponding to the predicate input. + Node* pred(); + + // Returns node corresponding to switch_f branch of predicate switch. + Node* switch_f(); + + // Returns node corresponding to switch_t branch of predicate switch. + Node* switch_t(); + + // Returns node corresponding to control successor. + Node* control_successor(); + + // Returns the Switch node to feed a value of the given type into the + // conditional. + absl::Status AddInput(const std::string& input_name, const DataType& type, + const std::string& device, const NodeDebugInfo& debug, + Node** input); + + private: + Node* control_successor_; + Node* switch_f_; + Node* switch_t_; + Node* pred_; + Graph* const graph_; + const std::string name_; + const std::string device_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_TPU_GRAPH_REWRITE_COND_BUILDER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/tpu/graph_rewrite/configure_tpu_embedding_rewrite_pass.h b/third_party/tflite-hdrs/tensorflow/core/tpu/graph_rewrite/configure_tpu_embedding_rewrite_pass.h new file mode 100644 index 00000000..977447f2 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/tpu/graph_rewrite/configure_tpu_embedding_rewrite_pass.h @@ -0,0 +1,40 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Rewrites ConfigureTPUEmbedding Op into nodes which set up TPUEmbedding. + +#ifndef TENSORFLOW_CORE_TPU_GRAPH_REWRITE_CONFIGURE_TPU_EMBEDDING_REWRITE_PASS_H_ +#define TENSORFLOW_CORE_TPU_GRAPH_REWRITE_CONFIGURE_TPU_EMBEDDING_REWRITE_PASS_H_ + +#include "absl/status/status.h" +#include "tensorflow/core/common_runtime/optimization_registry.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/status.h" + +namespace tensorflow { + +// TODO(shizhiw): Clean up embedding related code from +// distributed_tpu_configuration_rewrite_pass.cc. +// Replaces dummy ConfigureTPUEmbedding Ops assigned to TPU_SYSTEM +// devices with nodes which will set up TPU Embedding. +class ConfigureTPUEmbeddingRewritePass : public GraphOptimizationPass { + public: + absl::Status Run(const GraphOptimizationPassOptions& options) override; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_TPU_GRAPH_REWRITE_CONFIGURE_TPU_EMBEDDING_REWRITE_PASS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/tpu/graph_rewrite/distributed_tpu_configuration_rewrite_pass.h b/third_party/tflite-hdrs/tensorflow/core/tpu/graph_rewrite/distributed_tpu_configuration_rewrite_pass.h new file mode 100644 index 00000000..ecde017f --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/tpu/graph_rewrite/distributed_tpu_configuration_rewrite_pass.h @@ -0,0 +1,51 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Rewrites ConfigureDistributedTPU Op into a graph that configures each host. +// +// See the comment at the top of +// third_party/tensorflow/core/ops/tpu_configuration_ops.cc to see the +// sequence of Ops used to configure a distributed TPU system. + +#ifndef TENSORFLOW_CORE_TPU_GRAPH_REWRITE_DISTRIBUTED_TPU_CONFIGURATION_REWRITE_PASS_H_ +#define TENSORFLOW_CORE_TPU_GRAPH_REWRITE_DISTRIBUTED_TPU_CONFIGURATION_REWRITE_PASS_H_ + +#include "absl/status/status.h" +#include "tensorflow/core/common_runtime/optimization_registry.h" +#include "tensorflow/core/platform/status.h" + +namespace tensorflow { + +// Replaces dummy ConfigureDistributedTPU Ops assigned to TPU_SYSTEM +// devices with _ConfigureDistributedTPU and _WaitForDistributedTPU +// Ops on TPU_SYSTEM, and _InitializeHostForDistributedTPU on the CPU +// device of each host in the same job as the given TPU_SYSTEM device. +class DistributedTPUConfigurationRewritePass : public GraphOptimizationPass { + public: + absl::Status Run(const GraphOptimizationPassOptions& options) override; +}; + +// Replaces dummy ShutdownDistributedTPU Ops assigned to TPU_SYSTEM +// devices with _ShutdownDistributedTPU Ops on TPU_SYSTEM and +// _DisconnectHostFromDistributedTPUSystem on the CPU device of each +// host in the same job as the given TPU_SYSTEM device. +class DistributedTPUShutdownRewritePass : public GraphOptimizationPass { + public: + absl::Status Run(const GraphOptimizationPassOptions& options) override; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_TPU_GRAPH_REWRITE_DISTRIBUTED_TPU_CONFIGURATION_REWRITE_PASS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_helpers.h b/third_party/tflite-hdrs/tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_helpers.h new file mode 100644 index 00000000..ae4bfc8b --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_helpers.h @@ -0,0 +1,106 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Helper functions for TPU rewrite passes. + +#ifndef TENSORFLOW_CORE_TPU_GRAPH_REWRITE_DISTRIBUTED_TPU_REWRITE_HELPERS_H_ +#define TENSORFLOW_CORE_TPU_GRAPH_REWRITE_DISTRIBUTED_TPU_REWRITE_HELPERS_H_ + +#include +#include + +#include "absl/status/status.h" +#include "xla/status_macros.h" +#include "tensorflow/core/common_runtime/device_set.h" +#include "tensorflow/core/framework/device.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/resource_mgr.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/device_name_utils.h" + +namespace tensorflow { + +class DistributedTPURewriteHelpers { + public: + // Given a user-assigned device string, system_spec_string, parse it into + // system_spec. Verify that the device type is either TPU_SYSTEM or + // unassigned, and in the latter case set it to TPU_SYSTEM:0. Having set the + // type, verify that the spec matches a unique device in device_set, and + // return that device in system_device. The normal use case is for + // system_spec_string to identify the TPU_SYSTEM on replica 0, task 0 of the + // job that contains the TPU hardware. + // TODO(b/110910013): Possibly remove the tpu system device. + static absl::Status GetSystemDevice(const string& system_spec_string, + const DeviceSet& device_set, + DeviceNameUtils::ParsedName* system_spec, + Device** system_device); + + // Given a parsed system spec (e.g., the one returned above from + // GetSystemDeviceName), return in host_devices the TPU_SYSTEM:0 device on + // every host in the spec's job. If the spec does not include an explicit job, + // "localhost" is used. Returns an error if system_spec matches devices from + // a multiple jobs or replicas. + static absl::Status GetHostSystemDevices( + const DeviceNameUtils::ParsedName& system_spec, + const DeviceSet& device_set, std::vector* host_system_devices); + + // Given a parsed system spec (e.g., the one returned above from + // GetSystemDeviceName), sets `*tpu_devices` to a per-host vector of the TPU + // devices on every host in the spec's job. If the spec does not include an + // explicit job, "localhost" is used. Sets `*num_tpus_per_host` to the number + // of TPU devices in each host, and verifies that each host in the job has + // the same number of TPU devices. + // Returns an error if system_spec matches devices from a multiple jobs or + // replicas. + static absl::Status GetTPUDevices( + const DeviceNameUtils::ParsedName& system_spec, + const DeviceSet& device_set, int* num_tpus_per_host, + std::vector>* tpu_devices); + + // Perform 'action' on every node in 'graph' of type + // 'node_type'. This function is designed for use with configuration + // Ops that have no inputs or outputs. The arguments passed to 'action' are: + // 'configuration_node_name': the name of the node that matched + // 'configuration_device_name': the name of the device that the + // matching node is placed on + // 'host_devices': the set of TPU_SYSTEM devices on hosts with TPUs that are + // in the same system as the node that matched. + // 'input_dependencies': the set of nodes that have control edges to + // the matching node. + // 'output_dependencies': the set of output port, destination node, input port + // triples that have edges from the matching node. Input port is + // Graph::kControlSlot for a control edge. + // 'graph': the graph being mutated. + struct OutputDependency { + int src_output; + Node* dst; + int dst_input; + }; + static absl::Status ForConfigurationNodeMatchingType( + const string& node_type, Graph* graph, const DeviceSet& device_set, + const std::function< + absl::Status(const NodeDef& configuration_node_def, + const string& configuration_device_name, + const std::vector& host_devices, + const std::vector& input_dependencies, + const std::vector& output_dependencies, + Graph* graph)>& action); +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_TPU_GRAPH_REWRITE_DISTRIBUTED_TPU_REWRITE_HELPERS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass.h b/third_party/tflite-hdrs/tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass.h new file mode 100644 index 00000000..2c31b6d8 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass.h @@ -0,0 +1,619 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Rewrites TPUReplicate nodes into replicated computations on TPU. +// +// To represent a distributed TPU computation, we use the +// TPUReplicate operator, that describes a subgraph (represented as a +// Tensorflow function) to replicate across a TPU pod. +// +// Model parallelism and data parallelism: +// --------------------------------------- +// We support two different kinds of parallelism on TPU: +// * data parallelism (replication), or parallelization across batches, and +// * model parallelism, or parallelization within a batch. +// +// The function passed to a TPUReplicate operator is replicated many +// times across a TPU pod (data parallelism). The `num_replicas` attribute +// controls how many replicas of the computation to create. Replicas are mostly +// independent; replicas can only communicate using the CrossReplicaSum +// operator, which is typically used to communicate gradients during training. +// +// Each replica may optionally use more than one TPU core (model +// parallelism). The `num_cores_per_replica` attribute controls how many cores +// there are per replica. For each core, there is a virtual TPU_REPLICATED_CORE +// device that is only valid within replicated TPU computations (e.g., +// TPU_REPLICATED_CORE:0, TPU_REPLICATED_CORE:1, etc.); each TPU_REPLICATED_CORE +// device corresponds to one TPU core in every replica. +// Each replica has runs its own copy of the computation assigned to each +// TPU_REPLICATED_CORE device. +// +// The Python code is responsible for providing a device_assignment that +// describes how the replicated logical cores map to physical cores on the TPU +// topology. +// +// Inputs to TPUReplicate: +// ------------------------------ +// The TPUReplicate operator takes three kinds of inputs, in the +// following order: +// * per-replica inputs. If there are three per-replica inputs (A, B, C) and two +// replicas, the first six arguments to TPUReplicate will be: +// A0 B0 C0 A1 B1 C1 +// where Ai is the A input to the i-th replica. +// * distributed inputs. These inputs follow the per-replica inputs. +// If there are two distributed inputs (E, F) and two replicas, the following +// arguments to TPUReplicate will be: E F. +// But there is local E and F on each replica. +// * broadcast inputs. These inputs follow the distributed inputs. All +// replicas receive a copy of each of these inputs. +// * variables. Resource variables accessed by the computation follow the +// broadcast inputs. +// +// For example, for a computation with two replicas, three per-replica inputs +// (A, B, C), two distributed inputs(E, F), two broadcast inputs (X, Y), and two +// variables (V, W), the arguments to TPUReplicate will be: +// A0 B0 C0 A1 B1 C1 E F X Y V W +// and each replica will receive the following arguments: +// A B C E F X Y V W +// +// Distributed TPU compilation requires that the shapes of all operators +// be known statically at compilation time, before any nodes have executed. +// Shapes are determined using shape information emitted by InferShapes. It +// is not possible to replicate Tensorflow operators with unknown or dynamic +// shapes for TPU at present. +// +// Graph rewrite: +// -------------- +// Compilation replaces TPUReplicate operators with: +// * a single TPUCompile node that compiles the computations, +// * one TPUExecute node for each TPU device in the system that +// executes the relevant computation, +// * one ReadVariableOp for each variable accessed by the replicated +// computation, +// * one AssignVariableOp for each variable accessed by the replicated +// computation. An assignment is built even if a variable is only read by the +// computation. We do not know which variables are written until we apply the +// XlaCompiler to the computation, but that does not happen until after the +// rewrite. Conservatively, we write back the values of all variables after +// the computation completes. +// TODO(phawkins): only write back variables that the computation may write. +// * one Shape node for each Tensor or Variable input to the computation whose +// shape is not statically known at rewrite time. The input shapes are fed +// to the TPUCompile node. +// +// To ensure that the reads and writes seem to happen at the right time in the +// graph execution, we add control edges from all predecessors of the original +// TPUReplicate operator to each of the ReadVariableOp operators. +// Similarly, we add control edges from all of the AssignVariableOp operators to +// all of the successors of the TPUReplicate operator. +// +// The TPUReplicate rewrite must run before placement, since resource +// variable inputs will have DT_RESOURCE, which cannot be sent across devices, +// leading to objections from the placer. The rewrite rewrites the resource +// accesses into explicit ReadVariableOp and AssignVariableOp operators that the +// placer is free to colocate with the variables. + +#ifndef TENSORFLOW_CORE_TPU_GRAPH_REWRITE_DISTRIBUTED_TPU_REWRITE_PASS_H_ +#define TENSORFLOW_CORE_TPU_GRAPH_REWRITE_DISTRIBUTED_TPU_REWRITE_PASS_H_ + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/container/node_hash_map.h" +#include "absl/status/status.h" +#include "absl/types/span.h" +#include "tensorflow/compiler/jit/shape_inference.h" +#include "xla/service/computation_placer.h" +#include "xla/stream_executor/tpu/tpu_topology.h" +#include "xla/xla_data.pb.h" +#include "tensorflow/core/common_runtime/optimization_registry.h" +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/util/device_name_utils.h" + +namespace tensorflow { + +// Replaces clusters assigned to TPU_SYSTEM devices with +// TPUCompile and TPUExecute nodes assigned to the corresponding +// TPU devices. +class DistributedTPURewritePass : public GraphOptimizationPass { + public: + static void SetDistributedTpuRewritePassOptions( + bool distribute_vars, bool allow_xla_spmd_partition, + bool replicate_inputs_outputs_by_default_for_xla_spmd, + bool enable_cross_replica_sharding_mirrored_variables, + bool enable_automatic_model_parallelism, bool enable_xla_param_broadcast, + bool enable_multicore_locking, bool use_nd_sharding_ops); + + absl::Status Run(const GraphOptimizationPassOptions& options) override; + + // The following methods are public only for the use of unit tests. + + // See comment at the top of the file for how the inputs are ordered. + // Encapsulates the different TPU replicated node input and output + // information, and provide common APIs over them. + class ParameterInfo { + public: + ParameterInfo() = default; + ParameterInfo(int64_t num_replicas, int64_t num_per_replica_args, + int64_t num_distributed_args, int64_t num_broadcast_args, + int64_t num_variables, int64_t num_guaranteed_constants, + int64_t num_retvals_per_replica) + : num_replicas_(num_replicas), + num_per_replica_args_(num_per_replica_args), + num_distributed_args_(num_distributed_args), + num_broadcast_args_(num_broadcast_args), + num_variables_(num_variables), + num_guaranteed_constants_(num_guaranteed_constants), + num_retvals_per_replica_(num_retvals_per_replica) {} + + int64_t NumReplicas() const { return num_replicas_; } + + int64_t NumPerReplicaArgs() const { return num_per_replica_args_; } + + int64_t NumDistributedArgs() const { return num_distributed_args_; } + + int64_t NumBroadcastArgs() const { return num_broadcast_args_; } + + int64_t NumVariables() const { return num_variables_; } + + int64_t NumGuaranteedConstants() const { return num_guaranteed_constants_; } + + int64_t NumRetvalsPerReplica() const { return num_retvals_per_replica_; } + + bool IsPerReplicaArg(int64_t index) const { + return index < num_per_replica_args_; + } + + bool IsDistributedArg(int64_t index) const { + return index >= num_per_replica_args_ && + index < (num_per_replica_args_ + num_distributed_args_); + } + + bool IsBroadcastArg(int64_t index) const { + return (index >= num_per_replica_args_ + num_distributed_args_) && + index < (num_per_replica_args_ + num_distributed_args_ + + num_broadcast_args_); + } + + bool IsVariableArg(int64_t index) const { + return index >= (num_per_replica_args_ + num_distributed_args_ + + num_broadcast_args_) && + index < (num_per_replica_args_ + num_distributed_args_ + + num_broadcast_args_ + num_variables_); + } + + bool IsConstantArg(int64_t index) const { + return index >= (num_per_replica_args_ + num_distributed_args_ + + num_broadcast_args_ + num_variables_) && + index < (num_per_replica_args_ + num_distributed_args_ + + num_broadcast_args_ + num_variables_ + + num_guaranteed_constants_); + } + + // Returns the number of inputs which has been received by the host. + int64_t NumInputsFromHost() const { + return num_replicas_ * num_per_replica_args_ + num_distributed_args_ + + num_broadcast_args_ + num_variables_ + num_guaranteed_constants_; + } + + // Returns the number of inputs which will be sent to each replica. + int64_t NumInputsToEachReplica() const { + return num_per_replica_args_ + num_distributed_args_ + + num_broadcast_args_ + num_variables_ + num_guaranteed_constants_; + } + + // Returns the total number of output values returned to the host (for all + // replicas). + int64_t NumOutputsToHost() const { + return num_replicas_ * num_retvals_per_replica_; + } + + // Returns the position of the first per-replica argument, within the set + // of all hosts arguments. + // Broadcast arguments follow the distributed arguments. + int64_t FirstBroadcastArgFromHost() const { + return num_replicas_ * num_per_replica_args_ + num_distributed_args_; + } + + // Indices of mirrored variables across replicas, which should be + // categorized as per_replica_args. + const std::set& mirrored_variable_indices() const { + return mirrored_variable_indices_; + } + std::set* mutable_mirrored_variable_indices() { + return &mirrored_variable_indices_; + } + + private: + int64_t num_replicas_ = 1; + int64_t num_per_replica_args_ = 0; + int64_t num_distributed_args_ = 0; + int64_t num_broadcast_args_ = 0; + int64_t num_variables_ = 0; + int64_t num_guaranteed_constants_ = 0; + int64_t num_retvals_per_replica_ = 0; + std::set mirrored_variable_indices_; + }; + + // Mapping from TPUReplicate cluster name to tpu device names. Value is a + // mapping from [replica][core] to a TF device name. + typedef absl::flat_hash_map>> + TPUReplicateDeviceNamesMapping; + + // Determines which devices to use to run the computation. + // Inputs: + // * num_tpus_per_task: the number of TPU devices attached to each task + // * tpu_devices: a [task][device] collection of TPU devices + // * num_replicas: the number of replicas requested + // * num_cores_per_replica: the number of cores in each computation instance + // * topology_attr: the topology TPUReplicate attribute + // * device_assignment_attr: the device_assignment TPUReplicate attribute + // Outputs: + // * tf_device_assignment: a mapping from [replica][core] to a TF device name + // * devices_to_lock: a flat array of integer indices corresponding to devices + // that are used in this computation. They will be locked before the + // TPUExecute kernels are run, to ensure that the kernels from concurrent + // multi-core executions are enqueued consistently, i.e., all kernels from + // computation A before any kernel from computation B, thus preventing + // deadlock. + // * xla_device_assignment: a mapping from [replica][core] to a linearized TPU + // coordinate. + // TODO(phawkins): change tf_device_assignment to an xla::Array2D. + static absl::Status BuildDeviceAssignment( + const tpu::TpuTopologyExternal& topology, int num_tpus_per_task, + const std::vector>& tpu_devices, int num_replicas, + int num_cores_per_replica, const std::string& topology_attr, + absl::Span device_assignment_attr, + std::vector>* tf_device_assignment, + std::vector* devices_to_lock, + std::unique_ptr* xla_device_assignment); + + // Returns the `computation` graph attached to TPUReplicate operator + // `node`. `flr` is a FunctionLibraryRuntime to use when + // instantiating the function body. Sets `*arg_types` and + // `*retval_types` to the argument/return types of the function. + static absl::Status GetComputationForTPUReplicateOp( + const NameAttrList& function, FunctionLibraryRuntime* flr, + Graph* computation, DataTypeVector* arg_types, + DataTypeVector* retval_types); + + // Returns the shapes of the argument tensors and return values of the + // TPUReplicate operator `node` using the _output_shapes, + // _output_handle_shapes, and _output_handle_types annotations on the input + // nodes. Expects inputs in the following order (see comment at top of file): + // * num_replicas * num_per_replica_args per-replica inputs, + // * num_broadcast_args broadcast inputs, + // * num_variables variable inputs. + // Returns an error if the input shapes to `node` are not statically known. + // Also verifies that all replicas have identical input shapes for their + // per-replica inputs. + static absl::Status GetArgAndRetvalShapes( + const GraphShapeInfo& shape_info, const Node& node, + const ParameterInfo& params_info, std::vector* arg_shapes, + std::vector* retval_shapes); + + // Assigns arguments and return values to cores. The assignment is represented + // as an XLA op sharding, so that an argument can be replicated across cores. + // `arg_sharding` and `retval_sharding` are vectors of shardings indexed by + // argument/retval number. + // `arg_fast_mem` is vector of fast_mem indication which is indexed by + // argument number. + static absl::Status AssignArgsAndRetvalsToCores( + int num_cores_per_replica, const ParameterInfo& params_info, + const DataTypeVector& arg_types, + const std::vector& arg_shapes, + const DataTypeVector& retval_types, + const std::vector& retval_shapes, const Graph& graph, + const Node* replicate_node, FunctionLibraryRuntime* flr, + bool allow_parameter_replication_for_spmd, + std::vector<::xla::OpSharding>* arg_sharding, + std::vector* arg_fast_mem, + std::vector<::xla::OpSharding>* retval_sharding, + std::vector* arg_names); + + // Populates `*variables` with the "variables" inputs to `index`-th output of + // `node`. + struct VariableInput { + Node* node; + int index; + + // Type of the variable's value. Note that this is different to the type of + // the output of 'variable', which is always DT_RESOURCE. + DataType dtype; + }; + static absl::Status FindVariableInputs(const Node& node, + const NameRangeMap& input_range_map, + std::vector* variables); + + // Populates '*guaranteed_constants' with the "guaranteed_constants" inputs + // to 'node'. + static absl::Status FindGuaranteedConstantInputs( + const Node& node, const NameRangeMap& input_range_map, + std::vector* guaranteed_constants); + + // Builds Shape nodes that compute the shapes of arguments whose shapes are + // not statically known. + static absl::Status BuildDynamicShapeNodes( + const Node& replicate_node, const std::vector& arg_shapes, + const ParameterInfo& params_info, + const std::vector& variable_reads, Graph* graph, + std::vector* dynamic_shape_nodes); + + // Builds a TPUCompile node that compiles the computation in + // `function_names`. calls `nodes`. + // TODO(b/33943292): at present, for model parallelism with Send/Recv to work + // the `nodes` must correspond to the computations assigned to TPU:0, + // TPU:1, ... in order since XLA hard-codes the chip IDs in the generated + // executables. + static absl::Status BuildCompileNode( + const Node* replicate_node, const NameAttrList& function, + uint64_t library_fingerprint, const ParameterInfo& params_info, + const std::vector& arg_shapes, + const DataTypeVector& arg_types, + const std::vector& guaranteed_constant_nodes, + const std::string& session_handle, + const std::vector<::xla::OpSharding>& arg_sharding, + const std::vector& arg_fast_mem, + const std::vector& arg_names, + const std::vector<::xla::OpSharding>& retval_sharding, + int num_cores_per_replica, const std::string& compile_device, + const xla::DeviceAssignment* xla_device_assignment, + const std::vector& dynamic_shape_nodes, Graph* graph, + Node** compile_node, int64_t autotuner_thresh); + + // Builds a TPUCompileSucceededAssert node that verifies that compilation + // succeeded and replaces the TPUCompilationStatus node in the graph. + static absl::Status BuildCompilationStatusReturnNodes( + Node* replicate_node, Node* compile_node, + absl::Span devices_to_lock, Node** control_after_compilation, + Node** multilock_acquire, Graph* graph); + + // Builds ReadVariableOp nodes that read `variables`, with a control + // edges that ensure they happen after `control_predecessor`. + static absl::Status BuildVariableReads( + absl::Span variables, Node* control_predecessor, + Graph* graph, std::vector* variable_reads); + + // Returns true if graph or functions contain resource write op, otherwise + // return false. + // TODO(b/137048563): Recognize unused resource rewrite op. + static bool ContainsResourceWriteOp(const Graph& graph, + const FunctionLibraryDefinition& fld); + // Struct that describes a variable value to be written back from TPUExecute. + struct VariableWrite { + // A node:output pair containing a boolean tensor that determines whether + // the value should be written back. + Node* predicate; + int predicate_output; + + // A node:output pair containing the value to be written back. + Node* value; + int value_output; + }; + + // Builds AssignVariableOp nodes that write `variables` with the values from + // `variable_writes`, with control edges that ensure the writes happen before + // `control_successor`. + static absl::Status BuildVariableWrites( + absl::Span variables, Node* control_successor, + absl::Span variable_writes, Graph* graph); + + // Builds TPUExecute operators assigned to each TPU device + // involved in the computation. + // Arguments: + // * `params_info` is the structure containing the information about the + // TPUReplicate node inputs and outputs. + // * `num_tasks` is the number of TensorFlow tasks in the slice. + // * `num_cores_per_replica` is the number of cores which are dedicated to + // each replica. + // * `replicate_node` is the original TPUReplicate node. + // * `arg_names` are the names of the arguments to the computation function + // passed as argument to TPUReplicate, including per-replica, + // broadcast, and variable arguments. + // * `arg_types` are the corresponding types of the arguments. + // * `arg_shapes` are the corresponding shapes (and handle types/shapes, if + // applicable). + // * `arg_shardings` and `retval_shardings` are mappings from + // arguments/return indices to shardings, as returned by + // `AssignArgsAndRetvalsToCores`. + // * `pod_devices` lists the devices to assign to each core of each replica. + // * `variable_reads` is a vectors of ReadVariableOp operators, one for each + // variable argument to the computation. + // * The execute operators will have a control edge from + // `control_predecessor` and another control edge to `control_successor`. + // Populates '*variable_writes' with information about variable values to + // write back. + static absl::Status BuildExecuteNodes( + const ParameterInfo& params_info, int num_tasks, + int num_cores_per_replica, const Node& replicate_node, + const std::vector& arg_names, + const DataTypeVector& arg_types, + const std::vector& arg_shapes, + const DataTypeVector& retval_types, + const std::vector<::xla::OpSharding>& arg_shardings, + const std::vector<::xla::OpSharding>& retval_shardings, + const std::vector>& tpu_device_names, + Node* compile_node, const std::vector& variable_reads, + Node* control_predecessor, Node* control_successor, + Node* multilock_acquire, std::vector* variable_writes, + Graph* graph); + + // Connects the compile node to all the host transfer nodes, and removes the + // key placeholder node that was previously standing in for it. + // Arguments: + // * `compile_node` is the TPUCompile node that has been added to the graph. + // * `key_placeholder_node` is the placeholder node to send the key to all the + // host + // * transfer nodes in the original graph. + // * `graph` is the graph being rewritten. + static absl::Status ConnectHostComputeNodes(Node* compile_node, + Node* key_placeholder_node, + Graph* graph); + + // Map from a Node in an outside_compilation cluster in the original graph to + // the list of Nodes, one for each replica, that it is expanded into during + // replication. + typedef absl::node_hash_map> NodeToNodeReplicasMap; + + // Map from the name of an outside_compilation cluster to the model-parallel + // core index that the HostCompute Op should be placed on in that cluster. + typedef std::map HostComputeCoreMap; + + // Map from the name of an outside_compilation cluster to the list of Nodes + // that should run on the host for that cluster. + typedef std::map> OutsideCompilationNodeMap; + + // Copies the outside_compilation nodes in a cluster to create replica + // replica_index. + static absl::Status CopyOutsideCompilationNodes( + int replica_index, const std::vector& outside_compilation_nodes, + const DeviceNameUtils::ParsedName& tpu_device, + const DeviceNameUtils::ParsedName& partial_device, + NodeToNodeReplicasMap* node_images, Graph* graph); + + // Replicates all the nodes in outside_compilation clusters in a compiled + // computation. + static absl::Status ReplicateOutsideCompilationNodes( + const std::vector>& tf_device_assignment, + const HostComputeCoreMap& host_compute_core, + const OutsideCompilationNodeMap& outside_compilation_nodes, + NodeToNodeReplicasMap* node_images, Graph* graph); + + // Lifts the edges between original outside_compilation nodes in a cluster + // onto their replicas. + static absl::Status CopyOutsideCompilationEdges( + const std::vector& outside_compilation_nodes, + const NodeToNodeReplicasMap& node_images, + std::unordered_map outside_compilation_inputs, + Graph* graph); + + // Lifts all the edges in outside_compilation clusters in a compiled + // computation to their replicas. + static absl::Status ReplicateOutsideCompilationEdges( + const OutsideCompilationNodeMap& outside_compilation_nodes, + const NodeToNodeReplicasMap& node_images, + std::unordered_map outside_compilation_inputs, + Graph* graph); + + // Removes all the original outside_compilation nodes from the graph, + // following replication. + static absl::Status RemoveOutsideCompilationNodes( + const NodeToNodeReplicasMap& node_images, Graph* graph); + + // Lowers outside compilation functional nodes (If/While/function call). + // Otherwise, when we have multiple workers, device placer will not be able to + // place nodes if outside compilation has DT_RESOURCE inputs (e.g. a + // DT_RESOURCE input fed into multiple While nodes on different devices). + static absl::Status LowerOutsideCompilationFunctionalNodes( + Graph* g, FunctionLibraryDefinition& flib_def, + const TPUReplicateDeviceNamesMapping& tpu_replicate_device_names_mapping); + + // Parses the 'host_compute_core' attribute on replicate_node to get the + // replicated core id of each outside_compilation cluster. + static absl::Status ParseHostComputeCores( + const Node& replicate_node, + const OutsideCompilationNodeMap& outside_compilation_nodes, + HostComputeCoreMap* host_compute_core); + + // Gets the physical topology information about the TPU system. + static absl::Status GetDeviceTopology( + const DeviceSet& device_set, const Node& replicate_node, + int* num_replicas, int* num_cores_per_replica, int* num_tasks, + std::vector>* tf_device_assignment, + std::vector* devices_to_lock, + std::unique_ptr* xla_device_assignment, + std::string* tpu_compilation_device); + + // Gets the types of args, retvals, and parameters. + static absl::Status GetIOTypes( + int num_replicas, const Node& replicate_node, FunctionLibraryRuntime* flr, + Graph* graph, NameRangeMap* input_name_map, const NameAttrList** function, + std::unique_ptr* computation, DataTypeVector* arg_types, + DataTypeVector* retval_types, ParameterInfo* params_info); + + // Find known constants and deals with variable reads. + static absl::Status DealWithConstantsAndVariables( + const Node& replicate_node, const NameRangeMap& input_name_map, + Graph* graph, Node* host_transfer_sequencer, Node* control_before, + Node* control_after, absl::Span variable_nodes, + std::vector* guaranteed_constant_nodes, + std::vector* variable_reads); + + // Adds NoOp nodes for sequencing computation and variable reads/writes. + static absl::Status BuildSequencingNodes( + const std::string& tpu_compilation_device, const Node& replicate_node, + Graph* graph, Node** host_transfer_sequencer, Node** control_before, + Node** control_after); + + // Performs the pass's rewrite on a TPUReplicate node `node`. + static absl::Status RewriteTPUReplicateNode( + const std::string& session_handle, const DeviceSet& device_set, + Node* replicate_node, FunctionLibraryDefinition* flib_def, + FunctionLibraryRuntime* flr, Node* host_compute_key_placeholder_node, + const OutsideCompilationNodeMap& outside_compilation_nodes, + const std::vector& head_tail_outside_compilation_nodes, + NodeToNodeReplicasMap* outside_compilation_node_images, Graph* graph, + const GraphShapeInfo& shape_info, + TPUReplicateDeviceNamesMapping* tpu_replicate_device_names_mapping, + int64_t autotuner_thresh); + + // Performs host training loop optimization. For example, when TPUExecute + // node is inside a while loop, then model weight variables can be sharded + // in XLA preferred layout and then unsharded only at the very last iteration + // to reduce the number of all_gather. + static absl::Status PerformHostTrainingLoopOptimization( + Graph* graph, FunctionLibraryDefinition* flib_def, + FunctionLibraryRuntime* flr); + + // Heuristically place some nodes with unassigned devices on TPUs for + // performance reasons. + static absl::Status PlaceUnassignedDeviceNodesOnTPUIfPossible(Graph* graph); + + // Updates the head and tail outside compiled nodes so that nodes have the + // correct device and removes the replication and outside compilation + // attributes so that these nodes do not trigger further graph optimization + // passes. + static absl::Status UpdateHeadTailOutsideCompilation( + const std::vector>& tf_device_assignment, + const std::vector& head_tail_outside_compilation_nodes); + + private: + static bool distribute_vars_; + static bool allow_xla_spmd_partition_; + static bool replicate_inputs_outputs_by_default_for_xla_spmd_; + static bool enable_cross_replica_sharding_mirrored_variables_; + static bool enable_automatic_model_parallelism_; + static bool enable_xla_param_broadcast_; + static bool enable_multicore_locking_; + static bool use_nd_sharding_ops_; + absl::Status InternalRun(const GraphOptimizationPassOptions& options); +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_TPU_GRAPH_REWRITE_DISTRIBUTED_TPU_REWRITE_PASS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass_internal.h b/third_party/tflite-hdrs/tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass_internal.h new file mode 100644 index 00000000..ad4d74c2 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass_internal.h @@ -0,0 +1,38 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_TPU_GRAPH_REWRITE_DISTRIBUTED_TPU_REWRITE_PASS_INTERNAL_H_ +#define TENSORFLOW_CORE_TPU_GRAPH_REWRITE_DISTRIBUTED_TPU_REWRITE_PASS_INTERNAL_H_ + +#include + +namespace tensorflow { + +// Implementation details of distributed_tpu_rewrite_pass.cc, please DO NOT +// depend on these. +namespace internal { + +// When set to a value >= 0, overrides the node_id. Used for getting +// deterministic node_ids during testing. +void OverrideNodeIdForTesting(int64_t node_id); + +// Retrieves the node id, used to make some node names unique in the rewrite +// pass. +uint64_t GetNodeId(); + +} // namespace internal +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_TPU_GRAPH_REWRITE_DISTRIBUTED_TPU_REWRITE_PASS_INTERNAL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/tpu/graph_rewrite/encapsulate_tpu_computations_pass.h b/third_party/tflite-hdrs/tensorflow/core/tpu/graph_rewrite/encapsulate_tpu_computations_pass.h new file mode 100644 index 00000000..37ba029c --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/tpu/graph_rewrite/encapsulate_tpu_computations_pass.h @@ -0,0 +1,80 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Rewrites computations generated by the tpu.replicate() Python code into +// TPUReplicate operators. +// +// The tpu.replicate() does two main things: +// a) marks operators that make up a TPU computation with the attribute +// _tpu_replicate=XYZ, where XYZ is a unique key. +// b) adds TPUReplicatedInput and TPUReplicatedOutput nodes to represent +// replicated inputs. These nodes are not marked with the _tpu_replicate +// attribute. + +#ifndef TENSORFLOW_CORE_TPU_GRAPH_REWRITE_ENCAPSULATE_TPU_COMPUTATIONS_PASS_H_ +#define TENSORFLOW_CORE_TPU_GRAPH_REWRITE_ENCAPSULATE_TPU_COMPUTATIONS_PASS_H_ + +#include +#include +#include + +#include "absl/status/status.h" +#include "tensorflow/compiler/jit/encapsulate_util.h" +#include "tensorflow/core/common_runtime/optimization_registry.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/platform/status.h" + +namespace tensorflow { + +// Encapsulates nodes marked with the _tpu_replicate attribute into +// TPUReplicate operators. +class EncapsulateTPUComputationsPass : public GraphOptimizationPass { + public: + absl::Status Run(const GraphOptimizationPassOptions& options) override; + + // The following methods are public only for unit tests. + + // This pass has two stages: + // a) first, we call the EncapsulateSubgraphsPass to encapsulate all nodes + // marked with the same _tpu_replicate attribute into functions. These + // functions contain the computations to be passed to TPUReplicate. During + // encapsulation, we sort the arguments into the order expected by + // TPUReplicate. + static absl::Status Encapsulate(std::unique_ptr* graph, + FunctionLibraryDefinition* flib_def); + + // b) we rewrite the function calls generated in phase (a) into TPUReplicate + // operators. We also flatten the TPUReplicatedInput and + // TPUReplicatedOutput replicated input and output nodes of the function + // call into the replicated input and outputs of the TPUReplicate operator. + static absl::Status BuildTPUReplicateOps(Graph* graph); +}; + +// Graph optimization pass that calls `ExtractOutsideCompilation` for all XLA +// computation nodes. +class ExtractOutsideCompilationPass : public GraphOptimizationPass { + public: + absl::Status Run(const GraphOptimizationPassOptions& options) override; + + static absl::Status ProcessHeadTailOutsideCompilation( + const std::string& outside_compilation_attr_name, int* lifted_arg_count, + std::unordered_map* clusters, Graph* g, + FunctionLibraryRuntime* flr, FunctionLibraryDefinition* fld); +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_TPU_GRAPH_REWRITE_ENCAPSULATE_TPU_COMPUTATIONS_PASS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/tpu/graph_rewrite/host_training_loop_optimization_util.h b/third_party/tflite-hdrs/tensorflow/core/tpu/graph_rewrite/host_training_loop_optimization_util.h new file mode 100644 index 00000000..a3d2e01f --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/tpu/graph_rewrite/host_training_loop_optimization_util.h @@ -0,0 +1,83 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_TPU_GRAPH_REWRITE_HOST_TRAINING_LOOP_OPTIMIZATION_UTIL_H_ +#define TENSORFLOW_CORE_TPU_GRAPH_REWRITE_HOST_TRAINING_LOOP_OPTIMIZATION_UTIL_H_ + +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/platform/status.h" + +namespace tensorflow { +namespace tpu { + +struct LoopArgInfo { + std::string enter_node_name; + // Exit nodes are optional for loop invariant while loop args. + std::optional exit_node_name; +}; + +struct HostTrainingLoopInfo { + // Name and attribute information about the function in which + // host training loop is included. If host training loop is not + // inside a function call, then `function_name` and `function_attrs` + // are nullopt. + std::optional encapsulating_function_name; + std::optional encapsulating_function_attrs; + + // TPU Compile node as within a host training loop. + std::string compile_node_name; + + // Name of the while loop in which TPU compile op is located. + std::string while_loop_name; + + // Name of the node that represents loop condition. + std::string loop_cond_node_name; + + // Exit and Enter node names for each loop arguments. + std::vector loop_arguments; + + std::unordered_set loop_nodes; // NOLINT +}; + +// Walks through the `graph`, recursively if functional nodes exist, and +// identifies all host training loops. Host training loops are the inner +// most while loops that encapsulates TPUCompileOp node. This would be +// later used/analyzed to introduce host loop specific optimizations such +// as adding sharded weight update. +absl::Status DetectHostTrainingLoop( + const std::string* current_function_name, + const AttrValueMap* current_function_attr, + const FunctionLibraryDefinition* library, Graph* graph, + FunctionLibraryRuntime* flr, + std::vector* host_training_loops_info); + +// Injects VariableReshardOps to before and after TPUExecute op inside +// host training loop body. This effectively applies sharded weight update +// on model weight variables. +absl::Status AddReshardOp(Graph* graph, + const HostTrainingLoopInfo& host_loop_info); + +} // namespace tpu +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_TPU_GRAPH_REWRITE_HOST_TRAINING_LOOP_OPTIMIZATION_UTIL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/tpu/graph_rewrite/incomplete_nodedef_builder.h b/third_party/tflite-hdrs/tensorflow/core/tpu/graph_rewrite/incomplete_nodedef_builder.h new file mode 100644 index 00000000..27304087 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/tpu/graph_rewrite/incomplete_nodedef_builder.h @@ -0,0 +1,62 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_TPU_GRAPH_REWRITE_INCOMPLETE_NODEDEF_BUILDER_H_ +#define TENSORFLOW_CORE_TPU_GRAPH_REWRITE_INCOMPLETE_NODEDEF_BUILDER_H_ + +#include + +#include "absl/status/status.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +// Convenience builder to build NodeDefs without specifying the inputs. This is +// similar to NodeDefBuilder except inputs are not specified. +// TODO(jpienaar): Clean up NodeDefBuilder and remove this class. +class IncompleteNodeDefBuilder { + public: + IncompleteNodeDefBuilder(const string& name, const string& op, + const NodeDebugInfo& debug); + + IncompleteNodeDefBuilder& AddAttr(const string& attr, const DataType& type); + IncompleteNodeDefBuilder& AddAttr(const string& attr, int val); + + IncompleteNodeDefBuilder& Device(const string& device); + + absl::Status Build(Graph* graph, Node** n); + + static IncompleteNodeDefBuilder Identity(const string& name, + const DataType& type, + const NodeDebugInfo& debug); + static IncompleteNodeDefBuilder Merge(const string& name, + const DataType& type, + const NodeDebugInfo& debug, int n); + static IncompleteNodeDefBuilder Switch(const string& name, + const DataType& type, + const NodeDebugInfo& debug); + + private: + NodeDef nodedef_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_TPU_GRAPH_REWRITE_INCOMPLETE_NODEDEF_BUILDER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/tpu/graph_rewrite/tpu_embedding_rewrite_pass_utils.h b/third_party/tflite-hdrs/tensorflow/core/tpu/graph_rewrite/tpu_embedding_rewrite_pass_utils.h new file mode 100644 index 00000000..3589dd83 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/tpu/graph_rewrite/tpu_embedding_rewrite_pass_utils.h @@ -0,0 +1,37 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_TPU_GRAPH_REWRITE_TPU_EMBEDDING_REWRITE_PASS_UTILS_H_ +#define TENSORFLOW_CORE_TPU_GRAPH_REWRITE_TPU_EMBEDDING_REWRITE_PASS_UTILS_H_ + +#include "absl/status/status.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/platform/status.h" + +namespace tensorflow { + +// Adds a new TensorFlow graph node, with the output convention matching most TF +// code rather than the order used by Graph::AddNode(). +absl::Status AddNode(const NodeDef& n_def, Node** n, Graph* graph); + +// Replaces one TensorFlow graph node with another (specified by a NodeDef), +// moving all the edges. +absl::Status ReplaceNode(const NodeDef& to_def, Node* from, Node** to, + Graph* graph); + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_TPU_GRAPH_REWRITE_TPU_EMBEDDING_REWRITE_PASS_UTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/tpu/graph_rewrite/tpu_embedding_software_deduplication_rewrite_pass.h b/third_party/tflite-hdrs/tensorflow/core/tpu/graph_rewrite/tpu_embedding_software_deduplication_rewrite_pass.h new file mode 100644 index 00000000..7be458b0 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/tpu/graph_rewrite/tpu_embedding_software_deduplication_rewrite_pass.h @@ -0,0 +1,57 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_TPU_GRAPH_REWRITE_TPU_EMBEDDING_SOFTWARE_DEDUPLICATION_REWRITE_PASS_H_ +#define TENSORFLOW_CORE_TPU_GRAPH_REWRITE_TPU_EMBEDDING_SOFTWARE_DEDUPLICATION_REWRITE_PASS_H_ + +#include "absl/status/status.h" +#include "tensorflow/core/common_runtime/optimization_registry.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/status.h" + +namespace tensorflow { + +// Rewrites the graph and function defs in the specified +// GraphOptimizationPassOptions object for software deduplication. +// +// For the graph, groups the RecvTPUEmbeddingActivations and +// SendTPUEmbeddingGradients nodes by their _tpu_replicate attribute. For each +// such group: +// 1. Inserts a XlaRecvTPUEmbeddingDeduplicationData node into the graph. +// 2. Replaces the public RecvTPUEmbeddingActivations node (if present) with the +// internal XlaRecvTPUEmbeddingActivations node. +// 3. Replaces the public SendTPUEmbeddingGradients node (if present) with the +// internal XlaSendTPUEmbeddingGradients node. +// 4. Connects the outputs of the XlaRecvTPUEmbeddingDeduplicationData node with +// the inputs of the XlaRecvTPUEmbeddingActivations and +// XlaSendTPUEmbeddingGradients nodes. +// +// Iterates through the list of functions in the specified +// GraphOptimizationPassOptions object. Performs the same steps 1-4 specified +// above for each function. +// +// If multiple RecvTPUEmbeddingActivations nodes or SendTPUEmbeddingGradients +// nodes are present in the same function or in the same _tpu_replicate group, +// an InvalidArgument error is returned to the caller. +class TPUEmbeddingSoftwareDeduplicationRewritePass : + public GraphOptimizationPass { + public: + absl::Status Run(const GraphOptimizationPassOptions& options) override; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_TPU_GRAPH_REWRITE_TPU_EMBEDDING_SOFTWARE_DEDUPLICATION_REWRITE_PASS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/tpu/graph_rewrite/update_tpu_embedding_ops_passes.h b/third_party/tflite-hdrs/tensorflow/core/tpu/graph_rewrite/update_tpu_embedding_ops_passes.h new file mode 100644 index 00000000..1c1dd7bb --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/tpu/graph_rewrite/update_tpu_embedding_ops_passes.h @@ -0,0 +1,55 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Rewrites ConfigureTPUEmbedding Op into nodes which set up TPUEmbedding. + +#ifndef TENSORFLOW_CORE_TPU_GRAPH_REWRITE_UPDATE_TPU_EMBEDDING_OPS_PASSES_H_ +#define TENSORFLOW_CORE_TPU_GRAPH_REWRITE_UPDATE_TPU_EMBEDDING_OPS_PASSES_H_ + +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/container/node_hash_map.h" +#include "absl/status/status.h" +#include "tensorflow/core/common_runtime/optimization_registry.h" +#include "tensorflow/core/framework/function.pb.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/platform/status.h" + +namespace tensorflow { + +class UpdateTPUEmbeddingEnqueueOrdinalPass : public GraphOptimizationPass { + public: + absl::Status Run(const GraphOptimizationPassOptions& options) override; +}; + +class UpdateTPUEmbeddingModePass : public GraphOptimizationPass { + public: + absl::Status Run(const GraphOptimizationPassOptions& options) override; + + static absl::Status GetEnqueueOpsFromGraph( + Graph* graph, absl::flat_hash_map* enqueue); + static absl::Status UpdateGraphEnqueueOp(bool training, Graph* graph, + Node* enqueue); + static absl::Status GetEnqueueOpsFromFunctionDef( + FunctionDef* function, std::map* enqueue); + static absl::Status UpdateFunctionDefEnqueueOp(int enqueue, bool training, + FunctionDef* function, + bool* updated); +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_TPU_GRAPH_REWRITE_UPDATE_TPU_EMBEDDING_OPS_PASSES_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/tpu/graph_rewrite/variable_merger_pass.h b/third_party/tflite-hdrs/tensorflow/core/tpu/graph_rewrite/variable_merger_pass.h new file mode 100644 index 00000000..eaaaf1cf --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/tpu/graph_rewrite/variable_merger_pass.h @@ -0,0 +1,49 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Optimization pass that merges VarHandleOps and ReadVariableOps into their +// fused forms. +// +// The goal of this pass is to fix a latency problem sometimes observed in +// inference benchmarks. Often a inference step starts by reading the value of +// many weights. Reading a resource variable requires a VarHandleOp and a +// ReadVariableOp per variable. Running hundreds of trivial ops can add hundreds +// of microseconds of latency to the critical path of an inference step. The +// inter-op latency of the executor can be easily hundreds of nanoseconds, which +// rapidly adds up over many inexpensive ops. +// +// This pass merges VarHandleOps that have only the graph source node as a +// predecessor into a single VarHandlesOp that reads all at once. +// It then merges ReadVariablesOp that have no control inputs and originate from +// the same handle op into a single large ReadVariablesOp. + +#ifndef TENSORFLOW_CORE_TPU_GRAPH_REWRITE_VARIABLE_MERGER_PASS_H_ +#define TENSORFLOW_CORE_TPU_GRAPH_REWRITE_VARIABLE_MERGER_PASS_H_ + +#include "absl/status/status.h" +#include "tensorflow/core/common_runtime/optimization_registry.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/platform/status.h" + +namespace tensorflow { + +class VariableMergerPass : public GraphOptimizationPass { + public: + absl::Status Run(const GraphOptimizationPassOptions& options) override; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_TPU_GRAPH_REWRITE_VARIABLE_MERGER_PASS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/tpu/kernels/compiled_subgraph.h b/third_party/tflite-hdrs/tensorflow/core/tpu/kernels/compiled_subgraph.h new file mode 100644 index 00000000..bec973c6 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/tpu/kernels/compiled_subgraph.h @@ -0,0 +1,173 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_TPU_KERNELS_COMPILED_SUBGRAPH_H_ +#define TENSORFLOW_CORE_TPU_KERNELS_COMPILED_SUBGRAPH_H_ + +#include +#include +#include + +#include "tensorflow/core/platform/refcount.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/tpu/kernels/tpu_program_group_interface.h" + +namespace tensorflow { +namespace tpu { + +// Forward declaration to avoid circular dependency. +class TpuCompilationCacheInterface; + +// Cache for compiled TPU program. +// +// Each key identifies a unique subgraph, and the value is the vector of +// protos that are emitted by compiling the subgraph. +// +// When a subgraph is considered for compilation, the client calls +// +// auto subgraph_key = ; +// auto compile_function = ; +// auto per_step_ref_holder = ; +// int64 uid; +// std::vector proto_key; +// CompileIfKeyAbsent(subgraph_key, per_step_ref_holder, &uid, &proto_key, +// compile_function); +// +// where subgraph_key is the key computed for the subgraph. On success, +// proto_key contains a vector of keys, where proto_key[i] can be used to look +// up the ith proto compiled from the subgraph, and uid contains an identifier +// that can be used in place of key for clients that require cheap +// serializable handles. If the compiled protos were not present in the cache, +// compile_function would be called to generate them. per_step_ref_holder +// extends the lifetime of cached results: it is guaranteed that the protos +// indicated in proto_key will be available for lookup for at least as long as +// per_step_ref_holder is not deleted. +// +// If the caller passes nullptr instead of a per_step_ref_holder then the +// caller is responsible for calling Release(subgraph_key) once for every call +// to CompileIfKeyAbsent(subgraph_key, ...) to discard the reference to the +// compilation results, after the caller is sure it will not look up the +// compiled executables again. +// +// Subsequently the client can call +// +// std::unique_ptr entry; +// Lookup(proto_key, &entry); +// auto proto = entry->get(); +// +// or +// +// std::unique_ptr entry; +// Lookup(uid, proto_index, &entry); +// auto proto = entry->get(); +// +// to access a cached proto. +// TODO(misard) Switch the existing TPU ops to use uid+proto_index instead of +// string keys for proto lookups. +// +// +// Usage details within the system: +// +// This cache lives in the resource manager of the TPU_SYSTEM device where the +// compiler runs, typically worker 0 of the system. The cache is discarded and +// a new one created whenever the system is reinitialized. +// +// A compiled subgraph is placed into the cache using a key that is a +// combination of the function name, guaranteed_constants, the shapes of the +// dynamic inputs to the subgraph, and the function library in use at the time +// of execution. +// +// Whenever a compile Op is run, it looks to see if there is already an entry +// in the cache corresponding to that Op and the current dynamic shapes, and +// creates one if not. The entry is marked as most recently used in the cache +// by the compile Op. The entry is reference counted. The cache owns one entry +// , and each step that has executed a compile Op referring to the entry owns +// a reference until that step completes. +// +// If the cache exceeds a configured storage limit, entries are marked for +// eviction in order of least recently used. An entry is not evicted until all +// references to it are discarded, so an entry that is marked for eviction can +// still be looked up by the execute Ops in a running step. If another Compile +// Op looks up an entry that is marked for eviction, the entry will be +// unmarked and set to most recently used. +// +struct CompiledSubgraph : public core::RefCounted { + TpuCompilationCacheInterface* parent = nullptr; // Not owned. + + bool initialized = false; + + // The Status returned by the compilation function when the entry is + // initialized. This status will be returned to any client that requests the + // entry. + absl::Status initialization_status; + + // Counter to keep track of LRU entries for the eviction policy. + int64_t last_use = -1; + + // The unique key describing this entry. + std::string subgraph_key; + + // The uid describing this entry. + int64_t uid; + + // Compilation cache proto key to identify the cache entry. + std::vector proto_key; + + // Fingerprints of sharding programs if there is any. + std::vector sharding_key; + + // The number of 'external' client-held references to the entry. + int external_references = 0; + + // The sum of the SpaceUsed of each of the elements of programs; an estimate + // of how much RAM the entry consumes, used to determine when entries must + // be marked for eviction. + int64_t total_size = 0; + + // Debug info in case we miss. + std::string cache_entry_debug_string; + + // Entries representing the associated sharding and unsharding programs, + // which share the same life time of the owning main entry, so we always use + // the main entry's ref count. + std::unique_ptr sharding_entry; + std::unique_ptr unsharding_entry; + + // Only used for the nested sharding/unsharding entries to point to the + // owning main entry. + CompiledSubgraph* main_entry = nullptr; + + // Compiled TPU program group. + std::unique_ptr tpu_program_group; + + // Computes total program size. + size_t ComputeTotalSize() const { + CHECK_EQ(total_size, 0); + int64_t size = tpu_program_group->program_size(); + + if (sharding_entry != nullptr) { + size += sharding_entry->total_size; + } + if (unsharding_entry != nullptr) { + size += unsharding_entry->total_size; + } + return size; + } +}; + +} // namespace tpu +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_TPU_KERNELS_COMPILED_SUBGRAPH_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/tpu/kernels/infeed_ops.h b/third_party/tflite-hdrs/tensorflow/core/tpu/kernels/infeed_ops.h new file mode 100644 index 00000000..d6e24cf4 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/tpu/kernels/infeed_ops.h @@ -0,0 +1,89 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_TPU_KERNELS_INFEED_OPS_H_ +#define TENSORFLOW_CORE_TPU_KERNELS_INFEED_OPS_H_ + +#include +#include + +#include "xla/shape.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/tpu/kernels/transfer_ops.h" + +namespace tensorflow { + +// TODO(b/65200690): Rework this when there is a callback based infeed API to +// StreamExecutor. + +// The InfeedEnqueue op is used to deliver data to the device infeed queue. +class TpuInfeedEnqueueOp : public TpuTransferAsyncOpKernel { + public: + explicit TpuInfeedEnqueueOp( + OpKernelConstruction* ctx, + std::unique_ptr transfer_op); + absl::Status DoWork(OpKernelContext* ctx, int device_ordinal) override; + + private: + TensorShape shape_; + DataType dtype_; + xla::Shape xla_shape_; + + TpuInfeedEnqueueOp(const TpuInfeedEnqueueOp&) = delete; + TpuInfeedEnqueueOp& operator=(const TpuInfeedEnqueueOp&) = delete; +}; + +// The InfeedEnqueueTuple op is used on the host to deliver multiple tensors to +// the device infeed queue as an XLA tuple. +class TpuInfeedEnqueueTupleOp : public TpuTransferAsyncOpKernel { + public: + explicit TpuInfeedEnqueueTupleOp( + OpKernelConstruction* ctx, + std::unique_ptr transfer_op); + absl::Status DoWork(OpKernelContext* ctx, int device_ordinal) override; + + private: + std::vector shapes_; + DataTypeVector dtypes_; + xla::Shape tuple_shape_; + + TpuInfeedEnqueueTupleOp(const TpuInfeedEnqueueTupleOp&) = delete; + TpuInfeedEnqueueTupleOp& operator=(const TpuInfeedEnqueueTupleOp&) = delete; +}; + +// The InfeedEnqueuePrelinearizedBufferOp op is used to transfer prelinearized +// buffers to the device infeed queue. +class InfeedEnqueuePrelinearizedBufferOp : public TpuTransferAsyncOpKernel { + public: + explicit InfeedEnqueuePrelinearizedBufferOp( + OpKernelConstruction* ctx, + std::unique_ptr transfer_op); + + absl::Status DoWork(OpKernelContext* ctx, int device_ordinal) override; + + private: + InfeedEnqueuePrelinearizedBufferOp( + const InfeedEnqueuePrelinearizedBufferOp&) = delete; + InfeedEnqueuePrelinearizedBufferOp& operator=( + const InfeedEnqueuePrelinearizedBufferOp&) = delete; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_TPU_KERNELS_INFEED_OPS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/tpu/kernels/outfeed_ops.h b/third_party/tflite-hdrs/tensorflow/core/tpu/kernels/outfeed_ops.h new file mode 100644 index 00000000..8f1562e8 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/tpu/kernels/outfeed_ops.h @@ -0,0 +1,139 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_TPU_KERNELS_OUTFEED_OPS_H_ +#define TENSORFLOW_CORE_TPU_KERNELS_OUTFEED_OPS_H_ + +#include +#include +#include + +#include "tensorflow/compiler/tf2xla/literal_util.h" +#include "tensorflow/compiler/tf2xla/shape_util.h" +#include "xla/literal.h" +#include "xla/shape.h" +#include "xla/shape_util.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/op_requires.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/tpu/kernels/transfer_ops.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/logging.h" // IWYU pragma: keep + +namespace tensorflow { + +// The OutfeedDequeue op is used to retrieve a single tensor from the device +// outfeed queue. +template +class TpuOutfeedDequeueOp : public T { + public: + explicit TpuOutfeedDequeueOp( + OpKernelConstruction* ctx, + std::unique_ptr transfer_op) + : T(ctx, "outfeed_dequeue", 1, std::move(transfer_op)) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("shape", &shape_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_)); + OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(dtype_, shape_, &xla_shape_)); + } + + absl::Status DoWork(OpKernelContext* ctx, int device_ordinal) override { + Tensor* output; + TF_RETURN_IF_ERROR(ctx->allocate_output(0, shape_, &output)); + + // Transfer from the outfeed interface of the device. + xla::MutableBorrowingLiteral literal; + TF_RETURN_IF_ERROR( + HostTensorToMutableBorrowingLiteral(xla_shape_, output, &literal)); + + VLOG(1) << "TransferLiteralFromOutfeed " + << xla::ShapeUtil::HumanStringWithLayout(xla_shape_); + + TF_RETURN_IF_ERROR( + T::transfer_op_->TransferLiteralFromOutfeed(device_ordinal, literal)); + + VLOG(1) << "TransferLiteralFromOutfeed complete."; + + return absl::OkStatus(); + } + + private: + TensorShape shape_; + DataType dtype_; + xla::Shape xla_shape_; + + TpuOutfeedDequeueOp(const TpuOutfeedDequeueOp&) = delete; + TpuOutfeedDequeueOp& operator=(const TpuOutfeedDequeueOp&) = delete; +}; + +// The OutfeedDequeueTuple op is used to retrieve multiple tensors from the +// device outfeed queue. +template +class TpuOutfeedDequeueTupleOp : public T { + public: + explicit TpuOutfeedDequeueTupleOp( + OpKernelConstruction* ctx, + std::unique_ptr transfer_op) + : T(ctx, "outfeed_dequeue", 1, std::move(transfer_op)) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("shapes", &shapes_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("dtypes", &dtypes_)); + OP_REQUIRES( + ctx, shapes_.size() == dtypes_.size(), + errors::InvalidArgument("shapes and dtypes must be the same length.")); + // The `dtypes` list is inferred from the supplied inputs, so it + // is always the correct length. + for (int i = 0; i < shapes_.size(); i++) { + xla::Shape xla_shape; + OP_REQUIRES_OK(ctx, + TensorShapeToXLAShape(dtypes_[i], shapes_[i], &xla_shape)); + xla_shapes_.push_back(xla_shape); + } + tuple_shape_ = xla::ShapeUtil::MakeTupleShape(xla_shapes_); + } + + absl::Status DoWork(OpKernelContext* ctx, int device_ordinal) override { + VLOG(1) << "TransferLiteralFromOutfeed " + << xla::ShapeUtil::HumanStringWithLayout(tuple_shape_); + + for (int i = 0; i < shapes_.size(); ++i) { + Tensor* output; + TF_RETURN_IF_ERROR(ctx->allocate_output(i, shapes_[i], &output)); + + xla::MutableBorrowingLiteral literal; + TF_RETURN_IF_ERROR(HostTensorToMutableBorrowingLiteral(xla_shapes_[i], + output, &literal)); + TF_RETURN_IF_ERROR( + T::transfer_op_->TransferLiteralFromOutfeed(device_ordinal, literal)); + } + return absl::OkStatus(); + } + + private: + std::vector shapes_; + DataTypeVector dtypes_; + std::vector xla_shapes_; + xla::Shape tuple_shape_; + + TpuOutfeedDequeueTupleOp(const TpuOutfeedDequeueTupleOp&) = delete; + TpuOutfeedDequeueTupleOp& operator=(const TpuOutfeedDequeueTupleOp&) = delete; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_TPU_KERNELS_OUTFEED_OPS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/tpu/kernels/sharding_utils.h b/third_party/tflite-hdrs/tensorflow/core/tpu/kernels/sharding_utils.h new file mode 100644 index 00000000..e557c5dd --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/tpu/kernels/sharding_utils.h @@ -0,0 +1,455 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_TPU_KERNELS_SHARDING_UTILS_H_ +#define TENSORFLOW_CORE_TPU_KERNELS_SHARDING_UTILS_H_ + +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "Eigen/Core" // from @eigen_archive +#include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/device.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/platform/status.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/macros.h" +#include "tsl/platform/statusor.h" + +namespace tensorflow { +namespace sharding_internal { +absl::Status ValidateShapesForSlice(absl::string_view input_name, + const Tensor* input, + const std::vector& num_splits, + const std::vector& paddings); +template +Eigen::DSizes TF_ATTRIBUTE_NOINLINE +ShapeAsEigenDSizes(const TensorShape& shape); +template +Eigen::DSizes ShapeAsEigenDSizes( + const TensorShape& shape) { + return shape.AsEigenDSizes(); +} + +} // namespace sharding_internal + +// Converts flatten index to start indices (subscript scaled with slice shape) +// for determining where to start a slice in the input tensor. +template +Eigen::DSizes GetSliceIndices( + absl::Span num_partitions, + const Eigen::DSizes& slice_shape, int index); +template <> +Eigen::DSizes TF_ATTRIBUTE_NOINLINE GetSliceIndices( + absl::Span num_partitions, + const Eigen::DSizes& slice_shape, int index); +template <> +Eigen::DSizes TF_ATTRIBUTE_NOINLINE GetSliceIndices( + absl::Span num_partitions, + const Eigen::DSizes& slice_shape, int index); +template <> +Eigen::DSizes TF_ATTRIBUTE_NOINLINE GetSliceIndices( + absl::Span num_partitions, + const Eigen::DSizes& slice_shape, int index); +template <> +Eigen::DSizes TF_ATTRIBUTE_NOINLINE GetSliceIndices( + absl::Span num_partitions, + const Eigen::DSizes& slice_shape, int index); +template <> +Eigen::DSizes TF_ATTRIBUTE_NOINLINE GetSliceIndices( + absl::Span num_partitions, + const Eigen::DSizes& slice_shape, int index); +template <> +Eigen::DSizes TF_ATTRIBUTE_NOINLINE GetSliceIndices( + absl::Span num_partitions, + const Eigen::DSizes& slice_shape, int index); +template <> +Eigen::DSizes TF_ATTRIBUTE_NOINLINE GetSliceIndices( + absl::Span num_partitions, + const Eigen::DSizes& slice_shape, int index); +template <> +Eigen::DSizes TF_ATTRIBUTE_NOINLINE GetSliceIndices( + absl::Span num_partitions, + const Eigen::DSizes& slice_shape, int index); + +template +Eigen::DSizes GetSliceIndices( + absl::Span num_partitions, + const Eigen::DSizes& slice_shape, + const int index) { + return Eigen::DSizes(); +} + +// Shared base class to save code space +template +class XlaNDSplitter { + public: + static absl::StatusOr> Create( + const std::vector& num_splits, int num_slices, + const std::vector& paddings, bool has_paddings) { + if (num_splits.size() != paddings.size()) { + return absl::InvalidArgumentError( + absl::StrCat("num_splits size ", num_splits.size(), + " mismatch with paddings size ", paddings.size(), ".")); + } + + int splits_cnt = 1; + for (auto split : num_splits) { + splits_cnt *= split; + } + + if (num_slices != splits_cnt) { + return absl::InvalidArgumentError(absl::StrCat( + "Expect num_slices ", splits_cnt, " but got ", num_slices)); + } + + return XlaNDSplitter(num_splits, num_slices, paddings, + has_paddings); + } + + // Split the given input. + // + // The splitted outputs are stored into tensors allocated by + // `allocate_output_fn`. In the simple case of pass through (no split and no + // padding), the output is stored through the fast path by + // `assign_or_copy_value_fn`. + absl::Status Split( + const Tensor* input, absl::string_view input_name, + const std::function& assign_or_copy_value_fn, + const std::function& allocate_output_fn, + const Device& device) { + if (num_splits_.size() != paddings_.size()) { + return absl::InvalidArgumentError( + absl::StrCat("num_splits size ", num_splits_.size(), + " mismatch with paddings size ", paddings_.size(), ".")); + } + + const int rank = input->shape().dims(); + const auto& input_shape = input->shape().dim_sizes(); + + TF_RETURN_IF_ERROR(sharding_internal::ValidateShapesForSlice( + input_name, input, num_splits_, paddings_)); + + TensorShape output_slice_shape; + for (int i = 0; i < rank; ++i) { + output_slice_shape.AddDim((input_shape[i] + paddings_[i]) / + ((num_slices_ == 1) ? 1 : num_splits_[i])); + } + if (num_slices_ == 1 && !has_paddings_) { + // Handle simple case first + TF_RETURN_IF_ERROR(assign_or_copy_value_fn(*input)); + } else { + std::vector output_slices(num_slices_); + for (int i = 0; i < num_slices_; i++) { + TF_RETURN_IF_ERROR(allocate_output_fn( + /*index=*/i, output_slice_shape, &output_slices[i])); + } + + if (rank == 1) { + SliceAndMaybePad<1>(device, input, input_shape, output_slice_shape, + output_slices); + } else if (rank == 2) { + SliceAndMaybePad<2>(device, input, input_shape, output_slice_shape, + output_slices); + } else if (rank == 3) { + SliceAndMaybePad<3>(device, input, input_shape, output_slice_shape, + output_slices); + } else if (rank == 4) { + SliceAndMaybePad<4>(device, input, input_shape, output_slice_shape, + output_slices); + } else if (rank == 5) { + SliceAndMaybePad<5>(device, input, input_shape, output_slice_shape, + output_slices); + } else if (rank == 6) { + SliceAndMaybePad<6>(device, input, input_shape, output_slice_shape, + output_slices); + } else if (rank == 7) { + SliceAndMaybePad<7>(device, input, input_shape, output_slice_shape, + output_slices); + } else if (rank == 8) { + SliceAndMaybePad<8>(device, input, input_shape, output_slice_shape, + output_slices); + } + } + return absl::OkStatus(); + } + + private: + template + class SliceAndMaybePadState { + public: + int num_complete_pad_dims_; + int num_partial_pad_dims_; + TensorShape non_padded_slice_shape_; + Eigen::array, Rank> slice_paddings_; + Eigen::DSizes slice_indices_; + Eigen::DSizes output_slice_shape_dsizes_; + Eigen::DSizes non_padded_slice_shape_dsizes_; + + TF_ATTRIBUTE_NOINLINE SliceAndMaybePadState( + absl::Span num_splits, + const absl::Span input_shape, + const TensorShape& output_slice_shape, int slice_index) { + output_slice_shape_dsizes_ = + sharding_internal::ShapeAsEigenDSizes(output_slice_shape); + num_complete_pad_dims_ = 0; + num_partial_pad_dims_ = 0; + slice_indices_ = GetSliceIndices( + num_splits, output_slice_shape_dsizes_, slice_index); + + // Calculate paddings necessary for slice instead of padding input and + // slicing subsequently to reduce temporary memory allocation. + for (int dim = 0; dim < Rank; ++dim) { + const int64_t dim_size = input_shape[dim]; + const int64_t out_dim = output_slice_shape_dsizes_[dim]; + int64_t non_padded_dim = 0; + if (slice_indices_[dim] >= dim_size) { + // Complete padding. + slice_indices_[dim] = dim_size; + non_padded_dim = 0; + slice_paddings_[dim] = {0, out_dim}; + num_complete_pad_dims_++; + } else if (slice_indices_[dim] + out_dim > dim_size) { + // Partial padding. + non_padded_dim = dim_size - slice_indices_[dim]; + slice_paddings_[dim] = {0, out_dim - non_padded_dim}; + num_partial_pad_dims_++; + } else { + non_padded_dim = out_dim; + } + non_padded_slice_shape_.AddDim(non_padded_dim); + } + non_padded_slice_shape_dsizes_ = + sharding_internal::ShapeAsEigenDSizes(non_padded_slice_shape_); + } + }; + + std::vector num_splits_; + int num_slices_; + std::vector paddings_; + bool has_paddings_; + + explicit XlaNDSplitter(const std::vector& num_splits, int num_slices, + const std::vector& paddings, + bool has_paddings) + : num_splits_(num_splits), + num_slices_(num_slices), + paddings_(paddings), + has_paddings_(has_paddings) {} + + void TF_ATTRIBUTE_NOINLINE SetToConstant(Tensor* output_slice, + const Device& device) { + auto output_flat = output_slice->flat(); + output_flat.device(device) = output_flat.constant(T()); + } + + template + void TF_ATTRIBUTE_NOINLINE AssignFromInput( + Tensor* output_slice, const Device& device, const Tensor* input, + const Eigen::DSizes& slice_indices, + const Eigen::DSizes& output_slice_shape_dsizes) { + output_slice->tensor().device(device) = + input->tensor().slice(slice_indices, + output_slice_shape_dsizes); + } + + template + void TF_ATTRIBUTE_NOINLINE + SliceAndMaybePad(const Device& device, const Tensor* input, + const absl::Span input_shape, + const TensorShape& output_slice_shape, + const std::vector& output_slices) { + const auto& input_tensor = input->tensor(); + // Slice shape with optional padding. + for (int i = 0; i < num_slices_; ++i) { + Tensor* output_slice = output_slices[i]; + SliceAndMaybePadState r(num_splits_, input_shape, + output_slice_shape, i); + if (r.num_complete_pad_dims_ == Rank || + (r.num_complete_pad_dims_ > 0 || r.num_partial_pad_dims_ > 0)) { + // Need to init padding + SetToConstant(output_slice, device); + } + if (r.num_complete_pad_dims_ == Rank) { + // Done + } else if (r.num_complete_pad_dims_ > 0 || r.num_partial_pad_dims_ > 0) { + output_slice->tensor() + .slice(Eigen::DSizes(), + r.non_padded_slice_shape_dsizes_) + .device(device) = input_tensor.slice( + r.slice_indices_, r.non_padded_slice_shape_dsizes_); + } else { + AssignFromInput(output_slice, device, input, r.slice_indices_, + r.output_slice_shape_dsizes_); + } + } + } +}; + +// Shared base class to save code space +template +class XlaNDConcatenator { + public: + static absl::StatusOr> Create( + const std::vector& num_concats, int num_slices, + const std::vector& paddings, bool has_paddings) { + if (num_concats.size() != paddings.size()) { + return absl::InvalidArgumentError( + absl::StrCat("num_concats size ", num_concats.size(), + " mismatch with paddings size ", paddings.size(), ".")); + } + + int concats_cnt = 1; + for (auto concat : num_concats) { + concats_cnt *= concat; + } + + if (num_slices != concats_cnt) { + return absl::InvalidArgumentError(absl::StrCat( + "Expect num_slices ", concats_cnt, " but got ", num_slices)); + } + + return XlaNDConcatenator(num_concats, num_slices, paddings, + has_paddings); + } + absl::Status ComputeInternal( + absl::Span inputs, + const std::function& assign_or_copy_value_fn, + const std::function()>& get_output_fn, + const Device& device) { + const int rank = inputs[0].shape().dims(); + + if (rank < 1 || rank > 8) { + return absl::InvalidArgumentError(absl::StrCat( + "'inputs' tensors must have rank in range (0, 8], but got ", rank, + ".")); + } + + if (num_slices_ == 1 && !has_paddings_) { + // Simple case + return assign_or_copy_value_fn(inputs[0]); + } + + TF_ASSIGN_OR_RETURN(Tensor * output, get_output_fn()); + + if (rank == 1) { + MaybeUnpadAndAssign<1>(device, inputs, output); + } else if (rank == 2) { + MaybeUnpadAndAssign<2>(device, inputs, output); + } else if (rank == 3) { + MaybeUnpadAndAssign<3>(device, inputs, output); + } else if (rank == 4) { + MaybeUnpadAndAssign<4>(device, inputs, output); + } else if (rank == 5) { + MaybeUnpadAndAssign<5>(device, inputs, output); + } else if (rank == 6) { + MaybeUnpadAndAssign<6>(device, inputs, output); + } else if (rank == 7) { + MaybeUnpadAndAssign<7>(device, inputs, output); + } else if (rank == 8) { + MaybeUnpadAndAssign<8>(device, inputs, output); + } + return absl::OkStatus(); + } + + private: + template + class MaybeUnpadAndAssignState { + public: + int num_complete_pad_dims_; + int num_partial_pad_dims_; + TensorShape non_padded_slice_shape_; + Eigen::DSizes slice_shape_dsizes_; + Eigen::array, Rank> slice_paddings_; + Eigen::DSizes slice_indices_; + Eigen::DSizes output_slice_shape_dsizes_; + Eigen::DSizes non_padded_slice_shape_dsizes_; + + TF_ATTRIBUTE_NOINLINE MaybeUnpadAndAssignState( + absl::Span num_concats, const Tensor& input0, + Tensor* output, int slice_index) { + slice_shape_dsizes_ = input0.shape().AsEigenDSizes(); + slice_indices_ = + GetSliceIndices(num_concats, slice_shape_dsizes_, slice_index); + num_complete_pad_dims_ = 0; + num_partial_pad_dims_ = 0; + // Calculate paddings necessary to strip from slice. + for (int dim = 0; dim < Rank; ++dim) { + const int64_t dim_size = output->shape().dim_size(dim); + int64_t non_padded_dim = 0; + if (slice_indices_[dim] >= dim_size) { + // Complete padding. + slice_indices_[dim] = dim_size; + non_padded_dim = 0; + num_complete_pad_dims_++; + } else if (slice_indices_[dim] + slice_shape_dsizes_[dim] > dim_size) { + // Partial padding. + non_padded_dim = dim_size - slice_indices_[dim]; + num_partial_pad_dims_++; + } else { + non_padded_dim = slice_shape_dsizes_[dim]; + } + non_padded_slice_shape_.AddDim(non_padded_dim); + } + non_padded_slice_shape_dsizes_ = + non_padded_slice_shape_.AsEigenDSizes(); + } + }; + + std::vector num_concats_; + int num_slices_; + std::vector paddings_; + bool has_paddings_; + + explicit TF_ATTRIBUTE_NOINLINE XlaNDConcatenator( + const std::vector& num_concats, int num_slices, + const std::vector& paddings, bool has_paddings) + : num_concats_(num_concats), + num_slices_(num_slices), + paddings_(paddings), + has_paddings_(has_paddings) {} + + template + void TF_ATTRIBUTE_NOINLINE MaybeUnpadAndAssign( + const Device& device, absl::Span inputs, Tensor* output) { + for (int i = 0; i < num_slices_; ++i) { + MaybeUnpadAndAssignState r(num_concats_, inputs[0], output, i); + if (r.num_complete_pad_dims_ == Rank) { + continue; + } else if (r.num_complete_pad_dims_ > 0 || r.num_partial_pad_dims_ > 0) { + output->tensor() + .slice(r.slice_indices_, r.non_padded_slice_shape_dsizes_) + .device(device) = inputs[i].tensor().slice( + Eigen::DSizes(), + r.non_padded_slice_shape_dsizes_); + } else { + output->tensor() + .slice(r.slice_indices_, r.slice_shape_dsizes_) + .device(device) = inputs[i].tensor(); + } + } + } +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_TPU_KERNELS_SHARDING_UTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/tpu/kernels/sparse_core_layout.h b/third_party/tflite-hdrs/tensorflow/core/tpu/kernels/sparse_core_layout.h new file mode 100644 index 00000000..9f4697c2 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/tpu/kernels/sparse_core_layout.h @@ -0,0 +1,132 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_TPU_KERNELS_SPARSE_CORE_LAYOUT_H_ +#define TENSORFLOW_CORE_TPU_KERNELS_SPARSE_CORE_LAYOUT_H_ + +#include +#include +#include +#include +#include + +#include "absl/container/btree_map.h" +#include "absl/log/check.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "tensorflow/core/platform/stringpiece.h" +#include "tensorflow/core/tpu/kernels/sparse_core_layout.pb.h" + +namespace tensorflow::tpu { + +// A class to figure out which tables to stack. +class SparseCoreLayoutStacker { + public: + // Constructor. Arguments: + // num_partitions: How many shards the sparse core shards are concatenated + // into (usually one per TPU chip). + // NOTE: As of Q4 2023, SPMD is not supported by the sparse core python + // libraries so we don't support it here. + // sparse_cores_per_partition: Number of sparsecore per partition + // disable_table_stacking: Should not stack tables. + explicit SparseCoreLayoutStacker(int num_partitions, + bool disable_table_stacking = false, + int sparse_cores_per_partition = 4); + + // Change various limits. You must call these before calling Addtable. + void SetActivationMemoryBytesLimit(int64_t activation_mem_bytes_limit) { + CHECK(stacks_by_group_.empty()) << "must call before AddTable"; + activation_mem_bytes_limit_ = activation_mem_bytes_limit; + } + void SetVariableShardBytesLimit(int64_t variable_shard_bytes_limit) { + CHECK(stacks_by_group_.empty()) << "must call before AddTable"; + variable_shard_bytes_limit_ = variable_shard_bytes_limit; + } + void SetStackingEnabled(bool stacking_enabled) { + CHECK(stacks_by_group_.empty()) << "must call before AddTable"; + stacking_enabled_ = stacking_enabled; + } + void SetStackingRowLimit(int64_t row_limit) { + CHECK(stacks_by_group_.empty()) << "must call before AddTable"; + row_limit_ = row_limit; + } + void SetStackingTableLimit(int table_limit) { + CHECK(stacks_by_group_.empty()) << "must call before AddTable"; + table_limit_ = table_limit; + } + + // Add a new table. Arguments: + // table_name: How this table will be referred to. + // table_height: The number of rows. + // table_width: The number of columns in the input layer. For storage, this + // will be rounded up to a multiple of eight, but the padding columns will + // be stripped off when fed into the rest of the model. + // group: An arbitrary identifier that should be derived from the optimizer + // and hyperparameters. Only tables with the same group and rounded + // table_width can be stacked. The actual contents of this field are not + // particularly meaningful except they are used to construct the + // stack_name field in the SparseCoreTableLayout. + // output_samples: How many times a row from this table will have to be + // returned per batch. This is ordinarily the batch size unless we lay out + // several values from the same example in a sequence, or if multiple + // features share the same table. + // + // Be sure you call AddTable in a deterministic order; the details of the + // stacking will depend on the order you call AddTable. + absl::Status AddTable(absl::string_view table_name, int64_t table_height, + int64_t table_width, absl::string_view group, + int64_t output_samples); + + // Get the information about each table out. + absl::StatusOr GetLayouts(); + + private: + struct TableStack { + // A name we give the stack while we're constructing it. The name will be + // overridden later to be equal to the names of the tables. + std::string temporary_name; + int64_t padded_width = 0; + int64_t unsharded_height = 0; + int64_t total_activation_mem_bytes = 0; + int64_t total_variable_shard_bytes = 0; + + // While we're filling out this structure, we can't fill out all the fields + // in the SparseCoreTableLayout; we fill out as many of them as we can. + std::vector incomplete_tables; + }; + + const int num_partitions_; + const int sparse_cores_per_partition_; + const int num_sparse_cores_; + + bool stacking_enabled_ = true; + int64_t activation_mem_bytes_limit_ = 0; + int64_t variable_shard_bytes_limit_ = 0; + // Sparse core ops use signed int for row numbers so we had better not stack + // beyond this limit. + int64_t row_limit_ = (1LL << 31) - 1; + + // The maximum number of tables in any stack. + int table_limit_ = std::numeric_limits::max(); + + // All the stacks that we currently know about. Note that we use a btree_map + // rather than a flat_hash_map so the resulting order is deterministic as long + // as we are called in a deterministic order. Key is (padded_width, group). + absl::btree_map, std::vector> + stacks_by_group_; +}; + +} // namespace tensorflow::tpu + +#endif // TENSORFLOW_CORE_TPU_KERNELS_SPARSE_CORE_LAYOUT_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/tpu/kernels/sparse_core_ops_stats_handler.h b/third_party/tflite-hdrs/tensorflow/core/tpu/kernels/sparse_core_ops_stats_handler.h new file mode 100644 index 00000000..8667d49a --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/tpu/kernels/sparse_core_ops_stats_handler.h @@ -0,0 +1,39 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_TPU_KERNELS_SPARSE_CORE_OPS_STATS_HANDLER_H_ +#define TENSORFLOW_CORE_TPU_KERNELS_SPARSE_CORE_OPS_STATS_HANDLER_H_ + +#include +#include + +enum class StatsType { + NUM_MINIBATCHES_PER_SC, + MAX_IDS_PER_PARTITION, + MAX_UNIQUE_IDS_PER_PARTITION, + IDS_PER_PARTITION, + UNIQUE_IDS_PER_PARTITION, + DROPPED_ID_COUNT, +}; + +class SparseCoreOpsStatsHandler { + public: + virtual ~SparseCoreOpsStatsHandler() = default; + virtual void Record( + StatsType type, int64_t value, std::string device_name, + std::string table_name) { /* Default implementation does nothing */ + } +}; + +#endif // TENSORFLOW_CORE_TPU_KERNELS_SPARSE_CORE_OPS_STATS_HANDLER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/tpu/kernels/sparse_core_ops_utils.h b/third_party/tflite-hdrs/tensorflow/core/tpu/kernels/sparse_core_ops_utils.h new file mode 100644 index 00000000..dc9b028e --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/tpu/kernels/sparse_core_ops_utils.h @@ -0,0 +1,75 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_TPU_KERNELS_SPARSE_CORE_OPS_UTILS_H_ +#define TENSORFLOW_CORE_TPU_KERNELS_SPARSE_CORE_OPS_UTILS_H_ + +#include +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "tensorflow/compiler/jit/flags.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +// Pad value used for SparseCore mini batching logic. +const int32_t kXlaPadValue = std::numeric_limits::max(); + +std::vector ConvertBinarySplitsToBucketSplits(int64 split, + int max_division_level); + +int64 ConvertBucketSplitsToBinarySplits(std::vector bucket_splits, + int max_division_level); + +absl::Status ValidateInputCombiner(const std::string& combiner); + +std::function GetCombinerScaleContributionFunction( + absl::string_view combiner); + +std::function GetCombinerScaleTransformFunction( + absl::string_view combiner); + +// Stacks tables, so long as table have the same 'group' index. We assume that +// all tables with a given group index have the same width. Returns a list of +// list of table names, in alphabetical order. +std::vector> GetTableStacks( + const std::vector& table_height, + const std::vector& table_width, + const std::vector& table_num_samples, + const std::vector& table_group, + const std::vector& table_names, int64_t num_tpu_chips); + +int GetMinibatchMaxDivisionLevel(); + +bool GetDisableTableStacking(); + +int64_t GetXlaSparseCoreStackingMemLimit(); + +int64_t GetXlaSparseCoreStackingTableShardLimit(); + +absl::Status GetMaxIdsAndUniquesExternal(const std::string& program_key, + const std::string& table_name, + int64_t num_samples_per_sparse_core, + int64_t feature_width, + int64_t* max_ids_per_partition, + int64_t* max_unique_ids_per_partition); + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_TPU_KERNELS_SPARSE_CORE_OPS_UTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/tpu/kernels/sparse_core_preprocess_ops.h b/third_party/tflite-hdrs/tensorflow/core/tpu/kernels/sparse_core_preprocess_ops.h new file mode 100644 index 00000000..96b39458 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/tpu/kernels/sparse_core_preprocess_ops.h @@ -0,0 +1,261 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_TPU_KERNELS_SPARSE_CORE_PREPROCESS_OPS_H_ +#define TENSORFLOW_CORE_TPU_KERNELS_SPARSE_CORE_PREPROCESS_OPS_H_ + +#include +#include +#include +#include +#include +#include + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/tstring.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/tpu/kernels/sparse_core_ops_stats_handler.h" + +namespace tensorflow { + +// Struct to describe an embedding lookup input data. +struct EmbeddingLookupInput { + // Which replica it belongs. + int32 replica_id; + // Token id. + int32 token_id; + // Sample id. + int32 sample_id; + // Gain. + float gain; + + EmbeddingLookupInput(int32 replica_id, int32 token_id, int32 sample_id, + float gain) + : replica_id(replica_id), + token_id(token_id), + sample_id(sample_id), + gain(gain) {} +}; + +absl::Status ValidateInputs(const Tensor& indices_or_row_splits, + const Tensor& values, const Tensor& weights, + int sample_count); + +// Compute the row id list before padding. +absl::Status ComputeRowIdsBeforePadding(const Tensor& indices_or_row_splits, + int32 total_id_count, + int32 sample_count, + int32* row_ids_before_padding); + +class GetMinibatchesInCsrWithPhysicalReplicaOp : public OpKernel { + public: + explicit GetMinibatchesInCsrWithPhysicalReplicaOp(OpKernelConstruction* ctx); + ~GetMinibatchesInCsrWithPhysicalReplicaOp() override = default; + GetMinibatchesInCsrWithPhysicalReplicaOp( + const GetMinibatchesInCsrWithPhysicalReplicaOp&) = delete; + GetMinibatchesInCsrWithPhysicalReplicaOp& operator=( + const GetMinibatchesInCsrWithPhysicalReplicaOp&) = delete; + + void Compute(OpKernelContext* ctx) override; + + protected: + int sample_count_ = 1; + int feature_width_ = 1; + int64_t num_sc_per_chip_; + std::string table_name_; + std::unique_ptr sparse_core_ops_stats_handler_; + + bool allow_id_dropping_for_minibatching_ = false; + + private: + int num_replica_ = 1; + int max_minibatches_per_sc_ = 1; + int max_ids_per_chip_per_sample_ = 1; + int table_vocab_size_ = 1; + std::string device_name_; +}; + +class GetMinibatchSplitsWithPhysicalReplicaOp : public OpKernel { + public: + explicit GetMinibatchSplitsWithPhysicalReplicaOp(OpKernelConstruction* ctx); + ~GetMinibatchSplitsWithPhysicalReplicaOp() override = default; + GetMinibatchSplitsWithPhysicalReplicaOp( + const GetMinibatchSplitsWithPhysicalReplicaOp&) = delete; + GetMinibatchSplitsWithPhysicalReplicaOp& operator=( + const GetMinibatchSplitsWithPhysicalReplicaOp&) = delete; + + void Compute(OpKernelContext* ctx) override; + + protected: + virtual void CalculateHeadroom(int32 this_max_ids, int32 this_max_uniques, + tstring program_key, + int64_t max_ids_per_partition, + int64_t max_unique_ids_per_partition, + int32_t dropped_id_count) {} + virtual inline int32_t CalculateBucketIdWithHashing(int32_t col_id, + int32_t num_buckets) { + // TODO(pineapplejuice233): Add a proper hashing function here. + return col_id % num_buckets; + } + + std::string device_name_; + std::string table_name_; + std::unique_ptr sparse_core_ops_stats_handler_; + bool allow_id_dropping_for_minibatching_ = false; + bool allow_id_shuffling_for_minibatching_ = false; + + private: + int num_replica_ = 1; + int sample_count_ = 1; + int table_vocab_size_ = 1; + int feature_width_ = 1; + int64_t num_sc_per_chip_; +}; + +class StoreMinibatchStatisticsInFdoOp : public OpKernel { + public: + explicit StoreMinibatchStatisticsInFdoOp(OpKernelConstruction* ctx); + ~StoreMinibatchStatisticsInFdoOp() override = default; + StoreMinibatchStatisticsInFdoOp(const StoreMinibatchStatisticsInFdoOp&) = + delete; + StoreMinibatchStatisticsInFdoOp& operator=( + const StoreMinibatchStatisticsInFdoOp&) = delete; + + void Compute(OpKernelContext* ctx) override; + + protected: + virtual void CalculateHeadroom(int32 this_max_ids, int32 this_max_uniques, + tstring program_key, + int64_t max_ids_per_partition, + int64_t max_unique_ids_per_partition) {} + std::string device_name_; + std::string table_name_; + + private: + int num_replica_ = 1; + int sample_count_ = 1; + int feature_width_ = 1; + int64_t num_sc_per_chip_; +}; + +// TODO(pineapplejuice233): Unify this op with ConvertToListOfCooTensorsV2Op. +class ConvertToListOfSparseCoreCooTensorsOp : public OpKernel { + public: + explicit ConvertToListOfSparseCoreCooTensorsOp(OpKernelConstruction* ctx); + ~ConvertToListOfSparseCoreCooTensorsOp() override = default; + ConvertToListOfSparseCoreCooTensorsOp( + const ConvertToListOfSparseCoreCooTensorsOp&) = delete; + ConvertToListOfSparseCoreCooTensorsOp& operator=( + const ConvertToListOfSparseCoreCooTensorsOp&) = delete; + + void Compute(OpKernelContext* ctx) override; + + private: + void WriteToOutputTensor(int32* row_ids, int32* col_ids, float* gains, + int32* row_ids_tensor_ptr, int32* col_ids_tensor_ptr, + float* gains_tensor_ptr, int32_t begin_index, + int32_t end_index, int32_t sc_id, + std::optional> gains_rescale); + int sample_count_; + int num_sc_per_chip_; + int per_sc_sample_count_; + int row_offset_; + int col_offset_; + int col_shift_; + int num_sc_shards_; + int stacked_table_sample_count_; + int num_sc_shards_bit_mod_; + int num_sc_shards_bit_mod_inv_; + int per_sc_row_offset_; + int per_sc_stacked_table_sample_count_; + std::string combiner_; +}; + +class SortListOfSparseCoreCooTensorsOp : public OpKernel { + public: + explicit SortListOfSparseCoreCooTensorsOp(OpKernelConstruction* ctx); + ~SortListOfSparseCoreCooTensorsOp() override = default; + SortListOfSparseCoreCooTensorsOp(const SortListOfSparseCoreCooTensorsOp&) = + delete; + SortListOfSparseCoreCooTensorsOp& operator=( + const SortListOfSparseCoreCooTensorsOp&) = delete; + + void Compute(OpKernelContext* ctx) override; + + private: + int32_t num_sc_per_chip_; + int32_t feature_width_; + int32_t num_replica_; + int32_t num_physical_replica_; + int32_t num_physical_replica_bit_; + int32_t max_ids_per_sparse_core_; + int32_t max_unique_ids_per_sparse_core_; + std::string table_name_; + std::vector sample_count_list_; + std::vector col_offset_list_; + std::map> col_offset_to_feature_id_; +}; + +class ConvertToSparseCoreCsrWrappedCooTensorOp : public OpKernel { + public: + explicit ConvertToSparseCoreCsrWrappedCooTensorOp(OpKernelConstruction* ctx); + ~ConvertToSparseCoreCsrWrappedCooTensorOp() override = default; + ConvertToSparseCoreCsrWrappedCooTensorOp( + const ConvertToSparseCoreCsrWrappedCooTensorOp&) = delete; + ConvertToSparseCoreCsrWrappedCooTensorOp& operator=( + const ConvertToSparseCoreCsrWrappedCooTensorOp&) = delete; + + void Compute(OpKernelContext* ctx) override; + + private: + int32_t num_sc_per_chip_; + int32_t table_vocab_size_; + int32_t feature_width_; + int32_t num_replica_; + int32_t sample_count_per_sc_; + int32_t max_minibatches_per_sc_; + int32_t max_ids_per_chip_per_sample_; + bool allow_id_dropping_; + std::string table_name_; + std::string device_name_; +}; + +class GetStatsFromListOfSparseCoreCooTensorsOp : public OpKernel { + public: + explicit GetStatsFromListOfSparseCoreCooTensorsOp(OpKernelConstruction* ctx); + ~GetStatsFromListOfSparseCoreCooTensorsOp() override = default; + GetStatsFromListOfSparseCoreCooTensorsOp( + const GetStatsFromListOfSparseCoreCooTensorsOp&) = delete; + GetStatsFromListOfSparseCoreCooTensorsOp& operator=( + const GetStatsFromListOfSparseCoreCooTensorsOp&) = delete; + + void Compute(OpKernelContext* ctx) override; + + private: + int32_t num_sc_per_chip_; + int32_t feature_width_; + int32_t num_replica_; + int32_t num_physical_replica_; + int32_t num_physical_replica_bit_; + std::string table_name_; + std::vector sample_count_list_; + std::vector col_offset_list_; + std::map> col_offset_to_feature_id_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_TPU_KERNELS_SPARSE_CORE_PREPROCESS_OPS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/tpu/kernels/sparse_core_xla_flags_defaults.h b/third_party/tflite-hdrs/tensorflow/core/tpu/kernels/sparse_core_xla_flags_defaults.h new file mode 100644 index 00000000..42a25836 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/tpu/kernels/sparse_core_xla_flags_defaults.h @@ -0,0 +1,30 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_TPU_KERNELS_SPARSE_CORE_XLA_FLAGS_DEFAULTS_H_ +#define TENSORFLOW_CORE_TPU_KERNELS_SPARSE_CORE_XLA_FLAGS_DEFAULTS_H_ + +#include + +namespace tensorflow { + +constexpr int kDefaultSparseCoreMinibatchMaxDivisionLevel = 6; +constexpr bool kDefaultDisableTableStacking = false; +constexpr int64_t kDefaultXlaSparseCoreStackingMemLimit = 2097152; +constexpr int64_t kDefaultXlaSparseCoreStackingTableShardLimit = 2147483648; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_TPU_KERNELS_SPARSE_CORE_XLA_FLAGS_DEFAULTS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/tpu/kernels/sparse_core_xla_ops.h b/third_party/tflite-hdrs/tensorflow/core/tpu/kernels/sparse_core_xla_ops.h new file mode 100644 index 00000000..71995cb9 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/tpu/kernels/sparse_core_xla_ops.h @@ -0,0 +1,47 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_TPU_KERNELS_SPARSE_CORE_XLA_OPS_H_ +#define TENSORFLOW_CORE_TPU_KERNELS_SPARSE_CORE_XLA_OPS_H_ + +#include "xla/hlo/builder/xla_builder.h" +#include "xla/xla_data.pb.h" +#include "tsl/platform/macros.h" + +// RAII helper to set the frontend attribute for the target chip to the SC. +// Automatically restores the frontend attributes on exit. +class UseSparseCoreFrontendAttributes { + public: + explicit UseSparseCoreFrontendAttributes(xla::XlaBuilder* builder) + : builder_(builder), + original_frontend_attributes_(builder->frontend_attributes()) { + xla::FrontendAttributes sc_attributes = original_frontend_attributes_; + (*sc_attributes.mutable_map())["_xla_compute_type"] = "sparse"; + builder_->SetFrontendAttributes(sc_attributes); + } + + ~UseSparseCoreFrontendAttributes() { + builder_->SetFrontendAttributes(original_frontend_attributes_); + } + + private: + xla::XlaBuilder* builder_; + const xla::FrontendAttributes original_frontend_attributes_; + + UseSparseCoreFrontendAttributes(const UseSparseCoreFrontendAttributes&) = + delete; + void operator=(const UseSparseCoreFrontendAttributes&) = delete; +}; + +#endif // TENSORFLOW_CORE_TPU_KERNELS_SPARSE_CORE_XLA_OPS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/tpu/kernels/tpu_compilation_cache_entry.h b/third_party/tflite-hdrs/tensorflow/core/tpu/kernels/tpu_compilation_cache_entry.h new file mode 100644 index 00000000..52d2d8b3 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/tpu/kernels/tpu_compilation_cache_entry.h @@ -0,0 +1,48 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILATION_CACHE_ENTRY_H_ +#define TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILATION_CACHE_ENTRY_H_ + +#include "tensorflow/core/tpu/kernels/tpu_program_group_interface.h" + +namespace tensorflow { +namespace tpu { + +// Cache entry to hold a `TpuProgramGroupInterface` object that can be used to +// fetch a TPU program for a given TPU core index. +class TpuCompilationCacheEntry { + public: + explicit TpuCompilationCacheEntry( + const TpuProgramGroupInterface* tpu_program_group, int core_index) + : tpu_program_group_(tpu_program_group), core_index_(core_index) {} + + // Constructor for an empty entry. + TpuCompilationCacheEntry() : tpu_program_group_(nullptr), core_index_(-1) {} + + const TpuProgramGroupInterface* tpu_program_group() const { + return tpu_program_group_; + } + + int core_index() const { return core_index_; } + + private: + const TpuProgramGroupInterface* tpu_program_group_; + int core_index_; +}; + +} // namespace tpu +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILATION_CACHE_ENTRY_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/tpu/kernels/tpu_compilation_cache_entry_unloader.h b/third_party/tflite-hdrs/tensorflow/core/tpu/kernels/tpu_compilation_cache_entry_unloader.h new file mode 100644 index 00000000..c85376a7 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/tpu/kernels/tpu_compilation_cache_entry_unloader.h @@ -0,0 +1,78 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILATION_CACHE_ENTRY_UNLOADER_H_ +#define TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILATION_CACHE_ENTRY_UNLOADER_H_ + +#include +#include + +#include "absl/base/thread_annotations.h" +#include "absl/container/flat_hash_set.h" +#include "absl/synchronization/mutex.h" +#include "tensorflow/core/framework/resource_base.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.h" +#include "tsl/platform/logging.h" // IWYU pragma: keep +#include "tsl/platform/macros.h" + +namespace tensorflow { +namespace tpu { + +class TpuCompilationCacheEntryUnloader : public ResourceBase { + public: + explicit TpuCompilationCacheEntryUnloader(TpuCompilationCacheInterface* cache) + : cache_(cache) { + // Hold a reference to the cache until the unloader is destroyed. + cache_->Ref(); + VLOG(1) << "Will unload compilation cache entries when session closes."; + } + + ~TpuCompilationCacheEntryUnloader() override { + absl::MutexLock lock(&mu_); + for (int64_t uid : cache_entry_uids_) { + absl::Status s = cache_->MarkEntryForEviction(uid); + if (!s.ok()) { + LOG(WARNING) << "MarkEntryForEviction in " + "~CompilationCacheEntryUnloader fails with error " + << s; + } + } + // Release our reference to the cache. + cache_->Unref(); + } + + // Add cache entry uid to be unloaded in destructor. + void AddCacheEntryUid(int64_t uid) { + absl::MutexLock lock(&mu_); + cache_entry_uids_.insert(uid); + } + + std::string DebugString() const override { + return "CompilationCacheEntryUnloader"; + } + + private: + TpuCompilationCacheEntryUnloader(const TpuCompilationCacheEntryUnloader&) = + delete; + void operator=(const TpuCompilationCacheEntryUnloader&) = delete; + mutable absl::Mutex mu_; + TpuCompilationCacheInterface* cache_; // Not owned. + absl::flat_hash_set cache_entry_uids_ ABSL_GUARDED_BY(mu_); +}; + +} // namespace tpu +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILATION_CACHE_ENTRY_UNLOADER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/tpu/kernels/tpu_compilation_cache_external.h b/third_party/tflite-hdrs/tensorflow/core/tpu/kernels/tpu_compilation_cache_external.h new file mode 100644 index 00000000..d415feae --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/tpu/kernels/tpu_compilation_cache_external.h @@ -0,0 +1,59 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILATION_CACHE_EXTERNAL_H_ +#define TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILATION_CACHE_EXTERNAL_H_ + +#include +#include +#include + +#include "absl/base/thread_annotations.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/tpu/kernels/compiled_subgraph.h" +#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.h" +#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_key.h" +#include "tensorflow/core/tpu/kernels/tpu_program_group_interface.h" + +namespace tensorflow { +namespace tpu { + +class TpuCompilationCacheExternal : public TpuCompilationCacheInterface { + public: + explicit TpuCompilationCacheExternal(int64_t max_cache_size) + : TpuCompilationCacheInterface(max_cache_size) {} + + std::string DebugString() const override { + return "TpuCompilationCacheExternal"; + } + + private: + // Creates a new entry by running initialize_programs and places it in the + // cache to be looked up by key. The new entry is in the 'marked for eviction' + // state (not present in entries_by_last_use_) and the caller is expected to + // call LookupEntryMarkedForEviction after InitializeEntry. + // + // **InitializeEntry releases mu_ during the call to initialize_programs.** + CompiledSubgraph* InitializeEntry( + const std::string& key, + const std::function& + initialize_program, + const TpuCompilationCacheKey& subgraph_key) + ABSL_EXCLUSIVE_LOCKS_REQUIRED(TpuCompilationCacheInterface::mu_) override; +}; + +} // namespace tpu +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILATION_CACHE_EXTERNAL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/tpu/kernels/tpu_compilation_cache_factory.h b/third_party/tflite-hdrs/tensorflow/core/tpu/kernels/tpu_compilation_cache_factory.h new file mode 100644 index 00000000..4710f916 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/tpu/kernels/tpu_compilation_cache_factory.h @@ -0,0 +1,33 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILATION_CACHE_FACTORY_H_ +#define TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILATION_CACHE_FACTORY_H_ + +#include + +#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.h" + +namespace tensorflow { +namespace tpu { + +std::function GetCompilationCacheCreateFn(); + +void SetCompilationCacheCreateFn( + std::function fn); + +} // namespace tpu +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILATION_CACHE_FACTORY_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/tpu/kernels/tpu_compilation_cache_grpc.h b/third_party/tflite-hdrs/tensorflow/core/tpu/kernels/tpu_compilation_cache_grpc.h new file mode 100644 index 00000000..b7c6b7c3 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/tpu/kernels/tpu_compilation_cache_grpc.h @@ -0,0 +1,236 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// Copied from auto-generated gRPC code in order to enable using grpc_call.h +// for raw message handling. +#ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILATION_CACHE_GRPC_H_ +#define TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILATION_CACHE_GRPC_H_ + +#include +#include + +#include "grpcpp/generic/async_generic_service.h" +#include "grpcpp/impl/codegen/async_stream.h" +#include "grpcpp/impl/codegen/async_unary_call.h" +#include "grpcpp/impl/codegen/client_callback.h" +#include "grpcpp/impl/codegen/client_context.h" +#include "grpcpp/impl/codegen/completion_queue.h" +#include "grpcpp/impl/codegen/method_handler.h" +#include "grpcpp/impl/codegen/proto_utils.h" +#include "grpcpp/impl/codegen/rpc_method.h" +#include "grpcpp/impl/codegen/server_callback.h" +#include "grpcpp/impl/codegen/server_context.h" +#include "grpcpp/impl/codegen/service_type.h" +#include "grpcpp/impl/codegen/status.h" +#include "grpcpp/impl/codegen/stub_options.h" +#include "grpcpp/impl/codegen/sync_stream.h" + +#if defined(LIBTPU_ON_GCE) +#include "tensorflow/core/tpu/kernels/tpu_compilation_cache.pb.h" +#else +#include "tensorflow/core/tpu/kernels/tpu_compilation_cache.pb.h" // copybara" +#endif +#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_common.pb.h" + +namespace tensorflow { +namespace tpu { +namespace grpc { +class TpuCompilationCacheService final { + public: + using RequestType = ::tensorflow::tpu::GetTpuProgramRequest; +#if defined(LIBTPU_ON_GCE) + using ResponseType = ::tensorflow::tpu::GetTpuProgramResponseExternal; +#else + using ResponseType = ::tensorflow::tpu::GetTpuProgramResponse; +#endif + + // N.B. This must be synchronized with the method order in + // tpu_compilation_cache.proto. + enum class MethodId { kGetTpuProgram = 0 }; + + static constexpr char const* service_full_name() { +#if defined(LIBTPU_ON_GCE) + return "tensorflow.tpu.TpuCompilationCacheServiceExternal"; +#else + return "tensorflow.tpu.TpuCompilationCacheService"; +#endif + } + class StubInterface { + public: + virtual ~StubInterface() = default; + // This method requests the cached proto that the TPU execute op has + // been instructed to execute. + virtual ::grpc::Status GetTpuProgram(::grpc::ClientContext* context, + const RequestType& request, + ResponseType* response) = 0; + std::unique_ptr<::grpc::ClientAsyncResponseReaderInterface> + AsyncGetTpuProgram(::grpc::ClientContext* context, + const RequestType& request, + ::grpc::CompletionQueue* cq) { + return std::unique_ptr< + ::grpc::ClientAsyncResponseReaderInterface>( + AsyncGetTpuProgramRaw(context, request, cq)); + } + std::unique_ptr<::grpc::ClientAsyncResponseReaderInterface> + PrepareAsyncGetTpuProgram(::grpc::ClientContext* context, + const RequestType& request, + ::grpc::CompletionQueue* cq) { + return std::unique_ptr< + ::grpc::ClientAsyncResponseReaderInterface>( + PrepareAsyncGetTpuProgramRaw(context, request, cq)); + } + + private: + virtual ::grpc::ClientAsyncResponseReaderInterface* + AsyncGetTpuProgramRaw(::grpc::ClientContext* context, + const RequestType& request, + ::grpc::CompletionQueue* cq) = 0; + virtual ::grpc::ClientAsyncResponseReaderInterface* + PrepareAsyncGetTpuProgramRaw(::grpc::ClientContext* context, + const RequestType& request, + ::grpc::CompletionQueue* cq) = 0; + }; + class Stub final : public StubInterface { + public: + explicit Stub(const std::shared_ptr<::grpc::ChannelInterface>& channel); + ::grpc::Status GetTpuProgram(::grpc::ClientContext* context, + const RequestType& request, + ResponseType* response) override; + std::unique_ptr<::grpc::ClientAsyncResponseReader> + AsyncGetTpuProgram(::grpc::ClientContext* context, + const RequestType& request, + ::grpc::CompletionQueue* cq) { + return std::unique_ptr<::grpc::ClientAsyncResponseReader>( + AsyncGetTpuProgramRaw(context, request, cq)); + } + std::unique_ptr<::grpc::ClientAsyncResponseReader> + PrepareAsyncGetTpuProgram(::grpc::ClientContext* context, + const RequestType& request, + ::grpc::CompletionQueue* cq) { + return std::unique_ptr<::grpc::ClientAsyncResponseReader>( + PrepareAsyncGetTpuProgramRaw(context, request, cq)); + } + + private: + std::shared_ptr<::grpc::ChannelInterface> channel_; + ::grpc::ClientAsyncResponseReader* AsyncGetTpuProgramRaw( + ::grpc::ClientContext* context, const RequestType& request, + ::grpc::CompletionQueue* cq) override; + ::grpc::ClientAsyncResponseReader* + PrepareAsyncGetTpuProgramRaw(::grpc::ClientContext* context, + const RequestType& request, + ::grpc::CompletionQueue* cq) override; + const ::grpc::internal::RpcMethod rpcmethod_get_tpu_program_; + }; + static std::unique_ptr NewStub( + const std::shared_ptr<::grpc::ChannelInterface>& channel, + const ::grpc::StubOptions& options = ::grpc::StubOptions()); + + class Service : public ::grpc::Service { + public: + Service(); + ~Service() override; + // This method requests the cached proto that the TPU execute op has + // been instructed to execute. + virtual ::grpc::Status GetTpuProgram(::grpc::ServerContext* context, + const RequestType* request, + ResponseType* response); + }; + template + class WithAsyncMethod_GetTpuProgram : public BaseClass { + private: + void BaseClassMustBeDerivedFromService(const Service* service) {} + + public: + WithAsyncMethod_GetTpuProgram() { ::grpc::Service::MarkMethodAsync(0); } + ~WithAsyncMethod_GetTpuProgram() override { + BaseClassMustBeDerivedFromService(this); + } + // disable synchronous version of this method + ::grpc::Status GetTpuProgram(::grpc::ServerContext* context, + const RequestType* request, + ResponseType* response) override { + abort(); + return ::grpc::Status(::grpc::StatusCode::UNIMPLEMENTED, ""); + } + void RequestGetTpuProgram( + ::grpc::ServerContext* context, RequestType* request, + ::grpc::ServerAsyncResponseWriter* response, + ::grpc::CompletionQueue* new_call_cq, + ::grpc::ServerCompletionQueue* notification_cq, void* tag) { + ::grpc::Service::RequestAsyncUnary(0, context, request, response, + new_call_cq, notification_cq, tag); + } + + // Make RequestAsyncUnary accessible to grpc_call.h + using ::grpc::Service::RequestAsyncUnary; + }; + typedef WithAsyncMethod_GetTpuProgram AsyncService; + template + class WithGenericMethod_GetTpuProgram : public BaseClass { + private: + void BaseClassMustBeDerivedFromService(const Service* service) {} + + public: + WithGenericMethod_GetTpuProgram() { ::grpc::Service::MarkMethodGeneric(0); } + ~WithGenericMethod_GetTpuProgram() override { + BaseClassMustBeDerivedFromService(this); + } + // disable synchronous version of this method + ::grpc::Status GetTpuProgram(::grpc::ServerContext* context, + const RequestType* request, + ResponseType* response) override { + abort(); + return ::grpc::Status(::grpc::StatusCode::UNIMPLEMENTED, ""); + } + }; + template + class WithStreamedUnaryMethod_GetTpuProgram : public BaseClass { + private: + void BaseClassMustBeDerivedFromService(const Service* service) {} + + public: + WithStreamedUnaryMethod_GetTpuProgram() { + ::grpc::Service::MarkMethodStreamed( + 0, + new ::grpc::internal::StreamedUnaryHandler( + std::bind(&WithStreamedUnaryMethod_GetTpuProgram< + BaseClass>::StreamedGetTpuProgram, + this, std::placeholders::_1, std::placeholders::_2))); + } + ~WithStreamedUnaryMethod_GetTpuProgram() override { + BaseClassMustBeDerivedFromService(this); + } + // disable regular version of this method + ::grpc::Status GetTpuProgram(::grpc::ServerContext* context, + const RequestType* request, + ResponseType* response) override { + abort(); + return ::grpc::Status(::grpc::StatusCode::UNIMPLEMENTED, ""); + } + // replace default version of method with streamed unary + virtual ::grpc::Status StreamedGetTpuProgram( + ::grpc::ServerContext* context, + ::grpc::ServerUnaryStreamer* + server_unary_streamer) = 0; + }; + typedef WithStreamedUnaryMethod_GetTpuProgram StreamedUnaryService; + typedef Service SplitStreamedService; + typedef WithStreamedUnaryMethod_GetTpuProgram StreamedService; +}; +} // namespace grpc +} // namespace tpu +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILATION_CACHE_GRPC_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.h b/third_party/tflite-hdrs/tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.h new file mode 100644 index 00000000..eca894ed --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.h @@ -0,0 +1,335 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILATION_CACHE_INTERFACE_H_ +#define TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILATION_CACHE_INTERFACE_H_ + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/base/thread_annotations.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/node_hash_map.h" +#include "absl/strings/str_cat.h" +#include "absl/synchronization/mutex.h" +#include "tensorflow/compiler/tf2xla/host_compute_metadata.pb.h" +#include "xla/util.h" +#include "tensorflow/core/framework/resource_mgr.h" +#include "tensorflow/core/lib/core/refcount.h" +#include "tensorflow/core/lib/core/threadpool.h" +#include "tensorflow/core/profiler/lib/traceme.h" +#include "tensorflow/core/protobuf/config.pb.h" +#include "tensorflow/core/tpu/kernels/compiled_subgraph.h" +#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_common.pb.h" +#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_entry.h" +#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_key.h" +#include "tensorflow/core/tpu/kernels/tpu_compilation_metrics.h" +#include "tensorflow/core/tpu/kernels/trace_util.h" + +namespace tensorflow { +namespace tpu { + +// Base class that holds references to compiled protos so that the protos are +// not garbage-collected before being used by execute ops. Use +// TpuCompilationCache::MakePerStepRefHolder to create an instance of a concrete +// ref holder object. +class CompilationRefHolder : public ResourceBase { + public: + ~CompilationRefHolder() override = default; +}; + +// Wrapper for a cache entry returned by all the TpuCompilationCacheInterface +// `Lookup` methods, and ensures the underlying proto is not garbage-collected +// until the client discards the ptr. +class CompilationCacheEntryRef { + public: + CompilationCacheEntryRef(); + CompilationCacheEntryRef(TpuCompilationCacheInterface* parent, + CompiledSubgraph* entry, int index); + + virtual ~CompilationCacheEntryRef(); + + // Returns a TpuCompilationCacheEntry that should not be used beyond the + // lifetime of the CompilationCacheEntryRef. + virtual TpuCompilationCacheEntry get(); + + // Mutates this ref to point to the entry's subentry (for + // sharding/unsharding) or main entry (unchanged) as specified by + // fetch_target. The refcount is kept unchanged, since we only track the + // refcount of the main entry. The entry ref needs to point to the main + // entry before this call. + // + // If the requested subentry does not exist, the ref will point to a nullptr + // entry, and the original entry will be unref'ed. + virtual absl::Status ToSubEntryRef(CompilationCacheFetchTarget fetch_target); + + protected: + TpuCompilationCacheInterface* parent_; // Not owned. + // A reference to entry_ is acquired in the constructor and released via + // parent->DiscardEntryRefs in the destructor. + CompiledSubgraph* entry_; + // The index of the program in entry_ that is returned by the get method. + int index_; +}; + +class TpuCompilationCacheInterface : public ResourceBase { + public: + explicit TpuCompilationCacheInterface(int64_t max_cache_size); + ~TpuCompilationCacheInterface() override; + + // Ensures there is an entry for key present in the cache. By the time + // CompileIfKeyAbsent returns there is guaranteed to be an entry in the cache + // for key, and that entry will remain valid at least until + // per_step_ref_holder is deleted. The first call to CompileIfKeyAbsent with a + // key that is not in the cache will evaluate compile_function to compute the + // value to use in the entry. Subsequent calls with the same key will block + // until compile_function completes. Other cache reads and inserts may proceed + // on other threads while compile_function is executing. If + // per_step_ref_holder is nullptr then the caller is responsible for calling + // Release(subgraph_key) to manually discard its reference to the compiled + // program, once the caller will not look up the compiled program again. + // + // compile_function should compile the subgraph represented by key and fill in + // one TPUExecutableProto per model-parallel core into its passed argument. It + // should return OK if and only if compilation succeeds. The executable proto + // vector will be discarded on non-OK status. + absl::Status CompileIfKeyAbsent( + const TpuCompilationCacheKey& subgraph_key, + const SessionMetadata* session_metadata, + CompilationRefHolder* per_step_ref_holder, int64_t* uid, + std::vector* proto_key, + std::vector* sharding_key, + std::vector* may_modify_variables, + absl::Span* hlo_metadatas, + const std::function& + compile_function); + + // Differences between MarkEntryForEviction and Release: + // There are two modes of managing cache entries: + // 1) LRU eviction + pinning; 2) manual. + // We use mode 1) if CompilationRefHolder is provided to CompileIfKeyAbsent. + // Otherwise it is manual mode (mainly used by XRT). + // MarkEntryForEviction should only be used in mode 1) to eagerly evict cache + // entries when callers know that they do not need them anymore. + // Release should only be used in mode 2) to explicitly remove an entry. + + // Mark the entry indexed by `subgraph_uid` for eviction. This should only be + // called if per_step_ref_holder was NOT nullptr in the corresponding call to + // CompileIfKeyAbsent(subgraph_key, ...). Otherwise, use Release(int64 + // subgraph_uid). + absl::Status MarkEntryForEviction(int64_t subgraph_uid); + + // Manually discards a reference to the compiled subgraph. This should only be + // called if per_step_ref_holder was nullptr in the corresponding call to + // CompileIfKeyAbsent(subgraph_key, ...). + absl::Status Release(int64_t subgraph_uid); + + // Looks up an executable corresponding to the model-parallel core index of + // the subgraph represented by key. On success a pointer to an EntryRef + // holding the program is returned in entry. + absl::Status Lookup(const std::string& proto_key, + std::unique_ptr* entry); + + // Looks up an executable corresponding to the model-parallel core index of + // the subgraph represented by uid. On success a pointer to an EntryRef + // holding the program is returned in entry. + absl::Status Lookup(int64_t uid, int proto_index, + std::unique_ptr* entry); + + // Looks up the subgraph represented by uid, and returns the vector of keys, + // one per core, corresponding to that subgraph. + absl::Status GetKeysFromUid(int64_t uid, std::vector* keys); + + // Makes a reference holder for this cache, that can be stored in the per-step + // resource manager and will ensure that compiled entries persist until the + // end of a step. + CompilationRefHolder* MakePerStepRefHolder(); + + // Convenience method called by ~RefHolder without mu_ held. Calls + // DiscardEntryRef on every element of entries. + void DiscardEntryRefs(absl::Span entries); + + std::string DebugString() const override { return "TpuCompilationCacheBase"; } + + protected: + std::string ConstructCompilationCacheKey(const TpuCompilationCacheKey& key) { + if (!key.has_guaranteed_const) { + return key.prefix; + } + return absl::StrCat(key.prefix, "|", key.session_handle, "|", + key.guaranteed_const_fingerprint()); + } + + // Private implementation of the generic CompilationRefHolder that knows about + // CompiledSubgraph entries. + class RefHolder : public CompilationRefHolder { + public: + explicit RefHolder(TpuCompilationCacheInterface* parent); + ~RefHolder() override; + + // Adds entry to the list of entries that will be released when the + // RefHolder is destroyed. Each entry is released via a call to + // parent_->DiscardEntryRefs. + void AddRef(CompiledSubgraph* entry); + + std::string DebugString() const override; + + private: + TpuCompilationCacheInterface* parent_; // Not owned. + std::vector entries_; + }; + + // The bulk of implementation of CompileIfKeyAbsent() with the exception + // of unloading programs that corresponds to possibly removed cache + // entries. The split helps to manage locking since we prefer to perform + // unloading without holding extra locks. + absl::Status CompileIfKeyAbsentHelper( + const TpuCompilationCacheKey& subgraph_key, + const SessionMetadata* session_metadata, + CompilationRefHolder* per_step_ref_holder, int64_t* uid, + std::vector* proto_key, + std::vector* sharding_key, + std::vector* may_modify_variables, + std::vector* removed_entries, + absl::Span* hlo_metadatas, + const std::function& + compile_function); + + // This is called by the cache when entry is marked for eviction; by + // a RefHolder (via DiscardEntryRefs) when a step completes; and by + // an EntryRefImpl when it is destroyed. Releases one reference to entry + // if more than 1 remains. If only one reference is left, the entry is removed + // from cache_ and is returned to the caller; which must eventually call + // UnloadAndDestroy(). We do not call UnloadAndDestroy within DiscardEntryRef + // to avoid holding the lock during program unloading. + ABSL_MUST_USE_RESULT CompiledSubgraph* DiscardEntryRef( + CompiledSubgraph* entry) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); + + // Marks the oldest unmarked entry for eviction. Requires that there is at + // least one such entry. In case the evicted entry had only 1 reference it + // is removed from the cache and returned to the caller which must eventually + // call UnloadAndDestroy. + ABSL_MUST_USE_RESULT CompiledSubgraph* MarkOldestEntryForEviction() + ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); + + // Updates datastructures to indicate that entry, which had been marked for + // eviction, has been looked up. This is called by CompileIfKeyAbsent when an + // entry is newly created, or an entry that has been marked for eviction but + // not yet evicted is looked up. + // + // First the entry is unmarked for eviction, i.e. the cache gains a reference + // to entry, entry's last_use field is set to be the most recent value of + // use_counter_ and entries_by_last_use_ is updated accordingly. + // + // Next, the size of the cache is examined to see if any other entries need to + // be marked for eviction now that entry has been unmarked. While the total + // size of unmarked cached entries is greater than max_cache_size_, entries + // are marked for eviction in LRU order. The most recently used entry is never + // marked for eviction, so an entry larger than the max cache size will remain + // in the cache until it is replaced by something else. In case some entries + // actually were removed from the cache, they are a returned to the caller via + // removed_entries. The caller must eventually delete them by calling + // UnloadAndDestroy. + void LookupEntryMarkedForEviction( + CompiledSubgraph* entry, std::vector* removed_entries) + ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); + + // Removes the entry with given key from cache. + size_t RemoveEntry(const std::string& key) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); + + // Inserts the given key and entry to cache. + void InsertEntry(const std::string& key, CompiledSubgraph* entry) + ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); + + // Returns the cache key matching given subgraph_key. + std::string FindCacheKey(const TpuCompilationCacheKey& subgraph_key) + ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); + + // Creates a new entry by running initialize_programs and places it in the + // cache to be looked up by key. The new entry is in the 'marked for eviction' + // state (not present in entries_by_last_use_) and the caller is expected to + // call LookupEntryMarkedForEviction after InitializeEntry. + // + // **InitializeEntry releases mu_ during the call to initialize_programs.** + virtual CompiledSubgraph* InitializeEntry( + const std::string& key, + const std::function& + initialize_programs, + const TpuCompilationCacheKey& subgraph_key) + ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_) = 0; + + // Unloads the program associated with the entry from all local devices + // and deletes the entry itself. It is assumed no one else has a reference + // to it and all related keys had already been removed from the cache. + // The call can perform device IO so no locks should be held while calling it. + void UnloadAndDestroy(CompiledSubgraph* entry) ABSL_LOCKS_EXCLUDED(mu_); + + // The maximum size of entries that are stored in the cache before entries are + // marked for eviction. + const int64_t max_cache_size_; + // Mutex to protect access to shared resources under multi-threading + // environment. + absl::Mutex mu_; + // The total size of entries that are stored and not marked for eviction. + int64_t cache_size_ ABSL_GUARDED_BY(mu_) = 0; + // The total size of entries that are marked for eviction. + int64_t marked_for_eviction_size_ ABSL_GUARDED_BY(mu_) = 0; + // The value to assign to the last_use field of the next entry that is looked + // up. + int64_t use_counter_ ABSL_GUARDED_BY(mu_) = 0; + // session_key_map_ and fingerprint_key_map_ are used for looking up the + // cache_ key matching a given subgraph key. When doing a lookup, check + // session_key_map_ first to avoid unnecessay fingerprint computation. + // Map from key prefix + session_handle to a cache_ key. + absl::node_hash_map session_key_map_ + ABSL_GUARDED_BY(mu_); + // Map from key prefix + fingerprint to a cache_ key. + absl::node_hash_map fingerprint_key_map_ + ABSL_GUARDED_BY(mu_); + // All the subgraph entries that can be looked up in the cache. An entry is + // marked for eviction iff it is present in cache_ and not in + // entries_by_last_use_. + std::unordered_map cache_ + ABSL_GUARDED_BY(mu_); + // All the subgraph entries that can be looked up in the cache, indexed by + // uid. + absl::flat_hash_map entries_by_uid_ + ABSL_GUARDED_BY(mu_); + // All the protos that can be looked up in the cache, indexed by proto + // key. The value of the map is a subgraph and the index of the proto compiled + // for that subgraph. + std::unordered_map> + entries_by_proto_key_ ABSL_GUARDED_BY(mu_); + // Map from last_use to entry, used to mark entries for eviction in LRU + // order. If an entry's last_use counter is not present as a key in + // entries_by_last_use_ then the entry has been marked for eviction. + std::map entries_by_last_use_ + ABSL_GUARDED_BY(mu_); + + TpuCompilationMetrics tpu_compilation_metrics_; + + private: + TpuCompilationCacheInterface(const TpuCompilationCacheInterface&) = delete; + TpuCompilationCacheInterface& operator=(const TpuCompilationCacheInterface&) = + delete; +}; +} // namespace tpu +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILATION_CACHE_INTERFACE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/tpu/kernels/tpu_compilation_cache_key.h b/third_party/tflite-hdrs/tensorflow/core/tpu/kernels/tpu_compilation_cache_key.h new file mode 100644 index 00000000..59086f84 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/tpu/kernels/tpu_compilation_cache_key.h @@ -0,0 +1,70 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILATION_CACHE_KEY_H_ +#define TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILATION_CACHE_KEY_H_ + +#include +#include +#include + +#include "absl/strings/str_cat.h" + +namespace tensorflow { +namespace tpu { + +struct TpuCompilationCacheKey { + // Prefix of the key. + std::string prefix; + + // A boolean flag to specify if `guaranteed_const` is used. Guarantee const is + // normally used in TPU inference to avoid re-copying unchanged variables onto + // the TPU device. It promises the value is identical for every execution in + // the same session even if the actual value changes in later executions. + bool has_guaranteed_const = false; + + // Unique session identifier. It is set when `has_guaranteed_const` is true. + std::string session_handle; + + // Unique session identifier for TPU compilation; it should be a 64 bit + // positive integer, which can uniquely distinguish a live session. + // TPU compiler may use this information to choose dynamically provided + // compilation options without hurting reproducibility for debugging. + uint64_t session_id; + + // Fingerprint of `guaranteed_const` value. It is set when the value of the + // `has_guaranteed_const` is true. Produce the value when necessary. + std::function guaranteed_const_fingerprint; + + // A more verbose key for debugging purpose. + std::string debug_string; + + // Constructs the TPU compilation cache key by concatenating the `prefix`, + // `session_handle` and `guaranteed_const_fingerprint`. + std::string ToString() const { + if (!has_guaranteed_const) { + return prefix; + } + return absl::StrCat(prefix, "|", session_handle, "|", + guaranteed_const_fingerprint()); + } + + explicit TpuCompilationCacheKey() = default; + explicit TpuCompilationCacheKey(const std::string& p) : prefix(p) {} +}; + +} // namespace tpu +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILATION_CACHE_KEY_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/tpu/kernels/tpu_compilation_cache_local_lookup.h b/third_party/tflite-hdrs/tensorflow/core/tpu/kernels/tpu_compilation_cache_local_lookup.h new file mode 100644 index 00000000..40b4f862 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/tpu/kernels/tpu_compilation_cache_local_lookup.h @@ -0,0 +1,57 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILATION_CACHE_LOCAL_LOOKUP_H_ +#define TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILATION_CACHE_LOCAL_LOOKUP_H_ + +#include +#include +#include + +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_common.pb.h" +#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.h" +#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_lookup.h" + +namespace tensorflow { +namespace tpu { + +// Class for looking up TPU programs when the execute and compile Op are in the +// same address space. The proto is simply looked up in the compilation cache, +// without any serialization taking place. +class TpuCompilationCacheLocalLookup : public TpuCompilationCacheLookup { + public: + explicit TpuCompilationCacheLocalLookup(TpuCompilationCacheInterface* cache); + ~TpuCompilationCacheLocalLookup() override; + + absl::Status Lookup(const std::string& proto_key, + std::unique_ptr* entry, + CompilationCacheFetchTarget fetch_target) override; + + absl::Status Lookup(int64_t uid, int proto_index, + std::unique_ptr* entry, + CompilationCacheFetchTarget fetch_target) override; + + std::string DebugString() const override; + + private: + // The subgraph compilation cache, in the same process address space where the + // lookups are happening. + TpuCompilationCacheInterface* cache_; +}; + +} // namespace tpu +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILATION_CACHE_LOCAL_LOOKUP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/tpu/kernels/tpu_compilation_cache_lookup.h b/third_party/tflite-hdrs/tensorflow/core/tpu/kernels/tpu_compilation_cache_lookup.h new file mode 100644 index 00000000..0cdfe64e --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/tpu/kernels/tpu_compilation_cache_lookup.h @@ -0,0 +1,81 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILATION_CACHE_LOOKUP_H_ +#define TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILATION_CACHE_LOOKUP_H_ + +#include +#include +#include +#include + +#include "tensorflow/core/framework/resource_base.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_common.pb.h" +#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.h" + +namespace tensorflow { +namespace tpu { + +// TODO(b/162241759): consider merging TpuCompilationCacheLookup and +// TpuCompilationCacheInterface. +// Base class allowing Execute Ops to look up TPU programs. Different subclasses +// are used when the execute Op is in the same address space as the compile Op, +// and when they need to communicate over RPC. +class TpuCompilationCacheLookup : public ResourceBase { + public: + ~TpuCompilationCacheLookup() override = default; + + // Looks up an executable corresponding to the model-parallel core index of + // the subgraph represented by key. On success a wrapper for the proto is + // returned in program. The wrapper is guaranteed to be valid only during the + // execution of the Op requesting the proto. + // + // Only one of the main, sharding, unsharding entries is fetched, as specified + // in fetch_target. + // + // If the compilation does not create sharding/unsharding programs, but the + // fetch_target requests one of them, then after this call + // (*entry)->get().get_executable() will return nullptr. + virtual absl::Status Lookup(const std::string& proto_key, + std::unique_ptr* entry, + CompilationCacheFetchTarget fetch_target) = 0; + + virtual absl::Status Lookup( + const std::string& proto_key, + std::unique_ptr* entry) { + return Lookup(proto_key, std::move(entry), + CompilationCacheFetchTarget::MAIN); + } + + // Looks up an executable corresponding to the model-parallel core index of + // the subgraph represented by uid. On success a wrapper for the proto is + // returned in program. The wrapper is guaranteed to be valid only during the + // execution of the Op requesting the proto. + virtual absl::Status Lookup(int64_t uid, int proto_index, + std::unique_ptr* entry, + CompilationCacheFetchTarget fetch_target) = 0; + + virtual absl::Status Lookup( + int64_t uid, int proto_index, + std::unique_ptr* entry) { + return Lookup(uid, proto_index, std::move(entry), + CompilationCacheFetchTarget::MAIN); + } +}; + +} // namespace tpu +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILATION_CACHE_LOOKUP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/tpu/kernels/tpu_compilation_cache_rpc_lookup.h b/third_party/tflite-hdrs/tensorflow/core/tpu/kernels/tpu_compilation_cache_rpc_lookup.h new file mode 100644 index 00000000..e8666ec6 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/tpu/kernels/tpu_compilation_cache_rpc_lookup.h @@ -0,0 +1,95 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILATION_CACHE_RPC_LOOKUP_H_ +#define TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILATION_CACHE_RPC_LOOKUP_H_ + +#include +#include +#include +#include +#include + +#include "absl/synchronization/mutex.h" +#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_common.pb.h" +#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_grpc.h" +#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.h" +#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_lookup.h" +#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_rpc_support.h" +#include "tensorflow/core/tpu/kernels/tpu_program_group_interface.h" + +namespace tensorflow { +namespace tpu { + +// Class for looking up and caching TPU program via RPC. +class TpuCompilationCacheRpcLookup : public TpuCompilationCacheLookup { + public: + using StubType = tpu::grpc::TpuCompilationCacheService::Stub; + + TpuCompilationCacheRpcLookup(const string& server_address, + int64_t max_cache_size); + ~TpuCompilationCacheRpcLookup() override = default; + + absl::Status Lookup(const string& proto_key, + std::unique_ptr* entry, + tpu::CompilationCacheFetchTarget fetch_target) override; + + absl::Status Lookup(int64_t uid, int proto_index, + std::unique_ptr* entry, + tpu::CompilationCacheFetchTarget fetch_target) override; + + string DebugString() const override; + + private: + // Helper method to make the RPC request to the central cache. + absl::Status RemoteLookupLocked(const string& local_proto_key, + const tpu::GetTpuProgramRequest& request, + std::shared_ptr* cache_entry) + ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); + + // Helper method to adjust datastructures after a cache lookup. + // We use `removed_entries` so that actual CacheEntry destruction happens + // outside the lock. + void PostLookupLocked( + std::shared_ptr* cache_entry, + std::unique_ptr* entry, + std::vector>* removed_entries) + ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); + + // The maximum size of entries that are stored in the cache before entries are + // evicted. + const int64_t max_cache_size_; + + std::unique_ptr stub_; + + // Protect concurrent access to member variables below. + mutable absl::Mutex mu_; + + // The total size of entries in the cache. + int64_t cache_size_ ABSL_GUARDED_BY(mu_) = 0; + // The value to assign to the last_use field of the next entry that is looked + // up. + int64_t use_counter_ ABSL_GUARDED_BY(mu_) = 0; + // The entries that can be looked up in the cache. An entry is deleted from + // the cache as soon as it is evicted, but the underlying shared_ptr won't be + // freed until any wrappers holding it go out of scope. + std::unordered_map> cache_ + ABSL_GUARDED_BY(mu_); + // Map from last_use to entry, used to evict entries in LRU order. + std::map entries_by_last_use_ ABSL_GUARDED_BY(mu_); +}; +} // namespace tpu +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILATION_CACHE_RPC_LOOKUP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/tpu/kernels/tpu_compilation_cache_rpc_support.h b/third_party/tflite-hdrs/tensorflow/core/tpu/kernels/tpu_compilation_cache_rpc_support.h new file mode 100644 index 00000000..7a2f25f0 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/tpu/kernels/tpu_compilation_cache_rpc_support.h @@ -0,0 +1,99 @@ +#include "absl/status/statusor.h" +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILATION_CACHE_RPC_SUPPORT_H_ +#define TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILATION_CACHE_RPC_SUPPORT_H_ + +#include +#include +#include +#include +#include + +#include "grpcpp/security/credentials.h" +#include "grpcpp/support/slice.h" +#include "absl/strings/string_view.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_entry.h" +#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.h" +#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_lookup.h" +#include "tensorflow/core/tpu/kernels/tpu_program_group_interface.h" + +namespace tensorflow { +namespace tpu { + +// A cache entry for remote TPU compilation. +struct CacheEntry { + CacheEntry() : size(0), last_use(-1) {} + virtual ~CacheEntry() { + if (tpu_program_group != nullptr) { + tpu_program_group->UnloadAndDestroyPrograms(); + } + } + std::unique_ptr tpu_program_group; + std::string key; + int64_t size; + + // An integer-based monotonically increasing counter used by the TPU + // compilation cache to sort and evict the least recently used entry when the + // cache size exceeded the maximum size limit. The value is initialized to + // `-1` as an initial value. + int64_t last_use; +}; + +// Implementation of `CompilationCacheEntryRef` that holds a shared_ptr to the +// local cache entry until the wrapper is destroyed. +class CacheWrapper : public CompilationCacheEntryRef { + public: + explicit CacheWrapper(std::shared_ptr entry) + : cache_entry_(std::move(entry)) {} + ~CacheWrapper() override = default; + + TpuCompilationCacheEntry get() override { + if (cache_entry_->size == 0) { + // Create an empty entry if the size is 0. This corresponds to + // non-existing sharding/unsharding entries. + return TpuCompilationCacheEntry(); + } + return TpuCompilationCacheEntry(cache_entry_->tpu_program_group.get(), + /*core_index=*/0); + } + + absl::Status ToSubEntryRef( + CompilationCacheFetchTarget fetch_target) override { + LOG(FATAL) << "Not implemented by designed."; + } + + private: + std::shared_ptr cache_entry_; +}; + +// Creates gRPC channel credentials for the current runtime env. +std::shared_ptr<::grpc::ChannelCredentials> CreateChannelCredentials(); + +// Fills an uinitialized `CacheEntry` from `GetTpuProgramResponse` proto. The +// `cache_entry` will be instantiated by the function. +template +absl::Status DeserializeRpcResponseToCacheEntry( + absl::string_view local_proto_key, ResponseType* response, + std::shared_ptr* cache_entry); + +// Serializes `TpuCompilationCacheEntry` to gRPC bufer slices. +absl::StatusOr> SerializeCacheEntryToBufferSlices( + const TpuCompilationCacheEntry& cache_entry); +} // namespace tpu +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILATION_CACHE_RPC_SUPPORT_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/tpu/kernels/tpu_compilation_cache_service.h b/third_party/tflite-hdrs/tensorflow/core/tpu/kernels/tpu_compilation_cache_service.h new file mode 100644 index 00000000..6dd644d3 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/tpu/kernels/tpu_compilation_cache_service.h @@ -0,0 +1,71 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILATION_CACHE_SERVICE_H_ +#define TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILATION_CACHE_SERVICE_H_ + +#include +#include + +#include "grpcpp/server_builder.h" +#include "xla/tsl/distributed_runtime/rpc/grpc_call.h" +#include "tensorflow/core/lib/core/threadpool.h" +#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_common.pb.h" +#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_grpc.h" +#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.h" + +namespace tensorflow { +// gRPC service for handling CompilationCache requests. +// To avoid OOMs during execution, this service using the asynchronous raw gRPC +// interface to serialize cache results directly to gRPC byte buffers. This +// allows us to control serialization concurrency and avoids making an extra +// copy of the program cache for each worker. +class TpuCompilationCacheService { + public: + using ServiceType = ::tensorflow::tpu::grpc::TpuCompilationCacheService; + using AsyncService = ServiceType::AsyncService; + + TpuCompilationCacheService(::grpc::ServerBuilder* server_builder, + tpu::TpuCompilationCacheInterface* cache); + ~TpuCompilationCacheService(); + + void Start(); + bool Shutdown(int timeout_sec); + void SetMemoryQuota(size_t max_bytes); + + private: + void HandleRPCsLoop(); + + using GetTpuProgramCall = + tsl::Call; + + // Schedule the cache fetch into the serving thread pool. + void HandleGetTpuProgram(GetTpuProgramCall* call); + + // Performs the actual cache fetch and serialization. + void GetTpuProgram(GetTpuProgramCall* call); + + std::atomic running_; + tpu::TpuCompilationCacheInterface* cache_; + ::grpc::ServerBuilder* server_builder_; + std::unique_ptr<::grpc::Server> server_; + std::unique_ptr<::grpc::ServerCompletionQueue> cq_; + std::unique_ptr thread_pool_; + std::unique_ptr polling_thread_; + AsyncService service_; +}; +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILATION_CACHE_SERVICE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/tpu/kernels/tpu_compilation_metrics.h b/third_party/tflite-hdrs/tensorflow/core/tpu/kernels/tpu_compilation_metrics.h new file mode 100644 index 00000000..f201fd27 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/tpu/kernels/tpu_compilation_metrics.h @@ -0,0 +1,42 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILATION_METRICS_H_ +#define TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILATION_METRICS_H_ + +#include + +#include "absl/strings/string_view.h" + +namespace tensorflow { +namespace tpu { + +// Tracks Tpu compilation and cache metrics. +class TpuCompilationMetrics { + public: + // Increments the number of cache lookup count. + static void IncrementCacheLookupCount(bool is_cache_hit, + absl::string_view session_name); + + // Sets the total count of cache entries. + static void SetCacheEntryCount(int64_t count); + + // Increments number of compilation. + static void IncrementCompilationCount(absl::string_view session_name); +}; + +} // namespace tpu +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILATION_METRICS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/tpu/kernels/tpu_compile_op.h b/third_party/tflite-hdrs/tensorflow/core/tpu/kernels/tpu_compile_op.h new file mode 100644 index 00000000..4b8956f9 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/tpu/kernels/tpu_compile_op.h @@ -0,0 +1,79 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILE_OP_H_ +#define TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILE_OP_H_ + +#include + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/tpu/kernels/tpu_compile_op_common.h" + +namespace tensorflow { +namespace tpu { + +// The TPUCompile operator compiles a Tensorflow function into a +// TPU executable to be run by TPUExecute. +// +class TpuCompileOp : public OpKernel { + public: + explicit TpuCompileOp(OpKernelConstruction* ctx); + + TpuCompileOp(const TpuCompileOp&) = delete; + TpuCompileOp& operator=(const TpuCompileOp&) = delete; + + ~TpuCompileOp() override = default; + + void Compute(OpKernelContext* ctx) override; + + private: + std::unique_ptr impl_; +}; + +// The TPUCompile operator compiles a MLIR module into a +// TPU executable to be run by TPUExecute. +// +class TpuCompileMlirOp : public OpKernel { + public: + explicit TpuCompileMlirOp(OpKernelConstruction* ctx); + + TpuCompileMlirOp(const TpuCompileMlirOp&) = delete; + TpuCompileMlirOp& operator=(const TpuCompileMlirOp&) = delete; + + ~TpuCompileMlirOp() override = default; + + void Compute(OpKernelContext* ctx) override; + + private: + std::unique_ptr impl_; +}; + +class TpuCompileSucceededAssertOp : public OpKernel { + public: + explicit TpuCompileSucceededAssertOp(OpKernelConstruction* ctx) + : OpKernel(ctx) {} + + TpuCompileSucceededAssertOp(const TpuCompileSucceededAssertOp&) = delete; + TpuCompileSucceededAssertOp& operator=(const TpuCompileSucceededAssertOp&) = + delete; + + ~TpuCompileSucceededAssertOp() override = default; + + void Compute(OpKernelContext* ctx) override; +}; + +} // namespace tpu +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILE_OP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/tpu/kernels/tpu_compile_op_common.h b/third_party/tflite-hdrs/tensorflow/core/tpu/kernels/tpu_compile_op_common.h new file mode 100644 index 00000000..4c4dfdd0 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/tpu/kernels/tpu_compile_op_common.h @@ -0,0 +1,207 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILE_OP_COMMON_H_ +#define TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILE_OP_COMMON_H_ + +#include +#include +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "tensorflow/compiler/jit/shape_inference.h" +#include "xla/stream_executor/tpu/c_api_decl.h" +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/platform/fingerprint.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/protobuf/config.pb.h" +#include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h" +#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_key.h" +#include "tensorflow/core/tpu/kernels/tpu_compile_op_support.h" +#include "tensorflow/core/tpu/kernels/tpu_mesh_state_interface.h" +#include "tensorflow/core/tpu/kernels/tpu_program_group_interface.h" +#include "tsl/platform/env.h" +#include "tsl/platform/logging.h" // IWYU pragma: keep + +namespace tensorflow { +namespace tpu { + +// Forward declaration, defined below. +class TpuCompileOpKernelCommon; + +// A base factory class for creating a `TpuCompileOpKernelImpl` variant. +// By design, the actual factory can only be set once. +class CompileOpImplFactory { + public: + virtual ~CompileOpImplFactory() = default; + + virtual absl::StatusOr> + CreateNonMlirImpl(OpKernelConstruction* ctx) = 0; + + virtual absl::StatusOr> + CreateMlirImpl(OpKernelConstruction* ctx) = 0; + + static CompileOpImplFactory* Get(); + static void Register(CompileOpImplFactory* factory); + + private: + static CompileOpImplFactory* factory_; +}; + +// Abstract base class for TpuCompileOpKernel implementation. +class TpuCompileOpKernelCommon { + public: + TpuCompileOpKernelCommon(const std::string& mlir_module, + const tpu::TPUCompileMetadataProto metadata, + int num_computations, bool return_hlo_protos, + bool unload_cache_on_session_close) + : metadata_(metadata), + use_mlir_(true), + mlir_module_(mlir_module), + num_computations_(num_computations), + return_hlo_protos_(return_hlo_protos), + unload_cache_entry_on_session_close_(unload_cache_on_session_close), + persistent_cache_(nullptr) { + mlir_module_fingerprint_ = tensorflow::Fingerprint64(mlir_module_); + } + + TpuCompileOpKernelCommon( + const NameAttrList& function, const tpu::TPUCompileMetadataProto metadata, + int num_computations, bool return_hlo_protos, + bool unload_cache_on_session_close, + std::unique_ptr persistent_cache) + : metadata_(metadata), + use_mlir_(false), + function_(function), + num_computations_(num_computations), + return_hlo_protos_(return_hlo_protos), + unload_cache_entry_on_session_close_(unload_cache_on_session_close), + persistent_cache_(std::move(persistent_cache)) {} + + TpuCompileOpKernelCommon(const TpuCompileOpKernelCommon&) = delete; + TpuCompileOpKernelCommon& operator=(const TpuCompileOpKernelCommon&) = delete; + + virtual ~TpuCompileOpKernelCommon() = default; + + void Compute(OpKernelContext* ctx); + + // Lowers Mlir or TF Function computation into HLO IR and using XLA compiler + // compiles into TPU programs ready for execution. + virtual absl::Status Compile( + const std::variant& computation, + const XLA_TpuMeshState* mesh_state, + const std::vector& arg_shapes, + const TpuCompilationCacheKey* key, + TpuProgramGroupInterface* tpu_program_group) = 0; + + // Performs shape inference on `computation`, filling shape_info with operator + // shapes. The shapes of the _Arg nodes are taken from `arg_shapes`. + static absl::Status RunShapeInferenceOnComputation( + const tpu::TPUCompileMetadataProto& metadata, + const std::vector& arg_shapes, Graph* graph, + FunctionLibraryRuntime* flr, GraphShapeInfo* shape_info); + + protected: + absl::Status ComputeInternal(OpKernelContext* ctx); + + // Compile TPU program locally and populate the host compilation cache. + absl::Status CompileLocallyAndFillHostCache( + FunctionLibraryRuntime* flib_runtime, + const SessionMetadata* session_metadata, + const TpuMeshStateInterface* mesh_state, + const std::vector& dynamic_shapes, + const OpInputList& guaranteed_constants, + const tpu::TpuCompilationCacheKey& key, + TpuProgramGroupInterface* tpu_program_group); + + absl::Status CompileLocallyAndFillHostCacheInternal( + FunctionLibraryRuntime* flib_runtime, + const SessionMetadata* session_metadata, + const TpuMeshStateInterface* mesh_state, + const std::vector& dynamic_shapes, + const OpInputList& guaranteed_constants, + const tpu::TpuCompilationCacheKey& key, + TpuProgramGroupInterface* tpu_program_group); + + // Lookup from persistent compilation cache and populate both host cache and + // persistent cache. + virtual absl::Status LookupPersistentCompilationCacheAndFillCaches( + FunctionLibraryRuntime* flib_runtime, + const SessionMetadata* session_metadata, + const TpuMeshStateInterface* mesh_state, + const std::vector& dynamic_shapes, + const OpInputList& guaranteed_constants, + TpuPersistentCompilationCacheInterface* persistent_cache, + const tpu::TpuCompilationCacheKey& key, + TpuProgramGroupInterface* tpu_program_group) { + LOG(FATAL) << "Lookup from a persistent cache is NOT supported."; + } + + // Sleeps for `kSleepSeconds` seconds to give time for TPUCompileOp to finish + // before terminating peacefully. + static void ExitCountdown(tsl::Env* env, + std::shared_ptr> done); + + // Converts the `dynamic_shapes` arguments to the compile operator into + // TensorShapes. + static absl::Status GetDynamicShapes(OpKernelContext* ctx, + std::vector* shapes); + + tpu::TPUCompileMetadataProto metadata_; + + // Whether to compile given MLIR module in `mlir_module` instead of + // TensorFlow function referenced in `function_`. + bool use_mlir_; + + // Function containing the computation to compile. + NameAttrList function_; + + // A serialized MLIR ModuleOp. + std::string mlir_module_; + // Fingerprint of the MLIR Module created once on construction to avoid paying + // the cost on each invocation. + uint64 mlir_module_fingerprint_ = 0; + + // Number of different programs to compile. This maps to number of cores in + // each replica. + int num_computations_; + + // A flag to populate HLO protos field in CompilationResultProto. The HLO + // metadata could be large so default to not populating it unless explicitly + // requested. + bool return_hlo_protos_; + + // If enabled, DirectSession::Close will unload cache entries created during + // the lifetime of the session. + bool unload_cache_entry_on_session_close_; + + // Persistent cache for compiled TPU program for inference. + std::unique_ptr persistent_cache_; + + absl::Status RegisterXLAFingerprints( + const std::vector& arg_shapes, + TpuProgramGroupInterface* tpu_program_group, uint64 fingerprint); +}; + +} // namespace tpu +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILE_OP_COMMON_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/tpu/kernels/tpu_compile_op_impl.h b/third_party/tflite-hdrs/tensorflow/core/tpu/kernels/tpu_compile_op_impl.h new file mode 100644 index 00000000..1f5fdb52 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/tpu/kernels/tpu_compile_op_impl.h @@ -0,0 +1,68 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILE_OP_IMPL_H_ +#define TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILE_OP_IMPL_H_ + +#include +#include +#include + +#include "absl/status/status.h" +#include "xla/stream_executor/tpu/tpu_ops_c_api.h" +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h" +#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_key.h" +#include "tensorflow/core/tpu/kernels/tpu_compile_op_common.h" +#include "tensorflow/core/tpu/kernels/tpu_compile_op_support.h" +#include "tensorflow/core/tpu/kernels/tpu_program_group_interface.h" + +namespace tensorflow { +namespace tpu { + +// Base class for TpuCompileOp and TpuCompileMlirOp. +// Depends on whether it is given a computation in the form of serialized MLIR +// module or a Tensorflow function, TpuCompileOpKernelImpl converts computation +// into XLA HLO and then into a TPU execuable binary. +class TpuCompileOpKernelImpl : public TpuCompileOpKernelCommon { + public: + TpuCompileOpKernelImpl(const std::string& mlir_module, + const tpu::TPUCompileMetadataProto& metadata, + int num_computations, bool return_hlo_protos, + bool unload_cache_on_session_close) + : TpuCompileOpKernelCommon(mlir_module, metadata, num_computations, + return_hlo_protos, + unload_cache_on_session_close) {} + + TpuCompileOpKernelImpl(const NameAttrList& function, + const tpu::TPUCompileMetadataProto& metadata, + int num_computations, bool return_hlo_protos, + bool unload_cache_on_session_close) + : TpuCompileOpKernelCommon( + function, metadata, num_computations, return_hlo_protos, + unload_cache_on_session_close, /*persistent_cache=*/nullptr) {} + + absl::Status Compile( + const std::variant& computation, + const XLA_TpuMeshState* mesh_state, + const std::vector& arg_shapes, + const TpuCompilationCacheKey* key, + TpuProgramGroupInterface* tpu_program_group) override; +}; + +} // namespace tpu +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILE_OP_IMPL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/tpu/kernels/tpu_compile_op_options.h b/third_party/tflite-hdrs/tensorflow/core/tpu/kernels/tpu_compile_op_options.h new file mode 100644 index 00000000..b81fe4a3 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/tpu/kernels/tpu_compile_op_options.h @@ -0,0 +1,42 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILE_OP_OPTIONS_H_ +#define TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILE_OP_OPTIONS_H_ + +#include + +namespace tensorflow { +namespace internal { + +// Setter and getter that determine how TPUCompile responds to cancelled +// compilation. By default this is true, meaning cancelled compilation will +// abort the process, since that's the only mechanism we have available. +// +// Setting this to false allows the process to remain alive, and should only be +// used in tests. +void SetTpuCompilationCancellationTerminatesProcess(bool b); +bool TpuCompilationCancellationTerminatesProcess(); + +// Setter and getter that determine whether TPU compilation failure will cause +// chips to close. By default this is true, it is suitable for training. For +// inference, we never want servers to die and thus chips will keep alive. +// See b/109873767. +void SetTpuCompilationFailureClosesChips(bool value); +bool TpuCompilationFailureClosesChips(); + +} // namespace internal +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILE_OP_OPTIONS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/tpu/kernels/tpu_compile_op_support.h b/third_party/tflite-hdrs/tensorflow/core/tpu/kernels/tpu_compile_op_support.h new file mode 100644 index 00000000..e7ec4ac6 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/tpu/kernels/tpu_compile_op_support.h @@ -0,0 +1,168 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILE_OP_SUPPORT_H_ +#define TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILE_OP_SUPPORT_H_ + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "tensorflow/compiler/tf2xla/xla_compiler.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/ir/hlo_sharding.h" +#include "xla/service/computation_placer.h" +#include "xla/service/hlo_module_config.h" +#include "xla/shape.h" +#include "xla/shape_tree.h" +#include "xla/xla.pb.h" +#include "xla/xla_data.pb.h" +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor.pb.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/protobuf/config.pb.h" +#include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h" +#include "tensorflow/core/tpu/kernels/tpu_compile.pb.h" + +namespace tensorflow { +namespace tpu { + +// List of parameters for lowering Mlir to HLO IR. +// If mlir_module_op is set, it will be used instead of mlir_module. +struct MlirToHloArgs { + absl::string_view mlir_module; + ConfigProto::Experimental::MlirBridgeRollout rollout_state = + ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_ENABLED; + std::optional mlir_module_op; +}; + +// Variant of guaranteed constant tensors types. +using GuaranteedConsts = std::variant, + const OpInputList* const>; + +// List of parameters for lowering function library definition to HLO IR. +struct FunctionToHloArgs { + const NameAttrList* const function; + const FunctionLibraryDefinition* flib_def; + int graph_def_version; + GuaranteedConsts guaranteed_constants; +}; + +// Persistent cache for compiled TPU program and the related compiler metadata +// intended for TPU inference. +// TODO(henrytan): there is an opportunity to consolidate the interface with the +// `TpuCompilationCacheInterface` once `TpuPersistentCompilationCache` is +// converted into a ref count based class. +class TpuPersistentCompilationCacheInterface { + public: + virtual ~TpuPersistentCompilationCacheInterface() = default; + + // Returns the location where cache entries are stored. + virtual std::string cache_location() const = 0; +}; + +// Describes the position of an argument or return value after the computation +// has been partitioned into cores. +struct ShardingAndIndex { + // Sharding across cores. + ::xla::OpSharding sharding; + // Argument/return value number. If sharding is single-core, `indices` has a + // single element; otherwise, it has num_cores elements. + std::vector indices; +}; + +// TODO(b/158279168): Dedup with internal version. +// Return the per-device shape for a `shape` with a given `sharding`. +xla::Shape GetPerDeviceShape(const xla::Shape& shape, + const xla::HloSharding& sharding, int64_t device); + +absl::StatusOr> CreateModuleConfig( + const xla::ProgramShape& program_shape, + absl::Span argument_shapes, + std::optional result_layout, + std::optional device_assignment, + int replica_count, int num_partitions, + const xla::DebugOptions* debug_options, const int* seed, + const int* launch_id, const bool* alias_passthrough_params, + const xla::FusionConfigCollection* fusion_config_collection, + const std::vector>* fusion_config); + +absl::StatusOr> CreateModuleConfig( + const xla::ProgramShape& program_shape, + absl::Span argument_shapes, + std::optional result_layout, + std::optional device_assignment, + int replica_count, int num_partitions, + const xla::DebugOptions* debug_options); + +xla::ShapeTree GetSubtree( + const xla::ShapeTree& tuple_shape_tree, + int element_index); + +xla::Shape GetPerDeviceShape(const xla::Shape& shape, + const xla::HloSharding& sharding, int64_t device); + +absl::Status AddVariableUpdatesToCores( + const TPUCompileMetadataProto& metadata, + const XlaCompiler::CompilationResult& compilation_result, + const std::vector& arg_core_mapping, + std::vector* may_modify_variables, + std::vector>* per_core_output_shapes, + std::vector>>* per_core_variable_indices); + +absl::Status ComputeOutputShapesForEachCore( + const tpu::TPUCompileMetadataProto& metadata, + const XlaCompiler::CompilationResult& compilation_result, + std::vector>* per_core_output_shapes); + +absl::Status CreateHloModules( + const TPUCompileMetadataProto& metadata, + const XlaCompiler::CompilationResult& compilation_result, + const std::optional& device_assignment, + std::vector>* hlo_modules); + +absl::StatusOr CreateTpuCompilationRequest( + const std::variant& computation, + const TPUCompileMetadataProto& metadata, + const std::vector& arg_shapes); + +absl::Status CompileOpMetadataFromContext(OpKernelConstruction* ctx, + TPUCompileMetadataProto* metadata, + NameAttrList* function_name, + std::string* mlir_module); + +// Computes shapes for each argument. Uses both the static shape from the +// metadata, and the dynamic shapes where the static shape is not +// defined. There must be one dynamic_shape for each argument with a +// partially defined shape, in index order. +absl::Status ComputeArgumentShapes( + const TPUCompileMetadataProto& metadata, + const std::vector& dynamic_shapes, + std::vector* arg_shapes); + +} // namespace tpu +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILE_OP_SUPPORT_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/tpu/kernels/tpu_configuration_ops.h b/third_party/tflite-hdrs/tensorflow/core/tpu/kernels/tpu_configuration_ops.h new file mode 100644 index 00000000..fe5eeb22 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/tpu/kernels/tpu_configuration_ops.h @@ -0,0 +1,176 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_CONFIGURATION_OPS_H_ +#define TENSORFLOW_CORE_TPU_KERNELS_TPU_CONFIGURATION_OPS_H_ + +#include +#include + +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/op_requires.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/statusor.h" +#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.h" + +namespace tensorflow { + +absl::Status CreateTpuCompilationCache( + ResourceMgr* rmgr, tpu::TpuCompilationCacheInterface** compilation_cache); + +absl::StatusOr> ConstructDevicesPerHost( + OpKernelContext* ctx); + +// The ConfigureDistributedTpu op is used to start an TPUDriver from +// TensorFlow. It should be run on a TPU_SYSTEM device and returns the +// connection host:port for the CompilationCacheServer. The +// CompilationCacheServer will remain live until the device's Resource Manager +// is cleared or a ShutdownDistributedTpuOp is run on the same device. +class ConfigureDistributedTpuOp : public OpKernel { + public: + explicit ConfigureDistributedTpuOp(OpKernelConstruction* ctx) + : OpKernel(ctx) { + OP_REQUIRES(ctx, ctx->num_inputs() > 0, + absl::InternalError( + "_ConfigureDistributedTPU needs at least one input")); + } + void Compute(OpKernelContext* ctx) override; + ~ConfigureDistributedTpuOp() override = default; + + private: + // ConfigureDistributedTpuOp is neither copyable nor movable. + ConfigureDistributedTpuOp(const ConfigureDistributedTpuOp&) = delete; + ConfigureDistributedTpuOp& operator=(const ConfigureDistributedTpuOp&) = + delete; +}; + +// The WaitForDistributedTpuOp op is used to block execution until +// the distributed Tpu system has started up. It must be run on +// the same TPU_SYSTEM device that ConfigureDistributedTpuOp was run +// on, after all of the InitializeHostForDistributedTpuOp Ops have +// completed. +class WaitForDistributedTpuOp : public OpKernel { + public: + explicit WaitForDistributedTpuOp(OpKernelConstruction* ctx) : OpKernel(ctx) { + OP_REQUIRES_OK(ctx, + ctx->GetAttr("startup_timeout_sec", &startup_timeout_sec_)); + OP_REQUIRES( + ctx, startup_timeout_sec_ > 0, + absl::InvalidArgumentError(absl::StrCat( + "startup_timeout_sec ", startup_timeout_sec_, " must be >0"))); + } + void Compute(OpKernelContext* ctx) override; + ~WaitForDistributedTpuOp() override = default; + + private: + // The time to wait for all hosts to start up. + int startup_timeout_sec_; + + // WaitForDistributedTpuOp is neither copyable nor movable. + WaitForDistributedTpuOp(const WaitForDistributedTpuOp&) = delete; + WaitForDistributedTpuOp& operator=(const WaitForDistributedTpuOp&) = delete; +}; + +// The ShutdownDistributedTpu op is used to stop a running TPUDriver from +// TensorFlow. It should be run on the TPU_SYSTEM device where +// ConfigureDistributedTpuOp was run. +class ShutdownDistributedTpuOp : public OpKernel { + public: + explicit ShutdownDistributedTpuOp(OpKernelConstruction* ctx) + : OpKernel(ctx) {} + + void Compute(OpKernelContext* ctx) override; + + ~ShutdownDistributedTpuOp() override = default; + + private: + // ShutdownDistributedTpuOp is neither copyable nor movable. + ShutdownDistributedTpuOp(const ShutdownDistributedTpuOp&) = delete; + ShutdownDistributedTpuOp& operator=(const ShutdownDistributedTpuOp&) = delete; +}; + +// The InitializeHostForDistributedTpu op is used to initialize the +// TPUPlatform on a host in a distributed TPU system. It should be +// run on every host containing TPU devices before any other Ops that use +// TPU are run. +class InitializeHostForDistributedTpuOp : public OpKernel { + public: + explicit InitializeHostForDistributedTpuOp(OpKernelConstruction* ctx) + : OpKernel(ctx) { + ctx->GetAttr("enable_whole_mesh_compilations", + &enable_whole_mesh_compilations_) + .IgnoreError(); + ctx->GetAttr("tpu_cancellation_closes_chips", + &tpu_cancellation_closes_chips_) + .IgnoreError(); + } + + void Compute(OpKernelContext* ctx) override; + + ~InitializeHostForDistributedTpuOp() override = default; + + private: + // InitializeHostForDistributedTpuOp is neither copyable nor movable. + InitializeHostForDistributedTpuOp(const InitializeHostForDistributedTpuOp&) = + delete; + InitializeHostForDistributedTpuOp& operator=( + const InitializeHostForDistributedTpuOp&) = delete; + + bool enable_whole_mesh_compilations_ = false; + int tpu_cancellation_closes_chips_ = 0; +}; + +// The SetGlobalTPUArray op is used to initialize the TPUPlatform on a +// host in a distributed TPU system. It should be run on every host +// containing TPU devices before any other Ops that use TPU are run. +class SetGlobalTPUArrayOp : public OpKernel { + public: + explicit SetGlobalTPUArrayOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} + + void Compute(OpKernelContext* ctx) override; + + ~SetGlobalTPUArrayOp() override = default; + + private: + // SetGlobalTPUArrayOp is neither copyable nor movable. + SetGlobalTPUArrayOp(const SetGlobalTPUArrayOp&) = delete; + SetGlobalTPUArrayOp& operator=(const SetGlobalTPUArrayOp&) = delete; +}; + +// The DisconnectDistributedTpuChips op is used to disconnect all the chips on a +// host from a running TPUDriver instance. It should be run on every host +// containing TPU devices before the ShutdownDistributedTpuOp is run on +// the TPU_SYSTEM. +class DisconnectDistributedTpuChipsOp : public OpKernel { + public: + explicit DisconnectDistributedTpuChipsOp(OpKernelConstruction* ctx) + : OpKernel(ctx) {} + + void Compute(OpKernelContext* ctx) override; + + ~DisconnectDistributedTpuChipsOp() override = default; + + private: + // DisconnectDistributedTpuChipsOp is neither copyable nor movable. + DisconnectDistributedTpuChipsOp(const DisconnectDistributedTpuChipsOp&) = + delete; + DisconnectDistributedTpuChipsOp& operator=( + const DisconnectDistributedTpuChipsOp&) = delete; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_TPU_KERNELS_TPU_CONFIGURATION_OPS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/tpu/kernels/tpu_embedding_engine_state_interface.h b/third_party/tflite-hdrs/tensorflow/core/tpu/kernels/tpu_embedding_engine_state_interface.h new file mode 100644 index 00000000..73b0a492 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/tpu/kernels/tpu_embedding_engine_state_interface.h @@ -0,0 +1,75 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_EMBEDDING_ENGINE_STATE_INTERFACE_H_ +#define TENSORFLOW_CORE_TPU_KERNELS_TPU_EMBEDDING_ENGINE_STATE_INTERFACE_H_ + +#include + +#include "xla/stream_executor/tpu/tpu_api.h" +#include "xla/stream_executor/tpu/tpu_ops_c_api.h" +#include "tensorflow/core/framework/resource_mgr.h" + +namespace tensorflow { + +class TpuEmbeddingEngineState; + +namespace tpu { + +const char kTpuEmbeddingEngineStateInterfaceResourceName[] = + "tpu_embedding_engine_state"; + +class TpuEmbeddingEngineStateInterface : public ResourceBase { + public: + explicit TpuEmbeddingEngineStateInterface(XLA_TpuEmbeddingEngineState* handle) + : engine_state_(handle) {} + + ~TpuEmbeddingEngineStateInterface() override { + if (engine_state_ != nullptr) { + stream_executor::tpu::OpsApiFn()->TpuEmbeddingEngineState_FreeFn( + engine_state_); + } + } + + tensorflow::TpuEmbeddingEngineState* GetState() const { + if (engine_state_ == nullptr) { + return nullptr; + } + return static_cast( + stream_executor::tpu::OpsApiFn()->TpuEmbeddingEngineState_GetStateFn( + engine_state_)); + } + + static TpuEmbeddingEngineStateInterface* Create() { + XLA_TpuEmbeddingEngineState* state = nullptr; + if (stream_executor::tpu::OpsApiFn()->TpuEmbeddingEngineState_CreateFn != + nullptr) { + state = + stream_executor::tpu::OpsApiFn()->TpuEmbeddingEngineState_CreateFn(); + } + return new TpuEmbeddingEngineStateInterface(state); + } + + string DebugString() const override { + return "TpuEmbeddingEngineStateInterface"; + } + + private: + XLA_TpuEmbeddingEngineState* engine_state_; +}; + +} // namespace tpu +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_TPU_KERNELS_TPU_EMBEDDING_ENGINE_STATE_INTERFACE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/tpu/kernels/tpu_embedding_enqueue_ops.h b/third_party/tflite-hdrs/tensorflow/core/tpu/kernels/tpu_embedding_enqueue_ops.h new file mode 100644 index 00000000..e06c02c9 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/tpu/kernels/tpu_embedding_enqueue_ops.h @@ -0,0 +1,37 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_EMBEDDING_ENQUEUE_OPS_H_ +#define TENSORFLOW_CORE_TPU_KERNELS_TPU_EMBEDDING_ENQUEUE_OPS_H_ + +#include + +#include "absl/types/span.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/protobuf/tpu/tpu_embedding_configuration.pb.h" + +namespace tensorflow { + +// Validates that all the combiners passed are one of the following: sum, mean, +// or sqrtn. +absl::Status ValidateCombiners(absl::Span combiners); + +// Validates the `mode_override` input of the TPUEnqueue* ops, and, if correct, +// sets the `mode` to pass on to the TPU Embedding manager. +absl::Status GetValidatedModeOverride( + const string& mode_override, tpu::TPUEmbeddingConfiguration::Mode* mode); +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_TPU_KERNELS_TPU_EMBEDDING_ENQUEUE_OPS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/tpu/kernels/tpu_embedding_load_retrieve_ops.h b/third_party/tflite-hdrs/tensorflow/core/tpu/kernels/tpu_embedding_load_retrieve_ops.h new file mode 100644 index 00000000..51459c6a --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/tpu/kernels/tpu_embedding_load_retrieve_ops.h @@ -0,0 +1,99 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Ops to load and retrieve embeddings for TPU Embedding. + +#ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_EMBEDDING_LOAD_RETRIEVE_OPS_H_ +#define TENSORFLOW_CORE_TPU_KERNELS_TPU_EMBEDDING_LOAD_RETRIEVE_OPS_H_ + +#include +#include +#include + +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/protobuf/tpu/tpu_embedding_configuration.pb.h" +#include "tensorflow/core/tpu/tpu_embedding_optimization_parameters_utils.h" + +namespace tensorflow { + +// The LoadAllTPUEmbeddingParameters op is used to load initial embedding +// table parameters onto a host that has already been configured using +// ConfigureTPUEmbeddingHost. This Op should be used when TPUEmbedding is part +// of a training loop. The Op takes four input lists of tensors. Each list has +// one entry per embedding table, but some entries are ignored based on the +// particular optimization algorithm used for each table. parameters is the +// initial values of the embedding tables, and auxiliary[1-3] are the initial +// values of the auxiliary parameters. +class LoadAllTPUEmbeddingParametersOp : public OpKernel { + public: + explicit LoadAllTPUEmbeddingParametersOp(OpKernelConstruction* ctx); + ~LoadAllTPUEmbeddingParametersOp() override = default; + + void Compute(OpKernelContext* ctx) override; + + protected: + void GetStateVariables( + OpKernelContext* ctx, + std::array>, + tpu::kMaxAuxiliaryParameterCount + 1>& state_variable_vector); + + private: + tpu::TPUEmbeddingConfiguration config_; + std::vector table_shapes_; + + LoadAllTPUEmbeddingParametersOp(const LoadAllTPUEmbeddingParametersOp&) = + delete; + void operator=(const LoadAllTPUEmbeddingParametersOp&) = delete; +}; + +// The RetrieveAllTPUEmbeddingParameters op is used to retrieve updated +// embedding table parameters from a TPU that has already been +// configured using ConfigureTPUEmbeddingHostOp. This Op should be used when +// TPUEmbedding is part of a training loop. The Op returns four output lists of +// tensors. Each list has one entry per embedding table, but entries are empty +// when the relevant table does not have that number of auxiliary parameters. +// The parameters output is the updated values of the embedding tables, and +// auxiliary[1-3] are the updated values of the auxiliary parameters. + +// Currently, this op is the only method to make sure that the TPUEmbedding has +// completed execution of the mini-batches enqueued so far. +// TODO(misard, b/34936670): Add a TensorFlow op that waits till all +// minibatches have been processed by the TPUEmbedding on the current host. +class RetrieveAllTPUEmbeddingParametersOp : public OpKernel { + public: + explicit RetrieveAllTPUEmbeddingParametersOp(OpKernelConstruction* ctx); + ~RetrieveAllTPUEmbeddingParametersOp() override = default; + + void Compute(OpKernelContext* ctx) override; + + protected: + void GetStateVariables( + OpKernelContext* ctx, + std::array>, + tpu::kMaxAuxiliaryParameterCount + 1>& state_variable_vector, + std::vector& num_state_variables); + + tpu::TPUEmbeddingConfiguration config_; + std::vector table_shapes_; + + RetrieveAllTPUEmbeddingParametersOp( + const RetrieveAllTPUEmbeddingParametersOp&) = delete; + void operator=(const RetrieveAllTPUEmbeddingParametersOp&) = delete; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_TPU_KERNELS_TPU_EMBEDDING_LOAD_RETRIEVE_OPS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/tpu/kernels/tpu_execute_op.h b/third_party/tflite-hdrs/tensorflow/core/tpu/kernels/tpu_execute_op.h new file mode 100644 index 00000000..d0e70dbc --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/tpu/kernels/tpu_execute_op.h @@ -0,0 +1,70 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_EXECUTE_OP_H_ +#define TENSORFLOW_CORE_TPU_KERNELS_TPU_EXECUTE_OP_H_ + +#include +#include + +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/mutex.h" + +namespace tensorflow { + +// Op that executes a precompiled TPU computation. +class TPUExecuteOp : public AsyncOpKernel { + public: + explicit TPUExecuteOp(OpKernelConstruction* context); + ~TPUExecuteOp() override; + + AsyncOpKernel* AsAsync() override; + + void Compute(OpKernelContext* context) override; + void ComputeAsync(OpKernelContext* context, DoneCallback done) override; + + protected: + // Used by TPUExecuteAndUpdateVariablesOp to set the fused variable reads and + // updates indices in the XLA computation. The two vectors must have the same + // size, and a pair of read index and write index represents a variable's + // input to the program and its updated value from the program. If the + // variable is not updated, use -1 as the output index. + std::vector fused_device_var_reads_in_computation_inputs_; + std::vector fused_device_var_updates_in_computation_outputs_; + + private: + absl::Status DoWork(OpKernelContext* context); + + TPUExecuteOp(const TPUExecuteOp&) = delete; + void operator=(const TPUExecuteOp&) = delete; +}; + +// A variant of TPUExecuteOp that contains fused device variable reads and +// updates. +class TPUExecuteAndUpdateVariablesOp : public TPUExecuteOp { + public: + explicit TPUExecuteAndUpdateVariablesOp(OpKernelConstruction* context); + ~TPUExecuteAndUpdateVariablesOp() override = default; + + private: + TPUExecuteAndUpdateVariablesOp(const TPUExecuteAndUpdateVariablesOp&) = + delete; + void operator=(const TPUExecuteAndUpdateVariablesOp&) = delete; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_TPU_KERNELS_TPU_EXECUTE_OP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/tpu/kernels/tpu_execute_op_options.h b/third_party/tflite-hdrs/tensorflow/core/tpu/kernels/tpu_execute_op_options.h new file mode 100644 index 00000000..950fb884 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/tpu/kernels/tpu_execute_op_options.h @@ -0,0 +1,40 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_EXECUTE_OP_OPTIONS_H_ +#define TENSORFLOW_CORE_TPU_KERNELS_TPU_EXECUTE_OP_OPTIONS_H_ + +#include "absl/strings/string_view.h" +#include "tensorflow/core/platform/status.h" + +namespace tensorflow { +namespace internal { + +enum class TpuCancellationClosesChipsMode : int { + kUnset = 0, // fallback to other configuration, e.g. absl flag + kEnabled = 1, // Close TPU chips when cancellation happens + kDisabled = 2, // Do not close TPU chips when cancellation happens +}; + +// Set TPU cancellation closing chips mode from an integer. See the enum +// definition of `TpuCancellationClosesChipsConfig` above for valid values. +absl::Status SetTpuCancellationClosesChips(int val); + +// Get whether to close chips when TPUExecutionOp is cancelled. If unset, return +// the value specified by the `default_value` argument. +bool TpuCancellationClosesChipsGetOrDefault(bool default_value); +} // namespace internal +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_TPU_KERNELS_TPU_EXECUTE_OP_OPTIONS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/tpu/kernels/tpu_fingerprint_lookup.h b/third_party/tflite-hdrs/tensorflow/core/tpu/kernels/tpu_fingerprint_lookup.h new file mode 100644 index 00000000..fe98817c --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/tpu/kernels/tpu_fingerprint_lookup.h @@ -0,0 +1,95 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_FINGERPRINT_LOOKUP_H_ +#define TENSORFLOW_CORE_TPU_KERNELS_TPU_FINGERPRINT_LOOKUP_H_ + +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/container/node_hash_map.h" +#include "absl/strings/string_view.h" +#include "absl/synchronization/mutex.h" +#include "tensorflow/core/framework/resource_mgr.h" +#include "tensorflow/core/platform/stringpiece.h" + +namespace tensorflow { +namespace tpu { + +// A class that holds the key-value pair of fingerprints. By calling the +// Register method, this class can map the key to the value. Note that this +// class holds invariant key-value pairs. That is, it does not allow updating +// key-value pairs, nor N-key-to-1-value and 1-key-to-M-value pairs. If such +// cases occur, the class keeps the earliest registered pairs and discards any +// violating pairs. +// +// Example: +// TpuFingerprintLookup fingerprint_lookup; +// +// // Register key-intermediate pair. +// fingerprint_lookup.RegisterKeyValuePair("key1", "intermediate1"); +// // Register intermediate-value pair. +// fingerprint_lookup.RegisterKeyValuePair("intermediate1", "value1"); +// +// // Lookup fingerprint with key. +// std::string fingerprint = fingerprint_lookup.Lookup("key1"); +// +// TODO(chiachenc): use templates and add Unregister methods. +class TpuFingerprintLookup : public ResourceBase { + public: + // Creates an instance of TpuFingerprintLookup. + static TpuFingerprintLookup* Create(); + + // Register key-intermediate pair + void RegisterKeyAndIntermediatePair(uint64 key, uint64 intermediate); + + // Register intermediate-value pair. A successful registration requires a + // preceding RegisterKeyAndIntermediatePair. Return true if successfully + // registering a key-value pair; otherwise, return false. + bool RegisterIntermediateAndValuePair(uint64 intermediate, std::string value); + + // Look up fingerprint with key. + // Return std::nullopt if not found. + std::optional<::tensorflow::StringPiece> Lookup(uint64 key); + + size_t num_valid() { + absl::MutexLock lock(&mu_); + return key_to_value_.size(); + } + + std::string DebugString() const override { return "TpuFingerprintLookup"; } + + private: + explicit TpuFingerprintLookup() {} + + absl::Mutex mu_; + // Main storage for lookup + absl::node_hash_map key_to_value_ ABSL_GUARDED_BY(mu_); + + // An auxiliary storage to ensure 1-to-1 and invariant key-value pair + absl::node_hash_map value_to_key_ ABSL_GUARDED_BY(mu_); + + // An auxiliary storage to keep intermediate-key pairs. + absl::flat_hash_map intermediate_to_key_ ABSL_GUARDED_BY(mu_); + + TpuFingerprintLookup(const TpuFingerprintLookup&) = delete; + TpuFingerprintLookup& operator=(const TpuFingerprintLookup&) = delete; +}; +} // namespace tpu +} // namespace tensorflow +#endif // TENSORFLOW_CORE_TPU_KERNELS_TPU_FINGERPRINT_LOOKUP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/tpu/kernels/tpu_functional_ops.h b/third_party/tflite-hdrs/tensorflow/core/tpu/kernels/tpu_functional_ops.h new file mode 100644 index 00000000..3c8287af --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/tpu/kernels/tpu_functional_ops.h @@ -0,0 +1,383 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_FUNCTIONAL_OPS_H_ +#define TENSORFLOW_CORE_TPU_KERNELS_TPU_FUNCTIONAL_OPS_H_ + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/base/call_once.h" +#include "absl/base/thread_annotations.h" +#include "absl/container/flat_hash_set.h" +#include "absl/synchronization/mutex.h" +#include "tensorflow/compiler/jit/shape_inference.h" +#include "xla/stream_executor/tpu/tpu_api.h" +#include "xla/stream_executor/tpu/tpu_ops_c_api.h" +#include "tensorflow/core/common_runtime/device_mgr.h" +#include "tensorflow/core/common_runtime/optimization_registry.h" +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/op_requires.h" +#include "tensorflow/core/framework/resource_handle.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/lib/gtl/flatmap.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/refcount.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/threadpool.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/protobuf/tpu/topology.pb.h" +#include "tensorflow/core/tpu/kernels/tpu_ordinal_selector.h" +#include "tensorflow/core/util/device_name_utils.h" +#include "tensorflow/core/util/reffed_status_callback.h" +#include "absl/container/flat_hash_map.h" + +namespace tensorflow { +// Holds node's shape information for Concat/Split. +using EdgeShapes = absl::flat_hash_map>; +using GroupedEdges = + absl::flat_hash_map>; + +// Contains attrs "T", "sharding", "_tpu_replicate" for each XlaSharding op that +// we find as part of searching for inputs to models that are replicated. +using XlaShardingInfoMap = absl::flat_hash_map< + std::string, std::tuple>; + +// Contains attrs "T", and a pointer to tpu_replicated_metadata for ctrl dep +// for each TpuReplicatedInput op that we find as part of searching for inputs +// to models that are replicated. +using TpuReplicatedInputInfoMap = + absl::flat_hash_map>; + +namespace tpu_functional_internal { + +// Helper functions for graph rewrites. +GroupedEdges GroupTensorsForInputPacking( + const EdgeShapes& tpu_input_shapes, + const absl::flat_hash_map& tpu_input_dtypes, + bool input_shape_opt, bool group_tensors_for_packing); +GroupedEdges GroupTensorsForOutputPacking(Graph* graph, + EdgeShapes& tpu_output_shapes, + GraphShapeInfo* shape_info); + +absl::Status CreateConcatAndSplitNodesForInputTensor( + Graph* graph, const string& cluster_name, EdgeShapes* tpu_input_shapes, + const absl::flat_hash_map>& + grouped_input_edges, + int32_t minimum_input_tensors_packing, bool xla_spmd_input_sharded, + const XlaShardingInfoMap& xla_sharding_info, + const TpuReplicatedInputInfoMap& tpu_replicated_input_info); +absl::Status CreateConcatAndSplitNodesForOutputTensor( + Graph* graph, const string& cluster_name, EdgeShapes* tpu_output_shapes, + GraphShapeInfo* tpu_inferred_info, GroupedEdges shape_to_output, + int32_t minimum_output_tensors_packing); + +absl::Status InsertReshapeNodePairs(Graph* graph, const string& cluster_name, + EdgeShapes* tpu_input_shapes, + int num_cores_per_replica); + +} // namespace tpu_functional_internal + +typedef FunctionLibraryRuntime::Handle FHandle; + +// A `TPUPartitionedCallOp` asynchronously executes a function on exactly one +// TPU core and potentially across multiple other devices, but within a single +// process. The kernel places and partitions the function's underlying graph, +// executing each of the partitioned subgraphs as a function. +// +// The core on which the TPU computation is executed must be specified via the +// `device_ordinal` input. Different invocations of this op may specify +// different device ordinals, making it possible to map TPU computations to +// different cores at runtime. Currently, macro-substitution of device ordinals +// is only supported for the following whitelisted ops: +// * TPUExecute +// * InfeedEnqueue +// * InfeedEnqueueTuple +// +// Attempting to compute a TPUPartitionedCallOp whose function body has a +// non-whitelisted node bearing an attribute named "device_ordinal" will result +// in an error. +// +// TODO(akshayka): This class duplicates most of the logic of +// `PartitionedCallOp`; once that class and this one have evolved to stable +// states, and if at that time they remain sufficiently similar, either unify +// them in one op or set up an inheritance structure that allows for code reuse. +class TPUPartitionedCallOp : public AsyncOpKernel { + public: + explicit TPUPartitionedCallOp(OpKernelConstruction* ctx) + : AsyncOpKernel(ctx), + pool_(ctx->env(), "InitializeVarOnTPUPool", 1), + library_runtime_(nullptr) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("f", &func_)); + // If the importer has set the original function name, it means the function + // attribute is referring to a rewritten function, but we need to use the + // original function name in order to find it in the function library. + std::string orig_f; + if (ctx->GetAttr("_orig_f", &orig_f).ok()) { + func_.set_name(orig_f); + } + auto status = ctx->GetAttr("autotuner_thresh", &autotuner_thresh_); + if (!status.ok()) { + autotuner_thresh_ = 0; + } + stream_executor::tpu::OpsApiFn()->TfTpu_GetTpuPartitionedCallParamsFn( + &runtime_params_); + } + + ~TPUPartitionedCallOp() override = default; + + void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override; + + private: + struct DeviceAndFHandle { + std::string device; + FHandle handle; + + // The FLD passed to `library_runtime_` as an overlay function library for + // instantiation of function `handle`. This is a snapshot of the currrent + // `flib_def_`. Since `flib_def_` can be changed concurrently by another + // graph rewrite when executing `handle`, we need to make sure each + // `handle` uses a different FLD to avoid races. See b/181149591. + std::unique_ptr flib_def; + }; + + struct TPUMetadata { + tpu::TopologyProto topology; + int num_cores_per_replica = 1; + std::vector device_assignment; + }; + + // This method is thread-safe. + absl::Status GetTpuCoreOrdinal(OpKernelContext* ctx, uint64 input_hash, + int64_t* ordinal_selector_req_id, + int32_t* core_ordinal); + + // Helper to create and initialize a TPU variable given a CPU variable + // var: the CPU variable created by the user + // ndef: the node def of the corresponding TPU var handle that we created + // device_ordinal: TPU device ordinal on which to initialize this variable + absl::Status InitializeVarOnTPU(OpKernelContext* ctx, + const core::RefCountPtr& var, + NodeDef* ndef, int device_ordinal, + bool fast_mem) + ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); + + // Helper to create and initialize partitioned TPU variables given a CPU + // variable with XLA sharding annotation. + // var: the CPU variable created by the user. + // ndefs: the node def of the corresponding TPU var handle on all the logical + // cores. + // split_dim: the partition dimension of the variable. If -1, the variable is + // replicated. + // device_ordinal: The index of the TPU core that is scheduled to run + // the computation. In the case of XLA SPMD, it is the "primary" core, which + // is the smallest index of all the cores. + absl::Status InitializeShardedVarOnTPU(OpKernelContext* ctx, + const core::RefCountPtr& var, + std::vector& ndefs, + int split_dim, + const std::vector& tpu_devices) + ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); + + // Check if any of the immediate successors of node has attribute + // "_tpu_replicate". + bool IsInputToTPUReplicate(Node* node) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); + + // Replace an _Arg node of type DT_RESOURCE by a VarHandleOp on TPU + absl::Status ReplaceResourceArgsWithVarHandleOps( + Graph* graph, OpKernelContext* ctx, int device_ordinal, + bool enable_spmd_xla_partitioning, const TPUMetadata& tpu_metadata) + ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); + + // Replace a _Arg node indicates a variable on CPU host by sharded/replicated + // variables on all logical TPU devices. + absl::Status ReplaceAndPartitionXLAShardingVariable( + Graph* graph, OpKernelContext* ctx, int device_ordinal, + ResourceHandle& handle, Node* variable, const TPUMetadata& tpu_metadata) + ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); + + absl::Status ShardInputsWithXlaSharding(Graph* graph, + const std::string& cluster_name, + int num_cores_per_replica, + OpKernelContext* ctx) + ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); + + // Rewrite the graph for input and output optimiazations. + // TODO(ylc): Move this function to Graph optimization pass. + absl::Status OptimizeTpuInputOutputTensors( + Graph* graph, bool enable_spmd_xla_partitioning, + int num_cores_per_replica, + std::map>& named_input_shapes, + OpKernelContext* ctx) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); + + absl::Status InferShapesWithResourceVar( + Graph* graph, OpKernelContext* ctx, + std::map& arg_shapes, + GraphShapeInfo* tpu_inferred_info); + + // Copies the graph backing `func_` into `graph`. + absl::Status GetGraphFromFunction(Graph* graph, int device_ordinal, + bool* use_spmd_for_xla_partitioning, + TPUMetadata* tpu_metadata) + ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); + + // Places the graph carried by `optimization_options` and runs graph + // optimization passes (pre-placement, post-placement, and post-rewrite). + absl::Status PlacementHelper( + const DeviceSet& device_set, + const GraphOptimizationPassOptions& optimization_options, + const string& function_name); + // Partitions `graph`, populates `subgraphs` with the partitions, and runs + // the post-partitioning graph optimization passes. + absl::Status PartitionHelper( + const DeviceSet& device_set, + const GraphOptimizationPassOptions& optimization_options, Graph* graph, + std::unordered_map>* subgraphs); + + // Adds and instantiates a function backed by `graph` with name + // `function_name` on device `target_device`, storing the handle in `handle`. + // If `out_flib_def` is not null, it will be set to a copy of `flib_def_` and + // used for instantiation. + absl::Status InstantiatePartition( + const Graph& graph, const string& function_name, + const string& target_device, FHandle* handle, + std::unique_ptr* out_flib_def) + ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); + // Adds and instantiates functions for each subgraph in `subgraphs` after + // rewriting nodes' `device_ordinal` attributes to match `replica_id` when + // num_cores_per_replica == 1. + absl::Status InstantiateFunctionsFromSubgraphs( + const DeviceSet& device_set, int replica_id, uint64 cache_hash, + int num_cores_per_replica, + std::unordered_map> subgraphs) + ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); + + // Rewrites `graph` such that the device ordinal attributes of all whitelisted + // nodes (see `IsSupportedTPUOp`) are set to `device_ordinal`; + // `*modified` is set to true if the graph is modified in the process (i.e., + // if it contains a whitelisted node), otherwise is unmodified. + // + // Returns an error if + // (1) the graph contains a non-whitelisted node that carries an attribute + // with name "device_ordinal", or + // (2) the set of device ordinals found among the graph's nodes has + // cardinality greater than 1. + absl::Status SetDeviceOrdinal(const DeviceSet& device_set, int device_ordinal, + Graph* graph, bool* modified) + ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); + + void ExecuteRemoteFunction(const FunctionLibraryRuntime::Options& opts, + FHandle handle, OpKernelContext* ctx, + ReffedStatusCallback* done) + ABSL_LOCKS_EXCLUDED(mu_); + void ExecuteLocalFunction(const FunctionLibraryRuntime::Options& opts, + const OpInputList& arguments, FHandle handle, + OpKernelContext* ctx, ReffedStatusCallback* done) + ABSL_LOCKS_EXCLUDED(mu_); + void ExecuteFunctions(const std::vector& functions, + OpKernelContext* ctx, int device_ordinal, + int64_t ordinal_selector_req_id, DoneCallback done) + ABSL_LOCKS_EXCLUDED(mu_); + + absl::Status ShouldUseRemoteExecutionForFn(const std::string& target_device, + bool* remote_execution) { + DeviceNameUtils::ParsedName target_device_parsed; + DeviceNameUtils::ParsedName local_device_parsed; + + if (!DeviceNameUtils::ParseFullOrLocalName(target_device, + &target_device_parsed)) { + return errors::InvalidArgument("Cannot parse target device ", + target_device); + } + if (!DeviceNameUtils::ParseFullOrLocalName(local_device_name_, + &local_device_parsed)) { + return errors::InvalidArgument("Cannot parse local device ", + local_device_name_); + } + + if (DeviceNameUtils::AreCompatibleDevNames(target_device_parsed, + local_device_parsed)) { + *remote_execution = false; + } else { + *remote_execution = true; + } + return absl::OkStatus(); + } + + // Init once flagas. + absl::once_flag once_; + absl::once_flag ordinal_selector_once_; + + // Device manager and device set. + const DeviceMgr* device_mgr_; + DeviceSet device_set_; + + // Threadpool. + thread::ThreadPool pool_; + + // `func_` is the original function supplied to this OpKernel. + NameAttrList func_; + string local_device_name_; + // Maps from cache key to their corresponding functions, which are + // represented as (device, handle) pairs. + gtl::FlatMap> partition_cache_ + ABSL_GUARDED_BY(mu_); + + // A set contains seen ordinals. Used by variable initialization on TPU. + absl::flat_hash_set seen_ordinals_; + + // Record the indices of the _Arg with type DT_RESOURCE that goes + // into a TPU Op. + std::vector replaced_input_indices_; + + absl::Mutex mu_; + // Function shards are added to the `flib_def_`, and later on it'll create + // a copy of `flib_def_` to pass to `library_runtime_` as an overlay function + // library for instantiation. + std::unique_ptr flib_def_; + FunctionLibraryRuntime* library_runtime_; + + // Used to uniquify function names in `flib_def_`. + uint32 suffix_ = 0; + + // Minimum number of run steps (batches) necessary to trigger xla autotuner. + int autotuner_thresh_ = 0; + + // TPU core selection. + std::shared_ptr ordinal_selector_; + + // Maps input hash to TF fingerprint. + absl::flat_hash_map inputs_to_fingerprint_; + + // List of TPU devices + std::vector tpu_devices_; + + TpuPartitionedCall_Params runtime_params_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_TPU_KERNELS_TPU_FUNCTIONAL_OPS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/tpu/kernels/tpu_mesh_state_interface.h b/third_party/tflite-hdrs/tensorflow/core/tpu/kernels/tpu_mesh_state_interface.h new file mode 100644 index 00000000..6e84dde2 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/tpu/kernels/tpu_mesh_state_interface.h @@ -0,0 +1,87 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_MESH_STATE_INTERFACE_H_ +#define TENSORFLOW_CORE_TPU_KERNELS_TPU_MESH_STATE_INTERFACE_H_ + +#include + +#include "xla/stream_executor/tpu/tpu_api.h" +#include "xla/stream_executor/tpu/tpu_ops_c_api.h" +#include "tensorflow/core/framework/resource_mgr.h" +#include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h" + +namespace tensorflow { + +class TpuMeshCommonState; + +namespace tpu { + +const char kTpuMeshStateInterfaceResourceName[] = "tpu_mesh_common_state"; + +class TpuMeshStateInterface : public tensorflow::ResourceBase { + public: + explicit TpuMeshStateInterface(XLA_TpuMeshState* handle) + : mesh_state_(handle) {} + + ~TpuMeshStateInterface() override { + if (mesh_state_ != nullptr) { + stream_executor::tpu::OpsApiFn()->TpuMeshState_FreeFn(mesh_state_); + } + } + + static TpuMeshStateInterface* Create() { + XLA_TpuMeshState* state = nullptr; + if (stream_executor::tpu::OpsApiFn()->TpuMeshState_CreateFn != nullptr) { + state = stream_executor::tpu::OpsApiFn()->TpuMeshState_CreateFn(); + } + return new TpuMeshStateInterface(state); + } + + const XLA_TpuMeshState* data() const { return mesh_state_; } + + tensorflow::TpuMeshCommonState* mesh_common_state() const { + if (mesh_state_ == nullptr) { + return nullptr; + } + return static_cast( + stream_executor::tpu::OpsApiFn()->TpuMeshState_MeshCommonStateFn( + mesh_state_)); + } + + // Returns whether we should include the device assignment as a static field + // to the TPU program. This also determines whether we should include the + // device assignment as part of the compilation cache key. + bool NeedsStaticDeviceAssignment(const TPUCompileMetadataProto& metadata, + TpuCoreTypeEnum tpu_core_type) const { + if (mesh_state_ == nullptr) { + return false; + } + // Static device assignment enables XLA to perform certain optimization when + // all cores are used in the replicated computation. + return metadata.num_cores_per_replica() * metadata.num_replicas() == + stream_executor::tpu::OpsApiFn()->TpuTopology_AvailableCoreCountFn( + mesh_state_, tpu_core_type); + } + + string DebugString() const override { return "TpuMeshStateInterface"; } + + private: + XLA_TpuMeshState* mesh_state_; +}; + +} // namespace tpu +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_TPU_KERNELS_TPU_MESH_STATE_INTERFACE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/tpu/kernels/tpu_op_consts.h b/third_party/tflite-hdrs/tensorflow/core/tpu/kernels/tpu_op_consts.h new file mode 100644 index 00000000..cbf2c994 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/tpu/kernels/tpu_op_consts.h @@ -0,0 +1,41 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_OP_CONSTS_H_ +#define TENSORFLOW_CORE_TPU_KERNELS_TPU_OP_CONSTS_H_ + +#include "absl/base/attributes.h" + +namespace tensorflow { +namespace tpu { + +// Resource names in the ResourceMgr. +// +// Name of cache for compiled TPU ISA protos. CompilationCache is created by +// ConfigureDistributedTpuOp, so only the master has a CompilationCache. +ABSL_CONST_INIT extern const char kCompilationCacheResourceName[]; +// Name of base class allowing Execute Ops to look up ISA protos. +// CompiledProtoCache is created by InitializeHostForDistributedTpuOp, so each +// tpu_worker has a CompiledProtoCache. +ABSL_CONST_INIT extern const char kCompiledProtoCacheResourceName[]; +// Name of cache unloader for compiled TPU ISA protos. Cache unloader should be +// put into TPU_SYSTEM device resource manager. Inference may use it to unload +// cache entries created during lifetime of a DirectSession. +ABSL_CONST_INIT extern const char kCompilationCacheUnloaderResourceName[]; +// TBD +ABSL_CONST_INIT extern const char kFingerprintLookupResourceName[]; + +} // namespace tpu +} // namespace tensorflow +#endif // TENSORFLOW_CORE_TPU_KERNELS_TPU_OP_CONSTS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/tpu/kernels/tpu_op_util.h b/third_party/tflite-hdrs/tensorflow/core/tpu/kernels/tpu_op_util.h new file mode 100644 index 00000000..d0ca805f --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/tpu/kernels/tpu_op_util.h @@ -0,0 +1,45 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_OP_UTIL_H_ +#define TENSORFLOW_CORE_TPU_KERNELS_TPU_OP_UTIL_H_ + +#include + +#include "absl/strings/string_view.h" +#include "xla/xla_data.pb.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h" +#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_key.h" +#include "tensorflow/core/tpu/kernels/tpu_mesh_state_interface.h" + +namespace tensorflow { +namespace tpu { +// Creates a fingerprint given the name and the vector of shapes. +uint64 CreateFingerprintWithNameAndShapes( + uint64 name, const std::vector& shapes); + +// Creates a unique compilation cache `key`. +TpuCompilationCacheKey CreateCompilationCacheKey( + absl::string_view function_name, uint64 function_library_fingerprint, + uint64 mlir_module_fingerprint, const OpInputList& guaranteed_constants, + const std::vector& dynamic_shapes, + const TPUCompileMetadataProto& metadata, + const TpuMeshStateInterface& mesh_state, uint64_t session_id = 0, + ResourceMgr* resource_mgr = nullptr); +} // namespace tpu +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_TPU_KERNELS_TPU_OP_UTIL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/tpu/kernels/tpu_ordinal_selector.h b/third_party/tflite-hdrs/tensorflow/core/tpu/kernels/tpu_ordinal_selector.h new file mode 100644 index 00000000..9ea689b3 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/tpu/kernels/tpu_ordinal_selector.h @@ -0,0 +1,62 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_ORDINAL_SELECTOR_H_ +#define TENSORFLOW_CORE_TPU_KERNELS_TPU_ORDINAL_SELECTOR_H_ + +#include + +#include "xla/stream_executor/tpu/tpu_api.h" +#include "xla/stream_executor/tpu/tpu_ops_c_api.h" +#include "tensorflow/core/tpu/kernels/tpu_ordinal_selector_interface.h" + +namespace tensorflow { +namespace tpu { + +// A reserved ID for deferred core selection. Intentionally set at a number +// that is more than the number of cores available in a future system. +constexpr int32_t kDeferredCoreSelectionReserved = -8193; + +class TPUOrdinalSelector : TPUOrdinalSelectorInterface { + public: + explicit TPUOrdinalSelector(int num_cores_per_replica = 1) { + stream_executor::tpu::OpsApiFn()->TfTpuOrdinalSelector_CreateFn( + &ordinal_selector_, num_cores_per_replica); + } + ~TPUOrdinalSelector() override { + stream_executor::tpu::OpsApiFn()->TfTpuOrdinalSelector_DestroyFn( + ordinal_selector_); + } + int64_t GetOrdinal(std::optional key, int64_t* req_id) override { + int64_t ordinal; + stream_executor::tpu::OpsApiFn()->TfTpuOrdinalSelector_GetOrdinalFn( + ordinal_selector_, key, req_id, &ordinal); + return ordinal; + } + void DequeueFromCoreSelector(int32_t device_ordinal, + int64_t req_id) override { + stream_executor::tpu::OpsApiFn() + ->TfTpuOrdinalSelector_DequeueFromCoreSelectorFn( + ordinal_selector_, device_ordinal, req_id); + } + + private: + TfTpuOrdinalSelector* ordinal_selector_; +}; + +} // namespace tpu +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_TPU_KERNELS_TPU_ORDINAL_SELECTOR_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/tpu/kernels/tpu_ordinal_selector_interface.h b/third_party/tflite-hdrs/tensorflow/core/tpu/kernels/tpu_ordinal_selector_interface.h new file mode 100644 index 00000000..040959d5 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/tpu/kernels/tpu_ordinal_selector_interface.h @@ -0,0 +1,37 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_ORDINAL_SELECTOR_INTERFACE_H_ +#define TENSORFLOW_CORE_TPU_KERNELS_TPU_ORDINAL_SELECTOR_INTERFACE_H_ + +#include + +#include "tensorflow/core/framework/types.h" + +namespace tensorflow { +namespace tpu { + +class TPUOrdinalSelectorInterface { + public: + virtual ~TPUOrdinalSelectorInterface() = default; + virtual int64_t GetOrdinal(std::optional key, int64_t* req_id) = 0; + virtual void DequeueFromCoreSelector(int32_t device_ordinal, + int64_t req_id) = 0; +}; + +} // namespace tpu +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_TPU_KERNELS_TPU_ORDINAL_SELECTOR_INTERFACE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/tpu/kernels/tpu_pod_state.h b/third_party/tflite-hdrs/tensorflow/core/tpu/kernels/tpu_pod_state.h new file mode 100644 index 00000000..b24a512d --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/tpu/kernels/tpu_pod_state.h @@ -0,0 +1,65 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_POD_STATE_H_ +#define TENSORFLOW_CORE_TPU_KERNELS_TPU_POD_STATE_H_ + +#include +#include +#include + +#include "tensorflow/core/framework/resource_mgr.h" +#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_service.h" + +namespace tensorflow { + +// Name of tpu pod state. +ABSL_CONST_INIT extern const char kTpuPodStateResourceName[]; + +// Wrapper to hold centralized state for the distributed TPU in the TPU_SYSTEM +// device's resource manager. +class TpuPodState : public ResourceBase { + public: + // The port number given by isa_cache_port will be freed with + // RecycleUnusedPort in the destructor if it is non-negative. + TpuPodState(int service_port, + std::unique_ptr cache_service); + + ~TpuPodState() override; + + string DebugString() const override; + + private: + std::unique_ptr cache_service_; + int service_port_; +}; + +// Returns the TPU pod state or an error. +absl::Status GetTPUPodState(const ResourceMgr* rmgr, TpuPodState** pod_state); + +// Checks whether the TPU POD state configuration is present within the resource +// manager. +bool HasTPUPodState(const ResourceMgr* rmgr); + +// Construct TpuPodState. +absl::Status ConstructTpuPodState( + ResourceMgr* rmgr, const std::vector& num_devices_per_host, + tpu::TpuCompilationCacheInterface* compilation_cache, + std::string* host_config_proto); + +absl::Status GetServerAddressAndPort(std::string* server_address, + int* serving_port); +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_TPU_KERNELS_TPU_POD_STATE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/tpu/kernels/tpu_program_group.h b/third_party/tflite-hdrs/tensorflow/core/tpu/kernels/tpu_program_group.h new file mode 100644 index 00000000..1b82d17b --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/tpu/kernels/tpu_program_group.h @@ -0,0 +1,189 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_PROGRAM_GROUP_H_ +#define TENSORFLOW_CORE_TPU_KERNELS_TPU_PROGRAM_GROUP_H_ + +#include +#include +#include +#include + +#include "tensorflow/compiler/tf2xla/host_compute_metadata.pb.h" +#include "tensorflow/compiler/tf2xla/xla_compiler.h" +#include "xla/client/compile_only_client.h" +#include "xla/service/computation_placer.h" +#include "xla/service/hlo.pb.h" +#include "xla/stream_executor/tpu/tpu_ops_c_api.h" +#include "xla/stream_executor/tpu/tpu_platform_interface.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/tpu/kernels/tpu_compile_op_support.h" +#include "tensorflow/core/tpu/kernels/tpu_executable_info.pb.h" +#include "tensorflow/core/tpu/kernels/tpu_mesh_state_interface.h" +#include "tensorflow/core/tpu/kernels/tpu_program_group_interface.h" + +namespace tensorflow { +namespace tpu { + +class TpuAotCompilationOptions : public xla::AotCompilationOptions { + public: + explicit TpuAotCompilationOptions(int64_t replica_count) + : num_cores_(0), replica_count_(replica_count) {} + + // Returns the ID of the platform to which these options apply. + se::Platform::Id PlatformId() const override { + LOG(FATAL) << "Not implemented."; + return nullptr; + }; + + void set_num_cores(int64_t tpu_cores) { num_cores_ = tpu_cores; } + int64_t replica_count() const override { return replica_count_; } + int64_t num_cores() const override { return num_cores_; } + + void set_allow_separate_sharding_programs(bool allow) { + allow_separate_sharding_programs_ = allow; + } + bool allow_separate_sharding_programs() const { + return allow_separate_sharding_programs_; + } + + std::vector + shardable_value_update_pairs() const { + return shardable_value_update_pairs_; + } + void set_shardable_value_update_pairs( + std::vector pairs) { + shardable_value_update_pairs_ = std::move(pairs); + } + + private: + int64_t num_cores_; + int64_t replica_count_; + + // Whether to allow the compiler to create separte sharding and unsharding + // programs, and modify the original program's input/output sharded size. This + // is used for XLA-chosen sharding on parameters without an on-device loop: + // the caller can invoke sharding first, then (repeatedly) invoke the sharded + // main program, and finally invoke the unsharding program when it needs the + // full output. + bool allow_separate_sharding_programs_ = false; + + // The list of input/output pairs in the main program that could be sharded. + std::vector + shardable_value_update_pairs_; +}; + +class TpuProgramGroup : public TpuProgramGroupInterface { + public: + using Status = absl::Status; + + // Compiles Mlir or TF function computation by lowering into HLO IR and + // returns TPU programs ready for execution. + static Status CompileAndBuild( + const TpuCompilationRequestProto& compilation_request, + const XLA_TpuMeshState* mesh_state, + TpuProgramGroupInterface* tpu_program_group_interface); + + + // Initializes `TpuProgramGroup` object with `xla_tpu_programs`. + void Initialize(absl::Span xla_tpu_programs); + + TpuProgramGroup() = default; + TpuProgramGroup(TpuProgramGroup&& other); + TpuProgramGroup& operator=(TpuProgramGroup&&) = delete; + + bool has_sharding_program() const override; + + size_t program_count() const override; + + int64_t program_size() const override; + + bool LogProgramMemorySummary() override; + + void UnloadAndDestroyPrograms() override; + + const std::vector& may_modify_variables_list() const override; + void set_may_modify_variables(const std::vector& may_modify_variables); + bool may_modify_variables(int index) const override; + + const std::vector& fingerprints() const; + void set_fingerprints(); + + const std::string& fingerprint(int index) const override; + + const std::vector& tpu_programs() const; + std::vector tpu_programs(TpuProgramShardingType type) const; + const XLA_TpuProgram* tpu_program(int index) const override; + void set_tpu_programs(absl::Span tpu_programs); + + const TPUExecutableInfoProto& executable_info(int index) const override; + + const TPUHostTransferInfoProto& host_transfer_info(int index) const override; + void set_hlo_metadatas(absl::Span hlo_metadatas); + const xla::HloProto* hlo_metadata(int index) const; + absl::Span hlo_metadatas() const override; + + // Deserializes `GetTpuProgramResponse` protos from remote cache. + Status DeserializeFromRpcResponseProtos( + const std::vector& rpc_response_protos); + + // Serializes executable proto from the TPU program for the given core + // `index`. + Status SerializeExecutable(int index, + TpuExecutableSerializedProto* executable) const; + + // Serializes compiler metadata of the TPU program for the given core `index`. + Status SerializeCompilerMetadata( + int index, CompilerMetadataSerializedProto* compiler_metadata) const; + + // Serializes host compute metadata of the TPU program for the given core + // `index`. + Status SerializeHostComputeMetadata( + int index, + HostComputeMetadataSerializedProto* host_compute_metadata) const; + + private: + TPUExecutableInfoProto ConstructExecutableInfo( + const XLA_TpuProgram* tpu_program); + TPUHostTransferInfoProto ConstructHostTransferInfo( + const XLA_TpuProgram* tpu_program); + xla::HloProto ConstructHloMetadata(const XLA_TpuProgram* tpu_program); + + // Update `hlo_metadatas__ptrs_` array from `hlo_metadatas_`. This needs to be + // called on `hlo_metadatas_` change(s). + void RefreshHloMetadatasPtrs(); + + std::vector may_modify_variables_; + std::vector tpu_program_fingerprints_; + + std::vector tpu_programs_; // Not owned. + std::vector executable_infos_; + std::vector host_transfer_infos_; + + // To be consistent with the TpuProgramGroupInterface::hlo_metadatas() + // signature, we store HloProto values in hlo_metadatas_ when + // set_hlo_metadata(...) is called, and return their pointers from + // hlo_metadatas_ptrs_ when hlo_metadatas() is called. hlo_metadata_ptrs_ is + // refreshed whenever hlo_metadatas_ is set or the object is moved. + std::vector hlo_metadatas_; // Owned. + std::vector hlo_metadatas_ptrs_; + + TpuProgramGroup(const TpuProgramGroup&) = delete; + void operator=(const TpuProgramGroup&) = delete; +}; + +} // namespace tpu +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_TPU_KERNELS_TPU_PROGRAM_GROUP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/tpu/kernels/tpu_program_group_interface.h b/third_party/tflite-hdrs/tensorflow/core/tpu/kernels/tpu_program_group_interface.h new file mode 100644 index 00000000..02e91f5b --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/tpu/kernels/tpu_program_group_interface.h @@ -0,0 +1,81 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_PROGRAM_GROUP_INTERFACE_H_ +#define TENSORFLOW_CORE_TPU_KERNELS_TPU_PROGRAM_GROUP_INTERFACE_H_ + +#include +#include +#include +#include + +#include "absl/types/span.h" +#include "xla/service/hlo.pb.h" +#include "xla/stream_executor/tpu/tpu_ops_c_api.h" +#include "tensorflow/core/tpu/kernels/tpu_executable_info.pb.h" + +namespace tensorflow { +namespace tpu { + +// An interface to holds all the programs and metadatas generated by the +// compiler, including those for the sharding/unsharding programs. +class TpuProgramGroupInterface { + public: + virtual ~TpuProgramGroupInterface() = default; + + // Check if whether sharding/unsharding program exists. + virtual bool has_sharding_program() const = 0; + + // Computes program count. + virtual size_t program_count() const = 0; + + // Computes total program size. + virtual int64_t program_size() const = 0; + + // Unloads and destroys safely TPU programs. + virtual void UnloadAndDestroyPrograms() = 0; + + // Logs program memory summary. + virtual bool LogProgramMemorySummary() = 0; + + // Program fingerprints. + virtual const std::string& fingerprint(int index) const = 0; + + // Hlo metadatas. The pointers can only be used as long as the cache entry is + // referenced. + virtual absl::Span hlo_metadatas() const = 0; + + // Boolean array to indicate if the modification of variables are + // allowed. + virtual const std::vector& may_modify_variables_list() const = 0; + + // Gets may modify variables value of the TPU program for the given core + // `index`. + virtual bool may_modify_variables(int index) const = 0; + + // Get Executable Info Proto + virtual const TPUExecutableInfoProto& executable_info(int index) const = 0; + + // Get HostTransferInfo Proto + virtual const TPUHostTransferInfoProto& host_transfer_info( + int index) const = 0; + + // Get XLA_TpuProgram Proto + virtual const XLA_TpuProgram* tpu_program(int index) const = 0; +}; + +} // namespace tpu +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_TPU_KERNELS_TPU_PROGRAM_GROUP_INTERFACE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/tpu/kernels/tpu_reshard_variables_op.h b/third_party/tflite-hdrs/tensorflow/core/tpu/kernels/tpu_reshard_variables_op.h new file mode 100644 index 00000000..1d6c7bab --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/tpu/kernels/tpu_reshard_variables_op.h @@ -0,0 +1,50 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_RESHARD_VARIABLES_OP_H_ +#define TENSORFLOW_CORE_TPU_KERNELS_TPU_RESHARD_VARIABLES_OP_H_ + +#include + +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_common.pb.h" + +namespace tensorflow { + +// Op that changes the sharding state for a set of on-device variables. The +// sharding state is represented as the key of the compilation that generated +// the sharding/unsharding programs along with the main program. The op checks +// if the current sharding state matches the desired one, and if not, uses the +// sharding/unsharding programs to transform the variables to the desired state. +class TPUReshardVariablesOpKernel : public AsyncOpKernel { + public: + explicit TPUReshardVariablesOpKernel(OpKernelConstruction* context); + ~TPUReshardVariablesOpKernel() override; + + void ComputeAsync(OpKernelContext* context, DoneCallback done) override; + + private: + absl::Status DoWork(OpKernelContext* context); + absl::Status DoTpuExecute(OpKernelContext* context, const Tensor& format_key, + tpu::CompilationCacheFetchTarget fetch_target); + + int64_t num_vars_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_TPU_KERNELS_TPU_RESHARD_VARIABLES_OP_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/tpu/kernels/tpu_reshard_variables_op_util.h b/third_party/tflite-hdrs/tensorflow/core/tpu/kernels/tpu_reshard_variables_op_util.h new file mode 100644 index 00000000..c731cc10 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/tpu/kernels/tpu_reshard_variables_op_util.h @@ -0,0 +1,63 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_RESHARD_VARIABLES_OP_UTIL_H_ +#define TENSORFLOW_CORE_TPU_KERNELS_TPU_RESHARD_VARIABLES_OP_UTIL_H_ + +#include +#include + +#include "tensorflow/compiler/jit/variable_info.h" +#include "tensorflow/compiler/jit/xla_launch_util.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_common.pb.h" +#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.h" + +namespace tensorflow { +namespace tpu { +namespace reshard_variables { + +absl::Status FlushProgramMemory(se::Platform* platform, int device_ordinal); + +absl::Status CheckIsValidKey(const Tensor& key); + +bool IsDefaultKey(const Tensor& key); + +absl::Status GetComputationCacheEntry( + const Tensor& key, string* rendezvous_key_base, + std::unique_ptr* entry, + tpu::CompilationCacheFetchTarget fetch_target); + +absl::StatusOr> BuildInputBuffers( + OpKernelContext* context, const std::vector& variables, + const xla::Shape& input_host_shape, xla::Backend* backend, + int device_ordinal, se::Stream* stream); + +absl::Status PerformCompaction(stream_executor::Stream* stream); + +absl::Status UpdateOutputVariables( + OpKernelContext* context, xla::ScopedShapedBuffer result_buffers, + absl::Span output_tensor_shape_protos, + xla::Backend* backend, se::Stream* stream, int device_ordinal, + const std::vector& variables, + const std::shared_ptr& definition_event); + +} // namespace reshard_variables +} // namespace tpu +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_TPU_KERNELS_TPU_RESHARD_VARIABLES_OP_UTIL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/tpu/kernels/tpu_util.h b/third_party/tflite-hdrs/tensorflow/core/tpu/kernels/tpu_util.h new file mode 100644 index 00000000..0b0aedd4 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/tpu/kernels/tpu_util.h @@ -0,0 +1,66 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_UTIL_H_ +#define TENSORFLOW_CORE_TPU_KERNELS_TPU_UTIL_H_ + +#include +#include +#include + +#include "grpcpp/server_builder.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "tensorflow/cc/framework/ops.h" +#include "tensorflow/compiler/tf2xla/xla_compiler.h" +#include "xla/client/compile_only_client.h" +#include "tensorflow/core/protobuf/config.pb.h" +#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_key.h" + +namespace tensorflow { +namespace tpu { + +// Utility to get session_name from `SessionMetadata`. `SessionMetadata` may +// be null. +std::string SessionNameFromMetadata(const SessionMetadata* session_metadata); + +// Generates cache proto key for a given computation on a TPU core. +std::string ProtoKeyForComputation(const std::string& key, int core); + +// Returns a TpuCompilationCacheKey parsed from given key or an error. +absl::StatusOr ParseCompilationCacheKey( + const std::string& key); + +xla::CompileOnlyClient::AotXlaComputationInstance +BuildAotXlaComputationInstance( + const XlaCompiler::CompilationResult& compilation_result); + +// Returns true if TPU compilation is enabled. +bool IsTpuCompilationEnabled(); + +// Converts an int64 host memory `tensor` to a `shape`. +absl::Status ShapeTensorToTensorShape(const Tensor& tensor, TensorShape* shape); + +absl::Status DynamicShapesToTensorShapes(const OpInputList& dynamic_shapes, + std::vector* shapes); +absl::Status DynamicShapesToTensorShapes(const InputList& dynamic_shapes, + std::vector* shapes); + +// Creates gRPC ServerBuilder. +absl::StatusOr> CreateServerBuilder( + int serving_port); +} // namespace tpu +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_TPU_KERNELS_TPU_UTIL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/tpu/kernels/trace_util.h b/third_party/tflite-hdrs/tensorflow/core/tpu/kernels/trace_util.h new file mode 100644 index 00000000..e1c96233 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/tpu/kernels/trace_util.h @@ -0,0 +1,27 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_TPU_KERNELS_TRACE_UTIL_H_ +#define TENSORFLOW_CORE_TPU_KERNELS_TRACE_UTIL_H_ + +#ifdef PLATFORM_GOOGLE +#include "base/tracer.h" // IWYU pragma: export +#else +#undef TRACESTRING +#define TRACESTRING(x) +#undef TRACELITERAL +#define TRACELITERAL(x) +#endif + +#endif // TENSORFLOW_CORE_TPU_KERNELS_TRACE_UTIL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/tpu/kernels/transfer_ops.h b/third_party/tflite-hdrs/tensorflow/core/tpu/kernels/transfer_ops.h new file mode 100644 index 00000000..3c12d22f --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/tpu/kernels/transfer_ops.h @@ -0,0 +1,140 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_TPU_KERNELS_TRANSFER_OPS_H_ +#define TENSORFLOW_CORE_TPU_KERNELS_TRANSFER_OPS_H_ + +#include +#include +#include + +#include "xla/literal.h" +#include "xla/stream_executor/stream_executor.h" +#include "xla/stream_executor/tpu/noncopyable_buffer.h" +#include "xla/stream_executor/tpu/tpu_platform_interface.h" +#include "xla/stream_executor/tpu/tpu_transfer_manager_interface.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/statusor.h" +#include "tensorflow/core/platform/threadpool.h" + +namespace tensorflow { + +class TpuTransferOpInterface { + public: + virtual ~TpuTransferOpInterface() = default; + virtual void Cancel() = 0; + virtual absl::StatusOr GetDeviceOrdinal(OpKernelContext* ctx) = 0; + + virtual absl::Status TransferBuffersToInfeed( + int device_ordinal, + const std::deque& buffers) = 0; + virtual absl::Status TransferLiteralToInfeed( + int device_ordinal, const xla::LiteralSlice& literal) = 0; + virtual absl::Status TransferLiteralFromOutfeed( + int device_ordinal, xla::MutableBorrowingLiteral literal) = 0; +}; + +// Base class providing common functionality for async ops that transfer from +// host to TPU. +class TpuTransferAsyncOpKernelBase : public AsyncOpKernel { + public: + explicit TpuTransferAsyncOpKernelBase( + OpKernelConstruction* ctx, const std::string& transfer_type, + int number_of_threads, + std::unique_ptr transfer_op); + + void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override; + + protected: + virtual absl::Status DoWork(OpKernelContext* context, int device_ordinal) = 0; + + absl::Status RunTransferWithOrdinal(OpKernelContext* ctx, int device_ordinal); + std::string transfer_type_; + std::unique_ptr transfer_op_; + + private: + virtual absl::Status RunTransfer(OpKernelContext* ctx) = 0; + + std::unique_ptr thread_pool_; + mutex mu_; + + // TpuTransferAsyncOpKernelBase is neither copyable nor movable. + TpuTransferAsyncOpKernelBase(const TpuTransferAsyncOpKernelBase&) = delete; + TpuTransferAsyncOpKernelBase& operator=(const TpuTransferAsyncOpKernelBase&) = + delete; +}; + +class TpuTransferAsyncOpKernel : public TpuTransferAsyncOpKernelBase { + public: + explicit TpuTransferAsyncOpKernel( + OpKernelConstruction* ctx, const std::string& transfer_type, + int number_of_threads, + std::unique_ptr transfer_op); + + private: + absl::Status RunTransfer(OpKernelContext* ctx) override; + int device_ordinal_; + + // TpuTransferAsyncOpKernel is neither copyable nor movable. + TpuTransferAsyncOpKernel(const TpuTransferAsyncOpKernel&) = delete; + TpuTransferAsyncOpKernel& operator=(const TpuTransferAsyncOpKernel&) = delete; +}; + +class TpuTransferAsyncDynamicOrdinalOpKernel + : public TpuTransferAsyncOpKernelBase { + public: + explicit TpuTransferAsyncDynamicOrdinalOpKernel( + OpKernelConstruction* ctx, const std::string& transfer_type, + int number_of_threads, + std::unique_ptr transfer_op); + + private: + absl::Status RunTransfer(OpKernelContext* ctx) override; + + // TpuTransferAsyncDynamicOpKernel is neither copyable nor movable. + TpuTransferAsyncDynamicOrdinalOpKernel( + const TpuTransferAsyncDynamicOrdinalOpKernel&) = delete; + TpuTransferAsyncDynamicOrdinalOpKernel& operator=( + const TpuTransferAsyncDynamicOrdinalOpKernel&) = delete; +}; + +class StreamExecutorTransferOpImpl : public TpuTransferOpInterface { + public: + explicit StreamExecutorTransferOpImpl(); + ~StreamExecutorTransferOpImpl() override = default; + void Cancel() override; + absl::StatusOr GetDeviceOrdinal(OpKernelContext* ctx) override; + + absl::Status TransferBuffersToInfeed( + int device_ordinal, + const std::deque& buffers) override; + absl::Status TransferLiteralToInfeed( + int device_ordinal, const xla::LiteralSlice& literal) override; + + absl::Status TransferLiteralFromOutfeed( + int device_ordinal, xla::MutableBorrowingLiteral literal) override; + + private: + absl::StatusOr GetStreamExecutor( + int device_ordinal); + xla::TpuTransferManagerInterface* transfer_manager_; + tpu::TpuPlatformInterface* tpu_platform_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_TPU_KERNELS_TRANSFER_OPS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/tpu/ops/tpu_embedding_ops.h b/third_party/tflite-hdrs/tensorflow/core/tpu/ops/tpu_embedding_ops.h new file mode 100644 index 00000000..324f2b4e --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/tpu/ops/tpu_embedding_ops.h @@ -0,0 +1,37 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_TPU_OPS_TPU_EMBEDDING_OPS_H_ +#define TENSORFLOW_CORE_TPU_OPS_TPU_EMBEDDING_OPS_H_ + +#include +#include + +namespace tensorflow { +// Get the names of the LoadTPUEmbedding*Parameters ops. +std::vector GetPerTableLoadOptimizationParametersOps(); + +// Get the names of the RetrieveTPUEmbedding*Parameters ops. +std::vector GetPerTableRetrieveOptimizationParametersOps(); + +// Type enum of elements in deduplication data tuple. +enum DedupTupleElementType { + kInteger = 0, + kFloat = 1, +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_TPU_OPS_TPU_EMBEDDING_OPS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/tpu/ops/tpu_embedding_shape_util.h b/third_party/tflite-hdrs/tensorflow/core/tpu/ops/tpu_embedding_shape_util.h new file mode 100644 index 00000000..1d1e9138 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/tpu/ops/tpu_embedding_shape_util.h @@ -0,0 +1,70 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_TPU_OPS_TPU_EMBEDDING_SHAPE_UTIL_H_ +#define TENSORFLOW_CORE_TPU_OPS_TPU_EMBEDDING_SHAPE_UTIL_H_ + +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/types/span.h" +#include "tensorflow/core/framework/tensor_shape.pb.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/protobuf/tpu/tpu_embedding_configuration.pb.h" + +namespace tensorflow { +namespace tpu { + +// Utility class for inferring TpuEmbedding shape information. +class TpuEmbeddingShapeUtil { + public: + // Compute the shape of one embedding table stored on the + // TpuEmbeddingEngine. The table descriptor from the TpuEmbedding + // configuration is supplied in config. On success, shape is populated with + // the shape of the embedding table that will be loaded or retrieved using + // Ops such as {Load,Retrieve}TpuEmbedding*Parameters. + static absl::Status ComputeOneTableShape(int64_t vocabulary_size, + int table_dimension, int shard_id, + int num_shards, + TensorShapeProto* shape); + + // Compute the shapes of the embedding tables stored on the + // TpuEmbeddingEngine. The TpuEmbedding configuration is supplied in + // config. On success, shapes is populated with the shape of each embedding + // table that will be loaded or retrieved using Ops such as + // {Load,Retrieve}AllTpuEmbeddingParameters. + static absl::Status ComputeTableShapes( + absl::Span vocabulary_sizes, + absl::Span table_dimensions, int shard_id, int num_shards, + std::vector* shapes); + + static absl::Status ComputeTableShapes( + const tensorflow::tpu::TPUEmbeddingConfiguration& config, int shard_id, + int num_shards, std::vector* shapes); + + static TensorShapeProto MakeEmpty2DShape(); + + private: + // Compute the number of embedding IDs per embedding table shard. + // There are as many shards as the number of hosts in the job. + static absl::StatusOr ComputeNumEmbeddingIdsPerShard( + int64_t vocabulary_size, int shard_id, int num_shards); +}; + +} // namespace tpu +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_TPU_OPS_TPU_EMBEDDING_SHAPE_UTIL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/tpu/tpu_compile.h b/third_party/tflite-hdrs/tensorflow/core/tpu/tpu_compile.h new file mode 100644 index 00000000..f606c7e5 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/tpu/tpu_compile.h @@ -0,0 +1,72 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_TPU_TPU_COMPILE_H_ +#define TENSORFLOW_CORE_TPU_TPU_COMPILE_H_ + +#include + +#include "absl/types/span.h" +#include "tensorflow/compiler/jit/shape_inference.h" +#include "tensorflow/compiler/tf2xla/layout_util.h" +#include "tensorflow/compiler/tf2xla/xla_compiler.h" +#include "xla/client/compile_only_client.h" +#include "xla/shape.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h" +#include "tensorflow/core/tpu/kernels/tpu_compile_op_support.h" + +namespace tensorflow { +namespace tpu { +namespace internal { + +// Performs shape inference on the body of `graph`. Shapes for arguments +// are taken from `metadata` and `arg_shapes`. +absl::Status RunShapeInferenceOnComputation( + const tpu::TPUCompileMetadataProto& metadata, + const std::vector& arg_shapes, Graph* graph, + FunctionLibraryRuntime* flr, GraphShapeInfo* shape_info); +} // namespace internal + +// Converts a TF Function into XLA HLO, stores generated HLO module and +// accompanying metadata in CompilationResult. +absl::Status CompileTFFunctionToHlo( + const FunctionLibraryDefinition& flib_def, int graph_def_version, + const XlaShapeLayoutHelpers::ShapeDeterminationFns shape_determination_fns, + const std::vector& arg_shapes, const DeviceType& device_type, + const GuaranteedConsts& guaranteed_constants, const NameAttrList& function, + const tpu::TPUCompileMetadataProto& metadata, + xla::CompileOnlyClient* client, + std::vector* arg_core_mapping, + std::vector>* per_core_arg_shapes, + bool use_tuple_args, XlaCompiler::CompilationResult* compilation_result); + +// Gets information regarding how input arguments are sharded across multiple +// cores. +absl::Status GetShardingInfo( + const tpu::TPUCompileMetadataProto& metadata, + absl::Span arg_shapes, + const XlaShapeLayoutHelpers::ShapeDeterminationFns shape_determination_fns, + std::vector* arg_core_mapping, + std::vector>* per_core_arg_shapes); + +} // namespace tpu +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_TPU_TPU_COMPILE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/tpu/tpu_compile_interface.h b/third_party/tflite-hdrs/tensorflow/core/tpu/tpu_compile_interface.h new file mode 100644 index 00000000..a97e721b --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/tpu/tpu_compile_interface.h @@ -0,0 +1,47 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_TPU_TPU_COMPILE_INTERFACE_H_ +#define TENSORFLOW_CORE_TPU_TPU_COMPILE_INTERFACE_H_ + +#include + +#include "absl/strings/string_view.h" + +// Some legacy code requires different implementations for operations like +// fingerprint/hashing during compilation and/or graph rewriting. These +// alternate implementations can be registered (via a module initializer) to +// change the default behavior. +class TpuCompileInterface { + public: + virtual ~TpuCompileInterface() {} + static TpuCompileInterface* Get(); + static bool RegisterImplementation(TpuCompileInterface* impl); + + virtual uint64_t FingerprintString(absl::string_view str) = 0; + + // Proto: tensorflow::tpu::CompilationResultProto + // Location: tensorflow/core/protobuf/tpu/compilation_result.proto + static inline constexpr char kTpuCompileErrorPayloadKey[] = + "type.googleapis.com/tensorflow.tpu.CompilationResultProto"; + + // Unique string added to the error message for permanent errors during + // XLA:TPU compilation. This can be used by TensorFlow models to distinguish + // compilation errors from transient errors created by TPU worker preemptions + // and restarts. + static inline constexpr char kTpuCompileErrorMessage[] = + "XLA:TPU compile permanent error"; +}; + +#endif // TENSORFLOW_CORE_TPU_TPU_COMPILE_INTERFACE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/tpu/tpu_configuration.h b/third_party/tflite-hdrs/tensorflow/core/tpu/tpu_configuration.h new file mode 100644 index 00000000..0fbdb0f3 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/tpu/tpu_configuration.h @@ -0,0 +1,30 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_TPU_TPU_CONFIGURATION_H_ +#define TENSORFLOW_CORE_TPU_TPU_CONFIGURATION_H_ + +#include "tensorflow/core/framework/resource_mgr.h" + +namespace tensorflow { + +void MaybeInitializeTPUSystemForTests(); + +// Returns a process-wide global ResourceMgr. +ResourceMgr* GetTPUConfigResourceMgr(bool initialize_first = true); + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_TPU_TPU_CONFIGURATION_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/tpu/tpu_defs.h b/third_party/tflite-hdrs/tensorflow/core/tpu/tpu_defs.h new file mode 100644 index 00000000..b5c36680 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/tpu/tpu_defs.h @@ -0,0 +1,63 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Common definitions related to TPUs. + +#ifndef TENSORFLOW_CORE_TPU_TPU_DEFS_H_ +#define TENSORFLOW_CORE_TPU_TPU_DEFS_H_ + +#include + +#include "tensorflow/core/framework/types.pb.h" + +namespace tensorflow { + +// Name of the TPU device, which corresponds to a single core. +extern const char* const DEVICE_TPU_NODE; // "TPU"; + +// The TPU_REPLICATED_CORE device is a virtual device corresponding to one core +// of a replicated TPU computation. Only valid within the body of a +// TPUReplicate computation. +extern const char* const DEVICE_TPU_REPLICATED_CORE; + +// DEVICE_TPU_SYSTEM is now defined in tensorflow/core/framework/types.h/.cc + +// Name of the XLA_TPU_JIT compilation device, which is an internal device to +// compile graphs for TPU. Not registered as a device; no operators can be +// assigned to this device by a user. +extern const char* const DEVICE_TPU_XLA_JIT; // "XLA_TPU_JIT"; + +// Attribute used internally to pass "is_mirrored_variable" attribute on +// TPUReplicatedInput nodes to _TPUReplicate. +extern const char* const TPUREPLICATE_MIRRORED_VAR_INDICES_ATTR; + +// Attribute used internally to annotate ops which might consume TPU FastMem +// variable. +extern const char* const TPU_FAST_MEM_ATTR; // "_TPU_FAST_MEM" + +extern const char* const kTPUReplicateAttr; +extern const char* const kOutsideCompilationAttr; + +// Supported types for TPUs. +inline constexpr std::array kTpuAllTypes = { + {DT_INT32, DT_UINT32, DT_FLOAT8_E4M3FN, DT_FLOAT8_E5M2, DT_HALF, + DT_BFLOAT16, DT_FLOAT, DT_DOUBLE, DT_BOOL, DT_COMPLEX64, + DT_INT64, DT_UINT64, DT_QINT8, DT_QUINT8, DT_QINT32, + DT_INT8, DT_UINT8, DT_INT16, DT_UINT16, DT_INT4, + DT_UINT4}}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_TPU_TPU_DEFS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/tpu/tpu_embedding_configuration_proto_rewrite.h b/third_party/tflite-hdrs/tensorflow/core/tpu/tpu_embedding_configuration_proto_rewrite.h new file mode 100644 index 00000000..063f6668 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/tpu/tpu_embedding_configuration_proto_rewrite.h @@ -0,0 +1,44 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_TPU_TPU_EMBEDDING_CONFIGURATION_PROTO_REWRITE_H_ +#define TENSORFLOW_CORE_TPU_TPU_EMBEDDING_CONFIGURATION_PROTO_REWRITE_H_ + +#include "absl/status/status.h" +#include "tensorflow/core/protobuf/tpu/tpu_embedding_configuration.pb.h" + +namespace tensorflow { + +// Validates the TPU embedding configuration has been populated correctly and +// fills in missing fields. The user model is expected to fill in exactly one of +// the following: +// +// (1) batch_size_per_tensor_core and TableDescriptor.num_features, or +// (2) feature_descriptor. +// +// (1) If the user model fills in batch_size_per_tensor_core and +// TableDescriptor.num_features, this function validates that the +// feature_descriptor has not been filled in, and then populates +// feature_descriptor with appropriate values. +// +// (2) If the user model fills in feature_descriptor, this function validates +// that batch_size_per_tensor_core and TableDescriptor.num_features have not +// been filled in, and then populated them with appropriate values. +absl::Status PopulateMissingFieldsInTPUEmbeddingConfig( + tpu::TPUEmbeddingConfiguration* config); + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_TPU_TPU_EMBEDDING_CONFIGURATION_PROTO_REWRITE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/tpu/tpu_embedding_configuration_utils.h b/third_party/tflite-hdrs/tensorflow/core/tpu/tpu_embedding_configuration_utils.h new file mode 100644 index 00000000..3ac55d17 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/tpu/tpu_embedding_configuration_utils.h @@ -0,0 +1,37 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_TPU_TPU_EMBEDDING_CONFIGURATION_UTILS_H_ +#define TENSORFLOW_CORE_TPU_TPU_EMBEDDING_CONFIGURATION_UTILS_H_ + +#include + +#include "absl/status/statusor.h" +#include "tensorflow/core/protobuf/tpu/tpu_embedding_configuration.pb.h" + +namespace tensorflow { +namespace tpu { + +// Returns the total number of unique dynamic input tags used in optimizers. If +// the tag specific is erroneous, returns an invalid argument error. For correct +// tag specification, see the comment next to the OptimizerDynamicInput proto in +// //third_party/tensorflow/core/protobuf/tpu/optimization_parameters.proto. +absl::StatusOr ComputeTotalTagCountForOptimizerDynamicInputs( + const tensorflow::tpu::TPUEmbeddingConfiguration& tpu_embedding_config); + +} // namespace tpu +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_TPU_TPU_EMBEDDING_CONFIGURATION_UTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/tpu/tpu_embedding_errors.h b/third_party/tflite-hdrs/tensorflow/core/tpu/tpu_embedding_errors.h new file mode 100644 index 00000000..42a91124 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/tpu/tpu_embedding_errors.h @@ -0,0 +1,70 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_TPU_TPU_EMBEDDING_ERRORS_H_ +#define TENSORFLOW_CORE_TPU_TPU_EMBEDDING_ERRORS_H_ + +#include + +#include "absl/strings/cord.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/statusor.h" +#include "tensorflow/core/protobuf/tpu/tpu_embedding_configuration.pb.h" + +namespace tensorflow::tpu { + +// The payload URL for TPU embedding initialization permanent errors. +constexpr absl::string_view kTpuEmbeddingErrorUrl = + "type.googleapis.com/tensorflow.tpu.TPUEmbeddingError"; + +constexpr absl::string_view kTpuEmbeddingErrorMessage = + "TPUEmbedding permanent error"; + +// Appends a payload of type tensorflow::tpu::kTpuEmbeddingErrorUrl to the +// tensorflow::Status obj if the status is NOT OK. Returns the +// tensorflow::Status obj unchanged if the status is OK. +absl::Status AppendTpuEmbeddingErrorPayload(absl::Status obj); + +// Appends a payload of type tensorflow::tpu::kTpuEmbeddingErrorUrl to the +// tensorflow::Status obj if the status is NOT OK. Returns obj.value() if the +// status is OK. +template +StatusOr AppendTpuEmbeddingErrorPayload(StatusOr obj) { + if (obj.ok()) { + return std::move(obj.value()); + } else { + const std::string error_message = + absl::StrCat(kTpuEmbeddingErrorMessage, ". ", obj.status().message()); + absl::Status status(obj.status().code(), error_message); + TPUEmbeddingError error_payload; + status.SetPayload(kTpuEmbeddingErrorUrl, + absl::Cord(error_payload.SerializeAsString())); + return status; + } +} + +// Returns true if the tensorflow::Status obj has a payload of type +// tensorflow::tpu::kTpuEmbeddingErrorUrl. +bool HasTpuEmbeddingErrorPayload(const absl::Status& status); + +// Returns true if the tensorflow::Status obj error message contains +// tensorflow::tpu::kTpuEmbeddingErrorMessage as a substring. +bool HasTpuEmbeddingErrorMessage(const absl::Status& status); + +} // namespace tensorflow::tpu + +#endif // TENSORFLOW_CORE_TPU_TPU_EMBEDDING_ERRORS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/tpu/tpu_embedding_optimization_parameters_utils.h b/third_party/tflite-hdrs/tensorflow/core/tpu/tpu_embedding_optimization_parameters_utils.h new file mode 100644 index 00000000..43643fbd --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/tpu/tpu_embedding_optimization_parameters_utils.h @@ -0,0 +1,136 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_TPU_TPU_EMBEDDING_OPTIMIZATION_PARAMETERS_UTILS_H_ +#define TENSORFLOW_CORE_TPU_TPU_EMBEDDING_OPTIMIZATION_PARAMETERS_UTILS_H_ + +#include +#include +#include + +#include "absl/base/casts.h" +#include "absl/container/flat_hash_set.h" +#include "absl/status/status.h" +#include "tensorflow/core/framework/op_def_builder.h" +#include "tensorflow/core/protobuf/tpu/optimization_parameters.pb.h" + +namespace tensorflow { +namespace tpu { + +using OptimizationAlgorithm = OptimizationParameters::ParametersCase; + +// Returns the name of the optimization algorithm. +std::string GetOptimizationAlgorithmName(OptimizationAlgorithm alg); + +// Returns a user-friendly name for the optimization algorithm. +std::string GetOptimizationAlgorithmFriendlyName(OptimizationAlgorithm alg); + +// Returns all supported optimization algorithms. +std::vector GetOptimizationAlgorithms(); + +enum class GradientAccumulationSupport { + // Accumulation cannot be used with this optimizer. + kNotSupported, + + // Accumulation is allowed and changes optimizer behavior. + kSupported, +}; + +// Returns the number of optimization parameter vectors used by the optimization +// algorithm, excluding the weights themselves and assuming no gradient +// accumulation. +absl::Status GetBaseAuxiliaryParameterCount( + const OptimizationParameters ¶ms, int *count); + +// Returns whether (and how) an optimization algorithm supports gradient +// accumulation. +absl::Status GetGradientAccumulationSupport( + const OptimizationParameters ¶ms, GradientAccumulationSupport *support); + +// Returns whether both the given set of optimization parameters has gradient +// accumulation turned on and that the algorithm used supports it or should +// ignore that setting. Returns an error if gradient accumulation is enabled and +// the algorithm does not support it. +absl::Status UseGradientAccumulation(const OptimizationParameters ¶ms, + bool *use_gradient_accumulation); + +// Returns the parameter specifications for the optimization algorithm (the main +// parameters first, followed by any auxiliary parameters such as Adagrad +// accumulators). +absl::Status GetOptimizationAlgorithmStateVariables( + const OptimizationParameters ¶ms, + std::vector *state_variables); + +// Returns the set of dynamic input tags used by the optimization algorithm. +// This includes both dynamic learning rates and other hyperparameters (e.g., +// step counters for the frequency aware Adagrad optimizer). +absl::flat_hash_set GetOptimizerDynamicInputTags( + const OptimizationParameters ¶ms); + +// Returns the set of dynamic hyperparameter tags used by the optimization +// algorithm. This includes other hyperparameters used by the optimization +// algorithm (e.g., step counters for the frequency aware Adagrad optimizer). It +// excludes the dynamic learning rate tag. +absl::flat_hash_set GetOptimizerHyperParameterTags( + const OptimizationParameters ¶ms); + +// Returns true if the optimization algorithm uses dynamic inputs in its +// computation. +bool UsesDynamicInputsInOptimizer(const OptimizationParameters ¶ms); + +// Maximum value of auxiliary_parametery_count for any optimization algorithm. +// This count is used by TPU embedding load/retrieve and needs to be independent +// of any particular TPU version and hence, we take the maximum across all TPU +// versions. +static constexpr int kMaxAuxiliaryParameterCount = 7; + +// Fill value for gradient accumulators. This is a denormal so that it will be +// flushed to zero on the current TPU platforms and needs to continue to have +// the following properties in the future: +// +// 1. Does not have the same bit pattern as a zero and can be distinguished from +// it using integer operations. +// 2. Treated as zero by floating-point arithmetic operations (at least addition +// and subtraction). +// 3. Cannot be produced by any floating-point arithmetic operation, including +// those involving itself. +// +// It does not need to compare equal or not equal to zero in floating point. We +// need to use a non-zero value here because some optimization algorithms are +// not no-ops on zero gradients, so we need to distinguish an accumulated +// gradient of zero from one that has been cleared after its gradients have +// already been applied to the parameters and accumulators. +inline float GradientAccumulatorInitialValue() { + return absl::bit_cast(1); +} + +// Generic shape function for per-optimization-algorithm load ops. +class LoadOpShapeFunction { + public: + // Computes resulting shape and does parameter checking. + absl::Status operator()(shape_inference::InferenceContext *c) const; +}; + +// Generic shape function for per-optimization-algorithm retrieve ops. +class RetrieveOpShapeFunction { + public: + // Computes resulting shape and does parameter checking. + absl::Status operator()(shape_inference::InferenceContext *c) const; +}; + +} // namespace tpu +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_TPU_TPU_EMBEDDING_OPTIMIZATION_PARAMETERS_UTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/tpu/tpu_embedding_output_layout_utils.h b/third_party/tflite-hdrs/tensorflow/core/tpu/tpu_embedding_output_layout_utils.h new file mode 100644 index 00000000..f05b774b --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/tpu/tpu_embedding_output_layout_utils.h @@ -0,0 +1,36 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_TPU_TPU_EMBEDDING_OUTPUT_LAYOUT_UTILS_H_ +#define TENSORFLOW_CORE_TPU_TPU_EMBEDDING_OUTPUT_LAYOUT_UTILS_H_ + +#include + +#include "tensorflow/core/framework/tensor_shape.pb.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/protobuf/tpu/tpu_embedding_configuration.pb.h" + +namespace tensorflow { +namespace tpu { + +// Computes the shape of the output tensors from an embedding configuration. +absl::Status ComputeOutputTensorShapes( + const tensorflow::tpu::TPUEmbeddingConfiguration& config, + std::vector* shapes); + +} // namespace tpu +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_TPU_TPU_EMBEDDING_OUTPUT_LAYOUT_UTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/tpu/tpu_embedding_spmd_sharding_utils.h b/third_party/tflite-hdrs/tensorflow/core/tpu/tpu_embedding_spmd_sharding_utils.h new file mode 100644 index 00000000..957de62b --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/tpu/tpu_embedding_spmd_sharding_utils.h @@ -0,0 +1,38 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_TPU_TPU_EMBEDDING_SPMD_SHARDING_UTILS_H_ +#define TENSORFLOW_CORE_TPU_TPU_EMBEDDING_SPMD_SHARDING_UTILS_H_ + +#include "xla/hlo/builder/xla_builder.h" +#include "xla/shape.h" +#include "xla/xla_data.pb.h" +#include "tensorflow/core/platform/statusor.h" + +namespace tensorflow { +namespace tpu { + +// Gets SPMD manual sharding annotation from the input shape. If the shape is a +// scalar (rank = 0), the tensor is replicated across all the cores within the +// replica. If the shape is a non-scalar (rank >= 1), the tensor is sharded on +// dimension `0' across all the cores within the same replica. +absl::StatusOr SpmdShardingAnnotationOnFirstDim( + const xla::Shape& shape, int core_count_per_replica, + xla::XlaBuilder* builder); + +} // namespace tpu +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_TPU_TPU_EMBEDDING_SPMD_SHARDING_UTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/tpu/tpu_execute.h b/third_party/tflite-hdrs/tensorflow/core/tpu/tpu_execute.h new file mode 100644 index 00000000..bfa177d6 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/tpu/tpu_execute.h @@ -0,0 +1,54 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_TPU_TPU_EXECUTE_H_ +#define TENSORFLOW_CORE_TPU_TPU_EXECUTE_H_ + +#include +#include + +#include "absl/status/statusor.h" +#include "xla/service/computation_placer.h" +#include "xla/service/executable.h" +#include "xla/service/hlo.pb.h" +#include "xla/stream_executor/stream.h" +#include "xla/stream_executor/tpu/tpu_node_context.h" +#include "xla/stream_executor/tpu/tpu_ops_c_api.h" +#include "tensorflow/core/framework/cancellation.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/tpu/kernels/tpu_executable_info.pb.h" + +namespace tensorflow { + +// Runs a TPU executable. `input_allocations` and `output_allocations` are +// non-owning pointers to the root buffers of each argument/result tuple. +// `output_shape` is the output shape of the XLA computation from which +// `program` was derived. If `session_module` is not nullptr, it will be filled +// with the input and output literals of the execution. +absl::StatusOr TPUExecute( + const TPUExecutableInfoProto& executable, + const TPUHostTransferInfoProto& host_transfers, + const xla::HloProto& hlo_metadata, + std::vector arguments, + const std::string& rendezvous_key_base, uint32 rng_seed, + tpu::TpuNodeContext* node_context, xla::DeviceAssignment* device_assignment, + CancellationManager* cancellation_manager, OpKernelContext* ctx, + stream_executor::Stream* stream, + stream_executor::Stream* host_to_device_stream, + const XLA_TpuProgram* tpu_program); + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_TPU_TPU_EXECUTE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/tpu/tpu_fingerprint_utils.h b/third_party/tflite-hdrs/tensorflow/core/tpu/tpu_fingerprint_utils.h new file mode 100644 index 00000000..c7cb99db --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/tpu/tpu_fingerprint_utils.h @@ -0,0 +1,30 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_TPU_TPU_FINGERPRINT_UTILS_H_ +#define TENSORFLOW_CORE_TPU_TPU_FINGERPRINT_UTILS_H_ + +#include + +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/platform/status.h" + +namespace tensorflow { +// Computes a fingerprint of the contents of `library`. +absl::Status FingerprintFunctionLibrary( + const FunctionLibraryDefinition& library, uint64_t& fingerprint); +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_TPU_TPU_FINGERPRINT_UTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/tpu/tpu_global_init.h b/third_party/tflite-hdrs/tensorflow/core/tpu/tpu_global_init.h new file mode 100644 index 00000000..4d6dd064 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/tpu/tpu_global_init.h @@ -0,0 +1,78 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_TPU_TPU_GLOBAL_INIT_H_ +#define TENSORFLOW_CORE_TPU_TPU_GLOBAL_INIT_H_ + +#include "absl/strings/string_view.h" +#include "tensorflow/core/common_runtime/device_set.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/protobuf/tpu/topology.pb.h" + +namespace tensorflow { + +// Initializes the TPU system globally. The state of initialization can then be +// shared by different sessions running on these TPUs, on the same process. This +// API is provided for multi-tenant usecases where multiple sessions in a +// process are using the same set of TPUs. +// +// Returns status errors if initialization is unsuccessful and returns the TPU +// TopologyProto as an output parameter. +// +// REQUIRES: +// * Call this API before any sessions using TPUs are run. +// * If you are using this API for initialization, please don't use the TPU +// configuration ops within your graph. This will cause errors to be returned +// from the API which is called second. +// +// DISTRIBUTED SETUP: +// To properly initialize a TPU topology that is beyond donut level, caller is +// required to provide correct following arguments: +// +// 1. job_name +// The name of the job under distributed settings. For example, if the job is +// '/job:tpu_worker/replica:0/task:0/...', the "tpu_worker" is the desired +// job_name here. +// +// 2. session_target +// The target string that will be used to create a Session and run the +// distributed TPU initialization graph. Generally this would be the master +// session from the cluster. +// +// 3.device_set +// The GLOBAL set of devices in the distributed setting, including proper +// "TPU_SYSTEM" devices across all tasks. +// For example, device_set should contain two "TPU_SYSTEM" devices on 2 tasks +// for a 4x2 (2 TPU workers) setup, and other non "TPU_SYSTEM" devices. +absl::Status InitializeTPUSystemGlobally(absl::string_view job_name, + absl::string_view session_target, + const DeviceSet& device_set, Env* env, + tpu::TopologyProto* tpu_topology); + +absl::Status InitializeTPUSystemGlobally(Env* env, + tpu::TopologyProto* tpu_topology); + +absl::Status InitializeTPUSystemGlobally(); + +} // namespace tensorflow + +// Many clients rely on ADL to lookup InitializeTPUSystemGlobally, now that Env +// moved to namespace tsl they are all broken without these forwarding +// declarations. +namespace tsl { +using tensorflow::InitializeTPUSystemGlobally; // NOLINT +} + +#endif // TENSORFLOW_CORE_TPU_TPU_GLOBAL_INIT_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/tpu/tpu_init_mode.h b/third_party/tflite-hdrs/tensorflow/core/tpu/tpu_init_mode.h new file mode 100644 index 00000000..0f8ad389 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/tpu/tpu_init_mode.h @@ -0,0 +1,47 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_TPU_TPU_INIT_MODE_H_ +#define TENSORFLOW_CORE_TPU_TPU_INIT_MODE_H_ + +#include "tensorflow/core/platform/status.h" + +namespace tensorflow { + +enum class TPUInitMode : int { kNone, kGlobal, kRegular }; + +// Sets the TPU initialization mode appropriately. +// +// Requires that mode is not kNone, and mode doesn't transition kGlobal +// <-> kRegular. +// +// IMPLEMENTATION DETAILS: +// Used internally to record the current mode and type of API used for TPU +// initialization in a global static variable. +absl::Status SetTPUInitMode(TPUInitMode mode); + +// Returns the current TPUInitMode. +TPUInitMode GetTPUInitMode(); + +namespace test { + +// Forces the tpu init mode to be changed. +void ForceSetTPUInitMode(TPUInitMode mode); + +} // namespace test + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_TPU_TPU_INIT_MODE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/tpu/tpu_model_server_initializer.h b/third_party/tflite-hdrs/tensorflow/core/tpu/tpu_model_server_initializer.h new file mode 100644 index 00000000..7ebafaea --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/tpu/tpu_model_server_initializer.h @@ -0,0 +1,28 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_TPU_TPU_MODEL_SERVER_INITIALIZER_H_ +#define TENSORFLOW_CORE_TPU_TPU_MODEL_SERVER_INITIALIZER_H_ + +#include "xla/stream_executor/tpu/libtftpu.h" +#include "xla/stream_executor/tpu/tpu_executor_c_api.h" +#include "xla/stream_executor/tpu/tpu_ops_c_api.h" +#include "tensorflow/core/platform/status.h" + +namespace tensorflow { +namespace tpu {} // namespace tpu +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_TPU_TPU_MODEL_SERVER_INITIALIZER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/tpu/tpu_node_device_util.h b/third_party/tflite-hdrs/tensorflow/core/tpu/tpu_node_device_util.h new file mode 100644 index 00000000..b784727f --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/tpu/tpu_node_device_util.h @@ -0,0 +1,30 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_TPU_TPU_NODE_DEVICE_UTIL_H_ +#define TENSORFLOW_CORE_TPU_TPU_NODE_DEVICE_UTIL_H_ + +#include "tensorflow/core/framework/kernel_def.pb.h" + +namespace tensorflow { + +// This is a BackendOpFilter. (see tensorflow/compiler/tf2xla/xla_op_registry.h) +// It returns true if the op should be registered on the device, it may +// optionally modify the KernelDef. +bool TpuOpFilter(KernelDef* kdef); + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_TPU_TPU_NODE_DEVICE_UTIL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/tpu/virtual_device.h b/third_party/tflite-hdrs/tensorflow/core/tpu/virtual_device.h new file mode 100644 index 00000000..08233ece --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/tpu/virtual_device.h @@ -0,0 +1,39 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_TPU_VIRTUAL_DEVICE_H_ +#define TENSORFLOW_CORE_TPU_VIRTUAL_DEVICE_H_ + +#include "tensorflow/core/common_runtime/device.h" + +namespace tensorflow { + +// A dummy device that exists primarily for operator placement, without +// corresponding directly to a piece of hardware. +class VirtualDevice : public Device { + public: + VirtualDevice(Env* env, const DeviceAttributes& device_attributes); + + absl::Status Sync() override; + Allocator* GetAllocator(AllocatorAttributes attr) override; + absl::Status MakeTensorFromProto(const TensorProto& tensor_proto, + const AllocatorAttributes alloc_attrs, + Tensor* tensor) override; + absl::Status TryGetDeviceContext(DeviceContext** out_context) override; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_TPU_VIRTUAL_DEVICE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/transforms/cf_sink/pass.h b/third_party/tflite-hdrs/tensorflow/core/transforms/cf_sink/pass.h new file mode 100644 index 00000000..6db168c5 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/transforms/cf_sink/pass.h @@ -0,0 +1,31 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_TRANSFORMS_CF_SINK_PASS_H_ +#define TENSORFLOW_CORE_TRANSFORMS_CF_SINK_PASS_H_ + +#include + +namespace mlir { +class Pass; + +namespace tfg { + +std::unique_ptr CreateControlFlowSinkPass(); + +} // namespace tfg +} // namespace mlir + +#endif // TENSORFLOW_CORE_TRANSFORMS_CF_SINK_PASS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/transforms/consolidate_attrs/pass.h b/third_party/tflite-hdrs/tensorflow/core/transforms/consolidate_attrs/pass.h new file mode 100644 index 00000000..e945c18a --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/transforms/consolidate_attrs/pass.h @@ -0,0 +1,29 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_TRANSFORMS_CONSOLIDATE_ATTRS_PASS_H_ +#define TENSORFLOW_CORE_TRANSFORMS_CONSOLIDATE_ATTRS_PASS_H_ + +#include + +namespace mlir { +class Pass; +namespace tfg { +std::unique_ptr CreateConsolidateAttributesPass(); +std::unique_ptr CreatePrepareAttributesForExportPass(); +} // namespace tfg +} // namespace mlir + +#endif // TENSORFLOW_CORE_TRANSFORMS_CONSOLIDATE_ATTRS_PASS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/transforms/const_dedupe_hoist/pass.h b/third_party/tflite-hdrs/tensorflow/core/transforms/const_dedupe_hoist/pass.h new file mode 100644 index 00000000..4b9b8133 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/transforms/const_dedupe_hoist/pass.h @@ -0,0 +1,34 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_TRANSFORMS_CONST_DEDUPE_HOIST_PASS_H_ +#define TENSORFLOW_CORE_TRANSFORMS_CONST_DEDUPE_HOIST_PASS_H_ + +#include + +#include "mlir/Pass/Pass.h" // from @llvm-project + +namespace mlir { +namespace tfg { + +#define GEN_PASS_DECL_DEDUPEANDHOISTCONSTANT +#include "tensorflow/core/transforms/passes.h.inc" + +std::unique_ptr CreateDedupeAndHoistConstantPass(); + +} // namespace tfg +} // namespace mlir + +#endif // TENSORFLOW_CORE_TRANSFORMS_CONST_DEDUPE_HOIST_PASS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/transforms/constant_folding/pass.h b/third_party/tflite-hdrs/tensorflow/core/transforms/constant_folding/pass.h new file mode 100644 index 00000000..99603d4f --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/transforms/constant_folding/pass.h @@ -0,0 +1,34 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_TRANSFORMS_CONSTANT_FOLDING_PASS_H_ +#define TENSORFLOW_CORE_TRANSFORMS_CONSTANT_FOLDING_PASS_H_ + +#include + +#include "mlir/Pass/Pass.h" // from @llvm-project + +namespace mlir { +namespace tfg { + +#define GEN_PASS_DECL_CONSTANTFOLDINGPASS +#include "tensorflow/core/transforms/passes.h.inc" + +// Create a constant folding pass. +std::unique_ptr CreateConstantFoldingPass(); + +} // namespace tfg +} // namespace mlir + +#endif // TENSORFLOW_CORE_TRANSFORMS_CONSTANT_FOLDING_PASS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/transforms/cse/pass.h b/third_party/tflite-hdrs/tensorflow/core/transforms/cse/pass.h new file mode 100644 index 00000000..65f1d24b --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/transforms/cse/pass.h @@ -0,0 +1,29 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_TRANSFORMS_CSE_PASS_H_ +#define TENSORFLOW_CORE_TRANSFORMS_CSE_PASS_H_ + +#include + +#include "mlir/Pass/Pass.h" // from @llvm-project + +namespace mlir { +namespace tfg { +std::unique_ptr CreateCSEPass(); +} // namespace tfg +} // namespace mlir + +#endif // TENSORFLOW_CORE_TRANSFORMS_CSE_PASS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/transforms/drop_unregistered_attribute/pass.h b/third_party/tflite-hdrs/tensorflow/core/transforms/drop_unregistered_attribute/pass.h new file mode 100644 index 00000000..b0b46bca --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/transforms/drop_unregistered_attribute/pass.h @@ -0,0 +1,34 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_TRANSFORMS_DROP_UNREGISTERED_ATTRIBUTE_PASS_H_ +#define TENSORFLOW_CORE_TRANSFORMS_DROP_UNREGISTERED_ATTRIBUTE_PASS_H_ + +#include + +#include "mlir/Pass/Pass.h" // from @llvm-project + +namespace mlir { +namespace tfg { + +#define GEN_PASS_DECL_DROPOUTPUTSHAPESATTR +#include "tensorflow/core/transforms/passes.h.inc" + +std::unique_ptr CreateDropOutputShapesAttrPass(); + +} // namespace tfg +} // namespace mlir + +#endif // TENSORFLOW_CORE_TRANSFORMS_DROP_UNREGISTERED_ATTRIBUTE_PASS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/transforms/eliminate_passthrough_iter_args/pass.h b/third_party/tflite-hdrs/tensorflow/core/transforms/eliminate_passthrough_iter_args/pass.h new file mode 100644 index 00000000..186ab2d5 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/transforms/eliminate_passthrough_iter_args/pass.h @@ -0,0 +1,34 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This file defines the constructor for the eliminate passthrough iteration +// arguments pass. + +#ifndef TENSORFLOW_CORE_TRANSFORMS_ELIMINATE_PASSTHROUGH_ITER_ARGS_PASS_H_ +#define TENSORFLOW_CORE_TRANSFORMS_ELIMINATE_PASSTHROUGH_ITER_ARGS_PASS_H_ + +#include + +#include "mlir/Pass/Pass.h" // from @llvm-project + +namespace mlir { +namespace tfg { +// Creates a pass that eliminates passthrough iteration arguments from +// region-based loop operations. +std::unique_ptr CreateEliminatePassthroughIterArgsPass(); +} // namespace tfg +} // namespace mlir + +#endif // TENSORFLOW_CORE_TRANSFORMS_ELIMINATE_PASSTHROUGH_ITER_ARGS_PASS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/transforms/func_to_graph/func_to_graph.h b/third_party/tflite-hdrs/tensorflow/core/transforms/func_to_graph/func_to_graph.h new file mode 100644 index 00000000..5cab621b --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/transforms/func_to_graph/func_to_graph.h @@ -0,0 +1,33 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_TRANSFORMS_FUNC_TO_GRAPH_FUNC_TO_GRAPH_H_ +#define TENSORFLOW_CORE_TRANSFORMS_FUNC_TO_GRAPH_FUNC_TO_GRAPH_H_ + +#include "tensorflow/core/ir/ops.h" +#include "tensorflow/core/platform/status.h" + +namespace mlir { +namespace tfg { + +// Lowers a lifted graph func back to the graph. The uses of function arguments +// will be replaced with the associated value according to +// `tfg.lifted_value_attr` attribute. +absl::Status FuncToGraph(GraphFuncOp func); + +} // namespace tfg +} // namespace mlir + +#endif // TENSORFLOW_CORE_TRANSFORMS_FUNC_TO_GRAPH_FUNC_TO_GRAPH_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/transforms/func_to_graph/pass.h b/third_party/tflite-hdrs/tensorflow/core/transforms/func_to_graph/pass.h new file mode 100644 index 00000000..498aabc6 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/transforms/func_to_graph/pass.h @@ -0,0 +1,33 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_TRANSFORMS_FUNC_TO_GRAPH_PASS_H_ +#define TENSORFLOW_CORE_TRANSFORMS_FUNC_TO_GRAPH_PASS_H_ + +#include + +#include "mlir/Pass/Pass.h" // from @llvm-project + +namespace mlir { +namespace tfg { + +// Creates a pass which turns the function to a graph. Note that only the +// function which is lifted from graph is valid. +std::unique_ptr CreateFuncToGraphPass(); + +} // namespace tfg +} // namespace mlir + +#endif // TENSORFLOW_CORE_TRANSFORMS_FUNC_TO_GRAPH_PASS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/transforms/functional_to_region/impl.h b/third_party/tflite-hdrs/tensorflow/core/transforms/functional_to_region/impl.h new file mode 100644 index 00000000..023843d4 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/transforms/functional_to_region/impl.h @@ -0,0 +1,31 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_TRANSFORMS_FUNCTIONAL_TO_REGION_IMPL_H_ +#define TENSORFLOW_CORE_TRANSFORMS_FUNCTIONAL_TO_REGION_IMPL_H_ + +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/SymbolTable.h" // from @llvm-project + +namespace mlir { +namespace tfg { + +void PopulateFunctionalToRegionPatterns(RewritePatternSet &patterns, + SymbolTable &table); + +} // namespace tfg +} // namespace mlir + +#endif // TENSORFLOW_CORE_TRANSFORMS_FUNCTIONAL_TO_REGION_IMPL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/transforms/functional_to_region/pass.h b/third_party/tflite-hdrs/tensorflow/core/transforms/functional_to_region/pass.h new file mode 100644 index 00000000..362f11b8 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/transforms/functional_to_region/pass.h @@ -0,0 +1,31 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_TRANSFORMS_FUNCTIONAL_TO_REGION_PASS_H_ +#define TENSORFLOW_CORE_TRANSFORMS_FUNCTIONAL_TO_REGION_PASS_H_ + +#include + +namespace mlir { +class Pass; + +namespace tfg { + +std::unique_ptr CreateFunctionalToRegionPass(); + +} // namespace tfg +} // namespace mlir + +#endif // TENSORFLOW_CORE_TRANSFORMS_FUNCTIONAL_TO_REGION_PASS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/transforms/graph_compactor/pass.h b/third_party/tflite-hdrs/tensorflow/core/transforms/graph_compactor/pass.h new file mode 100644 index 00000000..6767bb59 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/transforms/graph_compactor/pass.h @@ -0,0 +1,31 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_TRANSFORMS_GRAPH_COMPACTOR_PASS_H_ +#define TENSORFLOW_CORE_TRANSFORMS_GRAPH_COMPACTOR_PASS_H_ + +#include + +#include "mlir/Pass/Pass.h" // from @llvm-project + +namespace mlir { +namespace tfg { +std::unique_ptr CreateNameCompressPass(); +std::unique_ptr CreateStripDefaultAttrsPass(); +std::unique_ptr CreateAddDefaultAttrsPass(); +} // namespace tfg +} // namespace mlir + +#endif // TENSORFLOW_CORE_TRANSFORMS_GRAPH_COMPACTOR_PASS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/transforms/graph_to_func/graph_to_func.h b/third_party/tflite-hdrs/tensorflow/core/transforms/graph_to_func/graph_to_func.h new file mode 100644 index 00000000..94723c96 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/transforms/graph_to_func/graph_to_func.h @@ -0,0 +1,45 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_TRANSFORMS_GRAPH_TO_FUNC_GRAPH_TO_FUNC_H_ +#define TENSORFLOW_CORE_TRANSFORMS_GRAPH_TO_FUNC_GRAPH_TO_FUNC_H_ + +#include + +#include "tensorflow/core/ir/ops.h" +#include "tensorflow/core/platform/status.h" + +namespace mlir { +namespace tfg { + +// Lifts a graph into a function, using the provided array of `feeds` for +// function arguments, `fetches` for function returned values, and +// `control_rets` for returned control values. The Graph op is replaced in-place +// by a GraphFuncOp with a name defined in the dialect. +absl::Status GraphToFunc(GraphOp graph, ArrayRef feeds, + ArrayRef fetches, ArrayRef control_rets); + +// Lifts a graph into a function, using the provided array of `feeds` for +// function arguments, `fetches` for function returned values, and +// `control_rets` for returned control values. The Graph op is replaced in-place +// by a GraphFuncOp with a name defined in the dialect. +absl::Status GraphToFunc(GraphOp graph, ArrayRef feeds_names, + ArrayRef fetches_names, + ArrayRef control_rets); + +} // namespace tfg +} // namespace mlir + +#endif // TENSORFLOW_CORE_TRANSFORMS_GRAPH_TO_FUNC_GRAPH_TO_FUNC_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/transforms/graph_to_func/pass.h b/third_party/tflite-hdrs/tensorflow/core/transforms/graph_to_func/pass.h new file mode 100644 index 00000000..798f5c95 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/transforms/graph_to_func/pass.h @@ -0,0 +1,41 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_TRANSFORMS_GRAPH_TO_FUNC_PASS_H_ +#define TENSORFLOW_CORE_TRANSFORMS_GRAPH_TO_FUNC_PASS_H_ + +#include +#include + +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project + +namespace mlir { +namespace tfg { + +#define GEN_PASS_DECL_GRAPHTOFUNC +#include "tensorflow/core/transforms/passes.h.inc" + +// Returns a pass that runs on a Module and expects to find a single GraphOp +// to transform into a function. The provided feeds and fetches are used to form +// the function arguments and returned values. +std::unique_ptr CreateGraphToFuncPass( + ArrayRef feeds = {}, ArrayRef fetches = {}, + ArrayRef control_rets = {}); + +} // namespace tfg +} // namespace mlir + +#endif // TENSORFLOW_CORE_TRANSFORMS_GRAPH_TO_FUNC_PASS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/transforms/graph_transform_wrapper.h b/third_party/tflite-hdrs/tensorflow/core/transforms/graph_transform_wrapper.h new file mode 100644 index 00000000..030f428c --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/transforms/graph_transform_wrapper.h @@ -0,0 +1,46 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_TRANSFORMS_GRAPH_TRANSFORM_WRAPPER_H_ +#define TENSORFLOW_CORE_TRANSFORMS_GRAPH_TRANSFORM_WRAPPER_H_ + +#include +#include + +#include "llvm/ADT/STLFunctionalExtras.h" +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "tensorflow/core/framework/graph_debug_info.pb.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/platform/status.h" + +namespace mlir { +namespace tfg { + +// Runs a sequence of passes over Graph* and attached function library. The +// Graph* is converted to TFG, provided passes executed and the passed in Graph* +// replaced. If the pass fails, then graph is not modified. +// +// This is meant for simple interop where there is a Graph* currently. Passes +// created here are constrained to run on Module ops. +absl::Status RunTransformOnGraph( + tensorflow::Graph* graph, + const std::initializer_list< + llvm::function_ref()>>& passes, + const tensorflow::GraphDebugInfo& debug_info = {}); + +} // namespace tfg +} // namespace mlir + +#endif // TENSORFLOW_CORE_TRANSFORMS_GRAPH_TRANSFORM_WRAPPER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/transforms/legacy_call/pass.h b/third_party/tflite-hdrs/tensorflow/core/transforms/legacy_call/pass.h new file mode 100644 index 00000000..95faae4d --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/transforms/legacy_call/pass.h @@ -0,0 +1,29 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_TRANSFORMS_LEGACY_CALL_PASS_H_ +#define TENSORFLOW_CORE_TRANSFORMS_LEGACY_CALL_PASS_H_ + +#include + +#include "mlir/Pass/Pass.h" // from @llvm-project + +namespace mlir { +namespace tfg { +std::unique_ptr CreateLiftLegacyCallPass(); +} // namespace tfg +} // namespace mlir + +#endif // TENSORFLOW_CORE_TRANSFORMS_LEGACY_CALL_PASS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/transforms/pass_registration.h b/third_party/tflite-hdrs/tensorflow/core/transforms/pass_registration.h new file mode 100644 index 00000000..55aa88ed --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/transforms/pass_registration.h @@ -0,0 +1,48 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_TRANSFORMS_PASS_REGISTRATION_H_ +#define TENSORFLOW_CORE_TRANSFORMS_PASS_REGISTRATION_H_ + +#include + +#include "tensorflow/core/transforms/cf_sink/pass.h" +#include "tensorflow/core/transforms/consolidate_attrs/pass.h" +#include "tensorflow/core/transforms/const_dedupe_hoist/pass.h" +#include "tensorflow/core/transforms/constant_folding/pass.h" +#include "tensorflow/core/transforms/cse/pass.h" +#include "tensorflow/core/transforms/drop_unregistered_attribute/pass.h" +#include "tensorflow/core/transforms/eliminate_passthrough_iter_args/pass.h" +#include "tensorflow/core/transforms/func_to_graph/pass.h" +#include "tensorflow/core/transforms/functional_to_region/pass.h" +#include "tensorflow/core/transforms/graph_compactor/pass.h" +#include "tensorflow/core/transforms/graph_to_func/pass.h" +#include "tensorflow/core/transforms/legacy_call/pass.h" +#include "tensorflow/core/transforms/region_to_functional/pass.h" +#include "tensorflow/core/transforms/remapper/pass.h" +#include "tensorflow/core/transforms/shape_inference/pass.h" +#include "tensorflow/core/transforms/toposort/pass.h" + +namespace mlir { +namespace tfg { + +// Generate the code for registering passes for command-line parsing. +#define GEN_PASS_REGISTRATION +#include "tensorflow/core/transforms/passes.h.inc" + +} // namespace tfg +} // namespace mlir + +#endif // TENSORFLOW_CORE_TRANSFORMS_PASS_REGISTRATION_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/transforms/region_to_functional/impl.h b/third_party/tflite-hdrs/tensorflow/core/transforms/region_to_functional/impl.h new file mode 100644 index 00000000..77c3ec91 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/transforms/region_to_functional/impl.h @@ -0,0 +1,34 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_TRANSFORMS_REGION_TO_FUNCTIONAL_IMPL_H_ +#define TENSORFLOW_CORE_TRANSFORMS_REGION_TO_FUNCTIONAL_IMPL_H_ + +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/SymbolTable.h" // from @llvm-project + +namespace mlir { +namespace tfg { + +// Populate the patterns to convert region ops to functional ops. Please refer +// to `tfg-region-to-functional` pass description. +void PopulateRegionToFunctionalPatterns(RewritePatternSet &patterns, + SymbolTable &table, + bool force_control_capture = false); + +} // namespace tfg +} // namespace mlir + +#endif // TENSORFLOW_CORE_TRANSFORMS_REGION_TO_FUNCTIONAL_IMPL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/transforms/region_to_functional/pass.h b/third_party/tflite-hdrs/tensorflow/core/transforms/region_to_functional/pass.h new file mode 100644 index 00000000..e1dc6fc5 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/transforms/region_to_functional/pass.h @@ -0,0 +1,38 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_TRANSFORMS_REGION_TO_FUNCTIONAL_PASS_H_ +#define TENSORFLOW_CORE_TRANSFORMS_REGION_TO_FUNCTIONAL_PASS_H_ + +#include + +#include "mlir/Pass/Pass.h" // from @llvm-project + +namespace mlir { +namespace tfg { + +#define GEN_PASS_DECL_REGIONTOFUNCTIONAL +#include "tensorflow/core/transforms/passes.h.inc" + +// Creates a conversion pass from region control-flow to functional +// control-flow. If `force_control_capture` is set, then all region control-flow +// ops are guaranteed to be converted to functional form by capturing implicit +// control tokens using a `Const` node. +std::unique_ptr CreateRegionToFunctionalPass( + bool force_control_capture = false); +} // namespace tfg +} // namespace mlir + +#endif // TENSORFLOW_CORE_TRANSFORMS_REGION_TO_FUNCTIONAL_PASS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/transforms/remapper/pass.h b/third_party/tflite-hdrs/tensorflow/core/transforms/remapper/pass.h new file mode 100644 index 00000000..aadb124a --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/transforms/remapper/pass.h @@ -0,0 +1,37 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_TRANSFORMS_REMAPPER_PASS_H_ +#define TENSORFLOW_CORE_TRANSFORMS_REMAPPER_PASS_H_ + +#include + +#include "mlir/Pass/Pass.h" // from @llvm-project + +namespace mlir { +namespace tfg { + +#define GEN_PASS_DECL_REMAPPER +#include "tensorflow/core/transforms/passes.h.inc" + +// Creates a remapper pass to remap the operations onto other opreations which +// decrease the amount of operations to perform a computation. +std::unique_ptr CreateRemapperPass(bool enable_onednn_patterns = false, + bool xla_auto_clustering = false); + +} // namespace tfg +} // namespace mlir + +#endif // TENSORFLOW_CORE_TRANSFORMS_REMAPPER_PASS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/transforms/remapper/remapping_helper.h b/third_party/tflite-hdrs/tensorflow/core/transforms/remapper/remapping_helper.h new file mode 100644 index 00000000..1d8db8fc --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/transforms/remapper/remapping_helper.h @@ -0,0 +1,245 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_TRANSFORMS_REMAPPER_REMAPPING_HELPER_H_ +#define TENSORFLOW_CORE_TRANSFORMS_REMAPPER_REMAPPING_HELPER_H_ + +#include + +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/transforms/utils/op_cat_helper.h" +#include "tensorflow/core/transforms/utils/utils.h" + +namespace mlir { +namespace tfg { + +// The following structures store info of the operations to be fused. These +// are mainly used for combining operands info and attributes for a fused +// operation. They are also used for some predicate functions like +// `IsCpuCompatible` and `IsGpuCompatible` to check if the relevant fusion is +// supported on CPU and GPU, respectively. Another reason to keep these +// structures is to follow similar logics in current grappler-remapper. +// TODO(intel-tf): Remove redundancies once the similar functionality is +// achieved by tfg-remapper. +struct ContractionBiasAdd { + Operation* contraction; + Operation* bias_add; +}; + +struct ContractionBiasAddActivation { + Operation* contraction; + Operation* bias_add; + Operation* activation; +}; + +struct ContractionBiasAddAdd { + Operation* contraction; + Operation* bias_add; + Operation* add; +}; + +struct ContractionBiasAddAddActivation { + Operation* contraction; + Operation* bias_add; + Operation* add; + Operation* activation; +}; + +struct FusedBatchNormEx { + Operation* fused_batch_norm; + Value side_input; + Operation* activation; +}; + +class OpPropertyHelper : public OpCatHelper { + public: + OpPropertyHelper() = default; + explicit OpPropertyHelper(TFGraphDialect* dialect, + bool onednn_enabled = false, + bool xla_auto_clustering = false) + : OpCatHelper(dialect), + is_onednn_enabled_(onednn_enabled), + is_xla_auto_clustering_enabled_(xla_auto_clustering) {} + + bool HasControlOperandsOrResultUsers(Operation* op) const { + TFOp wrapper_op(op); + bool has_ctl_operands = !(wrapper_op.getControlOperands().empty()); + bool has_ctl_ret_users = !(wrapper_op.controlRet().getUsers().empty()); + if (has_ctl_operands || has_ctl_ret_users) + return true; + else + return false; + } + + // This function is to be used for an operation that has at least 1 + // non-control result. + bool HasAtMostOneUserOfResult0(Operation* op) const { + // All tfg operation has 1 control result. When the operation has at least 1 + // non-control result, the number of results should be at least 2. + return op->getNumResults() > 1 && + (op->getResult(0).hasOneUse() || op->getResult(0).use_empty()); + } + + bool IsContraction(Operation* op) const { + return dialect_->IsConv2D(op) || dialect_->IsConv3D(op) || + dialect_->IsDepthwiseConv2dNative(op) || dialect_->IsMatMul(op); + } + + bool HaveSameDataType(Operation* lhs_op, Operation* rhs_op, + StringRef attr_name = "T") const { + auto lhs_attr = lhs_op->getAttrOfType(attr_name); + auto rhs_attr = rhs_op->getAttrOfType(attr_name); + if (!lhs_attr || !rhs_attr) return false; + return lhs_attr == rhs_attr; + } + + // This function is currently used by contraction ops. + bool IsGpuCompatibleDataType(Operation* contraction_op, + StringRef attr_name = "T") const { + auto attr = contraction_op->getAttrOfType(attr_name); + if (!attr) return false; + Type dtype = attr.getValue(); + if (dialect_->IsConv2D(contraction_op)) { + return mlir::isa(dtype); + } else if (dialect_->IsMatMul(contraction_op)) { + return mlir::isa(dtype); + } else { + return false; + } + } + + // This function is currently used by contraction ops. + bool IsCpuCompatibleDataType(Operation* contraction_op, + StringRef attr_name = "T") const { + auto attr = contraction_op->getAttrOfType(attr_name); + if (!attr) return false; + Type dtype = attr.getValue(); + if (is_onednn_enabled_) { + // Only contraction ops (MatMul, Conv2D, Conv3D, and + // DepthwiseConv2dNative) and BatchMatMul are supported. BatchMatMul + // fusions are handled differently than contraction ops. + bool is_supported = IsContraction(contraction_op) || + dialect_->IsAnyBatchMatMul(contraction_op); + return is_supported && mlir::isa(dtype); + } + + if (dialect_->IsConv2D(contraction_op)) { + return mlir::isa(dtype); + } else if (dialect_->IsMatMul(contraction_op)) { + return mlir::isa(dtype); + } else { + return false; + } + } + + // This function is currently used by convolution type op + bool IsGpuCompatibleDataFormat(Operation* conv_op, + StringRef attr_name = "data_format") const { + StringRef data_format; + if (auto attr = conv_op->getAttrOfType(attr_name)) { + data_format = attr.getValue(); + } else { + return false; + } + if (dialect_->IsConv2D(conv_op)) { + return data_format == "NHWC" || data_format == "NCHW"; + } else { + return false; + } + } + + // This function is currently used by convolution type op + bool IsCpuCompatibleDataFormat(Operation* conv_op, + StringRef attr_name = "data_format") const { + StringRef data_format; + if (auto attr = conv_op->getAttrOfType(attr_name)) { + data_format = attr.getValue(); + } else { + return false; + } + if (dialect_->IsConv2D(conv_op)) { + return data_format == "NHWC" || + (is_onednn_enabled_ && data_format == "NCHW"); + } else if (dialect_->IsConv3D(conv_op)) { + return data_format == "NDHWC" || + (is_onednn_enabled_ && data_format == "NCDHW"); + } else { + return false; + } + } + + bool IsGpuCompatible(const ContractionBiasAddActivation& pattern) const { +#if TENSORFLOW_USE_ROCM + // ROCm does not support _FusedConv2D. Does it suppport _FusedMatMul? + return false; +#endif + // The TF->XLA bridge does not support `_FusedMatMul` so we avoid creating + // this op. Furthermore, XLA already does this fusion internally so there + // is no true benefit from doing this optimization if XLA is going to + // compile the unfused operations anyway. + if (is_xla_auto_clustering_enabled_) return false; + if (!util::OpHasDevice(pattern.contraction, tensorflow::DEVICE_GPU)) + return false; + if (!dialect_->IsRelu(pattern.activation)) return false; + if (dialect_->IsMatMul(pattern.contraction)) { + return IsGpuCompatibleDataType(pattern.contraction); + } else { + // TODO(intel-tf): Add spatial convolution support on GPU + return false; + } + } + + // Currently GPU does not supprt contraction + bias_add + bool IsGpuCompatible(const ContractionBiasAdd&) const { return false; } + + bool IsCpuCompatible(Operation* contraction_op) const { + if (!util::OpHasDevice(contraction_op, tensorflow::DEVICE_CPU)) + return false; + if (dialect_->IsConv2D(contraction_op) || + dialect_->IsConv3D(contraction_op)) { + return IsCpuCompatibleDataType(contraction_op) && + IsCpuCompatibleDataFormat(contraction_op); + } else if (dialect_->IsMatMul(contraction_op) || + dialect_->IsAnyBatchMatMul(contraction_op) || + dialect_->IsDepthwiseConv2dNative(contraction_op)) { + return IsCpuCompatibleDataType(contraction_op); + } else { + return false; + } + } + + template + bool IsDeviceCompatible(const Pattern& pattern) const { + // Currently, this function is used by contraction based fussion. + if constexpr (!std::is_same::value && + !std::is_same::value && + !std::is_same::value && + !std::is_same::value) { + return false; + } + return IsGpuCompatible(pattern) || IsCpuCompatible(pattern.contraction); + } + + bool isOneDNNEnabled() const { return is_onednn_enabled_; } + + private: + bool is_onednn_enabled_; + bool is_xla_auto_clustering_enabled_; +}; + +} // namespace tfg +} // namespace mlir + +#endif // TENSORFLOW_CORE_TRANSFORMS_REMAPPER_REMAPPING_HELPER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/transforms/shape_inference/pass.h b/third_party/tflite-hdrs/tensorflow/core/transforms/shape_inference/pass.h new file mode 100644 index 00000000..046d6556 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/transforms/shape_inference/pass.h @@ -0,0 +1,35 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_TRANSFORMS_SHAPE_INFERENCE_PASS_H_ +#define TENSORFLOW_CORE_TRANSFORMS_SHAPE_INFERENCE_PASS_H_ + +#include + +#include "mlir/Pass/Pass.h" // from @llvm-project + +namespace mlir { +namespace tfg { + +#define GEN_PASS_DECL_SHAPEINFERENCE +#include "tensorflow/core/transforms/passes.h.inc" + +// Pass that infers the output shape of operations. +std::unique_ptr CreateShapeInferencePass(); + +} // namespace tfg +} // namespace mlir + +#endif // TENSORFLOW_CORE_TRANSFORMS_SHAPE_INFERENCE_PASS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/transforms/toposort/pass.h b/third_party/tflite-hdrs/tensorflow/core/transforms/toposort/pass.h new file mode 100644 index 00000000..84760adb --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/transforms/toposort/pass.h @@ -0,0 +1,38 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_TRANSFORMS_TOPOSORT_PASS_H_ +#define TENSORFLOW_CORE_TRANSFORMS_TOPOSORT_PASS_H_ + +#include + +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "tensorflow/core/ir/dialect.h" + +namespace mlir { +namespace tfg { + +// Sort topologically (following SSA defs-uses edges) the given block. +// The sort is stable. Optionally accepts an instance of the TFG dialect for +// virtually breaking NextIteration -> Merge cycles. +void SortTopologically(Block *block, TFGraphDialect *dialect = nullptr); + +// Programmatically create a pass that topologically sort graphs. +std::unique_ptr CreateTopoSortPass(); + +} // namespace tfg +} // namespace mlir + +#endif // TENSORFLOW_CORE_TRANSFORMS_TOPOSORT_PASS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/transforms/utils/eval_utils.h b/third_party/tflite-hdrs/tensorflow/core/transforms/utils/eval_utils.h new file mode 100644 index 00000000..28128938 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/transforms/utils/eval_utils.h @@ -0,0 +1,73 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_TRANSFORMS_UTILS_EVAL_UTILS_H_ +#define TENSORFLOW_CORE_TRANSFORMS_UTILS_EVAL_UTILS_H_ + +#include + +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/SmallVector.h" +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "tensorflow/core/framework/device_base.h" +#include "tensorflow/core/framework/resource_mgr.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/ir/tf_op_wrapper.h" + +namespace Eigen { +class ThreadPoolDevice; +} // namespace Eigen + +namespace mlir { +namespace tfg { +namespace util { + +// A simple CPU device for operation evaluation. +class SimpleDevice : public tensorflow::DeviceBase { + public: + SimpleDevice(); + ~SimpleDevice() override; + + absl::Status MakeTensorFromProto( + const tensorflow::TensorProto& tensor_proto, + const tensorflow::AllocatorAttributes alloc_attrs, + tensorflow::Tensor* tensor) override; + + tensorflow::Allocator* GetAllocator( + tensorflow::AllocatorAttributes attr) override; + + const std::string& device_type() const override { return device_type_; } + + private: + std::unique_ptr eigen_worker_; + tensorflow::DeviceBase::CpuWorkerThreads eigen_worker_threads_; + std::unique_ptr eigen_device_; + const std::string device_type_ = tensorflow::DEVICE_CPU; +}; + +// Attempts to evaluates an MLIR Operation with the op registered kernel. The op +// is always executed on the local host CPU irrespective of the device attribute +// of the given op. The results will be filled in the results vector. +LogicalResult EvaluateOperation(tensorflow::DeviceBase* cpu_device, + tensorflow::ResourceMgr* resource_mgr, TFOp op, + ArrayRef operands, + SmallVectorImpl& results); +} // namespace util +} // namespace tfg +} // namespace mlir + +#endif // TENSORFLOW_CORE_TRANSFORMS_UTILS_EVAL_UTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/transforms/utils/op_cat_helper.h b/third_party/tflite-hdrs/tensorflow/core/transforms/utils/op_cat_helper.h new file mode 100644 index 00000000..8cb212bd --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/transforms/utils/op_cat_helper.h @@ -0,0 +1,54 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_TRANSFORMS_UTILS_OP_CAT_HELPER_H_ +#define TENSORFLOW_CORE_TRANSFORMS_UTILS_OP_CAT_HELPER_H_ + +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "tensorflow/core/ir/dialect.h" +#include "tensorflow/core/ir/tf_op_wrapper.h" + +namespace mlir { +namespace tfg { +// A Helper class to identify if an op belongs to certain op category. +class OpCatHelper { + public: + OpCatHelper() = default; + explicit OpCatHelper(TFGraphDialect *dialect) : dialect_(dialect) {} + + bool IsAggregate(TFOp op); + bool IsCommutative(TFOp op); + + // Returns true if it's a splat tensor type and has the splat value 1. + bool IsOnes(TFOp op); + // Returns true if it's a splat tensor type and has the splat value 0. + bool IsZeros(TFOp op); + + // Returns true if the op is known to use persistent memory to store its + // value. + bool IsPersistent(TFOp op); + + // Returns true if the op belongs to the NC_DATASET class (see graph/graph.h). + bool IsDataset(TFOp op); + + TFGraphDialect *getDialect() const { return dialect_; } + + protected: + TFGraphDialect *dialect_; +}; +} // namespace tfg +} // namespace mlir + +#endif // TENSORFLOW_CORE_TRANSFORMS_UTILS_OP_CAT_HELPER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/transforms/utils/pdll/utils.h b/third_party/tflite-hdrs/tensorflow/core/transforms/utils/pdll/utils.h new file mode 100644 index 00000000..75f33509 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/transforms/utils/pdll/utils.h @@ -0,0 +1,30 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_TRANSFORMS_UTILS_PDLL_UTILS_H_ +#define TENSORFLOW_CORE_TRANSFORMS_UTILS_PDLL_UTILS_H_ + +#include "mlir/IR/PatternMatch.h" // from @llvm-project + +namespace mlir { +namespace tfg { + +// Register the common utils. +void RegisterPDLLUtils(RewritePatternSet &patterns); + +} // namespace tfg +} // namespace mlir + +#endif // TENSORFLOW_CORE_TRANSFORMS_UTILS_PDLL_UTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/transforms/utils/utils.h b/third_party/tflite-hdrs/tensorflow/core/transforms/utils/utils.h new file mode 100644 index 00000000..9b4b2e2a --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/transforms/utils/utils.h @@ -0,0 +1,82 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_TRANSFORMS_UTILS_UTILS_H_ +#define TENSORFLOW_CORE_TRANSFORMS_UTILS_UTILS_H_ + +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "tensorflow/core/ir/dialect.h" + +namespace mlir { + +class Operation; +class NamedAttrList; + +namespace tfg { +namespace util { + +// Returns true if the op has the requested device attribute. +bool OpHasDevice(Operation *op, const char *device_name); + +// Erase the attribute starts with "_". +void EraseRegularNodeAttributes(NamedAttrList &attr_list); + +// When rewriting an operation 1-to-1, intrinsic attributes are manually +// forwarded, modified, or dropped. For example, when `If` is rewritten to +// `IfRegion`, +// +// 1. `Tout` is forwarded as is, +// 2. `then_branch` is changed to `then_attrs` which contain the attribute +// dictionary part of the `#tf_type.func`, and +// 3. `Tin` is dropped. +// +// Non-intrinsic attributes, e.g. `_tpu_cluster`, are blindly forwarded to the +// new operation. +void ForwardNonIntrinsicAttributes(Operation *src, Operation *dst); + +// Add an argument to a loop region. This inserts the new data argument and +// control argument at the correct positions and returns them. Also, this +// function updates any preserved argument attributes by inserting a null. +struct LoopRegionArgumentUpdate { + BlockArgument data, ctl; +}; +LoopRegionArgumentUpdate LoopRegionAddArgument(Region ®ion, Type type); + +// Erase an argument from a loop region. This erases the corresponding control +// argument. Also, this function updates any preserved argument attributes by +// deleting them. +void LoopRegionEraseArgument(Region ®ion, unsigned index); + +// Indicate that a result has been added to a loop region. Call this function to +// update the preserved result attributes. +void LoopRegionResultAdded(Region ®ion, unsigned num = 1); + +// Indicate that a result has been erased from a loop region. Call this function +// to update the preserved result attributes. +void LoopRegionResultErased(Region ®ion, unsigned index); + +// Erase operands from an op that might have an `operand_segment_sizes` , +// updating the attribute in-place if present. +void SizedOperandSegmentsEraseOperands(Operation *op, + ArrayRef indices); +void SizedOperandSegmentsEraseOperands(Operation *op, + const llvm::BitVector &erase); + +} // namespace util +} // namespace tfg +} // namespace mlir + +#endif // TENSORFLOW_CORE_TRANSFORMS_UTILS_UTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/util/activation_mode.h b/third_party/tflite-hdrs/tensorflow/core/util/activation_mode.h new file mode 100644 index 00000000..2c2e6476 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/util/activation_mode.h @@ -0,0 +1,65 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_UTIL_ACTIVATION_MODE_H_ +#define TENSORFLOW_CORE_UTIL_ACTIVATION_MODE_H_ + +// This file contains helper routines to deal with activation mode in various +// ops and kernels. + +#include + +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { + +// ActivationMode: the activation function we apply to the input tensor: +enum ActivationMode { + NONE = 0, + SIGMOID = 1, + RELU = 2, + RELU6 = 3, + RELUX = 4, + TANH = 5, + BANDPASS = 6, +}; + +// Specialization to parse an attribute directly into a ActivationMode enum. +absl::Status GetActivationModeFromString(const string& str_value, + ActivationMode* value); + +inline absl::string_view ToString(ActivationMode mode) { + switch (mode) { + case NONE: + return "NONE"; + case SIGMOID: + return "SIGMOID"; + case RELU: + return "RELU"; + case RELU6: + return "RELU6"; + case RELUX: + return "RELUX"; + case TANH: + return "TANH"; + case BANDPASS: + return "BANDPASS"; + } +} + +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_UTIL_ACTIVATION_MODE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/util/autotune_maps/autotune_serialize.h b/third_party/tflite-hdrs/tensorflow/core/util/autotune_maps/autotune_serialize.h new file mode 100644 index 00000000..745eb1ad --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/util/autotune_maps/autotune_serialize.h @@ -0,0 +1,50 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// For Google-internal use only. +// +// Supports serializing the autotune maps to string +// (SerializeAutotuneMaps), as well as deserializing them from +// string and injecting them into TF runtime +// (LoadSerializedAutotuneMaps). +// +// Aims to speed up the warmup time of neural nets. + +#ifndef TENSORFLOW_CORE_UTIL_AUTOTUNE_MAPS_AUTOTUNE_SERIALIZE_H_ +#define TENSORFLOW_CORE_UTIL_AUTOTUNE_MAPS_AUTOTUNE_SERIALIZE_H_ + +#include + +#include "absl/status/status.h" +#include "absl/strings/string_view.h" +#include "tensorflow/core/platform/status.h" + +namespace tensorflow { + +// TODO(b/189530096) Support autotune maps for more ops. +// Loads autotune maps from string output by SerializeAutotuneMaps and uses +// them to update the runtime autotune maps. +absl::Status LoadSerializedAutotuneMaps(absl::string_view s); + +// Serializes all the autotune maps into a string that can be decoded by +// LoadSerializedAutotuneMaps. +absl::Status SerializeAutotuneMaps(std::string* output); + +// Resets all autotune maps. For test use only. +void ResetAutotuneMaps(); + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_UTIL_AUTOTUNE_MAPS_AUTOTUNE_SERIALIZE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/util/autotune_maps/conv_autotune_maps.h b/third_party/tflite-hdrs/tensorflow/core/util/autotune_maps/conv_autotune_maps.h new file mode 100644 index 00000000..7c00348a --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/util/autotune_maps/conv_autotune_maps.h @@ -0,0 +1,60 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// For Google-internal use only. +// +// This file defines the map data structure for storing autotuning results for +// fused_conv2d_bias_activation_op_kernels. +// +// The key of the map uniquely identifies a convolution operation that runs on a +// particular device model while the value might be the autotuned algorithm we +// choose for the conv. +// +// This map will be merged after fused_conv2d_bias_activation_op_kernels is +// merged into conv_ops_fused_impl.h (b/177365158, b/189530096) + +#ifndef TENSORFLOW_CORE_UTIL_AUTOTUNE_MAPS_CONV_AUTOTUNE_MAPS_H_ +#define TENSORFLOW_CORE_UTIL_AUTOTUNE_MAPS_CONV_AUTOTUNE_MAPS_H_ + +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM +#include + +#include "tensorflow/core/kernels/gpu_utils.h" +#include "tensorflow/core/platform/stream_executor.h" +#include "tensorflow/core/util/autotune_maps/conv_parameters.h" + +namespace tensorflow { + +// A dummy type to group forward convolution autotune results together. +struct ConvAutotuneGroup { + static string name() { return "Conv"; } +}; + +using ConvAutotuneMap = AutotuneSingleton>; + +// A dummy type to group fused convolution autotune results together. +struct ConvFusedAutotuneGroup { + static string name() { return "FusedConv"; } +}; + +using FusedConvAutotuneMap = + AutotuneSingleton>; + +} // namespace tensorflow +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM + +#endif // TENSORFLOW_CORE_UTIL_AUTOTUNE_MAPS_CONV_AUTOTUNE_MAPS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/util/autotune_maps/conv_map_wrapper.h b/third_party/tflite-hdrs/tensorflow/core/util/autotune_maps/conv_map_wrapper.h new file mode 100644 index 00000000..39ce9845 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/util/autotune_maps/conv_map_wrapper.h @@ -0,0 +1,66 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_UTIL_AUTOTUNE_MAPS_CONV_MAP_WRAPPER_H_ +#define TENSORFLOW_CORE_UTIL_AUTOTUNE_MAPS_CONV_MAP_WRAPPER_H_ + +#include +#include + +#include "absl/status/statusor.h" +#include "tensorflow/core/util/autotune_maps/autotune_map.pb.h" + +namespace tensorflow { + +// This class is a thin wrapper around `ConvMapProto::Entry`. It is used to +// provide opaque accessors to an entry's key and value without exposing the +// internal structure of the entry. +class ConvMapWrapper { + public: + using OpaqueKey = std::string; + using OpaqueValue = std::string; + + // Creates an `ConvMapWrapper` from a key and value. The provided key and + // value must be ones that were previously returned by calls to `Key()` and + // `Value()`. + static absl::StatusOr FromKeyAndValue(OpaqueKey key, + OpaqueValue value); + + // An opaque string that can be used as a key for this autotuning result. + // Do not rely on the format of this string. + OpaqueKey Key() const; + + // An opaque string that encodes the autotuning result. + // Do not rely on the format of this string. + OpaqueValue Value() const; + + static std::vector ConvMapToWrappers( + const ConvMapProto& autotune_results); + + // Returns the `ConvMapProto` proto that corresponds to the provided + // wrappers. + static absl::StatusOr ConvMapFromWrappers( + const std::vector& wrappers); + + private: + explicit ConvMapWrapper(const ConvMapProto::Entry& entry) + : conv_map_entry_(entry) {} + + ConvMapProto::Entry conv_map_entry_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_UTIL_AUTOTUNE_MAPS_CONV_MAP_WRAPPER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/util/autotune_maps/conv_parameters.h b/third_party/tflite-hdrs/tensorflow/core/util/autotune_maps/conv_parameters.h new file mode 100644 index 00000000..6658fa6e --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/util/autotune_maps/conv_parameters.h @@ -0,0 +1,137 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_UTIL_AUTOTUNE_MAPS_CONV_PARAMETERS_H_ +#define TENSORFLOW_CORE_UTIL_AUTOTUNE_MAPS_CONV_PARAMETERS_H_ + +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM +#include "absl/types/optional.h" +#include "tensorflow/core/platform/stream_executor.h" +#include "tensorflow/core/util/autotune_maps/conv_parameters.pb.h" + +namespace tensorflow { +// Uniquely identifies a convolution operation that runs on a particular device +// model. +// +// This can serve as a hashtable key, where the value might be the autotuned +// algorithm we choose for the conv. +// +// All of the data in this class other than the device_id is stored in the +// ConvParametersProto, so it can be easily serialized (for the purposes of +// ahead-of-time autotuning). +// +// When using the cudnn frontend API, two autotuning results for two different +// GPUs of the same model are not interchangeable, because an autotuning result +// includes a cudnn execution plan, which is tied to the GPU. As a result, we +// need to create separate ConvParameters objects for them. +class ConvParameters { + public: + struct FusionInfo { + // For some implementations (e.g. cuDNN new backend) these scales are part + // of the algorithm, not part of the parameters an algorithm take. They need + // to be used to distinguish different algorithms. + double conv_scale; + double side_input_scale; + double leakyrelu_alpha; + stream_executor::dnn::ActivationMode activation_mode; + bool is_contrib; + }; + + // LINT.IfChange(conv_parameters_version) + // A positive number that denotes the version of this class. Should be + // incremented everytime this class or ConvParametersProto are updated in a + // way that may invalidate autotune results. + static constexpr int kVersion = 3; + // LINT.ThenChange() + + // We have three kinds of convolutions today. Vanilla unfused convolutions, + // fused convolutions, and fused convolutions as implemented in the `contrib` + // directory. The two fused convolutions ultimately correspond to the same + // cudnn calls, but have slightly different semantics (e.g. they interpret + // padding differently). + ConvParameters( + se::StreamExecutor* stream_exec, int64_t batch, int64_t in_depths, + absl::Span in, int data_format, int64_t out_depths, + absl::Span filter, absl::Span dilation, + absl::Span stride, absl::Span padding, + DataType dtype, int group_count, + absl::optional fusion_info = absl::optional(), + // This argument should be set only for test use. + int version = kVersion); + + ConvParameters(int device_id, const ConvParametersProto& proto); + + ConvParameters(se::StreamExecutor* stream_exec, + const ConvParametersProto& proto) + : ConvParameters(stream_exec->device_ordinal(), proto) {} + + bool operator==(const ConvParameters& other) const; + + bool operator!=(const ConvParameters& other) const { + return !(*this == other); + } + uint64 hash() const { return hash_code_; } + + string ToString() const; + + const ConvParametersProto& proto() const { return proto_; } + + private: + int device_id_; + ConvParametersProto proto_; + uint64 hash_code_; +}; + +class MatmulParameters { + public: + // LINT.IfChange(matmul_parameters_version) + // A positive number that denotes the version of this class. Should be + // incremented everytime this class or ConvParametersProto are updated in a + // way that may invalidate autotune results. + static constexpr int kVersion = 2; + // LINT.ThenChange() + + MatmulParameters(se::StreamExecutor* stream_exec, DataType ab_dtype, + DataType c_dtype, bool trans_a, bool trans_b, uint64_t m, + uint64_t n, uint64_t k, int64_t lda, int64_t ldb, + int64_t ldc, + stream_executor::dnn::ActivationMode activation_mode, + // This argument should be set only for test use. + int version = kVersion); + + MatmulParameters(se::StreamExecutor* stream_exec, + const MatmulParametersProto& proto); + + bool operator==(const MatmulParameters& other) const; + + bool operator!=(const MatmulParameters& other) const { + return !(*this == other); + } + uint64 hash() const { return hash_code_; } + + string ToString() const; + + const MatmulParametersProto& proto() const { return proto_; } + + private: + int device_id_; + MatmulParametersProto proto_; + uint64 hash_code_; +}; + +} // namespace tensorflow +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM + +#endif // TENSORFLOW_CORE_UTIL_AUTOTUNE_MAPS_CONV_PARAMETERS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/util/bad_indices_policy.h b/third_party/tflite-hdrs/tensorflow/core/util/bad_indices_policy.h new file mode 100644 index 00000000..ee8f4a89 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/util/bad_indices_policy.h @@ -0,0 +1,39 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_UTIL_BAD_INDICES_POLICY_H_ +#define TENSORFLOW_CORE_UTIL_BAD_INDICES_POLICY_H_ + +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" + +namespace tensorflow { +enum class BadIndicesPolicy { + // Default behavior: return an error on CPU and ignore on GPU. This is because + // we handle bad indices differently on CPU and GPU before this policy is + // introduced. + kDefault, + // Return an error. + kError, + // Ignore bad indices. + kIgnore, +}; + +absl::StatusOr BadIndicesPolicyFromString( + absl::string_view str); + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_UTIL_BAD_INDICES_POLICY_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/util/batch_util.h b/third_party/tflite-hdrs/tensorflow/core/util/batch_util.h new file mode 100644 index 00000000..176c229a --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/util/batch_util.h @@ -0,0 +1,74 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_UTIL_BATCH_UTIL_H_ +#define TENSORFLOW_CORE_UTIL_BATCH_UTIL_H_ + +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { +namespace batch_util { + +// Copies element into the index^th slice of parent (in the 0th dimension). +// +// NOTE(mrry): The `element` argument is taken by value. Use `std::move()` +// to move the `element` argument into this function, and the implementation +// may be able to optimize the copy to a move. This is particularly important +// for DT_STRING tensors. +absl::Status CopyElementToSlice(Tensor element, Tensor* parent, int64_t index); + +// Copies the index^th slice of parent (in the 0th dimension) into element. +absl::Status CopySliceToElement(const Tensor& parent, Tensor* element, + int64_t index); + +// Copies 'num_slices' contiguous slices from 'src' tensor starting from index +// 'src_offset' into target tensor 'dst', and places them into slices +// starting from 'dst_offset'. +// +// This function requires 'src' and 'dst' to have compatible shapes. That is it +// requires cum_prod(src.shape[1:] == cum_prod(dst->shape[1:]). For example if +// source is of shape [x, 2, 1] and dst is a tensor of shape [y, 1, 2], this +// function can still proceed successfully. +absl::Status CopyContiguousSlices(const Tensor& src, int64_t src_offset, + int64_t dst_offset, int64_t num_slices, + Tensor* dst); + +// Copies the index^th slice of parent (in the 0th dimension) into element. +// +// NOTE(mrry): The implementation may be able to optimize the copy to a move. +// This is particularly important for DT_STRING tensors. +absl::Status MaybeMoveSliceToElement(Tensor* parent, Tensor* element, + int64_t index); + +// Moves `src` Tensor's data in [src_offset, src_offset+num_slices) along +// the first dimension if possible. Otherwise, copy them into `dst`. +absl::Status MaybeMoveContiguousSlices(Tensor& src, int64_t src_offset, + int64_t dst_offset, int64_t num_slices, + Tensor* dst); + +// Zero-initializes the tensor `element` using the scalar stored in `padding`. +// Both `element` and `padding` must have matching `dtype`. +absl::Status SetElementZero(Tensor* element, const Tensor& padding); + +// Copies `element` into a (0th dimension) slice of `parent`, assuming +// the shape of `element` is strictly not larger along any axis than a +// slice. +absl::Status CopyElementToLargerSlice(const Tensor& element, Tensor* parent, + int index); + +} // namespace batch_util +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_UTIL_BATCH_UTIL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/util/bcast.h b/third_party/tflite-hdrs/tensorflow/core/util/bcast.h new file mode 100644 index 00000000..61d1fb5a --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/util/bcast.h @@ -0,0 +1,427 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_UTIL_BCAST_H_ +#define TENSORFLOW_CORE_UTIL_BCAST_H_ + +#include +#include + +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/lib/gtl/inlined_vector.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +// Returns the mapping from the output batch indices to the corresponding +// input's batch indices, given the input's "reshape" and "bcast" shapes as +// returned by the BCastList helper class. The i'th element denotes the +// (flattened) batch index of the input that must be used to compute the i'th +// batch output. +// +inline void ComputeBatchIndices( + const int64_t output_batch_size, + const absl::InlinedVector& reshape, + const absl::InlinedVector& bcast, + std::vector* out_indices) { + // Populates the mapping in out_indices. This algorithm is identical to + // the following steps: + // - Reshape {0, 1, ..., input_batch_size - 1} to the input shape. + // - Broadcast to the output shape. + // - Reshape back to a flat 1D vector. + out_indices->resize(output_batch_size); + int64_t num_output_elements = 1; + int64_t num_input_elements = 1; + for (int64_t i = reshape.size() - 1; i >= 0; --i) { + // Replicate the already populated mapping an additional (dim - 1) times. + // If we are broadcasting, just copy the existing mapping. + // Otherwise, add another dimension from the input shape. + const int64_t dim = std::max(reshape[i], bcast[i]); + const int64_t incr = bcast[i] > 1 ? 0 : num_input_elements; + for (int64_t k = 0; k < (dim - 1) * num_output_elements; ++k) { + (*out_indices)[num_output_elements + k] = (*out_indices)[k] + incr; + } + num_output_elements *= dim; + num_input_elements *= reshape[i]; + } +} + +template +class BCastList { + public: + // A vector of int64 representing the shape of tensor. The 0-th + // element is the outer-most dimension and the last element is the + // inner-most dimension. Note that we do not use TensorShape since + // it's more convenient to manipulate Vec directly for this module. + typedef absl::InlinedVector Vec; + + // Constructs all helper shapes, following the aforementioned rules. + // + // If "fewer_dims_optimization" is set to true (the default), the + // implementation tries to reduce intermediate dimensions needed to be more + // efficient. This is transparent to the caller. + // + // If false, all intermediate shapes (except for grad_{x,y}_reduce_idx()) have + // the same number of dimensions as the larger of the two inputs. + // + // If return_flattened_batch_indices is true, the implementation will compute + // for each output member of the flattened output, which batch indices of + // each input correspond to it. This is disabled by default. + explicit BCastList(const Vec (&x)[N], bool fewer_dims_optimization = true, + bool return_flattened_batch_indices = false); + ~BCastList() = default; + + // Returns true iff two operands are compatible according to the + // broadcasting rule. + bool IsValid() const { return valid_; } + bool IsBroadcastingRequired() const { return broadcasting_required_; } + + // If and only if IsValid(), the following fields can be used in + // implementing a broadcasted binary tensor operation according to + // the broadcasting rule. + const Vec& reshape(int i) const { return reshape_[i]; } + const Vec& bcast(int i) const { return bcast_[i]; } + const Vec& result_shape() const { return result_; } + const Vec& output_shape() const { return output_; } + const Vec& grad_reduce_idx(int i) const { return grad_reduce_idx_[i]; } + int64_t output_batch_size() const { return output_batch_size_; } + + // Returns the mapping from the flattened output batch indices to x's + // flattened batch indices. The result is a vector of length + // output_batch_size(). To compute the i'th batch output, a binary matmul-like + // operation should use the `x_batch_indices()[i]`th batch index of `x`. + // Note: Returns an empty vector if broadcasting is not required. Callers + // should only use this when IsBroadcastingRequired() returns true. + const std::vector& batch_indices(int i) const { + return batch_indices_[i]; + } + + protected: + bool valid_ = true; + bool broadcasting_required_ = true; + Vec reshape_[N]; + Vec bcast_[N]; + Vec result_; + Vec output_; + Vec grad_reduce_idx_[N]; + + int64_t output_batch_size_; + std::vector batch_indices_[N]; + + static void Reverse(Vec* shape) { + std::reverse(shape->begin(), shape->end()); + } + + BCastList(const BCastList&) = delete; + void operator=(const BCastList&) = delete; +}; + +template +BCastList::BCastList(const BCastList::Vec (&x)[N], + const bool fewer_dims_optimization, + const bool return_flattened_batch_indices) { + typedef BCastList::Vec Vec; + + // Safely multiplies dimensions taking into account symbolic shapes. + auto mul_dims = [](int64_t dim1, int64_t dim2) -> int64_t { + return dim1 != 0 && dim2 != 0 && (dim1 < 0 || dim2 < 0) ? -1 : dim1 * dim2; + }; + + bool all_equal = true; + size_t largest_rank = 0; + output_batch_size_ = 1; + for (int i = 0; i < N; ++i) { + if (x[i] != x[0]) { + all_equal = false; + } + if (x[i].size() > largest_rank) { + largest_rank = x[i].size(); + } + } + if (all_equal) { + broadcasting_required_ = false; + } + if (all_equal && TF_PREDICT_TRUE(fewer_dims_optimization)) { + // Fast path for common case of identical shapes. + int64_t elements = 1; + const int rank = x[0].size(); + output_.resize(rank); + for (int i = 0; i < rank; i++) { + const int64_t dim = x[0][i]; + elements = mul_dims(elements, dim); + output_[i] = dim; + } + result_.push_back(elements); + output_batch_size_ = elements; + for (int i = 0; i < N; ++i) { + reshape_[i].push_back(elements); + bcast_[i].push_back(1); + } + // grad_reduce_ is left as empty + return; + } + + // Reverse all the shapes for convenience + // After the reverse, 0-th is the inner-most dimension. + Vec copy[N]; + for (int i = 0; i < N; ++i) { + copy[i] = x[i]; + Reverse(©[i]); + } + + // 1-extend and align all vectors. + for (int i = 0; i < N; ++i) { + if (copy[i].size() < largest_rank) { + copy[i].resize(largest_rank, 1); + } + } + // Going through each dimension starting from the inner-most + // dimension, compares dimension of x and y. They are compatible if + // they are equal or either is 1. + + // indices of j-th component of each input. + bool prev_is_one[N]; + bool current_is_one[N]; + for (int i = 0; i < N; ++i) { + prev_is_one[i] = false; + current_is_one[i] = false; + } + bool output_dim_set = false; + int64_t output_dim = -1; + bool none_is_one = true; + bool set_one = false; + for (int j = 0; j < largest_rank; ++j) { + output_dim = -1; + output_dim_set = false; + none_is_one = true; + // Find which indices are 1. + for (int i = 0; i < N; ++i) { + // Keep track of which indices are 1. + if (copy[i][j] == 1) { + current_is_one[i] = true; + none_is_one = false; + } else { + current_is_one[i] = false; + if (!output_dim_set || copy[i][j] == output_dim) { + output_dim = copy[i][j]; + output_dim_set = true; + } else { + valid_ = false; + return; + } + } + } + output_.push_back(output_dim_set ? output_dim : 1); + output_batch_size_ = mul_dims(output_batch_size_, output_.back()); + // All dimensions are 1. + if (!output_dim_set) { + if (!TF_PREDICT_TRUE(fewer_dims_optimization)) { + for (int i = 0; i < N; ++i) { + bcast_[i].push_back(1); + reshape_[i].push_back(1); + } + result_.push_back(1); + } + for (int i = 0; i < N; ++i) { + grad_reduce_idx_[i].push_back(largest_rank - 1 - j); + } + // This will skip updating the previous state to the current one. We'll + // explain why this is safe below. + // Consider the previous state P, current state C and the next state N. + // In the case where N also is all ones (N == C), we'll do the same + // optimization here (push back one dimensions if we need to), which is + // safe and is expected. + // + // When N != C, we'll continue as usual. However, we might trigger the + // next block if N == P (because we didn't update the previous state). + // We trigger the next block if `fewer_dims_optimization` is true. + // This means that we did not modify and broadcast / reshapes in this + // block (we skipped updating, since the one dimensions can be ignored). + // In essence, we only need to check whether the previous non-one state is + // equal to the current non-one state. + + continue; + } else if (TF_PREDICT_TRUE(fewer_dims_optimization) && + std::equal(current_is_one, current_is_one + N, prev_is_one) && + set_one) { + // It is a run of the same broadcasting case as last time. + // We can reshape the input so that fewer dimensions + // are involved in the intermediate computation. + result_.back() = mul_dims(result_.back(), output_dim); + for (int i = 0; i < N; ++i) { + reshape_[i].back() = mul_dims(reshape_[i].back(), copy[i][j]); + bcast_[i].back() = + mul_dims(bcast_[i].back(), current_is_one[i] ? output_dim : 1); + if (current_is_one[i] && !none_is_one) { + grad_reduce_idx_[i].push_back(largest_rank - 1 - j); + } + } + } else { + result_.push_back(output_dim); + for (int i = 0; i < N; ++i) { + reshape_[i].push_back(copy[i][j]); + bcast_[i].push_back(current_is_one[i] ? output_dim : 1); + if (current_is_one[i] && !none_is_one) { + grad_reduce_idx_[i].push_back(largest_rank - 1 - j); + } + } + } + set_one = true; + for (int i = 0; i < N; ++i) { + prev_is_one[i] = current_is_one[i]; + } + } + if (result_.empty()) { + result_.push_back(1); + for (int i = 0; i < N; ++i) { + reshape_[i].push_back(1); + bcast_[i].push_back(1); + } + } + // Do something about batches. + for (int i = 0; i < N; ++i) { + Reverse(&reshape_[i]); + Reverse(&bcast_[i]); + Reverse(&grad_reduce_idx_[i]); + } + Reverse(&result_); + Reverse(&output_); + // Only compute batch indices when we need broadcasting, and we aren't doing + // needless work (when the output size is 0 or the + // return_flattened_batch_indices isn't enabled). + if (return_flattened_batch_indices && broadcasting_required_ && + output_batch_size_ > 0) { + for (int i = 0; i < N; ++i) { + ComputeBatchIndices(output_batch_size_, reshape_[i], bcast_[i], + &batch_indices_[i]); + } + } +} + +// BCast is a helper for broadcasting binary tensor operation. +// TensorFlow's broadcasting rule follows that of numpy (See +// http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html). +// +// The rule has the following properties: +// +// 1. suffix matching: the rule starts with the right-most +// dimension, and works towards the left-most dimension. Since +// TensorFlow is row-major, the right-most dimension (the last +// element in the shape of a tensor) is the inner-most, a.k.a. +// the fastest changing, dimension. +// +// 2. Two dimensions are compatible for broadcasting if both are the +// same or either is 1. +// +// BCast takes the shape of two tensors and computes a few vectors of +// int32 that are useful for the caller to reshape the tensors, apply +// the right broadcasts to them, compute the broadcasted operation, +// and possibly the gradients. In a nutshell, the caller is expected +// to compute the broadcasted operation as following: +// +// BCast b(x.shape(), y.shape()); +// output = x.reshape(b.x_reshape()).broadcast(b.x_bcast()) +// _op_ +// y.reshape(b.y_reshape()).broadcast(b.y_bcast()) +// +// For the gradient computation, +// grad_x = sum(grad * backprop_x(x, y), grad_x_reduce_idx) +// .reshape(x.shape()) +// grad_y = sum(grad * backprop_y(x, y), grad_y_reduce_idx) +// .reshape(y.shape()) +// backprop_x and backprop_y are functionals of the binary function "op", +// e.g., +// for +, backprop_x(x, y) = backprop_y(x, y) = 1; +// for *, backprop_x(x, y) = y, backprop_y(x, y) = x; +// for /, backprop_x(x, y) = 1/y, backprop_y(x, y) = -x/y^2; +// +// The multiplication in the grad * backprop_x itself is also +// broadcasting following the same rule. +class BCast : public BCastList<2> { + public: + // Constructs all helper shapes, following the aforementioned rules. + // + // If "fewer_dims_optimization" is set to true (the default), the + // implementation tries to reduce intermediate dimensions needed to be more + // efficient. This is transparent to the caller. + // + // If false, all intermediate shapes (except for grad_{x,y}_reduce_idx()) have + // the same number of dimensions as the larger of the two inputs. + typedef absl::InlinedVector Vec; + + BCast(const Vec& x, const Vec& y, const bool fewer_dims_optimization = true, + const bool return_flattened_batch_indices = false) + : BCastList<2>({x, y}, fewer_dims_optimization, + return_flattened_batch_indices) {} + + ~BCast() = default; + + // If and only if IsValid(), the following fields can be used in + // implementing a broadcasted binary tensor operation according to + // the broadcasting rule. + const Vec& x_reshape() const { return reshape_[0]; } + const Vec& x_bcast() const { return bcast_[0]; } + const Vec& y_reshape() const { return reshape_[1]; } + const Vec& y_bcast() const { return bcast_[1]; } + const Vec& result_shape() const { return result_; } + const Vec& output_shape() const { return output_; } + const Vec& grad_x_reduce_idx() const { return grad_reduce_idx_[0]; } + const Vec& grad_y_reduce_idx() const { return grad_reduce_idx_[1]; } + + // Returns the mapping from the flattened output batch indices to x's + // flattened batch indices. The result is a vector of length + // output_batch_size(). To compute the i'th batch output, a binary matmul-like + // operation should use the `x_batch_indices()[i]`th batch index of `x`. + // Note: Returns an empty vector if broadcasting is not required. Callers + // should only use this when IsBroadcastingRequired() returns true. + const std::vector& x_batch_indices() const { + return batch_indices_[0]; + } + // Returns the mapping from the flattened output batch indices to y's + // flattened batch indices. Similar to x_batch_indices(). + // Note: Returns an empty vector if broadcasting is not required. Callers + // should only use this when IsBroadcastingRequired() returns true. + const std::vector& y_batch_indices() const { + return batch_indices_[1]; + } + + template + static Eigen::array ToIndexArrayType( + const BCast::Vec& vec) { + CHECK_EQ(vec.size(), NDIMS); + Eigen::array ret; + for (int i = 0; i < NDIMS; ++i) ret[i] = vec[i]; + return ret; + } + + template + static Eigen::array ToIndexArray( + const BCast::Vec& vec) { + return ToIndexArrayType(vec); + } + + // Static helpers. + static Vec FromShape(const TensorShape& shape); + static TensorShape ToShape(const Vec& vec); + + private: + BCast(const BCast&) = delete; + void operator=(const BCast&) = delete; +}; + +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_UTIL_BCAST_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/util/command_line_flags.h b/third_party/tflite-hdrs/tensorflow/core/util/command_line_flags.h new file mode 100644 index 00000000..ebc58f7e --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/util/command_line_flags.h @@ -0,0 +1,31 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_UTIL_COMMAND_LINE_FLAGS_H_ +#define TENSORFLOW_CORE_UTIL_COMMAND_LINE_FLAGS_H_ + +#include +#include +#include + +#include "xla/tsl/util/command_line_flags.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +using tsl::Flag; // NOLINT +using tsl::Flags; // NOLINT +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_UTIL_COMMAND_LINE_FLAGS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/util/ctc/ctc_beam_entry.h b/third_party/tflite-hdrs/tensorflow/core/util/ctc/ctc_beam_entry.h new file mode 100644 index 00000000..c8a23036 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/util/ctc/ctc_beam_entry.h @@ -0,0 +1,154 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// LINT.IfChange + +#ifndef TENSORFLOW_CORE_UTIL_CTC_CTC_BEAM_ENTRY_H_ +#define TENSORFLOW_CORE_UTIL_CTC_CTC_BEAM_ENTRY_H_ + +#include +#include +#include + +#include "Eigen/Core" // from @eigen_archive +#include "tensorflow/core/lib/gtl/flatmap.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/ctc/ctc_loss_util.h" + +namespace tensorflow { +namespace ctc { + +// The ctc_beam_search namespace holds several classes meant to be accessed only +// in case of extending the CTCBeamSearch decoder to allow custom scoring +// functions. +// +// BeamEntry is exposed through template arguments BeamScorer and BeamComparer +// of CTCBeamSearch (ctc_beam_search.h). +namespace ctc_beam_search { + +struct EmptyBeamState {}; + +template +struct BeamProbability { + BeamProbability() + : total(kLogZero()), blank(kLogZero()), label(kLogZero()) {} + void Reset() { + total = kLogZero(); + blank = kLogZero(); + label = kLogZero(); + } + T total; + T blank; + T label; +}; + +template +class BeamRoot; + +template +struct BeamEntry { + // BeamRoot::AddEntry() serves as the factory method. + friend BeamEntry* BeamRoot::AddEntry( + BeamEntry* p, int l); + inline bool Active() const { return newp.total != kLogZero(); } + // Return the child at the given index, or construct a new one in-place if + // none was found. + BeamEntry& GetChild(int ind) { + auto entry = children.emplace(ind, nullptr); + auto& child_entry = entry.first->second; + // If this is a new child, populate the BeamEntry*. + if (entry.second) { + child_entry = beam_root->AddEntry(this, ind); + } + return *child_entry; + } + std::vector LabelSeq(bool merge_repeated) const { + std::vector labels; + int prev_label = -1; + const BeamEntry* c = this; + while (c->parent != nullptr) { // Checking c->parent to skip root leaf. + if (!merge_repeated || c->label != prev_label) { + labels.push_back(c->label); + } + prev_label = c->label; + c = c->parent; + } + std::reverse(labels.begin(), labels.end()); + return labels; + } + + BeamEntry* parent; + int label; + // All instances of child BeamEntry are owned by *beam_root. + gtl::FlatMap*> children; + BeamProbability oldp; + BeamProbability newp; + CTCBeamState state; + + private: + // Constructor giving parent, label, and the beam_root. + // The object pointed to by p cannot be copied and should not be moved, + // otherwise parent will become invalid. + // This private constructor is only called through the factory method + // BeamRoot::AddEntry(). + BeamEntry(BeamEntry* p, int l, BeamRoot* beam_root) + : parent(p), label(l), beam_root(beam_root) {} + BeamRoot* beam_root; + BeamEntry(const BeamEntry&) = delete; + void operator=(const BeamEntry&) = delete; +}; + +// This class owns all instances of BeamEntry. This is used to avoid recursive +// destructor call during destruction. +template +class BeamRoot { + public: + BeamRoot(BeamEntry* p, int l) { + root_entry_ = AddEntry(p, l); + } + BeamRoot(const BeamRoot&) = delete; + BeamRoot& operator=(const BeamRoot&) = delete; + + BeamEntry* AddEntry(BeamEntry* p, int l) { + auto* new_entry = new BeamEntry(p, l, this); + beam_entries_.emplace_back(new_entry); + return new_entry; + } + BeamEntry* RootEntry() const { return root_entry_; } + + private: + BeamEntry* root_entry_ = nullptr; + std::vector>> beam_entries_; +}; + +// BeamComparer is the default beam comparer provided in CTCBeamSearch. +template +class BeamComparer { + public: + virtual ~BeamComparer() {} + virtual bool inline operator()(const BeamEntry* a, + const BeamEntry* b) const { + return a->newp.total > b->newp.total; + } +}; + +} // namespace ctc_beam_search + +} // namespace ctc +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_UTIL_CTC_CTC_BEAM_ENTRY_H_ +// LINT.ThenChange(//tensorflow/lite/kernels/ctc/ctc_beam_entry.h) diff --git a/third_party/tflite-hdrs/tensorflow/core/util/ctc/ctc_beam_scorer.h b/third_party/tflite-hdrs/tensorflow/core/util/ctc/ctc_beam_scorer.h new file mode 100644 index 00000000..1ea370f4 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/util/ctc/ctc_beam_scorer.h @@ -0,0 +1,77 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// LINT.IfChange + +// Collection of scoring classes that can be extended and provided to the +// CTCBeamSearchDecoder to incorporate additional scoring logic (such as a +// language model). +// +// To build a custom scorer extend and implement the pure virtual methods from +// BeamScorerInterface. The default CTC decoding behavior is implemented +// through BaseBeamScorer. + +#ifndef TENSORFLOW_CORE_UTIL_CTC_CTC_BEAM_SCORER_H_ +#define TENSORFLOW_CORE_UTIL_CTC_CTC_BEAM_SCORER_H_ + +#include "tensorflow/core/util/ctc/ctc_beam_entry.h" + +namespace tensorflow { +namespace ctc { + +// Base implementation of a beam scorer used by default by the decoder that can +// be subclassed and provided as an argument to CTCBeamSearchDecoder, if complex +// scoring is required. Its main purpose is to provide a thin layer for +// integrating language model scoring easily. +template +class BaseBeamScorer { + public: + virtual ~BaseBeamScorer() {} + // State initialization. + virtual void InitializeState(CTCBeamState* root) const {} + // ExpandState is called when expanding a beam to one of its children. + // Called at most once per child beam. In the simplest case, no state + // expansion is done. + virtual void ExpandState(const CTCBeamState& from_state, int from_label, + CTCBeamState* to_state, int to_label) const {} + // ExpandStateEnd is called after decoding has finished. Its purpose is to + // allow a final scoring of the beam in its current state, before resorting + // and retrieving the TopN requested candidates. Called at most once per beam. + virtual void ExpandStateEnd(CTCBeamState* state) const {} + // GetStateExpansionScore should be an inexpensive method to retrieve the + // (cached) expansion score computed within ExpandState. The score is + // multiplied (log-addition) with the input score at the current step from + // the network. + // + // The score returned should be a log-probability. In the simplest case, as + // there's no state expansion logic, the expansion score is zero. + virtual T GetStateExpansionScore(const CTCBeamState& state, + T previous_score) const { + return previous_score; + } + // GetStateEndExpansionScore should be an inexpensive method to retrieve the + // (cached) expansion score computed within ExpandStateEnd. The score is + // multiplied (log-addition) with the final probability of the beam. + // + // The score returned should be a log-probability. + virtual T GetStateEndExpansionScore(const CTCBeamState& state) const { + return T(0); + } +}; + +} // namespace ctc +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_UTIL_CTC_CTC_BEAM_SCORER_H_ +// LINT.ThenChange(//tensorflow/lite/kernels/ctc/ctc_beam_scorer.h) diff --git a/third_party/tflite-hdrs/tensorflow/core/util/ctc/ctc_beam_search.h b/third_party/tflite-hdrs/tensorflow/core/util/ctc/ctc_beam_search.h new file mode 100644 index 00000000..a592d7a3 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/util/ctc/ctc_beam_search.h @@ -0,0 +1,437 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// LINT.IfChange + +#ifndef TENSORFLOW_CORE_UTIL_CTC_CTC_BEAM_SEARCH_H_ +#define TENSORFLOW_CORE_UTIL_CTC_CTC_BEAM_SEARCH_H_ + +#include +#include +#include +#include +#include + +#include "Eigen/Core" // from @eigen_archive +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/gtl/top_n.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/ctc/ctc_beam_entry.h" +#include "tensorflow/core/util/ctc/ctc_beam_scorer.h" +#include "tensorflow/core/util/ctc/ctc_decoder.h" +#include "tensorflow/core/util/ctc/ctc_loss_util.h" + +namespace tensorflow { +namespace ctc { + +template > +class CTCBeamSearchDecoder : public CTCDecoder { + // Beam Search + // + // Example (GravesTh Fig. 7.5): + // a - + // P = [ 0.3 0.7 ] t = 0 + // [ 0.4 0.6 ] t = 1 + // + // Then P(l = -) = P(--) = 0.7 * 0.6 = 0.42 + // P(l = a) = P(a-) + P(aa) + P(-a) = 0.3*0.4 + ... = 0.58 + // + // In this case, Best Path decoding is suboptimal. + // + // For Beam Search, we use the following main recurrence relations: + // + // Relation 1: + // ---------------------------------------------------------- Eq. 1 + // P(l=abcd @ t=7) = P(l=abc @ t=6) * P(d @ 7) + // + P(l=abcd @ t=6) * (P(d @ 7) + P(- @ 7)) + // where P(l=? @ t=7), ? = a, ab, abc, abcd are all stored and + // updated recursively in the beam entry. + // + // Relation 2: + // ---------------------------------------------------------- Eq. 2 + // P(l=abc? @ t=3) = P(l=abc @ t=2) * P(? @ 3) + // for ? in a, b, d, ..., (not including c or the blank index), + // and the recurrence starts from the beam entry for P(l=abc @ t=2). + // + // For this case, the length of the new sequence equals t+1 (t + // starts at 0). This special case can be calculated as: + // P(l=abc? @ t=3) = P(a @ 0)*P(b @ 1)*P(c @ 2)*P(? @ 3) + // but we calculate it recursively for speed purposes. + typedef ctc_beam_search::BeamEntry BeamEntry; + typedef ctc_beam_search::BeamRoot BeamRoot; + typedef ctc_beam_search::BeamProbability BeamProbability; + + public: + typedef BaseBeamScorer DefaultBeamScorer; + + // The beam search decoder is constructed specifying the beam_width (number of + // candidates to keep at each decoding timestep) and a beam scorer (used for + // custom scoring, for example enabling the use of a language model). + // The ownership of the scorer remains with the caller. The default + // implementation, CTCBeamSearchDecoder<>::DefaultBeamScorer, generates the + // standard beam search. + CTCBeamSearchDecoder(int num_classes, int beam_width, + BaseBeamScorer* scorer, + int batch_size = 1, bool merge_repeated = false) + : CTCDecoder(num_classes, batch_size, merge_repeated), + beam_width_(beam_width), + leaves_(beam_width), + beam_scorer_(CHECK_NOTNULL(scorer)) { + Reset(); + } + + ~CTCBeamSearchDecoder() override {} + + // Run the hibernating beam search algorithm on the given input. + absl::Status Decode(const typename CTCDecoder::SequenceLength& seq_len, + const std::vector::Input>& input, + std::vector::Output>* output, + typename CTCDecoder::ScoreOutput* scores) override; + + // Calculate the next step of the beam search and update the internal state. + template + void Step(const Vector& log_input_t); + + template + T GetTopK(const int K, const Vector& input, std::vector* top_k_logits, + std::vector* top_k_indices); + + // Retrieve the beam scorer instance used during decoding. + BaseBeamScorer* GetBeamScorer() const { + return beam_scorer_; + } + + // Set label selection parameters for faster decoding. + // See comments for label_selection_size_ and label_selection_margin_. + void SetLabelSelectionParameters(int label_selection_size, + T label_selection_margin) { + label_selection_size_ = label_selection_size; + label_selection_margin_ = label_selection_margin; + } + + // Reset the beam search + void Reset(); + + // Extract the top n paths at current time step + absl::Status TopPaths(int n, std::vector>* paths, + std::vector* log_probs, bool merge_repeated) const; + + private: + int beam_width_; + + // Label selection is designed to avoid possibly very expensive scorer calls, + // by pruning the hypotheses based on the input alone. + // Label selection size controls how many items in each beam are passed + // through to the beam scorer. Only items with top N input scores are + // considered. + // Label selection margin controls the difference between minimal input score + // (versus the best scoring label) for an item to be passed to the beam + // scorer. This margin is expressed in terms of log-probability. + // Default is to do no label selection. + // For more detail: https://research.google.com/pubs/pub44823.html + int label_selection_size_ = 0; // zero means unlimited + T label_selection_margin_ = -1; // -1 means unlimited. + + gtl::TopN leaves_; + std::unique_ptr beam_root_; + BaseBeamScorer* beam_scorer_; + + CTCBeamSearchDecoder(const CTCBeamSearchDecoder&) = delete; + void operator=(const CTCBeamSearchDecoder&) = delete; +}; + +template +absl::Status CTCBeamSearchDecoder::Decode( + const typename CTCDecoder::SequenceLength& seq_len, + const std::vector::Input>& input, + std::vector::Output>* output, + typename CTCDecoder::ScoreOutput* scores) { + // Storage for top paths. + std::vector> beams; + std::vector beam_log_probabilities; + int top_n = output->size(); + if (std::any_of(output->begin(), output->end(), + [this](const typename CTCDecoder::Output& output) -> bool { + return output.size() < this->batch_size_; + })) { + return errors::InvalidArgument( + "output needs to be of size at least (top_n, batch_size)."); + } + if (scores->rows() < this->batch_size_ || scores->cols() < top_n) { + return errors::InvalidArgument( + "scores needs to be of size at least (batch_size, top_n)."); + } + + for (int b = 0; b < this->batch_size_; ++b) { + int seq_len_b = seq_len[b]; + Reset(); + + for (int t = 0; t < seq_len_b; ++t) { + // Pass log-probabilities for this example + time. + Step(input[t].row(b)); + } // for (int t... + + // O(n * log(n)) + std::unique_ptr> branches(leaves_.Extract()); + leaves_.Reset(); + for (int i = 0; i < branches->size(); ++i) { + BeamEntry* entry = (*branches)[i]; + beam_scorer_->ExpandStateEnd(&entry->state); + entry->newp.total += + beam_scorer_->GetStateEndExpansionScore(entry->state); + leaves_.push(entry); + } + + absl::Status status = + TopPaths(top_n, &beams, &beam_log_probabilities, this->merge_repeated_); + if (!status.ok()) { + return status; + } + + CHECK_EQ(top_n, beam_log_probabilities.size()); + CHECK_EQ(beams.size(), beam_log_probabilities.size()); + + for (int i = 0; i < top_n; ++i) { + // Copy output to the correct beam + batch + (*output)[i][b].swap(beams[i]); + (*scores)(b, i) = -beam_log_probabilities[i]; + } + } // for (int b... + return absl::OkStatus(); +} + +template +template +T CTCBeamSearchDecoder::GetTopK( + const int K, const Vector& input, std::vector* top_k_logits, + std::vector* top_k_indices) { + // Find Top K choices, complexity nk in worst case. The array input is read + // just once. + CHECK_EQ(this->num_classes_, input.size()); + top_k_logits->clear(); + top_k_indices->clear(); + top_k_logits->resize(K, -INFINITY); + top_k_indices->resize(K, -1); + for (int j = 0; j < this->num_classes_ - 1; ++j) { + const T logit = input(j); + if (logit > (*top_k_logits)[K - 1]) { + int k = K - 1; + while (k > 0 && logit > (*top_k_logits)[k - 1]) { + (*top_k_logits)[k] = (*top_k_logits)[k - 1]; + (*top_k_indices)[k] = (*top_k_indices)[k - 1]; + k--; + } + (*top_k_logits)[k] = logit; + (*top_k_indices)[k] = j; + } + } + // Return max value which is in 0th index or blank character logit + return std::max((*top_k_logits)[0], input(this->num_classes_ - 1)); +} + +template +template +void CTCBeamSearchDecoder::Step( + const Vector& raw_input) { + std::vector top_k_logits; + std::vector top_k_indices; + const bool top_k = + (label_selection_size_ > 0 && label_selection_size_ < raw_input.size()); + // Number of character classes to consider in each step. + const int max_classes = + top_k ? label_selection_size_ : (this->num_classes_ - 1); + // Get max coefficient and remove it from raw_input later. + T max_coeff; + if (top_k) { + max_coeff = GetTopK(label_selection_size_, raw_input, &top_k_logits, + &top_k_indices); + } else { + max_coeff = raw_input.maxCoeff(); + } + // Get normalization term of softmax: log(sum(exp(logit[j]-max_coeff))). + T logsumexp = T(0.0); + for (int j = 0; j < raw_input.size(); ++j) { + logsumexp += Eigen::numext::exp(raw_input(j) - max_coeff); + } + logsumexp = Eigen::numext::log(logsumexp); + // Final normalization offset to get correct log probabilities. + T norm_offset = max_coeff + logsumexp; + + const T label_selection_input_min = + (label_selection_margin_ >= 0) ? (max_coeff - label_selection_margin_) + : -std::numeric_limits::infinity(); + + // Extract the beams sorted in decreasing new probability + CHECK_EQ(this->num_classes_, raw_input.size()); + + std::unique_ptr> branches(leaves_.Extract()); + leaves_.Reset(); + + for (BeamEntry* b : *branches) { + // P(.. @ t) becomes the new P(.. @ t-1) + b->oldp = b->newp; + } + + for (BeamEntry* b : *branches) { + if (b->parent != nullptr) { // if not the root + if (b->parent->Active()) { + // If last two sequence characters are identical: + // Plabel(l=acc @ t=6) = (Plabel(l=acc @ t=5) + // + Pblank(l=ac @ t=5)) + // else: + // Plabel(l=abc @ t=6) = (Plabel(l=abc @ t=5) + // + P(l=ab @ t=5)) + T previous = (b->label == b->parent->label) ? b->parent->oldp.blank + : b->parent->oldp.total; + b->newp.label = + LogSumExp(b->newp.label, + beam_scorer_->GetStateExpansionScore(b->state, previous)); + } + // Plabel(l=abc @ t=6) *= P(c @ 6) + b->newp.label += raw_input(b->label) - norm_offset; + } + // Pblank(l=abc @ t=6) = P(l=abc @ t=5) * P(- @ 6) + b->newp.blank = b->oldp.total + raw_input(this->blank_index_) - norm_offset; + // P(l=abc @ t=6) = Plabel(l=abc @ t=6) + Pblank(l=abc @ t=6) + b->newp.total = LogSumExp(b->newp.blank, b->newp.label); + + // Push the entry back to the top paths list. + // Note, this will always fill leaves back up in sorted order. + leaves_.push(b); + } + + // we need to resort branches in descending oldp order. + + // branches is in descending oldp order because it was + // originally in descending newp order and we copied newp to oldp. + + // Grow new leaves + for (BeamEntry* b : *branches) { + // A new leaf (represented by its BeamProbability) is a candidate + // iff its total probability is nonzero and either the beam list + // isn't full, or the lowest probability entry in the beam has a + // lower probability than the leaf. + auto is_candidate = [this](const BeamProbability& prob) { + return (prob.total > kLogZero() && + (leaves_.size() < beam_width_ || + prob.total > leaves_.peek_bottom()->newp.total)); + }; + + if (!is_candidate(b->oldp)) { + continue; + } + + for (int ind = 0; ind < max_classes; ind++) { + const int label = top_k ? top_k_indices[ind] : ind; + const T logit = top_k ? top_k_logits[ind] : raw_input(ind); + // Perform label selection: if input for this label looks very + // unpromising, never evaluate it with a scorer. + // We may compare logits instead of log probabilities, + // since the difference is the same in both cases. + if (logit < label_selection_input_min) { + continue; + } + BeamEntry& c = b->GetChild(label); + if (!c.Active()) { + // Pblank(l=abcd @ t=6) = 0 + c.newp.blank = kLogZero(); + // If new child label is identical to beam label: + // Plabel(l=abcc @ t=6) = Pblank(l=abc @ t=5) * P(c @ 6) + // Otherwise: + // Plabel(l=abcd @ t=6) = P(l=abc @ t=5) * P(d @ 6) + beam_scorer_->ExpandState(b->state, b->label, &c.state, c.label); + T previous = (c.label == b->label) ? b->oldp.blank : b->oldp.total; + c.newp.label = logit - norm_offset + + beam_scorer_->GetStateExpansionScore(c.state, previous); + // P(l=abcd @ t=6) = Plabel(l=abcd @ t=6) + c.newp.total = c.newp.label; + + if (is_candidate(c.newp)) { + // Before adding the new node to the beam, check if the beam + // is already at maximum width. + if (leaves_.size() == beam_width_) { + // Bottom is no longer in the beam search. Reset + // its probability; signal it's no longer in the beam search. + BeamEntry* bottom = leaves_.peek_bottom(); + bottom->newp.Reset(); + } + leaves_.push(&c); + } else { + // Deactivate child. + c.oldp.Reset(); + c.newp.Reset(); + } + } + } + } // for (BeamEntry* b... +} + +template +void CTCBeamSearchDecoder::Reset() { + leaves_.Reset(); + + // This beam root, and all of its children, will be in memory until + // the next reset. + beam_root_.reset(new BeamRoot(nullptr, -1)); + beam_root_->RootEntry()->newp.total = T(0.0); // ln(1) + beam_root_->RootEntry()->newp.blank = T(0.0); // ln(1) + + // Add the root as the initial leaf. + leaves_.push(beam_root_->RootEntry()); + + // Call initialize state on the root object. + beam_scorer_->InitializeState(&beam_root_->RootEntry()->state); +} + +template +absl::Status CTCBeamSearchDecoder::TopPaths( + int n, std::vector>* paths, std::vector* log_probs, + bool merge_repeated) const { + CHECK_NOTNULL(paths)->clear(); + CHECK_NOTNULL(log_probs)->clear(); + if (n > beam_width_) { + return errors::InvalidArgument("requested more paths than the beam width."); + } + if (n > leaves_.size()) { + return errors::InvalidArgument( + "Less leaves in the beam search than requested."); + } + + gtl::TopN top_branches(n); + + // O(beam_width_ * log(n)), space complexity is O(n) + for (auto it = leaves_.unsorted_begin(); it != leaves_.unsorted_end(); ++it) { + top_branches.push(*it); + } + // O(n * log(n)) + std::unique_ptr> branches(top_branches.Extract()); + + for (int i = 0; i < n; ++i) { + BeamEntry* e((*branches)[i]); + paths->push_back(e->LabelSeq(merge_repeated)); + log_probs->push_back(e->newp.total); + } + return absl::OkStatus(); +} + +} // namespace ctc +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_UTIL_CTC_CTC_BEAM_SEARCH_H_ +// LINT.ThenChange(//tensorflow/lite/kernels/ctc/ctc_beam_search.h) diff --git a/third_party/tflite-hdrs/tensorflow/core/util/ctc/ctc_decoder.h b/third_party/tflite-hdrs/tensorflow/core/util/ctc/ctc_decoder.h new file mode 100644 index 00000000..8e6b3477 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/util/ctc/ctc_decoder.h @@ -0,0 +1,122 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// LINT.IfChange + +#ifndef TENSORFLOW_CORE_UTIL_CTC_CTC_DECODER_H_ +#define TENSORFLOW_CORE_UTIL_CTC_CTC_DECODER_H_ + +#include +#include + +#include "Eigen/Core" // from @eigen_archive +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { +namespace ctc { + +// The CTCDecoder is an abstract interface to be implemented when providing a +// decoding method on the timestep output of a RNN trained with CTC loss. +// +// The two types of decoding available are: +// - greedy path, through the CTCGreedyDecoder +// - beam search, through the CTCBeamSearchDecoder +template +class CTCDecoder { + public: + typedef Eigen::Map SequenceLength; + typedef Eigen::Map> + Input; + typedef std::vector> Output; + typedef Eigen::Map> + ScoreOutput; + + CTCDecoder(int num_classes, int batch_size, bool merge_repeated) + : num_classes_(num_classes), + blank_index_(num_classes - 1), + batch_size_(batch_size), + merge_repeated_(merge_repeated) {} + + virtual ~CTCDecoder() {} + + // Dimensionality of the input/output is expected to be: + // - seq_len[b] - b = 0 to batch_size_ + // - input[t].rows(b) - t = 0 to timesteps; b = 0 t batch_size_ + // - output.size() specifies the number of beams to be returned. + // - scores(b, i) - b = 0 to batch_size; i = 0 to output.size() + virtual absl::Status Decode(const SequenceLength& seq_len, + const std::vector& input, + std::vector* output, + ScoreOutput* scores) = 0; + + int batch_size() { return batch_size_; } + int num_classes() { return num_classes_; } + + protected: + int num_classes_; + int blank_index_; + int batch_size_; + bool merge_repeated_; +}; + +// CTCGreedyDecoder is an implementation of the simple best path decoding +// algorithm, selecting at each timestep the most likely class at each timestep. +template +class CTCGreedyDecoder : public CTCDecoder { + public: + typedef CTCDecoder Decoder; + CTCGreedyDecoder(int num_classes, int batch_size, bool merge_repeated) + : CTCDecoder(num_classes, batch_size, merge_repeated) {} + + absl::Status Decode(const typename CTCDecoder::SequenceLength& seq_len, + const std::vector::Input>& input, + std::vector::Output>* output, + typename CTCDecoder::ScoreOutput* scores) override { + if (output->empty() || (*output)[0].size() < Decoder::batch_size_) { + return errors::InvalidArgument( + "output needs to be of size at least (1, batch_size)."); + } + if (scores->rows() < Decoder::batch_size_ || scores->cols() == 0) { + return errors::InvalidArgument( + "scores needs to be of size at least (batch_size, 1)."); + } + // For each batch entry, identify the transitions + for (int b = 0; b < Decoder::batch_size_; ++b) { + int seq_len_b = seq_len[b]; + // Only writing to beam 0 + std::vector& output_b = (*output)[0][b]; + + int prev_class_ix = -1; + (*scores)(b, 0) = 0; + for (int t = 0; t < seq_len_b; ++t) { + auto row = input[t].row(b); + int max_class_ix; + (*scores)(b, 0) += -row.maxCoeff(&max_class_ix); + if (max_class_ix != Decoder::blank_index_ && + !(Decoder::merge_repeated_ && max_class_ix == prev_class_ix)) { + output_b.push_back(max_class_ix); + } + prev_class_ix = max_class_ix; + } + } + return absl::OkStatus(); + } +}; + +} // namespace ctc +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_UTIL_CTC_CTC_DECODER_H_ +// LINT.ThenChange(//tensorflow/lite/kernels/ctc/ctc_decoder.h) diff --git a/third_party/tflite-hdrs/tensorflow/core/util/ctc/ctc_loss_calculator.h b/third_party/tflite-hdrs/tensorflow/core/util/ctc/ctc_loss_calculator.h new file mode 100644 index 00000000..12c4ac0a --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/util/ctc/ctc_loss_calculator.h @@ -0,0 +1,544 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_UTIL_CTC_CTC_LOSS_CALCULATOR_H_ +#define TENSORFLOW_CORE_UTIL_CTC_CTC_LOSS_CALCULATOR_H_ + +#include + +#include "Eigen/Core" // from @eigen_archive +#include "tensorflow/core/framework/device_base.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/util/ctc/ctc_loss_util.h" +#include "tensorflow/core/util/work_sharder.h" + +namespace tensorflow { +namespace ctc { + +template +class CTCLossCalculator { + // Connectionist Temporal Classification Loss + // + // Implementation by kanishkarao@, posenhuang@, and ebrevdo@. + // + // The CTC Loss layer learns a *transition* probability value for each + // input time step. The transitions are on the class alphabet + // {0, 1, ..., N-2} + // where N is the depth of the input layer (the size of the alphabet is N-1). + // Note: The token N-1 is reserved for the "no transition" output, so + // make sure that your input layer has a depth that's one larger than + // the set of classes you're training on. Also make sure that your + // training labels do not have a class value of N-1, as training will skip + // these examples. + // + // Reference materials: + // GravesTh: Alex Graves, "Supervised Sequence Labeling with Recurrent + // Neural Networks" (PhD Thesis), Technische Universit¨at M¨unchen. + public: + typedef std::vector> LabelSequences; + using Matrix = Eigen::Matrix; + // typedef Eigen::MatrixXd Matrix; + using Array = Eigen::Array; + // typedef Eigen::ArrayXd Array; + using InputMap = Eigen::Map; + // typedef Eigen::Map InputMap; + using OutputMap = Eigen::Map; + // typedef Eigen::Map OutputMap; + + CTCLossCalculator(int blank_index, int output_delay) + : blank_index_(blank_index), output_delay_(output_delay) {} + + template + absl::Status CalculateLoss( + const VectorIn& seq_len, const LabelSequences& labels, + const std::vector& inputs, bool preprocess_collapse_repeated, + bool ctc_merge_repeated, bool ignore_longer_outputs_than_inputs, + VectorOut* loss, std::vector* gradients, + DeviceBase::CpuWorkerThreads* workers = nullptr) const; + + private: + void CalculateForwardVariables(const std::vector& l_prime, + const Matrix& y, bool ctc_merge_repeated, + Matrix* log_alpha) const; + + void CalculateBackwardVariables(const std::vector& l_prime, + const Matrix& y, bool ctc_merge_repeated, + Matrix* log_beta) const; + + void CalculateGradient(const std::vector& l_prime, const Matrix& y, + const Matrix& log_alpha, const Matrix& log_beta, + T log_p_z_x, Matrix* dy) const; + + void GetLPrimeIndices(const std::vector& l, + std::vector* l_prime) const; + + // Helper function that calculates the l_prime indices for all + // batches at the same time, and identifies errors for any given + // batch. Return value: + // max_{b in batch_size} l_primes[b].size() + template + absl::Status PopulateLPrimes(bool preprocess_collapse_repeated, + bool ignore_longer_outputs_than_inputs, + int batch_size, int num_classes, + const Vector& seq_len, + const LabelSequences& labels, + size_t* max_u_prime, + LabelSequences* l_primes) const; + + // Utility indices for the CTC algorithm. + int blank_index_; + + // Delay for target labels in time steps. + // The delay in time steps before the output sequence. + const int output_delay_; +}; + +template +template +absl::Status CTCLossCalculator::CalculateLoss( + const VectorIn& seq_len, const LabelSequences& labels, + const std::vector& inputs, bool preprocess_collapse_repeated, + bool ctc_merge_repeated, bool ignore_longer_outputs_than_inputs, + VectorOut* loss, std::vector* gradients, + DeviceBase::CpuWorkerThreads* workers) const { + using Eigen::numext::log; + + auto num_time_steps = inputs.size(); + + if (loss == nullptr) { + return errors::InvalidArgument("loss == nullptr"); + } + + bool requires_backprop = (gradients != nullptr); + + auto batch_size = inputs[0].rows(); + auto num_classes = inputs[0].cols(); + + if (loss->size() != batch_size) { + return errors::InvalidArgument("loss.size() != batch_size"); + } + loss->setZero(); + + for (int t = 1; t < num_time_steps; ++t) { + if (inputs[t].rows() != batch_size) { + return errors::InvalidArgument("Expected batch size at t: ", t, + " to be: ", batch_size, + " but got: ", inputs[t].rows()); + } + if (inputs[t].cols() != num_classes) { + return errors::InvalidArgument("Expected class count at t: ", t, + " to be: ", num_classes, + " but got: ", inputs[t].cols()); + } + } + + // Check validity of sequence_length array values. + auto max_seq_len = seq_len(0); + for (int b = 0; b < batch_size; b++) { + if (seq_len(b) < 0) { + return errors::InvalidArgument("seq_len(", b, ") < 0"); + } + if (seq_len(b) > num_time_steps) { + return errors::InvalidArgument("seq_len(", b, ") > num_time_steps"); + } + max_seq_len = std::max(seq_len(b), max_seq_len); + } + + // Calculate the modified label sequence l' for each batch element, + // and calculate the maximum necessary allocation size. + LabelSequences l_primes(batch_size); + size_t max_u_prime = 0; + absl::Status l_p_ret = PopulateLPrimes( + preprocess_collapse_repeated, ignore_longer_outputs_than_inputs, + batch_size, num_classes, seq_len, labels, &max_u_prime, &l_primes); + if (!l_p_ret.ok()) { + return l_p_ret; + } + + // Process each item in a batch in parallel, using at most kMaxThreads. + auto ComputeLossAndGradients = [this, num_classes, &labels, &l_primes, + &seq_len, &inputs, requires_backprop, + ctc_merge_repeated, + ignore_longer_outputs_than_inputs, &loss, + &gradients](int64_t start_row, + int64_t limit_row) { + for (int b = start_row; b < limit_row; b++) { + // Return zero gradient for empty sequences or sequences with labels + // longer than input, which is not supported by CTC. + if (seq_len(b) == 0 || + (ignore_longer_outputs_than_inputs && + labels[b].size() > seq_len(b) - this->output_delay_)) { + VLOG(1) << "The sequence length is either zero or shorter than the " + "target output (CTC works only with shorter target sequence " + "than input sequence). You can turn this into a warning by " + "using the flag ignore_longer_outputs_than_inputs - " + << b << ": " << absl::StrJoin(labels[b], " "); + continue; + } + + // For each batch element, log(alpha) and log(beta). + // row size is: u_prime == l_prime.size() + // col size is: seq_len[b] - output_delay_ + const std::vector& l_prime = l_primes[b]; + + Matrix log_alpha_b(l_prime.size(), seq_len(b) - this->output_delay_); + Matrix log_beta_b(l_prime.size(), seq_len(b) - this->output_delay_); + + // Work matrices, pre-allocated to the size required by this batch item. + Matrix y(num_classes, seq_len(b)); + Matrix dy; + if (requires_backprop) { + dy = Matrix::Zero(y.rows(), y.cols()); + } + + // For this batch, we'll only work with this shortened sequence_length. + Matrix y_b = y.leftCols(seq_len(b)); + + // Convert label from DistBelief + // y, prob are in num_classes x seq_len(b) + // Output activations. + Array y_b_col; + for (int t = 0; t < seq_len(b); t++) { + // Calculate the softmax of y_b. Use original precision + // arithmetic for the sum. + T max_coeff = inputs[t].row(b).maxCoeff(); + y_b_col = (inputs[t].row(b).array() - max_coeff).exp(); + y_b.col(t) = y_b_col / y_b_col.sum(); + } + + // Compute forward, backward. + // Forward variables. + CalculateForwardVariables(l_prime, y_b, ctc_merge_repeated, &log_alpha_b); + // Backward variables. + CalculateBackwardVariables(l_prime, y_b, ctc_merge_repeated, &log_beta_b); + + // The loss is computed as the log(p(z|x)) between the target and + // prediction. Do lazy evaluation of log_prob here. + T log_p_z_x = kLogZero(); + for (int u = 0; u < l_prime.size(); ++u) { + // (GravesTh) Eq 7.26, sum over all paths for t = 0. + log_p_z_x = LogSumExp(log_p_z_x, log_alpha_b(u, 0) + log_beta_b(u, 0)); + } + + (*loss)(b) = -log_p_z_x; // Use negative log loss for display. + + // We compute the derivative if needed. + if (requires_backprop) { + // Gradients with respect to input activations. + // Calculate gradient. + dy.setZero(); + CalculateGradient(l_prime, y_b, log_alpha_b, log_beta_b, log_p_z_x, + &dy); + + // Convert gradient for current sample to DistBelief. + for (int t = 0; t < seq_len(b); t++) { + (*gradients)[t].row(b).array() = dy.col(t); + } + } + } // for (int b = ... + }; + if (workers) { + // *Rough* estimate of the cost for one item in the batch. + // Forward, Backward: O(T * U (= 2L + 1)), Gradients: O(T * (U + L)). + // + // softmax: T * L * (Cost(Exp) + Cost(Div))softmax + + // fwd,bwd: T * 2 * (2*L + 1) * (Cost(LogSumExp) + Cost(Log)) + + // grad: T * ((2L + 1) * Cost(LogSumExp) + L * (Cost(Expf) + Cost(Add)). + const int64_t cost_exp = Eigen::internal::functor_traits< + Eigen::internal::scalar_exp_op>::Cost; + const int64_t cost_log = Eigen::internal::functor_traits< + Eigen::internal::scalar_log_op>::Cost; + const int64_t cost_log_sum_exp = + Eigen::TensorOpCost::AddCost() + cost_exp + cost_log; + const int64_t cost = + max_seq_len * num_classes * + (cost_exp + Eigen::TensorOpCost::DivCost()) + + max_seq_len * 2 * (2 * num_classes + 1) * + (cost_log_sum_exp + cost_log) + + max_seq_len * + ((2 * num_classes + 1) * cost_log_sum_exp + + num_classes * (cost_exp + Eigen::TensorOpCost::AddCost())); + Shard(workers->num_threads, workers->workers, batch_size, cost, + ComputeLossAndGradients); + } else { + ComputeLossAndGradients(0, batch_size); + } + return absl::OkStatus(); +} + +template +template +absl::Status CTCLossCalculator::PopulateLPrimes( + bool preprocess_collapse_repeated, bool ignore_longer_outputs_than_inputs, + int batch_size, int num_classes, const Vector& seq_len, + const LabelSequences& labels, size_t* max_u_prime, + LabelSequences* l_primes) const { + // labels is a Label array of size batch_size + if (labels.size() != batch_size) { + return errors::InvalidArgument( + "labels.size() != batch_size: ", labels.size(), " vs. ", batch_size); + } + + *max_u_prime = 0; // keep track of longest l' modified label sequence. + for (int b = 0; b < batch_size; b++) { + // Assume label is in Label proto + const std::vector& label = labels[b]; + if (label.size() == 0) { + return errors::InvalidArgument("Labels length is zero in batch ", b); + } + + // If debugging: output the labels coming into training. + // + VLOG(2) << "label for batch: " << b << ": " << absl::StrJoin(label, " "); + + // Target indices, length = U. + std::vector l; + + // Convert label from DistBelief + bool finished_sequence = false; + for (int i = 0; i < label.size(); ++i) { + if (i == 0 || !preprocess_collapse_repeated || label[i] != label[i - 1]) { + if (label[i] >= num_classes - 1) { + finished_sequence = true; + } else { + if (finished_sequence) { + // Saw an invalid sequence with non-null following null + // labels. + return errors::InvalidArgument( + "Saw a non-null label (index >= num_classes - 1) " + "following a ", + "null label, batch: ", b, " num_classes: ", num_classes, + " labels: ", absl::StrJoin(label, ","), + " labels seen so far: ", absl::StrJoin(l, ",")); + } + l.push_back(label[i]); + } + } + } + + for (int l_i : l) { + if (l_i < 0) { + return errors::InvalidArgument( + "All labels must be nonnegative integers, batch: ", b, + " labels: ", absl::StrJoin(l, ",")); + } else if (l_i >= num_classes) { + return errors::InvalidArgument( + "No label may be greater than num_classes. ", + "num_classes: ", num_classes, ", batch: ", b, + " labels: ", absl::StrJoin(l, ",")); + } + } + if (!ignore_longer_outputs_than_inputs) { + // Make sure there is enough time to output the target indices. + int time = seq_len(b) - output_delay_; + int required_time = label.size(); + if (required_time > time) { + return errors::InvalidArgument( + "Not enough time for target transition sequence (" + "required: ", + required_time, ", available: ", time, ")", b, + "You can turn this error into a warning by using the flag " + "ignore_longer_outputs_than_inputs"); + } + } + // Target indices with blanks before each index and a blank at the end. + // Length U' = 2U + 1. + // Convert l to l_prime + GetLPrimeIndices(l, &l_primes->at(b)); + *max_u_prime = std::max(*max_u_prime, l_primes->at(b).size()); + } + return absl::OkStatus(); +} + +// Calculates the alpha(t, u) as described in (GravesTh) Section 7.3. +// Starting with t = 0 instead of t = 1 used in the text. +// Based on Kanishka's CTC. +template +void CTCLossCalculator::CalculateForwardVariables( + const std::vector& l_prime, const Matrix& y, bool ctc_merge_repeated, + Matrix* log_alpha) const { + using Eigen::numext::log; + + // Number of cols is the number of time steps = number of cols in target + // after the output delay. + log_alpha->setConstant(kLogZero()); + + int U = l_prime.size(); + int T = log_alpha->cols(); + + CHECK_EQ(U, log_alpha->rows()); + + // Initial alpha values in (GravesTh) Eq 7.5 and Eq 7.6. + log_alpha->coeffRef(0, 0) = log(y(blank_index_, output_delay_)); + // Below, l_prime[1] == labels[0] + auto label_0 = (l_prime.size() > 1) ? l_prime[1] : blank_index_; + log_alpha->coeffRef(1, 0) = log(y(label_0, output_delay_)); + + for (int t = 1; t < T; ++t) { + // If there is not enough time to output the remaining labels or + // some labels have been skipped, then let log_alpha(u, t) continue to + // be kLogZero. + for (int u = std::max(0, U - (2 * (T - t))); u < std::min(U, 2 * (t + 1)); + ++u) { + // Begin (GravesTh) Eq 7.9 + // Add in the u, t - 1 term. + auto sum_log_alpha = kLogZero(); + if (ctc_merge_repeated || l_prime[u] == blank_index_) { + sum_log_alpha = log_alpha->coeff(u, t - 1); + } + + // Add in the u - 1, t - 1 term. + if (u > 0) { + sum_log_alpha = + LogSumExp(sum_log_alpha, log_alpha->coeff(u - 1, t - 1)); + } + + // Add in the u - 2, t - 1 term if l_prime(u) != blank or l_prime(u-2). + if (u > 1) { + const bool matching_labels_merge = + ctc_merge_repeated && (l_prime[u] == l_prime[u - 2]); + if (l_prime[u] != blank_index_ && !matching_labels_merge) { + sum_log_alpha = + LogSumExp(sum_log_alpha, log_alpha->coeff(u - 2, t - 1)); + } + } + // Multiply the summed alphas with the activation log probability. + log_alpha->coeffRef(u, t) = + log(y(l_prime[u], output_delay_ + t)) + sum_log_alpha; + } // End (GravesTh) Eq 7.9. + } +} + +// Calculates the beta(t, u) as described in (GravesTh) Section 7.3. +template +void CTCLossCalculator::CalculateBackwardVariables( + const std::vector& l_prime, const Matrix& y, bool ctc_merge_repeated, + Matrix* log_beta) const { + // Number of cols is the number of time steps = number of cols in target. + // Matrix log_beta = + // Matrix::Constant(l_prime.size(), y.cols() - output_delay_, + // kLogZero); + using Eigen::numext::log; + + log_beta->setConstant(kLogZero()); + int T = log_beta->cols(); + int U = l_prime.size(); + CHECK_EQ(U, log_beta->rows()); + + // Initial beta values in (GravesTh) Eq 7.13: log of probability 1. + for (int u = U - 2; u < U; ++u) log_beta->coeffRef(u, T - 1) = 0; + + for (int t = T - 1 - 1; t >= 0; --t) { + // If there is not enough time to output the remaining labels or + // some labels have been skipped, then let log_beta(u, t) continue to + // be kLogZero. + for (int u = std::max(0, U - (2 * (T - t))); u < std::min(U, 2 * (t + 1)); + ++u) { + // Begin (GravesTh) Eq 7.15 + // Add in the u, t + 1 term. + if (ctc_merge_repeated || l_prime[u] == blank_index_) { + log_beta->coeffRef(u, t) = + LogSumExp(log_beta->coeff(u, t), + log_beta->coeff(u, t + 1) + + log(y(l_prime[u], output_delay_ + t + 1))); + } + + // Add in the u + 1, t + 1 term. + if (u + 1 < U) { + log_beta->coeffRef(u, t) = + LogSumExp(log_beta->coeff(u, t), + log_beta->coeff(u + 1, t + 1) + + log(y(l_prime[u + 1], output_delay_ + t + 1))); + } + + // Add in the u + 2, t + 1 term if l_prime(u) != blank or l_prime(u+2). + if (u + 2 < U) { + const bool matching_labels_merge = + ctc_merge_repeated && (l_prime[u] == l_prime[u + 2]); + if (l_prime[u] != blank_index_ && !matching_labels_merge) { + // Add in u + 2 term. + log_beta->coeffRef(u, t) = + LogSumExp(log_beta->coeff(u, t), + log_beta->coeff(u + 2, t + 1) + + log(y(l_prime[u + 2], output_delay_ + t + 1))); + } + } // End (GravesTh) Eq. 7.15 + } + } +} + +// Using (GravesTh) Eq 7.26 & 7.34. +template +void CTCLossCalculator::CalculateGradient(const std::vector& l_prime, + const Matrix& y, + const Matrix& log_alpha, + const Matrix& log_beta, + TT log_p_z_x, Matrix* dy) const { + // Only working with the leftmost part of dy for this batch element. + auto dy_b = dy->leftCols(y.cols()); + + // It is possible that no valid path is found if the activations for the + // targets are zero. + if (log_p_z_x == kLogZero()) { + LOG(WARNING) << "No valid path found."; + dy_b = y; + return; + } + + int L = y.rows(); + int T = y.cols(); + int U = l_prime.size(); + + for (int t = 0; t < T - output_delay_; ++t) { + Array prob_sum(L); + prob_sum.setConstant(kLogZero()); + + for (int u = 0; u < U; ++u) { + int l = l_prime[u]; + prob_sum[l] = LogSumExp(prob_sum[l], log_alpha(u, t) + log_beta(u, t)); + } + + for (int l = 0; l < L; ++l) { + // Negative term in (GravesTh) Eq 7.28. + auto negative_term = expf(prob_sum[l] - log_p_z_x); + + dy_b(l, output_delay_ + t) = y(l, output_delay_ + t) - negative_term; + } + } +} + +template +void CTCLossCalculator::GetLPrimeIndices(const std::vector& l, + std::vector* l_prime) const { + // Assumption is that l_prime is empty. + l_prime->reserve(2 * l.size() + 1); + + for (auto label : l) { + l_prime->push_back(blank_index_); + l_prime->push_back(label); + } + // Add final blank to l'. + l_prime->push_back(blank_index_); +} + +} // namespace ctc +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_UTIL_CTC_CTC_LOSS_CALCULATOR_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/util/ctc/ctc_loss_util.h b/third_party/tflite-hdrs/tensorflow/core/util/ctc/ctc_loss_util.h new file mode 100644 index 00000000..e9fc99af --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/util/ctc/ctc_loss_util.h @@ -0,0 +1,55 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// LINT.IfChange + +#ifndef TENSORFLOW_CORE_UTIL_CTC_CTC_LOSS_UTIL_H_ +#define TENSORFLOW_CORE_UTIL_CTC_CTC_LOSS_UTIL_H_ + +#include +#include + +namespace tensorflow { +namespace ctc { + +template +constexpr T kLogZero() { + return -std::numeric_limits::infinity(); // NOLINT +} + +// Add logarithmic probabilities using: +// ln(a + b) = ln(a) + ln(1 + exp(ln(b) - ln(a))) +// The two inputs are assumed to be log probabilities. +// (GravesTh) Eq. 7.18 +template +inline T LogSumExp(T log_prob_1, T log_prob_2) { + // const T kLogZero = -std::numeric_limits::infinity(); + // Always have 'b' be the smaller number to avoid the exponential from + // blowing up. + if (log_prob_1 == kLogZero()) { + return log_prob_2; + } else if (log_prob_2 == kLogZero()) { + return log_prob_1; + } else { + return (log_prob_1 > log_prob_2) + ? log_prob_1 + log1pf(expf(log_prob_2 - log_prob_1)) + : log_prob_2 + log1pf(expf(log_prob_1 - log_prob_2)); + } +} + +} // namespace ctc +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_UTIL_CTC_CTC_LOSS_UTIL_H_ +// LINT.ThenChange(//tensorflow/lite/kernels/ctc/ctc_loss_util.h) diff --git a/third_party/tflite-hdrs/tensorflow/core/util/cuda_sparse.h b/third_party/tflite-hdrs/tensorflow/core/util/cuda_sparse.h new file mode 100644 index 00000000..ca3ac8ff --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/util/cuda_sparse.h @@ -0,0 +1,722 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_UTIL_CUDA_SPARSE_H_ +#define TENSORFLOW_CORE_UTIL_CUDA_SPARSE_H_ + +// This header declares the class GpuSparse, which contains wrappers of +// cuSparse libraries for use in TensorFlow kernels. + +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM + +#include +#include + +#if GOOGLE_CUDA + +#include "third_party/gpus/cuda/include/cuda.h" +#include "third_party/gpus/cuda/include/cusparse.h" + +using gpusparseStatus_t = cusparseStatus_t; +using gpusparseOperation_t = cusparseOperation_t; +using gpusparseMatDescr_t = cusparseMatDescr_t; +using gpusparseAction_t = cusparseAction_t; +using gpusparseHandle_t = cusparseHandle_t; +using gpuStream_t = cudaStream_t; +#if CUDA_VERSION >= 10020 +using gpusparseDnMatDescr_t = cusparseDnMatDescr_t; +using gpusparseSpMatDescr_t = cusparseSpMatDescr_t; +using gpusparseSpMMAlg_t = cusparseSpMMAlg_t; +#endif + +#define GPUSPARSE(postfix) CUSPARSE_##postfix +#define gpusparse(postfix) cusparse##postfix + +#elif TENSORFLOW_USE_ROCM + +#include "rocm/rocm_config.h" +#include "xla/stream_executor/rocm/hipsparse_wrapper.h" + +using gpusparseStatus_t = hipsparseStatus_t; +using gpusparseOperation_t = hipsparseOperation_t; +using gpusparseMatDescr_t = hipsparseMatDescr_t; +using gpusparseAction_t = hipsparseAction_t; +using gpusparseHandle_t = hipsparseHandle_t; +using gpuStream_t = hipStream_t; +#if TF_ROCM_VERSION >= 40200 +using gpusparseDnMatDescr_t = hipsparseDnMatDescr_t; +using gpusparseSpMatDescr_t = hipsparseSpMatDescr_t; +using gpusparseSpMMAlg_t = hipsparseSpMMAlg_t; +#endif +#define GPUSPARSE(postfix) HIPSPARSE_##postfix +#define gpusparse(postfix) hipsparse##postfix + +#endif + +#include "xla/stream_executor/data_type.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/stream_executor.h" +#include "tensorflow/core/public/version.h" + +#if GOOGLE_CUDA +#include "xla/stream_executor/cuda/cuda_blas_utils.h" +#endif + +// Macro that specializes a sparse method for all 4 standard +// numeric types. +// TODO: reuse with cuda_solvers +#define TF_CALL_LAPACK_TYPES(m) \ + m(float, S) m(double, D) m(std::complex, C) m(std::complex, Z) + +namespace tensorflow { + +inline std::string ConvertGPUSparseErrorToString( + const gpusparseStatus_t status) { + switch (status) { +#define STRINGIZE(q) #q +#define RETURN_IF_STATUS(err) \ + case err: \ + return STRINGIZE(err); + +#if GOOGLE_CUDA + + RETURN_IF_STATUS(CUSPARSE_STATUS_SUCCESS) + RETURN_IF_STATUS(CUSPARSE_STATUS_NOT_INITIALIZED) + RETURN_IF_STATUS(CUSPARSE_STATUS_ALLOC_FAILED) + RETURN_IF_STATUS(CUSPARSE_STATUS_INVALID_VALUE) + RETURN_IF_STATUS(CUSPARSE_STATUS_ARCH_MISMATCH) + RETURN_IF_STATUS(CUSPARSE_STATUS_MAPPING_ERROR) + RETURN_IF_STATUS(CUSPARSE_STATUS_EXECUTION_FAILED) + RETURN_IF_STATUS(CUSPARSE_STATUS_INTERNAL_ERROR) + RETURN_IF_STATUS(CUSPARSE_STATUS_MATRIX_TYPE_NOT_SUPPORTED) + + default: + return strings::StrCat("Unknown CUSPARSE error: ", + static_cast(status)); +#elif TENSORFLOW_USE_ROCM + + RETURN_IF_STATUS(HIPSPARSE_STATUS_SUCCESS) + RETURN_IF_STATUS(HIPSPARSE_STATUS_NOT_INITIALIZED) + RETURN_IF_STATUS(HIPSPARSE_STATUS_ALLOC_FAILED) + RETURN_IF_STATUS(HIPSPARSE_STATUS_INVALID_VALUE) + RETURN_IF_STATUS(HIPSPARSE_STATUS_ARCH_MISMATCH) + RETURN_IF_STATUS(HIPSPARSE_STATUS_MAPPING_ERROR) + RETURN_IF_STATUS(HIPSPARSE_STATUS_EXECUTION_FAILED) + RETURN_IF_STATUS(HIPSPARSE_STATUS_INTERNAL_ERROR) + RETURN_IF_STATUS(HIPSPARSE_STATUS_MATRIX_TYPE_NOT_SUPPORTED) + RETURN_IF_STATUS(HIPSPARSE_STATUS_ZERO_PIVOT) + + default: + return strings::StrCat("Unknown hipSPARSE error: ", + static_cast(status)); +#endif + +#undef RETURN_IF_STATUS +#undef STRINGIZE + } +} + +#if GOOGLE_CUDA + +#define TF_RETURN_IF_GPUSPARSE_ERROR(expr) \ + do { \ + auto status = (expr); \ + if (TF_PREDICT_FALSE(status != CUSPARSE_STATUS_SUCCESS)) { \ + return errors::Internal(__FILE__, ":", __LINE__, " (", TF_STR(expr), \ + "): cuSparse call failed with status ", \ + ConvertGPUSparseErrorToString(status)); \ + } \ + } while (0) + +#elif TENSORFLOW_USE_ROCM + +#define TF_RETURN_IF_GPUSPARSE_ERROR(expr) \ + do { \ + auto status = (expr); \ + if (TF_PREDICT_FALSE(status != HIPSPARSE_STATUS_SUCCESS)) { \ + return errors::Internal(__FILE__, ":", __LINE__, " (", TF_STR(expr), \ + "): hipSPARSE call failed with status ", \ + ConvertGPUSparseErrorToString(status)); \ + } \ + } while (0) + +#endif + +inline gpusparseOperation_t TransposeAndConjugateToGpuSparseOp(bool transpose, + bool conjugate, + Status* status) { +#if GOOGLE_CUDA + if (transpose) { + return conjugate ? CUSPARSE_OPERATION_CONJUGATE_TRANSPOSE + : CUSPARSE_OPERATION_TRANSPOSE; + } else { + if (conjugate) { + DCHECK(status != nullptr); + *status = errors::InvalidArgument( + "Conjugate == True and transpose == False is not supported."); + } + return CUSPARSE_OPERATION_NON_TRANSPOSE; + } +#elif TENSORFLOW_USE_ROCM + if (transpose) { + return conjugate ? HIPSPARSE_OPERATION_CONJUGATE_TRANSPOSE + : HIPSPARSE_OPERATION_TRANSPOSE; + } else { + if (conjugate) { + DCHECK(status != nullptr); + *status = errors::InvalidArgument( + "Conjugate == True and transpose == False is not supported."); + } + return HIPSPARSE_OPERATION_NON_TRANSPOSE; + } +#endif +} + +#if GOOGLE_CUDA && (CUDA_VERSION >= 12000) + +template +struct ToGpuSparseIndexType; +template <> +struct ToGpuSparseIndexType { + static constexpr cusparseIndexType_t value = CUSPARSE_INDEX_32I; +}; +template <> +struct ToGpuSparseIndexType { + static constexpr cusparseIndexType_t value = CUSPARSE_INDEX_64I; +}; + +class GpuSparseSpGEMMDescr { + public: + GpuSparseSpGEMMDescr() : initialized_(false) {} + ~GpuSparseSpGEMMDescr() { + if (initialized_) { + cusparseSpGEMM_destroyDescr(descr_); + } + } + Status Initialize() { + if (initialized_) { + return errors::Internal("Double initializion of GpuSparseSpGEMMDescr."); + } + TF_RETURN_IF_GPUSPARSE_ERROR(cusparseSpGEMM_createDescr(&descr_)); + initialized_ = true; + return OkStatus(); + } + cusparseSpGEMMDescr_t& get() { return descr_; } + + private: + bool initialized_; + cusparseSpGEMMDescr_t descr_; + + GpuSparseSpGEMMDescr(const GpuSparseSpGEMMDescr&) = delete; + void operator=(const GpuSparseSpGEMMDescr&) = delete; +}; + +class GpuSparseSpMatDescr { + public: + GpuSparseSpMatDescr() : initialized_(false) {} + ~GpuSparseSpMatDescr() { + if (initialized_) { + cusparseDestroySpMat(descr_); + } + } + template + Status InitializeCsr(int64_t rows, int64_t cols, int64_t nnz, + IndexType* csrRowOffsets, IndexType* csrColInd, + FloatType* csrValues) { + if (initialized_) { + return errors::Internal("Double initializion of gpusparseSpMatDescr."); + } + using stream_executor::cuda::AsCudaDataType; + using stream_executor::dnn::ToDataType; + TF_RETURN_IF_GPUSPARSE_ERROR(cusparseCreateCsr( + &descr_, rows, cols, nnz, csrRowOffsets, csrColInd, csrValues, + ToGpuSparseIndexType::value, + ToGpuSparseIndexType::value, CUSPARSE_INDEX_BASE_ZERO, + AsCudaDataType(ToDataType::value))); + initialized_ = true; + return OkStatus(); + } + gpusparseSpMatDescr_t& get() { return descr_; } + + private: + bool initialized_; + cusparseSpMatDescr_t descr_; + GpuSparseSpMatDescr(const GpuSparseSpMatDescr&) = delete; + void operator=(const GpuSparseSpMatDescr&) = delete; +}; + +class GpuSparseConstSpMatDescr { + public: + GpuSparseConstSpMatDescr() : initialized_(false) {} + ~GpuSparseConstSpMatDescr() { + if (initialized_) { + cusparseDestroySpMat(descr_); + } + } + template + Status InitializeCsr(int64_t rows, int64_t cols, int64_t nnz, + const IndexType* csrRowOffsets, + const IndexType* csrColInd, const FloatType* csrValues) { + if (initialized_) { + return errors::Internal("Double initializion of gpusparseSpMatDescr."); + } + using stream_executor::cuda::AsCudaDataType; + using stream_executor::dnn::ToDataType; + TF_RETURN_IF_GPUSPARSE_ERROR(cusparseCreateConstCsr( + &descr_, rows, cols, nnz, csrRowOffsets, csrColInd, csrValues, + ToGpuSparseIndexType::value, + ToGpuSparseIndexType::value, CUSPARSE_INDEX_BASE_ZERO, + AsCudaDataType(ToDataType::value))); + initialized_ = true; + return OkStatus(); + } + cusparseConstSpMatDescr_t& get() { return descr_; } + + private: + bool initialized_; + cusparseConstSpMatDescr_t descr_; + GpuSparseConstSpMatDescr(const GpuSparseConstSpMatDescr&) = delete; + void operator=(const GpuSparseConstSpMatDescr&) = delete; +}; + +#endif + +// The GpuSparse class provides a simplified templated API for cuSparse +// (http://docs.nvidia.com/cuda/cusparse/index.html). +// An object of this class wraps static cuSparse instances, +// and will launch Cuda kernels on the stream wrapped by the GPU device +// in the OpKernelContext provided to the constructor. +// +// Notice: All the computational member functions are asynchronous and simply +// launch one or more Cuda kernels on the Cuda stream wrapped by the GpuSparse +// object. + +class GpuSparse { + public: + // This object stores a pointer to context, which must outlive it. + explicit GpuSparse(OpKernelContext* context); + virtual ~GpuSparse() {} + + // This initializes the GpuSparse class if it hasn't + // been initialized yet. All following public methods require the + // class has been initialized. Can be run multiple times; all + // subsequent calls after the first have no effect. + Status Initialize(); // Move to constructor? + + // ==================================================================== + // Wrappers for cuSparse start here. + // + + // Solves tridiagonal system of equations. + // See: https://docs.nvidia.com/cuda/cusparse/index.html#gtsv2 + template + Status Gtsv2(int m, int n, const Scalar* dl, const Scalar* d, + const Scalar* du, Scalar* B, int ldb, void* pBuffer) const; + + // Computes the size of a temporary buffer used by Gtsv2. + // See: https://docs.nvidia.com/cuda/cusparse/index.html#gtsv2_bufferSize + template + Status Gtsv2BufferSizeExt(int m, int n, const Scalar* dl, const Scalar* d, + const Scalar* du, const Scalar* B, int ldb, + size_t* bufferSizeInBytes) const; + + // Solves tridiagonal system of equations without partial pivoting. + // See: https://docs.nvidia.com/cuda/cusparse/index.html#gtsv2_nopivot + template + Status Gtsv2NoPivot(int m, int n, const Scalar* dl, const Scalar* d, + const Scalar* du, Scalar* B, int ldb, + void* pBuffer) const; + + // Computes the size of a temporary buffer used by Gtsv2NoPivot. + // See: + // https://docs.nvidia.com/cuda/cusparse/index.html#gtsv2_nopivot_bufferSize + template + Status Gtsv2NoPivotBufferSizeExt(int m, int n, const Scalar* dl, + const Scalar* d, const Scalar* du, + const Scalar* B, int ldb, + size_t* bufferSizeInBytes) const; + + // Solves a batch of tridiagonal systems of equations. Doesn't support + // multiple right-hand sides per each system. Doesn't do pivoting. + // See: https://docs.nvidia.com/cuda/cusparse/index.html#gtsv2stridedbatch + template + Status Gtsv2StridedBatch(int m, const Scalar* dl, const Scalar* d, + const Scalar* du, Scalar* x, int batchCount, + int batchStride, void* pBuffer) const; + + // Computes the size of a temporary buffer used by Gtsv2StridedBatch. + // See: + // https://docs.nvidia.com/cuda/cusparse/index.html#gtsv2stridedbatch_bufferSize + template + Status Gtsv2StridedBatchBufferSizeExt(int m, const Scalar* dl, + const Scalar* d, const Scalar* du, + const Scalar* x, int batchCount, + int batchStride, + size_t* bufferSizeInBytes) const; + + // Compresses the indices of rows or columns. It can be interpreted as a + // conversion from COO to CSR sparse storage format. See: + // http://docs.nvidia.com/cuda/cusparse/index.html#cusparse-lt-t-gt-csr2coo. + Status Csr2coo(const int* CsrRowPtr, int nnz, int m, int* cooRowInd) const; + + // Uncompresses the indices of rows or columns. It can be interpreted as a + // conversion from CSR to COO sparse storage format. See: + // http://docs.nvidia.com/cuda/cusparse/index.html#cusparse-lt-t-gt-coo2csr. + Status Coo2csr(const int* cooRowInd, int nnz, int m, int* csrRowPtr) const; + +#if (GOOGLE_CUDA && (CUDA_VERSION < 10020)) || \ + (TENSORFLOW_USE_ROCM && TF_ROCM_VERSION < 40200) + // Sparse-dense matrix multiplication C = alpha * op(A) * op(B) + beta * C, + // where A is a sparse matrix in CSR format, B and C are dense tall + // matrices. This routine allows transposition of matrix B, which + // may improve performance. See: + // http://docs.nvidia.com/cuda/cusparse/index.html#cusparse-lt-t-gt-csrmm2 + // + // **NOTE** Matrices B and C are expected to be in column-major + // order; to make them consistent with TensorFlow they + // must be transposed (or the matmul op's pre/post-processing must take this + // into account). + // + // **NOTE** This is an in-place operation for data in C. + template + Status Csrmm(gpusparseOperation_t transA, gpusparseOperation_t transB, int m, + int n, int k, int nnz, const Scalar* alpha_host, + const gpusparseMatDescr_t descrA, const Scalar* csrSortedValA, + const int* csrSortedRowPtrA, const int* csrSortedColIndA, + const Scalar* B, int ldb, const Scalar* beta_host, Scalar* C, + int ldc) const; +#else // CUDA_VERSION >=10200 || TF_ROCM_VERSION >= 40200 + // Workspace size query for sparse-dense matrix multiplication. Helper + // function for SpMM which computes y = alpha * op(A) * op(B) + beta * C, + // where A is a sparse matrix in CSR format, B and C are dense matricies in + // column-major format. Returns needed workspace size in bytes. + template + Status SpMMBufferSize(gpusparseOperation_t transA, + gpusparseOperation_t transB, const Scalar* alpha, + const gpusparseSpMatDescr_t matA, + const gpusparseDnMatDescr_t matB, const Scalar* beta, + gpusparseDnMatDescr_t matC, gpusparseSpMMAlg_t alg, + size_t* bufferSize) const; + + // Sparse-dense matrix multiplication y = alpha * op(A) * op(B) + beta * C, + // where A is a sparse matrix in CSR format, B and C are dense matricies in + // column-major format. Buffer is assumed to be at least as large as the + // workspace size returned by SpMMBufferSize(). + // + // **NOTE** This is an in-place operation for data in C. + template + Status SpMM(gpusparseOperation_t transA, gpusparseOperation_t transB, + const Scalar* alpha, const gpusparseSpMatDescr_t matA, + const gpusparseDnMatDescr_t matB, const Scalar* beta, + gpusparseDnMatDescr_t matC, gpusparseSpMMAlg_t alg, + int8* buffer) const; +#endif + + // Sparse-dense vector multiplication y = alpha * op(A) * x + beta * y, + // where A is a sparse matrix in CSR format, x and y are dense vectors. See: + // http://docs.nvidia.com/cuda/cusparse/index.html#cusparse-lt-t-gt-csrmv_mergepath + // + // **NOTE** This is an in-place operation for data in y. +#if (GOOGLE_CUDA && (CUDA_VERSION < 10020)) || TENSORFLOW_USE_ROCM + template + Status Csrmv(gpusparseOperation_t transA, int m, int n, int nnz, + const Scalar* alpha_host, const gpusparseMatDescr_t descrA, + const Scalar* csrSortedValA, const int* csrSortedRowPtrA, + const int* csrSortedColIndA, const Scalar* x, + const Scalar* beta_host, Scalar* y) const; +#else + template + Status Csrmv(gpusparseOperation_t transA, int m, int n, int nnz, + const Scalar* alpha_host, const Scalar* csrSortedValA, + const int* csrSortedRowPtrA, const int* csrSortedColIndA, + const Scalar* x, const Scalar* beta_host, Scalar* y) const; +#endif // CUDA_VERSION < 10020 + + // Computes workspace size for sparse - sparse matrix addition of matrices + // stored in CSR format. + template + Status CsrgeamBufferSizeExt( + int m, int n, const Scalar* alpha, const gpusparseMatDescr_t descrA, + int nnzA, const Scalar* csrSortedValA, const int* csrSortedRowPtrA, + const int* csrSortedColIndA, const Scalar* beta, + const gpusparseMatDescr_t descrB, int nnzB, const Scalar* csrSortedValB, + const int* csrSortedRowPtrB, const int* csrSortedColIndB, + const gpusparseMatDescr_t descrC, Scalar* csrSortedValC, + int* csrSortedRowPtrC, int* csrSortedColIndC, size_t* bufferSize); + + // Computes sparse-sparse matrix addition of matrices + // stored in CSR format. This is part one: calculate nnz of the + // output. csrSortedRowPtrC must be preallocated on device with + // m + 1 entries. See: + // http://docs.nvidia.com/cuda/cusparse/index.html#cusparse-lt-t-gt-csrgeam. + Status CsrgeamNnz(int m, int n, const gpusparseMatDescr_t descrA, int nnzA, + const int* csrSortedRowPtrA, const int* csrSortedColIndA, + const gpusparseMatDescr_t descrB, int nnzB, + const int* csrSortedRowPtrB, const int* csrSortedColIndB, + const gpusparseMatDescr_t descrC, int* csrSortedRowPtrC, + int* nnzTotalDevHostPtr, void* workspace); + + // Computes sparse - sparse matrix addition of matrices + // stored in CSR format. This is part two: perform sparse-sparse + // addition. csrValC and csrColIndC must be allocated on the device + // with nnzTotalDevHostPtr entries (as calculated by CsrgeamNnz). See: + // http://docs.nvidia.com/cuda/cusparse/index.html#cusparse-lt-t-gt-csrgeam. + template + Status Csrgeam(int m, int n, const Scalar* alpha, + const gpusparseMatDescr_t descrA, int nnzA, + const Scalar* csrSortedValA, const int* csrSortedRowPtrA, + const int* csrSortedColIndA, const Scalar* beta, + const gpusparseMatDescr_t descrB, int nnzB, + const Scalar* csrSortedValB, const int* csrSortedRowPtrB, + const int* csrSortedColIndB, const gpusparseMatDescr_t descrC, + Scalar* csrSortedValC, int* csrSortedRowPtrC, + int* csrSortedColIndC, void* workspace); + + // Computes sparse-sparse matrix multiplication of matrices + // stored in CSR format. +#if TENSORFLOW_USE_ROCM + // Part one: calculate nnz of the output. + // csrSortedRowPtrC must be preallocated on device with m + 1 entries. + Status CsrgemmNnz(gpusparseOperation_t transA, gpusparseOperation_t transB, + int m, int k, int n, const gpusparseMatDescr_t descrA, + int nnzA, const int* csrSortedRowPtrA, + const int* csrSortedColIndA, + const gpusparseMatDescr_t descrB, int nnzB, + const int* csrSortedRowPtrB, const int* csrSortedColIndB, + const gpusparseMatDescr_t descrC, int* csrSortedRowPtrC, + int* nnzTotalDevHostPtr); + // Part two: perform sparse-sparse matmul. + // csrValC and csrColIndC must be allocated on the device with + // nnzTotalDevHostPtr entries (as calculated by CsrgemmNnz). + template + Status Csrgemm(gpusparseOperation_t transA, gpusparseOperation_t transB, + int m, int k, int n, const gpusparseMatDescr_t descrA, + int nnzA, const Scalar* csrSortedValA, + const int* csrSortedRowPtrA, const int* csrSortedColIndA, + const gpusparseMatDescr_t descrB, int nnzB, + const Scalar* csrSortedValB, const int* csrSortedRowPtrB, + const int* csrSortedColIndB, const gpusparseMatDescr_t descrC, + Scalar* csrSortedValC, int* csrSortedRowPtrC, + int* csrSortedColIndC); +#elif CUDA_VERSION < 12000 + // Part zero: calculate required workspace size. + template + Status CsrgemmBufferSize( + int m, int n, int k, const gpusparseMatDescr_t descrA, int nnzA, + const int* csrSortedRowPtrA, const int* csrSortedColIndA, + const gpusparseMatDescr_t descrB, int nnzB, const int* csrSortedRowPtrB, + const int* csrSortedColIndB, csrgemm2Info_t info, size_t* workspaceBytes); + // Part one: calculate nnz of the output. + // csrSortedRowPtrC must be preallocated on device with m + 1 entries. + Status CsrgemmNnz(int m, int n, int k, const gpusparseMatDescr_t descrA, + int nnzA, const int* csrSortedRowPtrA, + const int* csrSortedColIndA, + const gpusparseMatDescr_t descrB, int nnzB, + const int* csrSortedRowPtrB, const int* csrSortedColIndB, + const gpusparseMatDescr_t descrC, int* csrSortedRowPtrC, + int* nnzTotalDevHostPtr, csrgemm2Info_t info, + void* workspace); + // Part two: perform sparse-sparse matmul. + // csrValC and csrColIndC must be allocated on the device with + // nnzTotalDevHostPtr entries (as calculated by CsrgemmNnz). + template + Status Csrgemm(int m, int n, int k, const gpusparseMatDescr_t descrA, + int nnzA, const Scalar* csrSortedValA, + const int* csrSortedRowPtrA, const int* csrSortedColIndA, + const gpusparseMatDescr_t descrB, int nnzB, + const Scalar* csrSortedValB, const int* csrSortedRowPtrB, + const int* csrSortedColIndB, const gpusparseMatDescr_t descrC, + Scalar* csrSortedValC, int* csrSortedRowPtrC, + int* csrSortedColIndC, const csrgemm2Info_t info, + void* workspace); +#else // CUDA_VERSION >= 12000 + template + Status SpGEMM_workEstimation(GpuSparseConstSpMatDescr& matA, + GpuSparseConstSpMatDescr& matB, + GpuSparseSpMatDescr& matC, + GpuSparseSpGEMMDescr& spgemmDescr, + size_t* bufferSize1, void* externalBuffer1); + template + Status SpGEMM_compute(GpuSparseConstSpMatDescr& matA, + GpuSparseConstSpMatDescr& matB, + GpuSparseSpMatDescr& matC, + GpuSparseSpGEMMDescr& spgemmDescr, size_t* bufferSize2, + void* externalBuffer2); + template + Status SpGEMM_copy(GpuSparseConstSpMatDescr& matA, + GpuSparseConstSpMatDescr& matB, GpuSparseSpMatDescr& matC, + GpuSparseSpGEMMDescr& spgemmDescr); +#endif + + // In-place reordering of unsorted CSR to sorted CSR. + // http://docs.nvidia.com/cuda/cusparse/index.html#cusparse-lt-t-gt-csru2csr + template + Status Csru2csr(int m, int n, int nnz, const gpusparseMatDescr_t descrA, + Scalar* csrVal, const int* csrRowPtr, int* csrColInd); + + // Converts from CSR to CSC format (equivalently, transpose). + // http://docs.nvidia.com/cuda/cusparse/index.html#cusparse-csr2cscEx + template + Status Csr2csc(int m, int n, int nnz, const Scalar* csrVal, + const int* csrRowPtr, const int* csrColInd, Scalar* cscVal, + int* cscRowInd, int* cscColPtr, + const gpusparseAction_t copyValues); + + private: + bool initialized_; + OpKernelContext* context_; // not owned. + gpuStream_t gpu_stream_; + gpusparseHandle_t* gpusparse_handle_; // not owned. + + GpuSparse(const GpuSparse&) = delete; + void operator=(const GpuSparse&) = delete; +}; + +// A wrapper class to ensure that a CUDA sparse matrix descriptor is initialized +// only once. For more details on the descriptor (gpusparseMatDescr_t), see: +// https://docs.nvidia.com/cuda/cusparse/index.html#cusparsematdescrt +class GpuSparseMatrixDescriptor { + public: + explicit GpuSparseMatrixDescriptor() : initialized_(false) {} + + GpuSparseMatrixDescriptor(GpuSparseMatrixDescriptor&& rhs) + : initialized_(rhs.initialized_), descr_(std::move(rhs.descr_)) { + rhs.initialized_ = false; + } + + GpuSparseMatrixDescriptor& operator=(GpuSparseMatrixDescriptor&& rhs) { + if (this == &rhs) return *this; + Release(); + initialized_ = rhs.initialized_; + descr_ = std::move(rhs.descr_); + rhs.initialized_ = false; + return *this; + } + + ~GpuSparseMatrixDescriptor() { Release(); } + + // Initializes the underlying descriptor. Will fail on the second call if + // called more than once. + Status Initialize() { + DCHECK(!initialized_); +#if GOOGLE_CUDA + TF_RETURN_IF_GPUSPARSE_ERROR(cusparseCreateMatDescr(&descr_)); +#elif TENSORFLOW_USE_ROCM + TF_RETURN_IF_GPUSPARSE_ERROR(se::wrap::hipsparseCreateMatDescr(&descr_)); +#endif + initialized_ = true; + return OkStatus(); + } + + gpusparseMatDescr_t& descr() { + DCHECK(initialized_); + return descr_; + } + + const gpusparseMatDescr_t& descr() const { + DCHECK(initialized_); + return descr_; + } + + private: + void Release() { + if (initialized_) { +#if GOOGLE_CUDA + cusparseDestroyMatDescr(descr_); +#elif TENSORFLOW_USE_ROCM + se::wrap::hipsparseDestroyMatDescr(descr_); +#endif + initialized_ = false; + } + } + + bool initialized_; + gpusparseMatDescr_t descr_; + + GpuSparseMatrixDescriptor(const GpuSparseMatrixDescriptor&) = delete; + void operator=(const GpuSparseMatrixDescriptor&) = delete; +}; + +#if GOOGLE_CUDA + +// A wrapper class to ensure that an unsorted/sorted CSR conversion information +// struct (csru2csrInfo_t) is initialized only once. See: +// https://docs.nvidia.com/cuda/cusparse/index.html#csru2csr +class GpuSparseCsrSortingConversionInfo { + public: + explicit GpuSparseCsrSortingConversionInfo() : initialized_(false) {} + + GpuSparseCsrSortingConversionInfo(GpuSparseCsrSortingConversionInfo&& rhs) + : initialized_(rhs.initialized_), info_(std::move(rhs.info_)) { + rhs.initialized_ = false; + } + + GpuSparseCsrSortingConversionInfo& operator=( + GpuSparseCsrSortingConversionInfo&& rhs) { + if (this == &rhs) return *this; + Release(); + initialized_ = rhs.initialized_; + info_ = std::move(rhs.info_); + rhs.initialized_ = false; + return *this; + } + + ~GpuSparseCsrSortingConversionInfo() { Release(); } + + // Initializes the underlying info. Will fail on the second call if called + // more than once. + Status Initialize() { + DCHECK(!initialized_); + TF_RETURN_IF_GPUSPARSE_ERROR(cusparseCreateCsru2csrInfo(&info_)); + initialized_ = true; + return OkStatus(); + } + + csru2csrInfo_t& info() { + DCHECK(initialized_); + return info_; + } + + const csru2csrInfo_t& info() const { + DCHECK(initialized_); + return info_; + } + + private: + void Release() { + if (initialized_) { + cusparseDestroyCsru2csrInfo(info_); + initialized_ = false; + } + } + + bool initialized_; + csru2csrInfo_t info_; + + GpuSparseCsrSortingConversionInfo(const GpuSparseCsrSortingConversionInfo&) = + delete; + void operator=(const GpuSparseCsrSortingConversionInfo&) = delete; +}; + +#endif // GOOGLE_CUDA + +} // namespace tensorflow + +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM + +#endif // TENSORFLOW_CORE_UTIL_CUDA_SPARSE_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/util/debug_data_dumper.h b/third_party/tflite-hdrs/tensorflow/core/util/debug_data_dumper.h new file mode 100644 index 00000000..44eee52c --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/util/debug_data_dumper.h @@ -0,0 +1,138 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_UTIL_DEBUG_DATA_DUMPER_H_ +#define TENSORFLOW_CORE_UTIL_DEBUG_DATA_DUMPER_H_ + +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "tensorflow/core/platform/mutex.h" + +#define DEBUG_DATA_DUMPER() ::tensorflow::DebugDataDumper::Global() + +inline constexpr const char* kDebugGroupMain = "main"; +inline constexpr const char* kDebugGroupOpStacktrace = "op_stacktrace"; +inline constexpr const char* kDebugGroupGraphOptPass = "graph_opt_pass"; +inline constexpr const char* kDebugGroupBridgePhase1Clustering = + "bridge_phase1_clustering"; +inline constexpr const char* kDebugGroupRuntimeLowering = "runtime_lowering"; +inline constexpr const char* kDebugGroupBridgePhase1ExecutorExport = + "bridge_phase1_executor_export"; +inline constexpr const char* kDebugGroupBridgePhase2 = "bridge_phase2"; +inline constexpr const char* kDebugGroupDTensorMlir = "dtensor_mlir"; +inline constexpr const char* kDebugGroupDTensorGraph = "dtensor_graph"; +inline constexpr const char* kDebugGroupDTensorLayout = "dtensor_layout"; + +namespace tensorflow { + +class FunctionLibraryDefinition; +class Graph; + +//////////////////////////////////////////////////////////////////////////////// +// This class is responsible for dumping debugging data (e.g., GraphDef, MLIR). +// +// To dump GraphDef/MLIRs, take the following steps: +// * Set envvar TF_DUMP_GRAPH_PREFIX to your target dump directory. +// * Set envvar TF_DUMP_GRAPH_NAME_FILTER to '*' to dump all graphs, +// or a name filter to dump graphs with a name containing it. +// * Set envvar TF_DUMP_GRAPH_GROUPS to your dump groups (comma-separated). +// +// The dumped graphs then can be found in your target dump directory. +// The filename of the dump looks like this: +// ... +// +// This is what each field means: +// * : The name of your dump. +// * : The order of dumps of a specific name. +// Lower orders are executed before higher orders. +// * : The group of your dump, e.g., main. +// * : The tag of your dump, e.g., your pass name. +// +// Example dump files are: +// __inference_train_step_441.0.main.before_pre_placement_passes.pbtxt +// __inference_train_step_441.1.main.before_placer.pbtxt +// __inference_train_step_441.2.main.before_post_placement_passes.pbtxt +// __inference_train_step_441.3.main.before_graph_optimization.pbtxt +// __inference_train_step_441.4.main.after_graph_optimization.pbtxt +// __inference_train_step_441.5.main.before_post_rewrite_for_exec_passes.pbtxt +//////////////////////////////////////////////////////////////////////////////// +class DebugDataDumper { + public: + // Get the singleton instance. + static DebugDataDumper* Global(); + + // Initialize the debug data dumper. + void LoadEnvvars(); + + // Check if we should dump debug data. + // We should dump debug data only if the followings are true: + // 1. Envvar TF_DUMP_GRAPH_PREFIX is set to your target dump directory. + // 2. This condition is true if one of the followings is true. + // 2.1. TF_DUMP_GRAPH_NAME_FILTER is set to '*' + // 2.2. TF_DUMP_GRAPH_NAME_FILTER is set to a name filter + // which is a substr of name. + // 3. The group is defined in TF_DUMP_GRAPH_GROUPS. + bool ShouldDump(const std::string& name, const std::string& group) const; + + // Dump op creation callstacks, if ShouldDump returns true. + void DumpOpCreationStackTraces(const std::string& name, + const std::string& group, + const std::string& tag, const Graph* graph); + + // Dump a graph, if ShouldDump returns true. + void DumpGraph(const std::string& name, const std::string& group, + const std::string& tag, const Graph* graph, + const FunctionLibraryDefinition* func_lib_def, + bool bypass_filter = false); + + // Get the dump file basename. Dump file basenames are in this format: + // ... + // + // What each field means is explained on the class level comment. + std::string GetDumpFilename(const std::string& name, const std::string& group, + const std::string& tag); + + private: + DebugDataDumper(); + + // Get next dump id for a name. + int GetNextDumpId(const std::string& name) { + // Use a lock to make sure this is thread safe. + const mutex_lock lock(lock_); + return dump_order_ids_[name]++; + } + + // A dict to maintain the mapping from dump name to its current dump id. + absl::flat_hash_map dump_order_ids_; + + // A mutex to make sure this is thread safe. + tensorflow::mutex lock_; + + // The name filter. + std::optional name_filter_; + + // The groups filter. + std::set groups_filter_; + + // A flag indicating whether to dump wrapped graphs. + bool dump_wrapped_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_UTIL_DEBUG_DATA_DUMPER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/util/debug_events_writer.h b/third_party/tflite-hdrs/tensorflow/core/util/debug_events_writer.h new file mode 100644 index 00000000..7b104279 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/util/debug_events_writer.h @@ -0,0 +1,277 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_UTIL_DEBUG_EVENTS_WRITER_H_ +#define TENSORFLOW_CORE_UTIL_DEBUG_EVENTS_WRITER_H_ + +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/io/record_writer.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/protobuf/debug_event.pb.h" + +namespace tensorflow { +namespace tfdbg { + +// The set of files generated by a debugged TensorFlow program. +enum DebugEventFileType { + METADATA, + SOURCE_FILES, + STACK_FRAMES, + GRAPHS, + EXECUTION, + GRAPH_EXECUTION_TRACES, +}; + +// Helper class for DebugEventsWriter. +// This class manages the writing of data to a single TFRecord file. +// Each object of the DebugEventsWriter class below involves multiple +// TFRecord files, and hence utilizes multiple objects of this helper class. +class SingleDebugEventFileWriter { + public: + explicit SingleDebugEventFileWriter(const string& file_path); + + absl::Status Init(); + + void WriteSerializedDebugEvent(absl::string_view debug_event_str); + + absl::Status Flush(); + absl::Status Close(); + + const string FileName(); + + private: + Env* env_; + const string file_path_; + std::atomic_int_fast32_t num_outstanding_events_; + + std::unique_ptr writable_file_; + std::unique_ptr record_writer_ TF_PT_GUARDED_BY(writer_mu_); + mutex writer_mu_; +}; + +// The DebugEvents writer class. +class DebugEventsWriter { + public: +#ifndef SWIG + // Prefix of version string present in the first entry of every event file. + // Default size of each circular buffer (unit: number of DebugEvent protos). + static constexpr const int64_t kDefaultCyclicBufferSize = 1000; + + static constexpr const char* kFileNamePrefix = "tfdbg_events"; + static constexpr const char* kMetadataSuffix = "metadata"; + static constexpr const char* kSourceFilesSuffix = "source_files"; + static constexpr const char* kStackFramesSuffix = "stack_frames"; + static constexpr const char* kGraphsSuffix = "graphs"; + static constexpr const char* kExecutionSuffix = "execution"; + static constexpr const char* kGraphExecutionTracesSuffix = + "graph_execution_traces"; + + static constexpr const char* kVersionPrefix = "debug.Event:"; + static constexpr const int kCurrentFormatVersion = 1; +#endif + + // Get the DebugEventsWriter for the given dump_root. + // For a given dump_root value, it is a singleton. tfdbg event files come in + // sets of six. The singleton pattern avoids storing multiple sets in a single + // folder, which might cause confusion. + // + // If an instance of DebugEventsWriter has already been created at a + // `dump_root`, calling this method with the same `dump_root` will return + // the existing instance. + // + // Args: + // dump_root: Dump root directory. If it doesn't exist, will be created. + // tfdbg_run_id: Debugging run ID of the writer. + // circular_buffer_size: Circular buffer size (in number of DebugEvent + // protos). If set to a value <=0, will abolish the circular-buffer + // behavior. + // Returns: + // A pointer to a DebugEventsWriter object: a per-dump_root singleton. + static DebugEventsWriter* GetDebugEventsWriter(const string& dump_root, + const string& tfdbg_run_id, + int64_t circular_buffer_size); + // Look up existing events writer by dump_root. + // If no DebugEventsWriter has been created at the dump_root, a non-OK + // Status will be returned. Else an OK status will be returned, with + // the pointer to the existing instance provided by reference. + static absl::Status LookUpDebugEventsWriter( + const string& dump_root, DebugEventsWriter** debug_events_writer); + ~DebugEventsWriter(); + + // Sets the debug event filenames and opens file for writing. + // All files (see the DebugEventFileType enum) share the same prefix and + // differ only in their suffixes. If not called by user, will be invoked + // automatically by a call to FileName() or any of the Write*() methods(). + // Idempotent: if the metadata file exists and is open, this is a no-op. + // If on the other hand the file was opened, but has since disappeared (e.g. + // deleted by another process), this will open a new file. + absl::Status Init(); + + // The four DebugEvent fields below are written _without_ the circular + // buffer. Source file contents are written to the *.source_files file. + // Takes ownership of source_file. + absl::Status WriteSourceFile(SourceFile* source_file); + // Stack frames are written to the *.code_locations file. + // Takes ownership of stack_frame_with_id. + absl::Status WriteStackFrameWithId(StackFrameWithId* stack_frame_with_id); + // Graph op creation events are written to the *.graphs file. + // Takes ownership of graph_op_creation. + absl::Status WriteGraphOpCreation(GraphOpCreation* graph_op_creation); + // Debugged graphs are written to the *.graphs file. + // Takes ownership of debugged_graph. + absl::Status WriteDebuggedGraph(DebuggedGraph* debugged_graph); + + // The two DebugEvent fields below are written to the circular buffer + // and saved to disk only at the FlushExecutionFiles() call. + // Execution events (eager execution of an op or a tf.function) are written + // to the *.execution file. Takes ownership of execution. + absl::Status WriteExecution(Execution* execution); + // Graph execution traces (graph-internal tensor values or their summaries) + // are written to the *.graph_execution_traces file. + // Takes ownership of graph_execution_trace. + absl::Status WriteGraphExecutionTrace( + GraphExecutionTrace* graph_execution_trace); + + // Write a graph execution trace without using a protocol buffer. + // Instead, pass the raw values related to the graph execution trace. + // Args: + // tfdbg_context_id: A unique ID for the context of interest, e.g., a + // concreted compiled tf.function that the op of interest belongs to. + // op_name: Name of the op that this graph execution trace is concerned + // with. Applicable only to the single-tensor trace case. For cases in + // which the trace concerns multiple tensors, this is an empty string. + // output_slot: Output slot index of the op that this trace is concerned + // with. + // tensor_debug_mode: An integer that represents the tensor-debug mode + // enum. tensor_value: The value of the tensor that describes the + // tensor(s) + // that this trace is concerned with. The semantics of this tensor value + // depends on the value of `tensor_debug_mode`. + absl::Status WriteGraphExecutionTrace(const string& tfdbg_context_id, + const string& device_name, + const string& op_name, + int32_t output_slot, + int32_t tensor_debug_mode, + const Tensor& tensor_value); + + // Writes a serialized DebugEvent to one of the debug-events files + // concerned with the non-execution events: the SOURCE_FILES, STACK_FRAMES + // and GRAPHS files. + // NOTE: Actually used in the Python binding, to avoid overhead of + // serializing and parsing protos at the language interface. + void WriteSerializedNonExecutionDebugEvent(const string& debug_event_str, + DebugEventFileType type); + + // Writes a serialized DebugEvent to one of the debug-events files + // concerned with the execution-related events: the EXECUTION and + // GRAPH_EXECUTION_TRACES files. This involves the cyclic-buffer behavior if + // circular_buffer_size is configured to be >0. + // NOTE: Actually used in the Python binding, to avoid overhead of + // serializing and parsing protos at the language interface. + void WriteSerializedExecutionDebugEvent(const string& debug_event_str, + DebugEventFileType type); + + // Given name of the device, retrieve a unique integer ID. As a side effect, + // if this is the first time this object encounters the device name, + // writes a DebuggedDevice proto to the .graphs file in the file set. + int RegisterDeviceAndGetId(const string& device_name); + + // EventWriter automatically flushes and closes on destruction, but + // this method is provided for users who want to write to disk sooner + // and/or check for success. + // FlushNonExecutionFiles() pushes outstanding DebugEvents not written + // events to the circular buffer to their respective files. + absl::Status FlushNonExecutionFiles(); + + // Writes current contents of the circular buffers to their respective + // debug event files and clears the circular buffers. + absl::Status FlushExecutionFiles(); + + // Close() calls FlushNonExecutionFiles() and FlushExecutionFiles() + // and then closes the current debug events files. + absl::Status Close(); + + private: + static std::unordered_map>* + + // Get a static map from dump-root path to DebugEventsWriter objects. + // This helps the per-dump-root singletone pattern. + GetDebugEventsWriterMap(); + + // Guards calls to the GetDebugEventsWriter() method. + static mutex factory_mu_; + + DebugEventsWriter(const string& dump_root, const string& tfdbg_run_id, + int64_t circular_buffer_size); + + // Get the path prefix. The same for all files, which differ only in the + // suffix. + string FileName(DebugEventFileType type); + + // Initialize the TFRecord writer for non-metadata file type. + absl::Status InitNonMetadataFile(DebugEventFileType type); + + absl::Status SerializeAndWriteDebugEvent(DebugEvent* debug_event, + DebugEventFileType type); + + void SelectWriter(DebugEventFileType type, + std::unique_ptr** writer); + const string GetSuffix(DebugEventFileType type); + string GetFileNameInternal(DebugEventFileType type); + + Env* env_; + const string dump_root_; + const string tfdbg_run_id_; + + string file_prefix_; + bool is_initialized_ TF_GUARDED_BY(initialization_mu_); + mutex initialization_mu_; + + const int64_t circular_buffer_size_; + std::deque execution_buffer_ TF_GUARDED_BY(execution_buffer_mu_); + mutex execution_buffer_mu_; + std::deque graph_execution_trace_buffer_ + TF_GUARDED_BY(graph_execution_trace_buffer_mu_); + mutex graph_execution_trace_buffer_mu_; + + absl::flat_hash_map device_name_to_id_ TF_GUARDED_BY(device_mu_); + mutex device_mu_; + + std::unique_ptr metadata_writer_; + std::unique_ptr source_files_writer_; + std::unique_ptr stack_frames_writer_; + std::unique_ptr graphs_writer_; + std::unique_ptr execution_writer_; + std::unique_ptr graph_execution_traces_writer_; + + DebugEventsWriter(const DebugEventsWriter&) = delete; + void operator=(const DebugEventsWriter&) = delete; + + friend class DebugEventsWriterTest; +}; + +} // namespace tfdbg +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_UTIL_DEBUG_EVENTS_WRITER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/util/determinism.h b/third_party/tflite-hdrs/tensorflow/core/util/determinism.h new file mode 100644 index 00000000..136534ea --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/util/determinism.h @@ -0,0 +1,29 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_UTIL_DETERMINISM_H_ +#define TENSORFLOW_CORE_UTIL_DETERMINISM_H_ + +#include "xla/tsl/util/determinism.h" + +namespace tensorflow { + +using tsl::EnableOpDeterminism; +using tsl::OpDeterminismRequired; +using tsl::OpOrderDeterminismRequired; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_UTIL_DETERMINISM_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/util/device_name_utils.h b/third_party/tflite-hdrs/tensorflow/core/util/device_name_utils.h new file mode 100644 index 00000000..28b5b0f1 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/util/device_name_utils.h @@ -0,0 +1,27 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_UTIL_DEVICE_NAME_UTILS_H_ +#define TENSORFLOW_CORE_UTIL_DEVICE_NAME_UTILS_H_ + +#include "xla/tsl/util/device_name_utils.h" + +namespace tensorflow { +// NOLINTBEGIN(misc-unused-using-decls) +using tsl::DeviceNameUtils; +// NOLINTEND(misc-unused-using-decls) +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_UTIL_DEVICE_NAME_UTILS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/util/dump_graph.h b/third_party/tflite-hdrs/tensorflow/core/util/dump_graph.h new file mode 100644 index 00000000..0d0c5575 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/util/dump_graph.h @@ -0,0 +1,88 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Helper functions for dumping Graphs, GraphDefs, and FunctionDefs to files for +// debugging. + +#ifndef TENSORFLOW_CORE_UTIL_DUMP_GRAPH_H_ +#define TENSORFLOW_CORE_UTIL_DUMP_GRAPH_H_ + +#include +#include + +#include "absl/strings/string_view.h" +#include "tensorflow/core/framework/cost_graph.pb.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/graph/graph.h" + +namespace tensorflow { + +// Dumps 'graph_def' to a file, as a GraphDef text or binary proto. Returns the +// file name chosen. The format is determined by the TF_DUMP_GRAPH_FMT +// environment variable (TXT or BIN). +// +// If the TF_DUMP_GRAPH_PREFIX environment variable is "-", then instead the +// GraphDef will be logged (using the LOG() macro). +// +// Automatically picks a file name. Prefixes 'name' with the value of the +// TF_DUMP_GRAPH_PREFIX environment variable if 'dirname' is empty, and suffixes +// 'name' with '.pbtxt' or '.pb'. If a graph has already been dumped by +// this process with the same name, suffixes with "_n.pb(txt)", where 'n' is a +// sequence number. +string DumpGraphDefToFile(const string& name, GraphDef const& graph_def, + const string& dirname = ""); + +// Similar to DumpGraphDefToFile, use CostGraphDef instead of GraphDef. +string DumpCostGraphDefToFile(const string& name, CostGraphDef const& graph_def, + const string& dirname = ""); + +// Similar to DumpGraphDefToFile, but builds the GraphDef to dump from a 'graph' +// and an optional function library 'flib_def'. Returns the file name chosen. +string DumpGraphToFile(const string& name, Graph const& graph, + const FunctionLibraryDefinition* flib_def = nullptr, + const string& dirname = ""); + +// Similar to DumpGraphDefToFile, but dumps a function as a FunctionDef text +// proto. Returns the file name chosen. +string DumpFunctionDefToFile(const string& name, FunctionDef const& fdef, + const string& dirname = ""); + +// Similar to DumpGraphDefToFile, but dumps a proto of any type. Returns the +// file name chosen. +string DumpProtoToFile(const string& name, + tensorflow::protobuf::Message const& proto, + const string& dirname = ""); + +// Sets a custom Graph dumper. If set, this dumper will be used to dump graphs +// instead via DumpGraphToFile. As the custom dumper may not produce protobufs, +// allow specifying a file suffix/extension too. +void SetGraphDumper( + std::function + dumper, + string suffix = ".pbtxt"); + +// Dump data to a file. +// This function will create a WritableFile and pass it to the dumper. +// The dumper callback will be responsible for writing data to the file. +string DumpToFile(const string& name, const string& dirname, + const string& suffix, absl::string_view type_name, + std::function dumper); + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_UTIL_DUMP_GRAPH_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/util/einsum_op_util.h b/third_party/tflite-hdrs/tensorflow/core/util/einsum_op_util.h new file mode 100644 index 00000000..6155b8a0 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/util/einsum_op_util.h @@ -0,0 +1,72 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_UTIL_EINSUM_OP_UTIL_H_ +#define TENSORFLOW_CORE_UTIL_EINSUM_OP_UTIL_H_ + +#include + +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/gtl/inlined_vector.h" + +namespace tensorflow { + +using Labels = absl::InlinedVector; +using OperandLabels = absl::InlinedVector; +using LabelCounts = absl::InlinedVector; +using OperandLabelCounts = absl::InlinedVector; + +// Dummy axis label used to denote an ellipsis in an input or output subscript. +constexpr int kEllipsisLabel = -1; + +// Each dimension is categorized into exactly one of five types based on +// whether its corresponding label is present in the input and/or the output +// subscripts. +enum EinsumDimensionType { + // Batch dimensions are those present in two inputs as well as the output. + // They are part of the batch dimensions during Tensor contraction. Such + // dimensions may be broadcasting dimensions (those mapping to ellipsis) + // or explicit batch dimensions corresponding to named axis labels. + kBroadcasting = 0, + kBatch = 1, + // Free dimensions are present in exactly one of the inputs, and also the + // output. These are non-contracted axes in the Tensor contraction. + kFree = 2, + // Contract dimensions are present in two inputs, but not the output. These + // dimensions are contracted in Tensor contraction. + kContract = 3, + // Reduce dimensions are present in exactly one input; and not in the output + // and are summed over prior to Tensor contraction. + kReduce = 4, +}; + +// Parses and validates an einsum equation in explicit form. +absl::Status ValidateEinsumEquation( + const string& equation, absl::InlinedVector* input_subscripts, + string* output_subscript); + +// Parses and validates the equation and the input shapes. Single character +// labels are integerized and we populate input and output label subscripts +// and corresponding counts. Also create the mapping from (named) labels to +// their EinsumDimensionType. +absl::Status ParseEinsumEquation( + const string& equation, OperandLabels* input_labels, Labels* output_labels, + std::vector* label_types, + OperandLabelCounts* input_label_counts, LabelCounts* output_label_counts, + absl::InlinedVector* input_has_ellipsis, + bool* output_has_ellipsis); + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_UTIL_EINSUM_OP_UTIL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/util/env_var.h b/third_party/tflite-hdrs/tensorflow/core/util/env_var.h new file mode 100644 index 00000000..faad6153 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/util/env_var.h @@ -0,0 +1,34 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_UTIL_ENV_VAR_H_ +#define TENSORFLOW_CORE_UTIL_ENV_VAR_H_ + +#include "xla/tsl/util/env_var.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/stringpiece.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +using tsl::ReadBoolFromEnvVar; +using tsl::ReadFloatFromEnvVar; +using tsl::ReadInt64FromEnvVar; +using tsl::ReadStringFromEnvVar; +using tsl::ReadStringsFromEnvVar; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_UTIL_ENV_VAR_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/util/equal_graph_def.h b/third_party/tflite-hdrs/tensorflow/core/util/equal_graph_def.h new file mode 100644 index 00000000..9803b2db --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/util/equal_graph_def.h @@ -0,0 +1,100 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_UTIL_EQUAL_GRAPH_DEF_H_ +#define TENSORFLOW_CORE_UTIL_EQUAL_GRAPH_DEF_H_ + +#include "tensorflow/core/framework/graph_def_util.h" +#include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +class GraphDef; +class NodeDef; + +struct EqualGraphDefOptions { + // Should internal attributes (attribute names that start with '_') be + // ignored? + bool ignore_internal_attrs = true; +}; + +// Determines if actual and expected are equal, ignoring versions and ordering +// of nodes, attrs, and control inputs. If the GraphDefs are different and +// diff != nullptr, *diff is set to an explanation of the difference. Note that +// we use node names to match up nodes between the graphs, and so the naming of +// nodes must be consistent. +bool EqualGraphDef(const GraphDef& actual, const GraphDef& expected, + string* diff, const EqualGraphDefOptions& options = {}); + +// Returns a hash of `gdef` that is consistent with EqualGraphDef. In other +// words, if two graph defs compare equal according to EqualGraphDef, +// GraphDefHash will return the same value for both of them when called +// with the same `options` that was used in the call to EqualGraphDef. +// Similarly to protobuf deterministic serialization, hash value is +// guaranteed to be stable only for a given binary. In particular, one should +// probably not persist the returned value. +uint64 GraphDefHash(const GraphDef& gdef, + const EqualGraphDefOptions& options = {}); + +// Determines if actual and expected are equal, ignoring: ordering of +// attrs, internal attributes (if set in `options`), and control inputs. +// +// If the NodeDefs are different and +// diff != nullptr, *diff is set to an explanation of the difference. +bool EqualNodeDef(const NodeDef& actual, const NodeDef& expected, string* diff, + const EqualGraphDefOptions& options = {}); + +// Returns a hash of `ndef` that is consistent with EqualNodeDef. In other +// words, if two node defs compare equal according to EqualNodeDef, NodeDefHash +// will return the same value for both of them when called with the same +// `options` that was used in the call to EqualNodeDef. +// Similarly to protobuf deterministic serialization, hash value is +// guaranteed to be stable only for a given binary. In particular, one should +// probably not persist the returned value. +uint64 NodeDefHash(const NodeDef& ndef, + const EqualGraphDefOptions& options = {}); + +// Determines if actual and expected are equal, ignoring ordering. If they're +// different and diff != nullptr, *diff is set to an explanation of the +// difference. +bool EqualRepeatedNodeDef(const protobuf::RepeatedPtrField& actual, + const protobuf::RepeatedPtrField& expected, + string* diff, + const EqualGraphDefOptions& options = {}); + +// Returns a hash of `ndefs` that is consistent with EqualRepeatedNodeDef. +// In other words, if two ndefs compare equal according to +// EqualRepeatedNodeDef, RepeatedNodeDefHash will return the same value for +// both of them when called with the same `options` that was used in +// the call to EqualRepeatedNodeDef. +// Similarly to protobuf deterministic serialization, hash value is +// guaranteed to be stable only for a given binary. In particular, one should +// probably not persist the returned value. +uint64 RepeatedNodeDefHash(const protobuf::RepeatedPtrField& ndefs, + const EqualGraphDefOptions& options = {}); + +#define TF_EXPECT_GRAPH_EQ(expected, actual) \ + do { \ + string diff; \ + EXPECT_TRUE(EqualGraphDef(actual, expected, &diff)) \ + << diff << "\nExpected:\n" \ + << SummarizeGraphDef(expected) << "\nActual:\n" \ + << SummarizeGraphDef(actual); \ + } while (false) + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_UTIL_EQUAL_GRAPH_DEF_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/util/events_writer.h b/third_party/tflite-hdrs/tensorflow/core/util/events_writer.h new file mode 100644 index 00000000..a06eac7d --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/util/events_writer.h @@ -0,0 +1,103 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_UTIL_EVENTS_WRITER_H_ +#define TENSORFLOW_CORE_UTIL_EVENTS_WRITER_H_ + +#include +#include + +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/io/record_writer.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/event.pb.h" + +namespace tensorflow { + +class EventsWriter { + public: +#ifndef SWIG + // Prefix of version string present in the first entry of every event file. + static constexpr const char* kVersionPrefix = "brain.Event:"; + static constexpr const int kCurrentVersion = 2; + static constexpr const char* kWriterSourceMetadata = + "tensorflow.core.util.events_writer"; +#endif + + // Events files typically have a name of the form + // '/some/file/path/my.file.out.events.[timestamp].[hostname][suffix]' + // To create and EventWriter, the user should provide file_prefix = + // '/some/file/path/my.file' + // The EventsWriter will append '.out.events.[timestamp].[hostname][suffix]' + // to the ultimate filename once Init() is called. + // Note that it is not recommended to simultaneously have two + // EventWriters writing to the same file_prefix. + explicit EventsWriter(const std::string& file_prefix); + ~EventsWriter(); + + // Sets the event file filename and opens file for writing. If not called by + // user, will be invoked automatically by a call to FileName() or Write*(). + // Returns false if the file could not be opened. Idempotent: if file exists + // and is open this is a no-op. If on the other hand the file was opened, + // but has since disappeared (e.g. deleted by another process), this will open + // a new file with a new timestamp in its filename. + absl::Status Init(); + absl::Status InitWithSuffix(const std::string& suffix); + + // Returns the filename for the current events file: + // filename_ = [file_prefix_].out.events.[timestamp].[hostname][suffix] + std::string FileName(); + + // Append "event" to the file. The "tensorflow::" part is for swig happiness. + void WriteEvent(const tensorflow::Event& event); + + // Append "event_str", a serialized Event, to the file. + // Note that this function does NOT check that de-serializing event_str + // results in a valid Event proto. The tensorflow:: bit makes SWIG happy. + void WriteSerializedEvent(absl::string_view event_str); + + // EventWriter automatically flushes and closes on destruction, but + // these two methods are provided for users who want to write to disk sooner + // and/or check for success. + // Flush() pushes outstanding events to disk. Returns false if the + // events file could not be created, or if the file exists but could not + // be written too. + // Close() calls Flush() and then closes the current events file. + // Returns true only if both the flush and the closure were successful. + absl::Status Flush(); + absl::Status Close(); + + private: + absl::Status FileStillExists(); // OK if event_file_path_ exists. + absl::Status InitIfNeeded(); + + Env* env_; + const std::string file_prefix_; + std::string file_suffix_; + std::string filename_; + std::unique_ptr recordio_file_; + std::unique_ptr recordio_writer_; + int num_outstanding_events_; +#ifndef SWIG + EventsWriter(const EventsWriter&) = delete; + void operator=(const EventsWriter&) = delete; +#endif +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_UTIL_EVENTS_WRITER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/util/example_proto_fast_parsing.h b/third_party/tflite-hdrs/tensorflow/core/util/example_proto_fast_parsing.h new file mode 100644 index 00000000..6ba6d89a --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/util/example_proto_fast_parsing.h @@ -0,0 +1,172 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_UTIL_EXAMPLE_PROTO_FAST_PARSING_H_ +#define TENSORFLOW_CORE_UTIL_EXAMPLE_PROTO_FAST_PARSING_H_ + +#include +#include +#include +#include + +#include "tensorflow/core/example/example.pb.h" +#include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/partial_tensor_shape.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/sparse/sparse_tensor.h" + +namespace tensorflow { +namespace example { + +// FastParseExampleConfig defines how to parse features in Example. +// Each sub-config is responsible for one feature identified with feature_name. +// FastParseExampleConfig can't have two sub-configs with the same feature_name. +// dtype identifies the type of output vector and the kind of Feature expected +// in Example. +struct FastParseExampleConfig { + struct Dense { + Dense(absl::string_view feature_name, DataType dtype, + PartialTensorShape shape, Tensor default_value, bool variable_length, + std::size_t elements_per_stride) + : feature_name(feature_name), // TODO(mrry): Switch to preallocated + // tstring when this is available. + dtype(dtype), + shape(std::move(shape)), + default_value(std::move(default_value)), + variable_length(variable_length), + elements_per_stride(elements_per_stride) {} + Dense() = default; + + tstring feature_name; + DataType dtype; + // These 2 fields correspond exactly to dense_shapes and dense_defaults in + // ParseExample op. + // Documentation is available in: tensorflow/core/ops/parsing_ops.cc + PartialTensorShape shape; + Tensor default_value; + bool variable_length; + std::size_t elements_per_stride; + }; + + struct Sparse { + Sparse(absl::string_view feature_name, DataType dtype) + : feature_name(feature_name), // TODO(mrry): Switch to preallocated + // tstring when this is available. + dtype(dtype) {} + Sparse() = default; + + tstring feature_name; + DataType dtype; + }; + + struct Ragged { + Ragged(absl::string_view feature_name, DataType dtype, + DataType splits_dtype) + : feature_name(feature_name), // TODO(mrry): Switch to preallocated + // tstring when this is available. + dtype(dtype), + splits_dtype(splits_dtype) {} + Ragged() = default; + + tstring feature_name; + DataType dtype; + DataType splits_dtype; + }; + + std::vector dense; + std::vector sparse; + std::vector ragged; + + // If `true`, `Result::feature_stats` will contain one + // `PerExampleFeatureStats` for each serialized example in the input. + bool collect_feature_stats = false; +}; + +// Statistics about the features in each example passed to +// `FastParse[Single]Example()`. +// +// TODO(b/111553342): The gathered statistics currently have two limitations: +// * Feature names that appear more than once will be counted multiple times. +// * The feature values count only represents the counts for features that were +// requested in the `FastParseExampleConfig`. +// These could be addressed with additional work at runtime. +struct PerExampleFeatureStats { + // The number of feature names in an example. + size_t features_count = 0; + + // The sum of the number of values in each feature that is parsed. + size_t feature_values_count = 0; +}; + +// This is exactly the output of TF's ParseExample Op. +// Documentation is available in: tensorflow/core/ops/parsing_ops.cc +struct Result { + std::vector sparse_indices; + std::vector sparse_values; + std::vector sparse_shapes; + std::vector dense_values; + std::vector ragged_values; + std::vector ragged_splits; + std::vector ragged_outer_splits; // For SequenceExamples + + // This vector will be populated with one element per example if + // `FastParseExampleConfig::collect_feature_stats` is set to `true`. + std::vector feature_stats; +}; + +// Parses a batch of serialized Example protos and converts them into result +// according to given config. +// Given example names have to either be empty or the same size as serialized. +// example_names are used only for error messages. +absl::Status FastParseExample(const FastParseExampleConfig& config, + absl::Span serialized, + absl::Span example_names, + thread::ThreadPool* thread_pool, Result* result); + +// TODO(mrry): Move the hash table construction into the config object. +typedef FastParseExampleConfig FastParseSingleExampleConfig; + +absl::Status FastParseSingleExample(const FastParseSingleExampleConfig& config, + absl::string_view serialized, + Result* result); + +// Parses a batch of serialized SequenceExample protos and converts them into +// result according to given config. +// Given example names have to either be empty or the same size as serialized. +// example_names are used only for error messages. +// (If batch=true, then this parses a single SequenceExample.) +absl::Status FastParseSequenceExample( + const example::FastParseExampleConfig& context_config, + const example::FastParseExampleConfig& sequence_config, + absl::Span serialized, + absl::Span example_names, thread::ThreadPool* thread_pool, + example::Result* context_result, example::Result* sequence_result, + std::vector* dense_feature_lengths, bool is_batch = true); + +// This function parses serialized Example and populates given example. +// It uses the same specialized parser as FastParseExample which is efficient. +// But then constructs Example which is relatively slow. +// It is exported here as a convenient API to test parser part separately. +bool TestFastParse(const string& serialized, Example* example); + +} // namespace example +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_UTIL_EXAMPLE_PROTO_FAST_PARSING_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/util/example_proto_helper.h b/third_party/tflite-hdrs/tensorflow/core/util/example_proto_helper.h new file mode 100644 index 00000000..801aae37 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/util/example_proto_helper.h @@ -0,0 +1,369 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_UTIL_EXAMPLE_PROTO_HELPER_H_ +#define TENSORFLOW_CORE_UTIL_EXAMPLE_PROTO_HELPER_H_ + +#include +#include +#include + +#include "tensorflow/core/example/example.pb.h" +#include "tensorflow/core/example/feature.pb.h" +#include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/partial_tensor_shape.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/sparse/sparse_tensor.h" + +// This is a set of helper methods that will make it possible to share +// tensorflow::Example proto Tensor conversion code inside the ExampleParserOp +// OpKernel as well as in external code. +namespace tensorflow { + +// "Dense" feature configuration. +struct FixedLenFeature { + string key; + DataType dtype; + TensorShape shape; + Tensor default_value; + string values_output_tensor_name; +}; + +// "Sparse" feature configuration. +struct VarLenFeature { + string key; + DataType dtype; + string values_output_tensor_name; + string indices_output_tensor_name; + string shapes_output_tensor_name; +}; + +// Given a single tensorflow::Example, with an optional example name +// at a particular index within a batch, and dense and sparse feature +// configurations from fixed_len_features, var_len_features, this method +// updates the dense value tensor and the sparse values temporary vector +// of tensors. The indexing of the output vectors correspond 1:1 to the +// indexing of the feature configuration vectors. +// +// The fixed_len_features and var_len_features maps are assume to be +// have disjoint key fields from the Feature map in the tensorflow.Example +// proto. +// +// For each sparse feature, the sparse values temporary vector holds a +// tensor for each Example. Each tensor is either empty or filled, depending +// on if the sparse feature value is set for the Example. This +// temporary structure is needed because we need to know the total number +// of filled elements in the batch to get the proper final sparse tensor +// shapes allocated. After the entire batch is processed, +// GetSparseTensorShape can be used to calculate the final shapes and +// CopyIntoSparseTensor can be used to copy from the temporary vector +// into the final allocated tensors. +absl::Status SingleExampleProtoToTensors( + const Example& example, const string& name, int batch_index, + const std::vector& fixed_len_features, + const std::vector& var_len_features, + std::vector* output_dense_values_tensor, + std::vector>* output_sparse_values_tmp); + +// The shape of the indices and values tensors associated with a SparseTensor +// are dependent on the contents of the batch. +struct VarLenFeatureBatchShapes { + TensorShape indices_shape; + TensorShape values_shape; + int max_num_features; +}; + +// Get the shape of the sparse values and indices tensors for the batch, +// given how many of the tensors in the temporary sparse values vector +// are actually filled. +absl::Status GetSparseTensorShapes(const VarLenFeature& var_len_feature, + const std::vector& sparse_values_tmp, + int batch_size, + VarLenFeatureBatchShapes* output_shapes); + +// A method to convert a batch of tensorflow::Example protos into output +// tensors. This method is useful if there already is a batch of deserialized +// Example protos in memory (such as a serving use-case) and we do not wish +// to incur an extraneous serialize/deserialize. It is intended +// as an outside of OpKernel compatible replacement for the functionality of +// ExampleParserOp. In a serving setting, this method could be used to produce +// a feed_dict of Tensors that could bypass the ExampleParserOp. +// +// Note that unlike SingleExampleProtoToTensors, output tensors are +// allocated using a provided Allocator within this method. +absl::Status BatchExampleProtoToTensors( + const std::vector& examples, + const std::vector& names, + const std::vector& fixed_len_features, + const std::vector& var_len_features, Allocator* allocator, + std::vector* output_dense_values_tensor, + std::vector* output_sparse_indices_tensor, + std::vector* output_sparse_values_tensor, + std::vector* output_sparse_shapes_tensor); + +// Check that the given dtype is one that is compatible with +// tensorflow::Example protocol buffer feature values. +absl::Status CheckValidType(const DataType& dtype); + +// Check that the provided Feature proto message's oneof value +// matches that of the provided dtype. +absl::Status CheckTypesMatch(const Feature& feature, const DataType& dtype, + bool* match); + +// For a single Example, copy a dense feature value into an output +// dense value tensor Out at the provided out_index offset. +absl::Status FeatureDenseCopy(std::size_t out_index, const string& name, + const string& key, const DataType& dtype, + const TensorShape& shape, const Feature& feature, + Tensor* out); + +// Copy the value a provided Tensor into an output dense_value tensor Out +// at the provided out_index offset. +void RowDenseCopy(const std::size_t& out_index, const DataType& dtype, + const Tensor& in, Tensor* out); + +// For a single Example, and given sparse feature return a temporary output +// Tensor suitable for being collected in the temporary sparse value vector. +Tensor FeatureSparseCopy(std::size_t batch, const string& key, + const DataType& dtype, const Feature& feature); + +// Copy a temporary Tensor into the final sparse indices and values +// tensor at a given batch index and element offset. This method +// assumes that the indices/values Tensors have been properly allocated +// for the batch. +int64_t CopyIntoSparseTensor(const Tensor& in, int batch, int64_t offset, + Tensor* indices, Tensor* values); + +// Check that each dense_shape has known rank and inner dimensions; and +// update variable_length (whether the outer dimension is None) and +// elements_per_stride for each denes_shape. +absl::Status GetDenseShapes(const std::vector& dense_shapes, + std::vector* variable_length, + std::vector* elements_per_stride); + +// Parses the attributes passed to ParseExample. +// REQUIRES: Init must be called after construction. +struct ParseExampleAttrs { + public: + template + absl::Status Init(ContextType* ctx, int op_version = 1) { + TF_RETURN_IF_ERROR(ctx->GetAttr("sparse_types", &sparse_types)); + TF_RETURN_IF_ERROR(ctx->GetAttr("Tdense", &dense_types)); + TF_RETURN_IF_ERROR(ctx->GetAttr("dense_shapes", &dense_shapes)); + TF_RETURN_IF_ERROR( + GetDenseShapes(dense_shapes, &variable_length, &elements_per_stride)); + switch (op_version) { + case 1: + TF_RETURN_IF_ERROR(ctx->GetAttr("Nsparse", &num_sparse)); + TF_RETURN_IF_ERROR(ctx->GetAttr("Ndense", &num_dense)); + break; + case 2: + TF_RETURN_IF_ERROR( + ctx->GetAttr("ragged_value_types", &ragged_value_types)); + TF_RETURN_IF_ERROR(ctx->GetAttr("num_sparse", &num_sparse)); + TF_RETURN_IF_ERROR( + ctx->GetAttr("ragged_split_types", &ragged_split_types)); + break; + default: + return errors::InvalidArgument("Unexpected op_version", op_version); + } + return FinishInit(op_version); + } + + int64_t num_sparse; + int64_t num_dense; + int64_t num_ragged; + std::vector sparse_types; + std::vector dense_types; + std::vector ragged_value_types; + std::vector ragged_split_types; + std::vector dense_shapes; + std::vector variable_length; + std::vector elements_per_stride; + + private: + absl::Status FinishInit( + int op_version); // for context-independent parts of Init. +}; + +// Parses the attributes passed to ParseSingleExample. +// REQUIRES: Init must be called after construction. +struct ParseSingleExampleAttrs { + public: + template + absl::Status Init(ContextType* ctx) { + TF_RETURN_IF_ERROR(ctx->GetAttr("sparse_keys", &sparse_keys)); + TF_RETURN_IF_ERROR(ctx->GetAttr("sparse_types", &sparse_types)); + TF_RETURN_IF_ERROR(ctx->GetAttr("dense_keys", &dense_keys)); + TF_RETURN_IF_ERROR(ctx->GetAttr("Tdense", &dense_types)); + TF_RETURN_IF_ERROR(ctx->GetAttr("dense_shapes", &dense_shapes)); + + int num_sparse; + TF_RETURN_IF_ERROR(ctx->GetAttr("num_sparse", &num_sparse)); + if (num_sparse != sparse_keys.size() || num_sparse != sparse_types.size()) { + return errors::InvalidArgument( + "num_sparse (", num_sparse, ") must match the size of sparse_keys (", + sparse_keys.size(), ") and sparse_types (", sparse_types.size(), ")"); + } + + TF_RETURN_IF_ERROR( + GetDenseShapes(dense_shapes, &variable_length, &elements_per_stride)); + return FinishInit(); + } + + std::vector sparse_keys; + std::vector sparse_types; + std::vector dense_keys; + std::vector dense_types; + std::vector dense_shapes; + std::vector variable_length; + std::vector elements_per_stride; + + private: + absl::Status FinishInit(); // for context-independent parts of Init. +}; + +// Parses the attributes passed to ParseSequenceExample. +// REQUIRES: Init must be called after construction. +struct ParseSequenceExampleAttrs { + public: + template + absl::Status Init(ContextType* ctx, int op_version = 1) { + switch (op_version) { + case 1: { + std::vector missing_empty_vector; + TF_RETURN_IF_ERROR(ctx->GetAttr( + "feature_list_dense_missing_assumed_empty", &missing_empty_vector)); + for (const string& feature : missing_empty_vector) { + feature_list_dense_missing_assumed_empty.insert(feature); + } + } + TF_RETURN_IF_ERROR( + ctx->GetAttr("context_sparse_keys", &context_sparse_keys)); + TF_RETURN_IF_ERROR( + ctx->GetAttr("context_dense_keys", &context_dense_keys)); + TF_RETURN_IF_ERROR(ctx->GetAttr("feature_list_sparse_keys", + &feature_list_sparse_keys)); + TF_RETURN_IF_ERROR( + ctx->GetAttr("feature_list_dense_keys", &feature_list_dense_keys)); + TF_RETURN_IF_ERROR(ctx->GetAttr("Ncontext_dense", &num_context_dense)); + break; + case 2: + TF_RETURN_IF_ERROR(ctx->GetAttr("context_ragged_value_types", + &context_ragged_value_types)); + TF_RETURN_IF_ERROR(ctx->GetAttr("context_ragged_split_types", + &context_ragged_split_types)); + TF_RETURN_IF_ERROR(ctx->GetAttr("feature_list_ragged_value_types", + &feature_list_ragged_value_types)); + TF_RETURN_IF_ERROR(ctx->GetAttr("feature_list_ragged_split_types", + &feature_list_ragged_split_types)); + break; + default: + return errors::InvalidArgument("Unexpected op_version", op_version); + } + TF_RETURN_IF_ERROR( + ctx->GetAttr("context_sparse_types", &context_sparse_types)); + TF_RETURN_IF_ERROR( + ctx->GetAttr("Nfeature_list_dense", &num_feature_list_dense)); + TF_RETURN_IF_ERROR(ctx->GetAttr("Ncontext_sparse", &num_context_sparse)); + TF_RETURN_IF_ERROR(ctx->GetAttr("Tcontext_dense", &context_dense_types)); + TF_RETURN_IF_ERROR( + ctx->GetAttr("feature_list_sparse_types", &feature_list_sparse_types)); + TF_RETURN_IF_ERROR( + ctx->GetAttr("feature_list_dense_types", &feature_list_dense_types)); + TF_RETURN_IF_ERROR( + ctx->GetAttr("Nfeature_list_sparse", &num_feature_list_sparse)); + TF_RETURN_IF_ERROR( + ctx->GetAttr("context_dense_shapes", &context_dense_shapes)); + TF_RETURN_IF_ERROR( + ctx->GetAttr("feature_list_dense_shapes", &feature_list_dense_shapes)); + return FinishInit(op_version); + } + + std::unordered_set feature_list_dense_missing_assumed_empty; + int64_t num_context_sparse; + int64_t num_context_dense; + int64_t num_context_ragged; + int64_t num_feature_list_sparse; + int64_t num_feature_list_dense; + int64_t num_feature_list_ragged; + std::vector context_sparse_keys; + std::vector context_dense_keys; + std::vector feature_list_sparse_keys; + std::vector feature_list_dense_keys; + std::vector context_sparse_types; + std::vector context_dense_types; + std::vector context_dense_shapes; + std::vector feature_list_sparse_types; + std::vector feature_list_dense_types; + std::vector feature_list_dense_shapes; + std::vector context_ragged_value_types; + std::vector context_ragged_split_types; + std::vector feature_list_ragged_value_types; + std::vector feature_list_ragged_split_types; + + private: + absl::Status FinishInit( + int op_version); // for context-independent parts of Init. +}; + +// Parses the attributes passed to ParseSingleSequenceExample. +// REQUIRES: Init must be called after construction. +struct ParseSingleSequenceExampleAttrs { + public: + template + absl::Status Init(ContextType* ctx) { + TF_RETURN_IF_ERROR( + ctx->GetAttr("context_sparse_types", &context_sparse_types)); + TF_RETURN_IF_ERROR(ctx->GetAttr("Ncontext_dense", &num_context_dense)); + TF_RETURN_IF_ERROR( + ctx->GetAttr("Nfeature_list_dense", &num_feature_list_dense)); + TF_RETURN_IF_ERROR(ctx->GetAttr("Ncontext_sparse", &num_context_sparse)); + TF_RETURN_IF_ERROR(ctx->GetAttr("Tcontext_dense", &context_dense_types)); + TF_RETURN_IF_ERROR( + ctx->GetAttr("feature_list_sparse_types", &feature_list_sparse_types)); + TF_RETURN_IF_ERROR( + ctx->GetAttr("feature_list_dense_types", &feature_list_dense_types)); + TF_RETURN_IF_ERROR( + ctx->GetAttr("Nfeature_list_sparse", &num_feature_list_sparse)); + TF_RETURN_IF_ERROR( + ctx->GetAttr("context_dense_shapes", &context_dense_shapes)); + TF_RETURN_IF_ERROR( + ctx->GetAttr("feature_list_dense_shapes", &feature_list_dense_shapes)); + return FinishInit(); + } + + int64_t num_context_sparse; + int64_t num_context_dense; + int64_t num_feature_list_sparse; + int64_t num_feature_list_dense; + std::vector context_sparse_types; + std::vector context_dense_types; + std::vector context_dense_shapes; + std::vector feature_list_sparse_types; + std::vector feature_list_dense_types; + std::vector feature_list_dense_shapes; + + private: + absl::Status FinishInit(); // for context-independent parts of Init. +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_UTIL_EXAMPLE_PROTO_HELPER_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/util/exec_on_stall.h b/third_party/tflite-hdrs/tensorflow/core/util/exec_on_stall.h new file mode 100644 index 00000000..d4a6c552 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/util/exec_on_stall.h @@ -0,0 +1,89 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_UTIL_EXEC_ON_STALL_H_ +#define TENSORFLOW_CORE_UTIL_EXEC_ON_STALL_H_ + +#include + +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/mutex.h" + +namespace tensorflow { + +// An object that executes a particular function only if it +// is not deleted within the allotted number of seconds. +// +// This can be useful in diagnosing deadlocks, stalls and memory leaks +// without logging too aggressively. +class ExecuteOnStall { + public: + // delay_secs: If the object still exists after this many seconds, + // execute f. + // f: The function to be executed, for example a detailed log of the + // the state of an object to which this is attached. + // poll_microseconds: The spawned thread will wake and test whether + // the destructor has been invoked this frequently. + ExecuteOnStall(int delay_secs, std::function f, + int32_t poll_microseconds = 100) + : disabled_(false), + joined_(false), + env_(Env::Default()), + f_(f), + poll_microseconds_(poll_microseconds) { + deadline_ = env_->NowMicros() + 1000000 * delay_secs; + env_->SchedClosure([this]() { + while (env_->NowMicros() < deadline_) { + { + mutex_lock l(mu_); + if (disabled_) { + break; + } + } + env_->SleepForMicroseconds(poll_microseconds_); + } + { + mutex_lock l(mu_); + if (!disabled_) { + f_(); + } + joined_ = true; + cond_var_.notify_all(); + } + }); + } + + ~ExecuteOnStall() { + // Wait for spawned thread to terminate. + mutex_lock l(mu_); + disabled_ = true; + if (!joined_) { + cond_var_.wait(l); + } + } + + private: + mutex mu_; + condition_variable cond_var_; + bool disabled_ TF_GUARDED_BY(mu_); + bool joined_ TF_GUARDED_BY(mu_); + Env* env_; + std::function f_; + int64_t deadline_; + int32 poll_microseconds_; +}; + +} // namespace tensorflow +#endif // TENSORFLOW_CORE_UTIL_EXEC_ON_STALL_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/util/fake_clock_env.h b/third_party/tflite-hdrs/tensorflow/core/util/fake_clock_env.h new file mode 100644 index 00000000..2ded1708 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/util/fake_clock_env.h @@ -0,0 +1,58 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_UTIL_FAKE_CLOCK_ENV_H_ +#define TENSORFLOW_CORE_UTIL_FAKE_CLOCK_ENV_H_ + +#include +#include +#include + +#include "tensorflow/core/lib/core/notification.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/thread_annotations.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +// An Env implementation with a fake clock for NowMicros(). +// The clock doesn't advance on its own. It advances +// via an explicit AdvanceByMicroseconds() method. All other Env virtual methods +// pass through to a wrapped Env. +class FakeClockEnv : public EnvWrapper { + public: + explicit FakeClockEnv(Env* wrapped); + ~FakeClockEnv() override = default; + + // Advance the clock by a certain number of microseconds. + void AdvanceByMicroseconds(int64_t micros); + + // Returns the current time of FakeClockEnv in microseconds. + uint64 NowMicros() const override; + + private: + mutable mutex mu_; + uint64 current_time_ TF_GUARDED_BY(mu_) = 0; + + FakeClockEnv(const FakeClockEnv&) = delete; + void operator=(const FakeClockEnv&) = delete; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_UTIL_FAKE_CLOCK_ENV_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/util/gpu_cuda_alias.h b/third_party/tflite-hdrs/tensorflow/core/util/gpu_cuda_alias.h new file mode 100644 index 00000000..0a15d15e --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/util/gpu_cuda_alias.h @@ -0,0 +1,60 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_UTIL_GPU_CUDA_ALIAS_H_ +#define TENSORFLOW_CORE_UTIL_GPU_CUDA_ALIAS_H_ + +// Several forwarding macros are defined in this file to serve for backward +// compatibility usage as we migrating from CUDA prefixed function to GPU +// prefixed functions. Both Cuda and ROCm can unify under the new GPU prefix +// naming scheme. In the migration period, we provide equivalent CUDA* and GPU* +// function. Over time, all CUDA* functions will be deprecated. + +namespace tensorflow { + +// CREATE_CUDA_HOST_FUNCTION_ALIAS forward the host function to its CUDA Alias. +#ifndef TENSORFLOW_USE_ROCM +#define CREATE_CUDA_HOST_FUNCTION_ALIAS(func, cuda_alias) \ + template \ + auto cuda_alias(Args&&... args) \ + ->decltype(func(std::forward(args)...)) { \ + return func(std::forward(args)...); \ + } +#else +#define CREATE_CUDA_HOST_FUNCTION_ALIAS(func, cuda_alias) +#endif + +// CREATE_CUDA_DEVICE_FUNCTION_ALIAS forward the device function to its CUDA +// Alias. +#ifndef TENSORFLOW_USE_ROCM +#define CREATE_CUDA_DEVICE_FUNCTION_ALIAS(func, cuda_alias) \ + template \ + __device__ auto cuda_alias(Args&&... args) \ + ->decltype(func(std::forward(args)...)) { \ + return func(std::forward(args)...); \ + } +#else +#define CREATE_CUDA_DEVICE_FUNCTION_ALIAS(func, cuda_alias) +#endif + +// CREATE_CUDA_TYPE_ALIAS forward the type to its CUDA Alias. +#ifndef TENSORFLOW_USE_ROCM +#define CREATE_CUDA_TYPE_ALIAS(type, cuda_alias) using cuda_alias = type; +#else +#define CREATE_CUDA_TYPE_ALIAS(type, cuda_alias) +#endif +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_UTIL_GPU_CUDA_ALIAS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/util/gpu_device_functions.h b/third_party/tflite-hdrs/tensorflow/core/util/gpu_device_functions.h new file mode 100644 index 00000000..bb9ff8c7 --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/util/gpu_device_functions.h @@ -0,0 +1,1002 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_UTIL_GPU_DEVICE_FUNCTIONS_H_ +#define TENSORFLOW_CORE_UTIL_GPU_DEVICE_FUNCTIONS_H_ + +/** + * Wrappers and helpers for CUDA device code. + * + * Wraps the warp-cooperative intrinsics introduced in CUDA 9 to provide + * backwards compatibility, see go/volta-porting for details. + * Provides atomic operations on types that aren't natively supported. + * Defines a number of macros and types providing a shared interface + * to either CUDA or ROCm APIs, depending on the build. + */ + +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM + +#include +#include + +#include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#if GOOGLE_CUDA +#include "third_party/gpus/cuda/include/cuda.h" +#else +#include "rocm/include/hip/hip_complex.h" +#endif + +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/gpu_cuda_alias.h" + +#if GOOGLE_CUDA +using gpuStream_t = cudaStream_t; +using gpuEvent_t = cudaEvent_t; +#define gpuEventRecord cudaEventRecord +#define gpuEventSynchronize cudaEventSynchronize +#define gpuEventDestroy cudaEventDestroy +#define gpuEventCreate cudaEventCreate +#define gpuEventCreateWithFlags cudaEventCreateWithFlags +#define gpuEventDisableTiming cudaEventDisableTiming +#define gpuDeviceSynchronize cudaDeviceSynchronize +#define gpuFree cudaFree +#elif TENSORFLOW_USE_ROCM +using gpuStream_t = hipStream_t; +using gpuEvent_t = hipEvent_t; +using cudaError = int; +using cudaError_t = int; +#define cudaSuccess 0 +#define cudaGetLastError hipGetLastError +#define gpuEventRecord hipEventRecord +#define gpuEventDestroy hipEventDestroy +#define gpuEventSynchronize hipEventSynchronize +#define gpuEventCreate hipEventCreate +#define gpuEventCreateWithFlags hipEventCreateWithFlags +#define gpuEventDisableTiming hipEventDisableTiming +#define gpuDeviceSynchronize hipDeviceSynchronize +#define gpuFree hipFree +static std::string cudaGetErrorString(int err) { return std::to_string(err); } +#endif + +#define TF_RETURN_IF_CUDA_ERROR(result) \ + do { \ + cudaError_t error(result); \ + if (!TF_PREDICT_TRUE(error == cudaSuccess)) { \ + return absl::InternalError( \ + absl::StrCat("Cuda call failed with ", cudaGetErrorString(error))); \ + } \ + } while (0) + +#define TF_OP_REQUIRES_CUDA_SUCCESS(context, result) \ + do { \ + cudaError_t error(result); \ + if (!TF_PREDICT_TRUE(error == cudaSuccess)) { \ + context->SetStatus(absl::InternalError( \ + absl::StrCat("Cuda call failed with", cudaGetErrorString(error)))); \ + return; \ + } \ + } while (0) + +namespace tensorflow { +// According to HIP developer guide at +// https://github.com/ROCm-Developer-Tools/HIP/blob/master/docs/markdown/hip_kernel_language.md#assert +// assert is not supported by HIP. While we are waiting for assert support in +// hip kernels, the assert call should be macroed to NOP so that it does not +// block us from creating a debug build +#if TENSORFLOW_USE_ROCM +#undef assert +#define assert(x) \ + {} +#endif + +namespace detail { + +// Helper for range-based for loop using 'delta' increments. +// Usage: see GpuGridRange?() functions below. +template +class GpuGridRange { + struct Iterator { + __device__ Iterator(T index, T delta) : index_(index), delta_(delta) {} + __device__ T operator*() const { return index_; } + __device__ Iterator& operator++() { + index_ += delta_; + return *this; + } + __device__ bool operator!=(const Iterator& other) const { + bool greater = index_ > other.index_; + bool less = index_ < other.index_; + // Anything past an end iterator (delta_ == 0) is equal. + // In range-based for loops, this optimizes to 'return less'. + if (!other.delta_) { + return less; + } + if (!delta_) { + return greater; + } + return less || greater; + } + + private: + T index_; + const T delta_; + }; + + public: + __device__ GpuGridRange(T begin, T delta, T end) + : begin_(begin), delta_(delta), end_(end) {} + + __device__ Iterator begin() const { return Iterator{begin_, delta_}; } + __device__ Iterator end() const { return Iterator{end_, 0}; } + + private: + T begin_; + T delta_; + T end_; +}; + +#ifndef TENSORFLOW_USE_ROCM +template +using CudaGridRange = GpuGridRange; +#endif +} // namespace detail + +// Helper to visit indices in the range 0 <= i < count, using the x-coordinate +// of the global thread index. That is, each index i is visited by all threads +// with the same x-coordinate. +// Usage: for(int i : GpuGridRangeX(count)) { visit(i); } +template +__device__ detail::GpuGridRange GpuGridRangeX(T count) { + return detail::GpuGridRange( + /*begin=*/blockIdx.x * blockDim.x + threadIdx.x, + /*delta=*/gridDim.x * blockDim.x, /*end=*/count); +} +CREATE_CUDA_DEVICE_FUNCTION_ALIAS(GpuGridRangeX, CudaGridRangeX); + +// Helper to visit indices in the range 0 <= i < count using the y-coordinate. +// Usage: for(int i : GpuGridRangeY(count)) { visit(i); } +template +__device__ detail::GpuGridRange GpuGridRangeY(T count) { + return detail::GpuGridRange( + /*begin=*/blockIdx.y * blockDim.y + threadIdx.y, + /*delta=*/gridDim.y * blockDim.y, /*end=*/count); +} +CREATE_CUDA_DEVICE_FUNCTION_ALIAS(GpuGridRangeY, CudaGridRangeY); + +// Helper to visit indices in the range 0 <= i < count using the z-coordinate. +// Usage: for(int i : GpuGridRangeZ(count)) { visit(i); } +template +__device__ detail::GpuGridRange GpuGridRangeZ(T count) { + return detail::GpuGridRange( + /*begin=*/blockIdx.z * blockDim.z + threadIdx.z, + /*delta=*/gridDim.z * blockDim.z, /*end=*/count); +} +CREATE_CUDA_DEVICE_FUNCTION_ALIAS(GpuGridRangeZ, CudaGridRangeZ); + +// Mask for all 32 threads in a warp. +__device__ const unsigned kCudaWarpAll = 0xffffffff; +// ROCM TODO add ROCM implementation +// Mask for all 64 threads in a wavefront. +__device__ const unsigned kGpuWarpAll = 0xffffffff; + +// Returns the warp lane ID of the calling thread +__device__ inline unsigned GpuLaneId() { + unsigned int lane_id; +#if GOOGLE_CUDA +#if __clang__ + return __nvvm_read_ptx_sreg_laneid(); +#else // __clang__ + asm("mov.u32 %0, %%laneid;" : "=r"(lane_id)); +#endif // __clang__ +#elif TENSORFLOW_USE_ROCM + lane_id = __lane_id(); +#endif + return lane_id; +} +CREATE_CUDA_DEVICE_FUNCTION_ALIAS(GpuLaneId, CudaLaneId); + +namespace detail { +// Returns true if mask is a valid parameter for __shfl*sync to return a well +// defined value, assuming the calling lane will read from src_lane as part of +// the shuffle operation. +// +// Specifically, returns true iff mask has the calling lane bit and the src_lane +// bit set, and the src_lane calls this function with the same mask value +// (required for the two threads to wait for each other). +// +// On Volta, for some invalid masks, this function hangs or returns false +// positives, because the implementation shuffles with the same mask that +// we are validating. Run on Pascal if you suspect that the mask is incorrect. +__device__ inline bool GpuValidateShuffleSyncMask(unsigned mask, + unsigned src_lane) { + unsigned src_dst_mask = 1u << GpuLaneId() | 1u << src_lane; +#if GOOGLE_CUDA + unsigned src_lane_mask = __shfl_sync(mask, mask, src_lane); +#else // TENSORFLOW_USE_ROCM + unsigned src_lane_mask = + __shfl(static_cast(mask), static_cast(src_lane)); +#endif + return (src_dst_mask & ~mask) == 0 && src_lane_mask == mask; +} +CREATE_CUDA_DEVICE_FUNCTION_ALIAS(GpuValidateShuffleSyncMask, + CudaValidateShuffleSyncMask); + +// Returns the actual source lane for shuffle. +__device__ inline unsigned GpuShuffleGetSrcLane(int src_lane, int width) { + int lane_id = GpuLaneId(); + int lane_base = lane_id & ~width + 1; + int lane_offset = src_lane & width - 1; + return lane_base + lane_offset; +} +CREATE_CUDA_DEVICE_FUNCTION_ALIAS(GpuShuffleGetSrcLane, CudaShuffleGetSrcLane); + +// Returns the source lane for shuffle up. +__device__ inline unsigned GpuShuffleUpGetSrcLane(unsigned delta, int width) { + unsigned lane_id = GpuLaneId(); + if ((lane_id & width - 1) < delta) { + return lane_id; + } + return lane_id - delta; +} +CREATE_CUDA_DEVICE_FUNCTION_ALIAS(GpuShuffleUpGetSrcLane, + CudaShuffleUpGetSrcLane); + +// Returns the source lane for shuffle down. +__device__ inline unsigned GpuShuffleDownGetSrcLane(unsigned delta, int width) { + unsigned lane_id = GpuLaneId(); + if ((lane_id & width - 1) + delta >= width) { + return lane_id; + } + return lane_id + delta; +} +CREATE_CUDA_DEVICE_FUNCTION_ALIAS(GpuShuffleDownGetSrcLane, + CudaShuffleDownGetSrcLane); + +// Returns the source lane for shuffle xor. +__device__ inline unsigned GpuShuffleXorGetSrcLane(int lane_mask, int width) { + int lane_id = GpuLaneId(); + int src_lane = lane_id ^ lane_mask; + if (src_lane > (lane_id | width - 1)) { + return lane_id; + } + return src_lane; +} +CREATE_CUDA_DEVICE_FUNCTION_ALIAS(GpuShuffleXorGetSrcLane, + CudaShuffleXorGetSrcLane); +} // namespace detail + +// For all *_sync wrappers below, it is illegal to synchronize threads from +// different program locations, because that is not supported before sm_70. +// In other words, all threads in 'mask' must call the functions in convergence. +// Code that requires sm_70 (and CUDA 9) may use the intrinsic directly. +// +// It is also illegal to shuffle with a mask that produces an undefined result +// for any of the threads. Specifically, all source threads of the shuffle +// must have their corresponding bit in 'mask' set. + +// Wrapper for __syncwarp. No-op for CUDA 8 and earlier. +__device__ inline void GpuSyncWarp(unsigned mask = kCudaWarpAll) { + assert(mask & 1u << GpuLaneId()); +#if GOOGLE_CUDA + __syncwarp(mask); +#endif +} +CREATE_CUDA_DEVICE_FUNCTION_ALIAS(GpuSyncWarp, CudaSyncWarp); + +// Wrapper for __ballot_sync. All threads in 'mask' must call this function in +// convergence, see comment above for details. +__device__ inline unsigned GpuBallotSync(unsigned mask, int pred) { + assert(mask & 1u << GpuLaneId()); +#if GOOGLE_CUDA + return __ballot_sync(mask, pred); +#else // TENSORFLOW_USE_ROCM + return __ballot(pred) & mask; // Apply mask to match __ballot_sync's spec. +#endif +} +CREATE_CUDA_DEVICE_FUNCTION_ALIAS(GpuBallotSync, CudaBallotSync); + +// Wrapper for __any_sync. All threads in 'mask' must call this function in +// convergence, see comment above for details. +__device__ inline int GpuAnySync(unsigned mask, int pred) { + assert(mask & 1u << GpuLaneId()); +#if GOOGLE_CUDA + return __any_sync(mask, pred); +#else // TENSORFLOW_USE_ROCM + return __any(pred); +#endif +} +CREATE_CUDA_DEVICE_FUNCTION_ALIAS(GpuAnySync, CudaAnySync); + +// Wrapper for __all_sync. All threads in 'mask' must call this function in +// convergence, see comment above for details. +__device__ inline int GpuAllSync(unsigned mask, int pred) { + assert(mask & 1u << GpuLaneId()); +#if GOOGLE_CUDA + return __all_sync(mask, pred); +#else // TENSORFLOW_USE_ROCM + return __all(pred); +#endif +} +CREATE_CUDA_DEVICE_FUNCTION_ALIAS(GpuAllSync, CudaAllSync); + +// Wrapper for __shfl_sync. All threads in 'mask' must call this function in +// convergence, see comment above for details. +template +__device__ T GpuShuffleSync(unsigned mask, T value, int src_lane, + int width = warpSize) { + assert(!(width & width - 1)); + assert(detail::GpuValidateShuffleSyncMask( + mask, detail::GpuShuffleGetSrcLane(src_lane, width))); +#if GOOGLE_CUDA + return __shfl_sync(mask, value, src_lane, width); +#else // TENSORFLOW_USE_ROCM + return __shfl(value, src_lane, width); +#endif +} + +// Variant of the (undocumented) version from the CUDA SDK, but using unsigned +// instead of float for lo and hi (which is incorrect with ftz, for example). +// See b/69446944. +__device__ inline double GpuShuffleSync(unsigned mask, double value, + int src_lane, int width = warpSize) { +#if GOOGLE_CUDA + auto tmp = __double_as_longlong(value); + auto lo = static_cast(tmp); + auto hi = static_cast(tmp >> 32); + hi = GpuShuffleSync(mask, hi, src_lane, width); + lo = GpuShuffleSync(mask, lo, src_lane, width); + return __longlong_as_double(static_cast(hi) << 32 | lo); +#elif TENSORFLOW_USE_ROCM + auto tmp = static_cast(value); + auto lo = static_cast(tmp); + auto hi = static_cast(tmp >> 32); + hi = __shfl(static_cast(hi), src_lane, width); + lo = __shfl(static_cast(lo), src_lane, width); + return static_cast(static_cast(hi) << 32 | + static_cast(lo)); +#endif +} +CREATE_CUDA_DEVICE_FUNCTION_ALIAS(GpuShuffleSync, CudaShuffleSync); + +// Wrapper for __shfl_up_sync. All threads in 'mask' must call this function in +// convergence, see comment above for details. +template +__device__ inline T GpuShuffleUpSync(unsigned mask, T value, unsigned delta, + int width = warpSize) { + assert(!(width & width - 1)); + assert(detail::GpuValidateShuffleSyncMask( + mask, detail::GpuShuffleUpGetSrcLane(delta, width))); +#if GOOGLE_CUDA + return __shfl_up_sync(mask, value, delta, width); +#else // TENSORFLOW_USE_ROCM + return __shfl_up(value, delta, width); +#endif +} + +// Variant of the (undocumented) version from the CUDA SDK, but using unsigned +// instead of float for lo and hi (which is incorrect with ftz, for example). +// See b/69446944. +__device__ inline double GpuShuffleUpSync(unsigned mask, double value, + unsigned delta, + int width = warpSize) { +#if GOOGLE_CUDA + auto tmp = __double_as_longlong(value); + auto lo = static_cast(tmp); + auto hi = static_cast(tmp >> 32); + hi = GpuShuffleUpSync(mask, hi, delta, width); + lo = GpuShuffleUpSync(mask, lo, delta, width); + return __longlong_as_double(static_cast(hi) << 32 | lo); +#elif TENSORFLOW_USE_ROCM + auto tmp = static_cast(value); + auto lo = static_cast(tmp); + auto hi = static_cast(tmp >> 32); + hi = __shfl_up(static_cast(hi), delta, width); + lo = __shfl_up(static_cast(lo), delta, width); + return static_cast(static_cast(hi) << 32 | + static_cast(lo)); +#endif +} +CREATE_CUDA_DEVICE_FUNCTION_ALIAS(GpuShuffleUpSync, CudaShuffleUpSync); + +// Wrapper for __shfl_down_sync. All threads in 'mask' must call this function +// in convergence, see comment above for details. +template +__device__ inline T GpuShuffleDownSync(unsigned mask, T value, unsigned delta, + int width = warpSize) { + assert(!(width & width - 1)); + assert(detail::GpuValidateShuffleSyncMask( + mask, detail::GpuShuffleDownGetSrcLane(delta, width))); +#if GOOGLE_CUDA + return __shfl_down_sync(mask, value, delta, width); +#else // TENSORFLOW_USE_ROCM + return __shfl_down(value, delta, width); +#endif +} + +// Variant of the (undocumented) version from the CUDA SDK, but using unsigned +// instead of float for lo and hi (which is incorrect with ftz, for example). +// See b/69446944. +__device__ inline double GpuShuffleDownSync(unsigned mask, double value, + unsigned delta, + int width = warpSize) { +#if GOOGLE_CUDA + auto tmp = __double_as_longlong(value); + auto lo = static_cast(tmp); + auto hi = static_cast(tmp >> 32); + hi = GpuShuffleDownSync(mask, hi, delta, width); + lo = GpuShuffleDownSync(mask, lo, delta, width); + return __longlong_as_double(static_cast(hi) << 32 | lo); +#elif TENSORFLOW_USE_ROCM + auto tmp = static_cast(value); + auto lo = static_cast(tmp); + auto hi = static_cast(tmp >> 32); + hi = __shfl_down(static_cast(hi), delta, width); + lo = __shfl_down(static_cast(lo), delta, width); + return static_cast(static_cast(hi) << 32 | + static_cast(lo)); +#endif +} +CREATE_CUDA_DEVICE_FUNCTION_ALIAS(GpuShuffleDownSync, CudaShuffleDownSync); + +// Wrapper for __shfl_xor_sync. All threads in 'mask' must call this function in +// convergence, see comment above for details. +template +__device__ T GpuShuffleXorSync(unsigned mask, T value, int lane_mask, + int width = warpSize) { + assert(!(width & width - 1)); + assert(detail::GpuValidateShuffleSyncMask( + mask, detail::GpuShuffleXorGetSrcLane(lane_mask, width))); +#if GOOGLE_CUDA + return __shfl_xor_sync(mask, value, lane_mask, width); +#elif TENSORFLOW_USE_ROCM + // ROCM TODO: check if HIP should be changed to cope with more types + return __shfl_xor(static_cast(value), lane_mask, width); +#endif +} + +#if TENSORFLOW_USE_ROCM +__device__ inline Eigen::half GpuShuffleXorSync(unsigned mask, + Eigen::half value, + int lane_mask, + int width = warpSize) { + assert(!(width & width - 1)); + assert(detail::GpuValidateShuffleSyncMask( + mask, detail::GpuShuffleXorGetSrcLane(lane_mask, width))); + // TODO(rocm): This doesn't preserve NaN payload and flushes denorms to zero, + // maybe this should be implemented differently? + return static_cast( + __shfl_xor(static_cast(value), lane_mask, width)); +} +#endif + +// Variant of the (undocumented) version from the CUDA SDK, but using unsigned +// instead of float for lo and hi (which is incorrect with ftz, for example). +// See b/69446944. +__device__ inline double GpuShuffleXorSync(unsigned mask, double value, + int lane_mask, + int width = warpSize) { +#if GOOGLE_CUDA + auto tmp = __double_as_longlong(value); + auto lo = static_cast(tmp); + auto hi = static_cast(tmp >> 32); + hi = GpuShuffleXorSync(mask, hi, lane_mask, width); + lo = GpuShuffleXorSync(mask, lo, lane_mask, width); + return __longlong_as_double(static_cast(hi) << 32 | lo); +#elif TENSORFLOW_USE_ROCM + auto tmp = static_cast(value); + auto lo = static_cast(tmp); + auto hi = static_cast(tmp >> 32); + hi = __shfl_xor(static_cast(hi), lane_mask, width); + lo = __shfl_xor(static_cast(lo), lane_mask, width); + return static_cast(static_cast(hi) << 32 | + static_cast(lo)); +#endif +} +CREATE_CUDA_DEVICE_FUNCTION_ALIAS(GpuShuffleXorSync, CudaShuffleXorSync); + +// Wrapper for __ldg. +template +__host__ __device__ T GpuLdg(const T* address) { +#if __CUDA_ARCH__ >= 350 + return __ldg(address); +#else + return *address; +#endif +} + +__host__ __device__ inline bool GpuLdg(const bool* address) { + return GpuLdg(reinterpret_cast(address)) != 0; +} + +__host__ __device__ inline std::complex GpuLdg( + const std::complex* address) { +#if __CUDA_ARCH__ >= 350 + float2 mem = __ldg(reinterpret_cast(address)); + return std::complex(mem.x, mem.y); +#else + return *address; +#endif +} + +__host__ __device__ inline std::complex GpuLdg( + const std::complex* address) { +#if __CUDA_ARCH__ >= 350 + double2 mem = __ldg(reinterpret_cast(address)); + return std::complex(mem.x, mem.y); +#else + return *address; +#endif +} +CREATE_CUDA_DEVICE_FUNCTION_ALIAS(GpuLdg, CudaLdg); + +// Zeroes count elements starting at ptr using all threads of a 1-D grid. +// Note: this function does not synchronize, and therefore the memory range is +// not guaranteed to be zero until the next kernel launch. +template +__global__ void SetZero(const int count, T* __restrict__ ptr) { + // Check that the grid is one dimensional and index doesn't overflow. + assert(blockDim.y == 1); + assert(blockDim.z == 1); + assert(blockDim.x * gridDim.x / blockDim.x == gridDim.x); + for (int i : GpuGridRangeX(count)) { + ptr[i] = T(0); + } +} + +// Helper to set all tensor entries to a specific value. +template +__global__ void SetToValue(const int count, T* __restrict__ ptr, Tvalue value) { + // Check that the grid is one dimensional and index doesn't overflow. + assert(blockDim.y == 1); + assert(blockDim.z == 1); + assert(blockDim.x * gridDim.x / blockDim.x == gridDim.x); + for (int i : GpuGridRangeX(count)) { + ptr[i] = static_cast(value); + } +} + +namespace detail { +// Helper function for atomic accumulation implemented as CAS. +template +__device__ T GpuAtomicCasHelper(T* ptr, F accumulate) { + T old = *ptr; + T assumed; + do { + assumed = old; + old = atomicCAS(ptr, assumed, accumulate(assumed)); + } while (assumed != old); + return old; +} +CREATE_CUDA_DEVICE_FUNCTION_ALIAS(GpuAtomicCasHelper, CudaAtomicCasHelper); + +// Overload for floating point (using integer comparison to handle NaN +// correctly). +template +__device__ float GpuAtomicCasHelper(float* ptr, F accumulate) { + return __int_as_float( + GpuAtomicCasHelper(reinterpret_cast(ptr), [accumulate](int32 a) { + return __float_as_int(accumulate(__int_as_float(a))); + })); +} +template +__device__ double GpuAtomicCasHelper(double* ptr, F accumulate) { +#if TENSORFLOW_USE_ROCM + // FIXME: remove the workaround below once bug is fixed. + // HIP has a bug in the implementation of __longlong_as_double + // So workaround it by using reinterpret_cast. + uint64_t result = + GpuAtomicCasHelper(reinterpret_cast(ptr), + [accumulate](tensorflow::uint64 a) { + return __double_as_longlong( + accumulate(*(reinterpret_cast(&a)))); + }); + return *(reinterpret_cast(&result)); +#else + return __longlong_as_double(GpuAtomicCasHelper( + reinterpret_cast(ptr), + [accumulate](tensorflow::uint64 a) { + return __double_as_longlong(accumulate(__longlong_as_double(a))); + })); +#endif +} + +// Overload of above function for half. Note that we don't have +// atomicCAS() for anything less than 32 bits, so we need to include the +// other 16 bits in the operation. +// +// This version is going to be very slow +// under high concurrency, since most threads will be spinning on failing +// their compare-and-swap tests. (The fact that we get false sharing on the +// neighboring fp16 makes this even worse.) If you are doing a large reduction, +// you are much better off with doing the intermediate steps in fp32 and then +// switching to fp16 as late as you can in the calculations. +// +// Note: Assumes little endian. +template +__device__ Eigen::half GpuAtomicCasHelper(Eigen::half* ptr, F accumulate) { +#if defined(__BYTE_ORDER__) && defined(__ORDER_LITTLE_ENDIAN__) + static_assert(__BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__, "Not little endian"); +#endif + intptr_t intptr = reinterpret_cast(ptr); + assert(!(intptr & 0x1)); // should be 2-aligned. + if (intptr & 0x2) { + // The half is in the second part of the uint32 (upper 16 bits). + uint32* address = reinterpret_cast(intptr - 2); + uint32 result = GpuAtomicCasHelper(address, [accumulate](uint32 arg) { + unsigned short high = static_cast(arg >> 16); + Eigen::half acc = accumulate(Eigen::numext::bit_cast(high)); + return (static_cast(Eigen::numext::bit_cast(acc)) << 16) | + (arg & 0xffff); + }); + return Eigen::numext::bit_cast( + static_cast(result >> 16)); + } else { + // The half is in the first part of the uint32 (lower 16 bits). + uint32* address = reinterpret_cast(intptr); + uint32 result = GpuAtomicCasHelper(address, [accumulate](uint32 arg) { + unsigned short low = static_cast(arg & 0xffff); + Eigen::half acc = accumulate(Eigen::numext::bit_cast(low)); + return (arg & 0xffff0000) | + static_cast(Eigen::numext::bit_cast(acc)); + }); + return Eigen::numext::bit_cast( + static_cast(result & 0xffff)); + } +} + +template +__device__ Eigen::bfloat16 GpuAtomicCasHelper(Eigen::bfloat16* ptr, + F accumulate) { + Eigen::half ret = detail::GpuAtomicCasHelper( + reinterpret_cast(ptr), [accumulate](Eigen::half a) { + Eigen::bfloat16 acc = + accumulate(Eigen::numext::bit_cast(a)); + return Eigen::numext::bit_cast(acc); + }); + return Eigen::numext::bit_cast(ret); +} + +template +__device__ long long GpuAtomicCasHelper(long long* ptr, F accumulate) { + return static_cast( + GpuAtomicCasHelper(reinterpret_cast(ptr), + [accumulate](unsigned long long a) { + return static_cast( + accumulate(static_cast(a))); + })); +} + +template +using ToTypeIfConvertible = + typename std::enable_if::value, To>::type; + +template +struct CudaSupportedTypeImpl { + using type = T; +}; + +template <> +struct CudaSupportedTypeImpl { + using type = unsigned long long; +}; + +template <> +struct CudaSupportedTypeImpl { + using type = + typename std::conditional::type; +}; + +template <> +struct CudaSupportedTypeImpl { + // This cast should be safe since module-2 addition should work fine. However, + // signed overflow is not handled correctly since it's undefined behavior. + using type = typename CudaSupportedTypeImpl::type; +}; + +template +using CudaSupportedType = typename CudaSupportedTypeImpl::type; + +template +__device__ CudaSupportedType* ToCudaSupportedPtr(T* ptr) { + return reinterpret_cast*>(ptr); +} + +} // namespace detail + +// CUDA provides atomic ops, but not for all types. We provide wrappers +// for some ops and provide implementation for all reasonable types. + +template +__device__ detail::ToTypeIfConvertible GpuAtomicAdd(T* ptr, U value) { + return atomicAdd(detail::ToCudaSupportedPtr(ptr), value); +} + +__device__ inline Eigen::half GpuAtomicAdd(Eigen::half* ptr, + Eigen::half value) { + return detail::GpuAtomicCasHelper( + ptr, [value](Eigen::half a) { return a + value; }); +} + +__device__ inline Eigen::bfloat16 GpuAtomicAdd(Eigen::bfloat16* ptr, + Eigen::bfloat16 value) { + return detail::GpuAtomicCasHelper( + ptr, [value](Eigen::bfloat16 a) { return a + value; }); +} + +#if (__CUDA_ARCH__ < 600) || TENSORFLOW_USE_ROCM +__device__ inline double GpuAtomicAdd(double* ptr, double value) { + return detail::GpuAtomicCasHelper(ptr, + [value](double a) { return a + value; }); +} +#endif + +// GpuAtomicAdd +// Specializations of GpuAtomicAdd for complex types, which GpuAtomicAdd does +// not support. We treat a std::complex* as a T* (the C++ standard section +// 26.4.4 allows this explicitly) and atomic add the real and imaginary +// components individually. The operation as a whole is not atomic, but we can +// safely treat the components independently for the purpose of accumulating. + +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM +__device__ inline std::complex GpuAtomicAdd(std::complex* ptr, + std::complex value) { + auto ptr_scalar = reinterpret_cast(ptr); + return std::complex(GpuAtomicAdd(ptr_scalar, value.real()), + GpuAtomicAdd(ptr_scalar + 1, value.imag())); +} + +__device__ inline std::complex GpuAtomicAdd( + std::complex* ptr, std::complex value) { + auto ptr_scalar = reinterpret_cast(ptr); + return std::complex(GpuAtomicAdd(ptr_scalar, value.real()), + GpuAtomicAdd(ptr_scalar + 1, value.imag())); +} +#endif +CREATE_CUDA_DEVICE_FUNCTION_ALIAS(GpuAtomicAdd, CudaAtomicAdd); + +// GpuAtomicSub +template +__device__ detail::ToTypeIfConvertible GpuAtomicSub(T* ptr, U value) { + return atomicSub(ptr, value); +} + +// Specializations of substraction which add the negative value. +__device__ inline float GpuAtomicSub(float* ptr, float value) { + return GpuAtomicAdd(ptr, -value); +} + +__device__ inline double GpuAtomicSub(double* ptr, double value) { + return GpuAtomicAdd(ptr, -value); +} + +__device__ inline int64_t GpuAtomicSub(int64_t* ptr, int64_t value) { + return GpuAtomicAdd(ptr, -value); +} + +__device__ inline tensorflow::uint64 GpuAtomicSub(tensorflow::uint64* ptr, + tensorflow::uint64 value) { + return GpuAtomicAdd(ptr, -static_cast(value)); +} + +__device__ inline Eigen::half GpuAtomicSub(Eigen::half* ptr, + Eigen::half value) { + return detail::GpuAtomicCasHelper( + ptr, [value](Eigen::half a) { return a - value; }); +} + +__device__ inline Eigen::bfloat16 GpuAtomicSub(Eigen::bfloat16* ptr, + Eigen::bfloat16 value) { + return detail::GpuAtomicCasHelper( + ptr, [value](Eigen::bfloat16 a) { return a - value; }); +} + +CREATE_CUDA_DEVICE_FUNCTION_ALIAS(GpuAtomicSub, CudaAtomicSub); + +// GpuAtomicMax +template +__device__ detail::ToTypeIfConvertible GpuAtomicMax(T* ptr, U value) { + return atomicMax(detail::ToCudaSupportedPtr(ptr), value); +} + +#if TENSORFLOW_USE_ROCM + +/* + * CUDA runtime headers have the following defined + * __device__ int max(int, int) + * __device__ float max(float, float) + * __device__ double max(double, double) + * + * and many others, where as HIP runtime headers only have the "int" version + * + * Therefore need to special case ROCm version to call the correct underlying + * routines for float and double types. + * + */ + +__device__ inline float GpuAtomicMax(float* ptr, float value) { + return detail::GpuAtomicCasHelper( + ptr, [value](float a) { return fmaxf(a, value); }); +} + +__device__ inline double GpuAtomicMax(double* ptr, double value) { + return detail::GpuAtomicCasHelper( + ptr, [value](double a) { return fmax(a, value); }); +} + +#else + +__device__ inline float GpuAtomicMax(float* ptr, float value) { + return detail::GpuAtomicCasHelper(ptr, + [value](float a) { return max(a, value); }); +} + +__device__ inline double GpuAtomicMax(double* ptr, double value) { + return detail::GpuAtomicCasHelper( + ptr, [value](double a) { return max(a, value); }); +} + +#endif + +__device__ inline Eigen::half GpuAtomicMax(Eigen::half* ptr, + Eigen::half value) { + return detail::GpuAtomicCasHelper( + ptr, [value](Eigen::half a) { return max(a, value); }); +} + +__device__ inline Eigen::bfloat16 GpuAtomicMax(Eigen::bfloat16* ptr, + Eigen::bfloat16 value) { + return detail::GpuAtomicCasHelper( + ptr, [value](Eigen::bfloat16 a) { return max(a, value); }); +} + +#if TENSORFLOW_USE_ROCM || (__CUDA_ARCH__ < 320) +__device__ inline tensorflow::uint64 GpuAtomicMax(tensorflow::uint64* ptr, + tensorflow::uint64 value) { + return detail::GpuAtomicCasHelper( + detail::ToCudaSupportedPtr(ptr), + [value](tensorflow::uint64 a) { return max(a, value); }); +} + +__device__ inline int64_t GpuAtomicMax(int64_t* ptr, int64_t value) { + return detail::GpuAtomicCasHelper( + detail::ToCudaSupportedPtr(ptr), + [value](int64_t a) { return max(a, value); }); +} +#endif +CREATE_CUDA_DEVICE_FUNCTION_ALIAS(GpuAtomicMax, CudaAtomicMax); + +// GpuAtomicMin +template +__device__ detail::ToTypeIfConvertible GpuAtomicMin(T* ptr, U value) { + return atomicMin(detail::ToCudaSupportedPtr(ptr), value); +} + +#if TENSORFLOW_USE_ROCM + +/* + * CUDA runtime headers have the following defined + * __device__ int min(int, int) + * __device__ float min(float, float) + * __device__ double min(double, double) + * + * and many others, where as HIP runtime headers only have the "int" version + * + * Therefore need to special case ROCm version to call the correct underlying + * routines for float and double types. + * + */ + +__device__ inline float GpuAtomicMin(float* ptr, float value) { + return detail::GpuAtomicCasHelper( + ptr, [value](float a) { return fminf(a, value); }); +} + +__device__ inline double GpuAtomicMin(double* ptr, double value) { + return detail::GpuAtomicCasHelper( + ptr, [value](double a) { return fmin(a, value); }); +} + +#else + +__device__ inline float GpuAtomicMin(float* ptr, float value) { + return detail::GpuAtomicCasHelper(ptr, + [value](float a) { return min(a, value); }); +} + +__device__ inline double GpuAtomicMin(double* ptr, double value) { + return detail::GpuAtomicCasHelper( + ptr, [value](double a) { return min(a, value); }); +} + +#endif + +__device__ inline Eigen::half GpuAtomicMin(Eigen::half* ptr, + Eigen::half value) { + return detail::GpuAtomicCasHelper( + ptr, [value](Eigen::half a) { return min(a, value); }); +} + +__device__ inline Eigen::bfloat16 GpuAtomicMin(Eigen::bfloat16* ptr, + Eigen::bfloat16 value) { + return detail::GpuAtomicCasHelper( + ptr, [value](Eigen::bfloat16 a) { return min(a, value); }); +} + +#if TENSORFLOW_USE_ROCM || (__CUDA_ARCH__ < 320) +__device__ inline tensorflow::uint64 GpuAtomicMin(tensorflow::uint64* ptr, + tensorflow::uint64 value) { + return detail::GpuAtomicCasHelper( + detail::ToCudaSupportedPtr(ptr), + [value](tensorflow::uint64 a) { return min(a, value); }); +} + +__device__ inline int64_t GpuAtomicMin(int64_t* ptr, int64_t value) { + return detail::GpuAtomicCasHelper( + detail::ToCudaSupportedPtr(ptr), + [value](int64_t a) { return min(a, value); }); +} +#endif +CREATE_CUDA_DEVICE_FUNCTION_ALIAS(GpuAtomicMin, CudaAtomicMin); + +// GpuAtomicMul +template +__device__ detail::ToTypeIfConvertible GpuAtomicMul(T* ptr, U value) { + return detail::GpuAtomicCasHelper(ptr, [value](T a) { return a * value; }); +} +CREATE_CUDA_DEVICE_FUNCTION_ALIAS(GpuAtomicMul, CudaAtomicMul); + +// GpuAtomicDiv +template +__device__ detail::ToTypeIfConvertible GpuAtomicDiv(T* ptr, U value) { + return detail::GpuAtomicCasHelper(ptr, [value](T a) { return a / value; }); +} +CREATE_CUDA_DEVICE_FUNCTION_ALIAS(GpuAtomicDiv, CudaAtomicDiv); + +// Import all specialized std::complex device operators in namespace tensorflow. +#if GOOGLE_CUDA && defined(EIGEN_USING_STD_COMPLEX_OPERATORS) +EIGEN_USING_STD_COMPLEX_OPERATORS +#endif // GOOGLE_CUDA + +namespace functor { +// Import all specialized std::complex device operators in namespace functor. +#if GOOGLE_CUDA && defined(EIGEN_USING_STD_COMPLEX_OPERATORS) +EIGEN_USING_STD_COMPLEX_OPERATORS +#endif // GOOGLE_CUDA + +// ROCm hcc(clang) has severe difficulties dealing with std::complex directly +// due to a header issue. This template assists in casting std::complex into the +// corresponding internal ROCm types. +template +struct MapComplexToHipComplex { + typedef T TM; +}; + +#if TENSORFLOW_USE_ROCM +template <> +struct MapComplexToHipComplex > { + typedef hipFloatComplex TM; +}; + +template <> +struct MapComplexToHipComplex > { + typedef hipDoubleComplex TM; +}; +#endif +}; // namespace functor + +} // namespace tensorflow + +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM +#endif // TENSORFLOW_CORE_UTIL_GPU_DEVICE_FUNCTIONS_H_ diff --git a/third_party/tflite-hdrs/tensorflow/core/util/gpu_kernel_helper.h b/third_party/tflite-hdrs/tensorflow/core/util/gpu_kernel_helper.h new file mode 100644 index 00000000..ae9894cc --- /dev/null +++ b/third_party/tflite-hdrs/tensorflow/core/util/gpu_kernel_helper.h @@ -0,0 +1,524 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_UTIL_GPU_KERNEL_HELPER_H_ +#define TENSORFLOW_CORE_UTIL_GPU_KERNEL_HELPER_H_ + +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM + +#include + +#if GOOGLE_CUDA +#include "third_party/gpus/cuda/include/cuda_fp16.h" +#endif +#include "tensorflow/core/util/gpu_cuda_alias.h" +#include "tensorflow/core/util/gpu_device_functions.h" +#include "tensorflow/core/util/gpu_launch_config.h" + +#if GOOGLE_CUDA +#define TF_RED_WARPSIZE 32 +#elif TENSORFLOW_USE_ROCM +#define TF_RED_WARPSIZE 64 +#endif + +// Deprecated, use 'for(int i : GpuGridRangeX(n))' instead. +#define GPU_1D_KERNEL_LOOP(i, n) \ + for (int i : ::tensorflow::GpuGridRangeX(n)) +#define CUDA_1D_KERNEL_LOOP(i, n) \ + for (int i : ::tensorflow::GpuGridRangeX(n)) + +// Deprecated, use 'for(int i : GpuGridRange?(n))' instead. +#define GPU_AXIS_KERNEL_LOOP(i, n, axis) \ + for (int i : ::tensorflow::GpuGridRange##axis(n)) +#define CUDA_AXIS_KERNEL_LOOP(i, n, axis) \ + for (int i : ::tensorflow::GpuGridRange##axis(n)) + +#if GOOGLE_CUDA +#define gpuSuccess cudaSuccess +using gpuStream_t = cudaStream_t; +using gpuError_t = cudaError_t; +#elif TENSORFLOW_USE_ROCM +#define gpuSuccess hipSuccess +using gpuStream_t = hipStream_t; +using gpuError_t = hipError_t; +#endif + +// macro wrapper to declare dynamic shared memory +#if GOOGLE_CUDA + +#define GPU_DYNAMIC_SHARED_MEM_DECL(ALIGN, TYPE, NAME) \ + extern __shared__ __align__(ALIGN) \ + TYPE NAME[] + +#elif TENSORFLOW_USE_ROCM + +#define GPU_DYNAMIC_SHARED_MEM_DECL(ALIGN, TYPE, NAME) \ + HIP_DYNAMIC_SHARED(TYPE, NAME) + +#endif + +namespace tensorflow { + +#if GOOGLE_CUDA +// cudaGetErrorString is available to both host and device +__host__ __device__ inline const char* GpuGetErrorString(cudaError_t error) { + return cudaGetErrorString(error); +} +#elif TENSORFLOW_USE_ROCM +// hipGetErrorString is available on host side only +inline const char* GpuGetErrorString(hipError_t error) { + return hipGetErrorString(error); +} +#endif + +// Returns a raw reference to the current cuda stream. Required by a +// number of kernel calls (for which StreamInterface* does not work), +// i.e. CUB and certain cublas primitives. +inline gpuStream_t GetGpuStream(OpKernelContext* context) { + void* opaque_stream = CHECK_NOTNULL(context->op_device_context() + ->stream() + ->platform_specific_handle() + .stream); + return reinterpret_cast(opaque_stream); +} + +// Launches a GPU kernel through cudaLaunchKernel in CUDA environment, or +// hipLaunchKernel in ROCm environment with the given arguments. +// +// The kernel parameters 'Ts' must be constructible from the arguments 'Args'. +template +Status GpuLaunchKernel(void (*function)(Ts...), dim3 grid_dim, dim3 block_dim, + size_t shared_memory_size_bytes, gpuStream_t stream, + Args... arguments) { + static_assert(detail::NoneIsReference(), + "Kernels with reference arguments have undefined behaviour."); + if (grid_dim.x * grid_dim.y * grid_dim.z > 0 && + block_dim.x * block_dim.y * block_dim.z > 0) { +#if GOOGLE_CUDA + auto func_ptr = absl::bit_cast(function); + // Cast arguments and forward them as an array of pointers. + auto args_tuple = std::tuple(arguments...); + auto arg_ptrs = detail::GetArrayOfElementPointers(&args_tuple); + auto result = + cudaLaunchKernel(func_ptr, grid_dim, block_dim, arg_ptrs.data(), + shared_memory_size_bytes, stream); + if (result != cudaSuccess) { + return errors::Internal(cudaGetErrorString(result)); + } +#elif TENSORFLOW_USE_ROCM + hipLaunchKernelGGL(function, grid_dim, block_dim, shared_memory_size_bytes, + stream, std::forward(arguments)...); + TF_RETURN_IF_CUDA_ERROR(hipGetLastError()); +#endif + } + return OkStatus(); +} + +// Perfect forwarding to make CudaLaunchKernel available to both ROCm and CUDA +// builds +template +auto CudaLaunchKernel(Args&&... args) + -> decltype(GpuLaunchKernel(std::forward(args)...)) { + return GpuLaunchKernel(std::forward(args)...); +} + +__host__ __device__ inline tensorflow::bfloat16 GpuLdg( + const tensorflow::bfloat16* address) { + return Eigen::numext::bit_cast( + GpuLdg(reinterpret_cast(address))); +} +// Already aliased in gpu_device_functions.h + +template +__host__ __device__ inline T ldg(const T* ptr) { + return GpuLdg(ptr); +} + +template +__host__ __device__ inline const T& tf_min(const T& x, const T& y) { + return x < y ? x : y; +} + +template +__host__ __device__ inline const T& tf_max(const T& x, const T& y) { + return x < y ? y : x; +} + +// Overloads of the above functions for float and double. +__host__ __device__ inline float tf_min(float x, float y) { + return fminf(x, y); +} +__host__ __device__ inline double tf_min(double x, double y) { + return fmin(x, y); +} +__host__ __device__ inline float tf_max(float x, float y) { + return fmaxf(x, y); +} +__host__ __device__ inline double tf_max(double x, double y) { + return fmax(x, y); +} + +#ifdef _MSC_VER +#if _MSC_VER >= 1930 +using std::max; +using std::min; +__host__ __device__ inline int tf_min(int x, int y) { return min(x, y); } +__host__ __device__ inline int tf_max(int x, int y) { return max(x, y); } +#endif +#endif + +// ROCM TODO re-enable them after adding fp16 support logic +#if GOOGLE_CUDA +__device__ inline Eigen::half GpuShuffleSync(unsigned mask, Eigen::half value, + int src_lane, + int width = warpSize) { + return Eigen::half( + GpuShuffleSync(mask, static_cast(value), src_lane, width)); +} +// Aliased in gpu_device_functions.h + +__device__ EIGEN_ALWAYS_INLINE Eigen::half GpuShuffleUpSync( + unsigned mask, Eigen::half value, int delta, int width = warpSize) { + return Eigen::half( + GpuShuffleUpSync(mask, static_cast(value), delta, width)); +} +// Aliased in gpu_device_functions.h + +__device__ EIGEN_ALWAYS_INLINE Eigen::half GpuShuffleDownSync( + unsigned mask, Eigen::half value, int delta, int width = warpSize) { + return Eigen::half( + GpuShuffleDownSync(mask, static_cast(value), delta, width)); +} +// Aliased in gpu_device_functions.h + +__device__ EIGEN_ALWAYS_INLINE Eigen::half GpuShuffleXorSync( + unsigned mask, Eigen::half value, int lane_mask, int width = warpSize) { + return Eigen::half( + GpuShuffleXorSync(mask, static_cast(value), lane_mask, width)); +} +// Aliased in gpu_device_functions.h +#endif + +#ifdef __CUDA_ARCH__ +#define UNROLL_ON_DEVICE _Pragma("unroll") +#else +#define UNROLL_ON_DEVICE +#endif + +// Represents an aligned array of N elements of T. Data pointers can be +// reinterpreted as this type to generate vectorized loads/stores in a kernel. +template +class alignas(alignof(T) * N) AlignedVector { + public: + typedef T value_type; + static constexpr const int kSize = N; + + AlignedVector() = default; + + // Uniform initialization. + __host__ __device__ explicit AlignedVector(value_type uniform) { + UNROLL_ON_DEVICE for (int i = 0; i < kSize; ++i) { values_[i] = uniform; } + } + // Uniform initialization with explicit conversion. + // Note: This is required for T=Eigen::half because it only supports explicit + // conversions from other types and its template constructor is too relaxed + // to be able to use std::is_constructible. + template ::value, + int>::type = 0> + __host__ __device__ explicit AlignedVector(U uniform_u) { + value_type uniform(uniform_u); + UNROLL_ON_DEVICE for (int i = 0; i < kSize; ++i) { values_[i] = uniform; } + } + // Implicit conversion. + template ::value, int>::type = 0> + __host__ __device__ AlignedVector(const AlignedVector& other) { + UNROLL_ON_DEVICE for (int i = 0; i < kSize; ++i) { values_[i] = other[i]; } + } + // Explicit conversion. + template ::value && + std::is_constructible::value, + int>::type = 0> + __host__ __device__ explicit AlignedVector(const AlignedVector& other) { + UNROLL_ON_DEVICE for (int i = 0; i < kSize; ++i) { + values_[i] = T(other[i]); + } + } + + __host__ __device__ value_type& operator[](int i) { return values_[i]; } + __host__ __device__ const value_type& operator[](int i) const { + return values_[i]; + } + +#define DEFINE_BINARY_UPDATE_OPERATOR(op) \ + __host__ __device__ AlignedVector& operator op(const AlignedVector& rhs) { \ + UNROLL_ON_DEVICE for (int i = 0; i < kSize; ++i) { values_[i] op rhs[i]; } \ + return *this; \ + } + DEFINE_BINARY_UPDATE_OPERATOR(+=) + DEFINE_BINARY_UPDATE_OPERATOR(-=) + DEFINE_BINARY_UPDATE_OPERATOR(*=) + DEFINE_BINARY_UPDATE_OPERATOR(/=) +#undef DEFINE_BINARY_UPDATE_OPERATOR + +#define DEFINE_BINARY_OPERATOR(op) \ + friend __host__ __device__ AlignedVector operator op( \ + const AlignedVector& lhs, const AlignedVector& rhs) { \ + AlignedVector ret; \ + UNROLL_ON_DEVICE for (int i = 0; i < kSize; ++i) { \ + ret[i] = lhs[i] op rhs[i]; \ + } \ + return ret; \ + } + DEFINE_BINARY_OPERATOR(+) + DEFINE_BINARY_OPERATOR(-) + DEFINE_BINARY_OPERATOR(*) + DEFINE_BINARY_OPERATOR(/) +#undef DEFINE_BINARY_OPERATOR + +#define DEFINE_BINARY_FUNCTION(func) \ + friend __host__ __device__ AlignedVector func(const AlignedVector& lhs, \ + const AlignedVector& rhs) { \ + AlignedVector ret; \ + UNROLL_ON_DEVICE for (int i = 0; i < kSize; ++i) { \ + ret[i] = func(lhs[i], rhs[i]); \ + } \ + return ret; \ + } + DEFINE_BINARY_FUNCTION(min) + DEFINE_BINARY_FUNCTION(max) +#undef DEFINE_BINARY_FUNCTION + + private: + value_type values_[N]; +}; + +#undef UNROLL_ON_DEVICE + +// Returns the maximum power-of-two alignment (in units of elements, not bytes) +// of a stride or pointer value. +inline int64_t alignment_of(int64_t element_stride) { + // A zero/nullptr value means that the stride/pointer is not used, so it + // effectively has infinite alignment. + constexpr int64_t kMaxAlignment = 512; + if (element_stride == 0) return kMaxAlignment; + return element_stride & -element_stride; +} + +template +inline int64_t alignment_of(T* ptr) { + const intptr_t ptr_val = reinterpret_cast(ptr); + // Pointers should always be aligned to sizeof(T) bytes. + DCHECK_EQ(ptr_val % sizeof(T), 0); + // Note that we want the alignment in elements, not bytes. + return alignment_of(ptr_val / sizeof(T)); +} + +template +int64_t MinAlignmentOf(Args... args) { + return std::min({alignment_of(args)...}); +} + +namespace detail { + +template class Functor> +struct DispatchToVectorizedHelper { + template + Status operator()(int64_t max_vec_size, Args&&... args) const { + if (max_vec_size >= VecSize) { + return Functor()(std::forward(args)...); + } + return DispatchToVectorizedHelper()( + max_vec_size, std::forward(args)...); + } +}; +template